From 350b95efcf6ebedaae7b87ba621c38c021b471a8 Mon Sep 17 00:00:00 2001 From: wangqiankun13 Date: Sun, 4 Jan 2026 17:51:28 +0800 Subject: [PATCH] [BugFix]Disable dispatch_gmm_combine_decode operator when mtp drafter model uses non-w8a8 while main model uses w8a8, or drafter model is eagle series (#5293) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …w8a8 while main model uses w8a8 ### What this PR does / why we need it? Disable dispatch_gmm_combine_decode operator when mtp drafter model uses non-w8a8 while main model uses w8a8, or drafter model is eagle series. More info about this operator, please refer to RFC: issue https://github.com/vllm-project/vllm-ascend/issues/5476 - vLLM version: release/v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 Signed-off-by: wangqiankun --- vllm_ascend/ascend_forward_context.py | 8 ++++++-- vllm_ascend/utils.py | 17 +++++++++++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index 8ca7b255..6baa199b 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -14,7 +14,8 @@ import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.utils import (AscendDeviceType, enable_sp, flashcomm2_enable, get_ascend_device_type, has_layer_idx, - is_moe_model) + is_moe_model, + speculative_enable_dispatch_gmm_combine_decode) class MoECommType(Enum): @@ -242,7 +243,7 @@ def select_moe_comm_method(num_tokens: int, dynamic_eplb = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path # TODO: drop the EP-size guard when dispatch_ffn_combine supports larger EP sizes # TODO: drop dynamic_eplb guard when dispatch_gmm_combine_decode supports tensor list inputs - # TODO: add guard for dispatch_gmm_combine_decode when mtp uses float while moe uses w8a8 + # TODO: drop speculative method guard when dispatch_gmm_combine_decode supports w16a16 fused_mc2_enable = envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 and quant_type == "w8a8_dynamic" and ( not dynamic_eplb) if num_tokens <= mc2_tokens_capacity: @@ -250,6 +251,9 @@ def select_moe_comm_method(num_tokens: int, if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1: fused_decode_enable = fused_mc2_enable and get_ep_group( ).world_size <= 16 and (not is_draft_model) + elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2: + fused_decode_enable = fused_mc2_enable and \ + speculative_enable_dispatch_gmm_combine_decode(vllm_config) moe_comm_type = MoECommType.FUSED_MC2 if fused_decode_enable else MoECommType.MC2 else: fused_prefill_enable = fused_mc2_enable diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index e959335e..cecb88cd 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -831,6 +831,23 @@ def is_moe_model(vllm_config: VllmConfig): return _IS_MOE_MODEL +def speculative_enable_dispatch_gmm_combine_decode( + vllm_config: VllmConfig) -> bool: + if vllm_config.speculative_config is None: + return True + speculative_method = getattr(vllm_config.speculative_config, "method", + None) + if speculative_method in [None, "ngram", "suffix"]: + return True + if speculative_method in ["eagle", "eagle3"]: + return False + if speculative_method == "mtp": + mtp_quant_type = getattr(vllm_config.model_config.hf_config, + "mtp_quantize", None) + return mtp_quant_type == "w8a8_dynamic" + return False + + def _is_contain_expert(config: Any): if isinstance(config, dict): for k, v in config.items():