diff --git a/tests/ut/attention/test_mla_cp.py b/tests/ut/attention/test_mla_cp.py index 7607d27a..2d46212c 100755 --- a/tests/ut/attention/test_mla_cp.py +++ b/tests/ut/attention/test_mla_cp.py @@ -451,10 +451,10 @@ class TestAscendMLAImpl(TestBase): self.assertIsNone(decode_res) self.assertIsNotNone(prefill_res) - @patch("torch.distributed.all_gather") + @patch('vllm.distributed.parallel_state._PCP', + new_callable=lambda: MagicMock(spec=GroupCoordinator)) @patch("torch.distributed.all_to_all_single") - def test_process_attn_out_lse(self, mock_all_to_all_single, - mock_all_gather): + def test_process_attn_out_lse(self, mock_all_to_all_single, mock_pcp): self.impl.dcp_size = 2 self.impl.pcp_size = 2 @@ -468,11 +468,10 @@ class TestAscendMLAImpl(TestBase): mock_all_to_all_single.side_effect = lambda output, input, *args, **kwargs: output.copy_( input) - def mock_all_gather_func(tensor_list, tensor, group=None): - tensor_list[0] = tensor - tensor_list[1] = tensor.clone() + def make_all_gather(ws): + return lambda tensor, dim: torch.cat([tensor] * ws, dim=dim) - mock_all_gather.side_effect = mock_all_gather_func + mock_pcp.all_gather = MagicMock(side_effect=make_all_gather(2)) decode_metadata = MagicMock() decode_metadata.actual_seq_lengths_q = MagicMock() @@ -483,11 +482,12 @@ class TestAscendMLAImpl(TestBase): result = self.impl._process_attn_out_lse(attn_output, softmax_lse, decode_metadata) - self.assertEqual(result[0].shape[0], B) - self.assertEqual(result[0].shape[1], N / self.impl.dcp_size) - self.assertEqual(result[0].shape[2], self.impl.kv_lora_rank + 1) + self.assertEqual(result.shape[0], B * self.impl.pcp_size) + self.assertEqual(result.shape[1], N) + self.assertEqual(result.shape[2], self.impl.kv_lora_rank + 1) - @patch("torch.distributed.all_gather") + @patch('vllm.distributed.parallel_state._PCP', + new_callable=lambda: MagicMock(spec=GroupCoordinator)) @patch("torch.distributed.all_to_all_single") @patch('vllm_ascend.attention.mla_cp.get_forward_context') @patch("torch_npu.atb.npu_multi_head_latent_attention") @@ -495,7 +495,7 @@ class TestAscendMLAImpl(TestBase): def test_forward_decode_pcp_dcp(self, mock_npu_attention_update, mock_npu_multi_head_latent_attention, mock_get_forward_context, - mock_all_to_all_single, mock_all_gather): + mock_all_to_all_single, mock_pcp): self.impl.dcp_size = 2 self.impl.pcp_size = 2 self.impl.num_kv_heads = 1 @@ -534,11 +534,10 @@ class TestAscendMLAImpl(TestBase): mock_all_to_all_single.side_effect = lambda output, input, *args, **kwargs: output.copy_( input) - def mock_all_gather_func(tensor_list, tensor, group=None): - tensor_list[0] = tensor - tensor_list[1] = tensor.clone() + def make_all_gather(ws): + return lambda tensor, dim: torch.cat([tensor] * ws, dim=dim) - mock_all_gather.side_effect = mock_all_gather_func + mock_pcp.all_gather = MagicMock(side_effect=make_all_gather(2)) self.impl._v_up_proj = MagicMock() self.impl._v_up_proj.return_value = torch.randn( @@ -562,9 +561,6 @@ class TestAscendMLAImpl(TestBase): def mock_all_gather(ws): return lambda tensor, dim: torch.cat([tensor] * ws, dim=dim) - mock_dcp.all_gather = MagicMock(side_effect=mock_all_gather(2)) - mock_pcp.all_gather = MagicMock(side_effect=mock_all_gather(2)) - def mock_ring_attn(q_nope, q_rope, k_nope, k_rope, value, mask, seqlen, head_num, kv_head_num, pre_out, prev_lse, qk_scale, kernel_type, mask_type, input_layout, calc_type, @@ -624,6 +620,10 @@ class TestAscendMLAImpl(TestBase): torch.ones(10, 10, dtype=torch.float16), 1) for test_case in test_cases: pcp_size, dcp_size, nums_tokens_per_rank, nums_all_rank_context, num_prefills, num_decodes, num_seqs, cp_local_block_size, num_computed_tokens_of_pcp_dcp = test_case + mock_dcp.all_gather = MagicMock( + side_effect=mock_all_gather(dcp_size)) + mock_pcp.all_gather = MagicMock( + side_effect=mock_all_gather(pcp_size)) assert len(nums_tokens_per_rank) == len(nums_all_rank_context) nums_context_per_rank = [] for num_all_rank_context in nums_all_rank_context: @@ -804,11 +804,10 @@ class TestAscendMLAImpl(TestBase): attn_lse_split_cp[0]) mock_npu_attention_update.side_effect = mock_npu_attention_update_effect - attn_out_lse_list = [ - torch.randn(NUM_TOKENS, num_heads, head_dim) - for _ in range(self.impl.pcp_size * self.impl.dcp_size) - ] - out = self.impl._npu_attention_update(attn_out_lse_list) + attn_out_lse = torch.randn(self.impl.pcp_size * NUM_TOKENS, + self.impl.dcp_size * num_heads, + head_dim) + out = self.impl._npu_attention_update(attn_out_lse) self.impl.dcp_size = 1 self.impl.pcp_size = 1 assert out.shape == (NUM_TOKENS, num_heads, self.impl.kv_lora_rank) @@ -908,12 +907,14 @@ class TestAscendMLAImpl(TestBase): mock_npu_ring_mla.reset_mock() @patch("torch.distributed.all_to_all_single") - @patch("torch.distributed.all_gather") - def test_process_attn_out_lse_with_dcp_pcp(self, mock_all_gather, + @patch('vllm.distributed.parallel_state._PCP', + new_callable=lambda: MagicMock(spec=GroupCoordinator)) + def test_process_attn_out_lse_with_dcp_pcp(self, mock_pcp, mock_all_to_all): B, H, D = 4, self.impl.num_heads, self.impl.v_head_dim # total: [4, 4, 8] test_cases = [(1, 1), (1, 2), (2, 1), (2, 2), (4, 4)] for test_case in test_cases: + print(test_case) self.impl.dcp_size = test_case[0] self.impl.pcp_size = test_case[1] # Inputs @@ -928,26 +929,17 @@ class TestAscendMLAImpl(TestBase): mock_all_to_all.side_effect = mock_all_to_all_side_effect - def mock_all_gather_side_effect(tensor_list, tensor, group=None): - for i in range(len(tensor_list)): - tensor_list[i].copy_(tensor) + def mock_all_gather(ws): + return lambda tensor, dim: torch.cat([tensor] * ws, dim=dim) - mock_all_gather.side_effect = mock_all_gather_side_effect + mock_pcp.all_gather = MagicMock( + side_effect=mock_all_gather(self.impl.pcp_size)) result = self.impl._process_attn_out_lse(attn_output, softmax_lse, decode_meta) - - self.assertIsInstance(result, list) - if self.impl.dcp_size == 1 and self.impl.pcp_size == 1: - self.assertEqual(len(result), 0) - else: - self.assertEqual(len(result), - self.impl.dcp_size * self.impl.pcp_size) # 4 - - for tensor in result: - self.assertEqual(tensor.dtype, torch.float32) - self.assertEqual(tensor.shape, - (B, H // self.impl.dcp_size, D + 1)) + # [PCP * S, DCP * H, D + 1] + self.assertIsInstance(result, torch.Tensor) + assert result.shape == (B * self.impl.pcp_size, H, D + 1) self.impl.dcp_size = 1 self.impl.pcp_size = 1 diff --git a/vllm_ascend/attention/mla_cp.py b/vllm_ascend/attention/mla_cp.py index 980f3c50..c8e753ad 100644 --- a/vllm_ascend/attention/mla_cp.py +++ b/vllm_ascend/attention/mla_cp.py @@ -1,4 +1,4 @@ -from typing import ClassVar, List, Optional, Tuple, TypeVar +from typing import ClassVar, Optional, Tuple, TypeVar import numpy as np import torch @@ -1120,26 +1120,37 @@ class AscendMlaCPImpl(AscendMLAImpl): lse=softmax_lse) # Update out&lse - attn_out_lse_list = self._process_attn_out_lse(attn_output, - softmax_lse, - decode_meta) - attn_output = self._npu_attention_update(attn_out_lse_list) + attn_out_lse = self._process_attn_out_lse(attn_output, softmax_lse, + decode_meta) + attn_output = self._npu_attention_update(attn_out_lse) return self._v_up_proj(attn_output) - def _npu_attention_update( - self, attn_out_lse_list: List[torch.Tensor]) -> torch.Tensor: - attn_out_split_cp = [] - attn_lse_split_cp = [] - - for attn_out_lse in attn_out_lse_list: - attn_out_allgather, attn_lse_allgather = self._out_lse_reshape( - *torch.split(attn_out_lse, [self.kv_lora_rank, 1], dim=-1)) - attn_out_split_cp.append(attn_out_allgather) - attn_lse_split_cp.append(attn_lse_allgather) - attn_out, _ = torch_npu.npu_attention_update(attn_lse_split_cp, - attn_out_split_cp, 0) - attn_out = attn_out.view(-1, attn_out_lse_list[0].shape[1], - self.kv_lora_rank) + def _npu_attention_update(self, + attn_out_lse: torch.Tensor) -> torch.Tensor: + # [PCP * S, DCP * H, D+1] + B_total, H_total, D_plus_1 = attn_out_lse.shape + S = B_total // self.pcp_size + H = H_total // self.dcp_size + D = self.kv_lora_rank + assert D_plus_1 == D + 1 + # [PCP, S, DCP, H, D+1] + x = attn_out_lse.view(self.pcp_size, S, self.dcp_size, H, D_plus_1) + # [PCP, DCP, S, H, D+1] + x = x.permute(0, 2, 1, 3, 4).contiguous() + # Flatten [N, S, H, D+1], N = pcp_size * dcp_size + x = x.view(-1, S, H, D_plus_1) + # Split out lse + out_flat, lse_flat = torch.split(x, [D, 1], + dim=-1) # [N, S, H, D], [N, S, H, 1] + # out: [N, S, H, D] -> [N, S*H, D] + # lse: [N, S, H, 1] -> [N, S*H] + out_flat = out_flat.flatten(1, 2) # [N, S*H, D] + lse_flat = lse_flat.flatten(1, -1) # [N, S*H] + # unbind to list + out_list = out_flat.unbind(0) # [S*H, D] + lse_list = lse_flat.unbind(0) # [S*H] + attn_out, _ = torch_npu.npu_attention_update(lse_list, out_list, 0) + attn_out = attn_out.view(-1, H, D) return attn_out def _out_lse_reshape(self, attn_out: torch.Tensor, @@ -1155,8 +1166,7 @@ class AscendMlaCPImpl(AscendMLAImpl): attn_output: torch.Tensor, softmax_lse: torch.Tensor, decode_meta: AscendMLADecodeMetadata, - ) -> List[torch.Tensor]: - attn_out_lse_list = [] + ) -> torch.Tensor: out_mask = decode_meta.batch_seq_mask[:, None, None].expand_as(attn_output) attn_output = torch.where(out_mask, 0, attn_output) @@ -1175,30 +1185,14 @@ class AscendMlaCPImpl(AscendMLAImpl): dist.all_to_all_single(attn_out_lse_all2all, attn_out_lse, group=self.dcp_group) - # permute: [num_heads, v_head_dim+1, bs] -> [bs, num_heads, v_head_dim+1] - attn_out_lse_all2all = attn_out_lse_all2all.permute([2, 0, 1]) - if self.pcp_size > 1: - attn_out_lse = attn_out_lse_all2all.contiguous() - attn_out_lse_list = list( - torch.chunk(attn_out_lse_all2all, self.dcp_size, dim=1)) + attn_out_lse = attn_out_lse_all2all.permute([2, 0, 1]) if self.pcp_size > 1: - # AllGather out&lse within PCP group - attn_out_lse_list = [ - torch.empty_like(attn_out_lse) for _ in range(self.pcp_size) - ] - dist.all_gather(attn_out_lse_list, - attn_out_lse, - group=self.pcp_group) - if self.dcp_size > 1 and self.pcp_size > 1: - attn_out_lse_list_pcp_dcp = [] - for s in attn_out_lse_list: - attn_out_lse_list_split = list( - torch.chunk(s, self.dcp_size, dim=1)) - attn_out_lse_list_pcp_dcp += attn_out_lse_list_split - attn_out_lse_list = attn_out_lse_list_pcp_dcp + # AllGather out&lse within CP group + attn_out_lse = get_pcp_group().all_gather( + attn_out_lse.contiguous(), dim=0) - return attn_out_lse_list + return attn_out_lse def _reorg_kvcache( self,