From 3435a24e815760e5b5ccfec1571e971a57e4e959 Mon Sep 17 00:00:00 2001 From: Zilin Zhu Date: Mon, 4 Aug 2025 01:20:39 +0800 Subject: [PATCH] [RL] fix update weight for FusedMoE with EP (#8676) --- .../srt/layers/moe/fused_moe_triton/layer.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 3960e22a6..d0a9ed132 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -124,15 +124,18 @@ class FusedMoE(torch.nn.Module): if self.moe_ep_size > 1: # TODO(ch-wan): support shared experts fusion # Create a tensor of size num_experts filled with -1 - self.expert_map_cpu = torch.full((self.num_experts,), -1, dtype=torch.int32) + self.expert_map_cpu = torch.full( + (self.num_experts,), -1, dtype=torch.int32, device="cpu" + ) + self.expert_map_cpu = torch.full( + (self.num_experts,), -1, dtype=torch.int32, device="cpu" + ) # Create a expert map for the local experts self.expert_map_cpu[ self.moe_ep_rank * self.num_local_experts : (self.moe_ep_rank + 1) * self.num_local_experts ] = torch.arange(0, self.num_local_experts, dtype=torch.int32, device="cpu") - if not self.enable_flashinfer_cutlass_moe: - self.expert_map_gpu = self.expert_map_cpu.to(device="cuda") self.routed_scaling_factor = routed_scaling_factor assert intermediate_size % self.moe_tp_size == 0 @@ -624,6 +627,11 @@ class FusedMoE(torch.nn.Module): def forward(self, hidden_states: torch.Tensor, topk_output: StandardTopKOutput): assert self.quant_method is not None + if self.moe_ep_size > 1 and not self.enable_flashinfer_cutlass_moe: + if self.expert_map_cpu is not None and self.expert_map_gpu is None: + # If we are in EP mode, we need to move the expert map to GPU. + self.expert_map_gpu = self.expert_map_cpu.to(device="cuda") + if self.expert_map_gpu is not None: topk_output = topk_output._replace( topk_ids=self.expert_map_gpu[topk_output.topk_ids]