support fused_moe_allgather_ep (#1335)

### What this PR does / why we need it?
support fused_moe_allgather_ep

### How was this patch tested?
It was tested by UT.

Signed-off-by: lyj-jjj <liuyingjun5@huawei.com>
This commit is contained in:
lyj-jjj
2025-06-23 22:03:38 +08:00
committed by GitHub
parent 917c6b71af
commit 5177bef87a
5 changed files with 218 additions and 14 deletions

View File

@@ -988,8 +988,9 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
**kwargs,
) -> torch.Tensor:
is_deepseek_v3_r1 = global_num_experts == 256
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
if global_num_experts == 256:
if is_deepseek_v3_r1:
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
router_logits,
k=top_k, # topk当前写8
@@ -1025,7 +1026,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
fused_moe_state = get_fused_moe_state(self.ep_group.world_size,
is_prefill)
is_prefill, is_deepseek_v3_r1)
if fused_moe_state == FusedMoEState.MC2:
return fused_experts_with_mc2(
hidden_states=x,
@@ -1219,15 +1220,17 @@ class AscendFusedMoE(FusedMoE):
real_top_k = self.top_k
num_tokens, hidden_size = hidden_states.shape
is_deepseek_v3_r1 = self.global_num_experts == 256
fused_moe_state = get_fused_moe_state(self.moe_parallel_config.ep_size,
is_prefill)
is_prefill, is_deepseek_v3_r1)
if shared_experts:
if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2:
shared_hidden_states = shared_experts(hidden_states)
tp_size = get_tensor_model_parallel_world_size()
if tp_size > 1 and fused_moe_state != FusedMoEState.AllGather:
if (tp_size > 1 and fused_moe_state != FusedMoEState.AllGather
and fused_moe_state != FusedMoEState.AllGatherEP):
if num_tokens < tp_size:
hidden_states = nn.functional.pad(
hidden_states, (0, 0, 0, tp_size - num_tokens))
@@ -1285,7 +1288,8 @@ class AscendFusedMoE(FusedMoE):
if isinstance(e_hidden_states, tuple):
e_hidden_states, shared_hidden_states = e_hidden_states
if tp_size > 1 and fused_moe_state != FusedMoEState.AllGather:
if (tp_size > 1 and fused_moe_state != FusedMoEState.AllGather
and fused_moe_state != FusedMoEState.AllGatherEP):
dist.all_gather(list(chunk_hidden_states), e_hidden_states,
self.tp_group)
final_hidden_states = torch.cat(chunk_hidden_states, dim=0)
@@ -1303,7 +1307,8 @@ class AscendFusedMoE(FusedMoE):
else:
final_hidden_states = e_hidden_states
if tp_size > 1 and fused_moe_state == FusedMoEState.AllGather:
if tp_size > 1 and (fused_moe_state == FusedMoEState.AllGather
or fused_moe_state == FusedMoEState.AllGatherEP):
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)