diff --git a/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py b/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py index adb0e6ab..b6d8b669 100644 --- a/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py +++ b/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py @@ -1,8 +1,5 @@ from __future__ import annotations -import os -from unittest.mock import patch - import pytest from vllm import SamplingParams from vllm.config import CompilationConfig, CUDAGraphMode @@ -111,19 +108,3 @@ def test_mtp2_correctness_full_graph( model_name: str, ): mtp_correctness(sampling_config, model_name, 2, CUDAGraphMode.FULL) - - -@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_MLAPO": "1"}) -def test_mtp_correctness_piecewise_graph_with_mlapo_kernel( - sampling_config: SamplingParams, - model_name: str, -): - mtp_correctness(sampling_config, model_name, 1) - - -@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_MLAPO": "1"}) -def test_mtp_correctness_full_graph_with_mlapo_kernel( - sampling_config: SamplingParams, - model_name: str, -): - mtp_correctness(sampling_config, model_name, 1, CUDAGraphMode.FULL) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index bcdb6161..86418720 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -737,7 +737,7 @@ class AscendMLAImpl(MLAAttentionImpl): self.qb_qt_bias = qb_qt_bias.reshape( self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim)) - device = self.q_a_proj.weight.device + device = self.q_proj.weight.device self.gamma1 = self.q_a_layernorm.weight.data self.beta1 = self.q_a_layernorm.bias.data self.gamma2 = self.kv_a_layernorm.weight.data