[BugFix] test_mla_fp8.py fails on Cublas 12.9 (#11360)
Signed-off-by: Liu-congo <1502632128@qq.com>
This commit is contained in:
@@ -94,6 +94,7 @@ from sglang.srt.layers.quantization.fp8_utils import (
|
||||
block_quant_dequant,
|
||||
block_quant_to_tensor_quant,
|
||||
channel_quant_to_tensor_quant,
|
||||
input_to_float8,
|
||||
normalize_e4m3fn_to_e4m3fnuz,
|
||||
requant_weight_ue8m0_inplace,
|
||||
)
|
||||
@@ -131,6 +132,7 @@ from sglang.srt.utils import (
|
||||
is_hip,
|
||||
is_non_idle_and_non_empty,
|
||||
is_npu,
|
||||
is_nvidia_cublas_cu12_version_ge_12_9,
|
||||
is_sm100_supported,
|
||||
log_info_on_rank0,
|
||||
make_layers,
|
||||
@@ -189,6 +191,7 @@ else:
|
||||
|
||||
_is_flashinfer_available = is_flashinfer_available()
|
||||
_is_sm100_supported = is_cuda() and is_sm100_supported()
|
||||
_is_cublas_ge_129 = is_nvidia_cublas_cu12_version_ge_12_9()
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -1572,10 +1575,15 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
self.w_kc.to(torch.bfloat16) * self.w_scale,
|
||||
)
|
||||
elif self.w_kc.dtype == torch.float8_e4m3fn:
|
||||
q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
|
||||
q_nope.transpose(0, 1),
|
||||
zero_allocator.allocate(1),
|
||||
)
|
||||
# TODO fix the per_tensor_quant_mla_fp8 for cublas 12.9
|
||||
if _is_cublas_ge_129:
|
||||
q_nope_val, q_nope_scale = input_to_float8(
|
||||
q_nope.transpose(0, 1), torch.float8_e4m3fn
|
||||
)
|
||||
else:
|
||||
q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
|
||||
q_nope.transpose(0, 1), zero_allocator.allocate(1)
|
||||
)
|
||||
q_nope_out = bmm_fp8(
|
||||
q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
|
||||
)
|
||||
@@ -1716,10 +1724,14 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
attn_bmm_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
|
||||
|
||||
elif self.w_vc.dtype == torch.float8_e4m3fn:
|
||||
attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
|
||||
attn_output.transpose(0, 1),
|
||||
zero_allocator.allocate(1),
|
||||
)
|
||||
if _is_cublas_ge_129:
|
||||
attn_output_val, attn_output_scale = input_to_float8(
|
||||
attn_output.transpose(0, 1), torch.float8_e4m3fn
|
||||
)
|
||||
else:
|
||||
attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
|
||||
attn_output.transpose(0, 1), zero_allocator.allocate(1)
|
||||
)
|
||||
attn_bmm_output = bmm_fp8(
|
||||
attn_output_val,
|
||||
self.w_vc,
|
||||
|
||||
@@ -263,6 +263,17 @@ def is_flashinfer_available():
|
||||
return importlib.util.find_spec("flashinfer") is not None and is_cuda()
|
||||
|
||||
|
||||
def is_nvidia_cublas_cu12_version_ge_12_9():
|
||||
"""
|
||||
temporary fix for issue #11272
|
||||
"""
|
||||
try:
|
||||
installed_version = version("nvidia-cublas-cu12")
|
||||
except PackageNotFoundError:
|
||||
return False
|
||||
return pkg_version.parse(installed_version) >= pkg_version.parse("12.9")
|
||||
|
||||
|
||||
def random_uuid() -> str:
|
||||
return str(uuid.uuid4().hex)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user