update
This commit is contained in:
26
vllm/model_executor/layers/attention/__init__.py
Normal file
26
vllm/model_executor/layers/attention/__init__.py
Normal file
@@ -0,0 +1,26 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.model_executor.layers.attention.attention import Attention
|
||||
from vllm.model_executor.layers.attention.chunked_local_attention import (
|
||||
ChunkedLocalAttention,
|
||||
)
|
||||
from vllm.model_executor.layers.attention.cross_attention import CrossAttention
|
||||
from vllm.model_executor.layers.attention.encoder_only_attention import (
|
||||
EncoderOnlyAttention,
|
||||
)
|
||||
from vllm.model_executor.layers.attention.mla_attention import MLAAttention
|
||||
from vllm.model_executor.layers.attention.mm_encoder_attention import MMEncoderAttention
|
||||
from vllm.model_executor.layers.attention.static_sink_attention import (
|
||||
StaticSinkAttention,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Attention",
|
||||
"ChunkedLocalAttention",
|
||||
"CrossAttention",
|
||||
"EncoderOnlyAttention",
|
||||
"MLAAttention",
|
||||
"MMEncoderAttention",
|
||||
"StaticSinkAttention",
|
||||
]
|
||||
733
vllm/model_executor/layers/attention/attention.py
Normal file
733
vllm/model_executor/layers/attention/attention.py
Normal file
@@ -0,0 +1,733 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import CacheConfig, get_current_vllm_config
|
||||
from vllm.config.vllm import VllmConfig
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention.kv_transfer_utils import (
|
||||
maybe_transfer_kv_layer,
|
||||
)
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
|
||||
from vllm.model_executor.layers.linear import (
|
||||
UnquantizedLinearMethod,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import (
|
||||
direct_register_custom_op,
|
||||
kv_cache_dtype_str_to_dtype,
|
||||
)
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionType,
|
||||
)
|
||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.v1.attention.selector import get_attn_backend
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
FullAttentionSpec,
|
||||
KVCacheSpec,
|
||||
SlidingWindowSpec,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.layers.attention import MLAAttention
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def validate_kv_sharing_target(
|
||||
current_layer_name, target_layer_name, static_forward_context
|
||||
):
|
||||
error_msg = (
|
||||
f"Specified KV sharing target layer for {current_layer_name} "
|
||||
f"is not valid: target layer {target_layer_name} "
|
||||
)
|
||||
|
||||
if current_layer_name == target_layer_name:
|
||||
raise ValueError(error_msg + "cannot be the same as the current layer.")
|
||||
|
||||
if target_layer_name not in static_forward_context:
|
||||
from vllm.model_executor.models.utils import extract_layer_index
|
||||
|
||||
# If target layer name is not in the static fwd context, it means either
|
||||
# a) the target layer does not come BEFORE the current layer, or
|
||||
# b) the target layer is not an Attention layer that exists in the model
|
||||
current_layer_idx = extract_layer_index(current_layer_name)
|
||||
target_layer_idx = extract_layer_index(target_layer_name)
|
||||
if current_layer_idx <= target_layer_idx:
|
||||
raise ValueError(error_msg + "must come before the current layer.")
|
||||
else:
|
||||
raise ValueError(error_msg + "is not a valid Attention layer in the model.")
|
||||
|
||||
# Currently KV sharing is only supported between layers of the same type
|
||||
target_layer_attn_type = static_forward_context[target_layer_name].attn_type
|
||||
expected = static_forward_context[current_layer_name].attn_type
|
||||
if target_layer_attn_type != expected:
|
||||
raise ValueError(
|
||||
error_msg + f"must be the same type as the current layer ({expected})."
|
||||
)
|
||||
|
||||
|
||||
def should_load_quant_weights(quant_method: QuantizeMethodBase | None) -> bool:
|
||||
"""Returns whether the quantization method should load quantized weights."""
|
||||
return quant_method is not None and not isinstance(
|
||||
quant_method, UnquantizedLinearMethod
|
||||
)
|
||||
|
||||
|
||||
def set_default_quant_scales(layer: nn.Module, register_buffer: bool = False) -> None:
|
||||
"""Sets default quantization scales for the layer."""
|
||||
if register_buffer:
|
||||
layer.register_buffer("_k_scale", torch.tensor(1.0, dtype=torch.float32))
|
||||
layer.register_buffer("_v_scale", torch.tensor(1.0, dtype=torch.float32))
|
||||
layer.register_buffer("_q_scale", torch.tensor(1.0, dtype=torch.float32))
|
||||
layer.register_buffer("_prob_scale", torch.tensor(1.0, dtype=torch.float32))
|
||||
else:
|
||||
layer._k_scale.fill_(1.0)
|
||||
layer._v_scale.fill_(1.0)
|
||||
layer._q_scale.fill_(1.0)
|
||||
layer._prob_scale.fill_(1.0)
|
||||
|
||||
# We also keep q/k/v_scale on host (cpu) memory for attention
|
||||
# backends that require the scales to be on host instead of on device.
|
||||
# e.g. Flashinfer
|
||||
layer._q_scale_float = 1.0
|
||||
layer._k_scale_float = 1.0
|
||||
layer._v_scale_float = 1.0
|
||||
layer._prob_scale_float = 1.0
|
||||
|
||||
# Initialize q/k/v range constants used by calc_kv_scales
|
||||
layer.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
|
||||
layer.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
|
||||
layer.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)
|
||||
|
||||
|
||||
def _init_kv_cache_quant(
|
||||
layer: nn.Module,
|
||||
quant_config: QuantizationConfig | None,
|
||||
prefix: str,
|
||||
) -> None:
|
||||
"""Initializes KV cache scaling factors and quantization method.
|
||||
|
||||
This helper function sets up the KV cache quantization attributes that are
|
||||
shared between Attention and MLAAttention layers. It initializes scale
|
||||
tensors for query, key, value, and probability, and configures the
|
||||
quantization method if applicable.
|
||||
|
||||
Args:
|
||||
layer: The attention layer instance to initialize.
|
||||
quant_config: Optional quantization configuration.
|
||||
prefix: Layer name prefix for quantization method lookup.
|
||||
"""
|
||||
quant_method = (
|
||||
quant_config.get_quant_method(layer, prefix=prefix) if quant_config else None
|
||||
)
|
||||
|
||||
# Note [Register q/k/v/prob scales in state dict]
|
||||
# When calling model.to(device), only parameters/buffers in state dict are
|
||||
# moved. If not registering q/k/v/prob scales in state dict, there would
|
||||
# be an IMA error when a cuda kernel (e.g., quant_fp8) accesses the tensor
|
||||
# on cpu.
|
||||
# Registering in state dict means it interacts with weight loading. One edge
|
||||
# case is when quant_method is None, or quant_method is UnquantizedLinearMethod
|
||||
# (i.e., should_load_quant_weights(quant_method) == False).
|
||||
# In this case, the checkpoint does not have the scales. We need to
|
||||
# initialize the scales to 1.0 and update the scales after weight loading.
|
||||
# This is espectially important when we load dummy weights first (providing
|
||||
# wrong scales) and then load real weights (which misses scales and keeps the
|
||||
# wrong scales from dummy load).
|
||||
set_default_quant_scales(layer, register_buffer=True)
|
||||
|
||||
# The output scale on host memory. This should be the input scale of
|
||||
# the quant op after this attention layer.
|
||||
layer._o_scale_float = None
|
||||
|
||||
quant_method = (
|
||||
quant_config.get_quant_method(layer, prefix=prefix) if quant_config else None
|
||||
)
|
||||
|
||||
# See [Note: Register q/k/v/prob scales in state dict]
|
||||
if should_load_quant_weights(quant_method):
|
||||
assert isinstance(quant_method, BaseKVCacheMethod)
|
||||
# TODO (mgoin): kv cache dtype should be specified in the FP8
|
||||
# checkpoint config and become the "auto" behavior
|
||||
if layer.kv_cache_dtype == "fp8_e5m2":
|
||||
raise ValueError("fp8_e5m2 kv-cache is not supported with fp8 checkpoints.")
|
||||
# If quantization is enabled, we make "k_scale" and "v_scale"
|
||||
# parameters so that it can be loaded from the model checkpoint.
|
||||
# The k/v_scale will then be converted back to native float32
|
||||
# values after weight loading.
|
||||
layer.quant_method = quant_method
|
||||
layer.quant_method.create_weights(layer)
|
||||
|
||||
|
||||
class Attention(nn.Module, AttentionLayerBase):
|
||||
"""Attention layer.
|
||||
|
||||
This class takes query, key, and value tensors as input. The input tensors
|
||||
can either contain prompt tokens or generation tokens.
|
||||
The class does the following:
|
||||
|
||||
1. Store the input key and value tensors in the KV cache.
|
||||
2. Perform (multi-head/multi-query/grouped-query) attention.
|
||||
3. Return the output tensor.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int | None = None,
|
||||
alibi_slopes: list[float] | None = None,
|
||||
use_alibi_sqrt: bool | None = None,
|
||||
cache_config: CacheConfig | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
logits_soft_cap: float | None = None,
|
||||
per_layer_sliding_window: int | None = None,
|
||||
prefix: str = "",
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: str | None = None,
|
||||
attn_backend: type[AttentionBackend] | None = None,
|
||||
head_size_v: int | None = None,
|
||||
**extra_impl_args,
|
||||
) -> None:
|
||||
"""
|
||||
The KV cache is stored inside this class and is accessed via
|
||||
`self.kv_cache`.
|
||||
"""
|
||||
super().__init__()
|
||||
if per_layer_sliding_window is not None:
|
||||
# per-layer sliding window
|
||||
sliding_window = per_layer_sliding_window
|
||||
elif cache_config is not None:
|
||||
# model-level sliding window
|
||||
sliding_window = cache_config.sliding_window
|
||||
else:
|
||||
sliding_window = None
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
if cache_config is not None:
|
||||
kv_cache_dtype = cache_config.cache_dtype
|
||||
block_size = cache_config.block_size
|
||||
calculate_kv_scales = cache_config.calculate_kv_scales
|
||||
else:
|
||||
kv_cache_dtype = "auto"
|
||||
block_size = 16
|
||||
calculate_kv_scales = False
|
||||
|
||||
# llm-compressor mdls need to set cache_dtype to "fp8" manually.
|
||||
kv_cache_scheme = getattr(quant_config, "kv_cache_scheme", None)
|
||||
if kv_cache_scheme is not None:
|
||||
kv_cache_dtype = "fp8"
|
||||
calculate_kv_scales = False
|
||||
if cache_config is not None:
|
||||
cache_config.cache_dtype = "fp8"
|
||||
cache_config.calculate_kv_scales = False
|
||||
|
||||
# Check if per-head quant scales are required based on kv_cache_scheme
|
||||
use_per_head_quant_scales = (
|
||||
kv_cache_scheme is not None
|
||||
and kv_cache_scheme.get("strategy") == "attn_head"
|
||||
)
|
||||
|
||||
self.kv_cache_torch_dtype = kv_cache_dtype_str_to_dtype(
|
||||
kv_cache_dtype, vllm_config.model_config
|
||||
)
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.calculate_kv_scales = calculate_kv_scales
|
||||
if num_kv_heads is None:
|
||||
num_kv_heads = num_heads
|
||||
assert num_heads % num_kv_heads == 0, (
|
||||
f"num_heads ({num_heads}) is not divisible by num_kv_heads ({num_kv_heads})"
|
||||
)
|
||||
self.quant_config = quant_config
|
||||
self.layer_name = prefix
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.head_size_v = self.head_size if head_size_v is None else head_size_v
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.sliding_window = sliding_window
|
||||
self.has_sink = extra_impl_args.get("sinks") is not None
|
||||
|
||||
# NOTE: model_config may be None during certain tests
|
||||
model_config = vllm_config.model_config
|
||||
self.use_mm_prefix = model_config is not None and model_config.is_mm_prefix_lm
|
||||
|
||||
# During model initialization, the default dtype is set as the model
|
||||
# weight and activation dtype.
|
||||
dtype = torch.get_default_dtype()
|
||||
if attn_backend is None:
|
||||
self.attn_backend = get_attn_backend(
|
||||
head_size,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
block_size,
|
||||
use_mla=False,
|
||||
has_sink=self.has_sink,
|
||||
use_mm_prefix=self.use_mm_prefix,
|
||||
use_per_head_quant_scales=use_per_head_quant_scales,
|
||||
attn_type=attn_type,
|
||||
)
|
||||
else:
|
||||
self.attn_backend = attn_backend
|
||||
backend_supports_alibi_sqrt = self.attn_backend.supports_alibi_sqrt()
|
||||
use_alibi_sqrt = use_alibi_sqrt if use_alibi_sqrt else False
|
||||
if use_alibi_sqrt and not backend_supports_alibi_sqrt:
|
||||
raise ValueError(
|
||||
f"use_alibi_sqrt is not supported by backend "
|
||||
f"{self.attn_backend.get_name()}."
|
||||
)
|
||||
self.use_alibi_sqrt = bool(use_alibi_sqrt)
|
||||
if backend_supports_alibi_sqrt:
|
||||
extra_impl_args["use_alibi_sqrt"] = self.use_alibi_sqrt
|
||||
# prefix caching + batch invariance is currently not supported for
|
||||
# FLASHINFER and TRITON_MLA.
|
||||
if (
|
||||
cache_config is not None
|
||||
and cache_config.enable_prefix_caching
|
||||
and vllm_is_batch_invariant()
|
||||
and (
|
||||
self.attn_backend.get_name() == "FLASHINFER"
|
||||
or self.attn_backend.get_name() == "TRITON_MLA"
|
||||
)
|
||||
):
|
||||
logger.warning_once(
|
||||
"Disabling prefix caching for FLASHINFER/TRITON_MLA "
|
||||
"with batch invariance, as it is not yet supported.",
|
||||
scope="local",
|
||||
)
|
||||
cache_config.enable_prefix_caching = False
|
||||
|
||||
impl_cls = self.attn_backend.get_impl_cls()
|
||||
self.impl = impl_cls(
|
||||
num_heads,
|
||||
head_size,
|
||||
scale,
|
||||
num_kv_heads,
|
||||
alibi_slopes,
|
||||
sliding_window,
|
||||
kv_cache_dtype,
|
||||
logits_soft_cap,
|
||||
attn_type,
|
||||
kv_sharing_target_layer_name,
|
||||
**extra_impl_args,
|
||||
)
|
||||
self.backend = AttentionBackendEnum[self.attn_backend.get_name()]
|
||||
self.dtype = dtype
|
||||
|
||||
# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
|
||||
# torch.compile works by registering the attention as one giant
|
||||
# opaque custom op. For other platforms, we directly call them
|
||||
# and let torch.compile handle them.
|
||||
self.use_direct_call = not current_platform.opaque_attention_op()
|
||||
|
||||
self.use_output = self.attn_backend.accept_output_buffer
|
||||
compilation_config = vllm_config.compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
self.attn_type = attn_type
|
||||
|
||||
if kv_sharing_target_layer_name is not None:
|
||||
validate_kv_sharing_target(
|
||||
prefix,
|
||||
kv_sharing_target_layer_name,
|
||||
compilation_config.static_forward_context,
|
||||
)
|
||||
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||
|
||||
# use a placeholder kv cache tensor during init, which will be replaced
|
||||
# by bind_kv_cache
|
||||
# this variable will not be accessed if use_direct_call is True
|
||||
self.kv_cache = [
|
||||
torch.tensor([])
|
||||
for _ in range(vllm_config.parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
|
||||
# Initialize KV cache quantization attributes
|
||||
_init_kv_cache_quant(self, quant_config, prefix)
|
||||
|
||||
# for attn backends supporting query quantization
|
||||
self.query_quant = None
|
||||
if self.impl.supports_quant_query_input and self.kv_cache_dtype.startswith(
|
||||
"fp8"
|
||||
):
|
||||
is_per_head = (
|
||||
hasattr(self, "q_scale") and self.q_scale.numel() == self.num_kv_heads
|
||||
)
|
||||
block_size = self.head_size * self.num_heads // self.num_kv_heads
|
||||
self.query_quant = QuantFP8(
|
||||
static=True,
|
||||
group_shape=GroupShape(-1, block_size)
|
||||
if is_per_head
|
||||
else GroupShape.PER_TENSOR,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
# For some alternate attention backends like MLA the attention output
|
||||
# shape does not match the query shape, so we optionally let the model
|
||||
# definition specify the output tensor shape.
|
||||
output_shape: torch.Size | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
The KV cache is stored inside this class and is accessed via
|
||||
`self.kv_cache`.
|
||||
|
||||
Attention metadata (`attn_metadata`) is set using a context manager in
|
||||
the model runner's `execute_model` method. It is accessed via forward
|
||||
context using
|
||||
`vllm.forward_context.get_forward_context().attn_metadata`.
|
||||
"""
|
||||
if self.calculate_kv_scales:
|
||||
torch.ops.vllm.maybe_calc_kv_scales(query, key, value, self.layer_name)
|
||||
output_dtype = query.dtype
|
||||
if self.query_quant is not None:
|
||||
# quantizing with a simple torch operation enables
|
||||
# torch.compile to fuse this into previous ops
|
||||
# which reduces overheads during decoding.
|
||||
# Otherwise queries are quantized using custom ops
|
||||
# which causes decoding overheads
|
||||
assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"}
|
||||
|
||||
# check if query quantization is supported
|
||||
if self.impl.supports_quant_query_input:
|
||||
query, _ = self.query_quant(query, self._q_scale)
|
||||
|
||||
if self.use_output:
|
||||
if output_shape is None:
|
||||
# Handle both 2D [num_tokens, hidden] and
|
||||
# 3D [num_tokens, heads, head_dim] query
|
||||
num_tokens = query.shape[0]
|
||||
output_shape = torch.Size(
|
||||
(num_tokens, self.num_heads * self.head_size_v)
|
||||
)
|
||||
output = torch.empty(output_shape, dtype=output_dtype, device=query.device)
|
||||
hidden_size = output_shape[-1]
|
||||
# Reshape the query, key, and value tensors.
|
||||
# NOTE(woosuk): We do this outside the custom op to minimize the
|
||||
# CPU overheads from the non-CUDA-graph regions.
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
output = output.view(-1, self.num_heads, self.head_size_v)
|
||||
if key is not None:
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
if value is not None:
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size_v)
|
||||
kv_cache_dummy_dep = None
|
||||
if self.use_direct_call:
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
if (
|
||||
not self.attn_backend.forward_includes_kv_cache_update
|
||||
and self.kv_sharing_target_layer_name is None
|
||||
and key is not None
|
||||
and value is not None
|
||||
):
|
||||
kv_cache_dummy_dep = unified_kv_cache_update(
|
||||
key, value, self.layer_name
|
||||
)
|
||||
unified_attention_with_output(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
output,
|
||||
self.layer_name,
|
||||
kv_cache_dummy_dep=kv_cache_dummy_dep,
|
||||
)
|
||||
else:
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
if (
|
||||
not self.attn_backend.forward_includes_kv_cache_update
|
||||
and self.kv_sharing_target_layer_name is None
|
||||
and key is not None
|
||||
and value is not None
|
||||
):
|
||||
kv_cache_dummy_dep = torch.ops.vllm.unified_kv_cache_update(
|
||||
key, value, self.layer_name
|
||||
)
|
||||
torch.ops.vllm.unified_attention_with_output(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
output,
|
||||
self.layer_name,
|
||||
kv_cache_dummy_dep=kv_cache_dummy_dep,
|
||||
)
|
||||
return output.view(-1, hidden_size)
|
||||
else:
|
||||
assert self.attn_backend.forward_includes_kv_cache_update, (
|
||||
"Split KV cache update not supported when output tensor not provided."
|
||||
)
|
||||
if self.use_direct_call:
|
||||
return unified_attention(query, key, value, self.layer_name)
|
||||
else:
|
||||
return torch.ops.vllm.unified_attention(
|
||||
query, key, value, self.layer_name
|
||||
)
|
||||
|
||||
def calc_kv_scales(self, query, key, value):
|
||||
self._q_scale.copy_(torch.abs(query).max() / self.q_range)
|
||||
self._k_scale.copy_(torch.abs(key).max() / self.k_range)
|
||||
self._v_scale.copy_(torch.abs(value).max() / self.v_range)
|
||||
self._q_scale_float = self._q_scale.item()
|
||||
self._k_scale_float = self._k_scale.item()
|
||||
self._v_scale_float = self._v_scale.item()
|
||||
# We only calculate the scales once
|
||||
self.calculate_kv_scales = False
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
s = f"head_size={self.impl.head_size}" # type: ignore
|
||||
s += f", num_heads={self.impl.num_heads}" # type: ignore
|
||||
s += f", num_kv_heads={self.impl.num_kv_heads}" # type: ignore
|
||||
s += f", scale={self.impl.scale}" # type: ignore
|
||||
s += f", backend={self.impl.__class__.__name__}"
|
||||
return s
|
||||
|
||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||
self.impl.process_weights_after_loading(act_dtype)
|
||||
|
||||
# If we should not load quant weights, we initialize the scales to 1.0
|
||||
# as the default value. See [Note: Register q/k/v/prob scales in state dict]
|
||||
# for more details.
|
||||
quant_method = (
|
||||
self.quant_config.get_quant_method(self, prefix=self.layer_name)
|
||||
if self.quant_config
|
||||
else None
|
||||
)
|
||||
if not should_load_quant_weights(quant_method):
|
||||
set_default_quant_scales(self, register_buffer=False)
|
||||
|
||||
def get_attn_backend(self) -> type[AttentionBackend]:
|
||||
return self.attn_backend
|
||||
|
||||
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
|
||||
# Block size may get updated after model loading, refresh it
|
||||
block_size = vllm_config.cache_config.block_size
|
||||
# Should not be called for enc-dec or encoder-only attention.
|
||||
assert self.attn_type == AttentionType.DECODER
|
||||
if self.sliding_window is not None:
|
||||
assert not vllm_config.model_config.use_mla, (
|
||||
"MLA is not supported for slidingwindow"
|
||||
)
|
||||
return SlidingWindowSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_size=self.head_size,
|
||||
dtype=self.kv_cache_torch_dtype,
|
||||
sliding_window=self.sliding_window,
|
||||
)
|
||||
else:
|
||||
return FullAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_size=self.head_size,
|
||||
head_size_v=self.head_size_v,
|
||||
dtype=self.kv_cache_torch_dtype,
|
||||
)
|
||||
|
||||
|
||||
def maybe_calc_kv_scales(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> None:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
|
||||
# Only calculate if the layer's calculate_kv_scales flag is True
|
||||
# This flag gets set to False after the first forward pass
|
||||
if not self.calculate_kv_scales:
|
||||
return
|
||||
|
||||
self.calc_kv_scales(query, key, value)
|
||||
|
||||
|
||||
def maybe_calc_kv_scales_fake(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="maybe_calc_kv_scales",
|
||||
op_func=maybe_calc_kv_scales,
|
||||
mutates_args=["query", "key", "value"],
|
||||
fake_impl=maybe_calc_kv_scales_fake,
|
||||
)
|
||||
|
||||
|
||||
def get_attention_context(
|
||||
layer_name: str,
|
||||
) -> tuple[Any, "Attention | MLAAttention", torch.Tensor, torch.Tensor]:
|
||||
"""Extract attention context for a given layer.
|
||||
|
||||
This helper function extracts the attention metadata, attention layer
|
||||
instance, KV cache tensor, and slot mapping for a specific layer.
|
||||
|
||||
Args:
|
||||
layer_name: The name/identifier of the attention layer.
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
- attn_metadata: Attention metadata for this specific layer, or None if
|
||||
no metadata available
|
||||
- attn_layer: The attention layer instance (Attention or MLAAttention)
|
||||
- kv_cache: The KV cache tensor for current virtual engine
|
||||
- slot_mapping: The slot mapping for this specific layer
|
||||
|
||||
Note: attn_metadata may be None, but attn_layer and kv_cache are always
|
||||
extracted from the forward context.
|
||||
"""
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata[layer_name]
|
||||
attn_layer: Attention | MLAAttention = forward_context.no_compile_layers[layer_name]
|
||||
kv_cache = attn_layer.kv_cache[forward_context.virtual_engine]
|
||||
slot_mapping = forward_context.slot_mapping
|
||||
assert isinstance(slot_mapping, dict), (
|
||||
f"Expected slot_mapping to be a dict, got {type(slot_mapping)}. "
|
||||
)
|
||||
layer_slot_mapping = slot_mapping.get(layer_name)
|
||||
return attn_metadata, attn_layer, kv_cache, layer_slot_mapping
|
||||
|
||||
|
||||
@maybe_transfer_kv_layer
|
||||
def unified_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> torch.Tensor:
|
||||
attn_metadata, self, kv_cache, _ = get_attention_context(layer_name)
|
||||
output = self.impl.forward(self, query, key, value, kv_cache, attn_metadata)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def unified_attention_fake(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(query).contiguous()
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="unified_attention",
|
||||
op_func=unified_attention,
|
||||
fake_impl=unified_attention_fake,
|
||||
)
|
||||
|
||||
|
||||
def unified_kv_cache_update(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Returns a dummy that is passed to unified_attention to signal a side effect and
|
||||
the data dependency between them to ensure torch.compile preserves ordering.
|
||||
"""
|
||||
_, attn_layer, kv_cache, layer_slot_mapping = get_attention_context(layer_name)
|
||||
if layer_slot_mapping is not None:
|
||||
assert hasattr(attn_layer.impl, "do_kv_cache_update"), (
|
||||
f"{attn_layer.impl.__class__.__name__} does not support kv cache update"
|
||||
)
|
||||
attn_layer.impl.do_kv_cache_update(
|
||||
attn_layer,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
layer_slot_mapping,
|
||||
)
|
||||
|
||||
return torch.empty(0, device=kv_cache.device, dtype=kv_cache.dtype)
|
||||
|
||||
|
||||
def unified_kv_cache_update_fake(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty(0, device=key.device, dtype=key.dtype)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="unified_kv_cache_update",
|
||||
op_func=unified_kv_cache_update,
|
||||
fake_impl=unified_kv_cache_update_fake,
|
||||
mutates_args=[],
|
||||
)
|
||||
|
||||
|
||||
@maybe_transfer_kv_layer
|
||||
def unified_attention_with_output(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
kv_cache_dummy_dep: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
# kv_cache_dummy_dep is not used but accepting it creates a data dependency
|
||||
# that ensures torch.compile preserves ordering between KV cache update and
|
||||
# attention forward.
|
||||
del kv_cache_dummy_dep
|
||||
attn_metadata, self, kv_cache, _ = get_attention_context(layer_name)
|
||||
|
||||
self.impl.forward(
|
||||
self,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
output=output,
|
||||
output_scale=output_scale,
|
||||
output_block_scale=output_block_scale,
|
||||
)
|
||||
|
||||
|
||||
def unified_attention_with_output_fake(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
kv_cache_dummy_dep: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="unified_attention_with_output",
|
||||
op_func=unified_attention_with_output,
|
||||
mutates_args=["output", "output_block_scale"],
|
||||
fake_impl=unified_attention_with_output_fake,
|
||||
)
|
||||
130
vllm/model_executor/layers/attention/chunked_local_attention.py
Normal file
130
vllm/model_executor/layers/attention/chunked_local_attention.py
Normal file
@@ -0,0 +1,130 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import functools
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.config.vllm import VllmConfig
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
subclass_attention_backend,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
make_local_attention_virtual_batches,
|
||||
)
|
||||
from vllm.v1.attention.selector import get_attn_backend
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
AttentionSpec,
|
||||
ChunkedLocalAttentionSpec,
|
||||
KVCacheSpec,
|
||||
)
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def create_chunked_local_attention_backend(
|
||||
underlying_attn_backend: AttentionBackend,
|
||||
attention_chunk_size: int,
|
||||
block_size: int,
|
||||
) -> type[AttentionBackend]:
|
||||
prefix = f"ChunkedLocalAttention_{attention_chunk_size}_{block_size}_"
|
||||
|
||||
underlying_builder = underlying_attn_backend.get_builder_cls()
|
||||
assert issubclass(underlying_builder, AttentionMetadataBuilder)
|
||||
|
||||
class ChunkedLocalAttentionBuilder(underlying_builder): # type: ignore
|
||||
@classmethod
|
||||
def get_cudagraph_support(
|
||||
cls: type["AttentionMetadataBuilder"],
|
||||
vllm_config: VllmConfig,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
) -> AttentionCGSupport:
|
||||
# Explicit override in case the underlying builder specialized this getter.
|
||||
# @override omitted only because of mypy limitation due to type variable.
|
||||
return AttentionCGSupport.NEVER
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
):
|
||||
cm, make_virtual_batches_block_table = make_local_attention_virtual_batches(
|
||||
attention_chunk_size, common_attn_metadata, block_size
|
||||
)
|
||||
metadata = super().build(common_prefix_len, cm, fast_build)
|
||||
metadata.make_virtual_batches_block_table = make_virtual_batches_block_table
|
||||
return metadata
|
||||
|
||||
def update_block_table(
|
||||
self, metadata, blk_table: torch.Tensor, slot_mapping: torch.Tensor
|
||||
):
|
||||
blk_table = metadata.make_virtual_batches_block_table(blk_table)
|
||||
return super().update_block_table(metadata, blk_table, slot_mapping)
|
||||
|
||||
attn_backend = subclass_attention_backend(
|
||||
name_prefix=prefix,
|
||||
attention_backend_cls=underlying_attn_backend,
|
||||
builder_cls=ChunkedLocalAttentionBuilder,
|
||||
)
|
||||
|
||||
return attn_backend
|
||||
|
||||
|
||||
class ChunkedLocalAttention(Attention):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
attention_chunk_size: int,
|
||||
num_kv_heads: int | None = None,
|
||||
alibi_slopes: list[float] | None = None,
|
||||
cache_config: CacheConfig | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
kv_sharing_target_layer_name: str | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
self.attention_chunk_size = attention_chunk_size
|
||||
dtype = torch.get_default_dtype()
|
||||
if cache_config is not None:
|
||||
kv_cache_dtype = cache_config.cache_dtype
|
||||
block_size = cache_config.block_size
|
||||
else:
|
||||
kv_cache_dtype = "auto"
|
||||
block_size = 16
|
||||
|
||||
underlying_attn_backend = get_attn_backend(
|
||||
head_size, dtype, kv_cache_dtype, block_size
|
||||
)
|
||||
attn_backend = create_chunked_local_attention_backend(
|
||||
underlying_attn_backend, attention_chunk_size, block_size
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
scale=scale,
|
||||
num_kv_heads=num_kv_heads,
|
||||
alibi_slopes=alibi_slopes,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
kv_sharing_target_layer_name=kv_sharing_target_layer_name,
|
||||
attn_backend=attn_backend,
|
||||
)
|
||||
|
||||
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
|
||||
assert self.attention_chunk_size
|
||||
return ChunkedLocalAttentionSpec(
|
||||
block_size=vllm_config.cache_config.block_size,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_size=self.head_size,
|
||||
dtype=self.kv_cache_torch_dtype,
|
||||
attention_chunk_size=self.attention_chunk_size,
|
||||
)
|
||||
226
vllm/model_executor/layers/attention/cross_attention.py
Normal file
226
vllm/model_executor/layers/attention/cross_attention.py
Normal file
@@ -0,0 +1,226 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import functools
|
||||
from copy import copy
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionMetadata,
|
||||
AttentionType,
|
||||
CommonAttentionMetadata,
|
||||
subclass_attention_backend_with_overrides,
|
||||
)
|
||||
from vllm.v1.attention.selector import get_attn_backend
|
||||
from vllm.v1.kv_cache_interface import CrossAttentionSpec, KVCacheSpec
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _get_cross_slot_mapping(
|
||||
encoder_seq_lens: np.ndarray,
|
||||
block_table_tensor: torch.Tensor,
|
||||
kv_cache_spec: CrossAttentionSpec,
|
||||
device: torch.device,
|
||||
) -> torch.Tensor:
|
||||
"""Get cross-attention slot mappings."""
|
||||
|
||||
block_size = kv_cache_spec.block_size
|
||||
slot_mappings = []
|
||||
|
||||
# Find indices with non-zero encoder sequence lengths
|
||||
# The majority of parallel requests will be running the
|
||||
# decoder, so this list should be relatively small.
|
||||
active_indices = np.nonzero(encoder_seq_lens)[0]
|
||||
|
||||
for req_index in active_indices:
|
||||
encoder_seq_len = encoder_seq_lens[req_index].item()
|
||||
|
||||
# Calculate the number of blocks needed for this request
|
||||
num_blocks_needed = cdiv(encoder_seq_len, block_size)
|
||||
|
||||
# Get the block IDs for this request from the tensor
|
||||
req_block_ids = block_table_tensor[req_index]
|
||||
|
||||
# Get only the blocks we need (first num_blocks_needed blocks)
|
||||
needed_block_ids = req_block_ids[:num_blocks_needed]
|
||||
|
||||
# All needed blocks are allocated
|
||||
i_values = torch.arange(encoder_seq_len, dtype=torch.int64, device=device)
|
||||
block_indices = i_values // block_size
|
||||
block_offsets = i_values % block_size
|
||||
block_numbers = needed_block_ids[block_indices]
|
||||
slot_mapping = block_numbers * block_size + block_offsets
|
||||
|
||||
slot_mappings.append(slot_mapping)
|
||||
|
||||
if slot_mappings:
|
||||
return torch.cat(slot_mappings)
|
||||
else:
|
||||
return torch.empty(0, dtype=torch.int64, device=device)
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def create_cross_attention_backend(
|
||||
underlying_attn_backend: AttentionBackend,
|
||||
) -> type[AttentionBackend]:
|
||||
prefix = "CrossAttention_"
|
||||
underlying_builder = underlying_attn_backend.get_builder_cls()
|
||||
underlying_impl = underlying_attn_backend.get_impl_cls()
|
||||
|
||||
class CrossAttentionBuilder(underlying_builder): # type: ignore
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> AttentionMetadata:
|
||||
new_metadata = copy(common_attn_metadata)
|
||||
new_metadata.causal = False
|
||||
max_encoder_len = int(new_metadata.encoder_seq_lens_cpu.max())
|
||||
new_metadata.max_seq_len = max_encoder_len
|
||||
# Any computed tokens indicated decode step>1 (no chunked prefill)
|
||||
num_cache_decodes = (
|
||||
(common_attn_metadata.num_computed_tokens_cpu > 0).sum().item()
|
||||
)
|
||||
if num_cache_decodes > 0:
|
||||
# CrossAttn KV cache has already been populated on first decoder step,
|
||||
# skip slot_mapping calculation for requests that do not need
|
||||
# reshape_and_cache.
|
||||
num_tokens = common_attn_metadata.num_computed_tokens_cpu.numpy()
|
||||
new_metadata.encoder_seq_lens_cpu = np.where(
|
||||
num_tokens > 0, 0, new_metadata.encoder_seq_lens_cpu
|
||||
)
|
||||
|
||||
# seq_lens is provided by model runner: initial encoder input length is
|
||||
# needed here to know how many tokens to attend to from the cached
|
||||
# cross-attention KV cache.
|
||||
new_metadata.seq_lens = common_attn_metadata.encoder_seq_lens
|
||||
new_metadata._seq_lens_cpu = torch.from_numpy(
|
||||
common_attn_metadata.encoder_seq_lens_cpu
|
||||
)
|
||||
|
||||
# NOTE (NickLucche) use `new_metadata` instead of `common_*` (initial) here
|
||||
slot_mapping = _get_cross_slot_mapping(
|
||||
new_metadata.encoder_seq_lens_cpu,
|
||||
new_metadata.block_table_tensor,
|
||||
self.kv_cache_spec,
|
||||
self.device,
|
||||
)
|
||||
attn_metadata = super().build(common_prefix_len, new_metadata, fast_build)
|
||||
attn_metadata.slot_mapping = slot_mapping
|
||||
return attn_metadata
|
||||
|
||||
# NOTE(Lucas): we need a custom impl so we can use the slot-mapping computed by
|
||||
# `CrossAttentionBuilder` instead of the one computed by `BlockTable`
|
||||
# (gpu_model_runner)
|
||||
class CrossAttentionImpl(underlying_impl): # type: ignore[valid-type,misc]
|
||||
def forward(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
output: torch.Tensor | None = None,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
if (
|
||||
not underlying_attn_backend.forward_includes_kv_cache_update
|
||||
and attn_metadata is not None
|
||||
and layer.kv_sharing_target_layer_name is None
|
||||
and key is not None
|
||||
and value is not None
|
||||
):
|
||||
self.do_kv_cache_update(
|
||||
layer, key, value, kv_cache, attn_metadata.slot_mapping
|
||||
)
|
||||
|
||||
return super().forward(
|
||||
layer,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
output,
|
||||
output_scale,
|
||||
output_block_scale,
|
||||
)
|
||||
|
||||
attn_backend = subclass_attention_backend_with_overrides(
|
||||
name_prefix=prefix,
|
||||
attention_backend_cls=underlying_attn_backend,
|
||||
overrides={
|
||||
"get_builder_cls": lambda: CrossAttentionBuilder,
|
||||
"get_impl_cls": lambda: CrossAttentionImpl,
|
||||
"forward_includes_kv_cache_update": True,
|
||||
},
|
||||
)
|
||||
|
||||
return attn_backend
|
||||
|
||||
|
||||
class CrossAttention(Attention):
|
||||
"""
|
||||
Cross-attention for encoder-decoder models.
|
||||
Handles attention between decoder queries and encoder keys/values.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
cache_config: CacheConfig | None = None,
|
||||
attn_type: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
dtype = torch.get_default_dtype()
|
||||
|
||||
if cache_config is not None:
|
||||
kv_cache_dtype = cache_config.cache_dtype
|
||||
block_size = cache_config.block_size
|
||||
else:
|
||||
kv_cache_dtype = "auto"
|
||||
block_size = 16
|
||||
|
||||
if attn_type is not None:
|
||||
assert attn_type == AttentionType.ENCODER_DECODER, (
|
||||
"CrossAttention only supports AttentionType.ENCODER_DECODER"
|
||||
)
|
||||
|
||||
underlying_attn_backend = get_attn_backend(
|
||||
head_size,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
block_size,
|
||||
attn_type=AttentionType.ENCODER_DECODER,
|
||||
)
|
||||
attn_backend = create_cross_attention_backend(underlying_attn_backend)
|
||||
|
||||
super().__init__(
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
scale=scale,
|
||||
cache_config=cache_config,
|
||||
attn_backend=attn_backend,
|
||||
attn_type=AttentionType.ENCODER_DECODER,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
|
||||
return CrossAttentionSpec(
|
||||
block_size=vllm_config.cache_config.block_size,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_size=self.head_size,
|
||||
dtype=self.kv_cache_torch_dtype,
|
||||
)
|
||||
101
vllm/model_executor/layers/attention/encoder_only_attention.py
Normal file
101
vllm/model_executor/layers/attention/encoder_only_attention.py
Normal file
@@ -0,0 +1,101 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import functools
|
||||
from copy import copy
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.config.vllm import VllmConfig
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionMetadata,
|
||||
AttentionType,
|
||||
CommonAttentionMetadata,
|
||||
subclass_attention_backend,
|
||||
)
|
||||
from vllm.v1.attention.selector import get_attn_backend
|
||||
from vllm.v1.kv_cache_interface import KVCacheSpec
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def create_encoder_only_attention_backend(
|
||||
underlying_attn_backend: AttentionBackend,
|
||||
) -> type[AttentionBackend]:
|
||||
prefix = "EncoderOnlyAttention_"
|
||||
underlying_builder = underlying_attn_backend.get_builder_cls()
|
||||
|
||||
class EncoderOnlyAttentionBuilder(underlying_builder): # type: ignore
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> AttentionMetadata:
|
||||
new_common_attn_metadata = copy(common_attn_metadata)
|
||||
new_common_attn_metadata.causal = False
|
||||
return super().build(
|
||||
common_prefix_len, new_common_attn_metadata, fast_build
|
||||
)
|
||||
|
||||
attn_backend = subclass_attention_backend(
|
||||
name_prefix=prefix,
|
||||
attention_backend_cls=underlying_attn_backend,
|
||||
builder_cls=EncoderOnlyAttentionBuilder,
|
||||
)
|
||||
|
||||
return attn_backend
|
||||
|
||||
|
||||
class EncoderOnlyAttention(Attention):
|
||||
"""
|
||||
Encoder attention is a special case that doesn't need a KV Cache.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
cache_config: CacheConfig | None = None,
|
||||
attn_type: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
dtype = torch.get_default_dtype()
|
||||
|
||||
if cache_config is not None:
|
||||
kv_cache_dtype = cache_config.cache_dtype
|
||||
block_size = cache_config.block_size
|
||||
else:
|
||||
kv_cache_dtype = "auto"
|
||||
block_size = 16
|
||||
|
||||
underlying_attn_backend = get_attn_backend(
|
||||
head_size,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
block_size,
|
||||
attn_type=AttentionType.ENCODER_ONLY,
|
||||
)
|
||||
|
||||
attn_backend = create_encoder_only_attention_backend(underlying_attn_backend)
|
||||
|
||||
if attn_type is not None:
|
||||
assert attn_type == AttentionType.ENCODER_ONLY, (
|
||||
"EncoderOnlyAttention only supports AttentionType.ENCODER_ONLY"
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
scale=scale,
|
||||
cache_config=cache_config,
|
||||
attn_backend=attn_backend,
|
||||
attn_type=AttentionType.ENCODER_ONLY,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
|
||||
# Does not need KV cache
|
||||
return None
|
||||
60
vllm/model_executor/layers/attention/kv_transfer_utils.py
Normal file
60
vllm/model_executor/layers/attention/kv_transfer_utils.py
Normal file
@@ -0,0 +1,60 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import inspect
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
|
||||
from vllm.distributed.kv_transfer import (
|
||||
get_kv_transfer_group,
|
||||
has_kv_transfer_group,
|
||||
is_v1_kv_transfer_group,
|
||||
)
|
||||
|
||||
|
||||
def maybe_transfer_kv_layer(func: Callable) -> Callable:
|
||||
"""Decorator that handles KV layer transfer prior and after execution of
|
||||
an attention layer, if enabled. Otherwise, the wrapper is a no-op.
|
||||
|
||||
On entry: waits for the KV layer from the connector.
|
||||
On exit: saves the KV layer to the connector.
|
||||
"""
|
||||
# Import at runtime to avoid circular dependency
|
||||
from vllm.model_executor.layers.attention.attention import get_attention_context
|
||||
|
||||
# Inspect the signature ONCE when the decorator is applied.
|
||||
sig = inspect.signature(func)
|
||||
param_names = list(sig.parameters.keys())
|
||||
|
||||
# Find the index of 'layer_name' parameter.
|
||||
try:
|
||||
layer_name_index = param_names.index("layer_name")
|
||||
except ValueError as e:
|
||||
raise TypeError(
|
||||
f"Function {func.__name__} must have a 'layer_name' parameter"
|
||||
) from e
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
|
||||
return func(*args, **kwargs)
|
||||
|
||||
layer_name: str = args[layer_name_index]
|
||||
|
||||
# Extract attention context (metadata, layer, kv_cache, layer_slot_mapping)
|
||||
attn_metadata, _, kv_cache, _ = get_attention_context(layer_name)
|
||||
connector = get_kv_transfer_group()
|
||||
if attn_metadata is None or not connector.has_connector_metadata():
|
||||
return func(*args, **kwargs)
|
||||
|
||||
# Wait for KV layer on entry
|
||||
connector.wait_for_layer_load(layer_name)
|
||||
|
||||
# Execute the function
|
||||
result = func(*args, **kwargs)
|
||||
|
||||
# Save KV cache layer on exit
|
||||
connector.save_kv_layer(layer_name, kv_cache, attn_metadata)
|
||||
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
3006
vllm/model_executor/layers/attention/mla_attention.py
Normal file
3006
vllm/model_executor/layers/attention/mla_attention.py
Normal file
File diff suppressed because it is too large
Load Diff
262
vllm/model_executor/layers/attention/mm_encoder_attention.py
Normal file
262
vllm/model_executor/layers/attention/mm_encoder_attention.py
Normal file
@@ -0,0 +1,262 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.models.vision import get_vit_attn_backend
|
||||
from vllm.v1.attention.backends.fa_utils import get_flash_attn_version
|
||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.v1.attention.ops.vit_attn_wrappers import (
|
||||
vit_flash_attn_wrapper,
|
||||
vit_torch_sdpa_wrapper,
|
||||
vit_triton_attn_wrapper,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
# --8<-- [start:mm_encoder_attn]
|
||||
@CustomOp.register("mm_encoder_attn")
|
||||
class MMEncoderAttention(CustomOp):
|
||||
"""Multi-headed attention without any cache, used for multimodal encoder."""
|
||||
|
||||
# --8<-- [end:mm_encoder_attn]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float | None = None,
|
||||
num_kv_heads: int | None = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
num_heads: number of attention heads per partition.
|
||||
head_size: hidden_size per attention head.
|
||||
scale: scale factor.
|
||||
num_kv_heads: number of kv heads.
|
||||
prefix: This has no effect, it is only here to make it easier to
|
||||
swap between Attention and MultiHeadAttention
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = scale
|
||||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||
self.layer_name = prefix
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0, (
|
||||
f"num_heads ({self.num_heads}) is not "
|
||||
f"divisible by num_kv_heads ({self.num_kv_heads})"
|
||||
)
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
# During model initialization, the default dtype is set as the model
|
||||
# weight and activation dtype.
|
||||
dtype = torch.get_default_dtype()
|
||||
|
||||
# Get device-specific vision attention backend.
|
||||
self.attn_backend = get_vit_attn_backend(
|
||||
head_size=head_size,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
self.is_flash_attn_backend = self.attn_backend in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}
|
||||
|
||||
self._fa_version = (
|
||||
get_flash_attn_version() if self.is_flash_attn_backend else None
|
||||
)
|
||||
|
||||
logger.info_once(f"Using {self.attn_backend} for MMEncoderAttention.")
|
||||
|
||||
@classmethod
|
||||
def enabled(cls) -> bool:
|
||||
return True
|
||||
|
||||
def view_qkv_to_4d(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
bsz: int,
|
||||
q_len: int,
|
||||
kv_len: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Reshape query, key, value to 4D tensors:
|
||||
(batch_size, seq_len, num_heads, head_size)
|
||||
"""
|
||||
query = query.view(bsz, q_len, self.num_heads, self.head_size)
|
||||
key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
|
||||
value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)
|
||||
|
||||
return query, key, value
|
||||
|
||||
def _forward_sdpa(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Input shape:
|
||||
(batch_size x seq_len x hidden_size) or
|
||||
(batch_size x seq_len x num_heads x head_size)
|
||||
"""
|
||||
bsz, q_len = query.size()[:2]
|
||||
kv_len = key.size(1)
|
||||
is_reshaped = query.dim() != 4
|
||||
|
||||
query, key, value = self.view_qkv_to_4d(query, key, value, bsz, q_len, kv_len)
|
||||
|
||||
output = vit_torch_sdpa_wrapper(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
scale=self.scale,
|
||||
cu_seqlens=cu_seqlens,
|
||||
enable_gqa=self.num_heads > self.num_kv_heads,
|
||||
)
|
||||
if is_reshaped:
|
||||
output = output.reshape(bsz, q_len, -1)
|
||||
return output
|
||||
|
||||
def _forward_fa(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
|
||||
) -> torch.Tensor:
|
||||
"""Input shape:
|
||||
(batch_size x seq_len x hidden_size) or
|
||||
(batch_size x seq_len x num_heads x head_size)
|
||||
"""
|
||||
assert (cu_seqlens is not None and max_seqlen is not None) or (
|
||||
cu_seqlens is None and max_seqlen is None
|
||||
), "cu_seqlens and max_seqlen should be both set or both None."
|
||||
|
||||
bsz, q_len = query.size()[:2]
|
||||
kv_len = key.size(1)
|
||||
is_reshaped = query.dim() != 4
|
||||
|
||||
query, key, value = self.view_qkv_to_4d(query, key, value, bsz, q_len, kv_len)
|
||||
|
||||
output = vit_flash_attn_wrapper(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
batch_size=bsz,
|
||||
is_rocm_aiter=(self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA),
|
||||
fa_version=self._fa_version,
|
||||
scale=self.scale,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
if is_reshaped:
|
||||
output = output.reshape(bsz, q_len, -1)
|
||||
return output
|
||||
|
||||
def _forward_triton(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
|
||||
) -> torch.Tensor:
|
||||
"""Input shape:
|
||||
(batch_size x seq_len x hidden_size) or
|
||||
(batch_size x seq_len x num_heads x head_size)
|
||||
"""
|
||||
assert (cu_seqlens is not None and max_seqlen is not None) or (
|
||||
cu_seqlens is None and max_seqlen is None
|
||||
), "cu_seqlens and max_seqlen should be both set or both None."
|
||||
|
||||
bsz, q_len = query.size()[:2]
|
||||
kv_len = key.size(1)
|
||||
is_reshaped = query.dim() != 4
|
||||
|
||||
query, key, value = self.view_qkv_to_4d(query, key, value, bsz, q_len, kv_len)
|
||||
|
||||
output = vit_triton_attn_wrapper(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
batch_size=bsz,
|
||||
scale=self.scale,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
if is_reshaped:
|
||||
output = output.reshape(bsz, q_len, -1)
|
||||
return output
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
|
||||
) -> torch.Tensor:
|
||||
return self._forward_sdpa(query, key, value, cu_seqlens)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
|
||||
) -> torch.Tensor:
|
||||
if self.is_flash_attn_backend:
|
||||
return self._forward_fa(query, key, value, cu_seqlens, max_seqlen)
|
||||
elif self.attn_backend == AttentionBackendEnum.TRITON_ATTN:
|
||||
return self._forward_triton(query, key, value, cu_seqlens, max_seqlen)
|
||||
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
|
||||
return self._forward_sdpa(query, key, value, cu_seqlens)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported multi-modal encoder attention backend for CUDA: "
|
||||
f"{self.attn_backend}."
|
||||
)
|
||||
|
||||
def forward_cpu(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
|
||||
) -> torch.Tensor:
|
||||
return self._forward_sdpa(query, key, value, cu_seqlens)
|
||||
|
||||
def forward_xpu(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
|
||||
) -> torch.Tensor:
|
||||
if self.attn_backend == AttentionBackendEnum.FLASH_ATTN:
|
||||
return self._forward_fa(query, key, value, cu_seqlens, max_seqlen)
|
||||
elif self.attn_backend == AttentionBackendEnum.TRITON_ATTN:
|
||||
return self._forward_triton(query, key, value, cu_seqlens, max_seqlen)
|
||||
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
|
||||
return self._forward_sdpa(query, key, value, cu_seqlens)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported multi-modal encoder attention backend for XPU: "
|
||||
f"{self.attn_backend}."
|
||||
)
|
||||
252
vllm/model_executor/layers/attention/static_sink_attention.py
Normal file
252
vllm/model_executor/layers/attention/static_sink_attention.py
Normal file
@@ -0,0 +1,252 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import functools
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionMetadata,
|
||||
AttentionType,
|
||||
CommonAttentionMetadata,
|
||||
subclass_attention_backend,
|
||||
)
|
||||
from vllm.v1.attention.ops.triton_reshape_and_cache_flash import (
|
||||
triton_reshape_and_cache_flash_diffkv,
|
||||
)
|
||||
from vllm.v1.attention.selector import get_attn_backend
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
AttentionSpec,
|
||||
KVCacheSpec,
|
||||
SinkFullAttentionSpec,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def create_static_sink_attention_backend(
|
||||
underlying_attn_backend: type[AttentionBackend],
|
||||
sink_len: int = 0,
|
||||
) -> type[AttentionBackend]:
|
||||
prefix = "StaticSink_"
|
||||
underlying_builder = underlying_attn_backend.get_builder_cls()
|
||||
|
||||
class StaticSinkAttentionBuilder(underlying_builder): # type: ignore
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
model_config = vllm_config.model_config
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
self.sink_len = sink_len
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
self.num_sink_blocks = self.sink_len // vllm_config.cache_config.block_size
|
||||
self.max_num_blocks = cdiv(
|
||||
model_config.max_model_len, vllm_config.cache_config.block_size
|
||||
)
|
||||
self.block_table_with_sink = torch.zeros(
|
||||
(
|
||||
scheduler_config.max_num_seqs,
|
||||
self.max_num_blocks + self.num_sink_blocks,
|
||||
),
|
||||
device=device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
self.block_table_with_sink[:, : self.num_sink_blocks] = torch.arange(
|
||||
1,
|
||||
self.num_sink_blocks + 1,
|
||||
device=device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> AttentionMetadata:
|
||||
common_attn_metadata.seq_lens[:] = (
|
||||
common_attn_metadata.seq_lens + self.sink_len
|
||||
)
|
||||
common_attn_metadata.seq_lens[
|
||||
common_attn_metadata.seq_lens == self.sink_len
|
||||
] = 0
|
||||
common_attn_metadata.max_seq_len = (
|
||||
common_attn_metadata.max_seq_len + self.sink_len
|
||||
)
|
||||
max_num_blocks = cdiv(common_attn_metadata.max_seq_len, self.block_size)
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
self.block_table_with_sink[
|
||||
:num_reqs, self.num_sink_blocks : self.num_sink_blocks + max_num_blocks
|
||||
] = common_attn_metadata.block_table_tensor[:, :max_num_blocks]
|
||||
common_attn_metadata.block_table_tensor = self.block_table_with_sink[
|
||||
:num_reqs
|
||||
]
|
||||
|
||||
return super().build(common_prefix_len, common_attn_metadata, fast_build)
|
||||
|
||||
attn_backend = subclass_attention_backend(
|
||||
name_prefix=prefix,
|
||||
attention_backend_cls=underlying_attn_backend,
|
||||
builder_cls=StaticSinkAttentionBuilder,
|
||||
)
|
||||
|
||||
return attn_backend
|
||||
|
||||
|
||||
@CustomOp.register("static_sink_attention")
|
||||
class StaticSinkAttention(Attention, CustomOp):
|
||||
"""
|
||||
Attention with static sink tokens
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
sink_len: int,
|
||||
attn_backend: type[AttentionBackend] | None = None,
|
||||
cache_config: CacheConfig | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
dtype = torch.get_default_dtype()
|
||||
|
||||
if cache_config is not None:
|
||||
kv_cache_dtype = cache_config.cache_dtype
|
||||
block_size = cache_config.block_size
|
||||
else:
|
||||
kv_cache_dtype = "auto"
|
||||
block_size = 16
|
||||
|
||||
if attn_backend is not None:
|
||||
underlying_attn_backend = attn_backend
|
||||
else:
|
||||
underlying_attn_backend = get_attn_backend(
|
||||
head_size, dtype, kv_cache_dtype, block_size
|
||||
)
|
||||
attn_backend = create_static_sink_attention_backend(
|
||||
underlying_attn_backend, # type: ignore[arg-type]
|
||||
sink_len=sink_len,
|
||||
)
|
||||
Attention.__init__(
|
||||
self=self,
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
scale=scale,
|
||||
cache_config=cache_config,
|
||||
attn_backend=attn_backend,
|
||||
**kwargs,
|
||||
)
|
||||
CustomOp.__init__(self)
|
||||
|
||||
self.sink_len = sink_len
|
||||
self.block_size = block_size
|
||||
self.sink_populated = False
|
||||
self.sink_key = None
|
||||
self.sink_value = None
|
||||
|
||||
def update_sink_kv(self, sink_key, sink_value) -> None:
|
||||
self.sink_key = sink_key
|
||||
self.sink_value = sink_value
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
output_shape: torch.Size | None = None,
|
||||
) -> torch.Tensor:
|
||||
assert self.sink_key is not None and self.sink_value is not None, (
|
||||
"sink_key and sink_value have not been prepared"
|
||||
)
|
||||
if not self.sink_populated:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
torch.ops.vllm.maybe_populate_sink(self_kv_cache, self.layer_name)
|
||||
|
||||
return super().forward(query, key, value, output_shape)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
output_shape: torch.Size | None = None,
|
||||
) -> torch.Tensor:
|
||||
return self.forward_native(query, key, value, output_shape)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self._forward_method(*args, **kwargs)
|
||||
|
||||
def populate_sink_kv(self, self_kv_cache):
|
||||
sink_kv_slot_mapping = torch.arange(
|
||||
self.block_size,
|
||||
self.sink_len + self.block_size,
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=torch.long,
|
||||
)
|
||||
triton_reshape_and_cache_flash_diffkv(
|
||||
self.sink_key,
|
||||
self.sink_value,
|
||||
self_kv_cache,
|
||||
sink_kv_slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
self._k_scale,
|
||||
self._v_scale,
|
||||
)
|
||||
# We only populate the sink_key and sink_value once
|
||||
self.sink_populated = True
|
||||
|
||||
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
|
||||
# Block size may get updated after model loading, refresh it
|
||||
block_size = vllm_config.cache_config.block_size
|
||||
# Should not be called for enc-dec or encoder-only attention.
|
||||
assert self.attn_type == AttentionType.DECODER
|
||||
|
||||
return SinkFullAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_size=self.head_size,
|
||||
head_size_v=self.head_size_v,
|
||||
sink_len=self.sink_len,
|
||||
dtype=self.kv_cache_torch_dtype,
|
||||
)
|
||||
|
||||
|
||||
def maybe_populate_sink(
|
||||
self_kv_cache: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> None:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
if self.sink_populated or self_kv_cache.numel() == 0:
|
||||
return
|
||||
self.populate_sink_kv(self_kv_cache)
|
||||
|
||||
|
||||
def maybe_populate_sink_fake(
|
||||
self_kv_cache: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="maybe_populate_sink",
|
||||
op_func=maybe_populate_sink,
|
||||
mutates_args=["self_kv_cache"],
|
||||
fake_impl=maybe_populate_sink_fake,
|
||||
)
|
||||
Reference in New Issue
Block a user