444 lines
14 KiB
Python
444 lines
14 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
from abc import ABC, abstractmethod
|
|
from typing import TYPE_CHECKING, ClassVar, Generic, Protocol, TypeVar, get_args
|
|
|
|
import torch
|
|
|
|
if TYPE_CHECKING:
|
|
from vllm.config.cache import CacheDType
|
|
from vllm.model_executor.layers.linear import ColumnParallelLinear
|
|
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
|
|
from vllm.platforms.interface import DeviceCapability
|
|
from vllm.v1.attention.backends.utils import KVCacheLayoutType
|
|
|
|
|
|
class AttentionType:
|
|
"""
|
|
Attention type.
|
|
Use string to be compatible with `torch.compile`.
|
|
"""
|
|
|
|
DECODER = "decoder"
|
|
"""Decoder attention between previous layer Q/K/V."""
|
|
ENCODER = "encoder"
|
|
"""Encoder attention between previous layer Q/K/V for encoder-decoder."""
|
|
ENCODER_ONLY = "encoder_only"
|
|
"""Encoder attention between previous layer Q/K/V."""
|
|
ENCODER_DECODER = "encoder_decoder"
|
|
"""Attention between dec. Q and enc. K/V for encoder-decoder."""
|
|
|
|
|
|
class MultipleOf:
|
|
base: int
|
|
|
|
def __init__(self, base: int):
|
|
self.base = base
|
|
|
|
|
|
class AttentionBackend(ABC):
|
|
"""Abstract class for attention backends."""
|
|
|
|
# For some attention backends, we allocate an output tensor before
|
|
# calling the custom op. When piecewise cudagraph is enabled, this
|
|
# makes sure the output tensor is allocated inside the cudagraph.
|
|
accept_output_buffer: bool = False
|
|
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
|
supported_kv_cache_dtypes: ClassVar[list["CacheDType"]] = ["auto"]
|
|
|
|
@staticmethod
|
|
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
|
return [MultipleOf(1)]
|
|
|
|
@staticmethod
|
|
@abstractmethod
|
|
def get_name() -> str:
|
|
raise NotImplementedError
|
|
|
|
@staticmethod
|
|
@abstractmethod
|
|
def get_impl_cls() -> type["AttentionImpl"]:
|
|
raise NotImplementedError
|
|
|
|
@staticmethod
|
|
@abstractmethod
|
|
def get_builder_cls(): # -> Type["AttentionMetadataBuilder"]:
|
|
raise NotImplementedError
|
|
|
|
@staticmethod
|
|
@abstractmethod
|
|
def get_kv_cache_shape(
|
|
num_blocks: int,
|
|
block_size: int,
|
|
num_kv_heads: int,
|
|
head_size: int,
|
|
cache_dtype_str: str = "auto",
|
|
) -> tuple[int, ...]:
|
|
raise NotImplementedError
|
|
|
|
@staticmethod
|
|
def get_kv_cache_stride_order(
|
|
include_num_layers_dimension: bool = False,
|
|
) -> tuple[int, ...]:
|
|
"""
|
|
Get the physical (memory layout) ordering of the kv cache dimensions.
|
|
e.g. if the KV cache shape is
|
|
[2, num_blocks, block_size, num_heads, head_size],
|
|
and get_kv_cache_stride_order returns (1, 3, 0, 2, 4) then the physical
|
|
ordering of dimensions is
|
|
[num_blocks, num_heads, 2, block_size, head_size].
|
|
|
|
If this function is unimplemented / raises NotImplementedError,
|
|
the physical layout of the KV cache will match the logical shape.
|
|
|
|
Args:
|
|
include_num_layers_dimension: if True, includes an additional
|
|
num_layers dimension, which is assumed to be prepended
|
|
to the logical KV cache shape.
|
|
With the above example, a return value (2, 4, 0, 1, 3, 5)
|
|
corresponds to
|
|
[num_blocks, num_heads, num_layers, 2, block_size, head_size].
|
|
|
|
If an additional dimension is NOT included in the returned
|
|
tuple, the physical layout will not include a layers dimension.
|
|
|
|
Returns:
|
|
A tuple of ints which is a permutation of range(len(shape)).
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@classmethod
|
|
def full_cls_name(cls) -> tuple[str, str]:
|
|
return (cls.__module__, cls.__qualname__)
|
|
|
|
@classmethod
|
|
def get_supported_head_sizes(cls) -> list[int]:
|
|
return []
|
|
|
|
@classmethod
|
|
def supports_head_size(cls, head_size: int) -> bool:
|
|
supported_head_sizes = cls.get_supported_head_sizes()
|
|
return (not supported_head_sizes) or head_size in supported_head_sizes
|
|
|
|
@classmethod
|
|
def supports_dtype(cls, dtype: torch.dtype) -> bool:
|
|
return dtype in cls.supported_dtypes
|
|
|
|
@classmethod
|
|
def supports_kv_cache_dtype(cls, kv_cache_dtype: "CacheDType | None") -> bool:
|
|
if kv_cache_dtype is None:
|
|
return True
|
|
return (not cls.supported_kv_cache_dtypes) or (
|
|
kv_cache_dtype in cls.supported_kv_cache_dtypes
|
|
)
|
|
|
|
@classmethod
|
|
def supports_block_size(cls, block_size: int | None) -> bool:
|
|
from vllm.config.cache import BlockSize
|
|
|
|
if block_size is None:
|
|
return True
|
|
|
|
valid_sizes = get_args(BlockSize)
|
|
if block_size not in valid_sizes:
|
|
return False
|
|
|
|
supported_kernel_block_sizes = cls.get_supported_kernel_block_sizes()
|
|
if not supported_kernel_block_sizes:
|
|
return True
|
|
|
|
for supported_size in supported_kernel_block_sizes:
|
|
if isinstance(supported_size, MultipleOf):
|
|
supported_size = supported_size.base
|
|
# With hybrid_blocks feature, the framework-level block size
|
|
# only needs to be a multiple of the kernel's requirement,
|
|
# even if the kernel requires a fixed block_size.
|
|
if block_size % supported_size == 0:
|
|
return True
|
|
return False
|
|
|
|
@classmethod
|
|
def is_mla(cls) -> bool:
|
|
return False
|
|
|
|
@classmethod
|
|
def supports_sink(cls) -> bool:
|
|
return False
|
|
|
|
@classmethod
|
|
def supports_mm_prefix(cls) -> bool:
|
|
return False
|
|
|
|
@classmethod
|
|
def is_sparse(cls) -> bool:
|
|
return False
|
|
|
|
@classmethod
|
|
def supports_attn_type(cls, attn_type: str) -> bool:
|
|
"""Check if backend supports a given attention type.
|
|
|
|
By default, only supports decoder attention.
|
|
Backends should override this to support other attention types.
|
|
"""
|
|
return attn_type == AttentionType.DECODER
|
|
|
|
@classmethod
|
|
def supports_compute_capability(cls, capability: "DeviceCapability") -> bool:
|
|
return True
|
|
|
|
@classmethod
|
|
def supports_combination(
|
|
cls,
|
|
head_size: int,
|
|
dtype: torch.dtype,
|
|
kv_cache_dtype: "CacheDType | None",
|
|
block_size: int | None,
|
|
use_mla: bool,
|
|
has_sink: bool,
|
|
use_sparse: bool,
|
|
device_capability: "DeviceCapability",
|
|
) -> str | None:
|
|
return None
|
|
|
|
@classmethod
|
|
def validate_configuration(
|
|
cls,
|
|
head_size: int,
|
|
dtype: torch.dtype,
|
|
kv_cache_dtype: "CacheDType | None",
|
|
block_size: int | None,
|
|
use_mla: bool,
|
|
has_sink: bool,
|
|
use_sparse: bool,
|
|
use_mm_prefix: bool,
|
|
device_capability: "DeviceCapability",
|
|
attn_type: str,
|
|
) -> list[str]:
|
|
invalid_reasons = []
|
|
if not cls.supports_head_size(head_size):
|
|
invalid_reasons.append("head_size not supported")
|
|
if not cls.supports_dtype(dtype):
|
|
invalid_reasons.append("dtype not supported")
|
|
if not cls.supports_kv_cache_dtype(kv_cache_dtype):
|
|
invalid_reasons.append("kv_cache_dtype not supported")
|
|
if not cls.supports_block_size(block_size):
|
|
invalid_reasons.append("block_size not supported")
|
|
if use_mm_prefix and not cls.supports_mm_prefix():
|
|
invalid_reasons.append(
|
|
"partial multimodal token full attention not supported"
|
|
)
|
|
if use_mla != cls.is_mla():
|
|
if use_mla:
|
|
invalid_reasons.append("MLA not supported")
|
|
else:
|
|
invalid_reasons.append("non-MLA not supported")
|
|
if has_sink and not cls.supports_sink():
|
|
invalid_reasons.append("sink setting not supported")
|
|
if use_sparse != cls.is_sparse():
|
|
if use_sparse:
|
|
invalid_reasons.append("sparse not supported")
|
|
else:
|
|
invalid_reasons.append("non-sparse not supported")
|
|
if not cls.supports_compute_capability(device_capability):
|
|
invalid_reasons.append("compute capability not supported")
|
|
if not cls.supports_attn_type(attn_type):
|
|
invalid_reasons.append(f"attention type {attn_type} not supported")
|
|
combination_reason = cls.supports_combination(
|
|
head_size,
|
|
dtype,
|
|
kv_cache_dtype,
|
|
block_size,
|
|
use_mla,
|
|
has_sink,
|
|
use_sparse,
|
|
device_capability,
|
|
)
|
|
if combination_reason is not None:
|
|
invalid_reasons.append(combination_reason)
|
|
return invalid_reasons
|
|
|
|
@classmethod
|
|
def get_required_kv_cache_layout(cls) -> "KVCacheLayoutType | None":
|
|
return None
|
|
|
|
|
|
class AttentionMetadata:
|
|
pass
|
|
|
|
|
|
T = TypeVar("T", bound=AttentionMetadata)
|
|
|
|
|
|
class AttentionLayer(Protocol):
|
|
_q_scale: torch.Tensor
|
|
_k_scale: torch.Tensor
|
|
_v_scale: torch.Tensor
|
|
_q_scale_float: float
|
|
_k_scale_float: float
|
|
_v_scale_float: float
|
|
_prob_scale: torch.Tensor
|
|
|
|
def forward(
|
|
self,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
kv_cache: torch.Tensor,
|
|
attn_metadata: AttentionMetadata,
|
|
) -> torch.Tensor: ...
|
|
|
|
|
|
class AttentionImpl(ABC, Generic[T]):
|
|
# Whether the attention impl can return the softmax lse for decode.
|
|
# Some features like decode context parallelism require the softmax lse.
|
|
can_return_lse_for_decode: bool = False
|
|
|
|
# Whether the attention impl supports Prefill Context Parallelism.
|
|
supports_pcp: bool = False
|
|
# Whether the attention impl(or ops) supports MTP
|
|
# when cp_kv_cache_interleave_size > 1
|
|
supports_mtp_with_cp_non_trivial_interleave_size: bool = False
|
|
|
|
# some attention backends might not always want to return lse
|
|
# even if they can return lse (for efficiency reasons)
|
|
need_to_return_lse_for_decode: bool = False
|
|
|
|
# Whether this attention implementation supports pre-quantized query input.
|
|
# When True, the attention layer will quantize queries before passing them
|
|
# to this backend, allowing torch.compile to fuse the quantization with
|
|
# previous operations. This is typically supported when using FP8 KV cache
|
|
# with compatible attention kernels (e.g., TRT-LLM).
|
|
# Subclasses should set this in __init__.
|
|
# TODO add support to more backends:
|
|
# https://github.com/vllm-project/vllm/issues/25584
|
|
supports_quant_query_input: bool = False
|
|
|
|
dcp_world_size: int
|
|
dcp_rank: int
|
|
|
|
pcp_world_size: int
|
|
pcp_rank: int
|
|
|
|
total_cp_world_size: int
|
|
total_cp_rank: int
|
|
|
|
def __new__(cls, *args, **kwargs):
|
|
# use __new__ so that all subclasses will call this
|
|
self = super().__new__(cls)
|
|
try:
|
|
from vllm.distributed.parallel_state import get_dcp_group
|
|
|
|
self.dcp_world_size = get_dcp_group().world_size
|
|
self.dcp_rank = get_dcp_group().rank_in_group
|
|
except AssertionError:
|
|
# DCP might not be initialized in testing
|
|
self.dcp_world_size = 1
|
|
self.dcp_rank = 0
|
|
try:
|
|
from vllm.distributed.parallel_state import get_pcp_group
|
|
|
|
self.pcp_world_size = get_pcp_group().world_size
|
|
self.pcp_rank = get_pcp_group().rank_in_group
|
|
except AssertionError:
|
|
self.pcp_world_size = 1
|
|
self.pcp_rank = 0
|
|
self.total_cp_world_size = self.pcp_world_size * self.dcp_world_size
|
|
self.total_cp_rank = self.pcp_rank * self.dcp_world_size + self.dcp_rank
|
|
|
|
self.need_to_return_lse_for_decode = (
|
|
self.dcp_world_size > 1 and self.can_return_lse_for_decode
|
|
)
|
|
return self
|
|
|
|
@abstractmethod
|
|
def __init__(
|
|
self,
|
|
num_heads: int,
|
|
head_size: int,
|
|
scale: float,
|
|
num_kv_heads: int | None = None,
|
|
alibi_slopes: list[float] | None = None,
|
|
sliding_window: int | None = None,
|
|
kv_cache_dtype: str = "auto",
|
|
logits_soft_cap: float | None = None,
|
|
attn_type: str = AttentionType.DECODER,
|
|
kv_sharing_target_layer_name: str | None = None,
|
|
) -> None:
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def forward(
|
|
self,
|
|
layer: AttentionLayer,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
kv_cache: torch.Tensor,
|
|
attn_metadata: T,
|
|
output: torch.Tensor | None = None,
|
|
output_scale: torch.Tensor | None = None,
|
|
output_block_scale: torch.Tensor | None = None,
|
|
) -> torch.Tensor:
|
|
raise NotImplementedError
|
|
|
|
def fused_output_quant_supported(self, quant_key: "QuantKey"):
|
|
"""
|
|
Does this attention implementation support fused output quantization.
|
|
This is used by the AttnFusionPass to only fuse output quantization
|
|
onto implementations that support it.
|
|
|
|
:param quant_key: QuantKey object that describes the quantization op
|
|
:return: is fusion supported for this type of quantization
|
|
"""
|
|
return False
|
|
|
|
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
|
pass
|
|
|
|
|
|
class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
|
|
@abstractmethod
|
|
def __init__(
|
|
self,
|
|
num_heads: int,
|
|
head_size: int,
|
|
scale: float,
|
|
num_kv_heads: int,
|
|
alibi_slopes: list[float] | None,
|
|
sliding_window: int | None,
|
|
kv_cache_dtype: str,
|
|
logits_soft_cap: float | None,
|
|
attn_type: str,
|
|
kv_sharing_target_layer_name: str | None,
|
|
# MLA Specific Arguments
|
|
q_lora_rank: int | None,
|
|
kv_lora_rank: int,
|
|
qk_nope_head_dim: int,
|
|
qk_rope_head_dim: int,
|
|
qk_head_dim: int,
|
|
v_head_dim: int,
|
|
kv_b_proj: "ColumnParallelLinear",
|
|
indexer: object | None = None,
|
|
) -> None:
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def forward(
|
|
self,
|
|
layer: AttentionLayer,
|
|
hidden_states_or_cq: torch.Tensor,
|
|
kv_c_normed: torch.Tensor,
|
|
k_pe: torch.Tensor,
|
|
kv_cache: torch.Tensor,
|
|
attn_metadata: T,
|
|
output: torch.Tensor | None = None,
|
|
output_scale: torch.Tensor | None = None,
|
|
output_block_scale: torch.Tensor | None = None,
|
|
) -> torch.Tensor:
|
|
raise NotImplementedError
|
|
|
|
|
|
def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
|
|
return kv_cache_dtype != "auto"
|