support fused_moe_allgather_ep (#1335)
### What this PR does / why we need it? support fused_moe_allgather_ep ### How was this patch tested? It was tested by UT. Signed-off-by: lyj-jjj <liuyingjun5@huawei.com>
This commit is contained in:
@@ -988,8 +988,9 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
|
||||
is_deepseek_v3_r1 = global_num_experts == 256
|
||||
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
||||
if global_num_experts == 256:
|
||||
if is_deepseek_v3_r1:
|
||||
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
|
||||
router_logits,
|
||||
k=top_k, # topk当前写8
|
||||
@@ -1025,7 +1026,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
|
||||
|
||||
fused_moe_state = get_fused_moe_state(self.ep_group.world_size,
|
||||
is_prefill)
|
||||
is_prefill, is_deepseek_v3_r1)
|
||||
if fused_moe_state == FusedMoEState.MC2:
|
||||
return fused_experts_with_mc2(
|
||||
hidden_states=x,
|
||||
@@ -1219,15 +1220,17 @@ class AscendFusedMoE(FusedMoE):
|
||||
real_top_k = self.top_k
|
||||
|
||||
num_tokens, hidden_size = hidden_states.shape
|
||||
is_deepseek_v3_r1 = self.global_num_experts == 256
|
||||
|
||||
fused_moe_state = get_fused_moe_state(self.moe_parallel_config.ep_size,
|
||||
is_prefill)
|
||||
is_prefill, is_deepseek_v3_r1)
|
||||
if shared_experts:
|
||||
if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2:
|
||||
shared_hidden_states = shared_experts(hidden_states)
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
if tp_size > 1 and fused_moe_state != FusedMoEState.AllGather:
|
||||
if (tp_size > 1 and fused_moe_state != FusedMoEState.AllGather
|
||||
and fused_moe_state != FusedMoEState.AllGatherEP):
|
||||
if num_tokens < tp_size:
|
||||
hidden_states = nn.functional.pad(
|
||||
hidden_states, (0, 0, 0, tp_size - num_tokens))
|
||||
@@ -1285,7 +1288,8 @@ class AscendFusedMoE(FusedMoE):
|
||||
if isinstance(e_hidden_states, tuple):
|
||||
e_hidden_states, shared_hidden_states = e_hidden_states
|
||||
|
||||
if tp_size > 1 and fused_moe_state != FusedMoEState.AllGather:
|
||||
if (tp_size > 1 and fused_moe_state != FusedMoEState.AllGather
|
||||
and fused_moe_state != FusedMoEState.AllGatherEP):
|
||||
dist.all_gather(list(chunk_hidden_states), e_hidden_states,
|
||||
self.tp_group)
|
||||
final_hidden_states = torch.cat(chunk_hidden_states, dim=0)
|
||||
@@ -1303,7 +1307,8 @@ class AscendFusedMoE(FusedMoE):
|
||||
else:
|
||||
final_hidden_states = e_hidden_states
|
||||
|
||||
if tp_size > 1 and fused_moe_state == FusedMoEState.AllGather:
|
||||
if tp_size > 1 and (fused_moe_state == FusedMoEState.AllGather
|
||||
or fused_moe_state == FusedMoEState.AllGatherEP):
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
final_hidden_states)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user