From b13d22bf5a2d45ee73537e91513ae160bb72ed30 Mon Sep 17 00:00:00 2001 From: Yizhou <136800916+yiz-liu@users.noreply.github.com> Date: Thu, 23 Oct 2025 09:12:50 +0800 Subject: [PATCH] [Fix] Fixes attribute error in MLA implementation (#3618) ### What this PR does / why we need it? Corrects the attribute access for retrieving the device from `q_a_proj` to `q_proj`. This prevents an `AttributeError` as `q_a_proj` does not exist on the class instance. ### Does this PR introduce _any_ user-facing change? None. ### How was this patch tested? Need MLAPO tests. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: Yizhou Liu --- .../spec_decode_v1/test_v1_mtp_correctness.py | 19 ------------------- vllm_ascend/attention/mla_v1.py | 2 +- 2 files changed, 1 insertion(+), 20 deletions(-) diff --git a/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py b/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py index adb0e6ab..b6d8b669 100644 --- a/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py +++ b/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py @@ -1,8 +1,5 @@ from __future__ import annotations -import os -from unittest.mock import patch - import pytest from vllm import SamplingParams from vllm.config import CompilationConfig, CUDAGraphMode @@ -111,19 +108,3 @@ def test_mtp2_correctness_full_graph( model_name: str, ): mtp_correctness(sampling_config, model_name, 2, CUDAGraphMode.FULL) - - -@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_MLAPO": "1"}) -def test_mtp_correctness_piecewise_graph_with_mlapo_kernel( - sampling_config: SamplingParams, - model_name: str, -): - mtp_correctness(sampling_config, model_name, 1) - - -@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_MLAPO": "1"}) -def test_mtp_correctness_full_graph_with_mlapo_kernel( - sampling_config: SamplingParams, - model_name: str, -): - mtp_correctness(sampling_config, model_name, 1, CUDAGraphMode.FULL) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index bcdb6161..86418720 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -737,7 +737,7 @@ class AscendMLAImpl(MLAAttentionImpl): self.qb_qt_bias = qb_qt_bias.reshape( self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim)) - device = self.q_a_proj.weight.device + device = self.q_proj.weight.device self.gamma1 = self.q_a_layernorm.weight.data self.beta1 = self.q_a_layernorm.bias.data self.gamma2 = self.kv_a_layernorm.weight.data