From 5c45c227dc254591f4a9345e67a84a0d5fe1c345 Mon Sep 17 00:00:00 2001 From: elilzhu <2435754260@qq.com> Date: Tue, 14 Oct 2025 17:31:26 +0800 Subject: [PATCH] [BugFix] fix qwen2.5vl quant bug (#3426) ### What this PR does / why we need it? This PR fixes issues: 1. Resolve the issue of qwen2.5-VL quantization service startup failure: AttributeError, 'Parameter' object has no attribute 'weight_loader'. ### Does this PR introduce _any_ user-facing change? None. ### How was this patch tested? - ci & e2e - vLLM version: v0.11.0 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 Signed-off-by: elilzhu <2435754260@qq.com> --- vllm_ascend/quantization/quant_config.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm_ascend/quantization/quant_config.py b/vllm_ascend/quantization/quant_config.py index ed84d0d..f2d7176 100644 --- a/vllm_ascend/quantization/quant_config.py +++ b/vllm_ascend/quantization/quant_config.py @@ -33,6 +33,7 @@ from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.vocab_parallel_embedding import ( UnquantizedEmbeddingMethod, VocabParallelEmbedding) +from vllm.model_executor.parameter import PerTensorScaleParameter from vllm.model_executor.utils import set_weight_attrs from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group, @@ -250,6 +251,7 @@ class AscendLinearMethod(LinearMethodBase): **extra_weight_attrs, ) -> None: output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") weight_dict = self.quant_method.get_weight(input_size_per_partition, output_size_per_partition, @@ -262,7 +264,8 @@ class AscendLinearMethod(LinearMethodBase): pertensor_dict = self.quant_method.get_pertensor_param(params_dtype) for pertensor_name, pertensor_param in pertensor_dict.items(): - param = torch.nn.Parameter(pertensor_param, requires_grad=False) + param = PerTensorScaleParameter(data=pertensor_param, + weight_loader=weight_loader) # disable warning param.ignore_warning = True layer.register_parameter(pertensor_name, param)