From 211d4b9da4c9e52018e4090b268ea3f57b1bfd5f Mon Sep 17 00:00:00 2001 From: whx <56632993+whx-sjtu@users.noreply.github.com> Date: Thu, 30 Oct 2025 00:35:50 +0800 Subject: [PATCH] [BugFix] Fix mlapo accuracy problem related with weight processing. (#3857) This PR fixes a mlapo accuracy problem related with weight processing. Furthermore, modify mlapo related e2e test with quantized deepseek model to make it effective. Signed-off-by: whx-sjtu <2952154980@qq.com> --- .../spec_decode_v1/test_v1_mtp_correctness.py | 16 ---------------- vllm_ascend/attention/mla_v1.py | 4 ++-- 2 files changed, 2 insertions(+), 18 deletions(-) diff --git a/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py b/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py index adb0e6a..6dc2bc9 100644 --- a/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py +++ b/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py @@ -111,19 +111,3 @@ def test_mtp2_correctness_full_graph( model_name: str, ): mtp_correctness(sampling_config, model_name, 2, CUDAGraphMode.FULL) - - -@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_MLAPO": "1"}) -def test_mtp_correctness_piecewise_graph_with_mlapo_kernel( - sampling_config: SamplingParams, - model_name: str, -): - mtp_correctness(sampling_config, model_name, 1) - - -@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_MLAPO": "1"}) -def test_mtp_correctness_full_graph_with_mlapo_kernel( - sampling_config: SamplingParams, - model_name: str, -): - mtp_correctness(sampling_config, model_name, 1, CUDAGraphMode.FULL) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 977c114..177d91b 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -676,9 +676,9 @@ class AscendMLAImpl(MLAAttentionImpl): ..., self.q_lora_rank:].contiguous() q_a_proj_wt = self.fused_qkv_a_proj.weight.data[ ..., :self.q_lora_rank].contiguous() - kv_a_proj_wt = kv_a_proj_wt.contiguous() + kv_a_proj_wt = kv_a_proj_wt.t().contiguous() kv_a_proj_wt = trans_rope_weight(kv_a_proj_wt, self.qk_rope_head_dim) - kv_a_proj_wt = kv_a_proj_wt.contiguous() + kv_a_proj_wt = kv_a_proj_wt.t().contiguous() wd_qkv = torch.cat((kv_a_proj_wt, q_a_proj_wt), dim=-1) wd_qkv = wd_qkv.t().contiguous() wd_qkv = transdata(wd_qkv,