Fix sampling for speculative decoding & simplify kernels (#7207)
This commit is contained in:
@@ -37,6 +37,7 @@ void tree_speculative_sampling_target_only(
|
||||
at::Tensor retrive_next_token,
|
||||
at::Tensor retrive_next_sibling,
|
||||
at::Tensor uniform_samples,
|
||||
at::Tensor uniform_samples_for_final_sampling,
|
||||
at::Tensor target_probs,
|
||||
at::Tensor draft_probs,
|
||||
double threshold_single,
|
||||
@@ -48,6 +49,7 @@ void tree_speculative_sampling_target_only(
|
||||
CHECK_INPUT(retrive_next_token);
|
||||
CHECK_INPUT(retrive_next_sibling);
|
||||
CHECK_INPUT(uniform_samples);
|
||||
CHECK_INPUT(uniform_samples_for_final_sampling);
|
||||
CHECK_INPUT(target_probs);
|
||||
auto device = target_probs.device();
|
||||
CHECK_EQ(candidates.device(), device);
|
||||
@@ -55,6 +57,7 @@ void tree_speculative_sampling_target_only(
|
||||
CHECK_EQ(retrive_next_token.device(), device);
|
||||
CHECK_EQ(retrive_next_sibling.device(), device);
|
||||
CHECK_EQ(uniform_samples.device(), device);
|
||||
CHECK_EQ(uniform_samples_for_final_sampling.device(), device);
|
||||
CHECK_EQ(target_probs.device(), device);
|
||||
CHECK_DIM(1, predicts);
|
||||
CHECK_DIM(2, accept_index);
|
||||
@@ -92,21 +95,24 @@ void tree_speculative_sampling_target_only(
|
||||
if (accept_token_num.scalar_type() != at::kInt) {
|
||||
throw std::runtime_error("Expected 'accept_token_num' to be of type int (torch.int32).");
|
||||
}
|
||||
if (candidates.scalar_type() != at::kInt) {
|
||||
throw std::runtime_error("Expected 'candidates' to be of type int (torch.int32).");
|
||||
if (candidates.scalar_type() != at::kLong) {
|
||||
throw std::runtime_error("Expected 'candidates' to be of type long (torch.int64).");
|
||||
}
|
||||
if (retrive_index.scalar_type() != at::kInt) {
|
||||
throw std::runtime_error("Expected 'retrive_index' to be of type int (torch.int32).");
|
||||
if (retrive_index.scalar_type() != at::kLong) {
|
||||
throw std::runtime_error("Expected 'retrive_index' to be of type long (torch.int64).");
|
||||
}
|
||||
if (retrive_next_token.scalar_type() != at::kInt) {
|
||||
throw std::runtime_error("Expected 'retrive_next_token' to be of type int (torch.int32).");
|
||||
if (retrive_next_token.scalar_type() != at::kLong) {
|
||||
throw std::runtime_error("Expected 'retrive_next_token' to be of type long (torch.int64).");
|
||||
}
|
||||
if (retrive_next_sibling.scalar_type() != at::kInt) {
|
||||
throw std::runtime_error("Expected 'retrive_next_sibling' to be of type int (torch.int32).");
|
||||
if (retrive_next_sibling.scalar_type() != at::kLong) {
|
||||
throw std::runtime_error("Expected 'retrive_next_sibling' to be of type long (torch.int64).");
|
||||
}
|
||||
if (uniform_samples.scalar_type() != at::kFloat) {
|
||||
throw std::runtime_error("Expected 'uniform_samples' to be of type float (torch.float32).");
|
||||
}
|
||||
if (uniform_samples_for_final_sampling.scalar_type() != at::kFloat) {
|
||||
throw std::runtime_error("Expected 'uniform_samples_for_final_sampling' to be of type float (torch.float32).");
|
||||
}
|
||||
if (target_probs.scalar_type() != at::kFloat) {
|
||||
throw std::runtime_error("Expected 'target_probs' to be of type float (torch.float32).");
|
||||
}
|
||||
@@ -119,15 +125,16 @@ void tree_speculative_sampling_target_only(
|
||||
CHECK_GE(1, threshold_acc);
|
||||
|
||||
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
|
||||
cudaError_t status = sampling::TreeSpeculativeSamplingTargetOnly<float, int>(
|
||||
static_cast<int*>(predicts.data_ptr()),
|
||||
static_cast<int*>(accept_index.data_ptr()),
|
||||
static_cast<int*>(accept_token_num.data_ptr()),
|
||||
static_cast<int*>(candidates.data_ptr()),
|
||||
static_cast<int*>(retrive_index.data_ptr()),
|
||||
static_cast<int*>(retrive_next_token.data_ptr()),
|
||||
static_cast<int*>(retrive_next_sibling.data_ptr()),
|
||||
cudaError_t status = sampling::TreeSpeculativeSamplingTargetOnly<float, int32_t, int64_t>(
|
||||
static_cast<int32_t*>(predicts.data_ptr()),
|
||||
static_cast<int32_t*>(accept_index.data_ptr()),
|
||||
static_cast<int32_t*>(accept_token_num.data_ptr()),
|
||||
static_cast<int64_t*>(candidates.data_ptr()),
|
||||
static_cast<int64_t*>(retrive_index.data_ptr()),
|
||||
static_cast<int64_t*>(retrive_next_token.data_ptr()),
|
||||
static_cast<int64_t*>(retrive_next_sibling.data_ptr()),
|
||||
static_cast<float*>(uniform_samples.data_ptr()),
|
||||
static_cast<float*>(uniform_samples_for_final_sampling.data_ptr()),
|
||||
static_cast<float*>(target_probs.data_ptr()),
|
||||
static_cast<float*>(draft_probs.data_ptr()),
|
||||
batch_size,
|
||||
|
||||
Reference in New Issue
Block a user