[refactor] slightly tidy fp8 module (#5993)

This commit is contained in:
JieXin Liang
2025-05-08 08:28:24 +08:00
committed by GitHub
parent e444c13fb4
commit b70957fcf8
12 changed files with 238 additions and 231 deletions

View File

@@ -59,8 +59,8 @@ from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
from sglang.srt.layers.quantization.fp8_kernel import (
per_tensor_quant_mla_deep_gemm_masked_fp8,
per_tensor_quant_mla_fp8,
per_token_group_quant_mla_deep_gemm_masked_fp8,
)
from sglang.srt.layers.quantization.fp8_utils import (
block_quant_to_tensor_quant,
@@ -738,9 +738,7 @@ class DeepseekV2AttentionMLA(nn.Module):
if self.use_deep_gemm_bmm:
q_nope_val, q_nope_scale, masked_m, expected_m, aligned_m = (
per_tensor_quant_mla_deep_gemm_masked_fp8(
q_nope.transpose(0, 1), dtype=torch.float8_e4m3fn
)
per_token_group_quant_mla_deep_gemm_masked_fp8(q_nope.transpose(0, 1))
)
q_nope_out = q_nope.new_empty(
(self.num_local_heads, aligned_m, self.kv_lora_rank)
@@ -785,8 +783,8 @@ class DeepseekV2AttentionMLA(nn.Module):
if self.use_deep_gemm_bmm:
attn_output_val, attn_output_scale, masked_m, expected_m, aligned_m = (
per_tensor_quant_mla_deep_gemm_masked_fp8(
attn_output.transpose(0, 1), dtype=torch.float8_e4m3fn
per_token_group_quant_mla_deep_gemm_masked_fp8(
attn_output.transpose(0, 1)
)
)
attn_bmm_output = attn_output.new_empty(