[Feature] use pytest for sgl-kernel (#4896)
This commit is contained in:
committed by
GitHub
parent
4ede6770cd
commit
9fccda3111
@@ -1,3 +1,4 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from sgl_kernel import tree_speculative_sampling_target_only
|
||||
@@ -97,26 +98,21 @@ def test_tree_speculative_sampling_target_only(threshold_single=1, threshold_acc
|
||||
print(f"{accept_index=}")
|
||||
print(f"{accept_token_num=}")
|
||||
|
||||
return predicts, accept_index, accept_token_num
|
||||
if threshold_single == 1 and threshold_acc == 1:
|
||||
assert predicts.tolist() == [3, -1, -1, 4, 5, 18, 11, -1, -1, -1, 12, 18]
|
||||
assert accept_index.tolist() == [
|
||||
[0, 3, 4, 5],
|
||||
[6, 10, 11, -1],
|
||||
]
|
||||
assert accept_token_num.tolist() == [3, 2]
|
||||
elif threshold_single == 0 and threshold_acc == 0:
|
||||
assert predicts.tolist() == [1, 2, 18, -1, -1, -1, 11, -1, -1, -1, 12, 18]
|
||||
assert accept_index.tolist() == [
|
||||
[0, 1, 2, -1],
|
||||
[6, 10, 11, -1],
|
||||
]
|
||||
assert accept_token_num.tolist() == [2, 2]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
predicts, accept_index, accept_token_num = (
|
||||
test_tree_speculative_sampling_target_only(threshold_single=1, threshold_acc=1)
|
||||
)
|
||||
assert predicts.tolist() == [3, -1, -1, 4, 5, 18, 11, -1, -1, -1, 12, 18]
|
||||
assert accept_index.tolist() == [
|
||||
[0, 3, 4, 5],
|
||||
[6, 10, 11, -1],
|
||||
]
|
||||
assert accept_token_num.tolist() == [3, 2]
|
||||
|
||||
predicts, accept_index, accept_token_num = (
|
||||
test_tree_speculative_sampling_target_only(threshold_single=0, threshold_acc=0)
|
||||
)
|
||||
assert predicts.tolist() == [1, 2, 18, -1, -1, -1, 11, -1, -1, -1, 12, 18]
|
||||
assert accept_index.tolist() == [
|
||||
[0, 1, 2, -1],
|
||||
[6, 10, 11, -1],
|
||||
]
|
||||
assert accept_token_num.tolist() == [2, 2]
|
||||
pytest.main([__file__])
|
||||
|
||||
Reference in New Issue
Block a user