adjusting the communication method in graph mode (#1194)
### What this PR does / why we need it? Communication performance optimization: replace allreduce with reduce_scatter+all_gather in MLA layer's TP group,to remove stridedsliced and all_gather in MOE layer. when tp > 1, It is enabled during the decode phase of the graph mode when enable_multistream_moe、MLA, use_v1, and MC2 are used. According to the end-to-end RL inference test results, this PR can bring 3% gain in the decode stage. **Before Improvement** Profiling kernel_details  Evaluation   **After Improvement** Profiling kernel_details  Evaluation   ### Does this PR introduce _any_ user-facing change? Users need to configure enable_multistream_moe=True ### How was this patch tested? Add e2e test cases to cover code logic Signed-off-by: sharonyunyun <zhangying134@huawei.com>
This commit is contained in:
@@ -1211,7 +1211,8 @@ class AscendFusedMoE(FusedMoE):
|
||||
is_prefill: bool,
|
||||
enable_force_load_balance: bool = False,
|
||||
top_k: Optional[int] = None,
|
||||
shared_experts: Optional[Any] = None):
|
||||
shared_experts: Optional[Any] = None,
|
||||
replace_allreduce: bool = False):
|
||||
assert self.quant_method is not None
|
||||
|
||||
if top_k:
|
||||
@@ -1230,7 +1231,8 @@ class AscendFusedMoE(FusedMoE):
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
if (tp_size > 1 and fused_moe_state != FusedMoEState.AllGather
|
||||
and fused_moe_state != FusedMoEState.AllGatherEP):
|
||||
and fused_moe_state != FusedMoEState.AllGatherEP
|
||||
and not replace_allreduce):
|
||||
if num_tokens < tp_size:
|
||||
hidden_states = nn.functional.pad(
|
||||
hidden_states, (0, 0, 0, tp_size - num_tokens))
|
||||
@@ -1289,7 +1291,8 @@ class AscendFusedMoE(FusedMoE):
|
||||
e_hidden_states, shared_hidden_states = e_hidden_states
|
||||
|
||||
if (tp_size > 1 and fused_moe_state != FusedMoEState.AllGather
|
||||
and fused_moe_state != FusedMoEState.AllGatherEP):
|
||||
and fused_moe_state != FusedMoEState.AllGatherEP
|
||||
and not replace_allreduce):
|
||||
dist.all_gather(list(chunk_hidden_states), e_hidden_states,
|
||||
self.tp_group)
|
||||
final_hidden_states = torch.cat(chunk_hidden_states, dim=0)
|
||||
|
||||
Reference in New Issue
Block a user