# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Sequence from typing import Any import numpy as np import torch from vllm.config import VllmConfig from vllm.config.model import ModelDType from vllm.v1.attention.backends.utils import AttentionMetadataBuilder from vllm.v1.kv_cache_interface import EncoderOnlyAttentionSpec, KVCacheConfig from vllm_ascend.attention.attention_mask import AttentionMaskBuilder from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, AscendPrefillContextParallelMetadata) _ATTENTION_MASK_BUILDER = None def get_attn_mask_builder(device: torch.device): """Get attention mask builder which only have one instance.""" global _ATTENTION_MASK_BUILDER if _ATTENTION_MASK_BUILDER is None: _ATTENTION_MASK_BUILDER = AttentionMaskBuilder(device) return _ATTENTION_MASK_BUILDER def build_attn_metadata( attn_metadata_builders: list[AttentionMetadataBuilder], num_reqs: int, num_tokens: int, query_start_loc_gpu: torch.Tensor, query_start_loc_cpu: torch.Tensor, seq_lens: torch.Tensor, seq_lens_cpu: torch.Tensor, num_computed_tokens_cpu: torch.Tensor, block_tables: Sequence[torch.Tensor], slot_mappings: torch.Tensor, kv_cache_config: KVCacheConfig, decode_token_per_req: int, actual_seq_lengths_q: list[int], positions: torch.Tensor | None = None, attn_mask: torch.Tensor | None = None, spec_attn_mask: torch.Tensor | None = None, attn_state: Any | None = None, is_only_prefill: bool = False, graph_pad_size: int = -1, num_input_tokens: int = 0, prefill_context_parallel_metadata: AscendPrefillContextParallelMetadata | None = None, ) -> dict[str, Any]: """Build attention metadata for Ascend NPUs.""" # TODO(Ronald1995): optimize AscendCommonAttentionMetadata. max_query_len = int(query_start_loc_cpu.max()) attn_metadata: dict[str, Any] = {} kv_cache_groups = kv_cache_config.kv_cache_groups for i, kv_cache_spec in enumerate(kv_cache_groups): block_table = block_tables[i] slot_mapping = slot_mappings[i] common_attn_metadata = AscendCommonAttentionMetadata( query_start_loc=query_start_loc_gpu, query_start_loc_cpu=query_start_loc_cpu, seq_lens_cpu=seq_lens_cpu[:num_reqs], seq_lens=seq_lens[:num_reqs], num_computed_tokens_cpu=num_computed_tokens_cpu, num_reqs=num_reqs, num_actual_tokens=num_tokens, max_query_len=max_query_len, decode_token_per_req=decode_token_per_req, block_table_tensor=block_table, slot_mapping=slot_mapping, actual_seq_lengths_q=actual_seq_lengths_q, positions=positions, attn_mask=attn_mask, spec_attn_mask=spec_attn_mask, attn_state=attn_state, is_only_prefill=is_only_prefill, graph_pad_size=graph_pad_size, num_input_tokens=num_input_tokens, prefill_context_parallel_metadata=prefill_context_parallel_metadata, ) attn_metadata_builder = attn_metadata_builders[i] metadata = attn_metadata_builder.build( common_prefix_len=0, common_attn_metadata=common_attn_metadata, # type: ignore ) for layer_name in kv_cache_spec.layer_names: attn_metadata[layer_name] = metadata return attn_metadata def build_attn_state( vllm_config: VllmConfig, seq_lens_np: np.ndarray, num_reqs, num_scheduled_tokens, num_valid_tokens, ): """Build attention state for npu's attention backend.""" if vllm_config.model_config.runner_type == "pooling": if isinstance( vllm_config.kv_cache_config.kv_cache_groups[0].kv_cache_spec, EncoderOnlyAttentionSpec, ): attn_state = AscendAttentionState.PrefillNoCache else: attn_state = AscendAttentionState.PrefillCacheHit elif np.array_equal(seq_lens_np[:num_reqs], num_scheduled_tokens): attn_state = AscendAttentionState.PrefillNoCache # We assume it is the decode stage, where prefill occurs # but only one token is not hit in cache. elif np.all(num_scheduled_tokens == 1): attn_state = AscendAttentionState.DecodeOnly if (vllm_config.speculative_config and vllm_config.speculative_config.method == 'mtp'): # SpecDecoding now supports seq_len=1 and seq_len=2 # In Prefilling Decoding Disaggregation scenario, SpecDecoding # need to supports seq_len=1 attn_state = AscendAttentionState.SpecDecoding # Speculative decoding. elif np.all(num_valid_tokens == 1): if (vllm_config.speculative_config and vllm_config.speculative_config.method == 'mtp'): attn_state = AscendAttentionState.SpecDecoding else: attn_state = AscendAttentionState.ChunkedPrefill # splitfuse elif vllm_config.scheduler_config.enable_chunked_prefill: attn_state = AscendAttentionState.ChunkedPrefill else: attn_state = AscendAttentionState.PrefillCacheHit return attn_state def make_attention_mask( vllm_config: VllmConfig, attn_state: AscendAttentionState, dtype: ModelDType | torch.dtype, device: torch.device, ) -> torch.Tensor: """make attention mask for npu's attention backend.""" attn_mask_builder = get_attn_mask_builder(device) # pcp situation. if attn_mask_builder is None: raise ValueError("Attn mask builder is None") # Pooling situation. if vllm_config.model_config.runner_type == "pooling": return attn_mask_builder.get_attn_mask(2048, torch.bool) # TODO(Ronald1995) cosidering pcp. if vllm_config.model_config.use_mla: # mla prefill if attn_state != AscendAttentionState.DecodeOnly: return attn_mask_builder.get_mla_mask(dtype) return attn_mask_builder.get_splitfuse_attn_mask()