Add treemask mode to build_eagle_tree & release sgl-kernel 0.2.3 (#7756)
Co-authored-by: Pranjal Shankhdhar <pranjal.ssh@gmail.com>
This commit is contained in:
@@ -1,10 +1,12 @@
|
||||
# NOTE: Please run this file to make sure the test cases are correct.
|
||||
|
||||
from typing import List
|
||||
import math
|
||||
from enum import IntEnum
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.utils import is_cuda, is_hip, rank0_log
|
||||
from sglang.srt.utils import is_cuda, is_hip
|
||||
|
||||
if is_cuda() or is_hip():
|
||||
from sgl_kernel import (
|
||||
@@ -40,6 +42,12 @@ def build_tree_kernel_efficient_preprocess(
|
||||
return parent_list, top_scores_index, draft_tokens
|
||||
|
||||
|
||||
class TreeMaskMode(IntEnum):
|
||||
FULL_MASK = 0
|
||||
QLEN_ONLY = 1
|
||||
QLEN_ONLY_BITPACKING = 2
|
||||
|
||||
|
||||
def build_tree_kernel_efficient(
|
||||
verified_id: torch.Tensor,
|
||||
score_list: List[torch.Tensor],
|
||||
@@ -50,6 +58,9 @@ def build_tree_kernel_efficient(
|
||||
topk: int,
|
||||
spec_steps: int,
|
||||
num_verify_tokens: int,
|
||||
tree_mask_mode: TreeMaskMode = TreeMaskMode.FULL_MASK,
|
||||
tree_mask_buf: Optional[torch.Tensor] = None,
|
||||
position_buf: Optional[torch.Tensor] = None,
|
||||
):
|
||||
parent_list, top_scores_index, draft_tokens = (
|
||||
build_tree_kernel_efficient_preprocess(
|
||||
@@ -66,15 +77,37 @@ def build_tree_kernel_efficient(
|
||||
device = seq_lens.device
|
||||
# e.g. for bs=1, tree_mask: num_draft_token, seq_lens_sum + num_draft_token (flattened)
|
||||
# where each row indicates the attending pattern of each draft token
|
||||
# if use_partial_packed_tree_mask is True, tree_mask: num_draft_token (flattened, packed)
|
||||
if tree_mask_buf is not None:
|
||||
tree_mask = tree_mask_buf
|
||||
elif tree_mask_mode == TreeMaskMode.QLEN_ONLY:
|
||||
tree_mask = torch.full(
|
||||
(num_verify_tokens * bs * num_verify_tokens,),
|
||||
True,
|
||||
dtype=torch.bool,
|
||||
device=device,
|
||||
)
|
||||
elif tree_mask_mode == TreeMaskMode.QLEN_ONLY_BITPACKING:
|
||||
packed_dtypes = [torch.uint8, torch.uint16, torch.uint32]
|
||||
packed_dtype_idx = int(math.ceil(math.log2((num_verify_tokens + 7) // 8)))
|
||||
tree_mask = torch.zeros(
|
||||
(num_verify_tokens * bs,),
|
||||
dtype=packed_dtypes[packed_dtype_idx],
|
||||
device=device,
|
||||
)
|
||||
elif tree_mask_mode == TreeMaskMode.FULL_MASK:
|
||||
tree_mask = torch.full(
|
||||
(
|
||||
seq_lens_sum * num_verify_tokens
|
||||
+ num_verify_tokens * num_verify_tokens * bs,
|
||||
),
|
||||
True,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Invalid tree mask: {tree_mask_mode=}")
|
||||
|
||||
# TODO: make them torch.empty and fuse them into `sgl_build_tree_kernel`
|
||||
tree_mask = torch.full(
|
||||
(
|
||||
seq_lens_sum * num_verify_tokens
|
||||
+ num_verify_tokens * num_verify_tokens * bs,
|
||||
),
|
||||
True,
|
||||
device=device,
|
||||
)
|
||||
retrive_index = torch.full(
|
||||
(bs, num_verify_tokens), -1, device=device, dtype=torch.long
|
||||
)
|
||||
@@ -87,7 +120,12 @@ def build_tree_kernel_efficient(
|
||||
# position: where each token belongs to
|
||||
# e.g. if depth of each draft token is [0, 1, 1, 2] and the prompt length is 7
|
||||
# then, positions = [7, 8, 8, 9]
|
||||
positions = torch.empty((bs * num_verify_tokens,), device=device, dtype=torch.long)
|
||||
if position_buf is not None:
|
||||
positions = position_buf
|
||||
else:
|
||||
positions = torch.empty(
|
||||
(bs * num_verify_tokens,), device=device, dtype=torch.long
|
||||
)
|
||||
|
||||
sgl_build_tree_kernel_efficient(
|
||||
parent_list,
|
||||
@@ -101,6 +139,7 @@ def build_tree_kernel_efficient(
|
||||
topk,
|
||||
spec_steps,
|
||||
num_verify_tokens,
|
||||
tree_mask_mode,
|
||||
)
|
||||
return (
|
||||
tree_mask,
|
||||
@@ -344,13 +383,13 @@ def test_build_tree_kernel_efficient():
|
||||
num_verify_tokens=num_draft_token,
|
||||
)
|
||||
|
||||
rank0_log("=========== build tree kernel efficient ==========")
|
||||
# rank0_log(f"{tree_mask=}")
|
||||
rank0_log(f"{position=}")
|
||||
rank0_log(f"{retrive_index=}")
|
||||
rank0_log(f"{retrive_next_token=}")
|
||||
rank0_log(f"{retrive_next_sibling=}")
|
||||
rank0_log(f"{draft_tokens=}")
|
||||
print("=========== build tree kernel efficient ==========")
|
||||
print(f"{tree_mask=}")
|
||||
print(f"{position=}")
|
||||
print(f"{retrive_index=}")
|
||||
print(f"{retrive_next_token=}")
|
||||
print(f"{retrive_next_sibling=}")
|
||||
print(f"{draft_tokens=}")
|
||||
assert position.tolist() == [5, 6, 6, 7, 7, 8, 8, 9, 10, 11, 12, 12, 12, 12, 13, 14]
|
||||
assert retrive_index.tolist() == [
|
||||
[0, 1, 2, 3, 4, 5, 6, 7],
|
||||
|
||||
Reference in New Issue
Block a user