################################################################################ # Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ################################################################################ """Attention layer with FlashAttention.""" import os from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type import torch import torch_br from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionType, is_quantized_kv_cache) from vllm.attention.backends.utils import CommonAttentionState from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8, get_flash_attn_version) from vllm.logger import logger if TYPE_CHECKING: from vllm.worker.model_runner import (ModelInputForGPUBuilder) from collections import defaultdict from itertools import accumulate from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState, compute_slot_mapping, compute_slot_mapping_start_idx, is_block_tables_empty) from vllm.multimodal import MultiModalPlaceholderMap from vllm.utils import async_tensor_h2d, make_tensor_with_pad class SUPAFlashAttentionBackend(AttentionBackend): # NOTE: When piecewise cudagraph is enabled, this # makes sure the output tensor is allocated inside the cudagraph. # NOTE: currently, we do not support accept_output_buffer=True accept_output_buffer: bool = False @staticmethod def get_supported_head_sizes() -> list[int]: return [32, 64, 96, 128, 160, 192, 224, 256] @staticmethod def get_name() -> str: return "SUPAFLASH_ATTN_VLLM_V0" @staticmethod def get_impl_cls() -> type["SUPAFlashAttentionImpl"]: return SUPAFlashAttentionImpl @staticmethod def get_metadata_cls() -> type["SUPAFlashAttentionMetadata"]: return SUPAFlashAttentionMetadata @staticmethod def get_builder_cls() -> type["SUPAFlashAttentionMetadataBuilder"]: return SUPAFlashAttentionMetadataBuilder @staticmethod def get_state_cls() -> Type["CommonAttentionState"]: return CommonAttentionState @staticmethod def get_kv_cache_shape( num_blocks: int, block_size: int, num_kv_heads: int, head_size: int, ) -> tuple[int, ...]: if block_size % 16 != 0: raise ValueError("Block size must be a multiple of 16.") return (2, num_blocks, block_size, num_kv_heads, head_size) @staticmethod def get_kv_cache_usharp_shape( num_blocks: int, block_size: int, num_kv_heads: int, head_size: int, ) -> Tuple[int, ...]: th_gran = SUPAFlashAttentionBackend.get_kv_cache_usharp_alignment( block_size) n_block = max(1, (num_blocks + th_gran - 1) // th_gran) logger.debug( f'Origin kv cache shape is [2, {num_blocks}, {block_size}, {num_kv_heads}, {head_size}, For SUPA Speed up, use [2, {n_block}, {th_gran * block_size}, {num_kv_heads * head_size}]' # noqa: G004 ) return (2, n_block, th_gran * block_size, num_kv_heads * head_size) @staticmethod def get_kv_cache_usharp_alignment(block_size: int) -> int: max_h_limit = 2048 return max_h_limit // block_size @dataclass class SUPAFlashAttentionMetadata: # NOTE(sang): Definition of context_len, query_len, and seq_len. # |---------- N-1 iteration --------| # |---------------- N iteration ---------------------| # |- tokenA -|......................|-- newTokens ---| # |---------- context_len ----------| # |-------------------- seq_len ---------------------| # |-- query_len ---| num_actual_tokens: int # Number of tokens excluding padding. max_query_len: int query_start_loc: torch.Tensor max_seq_len: int seq_lens: torch.Tensor seq_lens_tensor: torch.Tensor block_table: torch.Tensor slot_mapping: torch.Tensor # BIREN Attention Params seq_start_loc: torch.Tensor context_lens: torch.Tensor max_decode_seq_len: int num_prefills: int num_decodes: int num_prefills_tokens: int do_cache: bool # when use attentionsplit, do cache = False # For cascade attention. use_cascade: bool common_prefix_len: int cu_prefix_query_lens: Optional[torch.Tensor] prefix_kv_lens: Optional[torch.Tensor] suffix_kv_lens: Optional[torch.Tensor] # Optional aot scheduling scheduler_metadata: Optional[torch.Tensor] = None prefix_scheduler_metadata: Optional[torch.Tensor] = None _cached_prefill_metadata: Optional["SUPAFlashAttentionMetadata"] = None _cached_decode_metadata: Optional["SUPAFlashAttentionMetadata"] = None # for local attention @dataclass class LocalAttentionMetadata: local_query_start_loc: torch.Tensor local_seqused_k: torch.Tensor local_block_table: torch.Tensor local_max_query_len: int local_max_seq_len: int local_scheduler_metadata: Optional[torch.Tensor] local_attn_metadata: Optional[LocalAttentionMetadata] = None @property def do_prefill(self) -> bool: return self.num_prefills > 0 @property def do_decode(self) -> bool: return self.num_decodes > 0 @property def prefill_metadata(self) -> Optional["SUPAFlashAttentionMetadata"]: if self.num_prefills == 0: return None if self._cached_prefill_metadata is not None: return self._cached_prefill_metadata else: return None class SUPAFlashAttentionMetadataBuilder: def __init__(self, input_builder: "ModelInputForGPUBuilder"): self.input_builder = input_builder self.runner = input_builder.runner self.sliding_window = input_builder.sliding_window self.block_size = input_builder.block_size def prepare(self): self.slot_mapping: List[int] = [] self.prefill_seq_lens: List[int] = [] self.context_lens: List[int] = [] self.block_tables: List[List[int]] = [] self.curr_seq_lens: List[int] = [] self.multimodal_placeholder_maps: Dict[ str, MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) self.num_prefills = 0 self.num_prefill_tokens = 0 self.num_decode_tokens = 0 self.has_prefix_cache_hit = False def _add_seq_group( self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", chunked_prefill_enabled: bool, prefix_cache_hit: bool): """Add a sequence group to the metadata. Specifically update/append 1. context length. 2. block table. 3. slot mapping. """ is_prompt = inter_data.is_prompt block_tables = inter_data.block_tables for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, curr_sliding_window_block) in zip( inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], inter_data.orig_seq_lens, inter_data.seq_lens, inter_data.query_lens, inter_data.context_lens, inter_data.curr_sliding_window_blocks, strict=False): self.context_lens.append(context_len) if is_prompt: mm_maps = inter_data.multi_modal_placeholder_maps if mm_maps: for modality, placeholders in mm_maps.items(): self.multimodal_placeholder_maps[modality].extend( placeholders) self.num_prefills += 1 self.num_prefill_tokens += token_len self.prefill_seq_lens.append(seq_len) else: self.num_decode_tokens += query_len self.curr_seq_lens.append(curr_seq_len) # Compute block table. # TODO(sang): Combine chunked prefill and prefix caching by # only allowing multiple of block_size chunk size. # NOTE: This only works for oooooooxxx style attention. block_table = [] if prefix_cache_hit: # NOTE(woosuk): For flash-attn, the block table should # include the entries for the incoming prefill tokens. block_table = block_tables[seq_id] elif ((chunked_prefill_enabled or not is_prompt) and block_tables is not None): if curr_sliding_window_block == 0: block_table = block_tables[seq_id] else: block_table = block_tables[seq_id][ -curr_sliding_window_block:] self.block_tables.append(block_table) # Compute slot mapping. is_profile_run = is_block_tables_empty(block_tables) start_idx = compute_slot_mapping_start_idx(is_prompt, query_len, context_len, self.sliding_window) compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, seq_len, context_len, start_idx, self.block_size, inter_data.block_tables) def _get_graph_runner_block_tables( self, num_seqs: int, block_tables: List[List[int]]) -> torch.Tensor: # The shape of graph_block_tables is # [max batch size, max context len // block size]. max_batch_size, max_blocks = self.runner.graph_block_tables.shape assert max_batch_size >= num_seqs graph_block_tables = self.runner.graph_block_tables[:num_seqs] for i, block_table in enumerate(block_tables): if block_table: num_blocks = len(block_table) if num_blocks <= max_blocks: graph_block_tables[i, :num_blocks] = block_table else: # It may be possible to have more blocks allocated due # to lookahead slots of multi-step, however, they are # not used anyway, so can be safely ignored. graph_block_tables[ i, :max_blocks] = block_table[:max_blocks] return torch.from_numpy(graph_block_tables).to( device=self.runner.device, non_blocking=True) def build(self, seq_lens: List[int], query_lens: List[int], cuda_graph_pad_size: int, batch_size: int): """Build attention metadata with on-device tensors. Args: seq_lens: The maybe padded sequence lengths of the input sequences. query_lens: The query lengths of the input sequences. cuda_graph_pad_size: The padding size for cuda graph. -1 if cuda graph is not used. batch_size: The maybe padded batch size. """ prefix_cache_hit = any([ inter_data.prefix_cache_hit for inter_data in self.input_builder.inter_data_list ]) for inter_data in self.input_builder.inter_data_list: self._add_seq_group(inter_data, self.input_builder.chunked_prefill_enabled, prefix_cache_hit) device = self.runner.device use_captured_graph = cuda_graph_pad_size != -1 max_query_len = max(query_lens) # decode_query_lens = query_lens[self.num_prefills:] # if len(decode_query_lens) > 0: # max_decode_query_len = max(decode_query_lens) # else: # max_decode_query_len = 1 max_prefill_seq_len = max(self.prefill_seq_lens, default=0) max_decode_seq_len = max(self.curr_seq_lens, default=0) num_decode_tokens = self.num_decode_tokens query_start_loc = list(accumulate(query_lens, initial=0)) seq_start_loc = list(accumulate(seq_lens, initial=0)) num_seqs = len(seq_lens) if use_captured_graph: self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) self.block_tables.extend([] * cuda_graph_pad_size) num_decode_tokens = batch_size - self.num_prefill_tokens block_tables = self._get_graph_runner_block_tables( num_seqs, self.block_tables) else: block_tables = make_tensor_with_pad( self.block_tables, pad=0, dtype=torch.int, device=device, ) assert max_query_len > 0, ("query_lens: {}".format(query_lens)) assert device is not None context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int, device, self.runner.pin_memory) seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, self.runner.pin_memory) slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long, device, self.runner.pin_memory) query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32, device, self.runner.pin_memory) return SUPAFlashAttentionMetadata( num_actual_tokens=batch_size, max_query_len=max_query_len, query_start_loc=query_start_loc_tensor, max_seq_len=max_prefill_seq_len, seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, block_table=block_tables, slot_mapping=slot_mapping_tensor, use_cascade=False, common_prefix_len=0, scheduler_metadata=0, cu_prefix_query_lens=None, prefix_kv_lens=None, suffix_kv_lens=None, local_attn_metadata=None, prefix_scheduler_metadata=None, # Biren Attention Params seq_start_loc=seq_start_loc, context_lens=context_lens_tensor, max_decode_seq_len=max_decode_seq_len, num_prefills=self.num_prefills, num_decodes=num_decode_tokens, num_prefills_tokens=self.num_prefill_tokens, do_cache=False) class SUPAFlashAttentionImpl(AttentionImpl): def __init__( self, num_heads: int, head_size: int, scale: float, num_kv_heads: int, alibi_slopes: Optional[list[float]], sliding_window: Optional[int], kv_cache_dtype: str, blocksparse_params: Optional[dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, attn_type: AttentionType = AttentionType.DECODER, use_irope: bool = False, ) -> None: if blocksparse_params is not None: raise ValueError( "FlashAttention does not support block-sparse attention.") self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) self.num_kv_heads = num_kv_heads self.attn_type = attn_type if alibi_slopes is not None: alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) self.alibi_slopes = alibi_slopes if sliding_window is None: self.sliding_window = (-1, -1) else: self.sliding_window = (sliding_window - 1, 0) self.kv_cache_dtype = kv_cache_dtype if logits_soft_cap is None: # In flash-attn, setting logits_soft_cap as 0 means no soft cap. logits_soft_cap = 0 self.logits_soft_cap = logits_soft_cap assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads support_head_sizes = SUPAFlashAttentionBackend.get_supported_head_sizes( ) if head_size not in support_head_sizes: raise ValueError( f"Head size {head_size} is not supported by FlashAttention. " f"Supported head sizes are: {support_head_sizes}. " "Set VLLM_USE_V1=1 to use another attention backend.") self.use_irope = use_irope self.vllm_flash_attn_version = get_flash_attn_version() if is_quantized_kv_cache(self.kv_cache_dtype) \ and not flash_attn_supports_fp8(): raise NotImplementedError( "FlashAttention does not support fp8 kv-cache on this device.") def forward( self, layer: torch.nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: SUPAFlashAttentionMetadata, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashAttention. Args: query: shape = [num_tokens, num_heads, head_size] key: shape = [num_tokens, num_kv_heads, head_size] value: shape = [num_tokens, num_kv_heads, head_size] kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] NOTE: FP8 quantization, flash-attn expect the size of {q,k,v}_descale to be (num_sequences, num_kv_heads). We use torch's .expand() to avoid duplicating values """ assert output is None, "Output tensor should not provided." if attn_metadata is None: # FIXME: this may lead to wrong block estimatation # Profiling run. return query # NOTE: supa attn use [batch_size, num_tokens, num_heads * head_size] as shape if kv_cache is not None and attn_metadata.do_cache: torch_br.supa_kvcache_store_infer_v2( kv_cache, key, value, # type: ignore attn_metadata.slot_mapping, self.head_size) output_prefill = output_decode = None output = torch.empty_like(query) if attn_metadata.do_prefill and attn_metadata.do_decode: # chunked decode_query = query[:, attn_metadata.num_prefills_tokens:] query = query[:, :attn_metadata.num_prefills_tokens] key = key[:, :attn_metadata.num_prefills_tokens] value = value[:, :attn_metadata.num_prefills_tokens] elif attn_metadata.do_decode: decode_query = query if attn_metadata.do_prefill: if (kv_cache is None or attn_metadata.block_table.numel() == 0): # has do_decode should go into prefix-enabled branch assert not attn_metadata.do_decode # in this branch, query_start_loc = seq_start_loc if os.getenv('USE_BR_SUEAGER_SDPA', 'False').lower() not in {'false', '0', ''}: output_prefill, inter_mediate = torch_br.sueager_scaled_dot_product_attention_fwd( query=query, key=key, value=value, mask=None, dropout_prob=0.0, is_causal=_get_causal_option(self.attn_type), scale=self.scale, algorithm="FMHA", ) output_prefill = torch_br.supa_shape_transform_qkv( output_prefill, 1, query.shape[1], self.num_kv_heads, self.head_size) else: output_prefill = torch_br.supa_flash_attention_infer( # type: ignore query, key, value, attn_metadata.query_start_loc, self.head_size, len(attn_metadata.query_start_loc), # type: ignore self.alibi_slopes, softmax_scale=self.scale, is_causal=_get_causal_option(self.attn_type)) else: # prefix-enabled attention output_prefill = torch_br.supa_flash_attn_cache_infer( # type: ignore query, kv_cache, attn_metadata.query_start_loc, attn_metadata.seq_start_loc, attn_metadata.block_table, attn_metadata.context_lens, attn_metadata.slot_mapping, attn_metadata.max_seq_len, self.head_size, self.alibi_slopes, softmax_scale=self.scale) if attn_metadata.do_decode: output_decode = torch_br.supa_attention_decoder_infer_v2( # type: ignore decode_query, # type: ignore kv_cache, attn_metadata.block_table, attn_metadata.seq_lens, attn_metadata.max_decode_seq_len, self.head_size, attn_metadata.num_prefills, self.alibi_slopes, softmax_scale=self.scale) if attn_metadata.do_prefill and attn_metadata.do_decode: output[:, :attn_metadata.num_prefills_tokens] = output_prefill output[:, attn_metadata.num_prefills_tokens:] = output_decode elif attn_metadata.do_prefill: output = output_prefill else: output = output_decode return output def _get_causal_option(attn_type: str) -> bool: """ Determine whether the given attention type is suitable for causal attention mechanisms. Args: attn_type (AttentionType): The type of attention being evaluated Returns: bool: Returns `True` if the attention type is suitable for causal attention (i.e., not encoder, encoder-only, or encoder-decoder), otherwise returns `False`. """ return not (attn_type == AttentionType.ENCODER or attn_type == AttentionType.ENCODER_ONLY or attn_type == AttentionType.ENCODER_DECODER)