[Feature] use pytest for sgl-kernel (#4896)
This commit is contained in:
committed by
GitHub
parent
4ede6770cd
commit
9fccda3111
@@ -1,3 +1,4 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from sgl_kernel import verify_tree_greedy
|
||||
@@ -85,14 +86,14 @@ def test_verify_tree_greedy():
|
||||
print(f"{accept_index=}")
|
||||
print(f"{accept_token_num=}")
|
||||
|
||||
return predicts, accept_index, accept_token_num
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
predicts, accept_index, accept_token_num = test_verify_tree_greedy()
|
||||
# Check the expected output.
|
||||
assert predicts.tolist() == [3, -1, -1, 4, 5, 18, 11, -1, -1, -1, 12, 18]
|
||||
assert accept_index.tolist() == [
|
||||
[0, 3, 4, 5],
|
||||
[6, 10, 11, -1],
|
||||
]
|
||||
assert accept_token_num.tolist() == [3, 2]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
|
||||
Reference in New Issue
Block a user