This commit is contained in:
2026-04-02 04:53:13 +00:00
parent 80932c96e5
commit 24df76db9d
1987 changed files with 447445 additions and 0 deletions

View File

@@ -0,0 +1,982 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, ClassVar, Optional
import os
import numpy as np
import torch
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
# from vllm_vacc.vllm.attention.backends.vacc_attn import (VACCAttentionBackendImpl,
# VACCAttentionMetadata)
# from vllm_vacc.vllm.attention.backends.vacc_attn import VACCAttentionBackendImpl
# from vllm.attention.backends.utils import CommonAttentionState
from vllm.config import VllmConfig
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
from vllm.v1.worker.cpu_model_runner import CPUModelRunner
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm_vacc.vllm.model_executor.models.vars import BLOCK_GROUP_SIZE as env_blk_grp_size
from vllm_vacc.vllm.attention.ops.vacc_paged_attn import VaccPagedAttention as PagedAttention
def _make_alibi_bias(
alibi_slopes: torch.Tensor,
dtype: torch.dtype,
seq_lens: list[int],
) -> list[torch.Tensor]:
attn_biases: list[torch.Tensor] = []
for seq_len in seq_lens:
bias = torch.arange(seq_len, dtype=dtype)
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(seq_len, 1)`
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# paper.
bias = bias[None, :] - bias[:, None]
num_heads = alibi_slopes.shape[0]
bias = bias[None, :].repeat((num_heads, 1, 1))
bias.mul_(alibi_slopes[:, None, None]).unsqueeze_(0)
inf_mask = torch.empty(
(1, seq_len, seq_len),
dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1)
attn_biases.append((bias + inf_mask).to(dtype))
return attn_biases
def _make_sliding_window_bias(
seq_lens: list[int],
window_size: Optional[int],
dtype: torch.dtype,
) -> list[torch.Tensor]:
attn_biases: list[torch.Tensor] = []
for seq_len in seq_lens:
tensor = torch.full(
(1, seq_len, seq_len),
dtype=dtype,
fill_value=1,
)
shift = 0
mask = torch.tril(tensor, diagonal=shift).to(dtype) # type: ignore
if window_size is not None:
mask = torch.triu(mask, diagonal=shift - window_size + 1)
mask = torch.log(mask)
attn_biases.append(mask.to(dtype))
return attn_biases
@dataclass
class VACCAttentionMetadata(AttentionMetadata):
"""Metadata for VACCAttentionMetadata.
"""
# Total number of prefill requests.
num_prefills: int
# Number of prefill tokens.
num_prefill_tokens: int
# Number of decode tokens. Note that it is equivalent to the number of
# decode requests.
num_decode_tokens: int
# (num_tokens,). The indices of the token slots that input tokens will be
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
# in block 0, and 1st slot in block 1, respectively.
slot_mapping: torch.Tensor
"""Metadata for PagedAttention."""
# (batch_size,). The length of sequences (entire tokens seen so far) per
# sequence.
seq_lens_tensor: Optional[torch.Tensor]
# Maximum sequence length in the batch. 0 if it is prefill-only batch.
max_decode_seq_len: int
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
# in the kv cache. Each block can contain up to block_size tokens.
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured.
block_tables: Optional[torch.Tensor]
"""Metadata for TorchSDPABackend.
"""
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
chunked_prefill: bool
seq_lens: Optional[list[int]] = None # For non-chunked prefill
# For chunked prefill only
max_query_len: Optional[int] = None
max_kv_len: Optional[int] = None
prefill_query_start_loc: Optional[torch.Tensor] = None
kv_start_loc: Optional[torch.Tensor] = None
prefill_block_tables: Optional[torch.Tensor] = None
# For V1 logits index only
query_start_loc: Optional[torch.Tensor] = None
# Begin encoder attn & enc/dec cross-attn fields...
# Encoder sequence lengths representation
encoder_seq_lens: Optional[list[int]] = None
encoder_seq_lens_tensor: Optional[torch.Tensor] = None
# Maximum sequence length among encoder sequences
max_encoder_seq_len: Optional[int] = None
# Number of tokens input to encoder
num_encoder_tokens: Optional[int] = None
# Cross-attention memory-mapping data structures: slot mapping
# and block tables
cross_slot_mapping: Optional[torch.Tensor] = None
cross_block_tables: Optional[torch.Tensor] = None
def __post_init__(self):
# Set during the execution of the first attention op.
# It is a list because it is needed to set per prompt
# when alibi slopes is used. It is because of the limitation
# from xformer API.
# will not appear in the __repr__ and __init__
self.attn_bias: Optional[list[torch.Tensor]] = None
self.encoder_attn_bias: Optional[list[torch.Tensor]] = None
self.cross_attn_bias: Optional[list[torch.Tensor]] = None
@property
def is_all_encoder_attn_metadata_set(self):
'''
All attention metadata required for encoder attention is set.
'''
return ((self.encoder_seq_lens is not None)
and (self.encoder_seq_lens_tensor is not None)
and (self.max_encoder_seq_len is not None))
@property
def is_all_cross_attn_metadata_set(self):
'''
All attention metadata required for enc/dec cross-attention is set.
Superset of encoder attention required metadata.
'''
return (self.is_all_encoder_attn_metadata_set
and (self.cross_slot_mapping is not None)
and (self.cross_block_tables is not None))
@property
def prefill_metadata(self) -> Optional["VACCAttentionMetadata"]:
# Currently chunked prefill is not supported
if self.num_prefill_tokens == 0:
return None
return self
@property
def decode_metadata(self) -> Optional["VACCAttentionMetadata"]:
# Currently chunked prefill is not supported
if self.num_decode_tokens == 0:
return None
return self
def get_seq_lens(
self,
attn_type: AttentionType,
):
'''
Extract appropriate sequence lengths from attention metadata
according to attention type.
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate sequence lengths tensor for query
* Appropriate sequence lengths tensor for key & value
'''
if (attn_type == AttentionType.DECODER
or attn_type == AttentionType.ENCODER_ONLY):
seq_lens_q = self.seq_lens
seq_lens_kv = self.seq_lens
elif attn_type == AttentionType.ENCODER:
seq_lens_q = self.encoder_seq_lens
seq_lens_kv = self.encoder_seq_lens
elif attn_type == AttentionType.ENCODER_DECODER:
seq_lens_q = self.seq_lens
seq_lens_kv = self.encoder_seq_lens
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
return seq_lens_q, seq_lens_kv
def get_attn_bias(
self,
attn_type: AttentionType,
) -> Optional[list[torch.Tensor]]:
'''
Extract appropriate attention bias from attention metadata
according to attention type.
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate attention bias value given the attention type
'''
if (attn_type == AttentionType.DECODER
or attn_type == AttentionType.ENCODER_ONLY):
return self.attn_bias
elif attn_type == AttentionType.ENCODER:
return self.encoder_attn_bias
elif attn_type == AttentionType.ENCODER_DECODER:
return self.cross_attn_bias
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
def set_attn_bias(
self,
attn_bias: list[torch.Tensor],
attn_type: AttentionType,
) -> None:
'''
Update appropriate attention bias field of attention metadata,
according to attention type.
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* attn_bias: The desired attention bias value
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
'''
if (attn_type == AttentionType.DECODER
or attn_type == AttentionType.ENCODER_ONLY):
self.attn_bias = attn_bias
elif attn_type == AttentionType.ENCODER:
self.encoder_attn_bias = attn_bias
elif attn_type == AttentionType.ENCODER_DECODER:
self.cross_attn_bias = attn_bias
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
def get_seq_len_block_table_args(
self,
attn_type: str,
) -> tuple:
'''
The particular choice of sequence-length- and block-table-related
attributes which should be extracted from attn_metadata is dependent
on the type of attention operation.
Decoder attn -> select entirely decoder self-attention-related fields
Encoder/decoder cross-attn -> select encoder sequence lengths &
cross-attn block-tables fields
Encoder attn -> select encoder sequence lengths fields & no block tables
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* is_prompt: True if prefill, False otherwise
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate sequence-lengths tensor
* Appropriate max sequence-length scalar
* Appropriate block tables (or None)
'''
if (attn_type == AttentionType.DECODER
or attn_type == AttentionType.ENCODER_ONLY):
# Decoder self-attention
# Choose max_seq_len based on whether we are in prompt_run
return (self.seq_lens_tensor, self.max_decode_seq_len,
self.block_tables)
elif attn_type == AttentionType.ENCODER_DECODER:
# Enc/dec cross-attention KVs match encoder sequence length;
# cross-attention utilizes special "cross" block tables
return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len,
self.cross_block_tables)
elif attn_type == AttentionType.ENCODER:
# No block tables associated with encoder attention
return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len,
None)
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
# class VACCMetadataBuilder(AttentionMetadataBuilder[VACCAttentionMetadata]):
# def __init__(self, input_builder: ModelInputForVACCBuilder) -> None:
# self.chunked_prefill = input_builder.chunked_prefill
# self.input_builder = input_builder
# def prepare(self):
# self.input_data = self.input_builder.input_data
# def build(self, seq_lens: list[int], query_lens: list[int],
# cuda_graph_pad_size: int, batch_size: int) -> VACCAttentionMetadata:
# input_data = self.input_data
# prefill_seq_lens = seq_lens[0:input_data.num_prefills]
# prefill_query_lens = query_lens[0:input_data.num_prefills]
# slot_mapping = torch.tensor(input_data.slot_mapping,
# dtype=torch.int32,
# device=self.input_builder.device)
# # For chunked-prefill
# if self.chunked_prefill and input_data.num_prefill_tokens != 0:
# prefill_block_tables = make_tensor_with_pad(
# self.input_data.prefill_block_tables,
# pad=0,
# dtype=torch.int32,
# device=self.input_builder.device,
# )
# query_lens_tensor = torch.tensor(prefill_query_lens,
# dtype=torch.int32,
# device=self.input_builder.device)
# kv_lens_tensor = torch.tensor(prefill_seq_lens,
# dtype=torch.int32,
# device=self.input_builder.device)
# query_start_loc = torch.zeros(input_data.num_prefills + 1,
# dtype=torch.int32,
# device=self.input_builder.device)
# kv_start_loc = torch.zeros(input_data.num_prefills + 1,
# dtype=torch.int32,
# device=self.input_builder.device)
# torch.cumsum(query_lens_tensor,
# dim=0,
# dtype=torch.int32,
# out=query_start_loc[1:])
# torch.cumsum(kv_lens_tensor,
# dim=0,
# dtype=torch.int32,
# out=kv_start_loc[1:])
# max_query_len = max(prefill_query_lens)
# max_kv_len = max(prefill_seq_lens)
# else:
# prefill_block_tables = None
# query_start_loc = None
# kv_start_loc = None
# max_query_len = None
# max_kv_len = None
# # For paged attention
# if input_data.num_decode_tokens != 0:
# seq_lens_tensor = torch.tensor(
# input_data.seq_lens[input_data.num_prefills:],
# dtype=torch.int32,
# device=self.input_builder.device,
# )
# block_tables = make_tensor_with_pad(
# self.input_data.decode_block_tables,
# pad=0,
# dtype=torch.int32,
# device=self.input_builder.device,
# )
# # lowest_dim_size = block_tables.size(-1)
# # if lowest_dim_size < 1024:
# # padding_amount = 1024 - lowest_dim_size
# # padding = torch.zeros(*block_tables.size()[:-1], padding_amount, dtype=block_tables.dtype, device=block_tables.device)
# # block_tables = torch.cat((block_tables, padding), dim=-1)
# else:
# block_tables = torch.tensor([])
# seq_lens_tensor = torch.tensor(
# input_data.seq_lens[:input_data.num_prefills],
# dtype=torch.int32,
# device=self.input_builder.device,
# )
# # For multi-modal models
# placeholder_index_maps = None
# if len(input_data.multi_modal_inputs_list) != 0:
# placeholder_index_maps = {
# modality: placeholder_map.index_map()
# for modality, placeholder_map in
# input_data.multi_modal_placeholder_maps.items()
# }
# attn_metadata = VACCAttentionMetadata(
# chunked_prefill=self.chunked_prefill,
# seq_lens=seq_lens, #prefill_seq_lens,
# seq_lens_tensor=seq_lens_tensor,
# max_query_len=max_query_len,
# max_kv_len=max_kv_len,
# query_start_loc=query_start_loc,
# kv_start_loc=kv_start_loc,
# max_decode_seq_len=None,
# num_prefills=input_data.num_prefills,
# num_prefill_tokens=input_data.num_prefill_tokens,
# num_decode_tokens=input_data.num_decode_tokens,
# block_tables=block_tables,
# prefill_block_tables=prefill_block_tables,
# slot_mapping=slot_mapping,
# multi_modal_placeholder_index_maps=placeholder_index_maps,
# enable_kv_scales_calculation=False,
# )
# return attn_metadata
def fp32_attention(
query_layer,
key_layer,
value_layer,
mask,
norm_factor,
out_type=None,
):
ori_type = out_type if out_type is not None else query_layer.dtype
query_layer = query_layer.to(torch.float32)
key_layer = key_layer.to(torch.float32)
value_layer = value_layer.to(torch.float32)
# GQA
if query_layer.size(1) != key_layer.size(1):
if query_layer.size(1) % key_layer.size(1) != 0:
assert False
groups = query_layer.size(1) // key_layer.size(1)
key_layer = torch.repeat_interleave(key_layer, groups, dim=1)
value_layer = torch.repeat_interleave(value_layer, groups, dim=1)
matmul_result = torch.bmm(
query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
) * norm_factor
mask_output = matmul_result
if mask != None:
mask = mask if mask.dim() >= 3 else mask.unsqueeze(0)
mask_output = matmul_result.masked_fill_(mask, -10000.0) # [b * np, sq, sk]
probs = torch.nn.Softmax(dim=-1)(mask_output)
context_layer = torch.bmm(probs, value_layer.transpose(0, 1))
return context_layer.transpose(0, 1).to(ori_type)
class VACCAttentionBackendImpl(AttentionImpl[VACCAttentionMetadata]):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
) -> None:
# if logits_soft_cap is not None:
# logger.warning_once("Torch SPDA does not support logits soft cap. "
# "Outputs may be slightly off.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
self.sliding_window = sliding_window
self.kv_cache_dtype = kv_cache_dtype
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.need_mask = (self.alibi_slopes is not None
or self.sliding_window is not None)
if kv_cache_dtype != "auto":
raise NotImplementedError(
"Torch SDPA backend does not support FP8 KV cache. "
"Please use xFormers backend instead.")
self.attn_type = attn_type
def forward(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: VACCAttentionMetadata, # type: ignore
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with torch SDPA and PagedAttention.
Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run.
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
attn_type = self.attn_type
if (attn_type == AttentionType.ENCODER
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
raise AttributeError("Encoder attention requires setting "
"encoder metadata attributes.")
elif (attn_type == AttentionType.ENCODER_DECODER
and (not attn_metadata.is_all_cross_attn_metadata_set)):
raise AttributeError("Encoder/decoder cross-attention "
"requires setting cross-attention "
"metadata attributes.")
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)
if key is not None:
assert value is not None
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
else:
assert value is None
if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0):
# KV-cache during decoder-self- or
# encoder-decoder-cross-attention, but not
# during encoder attention.
#
# Even if there are no new key/value pairs to cache,
# we still need to break out key_cache and value_cache
# i.e. for later use by paged attention
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)
if (key is not None) and (value is not None):
if attn_type == AttentionType.ENCODER_DECODER:
# Update cross-attention KV cache (prefill-only)
# During cross-attention decode, key & value will be None,
# preventing this IF-statement branch from running
updated_slot_mapping = attn_metadata.cross_slot_mapping
else:
# Update self-attention KV cache (prefill/decode)
updated_slot_mapping = attn_metadata.slot_mapping
PagedAttention.write_to_paged_cache(key, value, key_cache,
value_cache,
updated_slot_mapping,
self.kv_cache_dtype,
layer._k_scale, layer._v_scale)
if attn_type != AttentionType.ENCODER:
# Decoder self-attention supports chunked prefill.
# Encoder/decoder cross-attention requires no chunked
# prefill (100% prefill or 100% decode tokens, no mix)
num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
else:
# Encoder attention - chunked prefill is not applicable;
# derive token-count from query shape & and treat them
# as 100% prefill tokens
assert attn_metadata.num_encoder_tokens is not None
num_prefill_tokens = attn_metadata.num_encoder_tokens
num_decode_tokens = 0
if attn_type == AttentionType.DECODER:
# Only enforce this shape-constraint for decoder
# self-attention
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
output = torch.empty_like(query)
if prefill_meta := attn_metadata.prefill_metadata:
assert attn_metadata.seq_lens is not None
if (kv_cache.numel() == 0
or prefill_meta.block_tables.numel() == 0):
self._run_vacc_forward(
output,
query,
key,
value,
prefill_meta,
attn_type=attn_type)
else:
# prefix-enabled attention
assert not self.need_mask
import intel_extension_for_pytorch.llm.modules as ipex_modules
output = torch.empty_like(query)
ipex_modules.PagedAttention.flash_attn_varlen_func(
output[:prefill_meta.num_prefill_tokens, :, :],
query[:prefill_meta.num_prefill_tokens, :, :],
key_cache,
value_cache,
prefill_meta.query_start_loc,
prefill_meta.kv_start_loc,
prefill_meta.max_query_len,
prefill_meta.max_kv_len,
self.scale,
True,
prefill_meta.prefill_block_tables,
self.alibi_slopes,
)
if decode_meta := attn_metadata.decode_metadata:
assert attn_type != AttentionType.ENCODER_ONLY, (
"Encoder-only models should not have decode metadata.")
# Decoding run.
# (
# seq_lens_arg,
# max_seq_len_arg,
# block_tables_arg,
# ) = decode_meta.get_seq_len_block_table_args(attn_type)
# Note:
# decode attention still use SDPA method
# reshape k/v_cache to (num_block_grp, block_grp_size, head, hidden_size)
k_cache = key_cache.view(-1, env_blk_grp_size, key_cache.shape[2], key_cache.shape[3])
v_cache = value_cache.view(-1, env_blk_grp_size, value_cache.shape[2], value_cache.shape[3])
block_per_group = env_blk_grp_size // 16
# convert block_tables to 8K group index
block_tables = (decode_meta.block_tables // block_per_group).to(torch.int32)
attn_outs = []
for i in range(len(decode_meta.seq_lens_tensor)):
seq_len = decode_meta.seq_lens_tensor[i]
k_slices = k_cache[block_tables[i], ...]
k = \
torch.cat([k_slices[i, ...] for i in range(len(block_tables[i]))], dim=0)[:seq_len]
v_slices = v_cache[block_tables[i], ...]
v = \
torch.cat([v_slices[i, ...] for i in range(len(block_tables[i]))], dim=0)[:seq_len]
q = query[i : i + 1, ...]
if q.dtype == torch.bfloat16:
attn_out = fp32_attention(
q.cpu(),
k.cpu(),
v.cpu(),
None,
self.scale
).to(query.dtype).to(query.device)
else:
attn_out = torch.vacc.scaled_dot_product_attention(
query=q,
key=k,
value=v,
attn_mask=None,
dropout_p=0,
is_causal=False,
is_train=False,
recompute=False,
flash_attention=False,
sm_scale=self.scale,
)
attn_outs.append(attn_out)
output = torch.cat(attn_outs, dim=0)
# '''
# PagedAttention.forward_decode(
# output[attn_metadata.num_prefill_tokens:, :, :],
# query[attn_metadata.num_prefill_tokens:, :, :],
# key_cache,
# value_cache,
# block_tables_arg,
# seq_lens_arg,
# max_seq_len_arg,
# self.kv_cache_dtype,
# self.num_kv_heads,
# self.scale,
# self.alibi_slopes,
# layer._k_scale,
# layer._v_scale,
# )
# Reshape the output tensor.
return output.view(-1, self.num_heads * self.head_size)
def _run_vacc_forward(
self,
output: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: VACCAttentionMetadata,
attn_type: AttentionType = AttentionType.DECODER,
):
# if self.num_kv_heads != self.num_heads:
# key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
# value = value.repeat_interleave(self.num_queries_per_kv, dim=1)
attn_masks = attn_metadata.get_attn_bias(attn_type)
if attn_masks is None:
if self.alibi_slopes is not None:
attn_masks = _make_alibi_bias(
self.alibi_slopes, query.dtype,
attn_metadata.seq_lens) # type: ignore
elif self.sliding_window is not None:
assert attn_metadata.seq_lens is not None
attn_masks = _make_sliding_window_bias(
attn_metadata.seq_lens, self.sliding_window,
query.dtype) # type: ignore
else:
seq_lens, _ = attn_metadata.get_seq_lens(attn_type)
attn_masks = [None] * len(seq_lens)
attn_metadata.set_attn_bias(attn_masks, attn_type)
causal_attn = (attn_type == AttentionType.DECODER)
seq_lens_q, seq_lens_kv = attn_metadata.get_seq_lens(attn_type)
start_q, start_kv = 0, 0
for seq_len_q, seq_len_kv, mask in zip(seq_lens_q, seq_lens_kv,
attn_masks):
end_q = start_q + seq_len_q
end_kv = start_kv + seq_len_kv
sub_out=torch.vacc.scaled_dot_product_attention(
query[start_q:end_q,:, :].to(torch.float16) * (self.scale),
key[start_kv:end_kv,:, :].to(torch.float16),
value[start_kv:end_kv,:, :].contiguous().to(torch.float16),
attn_mask=None,
dropout_p=0.0,
is_causal=True if attn_type == AttentionType.DECODER else False, #causal_attn and not self.need_mask,
is_train=False,
recompute=False,
flash_attention=False,
sm_scale=1)
output[ start_q:end_q,:, :] = sub_out
start_q, start_kv = end_q, end_kv
return output
class VACCAttentionBackend(AttentionBackend):
accept_output_buffer: bool = False
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]
@classmethod
def validate_head_size(cls, head_size: int) -> None:
supported_head_sizes = cls.get_supported_head_sizes()
if head_size not in supported_head_sizes:
attn_type = cls.__name__.removesuffix("Backend")
raise ValueError(
f"Head size {head_size} is not supported by {attn_type}. "
f"Supported head sizes are: {supported_head_sizes}. "
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes.")
@staticmethod
def get_name() -> str:
return "TORCH_SDPA_VLLM_V1"
@staticmethod
def get_impl_cls() -> type["VACCAttentionBackendImpl"]:
return VACCAttentionBackendImpl
@staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
return VACCAttentionMetadata
# @staticmethod
# def get_state_cls() -> type["CommonAttentionState"]:
# return CommonAttentionState
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]
@staticmethod
def get_builder_cls() -> type["VACCAttentionMetadataBuilderV1"]:
return VACCAttentionMetadataBuilderV1
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str:str,
) -> tuple[int, ...]:
# return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
# num_kv_heads, head_size)
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
return (2, num_blocks, block_size, num_kv_heads, head_size)
@staticmethod
def use_cascade_attention(*args, **kwargs) -> bool:
return False
class VACCAttentionMetadataBuilderV1(AttentionMetadataBuilder[VACCAttentionMetadata]):
# def __init__(self, runner: CPUModelRunner, kv_cache_spec: AttentionSpec,
# block_table: BlockTable) -> None:
# self.runner = runner
# self.block_table = block_table
# # For reorder
# self.reorder_prompt_req_index_list = np.empty(self.runner.max_num_reqs,
# dtype=np.int64)
# self.reorder_decode_req_index_list = np.empty(self.runner.max_num_reqs,
# dtype=np.int64)
# self.num_prompt_req: int = 0
# self.seq_start_loc_cpu = torch.zeros(
# runner.max_num_reqs + 1,
# dtype=torch.int32,
# device="cpu",
# )
# self.seq_start_loc_np = self.seq_start_loc_cpu.numpy()
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device) -> None:
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self.scheduler_config = vllm_config.scheduler_config
# For reorder
self.reorder_prompt_req_index_list = np.empty(
vllm_config.scheduler_config.max_num_seqs, dtype=np.int64)
self.reorder_decode_req_index_list = np.empty(
vllm_config.scheduler_config.max_num_seqs, dtype=np.int64)
self.num_prompt_req: int = 0
self.seq_start_loc_cpu = torch.zeros(
vllm_config.scheduler_config.max_num_seqs + 1,
dtype=torch.int32,
device="cpu",
)
self.seq_start_loc_np = self.seq_start_loc_cpu.numpy()
# def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
# vllm_config: VllmConfig, device: torch.device):
# super().__init__(kv_cache_spec, layer_names, vllm_config, device)
# self.block_size = kv_cache_spec.block_size
# model_config = vllm_config.model_config
# self.num_heads_q = model_config.get_num_attention_heads(
# vllm_config.parallel_config)
# self.num_heads_kv = model_config.get_num_kv_heads(
# vllm_config.parallel_config)
# self.headdim = model_config.get_head_size()
def reorder_batch(self, input_batch: InputBatch,
scheduler_output: SchedulerOutput) -> bool:
prompt_list_idx = 0
decode_list_idx = 0
for req_index in range(input_batch.num_reqs):
if input_batch.num_computed_tokens_cpu[
req_index] < input_batch.num_prompt_tokens[req_index]:
# prompt stage
self.reorder_prompt_req_index_list[prompt_list_idx] = req_index
prompt_list_idx += 1
else:
# decode stage
self.reorder_decode_req_index_list[decode_list_idx] = req_index
decode_list_idx += 1
assert decode_list_idx + prompt_list_idx == input_batch.num_reqs
# Update prompt requests number
self.num_prompt_req = prompt_list_idx
reorder_req_num = 0
for req_index in range(decode_list_idx):
if self.reorder_decode_req_index_list[req_index] < prompt_list_idx:
reorder_req_num += 1
else:
break
if reorder_req_num == 0:
return False
reorder_prompt_list = (
self.reorder_prompt_req_index_list[:prompt_list_idx]
[-reorder_req_num:])
reorder_decode_list = (
self.reorder_decode_req_index_list[:decode_list_idx]
[:reorder_req_num])
assert reorder_decode_list.size == reorder_prompt_list.size
for idx in range(reorder_req_num):
prompt_req_index = reorder_prompt_list[idx].item()
decode_req_index = reorder_decode_list[idx].item()
input_batch.swap_states(prompt_req_index, decode_req_index)
return True
def build(self, common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False):
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
seq_lens = common_attn_metadata.seq_lens
# seq_lens = common_attn_metadata.seq_lens
# runner = self.runner
# block_table = self.block_table
# seq_lens = runner.seq_lens[:num_reqs]
num_prompt_req = self.num_prompt_req
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
num_prefill_tokens = int(query_start_loc_cpu[num_prompt_req].item())
num_decode_tokens = int(query_start_loc_cpu[num_reqs].item() -
num_prefill_tokens)
# print('query_start_loc_cpu', query_start_loc_cpu)
# print('num_prompt_req', num_prompt_req)
# print('num_reqs', num_reqs)
# num_prefill_tokens = runner.query_start_loc_np[num_prompt_req].item()
# num_decode_tokens = runner.query_start_loc_np[num_reqs].item(
# ) - num_prefill_tokens
# block_table.slot_mapping[:num_actual_tokens].copy_(
# block_table.slot_mapping_cpu[:num_actual_tokens],
# non_blocking=True)
# slot_mapping = block_table.slot_mapping[:num_actual_tokens] #.long()
# block_table_tensor = block_table.get_device_tensor()
slot_mapping = common_attn_metadata.slot_mapping
block_table_tensor = common_attn_metadata.block_table_tensor
block_num_per_group = env_blk_grp_size // 16
block_table_tensor_new = block_table_tensor[:num_reqs-num_prompt_req, ::block_num_per_group].contiguous()
# [bs, seq//16] => [bs, seq//16//block_num_per_group, block_num_per_group]
# => [:num_reqs, :, 0] 提取前reqs行并且把 block_num_per_group 的倍数提取出
attn_metadata = VACCAttentionMetadata(
num_prefills=num_prompt_req,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,
slot_mapping=slot_mapping,
seq_lens_tensor=seq_lens, # decode
max_decode_seq_len=None, # decode
block_tables=block_table_tensor_new, # decode
chunked_prefill=False,
# max_query_len=max_query_len,
# max_kv_len=max_prefill_seq_len,
# prefill_query_start_loc=runner.
# query_start_loc_cpu[:num_prompt_req + 1], # prefill
# kv_start_loc=self.seq_start_loc_cpu[:num_prompt_req +
# 1], # prefill
prefill_block_tables=block_table_tensor[:
num_prompt_req], # prefill
query_start_loc=query_start_loc_cpu[:num_reqs +
1], # for logits index
# multi_modal_placeholder_index_maps=None,
# enable_kv_scales_calculation=False,
)
return attn_metadata

View File

@@ -0,0 +1,953 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, ClassVar, Generic, List, Optional, Dict, TypeVar, Type
import torch
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata, AttentionType,
)
from vllm.attention.backends.utils import get_mla_dims
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import cdiv, round_down
from vllm.attention.backends.utils import get_mla_dims
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.config import VllmConfig
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
MLACommonMetadata)
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata,
get_per_layer_parameters,
infer_global_hyperparameters,
split_decodes_and_prefills)
from vllm.v1.attention.backends.mla.common import MLACommonImpl
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
from vllm_vacc.vllm.model_executor.models.vars import BLOCK_GROUP_SIZE as env_blk_grp_size
logger = init_logger(__name__)
M = TypeVar("M", bound=MLACommonMetadata)
def vacc_paged_attention_naive(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_table: torch.Tensor,
# seq_lens: torch.Tensor,
seq_lens: int,
out: Optional[torch.Tensor] = None,
sm_scale = -1
) -> torch.Tensor:
# gurantee batch=1 perf
if len(seq_lens) == 1:
k = key_cache.view(-1, key_cache.shape[2], key_cache.shape[3])[:seq_lens[0]]
v = value_cache.view(-1, value_cache.shape[2], value_cache.shape[3])[:seq_lens[0]]
attn_out = torch.vacc.scaled_dot_product_attention(
query=query,
key=k,
value=v,
attn_mask=None,
dropout_p=0,
is_causal=False,
is_train=False,
recompute=False,
flash_attention=False,
sm_scale=sm_scale
)
else:
# t0 = time.time()
attn_outs = []
for i in range(len(seq_lens)):
k_slices = key_cache[block_table[i], :, :, :]
k = torch.cat([k_slices[i, :, :, :].unsqueeze(1) for i in range(len(block_table[i]))], dim=0)
k = k.view(-1, key_cache.shape[2], key_cache.shape[3])[:seq_lens[i]]
v_slices = value_cache[block_table[i], :, :, :]
v = torch.cat([v_slices[i, :, :, :].unsqueeze(1) for i in range(len(block_table[i]))], dim=0)
v = v.view(-1, value_cache.shape[2], value_cache.shape[3])[:seq_lens[i]]
attn_out = torch.vacc.scaled_dot_product_attention(
query=query[i:i+1,:,:],
key=k,
value=v,
attn_mask=None,
dropout_p=0,
is_causal=False,
is_train=False,
recompute=False,
flash_attention=False,
sm_scale=sm_scale
)
attn_outs.append(attn_out)
attn_out = torch.cat(attn_outs, dim=0)
# print(f'{os.getpid()} call spda(seq: {seq_lens}) time: {time.time() - t0}')
return attn_out
@dataclass
class MLACommonPrefillMetadata:
""" Prefill Specific Metadata """
@dataclass
class ChunkedContextMetadata:
# New for MLA (compared to FlashAttention)
# For handling chunked prefill
cu_seq_lens: torch.Tensor
starts: torch.Tensor
seq_tot: list[int]
max_seq_lens: list[int]
workspace: torch.Tensor
block_tables: torch.Tensor #block_table => block_tables 兼容v0
query_start_loc: torch.Tensor
# max_query_len: int
seq_lens: list[int]
chunked_context: Optional[ChunkedContextMetadata] = None
@dataclass
class MLACommonDecodeMetadata:
block_tables: torch.Tensor #block_table => block_tables 兼容v0
seq_lens: torch.Tensor
class VACCMLAMetadata(MLACommonMetadata):
"""Metadata for VACCMLAMetadata.
NOTE: Any python object stored here is not updated when it is
cuda-graph replayed. If you have values that need to be changed
dynamically, it should be stored in tensor. The tensor has to be
updated from `CUDAGraphRunner.forward` API.
"""
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens: Optional[List[int]]
# seq_lens stored as a tensor.
seq_lens_tensor: Optional[torch.Tensor]
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
# Maximum sequence length among prefill batch. 0 if there are decoding
# requests only.
max_prefill_seq_len: int
# Maximum sequence length among decode batch. 0 if there are prefill
# requests only.
max_decode_seq_len: int
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor: Optional[torch.Tensor]
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
# in the kv cache. Each block can contain up to block_size tokens.
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured.
block_tables: Optional[torch.Tensor]
# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool
# Maximum query length in the batch.
max_query_len: Optional[int] = None
input_positions: Optional[torch.Tensor] = None
# Max number of query tokens among request in the batch.
max_decode_query_len: Optional[int] = None
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
query_start_loc: Optional[torch.Tensor] = None
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc: Optional[torch.Tensor] = None
_cached_prefill_metadata: Optional["VACCMLAMetadata"] = None
_cached_decode_metadata: Optional["VACCMLAMetadata"] = None
num_prefill_tokens: int
num_kv_splits: int = 4 # TODO(lucas) add heuristic
attn_logits: Optional[torch.Tensor] = None
req_idx: Optional[torch.Tensor] = None
# The dimension of the attention heads
head_dim: Optional[int] = None
def __post_init__(self):
supported_head_sizes = VACCMLABackend.get_supported_head_sizes()
if self.head_dim is not None and self.head_dim \
not in supported_head_sizes:
raise ValueError(
f"Only {supported_head_sizes} are supported for head_dim,",
f"received {self.head_dim}.")
@property
def prefill_metadata(self) -> Optional["VACCMLAMetadata"]:
if self.num_prefills == 0:
return None
if self._cached_prefill_metadata is not None:
return self._cached_prefill_metadata
assert self.seq_lens is not None
assert self.seq_lens_tensor is not None
# Compute some attn_metadata fields which default to None
query_start_loc = (None if self.query_start_loc is None else
self.query_start_loc[:self.num_prefills + 1])
slot_mapping = (None if self.slot_mapping is None else
self.slot_mapping[:self.num_prefill_tokens])
seq_lens = (None if self.seq_lens is None else
self.seq_lens[:self.num_prefills])
seq_lens_tensor = (None if self.seq_lens_tensor is None else
self.seq_lens_tensor[:self.num_prefills])
seq_start_loc = (None if self.seq_start_loc is None else
self.seq_start_loc[:self.num_prefills + 1])
context_lens_tensor = (None if self.context_lens_tensor is None else
self.context_lens_tensor[:self.num_prefills])
block_tables = (None if self.block_tables is None else
self.block_tables[:self.num_prefills])
input_positions = (None if self.input_positions is None else
self.input_positions[:self.num_prefill_tokens])
self._cached_prefill_metadata = VACCMLAMetadata(
num_prefills=self.num_prefills,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=0,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=self.
multi_modal_placeholder_index_maps,
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
input_positions=input_positions,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_prefill_seq_len=None,
max_decode_seq_len=0,
query_start_loc=query_start_loc,
seq_start_loc=seq_start_loc,
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
use_cuda_graph=False,
head_dim=self.head_dim)
return self._cached_prefill_metadata
@property
def decode_metadata(self) -> Optional["VACCMLAMetadata"]:
if self.num_decode_tokens == 0:
return None
if self._cached_decode_metadata is not None:
return self._cached_decode_metadata
assert self.seq_lens_tensor is not None
# Compute some attn_metadata fields which default to None
slot_mapping = (None if self.slot_mapping is None else
self.slot_mapping[self.num_prefill_tokens:])
seq_lens_tensor = (None if self.seq_lens_tensor is None else
self.seq_lens_tensor[self.num_prefills:])
block_tables = (None if self.block_tables is None else
self.block_tables[self.num_prefills:])
input_positions = (None if self.input_positions is None else
self.input_positions[self.num_prefill_tokens:])
self._cached_decode_metadata = VACCMLAMetadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=self.num_decode_tokens,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
seq_lens=self.seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_decode_query_len=self.max_decode_query_len,
max_query_len=self.max_query_len,
max_prefill_seq_len=0,
max_decode_seq_len=self.max_decode_seq_len,
# Batch may be composed of prefill|decodes, adjust query start
# indices to refer to the start of decodes. E.g.
# in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
query_start_loc=(self.query_start_loc[self.num_prefills:] -
self.query_start_loc[self.num_prefills])
if self.query_start_loc is not None else None,
seq_start_loc=self.seq_start_loc[self.num_prefills:]
if self.seq_start_loc is not None else None,
context_lens_tensor=None,
block_tables=block_tables,
use_cuda_graph=self.use_cuda_graph,
input_positions=input_positions,
head_dim=self.head_dim)
return self._cached_decode_metadata
def advance_step(self,
model_input: "ModelInputForVACCWithSamplingMetadata",
sampled_token_ids: Optional[torch.Tensor],
block_size: int,
num_seqs: int,
num_queries: int,
turn_prefills_into_decodes: bool = False):
"""
Update metadata in-place to advance one decode step.
"""
# When using cudagraph, the num_seqs is padded to the next captured
# batch sized, but num_queries tracks the actual number of requests in
# the batch. For --enforce-eager mode, num_seqs == num_queries
if num_seqs != num_queries:
assert num_seqs > num_queries
assert self.use_cuda_graph
if turn_prefills_into_decodes:
# When Mutli-Step is enabled with Chunked-Prefill, prefills and
# decodes are scheduled together. In the first step, all the
# prefills turn into decodes. This update reflects that
# conversion.
assert self.num_decode_tokens + self.num_prefills == num_seqs
self.num_decode_tokens += self.num_prefills
self.num_prefills = 0
# self.num_prefill_tokens = 0
# self.max_prefill_seq_len = 0
self.max_query_len = 1
self.slot_mapping = self.slot_mapping[:num_seqs]
else:
assert self.seq_lens is not None
assert self.max_decode_seq_len == max(self.seq_lens)
assert self.num_prefills == 0
assert self.num_prefill_tokens == 0
assert self.num_decode_tokens == num_seqs
assert self.slot_mapping.shape == (num_seqs, )
assert self.seq_lens is not None
assert len(self.seq_lens) == num_seqs
assert self.seq_lens_tensor is not None
assert self.seq_lens_tensor.shape == (num_seqs, )
# assert self.max_query_len == 1
# assert self.max_prefill_seq_len == 0
assert self.query_start_loc is not None
assert self.query_start_loc.shape == (num_queries + 1, )
assert self.seq_start_loc is not None
assert self.seq_start_loc.shape == (num_seqs + 1, )
assert self.context_lens_tensor is not None
assert self.context_lens_tensor.shape == (num_queries, )
assert self.block_tables is not None
assert self.block_tables.shape[0] == num_seqs
# Update query lengths. Note that we update only queries and not seqs,
# since tensors may be padded due to captured cuda graph batch size
for i in range(num_queries):
self.seq_lens[i] += 1
# self.max_decode_seq_len = None
ops.advance_step_flashattn(num_seqs=num_seqs,
num_queries=num_queries,
block_size=block_size,
input_tokens=model_input.input_tokens,
sampled_token_ids=sampled_token_ids,
input_positions=model_input.input_positions,
seq_lens=self.seq_lens_tensor,
slot_mapping=self.slot_mapping,
block_tables=self.block_tables)
class VACCMLAImpl(MLACommonImpl[VACCMLAMetadata]):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
# blocksparse_params: Optional[Dict[str, Any]],
logits_soft_cap: Optional[float],
attn_type: str,
kv_sharing_target_layer_name: Optional[str],
# MLA Specific Arguments
**kwargs) -> None:
# super().__init__(num_heads, head_size, scale, num_kv_heads,
# alibi_slopes, sliding_window, kv_cache_dtype,
# logits_soft_cap, attn_type,
# kv_sharing_target_layer_name, **kwargs)
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
self.kv_cache_dtype = kv_cache_dtype
# print('kwargs', kwargs) # 手动写老版本的继承
self.q_lora_rank = kwargs['q_lora_rank'] if 'q_lora_rank' in kwargs else None
self.kv_lora_rank = kwargs['kv_lora_rank'] if 'kv_lora_rank' in kwargs else None
self.qk_nope_head_dim = kwargs['qk_nope_head_dim'] if 'qk_nope_head_dim' in kwargs else None
self.qk_head_dim = kwargs['qk_head_dim'] if 'qk_head_dim' in kwargs else None
self.qk_head_dim = kwargs['qk_head_dim'] if 'qk_head_dim' in kwargs else None
self.v_head_dim = kwargs['v_head_dim'] if 'v_head_dim' in kwargs else None
self.rotary_emb = kwargs['rotary_emb'] if 'rotary_emb' in kwargs else None
self.q_proj = kwargs['q_proj'] if 'q_proj' in kwargs else None
self.kv_b_proj = kwargs['kv_b_proj'] if 'kv_b_proj' in kwargs else None
self.o_proj = kwargs['o_proj'] if 'o_proj' in kwargs else None
unsupported_features = [
alibi_slopes, sliding_window, logits_soft_cap
]
if any(unsupported_features):
raise NotImplementedError(
"VACCMLAImpl does not support one of the following: "
"alibi_slopes, sliding_window, blocksparse_params, "
"logits_soft_cap")
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"VACCMLAImpl")
def extract_weights(self):
weights = {}
if hasattr(self, 'W_Q'):
weights["W_Q"] = self.W_Q
if hasattr(self, 'W_Q_scales'):
weights["W_Q_scales"] = self.W_Q_scales
if hasattr(self, 'W_QR'):
weights['W_QR'] = self.W_QR
if hasattr(self, 'W_QR_scales'):
weights["W_QR_scales"] = self.W_QR_scales
if hasattr(self, 'W_Q_QR'):
weights["W_Q_QR"] = self.W_Q_QR
if hasattr(self, 'W_Q_QR_scales'):
weights["W_Q_QR_scales"] = self.W_Q_QR_scales
if hasattr(self, 'W_UK'):
weights['W_UK'] = self.W_UK
if hasattr(self, 'W_UK_scales'):
weights['W_UK_scales'] = self.W_UK_scales
if hasattr(self, 'W_Q_UK_scales'):
weights['W_Q_UK_scales'] = self.W_Q_UK_scales
if hasattr(self, 'W_UV'):
weights['W_UV'] = self.W_UV
if hasattr(self, 'W_UV_scales'):
weights['W_UV_scales'] = self.W_UV_scales
if hasattr(self, 'W_UV_O'):
weights['W_UV_O'] = self.W_UV_O
if hasattr(self, 'W_UV_O_scales'):
weights['W_UV_O_scales'] = self.W_UV_O_scales
return weights
def _forward_prefill(
self,
q: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: VACCMLAMetadata,
) -> torch.Tensor:
assert isinstance(attn_metadata, VACCMLAV1Metadata)
kv_nope = self.kv_b_proj(kv_c_normed)[0]\
.view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv_nope\
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
v = v.contiguous()
# For MLA the v head dim is smaller than qk head dim so we pad out
# v with 0s to match the qk head dim
# v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
# value=0)
# attn_output = torch.vacc.scaled_dot_product_attention(
# query=q,
# key=k,
# value=v_padded,
# attn_mask=None,
# dropout_p=0,
# is_causal=True,
# is_train=False,
# recompute=False,
# flash_attention=True,
# sm_scale=self.scale
# )
# attn_output = attn_output\
# .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\
# .reshape(-1, self.num_heads * v.shape[-1])
seq_lens = attn_metadata.prefill_metadata.seq_lens
if len(seq_lens) == 1:
# Vacc supports different head dim of v and qk.
attn_output = torch.vacc.scaled_dot_product_attention(
query=q,
key=k,
value=v,
attn_mask=None,
dropout_p=0,
is_causal=True,
is_train=False,
recompute=False,
flash_attention=False,
sm_scale=self.scale
)
attn_out = attn_output.view(-1, self.num_heads * v.shape[-1])
else:
attn_outs = []
start = 0
for seq in seq_lens:
end = start + seq
attn_out = torch.vacc.scaled_dot_product_attention(
query=q[start:end, :],
key=k[start:end, :],
value=v[start:end, :],
attn_mask=None,
dropout_p=0,
is_causal=True,
is_train=False,
recompute=False,
flash_attention=False,
sm_scale=self.scale
)
start = end
attn_outs.append(attn_out)
attn_out = torch.cat(attn_outs, dim=0).view(-1, self.num_heads * v.shape[-1])
return self.o_proj(attn_out)[0]
def _forward_decode(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: VACCMLAMetadata,
) -> torch.Tensor:
assert kv_c_and_k_pe_cache.numel() > 0
if self.kv_cache_dtype.startswith("fp8"):
raise NotImplementedError("FP8 Triton MLA not yet supported")
decode_meta = attn_metadata.decode_metadata
assert decode_meta is not None
B = q_nope.shape[0]
q = torch.cat([q_nope, q_pe], dim=-1)
o = torch.zeros(B,
self.num_heads,
self.kv_lora_rank,
dtype=q.dtype,
device=q.device)
# Add a head dim of 1
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2)
# print(f"kv_c_and_k_pe_cache: {kv_c_and_k_pe_cache.shape} ")
kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank]
# Run MQA using paged_attention
# o = torch.vacc.paged_attention(
# query=q,
# key_cache=kv_c_and_k_pe_cache,
# value_cache=kv_c_cache,
# block_table=decode_meta.block_tables,
# seq_len=decode_meta.seq_lens_tensor,
# out=o,
# sm_scale=self.scale
# )
# Run MQA using spda
# t0 = time.time()
o = vacc_paged_attention_naive(
q,
kv_c_and_k_pe_cache,
kv_c_cache,
block_table = decode_meta.block_tables,
# seq_lens = decode_meta.seq_lens_tensor,
seq_lens=decode_meta.seq_lens,
out = o,
sm_scale=self.scale)
return self._v_up_proj_and_o_proj(o)
# patch from MLACommonBackend
class VACCMLABackend(AttentionBackend):
accept_output_buffer: bool = False
@staticmethod
def get_name() -> str:
return "TRITON_MLA_VLLM_V1"
@staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
return VACCMLAMetadata
@staticmethod
def get_impl_cls() -> Type["VACCMLAImpl"]:
return VACCMLAImpl
@staticmethod
def get_builder_cls() -> type["MLAVaccMetadataBuilder"]:
return MLAVaccMetadataBuilder
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int, # assumed to be 1 for MLA
head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
return (num_blocks, block_size, head_size)
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [576]
@classmethod
def validate_head_size(cls, head_size: int) -> None:
supported_head_sizes = cls.get_supported_head_sizes()
if head_size not in supported_head_sizes:
attn_type = cls.__name__.removesuffix("Backend")
raise ValueError(
f"Head size {head_size} is not supported by {attn_type}. "
f"Supported head sizes are: {supported_head_sizes}. "
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes.")
D = TypeVar("D", bound=MLACommonDecodeMetadata)
# patch from MLACommonMetadata
@dataclass
class VACCMLAV1Metadata(Generic[D]):
"""Metadata for MLACommon.
NOTE: Please read the comment at the top of the file before trying to
understand this class
"""
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
num_actual_tokens: int # Number of tokens excluding padding.
query_start_loc: torch.Tensor
slot_mapping: torch.Tensor
# New for MLA (compared to FlashAttention)
# For handling prefill decode split
num_decodes: int
num_decode_tokens: int
num_prefills: int
num_prefill_tokens: int
# The dimension of the attention heads
head_dim: Optional[int] = None
decode_metadata: Optional[D] = None
prefill_metadata: Optional[MLACommonPrefillMetadata] = None
def __post_init__(self):
if self.head_dim is not None:
MLACommonBackend.validate_head_size(self.head_dim)
class MLAVaccMetadataBuilder(AttentionMetadataBuilder[M]):
"""
NOTE: Please read the comment at the top of the file before trying to
understand this class
"""
reorder_batch_threshold: ClassVar[int] = 2
#TODO 区分 prefill decode 阈值
def __init__(self,
# runner: "GPUModelRunner",
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
metadata_cls: Optional[type[M]] = None):
self._global_hyperparameters = infer_global_hyperparameters(
get_per_layer_parameters(vllm_config, layer_names,
MLACommonImpl))
self.metadata_cls = metadata_cls \
if metadata_cls is not None else VACCMLAV1Metadata
self.kv_cache_spec = kv_cache_spec
scheduler_config = vllm_config.scheduler_config
self.model_config = vllm_config.model_config
parallel_config = vllm_config.parallel_config
cache_config = vllm_config.cache_config
self.compilation_config = vllm_config.compilation_config
self.device = device
# self.runner = runner
# scheduler_config = runner.scheduler_config
# model_config = runner.model_config
# cache_config = runner.cache_config
self.chunked_prefill_enabled = False
self.num_heads = self.model_config.get_num_attention_heads(
parallel_config)
self.mla_dims = get_mla_dims(self.model_config)
self.aot_schedule = current_platform.is_cuda()
self.kv_cache_spec = kv_cache_spec
# Dont try to access the runner on AMD
if self.aot_schedule:
self.page_size = self.kv_cache_spec.block_size
# if self.chunked_prefill_enabled:
# self.chunked_prefill_workspace_size = min(
# # Max sure there is enough for 8 full length request or at least
# # 4 pages of cache per request
# max(
# 8 * model_config.max_model_len, 4 *
# scheduler_config.max_num_seqs * cache_config.block_size),
# # For long-context models try not to over-allocate limiting
# # kv-cache space, limiting it to 64k tokens,
# # which would result in the workspace being:
# # 2*(576)*(64*1024) = 144mb
# # (assuming 576 MLA head dim, and fp16)
# # which would result in up-projected context being
# # 2*(192*128)*(64*1024) = 3gb
# # (assuming 192 QK head dim, 128 heads, and fp16)
# 128 * 1024)
# assert self.chunked_prefill_workspace_size >= \
# scheduler_config.max_num_seqs * cache_config.block_size
# self.chunked_prefill_workspace = torch.empty(
# (self.chunked_prefill_workspace_size,
# model_config.get_head_size()),
# dtype=model_config.dtype,
# device=runner.device,
# )
# self.block_table = block_table
def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
# We now want to reorder the batch so that the "decode" requests are and
# the front and the "prefill" requests are at the using the least amount
# swaps possible. (NOTE for now we loosely use "decode" to mean requests
# where attention is likely memory-bound and "prefill" to mean requests
# where attention is likely compute-bound, TODO(lucas): figure out a
# better naming here)
decodes = []
prefills = []
num_decode_tokens = 0
num_prefill_tokens = 0
for i, req_id in enumerate(input_batch.req_ids):
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
# for now treat 1 scheduled token as "decode" even if its not,
# we should update this to something like < 8 in the future but
# currently the TritonMLA._forward_decode only supports
# num_tokens = 1
if scheduler_output.scheduled_cached_reqs.num_computed_tokens != []:
decodes.append(i)
num_decode_tokens += num_tokens
else:
prefills.append(i)
num_prefill_tokens += num_tokens
# We hope that this is fairly minimal since decodes
# should be around for a number of iterations so hopefully they are
# relatively stationary (and new request are generally appended to the
# persistent batch so already should be at the back)
# To achieve this we loop over the decodes in descending order and
# the prefills in ascending order. We swap decodes from the "back"
# i.e. past where the last decode should be in the reodorered with
# prefills from the front of the batch.
# `decodes` and `prefills` are already in ascending order just based on
# the above loop
num_decodes = len(decodes)
num_prefills = len(prefills)
modified_batch = False
for i in range(1, min(num_decodes, num_prefills) + 1):
# If the decode is at the "back" of the batch, i, we can swap it
# with the prefill closest to the front of the batch
decode_idx = decodes[num_decodes - i]
if decode_idx < num_decodes:
break
input_batch.swap_states(prefills[i - 1], decode_idx)
modified_batch = True
# Save for next `build` call
# TODO(lucas): this is a bit of a hack, we should probably have a
# better way of doing this
self._num_decodes = num_decodes
self._num_prefills = num_prefills
self._num_decode_tokens = num_decode_tokens
self._num_prefill_tokens = num_prefill_tokens
return modified_batch
def _build_decode(self, block_table_tensor: torch.Tensor,
seq_lens: torch.Tensor):
return MLACommonDecodeMetadata(
block_tables=block_table_tensor,
seq_lens=seq_lens,
)
# def build_for_cudagraph_capture(
# self, common_attn_metadata: CommonAttentionMetadata) -> M:
# """
# This method builds the metadata for full cudagraph capture.
# Currently, only decode is supported for full cudagraphs with MLA.
# """
# m = common_attn_metadata
# assert m.num_reqs == m.num_actual_tokens, \
# "MLA only supports decode-only full CUDAGraph capture. " \
# "Make sure all cudagraph capture sizes <= max_num_seq."
# # m.max_query_len = 1 # decode-only
# # Update state usually set in reorder_batch.
# self._num_decodes = m.num_reqs
# self._num_decode_tokens = m.num_actual_tokens
# self._num_prefills = 0
# self._num_prefill_tokens = 0
# return self.build(0, m)
def append_seqlen(self, seq_len: list[int], all_len: int):
# print('append_seqlen seq_len', seq_len)
# print('append_seqlen all_len', all_len)
if all_len > len(seq_len) and all_len % len(seq_len) == 0:
new_seq_len = []
mtp_num = all_len // len(seq_len)
for start_len in seq_len:
for i in range(1,1+mtp_num):
new_seq_len.append(start_len-mtp_num+i)
return new_seq_len
return seq_len
def build(self, common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> M:
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
# max_query_len = common_attn_metadata.max_query_len
# assert self._num_decodes + self._num_prefills == num_reqs
# Note(simon): be careful about the CPU <> GPU memory movement in this
# function. We should avoid GPU -> CPU sync as much as possible because
# it blocks on all previous kernels.
# device = self.runner.device
# block_table = self.block_table
# block_table_tensor = block_table.get_device_tensor()#[:num_reqs]
block_table_tensor = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping
#decode common_attn_metadata: CommonAttentionMetadata(query_start_loc=tensor([0, 2], device='vacc:20', dtype=torch.int32), query_start_loc_cpu=tensor([0, 2], dtype=torch.int32), seq_lens=[40], seq_lens_cpu=[40, 0, 0, 0], num_computed_tokens_cpu=tensor([38], dtype=torch.int32), num_reqs=1, num_actual_tokens=2, max_query_len=2, max_seq_len=40, block_table_tensor=tensor([[1536, 1537, 1538, ..., 0, 0, 0]], device='vacc:20',
# block_table.slot_mapping[:num_actual_tokens].copy_(
# block_table.slot_mapping_cpu[:num_actual_tokens],
# non_blocking=True)
# # block_table.slot_mapping[num_actual_tokens:].fill_(-1)
# slot_mapping = block_table.slot_mapping[:num_actual_tokens]
# num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
# split_decodes_and_prefills(common_attn_metadata,
# decode_threshold=self.reorder_batch_threshold)
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = 0, 0, 0, 0
token_nums = common_attn_metadata.query_start_loc_cpu[1:] - common_attn_metadata.query_start_loc_cpu[:-1]
if token_nums.max().item() > self.reorder_batch_threshold:
num_prefills = num_reqs
num_prefill_tokens = num_actual_tokens
else:
num_decodes = num_reqs
num_decode_tokens = num_actual_tokens
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
prefill_metadata = None
if num_prefills > 0:
reqs_start = num_decodes # prefill_start
# context_lens_cpu = self.runner.input_batch.\
# num_computed_tokens_cpu_tensor[reqs_start:num_reqs]
# max_context_len_cpu = context_lens_cpu.max().item()
# num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
prefill_query_start_loc = query_start_loc[
reqs_start:] - query_start_loc[reqs_start]
chunked_context_metadata = None
# if self.chunked_prefill_enabled and self._num_prefills > 0 \
# and max_context_len_cpu > 0:
# dont support chunked prefill
prefill_metadata = MLACommonPrefillMetadata(
block_tables=block_table_tensor[reqs_start:reqs_start+num_prefills, ...],
query_start_loc=prefill_query_start_loc,
# max_query_len=None,
chunked_context=chunked_context_metadata,
seq_lens=seq_lens[:num_prefills],
)
if not isinstance(seq_lens, list):
# TODO init set list in init: vllm/v1/spec_decode/eagle.py
seq_lens = seq_lens.tolist()
decode_metadata = None
if num_decodes > 0:
block_num_per_group = env_blk_grp_size // 16
block_table_tensor_new = block_table_tensor[:num_decodes, ::block_num_per_group].contiguous()
seq_lens_new = self.append_seqlen(seq_lens[:slot_mapping.shape[-1]], slot_mapping.shape[-1])
if slot_mapping.shape[-1] > num_decodes:
mtp_numbers = [query_start_loc[i+1]-query_start_loc[i] for i in range(len(query_start_loc)-1)] #query_start_loc[1:] - query_start_loc[:-1]
block_table_tensor_list = []
for bi,mtp_number in enumerate(mtp_numbers):
for _ in range(mtp_number):
block_table_tensor_list.append(block_table_tensor_new[bi:bi+1])
block_table_tensor_new = torch.concatenate(block_table_tensor_list, 0)
decode_metadata = self._build_decode(
block_table_tensor=block_table_tensor_new,
seq_lens=seq_lens_new,
)
return self.metadata_cls(
num_actual_tokens=num_actual_tokens,
query_start_loc=query_start_loc,
slot_mapping=slot_mapping,
head_dim=self.model_config.get_head_size(),
# prefill_seq_lens=seq_lens[:self._num_prefills].tolist(), # device to host, todo optimiz
# MLACommonMetadata Chunk prefill specific
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
prefill_metadata=prefill_metadata,
decode_metadata=decode_metadata,
)