[Feature] use pytest for sgl-kernel (#4896)

This commit is contained in:
Adarsh Shirawalmath
2025-03-30 23:06:52 +05:30
committed by GitHub
parent 4ede6770cd
commit 9fccda3111
10 changed files with 263 additions and 290 deletions

View File

@@ -1,3 +1,4 @@
import pytest
import torch
import torch.nn.functional as F
from sgl_kernel import tree_speculative_sampling_target_only
@@ -97,26 +98,21 @@ def test_tree_speculative_sampling_target_only(threshold_single=1, threshold_acc
print(f"{accept_index=}")
print(f"{accept_token_num=}")
return predicts, accept_index, accept_token_num
if threshold_single == 1 and threshold_acc == 1:
assert predicts.tolist() == [3, -1, -1, 4, 5, 18, 11, -1, -1, -1, 12, 18]
assert accept_index.tolist() == [
[0, 3, 4, 5],
[6, 10, 11, -1],
]
assert accept_token_num.tolist() == [3, 2]
elif threshold_single == 0 and threshold_acc == 0:
assert predicts.tolist() == [1, 2, 18, -1, -1, -1, 11, -1, -1, -1, 12, 18]
assert accept_index.tolist() == [
[0, 1, 2, -1],
[6, 10, 11, -1],
]
assert accept_token_num.tolist() == [2, 2]
if __name__ == "__main__":
predicts, accept_index, accept_token_num = (
test_tree_speculative_sampling_target_only(threshold_single=1, threshold_acc=1)
)
assert predicts.tolist() == [3, -1, -1, 4, 5, 18, 11, -1, -1, -1, 12, 18]
assert accept_index.tolist() == [
[0, 3, 4, 5],
[6, 10, 11, -1],
]
assert accept_token_num.tolist() == [3, 2]
predicts, accept_index, accept_token_num = (
test_tree_speculative_sampling_target_only(threshold_single=0, threshold_acc=0)
)
assert predicts.tolist() == [1, 2, 18, -1, -1, -1, 11, -1, -1, -1, 12, 18]
assert accept_index.tolist() == [
[0, 1, 2, -1],
[6, 10, 11, -1],
]
assert accept_token_num.tolist() == [2, 2]
pytest.main([__file__])