Add treemask mode to build_eagle_tree & release sgl-kernel 0.2.3 (#7756)

Co-authored-by: Pranjal Shankhdhar <pranjal.ssh@gmail.com>
This commit is contained in:
Lianmin Zheng
2025-07-05 12:17:05 -07:00
committed by GitHub
parent c04a8a820b
commit 5589b75024
6 changed files with 101 additions and 36 deletions

View File

@@ -232,7 +232,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.def(
"build_tree_kernel_efficient(Tensor parent_list, Tensor selected_index, Tensor verified_seq_len, "
"Tensor! tree_mask, Tensor! positions, Tensor! retrive_index, Tensor! retrive_next_token, "
"Tensor! retrive_next_sibling, int topk, int depth, int draft_token_num) -> ()");
"Tensor! retrive_next_sibling, int topk, int depth, int draft_token_num, int tree_mask_mode) -> "
"()");
m.impl("build_tree_kernel_efficient", torch::kCUDA, &build_tree_kernel_efficient);
m.def(

View File

@@ -23,6 +23,8 @@
#include "pytorch_extension_utils_rocm.h"
#endif
typedef enum { FULL_MASK = 0, QLEN_ONLY = 1, QLEN_ONLY_BITPACKING = 2 } TreeMaskMode;
// parent_list [bs, topk * (depth - 1) + 1)]
// selected_index [bs, draft_token_num - 1]
// verified_seq_len [bs]
@@ -40,7 +42,8 @@ __global__ void build_tree_efficient(
int64_t* retrive_next_sibling,
int topk,
int depth,
int draft_token_num) {
int draft_token_num,
int tree_mask_mode) {
int bid = blockIdx.x;
int tid = threadIdx.x;
@@ -52,7 +55,13 @@ __global__ void build_tree_efficient(
seq_tree_idx += verified_seq_len[i] * draft_token_num;
}
int seq_len = verified_seq_len[bid];
int token_tree_idx = seq_tree_idx + (seq_len + draft_token_num) * tid + seq_len + 1;
int token_tree_idx;
if (tree_mask_mode == FULL_MASK) {
token_tree_idx = seq_tree_idx + (seq_len + draft_token_num) * tid + seq_len + 1;
} else {
token_tree_idx = draft_token_num * draft_token_num * bid + draft_token_num * tid + 1;
}
tree_mask[token_tree_idx - 1] = true;
for (int i = 0; i < draft_token_num - 1; i++) {
tree_mask[token_tree_idx + i] = false;
}
@@ -124,7 +133,8 @@ void build_tree_kernel_efficient(
at::Tensor retrive_next_sibling,
int64_t topk,
int64_t depth,
int64_t draft_token_num) {
int64_t draft_token_num,
int64_t tree_mask_mode) {
// TODO (ying) check shape
// TODO (ying) check type
int bs = parent_list.size(0);
@@ -132,18 +142,29 @@ void build_tree_kernel_efficient(
dim3 block(draft_token_num);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
build_tree_efficient<<<grid, block, 0, stream>>>(
static_cast<int64_t*>(parent_list.data_ptr()),
static_cast<int64_t*>(selected_index.data_ptr()),
static_cast<int64_t*>(verified_seq_len.data_ptr()),
static_cast<bool*>(tree_mask.data_ptr()),
static_cast<int64_t*>(positions.data_ptr()),
static_cast<int64_t*>(retrive_index.data_ptr()),
static_cast<int64_t*>(retrive_next_token.data_ptr()),
static_cast<int64_t*>(retrive_next_sibling.data_ptr()),
int32_t(topk),
int32_t(depth),
int32_t(draft_token_num));
if (tree_mask_mode == QLEN_ONLY_BITPACKING) {
size_t num_bytes_per_item = 1;
if (draft_token_num > 16) {
num_bytes_per_item = 4;
} else if (draft_token_num > 8) {
num_bytes_per_item = 2;
}
throw std::runtime_error("Not implemented");
} else {
build_tree_efficient<<<grid, block, 0, stream>>>(
static_cast<int64_t*>(parent_list.data_ptr()),
static_cast<int64_t*>(selected_index.data_ptr()),
static_cast<int64_t*>(verified_seq_len.data_ptr()),
static_cast<bool*>(tree_mask.data_ptr()),
static_cast<int64_t*>(positions.data_ptr()),
static_cast<int64_t*>(retrive_index.data_ptr()),
static_cast<int64_t*>(retrive_next_token.data_ptr()),
static_cast<int64_t*>(retrive_next_sibling.data_ptr()),
int32_t(topk),
int32_t(depth),
int32_t(draft_token_num),
int32_t(tree_mask_mode));
}
}
template <typename IdType, typename IdType2>

View File

@@ -78,7 +78,8 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
m.def(
"build_tree_kernel_efficient(Tensor parent_list, Tensor selected_index, Tensor verified_seq_len, "
"Tensor! tree_mask, Tensor! positions, Tensor! retrive_index, Tensor! retrive_next_token, "
"Tensor! retrive_next_sibling, int topk, int depth, int draft_token_num) -> ()");
"Tensor! retrive_next_sibling, int topk, int depth, int draft_token_num, int tree_mask_mode) -> "
"()");
m.impl("build_tree_kernel_efficient", torch::kCUDA, &build_tree_kernel_efficient);
}

View File

@@ -374,7 +374,8 @@ void build_tree_kernel_efficient(
at::Tensor retrive_next_sibling,
int64_t topk,
int64_t depth,
int64_t draft_token_num);
int64_t draft_token_num,
int64_t tree_mask_mode);
void segment_packbits(
at::Tensor x,

View File

@@ -72,6 +72,7 @@ def build_tree_kernel_efficient(
topk: int,
depth: int,
draft_token_num: int,
tree_mask_mode: int,
) -> None:
torch.ops.sgl_kernel.build_tree_kernel_efficient.default(
parent_list,
@@ -85,6 +86,7 @@ def build_tree_kernel_efficient(
topk,
depth,
draft_token_num,
tree_mask_mode,
)