[RL] fix update weight for FusedMoE with EP (#8676)
This commit is contained in:
@@ -124,15 +124,18 @@ class FusedMoE(torch.nn.Module):
|
||||
if self.moe_ep_size > 1:
|
||||
# TODO(ch-wan): support shared experts fusion
|
||||
# Create a tensor of size num_experts filled with -1
|
||||
self.expert_map_cpu = torch.full((self.num_experts,), -1, dtype=torch.int32)
|
||||
self.expert_map_cpu = torch.full(
|
||||
(self.num_experts,), -1, dtype=torch.int32, device="cpu"
|
||||
)
|
||||
self.expert_map_cpu = torch.full(
|
||||
(self.num_experts,), -1, dtype=torch.int32, device="cpu"
|
||||
)
|
||||
# Create a expert map for the local experts
|
||||
self.expert_map_cpu[
|
||||
self.moe_ep_rank
|
||||
* self.num_local_experts : (self.moe_ep_rank + 1)
|
||||
* self.num_local_experts
|
||||
] = torch.arange(0, self.num_local_experts, dtype=torch.int32, device="cpu")
|
||||
if not self.enable_flashinfer_cutlass_moe:
|
||||
self.expert_map_gpu = self.expert_map_cpu.to(device="cuda")
|
||||
|
||||
self.routed_scaling_factor = routed_scaling_factor
|
||||
assert intermediate_size % self.moe_tp_size == 0
|
||||
@@ -624,6 +627,11 @@ class FusedMoE(torch.nn.Module):
|
||||
def forward(self, hidden_states: torch.Tensor, topk_output: StandardTopKOutput):
|
||||
assert self.quant_method is not None
|
||||
|
||||
if self.moe_ep_size > 1 and not self.enable_flashinfer_cutlass_moe:
|
||||
if self.expert_map_cpu is not None and self.expert_map_gpu is None:
|
||||
# If we are in EP mode, we need to move the expert map to GPU.
|
||||
self.expert_map_gpu = self.expert_map_cpu.to(device="cuda")
|
||||
|
||||
if self.expert_map_gpu is not None:
|
||||
topk_output = topk_output._replace(
|
||||
topk_ids=self.expert_map_gpu[topk_output.topk_ids]
|
||||
|
||||
Reference in New Issue
Block a user