[Fix] Fixes attribute error in MLA implementation (#3618)

### What this PR does / why we need it?
Corrects the attribute access for retrieving the device from `q_a_proj`
to `q_proj`. This prevents an `AttributeError` as `q_a_proj` does not
exist on the class instance.

### Does this PR introduce _any_ user-facing change?
None.

### How was this patch tested?
Need MLAPO tests.

- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
Yizhou
2025-10-23 09:12:50 +08:00
committed by GitHub
parent 179b897b52
commit b13d22bf5a
2 changed files with 1 additions and 20 deletions

View File

@@ -1,8 +1,5 @@
from __future__ import annotations from __future__ import annotations
import os
from unittest.mock import patch
import pytest import pytest
from vllm import SamplingParams from vllm import SamplingParams
from vllm.config import CompilationConfig, CUDAGraphMode from vllm.config import CompilationConfig, CUDAGraphMode
@@ -111,19 +108,3 @@ def test_mtp2_correctness_full_graph(
model_name: str, model_name: str,
): ):
mtp_correctness(sampling_config, model_name, 2, CUDAGraphMode.FULL) mtp_correctness(sampling_config, model_name, 2, CUDAGraphMode.FULL)
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_MLAPO": "1"})
def test_mtp_correctness_piecewise_graph_with_mlapo_kernel(
sampling_config: SamplingParams,
model_name: str,
):
mtp_correctness(sampling_config, model_name, 1)
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_MLAPO": "1"})
def test_mtp_correctness_full_graph_with_mlapo_kernel(
sampling_config: SamplingParams,
model_name: str,
):
mtp_correctness(sampling_config, model_name, 1, CUDAGraphMode.FULL)

View File

@@ -737,7 +737,7 @@ class AscendMLAImpl(MLAAttentionImpl):
self.qb_qt_bias = qb_qt_bias.reshape( self.qb_qt_bias = qb_qt_bias.reshape(
self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim)) self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim))
device = self.q_a_proj.weight.device device = self.q_proj.weight.device
self.gamma1 = self.q_a_layernorm.weight.data self.gamma1 = self.q_a_layernorm.weight.data
self.beta1 = self.q_a_layernorm.bias.data self.beta1 = self.q_a_layernorm.bias.data
self.gamma2 = self.kv_a_layernorm.weight.data self.gamma2 = self.kv_a_layernorm.weight.data