Fix sampling for speculative decoding & simplify kernels (#7207)
This commit is contained in:
@@ -10,7 +10,7 @@ def test_verify_tree_greedy():
|
||||
[0, 1, 2, 3, 4, 5],
|
||||
[7, 8, 9, 10, 11, 12],
|
||||
],
|
||||
dtype=torch.int32,
|
||||
dtype=torch.int64,
|
||||
device="cuda",
|
||||
)
|
||||
retrive_index = torch.tensor(
|
||||
@@ -18,7 +18,7 @@ def test_verify_tree_greedy():
|
||||
[0, 1, 2, 3, 4, 5],
|
||||
[6, 7, 8, 9, 10, 11],
|
||||
],
|
||||
dtype=torch.int32,
|
||||
dtype=torch.int64,
|
||||
device="cuda",
|
||||
)
|
||||
retrive_next_token = torch.tensor(
|
||||
@@ -26,7 +26,7 @@ def test_verify_tree_greedy():
|
||||
[1, 2, -1, 4, 5, -1],
|
||||
[4, 2, 3, -1, 5, -1],
|
||||
],
|
||||
dtype=torch.int32,
|
||||
dtype=torch.int64,
|
||||
device="cuda",
|
||||
)
|
||||
retrive_next_sibling = torch.tensor(
|
||||
@@ -34,7 +34,7 @@ def test_verify_tree_greedy():
|
||||
[-1, 3, -1, -1, -1, -1],
|
||||
[-1, -1, -1, -1, 1, -1],
|
||||
],
|
||||
dtype=torch.int32,
|
||||
dtype=torch.int64,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
@@ -49,12 +49,11 @@ def test_verify_tree_greedy():
|
||||
if torch.max(target_logits[i][j]) < 10:
|
||||
target_logits[i][j][18] = 10
|
||||
|
||||
target_predict = torch.argmax(target_logits, dim=-1).to(torch.int32)
|
||||
target_predict = torch.argmax(target_logits, dim=-1)
|
||||
predict_shape = (12,)
|
||||
|
||||
bs = candidates.shape[0]
|
||||
num_spec_step = 4
|
||||
num_draft_tokens = candidates.shape[1]
|
||||
|
||||
predicts = torch.full(
|
||||
predict_shape, -1, dtype=torch.int32, device="cuda"
|
||||
|
||||
@@ -42,7 +42,7 @@ def test_tree_speculative_sampling_target_only(
|
||||
[0, 1, 2, 3, 4, 5],
|
||||
[7, 8, 9, 10, 11, 12],
|
||||
],
|
||||
dtype=torch.int32,
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
retrive_index = torch.tensor(
|
||||
@@ -50,7 +50,7 @@ def test_tree_speculative_sampling_target_only(
|
||||
[0, 1, 2, 3, 4, 5],
|
||||
[6, 7, 8, 9, 10, 11],
|
||||
],
|
||||
dtype=torch.int32,
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
retrive_next_token = torch.tensor(
|
||||
@@ -58,7 +58,7 @@ def test_tree_speculative_sampling_target_only(
|
||||
[1, 2, -1, 4, 5, -1],
|
||||
[4, 2, 3, -1, 5, -1],
|
||||
],
|
||||
dtype=torch.int32,
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
retrive_next_sibling = torch.tensor(
|
||||
@@ -66,7 +66,7 @@ def test_tree_speculative_sampling_target_only(
|
||||
[-1, 3, -1, -1, -1, -1],
|
||||
[-1, -1, -1, -1, 1, -1],
|
||||
],
|
||||
dtype=torch.int32,
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
|
||||
@@ -95,6 +95,7 @@ def test_tree_speculative_sampling_target_only(
|
||||
target_probs = F.softmax(target_logits / expanded_temperature, dim=-1)
|
||||
draft_probs = torch.full_like(target_probs, 0, dtype=torch.float32, device=device)
|
||||
coins = torch.rand(bs, num_draft_tokens, device=device, dtype=torch.float32)
|
||||
coins_for_final_sampling = torch.rand(bs, device=device).to(torch.float32)
|
||||
|
||||
tree_speculative_sampling_target_only(
|
||||
predicts=predicts,
|
||||
@@ -105,6 +106,7 @@ def test_tree_speculative_sampling_target_only(
|
||||
retrive_next_token=retrive_next_token,
|
||||
retrive_next_sibling=retrive_next_sibling,
|
||||
uniform_samples=coins,
|
||||
uniform_samples_for_final_sampling=coins_for_final_sampling,
|
||||
target_probs=target_probs,
|
||||
draft_probs=draft_probs,
|
||||
threshold_single=threshold_single,
|
||||
|
||||
Reference in New Issue
Block a user