From a58013440a9c9c0b5220a60bc161025c5f5270a2 Mon Sep 17 00:00:00 2001 From: whx <56632993+whx-sjtu@users.noreply.github.com> Date: Thu, 4 Sep 2025 10:22:46 +0800 Subject: [PATCH] [BugFix][MLA] Fix attn_mask bug for ring mla (#2704) This PR fix a bug related to attention mask used in ring mla. Current ring mla has supported compressed mask, so we can directly use a 512 * 512 attention mask. - vLLM version: v0.10.1.1 - vLLM main: https://github.com/vllm-project/vllm/commit/b5ee1e3261d9edf94d76ba8b437ebdef7ac599ea --------- Signed-off-by: whx-sjtu <2952154980@qq.com> --- tests/ut/attention/test_mla_v1.py | 12 +++++++++--- vllm_ascend/attention/mla_v1.py | 14 ++++++-------- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py index 4868e6e..6360504 100644 --- a/tests/ut/attention/test_mla_v1.py +++ b/tests/ut/attention/test_mla_v1.py @@ -238,13 +238,19 @@ class TestAscendMLAImpl(TestBase): new_callable=lambda: MagicMock(spec=GroupCoordinator)) @patch("vllm.distributed.get_tensor_model_parallel_world_size", return_value=2) - @patch("vllm.config.get_current_vllm_config") + @patch("vllm_ascend.attention.mla_v1.get_current_vllm_config") @patch("vllm_ascend.attention.mla_v1.get_ascend_config") - def setUp(self, ascend_config, vllm_config, mock_get_tp_size, mock_tp): + def setUp(self, ascend_config, get_current_vllm_config, mock_get_tp_size, + mock_tp): mock_tp.world_size = 2 + vllm_config = MagicMock() speculative_config = MagicMock() + model_config = MagicMock() speculative_config.num_speculative_tokens = 4 vllm_config.speculative_config = speculative_config + model_config.dtype = torch.float16 + vllm_config.model_config = model_config + get_current_vllm_config.return_value = vllm_config num_heads = 256 head_size = 1024 @@ -622,4 +628,4 @@ class TestAscendMLAImpl(TestBase): self.assertEqual(result.shape[0], B) self.assertEqual(result.shape[1], N) - self.assertEqual(result.shape[2], HD) \ No newline at end of file + self.assertEqual(result.shape[2], HD) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index dcdd627..a386f63 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -483,10 +483,12 @@ class AscendMLAImpl(MLAAttentionImpl): self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz self.chunked_prefill_for_mla = ascend_config.chunked_prefill_for_mla + vllm_config = get_current_vllm_config() + self.ring_mla_mask_size = 512 self.prefill_mask = None # Adapt torch air graph mode with spec decoding. - speculative_config = get_current_vllm_config().speculative_config + speculative_config = vllm_config.speculative_config if speculative_config is not None: self.spec_token_num = speculative_config.num_speculative_tokens assert self.spec_token_num > 0 @@ -681,14 +683,10 @@ class AscendMLAImpl(MLAAttentionImpl): device=q_nope.device) if self.prefill_mask is None: self.prefill_mask = torch.triu( - torch.ones(512, - 512, + torch.ones(self.ring_mla_mask_size, + self.ring_mla_mask_size, device=q_nope.device, - dtype=q_nope.dtype), - 1) # 512: mask only support 512 - if attn_metadata.num_prefills > 1: - self.prefill_mask = self.prefill_mask.unsqueeze(0).repeat( - attn_metadata.num_prefills, 1, 1) + dtype=q_nope.dtype), 1) torch_npu.atb.npu_ring_mla( q_nope=q_nope, q_rope=q_pe,