move eagle draft post process to cuda graph (#11434)
Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com>
This commit is contained in:
@@ -1,427 +0,0 @@
|
|||||||
# NOTE: Please run this file to make sure the test cases are correct.
|
|
||||||
|
|
||||||
import math
|
|
||||||
from enum import IntEnum
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from sglang.srt.utils import is_cuda, is_hip
|
|
||||||
|
|
||||||
if is_cuda() or is_hip():
|
|
||||||
from sgl_kernel import (
|
|
||||||
build_tree_kernel_efficient as sgl_build_tree_kernel_efficient,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def build_tree_kernel_efficient_preprocess(
|
|
||||||
verified_id: torch.Tensor,
|
|
||||||
score_list: List[torch.Tensor],
|
|
||||||
token_list: List[torch.Tensor],
|
|
||||||
parents_list: List[torch.Tensor],
|
|
||||||
num_verify_tokens: int,
|
|
||||||
):
|
|
||||||
score_list = torch.cat(score_list, dim=1).flatten(
|
|
||||||
1
|
|
||||||
) # b, n, topk; n= 1 + (num_steps-1) * self.topk
|
|
||||||
ss_token_list = torch.cat(
|
|
||||||
token_list, dim=1
|
|
||||||
) # b, (self.topk + (num_steps-1) * self.topk)
|
|
||||||
top_scores = torch.topk(score_list, num_verify_tokens - 1, dim=-1)
|
|
||||||
top_scores_index = top_scores.indices
|
|
||||||
top_scores_index = torch.sort(top_scores_index).values
|
|
||||||
draft_tokens = torch.gather(ss_token_list, index=top_scores_index, dim=1)
|
|
||||||
draft_tokens = torch.cat((verified_id.unsqueeze(1), draft_tokens), dim=1).flatten()
|
|
||||||
|
|
||||||
if len(parents_list) > 1:
|
|
||||||
parent_list = torch.cat(parents_list[:-1], dim=1)
|
|
||||||
else:
|
|
||||||
batch_size = parents_list[0].shape[0]
|
|
||||||
parent_list = torch.empty(batch_size, 0, device=parents_list[0].device)
|
|
||||||
|
|
||||||
return parent_list, top_scores_index, draft_tokens
|
|
||||||
|
|
||||||
|
|
||||||
class TreeMaskMode(IntEnum):
|
|
||||||
FULL_MASK = 0
|
|
||||||
QLEN_ONLY = 1
|
|
||||||
QLEN_ONLY_BITPACKING = 2
|
|
||||||
|
|
||||||
|
|
||||||
def build_tree_kernel_efficient(
|
|
||||||
verified_id: torch.Tensor,
|
|
||||||
score_list: List[torch.Tensor],
|
|
||||||
token_list: List[torch.Tensor],
|
|
||||||
parents_list: List[torch.Tensor],
|
|
||||||
seq_lens: torch.Tensor,
|
|
||||||
seq_lens_sum: int,
|
|
||||||
topk: int,
|
|
||||||
spec_steps: int,
|
|
||||||
num_verify_tokens: int,
|
|
||||||
tree_mask_mode: TreeMaskMode = TreeMaskMode.FULL_MASK,
|
|
||||||
tree_mask_buf: Optional[torch.Tensor] = None,
|
|
||||||
position_buf: Optional[torch.Tensor] = None,
|
|
||||||
):
|
|
||||||
parent_list, top_scores_index, draft_tokens = (
|
|
||||||
build_tree_kernel_efficient_preprocess(
|
|
||||||
verified_id,
|
|
||||||
score_list,
|
|
||||||
token_list,
|
|
||||||
parents_list,
|
|
||||||
num_verify_tokens,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# seq_lens_sum == sum(seq_lens); seq_lens: sequence length without draft tokens
|
|
||||||
bs = seq_lens.numel()
|
|
||||||
device = seq_lens.device
|
|
||||||
# e.g. for bs=1, tree_mask: num_draft_token, seq_lens_sum + num_draft_token (flattened)
|
|
||||||
# where each row indicates the attending pattern of each draft token
|
|
||||||
# if use_partial_packed_tree_mask is True, tree_mask: num_draft_token (flattened, packed)
|
|
||||||
if tree_mask_buf is not None:
|
|
||||||
tree_mask = tree_mask_buf
|
|
||||||
elif tree_mask_mode == TreeMaskMode.QLEN_ONLY:
|
|
||||||
tree_mask = torch.full(
|
|
||||||
(num_verify_tokens * bs * num_verify_tokens,),
|
|
||||||
True,
|
|
||||||
dtype=torch.bool,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
elif tree_mask_mode == TreeMaskMode.QLEN_ONLY_BITPACKING:
|
|
||||||
packed_dtypes = [torch.uint8, torch.uint16, torch.uint32]
|
|
||||||
packed_dtype_idx = int(math.ceil(math.log2((num_verify_tokens + 7) // 8)))
|
|
||||||
tree_mask = torch.zeros(
|
|
||||||
(num_verify_tokens * bs,),
|
|
||||||
dtype=packed_dtypes[packed_dtype_idx],
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
elif tree_mask_mode == TreeMaskMode.FULL_MASK:
|
|
||||||
tree_mask = torch.full(
|
|
||||||
(
|
|
||||||
seq_lens_sum * num_verify_tokens
|
|
||||||
+ num_verify_tokens * num_verify_tokens * bs,
|
|
||||||
),
|
|
||||||
True,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f"Invalid tree mask: {tree_mask_mode=}")
|
|
||||||
|
|
||||||
# TODO: make them torch.empty and fuse them into `sgl_build_tree_kernel`
|
|
||||||
retrive_index = torch.full(
|
|
||||||
(bs, num_verify_tokens), -1, device=device, dtype=torch.long
|
|
||||||
)
|
|
||||||
retrive_next_token = torch.full(
|
|
||||||
(bs, num_verify_tokens), -1, device=device, dtype=torch.long
|
|
||||||
)
|
|
||||||
retrive_next_sibling = torch.full(
|
|
||||||
(bs, num_verify_tokens), -1, device=device, dtype=torch.long
|
|
||||||
)
|
|
||||||
# position: where each token belongs to
|
|
||||||
# e.g. if depth of each draft token is [0, 1, 1, 2] and the prompt length is 7
|
|
||||||
# then, positions = [7, 8, 8, 9]
|
|
||||||
if position_buf is not None:
|
|
||||||
positions = position_buf
|
|
||||||
else:
|
|
||||||
positions = torch.empty(
|
|
||||||
(bs * num_verify_tokens,), device=device, dtype=torch.long
|
|
||||||
)
|
|
||||||
|
|
||||||
sgl_build_tree_kernel_efficient(
|
|
||||||
parent_list,
|
|
||||||
top_scores_index,
|
|
||||||
seq_lens,
|
|
||||||
tree_mask,
|
|
||||||
positions,
|
|
||||||
retrive_index,
|
|
||||||
retrive_next_token,
|
|
||||||
retrive_next_sibling,
|
|
||||||
topk,
|
|
||||||
spec_steps,
|
|
||||||
num_verify_tokens,
|
|
||||||
tree_mask_mode,
|
|
||||||
)
|
|
||||||
return (
|
|
||||||
tree_mask,
|
|
||||||
positions,
|
|
||||||
retrive_index,
|
|
||||||
retrive_next_token,
|
|
||||||
retrive_next_sibling,
|
|
||||||
draft_tokens,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_build_tree_kernel_efficient():
|
|
||||||
verified_id = torch.tensor([29974, 13], device="cuda", dtype=torch.int32)
|
|
||||||
score_list = [
|
|
||||||
torch.tensor(
|
|
||||||
[
|
|
||||||
[[7.1127e-01, 2.8292e-01, 2.2995e-03, 1.7357e-03]],
|
|
||||||
[[9.7476e-01, 2.2219e-02, 6.5031e-04, 1.3212e-04]],
|
|
||||||
],
|
|
||||||
dtype=torch.float32,
|
|
||||||
device="cuda",
|
|
||||||
),
|
|
||||||
torch.tensor(
|
|
||||||
[
|
|
||||||
[
|
|
||||||
[6.9142e-01, 1.2863e-02, 1.6873e-03, 1.1871e-03],
|
|
||||||
[2.4787e-01, 1.8818e-02, 1.4204e-02, 9.2235e-04],
|
|
||||||
[2.2971e-03, 1.6700e-06, 1.8737e-07, 8.3146e-08],
|
|
||||||
[1.2771e-03, 2.4374e-04, 1.7832e-04, 1.1947e-05],
|
|
||||||
],
|
|
||||||
[
|
|
||||||
[8.4832e-02, 6.6068e-02, 5.8304e-02, 5.7851e-02],
|
|
||||||
[2.3616e-03, 1.1243e-03, 5.4368e-04, 2.7768e-04],
|
|
||||||
[2.5286e-04, 1.5578e-04, 2.8817e-05, 1.2888e-05],
|
|
||||||
[1.2834e-04, 2.5417e-06, 1.1279e-06, 1.6088e-08],
|
|
||||||
],
|
|
||||||
],
|
|
||||||
dtype=torch.float32,
|
|
||||||
device="cuda",
|
|
||||||
),
|
|
||||||
torch.tensor(
|
|
||||||
[
|
|
||||||
[
|
|
||||||
[6.6438e-01, 2.6997e-02, 2.4236e-05, 4.0821e-06],
|
|
||||||
[2.4402e-01, 2.8409e-03, 5.0935e-04, 2.9022e-04],
|
|
||||||
[1.6178e-02, 2.0567e-03, 4.5892e-04, 3.0034e-05],
|
|
||||||
[1.3023e-02, 5.0497e-04, 3.6371e-04, 8.7750e-05],
|
|
||||||
],
|
|
||||||
[
|
|
||||||
[2.3263e-02, 2.0054e-02, 9.3990e-03, 2.7783e-03],
|
|
||||||
[6.4156e-02, 5.5506e-04, 1.0429e-04, 9.7211e-05],
|
|
||||||
[4.9950e-02, 5.0630e-03, 9.0068e-04, 3.3656e-04],
|
|
||||||
[7.5817e-03, 8.5731e-04, 6.9972e-04, 6.0793e-04],
|
|
||||||
],
|
|
||||||
],
|
|
||||||
dtype=torch.float32,
|
|
||||||
device="cuda",
|
|
||||||
),
|
|
||||||
torch.tensor(
|
|
||||||
[
|
|
||||||
[
|
|
||||||
[6.6420e-01, 1.0525e-04, 6.5864e-05, 1.2253e-06],
|
|
||||||
[1.3019e-01, 1.0461e-01, 5.2083e-03, 1.6777e-03],
|
|
||||||
[2.0103e-02, 6.7335e-03, 1.2625e-04, 1.0364e-05],
|
|
||||||
[1.5142e-02, 7.0819e-04, 9.6595e-05, 8.7951e-05],
|
|
||||||
],
|
|
||||||
[
|
|
||||||
[5.8608e-02, 1.8840e-03, 7.8535e-04, 4.4400e-04],
|
|
||||||
[1.2185e-02, 2.0684e-03, 1.7418e-03, 1.4327e-03],
|
|
||||||
[6.2455e-03, 6.1487e-03, 2.6862e-03, 1.8034e-03],
|
|
||||||
[1.8590e-03, 1.6151e-03, 1.2481e-03, 3.6038e-04],
|
|
||||||
],
|
|
||||||
],
|
|
||||||
dtype=torch.float32,
|
|
||||||
device="cuda",
|
|
||||||
),
|
|
||||||
]
|
|
||||||
token_list = [
|
|
||||||
torch.tensor(
|
|
||||||
[[29896, 29906, 29900, 29945], [13, 2, 29871, 28956]],
|
|
||||||
dtype=torch.int64,
|
|
||||||
device="cuda",
|
|
||||||
),
|
|
||||||
torch.tensor(
|
|
||||||
[
|
|
||||||
[
|
|
||||||
29889,
|
|
||||||
29974,
|
|
||||||
29945,
|
|
||||||
29900,
|
|
||||||
29974,
|
|
||||||
29922,
|
|
||||||
29930,
|
|
||||||
29958,
|
|
||||||
29889,
|
|
||||||
29974,
|
|
||||||
29930,
|
|
||||||
29945,
|
|
||||||
29974,
|
|
||||||
29922,
|
|
||||||
29930,
|
|
||||||
29958,
|
|
||||||
],
|
|
||||||
[
|
|
||||||
22550,
|
|
||||||
4136,
|
|
||||||
16492,
|
|
||||||
8439,
|
|
||||||
29871,
|
|
||||||
2,
|
|
||||||
3001,
|
|
||||||
13,
|
|
||||||
2,
|
|
||||||
13,
|
|
||||||
29906,
|
|
||||||
29946,
|
|
||||||
2,
|
|
||||||
13,
|
|
||||||
29871,
|
|
||||||
259,
|
|
||||||
],
|
|
||||||
],
|
|
||||||
device="cuda",
|
|
||||||
),
|
|
||||||
torch.tensor(
|
|
||||||
[
|
|
||||||
[
|
|
||||||
29946,
|
|
||||||
29945,
|
|
||||||
29953,
|
|
||||||
29906,
|
|
||||||
29896,
|
|
||||||
29945,
|
|
||||||
29900,
|
|
||||||
29906,
|
|
||||||
29896,
|
|
||||||
29945,
|
|
||||||
29906,
|
|
||||||
29953,
|
|
||||||
29896,
|
|
||||||
29945,
|
|
||||||
29906,
|
|
||||||
29946,
|
|
||||||
],
|
|
||||||
[
|
|
||||||
29871,
|
|
||||||
2,
|
|
||||||
29901,
|
|
||||||
29889,
|
|
||||||
29871,
|
|
||||||
2,
|
|
||||||
395,
|
|
||||||
259,
|
|
||||||
29901,
|
|
||||||
29871,
|
|
||||||
2,
|
|
||||||
29889,
|
|
||||||
3001,
|
|
||||||
1234,
|
|
||||||
7146,
|
|
||||||
2186,
|
|
||||||
],
|
|
||||||
],
|
|
||||||
device="cuda",
|
|
||||||
),
|
|
||||||
torch.tensor(
|
|
||||||
[
|
|
||||||
[
|
|
||||||
29946,
|
|
||||||
29974,
|
|
||||||
29945,
|
|
||||||
29930,
|
|
||||||
29889,
|
|
||||||
29922,
|
|
||||||
29974,
|
|
||||||
29930,
|
|
||||||
29974,
|
|
||||||
29946,
|
|
||||||
29930,
|
|
||||||
29922,
|
|
||||||
29889,
|
|
||||||
29974,
|
|
||||||
29945,
|
|
||||||
29922,
|
|
||||||
],
|
|
||||||
[
|
|
||||||
29941,
|
|
||||||
29906,
|
|
||||||
2,
|
|
||||||
29946,
|
|
||||||
29871,
|
|
||||||
450,
|
|
||||||
319,
|
|
||||||
14990,
|
|
||||||
29946,
|
|
||||||
29941,
|
|
||||||
2,
|
|
||||||
29906,
|
|
||||||
29871,
|
|
||||||
2,
|
|
||||||
3001,
|
|
||||||
13,
|
|
||||||
],
|
|
||||||
],
|
|
||||||
device="cuda",
|
|
||||||
),
|
|
||||||
]
|
|
||||||
parents_list = [
|
|
||||||
torch.tensor(
|
|
||||||
[[-1, 0, 1, 2, 3], [-1, 0, 1, 2, 3]], dtype=torch.int64, device="cuda"
|
|
||||||
),
|
|
||||||
torch.tensor([[4, 8, 9, 10], [4, 5, 6, 7]], dtype=torch.int64, device="cuda"),
|
|
||||||
torch.tensor(
|
|
||||||
[[20, 24, 21, 28], [24, 28, 20, 21]], dtype=torch.int64, device="cuda"
|
|
||||||
),
|
|
||||||
torch.tensor(
|
|
||||||
[[36, 40, 41, 44], [36, 40, 44, 45]], dtype=torch.int64, device="cuda"
|
|
||||||
),
|
|
||||||
]
|
|
||||||
seq_lens = torch.tensor([5, 10], dtype=torch.int64, device="cuda")
|
|
||||||
topk = 4
|
|
||||||
depth = 4
|
|
||||||
num_draft_token = 8
|
|
||||||
|
|
||||||
(
|
|
||||||
tree_mask,
|
|
||||||
position,
|
|
||||||
retrive_index,
|
|
||||||
retrive_next_token,
|
|
||||||
retrive_next_sibling,
|
|
||||||
draft_tokens,
|
|
||||||
) = build_tree_kernel_efficient(
|
|
||||||
verified_id=verified_id,
|
|
||||||
score_list=score_list,
|
|
||||||
token_list=token_list,
|
|
||||||
parents_list=parents_list,
|
|
||||||
seq_lens=seq_lens,
|
|
||||||
seq_lens_sum=torch.sum(seq_lens).item(),
|
|
||||||
topk=topk,
|
|
||||||
spec_steps=depth,
|
|
||||||
num_verify_tokens=num_draft_token,
|
|
||||||
)
|
|
||||||
|
|
||||||
print("=========== build tree kernel efficient ==========")
|
|
||||||
print(f"{tree_mask=}")
|
|
||||||
print(f"{position=}")
|
|
||||||
print(f"{retrive_index=}")
|
|
||||||
print(f"{retrive_next_token=}")
|
|
||||||
print(f"{retrive_next_sibling=}")
|
|
||||||
print(f"{draft_tokens=}")
|
|
||||||
assert position.tolist() == [5, 6, 6, 7, 7, 8, 8, 9, 10, 11, 12, 12, 12, 12, 13, 14]
|
|
||||||
assert retrive_index.tolist() == [
|
|
||||||
[0, 1, 2, 3, 4, 5, 6, 7],
|
|
||||||
[8, 9, 10, 11, 12, 13, 14, 15],
|
|
||||||
]
|
|
||||||
assert retrive_next_token.tolist() == [
|
|
||||||
[1, 3, 4, 5, 6, 7, -1, -1],
|
|
||||||
[1, 2, -1, 6, -1, -1, 7, -1],
|
|
||||||
]
|
|
||||||
assert retrive_next_sibling.tolist() == [
|
|
||||||
[-1, 2, -1, -1, -1, -1, -1, -1],
|
|
||||||
[-1, -1, 3, 4, 5, -1, -1, -1],
|
|
||||||
]
|
|
||||||
assert draft_tokens.tolist() == [
|
|
||||||
29974,
|
|
||||||
29896,
|
|
||||||
29906,
|
|
||||||
29889,
|
|
||||||
29974,
|
|
||||||
29946,
|
|
||||||
29896,
|
|
||||||
29946,
|
|
||||||
13,
|
|
||||||
13,
|
|
||||||
22550,
|
|
||||||
4136,
|
|
||||||
16492,
|
|
||||||
8439,
|
|
||||||
29871,
|
|
||||||
29941,
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
test_build_tree_kernel_efficient()
|
|
||||||
@@ -276,11 +276,9 @@ class EAGLEDraftCudaGraphRunner:
|
|||||||
return graph, out
|
return graph, out
|
||||||
|
|
||||||
def _postprocess_output_to_raw_bs(self, out, raw_bs):
|
def _postprocess_output_to_raw_bs(self, out, raw_bs):
|
||||||
score_list, token_list, parents_list = out
|
# Keep the variables name for readability
|
||||||
score_list = [x[:raw_bs] for x in score_list]
|
parent_list, top_scores_index, draft_tokens = (t[:raw_bs] for t in out)
|
||||||
token_list = [x[:raw_bs] for x in token_list]
|
return parent_list, top_scores_index, draft_tokens
|
||||||
parents_list = [x[:raw_bs] for x in parents_list]
|
|
||||||
return (score_list, token_list, parents_list)
|
|
||||||
|
|
||||||
def replay(self, forward_batch: ForwardBatch):
|
def replay(self, forward_batch: ForwardBatch):
|
||||||
assert forward_batch.out_cache_loc is not None
|
assert forward_batch.out_cache_loc is not None
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import math
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Any, List, Optional
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@@ -19,7 +18,6 @@ from sglang.srt.model_executor.forward_batch_info import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
from sglang.srt.server_args import get_global_server_args
|
from sglang.srt.server_args import get_global_server_args
|
||||||
from sglang.srt.speculative.build_eagle_tree import TreeMaskMode
|
|
||||||
from sglang.srt.speculative.spec_utils import (
|
from sglang.srt.speculative.spec_utils import (
|
||||||
SIMULATE_ACC_LEN,
|
SIMULATE_ACC_LEN,
|
||||||
generate_simulated_accept_index,
|
generate_simulated_accept_index,
|
||||||
@@ -286,110 +284,6 @@ class EagleVerifyInputV2Mixin:
|
|||||||
return predict, accept_length, accept_index
|
return predict, accept_length, accept_index
|
||||||
|
|
||||||
|
|
||||||
def build_tree_kernel_efficient_tmp(
|
|
||||||
verified_id: torch.Tensor,
|
|
||||||
parent_list: List[torch.Tensor],
|
|
||||||
top_scores_index: torch.Tensor,
|
|
||||||
draft_tokens: torch.Tensor,
|
|
||||||
seq_lens: torch.Tensor,
|
|
||||||
seq_lens_sum: int,
|
|
||||||
topk: int,
|
|
||||||
spec_steps: int,
|
|
||||||
num_verify_tokens: int,
|
|
||||||
tree_mask_mode: TreeMaskMode = TreeMaskMode.FULL_MASK,
|
|
||||||
tree_mask_buf: Optional[torch.Tensor] = None,
|
|
||||||
position_buf: Optional[torch.Tensor] = None,
|
|
||||||
):
|
|
||||||
# TODO(lsyin): make it compatible with default code path
|
|
||||||
# TODO(lsyin): support cuda graph graph padding for eagle
|
|
||||||
draft_tokens = torch.cat((verified_id.unsqueeze(1), draft_tokens), dim=1).flatten()
|
|
||||||
|
|
||||||
# seq_lens_sum == sum(seq_lens); seq_lens: sequence length without draft tokens
|
|
||||||
bs = seq_lens.numel()
|
|
||||||
device = seq_lens.device
|
|
||||||
# e.g. for bs=1, tree_mask: num_draft_token, seq_lens_sum + num_draft_token (flattened)
|
|
||||||
# where each row indicates the attending pattern of each draft token
|
|
||||||
# if use_partial_packed_tree_mask is True, tree_mask: num_draft_token (flattened, packed)
|
|
||||||
if tree_mask_buf is not None:
|
|
||||||
tree_mask = tree_mask_buf
|
|
||||||
if tree_mask_mode == TreeMaskMode.QLEN_ONLY:
|
|
||||||
tree_mask.fill_(True)
|
|
||||||
elif tree_mask_mode == TreeMaskMode.QLEN_ONLY_BITPACKING:
|
|
||||||
tree_mask.fill_(0)
|
|
||||||
elif tree_mask_mode == TreeMaskMode.FULL_MASK:
|
|
||||||
tree_mask.fill_(True)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f"Invalid tree mask: {tree_mask_mode=}")
|
|
||||||
elif tree_mask_mode == TreeMaskMode.QLEN_ONLY:
|
|
||||||
tree_mask = torch.full(
|
|
||||||
(num_verify_tokens * bs * num_verify_tokens,),
|
|
||||||
True,
|
|
||||||
dtype=torch.bool,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
elif tree_mask_mode == TreeMaskMode.QLEN_ONLY_BITPACKING:
|
|
||||||
packed_dtypes = [torch.uint8, torch.uint16, torch.uint32]
|
|
||||||
packed_dtype_idx = int(math.ceil(math.log2((num_verify_tokens + 7) // 8)))
|
|
||||||
tree_mask = torch.zeros(
|
|
||||||
(num_verify_tokens * bs,),
|
|
||||||
dtype=packed_dtypes[packed_dtype_idx],
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
elif tree_mask_mode == TreeMaskMode.FULL_MASK:
|
|
||||||
tree_mask = torch.full(
|
|
||||||
(
|
|
||||||
seq_lens_sum * num_verify_tokens
|
|
||||||
+ num_verify_tokens * num_verify_tokens * bs,
|
|
||||||
),
|
|
||||||
True,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f"Invalid tree mask: {tree_mask_mode=}")
|
|
||||||
|
|
||||||
# TODO: make them torch.empty and fuse them into `sgl_build_tree_kernel`
|
|
||||||
retrive_buf = torch.full(
|
|
||||||
(3, bs, num_verify_tokens), -1, device=device, dtype=torch.long
|
|
||||||
)
|
|
||||||
retrive_index, retrive_next_token, retrive_next_sibling = retrive_buf
|
|
||||||
# position: where each token belongs to
|
|
||||||
# e.g. if depth of each draft token is [0, 1, 1, 2] and the prompt length is 7
|
|
||||||
# then, positions = [7, 8, 8, 9]
|
|
||||||
if position_buf is not None:
|
|
||||||
positions = position_buf
|
|
||||||
else:
|
|
||||||
positions = torch.empty(
|
|
||||||
(bs * num_verify_tokens,), device=device, dtype=torch.long
|
|
||||||
)
|
|
||||||
|
|
||||||
from sgl_kernel import (
|
|
||||||
build_tree_kernel_efficient as sgl_build_tree_kernel_efficient,
|
|
||||||
)
|
|
||||||
|
|
||||||
sgl_build_tree_kernel_efficient(
|
|
||||||
parent_list,
|
|
||||||
top_scores_index,
|
|
||||||
seq_lens,
|
|
||||||
tree_mask,
|
|
||||||
positions,
|
|
||||||
retrive_index,
|
|
||||||
retrive_next_token,
|
|
||||||
retrive_next_sibling,
|
|
||||||
topk,
|
|
||||||
spec_steps,
|
|
||||||
num_verify_tokens,
|
|
||||||
tree_mask_mode,
|
|
||||||
)
|
|
||||||
return (
|
|
||||||
tree_mask,
|
|
||||||
positions,
|
|
||||||
retrive_index,
|
|
||||||
retrive_next_token,
|
|
||||||
retrive_next_sibling,
|
|
||||||
draft_tokens,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@torch.compile(dynamic=True)
|
@torch.compile(dynamic=True)
|
||||||
def select_top_k_tokens_tmp(
|
def select_top_k_tokens_tmp(
|
||||||
i: int,
|
i: int,
|
||||||
|
|||||||
138
python/sglang/srt/speculative/eagle_utils.py
Normal file
138
python/sglang/srt/speculative/eagle_utils.py
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
import math
|
||||||
|
from enum import IntEnum
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.utils import is_cuda, is_hip
|
||||||
|
|
||||||
|
if is_cuda() or is_hip():
|
||||||
|
from sgl_kernel import (
|
||||||
|
build_tree_kernel_efficient as sgl_build_tree_kernel_efficient,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def organize_draft_results(
|
||||||
|
score_list: List[torch.Tensor],
|
||||||
|
token_list: List[torch.Tensor],
|
||||||
|
parents_list: List[torch.Tensor],
|
||||||
|
num_draft_token: int,
|
||||||
|
):
|
||||||
|
score_list = torch.cat(score_list, dim=1).flatten(1)
|
||||||
|
ss_token_list = torch.cat(token_list, dim=1)
|
||||||
|
top_scores = torch.topk(score_list, num_draft_token - 1, dim=-1)
|
||||||
|
top_scores_index = top_scores.indices
|
||||||
|
top_scores_index = torch.sort(top_scores_index).values
|
||||||
|
draft_tokens = torch.gather(ss_token_list, index=top_scores_index, dim=1)
|
||||||
|
|
||||||
|
if len(parents_list) > 1:
|
||||||
|
parent_list = torch.cat(parents_list[:-1], dim=1)
|
||||||
|
else:
|
||||||
|
batch_size = parents_list[0].shape[0]
|
||||||
|
parent_list = torch.empty(batch_size, 0, device=parents_list[0].device)
|
||||||
|
|
||||||
|
return parent_list, top_scores_index, draft_tokens
|
||||||
|
|
||||||
|
|
||||||
|
class TreeMaskMode(IntEnum):
|
||||||
|
FULL_MASK = 0
|
||||||
|
QLEN_ONLY = 1
|
||||||
|
QLEN_ONLY_BITPACKING = 2
|
||||||
|
|
||||||
|
|
||||||
|
def build_tree_kernel_efficient(
|
||||||
|
verified_id: torch.Tensor,
|
||||||
|
parent_list: List[torch.Tensor],
|
||||||
|
top_scores_index: torch.Tensor,
|
||||||
|
draft_tokens: torch.Tensor,
|
||||||
|
seq_lens: torch.Tensor,
|
||||||
|
seq_lens_sum: int,
|
||||||
|
topk: int,
|
||||||
|
spec_steps: int,
|
||||||
|
num_verify_tokens: int,
|
||||||
|
tree_mask_mode: TreeMaskMode = TreeMaskMode.FULL_MASK,
|
||||||
|
tree_mask_buf: Optional[torch.Tensor] = None,
|
||||||
|
position_buf: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
draft_tokens = torch.cat((verified_id.unsqueeze(1), draft_tokens), dim=1).flatten()
|
||||||
|
|
||||||
|
# seq_lens_sum == sum(seq_lens); seq_lens: sequence length without draft tokens
|
||||||
|
bs = seq_lens.numel()
|
||||||
|
device = seq_lens.device
|
||||||
|
# e.g. for bs=1, tree_mask: num_draft_token, seq_lens_sum + num_draft_token (flattened)
|
||||||
|
# where each row indicates the attending pattern of each draft token
|
||||||
|
# if use_partial_packed_tree_mask is True, tree_mask: num_draft_token (flattened, packed)
|
||||||
|
if tree_mask_buf is not None:
|
||||||
|
tree_mask = tree_mask_buf
|
||||||
|
if tree_mask_mode == TreeMaskMode.QLEN_ONLY:
|
||||||
|
tree_mask.fill_(True)
|
||||||
|
elif tree_mask_mode == TreeMaskMode.QLEN_ONLY_BITPACKING:
|
||||||
|
tree_mask.fill_(0)
|
||||||
|
elif tree_mask_mode == TreeMaskMode.FULL_MASK:
|
||||||
|
tree_mask.fill_(True)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Invalid tree mask: {tree_mask_mode=}")
|
||||||
|
elif tree_mask_mode == TreeMaskMode.QLEN_ONLY:
|
||||||
|
tree_mask = torch.full(
|
||||||
|
(num_verify_tokens * bs * num_verify_tokens,),
|
||||||
|
True,
|
||||||
|
dtype=torch.bool,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
elif tree_mask_mode == TreeMaskMode.QLEN_ONLY_BITPACKING:
|
||||||
|
packed_dtypes = [torch.uint8, torch.uint16, torch.uint32]
|
||||||
|
packed_dtype_idx = int(math.ceil(math.log2((num_verify_tokens + 7) // 8)))
|
||||||
|
tree_mask = torch.zeros(
|
||||||
|
(num_verify_tokens * bs,),
|
||||||
|
dtype=packed_dtypes[packed_dtype_idx],
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
elif tree_mask_mode == TreeMaskMode.FULL_MASK:
|
||||||
|
tree_mask = torch.full(
|
||||||
|
(
|
||||||
|
seq_lens_sum * num_verify_tokens
|
||||||
|
+ num_verify_tokens * num_verify_tokens * bs,
|
||||||
|
),
|
||||||
|
True,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Invalid tree mask: {tree_mask_mode=}")
|
||||||
|
|
||||||
|
# TODO: make them torch.empty and fuse them into `sgl_build_tree_kernel`
|
||||||
|
retrive_buf = torch.full(
|
||||||
|
(3, bs, num_verify_tokens), -1, device=device, dtype=torch.long
|
||||||
|
)
|
||||||
|
retrive_index, retrive_next_token, retrive_next_sibling = retrive_buf
|
||||||
|
# position: where each token belongs to
|
||||||
|
# e.g. if depth of each draft token is [0, 1, 1, 2] and the prompt length is 7
|
||||||
|
# then, positions = [7, 8, 8, 9]
|
||||||
|
if position_buf is not None:
|
||||||
|
positions = position_buf
|
||||||
|
else:
|
||||||
|
positions = torch.empty(
|
||||||
|
(bs * num_verify_tokens,), device=device, dtype=torch.long
|
||||||
|
)
|
||||||
|
|
||||||
|
sgl_build_tree_kernel_efficient(
|
||||||
|
parent_list,
|
||||||
|
top_scores_index,
|
||||||
|
seq_lens,
|
||||||
|
tree_mask,
|
||||||
|
positions,
|
||||||
|
retrive_index,
|
||||||
|
retrive_next_token,
|
||||||
|
retrive_next_sibling,
|
||||||
|
topk,
|
||||||
|
spec_steps,
|
||||||
|
num_verify_tokens,
|
||||||
|
tree_mask_mode,
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
tree_mask,
|
||||||
|
positions,
|
||||||
|
retrive_index,
|
||||||
|
retrive_next_token,
|
||||||
|
retrive_next_sibling,
|
||||||
|
draft_tokens,
|
||||||
|
)
|
||||||
@@ -28,7 +28,6 @@ from sglang.srt.model_executor.forward_batch_info import (
|
|||||||
ForwardMode,
|
ForwardMode,
|
||||||
)
|
)
|
||||||
from sglang.srt.server_args import ServerArgs, get_global_server_args
|
from sglang.srt.server_args import ServerArgs, get_global_server_args
|
||||||
from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient
|
|
||||||
from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
|
from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
|
||||||
EAGLEDraftCudaGraphRunner,
|
EAGLEDraftCudaGraphRunner,
|
||||||
)
|
)
|
||||||
@@ -40,6 +39,10 @@ from sglang.srt.speculative.eagle_info import (
|
|||||||
EagleVerifyInput,
|
EagleVerifyInput,
|
||||||
EagleVerifyOutput,
|
EagleVerifyOutput,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.speculative.eagle_utils import (
|
||||||
|
build_tree_kernel_efficient,
|
||||||
|
organize_draft_results,
|
||||||
|
)
|
||||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||||
from sglang.srt.speculative.spec_utils import (
|
from sglang.srt.speculative.spec_utils import (
|
||||||
assign_draft_cache_locs,
|
assign_draft_cache_locs,
|
||||||
@@ -677,7 +680,7 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
forward_batch
|
forward_batch
|
||||||
)
|
)
|
||||||
if can_cuda_graph:
|
if can_cuda_graph:
|
||||||
score_list, token_list, parents_list = self.cuda_graph_runner.replay(
|
parent_list, top_scores_index, draft_tokens = self.cuda_graph_runner.replay(
|
||||||
forward_batch
|
forward_batch
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -686,7 +689,9 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
# Initialize attention backend
|
# Initialize attention backend
|
||||||
self.draft_attn_backend.init_forward_metadata(forward_batch)
|
self.draft_attn_backend.init_forward_metadata(forward_batch)
|
||||||
# Run forward steps
|
# Run forward steps
|
||||||
score_list, token_list, parents_list = self.draft_forward(forward_batch)
|
parent_list, top_scores_index, draft_tokens = self.draft_forward(
|
||||||
|
forward_batch
|
||||||
|
)
|
||||||
|
|
||||||
if batch.forward_mode.is_idle():
|
if batch.forward_mode.is_idle():
|
||||||
return EagleVerifyInput.create_idle_input(
|
return EagleVerifyInput.create_idle_input(
|
||||||
@@ -704,9 +709,9 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
draft_tokens,
|
draft_tokens,
|
||||||
) = build_tree_kernel_efficient(
|
) = build_tree_kernel_efficient(
|
||||||
spec_info.verified_id,
|
spec_info.verified_id,
|
||||||
score_list,
|
parent_list,
|
||||||
token_list,
|
top_scores_index,
|
||||||
parents_list,
|
draft_tokens,
|
||||||
batch.seq_lens,
|
batch.seq_lens,
|
||||||
batch.seq_lens_sum,
|
batch.seq_lens_sum,
|
||||||
self.topk,
|
self.topk,
|
||||||
@@ -795,7 +800,11 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
topk_index = self.hot_token_id[topk_index]
|
topk_index = self.hot_token_id[topk_index]
|
||||||
hidden_states = logits_output.hidden_states
|
hidden_states = logits_output.hidden_states
|
||||||
|
|
||||||
return score_list, token_list, parents_list
|
parent_list, top_scores_index, draft_tokens = organize_draft_results(
|
||||||
|
score_list, token_list, parents_list, self.speculative_num_draft_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
return parent_list, top_scores_index, draft_tokens
|
||||||
|
|
||||||
def clear_cache_pool(self):
|
def clear_cache_pool(self):
|
||||||
self.model_runner.req_to_token_pool.clear()
|
self.model_runner.req_to_token_pool.clear()
|
||||||
|
|||||||
@@ -12,15 +12,14 @@ from sglang.srt.mem_cache.allocator import TokenToKVPoolAllocator
|
|||||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
||||||
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardBatch
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.speculative.build_eagle_tree import TreeMaskMode
|
|
||||||
from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput
|
from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput
|
||||||
from sglang.srt.speculative.eagle_info_v2 import (
|
from sglang.srt.speculative.eagle_info_v2 import (
|
||||||
assign_extend_cache_locs,
|
assign_extend_cache_locs,
|
||||||
build_tree_kernel_efficient_tmp,
|
|
||||||
fill_accepted_out_cache_loc,
|
fill_accepted_out_cache_loc,
|
||||||
fill_new_verified_id,
|
fill_new_verified_id,
|
||||||
select_top_k_tokens_tmp,
|
select_top_k_tokens_tmp,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.speculative.eagle_utils import TreeMaskMode, build_tree_kernel_efficient
|
||||||
from sglang.srt.speculative.eagle_worker import EAGLEWorker
|
from sglang.srt.speculative.eagle_worker import EAGLEWorker
|
||||||
from sglang.srt.utils.common import fast_topk, next_power_of_2
|
from sglang.srt.utils.common import fast_topk, next_power_of_2
|
||||||
|
|
||||||
@@ -116,7 +115,7 @@ class EAGLEWorkerV2(EAGLEWorker):
|
|||||||
retrive_next_token,
|
retrive_next_token,
|
||||||
retrive_next_sibling,
|
retrive_next_sibling,
|
||||||
draft_tokens,
|
draft_tokens,
|
||||||
) = build_tree_kernel_efficient_tmp(
|
) = build_tree_kernel_efficient(
|
||||||
draft_input.verified_id,
|
draft_input.verified_id,
|
||||||
parent_list,
|
parent_list,
|
||||||
top_scores_index,
|
top_scores_index,
|
||||||
|
|||||||
@@ -69,6 +69,7 @@ suites = {
|
|||||||
TestFile("test_chunked_prefill.py", 313),
|
TestFile("test_chunked_prefill.py", 313),
|
||||||
TestFile("test_create_kvindices.py", 2),
|
TestFile("test_create_kvindices.py", 2),
|
||||||
TestFile("test_deterministic.py", 300),
|
TestFile("test_deterministic.py", 300),
|
||||||
|
TestFile("test_build_eagle_tree.py", 8),
|
||||||
TestFile("test_eagle_infer_a.py", 370),
|
TestFile("test_eagle_infer_a.py", 370),
|
||||||
TestFile("test_eagle_infer_b.py", 700),
|
TestFile("test_eagle_infer_b.py", 700),
|
||||||
TestFile("test_eagle_infer_beta.py", 300),
|
TestFile("test_eagle_infer_beta.py", 300),
|
||||||
|
|||||||
308
test/srt/test_build_eagle_tree.py
Normal file
308
test/srt/test_build_eagle_tree.py
Normal file
@@ -0,0 +1,308 @@
|
|||||||
|
import unittest
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.speculative.eagle_utils import (
|
||||||
|
build_tree_kernel_efficient,
|
||||||
|
organize_draft_results,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestBuildEagleTree(unittest.TestCase):
|
||||||
|
"""Unit tests for build_eagle_tree functionality."""
|
||||||
|
|
||||||
|
def test_build_tree_kernel_efficient(self):
|
||||||
|
"""Test the build_tree_kernel_efficient function with known inputs and expected outputs."""
|
||||||
|
verified_id = torch.tensor([29974, 13], device="cuda", dtype=torch.int32)
|
||||||
|
score_list = [
|
||||||
|
torch.tensor(
|
||||||
|
[
|
||||||
|
[[7.1127e-01, 2.8292e-01, 2.2995e-03, 1.7357e-03]],
|
||||||
|
[[9.7476e-01, 2.2219e-02, 6.5031e-04, 1.3212e-04]],
|
||||||
|
],
|
||||||
|
dtype=torch.float32,
|
||||||
|
device="cuda",
|
||||||
|
),
|
||||||
|
torch.tensor(
|
||||||
|
[
|
||||||
|
[
|
||||||
|
[6.9142e-01, 1.2863e-02, 1.6873e-03, 1.1871e-03],
|
||||||
|
[2.4787e-01, 1.8818e-02, 1.4204e-02, 9.2235e-04],
|
||||||
|
[2.2971e-03, 1.6700e-06, 1.8737e-07, 8.3146e-08],
|
||||||
|
[1.2771e-03, 2.4374e-04, 1.7832e-04, 1.1947e-05],
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[8.4832e-02, 6.6068e-02, 5.8304e-02, 5.7851e-02],
|
||||||
|
[2.3616e-03, 1.1243e-03, 5.4368e-04, 2.7768e-04],
|
||||||
|
[2.5286e-04, 1.5578e-04, 2.8817e-05, 1.2888e-05],
|
||||||
|
[1.2834e-04, 2.5417e-06, 1.1279e-06, 1.6088e-08],
|
||||||
|
],
|
||||||
|
],
|
||||||
|
dtype=torch.float32,
|
||||||
|
device="cuda",
|
||||||
|
),
|
||||||
|
torch.tensor(
|
||||||
|
[
|
||||||
|
[
|
||||||
|
[6.6438e-01, 2.6997e-02, 2.4236e-05, 4.0821e-06],
|
||||||
|
[2.4402e-01, 2.8409e-03, 5.0935e-04, 2.9022e-04],
|
||||||
|
[1.6178e-02, 2.0567e-03, 4.5892e-04, 3.0034e-05],
|
||||||
|
[1.3023e-02, 5.0497e-04, 3.6371e-04, 8.7750e-05],
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[2.3263e-02, 2.0054e-02, 9.3990e-03, 2.7783e-03],
|
||||||
|
[6.4156e-02, 5.5506e-04, 1.0429e-04, 9.7211e-05],
|
||||||
|
[4.9950e-02, 5.0630e-03, 9.0068e-04, 3.3656e-04],
|
||||||
|
[7.5817e-03, 8.5731e-04, 6.9972e-04, 6.0793e-04],
|
||||||
|
],
|
||||||
|
],
|
||||||
|
dtype=torch.float32,
|
||||||
|
device="cuda",
|
||||||
|
),
|
||||||
|
torch.tensor(
|
||||||
|
[
|
||||||
|
[
|
||||||
|
[6.6420e-01, 1.0525e-04, 6.5864e-05, 1.2253e-06],
|
||||||
|
[1.3019e-01, 1.0461e-01, 5.2083e-03, 1.6777e-03],
|
||||||
|
[2.0103e-02, 6.7335e-03, 1.2625e-04, 1.0364e-05],
|
||||||
|
[1.5142e-02, 7.0819e-04, 9.6595e-05, 8.7951e-05],
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[5.8608e-02, 1.8840e-03, 7.8535e-04, 4.4400e-04],
|
||||||
|
[1.2185e-02, 2.0684e-03, 1.7418e-03, 1.4327e-03],
|
||||||
|
[6.2455e-03, 6.1487e-03, 2.6862e-03, 1.8034e-03],
|
||||||
|
[1.8590e-03, 1.6151e-03, 1.2481e-03, 3.6038e-04],
|
||||||
|
],
|
||||||
|
],
|
||||||
|
dtype=torch.float32,
|
||||||
|
device="cuda",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
token_list = [
|
||||||
|
torch.tensor(
|
||||||
|
[[29896, 29906, 29900, 29945], [13, 2, 29871, 28956]],
|
||||||
|
dtype=torch.int64,
|
||||||
|
device="cuda",
|
||||||
|
),
|
||||||
|
torch.tensor(
|
||||||
|
[
|
||||||
|
[
|
||||||
|
29889,
|
||||||
|
29974,
|
||||||
|
29945,
|
||||||
|
29900,
|
||||||
|
29974,
|
||||||
|
29922,
|
||||||
|
29930,
|
||||||
|
29958,
|
||||||
|
29889,
|
||||||
|
29974,
|
||||||
|
29930,
|
||||||
|
29945,
|
||||||
|
29974,
|
||||||
|
29922,
|
||||||
|
29930,
|
||||||
|
29958,
|
||||||
|
],
|
||||||
|
[
|
||||||
|
22550,
|
||||||
|
4136,
|
||||||
|
16492,
|
||||||
|
8439,
|
||||||
|
29871,
|
||||||
|
2,
|
||||||
|
3001,
|
||||||
|
13,
|
||||||
|
2,
|
||||||
|
13,
|
||||||
|
29906,
|
||||||
|
29946,
|
||||||
|
2,
|
||||||
|
13,
|
||||||
|
29871,
|
||||||
|
259,
|
||||||
|
],
|
||||||
|
],
|
||||||
|
device="cuda",
|
||||||
|
),
|
||||||
|
torch.tensor(
|
||||||
|
[
|
||||||
|
[
|
||||||
|
29946,
|
||||||
|
29945,
|
||||||
|
29953,
|
||||||
|
29906,
|
||||||
|
29896,
|
||||||
|
29945,
|
||||||
|
29900,
|
||||||
|
29906,
|
||||||
|
29896,
|
||||||
|
29945,
|
||||||
|
29906,
|
||||||
|
29953,
|
||||||
|
29896,
|
||||||
|
29945,
|
||||||
|
29906,
|
||||||
|
29946,
|
||||||
|
],
|
||||||
|
[
|
||||||
|
29871,
|
||||||
|
2,
|
||||||
|
29901,
|
||||||
|
29889,
|
||||||
|
29871,
|
||||||
|
2,
|
||||||
|
395,
|
||||||
|
259,
|
||||||
|
29901,
|
||||||
|
29871,
|
||||||
|
2,
|
||||||
|
29889,
|
||||||
|
3001,
|
||||||
|
1234,
|
||||||
|
7146,
|
||||||
|
2186,
|
||||||
|
],
|
||||||
|
],
|
||||||
|
device="cuda",
|
||||||
|
),
|
||||||
|
torch.tensor(
|
||||||
|
[
|
||||||
|
[
|
||||||
|
29946,
|
||||||
|
29974,
|
||||||
|
29945,
|
||||||
|
29930,
|
||||||
|
29889,
|
||||||
|
29922,
|
||||||
|
29974,
|
||||||
|
29930,
|
||||||
|
29974,
|
||||||
|
29946,
|
||||||
|
29930,
|
||||||
|
29922,
|
||||||
|
29889,
|
||||||
|
29974,
|
||||||
|
29945,
|
||||||
|
29922,
|
||||||
|
],
|
||||||
|
[
|
||||||
|
29941,
|
||||||
|
29906,
|
||||||
|
2,
|
||||||
|
29946,
|
||||||
|
29871,
|
||||||
|
450,
|
||||||
|
319,
|
||||||
|
14990,
|
||||||
|
29946,
|
||||||
|
29941,
|
||||||
|
2,
|
||||||
|
29906,
|
||||||
|
29871,
|
||||||
|
2,
|
||||||
|
3001,
|
||||||
|
13,
|
||||||
|
],
|
||||||
|
],
|
||||||
|
device="cuda",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
parents_list = [
|
||||||
|
torch.tensor(
|
||||||
|
[[-1, 0, 1, 2, 3], [-1, 0, 1, 2, 3]], dtype=torch.int64, device="cuda"
|
||||||
|
),
|
||||||
|
torch.tensor(
|
||||||
|
[[4, 8, 9, 10], [4, 5, 6, 7]], dtype=torch.int64, device="cuda"
|
||||||
|
),
|
||||||
|
torch.tensor(
|
||||||
|
[[20, 24, 21, 28], [24, 28, 20, 21]], dtype=torch.int64, device="cuda"
|
||||||
|
),
|
||||||
|
torch.tensor(
|
||||||
|
[[36, 40, 41, 44], [36, 40, 44, 45]], dtype=torch.int64, device="cuda"
|
||||||
|
),
|
||||||
|
]
|
||||||
|
seq_lens = torch.tensor([5, 10], dtype=torch.int64, device="cuda")
|
||||||
|
topk = 4
|
||||||
|
depth = 4
|
||||||
|
num_draft_token = 8
|
||||||
|
|
||||||
|
parent_list, top_scores_index, draft_tokens = organize_draft_results(
|
||||||
|
score_list, token_list, parents_list, num_draft_token
|
||||||
|
)
|
||||||
|
|
||||||
|
(
|
||||||
|
tree_mask,
|
||||||
|
position,
|
||||||
|
retrieve_index,
|
||||||
|
retrieve_next_token,
|
||||||
|
retrieve_next_sibling,
|
||||||
|
draft_tokens,
|
||||||
|
) = build_tree_kernel_efficient(
|
||||||
|
verified_id=verified_id,
|
||||||
|
parent_list=parent_list,
|
||||||
|
top_scores_index=top_scores_index,
|
||||||
|
draft_tokens=draft_tokens,
|
||||||
|
seq_lens=seq_lens,
|
||||||
|
seq_lens_sum=torch.sum(seq_lens).item(),
|
||||||
|
topk=topk,
|
||||||
|
spec_steps=depth,
|
||||||
|
num_verify_tokens=num_draft_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify expected outputs
|
||||||
|
self.assertEqual(
|
||||||
|
position.tolist(),
|
||||||
|
[5, 6, 6, 7, 7, 8, 8, 9, 10, 11, 12, 12, 12, 12, 13, 14],
|
||||||
|
"Position tensor does not match expected values",
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
retrieve_index.tolist(),
|
||||||
|
[
|
||||||
|
[0, 1, 2, 3, 4, 5, 6, 7],
|
||||||
|
[8, 9, 10, 11, 12, 13, 14, 15],
|
||||||
|
],
|
||||||
|
"Retrieve index tensor does not match expected values",
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
retrieve_next_token.tolist(),
|
||||||
|
[
|
||||||
|
[1, 3, 4, 5, 6, 7, -1, -1],
|
||||||
|
[1, 2, -1, 6, -1, -1, 7, -1],
|
||||||
|
],
|
||||||
|
"Retrieve next token tensor does not match expected values",
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
retrieve_next_sibling.tolist(),
|
||||||
|
[
|
||||||
|
[-1, 2, -1, -1, -1, -1, -1, -1],
|
||||||
|
[-1, -1, 3, 4, 5, -1, -1, -1],
|
||||||
|
],
|
||||||
|
"Retrieve next sibling tensor does not match expected values",
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
draft_tokens.tolist(),
|
||||||
|
[
|
||||||
|
29974,
|
||||||
|
29896,
|
||||||
|
29906,
|
||||||
|
29889,
|
||||||
|
29974,
|
||||||
|
29946,
|
||||||
|
29896,
|
||||||
|
29946,
|
||||||
|
13,
|
||||||
|
13,
|
||||||
|
22550,
|
||||||
|
4136,
|
||||||
|
16492,
|
||||||
|
8439,
|
||||||
|
29871,
|
||||||
|
29941,
|
||||||
|
],
|
||||||
|
"Draft tokens tensor does not match expected values",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user