# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer with TreeAttention.""" import ast from dataclasses import dataclass from typing import TYPE_CHECKING, Optional import torch from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) from vllm.attention.ops.triton_unified_attention import unified_attention from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.v1.attention.backends.utils import ( AttentionMetadataBuilder, CommonAttentionMetadata, reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills) from vllm.v1.kv_cache_interface import AttentionSpec if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.worker.gpu_input_batch import InputBatch from vllm import _custom_ops as ops logger = init_logger(__name__) class TreeAttentionBackend(AttentionBackend): accept_output_buffer: bool = True @classmethod def get_supported_dtypes(cls) -> list[torch.dtype]: return [torch.float16, torch.bfloat16] @classmethod def get_supported_head_sizes(cls) -> list[int]: return [32, 64, 96, 128, 160, 192, 224, 256] @classmethod def validate_head_size(cls, head_size: int) -> None: supported_head_sizes = cls.get_supported_head_sizes() if head_size not in supported_head_sizes: attn_type = cls.__name__.removesuffix("Backend") raise ValueError( f"Head size {head_size} is not supported by {attn_type}. " f"Supported head sizes are: {supported_head_sizes}. " "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " "FlexAttention backend which supports all head sizes.") @staticmethod def get_name() -> str: return "TREE_ATTN" @staticmethod def get_impl_cls() -> type["TreeAttentionImpl"]: return TreeAttentionImpl @staticmethod def get_metadata_cls() -> type["AttentionMetadata"]: return TreeAttentionMetadata @staticmethod def get_kv_cache_shape( num_blocks: int, block_size: int, num_kv_heads: int, head_size: int, cache_dtype_str: str = "auto", ) -> tuple[int, ...]: if block_size % 16 != 0: raise ValueError("Block size must be a multiple of 16.") return (2, num_blocks, block_size, num_kv_heads, head_size) @staticmethod def get_builder_cls() -> type["TreeAttentionMetadataBuilder"]: return TreeAttentionMetadataBuilder @staticmethod def use_cascade_attention(*args, **kwargs) -> bool: return False @dataclass class TreeAttentionMetadata: num_actual_tokens: int # Number of tokens excluding padding. max_query_len: int query_start_loc: torch.Tensor max_seq_len: int seq_lens: torch.Tensor block_table: torch.Tensor slot_mapping: torch.Tensor num_prefill_tokens: int = 0 num_decode_tokens: int = 0 num_prefills: int = 0 num_decodes: int = 0 tree_attn_bias: Optional[torch.Tensor] = None # Cached Prefill/decode metadata. _cached_prefill_metadata: Optional["TreeAttentionMetadata"] = None _cached_decode_metadata: Optional["TreeAttentionMetadata"] = None @property def prefill_metadata(self) -> Optional["TreeAttentionMetadata"]: if self.num_prefills == 0: return None if self._cached_prefill_metadata is not None: # Recover cached prefill-phase attention # metadata structure return self._cached_prefill_metadata q_start_loc = self.query_start_loc[self.num_decodes:] q_seqlens = torch.diff(q_start_loc) kv_seqlens = self.seq_lens[self.num_decodes:] # Construct & cache prefill-phase attention metadata structure self._cached_prefill_metadata = TreeAttentionMetadata( num_actual_tokens=self.num_prefill_tokens, max_query_len=int(q_seqlens.max().item()), query_start_loc=q_start_loc - q_start_loc[0], max_seq_len=int(kv_seqlens.max().item()), seq_lens=kv_seqlens, block_table=self.block_table[self.num_decodes:], slot_mapping=self.slot_mapping[self.num_decode_tokens:], ) return self._cached_prefill_metadata @property def decode_metadata(self) -> Optional["TreeAttentionMetadata"]: if self.num_decode_tokens == 0: return None if self._cached_decode_metadata is not None: # Recover cached decode-phase attention # metadata structure return self._cached_decode_metadata q_start_loc = self.query_start_loc[:self.num_decodes + 1] q_seqlens = torch.diff(q_start_loc) kv_seqlens = self.seq_lens[:self.num_decodes] # Construct & cache decode-phase attention metadata structure self._cached_decode_metadata = TreeAttentionMetadata( num_actual_tokens=self.num_decode_tokens, max_query_len=int(q_seqlens.max().item()), query_start_loc=q_start_loc, max_seq_len=int(kv_seqlens.max().item()), seq_lens=kv_seqlens, block_table=self.block_table[:self.num_decodes], slot_mapping=self.slot_mapping[:self.num_decode_tokens], tree_attn_bias=self.tree_attn_bias, ) return self._cached_decode_metadata class TreeAttentionMetadataBuilder( AttentionMetadataBuilder[TreeAttentionMetadata]): def __init__( self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device, ): super().__init__(kv_cache_spec, layer_names, vllm_config, device) self.block_size = kv_cache_spec.block_size spec_config = vllm_config.speculative_config spec_token_tree = (spec := spec_config) and spec.speculative_token_tree tree_choices: list[tuple[int, ...]] = (ast.literal_eval(spec_token_tree) if spec_token_tree is not None else [(0, )]) # Construct the tree attention bias. depth_counts = _get_depth_counts(tree_choices) self.tree_attn_bias = _prepare_tree_attn_bias( tree_choices, depth_counts, dtype=torch.float32, device=device, ) def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput") -> bool: return reorder_batch_to_split_decodes_and_prefills( input_batch, scheduler_output, decode_threshold=self.tree_attn_bias.shape[0]) def build( self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, fast_build: bool = False, ) -> TreeAttentionMetadata: decode_threshold = self.tree_attn_bias.shape[0] num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( split_decodes_and_prefills(common_attn_metadata, decode_threshold=decode_threshold)) num_actual_tokens = common_attn_metadata.num_actual_tokens q_start_loc = common_attn_metadata.query_start_loc max_query_len = common_attn_metadata.max_query_len kv_seqlens = common_attn_metadata.seq_lens max_seq_len = common_attn_metadata.max_seq_len block_table = common_attn_metadata.block_table_tensor slot_mapping = common_attn_metadata.slot_mapping return TreeAttentionMetadata( num_actual_tokens=num_actual_tokens, num_prefill_tokens=num_prefill_tokens, num_decode_tokens=num_decode_tokens, num_prefills=num_prefills, num_decodes=num_decodes, max_query_len=max_query_len, query_start_loc=q_start_loc, max_seq_len=max_seq_len, seq_lens=kv_seqlens, block_table=block_table, slot_mapping=slot_mapping, tree_attn_bias=self.tree_attn_bias, ) def build_for_drafting( self, common_attn_metadata: CommonAttentionMetadata, draft_index: int, ) -> TreeAttentionMetadata: # Cache the original tree attention bias. orig_tree_attn_bias = self.tree_attn_bias if draft_index == 0: # Use prefill for drafting at the root level. self.tree_attn_bias = torch.empty(0) else: # Slice the tree attention bias for drafting. Exclude # the root level. start, end = 1, 1 + common_attn_metadata.max_query_len self.tree_attn_bias = self.tree_attn_bias[start:end, start:end].contiguous() # Build attention bias. attn_metadata = self.build(0, common_attn_metadata, fast_build=True) # Reset the tree attention bias to the original value. self.tree_attn_bias = orig_tree_attn_bias return attn_metadata def _get_depth_counts(sorted_tree_choices: list[tuple[int, ...]]) -> list[int]: # Count the number of choices at each depth of the tree. depth_counts = [] prev_depth = 0 for path in sorted_tree_choices: depth = len(path) if depth != prev_depth: depth_counts.append(0) depth_counts[depth - 1] += 1 prev_depth = depth return depth_counts def _prepare_tree_attn_bias( sorted_tree_choices: list[tuple[int, ...]], depth_counts: list[int], dtype: Optional[torch.dtype], device: Optional[torch.device], ) -> torch.Tensor: # +1 comes from the additional root node. tree_len = len(sorted_tree_choices) + 1 tree_attn_mask = torch.full((tree_len, tree_len), -torch.inf, device=device, dtype=dtype) # Set diagonal to all zeros. Each token should # attend to itself. mask_val = 0 for i in range(tree_len): tree_attn_mask[i, i] = mask_val # Set root to all zeros. All tokens attend to it. tree_attn_mask[:, 0] = mask_val # Set all ancestors to zeros. start = 0 for i in range(len(depth_counts)): for j in range(depth_counts[i]): cur_tree_choice = sorted_tree_choices[start + j] # Retrieve ancestor position. if len(cur_tree_choice) == 1: continue ancestor_idx = [] for c in range(len(cur_tree_choice) - 1): ancestor_idx.append( sorted_tree_choices.index(cur_tree_choice[:c + 1]) + 1) tree_attn_mask[j + start + 1, ancestor_idx] = mask_val start += depth_counts[i] return tree_attn_mask class TreeAttentionImpl(AttentionImpl): def __init__( self, num_heads: int, head_size: int, scale: float, num_kv_heads: int, alibi_slopes: Optional[list[float]], sliding_window: Optional[int], kv_cache_dtype: str, logits_soft_cap: Optional[float] = None, attn_type: AttentionType = AttentionType.DECODER, kv_sharing_target_layer_name: Optional[str] = None, ) -> None: self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) self.num_kv_heads = num_kv_heads self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.kv_cache_dtype = kv_cache_dtype self.kv_sharing_target_layer_name = kv_sharing_target_layer_name if alibi_slopes is not None: alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) self.alibi_slopes = alibi_slopes if logits_soft_cap is None: # Setting logits_soft_cap to 0 means no soft cap. logits_soft_cap = 0 self.logits_soft_cap = logits_soft_cap if sliding_window is None: self.sliding_window = (-1, -1) else: self.sliding_window = (sliding_window - 1, 0) TreeAttentionBackend.validate_head_size(head_size) if attn_type != AttentionType.DECODER: raise NotImplementedError("Encoder self-attention and " "encoder/decoder cross-attention " "are not implemented for " "TreeAttentionImpl.") def forward( self, layer: torch.nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: TreeAttentionMetadata, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with TreeAttention. Args: query: shape = [num_tokens, num_heads, head_size] key: shape = [num_tokens, num_kv_heads, head_size] value: shape = [num_tokens, num_kv_heads, head_size] kv_cache: shape = [2, num_blocks, block_size, num_kv_heads, head_size] attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] """ assert output is not None, "Output tensor must be provided." if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" " for TreeAttentionImpl") if attn_metadata is None: # Profiling run. return output # Cache the input KVs. key_cache, value_cache = kv_cache.unbind(0) if self.kv_sharing_target_layer_name is None: # Reshape the input keys and values and store them in the cache. # Skip this if sharing KV cache with an earlier attention layer. # NOTE(woosuk): Here, key and value are padded while slot_mapping is # not padded. However, we don't need to do key[:num_actual_tokens] # and value[:num_actual_tokens] because the reshape_and_cache_flash # op uses the slot_mapping's shape to determine the number of # actual tokens. ops.reshape_and_cache_flash( key, value, key_cache, value_cache, attn_metadata.slot_mapping, self.kv_cache_dtype, layer._k_scale, layer._v_scale, ) num_actual_tokens = attn_metadata.num_actual_tokens num_decode_tokens = attn_metadata.num_decode_tokens descale_shape = (attn_metadata.query_start_loc.shape[0] - 1, key.shape[1]) if prefill_meta := attn_metadata.prefill_metadata: unified_attention( q=query[num_decode_tokens:num_actual_tokens], k=key_cache, v=value_cache, out=output[num_decode_tokens:num_actual_tokens], cu_seqlens_q=prefill_meta.query_start_loc, max_seqlen_q=prefill_meta.max_query_len, seqused_k=prefill_meta.seq_lens, max_seqlen_k=prefill_meta.max_seq_len, softmax_scale=self.scale, causal=True, alibi_slopes=self.alibi_slopes, window_size=self.sliding_window, block_table=prefill_meta.block_table, softcap=self.logits_soft_cap, q_descale=None, # Not supported k_descale=layer._k_scale.expand(descale_shape), v_descale=layer._v_scale.expand(descale_shape), ) if decode_meta := attn_metadata.decode_metadata: unified_attention( q=query[:num_decode_tokens], k=key_cache, v=value_cache, out=output[:num_decode_tokens], cu_seqlens_q=decode_meta.query_start_loc, max_seqlen_q=decode_meta.max_query_len, seqused_k=decode_meta.seq_lens, max_seqlen_k=decode_meta.max_seq_len, softmax_scale=self.scale, causal=True, alibi_slopes=self.alibi_slopes, qq_bias=decode_meta.tree_attn_bias, window_size=self.sliding_window, block_table=decode_meta.block_table, softcap=self.logits_soft_cap, q_descale=None, # Not supported k_descale=layer._k_scale.expand(descale_shape), v_descale=layer._v_scale.expand(descale_shape), ) return output