[Misc] Clean sgl-kernel test (#5216)
This commit is contained in:
@@ -49,7 +49,6 @@ def test_verify_tree_greedy():
|
||||
if torch.max(target_logits[i][j]) < 10:
|
||||
target_logits[i][j][18] = 10
|
||||
|
||||
print(f"{target_logits=}")
|
||||
target_predict = torch.argmax(target_logits, dim=-1).to(torch.int32)
|
||||
predict_shape = (12,)
|
||||
|
||||
@@ -65,12 +64,6 @@ def test_verify_tree_greedy():
|
||||
) # mutable
|
||||
accept_token_num = torch.full((bs,), 0, dtype=torch.int32, device="cuda") # mutable
|
||||
|
||||
print(f"{candidates=}")
|
||||
print(f"{retrive_index=}")
|
||||
print(f"{retrive_next_token=}")
|
||||
print(f"{retrive_next_sibling=}")
|
||||
print(f"{target_predict=}")
|
||||
|
||||
verify_tree_greedy(
|
||||
predicts=predicts,
|
||||
accept_index=accept_index,
|
||||
@@ -82,10 +75,6 @@ def test_verify_tree_greedy():
|
||||
target_predict=target_predict,
|
||||
)
|
||||
|
||||
print(f"{predicts=}")
|
||||
print(f"{accept_index=}")
|
||||
print(f"{accept_token_num=}")
|
||||
|
||||
# Check the expected output.
|
||||
assert predicts.tolist() == [3, -1, -1, 4, 5, 18, 11, -1, -1, -1, 12, 18]
|
||||
assert accept_index.tolist() == [
|
||||
|
||||
Reference in New Issue
Block a user