Add greedy verification kernel (#4383)
This commit is contained in:
98
sgl-kernel/tests/speculative/test_eagle_utils.py
Normal file
98
sgl-kernel/tests/speculative/test_eagle_utils.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from sgl_kernel import verify_tree_greedy
|
||||
|
||||
|
||||
def test_verify_tree_greedy():
|
||||
candidates = torch.tensor(
|
||||
[
|
||||
[0, 1, 2, 3, 4, 5],
|
||||
[7, 8, 9, 10, 11, 12],
|
||||
],
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
retrive_index = torch.tensor(
|
||||
[
|
||||
[0, 1, 2, 3, 4, 5],
|
||||
[6, 7, 8, 9, 10, 11],
|
||||
],
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
retrive_next_token = torch.tensor(
|
||||
[
|
||||
[1, 2, -1, 4, 5, -1],
|
||||
[4, 2, 3, -1, 5, -1],
|
||||
],
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
retrive_next_sibling = torch.tensor(
|
||||
[
|
||||
[-1, 3, -1, -1, -1, -1],
|
||||
[-1, -1, -1, -1, 1, -1],
|
||||
],
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
target_logits = torch.full((2, 6, 20), 1, dtype=torch.float32, device="cuda")
|
||||
target_logits[0, 0, 3] = 10
|
||||
target_logits[0, 3, 4] = 10
|
||||
target_logits[0, 4, 5] = 10
|
||||
target_logits[1, 0, 11] = 10
|
||||
target_logits[1, 4, 12] = 10
|
||||
for i in range(target_logits.shape[0]):
|
||||
for j in range(target_logits.shape[1]):
|
||||
if torch.max(target_logits[i][j]) < 10:
|
||||
target_logits[i][j][18] = 10
|
||||
|
||||
print(f"{target_logits=}")
|
||||
target_predict = torch.argmax(target_logits, dim=-1).to(torch.int32)
|
||||
predict_shape = (12,)
|
||||
|
||||
bs = candidates.shape[0]
|
||||
num_spec_step = 4
|
||||
num_draft_tokens = candidates.shape[1]
|
||||
|
||||
predicts = torch.full(
|
||||
predict_shape, -1, dtype=torch.int32, device="cuda"
|
||||
) # mutable
|
||||
accept_index = torch.full(
|
||||
(bs, num_spec_step), -1, dtype=torch.int32, device="cuda"
|
||||
) # mutable
|
||||
accept_token_num = torch.full((bs,), 0, dtype=torch.int32, device="cuda") # mutable
|
||||
|
||||
print(f"{candidates=}")
|
||||
print(f"{retrive_index=}")
|
||||
print(f"{retrive_next_token=}")
|
||||
print(f"{retrive_next_sibling=}")
|
||||
print(f"{target_predict=}")
|
||||
|
||||
verify_tree_greedy(
|
||||
predicts=predicts,
|
||||
accept_index=accept_index,
|
||||
accept_token_num=accept_token_num,
|
||||
candidates=candidates,
|
||||
retrive_index=retrive_index,
|
||||
retrive_next_token=retrive_next_token,
|
||||
retrive_next_sibling=retrive_next_sibling,
|
||||
target_predict=target_predict,
|
||||
)
|
||||
|
||||
print(f"{predicts=}")
|
||||
print(f"{accept_index=}")
|
||||
print(f"{accept_token_num=}")
|
||||
|
||||
return predicts, accept_index, accept_token_num
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
predicts, accept_index, accept_token_num = test_verify_tree_greedy()
|
||||
assert predicts.tolist() == [3, -1, -1, 4, 5, 18, 11, -1, -1, -1, 12, 18]
|
||||
assert accept_index.tolist() == [
|
||||
[0, 3, 4, 5],
|
||||
[6, 10, 11, -1],
|
||||
]
|
||||
assert accept_token_num.tolist() == [3, 2]
|
||||
@@ -3,7 +3,10 @@ import torch.nn.functional as F
|
||||
from sgl_kernel import tree_speculative_sampling_target_only
|
||||
|
||||
|
||||
def test_tree_speculative_sampling_target_only():
|
||||
def test_tree_speculative_sampling_target_only(threshold_single=1, threshold_acc=1):
|
||||
print(
|
||||
f"\n============= run test: {threshold_single=} {threshold_acc=} ==============\n"
|
||||
)
|
||||
candidates = torch.tensor(
|
||||
[
|
||||
[0, 1, 2, 3, 4, 5],
|
||||
@@ -37,7 +40,7 @@ def test_tree_speculative_sampling_target_only():
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
target_logits = torch.zeros((2, 6, 20), dtype=torch.float32, device="cuda")
|
||||
target_logits = torch.full((2, 6, 20), 1, dtype=torch.float32, device="cuda")
|
||||
target_logits[0, 0, 3] = 10
|
||||
target_logits[0, 3, 4] = 10
|
||||
target_logits[0, 4, 5] = 10
|
||||
@@ -85,6 +88,8 @@ def test_tree_speculative_sampling_target_only():
|
||||
uniform_samples=coins,
|
||||
target_probs=target_probs,
|
||||
draft_probs=draft_probs,
|
||||
threshold_single=threshold_single,
|
||||
threshold_acc=threshold_acc,
|
||||
deterministic=True,
|
||||
)
|
||||
|
||||
@@ -92,6 +97,13 @@ def test_tree_speculative_sampling_target_only():
|
||||
print(f"{accept_index=}")
|
||||
print(f"{accept_token_num=}")
|
||||
|
||||
return predicts, accept_index, accept_token_num
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
predicts, accept_index, accept_token_num = (
|
||||
test_tree_speculative_sampling_target_only(threshold_single=1, threshold_acc=1)
|
||||
)
|
||||
assert predicts.tolist() == [3, -1, -1, 4, 5, 18, 11, -1, -1, -1, 12, 18]
|
||||
assert accept_index.tolist() == [
|
||||
[0, 3, 4, 5],
|
||||
@@ -99,6 +111,12 @@ def test_tree_speculative_sampling_target_only():
|
||||
]
|
||||
assert accept_token_num.tolist() == [3, 2]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_tree_speculative_sampling_target_only()
|
||||
predicts, accept_index, accept_token_num = (
|
||||
test_tree_speculative_sampling_target_only(threshold_single=0, threshold_acc=0)
|
||||
)
|
||||
assert predicts.tolist() == [1, 2, 18, -1, -1, -1, 11, -1, -1, -1, 12, 18]
|
||||
assert accept_index.tolist() == [
|
||||
[0, 1, 2, -1],
|
||||
[6, 10, 11, -1],
|
||||
]
|
||||
assert accept_token_num.tolist() == [2, 2]
|
||||
Reference in New Issue
Block a user