RFC: https://github.com/vllm-project/vllm-ascend/issues/4629
Reason:
The metadata data class contains an excessive number of variables. We
will inherit the metadata of the community and simultaneously remove
some variables that are no longer needed at present.
Todo:
1. remove attn_state partly.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
159 lines
6.0 KiB
Python
159 lines
6.0 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
from collections.abc import Sequence
|
|
from typing import Any
|
|
|
|
import numpy as np
|
|
import torch
|
|
from vllm.config import VllmConfig
|
|
from vllm.config.model import ModelDType
|
|
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
|
|
from vllm.v1.kv_cache_interface import EncoderOnlyAttentionSpec, KVCacheConfig
|
|
|
|
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
|
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
|
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
|
AscendPrefillContextParallelMetadata)
|
|
|
|
_ATTENTION_MASK_BUILDER = None
|
|
|
|
|
|
def get_attn_mask_builder(device: torch.device):
|
|
"""Get attention mask builder which only have one instance."""
|
|
global _ATTENTION_MASK_BUILDER
|
|
if _ATTENTION_MASK_BUILDER is None:
|
|
_ATTENTION_MASK_BUILDER = AttentionMaskBuilder(device)
|
|
return _ATTENTION_MASK_BUILDER
|
|
|
|
|
|
def build_attn_metadata(
|
|
attn_metadata_builders: list[AttentionMetadataBuilder],
|
|
num_reqs: int,
|
|
num_tokens: int,
|
|
query_start_loc_gpu: torch.Tensor,
|
|
query_start_loc_cpu: torch.Tensor,
|
|
seq_lens: torch.Tensor,
|
|
seq_lens_cpu: torch.Tensor,
|
|
num_computed_tokens_cpu: torch.Tensor,
|
|
block_tables: Sequence[torch.Tensor],
|
|
slot_mappings: torch.Tensor,
|
|
kv_cache_config: KVCacheConfig,
|
|
decode_token_per_req: int,
|
|
actual_seq_lengths_q: list[int],
|
|
positions: torch.Tensor | None = None,
|
|
attn_mask: torch.Tensor
|
|
| None = None,
|
|
spec_attn_mask: torch.Tensor | None = None,
|
|
attn_state: Any | None = None,
|
|
graph_pad_size: int = -1,
|
|
num_input_tokens: int = 0,
|
|
prefill_context_parallel_metadata: AscendPrefillContextParallelMetadata
|
|
| None = None,
|
|
) -> dict[str, Any]:
|
|
"""Build attention metadata for Ascend NPUs."""
|
|
# TODO(Ronald1995): optimize AscendCommonAttentionMetadata.
|
|
max_query_len = int(query_start_loc_cpu.max())
|
|
|
|
attn_metadata: dict[str, Any] = {}
|
|
kv_cache_groups = kv_cache_config.kv_cache_groups
|
|
for i, kv_cache_spec in enumerate(kv_cache_groups):
|
|
block_table = block_tables[i]
|
|
slot_mapping = slot_mappings[i]
|
|
|
|
common_attn_metadata = AscendCommonAttentionMetadata(
|
|
query_start_loc=query_start_loc_gpu,
|
|
query_start_loc_cpu=query_start_loc_cpu,
|
|
seq_lens_cpu=seq_lens_cpu[:num_reqs],
|
|
seq_lens=seq_lens[:num_reqs],
|
|
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
|
num_reqs=num_reqs,
|
|
num_actual_tokens=num_tokens,
|
|
max_query_len=max_query_len,
|
|
decode_token_per_req=decode_token_per_req,
|
|
block_table_tensor=block_table,
|
|
slot_mapping=slot_mapping,
|
|
actual_seq_lengths_q=actual_seq_lengths_q,
|
|
positions=positions,
|
|
attn_mask=attn_mask,
|
|
spec_attn_mask=spec_attn_mask,
|
|
attn_state=attn_state,
|
|
graph_pad_size=graph_pad_size,
|
|
num_input_tokens=num_input_tokens,
|
|
prefill_context_parallel_metadata=prefill_context_parallel_metadata,
|
|
)
|
|
|
|
attn_metadata_builder = attn_metadata_builders[i]
|
|
metadata = attn_metadata_builder.build(
|
|
common_prefix_len=0,
|
|
common_attn_metadata=common_attn_metadata, # type: ignore
|
|
)
|
|
for layer_name in kv_cache_spec.layer_names:
|
|
attn_metadata[layer_name] = metadata
|
|
return attn_metadata
|
|
|
|
|
|
def build_attn_state(
|
|
vllm_config: VllmConfig,
|
|
seq_lens_np: np.ndarray,
|
|
num_reqs,
|
|
num_scheduled_tokens,
|
|
num_valid_tokens,
|
|
):
|
|
"""Build attention state for npu's attention backend."""
|
|
if vllm_config.model_config.runner_type == "pooling":
|
|
if isinstance(
|
|
vllm_config.kv_cache_config.kv_cache_groups[0].kv_cache_spec,
|
|
EncoderOnlyAttentionSpec,
|
|
):
|
|
attn_state = AscendAttentionState.PrefillNoCache
|
|
else:
|
|
attn_state = AscendAttentionState.PrefillCacheHit
|
|
elif np.array_equal(seq_lens_np[:num_reqs], num_scheduled_tokens):
|
|
attn_state = AscendAttentionState.PrefillNoCache
|
|
# We assume it is the decode stage, where prefill occurs
|
|
# but only one token is not hit in cache.
|
|
elif np.all(num_scheduled_tokens == 1):
|
|
attn_state = AscendAttentionState.DecodeOnly
|
|
if (vllm_config.speculative_config
|
|
and vllm_config.speculative_config.method == 'mtp'):
|
|
# SpecDecoding now supports seq_len=1 and seq_len=2
|
|
# In Prefilling Decoding Disaggregation scenario, SpecDecoding
|
|
# need to supports seq_len=1
|
|
attn_state = AscendAttentionState.SpecDecoding
|
|
# Speculative decoding.
|
|
elif np.all(num_valid_tokens == 1):
|
|
if (vllm_config.speculative_config
|
|
and vllm_config.speculative_config.method == 'mtp'):
|
|
attn_state = AscendAttentionState.SpecDecoding
|
|
else:
|
|
attn_state = AscendAttentionState.ChunkedPrefill
|
|
# splitfuse
|
|
elif vllm_config.scheduler_config.enable_chunked_prefill:
|
|
attn_state = AscendAttentionState.ChunkedPrefill
|
|
else:
|
|
attn_state = AscendAttentionState.PrefillCacheHit
|
|
return attn_state
|
|
|
|
|
|
def make_attention_mask(
|
|
vllm_config: VllmConfig,
|
|
attn_state: AscendAttentionState,
|
|
dtype: ModelDType | torch.dtype,
|
|
device: torch.device,
|
|
) -> torch.Tensor:
|
|
"""make attention mask for npu's attention backend."""
|
|
attn_mask_builder = get_attn_mask_builder(device)
|
|
# pcp situation.
|
|
if attn_mask_builder is None:
|
|
raise ValueError("Attn mask builder is None")
|
|
# Pooling situation.
|
|
if vllm_config.model_config.runner_type == "pooling":
|
|
return attn_mask_builder.get_attn_mask(2048, torch.bool)
|
|
|
|
# TODO(Ronald1995) cosidering pcp.
|
|
if vllm_config.model_config.use_mla:
|
|
# mla prefill
|
|
if attn_state != AscendAttentionState.DecodeOnly:
|
|
return attn_mask_builder.get_mla_mask(dtype)
|
|
return attn_mask_builder.get_splitfuse_attn_mask()
|