Files
xc-llm-ascend/vllm_ascend/worker/v2/attn_utils.py
weijinqian0 95e8a52156 [Refactor] move the metadata from attention_v1 to util(ready for extract common_cp) & realize Ascendmetadata inherit from the parent class. (#5203)
RFC: https://github.com/vllm-project/vllm-ascend/issues/4629

1. Remove the pcp-related code from attention_v1.
2. Establish the inheritance relationship of CommonAttentionMetadata.

TODO
1. extract common_cp
2. move cp metadata to common_cp.
3. remove commonAttentionMetadata for aclgraph.

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
2025-12-23 00:10:52 +08:00

160 lines
6.1 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from typing import Any
import numpy as np
import torch
from vllm.config import VllmConfig
from vllm.config.model import ModelDType
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
from vllm.v1.kv_cache_interface import EncoderOnlyAttentionSpec, KVCacheConfig
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
AscendPrefillContextParallelMetadata)
_ATTENTION_MASK_BUILDER = None
def get_attn_mask_builder(device: torch.device):
"""Get attention mask builder which only have one instance."""
global _ATTENTION_MASK_BUILDER
if _ATTENTION_MASK_BUILDER is None:
_ATTENTION_MASK_BUILDER = AttentionMaskBuilder(device)
return _ATTENTION_MASK_BUILDER
def build_attn_metadata(
attn_metadata_builders: list[AttentionMetadataBuilder],
num_reqs: int,
num_tokens: int,
query_start_loc_gpu: torch.Tensor,
query_start_loc_cpu: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_cpu: torch.Tensor,
num_computed_tokens_cpu: torch.Tensor,
block_tables: Sequence[torch.Tensor],
slot_mappings: torch.Tensor,
kv_cache_config: KVCacheConfig,
decode_token_per_req: int,
actual_seq_lengths_q: list[int],
positions: torch.Tensor | None = None,
attn_mask: torch.Tensor
| None = None,
spec_attn_mask: torch.Tensor | None = None,
attn_state: Any | None = None,
graph_pad_size: int = -1,
num_input_tokens: int = 0,
prefill_context_parallel_metadata: AscendPrefillContextParallelMetadata
| None = None,
) -> dict[str, Any]:
"""Build attention metadata for Ascend NPUs."""
# TODO(Ronald1995): optimize AscendCommonAttentionMetadata.
max_query_len = int(query_start_loc_cpu.max())
max_seq_len = int(seq_lens_cpu.max())
attn_metadata: dict[str, Any] = {}
kv_cache_groups = kv_cache_config.kv_cache_groups
for i, kv_cache_spec in enumerate(kv_cache_groups):
block_table = block_tables[i]
slot_mapping = slot_mappings[i]
common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=query_start_loc_gpu,
query_start_loc_cpu=query_start_loc_cpu,
seq_lens_cpu=seq_lens_cpu[:num_reqs],
seq_lens=seq_lens[:num_reqs],
num_computed_tokens_cpu=num_computed_tokens_cpu,
num_reqs=num_reqs,
num_actual_tokens=num_tokens,
max_query_len=max_query_len,
decode_token_per_req=decode_token_per_req,
block_table_tensor=block_table,
slot_mapping=slot_mapping,
actual_seq_lengths_q=actual_seq_lengths_q,
positions=positions,
attn_mask=attn_mask,
spec_attn_mask=spec_attn_mask,
attn_state=attn_state,
graph_pad_size=graph_pad_size,
num_input_tokens=num_input_tokens,
prefill_context_parallel_metadata=prefill_context_parallel_metadata,
max_seq_len=max_seq_len)
attn_metadata_builder = attn_metadata_builders[i]
metadata = attn_metadata_builder.build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata, # type: ignore
)
for layer_name in kv_cache_spec.layer_names:
attn_metadata[layer_name] = metadata
return attn_metadata
def build_attn_state(
vllm_config: VllmConfig,
seq_lens_np: np.ndarray,
num_reqs,
num_scheduled_tokens,
num_valid_tokens,
):
"""Build attention state for npu's attention backend."""
if vllm_config.model_config.runner_type == "pooling":
if isinstance(
vllm_config.kv_cache_config.kv_cache_groups[0].kv_cache_spec,
EncoderOnlyAttentionSpec,
):
attn_state = AscendAttentionState.PrefillNoCache
else:
attn_state = AscendAttentionState.PrefillCacheHit
elif np.array_equal(seq_lens_np[:num_reqs], num_scheduled_tokens):
attn_state = AscendAttentionState.PrefillNoCache
# We assume it is the decode stage, where prefill occurs
# but only one token is not hit in cache.
elif np.all(num_scheduled_tokens == 1):
attn_state = AscendAttentionState.DecodeOnly
if (vllm_config.speculative_config
and vllm_config.speculative_config.method == 'mtp'):
# SpecDecoding now supports seq_len=1 and seq_len=2
# In Prefilling Decoding Disaggregation scenario, SpecDecoding
# need to supports seq_len=1
attn_state = AscendAttentionState.SpecDecoding
# Speculative decoding.
elif np.all(num_valid_tokens == 1):
if (vllm_config.speculative_config
and vllm_config.speculative_config.method == 'mtp'):
attn_state = AscendAttentionState.SpecDecoding
else:
attn_state = AscendAttentionState.ChunkedPrefill
# splitfuse
elif vllm_config.scheduler_config.enable_chunked_prefill:
attn_state = AscendAttentionState.ChunkedPrefill
else:
attn_state = AscendAttentionState.PrefillCacheHit
return attn_state
def make_attention_mask(
vllm_config: VllmConfig,
attn_state: AscendAttentionState,
dtype: ModelDType | torch.dtype,
device: torch.device,
) -> torch.Tensor:
"""make attention mask for npu's attention backend."""
attn_mask_builder = get_attn_mask_builder(device)
# pcp situation.
if attn_mask_builder is None:
raise ValueError("Attn mask builder is None")
# Pooling situation.
if vllm_config.model_config.runner_type == "pooling":
return attn_mask_builder.get_attn_mask(2048, torch.bool)
# TODO(Ronald1995) cosidering pcp.
if vllm_config.model_config.use_mla:
# mla prefill
if attn_state != AscendAttentionState.DecodeOnly:
return attn_mask_builder.get_mla_mask(dtype)
return attn_mask_builder.get_splitfuse_attn_mask()