From 904c18f929d1ad860e5fa9bb2ea7691928366fe0 Mon Sep 17 00:00:00 2001 From: wangqiankun13 Date: Sun, 21 Dec 2025 15:23:59 +0800 Subject: [PATCH] [Feature]Use DispatchGmmCombineDecode operator to replace MC2(Optional) (#5040) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What this PR does / why we need it? This PR adds model-side integration for the previously introduced experimental AscendC fused operator DispatchGmmCombineDecode, used in MoE decoding. The operator implementation itself was added in a prior PR[#4139 ](https://github.com/vllm-project/vllm-ascend/pull/4139). This change only adapts the model execution path to optionally use the fused operator. When the environment variable VLLM_ASCEND_ENABLE_FUSED_MC2=2 is set, the original MC2 path composed of multiple operators (A8W8 dispatch → GMM → SwiGLU → GMM → combine) might be replaced by the single fused operator DispatchGmmCombineDecode. By default, the existing multi-operator MC2 implementation is preserved. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 Signed-off-by: wangqiankun --- vllm_ascend/ascend_forward_context.py | 20 +++++++++++++++---- vllm_ascend/envs.py | 8 +++++++- vllm_ascend/ops/fused_moe/fused_moe.py | 2 +- vllm_ascend/ops/fused_moe/moe_comm_method.py | 18 +++++++++++++++-- vllm_ascend/ops/fused_moe/token_dispatcher.py | 1 + vllm_ascend/quantization/w8a8_dynamic.py | 11 +++++++++- 6 files changed, 51 insertions(+), 9 deletions(-) diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index 1b68590d..f8b9d1cd 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -253,12 +253,24 @@ def select_moe_comm_method(num_tokens: int, ascend_config = get_ascend_config() dynamic_eplb = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path # TODO: drop the EP-size guard when dispatch_ffn_combine supports larger EP sizes - fused_mc2_enable = envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 and quant_type == "w8a8_dynamic" and get_ep_group( - ).world_size <= 16 and (not dynamic_eplb) and (not is_mtp_model) + # TODO: drop dynamic_eplb guard when dispatch_gmm_combine_decode supports tensor list inputs + # TODO: add guard for dispatch_gmm_combine_decode when mtp uses float while moe uses w8a8 + fused_mc2_enable = envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 and quant_type == "w8a8_dynamic" and ( + not dynamic_eplb) if num_tokens <= mc2_tokens_capacity: - moe_comm_type = MoECommType.FUSED_MC2 if fused_mc2_enable else MoECommType.MC2 + fused_decode_enable = fused_mc2_enable + if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1: + fused_decode_enable = fused_mc2_enable and get_ep_group( + ).world_size <= 16 and (not is_mtp_model) + moe_comm_type = MoECommType.FUSED_MC2 if fused_decode_enable else MoECommType.MC2 else: - moe_comm_type = MoECommType.FUSED_MC2 if fused_mc2_enable else MoECommType.ALLTOALL + fused_prefill_enable = fused_mc2_enable + if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1: + fused_prefill_enable = fused_mc2_enable and get_ep_group( + ).world_size <= 16 and (not is_mtp_model) + elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2: + fused_prefill_enable = False + moe_comm_type = MoECommType.FUSED_MC2 if fused_prefill_enable else MoECommType.ALLTOALL else: raise ValueError(f"Unsupported soc_version: {soc_version}") diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index c3c5e967..75024663 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -135,7 +135,13 @@ env_variables: Dict[str, Callable[[], Any]] = { # Whether to anbale dynamic EPLB "DYNAMIC_EPLB": lambda: os.getenv("DYNAMIC_EPLB", "false").lower(), - # Whether to anbale fused mc2(dispatch_gmm_combine_decode/dispatch_ffn_combine operator) + # Whether to enable fused mc2(`dispatch_gmm_combine_decode`/`dispatch_ffn_combine` operator) + # 0, or not set: default ALLTOALL and MC2 will be used. + # 1: ALLTOALL and MC2 might be replaced by `dispatch_ffn_combine` operator. + # `dispatch_ffn_combine` can be used only for moe layer with W8A8, EP<=16, non-mtp, non-dynamic-eplb. + # 2: MC2 might be replaced by `dispatch_gmm_combine_decode` operator. + # `dispatch_gmm_combine_decode` can be used only for **decode node** moe layer + # with W8A8, non-dynamic-eplb. And MTP layer must be W8A8. "VLLM_ASCEND_ENABLE_FUSED_MC2": lambda: int(os.getenv("VLLM_ASCEND_ENABLE_FUSED_MC2", '0')), } diff --git a/vllm_ascend/ops/fused_moe/fused_moe.py b/vllm_ascend/ops/fused_moe/fused_moe.py index 9d913f63..b4cbbb48 100644 --- a/vllm_ascend/ops/fused_moe/fused_moe.py +++ b/vllm_ascend/ops/fused_moe/fused_moe.py @@ -345,7 +345,7 @@ class AscendFusedMoE(FusedMoE): shared_out = fc3_context.shared_experts(hidden_states) # NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel` moe_comm_type = forward_context.moe_comm_type - if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2} \ + if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2} \ and not shared_expert_dp_enabled(): shared_out = tensor_model_parallel_all_reduce(shared_out) set_flash_common3_context(shared_out=shared_out) diff --git a/vllm_ascend/ops/fused_moe/moe_comm_method.py b/vllm_ascend/ops/fused_moe/moe_comm_method.py index e9dc7b0a..30d1e5c1 100644 --- a/vllm_ascend/ops/fused_moe/moe_comm_method.py +++ b/vllm_ascend/ops/fused_moe/moe_comm_method.py @@ -291,9 +291,9 @@ class FusedMC2CommImpl(MoECommMethod): assert not ( w1_scale is None or w2_scale is None ), "w1_scale and w2_scale cannot be None for FusedMC2CommImpl." - out = torch.empty_like(hidden_states) if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1: + out = torch.empty_like(hidden_states) torch.ops._C_ascend.dispatch_ffn_combine( x=hidden_states, weight1=w1[0], @@ -307,7 +307,21 @@ class FusedMC2CommImpl(MoECommMethod): out=out, ) elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2: - raise NotImplementedError() + assert expert_map is not None, "expert_map cannot be None." + out, _ = torch.ops._C_ascend.dispatch_gmm_combine_decode( + x=hidden_states, + expert_ids=topk_ids, + gmm1_permuted_weight=w1[0], + gmm1_permuted_weight_scale=w1_scale[0], + gmm2_weight=w2[0], + gmm2_weight_scale=w2_scale[0], + expert_smooth_scales=None, + expert_scales=topk_weights.to(torch.float32), + group_ep=self.token_dispatcher.moe_all_to_all_group_name, + ep_rank_size=self.token_dispatcher.ep_world_size, + ep_rank_id=self.token_dispatcher.ep_rank_id, + moe_expert_num=len(expert_map), + global_bs=self.token_dispatcher.fused_global_bs) else: raise ValueError( f"Wrong value of {envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2=}") diff --git a/vllm_ascend/ops/fused_moe/token_dispatcher.py b/vllm_ascend/ops/fused_moe/token_dispatcher.py index 1b18a488..aeb751d0 100644 --- a/vllm_ascend/ops/fused_moe/token_dispatcher.py +++ b/vllm_ascend/ops/fused_moe/token_dispatcher.py @@ -125,6 +125,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): max_num_tokens = min(max_num_reqs * uniform_decode_query_len, 512) num_tokens_per_tp_rank = (max_num_tokens + tp_size - 1) // tp_size self.global_bs = num_tokens_per_tp_rank * self.ep_world_size + self.fused_global_bs = max_num_tokens * self.ep_world_size def get_dispatch_mc2_kwargs( self, diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index e32360ce..da6d3a69 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -231,6 +231,10 @@ class AscendW8A8DynamicFusedMoEMethod: topk_weights = topk_weights.to(self.in_dtype) moe_comm_method = get_forward_context().moe_comm_method + # When VLLM_ASCEND_ENABLE_FUSED_MC2 == 2, use dispatch_gmm_combine_decode, need fp32 scale + w2_weight_scale_fp32_flag = ( + get_forward_context().moe_comm_type == MoECommType.FUSED_MC2 + and envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2) if self.dynamic_eplb: w1 = layer.w13_weight_list w1_scale = layer.w13_weight_scale_fp32_list @@ -240,7 +244,10 @@ class AscendW8A8DynamicFusedMoEMethod: w1 = [layer.w13_weight] w1_scale = [layer.w13_weight_scale_fp32] w2 = [layer.w2_weight] - w2_scale = [layer.w2_weight_scale] + w2_scale = [ + layer.w2_weight_scale_fp32 + if w2_weight_scale_fp32_flag else layer.w2_weight_scale + ] fused_scale_flag = (get_forward_context().moe_comm_type == MoECommType.FUSED_MC2 @@ -279,6 +286,8 @@ class AscendW8A8DynamicFusedMoEMethod: layer.w13_weight_offset.data.shape[0], -1) layer.w2_weight_scale.data = layer.w2_weight_scale.data.view( layer.w2_weight_scale.data.shape[0], -1) + layer.w2_weight_scale_fp32 = layer.w2_weight_scale.data.to( + torch.float32) layer.w2_weight_offset.data = layer.w2_weight_offset.data.view( layer.w2_weight_offset.data.shape[0], -1)