update
This commit is contained in:
885
vllm/v1/attention/backend.py
Normal file
885
vllm/v1/attention/backend.py
Normal file
@@ -0,0 +1,885 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, replace
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Protocol, TypeVar, get_args
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from typing_extensions import deprecated
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.model_executor.layers.linear import ColumnParallelLinear
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.v1.attention.backends.utils import KVCacheLayoutType
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
|
||||
class AttentionType(str, Enum):
|
||||
"""
|
||||
Attention type.
|
||||
Use string to be compatible with `torch.compile`.
|
||||
"""
|
||||
|
||||
DECODER = "decoder"
|
||||
"""Decoder attention between previous layer Q/K/V."""
|
||||
ENCODER = "encoder"
|
||||
"""Encoder attention between previous layer Q/K/V for encoder-decoder."""
|
||||
ENCODER_ONLY = "encoder_only"
|
||||
"""Encoder attention between previous layer Q/K/V."""
|
||||
ENCODER_DECODER = "encoder_decoder"
|
||||
"""Attention between dec. Q and enc. K/V for encoder-decoder."""
|
||||
|
||||
|
||||
class MultipleOf:
|
||||
base: int
|
||||
|
||||
def __init__(self, base: int):
|
||||
self.base = base
|
||||
|
||||
|
||||
class AttentionBackend(ABC):
|
||||
"""Abstract class for attention backends."""
|
||||
|
||||
# For some attention backends, we allocate an output tensor before
|
||||
# calling the custom op. When piecewise cudagraph is enabled, this
|
||||
# makes sure the output tensor is allocated inside the cudagraph.
|
||||
accept_output_buffer: bool = False
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kv_cache_dtypes: ClassVar[list["CacheDType"]] = ["auto", "bfloat16"]
|
||||
|
||||
# Does attention's forward() include kv cache update?
|
||||
forward_includes_kv_cache_update: bool = True
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||
return [MultipleOf(1)]
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_name() -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_impl_cls() -> type["AttentionImplBase"]:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_builder_cls(): # -> Type["AttentionMetadataBuilder"]:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> tuple[int, ...]:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_stride_order(
|
||||
include_num_layers_dimension: bool = False,
|
||||
) -> tuple[int, ...]:
|
||||
"""
|
||||
Get the physical (memory layout) ordering of the kv cache dimensions.
|
||||
e.g. if the KV cache shape is
|
||||
[2, num_blocks, block_size, num_heads, head_size],
|
||||
and get_kv_cache_stride_order returns (1, 3, 0, 2, 4) then the physical
|
||||
ordering of dimensions is
|
||||
[num_blocks, num_heads, 2, block_size, head_size].
|
||||
|
||||
If this function is unimplemented / raises NotImplementedError,
|
||||
the physical layout of the KV cache will match the logical shape.
|
||||
|
||||
Args:
|
||||
include_num_layers_dimension: if True, includes an additional
|
||||
num_layers dimension, which is assumed to be prepended
|
||||
to the logical KV cache shape.
|
||||
With the above example, a return value (2, 4, 0, 1, 3, 5)
|
||||
corresponds to
|
||||
[num_blocks, num_heads, num_layers, 2, block_size, head_size].
|
||||
|
||||
If an additional dimension is NOT included in the returned
|
||||
tuple, the physical layout will not include a layers dimension.
|
||||
|
||||
Returns:
|
||||
A tuple of ints which is a permutation of range(len(shape)).
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def full_cls_name(cls) -> tuple[str, str]:
|
||||
return (cls.__module__, cls.__qualname__)
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def supports_head_size(cls, head_size: int) -> bool:
|
||||
supported_head_sizes = cls.get_supported_head_sizes()
|
||||
return (not supported_head_sizes) or head_size in supported_head_sizes
|
||||
|
||||
@classmethod
|
||||
def supports_dtype(cls, dtype: torch.dtype) -> bool:
|
||||
return dtype in cls.supported_dtypes
|
||||
|
||||
@classmethod
|
||||
def supports_kv_cache_dtype(cls, kv_cache_dtype: "CacheDType | None") -> bool:
|
||||
if kv_cache_dtype is None:
|
||||
return True
|
||||
return (not cls.supported_kv_cache_dtypes) or (
|
||||
kv_cache_dtype in cls.supported_kv_cache_dtypes
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def supports_block_size(cls, block_size: int | None) -> bool:
|
||||
from vllm.config.cache import BlockSize
|
||||
|
||||
if block_size is None:
|
||||
return True
|
||||
|
||||
valid_sizes = get_args(BlockSize)
|
||||
if block_size not in valid_sizes:
|
||||
return False
|
||||
|
||||
supported_kernel_block_sizes = cls.get_supported_kernel_block_sizes()
|
||||
if not supported_kernel_block_sizes:
|
||||
return True
|
||||
|
||||
for supported_size in supported_kernel_block_sizes:
|
||||
if isinstance(supported_size, MultipleOf):
|
||||
supported_size = supported_size.base
|
||||
# With hybrid_blocks feature, the framework-level block size
|
||||
# only needs to be a multiple of the kernel's requirement,
|
||||
# even if the kernel requires a fixed block_size.
|
||||
if block_size % supported_size == 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def is_mla(cls) -> bool:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def supports_sink(cls) -> bool:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def supports_alibi_sqrt(cls) -> bool:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def supports_mm_prefix(cls) -> bool:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def is_sparse(cls) -> bool:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def supports_per_head_quant_scales(cls) -> bool:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def supports_attn_type(cls, attn_type: str) -> bool:
|
||||
"""Check if backend supports a given attention type.
|
||||
|
||||
By default, only supports decoder attention.
|
||||
Backends should override this to support other attention types.
|
||||
"""
|
||||
return attn_type == AttentionType.DECODER
|
||||
|
||||
@classmethod
|
||||
def supports_compute_capability(cls, capability: "DeviceCapability") -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def supports_combination(
|
||||
cls,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: "CacheDType | None",
|
||||
block_size: int,
|
||||
use_mla: bool,
|
||||
has_sink: bool,
|
||||
use_sparse: bool,
|
||||
device_capability: "DeviceCapability",
|
||||
) -> str | None:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def validate_configuration(
|
||||
cls,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: "CacheDType | None",
|
||||
block_size: int,
|
||||
use_mla: bool,
|
||||
has_sink: bool,
|
||||
use_sparse: bool,
|
||||
use_mm_prefix: bool,
|
||||
use_per_head_quant_scales: bool,
|
||||
device_capability: "DeviceCapability",
|
||||
attn_type: str,
|
||||
) -> list[str]:
|
||||
invalid_reasons = []
|
||||
if not cls.supports_head_size(head_size):
|
||||
invalid_reasons.append("head_size not supported")
|
||||
if not cls.supports_dtype(dtype):
|
||||
invalid_reasons.append("dtype not supported")
|
||||
if not cls.supports_kv_cache_dtype(kv_cache_dtype):
|
||||
invalid_reasons.append("kv_cache_dtype not supported")
|
||||
if not cls.supports_block_size(block_size):
|
||||
invalid_reasons.append("block_size not supported")
|
||||
if use_mm_prefix and not cls.supports_mm_prefix():
|
||||
invalid_reasons.append(
|
||||
"partial multimodal token full attention not supported"
|
||||
)
|
||||
if use_mla != cls.is_mla():
|
||||
if use_mla:
|
||||
invalid_reasons.append("MLA not supported")
|
||||
else:
|
||||
invalid_reasons.append("non-MLA not supported")
|
||||
if has_sink and not cls.supports_sink():
|
||||
invalid_reasons.append("sink setting not supported")
|
||||
if use_sparse != cls.is_sparse():
|
||||
if use_sparse:
|
||||
invalid_reasons.append("sparse not supported")
|
||||
else:
|
||||
invalid_reasons.append("non-sparse not supported")
|
||||
if use_per_head_quant_scales and not cls.supports_per_head_quant_scales():
|
||||
invalid_reasons.append("per-head quant scales not supported")
|
||||
if not cls.supports_compute_capability(device_capability):
|
||||
invalid_reasons.append("compute capability not supported")
|
||||
if not cls.supports_attn_type(attn_type):
|
||||
invalid_reasons.append(f"attention type {attn_type} not supported")
|
||||
combination_reason = cls.supports_combination(
|
||||
head_size,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
block_size,
|
||||
use_mla,
|
||||
has_sink,
|
||||
use_sparse,
|
||||
device_capability,
|
||||
)
|
||||
if combination_reason is not None:
|
||||
invalid_reasons.append(combination_reason)
|
||||
return invalid_reasons
|
||||
|
||||
@classmethod
|
||||
def get_required_kv_cache_layout(cls) -> "KVCacheLayoutType | None":
|
||||
return None
|
||||
|
||||
|
||||
class AttentionMetadata:
|
||||
pass
|
||||
|
||||
|
||||
T = TypeVar("T", bound=AttentionMetadata)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommonAttentionMetadata:
|
||||
"""
|
||||
Per-batch attention metadata, shared across layers and backends.
|
||||
AttentionMetadataBuilder instances use it to construct per-layer metadata.
|
||||
|
||||
For many of the tensors we keep both GPU and CPU versions.
|
||||
"""
|
||||
|
||||
query_start_loc: torch.Tensor
|
||||
query_start_loc_cpu: torch.Tensor
|
||||
"""(batch_size + 1,), the start location of each request in query Tensor"""
|
||||
|
||||
seq_lens: torch.Tensor
|
||||
"""(batch_size,), the number of computed tokens for each request"""
|
||||
|
||||
num_reqs: int
|
||||
"""Number of requests"""
|
||||
# TODO(lucas): rename to num_tokens since it may be padded and this is misleading
|
||||
num_actual_tokens: int
|
||||
"""Total number of tokens in batch"""
|
||||
max_query_len: int
|
||||
"""Longest query in batch"""
|
||||
max_seq_len: int
|
||||
"""Longest context length (may be an upper bound)"""
|
||||
|
||||
block_table_tensor: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
causal: bool = True
|
||||
|
||||
# Needed by FastPrefillAttentionBuilder
|
||||
logits_indices_padded: torch.Tensor | None = None
|
||||
num_logits_indices: int | None = None
|
||||
|
||||
# Needed by CrossAttentionBuilder
|
||||
encoder_seq_lens: torch.Tensor | None = None
|
||||
encoder_seq_lens_cpu: np.ndarray | None = None
|
||||
|
||||
dcp_local_seq_lens: torch.Tensor | None = None
|
||||
dcp_local_seq_lens_cpu: torch.Tensor | None = None
|
||||
"""Sequence lengths of the local rank in decode context parallelism world"""
|
||||
|
||||
# WARNING: Deprecated fields. Will be removed in a future release (v0.15.0)
|
||||
_seq_lens_cpu: torch.Tensor | None = None
|
||||
_num_computed_tokens_cpu: torch.Tensor | None = None
|
||||
|
||||
_num_computed_tokens_cache: torch.Tensor | None = None
|
||||
|
||||
def batch_size(self) -> int:
|
||||
return self.seq_lens.shape[0]
|
||||
|
||||
def naive_query_lens(self) -> torch.Tensor:
|
||||
"""Naive because it assumes that query ends where the next query starts."""
|
||||
return self.query_start_loc[1:] - self.query_start_loc[:-1]
|
||||
|
||||
def replace(self, **kwargs) -> "CommonAttentionMetadata":
|
||||
return replace(self, **kwargs)
|
||||
|
||||
@property
|
||||
@deprecated(
|
||||
"""
|
||||
Prefer using device seq_lens directly to avoid implicit H<>D sync.
|
||||
If a CPU copy is needed, use `seq_lens.cpu()` instead.
|
||||
Will be removed in a future release, please migrate as soon as possible.
|
||||
"""
|
||||
)
|
||||
def seq_lens_cpu(self) -> torch.Tensor:
|
||||
if self._seq_lens_cpu is None:
|
||||
self._seq_lens_cpu = self.seq_lens.to("cpu")
|
||||
return self._seq_lens_cpu
|
||||
|
||||
@property
|
||||
@deprecated(
|
||||
"""
|
||||
Prefer using device seq_lens directly to avoid implicit H<>D sync which breaks full
|
||||
async scheduling. If a CPU copy is needed, it can be derived from
|
||||
query_start_loc_cpu and seq_lens.
|
||||
Will be removed in a future release, please migrate as soon as possible.
|
||||
"""
|
||||
)
|
||||
def num_computed_tokens_cpu(self) -> torch.Tensor:
|
||||
if self._num_computed_tokens_cpu is None:
|
||||
query_seq_lens = (
|
||||
self.query_start_loc_cpu[1:] - self.query_start_loc_cpu[:-1]
|
||||
)
|
||||
self._num_computed_tokens_cpu = self.seq_lens_cpu - query_seq_lens
|
||||
return self._num_computed_tokens_cpu
|
||||
|
||||
def compute_num_computed_tokens(self) -> torch.Tensor:
|
||||
"""Compute num_computed_tokens on device (seq_lens - query_lens)."""
|
||||
if self._num_computed_tokens_cache is None:
|
||||
query_lens = self.query_start_loc[1:] - self.query_start_loc[:-1]
|
||||
self._num_computed_tokens_cache = self.seq_lens - query_lens
|
||||
return self._num_computed_tokens_cache
|
||||
|
||||
# TODO(lucas): remove once we have FULL-CG spec-decode support
|
||||
def unpadded(
|
||||
self, num_actual_tokens: int, num_actual_reqs: int
|
||||
) -> "CommonAttentionMetadata":
|
||||
maybe_slice_reqs = lambda x: x[:num_actual_reqs] if x is not None else None
|
||||
return CommonAttentionMetadata(
|
||||
query_start_loc=self.query_start_loc[: num_actual_reqs + 1],
|
||||
query_start_loc_cpu=self.query_start_loc_cpu[: num_actual_reqs + 1],
|
||||
seq_lens=self.seq_lens[:num_actual_reqs],
|
||||
_seq_lens_cpu=self._seq_lens_cpu[:num_actual_reqs]
|
||||
if self._seq_lens_cpu is not None
|
||||
else None,
|
||||
_num_computed_tokens_cpu=self._num_computed_tokens_cpu[:num_actual_reqs]
|
||||
if self._num_computed_tokens_cpu is not None
|
||||
else None,
|
||||
num_reqs=num_actual_reqs,
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
max_query_len=self.max_query_len,
|
||||
max_seq_len=self.max_seq_len,
|
||||
block_table_tensor=self.block_table_tensor[:num_actual_reqs],
|
||||
slot_mapping=self.slot_mapping[:num_actual_tokens],
|
||||
causal=self.causal,
|
||||
logits_indices_padded=self.logits_indices_padded,
|
||||
num_logits_indices=self.num_logits_indices,
|
||||
encoder_seq_lens=maybe_slice_reqs(self.encoder_seq_lens),
|
||||
encoder_seq_lens_cpu=maybe_slice_reqs(self.encoder_seq_lens_cpu),
|
||||
dcp_local_seq_lens=maybe_slice_reqs(self.dcp_local_seq_lens),
|
||||
dcp_local_seq_lens_cpu=maybe_slice_reqs(self.dcp_local_seq_lens_cpu),
|
||||
)
|
||||
|
||||
|
||||
M = TypeVar("M")
|
||||
|
||||
|
||||
class AttentionCGSupport(Enum):
|
||||
"""Constants for the cudagraph support of the attention backend
|
||||
Here we do not consider the cascade attention, as currently
|
||||
it is never cudagraph supported."""
|
||||
|
||||
ALWAYS = 3
|
||||
"""Cudagraph always supported; supports mixed-prefill-decode"""
|
||||
UNIFORM_BATCH = 2
|
||||
"""Cudagraph supported for batches the only contain query lengths that are
|
||||
the same, this can be used for spec-decode
|
||||
i.e. "decodes" are 1 + num_speculative_tokens"""
|
||||
UNIFORM_SINGLE_TOKEN_DECODE = 1
|
||||
"""Cudagraph supported for batches the only contain query_len==1 decodes"""
|
||||
NEVER = 0
|
||||
"""NO cudagraph support"""
|
||||
|
||||
|
||||
class AttentionMetadataBuilder(ABC, Generic[M]):
|
||||
# Does this backend/builder support CUDA Graphs for attention (default: no).
|
||||
# Do not access directly. Call get_cudagraph_support() instead.
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER
|
||||
# Does this backend/builder reorder the batch?
|
||||
# If not, set this to None. Otherwise set it to the query
|
||||
# length that will be pulled into the front of the batch.
|
||||
reorder_batch_threshold: int | None = None
|
||||
# Does this backend/builder support updating the block table in existing
|
||||
# metadata
|
||||
supports_update_block_table: bool = False
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: "AttentionSpec",
|
||||
layer_names: list[str],
|
||||
vllm_config: "VllmConfig",
|
||||
device: torch.device,
|
||||
):
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.layer_names = layer_names
|
||||
self.vllm_config = vllm_config
|
||||
self.device = device
|
||||
|
||||
@classmethod
|
||||
def get_cudagraph_support(
|
||||
cls: type["AttentionMetadataBuilder"],
|
||||
vllm_config: "VllmConfig",
|
||||
kv_cache_spec: "AttentionSpec",
|
||||
) -> AttentionCGSupport:
|
||||
"""Get the cudagraph support level of this builder class."""
|
||||
return cls._cudagraph_support
|
||||
|
||||
def _init_reorder_batch_threshold(
|
||||
self,
|
||||
reorder_batch_threshold: int | None = 1,
|
||||
supports_spec_as_decode: bool = False,
|
||||
supports_dcp_with_varlen: bool = False,
|
||||
) -> None:
|
||||
self.reorder_batch_threshold = reorder_batch_threshold
|
||||
if self.reorder_batch_threshold is not None and supports_spec_as_decode:
|
||||
# If the backend supports spec-as-decode kernels, then we can set
|
||||
# the reorder_batch_threshold based on the number of speculative
|
||||
# tokens from the config.
|
||||
speculative_config = self.vllm_config.speculative_config
|
||||
if (
|
||||
speculative_config is not None
|
||||
and speculative_config.num_speculative_tokens is not None
|
||||
):
|
||||
max_num_queries_for_spec = (
|
||||
1
|
||||
+ (2 if speculative_config.parallel_drafting else 1)
|
||||
* speculative_config.num_speculative_tokens
|
||||
)
|
||||
self.reorder_batch_threshold = max(
|
||||
self.reorder_batch_threshold,
|
||||
max_num_queries_for_spec,
|
||||
)
|
||||
|
||||
if (
|
||||
self.vllm_config.parallel_config.decode_context_parallel_size > 1
|
||||
and not supports_dcp_with_varlen
|
||||
):
|
||||
self.reorder_batch_threshold = 1
|
||||
|
||||
@abstractmethod
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> M:
|
||||
"""
|
||||
Central method that builds attention metadata.
|
||||
Some builders (MLA) require reorder_batch to be called prior to build.
|
||||
|
||||
Args:
|
||||
common_prefix_len: The length of the common prefix of the batch.
|
||||
common_attn_metadata: The common attention metadata.
|
||||
fast_build: The meta-data will prioritize speed of building over
|
||||
then speed at execution. Can be used for spec-decode where the
|
||||
result of a build call may only be used for few layers/iters.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def update_block_table(
|
||||
self,
|
||||
metadata: M,
|
||||
blk_table: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
) -> M:
|
||||
"""
|
||||
Update the block table for the attention metadata.
|
||||
Faster when theres multiple kv-cache groups that create virtually the
|
||||
same metadata but just with different block tables.
|
||||
|
||||
Only needs to be implemented if supports_update_block_table is True.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def build_for_cudagraph_capture(
|
||||
self, common_attn_metadata: CommonAttentionMetadata
|
||||
) -> M:
|
||||
"""
|
||||
Build attention metadata for CUDA graph capture. Uses build by default.
|
||||
Subclasses that override this method should call self.build or
|
||||
super().build_for_cudagraph_capture.
|
||||
"""
|
||||
return self.build(
|
||||
common_prefix_len=0, common_attn_metadata=common_attn_metadata
|
||||
)
|
||||
|
||||
def build_for_drafting(
|
||||
self,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
draft_index: int,
|
||||
) -> M:
|
||||
"""
|
||||
Build attention metadata for draft model. Uses build by default.
|
||||
|
||||
Args:
|
||||
common_attn_metadata: The common attention metadata.
|
||||
draft_index: The index of the current draft operation.
|
||||
When speculating a chain of tokens, this index refers to the
|
||||
draft attempt for the i-th token.
|
||||
For tree-based attention, this index instead refers to the
|
||||
draft attempt for the i-th level in the tree of tokens.
|
||||
"""
|
||||
return self.build(
|
||||
common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
fast_build=True,
|
||||
)
|
||||
|
||||
def use_cascade_attention(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
query_lens: np.ndarray,
|
||||
num_query_heads: int,
|
||||
num_kv_heads: int,
|
||||
use_alibi: bool,
|
||||
use_sliding_window: bool,
|
||||
use_local_attention: bool,
|
||||
num_sms: int,
|
||||
dcp_world_size: int,
|
||||
) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class AttentionLayer(Protocol):
|
||||
_q_scale: torch.Tensor
|
||||
_k_scale: torch.Tensor
|
||||
_v_scale: torch.Tensor
|
||||
_q_scale_float: float
|
||||
_k_scale_float: float
|
||||
_v_scale_float: float
|
||||
_prob_scale: torch.Tensor
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor: ...
|
||||
|
||||
|
||||
class AttentionImplBase(ABC, Generic[T]):
|
||||
"""Base class for attention implementations.
|
||||
|
||||
Contains common attributes and initialization logic shared by both
|
||||
standard AttentionImpl and MLAAttentionImpl. Does not define a forward
|
||||
method - subclasses define their own forward interfaces.
|
||||
"""
|
||||
|
||||
# Required attributes that all impls should have
|
||||
num_heads: int
|
||||
head_size: int
|
||||
scale: float
|
||||
|
||||
# Whether the attention impl can return the softmax lse for decode.
|
||||
# Some features like decode context parallelism require the softmax lse.
|
||||
can_return_lse_for_decode: bool = False
|
||||
|
||||
# Whether the attention impl supports Prefill Context Parallelism.
|
||||
supports_pcp: bool = False
|
||||
# Whether the attention impl(or ops) supports MTP
|
||||
# when cp_kv_cache_interleave_size > 1
|
||||
supports_mtp_with_cp_non_trivial_interleave_size: bool = False
|
||||
|
||||
# some attention backends might not always want to return lse
|
||||
# even if they can return lse (for efficiency reasons)
|
||||
need_to_return_lse_for_decode: bool = False
|
||||
|
||||
# Whether this attention implementation supports pre-quantized query input.
|
||||
# When True, the attention layer will quantize queries before passing them
|
||||
# to this backend, allowing torch.compile to fuse the quantization with
|
||||
# previous operations. This is typically supported when using FP8 KV cache
|
||||
# with compatible attention kernels (e.g., TRT-LLM).
|
||||
# Subclasses should set this in __init__.
|
||||
# TODO add support to more backends:
|
||||
# https://github.com/vllm-project/vllm/issues/25584
|
||||
supports_quant_query_input: bool = False
|
||||
|
||||
dcp_world_size: int
|
||||
dcp_rank: int
|
||||
|
||||
pcp_world_size: int
|
||||
pcp_rank: int
|
||||
|
||||
total_cp_world_size: int
|
||||
total_cp_rank: int
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
# use __new__ so that all subclasses will call this
|
||||
self = super().__new__(cls)
|
||||
try:
|
||||
from vllm.distributed.parallel_state import get_dcp_group
|
||||
|
||||
self.dcp_world_size = get_dcp_group().world_size
|
||||
self.dcp_rank = get_dcp_group().rank_in_group
|
||||
except AssertionError:
|
||||
# DCP might not be initialized in testing
|
||||
self.dcp_world_size = 1
|
||||
self.dcp_rank = 0
|
||||
try:
|
||||
from vllm.distributed.parallel_state import get_pcp_group
|
||||
|
||||
self.pcp_world_size = get_pcp_group().world_size
|
||||
self.pcp_rank = get_pcp_group().rank_in_group
|
||||
except AssertionError:
|
||||
self.pcp_world_size = 1
|
||||
self.pcp_rank = 0
|
||||
self.total_cp_world_size = self.pcp_world_size * self.dcp_world_size
|
||||
self.total_cp_rank = self.pcp_rank * self.dcp_world_size + self.dcp_rank
|
||||
|
||||
self.need_to_return_lse_for_decode = (
|
||||
self.dcp_world_size > 1 and self.can_return_lse_for_decode
|
||||
)
|
||||
return self
|
||||
|
||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||
pass
|
||||
|
||||
|
||||
class AttentionImpl(AttentionImplBase[T], Generic[T]):
|
||||
"""Standard attention implementation with forward method."""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int | None = None,
|
||||
alibi_slopes: list[float] | None = None,
|
||||
sliding_window: int | None = None,
|
||||
kv_cache_dtype: str = "auto",
|
||||
logits_soft_cap: float | None = None,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: str | None = None,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: T,
|
||||
output: torch.Tensor | None = None,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def fused_output_quant_supported(self, quant_key: "QuantKey"):
|
||||
"""
|
||||
Does this attention implementation support fused output quantization.
|
||||
This is used by the AttnFusionPass to only fuse output quantization
|
||||
onto implementations that support it.
|
||||
|
||||
:param quant_key: QuantKey object that describes the quantization op
|
||||
:return: is fusion supported for this type of quantization
|
||||
"""
|
||||
return False
|
||||
|
||||
def fused_rope_kvcache_supported(self):
|
||||
"""
|
||||
Does this attention implementation support RoPE+KVCache fusion.
|
||||
This is used by the RopeKVCacheFusionPass to only fuse the RoPE ops
|
||||
with the KV cache update for implementations that support it.
|
||||
"""
|
||||
return False
|
||||
|
||||
def do_rope_and_kv_cache_update(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
is_neox: bool,
|
||||
kv_cache: torch.Tensor,
|
||||
layer_slot_mapping: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
If `fused_rope_kvcache_supported` returns True, this method will be called
|
||||
by torch.ops.vllm.fused_rope_and_unified_kv_cache_update
|
||||
to perform the inplace RoPE and KV cache update.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MLAAttentionImpl(AttentionImplBase[T], Generic[T]):
|
||||
"""MLA attention implementation with forward_mqa and forward_mha methods."""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: list[float] | None,
|
||||
sliding_window: int | None,
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: float | None,
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: str | None,
|
||||
# MLA Specific Arguments
|
||||
q_lora_rank: int | None,
|
||||
kv_lora_rank: int,
|
||||
qk_nope_head_dim: int,
|
||||
qk_rope_head_dim: int,
|
||||
qk_head_dim: int,
|
||||
v_head_dim: int,
|
||||
kv_b_proj: "ColumnParallelLinear",
|
||||
indexer: object | None = None,
|
||||
q_pad_num_heads: int | None = None,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def forward_mha(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
kv_c_normed: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: T,
|
||||
k_scale: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
) -> None:
|
||||
"""MHA-style prefill forward pass."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def forward_mqa(
|
||||
self,
|
||||
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: T,
|
||||
layer: AttentionLayer,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
"""MQA-style decode forward pass."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SparseMLAAttentionImpl(AttentionImplBase[T], Generic[T]):
|
||||
"""Sparse MLA attention implementation with only forward_mqa method.
|
||||
|
||||
Sparse MLA implementations only support decode (MQA-style) attention.
|
||||
They do not support prefill (MHA-style) attention.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: list[float] | None,
|
||||
sliding_window: int | None,
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: float | None,
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: str | None,
|
||||
# MLA Specific Arguments
|
||||
q_lora_rank: int | None,
|
||||
kv_lora_rank: int,
|
||||
qk_nope_head_dim: int,
|
||||
qk_rope_head_dim: int,
|
||||
qk_head_dim: int,
|
||||
v_head_dim: int,
|
||||
kv_b_proj: "ColumnParallelLinear",
|
||||
indexer: object | None = None,
|
||||
q_pad_num_heads: int | None = None,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def forward_mqa(
|
||||
self,
|
||||
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: T,
|
||||
layer: AttentionLayer,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
"""MQA-style decode forward pass."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
|
||||
return kv_cache_dtype.startswith("fp8")
|
||||
|
||||
|
||||
def subclass_attention_backend(
|
||||
name_prefix: str,
|
||||
attention_backend_cls: type[AttentionBackend],
|
||||
builder_cls: type[AttentionMetadataBuilder[M]],
|
||||
) -> type[AttentionBackend]:
|
||||
"""
|
||||
Return a new subclass where `get_builder_cls` returns `builder_cls`.
|
||||
"""
|
||||
name: str = name_prefix + attention_backend_cls.__name__ # type: ignore
|
||||
|
||||
return type(
|
||||
name, (attention_backend_cls,), {"get_builder_cls": lambda: builder_cls}
|
||||
)
|
||||
|
||||
|
||||
def subclass_attention_backend_with_overrides(
|
||||
name_prefix: str,
|
||||
attention_backend_cls: type[AttentionBackend],
|
||||
overrides: dict[str, Any],
|
||||
) -> type[AttentionBackend]:
|
||||
name: str = name_prefix + attention_backend_cls.__name__ # type: ignore
|
||||
return type(name, (attention_backend_cls,), overrides)
|
||||
Reference in New Issue
Block a user