diff --git a/tests/e2e/multicard/test_qwen3_moe.py b/tests/e2e/multicard/test_qwen3_moe.py index ccc31d4..dcac7a8 100644 --- a/tests/e2e/multicard/test_qwen3_moe.py +++ b/tests/e2e/multicard/test_qwen3_moe.py @@ -18,9 +18,11 @@ # """Compare the short outputs of HF and vLLM when using greedy sampling. -Run `pytest tests/test_offline_inference.py`. +Run `pytest tests/e2e/multicard/test_qwen3_moe.py`. """ +from modelscope import snapshot_download # type: ignore + from tests.e2e.conftest import VllmRunner @@ -53,3 +55,20 @@ def test_models_distributed_Qwen3_MOE_TP2_WITH_EP(): distributed_executor_backend="mp", ) as vllm_model: vllm_model.generate_greedy(example_prompts, max_tokens) + + +def test_models_distributed_Qwen3_MOE_W8A8(): + example_prompts = [ + "Hello, my name is", + ] + dtype = "auto" + max_tokens = 5 + with VllmRunner( + snapshot_download("vllm-ascend/Qwen3-30B-A3B-W8A8"), + max_model_len=8192, + dtype=dtype, + tensor_parallel_size=2, + quantization="ascend", + enforce_eager=False, + ) as vllm_model: + vllm_model.generate_greedy(example_prompts, max_tokens) diff --git a/vllm_ascend/models/qwen3_moe.py b/vllm_ascend/models/qwen3_moe.py index f29c2a5..f3598cc 100644 --- a/vllm_ascend/models/qwen3_moe.py +++ b/vllm_ascend/models/qwen3_moe.py @@ -1,6 +1,7 @@ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2024 The Qwen team. # Copyright 2023 The vLLM team. -# +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -26,20 +27,23 @@ from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import (get_dp_group, get_ep_group, get_tp_group) from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.fused_moe.layer import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.models.interfaces import (MixtureOfExperts, + SupportsLoRA, SupportsPP) from vllm.model_executor.models.qwen3_moe import (Qwen3MoeAttention, Qwen3MoeDecoderLayer, Qwen3MoeForCausalLM, Qwen3MoeMLP, Qwen3MoeModel, Qwen3MoeSparseMoeBlock) from vllm.model_executor.models.utils import ( - extract_layer_index, make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) + PPMissingLayer, extract_layer_index, + make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) from vllm_ascend.ops.fused_moe import AscendFusedMoE @@ -230,6 +234,9 @@ class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): nn.Module.__init__(self) + SupportsPP.__init__(self) + SupportsLoRA.__init__(self) + MixtureOfExperts.__init__(self) config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config self.config = config @@ -238,9 +245,31 @@ class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM): prefix=maybe_prefix(prefix, "model")) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, - quant_config=quant_config) + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head")) if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + + # Set MoE hyperparameters + self.expert_weights: list[torch.Tensor] = [] + + self.moe_layers: list[FusedMoE] = [] + example_layer = None + for layer in self.model.layers: + if isinstance(layer, PPMissingLayer): + continue + + assert isinstance(layer, Qwen3MoeDecoderLayer) + if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock): + example_layer = layer.mlp + self.moe_layers.append(layer.mlp.experts) + + if example_layer is None: + raise RuntimeError("No Qwen3MoE layer found in the model.layers.") + + self.num_moe_layers = len(self.moe_layers) + self.num_expert_groups = 1 + self.num_shared_experts = 0