From 4a0d19198bf9222edcb9879028990b481f8ffe56 Mon Sep 17 00:00:00 2001 From: likesen-alibaba Date: Thu, 10 Jul 2025 16:19:56 +0800 Subject: [PATCH] Fix bug of deepseek-v3 under DP+EP mode with large batchsize/seqlen (#6449) --- python/sglang/srt/layers/quantization/fp8_kernel.py | 4 ++-- sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/layers/quantization/fp8_kernel.py b/python/sglang/srt/layers/quantization/fp8_kernel.py index c28454fc4..d73f5bbab 100644 --- a/python/sglang/srt/layers/quantization/fp8_kernel.py +++ b/python/sglang/srt/layers/quantization/fp8_kernel.py @@ -160,8 +160,8 @@ def _per_token_group_quant_fp8_colmajor( """ # Map the program id to the row of X and Y it should compute. g_id = tl.program_id(0) - y_ptr += g_id * group_size - y_q_ptr += g_id * group_size + y_ptr += g_id.to(tl.int64) * group_size + y_q_ptr += g_id.to(tl.int64) * group_size # Convert g_id the flattened block coordinate to 2D so we can index # into the output y_scales matrix diff --git a/sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu b/sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu index c9474b96e..25b57c8f4 100644 --- a/sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu +++ b/sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu @@ -35,12 +35,12 @@ __global__ void per_token_group_quant_8bit_kernel( const int scale_num_rows = 0, const int scale_stride = 0) { const int threads_per_group = 16; - const int local_group_id = threadIdx.x / threads_per_group; + const int64_t local_group_id = threadIdx.x / threads_per_group; const int lane_id = threadIdx.x % threads_per_group; - const int block_group_id = blockIdx.x * groups_per_block; - const int global_group_id = block_group_id + local_group_id; - const int block_group_offset = global_group_id * group_size; + const int64_t block_group_id = blockIdx.x * groups_per_block; + const int64_t global_group_id = block_group_id + local_group_id; + const int64_t block_group_offset = global_group_id * group_size; float local_absmax = eps;