Sync from v0.13
This commit is contained in:
428
vllm/v1/attention/backends/tree_attn.py
Normal file
428
vllm/v1/attention/backends/tree_attn.py
Normal file
@@ -0,0 +1,428 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Attention layer with TreeAttention."""
|
||||
|
||||
import ast
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionBackend,
|
||||
AttentionImpl,
|
||||
AttentionType,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.attention.ops.triton_unified_attention import unified_attention
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
split_decodes_and_prefills,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class TreeAttentionBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||
return [MultipleOf(16)]
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [32, 64, 96, 128, 160, 192, 224, 256]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "TREE_ATTN"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["TreeAttentionImpl"]:
|
||||
return TreeAttentionImpl
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> tuple[int, ...]:
|
||||
if block_size % 16 != 0:
|
||||
raise ValueError("Block size must be a multiple of 16.")
|
||||
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["TreeAttentionMetadataBuilder"]:
|
||||
return TreeAttentionMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def use_cascade_attention(*args, **kwargs) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
@dataclass
|
||||
class TreeAttentionMetadata:
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
max_query_len: int
|
||||
query_start_loc: torch.Tensor
|
||||
max_seq_len: int
|
||||
seq_lens: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
num_prefill_tokens: int = 0
|
||||
num_decode_tokens: int = 0
|
||||
num_prefills: int = 0
|
||||
num_decodes: int = 0
|
||||
|
||||
tree_attn_bias: torch.Tensor | None = None
|
||||
|
||||
# Cached Prefill/decode metadata.
|
||||
_cached_prefill_metadata: Optional["TreeAttentionMetadata"] = None
|
||||
_cached_decode_metadata: Optional["TreeAttentionMetadata"] = None
|
||||
|
||||
@property
|
||||
def prefill_metadata(self) -> Optional["TreeAttentionMetadata"]:
|
||||
if self.num_prefills == 0:
|
||||
return None
|
||||
|
||||
if self._cached_prefill_metadata is not None:
|
||||
# Recover cached prefill-phase attention
|
||||
# metadata structure
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
q_start_loc = self.query_start_loc[self.num_decodes :]
|
||||
q_seqlens = torch.diff(q_start_loc)
|
||||
kv_seqlens = self.seq_lens[self.num_decodes :]
|
||||
# Construct & cache prefill-phase attention metadata structure
|
||||
self._cached_prefill_metadata = TreeAttentionMetadata(
|
||||
num_actual_tokens=self.num_prefill_tokens,
|
||||
max_query_len=int(q_seqlens.max().item()),
|
||||
query_start_loc=q_start_loc - q_start_loc[0],
|
||||
max_seq_len=int(kv_seqlens.max().item()),
|
||||
seq_lens=kv_seqlens,
|
||||
block_table=self.block_table[self.num_decodes :],
|
||||
slot_mapping=self.slot_mapping[self.num_decode_tokens :],
|
||||
)
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
@property
|
||||
def decode_metadata(self) -> Optional["TreeAttentionMetadata"]:
|
||||
if self.num_decode_tokens == 0:
|
||||
return None
|
||||
|
||||
if self._cached_decode_metadata is not None:
|
||||
# Recover cached decode-phase attention
|
||||
# metadata structure
|
||||
return self._cached_decode_metadata
|
||||
|
||||
q_start_loc = self.query_start_loc[: self.num_decodes + 1]
|
||||
q_seqlens = torch.diff(q_start_loc)
|
||||
kv_seqlens = self.seq_lens[: self.num_decodes]
|
||||
# Construct & cache decode-phase attention metadata structure
|
||||
self._cached_decode_metadata = TreeAttentionMetadata(
|
||||
num_actual_tokens=self.num_decode_tokens,
|
||||
max_query_len=int(q_seqlens.max().item()),
|
||||
query_start_loc=q_start_loc,
|
||||
max_seq_len=int(kv_seqlens.max().item()),
|
||||
seq_lens=kv_seqlens,
|
||||
block_table=self.block_table[: self.num_decodes],
|
||||
slot_mapping=self.slot_mapping[: self.num_decode_tokens],
|
||||
tree_attn_bias=self.tree_attn_bias,
|
||||
)
|
||||
return self._cached_decode_metadata
|
||||
|
||||
|
||||
class TreeAttentionMetadataBuilder(AttentionMetadataBuilder[TreeAttentionMetadata]):
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
|
||||
self.block_size = kv_cache_spec.block_size
|
||||
|
||||
spec_config = vllm_config.speculative_config
|
||||
spec_token_tree = (spec := spec_config) and spec.speculative_token_tree
|
||||
tree_choices: list[tuple[int, ...]] = (
|
||||
ast.literal_eval(spec_token_tree) if spec_token_tree is not None else [(0,)]
|
||||
)
|
||||
# Construct the tree attention bias.
|
||||
depth_counts = _get_depth_counts(tree_choices)
|
||||
self.tree_attn_bias = _prepare_tree_attn_bias(
|
||||
tree_choices,
|
||||
depth_counts,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
self.reorder_batch_threshold = self.tree_attn_bias.shape[0]
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> TreeAttentionMetadata:
|
||||
decode_threshold = self.tree_attn_bias.shape[0]
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
split_decodes_and_prefills(
|
||||
common_attn_metadata, decode_threshold=decode_threshold
|
||||
)
|
||||
)
|
||||
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
q_start_loc = common_attn_metadata.query_start_loc
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
kv_seqlens = common_attn_metadata.seq_lens
|
||||
max_seq_len = common_attn_metadata.max_seq_len
|
||||
block_table = common_attn_metadata.block_table_tensor
|
||||
slot_mapping = common_attn_metadata.slot_mapping
|
||||
|
||||
return TreeAttentionMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
num_prefills=num_prefills,
|
||||
num_decodes=num_decodes,
|
||||
max_query_len=max_query_len,
|
||||
query_start_loc=q_start_loc,
|
||||
max_seq_len=max_seq_len,
|
||||
seq_lens=kv_seqlens,
|
||||
block_table=block_table,
|
||||
slot_mapping=slot_mapping,
|
||||
tree_attn_bias=self.tree_attn_bias,
|
||||
)
|
||||
|
||||
def build_for_drafting(
|
||||
self,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
draft_index: int,
|
||||
) -> TreeAttentionMetadata:
|
||||
# Cache the original tree attention bias.
|
||||
orig_tree_attn_bias = self.tree_attn_bias
|
||||
|
||||
if draft_index == 0:
|
||||
# Use prefill for drafting at the root level.
|
||||
self.tree_attn_bias = torch.empty(0)
|
||||
else:
|
||||
# Slice the tree attention bias for drafting. Exclude
|
||||
# the root level.
|
||||
start, end = 1, 1 + common_attn_metadata.max_query_len
|
||||
self.tree_attn_bias = self.tree_attn_bias[start:end, start:end].contiguous()
|
||||
|
||||
# Build attention bias.
|
||||
attn_metadata = self.build(0, common_attn_metadata, fast_build=True)
|
||||
|
||||
# Reset the tree attention bias to the original value.
|
||||
self.tree_attn_bias = orig_tree_attn_bias
|
||||
return attn_metadata
|
||||
|
||||
|
||||
def _get_depth_counts(sorted_tree_choices: list[tuple[int, ...]]) -> list[int]:
|
||||
# Count the number of choices at each depth of the tree.
|
||||
depth_counts = []
|
||||
prev_depth = 0
|
||||
for path in sorted_tree_choices:
|
||||
depth = len(path)
|
||||
if depth != prev_depth:
|
||||
depth_counts.append(0)
|
||||
depth_counts[depth - 1] += 1
|
||||
prev_depth = depth
|
||||
return depth_counts
|
||||
|
||||
|
||||
def _prepare_tree_attn_bias(
|
||||
sorted_tree_choices: list[tuple[int, ...]],
|
||||
depth_counts: list[int],
|
||||
dtype: torch.dtype | None,
|
||||
device: torch.device | None,
|
||||
) -> torch.Tensor:
|
||||
# +1 comes from the additional root node.
|
||||
tree_len = len(sorted_tree_choices) + 1
|
||||
tree_attn_mask = torch.full(
|
||||
(tree_len, tree_len), -torch.inf, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
# Set diagonal to all zeros. Each token should
|
||||
# attend to itself.
|
||||
mask_val = 0
|
||||
for i in range(tree_len):
|
||||
tree_attn_mask[i, i] = mask_val
|
||||
|
||||
# Set root to all zeros. All tokens attend to it.
|
||||
tree_attn_mask[:, 0] = mask_val
|
||||
|
||||
# Set all ancestors to zeros.
|
||||
start = 0
|
||||
for i in range(len(depth_counts)):
|
||||
for j in range(depth_counts[i]):
|
||||
cur_tree_choice = sorted_tree_choices[start + j]
|
||||
# Retrieve ancestor position.
|
||||
if len(cur_tree_choice) == 1:
|
||||
continue
|
||||
ancestor_idx = []
|
||||
for c in range(len(cur_tree_choice) - 1):
|
||||
ancestor_idx.append(
|
||||
sorted_tree_choices.index(cur_tree_choice[: c + 1]) + 1
|
||||
)
|
||||
tree_attn_mask[j + start + 1, ancestor_idx] = mask_val
|
||||
start += depth_counts[i]
|
||||
return tree_attn_mask
|
||||
|
||||
|
||||
class TreeAttentionImpl(AttentionImpl):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: list[float] | None,
|
||||
sliding_window: int | None,
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: float | None = None,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: str | None = None,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||
if alibi_slopes is not None:
|
||||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
||||
self.alibi_slopes = alibi_slopes
|
||||
if logits_soft_cap is None:
|
||||
# Setting logits_soft_cap to 0 means no soft cap.
|
||||
logits_soft_cap = 0
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
if sliding_window is None:
|
||||
self.sliding_window = (-1, -1)
|
||||
else:
|
||||
self.sliding_window = (sliding_window - 1, 0)
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError(
|
||||
"Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"TreeAttentionImpl."
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: TreeAttentionMetadata,
|
||||
output: torch.Tensor | None = None,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with TreeAttention.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads, head_size]
|
||||
key: shape = [num_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_tokens, num_kv_heads, head_size]
|
||||
kv_cache: shape =
|
||||
[2, num_blocks, block_size, num_kv_heads, head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_scale is not None or output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported for TreeAttentionImpl"
|
||||
)
|
||||
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output.fill_(0)
|
||||
|
||||
# Cache the input KVs.
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
if self.kv_sharing_target_layer_name is None:
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
|
||||
# not padded. However, we don't need to do key[:num_actual_tokens]
|
||||
# and value[:num_actual_tokens] because the reshape_and_cache_flash
|
||||
# op uses the slot_mapping's shape to determine the number of
|
||||
# actual tokens.
|
||||
ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
descale_shape = (attn_metadata.query_start_loc.shape[0] - 1, key.shape[1])
|
||||
if prefill_meta := attn_metadata.prefill_metadata:
|
||||
unified_attention(
|
||||
q=query[num_decode_tokens:num_actual_tokens],
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
out=output[num_decode_tokens:num_actual_tokens],
|
||||
cu_seqlens_q=prefill_meta.query_start_loc,
|
||||
max_seqlen_q=prefill_meta.max_query_len,
|
||||
seqused_k=prefill_meta.seq_lens,
|
||||
max_seqlen_k=prefill_meta.max_seq_len,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
window_size=self.sliding_window,
|
||||
block_table=prefill_meta.block_table,
|
||||
softcap=self.logits_soft_cap,
|
||||
q_descale=None, # Not supported
|
||||
k_descale=layer._k_scale.expand(descale_shape),
|
||||
v_descale=layer._v_scale.expand(descale_shape),
|
||||
)
|
||||
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
unified_attention(
|
||||
q=query[:num_decode_tokens],
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
out=output[:num_decode_tokens],
|
||||
cu_seqlens_q=decode_meta.query_start_loc,
|
||||
max_seqlen_q=decode_meta.max_query_len,
|
||||
seqused_k=decode_meta.seq_lens,
|
||||
max_seqlen_k=decode_meta.max_seq_len,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
qq_bias=decode_meta.tree_attn_bias,
|
||||
window_size=self.sliding_window,
|
||||
block_table=decode_meta.block_table,
|
||||
softcap=self.logits_soft_cap,
|
||||
q_descale=None, # Not supported
|
||||
k_descale=layer._k_scale.expand(descale_shape),
|
||||
v_descale=layer._v_scale.expand(descale_shape),
|
||||
)
|
||||
return output
|
||||
Reference in New Issue
Block a user