Fix sampling for speculative decoding & simplify kernels (#7207)
This commit is contained in:
@@ -10,7 +10,7 @@ def test_verify_tree_greedy():
|
||||
[0, 1, 2, 3, 4, 5],
|
||||
[7, 8, 9, 10, 11, 12],
|
||||
],
|
||||
dtype=torch.int32,
|
||||
dtype=torch.int64,
|
||||
device="cuda",
|
||||
)
|
||||
retrive_index = torch.tensor(
|
||||
@@ -18,7 +18,7 @@ def test_verify_tree_greedy():
|
||||
[0, 1, 2, 3, 4, 5],
|
||||
[6, 7, 8, 9, 10, 11],
|
||||
],
|
||||
dtype=torch.int32,
|
||||
dtype=torch.int64,
|
||||
device="cuda",
|
||||
)
|
||||
retrive_next_token = torch.tensor(
|
||||
@@ -26,7 +26,7 @@ def test_verify_tree_greedy():
|
||||
[1, 2, -1, 4, 5, -1],
|
||||
[4, 2, 3, -1, 5, -1],
|
||||
],
|
||||
dtype=torch.int32,
|
||||
dtype=torch.int64,
|
||||
device="cuda",
|
||||
)
|
||||
retrive_next_sibling = torch.tensor(
|
||||
@@ -34,7 +34,7 @@ def test_verify_tree_greedy():
|
||||
[-1, 3, -1, -1, -1, -1],
|
||||
[-1, -1, -1, -1, 1, -1],
|
||||
],
|
||||
dtype=torch.int32,
|
||||
dtype=torch.int64,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
@@ -49,12 +49,11 @@ def test_verify_tree_greedy():
|
||||
if torch.max(target_logits[i][j]) < 10:
|
||||
target_logits[i][j][18] = 10
|
||||
|
||||
target_predict = torch.argmax(target_logits, dim=-1).to(torch.int32)
|
||||
target_predict = torch.argmax(target_logits, dim=-1)
|
||||
predict_shape = (12,)
|
||||
|
||||
bs = candidates.shape[0]
|
||||
num_spec_step = 4
|
||||
num_draft_tokens = candidates.shape[1]
|
||||
|
||||
predicts = torch.full(
|
||||
predict_shape, -1, dtype=torch.int32, device="cuda"
|
||||
|
||||
Reference in New Issue
Block a user