[BugFix] replace the input_to_float8 used in dsv2 (#11612)
Signed-off-by: Liu-congo <1502632128@qq.com>
This commit is contained in:
@@ -92,7 +92,6 @@ from sglang.srt.layers.quantization.fp8_utils import (
|
||||
block_quant_dequant,
|
||||
block_quant_to_tensor_quant,
|
||||
channel_quant_to_tensor_quant,
|
||||
input_to_float8,
|
||||
normalize_e4m3fn_to_e4m3fnuz,
|
||||
quant_weight_ue8m0,
|
||||
requant_weight_ue8m0_inplace,
|
||||
@@ -1623,15 +1622,15 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
self.w_kc.to(torch.bfloat16) * self.w_scale,
|
||||
)
|
||||
elif self.w_kc.dtype == torch.float8_e4m3fn:
|
||||
# TODO fix the per_tensor_quant_mla_fp8 for cublas 12.9
|
||||
if _is_cublas_ge_129:
|
||||
q_nope_val, q_nope_scale = input_to_float8(
|
||||
q_nope.transpose(0, 1), torch.float8_e4m3fn
|
||||
)
|
||||
else:
|
||||
q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
|
||||
q_nope.transpose(0, 1), zero_allocator.allocate(1)
|
||||
)
|
||||
# fix bmm_fp8 error under cublas12.9 caused by bumpallocator, detail in pr#11612
|
||||
q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
|
||||
q_nope.transpose(0, 1),
|
||||
(
|
||||
torch.zeros((1,), dtype=torch.float32, device=q_nope.device)
|
||||
if _is_cublas_ge_129
|
||||
else zero_allocator.allocate(1)
|
||||
),
|
||||
)
|
||||
q_nope_out = bmm_fp8(
|
||||
q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
|
||||
)
|
||||
@@ -1772,14 +1771,14 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
attn_bmm_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
|
||||
|
||||
elif self.w_vc.dtype == torch.float8_e4m3fn:
|
||||
if _is_cublas_ge_129:
|
||||
attn_output_val, attn_output_scale = input_to_float8(
|
||||
attn_output.transpose(0, 1), torch.float8_e4m3fn
|
||||
)
|
||||
else:
|
||||
attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
|
||||
attn_output.transpose(0, 1), zero_allocator.allocate(1)
|
||||
)
|
||||
attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
|
||||
attn_output.transpose(0, 1),
|
||||
(
|
||||
torch.zeros((1,), dtype=torch.float32, device=attn_output.device)
|
||||
if _is_cublas_ge_129
|
||||
else zero_allocator.allocate(1)
|
||||
),
|
||||
)
|
||||
attn_bmm_output = bmm_fp8(
|
||||
attn_output_val,
|
||||
self.w_vc,
|
||||
|
||||
Reference in New Issue
Block a user