Fix wrong gemm branch cause 250us slower (#7969)

This commit is contained in:
fzyzcjy
2025-07-12 10:45:09 +08:00
committed by GitHub
parent aa2056091a
commit 2a2d3478af

View File

@@ -2193,7 +2193,6 @@ class DeepseekV2ForCausalLM(nn.Module):
# This may affect the accuracy of fp8 model.
# Fix deepseek v3 blockwise bmm by using deep_gemm
use_deep_gemm_bmm = False
model_dtype = torch.get_default_dtype()
if w.dtype in (
torch.float8_e4m3fn,
@@ -2219,7 +2218,6 @@ class DeepseekV2ForCausalLM(nn.Module):
_is_cuda
and weight_block_size[0] == 128
and weight_block_size[1] == 128
and model_dtype == torch.bfloat16
):
if (
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
@@ -2233,7 +2231,7 @@ class DeepseekV2ForCausalLM(nn.Module):
weight,
weight_scale,
weight_block_size,
model_dtype,
torch.bfloat16,
)
else:
w, scale = block_quant_to_tensor_quant(