### What this PR does / why we need it?
Add support for V1 Engine.
Please note that this is just the initial version, and there may be some
places need to be fixed or optimized in the future, feel free to leave
some comments to us.
### Does this PR introduce _any_ user-facing change?
To use V1 Engine on NPU device, you need to set the env variable shown
below:
```bash
export VLLM_USE_V1=1
export VLLM_WORKER_MULTIPROC_METHOD=spawn
```
If you are using vllm for offline inferencing, you must add a `__main__`
guard like:
```bash
if __name__ == '__main__':
llm = vllm.LLM(...)
```
Find more details
[here](https://docs.vllm.ai/en/latest/getting_started/troubleshooting.html#python-multiprocessing).
### How was this patch tested?
I have tested the online serving with `Qwen2.5-7B-Instruct` using this
command:
```bash
vllm serve Qwen/Qwen2.5-7B-Instruct --max_model_len 26240
```
Query the model with input prompts:
```bash
curl http://localhost:8000/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "Qwen/Qwen2.5-7B-Instruct",
"prompt": "The future of AI is",
"max_tokens": 7,
"temperature": 0
}'
```
---------
Signed-off-by: shen-shanshan <467638484@qq.com>
Co-authored-by: didongli182 <didongli@huawei.com>
1044 lines
44 KiB
Python
1044 lines
44 KiB
Python
#
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
# This file is a part of the vllm-ascend project.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
from dataclasses import dataclass
|
|
from itertools import accumulate
|
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
|
|
|
|
import numpy as np
|
|
import torch
|
|
from torch.nn.functional import scaled_dot_product_attention
|
|
|
|
try:
|
|
import torch_npu # noqa: F401
|
|
except ImportError:
|
|
print("Failed to import torch_npu.")
|
|
|
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
|
AttentionLayer,
|
|
AttentionMetadata, AttentionType,
|
|
MLAAttentionImpl)
|
|
from vllm.attention.backends.utils import (CommonAttentionState,
|
|
CommonMetadataBuilder,
|
|
compute_slot_mapping,
|
|
compute_slot_mapping_start_idx,
|
|
is_block_tables_empty)
|
|
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
|
|
|
if TYPE_CHECKING:
|
|
from vllm_ascend.worker.model_runner import (
|
|
ModelInputForNPUBuilder, ModelInputForNPUWithSamplingMetadata)
|
|
|
|
|
|
def generate_attn_mask(max_seq_len: int, dtype=torch.float16):
|
|
# Construct lower triangle matrix.
|
|
mask_flag = torch.tril(
|
|
torch.ones((max_seq_len, max_seq_len),
|
|
dtype=torch.bool)).view(max_seq_len, max_seq_len)
|
|
# Create upper triangle matrix used to mark mask positions.
|
|
mask_flag = ~mask_flag
|
|
# Currently for fp16 dtype, the mask value should be set to -inf.
|
|
# TODO: Eliminate this part in the future.
|
|
if dtype == torch.float16:
|
|
mask_value = torch.finfo(torch.float32).min
|
|
else:
|
|
mask_value = 1
|
|
attn_mask = torch.masked_fill(torch.zeros(size=(max_seq_len, max_seq_len)),
|
|
mask_flag, mask_value).to(dtype)
|
|
return attn_mask
|
|
|
|
|
|
class AttentionMaskBuilder:
|
|
|
|
def __init__(self, attn_mask: torch.Tensor):
|
|
self._seq_len_cached = attn_mask.shape[0]
|
|
self.attn_mask_cache = attn_mask
|
|
|
|
@classmethod
|
|
def initialize_from_len(cls,
|
|
max_seq_len: int,
|
|
dtype: torch.dtype = torch.float16):
|
|
return cls(generate_attn_mask(max_seq_len, dtype))
|
|
|
|
def update_attn_cache(self, seqlen: int, dtype: torch.dtype,
|
|
device: torch.device):
|
|
if seqlen > self._seq_len_cached or self.attn_mask_cache.dtype != dtype:
|
|
self._seq_len_cached = seqlen
|
|
self.attn_mask_cache = generate_attn_mask(seqlen, dtype)
|
|
if self.attn_mask_cache.device != device:
|
|
self.attn_mask_cache = self.attn_mask_cache.to(device)
|
|
|
|
def get_attn_mask(self, max_seq_len: int, dtype: torch.dtype,
|
|
device: torch.device):
|
|
self.update_attn_cache(max_seq_len, dtype, device)
|
|
return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous()
|
|
|
|
def get_decode_attn_mask(
|
|
self,
|
|
input_lengths: torch.tensor,
|
|
max_s: int,
|
|
dtype: torch.dtype,
|
|
device: torch.device,
|
|
):
|
|
self.update_attn_cache(max_s, dtype, device)
|
|
return (self.attn_mask_cache.index_select(
|
|
0, input_lengths)[:, :max_s].view(-1, 1, max_s).contiguous())
|
|
|
|
|
|
class AscendAttentionBackend(AttentionBackend):
|
|
|
|
@staticmethod
|
|
def get_name() -> str:
|
|
return "ASCEND"
|
|
|
|
@staticmethod
|
|
def get_impl_cls() -> Type["AscendAttentionBackendImpl"]:
|
|
return AscendAttentionBackendImpl
|
|
|
|
@staticmethod
|
|
def get_metadata_cls() -> Type["AscendMetadata"]:
|
|
return AscendMetadata
|
|
|
|
@staticmethod
|
|
def get_state_cls() -> Type["CommonAttentionState"]:
|
|
return CommonAttentionState
|
|
|
|
@staticmethod
|
|
def get_kv_cache_shape(
|
|
num_blocks: int,
|
|
block_size: int,
|
|
num_kv_heads: int,
|
|
head_size: int,
|
|
) -> Tuple[int, ...]:
|
|
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
|
|
|
@staticmethod
|
|
def swap_blocks(
|
|
src_kv_cache: List[torch.Tensor],
|
|
dst_kv_cache: List[torch.Tensor],
|
|
src_to_dst: torch.Tensor,
|
|
) -> None:
|
|
src_key_cache, src_value_cache = src_kv_cache[0], src_kv_cache[1]
|
|
dst_key_cache, dst_value_cache = dst_kv_cache[0], dst_kv_cache[1]
|
|
src_indices = src_to_dst[:, 0]
|
|
dst_indices = src_to_dst[:, 1]
|
|
|
|
dst_key_cache[dst_indices] = src_key_cache[src_indices].to(
|
|
dst_key_cache.device)
|
|
dst_value_cache[dst_indices] = src_value_cache[src_indices].to(
|
|
dst_key_cache.device)
|
|
|
|
@staticmethod
|
|
def copy_blocks(
|
|
kv_caches: List[torch.Tensor],
|
|
src_to_dists: torch.Tensor,
|
|
) -> None:
|
|
src_indices = src_to_dists[:, 0]
|
|
dst_indices = src_to_dists[:, 1]
|
|
|
|
for kv_cache in kv_caches:
|
|
key_caches = kv_cache[0]
|
|
value_caches = kv_cache[1]
|
|
key_caches[dst_indices] = key_caches[src_indices]
|
|
value_caches[dst_indices] = value_caches[src_indices]
|
|
|
|
@staticmethod
|
|
def get_builder_cls() -> Type["AscendMetadataBuilder"]:
|
|
return AscendMetadataBuilder
|
|
|
|
@classmethod
|
|
def make_metadata_builder(cls, *args, **kwargs) -> "AscendMetadataBuilder":
|
|
return cls.get_builder_cls()(*args, **kwargs)
|
|
|
|
|
|
class AscendMLAAttentionBackend(AscendAttentionBackend):
|
|
|
|
@staticmethod
|
|
def get_impl_cls() -> Type["AscendMLAAttentionBackendImpl"]:
|
|
return AscendMLAAttentionBackendImpl
|
|
|
|
@staticmethod
|
|
def get_kv_cache_shape(
|
|
num_blocks: int,
|
|
block_size: int,
|
|
num_kv_heads: int,
|
|
head_size: int,
|
|
) -> Tuple[int, ...]:
|
|
return (1, num_blocks, block_size, num_kv_heads * head_size)
|
|
|
|
|
|
@dataclass
|
|
class AscendMetadata(AttentionMetadata):
|
|
"""Metadata for Ascendbackend.
|
|
* modified from XFormersbackend
|
|
NOTE: Any python object stored here is not updated when it is
|
|
cuda-graph replayed. If you have values that need to be changed
|
|
dynamically, it should be stored in tensor. The tensor has to be
|
|
updated from `CUDAGraphRunner.forward` API.
|
|
"""
|
|
|
|
# |---------- N-1 iteration --------|
|
|
# |---------------- N iteration ---------------------|
|
|
# |- tokenA -|......................|-- newTokens ---|
|
|
# |---------- context_len ----------|
|
|
# |-------------------- seq_len ----------------------|
|
|
# |-- query_len ---|
|
|
|
|
# FIXME: It is for flash attn.
|
|
# Maximum sequence length among prefill batch. 0 if there are decoding
|
|
# Avoid mypy error
|
|
# Total number of prefill requests.
|
|
num_prefills: int
|
|
# Number of prefill tokens.
|
|
num_prefill_tokens: int
|
|
# (num_tokens,). The indices of the token slots that input tokens will be
|
|
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
|
|
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
|
|
# in block 0, and 1st slot in block 1, respectively.
|
|
slot_mapping: torch.Tensor
|
|
|
|
# requests only.
|
|
max_prefill_seq_len: int
|
|
# Maximum sequence length among decode batch. 0 if there are prefill
|
|
# requests only.
|
|
max_decode_seq_len: int
|
|
# (batch_size,) A tensor of context lengths (tokens that are computed
|
|
# so far).
|
|
context_lens_tensor: Optional[torch.Tensor]
|
|
|
|
# (batch_size, max_blocks_per_seq).
|
|
# Block addresses per sequence. (Seq id -> list of physical block)
|
|
block_tables: Optional[torch.Tensor]
|
|
|
|
# seq_lens stored as a tensor.
|
|
seq_lens_tensor: Optional[torch.Tensor]
|
|
|
|
# (batch_size,). The sequence length per sequence. Sequence length means
|
|
# the computed tokens + new tokens None if it is a decoding.
|
|
seq_lens: Optional[List[int]] = None
|
|
|
|
# Maximum query length in the batch. None for decoding.
|
|
max_query_len: Optional[int] = None
|
|
|
|
# Max number of query tokens among request in the batch.
|
|
max_decode_query_len: Optional[int] = None
|
|
|
|
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
|
# the batch, used to index into subquery. E.g., if the subquery length
|
|
# is [4, 6], it is [0, 4, 10].
|
|
query_start_loc: Optional[torch.Tensor] = None
|
|
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
|
# the batch, used to index into sequence. E.g., if the sequence length is
|
|
# [4, 6], it is [0, 4, 10].
|
|
seq_start_loc: Optional[torch.Tensor] = None
|
|
|
|
# Self-attention prefill/decode metadata cache
|
|
_cached_prefill_metadata: Optional["AscendMetadata"] = None
|
|
_cached_decode_metadata: Optional["AscendMetadata"] = None
|
|
|
|
# Begin encoder attn & enc/dec cross-attn fields...
|
|
|
|
# Encoder sequence lengths representation
|
|
encoder_seq_lens: Optional[List[int]] = None
|
|
encoder_seq_lens_tensor: Optional[torch.Tensor] = None
|
|
|
|
# Maximum sequence length among encoder sequences
|
|
max_encoder_seq_len: Optional[int] = None
|
|
|
|
# Number of tokens input to encoder
|
|
num_encoder_tokens: Optional[int] = None
|
|
|
|
attn_mask: Optional[torch.Tensor] = None
|
|
|
|
# Cross-attention memory-mapping data structures: slot mapping
|
|
# and block tables
|
|
cross_slot_mapping: Optional[torch.Tensor] = None
|
|
cross_block_tables: Optional[torch.Tensor] = None
|
|
|
|
@property
|
|
def prefill_metadata(self) -> Optional["AscendMetadata"]:
|
|
if self.num_prefills == 0:
|
|
return None
|
|
|
|
if self._cached_prefill_metadata is not None:
|
|
# Recover cached prefill-phase attention
|
|
# metadata structure.
|
|
return self._cached_prefill_metadata
|
|
|
|
assert ((self.seq_lens is not None)
|
|
or (self.encoder_seq_lens is not None))
|
|
|
|
# Compute some attn_metadata fields which default to None.
|
|
query_start_loc = (None if self.query_start_loc is None else
|
|
self.query_start_loc[:self.num_prefills + 1])
|
|
slot_mapping = (None if self.slot_mapping is None else
|
|
self.slot_mapping[:self.num_prefill_tokens])
|
|
seq_lens = (None if self.seq_lens is None else
|
|
self.seq_lens[:self.num_prefills])
|
|
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
|
self.seq_lens_tensor[:self.num_prefills])
|
|
seq_start_loc = (None if self.seq_start_loc is None else
|
|
self.seq_start_loc[:self.num_prefills + 1])
|
|
context_lens_tensor = (None if self.context_lens_tensor is None else
|
|
self.context_lens_tensor[:self.num_prefills])
|
|
block_tables = (None if self.block_tables is None else
|
|
self.block_tables[:self.num_prefills])
|
|
|
|
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
|
self.seq_lens_tensor[:self.num_prefills])
|
|
|
|
# Construct & cache prefill-phase attention metadata structure.
|
|
self._cached_prefill_metadata = AscendMetadata(
|
|
num_prefills=self.num_prefills,
|
|
num_prefill_tokens=self.num_prefill_tokens,
|
|
num_decode_tokens=0,
|
|
slot_mapping=slot_mapping,
|
|
seq_lens=seq_lens,
|
|
seq_lens_tensor=seq_lens_tensor,
|
|
max_query_len=self.max_query_len,
|
|
max_prefill_seq_len=self.max_prefill_seq_len,
|
|
max_decode_query_len=0,
|
|
max_decode_seq_len=0,
|
|
query_start_loc=query_start_loc,
|
|
seq_start_loc=seq_start_loc,
|
|
context_lens_tensor=context_lens_tensor,
|
|
block_tables=block_tables,
|
|
# Begin encoder & cross attn fields below...
|
|
encoder_seq_lens=self.encoder_seq_lens,
|
|
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
|
|
max_encoder_seq_len=self.max_encoder_seq_len,
|
|
multi_modal_placeholder_index_maps=self.
|
|
multi_modal_placeholder_index_maps,
|
|
cross_slot_mapping=self.cross_slot_mapping,
|
|
cross_block_tables=self.cross_block_tables,
|
|
enable_kv_scales_calculation=False)
|
|
return self._cached_prefill_metadata
|
|
|
|
@property
|
|
def decode_metadata(self) -> Optional["AscendMetadata"]:
|
|
if self.num_decode_tokens == 0:
|
|
return None
|
|
|
|
if self._cached_decode_metadata is not None:
|
|
# Recover cached decode-phase attention
|
|
# metadata structure.
|
|
return self._cached_decode_metadata
|
|
|
|
# Compute some attn_metadata fields which default to None.
|
|
slot_mapping = (None if self.slot_mapping is None else
|
|
self.slot_mapping[self.num_prefill_tokens:])
|
|
seq_lens = (None if self.seq_lens is None else
|
|
self.seq_lens[self.num_prefills:])
|
|
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
|
self.seq_lens_tensor[self.num_prefills:])
|
|
block_tables = (None if self.block_tables is None else
|
|
self.block_tables[self.num_prefills:])
|
|
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
|
self.seq_lens_tensor[self.num_prefills:])
|
|
# Construct & cache decode-phase attention metadata structure.
|
|
self._cached_decode_metadata = AscendMetadata(
|
|
num_prefills=0,
|
|
num_prefill_tokens=0,
|
|
num_decode_tokens=self.num_decode_tokens,
|
|
slot_mapping=slot_mapping,
|
|
seq_lens=seq_lens,
|
|
seq_lens_tensor=seq_lens_tensor,
|
|
max_decode_query_len=self.max_decode_query_len,
|
|
max_query_len=self.max_query_len,
|
|
max_prefill_seq_len=0,
|
|
max_decode_seq_len=self.max_decode_seq_len,
|
|
# Batch may be composed of prefill|decodes, adjust query start
|
|
# indices to refer to the start of decodes. E.g.
|
|
# in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
|
|
query_start_loc=(self.query_start_loc[self.num_prefills:] -
|
|
self.query_start_loc[self.num_prefills])
|
|
if self.query_start_loc is not None else None,
|
|
seq_start_loc=self.seq_start_loc[self.num_prefills:]
|
|
if self.seq_start_loc is not None else None,
|
|
context_lens_tensor=None,
|
|
block_tables=block_tables,
|
|
# Begin encoder & cross attn fields below...
|
|
encoder_seq_lens=self.encoder_seq_lens,
|
|
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
|
|
max_encoder_seq_len=self.max_encoder_seq_len,
|
|
multi_modal_placeholder_index_maps=self.
|
|
multi_modal_placeholder_index_maps,
|
|
cross_slot_mapping=self.cross_slot_mapping,
|
|
cross_block_tables=self.cross_block_tables,
|
|
enable_kv_scales_calculation=False)
|
|
return self._cached_decode_metadata
|
|
|
|
def advance_step(self,
|
|
model_input: "ModelInputForNPUWithSamplingMetadata",
|
|
sampled_token_ids: Optional[torch.Tensor],
|
|
block_size: int,
|
|
num_seqs: int,
|
|
num_queries: int,
|
|
turn_prefills_into_decodes: bool = False):
|
|
"""
|
|
Update metadata in-place to advance one decode step.
|
|
"""
|
|
# When using cudagraph, the num_seqs is padded to the next captured
|
|
# batch sized, but num_queries tracks the actual number of requests in
|
|
# the batch. For --enforce-eager mode, num_seqs == num_queries
|
|
if num_seqs != num_queries:
|
|
assert num_seqs > num_queries
|
|
|
|
if turn_prefills_into_decodes:
|
|
# When Mutli-Step is enabled with Chunked-Prefill, prefills and
|
|
# decodes are scheduled together. In the first step, all the
|
|
# prefills turn into decodes. This update reflects that
|
|
# conversion.
|
|
assert self.num_decode_tokens + self.num_prefills == num_seqs
|
|
self.num_decode_tokens += self.num_prefills
|
|
self.num_prefills = 0
|
|
self.num_prefill_tokens = 0
|
|
self.max_prefill_seq_len = 0
|
|
self.max_query_len = 1
|
|
|
|
self.slot_mapping = self.slot_mapping[:num_seqs]
|
|
else:
|
|
assert self.seq_lens is not None
|
|
assert self.max_decode_seq_len == max(self.seq_lens)
|
|
|
|
assert self.num_prefills == 0
|
|
assert self.num_prefill_tokens == 0
|
|
assert self.num_decode_tokens == num_seqs
|
|
assert self.slot_mapping.shape == (num_seqs, )
|
|
|
|
assert self.seq_lens is not None
|
|
assert len(self.seq_lens) == num_seqs
|
|
assert self.seq_lens_tensor is not None
|
|
assert self.seq_lens_tensor.shape == (num_seqs, )
|
|
assert self.max_query_len == 1
|
|
assert self.max_prefill_seq_len == 0
|
|
|
|
assert self.query_start_loc is not None
|
|
assert self.query_start_loc.shape == (num_queries + 1, )
|
|
assert self.seq_start_loc is not None
|
|
assert self.seq_start_loc.shape == (num_seqs + 1, )
|
|
|
|
assert self.context_lens_tensor is not None
|
|
assert self.context_lens_tensor.shape == (num_queries, )
|
|
|
|
assert self.block_tables is not None
|
|
assert self.block_tables.shape[0] == num_seqs
|
|
|
|
# Update query lengths. Note that we update only queries and not seqs,
|
|
# since tensors may be padded due to captured cuda graph batch size
|
|
for i in range(num_queries):
|
|
self.seq_lens[i] += 1
|
|
self.max_decode_seq_len = max(self.seq_lens)
|
|
|
|
# TODO optimize these codes using ascendc just like flash attention backend using cuda
|
|
|
|
# update input_tokens
|
|
sampled_token_ids_list = sampled_token_ids[:
|
|
num_queries].squeeze( # type: ignore
|
|
-1)
|
|
model_input.input_tokens[:
|
|
num_queries] = sampled_token_ids_list # type: ignore
|
|
|
|
# get seq_lens and input_positions
|
|
seq_lens = self.seq_lens_tensor[:num_queries]
|
|
next_seq_lens = seq_lens + 1
|
|
next_input_pos = next_seq_lens - 1
|
|
|
|
# update seq_lens and input_positions
|
|
self.seq_lens_tensor[:num_queries] = next_seq_lens
|
|
model_input.input_positions[:
|
|
num_queries] = next_input_pos # type: ignore
|
|
|
|
# 计算 block index 和 offset
|
|
block_idx = next_input_pos // block_size
|
|
block_offset = next_input_pos % block_size
|
|
|
|
current_block_table = self.block_tables.gather(
|
|
1, block_idx.unsqueeze(-1)).squeeze(-1)
|
|
slot_num = current_block_table * block_size + block_offset
|
|
|
|
# update slot_mapping
|
|
self.slot_mapping[:num_queries] = slot_num
|
|
|
|
|
|
class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
|
|
|
|
_attn_mask_builder = None # noqa
|
|
|
|
def __init__(self, input_builder: "ModelInputForNPUBuilder"):
|
|
self.input_builder = input_builder
|
|
self.runner = input_builder.runner
|
|
self.sliding_window = input_builder.sliding_window
|
|
self.block_size = input_builder.block_size
|
|
|
|
self.attn_mask = None
|
|
if AscendMetadataBuilder._attn_mask_builder is None:
|
|
AscendMetadataBuilder._attn_mask_builder = AttentionMaskBuilder.initialize_from_len(
|
|
128, self.input_builder.runner.model_config.dtype)
|
|
|
|
def _add_seq_group(
|
|
self, inter_data: "ModelInputForNPUBuilder.InterDataForSeqGroup",
|
|
chunked_prefill_enabled: bool):
|
|
"""Add a sequence group to the metadata. Specifically update/append
|
|
1. context length.
|
|
2. block table.
|
|
3. slot mapping.
|
|
"""
|
|
is_prompt = inter_data.is_prompt
|
|
block_tables = inter_data.block_tables
|
|
|
|
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
|
|
curr_sliding_window_block) in zip(
|
|
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
|
|
inter_data.orig_seq_lens, inter_data.seq_lens,
|
|
inter_data.query_lens, inter_data.context_lens,
|
|
inter_data.curr_sliding_window_blocks):
|
|
self.context_lens.append(context_len)
|
|
if is_prompt:
|
|
self.num_prefills += 1
|
|
self.num_prefill_tokens += token_len
|
|
self.prefill_seq_lens.append(seq_len)
|
|
else:
|
|
assert query_len == 1, (
|
|
"seq_len: {}, context_len: {}, query_len: {}".format(
|
|
seq_len, context_len, query_len))
|
|
self.num_decode_tokens += query_len
|
|
self.curr_seq_lens.append(curr_seq_len)
|
|
|
|
# Compute block table.
|
|
# TODO(sang): Combine chunked prefill and prefix caching by
|
|
# only allowing multiple of block_size chunk size.
|
|
# NOTE: This only works for oooooooxxx style attention.
|
|
block_table: List[int] = []
|
|
prefix_cache_hit = any([
|
|
inter_data.prefix_cache_hit
|
|
for inter_data in self.input_builder.inter_data_list
|
|
])
|
|
if prefix_cache_hit:
|
|
# NOTE(woosuk): For flash-attn, the block table should
|
|
# include the entries for the incoming prefill tokens.
|
|
if block_tables is not None:
|
|
block_table = block_tables[seq_id]
|
|
elif ((chunked_prefill_enabled or not is_prompt)
|
|
and block_tables is not None):
|
|
if curr_sliding_window_block == 0:
|
|
block_table = block_tables[seq_id]
|
|
else:
|
|
block_table = block_tables[seq_id][
|
|
-curr_sliding_window_block:]
|
|
self.block_tables.append(block_table)
|
|
|
|
# Compute slot mapping.
|
|
is_profile_run = is_block_tables_empty(block_tables)
|
|
start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
|
|
context_len,
|
|
self.sliding_window)
|
|
compute_slot_mapping(
|
|
is_profile_run,
|
|
self.slot_mapping,
|
|
seq_id,
|
|
seq_len,
|
|
context_len,
|
|
start_idx,
|
|
self.block_size,
|
|
inter_data.block_tables,
|
|
)
|
|
|
|
def build(
|
|
self,
|
|
seq_lens: List[int],
|
|
query_lens: List[int],
|
|
):
|
|
"""Build attention metadata with on-device tensors.
|
|
|
|
Args:
|
|
seq_lens: The maybe padded sequence lengths of the input sequences.
|
|
query_lens: The query lengths of the input sequences.
|
|
"""
|
|
for inter_data in self.input_builder.inter_data_list:
|
|
self._add_seq_group(inter_data,
|
|
self.input_builder.chunked_prefill_enabled)
|
|
|
|
device = self.runner.device
|
|
|
|
max_query_len = max(query_lens)
|
|
decode_query_lens = query_lens[self.num_prefills:]
|
|
if len(decode_query_lens) > 0:
|
|
max_decode_query_len = max(decode_query_lens)
|
|
else:
|
|
max_decode_query_len = 1
|
|
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
|
|
max_decode_seq_len = max(self.curr_seq_lens, default=0)
|
|
|
|
if self.num_prefills > 0:
|
|
self.attn_mask = AscendMetadataBuilder._attn_mask_builder.get_attn_mask( # type: ignore
|
|
max_prefill_seq_len,
|
|
self.input_builder.runner.model_config.dtype,
|
|
self.input_builder.runner.device)
|
|
else:
|
|
self.attn_mask = None
|
|
num_decode_tokens = self.num_decode_tokens
|
|
query_start_loc = list(accumulate(query_lens, initial=0))
|
|
seq_start_loc = list(accumulate(seq_lens, initial=0))
|
|
|
|
block_tables = make_tensor_with_pad(
|
|
self.block_tables,
|
|
pad=0,
|
|
dtype=torch.int32,
|
|
device=device,
|
|
)
|
|
assert max_query_len > 0, "query_lens: {}".format(query_lens)
|
|
|
|
assert device is not None
|
|
context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int,
|
|
device, self.runner.pin_memory)
|
|
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.int32,
|
|
device, self.runner.pin_memory)
|
|
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
|
|
self.runner.pin_memory)
|
|
query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32,
|
|
device,
|
|
self.runner.pin_memory)
|
|
seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32,
|
|
device, self.runner.pin_memory)
|
|
placeholder_index_maps = {
|
|
modality: placeholder_map.index_map()
|
|
for modality, placeholder_map in
|
|
self.multimodal_placeholder_maps.items()
|
|
}
|
|
|
|
seq_lens_tensor = torch.tensor(seq_lens,
|
|
dtype=torch.long,
|
|
device=device)
|
|
|
|
return AscendMetadata(
|
|
num_prefills=self.num_prefills,
|
|
slot_mapping=slot_mapping_tensor,
|
|
num_prefill_tokens=self.num_prefill_tokens,
|
|
num_decode_tokens=num_decode_tokens,
|
|
seq_lens=seq_lens,
|
|
multi_modal_placeholder_index_maps=placeholder_index_maps,
|
|
enable_kv_scales_calculation=True,
|
|
seq_lens_tensor=seq_lens_tensor,
|
|
max_query_len=max_query_len,
|
|
max_decode_query_len=max_decode_query_len,
|
|
max_prefill_seq_len=max_prefill_seq_len,
|
|
max_decode_seq_len=max_decode_seq_len,
|
|
query_start_loc=query_start_loc_tensor,
|
|
seq_start_loc=seq_start_loc_tensor,
|
|
context_lens_tensor=context_lens_tensor,
|
|
block_tables=block_tables,
|
|
attn_mask=self.attn_mask,
|
|
)
|
|
|
|
|
|
class AscendAttentionBackendImpl(AttentionImpl):
|
|
|
|
def __init__(
|
|
self,
|
|
num_heads: int,
|
|
head_size: int,
|
|
scale: float,
|
|
num_kv_heads: int,
|
|
alibi_slopes: Optional[List[float]],
|
|
sliding_window: Optional[int],
|
|
kv_cache_dtype: str,
|
|
blocksparse_params: Optional[Dict[str, Any]] = None,
|
|
logits_soft_cap: Optional[float] = None,
|
|
attn_type: str = AttentionType.DECODER,
|
|
) -> None:
|
|
self.num_heads = num_heads
|
|
self.head_size = head_size
|
|
self.scale = float(scale)
|
|
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
|
self.hidden_size = self.num_heads * self.head_size
|
|
self.kv_cache_dtype = kv_cache_dtype
|
|
self.sliding_window = sliding_window
|
|
if alibi_slopes is not None:
|
|
alibi_slopes = torch.tensor(alibi_slopes,
|
|
dtype=torch.float32,
|
|
device="npu")
|
|
self.alibi_slopes = alibi_slopes
|
|
self.attn_type = attn_type
|
|
|
|
assert self.num_heads % self.num_kv_heads == 0
|
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
|
self.seq_len_cpu_tensor = None
|
|
self.key_cache = None
|
|
self.value_cache = None
|
|
|
|
def forward(
|
|
self,
|
|
layer: AttentionLayer,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
kv_cache: torch.Tensor,
|
|
attn_metadata: AscendMetadata,
|
|
attn_type: str = AttentionType.DECODER,
|
|
output: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
"""Forward pass with Ascend attention.
|
|
Args:
|
|
query: shape = [num_tokens, num_heads * head_size]
|
|
num_tokens = batch_size * seq_len
|
|
key: shape = [num_tokens, num_kv_heads * head_size]
|
|
value: shape = [num_tokens, num_kv_heads * head_size]
|
|
kv_cache: shape = [2, num_blocks, block_size,
|
|
num_kv_heads * head_size]
|
|
key_cache = [num_blocks, block_size,
|
|
num_kv_heads * head_size]
|
|
value_cache = [num_blocks, block_size,
|
|
num_kv_heads * head_size]
|
|
attn_metadata: Metadata for attention.
|
|
Returns:
|
|
shape = [batch_size, seq_len * num_heads * head_size]
|
|
"""
|
|
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
|
|
# View q k v to BSH.
|
|
num_tokens = query.shape[0]
|
|
query = query.view(-1, self.num_heads, self.head_size)
|
|
key = key.view(-1, self.num_kv_heads, self.head_size)
|
|
value = value.view(-1, self.num_kv_heads, self.head_size)
|
|
# TODO: Remove this contiguous in the future.
|
|
value = value.contiguous()
|
|
attn_type = self.attn_type
|
|
|
|
output = torch.empty(num_tokens,
|
|
self.num_heads,
|
|
self.head_size,
|
|
dtype=query.dtype,
|
|
device=query.device)
|
|
|
|
if kv_cache.numel() > 0:
|
|
if self.key_cache is None:
|
|
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
|
|
slots = attn_metadata.slot_mapping
|
|
|
|
if hasattr(layer, 'quant_method'):
|
|
isPrefill = True if attn_metadata.num_prefills > 0 else False
|
|
if isPrefill:
|
|
assert attn_metadata.prefill_metadata is not None
|
|
self.seq_lens_tensor_cpu = torch.from_numpy(
|
|
np.array(attn_metadata.prefill_metadata.seq_lens).astype(
|
|
np.int32))
|
|
else:
|
|
assert attn_metadata.decode_metadata is not None
|
|
self.seq_lens_tensor_cpu = torch.from_numpy(
|
|
np.array(attn_metadata.decode_metadata.seq_lens).astype(
|
|
np.int32))
|
|
block_tables = attn_metadata.decode_metadata.block_tables if attn_metadata.decode_metadata else None
|
|
# Details of kv_cache arrangement in attention quantization
|
|
# are implemented by quant_method.
|
|
layer.quant_method.apply(
|
|
layer,
|
|
query,
|
|
key,
|
|
value,
|
|
self.key_cache,
|
|
self.value_cache,
|
|
self.scale,
|
|
block_tables,
|
|
isPrefill,
|
|
attn_metadata,
|
|
output,
|
|
seq_lens_tensor_cpu=self.seq_lens_tensor_cpu)
|
|
else:
|
|
if self.key_cache is not None:
|
|
torch_npu._npu_reshape_and_cache(key=key,
|
|
value=value,
|
|
key_cache=self.key_cache,
|
|
value_cache=self.value_cache,
|
|
slot_indices=slots)
|
|
|
|
if attn_metadata.num_prefills > 0:
|
|
|
|
if (attn_metadata.block_tables is None
|
|
or attn_metadata.block_tables.numel() == 0):
|
|
if attn_type == AttentionType.ENCODER_ONLY:
|
|
# TODO: change to use torch_npu encoder attention op, instead
|
|
# of torch sdpa
|
|
query = query.movedim(0, query.dim() - 2)
|
|
key = key.movedim(0, key.dim() - 2)
|
|
value = value.movedim(0, value.dim() - 2)
|
|
|
|
causal_attn = (attn_type == AttentionType.DECODER)
|
|
if attn_metadata.seq_lens is not None:
|
|
seq_lens_q = seq_lens_kv = attn_metadata.seq_lens
|
|
attn_masks = [None] * len(seq_lens_q)
|
|
start_q, start_kv = 0, 0
|
|
for seq_len_q, seq_len_kv, mask in zip(
|
|
seq_lens_q, seq_lens_kv, attn_masks):
|
|
end_q = start_q + seq_len_q
|
|
end_kv = start_kv + seq_len_kv
|
|
sub_out = scaled_dot_product_attention(
|
|
query[None, :, start_q:end_q, :],
|
|
key[None, :, start_kv:end_kv, :],
|
|
value[None, :, start_kv:end_kv, :],
|
|
attn_mask=mask,
|
|
dropout_p=0.0,
|
|
is_causal=causal_attn and mask is None,
|
|
scale=self.scale).squeeze(0).movedim(
|
|
query.dim() - 2, 0)
|
|
output[start_q:end_q, :, :] = sub_out
|
|
start_q, start_kv = end_q, end_kv
|
|
else:
|
|
assert attn_metadata.attn_mask is not None
|
|
mask = attn_metadata.attn_mask
|
|
assert attn_metadata.prefill_metadata is not None
|
|
self.seq_lens_tensor_cpu = torch.from_numpy(
|
|
np.array(attn_metadata.prefill_metadata.seq_lens).
|
|
astype(np.int32))
|
|
torch_npu._npu_flash_attention(
|
|
query=query,
|
|
key=key,
|
|
value=value,
|
|
mask=mask,
|
|
seq_len=self.seq_lens_tensor_cpu,
|
|
scale_value=self.scale,
|
|
num_heads=self.num_heads,
|
|
num_kv_heads=self.num_kv_heads,
|
|
out=output)
|
|
else:
|
|
# TODO: Will support prefix cache and chunked prefill soon.
|
|
raise RuntimeError(
|
|
"Prefix cache and chunked prefill are currently not supported."
|
|
)
|
|
elif attn_metadata.decode_metadata:
|
|
assert self.key_cache is not None
|
|
self.seq_lens_tensor_cpu = torch.from_numpy(
|
|
np.array(attn_metadata.decode_metadata.seq_lens).astype(
|
|
np.int32))
|
|
block_tables = attn_metadata.decode_metadata.block_tables
|
|
torch_npu._npu_paged_attention(
|
|
query=query,
|
|
key_cache=self.key_cache,
|
|
value_cache=self.value_cache,
|
|
num_kv_heads=self.num_kv_heads,
|
|
num_heads=self.num_heads,
|
|
scale_value=self.scale,
|
|
block_table=block_tables,
|
|
context_lens=self.seq_lens_tensor_cpu,
|
|
out=output)
|
|
|
|
return output.view(num_tokens, self.hidden_size)
|
|
|
|
|
|
class AscendMLAAttentionBackendImpl(MLAAttentionImpl):
|
|
|
|
def __init__(
|
|
self,
|
|
num_heads: int,
|
|
head_size: int,
|
|
scale: float,
|
|
num_kv_heads: int,
|
|
alibi_slopes: Optional[List[float]],
|
|
sliding_window: Optional[int],
|
|
kv_cache_dtype: str,
|
|
blocksparse_params: Optional[Dict[str, Any]] = None,
|
|
logits_soft_cap: Optional[float] = None,
|
|
attn_type: str = AttentionType.DECODER,
|
|
**extra_impl_args,
|
|
) -> None:
|
|
self.num_heads = num_heads
|
|
self.head_size = head_size
|
|
self.scale = float(scale)
|
|
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
|
self.hidden_size = self.num_heads * self.head_size
|
|
self.kv_cache_dtype = kv_cache_dtype
|
|
self.sliding_window = sliding_window
|
|
if alibi_slopes is not None:
|
|
alibi_slopes = torch.tensor(alibi_slopes,
|
|
dtype=torch.float32,
|
|
device="npu")
|
|
self.alibi_slopes = alibi_slopes
|
|
self.attn_type = attn_type
|
|
|
|
assert self.num_heads % self.num_kv_heads == 0
|
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
|
self.seq_len_cpu_tensor = None
|
|
|
|
# MLA Args
|
|
self.q_lora_rank = extra_impl_args['q_lora_rank']
|
|
self.kv_lora_rank = extra_impl_args['kv_lora_rank']
|
|
self.qk_nope_head_dim = extra_impl_args['qk_nope_head_dim']
|
|
self.qk_rope_head_dim = extra_impl_args['qk_rope_head_dim']
|
|
self.qk_head_dim = extra_impl_args['qk_head_dim']
|
|
self.v_head_dim = extra_impl_args['v_head_dim']
|
|
self.rotary_emb = extra_impl_args['rotary_emb']
|
|
self.q_proj = extra_impl_args['q_proj']
|
|
self.kv_b_proj = extra_impl_args['kv_b_proj']
|
|
self.o_proj = extra_impl_args['o_proj']
|
|
self.w_kc = None
|
|
self.w_vc = None
|
|
|
|
def forward(
|
|
self,
|
|
layer: AttentionLayer,
|
|
hidden_states_or_q_c: torch.Tensor,
|
|
kv_c_normed: torch.Tensor,
|
|
k_pe: torch.Tensor,
|
|
kv_cache: torch.Tensor,
|
|
attn_metadata: AscendMetadata,
|
|
attn_type: str = AttentionType.DECODER,
|
|
output: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
"""Forward pass with Ascend attention.
|
|
Args:
|
|
hidden_states_or_q_c: shape = [num_tokens, num_heads * head_size]
|
|
num_tokens = batch_size * seq_len
|
|
kv_c_normed: shape = [num_tokens, num_kv_heads * head_size]
|
|
k_pe: shape = [num_tokens, num_kv_heads * head_size]
|
|
kv_cache: shape = [1, num_blocks, block_size,
|
|
num_kv_heads * head_size]
|
|
attn_metadata: Metadata for attention.
|
|
Returns:
|
|
shape = [batch_size, seq_len * num_heads * head_size]
|
|
"""
|
|
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
|
|
attn_type = self.attn_type
|
|
if attn_type != AttentionType.DECODER:
|
|
raise NotImplementedError("Encoder self-attention and "
|
|
"encoder/decoder cross-attention "
|
|
"are not implemented for "
|
|
"PallasAttentionBackendImpl")
|
|
|
|
num_tokens = hidden_states_or_q_c.shape[0]
|
|
q = self.q_proj(hidden_states_or_q_c)[0].view(-1, self.num_heads,
|
|
self.qk_head_dim)
|
|
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim],
|
|
dim=-1)
|
|
|
|
k_pe = k_pe.view(num_tokens, self.num_kv_heads, -1)
|
|
|
|
if self.rotary_emb.__class__.__name__ == 'RotaryEmbedding':
|
|
ori_q_pe_shape, ori_k_pe_shape = q_pe.shape, k_pe.shape
|
|
q_pe = q_pe.reshape(num_tokens, -1)
|
|
k_pe = k_pe.reshape(num_tokens, -1)
|
|
q_pe, k_pe = self.rotary_emb(attn_metadata.input_positions, q_pe,
|
|
k_pe)
|
|
q_pe = q_pe.view(ori_q_pe_shape)
|
|
k_pe = k_pe.view(ori_k_pe_shape)
|
|
else:
|
|
q_pe, k_pe = self.rotary_emb(attn_metadata.input_positions, q_pe,
|
|
k_pe)
|
|
|
|
if self.w_kc is None or self.w_vc is None:
|
|
kv_b_proj_weight = self.kv_b_proj.weight.reshape(
|
|
self.num_heads, self.qk_nope_head_dim + self.v_head_dim,
|
|
self.kv_lora_rank)
|
|
self.w_kc = kv_b_proj_weight[:, :self.
|
|
qk_nope_head_dim, :].contiguous()
|
|
self.w_vc = kv_b_proj_weight[:,
|
|
self.qk_nope_head_dim:, :].transpose(
|
|
1, 2).contiguous()
|
|
|
|
if attn_metadata.num_prefills > 0:
|
|
kv_heads_num = self.num_heads
|
|
kv = self.kv_b_proj(kv_c_normed)[0].view(num_tokens, kv_heads_num,
|
|
-1)
|
|
k_nope, value = kv.split([self.qk_nope_head_dim, self.v_head_dim],
|
|
dim=-1)
|
|
k_cache = torch.cat(
|
|
[kv_c_normed.view(num_tokens, self.num_kv_heads, -1), k_pe],
|
|
dim=2)
|
|
k_pe = k_pe.expand(-1, self.num_heads, -1)
|
|
key = torch.cat([k_nope.view(num_tokens, kv_heads_num, -1), k_pe],
|
|
dim=2)
|
|
else:
|
|
kv_heads_num = self.num_kv_heads
|
|
q_nope_t = torch.transpose(q_nope, 0, 1)
|
|
q_nope_out = torch.bmm(q_nope_t, self.w_kc)
|
|
q_nope = torch.transpose(q_nope_out, 0, 1)
|
|
k_cache = torch.cat(
|
|
[kv_c_normed.view(num_tokens, self.num_kv_heads, -1), k_pe],
|
|
dim=2)
|
|
|
|
query = torch.cat([q_nope, q_pe], dim=-1).view(num_tokens,
|
|
self.num_heads, -1)
|
|
|
|
if kv_cache.numel() > 0:
|
|
key_cache = kv_cache[0]
|
|
num_blocks, block_size, _ = key_cache.shape
|
|
|
|
key_cache = key_cache.view(
|
|
num_blocks, block_size, self.num_kv_heads,
|
|
self.qk_rope_head_dim + self.kv_lora_rank)
|
|
slots = attn_metadata.slot_mapping
|
|
torch_npu._npu_reshape_and_cache_siso(key=k_cache,
|
|
key_cache=key_cache,
|
|
slot_indices=slots)
|
|
|
|
if attn_metadata.num_prefills > 0:
|
|
attn_output = torch.empty(num_tokens,
|
|
self.num_heads,
|
|
self.v_head_dim,
|
|
dtype=query.dtype,
|
|
device="npu")
|
|
if (attn_metadata.block_tables is None
|
|
or attn_metadata.block_tables.numel() == 0):
|
|
assert attn_metadata.attn_mask is not None
|
|
mask = attn_metadata.attn_mask
|
|
assert attn_metadata.prefill_metadata is not None
|
|
assert attn_metadata.prefill_metadata.seq_lens is not None
|
|
self.seq_lens_tensor_cpu = torch.from_numpy(
|
|
np.array(attn_metadata.prefill_metadata.seq_lens).astype(
|
|
np.int32))
|
|
torch_npu._npu_flash_attention(
|
|
query=query,
|
|
key=key,
|
|
value=value,
|
|
mask=mask,
|
|
seq_len=self.seq_lens_tensor_cpu,
|
|
scale_value=self.scale,
|
|
num_heads=self.num_heads,
|
|
num_kv_heads=self.num_heads,
|
|
out=attn_output)
|
|
else:
|
|
# TODO: Will support prefix cache and chunked prefill soon.
|
|
raise RuntimeError(
|
|
"Prefix cache and chunked prefill are currently not supported."
|
|
)
|
|
elif attn_metadata.decode_metadata:
|
|
assert kv_cache is not None
|
|
attn_output = torch.empty(num_tokens,
|
|
self.num_heads,
|
|
self.kv_lora_rank,
|
|
dtype=query.dtype,
|
|
device="npu")
|
|
self.seq_lens_tensor_cpu = torch.from_numpy(
|
|
np.array(attn_metadata.decode_metadata.seq_lens).astype(
|
|
np.int32))
|
|
block_tables = attn_metadata.decode_metadata.block_tables
|
|
torch_npu._npu_paged_attention_mla(
|
|
query=query,
|
|
key_cache=key_cache,
|
|
num_kv_heads=self.num_kv_heads,
|
|
num_heads=self.num_heads,
|
|
scale_value=self.scale,
|
|
block_table=block_tables,
|
|
context_lens=self.seq_lens_tensor_cpu,
|
|
mla_vheadsize=self.kv_lora_rank,
|
|
out=attn_output)
|
|
attn_output_t = torch.transpose(attn_output, 0, 1)
|
|
attn_output_t = torch.bmm(attn_output_t, self.w_vc)
|
|
attn_output = torch.transpose(attn_output_t, 0, 1)
|
|
|
|
output, _ = self.o_proj(attn_output.reshape(num_tokens, -1))
|
|
|
|
return output
|