Fix sampling for speculative decoding & simplify kernels (#7207)

This commit is contained in:
Lianmin Zheng
2025-06-16 03:28:30 -07:00
committed by GitHub
parent b1286a116a
commit cfceb83d05
11 changed files with 124 additions and 79 deletions

View File

@@ -37,6 +37,7 @@ void tree_speculative_sampling_target_only(
at::Tensor retrive_next_token,
at::Tensor retrive_next_sibling,
at::Tensor uniform_samples,
at::Tensor uniform_samples_for_final_sampling,
at::Tensor target_probs,
at::Tensor draft_probs,
double threshold_single,
@@ -48,6 +49,7 @@ void tree_speculative_sampling_target_only(
CHECK_INPUT(retrive_next_token);
CHECK_INPUT(retrive_next_sibling);
CHECK_INPUT(uniform_samples);
CHECK_INPUT(uniform_samples_for_final_sampling);
CHECK_INPUT(target_probs);
auto device = target_probs.device();
CHECK_EQ(candidates.device(), device);
@@ -55,6 +57,7 @@ void tree_speculative_sampling_target_only(
CHECK_EQ(retrive_next_token.device(), device);
CHECK_EQ(retrive_next_sibling.device(), device);
CHECK_EQ(uniform_samples.device(), device);
CHECK_EQ(uniform_samples_for_final_sampling.device(), device);
CHECK_EQ(target_probs.device(), device);
CHECK_DIM(1, predicts);
CHECK_DIM(2, accept_index);
@@ -92,21 +95,24 @@ void tree_speculative_sampling_target_only(
if (accept_token_num.scalar_type() != at::kInt) {
throw std::runtime_error("Expected 'accept_token_num' to be of type int (torch.int32).");
}
if (candidates.scalar_type() != at::kInt) {
throw std::runtime_error("Expected 'candidates' to be of type int (torch.int32).");
if (candidates.scalar_type() != at::kLong) {
throw std::runtime_error("Expected 'candidates' to be of type long (torch.int64).");
}
if (retrive_index.scalar_type() != at::kInt) {
throw std::runtime_error("Expected 'retrive_index' to be of type int (torch.int32).");
if (retrive_index.scalar_type() != at::kLong) {
throw std::runtime_error("Expected 'retrive_index' to be of type long (torch.int64).");
}
if (retrive_next_token.scalar_type() != at::kInt) {
throw std::runtime_error("Expected 'retrive_next_token' to be of type int (torch.int32).");
if (retrive_next_token.scalar_type() != at::kLong) {
throw std::runtime_error("Expected 'retrive_next_token' to be of type long (torch.int64).");
}
if (retrive_next_sibling.scalar_type() != at::kInt) {
throw std::runtime_error("Expected 'retrive_next_sibling' to be of type int (torch.int32).");
if (retrive_next_sibling.scalar_type() != at::kLong) {
throw std::runtime_error("Expected 'retrive_next_sibling' to be of type long (torch.int64).");
}
if (uniform_samples.scalar_type() != at::kFloat) {
throw std::runtime_error("Expected 'uniform_samples' to be of type float (torch.float32).");
}
if (uniform_samples_for_final_sampling.scalar_type() != at::kFloat) {
throw std::runtime_error("Expected 'uniform_samples_for_final_sampling' to be of type float (torch.float32).");
}
if (target_probs.scalar_type() != at::kFloat) {
throw std::runtime_error("Expected 'target_probs' to be of type float (torch.float32).");
}
@@ -119,15 +125,16 @@ void tree_speculative_sampling_target_only(
CHECK_GE(1, threshold_acc);
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
cudaError_t status = sampling::TreeSpeculativeSamplingTargetOnly<float, int>(
static_cast<int*>(predicts.data_ptr()),
static_cast<int*>(accept_index.data_ptr()),
static_cast<int*>(accept_token_num.data_ptr()),
static_cast<int*>(candidates.data_ptr()),
static_cast<int*>(retrive_index.data_ptr()),
static_cast<int*>(retrive_next_token.data_ptr()),
static_cast<int*>(retrive_next_sibling.data_ptr()),
cudaError_t status = sampling::TreeSpeculativeSamplingTargetOnly<float, int32_t, int64_t>(
static_cast<int32_t*>(predicts.data_ptr()),
static_cast<int32_t*>(accept_index.data_ptr()),
static_cast<int32_t*>(accept_token_num.data_ptr()),
static_cast<int64_t*>(candidates.data_ptr()),
static_cast<int64_t*>(retrive_index.data_ptr()),
static_cast<int64_t*>(retrive_next_token.data_ptr()),
static_cast<int64_t*>(retrive_next_sibling.data_ptr()),
static_cast<float*>(uniform_samples.data_ptr()),
static_cast<float*>(uniform_samples_for_final_sampling.data_ptr()),
static_cast<float*>(target_probs.data_ptr()),
static_cast<float*>(draft_probs.data_ptr()),
batch_size,