2025-02-05 10:53:12 +08:00
|
|
|
#
|
|
|
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
|
|
|
#
|
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
|
#
|
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
#
|
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
2025-04-17 14:59:56 +08:00
|
|
|
# This file is a part of the vllm-ascend project.
|
2025-02-05 10:53:12 +08:00
|
|
|
#
|
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass
|
2025-04-18 12:23:32 +08:00
|
|
|
from typing import Any, Dict, List, Optional, Tuple, Type
|
2025-02-05 10:53:12 +08:00
|
|
|
|
2025-02-21 17:07:37 +08:00
|
|
|
import numpy as np
|
2025-02-05 10:53:12 +08:00
|
|
|
import torch
|
2025-04-29 18:03:38 +08:00
|
|
|
import torch_npu
|
2025-04-23 16:23:25 +08:00
|
|
|
import torchair._contrib.custom_torch_ops # type: ignore # noqa: F401
|
2025-04-29 18:03:38 +08:00
|
|
|
from torch.nn.functional import scaled_dot_product_attention
|
2025-02-05 10:53:12 +08:00
|
|
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
|
|
|
|
AttentionLayer,
|
2025-02-21 17:07:37 +08:00
|
|
|
AttentionMetadata, AttentionType,
|
|
|
|
|
MLAAttentionImpl)
|
2025-04-19 17:38:18 +08:00
|
|
|
from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState,
|
2025-02-05 10:53:12 +08:00
|
|
|
CommonMetadataBuilder,
|
2025-02-21 17:07:37 +08:00
|
|
|
compute_slot_mapping,
|
2025-02-05 10:53:12 +08:00
|
|
|
compute_slot_mapping_start_idx,
|
|
|
|
|
is_block_tables_empty)
|
2025-04-23 16:23:25 +08:00
|
|
|
from vllm.config import get_current_vllm_config
|
2025-02-21 17:07:37 +08:00
|
|
|
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
2025-02-05 10:53:12 +08:00
|
|
|
|
2025-04-28 08:53:06 +08:00
|
|
|
from vllm_ascend.ops.cache import concat_and_cache_mla
|
2025-05-20 09:31:30 +08:00
|
|
|
from vllm_ascend.platform import CUSTOM_OP_ENABLED
|
2025-04-18 12:23:32 +08:00
|
|
|
from vllm_ascend.worker.model_runner import (
|
|
|
|
|
ModelInputForNPUBuilder, ModelInputForNPUWithSamplingMetadata)
|
2025-02-05 10:53:12 +08:00
|
|
|
|
2025-02-21 17:07:37 +08:00
|
|
|
|
2025-04-17 19:31:50 +08:00
|
|
|
def generate_attn_mask(max_seq_len: int, dtype=torch.float16, mask_value=None):
|
2025-02-21 17:07:37 +08:00
|
|
|
# Construct lower triangle matrix.
|
|
|
|
|
mask_flag = torch.tril(
|
|
|
|
|
torch.ones((max_seq_len, max_seq_len),
|
|
|
|
|
dtype=torch.bool)).view(max_seq_len, max_seq_len)
|
|
|
|
|
# Create upper triangle matrix used to mark mask positions.
|
|
|
|
|
mask_flag = ~mask_flag
|
|
|
|
|
# Currently for fp16 dtype, the mask value should be set to -inf.
|
|
|
|
|
# TODO: Eliminate this part in the future.
|
2025-04-17 19:31:50 +08:00
|
|
|
if mask_value is None:
|
|
|
|
|
if dtype == torch.float16:
|
|
|
|
|
mask_value = torch.finfo(torch.float32).min
|
|
|
|
|
else:
|
|
|
|
|
mask_value = 1
|
2025-02-21 17:07:37 +08:00
|
|
|
attn_mask = torch.masked_fill(torch.zeros(size=(max_seq_len, max_seq_len)),
|
|
|
|
|
mask_flag, mask_value).to(dtype)
|
|
|
|
|
return attn_mask
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AttentionMaskBuilder:
|
|
|
|
|
|
|
|
|
|
def __init__(self, attn_mask: torch.Tensor):
|
|
|
|
|
self._seq_len_cached = attn_mask.shape[0]
|
|
|
|
|
self.attn_mask_cache = attn_mask
|
2025-04-17 19:31:50 +08:00
|
|
|
self.splitfuse_mask_value = -10000
|
2025-02-21 17:07:37 +08:00
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def initialize_from_len(cls,
|
|
|
|
|
max_seq_len: int,
|
2025-04-17 19:31:50 +08:00
|
|
|
dtype: torch.dtype = torch.float16,
|
|
|
|
|
mask_value: Optional[int] = None):
|
|
|
|
|
return cls(generate_attn_mask(max_seq_len, dtype, mask_value))
|
2025-02-21 17:07:37 +08:00
|
|
|
|
|
|
|
|
def update_attn_cache(self, seqlen: int, dtype: torch.dtype,
|
|
|
|
|
device: torch.device):
|
|
|
|
|
if seqlen > self._seq_len_cached or self.attn_mask_cache.dtype != dtype:
|
|
|
|
|
self._seq_len_cached = seqlen
|
|
|
|
|
self.attn_mask_cache = generate_attn_mask(seqlen, dtype)
|
|
|
|
|
if self.attn_mask_cache.device != device:
|
|
|
|
|
self.attn_mask_cache = self.attn_mask_cache.to(device)
|
|
|
|
|
|
|
|
|
|
def get_attn_mask(self, max_seq_len: int, dtype: torch.dtype,
|
|
|
|
|
device: torch.device):
|
|
|
|
|
self.update_attn_cache(max_seq_len, dtype, device)
|
|
|
|
|
return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous()
|
|
|
|
|
|
|
|
|
|
def get_decode_attn_mask(
|
|
|
|
|
self,
|
|
|
|
|
input_lengths: torch.tensor,
|
|
|
|
|
max_s: int,
|
|
|
|
|
dtype: torch.dtype,
|
|
|
|
|
device: torch.device,
|
|
|
|
|
):
|
|
|
|
|
self.update_attn_cache(max_s, dtype, device)
|
|
|
|
|
return (self.attn_mask_cache.index_select(
|
|
|
|
|
0, input_lengths)[:, :max_s].view(-1, 1, max_s).contiguous())
|
2025-02-05 10:53:12 +08:00
|
|
|
|
2025-04-17 19:31:50 +08:00
|
|
|
def get_splitfuse_attn_mask(
|
|
|
|
|
self,
|
|
|
|
|
seq_lens,
|
|
|
|
|
query_lens,
|
|
|
|
|
position,
|
|
|
|
|
dtype,
|
|
|
|
|
device,
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
max_seq_len = max(seq_lens, default=0)
|
|
|
|
|
if max_seq_len <= self._seq_len_cached:
|
|
|
|
|
self.update_attn_cache(max_seq_len, dtype, device)
|
|
|
|
|
# FIXME: Currently the mask value of chunked-prefill situation and Prefill-Only situation
|
|
|
|
|
# is not the same. Fix this in the future when kernel is ready.
|
2025-04-21 19:25:51 +08:00
|
|
|
if self.attn_mask_cache.numel(
|
|
|
|
|
) > 1 and self.attn_mask_cache[0][1] > 0:
|
2025-04-17 19:31:50 +08:00
|
|
|
attn_mask = self.get_attn_mask( # type: ignore
|
|
|
|
|
max_seq_len, dtype, device)
|
|
|
|
|
attn_mask *= -10000
|
|
|
|
|
else:
|
|
|
|
|
attn_mask = self.attn_mask_cache
|
|
|
|
|
return torch.index_select(attn_mask, dim=0,
|
|
|
|
|
index=position)[:, :max_seq_len]
|
|
|
|
|
total_q_len = sum(query_lens)
|
|
|
|
|
attn_mask = torch.zeros((total_q_len, max_seq_len),
|
|
|
|
|
dtype=dtype,
|
|
|
|
|
device="cpu")
|
|
|
|
|
|
|
|
|
|
current_row = 0
|
|
|
|
|
for i in range(len(query_lens)):
|
|
|
|
|
seq_len = seq_lens[i]
|
|
|
|
|
q_len = query_lens[i]
|
|
|
|
|
context_len = seq_len - q_len
|
|
|
|
|
|
|
|
|
|
assert context_len >= 0
|
|
|
|
|
attn_mask[current_row:current_row + q_len,
|
|
|
|
|
context_len:] = self.splitfuse_mask_value
|
|
|
|
|
right_tensor = attn_mask[current_row:current_row + q_len,
|
|
|
|
|
context_len:seq_len]
|
2025-05-06 21:54:52 +08:00
|
|
|
right_tensor.masked_fill_(
|
2025-04-17 19:31:50 +08:00
|
|
|
right_tensor.tril() == self.splitfuse_mask_value, 0)
|
|
|
|
|
current_row += q_len
|
|
|
|
|
|
|
|
|
|
return attn_mask.to(device, non_blocking=True)
|
|
|
|
|
|
2025-02-05 10:53:12 +08:00
|
|
|
|
|
|
|
|
class AscendAttentionBackend(AttentionBackend):
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_name() -> str:
|
|
|
|
|
return "ASCEND"
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_impl_cls() -> Type["AscendAttentionBackendImpl"]:
|
|
|
|
|
return AscendAttentionBackendImpl
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_metadata_cls() -> Type["AscendMetadata"]:
|
|
|
|
|
return AscendMetadata
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_state_cls() -> Type["CommonAttentionState"]:
|
|
|
|
|
return CommonAttentionState
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_kv_cache_shape(
|
|
|
|
|
num_blocks: int,
|
|
|
|
|
block_size: int,
|
|
|
|
|
num_kv_heads: int,
|
|
|
|
|
head_size: int,
|
|
|
|
|
) -> Tuple[int, ...]:
|
2025-03-05 10:51:07 +08:00
|
|
|
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
2025-02-05 10:53:12 +08:00
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def swap_blocks(
|
|
|
|
|
src_kv_cache: List[torch.Tensor],
|
|
|
|
|
dst_kv_cache: List[torch.Tensor],
|
|
|
|
|
src_to_dst: torch.Tensor,
|
|
|
|
|
) -> None:
|
|
|
|
|
src_key_cache, src_value_cache = src_kv_cache[0], src_kv_cache[1]
|
|
|
|
|
dst_key_cache, dst_value_cache = dst_kv_cache[0], dst_kv_cache[1]
|
|
|
|
|
src_indices = src_to_dst[:, 0]
|
|
|
|
|
dst_indices = src_to_dst[:, 1]
|
|
|
|
|
|
|
|
|
|
dst_key_cache[dst_indices] = src_key_cache[src_indices].to(
|
|
|
|
|
dst_key_cache.device)
|
|
|
|
|
dst_value_cache[dst_indices] = src_value_cache[src_indices].to(
|
|
|
|
|
dst_key_cache.device)
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def copy_blocks(
|
|
|
|
|
kv_caches: List[torch.Tensor],
|
|
|
|
|
src_to_dists: torch.Tensor,
|
|
|
|
|
) -> None:
|
|
|
|
|
src_indices = src_to_dists[:, 0]
|
|
|
|
|
dst_indices = src_to_dists[:, 1]
|
|
|
|
|
|
|
|
|
|
for kv_cache in kv_caches:
|
|
|
|
|
key_caches = kv_cache[0]
|
|
|
|
|
value_caches = kv_cache[1]
|
|
|
|
|
key_caches[dst_indices] = key_caches[src_indices]
|
|
|
|
|
value_caches[dst_indices] = value_caches[src_indices]
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_builder_cls() -> Type["AscendMetadataBuilder"]:
|
|
|
|
|
return AscendMetadataBuilder
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def make_metadata_builder(cls, *args, **kwargs) -> "AscendMetadataBuilder":
|
|
|
|
|
return cls.get_builder_cls()(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
2025-02-21 17:07:37 +08:00
|
|
|
class AscendMLAAttentionBackend(AscendAttentionBackend):
|
2025-02-05 10:53:12 +08:00
|
|
|
|
|
|
|
|
@staticmethod
|
2025-02-21 17:07:37 +08:00
|
|
|
def get_impl_cls() -> Type["AscendMLAAttentionBackendImpl"]:
|
|
|
|
|
return AscendMLAAttentionBackendImpl
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_kv_cache_shape(
|
|
|
|
|
num_blocks: int,
|
|
|
|
|
block_size: int,
|
|
|
|
|
num_kv_heads: int,
|
|
|
|
|
head_size: int,
|
|
|
|
|
) -> Tuple[int, ...]:
|
2025-04-19 17:38:18 +08:00
|
|
|
return (num_blocks, block_size, num_kv_heads, head_size)
|
2025-02-05 10:53:12 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
2025-02-21 17:07:37 +08:00
|
|
|
class AscendMetadata(AttentionMetadata):
|
2025-02-05 10:53:12 +08:00
|
|
|
"""Metadata for Ascendbackend.
|
|
|
|
|
* modified from XFormersbackend
|
|
|
|
|
NOTE: Any python object stored here is not updated when it is
|
|
|
|
|
cuda-graph replayed. If you have values that need to be changed
|
|
|
|
|
dynamically, it should be stored in tensor. The tensor has to be
|
|
|
|
|
updated from `CUDAGraphRunner.forward` API.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
# |---------- N-1 iteration --------|
|
|
|
|
|
# |---------------- N iteration ---------------------|
|
|
|
|
|
# |- tokenA -|......................|-- newTokens ---|
|
|
|
|
|
# |---------- context_len ----------|
|
|
|
|
|
# |-------------------- seq_len ----------------------|
|
|
|
|
|
# |-- query_len ---|
|
|
|
|
|
|
|
|
|
|
# FIXME: It is for flash attn.
|
|
|
|
|
# Maximum sequence length among prefill batch. 0 if there are decoding
|
2025-03-11 19:20:06 +08:00
|
|
|
# Avoid mypy error
|
|
|
|
|
# Total number of prefill requests.
|
|
|
|
|
num_prefills: int
|
|
|
|
|
# Number of prefill tokens.
|
|
|
|
|
num_prefill_tokens: int
|
|
|
|
|
# (num_tokens,). The indices of the token slots that input tokens will be
|
|
|
|
|
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
|
|
|
|
|
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
|
|
|
|
|
# in block 0, and 1st slot in block 1, respectively.
|
|
|
|
|
slot_mapping: torch.Tensor
|
|
|
|
|
|
2025-02-05 10:53:12 +08:00
|
|
|
# requests only.
|
|
|
|
|
max_prefill_seq_len: int
|
|
|
|
|
# Maximum sequence length among decode batch. 0 if there are prefill
|
|
|
|
|
# requests only.
|
|
|
|
|
max_decode_seq_len: int
|
|
|
|
|
|
2025-05-09 16:39:28 +08:00
|
|
|
chunked_prefill_enabled: bool
|
|
|
|
|
|
2025-02-21 17:07:37 +08:00
|
|
|
# (batch_size, max_blocks_per_seq).
|
|
|
|
|
# Block addresses per sequence. (Seq id -> list of physical block)
|
|
|
|
|
block_tables: Optional[torch.Tensor]
|
2025-02-05 10:53:12 +08:00
|
|
|
|
2025-03-11 19:20:06 +08:00
|
|
|
# seq_lens stored as a tensor.
|
|
|
|
|
seq_lens_tensor: Optional[torch.Tensor]
|
|
|
|
|
|
2025-02-05 10:53:12 +08:00
|
|
|
# (batch_size,). The sequence length per sequence. Sequence length means
|
|
|
|
|
# the computed tokens + new tokens None if it is a decoding.
|
|
|
|
|
seq_lens: Optional[List[int]] = None
|
|
|
|
|
|
2025-05-09 16:39:28 +08:00
|
|
|
# The query lengths of the input sequences
|
|
|
|
|
query_lens: Optional[List[int]] = None
|
|
|
|
|
|
2025-02-05 10:53:12 +08:00
|
|
|
# Maximum query length in the batch. None for decoding.
|
|
|
|
|
max_query_len: Optional[int] = None
|
|
|
|
|
|
|
|
|
|
# Self-attention prefill/decode metadata cache
|
|
|
|
|
_cached_prefill_metadata: Optional["AscendMetadata"] = None
|
|
|
|
|
_cached_decode_metadata: Optional["AscendMetadata"] = None
|
|
|
|
|
|
|
|
|
|
# Begin encoder attn & enc/dec cross-attn fields...
|
|
|
|
|
|
|
|
|
|
# Encoder sequence lengths representation
|
|
|
|
|
encoder_seq_lens: Optional[List[int]] = None
|
|
|
|
|
encoder_seq_lens_tensor: Optional[torch.Tensor] = None
|
|
|
|
|
|
|
|
|
|
# Maximum sequence length among encoder sequences
|
|
|
|
|
max_encoder_seq_len: Optional[int] = None
|
|
|
|
|
|
|
|
|
|
# Number of tokens input to encoder
|
|
|
|
|
num_encoder_tokens: Optional[int] = None
|
|
|
|
|
|
2025-05-09 16:39:28 +08:00
|
|
|
# Mask for normal situation
|
2025-02-05 10:53:12 +08:00
|
|
|
attn_mask: Optional[torch.Tensor] = None
|
|
|
|
|
|
2025-05-09 16:39:28 +08:00
|
|
|
# Mask for prefix caching
|
|
|
|
|
compress_mask: Optional[torch.Tensor] = None
|
|
|
|
|
|
|
|
|
|
# Mask for chunked prefill
|
|
|
|
|
chunk_mask: Optional[torch.Tensor] = None
|
|
|
|
|
|
2025-02-05 10:53:12 +08:00
|
|
|
# Cross-attention memory-mapping data structures: slot mapping
|
|
|
|
|
# and block tables
|
|
|
|
|
cross_slot_mapping: Optional[torch.Tensor] = None
|
|
|
|
|
cross_block_tables: Optional[torch.Tensor] = None
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def prefill_metadata(self) -> Optional["AscendMetadata"]:
|
|
|
|
|
if self.num_prefills == 0:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
if self._cached_prefill_metadata is not None:
|
|
|
|
|
# Recover cached prefill-phase attention
|
2025-02-21 17:07:37 +08:00
|
|
|
# metadata structure.
|
2025-02-05 10:53:12 +08:00
|
|
|
return self._cached_prefill_metadata
|
|
|
|
|
|
|
|
|
|
assert ((self.seq_lens is not None)
|
|
|
|
|
or (self.encoder_seq_lens is not None))
|
|
|
|
|
|
2025-02-21 17:07:37 +08:00
|
|
|
# Compute some attn_metadata fields which default to None.
|
2025-02-05 10:53:12 +08:00
|
|
|
slot_mapping = (None if self.slot_mapping is None else
|
|
|
|
|
self.slot_mapping[:self.num_prefill_tokens])
|
|
|
|
|
seq_lens = (None if self.seq_lens is None else
|
|
|
|
|
self.seq_lens[:self.num_prefills])
|
2025-05-09 16:39:28 +08:00
|
|
|
query_lens = (None if self.query_lens is None else
|
|
|
|
|
self.query_lens[:self.num_prefills])
|
2025-02-05 10:53:12 +08:00
|
|
|
block_tables = (None if self.block_tables is None else
|
|
|
|
|
self.block_tables[:self.num_prefills])
|
|
|
|
|
|
[Core] Support pooling (#229)
This PR added pooling support for vllm-ascend
Tested with `bge-base-en-v1.5` by encode:
```
from vllm import LLM
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create an LLM.
model = LLM(model="./bge-base-en-v1.5", enforce_eager=True)
# Generate embedding. The output is a list of EmbeddingRequestOutputs.
outputs = model.encode(prompts)
# Print the outputs.
for output in outputs:
print(output.outputs.embedding) # list of 4096 floats
```
Tested by embedding:
```
from vllm import LLM, SamplingParams
llm = LLM(model="./bge-base-en-v1.5", task="embed")
(output,) = llm.embed("Hello, my name is")
embeds = output.outputs.embedding
print(f"Embeddings: {embeds!r} (size={len(embeds)})")
```
Related: https://github.com/vllm-project/vllm-ascend/issues/200
## Known issue
The accuracy is not correct since this feature rely on `enc-dec`
support. It'll be done in the following PR by @MengqingCao
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
2025-03-04 15:59:34 +08:00
|
|
|
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
|
|
|
|
self.seq_lens_tensor[:self.num_prefills])
|
|
|
|
|
|
2025-02-21 17:07:37 +08:00
|
|
|
# Construct & cache prefill-phase attention metadata structure.
|
2025-02-05 10:53:12 +08:00
|
|
|
self._cached_prefill_metadata = AscendMetadata(
|
|
|
|
|
num_prefills=self.num_prefills,
|
|
|
|
|
num_prefill_tokens=self.num_prefill_tokens,
|
|
|
|
|
num_decode_tokens=0,
|
|
|
|
|
slot_mapping=slot_mapping,
|
|
|
|
|
seq_lens=seq_lens,
|
[Core] Support pooling (#229)
This PR added pooling support for vllm-ascend
Tested with `bge-base-en-v1.5` by encode:
```
from vllm import LLM
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create an LLM.
model = LLM(model="./bge-base-en-v1.5", enforce_eager=True)
# Generate embedding. The output is a list of EmbeddingRequestOutputs.
outputs = model.encode(prompts)
# Print the outputs.
for output in outputs:
print(output.outputs.embedding) # list of 4096 floats
```
Tested by embedding:
```
from vllm import LLM, SamplingParams
llm = LLM(model="./bge-base-en-v1.5", task="embed")
(output,) = llm.embed("Hello, my name is")
embeds = output.outputs.embedding
print(f"Embeddings: {embeds!r} (size={len(embeds)})")
```
Related: https://github.com/vllm-project/vllm-ascend/issues/200
## Known issue
The accuracy is not correct since this feature rely on `enc-dec`
support. It'll be done in the following PR by @MengqingCao
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
2025-03-04 15:59:34 +08:00
|
|
|
seq_lens_tensor=seq_lens_tensor,
|
2025-05-09 16:39:28 +08:00
|
|
|
query_lens=query_lens,
|
2025-02-05 10:53:12 +08:00
|
|
|
max_query_len=self.max_query_len,
|
|
|
|
|
max_prefill_seq_len=self.max_prefill_seq_len,
|
|
|
|
|
max_decode_seq_len=0,
|
2025-05-09 16:39:28 +08:00
|
|
|
chunked_prefill_enabled=self.chunked_prefill_enabled,
|
2025-02-05 10:53:12 +08:00
|
|
|
block_tables=block_tables,
|
|
|
|
|
# Begin encoder & cross attn fields below...
|
|
|
|
|
encoder_seq_lens=self.encoder_seq_lens,
|
|
|
|
|
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
|
|
|
|
|
max_encoder_seq_len=self.max_encoder_seq_len,
|
|
|
|
|
multi_modal_placeholder_index_maps=self.
|
|
|
|
|
multi_modal_placeholder_index_maps,
|
|
|
|
|
cross_slot_mapping=self.cross_slot_mapping,
|
|
|
|
|
cross_block_tables=self.cross_block_tables,
|
|
|
|
|
enable_kv_scales_calculation=False)
|
|
|
|
|
return self._cached_prefill_metadata
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def decode_metadata(self) -> Optional["AscendMetadata"]:
|
|
|
|
|
if self.num_decode_tokens == 0:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
if self._cached_decode_metadata is not None:
|
|
|
|
|
# Recover cached decode-phase attention
|
2025-02-21 17:07:37 +08:00
|
|
|
# metadata structure.
|
2025-02-05 10:53:12 +08:00
|
|
|
return self._cached_decode_metadata
|
|
|
|
|
|
2025-02-21 17:07:37 +08:00
|
|
|
# Compute some attn_metadata fields which default to None.
|
2025-02-05 10:53:12 +08:00
|
|
|
slot_mapping = (None if self.slot_mapping is None else
|
|
|
|
|
self.slot_mapping[self.num_prefill_tokens:])
|
2025-02-21 17:07:37 +08:00
|
|
|
seq_lens = (None if self.seq_lens is None else
|
|
|
|
|
self.seq_lens[self.num_prefills:])
|
2025-05-09 16:39:28 +08:00
|
|
|
query_lens = (None if self.query_lens is None else
|
|
|
|
|
self.query_lens[self.num_prefills:])
|
2025-02-05 10:53:12 +08:00
|
|
|
block_tables = (None if self.block_tables is None else
|
|
|
|
|
self.block_tables[self.num_prefills:])
|
[Core] Support pooling (#229)
This PR added pooling support for vllm-ascend
Tested with `bge-base-en-v1.5` by encode:
```
from vllm import LLM
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create an LLM.
model = LLM(model="./bge-base-en-v1.5", enforce_eager=True)
# Generate embedding. The output is a list of EmbeddingRequestOutputs.
outputs = model.encode(prompts)
# Print the outputs.
for output in outputs:
print(output.outputs.embedding) # list of 4096 floats
```
Tested by embedding:
```
from vllm import LLM, SamplingParams
llm = LLM(model="./bge-base-en-v1.5", task="embed")
(output,) = llm.embed("Hello, my name is")
embeds = output.outputs.embedding
print(f"Embeddings: {embeds!r} (size={len(embeds)})")
```
Related: https://github.com/vllm-project/vllm-ascend/issues/200
## Known issue
The accuracy is not correct since this feature rely on `enc-dec`
support. It'll be done in the following PR by @MengqingCao
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
2025-03-04 15:59:34 +08:00
|
|
|
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
|
|
|
|
self.seq_lens_tensor[self.num_prefills:])
|
2025-02-21 17:07:37 +08:00
|
|
|
# Construct & cache decode-phase attention metadata structure.
|
2025-02-05 10:53:12 +08:00
|
|
|
self._cached_decode_metadata = AscendMetadata(
|
|
|
|
|
num_prefills=0,
|
|
|
|
|
num_prefill_tokens=0,
|
|
|
|
|
num_decode_tokens=self.num_decode_tokens,
|
|
|
|
|
slot_mapping=slot_mapping,
|
2025-02-21 17:07:37 +08:00
|
|
|
seq_lens=seq_lens,
|
[Core] Support pooling (#229)
This PR added pooling support for vllm-ascend
Tested with `bge-base-en-v1.5` by encode:
```
from vllm import LLM
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create an LLM.
model = LLM(model="./bge-base-en-v1.5", enforce_eager=True)
# Generate embedding. The output is a list of EmbeddingRequestOutputs.
outputs = model.encode(prompts)
# Print the outputs.
for output in outputs:
print(output.outputs.embedding) # list of 4096 floats
```
Tested by embedding:
```
from vllm import LLM, SamplingParams
llm = LLM(model="./bge-base-en-v1.5", task="embed")
(output,) = llm.embed("Hello, my name is")
embeds = output.outputs.embedding
print(f"Embeddings: {embeds!r} (size={len(embeds)})")
```
Related: https://github.com/vllm-project/vllm-ascend/issues/200
## Known issue
The accuracy is not correct since this feature rely on `enc-dec`
support. It'll be done in the following PR by @MengqingCao
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
2025-03-04 15:59:34 +08:00
|
|
|
seq_lens_tensor=seq_lens_tensor,
|
2025-05-09 16:39:28 +08:00
|
|
|
query_lens=query_lens,
|
2025-03-11 19:20:06 +08:00
|
|
|
max_query_len=self.max_query_len,
|
2025-02-05 10:53:12 +08:00
|
|
|
max_prefill_seq_len=0,
|
|
|
|
|
max_decode_seq_len=self.max_decode_seq_len,
|
2025-05-09 16:39:28 +08:00
|
|
|
chunked_prefill_enabled=self.chunked_prefill_enabled,
|
2025-02-05 10:53:12 +08:00
|
|
|
block_tables=block_tables,
|
|
|
|
|
# Begin encoder & cross attn fields below...
|
|
|
|
|
encoder_seq_lens=self.encoder_seq_lens,
|
|
|
|
|
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
|
|
|
|
|
max_encoder_seq_len=self.max_encoder_seq_len,
|
|
|
|
|
multi_modal_placeholder_index_maps=self.
|
|
|
|
|
multi_modal_placeholder_index_maps,
|
|
|
|
|
cross_slot_mapping=self.cross_slot_mapping,
|
|
|
|
|
cross_block_tables=self.cross_block_tables,
|
|
|
|
|
enable_kv_scales_calculation=False)
|
|
|
|
|
return self._cached_decode_metadata
|
|
|
|
|
|
2025-03-11 19:20:06 +08:00
|
|
|
def advance_step(self,
|
|
|
|
|
model_input: "ModelInputForNPUWithSamplingMetadata",
|
|
|
|
|
sampled_token_ids: Optional[torch.Tensor],
|
|
|
|
|
block_size: int,
|
|
|
|
|
num_seqs: int,
|
|
|
|
|
num_queries: int,
|
|
|
|
|
turn_prefills_into_decodes: bool = False):
|
|
|
|
|
"""
|
|
|
|
|
Update metadata in-place to advance one decode step.
|
|
|
|
|
"""
|
|
|
|
|
# When using cudagraph, the num_seqs is padded to the next captured
|
|
|
|
|
# batch sized, but num_queries tracks the actual number of requests in
|
|
|
|
|
# the batch. For --enforce-eager mode, num_seqs == num_queries
|
|
|
|
|
if num_seqs != num_queries:
|
|
|
|
|
assert num_seqs > num_queries
|
|
|
|
|
|
|
|
|
|
if turn_prefills_into_decodes:
|
|
|
|
|
# When Mutli-Step is enabled with Chunked-Prefill, prefills and
|
|
|
|
|
# decodes are scheduled together. In the first step, all the
|
|
|
|
|
# prefills turn into decodes. This update reflects that
|
|
|
|
|
# conversion.
|
|
|
|
|
assert self.num_decode_tokens + self.num_prefills == num_seqs
|
|
|
|
|
self.num_decode_tokens += self.num_prefills
|
|
|
|
|
self.num_prefills = 0
|
|
|
|
|
self.num_prefill_tokens = 0
|
|
|
|
|
self.max_prefill_seq_len = 0
|
|
|
|
|
self.max_query_len = 1
|
|
|
|
|
|
|
|
|
|
self.slot_mapping = self.slot_mapping[:num_seqs]
|
|
|
|
|
else:
|
|
|
|
|
assert self.seq_lens is not None
|
|
|
|
|
assert self.max_decode_seq_len == max(self.seq_lens)
|
|
|
|
|
|
|
|
|
|
assert self.num_prefills == 0
|
|
|
|
|
assert self.num_prefill_tokens == 0
|
|
|
|
|
assert self.num_decode_tokens == num_seqs
|
|
|
|
|
assert self.slot_mapping.shape == (num_seqs, )
|
|
|
|
|
|
|
|
|
|
assert self.seq_lens is not None
|
|
|
|
|
assert len(self.seq_lens) == num_seqs
|
|
|
|
|
assert self.seq_lens_tensor is not None
|
|
|
|
|
assert self.seq_lens_tensor.shape == (num_seqs, )
|
|
|
|
|
assert self.max_query_len == 1
|
|
|
|
|
assert self.max_prefill_seq_len == 0
|
|
|
|
|
|
|
|
|
|
assert self.block_tables is not None
|
|
|
|
|
assert self.block_tables.shape[0] == num_seqs
|
|
|
|
|
|
|
|
|
|
# Update query lengths. Note that we update only queries and not seqs,
|
|
|
|
|
# since tensors may be padded due to captured cuda graph batch size
|
|
|
|
|
for i in range(num_queries):
|
|
|
|
|
self.seq_lens[i] += 1
|
|
|
|
|
self.max_decode_seq_len = max(self.seq_lens)
|
2025-05-20 09:31:30 +08:00
|
|
|
if CUSTOM_OP_ENABLED:
|
|
|
|
|
#advance a step on NPU for existing inputs for a multi-step runner if custom ops is enabled
|
|
|
|
|
torch.ops._C.advance_step_flashattn_ascendc(
|
|
|
|
|
num_seqs=num_seqs,
|
|
|
|
|
num_queries=num_queries,
|
|
|
|
|
block_size=block_size,
|
|
|
|
|
input_tokens=model_input.input_tokens,
|
|
|
|
|
sampled_token_ids=sampled_token_ids,
|
|
|
|
|
input_positions=model_input.input_positions,
|
|
|
|
|
seq_lens=self.seq_lens_tensor,
|
|
|
|
|
slot_mapping=self.slot_mapping,
|
|
|
|
|
block_tables=self.block_tables)
|
|
|
|
|
else:
|
|
|
|
|
# use traditional Pytorch method for updating these tensors.
|
|
|
|
|
# update input_tokens
|
|
|
|
|
sampled_token_ids_list = sampled_token_ids[:
|
|
|
|
|
num_queries].squeeze( # type: ignore
|
|
|
|
|
-1)
|
|
|
|
|
model_input.input_tokens[:
|
|
|
|
|
num_queries] = sampled_token_ids_list # type: ignore
|
2025-03-11 19:20:06 +08:00
|
|
|
|
2025-05-20 09:31:30 +08:00
|
|
|
# get seq_lens and input_positions
|
|
|
|
|
seq_lens = self.seq_lens_tensor[:num_queries]
|
|
|
|
|
next_seq_lens = seq_lens + 1
|
|
|
|
|
next_input_pos = next_seq_lens - 1
|
2025-03-11 19:20:06 +08:00
|
|
|
|
2025-05-20 09:31:30 +08:00
|
|
|
# update seq_lens and input_positions
|
|
|
|
|
self.seq_lens_tensor[:num_queries] = next_seq_lens
|
|
|
|
|
model_input.input_positions[:
|
|
|
|
|
num_queries] = next_input_pos # type: ignore
|
2025-03-11 19:20:06 +08:00
|
|
|
|
2025-05-20 09:31:30 +08:00
|
|
|
# 计算 block index 和 offset
|
|
|
|
|
block_idx = next_input_pos // block_size
|
|
|
|
|
block_offset = next_input_pos % block_size
|
2025-03-11 19:20:06 +08:00
|
|
|
|
2025-05-20 09:31:30 +08:00
|
|
|
current_block_table = self.block_tables.gather(
|
|
|
|
|
1, block_idx.unsqueeze(-1)).squeeze(-1)
|
|
|
|
|
slot_num = current_block_table * block_size + block_offset
|
2025-03-11 19:20:06 +08:00
|
|
|
|
2025-05-20 09:31:30 +08:00
|
|
|
# update slot_mapping
|
|
|
|
|
self.slot_mapping[:num_queries] = slot_num
|
2025-03-11 19:20:06 +08:00
|
|
|
|
2025-02-05 10:53:12 +08:00
|
|
|
|
|
|
|
|
class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
|
|
|
|
|
|
2025-02-21 17:07:37 +08:00
|
|
|
_attn_mask_builder = None # noqa
|
2025-02-05 10:53:12 +08:00
|
|
|
|
[ModelRunner] Refactor model_runner for NPU (#6)
### What this PR does / why we need it?
This PR is a refactoring of model runner, to decouple it from the
classes specifically designed for GPU.
The changes of model runner are generally showed below:

**Other changes:** I have removed the code of `cuda`, `lora` and `prompt
adapter`, because NPU doesn`t support them now.
### Does this PR introduce _any_ user-facing change?
no.
### How was this patch tested?
I have used `AI-ModelScope/gpt2` for testing
`examples/offline_inference_npu.py`, and the results showed that it
worked well.
The test logs are showed below:
```bash
INFO 02-05 09:08:46 __init__.py:30] Available plugins for group vllm.platform_plugins:
INFO 02-05 09:08:46 __init__.py:32] name=ascend, value=vllm_ascend:register
INFO 02-05 09:08:46 __init__.py:34] all available plugins for group vllm.platform_plugins will be loaded.
INFO 02-05 09:08:46 __init__.py:36] set environment variable VLLM_PLUGINS to control which plugins to load.
INFO 02-05 09:08:46 __init__.py:44] plugin ascend loaded.
INFO 02-05 09:08:46 __init__.py:177] Platform plugin ascend is activated
INFO 02-05 09:08:48 config.py:2383] Downcasting torch.float32 to torch.float16.
INFO 02-05 09:08:59 config.py:542] This model supports multiple tasks: {'generate', 'score', 'embed', 'reward', 'classify'}. Defaulting to 'generate'.
INFO 02-05 09:08:59 llm_engine.py:234] Initializing a V0 LLM engine (v0.1.dev1+gb3a0d01) with config: model='/home/sss/models/AI-ModelScope/gpt2', speculative_config=None, tokenizer='/home/sss/models/AI-ModelScope/gpt2', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=1024, download_dir=None, load_format=auto, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, device_config=npu, decoding_config=DecodingConfig(guided_decoding_backend='xgrammar'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=/home/sss/models/AI-ModelScope/gpt2, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=False, chunked_prefill_enabled=False, use_async_output_proc=True, disable_mm_preprocessor_cache=False, mm_processor_kwargs=None, pooler_config=None, compilation_config={"splitting_ops":[],"compile_sizes":[],"cudagraph_capture_sizes":[256,248,240,232,224,216,208,200,192,184,176,168,160,152,144,136,128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1],"max_capture_size":256}, use_cached_outputs=False,
WARNING 02-05 09:09:01 _custom_ops.py:21] Failed to import from vllm._C with ModuleNotFoundError("No module named 'vllm._C'")
INFO 02-05 09:09:01 importing.py:16] Triton not installed or not compatible; certain GPU-related functions will not be available.
Loading safetensors checkpoint shards: 0% Completed | 0/1 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00, 3.18it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00, 3.18it/s]
INFO 02-05 09:09:11 executor_base.py:110] # CPU blocks: 98557, # CPU blocks: 7281
INFO 02-05 09:09:11 executor_base.py:115] Maximum concurrency for 1024 tokens per request: 1539.95x
INFO 02-05 09:09:12 llm_engine.py:431] init engine (profile, create kv cache, warmup model) took 2.13 seconds
Processed prompts: 100%|██████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:02<00:00, 1.53it/s, est. speed input: 8.41 toks/s, output: 152.97 toks/s]
Prompt: 'Hello, my name is', Generated text: " John. I'm a writer, and I'm a writer. I'm a writer. I'm a writer. I'm a writer. I'm a writer. I'm a writer. I'm a writer. I'm a writer. I'm a writer. I'm a writer. I'm a writer. I'm a writer. I'm a writer. I'm a writer. I'm a writer. I'm a writer. I'm a writer. I'm a writer. I'm"
Prompt: 'The president of the United States is', Generated text: ' States president. He is the president of the United States. He is the president of the United States. He is the president of the United States. He is the president of the United States. He is the president of the United States. He is the president of the United States. He is the president of the United States. He is the president of the United States. He is the president of the United States. He is the president of the United States. He is the president of the United'
Prompt: 'The capital of France is', Generated text: ' the capital of the French Republic, and the capital of the French Republic is the capital of the French Republic.\n\nThe French Republic is the capital of the French Republic.\n\nThe French Republic is the capital of the French Republic.\n\nThe French Republic is the capital of the French Republic.\n\nThe French Republic is the capital of the French Republic.\n\nThe French Republic is the capital of the French Republic.\n\nThe French Republic is the capital of the French Republic.'
Prompt: 'The future of AI is', Generated text: '\n\nThe future of AI is a question of how to make it work.\n\nThe future of AI is a question of how to make it work.\n\nThe future of AI is a question of how to make it work.\n\nThe future of AI is a question of how to make it work.\n\nThe future of AI is a question of how to make it work.\n\nThe future of AI is a question of how to make it work.\n\nThe future'
```
---------
Signed-off-by: Shanshan Shen <467638484@qq.com>
2025-02-06 09:04:18 +08:00
|
|
|
def __init__(self, input_builder: "ModelInputForNPUBuilder"):
|
|
|
|
|
self.input_builder = input_builder
|
|
|
|
|
self.runner = input_builder.runner
|
|
|
|
|
self.sliding_window = input_builder.sliding_window
|
|
|
|
|
self.block_size = input_builder.block_size
|
|
|
|
|
|
2025-02-21 17:07:37 +08:00
|
|
|
self.attn_mask = None
|
2025-05-09 16:39:28 +08:00
|
|
|
self.compress_mask = None
|
|
|
|
|
self.chunk_mask = None
|
2025-02-21 17:07:37 +08:00
|
|
|
if AscendMetadataBuilder._attn_mask_builder is None:
|
|
|
|
|
AscendMetadataBuilder._attn_mask_builder = AttentionMaskBuilder.initialize_from_len(
|
|
|
|
|
128, self.input_builder.runner.model_config.dtype)
|
2025-02-05 10:53:12 +08:00
|
|
|
|
|
|
|
|
def _add_seq_group(
|
2025-04-18 12:23:32 +08:00
|
|
|
self, inter_data: ModelInputForNPUBuilder.InterDataForSeqGroup,
|
2025-02-05 10:53:12 +08:00
|
|
|
chunked_prefill_enabled: bool):
|
|
|
|
|
"""Add a sequence group to the metadata. Specifically update/append
|
|
|
|
|
1. context length.
|
|
|
|
|
2. block table.
|
|
|
|
|
3. slot mapping.
|
|
|
|
|
"""
|
|
|
|
|
is_prompt = inter_data.is_prompt
|
|
|
|
|
block_tables = inter_data.block_tables
|
|
|
|
|
|
|
|
|
|
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
|
|
|
|
|
curr_sliding_window_block) in zip(
|
|
|
|
|
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
|
|
|
|
|
inter_data.orig_seq_lens, inter_data.seq_lens,
|
|
|
|
|
inter_data.query_lens, inter_data.context_lens,
|
|
|
|
|
inter_data.curr_sliding_window_blocks):
|
|
|
|
|
self.context_lens.append(context_len)
|
|
|
|
|
if is_prompt:
|
|
|
|
|
self.num_prefills += 1
|
|
|
|
|
self.num_prefill_tokens += token_len
|
|
|
|
|
self.prefill_seq_lens.append(seq_len)
|
|
|
|
|
else:
|
|
|
|
|
self.num_decode_tokens += query_len
|
|
|
|
|
self.curr_seq_lens.append(curr_seq_len)
|
|
|
|
|
|
|
|
|
|
# Compute block table.
|
|
|
|
|
# TODO(sang): Combine chunked prefill and prefix caching by
|
|
|
|
|
# only allowing multiple of block_size chunk size.
|
|
|
|
|
# NOTE: This only works for oooooooxxx style attention.
|
|
|
|
|
block_table: List[int] = []
|
|
|
|
|
prefix_cache_hit = any([
|
|
|
|
|
inter_data.prefix_cache_hit
|
|
|
|
|
for inter_data in self.input_builder.inter_data_list
|
|
|
|
|
])
|
|
|
|
|
if prefix_cache_hit:
|
|
|
|
|
# NOTE(woosuk): For flash-attn, the block table should
|
|
|
|
|
# include the entries for the incoming prefill tokens.
|
|
|
|
|
if block_tables is not None:
|
|
|
|
|
block_table = block_tables[seq_id]
|
|
|
|
|
elif ((chunked_prefill_enabled or not is_prompt)
|
|
|
|
|
and block_tables is not None):
|
|
|
|
|
if curr_sliding_window_block == 0:
|
|
|
|
|
block_table = block_tables[seq_id]
|
|
|
|
|
else:
|
|
|
|
|
block_table = block_tables[seq_id][
|
|
|
|
|
-curr_sliding_window_block:]
|
|
|
|
|
self.block_tables.append(block_table)
|
|
|
|
|
|
|
|
|
|
# Compute slot mapping.
|
|
|
|
|
is_profile_run = is_block_tables_empty(block_tables)
|
|
|
|
|
start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
|
|
|
|
|
context_len,
|
|
|
|
|
self.sliding_window)
|
2025-02-21 17:07:37 +08:00
|
|
|
compute_slot_mapping(
|
|
|
|
|
is_profile_run,
|
|
|
|
|
self.slot_mapping,
|
|
|
|
|
seq_id,
|
|
|
|
|
seq_len,
|
|
|
|
|
context_len,
|
|
|
|
|
start_idx,
|
|
|
|
|
self.block_size,
|
|
|
|
|
inter_data.block_tables,
|
|
|
|
|
)
|
|
|
|
|
|
2025-04-19 17:38:18 +08:00
|
|
|
def _get_graph_runner_block_tables(
|
|
|
|
|
self, num_seqs: int,
|
|
|
|
|
block_tables: List[List[int]]) -> torch.Tensor:
|
|
|
|
|
# The shape of graph_block_tables is
|
|
|
|
|
# [max batch size, max context len // block size].
|
|
|
|
|
|
|
|
|
|
max_batch_size, max_blocks = self.runner.graph_block_tables.shape
|
|
|
|
|
assert max_batch_size >= num_seqs
|
|
|
|
|
|
|
|
|
|
graph_block_tables = self.runner.graph_block_tables # [:num_seqs]
|
|
|
|
|
for i, block_table in enumerate(block_tables):
|
|
|
|
|
if block_table:
|
|
|
|
|
num_blocks = len(block_table)
|
|
|
|
|
if num_blocks <= max_blocks:
|
|
|
|
|
graph_block_tables[i, :num_blocks] = block_table
|
|
|
|
|
else:
|
|
|
|
|
graph_block_tables[
|
|
|
|
|
i, :max_blocks] = block_table[:max_blocks]
|
|
|
|
|
|
|
|
|
|
return torch.from_numpy(graph_block_tables).to(
|
|
|
|
|
device=self.runner.device, non_blocking=True)
|
|
|
|
|
|
2025-02-21 17:07:37 +08:00
|
|
|
def build(
|
|
|
|
|
self,
|
|
|
|
|
seq_lens: List[int],
|
|
|
|
|
query_lens: List[int],
|
2025-04-19 17:38:18 +08:00
|
|
|
graph_pad_size: int,
|
2025-02-21 17:07:37 +08:00
|
|
|
):
|
|
|
|
|
"""Build attention metadata with on-device tensors.
|
2025-02-05 10:53:12 +08:00
|
|
|
|
2025-02-21 17:07:37 +08:00
|
|
|
Args:
|
|
|
|
|
seq_lens: The maybe padded sequence lengths of the input sequences.
|
|
|
|
|
query_lens: The query lengths of the input sequences.
|
|
|
|
|
"""
|
|
|
|
|
for inter_data in self.input_builder.inter_data_list:
|
|
|
|
|
self._add_seq_group(inter_data,
|
|
|
|
|
self.input_builder.chunked_prefill_enabled)
|
|
|
|
|
|
|
|
|
|
device = self.runner.device
|
2025-05-09 16:39:28 +08:00
|
|
|
dtype = self.runner.model_config.dtype
|
2025-05-01 13:51:42 +08:00
|
|
|
use_npu_graph = graph_pad_size != -1
|
2025-02-21 17:07:37 +08:00
|
|
|
|
|
|
|
|
max_query_len = max(query_lens)
|
|
|
|
|
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
|
|
|
|
|
max_decode_seq_len = max(self.curr_seq_lens, default=0)
|
2025-05-09 16:39:28 +08:00
|
|
|
max_seq_len = max(max_prefill_seq_len, max_decode_seq_len)
|
2025-03-11 19:20:06 +08:00
|
|
|
num_decode_tokens = self.num_decode_tokens
|
2025-02-21 17:07:37 +08:00
|
|
|
|
2025-05-01 13:51:42 +08:00
|
|
|
if self.num_prefills == 0 and use_npu_graph:
|
2025-04-19 17:38:18 +08:00
|
|
|
num_seqs = len(seq_lens)
|
|
|
|
|
self.slot_mapping.extend([PAD_SLOT_ID] * graph_pad_size)
|
|
|
|
|
self.block_tables.extend([[]] * graph_pad_size)
|
|
|
|
|
block_tables = self._get_graph_runner_block_tables(
|
|
|
|
|
num_seqs, self.block_tables)
|
|
|
|
|
else:
|
|
|
|
|
block_tables = make_tensor_with_pad(
|
|
|
|
|
self.block_tables,
|
|
|
|
|
pad=0,
|
|
|
|
|
dtype=torch.int32,
|
|
|
|
|
device=device,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if self.num_prefills > 0:
|
2025-05-09 16:39:28 +08:00
|
|
|
if block_tables is None or block_tables.numel() == 0:
|
|
|
|
|
# normal mask
|
|
|
|
|
self.attn_mask = AscendMetadataBuilder._attn_mask_builder.get_attn_mask( # type: ignore
|
|
|
|
|
max_prefill_seq_len, dtype, device)
|
|
|
|
|
elif self.num_decode_tokens == 0 and not self.input_builder.chunked_prefill_enabled:
|
|
|
|
|
# compress mask for prefix cache
|
|
|
|
|
self.compress_mask = AscendMetadataBuilder._attn_mask_builder.get_attn_mask( # type: ignore
|
|
|
|
|
128, dtype, device)
|
|
|
|
|
else:
|
|
|
|
|
# chunk_mask for chunk prefill
|
|
|
|
|
attn_mask = AscendMetadataBuilder._attn_mask_builder.get_attn_mask( # type: ignore
|
|
|
|
|
max_seq_len, dtype, device)
|
|
|
|
|
if attn_mask.numel() > 1 and attn_mask[0][1] > 0:
|
|
|
|
|
attn_mask *= -10000
|
|
|
|
|
chunk_mask_list = []
|
|
|
|
|
for i, seq_len in enumerate(seq_lens):
|
|
|
|
|
context_len = self.context_lens[i]
|
|
|
|
|
chunk_mask_list.append(attn_mask[context_len:seq_len])
|
|
|
|
|
self.chunk_mask = torch.cat(chunk_mask_list, 0)
|
2025-04-19 17:38:18 +08:00
|
|
|
else:
|
|
|
|
|
self.attn_mask = None
|
2025-05-09 16:39:28 +08:00
|
|
|
self.compress_mask = None
|
|
|
|
|
self.chunk_mask = None
|
2025-04-19 17:38:18 +08:00
|
|
|
|
2025-02-21 17:07:37 +08:00
|
|
|
assert max_query_len > 0, "query_lens: {}".format(query_lens)
|
|
|
|
|
|
|
|
|
|
assert device is not None
|
|
|
|
|
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.int32,
|
|
|
|
|
device, self.runner.pin_memory)
|
2025-03-11 19:20:06 +08:00
|
|
|
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
|
|
|
|
|
self.runner.pin_memory)
|
2025-02-21 17:07:37 +08:00
|
|
|
placeholder_index_maps = {
|
|
|
|
|
modality: placeholder_map.index_map()
|
|
|
|
|
for modality, placeholder_map in
|
|
|
|
|
self.multimodal_placeholder_maps.items()
|
|
|
|
|
}
|
|
|
|
|
|
[Core] Support pooling (#229)
This PR added pooling support for vllm-ascend
Tested with `bge-base-en-v1.5` by encode:
```
from vllm import LLM
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create an LLM.
model = LLM(model="./bge-base-en-v1.5", enforce_eager=True)
# Generate embedding. The output is a list of EmbeddingRequestOutputs.
outputs = model.encode(prompts)
# Print the outputs.
for output in outputs:
print(output.outputs.embedding) # list of 4096 floats
```
Tested by embedding:
```
from vllm import LLM, SamplingParams
llm = LLM(model="./bge-base-en-v1.5", task="embed")
(output,) = llm.embed("Hello, my name is")
embeds = output.outputs.embedding
print(f"Embeddings: {embeds!r} (size={len(embeds)})")
```
Related: https://github.com/vllm-project/vllm-ascend/issues/200
## Known issue
The accuracy is not correct since this feature rely on `enc-dec`
support. It'll be done in the following PR by @MengqingCao
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
2025-03-04 15:59:34 +08:00
|
|
|
return AscendMetadata(
|
2025-02-21 17:07:37 +08:00
|
|
|
num_prefills=self.num_prefills,
|
|
|
|
|
slot_mapping=slot_mapping_tensor,
|
|
|
|
|
num_prefill_tokens=self.num_prefill_tokens,
|
2025-03-11 19:20:06 +08:00
|
|
|
num_decode_tokens=num_decode_tokens,
|
2025-02-21 17:07:37 +08:00
|
|
|
seq_lens=seq_lens,
|
2025-03-11 19:20:06 +08:00
|
|
|
multi_modal_placeholder_index_maps=placeholder_index_maps,
|
|
|
|
|
enable_kv_scales_calculation=True,
|
[Core] Support pooling (#229)
This PR added pooling support for vllm-ascend
Tested with `bge-base-en-v1.5` by encode:
```
from vllm import LLM
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create an LLM.
model = LLM(model="./bge-base-en-v1.5", enforce_eager=True)
# Generate embedding. The output is a list of EmbeddingRequestOutputs.
outputs = model.encode(prompts)
# Print the outputs.
for output in outputs:
print(output.outputs.embedding) # list of 4096 floats
```
Tested by embedding:
```
from vllm import LLM, SamplingParams
llm = LLM(model="./bge-base-en-v1.5", task="embed")
(output,) = llm.embed("Hello, my name is")
embeds = output.outputs.embedding
print(f"Embeddings: {embeds!r} (size={len(embeds)})")
```
Related: https://github.com/vllm-project/vllm-ascend/issues/200
## Known issue
The accuracy is not correct since this feature rely on `enc-dec`
support. It'll be done in the following PR by @MengqingCao
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
2025-03-04 15:59:34 +08:00
|
|
|
seq_lens_tensor=seq_lens_tensor,
|
2025-05-09 16:39:28 +08:00
|
|
|
query_lens=query_lens,
|
2025-02-21 17:07:37 +08:00
|
|
|
max_query_len=max_query_len,
|
|
|
|
|
max_prefill_seq_len=max_prefill_seq_len,
|
|
|
|
|
max_decode_seq_len=max_decode_seq_len,
|
|
|
|
|
block_tables=block_tables,
|
|
|
|
|
attn_mask=self.attn_mask,
|
2025-05-09 16:39:28 +08:00
|
|
|
compress_mask=self.compress_mask,
|
|
|
|
|
chunk_mask=self.chunk_mask,
|
|
|
|
|
chunked_prefill_enabled=self.input_builder.chunked_prefill_enabled,
|
2025-02-21 17:07:37 +08:00
|
|
|
)
|
2025-02-05 10:53:12 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class AscendAttentionBackendImpl(AttentionImpl):
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
num_heads: int,
|
|
|
|
|
head_size: int,
|
|
|
|
|
scale: float,
|
|
|
|
|
num_kv_heads: int,
|
|
|
|
|
alibi_slopes: Optional[List[float]],
|
|
|
|
|
sliding_window: Optional[int],
|
|
|
|
|
kv_cache_dtype: str,
|
|
|
|
|
blocksparse_params: Optional[Dict[str, Any]] = None,
|
|
|
|
|
logits_soft_cap: Optional[float] = None,
|
|
|
|
|
attn_type: str = AttentionType.DECODER,
|
[Attention][Kernel]moe support for llama4 and mllama4 (#740)
### What this PR does / why we need it?
moe support for llama4 and mllama4 in vllm-ascend
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
start sever:
python -m vllm.entrypoints.openai.api_server --model
/data/nfs/benchmark/tokenizer/Llama-4-Scout-17B-16E-Instruct \
--max-num-seqs=256 \
--max-model-len=8192 \
--tensor-parallel-size=8 \
--block-size=128 \
--dtype bfloat16 \
--host=0.0.0.0 \
--port=8000 \
--gpu-memory-utilization=0.9 \
--trust-remote-code
client:
python online_server.py --model-path
/data/nfs/benchmark/tokenizer/Llama-4-Scout-17B-16E-Instruct
--image-path /data/nfs/w60040464/cherry_blossom.jpg --docker-ip
7.242.108.253 --served-port 8000 --text "what is the content of this
image?"
result:
{'id': 'chatcmpl-2b709a5d2e1a4017991ec4ba8248686a', 'object':
'chat.completion', 'created': 1747056823, 'model':
'/data/nfs/benchmark/tokenizer/Llama-4-Scout-17B-16E-Instruct',
'choices': [{'index': 0, 'message': {'role': 'assistant',
'reasoning_content': None, 'content': 'The image depicts a tower, likely
Tokyo Skytree, framed by branches of a cherry blossom tree. The tower is
white and has a distinctive shape, with a large sphere at the top and a
long, thin spire extending from it. The branches of the cherry blossom
tree are in the foreground, with pink flowers blooming on them. The
background is a clear blue sky.\n\n**Key Features:**\n\n* **Tower:**
White, spherical shape at the top, long thin spire\n', 'tool_calls':
[]}, 'logprobs': None, 'finish_reason': 'length', 'stop_reason': None}],
'usage': {'prompt_tokens': 2340, 'total_tokens': 2440,
'completion_tokens': 100, 'prompt_tokens_details': None},
'prompt_logprobs': None}
Signed-off-by: chenxu <chenxu68@huawei.com>
Co-authored-by: chenxu <chenxu68@huawei.com>
Co-authored-by: evian <eviantai@u.nus.edu>
2025-05-13 19:12:40 +08:00
|
|
|
use_irope: bool = False,
|
2025-02-05 10:53:12 +08:00
|
|
|
) -> None:
|
|
|
|
|
self.num_heads = num_heads
|
|
|
|
|
self.head_size = head_size
|
|
|
|
|
self.scale = float(scale)
|
|
|
|
|
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
2025-02-21 17:07:37 +08:00
|
|
|
self.hidden_size = self.num_heads * self.head_size
|
2025-02-05 10:53:12 +08:00
|
|
|
self.kv_cache_dtype = kv_cache_dtype
|
|
|
|
|
self.sliding_window = sliding_window
|
|
|
|
|
if alibi_slopes is not None:
|
2025-02-21 17:07:37 +08:00
|
|
|
alibi_slopes = torch.tensor(alibi_slopes,
|
|
|
|
|
dtype=torch.float32,
|
|
|
|
|
device="npu")
|
2025-02-05 10:53:12 +08:00
|
|
|
self.alibi_slopes = alibi_slopes
|
|
|
|
|
self.attn_type = attn_type
|
|
|
|
|
|
|
|
|
|
assert self.num_heads % self.num_kv_heads == 0
|
|
|
|
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
2025-02-21 17:07:37 +08:00
|
|
|
self.seq_len_cpu_tensor = None
|
2025-05-09 16:39:28 +08:00
|
|
|
self.query_len_cpu_tensor = None
|
2025-03-05 10:51:07 +08:00
|
|
|
self.key_cache = None
|
|
|
|
|
self.value_cache = None
|
2025-02-05 10:53:12 +08:00
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
|
self,
|
|
|
|
|
layer: AttentionLayer,
|
|
|
|
|
query: torch.Tensor,
|
|
|
|
|
key: torch.Tensor,
|
|
|
|
|
value: torch.Tensor,
|
2025-02-21 17:07:37 +08:00
|
|
|
kv_cache: torch.Tensor,
|
2025-02-05 10:53:12 +08:00
|
|
|
attn_metadata: AscendMetadata,
|
|
|
|
|
attn_type: str = AttentionType.DECODER,
|
|
|
|
|
output: Optional[torch.Tensor] = None,
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
"""Forward pass with Ascend attention.
|
|
|
|
|
Args:
|
|
|
|
|
query: shape = [num_tokens, num_heads * head_size]
|
|
|
|
|
num_tokens = batch_size * seq_len
|
|
|
|
|
key: shape = [num_tokens, num_kv_heads * head_size]
|
|
|
|
|
value: shape = [num_tokens, num_kv_heads * head_size]
|
|
|
|
|
kv_cache: shape = [2, num_blocks, block_size,
|
2025-05-20 09:31:30 +08:00
|
|
|
num_kv_heads, head_size]
|
2025-02-05 10:53:12 +08:00
|
|
|
key_cache = [num_blocks, block_size,
|
2025-05-20 09:31:30 +08:00
|
|
|
num_kv_heads, head_size]
|
2025-02-05 10:53:12 +08:00
|
|
|
value_cache = [num_blocks, block_size,
|
2025-05-20 09:31:30 +08:00
|
|
|
num_kv_heads, head_size]
|
2025-02-05 10:53:12 +08:00
|
|
|
attn_metadata: Metadata for attention.
|
|
|
|
|
Returns:
|
|
|
|
|
shape = [batch_size, seq_len * num_heads * head_size]
|
|
|
|
|
"""
|
2025-02-21 17:07:37 +08:00
|
|
|
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
|
|
|
|
|
# View q k v to BSH.
|
2025-02-05 10:53:12 +08:00
|
|
|
num_tokens = query.shape[0]
|
2025-02-21 17:07:37 +08:00
|
|
|
query = query.view(-1, self.num_heads, self.head_size)
|
|
|
|
|
key = key.view(-1, self.num_kv_heads, self.head_size)
|
|
|
|
|
value = value.view(-1, self.num_kv_heads, self.head_size)
|
|
|
|
|
# TODO: Remove this contiguous in the future.
|
|
|
|
|
value = value.contiguous()
|
[Attn] Support encoder-only attention with torch sdpa (#290)
### What this PR does / why we need it?
Support encoder-only attention with torch sdpa
fix
https://github.com/vllm-project/vllm-ascend/pull/229#issuecomment-2695942741
### Does this PR introduce _any_ user-facing change?
N/A
### How was this patch tested?
Test locally with `pytest
vllm-project/vllm/tests/entrypoints/openai/test_score.py`
**Note**: Since torch compile on npu are still work in process, we need
to comment the following code to make UT run:
https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/vocab_parallel_embedding.py#L138
result:
```bash
/home/xxx/miniconda3/envs/atb/lib/python3.10/site-packages/pytest_asyncio/plugin.py:207: PytestDeprecationWarning: The configuration option "asyncio_default_fixture_loop_scope" is unset.
The event loop scope for asynchronous fixtures will default to the fixture caching scope. Future versions of pytest-asyncio will default the loop scope for asynchronous fixtures to function scope. Set the default fixture loop scope explicitly in order to avoid unexpected behavior in the future. Valid fixture loop scopes are: "function", "class", "module", "package", "session"
warnings.warn(PytestDeprecationWarning(_DEFAULT_FIXTURE_LOOP_SCOPE_UNSET))
================================================================================== test session starts ===================================================================================
platform linux -- Python 3.10.16, pytest-8.3.4, pluggy-1.5.0
rootdir: /home/xxx/code/vllm-cpu/vllm
configfile: pyproject.toml
plugins: shard-0.1.2, rerunfailures-15.0, asyncio-0.25.3, anyio-4.8.0, mock-3.14.0, forked-1.6.0, typeguard-4.3.0
asyncio: mode=strict, asyncio_default_fixture_loop_scope=None
collected 8 items
Running 8 items in this shard
tests/entrypoints/openai/test_score.py ........ [100%]
==================================================================================== warnings summary ====================================================================================
../../../miniconda3/envs/atb/lib/python3.10/site-packages/torch_npu/dynamo/torchair/__init__.py:8
/home/cmq/miniconda3/envs/atb/lib/python3.10/site-packages/torch_npu/dynamo/torchair/__init__.py:8: DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html
import pkg_resources
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
======================================================================== 8 passed, 1 warning in 131.42s (0:02:11) ========================================================================
```
This ut will be included in CI when torch compile feature is done.
Signed-off-by: MengqingCao <cmq0113@163.com>
2025-03-12 08:57:29 +08:00
|
|
|
attn_type = self.attn_type
|
2025-02-21 17:07:37 +08:00
|
|
|
|
|
|
|
|
output = torch.empty(num_tokens,
|
|
|
|
|
self.num_heads,
|
|
|
|
|
self.head_size,
|
|
|
|
|
dtype=query.dtype,
|
|
|
|
|
device=query.device)
|
|
|
|
|
|
2025-03-05 10:51:07 +08:00
|
|
|
if kv_cache.numel() > 0:
|
|
|
|
|
if self.key_cache is None:
|
|
|
|
|
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
|
|
|
|
|
slots = attn_metadata.slot_mapping
|
|
|
|
|
|
2025-02-21 17:07:37 +08:00
|
|
|
if hasattr(layer, 'quant_method'):
|
|
|
|
|
isPrefill = True if attn_metadata.num_prefills > 0 else False
|
|
|
|
|
if isPrefill:
|
|
|
|
|
assert attn_metadata.prefill_metadata is not None
|
|
|
|
|
self.seq_lens_tensor_cpu = torch.from_numpy(
|
|
|
|
|
np.array(attn_metadata.prefill_metadata.seq_lens).astype(
|
|
|
|
|
np.int32))
|
|
|
|
|
else:
|
|
|
|
|
assert attn_metadata.decode_metadata is not None
|
|
|
|
|
self.seq_lens_tensor_cpu = torch.from_numpy(
|
|
|
|
|
np.array(attn_metadata.decode_metadata.seq_lens).astype(
|
|
|
|
|
np.int32))
|
|
|
|
|
block_tables = attn_metadata.decode_metadata.block_tables if attn_metadata.decode_metadata else None
|
|
|
|
|
# Details of kv_cache arrangement in attention quantization
|
|
|
|
|
# are implemented by quant_method.
|
2025-03-12 11:33:21 +08:00
|
|
|
layer.quant_method.apply(
|
|
|
|
|
layer,
|
|
|
|
|
query,
|
|
|
|
|
key,
|
|
|
|
|
value,
|
|
|
|
|
self.key_cache,
|
|
|
|
|
self.value_cache,
|
|
|
|
|
self.scale,
|
|
|
|
|
block_tables,
|
|
|
|
|
isPrefill,
|
|
|
|
|
attn_metadata,
|
|
|
|
|
output,
|
|
|
|
|
seq_lens_tensor_cpu=self.seq_lens_tensor_cpu)
|
2025-02-21 17:07:37 +08:00
|
|
|
else:
|
2025-03-05 10:51:07 +08:00
|
|
|
if self.key_cache is not None:
|
2025-02-27 16:40:23 +08:00
|
|
|
torch_npu._npu_reshape_and_cache(key=key,
|
|
|
|
|
value=value,
|
2025-03-05 10:51:07 +08:00
|
|
|
key_cache=self.key_cache,
|
|
|
|
|
value_cache=self.value_cache,
|
2025-02-27 16:40:23 +08:00
|
|
|
slot_indices=slots)
|
2025-02-21 17:07:37 +08:00
|
|
|
|
|
|
|
|
if attn_metadata.num_prefills > 0:
|
2025-05-09 16:39:28 +08:00
|
|
|
# Prefix cache disabled and chunk prefill disabled or no prefix cache hit
|
2025-02-21 17:07:37 +08:00
|
|
|
if (attn_metadata.block_tables is None
|
|
|
|
|
or attn_metadata.block_tables.numel() == 0):
|
[Attn] Support encoder-only attention with torch sdpa (#290)
### What this PR does / why we need it?
Support encoder-only attention with torch sdpa
fix
https://github.com/vllm-project/vllm-ascend/pull/229#issuecomment-2695942741
### Does this PR introduce _any_ user-facing change?
N/A
### How was this patch tested?
Test locally with `pytest
vllm-project/vllm/tests/entrypoints/openai/test_score.py`
**Note**: Since torch compile on npu are still work in process, we need
to comment the following code to make UT run:
https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/vocab_parallel_embedding.py#L138
result:
```bash
/home/xxx/miniconda3/envs/atb/lib/python3.10/site-packages/pytest_asyncio/plugin.py:207: PytestDeprecationWarning: The configuration option "asyncio_default_fixture_loop_scope" is unset.
The event loop scope for asynchronous fixtures will default to the fixture caching scope. Future versions of pytest-asyncio will default the loop scope for asynchronous fixtures to function scope. Set the default fixture loop scope explicitly in order to avoid unexpected behavior in the future. Valid fixture loop scopes are: "function", "class", "module", "package", "session"
warnings.warn(PytestDeprecationWarning(_DEFAULT_FIXTURE_LOOP_SCOPE_UNSET))
================================================================================== test session starts ===================================================================================
platform linux -- Python 3.10.16, pytest-8.3.4, pluggy-1.5.0
rootdir: /home/xxx/code/vllm-cpu/vllm
configfile: pyproject.toml
plugins: shard-0.1.2, rerunfailures-15.0, asyncio-0.25.3, anyio-4.8.0, mock-3.14.0, forked-1.6.0, typeguard-4.3.0
asyncio: mode=strict, asyncio_default_fixture_loop_scope=None
collected 8 items
Running 8 items in this shard
tests/entrypoints/openai/test_score.py ........ [100%]
==================================================================================== warnings summary ====================================================================================
../../../miniconda3/envs/atb/lib/python3.10/site-packages/torch_npu/dynamo/torchair/__init__.py:8
/home/cmq/miniconda3/envs/atb/lib/python3.10/site-packages/torch_npu/dynamo/torchair/__init__.py:8: DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html
import pkg_resources
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
======================================================================== 8 passed, 1 warning in 131.42s (0:02:11) ========================================================================
```
This ut will be included in CI when torch compile feature is done.
Signed-off-by: MengqingCao <cmq0113@163.com>
2025-03-12 08:57:29 +08:00
|
|
|
if attn_type == AttentionType.ENCODER_ONLY:
|
|
|
|
|
# TODO: change to use torch_npu encoder attention op, instead
|
|
|
|
|
# of torch sdpa
|
|
|
|
|
query = query.movedim(0, query.dim() - 2)
|
|
|
|
|
key = key.movedim(0, key.dim() - 2)
|
|
|
|
|
value = value.movedim(0, value.dim() - 2)
|
|
|
|
|
|
|
|
|
|
causal_attn = (attn_type == AttentionType.DECODER)
|
|
|
|
|
if attn_metadata.seq_lens is not None:
|
|
|
|
|
seq_lens_q = seq_lens_kv = attn_metadata.seq_lens
|
|
|
|
|
attn_masks = [None] * len(seq_lens_q)
|
|
|
|
|
start_q, start_kv = 0, 0
|
|
|
|
|
for seq_len_q, seq_len_kv, mask in zip(
|
|
|
|
|
seq_lens_q, seq_lens_kv, attn_masks):
|
|
|
|
|
end_q = start_q + seq_len_q
|
|
|
|
|
end_kv = start_kv + seq_len_kv
|
|
|
|
|
sub_out = scaled_dot_product_attention(
|
|
|
|
|
query[None, :, start_q:end_q, :],
|
|
|
|
|
key[None, :, start_kv:end_kv, :],
|
|
|
|
|
value[None, :, start_kv:end_kv, :],
|
|
|
|
|
attn_mask=mask,
|
|
|
|
|
dropout_p=0.0,
|
|
|
|
|
is_causal=causal_attn and mask is None,
|
|
|
|
|
scale=self.scale).squeeze(0).movedim(
|
|
|
|
|
query.dim() - 2, 0)
|
|
|
|
|
output[start_q:end_q, :, :] = sub_out
|
|
|
|
|
start_q, start_kv = end_q, end_kv
|
|
|
|
|
else:
|
|
|
|
|
assert attn_metadata.attn_mask is not None
|
|
|
|
|
mask = attn_metadata.attn_mask
|
|
|
|
|
assert attn_metadata.prefill_metadata is not None
|
|
|
|
|
self.seq_lens_tensor_cpu = torch.from_numpy(
|
|
|
|
|
np.array(attn_metadata.prefill_metadata.seq_lens).
|
|
|
|
|
astype(np.int32))
|
|
|
|
|
torch_npu._npu_flash_attention(
|
|
|
|
|
query=query,
|
|
|
|
|
key=key,
|
|
|
|
|
value=value,
|
|
|
|
|
mask=mask,
|
|
|
|
|
seq_len=self.seq_lens_tensor_cpu,
|
|
|
|
|
scale_value=self.scale,
|
|
|
|
|
num_heads=self.num_heads,
|
|
|
|
|
num_kv_heads=self.num_kv_heads,
|
|
|
|
|
out=output)
|
2025-05-09 16:39:28 +08:00
|
|
|
# Prefix cache only and cache hit
|
|
|
|
|
elif attn_metadata.num_decode_tokens == 0 and not attn_metadata.chunked_prefill_enabled:
|
|
|
|
|
assert kv_cache is not None
|
|
|
|
|
assert attn_metadata.prefill_metadata is not None
|
|
|
|
|
self.seq_lens_tensor_cpu = torch.from_numpy(
|
|
|
|
|
np.array(
|
|
|
|
|
attn_metadata.prefill_metadata.seq_lens).astype(
|
|
|
|
|
np.int32))
|
|
|
|
|
self.query_lens_tensor_cpu = torch.from_numpy(
|
|
|
|
|
np.array(
|
|
|
|
|
attn_metadata.prefill_metadata.query_lens).astype(
|
|
|
|
|
np.int32))
|
|
|
|
|
block_tables = attn_metadata.prefill_metadata.block_tables
|
|
|
|
|
assert attn_metadata.compress_mask is not None
|
|
|
|
|
compress_mask = attn_metadata.compress_mask
|
|
|
|
|
torch_npu._npu_flash_attention_qlens(
|
|
|
|
|
query=query,
|
|
|
|
|
key_cache=self.key_cache,
|
|
|
|
|
value_cache=self.value_cache,
|
|
|
|
|
block_table=block_tables,
|
|
|
|
|
mask=compress_mask,
|
|
|
|
|
seq_len=self.query_lens_tensor_cpu,
|
|
|
|
|
context_lens=self.seq_lens_tensor_cpu,
|
|
|
|
|
num_kv_heads=self.num_kv_heads,
|
|
|
|
|
num_heads=self.num_heads,
|
|
|
|
|
scale_value=self.scale,
|
|
|
|
|
out=output)
|
|
|
|
|
# Splitfuse
|
2025-02-21 17:07:37 +08:00
|
|
|
else:
|
2025-05-09 16:39:28 +08:00
|
|
|
assert kv_cache is not None
|
|
|
|
|
self.seq_lens_tensor_cpu = torch.from_numpy(
|
|
|
|
|
np.array(attn_metadata.seq_lens).astype(np.int32))
|
|
|
|
|
self.query_lens_tensor_cpu = torch.from_numpy(
|
|
|
|
|
np.array(attn_metadata.query_lens).astype(np.int32))
|
|
|
|
|
block_tables = attn_metadata.block_tables
|
|
|
|
|
assert attn_metadata.chunk_mask is not None
|
|
|
|
|
chunk_mask = attn_metadata.chunk_mask
|
|
|
|
|
torch_npu._npu_paged_attention_splitfuse(
|
|
|
|
|
query=query,
|
|
|
|
|
key_cache=self.key_cache,
|
|
|
|
|
value_cache=self.value_cache,
|
|
|
|
|
block_table=block_tables,
|
|
|
|
|
context_lens=self.seq_lens_tensor_cpu,
|
|
|
|
|
mask=chunk_mask,
|
|
|
|
|
seq_len=self.query_lens_tensor_cpu,
|
|
|
|
|
num_kv_heads=self.num_kv_heads,
|
|
|
|
|
num_heads=self.num_heads,
|
|
|
|
|
scale_value=self.scale,
|
|
|
|
|
out=output)
|
|
|
|
|
# Decode only
|
|
|
|
|
else:
|
2025-03-05 10:51:07 +08:00
|
|
|
assert self.key_cache is not None
|
2025-05-09 16:39:28 +08:00
|
|
|
assert self.value_cache is not None
|
|
|
|
|
assert attn_metadata.decode_metadata is not None
|
2025-02-21 17:07:37 +08:00
|
|
|
self.seq_lens_tensor_cpu = torch.from_numpy(
|
|
|
|
|
np.array(attn_metadata.decode_metadata.seq_lens).astype(
|
|
|
|
|
np.int32))
|
|
|
|
|
block_tables = attn_metadata.decode_metadata.block_tables
|
2025-02-27 16:40:23 +08:00
|
|
|
torch_npu._npu_paged_attention(
|
2025-02-21 17:07:37 +08:00
|
|
|
query=query,
|
2025-03-05 10:51:07 +08:00
|
|
|
key_cache=self.key_cache,
|
|
|
|
|
value_cache=self.value_cache,
|
2025-02-27 16:40:23 +08:00
|
|
|
num_kv_heads=self.num_kv_heads,
|
|
|
|
|
num_heads=self.num_heads,
|
|
|
|
|
scale_value=self.scale,
|
|
|
|
|
block_table=block_tables,
|
|
|
|
|
context_lens=self.seq_lens_tensor_cpu,
|
|
|
|
|
out=output)
|
2025-02-21 17:07:37 +08:00
|
|
|
|
|
|
|
|
return output.view(num_tokens, self.hidden_size)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AscendMLAAttentionBackendImpl(MLAAttentionImpl):
|
2025-02-05 10:53:12 +08:00
|
|
|
|
2025-02-21 17:07:37 +08:00
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
num_heads: int,
|
|
|
|
|
head_size: int,
|
|
|
|
|
scale: float,
|
|
|
|
|
num_kv_heads: int,
|
|
|
|
|
alibi_slopes: Optional[List[float]],
|
|
|
|
|
sliding_window: Optional[int],
|
|
|
|
|
kv_cache_dtype: str,
|
|
|
|
|
blocksparse_params: Optional[Dict[str, Any]] = None,
|
|
|
|
|
logits_soft_cap: Optional[float] = None,
|
|
|
|
|
attn_type: str = AttentionType.DECODER,
|
|
|
|
|
**extra_impl_args,
|
|
|
|
|
) -> None:
|
|
|
|
|
self.num_heads = num_heads
|
|
|
|
|
self.head_size = head_size
|
|
|
|
|
self.scale = float(scale)
|
|
|
|
|
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
|
|
|
|
self.hidden_size = self.num_heads * self.head_size
|
|
|
|
|
self.kv_cache_dtype = kv_cache_dtype
|
|
|
|
|
self.sliding_window = sliding_window
|
|
|
|
|
if alibi_slopes is not None:
|
|
|
|
|
alibi_slopes = torch.tensor(alibi_slopes,
|
|
|
|
|
dtype=torch.float32,
|
|
|
|
|
device="npu")
|
|
|
|
|
self.alibi_slopes = alibi_slopes
|
|
|
|
|
self.attn_type = attn_type
|
|
|
|
|
|
|
|
|
|
assert self.num_heads % self.num_kv_heads == 0
|
|
|
|
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
|
|
|
|
self.seq_len_cpu_tensor = None
|
|
|
|
|
|
|
|
|
|
# MLA Args
|
|
|
|
|
self.q_lora_rank = extra_impl_args['q_lora_rank']
|
|
|
|
|
self.kv_lora_rank = extra_impl_args['kv_lora_rank']
|
|
|
|
|
self.qk_nope_head_dim = extra_impl_args['qk_nope_head_dim']
|
|
|
|
|
self.qk_rope_head_dim = extra_impl_args['qk_rope_head_dim']
|
|
|
|
|
self.qk_head_dim = extra_impl_args['qk_head_dim']
|
|
|
|
|
self.v_head_dim = extra_impl_args['v_head_dim']
|
|
|
|
|
self.rotary_emb = extra_impl_args['rotary_emb']
|
|
|
|
|
self.q_proj = extra_impl_args['q_proj']
|
|
|
|
|
self.kv_b_proj = extra_impl_args['kv_b_proj']
|
|
|
|
|
self.o_proj = extra_impl_args['o_proj']
|
2025-04-19 17:38:18 +08:00
|
|
|
self.kv_a_proj_with_mqa = extra_impl_args.get('kv_a_proj_with_mqa',
|
|
|
|
|
None)
|
|
|
|
|
self.kv_a_layernorm = extra_impl_args.get('kv_a_layernorm', None)
|
|
|
|
|
self.k_pe_cache = None
|
|
|
|
|
self.k_nope_cache = None
|
2025-02-21 17:07:37 +08:00
|
|
|
self.w_kc = None
|
|
|
|
|
self.w_vc = None
|
2025-02-05 10:53:12 +08:00
|
|
|
|
2025-04-23 16:23:25 +08:00
|
|
|
self.enable_graph_mode = False
|
|
|
|
|
additional_config = get_current_vllm_config().additional_config
|
|
|
|
|
if additional_config:
|
|
|
|
|
self.enable_graph_mode = additional_config.get(
|
|
|
|
|
"enable_graph_mode", False)
|
|
|
|
|
|
2025-04-19 17:38:18 +08:00
|
|
|
def exec_kv(
|
|
|
|
|
self,
|
|
|
|
|
hidden_states: torch.Tensor,
|
|
|
|
|
cos: torch.Tensor,
|
|
|
|
|
sin: torch.Tensor,
|
|
|
|
|
kv_cache: Tuple,
|
|
|
|
|
slots: torch.Tensor,
|
|
|
|
|
):
|
|
|
|
|
B = hidden_states.shape[0]
|
|
|
|
|
N = self.num_kv_heads
|
|
|
|
|
S = 1
|
|
|
|
|
kv = self.kv_a_proj_with_mqa(hidden_states)[0]
|
|
|
|
|
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
|
|
|
|
|
kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
|
|
|
|
|
|
2025-05-01 13:51:42 +08:00
|
|
|
k_pe, k_nope, _, _ = torch.ops.npu_inference.npu_kv_rmsnorm_rope_cache(
|
2025-04-19 17:38:18 +08:00
|
|
|
kv,
|
|
|
|
|
self.kv_a_layernorm.weight,
|
|
|
|
|
cos,
|
|
|
|
|
sin,
|
|
|
|
|
slots.to(torch.int64),
|
|
|
|
|
kv_cache[1],
|
|
|
|
|
kv_cache[0],
|
|
|
|
|
epsilon=self.kv_a_layernorm.variance_epsilon,
|
|
|
|
|
cache_mode="PA",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return k_pe, k_nope
|
|
|
|
|
|
|
|
|
|
def apply_rotary_emb(
|
|
|
|
|
self,
|
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
cos: torch.Tensor,
|
|
|
|
|
sin: torch.Tensor,
|
|
|
|
|
is_neox_style: bool,
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
"""
|
|
|
|
|
Args:
|
|
|
|
|
x: [num_tokens, num_heads, head_size]
|
|
|
|
|
cos: [num_tokens, head_size // 2]
|
|
|
|
|
sin: [num_tokens, head_size // 2]
|
|
|
|
|
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
|
|
|
|
|
positional embeddings.
|
|
|
|
|
"""
|
|
|
|
|
cos = cos.unsqueeze(-2).to(x.dtype)
|
|
|
|
|
sin = sin.unsqueeze(-2).to(x.dtype)
|
|
|
|
|
if is_neox_style:
|
|
|
|
|
x1, x2 = torch.chunk(x, 2, dim=-1)
|
|
|
|
|
else:
|
|
|
|
|
x1 = x[..., ::2]
|
|
|
|
|
x2 = x[..., 1::2]
|
|
|
|
|
o1 = x1 * cos - x2 * sin
|
|
|
|
|
o2 = x2 * cos + x1 * sin
|
|
|
|
|
if is_neox_style:
|
|
|
|
|
return torch.cat((o1, o2), dim=-1)
|
|
|
|
|
else:
|
|
|
|
|
return torch.stack((o1, o2), dim=-1).flatten(-2)
|
|
|
|
|
|
|
|
|
|
def rope_single(
|
|
|
|
|
self,
|
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
cos: torch.Tensor,
|
|
|
|
|
sin: torch.Tensor,
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
B, N, D = x.shape
|
|
|
|
|
S = 1
|
|
|
|
|
x = x.view(B, N, S, D)
|
|
|
|
|
x = torch.ops.npu_inference.npu_interleave_rope(x, cos, sin)
|
|
|
|
|
return x.view(B, N, D)
|
|
|
|
|
|
|
|
|
|
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
|
|
|
|
if self.w_kc is None or self.w_vc is None:
|
|
|
|
|
kv_b_proj_weight = self.kv_b_proj.weight.reshape(
|
|
|
|
|
self.num_heads, self.qk_nope_head_dim + self.v_head_dim,
|
|
|
|
|
self.kv_lora_rank)
|
|
|
|
|
self.w_kc = kv_b_proj_weight[:, :self.
|
|
|
|
|
qk_nope_head_dim, :].contiguous()
|
|
|
|
|
self.w_vc = kv_b_proj_weight[:,
|
|
|
|
|
self.qk_nope_head_dim:, :].transpose(
|
|
|
|
|
1, 2).contiguous()
|
|
|
|
|
|
2025-02-21 17:07:37 +08:00
|
|
|
def forward(
|
|
|
|
|
self,
|
|
|
|
|
layer: AttentionLayer,
|
|
|
|
|
hidden_states_or_q_c: torch.Tensor,
|
2025-04-19 17:38:18 +08:00
|
|
|
hidden_states_or_kv_c_normed: torch.Tensor,
|
2025-02-21 17:07:37 +08:00
|
|
|
k_pe: torch.Tensor,
|
|
|
|
|
kv_cache: torch.Tensor,
|
|
|
|
|
attn_metadata: AscendMetadata,
|
|
|
|
|
attn_type: str = AttentionType.DECODER,
|
|
|
|
|
output: Optional[torch.Tensor] = None,
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
"""Forward pass with Ascend attention.
|
|
|
|
|
Args:
|
|
|
|
|
hidden_states_or_q_c: shape = [num_tokens, num_heads * head_size]
|
|
|
|
|
num_tokens = batch_size * seq_len
|
2025-04-19 17:38:18 +08:00
|
|
|
hidden_states_or_kv_c_normed: shape = [num_tokens, num_kv_heads * head_size]
|
2025-02-21 17:07:37 +08:00
|
|
|
k_pe: shape = [num_tokens, num_kv_heads * head_size]
|
|
|
|
|
kv_cache: shape = [1, num_blocks, block_size,
|
|
|
|
|
num_kv_heads * head_size]
|
|
|
|
|
attn_metadata: Metadata for attention.
|
|
|
|
|
Returns:
|
|
|
|
|
shape = [batch_size, seq_len * num_heads * head_size]
|
|
|
|
|
"""
|
|
|
|
|
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
|
|
|
|
|
attn_type = self.attn_type
|
|
|
|
|
if attn_type != AttentionType.DECODER:
|
|
|
|
|
raise NotImplementedError("Encoder self-attention and "
|
|
|
|
|
"encoder/decoder cross-attention "
|
|
|
|
|
"are not implemented for "
|
|
|
|
|
"PallasAttentionBackendImpl")
|
|
|
|
|
|
2025-04-19 17:38:18 +08:00
|
|
|
if attn_metadata is None:
|
|
|
|
|
# for profile run
|
|
|
|
|
return hidden_states_or_q_c
|
|
|
|
|
|
2025-02-21 17:07:37 +08:00
|
|
|
num_tokens = hidden_states_or_q_c.shape[0]
|
|
|
|
|
q = self.q_proj(hidden_states_or_q_c)[0].view(-1, self.num_heads,
|
|
|
|
|
self.qk_head_dim)
|
|
|
|
|
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim],
|
|
|
|
|
dim=-1)
|
2025-04-19 17:38:18 +08:00
|
|
|
if k_pe is None and attn_metadata.decode_metadata:
|
|
|
|
|
seq_len = self.rotary_emb.max_position_embeddings
|
|
|
|
|
|
|
|
|
|
cos = self.rotary_emb.cos_cached[:seq_len].to(dtype=q_pe.dtype)
|
|
|
|
|
sin = self.rotary_emb.sin_cached[:seq_len].to(dtype=q_pe.dtype)
|
|
|
|
|
cos = cos[attn_metadata.input_positions]
|
|
|
|
|
sin = sin[attn_metadata.input_positions]
|
|
|
|
|
cos = cos[:, None, None, :]
|
|
|
|
|
sin = sin[:, None, None, :]
|
|
|
|
|
|
|
|
|
|
q_pe = self.rope_single(q_pe, cos, sin)
|
|
|
|
|
k_pe, k_nope = self.exec_kv(hidden_states_or_kv_c_normed, cos, sin,
|
|
|
|
|
kv_cache, attn_metadata.slot_mapping)
|
2025-02-21 17:07:37 +08:00
|
|
|
else:
|
2025-04-19 17:38:18 +08:00
|
|
|
if k_pe is None:
|
|
|
|
|
# NOTE: k_pe is None when graph mode enabled
|
|
|
|
|
kv_c, k_pe = self.kv_a_proj_with_mqa(
|
|
|
|
|
hidden_states_or_kv_c_normed)[0].split(
|
|
|
|
|
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
|
|
|
|
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
|
|
|
|
|
else:
|
|
|
|
|
kv_c_normed = hidden_states_or_kv_c_normed
|
|
|
|
|
k_pe = k_pe.view(num_tokens, self.num_kv_heads, -1)
|
|
|
|
|
if self.rotary_emb.__class__.__name__ == 'RotaryEmbedding':
|
|
|
|
|
# NOTE: When scaling not specified
|
|
|
|
|
ori_q_pe_shape, ori_k_pe_shape = q_pe.shape, k_pe.shape
|
|
|
|
|
q_pe = q_pe.reshape(num_tokens, -1)
|
|
|
|
|
k_pe = k_pe.reshape(num_tokens, -1)
|
|
|
|
|
q_pe, k_pe = self.rotary_emb(attn_metadata.input_positions,
|
|
|
|
|
q_pe, k_pe)
|
|
|
|
|
q_pe = q_pe.view(ori_q_pe_shape)
|
|
|
|
|
k_pe = k_pe.view(ori_k_pe_shape)
|
|
|
|
|
else:
|
|
|
|
|
q_pe, k_pe = self.rotary_emb(attn_metadata.input_positions,
|
|
|
|
|
q_pe, k_pe)
|
2025-02-05 10:53:12 +08:00
|
|
|
|
2025-02-21 17:07:37 +08:00
|
|
|
if attn_metadata.num_prefills > 0:
|
2025-04-19 17:38:18 +08:00
|
|
|
kv = self.kv_b_proj(kv_c_normed)[0].view(num_tokens,
|
|
|
|
|
self.num_heads, -1)
|
2025-02-21 17:07:37 +08:00
|
|
|
k_nope, value = kv.split([self.qk_nope_head_dim, self.v_head_dim],
|
|
|
|
|
dim=-1)
|
|
|
|
|
else:
|
2025-02-22 17:43:42 +08:00
|
|
|
q_nope_t = torch.transpose(q_nope, 0, 1)
|
2025-02-21 17:07:37 +08:00
|
|
|
q_nope_out = torch.bmm(q_nope_t, self.w_kc)
|
2025-02-22 17:43:42 +08:00
|
|
|
q_nope = torch.transpose(q_nope_out, 0, 1)
|
2025-02-21 17:07:37 +08:00
|
|
|
|
|
|
|
|
query = torch.cat([q_nope, q_pe], dim=-1).view(num_tokens,
|
|
|
|
|
self.num_heads, -1)
|
|
|
|
|
|
2025-04-19 17:38:18 +08:00
|
|
|
# TODO: Replace the env with more flexible expressions
|
2025-04-23 16:23:25 +08:00
|
|
|
if self.enable_graph_mode:
|
2025-04-19 17:38:18 +08:00
|
|
|
if len(kv_cache) > 0 and kv_cache[0].numel(
|
|
|
|
|
) > 0 and attn_metadata.num_prefills > 0:
|
|
|
|
|
slots = attn_metadata.slot_mapping
|
2025-04-29 11:14:19 +08:00
|
|
|
# NOTE: Separate the kv cache in advance to avoid OOM or other issues
|
2025-04-19 17:38:18 +08:00
|
|
|
torch_npu._npu_reshape_and_cache(key=kv_c_normed.view(
|
|
|
|
|
num_tokens, self.num_kv_heads, -1),
|
|
|
|
|
value=k_pe,
|
|
|
|
|
key_cache=kv_cache[0],
|
|
|
|
|
value_cache=kv_cache[1],
|
|
|
|
|
slot_indices=slots)
|
2025-04-28 08:53:06 +08:00
|
|
|
elif kv_cache.numel() > 0:
|
|
|
|
|
# TODO replace this naive implement with fusion kernel
|
|
|
|
|
concat_and_cache_mla(kv_c_normed, k_pe, kv_cache,
|
|
|
|
|
attn_metadata.slot_mapping)
|
2025-02-21 17:07:37 +08:00
|
|
|
|
|
|
|
|
if attn_metadata.num_prefills > 0:
|
|
|
|
|
attn_output = torch.empty(num_tokens,
|
|
|
|
|
self.num_heads,
|
|
|
|
|
self.v_head_dim,
|
|
|
|
|
dtype=query.dtype,
|
2025-03-21 15:55:51 +08:00
|
|
|
device=query.device)
|
2025-02-21 17:07:37 +08:00
|
|
|
if (attn_metadata.block_tables is None
|
2025-02-05 10:53:12 +08:00
|
|
|
or attn_metadata.block_tables.numel() == 0):
|
2025-02-21 17:07:37 +08:00
|
|
|
assert attn_metadata.attn_mask is not None
|
|
|
|
|
assert attn_metadata.prefill_metadata is not None
|
|
|
|
|
assert attn_metadata.prefill_metadata.seq_lens is not None
|
2025-04-19 17:38:18 +08:00
|
|
|
mask = attn_metadata.attn_mask
|
2025-02-21 17:07:37 +08:00
|
|
|
self.seq_lens_tensor_cpu = torch.from_numpy(
|
|
|
|
|
np.array(attn_metadata.prefill_metadata.seq_lens).astype(
|
|
|
|
|
np.int32))
|
2025-04-19 17:38:18 +08:00
|
|
|
k_pe = k_pe.repeat(1, self.num_heads, 1)
|
|
|
|
|
key = torch.cat(
|
|
|
|
|
[k_nope.view(num_tokens, self.num_heads, -1), k_pe], dim=2)
|
2025-02-27 18:50:52 +08:00
|
|
|
torch_npu._npu_flash_attention(
|
|
|
|
|
query=query,
|
|
|
|
|
key=key,
|
|
|
|
|
value=value,
|
|
|
|
|
mask=mask,
|
|
|
|
|
seq_len=self.seq_lens_tensor_cpu,
|
|
|
|
|
scale_value=self.scale,
|
|
|
|
|
num_heads=self.num_heads,
|
|
|
|
|
num_kv_heads=self.num_heads,
|
|
|
|
|
out=attn_output)
|
2025-02-05 10:53:12 +08:00
|
|
|
else:
|
2025-02-21 17:07:37 +08:00
|
|
|
# TODO: Will support prefix cache and chunked prefill soon.
|
|
|
|
|
raise RuntimeError(
|
|
|
|
|
"Prefix cache and chunked prefill are currently not supported."
|
|
|
|
|
)
|
2025-02-05 10:53:12 +08:00
|
|
|
elif attn_metadata.decode_metadata:
|
|
|
|
|
assert kv_cache is not None
|
2025-04-23 16:23:25 +08:00
|
|
|
if self.enable_graph_mode:
|
2025-05-01 13:51:42 +08:00
|
|
|
# shape of query for npu graph mode should be:
|
|
|
|
|
# [bs, num_heads_per_rank, seq_len, dim]
|
2025-04-19 17:38:18 +08:00
|
|
|
q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1)
|
|
|
|
|
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
|
2025-05-01 13:51:42 +08:00
|
|
|
# shape of knope/k_pe for npu graph mode should be:
|
|
|
|
|
# [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
|
|
|
|
|
block_size = kv_cache[0].shape[1]
|
|
|
|
|
k_nope = k_nope.view(-1, self.num_kv_heads, block_size,
|
|
|
|
|
self.kv_lora_rank)
|
|
|
|
|
k_pe = k_pe.view(-1, self.num_kv_heads, block_size,
|
|
|
|
|
self.qk_rope_head_dim)
|
2025-04-19 17:38:18 +08:00
|
|
|
attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score(
|
|
|
|
|
q_nope,
|
|
|
|
|
k_nope,
|
|
|
|
|
k_nope,
|
|
|
|
|
query_rope=q_pe,
|
|
|
|
|
key_rope=k_pe,
|
|
|
|
|
num_heads=self.num_heads,
|
2025-05-01 13:51:42 +08:00
|
|
|
num_key_value_heads=self.num_kv_heads,
|
2025-04-19 17:38:18 +08:00
|
|
|
input_layout="BNSD",
|
|
|
|
|
atten_mask=attn_metadata.attn_mask,
|
|
|
|
|
scale=self.scale,
|
|
|
|
|
antiquant_mode=0,
|
|
|
|
|
antiquant_scale=None,
|
|
|
|
|
block_table=attn_metadata.block_tables,
|
2025-05-01 13:51:42 +08:00
|
|
|
block_size=block_size,
|
2025-04-19 17:38:18 +08:00
|
|
|
actual_seq_lengths_kv=attn_metadata.seq_lens,
|
|
|
|
|
)
|
|
|
|
|
attn_output = attn_output.view(num_tokens, -1,
|
|
|
|
|
self.kv_lora_rank).transpose(
|
|
|
|
|
0, 1)
|
|
|
|
|
attn_output = torch.bmm(attn_output, self.w_vc).transpose(0, 1)
|
|
|
|
|
else:
|
|
|
|
|
# if torch.empty is used here, the preemptive scheduling case of
|
|
|
|
|
# test_mtp_correctness.py will fail to run.
|
|
|
|
|
attn_output = torch.randn(
|
|
|
|
|
[num_tokens, self.num_heads, self.kv_lora_rank],
|
|
|
|
|
dtype=query.dtype,
|
|
|
|
|
device=query.device)
|
|
|
|
|
self.seq_lens_tensor_cpu = torch.from_numpy(
|
|
|
|
|
np.array(attn_metadata.decode_metadata.seq_lens).astype(
|
|
|
|
|
np.int32))
|
|
|
|
|
block_tables = attn_metadata.decode_metadata.block_tables
|
|
|
|
|
torch_npu._npu_paged_attention_mla(
|
|
|
|
|
query=query,
|
|
|
|
|
key_cache=kv_cache,
|
|
|
|
|
num_kv_heads=self.num_kv_heads,
|
|
|
|
|
num_heads=self.num_heads,
|
|
|
|
|
scale_value=self.scale,
|
|
|
|
|
block_table=block_tables,
|
|
|
|
|
context_lens=self.seq_lens_tensor_cpu,
|
|
|
|
|
mla_vheadsize=self.kv_lora_rank,
|
|
|
|
|
out=attn_output)
|
|
|
|
|
attn_output_t = torch.transpose(attn_output, 0, 1)
|
|
|
|
|
attn_output_t = torch.bmm(attn_output_t, self.w_vc)
|
|
|
|
|
attn_output = torch.transpose(attn_output_t, 0, 1)
|
2025-02-21 17:07:37 +08:00
|
|
|
|
2025-02-22 17:43:42 +08:00
|
|
|
output, _ = self.o_proj(attn_output.reshape(num_tokens, -1))
|
2025-02-05 10:53:12 +08:00
|
|
|
|
|
|
|
|
return output
|