[bugfix] Use FUSED_MC2 MoE comm path for the op dispatch_ffn_combine (#5156)

### What this PR does / why we need it?

- Renames the MoE comm enum value `MoECommType.FUSED_ALLTOALL` to
`MoECommType.FUSED_MC2` and updates all call sites.
- Updates `select_moe_comm_method` to optionally select `FUSED_MC2` on
Ascend A3 when:
  - `enable_expert_parallel=True`
  - quantization is `w8a8_dynamic`
  - `EP <= 16`
  - `dynamic_eplb` is disabled
  - `is_mtp_model = False`
- Replaces the old “fused all-to-all” comm implementation with
`FusedMC2CommImpl`, using `TokenDispatcherWithMC2` /
`PrepareAndFinalizeWithMC2` and `dispatch_ffn_combine`.

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: Chen Chen <0109chenchen@gmail.com>
This commit is contained in:
Chen Chen
2025-12-18 23:34:31 +08:00
committed by GitHub
parent 73e4b4f496
commit 1b47fca0e8
7 changed files with 89 additions and 75 deletions

View File

@@ -26,7 +26,7 @@ class MoECommType(Enum):
ALLGATHER = 0
MC2 = 1
ALLTOALL = 2
FUSED_ALLTOALL = 3
FUSED_MC2 = 3
@contextmanager
@@ -62,11 +62,8 @@ def set_ascend_forward_context(
from vllm_ascend.ops.fused_moe.moe_comm_method import \
get_moe_comm_method
moe_comm_type = select_moe_comm_method(num_tokens, vllm_config)
# TODO: remove this after moe_comm_type selection logic is finalized
if is_mtp_model:
moe_comm_type = (MoECommType.ALLTOALL if moe_comm_type
== MoECommType.FUSED_ALLTOALL else moe_comm_type)
moe_comm_type = select_moe_comm_method(num_tokens, vllm_config,
is_mtp_model)
forward_context.moe_comm_type = moe_comm_type
forward_context.moe_comm_method = get_moe_comm_method(moe_comm_type)
@@ -210,28 +207,29 @@ def get_mc2_mask():
def select_moe_comm_method(num_tokens: int,
vllm_config: VllmConfig) -> Optional[MoECommType]:
"""1. If expert parallel is not enabled, we use all-gather since MC2 and all-to-all
are designed for expert parallelism.
2. If expert parallel is enabled, we need to consider the soc version and the
number of tokens. This is based on the observation that all-gather is more
efficient than all-to-all when running on A2.
vllm_config: VllmConfig,
is_mtp_model=False) -> Optional[MoECommType]:
"""Select the MoE communication method according to parallel settings,
device generation, token count, and quantization.
a. For A2, we choose from MC2 and all-gather.
b. For A3, we choose from MC2 and all-to-all.
In both cases, we use MC2 when the number of tokens is smaller than
a its capacity threshold.
1. Non-MoE models return `None`.
2. Without expert parallel, fall back to all-gather.
3. On A2 with expert parallel, pick MC2 when tokens fit the MC2 capacity
and the DP size is large enough; otherwise use all-gather.
4. On A3 with expert parallel, prefer fused MC2 when using w8a8_dynamic
quantization with small EP size, no dynamic_eplb, and not in MTP
mode; otherwise use MC2 within capacity or all-to-all.
Args:
num_tokens (int): The number of tokens in the current batch.
vllm_config (VllmConfig): Runtime configuration for the model.
is_mtp_model (bool): Whether the model runs in MTP mode (disables fused MC2).
Raises:
ValueError: If the soc version is unsupported.
Returns:
MoECommType: The selected MoE communication method.
MoECommType | None: The selected MoE communication method.
"""
if not is_moe_model(vllm_config):
return None
@@ -255,11 +253,13 @@ def select_moe_comm_method(num_tokens: int,
ascend_config = get_ascend_config()
dynamic_eplb = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path
# TODO: drop the EP-size guard when dispatch_ffn_combine supports larger EP sizes
fused_all2all_enable = quant_type == "w8a8_dynamic" and get_ep_group(
).world_size <= 16 and (not dynamic_eplb)
moe_comm_type = (MoECommType.MC2 if num_tokens <= mc2_tokens_capacity
else MoECommType.FUSED_ALLTOALL
if fused_all2all_enable else MoECommType.ALLTOALL)
fused_mc2_enable = envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 and quant_type == "w8a8_dynamic" and get_ep_group(
).world_size <= 16 and (not dynamic_eplb) and (not is_mtp_model)
if num_tokens <= mc2_tokens_capacity:
moe_comm_type = MoECommType.FUSED_MC2 if fused_mc2_enable else MoECommType.MC2
else:
moe_comm_type = MoECommType.FUSED_MC2 if fused_mc2_enable else MoECommType.ALLTOALL
else:
raise ValueError(f"Unsupported soc_version: {soc_version}")
return moe_comm_type

View File

@@ -132,6 +132,9 @@ env_variables: Dict[str, Callable[[], Any]] = {
# Whether to anbale dynamic EPLB
"DYNAMIC_EPLB":
lambda: os.getenv("DYNAMIC_EPLB", "false").lower(),
# Whether to anbale fused mc2(dispatch_gmm_combine_decode/dispatch_ffn_combine operator)
"VLLM_ASCEND_ENABLE_FUSED_MC2":
lambda: int(os.getenv("VLLM_ASCEND_ENABLE_FUSED_MC2", '0')),
}
# end-env-vars-definition

View File

@@ -533,7 +533,7 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
forward_context = get_forward_context()
moe_comm_type = forward_context.moe_comm_type
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_ALLTOALL} \
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2} \
and not shared_expert_dp_enabled():
shared_out = tensor_model_parallel_all_reduce(shared_out)
else:

View File

@@ -22,6 +22,7 @@ import torch
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.ops.fused_moe.moe_mlp import unified_apply_mlp
from vllm_ascend.ops.fused_moe.prepare_finalize import (
@@ -43,8 +44,7 @@ def setup_moe_comm_method(moe_config):
_MoECommMethods[MoECommType.ALLTOALL] = AlltoAllCommImpl(moe_config)
_MoECommMethods[MoECommType.ALLGATHER] = AllGatherCommImpl(moe_config)
_MoECommMethods[MoECommType.MC2] = MC2CommImpl(moe_config)
_MoECommMethods[MoECommType.FUSED_ALLTOALL] = FusedAlltoAllCommImpl(
moe_config)
_MoECommMethods[MoECommType.FUSED_MC2] = FusedMC2CommImpl(moe_config)
class MoECommMethod(ABC):
@@ -241,30 +241,27 @@ class AlltoAllCommImpl(MoECommMethod):
return PrepareAndFinalizeWithAll2All(self.moe_config)
class FusedAlltoAllCommImpl(MoECommMethod):
class FusedMC2CommImpl(MoECommMethod):
"""This implementation is for the scenarios listed below:
1. `enable_expert_parallel=True`.
2. `npu_grouped_matmul` is available.
2. `npu_moe_distribute_dispatch` and `npu_moe_distribute_combine` are available.
3. `enable_expert_parallel=False` is not supported.
This implementation uses all-to-all communication to exchange tokens
between data parallel ranks before and after the MLP computation. It should
have better performance than AllGatherCommImpl when DP size > 1.
This implementation uses the MC2 communication method, which is optimized for
Communication and Computation parallelism on Ascend devices.
"""
def _get_token_dispatcher(self):
return TokenDispatcherWithAll2AllV(
top_k=self.moe_config.experts_per_token,
num_experts=self.moe_config.num_experts,
num_local_experts=self.moe_config.num_local_experts)
return TokenDispatcherWithMC2()
def _get_prepare_finalize(self):
return PrepareAndFinalizeWithAll2All(self.moe_config)
return PrepareAndFinalizeWithMC2(self.moe_config)
def fused_experts(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1: torch.Tensor | list[torch.Tensor],
w2: torch.Tensor | list[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
@@ -274,8 +271,8 @@ class FusedAlltoAllCommImpl(MoECommMethod):
use_int4_w4a16: bool = False,
global_num_experts: Optional[int] = None,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_scale: Optional[list[torch.Tensor]] = None,
w2_scale: Optional[list[torch.Tensor]] = None,
w1_scale_bias: torch.Tensor = None,
w2_scale_bias: torch.Tensor = None,
w1_offset: Optional[torch.Tensor] = None,
@@ -291,18 +288,27 @@ class FusedAlltoAllCommImpl(MoECommMethod):
dynamic_eplb: bool = False,
mc2_mask: torch.Tensor = None,
pertoken_scale: Optional[torch.Tensor] = None):
assert not (
w1_scale is None or w2_scale is None
), "w1_scale and w2_scale cannot be None for FusedMC2CommImpl."
out = torch.empty_like(hidden_states)
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
torch.ops._C_ascend.dispatch_ffn_combine(
x=hidden_states,
weight1=w1,
weight2=w2,
weight1=w1[0],
weight2=w2[0],
expert_idx=topk_ids,
scale1=w1_scale,
scale2=w2_scale,
scale1=w1_scale[0],
scale2=w2_scale[0],
probs=topk_weights.to(torch.float32),
group=self.token_dispatcher.moe_all_to_all_group_name,
max_output_size=65536,
out=out,
)
elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2:
raise NotImplementedError()
else:
raise ValueError(
f"Wrong value of {envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2=}")
return out

View File

@@ -130,7 +130,8 @@ def _maybe_prefetch_mlp_gate_up_proj_impl(x_dependency: torch.Tensor,
with torch.npu.stream(prefetch_stream):
mlp_gate_up_prefetch_size = envs_ascend.VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE
torch_npu.npu_prefetch(model_instance.model.layers[layer_idx].mlp.gate_up_proj.weight, \
torch_npu.npu_prefetch(
model_instance.model.layers[layer_idx].mlp.gate_up_proj.weight,
x_dependency, mlp_gate_up_prefetch_size)
return
@@ -185,7 +186,8 @@ def _maybe_prefetch_mlp_down_proj_impl(x_dependency: torch.Tensor) -> None:
with torch.npu.stream(prefetch_stream):
mlp_down_prefetch_size = envs_ascend.VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE
torch_npu.npu_prefetch(model_instance.model.layers[layer_idx].mlp.down_proj.weight, \
torch_npu.npu_prefetch(
model_instance.model.layers[layer_idx].mlp.down_proj.weight,
x_dependency, mlp_down_prefetch_size)
forward_context.layer_idx += 1
return
@@ -250,7 +252,7 @@ def _maybe_all_reduce_tensor_model_parallel_impl(
forward_context = get_forward_context()
moe_comm_type = forward_context.moe_comm_type
if moe_comm_type in {
MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_ALLTOALL
MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2
} or forward_context.sp_enabled:
return final_hidden_states
else:

View File

@@ -23,6 +23,7 @@ from vllm.config import CompilationMode, get_current_vllm_config
from vllm.distributed import get_ep_group
from vllm.forward_context import get_forward_context
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.distributed.parallel_state import get_mc2_group
@@ -246,15 +247,16 @@ class AscendW8A8DynamicFusedMoEMethod:
w2 = [layer.w2_weight]
w2_scale = [layer.w2_weight_scale]
fused_flag = get_forward_context(
).moe_comm_type == MoECommType.FUSED_ALLTOALL
fused_scale_flag = (get_forward_context().moe_comm_type
== MoECommType.FUSED_MC2
and envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1)
return moe_comm_method.fused_experts(
hidden_states=x,
pertoken_scale=pertoken_scale,
w1=w1[0] if fused_flag else w1,
w1_scale=layer.fused_w1_scale if fused_flag else w1_scale,
w2=w2[0] if fused_flag else w2,
w2_scale=layer.fused_w2_scale if fused_flag else w2_scale,
w1=w1,
w1_scale=[layer.fused_w1_scale] if fused_scale_flag else w1_scale,
w2=w2,
w2_scale=[layer.fused_w2_scale] if fused_scale_flag else w2_scale,
topk_weights=topk_weights,
topk_ids=topk_ids,
use_int8_w8a8=True,

View File

@@ -430,7 +430,8 @@ class NPUModelRunner(GPUModelRunner):
# moe_comm_method of each rank is MC2 and recomputation would never happen in D
# nodes. So here we check whether recompute_scheduler_enable is True.
return self.is_kv_consumer and self.ascend_config.recompute_scheduler_enable and select_moe_comm_method(
potential_max_num_tokens, self.vllm_config) == MoECommType.MC2
potential_max_num_tokens,
self.vllm_config) in {MoECommType.MC2, MoECommType.FUSED_MC2}
def _sync_metadata_across_dp(
self, num_tokens: int,
@@ -2203,7 +2204,7 @@ class NPUModelRunner(GPUModelRunner):
def profile_run(self) -> None:
mc2_tokens_capacity = get_mc2_tokens_capacity()
if self.max_num_tokens > mc2_tokens_capacity and \
select_moe_comm_method(mc2_tokens_capacity, self.vllm_config) == MoECommType.MC2:
select_moe_comm_method(mc2_tokens_capacity, self.vllm_config) in {MoECommType.MC2, MoECommType.FUSED_MC2}:
self._dummy_run(mc2_tokens_capacity,
with_prefill=True,
is_profile=True)