205 lines
6.1 KiB
Python
205 lines
6.1 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
from abc import ABC, abstractmethod
|
|
from typing import Generic, List, Optional, Protocol, Tuple, Type, TypeVar
|
|
|
|
import torch
|
|
|
|
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
|
|
|
|
|
|
class AttentionType:
|
|
"""
|
|
Attention type.
|
|
Use string to be compatible with `torch.compile`.
|
|
"""
|
|
DECODER = "decoder"
|
|
"""Decoder attention between previous layer Q/K/V."""
|
|
ENCODER = "encoder"
|
|
"""Encoder attention between previous layer Q/K/V for encoder-decoder."""
|
|
ENCODER_ONLY = "encoder_only"
|
|
"""Encoder attention between previous layer Q/K/V."""
|
|
ENCODER_DECODER = "encoder_decoder"
|
|
"""Attention between dec. Q and enc. K/V for encoder-decoder."""
|
|
|
|
|
|
class AttentionBackend(ABC):
|
|
"""Abstract class for attention backends."""
|
|
# For some attention backends, we allocate an output tensor before
|
|
# calling the custom op. When piecewise cudagraph is enabled, this
|
|
# makes sure the output tensor is allocated inside the cudagraph.
|
|
accept_output_buffer: bool = False
|
|
|
|
# Whether this backend supports receiving pre-quantized query input.
|
|
# If True, the attention layer will handle query quantization instead
|
|
# of the backend, allowing torch.compile to fuse quantization with
|
|
# previous operations.
|
|
# Needs to be worked through for all backends
|
|
# https://github.com/vllm-project/vllm/issues/25584
|
|
supports_quant_query_input: bool = False
|
|
|
|
@staticmethod
|
|
@abstractmethod
|
|
def get_name() -> str:
|
|
raise NotImplementedError
|
|
|
|
@staticmethod
|
|
@abstractmethod
|
|
def get_impl_cls() -> Type["AttentionImpl"]:
|
|
raise NotImplementedError
|
|
|
|
@staticmethod
|
|
@abstractmethod
|
|
def get_metadata_cls() -> Type["AttentionMetadata"]:
|
|
raise NotImplementedError
|
|
|
|
@classmethod
|
|
def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
|
|
return cls.get_metadata_cls()(*args, **kwargs)
|
|
|
|
@staticmethod
|
|
@abstractmethod
|
|
def get_builder_cls(): # -> Type["AttentionMetadataBuilder"]:
|
|
raise NotImplementedError
|
|
|
|
@staticmethod
|
|
@abstractmethod
|
|
def get_kv_cache_shape(
|
|
num_blocks: int,
|
|
block_size: int,
|
|
num_kv_heads: int,
|
|
head_size: int,
|
|
cache_dtype_str: str = "auto",
|
|
) -> Tuple[int, ...]:
|
|
raise NotImplementedError
|
|
|
|
@staticmethod
|
|
def get_kv_cache_stride_order() -> Tuple[int, ...]:
|
|
raise NotImplementedError
|
|
|
|
@classmethod
|
|
def full_cls_name(cls) -> tuple[str, str]:
|
|
return (cls.__module__, cls.__qualname__)
|
|
|
|
|
|
class AttentionMetadata:
|
|
pass
|
|
|
|
|
|
T = TypeVar("T", bound=AttentionMetadata)
|
|
|
|
|
|
class AttentionLayer(Protocol):
|
|
|
|
_q_scale: torch.Tensor
|
|
_k_scale: torch.Tensor
|
|
_v_scale: torch.Tensor
|
|
_q_scale_float: float
|
|
_k_scale_float: float
|
|
_v_scale_float: float
|
|
_prob_scale: torch.Tensor
|
|
|
|
def forward(
|
|
self,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
kv_cache: torch.Tensor,
|
|
attn_metadata: AttentionMetadata,
|
|
) -> torch.Tensor:
|
|
...
|
|
|
|
|
|
class AttentionImpl(ABC, Generic[T]):
|
|
|
|
# Whether the attention impl can return the softmax lse for decode.
|
|
# Some features like decode context parallelism require the softmax lse.
|
|
can_return_lse_for_decode: bool = False
|
|
|
|
# some attention backends might not always want to return lse
|
|
# even if they can return lse (for efficiency reasons)
|
|
need_to_return_lse_for_decode: bool = False
|
|
|
|
dcp_world_size: int
|
|
dcp_rank: int
|
|
|
|
def __new__(cls, *args, **kwargs):
|
|
# use __new__ so that all subclasses will call this
|
|
self = super().__new__(cls)
|
|
try:
|
|
from vllm.distributed.parallel_state import get_dcp_group
|
|
self.dcp_world_size = get_dcp_group().world_size
|
|
self.dcp_rank = get_dcp_group().rank_in_group
|
|
except AssertionError:
|
|
# DCP might not be initialized in testing
|
|
self.dcp_world_size = 1
|
|
self.dcp_rank = 0
|
|
self.need_to_return_lse_for_decode = self.dcp_world_size > 1 \
|
|
and self.can_return_lse_for_decode
|
|
return self
|
|
|
|
@abstractmethod
|
|
def __init__(
|
|
self,
|
|
num_heads: int,
|
|
head_size: int,
|
|
scale: float,
|
|
num_kv_heads: Optional[int] = None,
|
|
alibi_slopes: Optional[List[float]] = None,
|
|
sliding_window: Optional[int] = None,
|
|
kv_cache_dtype: str = "auto",
|
|
logits_soft_cap: Optional[float] = None,
|
|
attn_type: str = AttentionType.DECODER,
|
|
kv_sharing_target_layer_name: Optional[str] = None,
|
|
) -> None:
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def forward(
|
|
self,
|
|
layer: AttentionLayer,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
kv_cache: torch.Tensor,
|
|
attn_metadata: T,
|
|
output: Optional[torch.Tensor] = None,
|
|
output_scale: Optional[torch.Tensor] = None,
|
|
output_block_scale: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
raise NotImplementedError
|
|
|
|
def fused_output_quant_supported(self, quant_key: QuantKey):
|
|
"""
|
|
Does this attention implementation support fused output quantization.
|
|
This is used by the AttnFusionPass to only fuse output quantization
|
|
onto implementations that support it.
|
|
|
|
:param quant_key: QuantKey object that describes the quantization op
|
|
:return: is fusion supported for this type of quantization
|
|
"""
|
|
return False
|
|
|
|
|
|
class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
|
|
|
|
@abstractmethod
|
|
def forward(
|
|
self,
|
|
layer: AttentionLayer,
|
|
hidden_states_or_cq: torch.Tensor,
|
|
kv_c_normed: torch.Tensor,
|
|
k_pe: torch.Tensor,
|
|
kv_cache: torch.Tensor,
|
|
attn_metadata: T,
|
|
output: Optional[torch.Tensor] = None,
|
|
output_scale: Optional[torch.Tensor] = None,
|
|
output_block_scale: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
raise NotImplementedError
|
|
|
|
|
|
def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
|
|
return kv_cache_dtype != "auto"
|