Fix sampling for speculative decoding & simplify kernels (#7207)

This commit is contained in:
Lianmin Zheng
2025-06-16 03:28:30 -07:00
committed by GitHub
parent b1286a116a
commit cfceb83d05
11 changed files with 124 additions and 79 deletions

View File

@@ -24,7 +24,12 @@ using namespace flashinfer;
// bitorder = "little"
void segment_packbits(
at::Tensor x, at::Tensor input_indptr, at::Tensor output_indptr, at::Tensor y, int64_t cuda_stream) {
at::Tensor x,
at::Tensor input_indptr,
at::Tensor output_indptr,
at::Tensor y,
int64_t batch_size,
int64_t cuda_stream) {
CHECK_INPUT(x);
CHECK_INPUT(input_indptr);
CHECK_INPUT(output_indptr);
@@ -32,8 +37,7 @@ void segment_packbits(
CHECK_EQ(input_indptr.device(), device);
CHECK_EQ(output_indptr.device(), device);
CHECK_EQ(y.device(), device);
unsigned int batch_size = input_indptr.size(0) - 1;
CHECK_EQ(output_indptr.size(0), batch_size + 1);
CHECK_GE(output_indptr.size(0), batch_size + 1);
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
cudaError_t status = quantization::SegmentPackBits(