# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer with PagedAttention and Triton prefix prefill.""" import torch from vllm import _custom_ops as ops from vllm._aiter_ops import rocm_aiter_ops from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, kFp8StaticTensorSym, ) from vllm.v1.attention.backend import AttentionLayer, AttentionType from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.attention.backends.rocm_attn import ( RocmAttentionBackend, RocmAttentionImpl, RocmAttentionMetadataBuilder, ) logger = init_logger(__name__) class RocmAiterUnifiedAttentionBackend(RocmAttentionBackend): accept_output_buffer: bool = True forward_includes_kv_cache_update: bool = False @staticmethod def get_name() -> str: return "ROCM_AITER_UNIFIED_ATTN" @staticmethod def get_impl_cls() -> type["RocmAiterUnifiedAttentionImpl"]: return RocmAiterUnifiedAttentionImpl @staticmethod def get_kv_cache_shape( num_blocks: int, block_size: int, num_kv_heads: int, head_size: int, cache_dtype_str: str = "auto", ) -> tuple[int, ...]: if block_size % 16 != 0: raise ValueError("Block size must be a multiple of 16.") return (2, num_blocks, block_size, num_kv_heads, head_size) @staticmethod def use_cascade_attention(*args, **kwargs) -> bool: return False @staticmethod def get_builder_cls() -> type["RocmAttentionMetadataBuilder"]: return RocmAttentionMetadataBuilder @classmethod def supports_attn_type(cls, attn_type: str) -> bool: """RocmAiterUnifiedAttention supports all attention types.""" return attn_type in ( AttentionType.DECODER, AttentionType.ENCODER, AttentionType.ENCODER_ONLY, AttentionType.ENCODER_DECODER, ) class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl): def fused_output_quant_supported(self, quant_key: QuantKey): return quant_key == kFp8StaticTensorSym def __init__( self, num_heads: int, head_size: int, scale: float, num_kv_heads: int, alibi_slopes: list[float] | None, sliding_window: int | None, kv_cache_dtype: str, logits_soft_cap: float | None = None, attn_type: AttentionType = AttentionType.DECODER, kv_sharing_target_layer_name: int | None = None, sinks: torch.Tensor | None = None, ) -> None: super().__init__( num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window, kv_cache_dtype, logits_soft_cap, attn_type, kv_sharing_target_layer_name, sinks, ) logger.info_once( "Using aiter unified attention for RocmAiterUnifiedAttentionImpl" ) from aiter.ops.triton.unified_attention import unified_attention self.unified_attention = unified_attention def forward( self, layer: torch.nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: FlashAttentionMetadata, output: torch.Tensor | None = None, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, ) -> torch.Tensor: """Forward pass with FlashAttention. Args: query: shape = [num_tokens, num_heads, head_size] key: shape = [num_tokens, num_kv_heads, head_size] value: shape = [num_tokens, num_kv_heads, head_size] kv_cache: shape = [2, num_blocks, block_size, num_kv_heads, head_size] attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] """ assert output is not None, "Output tensor must be provided." if output_block_scale is not None: raise NotImplementedError( "fused block_scale output quantization is not yet supported" " for RocmAttentionImpl" ) if attn_metadata is None: # Profiling run. return output.fill_(0) assert attn_metadata.use_cascade is False # IMPORTANT! # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead # in this method. For example, `view` and `slice` (or `[:n]`) operations # are surprisingly slow even in the case they do not invoke any GPU ops. # Minimize the PyTorch ops in this method as much as possible. # Whenever making a change in this method, please benchmark the # performance to make sure it does not introduce any overhead. num_actual_tokens = attn_metadata.num_actual_tokens # Handle encoder attention differently - no KV cache needed if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER): # For encoder attention, # we use direct Q, K, V tensors without caching return self._forward_encoder_attention( query[:num_actual_tokens], key[:num_actual_tokens], value[:num_actual_tokens], output[:num_actual_tokens], attn_metadata, layer, ) key_cache, value_cache = kv_cache.unbind(0) if self.kv_cache_dtype.startswith("fp8"): key_cache = key_cache.view(self.fp8_dtype) value_cache = value_cache.view(self.fp8_dtype) assert layer._q_scale_float == 1.0, ( "A non 1.0 q_scale is not currently supported." ) cu_seqlens_q = attn_metadata.query_start_loc seqused_k = attn_metadata.seq_lens max_seqlen_q = attn_metadata.max_query_len max_seqlen_k = attn_metadata.max_seq_len block_table = attn_metadata.block_table descale_shape = ( cu_seqlens_q.shape[0] - 1, key.shape[1] if key is not None else self.num_kv_heads, ) self.unified_attention( q=query[:num_actual_tokens], k=key_cache, v=value_cache, out=output[:num_actual_tokens], cu_seqlens_q=cu_seqlens_q, max_seqlen_q=max_seqlen_q, seqused_k=seqused_k, max_seqlen_k=max_seqlen_k, softmax_scale=self.scale, causal=True, alibi_slopes=self.alibi_slopes, window_size=self.sliding_window, block_table=block_table, softcap=self.logits_soft_cap, q_descale=None, # Not supported k_descale=layer._k_scale.expand(descale_shape), v_descale=layer._v_scale.expand(descale_shape), sinks=self.sinks, output_scale=output_scale, ) return output def do_kv_cache_update( self, layer: AttentionLayer, key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, slot_mapping: torch.Tensor, ): if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER): # For encoder attention, # we use direct Q, K, V tensors without caching return key_cache, value_cache = kv_cache.unbind(0) # Reshape the input keys and values and store them in the cache. ops.reshape_and_cache_flash( key, value, key_cache, value_cache, slot_mapping, self.kv_cache_dtype, layer._k_scale, layer._v_scale, ) def fused_rope_kvcache_supported(self): return rocm_aiter_ops.is_enabled() def do_rope_and_kv_cache_update( self, layer: AttentionLayer, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, positions: torch.Tensor, cos_sin_cache: torch.Tensor, is_neox: bool, kv_cache: torch.Tensor, layer_slot_mapping: torch.Tensor, ): if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER): # For encoder attention, # we use direct Q, K, V tensors without caching return key_cache, value_cache = kv_cache.unbind(0) flash_layout = True is_fp8_kv_cache = self.kv_cache_dtype.startswith("fp8") if is_fp8_kv_cache: key_cache = key_cache.view(self.fp8_dtype) value_cache = value_cache.view(self.fp8_dtype) rocm_aiter_ops.triton_rope_and_cache( query, key, value, positions, cos_sin_cache, is_neox, key_cache, value_cache, layer_slot_mapping, layer._k_scale, layer._v_scale, flash_layout, is_fp8_kv_cache, )