From 84b006b27833d93045ae5552e2cebb13f5140ab5 Mon Sep 17 00:00:00 2001 From: Cheng Wan <54331508+ch-wan@users.noreply.github.com> Date: Fri, 15 Aug 2025 02:28:33 -0700 Subject: [PATCH] Cleanup MoE Refactor (#9223) --- python/sglang/srt/layers/quantization/mxfp4.py | 11 ++++++++--- python/sglang/srt/models/deepseek_v2.py | 13 ++++++------- python/sglang/srt/models/glm4_moe.py | 10 ++++------ 3 files changed, 18 insertions(+), 16 deletions(-) diff --git a/python/sglang/srt/layers/quantization/mxfp4.py b/python/sglang/srt/layers/quantization/mxfp4.py index 5eaa21d1e..fedf4c0b0 100644 --- a/python/sglang/srt/layers/quantization/mxfp4.py +++ b/python/sglang/srt/layers/quantization/mxfp4.py @@ -573,6 +573,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): topk_output: TopKOutput, moe_runner_config: MoeRunnerConfig, ) -> torch.Tensor: + + from sglang.srt.layers.moe.topk import TopKOutputChecker + if self.use_flashinfer: # Based on profiling results, we need to quantize x to mxfp8 here to achieve better performance x_quant, x_scale = mxfp8_quantize( @@ -580,8 +583,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ) # to mxfp8 x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1) assert x_quant.shape[-1] == self.hidden_size + assert TopKOutputChecker.format_is_bypassed(topk_output) - top_k, router_logits = topk_output + top_k = topk_output.topk_config.top_k + router_logits = topk_output.router_logits trtllm_gen_output = trtllm_fp4_block_scale_moe( router_logits.to(torch.bfloat16), @@ -602,8 +607,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): None, # output2_scale_scalar layer.num_experts, top_k, - None, # n_group - None, # topk_group + None, # n_group # TODO: support n_group + None, # topk_group # TODO: support topk_group self.intermediate_size, # padded to multiple of 256 layer.moe_ep_rank * layer.num_local_experts, # local_expert_offset layer.num_local_experts, # local num experts diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 8d51d7823..2ba57f958 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -459,15 +459,15 @@ class DeepseekV2MoE(nn.Module): with torch.cuda.stream(self.alt_stream): # router_logits: (num_tokens, n_experts) router_logits = self.gate(hidden_states) - kwargs = {"hidden_states": hidden_states} - kwargs["topk_output"] = self.topk(hidden_states, router_logits) - - final_hidden_states = self.experts(**kwargs) + topk_output = self.topk(hidden_states, router_logits) + final_hidden_states = self.experts(hidden_states, topk_output) if not _is_cuda: final_hidden_states *= self.routed_scaling_factor + current_stream.wait_stream(self.alt_stream) with use_symmetric_memory(parallel_state.get_tp_group()) as sm: final_hidden_states_out = torch.empty_like(final_hidden_states) + torch.add(final_hidden_states, shared_output, out=final_hidden_states_out) final_hidden_states = final_hidden_states_out sm.tag(final_hidden_states) @@ -489,10 +489,9 @@ class DeepseekV2MoE(nn.Module): shared_output = self._forward_shared_experts(hidden_states) # router_logits: (num_tokens, n_experts) router_logits = self.gate(hidden_states) - kwargs = {"hidden_states": hidden_states} - kwargs["topk_output"] = self.topk(hidden_states, router_logits) + topk_output = self.topk(hidden_states, router_logits) - final_hidden_states = self.experts(**kwargs) + final_hidden_states = self.experts(hidden_states, topk_output) if not _is_cuda and not _use_aiter: # fused in biased_grouped_topk so we can skip here final_hidden_states *= self.routed_scaling_factor diff --git a/python/sglang/srt/models/glm4_moe.py b/python/sglang/srt/models/glm4_moe.py index 6e4b16e78..f75531bd8 100644 --- a/python/sglang/srt/models/glm4_moe.py +++ b/python/sglang/srt/models/glm4_moe.py @@ -509,9 +509,8 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE): with torch.cuda.stream(self.alt_stream): # router_logits: (num_tokens, n_experts) router_logits = self.gate(hidden_states) - kwargs = {"hidden_states": hidden_states} - kwargs["topk_output"] = self.topk(hidden_states, router_logits) - final_hidden_states = self.experts(**kwargs) + topk_output = self.topk(hidden_states, router_logits) + final_hidden_states = self.experts(hidden_states, topk_output) if not _is_cuda: final_hidden_states *= self.routed_scaling_factor current_stream.wait_stream(self.alt_stream) @@ -552,9 +551,8 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE): shared_output = self._forward_shared_experts(hidden_states) # router_logits: (num_tokens, n_experts) router_logits = self.gate(hidden_states) - kwargs = {"hidden_states": hidden_states} - kwargs["topk_output"] = self.topk(hidden_states, router_logits) - final_hidden_states = self.experts(**kwargs) + topk_output = self.topk(hidden_states, router_logits) + final_hidden_states = self.experts(hidden_states, topk_output) if not _is_cuda and not _use_aiter: # fused in biased_grouped_topk so we can skip here final_hidden_states *= self.routed_scaling_factor