# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/attn_utils.py # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # This file is a part of the vllm-ascend project. # from collections.abc import Sequence from typing import Any import numpy as np import torch from vllm.config import VllmConfig from vllm.v1.attention.backend import AttentionMetadataBuilder from vllm.v1.kv_cache_interface import EncoderOnlyAttentionSpec, KVCacheConfig from vllm_ascend.attention.attention_mask import AttentionMaskBuilder from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.utils import AscendCommonAttentionMetadata, AscendPrefillContextParallelMetadata _ATTENTION_MASK_BUILDER = None def get_attn_mask_builder(device: torch.device): """Get attention mask builder which only have one instance.""" global _ATTENTION_MASK_BUILDER if _ATTENTION_MASK_BUILDER is None: _ATTENTION_MASK_BUILDER = AttentionMaskBuilder(device) return _ATTENTION_MASK_BUILDER def build_attn_metadata( *, attn_metadata_builders: list[AttentionMetadataBuilder], num_reqs: int, num_tokens: int, query_start_loc_gpu: torch.Tensor, query_start_loc_cpu: torch.Tensor, max_query_len: int, seq_lens: torch.Tensor, max_seq_len: int, block_tables: Sequence[torch.Tensor], slot_mappings: torch.Tensor, kv_cache_config: KVCacheConfig, # extra attributes for ascend npus. seq_lens_np: np.ndarray | None = None, num_computed_tokens_cpu: torch.Tensor | None = None, positions: torch.Tensor | None = None, attn_state: Any | None = None, graph_pad_size: int = -1, num_input_tokens: int = 0, prefill_context_parallel_metadata: AscendPrefillContextParallelMetadata | None = None, ) -> dict[str, Any]: """Build attention metadata for Ascend NPUs.""" # TODO(Ronald1995): optimize AscendCommonAttentionMetadata. # seq_lens_np is used for ascend npus, it maybe None in spec_decode case, # we fill it with max_seq_len in case `attn_metadata_builder.build` raise # an error. if seq_lens_np is None: seq_lens_np = np.full(num_reqs, max_seq_len, dtype=np.int32) seq_lens_cpu = torch.from_numpy(seq_lens_np)[:num_reqs] # torch_npu._reshape_and_cache operator requires slot_mappings to # be torch.int32. slot_mappings = slot_mappings.to(torch.int32) attn_metadata: dict[str, Any] = {} kv_cache_groups = kv_cache_config.kv_cache_groups for i, kv_cache_spec in enumerate(kv_cache_groups): block_table = block_tables[i] slot_mapping = slot_mappings[i] common_attn_metadata = AscendCommonAttentionMetadata( query_start_loc=query_start_loc_gpu, query_start_loc_cpu=query_start_loc_cpu, seq_lens_cpu=seq_lens_cpu, seq_lens=seq_lens[:num_reqs], num_reqs=num_reqs, num_actual_tokens=num_tokens, max_query_len=max_query_len, block_table_tensor=block_table, slot_mapping=slot_mapping, positions=positions, attn_state=attn_state, graph_pad_size=graph_pad_size, num_input_tokens=num_input_tokens, prefill_context_parallel_metadata=prefill_context_parallel_metadata, max_seq_len=max_seq_len, ) attn_metadata_builder = attn_metadata_builders[i] metadata = attn_metadata_builder.build( common_prefix_len=0, common_attn_metadata=common_attn_metadata, # type: ignore ) for layer_name in kv_cache_spec.layer_names: attn_metadata[layer_name] = metadata return attn_metadata def build_attn_state( vllm_config: VllmConfig, seq_lens_np: np.ndarray, num_reqs, num_scheduled_tokens, num_valid_tokens, ): """Build attention state for npu's attention backend.""" if vllm_config.model_config.runner_type == "pooling": if isinstance( vllm_config.kv_cache_config.kv_cache_groups[0].kv_cache_spec, EncoderOnlyAttentionSpec, ): attn_state = AscendAttentionState.PrefillNoCache else: attn_state = AscendAttentionState.PrefillCacheHit elif np.array_equal(seq_lens_np[:num_reqs], num_scheduled_tokens): attn_state = AscendAttentionState.PrefillNoCache # We assume it is the decode stage, where prefill occurs # but only one token is not hit in cache. elif np.all(num_scheduled_tokens == 1): attn_state = AscendAttentionState.DecodeOnly if vllm_config.speculative_config and vllm_config.speculative_config.method == "mtp": # SpecDecoding now supports seq_len=1 and seq_len=2 # In Prefilling Decoding Disaggregation scenario, SpecDecoding # need to supports seq_len=1 attn_state = AscendAttentionState.SpecDecoding # Speculative decoding. elif np.all(num_valid_tokens == 1): if vllm_config.speculative_config and vllm_config.speculative_config.method == "mtp": attn_state = AscendAttentionState.SpecDecoding else: attn_state = AscendAttentionState.ChunkedPrefill # splitfuse elif vllm_config.scheduler_config.enable_chunked_prefill: attn_state = AscendAttentionState.ChunkedPrefill else: attn_state = AscendAttentionState.PrefillCacheHit return attn_state