From be0058bc0584c44ea749e332be1432909a7228c9 Mon Sep 17 00:00:00 2001 From: Liu-congo <51957663+Liu-congo@users.noreply.github.com> Date: Mon, 20 Oct 2025 08:34:13 +0800 Subject: [PATCH] [BugFix] replace the input_to_float8 used in dsv2 (#11612) Signed-off-by: Liu-congo <1502632128@qq.com> --- python/sglang/srt/models/deepseek_v2.py | 35 ++++++++++++------------- 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 31da13318..6ca168670 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -92,7 +92,6 @@ from sglang.srt.layers.quantization.fp8_utils import ( block_quant_dequant, block_quant_to_tensor_quant, channel_quant_to_tensor_quant, - input_to_float8, normalize_e4m3fn_to_e4m3fnuz, quant_weight_ue8m0, requant_weight_ue8m0_inplace, @@ -1623,15 +1622,15 @@ class DeepseekV2AttentionMLA(nn.Module): self.w_kc.to(torch.bfloat16) * self.w_scale, ) elif self.w_kc.dtype == torch.float8_e4m3fn: - # TODO fix the per_tensor_quant_mla_fp8 for cublas 12.9 - if _is_cublas_ge_129: - q_nope_val, q_nope_scale = input_to_float8( - q_nope.transpose(0, 1), torch.float8_e4m3fn - ) - else: - q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8( - q_nope.transpose(0, 1), zero_allocator.allocate(1) - ) + # fix bmm_fp8 error under cublas12.9 caused by bumpallocator, detail in pr#11612 + q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8( + q_nope.transpose(0, 1), + ( + torch.zeros((1,), dtype=torch.float32, device=q_nope.device) + if _is_cublas_ge_129 + else zero_allocator.allocate(1) + ), + ) q_nope_out = bmm_fp8( q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16 ) @@ -1772,14 +1771,14 @@ class DeepseekV2AttentionMLA(nn.Module): attn_bmm_output = attn_bmm_output.transpose(0, 1).flatten(1, 2) elif self.w_vc.dtype == torch.float8_e4m3fn: - if _is_cublas_ge_129: - attn_output_val, attn_output_scale = input_to_float8( - attn_output.transpose(0, 1), torch.float8_e4m3fn - ) - else: - attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8( - attn_output.transpose(0, 1), zero_allocator.allocate(1) - ) + attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8( + attn_output.transpose(0, 1), + ( + torch.zeros((1,), dtype=torch.float32, device=attn_output.device) + if _is_cublas_ge_129 + else zero_allocator.allocate(1) + ), + ) attn_bmm_output = bmm_fp8( attn_output_val, self.w_vc,