From 3eb4a800e82fc4e5d551e51e44c5e844c8278583 Mon Sep 17 00:00:00 2001 From: AniZpZ Date: Wed, 18 Jun 2025 04:45:10 +0800 Subject: [PATCH] Fix AWQ Dequant and Weight Loading of deepseek v2 (#6842) --- python/sglang/srt/models/deepseek_v2.py | 8 +++++++- sgl-kernel/csrc/gemm/awq_kernel.cu | 16 +++++++++------- sgl-kernel/tests/test_awq_dequant.py | 5 ++--- 3 files changed, 18 insertions(+), 11 deletions(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 192ab50ed..4cf898d43 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -2137,8 +2137,14 @@ class DeepseekV2ForCausalLM(nn.Module): ): q_a_proj_weight = cached_a_proj[q_a_proj_name] kv_a_proj_weight = cached_a_proj[kv_a_proj_name] + cat_dim = 0 + if ( + self.quant_config.get_name() == "awq" + or self.quant_config.get_name() == "moe_wna16" + ): + cat_dim = 1 fused_weight = torch.cat( - [q_a_proj_weight, kv_a_proj_weight], dim=0 + [q_a_proj_weight, kv_a_proj_weight], dim=cat_dim ) param_name = ( name.replace("q_a_proj", "fused_qkv_a_proj_with_mqa") diff --git a/sgl-kernel/csrc/gemm/awq_kernel.cu b/sgl-kernel/csrc/gemm/awq_kernel.cu index 188f0cb3f..eec933689 100644 --- a/sgl-kernel/csrc/gemm/awq_kernel.cu +++ b/sgl-kernel/csrc/gemm/awq_kernel.cu @@ -130,10 +130,12 @@ __global__ void __launch_bounds__(256) dequantize_weights( int* __restrict__ qzeros, OutputT* __restrict__ output, int group_size, - int qweight_cols) { + int qweight_cols, + int qweight_rows) { #if CUDA_VERSION >= 12000 int col = blockIdx.x * blockDim.x + threadIdx.x; int row = blockIdx.y * blockDim.y + threadIdx.y; + if (col >= qweight_cols || row >= qweight_rows) return; int group_idx = row / group_size; int scale_offset = 8 * col + group_idx * qweight_cols * 8; @@ -188,8 +190,8 @@ torch::Tensor awq_dequantize(torch::Tensor qweight, torch::Tensor scales, torch: int x_num_threads = 16; int y_num_threads = 16; - int x_blocks = qweight_cols / x_num_threads; - int y_blocks = qweight_rows / y_num_threads; + int x_blocks = (qweight_cols + x_num_threads - 1) / x_num_threads; + int y_blocks = (qweight_rows + y_num_threads - 1) / y_num_threads; const at::cuda::OptionalCUDAGuard device_guard(device_of(qweight)); @@ -206,13 +208,13 @@ torch::Tensor awq_dequantize(torch::Tensor qweight, torch::Tensor scales, torch: if (scales.scalar_type() == at::ScalarType::Half) { auto _scales = reinterpret_cast(scales.data_ptr()); auto _output = reinterpret_cast(output.data_ptr()); - dequantize_weights - <<>>(_qweight, _scales, _zeros, _output, group_size, qweight_cols); + dequantize_weights<<>>( + _qweight, _scales, _zeros, _output, group_size, qweight_cols, qweight_rows); } else { auto _scales = reinterpret_cast<__nv_bfloat16*>(scales.data_ptr()); auto _output = reinterpret_cast<__nv_bfloat16*>(output.data_ptr()); - dequantize_weights<__nv_bfloat16> - <<>>(_qweight, _scales, _zeros, _output, group_size, qweight_cols); + dequantize_weights<__nv_bfloat16><<>>( + _qweight, _scales, _zeros, _output, group_size, qweight_cols, qweight_rows); } return output; diff --git a/sgl-kernel/tests/test_awq_dequant.py b/sgl-kernel/tests/test_awq_dequant.py index 60fc8a148..da68e88d0 100644 --- a/sgl-kernel/tests/test_awq_dequant.py +++ b/sgl-kernel/tests/test_awq_dequant.py @@ -67,8 +67,8 @@ def sglang_awq_dequantize( "qweight_row,qweight_col,is_bf16_act", list( itertools.product( - [3584, 18944, 128, 256, 512, 1024], - [448, 576, 4736, 16, 32, 64, 128], + [3584, 18944, 128, 256, 512, 1024, 1536], + [448, 576, 4736, 16, 32, 64, 128, 72], [True, False], ) ), @@ -77,7 +77,6 @@ def test_awq_dequant_compare_implementations( qweight_row: int, qweight_col: int, is_bf16_act: bool ): device = torch.device("cuda") - qweight = torch.randint( 0, torch.iinfo(torch.int32).max,