From 88f484ce4c73bff72fe417234d819800d2263d4b Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Wed, 2 Jul 2025 12:30:18 -0700 Subject: [PATCH] Apply dsv3 router gemm kernel for deepseek-r1 fp4 (#7677) --- python/sglang/srt/models/deepseek_v2.py | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index bf867407d..1646a2858 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -111,9 +111,16 @@ _is_fp8_fnuz = is_fp8_fnuz() _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip _is_cpu_amx_available = cpu_has_amx_support() _is_cpu = is_cpu() +_device_sm = get_device_sm() if _is_cuda: - from sgl_kernel import awq_dequantize, bmm_fp8, dsv3_fused_a_gemm, merge_state_v2 + from sgl_kernel import ( + awq_dequantize, + bmm_fp8, + dsv3_fused_a_gemm, + dsv3_router_gemm, + merge_state_v2, + ) elif _is_cpu and _is_cpu_amx_available: pass else: @@ -225,7 +232,18 @@ class MoEGate(nn.Module): True, # is_vnni ) - logits = F.linear(hidden_states, self.weight, None) + if ( + hidden_states.shape[0] < 4 + and hidden_states.shape[1] == 7168 + and self.weight.shape[0] == 256 + and _device_sm >= 90 + ): + logits = dsv3_router_gemm(hidden_states, self.weight).to( + hidden_states.dtype + ) + else: + logits = F.linear(hidden_states, self.weight, None) + return logits @@ -882,7 +900,7 @@ class DeepseekV2AttentionMLA(nn.Module): and self.fused_qkv_a_proj_with_mqa.weight.shape[0] == 2112 and self.fused_qkv_a_proj_with_mqa.weight.shape[1] == 7168 and is_cuda - and get_device_sm() >= 90 + and _device_sm >= 90 ) self.qkv_proj_with_rope_is_int8 = (