[BugFix] test_mla_fp8.py fails on Cublas 12.9 (#11360)

Signed-off-by: Liu-congo <1502632128@qq.com>
This commit is contained in:
Liu-congo
2025-10-11 12:14:24 +08:00
committed by GitHub
parent eae9a9fb9d
commit c80a96dae9
2 changed files with 31 additions and 8 deletions

View File

@@ -94,6 +94,7 @@ from sglang.srt.layers.quantization.fp8_utils import (
block_quant_dequant,
block_quant_to_tensor_quant,
channel_quant_to_tensor_quant,
input_to_float8,
normalize_e4m3fn_to_e4m3fnuz,
requant_weight_ue8m0_inplace,
)
@@ -131,6 +132,7 @@ from sglang.srt.utils import (
is_hip,
is_non_idle_and_non_empty,
is_npu,
is_nvidia_cublas_cu12_version_ge_12_9,
is_sm100_supported,
log_info_on_rank0,
make_layers,
@@ -189,6 +191,7 @@ else:
_is_flashinfer_available = is_flashinfer_available()
_is_sm100_supported = is_cuda() and is_sm100_supported()
_is_cublas_ge_129 = is_nvidia_cublas_cu12_version_ge_12_9()
logger = logging.getLogger(__name__)
@@ -1572,10 +1575,15 @@ class DeepseekV2AttentionMLA(nn.Module):
self.w_kc.to(torch.bfloat16) * self.w_scale,
)
elif self.w_kc.dtype == torch.float8_e4m3fn:
q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
q_nope.transpose(0, 1),
zero_allocator.allocate(1),
)
# TODO fix the per_tensor_quant_mla_fp8 for cublas 12.9
if _is_cublas_ge_129:
q_nope_val, q_nope_scale = input_to_float8(
q_nope.transpose(0, 1), torch.float8_e4m3fn
)
else:
q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
q_nope.transpose(0, 1), zero_allocator.allocate(1)
)
q_nope_out = bmm_fp8(
q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
)
@@ -1716,10 +1724,14 @@ class DeepseekV2AttentionMLA(nn.Module):
attn_bmm_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
elif self.w_vc.dtype == torch.float8_e4m3fn:
attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
attn_output.transpose(0, 1),
zero_allocator.allocate(1),
)
if _is_cublas_ge_129:
attn_output_val, attn_output_scale = input_to_float8(
attn_output.transpose(0, 1), torch.float8_e4m3fn
)
else:
attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
attn_output.transpose(0, 1), zero_allocator.allocate(1)
)
attn_bmm_output = bmm_fp8(
attn_output_val,
self.w_vc,

View File

@@ -263,6 +263,17 @@ def is_flashinfer_available():
return importlib.util.find_spec("flashinfer") is not None and is_cuda()
def is_nvidia_cublas_cu12_version_ge_12_9():
"""
temporary fix for issue #11272
"""
try:
installed_version = version("nvidia-cublas-cu12")
except PackageNotFoundError:
return False
return pkg_version.parse(installed_version) >= pkg_version.parse("12.9")
def random_uuid() -> str:
return str(uuid.uuid4().hex)