Fix sampling for speculative decoding & simplify kernels (#7207)

This commit is contained in:
Lianmin Zheng
2025-06-16 03:28:30 -07:00
committed by GitHub
parent b1286a116a
commit cfceb83d05
11 changed files with 124 additions and 79 deletions

View File

@@ -10,7 +10,7 @@ def test_verify_tree_greedy():
[0, 1, 2, 3, 4, 5],
[7, 8, 9, 10, 11, 12],
],
dtype=torch.int32,
dtype=torch.int64,
device="cuda",
)
retrive_index = torch.tensor(
@@ -18,7 +18,7 @@ def test_verify_tree_greedy():
[0, 1, 2, 3, 4, 5],
[6, 7, 8, 9, 10, 11],
],
dtype=torch.int32,
dtype=torch.int64,
device="cuda",
)
retrive_next_token = torch.tensor(
@@ -26,7 +26,7 @@ def test_verify_tree_greedy():
[1, 2, -1, 4, 5, -1],
[4, 2, 3, -1, 5, -1],
],
dtype=torch.int32,
dtype=torch.int64,
device="cuda",
)
retrive_next_sibling = torch.tensor(
@@ -34,7 +34,7 @@ def test_verify_tree_greedy():
[-1, 3, -1, -1, -1, -1],
[-1, -1, -1, -1, 1, -1],
],
dtype=torch.int32,
dtype=torch.int64,
device="cuda",
)
@@ -49,12 +49,11 @@ def test_verify_tree_greedy():
if torch.max(target_logits[i][j]) < 10:
target_logits[i][j][18] = 10
target_predict = torch.argmax(target_logits, dim=-1).to(torch.int32)
target_predict = torch.argmax(target_logits, dim=-1)
predict_shape = (12,)
bs = candidates.shape[0]
num_spec_step = 4
num_draft_tokens = candidates.shape[1]
predicts = torch.full(
predict_shape, -1, dtype=torch.int32, device="cuda"