# # Copyright (c) 2025 Baidu, Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # This file is a part of the vllm-kunlun project. # from vllm.config import VllmConfig, get_layers_from_vllm_config import xtorch_ops from dataclasses import dataclass from typing import Any, Dict, List, Optional, ClassVar, Tuple, Type, TYPE_CHECKING import torch import numpy as np from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionLayer, AttentionType) # from vllm.attention.backends.utils import CommonAttentionState # from vllm.attention.backends.utils import is_block_tables_empty, compute_slot_mapping_start_idx, compute_slot_mapping from vllm_kunlun.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) from vllm_kunlun.ops._kunlun_ops import KunlunOps from vllm.v1.attention.backends.utils import (CommonAttentionMetadata, AttentionCGSupport, split_decodes_and_prefills) from vllm.forward_context import ForwardContext, get_forward_context from itertools import accumulate from vllm.utils import async_tensor_h2d, make_tensor_with_pad if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.worker.gpu_input_batch import InputBatch from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.worker.block_table import BlockTable from vllm.config import VllmConfig, get_layers_from_vllm_config class KunlunAttentionBackend(AttentionBackend): """KunlunAttentionBackend""" # crucial to cuda graph accept_output_buffer = True @staticmethod def get_name() -> str: """get_name""" return "Kunlun_v1" @staticmethod def get_impl_cls() -> Type["KunlunAttentionImpl"]: """get_impl_cls""" return KunlunAttentionImpl @staticmethod def get_metadata_cls() -> Type["KunlunMetadata"]: """get_metadata_cls""" return KunlunMetadata @staticmethod def get_builder_cls() -> Type["KunlunAttentionMetadataBuilder"]: """get_builder_cls""" return KunlunAttentionMetadataBuilder # @staticmethod # def get_state_cls() -> Type["CommonAttentionState"]: # """get_state_cls""" # return CommonAttentionState @staticmethod def get_kv_cache_shape( num_blocks: int, block_size: int, num_kv_heads: int, head_size: int, cache_dtype_str: str = "auto" ) -> Tuple[int, ...]: """get_kv_cache_shape""" # return (2, num_blocks, block_size, num_kv_heads * head_size) return PagedAttention.get_kv_cache_shape(num_blocks, block_size, num_kv_heads, head_size) @staticmethod def swap_blocks( src_kv_cache: List[torch.Tensor], dst_kv_cache: List[torch.Tensor], src_to_dst: torch.Tensor, ) -> None: """swap_blocks""" raise NotImplementedError @staticmethod def copy_blocks( kv_caches: List[torch.Tensor], src_to_dists: torch.Tensor, ) -> None: """copy_blocks""" raise NotImplementedError @dataclass class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata): """KunlunMetadata""" # |---------- N-1 iteration --------| # |---------------- N iteration ---------------------| # |- tokenA -|......................|-- newTokens ---| # |---------- context_len ----------| # |-------------------- seq_len ----------------------| # |-- query_len ---| # seq_lens stored as a tensor. seq_lens_tensor: Optional[torch.Tensor] # FIXME: It is for flash attn. # Maximum sequence length among prefill batch. 0 if there are decoding # requests only. max_prefill_seq_len: int # Maximum sequence length among decode batch. 0 if there are prefill # requests only. max_decode_seq_len: int num_actual_tokens: int # Whether or not if cuda graph is enabled. # Cuda-graph is currently enabled for decoding only. # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. use_cuda_graph: bool slot_mapping: torch.Tensor block_tables: torch.Tensor multi_modal_placeholder_index_maps: Optional[torch.Tensor] = None # (batch_size,). The sequence length per sequence. Sequence length means # the computed tokens + new tokens None if it is a decoding. seq_lens: Optional[List[int]] = None # FIXME: It is for flash attn. # (batch_size + 1,). The cumulative sequence lengths of the sequences in # the batch, used to index into sequence. E.g., if the sequence length is # [4, 6], it is [0, 4, 10]. seq_start_loc: Optional[torch.Tensor] = None # Prefix cache loc kv_lod_cpu: Optional[torch.Tensor] = None kv_lod_xpu: Optional[torch.Tensor] = None # (batch_size,) A tensor of context lengths (tokens that are computed # so far). context_lens_tensor: Optional[torch.Tensor] = None # Maximum query length in the batch. None for decoding. max_query_len: Optional[int] = None # Max number of key/value length in the batch, especially for prefix cache max_kv_len: Optional[int] = None # Max number of query tokens among request in the batch. max_decode_query_len: Optional[int] = None # (batch_size + 1,). The cumulative subquery lengths of the sequences in # the batch, used to index into subquery. E.g., if the subquery length # is [4, 6], it is [0, 4, 10]. query_start_loc: Optional[torch.Tensor] = None query_start_loc_host: Optional[torch.Tensor] = None # serve only for prefix cache kv_prefix_start_loc_host: Optional[torch.Tensor] = None kv_prefix_start_loc: Optional[torch.Tensor] = None # Self-attention prefill/decode metadata cache _cached_prefill_metadata: Optional["KunlunMetadata"] = None _cached_decode_metadata: Optional["KunlunMetadata"] = None # Begin encoder attn & enc/dec cross-attn fields... # Encoder sequence lengths representation encoder_seq_lens: Optional[List[int]] = None encoder_seq_lens_tensor: Optional[torch.Tensor] = None # Maximum sequence length among encoder sequences max_encoder_seq_len: Optional[int] = None # Number of tokens input to encoder num_encoder_tokens: Optional[int] = None enable_kv_scales_calculation: Optional[bool] = False # Cross-attention memory-mapping data structures: slot mapping # and block tables cross_slot_mapping: Optional[torch.Tensor] = None cross_block_tables: Optional[torch.Tensor] = None # Input positions for rotrary embeddings since for MLA the rotary # position embeddings are applied inside the attention backend input_positions: Optional[torch.Tensor] = None use_cascade: Optional[bool] = False seq_lens_tensor_cpu: Optional[torch.Tensor] = None num_prefill_tokens: int = 0 num_decode_tokens: int = 0 num_prefills: int = 0 num_decodes: int = 0 def __post_init__(self): """__post_init__""" self.attn_bias: Optional[List[AttentionBias]] = None self.encoder_attn_bias: Optional[List[AttentionBias]] = None self.cross_attn_bias: Optional[List[AttentionBias]] = None @property def is_all_encoder_attn_metadata_set(self): """is_all_encoder_attn_metadata_set""" return ((self.encoder_seq_lens is not None) and (self.encoder_seq_lens_tensor is not None) and (self.max_encoder_seq_len is not None)) @property def is_all_cross_attn_metadata_set(self): """is_all_cross_attn_metadata_set""" return (self.is_all_encoder_attn_metadata_set and (self.cross_slot_mapping is not None) and (self.cross_block_tables is not None)) @property def prefill_metadata(self) -> Optional["KunlunMetadata"]: """prefill_metadata""" if self.num_prefills == 0: return None if self._cached_prefill_metadata is not None: # Recover cached prefill-phase attention # metadata structure return self._cached_prefill_metadata assert ((self.seq_lens_tensor is not None) or (self.encoder_seq_lens_tensor is not None)) # Compute some attn_metadata fields which default to None query_start_loc = (None if self.query_start_loc is None else self.query_start_loc[-(self.num_prefills + 1):] - self.query_start_loc[-(self.num_prefills + 1)]) # flash attention needs both lod information on host and device query_start_loc_host = (None if self.query_start_loc_host is None else self.query_start_loc_host[-(self.num_prefills + 1):] - self.query_start_loc_host[-(self.num_prefills + 1)]) # TODO(chengruichang):how to support prefix cache kv_prefix_start_loc_host = None kv_prefix_start_loc = None slot_mapping = (None if self.slot_mapping is None else self.slot_mapping[-self.num_prefill_tokens:]) seq_lens_tensor = (None if self.seq_lens_tensor is None else self.seq_lens_tensor[-self.num_prefills:]) seq_lens = (None if self.seq_lens is None else self.seq_lens[-self.num_prefills:]) context_lens_tensor = (None if self.context_lens_tensor is None else self.context_lens_tensor[-self.num_prefills:]) block_tables = (None if self.block_tables is None else self.block_tables[-self.num_prefills:]) input_positions = (None if self.input_positions is None else self.input_positions[-self.num_prefills:]) if self.kv_lod_cpu is None: kv_lod_cpu = None kv_lod_xpu = None else: start = -(self.num_prefills + 1) base_cpu = self.kv_lod_cpu[start] kv_lod_cpu = self.kv_lod_cpu[start:] - base_cpu base_xpu = self.kv_lod_xpu[start] kv_lod_xpu = self.kv_lod_xpu[start:] - base_xpu # Construct & cache prefill-phase attention metadata structure self._cached_prefill_metadata = KunlunMetadata( num_actual_tokens=self.num_actual_tokens, multi_modal_placeholder_index_maps=self. multi_modal_placeholder_index_maps, num_prefills=self.num_prefills, num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=0, slot_mapping=slot_mapping, seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, seq_start_loc = None, kv_lod_cpu=kv_lod_cpu, kv_lod_xpu=kv_lod_xpu, max_query_len=self.max_query_len, max_kv_len=self.max_kv_len, max_prefill_seq_len=self.max_prefill_seq_len, max_decode_seq_len=0, query_start_loc=query_start_loc, query_start_loc_host=query_start_loc_host, input_positions=input_positions, kv_prefix_start_loc=kv_prefix_start_loc, kv_prefix_start_loc_host=kv_prefix_start_loc_host, context_lens_tensor=context_lens_tensor, block_tables=block_tables, use_cuda_graph=False, # Begin encoder & cross attn fields below... encoder_seq_lens=self.encoder_seq_lens, encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, max_encoder_seq_len=self.max_encoder_seq_len, cross_slot_mapping=self.cross_slot_mapping, cross_block_tables=self.cross_block_tables, enable_kv_scales_calculation=False, use_cascade=self.use_cascade) return self._cached_prefill_metadata @property def decode_metadata(self) -> Optional["KunlunMetadata"]: """decode_metadata""" if self.num_decode_tokens == 0: return None if self._cached_decode_metadata is not None: # Recover cached decode-phase attention # metadata structure return self._cached_decode_metadata assert ((self.seq_lens_tensor is not None) or (self.encoder_seq_lens_tensor is not None)) if self.num_prefills != 0: # Compute some attn_metadata fields which default to None slot_mapping = (None if self.slot_mapping is None else self.slot_mapping[:-self.num_prefill_tokens]) seq_lens_tensor = (None if self.seq_lens_tensor is None else self.seq_lens_tensor[:-self.num_prefills]) seq_lens_tensor_cpu = (None if self.seq_lens_tensor_cpu is None else self.seq_lens_tensor_cpu[:-self.num_prefills]) block_tables = (None if self.block_tables is None else self.block_tables[:-self.num_prefills]) else: # Compute some attn_metadata fields which default to None slot_mapping = (None if self.slot_mapping is None else self.slot_mapping) seq_lens_tensor = (None if self.seq_lens_tensor is None else self.seq_lens_tensor) seq_lens_tensor_cpu = (None if self.seq_lens_tensor_cpu is None else self.seq_lens_tensor_cpu) block_tables = (None if self.block_tables is None else self.block_tables) # Construct & cache decode-phase attention metadata structure self._cached_decode_metadata = KunlunMetadata( num_actual_tokens=self.num_actual_tokens, multi_modal_placeholder_index_maps=self. multi_modal_placeholder_index_maps, num_prefills=0, num_prefill_tokens=0, num_decode_tokens=self.num_decode_tokens, slot_mapping=slot_mapping, seq_lens_tensor=seq_lens_tensor, seq_lens_tensor_cpu=seq_lens_tensor_cpu, max_prefill_seq_len=0, max_decode_seq_len=self.max_decode_seq_len, block_tables=block_tables, use_cuda_graph=self.use_cuda_graph, # Begin encoder & cross attn fields below... encoder_seq_lens=self.encoder_seq_lens, encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, max_encoder_seq_len=self.max_encoder_seq_len, cross_slot_mapping=self.cross_slot_mapping, cross_block_tables=self.cross_block_tables, enable_kv_scales_calculation=False, use_cascade=self.use_cascade) return self._cached_decode_metadata class KunlunAttentionMetadataBuilder: """KunlunAttentionMetadataBuilder""" cudagraph_support: ClassVar[AttentionCGSupport] = \ AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE reorder_batch_threshold: ClassVar[Optional[int]] = 1 def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): """__init__""" self.vllm_config = vllm_config self.model_config = vllm_config.model_config self.parallel_config = vllm_config.parallel_config self.compilation_config = vllm_config.compilation_config self.num_heads_q = self.model_config.get_num_attention_heads( self.parallel_config) self.num_heads_kv = self.model_config.get_num_kv_heads( self.parallel_config) self.headdim = self.model_config.get_head_size() self.block_size = kv_cache_spec.block_size self.kv_cache_spec = kv_cache_spec self.device = device def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput") -> bool: """reorder_batch""" decodes = [] prefills = [] num_decode_tokens = 0 num_prefill_tokens = 0 for i, req_id in enumerate(input_batch.req_ids): num_tokens = scheduler_output.num_scheduled_tokens[req_id] # TODO: how if a prefilled sequence has only one token if num_tokens == 1: decodes.append(i) num_decode_tokens += num_tokens else: prefills.append(i) num_prefill_tokens += num_tokens num_decodes = len(decodes) num_prefills = len(prefills) first_prefill = 0 modified_batch = False for i in range(1, min(num_decodes, num_prefills) + 1): if decodes[num_decodes - i] >= num_decodes: input_batch.swap_states(prefills[first_prefill], decodes[num_decodes - i]) first_prefill += 1 modified_batch = True else: break self._num_decodes = num_decodes self._num_prefills = num_prefills self._num_decode_tokens = num_decode_tokens self._num_prefill_tokens = num_prefill_tokens return modified_batch def build_for_cudagraph_capture( self, common_attn_metadata: CommonAttentionMetadata ) -> KunlunMetadata: attn_metadata = self.build(0, common_attn_metadata) # When doing full graph capture, setting seq_lens to # max_model_len will cause graph capture to be extremely # slow, so here we set it to 1. attn_metadata.seq_lens_tensor.fill_(1) return attn_metadata def build(self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata): """build""" num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len common_prefix_len = common_prefix_len block_table_tensor = common_attn_metadata.block_table_tensor slot_mapping = common_attn_metadata.slot_mapping max_seq_len = int(common_attn_metadata.seq_lens_cpu.max()) query_start_loc_host = common_attn_metadata.query_start_loc_cpu[:num_reqs + 1] query_start_loc = common_attn_metadata.query_start_loc_cpu[:num_reqs + 1].to( self.device, non_blocking=True) seq_lens = common_attn_metadata.seq_lens seq_lens_cpu = common_attn_metadata.seq_lens_cpu seq_start_loc = list(accumulate(seq_lens, initial=0)) seq_start_loc_tensor = torch.empty(len(seq_start_loc), dtype=torch.int32, device=self.device) seq_start_loc_tensor.copy_(torch.as_tensor(seq_start_loc, dtype=torch.int32)) kv_lod_cpu = torch.zeros(num_reqs + 1, dtype=torch.int32, device="cpu") kv_lod_cpu[1:] = seq_lens_cpu.to(torch.int32).cumsum(dim=0) kv_lod_xpu = kv_lod_cpu.to(self.device) num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\ split_decodes_and_prefills(common_attn_metadata) num_scheduled_tokens = np.diff(common_attn_metadata.query_start_loc_cpu[:num_reqs + 1]) tmp_decode_scheduled_tokens = num_scheduled_tokens[:num_decodes] if num_decode_tokens == 0: max_decode_seq_len = 0 else: max_decode_seq_len = np.max(tmp_decode_scheduled_tokens) tmp_prefill_scheduled_tokens = num_scheduled_tokens[num_decodes: num_reqs] if num_prefill_tokens == 0: max_prefill_seq_len = 0 else: max_prefill_seq_len = np.max(tmp_prefill_scheduled_tokens) use_cascade = False attn_metadata = KunlunMetadata( num_actual_tokens=num_actual_tokens, num_prefills=num_prefills, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=None, enable_kv_scales_calculation=True, num_prefill_tokens=num_prefill_tokens, num_decode_tokens=num_decode_tokens, seq_lens_tensor=seq_lens, seq_lens_tensor_cpu=seq_lens_cpu, kv_lod_xpu=kv_lod_xpu, kv_lod_cpu=kv_lod_cpu, max_query_len=max_prefill_seq_len, max_prefill_seq_len=max_prefill_seq_len, max_decode_seq_len=max_decode_seq_len, query_start_loc=query_start_loc, query_start_loc_host=query_start_loc_host, context_lens_tensor=None, block_tables=block_table_tensor, use_cuda_graph=False, use_cascade=use_cascade, ) return attn_metadata def can_run_in_cudagraph( self, common_attn_metadata: CommonAttentionMetadata) -> bool: """can_run_in_cudagraph""" # Full CUDA Graph always supported (FA2 support checked separately) return True def use_cascade_attention(self, *args, **kwargs) -> bool: """use_cascade_attention""" return use_cascade_attention(*args, **kwargs) class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]): """KunlunAttentionImpl""" def __init__( self, num_heads: int, head_size: int, scale: float, num_kv_heads: int, alibi_slopes: Optional[List[float]], sliding_window: Optional[int], kv_cache_dtype: str, blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, kv_sharing_target_layer_name: Optional[str] = None, attn_type: AttentionType = AttentionType.DECODER, use_irope: bool = False, sinks:Optional[torch.Tensor]= None, multi_modal_placeholder_index_maps:Optional[torch.Tensor]= None, ) -> None: """__init__""" if blocksparse_params is not None: raise ValueError( "kunlunAttention does not support block-sparse attention.") # if logits_soft_cap is not None: # raise ValueError( # "kunlunAttention does not support attention logits soft capping.") self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) self.num_kv_heads = num_kv_heads if alibi_slopes is not None: alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) self.alibi_slopes = alibi_slopes self.sliding_window = sliding_window self.kv_cache_dtype = kv_cache_dtype self.kv_sharing_target_layer_name = kv_sharing_target_layer_name assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.use_irope = use_irope suppored_head_sizes = PagedAttention.get_supported_head_sizes() if head_size not in suppored_head_sizes: raise ValueError( f"Head size {head_size} is not supported by PagedAttention. " f"Supported head sizes are: {suppored_head_sizes}.") self.sinks = sinks if sinks is not None: assert sinks.shape[0] == num_heads, ( "Sinks must have the same number of heads as the number of " f"heads in the layer. Sinks shape: {sinks.shape}, " f"num_heads: {num_heads}.") self.multi_modal_placeholder_index_maps = multi_modal_placeholder_index_maps def forward( self, layer: AttentionLayer, query: torch.Tensor, key: Optional[torch.Tensor], value: Optional[torch.Tensor], kv_cache: torch.Tensor, attn_metadata: Optional[KunlunMetadata], k_scale: float = 1.0, v_scale: float = 1.0, attn_type: AttentionType = AttentionType.DECODER, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, output_block_scale: Optional[torch.Tensor] = None ) -> torch.Tensor: """forward""" query = query.view(-1, self.num_heads, self.head_size) if output is None: output = torch.empty_like(query) if attn_metadata is None: # Profiling run. return output.view(-1, self.num_heads * self.head_size) if key is not None: assert value is not None key = key.view(-1, self.num_kv_heads, self.head_size) value = value.view(-1, self.num_kv_heads, self.head_size) else: assert value is None # Self-attention vs. cross-attention will impact # which KV cache memory-mapping & which # seqlen datastructures we utilize if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0): # KV-cache during decoder-self- or # encoder-decoder-cross-attention, but not # during encoder attention. # # Even if there are no new key/value pairs to cache, # we still need to break out key_cache and value_cache # i.e. for later use by paged attention key_cache, value_cache = PagedAttention.split_kv_cache( kv_cache, self.num_kv_heads, self.head_size) if (key is not None) and (value is not None): updated_slot_mapping = attn_metadata.slot_mapping # Reshape the input keys and values and store them in the cache. # If kv_cache is not provided, the new key and value tensors are # not cached. This happens during the initial memory value = value.contiguous() if key_cache.is_contiguous(): xtorch_ops.reshape_and_cache( key, value, key_cache, value_cache, updated_slot_mapping) else: cast_key_cache = key_cache.squeeze(1).unsqueeze(-2) cast_value_cache = value_cache.squeeze(1).unsqueeze(-2) xtorch_ops.reshape_and_cache_flash( key, value, cast_key_cache, cast_value_cache, updated_slot_mapping) assert attn_type == AttentionType.DECODER # Decoder self-attention supports chunked prefill. num_prefill_tokens = attn_metadata.num_prefill_tokens num_decode_tokens = attn_metadata.num_decode_tokens # Only enforce this shape-constraint for decoder # self-attention if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. prefill_query = query[num_decode_tokens:attn_metadata.num_actual_tokens] if key_cache.is_contiguous(): tmp_block_tables = prefill_meta.block_tables else: tmp_block_tables = prefill_meta.block_tables * 2 # only test in Qwen3-Next xtorch_ops.prefill_attention( q=prefill_query, k=key_cache, # Key Cache (block_num, head, block_size, dim) v=value_cache, out=output[num_decode_tokens:attn_metadata.num_actual_tokens], is_causal=True, is_prefix_cache=True, block_table=tmp_block_tables, context_qlen_lod_cpu=prefill_meta.query_start_loc_host, context_qlen_lod_xpu=prefill_meta.query_start_loc, context_kvlen_lod_cpu=prefill_meta.kv_lod_cpu, context_kvlen_lod_xpu=prefill_meta.kv_lod_xpu, alibi_slopes=self.alibi_slopes, softmax_lse=None, sink=self.sinks ) if decode_meta := attn_metadata.decode_metadata: assert attn_type != AttentionType.ENCODER_ONLY, ( "Encoder-only models should not have decode metadata.") decode_query = query[:num_decode_tokens] if key_cache.is_contiguous(): tmp_block_tables = decode_meta.block_tables else: tmp_block_tables = decode_meta.block_tables * 2 # only test in Qwen3-Next xtorch_ops.paged_attention( x=decode_query, k_cache=key_cache, v_cache=value_cache, block_tables=tmp_block_tables, context_lens_cpu=decode_meta.seq_lens_tensor_cpu, context_lens_xpu=decode_meta.seq_lens_tensor, is_context=False, is_causal=True, out=output[:num_decode_tokens], vo_head_dim=self.head_size ) # Reshape the output tensor. return output.view(-1, self.num_heads * self.head_size) def use_cascade_attention( common_prefix_len: int, query_lens: np.ndarray, num_query_heads: int, num_kv_heads: int, use_alibi: bool, use_sliding_window: bool, num_sms: int, use_local_attention: bool = False, ) -> bool: """Decide whether to use cascade attention. This function 1) checks whether cascade attention is supported with the given configuration, and 2) heuristically decides whether using cascade attention can improve performance. """ # Too short common prefix. Probably not worth using cascade attention. # We use an arbitrary threshold of 256 tokens. TODO: Tune this threshold. # NOTE(woosuk): This is the common case. We should return False as soon as # possible to avoid any unnecessary computation. return False if common_prefix_len < 256: return False # Cascade attention is currently not supported with these variants. if use_alibi or use_sliding_window or use_local_attention: return False # Too few queries. Probably not worth using cascade attention. # We use an arbitrary threshold of 8 queries. TODO: Tune this threshold. num_reqs = len(query_lens) if num_reqs < 8: return False # Heuristics to decide whether using cascade attention is beneficial. # 1. When FlashDecoding is not used for normal attention, cascade attention # is likely to be faster since it saves memory bandwidth. num_queries_per_kv = num_query_heads // num_kv_heads # The criteria for using FlashDecoding can be found in the following link: # https://github.com/vllm-project/flash-attention/blob/96266b1111111f3d11aabefaf3bacbab6a89d03c/csrc/flash_attn/flash_api.cpp#L535 use_flash_decoding = (num_queries_per_kv > 1 and not use_sliding_window and not use_alibi and np.all(query_lens == 1)) if not use_flash_decoding: # Use cascade attention. return True # 2. When FlashDecoding is used for normal attention, it is not clear # whether cascade attention is beneficial, because FlashDecoding can # launch more CTAs than cascade attention. # We use a simple performance model to compare the two methods. # NOTE(woosuk): The performance model is very rough and may not be # accurate. num_tokens = num_reqs # NOTE(woosuk): These are default tile sizes. flash-attn might use # different tile sizes (e.g., 64 or 256) depending on the configuration. q_tile_size = 128 kv_tile_size = 128 num_prefix_tiles = cdiv(common_prefix_len, kv_tile_size) cascade_ctas = num_query_heads * cdiv(num_tokens, q_tile_size) cascade_waves = cdiv(cascade_ctas, num_sms) cascade_time = cascade_waves * num_prefix_tiles flash_decoding_ctas = (num_reqs * num_kv_heads * cdiv(num_queries_per_kv, q_tile_size)) flash_decoding_ctas *= num_prefix_tiles flash_decoding_time = cdiv(flash_decoding_ctas, num_sms) # Use cascade attention if it is faster than FlashDecoding. return cascade_time < flash_decoding_time