From 2a2d3478afe8cdb336888f2e6faa3775ac40254e Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Sat, 12 Jul 2025 10:45:09 +0800 Subject: [PATCH] Fix wrong gemm branch cause 250us slower (#7969) --- python/sglang/srt/models/deepseek_v2.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index d83a7bb06..5138e4a12 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -2193,7 +2193,6 @@ class DeepseekV2ForCausalLM(nn.Module): # This may affect the accuracy of fp8 model. # Fix deepseek v3 blockwise bmm by using deep_gemm use_deep_gemm_bmm = False - model_dtype = torch.get_default_dtype() if w.dtype in ( torch.float8_e4m3fn, @@ -2219,7 +2218,6 @@ class DeepseekV2ForCausalLM(nn.Module): _is_cuda and weight_block_size[0] == 128 and weight_block_size[1] == 128 - and model_dtype == torch.bfloat16 ): if ( deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM @@ -2233,7 +2231,7 @@ class DeepseekV2ForCausalLM(nn.Module): weight, weight_scale, weight_block_size, - model_dtype, + torch.bfloat16, ) else: w, scale = block_quant_to_tensor_quant(