################################################################################ # Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ################################################################################ """Attention layer with FlashAttention.""" from dataclasses import dataclass from typing import TYPE_CHECKING, ClassVar, Optional, Tuple import torch import torch_br from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionType, is_quantized_kv_cache) from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8, get_flash_attn_version) from vllm.config import VllmConfig from vllm.logger import logger from vllm.v1.attention.backends.flash_attn import _get_sliding_window_configs from vllm.v1.attention.backends.utils import (AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, get_kv_cache_layout, split_decodes_and_prefills) from vllm.v1.kv_cache_interface import AttentionSpec from vllm_br.config.compilation import SUPAGraphMode if TYPE_CHECKING: pass # from vllm.v1.worker.gpu_model_runner import GPUModelRunner class SUPAFlashAttentionBackend(AttentionBackend): # NOTE: When piecewise cudagraph is enabled, this # makes sure the output tensor is allocated inside the cudagraph. # NOTE: currently, we do not support accept_output_buffer=True accept_output_buffer: bool = False supports_quant_query_input: bool = True @classmethod def get_supported_dtypes(cls) -> list[torch.dtype]: return [torch.float16, torch.bfloat16] @classmethod def get_supported_head_sizes(cls) -> list[int]: return [32, 64, 96, 128, 160, 192, 224, 256] @classmethod def validate_head_size(cls, head_size: int) -> None: supported_head_sizes = cls.get_supported_head_sizes() if head_size not in supported_head_sizes: attn_type = cls.__name__.removesuffix("Backend") raise ValueError( f"Head size {head_size} is not supported by {attn_type}. " f"Supported head sizes are: {supported_head_sizes}. " "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " "FlexAttention backend which supports all head sizes.") @staticmethod def get_name() -> str: return "SUPAFLASH_ATTN_VLLM_V1" @staticmethod def get_impl_cls() -> type["SUPAFlashAttentionImpl"]: return SUPAFlashAttentionImpl @staticmethod def get_metadata_cls() -> type["SUPAFlashAttentionMetadata"]: return SUPAFlashAttentionMetadata @staticmethod def get_builder_cls() -> type["SUPAFlashAttentionMetadataBuilder"]: return SUPAFlashAttentionMetadataBuilder @staticmethod def get_kv_cache_shape( num_blocks: int, block_size: int, num_kv_heads: int, head_size: int, ) -> tuple[int, ...]: if block_size % 16 != 0: raise ValueError("Block size must be a multiple of 16.") return (2, num_blocks, block_size, num_kv_heads, head_size) @staticmethod def get_kv_cache_usharp_shape( num_blocks: int, block_size: int, num_kv_heads: int, head_size: int, ) -> Tuple[int, ...]: th_gran = SUPAFlashAttentionBackend.get_kv_cache_usharp_alignment( block_size) n_block = max(1, (num_blocks + th_gran - 1) // th_gran) logger.debug( f'Origin kv cache shape is [2, {num_blocks}, {block_size}, {num_kv_heads}, {head_size}, For SUPA Speed up, use [2, {n_block}, {th_gran * block_size}, {num_kv_heads * head_size}]' # noqa: G004 ) return (2, n_block, th_gran * block_size, num_kv_heads * head_size) @staticmethod def get_kv_cache_usharp_alignment(block_size: int) -> int: max_h_limit = 2048 return max_h_limit // block_size @staticmethod def get_kv_cache_stride_order() -> tuple[int, ...]: # `stride_order` indicates the permutation that gets # us from `get_kv_cache_shape` to the actual memory layout we want. cache_layout = get_kv_cache_layout() if cache_layout == "NHD": stride_order = (0, 1, 2, 3, 4) elif cache_layout == "HND": stride_order = (0, 1, 3, 2, 4) else: raise ValueError(f"Unknown cache layout format {cache_layout}.") return stride_order @staticmethod def get_fp8_dtype_for_flashattn(kv_cache_dtype: str) -> torch.dtype: if kv_cache_dtype in ("fp8", "fp8_e4m3"): return torch.float8_e4m3fn else: raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}") @dataclass class SUPAFlashAttentionMetadata: # NOTE(sang): Definition of context_len, query_len, and seq_len. # |---------- N-1 iteration --------| # |---------------- N iteration ---------------------| # |- tokenA -|......................|-- newTokens ---| # |---------- context_len ----------| # |-------------------- seq_len ---------------------| # |-- query_len ---| num_actual_tokens: int # Number of tokens excluding padding. max_query_len: int query_start_loc: torch.Tensor max_seq_len: int seq_lens: torch.Tensor block_table: torch.Tensor slot_mapping: torch.Tensor # BIREN Attention Params seq_start_loc: torch.Tensor context_lens: torch.Tensor max_decode_seq_len: int do_cache: bool # when use attentionsplit, do cache = False num_actual_reqs: torch.Tensor # Graph mode supagraph_runtime_mode: SUPAGraphMode # For handling prefill decode split num_decodes: int num_decode_tokens: int num_prefills: int num_prefill_tokens: int # For cascade attention. use_cascade: bool common_prefix_len: int cu_prefix_query_lens: Optional[torch.Tensor] prefix_kv_lens: Optional[torch.Tensor] suffix_kv_lens: Optional[torch.Tensor] # Optional aot scheduling scheduler_metadata: Optional[torch.Tensor] = None prefix_scheduler_metadata: Optional[torch.Tensor] = None max_num_splits: int = 0 causal: bool = True # for local attention # @dataclass # class LocalAttentionMetadata: # local_query_start_loc: torch.Tensor # local_seqused_k: torch.Tensor # local_block_table: torch.Tensor # local_max_query_len: int # local_max_seq_len: int # local_scheduler_metadata: Optional[torch.Tensor] # local_attn_metadata: Optional[LocalAttentionMetadata] = None class SUPAFlashAttentionMetadataBuilder( AttentionMetadataBuilder[SUPAFlashAttentionMetadata]): cudagraph_support: ClassVar[AttentionCGSupport] = \ AttentionCGSupport.ALWAYS reorder_batch_threshold: int = 1 def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): super().__init__(kv_cache_spec, layer_names, vllm_config, device) self.model_config = vllm_config.model_config self.parallel_config = vllm_config.parallel_config self.cache_config = vllm_config.cache_config self.compilation_config = vllm_config.compilation_config self.num_heads_q = self.model_config.get_num_attention_heads( self.parallel_config) self.num_heads_kv = self.model_config.get_num_kv_heads( self.parallel_config) self.kv_cache_dtype = kv_cache_spec.dtype self.headdim = self.model_config.get_head_size() self.block_size = kv_cache_spec.block_size supports_spec_as_decode = True self._init_reorder_batch_threshold(1, supports_spec_as_decode) self.max_num_splits = 0 # No upper bound on the number of splits. # self.aot_schedule = (get_flash_attn_version() == 3) self.aot_schedule = False self.use_full_cuda_graph = \ self.compilation_config.cudagraph_mode.has_full_cudagraphs() self.max_cudagraph_size = self.compilation_config.max_capture_size # if self.use_full_cuda_graph and self.aot_schedule: # if self.max_cudagraph_size > 992: # # This condition derives from FA3's internal heuristic. # # TODO(woosuk): Support larger cudagraph sizes. # raise ValueError( # "Capture size larger than 992 is not supported for " # "full cuda graph.") # self.scheduler_metadata = torch.zeros( # vllm_config.scheduler_config.max_num_seqs + 1, # dtype=torch.int32, # device=self.device, # ) # # When using cuda graph, we need to set the upper bound of the # # number of splits so that large enough intermediate buffers are # # pre-allocated during capture. # self.max_num_splits = ( # envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH) # Sliding window size to be used with the AOT scheduler will be # populated on first build() call. self.aot_sliding_window: Optional[tuple[int, int]] = None # model_config = runner.model_config # self.runner = runner # self.num_heads_q = model_config.get_num_attention_heads( # runner.parallel_config) # self.num_heads_kv = model_config.get_num_kv_heads( # runner.parallel_config) # self.headdim = model_config.get_head_size() # self.block_size = kv_cache_spec.block_size # self.kv_cache_spec = kv_cache_spec # self.block_table = block_table # self.aot_schedule = False # logger.warning( # "AOT Schedule is disabled when using SUPAFlashAttention.") # # Sliding window size to be used with the AOT scheduler will be # # populated on first build() call. # self.aot_sliding_window: Optional[tuple[int, int]] = None def build(self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, fast_build: bool = False) -> SUPAFlashAttentionMetadata: """ fast_build disables AOT scheduling, used when there will be few iterations i.e. spec-decode """ num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\ split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.reorder_batch_threshold, require_uniform=True) max_query_len = common_attn_metadata.max_query_len max_seq_len = common_attn_metadata.max_seq_len query_start_loc = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens seq_lens_cpu = common_attn_metadata.seq_lens_cpu block_table_tensor = common_attn_metadata.block_table_tensor slot_mapping = common_attn_metadata.slot_mapping causal = common_attn_metadata.causal num_actual_reqs = common_attn_metadata.num_actual_reqs seq_start_loc = common_attn_metadata.seq_start_loc context_lens = common_attn_metadata.context_lens # the overhead of the aot schedule is not worth it for spec-decode aot_schedule = self.aot_schedule and not fast_build if self.aot_sliding_window is None: self.aot_sliding_window = (-1, -1) # For the AOT scheduler we need the sliding window value to be # constant for all layers to. We have to populate this on the first # build() call so the layers are constructed (cannot populate) # in __init__. if aot_schedule: sliding_window_configs = _get_sliding_window_configs( self.vllm_config) if len(sliding_window_configs) == 1: sliding_window_config = sliding_window_configs.pop() if sliding_window_config is not None: self.aot_sliding_window = sliding_window_config elif len(sliding_window_configs) > 1: self.aot_schedule = False aot_schedule = False max_num_splits = 0 # 0 means use FA3's heuristics, not CG compatible if self.use_full_cuda_graph and \ num_actual_tokens <= self.max_cudagraph_size: # NOTE(woosuk): Setting num_splits > 1 may increase the memory # usage, because the intermediate buffers of size [num_splits, # num_heads, num_tokens, head_size] are allocated. Therefore, # we only set num_splits when using cuda graphs. max_num_splits = self.max_num_splits def schedule(batch_size, cu_query_lens, max_query_len, seqlens, max_seq_len, causal): if self.aot_schedule: raise NotImplementedError( 'aot schedule not support in SUPA attention') return None # for local attention # local_attn_metadata = None # if self.runner.attention_chunk_size is not None: # seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \ # virt_block_table_tensor = make_local_attention_virtual_batches( # self.runner.attention_chunk_size, # self.runner.query_start_loc_np[:num_reqs + 1], # self.runner.seq_lens_np[:num_reqs], # block_table_tensor, # self.block_size, # ) # local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to( # self.runner.device, non_blocking=False) # local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to( # self.runner.device, non_blocking=False) # local_max_query_len = seqlens_q_local_np.max() # local_max_seq_len = virt_k_seqlens_np.max() # local_scheduler_metadata = schedule( # batch_size=local_query_start_loc.shape[0] - 1, # cu_query_lens=local_query_start_loc, # max_query_len=local_max_query_len, # seqlens=local_seqused_k, # max_seq_len=local_max_seq_len, # causal=True) # local_attn_metadata = SUPAFlashAttentionMetadata.LocalAttentionMetadata( # local_query_start_loc=local_query_start_loc, # local_seqused_k=local_seqused_k, # local_block_table=virt_block_table_tensor, # local_max_query_len=local_max_query_len, # local_max_seq_len=local_max_seq_len, # local_scheduler_metadata=local_scheduler_metadata, # ) use_cascade = common_prefix_len > 0 if use_cascade: cu_prefix_query_lens = torch.tensor([0, num_actual_tokens], dtype=torch.int32, device=self.runner.device) prefix_kv_lens = torch.tensor([common_prefix_len], dtype=torch.int32, device=self.runner.device) suffix_kv_lens = (seq_lens_cpu[:num_reqs] - common_prefix_len).to( self.device, non_blocking=True) prefix_scheduler_metadata = schedule( batch_size=1, cu_query_lens=cu_prefix_query_lens, max_query_len=num_actual_tokens, seqlens=prefix_kv_lens, max_seq_len=common_prefix_len, causal=False) scheduler_metadata = schedule(batch_size=num_reqs, cu_query_lens=query_start_loc, max_query_len=max_query_len, seqlens=suffix_kv_lens, max_seq_len=max_seq_len - common_prefix_len, causal=True) else: cu_prefix_query_lens = None prefix_kv_lens = None suffix_kv_lens = None prefix_scheduler_metadata = None scheduler_metadata = schedule(batch_size=num_reqs, cu_query_lens=query_start_loc, max_query_len=max_query_len, seqlens=seq_lens, max_seq_len=max_seq_len, causal=causal) if common_attn_metadata.max_decode_seq_len is None: max_decode_seq_len = max_decode_seq_len = int( seq_lens.max().item()) else: max_decode_seq_len = common_attn_metadata.max_decode_seq_len attn_metadata = SUPAFlashAttentionMetadata( num_actual_tokens=num_actual_tokens, max_query_len=max_query_len, query_start_loc=query_start_loc, max_seq_len=max_seq_len, seq_lens=seq_lens, block_table=block_table_tensor, slot_mapping=slot_mapping, use_cascade=use_cascade, common_prefix_len=common_prefix_len, scheduler_metadata=scheduler_metadata, cu_prefix_query_lens=cu_prefix_query_lens, prefix_kv_lens=prefix_kv_lens, suffix_kv_lens=suffix_kv_lens, # local_attn_metadata=local_attn_metadata, prefix_scheduler_metadata=prefix_scheduler_metadata, max_num_splits=max_num_splits, causal=causal, # Biren Attention Params seq_start_loc=seq_start_loc, context_lens=context_lens, max_decode_seq_len=max_decode_seq_len, num_prefills=num_prefills, num_decodes=num_decodes, num_prefill_tokens=num_prefill_tokens, num_decode_tokens=num_decode_tokens, do_cache=True, num_actual_reqs=num_actual_reqs, supagraph_runtime_mode=common_attn_metadata.supagraph_runtime_mode) return attn_metadata def use_cascade_attention(self, *args, **kwargs) -> bool: return False class SUPAFlashAttentionImpl(AttentionImpl): def __init__( self, num_heads: int, head_size: int, scale: float, num_kv_heads: int, alibi_slopes: Optional[list[float]], sliding_window: Optional[int], kv_cache_dtype: str, logits_soft_cap: Optional[float] = None, attn_type: AttentionType = AttentionType.DECODER, kv_sharing_target_layer_name: Optional[str] = None, sinks: Optional[torch.Tensor] = None, ) -> None: self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) self.num_kv_heads = num_kv_heads if alibi_slopes is not None: alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32, device="cpu") self.alibi_slopes = alibi_slopes self.sliding_window = sliding_window or None self.kv_cache_dtype = kv_cache_dtype if logits_soft_cap is None: # In flash-attn, setting logits_soft_cap as 0 means no soft cap. logits_soft_cap = 0 self.logits_soft_cap = logits_soft_cap self.kv_sharing_target_layer_name = kv_sharing_target_layer_name assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads SUPAFlashAttentionBackend.validate_head_size(head_size) self.attn_type = attn_type if attn_type not in (AttentionType.DECODER, AttentionType.ENCODER_ONLY): raise NotImplementedError("Encoder self-attention and " "encoder/decoder cross-attention " "are not implemented for " "FlashAttentionImpl") self.vllm_flash_attn_version = get_flash_attn_version() if is_quantized_kv_cache(self.kv_cache_dtype) \ and not flash_attn_supports_fp8(): raise NotImplementedError( "FlashAttention does not support fp8 kv-cache on this device.") self.sinks: Optional[torch.Tensor] = None if sinks is not None: if sinks.shape[0] != num_heads: raise ValueError( "Sinks must have the same number of heads as the number of " f"heads in the layer. Expected {num_heads}, but got " f"{sinks.shape[0]}.") if sinks.dtype != torch.float32: raise ValueError("Sinks must be of type float32, but got " f"{sinks.dtype}.") self.sinks = sinks def forward( self, layer: torch.nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: SUPAFlashAttentionMetadata, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashAttention. Args: query: shape = [num_tokens, num_heads, head_size] key: shape = [num_tokens, num_kv_heads, head_size] value: shape = [num_tokens, num_kv_heads, head_size] kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] NOTE: FP8 quantization, flash-attn expect the size of {q,k,v}_descale to be (num_sequences, num_kv_heads). We use torch's .expand() to avoid duplicating values """ assert output is None, "Output tensor should not provided." if attn_metadata is None: # FIXME: this may lead to wrong block estimatation # Profiling run. return query is_encoder = self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER) # NOTE: supa attn use [batch_size, num_tokens, num_heads * head_size] as shape if kv_cache is not None and attn_metadata.do_cache and not is_encoder: torch_br.supa_kvcache_store_infer_v2( kv_cache, key, value, # type: ignore attn_metadata.slot_mapping, self.head_size) if self.sinks is not None: return self.forward_sw_sinks(query, kv_cache, attn_metadata) if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER): assert len(query.shape) == 3 return torch_br.supa_flash_attention_infer( # type: ignore query, key, value, attn_metadata.query_start_loc, self.head_size, len(attn_metadata.query_start_loc), # type: ignore self.alibi_slopes, softmax_scale=self.scale, is_causal=_get_causal_option(self.attn_type)) num_prefill_tokens = attn_metadata.num_prefill_tokens if attn_metadata.supagraph_runtime_mode is None or ( attn_metadata.supagraph_runtime_mode in (SUPAGraphMode.NONE, SUPAGraphMode.FULL_DECODE_ONLY)): # prefill + decode(non-mtp) if num_prefill_tokens > 0: output_prefill = torch_br.br_flash_attn_with_kvcache_infer( # type: ignore query, kv_cache, attn_metadata.query_start_loc, attn_metadata.seq_start_loc, attn_metadata.block_table, self.head_size, alibi_slopes=self.alibi_slopes, softmax_scale=self.scale, num_reqs=attn_metadata.num_actual_reqs) return output_prefill ## decode only output_decode = torch_br.supa_attention_decoder_infer_v2( # type: ignore query, # type: ignore kv_cache, attn_metadata.block_table, attn_metadata.seq_lens, attn_metadata.max_decode_seq_len, self.head_size, attn_metadata.num_prefills, self.alibi_slopes, softmax_scale=self.scale) return output_decode else: output_prefill = torch_br.br_flash_attn_with_kvcache_infer( # type: ignore query, kv_cache, attn_metadata.query_start_loc, attn_metadata.seq_start_loc, attn_metadata.block_table, self.head_size, alibi_slopes=self.alibi_slopes, softmax_scale=self.scale, num_reqs=attn_metadata.num_actual_reqs) return output_prefill # sliding window with sinks impl def forward_sw_sinks( self, query: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: SUPAFlashAttentionMetadata, ) -> torch.Tensor: # prefix-enabled attention output = torch_br.supa_flash_attn_cache_infer( # type: ignore query, kv_cache, attn_metadata.query_start_loc, attn_metadata.seq_start_loc, attn_metadata.block_table, attn_metadata.context_lens, attn_metadata.slot_mapping, attn_metadata.max_seq_len, self.head_size, window_size=self.sliding_window, sinks=self.sinks) return output def _get_causal_option(attn_type: str) -> bool: """ Determine whether the given attention type is suitable for causal attention mechanisms. Args: attn_type (AttentionType): The type of attention being evaluated Returns: bool: Returns `True` if the attention type is suitable for causal attention (i.e., not encoder, encoder-only, or encoder-decoder), otherwise returns `False`. """ return not (attn_type == AttentionType.ENCODER or attn_type == AttentionType.ENCODER_ONLY or attn_type == AttentionType.ENCODER_DECODER)