Fix sampling for speculative decoding & simplify kernels (#7207)

This commit is contained in:
Lianmin Zheng
2025-06-16 03:28:30 -07:00
committed by GitHub
parent b1286a116a
commit cfceb83d05
11 changed files with 124 additions and 79 deletions

View File

@@ -10,7 +10,7 @@ def test_verify_tree_greedy():
[0, 1, 2, 3, 4, 5],
[7, 8, 9, 10, 11, 12],
],
dtype=torch.int32,
dtype=torch.int64,
device="cuda",
)
retrive_index = torch.tensor(
@@ -18,7 +18,7 @@ def test_verify_tree_greedy():
[0, 1, 2, 3, 4, 5],
[6, 7, 8, 9, 10, 11],
],
dtype=torch.int32,
dtype=torch.int64,
device="cuda",
)
retrive_next_token = torch.tensor(
@@ -26,7 +26,7 @@ def test_verify_tree_greedy():
[1, 2, -1, 4, 5, -1],
[4, 2, 3, -1, 5, -1],
],
dtype=torch.int32,
dtype=torch.int64,
device="cuda",
)
retrive_next_sibling = torch.tensor(
@@ -34,7 +34,7 @@ def test_verify_tree_greedy():
[-1, 3, -1, -1, -1, -1],
[-1, -1, -1, -1, 1, -1],
],
dtype=torch.int32,
dtype=torch.int64,
device="cuda",
)
@@ -49,12 +49,11 @@ def test_verify_tree_greedy():
if torch.max(target_logits[i][j]) < 10:
target_logits[i][j][18] = 10
target_predict = torch.argmax(target_logits, dim=-1).to(torch.int32)
target_predict = torch.argmax(target_logits, dim=-1)
predict_shape = (12,)
bs = candidates.shape[0]
num_spec_step = 4
num_draft_tokens = candidates.shape[1]
predicts = torch.full(
predict_shape, -1, dtype=torch.int32, device="cuda"

View File

@@ -42,7 +42,7 @@ def test_tree_speculative_sampling_target_only(
[0, 1, 2, 3, 4, 5],
[7, 8, 9, 10, 11, 12],
],
dtype=torch.int32,
dtype=torch.int64,
device=device,
)
retrive_index = torch.tensor(
@@ -50,7 +50,7 @@ def test_tree_speculative_sampling_target_only(
[0, 1, 2, 3, 4, 5],
[6, 7, 8, 9, 10, 11],
],
dtype=torch.int32,
dtype=torch.int64,
device=device,
)
retrive_next_token = torch.tensor(
@@ -58,7 +58,7 @@ def test_tree_speculative_sampling_target_only(
[1, 2, -1, 4, 5, -1],
[4, 2, 3, -1, 5, -1],
],
dtype=torch.int32,
dtype=torch.int64,
device=device,
)
retrive_next_sibling = torch.tensor(
@@ -66,7 +66,7 @@ def test_tree_speculative_sampling_target_only(
[-1, 3, -1, -1, -1, -1],
[-1, -1, -1, -1, 1, -1],
],
dtype=torch.int32,
dtype=torch.int64,
device=device,
)
@@ -95,6 +95,7 @@ def test_tree_speculative_sampling_target_only(
target_probs = F.softmax(target_logits / expanded_temperature, dim=-1)
draft_probs = torch.full_like(target_probs, 0, dtype=torch.float32, device=device)
coins = torch.rand(bs, num_draft_tokens, device=device, dtype=torch.float32)
coins_for_final_sampling = torch.rand(bs, device=device).to(torch.float32)
tree_speculative_sampling_target_only(
predicts=predicts,
@@ -105,6 +106,7 @@ def test_tree_speculative_sampling_target_only(
retrive_next_token=retrive_next_token,
retrive_next_sibling=retrive_next_sibling,
uniform_samples=coins,
uniform_samples_for_final_sampling=coins_for_final_sampling,
target_probs=target_probs,
draft_probs=draft_probs,
threshold_single=threshold_single,