# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer with PagedAttention and Triton prefix prefill.""" from typing import TYPE_CHECKING, Any, Optional from dataclasses import dataclass from functools import cache from typing import ClassVar, Optional import torch from vllm import _custom_ops as ops from vllm import envs from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) from vllm.attention.ops.chunked_prefill_paged_decode import ( chunked_prefill_paged_decode) from vllm.attention.ops.paged_attn import PagedAttention from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.v1.attention.backends.flash_attn import ( FlashAttentionMetadata, FlashAttentionMetadataBuilder) from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.worker.block_table import BlockTable if TYPE_CHECKING: from vllm.v1.worker.gpu_model_runner import GPUModelRunner logger = init_logger(__name__) class TritonAttentionMetadataBuilder(FlashAttentionMetadataBuilder): def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec, block_table: BlockTable): super().__init__(runner, kv_cache_spec, block_table) self.aot_schedule = False class TritonAttentionBackend(AttentionBackend): accept_output_buffer: bool = True @classmethod def get_supported_dtypes(cls) -> list[torch.dtype]: return [torch.float16, torch.bfloat16] @classmethod def get_supported_head_sizes(cls) -> list[int]: return [32, 64, 96, 128, 160, 192, 224, 256] @classmethod def validate_head_size(cls, head_size: int) -> None: supported_head_sizes = cls.get_supported_head_sizes() if head_size not in supported_head_sizes: attn_type = cls.__name__.removesuffix("Backend") raise ValueError( f"Head size {head_size} is not supported by {attn_type}. " f"Supported head sizes are: {supported_head_sizes}. " "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " "FlexAttention backend which supports all head sizes.") @staticmethod def get_name() -> str: return "TRITON_ATTN_VLLM_V1" @staticmethod def get_impl_cls() -> type["TritonAttentionImpl"]: return TritonAttentionImpl @staticmethod def get_metadata_cls() -> type["AttentionMetadata"]: return FlashAttentionMetadata @staticmethod def get_kv_cache_shape( num_blocks: int, block_size: int, num_kv_heads: int, head_size: int, ) -> tuple[int, ...]: if block_size % 16 != 0: raise ValueError("Block size must be a multiple of 16.") return (2, num_blocks, block_size, num_kv_heads, head_size) @staticmethod def use_cascade_attention(*args, **kwargs) -> bool: return False @staticmethod def get_builder_cls() -> type["TritonAttentionMetadataBuilder"]: return TritonAttentionMetadataBuilder @cache def use_aiter_unified_attention() -> bool: """Check if aiter unified attention should be used.""" # VLLM_ROCM_USE_AITER_MHA needs to set to 0 as well as it is set # to 1 as default return envs.VLLM_ROCM_USE_AITER \ and envs.VLLM_USE_AITER_UNIFIED_ATTENTION class TritonAttentionImpl(AttentionImpl): def __init__( self, num_heads: int, head_size: int, scale: float, num_kv_heads: int, alibi_slopes: Optional[list[float]], sliding_window: Optional[int], kv_cache_dtype: str, logits_soft_cap: Optional[float] = None, attn_type: AttentionType = AttentionType.DECODER, kv_sharing_target_layer_name: Optional[int] = None, sinks: Optional[torch.Tensor] = None, ) -> None: self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) self.num_kv_heads = num_kv_heads if alibi_slopes is not None: alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) self.alibi_slopes = alibi_slopes if sliding_window is None: self.sliding_window = (-1, -1) else: self.sliding_window = (sliding_window - 1, 0) self.kv_cache_dtype = kv_cache_dtype if logits_soft_cap is None: # In flash-attn, setting logits_soft_cap as 0 means no soft cap. logits_soft_cap = 0 self.logits_soft_cap = logits_soft_cap self.kv_sharing_target_layer_name = kv_sharing_target_layer_name self.num_queries_per_kv = self.num_heads // self.num_kv_heads TritonAttentionBackend.validate_head_size(head_size) if attn_type != AttentionType.DECODER: raise NotImplementedError("Encoder self-attention and " "encoder/decoder cross-attention " "are not implemented for " "TritonAttentionImpl") self.fp8_dtype = current_platform.fp8_dtype() self.force_prefill_decode_attn = \ envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION if not self.force_prefill_decode_attn: # If not using prefill decode attention, we use the Triton # unified attention implementation. if use_aiter_unified_attention(): logger.info_once( "Using aiter unified attention for TritonAttentionImpl") from aiter.ops.triton.unified_attention import ( unified_attention) self.unified_attention = unified_attention else: logger.info_once( "Using vllm unified attention for TritonAttentionImpl") from vllm.attention.ops.triton_unified_attention import ( unified_attention) self.unified_attention = unified_attention self.sinks = sinks if sinks is not None: assert sinks.shape[0] == num_heads, ( "Sinks must have the same number of heads as the number of " f"heads in the layer. Sinks shape: {sinks.shape}, " f"num_heads: {num_heads}.") def forward( self, layer: torch.nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: FlashAttentionMetadata, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashAttention. Args: query: shape = [num_tokens, num_heads, head_size] key: shape = [num_tokens, num_kv_heads, head_size] value: shape = [num_tokens, num_kv_heads, head_size] kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] """ assert output is not None, "Output tensor must be provided." if output_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" " for TritonAttentionImpl") if attn_metadata is None: # Profiling run. return output assert attn_metadata.use_cascade is False # IMPORTANT! # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead # in this method. For example, `view` and `slice` (or `[:n]`) operations # are surprisingly slow even in the case they do not invoke any GPU ops. # Minimize the PyTorch ops in this method as much as possible. # Whenever making a change in this method, please benchmark the # performance to make sure it does not introduce any overhead. use_prefill_decode_attn = self.force_prefill_decode_attn num_actual_tokens = attn_metadata.num_actual_tokens if use_prefill_decode_attn: key_cache, value_cache = PagedAttention.split_kv_cache( kv_cache, self.num_kv_heads, self.head_size) else: key_cache, value_cache = kv_cache.unbind(0) if self.kv_sharing_target_layer_name is None: # Reshape the input keys and values and store them in the cache. # Skip this if sharing KV cache with an earlier attention layer. if use_prefill_decode_attn: PagedAttention.write_to_paged_cache( key, value, key_cache, value_cache, attn_metadata.slot_mapping, self.kv_cache_dtype, layer._k_scale, layer._v_scale, ) else: torch.ops._C_cache_ops.reshape_and_cache_flash( key, value, key_cache, value_cache, attn_metadata.slot_mapping, self.kv_cache_dtype, layer._k_scale, layer._v_scale, ) if self.kv_cache_dtype.startswith("fp8"): key_cache = key_cache.view(self.fp8_dtype) value_cache = value_cache.view(self.fp8_dtype) num_tokens, num_heads, head_size = query.shape assert layer._q_scale == 1.0, \ "A non 1.0 q_scale is not currently supported." if not current_platform.is_rocm(): # Skip Q quantization on ROCm, since dequantizing back to # f32 in the attention kernel is not supported. query, _ = ops.scaled_fp8_quant( query.reshape( (num_tokens, num_heads * head_size)).contiguous(), layer._q_scale) query = query.reshape((num_tokens, num_heads, head_size)) cu_seqlens_q = attn_metadata.query_start_loc seqused_k = attn_metadata.seq_lens max_seqlen_q = attn_metadata.max_query_len max_seqlen_k = attn_metadata.max_seq_len block_table = attn_metadata.block_table if use_prefill_decode_attn: # Compute attention and update output up to `num_actual_tokens`. chunked_prefill_paged_decode( query=query[:num_actual_tokens], key=key[:num_actual_tokens], value=value[:num_actual_tokens], output=output[:num_actual_tokens], kv_cache_dtype=self.kv_cache_dtype, key_cache=key_cache, value_cache=value_cache, block_table=block_table, query_start_loc=cu_seqlens_q, seq_lens=seqused_k, max_seq_len=max_seqlen_k, max_query_len=max_seqlen_q, k_scale=layer._k_scale, v_scale=layer._v_scale, alibi_slopes=self.alibi_slopes, sliding_window=self.sliding_window[0], sm_scale=self.scale, sinks=self.sinks, ) else: descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1]) self.unified_attention( q=query[:num_actual_tokens], k=key_cache, v=value_cache, out=output[:num_actual_tokens], cu_seqlens_q=cu_seqlens_q, max_seqlen_q=max_seqlen_q, seqused_k=seqused_k, max_seqlen_k=max_seqlen_k, softmax_scale=self.scale, causal=True, alibi_slopes=self.alibi_slopes, window_size=self.sliding_window, block_table=block_table, softcap=self.logits_soft_cap, q_descale=None, # Not supported k_descale=layer._k_scale.expand(descale_shape), v_descale=layer._v_scale.expand(descale_shape), sinks=self.sinks, ) return output