[bugfix] Use FUSED_MC2 MoE comm path for the op dispatch_ffn_combine (#5156)

### What this PR does / why we need it?

- Renames the MoE comm enum value `MoECommType.FUSED_ALLTOALL` to
`MoECommType.FUSED_MC2` and updates all call sites.
- Updates `select_moe_comm_method` to optionally select `FUSED_MC2` on
Ascend A3 when:
  - `enable_expert_parallel=True`
  - quantization is `w8a8_dynamic`
  - `EP <= 16`
  - `dynamic_eplb` is disabled
  - `is_mtp_model = False`
- Replaces the old “fused all-to-all” comm implementation with
`FusedMC2CommImpl`, using `TokenDispatcherWithMC2` /
`PrepareAndFinalizeWithMC2` and `dispatch_ffn_combine`.

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: Chen Chen <0109chenchen@gmail.com>
This commit is contained in:
Chen Chen
2025-12-18 23:34:31 +08:00
committed by GitHub
parent 73e4b4f496
commit 1b47fca0e8
7 changed files with 89 additions and 75 deletions

View File

@@ -26,7 +26,7 @@ class MoECommType(Enum):
ALLGATHER = 0 ALLGATHER = 0
MC2 = 1 MC2 = 1
ALLTOALL = 2 ALLTOALL = 2
FUSED_ALLTOALL = 3 FUSED_MC2 = 3
@contextmanager @contextmanager
@@ -62,11 +62,8 @@ def set_ascend_forward_context(
from vllm_ascend.ops.fused_moe.moe_comm_method import \ from vllm_ascend.ops.fused_moe.moe_comm_method import \
get_moe_comm_method get_moe_comm_method
moe_comm_type = select_moe_comm_method(num_tokens, vllm_config) moe_comm_type = select_moe_comm_method(num_tokens, vllm_config,
# TODO: remove this after moe_comm_type selection logic is finalized is_mtp_model)
if is_mtp_model:
moe_comm_type = (MoECommType.ALLTOALL if moe_comm_type
== MoECommType.FUSED_ALLTOALL else moe_comm_type)
forward_context.moe_comm_type = moe_comm_type forward_context.moe_comm_type = moe_comm_type
forward_context.moe_comm_method = get_moe_comm_method(moe_comm_type) forward_context.moe_comm_method = get_moe_comm_method(moe_comm_type)
@@ -93,7 +90,7 @@ def set_ascend_forward_context(
forward_context.mmrs_fusion = mmrs_fusion forward_context.mmrs_fusion = mmrs_fusion
forward_context.num_tokens = num_tokens forward_context.num_tokens = num_tokens
forward_context.sp_enabled = sp_enabled forward_context.sp_enabled = sp_enabled
#TODO(Levi-JQ): another PR to normalize the enabling logic for sp/fc2 # TODO(Levi-JQ): another PR to normalize the enabling logic for sp/fc2
forward_context.flashcomm_v2_enabled = flashcomm2_enable( forward_context.flashcomm_v2_enabled = flashcomm2_enable(
) and tp_world_size > 1 and num_tokens is not None ) and tp_world_size > 1 and num_tokens is not None
@@ -210,28 +207,29 @@ def get_mc2_mask():
def select_moe_comm_method(num_tokens: int, def select_moe_comm_method(num_tokens: int,
vllm_config: VllmConfig) -> Optional[MoECommType]: vllm_config: VllmConfig,
"""1. If expert parallel is not enabled, we use all-gather since MC2 and all-to-all is_mtp_model=False) -> Optional[MoECommType]:
are designed for expert parallelism. """Select the MoE communication method according to parallel settings,
2. If expert parallel is enabled, we need to consider the soc version and the device generation, token count, and quantization.
number of tokens. This is based on the observation that all-gather is more
efficient than all-to-all when running on A2.
a. For A2, we choose from MC2 and all-gather. 1. Non-MoE models return `None`.
2. Without expert parallel, fall back to all-gather.
b. For A3, we choose from MC2 and all-to-all. 3. On A2 with expert parallel, pick MC2 when tokens fit the MC2 capacity
and the DP size is large enough; otherwise use all-gather.
In both cases, we use MC2 when the number of tokens is smaller than 4. On A3 with expert parallel, prefer fused MC2 when using w8a8_dynamic
a its capacity threshold. quantization with small EP size, no dynamic_eplb, and not in MTP
mode; otherwise use MC2 within capacity or all-to-all.
Args: Args:
num_tokens (int): The number of tokens in the current batch. num_tokens (int): The number of tokens in the current batch.
vllm_config (VllmConfig): Runtime configuration for the model.
is_mtp_model (bool): Whether the model runs in MTP mode (disables fused MC2).
Raises: Raises:
ValueError: If the soc version is unsupported. ValueError: If the soc version is unsupported.
Returns: Returns:
MoECommType: The selected MoE communication method. MoECommType | None: The selected MoE communication method.
""" """
if not is_moe_model(vllm_config): if not is_moe_model(vllm_config):
return None return None
@@ -255,11 +253,13 @@ def select_moe_comm_method(num_tokens: int,
ascend_config = get_ascend_config() ascend_config = get_ascend_config()
dynamic_eplb = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path dynamic_eplb = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path
# TODO: drop the EP-size guard when dispatch_ffn_combine supports larger EP sizes # TODO: drop the EP-size guard when dispatch_ffn_combine supports larger EP sizes
fused_all2all_enable = quant_type == "w8a8_dynamic" and get_ep_group( fused_mc2_enable = envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 and quant_type == "w8a8_dynamic" and get_ep_group(
).world_size <= 16 and (not dynamic_eplb) ).world_size <= 16 and (not dynamic_eplb) and (not is_mtp_model)
moe_comm_type = (MoECommType.MC2 if num_tokens <= mc2_tokens_capacity if num_tokens <= mc2_tokens_capacity:
else MoECommType.FUSED_ALLTOALL moe_comm_type = MoECommType.FUSED_MC2 if fused_mc2_enable else MoECommType.MC2
if fused_all2all_enable else MoECommType.ALLTOALL) else:
moe_comm_type = MoECommType.FUSED_MC2 if fused_mc2_enable else MoECommType.ALLTOALL
else: else:
raise ValueError(f"Unsupported soc_version: {soc_version}") raise ValueError(f"Unsupported soc_version: {soc_version}")
return moe_comm_type return moe_comm_type

View File

@@ -132,6 +132,9 @@ env_variables: Dict[str, Callable[[], Any]] = {
# Whether to anbale dynamic EPLB # Whether to anbale dynamic EPLB
"DYNAMIC_EPLB": "DYNAMIC_EPLB":
lambda: os.getenv("DYNAMIC_EPLB", "false").lower(), lambda: os.getenv("DYNAMIC_EPLB", "false").lower(),
# Whether to anbale fused mc2(dispatch_gmm_combine_decode/dispatch_ffn_combine operator)
"VLLM_ASCEND_ENABLE_FUSED_MC2":
lambda: int(os.getenv("VLLM_ASCEND_ENABLE_FUSED_MC2", '0')),
} }
# end-env-vars-definition # end-env-vars-definition

View File

@@ -533,7 +533,7 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel` # NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
forward_context = get_forward_context() forward_context = get_forward_context()
moe_comm_type = forward_context.moe_comm_type moe_comm_type = forward_context.moe_comm_type
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_ALLTOALL} \ if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2} \
and not shared_expert_dp_enabled(): and not shared_expert_dp_enabled():
shared_out = tensor_model_parallel_all_reduce(shared_out) shared_out = tensor_model_parallel_all_reduce(shared_out)
else: else:

View File

@@ -22,6 +22,7 @@ import torch
from vllm.forward_context import get_forward_context from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe import FusedMoEConfig from vllm.model_executor.layers.fused_moe import FusedMoEConfig
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_forward_context import MoECommType from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.ops.fused_moe.moe_mlp import unified_apply_mlp from vllm_ascend.ops.fused_moe.moe_mlp import unified_apply_mlp
from vllm_ascend.ops.fused_moe.prepare_finalize import ( from vllm_ascend.ops.fused_moe.prepare_finalize import (
@@ -43,8 +44,7 @@ def setup_moe_comm_method(moe_config):
_MoECommMethods[MoECommType.ALLTOALL] = AlltoAllCommImpl(moe_config) _MoECommMethods[MoECommType.ALLTOALL] = AlltoAllCommImpl(moe_config)
_MoECommMethods[MoECommType.ALLGATHER] = AllGatherCommImpl(moe_config) _MoECommMethods[MoECommType.ALLGATHER] = AllGatherCommImpl(moe_config)
_MoECommMethods[MoECommType.MC2] = MC2CommImpl(moe_config) _MoECommMethods[MoECommType.MC2] = MC2CommImpl(moe_config)
_MoECommMethods[MoECommType.FUSED_ALLTOALL] = FusedAlltoAllCommImpl( _MoECommMethods[MoECommType.FUSED_MC2] = FusedMC2CommImpl(moe_config)
moe_config)
class MoECommMethod(ABC): class MoECommMethod(ABC):
@@ -241,30 +241,27 @@ class AlltoAllCommImpl(MoECommMethod):
return PrepareAndFinalizeWithAll2All(self.moe_config) return PrepareAndFinalizeWithAll2All(self.moe_config)
class FusedAlltoAllCommImpl(MoECommMethod): class FusedMC2CommImpl(MoECommMethod):
"""This implementation is for the scenarios listed below: """This implementation is for the scenarios listed below:
1. `enable_expert_parallel=True`. 1. `enable_expert_parallel=True`.
2. `npu_grouped_matmul` is available. 2. `npu_moe_distribute_dispatch` and `npu_moe_distribute_combine` are available.
3. `enable_expert_parallel=False` is not supported.
This implementation uses all-to-all communication to exchange tokens This implementation uses the MC2 communication method, which is optimized for
between data parallel ranks before and after the MLP computation. It should Communication and Computation parallelism on Ascend devices.
have better performance than AllGatherCommImpl when DP size > 1.
""" """
def _get_token_dispatcher(self): def _get_token_dispatcher(self):
return TokenDispatcherWithAll2AllV( return TokenDispatcherWithMC2()
top_k=self.moe_config.experts_per_token,
num_experts=self.moe_config.num_experts,
num_local_experts=self.moe_config.num_local_experts)
def _get_prepare_finalize(self): def _get_prepare_finalize(self):
return PrepareAndFinalizeWithAll2All(self.moe_config) return PrepareAndFinalizeWithMC2(self.moe_config)
def fused_experts( def fused_experts(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor | list[torch.Tensor],
w2: torch.Tensor, w2: torch.Tensor | list[torch.Tensor],
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
activation: str = "silu", activation: str = "silu",
@@ -274,8 +271,8 @@ class FusedAlltoAllCommImpl(MoECommMethod):
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
global_num_experts: Optional[int] = None, global_num_experts: Optional[int] = None,
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None, w1_scale: Optional[list[torch.Tensor]] = None,
w2_scale: Optional[torch.Tensor] = None, w2_scale: Optional[list[torch.Tensor]] = None,
w1_scale_bias: torch.Tensor = None, w1_scale_bias: torch.Tensor = None,
w2_scale_bias: torch.Tensor = None, w2_scale_bias: torch.Tensor = None,
w1_offset: Optional[torch.Tensor] = None, w1_offset: Optional[torch.Tensor] = None,
@@ -291,18 +288,27 @@ class FusedAlltoAllCommImpl(MoECommMethod):
dynamic_eplb: bool = False, dynamic_eplb: bool = False,
mc2_mask: torch.Tensor = None, mc2_mask: torch.Tensor = None,
pertoken_scale: Optional[torch.Tensor] = None): pertoken_scale: Optional[torch.Tensor] = None):
assert not (
w1_scale is None or w2_scale is None
), "w1_scale and w2_scale cannot be None for FusedMC2CommImpl."
out = torch.empty_like(hidden_states) out = torch.empty_like(hidden_states)
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
torch.ops._C_ascend.dispatch_ffn_combine( torch.ops._C_ascend.dispatch_ffn_combine(
x=hidden_states, x=hidden_states,
weight1=w1, weight1=w1[0],
weight2=w2, weight2=w2[0],
expert_idx=topk_ids, expert_idx=topk_ids,
scale1=w1_scale, scale1=w1_scale[0],
scale2=w2_scale, scale2=w2_scale[0],
probs=topk_weights.to(torch.float32), probs=topk_weights.to(torch.float32),
group=self.token_dispatcher.moe_all_to_all_group_name, group=self.token_dispatcher.moe_all_to_all_group_name,
max_output_size=65536, max_output_size=65536,
out=out, out=out,
) )
elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2:
raise NotImplementedError()
else:
raise ValueError(
f"Wrong value of {envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2=}")
return out return out

View File

@@ -130,7 +130,8 @@ def _maybe_prefetch_mlp_gate_up_proj_impl(x_dependency: torch.Tensor,
with torch.npu.stream(prefetch_stream): with torch.npu.stream(prefetch_stream):
mlp_gate_up_prefetch_size = envs_ascend.VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE mlp_gate_up_prefetch_size = envs_ascend.VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE
torch_npu.npu_prefetch(model_instance.model.layers[layer_idx].mlp.gate_up_proj.weight, \ torch_npu.npu_prefetch(
model_instance.model.layers[layer_idx].mlp.gate_up_proj.weight,
x_dependency, mlp_gate_up_prefetch_size) x_dependency, mlp_gate_up_prefetch_size)
return return
@@ -185,7 +186,8 @@ def _maybe_prefetch_mlp_down_proj_impl(x_dependency: torch.Tensor) -> None:
with torch.npu.stream(prefetch_stream): with torch.npu.stream(prefetch_stream):
mlp_down_prefetch_size = envs_ascend.VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE mlp_down_prefetch_size = envs_ascend.VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE
torch_npu.npu_prefetch(model_instance.model.layers[layer_idx].mlp.down_proj.weight, \ torch_npu.npu_prefetch(
model_instance.model.layers[layer_idx].mlp.down_proj.weight,
x_dependency, mlp_down_prefetch_size) x_dependency, mlp_down_prefetch_size)
forward_context.layer_idx += 1 forward_context.layer_idx += 1
return return
@@ -250,7 +252,7 @@ def _maybe_all_reduce_tensor_model_parallel_impl(
forward_context = get_forward_context() forward_context = get_forward_context()
moe_comm_type = forward_context.moe_comm_type moe_comm_type = forward_context.moe_comm_type
if moe_comm_type in { if moe_comm_type in {
MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_ALLTOALL MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2
} or forward_context.sp_enabled: } or forward_context.sp_enabled:
return final_hidden_states return final_hidden_states
else: else:

View File

@@ -23,6 +23,7 @@ from vllm.config import CompilationMode, get_current_vllm_config
from vllm.distributed import get_ep_group from vllm.distributed import get_ep_group
from vllm.forward_context import get_forward_context from vllm.forward_context import get_forward_context
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import MoECommType from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.distributed.parallel_state import get_mc2_group
@@ -246,15 +247,16 @@ class AscendW8A8DynamicFusedMoEMethod:
w2 = [layer.w2_weight] w2 = [layer.w2_weight]
w2_scale = [layer.w2_weight_scale] w2_scale = [layer.w2_weight_scale]
fused_flag = get_forward_context( fused_scale_flag = (get_forward_context().moe_comm_type
).moe_comm_type == MoECommType.FUSED_ALLTOALL == MoECommType.FUSED_MC2
and envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1)
return moe_comm_method.fused_experts( return moe_comm_method.fused_experts(
hidden_states=x, hidden_states=x,
pertoken_scale=pertoken_scale, pertoken_scale=pertoken_scale,
w1=w1[0] if fused_flag else w1, w1=w1,
w1_scale=layer.fused_w1_scale if fused_flag else w1_scale, w1_scale=[layer.fused_w1_scale] if fused_scale_flag else w1_scale,
w2=w2[0] if fused_flag else w2, w2=w2,
w2_scale=layer.fused_w2_scale if fused_flag else w2_scale, w2_scale=[layer.fused_w2_scale] if fused_scale_flag else w2_scale,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
use_int8_w8a8=True, use_int8_w8a8=True,

View File

@@ -430,7 +430,8 @@ class NPUModelRunner(GPUModelRunner):
# moe_comm_method of each rank is MC2 and recomputation would never happen in D # moe_comm_method of each rank is MC2 and recomputation would never happen in D
# nodes. So here we check whether recompute_scheduler_enable is True. # nodes. So here we check whether recompute_scheduler_enable is True.
return self.is_kv_consumer and self.ascend_config.recompute_scheduler_enable and select_moe_comm_method( return self.is_kv_consumer and self.ascend_config.recompute_scheduler_enable and select_moe_comm_method(
potential_max_num_tokens, self.vllm_config) == MoECommType.MC2 potential_max_num_tokens,
self.vllm_config) in {MoECommType.MC2, MoECommType.FUSED_MC2}
def _sync_metadata_across_dp( def _sync_metadata_across_dp(
self, num_tokens: int, self, num_tokens: int,
@@ -1058,7 +1059,7 @@ class NPUModelRunner(GPUModelRunner):
# (num_reqs_d + num_reqs_p, max_num_blocks), # (num_reqs_d + num_reqs_p, max_num_blocks),
# flattened block_table: [d0, d0, d1, d1, p0, p1, p2] # flattened block_table: [d0, d0, d1, d1, p0, p1, p2]
# (num_reqs_d * decode_threshold + num_reqs_p, max_num_blocks), # (num_reqs_d * decode_threshold + num_reqs_p, max_num_blocks),
ori_query_lens = self.query_start_loc_pcp_full.cpu[1:num_reqs+1] - \ ori_query_lens = self.query_start_loc_pcp_full.cpu[1:num_reqs + 1] - \
self.query_start_loc_pcp_full.cpu[:num_reqs] self.query_start_loc_pcp_full.cpu[:num_reqs]
num_prefill_reqs = (ori_query_lens num_prefill_reqs = (ori_query_lens
> self.decode_threshold).sum().item() > self.decode_threshold).sum().item()
@@ -2203,7 +2204,7 @@ class NPUModelRunner(GPUModelRunner):
def profile_run(self) -> None: def profile_run(self) -> None:
mc2_tokens_capacity = get_mc2_tokens_capacity() mc2_tokens_capacity = get_mc2_tokens_capacity()
if self.max_num_tokens > mc2_tokens_capacity and \ if self.max_num_tokens > mc2_tokens_capacity and \
select_moe_comm_method(mc2_tokens_capacity, self.vllm_config) == MoECommType.MC2: select_moe_comm_method(mc2_tokens_capacity, self.vllm_config) in {MoECommType.MC2, MoECommType.FUSED_MC2}:
self._dummy_run(mc2_tokens_capacity, self._dummy_run(mc2_tokens_capacity,
with_prefill=True, with_prefill=True,
is_profile=True) is_profile=True)