### What this PR does / why we need it? This PR is a refactoring of model runner, to decouple it from the classes specifically designed for GPU. The changes of model runner are generally showed below:  **Other changes:** I have removed the code of `cuda`, `lora` and `prompt adapter`, because NPU doesn`t support them now. ### Does this PR introduce _any_ user-facing change? no. ### How was this patch tested? I have used `AI-ModelScope/gpt2` for testing `examples/offline_inference_npu.py`, and the results showed that it worked well. The test logs are showed below: ```bash INFO 02-05 09:08:46 __init__.py:30] Available plugins for group vllm.platform_plugins: INFO 02-05 09:08:46 __init__.py:32] name=ascend, value=vllm_ascend:register INFO 02-05 09:08:46 __init__.py:34] all available plugins for group vllm.platform_plugins will be loaded. INFO 02-05 09:08:46 __init__.py:36] set environment variable VLLM_PLUGINS to control which plugins to load. INFO 02-05 09:08:46 __init__.py:44] plugin ascend loaded. INFO 02-05 09:08:46 __init__.py:177] Platform plugin ascend is activated INFO 02-05 09:08:48 config.py:2383] Downcasting torch.float32 to torch.float16. INFO 02-05 09:08:59 config.py:542] This model supports multiple tasks: {'generate', 'score', 'embed', 'reward', 'classify'}. Defaulting to 'generate'. INFO 02-05 09:08:59 llm_engine.py:234] Initializing a V0 LLM engine (v0.1.dev1+gb3a0d01) with config: model='/home/sss/models/AI-ModelScope/gpt2', speculative_config=None, tokenizer='/home/sss/models/AI-ModelScope/gpt2', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=1024, download_dir=None, load_format=auto, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, device_config=npu, decoding_config=DecodingConfig(guided_decoding_backend='xgrammar'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=/home/sss/models/AI-ModelScope/gpt2, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=False, chunked_prefill_enabled=False, use_async_output_proc=True, disable_mm_preprocessor_cache=False, mm_processor_kwargs=None, pooler_config=None, compilation_config={"splitting_ops":[],"compile_sizes":[],"cudagraph_capture_sizes":[256,248,240,232,224,216,208,200,192,184,176,168,160,152,144,136,128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1],"max_capture_size":256}, use_cached_outputs=False, WARNING 02-05 09:09:01 _custom_ops.py:21] Failed to import from vllm._C with ModuleNotFoundError("No module named 'vllm._C'") INFO 02-05 09:09:01 importing.py:16] Triton not installed or not compatible; certain GPU-related functions will not be available. Loading safetensors checkpoint shards: 0% Completed | 0/1 [00:00<?, ?it/s] Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00, 3.18it/s] Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00, 3.18it/s] INFO 02-05 09:09:11 executor_base.py:110] # CPU blocks: 98557, # CPU blocks: 7281 INFO 02-05 09:09:11 executor_base.py:115] Maximum concurrency for 1024 tokens per request: 1539.95x INFO 02-05 09:09:12 llm_engine.py:431] init engine (profile, create kv cache, warmup model) took 2.13 seconds Processed prompts: 100%|██████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:02<00:00, 1.53it/s, est. speed input: 8.41 toks/s, output: 152.97 toks/s] Prompt: 'Hello, my name is', Generated text: " John. I'm a writer, and I'm a writer. I'm a writer. I'm a writer. I'm a writer. I'm a writer. I'm a writer. I'm a writer. I'm a writer. I'm a writer. I'm a writer. I'm a writer. I'm a writer. I'm a writer. I'm a writer. I'm a writer. I'm a writer. I'm a writer. I'm a writer. I'm" Prompt: 'The president of the United States is', Generated text: ' States president. He is the president of the United States. He is the president of the United States. He is the president of the United States. He is the president of the United States. He is the president of the United States. He is the president of the United States. He is the president of the United States. He is the president of the United States. He is the president of the United States. He is the president of the United States. He is the president of the United' Prompt: 'The capital of France is', Generated text: ' the capital of the French Republic, and the capital of the French Republic is the capital of the French Republic.\n\nThe French Republic is the capital of the French Republic.\n\nThe French Republic is the capital of the French Republic.\n\nThe French Republic is the capital of the French Republic.\n\nThe French Republic is the capital of the French Republic.\n\nThe French Republic is the capital of the French Republic.\n\nThe French Republic is the capital of the French Republic.' Prompt: 'The future of AI is', Generated text: '\n\nThe future of AI is a question of how to make it work.\n\nThe future of AI is a question of how to make it work.\n\nThe future of AI is a question of how to make it work.\n\nThe future of AI is a question of how to make it work.\n\nThe future of AI is a question of how to make it work.\n\nThe future of AI is a question of how to make it work.\n\nThe future' ``` --------- Signed-off-by: Shanshan Shen <467638484@qq.com>
685 lines
28 KiB
Python
685 lines
28 KiB
Python
#
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
# This file is a part of the vllm-ascend project.
|
|
# Adapted from vllm-project/vllm/vllm/attention/backends
|
|
# Copyright 2023 The vLLM team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
import math
|
|
from dataclasses import dataclass
|
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
|
|
|
|
import torch
|
|
|
|
try:
|
|
import torch_npu # noqa: F401
|
|
except ImportError:
|
|
print("Failed to import torch_npu.")
|
|
|
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
|
AttentionLayer,
|
|
AttentionMetadata, AttentionType)
|
|
from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState,
|
|
CommonMetadataBuilder,
|
|
compute_slot_mapping_start_idx,
|
|
is_block_tables_empty)
|
|
from vllm.attention.ops.paged_attn import (PagedAttention,
|
|
PagedAttentionMetadata)
|
|
|
|
if TYPE_CHECKING:
|
|
from vllm_ascend.model_runner import ModelInputForNPUBuilder
|
|
|
|
SHARE_MASK_TRIL_PREFIX_CACHE = None
|
|
SHARE_MASK_TRIL = None
|
|
|
|
|
|
class AscendAttentionBackend(AttentionBackend):
|
|
|
|
@staticmethod
|
|
def get_name() -> str:
|
|
return "ASCEND"
|
|
|
|
@staticmethod
|
|
def get_impl_cls() -> Type["AscendAttentionBackendImpl"]:
|
|
return AscendAttentionBackendImpl
|
|
|
|
@staticmethod
|
|
def get_metadata_cls() -> Type["AscendMetadata"]:
|
|
return AscendMetadata
|
|
|
|
@staticmethod
|
|
def get_state_cls() -> Type["CommonAttentionState"]:
|
|
return CommonAttentionState
|
|
|
|
@staticmethod
|
|
def get_kv_cache_shape(
|
|
num_blocks: int,
|
|
block_size: int,
|
|
num_kv_heads: int,
|
|
head_size: int,
|
|
) -> Tuple[int, ...]:
|
|
return (2, num_blocks, block_size, num_kv_heads * head_size)
|
|
|
|
@staticmethod
|
|
def swap_blocks(
|
|
src_kv_cache: List[torch.Tensor],
|
|
dst_kv_cache: List[torch.Tensor],
|
|
src_to_dst: torch.Tensor,
|
|
) -> None:
|
|
src_key_cache, src_value_cache = src_kv_cache[0], src_kv_cache[1]
|
|
dst_key_cache, dst_value_cache = dst_kv_cache[0], dst_kv_cache[1]
|
|
src_indices = src_to_dst[:, 0]
|
|
dst_indices = src_to_dst[:, 1]
|
|
|
|
dst_key_cache[dst_indices] = src_key_cache[src_indices].to(
|
|
dst_key_cache.device)
|
|
dst_value_cache[dst_indices] = src_value_cache[src_indices].to(
|
|
dst_key_cache.device)
|
|
|
|
@staticmethod
|
|
def copy_blocks(
|
|
kv_caches: List[torch.Tensor],
|
|
src_to_dists: torch.Tensor,
|
|
) -> None:
|
|
src_indices = src_to_dists[:, 0]
|
|
dst_indices = src_to_dists[:, 1]
|
|
|
|
for kv_cache in kv_caches:
|
|
key_caches = kv_cache[0]
|
|
value_caches = kv_cache[1]
|
|
key_caches[dst_indices] = key_caches[src_indices]
|
|
value_caches[dst_indices] = value_caches[src_indices]
|
|
|
|
@staticmethod
|
|
def get_builder_cls() -> Type["AscendMetadataBuilder"]:
|
|
return AscendMetadataBuilder
|
|
|
|
@classmethod
|
|
def make_metadata_builder(cls, *args, **kwargs) -> "AscendMetadataBuilder":
|
|
return cls.get_builder_cls()(*args, **kwargs)
|
|
|
|
|
|
class AscendPagedAttention(PagedAttention):
|
|
|
|
@staticmethod
|
|
def write_to_paged_cache(
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
key_cache: torch.Tensor,
|
|
value_cache: torch.Tensor,
|
|
slot_indices: torch.Tensor,
|
|
) -> None:
|
|
torch_npu.npu_scatter_nd_update_(key_cache, slot_indices, key)
|
|
torch_npu.npu_scatter_nd_update_(value_cache, slot_indices, value)
|
|
|
|
|
|
@dataclass
|
|
class AscendMetadata(AttentionMetadata, PagedAttentionMetadata):
|
|
"""Metadata for Ascendbackend.
|
|
* modified from XFormersbackend
|
|
NOTE: Any python object stored here is not updated when it is
|
|
cuda-graph replayed. If you have values that need to be changed
|
|
dynamically, it should be stored in tensor. The tensor has to be
|
|
updated from `CUDAGraphRunner.forward` API.
|
|
"""
|
|
|
|
# |---------- N-1 iteration --------|
|
|
# |---------------- N iteration ---------------------|
|
|
# |- tokenA -|......................|-- newTokens ---|
|
|
# |---------- context_len ----------|
|
|
# |-------------------- seq_len ----------------------|
|
|
# |-- query_len ---|
|
|
|
|
# seq_lens stored as a tensor.
|
|
seq_lens_tensor: Optional[torch.Tensor]
|
|
|
|
# FIXME: It is for flash attn.
|
|
# Maximum sequence length among prefill batch. 0 if there are decoding
|
|
# requests only.
|
|
max_prefill_seq_len: int
|
|
# Maximum sequence length among decode batch. 0 if there are prefill
|
|
# requests only.
|
|
max_decode_seq_len: int
|
|
|
|
# Whether or not if cuda graph is enabled.
|
|
# Cuda-graph is currently enabled for decoding only.
|
|
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
|
use_cuda_graph: bool
|
|
|
|
# (batch_size,). The sequence length per sequence. Sequence length means
|
|
# the computed tokens + new tokens None if it is a decoding.
|
|
seq_lens: Optional[List[int]] = None
|
|
|
|
# FIXME: It is for flash attn.
|
|
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
|
# the batch, used to index into sequence. E.g., if the sequence length is
|
|
# [4, 6], it is [0, 4, 10].
|
|
seq_start_loc: Optional[torch.Tensor] = None
|
|
|
|
# (batch_size,) A tensor of context lengths (tokens that are computed
|
|
# so far).
|
|
context_lens_tensor: Optional[torch.Tensor] = None
|
|
|
|
# Maximum query length in the batch. None for decoding.
|
|
max_query_len: Optional[int] = None
|
|
|
|
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
|
# the batch, used to index into subquery. E.g., if the subquery length
|
|
# is [4, 6], it is [0, 4, 10].
|
|
query_start_loc: Optional[torch.Tensor] = None
|
|
|
|
# Self-attention prefill/decode metadata cache
|
|
_cached_prefill_metadata: Optional["AscendMetadata"] = None
|
|
_cached_decode_metadata: Optional["AscendMetadata"] = None
|
|
|
|
# Begin encoder attn & enc/dec cross-attn fields...
|
|
|
|
# Encoder sequence lengths representation
|
|
encoder_seq_lens: Optional[List[int]] = None
|
|
encoder_seq_lens_tensor: Optional[torch.Tensor] = None
|
|
|
|
# Maximum sequence length among encoder sequences
|
|
max_encoder_seq_len: Optional[int] = None
|
|
|
|
# Number of tokens input to encoder
|
|
num_encoder_tokens: Optional[int] = None
|
|
|
|
attn_mask: Optional[torch.Tensor] = None
|
|
pse_shift: Optional[torch.Tensor] = None
|
|
sparse_mode: int = 0
|
|
|
|
# Cross-attention memory-mapping data structures: slot mapping
|
|
# and block tables
|
|
cross_slot_mapping: Optional[torch.Tensor] = None
|
|
cross_block_tables: Optional[torch.Tensor] = None
|
|
|
|
# slot_mapping: Optional[torch.Tensor] = None
|
|
|
|
@property
|
|
def prefill_metadata(self) -> Optional["AscendMetadata"]:
|
|
if self.num_prefills == 0:
|
|
return None
|
|
|
|
if self._cached_prefill_metadata is not None:
|
|
# Recover cached prefill-phase attention
|
|
# metadata structure
|
|
return self._cached_prefill_metadata
|
|
|
|
assert ((self.seq_lens is not None)
|
|
or (self.encoder_seq_lens is not None))
|
|
assert ((self.seq_lens_tensor is not None)
|
|
or (self.encoder_seq_lens_tensor is not None))
|
|
|
|
# Compute some attn_metadata fields which default to None
|
|
query_start_loc = (None if self.query_start_loc is None else
|
|
self.query_start_loc[:self.num_prefills + 1])
|
|
slot_mapping = (None if self.slot_mapping is None else
|
|
self.slot_mapping[:self.num_prefill_tokens])
|
|
seq_lens = (None if self.seq_lens is None else
|
|
self.seq_lens[:self.num_prefills])
|
|
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
|
self.seq_lens_tensor[:self.num_prefills])
|
|
seq_start_loc = (None if self.seq_start_loc is None else
|
|
self.seq_start_loc[:self.num_prefills + 1])
|
|
context_lens_tensor = (None if self.context_lens_tensor is None else
|
|
self.context_lens_tensor[:self.num_prefills])
|
|
block_tables = (None if self.block_tables is None else
|
|
self.block_tables[:self.num_prefills])
|
|
|
|
# Construct & cache prefill-phase attention metadata structure
|
|
self._cached_prefill_metadata = AscendMetadata(
|
|
num_prefills=self.num_prefills,
|
|
num_prefill_tokens=self.num_prefill_tokens,
|
|
num_decode_tokens=0,
|
|
slot_mapping=slot_mapping,
|
|
seq_lens=seq_lens,
|
|
seq_lens_tensor=seq_lens_tensor,
|
|
max_query_len=self.max_query_len,
|
|
max_prefill_seq_len=self.max_prefill_seq_len,
|
|
max_decode_seq_len=0,
|
|
query_start_loc=query_start_loc,
|
|
seq_start_loc=seq_start_loc,
|
|
context_lens_tensor=context_lens_tensor,
|
|
block_tables=block_tables,
|
|
use_cuda_graph=False,
|
|
# Begin encoder & cross attn fields below...
|
|
encoder_seq_lens=self.encoder_seq_lens,
|
|
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
|
|
max_encoder_seq_len=self.max_encoder_seq_len,
|
|
multi_modal_placeholder_index_maps=self.
|
|
multi_modal_placeholder_index_maps,
|
|
cross_slot_mapping=self.cross_slot_mapping,
|
|
cross_block_tables=self.cross_block_tables,
|
|
enable_kv_scales_calculation=False)
|
|
return self._cached_prefill_metadata
|
|
|
|
@property
|
|
def decode_metadata(self) -> Optional["AscendMetadata"]:
|
|
if self.num_decode_tokens == 0:
|
|
return None
|
|
|
|
if self._cached_decode_metadata is not None:
|
|
# Recover cached decode-phase attention
|
|
# metadata structure
|
|
return self._cached_decode_metadata
|
|
assert ((self.seq_lens_tensor is not None)
|
|
or (self.encoder_seq_lens_tensor is not None))
|
|
|
|
# Compute some attn_metadata fields which default to None
|
|
slot_mapping = (None if self.slot_mapping is None else
|
|
self.slot_mapping[self.num_prefill_tokens:])
|
|
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
|
self.seq_lens_tensor[self.num_prefills:])
|
|
block_tables = (None if self.block_tables is None else
|
|
self.block_tables[self.num_prefills:])
|
|
|
|
# Construct & cache decode-phase attention metadata structure
|
|
self._cached_decode_metadata = AscendMetadata(
|
|
num_prefills=0,
|
|
num_prefill_tokens=0,
|
|
num_decode_tokens=self.num_decode_tokens,
|
|
slot_mapping=slot_mapping,
|
|
seq_lens_tensor=seq_lens_tensor,
|
|
max_prefill_seq_len=0,
|
|
max_decode_seq_len=self.max_decode_seq_len,
|
|
# Batch may be composed of prefill|decodes, adjust query start
|
|
# indices to refer to the start of decodes. E.g.
|
|
# in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
|
|
query_start_loc=(self.query_start_loc[self.num_prefills:] -
|
|
self.query_start_loc[self.num_prefills])
|
|
if self.query_start_loc is not None else None,
|
|
seq_start_loc=self.seq_start_loc[self.num_prefills:]
|
|
if self.seq_start_loc is not None else None,
|
|
context_lens_tensor=None,
|
|
block_tables=block_tables,
|
|
use_cuda_graph=self.use_cuda_graph,
|
|
# Begin encoder & cross attn fields below...
|
|
encoder_seq_lens=self.encoder_seq_lens,
|
|
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
|
|
max_encoder_seq_len=self.max_encoder_seq_len,
|
|
multi_modal_placeholder_index_maps=self.
|
|
multi_modal_placeholder_index_maps,
|
|
cross_slot_mapping=self.cross_slot_mapping,
|
|
cross_block_tables=self.cross_block_tables,
|
|
enable_kv_scales_calculation=False)
|
|
return self._cached_decode_metadata
|
|
|
|
|
|
class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
|
|
|
|
_metadata_cls = AscendMetadata
|
|
|
|
def __init__(self, input_builder: "ModelInputForNPUBuilder"):
|
|
self.input_builder = input_builder
|
|
self.runner = input_builder.runner
|
|
self.sliding_window = input_builder.sliding_window
|
|
self.block_size = input_builder.block_size
|
|
|
|
def compute_npu_slot_indices(self, is_profile_run, slot_indices, seq_id,
|
|
seq_len, context_len, start_idx, block_size,
|
|
block_tables, max_query_len):
|
|
"""
|
|
compute slot indices
|
|
slot mapping in other backend of vllm stores slot indices,
|
|
which are indicates by `block_number * block_size + block_offset`
|
|
In Ascend backend, slot mapping stores [block_number, block_offset].
|
|
To distinguish this, slot_indices is used in this func
|
|
"""
|
|
if is_profile_run:
|
|
# During memory profiling, the block tables are not
|
|
# initialized yet. In this case, we just use a dummy
|
|
# slot mapping.
|
|
# In embeddings, the block tables are {seq_id: None}.
|
|
slot_indices.extend([[PAD_SLOT_ID, 0]] * seq_len)
|
|
return
|
|
# Mask the [0, start_idx) tokens of the prompt with
|
|
# [PAD_SLOT_ID, 0], where start_idx is max(0, seq_len -
|
|
# sliding_window). For example, if the prompt len is 10,
|
|
# sliding window is 8, and block size is 4, the first two
|
|
# tokens are masked and the slot mapping will be
|
|
# [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
|
|
padding_mask_len = max(0, start_idx - context_len)
|
|
slot_indices.extend([[PAD_SLOT_ID, 0]] * padding_mask_len)
|
|
|
|
range_start = max(start_idx, context_len)
|
|
range_end = seq_len
|
|
numel = range_end - range_start
|
|
block_table = block_tables[seq_id]
|
|
|
|
for i in range(range_start, range_end):
|
|
block_number = block_table[i // block_size]
|
|
block_offset = i % block_size
|
|
slot_indices.append([block_number, block_offset])
|
|
slot_indices.extend([[PAD_SLOT_ID, 0]] * (max_query_len - numel))
|
|
|
|
def _add_seq_group(
|
|
self, inter_data: "ModelInputForNPUBuilder.InterDataForSeqGroup",
|
|
chunked_prefill_enabled: bool):
|
|
"""Add a sequence group to the metadata. Specifically update/append
|
|
1. context length.
|
|
2. block table.
|
|
3. slot mapping.
|
|
"""
|
|
is_prompt = inter_data.is_prompt
|
|
block_tables = inter_data.block_tables
|
|
max_query_len = max(
|
|
max(data.query_lens)
|
|
for data in self.input_builder.inter_data_list)
|
|
|
|
is_prompt = inter_data.is_prompt
|
|
block_tables = inter_data.block_tables
|
|
|
|
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
|
|
curr_sliding_window_block) in zip(
|
|
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
|
|
inter_data.orig_seq_lens, inter_data.seq_lens,
|
|
inter_data.query_lens, inter_data.context_lens,
|
|
inter_data.curr_sliding_window_blocks):
|
|
self.context_lens.append(context_len)
|
|
if is_prompt:
|
|
self.num_prefills += 1
|
|
self.num_prefill_tokens += token_len
|
|
self.prefill_seq_lens.append(seq_len)
|
|
else:
|
|
assert query_len == 1, (
|
|
"seq_len: {}, context_len: {}, query_len: {}".format(
|
|
seq_len, context_len, query_len))
|
|
self.num_decode_tokens += query_len
|
|
self.curr_seq_lens.append(curr_seq_len)
|
|
|
|
# Compute block table.
|
|
# TODO(sang): Combine chunked prefill and prefix caching by
|
|
# only allowing multiple of block_size chunk size.
|
|
# NOTE: This only works for oooooooxxx style attention.
|
|
block_table: List[int] = []
|
|
prefix_cache_hit = any([
|
|
inter_data.prefix_cache_hit
|
|
for inter_data in self.input_builder.inter_data_list
|
|
])
|
|
if prefix_cache_hit:
|
|
# NOTE(woosuk): For flash-attn, the block table should
|
|
# include the entries for the incoming prefill tokens.
|
|
if block_tables is not None:
|
|
block_table = block_tables[seq_id]
|
|
elif ((chunked_prefill_enabled or not is_prompt)
|
|
and block_tables is not None):
|
|
if curr_sliding_window_block == 0:
|
|
block_table = block_tables[seq_id]
|
|
else:
|
|
block_table = block_tables[seq_id][
|
|
-curr_sliding_window_block:]
|
|
self.block_tables.append(block_table)
|
|
|
|
# Compute slot mapping.
|
|
is_profile_run = is_block_tables_empty(block_tables)
|
|
start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
|
|
context_len,
|
|
self.sliding_window)
|
|
|
|
self.compute_npu_slot_indices(is_profile_run, self.slot_mapping,
|
|
seq_id, seq_len, context_len,
|
|
start_idx, self.block_size,
|
|
inter_data.block_tables,
|
|
max_query_len)
|
|
|
|
|
|
class AscendAttentionBackendImpl(AttentionImpl):
|
|
|
|
def __init__(
|
|
self,
|
|
num_heads: int,
|
|
head_size: int,
|
|
scale: float,
|
|
num_kv_heads: int,
|
|
alibi_slopes: Optional[List[float]],
|
|
sliding_window: Optional[int],
|
|
kv_cache_dtype: str,
|
|
blocksparse_params: Optional[Dict[str, Any]] = None,
|
|
logits_soft_cap: Optional[float] = None,
|
|
attn_type: str = AttentionType.DECODER,
|
|
) -> None:
|
|
self.num_heads = num_heads
|
|
self.head_size = head_size
|
|
self.scale = float(scale)
|
|
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
|
self.kv_cache_dtype = kv_cache_dtype
|
|
self.sliding_window = sliding_window
|
|
if alibi_slopes is not None:
|
|
alibi_slopes = torch.tensor(alibi_slopes,
|
|
dtype=torch.float32,
|
|
device="npu")
|
|
self.alibi_slopes = alibi_slopes
|
|
self.attn_type = attn_type
|
|
|
|
assert self.num_heads % self.num_kv_heads == 0
|
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
|
|
|
def forward(
|
|
self,
|
|
layer: AttentionLayer,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
kv_cache: List[torch.Tensor],
|
|
attn_metadata: AscendMetadata,
|
|
attn_type: str = AttentionType.DECODER,
|
|
output: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
"""Forward pass with Ascend attention.
|
|
Args:
|
|
query: shape = [num_tokens, num_heads * head_size]
|
|
num_tokens = batch_size * seq_len
|
|
key: shape = [num_tokens, num_kv_heads * head_size]
|
|
value: shape = [num_tokens, num_kv_heads * head_size]
|
|
kv_cache: shape = [2, num_blocks, block_size,
|
|
num_kv_heads * head_size]
|
|
key_cache = [num_blocks, block_size,
|
|
num_kv_heads * head_size]
|
|
value_cache = [num_blocks, block_size,
|
|
num_kv_heads * head_size]
|
|
attn_metadata: Metadata for attention.
|
|
Returns:
|
|
shape = [batch_size, seq_len * num_heads * head_size]
|
|
"""
|
|
assert layer._k_scale == 1.0 and layer._v_scale == 1.0
|
|
attn_type = self.attn_type
|
|
if attn_type != AttentionType.DECODER:
|
|
raise NotImplementedError("Encoder self-attention and "
|
|
"encoder/decoder cross-attention "
|
|
"are not implemented for "
|
|
"PallasAttentionBackendImpl")
|
|
# view q k v to BSH
|
|
num_tokens = query.shape[0]
|
|
|
|
if kv_cache is not None and len(kv_cache) >= 2:
|
|
slot_indices = attn_metadata.slot_mapping
|
|
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
|
AscendPagedAttention.write_to_paged_cache(
|
|
key,
|
|
value,
|
|
key_cache,
|
|
value_cache,
|
|
slot_indices,
|
|
)
|
|
|
|
if attn_metadata.num_prefills > 0:
|
|
if attn_metadata.attn_mask is None:
|
|
if num_tokens > 16384:
|
|
attn_metadata.sparse_mode = 2
|
|
attention_mask = gen_input_mask(
|
|
attn_metadata.max_prefill_seq_len, self.sliding_window,
|
|
num_tokens)
|
|
attn_metadata.attn_mask = attention_mask
|
|
|
|
if (self.alibi_slopes is not None
|
|
and attn_metadata.pse_shift is None):
|
|
attn_metadata.pse_shift = _make_alibi_bias(
|
|
self.alibi_slopes,
|
|
self.num_kv_heads,
|
|
dtype=query.dtype,
|
|
seq_len=attn_metadata.max_prefill_seq_len,
|
|
batch_size=num_tokens,
|
|
)
|
|
|
|
if (len(kv_cache) == 0 or attn_metadata.block_tables is None
|
|
or attn_metadata.block_tables.numel() == 0):
|
|
max_seq_len = attn_metadata.max_prefill_seq_len
|
|
|
|
# shape of q/k/v [B,S*H] --> [B,S,N,D]
|
|
query = query.view(-1, max_seq_len, self.num_heads,
|
|
self.head_size).transpose(1, 2)
|
|
key = key.view(-1, max_seq_len, self.num_kv_heads,
|
|
self.head_size).transpose(1, 2)
|
|
value = value.view(-1, max_seq_len, self.num_kv_heads,
|
|
self.head_size).transpose(1, 2)
|
|
# FA for prefill phase
|
|
output = torch_npu.npu_prompt_flash_attention(
|
|
query,
|
|
key,
|
|
value,
|
|
pse_shift=attn_metadata.pse_shift,
|
|
atten_mask=attn_metadata.attn_mask,
|
|
num_heads=self.num_heads,
|
|
scale_value=1 / math.sqrt(self.head_size),
|
|
input_layout="BNSD",
|
|
num_key_value_heads=self.num_kv_heads,
|
|
pre_tokens=65535,
|
|
next_tokens=0,
|
|
sparse_mode=attn_metadata.sparse_mode,
|
|
)
|
|
# reshape to [B,H]
|
|
output = output.transpose(1, 2).reshape(
|
|
num_tokens, self.num_heads * self.head_size)
|
|
else:
|
|
# prefix-enabled attention
|
|
assert attn_type == AttentionType.DECODER, (
|
|
"Only decoder-only models support prefix caching")
|
|
assert attn_metadata.seq_lens is not None
|
|
assert kv_cache is not None
|
|
query = query.view(query.shape[0], -1,
|
|
self.num_heads * self.head_size)
|
|
output = torch.zeros(query.shape,
|
|
device="npu",
|
|
dtype=query.dtype)
|
|
# TODO (Mengqing Cao): torch_npu.npu_incre_flash_attention
|
|
# support only when `S == 1`, OPTIMIZE ME when prefix caching
|
|
# is supported in torch-npu ops.
|
|
for i in range(query.shape[0]):
|
|
# FA for prefill phase
|
|
output[i] = torch_npu.npu_incre_flash_attention(
|
|
query[i].unsqueeze(0),
|
|
key_cache,
|
|
value_cache,
|
|
num_heads=self.num_heads,
|
|
num_key_value_heads=self.num_kv_heads,
|
|
scale_value=self.scale,
|
|
input_layout="BSH",
|
|
block_table=attn_metadata.block_tables,
|
|
block_size=key_cache.
|
|
shape[1], # max val of block_size == 512
|
|
actual_seq_lengths=attn_metadata.seq_lens,
|
|
)
|
|
# [B,S,H] --> [B,H]
|
|
output = output.squeeze(1)
|
|
|
|
elif attn_metadata.decode_metadata:
|
|
# FA for decoding phase
|
|
assert kv_cache is not None
|
|
# shape of query [B,S*H] --> [B,S,H]
|
|
query = query.view(
|
|
-1,
|
|
1,
|
|
self.head_size * self.num_heads,
|
|
)
|
|
output = torch_npu.npu_incre_flash_attention(
|
|
query,
|
|
key_cache,
|
|
value_cache,
|
|
num_heads=self.num_heads,
|
|
num_key_value_heads=self.num_kv_heads,
|
|
scale_value=self.scale,
|
|
input_layout="BSH",
|
|
block_table=attn_metadata.block_tables,
|
|
block_size=key_cache.shape[1], # max val of block_size == 512
|
|
actual_seq_lengths=attn_metadata.seq_lens,
|
|
)
|
|
|
|
# [B,S,H] --> [B,H]
|
|
output = output.squeeze(1)
|
|
return output
|
|
|
|
|
|
def gen_input_mask(seq_len, sliding_window, len):
|
|
"""
|
|
Generating lower triangular matrix
|
|
"""
|
|
if len > 16384:
|
|
# improve computing performance on NPU when input tokens are huge
|
|
global SHARE_MASK_TRIL_PREFIX_CACHE
|
|
if SHARE_MASK_TRIL_PREFIX_CACHE is None:
|
|
SHARE_MASK_TRIL_PREFIX_CACHE = torch.triu(
|
|
torch.ones(1, 1, 2048, 2048, dtype=bool, device="npu"),
|
|
diagonal=1,
|
|
)
|
|
attention_mask = SHARE_MASK_TRIL_PREFIX_CACHE
|
|
else:
|
|
global SHARE_MASK_TRIL
|
|
if SHARE_MASK_TRIL is None or SHARE_MASK_TRIL.shape[0] < seq_len:
|
|
SHARE_MASK_TRIL = ~torch.tril(
|
|
torch.ones(seq_len, seq_len, dtype=bool, device="npu"))
|
|
|
|
attention_mask = SHARE_MASK_TRIL
|
|
if sliding_window is not None:
|
|
attention_mask = ~attention_mask
|
|
attention_mask = torch.triu(attention_mask,
|
|
diagonal=1 - sliding_window)
|
|
attention_mask = ~attention_mask
|
|
|
|
return attention_mask
|
|
|
|
|
|
def _make_alibi_bias(
|
|
alibi_slopes: torch.Tensor,
|
|
num_kv_heads: int,
|
|
dtype: torch.dtype,
|
|
seq_len: int,
|
|
batch_size: int,
|
|
):
|
|
bias = torch.arange(seq_len, dtype=dtype, device=alibi_slopes.device)
|
|
# NOTE(zhuohan): HF uses
|
|
# `bias = bias[None, :].repeat(seq_len, 1)`
|
|
# here. We find that both biases give the same results, but
|
|
# the bias below more accurately follows the original ALiBi
|
|
# paper.
|
|
# Calculate a matrix where each element represents ith element- jth
|
|
# element.
|
|
bias = bias[None, :] - bias[:, None]
|
|
|
|
padded_len = (seq_len + 7) // 8 * 8
|
|
num_heads = alibi_slopes.shape[0]
|
|
bias = torch.empty(
|
|
1,
|
|
num_heads,
|
|
seq_len,
|
|
padded_len,
|
|
device=alibi_slopes.device,
|
|
dtype=dtype,
|
|
)[:, :, :, :seq_len].copy_(bias)
|
|
bias.mul_(alibi_slopes[:, None, None])
|
|
if num_heads != num_kv_heads:
|
|
bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
|
|
|
|
return bias
|