From 49838d4bec1dc1675915939669a00a74c2e8c17f Mon Sep 17 00:00:00 2001 From: Feng Liu <46866849+ader47@users.noreply.github.com> Date: Mon, 22 Dec 2025 11:06:19 +0800 Subject: [PATCH] [Perf] vectorize PCP/DCP loops in attention_cp.py (#4944) ### What this PR does / why we need it? - Add explicit .contiguous() after permute/view to ensure mem-friendly layout - Replace nested PCP/DCP Python loops with fully vectorized tensor operations - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 --------- Signed-off-by: F.Liu Co-authored-by: F.Liu --- tests/ut/attention/test_attention_cp.py | 74 ++++++++---- vllm_ascend/attention/attention_cp.py | 154 +++++++++++------------- 2 files changed, 119 insertions(+), 109 deletions(-) diff --git a/tests/ut/attention/test_attention_cp.py b/tests/ut/attention/test_attention_cp.py index 781581c3..2794b281 100644 --- a/tests/ut/attention/test_attention_cp.py +++ b/tests/ut/attention/test_attention_cp.py @@ -106,6 +106,8 @@ class TestAscendAttentionCPImpl(TestBase): self.assertEqual(output.shape[1], 4) self.assertEqual(output.shape[2], 128) + @patch('vllm_ascend.attention.attention_cp.get_pcp_group') + @patch('vllm.distributed.parallel_state._PCP') @patch('vllm_ascend.attention.attention_cp.get_dcp_group') @patch('vllm.distributed.parallel_state._DCP') @patch("torch_npu.npu_fused_infer_attention_score") @@ -115,9 +117,10 @@ class TestAscendAttentionCPImpl(TestBase): def test_forward_decode_pcp_dcp(self, mock_get_forward_context, mock_all_to_all_single, mock_all_gather, mock_npu_fused_infer_attention_score, - mock_dcp, mock_get_dcp_group): + mock_dcp, mock_get_dcp_group, mock_pcp, + mock_pcp_group): - def mock_dcp_all_gather_func(tensor, dim): + def mock_all_gather_func(tensor, dim): return torch.cat([tensor, tensor], dim=dim) mock_dcp.world_size = 2 @@ -126,17 +129,27 @@ class TestAscendAttentionCPImpl(TestBase): dcp_group.rank_in_group = 0 dcp_group.world_size = 2 dcp_group.device_group = MagicMock() - dcp_group.all_gather = mock_dcp_all_gather_func + dcp_group.all_gather = mock_all_gather_func mock_get_dcp_group.return_value = dcp_group + mock_pcp.world_size = 2 + mock_pcp.rank_in_group = 0 + pcp_group = MagicMock(spec=GroupCoordinator) + pcp_group.rank_in_group = 0 + pcp_group.world_size = 2 + pcp_group = MagicMock(spec=GroupCoordinator) + pcp_group.all_gather = mock_all_gather_func + mock_pcp_group.return_value = pcp_group + query = torch.randn(2, 4, 128) self.impl.key_cache = torch.randn(100, 128, 1, 128) self.impl.value_cache = torch.randn(100, 128, 1, 128) def mock_npu_attention_update(attn_out_lse_list): - mock_output = torch.randn(attn_out_lse_list[0].shape[0], - attn_out_lse_list[0].shape[1], - attn_out_lse_list[0].shape[2] - 1) + mock_output = torch.randn( + attn_out_lse_list.shape[0] // mock_pcp.world_size, + attn_out_lse_list.shape[1] // mock_dcp.world_size, + attn_out_lse_list.shape[2] - 1) return mock_output self.impl._npu_attention_update = MagicMock() @@ -147,11 +160,11 @@ class TestAscendAttentionCPImpl(TestBase): mock_all_to_all_single.side_effect = lambda output, input, *args, **kwargs: output.copy_( input) - def mock_all_gather_func(tensor_list, tensor, group=None): + def mock_all_gather_func1(tensor_list, tensor, group=None): tensor_list[0] = tensor tensor_list[1] = tensor.clone() - mock_all_gather.side_effect = mock_all_gather_func + mock_all_gather.side_effect = mock_all_gather_func1 def mock_npu_fused_infer_attention_score_func(query, k_nope, value, **common_kwargs): @@ -202,8 +215,9 @@ class TestAscendAttentionCPImpl(TestBase): self.assertEqual(output.shape[1], 8) self.assertEqual(output.shape[2], 128) + @patch('vllm_ascend.attention.attention_cp.get_pcp_group') @patch('torch.ops.npu.npu_fused_infer_attention_score') - def test_compute_prefill_context(self, mock_npu_attention): + def test_compute_prefill_context(self, mock_npu_attention, mock_pcp_group): block_num = 100 block_size = 128 @@ -238,6 +252,13 @@ class TestAscendAttentionCPImpl(TestBase): self.impl._load_kv_for_chunk = MagicMock() self.impl._load_kv_for_chunk.side_effect = mock_load_kv_for_chunk + def mock_all_gather_func(tensor, dim): + return torch.cat([tensor, tensor], dim=dim) + + pcp_group = MagicMock(spec=GroupCoordinator) + pcp_group.all_gather = mock_all_gather_func + mock_pcp_group.return_value = pcp_group + mock_npu_attention.return_value = torch.randn(batch_size, num_heads, head_size), torch.randn( batch_size, @@ -666,6 +687,9 @@ class TestUpdateNpuAttnOutLse(TestBase): self.assertIsInstance(out_final, torch.Tensor) self.assertIsInstance(lse_final, torch.Tensor) + @patch('vllm_ascend.attention.attention_cp.get_pcp_group') + @patch('vllm.distributed.parallel_state._PCP', + new_callable=lambda: MagicMock(spec=GroupCoordinator)) @patch('torch.cat') @patch('torch.distributed.all_to_all_single') @patch('torch.distributed.all_gather') @@ -673,7 +697,7 @@ class TestUpdateNpuAttnOutLse(TestBase): @patch('torch.split') def test_update_chunk_attn_out_lse_dcp_pcp_both_greater_than_1( self, mock_split, mock_stack, mock_all_gather, - mock_all_to_all_single, mock_cat): + mock_all_to_all_single, mock_cat, mock_pcp, mock_get_pcp_group): # Mock input data prefix_chunk_output = torch.randn(2, 4, 8) prefix_chunk_lse = torch.randn(2, 4, 1) @@ -687,6 +711,9 @@ class TestUpdateNpuAttnOutLse(TestBase): mock_stack.return_value = torch.randn(6, 2, 2, 9) mock_split.return_value = (torch.randn(6, 2, 2, 8), torch.randn(6, 2, 2, 1)) + mock_pcp_group = MagicMock() + mock_pcp_group.all_gather.return_value = torch.randn(6, 4, 9) + mock_get_pcp_group.return_value = mock_pcp_group # Call the method under test output, lse = self.impl._update_chunk_attn_out_lse( @@ -700,10 +727,10 @@ class TestUpdateNpuAttnOutLse(TestBase): self.assertEqual(mock_cat.call_count, 1) mock_all_to_all_single.assert_called_once() - mock_stack.assert_called_once() - mock_split.assert_called_once() - self.assertEqual(mock_all_gather.call_count, 1) + self.assertEqual(mock_get_pcp_group.call_count, 1) + @patch('vllm_ascend.attention.attention_cp.get_pcp_group') + @patch('vllm.distributed.parallel_state._PCP') @patch('torch.cat') @patch('torch.chunk') @patch('torch.stack') @@ -712,7 +739,7 @@ class TestUpdateNpuAttnOutLse(TestBase): @patch('torch.distributed.all_gather') def test_update_chunk_attn_out_lse_dcp_greater_than_1_only( self, mock_all_gather, mock_all_to_all_single, mock_split, - mock_stack, mock_chunk, mock_cat): + mock_stack, mock_chunk, mock_cat, mock_pcp, mock_pcp_group): # Mock input data prefix_chunk_output = torch.randn(2, 4, 8) prefix_chunk_lse = torch.randn(2, 4, 1) @@ -723,7 +750,8 @@ class TestUpdateNpuAttnOutLse(TestBase): # Mock output mock_cat.return_value = torch.randn(2, 4, 9) - mock_all_to_all_single.return_value = torch.randn(2, 4, 9) + mock_all_to_all_single.side_effect = lambda output, input, *args, **kwargs: output.copy_( + input) mock_chunk.return_value = [torch.randn(2, 2, 9), torch.randn(2, 2, 9)] mock_stack.return_value = torch.randn(2, 2, 2, 9) mock_split.return_value = [ @@ -743,11 +771,10 @@ class TestUpdateNpuAttnOutLse(TestBase): self.assertEqual(mock_cat.call_count, 1) mock_all_to_all_single.assert_called_once() - mock_chunk.assert_called_once() - mock_stack.assert_called_once() - mock_split.assert_called_once() mock_all_gather.assert_not_called() + @patch('vllm_ascend.attention.attention_cp.get_pcp_group') + @patch('vllm.distributed.parallel_state._PCP') @patch('torch.cat') @patch('torch.stack') @patch('torch.split') @@ -758,7 +785,8 @@ class TestUpdateNpuAttnOutLse(TestBase): ) def test_update_chunk_attn_out_lse_pcp_greater_than_1_only( self, mock_update_out_and_lse, mock_all_gather, - mock_all_to_all_single, mock_split, mock_stack, mock_cat): + mock_all_to_all_single, mock_split, mock_stack, mock_cat, mock_pcp, + mock_get_pcp_group): # Mock input data prefix_chunk_output = torch.randn(2, 4, 8) prefix_chunk_lse = torch.randn(2, 4, 1) @@ -769,7 +797,9 @@ class TestUpdateNpuAttnOutLse(TestBase): # Mock output mock_cat.return_value = torch.randn(2, 4, 9) - mock_all_gather.return_value = [(2, 4, 9), (2, 4, 9)] + mock_pcp_group = MagicMock() + mock_pcp_group.all_gather.return_value = torch.randn(4, 4, 9) + mock_get_pcp_group.return_value = mock_pcp_group mock_stack.return_value = torch.randn(2, 2, 4, 9) mock_split.return_value = [ torch.randn(2, 2, 4, 8), @@ -791,6 +821,4 @@ class TestUpdateNpuAttnOutLse(TestBase): self.assertEqual(mock_cat.call_count, 1) mock_all_to_all_single.assert_not_called() - mock_stack.assert_called_once() - mock_split.assert_called_once() - mock_all_gather.assert_called_once() + mock_get_pcp_group.assert_called_once() diff --git a/vllm_ascend/attention/attention_cp.py b/vllm_ascend/attention/attention_cp.py index 95ac7fcb..86a71cae 100644 --- a/vllm_ascend/attention/attention_cp.py +++ b/vllm_ascend/attention/attention_cp.py @@ -428,26 +428,36 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl): attn_lse.shape[0] * attn_lse.shape[1] * attn_lse.shape[2]) return attn_out, attn_lse - def _npu_attention_update( - self, attn_out_lse_list: List[torch.Tensor]) -> torch.Tensor: + def _npu_attention_update(self, + attn_out_lse: torch.Tensor) -> torch.Tensor: + B_total, H_total, D_plus_1 = attn_out_lse.shape + S = B_total // self.pcp_size + H = H_total // self.dcp_size + D = self.head_size update_type = 0 - - batch = attn_out_lse_list[0].shape[0] - num_heads = attn_out_lse_list[0].shape[1] - head_dim = attn_out_lse_list[0].shape[2] - 1 - - attn_out_split_cp = [] - attn_lse_split_cp = [] - - for i in attn_out_lse_list: - attn_out_allgather, attn_lse_allgather = self._out_lse_reshape( - *torch.split(i, [self.head_size, 1], dim=-1)) - attn_out_split_cp.append(attn_out_allgather) - attn_lse_split_cp.append(attn_lse_allgather) + assert D_plus_1 == D + 1 + # [PCP, S, DCP, H, D+1] + x = attn_out_lse.view(self.pcp_size, S, self.dcp_size, H, D_plus_1) + # [PCP, DCP, S, H, D+1] + x = x.permute(0, 2, 1, 3, 4).contiguous() + # Flatten [N, S, H, D+1], N = pcp_size * dcp_size + x = x.view(-1, S, H, D_plus_1) + # Split out lse + # [N, S, H, D], [N, S, H, 1] + out_flat, lse_flat = torch.split(x, [D, 1], dim=-1) + # out: [N, S, H, D] -> [N, S*H, D] + # lse: [N, S, H, 1] -> [N, S*H] + out_flat = out_flat.flatten(1, 2) + lse_flat = lse_flat.squeeze(-1).flatten(1) + # unbind to list + # [S*H, D] + out_list = out_flat.unbind(0) + # [S*H] + lse_list = lse_flat.unbind(0) attn_out, attn_lse = torch_npu.npu_attention_update( - attn_lse_split_cp, attn_out_split_cp, update_type) - attn_out = attn_out.view(batch, num_heads, head_dim) + lse_list, out_list, update_type) + attn_out = attn_out.view(S, H, D) return attn_out @@ -539,17 +549,10 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl): attn_out, attn_lse = torch_npu.npu_fused_infer_attention_score( query, k_nope, value, **common_kwargs) - out_mask = attn_metadata.decode_meta.batch_seq_mask[:, None, - None].expand_as( - attn_out) - attn_out = torch.where(out_mask, 0, attn_out) - lse_mask = attn_metadata.decode_meta.batch_seq_mask[:, None, None].expand_as( attn_lse) attn_lse = torch.where(lse_mask, -torch.inf, attn_lse) - - attn_out_lse_list = [] # Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1] attn_out_lse = torch.cat([attn_out, attn_lse], dim=-1) if self.dcp_size > 1: @@ -559,30 +562,14 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl): dist.all_to_all_single(attn_out_lse_all2all, attn_out_lse, group=self.dcp_group) - # permute: [num_heads, v_head_dim+1, bs] -> [bs, num_heads, v_head_dim+1] - attn_out_lse_all2all = attn_out_lse_all2all.permute([2, 0, 1]) - if self.pcp_size > 1: - attn_out_lse = attn_out_lse_all2all.contiguous() - attn_out_lse_list = list( - torch.chunk(attn_out_lse_all2all, self.dcp_size, dim=1)) + attn_out_lse = attn_out_lse_all2all.permute([2, 0, 1]) if self.pcp_size > 1: # AllGather out&lse within CP group - attn_out_lse_list = [ - torch.empty_like(attn_out_lse) for _ in range(self.pcp_size) - ] - dist.all_gather(attn_out_lse_list, - attn_out_lse, - group=self.pcp_group) - if self.dcp_size > 1 and self.pcp_size > 1: - attn_out_lse_list_pcp_dcp = [] - for s in attn_out_lse_list: - attn_out_lse_list_split = list( - torch.chunk(s, self.dcp_size, dim=1)) - attn_out_lse_list_pcp_dcp += attn_out_lse_list_split - attn_out_lse_list = attn_out_lse_list_pcp_dcp - # Update out&lse - attn_out = self._npu_attention_update(attn_out_lse_list) + attn_out_lse = get_pcp_group().all_gather( + attn_out_lse.contiguous(), dim=0) + + attn_out = self._npu_attention_update(attn_out_lse) return attn_out def _update_out_and_lse(self, out_list: torch.Tensor, @@ -739,35 +726,28 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl): dist.all_to_all_single(attn_out_lse_all2all, chunk_attn_out_lse, group=self.dcp_group) - attn_out_lse_all2all = attn_out_lse_all2all.permute([2, 0, 1]) - if self.pcp_size > 1: - chunk_attn_out_lse = attn_out_lse_all2all.contiguous() - - attn_out_lse_list = list( - torch.chunk(attn_out_lse_all2all, self.dcp_size, dim=1)) + chunk_attn_out_lse = attn_out_lse_all2all.permute([2, 0, 1]) if self.pcp_size > 1: - attn_out_lse_list = [ - torch.empty_like(chunk_attn_out_lse) - for _ in range(self.pcp_size) - ] - dist.all_gather(attn_out_lse_list, - chunk_attn_out_lse, - group=self.pcp_group) + # AllGather out&lse within CP group + chunk_attn_out_lse = get_pcp_group().all_gather( + chunk_attn_out_lse.contiguous(), dim=0) - if self.dcp_size > 1 and self.pcp_size > 1: - attn_out_lse_list_pcp_dcp = [] - for s in attn_out_lse_list: - attn_out_lse_list_split = list( - torch.chunk(s, self.dcp_size, dim=1)) - attn_out_lse_list_pcp_dcp += attn_out_lse_list_split - attn_out_lse_list = attn_out_lse_list_pcp_dcp - - attn_out_lse_allgather = torch.stack( - attn_out_lse_list, - dim=0) # [pcp, batch_size, num_heads, head_size+1] - attn_out_allgather, attn_lse_allgather = torch.split( - attn_out_lse_allgather, [self.head_size, 1], dim=-1) + B_total, H_total, D_plus_1 = chunk_attn_out_lse.shape + S = B_total // self.pcp_size + H = H_total // self.dcp_size + D = self.head_size + assert D_plus_1 == D + 1 + # [PCP, S, DCP, H, D+1] + x = chunk_attn_out_lse.view(self.pcp_size, S, self.dcp_size, H, + D_plus_1) + # [PCP, DCP, S, H, D+1] + x = x.permute(0, 2, 1, 3, 4).contiguous() + # Flatten [N, S, H, D+1], N = pcp_size * dcp_size + x = x.view(-1, S, H, D_plus_1) + # Split out lse. + # [N, S, H, D], [N, S, H, 1] + attn_out_allgather, attn_lse_allgather = torch.split(x, [D, 1], dim=-1) prefix_output, prefix_lse = self._update_out_and_lse( attn_out_allgather, attn_lse_allgather) @@ -842,19 +822,21 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl): pcp_allgather_restore_idx) key, value = all_kv.split([self.head_size, self.head_size], dim=-1) + prefill_key = key[self.pcp_size * + num_decode_tokens:attn_metadata. + num_actual_tokens_pcp_padded] + prefill_value = value[self.pcp_size * + num_decode_tokens:attn_metadata. + num_actual_tokens_pcp_padded] + slot_mapping = attn_metadata.slot_mapping[ + self.pcp_size * num_decode_tokens:attn_metadata. + num_actual_tokens_pcp_padded] + torch_npu._npu_reshape_and_cache(key=prefill_key, + value=prefill_value, + key_cache=self.key_cache, + value_cache=self.value_cache, + slot_indices=slot_mapping) - torch_npu._npu_reshape_and_cache( - key=key[self.pcp_size * num_decode_tokens:attn_metadata. - num_actual_tokens_pcp_padded], - value=value[self.pcp_size * - num_decode_tokens:attn_metadata. - num_actual_tokens_pcp_padded], - key_cache=self.key_cache, - value_cache=self.value_cache, - slot_indices=attn_metadata. - slot_mapping[self.pcp_size * - num_decode_tokens:attn_metadata. - num_actual_tokens_pcp_padded]) return key, value def forward_impl( @@ -879,9 +861,9 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl): assert attn_metadata.prefill is not None num_actual_tokens_pcp_padded = attn_metadata.num_actual_tokens_pcp_padded // self.pcp_size prefill_query = query[ - num_decode_tokens:num_actual_tokens_pcp_padded] - key = key[self.pcp_size * num_decode_tokens:] - value = value[self.pcp_size * num_decode_tokens:] + num_decode_tokens:num_actual_tokens_pcp_padded].contiguous() + key = key[self.pcp_size * num_decode_tokens:].contiguous() + value = value[self.pcp_size * num_decode_tokens:].contiguous() if self.pcp_size > 1: # Scenario of Enabling PCP or PCP&DCP attn_output_prefill, attn_lse_prefill = self._forward_prefill_cp(