From c80a96dae93d337ecf0097705d27f474b33aa3cf Mon Sep 17 00:00:00 2001 From: Liu-congo <51957663+Liu-congo@users.noreply.github.com> Date: Sat, 11 Oct 2025 12:14:24 +0800 Subject: [PATCH] [BugFix] test_mla_fp8.py fails on Cublas 12.9 (#11360) Signed-off-by: Liu-congo <1502632128@qq.com> --- python/sglang/srt/models/deepseek_v2.py | 28 ++++++++++++++++++------- python/sglang/srt/utils/common.py | 11 ++++++++++ 2 files changed, 31 insertions(+), 8 deletions(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 8877fe602..e66dd4a1f 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -94,6 +94,7 @@ from sglang.srt.layers.quantization.fp8_utils import ( block_quant_dequant, block_quant_to_tensor_quant, channel_quant_to_tensor_quant, + input_to_float8, normalize_e4m3fn_to_e4m3fnuz, requant_weight_ue8m0_inplace, ) @@ -131,6 +132,7 @@ from sglang.srt.utils import ( is_hip, is_non_idle_and_non_empty, is_npu, + is_nvidia_cublas_cu12_version_ge_12_9, is_sm100_supported, log_info_on_rank0, make_layers, @@ -189,6 +191,7 @@ else: _is_flashinfer_available = is_flashinfer_available() _is_sm100_supported = is_cuda() and is_sm100_supported() +_is_cublas_ge_129 = is_nvidia_cublas_cu12_version_ge_12_9() logger = logging.getLogger(__name__) @@ -1572,10 +1575,15 @@ class DeepseekV2AttentionMLA(nn.Module): self.w_kc.to(torch.bfloat16) * self.w_scale, ) elif self.w_kc.dtype == torch.float8_e4m3fn: - q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8( - q_nope.transpose(0, 1), - zero_allocator.allocate(1), - ) + # TODO fix the per_tensor_quant_mla_fp8 for cublas 12.9 + if _is_cublas_ge_129: + q_nope_val, q_nope_scale = input_to_float8( + q_nope.transpose(0, 1), torch.float8_e4m3fn + ) + else: + q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8( + q_nope.transpose(0, 1), zero_allocator.allocate(1) + ) q_nope_out = bmm_fp8( q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16 ) @@ -1716,10 +1724,14 @@ class DeepseekV2AttentionMLA(nn.Module): attn_bmm_output = attn_bmm_output.transpose(0, 1).flatten(1, 2) elif self.w_vc.dtype == torch.float8_e4m3fn: - attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8( - attn_output.transpose(0, 1), - zero_allocator.allocate(1), - ) + if _is_cublas_ge_129: + attn_output_val, attn_output_scale = input_to_float8( + attn_output.transpose(0, 1), torch.float8_e4m3fn + ) + else: + attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8( + attn_output.transpose(0, 1), zero_allocator.allocate(1) + ) attn_bmm_output = bmm_fp8( attn_output_val, self.w_vc, diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py index 732690cac..44ac84fbc 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -263,6 +263,17 @@ def is_flashinfer_available(): return importlib.util.find_spec("flashinfer") is not None and is_cuda() +def is_nvidia_cublas_cu12_version_ge_12_9(): + """ + temporary fix for issue #11272 + """ + try: + installed_version = version("nvidia-cublas-cu12") + except PackageNotFoundError: + return False + return pkg_version.parse(installed_version) >= pkg_version.parse("12.9") + + def random_uuid() -> str: return str(uuid.uuid4().hex)