# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer.""" from collections.abc import Callable from typing import cast import torch import torch.nn as nn import torch.nn.functional as F import vllm.envs as envs from vllm.attention import AttentionType from vllm.attention.backends.abstract import AttentionBackend, MLAAttentionImpl from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.selector import get_attn_backend from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target from vllm.attention.utils.kv_transfer_utils import maybe_transfer_kv_layer from vllm.config import CacheConfig, get_current_vllm_config from vllm.config.multimodal import MultiModalConfig from vllm.config.vllm import VllmConfig from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.linear import ( ColumnParallelLinear, UnquantizedLinearMethod, ) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.models.vision import get_vit_attn_backend from vllm.platforms import current_platform from vllm.utils.torch_utils import ( direct_register_custom_op, kv_cache_dtype_str_to_dtype, ) from vllm.v1.kv_cache_interface import ( FullAttentionSpec, KVCacheSpec, MLAAttentionSpec, SlidingWindowSpec, ) from ixformer.core import config _USE_TORCH_OPS = config.IXFORMER_USE_TORCH_OPS if current_platform.is_rocm(): from vllm.platforms.rocm import on_gfx9 else: on_gfx9 = lambda *args, **kwargs: False FP8_DTYPE = current_platform.fp8_dtype() logger = init_logger(__name__) USE_XFORMERS_OPS = None def check_xformers_availability(): global USE_XFORMERS_OPS if USE_XFORMERS_OPS is not None: return USE_XFORMERS_OPS if current_platform.is_cuda() and current_platform.has_device_capability(100): # Xformers FA is not compatible with B200 USE_XFORMERS_OPS = False else: try: from importlib.util import find_spec find_spec("xformers.ops") USE_XFORMERS_OPS = True except ImportError: USE_XFORMERS_OPS = False # the warning only needs to be shown once if not USE_XFORMERS_OPS: logger.warning("Xformers is not available, falling back.") return USE_XFORMERS_OPS import ixformer.contrib.vllm_flash_attn as ops def check_upstream_fa_availability(dtype: torch.dtype): if ( dtype in (torch.float16, torch.bfloat16) and current_platform.is_cuda() and current_platform.has_device_capability(80) ): from transformers.utils import is_flash_attn_2_available return is_flash_attn_2_available() if current_platform.is_rocm(): from importlib.util import find_spec return find_spec("flash_attn") is not None return False def maybe_get_vit_flash_attn_backend( attn_backend: AttentionBackendEnum, use_upstream_fa: bool, attn_backend_override: AttentionBackendEnum | None = None, ) -> tuple[AttentionBackendEnum, Callable | None]: if current_platform.is_rocm(): if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9(): attn_backend = AttentionBackendEnum.ROCM_AITER_FA elif ( check_upstream_fa_availability(torch.get_default_dtype()) and on_gfx9() and attn_backend_override is None ): attn_backend = AttentionBackendEnum.FLASH_ATTN use_upstream_fa = True else: return AttentionBackendEnum.TORCH_SDPA, None elif current_platform.is_cuda(): if ( attn_backend != AttentionBackendEnum.FLASH_ATTN and check_upstream_fa_availability(torch.get_default_dtype()) ): attn_backend = AttentionBackendEnum.FLASH_ATTN use_upstream_fa = True elif current_platform.is_xpu(): assert attn_backend == AttentionBackendEnum.FLASH_ATTN, ( "XPU platform only supports FLASH_ATTN as vision attention backend." ) use_upstream_fa = False else: return AttentionBackendEnum.TORCH_SDPA, None if attn_backend in { AttentionBackendEnum.FLASH_ATTN, AttentionBackendEnum.ROCM_AITER_FA, }: if attn_backend == AttentionBackendEnum.ROCM_AITER_FA: from aiter import flash_attn_varlen_func else: if use_upstream_fa: from flash_attn import flash_attn_varlen_func else: from vllm.attention.utils.fa_utils import flash_attn_varlen_func else: flash_attn_varlen_func = None return attn_backend, flash_attn_varlen_func def _init_kv_cache_quant( layer: nn.Module, quant_config: QuantizationConfig | None, prefix: str, kv_cache_dtype: str, calculate_kv_scales: bool, ) -> None: """Initializes KV cache scaling factors and quantization method. This helper function sets up the KV cache quantization attributes that are shared between Attention and MLAAttention layers. It initializes scale tensors for query, key, value, and probability, and configures the quantization method if applicable. Args: layer: The attention layer instance to initialize. quant_config: Optional quantization configuration. prefix: Layer name prefix for quantization method lookup. kv_cache_dtype: The KV cache data type string. calculate_kv_scales: Whether to calculate KV scales dynamically. """ # The default k/v_scale is set to 1.0. This is ignored # when kv-cache is not fp8, and should be used with # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we # expect the pre-quantized k/v_scale to be loaded along # with the model weights. layer.kv_cache_dtype = kv_cache_dtype layer.calculate_kv_scales = calculate_kv_scales layer._k_scale = torch.tensor(1.0, dtype=torch.float32) layer._v_scale = torch.tensor(1.0, dtype=torch.float32) layer._q_scale = torch.tensor(1.0, dtype=torch.float32) layer._prob_scale = torch.tensor(1.0, dtype=torch.float32) # We also keep q/k/v_scale on host (cpu) memory for attention # backends that require the scales to be on host instead of on device. # e.g. Flashinfer layer._q_scale_float = 1.0 layer._k_scale_float = 1.0 layer._v_scale_float = 1.0 # The output scale on host memory. This should be the input scale of # the quant op after this attention layer. layer._o_scale_float = None quant_method = ( quant_config.get_quant_method(layer, prefix=prefix) if quant_config else None ) if quant_method is not None and not isinstance( quant_method, UnquantizedLinearMethod ): assert isinstance(quant_method, BaseKVCacheMethod) # TODO (mgoin): kv cache dtype should be specified in the FP8 # checkpoint config and become the "auto" behavior if kv_cache_dtype == "fp8_e5m2": raise ValueError("fp8_e5m2 kv-cache is not supported with fp8 checkpoints.") # If quantization is enabled, we make "k_scale" and "v_scale" # parameters so that it can be loaded from the model checkpoint. # The k/v_scale will then be converted back to native float32 # values after weight loading. layer.quant_method = quant_method layer.quant_method.create_weights(layer) class Attention(nn.Module, AttentionLayerBase): """Attention layer. This class takes query, key, and value tensors as input. The input tensors can either contain prompt tokens or generation tokens. The class does the following: 1. Store the input key and value tensors in the KV cache. 2. Perform (multi-head/multi-query/grouped-query) attention. 3. Return the output tensor. """ def __init__( self, num_heads: int, head_size: int, scale: float, num_kv_heads: int | None = None, alibi_slopes: list[float] | None = None, cache_config: CacheConfig | None = None, quant_config: QuantizationConfig | None = None, logits_soft_cap: float | None = None, per_layer_sliding_window: int | None = None, prefix: str = "", attn_type: str = AttentionType.DECODER, kv_sharing_target_layer_name: str | None = None, attn_backend: type[AttentionBackend] | None = None, **extra_impl_args, ) -> None: """ The KV cache is stored inside this class and is accessed via `self.kv_cache`. """ super().__init__() if per_layer_sliding_window is not None: # per-layer sliding window sliding_window = per_layer_sliding_window elif cache_config is not None: # model-level sliding window sliding_window = cache_config.sliding_window else: sliding_window = None vllm_config = get_current_vllm_config() if cache_config is not None: kv_cache_dtype = cache_config.cache_dtype block_size = cache_config.block_size calculate_kv_scales = cache_config.calculate_kv_scales else: kv_cache_dtype = "auto" block_size = 16 calculate_kv_scales = False self.kv_cache_torch_dtype = kv_cache_dtype_str_to_dtype( kv_cache_dtype, vllm_config.model_config ) if num_kv_heads is None: num_kv_heads = num_heads assert num_heads % num_kv_heads == 0, ( f"num_heads ({num_heads}) is not divisible by num_kv_heads ({num_kv_heads})" ) # Initialize KV cache quantization attributes _init_kv_cache_quant( self, quant_config, prefix, kv_cache_dtype, calculate_kv_scales ) self.num_heads = num_heads self.head_size = head_size self.num_kv_heads = num_kv_heads self.sliding_window = sliding_window self.has_sink = extra_impl_args.get("sinks") is not None # During model initialization, the default dtype is set as the model # weight and activation dtype. dtype = torch.get_default_dtype() if attn_backend is None: self.attn_backend = get_attn_backend( head_size, dtype, kv_cache_dtype, block_size, use_mla=False, has_sink=self.has_sink, attn_type=attn_type, ) else: self.attn_backend = attn_backend impl_cls = self.attn_backend.get_impl_cls() self.impl = impl_cls( num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window, kv_cache_dtype, logits_soft_cap, attn_type, kv_sharing_target_layer_name, **extra_impl_args, ) self.backend = AttentionBackendEnum[self.attn_backend.get_name()] self.dtype = dtype # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how # torch.compile works by registering the attention as one giant # opaque custom op. For other platforms, we directly call them # and let torch.compile handle them. if _USE_TORCH_OPS: self.use_direct_call = False else: self.use_direct_call = True self.use_output = self.attn_backend.accept_output_buffer compilation_config = vllm_config.compilation_config if prefix in compilation_config.static_forward_context: raise ValueError(f"Duplicate layer name: {prefix}") compilation_config.static_forward_context[prefix] = self self.layer_name = prefix self.attn_type = attn_type if kv_sharing_target_layer_name is not None: validate_kv_sharing_target( prefix, kv_sharing_target_layer_name, compilation_config.static_forward_context, ) self.kv_sharing_target_layer_name = kv_sharing_target_layer_name # use a placeholder kv cache tensor during init, which will be replaced # by bind_kv_cache # this variable will not be accessed if use_direct_call is True self.kv_cache = [ torch.tensor([]) for _ in range(vllm_config.parallel_config.pipeline_parallel_size) ] # Initialize q/k/v range constants. self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32) self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32) self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32) # for attn backends supporting query quantization self.query_quant = None if ( self.kv_cache_dtype.startswith("fp8") and self.impl.supports_quant_query_input() ): self.query_quant = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR) def forward( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, # For some alternate attention backends like MLA the attention output # shape does not match the query shape, so we optionally let the model # definition specify the output tensor shape. output_shape: torch.Size | None = None, ) -> torch.Tensor: """ The KV cache is stored inside this class and is accessed via `self.kv_cache`. Attention metadata (`attn_metadata`) is set using a context manager in the model runner's `execute_model` method. It is accessed via forward context using `vllm.forward_context.get_forward_context().attn_metadata`. """ if self.calculate_kv_scales: torch.ops.vllm.maybe_calc_kv_scales(query, key, value, self.layer_name) output_dtype = query.dtype if self.query_quant is not None: # quantizing with a simple torch operation enables # torch.compile to fuse this into previous ops # which reduces overheads during decoding. # Otherwise queries are quantized using custom ops # which causes decoding overheads assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"} # check if query quantization is supported if self.impl.supports_quant_query_input(): query, _ = self.query_quant(query, self._q_scale) if self.use_output: output_shape = output_shape if output_shape is not None else query.shape output = torch.empty(output_shape, dtype=output_dtype, device=query.device) hidden_size = output_shape[-1] # Reshape the query, key, and value tensors. # NOTE(woosuk): We do this outside the custom op to minimize the # CPU overheads from the non-CUDA-graph regions. query = query.view(-1, self.num_heads, self.head_size) output = output.view(-1, self.num_heads, self.head_size) if key is not None: key = key.view(-1, self.num_kv_heads, self.head_size) if value is not None: value = value.view(-1, self.num_kv_heads, self.head_size) if self.use_direct_call: def fun(layer_name: str, output: torch.Tensor): forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata if isinstance(attn_metadata, dict): attn_metadata = attn_metadata[self.layer_name] self_kv_cache = self.kv_cache[forward_context.virtual_engine] output = self.impl.forward( self, query, key, value, self_kv_cache, attn_metadata, output=output ) return output if envs.VLLM_SUPPORT_IXSERVER: return maybe_transfer_kv_layer(fun)(self.layer_name, output) else: return fun(self.layer_name, output) else: torch.ops.vllm.unified_attention_with_output( query, key, value, output, self.layer_name ) return output.view(-1, self.num_heads * self.head_size) else: if self.use_direct_call: def fun(layer_name: str): forward_context = get_forward_context() attn_metadata = forward_context.attn_metadata if isinstance(attn_metadata, dict): attn_metadata = attn_metadata[self.layer_name] self_kv_cache = self.kv_cache[forward_context.virtual_engine] return self.impl.forward( self, query, key, value, self_kv_cache, attn_metadata ) if envs.VLLM_SUPPORT_IXSERVER: return maybe_transfer_kv_layer(fun)(self.layer_name) else: return fun(self.layer_name) else: return torch.ops.vllm.unified_attention( query, key, value, self.layer_name ) def calc_kv_scales(self, query, key, value): self._q_scale.copy_(torch.abs(query).max() / self.q_range) self._k_scale.copy_(torch.abs(key).max() / self.k_range) self._v_scale.copy_(torch.abs(value).max() / self.v_range) self._q_scale_float = self._q_scale.item() self._k_scale_float = self._k_scale.item() self._v_scale_float = self._v_scale.item() # We only calculate the scales once self.calculate_kv_scales = False def extra_repr(self) -> str: s = f"head_size={self.impl.head_size}" # type: ignore s += f", num_heads={self.impl.num_heads}" # type: ignore s += f", num_kv_heads={self.impl.num_kv_heads}" # type: ignore s += f", scale={self.impl.scale}" # type: ignore s += f", backend={self.impl.__class__.__name__}" return s def process_weights_after_loading(self, act_dtype: torch.dtype): self.impl.process_weights_after_loading(act_dtype) def get_attn_backend(self) -> type[AttentionBackend]: return self.attn_backend def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: # Block size may get updated after model loading, refresh it block_size = vllm_config.cache_config.block_size # Should not be called for enc-dec or encoder-only attention. assert self.attn_type == AttentionType.DECODER if self.sliding_window is not None: assert not vllm_config.model_config.use_mla, ( "MLA is not supported for slidingwindow" ) return SlidingWindowSpec( block_size=block_size, num_kv_heads=self.num_kv_heads, head_size=self.head_size, dtype=self.kv_cache_torch_dtype, sliding_window=self.sliding_window, ) else: return FullAttentionSpec( block_size=block_size, num_kv_heads=self.num_kv_heads, head_size=self.head_size, dtype=self.kv_cache_torch_dtype, ) class MultiHeadAttention(nn.Module): """Multi-headed attention without any cache, used for ViT.""" def __init__( self, num_heads: int, head_size: int, scale: float, num_kv_heads: int | None = None, # This has no effect, it is only here to make it easier to swap # between Attention and MultiHeadAttention prefix: str = "", multimodal_config: MultiModalConfig | None = None, ) -> None: super().__init__() self.num_heads = num_heads self.head_size = head_size self.scale = scale self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads self.layer_name = prefix assert self.num_heads % self.num_kv_heads == 0, ( f"num_heads ({self.num_heads}) is not " f"divisible by num_kv_heads ({self.num_kv_heads})" ) self.num_queries_per_kv = self.num_heads // self.num_kv_heads # During model initialization, the default dtype is set as the model # weight and activation dtype. dtype = torch.get_default_dtype() # Determine the attention backend attn_backend_override = None if multimodal_config is not None: attn_backend_override = multimodal_config.mm_encoder_attn_backend backend = get_vit_attn_backend( head_size=head_size, dtype=dtype, attn_backend_override=attn_backend_override, ) # Some auto-selected backends can be upgraded # to upstream flash attention if available. # If vllm native fa is selected, we use it directly. use_upstream_fa = False self.attn_backend = ( backend if backend in { AttentionBackendEnum.TORCH_SDPA, AttentionBackendEnum.XFORMERS, AttentionBackendEnum.PALLAS, AttentionBackendEnum.ROCM_AITER_FA, AttentionBackendEnum.FLASH_ATTN, } else AttentionBackendEnum.TORCH_SDPA ) self.attn_backend, self._flash_attn_varlen_func = ( maybe_get_vit_flash_attn_backend( self.attn_backend, use_upstream_fa, attn_backend_override=attn_backend_override, ) ) if ( self.attn_backend == AttentionBackendEnum.XFORMERS and not check_xformers_availability() ): self.attn_backend = AttentionBackendEnum.TORCH_SDPA self.is_flash_attn_backend = self.attn_backend in { AttentionBackendEnum.FLASH_ATTN, AttentionBackendEnum.ROCM_AITER_FA, } # this condition is just to make sure that the # use_upstream_fa in the log is correct if ( current_platform.is_rocm() and self.attn_backend == AttentionBackendEnum.FLASH_ATTN ): use_upstream_fa = True logger.info_once( f"MultiHeadAttention attn_backend: {self.attn_backend}, " f"use_upstream_fa: {use_upstream_fa}" ) def forward( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, ) -> torch.Tensor: """Input shape: (batch_size x seq_len x hidden_size) or (batch_size x seq_len x num_heads x head_size) """ bsz, q_len = query.size()[:2] kv_len = key.size(1) query = query.view(bsz * q_len, self.num_heads, self.head_size) key = key.view(bsz * kv_len, self.num_kv_heads, self.head_size) value = value.view(bsz * kv_len, self.num_kv_heads, self.head_size) cu_q = torch.tensor([0,] + [q_len for _ in range(bsz)], device=query.device, dtype=torch.int32).cumsum(dim=0, dtype=torch.int32) cu_kv = torch.tensor([0,] + [kv_len for _ in range(bsz)], device=query.device, dtype=torch.int32).cumsum(dim=0, dtype=torch.int32) out = ops.flash_attn_varlen_func( query, key, value, cu_q, cu_kv, q_len, kv_len, softmax_scale=self.scale, causal=False, ) return out.view(bsz, q_len, -1) class MLAAttention(nn.Module, AttentionLayerBase): """Multi-Head Latent Attention layer. This class takes query, and compressed key/value tensors as input. The class does the following: 1. Store the input key and value tensors in the KV cache. 2. Perform (multi-head/multi-query/grouped-query) attention. 3. Return the output tensor. """ def __init__( self, num_heads: int, scale: float, qk_nope_head_dim: int, qk_rope_head_dim: int, v_head_dim: int, q_lora_rank: int | None, kv_lora_rank: int, kv_b_proj: ColumnParallelLinear, cache_config: CacheConfig | None = None, quant_config: QuantizationConfig | None = None, prefix: str = "", use_sparse: bool = False, indexer: object | None = None, **extra_impl_args, ): super().__init__() self.num_heads = num_heads self.scale = scale self.qk_nope_head_dim = qk_nope_head_dim self.qk_rope_head_dim = qk_rope_head_dim self.v_head_dim = v_head_dim self.q_lora_rank = q_lora_rank self.kv_lora_rank = kv_lora_rank self.head_size = kv_lora_rank + qk_rope_head_dim self.layer_name = prefix if cache_config is not None: kv_cache_dtype = cache_config.cache_dtype block_size = cache_config.block_size calculate_kv_scales = cache_config.calculate_kv_scales else: kv_cache_dtype = "auto" block_size = 16 calculate_kv_scales = False # Initialize KV cache quantization attributes _init_kv_cache_quant( self, quant_config, prefix, kv_cache_dtype, calculate_kv_scales ) dtype = torch.get_default_dtype() self.attn_backend = get_attn_backend( self.head_size, dtype, kv_cache_dtype, block_size, use_mla=True, use_sparse=use_sparse, ) impl_cls = cast(type[MLAAttentionImpl], self.attn_backend.get_impl_cls()) self.impl = impl_cls( num_heads=self.num_heads, head_size=self.head_size, scale=self.scale, num_kv_heads=1, alibi_slopes=None, sliding_window=None, kv_cache_dtype=self.kv_cache_dtype, logits_soft_cap=None, attn_type=AttentionType.DECODER, kv_sharing_target_layer_name=None, # MLA Args q_lora_rank=self.q_lora_rank, kv_lora_rank=self.kv_lora_rank, qk_nope_head_dim=self.qk_nope_head_dim, qk_rope_head_dim=self.qk_rope_head_dim, qk_head_dim=self.qk_nope_head_dim + self.qk_rope_head_dim, v_head_dim=self.v_head_dim, kv_b_proj=kv_b_proj, indexer=indexer, **extra_impl_args, ) self.use_direct_call = True compilation_config = get_current_vllm_config().compilation_config if prefix in compilation_config.static_forward_context: raise ValueError(f"Duplicate layer name: {prefix}") compilation_config.static_forward_context[prefix] = self self.kv_cache = [ torch.tensor([]) for _ in range( get_current_vllm_config().parallel_config.pipeline_parallel_size ) ] if envs.VLLM_USE_INT8_MLA: self.kv_cache_scale = [ torch.tensor([]) for _ in range(get_current_vllm_config( ).parallel_config.pipeline_parallel_size) ] self.is_int8_mla = envs.VLLM_USE_INT8_MLA self.use_sparse = use_sparse # Initialize q/k/v range constants. self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32) self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32) self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32) def forward( self, q: torch.Tensor, kv_c_normed: torch.Tensor, k_pe: torch.Tensor, output_shape: torch.Size | None = None, ) -> torch.Tensor: optional_args = {} if self.calculate_kv_scales: torch.ops.vllm.maybe_calc_kv_scales(q, kv_c_normed, k_pe, self.layer_name) if self.use_direct_call: forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata if isinstance(attn_metadata, dict): attn_metadata = attn_metadata[self.layer_name] self_kv_cache = self.kv_cache[forward_context.virtual_engine] if self.is_int8_mla: optional_args["kv_cache_scale"] = self.kv_cache_scale[forward_context.virtual_engine] if self.attn_backend.accept_output_buffer: output_shape = (output_shape if output_shape is not None else q.shape) output = torch.zeros(output_shape, dtype=q.dtype, device=q.device) output = self.impl.forward( self, q, kv_c_normed, k_pe, self_kv_cache, attn_metadata, output=output, **optional_args ) return output else: return self.impl.forward( self, q, kv_c_normed, k_pe, self_kv_cache, attn_metadata ) else: if self.attn_backend.accept_output_buffer: output = torch.empty(output_shape, dtype=q.dtype, device=q.device) torch.ops.vllm.unified_mla_attention_with_output( q, kv_c_normed, k_pe, output, self.layer_name, ) return output else: return torch.ops.vllm.unified_mla_attention( q, kv_c_normed, k_pe, self.layer_name, ) def process_weights_after_loading(self, act_dtype: torch.dtype): if hasattr(self.impl, "process_weights_after_loading"): self.impl.process_weights_after_loading(act_dtype) def calc_kv_scales( self, q: torch.Tensor, kv_c_normed: torch.Tensor, k_pe: torch.Tensor ) -> None: """Optional scale calculation for MLA inputs. Mirrors Attention.calc_kv_scales. Not all MLA backends require this """ # Use safe defaults if ranges are not present q_range = getattr(self, "q_range", torch.tensor(1.0)) k_range = getattr(self, "k_range", torch.tensor(1.0)) v_range = getattr(self, "v_range", torch.tensor(1.0)) self._q_scale.copy_(torch.abs(q).max() / q_range) # kv_c_normed is the compressed KV representation; use it for k/v kv_abs_max = torch.abs(kv_c_normed).max() self._k_scale.copy_(kv_abs_max / k_range) self._v_scale.copy_(kv_abs_max / v_range) self._q_scale_float = self._q_scale.item() self._k_scale_float = self._k_scale.item() self._v_scale_float = self._v_scale.item() self.calculate_kv_scales = False def get_attn_backend(self) -> type[AttentionBackend]: return self.attn_backend def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: kv_cache_dtype = kv_cache_dtype_str_to_dtype( self.kv_cache_dtype, vllm_config.model_config ) return MLAAttentionSpec( block_size=vllm_config.cache_config.block_size, num_kv_heads=1, head_size=self.head_size, dtype=kv_cache_dtype, cache_dtype_str=vllm_config.cache_config.cache_dtype, ) def maybe_calc_kv_scales( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, layer_name: str, ) -> None: forward_context: ForwardContext = get_forward_context() self = forward_context.no_compile_layers[layer_name] # Only calculate if the layer's calculate_kv_scales flag is True # This flag gets set to False after the first forward pass if not self.calculate_kv_scales: return self.calc_kv_scales(query, key, value) def maybe_calc_kv_scales_fake( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, layer_name: str, ) -> None: return direct_register_custom_op( op_name="maybe_calc_kv_scales", op_func=maybe_calc_kv_scales, mutates_args=["query", "key", "value"], fake_impl=maybe_calc_kv_scales_fake, ) def get_attention_context( layer_name: str, ) -> tuple[dict | object | None, Attention | MLAAttention, torch.Tensor]: """Extract attention context for a given layer. This helper function extracts the attention metadata, attention layer instance, and KV cache tensor for a specific layer. Args: layer_name: The name/identifier of the attention layer. Returns: A tuple containing: - attn_metadata: Attention metadata for this specific layer, or None if no metadata available - attn_layer: The attention layer instance (Attention or MLAAttention) - kv_cache: The KV cache tensor for current virtual engine Note: attn_metadata may be None, but attn_layer and kv_cache are always extracted from the forward context. """ forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata if isinstance(attn_metadata, dict): attn_metadata = attn_metadata[layer_name] attn_layer: Attention | MLAAttention = forward_context.no_compile_layers[layer_name] kv_cache = attn_layer.kv_cache[forward_context.virtual_engine] return attn_metadata, attn_layer, kv_cache @maybe_transfer_kv_layer def unified_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, layer_name: str, ) -> torch.Tensor: attn_metadata, self, kv_cache = get_attention_context(layer_name) output = self.impl.forward(self, query, key, value, kv_cache, attn_metadata) return output def unified_attention_fake( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, layer_name: str, ) -> torch.Tensor: return torch.empty_like(query).contiguous() direct_register_custom_op( op_name="unified_attention", op_func=unified_attention, fake_impl=unified_attention_fake, ) @maybe_transfer_kv_layer def unified_attention_with_output( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, output: torch.Tensor, layer_name: str, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, ) -> None: attn_metadata, self, kv_cache = get_attention_context(layer_name) self.impl.forward( self, query, key, value, kv_cache, attn_metadata, output=output, output_scale=output_scale, output_block_scale=output_block_scale, ) def unified_attention_with_output_fake( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, output: torch.Tensor, layer_name: str, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, ) -> None: return direct_register_custom_op( op_name="unified_attention_with_output", op_func=unified_attention_with_output, mutates_args=["output", "output_block_scale"], fake_impl=unified_attention_with_output_fake, ) @maybe_transfer_kv_layer def unified_mla_attention( q: torch.Tensor, kv_c_normed: torch.Tensor, k_pe: torch.Tensor, layer_name: str, ) -> torch.Tensor: attn_metadata, self, kv_cache = get_attention_context(layer_name) output = self.impl.forward(self, q, kv_c_normed, k_pe, kv_cache, attn_metadata) return output def unified_mla_attention_fake( q: torch.Tensor, kv_c_normed: torch.Tensor, k_pe: torch.Tensor, layer_name: str, ) -> torch.Tensor: return torch.empty_like(q).contiguous() direct_register_custom_op( op_name="unified_mla_attention", op_func=unified_mla_attention, mutates_args=[], fake_impl=unified_mla_attention_fake, dispatch_key=current_platform.dispatch_key, ) @maybe_transfer_kv_layer def unified_mla_attention_with_output( q: torch.Tensor, kv_c_normed: torch.Tensor, k_pe: torch.Tensor, output: torch.Tensor, layer_name: str, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, ) -> None: attn_metadata, self, kv_cache = get_attention_context(layer_name) self.impl.forward( self, q, kv_c_normed, k_pe, kv_cache, attn_metadata, output=output, output_scale=output_scale, output_block_scale=output_block_scale, ) def unified_mla_attention_with_output_fake( q: torch.Tensor, kv_c_normed: torch.Tensor, k_pe: torch.Tensor, output: torch.Tensor, layer_name: str, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, ) -> None: return direct_register_custom_op( op_name="unified_mla_attention_with_output", op_func=unified_mla_attention_with_output, mutates_args=["output", "output_block_scale"], fake_impl=unified_mla_attention_with_output_fake, dispatch_key=current_platform.dispatch_key, )