[RL] fix update weight for FusedMoE with EP (#8676)
This commit is contained in:
@@ -124,15 +124,18 @@ class FusedMoE(torch.nn.Module):
|
|||||||
if self.moe_ep_size > 1:
|
if self.moe_ep_size > 1:
|
||||||
# TODO(ch-wan): support shared experts fusion
|
# TODO(ch-wan): support shared experts fusion
|
||||||
# Create a tensor of size num_experts filled with -1
|
# Create a tensor of size num_experts filled with -1
|
||||||
self.expert_map_cpu = torch.full((self.num_experts,), -1, dtype=torch.int32)
|
self.expert_map_cpu = torch.full(
|
||||||
|
(self.num_experts,), -1, dtype=torch.int32, device="cpu"
|
||||||
|
)
|
||||||
|
self.expert_map_cpu = torch.full(
|
||||||
|
(self.num_experts,), -1, dtype=torch.int32, device="cpu"
|
||||||
|
)
|
||||||
# Create a expert map for the local experts
|
# Create a expert map for the local experts
|
||||||
self.expert_map_cpu[
|
self.expert_map_cpu[
|
||||||
self.moe_ep_rank
|
self.moe_ep_rank
|
||||||
* self.num_local_experts : (self.moe_ep_rank + 1)
|
* self.num_local_experts : (self.moe_ep_rank + 1)
|
||||||
* self.num_local_experts
|
* self.num_local_experts
|
||||||
] = torch.arange(0, self.num_local_experts, dtype=torch.int32, device="cpu")
|
] = torch.arange(0, self.num_local_experts, dtype=torch.int32, device="cpu")
|
||||||
if not self.enable_flashinfer_cutlass_moe:
|
|
||||||
self.expert_map_gpu = self.expert_map_cpu.to(device="cuda")
|
|
||||||
|
|
||||||
self.routed_scaling_factor = routed_scaling_factor
|
self.routed_scaling_factor = routed_scaling_factor
|
||||||
assert intermediate_size % self.moe_tp_size == 0
|
assert intermediate_size % self.moe_tp_size == 0
|
||||||
@@ -624,6 +627,11 @@ class FusedMoE(torch.nn.Module):
|
|||||||
def forward(self, hidden_states: torch.Tensor, topk_output: StandardTopKOutput):
|
def forward(self, hidden_states: torch.Tensor, topk_output: StandardTopKOutput):
|
||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
|
|
||||||
|
if self.moe_ep_size > 1 and not self.enable_flashinfer_cutlass_moe:
|
||||||
|
if self.expert_map_cpu is not None and self.expert_map_gpu is None:
|
||||||
|
# If we are in EP mode, we need to move the expert map to GPU.
|
||||||
|
self.expert_map_gpu = self.expert_map_cpu.to(device="cuda")
|
||||||
|
|
||||||
if self.expert_map_gpu is not None:
|
if self.expert_map_gpu is not None:
|
||||||
topk_output = topk_output._replace(
|
topk_output = topk_output._replace(
|
||||||
topk_ids=self.expert_map_gpu[topk_output.topk_ids]
|
topk_ids=self.expert_map_gpu[topk_output.topk_ids]
|
||||||
|
|||||||
Reference in New Issue
Block a user