From df7e0fe9169529cb12a4a885d5accba52ad7021a Mon Sep 17 00:00:00 2001 From: Levi <54832289+Levi-JQ@users.noreply.github.com> Date: Mon, 15 Dec 2025 16:39:58 +0800 Subject: [PATCH] [Bugfix] qwen3-vl-235b-w8a8 load weight ERROR when start service (#4292) ### What this PR does / why we need it? fix qwen3-vl-w8a8 load weight ERROR when start service 0.12.0rc1 can start qwen3-vl-235b-w8a8 by adding this PR - vLLM version: v0.11.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379 --------- Signed-off-by: Levi-JQ Co-authored-by: Levi-JQ --- vllm_ascend/quantization/quant_config.py | 34 ++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/vllm_ascend/quantization/quant_config.py b/vllm_ascend/quantization/quant_config.py index 6669fd2d..e358b253 100644 --- a/vllm_ascend/quantization/quant_config.py +++ b/vllm_ascend/quantization/quant_config.py @@ -32,6 +32,7 @@ from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.vocab_parallel_embedding import ( UnquantizedEmbeddingMethod, VocabParallelEmbedding) +from vllm.model_executor.models.utils import WeightsMapper from vllm.model_executor.parameter import PerTensorScaleParameter from vllm.model_executor.utils import set_weight_attrs @@ -103,6 +104,15 @@ class AscendQuantConfig(QuantizationConfig): return ASCEND_QUANTIZATION_METHOD return None + def quant_prefix_mapper(self, model_type: str, prefix: str) -> str: + # TODO (Levi-JQ): will be removed when QuantizationConfig.apply_vllm_mapper is implemented + prefix_mapping = QUANT_MODEL_PREFIX_MAPPINGS.get(model_type) + if prefix_mapping: + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix=prefix_mapping) + return hf_to_vllm_mapper._map_name(prefix) + return prefix + def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> Optional["QuantizeMethodBase"]: vllm_config = get_current_vllm_config() @@ -110,6 +120,7 @@ class AscendQuantConfig(QuantizationConfig): if model_type in packed_modules_model_mapping: self.packed_modules_mapping = packed_modules_model_mapping[ model_type] + prefix = self.quant_prefix_mapper(model_type, prefix) from vllm.attention.layer import Attention if prefix.startswith("language_model"): prefix = prefix.split('.', 1)[-1] @@ -174,6 +185,16 @@ class AscendQuantConfig(QuantizationConfig): return [] +# key: model_type +# value: orig_to_new_prefix +QUANT_MODEL_PREFIX_MAPPINGS = { + "qwen3_vl_moe": { + "visual.": "model.visual.", + "language_model.lm_head.": "lm_head.", + "language_model.model.": "model.language_model.", + }, +} + packed_modules_model_mapping = { "qwen3_moe": { "qkv_proj": [ @@ -242,6 +263,19 @@ packed_modules_model_mapping = { "up_proj", ], }, + "qwen3_vl_moe": { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + "experts": + ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], + }, "glm4_moe": { "qkv_proj": [ "q_proj",