Files
sglang/sgl-kernel/csrc/speculative/ngram_utils.cu
2025-09-28 21:06:59 -07:00

106 lines
3.1 KiB
Plaintext

#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#ifndef USE_ROCM
#include "pytorch_extension_utils.h"
#else
#include "pytorch_extension_utils_rocm.h"
#endif
// tree_mask: [bs * draft_token_num * draft_token_num]
// verified_seq_len: [bs]
// positions: [bs * draft_token_num]
// retrive_index: [bs, draft_token_num]
// retrive_next_token: [bs, draft_token_num]
// retrive_next_sibling: [bs, draft_token_num]
__global__ void reconstructIndicesFromTreeMask(
bool* tree_mask,
int64_t* verified_seq_len,
int64_t* positions,
int64_t* retrive_index,
int64_t* retrive_next_token,
int64_t* retrive_next_sibling,
int batch_size,
int draft_token_num) {
int bid = blockIdx.x;
int tid = threadIdx.x;
if (bid >= batch_size || tid >= draft_token_num) {
return;
}
int base_offset = draft_token_num * draft_token_num;
// token_idx: [bid * draft_token_num, (bid + 1) * draft_token_num)
int token_idx = bid * draft_token_num;
// tree_mask_idx: [bid * base_offset, (bid + 1) * base_offset)
int tree_mask_offset = bid * base_offset;
int depth = 0;
int parent_idx = -1;
for (int i = tid - 1, start_idx = tree_mask_offset + tid * draft_token_num; i >= 0; i--) {
if (tree_mask[start_idx + i]) {
depth++;
if (parent_idx == -1) {
parent_idx = i;
}
}
}
retrive_index[token_idx + tid] = token_idx + tid;
positions[token_idx + tid] = depth + verified_seq_len[bid];
int next_token_idx = -1;
for (int i = tid + 1; i < draft_token_num; i++) {
if (tree_mask[tree_mask_offset + i * draft_token_num + tid]) {
next_token_idx = i;
break;
}
}
retrive_next_token[token_idx + tid] = next_token_idx;
int next_sibling_idx = -1;
if (parent_idx != -1) {
for (int i = tid + 1; i < draft_token_num; i++) {
int start_idx = tree_mask_offset + i * draft_token_num + parent_idx;
if (tree_mask[start_idx]) {
bool is_sibling = true;
int end_idx = tree_mask_offset + i * draft_token_num + i;
for (int j = start_idx + 1; j < end_idx; ++j) {
if (tree_mask[j]) {
is_sibling = false;
break;
}
}
if (is_sibling) {
next_sibling_idx = i;
break;
}
}
}
}
retrive_next_sibling[token_idx + tid] = next_sibling_idx;
}
void reconstruct_indices_from_tree_mask(
at::Tensor tree_mask,
at::Tensor verified_seq_len,
at::Tensor positions,
at::Tensor retrive_index,
at::Tensor retrive_next_token,
at::Tensor retrive_next_sibling,
int64_t batch_size,
int64_t draft_token_num) {
dim3 grid(batch_size);
dim3 block(draft_token_num);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
reconstructIndicesFromTreeMask<<<grid, block, 0, stream>>>(
static_cast<bool*>(tree_mask.data_ptr()),
static_cast<int64_t*>(verified_seq_len.data_ptr()),
static_cast<int64_t*>(positions.data_ptr()),
static_cast<int64_t*>(retrive_index.data_ptr()),
static_cast<int64_t*>(retrive_next_token.data_ptr()),
static_cast<int64_t*>(retrive_next_sibling.data_ptr()),
int(batch_size),
int(draft_token_num));
}