This commit is contained in:
2026-04-02 04:53:13 +00:00
parent 80932c96e5
commit 24df76db9d
1987 changed files with 447445 additions and 0 deletions

View File

View File

@@ -0,0 +1,919 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Optional
import numpy as np
import torch
from torch.nn.functional import scaled_dot_product_attention
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata, AttentionType,
is_quantized_kv_cache)
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.gpu_input_batch import InputBatch
try:
import intel_extension_for_pytorch.llm.modules as ipex_modules
_use_ipex = True
# AttributeError is to handle a bug in ipex
# https://github.com/intel/intel-extension-for-pytorch/pull/813
except (ImportError, AttributeError):
_use_ipex = False
from vllm import _custom_ops as ops
logger = init_logger(__name__)
class TorchSDPABackend(AttentionBackend):
accept_output_buffer: bool = False
@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16, torch.float32]
@classmethod
def validate_head_size(cls, head_size: int) -> None:
attn_impl = _get_paged_attn_impl()
is_valid, supported_head_sizes = attn_impl.validate_head_size(
head_size)
if not is_valid:
attn_type = cls.__name__.removesuffix("Backend")
raise ValueError(
f"Head size {head_size} is not supported by {attn_type}. "
f"Supported head sizes are: {supported_head_sizes}. "
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes.")
@staticmethod
def get_name() -> str:
return "TORCH_SDPA"
@staticmethod
def get_impl_cls() -> type["TorchSDPABackendImpl"]:
return TorchSDPABackendImpl
@staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
return TorchSDPAMetadata
@staticmethod
def get_builder_cls() -> type["TorchSDPAMetadataBuilderV1"]:
return TorchSDPAMetadataBuilderV1
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
return _get_paged_attn_impl().get_kv_cache_shape(
num_blocks, block_size, num_kv_heads, head_size)
@staticmethod
def use_cascade_attention(*args, **kwargs) -> bool:
return False
@dataclass
class TorchSDPAMetadata(AttentionMetadata):
"""Attention metadata for prefill and decode batched together."""
# Total number of prefill requests.
num_prefills: int
# Number of prefill tokens.
num_prefill_tokens: int
# Number of decode tokens. Note that it is equivalent to the number of
# decode requests.
num_decode_tokens: int
# (num_tokens,). The indices of the token slots that input tokens will be
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
# in block 0, and 1st slot in block 1, respectively.
slot_mapping: torch.Tensor
"""Metadata for PagedAttention."""
# (batch_size,). The length of sequences (entire tokens seen so far) per
# sequence.
seq_lens_tensor: Optional[torch.Tensor]
# Maximum sequence length in the batch. 0 if it is prefill-only batch.
max_decode_seq_len: int
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
# in the kv cache. Each block can contain up to block_size tokens.
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured.
block_tables: Optional[torch.Tensor]
"""Metadata for TorchSDPABackend.
"""
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
chunked_prefill: bool
seq_lens: Optional[list[int]] = None # For non-chunked prefill
# For chunked prefill only
max_query_len: Optional[int] = None
max_kv_len: Optional[int] = None
prefill_query_start_loc: Optional[torch.Tensor] = None
kv_start_loc: Optional[torch.Tensor] = None
prefill_block_tables: Optional[torch.Tensor] = None
# For V1 logits index only
query_start_loc: Optional[torch.Tensor] = None
# Begin encoder attn & enc/dec cross-attn fields...
# Encoder sequence lengths representation
encoder_seq_lens: Optional[list[int]] = None
encoder_seq_lens_tensor: Optional[torch.Tensor] = None
# Maximum sequence length among encoder sequences
max_encoder_seq_len: Optional[int] = None
# Number of tokens input to encoder
num_encoder_tokens: Optional[int] = None
# Cross-attention memory-mapping data structures: slot mapping
# and block tables
cross_slot_mapping: Optional[torch.Tensor] = None
cross_block_tables: Optional[torch.Tensor] = None
def __post_init__(self):
# Set during the execution of the first attention op.
# It is a list because it is needed to set per prompt
# when alibi slopes is used. It is because of the limitation
# from xformer API.
# will not appear in the __repr__ and __init__
self.attn_bias: Optional[list[torch.Tensor]] = None
self.encoder_attn_bias: Optional[list[torch.Tensor]] = None
self.cross_attn_bias: Optional[list[torch.Tensor]] = None
@property
def is_all_encoder_attn_metadata_set(self):
'''
All attention metadata required for encoder attention is set.
'''
return ((self.encoder_seq_lens is not None)
and (self.encoder_seq_lens_tensor is not None)
and (self.max_encoder_seq_len is not None))
@property
def is_all_cross_attn_metadata_set(self):
'''
All attention metadata required for enc/dec cross-attention is set.
Superset of encoder attention required metadata.
'''
return (self.is_all_encoder_attn_metadata_set
and (self.cross_slot_mapping is not None)
and (self.cross_block_tables is not None))
@property
def prefill_metadata(self) -> Optional["TorchSDPAMetadata"]:
if self.num_prefill_tokens == 0:
return None
return self
@property
def decode_metadata(self) -> Optional["TorchSDPAMetadata"]:
if self.num_decode_tokens == 0:
return None
return self
def get_seq_lens(
self,
attn_type: str,
):
'''
Extract appropriate sequence lengths from attention metadata
according to attention type.
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate sequence lengths tensor for query
* Appropriate sequence lengths tensor for key & value
'''
if (attn_type == AttentionType.DECODER
or attn_type == AttentionType.ENCODER_ONLY):
seq_lens_q = self.seq_lens
seq_lens_kv = self.seq_lens
elif attn_type == AttentionType.ENCODER:
seq_lens_q = self.encoder_seq_lens
seq_lens_kv = self.encoder_seq_lens
elif attn_type == AttentionType.ENCODER_DECODER:
seq_lens_q = self.seq_lens
seq_lens_kv = self.encoder_seq_lens
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
return seq_lens_q, seq_lens_kv
def get_attn_bias(
self,
attn_type: str,
) -> Optional[list[torch.Tensor]]:
'''
Extract appropriate attention bias from attention metadata
according to attention type.
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate attention bias value given the attention type
'''
if (attn_type == AttentionType.DECODER
or attn_type == AttentionType.ENCODER_ONLY):
return self.attn_bias
elif attn_type == AttentionType.ENCODER:
return self.encoder_attn_bias
elif attn_type == AttentionType.ENCODER_DECODER:
return self.cross_attn_bias
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
def set_attn_bias(
self,
attn_bias: list[torch.Tensor],
attn_type: str,
) -> None:
'''
Update appropriate attention bias field of attention metadata,
according to attention type.
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* attn_bias: The desired attention bias value
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
'''
if (attn_type == AttentionType.DECODER
or attn_type == AttentionType.ENCODER_ONLY):
self.attn_bias = attn_bias
elif attn_type == AttentionType.ENCODER:
self.encoder_attn_bias = attn_bias
elif attn_type == AttentionType.ENCODER_DECODER:
self.cross_attn_bias = attn_bias
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
def get_seq_len_block_table_args(
self,
attn_type: str,
) -> tuple:
'''
The particular choice of sequence-length- and block-table-related
attributes which should be extracted from attn_metadata is dependent
on the type of attention operation.
Decoder attn -> select entirely decoder self-attention-related fields
Encoder/decoder cross-attn -> select encoder sequence lengths &
cross-attn block-tables fields
Encoder attn -> select encoder sequence lengths fields & no block tables
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* is_prompt: True if prefill, False otherwise
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate sequence-lengths tensor
* Appropriate max sequence-length scalar
* Appropriate block tables (or None)
'''
if (attn_type == AttentionType.DECODER
or attn_type == AttentionType.ENCODER_ONLY):
# Decoder self-attention
# Choose max_seq_len based on whether we are in prompt_run
return (self.seq_lens_tensor, self.max_decode_seq_len,
self.block_tables)
elif attn_type == AttentionType.ENCODER_DECODER:
# Enc/dec cross-attention KVs match encoder sequence length;
# cross-attention utilizes special "cross" block tables
return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len,
self.cross_block_tables)
elif attn_type == AttentionType.ENCODER:
# No block tables associated with encoder attention
return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len,
None)
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]):
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device) -> None:
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self.scheduler_config = vllm_config.scheduler_config
# For reorder
self.reorder_prompt_req_index_list = np.empty(
vllm_config.scheduler_config.max_num_seqs, dtype=np.int64)
self.reorder_decode_req_index_list = np.empty(
vllm_config.scheduler_config.max_num_seqs, dtype=np.int64)
self.num_prompt_req: int = 0
self.seq_start_loc_cpu = torch.zeros(
vllm_config.scheduler_config.max_num_seqs + 1,
dtype=torch.int32,
device="cpu",
)
self.seq_start_loc_np = self.seq_start_loc_cpu.numpy()
def reorder_batch(self, input_batch: InputBatch,
scheduler_output: SchedulerOutput) -> bool:
prompt_list_idx = 0
decode_list_idx = 0
for req_index in range(input_batch.num_reqs):
if input_batch.num_computed_tokens_cpu[
req_index] < input_batch.num_prompt_tokens[req_index]:
# prompt stage
self.reorder_prompt_req_index_list[prompt_list_idx] = req_index
prompt_list_idx += 1
else:
# decode stage
self.reorder_decode_req_index_list[decode_list_idx] = req_index
decode_list_idx += 1
assert decode_list_idx + prompt_list_idx == input_batch.num_reqs
# Update prompt requests number
self.num_prompt_req = prompt_list_idx
reorder_req_num = 0
for req_index in range(decode_list_idx):
if self.reorder_decode_req_index_list[req_index] < prompt_list_idx:
reorder_req_num += 1
else:
break
if reorder_req_num == 0:
return False
reorder_prompt_list = (
self.reorder_prompt_req_index_list[:prompt_list_idx]
[-reorder_req_num:])
reorder_decode_list = (
self.reorder_decode_req_index_list[:decode_list_idx]
[:reorder_req_num])
assert reorder_decode_list.size == reorder_prompt_list.size
for idx in range(reorder_req_num):
prompt_req_index = reorder_prompt_list[idx].item()
decode_req_index = reorder_decode_list[idx].item()
input_batch.swap_states(prompt_req_index, decode_req_index)
return True
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> TorchSDPAMetadata:
num_reqs = common_attn_metadata.num_reqs
max_query_len = common_attn_metadata.max_query_len
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
seq_lens_np = seq_lens_cpu.numpy()
num_prompt_req = self.num_prompt_req
max_prefill_seq_len = seq_lens_np[:num_prompt_req].max().item(
) if num_prompt_req > 0 else 0
max_decode_seq_len = seq_lens_np[num_prompt_req:num_reqs].max().item(
) if num_prompt_req < num_reqs else 0
self.seq_start_loc_np[0] = 0
np.cumsum(seq_lens_np, out=self.seq_start_loc_np[1:num_reqs + 1])
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
num_prefill_tokens = int(query_start_loc_cpu[num_prompt_req].item())
num_decode_tokens = int(query_start_loc_cpu[num_reqs].item() -
num_prefill_tokens)
slot_mapping = common_attn_metadata.slot_mapping.long()
block_table_tensor = common_attn_metadata.block_table_tensor
attn_metadata = TorchSDPAMetadata(
num_prefills=num_prompt_req,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
slot_mapping=slot_mapping,
# to ensure inference when chunked_prefill is disabled
seq_lens=seq_lens_cpu.tolist(),
seq_lens_tensor=seq_lens_cpu[num_prompt_req:num_reqs], # decode
max_decode_seq_len=max_decode_seq_len, # decode
block_tables=block_table_tensor[num_prompt_req:num_reqs], # decode
chunked_prefill=self.scheduler_config.chunked_prefill_enabled,
max_query_len=max_query_len,
max_kv_len=max_prefill_seq_len,
prefill_query_start_loc=query_start_loc_cpu[:num_prompt_req +
1], # prefill
kv_start_loc=self.seq_start_loc_cpu[:num_prompt_req +
1], # prefill
prefill_block_tables=block_table_tensor[:
num_prompt_req], # prefill
query_start_loc=query_start_loc_cpu[:num_reqs +
1], # for logits index
)
return attn_metadata
class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
) -> None:
if kv_sharing_target_layer_name is not None:
raise NotImplementedError("KV sharing is not supported in V0.")
if logits_soft_cap is not None:
logger.warning_once("Torch SPDA does not support logits soft cap. "
"Outputs may be slightly off.")
self.paged_attn_impl = _get_paged_attn_impl()
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
self.sliding_window = sliding_window
self.kv_cache_dtype = kv_cache_dtype
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.need_mask = (self.alibi_slopes is not None
or self.sliding_window is not None)
if is_quantized_kv_cache(kv_cache_dtype) and not _use_ipex:
raise NotImplementedError(
"Torch SDPA backend FP8 KV cache requires "
"intel_extension_for_pytorch support.")
self.attn_type = attn_type
def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: TorchSDPAMetadata, # type: ignore
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with torch SDPA and PagedAttention.
Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache: shape =
[2, num_blocks, block_size * num_kv_heads * head_size]
NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run.
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for TorchSDPABackendImpl")
# For warming-up
if attn_metadata is None:
return query
attn_type = self.attn_type
if (attn_type == AttentionType.ENCODER
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
raise AttributeError("Encoder attention requires setting "
"encoder metadata attributes.")
elif (attn_type == AttentionType.ENCODER_DECODER
and (not attn_metadata.is_all_cross_attn_metadata_set)):
raise AttributeError("Encoder/decoder cross-attention "
"requires setting cross-attention "
"metadata attributes.")
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)
if key is not None:
assert value is not None
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
else:
assert value is None
if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0):
# KV-cache during decoder-self- or
# encoder-decoder-cross-attention, but not
# during encoder attention.
#
# Even if there are no new key/value pairs to cache,
# we still need to break out key_cache and value_cache
# i.e. for later use by paged attention
key_cache, value_cache = self.paged_attn_impl.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)
if (key is not None) and (value is not None):
if attn_type == AttentionType.ENCODER_DECODER:
# Update cross-attention KV cache (prefill-only)
# During cross-attention decode, key & value will be None,
# preventing this IF-statement branch from running
updated_slot_mapping = attn_metadata.cross_slot_mapping
else:
# Update self-attention KV cache (prefill/decode)
updated_slot_mapping = attn_metadata.slot_mapping
self.paged_attn_impl.write_to_paged_cache(
key, value, key_cache, value_cache, updated_slot_mapping,
self.kv_cache_dtype, layer._k_scale, layer._v_scale)
if attn_type != AttentionType.ENCODER:
# Decoder self-attention supports chunked prefill.
# Encoder/decoder cross-attention requires no chunked
# prefill (100% prefill or 100% decode tokens, no mix)
num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
else:
# Encoder attention - chunked prefill is not applicable;
# derive token-count from query shape & and treat them
# as 100% prefill tokens
assert attn_metadata.num_encoder_tokens is not None
num_prefill_tokens = attn_metadata.num_encoder_tokens
num_decode_tokens = 0
if attn_type == AttentionType.DECODER:
# Only enforce this shape-constraint for decoder
# self-attention
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
output = torch.empty_like(query)
if prefill_meta := attn_metadata.prefill_metadata:
if not prefill_meta.prefill_metadata.chunked_prefill: # type: ignore
assert attn_metadata.seq_lens is not None
self._run_sdpa_forward(output,
query,
key,
value,
prefill_meta,
attn_type=attn_type)
else:
# prefix-enabled attention
assert not self.need_mask
import intel_extension_for_pytorch.llm.modules as ipex_modules
output = torch.empty_like(query)
ipex_modules.PagedAttention.flash_attn_varlen_func(
output[:prefill_meta.num_prefill_tokens, :, :],
query[:prefill_meta.num_prefill_tokens, :, :],
key_cache,
value_cache,
prefill_meta.prefill_query_start_loc,
prefill_meta.kv_start_loc,
prefill_meta.max_query_len,
prefill_meta.max_kv_len,
self.scale,
True,
prefill_meta.prefill_block_tables,
self.alibi_slopes,
)
if decode_meta := attn_metadata.decode_metadata:
assert attn_type != AttentionType.ENCODER_ONLY, (
"Encoder-only models should not have decode metadata.")
# Decoding run.
(
seq_lens_arg,
max_seq_len_arg,
block_tables_arg,
) = decode_meta.get_seq_len_block_table_args(attn_type)
self.paged_attn_impl.forward_decode(
output[attn_metadata.num_prefill_tokens:, :, :],
query[attn_metadata.num_prefill_tokens:, :, :],
key_cache,
value_cache,
block_tables_arg,
seq_lens_arg,
max_seq_len_arg,
self.kv_cache_dtype,
self.num_kv_heads,
self.scale,
self.alibi_slopes,
layer._k_scale,
layer._v_scale,
)
# Reshape the output tensor.
return output.view(-1, self.num_heads * self.head_size)
def _run_sdpa_forward(
self,
output: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: TorchSDPAMetadata,
attn_type: str = AttentionType.DECODER,
) -> None:
attn_masks = attn_metadata.get_attn_bias(attn_type)
if attn_masks is None:
if self.alibi_slopes is not None:
attn_masks = _make_alibi_bias(
self.alibi_slopes, query.dtype,
attn_metadata.seq_lens) # type: ignore
elif self.sliding_window is not None:
assert attn_metadata.seq_lens is not None
attn_masks = _make_sliding_window_bias(
attn_metadata.seq_lens, self.sliding_window,
query.dtype) # type: ignore
else:
seq_lens, _ = attn_metadata.get_seq_lens(attn_type)
attn_masks = [None] * len(seq_lens)
attn_metadata.set_attn_bias(attn_masks, attn_type)
query = query.movedim(0, query.dim() - 2)
key = key.movedim(0, key.dim() - 2)
value = value.movedim(0, value.dim() - 2)
if self.num_kv_heads != self.num_heads:
key = key.repeat_interleave(self.num_queries_per_kv, dim=-3)
value = value.repeat_interleave(self.num_queries_per_kv, dim=-3)
causal_attn = (attn_type == AttentionType.DECODER)
seq_lens_q, seq_lens_kv = attn_metadata.get_seq_lens(attn_type)
start_q, start_kv = 0, 0
for seq_len_q, seq_len_kv, mask in zip(seq_lens_q, seq_lens_kv,
attn_masks):
end_q = start_q + seq_len_q
end_kv = start_kv + seq_len_kv
sub_out = scaled_dot_product_attention(
query[None, :, start_q:end_q, :],
key[None, :, start_kv:end_kv, :],
value[None, :, start_kv:end_kv, :],
attn_mask=mask,
dropout_p=0.0,
is_causal=causal_attn and mask is None,
scale=self.scale).squeeze(0).movedim(query.dim() - 2, 0)
output[start_q:end_q, :, :] = sub_out
start_q, start_kv = end_q, end_kv
def _make_alibi_bias(
alibi_slopes: torch.Tensor,
dtype: torch.dtype,
seq_lens: list[int],
) -> list[torch.Tensor]:
attn_biases: list[torch.Tensor] = []
for seq_len in seq_lens:
bias = torch.arange(seq_len, dtype=dtype)
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(seq_len, 1)`
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# paper.
bias = bias[None, :] - bias[:, None]
num_heads = alibi_slopes.shape[0]
bias = bias[None, :].repeat((num_heads, 1, 1))
bias.mul_(alibi_slopes[:, None, None]).unsqueeze_(0)
inf_mask = torch.empty(
(1, seq_len, seq_len),
dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1)
attn_biases.append((bias + inf_mask).to(dtype))
return attn_biases
def _make_sliding_window_bias(
seq_lens: list[int],
window_size: Optional[int],
dtype: torch.dtype,
) -> list[torch.Tensor]:
attn_biases: list[torch.Tensor] = []
for seq_len in seq_lens:
tensor = torch.full(
(1, seq_len, seq_len),
dtype=dtype,
fill_value=1,
)
shift = 0
mask = torch.tril(tensor, diagonal=shift).to(dtype) # type: ignore
if window_size is not None:
mask = torch.triu(mask, diagonal=shift - window_size + 1)
mask = torch.log(mask)
attn_biases.append(mask.to(dtype))
return attn_biases
class _PagedAttention:
@staticmethod
def validate_head_size(head_size: int) -> tuple[bool, list[int]]:
SUPPORT_HS = [32, 64, 80, 96, 112, 128, 192, 256]
return head_size in SUPPORT_HS, SUPPORT_HS
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
*args,
) -> tuple[int, ...]:
return 2, num_blocks, block_size * num_kv_heads * head_size
@staticmethod
def split_kv_cache(
kv_cache: torch.Tensor,
num_kv_heads: int,
head_size: int,
*args,
) -> tuple[torch.Tensor, torch.Tensor]:
x = 16 // kv_cache.element_size()
num_blocks = kv_cache.shape[1]
key_cache = kv_cache[0]
key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x,
-1, x)
value_cache = kv_cache[1]
value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1)
return key_cache, value_cache
@staticmethod
def write_to_paged_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
*args,
) -> None:
ops.reshape_and_cache(
key,
value,
key_cache,
value_cache,
slot_mapping.flatten(),
kv_cache_dtype,
k_scale,
v_scale,
)
@staticmethod
def forward_decode(
output: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
context_lens: torch.Tensor,
max_context_len: int,
kv_cache_dtype: str,
num_kv_heads: int,
scale: float,
alibi_slopes: Optional[torch.Tensor],
k_scale: torch.Tensor,
v_scale: torch.Tensor,
*args,
) -> None:
tp_rank: int = 0
blocksparse_local_blocks: int = 0
blocksparse_vert_stride: int = 0
blocksparse_block_size: int = 64
blocksparse_head_sliding_step: int = 0
block_size = value_cache.shape[3]
ops.paged_attention_v1(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
context_lens,
block_size,
max_context_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
)
class _IPEXPagedAttention(_PagedAttention):
@staticmethod
def validate_head_size(head_size: int) -> tuple[bool, list[int]]:
return True, []
@staticmethod
def split_kv_cache(
kv_cache: torch.Tensor,
num_kv_heads: int,
head_size: int,
*args,
) -> tuple[torch.Tensor, torch.Tensor]:
num_blocks = kv_cache.shape[1]
key_cache = kv_cache[0]
key_cache = key_cache.view(num_blocks, num_kv_heads, -1, head_size)
value_cache = kv_cache[1]
value_cache = value_cache.view(num_blocks, num_kv_heads, -1, head_size)
return key_cache, value_cache
@staticmethod
def write_to_paged_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
*args,
) -> None:
ipex_modules.PagedAttention.reshape_and_cache(
key, value, key_cache, value_cache,
slot_mapping.flatten().int())
@staticmethod
def forward_decode(
output: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
context_lens: torch.Tensor,
max_context_len: int,
kv_cache_dtype: str,
num_kv_heads: int,
scale: float,
alibi_slopes: Optional[torch.Tensor],
k_scale: torch.Tensor,
v_scale: torch.Tensor,
*args,
) -> None:
block_size = value_cache.shape[2]
head_mapping = torch.arange(
0,
num_kv_heads,
device="cpu",
dtype=torch.int32,
).view(num_kv_heads,
1).repeat_interleave(query.size(1) // num_kv_heads).flatten()
ipex_modules.PagedAttention.single_query_cached_kv_attention(
output, query.contiguous(), key_cache, value_cache, head_mapping,
scale, block_tables, context_lens, block_size, max_context_len,
alibi_slopes)
def _get_paged_attn_impl():
if _use_ipex:
return _IPEXPagedAttention
else:
return _PagedAttention

View File

@@ -0,0 +1,795 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with FlashAttention."""
from dataclasses import dataclass
from typing import Optional
import numpy as np
import torch
from vllm import envs
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType,
is_quantized_kv_cache)
from vllm.attention.layer import Attention
from vllm.attention.ops.merge_attn_states import merge_attn_states
from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8,
get_flash_attn_version,
is_flash_attn_varlen_func_available)
if is_flash_attn_varlen_func_available():
from vllm.attention.utils.fa_utils import (flash_attn_varlen_func,
get_scheduler_metadata,
reshape_and_cache_flash)
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger
from vllm.utils import cdiv
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
get_kv_cache_layout)
from vllm.v1.kv_cache_interface import AttentionSpec
logger = init_logger(__name__)
class FlashAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
supports_quant_query_input: bool = True
@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]
@classmethod
def validate_head_size(cls, head_size: int) -> None:
supported_head_sizes = cls.get_supported_head_sizes()
if head_size not in supported_head_sizes:
attn_type = cls.__name__.removesuffix("Backend")
raise ValueError(
f"Head size {head_size} is not supported by {attn_type}. "
f"Supported head sizes are: {supported_head_sizes}. "
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes.")
@staticmethod
def get_name() -> str:
return "FLASH_ATTN"
@staticmethod
def get_impl_cls() -> type["FlashAttentionImpl"]:
return FlashAttentionImpl
@staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
return FlashAttentionMetadata
@staticmethod
def get_builder_cls() -> type["FlashAttentionMetadataBuilder"]:
return FlashAttentionMetadataBuilder
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
return (2, num_blocks, block_size, num_kv_heads, head_size)
@staticmethod
def get_kv_cache_stride_order() -> tuple[int, ...]:
# `stride_order` indicates the permutation that gets
# us from `get_kv_cache_shape` to the actual memory layout we want.
cache_layout = get_kv_cache_layout()
if cache_layout == "NHD":
stride_order = (0, 1, 2, 3, 4)
elif cache_layout == "HND":
stride_order = (0, 1, 3, 2, 4)
else:
raise ValueError(f"Unknown cache layout format {cache_layout}.")
return stride_order
@staticmethod
def get_fp8_dtype_for_flashattn(kv_cache_dtype: str) -> torch.dtype:
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
return torch.float8_e4m3fn
else:
raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")
@dataclass
class FlashAttentionMetadata:
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
num_actual_tokens: int # Number of tokens excluding padding.
max_query_len: int
query_start_loc: torch.Tensor
max_seq_len: int
seq_lens: torch.Tensor
block_table: torch.Tensor
slot_mapping: torch.Tensor
# For cascade attention.
use_cascade: bool
common_prefix_len: int
cu_prefix_query_lens: Optional[torch.Tensor]
prefix_kv_lens: Optional[torch.Tensor]
suffix_kv_lens: Optional[torch.Tensor]
# Optional aot scheduling
scheduler_metadata: Optional[torch.Tensor] = None
prefix_scheduler_metadata: Optional[torch.Tensor] = None
max_num_splits: int = 0
causal: bool = True
def _get_sliding_window_configs(
vllm_config: VllmConfig) -> set[Optional[tuple[int, int]]]:
"""Get the set of all sliding window configs used in the model."""
sliding_window_configs: set[Optional[tuple[int, int]]] = set()
layers = get_layers_from_vllm_config(vllm_config, Attention)
for layer in layers.values():
assert isinstance(layer.impl, FlashAttentionImpl)
sliding_window_configs.add(layer.impl.sliding_window)
return sliding_window_configs
class FlashAttentionMetadataBuilder(
AttentionMetadataBuilder[FlashAttentionMetadata]):
# FA3:
# Supports full cudagraphs for all cases.
#
# FA2:
# For FA2, a graph is captured with max_query_len=1, (which is what we
# capture by default for num_tokens <= max_num_seqs when there is no
# spec-decode) then these graphs will not work for mixed prefill-decode
# (unlike FA3). This is due to special max_query_len=1 packed-GQA handling
# in FA2.
# In summary if we are running with spec decodes the graphs would
# work for mixed prefill-decode and uniform-decode. But for non-spec decodes
# the graphs would not work for mixed prefill-decode; sorta the inverse
# of UNIFORM_SINGLE_TOKEN_DECODE.
# There's probably a better way to describe this using `AttentionCGSupport`
# but for now just set it to `UNIFORM_BATCH` to get use to drop down
# to FULL_AND_PIECEWISE.
# TODO(luka, lucas): audit FA2 as part of:
# https://github.com/vllm-project/vllm/issues/22945
cudagraph_support = AttentionCGSupport.ALWAYS \
if get_flash_attn_version() == 3 else AttentionCGSupport.UNIFORM_BATCH
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self.model_config = vllm_config.model_config
self.parallel_config = vllm_config.parallel_config
self.cache_config = vllm_config.cache_config
self.compilation_config = vllm_config.compilation_config
self.num_heads_q = self.model_config.get_num_attention_heads(
self.parallel_config)
self.num_heads_kv = self.model_config.get_num_kv_heads(
self.parallel_config)
self.kv_cache_dtype = kv_cache_spec.dtype
self.headdim = self.model_config.get_head_size()
self.block_size = kv_cache_spec.block_size
self.max_num_splits = 0 # No upper bound on the number of splits.
self.aot_schedule = (get_flash_attn_version() == 3)
self.use_full_cuda_graph = \
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
self.max_cudagraph_size = self.compilation_config.max_capture_size
if self.use_full_cuda_graph and self.aot_schedule:
if self.max_cudagraph_size > 992:
# This condition derives from FA3's internal heuristic.
# TODO(woosuk): Support larger cudagraph sizes.
raise ValueError(
"Capture size larger than 992 is not supported for "
"full cuda graph.")
self.scheduler_metadata = torch.zeros(
vllm_config.scheduler_config.max_num_seqs + 1,
dtype=torch.int32,
device=self.device,
)
# When using cuda graph, we need to set the upper bound of the
# number of splits so that large enough intermediate buffers are
# pre-allocated during capture.
self.max_num_splits = (
envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH)
# Sliding window size to be used with the AOT scheduler will be
# populated on first build() call.
self.aot_sliding_window: Optional[tuple[int, int]] = None
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> FlashAttentionMetadata:
"""
fast_build disables AOT scheduling, used when there will be few
iterations i.e. spec-decode
"""
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
max_query_len = common_attn_metadata.max_query_len
max_seq_len = common_attn_metadata.max_seq_len
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
block_table_tensor = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping
causal = common_attn_metadata.causal
# the overhead of the aot schedule is not worth it for spec-decode
aot_schedule = self.aot_schedule and not fast_build
if self.aot_sliding_window is None:
self.aot_sliding_window = (-1, -1)
# For the AOT scheduler we need the sliding window value to be
# constant for all layers to. We have to populate this on the first
# build() call so the layers are constructed (cannot populate)
# in __init__.
if aot_schedule:
sliding_window_configs = _get_sliding_window_configs(
self.vllm_config)
if len(sliding_window_configs) == 1:
sliding_window_config = sliding_window_configs.pop()
if sliding_window_config is not None:
self.aot_sliding_window = sliding_window_config
elif len(sliding_window_configs) > 1:
self.aot_schedule = False
aot_schedule = False
max_num_splits = 0 # 0 means use FA3's heuristics, not CG compatible
if self.use_full_cuda_graph and \
num_actual_tokens <= self.max_cudagraph_size:
# NOTE(woosuk): Setting num_splits > 1 may increase the memory
# usage, because the intermediate buffers of size [num_splits,
# num_heads, num_tokens, head_size] are allocated. Therefore,
# we only set num_splits when using cuda graphs.
max_num_splits = self.max_num_splits
def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
max_seq_len, causal):
cache_dtype = self.cache_config.cache_dtype
if cache_dtype.startswith("fp8"):
qkv_dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn(
cache_dtype)
else:
qkv_dtype = self.kv_cache_dtype
if aot_schedule:
return get_scheduler_metadata(
batch_size=batch_size,
max_seqlen_q=max_query_len,
max_seqlen_k=max_seq_len,
num_heads_q=self.num_heads_q,
num_heads_kv=self.num_heads_kv,
headdim=self.headdim,
cache_seqlens=seqlens,
qkv_dtype=qkv_dtype,
cu_seqlens_q=cu_query_lens,
page_size=self.block_size,
causal=causal,
window_size=self.aot_sliding_window,
num_splits=max_num_splits,
)
return None
use_cascade = common_prefix_len > 0
if use_cascade:
cu_prefix_query_lens = torch.tensor([0, num_actual_tokens],
dtype=torch.int32,
device=self.device)
prefix_kv_lens = torch.tensor([common_prefix_len],
dtype=torch.int32,
device=self.device)
suffix_kv_lens = (seq_lens_cpu[:num_reqs] - common_prefix_len).to(
self.device, non_blocking=True)
prefix_scheduler_metadata = schedule(
batch_size=1,
cu_query_lens=cu_prefix_query_lens,
max_query_len=num_actual_tokens,
seqlens=prefix_kv_lens,
max_seq_len=common_prefix_len,
causal=False)
scheduler_metadata = schedule(batch_size=num_reqs,
cu_query_lens=query_start_loc,
max_query_len=max_query_len,
seqlens=suffix_kv_lens,
max_seq_len=max_seq_len -
common_prefix_len,
causal=True)
else:
cu_prefix_query_lens = None
prefix_kv_lens = None
suffix_kv_lens = None
prefix_scheduler_metadata = None
scheduler_metadata = schedule(batch_size=num_reqs,
cu_query_lens=query_start_loc,
max_query_len=max_query_len,
seqlens=seq_lens,
max_seq_len=max_seq_len,
causal=causal)
# For FA3 + full cudagraph
if self.use_full_cuda_graph and scheduler_metadata is not None:
n = scheduler_metadata.shape[0]
self.scheduler_metadata[:n] = scheduler_metadata
# NOTE(woosuk): We should zero out the rest of the scheduler
# metadata to guarantee the correctness. Otherwise, some thread
# blocks may use the invalid scheduler metadata and overwrite the
# output buffer.
self.scheduler_metadata[n:] = 0
scheduler_metadata = self.scheduler_metadata[:n]
attn_metadata = FlashAttentionMetadata(
num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len,
query_start_loc=query_start_loc,
max_seq_len=max_seq_len,
seq_lens=seq_lens,
block_table=block_table_tensor,
slot_mapping=slot_mapping,
use_cascade=use_cascade,
common_prefix_len=common_prefix_len,
scheduler_metadata=scheduler_metadata,
cu_prefix_query_lens=cu_prefix_query_lens,
prefix_kv_lens=prefix_kv_lens,
suffix_kv_lens=suffix_kv_lens,
prefix_scheduler_metadata=prefix_scheduler_metadata,
max_num_splits=max_num_splits,
causal=causal)
return attn_metadata
def use_cascade_attention(self, *args, **kwargs) -> bool:
return use_cascade_attention(*args, **kwargs)
class FlashAttentionImpl(AttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
logits_soft_cap: Optional[float] = None,
attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
sinks: Optional[torch.Tensor] = None,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
if sliding_window is None:
self.sliding_window = (-1, -1)
elif attn_type == AttentionType.ENCODER_ONLY:
self.sliding_window = (sliding_window - 1, sliding_window - 1)
else:
self.sliding_window = (sliding_window - 1, 0)
self.kv_cache_dtype = kv_cache_dtype
if logits_soft_cap is None:
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
logits_soft_cap = 0
self.logits_soft_cap = logits_soft_cap
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
FlashAttentionBackend.validate_head_size(head_size)
self.attn_type = attn_type
self.vllm_flash_attn_version = get_flash_attn_version()
if is_quantized_kv_cache(self.kv_cache_dtype) \
and not flash_attn_supports_fp8():
raise NotImplementedError(
"FlashAttention does not support fp8 kv-cache on this device.")
self.sinks = sinks
if self.sinks is not None:
assert self.vllm_flash_attn_version == 3, (
"Sinks are only supported in FlashAttention 3")
assert self.sinks.shape[0] == num_heads, (
"Sinks must have the same number of heads as the number of "
"heads in the layer")
def forward(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: FlashAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention.
Args:
query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache: shape =
[2, num_blocks, block_size, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
NOTE: FP8 quantization, flash-attn expect the size of
{q,k,v}_descale to be (num_sequences, num_kv_heads).
We use torch's .expand() to avoid duplicating values
"""
assert output is not None, "Output tensor must be provided."
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for FlashAttentionImpl")
if attn_metadata is None:
# Profiling run.
return output
attn_type = self.attn_type
# IMPORTANT!
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
# in this method. For example, `view` and `slice` (or `[:n]`) operations
# are surprisingly slow even in the case they do not invoke any GPU ops.
# Minimize the PyTorch ops in this method as much as possible.
# Whenever making a change in this method, please benchmark the
# performance to make sure it does not introduce any overhead.
num_actual_tokens = attn_metadata.num_actual_tokens
# Handle encoder attention differently - no KV cache needed
if attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
# For encoder attention,
# we use direct Q, K, V tensors without caching
return self._forward_encoder_attention(query[:num_actual_tokens],
key[:num_actual_tokens],
value[:num_actual_tokens],
output[:num_actual_tokens],
attn_metadata, layer)
# For decoder and cross-attention, use KV cache as before
key_cache, value_cache = kv_cache.unbind(0)
# key and value may be None in the case of cross attention. They are
# calculated once based on the output from the encoder and then cached
# in KV cache.
if (self.kv_sharing_target_layer_name is None and key is not None
and value is not None):
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
# not padded. However, we don't need to do key[:num_actual_tokens]
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
if self.kv_cache_dtype.startswith("fp8"):
# queries are quantized in the attention layer
dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn(
self.kv_cache_dtype)
key_cache = key_cache.view(dtype)
value_cache = value_cache.view(dtype)
if not attn_metadata.use_cascade:
cu_seqlens_q = attn_metadata.query_start_loc
seqused_k = attn_metadata.seq_lens
max_seqlen_q = attn_metadata.max_query_len
max_seqlen_k = attn_metadata.max_seq_len
block_table = attn_metadata.block_table
scheduler_metadata = attn_metadata.scheduler_metadata
descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads)
flash_attn_varlen_func(
q=query[:num_actual_tokens],
k=key_cache,
v=value_cache,
out=output[:num_actual_tokens],
cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=max_seqlen_q,
seqused_k=seqused_k,
max_seqlen_k=max_seqlen_k,
softmax_scale=self.scale,
causal=attn_metadata.causal,
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
block_table=block_table,
softcap=self.logits_soft_cap,
scheduler_metadata=scheduler_metadata,
fa_version=self.vllm_flash_attn_version,
q_descale=layer._q_scale.expand(descale_shape),
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
num_splits=attn_metadata.max_num_splits,
s_aux=self.sinks,
)
return output
# Cascade attention (rare case).
cascade_attention(
output[:num_actual_tokens],
query[:num_actual_tokens],
key_cache,
value_cache,
cu_query_lens=attn_metadata.query_start_loc,
max_query_len=attn_metadata.max_query_len,
cu_prefix_query_lens=attn_metadata.cu_prefix_query_lens,
prefix_kv_lens=attn_metadata.prefix_kv_lens,
suffix_kv_lens=attn_metadata.suffix_kv_lens,
max_kv_len=attn_metadata.max_seq_len,
softmax_scale=self.scale,
alibi_slopes=self.alibi_slopes,
sliding_window=self.sliding_window,
logits_soft_cap=self.logits_soft_cap,
block_table=attn_metadata.block_table,
common_prefix_len=attn_metadata.common_prefix_len,
fa_version=self.vllm_flash_attn_version,
prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata,
suffix_scheduler_metadata=attn_metadata.scheduler_metadata,
q_descale=layer._q_scale,
k_descale=layer._k_scale,
v_descale=layer._v_scale,
)
return output
def _forward_encoder_attention(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
attn_metadata: FlashAttentionMetadata,
layer: torch.nn.Module,
) -> torch.Tensor:
"""Forward pass for encoder attention without KV cache.
Args:
query: shape = [num_encoder_tokens, num_heads, head_size]
key: shape = [num_encoder_tokens, num_kv_heads, head_size]
value: shape = [num_encoder_tokens, num_kv_heads, head_size]
output: shape = [num_encoder_tokens, num_heads, head_size]
attn_metadata: Encoder attention metadata
layer: The attention layer
"""
# For encoder attention, process FP8 quantization if needed
if self.kv_cache_dtype.startswith("fp8"):
raise NotImplementedError(
"quantization is not supported for encoder attention")
# Use encoder-specific metadata for sequence information
cu_seqlens_q = attn_metadata.query_start_loc
cu_seqlens_k = attn_metadata.query_start_loc
max_seqlen_q = attn_metadata.max_query_len
max_seqlen_k = attn_metadata.max_query_len
descale_shape = (
cu_seqlens_q.shape[0] - 1, # type: ignore[union-attr]
self.num_kv_heads)
# Call flash attention directly on Q, K, V tensors
flash_attn_varlen_func(
q=query,
k=key,
v=value,
out=output,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
softmax_scale=self.scale,
causal=False, # Encoder attention is bidirectional
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
softcap=self.logits_soft_cap,
fa_version=self.vllm_flash_attn_version,
q_descale=layer._q_scale.expand(descale_shape),
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
)
return output
def use_cascade_attention(
common_prefix_len: int,
query_lens: np.ndarray,
num_query_heads: int,
num_kv_heads: int,
use_alibi: bool,
use_sliding_window: bool,
use_local_attention: bool,
num_sms: int,
) -> bool:
"""Decide whether to use cascade attention.
This function 1) checks whether cascade attention is supported with the
given configuration, and 2) heuristically decides whether using cascade
attention can improve performance.
"""
# Too short common prefix. Probably not worth using cascade attention.
# We use an arbitrary threshold of 256 tokens. TODO: Tune this threshold.
# NOTE(woosuk): This is the common case. We should return False as soon as
# possible to avoid any unnecessary computation.
if common_prefix_len < 256:
return False
# Cascade attention is currently not supported with these variants.
if use_alibi or use_sliding_window or use_local_attention:
return False
# Too few queries. Probably not worth using cascade attention.
# We use an arbitrary threshold of 8 queries. TODO: Tune this threshold.
num_reqs = len(query_lens)
if num_reqs < 8:
return False
# Heuristics to decide whether using cascade attention is beneficial.
# 1. When FlashDecoding is not used for normal attention, cascade attention
# is likely to be faster since it saves memory bandwidth.
num_queries_per_kv = num_query_heads // num_kv_heads
# The criteria for using FlashDecoding can be found in the following link:
# https://github.com/vllm-project/flash-attention/blob/96266b1111111f3d11aabefaf3bacbab6a89d03c/csrc/flash_attn/flash_api.cpp#L535
use_flash_decoding = (num_queries_per_kv > 1 and not use_sliding_window
and not use_alibi and np.all(query_lens == 1))
if not use_flash_decoding:
# Use cascade attention.
return True
# 2. When FlashDecoding is used for normal attention, it is not clear
# whether cascade attention is beneficial, because FlashDecoding can
# launch more CTAs than cascade attention.
# We use a simple performance model to compare the two methods.
# NOTE(woosuk): The performance model is very rough and may not be
# accurate.
num_tokens = num_reqs
# NOTE(woosuk): These are default tile sizes. flash-attn might use
# different tile sizes (e.g., 64 or 256) depending on the configuration.
q_tile_size = 128
kv_tile_size = 128
num_prefix_tiles = cdiv(common_prefix_len, kv_tile_size)
cascade_ctas = num_query_heads * cdiv(num_tokens, q_tile_size)
cascade_waves = cdiv(cascade_ctas, num_sms)
cascade_time = cascade_waves * num_prefix_tiles
flash_decoding_ctas = (num_reqs * num_kv_heads *
cdiv(num_queries_per_kv, q_tile_size))
flash_decoding_ctas *= num_prefix_tiles
flash_decoding_time = cdiv(flash_decoding_ctas, num_sms)
# Use cascade attention if it is faster than FlashDecoding.
return cascade_time < flash_decoding_time
def cascade_attention(
output: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
cu_query_lens: torch.Tensor,
max_query_len: int,
cu_prefix_query_lens: torch.Tensor,
prefix_kv_lens: torch.Tensor,
suffix_kv_lens: torch.Tensor,
max_kv_len: int,
softmax_scale: float,
alibi_slopes: Optional[torch.Tensor],
sliding_window: tuple[int, int],
logits_soft_cap: float,
block_table: torch.Tensor,
common_prefix_len: int,
fa_version: int,
prefix_scheduler_metadata: Optional[torch.Tensor] = None,
suffix_scheduler_metadata: Optional[torch.Tensor] = None,
q_descale: Optional[torch.Tensor] = None,
k_descale: Optional[torch.Tensor] = None,
v_descale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert alibi_slopes is None, ("Cascade attention does not support ALiBi.")
# TODO: Support sliding window.
assert sliding_window == (-1, -1), (
"Cascade attention does not support sliding window.")
num_tokens = query.shape[0]
block_size = key_cache.shape[-3]
assert common_prefix_len % block_size == 0
num_common_kv_blocks = common_prefix_len // block_size
assert num_common_kv_blocks > 0
descale_shape = (cu_prefix_query_lens.shape[0] - 1, key_cache.shape[-2])
# Process shared prefix.
prefix_output, prefix_lse = flash_attn_varlen_func(
q=query,
k=key_cache,
v=value_cache,
cu_seqlens_q=cu_prefix_query_lens,
seqused_k=prefix_kv_lens,
max_seqlen_q=num_tokens,
max_seqlen_k=common_prefix_len,
softmax_scale=softmax_scale,
causal=False,
window_size=sliding_window,
block_table=block_table[:1],
softcap=logits_soft_cap,
return_softmax_lse=True,
scheduler_metadata=prefix_scheduler_metadata,
fa_version=fa_version,
q_descale=q_descale.expand(descale_shape)
if q_descale is not None else None,
k_descale=k_descale.expand(descale_shape)
if k_descale is not None else None,
v_descale=v_descale.expand(descale_shape)
if v_descale is not None else None,
)
descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2])
# Process suffix per query.
suffix_output, suffix_lse = flash_attn_varlen_func(
q=query,
k=key_cache,
v=value_cache,
cu_seqlens_q=cu_query_lens,
seqused_k=suffix_kv_lens,
max_seqlen_q=max_query_len,
max_seqlen_k=max_kv_len - common_prefix_len,
softmax_scale=softmax_scale,
causal=True,
window_size=sliding_window,
block_table=block_table[:, num_common_kv_blocks:],
softcap=logits_soft_cap,
return_softmax_lse=True,
scheduler_metadata=suffix_scheduler_metadata,
fa_version=fa_version,
q_descale=q_descale.expand(descale_shape)
if q_descale is not None else None,
k_descale=k_descale.expand(descale_shape)
if k_descale is not None else None,
v_descale=v_descale.expand(descale_shape)
if v_descale is not None else None,
)
# Merge prefix and suffix outputs, and store the result in output.
merge_attn_states(output, prefix_output, prefix_lse, suffix_output,
suffix_lse)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,861 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with FlexAttention."""
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional, Union
import torch
import torch._dynamo.decorators
import torch.nn.functional as F
from torch.nn.attention.flex_attention import (BlockMask, _mask_mod_signature,
_score_mod_signature, and_masks,
create_block_mask,
flex_attention)
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType,
is_quantized_kv_cache)
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.utils import cdiv, is_torch_equal_or_newer
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import AttentionSpec
logger = init_logger(__name__)
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
create_block_mask_compiled = torch.compile(create_block_mask,
fullgraph=True,
mode="reduce-overhead")
flex_attention_compiled = torch.compile(flex_attention, fullgraph=True)
def _offsets_to_doc_ids_tensor(offsets: torch.Tensor) -> torch.Tensor:
device = offsets.device
counts = offsets[1:] - offsets[:-1]
return torch.repeat_interleave(
torch.arange(len(counts), device=device, dtype=torch.int32), counts)
def pad_to_multiple(x: torch.Tensor, multiple: int, dim: int):
difference = (multiple - (x.shape[dim] % multiple)) % multiple
if difference == 0:
return x
dim = dim if dim >= 0 else x.ndim + dim
pad_list = []
for i in range(x.ndim - 1, dim - 1, -1):
if i == dim:
pad_list.extend([0, difference])
else:
pad_list.extend([0, 0])
return F.pad(x, pad_list, mode="constant", value=0)
class FlexAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16, torch.float32]
@classmethod
def validate_head_size(cls, head_size: int) -> None:
return # FlexAttention supports any head size
@staticmethod
def get_name() -> str:
return "FLEX_ATTENTION"
@staticmethod
def get_impl_cls() -> type["FlexAttentionImpl"]:
return FlexAttentionImpl
@staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
return FlexAttentionMetadata
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
return (2, num_blocks, block_size, num_kv_heads, head_size)
@staticmethod
def get_builder_cls() -> type["FlexAttentionMetadataBuilder"]:
return FlexAttentionMetadataBuilder
@staticmethod
def use_cascade_attention(*args, **kwargs) -> bool:
return False
#@torch.compile(fullgraph=True, mode="reduce-overhead")
def physical_to_logical_mapping(block_table: torch.Tensor,
seq_lens: torch.Tensor, block_size: int,
total_blocks: int) -> torch.Tensor:
"""
Creates an inverse mapping from physical block locations to logical indices.
The original block_table maps from logical blocks to physical locations:
Logical to Physical (Original block_table):
┌───────────────────────────────────────────┐
│ Request 0: │
│ │
│ Logical Blocks: 0 1 2 3 4 5 6 7 │
│ │ │ │ │ │ │ │ │ │
│ v v v v v v v v │
│ Physical Blocks: 3 5 1 7 4 2 0 6 │
└───────────────────────────────────────────┘
This function creates the inverse mapping:
Physical to Logical (Inverse mapping):
┌───────────────────────────────────────────┐
│ Request 0: │
│ │
│ Physical Blocks: 0 1 2 3 4 5 6 7 │
│ │ │ │ │ │ │ │ │ │
│ v v v v v v v v │
│ Logical Blocks: 6 2 5 0 4 1 7 3 │
└───────────────────────────────────────────┘
If multiple logical blocks map to the same physical block,
this function returns the first (minimum) logical block index.
If a physical block is not mapped to by any logical block,
its value in the result will be -1.
IMPORTANT: Garbage Value Protection
────────────────────────────────────
The block_table tensor may contain garbage values in unused positions
(beyond the actual sequence length). For example, if a sequence only
needs 3 blocks but the table has space for 8:
block_table[0] = [10, 25, 7, 999, 1234, 888, ...]
^^^^^^^^^^^^^^^^^^^^
garbage values
These garbage values can cause issues because:
1. They may map to valid physical blocks by coincidence
2. The scatter_ operation will assign them logical indices
3. Later attention computations may incorrectly access these blocks
To prevent this, we use seq_lens and block_size to mask out unused
entries, ensuring only valid block references are processed.
Args:
block_table: Tensor of shape [max_reqs, max_num_blocks]
mapping logical blocks to physical locations. May contain
garbage values in unused positions.
seq_lens: Tensor of sequence lengths for each request. Used to
determine how many blocks are actually needed per sequence.
block_size: Size of each block in tokens. Used with seq_lens to
compute the number of valid blocks per sequence.
total_blocks: Total number of physical blocks available
Returns:
A tensor of shape [max_reqs, total_blocks] where each entry
physical_to_logical[req_id, physical_block] contains the logical
block index for that physical block, or -1 if unused.
"""
max_reqs, max_num_blocks = block_table.shape
device = block_table.device
physical_to_logical = torch.full((max_reqs, total_blocks),
-1,
dtype=torch.long,
device=device)
# Only process valid blocks to avoid garbage values
num_blocks_per_seq = cdiv(seq_lens, block_size)
mask = torch.arange(max_num_blocks,
device=device)[None, :] < num_blocks_per_seq[:, None]
valid_block_table = torch.where(mask, block_table, 0)
valid_logical_indices = torch.where(
mask,
torch.arange(max_num_blocks, device=device)[None, :], 0)
physical_to_logical.scatter_(-1, valid_block_table.to(torch.int64),
valid_logical_indices)
# NB - Seems like block 0 is always empty so we reset it manually
physical_to_logical[:, 0] = -1
return physical_to_logical
def unique_static_unsorted(
x: torch.Tensor,
*,
M: int, # maximum positive value (0 is “skip me”)
dim: int = -1, # axis along which to deduplicate
ignored_val: int = 0, # value to ignore
pad_val: int = -1, # sentinel for unused slots
) -> torch.Tensor:
"""
- Keeps the first occurrence of each non-zero value while preserving order,
then left-packs those uniques and fills the rest with `pad_val`.
- Returns (packed, keep_mask) with the *same shape* as `x`.
- Requires that all values be in the range [0, M]
- Skips ignored_val
Works on CPU or GPU, no Python loops, O(B·N) time / O(B·M) memory.
Example:
x =[3, 1, 0, 1, 2], M=3, ignored_val=0 => [3, 1, 2, -1, -1]
"""
if not (-1 <= pad_val <= M):
raise ValueError("`pad_val` must lie in [-1, M]")
# ── move `dim` to the end so we can treat tensor as [B, N] ──────────
dim = dim % x.ndim
x_perm = x.movedim(dim, -1) # shape [..., N]
B, N = x_perm.numel() // x_perm.shape[-1], x_perm.shape[-1]
x_flat = x_perm.reshape(B, N) # [B, N]
device = x.device
idx = torch.arange(N, device=device).expand(B, N) # per-row indices
# ── build first-occurrence table for every v ∈ [0, M] ───────────────
first_idx = torch.full((B, M + 1), N, device=device) # “∞”
# scatter_reduce_: first_idx[b, v] = min(first_idx[b, v], i)for each i
first_idx.scatter_reduce_(1, x_flat, idx, reduce="amin")
# ── keep mask: first occurrence *and* value ≠ 0 ─────────────────────
keep = (x_flat != ignored_val) & (idx == first_idx.gather(1, x_flat)
) # [B, N]
# ── left-pack uniques into a fresh tensor ───────────────────────────
dest_pos = torch.cumsum(keep.to(torch.long), dim=1) - 1 # where to go
packed_flat = torch.full_like(x_flat, pad_val)
rows, src_cols = torch.nonzero(keep, as_tuple=True)
packed_flat[rows, dest_pos[rows, src_cols]] = x_flat[rows, src_cols]
# ── restore original layout ─────────────────────────────────────────
packed = packed_flat.reshape(x_perm.shape).movedim(-1, dim)
return packed
def causal_mask_mod(b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor,
kv_idx: torch.Tensor):
return q_idx >= kv_idx
@dataclass
class FlexAttentionMetadata:
causal: bool
num_actual_tokens: int # Number of tokens excluding padding.
max_query_len: int
query_start_loc: torch.Tensor
max_seq_len: int
seq_lens: torch.Tensor
block_table: torch.Tensor
slot_mapping: torch.Tensor
use_cascade: bool
common_prefix_len: int
cu_prefix_query_lens: Optional[torch.Tensor]
prefix_kv_lens: Optional[torch.Tensor]
suffix_kv_lens: Optional[torch.Tensor]
# Block info
total_cache_tokens: int
block_size: int
max_possible_sequence_length: int
num_reqs: int
physical_to_logical: torch.Tensor
decode_offset: torch.Tensor
num_blocks_per_seq: torch.Tensor
# For logging.
num_input_tokens: int = 0 # Number of tokens including padding.
# Flex Metadata
num_blocks = 0
block_mask: Optional[BlockMask] = None
score_mod: Optional[_score_mod_signature] = None
logical_mask_mod: _mask_mod_signature = causal_mask_mod
doc_ids: Optional[torch.Tensor] = None
direct_build: bool = True
q_block_size: int = 16
kv_block_size: int = 16
transformed_score_mod: Optional[_score_mod_signature] = None
sliding_window: Optional[int] = None
def _convert_physical_to_logical(
self,
request_lookup: torch.Tensor,
q_idx: torch.Tensor,
physical_kv_idx: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Convert physical indices to logical indices for both query and kv.
NB is_within_lower_bound: do sequences start on block_boundaries?
Returns:
tuple of (is_valid, logical_q_idx, logical_kv_idx)
"""
# Map query indices to corresponding request indices
q_req = request_lookup[q_idx]
# Convert physical KV indices to logical indices
physical_kv_block = physical_kv_idx // self.block_size
physical_kv_offset = physical_kv_idx % self.block_size
logical_block_idx = self.physical_to_logical[q_req, physical_kv_block]
logical_kv_idx = (logical_block_idx * self.block_size +
physical_kv_offset)
# Determine valid kv indices
live_block = logical_block_idx >= 0
within_upper_bound = logical_kv_idx < self.seq_lens[q_req]
within_lower_bound = logical_kv_idx >= 0
is_valid = live_block & within_upper_bound & within_lower_bound
# Convert physical query indices to logical indices
local_q_idx = q_idx - self.query_start_loc[q_req]
logical_q_idx = local_q_idx + self.decode_offset[q_req]
return is_valid, logical_q_idx, logical_kv_idx
def get_causal_mask_mod(self) -> _mask_mod_signature:
"""Creates the mask_mod function for FlexAttention.
This function creates the combined mask mod function that handles:
1. The paged attention block mapping
2. The mapping from packed query sequences to logical query entries
It also by defaults adds the decoding offset to the query indices.
With this info we create the "logical" indices that are passed to
mask_mod functions. This allows mask mod functions to be agnostic to
layout of the query and key/value tensors.
"""
assert self.doc_ids is not None
def final_mask_mod(
b: torch.Tensor,
h: torch.Tensor,
q_idx: torch.Tensor,
physical_kv_idx: torch.Tensor,
) -> torch.Tensor:
(is_valid, logical_q_idx,
logical_kv_idx) = self._convert_physical_to_logical(
self.doc_ids, q_idx, physical_kv_idx)
# Apply mask modification only for valid indices
return torch.where(
is_valid,
self.logical_mask_mod(b, h, logical_q_idx, logical_kv_idx),
False,
)
return final_mask_mod
def get_bidirectional_mask_mod(self) -> _mask_mod_signature:
"""Creates the encoder mask_mod function for FlexAttention.
Since the encoder bidirectional attention doesn't run with
KV cache, this function creates a mask based on the
packed query sequences.
"""
# Create a lookup mapping from query indices -> request number
request_lookup = _offsets_to_doc_ids_tensor(self.query_start_loc)
def final_mask_mod(
b: torch.Tensor,
h: torch.Tensor,
q_idx: torch.Tensor,
kv_idx: torch.Tensor,
) -> torch.Tensor:
return request_lookup[q_idx] == request_lookup[kv_idx]
return final_mask_mod
def get_sliding_window_mask_mod(self) -> _mask_mod_signature:
"""Creates the sliding window mask_mod function for FlexAttention.
Note that the sliding window mask here is bidirectional, we need
to mask it with the bidirectional/causal mask for encoder/decoder.
"""
if self.sliding_window is None:
raise ValueError(
"sliding_window must be set for sliding window attention")
def sliding_window_mask_mod(b: torch.Tensor, h: torch.Tensor,
q_idx: torch.Tensor, kv_idx: torch.Tensor):
return torch.abs(q_idx - kv_idx) < self.sliding_window
def final_mask_mod(
b: torch.Tensor,
h: torch.Tensor,
q_idx: torch.Tensor,
physical_kv_idx: torch.Tensor,
) -> torch.Tensor:
(is_valid, logical_q_idx,
logical_kv_idx) = self._convert_physical_to_logical(
self.doc_ids, q_idx, physical_kv_idx)
return torch.where(
is_valid,
sliding_window_mask_mod(b, h, logical_q_idx, logical_kv_idx),
False,
)
return final_mask_mod if self.causal else sliding_window_mask_mod
def get_mask_mod(self):
# Stage-1: initialize the base mask_mod
# (causal mask for decoder or bidirectional mask for encoder)
if self.causal:
mask_mod = self.get_causal_mask_mod()
else:
mask_mod = self.get_bidirectional_mask_mod()
# stage-2: add external mask_mod for special attention during
# forwarding runtime to create the combined mask_mod.
if self.sliding_window is not None:
# Add sliding window mask for sliding window attention
sliding_window_mask_mod = self.get_sliding_window_mask_mod()
mask_mod = and_masks(mask_mod, sliding_window_mask_mod)
return mask_mod
def get_transformed_score_mod(self) -> Optional[_score_mod_signature]:
"""Creates the transformed score_mod function for FlexAttention.
This function wraps the user's score_mod to handle physical-to-logical
index conversion, similar to how get_mask_mod works for mask functions.
"""
if self.score_mod is None:
return None
# Create a lookup mapping from query indices -> request number
request_lookup = _offsets_to_doc_ids_tensor(self.query_start_loc)
user_score_mod = self.score_mod
def transformed_score_mod(
score: torch.Tensor,
b: torch.Tensor,
h: torch.Tensor,
q_idx: torch.Tensor,
physical_kv_idx: torch.Tensor,
) -> torch.Tensor:
(is_valid, logical_q_idx,
logical_kv_idx) = self._convert_physical_to_logical(
request_lookup, q_idx, physical_kv_idx)
return torch.where(
is_valid,
user_score_mod(score,
b,
h,
logical_q_idx,
logical_kv_idx,
physical_q=q_idx), -float('inf'))
return transformed_score_mod
def _build_block_mask_direct(self) -> BlockMask:
"""Direct block mask construction for standard causal attention.
This method constructs the block mask directly using
BlockMask.from_kv_blocks which is much more efficient than the
generic create_block_mask approach.
The direct path works as follows:
1. For each query token, fetch blocks from block_table using max_seq_len
(this fetches more blocks than needed for shorter sequences)
2. Group query tokens into chunks of q_block_size
3. For each group, deduplicate the blocks using unique_static_unsorted
4. Create BlockMask using the deduplicated block indices
Over-estimation occurs when a group of q_block_size tokens contains
multiple sequence IDs (doc_ids). In this case, we fetch ALL blocks for
each sequence represented in the group, even though individual query
tokens may only need a subset of those blocks based on causal masking
and their position.
"""
page_to_block_ratio = self.kv_block_size // self.block_size
if page_to_block_ratio != 1:
raise ValueError(
f"FlexAttention currently requires the cache block size "
f"({self.block_size}) to be equal to the kv_block_size "
f"({self.kv_block_size}). Please check your model's "
f"configuration.")
used_pages = self.block_table[
self.doc_ids, :cdiv(self.max_seq_len, self.block_size)]
used_pages_padded = pad_to_multiple(used_pages,
multiple=self.q_block_size,
dim=0)
used_pages_padded = used_pages_padded.reshape(
used_pages_padded.shape[0] // self.q_block_size, -1)
used_pages_padded = used_pages_padded // page_to_block_ratio
kv_indices = unique_static_unsorted((used_pages_padded.long()),
M=self.num_blocks).to(torch.int32)
kv_num_blocks = (kv_indices >= 0).sum(dim=-1).to(torch.int32)
block_mask_kwargs = {
"seq_lengths": (self.num_actual_tokens, self.total_cache_tokens),
"kv_num_blocks": kv_num_blocks[None, None],
"kv_indices": kv_indices[None, None],
"full_kv_num_blocks": None,
"full_kv_indices": None,
"BLOCK_SIZE": (self.q_block_size, self.kv_block_size),
"mask_mod": self.mask_mod,
}
# compute_q_blocks parameter is available in PyTorch 2.9+
if is_torch_equal_or_newer("2.9.0.dev0"):
block_mask_kwargs["compute_q_blocks"] = False
return BlockMask.from_kv_blocks(**block_mask_kwargs)
def build_block_mask(self) -> BlockMask:
mask_mod = self.get_mask_mod()
kv_len = (self.total_cache_tokens
if self.causal else self.num_actual_tokens)
return create_block_mask_compiled(
mask_mod,
None,
None,
self.num_actual_tokens,
kv_len,
device=self.block_table.device,
BLOCK_SIZE=(self.q_block_size, self.kv_block_size),
)
def __post_init__(self):
assert self.use_cascade is False, "Not implemented yet."
assert self.common_prefix_len == 0, "Not implemented yet."
assert self.cu_prefix_query_lens is None, "Not implemented yet."
assert self.prefix_kv_lens is None, "Not implemented yet."
assert self.suffix_kv_lens is None, "Not implemented yet."
# Create a lookup mapping from query indices -> request number
self.doc_ids = _offsets_to_doc_ids_tensor(self.query_start_loc)
self.num_blocks = self.total_cache_tokens // self.block_size
self.mask_mod = self.get_mask_mod()
self.transformed_score_mod = self.get_transformed_score_mod()
if self.direct_build and self.causal:
self.block_mask = self._build_block_mask_direct()
else:
self.block_mask = self.build_block_mask()
class FlexAttentionMetadataBuilder(
AttentionMetadataBuilder[FlexAttentionMetadata]):
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self.model_config = vllm_config.model_config
self.parallel_config = vllm_config.parallel_config
self.cache_config = vllm_config.cache_config
self.num_heads_q = self.model_config.get_num_attention_heads(
self.parallel_config)
self.num_heads_kv = self.model_config.get_num_kv_heads(
self.parallel_config)
self.headdim = self.model_config.get_head_size()
self.block_size = kv_cache_spec.block_size
self.kv_cache_spec = kv_cache_spec
self.direct_build: bool = is_torch_equal_or_newer("2.9.0.dev0")
self.q_block_size: int = 16 if is_torch_equal_or_newer(
"2.9.0.dev0") else 128
self.kv_block_size: int = 16 if is_torch_equal_or_newer(
"2.9.0.dev0") else 128
def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
return False
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> FlexAttentionMetadata:
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
max_query_len = common_attn_metadata.max_query_len
max_seq_len = common_attn_metadata.max_seq_len
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
block_table_tensor = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping
num_blocks_per_seq = cdiv(seq_lens, self.block_size)
use_cascade = common_prefix_len > 0
cu_prefix_query_lens = None
prefix_kv_lens = None
suffix_kv_lens = None
if use_cascade:
raise NotImplementedError("Not yet my friend")
block_size = self.kv_cache_spec.block_size
max_possible_seq_len = self.model_config.max_model_len
num_gpu_blocks = self.cache_config.num_gpu_blocks
assert num_gpu_blocks is not None, \
"FlexAttention requires num_gpu_blocks to be set"
total_cache_tokens = (num_gpu_blocks * block_size)
inverse_block_table = physical_to_logical_mapping(
block_table_tensor, seq_lens, block_size, num_gpu_blocks)
offset_tensor = common_attn_metadata.num_computed_tokens_cpu.to(
self.device, non_blocking=True)
out = FlexAttentionMetadata(
causal=common_attn_metadata.causal,
num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len,
query_start_loc=query_start_loc,
max_seq_len=max_seq_len,
seq_lens=seq_lens,
block_table=block_table_tensor,
slot_mapping=slot_mapping,
use_cascade=use_cascade,
common_prefix_len=common_prefix_len,
cu_prefix_query_lens=cu_prefix_query_lens,
prefix_kv_lens=prefix_kv_lens,
suffix_kv_lens=suffix_kv_lens,
block_size=block_size,
max_possible_sequence_length=max_possible_seq_len,
num_reqs=num_reqs,
physical_to_logical=inverse_block_table,
total_cache_tokens=total_cache_tokens,
decode_offset=offset_tensor,
num_blocks_per_seq=num_blocks_per_seq,
direct_build=self.direct_build,
q_block_size=self.q_block_size,
kv_block_size=self.kv_block_size,
)
return out
def use_cascade_attention(self, *args, **kwargs) -> bool:
return False
class FlexAttentionImpl(AttentionImpl):
sliding_window: Optional[int]
alibi_slopes: Optional[torch.Tensor]
logits_soft_cap: Optional[float]
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
logits_soft_cap: Optional[float] = None,
attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
**kwargs,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
self.attn_type = attn_type
if attn_type not in (AttentionType.ENCODER_ONLY,
AttentionType.DECODER):
raise NotImplementedError(
f"FlexAttention does not support {attn_type} attention")
if alibi_slopes is not None:
raise NotImplementedError(
"FlexAttention does not support alibi slopes yet.")
else:
self.alibi_slopes = None
self.sliding_window = sliding_window
self.kv_cache_dtype = kv_cache_dtype
self.logits_soft_cap = logits_soft_cap
if self.logits_soft_cap is not None:
raise NotImplementedError(
"FlexAttention does not support logits soft cap yet.")
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
if kv_sharing_target_layer_name is not None:
raise NotImplementedError(
"FlexAttention does not support kv sharing yet.")
FlexAttentionBackend.validate_head_size(head_size)
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"FlexAttention does not support quantized kv-cache. Yet")
@staticmethod
def view_as_4d(tensor: torch.Tensor) -> torch.Tensor:
"""View a 3d tensor as 4D."""
if tensor.ndim == 4:
return tensor
assert tensor.ndim == 3
return tensor[None, :, :, :]
def forward(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: FlexAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FLexAttention.
Args:
query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache: shape =
[2, num_blocks, block_size, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert output is not None, "Output tensor must be provided."
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for FlexAttentionImpl")
enable_gqa = self.num_kv_heads != self.num_heads
if attn_metadata is None:
# Profiling run.
return output
# query = self.view_as_4d(query).permute(0, 2, 1, 3)
# return torch.empty_like(query)
num_actual_tokens = attn_metadata.num_actual_tokens
if attn_metadata.sliding_window != self.sliding_window:
attn_metadata.sliding_window = self.sliding_window
if attn_metadata.direct_build:
# TODO: Support skipping the computation of sliding window
# in direct block mask building code path.
logger.warning_once(
"Using direct block mask building with sliding window, "
"which is suboptimal now. Performance may be degraded.")
# update mask mod in attention metadata
attn_metadata.mask_mod = attn_metadata.get_mask_mod()
attn_metadata.block_mask = (
attn_metadata._build_block_mask_direct())
else:
attn_metadata.block_mask = attn_metadata.build_block_mask()
if not attn_metadata.causal:
assert self.attn_type == AttentionType.ENCODER_ONLY
query, key_tensor, value_tensor = map(
lambda x: self.view_as_4d(x).permute(0, 2, 1, 3),
(query, key, value),
)
query = query[:, :, :num_actual_tokens, :]
if ((key_tensor.size(-2) > num_actual_tokens)
or (value_tensor.size(-2) > num_actual_tokens)):
# In the encoder-only model with torch.compile,
# qkv might be padded, which might cause exception.
# see: https://github.com/vllm-project/vllm/pull/24872#discussion_r2353252290
key_tensor = key_tensor[:, :, :num_actual_tokens, :]
value_tensor = value_tensor[:, :, :num_actual_tokens, :]
else:
assert self.attn_type == AttentionType.DECODER
key_cache, value_cache = kv_cache.unbind(0)
torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
# View out the block_size dim
key_cache = key_cache.view(-1, self.num_kv_heads, self.head_size)
value_cache = value_cache.view(-1, self.num_kv_heads,
self.head_size)
query, key_tensor, value_tensor = map(
lambda x: self.view_as_4d(x).permute(0, 2, 1, 3),
(query, key_cache, value_cache),
)
query = query[:, :, :num_actual_tokens, :]
# Doesn't work for now -> constraint violation
# torch._dynamo.try_mark_dynamic(query, 2)
assert attn_metadata.block_mask is not None
block_m, block_n = attn_metadata.block_mask.BLOCK_SIZE
kernel_options = get_kernel_options(query, block_m, block_n,
attn_metadata.direct_build)
out = flex_attention_compiled(
query,
key_tensor,
value_tensor,
attn_metadata.transformed_score_mod,
attn_metadata.block_mask,
self.scale,
enable_gqa=enable_gqa,
kernel_options=kernel_options,
)
# Flex doesn't have an out variant today, rely on epilogue fusion
out = out.permute(0, 2, 1, 3).squeeze(0)
output[:num_actual_tokens, :, :].copy_(out)
return output
def get_kernel_options(query, block_m, block_n,
use_direct_build: bool) -> dict[str, Union[int, bool]]:
kernel_options: dict[str, Union[int, bool]] = {
"FORCE_USE_FLEX_ATTENTION": True,
}
if use_direct_build:
kernel_options["BLOCK_M"] = block_m
kernel_options["BLOCK_N"] = block_n
return kernel_options
else:
kernel_options["BLOCK_M"] = 64
kernel_options["BLOCK_N"] = 64
if query.dtype == torch.float32:
kernel_options["BLOCK_M"] = 32
kernel_options["BLOCK_N"] = 32
# if current_platform.is_cuda():
if torch.cuda.is_available():
device_props = torch.cuda.get_device_properties()
max_shared_memory = device_props.shared_memory_per_block_optin
if max_shared_memory < 144 * 1024:
kernel_options["BLOCK_M"] = kernel_options["BLOCK_M"] // 2
kernel_options["BLOCK_N"] = kernel_options["BLOCK_N"] // 2
return kernel_options

View File

@@ -0,0 +1,332 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Backend for GatedDeltaNet attention."""
from dataclasses import dataclass
from typing import Optional
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.config import VllmConfig
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
compute_causal_conv1d_metadata,
split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
class GDNAttentionBackend(AttentionBackend):
@staticmethod
def get_builder_cls() -> type["GDNAttentionMetadataBuilder"]:
return GDNAttentionMetadataBuilder
@dataclass
class GDNAttentionMetadata:
num_prefills: int
num_prefill_tokens: int
num_decodes: int
num_decode_tokens: int
num_spec_decodes: int
num_spec_decode_tokens: int
num_actual_tokens: int
has_initial_state: Optional[torch.Tensor] = None
spec_query_start_loc: Optional[
torch.Tensor] = None # shape: [num_spec_decodes + 1,]
non_spec_query_start_loc: Optional[
torch.Tensor] = None # shape: [batch - num_spec_decodes + 1,]
spec_state_indices_tensor: Optional[
torch.Tensor] = None # shape: [batch, num_spec]
non_spec_state_indices_tensor: Optional[
torch.Tensor] = None # shape: [batch - num_spec_decodes,]
spec_sequence_masks: Optional[torch.Tensor] = None # shape: [batch,]
spec_token_masks: Optional[
torch.
Tensor] = None # shape: [num_prefill_tokens + num_decode_tokens,]
num_accepted_tokens: Optional[torch.Tensor] = None # shape: [batch,]
# The following attributes are for triton implementation of causal_conv1d
nums_dict: Optional[dict] = None
batch_ptr: Optional[torch.Tensor] = None
token_chunk_offset_ptr: Optional[torch.Tensor] = None
class GDNAttentionMetadataBuilder(
AttentionMetadataBuilder[GDNAttentionMetadata]):
cudagraph_support = AttentionCGSupport.UNIFORM_BATCH
reorder_batch_threshold: int = 1
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
assert isinstance(kv_cache_spec, MambaSpec)
self.vllm_config = vllm_config
self.compilation_config = vllm_config.compilation_config
self.speculative_config = vllm_config.speculative_config
self.kv_cache_spec = kv_cache_spec
if self.speculative_config:
self.num_spec = self.speculative_config.num_speculative_tokens # noqa: E501
else:
self.num_spec = 0
self.use_spec_decode = self.num_spec > 0
self._init_reorder_batch_threshold(1, self.use_spec_decode)
self.use_full_cuda_graph = \
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
self.decode_cudagraph_max_bs = min(
self.vllm_config.scheduler_config.max_num_seqs *
(self.num_spec + 1), self.compilation_config.max_capture_size)
self.spec_state_indices_tensor = torch.empty(
(self.decode_cudagraph_max_bs, self.num_spec + 1),
dtype=torch.int32,
device=device,
)
self.non_spec_state_indices_tensor = torch.empty(
(self.decode_cudagraph_max_bs, ),
dtype=torch.int32,
device=device,
)
self.spec_sequence_masks = torch.empty(
(self.decode_cudagraph_max_bs, ),
dtype=torch.bool,
device=device,
)
self.spec_token_masks = torch.empty(
(self.decode_cudagraph_max_bs * (self.num_spec + 1), ),
dtype=torch.bool,
device=device,
)
self.spec_query_start_loc = torch.empty(
(self.decode_cudagraph_max_bs + 1, ),
dtype=torch.int32,
device=device,
)
self.non_spec_query_start_loc = torch.empty(
(self.decode_cudagraph_max_bs + 1, ),
dtype=torch.int32,
device=device,
)
self.num_accepted_tokens = torch.empty(
(self.decode_cudagraph_max_bs, ),
dtype=torch.int32,
device=device,
)
def build( # type: ignore[override]
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
num_accepted_tokens: Optional[torch.Tensor] = None,
num_decode_draft_tokens_cpu: Optional[torch.Tensor] = None,
fast_build: bool = False,
) -> GDNAttentionMetadata:
m = common_attn_metadata
query_start_loc = m.query_start_loc
context_lens = m.num_computed_tokens_cpu
context_lens_tensor = context_lens.to(query_start_loc.device)
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
if (not self.use_spec_decode or num_decode_draft_tokens_cpu is None
or num_decode_draft_tokens_cpu[num_decode_draft_tokens_cpu >=
0].sum().item() == 0):
spec_sequence_masks = None
num_spec_decodes = 0
else:
spec_sequence_masks = num_decode_draft_tokens_cpu >= 0
num_spec_decodes = spec_sequence_masks.sum().item()
if num_spec_decodes == 0:
spec_sequence_masks = None
else:
spec_sequence_masks = spec_sequence_masks.to(
query_start_loc.device, non_blocking=True)
if spec_sequence_masks is None:
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(m, decode_threshold=1))
num_spec_decode_tokens = 0
spec_token_masks = None
spec_state_indices_tensor = None
non_spec_state_indices_tensor = m.block_table_tensor[:, 0]
spec_query_start_loc = None
non_spec_query_start_loc = query_start_loc
num_accepted_tokens = None
else:
query_lens = query_start_loc[1:] - query_start_loc[:-1]
non_spec_query_lens = query_lens[~spec_sequence_masks]
num_decodes = (non_spec_query_lens == 1).sum().item()
num_prefills = non_spec_query_lens.size(0) - num_decodes
num_decode_tokens = num_decodes
num_prefill_tokens = non_spec_query_lens.sum().item(
) - num_decode_tokens
if num_prefills == 0 and num_decodes == 0:
spec_token_masks = torch.ones(
(min(num_spec_decodes *
(self.num_spec + 1), query_start_loc[-1].item())),
dtype=torch.bool,
device=query_start_loc.device)
spec_state_indices_tensor = m.block_table_tensor[:, :self.
num_spec + 1]
non_spec_state_indices_tensor = None
spec_query_start_loc = query_start_loc
non_spec_query_start_loc = None
else:
spec_token_masks = torch.repeat_interleave(
spec_sequence_masks, query_lens)
spec_state_indices_tensor = m.block_table_tensor[
spec_sequence_masks, :self.num_spec + 1]
non_spec_state_indices_tensor = \
m.block_table_tensor[~spec_sequence_masks, 0]
spec_query_start_loc = torch.zeros(
num_spec_decodes + 1,
dtype=torch.int32,
device=query_start_loc.device)
torch.cumsum(query_lens[spec_sequence_masks],
dim=0,
out=spec_query_start_loc[1:])
non_spec_query_start_loc = torch.zeros(
query_lens.size(0) - num_spec_decodes + 1,
dtype=torch.int32,
device=query_start_loc.device)
torch.cumsum(query_lens[~spec_sequence_masks],
dim=0,
out=non_spec_query_start_loc[1:])
num_spec_decode_tokens = (query_lens.sum().item() -
num_prefill_tokens - num_decode_tokens)
assert num_accepted_tokens is not None
num_accepted_tokens = num_accepted_tokens[spec_sequence_masks]
if num_prefills > 0:
has_initial_state = context_lens_tensor > 0
if spec_sequence_masks is not None:
has_initial_state = has_initial_state[~spec_sequence_masks]
nums_dict, batch_ptr, token_chunk_offset_ptr = \
compute_causal_conv1d_metadata(non_spec_query_start_loc)
else:
has_initial_state = None
num_actual_tokens = num_prefill_tokens + num_decode_tokens + \
num_spec_decode_tokens
# prepare tensors for cudagraph
#
# With speculative decoding, the xgrammar backend may rollback tokens
# and causing some sequences has less draft tokens than self.num_spec.
#
# In above cases, the max possible batch size for n tokens, can be
# min(n, cudagraph_max_bs).
if (self.use_full_cuda_graph and num_prefills == 0 and num_decodes == 0
and num_spec_decodes <= self.decode_cudagraph_max_bs
and num_spec_decode_tokens <= self.decode_cudagraph_max_bs):
num_actual_tokens = self.vllm_config.pad_for_cudagraph(
m.num_actual_tokens)
batch_size = min(self.decode_cudagraph_max_bs, num_actual_tokens)
self.spec_state_indices_tensor[:num_spec_decodes].copy_(
spec_state_indices_tensor, non_blocking=True)
spec_state_indices_tensor = self.spec_state_indices_tensor[:
batch_size]
spec_state_indices_tensor[num_spec_decodes:].fill_(PAD_SLOT_ID)
self.spec_sequence_masks[:num_spec_decodes].copy_(
spec_sequence_masks, non_blocking=True)
spec_sequence_masks = self.spec_sequence_masks[:batch_size]
spec_sequence_masks[num_spec_decodes:].fill_(False)
assert spec_token_masks is not None
self.spec_token_masks[:spec_token_masks.size(0)].copy_(
spec_token_masks, non_blocking=True)
spec_token_masks = self.spec_token_masks[:num_actual_tokens]
spec_token_masks[spec_token_masks.size(0):].fill_(False)
self.spec_query_start_loc[:num_spec_decodes + 1].copy_(
spec_query_start_loc, non_blocking=True)
spec_num_query_tokens = spec_query_start_loc[
-1] # type: ignore[index]
spec_query_start_loc = self.spec_query_start_loc[:batch_size + 1]
spec_query_start_loc[num_spec_decodes +
1:].fill_(spec_num_query_tokens)
self.num_accepted_tokens[:num_spec_decodes].copy_(
num_accepted_tokens, non_blocking=True)
num_accepted_tokens = self.num_accepted_tokens[:batch_size]
num_accepted_tokens[num_spec_decodes:].fill_(1)
if (self.use_full_cuda_graph and num_prefills == 0
and num_spec_decodes == 0
and num_decodes <= self.decode_cudagraph_max_bs):
num_actual_tokens = self.vllm_config.pad_for_cudagraph(
m.num_actual_tokens)
batch_size = num_actual_tokens
self.non_spec_state_indices_tensor[:num_decodes].copy_(
non_spec_state_indices_tensor, non_blocking=True)
non_spec_state_indices_tensor = \
self.non_spec_state_indices_tensor[:batch_size]
non_spec_state_indices_tensor[num_decodes:].fill_(PAD_SLOT_ID)
self.non_spec_query_start_loc[:num_decodes + 1].copy_(
non_spec_query_start_loc, non_blocking=True)
non_spec_num_query_tokens = non_spec_query_start_loc[
-1] # type: ignore[index]
non_spec_query_start_loc = \
self.non_spec_query_start_loc[:batch_size + 1]
non_spec_query_start_loc[num_decodes +
1:].fill_(non_spec_num_query_tokens)
attn_metadata = GDNAttentionMetadata(
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
num_spec_decodes=num_spec_decodes,
num_spec_decode_tokens=num_spec_decode_tokens,
num_actual_tokens=num_actual_tokens,
has_initial_state=has_initial_state,
spec_query_start_loc=spec_query_start_loc,
non_spec_query_start_loc=non_spec_query_start_loc,
spec_state_indices_tensor=spec_state_indices_tensor,
non_spec_state_indices_tensor=non_spec_state_indices_tensor,
spec_sequence_masks=spec_sequence_masks,
spec_token_masks=spec_token_masks,
num_accepted_tokens=num_accepted_tokens,
nums_dict=nums_dict,
batch_ptr=batch_ptr,
token_chunk_offset_ptr=token_chunk_offset_ptr,
)
return attn_metadata
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata):
"""
This method builds the metadata for full cudagraph capture.
Currently, only decode is supported for full cudagraphs with Mamba.
"""
m = common_attn_metadata
assert (
m.num_reqs <= self.decode_cudagraph_max_bs
and m.num_actual_tokens <= self.decode_cudagraph_max_bs), (
f"GDN only supports decode-only full CUDAGraph capture. "
f"Make sure batch size ({m.num_reqs}) <= "
f"cudagraph capture sizes ({self.decode_cudagraph_max_bs}), "
f"and number of tokens ({m.num_actual_tokens}) <= "
f"cudagraph capture sizes ({self.decode_cudagraph_max_bs}).")
num_accepted_tokens = torch.diff(m.query_start_loc)
num_decode_draft_tokens_cpu = (num_accepted_tokens - 1).cpu()
m.num_computed_tokens_cpu = m.seq_lens_cpu - num_accepted_tokens.cpu()
return self.build(0, m, num_accepted_tokens,
num_decode_draft_tokens_cpu)

View File

@@ -0,0 +1,67 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.config import VllmConfig
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata,
split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
class LinearAttentionBackend(AttentionBackend):
@staticmethod
def get_builder_cls() -> type["LinearAttentionMetadataBuilder"]:
return LinearAttentionMetadataBuilder
@dataclass
class LinearAttentionMetadata:
num_prefills: int
num_prefill_tokens: int
num_decodes: int
num_decode_tokens: int
query_start_loc: torch.Tensor
seq_lens: torch.Tensor
state_indices_tensor: torch.Tensor # shape: [batch,]
class LinearAttentionMetadataBuilder(
AttentionMetadataBuilder[LinearAttentionMetadata]):
reorder_batch_threshold: int = 1
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
assert isinstance(kv_cache_spec, MambaSpec)
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> LinearAttentionMetadata:
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(
common_attn_metadata,
decode_threshold=self.reorder_batch_threshold))
attn_metadata = LinearAttentionMetadata(
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
query_start_loc=query_start_loc,
seq_lens=seq_lens,
state_indices_tensor=state_indices_tensor,
)
return attn_metadata

View File

@@ -0,0 +1,81 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Optional
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.v1.attention.backends.mamba_attn import (
BaseMambaAttentionMetadataBuilder)
from vllm.v1.attention.backends.utils import (CommonAttentionMetadata,
split_decodes_and_prefills)
class Mamba1AttentionBackend(AttentionBackend):
@staticmethod
def get_builder_cls() -> type["Mamba1AttentionMetadataBuilder"]:
return Mamba1AttentionMetadataBuilder
@dataclass
class Mamba1AttentionMetadata:
query_start_loc: torch.Tensor
context_lens_tensor: torch.Tensor
state_indices_tensor: torch.Tensor
has_initial_states: Optional[torch.Tensor]
num_prefills: int
num_prefill_tokens: int
num_decodes: int
num_decode_tokens: int
num_padded_decodes: int
class Mamba1AttentionMetadataBuilder(
BaseMambaAttentionMetadataBuilder[Mamba1AttentionMetadata]):
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> Mamba1AttentionMetadata:
query_start_loc = common_attn_metadata.query_start_loc
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
context_lens_tensor = common_attn_metadata.num_computed_tokens_cpu.to(
query_start_loc.device)
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(
common_attn_metadata,
decode_threshold=self.reorder_batch_threshold))
has_initial_states = None
padded_decodes = num_decodes
if num_prefills > 0:
has_initial_states = context_lens_tensor > 0
elif (num_decodes > 0 and num_decodes <= self.decode_cudagraph_max_bs
and self.compilation_config.full_cuda_graph):
state_indices_for_decode = state_indices_tensor[:num_decodes]
padded_decodes = self.vllm_config.pad_for_cudagraph(num_decodes)
self.state_indices_tensor[:num_decodes].copy_(
state_indices_for_decode, non_blocking=True)
state_indices_tensor = self.state_indices_tensor[:padded_decodes]
state_indices_tensor[num_decodes:] = PAD_SLOT_ID
return Mamba1AttentionMetadata(
query_start_loc=query_start_loc,
context_lens_tensor=context_lens_tensor,
has_initial_states=has_initial_states,
state_indices_tensor=state_indices_tensor,
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
num_padded_decodes=padded_decodes,
)

View File

@@ -0,0 +1,232 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from dataclasses import dataclass
from typing import Optional
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.config import VllmConfig
from vllm.v1.attention.backends.mamba_attn import (
BaseMambaAttentionMetadataBuilder)
from vllm.v1.attention.backends.utils import (PAD_SLOT_ID,
CommonAttentionMetadata,
compute_causal_conv1d_metadata,
split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec
def _query_start_loc_to_chunk_indices_offsets(
query_start_loc: torch.Tensor, chunk_size: int,
total_seqlens: int) -> tuple[torch.Tensor, torch.Tensor]:
"""
Args:
query_start_loc (torch.Tensor): 1D tensor of cumulative sequence
lengths, shape (num_seqs + 1,).
The first element should be 0. Each entry represents the starting
index of a sequence in the flattened token array.
chunk_size (int): The size of each physical mamba chunk
(number of tokens per chunk).
total_seqlens (int): The total number of tokens in the batch.
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- chunk_indices (torch.Tensor): 1D tensor of indices
indicating the physical chunk for each logical chunk.
- chunk_offsets (torch.Tensor): 1D tensor of offsets
indicating the starting index of each logical chunk within
its physical chunk.
This function computes the chunk indices and offsets for the given
query_start_loc and chunk_size. Both are tensors of integers with length N,
where N is the number of logical (pseudo) chunks.
A logical chunk is a sequence of tokens that are all part of the same
sequence and are all in the same physical mamba chunk.
In other words, a logical chunk changes every time we cross a sequence
boundary or a physical mamba chunk boundary.
Logical chunks are needed to handle batched requests with initial states
(see _state_passing_fwd and _chunk_scan_fwd).
The chunk_indices tensor contains the index of the physical chunk for each
logical chunk.
The chunk_offsets tensor contains the offset (AKA starting index) of the
logical chunk in the physical chunk.
Example:
query_start_loc = [0, 5, 10]
chunk_size = 8
total_seqlens = 10
-> chunk_indices = [0, 0, 1]
-> chunk_offsets = [0, 5, 0]
In this example, we have 2 sequences, each with 5 tokens. The physical
chunk size is 8 tokens.
We have three logical chunks:
- the first logical chunk starts at token 0 in the first physical chunk
and contains all 5 tokens from the first sequence
- the second logical chunk starts at token 5 in the first physical chunk
and contains first 3 tokens from the second sequence
- the third logical chunk starts at token 0 in the second physical chunk
and contains the remaining 2 tokens from the second sequence
"""
cu_seqlens = query_start_loc[1:] # remove prepended 0
# outputs will have length expansion of chunks that do not divide
# chunk_size
N = math.ceil(total_seqlens / chunk_size) + (cu_seqlens[:-1] % chunk_size
> 0).sum()
chunk_indices = torch.arange(N,
dtype=torch.int,
device=query_start_loc.device)
chunk_offsets = torch.zeros((N, ),
dtype=torch.int,
device=query_start_loc.device)
p = 0 # num of insertions
for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]):
# if does not divide chunk_size, then there is one chunk insertion
p += (s % chunk_size > 0)
# get the dimensions
# - the + 1 for _e is to shift the boundary by one chunk
# - this shifting is not needed if chunk_size divides e
_s, _e = s // chunk_size + p, e // chunk_size + p + (e % chunk_size
> 0)
# adjust indices and offsets
chunk_indices[_s:_e] -= p
chunk_offsets[_s] = s % chunk_size
return chunk_indices, chunk_offsets
class Mamba2AttentionBackend(AttentionBackend):
@staticmethod
def get_builder_cls() -> type["Mamba2AttentionMetadataBuilder"]:
return Mamba2AttentionMetadataBuilder
@dataclass
class Mamba2AttentionMetadata:
num_prefills: int
num_prefill_tokens: int
num_decodes: int
num_decode_tokens: int
query_start_loc_p: torch.Tensor
seq_lens: torch.Tensor
prep_initial_states: bool
chunk_size: int
# The following tensors only contain prefill requests and will be None if
# the batch has no prefill request.
has_initial_states_p: Optional[torch.Tensor]
seq_idx_p: Optional[torch.Tensor]
chunk_indices_p: Optional[torch.Tensor]
chunk_offsets_p: Optional[torch.Tensor]
state_indices_tensor: torch.Tensor # shape: [batch,]
# The following attributes are for triton implementation of causal_conv1d
nums_dict: Optional[dict] = None
batch_ptr: Optional[torch.Tensor] = None
token_chunk_offset_ptr: Optional[torch.Tensor] = None
class Mamba2AttentionMetadataBuilder(
BaseMambaAttentionMetadataBuilder[Mamba2AttentionMetadata]):
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self.chunk_size = vllm_config.model_config.get_mamba_chunk_size()
assert self.chunk_size is not None, (
"chunk_size needs to be set in the model config for Mamba2 models")
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> Mamba2AttentionMetadata:
num_reqs = common_attn_metadata.num_reqs
query_start_loc_p = None
seq_lens = common_attn_metadata.seq_lens
seq_idx_p = None
chunk_indices_p, chunk_offsets_p = None, None
# Need flags to indicate if there are initial states
# currently we really only support the FlashAttention backend
has_initial_states_p = None
prep_initial_states = False
# for causal_conv1d
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(
common_attn_metadata,
decode_threshold=self.reorder_batch_threshold))
# Compute seq_idx, chunk_indices and chunk_offsets for prefill only
if num_prefills > 0:
#[batch,]
has_initial_states_cpu = (
common_attn_metadata.
num_computed_tokens_cpu[num_reqs - num_prefills:num_reqs] > 0)
prep_initial_states = torch.any(has_initial_states_cpu).item()
has_initial_states_p = has_initial_states_cpu.to(
common_attn_metadata.query_start_loc.device)
query_start_loc_p = common_attn_metadata.query_start_loc[
-num_prefills - 1:] - num_decode_tokens
seq_idx_p = torch.repeat_interleave(torch.arange(
num_prefills,
dtype=torch.int32,
device=query_start_loc_p.device),
query_start_loc_p.diff(),
output_size=num_prefill_tokens)
# We compute metadata for chunked prefill once at the top level
# model forward and reuse them in mamba layers. If not needed,
# they will be ignored inside mamba kernels.
if prep_initial_states:
chunk_indices_p, chunk_offsets_p = (
_query_start_loc_to_chunk_indices_offsets(
query_start_loc_p, self.chunk_size,
num_prefill_tokens))
nums_dict, batch_ptr, token_chunk_offset_ptr = \
compute_causal_conv1d_metadata(query_start_loc_p)
elif num_decodes <= self.decode_cudagraph_max_bs:
# Pad state tensor for CUDA graph
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_decodes)
self.state_indices_tensor[:num_decodes].copy_(state_indices_tensor,
non_blocking=True)
state_indices_tensor = self.state_indices_tensor[:num_input_tokens]
state_indices_tensor[num_decodes:] = PAD_SLOT_ID
attn_metadata = Mamba2AttentionMetadata(
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
query_start_loc_p=query_start_loc_p,
seq_lens=seq_lens,
prep_initial_states=prep_initial_states,
chunk_size=self.chunk_size,
has_initial_states_p=has_initial_states_p,
seq_idx_p=seq_idx_p,
chunk_indices_p=chunk_indices_p,
chunk_offsets_p=chunk_offsets_p,
state_indices_tensor=state_indices_tensor,
nums_dict=nums_dict,
batch_ptr=batch_ptr,
token_chunk_offset_ptr=token_chunk_offset_ptr,
)
return attn_metadata

View File

@@ -0,0 +1,52 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import abc
from typing import ClassVar, TypeVar
import torch
from vllm.config import VllmConfig
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
M = TypeVar("M")
class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
reorder_batch_threshold: int = 1
cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
assert isinstance(kv_cache_spec, MambaSpec)
self.compilation_config = vllm_config.compilation_config
self.decode_cudagraph_max_bs = min(
self.vllm_config.scheduler_config.max_num_seqs,
self.compilation_config.max_capture_size)
self.state_indices_tensor = torch.empty(
(self.decode_cudagraph_max_bs, ),
dtype=torch.int32,
device=device,
)
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata) -> M:
"""
This method builds the metadata for full cudagraph capture.
Currently, only decode is supported for full cudagraphs with Mamba.
"""
m = common_attn_metadata
assert m.num_reqs == m.num_actual_tokens, \
"Mamba only supports decode-only full CUDAGraph capture. " \
"Make sure all cudagraph capture sizes <= max_num_seq."
m.max_query_len = 1 # decode-only
return self.build(0, m)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,248 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from typing import ClassVar, Optional, Union
import torch
import vllm._custom_ops as ops
from vllm.attention.backends.abstract import (AttentionLayer, AttentionType,
is_quantized_kv_cache)
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
MLACommonImpl,
MLACommonMetadata,
MLACommonMetadataBuilder)
from vllm.v1.attention.backends.utils import AttentionCGSupport
logger = init_logger(__name__)
class CutlassMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
# enable full CUDA Graph support for decode-only capture
cudagraph_support: ClassVar[
AttentionCGSupport] = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
class CutlassMLABackend(MLACommonBackend):
@staticmethod
def get_name() -> str:
return "CUTLASS_MLA"
@staticmethod
def get_impl_cls() -> type["CutlassMLAImpl"]:
return CutlassMLAImpl
@staticmethod
def get_builder_cls() -> type["CutlassMLAMetadataBuilder"]:
return CutlassMLAMetadataBuilder
class SM100Workspace:
def __init__(self, initial_workspace_size):
self._workspace_buf = torch.empty(initial_workspace_size,
device="cuda",
dtype=torch.uint8)
self._block_size = 128 # Forced to 128
# Pre-compute sm_count to avoid recomputing it. Use device 0 as a proxy
# (assumes all devices are similar)
properties = torch.cuda.get_device_properties(torch.device("cuda:0"))
self._sm_count = properties.multi_processor_count
def get_buf(self):
return self._workspace_buf
def ensure_size(self, attn_metadata: MLACommonMetadata,
num_kv_splits: int):
batch_size = attn_metadata.num_reqs
max_seq_len = attn_metadata.max_query_len
workspace_size = ops.sm100_cutlass_mla_get_workspace_size(
max_seq_len * self._block_size,
batch_size,
self._sm_count,
num_kv_splits=num_kv_splits)
if self._workspace_buf.shape[0] < workspace_size:
self._workspace_buf.resize_(workspace_size)
g_sm100_workspace = SM100Workspace(128 * 1024 * 1024) # 128MB
MAX_HEADS = 128
class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
can_return_lse_for_decode: bool = True
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
logits_soft_cap: Optional[float],
attn_type: str,
kv_sharing_target_layer_name: Optional[str],
# MLA Specific Arguments
**mla_args) -> None:
super().__init__(num_heads,
head_size,
scale,
num_kv_heads,
alibi_slopes,
sliding_window,
kv_cache_dtype,
logits_soft_cap,
attn_type,
kv_sharing_target_layer_name,
q_pad_num_heads=MAX_HEADS,
**mla_args)
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
if any(unsupported_features):
raise NotImplementedError(
"CutlassMLAImpl does not support one of the following: "
"alibi_slopes, sliding_window, logits_soft_cap")
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"CutlassMLAImpl")
# TODO: Currently, num_kv_splits is limited to 16 to avoid hanging
# issues. In case the code hangs, use:
# FORCE_NUM_KV_SPLITS=1
force_num_kv_splits = os.environ.get("FORCE_NUM_KV_SPLITS", None)
if force_num_kv_splits:
logger.warning_once("Forcing num_kv_splits to %d",
int(force_num_kv_splits))
self._num_kv_splits = int(force_num_kv_splits)
else:
self._num_kv_splits = -1 # => Auto-detect
# Share workspace buffer across all executions
self._workspace = g_sm100_workspace
def _sm100_cutlass_mla_decode(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
seq_lens: torch.Tensor,
page_table: torch.Tensor,
workspace: torch.Tensor,
sm_scale: float,
num_kv_splits: int,
) -> tuple[torch.Tensor, torch.Tensor]:
assert (q_nope.ndim == 3
), f"q_nope must be a 3D tensor, but got {q_nope.ndim}"
assert (
q_pe.ndim == 3), f"q_pe must be a 3D tensor, but got {q_pe.ndim}"
assert (
kv_c_and_k_pe_cache.ndim == 3
), "kv_c_and_k_pe_cache must be a 3D tensor, but got {}".format(
kv_c_and_k_pe_cache.ndim)
B_q, H, D_q_nope = q_nope.shape
B_q_2, H_2, D_q_pe = q_pe.shape
assert (B_q == B_q_2) and (H == H_2)
_, PAGE_SIZE, D_ckv = kv_c_and_k_pe_cache.shape
D_latent = 512
D_rope = 64
assert D_q_nope == D_latent
assert D_q_pe == D_rope
assert D_ckv == D_latent + D_rope
MAX_HEADS = 128
assert H <= MAX_HEADS, f"H must be <= {MAX_HEADS}, but got {H}"
assert len(page_table.shape) == 2
B_block_table, block_num = page_table.shape
assert B_block_table == B_q
assert (block_num
> 0), f"block num must be greater than 0, got {block_num}"
assert block_num % (128 / PAGE_SIZE) == 0
assert q_nope.dtype in (
torch.float16, torch.bfloat16, torch.float8_e4m3fn), (
f"q_nope.dtype needs to be fp16 or bf16 or e4m3 but got "
f"{q_nope.dtype}.")
assert q_nope.dtype == q_pe.dtype == kv_c_and_k_pe_cache.dtype
assert (
seq_lens.dtype == torch.int32
), f"seq_lens.dtype needs to be int32 but got {seq_lens.dtype}."
assert (
page_table.dtype == torch.int32
), f"page_table.dtype needs to be int32 but got {page_table.dtype}."
dtype = (torch.bfloat16 if is_quantized_kv_cache(self.kv_cache_dtype)
else q_nope.dtype)
out = q_nope.new_empty((B_q, MAX_HEADS, D_latent), dtype=dtype)
lse = (torch.empty(
(B_q, MAX_HEADS), dtype=torch.float32, device=q_nope.device)
if self.need_to_return_lse_for_decode else torch.Tensor())
ops.sm100_cutlass_mla_decode(
out,
lse,
q_nope,
q_pe,
kv_c_and_k_pe_cache,
seq_lens,
page_table,
workspace,
sm_scale,
num_kv_splits,
)
if H < MAX_HEADS:
# Extract the subsets of the outputs
lse = lse[:, :H] if self.need_to_return_lse_for_decode else lse
out = out[:, :H]
return out, lse
def _forward_decode(
self,
q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
layer: AttentionLayer,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
if type(q) is tuple:
q_nope, q_pe = q
else:
q_nope, q_pe = torch.split(
q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
# Adjust workspace size (if necessary)
self._workspace.ensure_size(attn_metadata, self._num_kv_splits)
# Run MLA
o, lse = self._sm100_cutlass_mla_decode(
q_nope,
q_pe,
kv_c_and_k_pe_cache,
attn_metadata.decode.seq_lens,
attn_metadata.decode.block_table,
self._workspace.get_buf(),
self.scale,
self._num_kv_splits,
)
return o, (lse if self.need_to_return_lse_for_decode else None)

View File

@@ -0,0 +1,271 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import ClassVar, Optional, Union
import torch
from vllm import envs
from vllm.attention.backends.abstract import (AttentionLayer, AttentionType,
is_quantized_kv_cache)
from vllm.attention.utils.fa_utils import (flash_attn_supports_mla,
get_flash_attn_version)
from vllm.config import VllmConfig
from vllm.distributed.parallel_state import get_dcp_group
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
MLACommonDecodeMetadata,
MLACommonImpl,
MLACommonMetadata,
MLACommonMetadataBuilder)
from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.vllm_flash_attn import flash_attn_varlen_func, get_scheduler_metadata
logger = init_logger(__name__)
class FlashAttnMLABackend(MLACommonBackend):
@staticmethod
def get_name() -> str:
return "FLASH_ATTN_MLA"
@staticmethod
def get_metadata_cls() -> type["FlashAttnMLAMetadata"]:
return FlashAttnMLAMetadata
@staticmethod
def get_builder_cls() -> type["FlashAttnMLAMetadataBuilder"]:
return FlashAttnMLAMetadataBuilder
@staticmethod
def get_impl_cls() -> type["FlashAttnMLAImpl"]:
return FlashAttnMLAImpl
@dataclass
class FlashAttnMLADecodeMetadata(MLACommonDecodeMetadata):
query_start_loc: torch.Tensor
max_query_len: int
max_seq_len: int
scheduler_metadata: Optional[torch.Tensor] = None
max_num_splits: int = 0
@dataclass
class FlashAttnMLAMetadata(MLACommonMetadata[FlashAttnMLADecodeMetadata]):
pass
class FlashAttnMLAMetadataBuilder(
MLACommonMetadataBuilder[FlashAttnMLAMetadata]):
cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.UNIFORM_BATCH
reorder_batch_threshold: int = 512
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
super().__init__(kv_cache_spec, layer_names, vllm_config, device,
FlashAttnMLAMetadata)
self.max_num_splits = 0 # No upper bound on the number of splits.
self.fa_aot_schedule = (get_flash_attn_version() == 3)
self.use_full_cuda_graph = \
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
if self.use_full_cuda_graph and self.fa_aot_schedule:
self.max_cudagraph_size = self.compilation_config.max_capture_size
if self.max_cudagraph_size > 992:
# This condition derives from FA3's internal heuristic.
# TODO(woosuk): Support larger cudagraph sizes.
raise ValueError(
"Capture size larger than 992 is not supported for "
"full cuda graph.")
self.scheduler_metadata = torch.zeros(
vllm_config.scheduler_config.max_num_seqs + 1,
dtype=torch.int32,
device=self.device,
)
# When using cuda graph, we need to set the upper bound of the
# number of splits so that large enough intermediate buffers are
# pre-allocated during capture.
self.max_num_splits = (
envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH)
# TODO(lucas): Until we add support for the DCP custom masking we need
# to restrict decodes to q_len == 1 when DCP is enabled.
self.reorder_batch_threshold = 1 \
if get_dcp_group().world_size > 1 else self.reorder_batch_threshold
def _schedule_decode(self, num_reqs, cu_query_lens, max_query_len, seqlens,
max_seq_len, causal):
if self.fa_aot_schedule:
return get_scheduler_metadata(
batch_size=num_reqs,
max_seqlen_q=max_query_len,
max_seqlen_k=max_seq_len,
num_heads_q=self.num_heads,
num_heads_kv=1,
headdim=self.mla_dims.qk_rope_head_dim,
cache_seqlens=seqlens,
qkv_dtype=self.kv_cache_spec.dtype,
headdim_v=self.mla_dims.kv_lora_rank,
page_size=self.page_size,
cu_seqlens_q=cu_query_lens,
causal=causal,
num_splits=self.max_num_splits,
)
return None
def _build_decode(self, block_table_tensor: torch.Tensor,
seq_lens_cpu: torch.Tensor,
seq_lens_device: torch.Tensor,
query_start_loc_cpu: torch.Tensor,
query_start_loc_device: torch.Tensor,
num_decode_tokens: int) -> FlashAttnMLADecodeMetadata:
query_lens_cpu = (query_start_loc_cpu[1:] - query_start_loc_cpu[:-1])
max_query_len = query_lens_cpu.max().item()
max_seq_len = seq_lens_cpu.max().item()
scheduler_metadata = self._schedule_decode(
num_reqs=seq_lens_cpu.numel(),
cu_query_lens=query_start_loc_device,
max_query_len=max_query_len,
seqlens=seq_lens_device,
max_seq_len=max_seq_len,
causal=True,
)
# For FA3 + full cudagraph
max_num_splits = 0
if self.use_full_cuda_graph and scheduler_metadata is not None:
n = scheduler_metadata.shape[0]
# Ensure the persistent buffer is large enough
assert n <= self.scheduler_metadata.shape[0], \
f"Scheduler metadata size {n} exceeds buffer size " + \
f"{self.scheduler_metadata.shape[0]}"
self.scheduler_metadata[:n] = scheduler_metadata
# NOTE(woosuk): We should zero out the rest of the scheduler
# metadata to guarantee the correctness. Otherwise, some thread
# blocks may use the invalid scheduler metadata and overwrite the
# output buffer.
self.scheduler_metadata[n:] = 0
scheduler_metadata = self.scheduler_metadata[:n]
if num_decode_tokens <= self.max_cudagraph_size:
# NOTE(woosuk): Setting num_splits > 1 may increase the memory
# usage, because the intermediate buffers of size [num_splits,
# num_heads, num_tokens, head_size] are allocated. Therefore,
# we only set num_splits when using cuda graphs.
max_num_splits = self.max_num_splits
return FlashAttnMLADecodeMetadata(
block_table=block_table_tensor,
seq_lens=seq_lens_device,
query_start_loc=query_start_loc_device,
max_query_len=max_query_len,
max_seq_len=max_seq_len,
scheduler_metadata=scheduler_metadata,
max_num_splits=max_num_splits,
)
class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]):
can_return_lse_for_decode: bool = True
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
logits_soft_cap: Optional[float],
attn_type: str,
kv_sharing_target_layer_name: Optional[str],
# MLA Specific Arguments
**mla_args) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
logits_soft_cap, attn_type,
kv_sharing_target_layer_name, **mla_args)
assert flash_attn_supports_mla(), \
"FlashAttnMLA is not supported on this device"
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
if any(unsupported_features):
raise NotImplementedError(
"FlashAttnMLAImpl does not support one of the following: "
"alibi_slopes, sliding_window, logits_soft_cap")
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"FlashAttnMLAImpl")
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"FlashAttnMLA V1 with FP8 KV cache not yet supported")
def _forward_decode(
self,
q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: FlashAttnMLAMetadata,
layer: AttentionLayer,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
if type(q) is tuple:
q_nope, q_pe = q
else:
q_nope, q_pe = torch.split(
q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
if self.kv_cache_dtype.startswith("fp8"):
raise NotImplementedError(
"FP8 FlashAttention MLA not yet supported")
kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank]
k_pe_cache = kv_c_and_k_pe_cache[..., self.kv_lora_rank:]
# NOTE(matt): During CUDA graph capture, max_query_len can be 0, but the
# kernel uses this to calculate grid dimensions. Ensure it's at least 1
# to prevent invalid grid configuration during graph capture.
max_seqlen_q = max(attn_metadata.decode.max_query_len, 1)
attn_out = flash_attn_varlen_func(
q=q_pe,
k=k_pe_cache.unsqueeze(-2), # Add head dim of 1
v=kv_c_cache.unsqueeze(-2), # Add head dim of 1
q_v=q_nope,
max_seqlen_q=max_seqlen_q,
cu_seqlens_q=attn_metadata.decode.query_start_loc,
max_seqlen_k=attn_metadata.decode.max_seq_len,
seqused_k=attn_metadata.decode.seq_lens,
block_table=attn_metadata.decode.block_table,
softmax_scale=self.scale,
causal=True,
return_softmax_lse=self.need_to_return_lse_for_decode,
fa_version=3, # only version 3 is supported
scheduler_metadata=attn_metadata.decode.scheduler_metadata,
num_splits=attn_metadata.decode.max_num_splits,
)
if self.need_to_return_lse_for_decode:
o, lse = attn_out
# FA returns LSE in shape [ H, B ] but DCP wants [ B, H ]
return o, lse.transpose(0, 1) # [ H, B ] -> [ B, H ]
else:
o = attn_out
return o, None

View File

@@ -0,0 +1,114 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional, Union
import torch
from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
from vllm.attention.backends.abstract import AttentionLayer, AttentionType
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
MLACommonImpl,
MLACommonMetadata)
logger = init_logger(__name__)
FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024
class FlashInferMLABackend(MLACommonBackend):
@staticmethod
def get_name() -> str:
return "FLASHINFER_MLA"
@staticmethod
def get_impl_cls() -> type["FlashInferMLAImpl"]:
return FlashInferMLAImpl
g_fi_workspace = torch.zeros(
FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE,
dtype=torch.uint8,
device="cuda",
)
class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
logits_soft_cap: Optional[float],
attn_type: str,
kv_sharing_target_layer_name: Optional[str],
# MLA Specific Arguments
**mla_args) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
logits_soft_cap, attn_type,
kv_sharing_target_layer_name, **mla_args)
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
if any(unsupported_features):
raise NotImplementedError(
"FlashInferMLAImpl does not support one of the following: "
"alibi_slopes, sliding_window, logits_soft_cap")
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"FlashInferMLAImpl")
self._workspace_buffer = g_fi_workspace
self.bmm1_scale: Optional[float] = None
self.bmm2_scale: Optional[float] = None
def _forward_decode(
self,
q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
layer: AttentionLayer,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
if isinstance(q, tuple):
q_nope, q_pe = q
q = torch.cat([q_nope, q_pe], dim=-1)
# trtllm API requires extra dimension q_len_per_request for MTP
q = q.unsqueeze(1)
if self.bmm1_scale is None:
self.bmm1_scale = (layer._q_scale_float * layer._k_scale_float *
self.scale)
if self.bmm2_scale is None:
self.bmm2_scale = layer._v_scale_float
o = trtllm_batch_decode_with_kv_cache_mla(
query=q,
kv_cache=kv_c_and_k_pe_cache.unsqueeze(1),
workspace_buffer=self._workspace_buffer,
qk_nope_head_dim=self.qk_nope_head_dim,
kv_lora_rank=self.kv_lora_rank,
qk_rope_head_dim=self.qk_rope_head_dim,
block_tables=attn_metadata.decode.block_table,
seq_lens=attn_metadata.decode.seq_lens,
max_seq_len=attn_metadata.max_seq_len,
bmm1_scale=self.bmm1_scale,
bmm2_scale=self.bmm2_scale,
)
# TODO: Return LSE pending support from Flashinfer API:
# https://github.com/flashinfer-ai/flashinfer/pull/1566
return o, None

View File

@@ -0,0 +1,203 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import ClassVar, Optional, Union
import torch
from vllm.attention.backends.abstract import AttentionLayer, AttentionType
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
get_mla_metadata,
is_flashmla_supported)
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
MLACommonDecodeMetadata,
MLACommonImpl,
MLACommonMetadata,
MLACommonMetadataBuilder)
from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.kv_cache_interface import AttentionSpec
logger = init_logger(__name__)
class FlashMLABackend(MLACommonBackend):
@staticmethod
def get_name() -> str:
return "FLASHMLA"
@staticmethod
def get_metadata_cls() -> type["FlashMLAMetadata"]:
return FlashMLAMetadata
@staticmethod
def get_builder_cls() -> type["FlashMLAMetadataBuilder"]:
return FlashMLAMetadataBuilder
@staticmethod
def get_impl_cls() -> type["FlashMLAImpl"]:
return FlashMLAImpl
@dataclass
class FlashMLADecodeMetadata(MLACommonDecodeMetadata):
tile_scheduler_metadata: torch.Tensor
num_splits: torch.Tensor
@dataclass
class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):
pass
class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.UNIFORM_BATCH
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
super().__init__(kv_cache_spec, layer_names, vllm_config, device,
FlashMLAMetadata)
self.num_q_heads = vllm_config.model_config.get_num_attention_heads(
vllm_config.parallel_config)
self.cg_buf_tile_scheduler_metadata = None
self.cg_buf_num_splits = None
device_properties = torch.cuda.get_device_properties(self.device)
num_sms = device_properties.multi_processor_count
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
self.cg_buf_tile_scheduler_metadata = torch.zeros(
# Upper bound on size (<= #SMs, TileSchedulerMetaDataSize)
# TileSchedulerMetaDataSize = 8
(num_sms, 8),
device=self.device,
dtype=torch.int32,
)
self.cg_buf_num_splits = torch.empty(
(vllm_config.scheduler_config.max_num_seqs + 1),
device=self.device,
dtype=torch.int32)
def _build_decode(self, block_table_tensor: torch.Tensor,
seq_lens_cpu: torch.Tensor,
seq_lens_device: torch.Tensor,
query_start_loc_cpu: torch.Tensor,
query_start_loc_device: torch.Tensor,
num_decode_tokens: int) -> FlashMLADecodeMetadata:
tile_scheduler_metadata, num_splits = \
get_mla_metadata(
seq_lens_device,
self.num_q_heads,
1, # MQA for the decode path
)
# TODO: we can disambiguate between decode and mixed-prefill decode here
# so we can only use the persistent buffer if a cudagraph is actually
# being used.
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
assert self.cg_buf_tile_scheduler_metadata is not None
assert self.cg_buf_num_splits is not None
sm_parts = tile_scheduler_metadata.size(0)
# Metadata per-SM, upper bound on size (<= #SMs, TileMetadataSize)
assert sm_parts <= self.cg_buf_tile_scheduler_metadata.size(0)
tile_scheduler_metadata_view = \
self.cg_buf_tile_scheduler_metadata[:sm_parts]
tile_scheduler_metadata_view.copy_(tile_scheduler_metadata)
tile_scheduler_metadata = tile_scheduler_metadata_view
# Num splits is per-batch, varying size (batch_size,)
n = num_splits.size(0)
# make sure static buffer is large enough
assert n <= self.cg_buf_num_splits.size(0)
num_splits_view = self.cg_buf_num_splits[:n]
num_splits_view.copy_(num_splits)
# Num splits needs to monotonically increasing
# (with: https://github.com/vllm-project/FlashMLA/pull/3, otherwise
# it needs to monotonically increasing by 1)
self.cg_buf_num_splits[n:].fill_(num_splits[-1])
num_splits = num_splits_view
return FlashMLADecodeMetadata(
block_table=block_table_tensor,
seq_lens=seq_lens_device,
tile_scheduler_metadata=tile_scheduler_metadata,
num_splits=num_splits,
)
class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
can_return_lse_for_decode: bool = True
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
logits_soft_cap: Optional[float],
attn_type: str,
kv_sharing_target_layer_name: Optional[str],
# MLA Specific Arguments
**mla_args) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
logits_soft_cap, attn_type,
kv_sharing_target_layer_name, **mla_args)
is_supported, reason = is_flashmla_supported()
assert is_supported, reason
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
if any(unsupported_features):
raise NotImplementedError(
"FlashMLAImpl does not support one of the following: "
"alibi_slopes, sliding_window, logits_soft_cap")
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"FlashMLAImpl")
def _forward_decode(
self,
q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: FlashMLAMetadata,
layer: AttentionLayer,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
# TODO: (zyongye) decode function for mla here
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
if type(q) is tuple:
q = torch.cat(q, dim=-1)
assert isinstance(q, torch.Tensor)
o, lse = flash_mla_with_kvcache(
q=q.unsqueeze(1), # Add seqlen dim of 1 (decode)
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
block_table=attn_metadata.decode.block_table,
cache_seqlens=attn_metadata.decode.seq_lens,
head_dim_v=self.kv_lora_rank,
tile_scheduler_metadata=attn_metadata.decode.
tile_scheduler_metadata,
num_splits=attn_metadata.decode.num_splits,
softmax_scale=self.scale,
causal=True,
descale_q=layer._q_scale.reshape(1),
descale_k=layer._k_scale.reshape(1),
)
return o, lse

View File

@@ -0,0 +1,544 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from dataclasses import dataclass
from typing import TYPE_CHECKING, ClassVar, Optional
import numpy as np
import torch
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
AttentionMetadata)
from vllm.attention.backends.utils import get_mla_dims
from vllm.attention.ops.flashmla import (flash_mla_sparse_prefill,
flash_mla_with_kvcache,
get_mla_metadata)
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import cdiv
from vllm.v1.attention.backends.mla.common import MLACommonBaseImpl
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import AttentionSpec
if TYPE_CHECKING:
from vllm.model_executor.models.deepseek_v2 import Indexer
logger = init_logger(__name__)
"""
NOTE: FlashMLA Sparse uses an fp8 cache with the following format
In the "FP8 with scale" format, each token's KV cache is 656 Bytes,
structured as:
- **First 512 bytes:** The "quantized NoPE" part, containing 512
`float8_e4m3` values.
- **Next 16 bytes:** Scale factors, containing 4 `float32` values.
The first `float32` is the scale for the first 128 `float8_e4m3` values,
the second for the next 128, and so on.
- **Last 128 bytes:** The "RoPE" part, containing 64 `bfloat16` values. This
part is not quantized for accuracy.
"""
def _lse2_to_lse(lse_base2: torch.Tensor) -> torch.Tensor:
# Convert base-2 LSE to natural-log LSE
# Keep FP32 for numerical stability during the merge.
return (lse_base2.to(torch.float32) * math.log(2.0))
class FlashMLASparseBackend(AttentionBackend):
accept_output_buffer: bool = True
@staticmethod
def get_name() -> str:
return "FLASHMLA_SPARSE_VLLM_V1"
@staticmethod
def get_metadata_cls() -> type[AttentionMetadata]:
return FlashMLASparseMetadata
@staticmethod
def get_builder_cls() -> type["FlashMLASparseMetadataBuilder"]:
return FlashMLASparseMetadataBuilder
@staticmethod
def get_impl_cls() -> type["FlashMLASparseImpl"]:
return FlashMLASparseImpl
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int, # assumed to be 1 for MLA
head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
if cache_dtype_str == "fp8_ds_mla":
# custom storage fromat is 656 bytes
# see FlashMLA readme.md for details
return (num_blocks, block_size, 656)
else:
return (num_blocks, block_size, head_size)
@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.bfloat16]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [576]
@dataclass
class MLASparsePrefillMetadata:
# NOTE(Chen): not call it "FlashMLASparsePrefillMetadata" because
# the kernel is not from flashmla
block_table: torch.Tensor
has_context: bool = False
context_lens: Optional[torch.Tensor] = None
@dataclass
class FlashMLASparseDecodeAndContextMetadata:
scheduler_metadata: torch.Tensor = None
num_splits: torch.Tensor = None
cache_lens: torch.Tensor = None
prefill_context_lengths: Optional[torch.Tensor] = None
prefill_new_k_start_locs: Optional[torch.Tensor] = None
dummy_block_table: torch.Tensor = None
def filter_prefill_indices(
self, indices: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
assert self.prefill_context_lengths is not None
prefill_context_lengths = self.prefill_context_lengths.unsqueeze(-1)
context_indices = torch.where(indices < prefill_context_lengths,
indices, -1)
new_token_indices = torch.where(indices >= prefill_context_lengths,
indices - prefill_context_lengths, -1)
return context_indices, new_token_indices
@dataclass
class FlashMLASparseMetadata:
num_reqs: int
max_query_len: int
max_seq_len: int
num_actual_tokens: int # Number of tokens excluding padding.
query_start_loc: torch.Tensor
slot_mapping: torch.Tensor
block_table: torch.Tensor
req_id_per_token: torch.Tensor
block_size: int = 64
topk_tokens: int = 2048
@dataclass
class FP8KernelMetadata:
scheduler_metadata: Optional[torch.Tensor]
num_splits: torch.Tensor
dummy_block_table: torch.Tensor
cache_lens: torch.Tensor
fp8_extra_metadata: Optional[FP8KernelMetadata] = None
@triton.jit
def _convert_req_index_to_global_index_kernel(
req_id_ptr, # int32 [num_tokens]
block_table_ptr, # int32 [num_requests, max_num_blocks_per_req]
token_indices_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS]
out_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS]
# shapes (compile-time where possible)
max_num_blocks_per_req: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
BLOCK_N: tl.constexpr, # tile width along columns
# strides (in elements)
bt_stride0,
bt_stride1,
ti_stride0,
ti_stride1,
out_stride0,
out_stride1,
):
# program_id(0) -> token_id (row)
# program_id(1) -> tile index along columns
token_id = tl.program_id(0)
tile_id = tl.program_id(1)
# Each program covers BLOCK_N consecutive columns
indice_id = tile_id * BLOCK_N + tl.arange(0, BLOCK_N)
# Load request id for this token (no mask: grid is exact)
req = tl.load(req_id_ptr + token_id)
# Load token indices for this tile
ti_ptr = token_indices_ptr + token_id * ti_stride0 + indice_id * ti_stride1
tok = tl.load(ti_ptr) # int32
# Only token == -1 should propagate as -1
is_invalid_tok = tok < 0
# Compute block id and in-block offset
block_id = tok // BLOCK_SIZE
inblock_off = tok % BLOCK_SIZE
# Guard block_table access
valid_block = block_id < max_num_blocks_per_req
bt_ptr = block_table_ptr + req * bt_stride0 + block_id * bt_stride1
base = tl.load(bt_ptr, mask=valid_block, other=0)
# If token == -1 OR block_id OOB, output -1; else base * BLOCK_SIZE + offset
out_val = tl.where(is_invalid_tok | (~valid_block), -1,
base * BLOCK_SIZE + inblock_off)
# Store results
out_ptr_ij = out_ptr + token_id * out_stride0 + indice_id * out_stride1
tl.store(out_ptr_ij, out_val)
def triton_convert_req_index_to_global_index(
req_id: torch.Tensor, # int32 [num_tokens]
block_table: torch.
Tensor, # int32 [num_requests, max_num_blocks_per_req]
token_indices: torch.Tensor, # int32 [num_tokens, NUM_TOPK_TOKENS]
BLOCK_SIZE: int = 64,
NUM_TOPK_TOKENS: int = 2048,
BLOCK_N: int = 128, # tile width along columns
):
"""
out[token_id, indice_id] =
block_table[req_id[token_id],
token_indices[token_id, indice_id] // BLOCK_SIZE] * BLOCK_SIZE
+ token_indices[token_id, indice_id] % BLOCK_SIZE
Only when token_indices[token_id, indice_id] == -1 do we output -1.
For safety, we also output -1 if the derived block_id would be
out-of-bounds.
"""
assert req_id.dtype == torch.int32
assert block_table.dtype == torch.int32
assert token_indices.dtype == torch.int32
assert token_indices.shape[1] == NUM_TOPK_TOKENS
assert NUM_TOPK_TOKENS % BLOCK_N == 0, \
f"NUM_TOPK_TOKENS ({NUM_TOPK_TOKENS}) must be divisible by" \
f"BLOCK_N ({BLOCK_N})"
num_tokens = req_id.shape[0]
num_requests, max_num_blocks_per_req = block_table.shape
tiles_per_row = NUM_TOPK_TOKENS // BLOCK_N
# Ensure contiguous tensors on the same device
req_id_c = req_id.contiguous()
block_table_c = block_table.contiguous()
token_indices_c = token_indices.contiguous()
out = torch.empty_like(token_indices_c)
# Strides in elements
bt_stride0, bt_stride1 = block_table_c.stride()
ti_stride0, ti_stride1 = token_indices_c.stride()
out_stride0, out_stride1 = out.stride()
# Exact 2D grid: tokens × column tiles
grid = (num_tokens, tiles_per_row)
_convert_req_index_to_global_index_kernel[grid](
req_id_c,
block_table_c,
token_indices_c,
out,
# shapes / constexprs
max_num_blocks_per_req,
BLOCK_SIZE,
BLOCK_N,
# strides
bt_stride0,
bt_stride1,
ti_stride0,
ti_stride1,
out_stride0,
out_stride1,
)
return out
@dataclass
class FlashMLASparseMetadataBuilder(
AttentionMetadataBuilder[FlashMLASparseMetadata]):
cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.UNIFORM_BATCH
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
cache_config = vllm_config.cache_config
self.kv_cache_spec = kv_cache_spec
self.model_config = vllm_config.model_config
parallel_config = vllm_config.parallel_config
self.device = device
props = torch.cuda.get_device_properties(device)
sm_count = props.multi_processor_count
self.num_heads = self.model_config.get_num_attention_heads(
parallel_config)
self.mla_dims = get_mla_dims(self.model_config)
self.topk_tokens = vllm_config.model_config.hf_config.index_topk
self.use_fp8_kv_cache = cache_config.cache_dtype == "fp8_ds_mla"
self.topk_tokens_tensor = torch.tensor([self.topk_tokens],
device=device,
dtype=torch.int32)
self.max_model_len_tensor = torch.tensor(
[self.model_config.max_model_len],
device=device,
dtype=torch.int32)
# this is ignored by `flash_mla_with_kvcache` if indices not None
self.dummy_block_table = torch.empty((1, 1),
dtype=torch.int32,
device=self.device)
# Equation taken from FlashMLA/csrc/pybind.cpp
h_q, h_k = self.num_heads, 1
s_q = 1 # inversely proportional to s_q, so s_q = 1 is the largest
max_num_sm_parts = int(
max((sm_count // 2) / h_k // (cdiv(h_q // h_k, 2 * 64) * s_q), 1))
if current_platform.is_device_capability(100):
max_num_sm_parts *= 2
self.tile_scheduler_metadata_buffer = torch.empty(
# TileSchedulerMetaDataSize = 8
# see: FlashMLA/csrc/params.h
(max_num_sm_parts, 8),
dtype=torch.int32,
device=device)
self.num_splits_buffer = torch.empty(
# We pack all the tokens into one batch for sparse attention.
# Otherwise, we can exceed the sm of `get_mla_metadata`.
(
2, ),
dtype=torch.int32,
device=device)
self.req_id_per_token_buffer = torch.empty(
(vllm_config.scheduler_config.max_num_batched_tokens, ),
dtype=torch.int32,
device=device)
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> FlashMLASparseMetadata:
num_tokens = common_attn_metadata.num_actual_tokens
starts = np.asarray(common_attn_metadata.query_start_loc_cpu,
dtype=np.int32)
seg_lengths = np.diff(starts)
req_id_per_token = np.repeat(
np.arange(seg_lengths.shape[0], dtype=np.int32), seg_lengths)
# Zero-fill for cudagraphs
self.req_id_per_token_buffer.fill_(0)
self.req_id_per_token_buffer[:req_id_per_token.shape[0]]\
.copy_(torch.from_numpy(req_id_per_token), non_blocking=True)
req_id_per_token = self.req_id_per_token_buffer[:num_tokens]
fp8_extra_metadata = None
if self.use_fp8_kv_cache:
tile_scheduler_metadata, num_splits = get_mla_metadata(
cache_seqlens=self.topk_tokens_tensor,
num_q_tokens_per_head_k=num_tokens * self.num_heads,
topk=self.topk_tokens,
num_heads_q=self.num_heads,
num_heads_k=1,
is_fp8_kvcache=True,
)
num_sm_parts = tile_scheduler_metadata.size(0)
# Copy to persistent buffer for full-CG support
tile_scheduler_metadata_buffer = \
self.tile_scheduler_metadata_buffer[:num_sm_parts]
tile_scheduler_metadata_buffer.copy_(tile_scheduler_metadata)
self.num_splits_buffer.copy_(num_splits)
fp8_extra_metadata = FlashMLASparseMetadata.FP8KernelMetadata(
scheduler_metadata=tile_scheduler_metadata_buffer,
num_splits=self.num_splits_buffer,
# cache_lens and block_table are basically unused in sparse case
# but the decode kernel will treat -1 and indices >= cache_lens
# as invalid so we make sure cache_lens is large enough to not
# accidentally mark indices invalid, we will use -1 exclusively
# to mark invalid indices
cache_lens=self.max_model_len_tensor,
dummy_block_table=self.dummy_block_table)
metadata = FlashMLASparseMetadata(
num_reqs=common_attn_metadata.num_reqs,
max_query_len=common_attn_metadata.max_query_len,
max_seq_len=common_attn_metadata.max_seq_len,
num_actual_tokens=common_attn_metadata.num_actual_tokens,
query_start_loc=common_attn_metadata.query_start_loc,
slot_mapping=common_attn_metadata.slot_mapping,
block_table=common_attn_metadata.block_table_tensor,
req_id_per_token=req_id_per_token,
block_size=self.kv_cache_spec.block_size,
topk_tokens=self.topk_tokens,
fp8_extra_metadata=fp8_extra_metadata,
)
return metadata
class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
logits_soft_cap: Optional[float],
attn_type: str,
kv_sharing_target_layer_name: Optional[str],
# MLA Specific Arguments
topk_indice_buffer: Optional[torch.Tensor] = None,
indexer: Optional["Indexer"] = None,
**mla_args) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
logits_soft_cap, attn_type,
kv_sharing_target_layer_name, **mla_args)
self.softmax_scale = scale
assert indexer is not None
self.topk_indices_buffer = indexer.topk_indices_buffer
self.padding = 128 if current_platform.is_device_capability(
100) else 64
def _forward_bf16_kv(
self, q: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor,
topk_indices: torch.Tensor,
attn_metadata: FlashMLASparseMetadata) -> torch.Tensor:
num_tokens = q.shape[0]
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view(
-1, 1, kv_c_and_k_pe_cache.shape[-1])
# NOTE(Chen): kernel requires num_local_head to be a multiple of
# 64 on hopper and 128 on blackwell
if self.num_heads % self.padding != 0:
assert self.padding % self.num_heads == 0
logger.warning_once(f"padding num_heads to {self.padding} \
due to sparse attn kernel requirement")
q_padded = q.new_empty((q.shape[0], self.padding, q.shape[2]))
q_padded[:, :self.num_heads, :] = q
q = q_padded
topk_indices = topk_indices.view(num_tokens, 1, -1)
output = flash_mla_sparse_prefill(q, kv_c_and_k_pe_cache, topk_indices,
self.softmax_scale)[0]
output = output[:, :self.num_heads, :]
return output
def _forward_fp8_kv(self, q: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
topk_indices: torch.Tensor,
attn_metadata: FlashMLASparseMetadata) -> torch.Tensor:
assert attn_metadata.fp8_extra_metadata is not None
extra_metadata = attn_metadata.fp8_extra_metadata
_attn_out, _ = flash_mla_with_kvcache(
q=q.unsqueeze(0), # unsqueeze to add batch_dim
k_cache=kv_c_and_k_pe_cache.view(torch.uint8).unsqueeze(-2),
block_table=extra_metadata.dummy_block_table,
head_dim_v=512,
cache_seqlens=extra_metadata.cache_lens,
tile_scheduler_metadata=extra_metadata.scheduler_metadata,
num_splits=extra_metadata.num_splits,
is_fp8_kvcache=True,
indices=topk_indices.unsqueeze(0), # unsqueeze to add batch_dim
softmax_scale=self.softmax_scale,
)
return _attn_out
def forward(
self,
layer: AttentionLayer,
q: torch.Tensor,
k_c_normed: torch.Tensor, # key in unified attn
k_pe: torch.Tensor, # value in unified attn
kv_cache: torch.Tensor,
attn_metadata: FlashMLASparseMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# NOTE(lucas): for the sparse FlashMLA kernels the kernels want to use
# MQA 576/512 approach for both prefill and decode
assert output is not None, "Output tensor must be provided."
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for MLACommonImpl")
if attn_metadata is None:
# The zero fill is required when used with DP + EP
# to ensure all ranks within a DP group compute the
# same expert outputs.
return output.fill_(0)
num_actual_toks = attn_metadata.num_actual_tokens
# Inputs and outputs may be padded for CUDA graphs
q = q[:num_actual_toks, ...]
k_c_normed = k_c_normed[:num_actual_toks, ...]
k_pe = k_pe[:num_actual_toks, ...]
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim],
dim=-1)
# Convert from (B, N, P) to (N, B, P)
q_nope = q_nope.transpose(0, 1)
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
ql_nope = torch.bmm(q_nope, self.W_UK_T)
# Convert from (N, B, L) to (B, N, L)
ql_nope = ql_nope.transpose(0, 1)
topk_indices = self.topk_indices_buffer[:num_actual_toks]
# TODO: handle index / kv_cache correctly
topk_indices_global = triton_convert_req_index_to_global_index(
attn_metadata.req_id_per_token,
attn_metadata.block_table,
topk_indices,
BLOCK_SIZE=attn_metadata.block_size,
NUM_TOPK_TOKENS=attn_metadata.topk_tokens,
)
q = torch.cat([ql_nope, q_pe], dim=-1)
# write the latent and rope to kv cache
if kv_cache.numel() > 0:
ops.concat_and_cache_mla(
k_c_normed,
k_pe.squeeze(1),
kv_cache,
attn_metadata.slot_mapping.flatten(),
kv_cache_dtype=self.kv_cache_dtype,
scale=layer._k_scale,
)
if self.kv_cache_dtype != "fp8_ds_mla":
attn_out = self._forward_bf16_kv(q, kv_cache, topk_indices_global,
attn_metadata)
else:
attn_out = self._forward_fp8_kv(q, kv_cache, topk_indices_global,
attn_metadata)
self._v_up_proj(attn_out, out=output[:num_actual_toks])
return output

View File

@@ -0,0 +1,342 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import ClassVar, Optional
import torch
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata)
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.utils.deep_gemm import get_paged_mqa_logits_metadata
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
split_decodes_and_prefills)
logger = init_logger(__name__)
class DeepseekV32IndexerBackend(AttentionBackend):
@staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
return DeepseekV32IndexerMetadata
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [32, 64, 128]
@staticmethod
def get_builder_cls() -> type["DeepseekV32IndexerMetadataBuilder"]:
return DeepseekV32IndexerMetadataBuilder
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
assert num_kv_heads == 1
return (num_blocks, block_size, head_size)
@staticmethod
def get_kv_cache_stride_order() -> tuple[int, ...]:
return (0, 1, 2)
@dataclass
class DeepseekV32IndexerPrefillChunkMetadata:
block_table: torch.Tensor
cu_seqlen_ks: torch.Tensor
cu_seqlen_ke: torch.Tensor
cu_seq_lens: torch.Tensor
total_seq_lens: int
token_start: int
token_end: int
num_reqs: int
@dataclass
class DeepseekV32IndexerPrefillMetadata:
chunks: list[DeepseekV32IndexerPrefillChunkMetadata]
@dataclass
class DeepSeekV32IndexerDecodeMetadata:
block_table: torch.Tensor
seq_lens: torch.Tensor
decode_lens: torch.Tensor
requires_padding: bool
schedule_metadata: torch.Tensor
@dataclass
class DeepseekV32IndexerMetadata:
# FIXME (zyongye)
# hacky way to access the data now, need to be in chunked meta
seq_lens: torch.Tensor
num_reqs: int
max_query_len: int
max_seq_len: int
num_actual_tokens: int # Number of tokens excluding padding.
query_start_loc: torch.Tensor
slot_mapping: torch.Tensor
# The dimension of the attention heads
head_dim: int
# New for MLA (compared to FlashAttention)
# For handling prefill decode split
num_decodes: int
num_decode_tokens: int
num_prefills: int
num_prefill_tokens: int
decode: Optional[DeepSeekV32IndexerDecodeMetadata] = None
prefill: Optional[DeepseekV32IndexerPrefillMetadata] = None
# TODO (zyongye) optimize this, this is now vibe coded
def kv_spans_from_batches(
start_seq_loc: torch.Tensor, seq_len_per_batch: torch.Tensor,
device: torch.device) -> tuple[torch.Tensor, torch.Tensor]:
"""
Args:
start_seq_loc: 1D long tensor [B+1], cumulative counts of
selected tokens per batch.
Example: [0, 2, 4, 7] ->
batch sizes (selected) [2, 2, 3], N=7 tokens total.
seq_len_per_batch: 1D long tensor [B],
full sequence length (KV length) of each batch.
Example: [5, 9, 4].
Returns:
start_tensor: 1D long tensor [N], start offset in the
concatenated KV cache for each token's batch.
end_location: 1D long tensor [N],
**exclusive** end = start + token's local position.
(So the attended KV slice is kv[start:end].)
Assumes each batch contributes its full `seq_len_per_batch[i]`
keys to the KV cache, andthe selected tokens within a batch
are the **last** `counts[i]` positions of that sequence.
"""
q = start_seq_loc.to(dtype=torch.long)
L = seq_len_per_batch.to(dtype=torch.long)
assert q.dim() == 1 and L.dim() == 1
assert q.numel() == L.numel() + 1, "start_seq_loc must have length B+1"
# Selected tokens per batch and totals
counts = q[1:] - q[:-1] # [B]
N = int(q[-1].item()) # total selected tokens
B = L.numel()
if N == 0:
return (torch.empty(0, dtype=torch.long, device=device),
torch.empty(0, dtype=torch.long, device=device))
# KV start offsets per batch in the concatenated KV cache
kv_starts_per_batch = torch.cumsum(L, dim=0) - L # [B]
# For each selected token, which batch does it belong to?
batch_id = torch.repeat_interleave(torch.arange(B), counts) # [N]
# Map batch KV start to each token
start_tensor = kv_starts_per_batch[batch_id] # [N]
# End-align local positions inside each batch:
# local_pos = L[b] - counts[b] + (1..counts[b]) for each batch b
L_expand = torch.repeat_interleave(L, counts) # [N]
m_expand = torch.repeat_interleave(counts, counts) # [N]
# position within the selected block: 1..counts[b]
pos_within = (torch.arange(N, dtype=torch.long) -
torch.repeat_interleave(q[:-1], counts) + 1)
local_pos = L_expand - m_expand + pos_within # [N], 1-based
end_location = start_tensor + local_pos # exclusive end
return start_tensor.int().to(device), end_location.int().to(device)
def get_max_prefill_buffer_size(vllm_config: VllmConfig):
max_model_len = vllm_config.model_config.max_model_len
# NOTE(Chen): 2 is a magic number for controlling the prefill buffer size.
# May be tuned later.
return max_model_len * 2
def split_prefill_chunks(seq_lens_cpu: torch.Tensor,
max_prefill_buffer_size: int,
reqs_start: int) -> list[tuple[int, int]]:
"""
Split the prefill chunks into a list of tuples of (reqs_start, reqs_end)
such that the total sequence length of each chunk is less than the
maximum prefill buffer size.
Args:
seq_lens_cpu: The sequence lengths of the prefill requests.
max_prefill_buffer_size: The maximum prefill buffer size.
reqs_start: The start index of the prefill requests.
Returns:
A list of tuples of (reqs_start, reqs_end).
"""
chunk_seq_ids = []
total_seq_lens = 0
for i in range(reqs_start, len(seq_lens_cpu)):
cur_seq_len = seq_lens_cpu[i].item()
assert cur_seq_len <= max_prefill_buffer_size
total_seq_lens += cur_seq_len
if total_seq_lens > max_prefill_buffer_size:
chunk_seq_ids.append((reqs_start, i))
reqs_start = i
total_seq_lens = cur_seq_len
if total_seq_lens > 0:
chunk_seq_ids.append((reqs_start, len(seq_lens_cpu)))
return chunk_seq_ids
class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
reorder_batch_threshold: int = 1
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
scheduler_config = self.vllm_config.scheduler_config
#NOTE(Chen):an estimated max size of flattened_kv. Need to double check.
self.max_prefill_buffer_size = get_max_prefill_buffer_size(
self.vllm_config)
self.num_speculative_tokens = (
self.vllm_config.speculative_config.num_speculative_tokens
if self.vllm_config.speculative_config else 0)
# Now deepgemm fp8_paged_mqa_logits does not support next_n > 2
self.reorder_batch_threshold += min(self.num_speculative_tokens, 1)
props = torch.cuda.get_device_properties(self.device)
sm_count = props.multi_processor_count
self.num_sms = sm_count
self.decode_lens_buffer = torch.empty(
(scheduler_config.max_num_seqs, ),
dtype=torch.int32,
device=self.device)
# See: DeepGMM/csrc/apis/attention.hpp
self.scheduler_metadata_buffer = torch.empty((self.num_sms + 1, 2),
dtype=torch.int32,
device=self.device)
def build_one_prefill_chunk(self, reqs_start, reqs_end,
query_start_loc_cpu, seq_lens_cpu,
block_table):
prefill_query_start_loc = query_start_loc_cpu[
reqs_start:reqs_end + 1] - query_start_loc_cpu[reqs_start]
cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches(
prefill_query_start_loc, seq_lens_cpu[reqs_start:reqs_end],
self.device)
token_start = query_start_loc_cpu[reqs_start].item()
token_end = query_start_loc_cpu[reqs_end].item()
total_seq_lens = seq_lens_cpu[reqs_start:reqs_end].sum()
assert total_seq_lens <= self.max_prefill_buffer_size
cu_seq_lens = torch.cat([
torch.zeros(1, dtype=torch.int32),
seq_lens_cpu[reqs_start:reqs_end].cumsum(dim=0)
]).to(torch.int32).to(self.device)
return DeepseekV32IndexerPrefillChunkMetadata(
cu_seqlen_ks=cu_seqlen_ks,
cu_seqlen_ke=cu_seqlen_ke,
cu_seq_lens=cu_seq_lens,
total_seq_lens=total_seq_lens,
block_table=block_table[reqs_start:reqs_end],
token_start=token_start,
token_end=token_end,
num_reqs=reqs_end - reqs_start,
)
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> DeepseekV32IndexerMetadata:
num_reqs = common_attn_metadata.num_reqs
num_tokens = common_attn_metadata.num_actual_tokens
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
split_decodes_and_prefills(
common_attn_metadata,
decode_threshold=self.reorder_batch_threshold)
assert num_decodes + num_prefills == num_reqs
assert num_decode_tokens + num_prefill_tokens == num_tokens
prefill_metadata = None
if num_prefills > 0:
chunk_seq_ids = split_prefill_chunks(
common_attn_metadata.seq_lens_cpu,
self.max_prefill_buffer_size,
num_decodes,
)
chunks = [
self.build_one_prefill_chunk(
reqs_start, reqs_end, query_start_loc_cpu,
common_attn_metadata.seq_lens_cpu,
common_attn_metadata.block_table_tensor)
for reqs_start, reqs_end in chunk_seq_ids
]
prefill_metadata = DeepseekV32IndexerPrefillMetadata(
chunks=chunks, )
decode_metadata = None
if num_decodes > 0:
torch.diff(common_attn_metadata.query_start_loc[:num_decodes + 1],
out=self.decode_lens_buffer[:num_decodes])
decode_lens = self.decode_lens_buffer[:num_decodes]
decode_lens_cpu = torch.diff(
common_attn_metadata.query_start_loc_cpu[:num_decodes + 1])
# Use CPU to avoid GPU sync; breaking async scheduling
requires_padding = (decode_lens_cpu.max()
> decode_lens_cpu.min()).item()
seq_lens = common_attn_metadata.seq_lens[:num_decodes]
self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata(
seq_lens, self.kv_cache_spec.block_size, self.num_sms)
decode_metadata = DeepSeekV32IndexerDecodeMetadata(
block_table=common_attn_metadata.
block_table_tensor[:num_decodes, ...],
seq_lens=common_attn_metadata.seq_lens[:num_decodes],
decode_lens=decode_lens,
requires_padding=requires_padding,
schedule_metadata=self.scheduler_metadata_buffer,
)
attn_metadata = DeepseekV32IndexerMetadata(
seq_lens=common_attn_metadata.seq_lens,
num_reqs=common_attn_metadata.num_reqs,
max_query_len=common_attn_metadata.max_query_len,
max_seq_len=common_attn_metadata.max_seq_len,
num_actual_tokens=common_attn_metadata.num_actual_tokens,
query_start_loc=common_attn_metadata.query_start_loc,
slot_mapping=common_attn_metadata.slot_mapping,
head_dim=128,
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
prefill=prefill_metadata,
decode=decode_metadata,
)
# if get_tensor_model_parallel_rank() == 0:
# logger.info(f"attn_metadata: {attn_metadata}")
return attn_metadata

View File

@@ -0,0 +1,255 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import ClassVar, Optional, Union
import torch
import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionLayer
from vllm.attention.ops.rocm_aiter_mla import aiter_mla_decode_fwd
from vllm.config import VllmConfig
from vllm.utils import cdiv
# yapf conflicts with isort for this docstring
# yapf: disable
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
MLACommonDecodeMetadata,
MLACommonImpl,
MLACommonMetadata,
MLACommonMetadataBuilder)
from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.kv_cache_interface import AttentionSpec
# yapf: enable
def is_aiter_mla_enabled() -> bool:
return envs.VLLM_ROCM_USE_AITER \
and envs.VLLM_ROCM_USE_AITER_MLA
class AiterMLABackend(MLACommonBackend):
@staticmethod
def get_name() -> str:
return "ROCM_AITER_MLA"
@staticmethod
def get_impl_cls() -> type["AiterMLAImpl"]:
return AiterMLAImpl
@staticmethod
def get_metadata_cls() -> type["AiterMLAMetadata"]:
return AiterMLAMetadata
@staticmethod
def get_builder_cls() -> type["AiterMLAMetadataBuilder"]:
return AiterMLAMetadataBuilder
@dataclass
class AiterMLADecodeMetadata(MLACommonDecodeMetadata):
# The indptr of the paged kv cache, shape: [batch_size + 1]
paged_kv_indptr: Optional[torch.Tensor] = None
# The page indices of the paged kv cache
paged_kv_indices: Optional[torch.Tensor] = None
# The number of entries in the last page of each request in
# the paged kv cache, shape: [batch_size]
paged_kv_last_page_len: Optional[torch.Tensor] = None
# The query indptr, shape : [num_decode + 1]
qo_indptr: Optional[torch.Tensor] = None
class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
pass
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
# TODO(luka, lucas): audit this as part of:
# https://github.com/vllm-project/vllm/issues/22945
cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
super().__init__(kv_cache_spec, layer_names, vllm_config, device,
AiterMLAMetadata)
assert self.kv_cache_spec.block_size == 1, "AITER MLA" \
"only supports block size 1."
self.compilation_config = vllm_config.compilation_config
max_num_pages_per_req = cdiv(vllm_config.model_config.max_model_len,
self.kv_cache_spec.block_size)
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
max_num_pages = max_num_reqs * max_num_pages_per_req
# Preparing persistent buffers
# TODO: we can disambiguate between decode and mixed-prefill decode here
# so we can only use the persistent buffer if a cudagraph is actually
# being used.
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
self.paged_kv_indptr = torch.zeros(max_num_reqs + 1,
dtype=torch.int32,
device=device)
self.paged_kv_indices = torch.zeros(max_num_pages,
dtype=torch.int32,
device=device)
self.paged_kv_last_page_len = torch.zeros(max_num_reqs,
dtype=torch.int32,
device=device)
self.qo_indptr = torch.arange(0,
max_num_reqs + 1,
dtype=torch.int32,
device=device)
def _build_decode(self, block_table_tensor: torch.Tensor,
seq_lens_cpu: torch.Tensor,
seq_lens_device: torch.Tensor,
query_start_loc_cpu: torch.Tensor,
query_start_loc_device: torch.Tensor,
num_decode_tokens: int) -> AiterMLADecodeMetadata:
page_size = self.kv_cache_spec.block_size
block_table_bounds = (seq_lens_device + page_size - 1) // page_size
device = self.device
num_reqs = seq_lens_device.size(0)
mask = (torch.arange(block_table_tensor.size(1),
dtype=block_table_tensor.dtype,
device=device).unsqueeze(0)
< block_table_bounds.unsqueeze(1))
paged_kv_indices = block_table_tensor[mask]
paged_kv_last_page_len = seq_lens_device % page_size
paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0,
page_size, paged_kv_last_page_len)
paged_kv_indptr = torch.cat([
torch.zeros(1, dtype=block_table_bounds.dtype, device=device),
block_table_bounds.cumsum(dim=0, dtype=torch.int32)
])
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
num_actual_pages = paged_kv_indices.size(0)
self.paged_kv_indices[:num_actual_pages].copy_(paged_kv_indices,
non_blocking=True)
self.paged_kv_indices[num_actual_pages:].fill_(-1)
paged_kv_indices = self.paged_kv_indices[:num_actual_pages]
self.paged_kv_indptr[:1 + num_reqs].copy_(paged_kv_indptr,
non_blocking=True)
self.paged_kv_indptr[1 + num_reqs:].fill_(paged_kv_indptr[-1])
paged_kv_indptr = self.paged_kv_indptr[:1 + num_reqs]
self.paged_kv_last_page_len[:num_reqs].copy_(
paged_kv_last_page_len, non_blocking=True)
self.paged_kv_last_page_len[num_reqs:].fill_(1)
paged_kv_last_page_len = self.paged_kv_last_page_len[:num_reqs]
qo_indptr = self.qo_indptr[:1 + num_reqs]
else:
qo_indptr = torch.arange(0,
num_reqs + 1,
step=1,
dtype=torch.int32,
device=device)
attn_metadata = AiterMLADecodeMetadata(
block_table=block_table_tensor,
seq_lens=seq_lens_device,
paged_kv_indptr=paged_kv_indptr,
paged_kv_indices=paged_kv_indices,
paged_kv_last_page_len=paged_kv_last_page_len,
qo_indptr=qo_indptr)
return attn_metadata
class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
logits_soft_cap: Optional[float],
attn_type: str,
kv_sharing_target_layer_name: Optional[str],
# MLA Specific Arguments
**mla_args) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
logits_soft_cap, attn_type,
kv_sharing_target_layer_name, **mla_args)
assert (num_heads == 16 or num_heads == 128), (
f"Aiter MLA only supports 16 or 128 number of heads.\n"
f"Provided {num_heads} number of heads.\n"
"Try adjusting tensor_parallel_size value.")
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
if any(unsupported_features):
raise NotImplementedError(
"Aiter MLA does not support one of the following: "
"alibi_slopes, sliding_window, logits_soft_cap")
from aiter import flash_attn_varlen_func
self.flash_attn_varlen_func = flash_attn_varlen_func
def _flash_attn_varlen_diff_headdims(self,
q,
k,
v,
return_softmax_lse=False,
softmax_scale=None,
**kwargs):
output = self.flash_attn_varlen_func(
q=q,
k=k,
v=v,
softmax_scale=softmax_scale,
return_lse=return_softmax_lse,
**kwargs,
)
return output
def _forward_decode(
self,
q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: AiterMLAMetadata,
layer: AttentionLayer,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
if type(q) is tuple:
q = torch.cat(q, dim=-1)
assert isinstance(q, torch.Tensor)
B = q.shape[0]
o = torch.zeros(B,
self.num_heads,
self.kv_lora_rank,
dtype=q.dtype,
device=q.device)
kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2)
# max_seqlen_qo must be 1 except for MTP
# TODO: Find the best value for MTP
max_seqlen_qo = 1
aiter_mla_decode_fwd(q, kv_buffer, o, self.scale,
attn_metadata.decode.qo_indptr, max_seqlen_qo,
attn_metadata.decode.paged_kv_indptr,
attn_metadata.decode.paged_kv_indices,
attn_metadata.decode.paged_kv_last_page_len)
return o, None

View File

@@ -0,0 +1,177 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional, Union
import torch
from vllm import envs
from vllm.attention.backends.abstract import (AttentionLayer, AttentionType,
is_quantized_kv_cache)
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
from vllm.attention.ops.triton_flash_attention import triton_attention
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.triton_utils import HAS_TRITON
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
MLACommonImpl,
MLACommonMetadata)
logger = init_logger(__name__)
class TritonMLABackend(MLACommonBackend):
@staticmethod
def get_name() -> str:
return "TRITON_MLA"
@staticmethod
def get_impl_cls() -> type["TritonMLAImpl"]:
return TritonMLAImpl
class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
can_return_lse_for_decode: bool = True
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
logits_soft_cap: Optional[float],
attn_type: str,
kv_sharing_target_layer_name: Optional[str],
# MLA Specific Arguments
**mla_args) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
logits_soft_cap, attn_type,
kv_sharing_target_layer_name, **mla_args)
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
if any(unsupported_features):
raise NotImplementedError(
"TritonMLAImpl does not support one of the following: "
"alibi_slopes, sliding_window, logits_soft_cap")
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"TritonMLAImpl")
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"TritonMLA V1 with FP8 KV cache not yet supported")
self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN
self.triton_fa_func = triton_attention if HAS_TRITON else None
def _flash_attn_varlen_diff_headdims_rocm(self,
q,
k,
v,
softmax_scale=None,
**kwargs):
assert self.triton_fa_func is not None
# Triton Attention requires a padded V
padded_v = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
value=0)
# The output of triton_attention is a tuple of
# [output_tensor, encoded_softmax] where encoded_softmax is always None
output_tensor, _ = self.triton_fa_func(
q,
k,
padded_v,
None, # output
kwargs["cu_seqlens_q"],
kwargs["cu_seqlens_k"],
kwargs["max_seqlen_q"],
kwargs["max_seqlen_k"],
kwargs["causal"],
softmax_scale,
None, # bias
)
return output_tensor
def _flash_attn_varlen_diff_headdims(self,
q,
k,
v,
return_softmax_lse=False,
softmax_scale=None,
**kwargs):
if current_platform.is_rocm() \
and self.use_triton_flash_attn \
and not return_softmax_lse:
return self._flash_attn_varlen_diff_headdims_rocm(
q, k, v, softmax_scale=softmax_scale, **kwargs)
else:
return super()._flash_attn_varlen_diff_headdims(
q,
k,
v,
return_softmax_lse=return_softmax_lse,
softmax_scale=softmax_scale,
**kwargs)
def _forward_decode(
self,
q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
layer: AttentionLayer,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
if self.kv_cache_dtype.startswith("fp8"):
raise NotImplementedError("FP8 Triton MLA not yet supported")
if type(q) is tuple:
q = torch.cat(q, dim=-1)
assert isinstance(q, torch.Tensor)
B = q.shape[0]
q_num_heads = q.shape[1]
o = torch.zeros(B,
q_num_heads,
self.kv_lora_rank,
dtype=q.dtype,
device=q.device)
lse = torch.zeros(B, q_num_heads, dtype=q.dtype, device=q.device)
num_kv_splits = 4 # TODO: heuristic
# TODO(lucas) Allocate ahead of time
attn_logits = torch.empty(
(
B,
q_num_heads,
num_kv_splits,
# NOTE(lucas) idk why the +1 is here but sglang has it so we
# just mirror that
self.kv_lora_rank + 1,
),
dtype=torch.float32,
device=q.device,
)
# Add a head dim of 1
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2)
kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank]
PAGE_SIZE = kv_c_and_k_pe_cache.size(1)
# Run MQA
decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o, lse,
attn_metadata.decode.block_table,
attn_metadata.decode.seq_lens, attn_logits,
num_kv_splits, self.scale, PAGE_SIZE)
return o, lse

View File

@@ -0,0 +1,409 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Optional
import torch
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer, AttentionType)
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.utils import cdiv, next_power_of_2
logger = init_logger(__name__)
# TPU requires the head size to be a multiple of 128.
TPU_HEAD_SIZE_ALIGNMENT = 128
# Note: TPU can fp8 as storage dtype but doesn't support converting from uint8
# from to fp32 directly. That's why it has a dtype mapping different from GPU
TPU_STR_DTYPE_TO_TORCH_DTYPE = {
"half": torch.half,
"bfloat16": torch.bfloat16,
"float": torch.float,
"fp8": torch.float8_e4m3fn,
"fp8_e4m3": torch.float8_e4m3fn,
"fp8_e5m2": torch.float8_e5m2,
"int8": torch.int8,
"uint8": torch.uint8,
}
try:
import tpu_commons # noqa: F401
except ImportError:
# Lazy import torch_xla
import torch_xla.core.xla_builder as xb
import torch_xla.experimental.custom_kernel # noqa: F401
from torch.library import impl
from torch_xla._internal.jax_workarounds import requires_jax
from torch_xla.experimental.custom_kernel import XLA_LIB
@requires_jax
def kv_cache_update_op_impl(kv: torch.Tensor, slot_mapping: torch.Tensor,
kv_cache: torch.Tensor,
num_kv_update_slices: torch.Tensor,
page_size: int, num_slices_per_block: int):
from vllm.attention.ops.pallas_kv_cache_update import kv_cache_update
new_kv_cache = xb.call_jax(
kv_cache_update,
(kv, slot_mapping, kv_cache, num_kv_update_slices), {
"page_size": page_size,
"num_slices_per_block": num_slices_per_block
})
return new_kv_cache
XLA_LIB.define(
"kv_cache_update_op(Tensor kv, Tensor slot_mapping," \
"Tensor kv_cache, Tensor num_kv_update_slices, int page_size," \
"int num_slices_per_block)" \
"-> Tensor", )
@impl(XLA_LIB, "kv_cache_update_op", "XLA")
def kv_cache_update_op_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
kv_cache: torch.Tensor,
num_kv_update_slices: torch.Tensor,
page_size: int,
num_slices_per_block: int) -> torch.Tensor:
new_kv_cache = kv_cache_update_op_impl(kv, slot_mapping, kv_cache,
num_kv_update_slices, page_size,
num_slices_per_block)
return new_kv_cache
@impl(XLA_LIB, "kv_cache_update_op", "CompositeExplicitAutograd")
def kv_cache_update_op_non_xla(kv: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache: torch.Tensor,
num_kv_update_slices: torch.Tensor,
page_size: int,
num_slices_per_block: int) -> torch.Tensor:
return kv_cache
class PallasAttentionBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "PALLAS"
@staticmethod
def get_impl_cls() -> type["PallasAttentionBackendImpl"]:
return PallasAttentionBackendImpl
@staticmethod
def get_metadata_cls() -> type["PallasMetadata"]:
return PallasMetadata
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
padded_head_size = cdiv(
head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
return (num_blocks, block_size, num_kv_heads * 2, padded_head_size)
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: torch.Tensor,
) -> None:
raise RuntimeError("swap_blocks is not used for the TPU backend.")
# In recent TPU generations, up to v6e, the SMEM size is 1MB. The
# block_tables within the PallasMetadata constitute almost the entire SMEM
# requirement. Its size is max_num_seqs * num_page_per_seq * 4 (Int). Here
# we simply make sure that the size is smaller than half of SMEM capacity.
@staticmethod
def get_min_page_size(vllm_config: VllmConfig) -> int:
max_num_page_per_req = (1024 * 1024 // 2 //
vllm_config.scheduler_config.max_num_seqs // 4)
min_page_size = cdiv(vllm_config.model_config.max_model_len,
max_num_page_per_req)
min_page_size = 1 << (min_page_size - 1).bit_length()
return min_page_size
@staticmethod
def get_max_num_seqs(model_len: int, page_size: int) -> int:
num_page_per_req = cdiv(model_len, page_size)
return 1024 * 1024 // 2 // num_page_per_req // 4
# TPU has limited SREGs (scalar registers), if page_size is too small, we
# can spill SREGs easily which leads to bad performance. The strategy we
# apply here is trying to split max-model-len to 16 pages which make the
# spill less likely. Meanwhile we make sure the page size is in [16, 256].
@staticmethod
def get_page_size(vllm_config: VllmConfig) -> int:
# TODO: This is a temporary fix for vmem OOM.
# For long model length, we use 16 page-size to avoid too much
# VMEM spill. A more robust solution should be implemented to
# handle VREG spills.
if vllm_config.model_config.max_model_len > 8192:
return 16
page_size = next_power_of_2(
vllm_config.model_config.max_model_len) // 16
if page_size <= 16:
return 16
if page_size >= 256:
return 256
return page_size
@dataclass
class PallasMetadata:
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
# Used in the PallasAttentionBackendImpl
slot_mapping: torch.Tensor
block_tables: torch.Tensor
context_lens: torch.Tensor
query_start_loc: torch.Tensor
num_seqs: torch.Tensor
num_kv_update_slices: torch.Tensor
num_slices_per_kv_cache_update_block: int
class PallasAttentionBackendImpl(AttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[int] = None,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
self.sliding_window = sliding_window
self.logits_soft_cap = logits_soft_cap
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
if alibi_slopes is not None:
raise NotImplementedError("Alibi slopes is not supported.")
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"PallasAttentionBackendImpl")
self.kv_cache_quantized_dtype = None
if kv_cache_dtype != "auto":
self.kv_cache_quantized_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE.get(
kv_cache_dtype.lower().strip())
def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: PallasMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with Pallas attention.
Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache: shape =
[num_blocks, block_size, num_kv_heads * 2, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for PallasAttentionBackendImpl")
# For determine_available_memory case.
if kv_cache.numel() == 0:
if output is None:
output = torch.ones_like(query)
return output
num_tokens, hidden_size = query.shape
query = query.view(num_tokens, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
if self.head_size % TPU_HEAD_SIZE_ALIGNMENT != 0:
padded_head_size = cdiv(
self.head_size,
TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
query = torch.nn.functional.pad(
query, (0, padded_head_size - self.head_size), value=0.0)
key = torch.nn.functional.pad(
key, (0, padded_head_size - self.head_size), value=0.0)
value = torch.nn.functional.pad(
value, (0, padded_head_size - self.head_size), value=0.0)
if self.kv_sharing_target_layer_name is None and kv_cache.numel() > 0:
# Write input keys and values to the KV cache.
# Skip this if sharing KV cache with an earlier attention layer.
slot_mapping = attn_metadata.slot_mapping
write_to_kv_cache(
key,
value,
kv_cache,
slot_mapping,
attn_metadata.num_slices_per_kv_cache_update_block,
attn_metadata.num_kv_update_slices,
self.kv_cache_quantized_dtype,
layer._k_scale_float,
layer._v_scale_float,
)
if self.kv_cache_quantized_dtype is not None and (
layer._k_scale_float == 0.0 or layer._v_scale_float == 0.0):
raise ValueError(
"k_scale_float and v_scale_float must be non-zero")
output = torch.ops.xla.ragged_paged_attention(
query,
kv_cache,
attn_metadata.context_lens,
attn_metadata.block_tables,
attn_metadata.query_start_loc,
attn_metadata.num_seqs,
# By default, the system utilizes optimized block size and
# vmem_limit_bytes parameters from the kernel repository. However,
# these can be manually adjusted for debugging if necessary.
num_kv_pages_per_block=None,
num_queries_per_block=None,
vmem_limit_bytes=None,
use_kernel=True,
sm_scale=self.scale,
sliding_window=self.sliding_window,
soft_cap=self.logits_soft_cap,
k_scale=layer._k_scale_float,
v_scale=layer._v_scale_float,
)
if self.head_size % TPU_HEAD_SIZE_ALIGNMENT != 0:
output = output[:, :, :self.head_size]
return output.reshape(num_tokens, hidden_size)
def write_to_kv_cache(
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
num_slices_per_kv_cache_update_block: int,
num_kv_update_slices: torch.Tensor,
kv_cache_quantized_dtype: Optional[torch.dtype] = None,
k_scale: float = 1.0,
v_scale: float = 1.0,
) -> None:
""" Write the key and values to the KV cache.
Args:
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache: shape = [num_blocks, block_size, num_kv_heads * 2, head_size]
num_slices_per_kv_cache_update_block: int
"""
_, page_size, num_combined_kv_heads, head_size = kv_cache.shape
head_size = cdiv(head_size,
TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
if kv_cache_quantized_dtype is not None:
dtype_info = torch.finfo(kv_cache_quantized_dtype)
key = key.to(torch.float32) / k_scale
# NOTE: clamp is added here to avoid out of range of quantized dtype
key = torch.clamp(key, dtype_info.min, dtype_info.max)
key = key.to(kv_cache_quantized_dtype)
value = value.to(torch.float32) / v_scale
value = torch.clamp(value, dtype_info.min, dtype_info.max)
value = value.to(kv_cache_quantized_dtype)
kv = torch.cat([key, value], axis=-1).reshape(-1, num_combined_kv_heads,
head_size)
torch.ops.xla.dynamo_set_buffer_donor_(kv_cache, True)
kv_cache = kv_cache.flatten(0, 1)
new_kv_cache = torch.ops.xla.kv_cache_update_op(
kv, slot_mapping, kv_cache, num_kv_update_slices, page_size,
num_slices_per_kv_cache_update_block)
# NOTE: the in-place copy will be optimized away by XLA compiler.
kv_cache.copy_(new_kv_cache)
# We can move this function to a common utils file if it's also useful for other
# hardware.
def dtype_bits(dtype: torch.dtype):
if dtype.is_floating_point:
try:
return torch.finfo(dtype).bits
except TypeError:
pass
elif dtype.is_complex:
if dtype is torch.complex32:
return 32
elif dtype is torch.complex64:
return 64
elif dtype is torch.complex128:
return 128
else:
try:
return torch.iinfo(dtype).bits
# torch.iinfo cannot support int4, int2, bits8...
except TypeError:
pass
str_dtype = str(dtype)
# support torch.int4, torch.int5, torch.uint5...
if str_dtype.startswith("torch.int") or str_dtype.startswith("torch.uint"):
return int(str_dtype[-1])
raise TypeError(f"Getting the bit width of {dtype} is not supported")
def get_dtype_packing(dtype):
bits = dtype_bits(dtype)
if 32 % bits != 0:
raise ValueError(
f"The bit width must be divisible by 32, but got bits={bits}, "
"dtype={dtype}")
return 32 // bits
def get_page_size_bytes(block_size: int, num_kv_heads: int, head_size: int,
kv_cache_dtype: torch.dtype) -> int:
"""Returns the size in bytes of one page of the KV cache."""
padded_head_size = cdiv(head_size,
TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
num_combined_kv_heads = num_kv_heads * 2
# NOTE: for the implicit padding in XLA
packing = get_dtype_packing(kv_cache_dtype)
num_combined_kv_heads = cdiv(num_combined_kv_heads, packing) * packing
kv_cache_dtype_bits = dtype_bits(kv_cache_dtype)
return (block_size * num_combined_kv_heads * padded_head_size *
kv_cache_dtype_bits // 8)

View File

@@ -0,0 +1,549 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with AiterFlashAttention."""
from dataclasses import dataclass
from typing import Optional
import torch
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import AttentionSpec
_PARTITION_SIZE_ROCM = 256
if current_platform.is_rocm():
import aiter
from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op
@triton.jit
def _vllm_layout_trans_kernel(
k_buffer_ptr,
v_buffer_ptr,
k_values_ptr,
v_values_ptr,
b_query_lens_loc,
b_seq_lens_loc,
block_table,
block_table_stride_0,
k_scale,
v_scale,
output_dtype: tl.constexpr,
E_DIM: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
batch_idx = tl.program_id(0)
block_idx = tl.program_id(1)
batch_query_indexes = tl.load(b_query_lens_loc + batch_idx +
tl.arange(0, 2))
batch_query_start, batch_query_end = tl.split(batch_query_indexes)
query_len = batch_query_end - batch_query_start
if query_len <= 1:
return
batch_token_indexes = tl.load(b_seq_lens_loc + batch_idx +
tl.arange(0, 2))
batch_token_start, batch_token_end = tl.split(batch_token_indexes)
seq_len = batch_token_end - batch_token_start
if block_idx * BLOCK_SIZE < seq_len:
block_mask = (block_idx * BLOCK_SIZE +
tl.arange(0, BLOCK_SIZE)[:, None]) < seq_len
kv_idx = tl.load(block_table + batch_idx * block_table_stride_0 +
block_idx).to(tl.int64)
kv_buffer_off = kv_idx * BLOCK_SIZE * E_DIM + tl.arange(
0, BLOCK_SIZE)[:, None] * E_DIM + tl.arange(0, E_DIM)[None, :]
k_vals = tl.load(k_buffer_ptr + kv_buffer_off,
mask=block_mask,
other=0.0)
if k_vals.dtype.is_fp8():
k_vals = (k_vals.to(tl.float32) *
tl.load(k_scale)).to(output_dtype)
else:
k_vals = k_vals.to(output_dtype)
v_vals = tl.load(v_buffer_ptr + kv_buffer_off,
mask=block_mask,
other=0.0)
if v_vals.dtype.is_fp8():
v_vals = (v_vals.to(tl.float32) *
tl.load(v_scale)).to(output_dtype)
else:
v_vals = v_vals.to(output_dtype)
kv_values_off = batch_token_start * E_DIM + \
block_idx * BLOCK_SIZE * E_DIM + \
tl.arange(0, BLOCK_SIZE)[:, None] * E_DIM + \
tl.arange(0, E_DIM)[None, :]
tl.store(k_values_ptr + kv_values_off, k_vals, mask=block_mask)
tl.store(v_values_ptr + kv_values_off, v_vals, mask=block_mask)
def vllm_layout_trans(b_query_lens_loc, b_seq_lens_loc, block_table,
k_cache, v_cache, max_seq_len, k_scale, v_scale,
output_dtype, total_tokens):
H_KV = v_cache.shape[2]
D = v_cache.shape[3]
BLOCK_SIZE = v_cache.shape[1]
k_values = torch.empty(
(total_tokens, H_KV, D),
dtype=output_dtype,
device=k_cache.device,
)
v_values = torch.empty(
(total_tokens, H_KV, D),
dtype=output_dtype,
device=v_cache.device,
)
grid = (block_table.shape[0],
(max_seq_len + BLOCK_SIZE - 1) // BLOCK_SIZE)
if output_dtype == torch.float16:
output_dtype = tl.float16
elif output_dtype == torch.bfloat16:
output_dtype = tl.bfloat16
else:
raise ValueError(f"Unsupported output dtype: {output_dtype}")
_vllm_layout_trans_kernel[grid](k_cache,
v_cache,
k_values,
v_values,
b_query_lens_loc,
b_seq_lens_loc,
block_table,
block_table.stride(0),
k_scale,
v_scale,
output_dtype=output_dtype,
E_DIM=H_KV * D,
BLOCK_SIZE=BLOCK_SIZE)
return k_values, v_values
def flash_attn_varlen_func_impl(
q: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
out: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
softmax_scale: float,
window_size: Optional[list[int]], # -1 means infinite context window
alibi_slopes: Optional[list[float]],
block_table: torch.Tensor,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
total_tokens: int = 0,
) -> torch.Tensor:
if total_tokens == 0:
total_tokens = int(cu_seqlens_k[-1].item())
k, v = vllm_layout_trans(cu_seqlens_q, cu_seqlens_k, block_table,
k_cache, v_cache, max_seqlen_k, k_scale,
v_scale, q.dtype, total_tokens)
output = aiter.flash_attn_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=max_seqlen_q,
min_seqlen_q=1,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_k=max_seqlen_k,
softmax_scale=softmax_scale,
causal=True,
alibi_slopes=alibi_slopes,
window_size=window_size,
out=out,
)
return output
def flash_attn_varlen_func_fake(
q: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
out: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
softmax_scale: float,
window_size: Optional[list[int]], # -1 means infinite context window
alibi_slopes: Optional[list[float]],
block_table: torch.Tensor,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
total_tokens: int = 0,
) -> torch.Tensor:
return torch.empty(q.shape[0],
q.shape[1],
v_cache.shape[-2],
dtype=q.dtype,
device=q.device)
direct_register_custom_op("flash_attn_varlen_func",
flash_attn_varlen_func_impl, ["out"],
flash_attn_varlen_func_fake,
dispatch_key=current_platform.dispatch_key)
logger = init_logger(__name__)
@dataclass
class AiterFlashAttentionMetadata:
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
num_actual_tokens: int # Number of tokens excluding padding.
num_actual_kv_tokens: int
max_query_len: int
query_start_loc: torch.Tensor
max_seq_len: int
seq_lens: torch.Tensor
slot_mapping: torch.Tensor
block_table: torch.Tensor
cu_seq_lens: Optional[torch.Tensor]
# For cascade attention.
use_cascade: bool
common_prefix_len: int
total_tokens: int
class AiterFlashAttentionMetadataBuilder(
AttentionMetadataBuilder[AiterFlashAttentionMetadata]):
cudagraph_support = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self.model_config = vllm_config.model_config
self.parallel_config = vllm_config.parallel_config
self.cache_config = vllm_config.cache_config
self.num_heads_q = self.model_config.get_num_attention_heads(
self.parallel_config)
self.num_heads_kv = self.model_config.get_num_kv_heads(
self.parallel_config)
self.headdim = self.model_config.get_head_size()
self.block_size = kv_cache_spec.block_size
# Sliding window size to be used with the AOT scheduler will be
# populated on first build() call.
self.aot_sliding_window: Optional[tuple[int, int]] = None
self.total_tokens: int = 0
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata):
self.total_tokens = self.model_config.max_model_len \
* self.vllm_config.scheduler_config.max_num_partial_prefills
res = self.build(common_prefix_len=0,
common_attn_metadata=common_attn_metadata)
self.total_tokens = 0
return res
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> 'AiterFlashAttentionMetadata':
num_actual_tokens = common_attn_metadata.num_actual_tokens
max_query_len = common_attn_metadata.max_query_len
max_seq_len = common_attn_metadata.max_seq_len
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
block_table_tensor = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping
if max_query_len > 1:
# We pre-compute cumulative seq len needed for prefill attention
# here to avoid recomputing it for every layer
cu_seq_lens = torch.zeros(seq_lens.shape[0] + 1,
dtype=torch.int32,
device=seq_lens.device)
torch.cumsum(seq_lens,
dim=0,
dtype=cu_seq_lens.dtype,
out=cu_seq_lens[1:])
num_actual_kv_tokens = int(cu_seq_lens[-1].item())
else:
cu_seq_lens = None
num_actual_kv_tokens = 0
def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
max_seq_len, causal):
return None
use_cascade = common_prefix_len > 0
attn_metadata = AiterFlashAttentionMetadata(
num_actual_tokens=num_actual_tokens,
num_actual_kv_tokens=num_actual_kv_tokens,
max_query_len=max_query_len,
query_start_loc=query_start_loc,
max_seq_len=max_seq_len,
seq_lens=seq_lens,
block_table=block_table_tensor,
slot_mapping=slot_mapping,
cu_seq_lens=cu_seq_lens,
use_cascade=use_cascade,
common_prefix_len=common_prefix_len,
total_tokens=self.total_tokens,
)
return attn_metadata
def use_cascade_attention(self, *args, **kwargs) -> bool:
return False
class AiterFlashAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [64, 128, 256]
@classmethod
def validate_head_size(cls, head_size: int) -> None:
supported_head_sizes = cls.get_supported_head_sizes()
if head_size not in supported_head_sizes:
attn_type = cls.__name__.removesuffix("Backend")
raise ValueError(
f"Head size {head_size} is not supported by {attn_type}. "
f"Supported head sizes are: {supported_head_sizes}. "
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes.")
@staticmethod
def get_name() -> str:
return "FLASH_ATTN"
@staticmethod
def get_impl_cls() -> type["AiterFlashAttentionImpl"]:
return AiterFlashAttentionImpl
@staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
return AiterFlashAttentionMetadata
@staticmethod
def get_builder_cls() -> type["AiterFlashAttentionMetadataBuilder"]:
return AiterFlashAttentionMetadataBuilder
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
return (2, num_blocks, block_size, num_kv_heads, head_size)
class AiterFlashAttentionImpl(AttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
logits_soft_cap: Optional[float] = None,
attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[int] = None,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
if sliding_window is None:
self.sliding_window = [-1, -1]
else:
self.sliding_window = [sliding_window - 1, 0]
self.kv_cache_dtype = kv_cache_dtype
if logits_soft_cap is None:
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
logits_soft_cap = 0.
self.logits_soft_cap = logits_soft_cap
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
AiterFlashAttentionBackend.validate_head_size(head_size)
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"FlashAttentionImpl")
def forward(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AiterFlashAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with AiterFlashAttention.
Args:
query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache: shape =
[2, num_blocks, block_size, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
NOTE: FP8 quantization, flash-attn expect the size of
{q,k,v}_descale to be (num_sequences, num_kv_heads).
We use torch's .expand() to avoid duplicating values
"""
assert output is not None, "Output tensor must be provided."
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for FlashAttentionImpl")
if attn_metadata is None:
# Profiling run.
return output
# IMPORTANT!
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
# in this method. For example, `view` and `slice` (or `[:n]`) operations
# are surprisingly slow even in the case they do not invoke any GPU ops.
# Minimize the PyTorch ops in this method as much as possible.
# Whenever making a change in this method, please benchmark the
# performance to make sure it does not introduce any overhead.
num_actual_tokens = attn_metadata.num_actual_tokens
key_cache, value_cache = kv_cache.unbind(0)
if self.kv_sharing_target_layer_name is None:
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
# not padded. However, we don't need to do key[:num_actual_tokens]
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(current_platform.fp8_dtype())
value_cache = value_cache.view(current_platform.fp8_dtype())
if not attn_metadata.use_cascade:
cu_seqlens_q = attn_metadata.query_start_loc
seqused_k = attn_metadata.seq_lens
max_seqlen_q = attn_metadata.max_query_len
max_seqlen_k = attn_metadata.max_seq_len
block_table = attn_metadata.block_table
if max_seqlen_q > 1:
torch.ops.vllm.flash_attn_varlen_func(
query[:num_actual_tokens],
key_cache,
value_cache,
out=output[:num_actual_tokens],
cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
softmax_scale=self.scale,
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
block_table=block_table,
cu_seqlens_k=attn_metadata.cu_seq_lens,
k_scale=layer._k_scale,
v_scale=layer._v_scale,
total_tokens=attn_metadata.num_actual_kv_tokens,
)
_, num_heads, head_size = query.shape
nbytes_per_qo_elem = torch.finfo(query.dtype).bits // 8
num_seqs = seqused_k.shape[0]
max_num_partitions = (max_seqlen_k + _PARTITION_SIZE_ROCM -
1) // _PARTITION_SIZE_ROCM
workspace_buffer = torch.empty(
(num_seqs * num_heads * max_num_partitions * head_size) *
nbytes_per_qo_elem + 2 *
(num_seqs * num_heads * max_num_partitions) * 4,
dtype=torch.uint8,
device=output.device,
)
torch.ops.aiter.paged_attention_v1(
output[:num_actual_tokens],
workspace_buffer,
query[:num_actual_tokens],
key_cache,
value_cache,
self.scale,
block_table,
cu_seqlens_q,
seqused_k,
max_seqlen_k,
self.alibi_slopes,
self.kv_cache_dtype,
"NHD",
self.logits_soft_cap,
layer._k_scale,
layer._v_scale,
None,
_PARTITION_SIZE_ROCM,
)
return output
else:
raise NotImplementedError(
"Cascade attention is not implemented for ROCM AITER")

View File

@@ -0,0 +1,426 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with PagedAttention and Triton prefix prefill."""
from dataclasses import dataclass
from functools import cache
from typing import ClassVar, Optional
import torch
from vllm import _custom_ops as ops
from vllm import envs
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
from vllm.attention.ops.chunked_prefill_paged_decode import (
chunked_prefill_paged_decode)
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey, kFp8StaticTensorSym)
from vllm.platforms import current_platform
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import AttentionSpec
logger = init_logger(__name__)
@dataclass
class RocmAttentionMetadata:
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
num_actual_tokens: int # Number of tokens excluding padding.
max_query_len: int
query_start_loc: torch.Tensor
max_seq_len: int
seq_lens: torch.Tensor
block_table: torch.Tensor
slot_mapping: torch.Tensor
# For cascade attention.
use_cascade: bool
common_prefix_len: int
cu_prefix_query_lens: Optional[torch.Tensor]
prefix_kv_lens: Optional[torch.Tensor]
suffix_kv_lens: Optional[torch.Tensor]
# Optional aot scheduling
scheduler_metadata: Optional[torch.Tensor] = None
prefix_scheduler_metadata: Optional[torch.Tensor] = None
class RocmAttentionMetadataBuilder(
AttentionMetadataBuilder[RocmAttentionMetadata]):
cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self.block_size = kv_cache_spec.block_size
model_config = vllm_config.model_config
self.num_heads_q = model_config.get_num_attention_heads(
vllm_config.parallel_config)
self.num_heads_kv = model_config.get_num_kv_heads(
vllm_config.parallel_config)
self.headdim = model_config.get_head_size()
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata
) -> RocmAttentionMetadata:
attn_metadata = self.build(0, common_attn_metadata)
# When doing full graph capture, setting seq_lens to
# max_model_len will cause graph capture to be extremely
# slow, so here we set it to 1.
attn_metadata.seq_lens.fill_(1)
return attn_metadata
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> RocmAttentionMetadata:
num_actual_tokens = common_attn_metadata.num_actual_tokens
max_query_len = common_attn_metadata.max_query_len
max_seq_len = common_attn_metadata.max_seq_len
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
block_table_tensor = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping
use_cascade = common_prefix_len > 0
if use_cascade:
cu_prefix_query_lens = torch.tensor([0, num_actual_tokens],
dtype=torch.int32,
device=self.device)
prefix_kv_lens = torch.tensor([common_prefix_len],
dtype=torch.int32,
device=self.device)
suffix_kv_lens = (common_attn_metadata.seq_lens_cpu -
common_prefix_len)
suffix_kv_lens = suffix_kv_lens.to(self.device)
else:
cu_prefix_query_lens = None
prefix_kv_lens = None
suffix_kv_lens = None
prefix_scheduler_metadata = None
attn_metadata = RocmAttentionMetadata(
num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len,
query_start_loc=query_start_loc,
max_seq_len=max_seq_len,
seq_lens=seq_lens,
block_table=block_table_tensor,
slot_mapping=slot_mapping,
use_cascade=use_cascade,
common_prefix_len=common_prefix_len,
cu_prefix_query_lens=cu_prefix_query_lens,
prefix_kv_lens=prefix_kv_lens,
suffix_kv_lens=suffix_kv_lens,
prefix_scheduler_metadata=prefix_scheduler_metadata,
)
return attn_metadata
class RocmAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]
@classmethod
def validate_head_size(cls, head_size: int) -> None:
supported_head_sizes = cls.get_supported_head_sizes()
if head_size not in supported_head_sizes:
attn_type = cls.__name__.removesuffix("Backend")
raise ValueError(
f"Head size {head_size} is not supported by {attn_type}. "
f"Supported head sizes are: {supported_head_sizes}. "
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes.")
@staticmethod
def get_name() -> str:
return "ROCM_ATTN"
@staticmethod
def get_impl_cls() -> type["RocmAttentionImpl"]:
return RocmAttentionImpl
@staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
return RocmAttentionMetadata
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> tuple[int, ...]:
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
return (2, num_blocks, block_size, num_kv_heads, head_size)
@staticmethod
def use_cascade_attention(*args, **kwargs) -> bool:
return False
@staticmethod
def get_builder_cls() -> type["RocmAttentionMetadataBuilder"]:
return RocmAttentionMetadataBuilder
@cache
def use_aiter_unified_attention() -> bool:
"""Check if aiter unified attention should be used."""
# VLLM_ROCM_USE_AITER_MHA needs to set to 0 as well as it is set
# to 1 as default
return envs.VLLM_ROCM_USE_AITER \
and envs.VLLM_USE_AITER_UNIFIED_ATTENTION
class RocmAttentionImpl(AttentionImpl):
def fused_output_quant_supported(self, quant_key: QuantKey):
return quant_key == kFp8StaticTensorSym
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
logits_soft_cap: Optional[float] = None,
attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[int] = None,
sinks: Optional[torch.Tensor] = None,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
if sliding_window is None:
self.sliding_window = (-1, -1)
else:
self.sliding_window = (sliding_window - 1, 0)
self.kv_cache_dtype = kv_cache_dtype
if logits_soft_cap is None:
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
logits_soft_cap = 0
self.logits_soft_cap = logits_soft_cap
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
RocmAttentionBackend.validate_head_size(head_size)
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"RocmAttentionImpl")
self.fp8_dtype = current_platform.fp8_dtype()
self.force_prefill_decode_attn = \
envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION
if not self.force_prefill_decode_attn:
# If not using prefill decode attention, we use the Triton
# unified attention implementation.
if use_aiter_unified_attention():
logger.info_once(
"Using aiter unified attention for RocmAttentionImpl")
from aiter.ops.triton.unified_attention import (
unified_attention)
self.unified_attention = unified_attention
else:
logger.info_once(
"Using vllm unified attention for RocmAttentionImpl")
from vllm.attention.ops.triton_unified_attention import (
unified_attention)
self.unified_attention = unified_attention
self.sinks = sinks
if sinks is not None:
assert sinks.shape[0] == num_heads, (
"Sinks must have the same number of heads as the number of "
f"heads in the layer. Sinks shape: {sinks.shape}, "
f"num_heads: {num_heads}.")
def forward(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: FlashAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention.
Args:
query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache: shape =
[2, num_blocks, block_size, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert output is not None, "Output tensor must be provided."
if output_block_scale is not None:
raise NotImplementedError(
"fused block_scale output quantization is not yet supported"
" for RocmAttentionImpl")
if attn_metadata is None:
# Profiling run.
return output
assert attn_metadata.use_cascade is False
# IMPORTANT!
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
# in this method. For example, `view` and `slice` (or `[:n]`) operations
# are surprisingly slow even in the case they do not invoke any GPU ops.
# Minimize the PyTorch ops in this method as much as possible.
# Whenever making a change in this method, please benchmark the
# performance to make sure it does not introduce any overhead.
use_prefill_decode_attn = self.force_prefill_decode_attn
num_actual_tokens = attn_metadata.num_actual_tokens
if use_prefill_decode_attn:
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)
else:
key_cache, value_cache = kv_cache.unbind(0)
if self.kv_sharing_target_layer_name is None:
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
if use_prefill_decode_attn:
PagedAttention.write_to_paged_cache(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
else:
ops.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
num_tokens, num_heads, head_size = query.shape
assert layer._q_scale_float == 1.0, \
"A non 1.0 q_scale is not currently supported."
if current_platform.is_cuda():
# Skip Q quantization on ROCm and XPU, enable this on cuda
# only, since dequantizing back to f32 in the attention kernel
# is not supported.
query, _ = ops.scaled_fp8_quant(
query.reshape(
(num_tokens, num_heads * head_size)).contiguous(),
layer._q_scale)
query = query.reshape((num_tokens, num_heads, head_size))
cu_seqlens_q = attn_metadata.query_start_loc
seqused_k = attn_metadata.seq_lens
max_seqlen_q = attn_metadata.max_query_len
max_seqlen_k = attn_metadata.max_seq_len
block_table = attn_metadata.block_table
if use_prefill_decode_attn:
# Compute attention and update output up to `num_actual_tokens`.
chunked_prefill_paged_decode(
query=query[:num_actual_tokens],
key=key[:num_actual_tokens],
value=value[:num_actual_tokens],
output=output[:num_actual_tokens],
kv_cache_dtype=self.kv_cache_dtype,
key_cache=key_cache,
value_cache=value_cache,
block_table=block_table,
query_start_loc=cu_seqlens_q,
seq_lens=seqused_k,
max_seq_len=max_seqlen_k,
max_query_len=max_seqlen_q,
k_scale=layer._k_scale,
v_scale=layer._v_scale,
alibi_slopes=self.alibi_slopes,
sliding_window=self.sliding_window[0],
sm_scale=self.scale,
output_scale=output_scale,
sinks=self.sinks,
)
else:
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
self.unified_attention(
q=query[:num_actual_tokens],
k=key_cache,
v=value_cache,
out=output[:num_actual_tokens],
cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=max_seqlen_q,
seqused_k=seqused_k,
max_seqlen_k=max_seqlen_k,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
block_table=block_table,
softcap=self.logits_soft_cap,
q_descale=None, # Not supported
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
sinks=self.sinks,
output_scale=output_scale)
return output

View File

@@ -0,0 +1,94 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Optional
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.config import VllmConfig
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata,
compute_causal_conv1d_metadata,
split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
class ShortConvAttentionBackend(AttentionBackend):
@staticmethod
def get_builder_cls() -> type["ShortConvAttentionMetadataBuilder"]:
return ShortConvAttentionMetadataBuilder
@dataclass
class ShortConvAttentionMetadata:
num_prefills: int
num_prefill_tokens: int
num_decodes: int
num_decode_tokens: int
query_start_loc: torch.Tensor
has_initial_states: torch.Tensor
state_indices_tensor: torch.Tensor # shape: [batch,]
# For causal_conv1d
nums_dict: Optional[dict] = None
batch_ptr: Optional[torch.Tensor] = None
token_chunk_offset_ptr: Optional[torch.Tensor] = None
class ShortConvAttentionMetadataBuilder(
AttentionMetadataBuilder[ShortConvAttentionMetadata]):
reorder_batch_threshold: int = 1
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
assert isinstance(kv_cache_spec, MambaSpec)
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> ShortConvAttentionMetadata:
num_reqs = common_attn_metadata.num_reqs
query_start_loc = common_attn_metadata.query_start_loc
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
# for causal_conv1d
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(
common_attn_metadata,
decode_threshold=self.reorder_batch_threshold))
has_initial_states = None
if num_prefills > 0:
#[batch,]
has_initial_states_cpu = (
common_attn_metadata.
num_computed_tokens_cpu[num_reqs - num_prefills:num_reqs] > 0)
has_initial_states = has_initial_states_cpu.to(
query_start_loc.device)
query_start_loc_p = common_attn_metadata.query_start_loc[
-num_prefills - 1:] - num_decode_tokens
nums_dict, batch_ptr, token_chunk_offset_ptr = \
compute_causal_conv1d_metadata(query_start_loc_p)
attn_metadata = ShortConvAttentionMetadata(
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
query_start_loc=query_start_loc,
has_initial_states=has_initial_states,
state_indices_tensor=state_indices_tensor,
nums_dict=nums_dict,
batch_ptr=batch_ptr,
token_chunk_offset_ptr=token_chunk_offset_ptr,
)
return attn_metadata

View File

@@ -0,0 +1,451 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with TreeAttention."""
import ast
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional
import torch
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
from vllm.attention.ops.triton_unified_attention import unified_attention
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, CommonAttentionMetadata,
reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm import _custom_ops as ops
logger = init_logger(__name__)
class TreeAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]
@classmethod
def validate_head_size(cls, head_size: int) -> None:
supported_head_sizes = cls.get_supported_head_sizes()
if head_size not in supported_head_sizes:
attn_type = cls.__name__.removesuffix("Backend")
raise ValueError(
f"Head size {head_size} is not supported by {attn_type}. "
f"Supported head sizes are: {supported_head_sizes}. "
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes.")
@staticmethod
def get_name() -> str:
return "TREE_ATTN"
@staticmethod
def get_impl_cls() -> type["TreeAttentionImpl"]:
return TreeAttentionImpl
@staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
return TreeAttentionMetadata
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
return (2, num_blocks, block_size, num_kv_heads, head_size)
@staticmethod
def get_builder_cls() -> type["TreeAttentionMetadataBuilder"]:
return TreeAttentionMetadataBuilder
@staticmethod
def use_cascade_attention(*args, **kwargs) -> bool:
return False
@dataclass
class TreeAttentionMetadata:
num_actual_tokens: int # Number of tokens excluding padding.
max_query_len: int
query_start_loc: torch.Tensor
max_seq_len: int
seq_lens: torch.Tensor
block_table: torch.Tensor
slot_mapping: torch.Tensor
num_prefill_tokens: int = 0
num_decode_tokens: int = 0
num_prefills: int = 0
num_decodes: int = 0
tree_attn_bias: Optional[torch.Tensor] = None
# Cached Prefill/decode metadata.
_cached_prefill_metadata: Optional["TreeAttentionMetadata"] = None
_cached_decode_metadata: Optional["TreeAttentionMetadata"] = None
@property
def prefill_metadata(self) -> Optional["TreeAttentionMetadata"]:
if self.num_prefills == 0:
return None
if self._cached_prefill_metadata is not None:
# Recover cached prefill-phase attention
# metadata structure
return self._cached_prefill_metadata
q_start_loc = self.query_start_loc[self.num_decodes:]
q_seqlens = torch.diff(q_start_loc)
kv_seqlens = self.seq_lens[self.num_decodes:]
# Construct & cache prefill-phase attention metadata structure
self._cached_prefill_metadata = TreeAttentionMetadata(
num_actual_tokens=self.num_prefill_tokens,
max_query_len=int(q_seqlens.max().item()),
query_start_loc=q_start_loc - q_start_loc[0],
max_seq_len=int(kv_seqlens.max().item()),
seq_lens=kv_seqlens,
block_table=self.block_table[self.num_decodes:],
slot_mapping=self.slot_mapping[self.num_decode_tokens:],
)
return self._cached_prefill_metadata
@property
def decode_metadata(self) -> Optional["TreeAttentionMetadata"]:
if self.num_decode_tokens == 0:
return None
if self._cached_decode_metadata is not None:
# Recover cached decode-phase attention
# metadata structure
return self._cached_decode_metadata
q_start_loc = self.query_start_loc[:self.num_decodes + 1]
q_seqlens = torch.diff(q_start_loc)
kv_seqlens = self.seq_lens[:self.num_decodes]
# Construct & cache decode-phase attention metadata structure
self._cached_decode_metadata = TreeAttentionMetadata(
num_actual_tokens=self.num_decode_tokens,
max_query_len=int(q_seqlens.max().item()),
query_start_loc=q_start_loc,
max_seq_len=int(kv_seqlens.max().item()),
seq_lens=kv_seqlens,
block_table=self.block_table[:self.num_decodes],
slot_mapping=self.slot_mapping[:self.num_decode_tokens],
tree_attn_bias=self.tree_attn_bias,
)
return self._cached_decode_metadata
class TreeAttentionMetadataBuilder(
AttentionMetadataBuilder[TreeAttentionMetadata]):
def __init__(
self,
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self.block_size = kv_cache_spec.block_size
spec_config = vllm_config.speculative_config
spec_token_tree = (spec := spec_config) and spec.speculative_token_tree
tree_choices: list[tuple[int,
...]] = (ast.literal_eval(spec_token_tree)
if spec_token_tree is not None else
[(0, )])
# Construct the tree attention bias.
depth_counts = _get_depth_counts(tree_choices)
self.tree_attn_bias = _prepare_tree_attn_bias(
tree_choices,
depth_counts,
dtype=torch.float32,
device=device,
)
def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
return reorder_batch_to_split_decodes_and_prefills(
input_batch,
scheduler_output,
decode_threshold=self.tree_attn_bias.shape[0])
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> TreeAttentionMetadata:
decode_threshold = self.tree_attn_bias.shape[0]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(common_attn_metadata,
decode_threshold=decode_threshold))
num_actual_tokens = common_attn_metadata.num_actual_tokens
q_start_loc = common_attn_metadata.query_start_loc
max_query_len = common_attn_metadata.max_query_len
kv_seqlens = common_attn_metadata.seq_lens
max_seq_len = common_attn_metadata.max_seq_len
block_table = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping
return TreeAttentionMetadata(
num_actual_tokens=num_actual_tokens,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
num_prefills=num_prefills,
num_decodes=num_decodes,
max_query_len=max_query_len,
query_start_loc=q_start_loc,
max_seq_len=max_seq_len,
seq_lens=kv_seqlens,
block_table=block_table,
slot_mapping=slot_mapping,
tree_attn_bias=self.tree_attn_bias,
)
def build_for_drafting(
self,
common_attn_metadata: CommonAttentionMetadata,
draft_index: int,
) -> TreeAttentionMetadata:
# Cache the original tree attention bias.
orig_tree_attn_bias = self.tree_attn_bias
if draft_index == 0:
# Use prefill for drafting at the root level.
self.tree_attn_bias = torch.empty(0)
else:
# Slice the tree attention bias for drafting. Exclude
# the root level.
start, end = 1, 1 + common_attn_metadata.max_query_len
self.tree_attn_bias = self.tree_attn_bias[start:end,
start:end].contiguous()
# Build attention bias.
attn_metadata = self.build(0, common_attn_metadata, fast_build=True)
# Reset the tree attention bias to the original value.
self.tree_attn_bias = orig_tree_attn_bias
return attn_metadata
def _get_depth_counts(sorted_tree_choices: list[tuple[int, ...]]) -> list[int]:
# Count the number of choices at each depth of the tree.
depth_counts = []
prev_depth = 0
for path in sorted_tree_choices:
depth = len(path)
if depth != prev_depth:
depth_counts.append(0)
depth_counts[depth - 1] += 1
prev_depth = depth
return depth_counts
def _prepare_tree_attn_bias(
sorted_tree_choices: list[tuple[int, ...]],
depth_counts: list[int],
dtype: Optional[torch.dtype],
device: Optional[torch.device],
) -> torch.Tensor:
# +1 comes from the additional root node.
tree_len = len(sorted_tree_choices) + 1
tree_attn_mask = torch.full((tree_len, tree_len),
-torch.inf,
device=device,
dtype=dtype)
# Set diagonal to all zeros. Each token should
# attend to itself.
mask_val = 0
for i in range(tree_len):
tree_attn_mask[i, i] = mask_val
# Set root to all zeros. All tokens attend to it.
tree_attn_mask[:, 0] = mask_val
# Set all ancestors to zeros.
start = 0
for i in range(len(depth_counts)):
for j in range(depth_counts[i]):
cur_tree_choice = sorted_tree_choices[start + j]
# Retrieve ancestor position.
if len(cur_tree_choice) == 1:
continue
ancestor_idx = []
for c in range(len(cur_tree_choice) - 1):
ancestor_idx.append(
sorted_tree_choices.index(cur_tree_choice[:c + 1]) + 1)
tree_attn_mask[j + start + 1, ancestor_idx] = mask_val
start += depth_counts[i]
return tree_attn_mask
class TreeAttentionImpl(AttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
logits_soft_cap: Optional[float] = None,
attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.kv_cache_dtype = kv_cache_dtype
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
if logits_soft_cap is None:
# Setting logits_soft_cap to 0 means no soft cap.
logits_soft_cap = 0
self.logits_soft_cap = logits_soft_cap
if sliding_window is None:
self.sliding_window = (-1, -1)
else:
self.sliding_window = (sliding_window - 1, 0)
TreeAttentionBackend.validate_head_size(head_size)
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"TreeAttentionImpl.")
def forward(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: TreeAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with TreeAttention.
Args:
query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache: shape =
[2, num_blocks, block_size, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert output is not None, "Output tensor must be provided."
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for TreeAttentionImpl")
if attn_metadata is None:
# Profiling run.
return output
# Cache the input KVs.
key_cache, value_cache = kv_cache.unbind(0)
if self.kv_sharing_target_layer_name is None:
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
# not padded. However, we don't need to do key[:num_actual_tokens]
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
ops.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
num_actual_tokens = attn_metadata.num_actual_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
descale_shape = (attn_metadata.query_start_loc.shape[0] - 1,
key.shape[1])
if prefill_meta := attn_metadata.prefill_metadata:
unified_attention(
q=query[num_decode_tokens:num_actual_tokens],
k=key_cache,
v=value_cache,
out=output[num_decode_tokens:num_actual_tokens],
cu_seqlens_q=prefill_meta.query_start_loc,
max_seqlen_q=prefill_meta.max_query_len,
seqused_k=prefill_meta.seq_lens,
max_seqlen_k=prefill_meta.max_seq_len,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
block_table=prefill_meta.block_table,
softcap=self.logits_soft_cap,
q_descale=None, # Not supported
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
)
if decode_meta := attn_metadata.decode_metadata:
unified_attention(
q=query[:num_decode_tokens],
k=key_cache,
v=value_cache,
out=output[:num_decode_tokens],
cu_seqlens_q=decode_meta.query_start_loc,
max_seqlen_q=decode_meta.max_query_len,
seqused_k=decode_meta.seq_lens,
max_seqlen_k=decode_meta.max_seq_len,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
qq_bias=decode_meta.tree_attn_bias,
window_size=self.sliding_window,
block_table=decode_meta.block_table,
softcap=self.logits_soft_cap,
q_descale=None, # Not supported
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
)
return output

View File

@@ -0,0 +1,361 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""High-Performance Triton-only Attention layer."""
from dataclasses import dataclass
from typing import ClassVar, Optional
import torch
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
from vllm.attention.ops.triton_reshape_and_cache_flash import (
triton_reshape_and_cache_flash)
from vllm.attention.ops.triton_unified_attention import unified_attention
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey, kFp8StaticTensorSym)
from vllm.platforms import current_platform
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import AttentionSpec
if current_platform.is_cuda_alike():
from vllm import _custom_ops as ops
elif current_platform.is_xpu():
from vllm._ipex_ops import ipex_ops as ops
logger = init_logger(__name__)
@dataclass
class TritonAttentionMetadata:
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
num_actual_tokens: int # Number of tokens excluding padding.
max_query_len: int
query_start_loc: torch.Tensor
max_seq_len: int
seq_lens: torch.Tensor
block_table: torch.Tensor
slot_mapping: torch.Tensor
# For cascade attention.
use_cascade: bool
common_prefix_len: int
cu_prefix_query_lens: Optional[torch.Tensor]
prefix_kv_lens: Optional[torch.Tensor]
suffix_kv_lens: Optional[torch.Tensor]
# Optional aot scheduling
scheduler_metadata: Optional[torch.Tensor] = None
prefix_scheduler_metadata: Optional[torch.Tensor] = None
class TritonAttentionMetadataBuilder(
AttentionMetadataBuilder[TritonAttentionMetadata]):
cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self.block_size = kv_cache_spec.block_size
model_config = vllm_config.model_config
self.num_heads_q = model_config.get_num_attention_heads(
vllm_config.parallel_config)
self.num_heads_kv = model_config.get_num_kv_heads(
vllm_config.parallel_config)
self.headdim = model_config.get_head_size()
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata
) -> TritonAttentionMetadata:
attn_metadata = self.build(0, common_attn_metadata)
# When doing full graph capture, setting seq_lens to
# max_model_len will cause graph capture to be extremely
# slow, so here we set it to 1.
attn_metadata.seq_lens.fill_(1)
return attn_metadata
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> TritonAttentionMetadata:
num_actual_tokens = common_attn_metadata.num_actual_tokens
max_query_len = common_attn_metadata.max_query_len
max_seq_len = common_attn_metadata.max_seq_len
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
block_table_tensor = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping
use_cascade = common_prefix_len > 0
if use_cascade:
cu_prefix_query_lens = torch.tensor([0, num_actual_tokens],
dtype=torch.int32,
device=self.device)
prefix_kv_lens = torch.tensor([common_prefix_len],
dtype=torch.int32,
device=self.device)
suffix_kv_lens = (common_attn_metadata.seq_lens_cpu -
common_prefix_len)
suffix_kv_lens = suffix_kv_lens.to(self.device)
else:
cu_prefix_query_lens = None
prefix_kv_lens = None
suffix_kv_lens = None
prefix_scheduler_metadata = None
attn_metadata = TritonAttentionMetadata(
num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len,
query_start_loc=query_start_loc,
max_seq_len=max_seq_len,
seq_lens=seq_lens,
block_table=block_table_tensor,
slot_mapping=slot_mapping,
use_cascade=use_cascade,
common_prefix_len=common_prefix_len,
cu_prefix_query_lens=cu_prefix_query_lens,
prefix_kv_lens=prefix_kv_lens,
suffix_kv_lens=suffix_kv_lens,
prefix_scheduler_metadata=prefix_scheduler_metadata,
)
return attn_metadata
class TritonAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16, torch.float32]
@classmethod
def validate_head_size(cls, head_size: int) -> None:
# Triton Attention supports any head size above 32
if head_size < 32:
raise ValueError(
f"Head size {head_size} is not supported by TritonAttention."
f"Head sizes need to be larger or equal 32 for this backend. "
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes.")
@staticmethod
def get_name() -> str:
return "TRITON_ATTN"
@staticmethod
def get_impl_cls() -> type["TritonAttentionImpl"]:
return TritonAttentionImpl
@staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
return TritonAttentionMetadata
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
return (num_blocks, 2, block_size, num_kv_heads, head_size)
@staticmethod
def use_cascade_attention(*args, **kwargs) -> bool:
return False
@staticmethod
def get_builder_cls() -> type["TritonAttentionMetadataBuilder"]:
return TritonAttentionMetadataBuilder
class TritonAttentionImpl(AttentionImpl):
def fused_output_quant_supported(self, quant_key: QuantKey):
return quant_key == kFp8StaticTensorSym
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
logits_soft_cap: Optional[float] = None,
attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[int] = None,
sinks: Optional[torch.Tensor] = None,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
if sliding_window is None:
self.sliding_window = (-1, -1)
else:
self.sliding_window = (sliding_window - 1, 0)
self.kv_cache_dtype = kv_cache_dtype
if logits_soft_cap is None:
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
logits_soft_cap = 0
self.logits_soft_cap = logits_soft_cap
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
TritonAttentionBackend.validate_head_size(head_size)
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"TritonAttentionImpl")
self.fp8_dtype = current_platform.fp8_dtype()
self.sinks = sinks
if sinks is not None:
assert sinks.shape[0] == num_heads, (
"Sinks must have the same number of heads as the number of "
f"heads in the layer. Sinks shape: {sinks.shape}, "
f"num_heads: {num_heads}.")
def forward(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: TritonAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with Paged Attention impl. in Triton.
Args:
query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache: shape =
[num_blocks, 2, block_size, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert output is not None, "Output tensor must be provided."
if output_block_scale is not None:
raise NotImplementedError(
"fused block_scale output quantization is not yet supported"
" for TritonAttentionImpl")
if attn_metadata is None:
# Profiling run.
return output
assert attn_metadata.use_cascade is False
# IMPORTANT!
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
# in this method. For example, `view` and `slice` (or `[:n]`) operations
# are surprisingly slow even in the case they do not invoke any GPU ops.
# Minimize the PyTorch ops in this method as much as possible.
# Whenever making a change in this method, please benchmark the
# performance to make sure it does not introduce any overhead.
num_actual_tokens = attn_metadata.num_actual_tokens
key_cache, value_cache = kv_cache.unbind(1)
if self.kv_sharing_target_layer_name is None:
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
# triton kernel does not support uint8 kv_cache
# (because some explicit casts (e.g. float8_e4m3fnuz)
# are not supported)
triton_reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
if self.kv_cache_dtype.startswith("fp8"):
if key_cache.dtype != self.fp8_dtype:
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
num_tokens, num_heads, head_size = query.shape
assert layer._q_scale_float == 1.0, \
"A non 1.0 q_scale is not currently supported."
if current_platform.is_cuda():
# Skip Q quantization on ROCm and XPU, enable this on cuda
# only, since dequantizing back to f32 in the attention kernel
# is not supported.
query, _ = ops.scaled_fp8_quant(
query.reshape(
(num_tokens, num_heads * head_size)).contiguous(),
layer._q_scale)
query = query.reshape((num_tokens, num_heads, head_size))
cu_seqlens_q = attn_metadata.query_start_loc
seqused_k = attn_metadata.seq_lens
max_seqlen_q = attn_metadata.max_query_len
max_seqlen_k = attn_metadata.max_seq_len
block_table = attn_metadata.block_table
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
unified_attention(
q=query[:num_actual_tokens],
k=key_cache,
v=value_cache,
out=output[:num_actual_tokens],
cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=max_seqlen_q,
seqused_k=seqused_k,
max_seqlen_k=max_seqlen_k,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
block_table=block_table,
softcap=self.logits_soft_cap,
q_descale=None, # Not supported
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
sinks=self.sinks,
output_scale=output_scale,
)
return output

View File

@@ -0,0 +1,990 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import abc
import enum
import functools
from abc import abstractmethod
from dataclasses import dataclass, fields, make_dataclass
from typing import (TYPE_CHECKING, Any, ClassVar, Generic, Literal, Optional,
Protocol, TypeVar, Union, get_args)
import numpy as np
import torch
from typing_extensions import runtime_checkable
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.utils import cdiv
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionImpl
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
import vllm.envs as envs
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata)
from vllm.attention.layer import Attention
from vllm.distributed.kv_transfer.kv_connector.utils import (
get_kv_connector_cache_layout)
from vllm.logger import init_logger
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.ubatch_utils import UBatchSlice
logger = init_logger(__name__)
KVCacheLayoutType = Literal["NHD", "HND"]
_KV_CACHE_LAYOUT_OVERRIDE: Union[KVCacheLayoutType, None] = None
PAD_SLOT_ID = -1
def is_valid_kv_cache_layout(value: str) -> bool:
return value in get_args(KVCacheLayoutType)
@dataclass
class CommonAttentionMetadata:
"""
Per-batch attention metadata, shared across layers and backends.
AttentionMetadataBuilder instances use it to construct per-layer metadata.
For many of the tensors we keep both GPU and CPU versions.
"""
query_start_loc: torch.Tensor
query_start_loc_cpu: torch.Tensor
"""(batch_size + 1,), the start location of each request in query Tensor"""
seq_lens: torch.Tensor
seq_lens_cpu: torch.Tensor
"""(batch_size,), the length of each request including both computed tokens
and newly scheduled tokens"""
num_computed_tokens_cpu: torch.Tensor
"""(batch_size,), the number of computed tokens for each request"""
num_reqs: int
"""Number of requests"""
num_actual_tokens: int
"""Total number of tokens in batch"""
max_query_len: int
"""Longest query in batch"""
max_seq_len: int
"""Longest context length in batch"""
block_table_tensor: torch.Tensor
slot_mapping: torch.Tensor
causal: bool = True
# Needed by FastPrefillAttentionBuilder
logits_indices_padded: Optional[torch.Tensor] = None
num_logits_indices: Optional[int] = None
# Needed by CrossAttentionBuilder
encoder_seq_lens: Optional[np.ndarray] = None
def slice_query_start_locs(
query_start_loc: torch.Tensor,
request_slice: slice,
) -> torch.Tensor:
"""
Creates a new query_start_loc that corresponds to the requests in
request_slice.
Note: This function creates a new tensor to hold the new query_start_locs.
This will break cudagraph compatibility.
"""
return query_start_loc[request_slice.start: request_slice.stop + 1] -\
query_start_loc[request_slice.start]
def _make_metadata_with_slice(
ubatch_slice: UBatchSlice,
attn_metadata: CommonAttentionMetadata) -> CommonAttentionMetadata:
"""
This function creates a new CommonAttentionMetadata that corresponds to
the requests included in ubatch_slice
"""
assert not ubatch_slice.is_empty(), (
f"Ubatch slice {ubatch_slice} is empty")
request_slice = ubatch_slice.request_slice
token_slice = ubatch_slice.token_slice
start_locs = attn_metadata.query_start_loc_cpu
first_req = request_slice.start
first_tok = token_slice.start
last_req = request_slice.stop - 1
last_tok = token_slice.stop - 1
assert start_locs[first_req] <= first_tok < start_locs[first_req + 1], \
"Token slice start outside of first request"
assert start_locs[last_req] <= last_tok < start_locs[last_req+1], \
"Token slice end outside of last request"
# If the "middle" request has tokens in both ubatches, we have to split it.
# If ubatch_slice is the first ubatch then we will be splitting the last
# request. If it's the second microbatch, then we will be splitting the
# first request
splits_first_request = first_tok > start_locs[first_req]
splits_last_request = last_tok < start_locs[last_req + 1] - 1
query_start_loc_cpu = slice_query_start_locs(start_locs, request_slice)
query_start_loc = slice_query_start_locs(attn_metadata.query_start_loc,
request_slice)
assert len(query_start_loc) >= 2, (
f"query_start_loc must have at least 2 elements, "
f"got {len(query_start_loc)}")
if splits_first_request:
tokens_skipped = first_tok - start_locs[first_req]
query_start_loc[1:] -= tokens_skipped
query_start_loc_cpu[1:] -= tokens_skipped
seq_lens = attn_metadata.seq_lens[request_slice]
seq_lens_cpu = attn_metadata.seq_lens_cpu[request_slice]
if splits_last_request:
tokens_skipped = query_start_loc_cpu[-1] - token_slice.stop
query_start_loc[-1] -= tokens_skipped
query_start_loc_cpu[-1] -= tokens_skipped
# Make sure we don't modify the seq_lens tensors
# (not cudagraph compatible)
seq_lens = seq_lens.clone()
seq_lens_cpu = seq_lens_cpu.clone()
seq_lens[-1] -= tokens_skipped
seq_lens_cpu[-1] -= tokens_skipped
max_seq_len = int(seq_lens_cpu.max())
num_computed_tokens_cpu = attn_metadata.num_computed_tokens_cpu[
request_slice]
num_requests = request_slice.stop - request_slice.start
num_actual_tokens = token_slice.stop - token_slice.start
max_query_len = int(
torch.max(torch.abs(query_start_loc_cpu[1:] -
query_start_loc_cpu[:-1])).item())
# This is to account for the case where we are in a dummy
# run and query_start_loc_cpu is full of 0s
if max_query_len == 0:
max_query_len = attn_metadata.max_query_len
block_table_tensor = attn_metadata.block_table_tensor[request_slice]
slot_mapping = attn_metadata.slot_mapping[token_slice]
return CommonAttentionMetadata(
query_start_loc=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
seq_lens=seq_lens,
seq_lens_cpu=seq_lens_cpu,
num_computed_tokens_cpu=num_computed_tokens_cpu,
num_reqs=num_requests,
num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len,
max_seq_len=max_seq_len,
block_table_tensor=block_table_tensor,
slot_mapping=slot_mapping,
)
def split_attn_metadata(
ubatch_slices: list[UBatchSlice],
common_attn_metadata: CommonAttentionMetadata,
) -> list[CommonAttentionMetadata]:
"""
Creates a new CommonAttentionMetadata instance that corresponds to the
requests for each UBatchSlice in ubatch_slices.
Note: This function does not modify common_attn_metadata
"""
results = []
for ubatch_slice in ubatch_slices:
results.append(
_make_metadata_with_slice(ubatch_slice, common_attn_metadata))
return results
M = TypeVar("M")
class AttentionCGSupport(enum.Enum):
""" Constants for the cudagraph support of the attention backend
Here we do not consider the cascade attention, as currently
it is never cudagraph supported."""
ALWAYS = 3
"""Cudagraph always supported; supports mixed-prefill-decode"""
UNIFORM_BATCH = 2
"""Cudagraph supported for batches the only contain query lengths that are
the same, this can be used for spec-decode
i.e. "decodes" are 1 + num_speculative_tokens"""
UNIFORM_SINGLE_TOKEN_DECODE = 1
"""Cudagraph supported for batches the only contain query_len==1 decodes"""
NEVER = 0
"""NO cudagraph support"""
class AttentionMetadataBuilder(abc.ABC, Generic[M]):
# Does this backend/builder support CUDA Graphs for attention (default: no).
cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.NEVER
# Does this backend/builder reorder the batch?
# If not, set this to None. Otherwise set it to the query
# length that will be pulled into the front of the batch.
reorder_batch_threshold: Optional[int] = None
@abstractmethod
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
self.kv_cache_spec = kv_cache_spec
self.layer_names = layer_names
self.vllm_config = vllm_config
self.device = device
def _init_reorder_batch_threshold(
self,
reorder_batch_threshold: int = 1,
supports_spec_as_decode: bool = False) -> None:
self.reorder_batch_threshold = reorder_batch_threshold
if self.reorder_batch_threshold is not None \
and supports_spec_as_decode:
# If the backend supports spec-as-decode kernels, then we can set
# the reorder_batch_threshold based on the number of speculative
# tokens from the config.
speculative_config = self.vllm_config.speculative_config
if (speculative_config is not None
and speculative_config.num_speculative_tokens is not None):
self.reorder_batch_threshold = \
1 + speculative_config.num_speculative_tokens
@abstractmethod
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> M:
"""
Central method that builds attention metadata.
Some builders (MLA) require reorder_batch to be called prior to build.
Args:
common_prefix_len: The length of the common prefix of the batch.
common_attn_metadata: The common attention metadata.
fast_build: The meta-data will prioritize speed of building over
then speed at execution. Can be used for spec-decode where the
result of a build call may only be used for few layers/iters.
"""
raise NotImplementedError
def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
"""
Update the order of requests in the batch based on the attention
backend's needs. For example, some attention backends (namely MLA) may
want to separate requests based on if the attention computation will be
compute-bound or memory-bound.
Args:
input_batch: input batch
scheduler_output: scheduler output.
Returns:
True if the batch was modified, False otherwise.
"""
raise NotImplementedError
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata) -> M:
"""
Build attention metadata for CUDA graph capture. Uses build by default.
Subclasses that override this method should call self.build or
super().build_for_cudagraph_capture.
"""
return self.build(common_prefix_len=0,
common_attn_metadata=common_attn_metadata)
def build_for_drafting(
self,
common_attn_metadata: CommonAttentionMetadata,
draft_index: int,
) -> M:
"""
Build attention metadata for draft model. Uses build by default.
Args:
common_attn_metadata: The common attention metadata.
draft_index: The index of the current draft operation.
When speculating a chain of tokens, this index refers to the
draft attempt for the i-th token.
For tree-based attention, this index instead refers to the
draft attempt for the i-th level in the tree of tokens.
"""
return self.build(common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
fast_build=True)
def use_cascade_attention(
self,
common_prefix_len: int,
query_lens: np.ndarray,
num_query_heads: int,
num_kv_heads: int,
use_alibi: bool,
use_sliding_window: bool,
use_local_attention: bool,
num_sms: int,
) -> bool:
return False
@functools.lru_cache
def get_kv_cache_layout():
# Format specified by the code.
global _KV_CACHE_LAYOUT_OVERRIDE
if _KV_CACHE_LAYOUT_OVERRIDE is not None:
cache_layout = _KV_CACHE_LAYOUT_OVERRIDE
logger.info_once("`_KV_CACHE_LAYOUT_OVERRIDE` variable detected. " \
"Setting KV cache layout to %s.", cache_layout)
return cache_layout
# Format specified by the user.
cache_layout = envs.VLLM_KV_CACHE_LAYOUT
# When neither the user nor the override specified a layout, get default
if cache_layout is None:
cache_layout = get_kv_connector_cache_layout()
else:
assert is_valid_kv_cache_layout(cache_layout)
logger.info_once("`VLLM_KV_CACHE_LAYOUT` environment variable " \
"detected. Setting KV cache layout to %s.", cache_layout)
return cache_layout
def set_kv_cache_layout(cache_layout: KVCacheLayoutType):
global _KV_CACHE_LAYOUT_OVERRIDE
_KV_CACHE_LAYOUT_OVERRIDE = cache_layout
@dataclass
class PerLayerParameters:
"""
Currently, FlashInfer backend only support models in which all layers share
the same values for the following hyperparameters. Should not be used for
trtllm-gen backend since it supports different values for the following
hyperparameters.
"""
window_left: int
logits_soft_cap: Optional[float]
sm_scale: float
has_sinks: bool = False
def get_per_layer_parameters(
vllm_config: VllmConfig, layer_names: list[str],
cls_: type['AttentionImpl']) -> dict[str, PerLayerParameters]:
"""
Scan layers in `layer_names` and determine some hyperparameters
to use during `plan`.
"""
layers = get_layers_from_vllm_config(vllm_config, Attention, layer_names)
per_layer_params: dict[str, PerLayerParameters] = {}
for key, layer in layers.items():
impl = layer.impl
assert isinstance(impl, cls_)
# Infer hyperparameters from the attention layer
window_size = getattr(impl, "sliding_window", None)
window_left = window_size[0] if window_size is not None else -1
logits_soft_cap = getattr(impl, "logits_soft_cap", None)
sm_scale = impl.scale
has_sinks = getattr(impl, "sinks", None) is not None
per_layer_params[key] = PerLayerParameters(window_left,
logits_soft_cap, sm_scale,
has_sinks)
return per_layer_params
def infer_global_hyperparameters(
per_layer_params: dict[str, PerLayerParameters]) -> PerLayerParameters:
"""
Currently, FlashInfer backend other than trtllm-gen
only support models in which all layers share
the same values for the following hyperparameters:
- `window_left`
- `logits_soft_cap`
- `sm_scale`
So this function asserts that all layers share the same values for these
hyperparameters and returns the global values.
"""
assert len(per_layer_params) > 0, "No attention layers found in the model."
param_sets = list(per_layer_params.values())
global_params = param_sets[0]
# trtllm attention doesn't need global hyper params so disable the check
if not envs.VLLM_USE_TRTLLM_ATTENTION:
for params in param_sets:
if params.window_left != global_params.window_left:
raise ValueError(
"Window left is not the same for all layers. " \
"One potential fix is to set disable_sliding_window=True")
assert params == global_params, (
"FlashInfer backend currently only supports models in which all"
"layers share the same values "
"for the following hyperparameters:"
"`window_left`, `logits_soft_cap`, `sm_scale`.")
return global_params
#
# Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into
# local attention blocks, where each block is passed to the attention kernel
# as an independent local ("virtual") batch item.
#
# For example, if are performing a chunked prefill a batch of 3 sequences:
# q_seqlens = [4, 10, 5]
# kv_seqlens = [6, 17, 9]
# Then normally for regular attention we would compute with an attention mask
# for batch idx 0 (q_seqlens = 4, kv_seqlens = 6) like:
# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6)
# k_toks > 0 1 2 3 4 5
# q_toks v _____________
# 0 | 1 1 1
# 1 | 1 1 1 1
# 2 | 1 1 1 1 1
# 3 | 1 1 1 1 1 1
#
# for local attention (with attn_chunk_size = 4) we would compute with an
# attention mask like:
# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6, attn_chunk_size = 4)
# k_toks > 0 1 2 3 4 5
# q_toks v _____________
# 0 | 1 1 1
# 1 | 1 1 1 1
# 2 | 1
# 3 | 1 1
#
# We can simulate this mask using standard flash-attention by breaking the
# sequences into local ("virtual") batches, where each local batch item is a
# local attention block, so in this case batch idx 0 would be broken up into:
#
# local-batch idx: 0 (q_seqlens = 2, kv_seqlens = 4) (batch 0)
# k_toks > 0 1 2 3
# q_toks v _____________
# 0 | 1 1 1
# 1 | 1 1 1 1
# local-batch idx: 1 (q_seqlens = 2, kv_seqlens = 2) (batch 0)
# k_toks > 4 5
# q_toks v _____________
# 2 | 1
# 3 | 1 1
#
# e.g. if we have:
# attn_chunk_size = 4
# query_start_loc_np = [0, 4, 14, 19] (q_seqlens = [4, 10, 5])
# Then this function would return:
# __b0__ ______b1______ __b2__ < orig batch indices
# q_seqlens_local = [ 2, 2, 1, 4, 4, 1, 4, 1]
# cu_seqlens_q_local = [0, 4, 6, 10, 14, 18, 19, 23, 24]
# seqlens_k_local = [ 4, 2, 4, 4, 4, 1, 4, 1]
# block_table_local : shape[local_virtual_batches, pages_per_local_batch]
def make_local_attention_virtual_batches(
attn_chunk_size: int,
common_attn_metadata: CommonAttentionMetadata,
block_size: int = 0,
) -> CommonAttentionMetadata:
query_start_loc_np = common_attn_metadata.query_start_loc_cpu.numpy()
seq_lens_np = common_attn_metadata.seq_lens_cpu.numpy()
block_table = common_attn_metadata.block_table_tensor
device = common_attn_metadata.query_start_loc.device
q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1]
actual_batch_size = seq_lens_np.shape[0]
# Handle if we are starting in the middle of a local attention block,
# we assume q_seqlens > 0 (for all elements), for each batch idx we compute
# the number of tokens that are not in the first local attention block and
# then we can simply use a cdiv for the rest.
# For example if we have:
# attn_chunk_size = 4
# q_seqlens = [4, 10, 5]
# k_seqlens = [6, 17, 9]
# Then we would get:
# new_tokens_in_first_block = [2, 1, 4]
# local_blocks = [2, 4, 2]
q_tokens_in_first_block = np.minimum(
attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size),
q_seqlens).astype(np.int32)
tokens_in_last_block = attn_chunk_size + (seq_lens_np % -attn_chunk_size)
local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block,
attn_chunk_size)
# Once we know the number of local blocks we can compute the request spans
# for each batch idx, we can figure out the number of "virtual" requests we
# have to make,
# For the above example we would get:
# seqlens_q_local = [2, 2, 1, 4, 4, 1, 4, 1]
#
# First Get batched arange. (E.g., [2, 4, 2] -> [0, 1, 0, 1, 2, 3, 0, 1])
# (TODO: max a utility to share this code with _prepare_inputs)
# arange step 1. [2, 4, 2] -> [2, 6, 8]
cu_num_blocks = np.cumsum(local_blocks)
virtual_batches = cu_num_blocks[-1]
# arange step 2. [2, 6, 8] -> [0, 0, 2, 2, 2, 2, 6, 6]
block_offsets = np.repeat(cu_num_blocks - local_blocks, local_blocks)
# arange step 3. [0, 1, 0, 1, 2, 3, 0, 1]
arange = np.arange(virtual_batches, dtype=np.int32) - block_offsets
# also compute reverse arange (i.e. [1, 0, 3, 2, 1, 0, 1, 0])
rarange = np.repeat(local_blocks, local_blocks) - arange - 1
# Then we can compute the seqlens_q_local, handling the fact that the
# first and last blocks could be partial
seqlens_q_local = \
np.repeat(q_seqlens - q_tokens_in_first_block, local_blocks)
# set the first block since this may be a partial block
seqlens_q_local[arange == 0] = q_tokens_in_first_block
# set the remaining blocks
seqlens_q_local[arange > 0] = np.minimum(
seqlens_q_local - attn_chunk_size * (arange - 1),
attn_chunk_size)[arange > 0]
# convert from q_seqlens to cu_seqlens_q
cu_seqlens_q_local = np.empty(virtual_batches + 1, dtype=np.int32)
np.cumsum(seqlens_q_local, out=cu_seqlens_q_local[1:])
cu_seqlens_q_local[0] = 0
# compute the seqlens_k_local,
# basically a full local attention block for all but the last block in each
# batch
# For our example this will be:
# seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1]
seqlens_k_local = np.full(cu_num_blocks[-1],
attn_chunk_size,
dtype=np.int32)
seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block
num_computed_tokens_local = seqlens_k_local - seqlens_q_local
k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - \
(rarange * attn_chunk_size + \
np.repeat(tokens_in_last_block, local_blocks))
# For the example the local attention blocks start at:
# _b0_ _____b1_____ _b2_
# k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8]
block_starts = k_seqstarts_absolute // block_size
assert attn_chunk_size % block_size == 0, \
f"attn_chunk_size {attn_chunk_size} is not " \
f"divisible by block_size {block_size}"
pages_per_local_batch = attn_chunk_size // block_size
# Create a block_table for the local attention blocks
# For out example if we have a block-table like (assuming block_size=2):
# block_table = [
# [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], < batch 0
# [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], < batch 1
# [20, 21, 22, 23, 24, 25, 26, 27, 28, 29], < batch 2
# ]
# Then for the local batches we would want a block-table like
# block_table_local = [
# [ 0, 1 ], < local-batch 0, (batch 0, starting from k[0])
# [ 2, 3 ], < local-batch 1, (batch 0, starting from k[4])
# [ 12, 13 ], < local-batch 2, (batch 1, starting from k[4])
# [ 14, 15 ], < local-batch 3, (batch 1, starting from k[8])
# [ 16, 17 ], < local-batch 4, (batch 1, starting from k[12])
# [ 18, 19 ], < local-batch 5, (batch 1, starting from k[16])
# [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4])
# [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8])
# ]
block_indices = (block_starts[:, None] +
np.arange(pages_per_local_batch, dtype=np.int32))
block_indices = block_indices.reshape(-1).clip(max=block_table.shape[1] -
1)
batch_indices = np.repeat(np.arange(actual_batch_size, dtype=np.int32),
local_blocks * pages_per_local_batch)
# NOTE: https://github.com/pytorch/pytorch/pull/160256 causes performance
# regression when using numpy arrays (batch and block indices) to index into
# torch tensor (block_table). As a workaround, convert numpy arrays to torch
# tensor first, which recovers perf.
batch_indices_torch = torch.from_numpy(batch_indices)
block_indices_torch = torch.from_numpy(block_indices)
block_table_local = block_table[batch_indices_torch, block_indices_torch]\
.view(virtual_batches, -1)
query_start_loc_cpu = torch.from_numpy(cu_seqlens_q_local)
seq_lens_cpu = torch.from_numpy(seqlens_k_local)
max_seq_len = int(seq_lens_cpu.max())
return CommonAttentionMetadata(
query_start_loc_cpu=query_start_loc_cpu,
query_start_loc=query_start_loc_cpu.to(device=device,
non_blocking=True),
seq_lens_cpu=seq_lens_cpu,
seq_lens=seq_lens_cpu.to(device=device, non_blocking=True),
num_computed_tokens_cpu=torch.from_numpy(num_computed_tokens_local),
num_reqs=len(seq_lens_cpu),
num_actual_tokens=common_attn_metadata.num_actual_tokens,
max_query_len=seqlens_q_local.max(),
max_seq_len=max_seq_len,
block_table_tensor=block_table_local,
slot_mapping=common_attn_metadata.slot_mapping,
causal=True,
)
def make_kv_sharing_fast_prefill_common_attn_metadata(
common_attn_metadata: CommonAttentionMetadata,
) -> CommonAttentionMetadata:
if common_attn_metadata.max_query_len == 1:
# All requests are decode (assume 1 token for now)
# Skip computing fast prefill path
return common_attn_metadata
assert common_attn_metadata.logits_indices_padded is not None
assert common_attn_metadata.num_logits_indices is not None
logits_indices_padded = common_attn_metadata.logits_indices_padded
num_logits_indices = common_attn_metadata.num_logits_indices
# Get rid of CUDAGraph padding, if any
logits_indices = logits_indices_padded[:num_logits_indices]
num_reqs = common_attn_metadata.num_reqs
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
# Example inputs
# num_reqs: 3
# generation_indices: [14, 18, 19, 27]
# query_start_loc: [0, 15, 20, 28]
# seq_lens: [41, 31, 40]
# Find how many decode indices belong to each request
# request_ids: [0, 1, 1, 2]
request_ids = torch.bucketize(logits_indices,
query_start_loc[1:],
right=True)
# Figure out how many tokens are in each request
# num_decode_tokens: [1, 2, 1]
num_decode_tokens = torch.bincount(request_ids, minlength=num_reqs)
# Calculate new query_start_loc with tokens in generation_indices
# decode_query_start_loc: [0, 1, 3, 4]
decode_query_start_loc = torch.empty(num_reqs + 1,
device=query_start_loc.device,
dtype=query_start_loc.dtype)
decode_query_start_loc[0] = 0
decode_query_start_loc[1:] = torch.cumsum(num_decode_tokens, dim=0)
decode_max_query_len = int(num_decode_tokens.max().item())
total_num_decode_tokens = int(num_decode_tokens.sum().item())
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=decode_query_start_loc,
query_start_loc_cpu=decode_query_start_loc.to("cpu",
non_blocking=True),
seq_lens=seq_lens,
seq_lens_cpu=seq_lens.to("cpu", non_blocking=True),
num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu,
num_reqs=num_reqs,
num_actual_tokens=total_num_decode_tokens,
max_query_len=decode_max_query_len,
max_seq_len=common_attn_metadata.max_seq_len,
block_table_tensor=common_attn_metadata.block_table_tensor,
slot_mapping=common_attn_metadata.slot_mapping,
causal=True,
)
return common_attn_metadata
def subclass_attention_backend(
name_prefix: str, attention_backend_cls: type[AttentionBackend],
builder_cls: type[AttentionMetadataBuilder[M]]
) -> type[AttentionBackend]:
"""
Return a new subclass where `get_builder_cls` returns `builder_cls`.
"""
name: str = name_prefix + attention_backend_cls.__name__ # type: ignore
return type(name, (attention_backend_cls, ),
{"get_builder_cls": lambda: builder_cls})
def split_decodes_and_prefills(
common_attn_metadata: CommonAttentionMetadata,
decode_threshold: int = 1,
require_uniform: bool = False) -> tuple[int, int, int, int]:
"""
Assuming a reordered batch, finds the boundary between prefill and decode
requests.
Args:
common_attn_metadata: CommonAttentionMetadata object containing the
batch metadata.
decode_threshold: The maximum query length to be considered a decode.
require_uniform: If True, requires that all decode requests have the
same query length. When set, some queries may be considered prefills
even if they are <= decode_threshold, in order to ensure uniformity.
Returns:
num_decodes: The number of decode requests.
num_prefills: The number of prefill requests.
num_decode_tokens: The number of tokens in the decode requests.
num_prefill_tokens: The number of tokens in the prefill requests.
"""
max_query_len = common_attn_metadata.max_query_len
num_reqs = common_attn_metadata.num_reqs
num_tokens = common_attn_metadata.num_actual_tokens
query_start_loc = common_attn_metadata.query_start_loc_cpu
if max_query_len <= decode_threshold and \
(not require_uniform or decode_threshold <= 1):
return num_reqs, 0, num_tokens, 0
query_lens = query_start_loc[1:] - query_start_loc[:-1]
if query_lens[0].item() > decode_threshold:
# first request is not decode, so no decode requests
return 0, num_reqs, 0, num_tokens
if require_uniform:
is_prefill = query_lens != query_lens[0]
else:
is_prefill = query_lens > decode_threshold
if not torch.any(is_prefill):
return num_reqs, 0, num_tokens, 0
first_prefill = is_prefill.int().argmax(dim=-1).item()
assert torch.all(query_lens[:first_prefill] <= decode_threshold)
num_decodes = first_prefill
num_prefills = num_reqs - num_decodes
num_decode_tokens = query_start_loc[first_prefill].item()
num_prefill_tokens = num_tokens - num_decode_tokens
return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens)
def reorder_batch_to_split_decodes_and_prefills(
input_batch: "InputBatch",
scheduler_output: "SchedulerOutput",
decode_threshold: int = 1,
) -> bool:
"""
Reorders the batch to split into prefill and decode requests; places all
requests with <= decode_threshold tokens at the front of the batch.
Returns:
True if the batch was modified, False otherwise.
"""
# We now want to reorder the batch so that the "decode" requests are at
# the front and the "prefill" requests are at the back using the least
# amount of swaps possible. (NOTE for now we loosely use "decode" to mean
# requests where attention is likely memory-bound and "prefill" to mean
# requests where attention is likely compute-bound, TODO(lucas): figure out
# a better naming here)
decodes = []
prefills = []
num_decode_tokens = 0
num_prefill_tokens = 0
for i, req_id in enumerate(input_batch.req_ids):
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
# for now treat 1 scheduled token as "decode" even if it's not,
# we should update this to something like < 8 in the future but
# currently the TritonMLA._forward_decode only supports
# num_tokens = 1
if num_tokens <= decode_threshold:
decodes.append(i)
num_decode_tokens += num_tokens
else:
prefills.append(i)
num_prefill_tokens += num_tokens
# We hope that this is fairly minimal since decodes
# should be around for a number of iterations so hopefully they are
# relatively stationary (and new request are generally appended to the
# persistent batch so already should be at the back)
# To achieve this we loop over the decodes in descending order and
# the prefills in ascending order. We swap decodes from the "back"
# i.e. past where the last decode should be in the reodorered with
# prefills from the front of the batch.
# `decodes` and `prefills` are already in ascending order just based on
# the above loop
num_decodes = len(decodes)
num_prefills = len(prefills)
modified_batch = False
for i in range(1, min(num_decodes, num_prefills) + 1):
# If the decode is at the "back" of the batch, i, we can swap it
# with the prefill closest to the front of the batch
decode_idx = decodes[num_decodes - i]
if decode_idx < num_decodes:
break
input_batch.swap_states(prefills[i - 1], decode_idx)
modified_batch = True
return modified_batch
def reshape_query_for_spec_decode(query: torch.Tensor,
batch_size: int) -> torch.Tensor:
"""
Reshapes the query tensor for the specified batch size, so that
it has shape (batch_size, seq_len, num_heads, head_dim).
"""
assert query.dim() == 3, f"query must be 3D, got {query.dim()}D"
total_tokens = query.shape[0]
num_heads = query.shape[1]
head_dim = query.shape[2]
assert total_tokens % batch_size == 0, (
f"{total_tokens=} is not divisible by {batch_size=}")
seq_len = total_tokens // batch_size
return query.view(batch_size, seq_len, num_heads, head_dim)
def reshape_attn_output_for_spec_decode(
attn_output: torch.Tensor) -> torch.Tensor:
"""
Reshapes the attention output tensor, so that
the batch_size and seq_len dimensions are combined.
"""
if attn_output.dim() == 3:
# Already in the correct shape
return attn_output
assert attn_output.dim() == 4, \
f"attn_output must be 4D, got {attn_output.dim()}D"
total_tokens = attn_output.shape[0] * attn_output.shape[1]
return attn_output.view(total_tokens, attn_output.shape[2],
attn_output.shape[3])
KV_SHARING_FAST_PREFILL_METADATA_FIELDS = [
('logits_indices_padded', Optional[torch.Tensor], None),
('num_logits_indices', int, 0),
]
def subclass_attention_metadata(
name_prefix: str,
metadata_cls: Any,
fields: list[tuple[str, Any, Any]],
) -> Any:
"""
Return a new subclass of `metadata_cls` with additional fields
"""
name: str = name_prefix + metadata_cls.__name__ # type: ignore
Wrapped = make_dataclass(name, fields, bases=(metadata_cls, ))
return Wrapped
@runtime_checkable
class KVSharingFastPrefillMetadata(Protocol):
logits_indices_padded: torch.Tensor
num_logits_indices: int
def create_fast_prefill_custom_backend(
prefix: str,
underlying_attn_backend: AttentionBackend,
) -> type[AttentionBackend]:
underlying_builder = underlying_attn_backend.get_builder_cls()
class FastPrefillAttentionBuilder(underlying_builder): # type: ignore
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> AttentionMetadata:
new_common_attn_metadata =\
make_kv_sharing_fast_prefill_common_attn_metadata(common_attn_metadata)
metadata = super().build(common_prefix_len,
new_common_attn_metadata, fast_build)
class KVSharingFastPrefillAttentionMetadata(
metadata.__class__, # type: ignore
KVSharingFastPrefillMetadata):
def __init__(self, metadata, common_attn_metadata):
# Shallow copy all fields in metadata cls
for field in fields(metadata.__class__):
setattr(self, field.name,
getattr(metadata, field.name))
# Set additional fields that will be used in model code
assert (common_attn_metadata.logits_indices_padded
is not None
and common_attn_metadata.num_logits_indices
is not None)
self.logits_indices_padded = \
common_attn_metadata.logits_indices_padded
self.num_logits_indices = \
common_attn_metadata.num_logits_indices
return KVSharingFastPrefillAttentionMetadata(
metadata, common_attn_metadata)
attn_backend = subclass_attention_backend(
name_prefix=prefix,
attention_backend_cls=underlying_attn_backend,
builder_cls=FastPrefillAttentionBuilder)
return attn_backend
def compute_causal_conv1d_metadata(query_start_loc_p: torch.Tensor):
# Needed for causal_conv1d
seqlens = query_start_loc_p.diff().to('cpu')
nums_dict = {} # type: ignore
batch_ptr = None
token_chunk_offset_ptr = None
device = query_start_loc_p.device
for BLOCK_M in [8]: # cover all BLOCK_M values
nums = -(-seqlens // BLOCK_M)
nums_dict[BLOCK_M] = {}
nums_dict[BLOCK_M]['nums'] = nums
nums_dict[BLOCK_M]['tot'] = nums.sum().item()
mlist = torch.from_numpy(np.repeat(np.arange(len(nums)), nums))
nums_dict[BLOCK_M]['mlist'] = mlist
mlist_len = len(nums_dict[BLOCK_M]['mlist'])
nums_dict[BLOCK_M]['mlist_len'] = mlist_len
MAX_NUM_PROGRAMS = max(1024, mlist_len) * 2
offsetlist = [] # type: ignore
for idx, num in enumerate(nums):
offsetlist.extend(range(num))
offsetlist = torch.tensor(offsetlist, dtype=torch.int32)
nums_dict[BLOCK_M]['offsetlist'] = offsetlist
if batch_ptr is None:
# Update default value after class definition
batch_ptr = torch.full((MAX_NUM_PROGRAMS, ),
PAD_SLOT_ID,
dtype=torch.int32,
device=device)
token_chunk_offset_ptr = torch.full((MAX_NUM_PROGRAMS, ),
PAD_SLOT_ID,
dtype=torch.int32,
device=device)
else:
if batch_ptr.nelement() < MAX_NUM_PROGRAMS:
batch_ptr.resize_(MAX_NUM_PROGRAMS).fill_(PAD_SLOT_ID)
token_chunk_offset_ptr.resize_( # type: ignore
MAX_NUM_PROGRAMS).fill_(PAD_SLOT_ID)
batch_ptr[0:mlist_len].copy_(mlist)
token_chunk_offset_ptr[ # type: ignore
0:mlist_len].copy_(offsetlist)
nums_dict[BLOCK_M]['batch_ptr'] = batch_ptr
nums_dict[BLOCK_M]['token_chunk_offset_ptr'] = (token_chunk_offset_ptr
) # type: ignore
return nums_dict, batch_ptr, token_chunk_offset_ptr

View File

@@ -0,0 +1,438 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with XFormersAttention."""
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional
import torch
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
from vllm.attention.ops.triton_unified_attention import unified_attention
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, CommonAttentionMetadata,
reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec
try:
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import (
AttentionBias, PagedBlockDiagonalCausalWithOffsetPaddedKeysMask)
XFORMERS_AVAILABLE = True
except ImportError:
XFORMERS_AVAILABLE = False
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm import _custom_ops as ops
logger = init_logger(__name__)
class XFormersAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [
32,
40,
48,
56,
64,
72,
80,
88,
96,
104,
112,
120,
128,
136,
144,
152,
160,
168,
176,
184,
192,
200,
208,
216,
224,
232,
240,
248,
256,
]
@classmethod
def validate_head_size(cls, head_size: int) -> None:
supported_head_sizes = cls.get_supported_head_sizes()
if head_size not in supported_head_sizes:
attn_type = cls.__name__.removesuffix("Backend")
raise ValueError(
f"Head size {head_size} is not supported by {attn_type}. "
f"Supported head sizes are: {supported_head_sizes}. "
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes.")
@staticmethod
def get_name() -> str:
return "XFORMERS"
@staticmethod
def get_impl_cls() -> type["XFormersAttentionImpl"]:
return XFormersAttentionImpl
@staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
return XFormersAttentionMetadata
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
return (2, num_blocks, block_size, num_kv_heads, head_size)
@staticmethod
def get_builder_cls() -> type["XFormersAttentionMetadataBuilder"]:
return XFormersAttentionMetadataBuilder
@staticmethod
def use_cascade_attention(*args, **kwargs) -> bool:
return False
@dataclass
class XFormersAttentionMetadata:
num_actual_tokens: int # Number of tokens excluding padding.
max_query_len: int
query_start_loc: torch.Tensor
max_seq_len: int
seq_lens: torch.Tensor
block_table: torch.Tensor
slot_mapping: torch.Tensor
num_prefill_tokens: int = 0
num_decode_tokens: int = 0
num_prefills: int = 0
num_decodes: int = 0
# Biases for different attention types.
attn_bias: Optional["AttentionBias"] = None
# Self-attention prefill/decode metadata cache
_cached_prefill_metadata: Optional["XFormersAttentionMetadata"] = None
_cached_decode_metadata: Optional["XFormersAttentionMetadata"] = None
@property
def prefill_metadata(self) -> Optional["XFormersAttentionMetadata"]:
if self.num_prefills == 0:
return None
if self._cached_prefill_metadata is not None:
# Recover cached prefill-phase attention
# metadata structure
return self._cached_prefill_metadata
q_start_loc = self.query_start_loc[self.num_decodes:]
q_seqlens = torch.diff(q_start_loc)
kv_seqlens = self.seq_lens[self.num_decodes:]
# Construct & cache prefill-phase attention metadata structure
self._cached_prefill_metadata = XFormersAttentionMetadata(
num_actual_tokens=self.num_prefill_tokens,
max_query_len=int(q_seqlens.max().item()),
query_start_loc=q_start_loc - q_start_loc[0],
max_seq_len=int(kv_seqlens.max().item()),
seq_lens=kv_seqlens,
block_table=self.block_table[self.num_decodes:],
slot_mapping=self.slot_mapping[self.num_decode_tokens:],
)
return self._cached_prefill_metadata
@property
def decode_metadata(self) -> Optional["XFormersAttentionMetadata"]:
if self.num_decode_tokens == 0:
return None
if self._cached_decode_metadata is not None:
# Recover cached decode-phase attention
# metadata structure
return self._cached_decode_metadata
q_start_loc = self.query_start_loc
q_seqlens = torch.diff(q_start_loc)
decode_kv_seqlens = self.seq_lens[:self.num_decodes]
# Construct & cache decode-phase attention metadata structure
self._cached_decode_metadata = XFormersAttentionMetadata(
num_actual_tokens=self.num_decode_tokens,
max_query_len=int(q_seqlens[:self.num_decodes].max().item()),
query_start_loc=q_start_loc[:self.num_decodes + 1],
max_seq_len=int(decode_kv_seqlens.max().item()),
seq_lens=decode_kv_seqlens,
block_table=self.block_table[:self.num_decodes],
slot_mapping=self.slot_mapping[:self.num_decode_tokens],
attn_bias=self.attn_bias,
)
return self._cached_decode_metadata
class XFormersAttentionMetadataBuilder(
AttentionMetadataBuilder[XFormersAttentionMetadata]):
reorder_batch_threshold: int = 1
def __init__(
self,
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
assert XFORMERS_AVAILABLE
self.block_size = kv_cache_spec.block_size
self._num_decodes = 0
self._num_decode_tokens = 0
def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
return reorder_batch_to_split_decodes_and_prefills(
input_batch,
scheduler_output,
decode_threshold=self.reorder_batch_threshold)
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> XFormersAttentionMetadata:
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(
common_attn_metadata,
decode_threshold=self.reorder_batch_threshold))
num_actual_tokens = common_attn_metadata.num_actual_tokens
q_start_loc = common_attn_metadata.query_start_loc
q_seqlens = torch.diff(q_start_loc)
max_query_len = common_attn_metadata.max_query_len
kv_seqlens = common_attn_metadata.seq_lens
max_seq_len = common_attn_metadata.max_seq_len
block_table = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping
bias = None
if num_decodes > 0:
# Construct the decoder bias.
decode_q_seqlens = q_seqlens[:num_decodes]
decode_kv_seqlens = kv_seqlens[:num_decodes]
bias = (
PagedBlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens(
q_seqlen=decode_q_seqlens.tolist(),
kv_seqlen=decode_kv_seqlens.tolist(),
page_size=self.block_size,
block_tables=block_table[:num_decodes],
device=block_table.device,
))
return XFormersAttentionMetadata(
num_actual_tokens=num_actual_tokens,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
num_prefills=num_prefills,
num_decodes=num_decodes,
max_query_len=max_query_len,
query_start_loc=q_start_loc,
max_seq_len=max_seq_len,
seq_lens=kv_seqlens,
block_table=block_table,
slot_mapping=slot_mapping,
attn_bias=bias,
)
class XFormersAttentionImpl(AttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
logits_soft_cap: Optional[float] = None,
attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
) -> None:
if kv_sharing_target_layer_name is not None:
raise NotImplementedError("KV sharing is not supported in V0.")
if alibi_slopes is not None:
raise NotImplementedError(
"XFormers does not support alibi slopes yet.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.kv_cache_dtype = kv_cache_dtype
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
if sliding_window is None:
self.sliding_window = (-1, -1)
else:
self.sliding_window = (sliding_window - 1, 0)
if logits_soft_cap is None:
# Setting logits_soft_cap to 0 means no soft cap.
logits_soft_cap = 0
self.logits_soft_cap = logits_soft_cap
XFormersAttentionBackend.validate_head_size(head_size)
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"XFormersAttentionImpl.")
def forward(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: XFormersAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with XFormers.
Args:
query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache: shape =
[2, num_blocks, block_size, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert output is not None, "Output tensor must be provided."
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for XFormersAttentionImpl")
if attn_metadata is None:
# Profiling run.
return output
# Cache the input KVs.
key_cache, value_cache = kv_cache.unbind(0)
if self.kv_sharing_target_layer_name is None:
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
# not padded. However, we don't need to do key[:num_actual_tokens]
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
ops.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
num_actual_tokens = attn_metadata.num_actual_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
if prefill_meta := attn_metadata.prefill_metadata:
descale_shape = (prefill_meta.query_start_loc.shape[0] - 1,
key.shape[1])
unified_attention(
q=query[num_decode_tokens:num_actual_tokens],
k=key_cache,
v=value_cache,
out=output[num_decode_tokens:num_actual_tokens],
cu_seqlens_q=prefill_meta.query_start_loc,
max_seqlen_q=prefill_meta.max_query_len,
seqused_k=prefill_meta.seq_lens,
max_seqlen_k=prefill_meta.max_seq_len,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
block_table=prefill_meta.block_table,
softcap=self.logits_soft_cap,
q_descale=None, # Not supported
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
)
if decode_meta := attn_metadata.decode_metadata:
# Query for decode. KV is not needed because it is already cached.
decode_query = query[:num_decode_tokens]
# Reshape query to [1, B_T, G, H, D].
q = decode_query.view(1, -1, self.num_kv_heads,
self.num_queries_per_kv, self.head_size)
# Reshape the k and v caches to [1, Bkv_T, G, H, D]
cache_k = key_cache.view(1, -1, self.num_kv_heads, 1,
self.head_size).expand(
1,
-1,
self.num_kv_heads,
self.num_queries_per_kv,
self.head_size,
)
cache_v = value_cache.view(1, -1, self.num_kv_heads, 1,
self.head_size).expand(
1,
-1,
self.num_kv_heads,
self.num_queries_per_kv,
self.head_size,
)
attn_bias = decode_meta.attn_bias
output[:
num_decode_tokens] = xops.memory_efficient_attention_forward(
q,
cache_k,
cache_v,
attn_bias=attn_bias,
p=0.0,
scale=self.scale,
).view(decode_query.shape)
# Reshape the output tensor.
return output