Files
enginex-vastai-va16-vllm/vllm_vacc/vllm/v1/attention/backends/vacc_attn.py
2026-04-02 04:55:00 +00:00

983 lines
40 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, ClassVar, Optional
import os
import numpy as np
import torch
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
# from vllm_vacc.vllm.attention.backends.vacc_attn import (VACCAttentionBackendImpl,
# VACCAttentionMetadata)
# from vllm_vacc.vllm.attention.backends.vacc_attn import VACCAttentionBackendImpl
# from vllm.attention.backends.utils import CommonAttentionState
from vllm.config import VllmConfig
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
from vllm.v1.worker.cpu_model_runner import CPUModelRunner
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm_vacc.vllm.model_executor.models.vars import BLOCK_GROUP_SIZE as env_blk_grp_size
from vllm_vacc.vllm.attention.ops.vacc_paged_attn import VaccPagedAttention as PagedAttention
def _make_alibi_bias(
alibi_slopes: torch.Tensor,
dtype: torch.dtype,
seq_lens: list[int],
) -> list[torch.Tensor]:
attn_biases: list[torch.Tensor] = []
for seq_len in seq_lens:
bias = torch.arange(seq_len, dtype=dtype)
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(seq_len, 1)`
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# paper.
bias = bias[None, :] - bias[:, None]
num_heads = alibi_slopes.shape[0]
bias = bias[None, :].repeat((num_heads, 1, 1))
bias.mul_(alibi_slopes[:, None, None]).unsqueeze_(0)
inf_mask = torch.empty(
(1, seq_len, seq_len),
dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1)
attn_biases.append((bias + inf_mask).to(dtype))
return attn_biases
def _make_sliding_window_bias(
seq_lens: list[int],
window_size: Optional[int],
dtype: torch.dtype,
) -> list[torch.Tensor]:
attn_biases: list[torch.Tensor] = []
for seq_len in seq_lens:
tensor = torch.full(
(1, seq_len, seq_len),
dtype=dtype,
fill_value=1,
)
shift = 0
mask = torch.tril(tensor, diagonal=shift).to(dtype) # type: ignore
if window_size is not None:
mask = torch.triu(mask, diagonal=shift - window_size + 1)
mask = torch.log(mask)
attn_biases.append(mask.to(dtype))
return attn_biases
@dataclass
class VACCAttentionMetadata(AttentionMetadata):
"""Metadata for VACCAttentionMetadata.
"""
# Total number of prefill requests.
num_prefills: int
# Number of prefill tokens.
num_prefill_tokens: int
# Number of decode tokens. Note that it is equivalent to the number of
# decode requests.
num_decode_tokens: int
# (num_tokens,). The indices of the token slots that input tokens will be
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
# in block 0, and 1st slot in block 1, respectively.
slot_mapping: torch.Tensor
"""Metadata for PagedAttention."""
# (batch_size,). The length of sequences (entire tokens seen so far) per
# sequence.
seq_lens_tensor: Optional[torch.Tensor]
# Maximum sequence length in the batch. 0 if it is prefill-only batch.
max_decode_seq_len: int
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
# in the kv cache. Each block can contain up to block_size tokens.
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured.
block_tables: Optional[torch.Tensor]
"""Metadata for TorchSDPABackend.
"""
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
chunked_prefill: bool
seq_lens: Optional[list[int]] = None # For non-chunked prefill
# For chunked prefill only
max_query_len: Optional[int] = None
max_kv_len: Optional[int] = None
prefill_query_start_loc: Optional[torch.Tensor] = None
kv_start_loc: Optional[torch.Tensor] = None
prefill_block_tables: Optional[torch.Tensor] = None
# For V1 logits index only
query_start_loc: Optional[torch.Tensor] = None
# Begin encoder attn & enc/dec cross-attn fields...
# Encoder sequence lengths representation
encoder_seq_lens: Optional[list[int]] = None
encoder_seq_lens_tensor: Optional[torch.Tensor] = None
# Maximum sequence length among encoder sequences
max_encoder_seq_len: Optional[int] = None
# Number of tokens input to encoder
num_encoder_tokens: Optional[int] = None
# Cross-attention memory-mapping data structures: slot mapping
# and block tables
cross_slot_mapping: Optional[torch.Tensor] = None
cross_block_tables: Optional[torch.Tensor] = None
def __post_init__(self):
# Set during the execution of the first attention op.
# It is a list because it is needed to set per prompt
# when alibi slopes is used. It is because of the limitation
# from xformer API.
# will not appear in the __repr__ and __init__
self.attn_bias: Optional[list[torch.Tensor]] = None
self.encoder_attn_bias: Optional[list[torch.Tensor]] = None
self.cross_attn_bias: Optional[list[torch.Tensor]] = None
@property
def is_all_encoder_attn_metadata_set(self):
'''
All attention metadata required for encoder attention is set.
'''
return ((self.encoder_seq_lens is not None)
and (self.encoder_seq_lens_tensor is not None)
and (self.max_encoder_seq_len is not None))
@property
def is_all_cross_attn_metadata_set(self):
'''
All attention metadata required for enc/dec cross-attention is set.
Superset of encoder attention required metadata.
'''
return (self.is_all_encoder_attn_metadata_set
and (self.cross_slot_mapping is not None)
and (self.cross_block_tables is not None))
@property
def prefill_metadata(self) -> Optional["VACCAttentionMetadata"]:
# Currently chunked prefill is not supported
if self.num_prefill_tokens == 0:
return None
return self
@property
def decode_metadata(self) -> Optional["VACCAttentionMetadata"]:
# Currently chunked prefill is not supported
if self.num_decode_tokens == 0:
return None
return self
def get_seq_lens(
self,
attn_type: AttentionType,
):
'''
Extract appropriate sequence lengths from attention metadata
according to attention type.
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate sequence lengths tensor for query
* Appropriate sequence lengths tensor for key & value
'''
if (attn_type == AttentionType.DECODER
or attn_type == AttentionType.ENCODER_ONLY):
seq_lens_q = self.seq_lens
seq_lens_kv = self.seq_lens
elif attn_type == AttentionType.ENCODER:
seq_lens_q = self.encoder_seq_lens
seq_lens_kv = self.encoder_seq_lens
elif attn_type == AttentionType.ENCODER_DECODER:
seq_lens_q = self.seq_lens
seq_lens_kv = self.encoder_seq_lens
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
return seq_lens_q, seq_lens_kv
def get_attn_bias(
self,
attn_type: AttentionType,
) -> Optional[list[torch.Tensor]]:
'''
Extract appropriate attention bias from attention metadata
according to attention type.
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate attention bias value given the attention type
'''
if (attn_type == AttentionType.DECODER
or attn_type == AttentionType.ENCODER_ONLY):
return self.attn_bias
elif attn_type == AttentionType.ENCODER:
return self.encoder_attn_bias
elif attn_type == AttentionType.ENCODER_DECODER:
return self.cross_attn_bias
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
def set_attn_bias(
self,
attn_bias: list[torch.Tensor],
attn_type: AttentionType,
) -> None:
'''
Update appropriate attention bias field of attention metadata,
according to attention type.
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* attn_bias: The desired attention bias value
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
'''
if (attn_type == AttentionType.DECODER
or attn_type == AttentionType.ENCODER_ONLY):
self.attn_bias = attn_bias
elif attn_type == AttentionType.ENCODER:
self.encoder_attn_bias = attn_bias
elif attn_type == AttentionType.ENCODER_DECODER:
self.cross_attn_bias = attn_bias
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
def get_seq_len_block_table_args(
self,
attn_type: str,
) -> tuple:
'''
The particular choice of sequence-length- and block-table-related
attributes which should be extracted from attn_metadata is dependent
on the type of attention operation.
Decoder attn -> select entirely decoder self-attention-related fields
Encoder/decoder cross-attn -> select encoder sequence lengths &
cross-attn block-tables fields
Encoder attn -> select encoder sequence lengths fields & no block tables
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* is_prompt: True if prefill, False otherwise
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate sequence-lengths tensor
* Appropriate max sequence-length scalar
* Appropriate block tables (or None)
'''
if (attn_type == AttentionType.DECODER
or attn_type == AttentionType.ENCODER_ONLY):
# Decoder self-attention
# Choose max_seq_len based on whether we are in prompt_run
return (self.seq_lens_tensor, self.max_decode_seq_len,
self.block_tables)
elif attn_type == AttentionType.ENCODER_DECODER:
# Enc/dec cross-attention KVs match encoder sequence length;
# cross-attention utilizes special "cross" block tables
return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len,
self.cross_block_tables)
elif attn_type == AttentionType.ENCODER:
# No block tables associated with encoder attention
return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len,
None)
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
# class VACCMetadataBuilder(AttentionMetadataBuilder[VACCAttentionMetadata]):
# def __init__(self, input_builder: ModelInputForVACCBuilder) -> None:
# self.chunked_prefill = input_builder.chunked_prefill
# self.input_builder = input_builder
# def prepare(self):
# self.input_data = self.input_builder.input_data
# def build(self, seq_lens: list[int], query_lens: list[int],
# cuda_graph_pad_size: int, batch_size: int) -> VACCAttentionMetadata:
# input_data = self.input_data
# prefill_seq_lens = seq_lens[0:input_data.num_prefills]
# prefill_query_lens = query_lens[0:input_data.num_prefills]
# slot_mapping = torch.tensor(input_data.slot_mapping,
# dtype=torch.int32,
# device=self.input_builder.device)
# # For chunked-prefill
# if self.chunked_prefill and input_data.num_prefill_tokens != 0:
# prefill_block_tables = make_tensor_with_pad(
# self.input_data.prefill_block_tables,
# pad=0,
# dtype=torch.int32,
# device=self.input_builder.device,
# )
# query_lens_tensor = torch.tensor(prefill_query_lens,
# dtype=torch.int32,
# device=self.input_builder.device)
# kv_lens_tensor = torch.tensor(prefill_seq_lens,
# dtype=torch.int32,
# device=self.input_builder.device)
# query_start_loc = torch.zeros(input_data.num_prefills + 1,
# dtype=torch.int32,
# device=self.input_builder.device)
# kv_start_loc = torch.zeros(input_data.num_prefills + 1,
# dtype=torch.int32,
# device=self.input_builder.device)
# torch.cumsum(query_lens_tensor,
# dim=0,
# dtype=torch.int32,
# out=query_start_loc[1:])
# torch.cumsum(kv_lens_tensor,
# dim=0,
# dtype=torch.int32,
# out=kv_start_loc[1:])
# max_query_len = max(prefill_query_lens)
# max_kv_len = max(prefill_seq_lens)
# else:
# prefill_block_tables = None
# query_start_loc = None
# kv_start_loc = None
# max_query_len = None
# max_kv_len = None
# # For paged attention
# if input_data.num_decode_tokens != 0:
# seq_lens_tensor = torch.tensor(
# input_data.seq_lens[input_data.num_prefills:],
# dtype=torch.int32,
# device=self.input_builder.device,
# )
# block_tables = make_tensor_with_pad(
# self.input_data.decode_block_tables,
# pad=0,
# dtype=torch.int32,
# device=self.input_builder.device,
# )
# # lowest_dim_size = block_tables.size(-1)
# # if lowest_dim_size < 1024:
# # padding_amount = 1024 - lowest_dim_size
# # padding = torch.zeros(*block_tables.size()[:-1], padding_amount, dtype=block_tables.dtype, device=block_tables.device)
# # block_tables = torch.cat((block_tables, padding), dim=-1)
# else:
# block_tables = torch.tensor([])
# seq_lens_tensor = torch.tensor(
# input_data.seq_lens[:input_data.num_prefills],
# dtype=torch.int32,
# device=self.input_builder.device,
# )
# # For multi-modal models
# placeholder_index_maps = None
# if len(input_data.multi_modal_inputs_list) != 0:
# placeholder_index_maps = {
# modality: placeholder_map.index_map()
# for modality, placeholder_map in
# input_data.multi_modal_placeholder_maps.items()
# }
# attn_metadata = VACCAttentionMetadata(
# chunked_prefill=self.chunked_prefill,
# seq_lens=seq_lens, #prefill_seq_lens,
# seq_lens_tensor=seq_lens_tensor,
# max_query_len=max_query_len,
# max_kv_len=max_kv_len,
# query_start_loc=query_start_loc,
# kv_start_loc=kv_start_loc,
# max_decode_seq_len=None,
# num_prefills=input_data.num_prefills,
# num_prefill_tokens=input_data.num_prefill_tokens,
# num_decode_tokens=input_data.num_decode_tokens,
# block_tables=block_tables,
# prefill_block_tables=prefill_block_tables,
# slot_mapping=slot_mapping,
# multi_modal_placeholder_index_maps=placeholder_index_maps,
# enable_kv_scales_calculation=False,
# )
# return attn_metadata
def fp32_attention(
query_layer,
key_layer,
value_layer,
mask,
norm_factor,
out_type=None,
):
ori_type = out_type if out_type is not None else query_layer.dtype
query_layer = query_layer.to(torch.float32)
key_layer = key_layer.to(torch.float32)
value_layer = value_layer.to(torch.float32)
# GQA
if query_layer.size(1) != key_layer.size(1):
if query_layer.size(1) % key_layer.size(1) != 0:
assert False
groups = query_layer.size(1) // key_layer.size(1)
key_layer = torch.repeat_interleave(key_layer, groups, dim=1)
value_layer = torch.repeat_interleave(value_layer, groups, dim=1)
matmul_result = torch.bmm(
query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
) * norm_factor
mask_output = matmul_result
if mask != None:
mask = mask if mask.dim() >= 3 else mask.unsqueeze(0)
mask_output = matmul_result.masked_fill_(mask, -10000.0) # [b * np, sq, sk]
probs = torch.nn.Softmax(dim=-1)(mask_output)
context_layer = torch.bmm(probs, value_layer.transpose(0, 1))
return context_layer.transpose(0, 1).to(ori_type)
class VACCAttentionBackendImpl(AttentionImpl[VACCAttentionMetadata]):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
) -> None:
# if logits_soft_cap is not None:
# logger.warning_once("Torch SPDA does not support logits soft cap. "
# "Outputs may be slightly off.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
self.sliding_window = sliding_window
self.kv_cache_dtype = kv_cache_dtype
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.need_mask = (self.alibi_slopes is not None
or self.sliding_window is not None)
if kv_cache_dtype != "auto":
raise NotImplementedError(
"Torch SDPA backend does not support FP8 KV cache. "
"Please use xFormers backend instead.")
self.attn_type = attn_type
def forward(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: VACCAttentionMetadata, # type: ignore
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with torch SDPA and PagedAttention.
Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run.
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
attn_type = self.attn_type
if (attn_type == AttentionType.ENCODER
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
raise AttributeError("Encoder attention requires setting "
"encoder metadata attributes.")
elif (attn_type == AttentionType.ENCODER_DECODER
and (not attn_metadata.is_all_cross_attn_metadata_set)):
raise AttributeError("Encoder/decoder cross-attention "
"requires setting cross-attention "
"metadata attributes.")
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)
if key is not None:
assert value is not None
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
else:
assert value is None
if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0):
# KV-cache during decoder-self- or
# encoder-decoder-cross-attention, but not
# during encoder attention.
#
# Even if there are no new key/value pairs to cache,
# we still need to break out key_cache and value_cache
# i.e. for later use by paged attention
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)
if (key is not None) and (value is not None):
if attn_type == AttentionType.ENCODER_DECODER:
# Update cross-attention KV cache (prefill-only)
# During cross-attention decode, key & value will be None,
# preventing this IF-statement branch from running
updated_slot_mapping = attn_metadata.cross_slot_mapping
else:
# Update self-attention KV cache (prefill/decode)
updated_slot_mapping = attn_metadata.slot_mapping
PagedAttention.write_to_paged_cache(key, value, key_cache,
value_cache,
updated_slot_mapping,
self.kv_cache_dtype,
layer._k_scale, layer._v_scale)
if attn_type != AttentionType.ENCODER:
# Decoder self-attention supports chunked prefill.
# Encoder/decoder cross-attention requires no chunked
# prefill (100% prefill or 100% decode tokens, no mix)
num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
else:
# Encoder attention - chunked prefill is not applicable;
# derive token-count from query shape & and treat them
# as 100% prefill tokens
assert attn_metadata.num_encoder_tokens is not None
num_prefill_tokens = attn_metadata.num_encoder_tokens
num_decode_tokens = 0
if attn_type == AttentionType.DECODER:
# Only enforce this shape-constraint for decoder
# self-attention
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
output = torch.empty_like(query)
if prefill_meta := attn_metadata.prefill_metadata:
assert attn_metadata.seq_lens is not None
if (kv_cache.numel() == 0
or prefill_meta.block_tables.numel() == 0):
self._run_vacc_forward(
output,
query,
key,
value,
prefill_meta,
attn_type=attn_type)
else:
# prefix-enabled attention
assert not self.need_mask
import intel_extension_for_pytorch.llm.modules as ipex_modules
output = torch.empty_like(query)
ipex_modules.PagedAttention.flash_attn_varlen_func(
output[:prefill_meta.num_prefill_tokens, :, :],
query[:prefill_meta.num_prefill_tokens, :, :],
key_cache,
value_cache,
prefill_meta.query_start_loc,
prefill_meta.kv_start_loc,
prefill_meta.max_query_len,
prefill_meta.max_kv_len,
self.scale,
True,
prefill_meta.prefill_block_tables,
self.alibi_slopes,
)
if decode_meta := attn_metadata.decode_metadata:
assert attn_type != AttentionType.ENCODER_ONLY, (
"Encoder-only models should not have decode metadata.")
# Decoding run.
# (
# seq_lens_arg,
# max_seq_len_arg,
# block_tables_arg,
# ) = decode_meta.get_seq_len_block_table_args(attn_type)
# Note:
# decode attention still use SDPA method
# reshape k/v_cache to (num_block_grp, block_grp_size, head, hidden_size)
k_cache = key_cache.view(-1, env_blk_grp_size, key_cache.shape[2], key_cache.shape[3])
v_cache = value_cache.view(-1, env_blk_grp_size, value_cache.shape[2], value_cache.shape[3])
block_per_group = env_blk_grp_size // 16
# convert block_tables to 8K group index
block_tables = (decode_meta.block_tables // block_per_group).to(torch.int32)
attn_outs = []
for i in range(len(decode_meta.seq_lens_tensor)):
seq_len = decode_meta.seq_lens_tensor[i]
k_slices = k_cache[block_tables[i], ...]
k = \
torch.cat([k_slices[i, ...] for i in range(len(block_tables[i]))], dim=0)[:seq_len]
v_slices = v_cache[block_tables[i], ...]
v = \
torch.cat([v_slices[i, ...] for i in range(len(block_tables[i]))], dim=0)[:seq_len]
q = query[i : i + 1, ...]
if q.dtype == torch.bfloat16:
attn_out = fp32_attention(
q.cpu(),
k.cpu(),
v.cpu(),
None,
self.scale
).to(query.dtype).to(query.device)
else:
attn_out = torch.vacc.scaled_dot_product_attention(
query=q,
key=k,
value=v,
attn_mask=None,
dropout_p=0,
is_causal=False,
is_train=False,
recompute=False,
flash_attention=False,
sm_scale=self.scale,
)
attn_outs.append(attn_out)
output = torch.cat(attn_outs, dim=0)
# '''
# PagedAttention.forward_decode(
# output[attn_metadata.num_prefill_tokens:, :, :],
# query[attn_metadata.num_prefill_tokens:, :, :],
# key_cache,
# value_cache,
# block_tables_arg,
# seq_lens_arg,
# max_seq_len_arg,
# self.kv_cache_dtype,
# self.num_kv_heads,
# self.scale,
# self.alibi_slopes,
# layer._k_scale,
# layer._v_scale,
# )
# Reshape the output tensor.
return output.view(-1, self.num_heads * self.head_size)
def _run_vacc_forward(
self,
output: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: VACCAttentionMetadata,
attn_type: AttentionType = AttentionType.DECODER,
):
# if self.num_kv_heads != self.num_heads:
# key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
# value = value.repeat_interleave(self.num_queries_per_kv, dim=1)
attn_masks = attn_metadata.get_attn_bias(attn_type)
if attn_masks is None:
if self.alibi_slopes is not None:
attn_masks = _make_alibi_bias(
self.alibi_slopes, query.dtype,
attn_metadata.seq_lens) # type: ignore
elif self.sliding_window is not None:
assert attn_metadata.seq_lens is not None
attn_masks = _make_sliding_window_bias(
attn_metadata.seq_lens, self.sliding_window,
query.dtype) # type: ignore
else:
seq_lens, _ = attn_metadata.get_seq_lens(attn_type)
attn_masks = [None] * len(seq_lens)
attn_metadata.set_attn_bias(attn_masks, attn_type)
causal_attn = (attn_type == AttentionType.DECODER)
seq_lens_q, seq_lens_kv = attn_metadata.get_seq_lens(attn_type)
start_q, start_kv = 0, 0
for seq_len_q, seq_len_kv, mask in zip(seq_lens_q, seq_lens_kv,
attn_masks):
end_q = start_q + seq_len_q
end_kv = start_kv + seq_len_kv
sub_out=torch.vacc.scaled_dot_product_attention(
query[start_q:end_q,:, :].to(torch.float16) * (self.scale),
key[start_kv:end_kv,:, :].to(torch.float16),
value[start_kv:end_kv,:, :].contiguous().to(torch.float16),
attn_mask=None,
dropout_p=0.0,
is_causal=True if attn_type == AttentionType.DECODER else False, #causal_attn and not self.need_mask,
is_train=False,
recompute=False,
flash_attention=False,
sm_scale=1)
output[ start_q:end_q,:, :] = sub_out
start_q, start_kv = end_q, end_kv
return output
class VACCAttentionBackend(AttentionBackend):
accept_output_buffer: bool = False
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]
@classmethod
def validate_head_size(cls, head_size: int) -> None:
supported_head_sizes = cls.get_supported_head_sizes()
if head_size not in supported_head_sizes:
attn_type = cls.__name__.removesuffix("Backend")
raise ValueError(
f"Head size {head_size} is not supported by {attn_type}. "
f"Supported head sizes are: {supported_head_sizes}. "
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes.")
@staticmethod
def get_name() -> str:
return "TORCH_SDPA_VLLM_V1"
@staticmethod
def get_impl_cls() -> type["VACCAttentionBackendImpl"]:
return VACCAttentionBackendImpl
@staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
return VACCAttentionMetadata
# @staticmethod
# def get_state_cls() -> type["CommonAttentionState"]:
# return CommonAttentionState
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]
@staticmethod
def get_builder_cls() -> type["VACCAttentionMetadataBuilderV1"]:
return VACCAttentionMetadataBuilderV1
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str:str,
) -> tuple[int, ...]:
# return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
# num_kv_heads, head_size)
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
return (2, num_blocks, block_size, num_kv_heads, head_size)
@staticmethod
def use_cascade_attention(*args, **kwargs) -> bool:
return False
class VACCAttentionMetadataBuilderV1(AttentionMetadataBuilder[VACCAttentionMetadata]):
# def __init__(self, runner: CPUModelRunner, kv_cache_spec: AttentionSpec,
# block_table: BlockTable) -> None:
# self.runner = runner
# self.block_table = block_table
# # For reorder
# self.reorder_prompt_req_index_list = np.empty(self.runner.max_num_reqs,
# dtype=np.int64)
# self.reorder_decode_req_index_list = np.empty(self.runner.max_num_reqs,
# dtype=np.int64)
# self.num_prompt_req: int = 0
# self.seq_start_loc_cpu = torch.zeros(
# runner.max_num_reqs + 1,
# dtype=torch.int32,
# device="cpu",
# )
# self.seq_start_loc_np = self.seq_start_loc_cpu.numpy()
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device) -> None:
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self.scheduler_config = vllm_config.scheduler_config
# For reorder
self.reorder_prompt_req_index_list = np.empty(
vllm_config.scheduler_config.max_num_seqs, dtype=np.int64)
self.reorder_decode_req_index_list = np.empty(
vllm_config.scheduler_config.max_num_seqs, dtype=np.int64)
self.num_prompt_req: int = 0
self.seq_start_loc_cpu = torch.zeros(
vllm_config.scheduler_config.max_num_seqs + 1,
dtype=torch.int32,
device="cpu",
)
self.seq_start_loc_np = self.seq_start_loc_cpu.numpy()
# def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
# vllm_config: VllmConfig, device: torch.device):
# super().__init__(kv_cache_spec, layer_names, vllm_config, device)
# self.block_size = kv_cache_spec.block_size
# model_config = vllm_config.model_config
# self.num_heads_q = model_config.get_num_attention_heads(
# vllm_config.parallel_config)
# self.num_heads_kv = model_config.get_num_kv_heads(
# vllm_config.parallel_config)
# self.headdim = model_config.get_head_size()
def reorder_batch(self, input_batch: InputBatch,
scheduler_output: SchedulerOutput) -> bool:
prompt_list_idx = 0
decode_list_idx = 0
for req_index in range(input_batch.num_reqs):
if input_batch.num_computed_tokens_cpu[
req_index] < input_batch.num_prompt_tokens[req_index]:
# prompt stage
self.reorder_prompt_req_index_list[prompt_list_idx] = req_index
prompt_list_idx += 1
else:
# decode stage
self.reorder_decode_req_index_list[decode_list_idx] = req_index
decode_list_idx += 1
assert decode_list_idx + prompt_list_idx == input_batch.num_reqs
# Update prompt requests number
self.num_prompt_req = prompt_list_idx
reorder_req_num = 0
for req_index in range(decode_list_idx):
if self.reorder_decode_req_index_list[req_index] < prompt_list_idx:
reorder_req_num += 1
else:
break
if reorder_req_num == 0:
return False
reorder_prompt_list = (
self.reorder_prompt_req_index_list[:prompt_list_idx]
[-reorder_req_num:])
reorder_decode_list = (
self.reorder_decode_req_index_list[:decode_list_idx]
[:reorder_req_num])
assert reorder_decode_list.size == reorder_prompt_list.size
for idx in range(reorder_req_num):
prompt_req_index = reorder_prompt_list[idx].item()
decode_req_index = reorder_decode_list[idx].item()
input_batch.swap_states(prompt_req_index, decode_req_index)
return True
def build(self, common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False):
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
seq_lens = common_attn_metadata.seq_lens
# seq_lens = common_attn_metadata.seq_lens
# runner = self.runner
# block_table = self.block_table
# seq_lens = runner.seq_lens[:num_reqs]
num_prompt_req = self.num_prompt_req
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
num_prefill_tokens = int(query_start_loc_cpu[num_prompt_req].item())
num_decode_tokens = int(query_start_loc_cpu[num_reqs].item() -
num_prefill_tokens)
# print('query_start_loc_cpu', query_start_loc_cpu)
# print('num_prompt_req', num_prompt_req)
# print('num_reqs', num_reqs)
# num_prefill_tokens = runner.query_start_loc_np[num_prompt_req].item()
# num_decode_tokens = runner.query_start_loc_np[num_reqs].item(
# ) - num_prefill_tokens
# block_table.slot_mapping[:num_actual_tokens].copy_(
# block_table.slot_mapping_cpu[:num_actual_tokens],
# non_blocking=True)
# slot_mapping = block_table.slot_mapping[:num_actual_tokens] #.long()
# block_table_tensor = block_table.get_device_tensor()
slot_mapping = common_attn_metadata.slot_mapping
block_table_tensor = common_attn_metadata.block_table_tensor
block_num_per_group = env_blk_grp_size // 16
block_table_tensor_new = block_table_tensor[:num_reqs-num_prompt_req, ::block_num_per_group].contiguous()
# [bs, seq//16] => [bs, seq//16//block_num_per_group, block_num_per_group]
# => [:num_reqs, :, 0] 提取前reqs行并且把 block_num_per_group 的倍数提取出
attn_metadata = VACCAttentionMetadata(
num_prefills=num_prompt_req,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,
slot_mapping=slot_mapping,
seq_lens_tensor=seq_lens, # decode
max_decode_seq_len=None, # decode
block_tables=block_table_tensor_new, # decode
chunked_prefill=False,
# max_query_len=max_query_len,
# max_kv_len=max_prefill_seq_len,
# prefill_query_start_loc=runner.
# query_start_loc_cpu[:num_prompt_req + 1], # prefill
# kv_start_loc=self.seq_start_loc_cpu[:num_prompt_req +
# 1], # prefill
prefill_block_tables=block_table_tensor[:
num_prompt_req], # prefill
query_start_loc=query_start_loc_cpu[:num_reqs +
1], # for logits index
# multi_modal_placeholder_index_maps=None,
# enable_kv_scales_calculation=False,
)
return attn_metadata