diff --git a/tests/ut/attention/test_attention_cp.py b/tests/ut/attention/test_attention_cp.py index 2794b281..890a4794 100644 --- a/tests/ut/attention/test_attention_cp.py +++ b/tests/ut/attention/test_attention_cp.py @@ -1,8 +1,9 @@ +from functools import wraps from typing import List from unittest.mock import MagicMock, patch import torch -from vllm.distributed.parallel_state import GroupCoordinator +from vllm.distributed.parallel_state import GroupCoordinator, all_gather_fake from tests.ut.base import TestBase from vllm_ascend.attention.attention_cp import AscendAttentionCPImpl @@ -10,33 +11,60 @@ from vllm_ascend.attention.attention_v1 import (AscendMetadata, AscendMetadataForPrefill) +def patch_distributed_groups(dcp_size=1, dcp_rank=0, pcp_size=1, pcp_rank=0): + """ + Decorator to patch common distributed group mocks with configuration + + Args: + dcp_size: DCP world size (default: 1) + dcp_rank: DCP rank (default: 0) + pcp_size: PCP world size (default: 1) + pcp_rank: PCP rank (default: 0) + """ + + def decorator(func): + + @wraps(func) + @patch('torch.distributed.all_to_all_single') + @patch('vllm.distributed.parallel_state._PCP') + def wrapper(self, mock_pcp, mock_all_to_all_single, *args, **kwargs): + mock_pcp.world_size = pcp_size + mock_pcp.rank_in_group = pcp_rank + + mock_pcp.rank_in_group = pcp_rank + mock_pcp.world_size = pcp_size + mock_pcp.device_group = MagicMock() + mock_pcp.all_gather = MagicMock() + mock_pcp.all_gather.side_effect = lambda input_, dim: all_gather_fake( + input_, dim, pcp_size, "mock") + + mock_all_to_all_single.side_effect = lambda output, input, *a, **kw: output.copy_( + input) + + return func(self, mock_all_to_all_single, mock_pcp, *args, + **kwargs) + + return wrapper + + return decorator + + class TestAscendAttentionCPImpl(TestBase): - @patch('vllm_ascend.attention.attention_cp.get_pcp_group') @patch('vllm.distributed.parallel_state._PCP', new_callable=lambda: MagicMock(spec=GroupCoordinator)) - @patch('vllm_ascend.attention.attention_cp.get_dcp_group') @patch('vllm.distributed.parallel_state._DCP', new_callable=lambda: MagicMock(spec=GroupCoordinator)) @patch("vllm.distributed.get_decode_context_model_parallel_world_size", return_value=1) - def setUp(self, mock_get_dcp_size, mock_dcp, mock_get_dcp_group, mock_pcp, - mock_get_pcp_group): + def setUp(self, mock_get_dcp_size, mock_dcp, mock_pcp): mock_dcp.world_size = 2 mock_dcp.rank_in_group = 0 - dcp_group = MagicMock(spec=GroupCoordinator) - dcp_group.rank_in_group = 0 - dcp_group.world_size = 2 - dcp_group.device_group = MagicMock() - mock_get_dcp_group.return_value = dcp_group + mock_dcp.device_group = MagicMock() mock_pcp.world_size = 2 mock_pcp.rank_in_group = 0 - pcp_group = MagicMock(spec=GroupCoordinator) - pcp_group.rank_in_group = 0 - pcp_group.world_size = 2 - pcp_group.device_group = MagicMock() - mock_get_pcp_group.return_value = pcp_group + mock_pcp.device_group = MagicMock() self.layer = MagicMock() self.layer.layer_name = "test_layer" @@ -237,10 +265,14 @@ class TestAscendAttentionCPImpl(TestBase): attn_metadata = MagicMock() attn_metadata.prefill = MagicMock() attn_metadata.prefill.chunked_context = MagicMock() - attn_metadata.prefill.chunked_context.local_context_lens_allranks = torch.tensor( - [[[256, 256], [256, 256]]]) + local_context_lens_allranks = torch.tensor([[[256, 256], [256, 256]]]) + attn_metadata.prefill.chunked_context.local_context_lens_allranks = local_context_lens_allranks attn_metadata.prefill.chunked_context.batch_chunk_seq_mask = torch.randint( 0, 2, (1024, ), dtype=torch.bool) + attn_metadata.prefill.chunked_context.local_total_toks = local_context_lens_allranks[:, + 0, + 0].sum( + ) def mock_load_kv_for_chunk(attn_metadata, kv_cache, local_chunked_kv_lens_rank, query, @@ -264,8 +296,17 @@ class TestAscendAttentionCPImpl(TestBase): batch_size, num_heads, 1) - result_output, result_lse = self.impl._compute_prefill_context( + context_output = self.impl._compute_prefill_context( query, kv_cache, attn_metadata) + local_context_output = torch.cat(context_output, + dim=-1).permute([1, 2, + 0]).contiguous() + global_context_output = self.impl._gather_global_context_output( + local_context_output) + global_context_output = global_context_output.permute([2, 0, 1 + ]).contiguous() + result_output, result_lse = self.impl._update_global_context_output( + global_context_output) self.assertEqual(result_output.shape[0], batch_size) self.assertEqual(result_output.shape[1], self.impl.num_heads) @@ -687,37 +728,25 @@ class TestUpdateNpuAttnOutLse(TestBase): self.assertIsInstance(out_final, torch.Tensor) self.assertIsInstance(lse_final, torch.Tensor) - @patch('vllm_ascend.attention.attention_cp.get_pcp_group') - @patch('vllm.distributed.parallel_state._PCP', - new_callable=lambda: MagicMock(spec=GroupCoordinator)) - @patch('torch.cat') - @patch('torch.distributed.all_to_all_single') - @patch('torch.distributed.all_gather') - @patch('torch.stack') - @patch('torch.split') - def test_update_chunk_attn_out_lse_dcp_pcp_both_greater_than_1( - self, mock_split, mock_stack, mock_all_gather, - mock_all_to_all_single, mock_cat, mock_pcp, mock_get_pcp_group): + @patch_distributed_groups(dcp_size=2, pcp_size=3) + def test_update_chunk_attn_out_lse_dcp2_pcp3(self, mock_all_to_all_single, + mock_pcp): # Mock input data prefix_chunk_output = torch.randn(2, 4, 8) prefix_chunk_lse = torch.randn(2, 4, 1) self.impl.dcp_size = 2 self.impl.pcp_size = 3 self.impl.head_size = 8 - # Mock output - mock_cat.return_value = torch.randn(2, 4, 9) - mock_all_to_all_single.return_value = torch.randn(4, 9, 2) - mock_all_gather.return_value = [(2, 4, 9), (2, 4, 9), (2, 4, 9)] - mock_stack.return_value = torch.randn(6, 2, 2, 9) - mock_split.return_value = (torch.randn(6, 2, 2, - 8), torch.randn(6, 2, 2, 1)) - mock_pcp_group = MagicMock() - mock_pcp_group.all_gather.return_value = torch.randn(6, 4, 9) - mock_get_pcp_group.return_value = mock_pcp_group # Call the method under test - output, lse = self.impl._update_chunk_attn_out_lse( - prefix_chunk_output, prefix_chunk_lse) + chunk_data = torch.cat([prefix_chunk_output, prefix_chunk_lse], + dim=-1).permute([1, 2, 0]).contiguous() + global_context_output = self.impl._gather_global_context_output( + chunk_data) + global_context_output = global_context_output.permute([2, 0, 1 + ]).contiguous() + output, lse = self.impl._update_global_context_output( + global_context_output) # Assert the method call self.assertIsInstance(output, torch.Tensor) @@ -725,21 +754,12 @@ class TestUpdateNpuAttnOutLse(TestBase): self.assertEqual(output.shape, (2, 2, 8)) self.assertEqual(lse.shape, (2, 2, 1)) - self.assertEqual(mock_cat.call_count, 1) mock_all_to_all_single.assert_called_once() - self.assertEqual(mock_get_pcp_group.call_count, 1) + mock_pcp.all_gather.assert_called_once() - @patch('vllm_ascend.attention.attention_cp.get_pcp_group') - @patch('vllm.distributed.parallel_state._PCP') - @patch('torch.cat') - @patch('torch.chunk') - @patch('torch.stack') - @patch('torch.split') - @patch('torch.distributed.all_to_all_single') - @patch('torch.distributed.all_gather') - def test_update_chunk_attn_out_lse_dcp_greater_than_1_only( - self, mock_all_gather, mock_all_to_all_single, mock_split, - mock_stack, mock_chunk, mock_cat, mock_pcp, mock_pcp_group): + @patch_distributed_groups(dcp_size=2) + def test_update_chunk_attn_out_lse_dcp2_pcp1(self, mock_all_to_all_single, + mock_pcp): # Mock input data prefix_chunk_output = torch.randn(2, 4, 8) prefix_chunk_lse = torch.randn(2, 4, 1) @@ -748,20 +768,15 @@ class TestUpdateNpuAttnOutLse(TestBase): self.impl.pcp_size = 1 self.impl.head_size = 8 - # Mock output - mock_cat.return_value = torch.randn(2, 4, 9) - mock_all_to_all_single.side_effect = lambda output, input, *args, **kwargs: output.copy_( - input) - mock_chunk.return_value = [torch.randn(2, 2, 9), torch.randn(2, 2, 9)] - mock_stack.return_value = torch.randn(2, 2, 2, 9) - mock_split.return_value = [ - torch.randn(2, 2, 2, 8), - torch.randn(2, 2, 2, 1) - ] - # Call the method under test - output, lse = self.impl._update_chunk_attn_out_lse( - prefix_chunk_output, prefix_chunk_lse) + chunk_data = torch.cat([prefix_chunk_output, prefix_chunk_lse], + dim=-1).permute([1, 2, 0]).contiguous() + global_context_output = self.impl._gather_global_context_output( + chunk_data) + global_context_output = global_context_output.permute([2, 0, 1 + ]).contiguous() + output, lse = self.impl._update_global_context_output( + global_context_output) # Assert the method call self.assertIsInstance(output, torch.Tensor) @@ -769,24 +784,12 @@ class TestUpdateNpuAttnOutLse(TestBase): self.assertEqual(output.shape, (2, 2, 8)) self.assertEqual(lse.shape, (2, 2, 1)) - self.assertEqual(mock_cat.call_count, 1) mock_all_to_all_single.assert_called_once() - mock_all_gather.assert_not_called() + mock_pcp.all_gather.assert_not_called() - @patch('vllm_ascend.attention.attention_cp.get_pcp_group') - @patch('vllm.distributed.parallel_state._PCP') - @patch('torch.cat') - @patch('torch.stack') - @patch('torch.split') - @patch('torch.distributed.all_to_all_single') - @patch('torch.distributed.all_gather') - @patch( - 'vllm_ascend.attention.attention_cp.AscendAttentionCPImpl._update_out_and_lse' - ) - def test_update_chunk_attn_out_lse_pcp_greater_than_1_only( - self, mock_update_out_and_lse, mock_all_gather, - mock_all_to_all_single, mock_split, mock_stack, mock_cat, mock_pcp, - mock_get_pcp_group): + @patch_distributed_groups(pcp_size=2) + def test_update_chunk_attn_out_lse_dcp1_pcp2(self, mock_all_to_all_single, + mock_pcp): # Mock input data prefix_chunk_output = torch.randn(2, 4, 8) prefix_chunk_lse = torch.randn(2, 4, 1) @@ -795,30 +798,21 @@ class TestUpdateNpuAttnOutLse(TestBase): self.impl.pcp_size = 2 self.impl.head_size = 8 - # Mock output - mock_cat.return_value = torch.randn(2, 4, 9) - mock_pcp_group = MagicMock() - mock_pcp_group.all_gather.return_value = torch.randn(4, 4, 9) - mock_get_pcp_group.return_value = mock_pcp_group - mock_stack.return_value = torch.randn(2, 2, 4, 9) - mock_split.return_value = [ - torch.randn(2, 2, 4, 8), - torch.randn(2, 2, 4, 1) - ] - mock_update_out_and_lse.return_value = torch.randn(2, 4, - 8), torch.randn( - 2, 4, 1) # Call the method under test - output, lse = self.impl._update_chunk_attn_out_lse( - prefix_chunk_output, prefix_chunk_lse) + chunk_data = torch.cat([prefix_chunk_output, prefix_chunk_lse], + dim=-1).permute([1, 2, 0]).contiguous() + global_context_output = self.impl._gather_global_context_output( + chunk_data) + global_context_output = global_context_output.permute([2, 0, 1 + ]).contiguous() + output, lse = self.impl._update_global_context_output( + global_context_output) # Assert the method call self.assertIsInstance(output, torch.Tensor) self.assertIsInstance(lse, torch.Tensor) self.assertEqual(output.shape, (2, 4, 8)) self.assertEqual(lse.shape, (2, 4, 1)) - self.impl._update_out_and_lse.assert_called_once() - self.assertEqual(mock_cat.call_count, 1) mock_all_to_all_single.assert_not_called() - mock_get_pcp_group.assert_called_once() + mock_pcp.all_gather.assert_called_once() diff --git a/vllm_ascend/attention/attention_cp.py b/vllm_ascend/attention/attention_cp.py index c0906724..3919848e 100644 --- a/vllm_ascend/attention/attention_cp.py +++ b/vllm_ascend/attention/attention_cp.py @@ -40,7 +40,7 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, split_decodes_and_prefills) from vllm_ascend.compilation.acl_graph import (get_graph_params, update_graph_params_workspaces) -from vllm_ascend.utils import weak_ref_tensors +from vllm_ascend.utils import cp_chunkedprefill_comm_stream, weak_ref_tensors class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder): @@ -152,6 +152,7 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder): dcp_rank] actual_seq_lengths_kv = torch.cumsum( local_chunked_kv_lens_rank, dim=0).tolist() + local_total_toks = local_chunked_kv_lens_rank.sum() chunked_req_mask = self._get_chunked_req_mask( local_context_lens_allranks) local_chunk_starts = torch.zeros( @@ -181,7 +182,8 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder): cp_kv_recover_idx_for_chunk=cp_kv_recover_idx_for_chunk, kv_inverse_idx_for_chunk=kv_inverse_idx_for_chunk, batch_chunk_seq_mask=batch_chunk_seq_mask, - chunk_seq_mask_filtered_indices=chunk_seq_mask_filtered_indices + chunk_seq_mask_filtered_indices=chunk_seq_mask_filtered_indices, + local_total_toks=local_total_toks.item() ) attn_mask_seqlens = common_long_seq_metadata.attn_mask_seqlens head_attn_nomask_seqlens = common_long_seq_metadata.head_attn_nomask_seqlens @@ -372,6 +374,25 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl): def _forward_prefill_cp(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_metadata: AscendMetadata) -> torch.Tensor: + + data_head, data_tail = self._forward_prefill_cp_pre( + query, key, value, attn_metadata) + + output_head, lse_head = self._forward_prefill_cp_attn( + data_head, True, attn_metadata) + output_tail, lse_tail = self._forward_prefill_cp_attn( + data_tail, False, attn_metadata) + + output, attn_lse = self._forward_prefill_cp_post( + [output_head, output_tail], + [lse_head, lse_tail], + attn_metadata, + ) + return output, attn_lse + + def _forward_prefill_cp_pre(self, query: torch.Tensor, key: torch.Tensor, + value: torch.Tensor, + attn_metadata: AscendMetadata) -> torch.Tensor: assert attn_metadata is not None assert attn_metadata.prefill is not None assert attn_metadata.prefill.pcp_metadata is not None @@ -382,48 +403,53 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl): kv_with_q_head_mask_idx = attn_metadata.prefill.pcp_metadata.kv_with_q_head_mask_idx kv_with_q_tail_nomask_idx = attn_metadata.prefill.pcp_metadata.kv_with_q_tail_nomask_idx kv_with_q_tail_mask_idx = attn_metadata.prefill.pcp_metadata.kv_with_q_tail_mask_idx + q_head = torch.index_select(query, 0, q_head_idx) + q_tail = torch.index_select(query, 0, q_tail_idx) + k_head_nomask=torch.index_select(key, 0, kv_with_q_head_nomask_idx) \ + if self.pcp_rank > 0 else None + v_head_nomask=torch.index_select(value, 0, kv_with_q_head_nomask_idx) \ + if self.pcp_rank > 0 else None + k_head_mask = torch.index_select(key, 0, kv_with_q_head_mask_idx) + v_head_mask = torch.index_select(value, 0, kv_with_q_head_mask_idx) + k_tail_nomask = torch.index_select(key, 0, kv_with_q_tail_nomask_idx) + v_tail_nomask = torch.index_select(value, 0, kv_with_q_tail_nomask_idx) + k_tail_mask = torch.index_select(key, 0, kv_with_q_tail_mask_idx) + v_tail_mask = torch.index_select(value, 0, kv_with_q_tail_mask_idx) + return { + "q": q_head, + "k_nomask": k_head_nomask, + "v_nomask": v_head_nomask, + "k_mask": k_head_mask, + "v_mask": v_head_mask, + }, { + "q": q_tail, + "k_nomask": k_tail_nomask, + "v_nomask": v_tail_nomask, + "k_mask": k_tail_mask, + "v_mask": v_tail_mask, + }, + + def _forward_prefill_cp_attn(self, data, is_head, attn_metadata): attn_mask_seqlens = attn_metadata.prefill.pcp_metadata.attn_mask_seqlens - head_attn_nomask_seqlens = attn_metadata.prefill.pcp_metadata.head_attn_nomask_seqlens - tail_attn_nomask_seqlens = attn_metadata.prefill.pcp_metadata.tail_attn_nomask_seqlens + nomask_seqlens = attn_metadata.prefill.pcp_metadata.head_attn_nomask_seqlens \ + if is_head else attn_metadata.prefill.pcp_metadata.tail_attn_nomask_seqlens mask = attn_metadata.prefill.pcp_metadata.pcp_prefill_mask - - # 1. Attention calculation in the first half of Q in load balancing - output_heads, lse_heads = self._attention_with_nomask_and_mask( - q=torch.index_select(query, 0, q_head_idx), + output, lse = self._attention_with_nomask_and_mask( + **data, q_seqlens=attn_mask_seqlens, - k_nomask=torch.index_select(key, 0, kv_with_q_head_nomask_idx) - if self.pcp_rank > 0 else None, - v_nomask=torch.index_select(value, 0, kv_with_q_head_nomask_idx) - if self.pcp_rank > 0 else None, - kv_seqlens_nomask=head_attn_nomask_seqlens, - k_mask=torch.index_select(key, 0, kv_with_q_head_mask_idx), - v_mask=torch.index_select(value, 0, kv_with_q_head_mask_idx), - kv_seqlens_mask=attn_mask_seqlens, - mask=mask, - attn_metadata=attn_metadata) - - # 2. the Attention calculation in the latter half of Q in load balancing - # pcp_rank0: Q3*KV0~KV2 + Q3*KV3 - # pcp_rank1: Q2*KV0~KV1 + Q2*KV2 - output_tails, lse_tails = self._attention_with_nomask_and_mask( - q=torch.index_select(query, 0, q_tail_idx), - q_seqlens=attn_mask_seqlens, - k_nomask=torch.index_select(key, 0, kv_with_q_tail_nomask_idx), - v_nomask=torch.index_select(value, 0, kv_with_q_tail_nomask_idx), - kv_seqlens_nomask=tail_attn_nomask_seqlens, - k_mask=torch.index_select(key, 0, kv_with_q_tail_mask_idx), - v_mask=torch.index_select(value, 0, kv_with_q_tail_mask_idx), + kv_seqlens_nomask=nomask_seqlens, kv_seqlens_mask=attn_mask_seqlens, mask=mask, attn_metadata=attn_metadata) + return output, lse + def _forward_prefill_cp_post(self, outputs, lses, attn_metadata): q_full_idx = attn_metadata.prefill.pcp_metadata.q_full_idx - output = torch.index_select( - torch.cat([output_heads, output_tails], dim=0), 0, q_full_idx) + output = torch.index_select(torch.cat(outputs, dim=0), 0, q_full_idx) attn_lse = None if attn_metadata.prefill is not None and attn_metadata.prefill.chunked_context is not None: - attn_lse = torch.index_select( - torch.cat([lse_heads, lse_tails], dim=0), 0, q_full_idx) + attn_lse = torch.index_select(torch.cat(lses, dim=0), 0, + q_full_idx) return output, attn_lse def _out_lse_reshape(self, attn_out: torch.Tensor, @@ -598,19 +624,6 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl): dim=0) return out_final, lse_final - def _process_chunk_prefill(self, current_attn_output_prefill, - current_attn_lse_prefill, kv_cache, - prefill_query, attn_metadata): - if attn_metadata.prefill is not None and attn_metadata.prefill.chunked_context is not None: - prefill_query_all = self._prefill_query_all_gather( - attn_metadata, prefill_query) - attn_output_full_chunk, attn_lse_full_chunk = self._compute_prefill_context( - prefill_query_all, kv_cache, attn_metadata) - self._update_chunk_attn_out_lse_with_current_attn_out_lse( - current_attn_output_prefill, current_attn_lse_prefill, - attn_output_full_chunk, attn_lse_full_chunk, prefill_query, - attn_metadata) - def _update_chunk_attn_out_lse_with_current_attn_out_lse( self, current_attn_output_prefill, current_attn_lse_prefill, attn_output_full_chunk, attn_lse_full_chunk, prefill_query, @@ -646,18 +659,14 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl): current_attn_output_prefill.dtype) def _prefill_query_all_gather(self, attn_metadata, prefill_query): - if self.dcp_size > 1: - prefill_query = get_dcp_group().all_gather(prefill_query, 1) - if self.pcp_size > 1: prefill_query = get_pcp_group().all_gather(prefill_query, 0) - - prefill_query_all = torch.index_select(prefill_query, - 0, - attn_metadata.prefill.chunked_context.cp_kv_recover_idx_for_chunk) \ - if self.pcp_size > 1 else prefill_query - - return prefill_query_all + prefill_query = torch.index_select( + prefill_query, 0, attn_metadata.prefill.chunked_context. + cp_kv_recover_idx_for_chunk) + if self.dcp_size > 1: + prefill_query = get_dcp_group().all_gather(prefill_query, 1) + return prefill_query def _compute_prefill_context(self, query: torch.Tensor, kv_cache: Tuple[torch.Tensor], @@ -672,8 +681,7 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl): local_chunked_kv_lens_rank = local_chunked_kv_lens[:, self.pcp_rank, self.dcp_rank] - total_toks = local_chunked_kv_lens_rank.sum() - + total_toks = prefill_metadata.chunked_context.local_total_toks key, value = self._load_kv_for_chunk(attn_metadata, kv_cache, local_chunked_kv_lens_rank, query, total_toks) @@ -682,16 +690,7 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl): else: num_heads = self.num_heads - prefix_chunk_output = torch.full( - (query.size(0), num_heads, self.head_size), - fill_value=0, - dtype=query.dtype, - device=query.device) - prefix_chunk_lse = torch.full((query.size(0), num_heads, 1), - fill_value=-torch.inf, - dtype=torch.float32, - device=query.device) - + prefix_chunk_output, prefix_chunk_lse = None, None if total_toks > 0: prefix_chunk_output, prefix_chunk_lse = torch.ops.npu.npu_fused_infer_attention_score( query, @@ -711,59 +710,12 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl): actual_seq_lengths=attn_metadata.prefill.chunked_context. actual_chunk_seq_lengths) batch_chunk_seq_mask = attn_metadata.prefill.chunked_context.batch_chunk_seq_mask - out_mask = batch_chunk_seq_mask[:, None, None].expand_as( - prefix_chunk_output) - prefix_chunk_output = torch.where(out_mask, 0, prefix_chunk_output) lse_mask = batch_chunk_seq_mask[:, None, None].expand_as(prefix_chunk_lse) prefix_chunk_lse = torch.where(lse_mask, -torch.inf, prefix_chunk_lse) - prefix_output, prefix_lse = self._update_chunk_attn_out_lse( - prefix_chunk_output, prefix_chunk_lse) - - return prefix_output, prefix_lse - - def _update_chunk_attn_out_lse(self, prefix_chunk_output, - prefix_chunk_lse): - # CP dimension all_gather and fusion - chunk_attn_out_lse = torch.cat([prefix_chunk_output, prefix_chunk_lse], - dim=-1) - - if self.dcp_size > 1: - chunk_attn_out_lse = chunk_attn_out_lse.permute([1, 2, - 0]).contiguous() - attn_out_lse_all2all = torch.empty_like(chunk_attn_out_lse) - dist.all_to_all_single(attn_out_lse_all2all, - chunk_attn_out_lse, - group=self.dcp_group) - chunk_attn_out_lse = attn_out_lse_all2all.permute([2, 0, 1]) - - if self.pcp_size > 1: - # AllGather out&lse within CP group - chunk_attn_out_lse = get_pcp_group().all_gather( - chunk_attn_out_lse.contiguous(), dim=0) - - B_total, H_total, D_plus_1 = chunk_attn_out_lse.shape - S = B_total // self.pcp_size - H = H_total // self.dcp_size - D = self.head_size - assert D_plus_1 == D + 1 - # [PCP, S, DCP, H, D+1] - x = chunk_attn_out_lse.view(self.pcp_size, S, self.dcp_size, H, - D_plus_1) - # [PCP, DCP, S, H, D+1] - x = x.permute(0, 2, 1, 3, 4).contiguous() - # Flatten [N, S, H, D+1], N = pcp_size * dcp_size - x = x.view(-1, S, H, D_plus_1) - # Split out lse. - # [N, S, H, D], [N, S, H, 1] - attn_out_allgather, attn_lse_allgather = torch.split(x, [D, 1], dim=-1) - - prefix_output, prefix_lse = self._update_out_and_lse( - attn_out_allgather, attn_lse_allgather) - - return prefix_output, prefix_lse + return prefix_chunk_output, prefix_chunk_lse def _load_kv_for_chunk(self, attn_metadata, kv_cache, local_chunked_kv_lens_rank, query, total_toks): @@ -850,6 +802,45 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl): return key, value + def _gather_global_context_output(self, local_context_attn_output): + if self.dcp_size > 1: + dcp_context_attn_output = torch.empty_like( + local_context_attn_output) + dist.all_to_all_single(dcp_context_attn_output, + local_context_attn_output, + group=self.dcp_group) + else: + dcp_context_attn_output = local_context_attn_output + + if self.pcp_size > 1: + # AllGather out&lse within CP group + global_context_attn_output = get_pcp_group().all_gather( + dcp_context_attn_output, dim=-1) + else: + global_context_attn_output = dcp_context_attn_output + + return global_context_attn_output + + def _update_global_context_output(self, global_context_output): + B_total, H_total, D_plus_1 = global_context_output.shape + S = B_total // self.pcp_size + H = H_total // self.dcp_size + D = self.head_size + assert D_plus_1 == D + 1 + # [PCP, S, DCP, H, D+1] + x = global_context_output.view(self.pcp_size, S, self.dcp_size, H, + D_plus_1) + # [PCP, DCP, S, H, D+1] + x = x.permute(0, 2, 1, 3, 4).contiguous() + # Flatten [N, S, H, D+1], N = pcp_size * dcp_size + x = x.view(-1, S, H, D_plus_1) + # Split out lse + attn_out_allgather, attn_lse_allgather = torch.split( + x, [D, 1], dim=-1) # [N, S, H, D], [N, S, H, 1] + context_output, context_lse = self._update_out_and_lse( + attn_out_allgather, attn_lse_allgather) + return context_output, context_lse + def forward_impl( self, query: torch.Tensor, @@ -870,15 +861,38 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl): output[:num_decode_tokens] = output_decode if has_prefill: assert attn_metadata.prefill is not None + # chunked prefill vars init + has_chunked_context = attn_metadata.prefill.chunked_context is not None + # Note(qcs): we use multi-stream for computation-communication overlap + # when enabling chunked prefill. + # current part + # current_stream: init -- pre -- head attn ------------------ tail attn -- post -- update + # context part -/ + # current_stream: ----- -- context attn -- -/ + # COMM_STREAM: \-- all_gather Q --/ \-- a2a ag output --/ + + # qkv init num_actual_tokens_pcp_padded = attn_metadata.num_actual_tokens_pcp_padded // self.pcp_size prefill_query = query[ num_decode_tokens:num_actual_tokens_pcp_padded].contiguous() key = key[self.pcp_size * num_decode_tokens:].contiguous() value = value[self.pcp_size * num_decode_tokens:].contiguous() + + if has_chunked_context: + # all_gather q for chunked prefill // overlap the computation inner current chunk + cp_chunkedprefill_comm_stream().wait_stream( + torch.npu.current_stream()) + with torch_npu.npu.stream(cp_chunkedprefill_comm_stream()): + prefill_query_all = self._prefill_query_all_gather( + attn_metadata, prefill_query.clone()) + if self.pcp_size > 1: # Scenario of Enabling PCP or PCP&DCP - attn_output_prefill, attn_lse_prefill = self._forward_prefill_cp( + # prepare qkv and compute the head part // overlap the communication of all gather q + data_head, data_tail = self._forward_prefill_cp_pre( prefill_query, key, value, attn_metadata) + output_head, lse_head = self._forward_prefill_cp_attn( + data_head, True, attn_metadata) else: # Scenario of Enabling DCP Individually attn_output_prefill, attn_lse_prefill = torch.ops.npu.npu_fused_infer_attention_score( @@ -899,8 +913,46 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl): actual_seq_lengths=attn_metadata.prefill. actual_seq_lengths_q) - self._process_chunk_prefill(attn_output_prefill, attn_lse_prefill, - kv_cache, prefill_query, attn_metadata) + if has_chunked_context: + torch.npu.current_stream().wait_stream( + cp_chunkedprefill_comm_stream()) + # computation of context + context_output = self._compute_prefill_context( + prefill_query_all, kv_cache, attn_metadata) + # Note(qcs): (output, lse) -> [Seq, Head_num, Head_dim+1] -> [Head_num, Head_dim+1, Seq] + local_context_output = torch.cat( + context_output, dim=-1).permute([1, 2, 0]).contiguous() + + # all2all and all_gather output&lse // overlap the computation inner current chunk + cp_chunkedprefill_comm_stream().wait_stream( + torch.npu.current_stream()) + with torch_npu.npu.stream(cp_chunkedprefill_comm_stream()): + global_context_output = self._gather_global_context_output( + local_context_output) + + if self.pcp_size > 1: + # compute the tail part and reorg output&lse // overlap the communication of output + output_tail, lse_tail = self._forward_prefill_cp_attn( + data_tail, False, attn_metadata) + + attn_output_prefill, attn_lse_prefill = self._forward_prefill_cp_post( + [output_head, output_tail], + [lse_head, lse_tail], + attn_metadata, + ) + + if attn_metadata.prefill is not None and attn_metadata.prefill.chunked_context is not None: + # update the output of current chunk with context part + torch.npu.current_stream().wait_stream( + cp_chunkedprefill_comm_stream()) + global_context_output = global_context_output.permute( + [2, 0, 1]).contiguous() + context_output, context_lse = self._update_global_context_output( + global_context_output) + self._update_chunk_attn_out_lse_with_current_attn_out_lse( + attn_output_prefill, attn_lse_prefill, context_output, + context_lse, prefill_query, attn_metadata) + output[num_decode_tokens:attn_output_prefill.shape[0] + num_decode_tokens] = attn_output_prefill return output diff --git a/vllm_ascend/attention/utils.py b/vllm_ascend/attention/utils.py index ae5b5733..e5e0bfee 100644 --- a/vllm_ascend/attention/utils.py +++ b/vllm_ascend/attention/utils.py @@ -63,6 +63,7 @@ class AscendMetadataForPrefill: cp_kv_recover_idx_for_chunk: Optional[list[int]] = None kv_inverse_idx_for_chunk: Optional[list[int]] = None batch_chunk_seq_mask: Optional[list[bool]] = None + local_total_toks: Optional[int] = None """ Prefill Specific Metadata for Ascend""" pcp_metadata: Optional[AscendPCPMetadata] = None diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 51b87cfe..e959335e 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -55,6 +55,7 @@ _PREFETCH_STREAM = None _WEIGHT_PREFETCH_METHOD = None _GLOBAL_STREAM = None _SHARED_EXPERTS_CALCULATION_STREAM = None +_CP_CHUNKEDPREFILL_COMM_STREAM = None _ASCEND_CUSTOMOP_IS_REIGISTERED = False _DEFAULT_BUFFER_SIZE = 200 _MIN_DP_BUFFER_SIZE = 50 @@ -340,6 +341,13 @@ def shared_experts_calculation_stream() -> torch.npu.Stream: return _SHARED_EXPERTS_CALCULATION_STREAM +def cp_chunkedprefill_comm_stream() -> torch.npu.Stream: + global _CP_CHUNKEDPREFILL_COMM_STREAM + if _CP_CHUNKEDPREFILL_COMM_STREAM is None: + _CP_CHUNKEDPREFILL_COMM_STREAM = torch_npu.npu.Stream() + return _CP_CHUNKEDPREFILL_COMM_STREAM + + def adapt_patch(is_global_patch: bool = False): if is_global_patch: from vllm_ascend.patch import platform # noqa: F401 diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 2b21955e..7510b175 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1771,9 +1771,6 @@ class NPUModelRunner(GPUModelRunner): kv_cache_group_id].get_device_tensor() slot_mapping = self.input_batch.block_table[ kv_cache_group_id].slot_mapping - self.cp_kv_recover_idx = torch.zeros(self.max_num_tokens, - dtype=torch.int32, - device=self.device) long_seq_metadata = None if self.pcp_size * self.dcp_size == 1 else self.pcp_manager.generate_pcp_metadata( num_tokens, self.query_lens, self.attn_mask, self.input_batch)