[Feature] Speculative decoding support lookahead (#9873)
Co-authored-by: a4zhangfei <a4zhangfei@qq.com> Co-authored-by: Qiaolin-Yu <liin1211@outlook.com>
This commit is contained in:
76
sgl-kernel/tests/speculative/test_lookahead_utils.py
Normal file
76
sgl-kernel/tests/speculative/test_lookahead_utils.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from sgl_kernel import reconstruct_indices_from_tree_mask
|
||||
|
||||
|
||||
def test_reconstruct_indices_from_tree_mask():
|
||||
bs = 1
|
||||
num_branch_token = 4
|
||||
seq_lens = torch.tensor([12], device="cuda", dtype=torch.int64)
|
||||
|
||||
retrive_index = torch.full(
|
||||
(bs, num_branch_token), -1, device="cuda", dtype=torch.int64
|
||||
)
|
||||
retrive_next_token = torch.full(
|
||||
(bs, num_branch_token), -1, device="cuda", dtype=torch.int64
|
||||
)
|
||||
retrive_next_sibling = torch.full(
|
||||
(bs, num_branch_token), -1, device="cuda", dtype=torch.int64
|
||||
)
|
||||
positions = torch.empty((bs * num_branch_token), device="cuda", dtype=torch.int64)
|
||||
|
||||
tree_mask = torch.tensor(
|
||||
[
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
1,
|
||||
0,
|
||||
1,
|
||||
0,
|
||||
1,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
],
|
||||
device="cuda",
|
||||
dtype=torch.int32,
|
||||
).to(torch.bool)
|
||||
|
||||
reconstruct_indices_from_tree_mask(
|
||||
tree_mask,
|
||||
seq_lens,
|
||||
positions, # mutable
|
||||
retrive_index, # mutable
|
||||
retrive_next_token, # mutable
|
||||
retrive_next_sibling, # mutable
|
||||
bs,
|
||||
num_branch_token,
|
||||
)
|
||||
# print(f"debug: \n\n{tree_mask=}, {retrive_index=}, {retrive_next_token=}, {retrive_next_sibling=}, {positions=}\n\n")
|
||||
assert retrive_index.tolist() == [
|
||||
[0, 1, 2, 3],
|
||||
], f"{retrive_index=}"
|
||||
assert retrive_next_token.tolist() == [
|
||||
[1, -1, 3, -1],
|
||||
], f"{retrive_next_token=}"
|
||||
assert retrive_next_sibling.tolist() == [
|
||||
[-1, 2, -1, -1],
|
||||
], f"{retrive_next_sibling=}"
|
||||
assert positions.tolist() == [
|
||||
12,
|
||||
13,
|
||||
13,
|
||||
14,
|
||||
], f"{positions=}"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_reconstruct_indices_from_tree_mask()
|
||||
pytest.main([__file__])
|
||||
Reference in New Issue
Block a user