# Copyright (c) 2023, Tri Dao. from typing import Optional, Union, Tuple, List import torch import torch.nn as nn # isort: off # We need to import the CUDA kernels after importing torch # Use relative import to support build-from-source installation in vLLM try: from . import _vllm_fa2_C # noqa: F401 FA2_UNAVAILABLE_REASON = None FA2_AVAILABLE = True except ImportError as e: FA2_UNAVAILABLE_REASON = str(e) FA2_AVAILABLE = False try: from . import _vllm_fa3_C # noqa: F401 FA3_UNAVAILABLE_REASON = None FA3_AVAILABLE = True except ImportError as e: FA3_UNAVAILABLE_REASON = str(e) FA3_AVAILABLE = False # isort: on DEFAULT_FA_VERSION = 2 def _is_fa2_supported(device = None) -> Tuple[bool, Optional[str]]: if not FA2_AVAILABLE: return False, f"FA2 is unavaible due to: {FA2_UNAVAILABLE_REASON}" if torch.cuda.get_device_capability(device)[0] < 8: return False, \ "FA2 is only supported on devices with compute capability >= 8" return True, None def _is_fa3_supported(device = None) -> Tuple[bool, Optional[str]]: if not FA3_AVAILABLE: return False, f"FA3 is unavaible due to: {FA3_UNAVAILABLE_REASON}" if torch.cuda.get_device_capability(device)[0] < 8 \ or torch.cuda.get_device_capability(device)[0] >= 10 \ or torch.cuda.get_device_capability(device) == (8, 6) \ or torch.cuda.get_device_capability(device) == (8, 9): return False, \ "FA3 is only supported on devices with compute capability >= 8" \ " excluding 8.6 and 8.9 and Blackwell archs (>=10)" return True, None def is_fa_version_supported(fa_version: int, device = None) -> bool: assert fa_version in [2, 3], f"Unsupported FA version: {fa_version}" if fa_version == 2: return _is_fa2_supported(device)[0] elif fa_version == 3: return _is_fa3_supported(device)[0] def fa_version_unsupported_reason(fa_version: int, device = None) \ -> Optional[str]: assert fa_version in [2, 3], f"Unsupported FA version: {fa_version}" if fa_version == 2: return _is_fa2_supported(device)[1] elif fa_version == 3: return _is_fa3_supported(device)[1] # # For vLLM we only care about `flash_attn_varlen_func` and # `flash_attn_with_kvcache` so we only maintain wrappers for these two. # def maybe_contiguous(x): return x.contiguous() if x is not None and x.stride(-1) != 1 else x # NOTE only used in FA3 def get_scheduler_metadata( batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, cache_seqlens: torch.Tensor, qkv_dtype=torch.bfloat16, headdim_v=None, cu_seqlens_q: Optional[torch.Tensor] = None, cu_seqlens_k_new: Optional[torch.Tensor] = None, cache_leftpad: Optional[torch.Tensor] = None, page_size: Optional[int] = None, max_seqlen_k_new=0, causal=False, window_size=(-1, -1), # -1 means infinite context window has_softcap=False, num_splits=0, # Can be tuned for speed pack_gqa=None, # Can be tuned for speed sm_margin=0, # Can be tuned if some SMs are used for communication ): cache_seqlens = maybe_contiguous(cache_seqlens) if headdim_v is None: headdim_v = headdim scheduler_metadata = torch.ops._vllm_fa3_C.get_scheduler_metadata( batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, headdim_v, qkv_dtype, cache_seqlens, cu_seqlens_q, None, # cu_seqlens_k cu_seqlens_k_new, None, # seqused_q cache_leftpad, page_size, max_seqlen_k_new, causal, window_size[0], window_size[1], has_softcap, num_splits, pack_gqa, sm_margin, ) return scheduler_metadata def flash_attn_varlen_func( q, k, v, max_seqlen_q, cu_seqlens_q, max_seqlen_k, cu_seqlens_k=None, # only used for non-paged prefill seqused_k=None, q_v=None, dropout_p=0.0, softmax_scale=None, causal=False, window_size: Optional[List[int]] = None, softcap=0.0, # 0.0 means deactivated alibi_slopes=None, deterministic=False, return_attn_probs=False, block_table=None, return_softmax_lse=False, out=None, # FA3 Only scheduler_metadata=None, q_descale=None, k_descale=None, v_descale=None, num_splits: int = 0, # Version selector fa_version: int = DEFAULT_FA_VERSION, s_aux=None, ): """dropout_p should be set to 0.0 during evaluation Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: 1 1 1 1 0 1 1 1 1 1 If seqlen_q = 5 and seqlen_k = 2, the causal mask is: 0 0 0 0 0 0 1 0 1 1 If the row of the mask is all zero, the output will be zero. If window_size != (-1, -1), implements sliding window local attention. Query at position i will only attend to keys between [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. Arguments: q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, used to index into q. cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, used to index into kv. max_seqlen_q: int. Maximum query sequence length in the batch. max_seqlen_k: int. Maximum key sequence length in the batch. dropout_p: float. Dropout probability. softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). window_size: (left, right). If not (-1, -1), implements sliding window local attention. softcap: float. Anything > 0 activates softcapping attention. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i + seqlen_k - seqlen_q - j|) is added to the attention score of query i and key j. deterministic: bool. Whether to use the deterministic implementation of the backward pass, which is slightly slower and uses more memory. The forward pass is always deterministic. return_attn_probs: bool. Whether to return the attention probabilities. This option is for testing only. The returned probabilities are not guaranteed to be correct (they might not have the right scaling). Return: out: (total, nheads, headdim). softmax_lse [optional, if return_softmax_lse=True]: (nheads, total_q_seqlen). The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). """ assert cu_seqlens_k is not None or seqused_k is not None, \ "cu_seqlens_k or seqused_k must be provided" assert cu_seqlens_k is None or seqused_k is None, \ "cu_seqlens_k and seqused_k cannot be provided at the same time" assert block_table is None or seqused_k is not None, \ "seqused_k must be provided if block_table is provided" if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) # custom op does not support non-tuple input real_window_size: Tuple[int, int] if window_size is None: real_window_size = (-1, -1) else: assert len(window_size) == 2 real_window_size = (window_size[0], window_size[1]) q, k, v = [maybe_contiguous(x) for x in (q, k, v)] dummy_cu_seqlens_k = torch.empty_like(cu_seqlens_q) if fa_version == 2: if scheduler_metadata is not None and q_descale is not None \ and k_descale is not None and v_descale is not None: raise NotImplementedError( "FA2 does not support scheduler_metadata, q_descale, " "k_descale, v_descale" ) if s_aux is not None: raise NotImplementedError("FA2 does not support s_aux") if num_splits > 1: raise NotImplementedError("FA2 does not support num_splits > 1") out, softmax_lse = torch.ops._vllm_fa2_C.varlen_fwd( q, k, v, out, cu_seqlens_q, # cu_seqlens_k not used since we use seqused_k, but flash_api.cpp # still wants it so we pass all zeros dummy_cu_seqlens_k if cu_seqlens_k is None else cu_seqlens_k, seqused_k, None, block_table, alibi_slopes, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal, real_window_size[0], real_window_size[1], softcap, return_softmax_lse and dropout_p > 0, None, ) elif fa_version == 3: assert alibi_slopes is None, "Alibi is not supported in FA3" out, softmax_lse, _, _ = torch.ops._vllm_fa3_C.fwd( q, k, v, None, None, # k_new, v_new q_v, out, cu_seqlens_q, cu_seqlens_k, # cu_seqlens_k None, # cu_seqlens_k_new None, seqused_k, # seqused_q, seqused_k max_seqlen_q, max_seqlen_k, block_table, None, # kv_batch_idx None, # leftpad_k None, None, None, # rotary_cos, rotary_sin, seqlens_rotary q_descale, k_descale, v_descale, softmax_scale, causal, real_window_size[0], real_window_size[1], softcap, True, # rotary_interleaved scheduler_metadata, num_splits, None, # pack_gqa 0, # sm_margin s_aux # s_aux ) else: raise ValueError(f"Unsupported FA version: {fa_version}") return (out, softmax_lse) if return_softmax_lse else out def flash_attn_with_kvcache( q, k_cache, v_cache, k=None, v=None, rotary_cos=None, rotary_sin=None, cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None, cache_batch_idx: Optional[torch.Tensor] = None, cache_leftpad: Optional[torch.Tensor] = None, block_table: Optional[torch.Tensor] = None, softmax_scale=None, causal=False, window_size=(-1, -1), # -1 means infinite context window softcap=0.0, # 0.0 means deactivated rotary_interleaved=True, alibi_slopes=None, num_splits=0, return_softmax_lse=False, *, out=None, # FA3 Only scheduler_metadata=None, q_descale=None, k_descale=None, v_descale=None, # Version selector fa_version: int = DEFAULT_FA_VERSION, s_aux=None, ): """ If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from k and v. This is useful for incremental decoding: you can pass in the cached keys/values from the previous step, and update them with the new keys/values from the current step, and do attention with the updated cache, all in 1 kernel. If you pass in k / v, you must make sure that the cache is large enough to hold the new values. For example, the KV cache could be pre-allocated with the max sequence length, and you can use cache_seqlens to keep track of the current sequence lengths of each sequence in the batch. Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc. If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc. If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens). See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function. Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: 1 1 1 1 0 1 1 1 1 1 If seqlen_q = 5 and seqlen_k = 2, the causal mask is: 0 0 0 0 0 0 1 0 1 1 If the row of the mask is all zero, the output will be zero. If window_size != (-1, -1), implements sliding window local attention. Query at position i will only attend to keys between [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. Note: Does not support backward pass. Arguments: q: (batch_size, seqlen, nheads, headdim) k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table, or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache) page_block_size must be a multiple of 256. v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table, or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache) k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate k with k_cache, starting at the indices specified by cache_seqlens. v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k. rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16. rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos. cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the KV cache. block_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32. cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache. If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1]. If the indices are not distinct, and k and v are provided, the values updated in the cache might come from any of the duplicate indices. softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). window_size: (left, right). If not (-1, -1), implements sliding window local attention. softcap: float. Anything > 0 activates softcapping attention. rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in. If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False, rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1 (i.e. GPT-NeoX style). alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i + seqlen_k - seqlen_q - j|) is added to the attention score of query i and key j. num_splits: int. If > 1, split the key/value into this many chunks along the sequence. If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic to automatically determine the number of splits. Don't change this unless you know what you are doing. return_softmax_lse: bool. Whether to return the logsumexp of the attention scores. Return: out: (batch_size, seqlen, nheads, headdim). softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). """ assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension" assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension" q, k, v = [maybe_contiguous(x) for x in (q, k, v)] if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) if cache_seqlens is not None and isinstance(cache_seqlens, int): cache_seqlens = torch.full( (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device ) cache_seqlens = maybe_contiguous(cache_seqlens) cache_batch_idx = maybe_contiguous(cache_batch_idx) block_table = maybe_contiguous(block_table) if s_aux is not None: raise NotImplementedError("FA2 does not support s_aux") if scheduler_metadata is not None and q_descale is not None \ and k_descale is not None and v_descale is not None: raise NotImplementedError( "FA2 does not support scheduler_metadata, q_descale, " "k_descale, v_descale" ) out, softmax_lse = torch.ops._vllm_fa2_C.fwd_kvcache( q, k_cache, v_cache, k, v, # k_new, v_new cache_seqlens, rotary_cos, rotary_sin, cache_batch_idx, cache_leftpad, block_table, alibi_slopes, out, softmax_scale, causal, window_size[0], window_size[1], softcap, rotary_interleaved, num_splits, ) return (out, softmax_lse) if return_softmax_lse else out def sparse_attn_func( q, k, v, block_count, block_offset, column_count, column_index, dropout_p=0.0, softmax_scale=None, causal=False, softcap=0.0, # 0.0 means deactivated alibi_slopes=None, deterministic=False, return_attn_probs=False, *, return_softmax_lse=False, out=None, ): """Compute attention with vertical and slash sparsity patterns. Most Arguments are the same with the flash_attn_func interface, except for 4 extra args: block_count and block_offset for slash sparsity patterns, and column_count and column_index for vertical sparsity patterns. For more details please refer to Appendix C.4.2 of paper https://arxiv.org/abs/2407.02490. Arguments: q: (batch_size, seqlen, nheads, headdim) k: (batch_size, seqlen, nheads_k, headdim) v: (batch_size, seqlen, nheads_k, headdim) block_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M)) block_offset: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_S) column_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M)) column_index: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_V) dropout_p: float. Dropout probability. softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i + seqlen_k - seqlen_q - j|) is added to the attention score of query i and key j. deterministic: bool. Whether to use the deterministic implementation of the backward pass, which is slightly slower and uses more memory. The forward pass is always deterministic. return_attn_probs: bool. Whether to return the attention probabilities. This option is for testing only. The returned probabilities are not guaranteed to be correct (they might not have the right scaling). Return: out: (batch_size, seqlen, nheads, headdim). softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). """ if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) q, k, v = [maybe_contiguous(x) for x in (q, k, v)] out, softmax_lse = torch.ops._vllm_fa2_C.fwd_sparse( q, k, v, block_count, block_offset, column_count, column_index, out, alibi_slopes, dropout_p, softmax_scale, causal, softcap, return_attn_probs and dropout_p > 0, None, ) return (out, softmax_lse) if return_softmax_lse else out def sparse_attn_varlen_func( q, k, v, block_count, block_offset, column_count, column_index, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p=0.0, softmax_scale=None, causal=False, softcap=0.0, # 0.0 means deactivated alibi_slopes=None, deterministic=False, return_attn_probs=False, *, return_softmax_lse=False, out=None, ): """Compute attention with vertical and slash sparsity patterns. Most Arguments are the same with the flash_attn_varlen_func interface, except for 4 extra args: block_count and block_offset for slash sparsity patterns, and column_count and column_index for vertical sparsity patterns. For more details please refer to Appendix C.4.2 of paper https://arxiv.org/abs/2407.02490. Arguments: q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. block_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M)) block_offset: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_S) column_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M)) column_index: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_V) cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, used to index into q. cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, used to index into kv. max_seqlen_q: int. Maximum query sequence length in the batch. max_seqlen_k: int. Maximum key sequence length in the batch. dropout_p: float. Dropout probability. softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). softcap: float. Anything > 0 activates softcapping attention. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i + seqlen_k - seqlen_q - j|) is added to the attention score of query i and key j. deterministic: bool. Whether to use the deterministic implementation of the backward pass, which is slightly slower and uses more memory. The forward pass is always deterministic. return_attn_probs: bool. Whether to return the attention probabilities. This option is for testing only. The returned probabilities are not guaranteed to be correct (they might not have the right scaling). Return: out: (total, nheads, headdim). softmax_lse [optional, if return_softmax_lse=True]: (nheads, total_q_seqlen). The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). """ if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) q, k, v = [maybe_contiguous(x) for x in (q, k, v)] out, softmax_lse = torch.ops._vllm_fa2_C.varlen_fwd_sparse( q, k, v, block_count, block_offset, column_count, column_index, out, cu_seqlens_q, cu_seqlens_k, None, alibi_slopes, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal, softcap, return_attn_probs and dropout_p > 0, None, ) return (out, softmax_lse) if return_softmax_lse else out