Fix sampling for speculative decoding & simplify kernels (#7207)
This commit is contained in:
@@ -24,7 +24,12 @@ using namespace flashinfer;
|
||||
|
||||
// bitorder = "little"
|
||||
void segment_packbits(
|
||||
at::Tensor x, at::Tensor input_indptr, at::Tensor output_indptr, at::Tensor y, int64_t cuda_stream) {
|
||||
at::Tensor x,
|
||||
at::Tensor input_indptr,
|
||||
at::Tensor output_indptr,
|
||||
at::Tensor y,
|
||||
int64_t batch_size,
|
||||
int64_t cuda_stream) {
|
||||
CHECK_INPUT(x);
|
||||
CHECK_INPUT(input_indptr);
|
||||
CHECK_INPUT(output_indptr);
|
||||
@@ -32,8 +37,7 @@ void segment_packbits(
|
||||
CHECK_EQ(input_indptr.device(), device);
|
||||
CHECK_EQ(output_indptr.device(), device);
|
||||
CHECK_EQ(y.device(), device);
|
||||
unsigned int batch_size = input_indptr.size(0) - 1;
|
||||
CHECK_EQ(output_indptr.size(0), batch_size + 1);
|
||||
CHECK_GE(output_indptr.size(0), batch_size + 1);
|
||||
|
||||
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
|
||||
cudaError_t status = quantization::SegmentPackBits(
|
||||
|
||||
Reference in New Issue
Block a user