2026-04-18 10:56:22 +08:00
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
""" Attention layer with FlashAttention. """
import copy
from dataclasses import dataclass
2026-04-29 19:38:22 +08:00
from typing import ClassVar , Optional , Union , List
2026-04-18 10:56:22 +08:00
import numpy as np
import torch
from vllm . model_executor . layers . attention import Attention
from vllm . v1 . attention . backend import (
AttentionBackend ,
AttentionImpl ,
AttentionType ,
MultipleOf ,
is_quantized_kv_cache ,
)
from vllm . v1 . attention . backends . fa_utils import (
flash_attn_supports_fp8 ,
get_flash_attn_version ,
is_flash_attn_varlen_func_available ,
)
from vllm . v1 . attention . ops . common import cp_lse_ag_out_rs
2026-04-29 19:38:22 +08:00
from ixformer . contrib . vllm_flash_attn import merge_attn_states
2026-04-18 10:56:22 +08:00
if is_flash_attn_varlen_func_available ( ) :
from vllm . v1 . attention . backends . fa_utils import (
flash_attn_supports_sinks ,
flash_attn_varlen_func ,
flash_attn_with_kvcache ,
reshape_and_cache_flash ,
2026-04-29 19:38:22 +08:00
flash_attn_varlen_int8_func
2026-04-18 10:56:22 +08:00
)
from vllm . config import VllmConfig , get_current_vllm_config , get_layers_from_vllm_config
from vllm . config . cache import CacheDType
from vllm . distributed . parallel_state import get_dcp_group
from vllm . logger import init_logger
from vllm . model_executor . layers . batch_invariant import (
vllm_is_batch_invariant ,
)
from vllm . platforms . interface import DeviceCapability
from vllm . utils . math_utils import cdiv , round_up
from vllm . v1 . attention . backend import (
AttentionCGSupport ,
AttentionMetadataBuilder ,
CommonAttentionMetadata ,
)
from vllm . v1 . attention . backends . utils import (
get_dcp_local_seq_lens ,
get_kv_cache_layout ,
2026-04-29 19:38:22 +08:00
split_decodes_and_prefills
2026-04-18 10:56:22 +08:00
)
from vllm . v1 . kv_cache_interface import AttentionSpec
2026-04-29 19:38:22 +08:00
from vllm import _custom_ops as ops
import vllm . envs as envs
import ixformer . inference . functions as ixf_ops
2026-04-18 10:56:22 +08:00
logger = init_logger ( __name__ )
class FlashAttentionBackend ( AttentionBackend ) :
accept_output_buffer : bool = True
supported_dtypes : ClassVar [ list [ torch . dtype ] ] = [ torch . float16 , torch . bfloat16 ]
@staticmethod
def get_supported_kernel_block_sizes ( ) - > list [ int | MultipleOf ] :
2026-04-29 19:38:22 +08:00
return [ 16 , 32 , 64 ]
2026-04-18 10:56:22 +08:00
forward_includes_kv_cache_update : bool = False
@staticmethod
def get_name ( ) - > str :
return " FLASH_ATTN "
@classmethod
def supports_attn_type ( cls , attn_type : str ) - > bool :
""" FlashAttention supports all attention types. """
return attn_type in (
AttentionType . DECODER ,
AttentionType . ENCODER ,
AttentionType . ENCODER_ONLY ,
AttentionType . ENCODER_DECODER ,
)
@classmethod
def supports_per_head_quant_scales ( cls ) - > bool :
fa_version = get_flash_attn_version ( )
return fa_version is not None and fa_version > = 3
@staticmethod
def get_impl_cls ( ) - > type [ " FlashAttentionImpl " ] :
return FlashAttentionImpl
@staticmethod
def get_builder_cls ( ) - > type [ " FlashAttentionMetadataBuilder " ] :
return FlashAttentionMetadataBuilder
@staticmethod
def get_kv_cache_shape (
num_blocks : int ,
block_size : int ,
num_kv_heads : int ,
head_size : int ,
cache_dtype_str : str = " auto " ,
) - > tuple [ int , . . . ] :
if block_size % 16 != 0 :
raise ValueError ( " Block size must be a multiple of 16. " )
2026-04-29 19:38:22 +08:00
if envs . VLLM_ATTN_OPT_LEVEL == 2 :
return ( 3 , num_blocks , num_kv_heads , block_size , head_size )
2026-04-18 10:56:22 +08:00
return ( 2 , num_blocks , num_kv_heads , block_size , head_size )
@staticmethod
def get_kv_cache_stride_order (
include_num_layers_dimension : bool = False ,
) - > tuple [ int , . . . ] :
# `stride_order` indicates the permutation that gets
# us from `get_kv_cache_shape` to the actual memory layout we want.
cache_layout = get_kv_cache_layout ( )
if cache_layout == " NHD " and include_num_layers_dimension :
# (num_blocks, num_layers, 2, block_size, num_kv_heads, head_size)
return ( 2 , 0 , 1 , 3 , 4 , 5 )
elif cache_layout == " NHD " :
stride_order = ( 0 , 1 , 2 , 3 , 4 )
elif cache_layout == " HND " and include_num_layers_dimension :
# (num_blocks, num_kv_heads, num_layers, 2, block_size, head_size)
return ( 2 , 4 , 0 , 1 , 3 , 5 )
elif cache_layout == " HND " :
2026-04-29 19:38:22 +08:00
stride_order = ( 0 , 1 , 2 , 3 , 4 )
2026-04-18 10:56:22 +08:00
else :
raise ValueError ( f " Unknown cache layout format { cache_layout } . " )
return stride_order
@staticmethod
def get_fp8_dtype_for_flashattn ( kv_cache_dtype : str ) - > torch . dtype :
if kv_cache_dtype in ( " fp8 " , " fp8_e4m3 " ) :
return torch . float8_e4m3fn
else :
raise ValueError ( f " Unrecognized FP8 dtype: { kv_cache_dtype } " )
@classmethod
def supports_head_size ( cls , head_size : int ) - > bool :
return head_size % 8 == 0 and head_size < = 256
@classmethod
def supports_kv_cache_dtype ( cls , kv_cache_dtype : CacheDType | None ) - > bool :
if kv_cache_dtype is None :
return True
if kv_cache_dtype . startswith ( " fp8 " ) :
return flash_attn_supports_fp8 ( )
return kv_cache_dtype in [ " auto " , " bfloat16 " ]
@classmethod
def supports_sink ( cls ) - > bool :
if not is_flash_attn_varlen_func_available ( ) :
return False
return flash_attn_supports_sinks ( )
@classmethod
def supports_compute_capability ( cls , capability : DeviceCapability ) - > bool :
return capability > = DeviceCapability ( 8 , 0 )
@classmethod
def supports_combination (
cls ,
head_size : int ,
dtype : torch . dtype ,
kv_cache_dtype : CacheDType | None ,
block_size : int | None ,
use_mla : bool ,
has_sink : bool ,
use_sparse : bool ,
device_capability : DeviceCapability ,
) - > str | None :
if has_sink and device_capability < DeviceCapability ( 9 , 0 ) :
return " sink not supported on compute capability < 9.0 "
return None
2026-04-29 19:38:22 +08:00
2026-04-18 10:56:22 +08:00
@dataclass
class FlashAttentionPrefillMetadata :
2026-04-29 19:38:22 +08:00
""" Prefill Specific Metadata """
2026-04-18 10:56:22 +08:00
block_table : torch . Tensor
query_start_loc : torch . Tensor
key_start_loc : torch . Tensor
max_query_len : int
@dataclass
class FlashAttentionDecodeMetadata :
block_table : torch . Tensor
2026-04-29 19:38:22 +08:00
query_start_loc : torch . Tensor
2026-04-18 10:56:22 +08:00
seq_lens : torch . Tensor
2026-04-29 19:38:22 +08:00
max_query_len : int
2026-04-18 10:56:22 +08:00
max_decode_seq_len : int
2026-04-29 19:38:22 +08:00
use_graph : bool
2026-04-18 10:56:22 +08:00
@dataclass
class FlashAttentionMetadata :
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
num_actual_tokens : int # Number of tokens excluding padding.
max_query_len : int
query_start_loc : torch . Tensor
2026-04-29 19:38:22 +08:00
key_start_loc : torch . Tensor
2026-04-18 10:56:22 +08:00
max_seq_len : int
seq_lens : torch . Tensor
block_table : torch . Tensor
slot_mapping : torch . Tensor
2026-04-29 19:38:22 +08:00
2026-04-18 10:56:22 +08:00
num_decodes : int
num_decode_tokens : int
num_prefills : int
# For cascade attention.
use_cascade : bool
common_prefix_len : int
cu_prefix_query_lens : torch . Tensor | None
prefix_kv_lens : torch . Tensor | None
suffix_kv_lens : torch . Tensor | None
cu_prefix_kv_lens : torch . Tensor | None
cu_suffix_kv_lens : torch . Tensor | None
# For GQA DCP
max_dcp_context_kv_len : int | None = None
dcp_context_kv_lens : torch . Tensor | None = None
# Optional aot scheduling
scheduler_metadata : torch . Tensor | None = None
prefix_scheduler_metadata : torch . Tensor | None = None
max_num_splits : int = 0
2026-04-29 19:38:22 +08:00
2026-04-18 10:56:22 +08:00
prefill : FlashAttentionPrefillMetadata | None = None
decode : FlashAttentionDecodeMetadata | None = None
causal : bool = True
def _get_sliding_window_configs (
vllm_config : VllmConfig ,
) - > set [ tuple [ int , int ] | None ] :
""" Get the set of all sliding window configs used in the model. """
sliding_window_configs : set [ tuple [ int , int ] | None ] = set ( )
layers = get_layers_from_vllm_config ( vllm_config , Attention )
for layer in layers . values ( ) :
assert isinstance ( layer . impl , FlashAttentionImpl )
sliding_window_configs . add ( layer . impl . sliding_window )
return sliding_window_configs
class FlashAttentionMetadataBuilder ( AttentionMetadataBuilder [ FlashAttentionMetadata ] ) :
# FA3:
# Supports full cudagraphs for all cases.
#
# FA2:
# For FA2, a graph is captured with max_query_len=1, (which is what we
# capture by default for num_tokens <= max_num_seqs when there is no
# spec-decode) then these graphs will not work for mixed prefill-decode
# (unlike FA3). This is due to special max_query_len=1 packed-GQA handling
# in FA2.
# In summary if we are running with spec decodes the graphs would
# work for mixed prefill-decode and uniform-decode. But for non-spec decodes
# the graphs would not work for mixed prefill-decode; sorta the inverse
# of UNIFORM_SINGLE_TOKEN_DECODE.
# There's probably a better way to describe this using `AttentionCGSupport`
# but for now just set it to `UNIFORM_BATCH` to get use to drop down
# to FULL_AND_PIECEWISE.
# TODO(luka, lucas): audit FA2 as part of:
# https://github.com/vllm-project/vllm/issues/22945
_cudagraph_support = (
AttentionCGSupport . ALWAYS
if get_flash_attn_version ( ) == 3
else AttentionCGSupport . UNIFORM_BATCH
)
supports_update_block_table : bool = True
2026-04-29 19:38:22 +08:00
2026-04-18 10:56:22 +08:00
reorder_batch_threshold : ClassVar [ int ] = 1
@classmethod
def get_cudagraph_support (
cls ,
vllm_config : " VllmConfig " ,
kv_cache_spec : " AttentionSpec " ,
) - > AttentionCGSupport :
return cls . _cudagraph_support
def __init__ (
self ,
kv_cache_spec : AttentionSpec ,
layer_names : list [ str ] ,
vllm_config : VllmConfig ,
device : torch . device ,
) :
super ( ) . __init__ ( kv_cache_spec , layer_names , vllm_config , device )
self . model_config = vllm_config . model_config
self . parallel_config = vllm_config . parallel_config
self . cache_config = vllm_config . cache_config
self . compilation_config = vllm_config . compilation_config
self . attention_config = vllm_config . attention_config
2026-04-29 19:38:22 +08:00
self . decode_use_graph = vllm_config . compilation_config . cudagraph_mode . decode_use_graph ( )
2026-04-18 10:56:22 +08:00
self . num_heads_q = self . model_config . get_num_attention_heads (
self . parallel_config
)
self . num_heads_kv = self . model_config . get_num_kv_heads ( self . parallel_config )
self . kv_cache_dtype = kv_cache_spec . dtype
self . headdim = self . model_config . get_head_size ( )
self . block_size = kv_cache_spec . block_size
self . max_num_splits = 0 # No upper bound on the number of splits.
self . aot_schedule = False
try :
from vllm . distributed . parallel_state import get_dcp_group
self . dcp_world_size = get_dcp_group ( ) . world_size
self . dcp_rank = get_dcp_group ( ) . rank_in_group
except AssertionError :
# DCP might not be initialized in testing
self . dcp_world_size = 1
self . dcp_rank = 0
self . cp_kv_cache_interleave_size = (
self . parallel_config . cp_kv_cache_interleave_size
)
self . use_full_cuda_graph = (
self . compilation_config . cudagraph_mode . has_full_cudagraphs ( )
)
self . max_cudagraph_size = self . compilation_config . max_cudagraph_capture_size
2026-04-29 19:38:22 +08:00
# Align decode/prefill split threshold with speculative decode query length
# when backend supports treating spec requests as decode.
self . _init_reorder_batch_threshold ( 1 , supports_spec_as_decode = True )
2026-04-18 10:56:22 +08:00
if self . use_full_cuda_graph and self . aot_schedule :
# FA3 scheduler_metadata size: 1 + round_up(batch_size, 4) * 4
# The +1 is for the tile_count_semaphore (synchronization).
# The 4 slots per batch element (num_prepare_batch_vectors) are:
# prepare_varlen + dynamic_split + sort_batches + head_swizzle
# See: https://github.com/vllm-project/flash-attention/blob/5824e6e/hopper/flash_api.cpp#L664-L671 # noqa: E501
max_batch_size = max (
vllm_config . scheduler_config . max_num_seqs ,
self . max_cudagraph_size or 0 ,
)
self . scheduler_metadata = torch . zeros (
1 + round_up ( max_batch_size , 4 ) * 4 ,
dtype = torch . int32 ,
device = self . device ,
)
# When using cuda graph, we need to set the upper bound of the
# number of splits so that large enough intermediate buffers are
# pre-allocated during capture.
self . max_num_splits = (
self . attention_config . flash_attn_max_num_splits_for_cuda_graph
)
# Sliding window size to be used with the AOT scheduler will be
# populated on first build() call.
self . aot_sliding_window : tuple [ int , int ] | None = None
def build (
self ,
common_prefix_len : int ,
common_attn_metadata : CommonAttentionMetadata ,
fast_build : bool = False ,
) - > FlashAttentionMetadata :
"""
fast_build disables AOT scheduling , used when there will be few
iterations i . e . spec - decode
"""
num_reqs = common_attn_metadata . num_reqs
num_actual_tokens = common_attn_metadata . num_actual_tokens
max_query_len = common_attn_metadata . max_query_len
max_seq_len = common_attn_metadata . max_seq_len
query_start_loc = common_attn_metadata . query_start_loc
2026-04-29 19:38:22 +08:00
key_start_loc = common_attn_metadata . key_start_loc
2026-04-18 10:56:22 +08:00
seq_lens = common_attn_metadata . seq_lens
2026-04-29 19:38:22 +08:00
seq_lens_np = common_attn_metadata . seq_lens_np
2026-04-18 10:56:22 +08:00
block_table_tensor = common_attn_metadata . block_table_tensor
slot_mapping = common_attn_metadata . slot_mapping
causal = common_attn_metadata . causal
2026-04-29 19:38:22 +08:00
num_decodes , num_prefills , num_decode_tokens , num_prefill_tokens = \
split_decodes_and_prefills (
common_attn_metadata ,
decode_threshold = self . reorder_batch_threshold ,
)
2026-04-18 10:56:22 +08:00
assert num_decodes + num_prefills == num_reqs
assert num_decode_tokens + num_prefill_tokens == num_actual_tokens
# the overhead of the aot schedule is not worth it for spec-decode
aot_schedule = self . aot_schedule and not fast_build
if self . aot_sliding_window is None :
self . aot_sliding_window = ( - 1 , - 1 )
# For the AOT scheduler we need the sliding window value to be
# constant for all layers to. We have to populate this on the first
# build() call so the layers are constructed (cannot populate)
# in __init__.
if aot_schedule :
sliding_window_configs = _get_sliding_window_configs ( self . vllm_config )
if len ( sliding_window_configs ) == 1 :
sliding_window_config = sliding_window_configs . pop ( )
if sliding_window_config is not None :
self . aot_sliding_window = sliding_window_config
elif len ( sliding_window_configs ) > 1 :
self . aot_schedule = False
aot_schedule = False
max_num_splits = 0 # 0 means use FA3's heuristics, not CG compatible
if (
self . use_full_cuda_graph
and self . max_cudagraph_size is not None
and num_actual_tokens < = self . max_cudagraph_size
) :
# NOTE(woosuk): Setting num_splits > 1 may increase the memory
# usage, because the intermediate buffers of size [num_splits,
# num_heads, num_tokens, head_size] are allocated. Therefore,
# we only set num_splits when using cuda graphs.
max_num_splits = self . max_num_splits
if vllm_is_batch_invariant ( ) :
max_num_splits = 1
def schedule (
batch_size , cu_query_lens , max_query_len , seqlens , max_seq_len , causal
) :
cache_dtype = self . cache_config . cache_dtype
if cache_dtype . startswith ( " fp8 " ) :
qkv_dtype = FlashAttentionBackend . get_fp8_dtype_for_flashattn (
cache_dtype
)
else :
qkv_dtype = self . kv_cache_dtype
if aot_schedule :
return get_scheduler_metadata (
batch_size = batch_size ,
max_seqlen_q = max_query_len ,
max_seqlen_k = max_seq_len ,
num_heads_q = self . num_heads_q * self . dcp_world_size ,
num_heads_kv = self . num_heads_kv ,
headdim = self . headdim ,
cache_seqlens = seqlens ,
qkv_dtype = qkv_dtype ,
cu_seqlens_q = cu_query_lens ,
page_size = self . block_size ,
causal = causal ,
window_size = self . aot_sliding_window ,
num_splits = max_num_splits ,
)
return None
use_cascade = common_prefix_len > 0
max_dcp_context_kv_len = 0
dcp_context_kv_lens = None
cu_prefix_query_lens = None
prefix_kv_lens = None
suffix_kv_lens = None
prefix_scheduler_metadata = None
2026-04-29 19:38:22 +08:00
cu_prefix_kv_lens = None
cu_suffix_kv_lens = None
2026-04-18 10:56:22 +08:00
if self . dcp_world_size > 1 :
query_kv_lens = query_start_loc [ 1 : ] - query_start_loc [ : - 1 ]
dcp_context_kv_lens = seq_lens - query_kv_lens
dcp_context_kv_lens = get_dcp_local_seq_lens (
dcp_context_kv_lens ,
self . dcp_world_size ,
self . dcp_rank ,
self . cp_kv_cache_interleave_size ,
)
# After DCP distribution, the maximum number of tokens for any rank is
# ceil(L / (N * I)) * I, where L is max_seq_len, N is dcp_world_size,
# and I is cp_kv_cache_interleave_size.
# This eliminates GPU->CPU sync while minimizing workspace over-allocation.
num_partitions = self . dcp_world_size * self . cp_kv_cache_interleave_size
max_dcp_context_kv_len = (
( max_seq_len + num_partitions - 1 ) / / num_partitions
) * self . cp_kv_cache_interleave_size
scheduler_metadata = schedule (
batch_size = num_reqs ,
cu_query_lens = query_start_loc ,
max_query_len = max_query_len ,
seqlens = dcp_context_kv_lens ,
max_seq_len = max_dcp_context_kv_len ,
causal = False ,
)
elif use_cascade :
cu_prefix_query_lens = torch . tensor (
[ 0 , num_actual_tokens ] , dtype = torch . int32 , device = self . device
)
prefix_kv_lens = torch . tensor (
[ common_prefix_len ] , dtype = torch . int32 , device = self . device
)
2026-04-29 19:38:22 +08:00
# Use GPU tensor directly - no CPU sync needed
suffix_kv_lens = seq_lens [ : num_reqs ] - common_prefix_len
2026-04-18 10:56:22 +08:00
cu_prefix_kv_lens = torch . tensor ( [ 0 , common_prefix_len ] ,
dtype = torch . int32 ,
device = self . device )
cu_suffix_kv_lens = torch . tensor ( [ 0 , ] + suffix_kv_lens . tolist ( ) ,
dtype = torch . int32 ,
device = self . device ) . cumsum_ ( dim = 0 , dtype = torch . int32 )
prefix_scheduler_metadata = schedule (
batch_size = 1 ,
cu_query_lens = cu_prefix_query_lens ,
max_query_len = num_actual_tokens ,
seqlens = prefix_kv_lens ,
max_seq_len = common_prefix_len ,
causal = False ,
)
scheduler_metadata = schedule (
batch_size = num_reqs ,
cu_query_lens = query_start_loc ,
max_query_len = max_query_len ,
seqlens = suffix_kv_lens ,
max_seq_len = max_seq_len - common_prefix_len ,
causal = True ,
)
else :
scheduler_metadata = schedule (
batch_size = num_reqs ,
cu_query_lens = query_start_loc ,
max_query_len = max_query_len ,
seqlens = seq_lens ,
max_seq_len = max_seq_len ,
causal = causal ,
)
# For FA3 + full cudagraph
2026-04-29 19:38:22 +08:00
max_num_splits = 0
2026-04-18 10:56:22 +08:00
if self . use_full_cuda_graph and scheduler_metadata is not None :
n = scheduler_metadata . shape [ 0 ]
self . scheduler_metadata [ : n ] = scheduler_metadata
# NOTE(woosuk): We should zero out the rest of the scheduler
# metadata to guarantee the correctness. Otherwise, some thread
# blocks may use the invalid scheduler metadata and overwrite the
# output buffer.
self . scheduler_metadata [ n : ] = 0
scheduler_metadata = self . scheduler_metadata [ : n ]
2026-04-29 19:38:22 +08:00
2026-04-18 10:56:22 +08:00
if num_actual_tokens < = self . max_cudagraph_size :
# NOTE(woosuk): Setting num_splits > 1 may increase the memory
# usage, because the intermediate buffers of size [num_splits,
# num_heads, num_tokens, head_size] are allocated. Therefore,
# we only set num_splits when using cuda graphs.
max_num_splits = self . max_num_splits
2026-04-29 19:38:22 +08:00
2026-04-18 10:56:22 +08:00
prefill_metadata = None
if num_prefills > 0 :
2026-04-29 19:38:22 +08:00
reqs_start = num_decodes # prefill_start
prefill_query_start_loc = query_start_loc [
reqs_start : ] - query_start_loc [ reqs_start ]
prefill_key_start_loc = key_start_loc [
reqs_start : ] - key_start_loc [ reqs_start ]
2026-04-18 10:56:22 +08:00
prefill_metadata = FlashAttentionPrefillMetadata (
2026-04-29 19:38:22 +08:00
block_table = block_table_tensor [ reqs_start : , . . . ] ,
query_start_loc = prefill_query_start_loc ,
key_start_loc = prefill_key_start_loc ,
max_query_len = max_query_len ,
)
2026-04-18 10:56:22 +08:00
decode_metadata = None
if num_decodes > 0 :
2026-04-29 19:38:22 +08:00
reqs_start = num_decodes # prefill_start
decode_query_start_loc = query_start_loc [ : reqs_start + 1 ]
decode_query_lens = (
decode_query_start_loc [ 1 : ] - decode_query_start_loc [ : - 1 ]
)
2026-04-18 10:56:22 +08:00
decode_metadata = FlashAttentionDecodeMetadata (
block_table = block_table_tensor [ : reqs_start , . . . ] ,
2026-04-29 19:38:22 +08:00
query_start_loc = decode_query_start_loc ,
2026-04-18 10:56:22 +08:00
seq_lens = seq_lens [ : reqs_start ] ,
2026-04-29 19:38:22 +08:00
max_query_len = decode_query_lens . max ( ) . item ( ) ,
max_decode_seq_len = np . max ( seq_lens_np [ : reqs_start ] ) . item ( ) ,
use_graph = num_prefills == 0 and self . decode_use_graph
2026-04-18 10:56:22 +08:00
)
2026-04-29 19:38:22 +08:00
2026-04-18 10:56:22 +08:00
attn_metadata = FlashAttentionMetadata (
num_actual_tokens = num_actual_tokens ,
max_query_len = max_query_len ,
query_start_loc = query_start_loc ,
2026-04-29 19:38:22 +08:00
key_start_loc = key_start_loc ,
2026-04-18 10:56:22 +08:00
max_seq_len = max_seq_len ,
seq_lens = seq_lens ,
block_table = block_table_tensor ,
slot_mapping = slot_mapping ,
num_decodes = num_decodes ,
num_decode_tokens = num_decode_tokens ,
num_prefills = num_prefills ,
2026-04-29 19:38:22 +08:00
max_dcp_context_kv_len = max_dcp_context_kv_len ,
dcp_context_kv_lens = dcp_context_kv_lens ,
2026-04-18 10:56:22 +08:00
use_cascade = use_cascade ,
common_prefix_len = common_prefix_len ,
scheduler_metadata = scheduler_metadata ,
cu_prefix_query_lens = cu_prefix_query_lens ,
prefix_kv_lens = prefix_kv_lens ,
suffix_kv_lens = suffix_kv_lens ,
cu_prefix_kv_lens = cu_prefix_kv_lens ,
cu_suffix_kv_lens = cu_suffix_kv_lens ,
prefix_scheduler_metadata = prefix_scheduler_metadata ,
max_num_splits = max_num_splits ,
causal = causal ,
2026-04-29 19:38:22 +08:00
prefill = prefill_metadata ,
decode = decode_metadata ,
2026-04-18 10:56:22 +08:00
)
return attn_metadata
def update_block_table (
self ,
metadata : FlashAttentionMetadata ,
blk_table : torch . Tensor ,
slot_mapping : torch . Tensor ,
) - > FlashAttentionMetadata :
new_metadata = copy . copy ( metadata )
new_metadata . block_table = blk_table
new_metadata . slot_mapping = slot_mapping
2026-04-29 19:38:22 +08:00
# Keep nested prefill/decode block tables in sync. Decode path consumes
# `attn_metadata.decode.block_table`, so updating only the top-level
# `block_table` is insufficient when metadata is reused across groups.
if metadata . decode is not None :
new_decode = copy . copy ( metadata . decode )
reqs_start = metadata . num_decodes
new_decode . block_table = blk_table [ : reqs_start , . . . ]
new_metadata . decode = new_decode
if metadata . prefill is not None :
new_prefill = copy . copy ( metadata . prefill )
reqs_start = metadata . num_decodes
new_prefill . block_table = blk_table [ reqs_start : , . . . ]
new_metadata . prefill = new_prefill
2026-04-18 10:56:22 +08:00
return new_metadata
def use_cascade_attention ( self , * args , * * kwargs ) - > bool :
return use_cascade_attention ( * args , * * kwargs )
class FlashAttentionImpl ( AttentionImpl ) :
can_return_lse_for_decode : bool = True
def __init__ (
self ,
num_heads : int ,
head_size : int ,
scale : float ,
num_kv_heads : int ,
alibi_slopes : list [ float ] | None ,
sliding_window : int | None ,
kv_cache_dtype : str ,
logits_soft_cap : float | None = None ,
attn_type : AttentionType = AttentionType . DECODER ,
kv_sharing_target_layer_name : str | None = None ,
sinks : torch . Tensor | None = None ,
) - > None :
self . num_heads = num_heads
self . head_size = head_size
self . scale = float ( scale )
self . num_kv_heads = num_kv_heads
if alibi_slopes is not None :
alibi_slopes = torch . tensor ( alibi_slopes , dtype = torch . float32 )
self . alibi_slopes = alibi_slopes
if sliding_window is None :
self . sliding_window = ( - 1 , - 1 )
elif attn_type == AttentionType . ENCODER_ONLY :
self . sliding_window = ( sliding_window - 1 , sliding_window - 1 )
else :
self . sliding_window = ( sliding_window - 1 , 0 )
self . kv_cache_dtype = kv_cache_dtype
if logits_soft_cap is None :
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
logits_soft_cap = 0
self . logits_soft_cap = logits_soft_cap
self . kv_sharing_target_layer_name = kv_sharing_target_layer_name
self . num_queries_per_kv = self . num_heads / / self . num_kv_heads
self . attn_type = attn_type
2026-04-29 19:38:22 +08:00
self . vllm_flash_attn_version = get_flash_attn_version (
requires_alibi = alibi_slopes is not None ,
head_size = head_size ,
)
logger . info_once (
" Using FlashAttention version %s " ,
self . vllm_flash_attn_version ,
scope = " local " ,
)
2026-04-18 10:56:22 +08:00
# Cache the batch invariant result for use in forward passes
self . batch_invariant_enabled = vllm_is_batch_invariant ( )
if is_quantized_kv_cache ( self . kv_cache_dtype ) and not flash_attn_supports_fp8 ( ) :
raise NotImplementedError (
" FlashAttention does not support fp8 kv-cache on this device. "
)
self . sinks = sinks
2026-04-29 19:38:22 +08:00
2026-04-18 10:56:22 +08:00
if self . sinks is not None :
assert flash_attn_supports_sinks ( ) , (
" Sinks are only supported in FlashAttention 3 "
)
assert self . sinks . shape [ 0 ] == num_heads , (
" Sinks must have the same number of heads as the number of "
" heads in the layer "
)
self . supports_quant_query_input = True
2026-04-29 19:38:22 +08:00
self . supports_per_head_quant_scales = (
self . vllm_flash_attn_version > = 3
if self . vllm_flash_attn_version is not None
else False
)
assert envs . VLLM_ATTN_OPT_LEVEL in [ 0 , 1 , 2 ] , " VLLM_ATTN_OPT_LEVEL only support [0 for non-quant, 1 for I8Q_I8K_I8V, 2 for I8Q_I8K_F16V] now! but got {} " . format ( envs . VLLM_ATTN_OPT_LEVEL )
'''
quant_type = 0
attention : f16 qkv
cache : f16 kv cache
quant_type = 1
attention : int8q int8k int8v
cache :
int8 k cache & & fp32 k cache scale
int8 v cache & & fp32 v cache scale ( load from file , dont update )
quant_type = 2
attention : int8q int8k fp16v
cache :
int8 k cache & & fp32 k cache scale
fp16 v cache
'''
self . quant_type = int ( envs . VLLM_ATTN_OPT_LEVEL )
2026-04-18 10:56:22 +08:00
def forward (
self ,
layer : torch . nn . Module ,
query : torch . Tensor ,
key : torch . Tensor ,
value : torch . Tensor ,
kv_cache : torch . Tensor ,
attn_metadata : FlashAttentionMetadata ,
output : torch . Tensor | None = None ,
sqrt_alibi : bool = False ,
2026-04-29 19:38:22 +08:00
kv_cache_scale : Union [ torch . Tensor , List [ torch . Tensor ] ] | None = None ,
2026-04-18 10:56:22 +08:00
output_scale : torch . Tensor | None = None ,
output_block_scale : torch . Tensor | None = None ,
) - > torch . Tensor :
""" Forward pass with FlashAttention.
Args :
query : shape = [ num_tokens , num_heads , head_size ]
key : shape = [ num_tokens , num_kv_heads , head_size ]
value : shape = [ num_tokens , num_kv_heads , head_size ]
kv_cache : shape =
[ 2 , num_blocks , block_size , num_kv_heads , head_size ]
attn_metadata : Metadata for attention .
2026-04-29 19:38:22 +08:00
kv_cache_scale = [ num_blocks , num_kv_heads , block_size ] + [ num_kv_heads , head_size ]
2026-04-18 10:56:22 +08:00
Returns :
shape = [ num_tokens , num_heads * head_size ]
NOTE : FP8 quantization , flash - attn expect the size of
{ q , k , v } _descale to be ( num_sequences , num_kv_heads ) .
We use torch ' s .expand() to avoid duplicating values
"""
assert output is not None , " Output tensor must be provided. "
2026-04-29 19:38:22 +08:00
# assert self.vllm_flash_attn_version is not None, (
# "FlashAttention version not detected."
# )
2026-04-18 10:56:22 +08:00
if output_scale is not None or output_block_scale is not None :
raise NotImplementedError (
" fused output quantization is not yet supported for FlashAttentionImpl "
)
if attn_metadata is None :
# Profiling run.
2026-04-29 19:38:22 +08:00
return output . view ( - 1 , self . num_heads * self . head_size )
2026-04-18 10:56:22 +08:00
softmax_scale : float = self . scale
window_size = self . sliding_window
2026-04-29 19:38:22 +08:00
alibi_slopes : torch . Tensor = self . alibi_slopes
logits_soft_cap : float = self . logits_soft_cap
2026-04-18 10:56:22 +08:00
attn_type = self . attn_type
# IMPORTANT!
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
# in this method. For example, `view` and `slice` (or `[:n]`) operations
# are surprisingly slow even in the case they do not invoke any GPU ops.
# Minimize the PyTorch ops in this method as much as possible.
# Whenever making a change in this method, please benchmark the
# performance to make sure it does not introduce any overhead.
num_actual_tokens = attn_metadata . num_actual_tokens
# Handle encoder attention differently - no KV cache needed
if attn_type in ( AttentionType . ENCODER_ONLY , AttentionType . ENCODER ) :
# For encoder attention,
# we use direct Q, K, V tensors without caching
return self . _forward_encoder_attention (
query [ : num_actual_tokens ] ,
key [ : num_actual_tokens ] ,
value [ : num_actual_tokens ] ,
output [ : num_actual_tokens ] ,
attn_metadata ,
layer ,
2026-04-29 19:38:22 +08:00
) . view ( - 1 , self . num_heads * self . head_size )
2026-04-18 10:56:22 +08:00
# For decoder and cross-attention, use KV cache as before
has_decode = attn_metadata . num_decodes > 0
has_prefill = attn_metadata . num_prefills > 0
2026-04-29 19:38:22 +08:00
decode_only = has_decode and not has_prefill
2026-04-18 10:56:22 +08:00
num_decode_tokens = attn_metadata . num_decode_tokens
2026-04-29 19:38:22 +08:00
if self . quant_type == 0 :
key_cache , value_cache = kv_cache . unbind ( 0 )
elif self . quant_type == 1 :
i8_key_cache , i8_value_cache = kv_cache . unbind ( 0 )
num_blocks , num_kv_heads , block_size , head_size = i8_key_cache . shape
key_scale_cache , value_scale_cache = kv_cache_scale
assert key_scale_cache . shape == ( num_blocks , num_kv_heads , block_size ) and key_scale_cache . dtype == torch . float32 , f " key_scale_cache.shape { key_scale_cache . shape } != (num_blocks, num_kv_heads, block_size) or key_scale_cache.dtype { key_scale_cache . dtype } != torch.float32 "
assert value_scale_cache . shape == ( num_kv_heads , head_size ) and value_scale_cache . dtype == torch . float32 , f " value_scale_cache.shape { value_scale_cache . shape } != (num_kv_heads, head_size) or value_scale_cache.dtype { value_scale_cache . dtype } != torch.float32 "
value_cache_info = ( i8_value_cache , value_scale_cache )
elif self . quant_type == 2 :
# key_cache 是 f16, value_cache 是 int8
i8_key_cache = kv_cache [ 0 ]
num_blocks , num_kv_heads , block_size , head_size = i8_key_cache . shape
value_cache = kv_cache [ 1 : ] . view ( query . dtype ) . reshape ( num_blocks , num_kv_heads , block_size , head_size )
key_scale_cache = kv_cache_scale
value_cache_info = ( value_cache , None )
2026-04-18 10:56:22 +08:00
decode_q = query [ : num_decode_tokens ]
prefill_q = query [ num_decode_tokens : ]
prefill_output = output [ num_decode_tokens : ]
decode_output = output [ : num_decode_tokens ]
2026-04-29 19:38:22 +08:00
if self . quant_type == 1 :
if decode_only :
int8_query , query_scale = ixf_ops . scaled_int8_quant_for_attn (
query , 2 , transpose_scale = False
)
i8_key , key_scale = ixf_ops . scaled_int8_quant_for_attn (
key , 2 , transpose_scale = False
)
i8_value , _value_scale = ixf_ops . scaled_int8_quant_for_attn (
value , 0 , transpose_scale = False , scale = value_cache_info [ 1 ]
)
else :
int8_query , query_scale = ixf_ops . scaled_int8_quant_for_attn (
query , 2 , transpose_scale = True
)
i8_key , key_scale = ixf_ops . scaled_int8_quant_for_attn (
key , 2 , transpose_scale = False
)
i8_value , _value_scale = ixf_ops . scaled_int8_quant_for_attn (
value , 0 , transpose_scale = False , scale = value_cache_info [ 1 ]
)
elif self . quant_type == 2 :
'''
origin key cache
num_blocks , num_kv_heads , block_size , head_size f16
reformat key cache
key_cache_i8 : num_blocks , num_kv_heads , block_size , head_size int8
key_scale_cache : num_blocks , num_kv_heads , block_size fp32
'''
if decode_only :
int8_query , query_scale = ixf_ops . scaled_int8_quant_for_attn (
query , 2 , transpose_scale = False
)
i8_key , key_scale = ixf_ops . scaled_int8_quant_for_attn (
key , 2 , transpose_scale = False
)
else :
int8_query , query_scale = ixf_ops . scaled_int8_quant_for_attn (
query , 2 , transpose_scale = True
)
i8_key , key_scale = ixf_ops . scaled_int8_quant_for_attn (
key , 2 , transpose_scale = False
)
else :
if layer . quant_manager is not None and layer . quant_manager . check_enable ( ) :
i8_value , value_scale = ixf_ops . scaled_int8_quant_for_attn (
value , 0 , transpose_scale = False
)
layer . quant_manager . update_data ( value_scale )
# key and value may be None in the case of cross attention. They are
# calculated once based on the output from the encoder and then cached
# in KV cache.
if (
self . kv_sharing_target_layer_name is None
and key is not None
and value is not None
) :
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
# not padded. However, we don't need to do key[:num_actual_tokens]
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
if self . quant_type == 1 :
if has_prefill :
ixf_ops . reshape_and_cache_flash_int8 (
key = i8_key ,
value = i8_value ,
k_scale = key_scale ,
key_cache = i8_key_cache ,
value_cache = value_cache_info [ 0 ] ,
key_scale_cache = key_scale_cache ,
slot_mapping = attn_metadata . slot_mapping ,
kv_cache_dtype = " " ,
)
elif self . quant_type == 2 :
if has_prefill :
ixf_ops . reshape_and_cache_flash_mix (
key = i8_key ,
value = value ,
k_scale = key_scale ,
key_cache = i8_key_cache ,
value_cache = value_cache_info [ 0 ] ,
key_scale_cache = key_scale_cache ,
slot_mapping = attn_metadata . slot_mapping ,
kv_cache_dtype = " " ,
)
else :
ops . reshape_and_cache_flash (
key ,
value ,
key_cache ,
value_cache ,
attn_metadata . slot_mapping ,
self . kv_cache_dtype ,
layer . _k_scale ,
layer . _v_scale ,
)
2026-04-18 10:56:22 +08:00
if self . kv_cache_dtype . startswith ( " fp8 " ) :
# queries are quantized in the attention layer
dtype = FlashAttentionBackend . get_fp8_dtype_for_flashattn (
self . kv_cache_dtype
)
key_cache = key_cache . view ( dtype )
value_cache = value_cache . view ( dtype )
if not attn_metadata . use_cascade :
if self . dcp_world_size > 1 :
self . _forward_with_dcp (
query [ : num_actual_tokens ] ,
key [ : num_actual_tokens ] ,
value [ : num_actual_tokens ] ,
key_cache ,
value_cache ,
output [ : num_actual_tokens ] ,
attn_metadata ,
)
return output . view ( - 1 , self . num_heads * self . head_size )
else :
if has_prefill :
2026-04-29 19:38:22 +08:00
# key = key[num_decode_tokens:]
# value = value[num_decode_tokens:]
# int8 attn
if self . quant_type > 0 :
flash_attn_varlen_int8_func (
q = int8_query [ num_decode_tokens : ] ,
k = i8_key_cache ,
v = value_cache_info [ 0 ] ,
q_scale = query_scale [ : , num_decode_tokens : ] ,
k_scale = key_scale_cache ,
v_scale = value_cache_info [ 1 ] ,
cu_seqlens_q = attn_metadata . prefill . query_start_loc ,
cu_seqlens_k = attn_metadata . prefill . key_start_loc ,
max_seqlen_q = attn_metadata . prefill . max_query_len ,
max_seqlen_k = attn_metadata . max_query_len ,
softmax_scale = softmax_scale ,
causal = True ,
window_size = window_size ,
alibi_slopes = alibi_slopes ,
softcap = logits_soft_cap ,
sqrt_alibi = sqrt_alibi ,
out = prefill_output ,
block_table = attn_metadata . prefill . block_table ,
output_dtype = query . dtype
)
else :
flash_attn_varlen_func (
q = prefill_q ,
k = key_cache ,
v = value_cache ,
cu_seqlens_q = attn_metadata . prefill . query_start_loc ,
cu_seqlens_k = attn_metadata . prefill . key_start_loc ,
max_seqlen_q = attn_metadata . prefill . max_query_len ,
max_seqlen_k = attn_metadata . max_query_len ,
softmax_scale = softmax_scale ,
causal = True ,
window_size = window_size ,
alibi_slopes = alibi_slopes ,
softcap = logits_soft_cap ,
sqrt_alibi = sqrt_alibi ,
sinks = self . sinks ,
out = prefill_output ,
block_table = attn_metadata . prefill . block_table ,
)
2026-04-18 10:56:22 +08:00
if has_decode :
2026-04-29 19:38:22 +08:00
# for mtp + cuda graph
max_q_len = attn_metadata . decode . max_query_len if attn_metadata . decode is not None else attn_metadata . max_query_len
max_ct_len = attn_metadata . decode . max_decode_seq_len if attn_metadata . decode is not None else attn_metadata . max_seq_len
if self . quant_type in [ 1 , 2 ] :
para_dict = dict (
output = decode_output ,
query = int8_query [ : num_decode_tokens ] ,
key_cache = i8_key_cache ,
query_scale = query_scale [ : num_decode_tokens ] if decode_only else query_scale [ : , : num_decode_tokens ] . t ( ) . contiguous ( ) ,
key_scale_cache = key_scale_cache ,
num_kv_heads = self . num_kv_heads ,
scale = softmax_scale ,
block_tables = attn_metadata . decode . block_table ,
context_lens = attn_metadata . decode . seq_lens ,
block_size = i8_key_cache . shape [ - 2 ] ,
softcap = logits_soft_cap ,
alibi_slopes = alibi_slopes ,
causal = True ,
window_left = window_size [ 0 ] ,
window_right = window_size [ 1 ] ,
use_sqrt_alibi = sqrt_alibi ,
use_cuda_graph = attn_metadata . decode . use_graph if decode_only else False ,
max_context_len = max_ct_len ,
# mtp
cu_query_lens = attn_metadata . decode . query_start_loc ,
max_query_len = max_q_len ,
)
if self . quant_type == 1 :
para_dict . update (
dict (
value_cache = value_cache_info [ 0 ] ,
value_scale_cache = value_cache_info [ 1 ] ,
)
)
# for kv + k_scale write fusion
if decode_only :
para_dict . update (
dict (
save_key = i8_key [ : num_decode_tokens ] ,
save_value = i8_value [ : num_decode_tokens ] ,
save_key_scale = key_scale [ : num_decode_tokens ] ,
)
)
ixf_ops . vllm_paged_attention_int8 ( * * para_dict )
elif self . quant_type == 2 :
para_dict . update (
dict (
value_cache = value_cache ,
)
)
if decode_only :
para_dict . update (
dict (
save_key = i8_key [ : num_decode_tokens ] ,
save_value = value [ : num_decode_tokens ] . contiguous ( ) ,
save_key_scale = key_scale [ : num_decode_tokens ] ,
)
)
ixf_ops . vllm_paged_attention_mix (
* * para_dict
)
else :
flash_attn_with_kvcache (
q = decode_q . unsqueeze ( 1 ) ,
k_cache = key_cache ,
v_cache = value_cache ,
block_table = attn_metadata . decode . block_table ,
cache_seqlens = attn_metadata . decode . seq_lens ,
max_query_len = max_q_len ,
cu_query_lens = attn_metadata . decode . query_start_loc ,
softmax_scale = softmax_scale ,
causal = True ,
window_size = window_size ,
alibi_slopes = alibi_slopes ,
softcap = logits_soft_cap ,
use_sqrt_alibi = sqrt_alibi ,
sinks = self . sinks ,
out = decode_output . unsqueeze ( 1 ) ,
use_cuda_graph = attn_metadata . decode . use_graph ,
max_context_len = max_ct_len
)
# Compute attention and update output up to `num_actual_tokens`.
2026-04-18 10:56:22 +08:00
return output . view ( - 1 , self . num_heads * self . head_size )
# Cascade attention (rare case).
cascade_attention (
output [ : num_actual_tokens ] ,
query [ : num_actual_tokens ] ,
key_cache ,
value_cache ,
cu_query_lens = attn_metadata . query_start_loc ,
max_query_len = attn_metadata . max_query_len ,
cu_prefix_query_lens = attn_metadata . cu_prefix_query_lens ,
cu_prefix_kv_lens = attn_metadata . cu_prefix_kv_lens ,
cu_suffix_kv_lens = attn_metadata . cu_suffix_kv_lens ,
max_kv_len = attn_metadata . max_seq_len ,
softmax_scale = self . scale ,
alibi_slopes = self . alibi_slopes ,
sliding_window = self . sliding_window ,
logits_soft_cap = self . logits_soft_cap ,
block_table = attn_metadata . block_table ,
common_prefix_len = attn_metadata . common_prefix_len ,
max_num_splits = attn_metadata . max_num_splits ,
fa_version = self . vllm_flash_attn_version ,
prefix_scheduler_metadata = attn_metadata . prefix_scheduler_metadata ,
suffix_scheduler_metadata = attn_metadata . scheduler_metadata ,
q_descale = layer . _q_scale ,
k_descale = layer . _k_scale ,
v_descale = layer . _v_scale ,
s_aux = self . sinks ,
)
2026-04-29 19:38:22 +08:00
return output . view ( - 1 , self . num_heads * self . head_size )
2026-04-18 10:56:22 +08:00
def do_kv_cache_update (
self ,
layer : torch . nn . Module ,
key : torch . Tensor ,
value : torch . Tensor ,
kv_cache : torch . Tensor ,
slot_mapping : torch . Tensor ,
) - > None :
if self . attn_type in ( AttentionType . ENCODER_ONLY , AttentionType . ENCODER ) :
# For encoder attention,
# we use direct Q, K, V tensors without caching
return
key_cache , value_cache = kv_cache . unbind ( 0 )
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
# not padded. However, we don't need to do key[:num_actual_tokens]
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
2026-04-29 19:38:22 +08:00
ops . reshape_and_cache_flash (
2026-04-18 10:56:22 +08:00
key ,
value ,
key_cache ,
value_cache ,
slot_mapping ,
self . kv_cache_dtype ,
layer . _k_scale ,
layer . _v_scale ,
)
def _forward_with_dcp (
self ,
query : torch . Tensor ,
key : torch . Tensor ,
value : torch . Tensor ,
key_cache : torch . Tensor ,
value_cache : torch . Tensor ,
output : torch . Tensor ,
attn_metadata : FlashAttentionMetadata ,
q_descale : torch . Tensor | None = None ,
k_descale : torch . Tensor | None = None ,
v_descale : torch . Tensor | None = None ,
) - > torch . Tensor :
2026-04-29 19:38:22 +08:00
# assert self.vllm_flash_attn_version is not None, (
# "FlashAttention version not detected."
# )
2026-04-18 10:56:22 +08:00
cu_seqlens_q = attn_metadata . query_start_loc
max_seqlen_q = attn_metadata . max_query_len
block_table = attn_metadata . block_table
query = query . contiguous ( )
query_across_dcp = get_dcp_group ( ) . all_gather ( query , dim = 1 )
2026-04-29 19:38:22 +08:00
cu_dcp_kv_klens = attn_metadata . dcp_context_kv_lens . cumsum ( dim = 0 , dtype = torch . int32 )
new_tensor = torch . tensor ( [ 0 ] ,
device = attn_metadata . dcp_context_kv_lens . device ,
dtype = attn_metadata . dcp_context_kv_lens . dtype )
cu_seqlens_k = torch . cat ( [ new_tensor , cu_dcp_kv_klens ] )
2026-04-18 10:56:22 +08:00
sliding_window_size = (
list ( self . sliding_window ) if self . sliding_window is not None else None
)
context_attn_out , context_lse = flash_attn_varlen_func (
q = query_across_dcp ,
k = key_cache ,
v = value_cache ,
out = None ,
cu_seqlens_q = cu_seqlens_q ,
cu_seqlens_k = cu_seqlens_k ,
2026-04-29 19:38:22 +08:00
max_seqlen_q = max_seqlen_q ,
2026-04-18 10:56:22 +08:00
max_seqlen_k = attn_metadata . max_dcp_context_kv_len ,
softmax_scale = self . scale ,
causal = False ,
alibi_slopes = self . alibi_slopes ,
window_size = sliding_window_size ,
block_table = block_table ,
softcap = self . logits_soft_cap ,
return_softmax_lse = True ,
)
# FA returns LSE in shape [ H, B ] but cp_lse_ag_out_rs wants [ B, H ]
context_attn_out_cor , context_lse_cor = cp_lse_ag_out_rs (
context_attn_out ,
context_lse . transpose ( 0 , 1 ) ,
get_dcp_group ( ) ,
return_lse = True ,
)
context_lse_cor = context_lse_cor . transpose ( 0 , 1 ) . contiguous ( )
query_attn_out , query_lse = flash_attn_varlen_func (
q = query ,
k = key ,
v = value ,
out = None ,
cu_seqlens_q = cu_seqlens_q ,
max_seqlen_q = max_seqlen_q ,
cu_seqlens_k = cu_seqlens_q ,
max_seqlen_k = max_seqlen_q ,
softmax_scale = self . scale ,
causal = attn_metadata . causal ,
alibi_slopes = self . alibi_slopes ,
window_size = sliding_window_size ,
softcap = self . logits_soft_cap ,
return_softmax_lse = True ,
)
assert context_attn_out_cor . shape == query_attn_out . shape
assert context_lse_cor . shape == query_lse . shape
merge_attn_states (
context_attn_out_cor ,
context_lse_cor ,
query_attn_out ,
query_lse ,
2026-04-29 19:38:22 +08:00
output
2026-04-18 10:56:22 +08:00
)
def _forward_encoder_attention (
self ,
query : torch . Tensor ,
key : torch . Tensor ,
value : torch . Tensor ,
output : torch . Tensor ,
attn_metadata : FlashAttentionMetadata ,
layer : torch . nn . Module ,
) - > torch . Tensor :
""" Forward pass for encoder attention without KV cache.
Args :
query : shape = [ num_encoder_tokens , num_heads , head_size ]
key : shape = [ num_encoder_tokens , num_kv_heads , head_size ]
value : shape = [ num_encoder_tokens , num_kv_heads , head_size ]
output : shape = [ num_encoder_tokens , num_heads , head_size ]
attn_metadata : Encoder attention metadata
layer : The attention layer
"""
2026-04-29 19:38:22 +08:00
# assert self.vllm_flash_attn_version is not None, (
# "FlashAttention version not detected."
# )
2026-04-18 10:56:22 +08:00
# For encoder attention, process FP8 quantization if needed
if self . kv_cache_dtype . startswith ( " fp8 " ) :
raise NotImplementedError (
" quantization is not supported for encoder attention "
)
# Use encoder-specific metadata for sequence information
cu_seqlens_q = attn_metadata . query_start_loc
cu_seqlens_k = attn_metadata . query_start_loc
max_seqlen_q = attn_metadata . max_query_len
max_seqlen_k = attn_metadata . max_query_len
descale_shape = (
cu_seqlens_q . shape [ 0 ] - 1 , # type: ignore[union-attr]
self . num_kv_heads ,
)
# Call flash attention directly on Q, K, V tensors
sliding_window_size = (
list ( self . sliding_window ) if self . sliding_window is not None else None
)
flash_attn_varlen_func (
q = query ,
k = key ,
v = value ,
out = output ,
cu_seqlens_q = cu_seqlens_q ,
cu_seqlens_k = cu_seqlens_k ,
max_seqlen_q = max_seqlen_q ,
max_seqlen_k = max_seqlen_k ,
softmax_scale = self . scale ,
causal = False , # Encoder attention is bidirectional
alibi_slopes = self . alibi_slopes ,
window_size = sliding_window_size ,
softcap = self . logits_soft_cap ,
)
2026-04-29 19:38:22 +08:00
return output
2026-04-18 10:56:22 +08:00
def use_cascade_attention (
common_prefix_len : int ,
query_lens : np . ndarray ,
num_query_heads : int ,
num_kv_heads : int ,
use_alibi : bool ,
use_sliding_window : bool ,
use_local_attention : bool ,
num_sms : int ,
dcp_world_size : int ,
) - > bool :
""" Decide whether to use cascade attention.
This function 1 ) checks whether cascade attention is supported with the
given configuration , and 2 ) heuristically decides whether using cascade
attention can improve performance .
"""
# Too short common prefix. Probably not worth using cascade attention.
# We use an arbitrary threshold of 256 tokens. TODO: Tune this threshold.
# NOTE(woosuk): This is the common case. We should return False as soon as
# possible to avoid any unnecessary computation.
if common_prefix_len < 256 :
return False
# Cascade attention is currently not supported with these variants.
if use_alibi or use_sliding_window or use_local_attention :
return False
# Too few queries. Probably not worth using cascade attention.
# We use an arbitrary threshold of 8 queries. TODO: Tune this threshold.
num_reqs = len ( query_lens )
if num_reqs < 8 :
return False
# disable cascade attention for DCP
if dcp_world_size > 1 :
return False
# Heuristics to decide whether using cascade attention is beneficial.
# 1. When FlashDecoding is not used for normal attention, cascade attention
# is likely to be faster since it saves memory bandwidth.
num_queries_per_kv = num_query_heads / / num_kv_heads
# The criteria for using FlashDecoding can be found in the following link:
# https://github.com/vllm-project/flash-attention/blob/96266b1111111f3d11aabefaf3bacbab6a89d03c/csrc/flash_attn/flash_api.cpp#L535
use_flash_decoding = (
num_queries_per_kv > 1
and not use_sliding_window
and not use_alibi
and np . all ( query_lens == 1 )
)
if not use_flash_decoding :
# Use cascade attention.
return True
# 2. When FlashDecoding is used for normal attention, it is not clear
# whether cascade attention is beneficial, because FlashDecoding can
# launch more CTAs than cascade attention.
# We use a simple performance model to compare the two methods.
# NOTE(woosuk): The performance model is very rough and may not be
# accurate.
num_tokens = num_reqs
# NOTE(woosuk): These are default tile sizes. flash-attn might use
# different tile sizes (e.g., 64 or 256) depending on the configuration.
q_tile_size = 128
kv_tile_size = 128
num_prefix_tiles = cdiv ( common_prefix_len , kv_tile_size )
cascade_ctas = num_query_heads * cdiv ( num_tokens , q_tile_size )
cascade_waves = cdiv ( cascade_ctas , num_sms )
cascade_time = cascade_waves * num_prefix_tiles
flash_decoding_ctas = (
num_reqs * num_kv_heads * cdiv ( num_queries_per_kv , q_tile_size )
)
flash_decoding_ctas * = num_prefix_tiles
flash_decoding_time = cdiv ( flash_decoding_ctas , num_sms )
# Use cascade attention if it is faster than FlashDecoding.
return cascade_time < flash_decoding_time
def cascade_attention (
output : torch . Tensor ,
query : torch . Tensor ,
key_cache : torch . Tensor ,
value_cache : torch . Tensor ,
cu_query_lens : torch . Tensor ,
max_query_len : int ,
cu_prefix_query_lens : torch . Tensor ,
cu_prefix_kv_lens : torch . Tensor ,
cu_suffix_kv_lens : torch . Tensor ,
max_kv_len : int ,
softmax_scale : float ,
alibi_slopes : torch . Tensor | None ,
sliding_window : tuple [ int , int ] ,
logits_soft_cap : float ,
block_table : torch . Tensor ,
common_prefix_len : int ,
max_num_splits : int ,
fa_version : int ,
prefix_scheduler_metadata : torch . Tensor | None = None ,
suffix_scheduler_metadata : torch . Tensor | None = None ,
q_descale : torch . Tensor | None = None ,
k_descale : torch . Tensor | None = None ,
v_descale : torch . Tensor | None = None ,
s_aux : torch . Tensor | None = None ,
) - > torch . Tensor :
assert alibi_slopes is None , " Cascade attention does not support ALiBi. "
# TODO: Support sliding window.
assert sliding_window == ( - 1 , - 1 ) , (
" Cascade attention does not support sliding window. "
)
num_tokens = query . shape [ 0 ]
block_size = key_cache . shape [ - 2 ]
assert common_prefix_len % block_size == 0
num_common_kv_blocks = common_prefix_len / / block_size
assert num_common_kv_blocks > 0
2026-04-29 19:38:22 +08:00
assert q_descale is None or q_descale == 1 , f " q_descale is not None, q_descale: { q_descale } "
assert k_descale is None or k_descale == 1 , f " k_descale is not None, k_descale: { k_descale } "
assert v_descale is None or v_descale == 1 , f " v_descale is not None, v_descale: { v_descale } "
2026-04-18 10:56:22 +08:00
# Process shared prefix.
prefix_output , prefix_lse = flash_attn_varlen_func (
q = query ,
k = key_cache ,
v = value_cache ,
cu_seqlens_q = cu_prefix_query_lens ,
cu_seqlens_k = cu_prefix_kv_lens ,
max_seqlen_q = num_tokens ,
max_seqlen_k = common_prefix_len ,
softmax_scale = softmax_scale ,
causal = False ,
window_size = list ( sliding_window ) ,
block_table = block_table [ : 1 ] ,
softcap = logits_soft_cap ,
return_softmax_lse = True ,
)
# Process suffix per query.
suffix_output , suffix_lse = flash_attn_varlen_func (
q = query ,
k = key_cache ,
v = value_cache ,
cu_seqlens_q = cu_query_lens ,
cu_seqlens_k = cu_suffix_kv_lens ,
max_seqlen_q = max_query_len ,
max_seqlen_k = max_kv_len - common_prefix_len ,
softmax_scale = softmax_scale ,
causal = True ,
window_size = list ( sliding_window ) ,
block_table = block_table [ : , num_common_kv_blocks : ] ,
softcap = logits_soft_cap ,
return_softmax_lse = True ,
)
merge_attn_states ( prefix_output , prefix_lse , suffix_output , suffix_lse , output )