From 5190ba7f421692bcd35f2386dd62830d3cadbffa Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Tue, 12 Aug 2025 16:20:13 +0800 Subject: [PATCH] Fuse two kernels of hidden states padding into quantization kernel (#9005) Co-authored-by: Qiaolin-Yu --- python/sglang/srt/layers/moe/fused_moe_triton/layer.py | 9 +-------- python/sglang/srt/layers/quantization/mxfp4.py | 5 ++++- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 8aa57bbf2..990d88aed 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -210,13 +210,13 @@ class FusedMoE(torch.nn.Module): self.use_enable_flashinfer_mxfp4_moe = global_server_args_dict.get( "enable_flashinfer_mxfp4_moe", False ) + # TODO maybe we should remove this `if`, since `Mxfp4MoEMethod` does another round-up logic if ( self.quant_config is not None and self.quant_config.get_name() == "mxfp4" and self.use_enable_flashinfer_mxfp4_moe ): hidden_size = round_up(hidden_size, 256) - self.hidden_size = hidden_size self.quant_method.create_weights( layer=self, num_experts=self.num_local_experts, @@ -796,13 +796,6 @@ class FusedMoE(torch.nn.Module): def forward(self, hidden_states: torch.Tensor, topk_output: StandardTopKOutput): origin_hidden_states_dim = hidden_states.shape[-1] - if self.hidden_size != origin_hidden_states_dim: - hidden_states = torch.nn.functional.pad( - hidden_states, - (0, self.hidden_size - origin_hidden_states_dim), - mode="constant", - value=0.0, - ) assert self.quant_method is not None if self.moe_ep_size > 1 and not self.enable_flashinfer_cutlass_moe: diff --git a/python/sglang/srt/layers/quantization/mxfp4.py b/python/sglang/srt/layers/quantization/mxfp4.py index 62bfaf887..ee73fb4ce 100644 --- a/python/sglang/srt/layers/quantization/mxfp4.py +++ b/python/sglang/srt/layers/quantization/mxfp4.py @@ -570,8 +570,11 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ) -> torch.Tensor: if self.use_flashinfer: # Based on profiling results, we need to quantize x to mxfp8 here to achieve better performance - x_quant, x_scale = mxfp8_quantize(x, False) # to mxfp8 + x_quant, x_scale = mxfp8_quantize( + x, False, alignment=self.hidden_size + ) # to mxfp8 x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1) + assert x_quant.shape[-1] == self.hidden_size top_k, router_logits = topk_output