diff --git a/tests/ut/attention/test_mla_cp.py b/tests/ut/attention/test_mla_cp.py index 2d46212c..a7597af8 100755 --- a/tests/ut/attention/test_mla_cp.py +++ b/tests/ut/attention/test_mla_cp.py @@ -6,8 +6,9 @@ from vllm.distributed.parallel_state import GroupCoordinator from tests.ut.base import TestBase from vllm_ascend.ascend_config import init_ascend_config from vllm_ascend.attention.attention_v1 import AscendAttentionState +from vllm_ascend.attention.common_cp import CPChunkedContextMetadata from vllm_ascend.attention.mla_cp import AscendMlaCPImpl -from vllm_ascend.attention.mla_v1 import AscendMLAPrefillMetadata +from vllm_ascend.attention.mla_v1 import ChunkedContextMetadata def get_pcp_split_info(pcp_rank, pcp_size, seq_lens): @@ -127,7 +128,7 @@ def get_chunk_metadata(pcp_size, dcp_size, num_prefills, num_decodes, out=padded_local_cu_chunk_seq_lens_cpu[:, 1:], dtype=torch.int32, ) - chunked_context_metadata = AscendMLAPrefillMetadata.ChunkedContextMetadata( + chunked_context_metadata = CPChunkedContextMetadata( cu_seq_lens=cu_seq_lens_cpu.to(non_blocking=True), starts=local_chunk_starts.to(non_blocking=True), seq_tot=padded_local_chunk_seq_lens.sum(dim=1).tolist(), @@ -144,16 +145,15 @@ def get_chunk_metadata(pcp_size, dcp_size, num_prefills, num_decodes, chunk_size=padded_local_max_context_chunk_across_ranks, ) else: - chunked_context_metadata = ( - AscendMLAPrefillMetadata.ChunkedContextMetadata( - cu_seq_lens=cu_seq_lens_cpu.to(non_blocking=True), - starts=chunk_starts.to(non_blocking=True), - seq_tot=chunk_seq_lens.sum(dim=1).tolist(), - max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), - chunk_seq_lens=chunk_seq_lens, - chunk_seq_lens_npu=chunk_seq_lens, - workspace=None, - )) + chunked_context_metadata = (ChunkedContextMetadata( + cu_seq_lens=cu_seq_lens_cpu.to(non_blocking=True), + starts=chunk_starts.to(non_blocking=True), + seq_tot=chunk_seq_lens.sum(dim=1).tolist(), + max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), + chunk_seq_lens=chunk_seq_lens, + chunk_seq_lens_npu=chunk_seq_lens, + workspace=None, + )) return chunked_context_metadata diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py index b07a9a55..06c5dc6d 100755 --- a/tests/ut/attention/test_mla_v1.py +++ b/tests/ut/attention/test_mla_v1.py @@ -14,7 +14,8 @@ from vllm_ascend.attention.mla_v1 import (AscendMLABackend, AscendMLADecodeMetadata, AscendMLAImpl, AscendMLAMetadata, AscendMLAMetadataBuilder, - AscendMLAPrefillMetadata) + AscendMLAPrefillMetadata, + ChunkedContextMetadata) from vllm_ascend.attention.utils import AscendCommonAttentionMetadata @@ -76,27 +77,15 @@ class TestAscendMLAPrefillMetadata(TestBase): max_seq_lens = [2, 2] workspace = torch.randn(2, 4) chunk_seq_lens = torch.tensor([2, 2]) - padded_chunk_seq_lens_npu = torch.tensor([2, 2]) - padded_local_chunk_seq_lens = [[2], [2]] - local_context_lens_allranks = [[1, 1], [1, 1]] - padded_local_cu_seq_lens = torch.tensor([0, 2, 4]) - cu_seq_lens_lst = [[0, 2], [2, 4]] - chunk_size = 2 - chunked_context = AscendMLAPrefillMetadata.ChunkedContextMetadata( + chunked_context = ChunkedContextMetadata( cu_seq_lens=cu_seq_lens, starts=starts, seq_tot=seq_tot, max_seq_lens=max_seq_lens, workspace=workspace, chunk_seq_lens=chunk_seq_lens, - chunk_seq_lens_npu=chunk_seq_lens, - padded_chunk_seq_lens_npu=padded_chunk_seq_lens_npu, - padded_local_chunk_seq_lens=padded_local_chunk_seq_lens, - local_context_lens_allranks=local_context_lens_allranks, - padded_local_cu_seq_lens=padded_local_cu_seq_lens, - cu_seq_lens_lst=cu_seq_lens_lst, - chunk_size=chunk_size) + chunk_seq_lens_npu=chunk_seq_lens) metadata = AscendMLAPrefillMetadata( attn_mask=torch.tensor([[1, 0], [1, 1]], dtype=torch.bool), @@ -119,17 +108,6 @@ class TestAscendMLAPrefillMetadata(TestBase): self.assertIs(metadata.chunked_context.chunk_seq_lens, chunk_seq_lens) self.assertIs(metadata.chunked_context.chunk_seq_lens_npu, chunk_seq_lens) - self.assertIs(metadata.chunked_context.padded_chunk_seq_lens_npu, - padded_chunk_seq_lens_npu) - self.assertEqual(metadata.chunked_context.padded_local_chunk_seq_lens, - padded_local_chunk_seq_lens) - self.assertEqual(metadata.chunked_context.local_context_lens_allranks, - local_context_lens_allranks) - self.assertIs(metadata.chunked_context.padded_local_cu_seq_lens, - padded_local_cu_seq_lens) - self.assertEqual(metadata.chunked_context.cu_seq_lens_lst, - cu_seq_lens_lst) - self.assertEqual(metadata.chunked_context.chunk_size, chunk_size) class TestAscendMLADecodeMetadata(TestBase): @@ -218,11 +196,9 @@ class TestAscendMLAMetadataBuilder(TestBase): @patch('vllm.distributed.parallel_state.get_dcp_group') @patch('vllm.distributed.parallel_state._DCP', new_callable=lambda: MagicMock(spec=GroupCoordinator)) - @patch("vllm.distributed.get_decode_context_model_parallel_world_size", - return_value=1) - def test_ascend_mla_metadata_builder_default(self, mock_get_dcp_size, - mock_dcp, mock_get_dcp_group, - mock_pcp, mock_get_pcp_group): + def test_ascend_mla_metadata_builder_default(self, mock_dcp, + mock_get_dcp_group, mock_pcp, + mock_get_pcp_group): mock_vllm_config = MagicMock() mock_vllm_config.model_config.max_model_len = 1024 mock_vllm_config.model_config.get_head_size.return_value = 64 @@ -262,8 +238,6 @@ class TestAscendMLAMetadataBuilder(TestBase): self.assertEqual( builder.chunked_prefill_enabled, mock_vllm_config.scheduler_config.enable_chunked_prefill) - self.assertEqual(builder.dcp_size, mock_dcp.world_size) - self.assertEqual(builder.pcp_size, mock_pcp.world_size) @patch('vllm.distributed.parallel_state.get_pcp_group') @patch('vllm.distributed.parallel_state._PCP', @@ -271,10 +245,7 @@ class TestAscendMLAMetadataBuilder(TestBase): @patch('vllm.distributed.parallel_state.get_dcp_group') @patch('vllm.distributed.parallel_state._DCP', new_callable=lambda: MagicMock(spec=GroupCoordinator)) - @patch("vllm.distributed.get_decode_context_model_parallel_world_size", - return_value=1) - def test_ascend_mla_metadata_builder_spec_decode(self, mock_get_dcp_size, - mock_dcp, + def test_ascend_mla_metadata_builder_spec_decode(self, mock_dcp, mock_get_dcp_group, mock_pcp, mock_get_pcp_group): @@ -324,11 +295,8 @@ class TestAscendMLAMetadataBuilder(TestBase): @patch('vllm.distributed.parallel_state.get_dcp_group') @patch('vllm.distributed.parallel_state._DCP', new_callable=lambda: MagicMock(spec=GroupCoordinator)) - @patch("vllm.distributed.get_decode_context_model_parallel_world_size", - return_value=1) def test_ascend_mla_metadata_builder_build_full_graph( - self, mock_get_dcp_size, mock_dcp, mock_get_dcp_group, mock_pcp, - mock_get_pcp_group): + self, mock_dcp, mock_get_dcp_group, mock_pcp, mock_get_pcp_group): mock_vllm_config = MagicMock() mock_vllm_config.model_config.max_model_len = 1024 mock_vllm_config.model_config.get_head_size.return_value = 64 @@ -387,10 +355,8 @@ class TestAscendMLAMetadataBuilder(TestBase): @patch('vllm.distributed.parallel_state.get_dcp_group') @patch('vllm.distributed.parallel_state._DCP', new_callable=lambda: MagicMock(spec=GroupCoordinator)) - @patch("vllm.distributed.get_decode_context_model_parallel_world_size", - return_value=1) - def test_reorder_batch(self, mock_get_dcp_size, mock_dcp, - mock_get_dcp_group, mock_pcp, mock_get_pcp_group): + def test_reorder_batch(self, mock_dcp, mock_get_dcp_group, mock_pcp, + mock_get_pcp_group): ascend_config = MagicMock() mock_vllm_config = MagicMock() @@ -448,10 +414,7 @@ class TestAscendMLAMetadataBuilder(TestBase): @patch('vllm.distributed.parallel_state.get_dcp_group') @patch('vllm.distributed.parallel_state._DCP', new_callable=lambda: MagicMock(spec=GroupCoordinator)) - @patch("vllm.distributed.get_decode_context_model_parallel_world_size", - return_value=1) - def test_pad_actual_seq_lens_q_mtp_disable_pad(self, mock_get_dcp_size, - mock_dcp, + def test_pad_actual_seq_lens_q_mtp_disable_pad(self, mock_dcp, mock_get_dcp_group, mock_pcp, mock_get_pcp_group): @@ -496,11 +459,8 @@ class TestAscendMLAMetadataBuilder(TestBase): @patch('vllm.distributed.parallel_state.get_dcp_group') @patch('vllm.distributed.parallel_state._DCP', new_callable=lambda: MagicMock(spec=GroupCoordinator)) - @patch("vllm.distributed.get_decode_context_model_parallel_world_size", - return_value=1) - def test_pad_actual_seq_lens_q_mtp_enable_pad(self, mock_get_dcp_size, - mock_dcp, mock_get_dcp_group, - mock_pcp, + def test_pad_actual_seq_lens_q_mtp_enable_pad(self, mock_dcp, + mock_get_dcp_group, mock_pcp, mock_get_pcp_group): mock_vllm_config = MagicMock() mock_vllm_config.model_config.max_model_len = 1024 @@ -566,17 +526,14 @@ class TestAscendMLAMetadataBuilderBuild(TestBase): self.kv_cache_spec.head_size = 128 self.kv_cache_spec.num_heads = 32 - @patch("vllm_ascend.attention.mla_v1.get_pcp_group") - @patch( - "vllm_ascend.attention.mla_v1.get_decode_context_model_parallel_world_size" - ) - @patch("vllm_ascend.attention.mla_v1.get_ascend_config") + @patch('vllm.distributed.parallel_state.get_pcp_group') + @patch("vllm.distributed.get_decode_context_model_parallel_world_size", + return_value=1) @patch("vllm_ascend.attention.mla_v1.torch.zeros", wraps=torch.zeros) @patch("torch.Tensor.npu", new=lambda self: self) @patch("torch.npu.is_available") def test_build_prefix_no_cache_metadata(self, mock_npu_available, - mock_zeros, mock_get_ascend_config, - mock_dcp_world_size, + mock_zeros, mock_dcp_world_size, mock_get_pcp_group): mock_npu_available.return_value = False mock_dcp_world_size.return_value = 1 @@ -633,17 +590,14 @@ class TestAscendMLAMetadataBuilderBuild(TestBase): torch.all(metadata.slot_mapping == base_inputs["slot_mapping"])) self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size) - @patch("vllm_ascend.attention.mla_v1.get_pcp_group") - @patch( - "vllm_ascend.attention.mla_v1.get_decode_context_model_parallel_world_size" - ) - @patch("vllm_ascend.attention.mla_v1.get_ascend_config") + @patch('vllm.distributed.parallel_state.get_pcp_group') + @patch("vllm.distributed.get_decode_context_model_parallel_world_size", + return_value=1) @patch("vllm_ascend.attention.mla_v1.torch.zeros", wraps=torch.zeros) @patch("torch.Tensor.npu", new=lambda self: self) @patch("torch.npu.is_available") def test_build_chunked_prefix_metadata(self, mock_npu_available, - mock_zeros, mock_get_ascend_config, - mock_dcp_world_size, + mock_zeros, mock_dcp_world_size, mock_get_pcp_group): mock_npu_available.return_value = False mock_dcp_world_size.return_value = 1 @@ -701,13 +655,10 @@ class TestAscendMLAMetadataBuilderBuild(TestBase): torch.all(metadata.slot_mapping == base_inputs["slot_mapping"])) self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size) - @patch("vllm_ascend.attention.mla_v1.get_pcp_group") - @patch( - "vllm_ascend.attention.mla_v1.get_decode_context_model_parallel_world_size" - ) - @patch("vllm_ascend.attention.mla_v1.get_ascend_config") - def test_build_decode_only_metadata(self, mock_get_ascend_config, - mock_dcp_world_size, + @patch('vllm.distributed.parallel_state.get_pcp_group') + @patch("vllm.distributed.get_decode_context_model_parallel_world_size", + return_value=1) + def test_build_decode_only_metadata(self, mock_dcp_world_size, mock_get_pcp_group): mock_dcp_world_size.return_value = 1 torch.Tensor.pin_memory = lambda x: x # noqa @@ -757,13 +708,10 @@ class TestAscendMLAMetadataBuilderBuild(TestBase): torch.all(metadata.slot_mapping == base_inputs["slot_mapping"])) self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size) - @patch("vllm_ascend.attention.mla_v1.get_pcp_group") - @patch( - "vllm_ascend.attention.mla_v1.get_decode_context_model_parallel_world_size" - ) - @patch("vllm_ascend.attention.mla_v1.get_ascend_config") - def test_build_for_graph_capture_decode_only(self, mock_get_ascend_config, - mock_dcp_world_size, + @patch('vllm.distributed.parallel_state.get_pcp_group') + @patch("vllm.distributed.get_decode_context_model_parallel_world_size", + return_value=1) + def test_build_for_graph_capture_decode_only(self, mock_dcp_world_size, mock_get_pcp_group): mock_dcp_world_size.return_value = 1 torch.Tensor.pin_memory = lambda x: x # noqa @@ -814,13 +762,10 @@ class TestAscendMLAMetadataBuilderBuild(TestBase): torch.all(metadata.slot_mapping == base_inputs["slot_mapping"])) self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size) - @patch("vllm_ascend.attention.mla_v1.get_pcp_group") - @patch( - "vllm_ascend.attention.mla_v1.get_decode_context_model_parallel_world_size" - ) - @patch("vllm_ascend.attention.mla_v1.get_ascend_config") - def test_build_for_graph_capture_prefill(self, mock_get_ascend_config, - mock_dcp_world_size, + @patch('vllm.distributed.parallel_state.get_pcp_group') + @patch("vllm.distributed.get_decode_context_model_parallel_world_size", + return_value=1) + def test_build_for_graph_capture_prefill(self, mock_dcp_world_size, mock_get_pcp_group): mock_dcp_world_size.return_value = 1 torch.Tensor.pin_memory = lambda x: x # noqa @@ -868,16 +813,10 @@ class TestAscendMLAImpl(TestBase): new_callable=lambda: MagicMock(spec=GroupCoordinator)) @patch('vllm.distributed.parallel_state._DCP', new_callable=lambda: MagicMock(spec=GroupCoordinator)) - @patch("vllm.distributed.get_decode_context_model_parallel_world_size", - return_value=1) @patch('vllm.distributed.parallel_state._TP', new_callable=lambda: MagicMock(spec=GroupCoordinator)) - @patch("vllm.distributed.get_tensor_model_parallel_world_size", - return_value=2) @patch("vllm_ascend.attention.mla_v1.get_current_vllm_config") - @patch("vllm_ascend.attention.mla_v1.get_ascend_config") - def setUp(self, ascend_config, get_current_vllm_config, mock_get_tp_size, - mock_tp, mock_get_dcp_size, mock_dcp, mock_pcp): + def setUp(self, get_current_vllm_config, mock_tp, mock_dcp, mock_pcp): mock_tp.world_size = 2 mock_tp.rank_in_group = MagicMock() mock_tp.device_group = MagicMock() diff --git a/vllm_ascend/attention/common_cp.py b/vllm_ascend/attention/common_cp.py new file mode 100644 index 00000000..ccf39344 --- /dev/null +++ b/vllm_ascend/attention/common_cp.py @@ -0,0 +1,40 @@ +from dataclasses import dataclass +from typing import Optional + +import torch + + +@dataclass +class AscendPCPMetadata: + q_head_idx: torch.Tensor = None + q_tail_idx: torch.Tensor = None + kv_with_q_head_nomask_idx: torch.Tensor = None + kv_with_q_head_mask_idx: torch.Tensor = None + kv_with_q_tail_nomask_idx: torch.Tensor = None + kv_with_q_tail_mask_idx: torch.Tensor = None + attn_mask_seqlens: torch.Tensor = None + head_attn_nomask_seqlens: torch.Tensor = None + tail_attn_nomask_seqlens: torch.Tensor = None + q_full_idx: torch.Tensor = None + pcp_prefill_mask: torch.Tensor = None + pcp_allgather_restore_idx: Optional[list[int]] = None + + +@dataclass +class CPChunkedContextMetadata: + # New for MLA (compared to FlashAttention) + # For handling chunked prefill + cu_seq_lens: torch.Tensor + starts: torch.Tensor + seq_tot: list[int] + max_seq_lens: list[int] + workspace: torch.Tensor + chunk_seq_lens: torch.Tensor + chunk_seq_lens_npu: torch.Tensor + # for mla DCP & PCP + padded_chunk_seq_lens_npu: torch.Tensor = None + padded_local_chunk_seq_lens: Optional[list[list[int]]] = None + local_context_lens_allranks: Optional[list[list[int]]] = None + padded_local_cu_seq_lens: torch.Tensor = None + cu_seq_lens_lst: Optional[list[list[int]]] = None + chunk_size: Optional[int] = None diff --git a/vllm_ascend/attention/mla_cp.py b/vllm_ascend/attention/mla_cp.py index ae24557a..0a3aed14 100644 --- a/vllm_ascend/attention/mla_cp.py +++ b/vllm_ascend/attention/mla_cp.py @@ -5,31 +5,32 @@ import torch import torch.distributed as dist import torch_npu from torch import nn -from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.config import VllmConfig from vllm.distributed import (get_dcp_group, get_decode_context_model_parallel_rank, get_decode_context_model_parallel_world_size, get_pcp_group) from vllm.forward_context import ForwardContext, get_forward_context -from vllm.utils.math_utils import cdiv, round_down +from vllm.utils.math_utils import cdiv from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.kv_cache_interface import MLAAttentionSpec +# isort: off from vllm_ascend.attention.mla_v1 import (AscendMLADecodeMetadata, AscendMLAImpl, AscendMLAMetadata, AscendMLAMetadataBuilder, AscendMLAPrefillMetadata, DecodeMLAPreprocessResult, PrefillMLAPreprocessResult) +#isort: on + from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, maybe_save_kv_layer_to_connector, - split_decodes_and_prefills, wait_for_kv_layer_from_connector) +from vllm_ascend.attention.common_cp import AscendPCPMetadata, CPChunkedContextMetadata from vllm_ascend.compilation.acl_graph import (get_graph_params, get_mtp_graph_params, update_graph_params_workspaces) -from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla from vllm_ascend.ops.shared_weight_layer import ( is_hidden_layer, reach_layer_for_shared_weight_series) from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch @@ -75,354 +76,173 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder): dtype=torch.uint8, device=device) - def build( + def set_num_actual_tokens( + self, + common_attn_metadata: AscendCommonAttentionMetadata, + ): + long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata + if long_seq_metadata is None: + raise AssertionError("long_seq_metadata should not be None.") + + self.num_actual_tokens = max( + long_seq_metadata.num_actual_tokens_pcp_padded, + common_attn_metadata.num_actual_tokens) + + def build_cp_metadata( self, common_prefix_len: int, common_attn_metadata: AscendCommonAttentionMetadata, model: nn.Module, - ) -> AscendMLAMetadata: - num_reqs = common_attn_metadata.num_reqs - num_actual_tokens = common_attn_metadata.num_actual_tokens - query_start_loc = common_attn_metadata.query_start_loc - query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + ) -> AscendPCPMetadata | None: + common_long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata + assert common_long_seq_metadata is not None + return AscendPCPMetadata( + q_head_idx=common_long_seq_metadata.q_head_idx_tensor, + q_tail_idx=common_long_seq_metadata.q_tail_idx_tensor, + kv_with_q_head_nomask_idx=common_long_seq_metadata. + kv_with_q_head_nomask_idx_tensor, + kv_with_q_head_mask_idx=common_long_seq_metadata. + kv_with_q_head_mask_idx_tensor, + kv_with_q_tail_nomask_idx=common_long_seq_metadata. + kv_with_q_tail_nomask_idx_tensor, + kv_with_q_tail_mask_idx=common_long_seq_metadata. + kv_with_q_tail_mask_idx_tensor, + attn_mask_seqlens=common_long_seq_metadata.attn_mask_seqlens, + head_attn_nomask_seqlens=common_long_seq_metadata. + head_attn_nomask_seqlens, + tail_attn_nomask_seqlens=common_long_seq_metadata. + tail_attn_nomask_seqlens, + q_full_idx=common_long_seq_metadata.q_full_idx, + pcp_prefill_mask=common_long_seq_metadata.pcp_prefill_mask, + pcp_allgather_restore_idx=common_long_seq_metadata. + pcp_allgather_restore_idx) + + def build_chunked_metadata( + self, + common_prefix_len: int, + common_attn_metadata: AscendCommonAttentionMetadata, + model: nn.Module, + ): + chunked_context_metadata = super().build_chunked_metadata( + common_prefix_len, common_attn_metadata, model) + if chunked_context_metadata is None: + return None + long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata - - if long_seq_metadata is None: - raise AssertionError("long_seq_metadata should not be None.") - - num_actual_tokens_pcp_padded = long_seq_metadata.num_actual_tokens_pcp_padded - if num_actual_tokens_pcp_padded is None: - num_actual_tokens_pcp_padded = num_actual_tokens - # In dcp only spec decode graph padding case, - # num_actual_tokens_pcp_padded may be less than num_actual_tokens - num_actual_tokens_pcp_padded = max(num_actual_tokens_pcp_padded, - num_actual_tokens) + assert long_seq_metadata is not None num_computed_tokens_of_pcp_dcp = long_seq_metadata.num_computed_tokens_of_pcp_dcp assert num_computed_tokens_of_pcp_dcp is not None + local_context_lens_allranks = torch.tensor( + num_computed_tokens_of_pcp_dcp[self.num_decodes_flatten:]).reshape( + -1, self.dcp_size * self.pcp_size) + # Note(qcs): The max local context lengths + # padded to `cp_local_block_size`. + padded_local_context_lens_cpu = (cdiv( + self.context_lens_cpu, + self.cp_virtual_block_size, + ) * self.cp_local_block_size) + padded_local_max_context_chunk_across_ranks = (cdiv( + self.max_context_chunk, + self.cp_virtual_block_size, + ) * self.cp_local_block_size) + local_chunk_starts = (torch.arange( + self.num_chunks, dtype=torch.int32).unsqueeze(1).expand( + -1, self.num_prefills) * + padded_local_max_context_chunk_across_ranks) + local_chunk_ends = torch.min( + padded_local_context_lens_cpu.unsqueeze(0), + local_chunk_starts + padded_local_max_context_chunk_across_ranks, + ) + padded_local_chunk_seq_lens = (local_chunk_ends - + local_chunk_starts).clamp(min=0) + padded_local_cu_chunk_seq_lens_cpu = torch.zeros(self.num_chunks, + self.num_prefills + 1, + dtype=torch.int32, + pin_memory=True) + torch.cumsum( + padded_local_chunk_seq_lens, + dim=1, + out=padded_local_cu_chunk_seq_lens_cpu[:, 1:], + dtype=torch.int32, + ) + chunked_metadata = CPChunkedContextMetadata( + cu_seq_lens=chunked_context_metadata.cu_seq_lens, + starts=local_chunk_starts.pin_memory().to(self.device, + non_blocking=True), + seq_tot=padded_local_chunk_seq_lens.sum(dim=1).tolist(), + max_seq_lens=chunked_context_metadata.max_seq_lens, + chunk_seq_lens=self.chunk_seq_lens, + chunk_seq_lens_npu=chunked_context_metadata.chunk_seq_lens_npu, + workspace=chunked_context_metadata.workspace, + padded_chunk_seq_lens_npu=padded_local_chunk_seq_lens.npu(), + padded_local_chunk_seq_lens=padded_local_chunk_seq_lens.tolist(), + local_context_lens_allranks=local_context_lens_allranks.tolist(), + padded_local_cu_seq_lens=padded_local_cu_chunk_seq_lens_cpu. + pin_memory().to(self.device, non_blocking=True), + cu_seq_lens_lst=self.cu_seq_lens_cpu.tolist(), + chunk_size=padded_local_max_context_chunk_across_ranks, + ) + return chunked_metadata - num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \ - split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.decode_threshold) - assert num_decodes + num_prefills == num_reqs - assert num_decode_tokens + num_prefill_tokens == num_actual_tokens - - # Note(simon): be careful about the CPU <> GPU memory movement in this - # function. We should avoid GPU -> CPU sync as much as possible because - # it blocks on all previous kernels. - device = self.device - - # If graph_pad_size > -1, mean is running in fullgraph mode. - graph_pad_size = common_attn_metadata.graph_pad_size - # NOTE: Maybe this block_table change can be removed when graph_pad_size > 1. - if graph_pad_size > num_reqs and self.speculative_config.disable_padded_drafter_batch: - block_table = ( - common_attn_metadata.block_table_tensor[:graph_pad_size]) - else: - block_table = (common_attn_metadata.block_table_tensor[:num_reqs]) - slot_mapping = common_attn_metadata.slot_mapping[: - num_actual_tokens_pcp_padded] - input_positions = common_attn_metadata.positions[: - num_actual_tokens_pcp_padded].long( - ) - - if self.cos_cache is None: - self.cos_cache = model.model.layers[ - model.model.start_layer].self_attn.rotary_emb.cos_cached - self.sin_cache = model.model.layers[ - model.model.start_layer].self_attn.rotary_emb.sin_cached - if self.cos_cache.dtype != self.model_config.dtype: # type: ignore - self.cos_cache = self.cos_cache.to( # type: ignore - self.model_config.dtype) # type: ignore - self.sin_cache = self.sin_cache.to( # type: ignore - self.model_config.dtype) # type: ignore - - query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] - query_lens = query_seq_lens_cpu[:num_reqs] - seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs] - num_computed_tokens_cpu = (seq_lens - query_lens) - + def set_prefill_block_table( + self, + common_attn_metadata: AscendCommonAttentionMetadata, + ): # For pcp + spec decode, we flatten seq_lens and block_table # to avoid irregular spec_attn_mask shape - num_decodes_flatten = query_lens[:num_decodes].sum().item() - block_table = common_attn_metadata.block_table_tensor[: - num_decodes_flatten - + num_prefills] + self.num_decodes_flatten = self.query_lens[:self.num_decodes].sum( + ).item() + self.block_table = common_attn_metadata.block_table_tensor[:self. + num_decodes_flatten + + self. + num_prefills] - prefill_metadata = None - chunked_context_metadata = None - if num_prefills > 0: - pcp_metadata = AscendMLAPrefillMetadata.AscendPCPMetadata( - q_head_idx=long_seq_metadata.q_head_idx_tensor, - q_tail_idx=long_seq_metadata.q_tail_idx_tensor, - kv_with_q_head_nomask_idx=long_seq_metadata. - kv_with_q_head_nomask_idx_tensor, - kv_with_q_head_mask_idx=long_seq_metadata. - kv_with_q_head_mask_idx_tensor, - kv_with_q_tail_nomask_idx=long_seq_metadata. - kv_with_q_tail_nomask_idx_tensor, - kv_with_q_tail_mask_idx=long_seq_metadata. - kv_with_q_tail_mask_idx_tensor, - attn_mask_seqlens=long_seq_metadata.attn_mask_seqlens, - head_attn_nomask_seqlens=long_seq_metadata. - head_attn_nomask_seqlens, - tail_attn_nomask_seqlens=long_seq_metadata. - tail_attn_nomask_seqlens, - q_full_idx=long_seq_metadata.q_full_idx, - pcp_prefill_mask=long_seq_metadata.pcp_prefill_mask, - pcp_allgather_restore_idx=long_seq_metadata. - pcp_allgather_restore_idx) + def set_decode_block_table( + self, common_attn_metadata: AscendCommonAttentionMetadata): + self.block_table = self.block_table[:self.num_decodes_flatten, ...] - reqs_start = num_decodes # prefill_start - tokens_start = num_decode_tokens - max_query_len = query_lens[reqs_start:].max().item() - max_seq_lens = seq_lens[reqs_start:].max().item() - prefill_query_start_loc = query_start_loc[ - reqs_start:] - query_start_loc[reqs_start] + def build_prefill_metadata( + self, + common_prefix_len: int, + common_attn_metadata: AscendCommonAttentionMetadata, + model: nn.Module, + ) -> AscendMLAPrefillMetadata: + prefill_metadata = super().build_prefill_metadata( + common_prefix_len, common_attn_metadata, model) + prefill_metadata.pcp_metadata = self.build_cp_metadata( + common_prefix_len, common_attn_metadata, model) + prefill_metadata.block_table = self.block_table[ + self.num_decodes_flatten:, ...] + return prefill_metadata - context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs] - max_context_len_cpu = context_lens_cpu.max().item() - num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item() - if self.chunked_prefill_enabled and max_context_len_cpu > 0: - max_context_chunk = (self.chunked_prefill_workspace_size // - num_prefills_with_context_cpu) - max_context_chunk = round_down(max_context_chunk, - self.block_size) + def build_decode_metadata( + self, + common_prefix_len: int, + common_attn_metadata: AscendCommonAttentionMetadata, + model: nn.Module, + ) -> AscendMLADecodeMetadata: + decode_metadata = super().build_decode_metadata( + common_prefix_len, common_attn_metadata, model) - assert max_context_chunk > 0 - num_chunks = cdiv(max_context_len_cpu, max_context_chunk) - chunk_starts = torch.arange(num_chunks, dtype=torch.int32) \ - .unsqueeze(1).expand(-1, num_prefills) * max_context_chunk - chunk_ends = torch.min(context_lens_cpu.unsqueeze(0), - chunk_starts + max_context_chunk) - chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0) - cu_seq_lens_cpu = torch.zeros(num_chunks, - num_prefills + 1, - dtype=torch.int32, - pin_memory=True) - torch.cumsum(chunk_seq_lens, - dim=1, - out=cu_seq_lens_cpu[:, 1:], - dtype=torch.int32) + long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata + assert long_seq_metadata is not None + num_computed_tokens_of_pcp_dcp = long_seq_metadata.num_computed_tokens_of_pcp_dcp + assert num_computed_tokens_of_pcp_dcp is not None + # [bs, pcp_size, dcp_size] + num_computed_tokens_of_cp_dcp_array = np.array( + num_computed_tokens_of_pcp_dcp)[:self.num_decodes_flatten] - local_context_lens_allranks = torch.tensor( - num_computed_tokens_of_pcp_dcp[num_decodes_flatten:] - ).reshape(-1, self.dcp_size * self.pcp_size) - # Note(qcs): The max local context lengths - # padded to `cp_local_block_size`. - padded_local_context_lens_cpu = (cdiv( - context_lens_cpu, - self.cp_virtual_block_size, - ) * self.cp_local_block_size) - padded_local_max_context_chunk_across_ranks = (cdiv( - max_context_chunk, - self.cp_virtual_block_size, - ) * self.cp_local_block_size) - local_chunk_starts = ( - torch.arange(num_chunks, - dtype=torch.int32).unsqueeze(1).expand( - -1, num_prefills) * - padded_local_max_context_chunk_across_ranks) - local_chunk_ends = torch.min( - padded_local_context_lens_cpu.unsqueeze(0), - local_chunk_starts + - padded_local_max_context_chunk_across_ranks, - ) - padded_local_chunk_seq_lens = (local_chunk_ends - - local_chunk_starts).clamp(min=0) - padded_local_cu_chunk_seq_lens_cpu = torch.zeros( - num_chunks, - num_prefills + 1, - dtype=torch.int32, - pin_memory=True) - torch.cumsum( - padded_local_chunk_seq_lens, - dim=1, - out=padded_local_cu_chunk_seq_lens_cpu[:, 1:], - dtype=torch.int32, - ) - chunked_context_metadata = AscendMLAPrefillMetadata.ChunkedContextMetadata( - cu_seq_lens=cu_seq_lens_cpu.pin_memory().to( - device, non_blocking=True), - starts=local_chunk_starts.pin_memory().to( - device, non_blocking=True), - seq_tot=padded_local_chunk_seq_lens.sum(dim=1).tolist(), - max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), - chunk_seq_lens=chunk_seq_lens, - chunk_seq_lens_npu=chunk_seq_lens.npu(), - workspace=self.chunked_prefill_workspace, - padded_chunk_seq_lens_npu=padded_local_chunk_seq_lens.npu( - ), - padded_local_chunk_seq_lens=padded_local_chunk_seq_lens. - tolist(), - local_context_lens_allranks=local_context_lens_allranks. - tolist(), - padded_local_cu_seq_lens=padded_local_cu_chunk_seq_lens_cpu - .pin_memory().to(device, non_blocking=True), - cu_seq_lens_lst=cu_seq_lens_cpu.tolist(), - chunk_size=padded_local_max_context_chunk_across_ranks, - ) - - prefill_input_positions = input_positions[tokens_start:] - assert self.cos_cache is not None - assert self.sin_cache is not None - cos = self.cos_cache[prefill_input_positions].unsqueeze( - 1).unsqueeze(2) - sin = self.sin_cache[prefill_input_positions].unsqueeze( - 1).unsqueeze(2) - prefill_metadata = AscendMLAPrefillMetadata( - attn_mask=common_attn_metadata.attn_mask, - query_lens=query_lens[reqs_start:].to(torch.int32), - seq_lens=seq_lens, - context_lens=seq_lens[reqs_start:], - input_positions=prefill_input_positions, - block_table=block_table[reqs_start:, ...], - max_query_len=max_query_len, - max_seq_lens=max_seq_lens, - query_start_loc=prefill_query_start_loc, - chunked_context=chunked_context_metadata, - sin=sin, - cos=cos, - pcp_metadata=pcp_metadata, - ) - prefill_metadata.block_table = \ - block_table[num_decodes_flatten:, ...] - - decode_metadata = None - if num_decodes > 0: - cos, sin = get_cos_and_sin_mla() - # Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario - actual_seq_lengths_q = query_start_loc_cpu[1:num_decodes + - 1].tolist() - max_seq_lens = seq_lens[:num_decodes].max().item() - seq_lens = seq_lens[:num_decodes] - input_positions = input_positions[:num_decode_tokens] - block_table = block_table[:num_decodes_flatten, ...] - # NOTE: Maybe this block_table change can be removed when graph_pad_size > 1. - if graph_pad_size > num_decodes and \ - self.speculative_config.disable_padded_drafter_batch: - block_table = block_table[:graph_pad_size, ...] - seq_lens_list = seq_lens.tolist() - - # [bs, pcp_size, dcp_size] - num_computed_tokens_of_cp_dcp_array = np.array( - num_computed_tokens_of_pcp_dcp)[:num_decodes_flatten] - - cp_seq_len = num_computed_tokens_of_cp_dcp_array[:, self.pcp_rank, - self.dcp_rank] - cp_seq_len = torch.tensor(cp_seq_len, dtype=torch.int32) - batch_seq_mask = (cp_seq_len == 0) - self.batch_seq_mask_buf[:batch_seq_mask.shape[0]].copy_( - batch_seq_mask, non_blocking=True) - batch_seq_mask = self.batch_seq_mask_buf[:batch_seq_mask.shape[0]] - cp_seq_len = torch.where(cp_seq_len == 0, 1, cp_seq_len) - - if graph_pad_size > num_reqs: - if self.speculative_config.disable_padded_drafter_batch: - num_reqs_pad_size = graph_pad_size - num_reqs - actual_seq_lengths_q = self.pad_actual_seq_len_q_mtp_disable_pad( - num_reqs_pad_size, num_reqs, actual_seq_lengths_q) - seq_lens_list = seq_lens_list + [0] * (graph_pad_size - - num_decodes) - num_block_pad_size = graph_pad_size - block_table.shape[0] - if num_block_pad_size > 0: - block_table_padding = torch.zeros( - (num_block_pad_size, ) + block_table.shape[1:], - dtype=block_table.dtype, - device=block_table.device) - block_table = torch.cat( - [block_table, block_table_padding], dim=0) - else: - num_token_pad_size = graph_pad_size - num_decode_tokens - num_reqs_pad_size = ( - graph_pad_size // - common_attn_metadata.decode_token_per_req - num_reqs) - num_block_table_pad_size = ( - graph_pad_size // - common_attn_metadata.decode_token_per_req - - num_decodes) - seq_lens_list = seq_lens.tolist() + [0] * num_reqs_pad_size - slot_padding = torch.full((num_token_pad_size, ), - PAD_SLOT_ID, - dtype=slot_mapping.dtype, - device=slot_mapping.device) - slot_mapping = torch.cat([slot_mapping, slot_padding]) - block_table_padding = torch.zeros( - (num_block_table_pad_size, ) + block_table.shape[1:], - dtype=block_table.dtype, - device=block_table.device) - block_table = torch.cat([block_table, block_table_padding], - dim=0) - position_padding = torch.zeros( - num_token_pad_size, - dtype=input_positions.dtype, - device=input_positions.device) - input_positions = torch.cat( - [input_positions, position_padding]) - actual_seq_lengths_q = self.pad_actual_seq_len_q_mtp_enable_pad( - num_reqs_pad_size, num_reqs, actual_seq_lengths_q, - common_attn_metadata) - - # TODO: After the fullgraph supports MTP, the if branch needs to deleted - assert self.cos_cache is not None - assert self.sin_cache is not None - if cos is None and sin is None: - cos = self.cos_cache[ - input_positions].unsqueeze( # type: ignore - 1).unsqueeze(2) - sin = self.sin_cache[ - input_positions].unsqueeze( # type: ignore - 1).unsqueeze(2) - - decode_metadata = AscendMLADecodeMetadata( - input_positions=input_positions, - block_table=block_table, - seq_lens=seq_lens, - seq_lens_list=seq_lens_list, - max_seq_lens=max_seq_lens, - attn_mask=common_attn_metadata.spec_attn_mask, - actual_seq_lengths_q=actual_seq_lengths_q, - sin=sin, - cos=cos, - cp_seq_len=cp_seq_len, - batch_seq_mask=batch_seq_mask) - else: - cos[:num_decode_tokens, - ...] = self.cos_cache[input_positions].unsqueeze( - 1).unsqueeze(2) - sin[:num_decode_tokens, - ...] = self.sin_cache[input_positions].unsqueeze( - 1).unsqueeze(2) - - decode_metadata = AscendMLADecodeMetadata( - input_positions=input_positions, - block_table=block_table, - seq_lens=seq_lens, - seq_lens_list=seq_lens_list, - max_seq_lens=max_seq_lens, - attn_mask=common_attn_metadata.spec_attn_mask, - actual_seq_lengths_q=actual_seq_lengths_q, - sin=sin[:num_decode_tokens, ...], - cos=cos[:num_decode_tokens, ...], - cp_seq_len=cp_seq_len, - batch_seq_mask=batch_seq_mask) - - return self.metadata_cls( # type: ignore - num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded, - num_input_tokens=common_attn_metadata.num_input_tokens, - num_actual_tokens=num_actual_tokens, - query_lens=query_lens.tolist(), - slot_mapping=slot_mapping, - head_dim=self.model_config.get_head_size(), - num_decodes=num_decodes, - num_decode_tokens=num_decode_tokens, - num_prefills=num_prefills, - attn_mask=common_attn_metadata.attn_mask, - attn_state=common_attn_metadata.attn_state, - prefill=prefill_metadata, - decode=decode_metadata, - query_start_loc=query_start_loc, - block_tables=block_table, - seq_lens=seq_lens, - ) + cp_seq_len = num_computed_tokens_of_cp_dcp_array[:, self.pcp_rank, + self.dcp_rank] + cp_seq_len = torch.tensor(cp_seq_len, dtype=torch.int32) + batch_seq_mask = (cp_seq_len == 0) + self.batch_seq_mask_buf[:batch_seq_mask.shape[0]].copy_( + batch_seq_mask, non_blocking=True) + batch_seq_mask = self.batch_seq_mask_buf[:batch_seq_mask.shape[0]] + cp_seq_len = torch.where(cp_seq_len == 0, 1, cp_seq_len) + decode_metadata.cp_seq_len = cp_seq_len + decode_metadata.batch_seq_mask = batch_seq_mask + return decode_metadata class AscendMlaCPImpl(AscendMLAImpl): diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 6e47b98e..b77682db 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -9,9 +9,6 @@ from torch import nn from vllm.attention.backends.abstract import AttentionBackend, MLAAttentionImpl from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.config import VllmConfig, get_current_vllm_config -from vllm.distributed import (get_decode_context_model_parallel_rank, - get_decode_context_model_parallel_world_size, - get_pcp_group) from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import logger from vllm.model_executor.layers.linear import UnquantizedLinearMethod @@ -22,6 +19,8 @@ from vllm.v1.kv_cache_interface import MLAAttentionSpec from vllm_ascend import envs from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.attention.attention_v1 import AscendAttentionState +from vllm_ascend.attention.common_cp import (AscendPCPMetadata, + CPChunkedContextMetadata) from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, enable_cp, maybe_save_kv_layer_to_connector, @@ -76,44 +75,22 @@ class AscendMLABackend(AttentionBackend): return AscendMLAImpl +@dataclass +class ChunkedContextMetadata: + # New for MLA (compared to FlashAttention) + # For handling chunked prefill + cu_seq_lens: torch.Tensor + starts: torch.Tensor + seq_tot: list[int] + max_seq_lens: list[int] + workspace: torch.Tensor + chunk_seq_lens: torch.Tensor + chunk_seq_lens_npu: torch.Tensor + + @dataclass class AscendMLAPrefillMetadata: """ Prefill Specific Metadata for Ascend""" - - @dataclass - class ChunkedContextMetadata: - # New for MLA (compared to FlashAttention) - # For handling chunked prefill - cu_seq_lens: torch.Tensor - starts: torch.Tensor - seq_tot: list[int] - max_seq_lens: list[int] - workspace: torch.Tensor - chunk_seq_lens: torch.Tensor - chunk_seq_lens_npu: torch.Tensor - # for mla DCP & PCP - padded_chunk_seq_lens_npu: torch.Tensor = None - padded_local_chunk_seq_lens: Optional[list[list[int]]] = None - local_context_lens_allranks: Optional[list[list[int]]] = None - padded_local_cu_seq_lens: torch.Tensor = None - cu_seq_lens_lst: Optional[list[list[int]]] = None - chunk_size: Optional[int] = None - - @dataclass - class AscendPCPMetadata: - q_head_idx: torch.Tensor = None - q_tail_idx: torch.Tensor = None - kv_with_q_head_nomask_idx: torch.Tensor = None - kv_with_q_head_mask_idx: torch.Tensor = None - kv_with_q_tail_nomask_idx: torch.Tensor = None - kv_with_q_tail_mask_idx: torch.Tensor = None - attn_mask_seqlens: torch.Tensor = None - head_attn_nomask_seqlens: torch.Tensor = None - tail_attn_nomask_seqlens: torch.Tensor = None - q_full_idx: torch.Tensor = None - pcp_prefill_mask: torch.Tensor = None - pcp_allgather_restore_idx: Optional[list[int]] = None - attn_mask: torch.Tensor query_lens: torch.Tensor seq_lens: list[int] @@ -123,7 +100,8 @@ class AscendMLAPrefillMetadata: block_table: torch.Tensor max_query_len: int max_seq_lens: int - chunked_context: Optional[ChunkedContextMetadata] = None + chunked_context: Optional[ChunkedContextMetadata + | CPChunkedContextMetadata] = None sin: torch.Tensor = None cos: torch.Tensor = None pcp_metadata: Optional[AscendPCPMetadata] = None @@ -262,21 +240,21 @@ class AscendMLAMetadataBuilder: self.cos_cache = None self.sin_cache = None - self.pcp_size = get_pcp_group().world_size - self.pcp_rank = get_pcp_group( - ).rank_in_group if self.pcp_size > 1 else 0 - self.dcp_size = get_decode_context_model_parallel_world_size() - self.dcp_rank = get_decode_context_model_parallel_rank( - ) if self.dcp_size > 1 else 0 - self.cp_local_block_size = vllm_config.parallel_config.cp_kv_cache_interleave_size - self.cp_virtual_block_size = self.cp_local_block_size * self.dcp_size * self.pcp_size - decode_max_num_seqs = getattr(scheduler_config, 'decode_max_num_seqs', - 0) - max_num_seqs = max(scheduler_config.max_num_seqs, decode_max_num_seqs) - self.batch_seq_mask_buf = torch.empty(max_num_seqs * - self.decode_threshold, - dtype=torch.uint8, - device=device) + self.chunk_seq_lens: torch.Tensor = None + self.cu_seq_lens_cpu: torch.Tensor = None + self.num_chunks: torch.Tensor = None + self.max_context_chunk = 0 + self.num_decodes = 0 + self.num_prefills = 0 + self.num_decode_tokens = 0 + self.num_prefill_tokens = 0 + self.context_lens_cpu: torch.Tensor = None + self.num_actual_tokens: Optional[int] = None + self.block_table: torch.Tensor = None + self.slot_mapping: torch.Tensor = None + self.graph_pad_size = 0 + self.query_lens: torch.Tensor = None + self.seq_lens: torch.Tensor = None def reorder_batch(self, input_batch: "NPUInputBatch", scheduler_output: "SchedulerOutput") -> bool: @@ -396,6 +374,12 @@ class AscendMLAMetadataBuilder: actual_seq_lengths_q = actual_seq_lengths_q + interpolated return actual_seq_lengths_q + def set_num_actual_tokens( + self, + common_attn_metadata: AscendCommonAttentionMetadata, + ): + self.num_actual_tokens = common_attn_metadata.num_actual_tokens + def build( self, common_prefix_len: int, @@ -403,41 +387,18 @@ class AscendMLAMetadataBuilder: model: nn.Module, ) -> AscendMLAMetadata: num_reqs = common_attn_metadata.num_reqs - num_actual_tokens = common_attn_metadata.num_actual_tokens query_start_loc = common_attn_metadata.query_start_loc query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu - long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata - num_actual_tokens_pcp_padded = long_seq_metadata.num_actual_tokens_pcp_padded if long_seq_metadata else None - - num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \ + self.num_decodes, self.num_prefills, self.num_decode_tokens, self.num_prefill_tokens = \ split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.decode_threshold) - assert num_decodes + num_prefills == num_reqs - assert num_decode_tokens + num_prefill_tokens == num_actual_tokens - - # Note(simon): be careful about the CPU <> GPU memory movement in this - # function. We should avoid GPU -> CPU sync as much as possible because - # it blocks on all previous kernels. - device = self.device - - # If graph_pad_size > -1, mean is running in fullgraph mode. - graph_pad_size = common_attn_metadata.graph_pad_size - # NOTE: Maybe this block_table change can be removed when graph_pad_size > 1. - if graph_pad_size > num_reqs and self.speculative_config.disable_padded_drafter_batch: - block_table = ( - common_attn_metadata.block_table_tensor[:graph_pad_size]) - else: - block_table = (common_attn_metadata.block_table_tensor[:num_reqs]) - - if num_actual_tokens_pcp_padded is None: - num_actual_tokens_pcp_padded = num_actual_tokens + self.set_num_actual_tokens(common_attn_metadata) + assert self.num_decodes + self.num_prefills == num_reqs + assert self.num_decode_tokens + self.num_prefill_tokens == common_attn_metadata.num_actual_tokens # NOTE: Currently, MTP-fullgraph is incompatibility pcp - slot_mapping = common_attn_metadata.slot_mapping[: - num_actual_tokens_pcp_padded] - input_positions = common_attn_metadata.positions[: - num_actual_tokens_pcp_padded].long( - ) + self.slot_mapping = common_attn_metadata.slot_mapping[:self. + num_actual_tokens] if self.cos_cache is None: self.cos_cache = model.model.layers[ @@ -451,210 +412,277 @@ class AscendMLAMetadataBuilder: self.model_config.dtype) # type: ignore query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] - query_lens = query_seq_lens_cpu[:num_reqs] - seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs] - num_computed_tokens_cpu = (seq_lens - query_lens) + self.query_lens = query_seq_lens_cpu[:num_reqs] + self.seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs] + + self.set_prefill_block_table(common_attn_metadata) prefill_metadata = None - chunked_context_metadata = None - if num_prefills > 0: - pcp_metadata = None - - reqs_start = num_decodes # prefill_start - tokens_start = num_decode_tokens - max_query_len = query_lens[reqs_start:].max().item() - max_seq_lens = seq_lens[reqs_start:].max().item() - prefill_query_start_loc = query_start_loc[ - reqs_start:] - query_start_loc[reqs_start] - - context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs] - max_context_len_cpu = context_lens_cpu.max().item() - num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item() - if self.chunked_prefill_enabled and max_context_len_cpu > 0: - max_context_chunk = (self.chunked_prefill_workspace_size // - num_prefills_with_context_cpu) - max_context_chunk = round_down(max_context_chunk, - self.block_size) - - assert max_context_chunk > 0 - num_chunks = cdiv(max_context_len_cpu, max_context_chunk) - chunk_starts = torch.arange(num_chunks, dtype=torch.int32) \ - .unsqueeze(1).expand(-1, num_prefills) * max_context_chunk - chunk_ends = torch.min(context_lens_cpu.unsqueeze(0), - chunk_starts + max_context_chunk) - chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0) - cu_seq_lens_cpu = torch.zeros(num_chunks, - num_prefills + 1, - dtype=torch.int32, - pin_memory=True) - torch.cumsum(chunk_seq_lens, - dim=1, - out=cu_seq_lens_cpu[:, 1:], - dtype=torch.int32) - - chunked_context_metadata = ( - AscendMLAPrefillMetadata.ChunkedContextMetadata( - cu_seq_lens=cu_seq_lens_cpu.pin_memory().to( - device, non_blocking=True), - starts=chunk_starts.pin_memory().to(device, - non_blocking=True), - seq_tot=chunk_seq_lens.sum(dim=1).tolist(), - max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), - chunk_seq_lens=chunk_seq_lens, - chunk_seq_lens_npu=chunk_seq_lens.npu(), - workspace=self.chunked_prefill_workspace, - )) - prefill_input_positions = input_positions[tokens_start:] - cos = self.cos_cache[ - prefill_input_positions].unsqueeze( # type: ignore - 1).unsqueeze(2) - sin = self.sin_cache[ - prefill_input_positions].unsqueeze( # type: ignore - 1).unsqueeze(2) - prefill_metadata = AscendMLAPrefillMetadata( - attn_mask=common_attn_metadata.attn_mask, - query_lens=query_lens[reqs_start:].to(torch.int32), - seq_lens=seq_lens, - context_lens=seq_lens[reqs_start:], - input_positions=prefill_input_positions, - block_table=block_table[reqs_start:, ...], - max_query_len=max_query_len, - max_seq_lens=max_seq_lens, - query_start_loc=prefill_query_start_loc, - chunked_context=chunked_context_metadata, - sin=sin, - cos=cos, - pcp_metadata=pcp_metadata, - ) + if self.num_prefills > 0: + prefill_metadata = self.build_prefill_metadata( + common_prefix_len, common_attn_metadata, model) decode_metadata = None - if num_decodes > 0: - cos, sin = get_cos_and_sin_mla() - # Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario - actual_seq_lengths_q = query_start_loc_cpu[1:num_decodes + - 1].tolist() - max_seq_lens = seq_lens[:num_decodes].max().item() - seq_lens = seq_lens[:num_decodes] - input_positions = input_positions[:num_decode_tokens] - block_table = block_table[:num_decodes, ...] - # NOTE: Currently, MTP-fullgraph is incompatibility pcp - # NOTE: Maybe this block_table change can be removed when graph_pad_size > 1. - if graph_pad_size > num_decodes and \ - self.speculative_config.disable_padded_drafter_batch: - block_table = block_table[:graph_pad_size, ...] - seq_lens_list = seq_lens.tolist() - - cp_seq_len, batch_seq_mask = None, None - - if graph_pad_size > num_reqs: - if self.speculative_config.disable_padded_drafter_batch: - num_reqs_pad_size = graph_pad_size - num_reqs - actual_seq_lengths_q = self.pad_actual_seq_len_q_mtp_disable_pad( - num_reqs_pad_size, num_reqs, actual_seq_lengths_q) - seq_lens_list = seq_lens_list + [0] * (graph_pad_size - - num_decodes) - num_block_pad_size = graph_pad_size - block_table.shape[0] - if num_block_pad_size > 0: - block_table_padding = torch.zeros( - (num_block_pad_size, ) + block_table.shape[1:], - dtype=block_table.dtype, - device=block_table.device) - block_table = torch.cat( - [block_table, block_table_padding], dim=0) - else: - num_token_pad_size = graph_pad_size - num_decode_tokens - num_reqs_pad_size = ( - graph_pad_size // - common_attn_metadata.decode_token_per_req - num_reqs) - num_block_table_pad_size = ( - graph_pad_size // - common_attn_metadata.decode_token_per_req - - num_decodes) - seq_lens_list = seq_lens.tolist() + [0] * num_reqs_pad_size - slot_padding = torch.full((num_token_pad_size, ), - PAD_SLOT_ID, - dtype=slot_mapping.dtype, - device=slot_mapping.device) - slot_mapping = torch.cat([slot_mapping, slot_padding]) - block_table_padding = torch.zeros( - (num_block_table_pad_size, ) + block_table.shape[1:], - dtype=block_table.dtype, - device=block_table.device) - block_table = torch.cat([block_table, block_table_padding], - dim=0) - position_padding = torch.zeros( - num_token_pad_size, - dtype=input_positions.dtype, - device=input_positions.device) - input_positions = torch.cat( - [input_positions, position_padding]) - actual_seq_lengths_q = self.pad_actual_seq_len_q_mtp_enable_pad( - num_reqs_pad_size, num_reqs, actual_seq_lengths_q, - common_attn_metadata) - - # TODO: After the fullgraph supports MTP, the if branch needs to deleted - assert self.cos_cache is not None - assert self.sin_cache is not None - if cos is None and sin is None: - cos = self.cos_cache[ - input_positions].unsqueeze( # type: ignore - 1).unsqueeze(2) - sin = self.sin_cache[ - input_positions].unsqueeze( # type: ignore - 1).unsqueeze(2) - - decode_metadata = AscendMLADecodeMetadata( - input_positions=input_positions, - block_table=block_table, - seq_lens=seq_lens, - seq_lens_list=seq_lens_list, - max_seq_lens=max_seq_lens, - attn_mask=common_attn_metadata.spec_attn_mask, - actual_seq_lengths_q=actual_seq_lengths_q, - sin=sin, - cos=cos, - cp_seq_len=cp_seq_len, - batch_seq_mask=batch_seq_mask) - else: - cos[:num_decode_tokens, - ...] = self.cos_cache[input_positions].unsqueeze( - 1).unsqueeze(2) - sin[:num_decode_tokens, - ...] = self.sin_cache[input_positions].unsqueeze( - 1).unsqueeze(2) - - decode_metadata = AscendMLADecodeMetadata( - input_positions=input_positions, - block_table=block_table, - seq_lens=seq_lens, - seq_lens_list=seq_lens_list, - max_seq_lens=max_seq_lens, - attn_mask=common_attn_metadata.spec_attn_mask, - actual_seq_lengths_q=actual_seq_lengths_q, - sin=sin[:num_decode_tokens, ...], - cos=cos[:num_decode_tokens, ...], - cp_seq_len=cp_seq_len, - batch_seq_mask=batch_seq_mask) + if self.num_decodes > 0: + decode_metadata = self.build_decode_metadata( + common_prefix_len, common_attn_metadata, model) return self.metadata_cls( # type: ignore - num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded, + num_actual_tokens_pcp_padded=self.num_actual_tokens, num_input_tokens=common_attn_metadata.num_input_tokens, - num_actual_tokens=num_actual_tokens, - query_lens=query_lens.tolist(), - slot_mapping=slot_mapping, + num_actual_tokens=self.num_actual_tokens, + query_lens=self.query_lens.tolist(), + slot_mapping=self.slot_mapping, head_dim=self.model_config.get_head_size(), - num_decodes=num_decodes, - num_decode_tokens=num_decode_tokens, - num_prefills=num_prefills, + num_decodes=self.num_decodes, + num_decode_tokens=self.num_decode_tokens, + num_prefills=self.num_prefills, attn_mask=common_attn_metadata.attn_mask, attn_state=common_attn_metadata.attn_state, prefill=prefill_metadata, decode=decode_metadata, query_start_loc=query_start_loc, - block_tables=block_table, - seq_lens=seq_lens, + block_tables=self.block_table, + seq_lens=self.seq_lens, ) + def build_chunked_metadata( + self, + common_prefix_len: int, + common_attn_metadata: AscendCommonAttentionMetadata, + model: nn.Module, + ): + if not self.chunked_prefill_enabled: + return None + num_reqs = common_attn_metadata.num_reqs + + num_computed_tokens_cpu = (self.seq_lens - self.query_lens) + reqs_start = self.num_decodes # prefill_start + + self.context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs] + max_context_len_cpu = self.context_lens_cpu.max().item() + if not max_context_len_cpu > 0: + return None + num_prefills_with_context_cpu = (self.context_lens_cpu + > 0).sum().item() + self.max_context_chunk = (self.chunked_prefill_workspace_size // + num_prefills_with_context_cpu) + self.max_context_chunk = round_down(self.max_context_chunk, + self.block_size) + + assert self.max_context_chunk > 0 + self.num_chunks = cdiv(max_context_len_cpu, self.max_context_chunk) + chunk_starts = torch.arange(self.num_chunks, dtype=torch.int32) \ + .unsqueeze(1).expand(-1, self.num_prefills) * self.max_context_chunk + chunk_ends = torch.min(self.context_lens_cpu.unsqueeze(0), + chunk_starts + self.max_context_chunk) + self.chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0) + self.cu_seq_lens_cpu = torch.zeros(self.num_chunks, + self.num_prefills + 1, + dtype=torch.int32, + pin_memory=True) + torch.cumsum(self.chunk_seq_lens, + dim=1, + out=self.cu_seq_lens_cpu[:, 1:], + dtype=torch.int32) + return ChunkedContextMetadata( + cu_seq_lens=self.cu_seq_lens_cpu.pin_memory().to( + self.device, non_blocking=True), + starts=chunk_starts.pin_memory().to(self.device, + non_blocking=True), + seq_tot=self.chunk_seq_lens.sum(dim=1).tolist(), + max_seq_lens=self.chunk_seq_lens.max(dim=1).values.tolist(), + chunk_seq_lens=self.chunk_seq_lens, + chunk_seq_lens_npu=self.chunk_seq_lens.npu(), + workspace=self.chunked_prefill_workspace, + ) + + def set_prefill_block_table( + self, + common_attn_metadata: AscendCommonAttentionMetadata, + ): + # If graph_pad_size > -1, mean is running in fullgraph mode. + self.graph_pad_size = common_attn_metadata.graph_pad_size + # NOTE: Maybe this block_table change can be removed when graph_pad_size > 1. + if self.graph_pad_size > common_attn_metadata.num_reqs and self.speculative_config.disable_padded_drafter_batch: + self.block_table = ( + common_attn_metadata.block_table_tensor[:self.graph_pad_size]) + else: + self.block_table = ( + common_attn_metadata.block_table_tensor[:common_attn_metadata. + num_reqs]) + + def set_decode_block_table( + self, common_attn_metadata: AscendCommonAttentionMetadata): + self.block_table = self.block_table[:self.num_decodes, ...] + + def build_prefill_metadata( + self, + common_prefix_len: int, + common_attn_metadata: AscendCommonAttentionMetadata, + model: nn.Module, + ) -> AscendMLAPrefillMetadata: + query_start_loc = common_attn_metadata.query_start_loc + + # NOTE: Currently, MTP-fullgraph is incompatibility pcp + input_positions = common_attn_metadata.positions[:self. + num_actual_tokens].long( + ) + + chunked_context_metadata = self.build_chunked_metadata( + common_prefix_len, common_attn_metadata, model) + reqs_start = self.num_decodes # prefill_start + tokens_start = self.num_decode_tokens + max_query_len = self.query_lens[reqs_start:].max().item() + max_seq_lens = self.seq_lens[reqs_start:].max().item() + prefill_query_start_loc = query_start_loc[ + reqs_start:] - query_start_loc[reqs_start] + + prefill_input_positions = input_positions[tokens_start:] + cos = self.cos_cache[ + prefill_input_positions].unsqueeze( # type: ignore + 1).unsqueeze(2) + sin = self.sin_cache[ + prefill_input_positions].unsqueeze( # type: ignore + 1).unsqueeze(2) + return AscendMLAPrefillMetadata( + attn_mask=common_attn_metadata.attn_mask, + query_lens=self.query_lens[reqs_start:].to(torch.int32), + seq_lens=self.seq_lens, + context_lens=self.seq_lens[reqs_start:], + input_positions=prefill_input_positions, + block_table=self.block_table[reqs_start:, ...], + max_query_len=max_query_len, + max_seq_lens=max_seq_lens, + query_start_loc=prefill_query_start_loc, + chunked_context=chunked_context_metadata, + sin=sin, + cos=cos, + ) + + def build_decode_metadata( + self, + common_prefix_len: int, + common_attn_metadata: AscendCommonAttentionMetadata, + model: nn.Module, + ) -> AscendMLADecodeMetadata: + num_reqs = common_attn_metadata.num_reqs + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + + input_positions = common_attn_metadata.positions[:self. + num_actual_tokens].long( + ) + + cos, sin = get_cos_and_sin_mla() + # Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario + actual_seq_lengths_q = query_start_loc_cpu[1:self.num_decodes + + 1].tolist() + max_seq_lens = self.seq_lens[:self.num_decodes].max().item() + self.seq_lens = self.seq_lens[:self.num_decodes] + input_positions = input_positions[:self.num_decode_tokens] + + self.set_decode_block_table(common_attn_metadata) + + # NOTE: Currently, MTP-fullgraph is incompatibility pcp + # NOTE: Maybe this block_table change can be removed when graph_pad_size > 1. + if self.graph_pad_size > self.num_decodes and \ + self.speculative_config.disable_padded_drafter_batch: + self.block_table = self.block_table[:self.graph_pad_size, ...] + seq_lens_list = self.seq_lens.tolist() + + cp_seq_len, batch_seq_mask = None, None + + if self.graph_pad_size > num_reqs: + if self.speculative_config.disable_padded_drafter_batch: + num_reqs_pad_size = self.graph_pad_size - num_reqs + actual_seq_lengths_q = self.pad_actual_seq_len_q_mtp_disable_pad( + num_reqs_pad_size, num_reqs, actual_seq_lengths_q) + seq_lens_list = seq_lens_list + [0] * (self.graph_pad_size - + self.num_decodes) + num_block_pad_size = self.graph_pad_size - self.block_table.shape[ + 0] + if num_block_pad_size > 0: + block_table_padding = torch.zeros( + (num_block_pad_size, ) + self.block_table.shape[1:], + dtype=self.block_table.dtype, + device=self.block_table.device) + self.block_table = torch.cat( + [self.block_table, block_table_padding], dim=0) + else: + num_token_pad_size = self.graph_pad_size - self.num_decode_tokens + num_reqs_pad_size = ( + self.graph_pad_size // + common_attn_metadata.decode_token_per_req - num_reqs) + num_block_table_pad_size = ( + self.graph_pad_size // + common_attn_metadata.decode_token_per_req - + self.num_decodes) + seq_lens_list = self.seq_lens.tolist() + [0 + ] * num_reqs_pad_size + slot_padding = torch.full((num_token_pad_size, ), + PAD_SLOT_ID, + dtype=self.slot_mapping.dtype, + device=self.slot_mapping.device) + self.slot_mapping = torch.cat( + [self.slot_mapping, slot_padding]) + block_table_padding = torch.zeros( + (num_block_table_pad_size, ) + self.block_table.shape[1:], + dtype=self.block_table.dtype, + device=self.block_table.device) + self.block_table = torch.cat( + [self.block_table, block_table_padding], dim=0) + position_padding = torch.zeros(num_token_pad_size, + dtype=input_positions.dtype, + device=input_positions.device) + input_positions = torch.cat( + [input_positions, position_padding]) + actual_seq_lengths_q = self.pad_actual_seq_len_q_mtp_enable_pad( + num_reqs_pad_size, num_reqs, actual_seq_lengths_q, + common_attn_metadata) + + # TODO: After the fullgraph supports MTP, the if branch needs to deleted + assert self.cos_cache is not None + assert self.sin_cache is not None + if cos is None and sin is None: + cos = self.cos_cache[input_positions].unsqueeze( # type: ignore + 1).unsqueeze(2) + sin = self.sin_cache[input_positions].unsqueeze( # type: ignore + 1).unsqueeze(2) + + decode_metadata = AscendMLADecodeMetadata( + input_positions=input_positions, + block_table=self.block_table, + seq_lens=self.seq_lens, + seq_lens_list=seq_lens_list, + max_seq_lens=max_seq_lens, + attn_mask=common_attn_metadata.spec_attn_mask, + actual_seq_lengths_q=actual_seq_lengths_q, + sin=sin, + cos=cos, + cp_seq_len=cp_seq_len, + batch_seq_mask=batch_seq_mask) + else: + cos[:self.num_decode_tokens, + ...] = self.cos_cache[input_positions].unsqueeze(1).unsqueeze( + 2) + sin[:self.num_decode_tokens, + ...] = self.sin_cache[input_positions].unsqueeze(1).unsqueeze( + 2) + + decode_metadata = AscendMLADecodeMetadata( + input_positions=input_positions, + block_table=self.block_table, + seq_lens=self.seq_lens, + seq_lens_list=seq_lens_list, + max_seq_lens=max_seq_lens, + attn_mask=common_attn_metadata.spec_attn_mask, + actual_seq_lengths_q=actual_seq_lengths_q, + sin=sin[:self.num_decode_tokens, ...], + cos=cos[:self.num_decode_tokens, ...], + cp_seq_len=cp_seq_len, + batch_seq_mask=batch_seq_mask) + return decode_metadata + def build_for_graph_capture( self, common_attn_metadata: AscendCommonAttentionMetadata, diff --git a/vllm_ascend/attention/utils.py b/vllm_ascend/attention/utils.py index 40175e5c..db0cc99d 100644 --- a/vllm_ascend/attention/utils.py +++ b/vllm_ascend/attention/utils.py @@ -87,7 +87,7 @@ class AscendPrefillContextParallelMetadata: cp_kv_recover_idx_for_chunk: torch.Tensor = None - num_actual_tokens_pcp_padded: Optional[int] = None + num_actual_tokens_pcp_padded: int = 0 num_computed_tokens_of_pcp_dcp: Optional[list[list[list[int]]]] = None