adjusting the communication method in graph mode (#1194)

### What this PR does / why we need it?
Communication performance optimization: replace allreduce with
reduce_scatter+all_gather in MLA layer's TP group,to remove
stridedsliced and all_gather in MOE layer.
when tp > 1, It is enabled during the decode phase of the graph mode
when enable_multistream_moe、MLA, use_v1, and MC2 are used.
According to the end-to-end RL inference test results, this PR can bring
3% gain in the decode stage.

**Before Improvement**
Profiling kernel_details

![image](https://github.com/user-attachments/assets/1bb5dfa1-809b-410a-90c9-c5fd23cff003)
Evaluation

![image](https://github.com/user-attachments/assets/0b8ea0c7-88e7-410f-9ef4-f0cfe910cdc7)

![image](https://github.com/user-attachments/assets/94fde910-c125-4c2e-8de4-88fc3fafc057)

**After Improvement**
Profiling kernel_details

![image](https://github.com/user-attachments/assets/55fac0e0-11f2-4654-8fd4-287949e0b29e)
Evaluation

![image](https://github.com/user-attachments/assets/e923f74b-29c4-4171-9382-40a00cf05df0)

![image](https://github.com/user-attachments/assets/5dba7967-07ea-4926-a8be-804bfd34e3e4)

### Does this PR introduce _any_ user-facing change?
Users need to configure enable_multistream_moe=True

### How was this patch tested?
Add e2e test cases to cover code logic

Signed-off-by: sharonyunyun <zhangying134@huawei.com>
This commit is contained in:
sharonyunyun
2025-06-25 19:56:49 +08:00
committed by GitHub
parent 205cb85a1e
commit 941269a6c5
6 changed files with 195 additions and 37 deletions

View File

@@ -1211,7 +1211,8 @@ class AscendFusedMoE(FusedMoE):
is_prefill: bool,
enable_force_load_balance: bool = False,
top_k: Optional[int] = None,
shared_experts: Optional[Any] = None):
shared_experts: Optional[Any] = None,
replace_allreduce: bool = False):
assert self.quant_method is not None
if top_k:
@@ -1230,7 +1231,8 @@ class AscendFusedMoE(FusedMoE):
tp_size = get_tensor_model_parallel_world_size()
if (tp_size > 1 and fused_moe_state != FusedMoEState.AllGather
and fused_moe_state != FusedMoEState.AllGatherEP):
and fused_moe_state != FusedMoEState.AllGatherEP
and not replace_allreduce):
if num_tokens < tp_size:
hidden_states = nn.functional.pad(
hidden_states, (0, 0, 0, tp_size - num_tokens))
@@ -1289,7 +1291,8 @@ class AscendFusedMoE(FusedMoE):
e_hidden_states, shared_hidden_states = e_hidden_states
if (tp_size > 1 and fused_moe_state != FusedMoEState.AllGather
and fused_moe_state != FusedMoEState.AllGatherEP):
and fused_moe_state != FusedMoEState.AllGatherEP
and not replace_allreduce):
dist.all_gather(list(chunk_hidden_states), e_hidden_states,
self.tp_group)
final_hidden_states = torch.cat(chunk_hidden_states, dim=0)