Add treemask mode to build_eagle_tree & release sgl-kernel 0.2.3 (#7756)

Co-authored-by: Pranjal Shankhdhar <pranjal.ssh@gmail.com>
This commit is contained in:
Lianmin Zheng
2025-07-05 12:17:05 -07:00
committed by GitHub
parent c04a8a820b
commit 5589b75024
6 changed files with 101 additions and 36 deletions

View File

@@ -1,10 +1,12 @@
# NOTE: Please run this file to make sure the test cases are correct.
from typing import List
import math
from enum import IntEnum
from typing import List, Optional
import torch
from sglang.srt.utils import is_cuda, is_hip, rank0_log
from sglang.srt.utils import is_cuda, is_hip
if is_cuda() or is_hip():
from sgl_kernel import (
@@ -40,6 +42,12 @@ def build_tree_kernel_efficient_preprocess(
return parent_list, top_scores_index, draft_tokens
class TreeMaskMode(IntEnum):
FULL_MASK = 0
QLEN_ONLY = 1
QLEN_ONLY_BITPACKING = 2
def build_tree_kernel_efficient(
verified_id: torch.Tensor,
score_list: List[torch.Tensor],
@@ -50,6 +58,9 @@ def build_tree_kernel_efficient(
topk: int,
spec_steps: int,
num_verify_tokens: int,
tree_mask_mode: TreeMaskMode = TreeMaskMode.FULL_MASK,
tree_mask_buf: Optional[torch.Tensor] = None,
position_buf: Optional[torch.Tensor] = None,
):
parent_list, top_scores_index, draft_tokens = (
build_tree_kernel_efficient_preprocess(
@@ -66,15 +77,37 @@ def build_tree_kernel_efficient(
device = seq_lens.device
# e.g. for bs=1, tree_mask: num_draft_token, seq_lens_sum + num_draft_token (flattened)
# where each row indicates the attending pattern of each draft token
# if use_partial_packed_tree_mask is True, tree_mask: num_draft_token (flattened, packed)
if tree_mask_buf is not None:
tree_mask = tree_mask_buf
elif tree_mask_mode == TreeMaskMode.QLEN_ONLY:
tree_mask = torch.full(
(num_verify_tokens * bs * num_verify_tokens,),
True,
dtype=torch.bool,
device=device,
)
elif tree_mask_mode == TreeMaskMode.QLEN_ONLY_BITPACKING:
packed_dtypes = [torch.uint8, torch.uint16, torch.uint32]
packed_dtype_idx = int(math.ceil(math.log2((num_verify_tokens + 7) // 8)))
tree_mask = torch.zeros(
(num_verify_tokens * bs,),
dtype=packed_dtypes[packed_dtype_idx],
device=device,
)
elif tree_mask_mode == TreeMaskMode.FULL_MASK:
tree_mask = torch.full(
(
seq_lens_sum * num_verify_tokens
+ num_verify_tokens * num_verify_tokens * bs,
),
True,
device=device,
)
else:
raise NotImplementedError(f"Invalid tree mask: {tree_mask_mode=}")
# TODO: make them torch.empty and fuse them into `sgl_build_tree_kernel`
tree_mask = torch.full(
(
seq_lens_sum * num_verify_tokens
+ num_verify_tokens * num_verify_tokens * bs,
),
True,
device=device,
)
retrive_index = torch.full(
(bs, num_verify_tokens), -1, device=device, dtype=torch.long
)
@@ -87,7 +120,12 @@ def build_tree_kernel_efficient(
# position: where each token belongs to
# e.g. if depth of each draft token is [0, 1, 1, 2] and the prompt length is 7
# then, positions = [7, 8, 8, 9]
positions = torch.empty((bs * num_verify_tokens,), device=device, dtype=torch.long)
if position_buf is not None:
positions = position_buf
else:
positions = torch.empty(
(bs * num_verify_tokens,), device=device, dtype=torch.long
)
sgl_build_tree_kernel_efficient(
parent_list,
@@ -101,6 +139,7 @@ def build_tree_kernel_efficient(
topk,
spec_steps,
num_verify_tokens,
tree_mask_mode,
)
return (
tree_mask,
@@ -344,13 +383,13 @@ def test_build_tree_kernel_efficient():
num_verify_tokens=num_draft_token,
)
rank0_log("=========== build tree kernel efficient ==========")
# rank0_log(f"{tree_mask=}")
rank0_log(f"{position=}")
rank0_log(f"{retrive_index=}")
rank0_log(f"{retrive_next_token=}")
rank0_log(f"{retrive_next_sibling=}")
rank0_log(f"{draft_tokens=}")
print("=========== build tree kernel efficient ==========")
print(f"{tree_mask=}")
print(f"{position=}")
print(f"{retrive_index=}")
print(f"{retrive_next_token=}")
print(f"{retrive_next_sibling=}")
print(f"{draft_tokens=}")
assert position.tolist() == [5, 6, 6, 7, 7, 8, 8, 9, 10, 11, 12, 12, 12, 12, 13, 14]
assert retrive_index.tolist() == [
[0, 1, 2, 3, 4, 5, 6, 7],