From 4a3663327b545c94aa28e09d6cc28a661371ed8f Mon Sep 17 00:00:00 2001 From: wujinyuan1 Date: Mon, 5 Jan 2026 17:41:12 +0800 Subject: [PATCH] [Refactor]7/N Extract common code to common_cp (#5490) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit RFC: https://github.com/vllm-project/vllm-ascend/issues/4629 Reason: Eliminate duplicate code for two file(mla_cp.py attention_cp.py) to common_cp.py. vLLM version: 0.13.0rc3 vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 vLLM version: release/v0.13.0 vLLM main: https://github.com/vllm-project/vllm/commit/5fbfa8d9ef15948599631baeb91e8220b2ee9bcc - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/5326c89803566a131c928f7fdd2100b75c981a42 --------- Signed-off-by: wujinyuan1 Signed-off-by: wujinyuan1 Co-authored-by: wujinyuan1 --- tests/ut/attention/test_attention_cp.py | 46 +++--- tests/ut/attention/test_mla_cp.py | 36 +++-- vllm_ascend/attention/attention_v1.py | 10 +- vllm_ascend/attention/common_cp.py | 40 ----- .../attention/context_parallel/__init__.py | 0 .../{ => context_parallel}/attention_cp.py | 81 ++--------- .../attention/context_parallel/common_cp.py | 137 ++++++++++++++++++ .../{ => context_parallel}/mla_cp.py | 108 +++----------- vllm_ascend/attention/mla_v1.py | 49 ++++--- vllm_ascend/attention/utils.py | 46 ------ 10 files changed, 252 insertions(+), 301 deletions(-) delete mode 100644 vllm_ascend/attention/common_cp.py create mode 100644 vllm_ascend/attention/context_parallel/__init__.py rename vllm_ascend/attention/{ => context_parallel}/attention_cp.py (93%) create mode 100644 vllm_ascend/attention/context_parallel/common_cp.py rename vllm_ascend/attention/{ => context_parallel}/mla_cp.py (88%) diff --git a/tests/ut/attention/test_attention_cp.py b/tests/ut/attention/test_attention_cp.py index 3df7cb16..cc518fda 100644 --- a/tests/ut/attention/test_attention_cp.py +++ b/tests/ut/attention/test_attention_cp.py @@ -5,9 +5,11 @@ import torch from tests.ut.attention.utils import patch_distributed_groups from tests.ut.base import TestBase -from vllm_ascend.attention.attention_cp import AscendAttentionCPImpl -from vllm_ascend.attention.attention_v1 import (AscendMetadata, - AscendMetadataForPrefill) +from vllm_ascend.attention.attention_v1 import AscendMetadata +from vllm_ascend.attention.context_parallel.attention_cp import \ + AscendAttentionCPImpl +from vllm_ascend.attention.context_parallel.common_cp import ( + AscendMetadataForPrefill, AscendPCPMetadata) class TestAscendAttentionCPImpl(TestBase): @@ -82,25 +84,22 @@ class TestAscendAttentionCPImpl(TestBase): self.assertEqual(output.shape[1], 4) self.assertEqual(output.shape[2], 128) + @patch('torch_npu.npu_attention_update') @patch("torch_npu.npu_fused_infer_attention_score") - @patch('vllm_ascend.attention.attention_cp.get_forward_context') + @patch( + 'vllm_ascend.attention.context_parallel.attention_cp.get_forward_context' + ) @patch_distributed_groups(dcp_size=2, pcp_size=2) def test_forward_decode_pcp_dcp(self, mock_all2all, mock_dcp, mock_pcp, mock_get_forward_context, - mock_npu_fused_infer_attention_score): - query = torch.randn(2, 4, 128) - self.impl.key_cache = torch.randn(100, 128, 1, 128) - self.impl.value_cache = torch.randn(100, 128, 1, 128) + mock_npu_fused_infer_attention_score, + mock_npu_attention_update): + query = torch.randn(2, 4, 64) + self.impl.key_cache = torch.randn(100, 64, 1, 64) + self.impl.value_cache = torch.randn(100, 64, 1, 64) - def mock_npu_attention_update(attn_out_lse_list): - mock_output = torch.randn( - attn_out_lse_list.shape[0] // mock_pcp.world_size, - attn_out_lse_list.shape[1] // mock_dcp.world_size, - attn_out_lse_list.shape[2] - 1) - return mock_output - - self.impl._npu_attention_update = MagicMock() - self.impl._npu_attention_update.side_effect = mock_npu_attention_update + # Mock output + mock_npu_attention_update.return_value = (torch.randn(2 * 4, 64), None) mock_get_forward_context.return_value = MagicMock(capturing=False) @@ -116,12 +115,11 @@ class TestAscendAttentionCPImpl(TestBase): attn_metadata.decode_meta = MagicMock() attn_metadata.decode_meta.batch_seq_mask = torch.tensor( [1, 0], dtype=torch.bool) - output = self.impl._forward_decode_pcp_dcp(query, attn_metadata) self.assertEqual(output.shape[0], 2) self.assertEqual(output.shape[1], 4) - self.assertEqual(output.shape[2], 128) + self.assertEqual(output.shape[2], 64) @patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False) def test_prefill_query_all_gather(self): @@ -249,7 +247,7 @@ class TestAscendAttentionCPImpl(TestBase): attn_metadata.slot_mapping = torch.randn(2) attn_metadata.num_actual_tokens_pcp_padded = num_tokens * self.impl.pcp_size attn_metadata.prefill = MagicMock() - attn_metadata.prefill.pcp_allgather_restore_idx = torch.tensor( + attn_metadata.prefill.pcp_metadata.pcp_allgather_restore_idx = torch.tensor( [0, 3, 1, 2, 0, 0, 0, 0]) key = torch.randn(num_tokens, num_heads, head_size) @@ -336,7 +334,7 @@ class TestUpdateNpuAttnOutLse(TestBase): attn_metadata.num_actual_tokens = self.q_total_tokens prefill_metadata = AscendMetadataForPrefill() - pcp_metadata = AscendMetadataForPrefill.AscendPCPMetadata() + pcp_metadata = AscendPCPMetadata() pcp_metadata.attn_mask_seqlens = self.kv_seqlens_mask_cumsum pcp_metadata.head_attn_nomask_seqlens = self.kv_seqlens_nomask_cumsum pcp_metadata.tail_attn_nomask_seqlens = self.kv_seqlens_nomask_cumsum @@ -409,7 +407,7 @@ class TestUpdateNpuAttnOutLse(TestBase): @patch('torch.ops.npu.npu_fused_infer_attention_score') @patch( - 'vllm_ascend.attention.attention_cp.AscendAttentionCPImpl._update_out_and_lse' + 'vllm_ascend.attention.context_parallel.attention_cp.AscendAttentionCPImpl._update_out_and_lse' ) def test_attention_with_nomask_and_mask_chunk( self, mock_update_out_and_lse, @@ -457,7 +455,7 @@ class TestUpdateNpuAttnOutLse(TestBase): @patch('torch.ops.npu.npu_fused_infer_attention_score') @patch( - 'vllm_ascend.attention.attention_cp.AscendAttentionCPImpl._npu_attn_out_lse_update' + 'vllm_ascend.attention.context_parallel.attention_cp.AscendAttentionCPImpl._npu_attn_out_lse_update' ) def test_attention_with_nomask_and_mask_nochunk( self, mock_npu_attn_out_lse_update, @@ -505,7 +503,7 @@ class TestUpdateNpuAttnOutLse(TestBase): self.assertEqual(attn_lse, None) @patch( - 'vllm_ascend.attention.attention_cp.AscendAttentionCPImpl._npu_attn_out_lse_update' + 'vllm_ascend.attention.context_parallel.attention_cp.AscendAttentionCPImpl._npu_attn_out_lse_update' ) def test_update_chunk_attn_out_lse_with_current_attn_out_lse( self, mock_npu_attn_out_lse_update): diff --git a/tests/ut/attention/test_mla_cp.py b/tests/ut/attention/test_mla_cp.py index e85f4041..10358547 100755 --- a/tests/ut/attention/test_mla_cp.py +++ b/tests/ut/attention/test_mla_cp.py @@ -7,8 +7,9 @@ from tests.ut.attention.utils import patch_distributed_groups from tests.ut.base import TestBase from vllm_ascend.ascend_config import init_ascend_config from vllm_ascend.attention.attention_v1 import AscendAttentionState -from vllm_ascend.attention.common_cp import CPChunkedContextMetadata -from vllm_ascend.attention.mla_cp import AscendMlaCPImpl +from vllm_ascend.attention.context_parallel.common_cp import ( + CPChunkedContextMetadata, _npu_attention_update, _process_attn_out_lse) +from vllm_ascend.attention.context_parallel.mla_cp import AscendMlaCPImpl from vllm_ascend.attention.mla_v1 import ChunkedContextMetadata @@ -441,14 +442,14 @@ class TestAscendMLAImpl(TestBase): decode_metadata.batch_seq_mask = torch.tensor([True, False], dtype=torch.bool) - result = self.impl._process_attn_out_lse(attn_output, softmax_lse, - decode_metadata) + result = _process_attn_out_lse(attn_output, softmax_lse, + decode_metadata.batch_seq_mask) self.assertEqual(result.shape[0], B * self.impl.pcp_size) self.assertEqual(result.shape[1], N) self.assertEqual(result.shape[2], self.impl.kv_lora_rank + 1) - @patch('vllm_ascend.attention.mla_cp.get_forward_context') + @patch('vllm_ascend.attention.context_parallel.mla_cp.get_forward_context') @patch("torch_npu.atb.npu_multi_head_latent_attention") @patch('torch_npu.npu_attention_update') @patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False) @@ -725,7 +726,15 @@ class TestAscendMLAImpl(TestBase): assert torch.allclose(lse, expected_lse) @patch('torch_npu.npu_attention_update') - def test_npu_attention_update_with_dcp_pcp(self, + @patch('vllm_ascend.attention.context_parallel.common_cp.get_pcp_group') + @patch('vllm.distributed.parallel_state._PCP', + new_callable=lambda: MagicMock(spec=GroupCoordinator)) + @patch('vllm_ascend.attention.context_parallel.common_cp.get_dcp_group') + @patch('vllm.distributed.parallel_state._DCP', + new_callable=lambda: MagicMock(spec=GroupCoordinator)) + def test_npu_attention_update_with_dcp_pcp(self, mock_dcp, + mock_get_dcp_group, mock_pcp, + mock_get_pcp_group, mock_npu_attention_update): NUM_TOKENS = 10 # fixed test_cases = [(1, 1), (1, 2), (2, 1), (2, 2), (2, 3)] @@ -752,10 +761,19 @@ class TestAscendMLAImpl(TestBase): attn_lse_split_cp[0]) mock_npu_attention_update.side_effect = mock_npu_attention_update_effect + + mock_pcp_group = MagicMock() + mock_pcp_group.world_size = self.impl.pcp_size + mock_get_pcp_group.return_value = mock_pcp_group + + mock_dcp.world_size = self.impl.dcp_size + mock_dcp_group = MagicMock() + # mock_dcp_group.world_size = self.impl.dcp_size + mock_get_dcp_group.return_value = mock_dcp_group attn_out_lse = torch.randn(self.impl.pcp_size * NUM_TOKENS, self.impl.dcp_size * num_heads, head_dim) - out = self.impl._npu_attention_update(attn_out_lse) + out = _npu_attention_update(self.impl.kv_lora_rank, attn_out_lse) self.impl.dcp_size = 1 self.impl.pcp_size = 1 assert out.shape == (NUM_TOKENS, num_heads, self.impl.kv_lora_rank) @@ -873,8 +891,8 @@ class TestAscendMLAImpl(TestBase): decode_meta = MagicMock() decode_meta.batch_seq_mask = batch_seq_mask - result = self.impl._process_attn_out_lse(attn_output, softmax_lse, - decode_meta) + result = _process_attn_out_lse(attn_output, softmax_lse, + batch_seq_mask) # [PCP * S, DCP * H, D + 1] self.assertIsInstance(result, torch.Tensor) assert result.shape == (B * self.impl.pcp_size, H, D + 1) diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 54999de2..cbdb7da3 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -34,10 +34,10 @@ from vllm.v1.attention.backends.utils import (AttentionCGSupport, from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import AttentionSpec +from vllm_ascend.attention.context_parallel.common_cp import ( + AscendMetadataForDecode, AscendMetadataForPrefill) from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, - AscendMetadataForDecode, - AscendMetadataForPrefill, enable_cp, - split_decodes_and_prefills, + enable_cp, split_decodes_and_prefills, using_paged_attention) from vllm_ascend.compilation.acl_graph import ( get_draft_graph_params, get_graph_params, @@ -63,7 +63,7 @@ class AscendAttentionBackend(AttentionBackend): @staticmethod def get_impl_cls() -> Type["AscendAttentionBackendImpl"]: if enable_cp(): - from vllm_ascend.attention.attention_cp import \ + from vllm_ascend.attention.context_parallel.attention_cp import \ AscendAttentionCPImpl return AscendAttentionCPImpl return AscendAttentionBackendImpl @@ -71,7 +71,7 @@ class AscendAttentionBackend(AttentionBackend): @staticmethod def get_builder_cls() -> type["AscendAttentionMetadataBuilder"]: if enable_cp(): - from vllm_ascend.attention.attention_cp import \ + from vllm_ascend.attention.context_parallel.attention_cp import \ AscendAttentionCPMetadataBuilder return AscendAttentionCPMetadataBuilder return AscendAttentionMetadataBuilder diff --git a/vllm_ascend/attention/common_cp.py b/vllm_ascend/attention/common_cp.py deleted file mode 100644 index ccf39344..00000000 --- a/vllm_ascend/attention/common_cp.py +++ /dev/null @@ -1,40 +0,0 @@ -from dataclasses import dataclass -from typing import Optional - -import torch - - -@dataclass -class AscendPCPMetadata: - q_head_idx: torch.Tensor = None - q_tail_idx: torch.Tensor = None - kv_with_q_head_nomask_idx: torch.Tensor = None - kv_with_q_head_mask_idx: torch.Tensor = None - kv_with_q_tail_nomask_idx: torch.Tensor = None - kv_with_q_tail_mask_idx: torch.Tensor = None - attn_mask_seqlens: torch.Tensor = None - head_attn_nomask_seqlens: torch.Tensor = None - tail_attn_nomask_seqlens: torch.Tensor = None - q_full_idx: torch.Tensor = None - pcp_prefill_mask: torch.Tensor = None - pcp_allgather_restore_idx: Optional[list[int]] = None - - -@dataclass -class CPChunkedContextMetadata: - # New for MLA (compared to FlashAttention) - # For handling chunked prefill - cu_seq_lens: torch.Tensor - starts: torch.Tensor - seq_tot: list[int] - max_seq_lens: list[int] - workspace: torch.Tensor - chunk_seq_lens: torch.Tensor - chunk_seq_lens_npu: torch.Tensor - # for mla DCP & PCP - padded_chunk_seq_lens_npu: torch.Tensor = None - padded_local_chunk_seq_lens: Optional[list[list[int]]] = None - local_context_lens_allranks: Optional[list[list[int]]] = None - padded_local_cu_seq_lens: torch.Tensor = None - cu_seq_lens_lst: Optional[list[list[int]]] = None - chunk_size: Optional[int] = None diff --git a/vllm_ascend/attention/context_parallel/__init__.py b/vllm_ascend/attention/context_parallel/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/vllm_ascend/attention/attention_cp.py b/vllm_ascend/attention/context_parallel/attention_cp.py similarity index 93% rename from vllm_ascend/attention/attention_cp.py rename to vllm_ascend/attention/context_parallel/attention_cp.py index 3919848e..8d477909 100644 --- a/vllm_ascend/attention/attention_cp.py +++ b/vllm_ascend/attention/context_parallel/attention_cp.py @@ -33,9 +33,10 @@ from vllm.v1.kv_cache_interface import AttentionSpec from vllm_ascend.attention.attention_v1 import (AscendAttentionBackendImpl, AscendAttentionMetadataBuilder, AscendMetadata) +from vllm_ascend.attention.context_parallel.common_cp import ( + AscendMetadataForDecode, AscendMetadataForPrefill, AscendPCPMetadata, + _npu_attention_update, _process_attn_out_lse) from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, - AscendMetadataForDecode, - AscendMetadataForPrefill, filter_chunked_req_indices, split_decodes_and_prefills) from vllm_ascend.compilation.acl_graph import (get_graph_params, @@ -196,7 +197,7 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder): tail_attn_nomask_seqlens = torch.cumsum( tail_attn_nomask_seqlens[1], dim=0).tolist() - pcp_metadata = AscendMetadataForPrefill.AscendPCPMetadata( + pcp_metadata = AscendPCPMetadata( q_head_idx=common_long_seq_metadata.q_head_idx_tensor, q_tail_idx=common_long_seq_metadata.q_tail_idx_tensor, kv_with_q_head_nomask_idx=common_long_seq_metadata. @@ -211,13 +212,12 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder): head_attn_nomask_seqlens=head_attn_nomask_seqlens, tail_attn_nomask_seqlens=tail_attn_nomask_seqlens, q_full_idx=common_long_seq_metadata.q_full_idx, - pcp_prefill_mask=common_long_seq_metadata.pcp_prefill_mask) + pcp_prefill_mask=common_long_seq_metadata.pcp_prefill_mask, + pcp_allgather_restore_idx=common_long_seq_metadata. + pcp_allgather_restore_idx) prefill_metadata = AscendMetadataForPrefill( pcp_metadata=pcp_metadata, - pcp_allgather_restore_idx=common_long_seq_metadata. - pcp_allgather_restore_idx - if common_long_seq_metadata is not None else None, chunked_context=chunked_context_metadata, block_tables=block_table[num_decodes:], actual_seq_lengths_q=torch.cumsum(query_lens, dim=0)) @@ -460,39 +460,6 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl): attn_lse.shape[0] * attn_lse.shape[1] * attn_lse.shape[2]) return attn_out, attn_lse - def _npu_attention_update(self, - attn_out_lse: torch.Tensor) -> torch.Tensor: - B_total, H_total, D_plus_1 = attn_out_lse.shape - S = B_total // self.pcp_size - H = H_total // self.dcp_size - D = self.head_size - update_type = 0 - assert D_plus_1 == D + 1 - # [PCP, S, DCP, H, D+1] - x = attn_out_lse.view(self.pcp_size, S, self.dcp_size, H, D_plus_1) - # [PCP, DCP, S, H, D+1] - x = x.permute(0, 2, 1, 3, 4).contiguous() - # Flatten [N, S, H, D+1], N = pcp_size * dcp_size - x = x.view(-1, S, H, D_plus_1) - # Split out lse - # [N, S, H, D], [N, S, H, 1] - out_flat, lse_flat = torch.split(x, [D, 1], dim=-1) - # out: [N, S, H, D] -> [N, S*H, D] - # lse: [N, S, H, 1] -> [N, S*H] - out_flat = out_flat.flatten(1, 2) - lse_flat = lse_flat.squeeze(-1).flatten(1) - # unbind to list - # [S*H, D] - out_list = out_flat.unbind(0) - # [S*H] - lse_list = lse_flat.unbind(0) - - attn_out, attn_lse = torch_npu.npu_attention_update( - lse_list, out_list, update_type) - attn_out = attn_out.view(S, H, D) - - return attn_out - def _forward_decode_pcp_dcp(self, query: torch.Tensor, attn_metadata: AscendMetadata) -> torch.Tensor: assert self.key_cache is not None @@ -580,33 +547,9 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl): else: attn_out, attn_lse = torch_npu.npu_fused_infer_attention_score( query, k_nope, value, **common_kwargs) - - out_mask = attn_metadata.decode_meta.batch_seq_mask[:, None, - None].expand_as( - attn_out) - attn_out = torch.where(out_mask, 0, attn_out) - - lse_mask = attn_metadata.decode_meta.batch_seq_mask[:, None, - None].expand_as( - attn_lse) - attn_lse = torch.where(lse_mask, -torch.inf, attn_lse) - # Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1] - attn_out_lse = torch.cat([attn_out, attn_lse], dim=-1) - if self.dcp_size > 1: - # permute: [bs, num_heads, v_head_dim+1] -> [num_heads, v_head_dim+1, bs] - attn_out_lse = attn_out_lse.permute([1, 2, 0]).contiguous() - attn_out_lse_all2all = torch.empty_like(attn_out_lse) - dist.all_to_all_single(attn_out_lse_all2all, - attn_out_lse, - group=self.dcp_group) - attn_out_lse = attn_out_lse_all2all.permute([2, 0, 1]) - - if self.pcp_size > 1: - # AllGather out&lse within CP group - attn_out_lse = get_pcp_group().all_gather( - attn_out_lse.contiguous(), dim=0) - - attn_out = self._npu_attention_update(attn_out_lse) + attn_out_lse = _process_attn_out_lse( + attn_out, attn_lse, attn_metadata.decode_meta.batch_seq_mask) + attn_out = _npu_attention_update(self.head_size, attn_out_lse) return attn_out def _update_out_and_lse(self, out_list: torch.Tensor, @@ -780,7 +723,9 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl): num_actual_tokens_pcp_padded = attn_metadata.num_actual_tokens_pcp_padded // self.pcp_size all_kv = get_pcp_group().all_gather( kv[:num_actual_tokens_pcp_padded].contiguous(), dim=0) - pcp_allgather_restore_idx = attn_metadata.prefill.pcp_allgather_restore_idx if attn_metadata.prefill else None + assert attn_metadata.prefill is not None + assert attn_metadata.prefill.pcp_metadata is not None + pcp_allgather_restore_idx = attn_metadata.prefill.pcp_metadata.pcp_allgather_restore_idx all_kv = torch.index_select(all_kv, 0, pcp_allgather_restore_idx) key, value = all_kv.split([self.head_size, self.head_size], diff --git a/vllm_ascend/attention/context_parallel/common_cp.py b/vllm_ascend/attention/context_parallel/common_cp.py new file mode 100644 index 00000000..0652a9a5 --- /dev/null +++ b/vllm_ascend/attention/context_parallel/common_cp.py @@ -0,0 +1,137 @@ +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.distributed as dist +import torch_npu +from vllm.distributed import (get_dcp_group, + get_decode_context_model_parallel_world_size, + get_pcp_group) + + +@dataclass +class AscendPCPMetadata: + q_head_idx: torch.Tensor = None + q_tail_idx: torch.Tensor = None + kv_with_q_head_nomask_idx: torch.Tensor = None + kv_with_q_head_mask_idx: torch.Tensor = None + kv_with_q_tail_nomask_idx: torch.Tensor = None + kv_with_q_tail_mask_idx: torch.Tensor = None + attn_mask_seqlens: torch.Tensor = None + head_attn_nomask_seqlens: torch.Tensor = None + tail_attn_nomask_seqlens: torch.Tensor = None + q_full_idx: torch.Tensor = None + pcp_prefill_mask: torch.Tensor = None + pcp_allgather_restore_idx: Optional[list[int]] = None + + +@dataclass +class CPChunkedContextMetadata: + # New for MLA (compared to FlashAttention) + # For handling chunked prefill + cu_seq_lens: torch.Tensor + starts: torch.Tensor + seq_tot: list[int] + max_seq_lens: list[int] + workspace: torch.Tensor + chunk_seq_lens: torch.Tensor + chunk_seq_lens_npu: torch.Tensor + # for mla DCP & PCP + padded_chunk_seq_lens_npu: torch.Tensor = None + padded_local_chunk_seq_lens: Optional[list[list[int]]] = None + local_context_lens_allranks: Optional[list[list[int]]] = None + padded_local_cu_seq_lens: torch.Tensor = None + cu_seq_lens_lst: Optional[list[list[int]]] = None + chunk_size: Optional[int] = None + + +@dataclass +class AscendMetadataForPrefill: + + @dataclass + class ChunkedContextMetadata: + actual_chunk_seq_lengths: torch.Tensor + actual_seq_lengths_kv: torch.Tensor + starts: torch.Tensor + chunk_seq_mask_filtered_indices: torch.Tensor + chunked_req_mask: Optional[list[bool]] = None + local_context_lens_allranks: Optional[list[list[int]]] = None + cp_kv_recover_idx_for_chunk: Optional[list[int]] = None + kv_inverse_idx_for_chunk: Optional[list[int]] = None + batch_chunk_seq_mask: Optional[list[bool]] = None + local_total_toks: Optional[int] = None + + """ Prefill Specific Metadata for Ascend""" + pcp_metadata: Optional[AscendPCPMetadata] = None + chunked_context: Optional[ChunkedContextMetadata] = None + block_tables: torch.Tensor = None + actual_seq_lengths_q: torch.Tensor = None + + +@dataclass +class AscendMetadataForDecode: + """ Decode Specific Metadata for Ascend""" + num_computed_tokens_of_pcp_dcp: Optional[list[list[list[int]]]] = None + batch_seq_mask: torch.Tensor = None + block_tables: torch.Tensor = None + + +def _process_attn_out_lse(attn_output: torch.Tensor, softmax_lse: torch.Tensor, + batch_seq_mask: torch.Tensor) -> torch.Tensor: + pcp_size = get_pcp_group().world_size + dcp_size = get_decode_context_model_parallel_world_size() + dcp_group = get_dcp_group().device_group if dcp_size > 1 else None + out_mask = batch_seq_mask[:, None, None].expand_as(attn_output) + attn_output = torch.where(out_mask, 0, attn_output) + lse_mask = batch_seq_mask[:, None, None].expand_as(softmax_lse) + softmax_lse = torch.where(lse_mask, -torch.inf, softmax_lse) + softmax_lse = softmax_lse.to(torch.float32) + attn_output = attn_output.to(torch.float32) + # Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1] + attn_out_lse = torch.cat([attn_output, softmax_lse], dim=-1) + if dcp_size > 1: + # permute: [bs, num_heads, v_head_dim+1] -> [num_heads, v_head_dim+1, bs] + attn_out_lse = attn_out_lse.permute([1, 2, 0]).contiguous() + attn_out_lse_all2all = torch.empty_like(attn_out_lse) + dist.all_to_all_single(attn_out_lse_all2all, + attn_out_lse, + group=dcp_group) + attn_out_lse = attn_out_lse_all2all.permute([2, 0, 1]) + + if pcp_size > 1: + # AllGather out&lse within CP group + attn_out_lse = get_pcp_group().all_gather(attn_out_lse.contiguous(), + dim=0) + + return attn_out_lse + + +def _npu_attention_update(head_size, + attn_out_lse: torch.Tensor) -> torch.Tensor: + pcp_size = get_pcp_group().world_size + dcp_size = get_decode_context_model_parallel_world_size() + # [PCP * S, DCP * H, D+1] + B_total, H_total, D_plus_1 = attn_out_lse.shape + S = B_total // pcp_size + H = H_total // dcp_size + D = head_size + assert D_plus_1 == D + 1 + # [PCP, S, DCP, H, D+1] + x = attn_out_lse.view(pcp_size, S, dcp_size, H, D_plus_1) + # [PCP, DCP, S, H, D+1] + x = x.permute(0, 2, 1, 3, 4).contiguous() + # Flatten [N, S, H, D+1], N = pcp_size * dcp_size + x = x.view(-1, S, H, D_plus_1) + # Split out lse + out_flat, lse_flat = torch.split(x, [D, 1], + dim=-1) # [N, S, H, D], [N, S, H, 1] + # out: [N, S, H, D] -> [N, S*H, D] + # lse: [N, S, H, 1] -> [N, S*H] + out_flat = out_flat.flatten(1, 2) # [N, S*H, D] + lse_flat = lse_flat.flatten(1, -1) # [N, S*H] + # unbind to list + out_list = out_flat.unbind(0) # [S*H, D] + lse_list = lse_flat.unbind(0) # [S*H] + attn_out, _ = torch_npu.npu_attention_update(lse_list, out_list, 0) + attn_out = attn_out.view(-1, H, D) + return attn_out diff --git a/vllm_ascend/attention/mla_cp.py b/vllm_ascend/attention/context_parallel/mla_cp.py similarity index 88% rename from vllm_ascend/attention/mla_cp.py rename to vllm_ascend/attention/context_parallel/mla_cp.py index e0ce1b41..6c2425c1 100644 --- a/vllm_ascend/attention/mla_cp.py +++ b/vllm_ascend/attention/context_parallel/mla_cp.py @@ -2,7 +2,6 @@ from typing import Optional, Tuple, TypeVar import numpy as np import torch -import torch.distributed as dist import torch_npu from vllm.config import VllmConfig from vllm.distributed import (get_dcp_group, @@ -15,17 +14,17 @@ from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec # isort: off -from vllm_ascend.attention.mla_v1 import (AscendMLADecodeMetadata, - AscendMLAImpl, AscendMLAMetadata, - AscendMLAMetadataBuilder, - AscendMLAPrefillMetadata, - DecodeMLAPreprocessResult, - PrefillMLAPreprocessResult) +from vllm_ascend.attention.mla_v1 import ( + AscendMLADecodeMetadata, AscendMLAImpl, AscendMLAMetadata, + AscendMLAMetadataBuilder, AscendMLAPrefillMetadata, + DecodeMLAPreprocessResult, PrefillMLAPreprocessResult, + BUILD_METADATA_STEP_PREFILL) #isort: on from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata) -from vllm_ascend.attention.common_cp import (AscendPCPMetadata, - CPChunkedContextMetadata) +from vllm_ascend.attention.context_parallel.common_cp import ( + AscendPCPMetadata, CPChunkedContextMetadata, _process_attn_out_lse, + _npu_attention_update) from vllm_ascend.compilation.acl_graph import (get_draft_graph_params, get_graph_params, update_graph_params_workspaces) @@ -89,6 +88,8 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder): if long_seq_metadata is None: raise AssertionError("long_seq_metadata should not be None.") + # In dcp only spec decode graph padding case, + # num_actual_tokens_pcp_padded may be less than num_actual_tokens self.num_actual_tokens = max( long_seq_metadata.num_actual_tokens_pcp_padded, common_attn_metadata.num_actual_tokens) @@ -187,21 +188,17 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder): ) return chunked_metadata - def set_prefill_block_table( - self, - common_attn_metadata: AscendCommonAttentionMetadata, - ): - # For pcp + spec decode, we flatten seq_lens and block_table - # to avoid irregular spec_attn_mask shape + def get_block_table_size( + self, common_attn_metadata: AscendCommonAttentionMetadata, + build_metadata_step: int): self.num_decodes_flatten = self.query_lens[:self.num_decodes].sum( ).item() - self.block_table = common_attn_metadata.block_table_tensor[:self. - num_decodes_flatten - + self. - num_prefills] - - def set_decode_block_table(self): - self.block_table = self.block_table[:self.num_decodes_flatten, ...] + if build_metadata_step == BUILD_METADATA_STEP_PREFILL: + # For pcp + spec decode, we flatten seq_lens and block_table + # to avoid irregular spec_attn_mask shape + return self.num_decodes_flatten + self.num_prefills + else: + return self.num_decodes_flatten def build_prefill_metadata( self, @@ -637,39 +634,11 @@ class AscendMlaCPImpl(AscendMLAImpl): lse=softmax_lse) # Update out&lse - attn_out_lse = self._process_attn_out_lse(attn_output, softmax_lse, - decode_meta) - attn_output = self._npu_attention_update(attn_out_lse) + attn_out_lse = _process_attn_out_lse(attn_output, softmax_lse, + decode_meta.batch_seq_mask) + attn_output = _npu_attention_update(self.kv_lora_rank, attn_out_lse) return self._v_up_proj(attn_output) - def _npu_attention_update(self, - attn_out_lse: torch.Tensor) -> torch.Tensor: - # [PCP * S, DCP * H, D+1] - B_total, H_total, D_plus_1 = attn_out_lse.shape - S = B_total // self.pcp_size - H = H_total // self.dcp_size - D = self.kv_lora_rank - assert D_plus_1 == D + 1 - # [PCP, S, DCP, H, D+1] - x = attn_out_lse.view(self.pcp_size, S, self.dcp_size, H, D_plus_1) - # [PCP, DCP, S, H, D+1] - x = x.permute(0, 2, 1, 3, 4).contiguous() - # Flatten [N, S, H, D+1], N = pcp_size * dcp_size - x = x.view(-1, S, H, D_plus_1) - # Split out lse - out_flat, lse_flat = torch.split(x, [D, 1], - dim=-1) # [N, S, H, D], [N, S, H, 1] - # out: [N, S, H, D] -> [N, S*H, D] - # lse: [N, S, H, 1] -> [N, S*H] - out_flat = out_flat.flatten(1, 2) # [N, S*H, D] - lse_flat = lse_flat.flatten(1, -1) # [N, S*H] - # unbind to list - out_list = out_flat.unbind(0) # [S*H, D] - lse_list = lse_flat.unbind(0) # [S*H] - attn_out, _ = torch_npu.npu_attention_update(lse_list, out_list, 0) - attn_out = attn_out.view(-1, H, D) - return attn_out - def _out_lse_reshape(self, attn_out: torch.Tensor, attn_lse: torch.Tensor) -> torch.Tensor: attn_out = attn_out.contiguous().view( @@ -678,39 +647,6 @@ class AscendMlaCPImpl(AscendMLAImpl): attn_lse.shape[0] * attn_lse.shape[1] * attn_lse.shape[2]) return attn_out, attn_lse - def _process_attn_out_lse( - self, - attn_output: torch.Tensor, - softmax_lse: torch.Tensor, - decode_meta: AscendMLADecodeMetadata, - ) -> torch.Tensor: - out_mask = decode_meta.batch_seq_mask[:, None, - None].expand_as(attn_output) - attn_output = torch.where(out_mask, 0, attn_output) - lse_mask = decode_meta.batch_seq_mask[:, None, - None].expand_as(softmax_lse) - softmax_lse = torch.where(lse_mask, -torch.inf, softmax_lse) - - softmax_lse = softmax_lse.to(torch.float32) - attn_output = attn_output.to(torch.float32) - # Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1] - attn_out_lse = torch.cat([attn_output, softmax_lse], dim=-1) - if self.dcp_size > 1: - # permute: [bs, num_heads, v_head_dim+1] -> [num_heads, v_head_dim+1, bs] - attn_out_lse = attn_out_lse.permute([1, 2, 0]).contiguous() - attn_out_lse_all2all = torch.empty_like(attn_out_lse) - dist.all_to_all_single(attn_out_lse_all2all, - attn_out_lse, - group=self.dcp_group) - attn_out_lse = attn_out_lse_all2all.permute([2, 0, 1]) - - if self.pcp_size > 1: - # AllGather out&lse within CP group - attn_out_lse = get_pcp_group().all_gather( - attn_out_lse.contiguous(), dim=0) - - return attn_out_lse - def _reorg_kvcache( self, kv_c_normed: torch.Tensor, diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index b6f90c71..6d6b0db5 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -19,8 +19,8 @@ from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec from vllm_ascend import envs from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.attention.attention_v1 import AscendAttentionState -from vllm_ascend.attention.common_cp import (AscendPCPMetadata, - CPChunkedContextMetadata) +from vllm_ascend.attention.context_parallel.common_cp import ( + AscendPCPMetadata, CPChunkedContextMetadata) from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, enable_cp, maybe_save_kv_layer_to_connector, @@ -46,6 +46,8 @@ if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 +BUILD_METADATA_STEP_PREFILL = 0 +BUILD_METADATA_STEP_DECODE = 1 class AscendMLABackend(AttentionBackend): @@ -61,7 +63,8 @@ class AscendMLABackend(AttentionBackend): @staticmethod def get_builder_cls(): if enable_cp(): - from vllm_ascend.attention.mla_cp import AscendMlaCPMetadataBuilder + from vllm_ascend.attention.context_parallel.mla_cp import \ + AscendMlaCPMetadataBuilder return AscendMlaCPMetadataBuilder return AscendMLAMetadataBuilder @@ -73,7 +76,8 @@ class AscendMLABackend(AttentionBackend): @staticmethod def get_impl_cls() -> Type["MLAAttentionImpl"]: if enable_cp(): - from vllm_ascend.attention.mla_cp import AscendMlaCPImpl + from vllm_ascend.attention.context_parallel.mla_cp import \ + AscendMlaCPImpl return AscendMlaCPImpl return AscendMLAImpl @@ -418,7 +422,11 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]): self.query_lens = query_seq_lens_cpu[:num_reqs] self.seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs] - self.set_prefill_block_table(common_attn_metadata) + self.graph_pad_size = common_attn_metadata.graph_pad_size + block_table_size = self.get_block_table_size( + common_attn_metadata, BUILD_METADATA_STEP_PREFILL) + self.block_table = common_attn_metadata.block_table_tensor[: + block_table_size] prefill_metadata = None if self.num_prefills > 0: @@ -499,23 +507,16 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]): workspace=self.chunked_prefill_workspace, ) - def set_prefill_block_table( - self, - common_attn_metadata: AscendCommonAttentionMetadata, - ): - # If graph_pad_size > -1, mean is running in fullgraph mode. - self.graph_pad_size = common_attn_metadata.graph_pad_size - # NOTE: Maybe this block_table change can be removed when graph_pad_size > 1. - if self.graph_pad_size > common_attn_metadata.num_reqs and self.speculative_config.disable_padded_drafter_batch: - self.block_table = ( - common_attn_metadata.block_table_tensor[:self.graph_pad_size]) - else: - self.block_table = ( - common_attn_metadata.block_table_tensor[:common_attn_metadata. - num_reqs]) - - def set_decode_block_table(self): - self.block_table = self.block_table[:self.num_decodes, ...] + def get_block_table_size( + self, common_attn_metadata: AscendCommonAttentionMetadata, + build_metadata_step: int): + if build_metadata_step == BUILD_METADATA_STEP_PREFILL: + # If graph_pad_size > -1, mean is running in fullgraph mode. + # NOTE: Maybe this block_table change can be removed when graph_pad_size > 1. + if self.graph_pad_size > common_attn_metadata.num_reqs and self.speculative_config.disable_padded_drafter_batch: + return self.graph_pad_size + return common_attn_metadata.num_reqs + return self.num_decodes def build_prefill_metadata( self, @@ -574,7 +575,9 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]): self.seq_lens = self.seq_lens[:self.num_decodes] input_positions = input_positions[:self.num_decode_tokens] - self.set_decode_block_table() + block_table_size = self.get_block_table_size( + common_attn_metadata, BUILD_METADATA_STEP_DECODE) + self.block_table = self.block_table[:block_table_size] # NOTE: Currently, MTP-fullgraph is incompatibility pcp # NOTE: Maybe this block_table change can be removed when graph_pad_size > 1. diff --git a/vllm_ascend/attention/utils.py b/vllm_ascend/attention/utils.py index e5e0bfee..af353838 100644 --- a/vllm_ascend/attention/utils.py +++ b/vllm_ascend/attention/utils.py @@ -35,52 +35,6 @@ def enable_cp(): or prefill_config.decode_context_parallel_size > 1 -@dataclass -class AscendMetadataForPrefill: - - @dataclass - class AscendPCPMetadata: - q_head_idx: torch.Tensor = None - q_tail_idx: torch.Tensor = None - kv_with_q_head_nomask_idx: torch.Tensor = None - kv_with_q_head_mask_idx: torch.Tensor = None - kv_with_q_tail_nomask_idx: torch.Tensor = None - kv_with_q_tail_mask_idx: torch.Tensor = None - attn_mask_seqlens: torch.Tensor = None - head_attn_nomask_seqlens: torch.Tensor = None - tail_attn_nomask_seqlens: torch.Tensor = None - q_full_idx: torch.Tensor = None - pcp_prefill_mask: torch.Tensor = None - - @dataclass - class ChunkedContextMetadata: - actual_chunk_seq_lengths: torch.Tensor - actual_seq_lengths_kv: torch.Tensor - starts: torch.Tensor - chunk_seq_mask_filtered_indices: torch.Tensor - chunked_req_mask: Optional[list[bool]] = None - local_context_lens_allranks: Optional[list[list[int]]] = None - cp_kv_recover_idx_for_chunk: Optional[list[int]] = None - kv_inverse_idx_for_chunk: Optional[list[int]] = None - batch_chunk_seq_mask: Optional[list[bool]] = None - local_total_toks: Optional[int] = None - - """ Prefill Specific Metadata for Ascend""" - pcp_metadata: Optional[AscendPCPMetadata] = None - pcp_allgather_restore_idx: Optional[List[int]] = None - chunked_context: Optional[ChunkedContextMetadata] = None - block_tables: torch.Tensor = None - actual_seq_lengths_q: torch.Tensor = None - - -@dataclass -class AscendMetadataForDecode: - """ Decode Specific Metadata for Ascend""" - num_computed_tokens_of_pcp_dcp: Optional[list[list[list[int]]]] = None - batch_seq_mask: torch.Tensor = None - block_tables: torch.Tensor = None - - @dataclass # class AscendCommonLongSequenceMetadata: class AscendPrefillContextParallelMetadata: