Fuse two kernels of hidden states padding into quantization kernel (#9005)

Co-authored-by: Qiaolin-Yu <liin1211@outlook.com>
This commit is contained in:
fzyzcjy
2025-08-12 16:20:13 +08:00
committed by GitHub
parent 5438886c87
commit 5190ba7f42
2 changed files with 5 additions and 9 deletions

View File

@@ -210,13 +210,13 @@ class FusedMoE(torch.nn.Module):
self.use_enable_flashinfer_mxfp4_moe = global_server_args_dict.get(
"enable_flashinfer_mxfp4_moe", False
)
# TODO maybe we should remove this `if`, since `Mxfp4MoEMethod` does another round-up logic
if (
self.quant_config is not None
and self.quant_config.get_name() == "mxfp4"
and self.use_enable_flashinfer_mxfp4_moe
):
hidden_size = round_up(hidden_size, 256)
self.hidden_size = hidden_size
self.quant_method.create_weights(
layer=self,
num_experts=self.num_local_experts,
@@ -796,13 +796,6 @@ class FusedMoE(torch.nn.Module):
def forward(self, hidden_states: torch.Tensor, topk_output: StandardTopKOutput):
origin_hidden_states_dim = hidden_states.shape[-1]
if self.hidden_size != origin_hidden_states_dim:
hidden_states = torch.nn.functional.pad(
hidden_states,
(0, self.hidden_size - origin_hidden_states_dim),
mode="constant",
value=0.0,
)
assert self.quant_method is not None
if self.moe_ep_size > 1 and not self.enable_flashinfer_cutlass_moe:

View File

@@ -570,8 +570,11 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
) -> torch.Tensor:
if self.use_flashinfer:
# Based on profiling results, we need to quantize x to mxfp8 here to achieve better performance
x_quant, x_scale = mxfp8_quantize(x, False) # to mxfp8
x_quant, x_scale = mxfp8_quantize(
x, False, alignment=self.hidden_size
) # to mxfp8
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1)
assert x_quant.shape[-1] == self.hidden_size
top_k, router_logits = topk_output