[BugFix] Fix the bug that qwen3 moe doesn't work with aclgraph (#2183)

What's the PR does:
1. Move AscendSparseMoeBlock to qwen3 model, since it's only used by
qwen3 model.
2. Disable AscendSparseMoeBlock if aclgraph is enabled,
AscendSparseMoeBlock doesn't work with aclgraph currently.

- vLLM version: v0.10.0
- vLLM main:
cdfd6871a5

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
wangxiyuan
2025-08-05 17:42:52 +08:00
committed by GitHub
parent 583ad8f347
commit 458ab2db12
3 changed files with 151 additions and 86 deletions

View File

@@ -0,0 +1,55 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py
#
"""Compare the short outputs of HF and vLLM when using greedy sampling.
Run `pytest tests/test_offline_inference.py`.
"""
from tests.e2e.conftest import VllmRunner
def test_models_distributed_Qwen3_MOE_TP2():
example_prompts = [
"Hello, my name is",
]
dtype = "half"
max_tokens = 5
with VllmRunner(
"Qwen/Qwen3-30B-A3B",
dtype=dtype,
tensor_parallel_size=2,
distributed_executor_backend="mp",
) as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens)
def test_models_distributed_Qwen3_MOE_TP2_WITH_EP():
example_prompts = [
"Hello, my name is",
]
dtype = "half"
max_tokens = 5
with VllmRunner(
"Qwen/Qwen3-30B-A3B",
dtype=dtype,
tensor_parallel_size=2,
enable_expert_parallel=True,
distributed_executor_backend="mp",
) as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens)

View File

@@ -17,11 +17,17 @@
# This file is a part of the vllm-ascend project. # This file is a part of the vllm-ascend project.
from typing import Optional from typing import Optional
import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig from vllm.config import CacheConfig, CompilationLevel, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import (get_dp_group, get_ep_group,
get_tp_group)
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -29,13 +35,84 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.models.qwen3_moe import (Qwen3MoeAttention, from vllm.model_executor.models.qwen3_moe import (Qwen3MoeAttention,
Qwen3MoeDecoderLayer, Qwen3MoeDecoderLayer,
Qwen3MoeForCausalLM, Qwen3MoeForCausalLM,
Qwen3MoeMLP, Qwen3MoeModel) Qwen3MoeMLP, Qwen3MoeModel,
Qwen3MoeSparseMoeBlock)
from vllm.model_executor.models.utils import ( from vllm.model_executor.models.utils import (
extract_layer_index, make_empty_intermediate_tensors_factory, make_layers, extract_layer_index, make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix) maybe_prefix)
from vllm_ascend.ops.fused_moe import AscendSparseMoeBlock from vllm_ascend.ops.fused_moe import AscendFusedMoE
from vllm_ascend.platform import VllmConfig
class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
nn.Module.__init__(self)
self.tp_size = get_tensor_model_parallel_world_size()
if self.tp_size > config.num_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {config.num_experts}.")
self.gate = ReplicatedLinear(
config.hidden_size,
config.num_experts,
bias=False,
quant_config=None,
prefix=f"{prefix}.gate",
)
self.experts = AscendFusedMoE(
num_experts=config.num_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
prefix=f"{prefix}.experts",
)
self.top_k = config.num_experts_per_tok
self.dp_size = get_dp_group().world_size
self.tp_group = get_tp_group().device_group
self.tp_rank = get_tp_group().rank_in_group
self.ep_group = get_ep_group()
self.params_dtype = torch.get_default_dtype()
def forward(
self,
hidden_states,
attn_metadata=None,
):
if attn_metadata is None:
attn_metadata = get_forward_context().attn_metadata
# when profile runs, force experts to load balanced tokens
# to avoid high memory consumption on a single rank.
enable_force_load_balance = get_forward_context().in_profile_run
is_prefill = get_forward_context().with_prefill
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
is_prefill=is_prefill,
top_k=self.top_k,
enable_force_load_balance=enable_force_load_balance,
shared_experts=None,
)
return hidden_states
class CustomQwen3MoeDecoderLayer(Qwen3MoeDecoderLayer): class CustomQwen3MoeDecoderLayer(Qwen3MoeDecoderLayer):
@@ -45,6 +122,7 @@ class CustomQwen3MoeDecoderLayer(Qwen3MoeDecoderLayer):
config: PretrainedConfig, config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
vllm_config: Optional[VllmConfig] = None,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
@@ -73,12 +151,22 @@ class CustomQwen3MoeDecoderLayer(Qwen3MoeDecoderLayer):
layer_idx = extract_layer_index(prefix) layer_idx = extract_layer_index(prefix)
mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else
config.mlp_only_layers) config.mlp_only_layers)
use_aclgraph = (vllm_config is not None
and vllm_config.compilation_config.level
== CompilationLevel.PIECEWISE
and not vllm_config.model_config.enforce_eager)
if (layer_idx not in mlp_only_layers) and ( if (layer_idx not in mlp_only_layers) and (
config.num_experts > 0 and config.num_experts > 0 and
(layer_idx + 1) % config.decoder_sparse_step == 0): (layer_idx + 1) % config.decoder_sparse_step == 0):
self.mlp = AscendSparseMoeBlock(config=config, if not use_aclgraph:
quant_config=quant_config, # FIXME: custom sparse moe block doesn't work with aclgraph.
prefix=f"{prefix}.mlp") self.mlp = CustomSparseMoeBlock(config=config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
else:
self.mlp = Qwen3MoeSparseMoeBlock(config=config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
else: else:
self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size, self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
@@ -115,6 +203,7 @@ class CustomQwen3MoeModel(Qwen3MoeModel):
config=config, config=config,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
vllm_config=vllm_config,
prefix=prefix), prefix=prefix),
prefix=f"{prefix}.layers", prefix=f"{prefix}.layers",
) )

View File

@@ -22,8 +22,6 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import torch_npu import torch_npu
from torch import nn from torch import nn
from transformers import PretrainedConfig
from vllm.attention import AttentionMetadata
from vllm.config import get_current_vllm_config from vllm.config import get_current_vllm_config
from vllm.distributed import (GroupCoordinator, get_tensor_model_parallel_rank, from vllm.distributed import (GroupCoordinator, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
@@ -37,7 +35,6 @@ from vllm.model_executor.layers.fused_moe.config import \
FusedMoEParallelConfig # isort: skip FusedMoEParallelConfig # isort: skip
from vllm.model_executor.layers.fused_moe.layer import ( from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, UnquantizedFusedMoEMethod, determine_expert_map) FusedMoE, UnquantizedFusedMoEMethod, determine_expert_map)
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.quantization.base_config import \ from vllm.model_executor.layers.quantization.base_config import \
QuantizationConfig QuantizationConfig
@@ -1546,79 +1543,3 @@ class AscendFusedMoE(FusedMoE):
) )
return hidden_states return hidden_states
class AscendSparseMoeBlock(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
if self.tp_size > config.num_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {config.num_experts}.")
ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
self.enable_multistream_moe = (
ascend_config.torchair_graph_config.enable_multistream_moe)
self.gate = ReplicatedLinear(
config.hidden_size,
config.num_experts,
bias=False,
quant_config=None,
prefix=f"{prefix}.gate",
)
self.experts = AscendFusedMoE(
num_experts=config.num_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
prefix=f"{prefix}.experts",
)
self.top_k = config.num_experts_per_tok
self.dp_size = get_dp_group().world_size
self.tp_group = get_tp_group().device_group
self.tp_rank = get_tp_group().rank_in_group
self.ep_group = get_ep_group()
self.params_dtype = torch.get_default_dtype()
def forward(
self,
hidden_states: torch.Tensor,
attn_metadata: Optional[AttentionMetadata] = None,
) -> torch.Tensor:
if attn_metadata is None:
attn_metadata = get_forward_context().attn_metadata
# when profile runs, force experts to load balanced tokens
# to avoid high memory consumption on a single rank.
enable_force_load_balance = get_forward_context().in_profile_run
is_prefill = get_forward_context().with_prefill
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
is_prefill=is_prefill,
top_k=self.top_k,
enable_force_load_balance=enable_force_load_balance,
shared_experts=None,
)
return hidden_states