From a370dfa9623e648439b724569931988a852e462e Mon Sep 17 00:00:00 2001 From: asunxiao <43368867+asunxiao@users.noreply.github.com> Date: Tue, 17 Mar 2026 19:53:02 +0800 Subject: [PATCH] [bugfix]Enable dispatch_ffn_combine feature for qwen3.5 (#7066) ### What this PR does / why we need it? Qwen3.5 Moe supports enabling the dispatch_ffn_combine fusion operator. Fix problem: In the w8a8 quantization scene, Qwen3.5 model's config.json lacks the quantize field. The previous logic strictly relied on quant_type == "w8a8_dynamic" to enable VLLM_ASCEND_ENABLE_FUSED_MC2. This caused the dispatch_ffn_combine fusion operator to fail to activate even when the environment variable was set. Enable dispatch_ffn_combine fusion operator for BF16 scenarios. - vLLM version: v0.16.0 - vLLM main: https://github.com/vllm-project/vllm/commit/4034c3d32e30d01639459edd3ab486f56993876d --------- Signed-off-by: asunxiao --- vllm_ascend/ascend_forward_context.py | 8 +++-- vllm_ascend/ops/fused_moe/fused_moe.py | 44 +++++++++++++++++++++++--- 2 files changed, 46 insertions(+), 6 deletions(-) diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index afca4d97..7c7242ef 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -245,14 +245,18 @@ def select_moe_comm_method(num_tokens: int, vllm_config: VllmConfig, is_draft_mo elif soc_version in {AscendDeviceType.A3}: # TODO: drop the EP-size guard when dispatch_ffn_combine supports larger EP sizes # TODO: drop speculative method guard when dispatch_gmm_combine_decode supports w16a16 - fused_mc2_enable = envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 and quant_type == "w8a8_dynamic" + fused_mc2_enable = envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 dispatch_ffn_combine_enable = get_ep_group().world_size <= 32 and (not is_draft_model) if num_tokens <= mc2_tokens_capacity: fused_decode_enable = fused_mc2_enable if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1: fused_decode_enable = fused_mc2_enable and dispatch_ffn_combine_enable elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2: - fused_decode_enable = fused_mc2_enable and speculative_enable_dispatch_gmm_combine_decode(vllm_config) + fused_decode_enable = ( + fused_mc2_enable + and speculative_enable_dispatch_gmm_combine_decode(vllm_config) + and quant_type == "w8a8_dynamic" + ) moe_comm_type = MoECommType.FUSED_MC2 if fused_decode_enable else MoECommType.MC2 else: fused_prefill_enable = fused_mc2_enable diff --git a/vllm_ascend/ops/fused_moe/fused_moe.py b/vllm_ascend/ops/fused_moe/fused_moe.py index 5aa5670a..3d858f7f 100644 --- a/vllm_ascend/ops/fused_moe/fused_moe.py +++ b/vllm_ascend/ops/fused_moe/fused_moe.py @@ -20,6 +20,7 @@ from functools import wraps import torch import torch.nn.functional as F +import torch_npu from vllm.config import get_current_vllm_config from vllm.distributed import get_dp_group, get_ep_group, get_tp_group, tensor_model_parallel_all_reduce from vllm.forward_context import get_forward_context @@ -32,6 +33,7 @@ from vllm.model_executor.layers.fused_moe.router.fused_moe_router import FusedMo from vllm.model_executor.layers.fused_moe.runner.default_moe_runner import DefaultMoERunner # type: ignore from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE +import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_forward_context import _EXTRA_CTX, MoECommType from vllm_ascend.distributed.parallel_state import get_mc2_group @@ -41,6 +43,7 @@ from vllm_ascend.ops.fused_moe.experts_selector import select_experts, zero_expe from vllm_ascend.ops.fused_moe.moe_comm_method import AllGatherCommImpl, FusedExpertsResult, setup_moe_comm_method from vllm_ascend.quantization.methods.base import QuantType from vllm_ascend.utils import ( + ACL_FORMAT_FRACTAL_NZ, enable_sp, maybe_trans_nz, npu_stream_switch, @@ -77,8 +80,18 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): w2_data = self._maybe_pad_weight(layer.w2_weight.data).transpose(1, 2).contiguous() layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False) - layer.w13_weight.data = maybe_trans_nz(layer.w13_weight.data) - layer.w2_weight.data = maybe_trans_nz(layer.w2_weight.data) + # TODO: Current dispatch_ffn_combine fusion operator ONLY supports NZ format. + # Therefore, we must cast weights to NZ when fusion is enabled. + # Once the underlying dispatch_ffn_combine operator is updated to support + # ND format (or other formats), remove this specific 'if' check and the forced + # npu_format_cast. At that point, the operator should be able to handle weights + # in their native format without explicit casting here. + if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2: + layer.w13_weight.data = torch_npu.npu_format_cast(layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ) + layer.w2_weight.data = torch_npu.npu_format_cast(layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ) + else: + layer.w13_weight.data = maybe_trans_nz(layer.w13_weight.data) + layer.w2_weight.data = maybe_trans_nz(layer.w2_weight.data) def apply( self, @@ -144,10 +157,33 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): topk_ids = torch.argsort(random_matrix, dim=1)[:, : topk_ids.size(1)].to(topk_ids.dtype) moe_comm_method = _EXTRA_CTX.moe_comm_method + # NOTE: In the MoECommType.FUSED_MC2 branch, we wrap weights (w1, w2) into lists + # and provide dummy scales (w1_scale, w2_scale). This is required because: + # The underlying Ascend fused operator (e.g., dispatch_ffn_combine) expects + # inputs in a list format. + # TODO: Passing an empty tensor as scale for float (BF16) cases is semantically + # incorrect. The ideal solution is to pass None. However, if the underlying + # dispatch_ffn_combine C++ operator does not support None for the scale argument + # (due to signature constraints), we are forced to use a placeholder empty tensor. + # This TODO tracks the requirement to update the C++ operator to accept Optional[Tensor] + # or None for scales in non-quantized scenarios. + if get_forward_context().moe_comm_type == MoECommType.FUSED_MC2: + w1 = [layer.w13_weight] + w1_scale = [torch.tensor([], dtype=torch.int64)] + w2 = [layer.w2_weight] + w2_scale = [torch.tensor([], dtype=torch.int64)] + else: + w1 = layer.w13_weight + w1_scale = None + w2 = layer.w2_weight + w2_scale = None + final_hidden_states = moe_comm_method.fused_experts( hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, + w1=w1, + w2=w2, + w1_scale=w1_scale, + w2_scale=w2_scale, w1_bias=layer.w13_bias if self.moe.has_bias else None, w2_bias=layer.w2_bias if self.moe.has_bias else None, activation=activation,