Minor style fixes for sgl-kernel (#9289)

This commit is contained in:
Lianmin Zheng
2025-08-18 09:38:35 -07:00
committed by GitHub
parent 6e316588f8
commit c480a3f6ea
17 changed files with 439 additions and 109 deletions

View File

@@ -122,6 +122,95 @@ __global__ void build_tree_efficient(
}
}
// parent_list [bs, topk * (depth - 1) + 1)]
// selected_index [bs, draft_token_num - 1]
// verified_seq_len [bs]
// tree_mask: [draft_token*num_bytes_per_item | .. ] = [bs*draft_token*num_bytes_per_item]
// positions [bs * draft_token]
// retrive_index [bs, draft_token]
// retrive_next_token [bs, draft_token]
// retrive_next_sibling [bs, draft_token]
__global__ void build_tree_efficient_partial_packed(
int64_t* parent_list,
int64_t* selected_index,
int64_t* verified_seq_len,
uint8_t* tree_mask,
int64_t* positions,
int64_t* retrive_index,
int64_t* retrive_next_token,
int64_t* retrive_next_sibling,
int topk,
int depth,
int draft_token_num,
size_t num_bytes_per_item) {
int bid = blockIdx.x;
int tid = threadIdx.x;
if (tid >= draft_token_num) {
return;
}
int seq_len = verified_seq_len[bid];
int token_tree_idx = (bid * draft_token_num + tid) * num_bytes_per_item;
tree_mask[token_tree_idx] = 1; // little endian
int position = 0;
if (tid == 0) {
positions[bid * draft_token_num] = seq_len;
int retrive_index_offset = bid * draft_token_num;
for (int i = draft_token_num - 1; i > 0; --i) {
int current_token_idx = retrive_index_offset + i;
retrive_index[bid * draft_token_num + i] = current_token_idx;
int parent_tb_idx = selected_index[bid * (draft_token_num - 1) + i - 1] / topk;
int parent_position = 0;
if (parent_tb_idx > 0) {
int parent_token_idx = parent_list[bid * (topk * (depth - 1) + 1) + parent_tb_idx];
for (; parent_position < draft_token_num; ++parent_position) {
if (selected_index[bid * (draft_token_num - 1) + parent_position] == parent_token_idx) {
++parent_position;
break;
}
}
}
if (parent_position == draft_token_num) {
printf(
"WARNING: invalid eagle tree!!! Detected a token with no parent token selected. "
"Please check if the logprob has nan. The token will be ignored to keep proceeding.\n");
continue;
}
if (retrive_next_token[bid * draft_token_num + parent_position] == -1) {
retrive_next_token[bid * draft_token_num + parent_position] = i;
} else {
int origin_next_token = retrive_next_token[bid * draft_token_num + parent_position];
retrive_next_token[bid * draft_token_num + parent_position] = i;
retrive_next_sibling[bid * draft_token_num + i] = origin_next_token;
}
}
retrive_index[bid * draft_token_num] = bid * draft_token_num;
} else {
int cur_position = tid - 1;
while (true) {
position += 1;
int byte_idx = (cur_position + 1) / 8;
int bit_idx = (cur_position + 1) % 8;
tree_mask[token_tree_idx + byte_idx] |= (1 << bit_idx);
int parent_tb_idx = selected_index[bid * (draft_token_num - 1) + cur_position] / topk;
if (parent_tb_idx == 0) {
break;
}
int token_idx = parent_list[bid * (topk * (depth - 1) + 1) + parent_tb_idx];
for (cur_position = 0; cur_position < draft_token_num; ++cur_position) {
if (selected_index[bid * (draft_token_num - 1) + cur_position] == token_idx) {
break;
}
}
}
positions[bid * draft_token_num + tid] = position + seq_len;
}
}
void build_tree_kernel_efficient(
at::Tensor parent_list,
at::Tensor selected_index,
@@ -149,7 +238,19 @@ void build_tree_kernel_efficient(
} else if (draft_token_num > 8) {
num_bytes_per_item = 2;
}
throw std::runtime_error("Not implemented");
build_tree_efficient_partial_packed<<<grid, block, 0, stream>>>(
static_cast<int64_t*>(parent_list.data_ptr()),
static_cast<int64_t*>(selected_index.data_ptr()),
static_cast<int64_t*>(verified_seq_len.data_ptr()),
static_cast<uint8_t*>(tree_mask.data_ptr()),
static_cast<int64_t*>(positions.data_ptr()),
static_cast<int64_t*>(retrive_index.data_ptr()),
static_cast<int64_t*>(retrive_next_token.data_ptr()),
static_cast<int64_t*>(retrive_next_sibling.data_ptr()),
int32_t(topk),
int32_t(depth),
int32_t(draft_token_num),
num_bytes_per_item);
} else {
build_tree_efficient<<<grid, block, 0, stream>>>(
static_cast<int64_t*>(parent_list.data_ptr()),