diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py index 4868e6e..6360504 100644 --- a/tests/ut/attention/test_mla_v1.py +++ b/tests/ut/attention/test_mla_v1.py @@ -238,13 +238,19 @@ class TestAscendMLAImpl(TestBase): new_callable=lambda: MagicMock(spec=GroupCoordinator)) @patch("vllm.distributed.get_tensor_model_parallel_world_size", return_value=2) - @patch("vllm.config.get_current_vllm_config") + @patch("vllm_ascend.attention.mla_v1.get_current_vllm_config") @patch("vllm_ascend.attention.mla_v1.get_ascend_config") - def setUp(self, ascend_config, vllm_config, mock_get_tp_size, mock_tp): + def setUp(self, ascend_config, get_current_vllm_config, mock_get_tp_size, + mock_tp): mock_tp.world_size = 2 + vllm_config = MagicMock() speculative_config = MagicMock() + model_config = MagicMock() speculative_config.num_speculative_tokens = 4 vllm_config.speculative_config = speculative_config + model_config.dtype = torch.float16 + vllm_config.model_config = model_config + get_current_vllm_config.return_value = vllm_config num_heads = 256 head_size = 1024 @@ -622,4 +628,4 @@ class TestAscendMLAImpl(TestBase): self.assertEqual(result.shape[0], B) self.assertEqual(result.shape[1], N) - self.assertEqual(result.shape[2], HD) \ No newline at end of file + self.assertEqual(result.shape[2], HD) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index dcdd627..a386f63 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -483,10 +483,12 @@ class AscendMLAImpl(MLAAttentionImpl): self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz self.chunked_prefill_for_mla = ascend_config.chunked_prefill_for_mla + vllm_config = get_current_vllm_config() + self.ring_mla_mask_size = 512 self.prefill_mask = None # Adapt torch air graph mode with spec decoding. - speculative_config = get_current_vllm_config().speculative_config + speculative_config = vllm_config.speculative_config if speculative_config is not None: self.spec_token_num = speculative_config.num_speculative_tokens assert self.spec_token_num > 0 @@ -681,14 +683,10 @@ class AscendMLAImpl(MLAAttentionImpl): device=q_nope.device) if self.prefill_mask is None: self.prefill_mask = torch.triu( - torch.ones(512, - 512, + torch.ones(self.ring_mla_mask_size, + self.ring_mla_mask_size, device=q_nope.device, - dtype=q_nope.dtype), - 1) # 512: mask only support 512 - if attn_metadata.num_prefills > 1: - self.prefill_mask = self.prefill_mask.unsqueeze(0).repeat( - attn_metadata.num_prefills, 1, 1) + dtype=q_nope.dtype), 1) torch_npu.atb.npu_ring_mla( q_nope=q_nope, q_rope=q_pe,