322 lines
12 KiB
Python
322 lines
12 KiB
Python
from dataclasses import dataclass
|
|
from typing import TYPE_CHECKING, List, Optional, Tuple, Type
|
|
|
|
import torch
|
|
|
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
|
AttentionMetadata,
|
|
AttentionMetadataBuilder)
|
|
from vllm.attention.backends.utils import CommonAttentionState
|
|
|
|
if TYPE_CHECKING:
|
|
from vllm.worker.model_runner import ModelInputForGPUBuilder
|
|
|
|
# Placeholder attention backend for models like Mamba and embedding models that
|
|
# lack attention.
|
|
|
|
|
|
class PlaceholderAttentionBackend(AttentionBackend):
|
|
"""Placeholder backend for when no attention is needed."""
|
|
|
|
@staticmethod
|
|
def get_name() -> str:
|
|
return "placeholder-attn"
|
|
|
|
@staticmethod
|
|
def get_impl_cls() -> Type["PlaceholderAttentionImpl"]:
|
|
return PlaceholderAttentionImpl
|
|
|
|
@staticmethod
|
|
def get_builder_cls() -> Type["PlaceholderAttentionMetadataBuilder"]:
|
|
return PlaceholderAttentionMetadataBuilder
|
|
|
|
@staticmethod
|
|
def get_metadata_cls() -> Type["PlaceholderAttentionMetadata"]:
|
|
return PlaceholderAttentionMetadata
|
|
|
|
@staticmethod
|
|
def get_state_cls() -> Type["CommonAttentionState"]:
|
|
return CommonAttentionState
|
|
|
|
@staticmethod
|
|
def get_kv_cache_shape(
|
|
num_blocks: int,
|
|
block_size: int,
|
|
num_kv_heads: int,
|
|
head_size: int,
|
|
) -> Tuple[int, ...]:
|
|
return (1, 1, 1, 1, 1)
|
|
|
|
@staticmethod
|
|
def swap_blocks(
|
|
src_kv_cache: torch.Tensor,
|
|
dst_kv_cache: torch.Tensor,
|
|
src_to_dst: torch.Tensor,
|
|
) -> None:
|
|
return
|
|
|
|
@staticmethod
|
|
def copy_blocks(
|
|
kv_caches: List[torch.Tensor],
|
|
src_to_dists: torch.Tensor,
|
|
) -> None:
|
|
return
|
|
|
|
|
|
@dataclass
|
|
class PlaceholderAttentionMetadata(AttentionMetadata):
|
|
"""Attention metadata for prefill and decode batched together."""
|
|
# (batch_size,). The sequence length per sequence. Sequence length means
|
|
# the computed tokens + new tokens None if it is a decoding.
|
|
seq_lens: Optional[List[int]]
|
|
# seq_lens stored as a tensor.
|
|
seq_lens_tensor: Optional[torch.Tensor]
|
|
|
|
# Maximum query length in the batch.
|
|
max_query_len: Optional[int]
|
|
|
|
# Max number of query tokens among request in the batch.
|
|
max_decode_query_len: Optional[int]
|
|
|
|
# Maximum sequence length among prefill batch. 0 if there are decoding
|
|
# requests only.
|
|
max_prefill_seq_len: int
|
|
# Maximum sequence length among decode batch. 0 if there are prefill
|
|
# requests only.
|
|
max_decode_seq_len: int
|
|
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
|
# the batch, used to index into subquery. E.g., if the subquery length
|
|
# is [4, 6], it is [0, 4, 10].
|
|
query_start_loc: Optional[torch.Tensor]
|
|
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
|
# the batch, used to index into sequence. E.g., if the sequence length is
|
|
# [4, 6], it is [0, 4, 10].
|
|
seq_start_loc: Optional[torch.Tensor]
|
|
# (batch_size,) A tensor of context lengths (tokens that are computed
|
|
# so far).
|
|
context_lens_tensor: Optional[torch.Tensor]
|
|
|
|
# (batch_size, max_blocks_per_seq).
|
|
# Block addresses per sequence. (Seq id -> list of physical block)
|
|
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
|
|
# in the kv cache. Each block can contain up to block_size tokens.
|
|
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
|
|
# captured.
|
|
block_tables: Optional[torch.Tensor]
|
|
|
|
# Whether or not if cuda graph is enabled.
|
|
# Cuda-graph is currently enabled for decoding only.
|
|
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
|
use_cuda_graph: bool
|
|
|
|
_cached_prefill_metadata: Optional["PlaceholderAttentionMetadata"] = None
|
|
_cached_decode_metadata: Optional["PlaceholderAttentionMetadata"] = None
|
|
|
|
@property
|
|
def prefill_metadata(self) -> Optional["PlaceholderAttentionMetadata"]:
|
|
if self.num_prefills == 0:
|
|
return None
|
|
|
|
if self._cached_prefill_metadata is not None:
|
|
return self._cached_prefill_metadata
|
|
|
|
assert self.seq_lens is not None
|
|
assert self.seq_lens_tensor is not None
|
|
assert self.query_start_loc is not None
|
|
assert self.context_lens_tensor is not None
|
|
assert self.seq_start_loc is not None
|
|
|
|
# Placeholders
|
|
slot_mapping = torch.empty(0)
|
|
block_tables = torch.empty(0)
|
|
|
|
self._cached_prefill_metadata = PlaceholderAttentionMetadata(
|
|
num_prefills=self.num_prefills,
|
|
num_prefill_tokens=self.num_prefill_tokens,
|
|
num_decode_tokens=0,
|
|
slot_mapping=slot_mapping,
|
|
seq_lens=self.seq_lens[:self.num_prefills],
|
|
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
|
|
max_decode_query_len=0,
|
|
max_query_len=self.max_query_len,
|
|
max_prefill_seq_len=self.max_prefill_seq_len,
|
|
max_decode_seq_len=0,
|
|
query_start_loc=self.query_start_loc[:self.num_prefills + 1],
|
|
seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
|
|
context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
|
|
block_tables=block_tables,
|
|
use_cuda_graph=False,
|
|
)
|
|
return self._cached_prefill_metadata
|
|
|
|
@property
|
|
def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]:
|
|
if self.num_decode_tokens == 0:
|
|
return None
|
|
|
|
if self._cached_decode_metadata is not None:
|
|
return self._cached_decode_metadata
|
|
assert self.seq_lens_tensor is not None
|
|
|
|
# Placeholders
|
|
slot_mapping = torch.empty(0)
|
|
block_tables = torch.empty(0)
|
|
|
|
self._cached_decode_metadata = PlaceholderAttentionMetadata(
|
|
num_prefills=0,
|
|
num_prefill_tokens=0,
|
|
num_decode_tokens=self.num_decode_tokens,
|
|
slot_mapping=slot_mapping,
|
|
seq_lens=None,
|
|
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
|
|
max_decode_query_len=self.max_decode_query_len,
|
|
max_query_len=None,
|
|
max_prefill_seq_len=0,
|
|
max_decode_seq_len=self.max_decode_seq_len,
|
|
query_start_loc=None,
|
|
seq_start_loc=None,
|
|
context_lens_tensor=None,
|
|
block_tables=block_tables,
|
|
use_cuda_graph=self.use_cuda_graph,
|
|
)
|
|
return self._cached_decode_metadata
|
|
|
|
|
|
class PlaceholderAttentionMetadataBuilder(
|
|
AttentionMetadataBuilder[PlaceholderAttentionMetadata]):
|
|
|
|
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
|
|
self.prefill_seq_lens: List[int] = []
|
|
self.context_lens: List[int] = []
|
|
self.curr_seq_lens: List[int] = []
|
|
self.num_prefills = 0
|
|
self.num_prefill_tokens = 0
|
|
self.num_decode_tokens = 0
|
|
|
|
self.input_builder = input_builder
|
|
self.runner = input_builder.runner
|
|
|
|
def _add_seq_group(
|
|
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
|
|
chunked_prefill_enabled: bool):
|
|
"""Add a sequence group to the metadata. Specifically update/append
|
|
1. context length.
|
|
"""
|
|
is_prompt = inter_data.is_prompt
|
|
|
|
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
|
|
curr_sliding_window_block) in zip(
|
|
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
|
|
inter_data.orig_seq_lens, inter_data.seq_lens,
|
|
inter_data.query_lens, inter_data.context_lens,
|
|
inter_data.curr_sliding_window_blocks):
|
|
self.context_lens.append(context_len)
|
|
|
|
if is_prompt:
|
|
self.num_prefills += 1
|
|
self.num_prefill_tokens += token_len
|
|
self.prefill_seq_lens.append(seq_len)
|
|
else:
|
|
assert query_len == 1, (
|
|
"seq_len: {}, context_len: {}, query_len: {}".format(
|
|
seq_len, context_len, query_len))
|
|
self.num_decode_tokens += query_len
|
|
self.curr_seq_lens.append(curr_seq_len)
|
|
|
|
def build(self, seq_lens: List[int], query_lens: List[int],
|
|
cuda_graph_pad_size: int, batch_size: int):
|
|
"""Build attention metadata with on-device tensors.
|
|
|
|
Args:
|
|
seq_lens: The maybe padded sequence lengths of the input sequences.
|
|
query_lens: The query lengths of the input sequences.
|
|
cuda_graph_pad_size: The padding size for cuda graph.
|
|
-1 if cuda graph is not used.
|
|
batch_size: The maybe padded batch size.
|
|
"""
|
|
for inter_data in self.input_builder.inter_data_list:
|
|
self._add_seq_group(inter_data,
|
|
self.input_builder.chunked_prefill_enabled)
|
|
|
|
device = self.runner.device
|
|
use_captured_graph = cuda_graph_pad_size != -1
|
|
|
|
logits_soft_cap = getattr(self.runner.model_config.hf_config,
|
|
"attn_logit_softcapping", None)
|
|
if logits_soft_cap is not None:
|
|
raise ValueError(
|
|
"Please use Flashinfer backend for models with logits_soft_cap"
|
|
" (i.e., Gemma-2). Otherwise, the output might be wrong."
|
|
" Set Flashinfer backend by "
|
|
"export VLLM_ATTENTION_BACKEND=FLASHINFER.")
|
|
|
|
max_query_len = max(query_lens)
|
|
decode_query_lens = query_lens[self.num_prefills:]
|
|
if len(decode_query_lens) > 0:
|
|
max_decode_query_len = max(decode_query_lens)
|
|
else:
|
|
max_decode_query_len = 1
|
|
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
|
|
max_decode_seq_len = max(self.curr_seq_lens, default=0)
|
|
num_decode_tokens = self.num_decode_tokens
|
|
|
|
if use_captured_graph:
|
|
num_decode_tokens = batch_size
|
|
|
|
assert max_query_len > 0, ("query_lens: {}".format(query_lens))
|
|
|
|
context_lens_tensor = torch.tensor(self.context_lens,
|
|
dtype=torch.int,
|
|
device=device)
|
|
seq_lens_tensor = torch.tensor(seq_lens,
|
|
dtype=torch.int,
|
|
device=device)
|
|
query_lens_tensor = torch.tensor(query_lens,
|
|
dtype=torch.long,
|
|
device=device)
|
|
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
|
|
dtype=torch.int32,
|
|
device=device)
|
|
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
|
|
dtype=torch.int32,
|
|
device=device)
|
|
torch.cumsum(seq_lens_tensor,
|
|
dim=0,
|
|
dtype=seq_start_loc.dtype,
|
|
out=seq_start_loc[1:])
|
|
torch.cumsum(query_lens_tensor,
|
|
dim=0,
|
|
dtype=query_start_loc.dtype,
|
|
out=query_start_loc[1:])
|
|
|
|
# Placeholders
|
|
slot_mapping = torch.empty(0)
|
|
block_tables = torch.empty(0)
|
|
|
|
return PlaceholderAttentionMetadata(
|
|
num_prefills=self.num_prefills,
|
|
slot_mapping=slot_mapping,
|
|
num_prefill_tokens=self.num_prefill_tokens,
|
|
num_decode_tokens=num_decode_tokens,
|
|
seq_lens=seq_lens,
|
|
seq_lens_tensor=seq_lens_tensor,
|
|
max_query_len=max_query_len,
|
|
max_decode_query_len=max_decode_query_len,
|
|
max_prefill_seq_len=max_prefill_seq_len,
|
|
max_decode_seq_len=max_decode_seq_len,
|
|
query_start_loc=query_start_loc,
|
|
seq_start_loc=seq_start_loc,
|
|
context_lens_tensor=context_lens_tensor,
|
|
block_tables=block_tables,
|
|
use_cuda_graph=use_captured_graph,
|
|
)
|
|
|
|
|
|
class PlaceholderAttentionImpl(AttentionImpl):
|
|
|
|
def __init__(self, *args, **kwargs) -> None:
|
|
return
|
|
|
|
def forward(self, *args, **kwargs) -> torch.Tensor:
|
|
raise NotImplementedError
|