Minor style fixes for sgl-kernel (#9289)
This commit is contained in:
@@ -122,6 +122,95 @@ __global__ void build_tree_efficient(
|
||||
}
|
||||
}
|
||||
|
||||
// parent_list [bs, topk * (depth - 1) + 1)]
|
||||
// selected_index [bs, draft_token_num - 1]
|
||||
// verified_seq_len [bs]
|
||||
// tree_mask: [draft_token*num_bytes_per_item | .. ] = [bs*draft_token*num_bytes_per_item]
|
||||
// positions [bs * draft_token]
|
||||
// retrive_index [bs, draft_token]
|
||||
// retrive_next_token [bs, draft_token]
|
||||
// retrive_next_sibling [bs, draft_token]
|
||||
__global__ void build_tree_efficient_partial_packed(
|
||||
int64_t* parent_list,
|
||||
int64_t* selected_index,
|
||||
int64_t* verified_seq_len,
|
||||
uint8_t* tree_mask,
|
||||
int64_t* positions,
|
||||
int64_t* retrive_index,
|
||||
int64_t* retrive_next_token,
|
||||
int64_t* retrive_next_sibling,
|
||||
int topk,
|
||||
int depth,
|
||||
int draft_token_num,
|
||||
size_t num_bytes_per_item) {
|
||||
int bid = blockIdx.x;
|
||||
int tid = threadIdx.x;
|
||||
|
||||
if (tid >= draft_token_num) {
|
||||
return;
|
||||
}
|
||||
int seq_len = verified_seq_len[bid];
|
||||
int token_tree_idx = (bid * draft_token_num + tid) * num_bytes_per_item;
|
||||
tree_mask[token_tree_idx] = 1; // little endian
|
||||
|
||||
int position = 0;
|
||||
if (tid == 0) {
|
||||
positions[bid * draft_token_num] = seq_len;
|
||||
|
||||
int retrive_index_offset = bid * draft_token_num;
|
||||
for (int i = draft_token_num - 1; i > 0; --i) {
|
||||
int current_token_idx = retrive_index_offset + i;
|
||||
retrive_index[bid * draft_token_num + i] = current_token_idx;
|
||||
int parent_tb_idx = selected_index[bid * (draft_token_num - 1) + i - 1] / topk;
|
||||
int parent_position = 0;
|
||||
if (parent_tb_idx > 0) {
|
||||
int parent_token_idx = parent_list[bid * (topk * (depth - 1) + 1) + parent_tb_idx];
|
||||
for (; parent_position < draft_token_num; ++parent_position) {
|
||||
if (selected_index[bid * (draft_token_num - 1) + parent_position] == parent_token_idx) {
|
||||
++parent_position;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (parent_position == draft_token_num) {
|
||||
printf(
|
||||
"WARNING: invalid eagle tree!!! Detected a token with no parent token selected. "
|
||||
"Please check if the logprob has nan. The token will be ignored to keep proceeding.\n");
|
||||
continue;
|
||||
}
|
||||
|
||||
if (retrive_next_token[bid * draft_token_num + parent_position] == -1) {
|
||||
retrive_next_token[bid * draft_token_num + parent_position] = i;
|
||||
} else {
|
||||
int origin_next_token = retrive_next_token[bid * draft_token_num + parent_position];
|
||||
retrive_next_token[bid * draft_token_num + parent_position] = i;
|
||||
retrive_next_sibling[bid * draft_token_num + i] = origin_next_token;
|
||||
}
|
||||
}
|
||||
retrive_index[bid * draft_token_num] = bid * draft_token_num;
|
||||
} else {
|
||||
int cur_position = tid - 1;
|
||||
while (true) {
|
||||
position += 1;
|
||||
int byte_idx = (cur_position + 1) / 8;
|
||||
int bit_idx = (cur_position + 1) % 8;
|
||||
tree_mask[token_tree_idx + byte_idx] |= (1 << bit_idx);
|
||||
int parent_tb_idx = selected_index[bid * (draft_token_num - 1) + cur_position] / topk;
|
||||
if (parent_tb_idx == 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
int token_idx = parent_list[bid * (topk * (depth - 1) + 1) + parent_tb_idx];
|
||||
for (cur_position = 0; cur_position < draft_token_num; ++cur_position) {
|
||||
if (selected_index[bid * (draft_token_num - 1) + cur_position] == token_idx) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
positions[bid * draft_token_num + tid] = position + seq_len;
|
||||
}
|
||||
}
|
||||
|
||||
void build_tree_kernel_efficient(
|
||||
at::Tensor parent_list,
|
||||
at::Tensor selected_index,
|
||||
@@ -149,7 +238,19 @@ void build_tree_kernel_efficient(
|
||||
} else if (draft_token_num > 8) {
|
||||
num_bytes_per_item = 2;
|
||||
}
|
||||
throw std::runtime_error("Not implemented");
|
||||
build_tree_efficient_partial_packed<<<grid, block, 0, stream>>>(
|
||||
static_cast<int64_t*>(parent_list.data_ptr()),
|
||||
static_cast<int64_t*>(selected_index.data_ptr()),
|
||||
static_cast<int64_t*>(verified_seq_len.data_ptr()),
|
||||
static_cast<uint8_t*>(tree_mask.data_ptr()),
|
||||
static_cast<int64_t*>(positions.data_ptr()),
|
||||
static_cast<int64_t*>(retrive_index.data_ptr()),
|
||||
static_cast<int64_t*>(retrive_next_token.data_ptr()),
|
||||
static_cast<int64_t*>(retrive_next_sibling.data_ptr()),
|
||||
int32_t(topk),
|
||||
int32_t(depth),
|
||||
int32_t(draft_token_num),
|
||||
num_bytes_per_item);
|
||||
} else {
|
||||
build_tree_efficient<<<grid, block, 0, stream>>>(
|
||||
static_cast<int64_t*>(parent_list.data_ptr()),
|
||||
|
||||
Reference in New Issue
Block a user