From ccd3fb946e04518ab51277812bc5d7d8b9d9dae4 Mon Sep 17 00:00:00 2001 From: hlu1 <14827759+hlu1@users.noreply.github.com> Date: Sat, 23 Aug 2025 01:48:40 -0700 Subject: [PATCH] [fix] Fix mxfp4 triton MoE tp bug (#9473) Signed-off-by: Hao Lu <14827759+hlu1@users.noreply.github.com> --- .../sglang/srt/layers/moe/fused_moe_triton/layer.py | 8 ++------ python/sglang/srt/layers/quantization/mxfp4.py | 7 +++++++ python/sglang/srt/models/gpt_oss.py | 11 +++++------ 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 2a00ddd00..7b3452525 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -111,9 +111,8 @@ class FusedMoE(torch.nn.Module): hidden_size: Input hidden state size of the transformer intermediate_size: Intermediate size of the experts params_dtype: Data type for the parameters. - reduce_results: Whether to all all_reduce on the output of the layer - renomalize: Whether to renormalize the logits in the fused_moe kernel - quant_config: Quantization configure. + reduce_results: Whether to apply all_reduce on the output of the layer + quant_config: Quantization configuration. inplace: suggestion to compute inplace (modify input activation). """ @@ -182,9 +181,6 @@ class FusedMoE(torch.nn.Module): self.expert_map_cpu = torch.full( (self.num_experts,), -1, dtype=torch.int32, device="cpu" ) - self.expert_map_cpu = torch.full( - (self.num_experts,), -1, dtype=torch.int32, device="cpu" - ) # Create a expert map for the local experts self.expert_map_cpu[ self.moe_ep_rank diff --git a/python/sglang/srt/layers/quantization/mxfp4.py b/python/sglang/srt/layers/quantization/mxfp4.py index 1e46cc868..fa0b4410c 100644 --- a/python/sglang/srt/layers/quantization/mxfp4.py +++ b/python/sglang/srt/layers/quantization/mxfp4.py @@ -309,6 +309,13 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): intermediate_size_per_partition_after_pad = round_up( intermediate_size, 64 ) + elif has_triton_kernels: + # TODO: this is a hack to make + # intermediate_size_per_partition_after_pad the same as the + # per_rank_intermediate_size during weight loading + intermediate_size_per_partition_after_pad = round_up( + intermediate_size, mxfp4_block + ) self.intermediate_size = intermediate_size_per_partition_after_pad diff --git a/python/sglang/srt/models/gpt_oss.py b/python/sglang/srt/models/gpt_oss.py index 829f40689..eda1ed7e7 100644 --- a/python/sglang/srt/models/gpt_oss.py +++ b/python/sglang/srt/models/gpt_oss.py @@ -793,12 +793,11 @@ class GptOssForCausalLM(nn.Module): intermediate_size % mxfp4_block == 0 ), f"{intermediate_size=} must be divisible by {mxfp4_block=}" intermediate_size_block = intermediate_size // mxfp4_block - if _is_sm100_supported: - per_rank_intermediate_size_block = math.ceil( - intermediate_size_block / moe_tp_size - ) - else: - per_rank_intermediate_size_block = intermediate_size_block // moe_tp_size + + per_rank_intermediate_size_block = math.ceil( + intermediate_size_block / moe_tp_size + ) + per_rank_intermediate_size = per_rank_intermediate_size_block * mxfp4_block # Calculate common slicing bounds for current rank