[Bugfix] Support Qwen3-MOE on aclgraph mode (#1381)

### What this PR does / why we need it?
Fix the shape of the `npu_moe_init_routing` input parameters to support
aclgraph mode on qwen3-moe

In addition to this PR, resolving the `gatherv3` error might be
necessary. See related PR
https://github.com/vllm-project/vllm-ascend/pull/1297
https://github.com/vllm-project/vllm-ascend/pull/1446

Thanks to @yiz-liu  for providing the idea

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
Tested on Qwen3-30B-A3B

Closes: https://github.com/vllm-project/vllm-ascend/issues/1368

---------

Signed-off-by: ApsarasX <apsarax@outlook.com>
Signed-off-by: Yikun Jiang <yikunkero@gmail.com>
Co-authored-by: Yizhou Liu <liu_yizhou@outlook.com>
Co-authored-by: Yikun Jiang <yikunkero@gmail.com>
This commit is contained in:
ApsarasX
2025-07-06 15:29:36 +08:00
committed by GitHub
parent 14373f65d7
commit c58accc15e
3 changed files with 21 additions and 3 deletions

View File

@@ -29,7 +29,7 @@ from vllm import LLM, SamplingParams
from tests.conftest import VllmRunner from tests.conftest import VllmRunner
from tests.model_utils import check_outputs_equal from tests.model_utils import check_outputs_equal
MODELS = ["Qwen/Qwen2.5-0.5B-Instruct"] MODELS = ["Qwen/Qwen2.5-0.5B-Instruct", "vllm-ascend/Qwen3-30B-A3B-Puring"]
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0", @pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0",

View File

@@ -18,6 +18,7 @@
from typing import Callable, Optional from typing import Callable, Optional
import torch import torch
from vllm.config import CompilationLevel, get_current_vllm_config
from vllm.model_executor.layers.fused_moe.layer import \ from vllm.model_executor.layers.fused_moe.layer import \
UnquantizedFusedMoEMethod UnquantizedFusedMoEMethod
@@ -25,6 +26,15 @@ from vllm_ascend.ops.fused_moe import (fused_experts, fused_experts_moge,
select_experts) select_experts)
from vllm_ascend.utils import is_310p from vllm_ascend.utils import is_310p
original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__
def unquantized_fused_moe_init_func(self, *args, **kwargs):
original_unquantized_fused_moe_init_func(self, *args, **kwargs)
vllm_config = get_current_vllm_config()
self.max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens
self.use_aclgraph = vllm_config.compilation_config.level == CompilationLevel.PIECEWISE and not vllm_config.model_config.enforce_eager
def forward_oot( def forward_oot(
self, self,
@@ -71,6 +81,10 @@ def forward_oot(
expert_map=expert_map, expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input) apply_router_weight_on_input=apply_router_weight_on_input)
# If use aclgraph, we need to set max_num_tokens to make
# the input shape of `npu_moe_init_routing` fixed
max_num_tokens = self.max_num_batched_tokens if self.use_aclgraph else None
return fused_experts( return fused_experts(
hidden_states=x, hidden_states=x,
w1=layer.w13_weight, w1=layer.w13_weight,
@@ -79,7 +93,9 @@ def forward_oot(
topk_ids=topk_ids, topk_ids=topk_ids,
top_k=top_k, top_k=top_k,
expert_map=expert_map, expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input) apply_router_weight_on_input=apply_router_weight_on_input,
max_num_tokens=max_num_tokens)
UnquantizedFusedMoEMethod.__init__ = unquantized_fused_moe_init_func
UnquantizedFusedMoEMethod.forward_oot = forward_oot UnquantizedFusedMoEMethod.forward_oot = forward_oot

View File

@@ -655,6 +655,7 @@ def fused_experts(
top_k: int, top_k: int,
expert_map: torch.Tensor = None, expert_map: torch.Tensor = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
max_num_tokens: Optional[int] = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Fused experts with top-k routing. Fused experts with top-k routing.
@@ -748,11 +749,12 @@ def fused_experts(
dtype=torch.int32, dtype=torch.int32,
device=device).view(top_k, -1).permute( device=device).view(top_k, -1).permute(
1, 0).contiguous()) 1, 0).contiguous())
active_num = max_num_tokens if max_num_tokens is not None else num_tokens
sorted_hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( sorted_hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
hidden_states, hidden_states,
row_idx=row_idx, row_idx=row_idx,
expert_idx=topk_ids, expert_idx=topk_ids,
active_num=num_tokens) active_num=active_num)
expert_tokens = torch_npu.npu_moe_compute_expert_tokens( expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
expanded_expert_idx, num_experts) expanded_expert_idx, num_experts)