Add minimal vLLM 0.16.1 build repo for BI-V150
This commit is contained in:
733
vllm/model_executor/layers/attention/attention.py
Normal file
733
vllm/model_executor/layers/attention/attention.py
Normal file
@@ -0,0 +1,733 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import CacheConfig, get_current_vllm_config
|
||||
from vllm.config.vllm import VllmConfig
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention.kv_transfer_utils import (
|
||||
maybe_transfer_kv_layer,
|
||||
)
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
|
||||
from vllm.model_executor.layers.linear import (
|
||||
UnquantizedLinearMethod,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import (
|
||||
direct_register_custom_op,
|
||||
kv_cache_dtype_str_to_dtype,
|
||||
)
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionType,
|
||||
)
|
||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.v1.attention.selector import get_attn_backend
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
FullAttentionSpec,
|
||||
KVCacheSpec,
|
||||
SlidingWindowSpec,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.layers.attention import MLAAttention
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def validate_kv_sharing_target(
|
||||
current_layer_name, target_layer_name, static_forward_context
|
||||
):
|
||||
error_msg = (
|
||||
f"Specified KV sharing target layer for {current_layer_name} "
|
||||
f"is not valid: target layer {target_layer_name} "
|
||||
)
|
||||
|
||||
if current_layer_name == target_layer_name:
|
||||
raise ValueError(error_msg + "cannot be the same as the current layer.")
|
||||
|
||||
if target_layer_name not in static_forward_context:
|
||||
from vllm.model_executor.models.utils import extract_layer_index
|
||||
|
||||
# If target layer name is not in the static fwd context, it means either
|
||||
# a) the target layer does not come BEFORE the current layer, or
|
||||
# b) the target layer is not an Attention layer that exists in the model
|
||||
current_layer_idx = extract_layer_index(current_layer_name)
|
||||
target_layer_idx = extract_layer_index(target_layer_name)
|
||||
if current_layer_idx <= target_layer_idx:
|
||||
raise ValueError(error_msg + "must come before the current layer.")
|
||||
else:
|
||||
raise ValueError(error_msg + "is not a valid Attention layer in the model.")
|
||||
|
||||
# Currently KV sharing is only supported between layers of the same type
|
||||
target_layer_attn_type = static_forward_context[target_layer_name].attn_type
|
||||
expected = static_forward_context[current_layer_name].attn_type
|
||||
if target_layer_attn_type != expected:
|
||||
raise ValueError(
|
||||
error_msg + f"must be the same type as the current layer ({expected})."
|
||||
)
|
||||
|
||||
|
||||
def should_load_quant_weights(quant_method: QuantizeMethodBase | None) -> bool:
|
||||
"""Returns whether the quantization method should load quantized weights."""
|
||||
return quant_method is not None and not isinstance(
|
||||
quant_method, UnquantizedLinearMethod
|
||||
)
|
||||
|
||||
|
||||
def set_default_quant_scales(layer: nn.Module, register_buffer: bool = False) -> None:
|
||||
"""Sets default quantization scales for the layer."""
|
||||
if register_buffer:
|
||||
layer.register_buffer("_k_scale", torch.tensor(1.0, dtype=torch.float32))
|
||||
layer.register_buffer("_v_scale", torch.tensor(1.0, dtype=torch.float32))
|
||||
layer.register_buffer("_q_scale", torch.tensor(1.0, dtype=torch.float32))
|
||||
layer.register_buffer("_prob_scale", torch.tensor(1.0, dtype=torch.float32))
|
||||
else:
|
||||
layer._k_scale.fill_(1.0)
|
||||
layer._v_scale.fill_(1.0)
|
||||
layer._q_scale.fill_(1.0)
|
||||
layer._prob_scale.fill_(1.0)
|
||||
|
||||
# We also keep q/k/v_scale on host (cpu) memory for attention
|
||||
# backends that require the scales to be on host instead of on device.
|
||||
# e.g. Flashinfer
|
||||
layer._q_scale_float = 1.0
|
||||
layer._k_scale_float = 1.0
|
||||
layer._v_scale_float = 1.0
|
||||
layer._prob_scale_float = 1.0
|
||||
|
||||
# Initialize q/k/v range constants used by calc_kv_scales
|
||||
layer.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
|
||||
layer.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
|
||||
layer.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)
|
||||
|
||||
|
||||
def _init_kv_cache_quant(
|
||||
layer: nn.Module,
|
||||
quant_config: QuantizationConfig | None,
|
||||
prefix: str,
|
||||
) -> None:
|
||||
"""Initializes KV cache scaling factors and quantization method.
|
||||
|
||||
This helper function sets up the KV cache quantization attributes that are
|
||||
shared between Attention and MLAAttention layers. It initializes scale
|
||||
tensors for query, key, value, and probability, and configures the
|
||||
quantization method if applicable.
|
||||
|
||||
Args:
|
||||
layer: The attention layer instance to initialize.
|
||||
quant_config: Optional quantization configuration.
|
||||
prefix: Layer name prefix for quantization method lookup.
|
||||
"""
|
||||
quant_method = (
|
||||
quant_config.get_quant_method(layer, prefix=prefix) if quant_config else None
|
||||
)
|
||||
|
||||
# Note [Register q/k/v/prob scales in state dict]
|
||||
# When calling model.to(device), only parameters/buffers in state dict are
|
||||
# moved. If not registering q/k/v/prob scales in state dict, there would
|
||||
# be an IMA error when a cuda kernel (e.g., quant_fp8) accesses the tensor
|
||||
# on cpu.
|
||||
# Registering in state dict means it interacts with weight loading. One edge
|
||||
# case is when quant_method is None, or quant_method is UnquantizedLinearMethod
|
||||
# (i.e., should_load_quant_weights(quant_method) == False).
|
||||
# In this case, the checkpoint does not have the scales. We need to
|
||||
# initialize the scales to 1.0 and update the scales after weight loading.
|
||||
# This is espectially important when we load dummy weights first (providing
|
||||
# wrong scales) and then load real weights (which misses scales and keeps the
|
||||
# wrong scales from dummy load).
|
||||
set_default_quant_scales(layer, register_buffer=True)
|
||||
|
||||
# The output scale on host memory. This should be the input scale of
|
||||
# the quant op after this attention layer.
|
||||
layer._o_scale_float = None
|
||||
|
||||
quant_method = (
|
||||
quant_config.get_quant_method(layer, prefix=prefix) if quant_config else None
|
||||
)
|
||||
|
||||
# See [Note: Register q/k/v/prob scales in state dict]
|
||||
if should_load_quant_weights(quant_method):
|
||||
assert isinstance(quant_method, BaseKVCacheMethod)
|
||||
# TODO (mgoin): kv cache dtype should be specified in the FP8
|
||||
# checkpoint config and become the "auto" behavior
|
||||
if layer.kv_cache_dtype == "fp8_e5m2":
|
||||
raise ValueError("fp8_e5m2 kv-cache is not supported with fp8 checkpoints.")
|
||||
# If quantization is enabled, we make "k_scale" and "v_scale"
|
||||
# parameters so that it can be loaded from the model checkpoint.
|
||||
# The k/v_scale will then be converted back to native float32
|
||||
# values after weight loading.
|
||||
layer.quant_method = quant_method
|
||||
layer.quant_method.create_weights(layer)
|
||||
|
||||
|
||||
class Attention(nn.Module, AttentionLayerBase):
|
||||
"""Attention layer.
|
||||
|
||||
This class takes query, key, and value tensors as input. The input tensors
|
||||
can either contain prompt tokens or generation tokens.
|
||||
The class does the following:
|
||||
|
||||
1. Store the input key and value tensors in the KV cache.
|
||||
2. Perform (multi-head/multi-query/grouped-query) attention.
|
||||
3. Return the output tensor.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int | None = None,
|
||||
alibi_slopes: list[float] | None = None,
|
||||
use_alibi_sqrt: bool | None = None,
|
||||
cache_config: CacheConfig | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
logits_soft_cap: float | None = None,
|
||||
per_layer_sliding_window: int | None = None,
|
||||
prefix: str = "",
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: str | None = None,
|
||||
attn_backend: type[AttentionBackend] | None = None,
|
||||
head_size_v: int | None = None,
|
||||
**extra_impl_args,
|
||||
) -> None:
|
||||
"""
|
||||
The KV cache is stored inside this class and is accessed via
|
||||
`self.kv_cache`.
|
||||
"""
|
||||
super().__init__()
|
||||
if per_layer_sliding_window is not None:
|
||||
# per-layer sliding window
|
||||
sliding_window = per_layer_sliding_window
|
||||
elif cache_config is not None:
|
||||
# model-level sliding window
|
||||
sliding_window = cache_config.sliding_window
|
||||
else:
|
||||
sliding_window = None
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
if cache_config is not None:
|
||||
kv_cache_dtype = cache_config.cache_dtype
|
||||
block_size = cache_config.block_size
|
||||
calculate_kv_scales = cache_config.calculate_kv_scales
|
||||
else:
|
||||
kv_cache_dtype = "auto"
|
||||
block_size = 16
|
||||
calculate_kv_scales = False
|
||||
|
||||
# llm-compressor mdls need to set cache_dtype to "fp8" manually.
|
||||
kv_cache_scheme = getattr(quant_config, "kv_cache_scheme", None)
|
||||
if kv_cache_scheme is not None:
|
||||
kv_cache_dtype = "fp8"
|
||||
calculate_kv_scales = False
|
||||
if cache_config is not None:
|
||||
cache_config.cache_dtype = "fp8"
|
||||
cache_config.calculate_kv_scales = False
|
||||
|
||||
# Check if per-head quant scales are required based on kv_cache_scheme
|
||||
use_per_head_quant_scales = (
|
||||
kv_cache_scheme is not None
|
||||
and kv_cache_scheme.get("strategy") == "attn_head"
|
||||
)
|
||||
|
||||
self.kv_cache_torch_dtype = kv_cache_dtype_str_to_dtype(
|
||||
kv_cache_dtype, vllm_config.model_config
|
||||
)
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.calculate_kv_scales = calculate_kv_scales
|
||||
if num_kv_heads is None:
|
||||
num_kv_heads = num_heads
|
||||
assert num_heads % num_kv_heads == 0, (
|
||||
f"num_heads ({num_heads}) is not divisible by num_kv_heads ({num_kv_heads})"
|
||||
)
|
||||
self.quant_config = quant_config
|
||||
self.layer_name = prefix
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.head_size_v = self.head_size if head_size_v is None else head_size_v
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.sliding_window = sliding_window
|
||||
self.has_sink = extra_impl_args.get("sinks") is not None
|
||||
|
||||
# NOTE: model_config may be None during certain tests
|
||||
model_config = vllm_config.model_config
|
||||
self.use_mm_prefix = model_config is not None and model_config.is_mm_prefix_lm
|
||||
|
||||
# During model initialization, the default dtype is set as the model
|
||||
# weight and activation dtype.
|
||||
dtype = torch.get_default_dtype()
|
||||
if attn_backend is None:
|
||||
self.attn_backend = get_attn_backend(
|
||||
head_size,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
block_size,
|
||||
use_mla=False,
|
||||
has_sink=self.has_sink,
|
||||
use_mm_prefix=self.use_mm_prefix,
|
||||
use_per_head_quant_scales=use_per_head_quant_scales,
|
||||
attn_type=attn_type,
|
||||
)
|
||||
else:
|
||||
self.attn_backend = attn_backend
|
||||
backend_supports_alibi_sqrt = self.attn_backend.supports_alibi_sqrt()
|
||||
use_alibi_sqrt = use_alibi_sqrt if use_alibi_sqrt else False
|
||||
if use_alibi_sqrt and not backend_supports_alibi_sqrt:
|
||||
raise ValueError(
|
||||
f"use_alibi_sqrt is not supported by backend "
|
||||
f"{self.attn_backend.get_name()}."
|
||||
)
|
||||
self.use_alibi_sqrt = bool(use_alibi_sqrt)
|
||||
if backend_supports_alibi_sqrt:
|
||||
extra_impl_args["use_alibi_sqrt"] = self.use_alibi_sqrt
|
||||
# prefix caching + batch invariance is currently not supported for
|
||||
# FLASHINFER and TRITON_MLA.
|
||||
if (
|
||||
cache_config is not None
|
||||
and cache_config.enable_prefix_caching
|
||||
and vllm_is_batch_invariant()
|
||||
and (
|
||||
self.attn_backend.get_name() == "FLASHINFER"
|
||||
or self.attn_backend.get_name() == "TRITON_MLA"
|
||||
)
|
||||
):
|
||||
logger.warning_once(
|
||||
"Disabling prefix caching for FLASHINFER/TRITON_MLA "
|
||||
"with batch invariance, as it is not yet supported.",
|
||||
scope="local",
|
||||
)
|
||||
cache_config.enable_prefix_caching = False
|
||||
|
||||
impl_cls = self.attn_backend.get_impl_cls()
|
||||
self.impl = impl_cls(
|
||||
num_heads,
|
||||
head_size,
|
||||
scale,
|
||||
num_kv_heads,
|
||||
alibi_slopes,
|
||||
sliding_window,
|
||||
kv_cache_dtype,
|
||||
logits_soft_cap,
|
||||
attn_type,
|
||||
kv_sharing_target_layer_name,
|
||||
**extra_impl_args,
|
||||
)
|
||||
self.backend = AttentionBackendEnum[self.attn_backend.get_name()]
|
||||
self.dtype = dtype
|
||||
|
||||
# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
|
||||
# torch.compile works by registering the attention as one giant
|
||||
# opaque custom op. For other platforms, we directly call them
|
||||
# and let torch.compile handle them.
|
||||
self.use_direct_call = not current_platform.opaque_attention_op()
|
||||
|
||||
self.use_output = self.attn_backend.accept_output_buffer
|
||||
compilation_config = vllm_config.compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
self.attn_type = attn_type
|
||||
|
||||
if kv_sharing_target_layer_name is not None:
|
||||
validate_kv_sharing_target(
|
||||
prefix,
|
||||
kv_sharing_target_layer_name,
|
||||
compilation_config.static_forward_context,
|
||||
)
|
||||
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||
|
||||
# use a placeholder kv cache tensor during init, which will be replaced
|
||||
# by bind_kv_cache
|
||||
# this variable will not be accessed if use_direct_call is True
|
||||
self.kv_cache = [
|
||||
torch.tensor([])
|
||||
for _ in range(vllm_config.parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
|
||||
# Initialize KV cache quantization attributes
|
||||
_init_kv_cache_quant(self, quant_config, prefix)
|
||||
|
||||
# for attn backends supporting query quantization
|
||||
self.query_quant = None
|
||||
if self.impl.supports_quant_query_input and self.kv_cache_dtype.startswith(
|
||||
"fp8"
|
||||
):
|
||||
is_per_head = (
|
||||
hasattr(self, "q_scale") and self.q_scale.numel() == self.num_kv_heads
|
||||
)
|
||||
block_size = self.head_size * self.num_heads // self.num_kv_heads
|
||||
self.query_quant = QuantFP8(
|
||||
static=True,
|
||||
group_shape=GroupShape(-1, block_size)
|
||||
if is_per_head
|
||||
else GroupShape.PER_TENSOR,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
# For some alternate attention backends like MLA the attention output
|
||||
# shape does not match the query shape, so we optionally let the model
|
||||
# definition specify the output tensor shape.
|
||||
output_shape: torch.Size | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
The KV cache is stored inside this class and is accessed via
|
||||
`self.kv_cache`.
|
||||
|
||||
Attention metadata (`attn_metadata`) is set using a context manager in
|
||||
the model runner's `execute_model` method. It is accessed via forward
|
||||
context using
|
||||
`vllm.forward_context.get_forward_context().attn_metadata`.
|
||||
"""
|
||||
if self.calculate_kv_scales:
|
||||
torch.ops.vllm.maybe_calc_kv_scales(query, key, value, self.layer_name)
|
||||
output_dtype = query.dtype
|
||||
if self.query_quant is not None:
|
||||
# quantizing with a simple torch operation enables
|
||||
# torch.compile to fuse this into previous ops
|
||||
# which reduces overheads during decoding.
|
||||
# Otherwise queries are quantized using custom ops
|
||||
# which causes decoding overheads
|
||||
assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"}
|
||||
|
||||
# check if query quantization is supported
|
||||
if self.impl.supports_quant_query_input:
|
||||
query, _ = self.query_quant(query, self._q_scale)
|
||||
|
||||
if self.use_output:
|
||||
if output_shape is None:
|
||||
# Handle both 2D [num_tokens, hidden] and
|
||||
# 3D [num_tokens, heads, head_dim] query
|
||||
num_tokens = query.shape[0]
|
||||
output_shape = torch.Size(
|
||||
(num_tokens, self.num_heads * self.head_size_v)
|
||||
)
|
||||
output = torch.empty(output_shape, dtype=output_dtype, device=query.device)
|
||||
hidden_size = output_shape[-1]
|
||||
# Reshape the query, key, and value tensors.
|
||||
# NOTE(woosuk): We do this outside the custom op to minimize the
|
||||
# CPU overheads from the non-CUDA-graph regions.
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
output = output.view(-1, self.num_heads, self.head_size_v)
|
||||
if key is not None:
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
if value is not None:
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size_v)
|
||||
kv_cache_dummy_dep = None
|
||||
if self.use_direct_call:
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
if (
|
||||
not self.attn_backend.forward_includes_kv_cache_update
|
||||
and self.kv_sharing_target_layer_name is None
|
||||
and key is not None
|
||||
and value is not None
|
||||
):
|
||||
kv_cache_dummy_dep = unified_kv_cache_update(
|
||||
key, value, self.layer_name
|
||||
)
|
||||
unified_attention_with_output(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
output,
|
||||
self.layer_name,
|
||||
kv_cache_dummy_dep=kv_cache_dummy_dep,
|
||||
)
|
||||
else:
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
if (
|
||||
not self.attn_backend.forward_includes_kv_cache_update
|
||||
and self.kv_sharing_target_layer_name is None
|
||||
and key is not None
|
||||
and value is not None
|
||||
):
|
||||
kv_cache_dummy_dep = torch.ops.vllm.unified_kv_cache_update(
|
||||
key, value, self.layer_name
|
||||
)
|
||||
torch.ops.vllm.unified_attention_with_output(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
output,
|
||||
self.layer_name,
|
||||
kv_cache_dummy_dep=kv_cache_dummy_dep,
|
||||
)
|
||||
return output.view(-1, hidden_size)
|
||||
else:
|
||||
assert self.attn_backend.forward_includes_kv_cache_update, (
|
||||
"Split KV cache update not supported when output tensor not provided."
|
||||
)
|
||||
if self.use_direct_call:
|
||||
return unified_attention(query, key, value, self.layer_name)
|
||||
else:
|
||||
return torch.ops.vllm.unified_attention(
|
||||
query, key, value, self.layer_name
|
||||
)
|
||||
|
||||
def calc_kv_scales(self, query, key, value):
|
||||
self._q_scale.copy_(torch.abs(query).max() / self.q_range)
|
||||
self._k_scale.copy_(torch.abs(key).max() / self.k_range)
|
||||
self._v_scale.copy_(torch.abs(value).max() / self.v_range)
|
||||
self._q_scale_float = self._q_scale.item()
|
||||
self._k_scale_float = self._k_scale.item()
|
||||
self._v_scale_float = self._v_scale.item()
|
||||
# We only calculate the scales once
|
||||
self.calculate_kv_scales = False
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
s = f"head_size={self.impl.head_size}" # type: ignore
|
||||
s += f", num_heads={self.impl.num_heads}" # type: ignore
|
||||
s += f", num_kv_heads={self.impl.num_kv_heads}" # type: ignore
|
||||
s += f", scale={self.impl.scale}" # type: ignore
|
||||
s += f", backend={self.impl.__class__.__name__}"
|
||||
return s
|
||||
|
||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||
self.impl.process_weights_after_loading(act_dtype)
|
||||
|
||||
# If we should not load quant weights, we initialize the scales to 1.0
|
||||
# as the default value. See [Note: Register q/k/v/prob scales in state dict]
|
||||
# for more details.
|
||||
quant_method = (
|
||||
self.quant_config.get_quant_method(self, prefix=self.layer_name)
|
||||
if self.quant_config
|
||||
else None
|
||||
)
|
||||
if not should_load_quant_weights(quant_method):
|
||||
set_default_quant_scales(self, register_buffer=False)
|
||||
|
||||
def get_attn_backend(self) -> type[AttentionBackend]:
|
||||
return self.attn_backend
|
||||
|
||||
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
|
||||
# Block size may get updated after model loading, refresh it
|
||||
block_size = vllm_config.cache_config.block_size
|
||||
# Should not be called for enc-dec or encoder-only attention.
|
||||
assert self.attn_type == AttentionType.DECODER
|
||||
if self.sliding_window is not None:
|
||||
assert not vllm_config.model_config.use_mla, (
|
||||
"MLA is not supported for slidingwindow"
|
||||
)
|
||||
return SlidingWindowSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_size=self.head_size,
|
||||
dtype=self.kv_cache_torch_dtype,
|
||||
sliding_window=self.sliding_window,
|
||||
)
|
||||
else:
|
||||
return FullAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_size=self.head_size,
|
||||
head_size_v=self.head_size_v,
|
||||
dtype=self.kv_cache_torch_dtype,
|
||||
)
|
||||
|
||||
|
||||
def maybe_calc_kv_scales(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> None:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
|
||||
# Only calculate if the layer's calculate_kv_scales flag is True
|
||||
# This flag gets set to False after the first forward pass
|
||||
if not self.calculate_kv_scales:
|
||||
return
|
||||
|
||||
self.calc_kv_scales(query, key, value)
|
||||
|
||||
|
||||
def maybe_calc_kv_scales_fake(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="maybe_calc_kv_scales",
|
||||
op_func=maybe_calc_kv_scales,
|
||||
mutates_args=["query", "key", "value"],
|
||||
fake_impl=maybe_calc_kv_scales_fake,
|
||||
)
|
||||
|
||||
|
||||
def get_attention_context(
|
||||
layer_name: str,
|
||||
) -> tuple[Any, "Attention | MLAAttention", torch.Tensor, torch.Tensor]:
|
||||
"""Extract attention context for a given layer.
|
||||
|
||||
This helper function extracts the attention metadata, attention layer
|
||||
instance, KV cache tensor, and slot mapping for a specific layer.
|
||||
|
||||
Args:
|
||||
layer_name: The name/identifier of the attention layer.
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
- attn_metadata: Attention metadata for this specific layer, or None if
|
||||
no metadata available
|
||||
- attn_layer: The attention layer instance (Attention or MLAAttention)
|
||||
- kv_cache: The KV cache tensor for current virtual engine
|
||||
- slot_mapping: The slot mapping for this specific layer
|
||||
|
||||
Note: attn_metadata may be None, but attn_layer and kv_cache are always
|
||||
extracted from the forward context.
|
||||
"""
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata[layer_name]
|
||||
attn_layer: Attention | MLAAttention = forward_context.no_compile_layers[layer_name]
|
||||
kv_cache = attn_layer.kv_cache[forward_context.virtual_engine]
|
||||
slot_mapping = forward_context.slot_mapping
|
||||
assert isinstance(slot_mapping, dict), (
|
||||
f"Expected slot_mapping to be a dict, got {type(slot_mapping)}. "
|
||||
)
|
||||
layer_slot_mapping = slot_mapping.get(layer_name)
|
||||
return attn_metadata, attn_layer, kv_cache, layer_slot_mapping
|
||||
|
||||
|
||||
@maybe_transfer_kv_layer
|
||||
def unified_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> torch.Tensor:
|
||||
attn_metadata, self, kv_cache, _ = get_attention_context(layer_name)
|
||||
output = self.impl.forward(self, query, key, value, kv_cache, attn_metadata)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def unified_attention_fake(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(query).contiguous()
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="unified_attention",
|
||||
op_func=unified_attention,
|
||||
fake_impl=unified_attention_fake,
|
||||
)
|
||||
|
||||
|
||||
def unified_kv_cache_update(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Returns a dummy that is passed to unified_attention to signal a side effect and
|
||||
the data dependency between them to ensure torch.compile preserves ordering.
|
||||
"""
|
||||
_, attn_layer, kv_cache, layer_slot_mapping = get_attention_context(layer_name)
|
||||
if layer_slot_mapping is not None:
|
||||
assert hasattr(attn_layer.impl, "do_kv_cache_update"), (
|
||||
f"{attn_layer.impl.__class__.__name__} does not support kv cache update"
|
||||
)
|
||||
attn_layer.impl.do_kv_cache_update(
|
||||
attn_layer,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
layer_slot_mapping,
|
||||
)
|
||||
|
||||
return torch.empty(0, device=kv_cache.device, dtype=kv_cache.dtype)
|
||||
|
||||
|
||||
def unified_kv_cache_update_fake(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty(0, device=key.device, dtype=key.dtype)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="unified_kv_cache_update",
|
||||
op_func=unified_kv_cache_update,
|
||||
fake_impl=unified_kv_cache_update_fake,
|
||||
mutates_args=[],
|
||||
)
|
||||
|
||||
|
||||
@maybe_transfer_kv_layer
|
||||
def unified_attention_with_output(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
kv_cache_dummy_dep: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
# kv_cache_dummy_dep is not used but accepting it creates a data dependency
|
||||
# that ensures torch.compile preserves ordering between KV cache update and
|
||||
# attention forward.
|
||||
del kv_cache_dummy_dep
|
||||
attn_metadata, self, kv_cache, _ = get_attention_context(layer_name)
|
||||
|
||||
self.impl.forward(
|
||||
self,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
output=output,
|
||||
output_scale=output_scale,
|
||||
output_block_scale=output_block_scale,
|
||||
)
|
||||
|
||||
|
||||
def unified_attention_with_output_fake(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
kv_cache_dummy_dep: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="unified_attention_with_output",
|
||||
op_func=unified_attention_with_output,
|
||||
mutates_args=["output", "output_block_scale"],
|
||||
fake_impl=unified_attention_with_output_fake,
|
||||
)
|
||||
Reference in New Issue
Block a user