[bugfix] Use FUSED_MC2 MoE comm path for the op dispatch_ffn_combine (#5156)

### What this PR does / why we need it?

- Renames the MoE comm enum value `MoECommType.FUSED_ALLTOALL` to
`MoECommType.FUSED_MC2` and updates all call sites.
- Updates `select_moe_comm_method` to optionally select `FUSED_MC2` on
Ascend A3 when:
  - `enable_expert_parallel=True`
  - quantization is `w8a8_dynamic`
  - `EP <= 16`
  - `dynamic_eplb` is disabled
  - `is_mtp_model = False`
- Replaces the old “fused all-to-all” comm implementation with
`FusedMC2CommImpl`, using `TokenDispatcherWithMC2` /
`PrepareAndFinalizeWithMC2` and `dispatch_ffn_combine`.

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: Chen Chen <0109chenchen@gmail.com>
This commit is contained in:
Chen Chen
2025-12-18 23:34:31 +08:00
committed by GitHub
parent 73e4b4f496
commit 1b47fca0e8
7 changed files with 89 additions and 75 deletions

View File

@@ -23,6 +23,7 @@ from vllm.config import CompilationMode, get_current_vllm_config
from vllm.distributed import get_ep_group
from vllm.forward_context import get_forward_context
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.distributed.parallel_state import get_mc2_group
@@ -246,15 +247,16 @@ class AscendW8A8DynamicFusedMoEMethod:
w2 = [layer.w2_weight]
w2_scale = [layer.w2_weight_scale]
fused_flag = get_forward_context(
).moe_comm_type == MoECommType.FUSED_ALLTOALL
fused_scale_flag = (get_forward_context().moe_comm_type
== MoECommType.FUSED_MC2
and envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1)
return moe_comm_method.fused_experts(
hidden_states=x,
pertoken_scale=pertoken_scale,
w1=w1[0] if fused_flag else w1,
w1_scale=layer.fused_w1_scale if fused_flag else w1_scale,
w2=w2[0] if fused_flag else w2,
w2_scale=layer.fused_w2_scale if fused_flag else w2_scale,
w1=w1,
w1_scale=[layer.fused_w1_scale] if fused_scale_flag else w1_scale,
w2=w2,
w2_scale=[layer.fused_w2_scale] if fused_scale_flag else w2_scale,
topk_weights=topk_weights,
topk_ids=topk_ids,
use_int8_w8a8=True,