From 6029bea4802e964b753ba42a8bb8bad96adf271b Mon Sep 17 00:00:00 2001 From: zengzengran Date: Mon, 15 Dec 2025 18:41:38 +0800 Subject: [PATCH] [UT]add pcp dcp ut (#4949) ### What this PR does / why we need it? Adding UT for DCP/PCP -vLLM version: v0.12.0 -vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 Signed-off-by: zengran --- tests/ut/attention/test_attention_cp.py | 321 +++++++++++++++++++ tests/ut/attention/test_mla_cp.py | 403 ++++++++++++++++++++++++ tests/ut/attention/test_mla_v1.py | 52 ++- 3 files changed, 768 insertions(+), 8 deletions(-) create mode 100644 tests/ut/attention/test_attention_cp.py create mode 100755 tests/ut/attention/test_mla_cp.py mode change 100644 => 100755 tests/ut/attention/test_mla_v1.py diff --git a/tests/ut/attention/test_attention_cp.py b/tests/ut/attention/test_attention_cp.py new file mode 100644 index 00000000..3a3af95d --- /dev/null +++ b/tests/ut/attention/test_attention_cp.py @@ -0,0 +1,321 @@ +from unittest.mock import MagicMock, patch + +import torch +from vllm.distributed.parallel_state import GroupCoordinator + +from tests.ut.base import TestBase +from vllm_ascend.attention.attention_cp import AscendAttentionCPImpl + + +class TestAscendAttentionCPImpl(TestBase): + + @patch('vllm_ascend.attention.attention_cp.get_pcp_group') + @patch('vllm.distributed.parallel_state._PCP', + new_callable=lambda: MagicMock(spec=GroupCoordinator)) + @patch('vllm_ascend.attention.attention_cp.get_dcp_group') + @patch('vllm.distributed.parallel_state._DCP', + new_callable=lambda: MagicMock(spec=GroupCoordinator)) + @patch("vllm.distributed.get_decode_context_model_parallel_world_size", + return_value=1) + def setUp(self, mock_get_dcp_size, mock_dcp, mock_get_dcp_group, mock_pcp, + mock_get_pcp_group): + mock_dcp.world_size = 2 + mock_dcp.rank_in_group = 0 + dcp_group = MagicMock(spec=GroupCoordinator) + dcp_group.rank_in_group = 0 + dcp_group.world_size = 2 + dcp_group.device_group = MagicMock() + mock_get_dcp_group.return_value = dcp_group + + mock_pcp.world_size = 2 + mock_pcp.rank_in_group = 0 + pcp_group = MagicMock(spec=GroupCoordinator) + pcp_group.rank_in_group = 0 + pcp_group.world_size = 2 + pcp_group.device_group = MagicMock() + mock_get_pcp_group.return_value = pcp_group + + self.layer = MagicMock() + self.layer.layer_name = "test_layer" + self.layer._k_scale_float = 1.0 + self.layer._v_scale_float = 1.0 + + self.attention_type = MagicMock() + self.attention_type.DECODER = "decoder" + self.attention_type.ENCODER = "encoder" + + self.attn_metadata = MagicMock() + self.attn_metadata.return_value = "1" + + self.layer_no_quant = MagicMock( + spec=['layer_name', '_k_scale_float', '_v_scale_float']) + self.layer_no_quant.layer_name = "test_layer" + self.layer_no_quant._k_scale_float = 1.0 + self.layer_no_quant._v_scale_float = 1.0 + + self.impl = AscendAttentionCPImpl( + num_heads=8, + head_size=64, + scale=1.0, + num_kv_heads=8, + alibi_slopes=None, + sliding_window=None, + kv_cache_dtype="float16", + logits_soft_cap=None, + attn_type=self.attention_type.DECODER, + kv_sharing_target_layer_name=None) + + def test_init(self): + self.assertEqual(self.impl.pcp_size, 2) + self.assertEqual(self.impl.pcp_rank, 0) + self.assertEqual(self.impl.dcp_size, 2) + self.assertEqual(self.impl.dcp_rank, 0) + + def test_forward_prefill_cp(self): + query = torch.randn(2, 4, 128) + key = torch.randn(4, 1, 128) + value = torch.randn(4, 1, 128) + + def mock_attention_with_nomask_and_mask(q, k_mask, **kwargs): + mock_output = torch.randn_like(q) + mock_lse = torch.randn_like(k_mask) + return mock_output, mock_lse + + self.impl._attention_with_nomask_and_mask = MagicMock() + self.impl._attention_with_nomask_and_mask.side_effect = mock_attention_with_nomask_and_mask + + attn_metadata = MagicMock() + attn_metadata.prefill = MagicMock() + attn_metadata.prefill.pcp_metadata.q_head_idx = torch.tensor([0]) + attn_metadata.prefill.pcp_metadata.q_tail_idx = torch.tensor([1]) + attn_metadata.prefill.pcp_metadata.q_full_idx = torch.tensor([0, 1]) + attn_metadata.prefill.pcp_metadata.kv_with_q_head_mask_idx = torch.tensor( + [0]) + attn_metadata.prefill.pcp_metadata.kv_with_q_tail_nomask_idx = torch.tensor( + [0]) + attn_metadata.prefill.pcp_metadata.kv_with_q_tail_mask_idx = torch.tensor( + [0]) + + output, attn_lse = self.impl._forward_prefill_cp( + query, key, value, attn_metadata) + + self.assertEqual(output.shape[0], 2) + self.assertEqual(output.shape[1], 4) + self.assertEqual(output.shape[2], 128) + + @patch('vllm_ascend.attention.attention_cp.get_dcp_group') + @patch('vllm.distributed.parallel_state._DCP') + @patch("torch_npu.npu_fused_infer_attention_score") + @patch("torch.distributed.all_gather") + @patch("torch.distributed.all_to_all_single") + @patch('vllm_ascend.attention.attention_cp.get_forward_context') + def test_forward_decode_pcp_dcp(self, mock_get_forward_context, + mock_all_to_all_single, mock_all_gather, + mock_npu_fused_infer_attention_score, + mock_dcp, mock_get_dcp_group): + + def mock_dcp_all_gather_func(tensor, dim): + return torch.cat([tensor, tensor], dim=dim) + + mock_dcp.world_size = 2 + mock_dcp.rank_in_group = 0 + dcp_group = MagicMock(spec=GroupCoordinator) + dcp_group.rank_in_group = 0 + dcp_group.world_size = 2 + dcp_group.device_group = MagicMock() + dcp_group.all_gather = mock_dcp_all_gather_func + mock_get_dcp_group.return_value = dcp_group + + query = torch.randn(2, 4, 128) + self.impl.key_cache = torch.randn(100, 128, 1, 128) + self.impl.value_cache = torch.randn(100, 128, 1, 128) + + def mock_npu_attention_update(attn_out_lse_list): + mock_output = torch.randn(attn_out_lse_list[0].shape[0], + attn_out_lse_list[0].shape[1], + attn_out_lse_list[0].shape[2] - 1) + return mock_output + + self.impl._npu_attention_update = MagicMock() + self.impl._npu_attention_update.side_effect = mock_npu_attention_update + + mock_get_forward_context.return_value = MagicMock(capturing=False) + + mock_all_to_all_single.side_effect = lambda output, input, *args, **kwargs: output.copy_( + input) + + def mock_all_gather_func(tensor_list, tensor, group=None): + tensor_list[0] = tensor + tensor_list[1] = tensor.clone() + + mock_all_gather.side_effect = mock_all_gather_func + + def mock_npu_fused_infer_attention_score_func(query, k_nope, value, + **common_kwargs): + mock_output = torch.randn_like(query) + mock_lse = torch.randn(query.shape[0], query.shape[1], 1) + return mock_output, mock_lse + + mock_npu_fused_infer_attention_score.side_effect = mock_npu_fused_infer_attention_score_func + + attn_metadata = MagicMock() + attn_metadata.decode_meta = MagicMock() + attn_metadata.decode_meta.batch_seq_mask = torch.tensor( + [1, 0], dtype=torch.bool) + + output = self.impl._forward_decode_pcp_dcp(query, attn_metadata) + + self.assertEqual(output.shape[0], 2) + self.assertEqual(output.shape[1], 4) + self.assertEqual(output.shape[2], 128) + + @patch('vllm_ascend.attention.attention_cp.get_pcp_group') + @patch('vllm.distributed.parallel_state._PCP') + @patch('vllm_ascend.attention.attention_cp.get_dcp_group') + @patch('vllm.distributed.parallel_state._DCP') + def test_prefill_query_all_gather(self, mock_dcp, mock_get_dcp_group, + mock_pcp, mock_get_pcp_group): + query = torch.randn(2, 4, 128) + + def mock_all_gather_func(tensor, dim): + return torch.cat([tensor, tensor], dim=dim) + + dcp_group = MagicMock(spec=GroupCoordinator) + dcp_group.all_gather = mock_all_gather_func + mock_get_dcp_group.return_value = dcp_group + + pcp_group = MagicMock(spec=GroupCoordinator) + pcp_group.all_gather = mock_all_gather_func + mock_get_pcp_group.return_value = pcp_group + + attn_metadata = MagicMock() + attn_metadata.prefill = MagicMock() + attn_metadata.prefill.chunked_context = MagicMock() + attn_metadata.prefill.chunked_context.cp_kv_recover_idx_for_chunk = torch.tensor( + [1, 2, 3, 0]) + output = self.impl._prefill_query_all_gather(attn_metadata, query) + + self.assertEqual(output.shape[0], 4) + self.assertEqual(output.shape[1], 8) + self.assertEqual(output.shape[2], 128) + + @patch('torch.ops.npu.npu_fused_infer_attention_score') + def test_compute_prefill_context(self, mock_npu_attention): + + block_num = 100 + block_size = 128 + kv_num_heads = 1 + head_size = 128 + kv_cache = (torch.randn(block_num, block_size, kv_num_heads, + head_size), + torch.randn(block_num, block_size, kv_num_heads, + head_size)) + + batch_size = 1024 + self.impl.head_size = head_size + self.impl.num_heads = 4 + num_heads = self.impl.num_heads * self.impl.dcp_size + query = torch.randn(batch_size, num_heads, head_size) + + attn_metadata = MagicMock() + attn_metadata.prefill = MagicMock() + attn_metadata.prefill.chunked_context = MagicMock() + attn_metadata.prefill.chunked_context.local_context_lens_allranks = torch.tensor( + [[[256, 256], [256, 256]]]) + attn_metadata.prefill.chunked_context.batch_chunk_seq_mask = torch.randint( + 0, 2, (1024, ), dtype=torch.bool) + + def mock_load_kv_for_chunk(attn_metadata, kv_cache, + local_chunked_kv_lens_rank, query, + total_toks): + return torch.randn(total_toks, kv_num_heads, + head_size), torch.randn(total_toks, + kv_num_heads, head_size) + + self.impl._load_kv_for_chunk = MagicMock() + self.impl._load_kv_for_chunk.side_effect = mock_load_kv_for_chunk + + mock_npu_attention.return_value = torch.randn(batch_size, num_heads, + head_size), torch.randn( + batch_size, + num_heads, 1) + + result_output, result_lse = self.impl._compute_prefill_context( + query, kv_cache, attn_metadata) + + self.assertEqual(result_output.shape[0], batch_size) + self.assertEqual(result_output.shape[1], self.impl.num_heads) + self.assertEqual(result_output.shape[2], head_size) + self.assertEqual(result_lse.shape[0], batch_size) + self.assertEqual(result_lse.shape[1], self.impl.num_heads) + self.assertEqual(result_lse.shape[2], 1) + + @patch('torch_npu.atb.npu_paged_cache_load') + def test_load_kv_for_chunk(self, mock_npu_paged_cache_load): + block_num = 100 + block_size = 128 + num_heads = 1 + head_size = 128 + + kv_cache = (torch.randn(block_num, block_size, num_heads, head_size), + torch.randn(block_num, block_size, num_heads, head_size)) + query = torch.randn(4, 8, 128) + total_toks = 256 + local_chunked_kv_lens_rank = torch.randn(total_toks) + + attn_metadata = MagicMock() + + key, value = self.impl._load_kv_for_chunk(attn_metadata, kv_cache, + local_chunked_kv_lens_rank, + query, total_toks) + + self.assertEqual(key.shape[0], total_toks) + self.assertEqual(key.shape[1], num_heads) + self.assertEqual(key.shape[2], head_size) + self.assertEqual(value.shape[0], total_toks) + self.assertEqual(value.shape[1], num_heads) + self.assertEqual(value.shape[2], head_size) + + @patch('vllm_ascend.attention.attention_cp.get_pcp_group') + @patch('vllm.distributed.parallel_state._PCP') + @patch('torch_npu._npu_reshape_and_cache') + def test_reshape_and_cache(self, mock_npu_reshape_and_cache, mock_pcp, + mock_get_pcp_group): + num_tokens = 4 + block_num = 100 + block_size = 128 + num_heads = 1 + head_size = 128 + self.impl.head_size = head_size + + kv_cache = (torch.randn(block_num, block_size, num_heads, head_size), + torch.randn(block_num, block_size, num_heads, head_size)) + + attn_metadata = MagicMock() + attn_metadata.num_decode_tokens = 1 + attn_metadata.num_decodes = 1 + attn_metadata.num_prefills = 1 + attn_metadata.slot_mapping = torch.randn(2) + attn_metadata.num_actual_tokens_pcp_padded = num_tokens * self.impl.pcp_size + attn_metadata.prefill = MagicMock() + attn_metadata.prefill.pcp_allgather_restore_idx = torch.tensor( + [0, 3, 1, 2, 0, 0, 0, 0]) + + key = torch.randn(num_tokens, num_heads, head_size) + value = torch.randn(num_tokens, num_heads, head_size) + + def mock_all_gather_func(tensor, dim): + return torch.cat([tensor, tensor], dim=dim) + + pcp_group = MagicMock(spec=GroupCoordinator) + pcp_group.all_gather = mock_all_gather_func + mock_get_pcp_group.return_value = pcp_group + + key, value = self.impl.reshape_and_cache(key, value, kv_cache, + attn_metadata) + self.assertEqual(key.shape[0], num_tokens * self.impl.pcp_size) + self.assertEqual(key.shape[1], num_heads) + self.assertEqual(key.shape[2], head_size) + self.assertEqual(value.shape[0], num_tokens * self.impl.pcp_size) + self.assertEqual(value.shape[1], num_heads) + self.assertEqual(value.shape[2], head_size) diff --git a/tests/ut/attention/test_mla_cp.py b/tests/ut/attention/test_mla_cp.py new file mode 100755 index 00000000..52286225 --- /dev/null +++ b/tests/ut/attention/test_mla_cp.py @@ -0,0 +1,403 @@ +from unittest.mock import MagicMock, patch + +import torch +from vllm.distributed.parallel_state import GroupCoordinator + +from tests.ut.base import TestBase +from vllm_ascend.ascend_config import init_ascend_config +from vllm_ascend.attention.attention_v1 import AscendAttentionState +from vllm_ascend.attention.mla_cp import AscendMlaCPImpl + + +class TestAscendMLAImpl(TestBase): + + @patch('vllm.distributed.parallel_state._PCP', + new_callable=lambda: MagicMock(spec=GroupCoordinator)) + @patch('vllm.distributed.parallel_state._DCP', + new_callable=lambda: MagicMock(spec=GroupCoordinator)) + @patch("vllm.distributed.get_decode_context_model_parallel_world_size", + return_value=1) + @patch('vllm.distributed.parallel_state._TP', + new_callable=lambda: MagicMock(spec=GroupCoordinator)) + @patch("vllm.distributed.get_tensor_model_parallel_world_size", + return_value=2) + @patch("vllm_ascend.attention.mla_v1.get_current_vllm_config") + @patch("vllm_ascend.attention.mla_v1.get_ascend_config") + def setUp(self, ascend_config, get_current_vllm_config, mock_get_tp_size, + mock_tp, mock_get_dcp_size, mock_dcp, mock_pcp): + mock_tp.world_size = 2 + mock_tp.rank_in_group = MagicMock() + mock_tp.device_group = MagicMock() + mock_dcp.world_size = 2 + mock_dcp.rank_in_group = MagicMock() + mock_dcp.device_group = MagicMock() + mock_pcp.world_size = 2 + mock_pcp.rank_in_group = MagicMock() + mock_pcp.device_group = MagicMock() + vllm_config = MagicMock() + speculative_config = MagicMock() + model_config = MagicMock() + speculative_config.num_speculative_tokens = 4 + vllm_config.speculative_config = speculative_config + model_config.dtype = torch.float16 + vllm_config.model_config = model_config + get_current_vllm_config.return_value = vllm_config + vllm_config.additional_config = {"refresh": True} + init_ascend_config(vllm_config) + + num_heads = 256 + head_size = 1024 + scale = 0.1 + num_kv_heads = 8 + kv_cache_dtype = "auto" + + kv_a_layernorm = MagicMock() + kv_a_layernorm.weight = torch.randn(96) + kv_a_layernorm.variance_epsilon = 1e-6 + kwargs = { + "kv_lora_rank": 32, + "qk_nope_head_dim": 64, + "qk_rope_head_dim": 32, + "qk_head_dim": 96, + "v_head_dim": 128, + "q_lora_rank": 64, + "q_proj": MagicMock(), + "q_b_proj": MagicMock(), + "kv_b_proj": MagicMock(), + "o_proj": MagicMock(), + "kv_a_proj_with_mqa": MagicMock(), + "fused_qkv_a_proj": MagicMock(), + "kv_a_layernorm": kv_a_layernorm, + "rotary_emb": MagicMock(), + } + + self.impl = AscendMlaCPImpl(num_heads=num_heads, + head_size=head_size, + scale=scale, + num_kv_heads=num_kv_heads, + alibi_slopes=None, + sliding_window=None, + kv_cache_dtype=kv_cache_dtype, + blocksparse_params=None, + logits_soft_cap=None, + attn_type=None, + kv_sharing_target_layer_name=None, + **kwargs) + + def test_init(self): + self.assertEqual(self.impl.num_heads, 256) + self.assertEqual(self.impl.head_size, 1024) + self.assertEqual(self.impl.scale, 0.1) + self.assertEqual(self.impl.num_kv_heads, 8) + self.assertEqual(self.impl.kv_cache_dtype, "auto") + self.assertEqual(self.impl.kv_lora_rank, 32) + self.assertEqual(self.impl.qk_nope_head_dim, 64) + self.assertEqual(self.impl.qk_rope_head_dim, 32) + self.assertEqual(self.impl.qk_head_dim, 96) + self.assertEqual(self.impl.v_head_dim, 128) + self.assertIsNotNone(self.impl.q_proj) + self.assertIsNotNone(self.impl.kv_b_proj) + self.assertIsNotNone(self.impl.o_proj) + self.assertIsNotNone(self.impl.kv_a_proj_with_mqa) + self.assertIsNotNone(self.impl.kv_a_layernorm) + self.assertEqual(self.impl.num_queries_per_kv, 32) + self.assertEqual(self.impl.pcp_size, 2) + self.assertEqual(self.impl.dcp_size, 2) + + @patch('vllm_ascend.attention.mla_cp.get_dcp_group') + @patch("torch.ops.vllm.maybe_all_gather_and_maybe_unpad") + @patch("vllm_ascend.attention.mla_cp.maybe_npu_prefetch") + def test_mla_preprocess_dcp(self, magic_npu_fetch, + mock_maybe_all_gather_and_maybe_unpad, + mock_get_dcp_group): + + self.impl.num_kv_heads = 1 + self.impl.num_heads = 16 + self.impl.qk_rope_head_dim = 64 + self.impl.kv_lora_rank = 512 + self.impl.q_lora_rank = 1536 + self.impl.dcp_size = 2 + self.impl.pcp_size = 2 + block_num = 10 + block_size = 128 + batch_size = 2 + hidden_size = 1024 + hidden_states = torch.randn(batch_size, hidden_size) + + kv_cache0 = torch.randn(block_num, block_size, self.impl.num_kv_heads, + self.impl.kv_lora_rank) + kv_cache1 = torch.randn(block_num, block_size, self.impl.num_kv_heads, + self.impl.qk_rope_head_dim) + kv_cache = (kv_cache0, kv_cache1) + + mock_dcp_group = MagicMock() + + def mock_all_gather_func(tensor, dim): + return torch.cat([tensor, tensor], dim=dim) + + mock_dcp_group.all_gather = mock_all_gather_func + mock_get_dcp_group.return_value = mock_dcp_group + + attn_metadata = MagicMock() + attn_metadata.num_decodes = 2 + attn_metadata.num_prefills = 0 + attn_metadata.num_prefill_tokens = 0 + attn_metadata.num_decode_tokens = 2 + attn_metadata.num_actual_tokens = 2 + attn_metadata.slot_mapping = torch.arange(4) + attn_metadata.decode.cos = torch.randn(2, 64) + attn_metadata.decode.sin = torch.randn(2, 64) + + self.impl.q_a_layernorm = MagicMock() + self.impl.q_a_layernorm.return_value = torch.randn( + attn_metadata.num_actual_tokens, self.impl.q_lora_rank) + self.impl.kv_a_proj_with_mqa = MagicMock() + self.impl.kv_a_proj_with_mqa.return_value = [ + torch.randn(batch_size, self.impl.num_heads, + self.impl.qk_rope_head_dim + self.impl.kv_lora_rank) + ] + self.impl.fused_qkv_a_proj = MagicMock() + self.impl.fused_qkv_a_proj.return_value = [ + torch.randn( + attn_metadata.num_actual_tokens, self.impl.qk_rope_head_dim + + self.impl.kv_lora_rank + self.impl.q_lora_rank) + ] + + self.impl.rope_single = MagicMock(side_effect=lambda x, cos, sin: x) + self.impl.exec_kv_decode = MagicMock() + self.impl.exec_kv_decode.return_value = [MagicMock(), MagicMock()] + + self.impl._q_proj_and_k_up_proj = MagicMock() + self.impl._q_proj_and_k_up_proj.return_value = [ + torch.randn(attn_metadata.num_decodes, self.impl.num_heads, + self.impl.kv_lora_rank), + torch.randn(attn_metadata.num_decodes, self.impl.num_heads, + self.impl.qk_rope_head_dim) + ] + + magic_npu_fetch.return_value = MagicMock() + mock_maybe_all_gather_and_maybe_unpad.side_effect = lambda x, label: x + + decode_res, prefill_res = self.impl._mla_preprocess( + "mock_layer", + hidden_states, + kv_cache, + attn_metadata, + need_gather_q_kv=False) + + self.assertIsNotNone(decode_res) + self.assertIsNone(prefill_res) + + @patch('torch_npu._npu_reshape_and_cache') + @patch('vllm_ascend.attention.mla_cp.get_pcp_group') + @patch("torch.ops.vllm.maybe_all_gather_and_maybe_unpad") + @patch("vllm_ascend.attention.mla_cp.maybe_npu_prefetch") + def test_mla_preprocess_pcp(self, magic_npu_fetch, + mock_maybe_all_gather_and_maybe_unpad, + mock_get_pcp_group, + mock_npu_reshape_and_cache): + self.impl.num_kv_heads = 1 + self.impl.num_heads = 16 + self.impl.qk_rope_head_dim = 64 + self.impl.kv_lora_rank = 512 + self.impl.q_lora_rank = 1536 + self.impl.dcp_size = 2 + self.impl.pcp_size = 2 + block_num = 10 + block_size = 128 + batch_size = 2 + hidden_size = 1024 + hidden_states = torch.randn(batch_size, hidden_size) + + kv_cache0 = torch.randn(block_num, block_size, self.impl.num_kv_heads, + self.impl.kv_lora_rank) + kv_cache1 = torch.randn(block_num, block_size, self.impl.num_kv_heads, + self.impl.qk_rope_head_dim) + kv_cache = (kv_cache0, kv_cache1) + + mock_pcp_group = MagicMock() + + def mock_all_gather_func(tensor, dim): + return torch.cat([tensor, tensor], dim=dim) + + mock_pcp_group.all_gather = mock_all_gather_func + mock_get_pcp_group.return_value = mock_pcp_group + + attn_metadata = MagicMock() + attn_metadata.num_decodes = 0 + attn_metadata.num_prefills = 2 + attn_metadata.num_prefill_tokens = 2 + attn_metadata.num_decode_tokens = 0 + attn_metadata.num_actual_tokens = 2 + attn_metadata.num_actual_tokens_pcp_padded = 4 + attn_metadata.prefill.pcp_metadata = MagicMock() + attn_metadata.prefill.pcp_metadata.pcp_allgather_restore_idx = torch.arange( + 4) + attn_metadata.slot_mapping = torch.arange(4) + attn_metadata.prefill.cos = torch.randn(2, 64) + attn_metadata.prefill.sin = torch.randn(2, 64) + + self.impl.q_a_layernorm = MagicMock() + self.impl.q_a_layernorm.return_value = torch.randn( + attn_metadata.num_actual_tokens, self.impl.q_lora_rank) + self.impl.kv_a_proj_with_mqa = MagicMock() + self.impl.kv_a_proj_with_mqa.return_value = [ + torch.randn(batch_size, self.impl.num_heads, + self.impl.qk_rope_head_dim + self.impl.kv_lora_rank) + ] + self.impl.fused_qkv_a_proj = MagicMock() + self.impl.fused_qkv_a_proj.return_value = [ + torch.randn( + attn_metadata.num_actual_tokens, self.impl.qk_rope_head_dim + + self.impl.kv_lora_rank + self.impl.q_lora_rank) + ] + + self.impl.rope_single = MagicMock(side_effect=lambda x, cos, sin: x) + self.impl.exec_kv_decode = MagicMock() + self.impl.exec_kv_decode.return_value = [MagicMock(), MagicMock()] + + self.impl._q_proj_and_k_up_proj = MagicMock() + self.impl._q_proj_and_k_up_proj.return_value = [ + torch.randn(attn_metadata.num_decodes, self.impl.num_heads, + self.impl.kv_lora_rank), + torch.randn(attn_metadata.num_decodes, self.impl.num_heads, + self.impl.qk_rope_head_dim) + ] + + magic_npu_fetch.return_value = MagicMock() + mock_maybe_all_gather_and_maybe_unpad.side_effect = lambda x, label: x + + self.impl.kv_a_layernorm = MagicMock() + self.impl.kv_a_layernorm.return_value = torch.randn( + attn_metadata.num_prefill_tokens, self.impl.num_kv_heads, + self.impl.kv_lora_rank) + + self.impl.q_proj = MagicMock() + self.impl.q_proj.return_value = [ + torch.randn(attn_metadata.num_prefill_tokens, self.impl.num_heads, + self.impl.qk_head_dim) + ] + self.impl.kv_b_proj = MagicMock() + self.impl.kv_b_proj.return_value = [ + torch.randn(attn_metadata.num_prefill_tokens * self.impl.pcp_size, + self.impl.num_heads, + self.impl.v_head_dim + self.impl.qk_nope_head_dim) + ] + self.impl.rope_single = MagicMock(side_effect=lambda x, cos, sin: x) + self.impl.exec_kv_decode = MagicMock() + self.impl.exec_kv_decode.return_value = [MagicMock(), MagicMock()] + self.impl.exec_kv_prefill = MagicMock() + self.impl.exec_kv_prefill.return_value = [ + torch.randn(attn_metadata.num_prefill_tokens, self.impl.num_heads, + self.impl.qk_rope_head_dim), + torch.randn(attn_metadata.num_prefill_tokens, self.impl.num_heads, + self.impl.kv_lora_rank) + ] + + decode_res, prefill_res = self.impl._mla_preprocess( + "mock_layer", + hidden_states, + kv_cache, + attn_metadata, + need_gather_q_kv=False) + self.assertIsNone(decode_res) + self.assertIsNotNone(prefill_res) + + @patch("torch.distributed.all_gather") + @patch("torch.distributed.all_to_all_single") + def test_process_attn_out_lse(self, mock_all_to_all_single, + mock_all_gather): + self.impl.dcp_size = 2 + self.impl.pcp_size = 2 + + B = 2 + N = self.impl.num_heads + self.impl.kv_lora_rank = 512 + + attn_output = torch.randn(B, N, self.impl.kv_lora_rank) + softmax_lse = torch.randn(B, N, 1) + + mock_all_to_all_single.side_effect = lambda output, input, *args, **kwargs: output.copy_( + input) + + def mock_all_gather_func(tensor_list, tensor, group=None): + tensor_list[0] = tensor + tensor_list[1] = tensor.clone() + + mock_all_gather.side_effect = mock_all_gather_func + + decode_metadata = MagicMock() + decode_metadata.actual_seq_lengths_q = MagicMock() + decode_metadata.seq_lens_list = MagicMock() + decode_metadata.batch_seq_mask = torch.tensor([True, False], + dtype=torch.bool) + + result = self.impl._process_attn_out_lse(attn_output, softmax_lse, + decode_metadata) + + self.assertEqual(result[0].shape[0], B) + self.assertEqual(result[0].shape[1], N / self.impl.dcp_size) + self.assertEqual(result[0].shape[2], self.impl.kv_lora_rank + 1) + + @patch("torch.distributed.all_gather") + @patch("torch.distributed.all_to_all_single") + @patch('vllm_ascend.attention.mla_cp.get_forward_context') + @patch("torch_npu.atb.npu_multi_head_latent_attention") + @patch('torch_npu.npu_attention_update') + def test_forward_decode_pcp_dcp(self, mock_npu_attention_update, + mock_npu_multi_head_latent_attention, + mock_get_forward_context, + mock_all_to_all_single, mock_all_gather): + self.impl.dcp_size = 2 + self.impl.pcp_size = 2 + self.impl.num_kv_heads = 1 + self.impl.num_heads = 16 + self.impl.kv_lora_rank = 64 + self.impl.qk_nope_head_dim = 64 + self.impl.spec_token_num = 1 + B = 2 + N = self.impl.num_heads * self.impl.dcp_size + BS = 128 + NB = 100 + + q_nope = torch.randn(B, N, self.impl.qk_nope_head_dim) + q_pe = torch.randn(B, N, self.impl.qk_rope_head_dim) + k_nope = torch.randn(NB, BS, 1, self.impl.kv_lora_rank) + k_pe = torch.randn(NB, BS, 1, self.impl.qk_rope_head_dim) + + attn_metadata = MagicMock() + attn_metadata.attn_state = AscendAttentionState.SpecDecoding + attn_metadata.decode = MagicMock() + attn_metadata.decode.actual_seq_lengths_q = MagicMock() + attn_metadata.decode.seq_lens_list = MagicMock() + attn_metadata.decode.batch_seq_mask = torch.tensor([False, False], + dtype=torch.bool) + + self.impl.enable_kv_nz = True + + mock_npu_attention_update.return_value = (torch.randn( + B, self.impl.num_heads, self.impl.kv_lora_rank), None) + mock_npu_multi_head_latent_attention.return_value = [ + torch.randn(B, N, self.impl.kv_lora_rank), + torch.randn(B, N, 1) + ] + mock_get_forward_context.return_value = MagicMock(capturing=False) + + mock_all_to_all_single.side_effect = lambda output, input, *args, **kwargs: output.copy_( + input) + + def mock_all_gather_func(tensor_list, tensor, group=None): + tensor_list[0] = tensor + tensor_list[1] = tensor.clone() + + mock_all_gather.side_effect = mock_all_gather_func + + self.impl._v_up_proj = MagicMock() + self.impl._v_up_proj.return_value = torch.randn( + B, self.impl.v_head_dim) + + result = self.impl._forward_decode_pcp_dcp(q_nope, q_pe, k_nope, k_pe, + BS, attn_metadata) + + self.assertEqual(result.shape[0], B) + self.assertEqual(result.shape[1], self.impl.v_head_dim) diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py old mode 100644 new mode 100755 index 2c74c446..5061ff37 --- a/tests/ut/attention/test_mla_v1.py +++ b/tests/ut/attention/test_mla_v1.py @@ -75,6 +75,12 @@ class TestAscendMLAPrefillMetadata(TestBase): max_seq_lens = [2, 2] workspace = torch.randn(2, 4) chunk_seq_lens = torch.tensor([2, 2]) + padded_chunk_seq_lens_npu = torch.tensor([2, 2]) + padded_local_chunk_seq_lens = [[2], [2]] + local_context_lens_allranks = [[1, 1], [1, 1]] + padded_local_cu_seq_lens = torch.tensor([0, 2, 4]) + cu_seq_lens_lst = [[0, 2], [2, 4]] + chunk_size = 2 chunked_context = AscendMLAPrefillMetadata.ChunkedContextMetadata( cu_seq_lens=cu_seq_lens, @@ -83,7 +89,13 @@ class TestAscendMLAPrefillMetadata(TestBase): max_seq_lens=max_seq_lens, workspace=workspace, chunk_seq_lens=chunk_seq_lens, - chunk_seq_lens_npu=chunk_seq_lens) + chunk_seq_lens_npu=chunk_seq_lens, + padded_chunk_seq_lens_npu=padded_chunk_seq_lens_npu, + padded_local_chunk_seq_lens=padded_local_chunk_seq_lens, + local_context_lens_allranks=local_context_lens_allranks, + padded_local_cu_seq_lens=padded_local_cu_seq_lens, + cu_seq_lens_lst=cu_seq_lens_lst, + chunk_size=chunk_size) metadata = AscendMLAPrefillMetadata( attn_mask=torch.tensor([[1, 0], [1, 1]], dtype=torch.bool), @@ -106,6 +118,17 @@ class TestAscendMLAPrefillMetadata(TestBase): self.assertIs(metadata.chunked_context.chunk_seq_lens, chunk_seq_lens) self.assertIs(metadata.chunked_context.chunk_seq_lens_npu, chunk_seq_lens) + self.assertIs(metadata.chunked_context.padded_chunk_seq_lens_npu, + padded_chunk_seq_lens_npu) + self.assertEqual(metadata.chunked_context.padded_local_chunk_seq_lens, + padded_local_chunk_seq_lens) + self.assertEqual(metadata.chunked_context.local_context_lens_allranks, + local_context_lens_allranks) + self.assertIs(metadata.chunked_context.padded_local_cu_seq_lens, + padded_local_cu_seq_lens) + self.assertEqual(metadata.chunked_context.cu_seq_lens_lst, + cu_seq_lens_lst) + self.assertEqual(metadata.chunked_context.chunk_size, chunk_size) class TestAscendMLADecodeMetadata(TestBase): @@ -117,10 +140,17 @@ class TestAscendMLADecodeMetadata(TestBase): max_seq_lens = 4 seq_lens_list = [2, 3] attn_mask = None + cp_seq_len = torch.tensor([2, 3]) + batch_seq_mask = torch.tensor([[1, 1, 0, 0], [1, 1, 1, 0]]) - metadata = AscendMLADecodeMetadata(input_positions, block_table, - seq_lens, max_seq_lens, - seq_lens_list, attn_mask) + metadata = AscendMLADecodeMetadata(input_positions=input_positions, + block_table=block_table, + seq_lens=seq_lens, + max_seq_lens=max_seq_lens, + seq_lens_list=seq_lens_list, + attn_mask=attn_mask, + cp_seq_len=cp_seq_len, + batch_seq_mask=batch_seq_mask) self.assertIs(metadata.input_positions, input_positions) self.assertIs(metadata.block_table, block_table) @@ -128,6 +158,8 @@ class TestAscendMLADecodeMetadata(TestBase): self.assertEqual(metadata.max_seq_lens, max_seq_lens) self.assertEqual(metadata.seq_lens_list, seq_lens_list) self.assertIsNone(attn_mask) + self.assertIs(metadata.cp_seq_len, cp_seq_len) + self.assertIs(metadata.batch_seq_mask, batch_seq_mask) class TestAscendMLAMetadata(TestBase): @@ -200,17 +232,19 @@ class TestAscendMLAMetadataBuilder(TestBase): mock_vllm_config.scheduler_config.enable_chunked_prefill = False mock_device = 'cpu' - mock_dcp.world_size = 1 + mock_dcp.world_size = 2 + mock_dcp.rank_in_group = 0 dcp_group = MagicMock(spec=GroupCoordinator) dcp_group.rank_in_group = 0 - dcp_group.world_size = 1 + dcp_group.world_size = 2 dcp_group.device_group = MagicMock() mock_get_dcp_group.return_value = dcp_group - mock_pcp.world_size = 1 + mock_pcp.world_size = 2 + mock_pcp.rank_in_group = 0 pcp_group = MagicMock(spec=GroupCoordinator) pcp_group.rank_in_group = 0 - pcp_group.world_size = 1 + pcp_group.world_size = 2 pcp_group.device_group = MagicMock() mock_get_pcp_group.return_value = pcp_group @@ -227,6 +261,8 @@ class TestAscendMLAMetadataBuilder(TestBase): self.assertEqual( builder.chunked_prefill_enabled, mock_vllm_config.scheduler_config.enable_chunked_prefill) + self.assertEqual(builder.dcp_size, mock_dcp.world_size) + self.assertEqual(builder.pcp_size, mock_pcp.world_size) @patch('vllm.distributed.parallel_state.get_pcp_group') @patch('vllm.distributed.parallel_state._PCP',