"""Attention layer with FlashAttention.""" from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type import torch from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionMetadataBuilder, AttentionType) from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState, compute_slot_mapping, compute_slot_mapping_start_idx, is_block_tables_empty) from vllm.forward_context import get_forward_context from vllm.utils import async_tensor_h2d, make_tensor_with_pad if TYPE_CHECKING: from vllm.worker.model_runner import (ModelInputForGPUBuilder, ModelInputForGPUWithSamplingMetadata) from ixformer.contrib.vllm_flash_attn import flash_attn_varlen_func as _flash_attn_varlen_func from ixformer.contrib.vllm_flash_attn import flash_attn_with_kvcache as _flash_attn_with_kvcache def flash_attn_varlen_func( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, cu_seqlens_q: torch.Tensor, cu_seqlens_k: torch.Tensor, max_seqlen_q: int, max_seqlen_k: int, softmax_scale: Optional[float] = None, causal: bool = False, window_size: Optional[List[int]] = None, softcap: float = 0.0, alibi_slopes: Optional[torch.Tensor] = None, block_table: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, use_sqrt_alibi: Optional[bool] = False ) -> torch.Tensor: # custom op does not support tuple input real_window_size: Tuple[int, int] if window_size is None: real_window_size = (-1, -1) else: assert len(window_size) == 2 real_window_size = (window_size[0], window_size[1]) return _flash_attn_varlen_func( q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_q, max_seqlen_k=max_seqlen_k, softmax_scale=softmax_scale, causal=causal, window_size=real_window_size, softcap=softcap, alibi_slopes=alibi_slopes, block_table=block_table, out=out, sqrt_alibi=use_sqrt_alibi, ) def flash_attn_with_kvcache( decode_query: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, seq_lens_cpu_tensors: torch.Tensor, max_context_len: int, cache_seqlens: Optional[torch.Tensor] = None, block_table: Optional[torch.Tensor] = None, softmax_scale: Optional[float] = None, causal: bool = False, alibi_slopes: Optional[torch.Tensor] = None, softcap: float = 0.0, out: Optional[torch.Tensor] = None, use_sqrt_alibi: bool = False ) -> torch.Tensor: return _flash_attn_with_kvcache( decode_query, key_cache, value_cache, cache_seqlens=cache_seqlens, block_table=block_table, softmax_scale=softmax_scale, causal=causal, alibi_slopes=alibi_slopes, softcap=softcap, max_context_len=max_context_len, cache_seqlens_cpu=seq_lens_cpu_tensors, out=out, use_sqrt_alibi=use_sqrt_alibi ) def reshape_and_cache_flash( key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, slot_mapping: torch.Tensor, kv_cache_dtype: str, k_scale: float, v_scale: float, ) -> None: """Inductor cannot deal with inplace operations on views. See https://github.com/pytorch/pytorch/issues/131192 and https://github.com/pytorch/pytorch/issues/130174 This is a workaround to hide the view operation from the inductor. """ return ops.reshape_and_cache_flash( key, value, kv_cache[0], kv_cache[1], slot_mapping, kv_cache_dtype, k_scale, v_scale) class FlashAttentionBackend(AttentionBackend): @staticmethod def get_supported_head_sizes() -> List[int]: return [32, 64, 80, 96, 128, 160, 192, 224, 256] @staticmethod def get_name() -> str: return "flash-attn" @staticmethod def get_impl_cls() -> Type["FlashAttentionImpl"]: return FlashAttentionImpl @staticmethod def get_metadata_cls() -> Type["AttentionMetadata"]: return FlashAttentionMetadata @staticmethod def get_builder_cls() -> Type["FlashAttentionMetadataBuilder"]: return FlashAttentionMetadataBuilder @staticmethod def get_state_cls() -> Type["CommonAttentionState"]: return CommonAttentionState @staticmethod def get_kv_cache_shape( num_blocks: int, block_size: int, num_kv_heads: int, head_size: int, ) -> Tuple[int, ...]: if block_size % 16 != 0: raise ValueError("Block size must be a multiple of 16.") return (2, num_blocks, num_kv_heads, block_size, head_size) @staticmethod def swap_blocks( src_kv_cache: torch.Tensor, dst_kv_cache: torch.Tensor, src_to_dst: torch.Tensor, ) -> None: src_key_cache = src_kv_cache[0] dst_key_cache = dst_kv_cache[0] ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) src_value_cache = src_kv_cache[1] dst_value_cache = dst_kv_cache[1] ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) @staticmethod def copy_blocks( kv_caches: List[torch.Tensor], src_to_dists: torch.Tensor, ) -> None: key_caches = [kv_cache[0] for kv_cache in kv_caches] value_caches = [kv_cache[1] for kv_cache in kv_caches] ops.copy_blocks(key_caches, value_caches, src_to_dists) @dataclass class FlashAttentionMetadata(AttentionMetadata): """Metadata for FlashAttentionBackend. NOTE: Any python object stored here is not updated when it is cuda-graph replayed. If you have values that need to be changed dynamically, it should be stored in tensor. The tensor has to be updated from `CUDAGraphRunner.forward` API. """ # (batch_size,). The sequence length per sequence. Sequence length means # the computed tokens + new tokens None if it is a decoding. seq_lens: Optional[List[int]] # seq_lens stored as a tensor. seq_lens_tensor: Optional[torch.Tensor] # NOTE(sang): Definition of context_len, query_len, and seq_len. # |---------- N-1 iteration --------| # |---------------- N iteration ---------------------| # |- tokenA -|......................|-- newTokens ---| # |---------- context_len ----------| # |-------------------- seq_len ---------------------| # |-- query_len ---| # Maximum query length in the batch. max_query_len: Optional[int] # Max number of query tokens among request in the batch. max_decode_query_len: Optional[int] # Maximum sequence length among prefill batch. 0 if there are decoding # requests only. max_prefill_seq_len: int # Maximum sequence length among decode batch. 0 if there are prefill # requests only. max_decode_seq_len: int # (batch_size + 1,). The cumulative subquery lengths of the sequences in # the batch, used to index into subquery. E.g., if the subquery length # is [4, 6], it is [0, 4, 10]. query_start_loc: Optional[torch.Tensor] # (batch_size + 1,). The cumulative sequence lengths of the sequences in # the batch, used to index into sequence. E.g., if the sequence length is # [4, 6], it is [0, 4, 10]. seq_start_loc: Optional[torch.Tensor] # (batch_size,) A tensor of context lengths (tokens that are computed # so far). context_lens_tensor: Optional[torch.Tensor] # (batch_size, max_blocks_per_seq). # Block addresses per sequence. (Seq id -> list of physical block) # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks # in the kv cache. Each block can contain up to block_size tokens. # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph # captured. block_tables: Optional[torch.Tensor] # Whether or not if cuda graph is enabled. # Cuda-graph is currently enabled for decoding only. # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. use_cuda_graph: bool _cached_prefill_metadata: Optional["FlashAttentionMetadata"] = None _cached_decode_metadata: Optional["FlashAttentionMetadata"] = None # Begin encoder attn & enc/dec cross-attn fields... # Encoder sequence lengths representation encoder_seq_lens: Optional[List[int]] = None encoder_seq_lens_tensor: Optional[torch.Tensor] = None # Maximum sequence length among encoder sequences max_encoder_seq_len: Optional[int] = None # Number of tokens input to encoder num_encoder_tokens: Optional[int] = None # Cross-attention memory-mapping data structures: slot mapping # and block tables cross_slot_mapping: Optional[torch.Tensor] = None cross_block_tables: Optional[torch.Tensor] = None # (batch_size + 1,). The cumulative sequence lengths of the sequences in # the batch, used to index into sequence on encoder-decoder encoder_seq_start_loc: Optional[torch.Tensor] = None # Our impl need info fields... seq_lens_cpu_tensors: Optional[torch.Tensor] = None encoder_seq_lens_cpu_tensor: Optional[torch.Tensor] = None @property def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]: if self.num_prefills == 0: return None if self._cached_prefill_metadata is not None: return self._cached_prefill_metadata assert self.seq_lens is not None assert self.seq_lens_tensor is not None assert self.query_start_loc is not None assert self.context_lens_tensor is not None assert self.block_tables is not None assert self.seq_start_loc is not None self._cached_prefill_metadata = FlashAttentionMetadata( num_prefills=self.num_prefills, num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=0, slot_mapping=self.slot_mapping[:self.num_prefill_tokens], seq_lens=self.seq_lens[:self.num_prefills], seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], max_query_len=self.max_query_len, max_prefill_seq_len=self.max_prefill_seq_len, max_decode_query_len=0, max_decode_seq_len=0, query_start_loc=self.query_start_loc[:self.num_prefills + 1], seq_start_loc=self.seq_start_loc[:self.num_prefills + 1], context_lens_tensor=self.context_lens_tensor[:self.num_prefills], block_tables=self.block_tables[:self.num_prefills], use_cuda_graph=False, # Begin encoder & cross attn fields below... encoder_seq_lens=self.encoder_seq_lens, encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, encoder_seq_start_loc=self.encoder_seq_start_loc, max_encoder_seq_len=self.max_encoder_seq_len, cross_slot_mapping=self.cross_slot_mapping, cross_block_tables=self.cross_block_tables ) return self._cached_prefill_metadata @property def decode_metadata(self) -> Optional["FlashAttentionMetadata"]: if self.num_decode_tokens == 0: return None if self._cached_decode_metadata is not None: return self._cached_decode_metadata assert self.block_tables is not None assert self.seq_lens_tensor is not None seq_lens_cpu_tensors = torch.tensor(self.seq_lens[self.num_prefills:],dtype=torch.int32,device="cpu") encoder_seq_lens_cpu_tensor = torch.tensor(self.encoder_seq_lens,dtype=torch.int32,device="cpu") if self.encoder_seq_lens is not None else None max_seq_len = self.seq_lens_tensor.max().item() self._cached_decode_metadata = FlashAttentionMetadata( num_prefills=0, num_prefill_tokens=0, num_decode_tokens=self.num_decode_tokens, slot_mapping=self.slot_mapping[self.num_prefill_tokens:], seq_lens=None, seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], max_decode_query_len=self.max_decode_query_len, max_query_len=max_seq_len, max_prefill_seq_len=0, max_decode_seq_len=self.max_decode_seq_len, query_start_loc=self.query_start_loc[self.num_prefills:] if self.query_start_loc is not None else None, seq_start_loc=self.seq_start_loc[self.num_prefills:] if self.seq_start_loc is not None else None, context_lens_tensor=None, block_tables=self.block_tables[self.num_prefills:], use_cuda_graph=self.use_cuda_graph, seq_lens_cpu_tensors=seq_lens_cpu_tensors, # Begin encoder & cross attn fields below... encoder_seq_lens=self.encoder_seq_lens, encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, encoder_seq_lens_cpu_tensor=encoder_seq_lens_cpu_tensor, max_encoder_seq_len=self.max_encoder_seq_len, cross_slot_mapping=self.cross_slot_mapping, cross_block_tables=self.cross_block_tables ) return self._cached_decode_metadata def advance_step(self, model_input: "ModelInputForGPUWithSamplingMetadata", sampled_token_ids: Optional[torch.Tensor], block_size: int, num_seqs: int, num_queries: int, turn_prefills_into_decodes: bool = False): """ Update metadata in-place to advance one decode step. """ # When using cudagraph, the num_seqs is padded to the next captured # batch sized, but num_queries tracks the actual number of requests in # the batch. For --enforce-eager mode, num_seqs == num_queries if num_seqs != num_queries: assert num_seqs > num_queries assert self.use_cuda_graph if turn_prefills_into_decodes: # When Mutli-Step is enabled with Chunked-Prefill, prefills and # decodes are scheduled together. In the first step, all the # prefills turn into decodes. This update reflects that # conversion. assert self.num_decode_tokens + self.num_prefills == num_seqs self.num_decode_tokens += self.num_prefills self.num_prefills = 0 self.num_prefill_tokens = 0 self.max_prefill_seq_len = 0 self.max_query_len = 1 self.slot_mapping = self.slot_mapping[:num_seqs] else: assert self.seq_lens is not None assert self.max_decode_seq_len == max(self.seq_lens) assert self.num_prefills == 0 assert self.num_prefill_tokens == 0 assert self.num_decode_tokens == num_seqs assert self.slot_mapping.shape == (num_seqs, ) assert self.seq_lens is not None assert len(self.seq_lens) == num_seqs assert self.seq_lens_tensor is not None assert self.seq_lens_tensor.shape == (num_seqs, ) assert self.max_query_len == 1 assert self.max_prefill_seq_len == 0 assert self.query_start_loc is not None assert self.query_start_loc.shape == (num_queries + 1, ) assert self.seq_start_loc is not None assert self.seq_start_loc.shape == (num_seqs + 1, ) assert self.context_lens_tensor is not None assert self.context_lens_tensor.shape == (num_queries, ) assert self.block_tables is not None assert self.block_tables.shape[0] == num_seqs # Update query lengths. Note that we update only queries and not seqs, # since tensors may be padded due to captured cuda graph batch size for i in range(num_queries): self.seq_lens[i] += 1 self.max_decode_seq_len = max(self.seq_lens) ops.advance_step_flashattn(num_seqs=num_seqs, num_queries=num_queries, block_size=block_size, input_tokens=model_input.input_tokens, sampled_token_ids=sampled_token_ids, input_positions=model_input.input_positions, seq_lens=self.seq_lens_tensor, slot_mapping=self.slot_mapping, block_tables=self.block_tables) class FlashAttentionMetadataBuilder( AttentionMetadataBuilder[FlashAttentionMetadata]): def __init__(self, input_builder: "ModelInputForGPUBuilder"): self.slot_mapping: List[int] = [] self.prefill_seq_lens: List[int] = [] self.context_lens: List[int] = [] self.block_tables: List[List[int]] = [] self.curr_seq_lens: List[int] = [] self.num_prefills = 0 self.num_prefill_tokens = 0 self.num_decode_tokens = 0 self.has_prefix_cache_hit = False self.input_builder = input_builder self.runner = input_builder.runner self.sliding_window = input_builder.sliding_window self.block_size = input_builder.block_size self.use_v2_block_manager = ( input_builder.scheduler_config.use_v2_block_manager) def _add_seq_group( self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", chunked_prefill_enabled: bool, prefix_cache_hit: bool): """Add a sequence group to the metadata. Specifically update/append 1. context length. 2. block table. 3. slot mapping. """ is_prompt = inter_data.is_prompt block_tables = inter_data.block_tables for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, curr_sliding_window_block) in zip( inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], inter_data.orig_seq_lens, inter_data.seq_lens, inter_data.query_lens, inter_data.context_lens, inter_data.curr_sliding_window_blocks): self.context_lens.append(context_len) if is_prompt: self.num_prefills += 1 self.num_prefill_tokens += token_len self.prefill_seq_lens.append(seq_len) else: self.num_decode_tokens += query_len self.curr_seq_lens.append(curr_seq_len) # Compute block table. # TODO(sang): Combine chunked prefill and prefix caching by # only allowing multiple of block_size chunk size. # NOTE: This only works for oooooooxxx style attention. block_table = [] if prefix_cache_hit: # NOTE(woosuk): For flash-attn, the block table should # include the entries for the incoming prefill tokens. block_table = block_tables[seq_id] elif ((chunked_prefill_enabled or not is_prompt) and block_tables is not None): if curr_sliding_window_block == 0: block_table = block_tables[seq_id] else: block_table = block_tables[seq_id][ -curr_sliding_window_block:] self.block_tables.append(block_table) # Compute slot mapping. is_profile_run = is_block_tables_empty(block_tables) start_idx = compute_slot_mapping_start_idx( is_prompt, query_len, context_len, self.sliding_window, self.use_v2_block_manager) compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, seq_len, context_len, start_idx, self.block_size, inter_data.block_tables) def _get_graph_runner_block_tables( self, num_seqs: int, block_tables: List[List[int]]) -> torch.Tensor: # The shape of graph_block_tables is # [max batch size, max context len // block size]. max_batch_size, max_blocks = self.runner.graph_block_tables.shape assert max_batch_size >= num_seqs graph_block_tables = self.runner.graph_block_tables[:num_seqs] for i, block_table in enumerate(block_tables): if block_table: num_blocks = len(block_table) if num_blocks <= max_blocks: graph_block_tables[i, :num_blocks] = block_table else: # It may be possible to have more blocks allocated due # to lookahead slots of multi-step, however, they are # not used anyway, so can be safely ignored. graph_block_tables[ i, :max_blocks] = block_table[:max_blocks] return torch.from_numpy(graph_block_tables).to( device=self.runner.device, non_blocking=True) def build(self, seq_lens: List[int], query_lens: List[int], cuda_graph_pad_size: int, batch_size: int): """Build attention metadata with on-device tensors. Args: seq_lens: The maybe padded sequence lengths of the input sequences. query_lens: The query lengths of the input sequences. cuda_graph_pad_size: The padding size for cuda graph. -1 if cuda graph is not used. batch_size: The maybe padded batch size. """ prefix_cache_hit = any([ inter_data.prefix_cache_hit for inter_data in self.input_builder.inter_data_list ]) for inter_data in self.input_builder.inter_data_list: self._add_seq_group(inter_data, self.input_builder.chunked_prefill_enabled, prefix_cache_hit) device = self.runner.device use_captured_graph = cuda_graph_pad_size != -1 max_query_len = max(query_lens) decode_query_lens = query_lens[self.num_prefills:] if len(decode_query_lens) > 0: max_decode_query_len = max(decode_query_lens) else: max_decode_query_len = 1 max_prefill_seq_len = max(self.prefill_seq_lens, default=0) max_decode_seq_len = max(self.curr_seq_lens, default=0) num_decode_tokens = self.num_decode_tokens num_seqs = len(seq_lens) if use_captured_graph: self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) self.block_tables.extend([] * cuda_graph_pad_size) num_decode_tokens = batch_size - self.num_prefill_tokens block_tables = self._get_graph_runner_block_tables( num_seqs, self.block_tables) else: block_tables = make_tensor_with_pad( self.block_tables, pad=0, dtype=torch.int, device=device, ) assert max_query_len > 0, ("query_lens: {}".format(query_lens)) assert device is not None context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int, device, self.runner.pin_memory) seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, self.runner.pin_memory) query_lens_tensor = async_tensor_h2d(query_lens, torch.long, device, self.runner.pin_memory) slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long, device, self.runner.pin_memory) query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, dtype=torch.int32, device=device) seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, dtype=torch.int32, device=device) torch.cumsum(seq_lens_tensor, dim=0, dtype=seq_start_loc.dtype, out=seq_start_loc[1:]) torch.cumsum(query_lens_tensor, dim=0, dtype=query_start_loc.dtype, out=query_start_loc[1:]) return FlashAttentionMetadata( num_prefills=self.num_prefills, slot_mapping=slot_mapping_tensor, num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, max_query_len=max_query_len, max_decode_query_len=max_decode_query_len, max_prefill_seq_len=max_prefill_seq_len, max_decode_seq_len=max_decode_seq_len, query_start_loc=query_start_loc, seq_start_loc=seq_start_loc, context_lens_tensor=context_lens_tensor, block_tables=block_tables, use_cuda_graph=use_captured_graph, ) class FlashAttentionImpl(AttentionImpl): """ If the input tensors contain prompt tokens, the layout is as follows: |<--------------- num_prefill_tokens ----------------->| |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->| Otherwise, the layout is as follows: |<----------------- num_decode_tokens ------------------>| |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->| Generation tokens can contain padding when cuda-graph is used. Currently, prompt tokens don't contain any padding. The prompts might have different lengths, while the generation tokens always have length 1. If chunked prefill is enabled, prefill tokens and decode tokens can be batched together in a flattened 1D query. |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->| |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->| Currently, cuda graph is disabled for chunked prefill, meaning there's no padding between prefill and decode tokens. """ def __init__( self, num_heads: int, head_size: int, scale: float, num_kv_heads: int, alibi_slopes: Optional[List[float]], sliding_window: Optional[int], kv_cache_dtype: str, blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, use_sqrt_alibi: bool = None ) -> None: if blocksparse_params is not None: raise ValueError( "FlashAttention does not support block-sparse attention yet, we will support soon") self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) self.num_kv_heads = num_kv_heads if alibi_slopes is not None: alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) self.alibi_slopes = alibi_slopes self.use_sqrt_alibi = use_sqrt_alibi self.sliding_window = ((sliding_window, sliding_window) if sliding_window is not None else (-1, -1)) self.kv_cache_dtype = kv_cache_dtype if logits_soft_cap is None: # In flash-attn, setting logits_soft_cap as 0 means no soft cap. logits_soft_cap = 0 self.logits_soft_cap = logits_soft_cap assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads if sliding_window is not None: # NOTE(woosuk): flash-attn's sliding window does not work with # paged KV cache. # TODO will support on next week. self.sliding_window = None support_head_sizes = FlashAttentionBackend.get_supported_head_sizes() if head_size not in support_head_sizes: raise ValueError( f"Head size {head_size} is not supported by FlashAttention. " f"Supported head sizes are: {support_head_sizes}.") def forward( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: FlashAttentionMetadata, k_scale: float = 1.0, v_scale: float = 1.0, attn_type: AttentionType = AttentionType.DECODER, ) -> torch.Tensor: """Forward pass with FlashAttention. Args: query: shape = [num_tokens, num_heads * head_size] key: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size] kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] NOTE: kv_cache will be an empty tensor with shape [0] for profiling run. attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] """ # NOTE(woosuk): FlashAttention does not support FP8 KV cache. assert k_scale == 1.0 and v_scale == 1.0, ( "key/v_scale is not supported in FlashAttention.") output = unified_flash_attention( query, key, value, self.num_heads, self.head_size, self.num_kv_heads, kv_cache, self.kv_cache_dtype, k_scale, v_scale, self.scale, self.sliding_window, self.alibi_slopes, self.logits_soft_cap, attn_type=attn_type, use_sqrt_alibi=self.use_sqrt_alibi, ) return output def unified_flash_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, num_heads: int, head_size: int, num_kv_heads: int, kv_cache: torch.Tensor, kv_cache_dtype: str, k_scale: float, v_scale: float, softmax_scale: float, window_size: Optional[List[int]] = None, alibi_slopes: Optional[torch.Tensor] = None, logits_soft_cap: Optional[float] = None, attn_type: AttentionType = AttentionType.DECODER, use_sqrt_alibi: bool = False ) -> torch.Tensor: current_metadata = get_forward_context() assert current_metadata is not None assert isinstance(current_metadata, FlashAttentionMetadata) attn_metadata: FlashAttentionMetadata = current_metadata # Reshape the query, key, and value tensors. query = query.view(-1, num_heads, head_size) if key is not None: assert value is not None key = key.view(-1, num_kv_heads, head_size) value = value.view(-1, num_kv_heads, head_size) else: assert value is None if attn_type != AttentionType.ENCODER and kv_cache.numel() > 0: key_cache = kv_cache[0] value_cache = kv_cache[1] # Reshape the input keys and values and store them in the cache. # If kv_cache is not provided, the new key and value tensors are # not cached. This happens during the initial memory profiling run. if (key is not None) and (value is not None): if attn_type == AttentionType.ENCODER_DECODER: # Update cross-attention KV cache (prefill-only) # During cross-attention decode, key & value will be None, # preventing this IF-statement branch from running updated_slot_mapping = attn_metadata.cross_slot_mapping.flatten() else: # Update self-attention KV cache (prefill/decode) updated_slot_mapping = attn_metadata.slot_mapping.flatten() ops.reshape_and_cache_flash( key, value, key_cache, value_cache, updated_slot_mapping, kv_cache_dtype, k_scale, v_scale, ) if attn_type == AttentionType.ENCODER: # Encoder attention - chunked prefill is not applicable; # derive token-count from query shape & and treat them # as 100% prefill tokens assert attn_metadata.num_encoder_tokens is not None num_prefill_tokens = attn_metadata.num_encoder_tokens num_encoder_tokens = attn_metadata.num_encoder_tokens num_decode_tokens = 0 elif attn_type == AttentionType.DECODER: # Decoder self-attention supports chunked prefill. num_prefill_tokens = attn_metadata.num_prefill_tokens num_encoder_tokens = attn_metadata.num_prefill_tokens num_decode_tokens = attn_metadata.num_decode_tokens assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \ f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \ f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa else: # attn_type == AttentionType.ENCODER_DECODER # Encoder/decoder cross-attention requires no chunked # prefill (100% prefill or 100% decode tokens, no mix) num_prefill_tokens = attn_metadata.num_prefill_tokens if attn_metadata.num_encoder_tokens is not None: num_encoder_tokens = attn_metadata.num_encoder_tokens else: num_encoder_tokens = attn_metadata.num_prefill_tokens num_decode_tokens = attn_metadata.num_decode_tokens # Query for decode. KV is not needed because it is already cached. output = torch.empty_like(query) decode_query = query[num_prefill_tokens:] # QKV for prefill. query = query[:num_prefill_tokens] if key is not None and value is not None: key = key[:num_encoder_tokens] value = value[:num_encoder_tokens] assert query.shape[0] == num_prefill_tokens assert decode_query.shape[0] == num_decode_tokens # prefill_output: Optional[torch.Tensor] = None # decode_output: Optional[torch.Tensor] = None if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. if (kv_cache.numel() == 0 or prefill_meta.block_tables is None or prefill_meta.block_tables.numel() == 0): # normal attention # When block_tables are not filled, it means q and k are the # prompt, and they have the same length. flash_attn_varlen_func( q=query, k=key, v=value, cu_seqlens_q=prefill_meta.encoder_seq_start_loc if attn_type == AttentionType.ENCODER else prefill_meta.seq_start_loc, cu_seqlens_k=prefill_meta.seq_start_loc if attn_type == AttentionType.DECODER else prefill_meta.encoder_seq_start_loc, max_seqlen_q=prefill_meta.max_encoder_seq_len if attn_type == AttentionType.ENCODER else prefill_meta.max_prefill_seq_len, max_seqlen_k=prefill_meta.max_prefill_seq_len if attn_type == AttentionType.DECODER else prefill_meta.max_encoder_seq_len, softmax_scale=softmax_scale, causal=attn_type == AttentionType.DECODER, window_size=window_size, alibi_slopes=alibi_slopes, softcap=logits_soft_cap, out=output[:num_prefill_tokens], use_sqrt_alibi=use_sqrt_alibi ) else: # prefix-enabled attention assert prefill_meta.seq_lens is not None max_seq_len = max(prefill_meta.seq_lens) flash_attn_varlen_func( # noqa q=query, k=key_cache, v=value_cache, cu_seqlens_q=prefill_meta.query_start_loc, max_seqlen_q=prefill_meta.max_query_len, cu_seqlens_k=prefill_meta.seq_start_loc, max_seqlen_k=max_seq_len, softmax_scale=softmax_scale, causal=attn_type == AttentionType.DECODER, alibi_slopes=alibi_slopes, block_table=prefill_meta.block_tables, softcap=logits_soft_cap, out=output[:num_prefill_tokens], use_sqrt_alibi=use_sqrt_alibi ) if decode_meta := attn_metadata.decode_metadata: # Decoding run. # Use flash_attn_varlen_func kernel for speculative decoding # because different queries might have different lengths. assert decode_meta.max_decode_query_len is not None if decode_meta.max_decode_query_len > 1: flash_attn_varlen_func( q=decode_query, k=key_cache, v=value_cache, cu_seqlens_q=decode_meta.query_start_loc, max_seqlen_q=decode_meta.max_decode_query_len, cu_seqlens_k=decode_meta.seq_start_loc, max_seqlen_k=decode_meta.max_decode_seq_len, softmax_scale=softmax_scale, causal=True, alibi_slopes=alibi_slopes, softcap=0.0, block_table=decode_meta.block_tables, out=output[num_prefill_tokens:], ) else: # Use flash_attn_with_kvcache for normal decoding. flash_attn_with_kvcache( decode_query.unsqueeze(1), key_cache, value_cache, seq_lens_cpu_tensors=decode_meta.seq_lens_cpu_tensors if attn_type == AttentionType.DECODER else decode_meta.encoder_seq_lens_cpu_tensor, max_context_len=decode_meta.max_query_len if attn_type == AttentionType.DECODER else decode_meta.max_encoder_seq_len, block_table=decode_meta.block_tables if attn_type == AttentionType.DECODER else decode_meta.cross_block_tables, cache_seqlens=decode_meta.seq_lens_tensor if attn_type == AttentionType.DECODER else decode_meta.encoder_seq_lens_tensor, softmax_scale=softmax_scale, causal=True, alibi_slopes=alibi_slopes, softcap=0.0, out=output[num_prefill_tokens:].unsqueeze(1), use_sqrt_alibi=use_sqrt_alibi ).squeeze(1) # TODO mv this to flash_attn_with_kvcache when supported. if logits_soft_cap != 0.0: output[num_prefill_tokens:] = logits_soft_cap * torch.tanh(output[num_prefill_tokens:] / logits_soft_cap) return output.view(-1, num_heads * head_size) # @unified_flash_attention.register_fake def _( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, num_heads: int, head_size: int, num_kv_heads: int, kv_cache: torch.Tensor, kv_cache_dtype: str, k_scale: float, v_scale: float, softmax_scale: float, window_size: Optional[List[int]] = None, alibi_slopes: Optional[torch.Tensor] = None, logits_soft_cap: Optional[float] = None, ) -> torch.Tensor: return torch.empty_like(query)