From 71866d531151b36a9af4db36ba8b3b74162f7028 Mon Sep 17 00:00:00 2001 From: Apocalypse Date: Tue, 11 Nov 2025 09:18:02 +0800 Subject: [PATCH] [feature] chunkprefill support pcp&dcp (#3801) ### What this PR does / why we need it? ChunkPrefill now can support Long Sequence Feature Pcp&Dcp ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? CI tests passed with self-test - vLLM version: v0.11.0 - vLLM main: https://github.com/vllm-project/vllm/commit/83f478bb19489b41e9d208b47b4bb5a95ac171ac --------- Signed-off-by: Apocalypse990923-qshi Signed-off-by: Delphine-Nic Co-authored-by: Delphine-Nic Co-authored-by: Delphine-Nic <3834144971@qq.com> --- tests/ut/attention/test_attention_v1.py | 1 + tests/ut/attention/test_mla_v1.py | 3 + vllm_ascend/attention/attention_v1.py | 529 ++++++++++++++++++++---- vllm_ascend/attention/mla_v1.py | 515 ++++++++++++++++++++--- vllm_ascend/attention/utils.py | 52 ++- vllm_ascend/worker/block_table.py | 30 +- vllm_ascend/worker/model_runner_v1.py | 301 +++++++++++++- vllm_ascend/worker/npu_input_batch.py | 15 + 8 files changed, 1276 insertions(+), 170 deletions(-) diff --git a/tests/ut/attention/test_attention_v1.py b/tests/ut/attention/test_attention_v1.py index 9c732f61..e8fff182 100644 --- a/tests/ut/attention/test_attention_v1.py +++ b/tests/ut/attention/test_attention_v1.py @@ -83,6 +83,7 @@ class TestAscendAttentionMetadataBuilder(TestBase): self.mock_vllm_config.compilation_config.cudagraph_mode = None self.mock_vllm_config.scheduler_config.max_num_seqs = 10 self.mock_vllm_config.scheduler_config.decode_max_num_seqs = 10 + self.mock_vllm_config.scheduler_config.chunked_prefill_enabled = False self.mock_device = 'cpu:0' self.builder = AscendAttentionMetadataBuilder(None, None, self.mock_vllm_config, diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py index 102b6feb..2ff545ff 100644 --- a/tests/ut/attention/test_mla_v1.py +++ b/tests/ut/attention/test_mla_v1.py @@ -484,6 +484,9 @@ class TestAscendMLAImpl(TestBase): chunk_ctx.chunk_seq_lens = [torch.tensor([8])] chunk_ctx.chunk_seq_lens_npu = [torch.tensor([8])] chunk_ctx.starts = [torch.tensor([0])] + chunk_ctx.max_chunk_num = 1 + chunk_ctx.mask_for_non_zero_chunk = [True] + chunk_ctx.local_chunked_kv_lens = [[[[8]]]] prefill_meta = MagicMock() prefill_meta.chunked_context = chunk_ctx diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 098e77c5..a0cc20f0 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -37,6 +37,8 @@ from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import AttentionSpec from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, + extract_req_dcp_by_chunk_pcp, + filter_chunked_req_indices, split_decodes_and_prefills) from vllm_ascend.compilation.acl_graph import (get_graph_params, update_graph_params_workspaces) @@ -52,6 +54,7 @@ if prefill_context_parallel_enable(): get_prefill_context_model_parallel_rank, get_prefill_context_model_parallel_world_size ) + # isort: on @@ -155,9 +158,23 @@ class AscendPCPMetadata: @dataclass class AscendMetadataForPrefill: + + @dataclass + class ChunkedContextMetadata: + actual_chunk_seq_lengths: list[int] + mask_for_non_zero_chunk: Optional[list[bool]] = None + max_chunk_num: int = 0 + local_chunked_kv_lens: Optional[list[Optional[list[Optional[list[ + Optional[list[int]]]]]]]] = None + cp_kv_recover_idx_for_chunk: Optional[list[int]] = None + kv_inverse_idx_for_chunk: Optional[list[int]] = None + """ Prefill Specific Metadata for Ascend""" pcp_metadata: Optional[AscendPCPMetadata] = None pcp_allgather_restore_idx: Optional[List[int]] = None + chunked_context: Optional[ChunkedContextMetadata] = None + block_tables: torch.Tensor = None + actual_seq_lengths_q: torch.Tensor = None @dataclass @@ -165,6 +182,7 @@ class AscendMetadataForDecode: """ Decode Specific Metadata for Ascend""" num_computed_tokens_of_pcp_dcp: Optional[list[list[list[int]]]] = None batch_seq_mask: torch.Tensor = None + block_tables: torch.Tensor = None @dataclass @@ -237,13 +255,10 @@ class AscendAttentionMetadataBuilder: self.max_num_blocks_per_req = cdiv( self.model_config.max_model_len, AscendAttentionBackend.get_supported_block_size()[0]) - decode_max_num_seqs = getattr(vllm_config.scheduler_config, - 'decode_max_num_seqs', 0) - max_num_seqs = max(vllm_config.scheduler_config.max_num_seqs, - decode_max_num_seqs) - self.batch_seq_mask_buf = torch.empty(max_num_seqs, - dtype=torch.uint8, - device=device) + self.batch_seq_mask_buf = torch.empty( + vllm_config.scheduler_config.max_num_batched_tokens, + dtype=torch.uint8, + device=device) self.pcp_size = get_prefill_context_model_parallel_world_size( ) if prefill_context_parallel_enable() else 1 self.pcp_rank = get_prefill_context_model_parallel_rank( @@ -263,6 +278,27 @@ class AscendAttentionMetadataBuilder: AscendAttentionMetadataBuilder.reorder_batch_threshold = self.decode_threshold + scheduler_config = vllm_config.scheduler_config + self.block_size = vllm_config.cache_config.block_size + self.max_blocks = (vllm_config.model_config.max_model_len + + self.block_size - 1) // self.block_size + self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled + if self.chunked_prefill_enabled: + self.chunked_prefill_workspace_size = min( + # Max sure there is enough for 8 full length request or at least + # 4 pages of cache per request + max(8 * self.model_config.max_model_len, + 4 * scheduler_config.max_num_seqs * self.block_size), + 128 * 1024) + assert self.chunked_prefill_workspace_size >= \ + scheduler_config.max_num_seqs * self.block_size + self.chunked_prefill_workspace = torch.empty( + (self.chunked_prefill_workspace_size, + self.model_config.get_head_size()), + dtype=self.model_config.dtype, + device=device, + ) + def reorder_batch(self, input_batch, scheduler_output: "SchedulerOutput") -> bool: return False @@ -279,9 +315,8 @@ class AscendAttentionMetadataBuilder: num_reqs + 1] - decode_threshold = 1 num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \ - split_decodes_and_prefills(common_attn_metadata, decode_threshold=decode_threshold) + split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.decode_threshold) assert num_decodes + num_prefills == num_reqs assert num_decode_tokens + num_prefill_tokens == num_actual_tokens @@ -302,6 +337,7 @@ class AscendAttentionMetadataBuilder: query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[: num_reqs + 1] + num_computed_tokens_cpu = (seq_lens - query_lens) if attn_state == AscendAttentionState.DecodeOnly and \ common_attn_metadata.num_input_tokens > num_actual_tokens: @@ -338,16 +374,35 @@ class AscendAttentionMetadataBuilder: attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(), ACL_FORMAT_FRACTAL_NZ) + common_long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata prefill_metadata = None - if num_prefills > 0: - pcp_metadata = None - common_long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata - if common_long_seq_metadata is not None: + decode_metadata = None + if common_long_seq_metadata is not None: + chunked_context_metadata = None + if num_prefills > 0: + query_lens = query_lens[num_decode_tokens:] + context_lens_cpu = num_computed_tokens_cpu[ + num_decodes:num_reqs] + max_context_len_cpu = context_lens_cpu.max().item() + pcp_size = get_prefill_context_model_parallel_world_size( + ) if prefill_context_parallel_enable() else 1 + if self.chunked_prefill_enabled and max_context_len_cpu > 0: + cp_kv_recover_idx_for_chunk = common_long_seq_metadata.cp_kv_recover_idx_for_chunk + kv_inverse_idx_for_chunk = torch.argsort( + cp_kv_recover_idx_for_chunk.to(torch.float32) + ) if cp_kv_recover_idx_for_chunk is not None else None + chunked_context_metadata = \ + AscendMetadataForPrefill.ChunkedContextMetadata( + actual_chunk_seq_lengths=torch.cumsum(query_lens * pcp_size, dim=0), + mask_for_non_zero_chunk=common_long_seq_metadata.mask_for_non_zero_chunk, + local_chunked_kv_lens=common_long_seq_metadata.local_chunked_kv_lens, + cp_kv_recover_idx_for_chunk=cp_kv_recover_idx_for_chunk, + kv_inverse_idx_for_chunk=kv_inverse_idx_for_chunk, + max_chunk_num=common_long_seq_metadata.max_chunk_num + ) attn_mask_seqlens = common_long_seq_metadata.attn_mask_seqlens head_attn_nomask_seqlens = common_long_seq_metadata.head_attn_nomask_seqlens tail_attn_nomask_seqlens = common_long_seq_metadata.tail_attn_nomask_seqlens - pcp_size = get_prefill_context_model_parallel_world_size( - ) if prefill_context_parallel_enable() else 1 if pcp_size > 1: attn_mask_seqlens = torch.cumsum(attn_mask_seqlens[0], dim=0).tolist() @@ -355,6 +410,7 @@ class AscendAttentionMetadataBuilder: head_attn_nomask_seqlens[1], dim=0).tolist() tail_attn_nomask_seqlens = torch.cumsum( tail_attn_nomask_seqlens[1], dim=0).tolist() + pcp_metadata = AscendPCPMetadata( q_head_idx=common_long_seq_metadata.q_head_idx_tensor, q_tail_idx=common_long_seq_metadata.q_tail_idx_tensor, @@ -371,16 +427,17 @@ class AscendAttentionMetadataBuilder: tail_attn_nomask_seqlens=tail_attn_nomask_seqlens, q_full_idx=common_long_seq_metadata.q_full_idx, pcp_prefill_mask=common_long_seq_metadata.pcp_prefill_mask) - prefill_metadata = AscendMetadataForPrefill( - pcp_metadata=pcp_metadata, - pcp_allgather_restore_idx=common_long_seq_metadata. - pcp_allgather_restore_idx - if common_long_seq_metadata is not None else None) - decode_metadata = None - if num_decodes > 0: - common_long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata - if common_long_seq_metadata is not None: + prefill_metadata = AscendMetadataForPrefill( + pcp_metadata=pcp_metadata, + pcp_allgather_restore_idx=common_long_seq_metadata. + pcp_allgather_restore_idx + if common_long_seq_metadata is not None else None, + chunked_context=chunked_context_metadata, + block_tables=block_table[num_decodes:], + actual_seq_lengths_q=torch.cumsum(query_lens, dim=0)) + + if num_decodes > 0: num_computed_tokens_of_pcp_dcp = common_long_seq_metadata.num_computed_tokens_of_pcp_dcp assert num_computed_tokens_of_pcp_dcp is not None num_computed_tokens_array = np.array( @@ -397,7 +454,7 @@ class AscendAttentionMetadataBuilder: num_computed_tokens_of_pcp_dcp=num_computed_tokens_array, batch_seq_mask=self.batch_seq_mask_buf[:batch_seq_mask. shape[0]], - ) + block_tables=block_table[:num_decodes]) attn_metadata = AscendMetadata( num_actual_tokens=num_actual_tokens, @@ -751,7 +808,8 @@ class AscendAttentionBackendImpl(AttentionImpl): k_mask: torch.Tensor, v_mask: torch.Tensor, kv_seqlens_mask: List[int], - mask: torch.Tensor) -> torch.Tensor: + mask: torch.Tensor, + attn_metadata) -> torch.Tensor: # nomask Attention if k_nomask is not None: attn_out_nomask, attn_lse_nomask = torch.ops.npu.npu_fused_infer_attention_score( @@ -786,30 +844,42 @@ class AscendAttentionBackendImpl(AttentionImpl): softmax_lse_flag=True, actual_seq_lengths_kv=kv_seqlens_mask, actual_seq_lengths=q_seqlens) - # update output = attn_out_mask + attn_lse = attn_lse_mask if k_nomask is not None: - T = attn_out_mask.shape[0] - N = attn_out_mask.shape[1] - D = attn_out_mask.shape[2] + if attn_metadata.prefill is not None and attn_metadata.prefill.chunked_context is None: + output = self._npu_attn_out_lse_update(attn_lse_mask, + attn_lse_nomask, + attn_out_mask, + attn_out_nomask) + attn_lse = None + else: + output, attn_lse = self._update_out_and_lse( + torch.stack([attn_out_nomask, attn_out_mask], dim=0), + torch.stack([attn_lse_nomask, attn_lse_mask], dim=0)) - attn_out_mask, attn_lse_mask = self._out_lse_reshape( - attn_out_mask, attn_lse_mask) - attn_out_nomask, attn_lse_nomask = self._out_lse_reshape( - attn_out_nomask, attn_lse_nomask) - attn_out_mask = attn_out_mask.to(torch.float32) - attn_out_nomask = attn_out_nomask.to(torch.float32) - attn_lse_mask = attn_lse_mask.to(torch.float32) - attn_lse_nomask = attn_lse_nomask.to(torch.float32) - - attn_output = [attn_out_nomask, attn_out_mask] - attn_lse = [attn_lse_nomask, attn_lse_mask] - update_type = 0 - output, _ = torch_npu.npu_attention_update(attn_lse, attn_output, - update_type) - output = output.view(T, N, D) + return output, attn_lse + def _npu_attn_out_lse_update(self, attn_lse_mask, attn_lse_nomask, + attn_out_mask, attn_out_nomask): + T = attn_out_mask.shape[0] + N = attn_out_mask.shape[1] + D = attn_out_mask.shape[2] + attn_out_mask, attn_lse_mask = self._out_lse_reshape( + attn_out_mask, attn_lse_mask) + attn_out_nomask, attn_lse_nomask = self._out_lse_reshape( + attn_out_nomask, attn_lse_nomask) + attn_out_mask = attn_out_mask.to(torch.float32) + attn_out_nomask = attn_out_nomask.to(torch.float32) + attn_lse_mask = attn_lse_mask.to(torch.float32) + attn_lse_nomask = attn_lse_nomask.to(torch.float32) + attn_output = [attn_out_nomask, attn_out_mask] + attn_lse = [attn_lse_nomask, attn_lse_mask] + update_type = 0 + output, _ = torch_npu.npu_attention_update(attn_lse, attn_output, + update_type) + output = output.view(T, N, D) return output def _forward_prefill_cp(self, query: torch.Tensor, key: torch.Tensor, @@ -831,7 +901,7 @@ class AscendAttentionBackendImpl(AttentionImpl): mask = attn_metadata.prefill.pcp_metadata.pcp_prefill_mask # 1. Attention calculation in the first half of Q in load balancing - output_head = self._attention_with_nomask_and_mask( + output_heads, lse_heads = self._attention_with_nomask_and_mask( q=torch.index_select(query, 0, q_head_idx), q_seqlens=attn_mask_seqlens, k_nomask=torch.index_select(key, 0, kv_with_q_head_nomask_idx) @@ -842,12 +912,13 @@ class AscendAttentionBackendImpl(AttentionImpl): k_mask=torch.index_select(key, 0, kv_with_q_head_mask_idx), v_mask=torch.index_select(value, 0, kv_with_q_head_mask_idx), kv_seqlens_mask=attn_mask_seqlens, - mask=mask) + mask=mask, + attn_metadata=attn_metadata) # 2. the Attention calculation in the latter half of Q in load balancing # pcp_rank0: Q3*KV0~KV2 + Q3*KV3 # pcp_rank1: Q2*KV0~KV1 + Q2*KV2 - output_tail = self._attention_with_nomask_and_mask( + output_tails, lse_tails = self._attention_with_nomask_and_mask( q=torch.index_select(query, 0, q_tail_idx), q_seqlens=attn_mask_seqlens, k_nomask=torch.index_select(key, 0, kv_with_q_tail_nomask_idx), @@ -856,13 +927,17 @@ class AscendAttentionBackendImpl(AttentionImpl): k_mask=torch.index_select(key, 0, kv_with_q_tail_mask_idx), v_mask=torch.index_select(value, 0, kv_with_q_tail_mask_idx), kv_seqlens_mask=attn_mask_seqlens, - mask=mask) + mask=mask, + attn_metadata=attn_metadata) - # 3. Combine the output of the first half and second half. q_full_idx = attn_metadata.prefill.pcp_metadata.q_full_idx output = torch.index_select( - torch.cat([output_head, output_tail], dim=0), 0, q_full_idx) - return output + torch.cat([output_heads, output_tails], dim=0), 0, q_full_idx) + attn_lse = None + if attn_metadata.prefill is not None and attn_metadata.prefill.chunked_context is not None: + attn_lse = torch.index_select( + torch.cat([lse_heads, lse_tails], dim=0), 0, q_full_idx) + return output, attn_lse def _out_lse_reshape(self, attn_out: torch.Tensor, attn_lse: torch.Tensor) -> torch.Tensor: @@ -928,7 +1003,7 @@ class AscendAttentionBackendImpl(AttentionImpl): 'softmax_lse_flag': True, 'block_table': - attn_metadata.block_tables, + attn_metadata.decode_meta.block_tables, 'block_size': self.key_cache.shape[1], 'actual_seq_lengths_kv': @@ -1029,8 +1104,24 @@ class AscendAttentionBackendImpl(AttentionImpl): attn_out = self._npu_attention_update(attn_out_lse_list) return attn_out + def _update_out_and_lse(self, out_list: torch.Tensor, + lse_list: torch.Tensor) -> torch.Tensor: + """LSE_final = log(sum(exp(LSE_i))), O_final = sum(exp(LSE_i - LSE_final) * O_i) + Args: + out_list: shape = [N, batch_size, num_heads, head_size] + lse_list: shape = [N, batch_size, num_heads, 1] + Returns: + out_final: shape = [batch_size, num_heads, head_size] + lse_final: shape = [batch_size, num_heads, 1] + """ + lse_final = torch.logsumexp(lse_list, dim=0, keepdim=False) + out_final = torch.sum(torch.exp(lse_list - lse_final) * out_list, + dim=0) + return out_final, lse_final + def _forward_pcp_dcp(self, query: torch.Tensor, key: torch.Tensor, - value: torch.Tensor, attn_metadata: AscendMetadata, + value: torch.Tensor, kv_cache: Tuple[torch.Tensor], + attn_metadata: AscendMetadata, output: torch.Tensor) -> torch.Tensor: assert attn_metadata is not None has_decode = attn_metadata.num_decodes > 0 @@ -1043,32 +1134,320 @@ class AscendAttentionBackendImpl(AttentionImpl): decode_query, attn_metadata) output[:num_decode_tokens] = output_decode if has_prefill: - prefill_query = query[num_decode_tokens:] + assert attn_metadata.prefill is not None + num_actual_tokens_pcp_padded = attn_metadata.num_actual_tokens_pcp_padded // self.pcp_size + prefill_query = query[ + num_decode_tokens:num_actual_tokens_pcp_padded] key = key[self.pcp_size * num_decode_tokens:] value = value[self.pcp_size * num_decode_tokens:] if self.pcp_size > 1: - output_prefill = self._forward_prefill_cp( + # Scenario of Enabling PCP or PCP&DCP + attn_output_prefill, attn_lse_prefill = self._forward_prefill_cp( prefill_query, key, value, attn_metadata) else: - max_prefill_seq_len = attn_metadata.seq_lens[ - attn_metadata.num_decode_tokens:].max().item() - if attn_metadata.attn_mask is not None: - attn_metadata.attn_mask = attn_metadata.attn_mask[: - max_prefill_seq_len, : - max_prefill_seq_len] - else: - ValueError("Attn_metadata.attn_mask is required") - seq_lens_back = attn_metadata.seq_lens - attn_metadata.seq_lens = attn_metadata.seq_lens[ - attn_metadata.num_decode_tokens:] - output_prefill = self._forward_prefill_no_cache( - prefill_query, key, value, attn_metadata, - output[num_decode_tokens:], prefill_query.shape[0]) - attn_metadata.seq_lens = seq_lens_back - output[num_decode_tokens:output_prefill.shape[0] + - num_decode_tokens] = output_prefill + # Scenario of Enabling DCP Individually + attn_output_prefill, attn_lse_prefill = torch.ops.npu.npu_fused_infer_attention_score( + prefill_query, + key, + value, + num_heads=self.num_heads, + num_key_value_heads=self.num_kv_heads, + input_layout="TND", + atten_mask=attn_metadata.attn_mask, + scale=self.scale, + sparse_mode=3, + antiquant_mode=0, + antiquant_scale=None, + softmax_lse_flag=True, + actual_seq_lengths_kv=attn_metadata.prefill. + actual_seq_lengths_q, + actual_seq_lengths=attn_metadata.prefill. + actual_seq_lengths_q) + + self._process_chunk_prefill(attn_output_prefill, attn_lse_prefill, + kv_cache, prefill_query, attn_metadata) + output[num_decode_tokens:attn_output_prefill.shape[0] + + num_decode_tokens] = attn_output_prefill return output + def _process_chunk_prefill(self, current_attn_output_prefill, + current_attn_lse_prefill, kv_cache, + prefill_query, attn_metadata): + if attn_metadata.prefill is not None and attn_metadata.prefill.chunked_context is not None: + prefill_query_all = self._prefill_query_all_gather( + attn_metadata, prefill_query) + attn_output_full_chunk, attn_lse_full_chunk = self._compute_prefill_context( + prefill_query_all, kv_cache, attn_metadata) + self._update_chunk_attn_out_lse_with_current_attn_out_lse( + current_attn_output_prefill, current_attn_lse_prefill, + attn_output_full_chunk, attn_lse_full_chunk, prefill_query, + attn_metadata) + + def _update_chunk_attn_out_lse_with_current_attn_out_lse( + self, current_attn_output_prefill, current_attn_lse_prefill, + attn_output_full_chunk, attn_lse_full_chunk, prefill_query, + attn_metadata): + if self.pcp_size > 1: + inverse_idx = attn_metadata.prefill.chunked_context.kv_inverse_idx_for_chunk + attn_output_full_chunk = torch.index_select( + attn_output_full_chunk, 0, inverse_idx) + attn_lse_full_chunk = torch.index_select(attn_lse_full_chunk, 0, + inverse_idx) + num_tokens = prefill_query.size(0) + attn_output_full_chunk = attn_output_full_chunk[ + self.pcp_rank * num_tokens:(self.pcp_rank + 1) * num_tokens, :, :] + attn_lse_full_chunk = attn_lse_full_chunk[ + self.pcp_rank * num_tokens:(self.pcp_rank + 1) * num_tokens, :, :] + assert attn_output_full_chunk.shape == current_attn_output_prefill.shape and attn_lse_full_chunk.shape == current_attn_lse_prefill.shape + seq_len = attn_metadata.query_lens.detach().clone() + filtered_indices = filter_chunked_req_indices( + seq_len, + attn_metadata.prefill.chunked_context.mask_for_non_zero_chunk) + + attn_output_prefill_filtered = current_attn_output_prefill[ + filtered_indices, :, :] + attn_lse_prefill_filtered = current_attn_lse_prefill[ + filtered_indices, :, :] + attn_output_full_chunk = attn_output_full_chunk[filtered_indices, :, :] + attn_lse_full_chunk = attn_lse_full_chunk[filtered_indices, :, :] + + attn_output_filtered = self._npu_attn_out_lse_update( + attn_lse_prefill_filtered, attn_lse_full_chunk, + attn_output_prefill_filtered, attn_output_full_chunk) + current_attn_output_prefill[ + filtered_indices, :, :] = attn_output_filtered.to( + current_attn_output_prefill.dtype) + + def _prefill_query_all_gather(self, attn_metadata, prefill_query): + prefill_query_all = get_pcp_group().all_gather(prefill_query.contiguous(), + 0) \ + if self.pcp_size > 1 else prefill_query + prefill_query_all = torch.index_select(prefill_query_all, + 0, + attn_metadata.prefill.chunked_context.cp_kv_recover_idx_for_chunk) \ + if self.pcp_size > 1 else prefill_query_all + return prefill_query_all + + def _compute_prefill_context(self, query: torch.Tensor, + kv_cache: Tuple[torch.Tensor], + attn_metadata: AscendMetadata): + assert len(kv_cache) > 1 + assert attn_metadata is not None + assert attn_metadata.prefill is not None + assert attn_metadata.prefill.chunked_context is not None + prefill_metadata = attn_metadata.prefill + local_chunked_kv_lens = attn_metadata.prefill.chunked_context.local_chunked_kv_lens + mask_for_non_zero_chunk = prefill_metadata.chunked_context.mask_for_non_zero_chunk + max_chunk_num = prefill_metadata.chunked_context.max_chunk_num + + assert local_chunked_kv_lens is not None and mask_for_non_zero_chunk is not None and max_chunk_num > 0 + + iters = max_chunk_num + # Keep the causal mask; do not override to all-ones. [req_id][chunk_id][cp-rank][dcp_rank] + context_starts_rank = None + + prefix_output_list = [] + prefix_lse_list = [] + for i in range(iters): + key, value, seq_lens_current_chunk_rank = self._load_kv_for_chunk( + attn_metadata, kv_cache, context_starts_rank, i, + local_chunked_kv_lens, prefill_metadata, query) + + # 2. Attention computation + if seq_lens_current_chunk_rank is None or torch.all( + seq_lens_current_chunk_rank == 0).item(): + prefix_output = torch.full( + (query.size(0), self.num_heads, self.head_size), + fill_value=0, + dtype=query.dtype, + device=query.device) + prefix_lse = torch.full((query.size(0), self.num_heads, 1), + fill_value=0, + dtype=torch.float32, + device=query.device) + else: + + actual_seq_lengths_kv = torch.cumsum( + seq_lens_current_chunk_rank, dim=0).tolist() + prefix_output, prefix_lse = torch.ops.npu.npu_fused_infer_attention_score( + query, + key, + value, + num_heads=self.num_heads, + num_key_value_heads=self.num_kv_heads, + input_layout="TND", # + atten_mask=None, + scale=self.scale, + sparse_mode=0, + antiquant_mode=0, + antiquant_scale=None, + softmax_lse_flag=True, + actual_seq_lengths_kv=actual_seq_lengths_kv, + actual_seq_lengths=attn_metadata.prefill.chunked_context. + actual_chunk_seq_lengths) + prefix_output_list.append(prefix_output) + prefix_lse_list.append(prefix_lse) + + # 3. update attn-out & lse + prefix_output, prefix_lse = self._update_attn_out_lse_in_chunks( + prefix_output_list, prefix_lse_list) + + self._update_attn_out_lse_in_pcp(attn_metadata, prefix_output, + prefix_lse) + + return prefix_output, prefix_lse + + def _update_attn_out_lse_in_chunks(self, prefix_output_list, + prefix_lse_list): + # update output and lse + if len(prefix_output_list) > 1: + prefix_output, prefix_lse = self._update_out_and_lse( + torch.stack(prefix_output_list, dim=0), + torch.stack(prefix_lse_list, dim=0)) + else: + prefix_output = prefix_output_list[0] + prefix_lse = prefix_lse_list[0] + return prefix_output, prefix_lse + + def _update_attn_out_lse_in_pcp(self, attn_metadata, prefix_output, + prefix_lse): + # CP dimension all_gather and fusion + if self.pcp_size > 1: + # filter non-zero chunk part of prefix_output + current_seq_lens = attn_metadata.query_lens.detach().clone() + current_seq_lens.mul_(self.pcp_size) # q_full + current_seq_lens_cpu = current_seq_lens.cpu() + filtered_indices = filter_chunked_req_indices( + current_seq_lens_cpu, + attn_metadata.prefill.chunked_context.mask_for_non_zero_chunk) + prefix_output_filtered = prefix_output[filtered_indices, :, :] + prefix_lse_filtered = prefix_lse[filtered_indices, :, :] + + out_lse_local = torch.cat( + [prefix_output_filtered, prefix_lse_filtered], dim=-1) + attn_out_lse_list = [ + torch.empty_like(out_lse_local) for _ in range(self.pcp_size) + ] + dist.all_gather(attn_out_lse_list, + out_lse_local, + group=self.pcp_group) + + attn_out_lse_allgather = torch.stack( + attn_out_lse_list, + dim=0) # [pcp, batch_size, num_heads, head_size+1] + attn_out_allgather, attn_lse_allgather = torch.split( + attn_out_lse_allgather, [self.head_size, 1], dim=-1) + prefix_output_filtered, prefix_lse_filtered = self._update_out_and_lse( + attn_out_allgather, attn_lse_allgather) + + prefix_output[filtered_indices, :, :] = prefix_output_filtered.to( + prefix_output.dtype) + prefix_lse[filtered_indices, :, :] = prefix_lse_filtered.to( + prefix_lse.dtype) + + def _load_kv_for_chunk(self, attn_metadata, kv_cache, context_starts_rank, + i, local_chunked_kv_lens, prefill_metadata, query): + + cache_key = kv_cache[0] + cache_value = kv_cache[1] + num_heads = cache_key.size(2) + head_size = kv_cache[0].size(-1) + + # 1. Load current query's history key-value + seq_lens_current_chunk = attn_metadata.query_lens.detach().clone() + num_requests = len(seq_lens_current_chunk) + # Before dealing with a new chunk, set to zero, and accumulate the start positions as chunk prefill step increases + context_starts_rank = torch.zeros( + num_requests, dtype=torch.int32, device=query.device + ) if context_starts_rank is None else context_starts_rank + # Calculate tokens each rank should process per request + seq_lens_current_chunk_rank = torch.zeros_like(seq_lens_current_chunk, + dtype=torch.int32, + device=query.device) + total_toks = 0 + for req_idx in range(num_requests): + if i >= len(local_chunked_kv_lens[req_idx]): + continue + n_computed_acc = local_chunked_kv_lens[req_idx][i] + total_toks += n_computed_acc[self.pcp_rank][self.dcp_rank] + seq_lens_current_chunk_rank[req_idx] = n_computed_acc[ + self.pcp_rank][self.dcp_rank] + if total_toks > 0: + key = torch.empty(total_toks, + num_heads, + head_size, + dtype=query.dtype, + device=query.device) + value = torch.empty(total_toks, + num_heads, + head_size, + dtype=query.dtype, + device=query.device) + + torch_npu.atb.npu_paged_cache_load( + cache_key, + cache_value, + attn_metadata.prefill.block_tables, + seq_lens_current_chunk_rank.to(query.device), + seq_starts= + context_starts_rank, # slot offsets of current chunk in current iteration + key=key, + value=value, + ) + else: + # If current rank has no tokens to process, create empty tensors + key = torch.empty(0, + self.num_heads, + self.head_size, + dtype=query.dtype, + device=query.device) + value = torch.empty(0, + self.num_heads, + self.head_size, + dtype=query.dtype, + device=query.device) + seq_lens_current_chunk_rank = torch.zeros( + (len(seq_lens_current_chunk), ), + dtype=torch.int32, + device=query.device) + for req_idx in range(num_requests): + # Before dealing with a new chunk, set to zero, and accumulate the start positions as chunk prefill step increases + if i >= len(local_chunked_kv_lens[req_idx]): + continue + context_starts_rank[req_idx] += local_chunked_kv_lens[req_idx][i][ + self.pcp_rank][self.dcp_rank] + if self.dcp_size > 1: + req_dcp_sizes = extract_req_dcp_by_chunk_pcp( + local_chunked_kv_lens, i, self.dcp_size, self.pcp_rank) + + assert len(req_dcp_sizes) == num_requests and all( + len(dcp_arr) == self.dcp_size for dcp_arr in req_dcp_sizes) + total_toks = np.sum(np.array(req_dcp_sizes)) + kv_local = torch.cat([key, value], dim=-1) + head_dim = kv_local.size(-1) + kv_full = torch.empty((total_toks, num_heads, head_dim), + device=query.device, + dtype=query.dtype) + + kv_full_list = [None for _ in range(self.dcp_size)] + dist.all_gather_object(kv_full_list, + kv_local, + group=self.dcp_group) + kv_full_list = [ + kv for kv in kv_full_list if kv is not None and kv.numel() > 0 + ] + + if len(kv_full_list) > 0: + kv_full = torch.cat(kv_full_list, dim=0) + key, value = kv_full.split([head_size, head_size], dim=-1) + if total_toks == 0: + return key, value, None + seq_lens_current_chunk_rank = torch.tensor( + np.sum(np.array(req_dcp_sizes), axis=1), + dtype=torch.int32, + device=query.device) # [reqs] + return key, value, seq_lens_current_chunk_rank + def forward( self, layer: AttentionLayer, @@ -1162,7 +1541,7 @@ class AscendAttentionBackendImpl(AttentionImpl): if self.pcp_size * self.dcp_size > 1: intermediate_output = self._forward_pcp_dcp( - query, key, value, attn_metadata, output) + query, key, value, kv_cache, attn_metadata, output) elif attn_type == AttentionType.ENCODER_ONLY: # TODO(zzzwwjj): Deal with this `cum_seq_len` more elegantly. cum_seq_len = attn_metadata.query_start_loc[1:].tolist() @@ -1185,7 +1564,7 @@ class AscendAttentionBackendImpl(AttentionImpl): intermediate_output = self._forward_prefill_no_cache( query, key, value, attn_metadata, output, num_tokens) elif attn_metadata.attn_state == \ - AscendAttentionState.PrefillCacheHit: + AscendAttentionState.PrefillCacheHit: intermediate_output = self._forward_prefill_cache_hit( query, attn_metadata, output) elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly: diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 6d5c1397..314d5a55 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -5,6 +5,7 @@ from typing import (TYPE_CHECKING, ClassVar, List, NamedTuple, Optional, Tuple, import numpy as np import torch import torch.distributed as dist +import torch.nn.functional as F import torch_npu from torch import nn from vllm.attention.backends.abstract import (AttentionBackend, @@ -27,11 +28,14 @@ from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm_ascend import envs from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.attention.attention_v1 import AscendAttentionState -from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, - maybe_save_kv_layer_to_connector, - split_decodes_and_prefills, - trans_rope_weight, transdata, - wait_for_kv_layer_from_connector) + +# isort: off +from vllm_ascend.attention.utils import ( + AscendCommonAttentionMetadata, extract_req_dcp_by_chunk_pcp, + filter_chunked_req_indices, maybe_save_kv_layer_to_connector, + split_decodes_and_prefills, trans_rope_weight, transdata, + wait_for_kv_layer_from_connector) +# isort: on from vllm_ascend.compilation.acl_graph import (get_graph_params, update_graph_params_workspaces) from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch @@ -111,6 +115,10 @@ class AscendMLAPrefillMetadata: workspace: torch.Tensor chunk_seq_lens: torch.Tensor chunk_seq_lens_npu: torch.Tensor + mask_for_non_zero_chunk: Optional[list[bool]] = None + max_chunk_num: int = 0 + local_chunked_kv_lens: Optional[list[Optional[list[Optional[list[ + Optional[list[int]]]]]]]] = None attn_mask: torch.Tensor query_lens: torch.Tensor @@ -125,6 +133,7 @@ class AscendMLAPrefillMetadata: sin: torch.Tensor = None cos: torch.Tensor = None pcp_metadata: Optional[AscendPCPMetadata] = None + cp_kv_recover_idx_for_chunk: Optional[list[int]] = None @dataclass @@ -347,6 +356,10 @@ class AscendMLAMetadataBuilder: num_actual_tokens_pcp_padded = long_seq_metadata.num_actual_tokens_pcp_padded if long_seq_metadata else None num_computed_tokens_of_pcp_dcp = long_seq_metadata.num_computed_tokens_of_pcp_dcp if long_seq_metadata else None + cp_kv_recover_idx_for_chunk = long_seq_metadata.cp_kv_recover_idx_for_chunk if long_seq_metadata else None + local_chunked_kv_lens = long_seq_metadata.local_chunked_kv_lens if long_seq_metadata else None + mask_for_non_zero_chunk = long_seq_metadata.mask_for_non_zero_chunk if long_seq_metadata else None + max_chunk_num = long_seq_metadata.max_chunk_num if long_seq_metadata else 0 num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \ split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.decode_threshold) @@ -359,14 +372,15 @@ class AscendMLAMetadataBuilder: device = self.device block_table = (common_attn_metadata.block_table_tensor[:num_reqs]) - input_positions = common_attn_metadata.positions[: - num_actual_tokens].long( - ) + if num_actual_tokens_pcp_padded is None: num_actual_tokens_pcp_padded = num_actual_tokens slot_mapping = common_attn_metadata.slot_mapping[: num_actual_tokens_pcp_padded] + input_positions = common_attn_metadata.positions[: + num_actual_tokens_pcp_padded].long( + ) if self.cos_cache is None: self.cos_cache = model.model.layers[ @@ -408,7 +422,8 @@ class AscendMLAMetadataBuilder: tail_attn_nomask_seqlens=common_long_seq_metadata. tail_attn_nomask_seqlens, q_full_idx=common_long_seq_metadata.q_full_idx, - pcp_prefill_mask=common_long_seq_metadata.pcp_prefill_mask, + pcp_prefill_mask=common_long_seq_metadata.pcp_prefill_mask + if long_seq_metadata else None, pcp_allgather_restore_idx=long_seq_metadata. pcp_allgather_restore_idx if long_seq_metadata else None) @@ -452,6 +467,9 @@ class AscendMLAMetadataBuilder: chunk_seq_lens=chunk_seq_lens, chunk_seq_lens_npu=chunk_seq_lens.npu(), workspace=self.chunked_prefill_workspace, + local_chunked_kv_lens=local_chunked_kv_lens, + mask_for_non_zero_chunk=mask_for_non_zero_chunk, + max_chunk_num=max_chunk_num, ) prefill_input_positions = input_positions[tokens_start:] cos = self.cos_cache[ @@ -474,7 +492,7 @@ class AscendMLAMetadataBuilder: sin=sin, cos=cos, pcp_metadata=pcp_metadata, - ) + cp_kv_recover_idx_for_chunk=cp_kv_recover_idx_for_chunk) decode_metadata = None if num_decodes > 0: @@ -887,8 +905,26 @@ class AscendMLAImpl(MLAAttentionImpl): prefill_metadata = attn_metadata.prefill if prefill_metadata is None or prefill_metadata.chunked_context is None: return prefix_output, prefix_lse + local_chunked_kv_lens = prefill_metadata.chunked_context.local_chunked_kv_lens + mask_for_non_zero_chunk = prefill_metadata.chunked_context.mask_for_non_zero_chunk + max_chunk_num = prefill_metadata.chunked_context.max_chunk_num + if self.pcp_size * self.dcp_size > 1: + assert local_chunked_kv_lens is not None and mask_for_non_zero_chunk is not None and max_chunk_num > 0 + + if self.pcp_size > 1: + prefix_output = torch.zeros(q_nope.shape[0], + self.num_heads, + self.v_head_dim, + dtype=q_nope.dtype, + device=q_nope.device) + prefix_lse = torch.zeros(self.num_heads, + q_pe.shape[0], + dtype=torch.float32, + device=q_pe.device) iters = len(prefill_metadata.chunked_context.seq_tot) + if self.pcp_size * self.dcp_size > 1: + iters = max_chunk_num current_seq_len = torch.tensor(prefill_metadata.query_lens, dtype=torch.int32) @@ -896,60 +932,305 @@ class AscendMLAImpl(MLAAttentionImpl): cache_k_pe = kv_c_and_k_pe_cache[1] num_heads = cache_k_pe.size(2) latent_kv_dim = kv_c_and_k_pe_cache[0].size(-1) + # token -> request mapping for building per-token masks when CP>1 + seq_len1 = torch.tensor(prefill_metadata.query_lens, + dtype=torch.int32, + device=q_nope.device) + seq_len1.mul_( + self.pcp_size) # q_full: already padded, divisible by cp_size + + # Select mask: prefer CP prefill mask from metadata; fallback to cached prefill_mask; create if needed. + mask_local = None + if attn_metadata is not None and attn_metadata.prefill is not None and \ + attn_metadata.prefill.pcp_metadata is not None and attn_metadata.prefill.pcp_metadata.pcp_prefill_mask is not None: + mask_local = attn_metadata.prefill.pcp_metadata.pcp_prefill_mask + else: + mask_local = self.prefill_mask + if mask_local is None: + mask_local = torch.triu( + torch.ones(512, + 512, + device=q_nope.device, + dtype=q_nope.dtype), 1) + self.prefill_mask = mask_local + + # Keep the causal mask; do not override to all-ones. + context_starts_rank = None + for i in range(iters): - toks = prefill_metadata.chunked_context.seq_tot[i] + if self.pcp_size * self.dcp_size > 1: + ## DCP mode: each rank processes its own (cp,dcp) historical context slice per request dimension + num_requests = len(seq_len1) + assert num_requests == len(local_chunked_kv_lens) + # Before dealing with a new chunk, set to zero, and accumulate the start positions as chunk prefill step increases + context_starts_rank = torch.zeros( + num_requests, dtype=torch.int32, device=q_nope.device + ) if context_starts_rank is None else context_starts_rank - context_seq_len = prefill_metadata.chunked_context.chunk_seq_lens[ - i] - context_seq_len_npu = prefill_metadata.chunked_context.chunk_seq_lens_npu[ - i] - seq_len = torch.stack([current_seq_len, context_seq_len]) - kv_c_normed = torch.empty(toks, - num_heads, - latent_kv_dim, - dtype=q_nope.dtype, - device=q_nope.device) - k_pe = torch.empty(toks, - num_heads, - rope_dim, - dtype=q_nope.dtype, - device=q_nope.device) + ## Calculate tokens each rank should process per request + seq_len2_rank = torch.zeros_like(seq_len1, dtype=torch.int32) + total_toks = 0 - torch_npu.atb.npu_paged_cache_load( - cache_kv_c, - cache_k_pe, - prefill_metadata.block_table, - context_seq_len_npu, - seq_starts=prefill_metadata.chunked_context.starts[i], - key=kv_c_normed, - value=k_pe, - ) + for req_idx in range(num_requests): + if i >= len(local_chunked_kv_lens[req_idx]): + continue + n_computed_acc = local_chunked_kv_lens[req_idx][i] + total_toks += n_computed_acc[self.pcp_rank][self.dcp_rank] + seq_len2_rank[req_idx] = n_computed_acc[self.pcp_rank][ + self.dcp_rank] + + if total_toks > 0: + kv_c_normed = torch.empty(total_toks, + num_heads, + latent_kv_dim, + dtype=q_nope.dtype, + device=q_nope.device) + k_pe = torch.empty(total_toks, + num_heads, + rope_dim, + dtype=q_nope.dtype, + device=q_nope.device) + + torch_npu.atb.npu_paged_cache_load( + cache_kv_c, + cache_k_pe, + prefill_metadata.block_table, + seq_len2_rank.to(q_nope.device), + seq_starts= + context_starts_rank, # slot offsets of current chunk in current iteration + key=kv_c_normed, + value=k_pe, + ) + seq_len2 = seq_len2_rank.to(q_nope.device) + else: + # If current rank has no tokens to process, create empty tensors + kv_c_normed = torch.empty(0, + num_heads, + latent_kv_dim, + dtype=q_nope.dtype, + device=q_nope.device) + k_pe = torch.empty(0, + num_heads, + rope_dim, + dtype=q_nope.dtype, + device=q_nope.device) + seq_len2 = torch.zeros((len(seq_len1), ), + dtype=torch.int32, + device=q_nope.device) + seq_len = torch.stack([seq_len1.cpu(), seq_len2.cpu()]) + for req_idx in range(num_requests): + # Before dealing with a new chunk, set to zero, and accumulate the start positions as chunk prefill step increases + if i >= len(local_chunked_kv_lens[req_idx]): + continue + context_starts_rank[req_idx] += local_chunked_kv_lens[ + req_idx][i][self.pcp_rank][self.dcp_rank] + else: + # Original logic: ChunkPrefill-only mode + toks = prefill_metadata.chunked_context.seq_tot[i] + context_seq_len = prefill_metadata.chunked_context.chunk_seq_lens[ + i] + context_seq_len_npu = prefill_metadata.chunked_context.chunk_seq_lens_npu[ + i] + seq_len = torch.stack([current_seq_len, context_seq_len]) + + kv_c_normed = torch.empty(toks, + num_heads, + latent_kv_dim, + dtype=q_nope.dtype, + device=q_nope.device) + k_pe = torch.empty(toks, + num_heads, + rope_dim, + dtype=q_nope.dtype, + device=q_nope.device) + + torch_npu.atb.npu_paged_cache_load( + cache_kv_c, + cache_k_pe, + prefill_metadata.block_table, + context_seq_len_npu, + seq_starts=prefill_metadata.chunked_context.starts[i], + key=kv_c_normed, + value=k_pe, + ) kv_c_normed = kv_c_normed.squeeze() - kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \ - -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) - k_nope, v = kv_nope\ - .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - k_pe = k_pe.expand((*k_nope.shape[:-1], -1)) - torch_npu.atb.npu_ring_mla( - q_nope=q_nope, - q_rope=q_pe, - k_nope=k_nope, - k_rope=k_pe, - value=v, - mask=self.prefill_mask, - seqlen=seq_len, - head_num=self.num_heads, - kv_head_num=self.num_heads, - pre_out=prefix_output, - prev_lse=prefix_lse, - qk_scale=self.scale, - kernel_type="kernel_type_high_precision", - mask_type="no_mask", - input_layout="type_bsnd", - calc_type="calc_type_default", - output=prefix_output, - softmax_lse=prefix_lse) + if self.dcp_size > 1: + # DCP mode: first all_gather within DCP group, let each rank in CP group share complete sequence blocks + # Step 1: DCP all_gather latent + kv_c_k_pe_local = torch.cat( + [kv_c_normed, k_pe.squeeze()], + dim=-1) # [local_toks, latent_dim + rope_dim] + + # Step 2: use all_gather_into_tensor_uneven (gather + cat) + req_dcp_sizes = extract_req_dcp_by_chunk_pcp( + local_chunked_kv_lens, i, self.dcp_size, self.pcp_rank + ) # need to know num tokens of each rank in dcp group before all_gather # [reqs, dcp] + assert len(req_dcp_sizes) == num_requests and all( + len(dcp_arr) == self.dcp_size for dcp_arr in req_dcp_sizes) + total_toks = np.sum(np.array(req_dcp_sizes)) + latent_rope_dim = kv_c_k_pe_local.size(-1) + kv_c_k_pe_full = torch.empty((total_toks, latent_rope_dim), + device=kv_c_k_pe_local.device, + dtype=kv_c_k_pe_local.dtype) + + kv_c_k_pe_full_list = [None for _ in range(self.dcp_size)] + dist.all_gather_object(kv_c_k_pe_full_list, + kv_c_k_pe_local, + group=self.dcp_group) + kv_c_k_pe_full_list = [ + kv_c_k_pe for kv_c_k_pe in kv_c_k_pe_full_list + if kv_c_k_pe is not None and kv_c_k_pe.numel() > 0 + ] + if len(kv_c_k_pe_full_list) > 0: + kv_c_k_pe_full = torch.cat(kv_c_k_pe_full_list, dim=0) + if len(kv_c_k_pe_full.shape) == 1: + assert total_toks == 1 + kv_c_k_pe_full = kv_c_k_pe_full.unsqueeze(0) + assert kv_c_k_pe_full.shape[ + 0] == total_toks and kv_c_k_pe_full.shape[ + 1] == latent_rope_dim + kv_c_normed_full, k_pe_full = torch.split( + kv_c_k_pe_full, [latent_kv_dim, rope_dim], dim=-1) + + # Step 3: process complete sequence with TP projection to get current rank's head slice + # Case that no kv_cache has been stored on this CP rank(after dcp all_gather), no need to do following computation. + if total_toks == 0: + continue + kv_nope = self.kv_b_proj(kv_c_normed_full)[0].view( + -1, self.num_heads, + self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = kv_nope.split( + [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + k_pe = k_pe_full.unsqueeze(1).expand((*k_nope.shape[:-1], -1)) + + seq_len2 = torch.tensor(np.sum(np.array(req_dcp_sizes), + axis=1), + dtype=torch.int32, + device=q_nope.device) # [reqs] + seq_len = torch.stack([seq_len1.cpu(), seq_len2.cpu()]) + else: + # Non-DCP mode: use TP-split projection + kv_nope = self.kv_b_proj(kv_c_normed)[0].view( + -1, self.num_heads, + self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = kv_nope.split( + [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + k_pe = k_pe.expand((*k_nope.shape[:-1], -1)) + + if self.pcp_size > 1: + # Case that no kv_cache has been stored on this CP rank, no need to do following computation. + if torch.all(seq_len2 == 0).item(): + continue + # PCP mode: first compute this rank's contribution to the chunk + if i == 0: + torch_npu.atb.npu_ring_mla( + q_nope=q_nope, + q_rope=q_pe, + k_nope=k_nope, + k_rope=k_pe, + value=v, + mask=mask_local, + seqlen=seq_len, + head_num=self.num_heads, + kv_head_num=self.num_heads, + pre_out=None, + prev_lse=None, + qk_scale=self.scale, + kernel_type="kernel_type_high_precision", + mask_type="no_mask", + input_layout="type_bsnd", + calc_type="calc_type_first_ring", + output=prefix_output, + softmax_lse=prefix_lse) + continue + + torch_npu.atb.npu_ring_mla( + q_nope=q_nope, + q_rope=q_pe, + k_nope=k_nope, + k_rope=k_pe, + value=v, + mask=mask_local, + seqlen=seq_len, + head_num=self.num_heads, + kv_head_num=self.num_heads, + pre_out=prefix_output, + prev_lse=prefix_lse, + qk_scale=self.scale, + kernel_type="kernel_type_high_precision", + mask_type="no_mask", + input_layout="type_bsnd", + calc_type="calc_type_default", + output=prefix_output, + softmax_lse=prefix_lse) + + else: + assert not torch.all(context_seq_len == 0).item() + # compute this chunk block then update prefix tensors to keep shapes consistent + torch_npu.atb.npu_ring_mla( + q_nope=q_nope, + q_rope=q_pe, + k_nope=k_nope, + k_rope=k_pe, + value=v, + mask=mask_local, + seqlen=seq_len, + head_num=self.num_heads, + kv_head_num=self.num_heads, + pre_out=prefix_output, + prev_lse=prefix_lse, + qk_scale=self.scale, + kernel_type="kernel_type_high_precision", + mask_type="no_mask", + input_layout="type_bsnd", + calc_type="calc_type_default", + output=prefix_output, + softmax_lse=prefix_lse) + + # CP dimension all_gather and fusion + if self.pcp_size > 1: + # filter non-zero chunk part of prefix_output + seq_len1_cpu = seq_len1.cpu() + filtered_indices = filter_chunked_req_indices( + seq_len1_cpu, mask_for_non_zero_chunk) + prefix_output_filtered = prefix_output[filtered_indices, :, :] + prefix_lse_filtered = prefix_lse[:, filtered_indices] + + # normalize prefix LSE to [bs, heads, 1] for stable updates + prefix_lse_filtered_bt = prefix_lse_filtered.permute( + 1, 0).unsqueeze(-1).contiguous( + ) if prefix_lse_filtered is not None else None + out_lse_local = torch.cat( + [prefix_output_filtered, prefix_lse_filtered_bt], dim=-1) + out_lse_list = [ + torch.empty_like(out_lse_local) for _ in range(self.pcp_size) + ] + dist.all_gather(out_lse_list, out_lse_local, group=self.pcp_group) + prefix_output_filtered = None + prefix_lse_filtered_bt = None + for r in range(self.pcp_size): + out_lse_r = out_lse_list[r] + if torch.all(out_lse_r == 0).item(): + continue + out_r, lse_r = torch.split(out_lse_r, [self.v_head_dim, 1], + dim=-1) + token_mask = torch.ones([out_r.size(0)], + dtype=torch.uint8, + device=out_r.device) + prefix_output_filtered, prefix_lse_filtered_bt = self._update_out_and_lse( + prefix_output_filtered, prefix_lse_filtered_bt, out_r, + lse_r, token_mask) + # convert lse back to [heads, bs] + assert prefix_output_filtered is not None and prefix_lse_filtered_bt is not None + prefix_lse_filtered = prefix_lse_filtered_bt.squeeze(-1).permute( + 1, 0).contiguous() + + prefix_output[filtered_indices, :, :] = prefix_output_filtered.to( + prefix_output.dtype) + prefix_lse[:, filtered_indices] = prefix_lse_filtered.to( + prefix_lse.dtype) + return prefix_output, prefix_lse def _forward_prefill( @@ -1516,7 +1797,7 @@ class AscendMLAImpl(MLAAttentionImpl): tail_attn_nomask_seqlens = attn_metadata.prefill.pcp_metadata.tail_attn_nomask_seqlens mask = attn_metadata.prefill.pcp_metadata.pcp_prefill_mask - output_head = self._attention_with_mask_and_nomask( + output_head, head_lse = self._attention_with_mask_and_nomask( q_nope=torch.index_select(q_nope, 0, q_head_idx), q_pe=torch.index_select(q_pe, 0, q_head_idx), k_nope=k_nope, @@ -1528,7 +1809,7 @@ class AscendMLAImpl(MLAAttentionImpl): attn_nomask_seqlens=head_attn_nomask_seqlens, mask=mask) - output_tail = self._attention_with_mask_and_nomask( + output_tail, tail_lse = self._attention_with_mask_and_nomask( q_nope=torch.index_select(q_nope, 0, q_tail_idx), q_pe=torch.index_select(q_pe, 0, q_tail_idx), k_nope=k_nope, @@ -1544,7 +1825,83 @@ class AscendMLAImpl(MLAAttentionImpl): output = torch.index_select( torch.cat([output_head, output_tail], dim=0), 0, q_full_idx) - output = output.reshape([num_tokens, self.num_heads * self.v_head_dim]) + # Synchronize and reorder LSE for subsequent chunked context accumulation + attn_lse = torch.cat([head_lse, tail_lse], dim=1) + attn_lse = attn_lse[:, q_full_idx] + + # Post-processing: keep [tokens, H, V] shape and perform chunked context accumulation if needed + if attn_metadata.prefill is not None and \ + attn_metadata.prefill.chunked_context is not None: + # q all_gather + q_nope_full = get_pcp_group().all_gather(q_nope.contiguous(), 0) + q_pe_full = get_pcp_group().all_gather(q_pe.contiguous(), 0) + q_nope_full = torch.index_select( + q_nope_full, 0, + attn_metadata.prefill.cp_kv_recover_idx_for_chunk) + q_pe_full = torch.index_select( + q_pe_full, 0, + attn_metadata.prefill.cp_kv_recover_idx_for_chunk) + attn_output_pre = output.view(num_tokens, self.num_heads, + self.v_head_dim) + attn_output_pre_full, attn_lse_full = self._compute_prefill_context( + q_nope_full, + q_pe_full, + kv_c_and_k_pe_cache, + self.qk_rope_head_dim, + attn_metadata, + None, + None, + ) + # reorder back && extract output + lse result of each cp rank + inverse_idx = torch.argsort( + attn_metadata.prefill.cp_kv_recover_idx_for_chunk) + attn_output_pre_full = torch.index_select(attn_output_pre_full, 0, + inverse_idx) + attn_lse_full = torch.index_select(attn_lse_full, 1, inverse_idx) + attn_output_pre_new = attn_output_pre_full[ + self.pcp_rank * num_tokens:(self.pcp_rank + 1) * + num_tokens, :, :] + attn_lse_new = attn_lse_full[:, self.pcp_rank * + num_tokens:(self.pcp_rank + 1) * + num_tokens] + + # update(output_origin, output_new) + assert attn_output_pre_new.shape == attn_output_pre.shape and attn_lse_new.shape == attn_lse.shape + seq_len = torch.tensor(attn_metadata.prefill.query_lens, + dtype=torch.int32) + mask_for_non_zero_chunk = attn_metadata.prefill.chunked_context.mask_for_non_zero_chunk + filtered_indices = filter_chunked_req_indices( + seq_len, mask_for_non_zero_chunk) + attn_output_pre_filtered = attn_output_pre[filtered_indices, :, :] + attn_lse_filtered = attn_lse[:, filtered_indices] + attn_output_pre_new = attn_output_pre_new[filtered_indices, :, :] + attn_lse_new = attn_lse_new[:, filtered_indices] + + # normalize prefix LSE to [bs, heads, 1] for stable updates + attn_lse_filtered = attn_lse_filtered.permute(1, 0).unsqueeze(-1) + attn_lse_new = attn_lse_new.permute(1, 0).unsqueeze(-1) + token_mask = torch.ones([attn_lse_new.size(0)], + dtype=torch.uint8, + device=attn_lse_new.device) + attn_output_pre_filtered, attn_lse_filtered = self._update_out_and_lse( + attn_output_pre_filtered, attn_lse_filtered, + attn_output_pre_new, attn_lse_new, token_mask) + # convert lse back to [heads, bs] + attn_lse_filtered = attn_lse_filtered.squeeze(-1).permute( + 1, 0).contiguous() + + attn_output_pre[ + filtered_indices, :, :] = attn_output_pre_filtered.to( + attn_output_pre.dtype) + attn_lse[:, + filtered_indices] = attn_lse_filtered.to(attn_lse.dtype) + + attn_output_pre = attn_output_pre.to(q_nope.dtype) + output = attn_output_pre.reshape( + [num_tokens, self.num_heads * self.v_head_dim]) + else: + output = output.reshape( + [num_tokens, self.num_heads * self.v_head_dim]) return output @@ -1588,7 +1945,7 @@ class AscendMLAImpl(MLAAttentionImpl): # nomask if kv_nomask_idx.shape[0] == 0: - return attn_output + return attn_output, attn_lse k_nope_nomask = torch.index_select(k_nope, 0, kv_nomask_idx) value_nomask = torch.index_select(value, 0, kv_nomask_idx) @@ -1611,7 +1968,7 @@ class AscendMLAImpl(MLAAttentionImpl): calc_type="calc_type_default", output=attn_output, softmax_lse=attn_lse) - return attn_output + return attn_output, attn_lse def _forward_decode_pcp_dcp( self, @@ -1788,3 +2145,33 @@ class AscendMLAImpl(MLAAttentionImpl): attn_out_lse_list = attn_out_lse_list_pcp_dcp return attn_out_lse_list + + # TODO use update op to replace this + def _update_out_and_lse( + self, + out: torch.Tensor, + lse: torch.Tensor, + block_out: torch.Tensor, + block_lse: torch.Tensor, + mask: torch.Tensor = None, + ): + if out is None: + out = block_out.to(torch.float32) + lse = block_lse + else: + if mask is None: + mask = torch.ones([block_out.size(0)], + dtype=torch.uint8, + device=block_out.device) + out_mask = mask[:, None, None].expand_as(block_out) + lse_mask = mask[:, None, None].expand_as(block_lse) + block_out = block_out.to(torch.float32) + out_without_update = out.clone() + lse_without_update = lse.clone() + + out = out - F.sigmoid(block_lse - lse) * (out - block_out) + lse = lse - F.logsigmoid(lse - block_lse) + # mask + out = torch.where(out_mask, out, out_without_update) + lse = torch.where(lse_mask, lse, lse_without_update) + return out, lse diff --git a/vllm_ascend/attention/utils.py b/vllm_ascend/attention/utils.py index 48118c05..b998ee6a 100644 --- a/vllm_ascend/attention/utils.py +++ b/vllm_ascend/attention/utils.py @@ -14,10 +14,19 @@ from vllm.forward_context import ForwardContext, get_forward_context class AscendPrefillContextParallelMetadata: pcp_allgather_restore_idx: torch.Tensor = None + cp_kv_recover_idx_for_chunk: torch.Tensor = None + num_actual_tokens_pcp_padded: Optional[int] = None num_computed_tokens_of_pcp_dcp: Optional[list[list[list[int]]]] = None + local_chunked_kv_lens: Optional[list[Optional[list[Optional[list[Optional[ + list[int]]]]]]]] = None + + mask_for_non_zero_chunk: Optional[List[bool]] = None + + max_chunk_num: int = 0 + q_head_idx_tensor: torch.Tensor = None q_tail_idx_tensor: torch.Tensor = None @@ -46,7 +55,7 @@ class AscendCommonAttentionMetadata: """ Per-batch attention metadata, shared across layers and backends. AttentionMetadataBuilder instances use it to construct per-layer metadata. - + For many of the tensors we keep both GPU and CPU versions. """ @@ -106,6 +115,47 @@ class AscendCommonAttentionMetadata: AscendPrefillContextParallelMetadata] = None +def extract_req_dcp_by_chunk_pcp(lst, + chunk_idx, + dcp_size, + pcp_rank, + fill_value=0): + num_reqs = len(lst) + results: List[List[int]] = [] + for i in range(num_reqs): + if len(lst[i]) == 0 or chunk_idx >= len(lst[i]): + # empty req or this req has no corresponding chunk, fill 0 + results.append([fill_value] * dcp_size) + continue + dcp_values = lst[i][chunk_idx][pcp_rank] + results.append(dcp_values) + return results + + +def filter_chunked_req_indices( + seq_len: torch.Tensor, + mask_for_non_zero_chunk: Optional[List[bool]], +) -> torch.Tensor: + """ + filter the reqs which are doing real chunk_prefill. + + Args: + seq_len: contains multi-req length: [req0_len, req1_len, ...] + mask_for_non_zero_chunk: [True, False, True, False, ...] + Returns: + filtered_indices: the real chunked req's indices + """ + assert mask_for_non_zero_chunk is not None and len(seq_len) == len( + mask_for_non_zero_chunk) + offsets = torch.cumsum(torch.cat([torch.tensor([0]), seq_len[:-1]]), dim=0) + filtered_indices = torch.cat([ + torch.arange(offsets[i], offsets[i] + seq_len[i]) + for i in range(len(mask_for_non_zero_chunk)) + if mask_for_non_zero_chunk[i] + ]) + return filtered_indices + + def split_decodes_and_prefills( common_attn_metadata: AscendCommonAttentionMetadata, decode_threshold: int = 1, diff --git a/vllm_ascend/worker/block_table.py b/vllm_ascend/worker/block_table.py index d8333abd..6c35fcfa 100644 --- a/vllm_ascend/worker/block_table.py +++ b/vllm_ascend/worker/block_table.py @@ -77,14 +77,6 @@ class BlockTable: self.block_table_np = self.block_table_cpu.numpy() self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32) - self.slot_mapping_cpu = torch.zeros(self.max_num_batched_tokens, - dtype=torch.int64, - device="cpu", - pin_memory=self.pin_memory) - self.slot_mapping_np = self.slot_mapping_cpu.numpy() - self.slot_mapping = torch.zeros(self.max_num_batched_tokens, - dtype=torch.int32, - device=self.device) try: self.pcp_world_size = get_pcp_group( ).world_size if prefill_context_parallel_enable() else 1 @@ -98,6 +90,20 @@ class BlockTable: self.dcp_rank = 0 self.pcp_world_size = 1 self.pcp_rank = 0 + + self.slot_mapping_cpu = torch.zeros( + self.max_num_batched_tokens + + 2 * self.pcp_world_size * self.max_num_reqs, + dtype=torch.int64, + device="cpu", + pin_memory=self.pin_memory) + self.slot_mapping_np = self.slot_mapping_cpu.numpy() + self.slot_mapping = torch.zeros( + self.max_num_batched_tokens + + 2 * self.pcp_world_size * self.max_num_reqs, + dtype=torch.int32, + device=self.device) + self.kernel_sizes = kernel_sizes self.cp_kv_cache_interleave_size = cp_kv_cache_interleave_size @@ -148,7 +154,7 @@ class BlockTable: if self.dcp_world_size * self.pcp_world_size > 1: # Note(hc): The DCP implement store kvcache with an interleave # style, the kvcache for the token whose token_idx is i is - # always stored on the GPU whose dcp_rank equals i % cp_world_size: + # always stored on the GPU whose dcp_rank equals i % pcp_world_size: # Use a "virtual block" which equals to world_size * block_size # for block_table_indices calculation. @@ -268,12 +274,12 @@ class MultiGroupBlockTable: # must be multiplied by dcp_world_size. try: dcp_world_size = get_dcp_group().world_size - cp_world_size = get_pcp_group( + pcp_world_size = get_pcp_group( ).world_size if prefill_context_parallel_enable() else 1 except AssertionError: # DCP might not be initialized in testing dcp_world_size = 1 - cp_world_size = 1 + pcp_world_size = 1 if kernel_sizes is None: kernel_sizes = [[0]] * len(block_sizes) @@ -291,7 +297,7 @@ class MultiGroupBlockTable: block_size, max_num_reqs, max( cdiv(max_model_len, - block_size * dcp_world_size * cp_world_size), + block_size * dcp_world_size * pcp_world_size), 1 + num_speculative_tokens), max_num_batched_tokens, pin_memory, device, kernel_size_list, cp_kv_cache_interleave_size) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index d28c1a7a..167345bf 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -29,7 +29,7 @@ from copy import deepcopy from dataclasses import dataclass from multiprocessing import Manager from typing import (TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, - Union, cast) + Tuple, Union, cast) import numpy as np import numpy.typing as npt @@ -471,13 +471,19 @@ class NPUModelRunner(LoRAModelRunnerMixin): device="cpu", pin_memory=True) self.seq_lens_np = self.seq_lens_cpu.numpy() - self.pcp_allgather_restore_idx = torch.zeros(self.max_num_tokens, - dtype=torch.int32, - device=self.device) + self.pcp_allgather_restore_idx = torch.zeros( + self.max_num_tokens + 2 * self.pcp_size * self.max_num_reqs, + dtype=torch.int32, + device=self.device) + self.cp_kv_recover_idx_for_chunk: List[List[int]] = [ + [] for _ in range(self.pcp_size) + ] + self.num_pcp_pads = torch.zeros(self.max_num_reqs, dtype=torch.int32) - self.pcp_padded_slot_mapping = torch.zeros(self.max_num_tokens, - dtype=torch.int32, - device=self.device) + self.pcp_padded_slot_mapping = torch.zeros( + self.max_num_tokens + 2 * self.pcp_size * self.max_num_reqs, + dtype=torch.int32, + device=self.device) self.num_actual_tokens_pcp_padded = 0 if self.speculative_config and self.pcp_size > 1: self.input_ids_pcp_full = torch.zeros(self.max_num_tokens, @@ -739,7 +745,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): backward_kwargs = {} backward_kwargs["mm_features"] = new_req_data.mm_features - self.requests[req_id] = CachedRequestState( + # Create request state - PCP/DCP tracking will be computed below + req_state = CachedRequestState( req_id=req_id, prompt_token_ids=new_req_data.prompt_token_ids, prompt_embeds=new_req_data.prompt_embeds, @@ -750,9 +757,42 @@ class NPUModelRunner(LoRAModelRunnerMixin): num_computed_tokens=new_req_data.num_computed_tokens, output_token_ids=[], lora_request=new_req_data.lora_request, + local_chunked_kv_lens=None, **backward_kwargs, ) + # Compute PCP/DCP tracking fields for chunked prefill + self.input_batch.local_chunked_kv_lens = [None] * self.max_num_reqs + if self.pcp_size * self.dcp_size > 1: + num_computed_tokens = new_req_data.num_computed_tokens + if num_computed_tokens > 0: + # Initialize with starting rank 0 + temp_start_rank_dict = {req_id: (0, 0)} + + # Compute token distribution for initial tokens + current_distribution = self.get_split_computed_tokens( + np.array([num_computed_tokens]), + request_ids=[req_id], + request_start_rank_dict=temp_start_rank_dict, + cp_kv_cache_interleave_size=self.parallel_config. + cp_kv_cache_interleave_size, + )[0] + + # Update next_pcp_dcp_start_rank + req_state.next_pcp_dcp_start_rank = temp_start_rank_dict[ + req_id][0] + req_state.token_blank_in_last_blk = temp_start_rank_dict[ + req_id][1] + + req_state.local_chunked_kv_lens = [ + copy.deepcopy(current_distribution) + ] + else: + # No computed tokens yet + req_state.local_chunked_kv_lens = [] + + self.requests[req_id] = req_state + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: self._init_mrope_positions(self.requests[req_id]) @@ -769,8 +809,44 @@ class NPUModelRunner(LoRAModelRunnerMixin): resumed_from_preemption = req_data.resumed_from_preemption[i] # Update the cached states. + prev_num_computed_tokens = req_state.num_computed_tokens req_state.num_computed_tokens = num_computed_tokens + # Compute PCP/DCP tracking fields for chunked prefill + if self.pcp_size * self.dcp_size > 1: + # If this is the first chunk, initialize tracking fields + if req_state.local_chunked_kv_lens is None: + req_state.local_chunked_kv_lens = [] + + # Compute tokens added in this chunk (not cumulative) + chunk_tokens = num_computed_tokens - prev_num_computed_tokens + + if chunk_tokens > 0: + # Create a temporary dict with this request's starting rank + temp_start_rank_dict = { + req_id: (req_state.next_pcp_dcp_start_rank, + req_state.token_blank_in_last_blk) + } + + # Compute distribution for this chunk only + chunk_distribution = self.get_split_computed_tokens( + np.array([chunk_tokens]), + request_ids=[req_id], + request_start_rank_dict=temp_start_rank_dict, + cp_kv_cache_interleave_size=self.parallel_config. + cp_kv_cache_interleave_size, + )[0] + + # Update next_pcp_dcp_start_rank for this request + req_state.next_pcp_dcp_start_rank = temp_start_rank_dict[ + req_id][0] + req_state.token_blank_in_last_blk = temp_start_rank_dict[ + req_id][1] + + # Append this chunk's distribution to accumulation list + req_state.local_chunked_kv_lens.append( + copy.deepcopy(chunk_distribution)) + if not is_last_rank: # When using PP, the scheduler sends the sampled tokens back, # because there's no direct communication between the first- @@ -815,6 +891,10 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.input_batch.block_table.append_row( new_block_ids, req_index) + # Update PCP/DCP tracking fields in input_batch + self.input_batch.local_chunked_kv_lens[ + req_index] = req_state.local_chunked_kv_lens + # For the last rank, we don't need to update the token_ids_cpu # because the sampled tokens are already cached. if not is_last_rank: @@ -979,6 +1059,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): return None if self.attn_mask_builder is None: raise ValueError("Attn mask builder is None") + if self.dcp_size > 1: + return self.attn_mask_builder.get_splitfuse_attn_mask() # Pooling situation. if self.model_config.runner_type == "pooling" and self.model_config.pooler_config.pooling_type == "CLS": return self.attn_mask_builder.get_pooling_mask(self.device) @@ -1378,6 +1460,49 @@ class NPUModelRunner(LoRAModelRunnerMixin): scheduler_output, decode_threshold=self.reorder_batch_threshold) + def generate_kv_idx(self, tokens, scheduler_output): + if not self.pcp_size > 1: + return + self.cp_kv_recover_idx_for_chunk = [[] for _ in range(self.pcp_size)] + + for i, req_id in enumerate(self.input_batch.req_ids): + num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ + req_id] + is_prefill = self.input_batch.num_computed_tokens_cpu[ + i] < self.input_batch.num_prompt_tokens[i] + if is_prefill: + num_cp_padded_scheduled_tokens = cdiv( + num_scheduled_tokens, + 2 * self.pcp_size) * (2 * self.pcp_size) + full_indices = list( + range(self.max_num_tokens * self.pcp_size * self.dcp_size + + self.pcp_size * self.dcp_size * self.max_num_reqs)) + chunk_size = num_cp_padded_scheduled_tokens // (2 * + self.pcp_size) + num_added_recover_tokens = len( + self.cp_kv_recover_idx_for_chunk[0]) * self.pcp_size + for rank in range(self.pcp_size): + self.cp_kv_recover_idx_for_chunk[rank].extend( + full_indices[rank * chunk_size + + num_added_recover_tokens:(rank + 1) * + chunk_size + num_added_recover_tokens]) + self.cp_kv_recover_idx_for_chunk[rank].extend( + full_indices[num_cp_padded_scheduled_tokens - + (rank + 1) * chunk_size + + num_added_recover_tokens: + num_cp_padded_scheduled_tokens - + rank * chunk_size + + num_added_recover_tokens]) + + cp_kv_recover_idx_for_chunk = torch.from_numpy( + np.concatenate( + self.cp_kv_recover_idx_for_chunk)).to(device=self.device) + cp_kv_recover_idx_for_chunk.copy_(torch.tensor( + np.array(self.cp_kv_recover_idx_for_chunk).flatten().tolist()), + non_blocking=True) + self.cp_kv_recover_idx_for_chunk = cp_kv_recover_idx_for_chunk.to( + torch.float32).argsort().to(torch.int32) + def _prepare_inputs( self, scheduler_output: "SchedulerOutput", @@ -1406,7 +1531,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.input_batch.num_computed_tokens_cpu[req_indices], arange, ) - + self.generate_kv_idx(tokens, scheduler_output) self.input_batch.block_table.compute_slot_mapping( req_indices, positions_np) self.input_batch.block_table.commit_slot_mapping( @@ -1610,15 +1735,19 @@ class NPUModelRunner(LoRAModelRunnerMixin): ] num_tokens_np = np.array(num_tokens, dtype=np.int32) num_reqs = self.input_batch.num_reqs - if self.pcp_size == 1: - discard_requests_mask = self.seq_lens_np[:num_reqs] < num_tokens_np - else: + if self.pcp_size > 1: # while pcp > 1, we need the original num_scheduled_tokens before split # to calculate discard_requests_mask + tokens_original = [ + scheduler_output.num_scheduled_tokens[i] for i in req_ids + ] original_seq_lens_np = ( self.input_batch.num_computed_tokens_cpu[:num_reqs] + - np.array(list(scheduler_output.num_scheduled_tokens.values()))) + np.array(tokens_original, dtype=np.int32)) discard_requests_mask = original_seq_lens_np < num_tokens_np + else: + discard_requests_mask = self.seq_lens_np[:num_reqs] < num_tokens_np + discard_request_indices = np.nonzero(discard_requests_mask)[0] self.num_discarded_requests = len(discard_request_indices) self.discard_request_indices.np[:self.num_discarded_requests] = ( @@ -1762,8 +1891,17 @@ class NPUModelRunner(LoRAModelRunnerMixin): scheduler_output.num_scheduled_tokens) # prepare pcp meta data + # For chunked prefill, use num_scheduled_tokens instead of cumulative seq_lens + # to correctly calculate chunk_len in _generate_pcp_metadata + if self.vllm_config.scheduler_config.chunked_prefill_enabled and self.pcp_size > 1: + # In chunked prefill, seq_lens_for_chunk should be the current chunk size + seq_lens_for_chunk = torch.from_numpy( + num_scheduled_tokens[:num_reqs]) + else: + # Normal mode: use cumulative sequence lengths + seq_lens_for_chunk = seq_lens_cpu long_seq_metadata = self._generate_pcp_metadata( - total_num_scheduled_tokens, seq_lens_cpu) + total_num_scheduled_tokens, seq_lens_for_chunk, seq_lens_cpu) # Prepare the attention metadata for each KV cache group and make layers # in the same group share the same metadata. for kv_cache_group_id, kv_cache_group_spec in enumerate( @@ -2690,7 +2828,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): dtype=torch.int32, device=self.device) long_seq_metadata = self._generate_pcp_metadata( - num_tokens, self.seq_lens_cpu) + num_tokens, self.seq_lens_cpu, self.seq_lens_cpu) if long_seq_metadata is not None: pcp_world_size = get_pcp_group( ).world_size if prefill_context_parallel_enable() else 1 @@ -4266,23 +4404,149 @@ class NPUModelRunner(LoRAModelRunnerMixin): [-1, pcp_world_size, dcp_world_size]) return dcp_local_seq_lens - def _generate_pcp_metadata(self, total_num_scheduled_tokens, seq_lens): + def get_split_computed_tokens( + self, + num_computed_tokens: np.ndarray, + request_ids: Optional[List[str]] = None, + request_start_rank_dict: Dict[str, tuple[ + int, int]] = {}, # tuple: start_rank, tokens_blank_in_this_block + cp_kv_cache_interleave_size: int = 1 + ) -> list[Optional[list[Optional[list[int]]]]]: + """Splits computed token counts across dcp and sp dimensions for distributed allocation. + + Args: + num_computed_tokens: Number of tokens for each request (current chunk, not cumulative) + request_ids: Request IDs to track state + request_start_rank_dict: Dict mapping req_id to the starting rank for this chunk. + Will be updated with next starting rank after distribution. + + Returns: + List of [pcp_size][dcp_size] distribution for each request + """ + self.pcp_world_size = get_pcp_group( + ).world_size if prefill_context_parallel_enable() else 1 + self.dcp_world_size = get_dcp_group().world_size + num_requests = len(num_computed_tokens) + assert request_start_rank_dict is not None and request_ids is not None and len( + request_ids) == num_requests + local_chunked_kv_lens = [[[0] * self.dcp_world_size + for _ in range(self.pcp_world_size)] + for _ in range(num_requests)] + total_ranks = self.pcp_world_size * self.dcp_world_size + + for req_idx, (req_id, total_tokens) in enumerate( + zip(request_ids, num_computed_tokens)): + if total_tokens <= 0: + continue + + # Get starting rank for this chunk + start_rank = 0 + tokens_blank = 0 + if request_start_rank_dict is not None: + start_rank, tokens_blank = request_start_rank_dict.get( + req_id, (0, 0)) + + if tokens_blank > 0: # need to continue writing in the last block of previous chunk + consumed_tokens = min(tokens_blank, total_tokens) + total_tokens -= consumed_tokens + tokens_blank -= consumed_tokens + pcp_idx = start_rank // self.dcp_world_size + dcp_idx = start_rank % self.dcp_world_size + local_chunked_kv_lens[req_idx][pcp_idx][ + dcp_idx] += consumed_tokens + if tokens_blank == 0: + start_rank = (start_rank + 1) % total_ranks + if total_tokens == 0: + request_start_rank_dict[req_id] = (start_rank, + tokens_blank) + continue + + virtual_size = total_ranks * cp_kv_cache_interleave_size + base = int(total_tokens) // virtual_size + + # Distribute base tokens to all ranks + for rank_idx in range(total_ranks): + pcp_idx = rank_idx // self.dcp_world_size + dcp_idx = rank_idx % self.dcp_world_size + local_chunked_kv_lens[req_idx][pcp_idx][ + dcp_idx] += base * cp_kv_cache_interleave_size + + remainder = int(total_tokens) % virtual_size + if remainder == 0: + request_start_rank_dict[req_id] = (start_rank, tokens_blank) + continue + remain_blocks = cdiv(remainder, cp_kv_cache_interleave_size) + assert remain_blocks > 0 + + # Distribute remainder tokens starting from start_rank + for i in range(remain_blocks): + rank = (start_rank + i) % total_ranks + pcp_idx = rank // self.dcp_world_size + dcp_idx = rank % self.dcp_world_size + if i < remain_blocks - 1 or remainder % cp_kv_cache_interleave_size == 0: # not last block or divisible + local_chunked_kv_lens[req_idx][pcp_idx][ + dcp_idx] += 1 * cp_kv_cache_interleave_size + tokens_blank = 0 + else: # if last block and undivisible + local_chunked_kv_lens[req_idx][pcp_idx][ + dcp_idx] += remainder % cp_kv_cache_interleave_size + tokens_blank = cp_kv_cache_interleave_size - ( + remainder % cp_kv_cache_interleave_size) + start_rank = (start_rank + remain_blocks - 1) % total_ranks + if tokens_blank == 0: + start_rank = (start_rank + 1) % total_ranks + + # Update next starting rank for this request + request_start_rank_dict[req_id] = (start_rank, tokens_blank) + + return cast(List[Optional[List[Optional[List[int]]]]], + local_chunked_kv_lens) + + def _get_chunked_req_mask_and_max_chunk( + self, + local_chunked_kv_lens: Optional[list[Optional[list[Optional[list[ + Optional[list[int]]]]]]]] = None + ) -> Tuple[List[bool], int]: + """ + given 4-d list [req][chunk][pcp][dcp], return: + 1. if each req has any chunk (list[bool]) + 2. max chunk num along all reqs (int) + """ + assert local_chunked_kv_lens is not None + if len(local_chunked_kv_lens) == 0: + return ([], 0) + mask_for_non_zero_chunk = [ + len(req) > 0 for req in local_chunked_kv_lens if req is not None + ] + max_chunk_num = max( + (len(req) for req in local_chunked_kv_lens if req is not None), + default=0) + return mask_for_non_zero_chunk, max_chunk_num + + def _generate_pcp_metadata(self, total_num_scheduled_tokens, seq_lens, + seq_lens_origin): num_reqs = self.input_batch.num_reqs num_decodes = sum(self.input_batch.num_computed_tokens_cpu[:num_reqs] >= self.input_batch.num_prompt_tokens[:num_reqs]) num_actual_tokens_pcp_padded = total_num_scheduled_tokens * self.pcp_size self.num_actual_tokens_pcp_padded = num_actual_tokens_pcp_padded + local_chunked_kv_lens = self.input_batch.local_chunked_kv_lens[ + num_decodes:num_reqs] + mask_for_non_zero_chunk, max_chunk_num = self._get_chunked_req_mask_and_max_chunk( + local_chunked_kv_lens) long_seq_metadata = None if self.pcp_size * self.dcp_size > 1: long_seq_metadata = AscendPrefillContextParallelMetadata( num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded, num_computed_tokens_of_pcp_dcp=self._get_pcp_local_seq_lens( - seq_lens, + seq_lens_origin, self.pcp_size, self.dcp_size, self.parallel_config.cp_kv_cache_interleave_size, ).numpy(), - ) + local_chunked_kv_lens=local_chunked_kv_lens, + mask_for_non_zero_chunk=mask_for_non_zero_chunk, + max_chunk_num=max_chunk_num) if self.pcp_size > 1: q_head_idx, q_tail_idx = [], [] kv_with_q_head_nomask_idx, kv_with_q_head_mask_idx = [], [] @@ -4393,6 +4657,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): } long_seq_metadata.pcp_allgather_restore_idx = self.pcp_allgather_restore_idx[: num_actual_tokens_pcp_padded] + long_seq_metadata.cp_kv_recover_idx_for_chunk = self.cp_kv_recover_idx_for_chunk long_seq_metadata.q_head_idx_tensor = self.q_head_idx_tensor long_seq_metadata.q_tail_idx_tensor = self.q_tail_idx_tensor long_seq_metadata.q_full_idx = self.q_full_idx diff --git a/vllm_ascend/worker/npu_input_batch.py b/vllm_ascend/worker/npu_input_batch.py index 846a4b29..bda2851e 100644 --- a/vllm_ascend/worker/npu_input_batch.py +++ b/vllm_ascend/worker/npu_input_batch.py @@ -73,6 +73,12 @@ class CachedRequestState: lora_request: Optional[LoRARequest] = None prompt_embeds: Optional[torch.Tensor] = None + # pcp/dcp param + local_chunked_kv_lens: Optional[list[Optional[list[Optional[ + list[int]]]]]] = None # Records computed tokens for each chunk + next_pcp_dcp_start_rank: int = 0 # Tracks next starting rank for round-robin distribution + token_blank_in_last_blk: int = 0 # if the last block is not full, how many future tokens can be stored + def __post_init__(self): self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds( self.prompt_token_ids, self.prompt_embeds) @@ -313,6 +319,10 @@ class InputBatch: self.prev_sampled_token_ids_invalid_indices: Optional[set[int]] = None self.prev_req_id_to_index: Optional[dict[str, int]] = None + # pcp/dcp parameters + self.local_chunked_kv_lens: list[Optional[list[Optional[list[Optional[ + list[int]]]]]]] = [None] * max_num_reqs + @property def req_ids(self) -> list[str]: # None elements should only be present transiently @@ -385,6 +395,9 @@ class InputBatch: self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens self.block_table.add_row(request.block_ids, req_index) + # Add PCP/DCP tracking fields + self.local_chunked_kv_lens[req_index] = request.local_chunked_kv_lens + if sampling_params := request.sampling_params: if (self.is_spec_decode and is_spec_decode_unsupported(sampling_params)): @@ -680,6 +693,8 @@ class InputBatch: last_req_index] self.num_computed_tokens_cpu[ empty_index] = self.num_computed_tokens_cpu[last_req_index] + self.local_chunked_kv_lens[ + empty_index] = self.local_chunked_kv_lens[last_req_index] self.block_table.move_row(last_req_index, empty_index) self.temperature_cpu[empty_index] = self.temperature_cpu[ last_req_index]