Sync from v0.13

This commit is contained in:
2026-01-19 10:38:50 +08:00
parent b2ef04d792
commit 5aef6c175a
3714 changed files with 854317 additions and 89342 deletions

View File

@@ -1,13 +0,0 @@
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata,
AttentionMetadataPerStage)
from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend
__all__ = [
"AttentionBackend",
"AttentionMetadata",
"Attention",
"get_attn_backend",
"AttentionMetadataPerStage",
]

View File

@@ -1,22 +1,69 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from dataclasses import dataclass, fields
from typing import (Any, Dict, Generic, List, Optional, Set, Tuple, Type,
TypeVar)
from typing import TYPE_CHECKING, ClassVar, Generic, Protocol, TypeVar, get_args
import torch
if TYPE_CHECKING:
from vllm.config.cache import CacheDType
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backends.utils import KVCacheLayoutType
class AttentionType:
"""
Attention type.
Use string to be compatible with `torch.compile`.
"""
DECODER = "decoder"
"""Decoder attention between previous layer Q/K/V."""
ENCODER = "encoder"
"""Encoder attention between previous layer Q/K/V for encoder-decoder."""
ENCODER_ONLY = "encoder_only"
"""Encoder attention between previous layer Q/K/V."""
ENCODER_DECODER = "encoder_decoder"
"""Attention between dec. Q and enc. K/V for encoder-decoder."""
class MultipleOf:
base: int
def __init__(self, base: int):
self.base = base
class AttentionBackend(ABC):
"""Abstract class for attention backends."""
# For some attention backends, we allocate an output tensor before
# calling the custom op. When piecewise cudagraph is enabled, this
# makes sure the output tensor is allocated inside the cudagraph.
accept_output_buffer: bool = False
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list["CacheDType"]] = ["auto"]
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [MultipleOf(1)]
@staticmethod
@abstractmethod
def get_impl_cls() -> Type["AttentionImpl"]:
def get_name() -> str:
raise NotImplementedError
@staticmethod
@abstractmethod
def make_metadata(*args, **kwargs) -> "AttentionMetadataPerStage":
def get_impl_cls() -> type["AttentionImpl"]:
raise NotImplementedError
@staticmethod
@abstractmethod
def get_builder_cls(): # -> Type["AttentionMetadataBuilder"]:
raise NotImplementedError
@staticmethod
@@ -26,95 +73,212 @@ class AttentionBackend(ABC):
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
raise NotImplementedError
@staticmethod
@abstractmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: Dict[int, int],
) -> None:
def get_kv_cache_stride_order(
include_num_layers_dimension: bool = False,
) -> tuple[int, ...]:
"""
Get the physical (memory layout) ordering of the kv cache dimensions.
e.g. if the KV cache shape is
[2, num_blocks, block_size, num_heads, head_size],
and get_kv_cache_stride_order returns (1, 3, 0, 2, 4) then the physical
ordering of dimensions is
[num_blocks, num_heads, 2, block_size, head_size].
If this function is unimplemented / raises NotImplementedError,
the physical layout of the KV cache will match the logical shape.
Args:
include_num_layers_dimension: if True, includes an additional
num_layers dimension, which is assumed to be prepended
to the logical KV cache shape.
With the above example, a return value (2, 4, 0, 1, 3, 5)
corresponds to
[num_blocks, num_heads, num_layers, 2, block_size, head_size].
If an additional dimension is NOT included in the returned
tuple, the physical layout will not include a layers dimension.
Returns:
A tuple of ints which is a permutation of range(len(shape)).
"""
raise NotImplementedError
@staticmethod
@abstractmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]],
) -> None:
raise NotImplementedError
@classmethod
def full_cls_name(cls) -> tuple[str, str]:
return (cls.__module__, cls.__qualname__)
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return []
@dataclass
class AttentionMetadataPerStage:
"""Attention metadata for a specific stage. I.e., prefill or decode."""
@classmethod
def supports_head_size(cls, head_size: int) -> bool:
supported_head_sizes = cls.get_supported_head_sizes()
return (not supported_head_sizes) or head_size in supported_head_sizes
def asdict_zerocopy(self,
skip_fields: Optional[Set[str]] = None
) -> Dict[str, Any]:
"""Similar to dataclasses.asdict, but avoids deepcopying."""
if skip_fields is None:
skip_fields = set()
# Note that if we add dataclasses as fields, they will need
# similar handling.
return {
field.name: getattr(self, field.name)
for field in fields(self) if field.name not in skip_fields
}
@classmethod
def supports_dtype(cls, dtype: torch.dtype) -> bool:
return dtype in cls.supported_dtypes
@classmethod
def supports_kv_cache_dtype(cls, kv_cache_dtype: "CacheDType | None") -> bool:
if kv_cache_dtype is None:
return True
return (not cls.supported_kv_cache_dtypes) or (
kv_cache_dtype in cls.supported_kv_cache_dtypes
)
T = TypeVar("T", bound=AttentionMetadataPerStage)
@classmethod
def supports_block_size(cls, block_size: int | None) -> bool:
from vllm.config.cache import BlockSize
if block_size is None:
return True
@dataclass
class AttentionMetadata(Generic[T]):
"""Attention metadata for prefill and decode batched together."""
# Total number of prefill requests.
num_prefills: int
# Number of prefill tokens.
num_prefill_tokens: int
# Number of decode tokens. Note that it is equivalent to the number of
# decode requests.
num_decode_tokens: int
# The attention metadata for prefill requests in a batch.
# None if there's no prefill requests in a batch.
prefill_metadata: Optional[T]
# The attention metadata for decode requests in a batch.
# None if there's no decode requests in a batch.
decode_metadata: Optional[T]
# (num_tokens,). The indices of the token slots that input tokens will be
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
# in block 0, and 1st slot in block 1, respectively.
slot_mapping: torch.Tensor
# The kv cache's data type.
kv_cache_dtype: str
valid_sizes = get_args(BlockSize)
if block_size not in valid_sizes:
return False
def __post_init__(self):
if self.num_prefill_tokens > 0:
assert self.num_prefills > 0
assert self.prefill_metadata is not None
if self.num_decode_tokens > 0:
assert self.decode_metadata is not None
supported_kernel_block_sizes = cls.get_supported_kernel_block_sizes()
if not supported_kernel_block_sizes:
return True
for supported_size in supported_kernel_block_sizes:
if isinstance(supported_size, MultipleOf):
supported_size = supported_size.base
# With hybrid_blocks feature, the framework-level block size
# only needs to be a multiple of the kernel's requirement,
# even if the kernel requires a fixed block_size.
if block_size % supported_size == 0:
return True
return False
class AttentionImpl(ABC):
@classmethod
def is_mla(cls) -> bool:
return False
@abstractmethod
def __init__(
self,
num_heads: int,
@classmethod
def supports_sink(cls) -> bool:
return False
@classmethod
def supports_mm_prefix(cls) -> bool:
return False
@classmethod
def is_sparse(cls) -> bool:
return False
@classmethod
def supports_attn_type(cls, attn_type: str) -> bool:
"""Check if backend supports a given attention type.
By default, only supports decoder attention.
Backends should override this to support other attention types.
"""
return attn_type == AttentionType.DECODER
@classmethod
def supports_compute_capability(cls, capability: "DeviceCapability") -> bool:
return True
@classmethod
def supports_combination(
cls,
head_size: int,
scale: float,
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None,
) -> None:
raise NotImplementedError
dtype: torch.dtype,
kv_cache_dtype: "CacheDType | None",
block_size: int | None,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
device_capability: "DeviceCapability",
) -> str | None:
return None
@classmethod
def validate_configuration(
cls,
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: "CacheDType | None",
block_size: int | None,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
use_mm_prefix: bool,
device_capability: "DeviceCapability",
attn_type: str,
) -> list[str]:
invalid_reasons = []
if not cls.supports_head_size(head_size):
invalid_reasons.append("head_size not supported")
if not cls.supports_dtype(dtype):
invalid_reasons.append("dtype not supported")
if not cls.supports_kv_cache_dtype(kv_cache_dtype):
invalid_reasons.append("kv_cache_dtype not supported")
if not cls.supports_block_size(block_size):
invalid_reasons.append("block_size not supported")
if use_mm_prefix and not cls.supports_mm_prefix():
invalid_reasons.append(
"partial multimodal token full attention not supported"
)
if use_mla != cls.is_mla():
if use_mla:
invalid_reasons.append("MLA not supported")
else:
invalid_reasons.append("non-MLA not supported")
if has_sink and not cls.supports_sink():
invalid_reasons.append("sink setting not supported")
if use_sparse != cls.is_sparse():
if use_sparse:
invalid_reasons.append("sparse not supported")
else:
invalid_reasons.append("non-sparse not supported")
if not cls.supports_compute_capability(device_capability):
invalid_reasons.append("compute capability not supported")
if not cls.supports_attn_type(attn_type):
invalid_reasons.append(f"attention type {attn_type} not supported")
combination_reason = cls.supports_combination(
head_size,
dtype,
kv_cache_dtype,
block_size,
use_mla,
has_sink,
use_sparse,
device_capability,
)
if combination_reason is not None:
invalid_reasons.append(combination_reason)
return invalid_reasons
@classmethod
def get_required_kv_cache_layout(cls) -> "KVCacheLayoutType | None":
return None
class AttentionMetadata:
pass
T = TypeVar("T", bound=AttentionMetadata)
class AttentionLayer(Protocol):
_q_scale: torch.Tensor
_k_scale: torch.Tensor
_v_scale: torch.Tensor
_q_scale_float: float
_k_scale_float: float
_v_scale_float: float
_prob_scale: torch.Tensor
@abstractmethod
def forward(
self,
query: torch.Tensor,
@@ -122,6 +286,158 @@ class AttentionImpl(ABC):
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
kv_scale: float,
) -> torch.Tensor: ...
class AttentionImpl(ABC, Generic[T]):
# Whether the attention impl can return the softmax lse for decode.
# Some features like decode context parallelism require the softmax lse.
can_return_lse_for_decode: bool = False
# Whether the attention impl supports Prefill Context Parallelism.
supports_pcp: bool = False
# Whether the attention impl(or ops) supports MTP
# when cp_kv_cache_interleave_size > 1
supports_mtp_with_cp_non_trivial_interleave_size: bool = False
# some attention backends might not always want to return lse
# even if they can return lse (for efficiency reasons)
need_to_return_lse_for_decode: bool = False
# Whether this attention implementation supports pre-quantized query input.
# When True, the attention layer will quantize queries before passing them
# to this backend, allowing torch.compile to fuse the quantization with
# previous operations. This is typically supported when using FP8 KV cache
# with compatible attention kernels (e.g., TRT-LLM).
# Subclasses should set this in __init__.
# TODO add support to more backends:
# https://github.com/vllm-project/vllm/issues/25584
supports_quant_query_input: bool = False
dcp_world_size: int
dcp_rank: int
pcp_world_size: int
pcp_rank: int
total_cp_world_size: int
total_cp_rank: int
def __new__(cls, *args, **kwargs):
# use __new__ so that all subclasses will call this
self = super().__new__(cls)
try:
from vllm.distributed.parallel_state import get_dcp_group
self.dcp_world_size = get_dcp_group().world_size
self.dcp_rank = get_dcp_group().rank_in_group
except AssertionError:
# DCP might not be initialized in testing
self.dcp_world_size = 1
self.dcp_rank = 0
try:
from vllm.distributed.parallel_state import get_pcp_group
self.pcp_world_size = get_pcp_group().world_size
self.pcp_rank = get_pcp_group().rank_in_group
except AssertionError:
self.pcp_world_size = 1
self.pcp_rank = 0
self.total_cp_world_size = self.pcp_world_size * self.dcp_world_size
self.total_cp_rank = self.pcp_rank * self.dcp_world_size + self.dcp_rank
self.need_to_return_lse_for_decode = (
self.dcp_world_size > 1 and self.can_return_lse_for_decode
)
return self
@abstractmethod
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int | None = None,
alibi_slopes: list[float] | None = None,
sliding_window: int | None = None,
kv_cache_dtype: str = "auto",
logits_soft_cap: float | None = None,
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: str | None = None,
) -> None:
raise NotImplementedError
@abstractmethod
def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: T,
output: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
raise NotImplementedError
def fused_output_quant_supported(self, quant_key: "QuantKey"):
"""
Does this attention implementation support fused output quantization.
This is used by the AttnFusionPass to only fuse output quantization
onto implementations that support it.
:param quant_key: QuantKey object that describes the quantization op
:return: is fusion supported for this type of quantization
"""
return False
def process_weights_after_loading(self, act_dtype: torch.dtype):
pass
class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
@abstractmethod
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: list[float] | None,
sliding_window: int | None,
kv_cache_dtype: str,
logits_soft_cap: float | None,
attn_type: str,
kv_sharing_target_layer_name: str | None,
# MLA Specific Arguments
q_lora_rank: int | None,
kv_lora_rank: int,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
qk_head_dim: int,
v_head_dim: int,
kv_b_proj: "ColumnParallelLinear",
indexer: object | None = None,
) -> None:
raise NotImplementedError
@abstractmethod
def forward(
self,
layer: AttentionLayer,
hidden_states_or_cq: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: T,
output: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
raise NotImplementedError
def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
return kv_cache_dtype != "auto"

View File

@@ -1,283 +0,0 @@
"""Attention layer with Flash and PagedAttention.
NOTE(woosuk): At the moment, this file includes a lot of duplicated code from
XFormers backend. The duplicated code will be removed once we use flash-attn or
flashinfer for all the attention operations.
"""
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Type
import torch
import torch_musa
from torch.nn.functional import scaled_dot_product_attention
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata,
AttentionMetadataPerStage)
from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata)
class FlashAttentionBackend(AttentionBackend):
@staticmethod
def get_impl_cls() -> Type["FlashAttentionImpl"]:
return FlashAttentionImpl
@staticmethod
def make_metadata(*args, **kwargs) -> "FlashAttentionMetadata":
return FlashAttentionMetadata(*args, **kwargs)
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
num_kv_heads, head_size)
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: Dict[int, int],
) -> None:
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]],
) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists)
@dataclass
class FlashAttentionMetadata(AttentionMetadataPerStage,
PagedAttentionMetadata):
"""Metadata for FlashAttentionBackend.
NOTE: Any python object stored here is not updated when it is
cuda-graph replayed. If you have values that need to be changed
dynamically, it should be stored in tensor. The tensor has to be
updated from `CUDAGraphRunner.forward` API.
"""
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
is_prompt: bool
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens: Optional[List[int]]
# seq_lens stored as a tensor.
seq_lens_tensor: Optional[torch.Tensor]
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ----------------------|
# |-- query_len ---|
# Maximum query length in the batch.
max_query_len: Optional[int]
# Maximum sequence length in the batch.
max_seq_len: Optional[int]
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
subquery_start_loc: Optional[torch.Tensor]
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc: Optional[torch.Tensor]
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor: Optional[torch.Tensor]
# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool
class FlashAttentionImpl(AttentionImpl):
"""
If the input tensors contain prompt tokens, the layout is as follows:
|<--------------- num_prefill_tokens ----------------->|
|<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
Otherwise, the layout is as follows:
|<----------------- num_decode_tokens ------------------>|
|<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
Generation tokens can contain padding when cuda-graph is used.
Currently, prompt tokens don't contain any padding.
The prompts might have different lengths, while the generation tokens
always have length 1.
If chunked prefill is enabled, prefill tokens and decode tokens can be
batched together in a flattened 1D query.
|<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
|<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|
Currently, cuda graph is disabled for chunked prefill, meaning there's no
padding between prefill and decode tokens.
"""
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.sliding_window = -1
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
self.need_mask = (self.alibi_slopes is not None
or self.sliding_window is not None)
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
suppored_head_sizes = PagedAttention.get_supported_head_sizes()
if head_size not in suppored_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by PagedAttention. "
f"Supported head sizes are: {suppored_head_sizes}.")
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata[FlashAttentionMetadata],
kv_scale: float,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.
Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
# enable musa flash attention
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_math_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(True)
if kv_cache is not None:
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)
# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
PagedAttention.write_to_paged_cache(key, value, key_cache,
value_cache,
attn_metadata.slot_mapping,
attn_metadata.kv_cache_dtype,
kv_scale)
num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
output = torch.empty_like(query)
# Query for decode. KV is not needed because it is already cached.
decode_query = query[num_prefill_tokens:]
# QKV for prefill.
query = query[:num_prefill_tokens]
key = key[:num_prefill_tokens]
value = value[:num_prefill_tokens]
query = query.movedim(0, query.dim() - 2).unsqueeze(0)
key = key.movedim(0, key.dim() - 2).unsqueeze(0)
value = value.movedim(0, value.dim() - 2).unsqueeze(0)
assert decode_query.shape[0] == num_decode_tokens
if prefill_meta := attn_metadata.prefill_metadata:
tensor = torch.full(
(1, 1, num_tokens, num_tokens),
dtype=torch.bool,
fill_value=1,
device=query.device)
att_mask = torch.tril(tensor, diagonal=0)
# Prompt run.
if kv_cache is None or prefill_meta.block_tables.numel() == 0:
# normal attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
attn_output = scaled_dot_product_attention(
query.contiguous(),
key.contiguous(),
value.contiguous(),
attn_mask=att_mask.contiguous(),
dropout_p=0.0,
is_causal=False,
)
attn_output = attn_output.squeeze(0).permute(1, 0, 2).contiguous()
assert output[:num_prefill_tokens].shape == attn_output.shape
output[:num_prefill_tokens] = attn_output
else:
# prefix-enabled attention
# TODO(Hai) this triton kernel has regression issue (broke) to
# deal with different data types between KV and FP8 KV cache,
# to be addressed separately.
output[:num_prefill_tokens] = PagedAttention.forward_prefix(
query,
key,
value,
key_cache,
value_cache,
prefill_meta.block_tables,
prefill_meta.subquery_start_loc,
prefill_meta.seq_lens_tensor,
prefill_meta.context_lens_tensor,
prefill_meta.max_query_len,
self.alibi_slopes,
self.sliding_window[0],
)
if decode_meta := attn_metadata.decode_metadata:
# Decoding run.
output[num_prefill_tokens:] = PagedAttention.forward_decode(
decode_query,
key_cache,
value_cache,
decode_meta.block_tables,
decode_meta.seq_lens_tensor,
decode_meta.max_seq_len,
attn_metadata.kv_cache_dtype,
self.num_kv_heads,
self.scale,
self.alibi_slopes,
kv_scale,
)
# Reshape the output tensor.
return output.view(num_tokens, hidden_size)

View File

@@ -1,220 +0,0 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Set, Tuple, Type
try:
import flashinfer
from flash_attn import flash_attn_varlen_func
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
except ImportError:
flashinfer = None
flash_attn_varlen_func = None
BatchDecodeWithPagedKVCacheWrapper = None
import torch
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata,
AttentionMetadataPerStage)
class FlashInferBackend(AttentionBackend):
@staticmethod
def get_impl_cls() -> Type["FlashInferImpl"]:
return FlashInferImpl
@staticmethod
def make_metadata(*args, **kwargs) -> "FlashInferMetadata":
return FlashInferMetadata(*args, **kwargs)
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return (num_blocks, 2, block_size, num_kv_heads, head_size)
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: Dict[int, int],
) -> None:
raise NotImplementedError
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]],
) -> None:
raise NotImplementedError
@staticmethod
def get_supported_head_sizes() -> List[int]:
return [64, 128, 256]
@dataclass
class FlashInferMetadata(AttentionMetadataPerStage):
is_prompt: bool
use_cuda_graph: bool = False
decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None
# Metadata for the prefill stage since we still
# use flash attention for prefill.
seq_start_loc: Optional[torch.Tensor] = None
max_seq_len: Optional[int] = None
block_tables: Optional[torch.Tensor] = None
# Metadata for the decode stage
# Workspace buffer required by the kernel, the buffer should not
# be allocated/deacollated by the FalshInfermetadata object.
workspace_buffer: Optional[torch.Tensor] = None
# An example for paged_kv_indices, paged_kv_indptr:
# request 1, page indices [0, 5, 8]
# request 2, page indices [1, 6, 7]
# request 3, page indices [3, 4]
# paged_kv_indices is a concatenation of page indices of all requests:
# [0, 5, 8, 1, 6, 7, 3, 4]
# paged_kv_indptr is used to index into paged_kv_indices:
# [0, 3, 6, 8]
# The indptr of the paged kv cache, shape: [batch_size + 1]
paged_kv_indptr: Optional[torch.Tensor] = None
# The page indices of the paged kv cache
paged_kv_indices: Optional[torch.Tensor] = None
# The number of entries in the last page of each request in
# the paged kv cache, shape: [batch_size]
paged_kv_last_page_len: Optional[torch.Tensor] = None
# The number of query/output heads
num_qo_heads: Optional[int] = None
# The number of key/value heads
num_kv_heads: Optional[int] = None
# The dimension of the attention heads
head_dim: Optional[int] = None
# Block size of vllm
page_size: Optional[int] = None
# The data type of the paged kv cache
data_type: torch.dtype = None
def __post_init__(self):
# Refer to
# https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
supported_head_sizes = FlashInferBackend.get_supported_head_sizes()
if self.head_dim is not None and self.head_dim \
not in supported_head_sizes:
raise ValueError(
f"Only {supported_head_sizes} are supported for head_dim,",
f"received {self.head_dim}.")
# When using flashinfer, we are also creating the FlashInferMetadata,
# which will also call post_init by default, here we want to skip the
# post_init if it's the prefill phase.
if not self.is_prompt:
self.decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
self.workspace_buffer, "NHD")
self.decode_wrapper.begin_forward(
self.paged_kv_indptr,
self.paged_kv_indices,
self.paged_kv_last_page_len,
self.num_qo_heads,
self.num_kv_heads,
self.head_dim,
self.page_size,
# Disable flashinfer's pos encoding and use vllm's rope.
pos_encoding_mode="NONE",
data_type=self.data_type)
def asdict_zerocopy(self,
skip_fields: Optional[Set[str]] = None
) -> Dict[str, Any]:
if skip_fields is None:
skip_fields = set()
# We need to skip the decode_wrapper field since it cannot be
# broadcasted with nccl when TP is enabled.
skip_fields.add('decode_wrapper')
return super().asdict_zerocopy(skip_fields)
class FlashInferImpl(AttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None,
) -> None:
if sliding_window is not None:
raise ValueError("Sliding window is not supported in FlashInfer.")
self.sliding_window = (-1, -1)
self.alibi_slopes = alibi_slopes
self.scale = scale
self.num_heads = num_heads
self.head_size = head_size
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
def forward(self, query: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, kv_cache: Optional[torch.Tensor],
attn_metadata: AttentionMetadata[FlashInferMetadata],
kv_scale: float):
num_tokens, hidden_size = query.shape
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
if attn_metadata.num_prefill_tokens > 0:
assert attn_metadata.num_decode_tokens == 0, (
"Chunked prefill is not supported with flashinfer yet.")
if attn_metadata.num_decode_tokens > 0:
assert attn_metadata.num_prefill_tokens == 0, (
"Chunked prefill is not supported with flashinfer yet.")
if kv_cache is not None:
# Use the same reshape and cache kernel as flash attention.
ops.reshape_and_cache_flash(
key,
value,
kv_cache[:, 0],
kv_cache[:, 1],
attn_metadata.slot_mapping.flatten(),
attn_metadata.kv_cache_dtype,
)
if prefill_meta := attn_metadata.prefill_metadata:
assert prefill_meta.block_tables is not None
if kv_cache is None or prefill_meta.block_tables.numel() == 0:
output = flash_attn_varlen_func(
q=query,
k=key,
v=value,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_q=prefill_meta.max_seq_len,
max_seqlen_k=prefill_meta.max_seq_len,
softmax_scale=self.scale,
causal=True,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
)
else:
raise NotImplementedError(
"Prefix caching is not supported with flashinfer yet.")
else:
assert attn_metadata.decode_metadata is not None
assert attn_metadata.decode_metadata.decode_wrapper is not None
query = query.contiguous(
) # Flashinfer requires query to be contiguous
output = attn_metadata.decode_metadata.decode_wrapper.forward(
query,
kv_cache,
sm_scale=self.scale,
)
return output.view(num_tokens, hidden_size)

View File

@@ -0,0 +1,254 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention backend registry"""
from collections.abc import Callable
from enum import Enum, EnumMeta
from typing import TYPE_CHECKING, cast
from vllm.logger import init_logger
from vllm.utils.import_utils import resolve_obj_by_qualname
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
logger = init_logger(__name__)
class _AttentionBackendEnumMeta(EnumMeta):
"""Metaclass for AttentionBackendEnum to provide better error messages."""
def __getitem__(cls, name: str):
"""Get backend by name with helpful error messages."""
try:
return super().__getitem__(name)
except KeyError:
members = cast("dict[str, Enum]", cls.__members__).keys()
valid_backends = ", ".join(members)
raise ValueError(
f"Unknown attention backend: '{name}'. "
f"Valid options are: {valid_backends}"
) from None
class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta):
"""Enumeration of all supported attention backends.
The enum value is the default class path, but this can be overridden
at runtime using register_backend().
To get the actual backend class (respecting overrides), use:
backend.get_class()
"""
FLASH_ATTN = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"
ROCM_ATTN = "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend"
ROCM_AITER_MLA = "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend"
ROCM_AITER_TRITON_MLA = (
"vllm.v1.attention.backends.mla.aiter_triton_mla.AiterTritonMLABackend"
)
ROCM_AITER_FA = (
"vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend"
)
ROCM_AITER_MLA_SPARSE = (
"vllm.v1.attention.backends.mla.rocm_aiter_mla_sparse.ROCMAiterMLASparseBackend"
)
TORCH_SDPA = "" # this tag is only used for ViT
FLASHINFER = "vllm.v1.attention.backends.flashinfer.FlashInferBackend"
FLASHINFER_MLA = (
"vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend"
)
TRITON_MLA = "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend"
CUTLASS_MLA = "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend"
FLASHMLA = "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend"
FLASHMLA_SPARSE = (
"vllm.v1.attention.backends.mla.flashmla_sparse.FlashMLASparseBackend"
)
FLASH_ATTN_MLA = "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend"
PALLAS = "vllm.v1.attention.backends.pallas.PallasAttentionBackend"
IPEX = "vllm.v1.attention.backends.ipex.IpexAttentionBackend"
NO_ATTENTION = "vllm.v1.attention.backends.no_attention.NoAttentionBackend"
FLEX_ATTENTION = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend"
TREE_ATTN = "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend"
ROCM_AITER_UNIFIED_ATTN = (
"vllm.v1.attention.backends.rocm_aiter_unified_attn."
"RocmAiterUnifiedAttentionBackend"
)
CPU_ATTN = "vllm.v1.attention.backends.cpu_attn.CPUAttentionBackend"
# Placeholder for third-party/custom backends - must be registered before use
CUSTOM = ""
def get_path(self, include_classname: bool = True) -> str:
"""Get the class path for this backend (respects overrides).
Returns:
The fully qualified class path string
Raises:
ValueError: If Backend.CUSTOM is used without being registered
"""
path = _ATTN_OVERRIDES.get(self, self.value)
if not path:
raise ValueError(
f"Backend {self.name} must be registered before use. "
f"Use register_backend(Backend.{self.name}, 'your.module.YourClass')"
)
if not include_classname:
path = path.rsplit(".", 1)[0]
return path
def get_class(self) -> "type[AttentionBackend]":
"""Get the backend class (respects overrides).
Returns:
The backend class
Raises:
ImportError: If the backend class cannot be imported
ValueError: If Backend.CUSTOM is used without being registered
"""
return resolve_obj_by_qualname(self.get_path())
def is_overridden(self) -> bool:
"""Check if this backend has been overridden.
Returns:
True if the backend has a registered override
"""
return self in _ATTN_OVERRIDES
def clear_override(self) -> None:
"""Clear any override for this backend, reverting to the default."""
_ATTN_OVERRIDES.pop(self, None)
class MambaAttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta):
"""Enumeration of all supported mamba attention backends.
The enum value is the default class path, but this can be overridden
at runtime using register_backend().
To get the actual backend class (respecting overrides), use:
backend.get_class()
"""
MAMBA1 = "vllm.v1.attention.backends.mamba1_attn.Mamba1AttentionBackend"
MAMBA2 = "vllm.v1.attention.backends.mamba2_attn.Mamba2AttentionBackend"
SHORT_CONV = "vllm.v1.attention.backends.short_conv_attn.ShortConvAttentionBackend"
LINEAR = "vllm.v1.attention.backends.linear_attn.LinearAttentionBackend"
GDN_ATTN = "vllm.v1.attention.backends.gdn_attn.GDNAttentionBackend"
# Placeholder for third-party/custom backends - must be registered before use
CUSTOM = ""
def get_path(self, include_classname: bool = True) -> str:
"""Get the class path for this backend (respects overrides).
Returns:
The fully qualified class path string
Raises:
ValueError: If Backend.CUSTOM is used without being registered
"""
path = _MAMBA_ATTN_OVERRIDES.get(self, self.value)
if not path:
raise ValueError(
f"Backend {self.name} must be registered before use. "
f"Use register_backend(Backend.{self.name}, 'your.module.YourClass')"
)
if not include_classname:
path = path.rsplit(".", 1)[0]
return path
def get_class(self) -> "type[AttentionBackend]":
"""Get the backend class (respects overrides).
Returns:
The backend class
Raises:
ImportError: If the backend class cannot be imported
ValueError: If Backend.CUSTOM is used without being registered
"""
return resolve_obj_by_qualname(self.get_path())
def is_overridden(self) -> bool:
"""Check if this backend has been overridden.
Returns:
True if the backend has a registered override
"""
return self in _MAMBA_ATTN_OVERRIDES
def clear_override(self) -> None:
"""Clear any override for this backend, reverting to the default."""
_MAMBA_ATTN_OVERRIDES.pop(self, None)
MAMBA_TYPE_TO_BACKEND_MAP = {
"mamba1": MambaAttentionBackendEnum.MAMBA1.name,
"mamba2": MambaAttentionBackendEnum.MAMBA2.name,
"short_conv": MambaAttentionBackendEnum.SHORT_CONV.name,
"linear_attention": MambaAttentionBackendEnum.LINEAR.name,
"gdn_attention": MambaAttentionBackendEnum.GDN_ATTN.name,
"custom": MambaAttentionBackendEnum.CUSTOM.name,
}
_ATTN_OVERRIDES: dict[AttentionBackendEnum, str] = {}
_MAMBA_ATTN_OVERRIDES: dict[MambaAttentionBackendEnum, str] = {}
def register_backend(
backend: AttentionBackendEnum | MambaAttentionBackendEnum,
is_mamba: bool = False,
class_path: str | None = None,
) -> Callable[[type], type]:
"""Register or override a backend implementation.
Args:
backend: The AttentionBackendEnum member to register
class_path: Optional class path. If not provided and used as
decorator, will be auto-generated from the class.
Returns:
Decorator function if class_path is None, otherwise a no-op
Examples:
# Override an existing attention backend
@register_backend(AttentionBackendEnum.FLASH_ATTN)
class MyCustomFlashAttn:
...
# Override an existing mamba attention backend
@register_backend(MambaAttentionBackendEnum.LINEAR, is_mamba=True)
class MyCustomMambaAttn:
...
# Register a custom third-party attention backend
@register_backend(AttentionBackendEnum.CUSTOM)
class MyCustomBackend:
...
# Direct registration
register_backend(
AttentionBackendEnum.CUSTOM,
"my.module.MyCustomBackend"
)
"""
def decorator(cls: type) -> type:
if is_mamba:
_MAMBA_ATTN_OVERRIDES[backend] = f"{cls.__module__}.{cls.__qualname__}" # type: ignore[index]
else:
_ATTN_OVERRIDES[backend] = f"{cls.__module__}.{cls.__qualname__}" # type: ignore[index]
return cls
if class_path is not None:
if is_mamba:
_MAMBA_ATTN_OVERRIDES[backend] = class_path # type: ignore[index]
else:
_ATTN_OVERRIDES[backend] = class_path # type: ignore[index]
return lambda x: x
return decorator

View File

@@ -1,374 +0,0 @@
"""Attention layer ROCm GPUs."""
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Type
import torch
import vllm.envs as envs
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata,
AttentionMetadataPerStage)
from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata)
from vllm.logger import init_logger
logger = init_logger(__name__)
class ROCmFlashAttentionBackend(AttentionBackend):
@staticmethod
def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]:
return ROCmFlashAttentionImpl
@staticmethod
def make_metadata(*args, **kwargs) -> "ROCmFlashAttentionMetadata":
return ROCmFlashAttentionMetadata(*args, **kwargs)
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
num_kv_heads, head_size)
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: Dict[int, int],
) -> None:
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]],
) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists)
@dataclass
class ROCmFlashAttentionMetadata(AttentionMetadataPerStage,
PagedAttentionMetadata):
"""Metadata for FlashAttentionBackend.
NOTE: Any python object stored here is not updated when it is
cuda-graph replayed. If you have values that need to be changed
dynamically, it should be stored in tensor. The tensor has to be
updated from `CUDAGraphRunner.forward` API.
"""
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
is_prompt: bool
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens: Optional[List[int]]
# seq_lens stored as a tensor.
seq_lens_tensor: Optional[torch.Tensor]
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ----------------------|
# |-- query_len ---|
# Maximum query length in the batch.
max_query_len: Optional[int]
# Maximum sequence length in the batch.
max_seq_len: Optional[int]
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
subquery_start_loc: Optional[torch.Tensor]
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc: Optional[torch.Tensor]
# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor: Optional[torch.Tensor]
class ROCmFlashAttentionImpl(AttentionImpl):
"""
If the input tensors contain prompt tokens, the layout is as follows:
|<--------------- num_prompt_tokens -------------->|
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|
Otherwise, the layout is as follows:
|<------------------ num_generation_tokens (M) ----------------->|
|<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|
Generation tokens can contain padding when cuda-graph is used.
Currently, prompt tokens don't contain any padding.
The prompts might have different lengths, while the generation tokens
always have length 1.
If chunked prefill is enabled, prefill tokens and decode tokens can be
batched together in a flattened 1D query.
|<----- num_prefill_tokens ---->|<------- num_decode_tokens ----------->|
|<-prompt_0->|...|<-prompt_N-1->|<-generation_0->|...|<-generation_M-1->|
Currently, cuda graph is disabled for chunked prefill, meaning there's no
padding between prefill and decode tokens.
"""
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.sliding_window = ((sliding_window, sliding_window)
if sliding_window is not None else (-1, -1))
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
suppored_head_sizes = PagedAttention.get_supported_head_sizes()
if head_size not in suppored_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by PagedAttention. "
f"Supported head sizes are: {suppored_head_sizes}.")
self.use_naive_attn = False
# NOTE: Allow for switching between Triton and CK. Defaulting to triton.
self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN
if self.use_triton_flash_attn:
from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
triton_attention)
self.attn_func = triton_attention
logger.debug("Using Triton FA in ROCmBackend")
else:
# if not using triton, navi3x not use flash-attn either
if torch.cuda.get_device_capability()[0] == 11:
self.use_naive_attn = True
else:
try:
from flash_attn import flash_attn_varlen_func # noqa: F401
self.attn_func = flash_attn_varlen_func
logger.debug("Using CK FA in ROCmBackend")
except ModuleNotFoundError:
self.use_naive_attn = True
if self.use_naive_attn:
self.attn_func = _naive_attention
logger.debug("Using naive attention in ROCmBackend")
def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
tokens, n_kv_heads, head_dim = x.shape
return (x[:, :,
None, :].expand(tokens, n_kv_heads, n_rep,
head_dim).reshape(tokens, n_kv_heads * n_rep,
head_dim))
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata[ROCmFlashAttentionMetadata],
kv_scale: float = 1.0,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.
Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
if kv_cache is not None:
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)
# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
PagedAttention.write_to_paged_cache(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
attn_metadata.kv_cache_dtype,
kv_scale,
)
num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
output = torch.empty_like(query)
# Query for decode. KV is not needed because it is already cached.
decode_query = query[num_prefill_tokens:]
# QKV for prefill.
query = query[:num_prefill_tokens]
key = key[:num_prefill_tokens]
value = value[:num_prefill_tokens]
assert query.shape[0] == num_prefill_tokens
assert decode_query.shape[0] == num_decode_tokens
if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
assert prefill_meta.seq_lens is not None
if kv_cache is None or prefill_meta.block_tables.numel() == 0:
# triton attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
if self.use_triton_flash_attn:
out, _ = self.attn_func(
query,
key,
value,
None,
prefill_meta.seq_start_loc,
prefill_meta.seq_start_loc,
prefill_meta.max_seq_len,
prefill_meta.max_seq_len,
True,
self.scale,
)
elif self.use_naive_attn:
if self.num_kv_heads != self.num_heads:
# Interleave for MQA workaround.
key = self.repeat_kv(key, self.num_queries_per_kv)
value = self.repeat_kv(value, self.num_queries_per_kv)
out = self.attn_func(
query,
key,
value,
prefill_meta.seq_lens,
self.scale,
)
else:
out = self.attn_func(
q=query,
k=key,
v=value,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_q=prefill_meta.max_seq_len,
max_seqlen_k=prefill_meta.max_seq_len,
softmax_scale=self.scale,
causal=True,
)
# common code for prefill
assert output[:num_prefill_tokens].shape == out.shape
output[:num_prefill_tokens] = out
else:
# prefix-enabled attention
output[:num_prefill_tokens] = PagedAttention.forward_prefix(
query,
key,
value,
key_cache,
value_cache,
prefill_meta.block_tables,
prefill_meta.subquery_start_loc,
prefill_meta.seq_lens_tensor,
prefill_meta.context_lens_tensor,
prefill_meta.max_query_len,
self.alibi_slopes,
self.sliding_window[0],
)
if decode_meta := attn_metadata.decode_metadata:
# Decoding run.
output[num_prefill_tokens:] = PagedAttention.forward_decode(
decode_query,
key_cache,
value_cache,
decode_meta.block_tables,
decode_meta.seq_lens_tensor,
decode_meta.max_seq_len,
attn_metadata.kv_cache_dtype,
self.num_kv_heads,
self.scale,
self.alibi_slopes,
kv_scale,
)
# Reshape the output tensor.
return output.view(num_tokens, hidden_size)
def _naive_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
seq_lens: List[int],
scale: float,
) -> torch.Tensor:
output = torch.empty_like(query)
start = 0
for _, seq_len in enumerate(seq_lens):
end = start + seq_len
out = _naive_masked_attention(
query[start:end],
key[start:end],
value[start:end],
scale,
)
# TODO(woosuk): Unnecessary copy. Optimize.
output[start:end].copy_(out)
start += seq_len
return output
def _naive_masked_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
scale: float,
) -> torch.Tensor:
seq_len, head_size, head_dim = query.shape
attn_mask = torch.triu(torch.ones(seq_len,
seq_len,
dtype=query.dtype,
device=query.device),
diagonal=1)
attn_mask = attn_mask * torch.finfo(query.dtype).min
attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
attn_weights = attn_weights + attn_mask.float()
attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
out = torch.einsum("hqk,khd->qhd", attn_weights, value)
return out

View File

@@ -1,253 +0,0 @@
""" Attention layer with torch scaled_dot_product_attention
and PagedAttention."""
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Type
import torch
import torch_musa
from torch.nn.functional import scaled_dot_product_attention
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata,
AttentionMetadataPerStage)
from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata)
class TorchSDPABackend(AttentionBackend):
@staticmethod
def get_impl_cls() -> Type["TorchSDPABackendImpl"]:
return TorchSDPABackendImpl
@staticmethod
def make_metadata(*args, **kwargs) -> "TorchSDPAMetadata":
return TorchSDPAMetadata(*args, **kwargs)
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
num_kv_heads, head_size)
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: Dict[int, int],
) -> None:
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]],
) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists)
@dataclass
class TorchSDPAMetadata(PagedAttentionMetadata, AttentionMetadata):
"""Metadata for TorchSDPABackend.
"""
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
is_prompt: bool
slot_mapping: torch.Tensor
seq_lens: Optional[List[int]]
def __post_init__(self):
# Set during the execution of the first attention op.
# It is a list because it is needed to set per prompt
# when alibi slopes is used. It is because of the limitation
# from xformer API.
# will not appear in the __repr__ and __init__
self.attn_bias: Optional[List[torch.Tensor]] = None
class TorchSDPABackendImpl(AttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.sliding_window = sliding_window
if alibi_slopes is not None:
assert len(alibi_slopes) == num_heads
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
self.need_mask = (self.alibi_slopes is not None
or self.sliding_window is not None)
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
suppored_head_sizes = PagedAttention.get_supported_head_sizes()
if head_size not in suppored_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by PagedAttention. "
f"Supported head sizes are: {suppored_head_sizes}.")
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: Optional[torch.Tensor],
attn_metadata: TorchSDPAMetadata, # type: ignore
kv_scale: float,
) -> torch.Tensor:
"""Forward pass with torch SDPA and PagedAttention.
Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_math_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(True)
if kv_cache is not None:
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)
PagedAttention.write_to_paged_cache(key, value, key_cache,
value_cache,
attn_metadata.slot_mapping,
attn_metadata.kv_cache_dtype,
kv_scale)
if attn_metadata.is_prompt:
assert attn_metadata.seq_lens is not None
if (kv_cache is None or attn_metadata.block_tables.numel() == 0):
if self.num_kv_heads != self.num_heads:
key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
value = value.repeat_interleave(self.num_queries_per_kv,
dim=1)
if attn_metadata.attn_bias is None:
if self.alibi_slopes is not None:
att_masks = _make_alibi_bias(
self.alibi_slopes, query.dtype,
attn_metadata.seq_lens) # type: ignore
elif self.sliding_window is not None:
att_masks = _make_sliding_window_bias(
attn_metadata.prefill_metadata.seq_lens, self.sliding_window,
query.dtype) # type: ignore
else:
att_masks = [None] * len(attn_metadata.prefill_metadata.seq_lens)
attn_metadata.prefill_metadata.attn_bias = att_masks
query = query.movedim(0, query.dim() - 2).unsqueeze(0)
key = key.movedim(0, key.dim() - 2).unsqueeze(0)
value = value.movedim(0, value.dim() - 2).unsqueeze(0)
start = 0
output = torch.empty(
(1, num_tokens, self.num_heads, self.head_size),
dtype=query.dtype)
for seq_len, mask in zip(attn_metadata.prefill_metadata.seq_lens,
attn_metadata.prefill_metadata.attn_bias):
end = start + seq_len
sub_out = scaled_dot_product_attention(
query[:, :, start:end, :],
key[:, :, start:end, :],
value[:, :, start:end, :],
attn_mask=mask,
dropout_p=0.0,
is_causal=not self.need_mask,
scale=self.scale).movedim(query.dim() - 2, 0)
output[start:end, :, :] = sub_out[0]
start = end
else:
# prefix-enabled attention
raise RuntimeError(
"Torch SDPA backend doesn't support prefix decoding.")
else:
# Decoding run.
output = PagedAttention.forward_decode(
query,
key_cache,
value_cache,
attn_metadata.block_tables,
attn_metadata.seq_lens_tensor,
attn_metadata.max_seq_len,
attn_metadata.kv_cache_dtype,
self.num_kv_heads,
self.scale,
self.alibi_slopes,
kv_scale,
)
# Reshape the output tensor.
return output.view(-1, self.num_heads * self.head_size)
def _make_alibi_bias(
alibi_slopes: torch.Tensor,
dtype: torch.dtype,
seq_lens: List[int],
) -> List[torch.Tensor]:
attn_biases = []
for seq_len in seq_lens:
bias = torch.arange(seq_len, dtype=dtype)
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(seq_len, 1)`
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# paper.
bias = bias[None, :] - bias[:, None]
num_heads = alibi_slopes.shape[0]
bias = bias[None, :].repeat((num_heads, 1, 1))
bias.mul_(alibi_slopes[:, None, None])
inf_mask = torch.empty(
(1, seq_len, seq_len),
dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1)
attn_biases.append((bias + inf_mask).to(dtype))
return attn_biases
def _make_sliding_window_bias(
seq_lens: List[int],
window_size: Optional[int],
dtype: torch.dtype,
) -> List[torch.Tensor]:
attn_biases = []
for seq_len in seq_lens:
tensor = torch.full(
(1, seq_len, seq_len),
dtype=dtype,
fill_value=1,
)
shift = 0
mask = torch.tril(tensor, diagonal=shift).to(dtype) # type: ignore
if window_size is not None:
mask = torch.triu(mask, diagonal=shift - window_size + 1)
mask = torch.log(mask)
attn_biases.append(mask.to(dtype))
return attn_biases

View File

@@ -0,0 +1,33 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention backend utils"""
from dataclasses import dataclass
from vllm.config import ModelConfig
from vllm.logger import init_logger
logger = init_logger(__name__)
PAD_SLOT_ID = -1
@dataclass
class MLADims:
q_lora_rank: int | None
kv_lora_rank: int
qk_nope_head_dim: int
qk_rope_head_dim: int
v_head_dim: int
def get_mla_dims(model_config: ModelConfig) -> MLADims:
hf_text_config = model_config.hf_text_config
return MLADims(
q_lora_rank=getattr(hf_text_config, "q_lora_rank", None),
kv_lora_rank=hf_text_config.kv_lora_rank,
qk_nope_head_dim=hf_text_config.qk_nope_head_dim,
qk_rope_head_dim=hf_text_config.qk_rope_head_dim,
v_head_dim=hf_text_config.v_head_dim,
)

View File

@@ -1,393 +0,0 @@
"""Attention layer with xFormers and PagedAttention."""
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Type
import torch
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import (AttentionBias,
BlockDiagonalCausalMask,
LowerTriangularMaskWithTensorBias)
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata,
AttentionMetadataPerStage)
from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata)
from vllm.logger import init_logger
logger = init_logger(__name__)
class XFormersBackend(AttentionBackend):
@staticmethod
def get_impl_cls() -> Type["XFormersImpl"]:
return XFormersImpl
@staticmethod
def make_metadata(*args, **kwargs) -> "XFormersMetadata":
return XFormersMetadata(*args, **kwargs)
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
num_kv_heads, head_size)
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: Dict[int, int],
) -> None:
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]],
) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists)
@dataclass
class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
"""Metadata for XFormersbackend.
NOTE: Any python object stored here is not updated when it is
cuda-graph replayed. If you have values that need to be changed
dynamically, it should be stored in tensor. The tensor has to be
updated from `CUDAGraphRunner.forward` API.
"""
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
is_prompt: bool
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens: Optional[List[int]]
# seq_lens stored as a tensor.
seq_lens_tensor: Optional[torch.Tensor]
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ----------------------|
# |-- query_len ---|
# Maximum query length in the batch.
max_query_len: Optional[int]
# FIXME: It is for flash attn.
# Maximum sequence length in the batch.
max_seq_len: Optional[int]
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
subquery_start_loc: Optional[torch.Tensor]
# FIXME: It is for flash attn.
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc: Optional[torch.Tensor]
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor: Optional[torch.Tensor]
# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool
def __post_init__(self):
# Set during the execution of the first attention op.
# It is a list because it is needed to set per prompt
# when alibi slopes is used. It is because of the limitation
# from xformer API.
# will not appear in the __repr__ and __init__
self.attn_bias: Optional[List[AttentionBias]] = None
class XFormersImpl(AttentionImpl):
"""
If the input tensors contain prompt tokens, the layout is as follows:
|<--------------- num_prefill_tokens ----------------->|
|<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
Otherwise, the layout is as follows:
|<----------------- num_decode_tokens ------------------>|
|<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
Generation tokens can contain padding when cuda-graph is used.
Currently, prompt tokens don't contain any padding.
The prompts might have different lengths, while the generation tokens
always have length 1.
If chunked prefill is enabled, prefill tokens and decode tokens can be
batched together in a flattened 1D query.
|<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
|<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|
Currently, cuda graph is disabled for chunked prefill, meaning there's no
padding between prefill and decode tokens.
"""
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.sliding_window = sliding_window
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
suppored_head_sizes = PagedAttention.get_supported_head_sizes()
if head_size not in suppored_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by PagedAttention. "
f"Supported head sizes are: {suppored_head_sizes}.")
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: Optional[torch.Tensor],
attn_metadata: AttentionMetadata[XFormersMetadata],
kv_scale: float,
) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention.
Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
num_tokens, hidden_size = query.shape
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
if kv_cache is not None:
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)
# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
PagedAttention.write_to_paged_cache(key, value, key_cache,
value_cache,
attn_metadata.slot_mapping,
attn_metadata.kv_cache_dtype,
kv_scale)
num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
output = torch.empty_like(query)
# Query for decode. KV is not needed because it is already cached.
decode_query = query[num_prefill_tokens:]
# QKV for prefill.
query = query[:num_prefill_tokens]
key = key[:num_prefill_tokens]
value = value[:num_prefill_tokens]
assert query.shape[0] == num_prefill_tokens
assert decode_query.shape[0] == num_decode_tokens
if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
if kv_cache is None or prefill_meta.block_tables.numel() == 0:
# normal attention.
# block tables are empty if the prompt does not have a cached
# prefix.
out = self._run_memory_efficient_xformers_forward(
query, key, value, prefill_meta)
assert out.shape == output[:num_prefill_tokens].shape
output[:num_prefill_tokens] = out
else:
# prefix-enabled attention
# TODO(Hai) this triton kernel has regression issue (broke) to
# deal with different data types between KV and FP8 KV cache,
# to be addressed separately.
out = PagedAttention.forward_prefix(
query,
key,
value,
key_cache,
value_cache,
prefill_meta.block_tables,
prefill_meta.subquery_start_loc,
prefill_meta.seq_lens_tensor,
prefill_meta.context_lens_tensor,
prefill_meta.max_query_len,
self.alibi_slopes,
self.sliding_window,
)
assert output[:num_prefill_tokens].shape == out.shape
output[:num_prefill_tokens] = out
if decode_meta := attn_metadata.decode_metadata:
output[num_prefill_tokens:] = PagedAttention.forward_decode(
decode_query,
key_cache,
value_cache,
decode_meta.block_tables,
decode_meta.seq_lens_tensor,
decode_meta.max_seq_len,
attn_metadata.kv_cache_dtype,
self.num_kv_heads,
self.scale,
self.alibi_slopes,
kv_scale,
)
# Reshape the output tensor.
return output.view(-1, self.num_heads * self.head_size)
def _run_memory_efficient_xformers_forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: XFormersMetadata,
) -> torch.Tensor:
"""Attention for 1D query of multiple prompts. Multiple prompt
tokens are flattened in to `query` input.
See https://facebookresearch.github.io/xformers/components/ops.html
for API spec.
Args:
output: shape = [num_prefill_tokens, num_heads, head_size]
query: shape = [num_prefill_tokens, num_heads, head_size]
key: shape = [num_prefill_tokens, num_kv_heads, head_size]
value: shape = [num_prefill_tokens, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
"""
assert attn_metadata.seq_lens is not None
original_query = query
if self.num_kv_heads != self.num_heads:
# GQA/MQA requires the shape [B, M, G, H, K].
# Note that the output also has the same shape (which is different
# from a spec from the doc).
query = query.view(query.shape[0], self.num_kv_heads,
self.num_queries_per_kv, query.shape[-1])
key = key[:, :,
None, :].expand(key.shape[0], self.num_kv_heads,
self.num_queries_per_kv, key.shape[-1])
value = value[:, :,
None, :].expand(value.shape[0], self.num_kv_heads,
self.num_queries_per_kv,
value.shape[-1])
# Set attention bias if not provided. This typically happens at
# the very attention layer of every iteration.
# FIXME(woosuk): This is a hack.
if attn_metadata.attn_bias is None:
if self.alibi_slopes is None:
attn_bias = BlockDiagonalCausalMask.from_seqlens(
attn_metadata.seq_lens)
if self.sliding_window is not None:
attn_bias = attn_bias.make_local_attention(
self.sliding_window)
attn_metadata.attn_bias = [attn_bias]
else:
attn_metadata.attn_bias = _make_alibi_bias(
self.alibi_slopes, self.num_kv_heads, query.dtype,
attn_metadata.seq_lens)
# No alibi slopes.
# TODO(woosuk): Too many view operations. Let's try to reduce
# them in the future for code readability.
if self.alibi_slopes is None:
# Add the batch dimension.
query = query.unsqueeze(0)
key = key.unsqueeze(0)
value = value.unsqueeze(0)
out = xops.memory_efficient_attention_forward(
query,
key,
value,
attn_bias=attn_metadata.attn_bias[0],
p=0.0,
scale=self.scale)
return out.view_as(original_query)
# Attention with alibi slopes.
# FIXME(woosuk): Because xformers does not support dynamic sequence
# lengths with custom attention bias, we process each prompt one by
# one. This is inefficient, especially when we have many short prompts.
output = torch.empty_like(original_query)
start = 0
for i, seq_len in enumerate(attn_metadata.seq_lens):
end = start + seq_len
out = xops.memory_efficient_attention_forward(
query[None, start:end],
key[None, start:end],
value[None, start:end],
attn_bias=attn_metadata.attn_bias[i],
p=0.0,
scale=self.scale)
# TODO(woosuk): Unnecessary copy. Optimize.
output[start:end].copy_(out.view_as(original_query[start:end]))
start += seq_len
return output
def _make_alibi_bias(
alibi_slopes: torch.Tensor,
num_kv_heads: int,
dtype: torch.dtype,
seq_lens: List[int],
) -> LowerTriangularMaskWithTensorBias:
attn_biases = []
for seq_len in seq_lens:
bias = torch.arange(seq_len, dtype=dtype)
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(seq_len, 1)`
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# paper.
# Calculate a matrix where each element represents ith element- jth
# element.
bias = bias[None, :] - bias[:, None]
padded_len = (seq_len + 7) // 8 * 8
num_heads = alibi_slopes.shape[0]
bias = torch.empty(
1, # batch size
num_heads,
seq_len,
padded_len,
device=alibi_slopes.device,
dtype=dtype,
)[:, :, :, :seq_len].copy_(bias)
bias.mul_(alibi_slopes[:, None, None])
if num_heads != num_kv_heads:
bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
attn_biases.append(LowerTriangularMaskWithTensorBias(bias))
return attn_biases

View File

@@ -1,15 +1,121 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer."""
from typing import List, Optional
import functools
from typing import cast
import torch
import torch.nn as nn
import torch.nn.functional as F
from vllm.attention.backends.abstract import (AttentionMetadata,
AttentionMetadataPerStage)
import vllm.envs as envs
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionType,
MLAAttentionImpl,
)
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layers.mm_encoder_attention import maybe_get_vit_flash_attn_backend
from vllm.attention.selector import get_attn_backend
from vllm.attention.utils.fa_utils import get_flash_attn_version
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
from vllm.attention.utils.kv_transfer_utils import maybe_transfer_kv_layer
from vllm.config import CacheConfig, get_current_vllm_config
from vllm.config.multimodal import MultiModalConfig
from vllm.config.vllm import VllmConfig
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
UnquantizedLinearMethod,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.model_executor.models.vision import get_vit_attn_backend
from vllm.platforms import current_platform
from vllm.utils.torch_utils import (
direct_register_custom_op,
kv_cache_dtype_str_to_dtype,
)
from vllm.v1.kv_cache_interface import (
FullAttentionSpec,
KVCacheSpec,
MLAAttentionSpec,
SlidingWindowSpec,
)
logger = init_logger(__name__)
class Attention(nn.Module):
def _init_kv_cache_quant(
layer: nn.Module,
quant_config: QuantizationConfig | None,
prefix: str,
kv_cache_dtype: str,
calculate_kv_scales: bool,
) -> None:
"""Initializes KV cache scaling factors and quantization method.
This helper function sets up the KV cache quantization attributes that are
shared between Attention and MLAAttention layers. It initializes scale
tensors for query, key, value, and probability, and configures the
quantization method if applicable.
Args:
layer: The attention layer instance to initialize.
quant_config: Optional quantization configuration.
prefix: Layer name prefix for quantization method lookup.
kv_cache_dtype: The KV cache data type string.
calculate_kv_scales: Whether to calculate KV scales dynamically.
"""
# The default k/v_scale is set to 1.0. This is ignored
# when kv-cache is not fp8, and should be used with
# kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
# expect the pre-quantized k/v_scale to be loaded along
# with the model weights.
layer.kv_cache_dtype = kv_cache_dtype
layer.calculate_kv_scales = calculate_kv_scales
layer._k_scale = torch.tensor(1.0, dtype=torch.float32)
layer._v_scale = torch.tensor(1.0, dtype=torch.float32)
layer._q_scale = torch.tensor(1.0, dtype=torch.float32)
layer._prob_scale = torch.tensor(1.0, dtype=torch.float32)
# We also keep q/k/v_scale on host (cpu) memory for attention
# backends that require the scales to be on host instead of on device.
# e.g. Flashinfer
layer._q_scale_float = 1.0
layer._k_scale_float = 1.0
layer._v_scale_float = 1.0
# The output scale on host memory. This should be the input scale of
# the quant op after this attention layer.
layer._o_scale_float = None
quant_method = (
quant_config.get_quant_method(layer, prefix=prefix) if quant_config else None
)
if quant_method is not None and not isinstance(
quant_method, UnquantizedLinearMethod
):
assert isinstance(quant_method, BaseKVCacheMethod)
# TODO (mgoin): kv cache dtype should be specified in the FP8
# checkpoint config and become the "auto" behavior
if kv_cache_dtype == "fp8_e5m2":
raise ValueError("fp8_e5m2 kv-cache is not supported with fp8 checkpoints.")
# If quantization is enabled, we make "k_scale" and "v_scale"
# parameters so that it can be loaded from the model checkpoint.
# The k/v_scale will then be converted back to native float32
# values after weight loading.
layer.quant_method = quant_method
layer.quant_method.create_weights(layer)
class Attention(nn.Module, AttentionLayerBase):
"""Attention layer.
This class takes query, key, and value tensors as input. The input tensors
@@ -26,31 +132,838 @@ class Attention(nn.Module):
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None,
num_kv_heads: int | None = None,
alibi_slopes: list[float] | None = None,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
logits_soft_cap: float | None = None,
per_layer_sliding_window: int | None = None,
prefix: str = "",
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: str | None = None,
attn_backend: type[AttentionBackend] | None = None,
**extra_impl_args,
) -> None:
"""
The KV cache is stored inside this class and is accessed via
`self.kv_cache`.
"""
super().__init__()
self.backend = get_attn_backend(torch.get_default_dtype())
impl_cls = self.backend.get_impl_cls()
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window)
if per_layer_sliding_window is not None:
# per-layer sliding window
sliding_window = per_layer_sliding_window
elif cache_config is not None:
# model-level sliding window
sliding_window = cache_config.sliding_window
else:
sliding_window = None
vllm_config = get_current_vllm_config()
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
calculate_kv_scales = cache_config.calculate_kv_scales
else:
kv_cache_dtype = "auto"
block_size = 16
calculate_kv_scales = False
self.kv_cache_torch_dtype = kv_cache_dtype_str_to_dtype(
kv_cache_dtype, vllm_config.model_config
)
if num_kv_heads is None:
num_kv_heads = num_heads
assert num_heads % num_kv_heads == 0, (
f"num_heads ({num_heads}) is not divisible by num_kv_heads ({num_kv_heads})"
)
# Initialize KV cache quantization attributes
_init_kv_cache_quant(
self, quant_config, prefix, kv_cache_dtype, calculate_kv_scales
)
self.num_heads = num_heads
self.head_size = head_size
self.num_kv_heads = num_kv_heads
self.sliding_window = sliding_window
self.has_sink = extra_impl_args.get("sinks") is not None
# NOTE: model_config may be None during certain tests
model_config = vllm_config.model_config
self.use_mm_prefix = model_config is not None and model_config.is_mm_prefix_lm
# During model initialization, the default dtype is set as the model
# weight and activation dtype.
dtype = torch.get_default_dtype()
if attn_backend is None:
self.attn_backend = get_attn_backend(
head_size,
dtype,
kv_cache_dtype,
block_size,
use_mla=False,
has_sink=self.has_sink,
use_mm_prefix=self.use_mm_prefix,
attn_type=attn_type,
)
else:
self.attn_backend = attn_backend
# prefix caching + batch invariance is currently not supported for
# FLASHINFER and TRITON_MLA.
if (
cache_config is not None
and cache_config.enable_prefix_caching
and vllm_is_batch_invariant()
and (
self.attn_backend.get_name() == "FLASHINFER"
or self.attn_backend.get_name() == "TRITON_MLA"
)
):
logger.warning_once(
"Disabling prefix caching for FLASHINFER/TRITON_MLA "
"with batch invariance, as it is not yet supported.",
scope="local",
)
cache_config.enable_prefix_caching = False
impl_cls = self.attn_backend.get_impl_cls()
self.impl = impl_cls(
num_heads,
head_size,
scale,
num_kv_heads,
alibi_slopes,
sliding_window,
kv_cache_dtype,
logits_soft_cap,
attn_type,
kv_sharing_target_layer_name,
**extra_impl_args,
)
backend_name = self.attn_backend.get_name()
self.backend = AttentionBackendEnum.__members__.get(backend_name)
self.dtype = dtype
# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
# torch.compile works by registering the attention as one giant
# opaque custom op. For other platforms, we directly call them
# and let torch.compile handle them.
self.use_direct_call = not current_platform.opaque_attention_op()
self.use_output = self.attn_backend.accept_output_buffer
compilation_config = vllm_config.compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
self.layer_name = prefix
self.attn_type = attn_type
if kv_sharing_target_layer_name is not None:
validate_kv_sharing_target(
prefix,
kv_sharing_target_layer_name,
compilation_config.static_forward_context,
)
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
# use a placeholder kv cache tensor during init, which will be replaced
# by bind_kv_cache
# this variable will not be accessed if use_direct_call is True
self.kv_cache = [
torch.tensor([])
for _ in range(vllm_config.parallel_config.pipeline_parallel_size)
]
# Initialize q/k/v range constants.
self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)
# for attn backends supporting query quantization
self.query_quant = None
if (
self.kv_cache_dtype.startswith("fp8")
and self.impl.supports_quant_query_input
):
self.query_quant = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: Optional[torch.Tensor],
attn_metadata: AttentionMetadata[AttentionMetadataPerStage],
kv_scale: float = 1.0,
# For some alternate attention backends like MLA the attention output
# shape does not match the query shape, so we optionally let the model
# definition specify the output tensor shape.
output_shape: torch.Size | None = None,
) -> torch.Tensor:
return self.impl.forward(query, key, value, kv_cache, attn_metadata,
kv_scale)
"""
The KV cache is stored inside this class and is accessed via
`self.kv_cache`.
Attention metadata (`attn_metadata`) is set using a context manager in
the model runner's `execute_model` method. It is accessed via forward
context using
`vllm.forward_context.get_forward_context().attn_metadata`.
"""
if self.calculate_kv_scales:
torch.ops.vllm.maybe_calc_kv_scales(query, key, value, self.layer_name)
output_dtype = query.dtype
if self.query_quant is not None:
# quantizing with a simple torch operation enables
# torch.compile to fuse this into previous ops
# which reduces overheads during decoding.
# Otherwise queries are quantized using custom ops
# which causes decoding overheads
assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"}
# check if query quantization is supported
if self.impl.supports_quant_query_input:
query, _ = self.query_quant(query, self._q_scale)
if self.use_output:
output_shape = output_shape if output_shape is not None else query.shape
output = torch.empty(output_shape, dtype=output_dtype, device=query.device)
hidden_size = output_shape[-1]
# Reshape the query, key, and value tensors.
# NOTE(woosuk): We do this outside the custom op to minimize the
# CPU overheads from the non-CUDA-graph regions.
query = query.view(-1, self.num_heads, self.head_size)
output = output.view(-1, self.num_heads, self.head_size)
if key is not None:
key = key.view(-1, self.num_kv_heads, self.head_size)
if value is not None:
value = value.view(-1, self.num_kv_heads, self.head_size)
if self.use_direct_call:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[self.layer_name]
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
self.impl.forward(
self, query, key, value, self_kv_cache, attn_metadata, output=output
)
else:
torch.ops.vllm.unified_attention_with_output(
query, key, value, output, self.layer_name
)
return output.view(-1, hidden_size)
else:
if self.use_direct_call:
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[self.layer_name]
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
return self.impl.forward(
self, query, key, value, self_kv_cache, attn_metadata
)
else:
return torch.ops.vllm.unified_attention(
query, key, value, self.layer_name
)
def calc_kv_scales(self, query, key, value):
self._q_scale.copy_(torch.abs(query).max() / self.q_range)
self._k_scale.copy_(torch.abs(key).max() / self.k_range)
self._v_scale.copy_(torch.abs(value).max() / self.v_range)
self._q_scale_float = self._q_scale.item()
self._k_scale_float = self._k_scale.item()
self._v_scale_float = self._v_scale.item()
# We only calculate the scales once
self.calculate_kv_scales = False
def extra_repr(self) -> str:
s = f"head_size={self.impl.head_size}" # type: ignore
s += f", num_heads={self.impl.num_heads}" # type: ignore
s += f", num_kv_heads={self.impl.num_kv_heads}" # type: ignore
s += f", scale={self.impl.scale}" # type: ignore
s += f", backend={self.impl.__class__.__name__}"
return s
def process_weights_after_loading(self, act_dtype: torch.dtype):
self.impl.process_weights_after_loading(act_dtype)
def get_attn_backend(self) -> type[AttentionBackend]:
return self.attn_backend
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
# Block size may get updated after model loading, refresh it
block_size = vllm_config.cache_config.block_size
# Should not be called for enc-dec or encoder-only attention.
assert self.attn_type == AttentionType.DECODER
if self.sliding_window is not None:
assert not vllm_config.model_config.use_mla, (
"MLA is not supported for slidingwindow"
)
return SlidingWindowSpec(
block_size=block_size,
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
dtype=self.kv_cache_torch_dtype,
sliding_window=self.sliding_window,
)
else:
return FullAttentionSpec(
block_size=block_size,
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
dtype=self.kv_cache_torch_dtype,
)
class MultiHeadAttention(nn.Module):
"""Multi-headed attention without any cache, used for ViT."""
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int | None = None,
# This has no effect, it is only here to make it easier to swap
# between Attention and MultiHeadAttention
prefix: str = "",
multimodal_config: MultiModalConfig | None = None,
) -> None:
super().__init__()
self.num_heads = num_heads
self.head_size = head_size
self.scale = scale
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.layer_name = prefix
assert self.num_heads % self.num_kv_heads == 0, (
f"num_heads ({self.num_heads}) is not "
f"divisible by num_kv_heads ({self.num_kv_heads})"
)
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
# During model initialization, the default dtype is set as the model
# weight and activation dtype.
dtype = torch.get_default_dtype()
# Determine the attention backend
attn_backend_override = None
if multimodal_config is not None:
attn_backend_override = multimodal_config.mm_encoder_attn_backend
self.attn_backend = get_vit_attn_backend(
head_size=head_size,
dtype=dtype,
attn_backend_override=attn_backend_override,
)
self._flash_attn_varlen_func = maybe_get_vit_flash_attn_backend(
self.attn_backend,
)
self.is_flash_attn_backend = self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}
self.fa_version = None
if (
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
and current_platform.is_cuda()
):
self.fa_version = get_flash_attn_version()
assert self._flash_attn_varlen_func is not None
self._flash_attn_varlen_func = functools.partial(
self._flash_attn_varlen_func, fa_version=self.fa_version
)
logger.info_once(
f"Using {self.attn_backend} for MultiHeadAttention in multimodal encoder."
)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
) -> torch.Tensor:
"""Input shape:
(batch_size x seq_len x hidden_size) or
(batch_size x seq_len x num_heads x head_size)
"""
bsz, q_len = query.size()[:2]
kv_len = key.size(1)
query = query.view(bsz, q_len, self.num_heads, self.head_size)
key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)
if (num_repeat := self.num_queries_per_kv) > 1:
# Handle MQA and GQA
key = torch.repeat_interleave(key, num_repeat, dim=2)
value = torch.repeat_interleave(value, num_repeat, dim=2)
if self.is_flash_attn_backend:
assert self._flash_attn_varlen_func is not None
cu_seqlens_q = torch.arange(
0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=query.device
)
cu_seqlens_k = torch.arange(
0, (bsz + 1) * kv_len, step=kv_len, dtype=torch.int32, device=key.device
)
out = self._flash_attn_varlen_func(
query.flatten(0, 1),
key.flatten(0, 1),
value.flatten(0, 1),
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=q_len,
max_seqlen_k=kv_len,
softmax_scale=self.scale,
)
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
out = F.scaled_dot_product_attention(query, key, value, scale=self.scale)
out = out.transpose(1, 2)
elif self.attn_backend == AttentionBackendEnum.PALLAS:
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
from torch_xla.experimental.custom_kernel import flash_attention
out = flash_attention(query, key, value, sm_scale=self.scale)
out = out.transpose(1, 2)
else:
# ViT attention hasn't supported this backend yet
raise NotImplementedError(
f"ViT attention hasn't supported {self.attn_backend} backend yet."
)
return out.reshape(bsz, q_len, -1)
class MLAAttention(nn.Module, AttentionLayerBase):
"""Multi-Head Latent Attention layer.
This class takes query, and compressed key/value tensors as input.
The class does the following:
1. Store the input key and value tensors in the KV cache.
2. Perform (multi-head/multi-query/grouped-query) attention.
3. Return the output tensor.
"""
def __init__(
self,
num_heads: int,
scale: float,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
v_head_dim: int,
q_lora_rank: int | None,
kv_lora_rank: int,
kv_b_proj: ColumnParallelLinear,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_sparse: bool = False,
indexer: object | None = None,
**extra_impl_args,
):
super().__init__()
self.num_heads = num_heads
self.scale = scale
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.v_head_dim = v_head_dim
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.head_size = kv_lora_rank + qk_rope_head_dim
self.layer_name = prefix
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
calculate_kv_scales = cache_config.calculate_kv_scales
else:
kv_cache_dtype = "auto"
block_size = 16
calculate_kv_scales = False
# Initialize KV cache quantization attributes
_init_kv_cache_quant(
self, quant_config, prefix, kv_cache_dtype, calculate_kv_scales
)
dtype = torch.get_default_dtype()
self.attn_backend = get_attn_backend(
self.head_size,
dtype,
kv_cache_dtype,
block_size,
use_mla=True,
use_sparse=use_sparse,
)
if (
cache_config is not None
and cache_config.enable_prefix_caching
and vllm_is_batch_invariant()
and (
self.attn_backend.get_name() == "TRITON_MLA"
or self.attn_backend.get_name() == "FLASHINFER"
)
):
logger.warning_once(
"Disabling prefix caching for TRITON_MLA / FLASHINFER "
"with batch invariance, as it is not yet supported.",
scope="local",
)
cache_config.enable_prefix_caching = False
impl_cls = cast(type[MLAAttentionImpl], self.attn_backend.get_impl_cls())
self.impl = impl_cls(
num_heads=self.num_heads,
head_size=self.head_size,
scale=self.scale,
num_kv_heads=1,
alibi_slopes=None,
sliding_window=None,
kv_cache_dtype=self.kv_cache_dtype,
logits_soft_cap=None,
attn_type=AttentionType.DECODER,
kv_sharing_target_layer_name=None,
# MLA Args
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank,
qk_nope_head_dim=self.qk_nope_head_dim,
qk_rope_head_dim=self.qk_rope_head_dim,
qk_head_dim=self.qk_nope_head_dim + self.qk_rope_head_dim,
v_head_dim=self.v_head_dim,
kv_b_proj=kv_b_proj,
indexer=indexer,
**extra_impl_args,
)
self.use_direct_call = not current_platform.opaque_attention_op()
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
self.kv_cache = [
torch.tensor([])
for _ in range(
get_current_vllm_config().parallel_config.pipeline_parallel_size
)
]
self.use_sparse = use_sparse
# Initialize q/k/v range constants.
self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)
def forward(
self,
q: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
output_shape: torch.Size | None = None,
) -> torch.Tensor:
if self.calculate_kv_scales:
torch.ops.vllm.maybe_calc_kv_scales(q, kv_c_normed, k_pe, self.layer_name)
if self.use_direct_call:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[self.layer_name]
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
if self.attn_backend.accept_output_buffer:
output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
self.impl.forward(
self,
q,
kv_c_normed,
k_pe,
self_kv_cache,
attn_metadata,
output=output,
)
return output
else:
return self.impl.forward(
self, q, kv_c_normed, k_pe, self_kv_cache, attn_metadata
)
else:
if self.attn_backend.accept_output_buffer:
output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
torch.ops.vllm.unified_mla_attention_with_output(
q,
kv_c_normed,
k_pe,
output,
self.layer_name,
)
return output
else:
return torch.ops.vllm.unified_mla_attention(
q,
kv_c_normed,
k_pe,
self.layer_name,
)
def process_weights_after_loading(self, act_dtype: torch.dtype):
if hasattr(self.impl, "process_weights_after_loading"):
self.impl.process_weights_after_loading(act_dtype)
def calc_kv_scales(
self, q: torch.Tensor, kv_c_normed: torch.Tensor, k_pe: torch.Tensor
) -> None:
"""Optional scale calculation for MLA inputs.
Mirrors Attention.calc_kv_scales. Not all MLA backends require this
"""
# Use safe defaults if ranges are not present
q_range = getattr(self, "q_range", torch.tensor(1.0))
k_range = getattr(self, "k_range", torch.tensor(1.0))
v_range = getattr(self, "v_range", torch.tensor(1.0))
self._q_scale.copy_(torch.abs(q).max() / q_range)
# kv_c_normed is the compressed KV representation; use it for k/v
kv_abs_max = torch.abs(kv_c_normed).max()
self._k_scale.copy_(kv_abs_max / k_range)
self._v_scale.copy_(kv_abs_max / v_range)
self._q_scale_float = self._q_scale.item()
self._k_scale_float = self._k_scale.item()
self._v_scale_float = self._v_scale.item()
self.calculate_kv_scales = False
def get_attn_backend(self) -> type[AttentionBackend]:
return self.attn_backend
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
kv_cache_dtype = kv_cache_dtype_str_to_dtype(
self.kv_cache_dtype, vllm_config.model_config
)
return MLAAttentionSpec(
block_size=vllm_config.cache_config.block_size,
num_kv_heads=1,
head_size=self.head_size,
dtype=kv_cache_dtype,
cache_dtype_str=vllm_config.cache_config.cache_dtype,
)
def maybe_calc_kv_scales(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
# Only calculate if the layer's calculate_kv_scales flag is True
# This flag gets set to False after the first forward pass
if not self.calculate_kv_scales:
return
self.calc_kv_scales(query, key, value)
def maybe_calc_kv_scales_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
layer_name: str,
) -> None:
return
direct_register_custom_op(
op_name="maybe_calc_kv_scales",
op_func=maybe_calc_kv_scales,
mutates_args=["query", "key", "value"],
fake_impl=maybe_calc_kv_scales_fake,
)
def get_attention_context(
layer_name: str,
) -> tuple[dict | object | None, Attention | MLAAttention, torch.Tensor]:
"""Extract attention context for a given layer.
This helper function extracts the attention metadata, attention layer
instance, and KV cache tensor for a specific layer.
Args:
layer_name: The name/identifier of the attention layer.
Returns:
A tuple containing:
- attn_metadata: Attention metadata for this specific layer, or None if
no metadata available
- attn_layer: The attention layer instance (Attention or MLAAttention)
- kv_cache: The KV cache tensor for current virtual engine
Note: attn_metadata may be None, but attn_layer and kv_cache are always
extracted from the forward context.
"""
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name]
attn_layer: Attention | MLAAttention = forward_context.no_compile_layers[layer_name]
kv_cache = attn_layer.kv_cache[forward_context.virtual_engine]
return attn_metadata, attn_layer, kv_cache
@maybe_transfer_kv_layer
def unified_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
layer_name: str,
) -> torch.Tensor:
attn_metadata, self, kv_cache = get_attention_context(layer_name)
output = self.impl.forward(self, query, key, value, kv_cache, attn_metadata)
return output
def unified_attention_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
layer_name: str,
) -> torch.Tensor:
return torch.empty_like(query).contiguous()
direct_register_custom_op(
op_name="unified_attention",
op_func=unified_attention,
fake_impl=unified_attention_fake,
)
@maybe_transfer_kv_layer
def unified_attention_with_output(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
layer_name: str,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> None:
attn_metadata, self, kv_cache = get_attention_context(layer_name)
self.impl.forward(
self,
query,
key,
value,
kv_cache,
attn_metadata,
output=output,
output_scale=output_scale,
output_block_scale=output_block_scale,
)
def unified_attention_with_output_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
layer_name: str,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> None:
return
direct_register_custom_op(
op_name="unified_attention_with_output",
op_func=unified_attention_with_output,
mutates_args=["output", "output_block_scale"],
fake_impl=unified_attention_with_output_fake,
)
@maybe_transfer_kv_layer
def unified_mla_attention(
q: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
layer_name: str,
) -> torch.Tensor:
attn_metadata, self, kv_cache = get_attention_context(layer_name)
output = self.impl.forward(self, q, kv_c_normed, k_pe, kv_cache, attn_metadata)
return output
def unified_mla_attention_fake(
q: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
layer_name: str,
) -> torch.Tensor:
return torch.empty_like(q).contiguous()
direct_register_custom_op(
op_name="unified_mla_attention",
op_func=unified_mla_attention,
mutates_args=[],
fake_impl=unified_mla_attention_fake,
dispatch_key=current_platform.dispatch_key,
)
@maybe_transfer_kv_layer
def unified_mla_attention_with_output(
q: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
output: torch.Tensor,
layer_name: str,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> None:
attn_metadata, self, kv_cache = get_attention_context(layer_name)
self.impl.forward(
self,
q,
kv_c_normed,
k_pe,
kv_cache,
attn_metadata,
output=output,
output_scale=output_scale,
output_block_scale=output_block_scale,
)
def unified_mla_attention_with_output_fake(
q: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
output: torch.Tensor,
layer_name: str,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> None:
return
direct_register_custom_op(
op_name="unified_mla_attention_with_output",
op_func=unified_mla_attention_with_output,
mutates_args=["output", "output_block_scale"],
fake_impl=unified_mla_attention_with_output_fake,
dispatch_key=current_platform.dispatch_key,
)

View File

View File

@@ -0,0 +1,120 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
import torch
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend
from vllm.config import CacheConfig
from vllm.config.vllm import VllmConfig
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.v1.attention.backends.utils import (
AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
make_local_attention_virtual_batches,
subclass_attention_backend,
)
from vllm.v1.kv_cache_interface import (
AttentionSpec,
ChunkedLocalAttentionSpec,
KVCacheSpec,
)
@functools.lru_cache
def create_chunked_local_attention_backend(
underlying_attn_backend: AttentionBackend,
attention_chunk_size: int,
block_size: int,
) -> type[AttentionBackend]:
prefix = f"ChunkedLocalAttention_{attention_chunk_size}_{block_size}_"
underlying_builder = underlying_attn_backend.get_builder_cls()
assert issubclass(underlying_builder, AttentionMetadataBuilder)
class ChunkedLocalAttentionBuilder(underlying_builder): # type: ignore
@classmethod
def get_cudagraph_support(
cls: type["AttentionMetadataBuilder"],
vllm_config: VllmConfig,
kv_cache_spec: AttentionSpec,
) -> AttentionCGSupport:
# Explicit override in case the underlying builder specialized this getter.
# @override omitted only because of mypy limitation due to type variable.
return AttentionCGSupport.NEVER
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> AttentionMetadata:
common_attn_metadata = make_local_attention_virtual_batches(
attention_chunk_size, common_attn_metadata, block_size
)
return super().build(common_prefix_len, common_attn_metadata, fast_build)
attn_backend = subclass_attention_backend(
name_prefix=prefix,
attention_backend_cls=underlying_attn_backend,
builder_cls=ChunkedLocalAttentionBuilder,
)
return attn_backend
class ChunkedLocalAttention(Attention):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
attention_chunk_size: int,
num_kv_heads: int | None = None,
alibi_slopes: list[float] | None = None,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
kv_sharing_target_layer_name: str | None = None,
prefix: str = "",
):
self.attention_chunk_size = attention_chunk_size
dtype = torch.get_default_dtype()
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
else:
kv_cache_dtype = "auto"
block_size = 16
underlying_attn_backend = get_attn_backend(
head_size, dtype, kv_cache_dtype, block_size
)
attn_backend = create_chunked_local_attention_backend(
underlying_attn_backend, attention_chunk_size, block_size
)
super().__init__(
num_heads=num_heads,
head_size=head_size,
scale=scale,
num_kv_heads=num_kv_heads,
alibi_slopes=alibi_slopes,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix,
kv_sharing_target_layer_name=kv_sharing_target_layer_name,
attn_backend=attn_backend,
)
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
assert self.attention_chunk_size
return ChunkedLocalAttentionSpec(
block_size=vllm_config.cache_config.block_size,
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
dtype=self.kv_cache_torch_dtype,
attention_chunk_size=self.attention_chunk_size,
)

View File

@@ -0,0 +1,178 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
from copy import copy
import numpy as np
import torch
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionMetadata,
AttentionType,
)
from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend
from vllm.config import CacheConfig, VllmConfig
from vllm.logger import init_logger
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata,
subclass_attention_backend,
)
from vllm.v1.kv_cache_interface import CrossAttentionSpec, KVCacheSpec
logger = init_logger(__name__)
def _get_cross_slot_mapping(
encoder_seq_lens: np.ndarray,
block_table_tensor: torch.Tensor,
kv_cache_spec: CrossAttentionSpec,
device: torch.device,
) -> torch.Tensor:
"""Get cross-attention slot mappings."""
block_size = kv_cache_spec.block_size
slot_mappings = []
# Find indices with non-zero encoder sequence lengths
# The majority of parallel requests will be running the
# decoder, so this list should be relatively small.
active_indices = np.nonzero(encoder_seq_lens)[0]
for req_index in active_indices:
encoder_seq_len = encoder_seq_lens[req_index].item()
# Calculate the number of blocks needed for this request
num_blocks_needed = cdiv(encoder_seq_len, block_size)
# Get the block IDs for this request from the tensor
req_block_ids = block_table_tensor[req_index]
# Get only the blocks we need (first num_blocks_needed blocks)
needed_block_ids = req_block_ids[:num_blocks_needed]
# All needed blocks are allocated
i_values = torch.arange(encoder_seq_len, dtype=torch.int64, device=device)
block_indices = i_values // block_size
block_offsets = i_values % block_size
block_numbers = needed_block_ids[block_indices]
slot_mapping = block_numbers * block_size + block_offsets
slot_mappings.append(slot_mapping)
if slot_mappings:
return torch.cat(slot_mappings)
else:
return torch.empty(0, dtype=torch.int64, device=device)
@functools.lru_cache
def create_cross_attention_backend(
underlying_attn_backend: AttentionBackend,
) -> type[AttentionBackend]:
prefix = "CrossAttention_"
underlying_builder = underlying_attn_backend.get_builder_cls()
class CrossAttentionBuilder(underlying_builder): # type: ignore
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> AttentionMetadata:
new_metadata = copy(common_attn_metadata)
new_metadata.causal = False
max_encoder_len = int(new_metadata.encoder_seq_lens_cpu.max())
new_metadata.max_seq_len = max_encoder_len
# Any computed tokens indicated decode step>1 (no chunked prefill)
num_cache_decodes = (
(common_attn_metadata.num_computed_tokens_cpu > 0).sum().item()
)
if num_cache_decodes > 0:
# CrossAttn KV cache has already been populated on first decoder step,
# skip slot_mapping calculation for requests that do not need
# reshape_and_cache.
num_tokens = common_attn_metadata.num_computed_tokens_cpu.numpy()
new_metadata.encoder_seq_lens_cpu = np.where(
num_tokens > 0, 0, new_metadata.encoder_seq_lens_cpu
)
# seq_lens is provided by model runner: initial encoder input length is
# needed here to know how many tokens to attend to from the cached
# cross-attention KV cache.
new_metadata.seq_lens = common_attn_metadata.encoder_seq_lens
new_metadata._seq_lens_cpu = torch.from_numpy(
common_attn_metadata.encoder_seq_lens_cpu
)
# NOTE (NickLucche) use `new_metadata` instead of `common_*` (initial) here
new_metadata.slot_mapping = _get_cross_slot_mapping(
new_metadata.encoder_seq_lens_cpu,
new_metadata.block_table_tensor,
self.kv_cache_spec,
self.device,
)
return super().build(common_prefix_len, new_metadata, fast_build)
attn_backend = subclass_attention_backend(
name_prefix=prefix,
attention_backend_cls=underlying_attn_backend,
builder_cls=CrossAttentionBuilder,
)
return attn_backend
class CrossAttention(Attention):
"""
Cross-attention for encoder-decoder models.
Handles attention between decoder queries and encoder keys/values.
"""
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
cache_config: CacheConfig | None = None,
attn_type: str | None = None,
**kwargs,
):
dtype = torch.get_default_dtype()
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
else:
kv_cache_dtype = "auto"
block_size = 16
underlying_attn_backend = get_attn_backend(
head_size, dtype, kv_cache_dtype, block_size
)
attn_backend = create_cross_attention_backend(underlying_attn_backend)
if attn_type is not None:
assert attn_type == AttentionType.ENCODER_DECODER, (
"CrossAttention only supports AttentionType.ENCODER_DECODER"
)
super().__init__(
num_heads=num_heads,
head_size=head_size,
scale=scale,
cache_config=cache_config,
attn_backend=attn_backend,
attn_type=AttentionType.ENCODER_DECODER,
**kwargs,
)
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
return CrossAttentionSpec(
block_size=vllm_config.cache_config.block_size,
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
dtype=self.kv_cache_torch_dtype,
)

View File

@@ -0,0 +1,103 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
from copy import copy
import torch
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionMetadata,
AttentionType,
)
from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend
from vllm.config import CacheConfig
from vllm.config.vllm import VllmConfig
from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata,
subclass_attention_backend,
)
from vllm.v1.kv_cache_interface import KVCacheSpec
@functools.lru_cache
def create_encoder_only_attention_backend(
underlying_attn_backend: AttentionBackend,
) -> type[AttentionBackend]:
prefix = "EncoderOnlyAttention_"
underlying_builder = underlying_attn_backend.get_builder_cls()
class EncoderOnlyAttentionBuilder(underlying_builder): # type: ignore
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> AttentionMetadata:
new_common_attn_metadata = copy(common_attn_metadata)
new_common_attn_metadata.causal = False
return super().build(
common_prefix_len, new_common_attn_metadata, fast_build
)
attn_backend = subclass_attention_backend(
name_prefix=prefix,
attention_backend_cls=underlying_attn_backend,
builder_cls=EncoderOnlyAttentionBuilder,
)
return attn_backend
class EncoderOnlyAttention(Attention):
"""
Encoder attention is a special case that doesn't need a KV Cache.
"""
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
cache_config: CacheConfig | None = None,
attn_type: str | None = None,
**kwargs,
):
dtype = torch.get_default_dtype()
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
else:
kv_cache_dtype = "auto"
block_size = 16
underlying_attn_backend = get_attn_backend(
head_size,
dtype,
kv_cache_dtype,
block_size,
attn_type=AttentionType.ENCODER_ONLY,
)
attn_backend = create_encoder_only_attention_backend(underlying_attn_backend)
if attn_type is not None:
assert attn_type == AttentionType.ENCODER_ONLY, (
"EncoderOnlyAttention only supports AttentionType.ENCODER_ONLY"
)
super().__init__(
num_heads=num_heads,
head_size=head_size,
scale=scale,
cache_config=cache_config,
attn_backend=attn_backend,
attn_type=AttentionType.ENCODER_ONLY,
**kwargs,
)
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
# Does not need KV cache
return None

View File

@@ -0,0 +1,284 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.ops.vit_attn_wrappers import (
vit_flash_attn_wrapper,
vit_torch_sdpa_wrapper,
)
from vllm.config import MultiModalConfig
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.models.vision import get_vit_attn_backend
logger = init_logger(__name__)
def maybe_get_vit_flash_attn_backend(
attn_backend: AttentionBackendEnum | None,
) -> Callable | None:
# At this point,
# we already have the attn_backend,
# overriding logic is done in the platform-specific implementation.
# so we don't need to override backend here.
# Just return the attn_backend and flash_attn_varlen_func.
if attn_backend == AttentionBackendEnum.FLASH_ATTN:
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
elif attn_backend == AttentionBackendEnum.ROCM_AITER_FA:
from aiter import flash_attn_varlen_func
else:
flash_attn_varlen_func = None
# if attn_backend is TORCH_SDPA,
# it will reach here and the flash_attn_varlen_func will be None.
return flash_attn_varlen_func
@CustomOp.register("mm_encoder_attn")
class MMEncoderAttention(CustomOp):
"""Multi-headed attention without any cache, used for multimodal encoder."""
def __init__(
self,
num_heads: int,
head_size: int,
scale: float | None = None,
num_kv_heads: int | None = None,
prefix: str = "",
multimodal_config: MultiModalConfig | None = None,
) -> None:
"""
Args:
num_heads: number of attention heads per partition.
head_size: hidden_size per attention head.
scale: scale factor.
num_kv_heads: number of kv heads.
prefix: This has no effect, it is only here to make it easier to
swap between Attention and MultiHeadAttention
multimodal_config: configs for multi-modal.
"""
super().__init__()
self.num_heads = num_heads
self.head_size = head_size
self.scale = scale
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.layer_name = prefix
assert self.num_heads % self.num_kv_heads == 0, (
f"num_heads ({self.num_heads}) is not "
f"divisible by num_kv_heads ({self.num_kv_heads})"
)
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
# During model initialization, the default dtype is set as the model
# weight and activation dtype.
dtype = torch.get_default_dtype()
# Try to get vision attention backend from multimodal_config.
attn_backend_override = None
if multimodal_config is not None:
attn_backend_override = multimodal_config.mm_encoder_attn_backend
# Get device-specific vision attention backend.
self.attn_backend = get_vit_attn_backend(
head_size=head_size,
dtype=dtype,
attn_backend_override=attn_backend_override,
)
self.is_flash_attn_backend = self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}
self.flash_attn_varlen_func = maybe_get_vit_flash_attn_backend(
self.attn_backend,
)
logger.info_once(f"Using {self.attn_backend} for MMEncoderAttention.")
@classmethod
def enabled(cls) -> bool:
return True
def reshape_qkv_to_4d(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
bsz: int,
q_len: int,
kv_len: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Reshape query, key, value to 4D tensors:
(batch_size, seq_len, num_heads, head_size)
"""
query = query.view(bsz, q_len, self.num_heads, self.head_size)
key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)
if (num_repeat := self.num_queries_per_kv) > 1:
# Handle MQA and GQA
key = torch.repeat_interleave(key, num_repeat, dim=2)
value = torch.repeat_interleave(value, num_repeat, dim=2)
return query, key, value
def reshape_qkv_to_3d(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
bsz: int,
q_len: int,
kv_len: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Reshape query, key, value to 3D tensors:
(batch_size * seq_len, num_heads, head_size)
"""
query = query.view(bsz * q_len, self.num_heads, self.head_size)
key = key.view(bsz * kv_len, self.num_kv_heads, self.head_size)
value = value.view(bsz * kv_len, self.num_kv_heads, self.head_size)
if (num_repeat := self.num_queries_per_kv) > 1:
# Handle MQA and GQA
key = torch.repeat_interleave(key, num_repeat, dim=1)
value = torch.repeat_interleave(value, num_repeat, dim=1)
return query, key, value
def _forward_sdpa(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
) -> torch.Tensor:
# TODO(Isotr0py): Migrate MultiHeadAttention
assert cu_seqlens is not None
bsz, q_len = query.size()[:2]
kv_len = key.size(1)
query, key, value = self.reshape_qkv_to_4d(
query, key, value, bsz, q_len, kv_len
)
output = vit_torch_sdpa_wrapper(
q=query,
k=key,
v=value,
cu_seqlens=cu_seqlens,
)
return output
def _forward_fa(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
) -> torch.Tensor:
assert self.flash_attn_varlen_func is not None, (
"Flash attention function is not set."
)
# # TODO(Isotr0py): Migrate MultiHeadAttention
assert cu_seqlens is not None and max_seqlen is not None
bsz = query.shape[0]
output = vit_flash_attn_wrapper(
q=query,
k=key,
v=value,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
batch_size=bsz,
is_rocm_aiter=(self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA),
)
return output
def forward_native(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
) -> torch.Tensor:
return self._forward_sdpa(query, key, value, cu_seqlens)
def forward_cuda(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
) -> torch.Tensor:
if self.is_flash_attn_backend:
return self._forward_fa(query, key, value, cu_seqlens, max_seqlen)
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
return self._forward_sdpa(query, key, value, cu_seqlens)
else:
raise ValueError(
f"Unsupported multi-modal encoder attention backend for CUDA: "
f"{self.attn_backend}."
)
def forward_cpu(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
) -> torch.Tensor:
return self._forward_sdpa(query, key, value, cu_seqlens)
def forward_xpu(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
) -> torch.Tensor:
assert self.is_flash_attn_backend, (
"XPU only supports FLASH_ATTN for vision attention."
)
return self._forward_fa(query, key, value, cu_seqlens, max_seqlen)
def forward_tpu(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
) -> torch.Tensor:
assert self.attn_backend == AttentionBackendEnum.PALLAS, (
f"MMEncoderAttention on TPU only supports PALLAS backend, "
f"but got {self.attn_backend}."
)
if cu_seqlens is None:
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
from torch_xla.experimental.custom_kernel import flash_attention
out = flash_attention(query, key, value, sm_scale=self.scale)
out = out.transpose(1, 2)
return out
logger.warning_once(
"PALLAS backend with cu_seqlens is not supported for ViT yet. ",
"Falling back to SDPA implementation.",
)
return self._forward_sdpa(query, key, value, cu_seqlens)

View File

@@ -0,0 +1,401 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Authors:
# - Burkhard Ringlein <ngl@zurich.ibm.com>
# - Jan van Lunteren <jvl@zurich.ibm.com>
# - Chih-Chieh Yang <chih.chieh.yang@ibm.com>
# - Thomas Parnell <tpa@zurich.ibm.com>
import torch
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from .prefix_prefill import context_attention_fwd
float8_info = torch.finfo(current_platform.fp8_dtype())
@triton.jit
def cdiv_fn(x, y):
return (x + y - 1) // y
@triton.jit
def kernel_paged_attention_2d(
output_ptr, # [num_tokens, num_query_heads, head_size]
query_ptr, # [num_tokens, num_query_heads, head_size]
key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x]
value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size]
sink_ptr, # [num_query_heads]
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
seq_lens_ptr, # [num_seqs]
alibi_slopes_ptr, # [num_query_heads]
scale, # float32
k_scale, # float32
v_scale, # float32
out_scale_inv,
num_query_heads: tl.constexpr, # int
num_queries_per_kv: tl.constexpr, # int
num_queries_per_kv_padded: tl.constexpr, # int
block_table_stride: tl.int64, # int
query_stride_0: tl.int64, # int
query_stride_1: tl.int64, # int, should be equal to head_size
output_stride_0: tl.int64, # int
output_stride_1: tl.int64, # int, should be equal to head_size
BLOCK_SIZE: tl.constexpr, # int
HEAD_SIZE: tl.constexpr, # int
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
USE_ALIBI_SLOPES: tl.constexpr, # bool
SLIDING_WINDOW: tl.constexpr, # int
x: tl.constexpr, # int
stride_k_cache_0: tl.int64, # int
stride_k_cache_1: tl.int64, # int
stride_k_cache_2: tl.int64, # int
stride_k_cache_3: tl.int64, # int
stride_k_cache_4: tl.int64, # int
stride_v_cache_0: tl.int64, # int
stride_v_cache_1: tl.int64, # int
stride_v_cache_2: tl.int64, # int
stride_v_cache_3: tl.int64, # int
filter_by_query_len: tl.constexpr, # bool
query_start_len_ptr, # [num_seqs+1]
USE_SINKS: tl.constexpr, # bool
USE_FP8: tl.constexpr,
FP8_MIN: tl.constexpr = float8_info.min,
FP8_MAX: tl.constexpr = float8_info.max,
):
seq_idx = tl.program_id(0)
kv_head_idx = tl.program_id(1)
if filter_by_query_len:
cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx)
cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1)
cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index
if cur_batch_query_len > 1:
return
else:
cur_batch_in_all_start_index = seq_idx
query_head_idx = kv_head_idx * num_queries_per_kv + tl.arange(
0, num_queries_per_kv_padded
)
query_offset = (
cur_batch_in_all_start_index * query_stride_0
+ query_head_idx[:, None] * query_stride_1
)
head_mask = query_head_idx < (kv_head_idx + 1) * num_queries_per_kv
head_mask = head_mask & (query_head_idx < num_query_heads)
dim_mask = tl.where(tl.arange(0, HEAD_SIZE_PADDED) < HEAD_SIZE, 1, 0).to(tl.int1)
# Q : (num_queries_per_kv, HEAD_SIZE,)
Q = tl.load(
query_ptr + query_offset + tl.arange(0, HEAD_SIZE_PADDED)[None, :],
mask=dim_mask[None, :] & head_mask[:, None],
other=0.0,
)
block_table_offset = seq_idx * block_table_stride
if not USE_SINKS:
M = tl.full([num_queries_per_kv_padded], float("-inf"), dtype=tl.float32)
else:
M = tl.load(
sink_ptr + query_head_idx,
mask=head_mask,
other=float("-inf"),
).to(dtype=tl.float32)
L = tl.full([num_queries_per_kv_padded], 1.0, dtype=tl.float32)
acc = tl.zeros([num_queries_per_kv_padded, HEAD_SIZE_PADDED], dtype=tl.float32)
# sequence len for this particular sequence
seq_len = tl.load(seq_lens_ptr + seq_idx)
# alibi slope for this head
if USE_ALIBI_SLOPES:
alibi_slope = tl.load(
alibi_slopes_ptr + query_head_idx, mask=head_mask, other=0.0
)
num_blocks = cdiv_fn(seq_len, BLOCK_SIZE)
# iterate through tiles
for j in range(0, num_blocks):
physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j)
offs_n = tl.arange(0, BLOCK_SIZE)
offs_d = tl.arange(0, HEAD_SIZE_PADDED)
v_offset = (
physical_block_idx * stride_v_cache_0
+ kv_head_idx * stride_v_cache_1
+ offs_d[None, :] * stride_v_cache_2
+ offs_n[:, None] * stride_v_cache_3
)
k_offset = (
physical_block_idx * stride_k_cache_0
+ kv_head_idx * stride_k_cache_1
+ (offs_d[:, None] // x) * stride_k_cache_2
+ offs_n[None, :] * stride_k_cache_3
+ (offs_d[:, None] % x) * stride_k_cache_4
)
# K : (HEAD_SIZE, BLOCK_SIZE)
K_load = tl.load(key_cache_ptr + k_offset, mask=dim_mask[:, None], other=0.0)
if K_load.dtype.is_fp8():
K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype)
else:
K = K_load
# V : (BLOCK_SIZE, HEAD_SIZE)
V_load = tl.load(value_cache_ptr + v_offset, mask=dim_mask[None, :], other=0.0)
if V_load.dtype.is_fp8():
V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype)
else:
V = V_load
seq_offset = j * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
boundary = tl.full([BLOCK_SIZE], seq_len, dtype=tl.int32)
seq_mask = seq_offset[None, :] < boundary
# S : (num_queries_per_kv, BLOCK_SIZE,)
S = tl.where(head_mask[:, None] & seq_mask, 0.0, float("-inf")).to(tl.float32)
S += scale * tl.dot(Q, K)
context_len = seq_len - 1
if SLIDING_WINDOW > 0:
S = tl.where((context_len - seq_offset) < SLIDING_WINDOW, S, -10000)
if USE_ALIBI_SLOPES:
S += alibi_slope[:, None] * (seq_offset - context_len)
# compute running maximum
# m_j : (num_queries_per_kv,)
m_j = tl.maximum(M, tl.max(S, axis=1))
# P : (num_queries_per_kv, BLOCK_SIZE,)
P = tl.exp(S - m_j[:, None])
# l_j : (num_queries_per_kv,)
l_j = tl.sum(P, axis=1)
# alpha : (num_queries_per_kv, )
alpha = tl.exp(M - m_j)
# acc : (num_queries_per_kv, BLOCK_SIZE,)
acc = acc * alpha[:, None]
# update constants
L = L * alpha + l_j
M = m_j
# acc : (num_queries_per_kv, BLOCK_SIZE,)
acc += tl.dot(P.to(V.dtype), V)
# epilogue
acc = acc / L[:, None]
if USE_FP8:
acc = acc * tl.load(out_scale_inv)
acc = tl.clamp(acc, FP8_MIN, FP8_MAX)
output_offset = (
cur_batch_in_all_start_index * output_stride_0
+ query_head_idx * output_stride_1
)
tl.store(
output_ptr + output_offset[:, None] + tl.arange(0, HEAD_SIZE_PADDED)[None, :],
acc,
mask=dim_mask[None, :] & head_mask[:, None],
)
def chunked_prefill_paged_decode(
query,
key,
value,
output,
kv_cache_dtype,
key_cache,
value_cache,
block_table,
query_start_loc,
seq_lens,
max_seq_len,
max_query_len,
k_scale,
v_scale,
alibi_slopes=None,
sliding_window=None,
sm_scale=None,
output_scale=None,
# Optional tensor for sinks
sinks=None,
):
if sm_scale is None:
sm_scale = 1.0 / (query.shape[1] ** 0.5)
use_alibi_slopes = alibi_slopes is not None
if sliding_window is None or sliding_window <= 0:
sliding_window = 0
if max_query_len > 1:
context_attention_fwd(
q=query,
k=key,
v=value,
o=output,
kv_cache_dtype=kv_cache_dtype,
k_cache=key_cache,
v_cache=value_cache,
b_loc=block_table,
b_start_loc=query_start_loc,
b_seq_len=seq_lens,
max_seq_len=max_seq_len,
max_input_len=max_query_len,
k_scale=k_scale,
v_scale=v_scale,
alibi_slopes=alibi_slopes,
sliding_window=sliding_window,
sm_scale=sm_scale,
skip_decode=True,
fp8_out_scale=output_scale,
sinks=sinks,
)
block_size = value_cache.shape[3]
num_seqs = len(seq_lens)
num_query_heads = query.shape[1]
num_kv_heads = key.shape[1]
num_queries_per_kv = query.shape[1] // key.shape[1]
head_size = query.shape[2]
# Conversion of FP8 Tensor from uint8 storage to
# appropriate torch.dtype for interpretation by Triton
if "fp8" in kv_cache_dtype:
assert key_cache.dtype in [torch.uint8, current_platform.fp8_dtype()]
assert value_cache.dtype in [torch.uint8, current_platform.fp8_dtype()]
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
target_dtype = current_platform.fp8_dtype()
elif kv_cache_dtype == "fp8_e5m2":
target_dtype = torch.float8_e5m2
else:
raise ValueError("Unsupported FP8 dtype:", kv_cache_dtype)
key_cache = key_cache.view(target_dtype)
value_cache = value_cache.view(target_dtype)
num_queries_per_kv_padded = max(triton.next_power_of_2(num_queries_per_kv), 16)
from vllm.platforms.rocm import use_rocm_custom_paged_attention
use_custom = use_rocm_custom_paged_attention(
query.dtype,
head_size,
block_size,
num_queries_per_kv,
max_seq_len,
sliding_window,
kv_cache_dtype,
alibi_slopes,
sinks,
)
if use_custom:
_PARTITION_SIZE_ROCM = 256
max_num_partitions = (
max_seq_len + _PARTITION_SIZE_ROCM - 1
) // _PARTITION_SIZE_ROCM
assert _PARTITION_SIZE_ROCM % block_size == 0
total_num_seq = block_table.shape[0]
tmp_output = torch.empty(
size=(total_num_seq, num_query_heads, max_num_partitions, head_size),
dtype=query.dtype,
device=output.device,
)
exp_sums = torch.empty(
size=(total_num_seq, num_query_heads, max_num_partitions),
dtype=torch.float32,
device=output.device,
)
max_logits = torch.empty_like(exp_sums)
ops.paged_attention_rocm(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale=sm_scale,
block_tables=block_table,
seq_lens=seq_lens,
query_start_loc=query_start_loc,
block_size=block_size,
max_seq_len=max_seq_len,
alibi_slopes=alibi_slopes,
kv_cache_dtype=kv_cache_dtype,
k_scale=k_scale,
v_scale=v_scale,
fp8_out_scale=output_scale,
)
else:
kernel_paged_attention_2d[
(
num_seqs,
num_kv_heads,
)
](
output_ptr=output,
query_ptr=query,
key_cache_ptr=key_cache,
value_cache_ptr=value_cache,
sink_ptr=sinks,
block_tables_ptr=block_table,
seq_lens_ptr=seq_lens,
alibi_slopes_ptr=alibi_slopes,
scale=sm_scale,
k_scale=k_scale,
v_scale=v_scale,
out_scale_inv=1.0 / output_scale if output_scale is not None else 1.0,
num_query_heads=num_query_heads,
num_queries_per_kv=num_queries_per_kv,
num_queries_per_kv_padded=num_queries_per_kv_padded,
block_table_stride=block_table.stride(0),
query_stride_0=query.stride(0),
query_stride_1=query.stride(1),
output_stride_0=output.stride(0),
output_stride_1=output.stride(1),
BLOCK_SIZE=block_size,
HEAD_SIZE=head_size,
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
USE_ALIBI_SLOPES=use_alibi_slopes,
SLIDING_WINDOW=sliding_window,
x=key_cache.shape[4],
stride_k_cache_0=key_cache.stride(0),
stride_k_cache_1=key_cache.stride(1),
stride_k_cache_2=key_cache.stride(2),
stride_k_cache_3=key_cache.stride(3),
stride_k_cache_4=key_cache.stride(4),
stride_v_cache_0=value_cache.stride(0),
stride_v_cache_1=value_cache.stride(1),
stride_v_cache_2=value_cache.stride(2),
stride_v_cache_3=value_cache.stride(3),
filter_by_query_len=True,
query_start_len_ptr=query_start_loc,
USE_SINKS=sinks is not None,
USE_FP8=output_scale is not None,
)

View File

@@ -0,0 +1,469 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.distributed.parallel_state import GroupCoordinator
from vllm.triton_utils import tl, triton
@triton.jit
def _correct_attn_cp_out_kernel(
outputs_ptr,
new_output_ptr,
lses_ptr,
vlse_ptr,
outputs_stride_B,
outputs_stride_H,
outputs_stride_D,
lses_stride_N,
lses_stride_B,
lses_stride_H,
lse_idx,
HEAD_DIM: tl.constexpr,
N_ROUNDED: tl.constexpr,
IS_BASE_E: tl.constexpr,
):
"""
Apply the all-gathered lses to correct each local rank's attention
output. we still need perform a cross-rank reduction to obtain the
final attention output.
Args:
outputs_ptr (triton.PointerType):
Pointer to input tensor of shape [ B, H, D ]
lses_ptr (triton.PointerType):
Pointer to input tensor of shape [ N, B, H ]
new_output_ptr (triton.PointerType):
Pointer to output tensor of shape [ B, H, D ]
vlse_ptr (triton.PointerType):
Pointer to output tensor of shape [ B, H ]
"""
batch_idx = tl.program_id(axis=0).to(tl.int64)
head_idx = tl.program_id(axis=1).to(tl.int64)
d_offsets = tl.arange(0, HEAD_DIM)
num_n_offsets = tl.arange(0, N_ROUNDED)
# shape = [N]
lse_offsets = (
num_n_offsets * lses_stride_N
+ batch_idx * lses_stride_B
+ head_idx * lses_stride_H
)
# calc final lse
lse = tl.load(lses_ptr + lse_offsets)
lse = tl.where((lse != lse) | (lse == float("inf")), -float("inf"), lse)
lse_max = tl.max(lse, axis=0)
lse_max = tl.where(lse_max == -float("inf"), 0, lse_max)
lse -= lse_max
if IS_BASE_E:
lse_exp = tl.exp(lse)
lse_acc = tl.sum(lse_exp, axis=0)
lse = tl.log(lse_acc)
else:
lse_exp = tl.exp2(lse)
lse_acc = tl.sum(lse_exp, axis=0)
lse = tl.log2(lse_acc)
lse += lse_max
lse_offsets = batch_idx * lses_stride_B + head_idx * lses_stride_H
tl.store(vlse_ptr + lse_offsets, lse)
# shape = [D]
output_offsets = (
batch_idx * outputs_stride_B
+ head_idx * outputs_stride_H
+ d_offsets * outputs_stride_D
)
# correct output
lse_offset = (
lse_idx * lses_stride_N + batch_idx * lses_stride_B + head_idx * lses_stride_H
)
lse_tmp = tl.load(lses_ptr + lse_offset)
lse_finally = lse_tmp - lse
lse_finally = tl.where(
(lse_finally != lse_finally) | (lse_finally == float("inf")),
-float("inf"),
lse_finally,
)
factor = tl.exp(lse_finally) if IS_BASE_E else tl.exp2(lse_finally)
output = tl.load(outputs_ptr + output_offsets)
output = output * factor
tl.store(new_output_ptr + output_offsets, output)
class CPTritonContext:
"""The CPTritonContext is used to avoid recompilation of the Triton JIT."""
def __init__(self):
self.inner_kernel = None
def call_kernel(self, kernel, grid, *regular_args, **const_args):
if self.inner_kernel is None:
self.inner_kernel = kernel[grid](*regular_args, **const_args)
else:
self.inner_kernel[grid](*regular_args)
def correct_attn_out(
out: torch.Tensor,
lses: torch.Tensor,
cp_rank: int,
ctx: CPTritonContext,
is_lse_base_on_e: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Correct the attention output using the all-gathered lses.
Args:
out: Tensor of shape [ B, H, D ]
lses: Tensor of shape [ N, B, H ]
cp_rank: Current rank in the context-parallel group
ctx: Triton context to avoid recompilation
Returns:
Tuple of (out, lse) with corrected attention and final log-sum-exp.
"""
if ctx is None:
ctx = CPTritonContext()
# --- Normalize to 3D views ---
if out.ndim == 4 and out.shape[1] == 1:
out = out.squeeze(1)
assert out.ndim == 3, f"expected out [B,H,D] or [B,1,H,D], got {tuple(out.shape)}"
if lses.ndim == 4 and lses.shape[-1] == 1:
lses = lses.squeeze(-1)
if lses.ndim == 4 and lses.shape[1] == 1:
lses = lses.squeeze(1)
assert lses.ndim == 3, (
f"expected lses [N,B,H] (optionally with a 1-sized extra dim), "
f"got {tuple(lses.shape)}"
)
B, H, D = out.shape
N = lses.shape[0]
# Strides after we normalized shapes to 3-D views. The kernel computes
# offsets for `vlse_ptr` using lses_stride_B/H, so the output buffer must
# have the same B/H stride layout as a slice of `lses`.
o_sB, o_sH, o_sD = out.stride()
l_sN, l_sB, l_sH = lses.stride()
# Allocate LSE with the same B/H strides as `lses` so writes land correctly
# even when `lses` is a non-contiguous view (e.g., 4-D to 3-D squeeze).
lse = torch.empty_strided(
(B, H), (l_sB, l_sH), device=lses.device, dtype=lses.dtype
)
# Kernel launch config
grid = (B, H, 1)
regular_args = (
out,
out,
lses,
lse,
o_sB,
o_sH,
o_sD,
l_sN,
l_sB,
l_sH,
cp_rank,
)
const_args = {"HEAD_DIM": D, "N_ROUNDED": N, "IS_BASE_E": is_lse_base_on_e}
ctx.call_kernel(_correct_attn_cp_out_kernel, grid, *regular_args, **const_args)
return out, lse
def _cp_lse_common(
cp_attn_out: torch.Tensor,
cp_attn_lse: torch.Tensor,
cp_group: GroupCoordinator,
ctx: CPTritonContext | None = None,
is_lse_base_on_e=True,
):
"""
cp_attn_out: [ B, H, D ]
cp_attn_lse: [ B, H ]
"""
if cp_group.world_size == 1:
return cp_attn_out
if ctx is None:
ctx = CPTritonContext()
lses = torch.empty(
(cp_group.world_size,) + cp_attn_lse.shape,
dtype=cp_attn_lse.dtype,
device=cp_attn_lse.device,
)
cp_attn_lse = cp_attn_lse.contiguous()
lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses)
out, lse = correct_attn_out(
cp_attn_out,
lses,
cp_group.rank_in_group,
ctx,
is_lse_base_on_e=is_lse_base_on_e,
)
return out, lse
def cp_lse_ag_out_rs(
cp_attn_out: torch.Tensor,
cp_attn_lse: torch.Tensor,
cp_group: GroupCoordinator,
ctx: CPTritonContext | None = None,
return_lse: bool = False,
is_lse_base_on_e=True,
):
"""
cp_attn_out: [ B, H, D ]
cp_attn_lse: [ B, H ]
"""
out, lse = _cp_lse_common(
cp_attn_out, cp_attn_lse, cp_group, ctx=ctx, is_lse_base_on_e=is_lse_base_on_e
)
out = cp_group.reduce_scatter(out, dim=1)
if return_lse:
cp_num_heads = lse.shape[1] // cp_group.world_size
cp_rank = cp_group.rank_in_group
lse = lse[:, cp_num_heads * cp_rank : cp_num_heads * (cp_rank + 1)]
return out, lse
return out
def cp_lse_ag_out_ar(
cp_attn_out: torch.Tensor,
cp_attn_lse: torch.Tensor,
cp_group: GroupCoordinator,
ctx: CPTritonContext | None = None,
return_lse: bool = False,
is_lse_base_on_e=True,
):
"""
cp_attn_out: [ B, H, D ]
cp_attn_lse: [ B, H ]
"""
out, lse = _cp_lse_common(
cp_attn_out, cp_attn_lse, cp_group, ctx=ctx, is_lse_base_on_e=is_lse_base_on_e
)
out = cp_group.all_reduce(out)
if return_lse:
return out, lse
return out
@triton.jit
def _pack_seq_kernel(
x_ptr, # [N, D]
out_ptr, # [B, Lmax, D]
lengths_ptr, # *i32, [B]
N: tl.constexpr,
D: tl.constexpr,
Lmax: tl.constexpr,
PAD_VALUE: tl.constexpr,
BLOCK_T: tl.constexpr, # timesteps per program
BLOCK_D: tl.constexpr, # features per program
):
pid_b = tl.program_id(0) # batch id
pid_t = tl.program_id(1) # block over time dimension
pid_d = tl.program_id(2) # block over feature dimension
off_t = pid_t * BLOCK_T + tl.arange(0, BLOCK_T) # [BLOCK_T]
off_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) # [BLOCK_D]
# Compute start index and sequence length from cumulative lengths
in_start = 0
for i in range(pid_b):
in_start += tl.load(lengths_ptr + i)
seq_len = tl.load(lengths_ptr + pid_b)
# valid time positions for this block
t_mask = off_t < Lmax
# compute input row indices for valid (b, t)
in_row = in_start + off_t
valid_row = (off_t < seq_len) & t_mask
# Pointers
# x_ptr: row-major [N, D]
x_row_ptr = x_ptr + in_row[:, None] * D + off_d[None, :]
# out_ptr: row-major [B, Lmax, D]
out_row_ptr = out_ptr + (pid_b * Lmax + off_t)[:, None] * D + off_d[None, :]
# Initialize with PAD (cast will occur as needed based on out_ptr dtype)
d_mask = off_d[None, :] < D
pad_vals = tl.full([BLOCK_T, BLOCK_D], PAD_VALUE, tl.float32)
tl.store(out_row_ptr, pad_vals, mask=t_mask[:, None] & d_mask)
# Load & write only where within seq_len
x_vals = tl.load(x_row_ptr, mask=valid_row[:, None] & d_mask)
tl.store(out_row_ptr, x_vals, mask=valid_row[:, None] & d_mask)
def pack_seq_triton(
x: torch.Tensor,
lengths: torch.Tensor,
pad_value: float = -float("inf"),
block_t: int = 64,
block_d: int = 64,
) -> torch.Tensor:
"""
Pack sequences of different lengths into a batched tensor.
Args:
x: [N, ...] - input tensor where N is total number of tokens
lengths: [B] - sequence lengths for each batch
pad_value: value to use for padding
block_t: block size for time dimension
block_d: block size for feature dimension
Returns:
packed: [B, Lmax, ...] - packed tensor
"""
# Handle multi-dimensional input by reshaping to (N, -1)
original_shape = x.shape
if len(original_shape) > 2:
N = original_shape[0]
x_reshaped = x.reshape(N, -1)
D = x_reshaped.shape[1]
else:
N, D = x.shape
x_reshaped = x
B = lengths.numel()
Lmax = int(lengths.max().item())
# Starts are computed inside the kernel from lengths
out = torch.empty((B, Lmax, D), device=x.device, dtype=x.dtype)
grid = (B, triton.cdiv(Lmax, block_t), triton.cdiv(D, block_d))
_pack_seq_kernel[grid](
x_reshaped,
out,
lengths.int(),
N,
D,
Lmax,
PAD_VALUE=float(pad_value),
BLOCK_T=block_t,
BLOCK_D=block_d,
num_warps=4,
num_stages=2,
)
# Reshape output back to original dimensions (except first dimension)
if len(original_shape) > 2:
output_shape = (B, Lmax) + original_shape[1:]
out = out.reshape(output_shape)
return out
@triton.jit
def _unpack_seq_triton_kernel(
packed_ptr, # [B, Lmax, D]
out_ptr, # [N, D]
lengths_ptr, # *i32, [B]
B: tl.constexpr,
Lmax: tl.constexpr,
D: tl.constexpr,
BLOCK_T: tl.constexpr, # timesteps per program
BLOCK_D: tl.constexpr, # features per program
):
pid_b = tl.program_id(0) # batch id
pid_t = tl.program_id(1) # block over time dimension
pid_d = tl.program_id(2) # block over feature dimension
off_t = pid_t * BLOCK_T + tl.arange(0, BLOCK_T) # [BLOCK_T]
off_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) # [BLOCK_D]
# bounds: compute start from cumulative lengths
in_start = 0
for i in range(pid_b):
in_start += tl.load(lengths_ptr + i)
seq_len = tl.load(lengths_ptr + pid_b)
# valid time positions for this block
t_mask = off_t < Lmax
valid_row = (off_t < seq_len) & t_mask
# compute output row indices for valid (b, t)
out_row = in_start + off_t
# Pointers
# packed_ptr: row-major [B, Lmax, D]
packed_row_ptr = packed_ptr + (pid_b * Lmax + off_t)[:, None] * D + off_d[None, :]
# out_ptr: row-major [N, D]
out_row_ptr = out_ptr + out_row[:, None] * D + off_d[None, :]
# Load from packed tensor and store to output
d_mask = off_d[None, :] < D
packed_vals = tl.load(packed_row_ptr, mask=valid_row[:, None] & d_mask)
tl.store(out_row_ptr, packed_vals, mask=valid_row[:, None] & d_mask)
def unpack_seq_triton(
packed_tensor: torch.Tensor,
lengths: torch.Tensor,
block_t: int = 64,
block_d: int = 64,
) -> torch.Tensor:
"""
Unpack a packed decode query tensor back to the original format.
Efficient Triton implementation.
Args:
packed_tensor: [B, Lmax, ...] - packed tensor from pack_seq_triton
lengths: [B] - sequence lengths for each batch
block_t: block size for time dimension
block_d: block size for feature dimension
Returns:
unpacked_tensor: [N, ...] where N = sum(lengths)
"""
# Handle multi-dimensional input by reshaping to (B, Lmax, -1)
original_shape = packed_tensor.shape
if len(original_shape) > 3:
B, Lmax = original_shape[:2]
packed_reshaped = packed_tensor.reshape(B, Lmax, -1)
D = packed_reshaped.shape[2]
else:
B, Lmax, D = packed_tensor.shape
packed_reshaped = packed_tensor
# Calculate total number of elements
N = int(lengths.sum().item())
out = torch.empty((N, D), device=packed_tensor.device, dtype=packed_tensor.dtype)
grid = (B, triton.cdiv(Lmax, block_t), triton.cdiv(D, block_d))
_unpack_seq_triton_kernel[grid](
packed_reshaped,
out,
lengths.int(),
B,
Lmax,
D,
BLOCK_T=block_t,
BLOCK_D=block_d,
num_warps=4,
num_stages=2,
)
# Reshape output back to original dimensions (except first dimension)
if len(original_shape) > 3:
output_shape = (N,) + original_shape[2:]
out = out.reshape(output_shape)
return out

View File

@@ -0,0 +1,251 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# adapted from: https://github.com/deepseek-ai/FlashMLA/blob/main/flash_mla/flash_mla_interface.py
import torch
from vllm.logger import init_logger
from vllm.platforms import current_platform
logger = init_logger(__name__)
if current_platform.is_cuda():
try:
import vllm._flashmla_C # noqa: F401
_flashmla_C_AVAILABLE = True
except ImportError:
_flashmla_C_AVAILABLE = False
else:
_flashmla_C_AVAILABLE = False
if current_platform.is_cuda():
try:
import vllm._flashmla_extension_C # noqa: F401
_flashmla_extension_C_AVAILABLE = True
except ImportError:
_flashmla_extension_C_AVAILABLE = False
else:
_flashmla_extension_C_AVAILABLE = False
def _is_flashmla_available() -> tuple[bool, str | None]:
if not _flashmla_C_AVAILABLE:
return (
False,
"vllm._flashmla_C is not available, likely was not "
"compiled due to insufficient nvcc version or a supported arch "
"was not in the list of target arches to compile for.",
)
if not _flashmla_extension_C_AVAILABLE:
return (
False,
"vllm._flashmla_extension_C is not available, likely "
"was not compiled due to a build error.",
)
return True, None
def is_flashmla_dense_supported() -> tuple[bool, str | None]:
"""
Return: is_supported_flag, unsupported_reason (optional).
"""
is_availble, maybe_reason = _is_flashmla_available()
if not is_availble:
return False, maybe_reason
if current_platform.get_device_capability()[0] != 9:
return False, "FlashMLA Dense is only supported on Hopper devices."
return True, None
def is_flashmla_sparse_supported() -> tuple[bool, str | None]:
"""
Return: is_supported_flag, unsupported_reason (optional).
"""
is_availble, maybe_reason = _is_flashmla_available()
if not is_availble:
return False, maybe_reason
if current_platform.get_device_capability()[0] not in (9, 10):
return (
False,
"FlashMLA Sparse is only supported on Hopper and Blackwell devices.",
)
return True, None
def get_mla_metadata(
cache_seqlens: torch.Tensor,
num_q_tokens_per_head_k: int,
num_heads_k: int,
num_heads_q: int | None = None,
is_fp8_kvcache: bool = False,
topk: int | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
- cache_seqlens: (batch_size), dtype torch.int32.
- num_q_tokens_per_head_k:
Equals to num_q_tokens_per_q_seq * num_heads_q // num_heads_k.
- num_heads_k: The number of k heads.
- num_heads_q:
The number of q heads.
This argument is optional when sparse attention is not enabled
- is_fp8_kvcache: Whether the k_cache and v_cache are in fp8 format.
- topk: If not None, sparse attention will be enabled,
and only tokens in the `indices` array
passed to `flash_mla_with_kvcache_sm90` will be attended to.
Returns:
- tile_scheduler_metadata:
(num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
- num_splits: (batch_size + 1), dtype torch.int32.
"""
if is_fp8_kvcache and topk is None:
return torch.ops._flashmla_extension_C.get_mla_decoding_metadata_dense_fp8(
cache_seqlens,
num_q_tokens_per_head_k,
num_heads_k,
)
return torch.ops._flashmla_C.get_mla_decoding_metadata(
cache_seqlens,
num_q_tokens_per_head_k,
num_heads_k,
num_heads_q,
is_fp8_kvcache,
topk,
)
def flash_mla_with_kvcache(
q: torch.Tensor,
k_cache: torch.Tensor,
block_table: torch.Tensor,
cache_seqlens: torch.Tensor,
head_dim_v: int,
tile_scheduler_metadata: torch.Tensor,
num_splits: torch.Tensor,
softmax_scale: float | None = None,
causal: bool = False,
descale_q: torch.Tensor | None = None,
descale_k: torch.Tensor | None = None,
is_fp8_kvcache: bool = False,
indices: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
- q: (batch_size, seq_len_q, num_heads_q, head_dim).
- k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
- block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
- cache_seqlens: (batch_size), torch.int32.
- head_dim_v: Head dimension of v.
- tile_scheduler_metadata:
(num_sm_parts, TileSchedulerMetaDataSize), torch.int32,
returned by get_mla_metadata.
- num_splits:
(batch_size + 1), torch.int32, returned by get_mla_metadata.
- softmax_scale: float.
The scale of QK^T before applying softmax.
Default to 1 / sqrt(head_dim).
- causal: bool. Whether to apply causal attention mask.
- descale_q: (batch_size),
torch.float32. Descaling factors for Q, used for fp8 quantization.
- descale_k: (batch_size),
torch.float32. Descaling factors for K, used for fp8 quantization.
- is_fp8_kvcache: bool.
Whether the k_cache and v_cache are in fp8 format.
For the format of FP8 KV cache, please refer to README.md
- indices: (batch_size, seq_len_q, topk), torch.int32.
If not None, sparse attention will be enabled,
and only tokens in the `indices` array will be attended to.
Invalid indices should be set to -1 or numbers >= total_seq_len_kv.
For details about how to set up `indices`, please refer to README.md.
Returns:
- out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
- softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
"""
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
if indices is not None:
# NOTE (zyongye): sparse attention is also causal
# since it only attend to the tokens before
# but here `causal` should not be specified
assert not causal, "causal must be `false` if sparse attention is enabled."
assert (descale_q is None) == (descale_k is None), (
"descale_q and descale_k should be both None or both not None"
)
if indices is None and q.element_size() == 1:
out, softmax_lse = torch.ops._flashmla_extension_C.fwd_kvcache_mla_fp8(
q,
k_cache,
head_dim_v,
cache_seqlens,
block_table,
softmax_scale,
causal,
tile_scheduler_metadata,
num_splits,
descale_q,
descale_k,
)
else:
out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla(
q,
k_cache,
head_dim_v,
cache_seqlens,
block_table,
softmax_scale,
causal,
tile_scheduler_metadata,
num_splits,
is_fp8_kvcache,
indices,
)
return out, softmax_lse
def flash_mla_sparse_prefill(
q: torch.Tensor,
kv: torch.Tensor,
indices: torch.Tensor,
sm_scale: float,
d_v: int = 512,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Sparse attention prefill kernel
Args:
- q: [s_q, h_q, d_qk], bfloat16
- kv: [s_kv, h_kv, d_qk], bfloat16
- indices: [s_q, h_kv, topk], int32.
Invalid indices should be set to -1 or numbers >= s_kv
- sm_scale: float
- d_v: The dimension of value vectors. Can only be 512
Returns:
- (output, max_logits, lse)
About the definition of output,
max_logits and lse, please refer to README.md
- output: [s_q, h_q, d_v], bfloat16
- max_logits: [s_q, h_q], float
- lse: [s_q, h_q], float, 2-based log-sum-exp
"""
results = torch.ops._flashmla_C.sparse_prefill_fwd(q, kv, indices, sm_scale, d_v)
return results
#
# TODO: Add fake functions
#
# @register_fake("_flashmla_C::get_mla_metadata")
# def _get_mla_metadata_fake(....) -> Tuple[torch.Tensor, torch.Tensor]:
# return ....
#
# @register_fake("_flashmla_C::fwd_kvcache_mla")
# def _fwd_kvcache_mla_fake(....) -> Tuple[torch.Tensor, torch.Tensor]:
# return ....
#

View File

@@ -0,0 +1,47 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.platforms import current_platform
def merge_attn_states(
output: torch.Tensor,
prefix_output: torch.Tensor,
prefix_lse: torch.Tensor,
suffix_output: torch.Tensor,
suffix_lse: torch.Tensor,
output_lse: torch.Tensor | None = None,
) -> None:
# NOTE(DefTruth): Currently, custom merge_attn_states CUDA kernel
# is not support for FP8 dtype, fallback to use Triton kernel.
def supported_dtypes(o: torch.Tensor) -> bool:
return o.dtype in [torch.float32, torch.half, torch.bfloat16]
# NOTE(DefTruth): Currently, custom merge_attn_states CUDA
# kernel load/store 128b(16 bytes) per memory issue within
# thread. Namely, the headsize(headdim) must be multiple of
# pack_size (float32 -> 4, half/bfloat16 -> 8).
def supported_headdim(o: torch.Tensor) -> bool:
headdim = o.shape[2] # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
if o.dtype == torch.float32:
return headdim % 4 == 0
return headdim % 8 == 0
if (
current_platform.is_cuda()
and supported_dtypes(output)
and supported_headdim(output)
):
from vllm._custom_ops import merge_attn_states
return merge_attn_states(
output, prefix_output, prefix_lse, suffix_output, suffix_lse, output_lse
)
else:
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
return merge_attn_states(
output, prefix_output, prefix_lse, suffix_output, suffix_lse, output_lse
)

View File

@@ -1,59 +1,29 @@
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm import _custom_ops as ops
from vllm.attention.ops.prefix_prefill import context_attention_fwd
from vllm.platforms import current_platform
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
_PARTITION_SIZE = 512
@dataclass
class PagedAttentionMetadata:
"""Metadata for PagedAttention."""
# (batch_size,). The length of sequences (entire tokens seen so far) per
# sequence.
seq_lens_tensor: Optional[torch.Tensor]
# Maximum sequence length in the batch.
max_seq_len: Optional[int]
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
# in the kv cache. Each block can contain up to block_size tokens.
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured.
block_tables: Optional[torch.Tensor]
if current_platform.is_cuda_alike():
from vllm import _custom_ops as ops
elif current_platform.is_xpu():
from vllm._ipex_ops import ipex_ops as ops
class PagedAttention:
@staticmethod
def get_supported_head_sizes() -> List[int]:
return [64, 80, 96, 112, 128, 256]
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return (2, num_blocks, block_size * num_kv_heads * head_size)
@staticmethod
def split_kv_cache(
kv_cache: torch.Tensor,
num_kv_heads: int,
head_size: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
x = 16 // kv_cache.element_size()
num_blocks = kv_cache.shape[1]
key_cache = kv_cache[0]
key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x,
-1, x)
key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x, -1, x)
value_cache = kv_cache[1]
value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1)
return key_cache, value_cache
@@ -66,7 +36,8 @@ class PagedAttention:
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
kv_scale: float,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
) -> None:
ops.reshape_and_cache(
key,
@@ -75,142 +46,6 @@ class PagedAttention:
value_cache,
slot_mapping.flatten(),
kv_cache_dtype,
kv_scale,
k_scale,
v_scale,
)
@staticmethod
def forward_decode(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
max_seq_len: int,
kv_cache_dtype: str,
num_kv_heads: int,
scale: float,
alibi_slopes: Optional[torch.Tensor],
kv_scale: float,
) -> torch.Tensor:
output = torch.empty_like(query)
block_size = value_cache.shape[3]
num_seqs, num_heads, head_size = query.shape
max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) //
_PARTITION_SIZE)
# NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use
# V1 to avoid the overhead of reduction. Also, if the number of
# sequences or heads is large, we use V1 since there is enough work
# to parallelize.
# TODO(woosuk): Tune this heuristic.
# For context len > 8192, use V2 kernel to avoid shared memory shortage.
use_v1 = (max_seq_len <= 8192
and (max_num_partitions == 1 or num_seqs * num_heads > 512))
if use_v1:
# Run PagedAttention V1.
ops.paged_attention_v1(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
kv_scale,
)
else:
# Run PagedAttention V2.
assert _PARTITION_SIZE % block_size == 0
tmp_output = torch.empty(
size=(num_seqs, num_heads, max_num_partitions, head_size),
dtype=output.dtype,
device=output.device,
)
exp_sums = torch.empty(
size=(num_seqs, num_heads, max_num_partitions),
dtype=torch.float32,
device=output.device,
)
max_logits = torch.empty_like(exp_sums)
ops.paged_attention_v2(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
kv_scale,
)
return output
@staticmethod
def forward_prefix(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
subquery_start_loc: torch.Tensor,
seq_lens_tensor: torch.Tensor,
context_lens: torch.Tensor,
max_query_len: int,
alibi_slopes: Optional[torch.Tensor],
sliding_window: Optional[int],
) -> torch.Tensor:
output = torch.empty_like(query)
context_attention_fwd(
query,
key,
value,
output,
key_cache,
value_cache,
block_tables,
# subquery_start_loc is (batch_size + 1,)
subquery_start_loc[:-1],
seq_lens_tensor,
context_lens,
max_query_len,
alibi_slopes,
sliding_window,
)
return output
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: Dict[int, int],
) -> None:
src_key_cache = src_kv_cache[0]
dst_key_cache = dst_kv_cache[0]
ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
src_value_cache = src_kv_cache[1]
dst_value_cache = dst_kv_cache[1]
ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]],
) -> None:
key_caches = [kv_cache[0] for kv_cache in kv_caches]
value_caches = [kv_cache[1] for kv_cache in kv_caches]
ops.copy_blocks(key_caches, value_caches, src_to_dists)

View File

@@ -0,0 +1,130 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
import jax
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
from vllm.utils.math_utils import cdiv
def _kv_cache_update_kernel(
# Prefetch
slices_ref, # [3, padded_num_slices], list of (kv_cache_start,
# new_kv_start, slice_len)
num_slices_ref, # [1]
# Input
new_kv_hbm_ref, # [num_tokens, num_combined_kv_heads, head_dim]
kv_cache_hbm_ref, # [total_num_pages * page_size, num_combined_kv_heads,
# head_dim]
# Output
_, # [total_num_pages * page_size, num_combined_kv_heads, head_dim]
# Scratch
scratch, # [num_slices_per_block, page_size, num_combined_kv_heads,
# head_dim]
sem,
):
async_copies = []
block_idx = pl.program_id(0)
num_slices_per_block = scratch.shape[0]
# Copy from new_kv_hbm_ref to scratch
for i in range(num_slices_per_block):
offset_i = i + block_idx * num_slices_per_block
new_kv_start = jax.lax.select(
offset_i < num_slices_ref[0], slices_ref[1, offset_i], 0
)
length = jax.lax.select(
offset_i < num_slices_ref[0], slices_ref[2, offset_i], 0
)
async_copy = pltpu.make_async_copy(
new_kv_hbm_ref.at[pl.ds(new_kv_start, length), ...],
scratch.at[i, pl.ds(0, length), ...],
sem,
)
async_copy.start()
async_copies.append(async_copy)
for async_copy in async_copies:
async_copy.wait()
# Copy from scratch to kv_cache_hbm_ref
async_copies.clear()
for i in range(num_slices_per_block):
offset_i = i + block_idx * num_slices_per_block
kv_cache_start = jax.lax.select(
offset_i < num_slices_ref[0], slices_ref[0, offset_i], 0
)
length = jax.lax.select(
offset_i < num_slices_ref[0], slices_ref[2, offset_i], 0
)
async_copy = pltpu.make_async_copy(
scratch.at[i, pl.ds(0, length), ...],
kv_cache_hbm_ref.at[pl.ds(kv_cache_start, length), ...],
sem,
)
async_copy.start()
async_copies.append(async_copy)
for async_copy in async_copies:
async_copy.wait()
@functools.partial(
jax.jit,
static_argnames=["page_size", "num_slices_per_block"],
)
def kv_cache_update(
# [total_num_token, num_combined_kv_heads, head_dim]
new_kv: jax.Array,
# [3, slices], list of (kv_cache_start, new_kv_start, slice_len)
slices: jax.Array,
# [total_num_pages * page_size, num_combined_kv_heads, head_dim]
kv_cache: jax.Array,
# [1]
num_kv_update_slices: jax.Array,
*,
page_size: int = 32,
num_slices_per_block: int = 8,
):
_, num_combined_kv_heads, head_dim = new_kv.shape
assert kv_cache.shape[1] == num_combined_kv_heads
assert kv_cache.shape[2] == head_dim
assert head_dim % 128 == 0
# TODO: Add dynamic check to make sure that the all the slice lengths are
# smaller or equal to page_size
in_specs = [
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
]
out_specs = [pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY)]
out_shape = [jax.ShapeDtypeStruct(kv_cache.shape, dtype=kv_cache.dtype)]
scalar_prefetches = [slices, num_kv_update_slices]
scratch = pltpu.VMEM(
(num_slices_per_block, page_size, num_combined_kv_heads, head_dim),
new_kv.dtype,
)
scratch_shapes = [
scratch,
pltpu.SemaphoreType.DMA,
]
kernel = pl.pallas_call(
_kv_cache_update_kernel,
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=len(scalar_prefetches),
in_specs=in_specs,
out_specs=out_specs,
grid=(cdiv(num_kv_update_slices[0], num_slices_per_block),),
scratch_shapes=scratch_shapes,
),
out_shape=out_shape,
input_output_aliases={len(scalar_prefetches) + 1: 0},
)
return kernel(*scalar_prefetches, new_kv, kv_cache)[0]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,210 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib
from functools import lru_cache
import torch
from vllm._aiter_ops import rocm_aiter_ops
from vllm.logger import init_logger
from vllm.platforms import current_platform
logger = init_logger(__name__)
# Take from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_attention.py#L84
def fp8_mqa_logits_torch(
q: torch.Tensor,
kv: tuple[torch.Tensor, torch.Tensor],
weights: torch.Tensor,
cu_seqlen_ks: torch.Tensor,
cu_seqlen_ke: torch.Tensor,
) -> torch.Tensor:
"""Compute FP8 MQA logits for a single sequence without KV paging.
Args:
q: Query tensor of shape [M, H, D]. Casted to
`torch.float8_e4m3fn` by caller.
kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with
dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or
[N, 1]) with dtype `torch.float32`.
weights: weights of shape [M, H], dtype `torch.float32`.
cu_seqlen_ks: Start indices (inclusive) for valid K per query position,
shape [M], dtype int32.
cu_seqlen_ke: End indices (exclusive) for valid K per query position,
shape [M], dtype int32.
Returns:
Logits tensor of shape [M, N], dtype `torch.float32`.
"""
kv, scale = kv
seq_len_kv = kv.shape[0]
k = kv.to(torch.bfloat16)
q = q.to(torch.bfloat16)
mask_lo = (
torch.arange(0, seq_len_kv, device="cuda")[None, :] >= cu_seqlen_ks[:, None]
)
mask_hi = (
torch.arange(0, seq_len_kv, device="cuda")[None, :] < cu_seqlen_ke[:, None]
)
mask = mask_lo & mask_hi
score = torch.einsum("mhd,nd->hmn", q, k).float() * scale
logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0)
logits = logits.masked_fill(~mask, float("-inf"))
return logits
def rocm_fp8_mqa_logits(
q: torch.Tensor,
kv: tuple[torch.Tensor, torch.Tensor],
weights: torch.Tensor,
cu_seqlen_ks: torch.Tensor,
cu_seqlen_ke: torch.Tensor,
) -> torch.Tensor:
"""Compute FP8 MQA logits for a single sequence without KV paging.
Args:
q: Query tensor of shape [M, H, D]. Casted to
`torch.float8_e4m3fn` by caller.
kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with
dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or
[N, 1]) with dtype `torch.float32`.
weights: weights of shape [M, H], dtype `torch.float32`.
cu_seqlen_ks: Start indices (inclusive) for valid K per query position,
shape [M], dtype int32.
cu_seqlen_ke: End indices (exclusive) for valid K per query position,
shape [M], dtype int32.
Returns:
Logits tensor of shape [M, N], dtype `torch.float32`.
"""
# TODO(ganyi): Temporarily workaround, will remove the module check and reference
# path after aiter merge this kernel into main
@lru_cache
def has_mqa_logits_module():
return importlib.util.find_spec("aiter.ops.triton.fp8_mqa_logits") is not None
if rocm_aiter_ops.is_enabled() and has_mqa_logits_module():
from aiter.ops.triton.fp8_mqa_logits import fp8_mqa_logits
kv, scale = kv
return fp8_mqa_logits(q, kv, scale, weights, cu_seqlen_ks, cu_seqlen_ke)
else:
return fp8_mqa_logits_torch(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke)
# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_attention.py#L156
def fp8_paged_mqa_logits_torch(
q: torch.Tensor,
kv_cache: torch.Tensor,
weights: torch.Tensor,
context_lens: torch.Tensor,
block_tables: torch.Tensor,
max_model_len: int,
):
from vllm.utils.math_utils import cdiv
fp8_dtype = current_platform.fp8_dtype()
batch_size, next_n, _, dim = q.size()
kv_cache, scale = kv_cache[..., :dim], kv_cache[..., dim:]
scale = scale.contiguous().view(torch.float)
q = q.float()
kv_cache = kv_cache.view(fp8_dtype).float() * scale
num_block, block_size, _, dim = kv_cache.size()
logits = torch.full(
[batch_size * next_n, max_model_len],
float("-inf"),
device=q.device,
dtype=torch.float32,
)
context_lens = context_lens.tolist()
for i in range(batch_size):
context_len = context_lens[i]
q_offsets = torch.arange(context_len - next_n, context_len, device="cuda")
weight_slice = (
weights[i * next_n : (i + 1) * next_n, :].transpose(0, 1).contiguous()
)
for block_rk in range(cdiv(context_len, block_size)):
block_idx = block_tables[i][block_rk]
qx, kx = q[i], kv_cache[block_idx]
k_offsets = torch.arange(
block_rk * block_size, (block_rk + 1) * block_size, device="cuda"
)
mask = (k_offsets[None, :] < context_len) & (
k_offsets[None, :] <= q_offsets[:, None]
)
s = torch.where(
mask[None, :, :],
(qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to(
logits.dtype
),
float("-inf"),
)
s = torch.relu(s) * weight_slice[..., None]
s = s.sum(dim=0)
logits[
i * next_n : (i + 1) * next_n,
block_rk * block_size : (block_rk + 1) * block_size,
] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s, float("-inf"))
return logits
def rocm_fp8_paged_mqa_logits(
q_fp8: torch.Tensor,
kv_cache_fp8: torch.Tensor,
weights: torch.Tensor,
context_lens: torch.Tensor,
block_tables: torch.Tensor,
schedule_metadata: torch.Tensor,
max_model_len: int,
) -> torch.Tensor:
"""Compute FP8 MQA logits using paged KV-cache.
Args:
q_fp8: Query tensor of shape [B, next_n, H, D]. Casted to
`torch.float8_e4m3fn` by caller.
kv_cache_fp8: Paged KV-cache in packed FP8+scale layout with shape
[num_blocks, block_size, 1, D+4], dtype `torch.uint8`. The last
4 bytes per (block,pos) store the `float` dequant scale.
weights: Tensor of shape [B * next_n, H], dtype `torch.float32`.
context_lens: Tensor of shape [B], dtype int32; effective context length
for each batch element.
block_tables: Tensor of shape [B, max_blocks], dtype int32; maps logical
block indices to physical blocks in the paged cache.
schedule_metadata: Returned by `get_paged_mqa_logits_metadata`;
used to distribute work across SMs.
max_model_len: Maximum sequence length used to size the logits output.
Returns:
Logits tensor of shape [B * next_n, max_model_len], dtype
`torch.float32`.
"""
if rocm_aiter_ops.is_enabled():
from aiter.ops.triton.pa_mqa_logits import deepgemm_fp8_paged_mqa_logits_stage1
batch_size, next_n, heads, _ = q_fp8.shape
out_qk = torch.full(
(heads, batch_size * next_n, max_model_len),
float("-inf"),
device="cuda",
dtype=torch.float32,
)
deepgemm_fp8_paged_mqa_logits_stage1(
q_fp8,
kv_cache_fp8,
weights,
out_qk,
context_lens,
block_tables,
max_model_len,
)
return out_qk.sum(dim=0)
else:
return fp8_paged_mqa_logits_torch(
q_fp8, kv_cache_fp8, weights, context_lens, block_tables, max_model_len
)

View File

@@ -0,0 +1,712 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from
# https://github.com/sgl-project/sglang/blob/9f635ea50de920aa507f486daafba26a5b837574/python/sglang/srt/layers/attention/triton_ops/decode_attention.py
# which was originally adapted from
# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py
# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage2.py
# Changes:
# - Add support for page size >= 1.
# Copyright 2025 vLLM Team
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Memory-efficient attention for decoding.
It supports page size >= 1.
"""
import logging
from packaging import version
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
is_hip_ = current_platform.is_rocm()
logger = logging.getLogger(__name__)
# Only print the following warnings when triton version < 3.2.0.
# The issue won't affect performance or accuracy.
if version.parse(triton.__version__) < version.parse("3.2.0"):
logger.warning(
"The following error message 'operation scheduled before its operands' "
"can be ignored."
)
@triton.jit
def tanh(x):
# Tanh is just a scaled sigmoid
return 2 * tl.sigmoid(2 * x) - 1
@triton.jit
def _fwd_kernel_stage1(
Q,
K_Buffer,
V_Buffer,
sm_scale,
Req_to_tokens,
B_Seqlen,
Att_Out,
stride_req_to_tokens_b,
stride_qbs,
stride_qh,
stride_buf_kbs,
stride_buf_kh,
stride_buf_vbs,
stride_buf_vh,
stride_mid_ob,
stride_mid_oh,
stride_mid_os,
kv_group_num: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_DV: tl.constexpr,
BLOCK_N: tl.constexpr,
NUM_KV_SPLITS: tl.constexpr,
PAGE_SIZE: tl.constexpr,
logit_cap: tl.constexpr,
Lk: tl.constexpr,
Lv: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
split_kv_id = tl.program_id(2)
cur_kv_head = cur_head // kv_group_num
offs_d = tl.arange(0, BLOCK_DMODEL)
offs_dv = tl.arange(0, BLOCK_DV)
mask_d = offs_d < Lk
mask_dv = offs_dv < Lv
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_req_idx = cur_batch
off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d
q = tl.load(Q + off_q, mask=mask_d, other=0.0)
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
split_kv_start = kv_len_per_split * split_kv_id
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len)
e_max = -float("inf")
e_sum = 0.0
acc = tl.zeros([BLOCK_DV], dtype=tl.float32)
if split_kv_end > split_kv_start:
for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
offs_n = start_n + tl.arange(0, BLOCK_N)
kv_page_number = tl.load(
Req_to_tokens
+ stride_req_to_tokens_b * cur_batch_req_idx
+ offs_n // PAGE_SIZE,
mask=offs_n < split_kv_end,
other=0,
)
kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE
offs_buf_k = (
kv_loc[:, None] * stride_buf_kbs
+ cur_kv_head * stride_buf_kh
+ offs_d[None, :]
)
k = tl.load(
K_Buffer + offs_buf_k,
mask=(offs_n[:, None] < split_kv_end) & (mask_d[None, :]),
other=0.0,
)
qk = tl.sum(q[None, :] * k, 1)
qk *= sm_scale
if logit_cap > 0:
qk = logit_cap * tanh(qk / logit_cap)
qk = tl.where(offs_n < split_kv_end, qk, float("-inf"))
offs_buf_v = (
kv_loc[:, None] * stride_buf_vbs
+ cur_kv_head * stride_buf_vh
+ offs_dv[None, :]
)
v = tl.load(
V_Buffer + offs_buf_v,
mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]),
other=0.0,
)
n_e_max = tl.maximum(tl.max(qk, 0), e_max)
re_scale = tl.exp(e_max - n_e_max)
p = tl.exp(qk - n_e_max)
acc *= re_scale
acc += tl.sum(p[:, None] * v, 0)
e_sum = e_sum * re_scale + tl.sum(p, 0)
e_max = n_e_max
offs_mid_o = (
cur_batch * stride_mid_ob
+ cur_head * stride_mid_oh
+ split_kv_id * stride_mid_os
+ offs_dv
)
tl.store(
Att_Out + offs_mid_o,
acc / e_sum,
mask=(mask_dv),
)
offs_mid_o_1 = (
cur_batch * stride_mid_ob
+ cur_head * stride_mid_oh
+ split_kv_id * stride_mid_os
+ Lv
)
tl.store(
Att_Out + offs_mid_o_1,
e_max + tl.log(e_sum),
)
def _decode_att_m_fwd(
q,
k_buffer,
v_buffer,
att_out,
Req_to_tokens,
B_Seqlen,
num_kv_splits,
sm_scale,
page_size,
logit_cap,
):
BLOCK = 64 if not is_hip_ else 8
NUM_KV_SPLITS = num_kv_splits
Lk = k_buffer.shape[-1]
Lv = v_buffer.shape[-1]
batch, head_num = q.shape[0], q.shape[1]
grid = (batch, head_num, NUM_KV_SPLITS)
kv_group_num = q.shape[1] // k_buffer.shape[-2]
num_warps = 4
if kv_group_num != 1:
num_warps = 1 if is_hip_ else 2
BLOCK_DMODEL = triton.next_power_of_2(Lk)
BLOCK_DV = triton.next_power_of_2(Lv)
_fwd_kernel_stage1[grid](
q,
k_buffer,
v_buffer,
sm_scale,
Req_to_tokens,
B_Seqlen,
att_out,
Req_to_tokens.stride(0),
q.stride(0),
q.stride(1),
k_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
k_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
v_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
v_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
att_out.stride(0),
att_out.stride(1),
att_out.stride(2),
kv_group_num=kv_group_num,
BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_DV=BLOCK_DV,
BLOCK_N=BLOCK,
NUM_KV_SPLITS=NUM_KV_SPLITS,
PAGE_SIZE=page_size,
logit_cap=logit_cap,
num_warps=num_warps,
num_stages=2,
Lk=Lk,
Lv=Lv,
)
@triton.jit
def _fwd_grouped_kernel_stage1(
Q,
K_Buffer,
V_Buffer,
sm_scale,
Req_to_tokens,
B_Seqlen,
Att_Out,
stride_req_to_tokens_b,
stride_qbs,
stride_qh,
stride_buf_kbs,
stride_buf_kh,
stride_buf_vbs,
stride_buf_vh,
stride_mid_ob,
stride_mid_oh,
stride_mid_os,
kv_group_num: tl.constexpr,
q_head_num: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_DPE: tl.constexpr,
BLOCK_DV: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_H: tl.constexpr,
NUM_KV_SPLITS: tl.constexpr,
PAGE_SIZE: tl.constexpr,
logit_cap: tl.constexpr,
Lk: tl.constexpr,
Lv: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head_id = tl.program_id(1)
cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H)
split_kv_id = tl.program_id(2)
if kv_group_num > BLOCK_H:
VALID_BLOCK_H: tl.constexpr = BLOCK_H
else:
VALID_BLOCK_H: tl.constexpr = kv_group_num
cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H)
mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H
mask_h = mask_h & (cur_head < q_head_num)
offs_d = tl.arange(0, BLOCK_DMODEL)
offs_dv = tl.arange(0, BLOCK_DV)
mask_d = offs_d < Lk
mask_dv = offs_dv < Lv
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_req_idx = cur_batch
offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :]
q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0)
if BLOCK_DPE > 0:
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
mask_dpe = offs_dpe < Lk
off_qpe = (
cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :]
)
qpe = tl.load(
Q + off_qpe, mask=(mask_h[:, None]) & (mask_dpe[None, :]), other=0.0
)
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
split_kv_start = kv_len_per_split * split_kv_id
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len)
e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf")
e_sum = tl.zeros([BLOCK_H], dtype=tl.float32)
acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32)
if split_kv_end > split_kv_start:
for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
offs_n = start_n + tl.arange(0, BLOCK_N)
kv_page_number = tl.load(
Req_to_tokens
+ stride_req_to_tokens_b * cur_batch_req_idx
+ offs_n // PAGE_SIZE,
mask=offs_n < split_kv_end,
other=0,
)
kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE
offs_buf_k = (
kv_loc[None, :] * stride_buf_kbs
+ cur_kv_head * stride_buf_kh
+ offs_d[:, None]
)
k = tl.load(
K_Buffer + offs_buf_k,
mask=(offs_n[None, :] < split_kv_end) & (mask_d[:, None]),
other=0.0,
)
qk = tl.dot(q, k.to(q.dtype))
if BLOCK_DPE > 0:
offs_buf_kpe = (
kv_loc[None, :] * stride_buf_kbs
+ cur_kv_head * stride_buf_kh
+ offs_dpe[:, None]
)
kpe = tl.load(
K_Buffer + offs_buf_kpe,
mask=(offs_n[None, :] < split_kv_end) & (mask_dpe[:, None]),
other=0.0,
)
qk += tl.dot(qpe, kpe.to(qpe.dtype))
qk *= sm_scale
if logit_cap > 0:
qk = logit_cap * tanh(qk / logit_cap)
qk = tl.where(
mask_h[:, None] & (offs_n[None, :] < split_kv_end), qk, float("-inf")
)
offs_buf_v = (
kv_loc[:, None] * stride_buf_vbs
+ cur_kv_head * stride_buf_vh
+ offs_dv[None, :]
)
v = tl.load(
V_Buffer + offs_buf_v,
mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]),
other=0.0,
)
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
re_scale = tl.exp(e_max - n_e_max)
p = tl.exp(qk - n_e_max[:, None])
acc *= re_scale[:, None]
acc += tl.dot(p.to(v.dtype), v)
e_sum = e_sum * re_scale + tl.sum(p, 1)
e_max = n_e_max
offs_mid_o = (
cur_batch * stride_mid_ob
+ cur_head[:, None] * stride_mid_oh
+ split_kv_id * stride_mid_os
+ offs_dv[None, :]
)
tl.store(
Att_Out + offs_mid_o,
acc / e_sum[:, None],
mask=(mask_h[:, None]) & (mask_dv[None, :]),
)
offs_mid_o_1 = (
cur_batch * stride_mid_ob
+ cur_head * stride_mid_oh
+ split_kv_id * stride_mid_os
+ Lv
)
tl.store(
Att_Out + offs_mid_o_1,
e_max + tl.log(e_sum),
mask=mask_h,
)
def _decode_grouped_att_m_fwd(
q,
k_buffer,
v_buffer,
att_out,
Req_to_tokens,
B_Seqlen,
num_kv_splits,
sm_scale,
page_size,
logit_cap,
):
BLOCK = 32
Lk = k_buffer.shape[-1]
Lv = v_buffer.shape[-1]
# [TODO] work around shmem limit on MI3xx
if is_hip_ and Lk >= 576:
BLOCK = 16
if Lk == 576:
BLOCK_DMODEL = 512
BLOCK_DPE = 64
elif Lk == 288:
BLOCK_DMODEL = 256
BLOCK_DPE = 32
else:
BLOCK_DMODEL = triton.next_power_of_2(Lk)
BLOCK_DPE = 0
BLOCK_DV = triton.next_power_of_2(Lv)
batch, head_num = q.shape[0], q.shape[1]
kv_group_num = q.shape[1] // k_buffer.shape[-2]
BLOCK_H = 16
NUM_KV_SPLITS = num_kv_splits
grid = (
batch,
triton.cdiv(head_num, min(BLOCK_H, kv_group_num)),
NUM_KV_SPLITS,
)
extra_kargs = {}
num_stages = 2
if is_hip_:
# https://rocm.docs.amd.com/en/latest/how-to/rocm-for-ai/inference-optimization/workload.html#mi300x-triton-kernel-performance-optimization
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
extra_kargs = {"waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 2}
num_stages = 1
_fwd_grouped_kernel_stage1[grid](
q,
k_buffer,
v_buffer,
sm_scale,
Req_to_tokens,
B_Seqlen,
att_out,
Req_to_tokens.stride(0),
q.stride(0),
q.stride(1),
k_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
k_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
v_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
v_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
att_out.stride(0),
att_out.stride(1),
att_out.stride(2),
kv_group_num=kv_group_num,
q_head_num=head_num,
BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_DPE=BLOCK_DPE,
BLOCK_DV=BLOCK_DV,
BLOCK_N=BLOCK,
BLOCK_H=BLOCK_H,
NUM_KV_SPLITS=NUM_KV_SPLITS,
PAGE_SIZE=page_size,
logit_cap=logit_cap,
num_warps=4,
num_stages=num_stages,
Lk=Lk,
Lv=Lv,
**extra_kargs,
)
@triton.jit
def _fwd_kernel_stage2(
Mid_O,
o,
lse,
B_Seqlen,
stride_mid_ob,
stride_mid_oh,
stride_mid_os,
stride_obs,
stride_oh,
stride_lse_bs,
NUM_KV_SPLITS: tl.constexpr,
BLOCK_DV: tl.constexpr,
Lv: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
offs_d = tl.arange(0, BLOCK_DV)
mask_d = offs_d < Lv
e_sum = 0.0
e_max = -float("inf")
acc = tl.zeros([BLOCK_DV], dtype=tl.float32)
offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d
offs_logic = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + Lv
for split_kv_id in range(0, NUM_KV_SPLITS):
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
split_kv_start = kv_len_per_split * split_kv_id
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len)
if split_kv_end > split_kv_start:
tv = tl.load(
Mid_O + offs_v + split_kv_id * stride_mid_os, mask=mask_d, other=0.0
)
tlogic = tl.load(Mid_O + offs_logic + split_kv_id * stride_mid_os)
n_e_max = tl.maximum(tlogic, e_max)
old_scale = tl.exp(e_max - n_e_max)
acc *= old_scale
exp_logic = tl.exp(tlogic - n_e_max)
acc += exp_logic * tv
e_sum = e_sum * old_scale + exp_logic
e_max = n_e_max
tl.store(
o + cur_batch * stride_obs + cur_head * stride_oh + offs_d,
acc / e_sum,
mask=mask_d,
)
lse_val = e_max + tl.log(e_sum)
tl.store(
lse + cur_batch * stride_lse_bs + cur_head,
lse_val,
)
def _decode_softmax_reducev_fwd(
logits,
q,
o,
lse,
v_buffer,
b_seq_len,
num_kv_splits,
):
batch, head_num = q.shape[0], q.shape[1]
Lv = v_buffer.shape[-1]
BLOCK_DV = triton.next_power_of_2(Lv)
NUM_KV_SPLITS = num_kv_splits
extra_kargs = {}
if is_hip_:
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2}
grid = (batch, head_num)
_fwd_kernel_stage2[grid](
logits,
o,
lse,
b_seq_len,
logits.stride(0),
logits.stride(1),
logits.stride(2),
o.stride(0),
o.stride(1),
lse.stride(0),
NUM_KV_SPLITS=NUM_KV_SPLITS,
BLOCK_DV=BLOCK_DV,
Lv=Lv,
num_warps=4,
num_stages=2,
**extra_kargs,
)
def decode_attention_fwd_normal(
q,
k_buffer,
v_buffer,
o,
lse,
req_to_token,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
page_size,
logit_cap=0.0,
):
_decode_att_m_fwd(
q,
k_buffer,
v_buffer,
attn_logits,
req_to_token,
b_seq_len,
num_kv_splits,
sm_scale,
page_size,
logit_cap,
)
_decode_softmax_reducev_fwd(
attn_logits, q, o, lse, v_buffer, b_seq_len, num_kv_splits
)
def decode_attention_fwd_grouped(
q,
k_buffer,
v_buffer,
o,
lse,
req_to_token,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
page_size,
logit_cap=0.0,
):
_decode_grouped_att_m_fwd(
q,
k_buffer,
v_buffer,
attn_logits,
req_to_token,
b_seq_len,
num_kv_splits,
sm_scale,
page_size,
logit_cap,
)
_decode_softmax_reducev_fwd(
attn_logits, q, o, lse, v_buffer, b_seq_len, num_kv_splits
)
def decode_attention_fwd(
q,
k_buffer,
v_buffer,
o,
lse,
req_to_token,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
page_size=1,
logit_cap=0.0,
):
assert num_kv_splits == attn_logits.shape[2]
kv_group_num = q.shape[1] // v_buffer.shape[-2]
if kv_group_num == 1:
# MHA
decode_attention_fwd_normal(
q,
k_buffer,
v_buffer,
o,
lse,
req_to_token,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
page_size,
logit_cap,
)
else:
# GQA/MQA/MLA
decode_attention_fwd_grouped(
q,
k_buffer,
v_buffer,
o,
lse,
req_to_token,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
page_size,
logit_cap,
)

View File

@@ -1,810 +0,0 @@
#!/usr/bin/env python
"""
Fused Attention
===============
This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao
(https://tridao.me/publications/flash2/flash2.pdf)
Credits: OpenAI kernel team, AMD ML Frameworks Triton team
Features supported:
1) Fwd with causal masking
2) Any sequence lengths without padding (currently fwd kernel only)
3) Support for different sequence lengths for q and k
4) Nested tensor API currently does not support dropout or bias.
Not currently supported:
1) Non power of two head dims
"""
import torch
import triton
import triton.language as tl
torch_dtype: tl.constexpr = torch.float16
@triton.jit
def cdiv_fn(x, y):
return (x + y - 1) // y
@triton.jit
def max_fn(x, y):
return tl.math.max(x, y)
@triton.jit
def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride):
ms = tl.arange(0, m)
ns = tl.arange(0, n)
return philox_offset + ms[:, None] * stride + ns[None, :]
@triton.jit
def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride):
rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n,
stride).to(tl.uint32)
# TODO: use tl.randint for better performance
return tl.rand(philox_seed, rng_offsets)
@triton.jit
def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride):
rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n,
stride)
rng_keep = rng_output > dropout_p
return rng_keep
@triton.jit
def load_fn(block_ptr, first, second, pad):
if first and second:
tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad)
elif first:
tensor = tl.load(block_ptr, boundary_check=(0, ), padding_option=pad)
elif second:
tensor = tl.load(block_ptr, boundary_check=(1, ), padding_option=pad)
else:
tensor = tl.load(block_ptr)
return tensor
@triton.jit
def _attn_fwd_inner(
acc,
l_i,
m_i,
q,
K_block_ptr,
V_block_ptr,
start_m,
actual_seqlen_k,
dropout_p,
philox_seed,
batch_philox_offset,
encoded_softmax_block_ptr,
block_min,
block_max,
offs_n_causal,
masked_blocks,
n_extra_tokens,
bias_ptr,
IS_CAUSAL: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
OFFS_M: tl.constexpr,
OFFS_N: tl.constexpr,
PRE_LOAD_V: tl.constexpr,
MASK_STEPS: tl.constexpr,
ENABLE_DROPOUT: tl.constexpr,
RETURN_ENCODED_SOFTMAX: tl.constexpr,
PADDED_HEAD: tl.constexpr,
):
# loop over k, v, and update accumulator
for start_n in range(block_min, block_max, BLOCK_N):
# For padded blocks, we will overrun the tensor size if
# we load all BLOCK_N. For others, the blocks are all within range.
k = load_fn(
K_block_ptr,
PADDED_HEAD,
MASK_STEPS and (n_extra_tokens != 0),
"zero",
)
if PRE_LOAD_V:
v = load_fn(
V_block_ptr,
MASK_STEPS and (n_extra_tokens != 0),
PADDED_HEAD,
"zero",
)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
# We start from end of seqlen_k so only the first iteration would need
# to be checked for padding if it is not a multiple of block_n
# TODO: This can be optimized to only be true for the padded block.
if MASK_STEPS: # noqa: SIM102
# If this is the last block / iteration, we want to
# mask if the sequence length is not a multiple of block size
# a solution is to always do BLOCK_M // BLOCK_N + 1 steps
# if not is_modulo_mn. last step might get wasted but that is okay.
# check if this masking works for that case.
if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0):
boundary_m = tl.full([BLOCK_M],
actual_seqlen_k,
dtype=tl.int32)
size_n = start_n + OFFS_N[None, :]
mask = size_n < boundary_m[:, None]
qk = tl.where(mask, qk, float("-inf"))
if IS_CAUSAL:
causal_boundary = start_n + offs_n_causal
causal_mask = OFFS_M[:, None] >= causal_boundary[None, :]
qk = tl.where(causal_mask, qk, float("-inf"))
# -- compute qk ----
qk += tl.dot(q, k)
if bias_ptr is not None:
bias = load_fn(bias_ptr, False, MASK_STEPS
and (n_extra_tokens != 0), "zero")
# While bias is added after multiplying qk with sm_scale, our
# optimization to use 2^x instead of e^x results in an additional
# scale factor of log2(e) which we must also multiply the bias with.
qk += bias * 1.44269504089
m_ij = tl.maximum(m_i, tl.max(qk, 1))
qk = qk - m_ij[:, None]
p = tl.math.exp2(qk)
# CAVEAT: Must update l_ij before applying dropout
l_ij = tl.sum(p, 1)
if ENABLE_DROPOUT:
philox_offset = (batch_philox_offset +
start_m * BLOCK_M * actual_seqlen_k + start_n -
BLOCK_N)
keep = dropout_mask(
philox_seed,
philox_offset,
dropout_p,
BLOCK_M,
BLOCK_N,
actual_seqlen_k,
)
if RETURN_ENCODED_SOFTMAX:
tl.store(
encoded_softmax_block_ptr,
tl.where(keep, p,
-p).to(encoded_softmax_block_ptr.type.element_ty),
)
p = tl.where(keep, p, 0.0)
elif RETURN_ENCODED_SOFTMAX:
tl.store(
encoded_softmax_block_ptr,
p.to(encoded_softmax_block_ptr.type.element_ty),
)
# -- update output accumulator --
alpha = tl.math.exp2(m_i - m_ij)
acc = acc * alpha[:, None]
if not PRE_LOAD_V:
v = load_fn(
V_block_ptr,
MASK_STEPS and (n_extra_tokens != 0),
PADDED_HEAD,
"zero",
)
# -- update m_i and l_i
l_i = l_i * alpha + l_ij
# update m_i and l_i
m_i = m_ij
acc += tl.dot(p.to(V_block_ptr.type.element_ty), v)
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
if bias_ptr is not None:
bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N))
if RETURN_ENCODED_SOFTMAX:
encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr,
(0, BLOCK_N))
return acc, l_i, m_i
@triton.autotune(
configs=[
triton.Config(
{
"BLOCK_M": 256,
"BLOCK_N": 64,
"waves_per_eu": 2,
"PRE_LOAD_V": False,
},
num_stages=1,
num_warps=8,
),
triton.Config(
{
"BLOCK_M": 128,
"BLOCK_N": 128,
"waves_per_eu": 2,
"PRE_LOAD_V": False,
},
num_stages=1,
num_warps=4,
),
triton.Config(
{
"BLOCK_M": 256,
"BLOCK_N": 128,
"waves_per_eu": 2,
"PRE_LOAD_V": False,
},
num_stages=1,
num_warps=8,
),
triton.Config(
{
"BLOCK_M": 128,
"BLOCK_N": 64,
"waves_per_eu": 3,
"PRE_LOAD_V": True,
},
num_stages=1,
num_warps=4,
),
triton.Config(
{
"BLOCK_M": 128,
"BLOCK_N": 64,
"waves_per_eu": 3,
"PRE_LOAD_V": False,
},
num_stages=1,
num_warps=4,
),
triton.Config(
{
"BLOCK_M": 64,
"BLOCK_N": 64,
"waves_per_eu": 4,
"PRE_LOAD_V": False,
},
num_stages=1,
num_warps=8,
),
triton.Config(
{
"BLOCK_M": 32,
"BLOCK_N": 32,
"waves_per_eu": 4,
"PRE_LOAD_V": False,
},
num_stages=1,
num_warps=8,
),
# TODO: This config fails with head_size not pow2 with data mismatches.
# triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1,
# 'PRE_LOAD_V': False}, num_stages=1, num_warps=4),
triton.Config(
{
"BLOCK_M": 16,
"BLOCK_N": 16,
"waves_per_eu": 1,
"PRE_LOAD_V": False,
},
num_stages=1,
num_warps=4,
),
],
key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'],
)
@triton.jit
def attn_fwd(
Q,
K,
V,
bias,
sm_scale,
L,
Out,
stride_qz,
stride_qh,
stride_qm,
stride_qk,
stride_kz,
stride_kh,
stride_kn,
stride_kk,
stride_vz,
stride_vh,
stride_vk,
stride_vn,
stride_oz,
stride_oh,
stride_om,
stride_on,
stride_bz,
stride_bh,
stride_bm,
stride_bn,
cu_seqlens_q,
cu_seqlens_k,
dropout_p,
philox_seed,
philox_offset_base,
encoded_softmax,
HQ: tl.constexpr,
HK: tl.constexpr,
ACTUAL_BLOCK_DMODEL: tl.constexpr,
MAX_SEQLENS_Q: tl.constexpr,
MAX_SEQLENS_K: tl.constexpr,
VARLEN: tl.constexpr,
IS_CAUSAL: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
PRE_LOAD_V: tl.constexpr,
BIAS_TYPE: tl.constexpr,
ENABLE_DROPOUT: tl.constexpr,
RETURN_ENCODED_SOFTMAX: tl.constexpr,
):
start_m = tl.program_id(0)
off_h_q = tl.program_id(1)
off_z = tl.program_id(2)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
if VARLEN:
cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z)
cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1)
seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start
# We have a one-size-fits-all grid in id(0). Some seqlens might be too
# small for all start_m so for those we return early.
if start_m * BLOCK_M > seqlen_q:
return
cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z)
cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1)
seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start
else:
cu_seqlens_q_start = 0
cu_seqlens_k_start = 0
seqlen_q = MAX_SEQLENS_Q
seqlen_k = MAX_SEQLENS_K
# Now we compute whether we need to exit early due to causal masking.
# This is because for seqlen_q > seqlen_k, M rows of the attn scores
# are completely masked, resulting in 0s written to the output, and
# inf written to LSE. We don't need to do any GEMMs in this case.
# This block of code determines what N is, and if this WG is operating
# on those M rows.
n_blocks = cdiv_fn(seqlen_k, BLOCK_N)
if IS_CAUSAL:
# If seqlen_q == seqlen_k, the attn scores are a square matrix.
# If seqlen_q != seqlen_k, attn scores are rectangular which means
# the causal mask boundary is bottom right aligned, and ends at either
# the top edge (seqlen_q < seqlen_k) or left edge.
# This captures the decrease in n_blocks if we have a rectangular attn
# matrix
n_blocks_seqlen = cdiv_fn(
(start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N)
# This is what adjusts the block_max for the current WG, only
# if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks
n_blocks = min(n_blocks, n_blocks_seqlen)
# If we have no blocks after adjusting for seqlen deltas, this WG is
# part of the blocks that are all 0. We exit early.
if n_blocks <= 0:
o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om +
off_h_q * stride_oh)
O_block_ptr = tl.make_block_ptr(
base=Out + o_offset,
shape=(seqlen_q, BLOCK_DMODEL),
strides=(stride_om, stride_on),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty)
# We still need to write 0s to the result
# tl.store(O_block_ptr,
# acc.to(Out.type.element_ty), boundary_check=(0,1))
# l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q
# + offs_m
# We store inf to LSE, not -inf because in the bwd pass,
# we subtract this
# from qk which makes it -inf, such that exp(qk - inf) = 0
# for these masked blocks.
# l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32)
# tl.store(l_ptrs, l)
# TODO: Should dropout and return encoded softmax be handled here?
return
# If MQA / GQA, set the K and V head offsets appropriately.
GROUP_SIZE: tl.constexpr = HQ // HK
off_h_k = off_h_q // GROUP_SIZE if GROUP_SIZE != 1 else off_h_q
n_extra_tokens = 0
if seqlen_k < BLOCK_N:
n_extra_tokens = BLOCK_N - seqlen_k
elif seqlen_k % BLOCK_N:
n_extra_tokens = seqlen_k % BLOCK_N
padded_head = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL
# Compute pointers for all the tensors used in this kernel.
q_offset = (off_z * stride_qz + off_h_q * stride_qh +
cu_seqlens_q_start * stride_qm)
Q_block_ptr = tl.make_block_ptr(
base=Q + q_offset,
shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
k_offset = (off_z * stride_kz + off_h_k * stride_kh +
cu_seqlens_k_start * stride_kn)
K_block_ptr = tl.make_block_ptr(
base=K + k_offset,
shape=(ACTUAL_BLOCK_DMODEL, seqlen_k),
strides=(stride_kk, stride_kn),
offsets=(0, 0),
block_shape=(BLOCK_DMODEL, BLOCK_N),
order=(0, 1),
)
v_offset = (off_z * stride_vz + off_h_k * stride_vh +
cu_seqlens_k_start * stride_vk)
V_block_ptr = tl.make_block_ptr(
base=V + v_offset,
shape=(seqlen_k, ACTUAL_BLOCK_DMODEL),
strides=(stride_vk, stride_vn),
offsets=(0, 0),
block_shape=(BLOCK_N, BLOCK_DMODEL),
order=(1, 0),
)
if BIAS_TYPE != 0:
bias_ptr = tl.make_block_ptr(
base=bias + off_h_q * stride_bh,
shape=(seqlen_q, seqlen_k),
strides=(stride_bm, stride_bn),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_N),
order=(1, 0),
)
else:
bias_ptr = None
if ENABLE_DROPOUT:
batch_philox_offset = philox_offset_base \
+ (off_z * HQ + off_h_q) \
* seqlen_q * seqlen_k
else:
batch_philox_offset = 0
# We can ask to return the dropout mask without actually doing any dropout.
# In this case, we return an invalid pointer so indicate the mask is not i
# valid.
# TODO: Fix encoded softmax. It currently uses just h_q in the base offset.
if RETURN_ENCODED_SOFTMAX:
encoded_softmax_block_ptr = tl.make_block_ptr(
base=encoded_softmax + off_h_q * seqlen_q * seqlen_k,
shape=(seqlen_q, seqlen_k),
strides=(seqlen_k, 1),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_N),
order=(1, 0),
)
else:
encoded_softmax_block_ptr = 0
# initialize pointer to m and l
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# scale sm_scale by log_2(e) and use 2^x in the loop as we do not
# have native e^x support in HW.
qk_scale = sm_scale * 1.44269504089
# Q is loaded once at the beginning and shared by all N blocks.
q = load_fn(Q_block_ptr, True, padded_head, "zero")
q = (q * qk_scale).to(Q_block_ptr.type.element_ty)
# Here we compute how many full and masked blocks we have.
padded_block_k = n_extra_tokens != 0
is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0)
if IS_CAUSAL:
# There are always at least BLOCK_M // BLOCK_N masked blocks.
# Additionally there might be one more due to dissimilar seqlens.
masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn)
else:
# Padding on Q does not need to be masked in the FA loop.
masked_blocks = padded_block_k
# if IS_CAUSAL, not is_modulo_mn does not always result in an additional
# block. In this case we might exceed n_blocks so pick the min.
masked_blocks = min(masked_blocks, n_blocks)
n_full_blocks = n_blocks - masked_blocks
block_min = 0
block_max = n_blocks * BLOCK_N
# Compute for full blocks. Here we set causal to false regardless of its
# value because there is no masking. Similarly we do not need padding.
if n_full_blocks > 0:
block_max = (n_blocks - masked_blocks) * BLOCK_N
acc, l_i, m_i = _attn_fwd_inner(
acc,
l_i,
m_i,
q,
K_block_ptr,
V_block_ptr,
start_m,
seqlen_k,
dropout_p,
philox_seed,
batch_philox_offset,
encoded_softmax_block_ptr,
# _, _, offs_n_causal, masked_blocks, n_extra_tokens, _
block_min,
block_max,
0,
0,
0,
bias_ptr,
# IS_CAUSAL, ....
False,
BLOCK_M,
BLOCK_DMODEL,
BLOCK_N,
offs_m,
offs_n,
# _, MASK_STEPS, ...
PRE_LOAD_V,
False,
ENABLE_DROPOUT,
RETURN_ENCODED_SOFTMAX,
padded_head,
)
block_min = block_max
block_max = n_blocks * BLOCK_N
tl.debug_barrier()
# Remaining blocks, if any, are full / not masked.
if masked_blocks > 0:
offs_n_causal = offs_n + (seqlen_q - seqlen_k) if IS_CAUSAL else 0
K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks * BLOCK_N))
V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks * BLOCK_N, 0))
if bias_ptr is not None:
bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N))
if RETURN_ENCODED_SOFTMAX:
encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr,
(0, n_full_blocks))
acc, l_i, m_i = _attn_fwd_inner(
acc,
l_i,
m_i,
q,
K_block_ptr,
V_block_ptr,
start_m,
seqlen_k,
dropout_p,
philox_seed,
batch_philox_offset,
encoded_softmax_block_ptr,
block_min,
block_max,
offs_n_causal,
masked_blocks,
n_extra_tokens,
bias_ptr,
IS_CAUSAL,
BLOCK_M,
BLOCK_DMODEL,
BLOCK_N,
offs_m,
offs_n,
# _, MASK_STEPS, ...
PRE_LOAD_V,
True,
ENABLE_DROPOUT,
RETURN_ENCODED_SOFTMAX,
padded_head,
)
# epilogue
acc = acc / l_i[:, None]
if ENABLE_DROPOUT:
acc = acc / (1 - dropout_p)
# If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M,
# then we have one block with a row of all NaNs which come from computing
# softmax over a row of all -infs (-inf - inf = NaN). We check for that here
# and store 0s where there are NaNs as these rows should've been zeroed out.
end_m_idx = (start_m + 1) * BLOCK_M
start_m_idx = start_m * BLOCK_M
causal_start_idx = seqlen_q - seqlen_k
acc = acc.to(Out.type.element_ty)
if IS_CAUSAL: # noqa: SIM102
if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx:
out_mask_boundary = tl.full((BLOCK_DMODEL, ),
causal_start_idx,
dtype=tl.int32)
mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M)
out_ptrs_mask = (mask_m_offsets[:, None] >=
out_mask_boundary[None, :])
z = 0.0
acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty))
# write back LSE
# l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m
# If seqlen_q not multiple of BLOCK_M, we need to mask out the last
# few rows. This is only true for the last M block. For others,
# overflow_size will be -ve
# overflow_size = end_m_idx - seqlen_q
# if overflow_size > 0:
# boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32)
# # This is a > check because mask being 0 blocks the store.
# l_ptrs_mask = boundary > tl.arange(0, BLOCK_M)
# tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask)
# else:
# tl.store(l_ptrs, m_i + tl.math.log2(l_i))
# write back O
o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om +
off_h_q * stride_oh)
O_block_ptr = tl.make_block_ptr(
base=Out + o_offset,
shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),
strides=(stride_om, stride_on),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
# Need boundary check on this to make sure the padding from the
# Q and KV tensors in both dims are not part of what we store back.
# TODO: Do the boundary check optionally.
tl.store(O_block_ptr, acc, boundary_check=(0, 1))
def check_args(
q,
k,
v,
o,
varlen=True,
max_seqlens=None,
cu_seqlens_q=None,
cu_seqlens_k=None,
):
assert q.dim() == k.dim() and q.dim() == v.dim()
if varlen:
assert q.dim() == 3
total_q, nheads_q, head_size = q.shape
total_k, nheads_k, _ = k.shape
assert cu_seqlens_q is not None
assert cu_seqlens_k is not None
assert len(cu_seqlens_q) == len(cu_seqlens_k)
else:
assert q.dim() == 4
batch, nheads_q, seqlen_q, head_size = q.shape
_, nheads_k, seqlen_k, _ = k.shape
assert max_seqlens > 0
assert k.shape == v.shape
assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1]
# TODO: Change assert if we support qkl f8 and v f16
assert q.dtype == k.dtype and q.dtype == v.dtype
assert head_size <= 256
assert o.shape == q.shape
assert (nheads_q % nheads_k) == 0
class _attention(torch.autograd.Function):
@staticmethod
def forward(
ctx,
q,
k,
v,
o,
cu_seqlens_q,
cu_seqlens_k,
max_seqlens_q,
max_seqlens_k,
causal=False,
sm_scale=1.0,
bias=None,
):
if o is None:
o = torch.empty_like(q, dtype=v.dtype)
check_args(
q,
k,
v,
o,
varlen=True,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
)
if True: # varlen
total_q, nheads_q, head_size = q.shape
total_k, nheads_k, _ = k.shape
batch = len(cu_seqlens_q) - 1
q_strides = (0, q.stride(1), q.stride(0), q.stride(2))
k_strides = (0, k.stride(1), k.stride(0), k.stride(2))
v_strides = (0, v.stride(1), v.stride(0), v.stride(2))
o_strides = (0, o.stride(1), o.stride(0), o.stride(2))
else:
batch, seqlen_q, nheads_q, head_size = q.shape
_, seqlen_k, nheads_k, _ = k.shape
q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3))
k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3))
v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3))
o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3))
# Get closest power of 2 over or equal to 32.
unpadded_head_dims = {32, 64, 128, 256}
if head_size not in unpadded_head_dims:
padded_d_model = None
for i in unpadded_head_dims:
if i > head_size:
padded_d_model = i
break
assert padded_d_model is not None
else:
padded_d_model = head_size
grid = lambda META: (
triton.cdiv(max_seqlens_q, META["BLOCK_M"]),
nheads_q,
batch,
)
encoded_softmax = None
# Seed the RNG so we get reproducible results for testing.
philox_seed = 0x1BF52
philox_offset = 0x1D4B42
if bias is not None:
bias_strides = (
bias.stride(0),
bias.stride(1),
bias.stride(2),
bias.stride(3),
)
else:
bias_strides = (0, 0, 0, 0)
attn_fwd[grid](
q,
k,
v,
bias,
sm_scale,
None,
o,
*q_strides,
*k_strides,
*v_strides,
*o_strides,
*bias_strides,
cu_seqlens_q,
cu_seqlens_k,
dropout_p=0.0,
philox_seed=philox_seed,
philox_offset_base=philox_offset,
encoded_softmax=encoded_softmax,
HQ=nheads_q,
HK=nheads_k,
ACTUAL_BLOCK_DMODEL=head_size,
MAX_SEQLENS_Q=max_seqlens_q,
MAX_SEQLENS_K=max_seqlens_k,
IS_CAUSAL=causal,
VARLEN=True,
BLOCK_DMODEL=padded_d_model,
BIAS_TYPE=0 if bias is None else 1,
ENABLE_DROPOUT=False,
RETURN_ENCODED_SOFTMAX=False,
)
ctx.grid = grid
ctx.sm_scale = sm_scale
ctx.BLOCK_DMODEL = head_size
ctx.causal = causal
ctx.dropout_p = 0.0
ctx.philox_seed = philox_seed
ctx.philox_offset = philox_offset
ctx.encoded_softmax = encoded_softmax
ctx.return_encoded_softmax = False
return o, encoded_softmax
triton_attention = _attention.apply

View File

@@ -0,0 +1,116 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.triton_utils import tl, triton
# Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
# can be used to combine partial attention results (in the split-KV case)
def merge_attn_states(
output: torch.Tensor,
prefix_output: torch.Tensor,
prefix_lse: torch.Tensor,
suffix_output: torch.Tensor,
suffix_lse: torch.Tensor,
output_lse: torch.Tensor | None = None,
) -> None:
num_tokens = output.shape[0]
num_query_heads = output.shape[1]
head_size = output.shape[2]
padded_head_size = triton.next_power_of_2(head_size)
# We assume the output stride on num_head is not always as same as the
# `suffix_output` and `prefix_output`, as them might be padded by the attention
# backend.
prefix_head_stride = prefix_output.stride(1)
output_head_stride = output.stride(1)
# TODO(woosuk): Use CUDA kernel instead of Triton to minimize CPU overhead.
merge_attn_states_kernel[(num_tokens, num_query_heads)](
output,
output_lse,
prefix_output,
prefix_lse,
suffix_output,
suffix_lse,
prefix_head_stride,
output_head_stride,
head_size,
padded_head_size,
output_lse is not None,
)
@triton.jit
def merge_attn_states_kernel(
output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
output_lse, # [NUM_HEADS, NUM_TOKENS]
prefix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
prefix_lse, # [NUM_HEADS, NUM_TOKENS]
suffix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
suffix_lse, # [NUM_HEADS, NUM_TOKENS]
prefix_head_stride,
output_head_stride,
HEAD_SIZE: tl.constexpr,
PADDED_HEAD_SIZE: tl.constexpr,
OUTPUT_LSE: tl.constexpr,
):
token_idx = tl.program_id(0)
num_tokens = tl.num_programs(0)
head_idx = tl.program_id(1)
num_heads = tl.num_programs(1)
p_lse = tl.load(prefix_lse + head_idx * num_tokens + token_idx)
s_lse = tl.load(suffix_lse + head_idx * num_tokens + token_idx)
# FA2 and FA3 have different behavior for when the sum-exp is 0, this namely
# arises with 0 len seqlens. FA3 returns -inf here while FA2 returns inf.
# If we see an inf assume FA2 and convert inf to -inf for consistency
# and correctness. Inf generally doesn't make sense in this context outside
# of undefined-behavior/FA2-case, so I think this a safe assumption.
p_lse = float("-inf") if p_lse == float("inf") else p_lse
s_lse = float("-inf") if s_lse == float("inf") else s_lse
max_lse = tl.maximum(p_lse, s_lse)
p_lse = p_lse - max_lse
s_lse = s_lse - max_lse
# Will reuse precomputed Exp values for scale factor computation.
p_se = tl.exp(p_lse)
s_se = tl.exp(s_lse)
out_se = p_se + s_se
if OUTPUT_LSE:
out_lse = tl.log(out_se) + max_lse
tl.store(output_lse + head_idx * num_tokens + token_idx, out_lse)
head_arange = tl.arange(0, PADDED_HEAD_SIZE)
head_mask = head_arange < HEAD_SIZE
p_out = tl.load(
prefix_output
+ token_idx * num_heads * prefix_head_stride
+ head_idx * prefix_head_stride
+ head_arange,
mask=head_mask,
)
s_out = tl.load(
suffix_output
+ token_idx * num_heads * prefix_head_stride
+ head_idx * prefix_head_stride
+ head_arange,
mask=head_mask,
)
# NOTE(woosuk): Be careful with the numerical stability.
# We should compute the scale first, and then multiply it with the output.
# Do not multiply the output with tl.exp(p_lse) or tl.exp(s_lse) directly.
p_scale = p_se / out_se
s_scale = s_se / out_se
out = p_out * p_scale + s_out * s_scale
tl.store(
output
+ token_idx * num_heads * output_head_stride
+ head_idx * output_head_stride
+ head_arange,
out,
mask=head_mask,
)

View File

@@ -0,0 +1,184 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
@triton.jit
def reshape_and_cache_kernel_flash(
key_ptr, # [num_tokens, num_heads, head_size]
value_ptr, # [num_tokens, num_heads, head_size]
key_cache_ptr, # [num_blocks, block_size, num_heads, head_size]
value_cache_ptr, # [num_blocks, block_size, num_heads, head_size]
slot_mapping_ptr, # [num_tokens]
k_scale, # float32
v_scale, # float32
# strides
key_stride: tl.int64,
value_stride: tl.int64,
block_stride: tl.int64,
page_stride: tl.int64,
num_heads: tl.constexpr,
head_size: tl.constexpr,
block_size: tl.constexpr,
# FP8 flags
FP8_KV_CACHE: tl.constexpr,
# tune parameters
TILE_SIZE: tl.constexpr,
):
token_idx = tl.program_id(axis=0)
slot_idx = tl.load(slot_mapping_ptr + token_idx).to(tl.int64)
if slot_idx < 0:
# Padding token that should be ignored.
return
tile_i = tl.program_id(axis=1)
tile_offs = tl.arange(0, TILE_SIZE)
tile_pos = tile_i * TILE_SIZE + tile_offs
block_idx = slot_idx // block_size
block_offset = slot_idx % block_size
src_key_idx = token_idx * key_stride
src_value_idx = token_idx * value_stride
tgt_idx = block_idx * block_stride + block_offset * page_stride
# [TILE_SIZE]
key_load = tl.load(
key_ptr + src_key_idx + tile_pos, mask=tile_pos < (num_heads * head_size)
)
if FP8_KV_CACHE:
# tl.store will do the correct implicit cast to fp8,
# based on the key_cache_ptr.dtype.element_ty
key_tile = key_load if key_load.dtype.is_fp8() else key_load / tl.load(k_scale)
else:
key_tile = key_load
# [TILE_SIZE]
value_load = tl.load(
value_ptr + src_value_idx + tile_pos, mask=tile_pos < (num_heads * head_size)
)
if FP8_KV_CACHE:
if value_load.dtype.is_fp8():
value_tile = value_load
else:
# tl.store will do the correct implicit cast to fp8,
# based on the value_cache_ptr.dtype.element_ty
value_tile = value_load / tl.load(v_scale)
else:
value_tile = value_load
tl.store(
key_cache_ptr + tgt_idx + tile_pos,
key_tile,
mask=tile_pos < (num_heads * head_size),
)
tl.store(
value_cache_ptr + tgt_idx + tile_pos,
value_tile,
mask=tile_pos < (num_heads * head_size),
)
return
def triton_reshape_and_cache_flash(
key: torch.Tensor, # [num_tokens, num_heads, head_size]
value: torch.Tensor, # [num_tokens, num_heads, head_size]
# [num_blocks, block_size, num_heads, head_size]
key_cache: torch.Tensor,
# [num_blocks, block_size, num_heads, head_size]
value_cache: torch.Tensor,
slot_mapping: torch.Tensor, # [num_tokens]
kv_cache_dtype: str, # "auto", "fp8"
k_scale: torch.Tensor, # float32
v_scale: torch.Tensor, # float32
):
num_heads = key.shape[1]
head_size = key.shape[2]
block_size = key_cache.shape[1]
n = num_heads * head_size
key_stride = key.stride()[0]
value_stride = value.stride()[0]
block_stride = key_cache.stride()[0]
page_stride = key_cache.stride()[1]
head_stride = key_cache.stride()[2]
assert head_stride == head_size, "only continous heads are supported"
assert kv_cache_dtype == "auto" or kv_cache_dtype.startswith("fp8"), (
f"unsupported kv_cache_dtype (str), got {kv_cache_dtype}."
)
kv_cache_torch_dtype = (
current_platform.fp8_dtype()
if kv_cache_dtype.startswith("fp8")
else key_cache.dtype
)
if key_cache.dtype != kv_cache_torch_dtype and kv_cache_dtype.startswith("fp8"):
# to avoid erounous implicit cast in triton kernel (tl.store to uint8)
# (e.g. explicit cast to fp8e4m3fnuz is not supported in triton 3.4)
key_cache = key_cache.view(kv_cache_torch_dtype)
value_cache = value_cache.view(kv_cache_torch_dtype)
assert kv_cache_dtype != torch.uint8, (
"explicit fp8 cast and store to "
"uint8 is not supported by triton reshape_and_cache_flash"
)
FP8_KV_CACHE = kv_cache_dtype.startswith("fp8")
assert (not FP8_KV_CACHE) or kv_cache_torch_dtype in [
torch.float8_e4m3fn,
torch.float8_e5m2,
torch.uint8,
torch.float8_e4m3fnuz,
], (
"unsupported dtype of KV cache tensor, got "
"{kv_cache_torch_dtype}. Supported kv cache dtypes: fp8e4m3fn, "
"fp8e5m2, uint8, bfloat16, float16, float32, fp8e4m3fnuz."
)
# heuristics instead of autotuning
TILE_SIZE = min(2048, triton.next_power_of_2(n))
if current_platform.is_rocm() or current_platform.is_xpu():
num_stages = 4
num_warps = 8
else: # cuda
num_stages = 10
num_warps = 16
if torch.cuda.get_device_capability(key.device)[0] < 9:
TILE_SIZE = min(512, TILE_SIZE)
# TODO(ngl): maybe replace with static launch grid to avoid overhead if
# using cudagraphs
grid = lambda meta: (
slot_mapping.shape[0],
triton.cdiv(n, meta["TILE_SIZE"]),
)
reshape_and_cache_kernel_flash[grid](
key_ptr=key,
value_ptr=value,
key_cache_ptr=key_cache,
value_cache_ptr=value_cache,
slot_mapping_ptr=slot_mapping,
k_scale=k_scale,
v_scale=v_scale,
# strides
key_stride=key_stride,
value_stride=value_stride,
block_stride=block_stride,
page_stride=page_stride,
num_heads=num_heads,
head_size=head_size,
block_size=block_size,
# FP8 flags
FP8_KV_CACHE=FP8_KV_CACHE,
# autotune parameters
TILE_SIZE=TILE_SIZE,
num_warps=num_warps,
num_stages=num_stages,
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,139 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This file contains ops for ViT attention to be compatible with torch.compile
as there are operations here not supported by torch.compile (for instance,
`.item()` in flash attention)
Using these ops and wrapping vision blocks with `torch.compile` can speed up
throughput in vision models by ~5% relative on H100, and improve token
latencies by ~7% (see qwen2_5_vl for example usage)
To use these ops, you must have a recent version of PyTorch installed (>= 2.4.0)
"""
import einops
import torch
import torch.nn.functional as F
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
def flash_attn_maxseqlen_wrapper(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens: torch.Tensor,
max_seqlen: torch.Tensor,
batch_size: int,
is_rocm_aiter: bool,
) -> torch.Tensor:
if is_rocm_aiter:
from aiter import flash_attn_varlen_func
else:
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
output = flash_attn_varlen_func(
q,
k,
v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen.item(),
max_seqlen_k=max_seqlen.item(),
dropout_p=0.0,
causal=False,
)
context_layer = einops.rearrange(output, "(b s) h d -> b s h d", b=batch_size)
return context_layer
def flash_attn_maxseqlen_wrapper_fake(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens: torch.Tensor,
max_seqlen: torch.Tensor,
batch_size: int,
is_rocm_aiter: bool,
) -> torch.Tensor:
return torch.empty_like(q)
direct_register_custom_op(
op_name="flash_attn_maxseqlen_wrapper",
op_func=flash_attn_maxseqlen_wrapper,
fake_impl=flash_attn_maxseqlen_wrapper_fake,
)
def vit_flash_attn_wrapper(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens: torch.Tensor,
max_seqlen: torch.Tensor,
batch_size: int,
is_rocm_aiter: bool,
) -> torch.Tensor:
return torch.ops.vllm.flash_attn_maxseqlen_wrapper(
q, k, v, cu_seqlens, max_seqlen, batch_size, is_rocm_aiter
)
# TODO: Once we have a torch 2.10, we can use tensor slices
# so we won't need to wrap this in custom ops
def torch_sdpa_wrapper(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens: torch.Tensor,
) -> torch.Tensor:
# Never remove the contiguous logic for ROCm
# Without it, hallucinations occur with the backend
if current_platform.is_rocm():
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
outputs = []
lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
q_chunks = torch.split(q, lens, dim=1)
k_chunks = torch.split(k, lens, dim=1)
v_chunks = torch.split(v, lens, dim=1)
for q_i, k_i, v_i in zip(q_chunks, k_chunks, v_chunks):
q_i, k_i, v_i = (
einops.rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
)
output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
output_i = einops.rearrange(output_i, "b h s d -> b s h d ")
outputs.append(output_i)
context_layer = torch.cat(outputs, dim=1)
return context_layer
def torch_sdpa_wrapper_fake(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens: torch.Tensor,
) -> torch.Tensor:
return torch.empty_like(q)
direct_register_custom_op(
op_name="torch_sdpa_wrapper",
op_func=torch_sdpa_wrapper,
fake_impl=torch_sdpa_wrapper_fake,
)
def vit_torch_sdpa_wrapper(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens: torch.Tensor,
) -> torch.Tensor:
return torch.ops.vllm.torch_sdpa_wrapper(q, k, v, cu_seqlens)

View File

@@ -1,94 +1,145 @@
import enum
from functools import lru_cache
from typing import Type
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from functools import cache
from typing import NamedTuple, cast, get_args
import torch
import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.abstract import AttentionBackend, AttentionType
from vllm.attention.backends.registry import (
MAMBA_TYPE_TO_BACKEND_MAP,
MambaAttentionBackendEnum,
)
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.utils import is_cpu, is_hip, is_musa
from vllm.utils.import_utils import resolve_obj_by_qualname
logger = init_logger(__name__)
class _Backend(enum.Enum):
FLASH_ATTN = enum.auto()
XFORMERS = enum.auto()
ROCM_FLASH = enum.auto()
TORCH_SDPA = enum.auto()
FLASHINFER = enum.auto()
class AttentionSelectorConfig(NamedTuple):
head_size: int
dtype: torch.dtype
kv_cache_dtype: CacheDType | None
block_size: int | None
use_mla: bool = False
has_sink: bool = False
use_sparse: bool = False
use_mm_prefix: bool = False
attn_type: str = AttentionType.DECODER
def __repr__(self):
return (
f"AttentionSelectorConfig(head_size={self.head_size}, "
f"dtype={self.dtype}, "
f"kv_cache_dtype={self.kv_cache_dtype}, "
f"block_size={self.block_size}, "
f"use_mla={self.use_mla}, "
f"has_sink={self.has_sink}, "
f"use_sparse={self.use_sparse}, "
f"use_mm_prefix={self.use_mm_prefix}, "
f"attn_type={self.attn_type})"
)
@lru_cache(maxsize=None)
def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]:
backend = _which_attn_to_use(dtype)
if backend == _Backend.FLASH_ATTN:
logger.info("Using FlashAttention-2 backend.")
from vllm.attention.backends.flash_attn import ( # noqa: F401
FlashAttentionBackend)
return FlashAttentionBackend
elif backend == _Backend.XFORMERS:
logger.info("Using XFormers backend.")
from vllm.attention.backends.xformers import ( # noqa: F401
XFormersBackend)
return XFormersBackend
elif backend == _Backend.ROCM_FLASH:
logger.info("Using ROCmFlashAttention backend.")
from vllm.attention.backends.rocm_flash_attn import ( # noqa: F401
ROCmFlashAttentionBackend)
return ROCmFlashAttentionBackend
elif backend == _Backend.TORCH_SDPA:
logger.info("Using Torch SDPA backend.")
from vllm.attention.backends.torch_sdpa import TorchSDPABackend
return TorchSDPABackend
elif backend == _Backend.FLASHINFER:
logger.info("Using Flashinfer backend.")
logger.warning("Eager mode is enforced for the Flashinfer backend. ")
from vllm.attention.backends.flashinfer import FlashInferBackend
return FlashInferBackend
else:
raise ValueError("Invalid attention backend.")
def get_attn_backend(
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: str | None,
block_size: int | None,
use_mla: bool = False,
has_sink: bool = False,
use_sparse: bool = False,
use_mm_prefix: bool = False,
attn_type: str | None = None,
) -> type[AttentionBackend]:
"""Selects which attention backend to use and lazily imports it."""
if kv_cache_dtype is not None:
valid_cache_dtypes = get_args(CacheDType)
assert kv_cache_dtype in valid_cache_dtypes, (
f"Invalid kv_cache_dtype: {kv_cache_dtype}. "
f"Valid values are: {valid_cache_dtypes}"
)
from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
backend_enum = vllm_config.attention_config.backend
attn_selector_config = AttentionSelectorConfig(
head_size=head_size,
dtype=dtype,
kv_cache_dtype=cast(CacheDType | None, kv_cache_dtype),
block_size=block_size,
use_mla=use_mla,
has_sink=has_sink,
use_sparse=use_sparse,
use_mm_prefix=use_mm_prefix,
attn_type=attn_type or AttentionType.DECODER,
)
return _cached_get_attn_backend(
backend=backend_enum,
attn_selector_config=attn_selector_config,
)
def _which_attn_to_use(dtype: torch.dtype) -> _Backend:
"""Returns which flash attention backend to use."""
if is_cpu():
return _Backend.TORCH_SDPA
if is_musa():
return _Backend.FLASH_ATTN
@cache
def _cached_get_attn_backend(
backend,
attn_selector_config: AttentionSelectorConfig,
) -> type[AttentionBackend]:
from vllm.platforms import current_platform
if is_hip():
# AMD GPUs.
if torch.cuda.get_device_capability()[0] != 9:
# not Instinct series GPUs.
logger.info("flash_atten is not supported on NAVI GPUs.")
return _Backend.ROCM_FLASH
attention_cls = current_platform.get_attn_backend_cls(
backend,
attn_selector_config=attn_selector_config,
)
if not attention_cls:
raise ValueError(
f"Invalid attention backend for {current_platform.device_name}"
)
backend = resolve_obj_by_qualname(attention_cls)
# NVIDIA GPUs.
if torch.cuda.get_device_capability()[0] < 8:
# Volta and Turing NVIDIA GPUs.
logger.info("Cannot use FlashAttention-2 backend for Volta and Turing "
"GPUs.")
return _Backend.XFORMERS
# Adjust kv cache layout if the selected backend requires a specific one
required_layout = backend.get_required_kv_cache_layout()
if required_layout is not None:
from vllm.v1.attention.backends.utils import set_kv_cache_layout
if dtype not in (torch.float16, torch.bfloat16):
logger.info("Cannot use FlashAttention-2 backend for dtype other than "
"torch.float16 or torch.bfloat16.")
return _Backend.XFORMERS
try:
import flash_attn # noqa: F401
except ImportError:
set_kv_cache_layout(required_layout)
logger.info(
"Cannot use FlashAttention-2 backend because the flash_attn "
"package is not found. Please install it for better performance.")
return _Backend.XFORMERS
"Using %s KV cache layout for %s backend.",
required_layout,
backend.get_name(),
)
backend_by_env_var = envs.VLLM_ATTENTION_BACKEND
if backend_by_env_var is not None:
return _Backend[backend_by_env_var]
return backend
# Default case.
return _Backend.FLASH_ATTN
def get_mamba_attn_backend(
mamba_type: str,
) -> type[AttentionBackend]:
"""Select which mamba attention backend to use and lazily import it."""
return _cached_get_mamba_attn_backend(mamba_type)
@cache
def _cached_get_mamba_attn_backend(
mamba_type: str,
) -> type[AttentionBackend]:
assert mamba_type and isinstance(mamba_type, str)
selected_backend = None
try:
backend_name = MAMBA_TYPE_TO_BACKEND_MAP[mamba_type]
selected_backend = MambaAttentionBackendEnum[backend_name]
except KeyError as e:
raise ValueError(
f"Invalid mamba attention backend type: '{backend_name}'. Valid "
f"backends are: {list(MambaAttentionBackendEnum.__members__.keys())}"
) from e
mamba_attn_backend = selected_backend.get_class()
return mamba_attn_backend

View File

View File

@@ -0,0 +1,118 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.logger import init_logger
from vllm.platforms import current_platform
logger = init_logger(__name__)
if current_platform.is_cuda():
from vllm import _custom_ops as ops
reshape_and_cache_flash = ops.reshape_and_cache_flash
from vllm.vllm_flash_attn import flash_attn_varlen_func, get_scheduler_metadata
elif current_platform.is_xpu():
from vllm._ipex_ops import ipex_ops as ops
reshape_and_cache_flash = ops.reshape_and_cache_flash
flash_attn_varlen_func = ops.flash_attn_varlen_func
get_scheduler_metadata = ops.get_scheduler_metadata
elif current_platform.is_rocm():
try:
from flash_attn import flash_attn_varlen_func # noqa: F401
except ImportError as e:
raise ImportError(
"Rocm platform requires upstream flash-attn "
"to be installed. Please install flash-attn first."
) from e
def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
# import here to avoid circular dependencies
from vllm.platforms import current_platform
if current_platform.is_xpu():
return 2
try:
from vllm.vllm_flash_attn.flash_attn_interface import (
fa_version_unsupported_reason,
is_fa_version_supported,
)
device_capability = current_platform.get_device_capability()
assert device_capability is not None
# 1. default version depending on platform
fa_version = (
3 if (device_capability.major == 9 and is_fa_version_supported(3)) else 2
)
# 2. override if passed by environment or config
from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
if vllm_config.attention_config.flash_attn_version is not None:
fa_version = vllm_config.attention_config.flash_attn_version
# 3. fallback for unsupported combinations
if device_capability.major == 10 and fa_version == 3:
logger.warning_once(
"Cannot use FA version 3 on Blackwell platform "
"defaulting to FA version 2."
)
fa_version = 2
if requires_alibi and fa_version == 3:
logger.warning_once(
"Cannot use FA version 3 with ALiBi, defaulting to FA version 2."
)
fa_version = 2
if not is_fa_version_supported(fa_version):
logger.error(
"Cannot use FA version %d is not supported due to %s",
fa_version,
fa_version_unsupported_reason(fa_version),
)
assert is_fa_version_supported(fa_version)
return fa_version
except (ImportError, AssertionError):
return None
def flash_attn_supports_fp8() -> bool:
return (
get_flash_attn_version() == 3
and current_platform.get_device_capability().major == 9
)
def flash_attn_supports_sinks() -> bool:
if current_platform.is_xpu():
return True
else:
return get_flash_attn_version() == 3
def flash_attn_supports_mla():
from vllm.platforms import current_platform
if current_platform.is_cuda():
try:
from vllm.vllm_flash_attn.flash_attn_interface import (
is_fa_version_supported,
)
return (
is_fa_version_supported(3)
and current_platform.get_device_capability()[0] == 9
)
except (ImportError, AssertionError):
pass
return False
def is_flash_attn_varlen_func_available() -> bool:
return current_platform.is_cuda() or current_platform.is_xpu()

View File

@@ -0,0 +1,33 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
def validate_kv_sharing_target(
current_layer_name, target_layer_name, static_forward_context
):
error_msg = (
f"Specified KV sharing target layer for {current_layer_name} "
f"is not valid: target layer {target_layer_name} "
)
if current_layer_name == target_layer_name:
raise ValueError(error_msg + "cannot be the same as the current layer.")
if target_layer_name not in static_forward_context:
from vllm.model_executor.models.utils import extract_layer_index
# If target layer name is not in the static fwd context, it means either
# a) the target layer does not come BEFORE the current layer, or
# b) the target layer is not an Attention layer that exists in the model
current_layer_idx = extract_layer_index(current_layer_name)
target_layer_idx = extract_layer_index(target_layer_name)
if current_layer_idx <= target_layer_idx:
raise ValueError(error_msg + "must come before the current layer.")
else:
raise ValueError(error_msg + "is not a valid Attention layer in the model.")
# Currently KV sharing is only supported between layers of the same type
target_layer_attn_type = static_forward_context[target_layer_name].attn_type
expected = static_forward_context[current_layer_name].attn_type
if target_layer_attn_type != expected:
raise ValueError(
error_msg + f"must be the same type as the current layer ({expected})."
)

View File

@@ -0,0 +1,60 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import inspect
from collections.abc import Callable
from functools import wraps
from vllm.distributed.kv_transfer import (
get_kv_transfer_group,
has_kv_transfer_group,
is_v1_kv_transfer_group,
)
def maybe_transfer_kv_layer(func: Callable) -> Callable:
"""Decorator that handles KV layer transfer prior and after execution of
an attention layer, if enabled. Otherwise, the wrapper is a no-op.
On entry: waits for the KV layer from the connector.
On exit: saves the KV layer to the connector.
"""
# Import at runtime to avoid circular dependency
from vllm.attention.layer import get_attention_context
# Inspect the signature ONCE when the decorator is applied.
sig = inspect.signature(func)
param_names = list(sig.parameters.keys())
# Find the index of 'layer_name' parameter.
try:
layer_name_index = param_names.index("layer_name")
except ValueError as e:
raise TypeError(
f"Function {func.__name__} must have a 'layer_name' parameter"
) from e
@wraps(func)
def wrapper(*args, **kwargs):
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
return func(*args, **kwargs)
layer_name: str = args[layer_name_index]
# Extract attention context (layer-specific metadata, layer, and kv_cache)
attn_metadata, attn_layer, kv_cache = get_attention_context(layer_name)
connector = get_kv_transfer_group()
if attn_metadata is None or not connector.has_connector_metadata():
return func(*args, **kwargs)
# Wait for KV layer on entry
connector.wait_for_layer_load(layer_name)
# Execute the function
result = func(*args, **kwargs)
# Save KV cache layer on exit
connector.save_kv_layer(layer_name, kv_cache, attn_metadata)
return result
return wrapper