From 1b47fca0e8e1795414452ca440617f41c9eecc96 Mon Sep 17 00:00:00 2001 From: Chen Chen <0109chenchen@gmail.com> Date: Thu, 18 Dec 2025 23:34:31 +0800 Subject: [PATCH] [bugfix] Use FUSED_MC2 MoE comm path for the op `dispatch_ffn_combine` (#5156) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What this PR does / why we need it? - Renames the MoE comm enum value `MoECommType.FUSED_ALLTOALL` to `MoECommType.FUSED_MC2` and updates all call sites. - Updates `select_moe_comm_method` to optionally select `FUSED_MC2` on Ascend A3 when: - `enable_expert_parallel=True` - quantization is `w8a8_dynamic` - `EP <= 16` - `dynamic_eplb` is disabled - `is_mtp_model = False` - Replaces the old “fused all-to-all” comm implementation with `FusedMC2CommImpl`, using `TokenDispatcherWithMC2` / `PrepareAndFinalizeWithMC2` and `dispatch_ffn_combine`. - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 --------- Signed-off-by: Chen Chen <0109chenchen@gmail.com> --- vllm_ascend/ascend_forward_context.py | 62 +++++++++---------- vllm_ascend/envs.py | 3 + vllm_ascend/ops/fused_moe/fused_moe.py | 2 +- vllm_ascend/ops/fused_moe/moe_comm_method.py | 64 +++++++++++--------- vllm_ascend/ops/register_custom_ops.py | 12 ++-- vllm_ascend/quantization/w8a8_dynamic.py | 14 +++-- vllm_ascend/worker/model_runner_v1.py | 7 ++- 7 files changed, 89 insertions(+), 75 deletions(-) diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index 6f402ca8..1b68590d 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -26,7 +26,7 @@ class MoECommType(Enum): ALLGATHER = 0 MC2 = 1 ALLTOALL = 2 - FUSED_ALLTOALL = 3 + FUSED_MC2 = 3 @contextmanager @@ -62,11 +62,8 @@ def set_ascend_forward_context( from vllm_ascend.ops.fused_moe.moe_comm_method import \ get_moe_comm_method - moe_comm_type = select_moe_comm_method(num_tokens, vllm_config) - # TODO: remove this after moe_comm_type selection logic is finalized - if is_mtp_model: - moe_comm_type = (MoECommType.ALLTOALL if moe_comm_type - == MoECommType.FUSED_ALLTOALL else moe_comm_type) + moe_comm_type = select_moe_comm_method(num_tokens, vllm_config, + is_mtp_model) forward_context.moe_comm_type = moe_comm_type forward_context.moe_comm_method = get_moe_comm_method(moe_comm_type) @@ -93,7 +90,7 @@ def set_ascend_forward_context( forward_context.mmrs_fusion = mmrs_fusion forward_context.num_tokens = num_tokens forward_context.sp_enabled = sp_enabled - #TODO(Levi-JQ): another PR to normalize the enabling logic for sp/fc2 + # TODO(Levi-JQ): another PR to normalize the enabling logic for sp/fc2 forward_context.flashcomm_v2_enabled = flashcomm2_enable( ) and tp_world_size > 1 and num_tokens is not None @@ -210,29 +207,30 @@ def get_mc2_mask(): def select_moe_comm_method(num_tokens: int, - vllm_config: VllmConfig) -> Optional[MoECommType]: - """1. If expert parallel is not enabled, we use all-gather since MC2 and all-to-all - are designed for expert parallelism. - 2. If expert parallel is enabled, we need to consider the soc version and the - number of tokens. This is based on the observation that all-gather is more - efficient than all-to-all when running on A2. + vllm_config: VllmConfig, + is_mtp_model=False) -> Optional[MoECommType]: + """Select the MoE communication method according to parallel settings, + device generation, token count, and quantization. - a. For A2, we choose from MC2 and all-gather. + 1. Non-MoE models return `None`. + 2. Without expert parallel, fall back to all-gather. + 3. On A2 with expert parallel, pick MC2 when tokens fit the MC2 capacity + and the DP size is large enough; otherwise use all-gather. + 4. On A3 with expert parallel, prefer fused MC2 when using w8a8_dynamic + quantization with small EP size, no dynamic_eplb, and not in MTP + mode; otherwise use MC2 within capacity or all-to-all. - b. For A3, we choose from MC2 and all-to-all. + Args: + num_tokens (int): The number of tokens in the current batch. + vllm_config (VllmConfig): Runtime configuration for the model. + is_mtp_model (bool): Whether the model runs in MTP mode (disables fused MC2). - In both cases, we use MC2 when the number of tokens is smaller than - a its capacity threshold. + Raises: + ValueError: If the soc version is unsupported. - Args: - num_tokens (int): The number of tokens in the current batch. - - Raises: - ValueError: If the soc version is unsupported. - - Returns: - MoECommType: The selected MoE communication method. - """ + Returns: + MoECommType | None: The selected MoE communication method. + """ if not is_moe_model(vllm_config): return None mc2_tokens_capacity = get_mc2_tokens_capacity() @@ -255,11 +253,13 @@ def select_moe_comm_method(num_tokens: int, ascend_config = get_ascend_config() dynamic_eplb = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path # TODO: drop the EP-size guard when dispatch_ffn_combine supports larger EP sizes - fused_all2all_enable = quant_type == "w8a8_dynamic" and get_ep_group( - ).world_size <= 16 and (not dynamic_eplb) - moe_comm_type = (MoECommType.MC2 if num_tokens <= mc2_tokens_capacity - else MoECommType.FUSED_ALLTOALL - if fused_all2all_enable else MoECommType.ALLTOALL) + fused_mc2_enable = envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 and quant_type == "w8a8_dynamic" and get_ep_group( + ).world_size <= 16 and (not dynamic_eplb) and (not is_mtp_model) + if num_tokens <= mc2_tokens_capacity: + moe_comm_type = MoECommType.FUSED_MC2 if fused_mc2_enable else MoECommType.MC2 + else: + moe_comm_type = MoECommType.FUSED_MC2 if fused_mc2_enable else MoECommType.ALLTOALL + else: raise ValueError(f"Unsupported soc_version: {soc_version}") return moe_comm_type diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index 5e926b11..4e92d800 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -132,6 +132,9 @@ env_variables: Dict[str, Callable[[], Any]] = { # Whether to anbale dynamic EPLB "DYNAMIC_EPLB": lambda: os.getenv("DYNAMIC_EPLB", "false").lower(), + # Whether to anbale fused mc2(dispatch_gmm_combine_decode/dispatch_ffn_combine operator) + "VLLM_ASCEND_ENABLE_FUSED_MC2": + lambda: int(os.getenv("VLLM_ASCEND_ENABLE_FUSED_MC2", '0')), } # end-env-vars-definition diff --git a/vllm_ascend/ops/fused_moe/fused_moe.py b/vllm_ascend/ops/fused_moe/fused_moe.py index 9adfe085..2d0e7afc 100644 --- a/vllm_ascend/ops/fused_moe/fused_moe.py +++ b/vllm_ascend/ops/fused_moe/fused_moe.py @@ -533,7 +533,7 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE): # NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel` forward_context = get_forward_context() moe_comm_type = forward_context.moe_comm_type - if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_ALLTOALL} \ + if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2} \ and not shared_expert_dp_enabled(): shared_out = tensor_model_parallel_all_reduce(shared_out) else: diff --git a/vllm_ascend/ops/fused_moe/moe_comm_method.py b/vllm_ascend/ops/fused_moe/moe_comm_method.py index 93b79242..e9dc7b0a 100644 --- a/vllm_ascend/ops/fused_moe/moe_comm_method.py +++ b/vllm_ascend/ops/fused_moe/moe_comm_method.py @@ -22,6 +22,7 @@ import torch from vllm.forward_context import get_forward_context from vllm.model_executor.layers.fused_moe import FusedMoEConfig +import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_forward_context import MoECommType from vllm_ascend.ops.fused_moe.moe_mlp import unified_apply_mlp from vllm_ascend.ops.fused_moe.prepare_finalize import ( @@ -43,8 +44,7 @@ def setup_moe_comm_method(moe_config): _MoECommMethods[MoECommType.ALLTOALL] = AlltoAllCommImpl(moe_config) _MoECommMethods[MoECommType.ALLGATHER] = AllGatherCommImpl(moe_config) _MoECommMethods[MoECommType.MC2] = MC2CommImpl(moe_config) - _MoECommMethods[MoECommType.FUSED_ALLTOALL] = FusedAlltoAllCommImpl( - moe_config) + _MoECommMethods[MoECommType.FUSED_MC2] = FusedMC2CommImpl(moe_config) class MoECommMethod(ABC): @@ -241,30 +241,27 @@ class AlltoAllCommImpl(MoECommMethod): return PrepareAndFinalizeWithAll2All(self.moe_config) -class FusedAlltoAllCommImpl(MoECommMethod): +class FusedMC2CommImpl(MoECommMethod): """This implementation is for the scenarios listed below: 1. `enable_expert_parallel=True`. - 2. `npu_grouped_matmul` is available. - - This implementation uses all-to-all communication to exchange tokens - between data parallel ranks before and after the MLP computation. It should - have better performance than AllGatherCommImpl when DP size > 1. + 2. `npu_moe_distribute_dispatch` and `npu_moe_distribute_combine` are available. + 3. `enable_expert_parallel=False` is not supported. + + This implementation uses the MC2 communication method, which is optimized for + Communication and Computation parallelism on Ascend devices. """ def _get_token_dispatcher(self): - return TokenDispatcherWithAll2AllV( - top_k=self.moe_config.experts_per_token, - num_experts=self.moe_config.num_experts, - num_local_experts=self.moe_config.num_local_experts) + return TokenDispatcherWithMC2() def _get_prepare_finalize(self): - return PrepareAndFinalizeWithAll2All(self.moe_config) + return PrepareAndFinalizeWithMC2(self.moe_config) def fused_experts( self, hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, + w1: torch.Tensor | list[torch.Tensor], + w2: torch.Tensor | list[torch.Tensor], topk_weights: torch.Tensor, topk_ids: torch.Tensor, activation: str = "silu", @@ -274,8 +271,8 @@ class FusedAlltoAllCommImpl(MoECommMethod): use_int4_w4a16: bool = False, global_num_experts: Optional[int] = None, expert_map: Optional[torch.Tensor] = None, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, + w1_scale: Optional[list[torch.Tensor]] = None, + w2_scale: Optional[list[torch.Tensor]] = None, w1_scale_bias: torch.Tensor = None, w2_scale_bias: torch.Tensor = None, w1_offset: Optional[torch.Tensor] = None, @@ -291,18 +288,27 @@ class FusedAlltoAllCommImpl(MoECommMethod): dynamic_eplb: bool = False, mc2_mask: torch.Tensor = None, pertoken_scale: Optional[torch.Tensor] = None): + assert not ( + w1_scale is None or w2_scale is None + ), "w1_scale and w2_scale cannot be None for FusedMC2CommImpl." out = torch.empty_like(hidden_states) - torch.ops._C_ascend.dispatch_ffn_combine( - x=hidden_states, - weight1=w1, - weight2=w2, - expert_idx=topk_ids, - scale1=w1_scale, - scale2=w2_scale, - probs=topk_weights.to(torch.float32), - group=self.token_dispatcher.moe_all_to_all_group_name, - max_output_size=65536, - out=out, - ) + if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1: + torch.ops._C_ascend.dispatch_ffn_combine( + x=hidden_states, + weight1=w1[0], + weight2=w2[0], + expert_idx=topk_ids, + scale1=w1_scale[0], + scale2=w2_scale[0], + probs=topk_weights.to(torch.float32), + group=self.token_dispatcher.moe_all_to_all_group_name, + max_output_size=65536, + out=out, + ) + elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2: + raise NotImplementedError() + else: + raise ValueError( + f"Wrong value of {envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2=}") return out diff --git a/vllm_ascend/ops/register_custom_ops.py b/vllm_ascend/ops/register_custom_ops.py index a534b719..6874687f 100644 --- a/vllm_ascend/ops/register_custom_ops.py +++ b/vllm_ascend/ops/register_custom_ops.py @@ -130,8 +130,9 @@ def _maybe_prefetch_mlp_gate_up_proj_impl(x_dependency: torch.Tensor, with torch.npu.stream(prefetch_stream): mlp_gate_up_prefetch_size = envs_ascend.VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE - torch_npu.npu_prefetch(model_instance.model.layers[layer_idx].mlp.gate_up_proj.weight, \ - x_dependency, mlp_gate_up_prefetch_size) + torch_npu.npu_prefetch( + model_instance.model.layers[layer_idx].mlp.gate_up_proj.weight, + x_dependency, mlp_gate_up_prefetch_size) return @@ -185,8 +186,9 @@ def _maybe_prefetch_mlp_down_proj_impl(x_dependency: torch.Tensor) -> None: with torch.npu.stream(prefetch_stream): mlp_down_prefetch_size = envs_ascend.VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE - torch_npu.npu_prefetch(model_instance.model.layers[layer_idx].mlp.down_proj.weight, \ - x_dependency, mlp_down_prefetch_size) + torch_npu.npu_prefetch( + model_instance.model.layers[layer_idx].mlp.down_proj.weight, + x_dependency, mlp_down_prefetch_size) forward_context.layer_idx += 1 return @@ -250,7 +252,7 @@ def _maybe_all_reduce_tensor_model_parallel_impl( forward_context = get_forward_context() moe_comm_type = forward_context.moe_comm_type if moe_comm_type in { - MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_ALLTOALL + MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2 } or forward_context.sp_enabled: return final_hidden_states else: diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index cce8750b..8952d3cf 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -23,6 +23,7 @@ from vllm.config import CompilationMode, get_current_vllm_config from vllm.distributed import get_ep_group from vllm.forward_context import get_forward_context +import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_forward_context import MoECommType from vllm_ascend.distributed.parallel_state import get_mc2_group @@ -246,15 +247,16 @@ class AscendW8A8DynamicFusedMoEMethod: w2 = [layer.w2_weight] w2_scale = [layer.w2_weight_scale] - fused_flag = get_forward_context( - ).moe_comm_type == MoECommType.FUSED_ALLTOALL + fused_scale_flag = (get_forward_context().moe_comm_type + == MoECommType.FUSED_MC2 + and envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1) return moe_comm_method.fused_experts( hidden_states=x, pertoken_scale=pertoken_scale, - w1=w1[0] if fused_flag else w1, - w1_scale=layer.fused_w1_scale if fused_flag else w1_scale, - w2=w2[0] if fused_flag else w2, - w2_scale=layer.fused_w2_scale if fused_flag else w2_scale, + w1=w1, + w1_scale=[layer.fused_w1_scale] if fused_scale_flag else w1_scale, + w2=w2, + w2_scale=[layer.fused_w2_scale] if fused_scale_flag else w2_scale, topk_weights=topk_weights, topk_ids=topk_ids, use_int8_w8a8=True, diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 8ef38d4f..d562f114 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -430,7 +430,8 @@ class NPUModelRunner(GPUModelRunner): # moe_comm_method of each rank is MC2 and recomputation would never happen in D # nodes. So here we check whether recompute_scheduler_enable is True. return self.is_kv_consumer and self.ascend_config.recompute_scheduler_enable and select_moe_comm_method( - potential_max_num_tokens, self.vllm_config) == MoECommType.MC2 + potential_max_num_tokens, + self.vllm_config) in {MoECommType.MC2, MoECommType.FUSED_MC2} def _sync_metadata_across_dp( self, num_tokens: int, @@ -1058,7 +1059,7 @@ class NPUModelRunner(GPUModelRunner): # (num_reqs_d + num_reqs_p, max_num_blocks), # flattened block_table: [d0, d0, d1, d1, p0, p1, p2] # (num_reqs_d * decode_threshold + num_reqs_p, max_num_blocks), - ori_query_lens = self.query_start_loc_pcp_full.cpu[1:num_reqs+1] - \ + ori_query_lens = self.query_start_loc_pcp_full.cpu[1:num_reqs + 1] - \ self.query_start_loc_pcp_full.cpu[:num_reqs] num_prefill_reqs = (ori_query_lens > self.decode_threshold).sum().item() @@ -2203,7 +2204,7 @@ class NPUModelRunner(GPUModelRunner): def profile_run(self) -> None: mc2_tokens_capacity = get_mc2_tokens_capacity() if self.max_num_tokens > mc2_tokens_capacity and \ - select_moe_comm_method(mc2_tokens_capacity, self.vllm_config) == MoECommType.MC2: + select_moe_comm_method(mc2_tokens_capacity, self.vllm_config) in {MoECommType.MC2, MoECommType.FUSED_MC2}: self._dummy_run(mc2_tokens_capacity, with_prefill=True, is_profile=True)