diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py index 4a29dd9..4868e6e 100644 --- a/tests/ut/attention/test_mla_v1.py +++ b/tests/ut/attention/test_mla_v1.py @@ -456,3 +456,170 @@ class TestAscendMLAImpl(TestBase): self.assertEqual(result.shape[2], self.impl.v_head_dim) mock_up_proj.assert_called_once() mock_npu_fused_infer_attention_score.assert_called_once() + + @patch("vllm_ascend.attention.mla_v1.npu_prefetch") + def test_mla_preprocess(self, magic_npu_fetch): + magic_npu_fetch.return_value = MagicMock() + batch_size = 4 + seq_len = 8 + hidden_size = 1024 + hidden_states = torch.randn(batch_size * seq_len, hidden_size) + + kv_cache = MagicMock() + + attn_metadata = MagicMock() + attn_metadata.num_decodes = 2 + attn_metadata.num_prefills = 2 + attn_metadata.num_decode_tokens = 2 + attn_metadata.num_actual_tokens = 4 + num_prefill_tokens = 2 + attn_metadata.slot_mapping = torch.arange(4) + attn_metadata.decode.cos = torch.randn(2, 64) + attn_metadata.decode.sin = torch.randn(2, 64) + attn_metadata.prefill.cos = torch.randn(2, 64) + attn_metadata.prefill.sin = torch.randn(2, 64) + + self.impl.q_a_proj = MagicMock() + self.impl.q_a_layernorm = MagicMock() + self.impl.q_a_layernorm.return_value = torch.randn( + attn_metadata.num_actual_tokens, self.impl.num_heads, + self.impl.qk_rope_head_dim) + self.impl.kv_a_proj_with_mqa = MagicMock() + self.impl.kv_a_proj_with_mqa.return_value = [ + torch.randn(num_prefill_tokens, self.impl.num_heads, + self.impl.qk_nope_head_dim + self.impl.kv_lora_rank) + ] + self.impl.q_proj = MagicMock() + self.impl.q_proj.return_value = [ + torch.randn(num_prefill_tokens, self.impl.num_heads, + self.impl.qk_head_dim) + ] + self.impl.kv_b_proj = MagicMock() + self.impl.kv_b_proj.return_value = [ + torch.randn(num_prefill_tokens, self.impl.num_heads, + self.impl.v_head_dim + self.impl.qk_nope_head_dim) + ] + self.impl.rope_single = MagicMock(side_effect=lambda x, cos, sin: x) + self.impl.exec_kv_decode = MagicMock() + self.impl.exec_kv_decode.return_value = [MagicMock(), MagicMock()] + self.impl.exec_kv_prefill = MagicMock() + self.impl.exec_kv_prefill.return_value = [ + torch.randn(num_prefill_tokens, self.impl.num_heads, + self.impl.qk_rope_head_dim), + torch.randn(num_prefill_tokens, self.impl.num_heads, + self.impl.kv_lora_rank) + ] + self.impl._q_proj_and_k_up_proj = MagicMock() + self.impl._q_proj_and_k_up_proj.return_value = [ + MagicMock(), MagicMock() + ] + self.impl.num_kv_heads = self.impl.num_heads + + decode_res, prefill_res = self.impl._mla_preprocess( + hidden_states, kv_cache, attn_metadata, need_gather_q_kv=False) + + self.assertIsNotNone(decode_res) + self.assertIsNotNone(prefill_res) + + @patch("torch_npu.npu_kv_rmsnorm_rope_cache") + def test_exec_kv_prefill(self, mock_kv_rmsnorm_rope_cache): + B = 2 + N = self.impl.num_kv_heads + D = self.impl.kv_lora_rank + self.impl.qk_rope_head_dim + kv_no_split = torch.randn(B, N, D) + self.impl.enable_kv_nz = None + self.impl.kv_a_layernorm.weight = MagicMock() + self.impl.kv_a_layernorm.variance_epsilon = MagicMock() + cos = MagicMock() + sin = MagicMock() + slots = MagicMock() + kv_cache = [MagicMock(), MagicMock()] + + mock_kv_rmsnorm_rope_cache.return_value = [ + None, None, + torch.randn(B, N, 1, self.impl.qk_rope_head_dim), + torch.randn(B, N, 1, self.impl.kv_lora_rank) + ] + + k_pe, k_nope = self.impl.exec_kv_prefill(kv_no_split, cos, sin, + kv_cache, slots) + + self.assertEqual(k_pe.shape[-1], self.impl.qk_rope_head_dim) + self.assertEqual(k_nope.shape[-1], self.impl.kv_lora_rank) + + @patch("torch_npu.npu_kv_rmsnorm_rope_cache") + def test_exec_kv_decode(self, mock_kv_rmsnorm_rope_cache): + B = 2 + N = self.impl.num_kv_heads + D = self.impl.kv_lora_rank + self.impl.qk_rope_head_dim + kv_no_split = torch.randn(B, N, D) + self.impl.enable_kv_nz = None + self.impl.kv_a_layernorm.weight = MagicMock() + self.impl.kv_a_layernorm.variance_epsilon = MagicMock() + cos = MagicMock() + sin = MagicMock() + slots = MagicMock() + kv_cache = [MagicMock(), MagicMock()] + + mock_kv_rmsnorm_rope_cache.return_value = [ + torch.randn(B, N, 1, self.impl.qk_rope_head_dim), + torch.randn(B, N, 1, self.impl.kv_lora_rank), None, None + ] + + k_pe, k_nope = self.impl.exec_kv_decode(kv_no_split, cos, sin, + kv_cache, slots) + + self.assertEqual(k_pe.shape[-1], self.impl.qk_rope_head_dim) + self.assertEqual(k_nope.shape[-1], self.impl.kv_lora_rank) + + @patch("torch.npu.stream") + @patch("vllm_ascend.attention.mla_v1.get_multistream_comm_context") + @patch("torch_npu.npu_fused_infer_attention_score") + def test_forward_decode(self, mock_npu_fused_infer_attention_score, + mock_get_multistream_comm_context, + mock_npu_stream): + B = 2 + N = self.impl.num_kv_heads + BS = 100 + HD = self.impl.v_head_dim + self.impl.kv_lora_rank = 256 + self.impl.spec_token_num = 1 + self.impl._v_up_proj = MagicMock() + self.impl._v_up_proj.return_value = torch.randn(B, N, HD) + q_nope = torch.randn(B, N, self.impl.qk_nope_head_dim) + q_pe = torch.randn(B, N, self.impl.qk_rope_head_dim) + k_nope = torch.randn(BS, N, self.impl.kv_lora_rank) + k_pe = torch.randn(BS, N, self.impl.qk_rope_head_dim) + attn_metadata = MagicMock() + attn_metadata.attn_state = AscendAttentionState.SpecDecoding + attn_metadata.decode = MagicMock() + attn_metadata.decode.actual_seq_lengths_q = MagicMock() + attn_metadata.decode.seq_lens_list = MagicMock() + self.impl.enable_kv_nz = True + + mock_npu_fused_infer_attention_score.return_value = [ + torch.randn(B, N, self.impl.kv_lora_rank), None + ] + mock_get_multistream_comm_context.return_value = None + + result = self.impl._forward_decode(q_nope, q_pe, k_nope, k_pe, BS, + attn_metadata) + + self.assertEqual(result.shape[0], B) + self.assertEqual(result.shape[1], N) + self.assertEqual(result.shape[2], HD) + + self.impl.enable_kv_nz = False + attn_metadata.attn_state = None + mock_return_value = MagicMock() + mock_get_multistream_comm_context.return_value = mock_return_value + mock_return_value.before_comm_event = MagicMock() + mock_return_value.comm_stream = MagicMock() + mock_npu_stream.return_value = MagicMock() + + result = self.impl._forward_decode(q_nope, q_pe, k_nope, k_pe, BS, + attn_metadata) + + self.assertEqual(result.shape[0], B) + self.assertEqual(result.shape[1], N) + self.assertEqual(result.shape[2], HD) \ No newline at end of file