# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer.""" import functools from typing import cast import torch import torch.nn as nn import torch.nn.functional as F import vllm.envs as envs from vllm.attention.backends.abstract import ( AttentionBackend, AttentionType, MLAAttentionImpl, ) from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.layers.mm_encoder_attention import maybe_get_vit_flash_attn_backend from vllm.attention.selector import get_attn_backend from vllm.attention.utils.fa_utils import get_flash_attn_version from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target from vllm.attention.utils.kv_transfer_utils import maybe_transfer_kv_layer from vllm.config import CacheConfig, get_current_vllm_config from vllm.config.multimodal import MultiModalConfig from vllm.config.vllm import VllmConfig from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant from vllm.model_executor.layers.linear import ( ColumnParallelLinear, UnquantizedLinearMethod, ) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.models.vision import get_vit_attn_backend from vllm.platforms import current_platform from vllm.utils.torch_utils import ( direct_register_custom_op, kv_cache_dtype_str_to_dtype, ) from vllm.v1.kv_cache_interface import ( FullAttentionSpec, KVCacheSpec, MLAAttentionSpec, SlidingWindowSpec, ) logger = init_logger(__name__) def _init_kv_cache_quant( layer: nn.Module, quant_config: QuantizationConfig | None, prefix: str, kv_cache_dtype: str, calculate_kv_scales: bool, ) -> None: """Initializes KV cache scaling factors and quantization method. This helper function sets up the KV cache quantization attributes that are shared between Attention and MLAAttention layers. It initializes scale tensors for query, key, value, and probability, and configures the quantization method if applicable. Args: layer: The attention layer instance to initialize. quant_config: Optional quantization configuration. prefix: Layer name prefix for quantization method lookup. kv_cache_dtype: The KV cache data type string. calculate_kv_scales: Whether to calculate KV scales dynamically. """ # The default k/v_scale is set to 1.0. This is ignored # when kv-cache is not fp8, and should be used with # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we # expect the pre-quantized k/v_scale to be loaded along # with the model weights. layer.kv_cache_dtype = kv_cache_dtype layer.calculate_kv_scales = calculate_kv_scales layer._k_scale = torch.tensor(1.0, dtype=torch.float32) layer._v_scale = torch.tensor(1.0, dtype=torch.float32) layer._q_scale = torch.tensor(1.0, dtype=torch.float32) layer._prob_scale = torch.tensor(1.0, dtype=torch.float32) # We also keep q/k/v_scale on host (cpu) memory for attention # backends that require the scales to be on host instead of on device. # e.g. Flashinfer layer._q_scale_float = 1.0 layer._k_scale_float = 1.0 layer._v_scale_float = 1.0 # The output scale on host memory. This should be the input scale of # the quant op after this attention layer. layer._o_scale_float = None quant_method = ( quant_config.get_quant_method(layer, prefix=prefix) if quant_config else None ) if quant_method is not None and not isinstance( quant_method, UnquantizedLinearMethod ): assert isinstance(quant_method, BaseKVCacheMethod) # TODO (mgoin): kv cache dtype should be specified in the FP8 # checkpoint config and become the "auto" behavior if kv_cache_dtype == "fp8_e5m2": raise ValueError("fp8_e5m2 kv-cache is not supported with fp8 checkpoints.") # If quantization is enabled, we make "k_scale" and "v_scale" # parameters so that it can be loaded from the model checkpoint. # The k/v_scale will then be converted back to native float32 # values after weight loading. layer.quant_method = quant_method layer.quant_method.create_weights(layer) class Attention(nn.Module, AttentionLayerBase): """Attention layer. This class takes query, key, and value tensors as input. The input tensors can either contain prompt tokens or generation tokens. The class does the following: 1. Store the input key and value tensors in the KV cache. 2. Perform (multi-head/multi-query/grouped-query) attention. 3. Return the output tensor. """ def __init__( self, num_heads: int, head_size: int, scale: float, num_kv_heads: int | None = None, alibi_slopes: list[float] | None = None, cache_config: CacheConfig | None = None, quant_config: QuantizationConfig | None = None, logits_soft_cap: float | None = None, per_layer_sliding_window: int | None = None, prefix: str = "", attn_type: str = AttentionType.DECODER, kv_sharing_target_layer_name: str | None = None, attn_backend: type[AttentionBackend] | None = None, **extra_impl_args, ) -> None: """ The KV cache is stored inside this class and is accessed via `self.kv_cache`. """ super().__init__() if per_layer_sliding_window is not None: # per-layer sliding window sliding_window = per_layer_sliding_window elif cache_config is not None: # model-level sliding window sliding_window = cache_config.sliding_window else: sliding_window = None vllm_config = get_current_vllm_config() if cache_config is not None: kv_cache_dtype = cache_config.cache_dtype block_size = cache_config.block_size calculate_kv_scales = cache_config.calculate_kv_scales else: kv_cache_dtype = "auto" block_size = 16 calculate_kv_scales = False self.kv_cache_torch_dtype = kv_cache_dtype_str_to_dtype( kv_cache_dtype, vllm_config.model_config ) if num_kv_heads is None: num_kv_heads = num_heads assert num_heads % num_kv_heads == 0, ( f"num_heads ({num_heads}) is not divisible by num_kv_heads ({num_kv_heads})" ) # Initialize KV cache quantization attributes _init_kv_cache_quant( self, quant_config, prefix, kv_cache_dtype, calculate_kv_scales ) self.num_heads = num_heads self.head_size = head_size self.num_kv_heads = num_kv_heads self.sliding_window = sliding_window self.has_sink = extra_impl_args.get("sinks") is not None # NOTE: model_config may be None during certain tests model_config = vllm_config.model_config self.use_mm_prefix = model_config is not None and model_config.is_mm_prefix_lm # During model initialization, the default dtype is set as the model # weight and activation dtype. dtype = torch.get_default_dtype() if attn_backend is None: self.attn_backend = get_attn_backend( head_size, dtype, kv_cache_dtype, block_size, use_mla=False, has_sink=self.has_sink, use_mm_prefix=self.use_mm_prefix, attn_type=attn_type, ) else: self.attn_backend = attn_backend # prefix caching + batch invariance is currently not supported for # FLASHINFER and TRITON_MLA. if ( cache_config is not None and cache_config.enable_prefix_caching and vllm_is_batch_invariant() and ( self.attn_backend.get_name() == "FLASHINFER" or self.attn_backend.get_name() == "TRITON_MLA" ) ): logger.warning_once( "Disabling prefix caching for FLASHINFER/TRITON_MLA " "with batch invariance, as it is not yet supported.", scope="local", ) cache_config.enable_prefix_caching = False impl_cls = self.attn_backend.get_impl_cls() self.impl = impl_cls( num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window, kv_cache_dtype, logits_soft_cap, attn_type, kv_sharing_target_layer_name, **extra_impl_args, ) backend_name = self.attn_backend.get_name() self.backend = AttentionBackendEnum.__members__.get(backend_name) self.dtype = dtype # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how # torch.compile works by registering the attention as one giant # opaque custom op. For other platforms, we directly call them # and let torch.compile handle them. self.use_direct_call = not current_platform.opaque_attention_op() self.use_output = self.attn_backend.accept_output_buffer compilation_config = vllm_config.compilation_config if prefix in compilation_config.static_forward_context: raise ValueError(f"Duplicate layer name: {prefix}") compilation_config.static_forward_context[prefix] = self self.layer_name = prefix self.attn_type = attn_type if kv_sharing_target_layer_name is not None: validate_kv_sharing_target( prefix, kv_sharing_target_layer_name, compilation_config.static_forward_context, ) self.kv_sharing_target_layer_name = kv_sharing_target_layer_name # use a placeholder kv cache tensor during init, which will be replaced # by bind_kv_cache # this variable will not be accessed if use_direct_call is True self.kv_cache = [ torch.tensor([]) for _ in range(vllm_config.parallel_config.pipeline_parallel_size) ] # Initialize q/k/v range constants. self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32) self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32) self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32) # for attn backends supporting query quantization self.query_quant = None if ( self.kv_cache_dtype.startswith("fp8") and self.impl.supports_quant_query_input ): self.query_quant = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR) def forward( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, # For some alternate attention backends like MLA the attention output # shape does not match the query shape, so we optionally let the model # definition specify the output tensor shape. output_shape: torch.Size | None = None, ) -> torch.Tensor: """ The KV cache is stored inside this class and is accessed via `self.kv_cache`. Attention metadata (`attn_metadata`) is set using a context manager in the model runner's `execute_model` method. It is accessed via forward context using `vllm.forward_context.get_forward_context().attn_metadata`. """ if self.calculate_kv_scales: torch.ops.vllm.maybe_calc_kv_scales(query, key, value, self.layer_name) output_dtype = query.dtype if self.query_quant is not None: # quantizing with a simple torch operation enables # torch.compile to fuse this into previous ops # which reduces overheads during decoding. # Otherwise queries are quantized using custom ops # which causes decoding overheads assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"} # check if query quantization is supported if self.impl.supports_quant_query_input: query, _ = self.query_quant(query, self._q_scale) if self.use_output: output_shape = output_shape if output_shape is not None else query.shape output = torch.empty(output_shape, dtype=output_dtype, device=query.device) hidden_size = output_shape[-1] # Reshape the query, key, and value tensors. # NOTE(woosuk): We do this outside the custom op to minimize the # CPU overheads from the non-CUDA-graph regions. query = query.view(-1, self.num_heads, self.head_size) output = output.view(-1, self.num_heads, self.head_size) if key is not None: key = key.view(-1, self.num_kv_heads, self.head_size) if value is not None: value = value.view(-1, self.num_kv_heads, self.head_size) if self.use_direct_call: forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata if isinstance(attn_metadata, dict): attn_metadata = attn_metadata[self.layer_name] self_kv_cache = self.kv_cache[forward_context.virtual_engine] self.impl.forward( self, query, key, value, self_kv_cache, attn_metadata, output=output ) else: torch.ops.vllm.unified_attention_with_output( query, key, value, output, self.layer_name ) return output.view(-1, hidden_size) else: if self.use_direct_call: forward_context = get_forward_context() attn_metadata = forward_context.attn_metadata if isinstance(attn_metadata, dict): attn_metadata = attn_metadata[self.layer_name] self_kv_cache = self.kv_cache[forward_context.virtual_engine] return self.impl.forward( self, query, key, value, self_kv_cache, attn_metadata ) else: return torch.ops.vllm.unified_attention( query, key, value, self.layer_name ) def calc_kv_scales(self, query, key, value): self._q_scale.copy_(torch.abs(query).max() / self.q_range) self._k_scale.copy_(torch.abs(key).max() / self.k_range) self._v_scale.copy_(torch.abs(value).max() / self.v_range) self._q_scale_float = self._q_scale.item() self._k_scale_float = self._k_scale.item() self._v_scale_float = self._v_scale.item() # We only calculate the scales once self.calculate_kv_scales = False def extra_repr(self) -> str: s = f"head_size={self.impl.head_size}" # type: ignore s += f", num_heads={self.impl.num_heads}" # type: ignore s += f", num_kv_heads={self.impl.num_kv_heads}" # type: ignore s += f", scale={self.impl.scale}" # type: ignore s += f", backend={self.impl.__class__.__name__}" return s def process_weights_after_loading(self, act_dtype: torch.dtype): self.impl.process_weights_after_loading(act_dtype) def get_attn_backend(self) -> type[AttentionBackend]: return self.attn_backend def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: # Block size may get updated after model loading, refresh it block_size = vllm_config.cache_config.block_size # Should not be called for enc-dec or encoder-only attention. assert self.attn_type == AttentionType.DECODER if self.sliding_window is not None: assert not vllm_config.model_config.use_mla, ( "MLA is not supported for slidingwindow" ) return SlidingWindowSpec( block_size=block_size, num_kv_heads=self.num_kv_heads, head_size=self.head_size, dtype=self.kv_cache_torch_dtype, sliding_window=self.sliding_window, ) else: return FullAttentionSpec( block_size=block_size, num_kv_heads=self.num_kv_heads, head_size=self.head_size, dtype=self.kv_cache_torch_dtype, ) class MultiHeadAttention(nn.Module): """Multi-headed attention without any cache, used for ViT.""" def __init__( self, num_heads: int, head_size: int, scale: float, num_kv_heads: int | None = None, # This has no effect, it is only here to make it easier to swap # between Attention and MultiHeadAttention prefix: str = "", multimodal_config: MultiModalConfig | None = None, ) -> None: super().__init__() self.num_heads = num_heads self.head_size = head_size self.scale = scale self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads self.layer_name = prefix assert self.num_heads % self.num_kv_heads == 0, ( f"num_heads ({self.num_heads}) is not " f"divisible by num_kv_heads ({self.num_kv_heads})" ) self.num_queries_per_kv = self.num_heads // self.num_kv_heads # During model initialization, the default dtype is set as the model # weight and activation dtype. dtype = torch.get_default_dtype() # Determine the attention backend attn_backend_override = None if multimodal_config is not None: attn_backend_override = multimodal_config.mm_encoder_attn_backend self.attn_backend = get_vit_attn_backend( head_size=head_size, dtype=dtype, attn_backend_override=attn_backend_override, ) self._flash_attn_varlen_func = maybe_get_vit_flash_attn_backend( self.attn_backend, ) self.is_flash_attn_backend = self.attn_backend in { AttentionBackendEnum.FLASH_ATTN, AttentionBackendEnum.ROCM_AITER_FA, } self.fa_version = None if ( self.attn_backend == AttentionBackendEnum.FLASH_ATTN and current_platform.is_cuda() ): self.fa_version = get_flash_attn_version() assert self._flash_attn_varlen_func is not None self._flash_attn_varlen_func = functools.partial( self._flash_attn_varlen_func, fa_version=self.fa_version ) logger.info_once( f"Using {self.attn_backend} for MultiHeadAttention in multimodal encoder." ) def forward( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, ) -> torch.Tensor: """Input shape: (batch_size x seq_len x hidden_size) or (batch_size x seq_len x num_heads x head_size) """ bsz, q_len = query.size()[:2] kv_len = key.size(1) query = query.view(bsz, q_len, self.num_heads, self.head_size) key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size) value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size) if (num_repeat := self.num_queries_per_kv) > 1: # Handle MQA and GQA key = torch.repeat_interleave(key, num_repeat, dim=2) value = torch.repeat_interleave(value, num_repeat, dim=2) if self.is_flash_attn_backend: assert self._flash_attn_varlen_func is not None cu_seqlens_q = torch.arange( 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=query.device ) cu_seqlens_k = torch.arange( 0, (bsz + 1) * kv_len, step=kv_len, dtype=torch.int32, device=key.device ) out = self._flash_attn_varlen_func( query.flatten(0, 1), key.flatten(0, 1), value.flatten(0, 1), cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=q_len, max_seqlen_k=kv_len, softmax_scale=self.scale, ) elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: query, key, value = (x.transpose(1, 2) for x in (query, key, value)) out = F.scaled_dot_product_attention(query, key, value, scale=self.scale) out = out.transpose(1, 2) elif self.attn_backend == AttentionBackendEnum.PALLAS: query, key, value = (x.transpose(1, 2) for x in (query, key, value)) from torch_xla.experimental.custom_kernel import flash_attention out = flash_attention(query, key, value, sm_scale=self.scale) out = out.transpose(1, 2) else: # ViT attention hasn't supported this backend yet raise NotImplementedError( f"ViT attention hasn't supported {self.attn_backend} backend yet." ) return out.reshape(bsz, q_len, -1) class MLAAttention(nn.Module, AttentionLayerBase): """Multi-Head Latent Attention layer. This class takes query, and compressed key/value tensors as input. The class does the following: 1. Store the input key and value tensors in the KV cache. 2. Perform (multi-head/multi-query/grouped-query) attention. 3. Return the output tensor. """ def __init__( self, num_heads: int, scale: float, qk_nope_head_dim: int, qk_rope_head_dim: int, v_head_dim: int, q_lora_rank: int | None, kv_lora_rank: int, kv_b_proj: ColumnParallelLinear, cache_config: CacheConfig | None = None, quant_config: QuantizationConfig | None = None, prefix: str = "", use_sparse: bool = False, indexer: object | None = None, **extra_impl_args, ): super().__init__() self.num_heads = num_heads self.scale = scale self.qk_nope_head_dim = qk_nope_head_dim self.qk_rope_head_dim = qk_rope_head_dim self.v_head_dim = v_head_dim self.q_lora_rank = q_lora_rank self.kv_lora_rank = kv_lora_rank self.head_size = kv_lora_rank + qk_rope_head_dim self.layer_name = prefix if cache_config is not None: kv_cache_dtype = cache_config.cache_dtype block_size = cache_config.block_size calculate_kv_scales = cache_config.calculate_kv_scales else: kv_cache_dtype = "auto" block_size = 16 calculate_kv_scales = False # Initialize KV cache quantization attributes _init_kv_cache_quant( self, quant_config, prefix, kv_cache_dtype, calculate_kv_scales ) dtype = torch.get_default_dtype() self.attn_backend = get_attn_backend( self.head_size, dtype, kv_cache_dtype, block_size, use_mla=True, use_sparse=use_sparse, ) if ( cache_config is not None and cache_config.enable_prefix_caching and vllm_is_batch_invariant() and ( self.attn_backend.get_name() == "TRITON_MLA" or self.attn_backend.get_name() == "FLASHINFER" ) ): logger.warning_once( "Disabling prefix caching for TRITON_MLA / FLASHINFER " "with batch invariance, as it is not yet supported.", scope="local", ) cache_config.enable_prefix_caching = False impl_cls = cast(type[MLAAttentionImpl], self.attn_backend.get_impl_cls()) self.impl = impl_cls( num_heads=self.num_heads, head_size=self.head_size, scale=self.scale, num_kv_heads=1, alibi_slopes=None, sliding_window=None, kv_cache_dtype=self.kv_cache_dtype, logits_soft_cap=None, attn_type=AttentionType.DECODER, kv_sharing_target_layer_name=None, # MLA Args q_lora_rank=self.q_lora_rank, kv_lora_rank=self.kv_lora_rank, qk_nope_head_dim=self.qk_nope_head_dim, qk_rope_head_dim=self.qk_rope_head_dim, qk_head_dim=self.qk_nope_head_dim + self.qk_rope_head_dim, v_head_dim=self.v_head_dim, kv_b_proj=kv_b_proj, indexer=indexer, **extra_impl_args, ) self.use_direct_call = not current_platform.opaque_attention_op() compilation_config = get_current_vllm_config().compilation_config if prefix in compilation_config.static_forward_context: raise ValueError(f"Duplicate layer name: {prefix}") compilation_config.static_forward_context[prefix] = self self.kv_cache = [ torch.tensor([]) for _ in range( get_current_vllm_config().parallel_config.pipeline_parallel_size ) ] self.use_sparse = use_sparse # Initialize q/k/v range constants. self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32) self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32) self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32) def forward( self, q: torch.Tensor, kv_c_normed: torch.Tensor, k_pe: torch.Tensor, output_shape: torch.Size | None = None, ) -> torch.Tensor: if self.calculate_kv_scales: torch.ops.vllm.maybe_calc_kv_scales(q, kv_c_normed, k_pe, self.layer_name) if self.use_direct_call: forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata if isinstance(attn_metadata, dict): attn_metadata = attn_metadata[self.layer_name] self_kv_cache = self.kv_cache[forward_context.virtual_engine] if self.attn_backend.accept_output_buffer: output = torch.empty(output_shape, dtype=q.dtype, device=q.device) self.impl.forward( self, q, kv_c_normed, k_pe, self_kv_cache, attn_metadata, output=output, ) return output else: return self.impl.forward( self, q, kv_c_normed, k_pe, self_kv_cache, attn_metadata ) else: if self.attn_backend.accept_output_buffer: output = torch.empty(output_shape, dtype=q.dtype, device=q.device) torch.ops.vllm.unified_mla_attention_with_output( q, kv_c_normed, k_pe, output, self.layer_name, ) return output else: return torch.ops.vllm.unified_mla_attention( q, kv_c_normed, k_pe, self.layer_name, ) def process_weights_after_loading(self, act_dtype: torch.dtype): if hasattr(self.impl, "process_weights_after_loading"): self.impl.process_weights_after_loading(act_dtype) def calc_kv_scales( self, q: torch.Tensor, kv_c_normed: torch.Tensor, k_pe: torch.Tensor ) -> None: """Optional scale calculation for MLA inputs. Mirrors Attention.calc_kv_scales. Not all MLA backends require this """ # Use safe defaults if ranges are not present q_range = getattr(self, "q_range", torch.tensor(1.0)) k_range = getattr(self, "k_range", torch.tensor(1.0)) v_range = getattr(self, "v_range", torch.tensor(1.0)) self._q_scale.copy_(torch.abs(q).max() / q_range) # kv_c_normed is the compressed KV representation; use it for k/v kv_abs_max = torch.abs(kv_c_normed).max() self._k_scale.copy_(kv_abs_max / k_range) self._v_scale.copy_(kv_abs_max / v_range) self._q_scale_float = self._q_scale.item() self._k_scale_float = self._k_scale.item() self._v_scale_float = self._v_scale.item() self.calculate_kv_scales = False def get_attn_backend(self) -> type[AttentionBackend]: return self.attn_backend def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: kv_cache_dtype = kv_cache_dtype_str_to_dtype( self.kv_cache_dtype, vllm_config.model_config ) return MLAAttentionSpec( block_size=vllm_config.cache_config.block_size, num_kv_heads=1, head_size=self.head_size, dtype=kv_cache_dtype, cache_dtype_str=vllm_config.cache_config.cache_dtype, ) def maybe_calc_kv_scales( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, layer_name: str, ) -> None: forward_context: ForwardContext = get_forward_context() self = forward_context.no_compile_layers[layer_name] # Only calculate if the layer's calculate_kv_scales flag is True # This flag gets set to False after the first forward pass if not self.calculate_kv_scales: return self.calc_kv_scales(query, key, value) def maybe_calc_kv_scales_fake( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, layer_name: str, ) -> None: return direct_register_custom_op( op_name="maybe_calc_kv_scales", op_func=maybe_calc_kv_scales, mutates_args=["query", "key", "value"], fake_impl=maybe_calc_kv_scales_fake, ) def get_attention_context( layer_name: str, ) -> tuple[dict | object | None, Attention | MLAAttention, torch.Tensor]: """Extract attention context for a given layer. This helper function extracts the attention metadata, attention layer instance, and KV cache tensor for a specific layer. Args: layer_name: The name/identifier of the attention layer. Returns: A tuple containing: - attn_metadata: Attention metadata for this specific layer, or None if no metadata available - attn_layer: The attention layer instance (Attention or MLAAttention) - kv_cache: The KV cache tensor for current virtual engine Note: attn_metadata may be None, but attn_layer and kv_cache are always extracted from the forward context. """ forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata if isinstance(attn_metadata, dict): attn_metadata = attn_metadata[layer_name] attn_layer: Attention | MLAAttention = forward_context.no_compile_layers[layer_name] kv_cache = attn_layer.kv_cache[forward_context.virtual_engine] return attn_metadata, attn_layer, kv_cache @maybe_transfer_kv_layer def unified_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, layer_name: str, ) -> torch.Tensor: attn_metadata, self, kv_cache = get_attention_context(layer_name) output = self.impl.forward(self, query, key, value, kv_cache, attn_metadata) return output def unified_attention_fake( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, layer_name: str, ) -> torch.Tensor: return torch.empty_like(query).contiguous() direct_register_custom_op( op_name="unified_attention", op_func=unified_attention, fake_impl=unified_attention_fake, ) @maybe_transfer_kv_layer def unified_attention_with_output( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, output: torch.Tensor, layer_name: str, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, ) -> None: attn_metadata, self, kv_cache = get_attention_context(layer_name) self.impl.forward( self, query, key, value, kv_cache, attn_metadata, output=output, output_scale=output_scale, output_block_scale=output_block_scale, ) def unified_attention_with_output_fake( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, output: torch.Tensor, layer_name: str, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, ) -> None: return direct_register_custom_op( op_name="unified_attention_with_output", op_func=unified_attention_with_output, mutates_args=["output", "output_block_scale"], fake_impl=unified_attention_with_output_fake, ) @maybe_transfer_kv_layer def unified_mla_attention( q: torch.Tensor, kv_c_normed: torch.Tensor, k_pe: torch.Tensor, layer_name: str, ) -> torch.Tensor: attn_metadata, self, kv_cache = get_attention_context(layer_name) output = self.impl.forward(self, q, kv_c_normed, k_pe, kv_cache, attn_metadata) return output def unified_mla_attention_fake( q: torch.Tensor, kv_c_normed: torch.Tensor, k_pe: torch.Tensor, layer_name: str, ) -> torch.Tensor: return torch.empty_like(q).contiguous() direct_register_custom_op( op_name="unified_mla_attention", op_func=unified_mla_attention, mutates_args=[], fake_impl=unified_mla_attention_fake, dispatch_key=current_platform.dispatch_key, ) @maybe_transfer_kv_layer def unified_mla_attention_with_output( q: torch.Tensor, kv_c_normed: torch.Tensor, k_pe: torch.Tensor, output: torch.Tensor, layer_name: str, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, ) -> None: attn_metadata, self, kv_cache = get_attention_context(layer_name) self.impl.forward( self, q, kv_c_normed, k_pe, kv_cache, attn_metadata, output=output, output_scale=output_scale, output_block_scale=output_block_scale, ) def unified_mla_attention_with_output_fake( q: torch.Tensor, kv_c_normed: torch.Tensor, k_pe: torch.Tensor, output: torch.Tensor, layer_name: str, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, ) -> None: return direct_register_custom_op( op_name="unified_mla_attention_with_output", op_func=unified_mla_attention_with_output, mutates_args=["output", "output_block_scale"], fake_impl=unified_mla_attention_with_output_fake, dispatch_key=current_platform.dispatch_key, )