[Fix][MoE] Refine MoE communication strategy (#2734)

### What this PR does / why we need it?
Refactors the Mixture-of-Experts (MoE) communication method selection
logic. The choice between all-gather, all-to-all, and mc2 is now
determined by expert parallel configuration, SoC version (A2/A3), and
token count for better performance.

### Does this PR introduce _any_ user-facing change?
None.

### How was this patch tested?
Added.


- vLLM version: v0.10.1.1
- vLLM main:
eafa8dcde6

---------

Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
yiz-liu
2025-09-05 09:04:04 +08:00
committed by GitHub
parent 4c90fa79ca
commit 83eb40a51c
3 changed files with 123 additions and 9 deletions

View File

@@ -1434,14 +1434,39 @@ class NPUModelRunner(LoRAModelRunnerMixin):
)
def _select_moe_comm_method(self, num_tokens: int) -> str:
"""1. If expert parallel is not enabled, we use all-gather since MC2 and all-to-all
are designed for expert parallelism.
2. If expert parallel is enabled, we need to consider the soc version and the
number of tokens. This is based on the observation that all-gather is more
efficient than all-to-all when running on A2.
a. For A2, we choose from MC2 and all-gather.
b. For A3, we choose from MC2 and all-to-all.
In both cases, we use MC2 when the number of tokens is smaller than
a its capacity threshold.
Args:
num_tokens (int): The number of tokens in the current batch.
Raises:
ValueError: If the soc version is unsupported.
Returns:
str: The selected MoE communication method, either "allgather", "mc2", or "alltoall".
"""
soc_version = get_ascend_soc_version()
if num_tokens <= self.mc2_tokens_capacity:
moe_comm_method = "mc2"
elif soc_version in {AscendSocVersion.A2}:
if not self.parallel_config.enable_expert_parallel:
moe_comm_method = "allgather"
elif soc_version in {AscendSocVersion.A2}:
if num_tokens <= self.mc2_tokens_capacity and self.parallel_config.world_size >= 16:
moe_comm_method = "mc2"
else:
moe_comm_method = "allgather"
elif soc_version in {AscendSocVersion.A3}:
moe_comm_method = "alltoall"
moe_comm_method = "mc2" if num_tokens <= self.mc2_tokens_capacity else "alltoall"
else:
raise ValueError(f"Unsupported soc_version: {soc_version}")