Files
2026-01-19 10:38:50 +08:00

444 lines
14 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, ClassVar, Generic, Protocol, TypeVar, get_args
import torch
if TYPE_CHECKING:
from vllm.config.cache import CacheDType
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backends.utils import KVCacheLayoutType
class AttentionType:
"""
Attention type.
Use string to be compatible with `torch.compile`.
"""
DECODER = "decoder"
"""Decoder attention between previous layer Q/K/V."""
ENCODER = "encoder"
"""Encoder attention between previous layer Q/K/V for encoder-decoder."""
ENCODER_ONLY = "encoder_only"
"""Encoder attention between previous layer Q/K/V."""
ENCODER_DECODER = "encoder_decoder"
"""Attention between dec. Q and enc. K/V for encoder-decoder."""
class MultipleOf:
base: int
def __init__(self, base: int):
self.base = base
class AttentionBackend(ABC):
"""Abstract class for attention backends."""
# For some attention backends, we allocate an output tensor before
# calling the custom op. When piecewise cudagraph is enabled, this
# makes sure the output tensor is allocated inside the cudagraph.
accept_output_buffer: bool = False
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list["CacheDType"]] = ["auto"]
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [MultipleOf(1)]
@staticmethod
@abstractmethod
def get_name() -> str:
raise NotImplementedError
@staticmethod
@abstractmethod
def get_impl_cls() -> type["AttentionImpl"]:
raise NotImplementedError
@staticmethod
@abstractmethod
def get_builder_cls(): # -> Type["AttentionMetadataBuilder"]:
raise NotImplementedError
@staticmethod
@abstractmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
raise NotImplementedError
@staticmethod
def get_kv_cache_stride_order(
include_num_layers_dimension: bool = False,
) -> tuple[int, ...]:
"""
Get the physical (memory layout) ordering of the kv cache dimensions.
e.g. if the KV cache shape is
[2, num_blocks, block_size, num_heads, head_size],
and get_kv_cache_stride_order returns (1, 3, 0, 2, 4) then the physical
ordering of dimensions is
[num_blocks, num_heads, 2, block_size, head_size].
If this function is unimplemented / raises NotImplementedError,
the physical layout of the KV cache will match the logical shape.
Args:
include_num_layers_dimension: if True, includes an additional
num_layers dimension, which is assumed to be prepended
to the logical KV cache shape.
With the above example, a return value (2, 4, 0, 1, 3, 5)
corresponds to
[num_blocks, num_heads, num_layers, 2, block_size, head_size].
If an additional dimension is NOT included in the returned
tuple, the physical layout will not include a layers dimension.
Returns:
A tuple of ints which is a permutation of range(len(shape)).
"""
raise NotImplementedError
@classmethod
def full_cls_name(cls) -> tuple[str, str]:
return (cls.__module__, cls.__qualname__)
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return []
@classmethod
def supports_head_size(cls, head_size: int) -> bool:
supported_head_sizes = cls.get_supported_head_sizes()
return (not supported_head_sizes) or head_size in supported_head_sizes
@classmethod
def supports_dtype(cls, dtype: torch.dtype) -> bool:
return dtype in cls.supported_dtypes
@classmethod
def supports_kv_cache_dtype(cls, kv_cache_dtype: "CacheDType | None") -> bool:
if kv_cache_dtype is None:
return True
return (not cls.supported_kv_cache_dtypes) or (
kv_cache_dtype in cls.supported_kv_cache_dtypes
)
@classmethod
def supports_block_size(cls, block_size: int | None) -> bool:
from vllm.config.cache import BlockSize
if block_size is None:
return True
valid_sizes = get_args(BlockSize)
if block_size not in valid_sizes:
return False
supported_kernel_block_sizes = cls.get_supported_kernel_block_sizes()
if not supported_kernel_block_sizes:
return True
for supported_size in supported_kernel_block_sizes:
if isinstance(supported_size, MultipleOf):
supported_size = supported_size.base
# With hybrid_blocks feature, the framework-level block size
# only needs to be a multiple of the kernel's requirement,
# even if the kernel requires a fixed block_size.
if block_size % supported_size == 0:
return True
return False
@classmethod
def is_mla(cls) -> bool:
return False
@classmethod
def supports_sink(cls) -> bool:
return False
@classmethod
def supports_mm_prefix(cls) -> bool:
return False
@classmethod
def is_sparse(cls) -> bool:
return False
@classmethod
def supports_attn_type(cls, attn_type: str) -> bool:
"""Check if backend supports a given attention type.
By default, only supports decoder attention.
Backends should override this to support other attention types.
"""
return attn_type == AttentionType.DECODER
@classmethod
def supports_compute_capability(cls, capability: "DeviceCapability") -> bool:
return True
@classmethod
def supports_combination(
cls,
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: "CacheDType | None",
block_size: int | None,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
device_capability: "DeviceCapability",
) -> str | None:
return None
@classmethod
def validate_configuration(
cls,
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: "CacheDType | None",
block_size: int | None,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
use_mm_prefix: bool,
device_capability: "DeviceCapability",
attn_type: str,
) -> list[str]:
invalid_reasons = []
if not cls.supports_head_size(head_size):
invalid_reasons.append("head_size not supported")
if not cls.supports_dtype(dtype):
invalid_reasons.append("dtype not supported")
if not cls.supports_kv_cache_dtype(kv_cache_dtype):
invalid_reasons.append("kv_cache_dtype not supported")
if not cls.supports_block_size(block_size):
invalid_reasons.append("block_size not supported")
if use_mm_prefix and not cls.supports_mm_prefix():
invalid_reasons.append(
"partial multimodal token full attention not supported"
)
if use_mla != cls.is_mla():
if use_mla:
invalid_reasons.append("MLA not supported")
else:
invalid_reasons.append("non-MLA not supported")
if has_sink and not cls.supports_sink():
invalid_reasons.append("sink setting not supported")
if use_sparse != cls.is_sparse():
if use_sparse:
invalid_reasons.append("sparse not supported")
else:
invalid_reasons.append("non-sparse not supported")
if not cls.supports_compute_capability(device_capability):
invalid_reasons.append("compute capability not supported")
if not cls.supports_attn_type(attn_type):
invalid_reasons.append(f"attention type {attn_type} not supported")
combination_reason = cls.supports_combination(
head_size,
dtype,
kv_cache_dtype,
block_size,
use_mla,
has_sink,
use_sparse,
device_capability,
)
if combination_reason is not None:
invalid_reasons.append(combination_reason)
return invalid_reasons
@classmethod
def get_required_kv_cache_layout(cls) -> "KVCacheLayoutType | None":
return None
class AttentionMetadata:
pass
T = TypeVar("T", bound=AttentionMetadata)
class AttentionLayer(Protocol):
_q_scale: torch.Tensor
_k_scale: torch.Tensor
_v_scale: torch.Tensor
_q_scale_float: float
_k_scale_float: float
_v_scale_float: float
_prob_scale: torch.Tensor
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ...
class AttentionImpl(ABC, Generic[T]):
# Whether the attention impl can return the softmax lse for decode.
# Some features like decode context parallelism require the softmax lse.
can_return_lse_for_decode: bool = False
# Whether the attention impl supports Prefill Context Parallelism.
supports_pcp: bool = False
# Whether the attention impl(or ops) supports MTP
# when cp_kv_cache_interleave_size > 1
supports_mtp_with_cp_non_trivial_interleave_size: bool = False
# some attention backends might not always want to return lse
# even if they can return lse (for efficiency reasons)
need_to_return_lse_for_decode: bool = False
# Whether this attention implementation supports pre-quantized query input.
# When True, the attention layer will quantize queries before passing them
# to this backend, allowing torch.compile to fuse the quantization with
# previous operations. This is typically supported when using FP8 KV cache
# with compatible attention kernels (e.g., TRT-LLM).
# Subclasses should set this in __init__.
# TODO add support to more backends:
# https://github.com/vllm-project/vllm/issues/25584
supports_quant_query_input: bool = False
dcp_world_size: int
dcp_rank: int
pcp_world_size: int
pcp_rank: int
total_cp_world_size: int
total_cp_rank: int
def __new__(cls, *args, **kwargs):
# use __new__ so that all subclasses will call this
self = super().__new__(cls)
try:
from vllm.distributed.parallel_state import get_dcp_group
self.dcp_world_size = get_dcp_group().world_size
self.dcp_rank = get_dcp_group().rank_in_group
except AssertionError:
# DCP might not be initialized in testing
self.dcp_world_size = 1
self.dcp_rank = 0
try:
from vllm.distributed.parallel_state import get_pcp_group
self.pcp_world_size = get_pcp_group().world_size
self.pcp_rank = get_pcp_group().rank_in_group
except AssertionError:
self.pcp_world_size = 1
self.pcp_rank = 0
self.total_cp_world_size = self.pcp_world_size * self.dcp_world_size
self.total_cp_rank = self.pcp_rank * self.dcp_world_size + self.dcp_rank
self.need_to_return_lse_for_decode = (
self.dcp_world_size > 1 and self.can_return_lse_for_decode
)
return self
@abstractmethod
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int | None = None,
alibi_slopes: list[float] | None = None,
sliding_window: int | None = None,
kv_cache_dtype: str = "auto",
logits_soft_cap: float | None = None,
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: str | None = None,
) -> None:
raise NotImplementedError
@abstractmethod
def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: T,
output: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
raise NotImplementedError
def fused_output_quant_supported(self, quant_key: "QuantKey"):
"""
Does this attention implementation support fused output quantization.
This is used by the AttnFusionPass to only fuse output quantization
onto implementations that support it.
:param quant_key: QuantKey object that describes the quantization op
:return: is fusion supported for this type of quantization
"""
return False
def process_weights_after_loading(self, act_dtype: torch.dtype):
pass
class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
@abstractmethod
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: list[float] | None,
sliding_window: int | None,
kv_cache_dtype: str,
logits_soft_cap: float | None,
attn_type: str,
kv_sharing_target_layer_name: str | None,
# MLA Specific Arguments
q_lora_rank: int | None,
kv_lora_rank: int,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
qk_head_dim: int,
v_head_dim: int,
kv_b_proj: "ColumnParallelLinear",
indexer: object | None = None,
) -> None:
raise NotImplementedError
@abstractmethod
def forward(
self,
layer: AttentionLayer,
hidden_states_or_cq: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: T,
output: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
raise NotImplementedError
def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
return kv_cache_dtype != "auto"