From 5ec96fd46cdf8e6355c3b534cb3b8d29bcf4acc3 Mon Sep 17 00:00:00 2001 From: LookAround0301 Date: Fri, 14 Nov 2025 08:43:37 +0800 Subject: [PATCH] [long_seq_Feat] support chunk prefill (#4158) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What this PR does / why we need it? 1、qwen GQA attention_v1 optim 2、DeepSeek MLA refactor, all gather q -> all gather kv 3、modelrunner refactor for chunk prefill, we remove some code not use - vLLM version: v0.11.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379 --------- Signed-off-by: LookAround Signed-off-by: Delphine-Nic Co-authored-by: Delphine-Nic --- tests/ut/attention/test_mla_v1.py | 3 - vllm_ascend/attention/attention_v1.py | 373 ++++++-------- vllm_ascend/attention/mla_v1.py | 700 +++++++++----------------- vllm_ascend/attention/utils.py | 24 - vllm_ascend/worker/model_runner_v1.py | 245 +-------- vllm_ascend/worker/npu_input_batch.py | 15 - 6 files changed, 419 insertions(+), 941 deletions(-) diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py index 2ff545ff..102b6feb 100644 --- a/tests/ut/attention/test_mla_v1.py +++ b/tests/ut/attention/test_mla_v1.py @@ -484,9 +484,6 @@ class TestAscendMLAImpl(TestBase): chunk_ctx.chunk_seq_lens = [torch.tensor([8])] chunk_ctx.chunk_seq_lens_npu = [torch.tensor([8])] chunk_ctx.starts = [torch.tensor([0])] - chunk_ctx.max_chunk_num = 1 - chunk_ctx.mask_for_non_zero_chunk = [True] - chunk_ctx.local_chunked_kv_lens = [[[[8]]]] prefill_meta = MagicMock() prefill_meta.chunked_context = chunk_ctx diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index d18f5a3b..50f23a7b 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -44,7 +44,6 @@ from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import AttentionSpec from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, - extract_req_dcp_by_chunk_pcp, filter_chunked_req_indices, split_decodes_and_prefills) from vllm_ascend.compilation.acl_graph import (get_graph_params, @@ -169,10 +168,10 @@ class AscendMetadataForPrefill: @dataclass class ChunkedContextMetadata: actual_chunk_seq_lengths: list[int] - mask_for_non_zero_chunk: Optional[list[bool]] = None - max_chunk_num: int = 0 - local_chunked_kv_lens: Optional[list[Optional[list[Optional[list[ - Optional[list[int]]]]]]]] = None + actual_seq_lengths_kv: list[int] + starts: torch.Tensor + chunked_req_mask: Optional[list[bool]] = None + local_context_lens_allranks: Optional[list[list[int]]] = None cp_kv_recover_idx_for_chunk: Optional[list[int]] = None kv_inverse_idx_for_chunk: Optional[list[int]] = None @@ -286,25 +285,7 @@ class AscendAttentionMetadataBuilder: AscendAttentionMetadataBuilder.reorder_batch_threshold = self.decode_threshold scheduler_config = vllm_config.scheduler_config - self.block_size = vllm_config.cache_config.block_size - self.max_blocks = (vllm_config.model_config.max_model_len + - self.block_size - 1) // self.block_size self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled - if self.chunked_prefill_enabled: - self.chunked_prefill_workspace_size = min( - # Max sure there is enough for 8 full length request or at least - # 4 pages of cache per request - max(8 * self.model_config.max_model_len, - 4 * scheduler_config.max_num_seqs * self.block_size), - 128 * 1024) - assert self.chunked_prefill_workspace_size >= \ - scheduler_config.max_num_seqs * self.block_size - self.chunked_prefill_workspace = torch.empty( - (self.chunked_prefill_workspace_size, - self.model_config.get_head_size()), - dtype=self.model_config.dtype, - device=device, - ) def reorder_batch(self, input_batch, scheduler_output: "SchedulerOutput") -> bool: @@ -385,6 +366,8 @@ class AscendAttentionMetadataBuilder: prefill_metadata = None decode_metadata = None if common_long_seq_metadata is not None: + num_computed_tokens_of_pcp_dcp = common_long_seq_metadata.num_computed_tokens_of_pcp_dcp + assert num_computed_tokens_of_pcp_dcp is not None chunked_context_metadata = None if num_prefills > 0: query_lens = query_lens[num_decode_tokens:] @@ -394,18 +377,39 @@ class AscendAttentionMetadataBuilder: pcp_size = get_prefill_context_model_parallel_world_size( ) if prefill_context_parallel_enable() else 1 if self.chunked_prefill_enabled and max_context_len_cpu > 0: + local_context_lens_allranks = torch.tensor( + num_computed_tokens_of_pcp_dcp + )[num_decodes:num_reqs].to( + self.device).to(dtype=torch.int32) + local_chunked_kv_lens_rank = local_context_lens_allranks[:, + self + . + pcp_rank, + self + . + dcp_rank] + actual_seq_lengths_kv = torch.cumsum( + local_chunked_kv_lens_rank, dim=0).tolist() + chunked_req_mask = self._get_chunked_req_mask( + local_context_lens_allranks) + local_chunk_starts = torch.zeros( + (len(local_context_lens_allranks)), + dtype=torch.int32, + device=self.device) cp_kv_recover_idx_for_chunk = common_long_seq_metadata.cp_kv_recover_idx_for_chunk kv_inverse_idx_for_chunk = torch.argsort( cp_kv_recover_idx_for_chunk.to(torch.float32) ) if cp_kv_recover_idx_for_chunk is not None else None + chunked_context_metadata = \ AscendMetadataForPrefill.ChunkedContextMetadata( actual_chunk_seq_lengths=torch.cumsum(query_lens * pcp_size, dim=0), - mask_for_non_zero_chunk=common_long_seq_metadata.mask_for_non_zero_chunk, - local_chunked_kv_lens=common_long_seq_metadata.local_chunked_kv_lens, + actual_seq_lengths_kv=actual_seq_lengths_kv, + chunked_req_mask=chunked_req_mask, + starts=local_chunk_starts, + local_context_lens_allranks=local_context_lens_allranks, cp_kv_recover_idx_for_chunk=cp_kv_recover_idx_for_chunk, - kv_inverse_idx_for_chunk=kv_inverse_idx_for_chunk, - max_chunk_num=common_long_seq_metadata.max_chunk_num + kv_inverse_idx_for_chunk=kv_inverse_idx_for_chunk ) attn_mask_seqlens = common_long_seq_metadata.attn_mask_seqlens head_attn_nomask_seqlens = common_long_seq_metadata.head_attn_nomask_seqlens @@ -445,8 +449,6 @@ class AscendAttentionMetadataBuilder: actual_seq_lengths_q=torch.cumsum(query_lens, dim=0)) if num_decodes > 0: - num_computed_tokens_of_pcp_dcp = common_long_seq_metadata.num_computed_tokens_of_pcp_dcp - assert num_computed_tokens_of_pcp_dcp is not None num_computed_tokens_array = np.array( num_computed_tokens_of_pcp_dcp) num_computed_tokens_array = num_computed_tokens_array[: @@ -483,6 +485,19 @@ class AscendAttentionMetadataBuilder: decode_meta=decode_metadata) return attn_metadata + def _get_chunked_req_mask(self, local_context_lens_allranks) -> List[bool]: + """ + given 4-d list [req][pcp][dcp], return: + 1. if each req has any chunk (list[bool]) + """ + assert local_context_lens_allranks is not None + if len(local_context_lens_allranks) == 0: + return [] + chunked_req_mask = [(req.sum() > 0).item() + for req in local_context_lens_allranks + if req is not None] + return chunked_req_mask + def build_for_graph_capture( self, common_attn_metadata: AscendCommonAttentionMetadata, @@ -1205,11 +1220,11 @@ class AscendAttentionBackendImpl(AttentionImpl): self.pcp_rank * num_tokens:(self.pcp_rank + 1) * num_tokens, :, :] attn_lse_full_chunk = attn_lse_full_chunk[ self.pcp_rank * num_tokens:(self.pcp_rank + 1) * num_tokens, :, :] + assert attn_output_full_chunk.shape == current_attn_output_prefill.shape and attn_lse_full_chunk.shape == current_attn_lse_prefill.shape seq_len = attn_metadata.query_lens.detach().clone() filtered_indices = filter_chunked_req_indices( - seq_len, - attn_metadata.prefill.chunked_context.mask_for_non_zero_chunk) + seq_len, attn_metadata.prefill.chunked_context.chunked_req_mask) attn_output_prefill_filtered = current_attn_output_prefill[ filtered_indices, :, :] @@ -1221,18 +1236,23 @@ class AscendAttentionBackendImpl(AttentionImpl): attn_output_filtered = self._npu_attn_out_lse_update( attn_lse_prefill_filtered, attn_lse_full_chunk, attn_output_prefill_filtered, attn_output_full_chunk) + current_attn_output_prefill[ filtered_indices, :, :] = attn_output_filtered.to( current_attn_output_prefill.dtype) def _prefill_query_all_gather(self, attn_metadata, prefill_query): - prefill_query_all = get_pcp_group().all_gather(prefill_query.contiguous(), - 0) \ - if self.pcp_size > 1 else prefill_query - prefill_query_all = torch.index_select(prefill_query_all, + if self.dcp_size > 1: + prefill_query = get_dcp_group().all_gather(prefill_query, 1) + + if self.pcp_size > 1: + prefill_query = get_pcp_group().all_gather(prefill_query, 0) + + prefill_query_all = torch.index_select(prefill_query, 0, attn_metadata.prefill.chunked_context.cp_kv_recover_idx_for_chunk) \ - if self.pcp_size > 1 else prefill_query_all + if self.pcp_size > 1 else prefill_query + return prefill_query_all def _compute_prefill_context(self, query: torch.Tensor, @@ -1243,217 +1263,132 @@ class AscendAttentionBackendImpl(AttentionImpl): assert attn_metadata.prefill is not None assert attn_metadata.prefill.chunked_context is not None prefill_metadata = attn_metadata.prefill - local_chunked_kv_lens = attn_metadata.prefill.chunked_context.local_chunked_kv_lens - mask_for_non_zero_chunk = prefill_metadata.chunked_context.mask_for_non_zero_chunk - max_chunk_num = prefill_metadata.chunked_context.max_chunk_num + local_chunked_kv_lens = prefill_metadata.chunked_context.local_context_lens_allranks + assert local_chunked_kv_lens is not None - assert local_chunked_kv_lens is not None and mask_for_non_zero_chunk is not None and max_chunk_num > 0 + local_chunked_kv_lens_rank = local_chunked_kv_lens[:, self.pcp_rank, + self.dcp_rank] - iters = max_chunk_num - # Keep the causal mask; do not override to all-ones. [req_id][chunk_id][cp-rank][dcp_rank] - context_starts_rank = None - - prefix_output_list = [] - prefix_lse_list = [] - for i in range(iters): - key, value, seq_lens_current_chunk_rank = self._load_kv_for_chunk( - attn_metadata, kv_cache, context_starts_rank, i, - local_chunked_kv_lens, prefill_metadata, query) - - # 2. Attention computation - if seq_lens_current_chunk_rank is None or torch.all( - seq_lens_current_chunk_rank == 0).item(): - prefix_output = torch.full( - (query.size(0), self.num_heads, self.head_size), - fill_value=0, - dtype=query.dtype, - device=query.device) - prefix_lse = torch.full((query.size(0), self.num_heads, 1), - fill_value=0, - dtype=torch.float32, - device=query.device) - else: - - actual_seq_lengths_kv = torch.cumsum( - seq_lens_current_chunk_rank, dim=0).tolist() - prefix_output, prefix_lse = torch.ops.npu.npu_fused_infer_attention_score( - query, - key, - value, - num_heads=self.num_heads, - num_key_value_heads=self.num_kv_heads, - input_layout="TND", # - atten_mask=None, - scale=self.scale, - sparse_mode=0, - antiquant_mode=0, - antiquant_scale=None, - softmax_lse_flag=True, - actual_seq_lengths_kv=actual_seq_lengths_kv, - actual_seq_lengths=attn_metadata.prefill.chunked_context. - actual_chunk_seq_lengths) - prefix_output_list.append(prefix_output) - prefix_lse_list.append(prefix_lse) - - # 3. update attn-out & lse - prefix_output, prefix_lse = self._update_attn_out_lse_in_chunks( - prefix_output_list, prefix_lse_list) - - self._update_attn_out_lse_in_pcp(attn_metadata, prefix_output, - prefix_lse) - - return prefix_output, prefix_lse - - def _update_attn_out_lse_in_chunks(self, prefix_output_list, - prefix_lse_list): - # update output and lse - if len(prefix_output_list) > 1: - prefix_output, prefix_lse = self._update_out_and_lse( - torch.stack(prefix_output_list, dim=0), - torch.stack(prefix_lse_list, dim=0)) + key, value = self._load_kv_for_chunk(attn_metadata, kv_cache, + local_chunked_kv_lens_rank, query) + if self.dcp_size > 1: + num_heads = self.num_heads * self.dcp_size else: - prefix_output = prefix_output_list[0] - prefix_lse = prefix_lse_list[0] + num_heads = self.num_heads + + prefix_chunk_output = torch.full( + (query.size(0), num_heads, self.head_size), + fill_value=0, + dtype=query.dtype, + device=query.device) + prefix_chunk_lse = torch.full((query.size(0), num_heads, 1), + fill_value=-torch.inf, + dtype=torch.float32, + device=query.device) + + if not torch.all(local_chunked_kv_lens_rank == 0).item(): + prefix_chunk_output, prefix_chunk_lse = torch.ops.npu.npu_fused_infer_attention_score( + query, + key, + value, + num_heads=num_heads, + num_key_value_heads=self.num_kv_heads, + input_layout="TND", + atten_mask=None, + scale=self.scale, + sparse_mode=0, + antiquant_mode=0, + antiquant_scale=None, + softmax_lse_flag=True, + actual_seq_lengths_kv=prefill_metadata.chunked_context. + actual_seq_lengths_kv, + actual_seq_lengths=attn_metadata.prefill.chunked_context. + actual_chunk_seq_lengths) + + prefix_output, prefix_lse = self._update_chunk_attn_out_lse( + prefix_chunk_output, prefix_chunk_lse) + return prefix_output, prefix_lse - def _update_attn_out_lse_in_pcp(self, attn_metadata, prefix_output, - prefix_lse): + def _update_chunk_attn_out_lse(self, prefix_chunk_output, + prefix_chunk_lse): # CP dimension all_gather and fusion - if self.pcp_size > 1: - # filter non-zero chunk part of prefix_output - current_seq_lens = attn_metadata.query_lens.detach().clone() - current_seq_lens.mul_(self.pcp_size) # q_full - current_seq_lens_cpu = current_seq_lens.cpu() - filtered_indices = filter_chunked_req_indices( - current_seq_lens_cpu, - attn_metadata.prefill.chunked_context.mask_for_non_zero_chunk) - prefix_output_filtered = prefix_output[filtered_indices, :, :] - prefix_lse_filtered = prefix_lse[filtered_indices, :, :] + chunk_attn_out_lse = torch.cat([prefix_chunk_output, prefix_chunk_lse], + dim=-1) - out_lse_local = torch.cat( - [prefix_output_filtered, prefix_lse_filtered], dim=-1) + if self.dcp_size > 1: + chunk_attn_out_lse = chunk_attn_out_lse.permute([1, 2, + 0]).contiguous() + attn_out_lse_all2all = torch.empty_like(chunk_attn_out_lse) + dist.all_to_all_single(attn_out_lse_all2all, + chunk_attn_out_lse, + group=self.dcp_group) + attn_out_lse_all2all = attn_out_lse_all2all.permute([2, 0, 1]) + if self.pcp_size > 1: + chunk_attn_out_lse = attn_out_lse_all2all.contiguous() + + attn_out_lse_list = list( + torch.chunk(attn_out_lse_all2all, self.dcp_size, dim=1)) + + if self.pcp_size > 1: attn_out_lse_list = [ - torch.empty_like(out_lse_local) for _ in range(self.pcp_size) + torch.empty_like(chunk_attn_out_lse) + for _ in range(self.pcp_size) ] dist.all_gather(attn_out_lse_list, - out_lse_local, + chunk_attn_out_lse, group=self.pcp_group) - attn_out_lse_allgather = torch.stack( - attn_out_lse_list, - dim=0) # [pcp, batch_size, num_heads, head_size+1] - attn_out_allgather, attn_lse_allgather = torch.split( - attn_out_lse_allgather, [self.head_size, 1], dim=-1) - prefix_output_filtered, prefix_lse_filtered = self._update_out_and_lse( - attn_out_allgather, attn_lse_allgather) + if self.dcp_size > 1 and self.pcp_size > 1: + attn_out_lse_list_pcp_dcp = [] + for s in attn_out_lse_list: + attn_out_lse_list_split = list( + torch.chunk(s, self.dcp_size, dim=1)) + attn_out_lse_list_pcp_dcp += attn_out_lse_list_split + attn_out_lse_list = attn_out_lse_list_pcp_dcp - prefix_output[filtered_indices, :, :] = prefix_output_filtered.to( - prefix_output.dtype) - prefix_lse[filtered_indices, :, :] = prefix_lse_filtered.to( - prefix_lse.dtype) + attn_out_lse_allgather = torch.stack( + attn_out_lse_list, + dim=0) # [pcp, batch_size, num_heads, head_size+1] + attn_out_allgather, attn_lse_allgather = torch.split( + attn_out_lse_allgather, [self.head_size, 1], dim=-1) - def _load_kv_for_chunk(self, attn_metadata, kv_cache, context_starts_rank, - i, local_chunked_kv_lens, prefill_metadata, query): + prefix_output, prefix_lse = self._update_out_and_lse( + attn_out_allgather, attn_lse_allgather) + return prefix_output, prefix_lse + + def _load_kv_for_chunk(self, attn_metadata, kv_cache, + local_chunked_kv_lens_rank, query): cache_key = kv_cache[0] cache_value = kv_cache[1] num_heads = cache_key.size(2) head_size = kv_cache[0].size(-1) - # 1. Load current query's history key-value - seq_lens_current_chunk = attn_metadata.query_lens.detach().clone() - num_requests = len(seq_lens_current_chunk) - # Before dealing with a new chunk, set to zero, and accumulate the start positions as chunk prefill step increases - context_starts_rank = torch.zeros( - num_requests, dtype=torch.int32, device=query.device - ) if context_starts_rank is None else context_starts_rank - # Calculate tokens each rank should process per request - seq_lens_current_chunk_rank = torch.zeros_like(seq_lens_current_chunk, - dtype=torch.int32, - device=query.device) - total_toks = 0 - for req_idx in range(num_requests): - if i >= len(local_chunked_kv_lens[req_idx]): - continue - n_computed_acc = local_chunked_kv_lens[req_idx][i] - total_toks += n_computed_acc[self.pcp_rank][self.dcp_rank] - seq_lens_current_chunk_rank[req_idx] = n_computed_acc[ - self.pcp_rank][self.dcp_rank] - if total_toks > 0: - key = torch.empty(total_toks, - num_heads, - head_size, - dtype=query.dtype, - device=query.device) - value = torch.empty(total_toks, - num_heads, - head_size, - dtype=query.dtype, - device=query.device) + total_toks = local_chunked_kv_lens_rank.sum() + key = torch.empty(total_toks, + num_heads, + head_size, + dtype=query.dtype, + device=query.device) + value = torch.empty(total_toks, + num_heads, + head_size, + dtype=query.dtype, + device=query.device) + if total_toks > 0: torch_npu.atb.npu_paged_cache_load( cache_key, cache_value, attn_metadata.prefill.block_tables, - seq_lens_current_chunk_rank.to(query.device), - seq_starts= - context_starts_rank, # slot offsets of current chunk in current iteration + local_chunked_kv_lens_rank, + seq_starts=attn_metadata.prefill.chunked_context. + starts, # slot offsets of current chunk in current iteration key=key, value=value, ) - else: - # If current rank has no tokens to process, create empty tensors - key = torch.empty(0, - self.num_heads, - self.head_size, - dtype=query.dtype, - device=query.device) - value = torch.empty(0, - self.num_heads, - self.head_size, - dtype=query.dtype, - device=query.device) - seq_lens_current_chunk_rank = torch.zeros( - (len(seq_lens_current_chunk), ), - dtype=torch.int32, - device=query.device) - for req_idx in range(num_requests): - # Before dealing with a new chunk, set to zero, and accumulate the start positions as chunk prefill step increases - if i >= len(local_chunked_kv_lens[req_idx]): - continue - context_starts_rank[req_idx] += local_chunked_kv_lens[req_idx][i][ - self.pcp_rank][self.dcp_rank] - if self.dcp_size > 1: - req_dcp_sizes = extract_req_dcp_by_chunk_pcp( - local_chunked_kv_lens, i, self.dcp_size, self.pcp_rank) - - assert len(req_dcp_sizes) == num_requests and all( - len(dcp_arr) == self.dcp_size for dcp_arr in req_dcp_sizes) - total_toks = np.sum(np.array(req_dcp_sizes)) - kv_local = torch.cat([key, value], dim=-1) - head_dim = kv_local.size(-1) - kv_full = torch.empty((total_toks, num_heads, head_dim), - device=query.device, - dtype=query.dtype) - - kv_full_list = [None for _ in range(self.dcp_size)] - dist.all_gather_object(kv_full_list, - kv_local, - group=self.dcp_group) - kv_full_list = [ - kv for kv in kv_full_list if kv is not None and kv.numel() > 0 - ] - - if len(kv_full_list) > 0: - kv_full = torch.cat(kv_full_list, dim=0) - key, value = kv_full.split([head_size, head_size], dim=-1) - if total_toks == 0: - return key, value, None - seq_lens_current_chunk_rank = torch.tensor( - np.sum(np.array(req_dcp_sizes), axis=1), - dtype=torch.int32, - device=query.device) # [reqs] - return key, value, seq_lens_current_chunk_rank + return key, value def forward( self, diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 64f72c65..0650f3e3 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -5,7 +5,6 @@ from typing import (TYPE_CHECKING, ClassVar, List, NamedTuple, Optional, Tuple, import numpy as np import torch import torch.distributed as dist -import torch.nn.functional as F import torch_npu from torch import nn from vllm.attention.backends.abstract import (AttentionBackend, @@ -35,14 +34,11 @@ from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm_ascend import envs from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.attention.attention_v1 import AscendAttentionState - -# isort: off -from vllm_ascend.attention.utils import ( - AscendCommonAttentionMetadata, extract_req_dcp_by_chunk_pcp, - filter_chunked_req_indices, maybe_save_kv_layer_to_connector, - split_decodes_and_prefills, trans_rope_weight, transdata, - wait_for_kv_layer_from_connector) -# isort: on +from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, + maybe_save_kv_layer_to_connector, + split_decodes_and_prefills, + trans_rope_weight, transdata, + wait_for_kv_layer_from_connector) from vllm_ascend.compilation.acl_graph import (get_graph_params, update_graph_params_workspaces) from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch @@ -122,10 +118,13 @@ class AscendMLAPrefillMetadata: workspace: torch.Tensor chunk_seq_lens: torch.Tensor chunk_seq_lens_npu: torch.Tensor - mask_for_non_zero_chunk: Optional[list[bool]] = None - max_chunk_num: int = 0 - local_chunked_kv_lens: Optional[list[Optional[list[Optional[list[ - Optional[list[int]]]]]]]] = None + # for mla DCP & PCP + padded_chunk_seq_lens_npu: torch.Tensor = None + padded_local_chunk_seq_lens: Optional[list[list[int]]] = None + local_context_lens_allranks: Optional[list[list[int]]] = None + padded_local_cu_seq_lens: torch.Tensor = None + cu_seq_lens_lst: Optional[list[list[int]]] = None + chunk_size: Optional[int] = None attn_mask: torch.Tensor query_lens: torch.Tensor @@ -140,7 +139,6 @@ class AscendMLAPrefillMetadata: sin: torch.Tensor = None cos: torch.Tensor = None pcp_metadata: Optional[AscendPCPMetadata] = None - cp_kv_recover_idx_for_chunk: Optional[list[int]] = None @dataclass @@ -285,6 +283,9 @@ class AscendMLAMetadataBuilder: self.dcp_size = get_decode_context_model_parallel_world_size() self.dcp_rank = get_decode_context_model_parallel_rank( ) if self.dcp_size > 1 else 0 + self.cp_local_block_size = vllm_config.parallel_config.cp_kv_cache_interleave_size if prefill_context_parallel_enable( + ) else 1 + self.cp_virtual_block_size = self.cp_local_block_size * self.dcp_size * self.pcp_size decode_max_num_seqs = getattr(scheduler_config, 'decode_max_num_seqs', 0) max_num_seqs = max(scheduler_config.max_num_seqs, decode_max_num_seqs) @@ -292,16 +293,6 @@ class AscendMLAMetadataBuilder: self.decode_threshold, dtype=torch.uint8, device=device) - self.seq_mask_pcp_buf = torch.empty(max_num_seqs * - self.decode_threshold, - self.pcp_size, - dtype=torch.uint8, - device=device) - self.seq_mask_dcp_buf = torch.empty(max_num_seqs * - self.decode_threshold, - self.dcp_size, - dtype=torch.uint8, - device=device) def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput") -> bool: @@ -366,10 +357,6 @@ class AscendMLAMetadataBuilder: num_actual_tokens_pcp_padded = long_seq_metadata.num_actual_tokens_pcp_padded if long_seq_metadata else None num_computed_tokens_of_pcp_dcp = long_seq_metadata.num_computed_tokens_of_pcp_dcp if long_seq_metadata else None - cp_kv_recover_idx_for_chunk = long_seq_metadata.cp_kv_recover_idx_for_chunk if long_seq_metadata else None - local_chunked_kv_lens = long_seq_metadata.local_chunked_kv_lens if long_seq_metadata else None - mask_for_non_zero_chunk = long_seq_metadata.mask_for_non_zero_chunk if long_seq_metadata else None - max_chunk_num = long_seq_metadata.max_chunk_num if long_seq_metadata else 0 num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \ split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.decode_threshold) @@ -468,19 +455,75 @@ class AscendMLAMetadataBuilder: dim=1, out=cu_seq_lens_cpu[:, 1:], dtype=torch.int32) - chunked_context_metadata = \ + + if self.dcp_size * self.pcp_size > 1: + if num_computed_tokens_of_pcp_dcp is not None: + local_context_lens_allranks = torch.tensor( + num_computed_tokens_of_pcp_dcp[reqs_start:num_reqs] + ).reshape(-1, self.dcp_size * self.pcp_size) + # Note(qcs): The max local context lengths + # padded to `cp_local_block_size`. + padded_local_context_lens_cpu = (cdiv( + context_lens_cpu, + self.cp_virtual_block_size, + ) * self.cp_local_block_size) + padded_local_max_context_chunk_across_ranks = (cdiv( + max_context_chunk, + self.cp_virtual_block_size, + ) * self.cp_local_block_size) + local_chunk_starts = ( + torch.arange(num_chunks, + dtype=torch.int32).unsqueeze(1).expand( + -1, num_prefills) * + padded_local_max_context_chunk_across_ranks) + local_chunk_ends = torch.min( + padded_local_context_lens_cpu.unsqueeze(0), + local_chunk_starts + + padded_local_max_context_chunk_across_ranks, + ) + padded_local_chunk_seq_lens = (local_chunk_ends - + local_chunk_starts).clamp( + min=0) + padded_local_cu_chunk_seq_lens_cpu = torch.zeros( + num_chunks, + num_prefills + 1, + dtype=torch.int32, + pin_memory=True) + torch.cumsum( + padded_local_chunk_seq_lens, + dim=1, + out=padded_local_cu_chunk_seq_lens_cpu[:, 1:], + dtype=torch.int32, + ) + chunked_context_metadata = \ AscendMLAPrefillMetadata.ChunkedContextMetadata( - cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True), - starts=chunk_starts.to(device, non_blocking=True), - seq_tot=chunk_seq_lens.sum(dim=1).tolist(), - max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), - chunk_seq_lens=chunk_seq_lens, - chunk_seq_lens_npu=chunk_seq_lens.npu(), - workspace=self.chunked_prefill_workspace, - local_chunked_kv_lens=local_chunked_kv_lens, - mask_for_non_zero_chunk=mask_for_non_zero_chunk, - max_chunk_num=max_chunk_num, - ) + cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True), + starts=local_chunk_starts.to(device, non_blocking=True), + seq_tot=padded_local_chunk_seq_lens.sum(dim=1).tolist(), + max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), + chunk_seq_lens=chunk_seq_lens, + chunk_seq_lens_npu=chunk_seq_lens.npu(), + workspace=self.chunked_prefill_workspace, + padded_chunk_seq_lens_npu=padded_local_chunk_seq_lens.npu(), + padded_local_chunk_seq_lens=padded_local_chunk_seq_lens.tolist(), + local_context_lens_allranks=local_context_lens_allranks.tolist(), + padded_local_cu_seq_lens=padded_local_cu_chunk_seq_lens_cpu.to( + device, non_blocking=True + ), + cu_seq_lens_lst=cu_seq_lens_cpu.tolist(), + chunk_size=padded_local_max_context_chunk_across_ranks, + ) + else: + chunked_context_metadata = \ + AscendMLAPrefillMetadata.ChunkedContextMetadata( + cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True), + starts=chunk_starts.to(device, non_blocking=True), + seq_tot=chunk_seq_lens.sum(dim=1).tolist(), + max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), + chunk_seq_lens=chunk_seq_lens, + chunk_seq_lens_npu=chunk_seq_lens.npu(), + workspace=self.chunked_prefill_workspace, + ) prefill_input_positions = input_positions[tokens_start:] cos = self.cos_cache[ prefill_input_positions].unsqueeze( # type: ignore @@ -502,7 +545,7 @@ class AscendMLAMetadataBuilder: sin=sin, cos=cos, pcp_metadata=pcp_metadata, - cp_kv_recover_idx_for_chunk=cp_kv_recover_idx_for_chunk) + ) decode_metadata = None if num_decodes > 0: @@ -516,7 +559,7 @@ class AscendMLAMetadataBuilder: block_table = block_table[:num_decodes, ...] # For pcp + spec decode, we flatten seq_lens and block_table # to avoid irregular spec_attn_mask shape - if self.pcp_size > 1: + if self.pcp_size > 1 and self.decode_threshold > 1: block_table = block_table.repeat_interleave( self.decode_threshold, dim=0) seq_lens_list = seq_lens.tolist() @@ -921,26 +964,8 @@ class AscendMLAImpl(MLAAttentionImpl): prefill_metadata = attn_metadata.prefill if prefill_metadata is None or prefill_metadata.chunked_context is None: return prefix_output, prefix_lse - local_chunked_kv_lens = prefill_metadata.chunked_context.local_chunked_kv_lens - mask_for_non_zero_chunk = prefill_metadata.chunked_context.mask_for_non_zero_chunk - max_chunk_num = prefill_metadata.chunked_context.max_chunk_num - if self.pcp_size * self.dcp_size > 1: - assert local_chunked_kv_lens is not None and mask_for_non_zero_chunk is not None and max_chunk_num > 0 - - if self.pcp_size > 1: - prefix_output = torch.zeros(q_nope.shape[0], - self.num_heads, - self.v_head_dim, - dtype=q_nope.dtype, - device=q_nope.device) - prefix_lse = torch.zeros(self.num_heads, - q_pe.shape[0], - dtype=torch.float32, - device=q_pe.device) iters = len(prefill_metadata.chunked_context.seq_tot) - if self.pcp_size * self.dcp_size > 1: - iters = max_chunk_num current_seq_len = torch.tensor(prefill_metadata.query_lens, dtype=torch.int32) @@ -948,305 +973,97 @@ class AscendMLAImpl(MLAAttentionImpl): cache_k_pe = kv_c_and_k_pe_cache[1] num_heads = cache_k_pe.size(2) latent_kv_dim = kv_c_and_k_pe_cache[0].size(-1) - # token -> request mapping for building per-token masks when CP>1 - seq_len1 = torch.tensor(prefill_metadata.query_lens, - dtype=torch.int32, - device=q_nope.device) - seq_len1.mul_( - self.pcp_size) # q_full: already padded, divisible by cp_size - - # Select mask: prefer CP prefill mask from metadata; fallback to cached prefill_mask; create if needed. - mask_local = None - if attn_metadata is not None and attn_metadata.prefill is not None and \ - attn_metadata.prefill.pcp_metadata is not None and attn_metadata.prefill.pcp_metadata.pcp_prefill_mask is not None: - mask_local = attn_metadata.prefill.pcp_metadata.pcp_prefill_mask - else: - mask_local = self.prefill_mask - if mask_local is None: - mask_local = torch.triu( - torch.ones(512, - 512, - device=q_nope.device, - dtype=q_nope.dtype), 1) - self.prefill_mask = mask_local - - # Keep the causal mask; do not override to all-ones. - context_starts_rank = None - for i in range(iters): - if self.pcp_size * self.dcp_size > 1: - ## DCP mode: each rank processes its own (cp,dcp) historical context slice per request dimension - num_requests = len(seq_len1) - assert num_requests == len(local_chunked_kv_lens) - # Before dealing with a new chunk, set to zero, and accumulate the start positions as chunk prefill step increases - context_starts_rank = torch.zeros( - num_requests, dtype=torch.int32, device=q_nope.device - ) if context_starts_rank is None else context_starts_rank + toks = prefill_metadata.chunked_context.seq_tot[i] + # chunk_seq_lens will be padded when pcp&dcp + context_seq_len = prefill_metadata.chunked_context.chunk_seq_lens[ + i] + context_seq_len_npu = prefill_metadata.chunked_context.chunk_seq_lens_npu[ + i] + seq_len = torch.stack([current_seq_len, context_seq_len]) + kv_c_normed = torch.empty(toks, + num_heads, + latent_kv_dim, + dtype=q_nope.dtype, + device=q_nope.device) + k_pe = torch.empty(toks, + num_heads, + rope_dim, + dtype=q_nope.dtype, + device=q_nope.device) - ## Calculate tokens each rank should process per request - seq_len2_rank = torch.zeros_like(seq_len1, dtype=torch.int32) - total_toks = 0 - - for req_idx in range(num_requests): - if i >= len(local_chunked_kv_lens[req_idx]): - continue - n_computed_acc = local_chunked_kv_lens[req_idx][i] - total_toks += n_computed_acc[self.pcp_rank][self.dcp_rank] - seq_len2_rank[req_idx] = n_computed_acc[self.pcp_rank][ - self.dcp_rank] - - if total_toks > 0: - kv_c_normed = torch.empty(total_toks, - num_heads, - latent_kv_dim, - dtype=q_nope.dtype, - device=q_nope.device) - k_pe = torch.empty(total_toks, - num_heads, - rope_dim, - dtype=q_nope.dtype, - device=q_nope.device) - - torch_npu.atb.npu_paged_cache_load( - cache_kv_c, - cache_k_pe, - prefill_metadata.block_table, - seq_len2_rank.to(q_nope.device), - seq_starts= - context_starts_rank, # slot offsets of current chunk in current iteration - key=kv_c_normed, - value=k_pe, - ) - seq_len2 = seq_len2_rank.to(q_nope.device) - else: - # If current rank has no tokens to process, create empty tensors - kv_c_normed = torch.empty(0, - num_heads, - latent_kv_dim, - dtype=q_nope.dtype, - device=q_nope.device) - k_pe = torch.empty(0, - num_heads, - rope_dim, - dtype=q_nope.dtype, - device=q_nope.device) - seq_len2 = torch.zeros((len(seq_len1), ), - dtype=torch.int32, - device=q_nope.device) - seq_len = torch.stack([seq_len1.cpu(), seq_len2.cpu()]) - for req_idx in range(num_requests): - # Before dealing with a new chunk, set to zero, and accumulate the start positions as chunk prefill step increases - if i >= len(local_chunked_kv_lens[req_idx]): - continue - context_starts_rank[req_idx] += local_chunked_kv_lens[ - req_idx][i][self.pcp_rank][self.dcp_rank] - else: - # Original logic: ChunkPrefill-only mode - toks = prefill_metadata.chunked_context.seq_tot[i] - context_seq_len = prefill_metadata.chunked_context.chunk_seq_lens[ + if self.dcp_size * self.pcp_size > 1: + context_seq_len_npu = prefill_metadata.chunked_context.padded_chunk_seq_lens_npu[ i] - context_seq_len_npu = prefill_metadata.chunked_context.chunk_seq_lens_npu[ - i] - seq_len = torch.stack([current_seq_len, context_seq_len]) - kv_c_normed = torch.empty(toks, - num_heads, - latent_kv_dim, - dtype=q_nope.dtype, - device=q_nope.device) - k_pe = torch.empty(toks, - num_heads, - rope_dim, - dtype=q_nope.dtype, - device=q_nope.device) + torch_npu.atb.npu_paged_cache_load( + cache_kv_c, + cache_k_pe, + prefill_metadata.block_table, + context_seq_len_npu, + seq_starts=prefill_metadata.chunked_context.starts[i], + key=kv_c_normed, + value=k_pe, + ) - torch_npu.atb.npu_paged_cache_load( - cache_kv_c, - cache_k_pe, - prefill_metadata.block_table, - context_seq_len_npu, - seq_starts=prefill_metadata.chunked_context.starts[i], - key=kv_c_normed, - value=k_pe, + cache_kv_c_k_pe = torch.cat([kv_c_normed, k_pe], dim=-1) + if self.dcp_size > 1: + cache_kv_c_k_pe = get_dcp_group().all_gather( + cache_kv_c_k_pe, 0) + + if self.pcp_size > 1: + cache_kv_c_k_pe = get_pcp_group().all_gather( + cache_kv_c_k_pe, 0) + + if self.dcp_size * self.pcp_size > 1: + allgatered_kv_c_normed, allgatered_k_pe = cache_kv_c_k_pe.split( + [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + kv_c_normed, k_pe = self._reorg_kvcache( + allgatered_kv_c_normed, + allgatered_k_pe, + padded_local_chunk_seq_lens_lst=prefill_metadata. + chunked_context.padded_local_chunk_seq_lens[i], + local_context_lens_allranks=prefill_metadata. + chunked_context.local_context_lens_allranks, + sum_seq_len=prefill_metadata.chunked_context. + cu_seq_lens_lst[i][-1], + max_seq_len=prefill_metadata.chunked_context. + max_seq_lens[i], + chunk_size=prefill_metadata.chunked_context.chunk_size, + chunk_idx=i, + toks=toks, ) kv_c_normed = kv_c_normed.squeeze() - if self.dcp_size > 1: - # DCP mode: first all_gather within DCP group, let each rank in CP group share complete sequence blocks - # Step 1: DCP all_gather latent - kv_c_k_pe_local = torch.cat( - [kv_c_normed, k_pe.squeeze()], - dim=-1) # [local_toks, latent_dim + rope_dim] - - # Step 2: use all_gather_into_tensor_uneven (gather + cat) - req_dcp_sizes = extract_req_dcp_by_chunk_pcp( - local_chunked_kv_lens, i, self.dcp_size, self.pcp_rank - ) # need to know num tokens of each rank in dcp group before all_gather # [reqs, dcp] - assert len(req_dcp_sizes) == num_requests and all( - len(dcp_arr) == self.dcp_size for dcp_arr in req_dcp_sizes) - total_toks = np.sum(np.array(req_dcp_sizes)) - latent_rope_dim = kv_c_k_pe_local.size(-1) - kv_c_k_pe_full = torch.empty((total_toks, latent_rope_dim), - device=kv_c_k_pe_local.device, - dtype=kv_c_k_pe_local.dtype) - - kv_c_k_pe_full_list = [None for _ in range(self.dcp_size)] - dist.all_gather_object(kv_c_k_pe_full_list, - kv_c_k_pe_local, - group=self.dcp_group) - kv_c_k_pe_full_list = [ - kv_c_k_pe for kv_c_k_pe in kv_c_k_pe_full_list - if kv_c_k_pe is not None and kv_c_k_pe.numel() > 0 - ] - if len(kv_c_k_pe_full_list) > 0: - kv_c_k_pe_full = torch.cat(kv_c_k_pe_full_list, dim=0) - if len(kv_c_k_pe_full.shape) == 1: - assert total_toks == 1 - kv_c_k_pe_full = kv_c_k_pe_full.unsqueeze(0) - assert kv_c_k_pe_full.shape[ - 0] == total_toks and kv_c_k_pe_full.shape[ - 1] == latent_rope_dim - kv_c_normed_full, k_pe_full = torch.split( - kv_c_k_pe_full, [latent_kv_dim, rope_dim], dim=-1) - - # Step 3: process complete sequence with TP projection to get current rank's head slice - # Case that no kv_cache has been stored on this CP rank(after dcp all_gather), no need to do following computation. - if total_toks == 0: - continue - kv_nope = self.kv_b_proj(kv_c_normed_full)[0].view( - -1, self.num_heads, - self.qk_nope_head_dim + self.v_head_dim) - k_nope, v = kv_nope.split( - [self.qk_nope_head_dim, self.v_head_dim], dim=-1) - k_pe = k_pe_full.unsqueeze(1).expand((*k_nope.shape[:-1], -1)) - - seq_len2 = torch.tensor(np.sum(np.array(req_dcp_sizes), - axis=1), - dtype=torch.int32, - device=q_nope.device) # [reqs] - seq_len = torch.stack([seq_len1.cpu(), seq_len2.cpu()]) - else: - # Non-DCP mode: use TP-split projection - kv_nope = self.kv_b_proj(kv_c_normed)[0].view( - -1, self.num_heads, - self.qk_nope_head_dim + self.v_head_dim) - k_nope, v = kv_nope.split( - [self.qk_nope_head_dim, self.v_head_dim], dim=-1) - k_pe = k_pe.expand((*k_nope.shape[:-1], -1)) + kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \ + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = kv_nope\ + .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + k_pe = k_pe.expand((*k_nope.shape[:-1], -1)) if self.pcp_size > 1: - # Case that no kv_cache has been stored on this CP rank, no need to do following computation. - if torch.all(seq_len2 == 0).item(): - continue - # PCP mode: first compute this rank's contribution to the chunk - if i == 0: - torch_npu.atb.npu_ring_mla( - q_nope=q_nope, - q_rope=q_pe, - k_nope=k_nope, - k_rope=k_pe, - value=v, - mask=mask_local, - seqlen=seq_len, - head_num=self.num_heads, - kv_head_num=self.num_heads, - pre_out=None, - prev_lse=None, - qk_scale=self.scale, - kernel_type="kernel_type_high_precision", - mask_type="no_mask", - input_layout="type_bsnd", - calc_type="calc_type_first_ring", - output=prefix_output, - softmax_lse=prefix_lse) - continue - - torch_npu.atb.npu_ring_mla( - q_nope=q_nope, - q_rope=q_pe, - k_nope=k_nope, - k_rope=k_pe, - value=v, - mask=mask_local, - seqlen=seq_len, - head_num=self.num_heads, - kv_head_num=self.num_heads, - pre_out=prefix_output, - prev_lse=prefix_lse, - qk_scale=self.scale, - kernel_type="kernel_type_high_precision", - mask_type="no_mask", - input_layout="type_bsnd", - calc_type="calc_type_default", - output=prefix_output, - softmax_lse=prefix_lse) - + mask = attn_metadata.prefill.pcp_metadata.pcp_prefill_mask else: - assert not torch.all(context_seq_len == 0).item() - # compute this chunk block then update prefix tensors to keep shapes consistent - torch_npu.atb.npu_ring_mla( - q_nope=q_nope, - q_rope=q_pe, - k_nope=k_nope, - k_rope=k_pe, - value=v, - mask=mask_local, - seqlen=seq_len, - head_num=self.num_heads, - kv_head_num=self.num_heads, - pre_out=prefix_output, - prev_lse=prefix_lse, - qk_scale=self.scale, - kernel_type="kernel_type_high_precision", - mask_type="no_mask", - input_layout="type_bsnd", - calc_type="calc_type_default", - output=prefix_output, - softmax_lse=prefix_lse) - - # CP dimension all_gather and fusion - if self.pcp_size > 1: - # filter non-zero chunk part of prefix_output - seq_len1_cpu = seq_len1.cpu() - filtered_indices = filter_chunked_req_indices( - seq_len1_cpu, mask_for_non_zero_chunk) - prefix_output_filtered = prefix_output[filtered_indices, :, :] - prefix_lse_filtered = prefix_lse[:, filtered_indices] - - # normalize prefix LSE to [bs, heads, 1] for stable updates - prefix_lse_filtered_bt = prefix_lse_filtered.permute( - 1, 0).unsqueeze(-1).contiguous( - ) if prefix_lse_filtered is not None else None - out_lse_local = torch.cat( - [prefix_output_filtered, prefix_lse_filtered_bt], dim=-1) - out_lse_list = [ - torch.empty_like(out_lse_local) for _ in range(self.pcp_size) - ] - dist.all_gather(out_lse_list, out_lse_local, group=self.pcp_group) - prefix_output_filtered = None - prefix_lse_filtered_bt = None - for r in range(self.pcp_size): - out_lse_r = out_lse_list[r] - if torch.all(out_lse_r == 0).item(): - continue - out_r, lse_r = torch.split(out_lse_r, [self.v_head_dim, 1], - dim=-1) - token_mask = torch.ones([out_r.size(0)], - dtype=torch.uint8, - device=out_r.device) - prefix_output_filtered, prefix_lse_filtered_bt = self._update_out_and_lse( - prefix_output_filtered, prefix_lse_filtered_bt, out_r, - lse_r, token_mask) - # convert lse back to [heads, bs] - assert prefix_output_filtered is not None and prefix_lse_filtered_bt is not None - prefix_lse_filtered = prefix_lse_filtered_bt.squeeze(-1).permute( - 1, 0).contiguous() - - prefix_output[filtered_indices, :, :] = prefix_output_filtered.to( - prefix_output.dtype) - prefix_lse[:, filtered_indices] = prefix_lse_filtered.to( - prefix_lse.dtype) - + mask = self.prefill_mask + torch_npu.atb.npu_ring_mla( + q_nope=q_nope, + q_rope=q_pe, + k_nope=k_nope, + k_rope=k_pe, + value=v, + mask=mask, + seqlen=seq_len, + head_num=self.num_heads, + kv_head_num=self.num_heads, + pre_out=prefix_output, + prev_lse=prefix_lse, + qk_scale=self.scale, + kernel_type="kernel_type_high_precision", + mask_type="no_mask", + input_layout="type_bsnd", + calc_type="calc_type_default", + output=prefix_output, + softmax_lse=prefix_lse) return prefix_output, prefix_lse def _forward_prefill( @@ -1814,8 +1631,7 @@ class AscendMLAImpl(MLAAttentionImpl): head_attn_nomask_seqlens = attn_metadata.prefill.pcp_metadata.head_attn_nomask_seqlens tail_attn_nomask_seqlens = attn_metadata.prefill.pcp_metadata.tail_attn_nomask_seqlens mask = attn_metadata.prefill.pcp_metadata.pcp_prefill_mask - - output_head, head_lse = self._attention_with_mask_and_nomask( + output_head, lse_head = self._attention_with_mask_and_nomask( q_nope=torch.index_select(q_nope, 0, q_head_idx), q_pe=torch.index_select(q_pe, 0, q_head_idx), k_nope=k_nope, @@ -1827,7 +1643,7 @@ class AscendMLAImpl(MLAAttentionImpl): attn_nomask_seqlens=head_attn_nomask_seqlens, mask=mask) - output_tail, tail_lse = self._attention_with_mask_and_nomask( + output_tail, lse_tail = self._attention_with_mask_and_nomask( q_nope=torch.index_select(q_nope, 0, q_tail_idx), q_pe=torch.index_select(q_pe, 0, q_tail_idx), k_nope=k_nope, @@ -1840,86 +1656,15 @@ class AscendMLAImpl(MLAAttentionImpl): mask=mask) q_full_idx = attn_metadata.prefill.pcp_metadata.q_full_idx - output = torch.index_select( + attn_output = torch.index_select( torch.cat([output_head, output_tail], dim=0), 0, q_full_idx) + attn_lse = torch.index_select(torch.cat([lse_head, lse_tail], dim=1), + 1, q_full_idx) - # Synchronize and reorder LSE for subsequent chunked context accumulation - attn_lse = torch.cat([head_lse, tail_lse], dim=1) - attn_lse = attn_lse[:, q_full_idx] + output, _ = self._compute_prefill_context( \ + q_nope, q_pe, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse) - # Post-processing: keep [tokens, H, V] shape and perform chunked context accumulation if needed - if attn_metadata.prefill is not None and \ - attn_metadata.prefill.chunked_context is not None: - # q all_gather - q_nope_full = get_pcp_group().all_gather(q_nope.contiguous(), 0) - q_pe_full = get_pcp_group().all_gather(q_pe.contiguous(), 0) - q_nope_full = torch.index_select( - q_nope_full, 0, - attn_metadata.prefill.cp_kv_recover_idx_for_chunk) - q_pe_full = torch.index_select( - q_pe_full, 0, - attn_metadata.prefill.cp_kv_recover_idx_for_chunk) - attn_output_pre = output.view(num_tokens, self.num_heads, - self.v_head_dim) - attn_output_pre_full, attn_lse_full = self._compute_prefill_context( - q_nope_full, - q_pe_full, - kv_c_and_k_pe_cache, - self.qk_rope_head_dim, - attn_metadata, - None, - None, - ) - # reorder back && extract output + lse result of each cp rank - inverse_idx = torch.argsort( - attn_metadata.prefill.cp_kv_recover_idx_for_chunk) - attn_output_pre_full = torch.index_select(attn_output_pre_full, 0, - inverse_idx) - attn_lse_full = torch.index_select(attn_lse_full, 1, inverse_idx) - attn_output_pre_new = attn_output_pre_full[ - self.pcp_rank * num_tokens:(self.pcp_rank + 1) * - num_tokens, :, :] - attn_lse_new = attn_lse_full[:, self.pcp_rank * - num_tokens:(self.pcp_rank + 1) * - num_tokens] - - # update(output_origin, output_new) - assert attn_output_pre_new.shape == attn_output_pre.shape and attn_lse_new.shape == attn_lse.shape - seq_len = torch.tensor(attn_metadata.prefill.query_lens, - dtype=torch.int32) - mask_for_non_zero_chunk = attn_metadata.prefill.chunked_context.mask_for_non_zero_chunk - filtered_indices = filter_chunked_req_indices( - seq_len, mask_for_non_zero_chunk) - attn_output_pre_filtered = attn_output_pre[filtered_indices, :, :] - attn_lse_filtered = attn_lse[:, filtered_indices] - attn_output_pre_new = attn_output_pre_new[filtered_indices, :, :] - attn_lse_new = attn_lse_new[:, filtered_indices] - - # normalize prefix LSE to [bs, heads, 1] for stable updates - attn_lse_filtered = attn_lse_filtered.permute(1, 0).unsqueeze(-1) - attn_lse_new = attn_lse_new.permute(1, 0).unsqueeze(-1) - token_mask = torch.ones([attn_lse_new.size(0)], - dtype=torch.uint8, - device=attn_lse_new.device) - attn_output_pre_filtered, attn_lse_filtered = self._update_out_and_lse( - attn_output_pre_filtered, attn_lse_filtered, - attn_output_pre_new, attn_lse_new, token_mask) - # convert lse back to [heads, bs] - attn_lse_filtered = attn_lse_filtered.squeeze(-1).permute( - 1, 0).contiguous() - - attn_output_pre[ - filtered_indices, :, :] = attn_output_pre_filtered.to( - attn_output_pre.dtype) - attn_lse[:, - filtered_indices] = attn_lse_filtered.to(attn_lse.dtype) - - attn_output_pre = attn_output_pre.to(q_nope.dtype) - output = attn_output_pre.reshape( - [num_tokens, self.num_heads * self.v_head_dim]) - else: - output = output.reshape( - [num_tokens, self.num_heads * self.v_head_dim]) + output = output.reshape([num_tokens, self.num_heads * self.v_head_dim]) return output @@ -2164,32 +1909,75 @@ class AscendMLAImpl(MLAAttentionImpl): return attn_out_lse_list - # TODO use update op to replace this - def _update_out_and_lse( + def _reorg_kvcache( self, - out: torch.Tensor, - lse: torch.Tensor, - block_out: torch.Tensor, - block_lse: torch.Tensor, - mask: torch.Tensor = None, - ): - if out is None: - out = block_out.to(torch.float32) - lse = block_lse - else: - if mask is None: - mask = torch.ones([block_out.size(0)], - dtype=torch.uint8, - device=block_out.device) - out_mask = mask[:, None, None].expand_as(block_out) - lse_mask = mask[:, None, None].expand_as(block_lse) - block_out = block_out.to(torch.float32) - out_without_update = out.clone() - lse_without_update = lse.clone() - - out = out - F.sigmoid(block_lse - lse) * (out - block_out) - lse = lse - F.logsigmoid(lse - block_lse) - # mask - out = torch.where(out_mask, out, out_without_update) - lse = torch.where(lse_mask, lse, lse_without_update) - return out, lse + allgatered_kv_c_normed: torch.Tensor, + allgatered_k_pe: torch.Tensor, + padded_local_chunk_seq_lens_lst: list[int], + local_context_lens_allranks: list[list[int]], + sum_seq_len: int, + max_seq_len: int, + chunk_size: int, + chunk_idx: int, + toks: int, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + reorg and unpad kvcache after cp local gather to tp layout for attn kernel. + e.g. + kv_c_normed in rank0 = [T0_0, T0_1, T0_2, T0_3, T1_0, T1_1, ...] + kv_c_normed in rank1 = [T0_4, T0_5, pad, pad, T1_2, pad, ...] + allgatered_kv_c_normed = [T0_0, T0_1, T0_2, T0_3, T1_0, T1_1, ..., + T0_4, T0_5, pad, pad, T1_2, pad, ...] + -> reorganized_kv_c_normed = [T0_0, T0_1, T0_2, T0_3, T0_4, T0_5, + T1_0, T1_1, T1_2, ...] + Args: + padded_local_chunk_seq_lens_lst: local chunk context lengths + under current CP rank. + local_context_lens_allranks: local context lengths on each CP rank. + sum_seq_len: the sum of cp_chunk_seq_lens_lst. + max_seq_len: the max value of cp_chunk_seq_lens_lst. + chunk_size: the local padded max context chunk from + chunked_context_metadata building. + chunk_idx: chunk idx of chunked_prefill. + toks: the number of tokens for local gather cache. + """ + kv_c_segments = [] + k_pe_segments = [] + src_token_idx = 0 + max_seq_len_check = 0 + for padded_local_chunk_seq_len, local_context_lens in zip( + padded_local_chunk_seq_lens_lst, local_context_lens_allranks): + cur_seq_len = 0 + for rank, local_context_len in enumerate(local_context_lens): + # Note(qcs): We split the context into multiple chunks, + # depending on the size of the workspace. + # local_context in dcp0: |-----------------| + # local_context in dcp1: |--------------| + # n*padded_local_chunk: |-----|-----|-----| + # local_chunk_len in dcp1: |-----|-----|--| + # so we need update the last chunk length in dcp1. + local_chunk_len = min( + max(0, local_context_len - chunk_idx * chunk_size), + padded_local_chunk_seq_len, + ) + if local_chunk_len != 0: + kv_c_segment = allgatered_kv_c_normed[rank * toks + + src_token_idx:rank * + toks + + src_token_idx + + local_chunk_len] + k_pe_segment = allgatered_k_pe[rank * toks + + src_token_idx:rank * toks + + src_token_idx + + local_chunk_len] + kv_c_segments.append(kv_c_segment) + k_pe_segments.append(k_pe_segment) + cur_seq_len += local_chunk_len + max_seq_len_check = max(max_seq_len_check, cur_seq_len) + src_token_idx += padded_local_chunk_seq_len + reorganized_kv_c_normed = torch.cat(kv_c_segments, dim=0) + reorganized_k_pe = torch.cat(k_pe_segments, dim=0) + assert reorganized_kv_c_normed.shape[0] == sum_seq_len + assert reorganized_k_pe.shape[0] == sum_seq_len + assert max_seq_len_check == max_seq_len + return reorganized_kv_c_normed, reorganized_k_pe diff --git a/vllm_ascend/attention/utils.py b/vllm_ascend/attention/utils.py index b998ee6a..a2f71de7 100644 --- a/vllm_ascend/attention/utils.py +++ b/vllm_ascend/attention/utils.py @@ -20,13 +20,6 @@ class AscendPrefillContextParallelMetadata: num_computed_tokens_of_pcp_dcp: Optional[list[list[list[int]]]] = None - local_chunked_kv_lens: Optional[list[Optional[list[Optional[list[Optional[ - list[int]]]]]]]] = None - - mask_for_non_zero_chunk: Optional[List[bool]] = None - - max_chunk_num: int = 0 - q_head_idx_tensor: torch.Tensor = None q_tail_idx_tensor: torch.Tensor = None @@ -115,23 +108,6 @@ class AscendCommonAttentionMetadata: AscendPrefillContextParallelMetadata] = None -def extract_req_dcp_by_chunk_pcp(lst, - chunk_idx, - dcp_size, - pcp_rank, - fill_value=0): - num_reqs = len(lst) - results: List[List[int]] = [] - for i in range(num_reqs): - if len(lst[i]) == 0 or chunk_idx >= len(lst[i]): - # empty req or this req has no corresponding chunk, fill 0 - results.append([fill_value] * dcp_size) - continue - dcp_values = lst[i][chunk_idx][pcp_rank] - results.append(dcp_values) - return results - - def filter_chunked_req_indices( seq_len: torch.Tensor, mask_for_non_zero_chunk: Optional[List[bool]], diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index a332ac79..29720464 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -29,7 +29,7 @@ from copy import deepcopy from dataclasses import dataclass from multiprocessing import Manager from typing import (TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, - Tuple, Union, cast) + Union, cast) import numpy as np import numpy.typing as npt @@ -763,7 +763,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): backward_kwargs["mm_features"] = new_req_data.mm_features # Create request state - PCP/DCP tracking will be computed below - req_state = CachedRequestState( + self.requests[req_id] = CachedRequestState( req_id=req_id, prompt_token_ids=new_req_data.prompt_token_ids, prompt_embeds=new_req_data.prompt_embeds, @@ -774,42 +774,9 @@ class NPUModelRunner(LoRAModelRunnerMixin): num_computed_tokens=new_req_data.num_computed_tokens, output_token_ids=[], lora_request=new_req_data.lora_request, - local_chunked_kv_lens=None, **backward_kwargs, ) - # Compute PCP/DCP tracking fields for chunked prefill - self.input_batch.local_chunked_kv_lens = [None] * self.max_num_reqs - if self.pcp_size * self.dcp_size > 1: - num_computed_tokens = new_req_data.num_computed_tokens - if num_computed_tokens > 0: - # Initialize with starting rank 0 - temp_start_rank_dict = {req_id: (0, 0)} - - # Compute token distribution for initial tokens - current_distribution = self.get_split_computed_tokens( - np.array([num_computed_tokens]), - request_ids=[req_id], - request_start_rank_dict=temp_start_rank_dict, - cp_kv_cache_interleave_size=self.parallel_config. - cp_kv_cache_interleave_size, - )[0] - - # Update next_pcp_dcp_start_rank - req_state.next_pcp_dcp_start_rank = temp_start_rank_dict[ - req_id][0] - req_state.token_blank_in_last_blk = temp_start_rank_dict[ - req_id][1] - - req_state.local_chunked_kv_lens = [ - copy.deepcopy(current_distribution) - ] - else: - # No computed tokens yet - req_state.local_chunked_kv_lens = [] - - self.requests[req_id] = req_state - # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: self._init_mrope_positions(self.requests[req_id]) @@ -826,44 +793,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): resumed_from_preemption = req_data.resumed_from_preemption[i] # Update the cached states. - prev_num_computed_tokens = req_state.num_computed_tokens req_state.num_computed_tokens = num_computed_tokens - # Compute PCP/DCP tracking fields for chunked prefill - if self.pcp_size * self.dcp_size > 1: - # If this is the first chunk, initialize tracking fields - if req_state.local_chunked_kv_lens is None: - req_state.local_chunked_kv_lens = [] - - # Compute tokens added in this chunk (not cumulative) - chunk_tokens = num_computed_tokens - prev_num_computed_tokens - - if chunk_tokens > 0: - # Create a temporary dict with this request's starting rank - temp_start_rank_dict = { - req_id: (req_state.next_pcp_dcp_start_rank, - req_state.token_blank_in_last_blk) - } - - # Compute distribution for this chunk only - chunk_distribution = self.get_split_computed_tokens( - np.array([chunk_tokens]), - request_ids=[req_id], - request_start_rank_dict=temp_start_rank_dict, - cp_kv_cache_interleave_size=self.parallel_config. - cp_kv_cache_interleave_size, - )[0] - - # Update next_pcp_dcp_start_rank for this request - req_state.next_pcp_dcp_start_rank = temp_start_rank_dict[ - req_id][0] - req_state.token_blank_in_last_blk = temp_start_rank_dict[ - req_id][1] - - # Append this chunk's distribution to accumulation list - req_state.local_chunked_kv_lens.append( - copy.deepcopy(chunk_distribution)) - if not is_last_rank: # When using PP, the scheduler sends the sampled tokens back, # because there's no direct communication between the first- @@ -908,10 +839,6 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.input_batch.block_table.append_row( new_block_ids, req_index) - # Update PCP/DCP tracking fields in input_batch - self.input_batch.local_chunked_kv_lens[ - req_index] = req_state.local_chunked_kv_lens - # For the last rank, we don't need to update the token_ids_cpu # because the sampled tokens are already cached. if not is_last_rank: @@ -1477,7 +1404,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): scheduler_output, decode_threshold=self.reorder_batch_threshold) - def generate_kv_idx(self, tokens, scheduler_output): + def generate_kv_idx(self, scheduler_output): if not self.pcp_size > 1: return self.cp_kv_recover_idx_for_chunk = [[] for _ in range(self.pcp_size)] @@ -1548,12 +1475,14 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.input_batch.num_computed_tokens_cpu[req_indices], arange, ) - self.generate_kv_idx(tokens, scheduler_output) + self.input_batch.block_table.compute_slot_mapping( req_indices, positions_np) self.input_batch.block_table.commit_slot_mapping( total_num_scheduled_tokens) if self.pcp_size > 1: + if not self.vllm_config.model_config.use_mla: + self.generate_kv_idx(scheduler_output) tokens, position_pcp, pcp_unpad_mask = self._update_tokens_for_pcp( tokens) num_scheduled_tokens = np.array(tokens, dtype=np.int32) @@ -1905,18 +1834,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): num_reqs, scheduler_output.total_num_scheduled_tokens, scheduler_output.num_scheduled_tokens) - # prepare pcp meta data - # For chunked prefill, use num_scheduled_tokens instead of cumulative seq_lens - # to correctly calculate chunk_len in _generate_pcp_metadata - if self.vllm_config.scheduler_config.chunked_prefill_enabled and self.pcp_size > 1: - # In chunked prefill, seq_lens_for_chunk should be the current chunk size - seq_lens_for_chunk = torch.from_numpy( - num_scheduled_tokens[:num_reqs]) - else: - # Normal mode: use cumulative sequence lengths - seq_lens_for_chunk = seq_lens_cpu long_seq_metadata = self._generate_pcp_metadata( - total_num_scheduled_tokens, seq_lens_for_chunk, seq_lens_cpu) + total_num_scheduled_tokens) # Prepare the attention metadata for each KV cache group and make layers # in the same group share the same metadata. for kv_cache_group_id, kv_cache_group_spec in enumerate( @@ -2842,6 +2761,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.query_start_loc[1:num_reqs + 1] = query_start_loc_tensor self.query_start_loc_cpu[1:num_reqs + 1] = torch.Tensor(cu_num_tokens) + self.query_lens = torch.from_numpy(num_scheduled_tokens) num_computed_tokens_cpu = ( self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs]) @@ -2855,8 +2775,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.cp_kv_recover_idx = torch.zeros(self.max_num_tokens, dtype=torch.int32, device=self.device) - long_seq_metadata = self._generate_pcp_metadata( - num_tokens, self.seq_lens_cpu, self.seq_lens_cpu) + long_seq_metadata = self._generate_pcp_metadata(num_tokens) if long_seq_metadata is not None: pcp_world_size = get_pcp_group( ).world_size if prefill_context_parallel_enable() else 1 @@ -4411,7 +4330,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): all_positions_tensor.float().argsort().long(), non_blocking=True) return pcp_tokens, positions, unpad_mask - def _get_pcp_local_seq_lens( + def _get_cp_local_seq_lens( self, seq_lens: torch.Tensor, pcp_world_size: int = 1, @@ -4439,139 +4358,20 @@ class NPUModelRunner(LoRAModelRunnerMixin): [-1, pcp_world_size, dcp_world_size]) return dcp_local_seq_lens - def get_split_computed_tokens( - self, - num_computed_tokens: np.ndarray, - request_ids: Optional[List[str]] = None, - request_start_rank_dict: Dict[str, tuple[ - int, int]] = {}, # tuple: start_rank, tokens_blank_in_this_block - cp_kv_cache_interleave_size: int = 1 - ) -> list[Optional[list[Optional[list[int]]]]]: - """Splits computed token counts across dcp and sp dimensions for distributed allocation. - - Args: - num_computed_tokens: Number of tokens for each request (current chunk, not cumulative) - request_ids: Request IDs to track state - request_start_rank_dict: Dict mapping req_id to the starting rank for this chunk. - Will be updated with next starting rank after distribution. - - Returns: - List of [pcp_size][dcp_size] distribution for each request - """ - self.pcp_world_size = get_pcp_group( - ).world_size if prefill_context_parallel_enable() else 1 - self.dcp_world_size = get_dcp_group().world_size - num_requests = len(num_computed_tokens) - assert request_start_rank_dict is not None and request_ids is not None and len( - request_ids) == num_requests - local_chunked_kv_lens = [[[0] * self.dcp_world_size - for _ in range(self.pcp_world_size)] - for _ in range(num_requests)] - total_ranks = self.pcp_world_size * self.dcp_world_size - - for req_idx, (req_id, total_tokens) in enumerate( - zip(request_ids, num_computed_tokens)): - if total_tokens <= 0: - continue - - # Get starting rank for this chunk - start_rank = 0 - tokens_blank = 0 - if request_start_rank_dict is not None: - start_rank, tokens_blank = request_start_rank_dict.get( - req_id, (0, 0)) - - if tokens_blank > 0: # need to continue writing in the last block of previous chunk - consumed_tokens = min(tokens_blank, total_tokens) - total_tokens -= consumed_tokens - tokens_blank -= consumed_tokens - pcp_idx = start_rank // self.dcp_world_size - dcp_idx = start_rank % self.dcp_world_size - local_chunked_kv_lens[req_idx][pcp_idx][ - dcp_idx] += consumed_tokens - if tokens_blank == 0: - start_rank = (start_rank + 1) % total_ranks - if total_tokens == 0: - request_start_rank_dict[req_id] = (start_rank, - tokens_blank) - continue - - virtual_size = total_ranks * cp_kv_cache_interleave_size - base = int(total_tokens) // virtual_size - - # Distribute base tokens to all ranks - for rank_idx in range(total_ranks): - pcp_idx = rank_idx // self.dcp_world_size - dcp_idx = rank_idx % self.dcp_world_size - local_chunked_kv_lens[req_idx][pcp_idx][ - dcp_idx] += base * cp_kv_cache_interleave_size - - remainder = int(total_tokens) % virtual_size - if remainder == 0: - request_start_rank_dict[req_id] = (start_rank, tokens_blank) - continue - remain_blocks = cdiv(remainder, cp_kv_cache_interleave_size) - assert remain_blocks > 0 - - # Distribute remainder tokens starting from start_rank - for i in range(remain_blocks): - rank = (start_rank + i) % total_ranks - pcp_idx = rank // self.dcp_world_size - dcp_idx = rank % self.dcp_world_size - if i < remain_blocks - 1 or remainder % cp_kv_cache_interleave_size == 0: # not last block or divisible - local_chunked_kv_lens[req_idx][pcp_idx][ - dcp_idx] += 1 * cp_kv_cache_interleave_size - tokens_blank = 0 - else: # if last block and undivisible - local_chunked_kv_lens[req_idx][pcp_idx][ - dcp_idx] += remainder % cp_kv_cache_interleave_size - tokens_blank = cp_kv_cache_interleave_size - ( - remainder % cp_kv_cache_interleave_size) - start_rank = (start_rank + remain_blocks - 1) % total_ranks - if tokens_blank == 0: - start_rank = (start_rank + 1) % total_ranks - - # Update next starting rank for this request - request_start_rank_dict[req_id] = (start_rank, tokens_blank) - - return cast(List[Optional[List[Optional[List[int]]]]], - local_chunked_kv_lens) - - def _get_chunked_req_mask_and_max_chunk( - self, - local_chunked_kv_lens: Optional[list[Optional[list[Optional[list[ - Optional[list[int]]]]]]]] = None - ) -> Tuple[List[bool], int]: - """ - given 4-d list [req][chunk][pcp][dcp], return: - 1. if each req has any chunk (list[bool]) - 2. max chunk num along all reqs (int) - """ - assert local_chunked_kv_lens is not None - if len(local_chunked_kv_lens) == 0: - return ([], 0) - mask_for_non_zero_chunk = [ - len(req) > 0 for req in local_chunked_kv_lens if req is not None - ] - max_chunk_num = max( - (len(req) for req in local_chunked_kv_lens if req is not None), - default=0) - return mask_for_non_zero_chunk, max_chunk_num - - def _generate_pcp_metadata(self, total_num_scheduled_tokens, seq_lens, - seq_lens_origin): + def _generate_pcp_metadata(self, total_num_scheduled_tokens): # In dummy run num_reqs == 0, update it from seq_lens - num_reqs = self.input_batch.num_reqs or seq_lens.size(0) + num_reqs = self.input_batch.num_reqs or self.query_lens.size(0) num_decodes = sum(self.input_batch.num_computed_tokens_cpu[:num_reqs] >= self.input_batch.num_prompt_tokens[:num_reqs]) num_actual_tokens_pcp_padded = total_num_scheduled_tokens * self.pcp_size self.num_actual_tokens_pcp_padded = num_actual_tokens_pcp_padded - local_chunked_kv_lens = self.input_batch.local_chunked_kv_lens[ - num_decodes:num_reqs] - mask_for_non_zero_chunk, max_chunk_num = self._get_chunked_req_mask_and_max_chunk( - local_chunked_kv_lens) long_seq_metadata = None if self.pcp_size * self.dcp_size > 1: + decode_context_lens = self.input_batch.num_tokens[:num_decodes] + prefill_context_lens = self.input_batch.num_computed_tokens_cpu[ + num_decodes:num_reqs] + context_lens = np.concatenate( + [decode_context_lens, prefill_context_lens]) num_computed_tokens_of_pcp_dcp = torch.zeros( [ num_reqs * self.decode_threshold, self.pcp_size, @@ -4584,8 +4384,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): for decode_idx in range(self.decode_threshold): num_computed_tokens_of_pcp_dcp[ self.decode_threshold - 1 - decode_idx::self.decode_threshold] = \ - self._get_pcp_local_seq_lens( - seq_lens_origin - decode_idx, + self._get_cp_local_seq_lens( + torch.tensor(context_lens), self.pcp_size, self.dcp_size, self.parallel_config.cp_kv_cache_interleave_size, @@ -4593,10 +4393,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): long_seq_metadata = AscendPrefillContextParallelMetadata( num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded, num_computed_tokens_of_pcp_dcp=num_computed_tokens_of_pcp_dcp. - numpy(), - local_chunked_kv_lens=local_chunked_kv_lens, - mask_for_non_zero_chunk=mask_for_non_zero_chunk, - max_chunk_num=max_chunk_num) + numpy()) if self.pcp_size > 1: q_head_idx, q_tail_idx = [], [] kv_with_q_head_nomask_idx, kv_with_q_head_mask_idx = [], [] @@ -4607,7 +4404,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): kv_req_offset = 0 q_head_chunk_id = self.pcp_rank q_tail_chunk_id = self.pcp_size * 2 - 1 - self.pcp_rank - for i, seq_len in enumerate(seq_lens): + for i, seq_len in enumerate(self.query_lens): if i < num_decodes: continue chunk_len = seq_len // 2 diff --git a/vllm_ascend/worker/npu_input_batch.py b/vllm_ascend/worker/npu_input_batch.py index bda2851e..846a4b29 100644 --- a/vllm_ascend/worker/npu_input_batch.py +++ b/vllm_ascend/worker/npu_input_batch.py @@ -73,12 +73,6 @@ class CachedRequestState: lora_request: Optional[LoRARequest] = None prompt_embeds: Optional[torch.Tensor] = None - # pcp/dcp param - local_chunked_kv_lens: Optional[list[Optional[list[Optional[ - list[int]]]]]] = None # Records computed tokens for each chunk - next_pcp_dcp_start_rank: int = 0 # Tracks next starting rank for round-robin distribution - token_blank_in_last_blk: int = 0 # if the last block is not full, how many future tokens can be stored - def __post_init__(self): self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds( self.prompt_token_ids, self.prompt_embeds) @@ -319,10 +313,6 @@ class InputBatch: self.prev_sampled_token_ids_invalid_indices: Optional[set[int]] = None self.prev_req_id_to_index: Optional[dict[str, int]] = None - # pcp/dcp parameters - self.local_chunked_kv_lens: list[Optional[list[Optional[list[Optional[ - list[int]]]]]]] = [None] * max_num_reqs - @property def req_ids(self) -> list[str]: # None elements should only be present transiently @@ -395,9 +385,6 @@ class InputBatch: self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens self.block_table.add_row(request.block_ids, req_index) - # Add PCP/DCP tracking fields - self.local_chunked_kv_lens[req_index] = request.local_chunked_kv_lens - if sampling_params := request.sampling_params: if (self.is_spec_decode and is_spec_decode_unsupported(sampling_params)): @@ -693,8 +680,6 @@ class InputBatch: last_req_index] self.num_computed_tokens_cpu[ empty_index] = self.num_computed_tokens_cpu[last_req_index] - self.local_chunked_kv_lens[ - empty_index] = self.local_chunked_kv_lens[last_req_index] self.block_table.move_row(last_req_index, empty_index) self.temperature_cpu[empty_index] = self.temperature_cpu[ last_req_index]