forked from EngineX-Cambricon/enginex-mlu370-vllm
add qwen3
This commit is contained in:
0
vllm-v0.6.2/vllm/attention/backends/__init__.py
Normal file
0
vllm-v0.6.2/vllm/attention/backends/__init__.py
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
246
vllm-v0.6.2/vllm/attention/backends/abstract.py
Normal file
246
vllm-v0.6.2/vllm/attention/backends/abstract.py
Normal file
@@ -0,0 +1,246 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, fields
|
||||
from enum import Enum, auto
|
||||
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set,
|
||||
Tuple, Type, TypeVar)
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.multimodal import MultiModalPlaceholderMap
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner_base import (ModelRunnerBase,
|
||||
ModelRunnerInputBase,
|
||||
ModelRunnerInputBuilderBase)
|
||||
|
||||
|
||||
class AttentionType(Enum):
|
||||
DECODER = auto() # Decoder attention between previous layer Q/K/V
|
||||
ENCODER = auto(
|
||||
) # Encoder attention between previous layer Q/K/V for encoder-decoder
|
||||
ENCODER_ONLY = auto() # Encoder attention between previous layer Q/K/V
|
||||
ENCODER_DECODER = auto(
|
||||
) # Attention between dec. Q and enc. K/V for encoder-decoder
|
||||
|
||||
|
||||
class AttentionBackend(ABC):
|
||||
"""Abstract class for attention backends."""
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_name() -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_impl_cls() -> Type["AttentionImpl"]:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_metadata_cls() -> Type["AttentionMetadata"]:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_state_cls() -> Type["AttentionState"]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
|
||||
return cls.get_metadata_cls()(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_builder_cls() -> Type["AttentionMetadataBuilder"]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def make_metadata_builder(cls, *args,
|
||||
**kwargs) -> "AttentionMetadataBuilder":
|
||||
return cls.get_builder_cls()(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: torch.Tensor,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: torch.Tensor,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def advance_step(self, model_input: "ModelRunnerInputBase",
|
||||
sampled_token_ids: Optional[torch.Tensor],
|
||||
block_size: int, num_seqs: int, num_queries: int) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass
|
||||
class AttentionMetadata:
|
||||
"""Attention metadata for prefill and decode batched together."""
|
||||
# Total number of prefill requests.
|
||||
num_prefills: int
|
||||
# Number of prefill tokens.
|
||||
num_prefill_tokens: int
|
||||
# Number of decode tokens. Note that it is equivalent to the number of
|
||||
# decode requests.
|
||||
num_decode_tokens: int
|
||||
# (num_tokens,). The indices of the token slots that input tokens will be
|
||||
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
|
||||
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
|
||||
# in block 0, and 1st slot in block 1, respectively.
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
# The index maps that relate multi-modal embeddings to the corresponding
|
||||
# placeholders.
|
||||
#
|
||||
# N.B. These aren't really related to attention and don't belong on this
|
||||
# type -- this is just a temporary solution to make them available to
|
||||
# `model_executable`.
|
||||
multi_modal_placeholder_index_maps: Optional[Dict[
|
||||
str, MultiModalPlaceholderMap.IndexMap]]
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def prefill_metadata(self) -> Optional["AttentionMetadata"]:
|
||||
"""Return the attention metadata that's required to run prefill
|
||||
attention."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def decode_metadata(self) -> Optional["AttentionMetadata"]:
|
||||
"""Return the attention metadata that's required to run decode
|
||||
attention."""
|
||||
pass
|
||||
|
||||
def asdict_zerocopy(self,
|
||||
skip_fields: Optional[Set[str]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Similar to dataclasses.asdict, but avoids deepcopying."""
|
||||
if skip_fields is None:
|
||||
skip_fields = set()
|
||||
# Note that if we add dataclasses as fields, they will need
|
||||
# similar handling.
|
||||
return {
|
||||
field.name: getattr(self, field.name)
|
||||
for field in fields(self) if field.name not in skip_fields
|
||||
}
|
||||
|
||||
|
||||
T = TypeVar("T", bound=AttentionMetadata)
|
||||
|
||||
|
||||
class AttentionState(ABC, Generic[T]):
|
||||
"""Holds attention backend-specific objects reused during the
|
||||
lifetime of the model runner."""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, runner: "ModelRunnerBase"):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
@contextmanager
|
||||
def graph_capture(self, max_batch_size: int):
|
||||
"""Context manager used when capturing CUDA graphs."""
|
||||
yield
|
||||
|
||||
@abstractmethod
|
||||
def graph_clone(self, batch_size: int) -> "AttentionState[T]":
|
||||
"""Clone attention state to save in CUDA graph metadata."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def graph_capture_get_metadata_for_batch(
|
||||
self,
|
||||
batch_size: int,
|
||||
is_encoder_decoder_model: bool = False) -> T:
|
||||
"""Get attention metadata for CUDA graph capture of batch_size."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_graph_input_buffers(
|
||||
self,
|
||||
attn_metadata: T,
|
||||
is_encoder_decoder_model: bool = False) -> Dict[str, Any]:
|
||||
"""Get attention-specific input buffers for CUDA graph capture."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def prepare_graph_input_buffers(
|
||||
self,
|
||||
input_buffers: Dict[str, Any],
|
||||
attn_metadata: T,
|
||||
is_encoder_decoder_model: bool = False) -> None:
|
||||
"""In-place modify input buffers dict for CUDA graph replay."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def begin_forward(self, model_input: "ModelRunnerInputBase") -> None:
|
||||
"""Prepare state for forward pass."""
|
||||
...
|
||||
|
||||
|
||||
class AttentionMetadataBuilder(ABC, Generic[T]):
|
||||
"""Abstract class for attention metadata builders."""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, input_builder: "ModelRunnerInputBuilderBase") -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def build(self, seq_lens: List[int], query_lens: List[int],
|
||||
cuda_graph_pad_size: int, batch_size: int) -> T:
|
||||
"""Build attention metadata with on-device tensors."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class AttentionImpl(ABC, Generic[T]):
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: Optional[int] = None,
|
||||
alibi_slopes: Optional[List[float]] = None,
|
||||
sliding_window: Optional[int] = None,
|
||||
kv_cache_dtype: str = "auto",
|
||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: T,
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
447
vllm-v0.6.2/vllm/attention/backends/blocksparse_attn.py
Normal file
447
vllm-v0.6.2/vllm/attention/backends/blocksparse_attn.py
Normal file
@@ -0,0 +1,447 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata, AttentionType)
|
||||
from vllm.attention.backends.utils import (CommonAttentionState,
|
||||
CommonMetadataBuilder)
|
||||
from vllm.attention.ops.blocksparse_attention.interface import (
|
||||
LocalStridedBlockSparseAttn, get_head_sliding_step)
|
||||
from vllm.attention.ops.paged_attn import PagedAttention
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BlocksparseParams:
|
||||
max_seqlen: int
|
||||
|
||||
# Num q heads per tensor-parallel rank/partition
|
||||
num_heads: int # per TP partition
|
||||
# Num kv heads per tensor-parallel rank/partition
|
||||
num_kv_heads: int
|
||||
|
||||
# block size used for blocksparse attention.
|
||||
# This is the block_size used in `local_blocks`, `vert_stride`.
|
||||
block_size: int
|
||||
|
||||
# Number of blocks for local attention, i.e., number of
|
||||
# local attended tokens / `sparse_block_size`
|
||||
local_blocks: int
|
||||
|
||||
# Attend to one block per every `vert_stride` blocks.
|
||||
# Controlling the sparsity
|
||||
vert_stride: int
|
||||
"""
|
||||
If to use the same vertical stride offset for all heads,
|
||||
i.e., attend to the same block of tokens on all heads.
|
||||
By default, it is False, i.e., attention on the non-local
|
||||
blocks depends on the `head_idx`, that is on
|
||||
blocks satisfying
|
||||
`(block_idx + head_idx * head_sliding_step + 1) % vert_stride == 0`
|
||||
where `head_sliding_step=max(1, int(vert_stride / num_total_heads))`,
|
||||
`block_idx = position_id // sparse_block_size`.
|
||||
See `..ops.blocksparse_attention.utils:get_sparse_attn_mask`
|
||||
for more detail.
|
||||
"""
|
||||
homo_head: bool = False
|
||||
|
||||
# If within a group, the kv offsets that each q attends is the same or no.
|
||||
homo_head_group: bool = False
|
||||
|
||||
# Decided by homo_head and homo_head group
|
||||
head_sliding_step: int = field(init=False)
|
||||
|
||||
# range of q heads to for a TP rank
|
||||
active_head_range: Tuple = field(init=False)
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.block_size > 0
|
||||
assert self.local_blocks >= 0
|
||||
assert self.vert_stride >= 1
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
total_heads = tp_size * self.num_heads
|
||||
total_kv_heads = tp_size * self.num_kv_heads
|
||||
|
||||
if self.homo_head:
|
||||
self.head_sliding_step = 0
|
||||
elif self.homo_head_group:
|
||||
head_sliding_step = get_head_sliding_step(total_kv_heads,
|
||||
self.vert_stride)
|
||||
# negative indicates sliding along kv heads, i.e., homo q group
|
||||
self.head_sliding_step = -head_sliding_step
|
||||
else:
|
||||
self.head_sliding_step = get_head_sliding_step(
|
||||
total_heads, self.vert_stride)
|
||||
|
||||
self.active_head_range = (
|
||||
tp_rank * self.num_heads,
|
||||
(tp_rank + 1) * self.num_heads,
|
||||
)
|
||||
|
||||
|
||||
class BlocksparseFlashAttentionBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["BlocksparseFlashAttentionImpl"]:
|
||||
return BlocksparseFlashAttentionImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> Type["AttentionMetadata"]:
|
||||
return BlocksparseFlashAttentionMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> Type["BlocksparseFlashAttentionMetadataBuilder"]:
|
||||
return BlocksparseFlashAttentionMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> Type["CommonAttentionState"]:
|
||||
return CommonAttentionState
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
|
||||
num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: Dict[int, int],
|
||||
) -> None:
|
||||
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: Dict[int, List[int]],
|
||||
) -> None:
|
||||
PagedAttention.copy_blocks(kv_caches, src_to_dists)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BlocksparseFlashAttentionMetadata(AttentionMetadata):
|
||||
"""A copy of Metadata for FlashAttentionBackend,
|
||||
to avoid having to install flash_attn.
|
||||
|
||||
NOTE: Any python object stored here is not updated when it is
|
||||
cuda-graph replayed. If you have values that need to be changed
|
||||
dynamically, it should be stored in tensor. The tensor has to be
|
||||
updated from `CUDAGraphRunner.forward` API.
|
||||
"""
|
||||
# (batch_size,). The sequence length per sequence. Sequence length means
|
||||
# the computed tokens + new tokens None if it is a decoding.
|
||||
seq_lens: Optional[List[int]]
|
||||
# seq_lens stored as a tensor.
|
||||
seq_lens_tensor: Optional[torch.Tensor]
|
||||
|
||||
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
# |- tokenA -|......................|-- newTokens ---|
|
||||
# |---------- context_len ----------|
|
||||
# |-------------------- seq_len ----------------------|
|
||||
# |-- query_len ---|
|
||||
|
||||
# Maximum query length in the batch. None for decoding.
|
||||
max_query_len: Optional[int]
|
||||
# Maximum sequence length among prefill batch. 0 if there are decoding
|
||||
# requests only.
|
||||
max_prefill_seq_len: int
|
||||
# Maximum sequence length among decode batch. 0 if there are prefill
|
||||
# requests only.
|
||||
max_decode_seq_len: int
|
||||
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
||||
# the batch, used to index into subquery. E.g., if the subquery length
|
||||
# is [4, 6], it is [0, 4, 10].
|
||||
query_start_loc: Optional[torch.Tensor]
|
||||
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
||||
# the batch, used to index into sequence. E.g., if the sequence length is
|
||||
# [4, 6], it is [0, 4, 10].
|
||||
seq_start_loc: Optional[torch.Tensor]
|
||||
# (batch_size,) A tensor of context lengths (tokens that are computed
|
||||
# so far).
|
||||
context_lens_tensor: Optional[torch.Tensor]
|
||||
|
||||
# (batch_size, max_blocks_per_seq).
|
||||
# Block addresses per sequence. (Seq id -> list of physical block)
|
||||
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
|
||||
# in the kv cache. Each block can contain up to block_size tokens.
|
||||
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
|
||||
# captured.
|
||||
block_tables: Optional[torch.Tensor]
|
||||
|
||||
# Whether or not if cuda graph is enabled.
|
||||
# Cuda-graph is currently enabled for decoding only.
|
||||
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
||||
use_cuda_graph: bool
|
||||
|
||||
# Max number of query tokens for among request in the batch.
|
||||
max_decode_query_len: Optional[int] = None
|
||||
|
||||
_cached_prefill_metadata: Optional[
|
||||
"BlocksparseFlashAttentionMetadata"] = None
|
||||
_cached_decode_metadata: Optional[
|
||||
"BlocksparseFlashAttentionMetadata"] = None
|
||||
|
||||
@property
|
||||
def prefill_metadata(
|
||||
self) -> Optional["BlocksparseFlashAttentionMetadata"]:
|
||||
if self.num_prefills == 0:
|
||||
return None
|
||||
|
||||
if self._cached_prefill_metadata is not None:
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
assert self.seq_lens is not None
|
||||
assert self.seq_lens_tensor is not None
|
||||
assert self.query_start_loc is not None
|
||||
assert self.context_lens_tensor is not None
|
||||
assert self.block_tables is not None
|
||||
assert self.seq_start_loc is not None
|
||||
|
||||
self._cached_prefill_metadata = BlocksparseFlashAttentionMetadata(
|
||||
num_prefills=self.num_prefills,
|
||||
num_prefill_tokens=self.num_prefill_tokens,
|
||||
num_decode_tokens=0,
|
||||
slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
|
||||
multi_modal_placeholder_index_maps=self.
|
||||
multi_modal_placeholder_index_maps,
|
||||
seq_lens=self.seq_lens[:self.num_prefills],
|
||||
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
|
||||
max_query_len=self.max_query_len,
|
||||
max_prefill_seq_len=self.max_prefill_seq_len,
|
||||
max_decode_seq_len=0,
|
||||
query_start_loc=self.query_start_loc[:self.num_prefills + 1],
|
||||
seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
|
||||
context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
|
||||
block_tables=self.block_tables[:self.num_prefills],
|
||||
use_cuda_graph=False,
|
||||
)
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
@property
|
||||
def decode_metadata(self) -> Optional["BlocksparseFlashAttentionMetadata"]:
|
||||
if self.num_decode_tokens == 0:
|
||||
return None
|
||||
|
||||
if self._cached_decode_metadata is not None:
|
||||
return self._cached_decode_metadata
|
||||
assert self.block_tables is not None
|
||||
assert self.seq_lens_tensor is not None
|
||||
|
||||
self._cached_decode_metadata = BlocksparseFlashAttentionMetadata(
|
||||
num_prefills=0,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=self.num_decode_tokens,
|
||||
slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
seq_lens=None,
|
||||
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
|
||||
max_query_len=None,
|
||||
max_prefill_seq_len=0,
|
||||
max_decode_seq_len=self.max_decode_seq_len,
|
||||
query_start_loc=None,
|
||||
seq_start_loc=None,
|
||||
context_lens_tensor=None,
|
||||
block_tables=self.block_tables[self.num_prefills:],
|
||||
use_cuda_graph=self.use_cuda_graph,
|
||||
)
|
||||
return self._cached_decode_metadata
|
||||
|
||||
|
||||
class BlocksparseFlashAttentionMetadataBuilder(
|
||||
CommonMetadataBuilder[BlocksparseFlashAttentionMetadata]):
|
||||
|
||||
_metadata_cls = BlocksparseFlashAttentionMetadata
|
||||
|
||||
|
||||
class BlocksparseFlashAttentionImpl(AttentionImpl):
|
||||
"""
|
||||
If the input tensors contain prompt tokens, the layout is as follows:
|
||||
|<--------------- num_prompt_tokens -------------->|
|
||||
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|
|
||||
|
||||
Otherwise, the layout is as follows:
|
||||
|<------------------ num_generation_tokens (M) ----------------->|
|
||||
|<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|
|
||||
|
||||
Generation tokens can contain padding when cuda-graph is used.
|
||||
Currently, prompt tokens don't contain any padding.
|
||||
|
||||
The prompts might have different lengths, while the generation tokens
|
||||
always have length 1.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
) -> None:
|
||||
assert blocksparse_params is not None
|
||||
assert alibi_slopes is None, ValueError(
|
||||
"Alibi not support for blocksparse flash attention.")
|
||||
assert sliding_window is None, ValueError(
|
||||
"sliding_window is invalid for blocksparse attention.")
|
||||
assert logits_soft_cap is None, ValueError(
|
||||
"logits_soft_cap is invalid for blocksparse attention.")
|
||||
|
||||
if "num_heads" not in blocksparse_params:
|
||||
blocksparse_params["num_heads"] = num_heads
|
||||
if "num_kv_heads" not in blocksparse_params:
|
||||
blocksparse_params["num_kv_heads"] = num_kv_heads or num_heads
|
||||
self.blocksparse_params = BlocksparseParams(**blocksparse_params)
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.alibi_slopes = alibi_slopes
|
||||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
self.local_blocks = self.blocksparse_params.local_blocks
|
||||
self.vert_stride = self.blocksparse_params.vert_stride
|
||||
self.sparse_block_size = self.blocksparse_params.block_size
|
||||
self.head_sliding_step = self.blocksparse_params.head_sliding_step
|
||||
|
||||
suppored_head_sizes = PagedAttention.get_supported_head_sizes()
|
||||
if head_size not in suppored_head_sizes:
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by PagedAttention. "
|
||||
f"Supported head sizes are: {suppored_head_sizes}.")
|
||||
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
total_num_heads = num_heads * self.tp_size
|
||||
self.bs_attn = LocalStridedBlockSparseAttn(
|
||||
total_num_heads,
|
||||
self.blocksparse_params.max_seqlen,
|
||||
self.blocksparse_params.local_blocks,
|
||||
self.blocksparse_params.vert_stride,
|
||||
self.blocksparse_params.block_size,
|
||||
homo_head=self.blocksparse_params.homo_head,
|
||||
active_head_range=self.blocksparse_params.active_head_range,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: BlocksparseFlashAttentionMetadata,
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashAttention and PagedAttention.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads * head_size]
|
||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
|
||||
NOTE: kv_cache will be an empty tensor with shape [0]
|
||||
for profiling run.
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"BlocksparseFlashAttentionImpl")
|
||||
|
||||
num_tokens, hidden_size = query.shape
|
||||
# Reshape the query, key, and value tensors.
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
|
||||
if kv_cache.numel() > 0:
|
||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||
kv_cache, self.num_kv_heads, self.head_size)
|
||||
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# If kv_cache is not provided, the new key and value tensors are
|
||||
# not cached. This happens during the initial memory profiling run.
|
||||
|
||||
PagedAttention.write_to_paged_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
|
||||
if prefill_meta := attn_metadata.prefill_metadata:
|
||||
|
||||
# Prompt run.
|
||||
# normal attention
|
||||
# When block_tables are not filled, it means q and k are the
|
||||
# prompt, and they have the same length.
|
||||
|
||||
assert kv_cache.numel() == 0 \
|
||||
or prefill_meta.block_tables is None \
|
||||
or prefill_meta.block_tables.numel() == 0, \
|
||||
"Does not support prefix-enabled attention."
|
||||
|
||||
output = self.bs_attn(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
cu_seqlens_q=prefill_meta.seq_start_loc,
|
||||
cu_seqlens_k=prefill_meta.seq_start_loc,
|
||||
sm_scale=self.scale,
|
||||
)
|
||||
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
# Decoding run.
|
||||
output = PagedAttention.forward_decode(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
decode_meta.block_tables,
|
||||
decode_meta.seq_lens_tensor,
|
||||
self.blocksparse_params.max_seqlen,
|
||||
self.kv_cache_dtype,
|
||||
self.num_kv_heads,
|
||||
self.scale,
|
||||
self.alibi_slopes,
|
||||
k_scale,
|
||||
v_scale,
|
||||
tp_rank=self.tp_rank,
|
||||
blocksparse_local_blocks=self.local_blocks,
|
||||
blocksparse_vert_stride=self.vert_stride,
|
||||
blocksparse_block_size=self.sparse_block_size,
|
||||
blocksparse_head_sliding_step=self.head_sliding_step,
|
||||
)
|
||||
|
||||
# Reshape the output tensor.
|
||||
return output.view(num_tokens, hidden_size)
|
||||
989
vllm-v0.6.2/vllm/attention/backends/flash_attn.py
Normal file
989
vllm-v0.6.2/vllm/attention/backends/flash_attn.py
Normal file
@@ -0,0 +1,989 @@
|
||||
"""Attention layer with FlashAttention."""
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from itertools import accumulate
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata,
|
||||
AttentionMetadataBuilder,
|
||||
AttentionType)
|
||||
from vllm.attention.backends.utils import (
|
||||
PAD_SLOT_ID, CommonAttentionState, compute_slot_mapping,
|
||||
compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens,
|
||||
get_seq_len_block_table_args, is_all_cross_attn_metadata_set,
|
||||
is_all_encoder_attn_metadata_set, is_block_tables_empty)
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.multimodal import MultiModalPlaceholderMap
|
||||
from vllm.utils import (async_tensor_h2d, direct_register_custom_op,
|
||||
make_tensor_with_pad)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
|
||||
ModelInputForGPUWithSamplingMetadata)
|
||||
|
||||
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
|
||||
flash_attn_with_kvcache)
|
||||
|
||||
|
||||
class FlashAttentionBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_supported_head_sizes() -> List[int]:
|
||||
return [32, 64, 96, 128, 160, 192, 224, 256]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASH_ATTN"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["FlashAttentionImpl"]:
|
||||
return FlashAttentionImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> Type["AttentionMetadata"]:
|
||||
return FlashAttentionMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> Type["FlashAttentionMetadataBuilder"]:
|
||||
return FlashAttentionMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> Type["CommonAttentionState"]:
|
||||
return CommonAttentionState
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
if block_size % 16 != 0:
|
||||
raise ValueError("Block size must be a multiple of 16.")
|
||||
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: torch.Tensor,
|
||||
) -> None:
|
||||
src_key_cache = src_kv_cache[0]
|
||||
dst_key_cache = dst_kv_cache[0]
|
||||
ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
|
||||
src_value_cache = src_kv_cache[1]
|
||||
dst_value_cache = dst_kv_cache[1]
|
||||
ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: torch.Tensor,
|
||||
) -> None:
|
||||
key_caches = [kv_cache[0] for kv_cache in kv_caches]
|
||||
value_caches = [kv_cache[1] for kv_cache in kv_caches]
|
||||
|
||||
ops.copy_blocks(key_caches, value_caches, src_to_dists)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashAttentionMetadata(AttentionMetadata):
|
||||
"""Metadata for FlashAttentionBackend.
|
||||
|
||||
NOTE: Any python object stored here is not updated when it is
|
||||
cuda-graph replayed. If you have values that need to be changed
|
||||
dynamically, it should be stored in tensor. The tensor has to be
|
||||
updated from `CUDAGraphRunner.forward` API.
|
||||
"""
|
||||
# (batch_size,). The sequence length per sequence. Sequence length means
|
||||
# the computed tokens + new tokens None if it is a decoding.
|
||||
seq_lens: Optional[List[int]]
|
||||
# seq_lens stored as a tensor.
|
||||
seq_lens_tensor: Optional[torch.Tensor]
|
||||
|
||||
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
# |- tokenA -|......................|-- newTokens ---|
|
||||
# |---------- context_len ----------|
|
||||
# |-------------------- seq_len ---------------------|
|
||||
# |-- query_len ---|
|
||||
|
||||
# Maximum sequence length among prefill batch. 0 if there are decoding
|
||||
# requests only.
|
||||
max_prefill_seq_len: int
|
||||
# Maximum sequence length among decode batch. 0 if there are prefill
|
||||
# requests only.
|
||||
max_decode_seq_len: int
|
||||
# (batch_size,) A tensor of context lengths (tokens that are computed
|
||||
# so far).
|
||||
context_lens_tensor: Optional[torch.Tensor]
|
||||
|
||||
# (batch_size, max_blocks_per_seq).
|
||||
# Block addresses per sequence. (Seq id -> list of physical block)
|
||||
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
|
||||
# in the kv cache. Each block can contain up to block_size tokens.
|
||||
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
|
||||
# captured.
|
||||
block_tables: Optional[torch.Tensor]
|
||||
|
||||
# Whether or not if cuda graph is enabled.
|
||||
# Cuda-graph is currently enabled for decoding only.
|
||||
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
||||
|
||||
use_cuda_graph: bool
|
||||
|
||||
# Maximum query length in the batch.
|
||||
max_query_len: Optional[int] = None
|
||||
|
||||
# Max number of query tokens among request in the batch.
|
||||
max_decode_query_len: Optional[int] = None
|
||||
|
||||
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
||||
# the batch, used to index into subquery. E.g., if the subquery length
|
||||
# is [4, 6], it is [0, 4, 10].
|
||||
query_start_loc: Optional[torch.Tensor] = None
|
||||
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
||||
# the batch, used to index into sequence. E.g., if the sequence length is
|
||||
# [4, 6], it is [0, 4, 10].
|
||||
seq_start_loc: Optional[torch.Tensor] = None
|
||||
|
||||
_cached_prefill_metadata: Optional["FlashAttentionMetadata"] = None
|
||||
_cached_decode_metadata: Optional["FlashAttentionMetadata"] = None
|
||||
|
||||
# Begin encoder attn & enc/dec cross-attn fields...
|
||||
|
||||
# Encoder sequence lengths representation
|
||||
encoder_seq_lens: Optional[List[int]] = None
|
||||
encoder_seq_lens_tensor: Optional[torch.Tensor] = None
|
||||
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
||||
# the batch, used to index into sequence. E.g., if the sequence length is
|
||||
# [4, 6], it is [0, 4, 10].
|
||||
encoder_seq_start_loc: Optional[torch.Tensor] = None
|
||||
# Maximum sequence length among encoder sequences
|
||||
max_encoder_seq_len: Optional[int] = None
|
||||
# Number of tokens input to encoder
|
||||
num_encoder_tokens: Optional[int] = None
|
||||
|
||||
# Cross-attention memory-mapping data structures: slot mapping
|
||||
# and block tables
|
||||
cross_slot_mapping: Optional[torch.Tensor] = None
|
||||
cross_block_tables: Optional[torch.Tensor] = None
|
||||
|
||||
@property
|
||||
def is_all_encoder_attn_metadata_set(self):
|
||||
'''
|
||||
All attention metadata required for encoder attention is set.
|
||||
'''
|
||||
return is_all_encoder_attn_metadata_set(self)
|
||||
|
||||
@property
|
||||
def is_all_cross_attn_metadata_set(self):
|
||||
'''
|
||||
All attention metadata required for enc/dec cross-attention is set.
|
||||
|
||||
Superset of encoder attention required metadata.
|
||||
'''
|
||||
return is_all_cross_attn_metadata_set(self)
|
||||
|
||||
@property
|
||||
def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]:
|
||||
if self.num_prefills == 0:
|
||||
return None
|
||||
|
||||
if self._cached_prefill_metadata is not None:
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
assert ((self.seq_lens is not None)
|
||||
or (self.encoder_seq_lens is not None))
|
||||
assert ((self.seq_lens_tensor is not None)
|
||||
or (self.encoder_seq_lens_tensor is not None))
|
||||
|
||||
# Compute some attn_metadata fields which default to None
|
||||
query_start_loc = (None if self.query_start_loc is None else
|
||||
self.query_start_loc[:self.num_prefills + 1])
|
||||
slot_mapping = (None if self.slot_mapping is None else
|
||||
self.slot_mapping[:self.num_prefill_tokens])
|
||||
seq_lens = (None if self.seq_lens is None else
|
||||
self.seq_lens[:self.num_prefills])
|
||||
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
||||
self.seq_lens_tensor[:self.num_prefills])
|
||||
seq_start_loc = (None if self.seq_start_loc is None else
|
||||
self.seq_start_loc[:self.num_prefills + 1])
|
||||
context_lens_tensor = (None if self.context_lens_tensor is None else
|
||||
self.context_lens_tensor[:self.num_prefills])
|
||||
block_tables = (None if self.block_tables is None else
|
||||
self.block_tables[:self.num_prefills])
|
||||
|
||||
self._cached_prefill_metadata = FlashAttentionMetadata(
|
||||
num_prefills=self.num_prefills,
|
||||
num_prefill_tokens=self.num_prefill_tokens,
|
||||
num_decode_tokens=0,
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=self.
|
||||
multi_modal_placeholder_index_maps,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_query_len=self.max_query_len,
|
||||
max_prefill_seq_len=self.max_prefill_seq_len,
|
||||
max_decode_query_len=0,
|
||||
max_decode_seq_len=0,
|
||||
query_start_loc=query_start_loc,
|
||||
seq_start_loc=seq_start_loc,
|
||||
context_lens_tensor=context_lens_tensor,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=False,
|
||||
# Begin encoder & cross attn fields below...
|
||||
encoder_seq_lens=self.encoder_seq_lens,
|
||||
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
|
||||
encoder_seq_start_loc=self.encoder_seq_start_loc,
|
||||
max_encoder_seq_len=self.max_encoder_seq_len,
|
||||
cross_slot_mapping=self.cross_slot_mapping,
|
||||
cross_block_tables=self.cross_block_tables)
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
@property
|
||||
def decode_metadata(self) -> Optional["FlashAttentionMetadata"]:
|
||||
if self.num_decode_tokens == 0:
|
||||
return None
|
||||
|
||||
if self._cached_decode_metadata is not None:
|
||||
return self._cached_decode_metadata
|
||||
assert ((self.seq_lens_tensor is not None)
|
||||
or (self.encoder_seq_lens_tensor is not None))
|
||||
|
||||
# Compute some attn_metadata fields which default to None
|
||||
slot_mapping = (None if self.slot_mapping is None else
|
||||
self.slot_mapping[self.num_prefill_tokens:])
|
||||
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
||||
self.seq_lens_tensor[self.num_prefills:])
|
||||
block_tables = (None if self.block_tables is None else
|
||||
self.block_tables[self.num_prefills:])
|
||||
|
||||
self._cached_decode_metadata = FlashAttentionMetadata(
|
||||
num_prefills=0,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=self.num_decode_tokens,
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
seq_lens=None,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_decode_query_len=self.max_decode_query_len,
|
||||
max_query_len=self.max_query_len,
|
||||
max_prefill_seq_len=0,
|
||||
max_decode_seq_len=self.max_decode_seq_len,
|
||||
# Batch may be composed of prefill|decodes, adjust query start
|
||||
# indices to refer to the start of decodes. E.g.
|
||||
# in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
|
||||
query_start_loc=(self.query_start_loc[self.num_prefills:] -
|
||||
self.query_start_loc[self.num_prefills])
|
||||
if self.query_start_loc is not None else None,
|
||||
seq_start_loc=self.seq_start_loc[self.num_prefills:]
|
||||
if self.seq_start_loc is not None else None,
|
||||
context_lens_tensor=None,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=self.use_cuda_graph,
|
||||
# Begin encoder & cross attn fields below...
|
||||
encoder_seq_lens=self.encoder_seq_lens,
|
||||
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
|
||||
encoder_seq_start_loc=self.encoder_seq_start_loc,
|
||||
max_encoder_seq_len=self.max_encoder_seq_len,
|
||||
cross_slot_mapping=self.cross_slot_mapping,
|
||||
cross_block_tables=self.cross_block_tables)
|
||||
return self._cached_decode_metadata
|
||||
|
||||
def advance_step(self,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
sampled_token_ids: Optional[torch.Tensor],
|
||||
block_size: int,
|
||||
num_seqs: int,
|
||||
num_queries: int,
|
||||
turn_prefills_into_decodes: bool = False):
|
||||
"""
|
||||
Update metadata in-place to advance one decode step.
|
||||
"""
|
||||
# When using cudagraph, the num_seqs is padded to the next captured
|
||||
# batch sized, but num_queries tracks the actual number of requests in
|
||||
# the batch. For --enforce-eager mode, num_seqs == num_queries
|
||||
if num_seqs != num_queries:
|
||||
assert num_seqs > num_queries
|
||||
assert self.use_cuda_graph
|
||||
|
||||
if turn_prefills_into_decodes:
|
||||
# When Mutli-Step is enabled with Chunked-Prefill, prefills and
|
||||
# decodes are scheduled together. In the first step, all the
|
||||
# prefills turn into decodes. This update reflects that
|
||||
# conversion.
|
||||
assert self.num_decode_tokens + self.num_prefills == num_seqs
|
||||
self.num_decode_tokens += self.num_prefills
|
||||
self.num_prefills = 0
|
||||
self.num_prefill_tokens = 0
|
||||
self.max_prefill_seq_len = 0
|
||||
self.max_query_len = 1
|
||||
|
||||
self.slot_mapping = self.slot_mapping[:num_seqs]
|
||||
else:
|
||||
assert self.seq_lens is not None
|
||||
assert self.max_decode_seq_len == max(self.seq_lens)
|
||||
|
||||
assert self.num_prefills == 0
|
||||
assert self.num_prefill_tokens == 0
|
||||
assert self.num_decode_tokens == num_seqs
|
||||
assert self.slot_mapping.shape == (num_seqs, )
|
||||
|
||||
assert self.seq_lens is not None
|
||||
assert len(self.seq_lens) == num_seqs
|
||||
assert self.seq_lens_tensor is not None
|
||||
assert self.seq_lens_tensor.shape == (num_seqs, )
|
||||
assert self.max_query_len == 1
|
||||
assert self.max_prefill_seq_len == 0
|
||||
|
||||
assert self.query_start_loc is not None
|
||||
assert self.query_start_loc.shape == (num_queries + 1, )
|
||||
assert self.seq_start_loc is not None
|
||||
assert self.seq_start_loc.shape == (num_seqs + 1, )
|
||||
|
||||
assert self.context_lens_tensor is not None
|
||||
assert self.context_lens_tensor.shape == (num_queries, )
|
||||
|
||||
assert self.block_tables is not None
|
||||
assert self.block_tables.shape[0] == num_seqs
|
||||
|
||||
# Update query lengths. Note that we update only queries and not seqs,
|
||||
# since tensors may be padded due to captured cuda graph batch size
|
||||
for i in range(num_queries):
|
||||
self.seq_lens[i] += 1
|
||||
self.max_decode_seq_len = max(self.seq_lens)
|
||||
|
||||
ops.advance_step_flashattn(num_seqs=num_seqs,
|
||||
num_queries=num_queries,
|
||||
block_size=block_size,
|
||||
input_tokens=model_input.input_tokens,
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
input_positions=model_input.input_positions,
|
||||
seq_lens=self.seq_lens_tensor,
|
||||
slot_mapping=self.slot_mapping,
|
||||
block_tables=self.block_tables)
|
||||
|
||||
|
||||
class FlashAttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[FlashAttentionMetadata]):
|
||||
|
||||
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
|
||||
self.slot_mapping: List[int] = []
|
||||
self.prefill_seq_lens: List[int] = []
|
||||
self.context_lens: List[int] = []
|
||||
self.block_tables: List[List[int]] = []
|
||||
self.curr_seq_lens: List[int] = []
|
||||
self.multimodal_placeholder_maps: Dict[
|
||||
str,
|
||||
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
|
||||
self.num_prefills = 0
|
||||
self.num_prefill_tokens = 0
|
||||
self.num_decode_tokens = 0
|
||||
self.has_prefix_cache_hit = False
|
||||
|
||||
self.input_builder = input_builder
|
||||
self.runner = input_builder.runner
|
||||
self.sliding_window = input_builder.sliding_window
|
||||
self.block_size = input_builder.block_size
|
||||
|
||||
def _add_seq_group(
|
||||
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
|
||||
chunked_prefill_enabled: bool, prefix_cache_hit: bool):
|
||||
"""Add a sequence group to the metadata. Specifically update/append
|
||||
1. context length.
|
||||
2. block table.
|
||||
3. slot mapping.
|
||||
"""
|
||||
is_prompt = inter_data.is_prompt
|
||||
block_tables = inter_data.block_tables
|
||||
|
||||
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
|
||||
curr_sliding_window_block) in zip(
|
||||
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
|
||||
inter_data.orig_seq_lens, inter_data.seq_lens,
|
||||
inter_data.query_lens, inter_data.context_lens,
|
||||
inter_data.curr_sliding_window_blocks):
|
||||
self.context_lens.append(context_len)
|
||||
|
||||
if is_prompt:
|
||||
mm_maps = inter_data.multi_modal_placeholder_maps
|
||||
if mm_maps:
|
||||
for modality, placeholders in mm_maps.items():
|
||||
self.multimodal_placeholder_maps[modality].extend(
|
||||
placeholders)
|
||||
|
||||
self.num_prefills += 1
|
||||
self.num_prefill_tokens += token_len
|
||||
self.prefill_seq_lens.append(seq_len)
|
||||
else:
|
||||
self.num_decode_tokens += query_len
|
||||
self.curr_seq_lens.append(curr_seq_len)
|
||||
|
||||
# Compute block table.
|
||||
# TODO(sang): Combine chunked prefill and prefix caching by
|
||||
# only allowing multiple of block_size chunk size.
|
||||
# NOTE: This only works for oooooooxxx style attention.
|
||||
block_table = []
|
||||
if prefix_cache_hit:
|
||||
# NOTE(woosuk): For flash-attn, the block table should
|
||||
# include the entries for the incoming prefill tokens.
|
||||
block_table = block_tables[seq_id]
|
||||
elif ((chunked_prefill_enabled or not is_prompt)
|
||||
and block_tables is not None):
|
||||
if curr_sliding_window_block == 0:
|
||||
block_table = block_tables[seq_id]
|
||||
else:
|
||||
block_table = block_tables[seq_id][
|
||||
-curr_sliding_window_block:]
|
||||
self.block_tables.append(block_table)
|
||||
|
||||
# Compute slot mapping.
|
||||
is_profile_run = is_block_tables_empty(block_tables)
|
||||
start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
|
||||
context_len,
|
||||
self.sliding_window)
|
||||
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
|
||||
seq_len, context_len, start_idx,
|
||||
self.block_size, inter_data.block_tables)
|
||||
|
||||
def _get_graph_runner_block_tables(
|
||||
self, num_seqs: int,
|
||||
block_tables: List[List[int]]) -> torch.Tensor:
|
||||
# The shape of graph_block_tables is
|
||||
# [max batch size, max context len // block size].
|
||||
max_batch_size, max_blocks = self.runner.graph_block_tables.shape
|
||||
assert max_batch_size >= num_seqs
|
||||
|
||||
graph_block_tables = self.runner.graph_block_tables[:num_seqs]
|
||||
for i, block_table in enumerate(block_tables):
|
||||
if block_table:
|
||||
num_blocks = len(block_table)
|
||||
if num_blocks <= max_blocks:
|
||||
graph_block_tables[i, :num_blocks] = block_table
|
||||
else:
|
||||
# It may be possible to have more blocks allocated due
|
||||
# to lookahead slots of multi-step, however, they are
|
||||
# not used anyway, so can be safely ignored.
|
||||
graph_block_tables[
|
||||
i, :max_blocks] = block_table[:max_blocks]
|
||||
|
||||
return torch.from_numpy(graph_block_tables).to(
|
||||
device=self.runner.device, non_blocking=True)
|
||||
|
||||
def build(self, seq_lens: List[int], query_lens: List[int],
|
||||
cuda_graph_pad_size: int, batch_size: int):
|
||||
"""Build attention metadata with on-device tensors.
|
||||
|
||||
Args:
|
||||
seq_lens: The maybe padded sequence lengths of the input sequences.
|
||||
query_lens: The query lengths of the input sequences.
|
||||
cuda_graph_pad_size: The padding size for cuda graph.
|
||||
-1 if cuda graph is not used.
|
||||
batch_size: The maybe padded batch size.
|
||||
"""
|
||||
prefix_cache_hit = any([
|
||||
inter_data.prefix_cache_hit
|
||||
for inter_data in self.input_builder.inter_data_list
|
||||
])
|
||||
for inter_data in self.input_builder.inter_data_list:
|
||||
self._add_seq_group(inter_data,
|
||||
self.input_builder.chunked_prefill_enabled,
|
||||
prefix_cache_hit)
|
||||
|
||||
device = self.runner.device
|
||||
use_captured_graph = cuda_graph_pad_size != -1
|
||||
|
||||
max_query_len = max(query_lens)
|
||||
decode_query_lens = query_lens[self.num_prefills:]
|
||||
if len(decode_query_lens) > 0:
|
||||
max_decode_query_len = max(decode_query_lens)
|
||||
else:
|
||||
max_decode_query_len = 1
|
||||
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
|
||||
max_decode_seq_len = max(self.curr_seq_lens, default=0)
|
||||
num_decode_tokens = self.num_decode_tokens
|
||||
query_start_loc = list(accumulate(query_lens, initial=0))
|
||||
seq_start_loc = list(accumulate(seq_lens, initial=0))
|
||||
|
||||
num_seqs = len(seq_lens)
|
||||
if use_captured_graph:
|
||||
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
|
||||
self.block_tables.extend([] * cuda_graph_pad_size)
|
||||
num_decode_tokens = batch_size - self.num_prefill_tokens
|
||||
block_tables = self._get_graph_runner_block_tables(
|
||||
num_seqs, self.block_tables)
|
||||
else:
|
||||
block_tables = make_tensor_with_pad(
|
||||
self.block_tables,
|
||||
pad=0,
|
||||
dtype=torch.int,
|
||||
device=device,
|
||||
)
|
||||
assert max_query_len > 0, ("query_lens: {}".format(query_lens))
|
||||
|
||||
assert device is not None
|
||||
context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int,
|
||||
device, self.runner.pin_memory)
|
||||
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
|
||||
self.runner.pin_memory)
|
||||
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
|
||||
device, self.runner.pin_memory)
|
||||
query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32,
|
||||
device,
|
||||
self.runner.pin_memory)
|
||||
seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32,
|
||||
device, self.runner.pin_memory)
|
||||
placeholder_index_maps = {
|
||||
modality: placeholder_map.index_map()
|
||||
for modality, placeholder_map in
|
||||
self.multimodal_placeholder_maps.items()
|
||||
}
|
||||
|
||||
return FlashAttentionMetadata(
|
||||
num_prefills=self.num_prefills,
|
||||
slot_mapping=slot_mapping_tensor,
|
||||
num_prefill_tokens=self.num_prefill_tokens,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
seq_lens=seq_lens,
|
||||
multi_modal_placeholder_index_maps=placeholder_index_maps,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_query_len=max_query_len,
|
||||
max_decode_query_len=max_decode_query_len,
|
||||
max_prefill_seq_len=max_prefill_seq_len,
|
||||
max_decode_seq_len=max_decode_seq_len,
|
||||
query_start_loc=query_start_loc_tensor,
|
||||
seq_start_loc=seq_start_loc_tensor,
|
||||
context_lens_tensor=context_lens_tensor,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=use_captured_graph,
|
||||
)
|
||||
|
||||
|
||||
class FlashAttentionImpl(AttentionImpl):
|
||||
"""
|
||||
If the input tensors contain prompt tokens, the layout is as follows:
|
||||
|<--------------- num_prefill_tokens ----------------->|
|
||||
|<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
|
||||
|
||||
Otherwise, the layout is as follows:
|
||||
|<----------------- num_decode_tokens ------------------>|
|
||||
|<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
|
||||
|
||||
Generation tokens can contain padding when cuda-graph is used.
|
||||
Currently, prompt tokens don't contain any padding.
|
||||
|
||||
The prompts might have different lengths, while the generation tokens
|
||||
always have length 1.
|
||||
|
||||
If chunked prefill is enabled, prefill tokens and decode tokens can be
|
||||
batched together in a flattened 1D query.
|
||||
|
||||
|<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
|
||||
|<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|
|
||||
|
||||
Currently, cuda graph is disabled for chunked prefill, meaning there's no
|
||||
padding between prefill and decode tokens.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
) -> None:
|
||||
if blocksparse_params is not None:
|
||||
raise ValueError(
|
||||
"FlashAttention does not support block-sparse attention.")
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
if alibi_slopes is not None:
|
||||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
||||
self.alibi_slopes = alibi_slopes
|
||||
self.sliding_window = ((sliding_window - 1,
|
||||
0) if sliding_window is not None else (-1, -1))
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
if logits_soft_cap is None:
|
||||
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
|
||||
logits_soft_cap = 0
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
support_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
|
||||
if head_size not in support_head_sizes:
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by FlashAttention. "
|
||||
f"Supported head sizes are: {support_head_sizes}.")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: FlashAttentionMetadata,
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashAttention.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads * head_size]
|
||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
|
||||
NOTE: kv_cache will be an empty tensor with shape [0]
|
||||
for profiling run.
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
# NOTE(woosuk): FlashAttention does not support FP8 KV cache.
|
||||
assert k_scale == 1.0 and v_scale == 1.0, (
|
||||
"key/v_scale is not supported in FlashAttention.")
|
||||
|
||||
if (attn_type == AttentionType.ENCODER
|
||||
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
|
||||
raise AttributeError("Encoder attention requires setting "
|
||||
"encoder metadata attributes.")
|
||||
elif (attn_type == AttentionType.ENCODER_DECODER
|
||||
and (not attn_metadata.is_all_cross_attn_metadata_set)):
|
||||
raise AttributeError("Encoder/decoder cross-attention "
|
||||
"requires setting cross-attention "
|
||||
"metadata attributes.")
|
||||
|
||||
output = torch.ops.vllm.unified_flash_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
self.num_heads,
|
||||
self.head_size,
|
||||
self.num_kv_heads,
|
||||
kv_cache,
|
||||
self.kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
self.scale,
|
||||
attn_type.value,
|
||||
self.sliding_window,
|
||||
self.alibi_slopes,
|
||||
self.logits_soft_cap,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def _get_query_key_seq_metadata(
|
||||
attn_metadata,
|
||||
is_prompt: bool,
|
||||
attn_type: AttentionType,
|
||||
) -> tuple:
|
||||
"""
|
||||
Returns sequence metadata for key and query based on the specified
|
||||
attention type and whether input is a prompt.
|
||||
|
||||
This function computes the starting locations and maximum sequence lengths
|
||||
for key and query sequences for different attention types.
|
||||
|
||||
Args:
|
||||
attn_metadata: The attention metadata object
|
||||
is_prompt (bool): A flag indicating if the input is a prompt
|
||||
attn_type (AttentionType): The type of attention being used.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing four integers:
|
||||
- Starting location for the query sequence.
|
||||
- Maximum sequence length for the query sequence.
|
||||
- Starting location for the key sequence.
|
||||
- Maximum sequence length for the key sequence.
|
||||
|
||||
Raises:
|
||||
AttributeError: If an invalid attention type is provided.
|
||||
"""
|
||||
if attn_type == AttentionType.DECODER:
|
||||
# Decoder self-attention
|
||||
# Choose max_seq_len based on whether we are in prompt_run
|
||||
if is_prompt:
|
||||
max_seq_len = attn_metadata.max_prefill_seq_len
|
||||
else:
|
||||
max_seq_len = attn_metadata.max_decode_seq_len
|
||||
return (attn_metadata.seq_start_loc, max_seq_len,
|
||||
attn_metadata.seq_start_loc, max_seq_len)
|
||||
|
||||
elif attn_type == AttentionType.ENCODER_DECODER:
|
||||
# This is cross attention between the where the key
|
||||
# is the precomputed encoder attention and query
|
||||
# is the input sequence.
|
||||
# Choose query max length based on whether it is prompt
|
||||
# or not.
|
||||
if is_prompt:
|
||||
max_seq_len = attn_metadata.max_prefill_seq_len
|
||||
else:
|
||||
max_seq_len = attn_metadata.max_decode_seq_len
|
||||
return (attn_metadata.seq_start_loc, max_seq_len,
|
||||
attn_metadata.encoder_seq_start_loc,
|
||||
attn_metadata.max_encoder_seq_len)
|
||||
elif attn_type == AttentionType.ENCODER:
|
||||
# For encoder attention both the query and the key are same i.e the
|
||||
# encoder sequence.
|
||||
return (attn_metadata.encoder_seq_start_loc,
|
||||
attn_metadata.max_encoder_seq_len,
|
||||
attn_metadata.encoder_seq_start_loc,
|
||||
attn_metadata.max_encoder_seq_len)
|
||||
elif attn_type == AttentionType.ENCODER_ONLY:
|
||||
assert is_prompt, "Should not have decode for encoder only model."
|
||||
return (attn_metadata.seq_start_loc, attn_metadata.max_prefill_seq_len,
|
||||
attn_metadata.seq_start_loc, attn_metadata.max_prefill_seq_len)
|
||||
else:
|
||||
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
||||
|
||||
|
||||
def _get_causal_option(attn_type: AttentionType) -> bool:
|
||||
"""
|
||||
Determine whether the given attention type is suitable for causal
|
||||
attention mechanisms.
|
||||
|
||||
Args:
|
||||
attn_type (AttentionType): The type of attention being evaluated
|
||||
|
||||
Returns:
|
||||
bool: Returns `True` if the attention type is suitable for causal
|
||||
attention (i.e., not encoder, encoder-only, or encoder-decoder),
|
||||
otherwise returns `False`.
|
||||
"""
|
||||
return not (attn_type == AttentionType.ENCODER
|
||||
or attn_type == AttentionType.ENCODER_ONLY
|
||||
or attn_type == AttentionType.ENCODER_DECODER)
|
||||
|
||||
|
||||
def unified_flash_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
num_kv_heads: int,
|
||||
kv_cache: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
k_scale: float,
|
||||
v_scale: float,
|
||||
softmax_scale: float,
|
||||
attn_type_int_val: int,
|
||||
window_size: Optional[List[int]] = None,
|
||||
alibi_slopes: Optional[torch.Tensor] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
# Convert integer attn_type to enum
|
||||
try:
|
||||
attn_type = AttentionType(attn_type_int_val)
|
||||
except ValueError as err:
|
||||
raise AttributeError(
|
||||
f"Invalid attention type {str(attn_type_int_val)}") from err
|
||||
|
||||
current_metadata = get_forward_context()
|
||||
assert current_metadata is not None
|
||||
assert isinstance(current_metadata, FlashAttentionMetadata)
|
||||
attn_metadata: FlashAttentionMetadata = current_metadata
|
||||
|
||||
num_tokens, hidden_size = query.shape
|
||||
|
||||
# Reshape the query, key, and value tensors.
|
||||
query = query.view(-1, num_heads, head_size)
|
||||
if (key is not None) and (value is not None):
|
||||
key = key.view(-1, num_kv_heads, head_size)
|
||||
value = value.view(-1, num_kv_heads, head_size)
|
||||
|
||||
if kv_cache.numel() > 0:
|
||||
key_cache = kv_cache[0]
|
||||
value_cache = kv_cache[1]
|
||||
# We skip updating the KV cache under two conditions:
|
||||
# a. When the Attention Type is ENCODER. In this phase, we compute
|
||||
# only the encoder attention without updating the cache.
|
||||
# b. When both Key and Value are None. This occurs during
|
||||
# cross-attention computation in the decoding phase, where the KV
|
||||
# cache is already populated with the cross-attention tensor.
|
||||
# Thus, we skip cache updates during this time.
|
||||
if (attn_type != AttentionType.ENCODER) and (key is not None) and (
|
||||
value is not None):
|
||||
if attn_type == AttentionType.ENCODER_DECODER:
|
||||
# Update cross-attention KV cache (prefill-only)
|
||||
updated_slot_mapping = attn_metadata.cross_slot_mapping
|
||||
else:
|
||||
# Update self-attention KV cache (prefill/decode)
|
||||
updated_slot_mapping = attn_metadata.slot_mapping
|
||||
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# If kv_cache is not provided, the new key and value tensors are
|
||||
# not cached. This happens during the initial memory profiling run.
|
||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
updated_slot_mapping.flatten(), # type: ignore[union-attr]
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
|
||||
(num_prefill_query_tokens, num_prefill_kv_tokens,
|
||||
num_decode_query_tokens) = \
|
||||
get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type)
|
||||
decode_query = query[num_prefill_query_tokens:]
|
||||
# QKV for prefill.
|
||||
query = query[:num_prefill_query_tokens]
|
||||
assert query.shape[0] == num_prefill_query_tokens
|
||||
assert decode_query.shape[0] == num_decode_query_tokens
|
||||
|
||||
prefill_output: Optional[torch.Tensor] = None
|
||||
decode_output: Optional[torch.Tensor] = None
|
||||
if prefill_meta := attn_metadata.prefill_metadata:
|
||||
# Prompt run.
|
||||
if (kv_cache.numel() == 0 or prefill_meta.block_tables is None
|
||||
or prefill_meta.block_tables.numel() == 0):
|
||||
# normal attention
|
||||
# When block_tables are not filled, it means q and k are the
|
||||
# prompt, and they have the same length.
|
||||
q_seq_start_loc, q_seq_len, k_seq_start_loc, k_seq_len = \
|
||||
_get_query_key_seq_metadata(prefill_meta, True, attn_type)
|
||||
|
||||
key = key[:num_prefill_kv_tokens]
|
||||
value = value[:num_prefill_kv_tokens]
|
||||
|
||||
prefill_output = flash_attn_varlen_func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
cu_seqlens_q=q_seq_start_loc,
|
||||
cu_seqlens_k=k_seq_start_loc,
|
||||
max_seqlen_q=q_seq_len,
|
||||
max_seqlen_k=k_seq_len,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=_get_causal_option(attn_type),
|
||||
window_size=window_size,
|
||||
alibi_slopes=alibi_slopes,
|
||||
softcap=logits_soft_cap,
|
||||
)
|
||||
else:
|
||||
# prefix-enabled attention
|
||||
assert attn_type == AttentionType.DECODER, (
|
||||
"Only decoder-only models support prefix caching")
|
||||
assert prefill_meta.seq_lens is not None
|
||||
max_seq_len = max(prefill_meta.seq_lens)
|
||||
prefill_output = flash_attn_varlen_func( # noqa
|
||||
q=query,
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
cu_seqlens_q=prefill_meta.query_start_loc,
|
||||
max_seqlen_q=prefill_meta.max_query_len,
|
||||
cu_seqlens_k=prefill_meta.seq_start_loc,
|
||||
max_seqlen_k=max_seq_len,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
window_size=window_size,
|
||||
alibi_slopes=alibi_slopes,
|
||||
block_table=prefill_meta.block_tables,
|
||||
softcap=logits_soft_cap,
|
||||
)
|
||||
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
# Decoding run.
|
||||
# Use flash_attn_varlen_func kernel for speculative decoding
|
||||
# because different queries might have different lengths.
|
||||
|
||||
assert decode_meta.max_decode_query_len is not None
|
||||
# use only for actual varlen decoding
|
||||
if decode_meta.max_decode_query_len > 1:
|
||||
assert attn_type == AttentionType.DECODER, (
|
||||
"Only decoder-only models support max_decode_query_len > 1")
|
||||
decode_output = flash_attn_varlen_func(
|
||||
q=decode_query,
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
cu_seqlens_q=decode_meta.query_start_loc,
|
||||
max_seqlen_q=decode_meta.max_decode_query_len,
|
||||
cu_seqlens_k=decode_meta.seq_start_loc,
|
||||
max_seqlen_k=decode_meta.max_decode_seq_len,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
window_size=window_size,
|
||||
alibi_slopes=alibi_slopes,
|
||||
softcap=logits_soft_cap,
|
||||
block_table=decode_meta.block_tables,
|
||||
)
|
||||
else:
|
||||
# Use flash_attn_with_kvcache for normal decoding.
|
||||
(
|
||||
seq_lens_arg,
|
||||
_,
|
||||
block_tables_arg,
|
||||
) = get_seq_len_block_table_args(decode_meta, False, attn_type)
|
||||
decode_output = flash_attn_with_kvcache(
|
||||
q=decode_query.unsqueeze(1),
|
||||
k_cache=key_cache,
|
||||
v_cache=value_cache,
|
||||
block_table=block_tables_arg,
|
||||
cache_seqlens=seq_lens_arg,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
window_size=window_size,
|
||||
alibi_slopes=alibi_slopes,
|
||||
softcap=logits_soft_cap,
|
||||
).squeeze(1)
|
||||
|
||||
if prefill_output is None:
|
||||
assert decode_output is not None
|
||||
return decode_output.view(num_decode_query_tokens, hidden_size)
|
||||
if decode_output is None:
|
||||
assert prefill_output is not None
|
||||
return prefill_output.view(num_prefill_query_tokens, hidden_size)
|
||||
|
||||
assert decode_meta is not None
|
||||
decode_output = decode_output.squeeze(1)
|
||||
output = torch.cat([prefill_output, decode_output], dim=0)
|
||||
return output.view(num_tokens, hidden_size)
|
||||
|
||||
|
||||
def unified_flash_attention_fake(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
num_kv_heads: int,
|
||||
kv_cache: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
k_scale: float,
|
||||
v_scale: float,
|
||||
softmax_scale: float,
|
||||
attn_type_int_val: int,
|
||||
window_size: Optional[List[int]] = None,
|
||||
alibi_slopes: Optional[torch.Tensor] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(query)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="unified_flash_attention",
|
||||
op_func=unified_flash_attention,
|
||||
mutates_args=["kv_cache"],
|
||||
fake_impl=unified_flash_attention_fake,
|
||||
)
|
||||
952
vllm-v0.6.2/vllm/attention/backends/flashinfer.py
Normal file
952
vllm-v0.6.2/vllm/attention/backends/flashinfer.py
Normal file
@@ -0,0 +1,952 @@
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type
|
||||
|
||||
from vllm.multimodal import MultiModalPlaceholderMap
|
||||
|
||||
try:
|
||||
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
|
||||
from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper
|
||||
from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
|
||||
|
||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
|
||||
except ImportError:
|
||||
BatchDecodeWithPagedKVCacheWrapper = None
|
||||
CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None
|
||||
BatchPrefillWithPagedKVCacheWrapper = None
|
||||
FLASHINFER_WORKSPACE_BUFFER_SIZE = 0
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata,
|
||||
AttentionMetadataBuilder,
|
||||
AttentionState, AttentionType)
|
||||
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
|
||||
compute_slot_mapping_start_idx,
|
||||
is_block_tables_empty)
|
||||
from vllm.attention.ops.paged_attn import PagedAttention
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.utils import (async_tensor_h2d, direct_register_custom_op,
|
||||
get_kv_cache_torch_dtype, make_tensor_with_pad)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
|
||||
ModelInputForGPUWithSamplingMetadata)
|
||||
|
||||
|
||||
class FlashInferBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASHINFER"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["FlashInferImpl"]:
|
||||
return FlashInferImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> Type["AttentionMetadata"]:
|
||||
return FlashInferMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> Type["FlashInferMetadataBuilder"]:
|
||||
return FlashInferMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> Type["FlashInferState"]:
|
||||
return FlashInferState
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return (num_blocks, 2, block_size, num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: torch.Tensor,
|
||||
) -> None:
|
||||
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: torch.Tensor,
|
||||
) -> None:
|
||||
PagedAttention.copy_blocks(kv_caches, src_to_dists)
|
||||
|
||||
@staticmethod
|
||||
def get_supported_head_sizes() -> List[int]:
|
||||
return [64, 128, 256]
|
||||
|
||||
@staticmethod
|
||||
def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype:
|
||||
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
|
||||
return torch.float8_e4m3fn
|
||||
elif kv_cache_dtype == "fp8_e5m2":
|
||||
return torch.float8_e5m2
|
||||
else:
|
||||
raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")
|
||||
|
||||
|
||||
class FlashInferState(AttentionState):
|
||||
|
||||
def __init__(self, runner):
|
||||
self.runner = runner
|
||||
self._is_graph_capturing = False
|
||||
self._workspace_buffer = None
|
||||
self._decode_wrapper = None
|
||||
self._prefill_wrapper = None
|
||||
|
||||
def _get_workspace_buffer(self):
|
||||
if self._workspace_buffer is None:
|
||||
self._workspace_buffer = torch.empty(
|
||||
FLASHINFER_WORKSPACE_BUFFER_SIZE,
|
||||
dtype=torch.uint8,
|
||||
device=self.runner.device)
|
||||
return self._workspace_buffer
|
||||
|
||||
def _get_prefill_wrapper(self):
|
||||
if self._prefill_wrapper is None:
|
||||
self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
|
||||
self._get_workspace_buffer(), "NHD")
|
||||
return self._prefill_wrapper
|
||||
|
||||
def _get_decode_wrapper(self):
|
||||
if self._decode_wrapper is None:
|
||||
num_qo_heads = (self.runner.model_config.get_num_attention_heads(
|
||||
self.runner.parallel_config))
|
||||
num_kv_heads = self.runner.model_config.get_num_kv_heads(
|
||||
self.runner.parallel_config)
|
||||
use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or (
|
||||
num_qo_heads // num_kv_heads > 4)
|
||||
self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
||||
self._get_workspace_buffer(),
|
||||
"NHD",
|
||||
use_tensor_cores=use_tensor_cores)
|
||||
return self._decode_wrapper
|
||||
|
||||
@contextmanager
|
||||
def graph_capture(self, max_batch_size: int):
|
||||
self._is_graph_capturing = True
|
||||
self._graph_decode_wrapper = None
|
||||
self._graph_slot_mapping = torch.full((max_batch_size, ),
|
||||
PAD_SLOT_ID,
|
||||
dtype=torch.long,
|
||||
device=self.runner.device)
|
||||
self._graph_seq_lens = torch.ones(max_batch_size,
|
||||
dtype=torch.int32,
|
||||
device=self.runner.device)
|
||||
self._graph_block_tables = torch.from_numpy(
|
||||
self.runner.graph_block_tables).to(device=self.runner.device)
|
||||
self._graph_decode_workspace_buffer = self._get_workspace_buffer()
|
||||
self._graph_indices_buffer = torch.empty(
|
||||
max_batch_size * self.runner.cache_config.num_gpu_blocks,
|
||||
dtype=torch.int32,
|
||||
device=self.runner.device)
|
||||
self._graph_indptr_buffer = torch.empty(max_batch_size + 1,
|
||||
dtype=torch.int32,
|
||||
device=self.runner.device)
|
||||
self._graph_last_page_len_buffer = torch.empty(
|
||||
max_batch_size, dtype=torch.int32, device=self.runner.device)
|
||||
yield
|
||||
self._is_graph_capturing = False
|
||||
del self._graph_slot_mapping
|
||||
del self._graph_seq_lens
|
||||
del self._graph_block_tables
|
||||
del self._graph_decode_workspace_buffer
|
||||
del self._graph_indices_buffer
|
||||
del self._graph_indptr_buffer
|
||||
del self._graph_last_page_len_buffer
|
||||
del self._graph_decode_wrapper
|
||||
|
||||
def graph_clone(self, batch_size: int):
|
||||
assert self._is_graph_capturing
|
||||
state = self.__class__(self.runner)
|
||||
state._workspace_buffer = self._graph_decode_workspace_buffer
|
||||
state._decode_wrapper = self._graph_decode_wrapper
|
||||
state._prefill_wrapper = self._get_prefill_wrapper()
|
||||
return state
|
||||
|
||||
def graph_capture_get_metadata_for_batch(
|
||||
self, batch_size: int, is_encoder_decoder_model: bool = False):
|
||||
assert self._is_graph_capturing
|
||||
_indptr_buffer = self._graph_indptr_buffer[:batch_size + 1]
|
||||
_last_page_len_buffer = self._graph_last_page_len_buffer[:batch_size]
|
||||
|
||||
num_qo_heads = (self.runner.model_config.get_num_attention_heads(
|
||||
self.runner.parallel_config))
|
||||
num_kv_heads = self.runner.model_config.get_num_kv_heads(
|
||||
self.runner.parallel_config)
|
||||
use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or (
|
||||
num_qo_heads // num_kv_heads > 4)
|
||||
self._graph_decode_wrapper = \
|
||||
CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
|
||||
self._graph_decode_workspace_buffer, _indptr_buffer,
|
||||
self._graph_indices_buffer, _last_page_len_buffer, "NHD",
|
||||
use_tensor_cores)
|
||||
if self.runner.kv_cache_dtype.startswith("fp8"):
|
||||
kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
|
||||
self.runner.kv_cache_dtype)
|
||||
else:
|
||||
kv_cache_dtype = get_kv_cache_torch_dtype(
|
||||
self.runner.kv_cache_dtype, self.runner.model_config.dtype)
|
||||
|
||||
paged_kv_indptr_tensor_host = torch.arange(0,
|
||||
batch_size + 1,
|
||||
dtype=torch.int32)
|
||||
paged_kv_indices_tensor_host = torch.arange(0,
|
||||
batch_size,
|
||||
dtype=torch.int32)
|
||||
paged_kv_last_page_len_tensor_host = torch.full((batch_size, ),
|
||||
self.runner.block_size,
|
||||
dtype=torch.int32)
|
||||
query_start_loc_host = torch.arange(0,
|
||||
batch_size + 1,
|
||||
dtype=torch.int32)
|
||||
|
||||
attn_metadata = self.runner.attn_backend.make_metadata(
|
||||
num_prefills=0,
|
||||
slot_mapping=self._graph_slot_mapping[:batch_size],
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=batch_size,
|
||||
max_prefill_seq_len=0,
|
||||
block_tables=self._graph_block_tables,
|
||||
paged_kv_indptr=paged_kv_indptr_tensor_host,
|
||||
paged_kv_indices=paged_kv_indices_tensor_host,
|
||||
paged_kv_last_page_len=paged_kv_last_page_len_tensor_host,
|
||||
num_qo_heads=num_qo_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=self.runner.model_config.get_head_size(),
|
||||
page_size=self.runner.block_size,
|
||||
seq_start_loc=None,
|
||||
query_start_loc=query_start_loc_host,
|
||||
device=self.runner.device,
|
||||
data_type=kv_cache_dtype,
|
||||
q_data_type=self.runner.model_config.dtype,
|
||||
use_cuda_graph=True,
|
||||
decode_wrapper=self._graph_decode_wrapper,
|
||||
prefill_wrapper=None)
|
||||
attn_metadata.begin_forward()
|
||||
return attn_metadata
|
||||
|
||||
def get_graph_input_buffers(self,
|
||||
attn_metadata,
|
||||
is_encoder_decoder_model: bool = False):
|
||||
return {
|
||||
"slot_mapping": attn_metadata.slot_mapping,
|
||||
}
|
||||
|
||||
def prepare_graph_input_buffers(self,
|
||||
input_buffers,
|
||||
attn_metadata,
|
||||
is_encoder_decoder_model: bool = False):
|
||||
return
|
||||
|
||||
def begin_forward(self, model_input):
|
||||
assert not self._is_graph_capturing
|
||||
state = self
|
||||
if model_input.attn_metadata.use_cuda_graph:
|
||||
batch_size = model_input.input_tokens.shape[0]
|
||||
state = (self.runner.graph_runners[model_input.virtual_engine]
|
||||
[batch_size].attn_state)
|
||||
model_input.attn_metadata.prefill_wrapper = state._get_prefill_wrapper(
|
||||
)
|
||||
model_input.attn_metadata.decode_wrapper = state._get_decode_wrapper()
|
||||
model_input.attn_metadata.begin_forward()
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashInferMetadata(AttentionMetadata):
|
||||
# Maximum sequence length among prefill batch. 0 if there are decoding
|
||||
# requests only.
|
||||
max_prefill_seq_len: int
|
||||
# Number of query tokens for each request in the batch.
|
||||
# Currently, we require that all requests have the same number of query
|
||||
# tokens during the decoding phase. When speculavie decoding is enabled,
|
||||
# decode_query_len might be greater than 1. In all other cases, it is 1.
|
||||
decode_query_len: Optional[int] = 1
|
||||
|
||||
use_cuda_graph: bool = True
|
||||
|
||||
prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None
|
||||
decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None
|
||||
|
||||
# Metadata for the prefill stage
|
||||
seq_start_loc: Optional[torch.Tensor] = None
|
||||
query_start_loc: Optional[torch.Tensor] = None
|
||||
block_tables: Optional[torch.Tensor] = None
|
||||
|
||||
# used for GPU in-place advance_step
|
||||
seq_lens_tensor: Optional[torch.Tensor] = None
|
||||
block_table_bound: Optional[torch.Tensor] = None
|
||||
|
||||
# An example for paged_kv_indices, paged_kv_indptr:
|
||||
# request 1, page indices [0, 5, 8]
|
||||
# request 2, page indices [1, 6, 7]
|
||||
# request 3, page indices [3, 4]
|
||||
# paged_kv_indices is a concatenation of page indices of all requests:
|
||||
# [0, 5, 8, 1, 6, 7, 3, 4]
|
||||
# paged_kv_indptr is used to index into paged_kv_indices:
|
||||
# [0, 3, 6, 8]
|
||||
# The indptr of the paged kv cache, shape: [batch_size + 1]
|
||||
paged_kv_indptr: Optional[torch.Tensor] = None
|
||||
# The page indices of the paged kv cache
|
||||
paged_kv_indices: Optional[torch.Tensor] = None
|
||||
# The number of entries in the last page of each request in
|
||||
# the paged kv cache, shape: [batch_size]
|
||||
paged_kv_last_page_len: Optional[torch.Tensor] = None
|
||||
# The number of query/output heads
|
||||
num_qo_heads: Optional[int] = None
|
||||
# The number of key/value heads
|
||||
num_kv_heads: Optional[int] = None
|
||||
# The dimension of the attention heads
|
||||
head_dim: Optional[int] = None
|
||||
# Block size of vllm
|
||||
page_size: Optional[int] = None
|
||||
# The data type of the paged kv cache
|
||||
data_type: torch.dtype = None
|
||||
# The data type of the query
|
||||
q_data_type: torch.dtype = None
|
||||
device: torch.device = torch.device("cuda")
|
||||
is_profile_run: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
# Refer to
|
||||
# https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
|
||||
supported_head_sizes = FlashInferBackend.get_supported_head_sizes()
|
||||
if self.head_dim is not None and self.head_dim \
|
||||
not in supported_head_sizes:
|
||||
raise ValueError(
|
||||
f"Only {supported_head_sizes} are supported for head_dim,",
|
||||
f"received {self.head_dim}.")
|
||||
|
||||
def begin_forward(self):
|
||||
if self.num_prefill_tokens > 0:
|
||||
if self.paged_kv_indices is None:
|
||||
return
|
||||
|
||||
assert self.prefill_wrapper is not None
|
||||
assert self.query_start_loc is not None
|
||||
assert self.paged_kv_indices is not None
|
||||
assert self.paged_kv_indptr is not None
|
||||
assert self.paged_kv_last_page_len is not None
|
||||
assert self.block_table_bound is not None
|
||||
assert self.seq_lens_tensor is not None
|
||||
self.query_start_loc = self.query_start_loc[:self.num_prefills + 1]
|
||||
batch_size = self.query_start_loc.shape[0] - 1
|
||||
assert batch_size >= 0
|
||||
# We will use flash attention for profiling to
|
||||
# determine the number of blocks. Therefore,
|
||||
# we don't need to prepare the input for flashinfer for profile run.
|
||||
if not self.is_profile_run:
|
||||
self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
|
||||
self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
|
||||
self.device)
|
||||
self.block_table_bound = self.block_table_bound.to(self.device)
|
||||
self.seq_lens_tensor = self.seq_lens_tensor.to(self.device)
|
||||
self.paged_kv_indices = self.paged_kv_indices.to(self.device)
|
||||
self.prefill_wrapper.end_forward()
|
||||
self.prefill_wrapper.begin_forward(
|
||||
self.query_start_loc,
|
||||
self.paged_kv_indptr[:self.num_prefills + 1],
|
||||
self.paged_kv_indices,
|
||||
self.paged_kv_last_page_len[:self.num_prefills],
|
||||
self.num_qo_heads, self.num_kv_heads, self.head_dim,
|
||||
self.page_size)
|
||||
if self.num_decode_tokens > 0:
|
||||
assert self.paged_kv_indices is not None
|
||||
assert self.paged_kv_indptr is not None
|
||||
assert self.paged_kv_last_page_len is not None
|
||||
self.paged_kv_indices = self.paged_kv_indices.to(self.device)
|
||||
self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
|
||||
self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
|
||||
self.device)
|
||||
# handle model warmup path
|
||||
if self.block_table_bound is not None:
|
||||
self.block_table_bound = self.block_table_bound.to(self.device)
|
||||
if self.seq_lens_tensor is not None:
|
||||
self.seq_lens_tensor = self.seq_lens_tensor.to(self.device)
|
||||
|
||||
assert self.decode_wrapper is not None
|
||||
self.decode_wrapper.end_forward()
|
||||
self.decode_wrapper.begin_forward(
|
||||
self.paged_kv_indptr[self.num_prefills:],
|
||||
self.paged_kv_indices,
|
||||
self.paged_kv_last_page_len[self.num_prefills:],
|
||||
self.num_qo_heads,
|
||||
self.num_kv_heads,
|
||||
self.head_dim,
|
||||
self.page_size,
|
||||
# Disable flashinfer's pos encoding and use vllm's rope.
|
||||
pos_encoding_mode="NONE",
|
||||
# kv-cache data type.
|
||||
data_type=self.data_type,
|
||||
# query data type.
|
||||
q_data_type=self.q_data_type)
|
||||
|
||||
def asdict_zerocopy(self,
|
||||
skip_fields: Optional[Set[str]] = None
|
||||
) -> Dict[str, Any]:
|
||||
if skip_fields is None:
|
||||
skip_fields = set()
|
||||
# We need to skip the prefill/decode_wrapper field since it cannot be
|
||||
# broadcasted with nccl when TP is enabled.
|
||||
skip_fields.add('prefill_wrapper')
|
||||
skip_fields.add('decode_wrapper')
|
||||
return super().asdict_zerocopy(skip_fields)
|
||||
|
||||
@property
|
||||
def prefill_metadata(self) -> Optional["FlashInferMetadata"]:
|
||||
if self.num_prefills == 0:
|
||||
return None
|
||||
return self
|
||||
|
||||
@property
|
||||
def decode_metadata(self) -> Optional["FlashInferMetadata"]:
|
||||
if self.num_decode_tokens == 0:
|
||||
return None
|
||||
return self
|
||||
|
||||
def advance_step(self,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
sampled_token_ids: Optional[torch.Tensor],
|
||||
block_size: int,
|
||||
num_seqs: int,
|
||||
num_queries: int,
|
||||
turn_prefills_into_decodes: bool = False):
|
||||
"""
|
||||
Update metadata in-place to advance one decode step.
|
||||
"""
|
||||
|
||||
assert not turn_prefills_into_decodes, \
|
||||
("Chunked prefill is not supported with flashinfer yet."
|
||||
"turn_prefills_into_decodes is a Multi-Step + Chunked-Prefill "
|
||||
"specific parameter.")
|
||||
|
||||
assert num_seqs > 0
|
||||
assert num_queries > 0
|
||||
assert model_input.attn_metadata is not None
|
||||
assert sampled_token_ids is not None
|
||||
|
||||
# When using cudagraph, the num_seqs is padded to the next captured
|
||||
# batch sized, but num_queries tracks the actual number of requests in
|
||||
# the batch. For --enforce-eager mode, num_seqs == num_queries
|
||||
if num_seqs != num_queries:
|
||||
assert num_seqs > num_queries
|
||||
assert self.use_cuda_graph
|
||||
|
||||
model_input.input_tokens[:num_queries] = sampled_token_ids.flatten()
|
||||
|
||||
# Update GPU tensors
|
||||
ops.advance_step_flashinfer(
|
||||
num_seqs=num_seqs,
|
||||
num_queries=num_queries,
|
||||
block_size=block_size,
|
||||
input_tokens=model_input.input_tokens,
|
||||
sampled_token_ids=model_input.input_tokens,
|
||||
input_positions=model_input.input_positions,
|
||||
seq_lens=self.seq_lens_tensor,
|
||||
slot_mapping=self.slot_mapping,
|
||||
block_tables=self.block_tables,
|
||||
paged_kv_indices=self.paged_kv_indices,
|
||||
paged_kv_indptr=self.paged_kv_indptr,
|
||||
paged_kv_last_page_len=self.paged_kv_last_page_len,
|
||||
block_table_bound=self.block_table_bound)
|
||||
|
||||
|
||||
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
|
||||
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
|
||||
self.slot_mapping: List[int] = []
|
||||
self.prefill_seq_lens: List[int] = []
|
||||
self.context_lens: List[int] = []
|
||||
self.block_tables: List[List[int]] = []
|
||||
self.curr_seq_lens: List[int] = []
|
||||
self.multimodal_placeholder_maps: Dict[
|
||||
str,
|
||||
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
|
||||
self.num_prefills = 0
|
||||
self.num_prefill_tokens = 0
|
||||
self.num_decode_tokens = 0
|
||||
|
||||
self.input_builder = input_builder
|
||||
self.runner = input_builder.runner
|
||||
|
||||
self.sliding_window = input_builder.sliding_window
|
||||
self.block_size = input_builder.block_size
|
||||
|
||||
# Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout
|
||||
# for the precise definition of the following fields.
|
||||
# An example:
|
||||
# request 1, page indices [0, 5, 8]
|
||||
# request 2, page indices [1, 6, 7]
|
||||
# request 3, page indices [3, 4]
|
||||
# paged_kv_indices is a concatenation of page indices of all requests:
|
||||
# [0, 5, 8, 1, 6, 7, 3, 4]
|
||||
# paged_kv_indptr is used to index into paged_kv_indices:
|
||||
# [0, 3, 6, 8]
|
||||
self.paged_kv_indices: List[int] = []
|
||||
# 0 at the beginning of paged_kv_indptr indicates the start of the
|
||||
# first request’s page indices in the paged_kv_indices list.
|
||||
self.paged_kv_indptr: List[int] = [0]
|
||||
# paged_kv_last_page_len is the length of the last page of each request
|
||||
self.paged_kv_last_page_len: List[int] = []
|
||||
self.total_blocks = 0
|
||||
self.is_profile_run: bool = False
|
||||
|
||||
def _add_seq_group(
|
||||
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
|
||||
chunked_prefill_enabled: bool):
|
||||
"""Add a sequence group to the metadata. Specifically update/append
|
||||
1. context length.
|
||||
2. block table.
|
||||
3. slot mapping.
|
||||
"""
|
||||
is_prompt = inter_data.is_prompt
|
||||
block_tables = inter_data.block_tables
|
||||
computed_block_nums = inter_data.computed_block_nums
|
||||
|
||||
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
|
||||
curr_sliding_window_block) in zip(
|
||||
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
|
||||
inter_data.orig_seq_lens, inter_data.seq_lens,
|
||||
inter_data.query_lens, inter_data.context_lens,
|
||||
inter_data.curr_sliding_window_blocks):
|
||||
self.context_lens.append(context_len)
|
||||
if is_prompt:
|
||||
mm_maps = inter_data.multi_modal_placeholder_maps
|
||||
if mm_maps:
|
||||
for modality, placeholders in mm_maps.items():
|
||||
self.multimodal_placeholder_maps[modality].extend(
|
||||
placeholders)
|
||||
self.num_prefills += 1
|
||||
self.num_prefill_tokens += token_len
|
||||
self.prefill_seq_lens.append(seq_len)
|
||||
else:
|
||||
assert query_len == 1, (
|
||||
"seq_len: {}, context_len: {}, query_len: {}".format(
|
||||
seq_len, context_len, query_len))
|
||||
self.num_decode_tokens += query_len
|
||||
self.curr_seq_lens.append(curr_seq_len)
|
||||
|
||||
# Compute block table.
|
||||
# TODO(sang): Combine chunked prefill and prefix caching by
|
||||
# only allowing multiple of block_size chunk size.
|
||||
# NOTE: This only works for oooooooxxx style attention.
|
||||
block_table = []
|
||||
if inter_data.prefix_cache_hit:
|
||||
block_table = computed_block_nums
|
||||
elif ((chunked_prefill_enabled or not is_prompt)
|
||||
and block_tables is not None):
|
||||
block_table = block_tables[seq_id][-curr_sliding_window_block:]
|
||||
self.block_tables.append(block_table)
|
||||
|
||||
is_profile_run = is_block_tables_empty(block_tables)
|
||||
|
||||
# Compute slot mapping.
|
||||
start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
|
||||
context_len,
|
||||
self.sliding_window)
|
||||
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
|
||||
seq_len, context_len, start_idx,
|
||||
self.block_size, inter_data.block_tables)
|
||||
|
||||
# It is not necessary to add paged_kv_indices, paged_kv_indptr,
|
||||
# and paged_kv_last_page_len for profile run because we will
|
||||
# create dummy inputs.
|
||||
if is_profile_run:
|
||||
self.is_profile_run = is_profile_run
|
||||
return
|
||||
|
||||
block_table = block_tables[seq_id]
|
||||
self._update_paged_kv_tensors(block_table, seq_len)
|
||||
|
||||
def _update_paged_kv_tensors(self, block_table: List[int], seq_len: int):
|
||||
# Get the number of valid blocks based on sequence length.
|
||||
# If seq_len = 16, block_size = 16,
|
||||
# block_table_bound is 1 with 1 valid block.
|
||||
# If seq_len = 15, block_size = 16,
|
||||
# block_table_bound is 0 + 1 with 1 valid block.
|
||||
self.total_blocks += len(block_table)
|
||||
block_table_bound = seq_len // self.block_size + 1 \
|
||||
if seq_len % self.block_size != 0 \
|
||||
else seq_len // self.block_size
|
||||
self.paged_kv_indices.extend(block_table[:block_table_bound])
|
||||
self.paged_kv_indptr.append(self.paged_kv_indptr[-1] +
|
||||
block_table_bound)
|
||||
|
||||
last_page_len = seq_len % self.block_size
|
||||
if last_page_len == 0:
|
||||
last_page_len = self.block_size
|
||||
self.paged_kv_last_page_len.append(last_page_len)
|
||||
|
||||
def build(self, seq_lens: List[int], query_lens: List[int],
|
||||
cuda_graph_pad_size: int, batch_size: int):
|
||||
"""Build attention metadata with on-device tensors.
|
||||
|
||||
Args:
|
||||
seq_lens: The maybe padded sequence lengths of the input sequences.
|
||||
query_lens: The query lengths of the input sequences.
|
||||
cuda_graph_pad_size: The padding size for cuda graph.
|
||||
-1 if cuda graph is not used.
|
||||
batch_size: The maybe padded batch size.
|
||||
"""
|
||||
for inter_data in self.input_builder.inter_data_list:
|
||||
self._add_seq_group(inter_data,
|
||||
self.input_builder.chunked_prefill_enabled)
|
||||
|
||||
device = self.runner.device
|
||||
use_captured_graph = cuda_graph_pad_size != -1
|
||||
|
||||
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
|
||||
num_decode_tokens = self.num_decode_tokens
|
||||
decode_query_len = max(query_lens[self.num_prefills:], default=1)
|
||||
|
||||
if use_captured_graph:
|
||||
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
|
||||
self.block_tables.extend([] * cuda_graph_pad_size)
|
||||
num_decode_tokens = batch_size - self.num_prefill_tokens
|
||||
|
||||
# The shape of graph_block_tables is
|
||||
# [max batch size, max context len // block size].
|
||||
input_block_tables = self.runner.graph_block_tables[:batch_size]
|
||||
max_blocks = input_block_tables.shape[1]
|
||||
for i, block_table in enumerate(self.block_tables):
|
||||
if block_table:
|
||||
num_blocks = len(block_table)
|
||||
if num_blocks <= max_blocks:
|
||||
input_block_tables[i, :num_blocks] = block_table
|
||||
else:
|
||||
# It may be possible to have more blocks allocated due
|
||||
# to lookahead slots of multi-step, however, they are
|
||||
# not used anyway, so can be safely ignored.
|
||||
input_block_tables[
|
||||
i, :max_blocks] = block_table[:max_blocks]
|
||||
|
||||
block_tables = torch.from_numpy(input_block_tables).to(
|
||||
device, non_blocking=True)
|
||||
|
||||
last_paged_kv_indptr = self.paged_kv_indptr[-1]
|
||||
self.paged_kv_indptr.extend([last_paged_kv_indptr] *
|
||||
cuda_graph_pad_size)
|
||||
self.paged_kv_last_page_len.extend([0] * cuda_graph_pad_size)
|
||||
else:
|
||||
block_tables = make_tensor_with_pad(
|
||||
self.block_tables,
|
||||
pad=0,
|
||||
dtype=torch.int,
|
||||
device=device,
|
||||
)
|
||||
|
||||
assert device is not None
|
||||
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
|
||||
self.runner.pin_memory)
|
||||
query_lens_tensor = async_tensor_h2d(query_lens, torch.long, device,
|
||||
self.runner.pin_memory)
|
||||
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
|
||||
device, self.runner.pin_memory)
|
||||
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
placeholder_index_maps = {
|
||||
modality: placeholder_map.index_map()
|
||||
for modality, placeholder_map in
|
||||
self.multimodal_placeholder_maps.items()
|
||||
}
|
||||
torch.cumsum(seq_lens_tensor,
|
||||
dim=0,
|
||||
dtype=seq_start_loc.dtype,
|
||||
out=seq_start_loc[1:])
|
||||
torch.cumsum(query_lens_tensor,
|
||||
dim=0,
|
||||
dtype=query_start_loc.dtype,
|
||||
out=query_start_loc[1:])
|
||||
|
||||
if len(self.paged_kv_indptr) > 0:
|
||||
# extend to the maximum number of blocks as returned by the
|
||||
# scheduler
|
||||
self.paged_kv_indices.extend(
|
||||
[0] * (self.total_blocks - len(self.paged_kv_indices)))
|
||||
paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices,
|
||||
device="cpu",
|
||||
dtype=torch.int)
|
||||
paged_kv_indptr_tensor = torch.tensor(self.paged_kv_indptr,
|
||||
device="cpu",
|
||||
dtype=torch.int)
|
||||
paged_kv_last_page_len_tensor = torch.tensor(
|
||||
self.paged_kv_last_page_len, device="cpu", dtype=torch.int)
|
||||
block_table_bound_tensor = torch.zeros(len(self.paged_kv_indptr) -
|
||||
1,
|
||||
device="cpu",
|
||||
dtype=torch.int)
|
||||
else:
|
||||
paged_kv_indices_tensor = None
|
||||
paged_kv_indptr_tensor = None
|
||||
paged_kv_last_page_len_tensor = None
|
||||
block_table_bound_tensor = None
|
||||
|
||||
if self.runner.kv_cache_dtype.startswith("fp8"):
|
||||
kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
|
||||
self.runner.kv_cache_dtype)
|
||||
else:
|
||||
kv_cache_dtype = get_kv_cache_torch_dtype(
|
||||
self.runner.kv_cache_dtype, self.runner.model_config.dtype)
|
||||
|
||||
return FlashInferMetadata(
|
||||
decode_query_len=decode_query_len,
|
||||
num_prefills=self.num_prefills,
|
||||
slot_mapping=slot_mapping_tensor,
|
||||
multi_modal_placeholder_index_maps=placeholder_index_maps,
|
||||
num_prefill_tokens=self.num_prefill_tokens,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
max_prefill_seq_len=max_prefill_seq_len,
|
||||
block_tables=block_tables,
|
||||
paged_kv_indptr=paged_kv_indptr_tensor,
|
||||
paged_kv_indices=paged_kv_indices_tensor,
|
||||
paged_kv_last_page_len=paged_kv_last_page_len_tensor,
|
||||
block_table_bound=block_table_bound_tensor,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
num_qo_heads=self.runner.model_config.get_num_attention_heads(
|
||||
self.runner.parallel_config),
|
||||
num_kv_heads=self.runner.model_config.get_num_kv_heads(
|
||||
self.runner.parallel_config),
|
||||
head_dim=self.runner.model_config.get_head_size(),
|
||||
page_size=self.block_size,
|
||||
seq_start_loc=seq_start_loc,
|
||||
query_start_loc=query_start_loc,
|
||||
device=device,
|
||||
data_type=kv_cache_dtype,
|
||||
q_data_type=self.runner.model_config.dtype,
|
||||
use_cuda_graph=use_captured_graph,
|
||||
is_profile_run=self.is_profile_run)
|
||||
|
||||
|
||||
class FlashInferImpl(AttentionImpl):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
if alibi_slopes is not None:
|
||||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
||||
self.alibi_slopes = alibi_slopes
|
||||
if sliding_window is not None:
|
||||
raise ValueError("Sliding window is not supported in FlashInfer.")
|
||||
self.sliding_window = (-1, -1)
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: FlashInferMetadata,
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
) -> torch.Tensor:
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"FlashInferImpl")
|
||||
|
||||
return torch.ops.vllm.unified_flash_infer(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
self.num_heads,
|
||||
self.head_size,
|
||||
self.num_kv_heads,
|
||||
kv_cache,
|
||||
self.kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
self.scale,
|
||||
self.sliding_window,
|
||||
self.alibi_slopes,
|
||||
self.logits_soft_cap,
|
||||
)
|
||||
|
||||
|
||||
def unified_flash_infer(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
num_kv_heads: int,
|
||||
kv_cache: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
k_scale: float,
|
||||
v_scale: float,
|
||||
softmax_scale: float,
|
||||
window_size: Optional[List[int]] = None,
|
||||
alibi_slopes: Optional[torch.Tensor] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
current_metadata = get_forward_context()
|
||||
assert current_metadata is not None
|
||||
assert isinstance(current_metadata, FlashInferMetadata)
|
||||
attn_metadata: FlashInferMetadata = current_metadata
|
||||
|
||||
num_tokens, hidden_size = query.shape
|
||||
query = query.view(-1, num_heads, head_size)
|
||||
key = key.view(-1, num_kv_heads, head_size)
|
||||
value = value.view(-1, num_kv_heads, head_size)
|
||||
|
||||
if kv_cache.numel() > 0:
|
||||
# Use the same reshape and cache kernel as flash attention.
|
||||
ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
kv_cache[:, 0],
|
||||
kv_cache[:, 1],
|
||||
attn_metadata.slot_mapping.flatten(),
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
|
||||
# to process the cache when the kv_cache_dtype is fp8
|
||||
if kv_cache_dtype.startswith("fp8"):
|
||||
torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
|
||||
kv_cache_dtype)
|
||||
kv_cache = kv_cache.view(torch_dtype)
|
||||
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \
|
||||
f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa
|
||||
assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \
|
||||
f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa
|
||||
query = query.contiguous() # Flashinfer requires query to be contiguous
|
||||
# Query for decode. KV is not needed because it is already cached.
|
||||
# QKV for prefill.
|
||||
decode_query = query[num_prefill_tokens:]
|
||||
query = query[:num_prefill_tokens]
|
||||
|
||||
key = key[:num_prefill_tokens]
|
||||
value = value[:num_prefill_tokens]
|
||||
|
||||
assert query.shape[0] == num_prefill_tokens
|
||||
assert decode_query.shape[0] == num_decode_tokens
|
||||
|
||||
prefill_output: Optional[torch.Tensor] = None
|
||||
decode_output: Optional[torch.Tensor] = None
|
||||
if prefill_meta := attn_metadata.prefill_metadata:
|
||||
# We will use flash attention for prefill
|
||||
# when kv_cache is not provided.
|
||||
# This happens when vllm runs the profiling to
|
||||
# determine the number of blocks.
|
||||
if kv_cache.numel() == 0:
|
||||
prefill_output = flash_attn_varlen_func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
cu_seqlens_q=prefill_meta.seq_start_loc,
|
||||
cu_seqlens_k=prefill_meta.seq_start_loc,
|
||||
max_seqlen_q=prefill_meta.max_prefill_seq_len,
|
||||
max_seqlen_k=prefill_meta.max_prefill_seq_len,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
window_size=window_size,
|
||||
alibi_slopes=alibi_slopes,
|
||||
)
|
||||
else:
|
||||
assert prefill_meta is not None
|
||||
assert prefill_meta.prefill_wrapper is not None
|
||||
prefill_output = prefill_meta.prefill_wrapper.forward(
|
||||
query,
|
||||
kv_cache,
|
||||
logits_soft_cap=logits_soft_cap,
|
||||
causal=True,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale)
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
assert attn_metadata.decode_metadata is not None
|
||||
assert attn_metadata.decode_metadata.decode_wrapper is not None
|
||||
decode_output = attn_metadata.decode_metadata.decode_wrapper.forward(
|
||||
decode_query,
|
||||
kv_cache,
|
||||
sm_scale=softmax_scale,
|
||||
logits_soft_cap=logits_soft_cap,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale)
|
||||
|
||||
if prefill_output is None and decode_output is not None:
|
||||
# Decode only batch.
|
||||
output, num_tokens = decode_output, num_decode_tokens
|
||||
elif decode_output is None and prefill_output is not None:
|
||||
# Prefill only batch.
|
||||
output, num_tokens = prefill_output, num_prefill_tokens
|
||||
else:
|
||||
# Chunked prefill batch does not work with speculative decoding in
|
||||
# FlashInfer backend, so the query length for decode should be 1.
|
||||
assert prefill_output is not None
|
||||
assert decode_output is not None
|
||||
assert decode_meta is not None
|
||||
assert decode_meta.decode_query_len == 1
|
||||
decode_output = decode_output.squeeze(1)
|
||||
output = torch.cat([prefill_output, decode_output], dim=0)
|
||||
return output.view(num_tokens, hidden_size)
|
||||
|
||||
|
||||
def unified_flash_infer_fake(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
num_kv_heads: int,
|
||||
kv_cache: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
k_scale: float,
|
||||
v_scale: float,
|
||||
softmax_scale: float,
|
||||
window_size: Optional[List[int]] = None,
|
||||
alibi_slopes: Optional[torch.Tensor] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(query).contiguous()
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="unified_flash_infer",
|
||||
op_func=unified_flash_infer,
|
||||
mutates_args=["kv_cache"],
|
||||
fake_impl=unified_flash_infer_fake,
|
||||
)
|
||||
264
vllm-v0.6.2/vllm/attention/backends/hpu_attn.py
Normal file
264
vllm-v0.6.2/vllm/attention/backends/hpu_attn.py
Normal file
@@ -0,0 +1,264 @@
|
||||
###############################################################################
|
||||
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company
|
||||
###############################################################################
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
import vllm_hpu_extension.ops as ops
|
||||
from vllm_hpu_extension.utils import Matmul, Softmax, VLLMKVCache
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata, AttentionType)
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
from vllm.attention.ops.hpu_paged_attn import (HPUPagedAttention,
|
||||
HPUPagedAttentionMetadata)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class HPUAttentionBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["HPUAttentionImpl"]:
|
||||
return HPUAttentionImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> Type["AttentionMetadata"]:
|
||||
return HPUAttentionMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> Type["CommonAttentionState"]:
|
||||
return CommonAttentionState
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return HPUPagedAttention.get_kv_cache_shape(num_blocks, block_size,
|
||||
num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: Dict[int, int],
|
||||
) -> None:
|
||||
HPUPagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: Dict[int, List[int]],
|
||||
) -> None:
|
||||
HPUPagedAttention.copy_blocks(kv_caches, src_to_dists)
|
||||
|
||||
|
||||
@dataclass
|
||||
class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata):
|
||||
"""Metadata for HPUAttentionbackend."""
|
||||
# Currently, input sequences can only contain all prompts
|
||||
# or all decoding. True if all sequences are prompts.
|
||||
is_prompt: bool
|
||||
attn_bias: Optional[torch.Tensor]
|
||||
seq_lens_tensor: Optional[torch.Tensor]
|
||||
|
||||
|
||||
class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
|
||||
"""
|
||||
If the input tensors contain prompt tokens, the layout is as follows:
|
||||
|<--------------- num_prefill_tokens ----------------->|
|
||||
|<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
|
||||
|
||||
Otherwise, the layout is as follows:
|
||||
|<----------------- num_decode_tokens ------------------>|
|
||||
|<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
|
||||
|
||||
Generation tokens can contain padding when cuda-graph is used.
|
||||
Currently, prompt tokens don't contain any padding.
|
||||
|
||||
The prompts might have different lengths, while the generation tokens
|
||||
always have length 1.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||
max_seq_len: int = 4096,
|
||||
) -> None:
|
||||
super(AttentionImpl, self).__init__()
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.matmul_qk = Matmul()
|
||||
self.softmax = Softmax()
|
||||
self.matmul_av = Matmul()
|
||||
self.k_cache = VLLMKVCache()
|
||||
self.v_cache = VLLMKVCache()
|
||||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||
self.sliding_window = sliding_window
|
||||
self.alibi_slopes = alibi_slopes
|
||||
if alibi_slopes is not None:
|
||||
alibi_slopes_tensor = torch.tensor(alibi_slopes,
|
||||
dtype=torch.bfloat16)
|
||||
self.alibi_slopes = alibi_slopes_tensor
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
self.prefill_usefusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA',
|
||||
'0').lower() in ['1', 'true']
|
||||
if self.prefill_usefusedsdpa:
|
||||
assert alibi_slopes is None, \
|
||||
'Prefill with FusedSDPA not supported with alibi slopes!'
|
||||
|
||||
suppored_head_sizes = HPUPagedAttention.get_supported_head_sizes()
|
||||
if head_size not in suppored_head_sizes:
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by PagedAttention. "
|
||||
f"Supported head sizes are: {suppored_head_sizes}.")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: HPUAttentionMetadata,
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with xFormers and PagedAttention.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads * head_size]
|
||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"HPUAttentionImpl")
|
||||
batch_size, seq_len, hidden_size = query.shape
|
||||
_, seq_len_kv, _ = key.shape
|
||||
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
block_indices = attn_metadata.block_indices
|
||||
block_offsets = attn_metadata.block_offsets
|
||||
if attn_metadata.is_prompt:
|
||||
key = key.unflatten(0, (block_indices.size(0), -1))
|
||||
value = value.unflatten(0, (block_indices.size(0), -1))
|
||||
if kv_cache is not None:
|
||||
key_cache, value_cache = HPUPagedAttention.split_kv_cache(
|
||||
kv_cache, self.num_kv_heads, self.head_size)
|
||||
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# If kv_cache is not provided, the new key and value tensors are
|
||||
# not cached. This happens during the initial memory profiling run.
|
||||
key_cache = self.k_cache(key, key_cache, block_indices,
|
||||
block_offsets)
|
||||
value_cache = self.v_cache(value, value_cache, block_indices,
|
||||
block_offsets)
|
||||
|
||||
if attn_metadata.is_prompt:
|
||||
# Prompt run.
|
||||
if not self.prefill_usefusedsdpa:
|
||||
# TODO: move this outside of model
|
||||
assert attn_metadata.attn_bias is not None, \
|
||||
'attn_bias must be set before calling model.forward!'
|
||||
attn_bias = attn_metadata.attn_bias
|
||||
if self.alibi_slopes is not None:
|
||||
position_bias = _make_alibi_bias(self.alibi_slopes,
|
||||
self.num_kv_heads,
|
||||
attn_bias.dtype,
|
||||
attn_bias.shape[-1])
|
||||
attn_bias = attn_bias.tile((1, self.num_kv_heads, 1, 1))
|
||||
attn_bias.add_(position_bias)
|
||||
else:
|
||||
attn_bias = None
|
||||
|
||||
query_shape = (batch_size, seq_len, self.num_heads, self.head_size)
|
||||
kv_shape = (batch_size, seq_len_kv, self.num_kv_heads,
|
||||
self.head_size)
|
||||
out = ops.prompt_attention(
|
||||
query.view(query_shape),
|
||||
key.view(kv_shape),
|
||||
value.view(kv_shape),
|
||||
attn_bias=attn_bias,
|
||||
p=0.0,
|
||||
scale=self.scale,
|
||||
matmul_qk_op=self.matmul_qk,
|
||||
softmax_op=self.softmax,
|
||||
matmul_av_op=self.matmul_av,
|
||||
)
|
||||
output = out.reshape(batch_size, seq_len, hidden_size)
|
||||
else:
|
||||
# Decoding run.
|
||||
output = HPUPagedAttention.forward_decode(
|
||||
query=query,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
block_list=attn_metadata.block_list,
|
||||
block_mapping=attn_metadata.block_mapping,
|
||||
block_bias=attn_metadata.attn_bias,
|
||||
block_scales=attn_metadata.block_scales,
|
||||
scale=self.scale,
|
||||
matmul_qk_op=self.matmul_qk,
|
||||
matmul_av_op=self.matmul_av,
|
||||
keys_fetch_func=self.k_cache.fetch_from_cache,
|
||||
values_fetch_func=self.v_cache.fetch_from_cache)
|
||||
# Reshape the output tensor.
|
||||
return output.view(batch_size, seq_len, hidden_size)
|
||||
|
||||
|
||||
def _make_alibi_bias(
|
||||
alibi_slopes: torch.Tensor,
|
||||
num_kv_heads: int,
|
||||
dtype: torch.dtype,
|
||||
seq_len: int,
|
||||
) -> torch.Tensor:
|
||||
bias = torch.arange(seq_len, dtype=dtype)
|
||||
# NOTE(zhuohan): HF uses
|
||||
# `bias = bias[None, :].repeat(seq_len, 1)`
|
||||
# here. We find that both biases give the same results, but
|
||||
# the bias below more accurately follows the original ALiBi
|
||||
# paper.
|
||||
# Calculate a matrix where each element represents ith element- jth
|
||||
# element.
|
||||
bias = bias[None, :] - bias[:, None]
|
||||
|
||||
padded_len = (seq_len + 7) // 8 * 8
|
||||
num_heads = alibi_slopes.shape[0]
|
||||
bias = torch.empty(
|
||||
1, # batch size
|
||||
num_heads,
|
||||
seq_len,
|
||||
padded_len,
|
||||
device=alibi_slopes.device,
|
||||
dtype=dtype,
|
||||
)[:, :, :, :seq_len].copy_(bias)
|
||||
bias.mul_(alibi_slopes[:, None, None])
|
||||
if num_heads != num_kv_heads:
|
||||
bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
|
||||
return bias
|
||||
385
vllm-v0.6.2/vllm/attention/backends/ipex_attn.py
Normal file
385
vllm-v0.6.2/vllm/attention/backends/ipex_attn.py
Normal file
@@ -0,0 +1,385 @@
|
||||
""" Attention layer with torch scaled_dot_product_attention
|
||||
and PagedAttention."""
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
|
||||
from vllm._ipex_ops import ipex_ops
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata, AttentionType)
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
from vllm.attention.ops.paged_attn import (PagedAttention,
|
||||
PagedAttentionMetadata)
|
||||
|
||||
_PARTITION_SIZE = 512
|
||||
|
||||
|
||||
class IpexAttnBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "IPEX"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["IpexAttnBackendImpl"]:
|
||||
return IpexAttnBackendImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> Type["IpexAttnMetadata"]:
|
||||
return IpexAttnMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> Type["CommonAttentionState"]:
|
||||
return CommonAttentionState
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
|
||||
num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: torch.Tensor,
|
||||
) -> None:
|
||||
from vllm._ipex_ops import ipex_ops as ops
|
||||
ops.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: torch.Tensor,
|
||||
) -> None:
|
||||
from vllm._ipex_ops import ipex_ops as ops
|
||||
key_caches = [kv_cache[0] for kv_cache in kv_caches]
|
||||
value_caches = [kv_cache[1] for kv_cache in kv_caches]
|
||||
ops.copy_blocks(key_caches, value_caches, src_to_dists)
|
||||
|
||||
|
||||
@dataclass
|
||||
class IpexAttnMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
"""Metadata for IpexAttnBackend.
|
||||
"""
|
||||
# Currently, input sequences can only contain all prompts
|
||||
# or all decoding. True if all sequences are prompts.
|
||||
is_prompt: bool
|
||||
slot_mapping: torch.Tensor
|
||||
seq_lens: Optional[List[int]]
|
||||
seqlen_q: Optional[torch.Tensor]
|
||||
max_seqlen: Optional[int]
|
||||
|
||||
def __post_init__(self):
|
||||
# Set during the execution of the first attention op.
|
||||
# It is a list because it is needed to set per prompt
|
||||
# when alibi slopes is used. It is because of the limitation
|
||||
# from xformer API.
|
||||
# will not appear in the __repr__ and __init__
|
||||
self.attn_bias: Optional[List[torch.Tensor]] = None
|
||||
|
||||
@property
|
||||
def prefill_metadata(self) -> Optional["IpexAttnMetadata"]:
|
||||
# Currently chunked prefill is not supported
|
||||
if self.num_decode_tokens == 0:
|
||||
assert self.num_prefills > 0
|
||||
return self
|
||||
|
||||
return None
|
||||
|
||||
@property
|
||||
def decode_metadata(self) -> Optional["IpexAttnMetadata"]:
|
||||
# Currently chunked prefill is not supported
|
||||
if self.num_prefills > 0:
|
||||
assert self.num_decode_tokens == 0
|
||||
return None
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
) -> None:
|
||||
if blocksparse_params is not None:
|
||||
raise ValueError(
|
||||
"IPEX backend does not support block-sparse attention.")
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
if alibi_slopes is not None:
|
||||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
||||
self.alibi_slopes = alibi_slopes
|
||||
self.sliding_window = sliding_window
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
self.need_mask = (self.alibi_slopes is not None
|
||||
or self.sliding_window is not None)
|
||||
if logits_soft_cap is None:
|
||||
logits_soft_cap = 0
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
|
||||
supported_head_sizes = PagedAttention.get_supported_head_sizes()
|
||||
if head_size not in supported_head_sizes:
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by PagedAttention. "
|
||||
f"Supported head sizes are: {supported_head_sizes}.")
|
||||
if kv_cache_dtype != "auto":
|
||||
raise NotImplementedError(
|
||||
"IPEX backend does not support FP8 KV cache. "
|
||||
"Please use xFormers backend instead.")
|
||||
|
||||
def split_kv_cache(
|
||||
self,
|
||||
kv_cache: torch.Tensor,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
x = 1
|
||||
num_blocks = kv_cache.shape[1]
|
||||
|
||||
key_cache = kv_cache[0]
|
||||
key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x,
|
||||
-1, x)
|
||||
value_cache = kv_cache[1]
|
||||
value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1)
|
||||
return key_cache, value_cache
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: IpexAttnMetadata, # type: ignore
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with IPEX varlen_attention and PagedAttention.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads * head_size]
|
||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
|
||||
NOTE: kv_cache will be an empty tensor with shape [0]
|
||||
for profiling run.
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
assert k_scale == 1.0 and v_scale == 1.0
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"IpexAttnBackendImpl")
|
||||
num_tokens, hidden_size = query.shape
|
||||
# Reshape the query, key, and value tensors.
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
|
||||
if kv_cache.numel() > 0:
|
||||
key_cache, value_cache = self.split_kv_cache(
|
||||
kv_cache, self.num_kv_heads, self.head_size)
|
||||
ipex_ops.reshape_and_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping.flatten(),
|
||||
self.kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
|
||||
if attn_metadata.is_prompt:
|
||||
assert attn_metadata.seq_lens is not None
|
||||
if (kv_cache.numel() == 0
|
||||
or attn_metadata.block_tables.numel() == 0):
|
||||
if self.num_kv_heads != self.num_heads:
|
||||
key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
|
||||
value = value.repeat_interleave(self.num_queries_per_kv,
|
||||
dim=1)
|
||||
|
||||
if attn_metadata.attn_bias is None:
|
||||
if self.alibi_slopes is not None:
|
||||
att_masks = _make_alibi_bias(
|
||||
self.alibi_slopes, query.dtype,
|
||||
attn_metadata.seq_lens) # type: ignore
|
||||
elif self.sliding_window is not None:
|
||||
att_masks = _make_sliding_window_bias(
|
||||
attn_metadata.seq_lens, self.sliding_window,
|
||||
query.dtype) # type: ignore
|
||||
else:
|
||||
att_masks = _make_sliding_window_bias(
|
||||
attn_metadata.seq_lens, None, dtype=query.dtype)
|
||||
attn_metadata.attn_bias = att_masks
|
||||
|
||||
output = torch.empty(
|
||||
(num_tokens, self.num_heads, self.head_size),
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
ipex_ops.varlen_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
output,
|
||||
attn_metadata.seqlen_q,
|
||||
attn_metadata.seqlen_q,
|
||||
attn_metadata.max_seqlen,
|
||||
attn_metadata.max_seqlen,
|
||||
pdropout=0.0,
|
||||
softmax_scale=self.scale,
|
||||
zero_tensors=False,
|
||||
is_causal=True,
|
||||
return_softmax=False,
|
||||
gen_=None,
|
||||
logits_soft_cap=self.logits_soft_cap,
|
||||
)
|
||||
else:
|
||||
# prefix-enabled attention
|
||||
raise RuntimeError(
|
||||
"IPEX backend doesn't support prefix decoding.")
|
||||
|
||||
else:
|
||||
# Decoding run.
|
||||
max_seq_len = attn_metadata.max_decode_seq_len
|
||||
output = torch.empty_like(query)
|
||||
block_size = value_cache.shape[3]
|
||||
num_seqs, num_heads, head_size = query.shape
|
||||
max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) //
|
||||
_PARTITION_SIZE)
|
||||
# NOTE(woosuk): We use a simple heuristic to decide whether to use
|
||||
# PagedAttention V1 or V2. If the number of partitions is 1, we use
|
||||
# V1 to avoid the overhead of reduction. Also, if the number of
|
||||
# sequences or heads is large, we use V1 since there is enough work
|
||||
# to parallelize.
|
||||
# TODO(woosuk): Tune this heuristic.
|
||||
# For context len > 8192, use V2 kernel to avoid shared memory
|
||||
# shortage.
|
||||
use_v1 = (max_seq_len <= 8192 and
|
||||
(max_num_partitions == 1 or num_seqs * num_heads > 512))
|
||||
if use_v1:
|
||||
# Run PagedAttention V1.
|
||||
ipex_ops.paged_attention_v1(
|
||||
output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
self.num_kv_heads,
|
||||
self.scale,
|
||||
attn_metadata.block_tables,
|
||||
attn_metadata.seq_lens_tensor,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
self.alibi_slopes,
|
||||
self.kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
else:
|
||||
# Run PagedAttention V2.
|
||||
assert _PARTITION_SIZE % block_size == 0
|
||||
tmp_output = torch.empty(
|
||||
size=(num_seqs, num_heads, max_num_partitions, head_size),
|
||||
dtype=output.dtype,
|
||||
device=output.device,
|
||||
)
|
||||
exp_sums = torch.empty(
|
||||
size=(num_seqs, num_heads, max_num_partitions),
|
||||
dtype=torch.float32,
|
||||
device=output.device,
|
||||
)
|
||||
max_logits = torch.empty_like(exp_sums)
|
||||
ipex_ops.paged_attention_v2(
|
||||
output,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
self.num_kv_heads,
|
||||
self.scale,
|
||||
attn_metadata.block_tables,
|
||||
attn_metadata.seq_lens_tensor,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
self.alibi_slopes,
|
||||
self.kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
|
||||
# Reshape the output tensor.
|
||||
return output.view(-1, self.num_heads * self.head_size)
|
||||
|
||||
|
||||
def _make_alibi_bias(
|
||||
alibi_slopes: torch.Tensor,
|
||||
dtype: torch.dtype,
|
||||
seq_lens: List[int],
|
||||
) -> List[torch.Tensor]:
|
||||
attn_biases = []
|
||||
for seq_len in seq_lens:
|
||||
bias = torch.arange(seq_len, dtype=dtype, device=alibi_slopes.device)
|
||||
# NOTE(zhuohan): HF uses
|
||||
# `bias = bias[None, :].repeat(seq_len, 1)`
|
||||
# here. We find that both biases give the same results, but
|
||||
# the bias below more accurately follows the original ALiBi
|
||||
# paper.
|
||||
bias = bias[None, :] - bias[:, None]
|
||||
|
||||
num_heads = alibi_slopes.shape[0]
|
||||
bias = bias[None, :].repeat((num_heads, 1, 1))
|
||||
bias.mul_(alibi_slopes[:, None, None])
|
||||
inf_mask = torch.empty(
|
||||
(1, seq_len, seq_len),
|
||||
dtype=bias.dtype,
|
||||
device=alibi_slopes.device).fill_(-torch.inf).triu_(diagonal=1)
|
||||
attn_biases.append((bias + inf_mask).to(dtype))
|
||||
|
||||
return attn_biases
|
||||
|
||||
|
||||
def _make_sliding_window_bias(
|
||||
seq_lens: List[int],
|
||||
window_size: Optional[int],
|
||||
dtype: torch.dtype,
|
||||
) -> List[torch.Tensor]:
|
||||
attn_biases = []
|
||||
for seq_len in seq_lens:
|
||||
tensor = torch.full(
|
||||
(1, seq_len, seq_len),
|
||||
dtype=dtype,
|
||||
fill_value=1,
|
||||
)
|
||||
shift = 0
|
||||
mask = torch.tril(tensor, diagonal=shift).to(dtype) # type: ignore
|
||||
if window_size is not None:
|
||||
mask = torch.triu(mask, diagonal=shift - window_size + 1)
|
||||
mask = torch.log(mask)
|
||||
attn_biases.append(mask.to(dtype))
|
||||
|
||||
return attn_biases
|
||||
1090
vllm-v0.6.2/vllm/attention/backends/mlu_attn.py
Executable file
1090
vllm-v0.6.2/vllm/attention/backends/mlu_attn.py
Executable file
File diff suppressed because it is too large
Load Diff
140
vllm-v0.6.2/vllm/attention/backends/openvino.py
Normal file
140
vllm-v0.6.2/vllm/attention/backends/openvino.py
Normal file
@@ -0,0 +1,140 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple, Type
|
||||
|
||||
import openvino as ov
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||
AttentionMetadata)
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
from vllm.multimodal import MultiModalPlaceholderMap
|
||||
|
||||
|
||||
def copy_cache_block(src_tensor: ov.Tensor, dst_tensor: ov.Tensor,
|
||||
src_offset: int, dst_offset: int) -> None:
|
||||
|
||||
def create_roi_tensor(
|
||||
tensor: ov.Tensor,
|
||||
block_number: int,
|
||||
) -> ov.Tensor:
|
||||
roi_begin = ov.runtime.Coordinate([0, 0, 0, 0])
|
||||
roi_end = ov.runtime.Coordinate(tensor.get_shape())
|
||||
|
||||
roi_begin[0] = block_number
|
||||
roi_end[0] = block_number + 1
|
||||
|
||||
if isinstance(tensor, ov.Tensor):
|
||||
return ov.Tensor(tensor, roi_begin, roi_end)
|
||||
else:
|
||||
return ov.RemoteTensor(tensor, roi_begin, roi_end)
|
||||
|
||||
src_roi_tensor = \
|
||||
create_roi_tensor(src_tensor, src_offset)
|
||||
dst_roi_tensor = \
|
||||
create_roi_tensor(dst_tensor, dst_offset)
|
||||
src_roi_tensor.copy_to(dst_roi_tensor)
|
||||
|
||||
|
||||
class OpenVINOAttentionBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "OPENVINO"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls():
|
||||
# OpenVINO implements PagedAttention as part of the Optimum
|
||||
# exported model
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def make_metadata(*args, **kwargs) -> "AttentionMetadata":
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> Type["CommonAttentionState"]:
|
||||
return CommonAttentionState
|
||||
|
||||
@staticmethod
|
||||
def make_openvino_metadata(*args, **kwargs) -> "OpenVINOAttentionMetadata":
|
||||
return OpenVINOAttentionMetadata(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return (2, num_blocks, num_kv_heads, block_size, head_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_tensor: ov.Tensor,
|
||||
dst_tensor: ov.Tensor,
|
||||
src_to_dists: List[Tuple[int, int]],
|
||||
) -> None:
|
||||
for src, dst in src_to_dists:
|
||||
copy_cache_block(src_tensor, dst_tensor, src, dst)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[Tuple[ov.Tensor, ov.Tensor]],
|
||||
src_to_dists: List[Tuple[int, int]],
|
||||
) -> None:
|
||||
for src, dst in src_to_dists:
|
||||
for key_cache, value_cache in kv_caches:
|
||||
copy_cache_block(key_cache, key_cache, src, dst)
|
||||
copy_cache_block(value_cache, value_cache, src, dst)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OpenVINOAttentionMetadata:
|
||||
"""Metadata for OpenVINOAttentionBackend.
|
||||
|
||||
Basic terms used below:
|
||||
- batch_size_in_sequences - total number of sequences to execute
|
||||
- prompt_lens – per sequence size number of scheduled tokens
|
||||
- batch_size_in_tokens = sum(prompt_lens)
|
||||
- max_context_len = max(context_lens)
|
||||
- max_num_blocks = div_up(max_context_len / BLOCK_SIZE)
|
||||
- num_blocks – total number of blocks in block_indices
|
||||
"""
|
||||
|
||||
# Describes past KV cache size for each sequence within a batch
|
||||
# Shape: [batch_size_in_sequences]
|
||||
# Type: i32
|
||||
past_lens: torch.Tensor
|
||||
|
||||
# Describes start indices of input / speculative tokens from
|
||||
# current sequences within a batch sequence
|
||||
# Shape: [batch_size_in_sequences + 1]
|
||||
# Type: i32
|
||||
subsequence_begins: torch.Tensor
|
||||
|
||||
# Describes block tables for each sequence within a batch -
|
||||
# indices along 0th dimension in key_cache and value_cache inputs
|
||||
# Shape: [num_blocks]
|
||||
# Type: i32
|
||||
block_indices: torch.Tensor
|
||||
|
||||
# Describes block tables for each sequence within a batch -
|
||||
# for i-th element, it is an index in block_indices with the
|
||||
# first block belonging to i-th sequence
|
||||
# Shape: [batch_size_in_sequences + 1]
|
||||
# Type: i32
|
||||
block_indices_begins: torch.Tensor
|
||||
|
||||
# Describes max context length
|
||||
# Shape: scalar
|
||||
# Type: i32
|
||||
max_context_len: torch.Tensor
|
||||
|
||||
# The index maps that relate multi-modal embeddings to the corresponding
|
||||
# placeholders.
|
||||
#
|
||||
# N.B. These aren't really related to attention and don't belong on this
|
||||
# type -- this is just a temporary solution to make them available to
|
||||
# `model_executable`.
|
||||
multi_modal_placeholder_index_maps: Optional[Dict[
|
||||
str, MultiModalPlaceholderMap.IndexMap]]
|
||||
323
vllm-v0.6.2/vllm/attention/backends/pallas.py
Normal file
323
vllm-v0.6.2/vllm/attention/backends/pallas.py
Normal file
@@ -0,0 +1,323 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
import torch_xla.experimental.custom_kernel # Required to register custom ops.
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata, AttentionType)
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
|
||||
|
||||
class PallasAttentionBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "PALLAS"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["PallasAttentionBackendImpl"]:
|
||||
return PallasAttentionBackendImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> Type["PallasMetadata"]:
|
||||
return PallasMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> Type["CommonAttentionState"]:
|
||||
return CommonAttentionState
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return (num_kv_heads, num_blocks, block_size, head_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: torch.Tensor,
|
||||
) -> None:
|
||||
raise RuntimeError("swap_blocks is not used for the TPU backend.")
|
||||
|
||||
@torch.compile(backend="openxla")
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
src_to_dists: Tuple[torch.Tensor, torch.Tensor],
|
||||
) -> None:
|
||||
src_indices, dst_indices = src_to_dists
|
||||
for k_cache, v_cache in kv_caches:
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(k_cache, True)
|
||||
k_cache[:, dst_indices] = k_cache[:, src_indices]
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(v_cache, True)
|
||||
v_cache[:, dst_indices] = v_cache[:, src_indices]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PallasMetadata(AttentionMetadata):
|
||||
|
||||
# Currently, input sequences can only contain all prefills
|
||||
# or all decoding.
|
||||
block_tables: Optional[torch.Tensor] = None
|
||||
context_lens: Optional[torch.Tensor] = None
|
||||
|
||||
@property
|
||||
def prefill_metadata(self) -> Optional["PallasMetadata"]:
|
||||
if self.num_prefills == 0:
|
||||
return None
|
||||
|
||||
assert self.num_decode_tokens == 0
|
||||
assert self.block_tables is None
|
||||
assert self.context_lens is None
|
||||
return self
|
||||
|
||||
@property
|
||||
def decode_metadata(self) -> Optional["PallasMetadata"]:
|
||||
if self.num_decode_tokens == 0:
|
||||
return None
|
||||
|
||||
assert self.num_prefills == 0
|
||||
assert self.num_prefill_tokens == 0
|
||||
assert self.block_tables is not None
|
||||
assert self.context_lens is not None
|
||||
return self
|
||||
|
||||
|
||||
class PallasAttentionBackendImpl(AttentionImpl):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
if head_size % 128 != 0:
|
||||
raise NotImplementedError("Head size must be a multiple of 128.")
|
||||
if alibi_slopes is not None:
|
||||
raise NotImplementedError("Alibi slopes is not supported.")
|
||||
if sliding_window is not None:
|
||||
raise NotImplementedError("Sliding window is not supported.")
|
||||
if kv_cache_dtype != "auto":
|
||||
raise NotImplementedError("FP8 KV cache dtype is not supported.")
|
||||
if blocksparse_params is not None:
|
||||
raise NotImplementedError("Blocksparse is not supported.")
|
||||
if logits_soft_cap is not None:
|
||||
raise NotImplementedError(
|
||||
"Attention logits soft-capping is not supported.")
|
||||
|
||||
if torch_xla.tpu.version() < 4:
|
||||
raise NotImplementedError("TPU version must be 4 or higher.")
|
||||
|
||||
self.megacore_mode = None
|
||||
tpu_env = torch_xla.tpu.get_tpu_env()
|
||||
tpu_type = (tpu_env.get("ACCELERATOR_TYPE", None)
|
||||
or tpu_env.get("TYPE", None)
|
||||
or tpu_env.get("TPU_ACCELERATOR_TYPE", None))
|
||||
assert tpu_type is not None
|
||||
tpu_type = tpu_type.lower()
|
||||
|
||||
if (("lite" not in tpu_type) and ("v6" not in tpu_type)):
|
||||
if self.num_kv_heads % 2 == 0:
|
||||
self.megacore_mode = "kv_head"
|
||||
else:
|
||||
# NOTE(woosuk): If the batch size is not a multiple of 2, the
|
||||
# megacore mode will be None.
|
||||
self.megacore_mode = "batch"
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: Tuple[torch.Tensor, torch.Tensor],
|
||||
attn_metadata: PallasMetadata,
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with Pallas attention.
|
||||
|
||||
Args:
|
||||
query: shape = [batch_size, seq_len, num_heads * head_size]
|
||||
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
||||
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
||||
kv_cache[0] = [num_kv_heads, num_blocks, block_size, head_size]
|
||||
kv_cache[1] = [num_kv_heads, num_blocks, block_size, head_size]
|
||||
NOTE: kv_cache[0] and kv_cache[1] will be an empty tensor
|
||||
with shape [0] for profiling run.
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [batch_size, seq_len, num_heads * head_size]
|
||||
"""
|
||||
assert k_scale == 1.0 and v_scale == 1.0
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"PallasAttentionBackendImpl")
|
||||
batch_size, seq_len, hidden_size = query.shape
|
||||
query = query.view(batch_size, seq_len, self.num_heads, self.head_size)
|
||||
key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size)
|
||||
value = value.view(batch_size, seq_len, self.num_kv_heads,
|
||||
self.head_size)
|
||||
|
||||
if kv_cache[0].numel() > 0:
|
||||
slot_mapping = attn_metadata.slot_mapping
|
||||
key_cache, value_cache = kv_cache
|
||||
write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping)
|
||||
|
||||
query = query * self.scale
|
||||
if attn_metadata.num_prefills > 0:
|
||||
assert seq_len % 16 == 0, (
|
||||
"Pallas FlashAttention kernel requires seq_len to be a "
|
||||
f"multiple of 16 but got {seq_len}")
|
||||
|
||||
# Handle GQA/MQA.
|
||||
if self.num_kv_heads != self.num_heads:
|
||||
key = key.repeat_interleave(self.num_queries_per_kv, dim=-2)
|
||||
key = key.view(batch_size, seq_len, self.num_heads,
|
||||
self.head_size)
|
||||
value = value.repeat_interleave(self.num_queries_per_kv,
|
||||
dim=-2)
|
||||
value = value.view(batch_size, seq_len, self.num_heads,
|
||||
self.head_size)
|
||||
# FlashAttention requires [batch_size, num_heads, seq_len, d_model]
|
||||
# while the input is [batch_size, seq_len, num_heads, d_model].
|
||||
# Permute the input to match the required format.
|
||||
output = torch.ops.xla.flash_attention(
|
||||
query.permute(0, 2, 1, 3),
|
||||
key.permute(0, 2, 1, 3),
|
||||
value.permute(0, 2, 1, 3),
|
||||
True,
|
||||
)
|
||||
output = output.permute(0, 2, 1, 3)
|
||||
else:
|
||||
# Decoding run.
|
||||
assert kv_cache[0].numel() > 0
|
||||
query = query.squeeze(dim=1)
|
||||
pages_per_compute_block = 16 # TODO(woosuk): Tune this value.
|
||||
|
||||
assert attn_metadata.block_tables is not None
|
||||
assert attn_metadata.context_lens is not None
|
||||
# NOTE(woosuk): The PagedAttention Pallas kernel stores the entire
|
||||
# block table in SMEM. Therefore, if the block table is too large,
|
||||
# the kernel compilation will fail. To avoid this, we split the
|
||||
# batch dimension into smaller chunks and run the kernel multiple
|
||||
# times.
|
||||
MAX_SMEM_USAGE = 512 * 1024
|
||||
size_per_seq = 4 * attn_metadata.block_tables.shape[1]
|
||||
max_num_seq = MAX_SMEM_USAGE // size_per_seq
|
||||
|
||||
if batch_size <= max_num_seq:
|
||||
output = paged_attention(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.context_lens,
|
||||
attn_metadata.block_tables,
|
||||
pages_per_compute_block,
|
||||
self.megacore_mode,
|
||||
)
|
||||
else:
|
||||
chunk_size = max_num_seq
|
||||
# Make sure the chunk size is a multiple of 2.
|
||||
chunk_size = chunk_size // 2 * 2
|
||||
num_chunks = (batch_size + chunk_size - 1) // chunk_size
|
||||
|
||||
output = torch.empty_like(query)
|
||||
for chunk_idx in range(num_chunks):
|
||||
chunk_start = chunk_idx * chunk_size
|
||||
chunk_end = chunk_start + chunk_size
|
||||
# NOTE(woosuk): We skip this line because it causes Dynamo
|
||||
# compilation error. Instead, we rely on the slice operation
|
||||
# to handle the out-of-bound case.
|
||||
# chunk_end = min(chunk_end, batch_size)
|
||||
chunk_output = paged_attention(
|
||||
query[chunk_start:chunk_end],
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.context_lens[chunk_start:chunk_end],
|
||||
attn_metadata.block_tables[chunk_start:chunk_end],
|
||||
pages_per_compute_block,
|
||||
self.megacore_mode,
|
||||
)
|
||||
output[chunk_start:chunk_end] = chunk_output
|
||||
|
||||
# Reshape the output tensor.
|
||||
return output.reshape(batch_size, seq_len, hidden_size)
|
||||
|
||||
|
||||
def write_to_kv_cache(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
) -> None:
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(key_cache, True)
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(value_cache, True)
|
||||
|
||||
key = key.flatten(0, 2)
|
||||
value = value.flatten(0, 2)
|
||||
key_cache = key_cache.flatten(0, 2)
|
||||
value_cache = value_cache.flatten(0, 2)
|
||||
key_cache.index_copy_(0, slot_mapping, key)
|
||||
value_cache.index_copy_(0, slot_mapping, value)
|
||||
|
||||
|
||||
def paged_attention(
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
pages_per_compute_block: int,
|
||||
megacore_mode: Optional[str],
|
||||
) -> torch.Tensor:
|
||||
batch_size = query.shape[0]
|
||||
if megacore_mode == "batch" and batch_size % 2 != 0:
|
||||
megacore_mode = None
|
||||
else:
|
||||
megacore_mode = megacore_mode
|
||||
|
||||
# NOTE(woosuk): A temporary workaround to avoid the error:
|
||||
# "xla::paged_attention() Expected a value of type 'str' for
|
||||
# argument 'megacore_mode' but instead found type 'NoneType'."
|
||||
if megacore_mode is not None:
|
||||
output = torch.ops.xla.paged_attention(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
context_lens,
|
||||
block_tables,
|
||||
pages_per_compute_block,
|
||||
megacore_mode=megacore_mode,
|
||||
)
|
||||
else:
|
||||
output = torch.ops.xla.paged_attention(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
context_lens,
|
||||
block_tables,
|
||||
pages_per_compute_block,
|
||||
)
|
||||
return output
|
||||
341
vllm-v0.6.2/vllm/attention/backends/placeholder_attn.py
Normal file
341
vllm-v0.6.2/vllm/attention/backends/placeholder_attn.py
Normal file
@@ -0,0 +1,341 @@
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata,
|
||||
AttentionMetadataBuilder)
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
from vllm.multimodal import MultiModalPlaceholderMap
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import ModelInputForGPUBuilder
|
||||
|
||||
# Placeholder attention backend for models like Mamba and embedding models that
|
||||
# lack attention.
|
||||
|
||||
|
||||
class PlaceholderAttentionBackend(AttentionBackend):
|
||||
"""Placeholder backend for when no attention is needed."""
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "NO_ATTENTION"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["PlaceholderAttentionImpl"]:
|
||||
return PlaceholderAttentionImpl
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> Type["PlaceholderAttentionMetadataBuilder"]:
|
||||
return PlaceholderAttentionMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> Type["PlaceholderAttentionMetadata"]:
|
||||
return PlaceholderAttentionMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> Type["CommonAttentionState"]:
|
||||
return CommonAttentionState
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return (1, 1, 1, 1, 1)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: torch.Tensor,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: torch.Tensor,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlaceholderAttentionMetadata(AttentionMetadata):
|
||||
"""Attention metadata for prefill and decode batched together."""
|
||||
# (batch_size,). The sequence length per sequence. Sequence length means
|
||||
# the computed tokens + new tokens None if it is a decoding.
|
||||
seq_lens: Optional[List[int]]
|
||||
# seq_lens stored as a tensor.
|
||||
seq_lens_tensor: Optional[torch.Tensor]
|
||||
|
||||
# Maximum query length in the batch.
|
||||
max_query_len: Optional[int]
|
||||
|
||||
# Max number of query tokens among request in the batch.
|
||||
max_decode_query_len: Optional[int]
|
||||
|
||||
# Maximum sequence length among prefill batch. 0 if there are decoding
|
||||
# requests only.
|
||||
max_prefill_seq_len: int
|
||||
# Maximum sequence length among decode batch. 0 if there are prefill
|
||||
# requests only.
|
||||
max_decode_seq_len: int
|
||||
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
||||
# the batch, used to index into subquery. E.g., if the subquery length
|
||||
# is [4, 6], it is [0, 4, 10].
|
||||
query_start_loc: Optional[torch.Tensor]
|
||||
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
||||
# the batch, used to index into sequence. E.g., if the sequence length is
|
||||
# [4, 6], it is [0, 4, 10].
|
||||
seq_start_loc: Optional[torch.Tensor]
|
||||
# (batch_size,) A tensor of context lengths (tokens that are computed
|
||||
# so far).
|
||||
context_lens_tensor: Optional[torch.Tensor]
|
||||
|
||||
# (batch_size, max_blocks_per_seq).
|
||||
# Block addresses per sequence. (Seq id -> list of physical block)
|
||||
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
|
||||
# in the kv cache. Each block can contain up to block_size tokens.
|
||||
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
|
||||
# captured.
|
||||
block_tables: Optional[torch.Tensor]
|
||||
|
||||
# Whether or not if cuda graph is enabled.
|
||||
# Cuda-graph is currently enabled for decoding only.
|
||||
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
||||
use_cuda_graph: bool
|
||||
|
||||
_cached_prefill_metadata: Optional["PlaceholderAttentionMetadata"] = None
|
||||
_cached_decode_metadata: Optional["PlaceholderAttentionMetadata"] = None
|
||||
|
||||
@property
|
||||
def prefill_metadata(self) -> Optional["PlaceholderAttentionMetadata"]:
|
||||
if self.num_prefills == 0:
|
||||
return None
|
||||
|
||||
if self._cached_prefill_metadata is not None:
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
assert self.seq_lens is not None
|
||||
assert self.seq_lens_tensor is not None
|
||||
assert self.query_start_loc is not None
|
||||
assert self.context_lens_tensor is not None
|
||||
assert self.seq_start_loc is not None
|
||||
|
||||
# Placeholders
|
||||
slot_mapping = torch.empty(0)
|
||||
block_tables = torch.empty(0)
|
||||
|
||||
self._cached_prefill_metadata = PlaceholderAttentionMetadata(
|
||||
num_prefills=self.num_prefills,
|
||||
num_prefill_tokens=self.num_prefill_tokens,
|
||||
num_decode_tokens=0,
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=self.
|
||||
multi_modal_placeholder_index_maps,
|
||||
seq_lens=self.seq_lens[:self.num_prefills],
|
||||
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
|
||||
max_decode_query_len=0,
|
||||
max_query_len=self.max_query_len,
|
||||
max_prefill_seq_len=self.max_prefill_seq_len,
|
||||
max_decode_seq_len=0,
|
||||
query_start_loc=self.query_start_loc[:self.num_prefills + 1],
|
||||
seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
|
||||
context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=False,
|
||||
)
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
@property
|
||||
def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]:
|
||||
if self.num_decode_tokens == 0:
|
||||
return None
|
||||
|
||||
if self._cached_decode_metadata is not None:
|
||||
return self._cached_decode_metadata
|
||||
assert self.seq_lens_tensor is not None
|
||||
|
||||
# Placeholders
|
||||
slot_mapping = torch.empty(0)
|
||||
block_tables = torch.empty(0)
|
||||
|
||||
self._cached_decode_metadata = PlaceholderAttentionMetadata(
|
||||
num_prefills=0,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=self.num_decode_tokens,
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
seq_lens=None,
|
||||
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
|
||||
max_decode_query_len=self.max_decode_query_len,
|
||||
max_query_len=None,
|
||||
max_prefill_seq_len=0,
|
||||
max_decode_seq_len=self.max_decode_seq_len,
|
||||
query_start_loc=None,
|
||||
seq_start_loc=None,
|
||||
context_lens_tensor=None,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=self.use_cuda_graph,
|
||||
)
|
||||
return self._cached_decode_metadata
|
||||
|
||||
|
||||
class PlaceholderAttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[PlaceholderAttentionMetadata]):
|
||||
|
||||
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
|
||||
self.prefill_seq_lens: List[int] = []
|
||||
self.context_lens: List[int] = []
|
||||
self.curr_seq_lens: List[int] = []
|
||||
self.multimodal_placeholder_maps: Dict[
|
||||
str,
|
||||
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
|
||||
self.num_prefills = 0
|
||||
self.num_prefill_tokens = 0
|
||||
self.num_decode_tokens = 0
|
||||
|
||||
self.input_builder = input_builder
|
||||
self.runner = input_builder.runner
|
||||
|
||||
def _add_seq_group(
|
||||
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
|
||||
chunked_prefill_enabled: bool):
|
||||
"""Add a sequence group to the metadata. Specifically update/append
|
||||
1. context length.
|
||||
"""
|
||||
is_prompt = inter_data.is_prompt
|
||||
|
||||
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
|
||||
curr_sliding_window_block) in zip(
|
||||
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
|
||||
inter_data.orig_seq_lens, inter_data.seq_lens,
|
||||
inter_data.query_lens, inter_data.context_lens,
|
||||
inter_data.curr_sliding_window_blocks):
|
||||
self.context_lens.append(context_len)
|
||||
|
||||
if is_prompt:
|
||||
mm_maps = inter_data.multi_modal_placeholder_maps
|
||||
if mm_maps:
|
||||
for modality, placeholders in mm_maps.items():
|
||||
self.multimodal_placeholder_maps[modality].extend(
|
||||
placeholders)
|
||||
|
||||
self.num_prefills += 1
|
||||
self.num_prefill_tokens += token_len
|
||||
self.prefill_seq_lens.append(seq_len)
|
||||
else:
|
||||
assert query_len == 1, (
|
||||
"seq_len: {}, context_len: {}, query_len: {}".format(
|
||||
seq_len, context_len, query_len))
|
||||
self.num_decode_tokens += query_len
|
||||
self.curr_seq_lens.append(curr_seq_len)
|
||||
|
||||
def build(self, seq_lens: List[int], query_lens: List[int],
|
||||
cuda_graph_pad_size: int, batch_size: int):
|
||||
"""Build attention metadata with on-device tensors.
|
||||
|
||||
Args:
|
||||
seq_lens: The maybe padded sequence lengths of the input sequences.
|
||||
query_lens: The query lengths of the input sequences.
|
||||
cuda_graph_pad_size: The padding size for cuda graph.
|
||||
-1 if cuda graph is not used.
|
||||
batch_size: The maybe padded batch size.
|
||||
"""
|
||||
for inter_data in self.input_builder.inter_data_list:
|
||||
self._add_seq_group(inter_data,
|
||||
self.input_builder.chunked_prefill_enabled)
|
||||
|
||||
device = self.runner.device
|
||||
use_captured_graph = cuda_graph_pad_size != -1
|
||||
|
||||
logits_soft_cap = getattr(self.runner.model_config.hf_config,
|
||||
"attn_logit_softcapping", None)
|
||||
if logits_soft_cap is not None:
|
||||
raise ValueError(
|
||||
"Please use Flashinfer backend for models with logits_soft_cap"
|
||||
" (i.e., Gemma-2). Otherwise, the output might be wrong."
|
||||
" Set Flashinfer backend by "
|
||||
"export VLLM_ATTENTION_BACKEND=FLASHINFER.")
|
||||
|
||||
max_query_len = max(query_lens)
|
||||
decode_query_lens = query_lens[self.num_prefills:]
|
||||
if len(decode_query_lens) > 0:
|
||||
max_decode_query_len = max(decode_query_lens)
|
||||
else:
|
||||
max_decode_query_len = 1
|
||||
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
|
||||
max_decode_seq_len = max(self.curr_seq_lens, default=0)
|
||||
num_decode_tokens = self.num_decode_tokens
|
||||
|
||||
if use_captured_graph:
|
||||
num_decode_tokens = batch_size
|
||||
|
||||
assert max_query_len > 0, ("query_lens: {}".format(query_lens))
|
||||
|
||||
context_lens_tensor = torch.tensor(self.context_lens,
|
||||
dtype=torch.int,
|
||||
device=device)
|
||||
seq_lens_tensor = torch.tensor(seq_lens,
|
||||
dtype=torch.int,
|
||||
device=device)
|
||||
query_lens_tensor = torch.tensor(query_lens,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
placeholder_index_maps = {
|
||||
modality: placeholder_map.index_map()
|
||||
for modality, placeholder_map in
|
||||
self.multimodal_placeholder_maps.items()
|
||||
}
|
||||
torch.cumsum(seq_lens_tensor,
|
||||
dim=0,
|
||||
dtype=seq_start_loc.dtype,
|
||||
out=seq_start_loc[1:])
|
||||
torch.cumsum(query_lens_tensor,
|
||||
dim=0,
|
||||
dtype=query_start_loc.dtype,
|
||||
out=query_start_loc[1:])
|
||||
|
||||
# Placeholders
|
||||
slot_mapping = torch.empty(0)
|
||||
block_tables = torch.empty(0)
|
||||
|
||||
return PlaceholderAttentionMetadata(
|
||||
num_prefills=self.num_prefills,
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=placeholder_index_maps,
|
||||
num_prefill_tokens=self.num_prefill_tokens,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_query_len=max_query_len,
|
||||
max_decode_query_len=max_decode_query_len,
|
||||
max_prefill_seq_len=max_prefill_seq_len,
|
||||
max_decode_seq_len=max_decode_seq_len,
|
||||
query_start_loc=query_start_loc,
|
||||
seq_start_loc=seq_start_loc,
|
||||
context_lens_tensor=context_lens_tensor,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=use_captured_graph,
|
||||
)
|
||||
|
||||
|
||||
class PlaceholderAttentionImpl(AttentionImpl):
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
return
|
||||
|
||||
def forward(self, *args, **kwargs) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
681
vllm-v0.6.2/vllm/attention/backends/rocm_flash_attn.py
Normal file
681
vllm-v0.6.2/vllm/attention/backends/rocm_flash_attn.py
Normal file
@@ -0,0 +1,681 @@
|
||||
"""Attention layer ROCm GPUs."""
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata, AttentionType)
|
||||
from vllm.attention.backends.utils import (CommonAttentionState,
|
||||
CommonMetadataBuilder)
|
||||
from vllm.attention.ops.paged_attn import (PagedAttention,
|
||||
PagedAttentionMetadata)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_PARTITION_SIZE_ROCM = 512
|
||||
_GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
|
||||
_ON_NAVI = "gfx1" in _GPU_ARCH
|
||||
_ON_MI250_MI300 = any(arch in _GPU_ARCH
|
||||
for arch in ["gfx90a", "gfx940", "gfx941", "gfx942"])
|
||||
|
||||
|
||||
class ROCmFlashAttentionBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "ROCM_FLASH"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]:
|
||||
return ROCmFlashAttentionImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> Type["AttentionMetadata"]:
|
||||
return ROCmFlashAttentionMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> Type["ROCmFlashAttentionMetadataBuilder"]:
|
||||
return ROCmFlashAttentionMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> Type["CommonAttentionState"]:
|
||||
return CommonAttentionState
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
|
||||
num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: torch.Tensor,
|
||||
) -> None:
|
||||
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: torch.Tensor,
|
||||
) -> None:
|
||||
PagedAttention.copy_blocks(kv_caches, src_to_dists)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
"""Metadata for FlashAttentionBackend.
|
||||
|
||||
NOTE: Any python object stored here is not updated when it is
|
||||
cuda-graph replayed. If you have values that need to be changed
|
||||
dynamically, it should be stored in tensor. The tensor has to be
|
||||
updated from `CUDAGraphRunner.forward` API.
|
||||
"""
|
||||
# (batch_size,). The sequence length per sequence. Sequence length means
|
||||
# the computed tokens + new tokens None if it is a decoding.
|
||||
seq_lens: Optional[List[int]]
|
||||
# seq_lens stored as a tensor.
|
||||
seq_lens_tensor: Optional[torch.Tensor]
|
||||
|
||||
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
# |- tokenA -|......................|-- newTokens ---|
|
||||
# |---------- context_len ----------|
|
||||
# |-------------------- seq_len ----------------------|
|
||||
# |-- query_len ---|
|
||||
|
||||
# Maximum query length in the batch. None for decoding.
|
||||
max_query_len: Optional[int]
|
||||
# Maximum sequence length among prefill batch. 0 if there are decoding
|
||||
# requests only.
|
||||
max_prefill_seq_len: int
|
||||
# Maximum sequence length among decode batch. 0 if there are prefill
|
||||
# requests only.
|
||||
max_decode_seq_len: int
|
||||
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
||||
# the batch, used to index into subquery. E.g., if the subquery length
|
||||
# is [4, 6], it is [0, 4, 10].
|
||||
query_start_loc: Optional[torch.Tensor]
|
||||
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
||||
# the batch, used to index into sequence. E.g., if the sequence length is
|
||||
# [4, 6], it is [0, 4, 10].
|
||||
seq_start_loc: Optional[torch.Tensor]
|
||||
|
||||
# Whether or not if cuda graph is enabled.
|
||||
# Cuda-graph is currently enabled for decoding only.
|
||||
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
||||
use_cuda_graph: bool
|
||||
|
||||
# (batch_size,) A tensor of context lengths (tokens that are computed
|
||||
# so far).
|
||||
context_lens_tensor: Optional[torch.Tensor]
|
||||
|
||||
# Max number of query tokens among request in the batch.
|
||||
max_decode_query_len: Optional[int] = None
|
||||
|
||||
_cached_prefill_metadata: Optional["ROCmFlashAttentionMetadata"] = None
|
||||
_cached_decode_metadata: Optional["ROCmFlashAttentionMetadata"] = None
|
||||
|
||||
@property
|
||||
def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
|
||||
if self.num_prefills == 0:
|
||||
return None
|
||||
|
||||
if self._cached_prefill_metadata is not None:
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
assert self.seq_lens is not None
|
||||
assert self.seq_lens_tensor is not None
|
||||
assert self.query_start_loc is not None
|
||||
assert self.context_lens_tensor is not None
|
||||
assert self.block_tables is not None
|
||||
assert self.seq_start_loc is not None
|
||||
|
||||
self._cached_prefill_metadata = ROCmFlashAttentionMetadata(
|
||||
num_prefills=self.num_prefills,
|
||||
num_prefill_tokens=self.num_prefill_tokens,
|
||||
num_decode_tokens=0,
|
||||
slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
|
||||
multi_modal_placeholder_index_maps=self.
|
||||
multi_modal_placeholder_index_maps,
|
||||
seq_lens=self.seq_lens[:self.num_prefills],
|
||||
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
|
||||
max_query_len=self.max_query_len,
|
||||
max_prefill_seq_len=self.max_prefill_seq_len,
|
||||
max_decode_seq_len=0,
|
||||
query_start_loc=self.query_start_loc[:self.num_prefills + 1],
|
||||
seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
|
||||
context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
|
||||
block_tables=self.block_tables[:self.num_prefills],
|
||||
use_cuda_graph=False,
|
||||
)
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
@property
|
||||
def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
|
||||
if self.num_decode_tokens == 0:
|
||||
return None
|
||||
|
||||
if self._cached_decode_metadata is not None:
|
||||
return self._cached_decode_metadata
|
||||
assert self.block_tables is not None
|
||||
assert self.seq_lens_tensor is not None
|
||||
|
||||
self._cached_decode_metadata = ROCmFlashAttentionMetadata(
|
||||
num_prefills=0,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=self.num_decode_tokens,
|
||||
slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
seq_lens=None,
|
||||
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
|
||||
max_query_len=None,
|
||||
max_prefill_seq_len=0,
|
||||
max_decode_seq_len=self.max_decode_seq_len,
|
||||
query_start_loc=None,
|
||||
seq_start_loc=None,
|
||||
context_lens_tensor=None,
|
||||
block_tables=self.block_tables[self.num_prefills:],
|
||||
use_cuda_graph=self.use_cuda_graph,
|
||||
)
|
||||
# Batch may be composed of prefill|decodes, adjust query start indices
|
||||
# to refer to the start of decodes when the two are split apart.
|
||||
# E.g. in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
|
||||
if self._cached_decode_metadata.query_start_loc is not None:
|
||||
qs = self._cached_decode_metadata.query_start_loc
|
||||
self._cached_decode_metadata.query_start_loc = qs - qs[0]
|
||||
return self._cached_decode_metadata
|
||||
|
||||
def advance_step(self,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
sampled_token_ids: Optional[torch.Tensor],
|
||||
block_size: int,
|
||||
num_seqs: int,
|
||||
num_queries: int,
|
||||
turn_prefills_into_decodes: bool = False):
|
||||
"""
|
||||
Update metadata in-place to advance one decode step.
|
||||
"""
|
||||
|
||||
assert not turn_prefills_into_decodes, \
|
||||
("Chunked prefill is not supported with rocm_flash_attn yet."
|
||||
"turn_prefills_into_decodes is a Multi-Step + Chunked-Prefill "
|
||||
"specific parameter.")
|
||||
|
||||
# When using cudagraph, the num_seqs is padded to the next captured
|
||||
# batch sized, but num_queries tracks the actual number of requests in
|
||||
# the batch. For --enforce-eager mode, num_seqs == num_queries
|
||||
if num_seqs != num_queries:
|
||||
assert num_seqs > num_queries
|
||||
assert self.use_cuda_graph
|
||||
|
||||
assert self.num_prefills == 0
|
||||
assert self.num_prefill_tokens == 0
|
||||
assert self.num_decode_tokens == num_seqs
|
||||
assert self.slot_mapping.shape == (num_seqs, )
|
||||
|
||||
assert self.seq_lens is not None
|
||||
assert len(self.seq_lens) == num_seqs
|
||||
assert self.seq_lens_tensor is not None
|
||||
assert self.seq_lens_tensor.shape == (num_seqs, )
|
||||
assert self.max_query_len == 1
|
||||
assert self.max_prefill_seq_len == 0
|
||||
assert self.max_decode_seq_len == max(self.seq_lens)
|
||||
|
||||
assert self.query_start_loc is not None
|
||||
assert self.query_start_loc.shape == (num_queries + 1, )
|
||||
assert self.seq_start_loc is not None
|
||||
assert self.seq_start_loc.shape == (num_seqs + 1, )
|
||||
|
||||
assert self.context_lens_tensor is not None
|
||||
assert self.context_lens_tensor.shape == (num_queries, )
|
||||
|
||||
assert self.block_tables is not None
|
||||
assert self.block_tables.shape[0] == num_seqs
|
||||
|
||||
# Update query lengths. Note that we update only queries and not seqs,
|
||||
# since tensors may be padded due to captured cuda graph batch size
|
||||
for i in range(num_queries):
|
||||
self.seq_lens[i] += 1
|
||||
self.max_decode_seq_len = max(self.seq_lens)
|
||||
|
||||
ops.advance_step_flashattn(num_seqs=num_seqs,
|
||||
num_queries=num_queries,
|
||||
block_size=block_size,
|
||||
input_tokens=model_input.input_tokens,
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
input_positions=model_input.input_positions,
|
||||
seq_lens=self.seq_lens_tensor,
|
||||
slot_mapping=self.slot_mapping,
|
||||
block_tables=self.block_tables)
|
||||
|
||||
|
||||
class ROCmFlashAttentionMetadataBuilder(
|
||||
CommonMetadataBuilder[ROCmFlashAttentionMetadata]):
|
||||
|
||||
_metadata_cls = ROCmFlashAttentionMetadata
|
||||
|
||||
|
||||
def _make_alibi_bias(alibi_slopes: torch.Tensor,
|
||||
dtype: torch.dtype,
|
||||
seq_lens: Optional[List[int]],
|
||||
make_attn_mask: bool = True) -> List[torch.Tensor]:
|
||||
attn_biases = []
|
||||
if seq_lens:
|
||||
for seq_len in seq_lens:
|
||||
bias = torch.arange(seq_len, dtype=dtype)
|
||||
# NOTE(zhuohan): HF uses
|
||||
# `bias = bias[None, :].repeat(seq_len, 1)`
|
||||
# here. We find that both biases give the same results, but
|
||||
# the bias below more accurately follows the original ALiBi
|
||||
# paper.
|
||||
bias = bias[None, :] - bias[:, None]
|
||||
|
||||
num_heads = alibi_slopes.shape[0]
|
||||
bias = bias[None, :].repeat(
|
||||
(num_heads, 1, 1)).to(alibi_slopes.device)
|
||||
bias.mul_(alibi_slopes[:, None, None])
|
||||
if make_attn_mask:
|
||||
inf_mask = torch.empty(
|
||||
(1, seq_len, seq_len),
|
||||
dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1).to(
|
||||
alibi_slopes.device)
|
||||
attn_biases.append((bias + inf_mask).to(dtype))
|
||||
else:
|
||||
attn_biases.append(bias.to(dtype))
|
||||
|
||||
return attn_biases
|
||||
|
||||
|
||||
class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
"""
|
||||
If the input tensors contain prompt tokens, the layout is as follows:
|
||||
|<--------------- num_prompt_tokens -------------->|
|
||||
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|
|
||||
|
||||
Otherwise, the layout is as follows:
|
||||
|<------------------ num_generation_tokens (M) ----------------->|
|
||||
|<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|
|
||||
|
||||
Generation tokens can contain padding when cuda-graph is used.
|
||||
Currently, prompt tokens don't contain any padding.
|
||||
|
||||
The prompts might have different lengths, while the generation tokens
|
||||
always have length 1.
|
||||
|
||||
If chunked prefill is enabled, prefill tokens and decode tokens can be
|
||||
batched together in a flattened 1D query.
|
||||
|
||||
|<----- num_prefill_tokens ---->|<------- num_decode_tokens ----------->|
|
||||
|<-prompt_0->|...|<-prompt_N-1->|<-generation_0->|...|<-generation_M-1->|
|
||||
|
||||
Currently, cuda graph is disabled for chunked prefill, meaning there's no
|
||||
padding between prefill and decode tokens.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
) -> None:
|
||||
if blocksparse_params is not None:
|
||||
raise ValueError(
|
||||
"ROCmFlashAttention does not support blocksparse attention.")
|
||||
if logits_soft_cap is not None:
|
||||
raise ValueError(
|
||||
"ROCmFlashAttention does not support attention logits soft "
|
||||
"capping.")
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
if alibi_slopes is not None:
|
||||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
||||
self.alibi_slopes = alibi_slopes
|
||||
self.sliding_window = ((sliding_window, sliding_window)
|
||||
if sliding_window is not None else (-1, -1))
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
supported_head_sizes = PagedAttention.get_supported_head_sizes()
|
||||
if head_size not in supported_head_sizes:
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by PagedAttention. "
|
||||
f"Supported head sizes are: {supported_head_sizes}.")
|
||||
|
||||
self.use_naive_attn = False
|
||||
# NOTE: Allow for switching between Triton and CK. Defaulting to triton.
|
||||
self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN
|
||||
if self.use_triton_flash_attn:
|
||||
from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
|
||||
triton_attention)
|
||||
self.attn_func = triton_attention
|
||||
logger.debug("Using Triton FA in ROCmBackend")
|
||||
if self.sliding_window != (-1, -1):
|
||||
logger.warning("ROCm Triton FA does not currently support "
|
||||
"sliding window attention. If using half "
|
||||
"precision, please try using the ROCm CK "
|
||||
"FA backend instead by setting the env var "
|
||||
"`VLLM_USE_TRITON_FLASH_ATTN=0`")
|
||||
else:
|
||||
# if not using triton, navi3x/navi21/navi10 do not use flash-attn
|
||||
# either
|
||||
if not current_platform.has_device_capability(90):
|
||||
self.use_naive_attn = True
|
||||
else:
|
||||
try:
|
||||
from flash_attn import flash_attn_varlen_func # noqa: F401
|
||||
self.attn_func = flash_attn_varlen_func
|
||||
logger.debug("Using CK FA in ROCmBackend")
|
||||
except ModuleNotFoundError:
|
||||
self.use_naive_attn = True
|
||||
|
||||
if self.use_naive_attn:
|
||||
self.attn_func = _sdpa_attention
|
||||
logger.debug("Using naive attention in ROCmBackend")
|
||||
|
||||
def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
|
||||
tokens, n_kv_heads, head_dim = x.shape
|
||||
return (x[:, :,
|
||||
None, :].expand(tokens, n_kv_heads, n_rep,
|
||||
head_dim).reshape(tokens, n_kv_heads * n_rep,
|
||||
head_dim))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: ROCmFlashAttentionMetadata,
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashAttention and PagedAttention.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads * head_size]
|
||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
|
||||
NOTE: kv_cache will be an empty tensor with shape [0]
|
||||
for profiling run.
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
|
||||
# If the feature combo become valid
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"ROCmFlashAttentionImpl")
|
||||
|
||||
num_tokens, hidden_size = query.shape
|
||||
# Reshape the query, key, and value tensors.
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
|
||||
if kv_cache.numel() > 0:
|
||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||
kv_cache, self.num_kv_heads, self.head_size)
|
||||
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# If kv_cache is not provided, the new key and value tensors are
|
||||
# not cached. This happens during the initial memory profiling run.
|
||||
PagedAttention.write_to_paged_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
|
||||
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
|
||||
|
||||
output = torch.empty_like(query)
|
||||
# Query for decode. KV is not needed because it is already cached.
|
||||
decode_query = query[num_prefill_tokens:]
|
||||
# QKV for prefill.
|
||||
query = query[:num_prefill_tokens]
|
||||
key = key[:num_prefill_tokens]
|
||||
value = value[:num_prefill_tokens]
|
||||
|
||||
assert query.shape[0] == num_prefill_tokens
|
||||
assert decode_query.shape[0] == num_decode_tokens
|
||||
|
||||
if prefill_meta := attn_metadata.prefill_metadata:
|
||||
# Prompt run.
|
||||
assert prefill_meta.seq_lens is not None
|
||||
if kv_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0:
|
||||
# triton attention
|
||||
# When block_tables are not filled, it means q and k are the
|
||||
# prompt, and they have the same length.
|
||||
attn_masks = None
|
||||
if self.use_triton_flash_attn:
|
||||
if self.alibi_slopes is not None:
|
||||
attn_masks = _make_alibi_bias(
|
||||
self.alibi_slopes,
|
||||
query.dtype,
|
||||
attn_metadata.seq_lens,
|
||||
make_attn_mask=False) # type: ignore
|
||||
out, _ = self.attn_func(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
None,
|
||||
prefill_meta.seq_start_loc,
|
||||
prefill_meta.seq_start_loc,
|
||||
prefill_meta.max_prefill_seq_len,
|
||||
prefill_meta.max_prefill_seq_len,
|
||||
True,
|
||||
self.scale,
|
||||
attn_masks[0][None]
|
||||
if attn_masks is not None else None,
|
||||
)
|
||||
elif self.use_naive_attn:
|
||||
if self.num_kv_heads != self.num_heads:
|
||||
# Interleave for MQA workaround.
|
||||
key = self.repeat_kv(key, self.num_queries_per_kv)
|
||||
value = self.repeat_kv(value, self.num_queries_per_kv)
|
||||
if self.alibi_slopes is not None:
|
||||
attn_masks = _make_alibi_bias(
|
||||
self.alibi_slopes,
|
||||
query.dtype,
|
||||
attn_metadata.seq_lens,
|
||||
make_attn_mask=True) # type: ignore
|
||||
query = query.movedim(0, query.dim() - 2)
|
||||
key = key.movedim(0, key.dim() - 2)
|
||||
value = value.movedim(0, value.dim() - 2)
|
||||
# sdpa math backend attention
|
||||
out = self.attn_func(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
prefill_meta.seq_lens,
|
||||
num_tokens,
|
||||
self.num_heads,
|
||||
self.head_size,
|
||||
self.scale,
|
||||
attn_masks,
|
||||
)
|
||||
else:
|
||||
out = self.attn_func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
cu_seqlens_q=prefill_meta.seq_start_loc,
|
||||
cu_seqlens_k=prefill_meta.seq_start_loc,
|
||||
max_seqlen_q=prefill_meta.max_prefill_seq_len,
|
||||
max_seqlen_k=prefill_meta.max_prefill_seq_len,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
window_size=self.sliding_window,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
)
|
||||
|
||||
# common code for prefill
|
||||
assert output[:num_prefill_tokens].shape == out.shape
|
||||
output[:num_prefill_tokens] = out
|
||||
else:
|
||||
# prefix-enabled attention
|
||||
output[:num_prefill_tokens] = PagedAttention.forward_prefix(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
self.kv_cache_dtype,
|
||||
key_cache,
|
||||
value_cache,
|
||||
prefill_meta.block_tables,
|
||||
prefill_meta.query_start_loc,
|
||||
prefill_meta.seq_lens_tensor,
|
||||
prefill_meta.context_lens_tensor,
|
||||
prefill_meta.max_query_len,
|
||||
self.alibi_slopes,
|
||||
self.sliding_window[0],
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
# Decoding run.
|
||||
# Whether to use rocm custom paged attention or not
|
||||
num_seqs, num_heads, head_size = decode_query.shape
|
||||
block_size = value_cache.shape[3]
|
||||
gqa_ratio = num_heads // self.num_kv_heads
|
||||
use_custom = _use_rocm_custom_paged_attention(
|
||||
decode_query.dtype, head_size, block_size, gqa_ratio,
|
||||
decode_meta.max_decode_seq_len)
|
||||
if use_custom:
|
||||
max_seq_len = decode_meta.max_decode_seq_len
|
||||
max_num_partitions = (
|
||||
(max_seq_len + _PARTITION_SIZE_ROCM - 1) //
|
||||
_PARTITION_SIZE_ROCM)
|
||||
assert _PARTITION_SIZE_ROCM % block_size == 0
|
||||
tmp_output = torch.empty(
|
||||
size=(num_seqs, num_heads, max_num_partitions, head_size),
|
||||
dtype=output.dtype,
|
||||
device=output.device,
|
||||
)
|
||||
exp_sums = torch.empty(
|
||||
size=(num_seqs, num_heads, max_num_partitions),
|
||||
dtype=torch.float32,
|
||||
device=output.device,
|
||||
)
|
||||
max_logits = torch.empty_like(exp_sums)
|
||||
ops.paged_attention_rocm(
|
||||
output[num_prefill_tokens:],
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
decode_query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
self.num_kv_heads,
|
||||
self.scale,
|
||||
decode_meta.block_tables,
|
||||
decode_meta.seq_lens_tensor,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
self.alibi_slopes,
|
||||
self.kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
else:
|
||||
output[num_prefill_tokens:] = PagedAttention.forward_decode(
|
||||
decode_query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
decode_meta.block_tables,
|
||||
decode_meta.seq_lens_tensor,
|
||||
decode_meta.max_decode_seq_len,
|
||||
self.kv_cache_dtype,
|
||||
self.num_kv_heads,
|
||||
self.scale,
|
||||
self.alibi_slopes,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
|
||||
# Reshape the output tensor.
|
||||
return output.view(num_tokens, hidden_size)
|
||||
|
||||
|
||||
def _sdpa_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
seq_lens: List[int],
|
||||
num_tokens: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
attn_masks: Optional[List[torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
start = 0
|
||||
output = torch.empty((num_tokens, num_heads, head_size),
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
|
||||
for i, seq_len in enumerate(seq_lens):
|
||||
end = start + seq_len
|
||||
with torch.backends.cuda.sdp_kernel(enable_math=True,
|
||||
enable_flash=False,
|
||||
enable_mem_efficient=False):
|
||||
sub_out = torch.nn.functional.scaled_dot_product_attention(
|
||||
query[:, start:end, :],
|
||||
key[:, start:end, :],
|
||||
value[:, start:end, :],
|
||||
dropout_p=0.0,
|
||||
is_causal=attn_masks is None,
|
||||
attn_mask=attn_masks[i] if attn_masks else None,
|
||||
scale=scale).movedim(query.dim() - 2, 0)
|
||||
output[start:end, :, :] = sub_out
|
||||
start = end
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def _use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int,
|
||||
block_size: int, gqa_ratio: int,
|
||||
max_seq_len: int) -> bool:
|
||||
# rocm custom page attention not support on navi (gfx1*)
|
||||
return (_ON_MI250_MI300 and not _ON_NAVI
|
||||
and (qtype == torch.half or qtype == torch.bfloat16)
|
||||
and (head_size == 64 or head_size == 128)
|
||||
and (block_size == 16 or block_size == 32)
|
||||
and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768)
|
||||
553
vllm-v0.6.2/vllm/attention/backends/torch_sdpa.py
Normal file
553
vllm-v0.6.2/vllm/attention/backends/torch_sdpa.py
Normal file
@@ -0,0 +1,553 @@
|
||||
""" Attention layer with torch scaled_dot_product_attention
|
||||
and PagedAttention."""
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
from torch.nn.functional import scaled_dot_product_attention
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata, AttentionType)
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
from vllm.attention.ops.paged_attn import PagedAttentionMetadata
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if current_platform.is_cpu():
|
||||
try:
|
||||
from vllm.attention.ops.ipex_attn import PagedAttention
|
||||
except ImportError:
|
||||
from vllm.attention.ops.paged_attn import PagedAttention
|
||||
else:
|
||||
from vllm.attention.ops.paged_attn import PagedAttention
|
||||
|
||||
|
||||
class TorchSDPABackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "TORCH_SDPA"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["TorchSDPABackendImpl"]:
|
||||
return TorchSDPABackendImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> Type["AttentionMetadata"]:
|
||||
return TorchSDPAMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> Type["CommonAttentionState"]:
|
||||
return CommonAttentionState
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
|
||||
num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: torch.Tensor,
|
||||
) -> None:
|
||||
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: torch.Tensor,
|
||||
) -> None:
|
||||
PagedAttention.copy_blocks(kv_caches, src_to_dists)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
"""Metadata for TorchSDPABackend.
|
||||
"""
|
||||
# Currently, input sequences can only contain all prompts
|
||||
# or all decoding. True if all sequences are prompts.
|
||||
is_prompt: bool
|
||||
slot_mapping: torch.Tensor
|
||||
seq_lens: Optional[List[int]]
|
||||
|
||||
# Begin encoder attn & enc/dec cross-attn fields...
|
||||
# Encoder sequence lengths representation
|
||||
encoder_seq_lens: Optional[List[int]] = None
|
||||
encoder_seq_lens_tensor: Optional[torch.Tensor] = None
|
||||
|
||||
# Maximum sequence length among encoder sequences
|
||||
max_encoder_seq_len: Optional[int] = None
|
||||
|
||||
# Number of tokens input to encoder
|
||||
num_encoder_tokens: Optional[int] = None
|
||||
|
||||
# Cross-attention memory-mapping data structures: slot mapping
|
||||
# and block tables
|
||||
cross_slot_mapping: Optional[torch.Tensor] = None
|
||||
cross_block_tables: Optional[torch.Tensor] = None
|
||||
|
||||
def __post_init__(self):
|
||||
# Set during the execution of the first attention op.
|
||||
# It is a list because it is needed to set per prompt
|
||||
# when alibi slopes is used. It is because of the limitation
|
||||
# from xformer API.
|
||||
# will not appear in the __repr__ and __init__
|
||||
self.attn_bias: Optional[List[torch.Tensor]] = None
|
||||
self.encoder_attn_bias: Optional[List[torch.Tensor]] = None
|
||||
self.cross_attn_bias: Optional[List[torch.Tensor]] = None
|
||||
|
||||
@property
|
||||
def is_all_encoder_attn_metadata_set(self):
|
||||
'''
|
||||
All attention metadata required for encoder attention is set.
|
||||
'''
|
||||
return ((self.encoder_seq_lens is not None)
|
||||
and (self.encoder_seq_lens_tensor is not None)
|
||||
and (self.max_encoder_seq_len is not None))
|
||||
|
||||
@property
|
||||
def is_all_cross_attn_metadata_set(self):
|
||||
'''
|
||||
All attention metadata required for enc/dec cross-attention is set.
|
||||
|
||||
Superset of encoder attention required metadata.
|
||||
'''
|
||||
return (self.is_all_encoder_attn_metadata_set
|
||||
and (self.cross_slot_mapping is not None)
|
||||
and (self.cross_block_tables is not None))
|
||||
|
||||
@property
|
||||
def prefill_metadata(self) -> Optional["TorchSDPAMetadata"]:
|
||||
# Currently chunked prefill is not supported
|
||||
if self.num_decode_tokens == 0:
|
||||
assert self.num_prefills > 0
|
||||
return self
|
||||
|
||||
return None
|
||||
|
||||
@property
|
||||
def decode_metadata(self) -> Optional["TorchSDPAMetadata"]:
|
||||
# Currently chunked prefill is not supported
|
||||
if self.num_prefills > 0:
|
||||
assert self.num_decode_tokens == 0
|
||||
return None
|
||||
|
||||
return self
|
||||
|
||||
def get_seq_lens(
|
||||
self,
|
||||
attn_type: AttentionType,
|
||||
):
|
||||
'''
|
||||
Extract appropriate sequence lengths from attention metadata
|
||||
according to attention type.
|
||||
|
||||
Arguments:
|
||||
|
||||
* attn_metadata: Attention metadata structure associated with attention
|
||||
* attn_type: encoder attention, decoder self-attention,
|
||||
encoder/decoder cross-attention
|
||||
|
||||
Returns:
|
||||
* Appropriate sequence lengths tensor for query
|
||||
* Appropriate sequence lengths tensor for key & value
|
||||
'''
|
||||
|
||||
if (attn_type == AttentionType.DECODER
|
||||
or attn_type == AttentionType.ENCODER_ONLY):
|
||||
seq_lens_q = self.seq_lens
|
||||
seq_lens_kv = self.seq_lens
|
||||
elif attn_type == AttentionType.ENCODER:
|
||||
seq_lens_q = self.encoder_seq_lens
|
||||
seq_lens_kv = self.encoder_seq_lens
|
||||
elif attn_type == AttentionType.ENCODER_DECODER:
|
||||
seq_lens_q = self.seq_lens
|
||||
seq_lens_kv = self.encoder_seq_lens
|
||||
else:
|
||||
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
||||
return seq_lens_q, seq_lens_kv
|
||||
|
||||
def get_attn_bias(
|
||||
self,
|
||||
attn_type: AttentionType,
|
||||
) -> Optional[List[torch.Tensor]]:
|
||||
'''
|
||||
Extract appropriate attention bias from attention metadata
|
||||
according to attention type.
|
||||
|
||||
Arguments:
|
||||
|
||||
* attn_metadata: Attention metadata structure associated with attention
|
||||
* attn_type: encoder attention, decoder self-attention,
|
||||
encoder/decoder cross-attention
|
||||
|
||||
Returns:
|
||||
* Appropriate attention bias value given the attention type
|
||||
'''
|
||||
|
||||
if (attn_type == AttentionType.DECODER
|
||||
or attn_type == AttentionType.ENCODER_ONLY):
|
||||
return self.attn_bias
|
||||
elif attn_type == AttentionType.ENCODER:
|
||||
return self.encoder_attn_bias
|
||||
elif attn_type == AttentionType.ENCODER_DECODER:
|
||||
return self.cross_attn_bias
|
||||
else:
|
||||
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
||||
|
||||
def set_attn_bias(
|
||||
self,
|
||||
attn_bias: List[torch.Tensor],
|
||||
attn_type: AttentionType,
|
||||
) -> None:
|
||||
'''
|
||||
Update appropriate attention bias field of attention metadata,
|
||||
according to attention type.
|
||||
|
||||
Arguments:
|
||||
|
||||
* attn_metadata: Attention metadata structure associated with attention
|
||||
* attn_bias: The desired attention bias value
|
||||
* attn_type: encoder attention, decoder self-attention,
|
||||
encoder/decoder cross-attention
|
||||
'''
|
||||
|
||||
if (attn_type == AttentionType.DECODER
|
||||
or attn_type == AttentionType.ENCODER_ONLY):
|
||||
self.attn_bias = attn_bias
|
||||
elif attn_type == AttentionType.ENCODER:
|
||||
self.encoder_attn_bias = attn_bias
|
||||
elif attn_type == AttentionType.ENCODER_DECODER:
|
||||
self.cross_attn_bias = attn_bias
|
||||
else:
|
||||
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
||||
|
||||
def get_seq_len_block_table_args(
|
||||
self,
|
||||
attn_type: AttentionType,
|
||||
) -> tuple:
|
||||
'''
|
||||
The particular choice of sequence-length- and block-table-related
|
||||
attributes which should be extracted from attn_metadata is dependent
|
||||
on the type of attention operation.
|
||||
|
||||
Decoder attn -> select entirely decoder self-attention-related fields
|
||||
Encoder/decoder cross-attn -> select encoder sequence lengths &
|
||||
cross-attn block-tables fields
|
||||
Encoder attn -> select encoder sequence lengths fields & no block tables
|
||||
|
||||
Arguments:
|
||||
|
||||
* attn_metadata: Attention metadata structure associated with attention
|
||||
* is_prompt: True if prefill, False otherwise
|
||||
* attn_type: encoder attention, decoder self-attention,
|
||||
encoder/decoder cross-attention
|
||||
|
||||
Returns:
|
||||
|
||||
* Appropriate sequence-lengths tensor
|
||||
* Appropriate max sequence-length scalar
|
||||
* Appropriate block tables (or None)
|
||||
'''
|
||||
|
||||
if (attn_type == AttentionType.DECODER
|
||||
or attn_type == AttentionType.ENCODER_ONLY):
|
||||
# Decoder self-attention
|
||||
# Choose max_seq_len based on whether we are in prompt_run
|
||||
return (self.seq_lens_tensor, self.max_decode_seq_len,
|
||||
self.block_tables)
|
||||
elif attn_type == AttentionType.ENCODER_DECODER:
|
||||
# Enc/dec cross-attention KVs match encoder sequence length;
|
||||
# cross-attention utilizes special "cross" block tables
|
||||
return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len,
|
||||
self.cross_block_tables)
|
||||
elif attn_type == AttentionType.ENCODER:
|
||||
# No block tables associated with encoder attention
|
||||
return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len,
|
||||
None)
|
||||
else:
|
||||
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
||||
|
||||
|
||||
class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
) -> None:
|
||||
if blocksparse_params is not None:
|
||||
raise ValueError(
|
||||
"Torch SPDA does not support block-sparse attention.")
|
||||
if logits_soft_cap is not None:
|
||||
raise ValueError("Torch SPDA does not support logits soft cap.")
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
if alibi_slopes is not None:
|
||||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
||||
self.alibi_slopes = alibi_slopes
|
||||
self.sliding_window = sliding_window
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
self.need_mask = (self.alibi_slopes is not None
|
||||
or self.sliding_window is not None)
|
||||
|
||||
supported_head_sizes = PagedAttention.get_supported_head_sizes()
|
||||
if head_size not in supported_head_sizes:
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by PagedAttention. "
|
||||
f"Supported head sizes are: {supported_head_sizes}.")
|
||||
if kv_cache_dtype != "auto":
|
||||
raise NotImplementedError(
|
||||
"Torch SDPA backend does not support FP8 KV cache. "
|
||||
"Please use xFormers backend instead.")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: TorchSDPAMetadata, # type: ignore
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with torch SDPA and PagedAttention.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads * head_size]
|
||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
|
||||
NOTE: kv_cache will be an empty tensor with shape [0]
|
||||
for profiling run.
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
assert k_scale == 1.0 and v_scale == 1.0
|
||||
if (attn_type == AttentionType.ENCODER
|
||||
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
|
||||
raise AttributeError("Encoder attention requires setting "
|
||||
"encoder metadata attributes.")
|
||||
elif (attn_type == AttentionType.ENCODER_DECODER
|
||||
and (not attn_metadata.is_all_cross_attn_metadata_set)):
|
||||
raise AttributeError("Encoder/decoder cross-attention "
|
||||
"requires setting cross-attention "
|
||||
"metadata attributes.")
|
||||
|
||||
# Reshape the query, key, and value tensors.
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
if key is not None:
|
||||
assert value is not None
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
else:
|
||||
assert value is None
|
||||
|
||||
if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0):
|
||||
# KV-cache during decoder-self- or
|
||||
# encoder-decoder-cross-attention, but not
|
||||
# during encoder attention.
|
||||
#
|
||||
# Even if there are no new key/value pairs to cache,
|
||||
# we still need to break out key_cache and value_cache
|
||||
# i.e. for later use by paged attention
|
||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||
kv_cache, self.num_kv_heads, self.head_size)
|
||||
|
||||
if (key is not None) and (value is not None):
|
||||
if attn_type == AttentionType.ENCODER_DECODER:
|
||||
# Update cross-attention KV cache (prefill-only)
|
||||
# During cross-attention decode, key & value will be None,
|
||||
# preventing this IF-statement branch from running
|
||||
updated_slot_mapping = attn_metadata.cross_slot_mapping
|
||||
else:
|
||||
# Update self-attention KV cache (prefill/decode)
|
||||
updated_slot_mapping = attn_metadata.slot_mapping
|
||||
|
||||
PagedAttention.write_to_paged_cache(key, value, key_cache,
|
||||
value_cache,
|
||||
updated_slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
k_scale, v_scale)
|
||||
|
||||
if attn_type != AttentionType.ENCODER:
|
||||
# Decoder self-attention supports chunked prefill.
|
||||
# Encoder/decoder cross-attention requires no chunked
|
||||
# prefill (100% prefill or 100% decode tokens, no mix)
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
else:
|
||||
# Encoder attention - chunked prefill is not applicable;
|
||||
# derive token-count from query shape & and treat them
|
||||
# as 100% prefill tokens
|
||||
assert attn_metadata.num_encoder_tokens is not None
|
||||
num_prefill_tokens = attn_metadata.num_encoder_tokens
|
||||
num_decode_tokens = 0
|
||||
|
||||
if attn_type == AttentionType.DECODER:
|
||||
# Only enforce this shape-constraint for decoder
|
||||
# self-attention
|
||||
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
|
||||
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
|
||||
|
||||
if prefill_meta := attn_metadata.prefill_metadata:
|
||||
assert attn_metadata.seq_lens is not None
|
||||
if (kv_cache.numel() == 0
|
||||
or prefill_meta.block_tables.numel() == 0):
|
||||
output = self._run_sdpa_forward(query,
|
||||
key,
|
||||
value,
|
||||
prefill_meta,
|
||||
attn_type=attn_type)
|
||||
else:
|
||||
# prefix-enabled attention
|
||||
raise RuntimeError(
|
||||
"Torch SDPA backend doesn't support prefix decoding.")
|
||||
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
assert attn_type != AttentionType.ENCODER_ONLY, (
|
||||
"Encoder-only models should not have decode metadata.")
|
||||
# Decoding run.
|
||||
(
|
||||
seq_lens_arg,
|
||||
max_seq_len_arg,
|
||||
block_tables_arg,
|
||||
) = decode_meta.get_seq_len_block_table_args(attn_type)
|
||||
|
||||
output = PagedAttention.forward_decode(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
block_tables_arg,
|
||||
seq_lens_arg,
|
||||
max_seq_len_arg,
|
||||
self.kv_cache_dtype,
|
||||
self.num_kv_heads,
|
||||
self.scale,
|
||||
self.alibi_slopes,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
|
||||
# Reshape the output tensor.
|
||||
return output.view(-1, self.num_heads * self.head_size)
|
||||
|
||||
def _run_sdpa_forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_metadata: TorchSDPAMetadata,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
):
|
||||
if self.num_kv_heads != self.num_heads:
|
||||
key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
|
||||
value = value.repeat_interleave(self.num_queries_per_kv, dim=1)
|
||||
|
||||
attn_masks = attn_metadata.get_attn_bias(attn_type)
|
||||
if attn_masks is None:
|
||||
if self.alibi_slopes is not None:
|
||||
attn_masks = _make_alibi_bias(
|
||||
self.alibi_slopes, query.dtype,
|
||||
attn_metadata.seq_lens) # type: ignore
|
||||
elif self.sliding_window is not None:
|
||||
assert attn_metadata.seq_lens is not None
|
||||
attn_masks = _make_sliding_window_bias(
|
||||
attn_metadata.seq_lens, self.sliding_window,
|
||||
query.dtype) # type: ignore
|
||||
else:
|
||||
seq_lens, _ = attn_metadata.get_seq_lens(attn_type)
|
||||
attn_masks = [None] * len(seq_lens)
|
||||
attn_metadata.set_attn_bias(attn_masks, attn_type)
|
||||
|
||||
output = torch.empty_like(query)
|
||||
query = query.movedim(0, query.dim() - 2)
|
||||
key = key.movedim(0, key.dim() - 2)
|
||||
value = value.movedim(0, value.dim() - 2)
|
||||
|
||||
causal_attn = (attn_type == AttentionType.DECODER)
|
||||
|
||||
seq_lens_q, seq_lens_kv = attn_metadata.get_seq_lens(attn_type)
|
||||
start_q, start_kv = 0, 0
|
||||
for seq_len_q, seq_len_kv, mask in zip(seq_lens_q, seq_lens_kv,
|
||||
attn_masks):
|
||||
end_q = start_q + seq_len_q
|
||||
end_kv = start_kv + seq_len_kv
|
||||
sub_out = scaled_dot_product_attention(
|
||||
query[None, :, start_q:end_q, :],
|
||||
key[None, :, start_kv:end_kv, :],
|
||||
value[None, :, start_kv:end_kv, :],
|
||||
attn_mask=mask,
|
||||
dropout_p=0.0,
|
||||
is_causal=causal_attn and not self.need_mask,
|
||||
scale=self.scale).squeeze(0).movedim(query.dim() - 2, 0)
|
||||
output[start_q:end_q, :, :] = sub_out
|
||||
start_q, start_kv = end_q, end_kv
|
||||
return output
|
||||
|
||||
|
||||
def _make_alibi_bias(
|
||||
alibi_slopes: torch.Tensor,
|
||||
dtype: torch.dtype,
|
||||
seq_lens: List[int],
|
||||
) -> List[torch.Tensor]:
|
||||
attn_biases: List[torch.Tensor] = []
|
||||
for seq_len in seq_lens:
|
||||
bias = torch.arange(seq_len, dtype=dtype)
|
||||
# NOTE(zhuohan): HF uses
|
||||
# `bias = bias[None, :].repeat(seq_len, 1)`
|
||||
# here. We find that both biases give the same results, but
|
||||
# the bias below more accurately follows the original ALiBi
|
||||
# paper.
|
||||
bias = bias[None, :] - bias[:, None]
|
||||
|
||||
num_heads = alibi_slopes.shape[0]
|
||||
bias = bias[None, :].repeat((num_heads, 1, 1))
|
||||
bias.mul_(alibi_slopes[:, None, None]).unsqueeze_(0)
|
||||
inf_mask = torch.empty(
|
||||
(1, seq_len, seq_len),
|
||||
dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1)
|
||||
attn_biases.append((bias + inf_mask).to(dtype))
|
||||
|
||||
return attn_biases
|
||||
|
||||
|
||||
def _make_sliding_window_bias(
|
||||
seq_lens: List[int],
|
||||
window_size: Optional[int],
|
||||
dtype: torch.dtype,
|
||||
) -> List[torch.Tensor]:
|
||||
attn_biases: List[torch.Tensor] = []
|
||||
for seq_len in seq_lens:
|
||||
tensor = torch.full(
|
||||
(1, seq_len, seq_len),
|
||||
dtype=dtype,
|
||||
fill_value=1,
|
||||
)
|
||||
shift = 0
|
||||
mask = torch.tril(tensor, diagonal=shift).to(dtype) # type: ignore
|
||||
if window_size is not None:
|
||||
mask = torch.triu(mask, diagonal=shift - window_size + 1)
|
||||
mask = torch.log(mask)
|
||||
attn_biases.append(mask.to(dtype))
|
||||
|
||||
return attn_biases
|
||||
574
vllm-v0.6.2/vllm/attention/backends/utils.py
Normal file
574
vllm-v0.6.2/vllm/attention/backends/utils.py
Normal file
@@ -0,0 +1,574 @@
|
||||
"""Attention backend utils"""
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from itertools import accumulate
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder,
|
||||
AttentionState)
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.multimodal import MultiModalPlaceholderMap
|
||||
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner_base import ModelRunnerBase
|
||||
|
||||
# Error string(s) for encoder/decoder
|
||||
# unsupported attention scenarios
|
||||
STR_NOT_IMPL_ENC_DEC_ROCM_HIP = ("ROCm/HIP is not currently supported "
|
||||
"with encoder/decoder models.")
|
||||
|
||||
PAD_SLOT_ID = -1
|
||||
|
||||
# Switch to numpy implementation of compute_slot_mapping
|
||||
# if we have at least this many elements. Could be tuned further.
|
||||
_COMPUTE_SLOT_MAPPING_NUMPY_NUMEL = 256
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import ModelInputForGPUBuilder
|
||||
|
||||
|
||||
def is_block_tables_empty(block_tables: Union[None, Dict]):
|
||||
"""
|
||||
Check if block_tables is None or a dictionary with all None values.
|
||||
"""
|
||||
if block_tables is None:
|
||||
return True
|
||||
return (isinstance(block_tables, dict)
|
||||
and all(value is None for value in block_tables.values()))
|
||||
|
||||
|
||||
def compute_slot_mapping_start_idx(is_prompt: bool, query_len: int,
|
||||
context_len: int, sliding_window: int):
|
||||
"""
|
||||
Compute the start index of slot mapping.
|
||||
"""
|
||||
start_idx = 0
|
||||
if is_prompt and sliding_window is not None:
|
||||
start_idx = max(0, query_len - sliding_window)
|
||||
return start_idx
|
||||
|
||||
|
||||
def _compute_slot_mapping_python(slot_mapping: List[int],
|
||||
block_table: List[int], range_start: int,
|
||||
range_end: int, block_size: int):
|
||||
for i in range(range_start, range_end):
|
||||
block_number = block_table[i // block_size]
|
||||
block_offset = i % block_size
|
||||
slot = block_number * block_size + block_offset
|
||||
slot_mapping.append(slot)
|
||||
|
||||
|
||||
def _compute_slot_mapping_numpy(slot_mapping: List[int],
|
||||
block_table: List[int], range_start: int,
|
||||
range_end: int, block_size: int):
|
||||
block_table_array = np.array(block_table)
|
||||
idx = np.arange(range_start, range_end)
|
||||
block_offset = idx % block_size
|
||||
idx //= block_size
|
||||
seq_slot_mapping_array = block_table_array[idx]
|
||||
seq_slot_mapping_array *= block_size
|
||||
seq_slot_mapping_array += block_offset
|
||||
slot_mapping.extend(seq_slot_mapping_array)
|
||||
|
||||
|
||||
def compute_slot_mapping(is_profile_run: bool, slot_mapping: List[int],
|
||||
seq_id: int, seq_len: int, context_len: int,
|
||||
start_idx: int, block_size: int,
|
||||
block_tables: Dict[int, List[int]]):
|
||||
"""
|
||||
Compute slot mapping.
|
||||
"""
|
||||
if is_profile_run:
|
||||
# During memory profiling, the block tables are not
|
||||
# initialized yet. In this case, we just use a dummy
|
||||
# slot mapping.
|
||||
# In embeddings, the block tables are {seq_id: None}.
|
||||
slot_mapping.extend([PAD_SLOT_ID] * seq_len)
|
||||
return
|
||||
|
||||
# Mask the [0, start_idx) tokens of the prompt with
|
||||
# PAD_SLOT_ID, where start_idx is max(0, seq_len -
|
||||
# sliding_window). For example, if the prompt len is 10,
|
||||
# sliding window is 8, and block size is 4, the first two
|
||||
# tokens are masked and the slot mapping will be
|
||||
# [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
|
||||
padding_mask_len = max(0, start_idx - context_len)
|
||||
slot_mapping.extend([PAD_SLOT_ID] * padding_mask_len)
|
||||
|
||||
range_start = max(start_idx, context_len)
|
||||
range_end = seq_len
|
||||
numel = range_end - range_start
|
||||
block_table = block_tables[seq_id]
|
||||
|
||||
# numpy implementation will be faster than python if we have
|
||||
# many elements, otherwise it will be slower.
|
||||
if numel < _COMPUTE_SLOT_MAPPING_NUMPY_NUMEL:
|
||||
_compute_slot_mapping_python(slot_mapping, block_table, range_start,
|
||||
range_end, block_size)
|
||||
else:
|
||||
_compute_slot_mapping_numpy(slot_mapping, block_table, range_start,
|
||||
range_end, block_size)
|
||||
|
||||
|
||||
TAttentionMetadata = TypeVar("TAttentionMetadata", bound='AttentionMetadata')
|
||||
|
||||
|
||||
class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
|
||||
|
||||
_metadata_cls: Type[TAttentionMetadata]
|
||||
|
||||
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
|
||||
self.slot_mapping: List[int] = []
|
||||
self.prefill_seq_lens: List[int] = []
|
||||
self.context_lens: List[int] = []
|
||||
self.block_tables: List[List[int]] = []
|
||||
self.curr_seq_lens: List[int] = []
|
||||
self.multimodal_placeholder_maps: Dict[
|
||||
str,
|
||||
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
|
||||
self.num_prefills = 0
|
||||
self.num_prefill_tokens = 0
|
||||
self.num_decode_tokens = 0
|
||||
|
||||
self.input_builder = input_builder
|
||||
self.runner = input_builder.runner
|
||||
|
||||
self.sliding_window = input_builder.sliding_window
|
||||
self.block_size = input_builder.block_size
|
||||
|
||||
def _add_seq_group(
|
||||
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
|
||||
chunked_prefill_enabled: bool):
|
||||
is_prompt = inter_data.is_prompt
|
||||
block_tables = inter_data.block_tables
|
||||
|
||||
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
|
||||
curr_sliding_window_block) in zip(
|
||||
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
|
||||
inter_data.orig_seq_lens, inter_data.seq_lens,
|
||||
inter_data.query_lens, inter_data.context_lens,
|
||||
inter_data.curr_sliding_window_blocks):
|
||||
self.context_lens.append(context_len)
|
||||
if is_prompt:
|
||||
mm_maps = inter_data.multi_modal_placeholder_maps
|
||||
if mm_maps:
|
||||
for modality, placeholders in mm_maps.items():
|
||||
self.multimodal_placeholder_maps[modality].extend(
|
||||
placeholders)
|
||||
|
||||
self.num_prefills += 1
|
||||
self.num_prefill_tokens += token_len
|
||||
self.prefill_seq_lens.append(seq_len)
|
||||
else:
|
||||
assert query_len == 1, (
|
||||
"seq_len: {}, context_len: {}, query_len: {}".format(
|
||||
seq_len, context_len, query_len))
|
||||
self.num_decode_tokens += query_len
|
||||
self.curr_seq_lens.append(curr_seq_len)
|
||||
|
||||
# Compute block table.
|
||||
# TODO(sang): Combine chunked prefill and prefix caching by
|
||||
# only allowing multiple of block_size chunk size.
|
||||
# NOTE: This only works for oooooooxxx style attention.
|
||||
block_table = []
|
||||
if inter_data.prefix_cache_hit:
|
||||
block_table = block_tables[seq_id]
|
||||
elif ((chunked_prefill_enabled or not is_prompt)
|
||||
and block_tables is not None):
|
||||
if curr_sliding_window_block == 0:
|
||||
block_table = block_tables[seq_id]
|
||||
else:
|
||||
block_table = block_tables[seq_id][
|
||||
-curr_sliding_window_block:]
|
||||
self.block_tables.append(block_table)
|
||||
|
||||
# Compute slot mapping.
|
||||
is_profile_run = is_block_tables_empty(block_tables)
|
||||
start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
|
||||
context_len,
|
||||
self.sliding_window)
|
||||
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
|
||||
seq_len, context_len, start_idx,
|
||||
self.block_size, inter_data.block_tables)
|
||||
|
||||
def build(self, seq_lens: List[int], query_lens: List[int],
|
||||
cuda_graph_pad_size: int, batch_size: int):
|
||||
"""Build attention metadata with on-device tensors.
|
||||
|
||||
Args:
|
||||
seq_lens: The maybe padded sequence lengths of the input sequences.
|
||||
query_lens: The query lengths of the input sequences.
|
||||
cuda_graph_pad_size: The padding size for cuda graph.
|
||||
-1 if cuda graph is not used.
|
||||
batch_size: The maybe padded batch size.
|
||||
"""
|
||||
for inter_data in self.input_builder.inter_data_list:
|
||||
self._add_seq_group(inter_data,
|
||||
self.input_builder.chunked_prefill_enabled)
|
||||
|
||||
device = self.runner.device
|
||||
use_captured_graph = cuda_graph_pad_size != -1
|
||||
|
||||
max_query_len = max(query_lens)
|
||||
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
|
||||
max_decode_seq_len = max(self.curr_seq_lens, default=0)
|
||||
num_decode_tokens = self.num_decode_tokens
|
||||
query_start_loc = list(accumulate(query_lens, initial=0))
|
||||
seq_start_loc = list(accumulate(seq_lens, initial=0))
|
||||
|
||||
if use_captured_graph:
|
||||
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
|
||||
self.block_tables.extend([] * cuda_graph_pad_size)
|
||||
num_decode_tokens = batch_size
|
||||
|
||||
# The shape of graph_block_tables is
|
||||
# [max batch size, max context len // block size].
|
||||
input_block_tables = self.runner.graph_block_tables[:batch_size]
|
||||
for i, block_table in enumerate(self.block_tables):
|
||||
if block_table:
|
||||
input_block_tables[i, :len(block_table)] = block_table
|
||||
block_tables = torch.from_numpy(input_block_tables).to(
|
||||
device, non_blocking=True)
|
||||
else:
|
||||
block_tables = make_tensor_with_pad(
|
||||
self.block_tables,
|
||||
pad=0,
|
||||
dtype=torch.int,
|
||||
device=device,
|
||||
)
|
||||
assert max_query_len > 0, "query_lens: {}".format(query_lens)
|
||||
|
||||
assert device is not None
|
||||
context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int,
|
||||
device, self.runner.pin_memory)
|
||||
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
|
||||
self.runner.pin_memory)
|
||||
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
|
||||
device, self.runner.pin_memory)
|
||||
query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32,
|
||||
device,
|
||||
self.runner.pin_memory)
|
||||
seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32,
|
||||
device, self.runner.pin_memory)
|
||||
placeholder_index_maps = {
|
||||
modality: placeholder_map.index_map()
|
||||
for modality, placeholder_map in
|
||||
self.multimodal_placeholder_maps.items()
|
||||
}
|
||||
|
||||
return self._metadata_cls( # type: ignore
|
||||
num_prefills=self.num_prefills,
|
||||
slot_mapping=slot_mapping_tensor,
|
||||
multi_modal_placeholder_index_maps=placeholder_index_maps,
|
||||
num_prefill_tokens=self.num_prefill_tokens,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_query_len=max_query_len,
|
||||
max_prefill_seq_len=max_prefill_seq_len,
|
||||
max_decode_seq_len=max_decode_seq_len,
|
||||
query_start_loc=query_start_loc_tensor,
|
||||
seq_start_loc=seq_start_loc_tensor,
|
||||
context_lens_tensor=context_lens_tensor,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=use_captured_graph,
|
||||
)
|
||||
|
||||
|
||||
class CommonAttentionState(AttentionState):
|
||||
|
||||
def __init__(self, runner: "ModelRunnerBase"):
|
||||
self.runner = runner
|
||||
self._is_graph_capturing = False
|
||||
|
||||
@contextmanager
|
||||
def graph_capture(self, max_batch_size: int):
|
||||
self._is_graph_capturing = True
|
||||
self._graph_slot_mapping = torch.full((max_batch_size, ),
|
||||
PAD_SLOT_ID,
|
||||
dtype=torch.long,
|
||||
device=self.runner.device)
|
||||
self._graph_seq_lens = torch.ones(max_batch_size,
|
||||
dtype=torch.int32,
|
||||
device=self.runner.device)
|
||||
self._graph_block_tables = torch.from_numpy(
|
||||
self.runner.graph_block_tables).to(device=self.runner.device)
|
||||
yield
|
||||
self._is_graph_capturing = False
|
||||
del self._graph_slot_mapping
|
||||
del self._graph_seq_lens
|
||||
del self._graph_block_tables
|
||||
|
||||
def graph_clone(self, batch_size: int) -> "CommonAttentionState":
|
||||
assert self._is_graph_capturing
|
||||
return self.__class__(self.runner)
|
||||
|
||||
def graph_capture_get_metadata_for_batch(
|
||||
self, batch_size: int, is_encoder_decoder_model: bool = False):
|
||||
assert self._is_graph_capturing
|
||||
attn_metadata = self.runner.attn_backend.make_metadata(
|
||||
num_prefills=0,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=batch_size,
|
||||
slot_mapping=self._graph_slot_mapping[:batch_size],
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
seq_lens=None,
|
||||
seq_lens_tensor=self._graph_seq_lens[:batch_size],
|
||||
max_query_len=1,
|
||||
max_decode_query_len=1,
|
||||
max_prefill_seq_len=0,
|
||||
max_decode_seq_len=self.runner.max_seq_len_to_capture,
|
||||
query_start_loc=None,
|
||||
seq_start_loc=None,
|
||||
context_lens_tensor=None,
|
||||
block_tables=self._graph_block_tables[:batch_size],
|
||||
use_cuda_graph=True,
|
||||
)
|
||||
if is_encoder_decoder_model:
|
||||
# The encoder decoder model works only with XFormers and
|
||||
# Flash Attention backend. Assert the same.
|
||||
assert self.runner.attn_backend.get_name() in\
|
||||
["XFORMERS", "FLASH_ATTN"], \
|
||||
f"Expected attn_backend name to be either 'XFORMERS' or " \
|
||||
f"'FLASH_ATTN', but "\
|
||||
f"got '{self.runner.attn_backend.get_name()}'"
|
||||
self._update_captured_metadata_for_enc_dec_model(
|
||||
batch_size=batch_size, attn_metadata=attn_metadata)
|
||||
|
||||
return attn_metadata
|
||||
|
||||
def get_graph_input_buffers(
|
||||
self,
|
||||
attn_metadata,
|
||||
is_encoder_decoder_model: bool = False) -> Dict[str, Any]:
|
||||
input_buffers = {
|
||||
"slot_mapping": attn_metadata.slot_mapping,
|
||||
"seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor,
|
||||
"block_tables": attn_metadata.decode_metadata.block_tables,
|
||||
}
|
||||
if is_encoder_decoder_model:
|
||||
# The encoder decoder model works only with XFormers and
|
||||
# Flash Attention backend. Assert the same.
|
||||
assert self.runner.attn_backend.get_name() in\
|
||||
["XFORMERS", "FLASH_ATTN"], \
|
||||
f"Expected attn_backend name to be either 'XFORMERS' or "\
|
||||
f"'FLASH_ATTN', but "\
|
||||
f"got '{self.runner.attn_backend.get_name()}'"
|
||||
self._add_additonal_input_buffers_for_enc_dec_model(
|
||||
attn_metadata=attn_metadata, input_buffers=input_buffers)
|
||||
return input_buffers
|
||||
|
||||
def prepare_graph_input_buffers(
|
||||
self,
|
||||
input_buffers,
|
||||
attn_metadata,
|
||||
is_encoder_decoder_model: bool = False) -> None:
|
||||
input_buffers["seq_lens_tensor"].copy_(
|
||||
attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True)
|
||||
input_buffers["block_tables"].copy_(
|
||||
attn_metadata.decode_metadata.block_tables, non_blocking=True)
|
||||
if is_encoder_decoder_model:
|
||||
# The encoder decoder model works only with XFormers and
|
||||
# Flash Attention backend. Assert the same.
|
||||
assert self.runner.attn_backend.get_name() in\
|
||||
["XFORMERS", "FLASH_ATTN"], \
|
||||
f"Expected attn_backend name to be either 'XFORMERS' or "\
|
||||
f"'FLASH_ATTN', but "\
|
||||
f"got '{self.runner.attn_backend.get_name()}'"
|
||||
self._prepare_input_buffers_for_enc_dec_model(
|
||||
attn_metadata, input_buffers)
|
||||
|
||||
def begin_forward(self, model_input) -> None:
|
||||
return
|
||||
|
||||
def _update_captured_metadata_for_enc_dec_model(self, batch_size: int,
|
||||
attn_metadata):
|
||||
"""
|
||||
Updates the attention metadata parameters for CUDA graph capture in an
|
||||
encoder-decoder model.
|
||||
|
||||
This method modifies attention-related tensors and metadata required
|
||||
for CUDA graph capture in encoder-decoder models. Specifically, it
|
||||
updates the cross-attention and encoder sequence tensors in the
|
||||
AttentionMetadata object.
|
||||
"""
|
||||
# During decode phase the cross_slot_mapping will be empty. Hence set
|
||||
# an empty tensor for CUDA Graph capture.
|
||||
attn_metadata.cross_slot_mapping = torch.tensor(
|
||||
[], dtype=torch.int).cuda()
|
||||
attn_metadata.cross_block_tables = torch.full(
|
||||
(batch_size, self.runner.get_max_block_per_batch()),
|
||||
1,
|
||||
dtype=torch.int).cuda()
|
||||
attn_metadata.encoder_seq_lens = torch.full((batch_size, ),
|
||||
1,
|
||||
dtype=torch.int).cuda()
|
||||
attn_metadata.encoder_seq_lens_tensor = torch.full(
|
||||
(batch_size, ), 1, dtype=torch.int).cuda()
|
||||
attn_metadata.max_encoder_seq_len = self.runner.max_seq_len_to_capture
|
||||
attn_metadata.num_encoder_tokens = 0
|
||||
|
||||
def _add_additonal_input_buffers_for_enc_dec_model(
|
||||
self, attn_metadata, input_buffers: Dict[str, Any]):
|
||||
"""
|
||||
Saves additional input buffers specific to the encoder-decoder model
|
||||
from the attention metadata.
|
||||
|
||||
This method extracts and stores encoder-decoder related input buffers
|
||||
from the `attn_metadata` into the `input_buffers` dictionary. The
|
||||
buffers include encoder sequence lengths, cross-slot mappings, and
|
||||
cross-block tables, which are essential for the encoder-decoder model
|
||||
during CUDA graph replay.
|
||||
"""
|
||||
input_buffers["encoder_seq_lens_tensor"] = (
|
||||
attn_metadata.decode_metadata.encoder_seq_lens_tensor)
|
||||
input_buffers["cross_slot_mapping"] = (
|
||||
attn_metadata.decode_metadata.cross_slot_mapping)
|
||||
input_buffers["cross_block_tables"] = (
|
||||
attn_metadata.decode_metadata.cross_block_tables)
|
||||
|
||||
def _prepare_input_buffers_for_enc_dec_model(self, attn_metadata,
|
||||
input_buffers: Dict[str,
|
||||
Any]):
|
||||
"""
|
||||
Populates input buffers with data from the encoder-decoder model's
|
||||
attention metadata.
|
||||
|
||||
This method fills the input buffers with encoder-decoder specific
|
||||
tensors. It copies data from the `attn_metadata` and keyword arguments
|
||||
(`kwargs`) into corresponding buffers in the `input_buffers` dictionary.
|
||||
The copied data includes attention-related metadata as well as input
|
||||
IDs and positional information for the encoder.
|
||||
"""
|
||||
input_buffers["encoder_seq_lens_tensor"].copy_(
|
||||
attn_metadata.decode_metadata.encoder_seq_lens_tensor,
|
||||
non_blocking=True)
|
||||
input_buffers["cross_slot_mapping"].copy_(
|
||||
attn_metadata.decode_metadata.cross_slot_mapping,
|
||||
non_blocking=True)
|
||||
input_buffers["cross_block_tables"].copy_(
|
||||
attn_metadata.decode_metadata.cross_block_tables,
|
||||
non_blocking=True)
|
||||
|
||||
|
||||
def is_all_encoder_attn_metadata_set(attn_metadata):
|
||||
'''
|
||||
All attention metadata required for encoder attention is set.
|
||||
'''
|
||||
return ((attn_metadata.encoder_seq_lens is not None)
|
||||
and (attn_metadata.encoder_seq_lens_tensor is not None)
|
||||
and (attn_metadata.max_encoder_seq_len is not None))
|
||||
|
||||
|
||||
def is_all_cross_attn_metadata_set(attn_metadata):
|
||||
'''
|
||||
All attention metadata required for enc/dec cross-attention is set.
|
||||
|
||||
Superset of encoder attention required metadata.
|
||||
'''
|
||||
return (attn_metadata.is_all_encoder_attn_metadata_set
|
||||
and (attn_metadata.cross_slot_mapping is not None)
|
||||
and (attn_metadata.cross_block_tables is not None))
|
||||
|
||||
|
||||
def get_seq_len_block_table_args(
|
||||
attn_metadata,
|
||||
is_prompt: bool,
|
||||
attn_type: AttentionType,
|
||||
) -> tuple:
|
||||
'''
|
||||
The particular choice of sequence-length- and block-table-related
|
||||
attributes which should be extracted from attn_metadata is dependent
|
||||
on the type of attention operation.
|
||||
|
||||
Decoder attn -> select entirely decoder self-attention-related fields
|
||||
Encoder/decoder cross-attn -> select encoder sequence lengths &
|
||||
cross-attn block-tables fields
|
||||
Encoder attn -> select encoder sequence lengths fields & no block tables
|
||||
|
||||
Arguments:
|
||||
|
||||
* attn_metadata: Attention metadata structure associated with attention op
|
||||
* is_prompt: True if prefill, False otherwise
|
||||
* attn_type: encoder attention, decoder self-attention,
|
||||
encoder/decoder cross-attention
|
||||
|
||||
Returns:
|
||||
|
||||
* Appropriate sequence-lengths tensor
|
||||
* Appropriate max sequence-length scalar
|
||||
* Appropriate block tables (or None)
|
||||
'''
|
||||
|
||||
if attn_type == AttentionType.DECODER:
|
||||
# Decoder self-attention
|
||||
# Choose max_seq_len based on whether we are in prompt_run
|
||||
if is_prompt:
|
||||
max_seq_len = attn_metadata.max_prefill_seq_len
|
||||
else:
|
||||
max_seq_len = attn_metadata.max_decode_seq_len
|
||||
return (attn_metadata.seq_lens_tensor, max_seq_len,
|
||||
attn_metadata.block_tables)
|
||||
elif attn_type == AttentionType.ENCODER_DECODER:
|
||||
# Enc/dec cross-attention KVs match encoder sequence length;
|
||||
# cross-attention utilizes special "cross" block tables
|
||||
return (attn_metadata.encoder_seq_lens_tensor,
|
||||
attn_metadata.max_encoder_seq_len,
|
||||
attn_metadata.cross_block_tables)
|
||||
elif attn_type == AttentionType.ENCODER:
|
||||
# No block tables associated with encoder attention
|
||||
return (attn_metadata.encoder_seq_lens_tensor,
|
||||
attn_metadata.max_encoder_seq_len, None)
|
||||
else:
|
||||
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
||||
|
||||
|
||||
def get_num_prefill_decode_query_kv_tokens(
|
||||
attn_metadata,
|
||||
attn_type: AttentionType,
|
||||
) -> Tuple[int, int, int]:
|
||||
"""
|
||||
Calculate the number of prefill and decode tokens for query, key/value
|
||||
based on the attention metadata and the specified attention type.
|
||||
|
||||
Args:
|
||||
attn_metadata (FlashAttentionMetadata): Attention Metadata object.
|
||||
attn_type (AttentionType): The type of attention being used.
|
||||
Returns:
|
||||
Tuple[int, int, int]: A tuple containing three integers:
|
||||
- The number of prefill query tokens.
|
||||
- The number of prefill key/value tokens.
|
||||
- The number of decode query tokens.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the number of encoder tokens in `attn_metadata`
|
||||
is `None` when required for the calculations.
|
||||
"""
|
||||
num_prefill_query_tokens = 0
|
||||
num_decode_query_tokens = 0
|
||||
num_prefill_kv_tokens = 0
|
||||
if attn_type == AttentionType.ENCODER:
|
||||
# Encoder attention is only invoked during prefill phase.
|
||||
# The same input servers a both query and key.
|
||||
assert attn_metadata.num_encoder_tokens is not None
|
||||
num_prefill_query_tokens = attn_metadata.num_encoder_tokens
|
||||
num_prefill_kv_tokens = attn_metadata.num_encoder_tokens
|
||||
num_decode_query_tokens = 0
|
||||
elif attn_type == AttentionType.ENCODER_DECODER:
|
||||
assert attn_metadata.num_encoder_tokens is not None
|
||||
num_prefill_query_tokens = attn_metadata.num_prefill_tokens
|
||||
# The key is the encoder/cross-attention.
|
||||
num_prefill_kv_tokens = attn_metadata.num_encoder_tokens
|
||||
num_decode_query_tokens = attn_metadata.num_decode_tokens
|
||||
else: # attn_type == AttentionType.DECODER or
|
||||
# attn_type == AttentionType.ENCODER_ONLY
|
||||
num_prefill_query_tokens = attn_metadata.num_prefill_tokens
|
||||
num_prefill_kv_tokens = attn_metadata.num_prefill_tokens
|
||||
num_decode_query_tokens = attn_metadata.num_decode_tokens
|
||||
|
||||
return (num_prefill_query_tokens, num_prefill_kv_tokens,
|
||||
num_decode_query_tokens)
|
||||
787
vllm-v0.6.2/vllm/attention/backends/xformers.py
Normal file
787
vllm-v0.6.2/vllm/attention/backends/xformers.py
Normal file
@@ -0,0 +1,787 @@
|
||||
"""Attention layer with xFormers and PagedAttention."""
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
from xformers import ops as xops
|
||||
from xformers.ops.fmha.attn_bias import (AttentionBias,
|
||||
BlockDiagonalCausalMask,
|
||||
BlockDiagonalMask,
|
||||
LowerTriangularMaskWithTensorBias)
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata, AttentionType)
|
||||
from vllm.attention.backends.utils import (
|
||||
CommonAttentionState, CommonMetadataBuilder,
|
||||
get_num_prefill_decode_query_kv_tokens, get_seq_len_block_table_args,
|
||||
is_all_cross_attn_metadata_set, is_all_encoder_attn_metadata_set)
|
||||
from vllm.attention.ops.paged_attn import (PagedAttention,
|
||||
PagedAttentionMetadata)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class XFormersBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "XFORMERS"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["XFormersImpl"]:
|
||||
return XFormersImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> Type["AttentionMetadata"]:
|
||||
return XFormersMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> Type["XFormersMetadataBuilder"]:
|
||||
return XFormersMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> Type["CommonAttentionState"]:
|
||||
return CommonAttentionState
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
|
||||
num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: Dict[int, int],
|
||||
) -> None:
|
||||
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: torch.Tensor,
|
||||
) -> None:
|
||||
PagedAttention.copy_blocks(kv_caches, src_to_dists)
|
||||
|
||||
|
||||
@dataclass
|
||||
class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
"""Metadata for XFormersbackend.
|
||||
|
||||
NOTE: Any python object stored here is not updated when it is
|
||||
cuda-graph replayed. If you have values that need to be changed
|
||||
dynamically, it should be stored in tensor. The tensor has to be
|
||||
updated from `CUDAGraphRunner.forward` API.
|
||||
"""
|
||||
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
# |- tokenA -|......................|-- newTokens ---|
|
||||
# |---------- context_len ----------|
|
||||
# |-------------------- seq_len ----------------------|
|
||||
# |-- query_len ---|
|
||||
|
||||
# seq_lens stored as a tensor.
|
||||
seq_lens_tensor: Optional[torch.Tensor]
|
||||
|
||||
# FIXME: It is for flash attn.
|
||||
# Maximum sequence length among prefill batch. 0 if there are decoding
|
||||
# requests only.
|
||||
max_prefill_seq_len: int
|
||||
# Maximum sequence length among decode batch. 0 if there are prefill
|
||||
# requests only.
|
||||
max_decode_seq_len: int
|
||||
|
||||
# Whether or not if cuda graph is enabled.
|
||||
# Cuda-graph is currently enabled for decoding only.
|
||||
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
||||
use_cuda_graph: bool
|
||||
|
||||
# (batch_size,). The sequence length per sequence. Sequence length means
|
||||
# the computed tokens + new tokens None if it is a decoding.
|
||||
seq_lens: Optional[List[int]] = None
|
||||
|
||||
# FIXME: It is for flash attn.
|
||||
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
||||
# the batch, used to index into sequence. E.g., if the sequence length is
|
||||
# [4, 6], it is [0, 4, 10].
|
||||
seq_start_loc: Optional[torch.Tensor] = None
|
||||
|
||||
# (batch_size,) A tensor of context lengths (tokens that are computed
|
||||
# so far).
|
||||
context_lens_tensor: Optional[torch.Tensor] = None
|
||||
|
||||
# Maximum query length in the batch. None for decoding.
|
||||
max_query_len: Optional[int] = None
|
||||
|
||||
# Max number of query tokens among request in the batch.
|
||||
max_decode_query_len: Optional[int] = None
|
||||
|
||||
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
||||
# the batch, used to index into subquery. E.g., if the subquery length
|
||||
# is [4, 6], it is [0, 4, 10].
|
||||
query_start_loc: Optional[torch.Tensor] = None
|
||||
|
||||
# Self-attention prefill/decode metadata cache
|
||||
_cached_prefill_metadata: Optional["XFormersMetadata"] = None
|
||||
_cached_decode_metadata: Optional["XFormersMetadata"] = None
|
||||
|
||||
# Begin encoder attn & enc/dec cross-attn fields...
|
||||
|
||||
# Encoder sequence lengths representation
|
||||
encoder_seq_lens: Optional[List[int]] = None
|
||||
encoder_seq_lens_tensor: Optional[torch.Tensor] = None
|
||||
# FIXME: It is for flash attn.
|
||||
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
||||
# the batch, used to index into sequence. E.g., if the sequence length is
|
||||
# [4, 6], it is [0, 4, 10].
|
||||
encoder_seq_start_loc: Optional[torch.Tensor] = None
|
||||
|
||||
# Maximum sequence length among encoder sequences
|
||||
max_encoder_seq_len: Optional[int] = None
|
||||
|
||||
# Number of tokens input to encoder
|
||||
num_encoder_tokens: Optional[int] = None
|
||||
|
||||
# Cross-attention memory-mapping data structures: slot mapping
|
||||
# and block tables
|
||||
cross_slot_mapping: Optional[torch.Tensor] = None
|
||||
cross_block_tables: Optional[torch.Tensor] = None
|
||||
|
||||
def __post_init__(self):
|
||||
# Set during the execution of the first attention op.
|
||||
# It is a list because it is needed to set per prompt
|
||||
# when alibi slopes is used. It is because of the limitation
|
||||
# from xformer API.
|
||||
# will not appear in the __repr__ and __init__
|
||||
self.attn_bias: Optional[List[AttentionBias]] = None
|
||||
self.encoder_attn_bias: Optional[List[AttentionBias]] = None
|
||||
self.cross_attn_bias: Optional[List[AttentionBias]] = None
|
||||
|
||||
@property
|
||||
def is_all_encoder_attn_metadata_set(self):
|
||||
'''
|
||||
All attention metadata required for encoder attention is set.
|
||||
'''
|
||||
return is_all_encoder_attn_metadata_set(self)
|
||||
|
||||
@property
|
||||
def is_all_cross_attn_metadata_set(self):
|
||||
'''
|
||||
All attention metadata required for enc/dec cross-attention is set.
|
||||
|
||||
Superset of encoder attention required metadata.
|
||||
'''
|
||||
return is_all_cross_attn_metadata_set(self)
|
||||
|
||||
@property
|
||||
def prefill_metadata(self) -> Optional["XFormersMetadata"]:
|
||||
if self.num_prefills == 0:
|
||||
return None
|
||||
|
||||
if self._cached_prefill_metadata is not None:
|
||||
# Recover cached prefill-phase attention
|
||||
# metadata structure
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
assert ((self.seq_lens is not None)
|
||||
or (self.encoder_seq_lens is not None))
|
||||
assert ((self.seq_lens_tensor is not None)
|
||||
or (self.encoder_seq_lens_tensor is not None))
|
||||
|
||||
# Compute some attn_metadata fields which default to None
|
||||
query_start_loc = (None if self.query_start_loc is None else
|
||||
self.query_start_loc[:self.num_prefills + 1])
|
||||
slot_mapping = (None if self.slot_mapping is None else
|
||||
self.slot_mapping[:self.num_prefill_tokens])
|
||||
seq_lens = (None if self.seq_lens is None else
|
||||
self.seq_lens[:self.num_prefills])
|
||||
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
||||
self.seq_lens_tensor[:self.num_prefills])
|
||||
context_lens_tensor = (None if self.context_lens_tensor is None else
|
||||
self.context_lens_tensor[:self.num_prefills])
|
||||
block_tables = (None if self.block_tables is None else
|
||||
self.block_tables[:self.num_prefills])
|
||||
|
||||
# Construct & cache prefill-phase attention metadata structure
|
||||
self._cached_prefill_metadata = XFormersMetadata(
|
||||
num_prefills=self.num_prefills,
|
||||
num_prefill_tokens=self.num_prefill_tokens,
|
||||
num_decode_tokens=0,
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=self.
|
||||
multi_modal_placeholder_index_maps,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_query_len=self.max_query_len,
|
||||
max_prefill_seq_len=self.max_prefill_seq_len,
|
||||
max_decode_seq_len=0,
|
||||
query_start_loc=query_start_loc,
|
||||
context_lens_tensor=context_lens_tensor,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=False,
|
||||
# Begin encoder & cross attn fields below...
|
||||
encoder_seq_lens=self.encoder_seq_lens,
|
||||
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
|
||||
max_encoder_seq_len=self.max_encoder_seq_len,
|
||||
cross_slot_mapping=self.cross_slot_mapping,
|
||||
cross_block_tables=self.cross_block_tables)
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
@property
|
||||
def decode_metadata(self) -> Optional["XFormersMetadata"]:
|
||||
if self.num_decode_tokens == 0:
|
||||
return None
|
||||
|
||||
if self._cached_decode_metadata is not None:
|
||||
# Recover cached decode-phase attention
|
||||
# metadata structure
|
||||
return self._cached_decode_metadata
|
||||
assert ((self.seq_lens_tensor is not None)
|
||||
or (self.encoder_seq_lens_tensor is not None))
|
||||
|
||||
# Compute some attn_metadata fields which default to None
|
||||
slot_mapping = (None if self.slot_mapping is None else
|
||||
self.slot_mapping[self.num_prefill_tokens:])
|
||||
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
||||
self.seq_lens_tensor[self.num_prefills:])
|
||||
block_tables = (None if self.block_tables is None else
|
||||
self.block_tables[self.num_prefills:])
|
||||
|
||||
# Construct & cache decode-phase attention metadata structure
|
||||
self._cached_decode_metadata = XFormersMetadata(
|
||||
num_prefills=0,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=self.num_decode_tokens,
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_prefill_seq_len=0,
|
||||
max_decode_seq_len=self.max_decode_seq_len,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=self.use_cuda_graph,
|
||||
# Begin encoder & cross attn fields below...
|
||||
encoder_seq_lens=self.encoder_seq_lens,
|
||||
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
|
||||
max_encoder_seq_len=self.max_encoder_seq_len,
|
||||
cross_slot_mapping=self.cross_slot_mapping,
|
||||
cross_block_tables=self.cross_block_tables)
|
||||
|
||||
# Batch may be composed of prefill|decodes, adjust query start indices
|
||||
# to refer to the start of decodes when the two are split apart.
|
||||
# E.g. in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
|
||||
if self._cached_decode_metadata.query_start_loc is not None:
|
||||
qs = self._cached_decode_metadata.query_start_loc
|
||||
self._cached_decode_metadata.query_start_loc = qs - qs[0]
|
||||
return self._cached_decode_metadata
|
||||
|
||||
|
||||
def _get_attn_bias(
|
||||
attn_metadata: XFormersMetadata,
|
||||
attn_type: AttentionType,
|
||||
) -> Optional[AttentionBias]:
|
||||
'''
|
||||
Extract appropriate attention bias from attention metadata
|
||||
according to attention type.
|
||||
|
||||
Arguments:
|
||||
|
||||
* attn_metadata: Attention metadata structure associated with attention
|
||||
* attn_type: encoder attention, decoder self-attention,
|
||||
encoder/decoder cross-attention
|
||||
|
||||
Returns:
|
||||
* Appropriate attention bias value given the attention type
|
||||
'''
|
||||
|
||||
if (attn_type == AttentionType.DECODER
|
||||
or attn_type == AttentionType.ENCODER_ONLY):
|
||||
return attn_metadata.attn_bias
|
||||
elif attn_type == AttentionType.ENCODER:
|
||||
return attn_metadata.encoder_attn_bias
|
||||
elif attn_type == AttentionType.ENCODER_DECODER:
|
||||
return attn_metadata.cross_attn_bias
|
||||
else:
|
||||
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
||||
|
||||
|
||||
def _set_attn_bias(
|
||||
attn_metadata: XFormersMetadata,
|
||||
attn_bias: List[Optional[AttentionBias]],
|
||||
attn_type: AttentionType,
|
||||
) -> None:
|
||||
'''
|
||||
Update appropriate attention bias field of attention metadata,
|
||||
according to attention type.
|
||||
|
||||
Arguments:
|
||||
|
||||
* attn_metadata: Attention metadata structure associated with attention
|
||||
* attn_bias: The desired attention bias value
|
||||
* attn_type: encoder attention, decoder self-attention,
|
||||
encoder/decoder cross-attention
|
||||
'''
|
||||
|
||||
if (attn_type == AttentionType.DECODER
|
||||
or attn_type == AttentionType.ENCODER_ONLY):
|
||||
attn_metadata.attn_bias = attn_bias
|
||||
elif attn_type == AttentionType.ENCODER:
|
||||
attn_metadata.encoder_attn_bias = attn_bias
|
||||
elif attn_type == AttentionType.ENCODER_DECODER:
|
||||
attn_metadata.cross_attn_bias = attn_bias
|
||||
else:
|
||||
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
||||
|
||||
|
||||
class XFormersMetadataBuilder(CommonMetadataBuilder[XFormersMetadata]):
|
||||
|
||||
_metadata_cls = XFormersMetadata
|
||||
|
||||
|
||||
class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
||||
"""
|
||||
If the input tensors contain prompt tokens, the layout is as follows:
|
||||
|<--------------- num_prefill_tokens ----------------->|
|
||||
|<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
|
||||
|
||||
Otherwise, the layout is as follows:
|
||||
|<----------------- num_decode_tokens ------------------>|
|
||||
|<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
|
||||
|
||||
Generation tokens can contain padding when cuda-graph is used.
|
||||
Currently, prompt tokens don't contain any padding.
|
||||
|
||||
The prompts might have different lengths, while the generation tokens
|
||||
always have length 1.
|
||||
|
||||
If chunked prefill is enabled, prefill tokens and decode tokens can be
|
||||
batched together in a flattened 1D query.
|
||||
|
||||
|<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
|
||||
|<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|
|
||||
|
||||
Currently, cuda graph is disabled for chunked prefill, meaning there's no
|
||||
padding between prefill and decode tokens.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
) -> None:
|
||||
if blocksparse_params is not None:
|
||||
raise ValueError(
|
||||
"XFormers does not support block-sparse attention.")
|
||||
if logits_soft_cap is not None:
|
||||
raise ValueError(
|
||||
"XFormers does not support attention logits soft capping.")
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
if alibi_slopes is not None:
|
||||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
||||
self.alibi_slopes = alibi_slopes
|
||||
self.sliding_window = sliding_window
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
suppored_head_sizes = PagedAttention.get_supported_head_sizes()
|
||||
if head_size not in suppored_head_sizes:
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by PagedAttention. "
|
||||
f"Supported head sizes are: {suppored_head_sizes}.")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: Optional[torch.Tensor],
|
||||
value: Optional[torch.Tensor],
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: "XFormersMetadata",
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with xFormers and PagedAttention.
|
||||
|
||||
For decoder-only models: query, key and value must be non-None.
|
||||
|
||||
For encoder/decoder models:
|
||||
* XFormersImpl.forward() may be invoked for both self- and cross-
|
||||
attention layers.
|
||||
* For self-attention: query, key and value must be non-None.
|
||||
* For cross-attention:
|
||||
* Query must be non-None
|
||||
* During prefill, key and value must be non-None; key and value
|
||||
get cached for use during decode.
|
||||
* During decode, key and value may be None, since:
|
||||
(1) key and value tensors were cached during prefill, and
|
||||
(2) cross-attention key and value tensors do not grow during
|
||||
decode
|
||||
|
||||
A note on how the attn_type (attention type enum) argument impacts
|
||||
attention forward() behavior:
|
||||
|
||||
* DECODER: normal decoder-only behavior;
|
||||
use decoder self-attention block table
|
||||
* ENCODER: no KV caching; pass encoder sequence
|
||||
attributes (encoder_seq_lens/encoder_seq_lens_tensor/
|
||||
max_encoder_seq_len) to kernel, in lieu of decoder
|
||||
sequence attributes (seq_lens/seq_lens_tensor/max_seq_len).
|
||||
Used for encoder branch of encoder-decoder models.
|
||||
* ENCODER_ONLY: no kv_caching, uses the normal attention
|
||||
attributes (seq_lens/seq_lens_tensor/max_seq_len).
|
||||
* ENCODER_DECODER: cross-attention behavior;
|
||||
use cross-attention block table for caching KVs derived
|
||||
from encoder hidden states; since KV sequence lengths
|
||||
will match encoder sequence lengths, pass encoder sequence
|
||||
attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/
|
||||
max_encoder_seq_len)
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads * head_size]
|
||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
|
||||
NOTE: kv_cache will be an empty tensor with shape [0]
|
||||
for profiling run.
|
||||
attn_metadata: Metadata for attention.
|
||||
attn_type: Select attention type, between encoder attention,
|
||||
decoder self-attention, or encoder/decoder cross-
|
||||
attention. Defaults to decoder self-attention,
|
||||
which is the vLLM default generally
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
|
||||
# Check that appropriate attention metadata attributes are
|
||||
# selected for the desired attention type
|
||||
if (attn_type == AttentionType.ENCODER
|
||||
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
|
||||
raise AttributeError("Encoder attention requires setting "
|
||||
"encoder metadata attributes.")
|
||||
|
||||
elif (attn_type == AttentionType.ENCODER_DECODER
|
||||
and (not attn_metadata.is_all_cross_attn_metadata_set)):
|
||||
raise AttributeError("Encoder/decoder cross-attention "
|
||||
"requires setting cross-attention "
|
||||
"metadata attributes.")
|
||||
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
if key is not None:
|
||||
assert value is not None
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
else:
|
||||
assert value is None
|
||||
|
||||
# Self-attention vs. cross-attention will impact
|
||||
# which KV cache memory-mapping & which
|
||||
# seqlen datastructures we utilize
|
||||
|
||||
if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0):
|
||||
# KV-cache during decoder-self- or
|
||||
# encoder-decoder-cross-attention, but not
|
||||
# during encoder attention.
|
||||
#
|
||||
# Even if there are no new key/value pairs to cache,
|
||||
# we still need to break out key_cache and value_cache
|
||||
# i.e. for later use by paged attention
|
||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||
kv_cache, self.num_kv_heads, self.head_size)
|
||||
|
||||
if (key is not None) and (value is not None):
|
||||
|
||||
if attn_type == AttentionType.ENCODER_DECODER:
|
||||
# Update cross-attention KV cache (prefill-only)
|
||||
# During cross-attention decode, key & value will be None,
|
||||
# preventing this IF-statement branch from running
|
||||
updated_slot_mapping = attn_metadata.cross_slot_mapping
|
||||
else:
|
||||
# Update self-attention KV cache (prefill/decode)
|
||||
updated_slot_mapping = attn_metadata.slot_mapping
|
||||
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# If kv_cache is not provided, the new key and value tensors are
|
||||
# not cached. This happens during the initial memory
|
||||
# profiling run.
|
||||
PagedAttention.write_to_paged_cache(key, value, key_cache,
|
||||
value_cache,
|
||||
updated_slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
k_scale, v_scale)
|
||||
(num_prefill_query_tokens, num_prefill_kv_tokens,
|
||||
num_decode_query_tokens) = \
|
||||
get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type)
|
||||
|
||||
output = torch.empty_like(query)
|
||||
# Query for decode. KV is not needed because it is already cached.
|
||||
decode_query = query[num_prefill_query_tokens:]
|
||||
# QKV for prefill.
|
||||
query = query[:num_prefill_query_tokens]
|
||||
if key is not None and value is not None:
|
||||
key = key[:num_prefill_kv_tokens]
|
||||
value = value[:num_prefill_kv_tokens]
|
||||
|
||||
assert query.shape[0] == num_prefill_query_tokens
|
||||
assert decode_query.shape[0] == num_decode_query_tokens
|
||||
|
||||
if prefill_meta := attn_metadata.prefill_metadata:
|
||||
# Prompt run.
|
||||
if kv_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0:
|
||||
# normal attention.
|
||||
# block tables are empty if the prompt does not have a cached
|
||||
# prefix.
|
||||
out = self._run_memory_efficient_xformers_forward(
|
||||
query, key, value, prefill_meta, attn_type=attn_type)
|
||||
assert out.shape == output[:num_prefill_query_tokens].shape
|
||||
output[:num_prefill_query_tokens] = out
|
||||
else:
|
||||
assert attn_type != AttentionType.ENCODER_ONLY, (
|
||||
"Encoder-only models should not have prefix attention.")
|
||||
|
||||
assert prefill_meta.query_start_loc is not None
|
||||
assert prefill_meta.max_query_len is not None
|
||||
|
||||
# prefix-enabled attention
|
||||
# TODO(Hai) this triton kernel has regression issue (broke) to
|
||||
# deal with different data types between KV and FP8 KV cache,
|
||||
# to be addressed separately.
|
||||
out = PagedAttention.forward_prefix(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
self.kv_cache_dtype,
|
||||
key_cache,
|
||||
value_cache,
|
||||
prefill_meta.block_tables,
|
||||
prefill_meta.query_start_loc,
|
||||
prefill_meta.seq_lens_tensor,
|
||||
prefill_meta.context_lens_tensor,
|
||||
prefill_meta.max_query_len,
|
||||
self.alibi_slopes,
|
||||
self.sliding_window,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
assert output[:num_prefill_query_tokens].shape == out.shape
|
||||
output[:num_prefill_query_tokens] = out
|
||||
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
assert attn_type != AttentionType.ENCODER_ONLY, (
|
||||
"Encoder-only models should not have decode metadata.")
|
||||
|
||||
(
|
||||
seq_lens_arg,
|
||||
max_seq_len_arg,
|
||||
block_tables_arg,
|
||||
) = get_seq_len_block_table_args(decode_meta, False, attn_type)
|
||||
|
||||
output[num_prefill_query_tokens:] = PagedAttention.forward_decode(
|
||||
decode_query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
block_tables_arg,
|
||||
seq_lens_arg,
|
||||
max_seq_len_arg,
|
||||
self.kv_cache_dtype,
|
||||
self.num_kv_heads,
|
||||
self.scale,
|
||||
self.alibi_slopes,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
|
||||
# Reshape the output tensor.
|
||||
return output.view(-1, self.num_heads * self.head_size)
|
||||
|
||||
def _run_memory_efficient_xformers_forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_metadata: XFormersMetadata,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
) -> torch.Tensor:
|
||||
"""Attention for 1D query of multiple prompts. Multiple prompt
|
||||
tokens are flattened in to `query` input.
|
||||
|
||||
See https://facebookresearch.github.io/xformers/components/ops.html
|
||||
for API spec.
|
||||
|
||||
Args:
|
||||
output: shape = [num_prefill_tokens, num_heads, head_size]
|
||||
query: shape = [num_prefill_tokens, num_heads, head_size]
|
||||
key: shape = [num_prefill_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_prefill_tokens, num_kv_heads, head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
attn_type: Select attention type, between encoder attention,
|
||||
decoder self-attention, or encoder/decoder cross-
|
||||
attention. Defaults to decoder self-attention,
|
||||
which is the vLLM default generally
|
||||
"""
|
||||
|
||||
original_query = query
|
||||
if self.num_kv_heads != self.num_heads:
|
||||
# GQA/MQA requires the shape [B, M, G, H, K].
|
||||
# Note that the output also has the same shape (which is different
|
||||
# from a spec from the doc).
|
||||
query = query.view(query.shape[0], self.num_kv_heads,
|
||||
self.num_queries_per_kv, query.shape[-1])
|
||||
key = key[:, :,
|
||||
None, :].expand(key.shape[0], self.num_kv_heads,
|
||||
self.num_queries_per_kv, key.shape[-1])
|
||||
value = value[:, :,
|
||||
None, :].expand(value.shape[0], self.num_kv_heads,
|
||||
self.num_queries_per_kv,
|
||||
value.shape[-1])
|
||||
|
||||
# Set attention bias if not provided. This typically happens at
|
||||
# the very attention layer of every iteration.
|
||||
# FIXME(woosuk): This is a hack.
|
||||
attn_bias = _get_attn_bias(attn_metadata, attn_type)
|
||||
if attn_bias is None:
|
||||
if self.alibi_slopes is None:
|
||||
|
||||
# Cross attention block of decoder branch of encoder-decoder
|
||||
# model uses seq_lens for dec / encoder_seq_lens for enc
|
||||
if (attn_type == AttentionType.ENCODER_DECODER):
|
||||
assert attn_metadata.seq_lens is not None
|
||||
assert attn_metadata.encoder_seq_lens is not None
|
||||
|
||||
# Cross-attention mask is non-causal
|
||||
attn_bias = BlockDiagonalMask.from_seqlens(
|
||||
attn_metadata.seq_lens, attn_metadata.encoder_seq_lens)
|
||||
|
||||
# Encoder branch of encoder-decoder model uses
|
||||
# attn_metadata.encoder_seq_lens
|
||||
elif attn_type == AttentionType.ENCODER:
|
||||
|
||||
assert attn_metadata.encoder_seq_lens is not None
|
||||
|
||||
# Encoder self-attention mask is non-causal
|
||||
attn_bias = BlockDiagonalMask.from_seqlens(
|
||||
attn_metadata.encoder_seq_lens)
|
||||
|
||||
# Self-attention block of encoder-only model just
|
||||
# uses the seq_lens directly.
|
||||
elif attn_type == AttentionType.ENCODER_ONLY:
|
||||
assert attn_metadata.seq_lens is not None
|
||||
|
||||
# Encoder self-attention mask is non-causal
|
||||
attn_bias = BlockDiagonalMask.from_seqlens(
|
||||
attn_metadata.seq_lens)
|
||||
|
||||
# Self-attention block of decoder branch just
|
||||
# uses the seq_lens directly
|
||||
elif attn_type == AttentionType.DECODER:
|
||||
assert attn_metadata.seq_lens is not None
|
||||
|
||||
# Decoder self-attention mask is causal
|
||||
attn_bias = BlockDiagonalCausalMask.from_seqlens(
|
||||
attn_metadata.seq_lens)
|
||||
else:
|
||||
raise ValueError("Unknown AttentionType: %s", attn_type)
|
||||
|
||||
if self.sliding_window is not None:
|
||||
attn_bias = attn_bias.make_local_attention(
|
||||
self.sliding_window)
|
||||
attn_bias = [attn_bias]
|
||||
else:
|
||||
assert attn_type == AttentionType.DECODER
|
||||
assert attn_metadata.seq_lens is not None
|
||||
attn_bias = _make_alibi_bias(self.alibi_slopes,
|
||||
self.num_kv_heads, query.dtype,
|
||||
attn_metadata.seq_lens)
|
||||
|
||||
_set_attn_bias(attn_metadata, attn_bias, attn_type)
|
||||
|
||||
# No alibi slopes.
|
||||
# TODO(woosuk): Too many view operations. Let's try to reduce
|
||||
# them in the future for code readability.
|
||||
if self.alibi_slopes is None:
|
||||
# Add the batch dimension.
|
||||
query = query.unsqueeze(0)
|
||||
key = key.unsqueeze(0)
|
||||
value = value.unsqueeze(0)
|
||||
out = xops.memory_efficient_attention_forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_bias=attn_bias[0],
|
||||
p=0.0,
|
||||
scale=self.scale)
|
||||
return out.view_as(original_query)
|
||||
|
||||
# Attention with alibi slopes.
|
||||
# FIXME(woosuk): Because xformers does not support dynamic sequence
|
||||
# lengths with custom attention bias, we process each prompt one by
|
||||
# one. This is inefficient, especially when we have many short prompts.
|
||||
assert attn_metadata.seq_lens is not None
|
||||
output = torch.empty_like(original_query)
|
||||
start = 0
|
||||
for i, seq_len in enumerate(attn_metadata.seq_lens):
|
||||
end = start + seq_len
|
||||
out = xops.memory_efficient_attention_forward(
|
||||
query[None, start:end],
|
||||
key[None, start:end],
|
||||
value[None, start:end],
|
||||
attn_bias=attn_bias[i],
|
||||
p=0.0,
|
||||
scale=self.scale)
|
||||
# TODO(woosuk): Unnecessary copy. Optimize.
|
||||
output[start:end].copy_(out.view_as(original_query[start:end]))
|
||||
start += seq_len
|
||||
return output
|
||||
|
||||
|
||||
def _make_alibi_bias(
|
||||
alibi_slopes: torch.Tensor,
|
||||
num_kv_heads: int,
|
||||
dtype: torch.dtype,
|
||||
seq_lens: List[int],
|
||||
) -> List[AttentionBias]:
|
||||
attn_biases: List[AttentionBias] = []
|
||||
for seq_len in seq_lens:
|
||||
bias = torch.arange(seq_len, dtype=dtype)
|
||||
# NOTE(zhuohan): HF uses
|
||||
# `bias = bias[None, :].repeat(seq_len, 1)`
|
||||
# here. We find that both biases give the same results, but
|
||||
# the bias below more accurately follows the original ALiBi
|
||||
# paper.
|
||||
# Calculate a matrix where each element represents ith element- jth
|
||||
# element.
|
||||
bias = bias[None, :] - bias[:, None]
|
||||
|
||||
padded_len = (seq_len + 7) // 8 * 8
|
||||
num_heads = alibi_slopes.shape[0]
|
||||
bias = torch.empty(
|
||||
1, # batch size
|
||||
num_heads,
|
||||
seq_len,
|
||||
padded_len,
|
||||
device=alibi_slopes.device,
|
||||
dtype=dtype,
|
||||
)[:, :, :, :seq_len].copy_(bias)
|
||||
bias.mul_(alibi_slopes[:, None, None])
|
||||
if num_heads != num_kv_heads:
|
||||
bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
|
||||
attn_biases.append(LowerTriangularMaskWithTensorBias(bias))
|
||||
|
||||
return attn_biases
|
||||
Reference in New Issue
Block a user