# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer with TreeAttention.""" import ast from dataclasses import dataclass from typing import ClassVar, Optional import torch from vllm import _custom_ops as ops from vllm.attention.backends.abstract import ( AttentionBackend, AttentionImpl, AttentionType, MultipleOf, ) from vllm.attention.ops.triton_unified_attention import unified_attention from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.v1.attention.backends.utils import ( AttentionMetadataBuilder, CommonAttentionMetadata, split_decodes_and_prefills, ) from vllm.v1.kv_cache_interface import AttentionSpec logger = init_logger(__name__) class TreeAttentionBackend(AttentionBackend): accept_output_buffer: bool = True supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] @staticmethod def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: return [MultipleOf(16)] @classmethod def get_supported_head_sizes(cls) -> list[int]: return [32, 64, 96, 128, 160, 192, 224, 256] @staticmethod def get_name() -> str: return "TREE_ATTN" @staticmethod def get_impl_cls() -> type["TreeAttentionImpl"]: return TreeAttentionImpl @staticmethod def get_kv_cache_shape( num_blocks: int, block_size: int, num_kv_heads: int, head_size: int, cache_dtype_str: str = "auto", ) -> tuple[int, ...]: if block_size % 16 != 0: raise ValueError("Block size must be a multiple of 16.") return (2, num_blocks, block_size, num_kv_heads, head_size) @staticmethod def get_builder_cls() -> type["TreeAttentionMetadataBuilder"]: return TreeAttentionMetadataBuilder @staticmethod def use_cascade_attention(*args, **kwargs) -> bool: return False @dataclass class TreeAttentionMetadata: num_actual_tokens: int # Number of tokens excluding padding. max_query_len: int query_start_loc: torch.Tensor max_seq_len: int seq_lens: torch.Tensor block_table: torch.Tensor slot_mapping: torch.Tensor num_prefill_tokens: int = 0 num_decode_tokens: int = 0 num_prefills: int = 0 num_decodes: int = 0 tree_attn_bias: torch.Tensor | None = None # Cached Prefill/decode metadata. _cached_prefill_metadata: Optional["TreeAttentionMetadata"] = None _cached_decode_metadata: Optional["TreeAttentionMetadata"] = None @property def prefill_metadata(self) -> Optional["TreeAttentionMetadata"]: if self.num_prefills == 0: return None if self._cached_prefill_metadata is not None: # Recover cached prefill-phase attention # metadata structure return self._cached_prefill_metadata q_start_loc = self.query_start_loc[self.num_decodes :] q_seqlens = torch.diff(q_start_loc) kv_seqlens = self.seq_lens[self.num_decodes :] # Construct & cache prefill-phase attention metadata structure self._cached_prefill_metadata = TreeAttentionMetadata( num_actual_tokens=self.num_prefill_tokens, max_query_len=int(q_seqlens.max().item()), query_start_loc=q_start_loc - q_start_loc[0], max_seq_len=int(kv_seqlens.max().item()), seq_lens=kv_seqlens, block_table=self.block_table[self.num_decodes :], slot_mapping=self.slot_mapping[self.num_decode_tokens :], ) return self._cached_prefill_metadata @property def decode_metadata(self) -> Optional["TreeAttentionMetadata"]: if self.num_decode_tokens == 0: return None if self._cached_decode_metadata is not None: # Recover cached decode-phase attention # metadata structure return self._cached_decode_metadata q_start_loc = self.query_start_loc[: self.num_decodes + 1] q_seqlens = torch.diff(q_start_loc) kv_seqlens = self.seq_lens[: self.num_decodes] # Construct & cache decode-phase attention metadata structure self._cached_decode_metadata = TreeAttentionMetadata( num_actual_tokens=self.num_decode_tokens, max_query_len=int(q_seqlens.max().item()), query_start_loc=q_start_loc, max_seq_len=int(kv_seqlens.max().item()), seq_lens=kv_seqlens, block_table=self.block_table[: self.num_decodes], slot_mapping=self.slot_mapping[: self.num_decode_tokens], tree_attn_bias=self.tree_attn_bias, ) return self._cached_decode_metadata class TreeAttentionMetadataBuilder(AttentionMetadataBuilder[TreeAttentionMetadata]): def __init__( self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device, ): super().__init__(kv_cache_spec, layer_names, vllm_config, device) self.block_size = kv_cache_spec.block_size spec_config = vllm_config.speculative_config spec_token_tree = (spec := spec_config) and spec.speculative_token_tree tree_choices: list[tuple[int, ...]] = ( ast.literal_eval(spec_token_tree) if spec_token_tree is not None else [(0,)] ) # Construct the tree attention bias. depth_counts = _get_depth_counts(tree_choices) self.tree_attn_bias = _prepare_tree_attn_bias( tree_choices, depth_counts, dtype=torch.float32, device=device, ) self.reorder_batch_threshold = self.tree_attn_bias.shape[0] def build( self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, fast_build: bool = False, ) -> TreeAttentionMetadata: decode_threshold = self.tree_attn_bias.shape[0] num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( split_decodes_and_prefills( common_attn_metadata, decode_threshold=decode_threshold ) ) num_actual_tokens = common_attn_metadata.num_actual_tokens q_start_loc = common_attn_metadata.query_start_loc max_query_len = common_attn_metadata.max_query_len kv_seqlens = common_attn_metadata.seq_lens max_seq_len = common_attn_metadata.max_seq_len block_table = common_attn_metadata.block_table_tensor slot_mapping = common_attn_metadata.slot_mapping return TreeAttentionMetadata( num_actual_tokens=num_actual_tokens, num_prefill_tokens=num_prefill_tokens, num_decode_tokens=num_decode_tokens, num_prefills=num_prefills, num_decodes=num_decodes, max_query_len=max_query_len, query_start_loc=q_start_loc, max_seq_len=max_seq_len, seq_lens=kv_seqlens, block_table=block_table, slot_mapping=slot_mapping, tree_attn_bias=self.tree_attn_bias, ) def build_for_drafting( self, common_attn_metadata: CommonAttentionMetadata, draft_index: int, ) -> TreeAttentionMetadata: # Cache the original tree attention bias. orig_tree_attn_bias = self.tree_attn_bias if draft_index == 0: # Use prefill for drafting at the root level. self.tree_attn_bias = torch.empty(0) else: # Slice the tree attention bias for drafting. Exclude # the root level. start, end = 1, 1 + common_attn_metadata.max_query_len self.tree_attn_bias = self.tree_attn_bias[start:end, start:end].contiguous() # Build attention bias. attn_metadata = self.build(0, common_attn_metadata, fast_build=True) # Reset the tree attention bias to the original value. self.tree_attn_bias = orig_tree_attn_bias return attn_metadata def _get_depth_counts(sorted_tree_choices: list[tuple[int, ...]]) -> list[int]: # Count the number of choices at each depth of the tree. depth_counts = [] prev_depth = 0 for path in sorted_tree_choices: depth = len(path) if depth != prev_depth: depth_counts.append(0) depth_counts[depth - 1] += 1 prev_depth = depth return depth_counts def _prepare_tree_attn_bias( sorted_tree_choices: list[tuple[int, ...]], depth_counts: list[int], dtype: torch.dtype | None, device: torch.device | None, ) -> torch.Tensor: # +1 comes from the additional root node. tree_len = len(sorted_tree_choices) + 1 tree_attn_mask = torch.full( (tree_len, tree_len), -torch.inf, device=device, dtype=dtype ) # Set diagonal to all zeros. Each token should # attend to itself. mask_val = 0 for i in range(tree_len): tree_attn_mask[i, i] = mask_val # Set root to all zeros. All tokens attend to it. tree_attn_mask[:, 0] = mask_val # Set all ancestors to zeros. start = 0 for i in range(len(depth_counts)): for j in range(depth_counts[i]): cur_tree_choice = sorted_tree_choices[start + j] # Retrieve ancestor position. if len(cur_tree_choice) == 1: continue ancestor_idx = [] for c in range(len(cur_tree_choice) - 1): ancestor_idx.append( sorted_tree_choices.index(cur_tree_choice[: c + 1]) + 1 ) tree_attn_mask[j + start + 1, ancestor_idx] = mask_val start += depth_counts[i] return tree_attn_mask class TreeAttentionImpl(AttentionImpl): def __init__( self, num_heads: int, head_size: int, scale: float, num_kv_heads: int, alibi_slopes: list[float] | None, sliding_window: int | None, kv_cache_dtype: str, logits_soft_cap: float | None = None, attn_type: AttentionType = AttentionType.DECODER, kv_sharing_target_layer_name: str | None = None, ) -> None: self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) self.num_kv_heads = num_kv_heads self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.kv_cache_dtype = kv_cache_dtype self.kv_sharing_target_layer_name = kv_sharing_target_layer_name if alibi_slopes is not None: alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) self.alibi_slopes = alibi_slopes if logits_soft_cap is None: # Setting logits_soft_cap to 0 means no soft cap. logits_soft_cap = 0 self.logits_soft_cap = logits_soft_cap if sliding_window is None: self.sliding_window = (-1, -1) else: self.sliding_window = (sliding_window - 1, 0) if attn_type != AttentionType.DECODER: raise NotImplementedError( "Encoder self-attention and " "encoder/decoder cross-attention " "are not implemented for " "TreeAttentionImpl." ) def forward( self, layer: torch.nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: TreeAttentionMetadata, output: torch.Tensor | None = None, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, ) -> torch.Tensor: """Forward pass with TreeAttention. Args: query: shape = [num_tokens, num_heads, head_size] key: shape = [num_tokens, num_kv_heads, head_size] value: shape = [num_tokens, num_kv_heads, head_size] kv_cache: shape = [2, num_blocks, block_size, num_kv_heads, head_size] attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] """ assert output is not None, "Output tensor must be provided." if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported for TreeAttentionImpl" ) if attn_metadata is None: # Profiling run. return output.fill_(0) # Cache the input KVs. key_cache, value_cache = kv_cache.unbind(0) if self.kv_sharing_target_layer_name is None: # Reshape the input keys and values and store them in the cache. # Skip this if sharing KV cache with an earlier attention layer. # NOTE(woosuk): Here, key and value are padded while slot_mapping is # not padded. However, we don't need to do key[:num_actual_tokens] # and value[:num_actual_tokens] because the reshape_and_cache_flash # op uses the slot_mapping's shape to determine the number of # actual tokens. ops.reshape_and_cache_flash( key, value, key_cache, value_cache, attn_metadata.slot_mapping, self.kv_cache_dtype, layer._k_scale, layer._v_scale, ) num_actual_tokens = attn_metadata.num_actual_tokens num_decode_tokens = attn_metadata.num_decode_tokens descale_shape = (attn_metadata.query_start_loc.shape[0] - 1, key.shape[1]) if prefill_meta := attn_metadata.prefill_metadata: unified_attention( q=query[num_decode_tokens:num_actual_tokens], k=key_cache, v=value_cache, out=output[num_decode_tokens:num_actual_tokens], cu_seqlens_q=prefill_meta.query_start_loc, max_seqlen_q=prefill_meta.max_query_len, seqused_k=prefill_meta.seq_lens, max_seqlen_k=prefill_meta.max_seq_len, softmax_scale=self.scale, causal=True, alibi_slopes=self.alibi_slopes, window_size=self.sliding_window, block_table=prefill_meta.block_table, softcap=self.logits_soft_cap, q_descale=None, # Not supported k_descale=layer._k_scale.expand(descale_shape), v_descale=layer._v_scale.expand(descale_shape), ) if decode_meta := attn_metadata.decode_metadata: unified_attention( q=query[:num_decode_tokens], k=key_cache, v=value_cache, out=output[:num_decode_tokens], cu_seqlens_q=decode_meta.query_start_loc, max_seqlen_q=decode_meta.max_query_len, seqused_k=decode_meta.seq_lens, max_seqlen_k=decode_meta.max_seq_len, softmax_scale=self.scale, causal=True, alibi_slopes=self.alibi_slopes, qq_bias=decode_meta.tree_attn_bias, window_size=self.sliding_window, block_table=decode_meta.block_table, softcap=self.logits_soft_cap, q_descale=None, # Not supported k_descale=layer._k_scale.expand(descale_shape), v_descale=layer._v_scale.expand(descale_shape), ) return output