[bugfix] Use FUSED_MC2 MoE comm path for the op dispatch_ffn_combine (#5156)

### What this PR does / why we need it?

- Renames the MoE comm enum value `MoECommType.FUSED_ALLTOALL` to
`MoECommType.FUSED_MC2` and updates all call sites.
- Updates `select_moe_comm_method` to optionally select `FUSED_MC2` on
Ascend A3 when:
  - `enable_expert_parallel=True`
  - quantization is `w8a8_dynamic`
  - `EP <= 16`
  - `dynamic_eplb` is disabled
  - `is_mtp_model = False`
- Replaces the old “fused all-to-all” comm implementation with
`FusedMC2CommImpl`, using `TokenDispatcherWithMC2` /
`PrepareAndFinalizeWithMC2` and `dispatch_ffn_combine`.

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: Chen Chen <0109chenchen@gmail.com>
This commit is contained in:
Chen Chen
2025-12-18 23:34:31 +08:00
committed by GitHub
parent 73e4b4f496
commit 1b47fca0e8
7 changed files with 89 additions and 75 deletions

View File

@@ -22,6 +22,7 @@ import torch
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.ops.fused_moe.moe_mlp import unified_apply_mlp
from vllm_ascend.ops.fused_moe.prepare_finalize import (
@@ -43,8 +44,7 @@ def setup_moe_comm_method(moe_config):
_MoECommMethods[MoECommType.ALLTOALL] = AlltoAllCommImpl(moe_config)
_MoECommMethods[MoECommType.ALLGATHER] = AllGatherCommImpl(moe_config)
_MoECommMethods[MoECommType.MC2] = MC2CommImpl(moe_config)
_MoECommMethods[MoECommType.FUSED_ALLTOALL] = FusedAlltoAllCommImpl(
moe_config)
_MoECommMethods[MoECommType.FUSED_MC2] = FusedMC2CommImpl(moe_config)
class MoECommMethod(ABC):
@@ -241,30 +241,27 @@ class AlltoAllCommImpl(MoECommMethod):
return PrepareAndFinalizeWithAll2All(self.moe_config)
class FusedAlltoAllCommImpl(MoECommMethod):
class FusedMC2CommImpl(MoECommMethod):
"""This implementation is for the scenarios listed below:
1. `enable_expert_parallel=True`.
2. `npu_grouped_matmul` is available.
This implementation uses all-to-all communication to exchange tokens
between data parallel ranks before and after the MLP computation. It should
have better performance than AllGatherCommImpl when DP size > 1.
2. `npu_moe_distribute_dispatch` and `npu_moe_distribute_combine` are available.
3. `enable_expert_parallel=False` is not supported.
This implementation uses the MC2 communication method, which is optimized for
Communication and Computation parallelism on Ascend devices.
"""
def _get_token_dispatcher(self):
return TokenDispatcherWithAll2AllV(
top_k=self.moe_config.experts_per_token,
num_experts=self.moe_config.num_experts,
num_local_experts=self.moe_config.num_local_experts)
return TokenDispatcherWithMC2()
def _get_prepare_finalize(self):
return PrepareAndFinalizeWithAll2All(self.moe_config)
return PrepareAndFinalizeWithMC2(self.moe_config)
def fused_experts(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1: torch.Tensor | list[torch.Tensor],
w2: torch.Tensor | list[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
@@ -274,8 +271,8 @@ class FusedAlltoAllCommImpl(MoECommMethod):
use_int4_w4a16: bool = False,
global_num_experts: Optional[int] = None,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_scale: Optional[list[torch.Tensor]] = None,
w2_scale: Optional[list[torch.Tensor]] = None,
w1_scale_bias: torch.Tensor = None,
w2_scale_bias: torch.Tensor = None,
w1_offset: Optional[torch.Tensor] = None,
@@ -291,18 +288,27 @@ class FusedAlltoAllCommImpl(MoECommMethod):
dynamic_eplb: bool = False,
mc2_mask: torch.Tensor = None,
pertoken_scale: Optional[torch.Tensor] = None):
assert not (
w1_scale is None or w2_scale is None
), "w1_scale and w2_scale cannot be None for FusedMC2CommImpl."
out = torch.empty_like(hidden_states)
torch.ops._C_ascend.dispatch_ffn_combine(
x=hidden_states,
weight1=w1,
weight2=w2,
expert_idx=topk_ids,
scale1=w1_scale,
scale2=w2_scale,
probs=topk_weights.to(torch.float32),
group=self.token_dispatcher.moe_all_to_all_group_name,
max_output_size=65536,
out=out,
)
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
torch.ops._C_ascend.dispatch_ffn_combine(
x=hidden_states,
weight1=w1[0],
weight2=w2[0],
expert_idx=topk_ids,
scale1=w1_scale[0],
scale2=w2_scale[0],
probs=topk_weights.to(torch.float32),
group=self.token_dispatcher.moe_all_to_all_group_name,
max_output_size=65536,
out=out,
)
elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2:
raise NotImplementedError()
else:
raise ValueError(
f"Wrong value of {envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2=}")
return out