Rename files in sgl kernel to avoid nested folder structure (#4213)

Co-authored-by: zhyncs <me@zhyncs.com>
This commit is contained in:
Lianmin Zheng
2025-03-08 22:54:51 -08:00
committed by GitHub
parent ee132a4515
commit 8abf74e3c9
47 changed files with 184 additions and 199 deletions

View File

@@ -0,0 +1,251 @@
/*
* Copyright (c) 2025 by SGLang team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
// parent_list [bs, topk * (depth - 1) + 1)]
// selected_index [bs, draft_token_num - 1]
// verified_seq_len [bs]
// tree_mask [draft_token*(seq_len[0]+draft_token) | draft_token*(seq_len[1]+draft_token) | ..] =
// [sum(verified_seq_len)*draft_token+bs*draft_token*draft_token] positions [bs * draft_token] retrive_index [b,
// draft_token] retrive_next_token [b, draft_token] retrive_next_sibling [b, draft_token]
__global__ void build_tree_efficient(
int64_t* parent_list,
int64_t* selected_index,
int32_t* verified_seq_len,
bool* tree_mask,
int64_t* positions,
int64_t* retrive_index,
int64_t* retrive_next_token,
int64_t* retrive_next_sibling,
int topk,
int depth,
int draft_token_num) {
int bid = blockIdx.x;
int tid = threadIdx.x;
if (tid >= draft_token_num) {
return;
}
int seq_tree_idx = draft_token_num * draft_token_num * bid;
for (int i = 0; i < bid; i++) {
seq_tree_idx += verified_seq_len[i] * draft_token_num;
}
int seq_len = verified_seq_len[bid];
int token_tree_idx = seq_tree_idx + (seq_len + draft_token_num) * tid + seq_len + 1;
for (int i = 0; i < draft_token_num - 1; i++) {
tree_mask[token_tree_idx + i] = false;
}
int position = 0;
if (tid == 0) {
positions[bid * draft_token_num] = seq_len;
int retrive_index_offset = bid * draft_token_num;
for (int i = draft_token_num - 1; i > 0; --i) {
int current_token_idx = retrive_index_offset + i;
retrive_index[bid * draft_token_num + i] = current_token_idx;
int parent_tb_idx = selected_index[bid * (draft_token_num - 1) + i - 1] / topk;
int parent_position = 0;
if (parent_tb_idx > 0) {
int parent_token_idx = parent_list[bid * (topk * (depth - 1) + 1) + parent_tb_idx];
for (; parent_position < draft_token_num; ++parent_position) {
if (selected_index[bid * (draft_token_num - 1) + parent_position] == parent_token_idx) {
++parent_position;
break;
}
}
}
if (parent_position == draft_token_num) {
printf(
"ERROR: invalid eagle tree!!! Detected a token with no parent token selected. Check the logprob. The token "
"will be dropped.");
continue;
}
if (retrive_next_token[bid * draft_token_num + parent_position] == -1) {
retrive_next_token[bid * draft_token_num + parent_position] = i;
} else {
int origin_next_token = retrive_next_token[bid * draft_token_num + parent_position];
retrive_next_token[bid * draft_token_num + parent_position] = i;
retrive_next_sibling[bid * draft_token_num + i] = origin_next_token;
}
}
retrive_index[bid * draft_token_num] = bid * draft_token_num;
} else {
int cur_position = tid - 1;
while (true) {
position += 1;
tree_mask[token_tree_idx + cur_position] = true;
int parent_tb_idx = selected_index[bid * (draft_token_num - 1) + cur_position] / topk;
if (parent_tb_idx == 0) {
break;
}
int token_idx = parent_list[bid * (topk * (depth - 1) + 1) + parent_tb_idx];
for (cur_position = 0; cur_position < draft_token_num; ++cur_position) {
if (selected_index[bid * (draft_token_num - 1) + cur_position] == token_idx) {
break;
}
}
}
positions[bid * draft_token_num + tid] = position + seq_len;
}
}
void build_tree_kernel_efficient(
at::Tensor parent_list,
at::Tensor selected_index,
at::Tensor verified_seq_len,
at::Tensor tree_mask,
at::Tensor positions,
at::Tensor retrive_index,
at::Tensor retrive_next_token,
at::Tensor retrive_next_sibling,
int64_t topk,
int64_t depth,
int64_t draft_token_num) {
// TODO (ying) check shape
// TODO (ying) check type
int bs = parent_list.size(0);
dim3 grid(bs);
dim3 block(draft_token_num);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
build_tree_efficient<<<grid, block, 0, stream>>>(
static_cast<int64_t*>(parent_list.data_ptr()),
static_cast<int64_t*>(selected_index.data_ptr()),
static_cast<int32_t*>(verified_seq_len.data_ptr()),
static_cast<bool*>(tree_mask.data_ptr()),
static_cast<int64_t*>(positions.data_ptr()),
static_cast<int64_t*>(retrive_index.data_ptr()),
static_cast<int64_t*>(retrive_next_token.data_ptr()),
static_cast<int64_t*>(retrive_next_sibling.data_ptr()),
int32_t(topk),
int32_t(depth),
int32_t(draft_token_num));
}
// parent_list [bs, topk * (depth - 1) + 1)]
// selected_index [bs, draft_token_num - 1]
// verified_seq_len [bs]
// tree_mask [draft_token*(seq_len[0]+draft_token) | draft_token*(seq_len[1]+draft_token) | ..] =
// [sum(verified_seq_len)*draft_token+bs*draft_token*draft_token] positions [bs * draft_token] retrive_index [b,
// draft_token, depth + 2]
__global__ void build_tree(
int64_t* parent_list,
int64_t* selected_index,
int32_t* verified_seq_len,
bool* tree_mask,
int64_t* positions,
int64_t* retrive_index,
int topk,
int depth,
int draft_token_num) {
int bid = blockIdx.x;
int tid = threadIdx.x;
if (tid >= draft_token_num) {
return;
}
int seq_tree_idx = draft_token_num * draft_token_num * bid;
for (int i = 0; i < bid; i++) {
seq_tree_idx += verified_seq_len[i] * draft_token_num;
}
int seq_len = verified_seq_len[bid];
int token_tree_idx = seq_tree_idx + (seq_len + draft_token_num) * tid + seq_len + 1;
for (int i = 0; i < draft_token_num - 1; i++) {
tree_mask[token_tree_idx + i] = false;
}
int position = 0;
if (tid == 0) {
positions[bid * draft_token_num] = seq_len;
retrive_index[bid * draft_token_num * (depth + 2)] = bid * draft_token_num;
return;
}
int depends_order[10];
int cur_position = tid - 1;
while (true) {
depends_order[position] = cur_position + 1;
position += 1;
tree_mask[token_tree_idx + cur_position] = true;
int parent_tb_idx = selected_index[bid * (draft_token_num - 1) + cur_position] / topk;
if (parent_tb_idx == 0) {
break;
}
int token_idx = parent_list[bid * (topk * (depth - 1) + 1) + parent_tb_idx];
for (cur_position = 0; cur_position < draft_token_num; cur_position++) {
if (selected_index[bid * (draft_token_num - 1) + cur_position] == token_idx) {
break;
}
}
if (cur_position == draft_token_num) {
printf(
"ERROR: invalid eagle tree!!! Detected a token with no parent token selected. Check the logprob. The token "
"will be dropped.");
break;
}
}
positions[bid * draft_token_num + tid] = position + seq_len;
int is_leaf = 0;
for (int i = 1; i < draft_token_num; i++) {
if (tree_mask[seq_tree_idx + i * (draft_token_num + seq_len) + seq_len + tid]) {
is_leaf++;
}
}
if (is_leaf == 1) {
for (int i = 0; i < position; i++) {
retrive_index[(bid * (draft_token_num) + tid) * (depth + 2) + position - i] =
depends_order[i] + bid * draft_token_num;
}
retrive_index[(bid * (draft_token_num) + tid) * (depth + 2)] = bid * draft_token_num;
}
}
void build_tree_kernel(
at::Tensor parent_list,
at::Tensor selected_index,
at::Tensor verified_seq_len,
at::Tensor tree_mask,
at::Tensor positions,
at::Tensor retrive_index,
int64_t topk,
int64_t depth,
int64_t draft_token_num) {
// TODO (ying) check shape
// TODO (ying) check type
int bs = parent_list.size(0);
dim3 grid(bs);
dim3 block(draft_token_num);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
build_tree<<<grid, block, 0, stream>>>(
static_cast<int64_t*>(parent_list.data_ptr()),
static_cast<int64_t*>(selected_index.data_ptr()),
static_cast<int32_t*>(verified_seq_len.data_ptr()),
static_cast<bool*>(tree_mask.data_ptr()),
static_cast<int64_t*>(positions.data_ptr()),
static_cast<int64_t*>(retrive_index.data_ptr()),
int32_t(topk),
int32_t(depth),
int32_t(draft_token_num));
}

View File

@@ -0,0 +1,138 @@
/*
* Copyright (c) 2025 by SGLang team.
* Copyright (c) 2025 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pytorch_extension_utils.h"
#include "speculative_sampling.cuh"
using namespace flashinfer;
// predicts: [tot_num_draft_tokens]
// accept_index: [bs, num_spec_step]
// accept_token_num: [bs]
// candidates: [bs, num_draft_tokens]
// retrive_index: [bs, num_draft_tokens]
// retrive_next_token: [bs, num_draft_tokens]
// retrive_next_sibling: [bs, num_draft_tokens]
// uniform_samples: [bs, num_draft_tokens]
// target_probs: [bs, num_draft_tokens, vocab_size]
void tree_speculative_sampling_target_only(
at::Tensor predicts,
at::Tensor accept_index,
at::Tensor accept_token_num, // mutable
at::Tensor candidates,
at::Tensor retrive_index,
at::Tensor retrive_next_token,
at::Tensor retrive_next_sibling,
at::Tensor uniform_samples,
at::Tensor target_probs,
at::Tensor draft_probs,
bool deterministic,
int64_t cuda_stream = 0) {
CHECK_INPUT(candidates);
CHECK_INPUT(retrive_index);
CHECK_INPUT(retrive_next_token);
CHECK_INPUT(retrive_next_sibling);
CHECK_INPUT(uniform_samples);
CHECK_INPUT(target_probs);
auto device = target_probs.device();
CHECK_EQ(candidates.device(), device);
CHECK_EQ(retrive_index.device(), device);
CHECK_EQ(retrive_next_token.device(), device);
CHECK_EQ(retrive_next_sibling.device(), device);
CHECK_EQ(uniform_samples.device(), device);
CHECK_EQ(target_probs.device(), device);
CHECK_DIM(1, predicts);
CHECK_DIM(2, accept_index);
CHECK_DIM(1, accept_token_num);
CHECK_DIM(2, candidates);
CHECK_DIM(2, retrive_index);
CHECK_DIM(2, retrive_next_token);
CHECK_DIM(2, retrive_next_sibling);
CHECK_DIM(2, uniform_samples);
CHECK_DIM(3, target_probs);
CHECK_DIM(3, draft_probs);
unsigned int batch_size = uniform_samples.size(0);
unsigned int num_spec_step = accept_index.size(1);
unsigned int num_draft_tokens = candidates.size(1);
unsigned int vocab_size = target_probs.size(2);
CHECK_EQ(batch_size, candidates.size(0));
CHECK_EQ(batch_size, retrive_index.size(0));
CHECK_EQ(batch_size, retrive_next_token.size(0));
CHECK_EQ(batch_size, retrive_next_sibling.size(0));
CHECK_EQ(batch_size, target_probs.size(0));
CHECK_EQ(num_draft_tokens, retrive_index.size(1));
CHECK_EQ(num_draft_tokens, retrive_next_token.size(1));
CHECK_EQ(num_draft_tokens, retrive_next_sibling.size(1));
CHECK_EQ(num_draft_tokens, uniform_samples.size(1));
CHECK_EQ(num_draft_tokens, target_probs.size(1));
CHECK_EQ(vocab_size, target_probs.size(2));
CHECK_EQ(batch_size, accept_index.size(0));
CHECK_EQ(batch_size, accept_token_num.size(0));
if (predicts.scalar_type() != at::kInt) {
throw std::runtime_error("Expected 'predicts' to be of type int (torch.int32).");
}
if (accept_index.scalar_type() != at::kInt) {
throw std::runtime_error("Expected 'accept_index' to be of type int (torch.int32).");
}
if (accept_token_num.scalar_type() != at::kInt) {
throw std::runtime_error("Expected 'accept_token_num' to be of type int (torch.int32).");
}
if (candidates.scalar_type() != at::kInt) {
throw std::runtime_error("Expected 'candidates' to be of type int (torch.int32).");
}
if (retrive_index.scalar_type() != at::kInt) {
throw std::runtime_error("Expected 'retrive_index' to be of type int (torch.int32).");
}
if (retrive_next_token.scalar_type() != at::kInt) {
throw std::runtime_error("Expected 'retrive_next_token' to be of type int (torch.int32).");
}
if (retrive_next_sibling.scalar_type() != at::kInt) {
throw std::runtime_error("Expected 'retrive_next_sibling' to be of type int (torch.int32).");
}
if (uniform_samples.scalar_type() != at::kFloat) {
throw std::runtime_error("Expected 'uniform_samples' to be of type float (torch.float32).");
}
if (target_probs.scalar_type() != at::kFloat) {
throw std::runtime_error("Expected 'target_probs' to be of type float (torch.float32).");
}
if (draft_probs.scalar_type() != at::kFloat) {
throw std::runtime_error("Expected 'target_probs' to be of type float (torch.float32).");
}
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
cudaError_t status = sampling::TreeSpeculativeSamplingTargetOnly<float, int>(
static_cast<int*>(predicts.data_ptr()),
static_cast<int*>(accept_index.data_ptr()),
static_cast<int*>(accept_token_num.data_ptr()),
static_cast<int*>(candidates.data_ptr()),
static_cast<int*>(retrive_index.data_ptr()),
static_cast<int*>(retrive_next_token.data_ptr()),
static_cast<int*>(retrive_next_sibling.data_ptr()),
static_cast<float*>(uniform_samples.data_ptr()),
static_cast<float*>(target_probs.data_ptr()),
static_cast<float*>(draft_probs.data_ptr()),
batch_size,
num_spec_step,
num_draft_tokens,
vocab_size,
deterministic,
stream);
TORCH_CHECK(
status == cudaSuccess,
"TreeSpeculativeSamplingTargetOnly failed with error code " + std::string(cudaGetErrorString(status)));
}

View File

@@ -0,0 +1,215 @@
/*
* Copyright (c) 2025 by SGLang team.
* Copyright (c) 2024-2025 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef SPECULATIVE_SAMPLING_CUH_
#define SPECULATIVE_SAMPLING_CUH_
#include <assert.h>
#include <flashinfer/sampling.cuh>
namespace flashinfer {
namespace sampling {
using namespace cub;
template <
uint32_t BLOCK_THREADS,
BlockScanAlgorithm SCAN_ALGORITHM,
BlockReduceAlgorithm REDUCE_ALGORITHM,
uint32_t VEC_SIZE,
bool DETERMINISTIC,
typename DType,
typename IdType>
__global__ void TreeSpeculativeSamplingTargetOnly(
IdType* predicts,
IdType* accept_index,
IdType* accept_token_num, // mutable
IdType* candidates,
IdType* retrive_index,
IdType* retrive_next_token,
IdType* retrive_next_sibling,
DType* uniform_samples,
DType* target_probs,
DType* draft_probs,
uint32_t batch_size,
uint32_t num_speculative_tokens,
uint32_t num_draft_tokens,
uint32_t d) {
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
extern __shared__ __align__(alignof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
uint8_t smem_sampling[];
auto& temp_storage =
reinterpret_cast<SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem_sampling);
DType prob_acc = 0.0;
uint32_t cur_prob_offset = bx * num_draft_tokens * d;
DType coin = uniform_samples[bx * num_draft_tokens];
IdType last_accepted_retrive_idx = retrive_index[bx * num_draft_tokens];
accept_index[bx * num_speculative_tokens] = last_accepted_retrive_idx;
uint32_t num_accepted_tokens = 0;
IdType cur_index = 0;
for (uint32_t j = 1; j < num_speculative_tokens; ++j) {
cur_index = retrive_next_token[bx * num_draft_tokens + cur_index];
while (cur_index != -1) {
IdType draft_index = retrive_index[bx * num_draft_tokens + cur_index];
IdType draft_token_id = candidates[bx * num_draft_tokens + cur_index];
prob_acc += target_probs[cur_prob_offset + draft_token_id];
if (coin < prob_acc) {
// accept token
prob_acc = 0.;
cur_prob_offset = (bx * num_draft_tokens + cur_index) * d;
coin = uniform_samples[bx * num_draft_tokens + cur_index];
predicts[last_accepted_retrive_idx] = draft_token_id;
++num_accepted_tokens;
accept_index[bx * num_speculative_tokens + num_accepted_tokens] = draft_index;
last_accepted_retrive_idx = draft_index;
break;
} else {
// FIXME: leverage draft probs
draft_probs[cur_prob_offset + draft_token_id] = target_probs[cur_prob_offset + draft_token_id];
cur_index = retrive_next_sibling[bx * num_draft_tokens + cur_index];
}
}
if (cur_index == -1) break;
}
accept_token_num[bx] = num_accepted_tokens;
// sample from relu(target_probs - draft_probs)
DType sum_relu_q_minus_p(0);
vec_t<DType, VEC_SIZE> q_vec, p_vec;
DType relu_q_minus_p[VEC_SIZE];
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
q_vec.fill(DType(0));
p_vec.fill(DType(0));
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
q_vec.load(target_probs + cur_prob_offset + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
if (num_accepted_tokens != num_speculative_tokens - 1) {
// there is no draft_probs for the bonus token
p_vec.load(draft_probs + cur_prob_offset + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
}
}
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
relu_q_minus_p[j] = max(q_vec[j] - p_vec[j], DType(0));
}
sum_relu_q_minus_p += BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
.Sum<VEC_SIZE>(relu_q_minus_p);
__syncthreads();
}
if (tx == 0) {
temp_storage.block_aggregate.value = sum_relu_q_minus_p;
}
// init the first rejected token to (d - 1)
temp_storage.sampled_id = d - 1;
__syncthreads();
sum_relu_q_minus_p = temp_storage.block_aggregate.value;
DType u = coin * sum_relu_q_minus_p;
DType aggregate_relu_q_minus_p(0);
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
q_vec.fill(DType(0));
p_vec.fill(DType(0));
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
q_vec.load(target_probs + cur_prob_offset + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
if (num_accepted_tokens != num_speculative_tokens - 1) {
// there is no draft_probs for the bonus token
p_vec.load(draft_probs + cur_prob_offset + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
}
}
vec_t<DType, VEC_SIZE> relu_q_minus_p_vec;
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
relu_q_minus_p_vec[j] = max(q_vec[j] - p_vec[j], DType(0));
}
DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM, DETERMINISTIC, DType>(
i, d, [&](DType x) { return x > 0; }, u, relu_q_minus_p_vec, aggregate_relu_q_minus_p, &temp_storage);
if (aggregate_relu_q_minus_p > u) {
break;
}
}
__syncthreads();
// set the first rejected token
predicts[last_accepted_retrive_idx] = temp_storage.sampled_id;
// value at not used indices are undefined
}
template <typename DType, typename IdType>
cudaError_t TreeSpeculativeSamplingTargetOnly(
IdType* predicts,
IdType* output_token_ids,
IdType* output_accepted_token_num, // mutable
IdType* candidates,
IdType* retrive_index,
IdType* retrive_next_token,
IdType* retrive_next_sibling,
DType* uniform_samples,
DType* target_probs,
DType* draft_probs,
uint32_t batch_size,
uint32_t num_speculative_tokens,
uint32_t num_draft_tokens,
uint32_t d,
bool deterministic,
cudaStream_t stream = 0) {
constexpr uint32_t BLOCK_THREADS = 1024;
const uint32_t vec_size = std::gcd(16 / sizeof(DType), d);
const uint32_t smem_size = sizeof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
dim3 nblks(batch_size);
dim3 nthrs(BLOCK_THREADS);
void* args[] = {
&predicts,
&output_token_ids,
&output_accepted_token_num,
&candidates,
&retrive_index,
&retrive_next_token,
&retrive_next_sibling,
&uniform_samples,
&target_probs,
&draft_probs,
&batch_size,
&num_speculative_tokens,
&num_draft_tokens,
&d};
DISPATCH_ALIGNED_VEC_SIZE(
vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, {
auto kernel = TreeSpeculativeSamplingTargetOnly<
BLOCK_THREADS,
SCAN_ALGO,
REDUCE_ALGO,
VEC_SIZE,
DETERMINISTIC,
DType,
IdType>;
FLASHINFER_CUDA_CALL(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
})});
return cudaSuccess;
}
} // namespace sampling
} // namespace flashinfer
#endif // SPECULATIVE_SAMPLING_CUH_