From 5589b7502440c42d8e4b17fb7d3dfcfeebb83c05 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sat, 5 Jul 2025 12:17:05 -0700 Subject: [PATCH] Add treemask mode to build_eagle_tree & release sgl-kernel 0.2.3 (#7756) Co-authored-by: Pranjal Shankhdhar --- .../srt/speculative/build_eagle_tree.py | 75 ++++++++++++++----- sgl-kernel/csrc/common_extension.cc | 3 +- sgl-kernel/csrc/speculative/eagle_utils.cu | 51 +++++++++---- sgl-kernel/csrc/torch_extension_rocm.cc | 3 +- sgl-kernel/include/sgl_kernel_ops.h | 3 +- sgl-kernel/python/sgl_kernel/speculative.py | 2 + 6 files changed, 101 insertions(+), 36 deletions(-) diff --git a/python/sglang/srt/speculative/build_eagle_tree.py b/python/sglang/srt/speculative/build_eagle_tree.py index c53a13f4a..fd27f414c 100644 --- a/python/sglang/srt/speculative/build_eagle_tree.py +++ b/python/sglang/srt/speculative/build_eagle_tree.py @@ -1,10 +1,12 @@ # NOTE: Please run this file to make sure the test cases are correct. -from typing import List +import math +from enum import IntEnum +from typing import List, Optional import torch -from sglang.srt.utils import is_cuda, is_hip, rank0_log +from sglang.srt.utils import is_cuda, is_hip if is_cuda() or is_hip(): from sgl_kernel import ( @@ -40,6 +42,12 @@ def build_tree_kernel_efficient_preprocess( return parent_list, top_scores_index, draft_tokens +class TreeMaskMode(IntEnum): + FULL_MASK = 0 + QLEN_ONLY = 1 + QLEN_ONLY_BITPACKING = 2 + + def build_tree_kernel_efficient( verified_id: torch.Tensor, score_list: List[torch.Tensor], @@ -50,6 +58,9 @@ def build_tree_kernel_efficient( topk: int, spec_steps: int, num_verify_tokens: int, + tree_mask_mode: TreeMaskMode = TreeMaskMode.FULL_MASK, + tree_mask_buf: Optional[torch.Tensor] = None, + position_buf: Optional[torch.Tensor] = None, ): parent_list, top_scores_index, draft_tokens = ( build_tree_kernel_efficient_preprocess( @@ -66,15 +77,37 @@ def build_tree_kernel_efficient( device = seq_lens.device # e.g. for bs=1, tree_mask: num_draft_token, seq_lens_sum + num_draft_token (flattened) # where each row indicates the attending pattern of each draft token + # if use_partial_packed_tree_mask is True, tree_mask: num_draft_token (flattened, packed) + if tree_mask_buf is not None: + tree_mask = tree_mask_buf + elif tree_mask_mode == TreeMaskMode.QLEN_ONLY: + tree_mask = torch.full( + (num_verify_tokens * bs * num_verify_tokens,), + True, + dtype=torch.bool, + device=device, + ) + elif tree_mask_mode == TreeMaskMode.QLEN_ONLY_BITPACKING: + packed_dtypes = [torch.uint8, torch.uint16, torch.uint32] + packed_dtype_idx = int(math.ceil(math.log2((num_verify_tokens + 7) // 8))) + tree_mask = torch.zeros( + (num_verify_tokens * bs,), + dtype=packed_dtypes[packed_dtype_idx], + device=device, + ) + elif tree_mask_mode == TreeMaskMode.FULL_MASK: + tree_mask = torch.full( + ( + seq_lens_sum * num_verify_tokens + + num_verify_tokens * num_verify_tokens * bs, + ), + True, + device=device, + ) + else: + raise NotImplementedError(f"Invalid tree mask: {tree_mask_mode=}") + # TODO: make them torch.empty and fuse them into `sgl_build_tree_kernel` - tree_mask = torch.full( - ( - seq_lens_sum * num_verify_tokens - + num_verify_tokens * num_verify_tokens * bs, - ), - True, - device=device, - ) retrive_index = torch.full( (bs, num_verify_tokens), -1, device=device, dtype=torch.long ) @@ -87,7 +120,12 @@ def build_tree_kernel_efficient( # position: where each token belongs to # e.g. if depth of each draft token is [0, 1, 1, 2] and the prompt length is 7 # then, positions = [7, 8, 8, 9] - positions = torch.empty((bs * num_verify_tokens,), device=device, dtype=torch.long) + if position_buf is not None: + positions = position_buf + else: + positions = torch.empty( + (bs * num_verify_tokens,), device=device, dtype=torch.long + ) sgl_build_tree_kernel_efficient( parent_list, @@ -101,6 +139,7 @@ def build_tree_kernel_efficient( topk, spec_steps, num_verify_tokens, + tree_mask_mode, ) return ( tree_mask, @@ -344,13 +383,13 @@ def test_build_tree_kernel_efficient(): num_verify_tokens=num_draft_token, ) - rank0_log("=========== build tree kernel efficient ==========") - # rank0_log(f"{tree_mask=}") - rank0_log(f"{position=}") - rank0_log(f"{retrive_index=}") - rank0_log(f"{retrive_next_token=}") - rank0_log(f"{retrive_next_sibling=}") - rank0_log(f"{draft_tokens=}") + print("=========== build tree kernel efficient ==========") + print(f"{tree_mask=}") + print(f"{position=}") + print(f"{retrive_index=}") + print(f"{retrive_next_token=}") + print(f"{retrive_next_sibling=}") + print(f"{draft_tokens=}") assert position.tolist() == [5, 6, 6, 7, 7, 8, 8, 9, 10, 11, 12, 12, 12, 12, 13, 14] assert retrive_index.tolist() == [ [0, 1, 2, 3, 4, 5, 6, 7], diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index bdec3cf1a..f5eb9bfe5 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -232,7 +232,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { m.def( "build_tree_kernel_efficient(Tensor parent_list, Tensor selected_index, Tensor verified_seq_len, " "Tensor! tree_mask, Tensor! positions, Tensor! retrive_index, Tensor! retrive_next_token, " - "Tensor! retrive_next_sibling, int topk, int depth, int draft_token_num) -> ()"); + "Tensor! retrive_next_sibling, int topk, int depth, int draft_token_num, int tree_mask_mode) -> " + "()"); m.impl("build_tree_kernel_efficient", torch::kCUDA, &build_tree_kernel_efficient); m.def( diff --git a/sgl-kernel/csrc/speculative/eagle_utils.cu b/sgl-kernel/csrc/speculative/eagle_utils.cu index 8b0759765..9b463de9a 100644 --- a/sgl-kernel/csrc/speculative/eagle_utils.cu +++ b/sgl-kernel/csrc/speculative/eagle_utils.cu @@ -23,6 +23,8 @@ #include "pytorch_extension_utils_rocm.h" #endif +typedef enum { FULL_MASK = 0, QLEN_ONLY = 1, QLEN_ONLY_BITPACKING = 2 } TreeMaskMode; + // parent_list [bs, topk * (depth - 1) + 1)] // selected_index [bs, draft_token_num - 1] // verified_seq_len [bs] @@ -40,7 +42,8 @@ __global__ void build_tree_efficient( int64_t* retrive_next_sibling, int topk, int depth, - int draft_token_num) { + int draft_token_num, + int tree_mask_mode) { int bid = blockIdx.x; int tid = threadIdx.x; @@ -52,7 +55,13 @@ __global__ void build_tree_efficient( seq_tree_idx += verified_seq_len[i] * draft_token_num; } int seq_len = verified_seq_len[bid]; - int token_tree_idx = seq_tree_idx + (seq_len + draft_token_num) * tid + seq_len + 1; + int token_tree_idx; + if (tree_mask_mode == FULL_MASK) { + token_tree_idx = seq_tree_idx + (seq_len + draft_token_num) * tid + seq_len + 1; + } else { + token_tree_idx = draft_token_num * draft_token_num * bid + draft_token_num * tid + 1; + } + tree_mask[token_tree_idx - 1] = true; for (int i = 0; i < draft_token_num - 1; i++) { tree_mask[token_tree_idx + i] = false; } @@ -124,7 +133,8 @@ void build_tree_kernel_efficient( at::Tensor retrive_next_sibling, int64_t topk, int64_t depth, - int64_t draft_token_num) { + int64_t draft_token_num, + int64_t tree_mask_mode) { // TODO (ying) check shape // TODO (ying) check type int bs = parent_list.size(0); @@ -132,18 +142,29 @@ void build_tree_kernel_efficient( dim3 block(draft_token_num); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - build_tree_efficient<<>>( - static_cast(parent_list.data_ptr()), - static_cast(selected_index.data_ptr()), - static_cast(verified_seq_len.data_ptr()), - static_cast(tree_mask.data_ptr()), - static_cast(positions.data_ptr()), - static_cast(retrive_index.data_ptr()), - static_cast(retrive_next_token.data_ptr()), - static_cast(retrive_next_sibling.data_ptr()), - int32_t(topk), - int32_t(depth), - int32_t(draft_token_num)); + if (tree_mask_mode == QLEN_ONLY_BITPACKING) { + size_t num_bytes_per_item = 1; + if (draft_token_num > 16) { + num_bytes_per_item = 4; + } else if (draft_token_num > 8) { + num_bytes_per_item = 2; + } + throw std::runtime_error("Not implemented"); + } else { + build_tree_efficient<<>>( + static_cast(parent_list.data_ptr()), + static_cast(selected_index.data_ptr()), + static_cast(verified_seq_len.data_ptr()), + static_cast(tree_mask.data_ptr()), + static_cast(positions.data_ptr()), + static_cast(retrive_index.data_ptr()), + static_cast(retrive_next_token.data_ptr()), + static_cast(retrive_next_sibling.data_ptr()), + int32_t(topk), + int32_t(depth), + int32_t(draft_token_num), + int32_t(tree_mask_mode)); + } } template diff --git a/sgl-kernel/csrc/torch_extension_rocm.cc b/sgl-kernel/csrc/torch_extension_rocm.cc index 0e3f48e61..84f9d1e7a 100644 --- a/sgl-kernel/csrc/torch_extension_rocm.cc +++ b/sgl-kernel/csrc/torch_extension_rocm.cc @@ -78,7 +78,8 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) { m.def( "build_tree_kernel_efficient(Tensor parent_list, Tensor selected_index, Tensor verified_seq_len, " "Tensor! tree_mask, Tensor! positions, Tensor! retrive_index, Tensor! retrive_next_token, " - "Tensor! retrive_next_sibling, int topk, int depth, int draft_token_num) -> ()"); + "Tensor! retrive_next_sibling, int topk, int depth, int draft_token_num, int tree_mask_mode) -> " + "()"); m.impl("build_tree_kernel_efficient", torch::kCUDA, &build_tree_kernel_efficient); } diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index 6811fdd55..4d4990041 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -374,7 +374,8 @@ void build_tree_kernel_efficient( at::Tensor retrive_next_sibling, int64_t topk, int64_t depth, - int64_t draft_token_num); + int64_t draft_token_num, + int64_t tree_mask_mode); void segment_packbits( at::Tensor x, diff --git a/sgl-kernel/python/sgl_kernel/speculative.py b/sgl-kernel/python/sgl_kernel/speculative.py index 0ff46148a..ea2e3ac8a 100644 --- a/sgl-kernel/python/sgl_kernel/speculative.py +++ b/sgl-kernel/python/sgl_kernel/speculative.py @@ -72,6 +72,7 @@ def build_tree_kernel_efficient( topk: int, depth: int, draft_token_num: int, + tree_mask_mode: int, ) -> None: torch.ops.sgl_kernel.build_tree_kernel_efficient.default( parent_list, @@ -85,6 +86,7 @@ def build_tree_kernel_efficient( topk, depth, draft_token_num, + tree_mask_mode, )