[Fix][MoE] Refine MoE communication strategy (#2734)
### What this PR does / why we need it?
Refactors the Mixture-of-Experts (MoE) communication method selection
logic. The choice between all-gather, all-to-all, and mc2 is now
determined by expert parallel configuration, SoC version (A2/A3), and
token count for better performance.
### Does this PR introduce _any_ user-facing change?
None.
### How was this patch tested?
Added.
- vLLM version: v0.10.1.1
- vLLM main:
eafa8dcde6
---------
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
@@ -482,11 +482,6 @@ class AscendFusedMoE(FusedMoE):
|
||||
forward_context = get_forward_context()
|
||||
moe_comm_method_name = forward_context.moe_comm_method_name
|
||||
|
||||
# TODO: Can we refactor this logic to model_runner?
|
||||
# TODO: Adjusted logic to differentiate between A2 and A3, we check ep_size here since mc2 only support ep_size >= 16 on A3 now
|
||||
if self.moe_config.ep_size < 16:
|
||||
moe_comm_method_name = "allgathercommimpl"
|
||||
|
||||
forward_context.moe_comm_method = getattr(self, moe_comm_method_name)
|
||||
|
||||
hidden_states, router_logits = forward_context.moe_comm_method.prepare(
|
||||
|
||||
@@ -1434,14 +1434,39 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
)
|
||||
|
||||
def _select_moe_comm_method(self, num_tokens: int) -> str:
|
||||
"""1. If expert parallel is not enabled, we use all-gather since MC2 and all-to-all
|
||||
are designed for expert parallelism.
|
||||
2. If expert parallel is enabled, we need to consider the soc version and the
|
||||
number of tokens. This is based on the observation that all-gather is more
|
||||
efficient than all-to-all when running on A2.
|
||||
|
||||
a. For A2, we choose from MC2 and all-gather.
|
||||
|
||||
b. For A3, we choose from MC2 and all-to-all.
|
||||
|
||||
In both cases, we use MC2 when the number of tokens is smaller than
|
||||
a its capacity threshold.
|
||||
|
||||
Args:
|
||||
num_tokens (int): The number of tokens in the current batch.
|
||||
|
||||
Raises:
|
||||
ValueError: If the soc version is unsupported.
|
||||
|
||||
Returns:
|
||||
str: The selected MoE communication method, either "allgather", "mc2", or "alltoall".
|
||||
"""
|
||||
soc_version = get_ascend_soc_version()
|
||||
|
||||
if num_tokens <= self.mc2_tokens_capacity:
|
||||
moe_comm_method = "mc2"
|
||||
elif soc_version in {AscendSocVersion.A2}:
|
||||
if not self.parallel_config.enable_expert_parallel:
|
||||
moe_comm_method = "allgather"
|
||||
elif soc_version in {AscendSocVersion.A2}:
|
||||
if num_tokens <= self.mc2_tokens_capacity and self.parallel_config.world_size >= 16:
|
||||
moe_comm_method = "mc2"
|
||||
else:
|
||||
moe_comm_method = "allgather"
|
||||
elif soc_version in {AscendSocVersion.A3}:
|
||||
moe_comm_method = "alltoall"
|
||||
moe_comm_method = "mc2" if num_tokens <= self.mc2_tokens_capacity else "alltoall"
|
||||
else:
|
||||
raise ValueError(f"Unsupported soc_version: {soc_version}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user