[refactor] slightly tidy fp8 module (#5993)
This commit is contained in:
@@ -59,8 +59,8 @@ from sglang.srt.layers.moe.topk import select_experts
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
per_tensor_quant_mla_deep_gemm_masked_fp8,
|
||||
per_tensor_quant_mla_fp8,
|
||||
per_token_group_quant_mla_deep_gemm_masked_fp8,
|
||||
)
|
||||
from sglang.srt.layers.quantization.fp8_utils import (
|
||||
block_quant_to_tensor_quant,
|
||||
@@ -738,9 +738,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
|
||||
if self.use_deep_gemm_bmm:
|
||||
q_nope_val, q_nope_scale, masked_m, expected_m, aligned_m = (
|
||||
per_tensor_quant_mla_deep_gemm_masked_fp8(
|
||||
q_nope.transpose(0, 1), dtype=torch.float8_e4m3fn
|
||||
)
|
||||
per_token_group_quant_mla_deep_gemm_masked_fp8(q_nope.transpose(0, 1))
|
||||
)
|
||||
q_nope_out = q_nope.new_empty(
|
||||
(self.num_local_heads, aligned_m, self.kv_lora_rank)
|
||||
@@ -785,8 +783,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
|
||||
if self.use_deep_gemm_bmm:
|
||||
attn_output_val, attn_output_scale, masked_m, expected_m, aligned_m = (
|
||||
per_tensor_quant_mla_deep_gemm_masked_fp8(
|
||||
attn_output.transpose(0, 1), dtype=torch.float8_e4m3fn
|
||||
per_token_group_quant_mla_deep_gemm_masked_fp8(
|
||||
attn_output.transpose(0, 1)
|
||||
)
|
||||
)
|
||||
attn_bmm_output = attn_output.new_empty(
|
||||
|
||||
Reference in New Issue
Block a user