From c58accc15e2a4e8a672c754035168a09ccf421f6 Mon Sep 17 00:00:00 2001 From: ApsarasX Date: Sun, 6 Jul 2025 15:29:36 +0800 Subject: [PATCH] [Bugfix] Support Qwen3-MOE on aclgraph mode (#1381) ### What this PR does / why we need it? Fix the shape of the `npu_moe_init_routing` input parameters to support aclgraph mode on qwen3-moe In addition to this PR, resolving the `gatherv3` error might be necessary. See related PR https://github.com/vllm-project/vllm-ascend/pull/1297 https://github.com/vllm-project/vllm-ascend/pull/1446 Thanks to @yiz-liu for providing the idea ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Tested on Qwen3-30B-A3B Closes: https://github.com/vllm-project/vllm-ascend/issues/1368 --------- Signed-off-by: ApsarasX Signed-off-by: Yikun Jiang Co-authored-by: Yizhou Liu Co-authored-by: Yikun Jiang --- tests/e2e/singlecard/test_aclgraph.py | 2 +- vllm_ascend/ops/common_fused_moe.py | 18 +++++++++++++++++- vllm_ascend/ops/fused_moe.py | 4 +++- 3 files changed, 21 insertions(+), 3 deletions(-) diff --git a/tests/e2e/singlecard/test_aclgraph.py b/tests/e2e/singlecard/test_aclgraph.py index e0bfb65..4fc23aa 100644 --- a/tests/e2e/singlecard/test_aclgraph.py +++ b/tests/e2e/singlecard/test_aclgraph.py @@ -29,7 +29,7 @@ from vllm import LLM, SamplingParams from tests.conftest import VllmRunner from tests.model_utils import check_outputs_equal -MODELS = ["Qwen/Qwen2.5-0.5B-Instruct"] +MODELS = ["Qwen/Qwen2.5-0.5B-Instruct", "vllm-ascend/Qwen3-30B-A3B-Puring"] @pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0", diff --git a/vllm_ascend/ops/common_fused_moe.py b/vllm_ascend/ops/common_fused_moe.py index 4e21c74..3aa23a2 100644 --- a/vllm_ascend/ops/common_fused_moe.py +++ b/vllm_ascend/ops/common_fused_moe.py @@ -18,6 +18,7 @@ from typing import Callable, Optional import torch +from vllm.config import CompilationLevel, get_current_vllm_config from vllm.model_executor.layers.fused_moe.layer import \ UnquantizedFusedMoEMethod @@ -25,6 +26,15 @@ from vllm_ascend.ops.fused_moe import (fused_experts, fused_experts_moge, select_experts) from vllm_ascend.utils import is_310p +original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__ + + +def unquantized_fused_moe_init_func(self, *args, **kwargs): + original_unquantized_fused_moe_init_func(self, *args, **kwargs) + vllm_config = get_current_vllm_config() + self.max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens + self.use_aclgraph = vllm_config.compilation_config.level == CompilationLevel.PIECEWISE and not vllm_config.model_config.enforce_eager + def forward_oot( self, @@ -71,6 +81,10 @@ def forward_oot( expert_map=expert_map, apply_router_weight_on_input=apply_router_weight_on_input) + # If use aclgraph, we need to set max_num_tokens to make + # the input shape of `npu_moe_init_routing` fixed + max_num_tokens = self.max_num_batched_tokens if self.use_aclgraph else None + return fused_experts( hidden_states=x, w1=layer.w13_weight, @@ -79,7 +93,9 @@ def forward_oot( topk_ids=topk_ids, top_k=top_k, expert_map=expert_map, - apply_router_weight_on_input=apply_router_weight_on_input) + apply_router_weight_on_input=apply_router_weight_on_input, + max_num_tokens=max_num_tokens) +UnquantizedFusedMoEMethod.__init__ = unquantized_fused_moe_init_func UnquantizedFusedMoEMethod.forward_oot = forward_oot diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index c9fd8f2..da5e8e3 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -655,6 +655,7 @@ def fused_experts( top_k: int, expert_map: torch.Tensor = None, apply_router_weight_on_input: bool = False, + max_num_tokens: Optional[int] = None, ) -> torch.Tensor: """ Fused experts with top-k routing. @@ -748,11 +749,12 @@ def fused_experts( dtype=torch.int32, device=device).view(top_k, -1).permute( 1, 0).contiguous()) + active_num = max_num_tokens if max_num_tokens is not None else num_tokens sorted_hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( hidden_states, row_idx=row_idx, expert_idx=topk_ids, - active_num=num_tokens) + active_num=active_num) expert_tokens = torch_npu.npu_moe_compute_expert_tokens( expanded_expert_idx, num_experts)