# # Copyright (c) 2025 Baidu, Inc. All Rights Reserved. # Author: Dong Xinyu, Bao Qian, Chen Zhennan, Ma Tianyu, Wang Haowen # Email: dongxinyu03@baidu.com # This file is a part of the vllm-kunlun project. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from vllm.config import VllmConfig, get_layers_from_vllm_config import xtorch_ops from dataclasses import dataclass from typing import Any, Dict, List, Optional, ClassVar, Tuple, Type, TYPE_CHECKING import torch import numpy as np from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionLayer, AttentionType) from vllm.attention.backends.utils import CommonAttentionState from vllm.attention.backends.utils import is_block_tables_empty, compute_slot_mapping_start_idx, compute_slot_mapping from vllm_kunlun.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) from vllm_kunlun.ops._kunlun_ops import KunlunOps from vllm.v1.attention.backends.utils import (CommonAttentionMetadata, AttentionCGSupport, split_decodes_and_prefills) from vllm.forward_context import ForwardContext, get_forward_context from itertools import accumulate from vllm.utils import async_tensor_h2d, make_tensor_with_pad if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.worker.gpu_input_batch import InputBatch from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.worker.block_table import BlockTable from vllm.config import VllmConfig, get_layers_from_vllm_config class KunlunAttentionBackend(AttentionBackend): """KunlunAttentionBackend""" # crucial to cuda graph accept_output_buffer = True @staticmethod def get_name() -> str: """get_name""" return "Kunlun_v1" @staticmethod def get_impl_cls() -> Type["KunlunAttentionImpl"]: """get_impl_cls""" return KunlunAttentionImpl @staticmethod def get_metadata_cls() -> Type["KunlunMetadata"]: """get_metadata_cls""" return KunlunMetadata @staticmethod def get_builder_cls() -> Type["KunlunAttentionMetadataBuilder"]: """get_builder_cls""" return KunlunAttentionMetadataBuilder @staticmethod def get_state_cls() -> Type["CommonAttentionState"]: """get_state_cls""" return CommonAttentionState @staticmethod def get_kv_cache_shape( num_blocks: int, block_size: int, num_kv_heads: int, head_size: int, ) -> Tuple[int, ...]: """get_kv_cache_shape""" # return (2, num_blocks, block_size, num_kv_heads * head_size) return PagedAttention.get_kv_cache_shape(num_blocks, block_size, num_kv_heads, head_size) @staticmethod def swap_blocks( src_kv_cache: List[torch.Tensor], dst_kv_cache: List[torch.Tensor], src_to_dst: torch.Tensor, ) -> None: """swap_blocks""" raise NotImplementedError @staticmethod def copy_blocks( kv_caches: List[torch.Tensor], src_to_dists: torch.Tensor, ) -> None: """copy_blocks""" raise NotImplementedError @dataclass class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata): """KunlunMetadata""" # |---------- N-1 iteration --------| # |---------------- N iteration ---------------------| # |- tokenA -|......................|-- newTokens ---| # |---------- context_len ----------| # |-------------------- seq_len ----------------------| # |-- query_len ---| # seq_lens stored as a tensor. seq_lens_tensor: Optional[torch.Tensor] # FIXME: It is for flash attn. # Maximum sequence length among prefill batch. 0 if there are decoding # requests only. max_prefill_seq_len: int # Maximum sequence length among decode batch. 0 if there are prefill # requests only. max_decode_seq_len: int num_actual_tokens: int # Whether or not if cuda graph is enabled. # Cuda-graph is currently enabled for decoding only. # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. use_cuda_graph: bool # (batch_size,). The sequence length per sequence. Sequence length means # the computed tokens + new tokens None if it is a decoding. seq_lens: Optional[List[int]] = None # FIXME: It is for flash attn. # (batch_size + 1,). The cumulative sequence lengths of the sequences in # the batch, used to index into sequence. E.g., if the sequence length is # [4, 6], it is [0, 4, 10]. seq_start_loc: Optional[torch.Tensor] = None # (batch_size,) A tensor of context lengths (tokens that are computed # so far). context_lens_tensor: Optional[torch.Tensor] = None # Maximum query length in the batch. None for decoding. max_query_len: Optional[int] = None # Max number of key/value length in the batch, especially for prefix cache max_kv_len: Optional[int] = None # Max number of query tokens among request in the batch. max_decode_query_len: Optional[int] = None # (batch_size + 1,). The cumulative subquery lengths of the sequences in # the batch, used to index into subquery. E.g., if the subquery length # is [4, 6], it is [0, 4, 10]. query_start_loc: Optional[torch.Tensor] = None query_start_loc_host: Optional[torch.Tensor] = None # serve only for prefix cache kv_prefix_start_loc_host: Optional[torch.Tensor] = None kv_prefix_start_loc: Optional[torch.Tensor] = None # Self-attention prefill/decode metadata cache _cached_prefill_metadata: Optional["KunlunMetadata"] = None _cached_decode_metadata: Optional["KunlunMetadata"] = None # Begin encoder attn & enc/dec cross-attn fields... # Encoder sequence lengths representation encoder_seq_lens: Optional[List[int]] = None encoder_seq_lens_tensor: Optional[torch.Tensor] = None # Maximum sequence length among encoder sequences max_encoder_seq_len: Optional[int] = None # Number of tokens input to encoder num_encoder_tokens: Optional[int] = None # Cross-attention memory-mapping data structures: slot mapping # and block tables cross_slot_mapping: Optional[torch.Tensor] = None cross_block_tables: Optional[torch.Tensor] = None # Input positions for rotrary embeddings since for MLA the rotary # position embeddings are applied inside the attention backend input_positions: Optional[torch.Tensor] = None use_cascade: Optional[bool] = False seq_lens_tensor_cpu: Optional[torch.Tensor] = None def __post_init__(self): """__post_init__""" self.attn_bias: Optional[List[AttentionBias]] = None self.encoder_attn_bias: Optional[List[AttentionBias]] = None self.cross_attn_bias: Optional[List[AttentionBias]] = None @property def is_all_encoder_attn_metadata_set(self): """is_all_encoder_attn_metadata_set""" return ((self.encoder_seq_lens is not None) and (self.encoder_seq_lens_tensor is not None) and (self.max_encoder_seq_len is not None)) @property def is_all_cross_attn_metadata_set(self): """is_all_cross_attn_metadata_set""" return (self.is_all_encoder_attn_metadata_set and (self.cross_slot_mapping is not None) and (self.cross_block_tables is not None)) @property def prefill_metadata(self) -> Optional["KunlunMetadata"]: """prefill_metadata""" if self.num_prefills == 0: return None if self._cached_prefill_metadata is not None: # Recover cached prefill-phase attention # metadata structure return self._cached_prefill_metadata assert ((self.seq_lens_tensor is not None) or (self.encoder_seq_lens_tensor is not None)) # Compute some attn_metadata fields which default to None query_start_loc = (None if self.query_start_loc is None else self.query_start_loc[-(self.num_prefills + 1):] - self.query_start_loc[-(self.num_prefills + 1)]) # flash attention needs both lod information on host and device query_start_loc_host = (None if self.query_start_loc_host is None else self.query_start_loc_host[-(self.num_prefills + 1):] - self.query_start_loc_host[-(self.num_prefills + 1)]) # TODO(chengruichang):how to support prefix cache kv_prefix_start_loc_host = None kv_prefix_start_loc = None slot_mapping = (None if self.slot_mapping is None else self.slot_mapping[-self.num_prefill_tokens:]) seq_lens_tensor = (None if self.seq_lens_tensor is None else self.seq_lens_tensor[-self.num_prefills:]) seq_lens = (None if self.seq_lens is None else self.seq_lens[-self.num_prefills:]) context_lens_tensor = (None if self.context_lens_tensor is None else self.context_lens_tensor[-self.num_prefills:]) block_tables = (None if self.block_tables is None else self.block_tables[-self.num_prefills:]) input_positions = (None if self.input_positions is None else self.input_positions[-self.num_prefills:]) # Construct & cache prefill-phase attention metadata structure self._cached_prefill_metadata = KunlunMetadata( num_actual_tokens=self.num_actual_tokens, multi_modal_placeholder_index_maps=self. multi_modal_placeholder_index_maps, num_prefills=self.num_prefills, num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=0, slot_mapping=slot_mapping, seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, seq_start_loc=None, max_query_len=self.max_query_len, max_kv_len=self.max_kv_len, max_prefill_seq_len=self.max_prefill_seq_len, max_decode_seq_len=0, query_start_loc=query_start_loc, query_start_loc_host=query_start_loc_host, input_positions=input_positions, kv_prefix_start_loc=kv_prefix_start_loc, kv_prefix_start_loc_host=kv_prefix_start_loc_host, context_lens_tensor=context_lens_tensor, block_tables=block_tables, use_cuda_graph=False, # Begin encoder & cross attn fields below... encoder_seq_lens=self.encoder_seq_lens, encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, max_encoder_seq_len=self.max_encoder_seq_len, cross_slot_mapping=self.cross_slot_mapping, cross_block_tables=self.cross_block_tables, enable_kv_scales_calculation=False, use_cascade=self.use_cascade) return self._cached_prefill_metadata @property def decode_metadata(self) -> Optional["KunlunMetadata"]: """decode_metadata""" if self.num_decode_tokens == 0: return None if self._cached_decode_metadata is not None: # Recover cached decode-phase attention # metadata structure return self._cached_decode_metadata assert ((self.seq_lens_tensor is not None) or (self.encoder_seq_lens_tensor is not None)) if self.num_prefills != 0: # Compute some attn_metadata fields which default to None slot_mapping = (None if self.slot_mapping is None else self.slot_mapping[:-self.num_prefill_tokens]) seq_lens_tensor = (None if self.seq_lens_tensor is None else self.seq_lens_tensor[:-self.num_prefills]) seq_lens_tensor_cpu = (None if self.seq_lens_tensor_cpu is None else self.seq_lens_tensor_cpu[:-self.num_prefills]) block_tables = (None if self.block_tables is None else self.block_tables[:-self.num_prefills]) else: # Compute some attn_metadata fields which default to None slot_mapping = (None if self.slot_mapping is None else self.slot_mapping) seq_lens_tensor = (None if self.seq_lens_tensor is None else self.seq_lens_tensor) seq_lens_tensor_cpu = (None if self.seq_lens_tensor_cpu is None else self.seq_lens_tensor_cpu) block_tables = (None if self.block_tables is None else self.block_tables) # Construct & cache decode-phase attention metadata structure self._cached_decode_metadata = KunlunMetadata( num_actual_tokens=self.num_actual_tokens, multi_modal_placeholder_index_maps=self. multi_modal_placeholder_index_maps, num_prefills=0, num_prefill_tokens=0, num_decode_tokens=self.num_decode_tokens, slot_mapping=slot_mapping, seq_lens_tensor=seq_lens_tensor, seq_lens_tensor_cpu=seq_lens_tensor_cpu, max_prefill_seq_len=0, max_decode_seq_len=self.max_decode_seq_len, block_tables=block_tables, use_cuda_graph=self.use_cuda_graph, # Begin encoder & cross attn fields below... encoder_seq_lens=self.encoder_seq_lens, encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, max_encoder_seq_len=self.max_encoder_seq_len, cross_slot_mapping=self.cross_slot_mapping, cross_block_tables=self.cross_block_tables, enable_kv_scales_calculation=False, use_cascade=self.use_cascade) return self._cached_decode_metadata class KunlunAttentionMetadataBuilder: """KunlunAttentionMetadataBuilder""" cudagraph_support: ClassVar[AttentionCGSupport] = \ AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE reorder_batch_threshold: ClassVar[Optional[int]] = 1 def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): """__init__""" self.vllm_config = vllm_config self.model_config = vllm_config.model_config self.parallel_config = vllm_config.parallel_config self.compilation_config = vllm_config.compilation_config self.num_heads_q = self.model_config.get_num_attention_heads( self.parallel_config) self.num_heads_kv = self.model_config.get_num_kv_heads( self.parallel_config) self.headdim = self.model_config.get_head_size() self.block_size = kv_cache_spec.block_size self.kv_cache_spec = kv_cache_spec self.device = device def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput") -> bool: """reorder_batch""" decodes = [] prefills = [] num_decode_tokens = 0 num_prefill_tokens = 0 for i, req_id in enumerate(input_batch.req_ids): num_tokens = scheduler_output.num_scheduled_tokens[req_id] # TODO: how if a prefilled sequence has only one token if num_tokens == 1: decodes.append(i) num_decode_tokens += num_tokens else: prefills.append(i) num_prefill_tokens += num_tokens num_decodes = len(decodes) num_prefills = len(prefills) first_prefill = 0 modified_batch = False for i in range(1, min(num_decodes, num_prefills) + 1): if decodes[num_decodes - i] >= num_decodes: input_batch.swap_states(prefills[first_prefill], decodes[num_decodes - i]) first_prefill += 1 modified_batch = True else: break self._num_decodes = num_decodes self._num_prefills = num_prefills self._num_decode_tokens = num_decode_tokens self._num_prefill_tokens = num_prefill_tokens return modified_batch def build(self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata): """build""" num_reqs=common_attn_metadata.num_reqs num_actual_tokens=common_attn_metadata.num_actual_tokens max_query_len=common_attn_metadata.max_query_len common_prefix_len=common_prefix_len block_table_tensor = common_attn_metadata.block_table_tensor slot_mapping = common_attn_metadata.slot_mapping max_seq_len = int(common_attn_metadata.seq_lens_cpu.max()) query_start_loc_host = common_attn_metadata.query_start_loc_cpu[:num_reqs + 1] query_start_loc = common_attn_metadata.query_start_loc_cpu[:num_reqs + 1].to( self.device, non_blocking=True) seq_lens = common_attn_metadata.seq_lens seq_lens_cpu = common_attn_metadata.seq_lens_cpu seq_start_loc = list(accumulate(seq_lens, initial=0)) if len(seq_start_loc) != num_reqs + 1: seq_start_loc = query_start_loc_host.tolist() if seq_start_loc[-1] != num_actual_tokens: seq_start_loc = query_start_loc_host.tolist() seq_start_loc_tensor = torch.empty(len(seq_start_loc), dtype=torch.int32, device=self.device) seq_start_loc_tensor.copy_(torch.as_tensor(seq_start_loc, dtype=torch.int32)) num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\ split_decodes_and_prefills(common_attn_metadata) num_scheduled_tokens = np.diff(common_attn_metadata.query_start_loc_cpu[:num_reqs + 1]) tmp_decode_scheduled_tokens = num_scheduled_tokens[:num_decodes] if num_decode_tokens == 0: max_decode_seq_len = 0 else: max_decode_seq_len = np.max(tmp_decode_scheduled_tokens) tmp_prefill_scheduled_tokens = num_scheduled_tokens[num_decodes: num_reqs] if num_prefill_tokens == 0: max_prefill_seq_len = 0 else: max_prefill_seq_len = np.max(tmp_prefill_scheduled_tokens) use_cascade = False attn_metadata = KunlunMetadata( num_actual_tokens=num_actual_tokens, num_prefills=num_prefills, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=None, enable_kv_scales_calculation=True, num_prefill_tokens=num_prefill_tokens, num_decode_tokens=num_decode_tokens, seq_lens_tensor=seq_lens, seq_lens_tensor_cpu=seq_lens_cpu, max_query_len=max_prefill_seq_len, max_prefill_seq_len=max_prefill_seq_len, max_decode_seq_len=max_decode_seq_len, query_start_loc=query_start_loc, query_start_loc_host=query_start_loc_host, context_lens_tensor=None, block_tables=block_table_tensor, use_cuda_graph=False, use_cascade=use_cascade, ) return attn_metadata def can_run_in_cudagraph( self, common_attn_metadata: CommonAttentionMetadata) -> bool: """can_run_in_cudagraph""" # Full CUDA Graph always supported (FA2 support checked separately) return True def use_cascade_attention(self, *args, **kwargs) -> bool: """use_cascade_attention""" return use_cascade_attention(*args, **kwargs) class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]): """KunlunAttentionImpl""" def __init__( self, num_heads: int, head_size: int, scale: float, num_kv_heads: int, alibi_slopes: Optional[List[float]], sliding_window: Optional[int], kv_cache_dtype: str, blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, kv_sharing_target_layer_name: Optional[str] = None, attn_type: AttentionType = AttentionType.DECODER, use_irope: bool = False, sinks:Optional[torch.Tensor]= None, ) -> None: """__init__""" if blocksparse_params is not None: raise ValueError( "kunlunAttention does not support block-sparse attention.") self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) self.num_kv_heads = num_kv_heads if alibi_slopes is not None: alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) self.alibi_slopes = alibi_slopes self.sliding_window = sliding_window self.kv_cache_dtype = kv_cache_dtype self.kv_sharing_target_layer_name = kv_sharing_target_layer_name assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.use_irope = use_irope suppored_head_sizes = PagedAttention.get_supported_head_sizes() if head_size not in suppored_head_sizes: raise ValueError( f"Head size {head_size} is not supported by PagedAttention. " f"Supported head sizes are: {suppored_head_sizes}.") self.sinks = sinks if sinks is not None: assert sinks.shape[0] == num_heads, ( "Sinks must have the same number of heads as the number of " f"heads in the layer. Sinks shape: {sinks.shape}, " f"num_heads: {num_heads}.") def forward( self, layer: AttentionLayer, query: torch.Tensor, key: Optional[torch.Tensor], value: Optional[torch.Tensor], kv_cache: torch.Tensor, attn_metadata: Optional[KunlunMetadata], k_scale: float = 1.0, v_scale: float = 1.0, attn_type: AttentionType = AttentionType.DECODER, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None ) -> torch.Tensor: """forward""" query = query.view(-1, self.num_heads, self.head_size) if output is None: output = torch.empty_like(query) if attn_metadata is None: # Profiling run. return output.view(-1, self.num_heads * self.head_size) if key is not None: assert value is not None key = key.view(-1, self.num_kv_heads, self.head_size) value = value.view(-1, self.num_kv_heads, self.head_size) else: assert value is None # Self-attention vs. cross-attention will impact # which KV cache memory-mapping & which # seqlen datastructures we utilize if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0): # KV-cache during decoder-self- or # encoder-decoder-cross-attention, but not # during encoder attention. # # Even if there are no new key/value pairs to cache, # we still need to break out key_cache and value_cache # i.e. for later use by paged attention key_cache, value_cache = PagedAttention.split_kv_cache( kv_cache, self.num_kv_heads, self.head_size) if (key is not None) and (value is not None): updated_slot_mapping = attn_metadata.slot_mapping # Reshape the input keys and values and store them in the cache. # If kv_cache is not provided, the new key and value tensors are # not cached. This happens during the initial memory value = value.contiguous() xtorch_ops.reshape_and_cache( key, value, key_cache, value_cache, updated_slot_mapping) assert attn_type == AttentionType.DECODER # Decoder self-attention supports chunked prefill. num_prefill_tokens = attn_metadata.num_prefill_tokens num_decode_tokens = attn_metadata.num_decode_tokens # Only enforce this shape-constraint for decoder # self-attention if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. prefill_query = query[num_decode_tokens:attn_metadata.num_actual_tokens] prefill_key = key[num_decode_tokens:attn_metadata.num_actual_tokens] prefill_value = value[num_decode_tokens:attn_metadata.num_actual_tokens] assert prefill_query.shape[0] == num_prefill_tokens output[num_decode_tokens:attn_metadata.num_actual_tokens] = KunlunOps.multi_query_kv_attention( prefill_meta.query_start_loc,prefill_meta.query_start_loc_host, prefill_query, prefill_key, prefill_value, alibi_slopes=self.alibi_slopes).view_as(prefill_query) if decode_meta := attn_metadata.decode_metadata: assert attn_type != AttentionType.ENCODER_ONLY, ( "Encoder-only models should not have decode metadata.") decode_query = query[:num_decode_tokens] xtorch_ops.paged_attention( x=decode_query, k_cache=key_cache, v_cache=value_cache, block_tables=decode_meta.block_tables, context_lens_cpu=decode_meta.seq_lens_tensor_cpu, context_lens_xpu=decode_meta.seq_lens_tensor, is_context=False, is_causal=True, out=output[:num_decode_tokens], vo_head_dim=self.head_size ) # Reshape the output tensor. return output.view(-1, self.num_heads * self.head_size) def use_cascade_attention( common_prefix_len: int, query_lens: np.ndarray, num_query_heads: int, num_kv_heads: int, use_alibi: bool, use_sliding_window: bool, num_sms: int, use_local_attention: bool = False, ) -> bool: """ TODO: Not Yet Supported on Kunlun platform """ # Too short common prefix. Probably not worth using cascade attention. # We use an arbitrary threshold of 256 tokens. TODO: Tune this threshold. # NOTE(woosuk): This is the common case. We should return False as soon as # possible to avoid any unnecessary computation. if common_prefix_len < 256: return False # Cascade attention is currently not supported with these variants. if use_alibi or use_sliding_window or use_local_attention: return False # Too few queries. Probably not worth using cascade attention. # We use an arbitrary threshold of 8 queries. TODO: Tune this threshold. num_reqs = len(query_lens) if num_reqs < 8: return False # Heuristics to decide whether using cascade attention is beneficial. # 1. When FlashDecoding is not used for normal attention, cascade attention # is likely to be faster since it saves memory bandwidth. num_queries_per_kv = num_query_heads // num_kv_heads # The criteria for using FlashDecoding can be found in the following link: # https://github.com/vllm-project/flash-attention/blob/96266b1111111f3d11aabefaf3bacbab6a89d03c/csrc/flash_attn/flash_api.cpp#L535 use_flash_decoding = (num_queries_per_kv > 1 and not use_sliding_window and not use_alibi and np.all(query_lens == 1)) if not use_flash_decoding: # Use cascade attention. return True # 2. When FlashDecoding is used for normal attention, it is not clear # whether cascade attention is beneficial, because FlashDecoding can # launch more CTAs than cascade attention. # We use a simple performance model to compare the two methods. # NOTE(woosuk): The performance model is very rough and may not be # accurate. num_tokens = num_reqs # NOTE(woosuk): These are default tile sizes. flash-attn might use # different tile sizes (e.g., 64 or 256) depending on the configuration. q_tile_size = 128 kv_tile_size = 128 num_prefix_tiles = cdiv(common_prefix_len, kv_tile_size) cascade_ctas = num_query_heads * cdiv(num_tokens, q_tile_size) cascade_waves = cdiv(cascade_ctas, num_sms) cascade_time = cascade_waves * num_prefix_tiles flash_decoding_ctas = (num_reqs * num_kv_heads * cdiv(num_queries_per_kv, q_tile_size)) flash_decoding_ctas *= num_prefix_tiles flash_decoding_time = cdiv(flash_decoding_ctas, num_sms) # Use cascade attention if it is faster than FlashDecoding. return cascade_time < flash_decoding_time