from dataclasses import dataclass from functools import lru_cache from typing import Any, List, Optional import torch import torch.nn.functional as F from vllm.config import VllmConfig, get_current_vllm_config from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group, is_v1_kv_transfer_group) from vllm.forward_context import ForwardContext, get_forward_context from vllm_ascend.utils import (AscendDeviceType, get_ascend_config, get_ascend_device_type) def using_paged_attention(runtime_shape: int, vllm_config: VllmConfig) -> bool: if vllm_config.speculative_config is not None: return False if get_ascend_device_type() == AscendDeviceType.A5: return False from vllm.config.compilation import CUDAGraphMode cudagraph_mode = vllm_config.compilation_config.cudagraph_mode if cudagraph_mode != CUDAGraphMode.FULL_DECODE_ONLY: return False return runtime_shape in get_ascend_config().pa_shape_list @lru_cache(maxsize=1) def enable_cp(): prefill_config = get_current_vllm_config().parallel_config return prefill_config.prefill_context_parallel_size > 1 \ or prefill_config.decode_context_parallel_size > 1 @dataclass # class AscendCommonLongSequenceMetadata: class AscendPrefillContextParallelMetadata: pcp_allgather_restore_idx: torch.Tensor = None cp_kv_recover_idx_for_chunk: torch.Tensor = None num_actual_tokens_pcp_padded: Optional[int] = None num_computed_tokens_of_pcp_dcp: Optional[list[list[list[int]]]] = None q_head_idx_tensor: torch.Tensor = None q_tail_idx_tensor: torch.Tensor = None kv_with_q_head_nomask_idx_tensor: torch.Tensor = None kv_with_q_head_mask_idx_tensor: torch.Tensor = None kv_with_q_tail_nomask_idx_tensor: torch.Tensor = None kv_with_q_tail_mask_idx_tensor: torch.Tensor = None attn_mask_seqlens: torch.Tensor = None head_attn_nomask_seqlens: torch.Tensor = None tail_attn_nomask_seqlens: torch.Tensor = None q_full_idx: torch.Tensor = None pcp_prefill_mask: torch.Tensor = None @dataclass class AscendCommonAttentionMetadata: """ Per-batch attention metadata, shared across layers and backends. AttentionMetadataBuilder instances use it to construct per-layer metadata. For many of the tensors we keep both NPU and CPU versions. """ query_start_loc: torch.Tensor query_start_loc_cpu: torch.Tensor """(batch_size + 1,), the start location of each request in query Tensor""" seq_lens_cpu: torch.Tensor """(batch_size,), the length of each request including both computed tokens and newly scheduled tokens""" seq_lens: torch.Tensor """same to seq_lens_cpu, for compatibility with some new attn metadata (such as GDN).""" num_computed_tokens_cpu: torch.Tensor """(batch_size,), the number of computed tokens for each request""" num_reqs: int """Number of requests""" num_actual_tokens: int """Total number of tokens in batch""" max_query_len: int """Max token number of request in batch""" decode_token_per_req: int """decode token number per request""" block_table_tensor: torch.Tensor slot_mapping: torch.Tensor actual_seq_lengths_q: list[int] positions: torch.Tensor = None attn_mask: torch.Tensor = None spec_attn_mask: torch.Tensor = None attn_state: Any = None graph_pad_size: int = -1 # num_input_tokens refers to total number of tokens including # padding tokens. It is used to handle some padding operations. num_input_tokens: int = 0 prefill_context_parallel_metadata: Optional[ AscendPrefillContextParallelMetadata] = None causal: bool = True # TODO: Remove it when vLLM no longer uses this function. def unpadded(self, num_actual_tokens: int, num_actual_reqs: int) -> "AscendCommonAttentionMetadata": # This only use to eagle now. It will be use to enforce_eager in future. return AscendCommonAttentionMetadata( query_start_loc=self.query_start_loc[:num_actual_reqs + 1], query_start_loc_cpu=self.query_start_loc_cpu[:num_actual_reqs + 1], seq_lens=self.seq_lens[:num_actual_reqs], seq_lens_cpu=self.seq_lens_cpu[:num_actual_reqs], num_computed_tokens_cpu=self. num_computed_tokens_cpu[:num_actual_reqs], num_reqs=num_actual_reqs, num_actual_tokens=num_actual_tokens, max_query_len=self.max_query_len, decode_token_per_req=self.decode_token_per_req, block_table_tensor=self.block_table_tensor[:num_actual_reqs], slot_mapping=self.slot_mapping[:num_actual_tokens], causal=self.causal, actual_seq_lengths_q=self.actual_seq_lengths_q[:num_actual_tokens], positions=self.positions[:num_actual_tokens], attn_mask=self.attn_mask, spec_attn_mask=self.spec_attn_mask, attn_state=self.attn_state, graph_pad_size=-1, # It should be -1 when not run in fullgraph mode. num_input_tokens=num_actual_tokens, prefill_context_parallel_metadata=self. prefill_context_parallel_metadata, ) def filter_chunked_req_indices( seq_len: torch.Tensor, mask_for_non_zero_chunk: Optional[List[bool]], ) -> torch.Tensor: """ filter the reqs which are doing real chunk_prefill. Args: seq_len: contains multi-req length: [req0_len, req1_len, ...] mask_for_non_zero_chunk: [True, False, True, False, ...] Returns: filtered_indices: the real chunked req's indices """ assert mask_for_non_zero_chunk is not None and len(seq_len) == len( mask_for_non_zero_chunk) offsets = torch.cumsum(torch.cat([torch.tensor([0]), seq_len[:-1]]), dim=0) filtered_indices = torch.cat([ torch.arange(offsets[i], offsets[i] + seq_len[i]) for i in range(len(mask_for_non_zero_chunk)) if mask_for_non_zero_chunk[i] ]) return filtered_indices def split_decodes_and_prefills( common_attn_metadata: AscendCommonAttentionMetadata, decode_threshold: int = 1, ) -> tuple[int, int, int, int]: """ Assuming a reordered batch, finds the boundary between prefill and decode requests. Args: common_attn_metadata: AscendCommonAttentionMetadata object containing the batch metadata. decode_threshold: The maximum query length to be considered a decode. Returns: num_decodes: The number of decode requests. num_prefills: The number of prefill requests. num_decode_tokens: The number of tokens in the decode requests. num_prefill_tokens: The number of tokens in the prefill requests. """ max_query_len = common_attn_metadata.max_query_len num_reqs = common_attn_metadata.num_reqs num_tokens = common_attn_metadata.num_actual_tokens query_start_loc = common_attn_metadata.query_start_loc_cpu if max_query_len <= decode_threshold: return num_reqs, 0, num_tokens, 0 query_lens = query_start_loc[1:] - query_start_loc[:-1] is_prefill = query_lens > decode_threshold if not torch.any(is_prefill): return num_reqs, 0, num_tokens, 0 first_prefill = is_prefill.int().argmax(dim=-1).item() num_decodes = first_prefill num_prefills = num_reqs - num_decodes num_decode_tokens = query_start_loc[first_prefill].item() num_prefill_tokens = num_tokens - num_decode_tokens return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) def wait_for_kv_layer_from_connector(layer_name: str): if not has_kv_transfer_group() or not is_v1_kv_transfer_group(): return connector = get_kv_transfer_group() forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata if attn_metadata is None: return # TODO: assert ascendMetadata connector.wait_for_layer_load(layer_name) def maybe_save_kv_layer_to_connector( layer_name: str, kv_cache_layer: List[torch.Tensor], ): if not has_kv_transfer_group() or not is_v1_kv_transfer_group(): return connector = get_kv_transfer_group() forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata if attn_metadata is None: return # TODO: assert ascendMetadata connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata) def round_up(val: int, align: int) -> int: if align == 0: return 0 return -(val // -align) * align def trans_rope_weight(weight, rope_dim): if rope_dim == 0: return weight.contiguous() nope_part = weight[..., :-rope_dim, :] rope_part = weight[..., -rope_dim:, :] reordered_rope_part = torch.cat( (rope_part[..., ::2, :], rope_part[..., 1::2, :]), dim=-2) return torch.cat((nope_part, reordered_rope_part), dim=-2).contiguous() def transdata(nd_mat, block_size: tuple = (16, 16)): r = round_up(nd_mat.shape[0], block_size[0]) c = round_up(nd_mat.shape[1], block_size[1]) r_pad = r - nd_mat.shape[0] c_pad = c - nd_mat.shape[1] nd_mat = F.pad(nd_mat, (0, r_pad, 0, c_pad)) nz_mat = torch.permute( torch.reshape( nd_mat, (r // block_size[0], block_size[0], c // block_size[1], block_size[1]), ), [2, 0, 1, 3], ) nz_mat = torch.reshape( nz_mat, (nz_mat.shape[0], nz_mat.shape[1] * nz_mat.shape[2], nz_mat.shape[3])) return nz_mat