diff --git a/python/sglang/srt/layers/quantization/fp8_kernel.py b/python/sglang/srt/layers/quantization/fp8_kernel.py index c28454fc4..d73f5bbab 100644 --- a/python/sglang/srt/layers/quantization/fp8_kernel.py +++ b/python/sglang/srt/layers/quantization/fp8_kernel.py @@ -160,8 +160,8 @@ def _per_token_group_quant_fp8_colmajor( """ # Map the program id to the row of X and Y it should compute. g_id = tl.program_id(0) - y_ptr += g_id * group_size - y_q_ptr += g_id * group_size + y_ptr += g_id.to(tl.int64) * group_size + y_q_ptr += g_id.to(tl.int64) * group_size # Convert g_id the flattened block coordinate to 2D so we can index # into the output y_scales matrix diff --git a/sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu b/sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu index c9474b96e..25b57c8f4 100644 --- a/sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu +++ b/sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu @@ -35,12 +35,12 @@ __global__ void per_token_group_quant_8bit_kernel( const int scale_num_rows = 0, const int scale_stride = 0) { const int threads_per_group = 16; - const int local_group_id = threadIdx.x / threads_per_group; + const int64_t local_group_id = threadIdx.x / threads_per_group; const int lane_id = threadIdx.x % threads_per_group; - const int block_group_id = blockIdx.x * groups_per_block; - const int global_group_id = block_group_id + local_group_id; - const int block_group_offset = global_group_id * group_size; + const int64_t block_group_id = blockIdx.x * groups_per_block; + const int64_t global_group_id = block_group_id + local_group_id; + const int64_t block_group_offset = global_group_id * group_size; float local_absmax = eps;