forked from EngineX-Hygon/enginex-hygon-vllm
init src 0.9.2
This commit is contained in:
20
vllm/attention/__init__.py
Normal file
20
vllm/attention/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||
AttentionMetadata,
|
||||
AttentionMetadataBuilder,
|
||||
AttentionState, AttentionType)
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
|
||||
__all__ = [
|
||||
"Attention",
|
||||
"AttentionBackend",
|
||||
"AttentionMetadata",
|
||||
"AttentionType",
|
||||
"AttentionMetadataBuilder",
|
||||
"Attention",
|
||||
"AttentionState",
|
||||
"get_attn_backend",
|
||||
]
|
||||
0
vllm/attention/backends/__init__.py
Normal file
0
vllm/attention/backends/__init__.py
Normal file
325
vllm/attention/backends/abstract.py
Normal file
325
vllm/attention/backends/abstract.py
Normal file
@@ -0,0 +1,325 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, fields
|
||||
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional,
|
||||
Protocol, Set, Tuple, Type, TypeVar)
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.multimodal import MultiModalPlaceholderMap
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner_base import (ModelRunnerBase,
|
||||
ModelRunnerInputBase,
|
||||
ModelRunnerInputBuilderBase)
|
||||
|
||||
|
||||
class AttentionType:
|
||||
"""
|
||||
Attention type.
|
||||
Use string to be compatible with `torch.compile`.
|
||||
"""
|
||||
# Decoder attention between previous layer Q/K/V
|
||||
DECODER = "decoder"
|
||||
# Encoder attention between previous layer Q/K/V for encoder-decoder
|
||||
ENCODER = "encoder"
|
||||
# Encoder attention between previous layer Q/K/V
|
||||
ENCODER_ONLY = "encoder_only"
|
||||
# Attention between dec. Q and enc. K/V for encoder-decoder
|
||||
ENCODER_DECODER = "encoder_decoder"
|
||||
|
||||
|
||||
class AttentionBackend(ABC):
|
||||
"""Abstract class for attention backends."""
|
||||
# For some attention backends, we allocate an output tensor before
|
||||
# calling the custom op. When piecewise cudagraph is enabled, this
|
||||
# makes sure the output tensor is allocated inside the cudagraph.
|
||||
accept_output_buffer: bool = False
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_name() -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_impl_cls() -> Type["AttentionImpl"]:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_metadata_cls() -> Type["AttentionMetadata"]:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_state_cls() -> Type["AttentionState"]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
|
||||
return cls.get_metadata_cls()(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_builder_cls() -> Type["AttentionMetadataBuilder"]:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_stride_order() -> Tuple[int, ...]:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: torch.Tensor,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: torch.Tensor,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def advance_step(self, model_input: "ModelRunnerInputBase",
|
||||
sampled_token_ids: Optional[torch.Tensor],
|
||||
block_size: int, num_seqs: int, num_queries: int) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass
|
||||
class AttentionMetadata:
|
||||
"""Attention metadata for prefill and decode batched together."""
|
||||
# Total number of prefill requests.
|
||||
num_prefills: int
|
||||
# Number of prefill tokens.
|
||||
num_prefill_tokens: int
|
||||
# Number of decode tokens. Note that it is equivalent to the number of
|
||||
# decode requests.
|
||||
num_decode_tokens: int
|
||||
# (num_tokens,). The indices of the token slots that input tokens will be
|
||||
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
|
||||
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
|
||||
# in block 0, and 1st slot in block 1, respectively.
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
# The index maps that relate multi-modal embeddings to the corresponding
|
||||
# placeholders.
|
||||
#
|
||||
# N.B. These aren't really related to attention and don't belong on this
|
||||
# type -- this is just a temporary solution to make them available to
|
||||
# `model_executable`.
|
||||
multi_modal_placeholder_index_maps: Optional[Dict[
|
||||
str, MultiModalPlaceholderMap.IndexMap]]
|
||||
|
||||
# Enable/disable KV scales calculation. This is so that we can disable the
|
||||
# calculation until after prefill and cuda graph capture.
|
||||
enable_kv_scales_calculation: bool
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def prefill_metadata(self) -> Optional["AttentionMetadata"]:
|
||||
"""Return the attention metadata that's required to run prefill
|
||||
attention."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def decode_metadata(self) -> Optional["AttentionMetadata"]:
|
||||
"""Return the attention metadata that's required to run decode
|
||||
attention."""
|
||||
pass
|
||||
|
||||
def asdict_zerocopy(self,
|
||||
skip_fields: Optional[Set[str]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Similar to dataclasses.asdict, but avoids deepcopying."""
|
||||
if skip_fields is None:
|
||||
skip_fields = set()
|
||||
# Note that if we add dataclasses as fields, they will need
|
||||
# similar handling.
|
||||
return {
|
||||
field.name: getattr(self, field.name)
|
||||
for field in fields(self) if field.name not in skip_fields
|
||||
}
|
||||
|
||||
|
||||
T = TypeVar("T", bound=AttentionMetadata)
|
||||
|
||||
|
||||
class AttentionState(ABC, Generic[T]):
|
||||
"""Holds attention backend-specific objects reused during the
|
||||
lifetime of the model runner."""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, runner: "ModelRunnerBase"):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
@contextmanager
|
||||
def graph_capture(self, max_batch_size: int):
|
||||
"""Context manager used when capturing CUDA graphs."""
|
||||
yield
|
||||
|
||||
@abstractmethod
|
||||
def graph_clone(self, batch_size: int) -> "AttentionState[T]":
|
||||
"""Clone attention state to save in CUDA graph metadata."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def graph_capture_get_metadata_for_batch(
|
||||
self,
|
||||
batch_size: int,
|
||||
is_encoder_decoder_model: bool = False) -> T:
|
||||
"""Get attention metadata for CUDA graph capture of batch_size."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_graph_input_buffers(
|
||||
self,
|
||||
attn_metadata: T,
|
||||
is_encoder_decoder_model: bool = False) -> Dict[str, Any]:
|
||||
"""Get attention-specific input buffers for CUDA graph capture."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def prepare_graph_input_buffers(
|
||||
self,
|
||||
input_buffers: Dict[str, Any],
|
||||
attn_metadata: T,
|
||||
is_encoder_decoder_model: bool = False) -> None:
|
||||
"""In-place modify input buffers dict for CUDA graph replay."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def begin_forward(self, model_input: "ModelRunnerInputBase") -> None:
|
||||
"""Prepare state for forward pass."""
|
||||
...
|
||||
|
||||
|
||||
class AttentionMetadataBuilder(ABC, Generic[T]):
|
||||
"""Abstract class for attention metadata builders."""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, input_builder: "ModelRunnerInputBuilderBase") -> None:
|
||||
"""Create the builder, remember some configuration and parameters."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def prepare(self) -> None:
|
||||
"""Prepare for one batch."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def build(self, seq_lens: List[int], query_lens: List[int],
|
||||
cuda_graph_pad_size: int, batch_size: int) -> T:
|
||||
"""Build attention metadata with on-device tensors."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class AttentionLayer(Protocol):
|
||||
|
||||
_q_scale: torch.Tensor
|
||||
_k_scale: torch.Tensor
|
||||
_v_scale: torch.Tensor
|
||||
_k_scale_float: float
|
||||
_v_scale_float: float
|
||||
_prob_scale: torch.Tensor
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
...
|
||||
|
||||
|
||||
class AttentionImpl(ABC, Generic[T]):
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: Optional[int] = None,
|
||||
alibi_slopes: Optional[List[float]] = None,
|
||||
sliding_window: Optional[int] = None,
|
||||
kv_cache_dtype: str = "auto",
|
||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: T,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def fused_output_quant_supported(self, dtype: torch.dtype, static: bool,
|
||||
group_shape: tuple[int, int]):
|
||||
"""
|
||||
Does this attention implementation support fused output quantization.
|
||||
This is used by the AttnFusionPass to only fuse output quantization
|
||||
onto implementations that support it.
|
||||
|
||||
TODO(luka) merge parameters into QuantDescriptor
|
||||
:param dtype: quantized dtype
|
||||
:param static: static or dynamic quantization
|
||||
:param group_shape: quant group shape. (-1, -1) for per-tensor.
|
||||
:return: is fusion supported for this type of quantization
|
||||
"""
|
||||
return False
|
||||
|
||||
|
||||
class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
hidden_states_or_cq: torch.Tensor,
|
||||
kv_c_normed: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: T,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
|
||||
return kv_cache_dtype != "auto"
|
||||
469
vllm/attention/backends/blocksparse_attn.py
Normal file
469
vllm/attention/backends/blocksparse_attn.py
Normal file
@@ -0,0 +1,469 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionLayer,
|
||||
AttentionMetadata, AttentionType)
|
||||
from vllm.attention.backends.utils import (CommonAttentionState,
|
||||
CommonMetadataBuilder)
|
||||
from vllm.attention.ops.blocksparse_attention.interface import (
|
||||
LocalStridedBlockSparseAttn, get_head_sliding_step)
|
||||
from vllm.attention.ops.paged_attn import PagedAttention
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BlocksparseParams:
|
||||
max_seqlen: int
|
||||
|
||||
# Num q heads per tensor-parallel rank/partition
|
||||
num_heads: int # per TP partition
|
||||
# Num kv heads per tensor-parallel rank/partition
|
||||
num_kv_heads: int
|
||||
|
||||
# block size used for blocksparse attention.
|
||||
# This is the block_size used in `local_blocks`, `vert_stride`.
|
||||
block_size: int
|
||||
|
||||
# Number of blocks for local attention, i.e., number of
|
||||
# local attended tokens / `sparse_block_size`
|
||||
local_blocks: int
|
||||
|
||||
# Attend to one block per every `vert_stride` blocks.
|
||||
# Controlling the sparsity
|
||||
vert_stride: int
|
||||
"""
|
||||
If to use the same vertical stride offset for all heads,
|
||||
i.e., attend to the same block of tokens on all heads.
|
||||
By default, it is False, i.e., attention on the non-local
|
||||
blocks depends on the `head_idx`, that is on
|
||||
blocks satisfying
|
||||
`(block_idx + head_idx * head_sliding_step + 1) % vert_stride == 0`
|
||||
where `head_sliding_step=max(1, int(vert_stride / num_total_heads))`,
|
||||
`block_idx = position_id // sparse_block_size`.
|
||||
See `..ops.blocksparse_attention.utils:get_sparse_attn_mask`
|
||||
for more detail.
|
||||
"""
|
||||
homo_head: bool = False
|
||||
|
||||
# If within a group, the kv offsets that each q attends is the same or no.
|
||||
homo_head_group: bool = False
|
||||
|
||||
# Decided by homo_head and homo_head group
|
||||
head_sliding_step: int = field(init=False)
|
||||
|
||||
# range of q heads to for a TP rank
|
||||
active_head_range: Tuple = field(init=False)
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.block_size > 0
|
||||
assert self.local_blocks >= 0
|
||||
assert self.vert_stride >= 1
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
total_heads = tp_size * self.num_heads
|
||||
total_kv_heads = tp_size * self.num_kv_heads
|
||||
|
||||
if self.homo_head:
|
||||
self.head_sliding_step = 0
|
||||
elif self.homo_head_group:
|
||||
head_sliding_step = get_head_sliding_step(total_kv_heads,
|
||||
self.vert_stride)
|
||||
# negative indicates sliding along kv heads, i.e., homo q group
|
||||
self.head_sliding_step = -head_sliding_step
|
||||
else:
|
||||
self.head_sliding_step = get_head_sliding_step(
|
||||
total_heads, self.vert_stride)
|
||||
|
||||
self.active_head_range = (
|
||||
tp_rank * self.num_heads,
|
||||
(tp_rank + 1) * self.num_heads,
|
||||
)
|
||||
|
||||
|
||||
class BlocksparseFlashAttentionBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "BLOCK_SPARSE_FLASH_ATTN"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["BlocksparseFlashAttentionImpl"]:
|
||||
return BlocksparseFlashAttentionImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> Type["AttentionMetadata"]:
|
||||
return BlocksparseFlashAttentionMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> Type["BlocksparseFlashAttentionMetadataBuilder"]:
|
||||
return BlocksparseFlashAttentionMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> Type["CommonAttentionState"]:
|
||||
return CommonAttentionState
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
|
||||
num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: Dict[int, int],
|
||||
) -> None:
|
||||
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: Dict[int, List[int]],
|
||||
) -> None:
|
||||
PagedAttention.copy_blocks(kv_caches, src_to_dists)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BlocksparseFlashAttentionMetadata(AttentionMetadata):
|
||||
"""A copy of Metadata for FlashAttentionBackend,
|
||||
to avoid having to install flash_attn.
|
||||
|
||||
NOTE: Any python object stored here is not updated when it is
|
||||
cuda-graph replayed. If you have values that need to be changed
|
||||
dynamically, it should be stored in tensor. The tensor has to be
|
||||
updated from `CUDAGraphRunner.forward` API.
|
||||
"""
|
||||
# (batch_size,). The sequence length per sequence. Sequence length means
|
||||
# the computed tokens + new tokens None if it is a decoding.
|
||||
seq_lens: Optional[List[int]]
|
||||
# seq_lens stored as a tensor.
|
||||
seq_lens_tensor: Optional[torch.Tensor]
|
||||
|
||||
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
# |- tokenA -|......................|-- newTokens ---|
|
||||
# |---------- context_len ----------|
|
||||
# |-------------------- seq_len ----------------------|
|
||||
# |-- query_len ---|
|
||||
|
||||
# Maximum query length in the batch. None for decoding.
|
||||
max_query_len: Optional[int]
|
||||
# Maximum sequence length among prefill batch. 0 if there are decoding
|
||||
# requests only.
|
||||
max_prefill_seq_len: int
|
||||
# Maximum sequence length among decode batch. 0 if there are prefill
|
||||
# requests only.
|
||||
max_decode_seq_len: int
|
||||
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
||||
# the batch, used to index into subquery. E.g., if the subquery length
|
||||
# is [4, 6], it is [0, 4, 10].
|
||||
query_start_loc: Optional[torch.Tensor]
|
||||
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
||||
# the batch, used to index into sequence. E.g., if the sequence length is
|
||||
# [4, 6], it is [0, 4, 10].
|
||||
seq_start_loc: Optional[torch.Tensor]
|
||||
# (batch_size,) A tensor of context lengths (tokens that are computed
|
||||
# so far).
|
||||
context_lens_tensor: Optional[torch.Tensor]
|
||||
|
||||
# (batch_size, max_blocks_per_seq).
|
||||
# Block addresses per sequence. (Seq id -> list of physical block)
|
||||
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
|
||||
# in the kv cache. Each block can contain up to block_size tokens.
|
||||
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
|
||||
# captured.
|
||||
block_tables: Optional[torch.Tensor]
|
||||
|
||||
# Whether or not if cuda graph is enabled.
|
||||
# Cuda-graph is currently enabled for decoding only.
|
||||
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
||||
use_cuda_graph: bool
|
||||
|
||||
# Max number of query tokens for among request in the batch.
|
||||
max_decode_query_len: Optional[int] = None
|
||||
|
||||
_cached_prefill_metadata: Optional[
|
||||
"BlocksparseFlashAttentionMetadata"] = None
|
||||
_cached_decode_metadata: Optional[
|
||||
"BlocksparseFlashAttentionMetadata"] = None
|
||||
|
||||
block_tables_list: Optional[List[int]] = None
|
||||
|
||||
@property
|
||||
def prefill_metadata(
|
||||
self) -> Optional["BlocksparseFlashAttentionMetadata"]:
|
||||
if self.num_prefills == 0:
|
||||
return None
|
||||
|
||||
if self._cached_prefill_metadata is not None:
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
assert self.seq_lens is not None
|
||||
assert self.seq_lens_tensor is not None
|
||||
assert self.query_start_loc is not None
|
||||
assert self.context_lens_tensor is not None
|
||||
assert self.block_tables is not None
|
||||
assert self.seq_start_loc is not None
|
||||
|
||||
self._cached_prefill_metadata = BlocksparseFlashAttentionMetadata(
|
||||
num_prefills=self.num_prefills,
|
||||
num_prefill_tokens=self.num_prefill_tokens,
|
||||
num_decode_tokens=0,
|
||||
slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
|
||||
multi_modal_placeholder_index_maps=self.
|
||||
multi_modal_placeholder_index_maps,
|
||||
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
|
||||
seq_lens=self.seq_lens[:self.num_prefills],
|
||||
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
|
||||
max_query_len=self.max_query_len,
|
||||
max_prefill_seq_len=self.max_prefill_seq_len,
|
||||
max_decode_seq_len=0,
|
||||
query_start_loc=self.query_start_loc[:self.num_prefills + 1],
|
||||
seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
|
||||
context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
|
||||
block_tables=self.block_tables[:self.num_prefills],
|
||||
use_cuda_graph=False,
|
||||
block_tables_list=self.block_tables_list
|
||||
)
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
@property
|
||||
def decode_metadata(self) -> Optional["BlocksparseFlashAttentionMetadata"]:
|
||||
if self.num_decode_tokens == 0:
|
||||
return None
|
||||
|
||||
if self._cached_decode_metadata is not None:
|
||||
return self._cached_decode_metadata
|
||||
assert self.block_tables is not None
|
||||
assert self.seq_lens_tensor is not None
|
||||
|
||||
self._cached_decode_metadata = BlocksparseFlashAttentionMetadata(
|
||||
num_prefills=0,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=self.num_decode_tokens,
|
||||
slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=False,
|
||||
seq_lens=None,
|
||||
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
|
||||
max_query_len=None,
|
||||
max_prefill_seq_len=0,
|
||||
max_decode_seq_len=self.max_decode_seq_len,
|
||||
query_start_loc=None,
|
||||
seq_start_loc=None,
|
||||
context_lens_tensor=None,
|
||||
block_tables=self.block_tables[self.num_prefills:],
|
||||
use_cuda_graph=self.use_cuda_graph,
|
||||
block_tables_list=self.block_tables_list
|
||||
)
|
||||
return self._cached_decode_metadata
|
||||
|
||||
|
||||
class BlocksparseFlashAttentionMetadataBuilder(
|
||||
CommonMetadataBuilder[BlocksparseFlashAttentionMetadata]):
|
||||
|
||||
_metadata_cls = BlocksparseFlashAttentionMetadata
|
||||
|
||||
|
||||
class BlocksparseFlashAttentionImpl(AttentionImpl):
|
||||
"""
|
||||
If the input tensors contain prompt tokens, the layout is as follows:
|
||||
|<--------------- num_prompt_tokens -------------->|
|
||||
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|
|
||||
|
||||
Otherwise, the layout is as follows:
|
||||
|<------------------ num_generation_tokens (M) ----------------->|
|
||||
|<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|
|
||||
|
||||
Generation tokens can contain padding when cuda-graph is used.
|
||||
Currently, prompt tokens don't contain any padding.
|
||||
|
||||
The prompts might have different lengths, while the generation tokens
|
||||
always have length 1.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
) -> None:
|
||||
if kv_sharing_target_layer_name is not None:
|
||||
raise NotImplementedError("KV sharing is not supported in V0.")
|
||||
assert blocksparse_params is not None
|
||||
assert alibi_slopes is None, ValueError(
|
||||
"Alibi not support for blocksparse flash attention.")
|
||||
assert sliding_window is None, ValueError(
|
||||
"sliding_window is invalid for blocksparse attention.")
|
||||
assert logits_soft_cap is None, ValueError(
|
||||
"logits_soft_cap is invalid for blocksparse attention.")
|
||||
|
||||
if "num_heads" not in blocksparse_params:
|
||||
blocksparse_params["num_heads"] = num_heads
|
||||
if "num_kv_heads" not in blocksparse_params:
|
||||
blocksparse_params["num_kv_heads"] = num_kv_heads or num_heads
|
||||
self.blocksparse_params = BlocksparseParams(**blocksparse_params)
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.alibi_slopes = alibi_slopes
|
||||
self.num_kv_heads = num_kv_heads
|
||||
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
self.local_blocks = self.blocksparse_params.local_blocks
|
||||
self.vert_stride = self.blocksparse_params.vert_stride
|
||||
self.sparse_block_size = self.blocksparse_params.block_size
|
||||
self.head_sliding_step = self.blocksparse_params.head_sliding_step
|
||||
|
||||
supported_head_sizes = PagedAttention.get_supported_head_sizes()
|
||||
if head_size not in supported_head_sizes:
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by PagedAttention. "
|
||||
f"Supported head sizes are: {supported_head_sizes}.")
|
||||
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
total_num_heads = num_heads * self.tp_size
|
||||
self.bs_attn = LocalStridedBlockSparseAttn(
|
||||
total_num_heads,
|
||||
self.blocksparse_params.max_seqlen,
|
||||
self.blocksparse_params.local_blocks,
|
||||
self.blocksparse_params.vert_stride,
|
||||
self.blocksparse_params.block_size,
|
||||
homo_head=self.blocksparse_params.homo_head,
|
||||
active_head_range=self.blocksparse_params.active_head_range,
|
||||
)
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"BlocksparseFlashAttentionImpl")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: BlocksparseFlashAttentionMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashAttention and PagedAttention.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads * head_size]
|
||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
|
||||
NOTE: kv_cache will be an empty tensor with shape [0]
|
||||
for profiling run.
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
if output_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for BlocksparseFlashAttentionImpl")
|
||||
|
||||
num_tokens, hidden_size = query.shape
|
||||
# Reshape the query, key, and value tensors.
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
|
||||
if kv_cache.numel() > 0:
|
||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||
kv_cache, self.num_kv_heads, self.head_size)
|
||||
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# If kv_cache is not provided, the new key and value tensors are
|
||||
# not cached. This happens during the initial memory profiling run.
|
||||
|
||||
PagedAttention.write_to_paged_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
if prefill_meta := attn_metadata.prefill_metadata:
|
||||
|
||||
# Prompt run.
|
||||
# normal attention
|
||||
# When block_tables are not filled, it means q and k are the
|
||||
# prompt, and they have the same length.
|
||||
|
||||
assert kv_cache.numel() == 0 \
|
||||
or prefill_meta.block_tables is None \
|
||||
or prefill_meta.block_tables.numel() == 0, \
|
||||
"Does not support prefix-enabled attention."
|
||||
|
||||
output = self.bs_attn(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
cu_seqlens_q=prefill_meta.seq_start_loc,
|
||||
cu_seqlens_k=prefill_meta.seq_start_loc,
|
||||
sm_scale=self.scale,
|
||||
)
|
||||
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
# Decoding run.
|
||||
output = PagedAttention.forward_decode(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
decode_meta.block_tables,
|
||||
decode_meta.seq_lens_tensor,
|
||||
self.blocksparse_params.max_seqlen,
|
||||
self.kv_cache_dtype,
|
||||
self.num_kv_heads,
|
||||
self.scale,
|
||||
self.alibi_slopes,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
tp_rank=self.tp_rank,
|
||||
blocksparse_local_blocks=self.local_blocks,
|
||||
blocksparse_vert_stride=self.vert_stride,
|
||||
blocksparse_block_size=self.sparse_block_size,
|
||||
blocksparse_head_sliding_step=self.head_sliding_step,
|
||||
)
|
||||
|
||||
assert output is not None
|
||||
# Reshape the output tensor.
|
||||
return output.view(num_tokens, hidden_size)
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
307
vllm/attention/backends/cpu_mla.py
Normal file
307
vllm/attention/backends/cpu_mla.py
Normal file
@@ -0,0 +1,307 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from vllm._ipex_ops import ipex_ops
|
||||
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||
AttentionMetadataBuilder,
|
||||
AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
from vllm.attention.backends.mla.common import MLACommonImpl, MLACommonState
|
||||
from vllm.attention.backends.torch_sdpa import TorchSDPAMetadata
|
||||
from vllm.utils import make_tensor_with_pad
|
||||
from vllm.worker.cpu_model_runner import ModelInputForCPUBuilder
|
||||
|
||||
|
||||
class CPUMLABackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "CPU_MLA"
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> Type["CPUMLAMetadata"]:
|
||||
return CPUMLAMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> Type["CPUMLAMetadataBuilder"]:
|
||||
return CPUMLAMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> Type["MLACommonState"]:
|
||||
return MLACommonState
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["CPUMLAImpl"]:
|
||||
return CPUMLAImpl
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int, # assumed to be 1 for MLA
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return (num_blocks, block_size, head_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: torch.Tensor,
|
||||
) -> None:
|
||||
ops.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: torch.Tensor,
|
||||
) -> None:
|
||||
ops.copy_blocks_mla(kv_caches, src_to_dists)
|
||||
|
||||
@staticmethod
|
||||
def get_supported_head_sizes() -> List[int]:
|
||||
return [576]
|
||||
|
||||
|
||||
@dataclass
|
||||
class CPUMLAMetadata(TorchSDPAMetadata):
|
||||
# New for MLA
|
||||
# Input positions for rotrary embeddings since for MLA the rotary
|
||||
# position embeddings are applied inside the attention backend
|
||||
input_positions: torch.Tensor = None
|
||||
|
||||
# required by MLACommonImpl
|
||||
is_profile_run: bool = False
|
||||
|
||||
|
||||
class CPUMLAMetadataBuilder(AttentionMetadataBuilder[CPUMLAMetadata]):
|
||||
|
||||
def __init__(self, input_builder: ModelInputForCPUBuilder) -> None:
|
||||
self.chunked_prefill = input_builder.chunked_prefill
|
||||
self.input_builder = input_builder
|
||||
assert not self.chunked_prefill, \
|
||||
"chunked prefill is currently not supported"
|
||||
|
||||
def prepare(self):
|
||||
self.input_data = self.input_builder.input_data
|
||||
|
||||
def build(self, seq_lens, query_lens, cuda_graph_pad_size, batch_size):
|
||||
input_data = self.input_data
|
||||
prefill_seq_lens = seq_lens[0:input_data.num_prefills]
|
||||
prefill_query_lens = query_lens[0:input_data.num_prefills]
|
||||
slot_mapping = torch.tensor(input_data.slot_mapping,
|
||||
dtype=torch.long,
|
||||
device="cpu")
|
||||
|
||||
# metadata for prefill
|
||||
if input_data.num_prefills > 0:
|
||||
query_lens_tensor = torch.tensor(prefill_query_lens,
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
kv_lens_tensor = torch.tensor(prefill_seq_lens,
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
query_start_loc = torch.zeros(input_data.num_prefills + 1,
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
kv_start_loc = torch.zeros(input_data.num_prefills + 1,
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
torch.cumsum(query_lens_tensor,
|
||||
dim=0,
|
||||
dtype=torch.int32,
|
||||
out=query_start_loc[1:])
|
||||
torch.cumsum(kv_lens_tensor,
|
||||
dim=0,
|
||||
dtype=torch.int32,
|
||||
out=kv_start_loc[1:])
|
||||
max_query_len = max(prefill_query_lens)
|
||||
max_kv_len = max(prefill_seq_lens)
|
||||
|
||||
# for chunked-prefill
|
||||
if self.chunked_prefill:
|
||||
prefill_block_tables = make_tensor_with_pad(
|
||||
self.input_data.prefill_block_tables,
|
||||
pad=0,
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
)
|
||||
else:
|
||||
prefill_block_tables = None
|
||||
|
||||
else:
|
||||
query_start_loc = None
|
||||
kv_start_loc = None
|
||||
max_query_len = None
|
||||
max_kv_len = None
|
||||
prefill_block_tables = None
|
||||
|
||||
# metadata for decode
|
||||
if input_data.num_decode_tokens != 0:
|
||||
seq_lens_tensor = torch.tensor(
|
||||
input_data.seq_lens[input_data.num_prefills:],
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
)
|
||||
block_tables = make_tensor_with_pad(
|
||||
self.input_data.decode_block_tables,
|
||||
pad=0,
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
)
|
||||
else:
|
||||
block_tables = torch.tensor([])
|
||||
seq_lens_tensor = torch.tensor(
|
||||
input_data.seq_lens[:input_data.num_prefills],
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
)
|
||||
|
||||
# For multi-modal models
|
||||
placeholder_index_maps = None
|
||||
if len(input_data.multi_modal_inputs_list) != 0:
|
||||
placeholder_index_maps = {
|
||||
modality: placeholder_map.index_map()
|
||||
for modality, placeholder_map in
|
||||
input_data.multi_modal_placeholder_maps.items()
|
||||
}
|
||||
|
||||
return CPUMLAMetadata(
|
||||
chunked_prefill=self.chunked_prefill,
|
||||
seq_lens=prefill_seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_query_len=max_query_len,
|
||||
max_kv_len=max_kv_len,
|
||||
prefill_query_start_loc=query_start_loc,
|
||||
kv_start_loc=kv_start_loc,
|
||||
max_decode_seq_len=input_data.max_decode_seq_len,
|
||||
num_prefills=input_data.num_prefills,
|
||||
num_prefill_tokens=input_data.num_prefill_tokens,
|
||||
num_decode_tokens=input_data.num_decode_tokens,
|
||||
block_tables=block_tables,
|
||||
prefill_block_tables=prefill_block_tables,
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=placeholder_index_maps,
|
||||
enable_kv_scales_calculation=False,
|
||||
input_positions=torch.tensor([self.input_data.input_positions]))
|
||||
|
||||
|
||||
class CPUMLAImpl(MLACommonImpl[CPUMLAMetadata]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[Dict[str, Any]],
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str],
|
||||
# MLA Specific Arguments
|
||||
**mla_args) -> None:
|
||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
blocksparse_params, logits_soft_cap, attn_type,
|
||||
kv_sharing_target_layer_name, **mla_args)
|
||||
|
||||
unsupported_features = [
|
||||
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
|
||||
]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"CPUMLAImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, blocksparse_params, "
|
||||
"logits_soft_cap")
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"CPUMLAImpl")
|
||||
|
||||
# states is implemented.
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
raise NotImplementedError(
|
||||
"CPUMLAImpl with FP8 KV cache not yet supported")
|
||||
|
||||
def _forward_prefill(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
kv_c_normed: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: CPUMLAMetadata, # type: ignore[override]
|
||||
) -> torch.Tensor:
|
||||
|
||||
prefill_metadata = attn_metadata.prefill_metadata
|
||||
assert prefill_metadata is not None
|
||||
|
||||
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\
|
||||
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
||||
k_nope, v = kv_nope\
|
||||
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
|
||||
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
|
||||
|
||||
# For MLA the v head dim is smaller than qk head dim so we pad out
|
||||
# v with 0s to match the qk head dim
|
||||
v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
|
||||
value=0)
|
||||
|
||||
output = torch.empty_like(q)
|
||||
ipex_ops.varlen_attention(
|
||||
query=q,
|
||||
key=k,
|
||||
value=v_padded,
|
||||
out=output,
|
||||
seqlen_q=prefill_metadata.prefill_query_start_loc,
|
||||
seqlen_k=prefill_metadata.prefill_query_start_loc,
|
||||
max_seqlen_q=prefill_metadata.max_query_len,
|
||||
max_seqlen_k=prefill_metadata.max_query_len,
|
||||
pdropout=0.0,
|
||||
softmax_scale=self.scale,
|
||||
zero_tensors=False,
|
||||
is_causal=True,
|
||||
return_softmax=False,
|
||||
gen_=None,
|
||||
logits_soft_cap=0.0,
|
||||
window_size_left=-1,
|
||||
window_size_right=-1,
|
||||
alibi_slopes=None,
|
||||
)
|
||||
|
||||
# remove padding
|
||||
output = output.view(-1, self.num_heads,
|
||||
q.shape[-1])[..., :v.shape[-1]]
|
||||
return output.reshape(-1, self.num_heads * v.shape[-1])
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: CPUMLAMetadata, # type: ignore[override]
|
||||
) -> torch.Tensor:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
|
||||
decode_meta = attn_metadata.decode_metadata
|
||||
assert decode_meta is not None
|
||||
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
||||
o = q.new_empty(q.shape[0], self.num_heads, self.kv_lora_rank)
|
||||
|
||||
# Run MQA
|
||||
ops.mla_decode_kvcache_cpu(o, q, kv_c_and_k_pe_cache, self.scale,
|
||||
decode_meta.block_tables,
|
||||
decode_meta.seq_lens_tensor)
|
||||
return self._v_up_proj(o)
|
||||
1530
vllm/attention/backends/dual_chunk_flash_attn.py
Normal file
1530
vllm/attention/backends/dual_chunk_flash_attn.py
Normal file
File diff suppressed because it is too large
Load Diff
1084
vllm/attention/backends/flash_attn.py
Normal file
1084
vllm/attention/backends/flash_attn.py
Normal file
File diff suppressed because it is too large
Load Diff
1109
vllm/attention/backends/flashinfer.py
Normal file
1109
vllm/attention/backends/flashinfer.py
Normal file
File diff suppressed because it is too large
Load Diff
249
vllm/attention/backends/flashmla.py
Normal file
249
vllm/attention/backends/flashmla.py
Normal file
@@ -0,0 +1,249 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
from vllm.attention.backends.mla.common import (MLACommonBackend,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata,
|
||||
MLACommonMetadataBuilder,
|
||||
MLACommonState)
|
||||
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
|
||||
get_mla_metadata,
|
||||
is_flashmla_supported)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
|
||||
|
||||
|
||||
class FlashMLABackend(MLACommonBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASHMLA"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["FlashMLAImpl"]:
|
||||
return FlashMLAImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> Type["FlashMLAMetadata"]:
|
||||
return FlashMLAMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> Type["FlashMLAMetadataBuilder"]:
|
||||
return FlashMLAMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> Type["FlashMLAState"]:
|
||||
return FlashMLAState
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashMLAMetadata(MLACommonMetadata):
|
||||
decode_tile_scheduler_metadata: Optional[Tuple[torch.Tensor,
|
||||
torch.Tensor]] = None
|
||||
decode_num_splits: Optional[torch.Tensor] = None
|
||||
|
||||
@property
|
||||
def decode_metadata(self):
|
||||
decode_metadata = super().decode_metadata
|
||||
# TODO: cache assignment?
|
||||
if decode_metadata is not None:
|
||||
decode_metadata.decode_tile_scheduler_metadata=\
|
||||
self.decode_tile_scheduler_metadata
|
||||
decode_metadata.decode_num_splits=\
|
||||
self.decode_num_splits
|
||||
return decode_metadata
|
||||
|
||||
def advance_step(self,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
sampled_token_ids: Optional[torch.Tensor],
|
||||
block_size: int,
|
||||
num_seqs: int,
|
||||
num_queries: int,
|
||||
turn_prefills_into_decodes: bool = False):
|
||||
raise NotImplementedError(
|
||||
"advance_step is not implemented for FlashMLA")
|
||||
|
||||
|
||||
class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.num_q_heads = self.runner.model_config.get_num_attention_heads(
|
||||
self.runner.parallel_config)
|
||||
|
||||
def build(self, seq_lens: List[int], query_lens: List[int],
|
||||
cuda_graph_pad_size: int, batch_size: int):
|
||||
m = super().build(seq_lens, query_lens, cuda_graph_pad_size,
|
||||
batch_size)
|
||||
|
||||
if m.num_decode_tokens > 0:
|
||||
m.decode_tile_scheduler_metadata, m.decode_num_splits = \
|
||||
get_mla_metadata(
|
||||
m.seq_lens_tensor[m.num_prefills:],
|
||||
self.num_q_heads,
|
||||
1, # MQA for the decode path
|
||||
)
|
||||
|
||||
return m
|
||||
|
||||
|
||||
class FlashMLAState(MLACommonState[FlashMLAMetadata]):
|
||||
|
||||
def __init__(self, *args, **kwds):
|
||||
super().__init__(*args, **kwds)
|
||||
|
||||
self.num_q_heads = self.runner.model_config.get_num_attention_heads(
|
||||
self.runner.parallel_config)
|
||||
|
||||
@contextmanager
|
||||
def graph_capture(self, max_batch_size: int):
|
||||
# Run a dummy `get_mla_metadata` so we can get the right shapes
|
||||
self._graph_decoder_tile_scheduler_metadata, \
|
||||
self._graph_decode_num_splits = get_mla_metadata(
|
||||
torch.ones(
|
||||
max_batch_size, dtype=torch.int32, device=self.runner.device),
|
||||
self.num_q_heads,
|
||||
1, # MQA for the decode path
|
||||
)
|
||||
|
||||
with super().graph_capture(max_batch_size):
|
||||
yield
|
||||
|
||||
del self._graph_decoder_tile_scheduler_metadata
|
||||
del self._graph_decode_num_splits
|
||||
|
||||
def graph_capture_get_metadata_for_batch(
|
||||
self, batch_size: int, is_encoder_decoder_model: bool = False):
|
||||
metadata = super().graph_capture_get_metadata_for_batch(
|
||||
batch_size, is_encoder_decoder_model)
|
||||
assert metadata.num_decode_tokens > 0
|
||||
|
||||
decoder_tile_scheduler_metadata, decode_num_splits = get_mla_metadata(
|
||||
self._graph_seq_lens[:batch_size],
|
||||
self.num_q_heads,
|
||||
1, # MQA for the decode path
|
||||
)
|
||||
|
||||
self._graph_decoder_tile_scheduler_metadata.copy_(
|
||||
decoder_tile_scheduler_metadata)
|
||||
self._graph_decode_num_splits[:batch_size + 1].copy_(decode_num_splits)
|
||||
|
||||
metadata.decode_tile_scheduler_metadata=\
|
||||
self._graph_decoder_tile_scheduler_metadata
|
||||
metadata.decode_num_splits=\
|
||||
self._graph_decode_num_splits[:batch_size + 1]
|
||||
|
||||
return metadata
|
||||
|
||||
def get_graph_input_buffers(self,
|
||||
attn_metadata,
|
||||
is_encoder_decoder_model: bool = False):
|
||||
input_buffers = super().get_graph_input_buffers(
|
||||
attn_metadata, is_encoder_decoder_model)
|
||||
input_buffers["decode_tile_scheduler_metadata"] = \
|
||||
attn_metadata.decode_metadata.decode_tile_scheduler_metadata
|
||||
input_buffers["decode_num_splits"] = \
|
||||
attn_metadata.decode_metadata.decode_num_splits
|
||||
|
||||
return input_buffers
|
||||
|
||||
def prepare_graph_input_buffers(self,
|
||||
input_buffers,
|
||||
attn_metadata,
|
||||
is_encoder_decoder_model: bool = False):
|
||||
super().prepare_graph_input_buffers(input_buffers, attn_metadata,
|
||||
is_encoder_decoder_model)
|
||||
|
||||
input_buffers["decode_tile_scheduler_metadata"].copy_(
|
||||
attn_metadata.decode_metadata.decode_tile_scheduler_metadata)
|
||||
input_buffers["decode_num_splits"].copy_(
|
||||
attn_metadata.decode_metadata.decode_num_splits)
|
||||
|
||||
|
||||
class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[Dict[str, Any]],
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
# MLA Specific Arguments
|
||||
**mla_args) -> None:
|
||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
blocksparse_params, logits_soft_cap, attn_type,
|
||||
kv_sharing_target_layer_name, **mla_args)
|
||||
|
||||
assert is_flashmla_supported(), \
|
||||
"FlashMLA is not supported on this device"
|
||||
|
||||
unsupported_features = [
|
||||
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
|
||||
]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"FlashMLAImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, blocksparse_params, "
|
||||
"logits_soft_cap")
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"FlashMLAImpl")
|
||||
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
if self.kv_cache_dtype != "fp8":
|
||||
raise NotImplementedError(
|
||||
"FlashMLA with other KV cache not yet supported")
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: FlashMLAMetadata,
|
||||
k_scale = None,
|
||||
kv_cache_dtype = "auto",
|
||||
) -> torch.Tensor:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
|
||||
decode_meta = attn_metadata.decode_metadata
|
||||
assert decode_meta is not None
|
||||
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)\
|
||||
.unsqueeze(1) # Add seqlen dim of 1 (decode)
|
||||
|
||||
o, _ = flash_mla_with_kvcache(
|
||||
q=q,
|
||||
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
|
||||
block_table=decode_meta.block_tables,
|
||||
cache_seqlens=decode_meta.seq_lens_tensor,
|
||||
head_dim_v=self.kv_lora_rank,
|
||||
tile_scheduler_metadata=decode_meta.decode_tile_scheduler_metadata,
|
||||
num_splits=decode_meta.decode_num_splits,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
k_scale = k_scale,
|
||||
kv_cache_dtype = kv_cache_dtype,
|
||||
)
|
||||
|
||||
return self._v_up_proj(o)
|
||||
318
vllm/attention/backends/hpu_attn.py
Normal file
318
vllm/attention/backends/hpu_attn.py
Normal file
@@ -0,0 +1,318 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
###############################################################################
|
||||
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company
|
||||
###############################################################################
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
import vllm_hpu_extension.kernels as kernels
|
||||
import vllm_hpu_extension.ops as ops
|
||||
from vllm_hpu_extension.flags import enabled_flags
|
||||
from vllm_hpu_extension.utils import Matmul, Softmax, VLLMKVCache
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionLayer,
|
||||
AttentionMetadata, AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
from vllm.attention.ops.hpu_paged_attn import (HPUPagedAttention,
|
||||
HPUPagedAttentionMetadata)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class HPUAttentionBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "HPU_ATTN"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["HPUAttentionImpl"]:
|
||||
return HPUAttentionImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> Type["AttentionMetadata"]:
|
||||
return HPUAttentionMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> Type["CommonAttentionState"]:
|
||||
return CommonAttentionState
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return HPUPagedAttention.get_kv_cache_shape(num_blocks, block_size,
|
||||
num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dsts: torch.Tensor,
|
||||
) -> None:
|
||||
HPUPagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dsts)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dsts: torch.Tensor,
|
||||
) -> None:
|
||||
HPUPagedAttention.copy_blocks(kv_caches, src_to_dsts)
|
||||
|
||||
|
||||
@dataclass
|
||||
class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata):
|
||||
"""Metadata for HPUAttentionbackend."""
|
||||
# Currently, input sequences can only contain all prompts
|
||||
# or all decoding. True if all sequences are prompts.
|
||||
is_prompt: bool
|
||||
attn_bias: Optional[torch.Tensor]
|
||||
seq_lens_tensor: Optional[torch.Tensor]
|
||||
context_lens_tensor: Optional[torch.Tensor]
|
||||
|
||||
|
||||
class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
|
||||
"""
|
||||
If the input tensors contain prompt tokens, the layout is as follows:
|
||||
|<--------------- num_prefill_tokens ----------------->|
|
||||
|<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
|
||||
|
||||
Otherwise, the layout is as follows:
|
||||
|<----------------- num_decode_tokens ------------------>|
|
||||
|<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
|
||||
|
||||
Generation tokens can contain padding when cuda-graph is used.
|
||||
Currently, prompt tokens don't contain any padding.
|
||||
|
||||
The prompts might have different lengths, while the generation tokens
|
||||
always have length 1.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||
max_seq_len: int = 4096,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
use_irope: bool = False,
|
||||
) -> None:
|
||||
super(AttentionImpl, self).__init__()
|
||||
if kv_sharing_target_layer_name is not None:
|
||||
raise NotImplementedError("KV sharing is not supported in V0.")
|
||||
if use_irope:
|
||||
logger.warning_once(
|
||||
"Using irope in HPU is not supported yet, it will fall back "
|
||||
"to global attention for long context.")
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.matmul_qk = Matmul()
|
||||
self.softmax = Softmax()
|
||||
self.matmul_av = Matmul()
|
||||
self.batch2block_matmul = Matmul()
|
||||
self.block2batch_matmul = Matmul()
|
||||
self.k_cache = VLLMKVCache()
|
||||
self.v_cache = VLLMKVCache()
|
||||
self.fused_scaled_dot_product_attention = kernels.fsdpa()
|
||||
|
||||
self.prefill_impl = 'naive'
|
||||
if "flex_attention" in enabled_flags():
|
||||
self.prefill_impl = 'flex'
|
||||
if "fsdpa" in enabled_flags():
|
||||
assert alibi_slopes is None, \
|
||||
'Prefill with FusedSDPA not supported with alibi slopes!'
|
||||
self.prefill_impl = 'fsdpa'
|
||||
|
||||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||
self.sliding_window = sliding_window
|
||||
self.alibi_slopes = alibi_slopes
|
||||
if alibi_slopes is not None:
|
||||
alibi_slopes_tensor = torch.tensor(alibi_slopes,
|
||||
dtype=torch.bfloat16)
|
||||
self.alibi_slopes = alibi_slopes_tensor
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
if self.prefill_impl == 'fsdpa':
|
||||
assert alibi_slopes is None, \
|
||||
'Prefill with FusedSDPA not supported with alibi slopes!'
|
||||
|
||||
supported_head_sizes = HPUPagedAttention.get_supported_head_sizes()
|
||||
if head_size not in supported_head_sizes:
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by PagedAttention. "
|
||||
f"Supported head sizes are: {supported_head_sizes}.")
|
||||
|
||||
self.attn_type = attn_type
|
||||
if self.attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"HPUAttentionImpl")
|
||||
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
raise NotImplementedError(
|
||||
"HPUAttention with FP8 KV cache not yet supported")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: HPUAttentionMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with xFormers and PagedAttention.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads * head_size]
|
||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
if output_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for HPUAttentionImpl")
|
||||
|
||||
batch_size, seq_len, hidden_size = query.shape
|
||||
_, seq_len_kv, _ = key.shape
|
||||
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
block_indices = attn_metadata.block_indices
|
||||
block_offsets = attn_metadata.block_offsets
|
||||
key_cache = None
|
||||
value_cache = None
|
||||
if attn_metadata.is_prompt and self.attn_type \
|
||||
is not AttentionType.ENCODER_ONLY:
|
||||
key = key.unflatten(0, (block_indices.size(0), -1))
|
||||
value = value.unflatten(0, (block_indices.size(0), -1))
|
||||
if kv_cache is not None and isinstance(kv_cache, tuple):
|
||||
key_cache, value_cache = HPUPagedAttention.split_kv_cache(
|
||||
kv_cache, self.num_kv_heads, self.head_size)
|
||||
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# If kv_cache is not provided, the new key and value tensors are
|
||||
# not cached. This happens during the initial memory profiling run.
|
||||
key_cache = self.k_cache(key, key_cache, block_indices,
|
||||
block_offsets)
|
||||
value_cache = self.v_cache(value, value_cache, block_indices,
|
||||
block_offsets)
|
||||
|
||||
if attn_metadata.is_prompt:
|
||||
# Prompt run.
|
||||
query_shape = (batch_size, seq_len, self.num_heads, self.head_size)
|
||||
kv_shape = (batch_size, seq_len_kv, self.num_kv_heads,
|
||||
self.head_size)
|
||||
|
||||
attn_bias = attn_metadata.attn_bias
|
||||
if attn_bias is not None and self.alibi_slopes is not None:
|
||||
position_bias = _make_alibi_bias(self.alibi_slopes,
|
||||
self.num_kv_heads,
|
||||
attn_bias.dtype,
|
||||
attn_bias.shape[-1])
|
||||
attn_bias = attn_bias.tile((1, self.num_kv_heads, 1, 1))
|
||||
attn_bias.add_(position_bias)
|
||||
|
||||
block_list = attn_metadata.block_list if attn_metadata \
|
||||
and attn_metadata.block_list is not None else None
|
||||
|
||||
out = ops.prompt_attention(
|
||||
impl=self.prefill_impl,
|
||||
query=query.view(query_shape),
|
||||
key=key.view(kv_shape),
|
||||
value=value.view(kv_shape),
|
||||
is_causal=True,
|
||||
attn_bias=attn_bias,
|
||||
valid_seq_lengths=attn_metadata.seq_lens_tensor,
|
||||
**self.common_attention_args(block_list, key_cache,
|
||||
value_cache))
|
||||
output = out.reshape(batch_size, seq_len, hidden_size)
|
||||
else:
|
||||
# Decoding run.
|
||||
output = HPUPagedAttention.forward_decode(
|
||||
query=query,
|
||||
block_mapping=attn_metadata.block_mapping,
|
||||
block_bias=attn_metadata.attn_bias,
|
||||
block_groups=attn_metadata.block_groups,
|
||||
**self.common_attention_args(attn_metadata.block_list,
|
||||
key_cache, value_cache))
|
||||
# Reshape the output tensor.
|
||||
return output.view(batch_size, seq_len, hidden_size)
|
||||
|
||||
def common_attention_args(self,
|
||||
block_list=None,
|
||||
key_cache=None,
|
||||
value_cache=None):
|
||||
fsdpa_op = self.fused_scaled_dot_product_attention.apply \
|
||||
if self.fused_scaled_dot_product_attention is not None else None
|
||||
return {
|
||||
'scale': self.scale,
|
||||
'matmul_qk_op': self.matmul_qk,
|
||||
'matmul_av_op': self.matmul_av,
|
||||
'batch2block_matmul_op': self.batch2block_matmul,
|
||||
'block2batch_matmul_op': self.block2batch_matmul,
|
||||
'fsdpa_op': fsdpa_op,
|
||||
'keys_fetch_func': self.k_cache.fetch_from_cache,
|
||||
'values_fetch_func': self.v_cache.fetch_from_cache,
|
||||
'softmax_op': self.softmax,
|
||||
'block_list': block_list,
|
||||
'key_cache': key_cache,
|
||||
'value_cache': value_cache,
|
||||
}
|
||||
|
||||
|
||||
def _make_alibi_bias(
|
||||
alibi_slopes: torch.Tensor,
|
||||
num_kv_heads: int,
|
||||
dtype: torch.dtype,
|
||||
seq_len: int,
|
||||
) -> torch.Tensor:
|
||||
bias = torch.arange(seq_len, dtype=dtype)
|
||||
# NOTE(zhuohan): HF uses
|
||||
# `bias = bias[None, :].repeat(seq_len, 1)`
|
||||
# here. We find that both biases give the same results, but
|
||||
# the bias below more accurately follows the original ALiBi
|
||||
# paper.
|
||||
# Calculate a matrix where each element represents ith element- jth
|
||||
# element.
|
||||
bias = bias[None, :] - bias[:, None]
|
||||
|
||||
padded_len = (seq_len + 7) // 8 * 8
|
||||
num_heads = alibi_slopes.shape[0]
|
||||
bias = torch.empty(
|
||||
1, # batch size
|
||||
num_heads,
|
||||
seq_len,
|
||||
padded_len,
|
||||
device=alibi_slopes.device,
|
||||
dtype=dtype,
|
||||
)[:, :, :, :seq_len].copy_(bias)
|
||||
bias.mul_(alibi_slopes[:, None, None])
|
||||
if num_heads != num_kv_heads:
|
||||
bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
|
||||
return bias
|
||||
403
vllm/attention/backends/ipex_attn.py
Normal file
403
vllm/attention/backends/ipex_attn.py
Normal file
@@ -0,0 +1,403 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
""" Attention layer with torch scaled_dot_product_attention
|
||||
and PagedAttention."""
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
|
||||
from vllm._ipex_ops import ipex_ops
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionLayer,
|
||||
AttentionMetadata, AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
from vllm.attention.ops.paged_attn import (PagedAttention,
|
||||
PagedAttentionMetadata)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_PARTITION_SIZE = 512
|
||||
|
||||
|
||||
class IpexAttnBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "IPEX"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["IpexAttnBackendImpl"]:
|
||||
return IpexAttnBackendImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> Type["IpexAttnMetadata"]:
|
||||
return IpexAttnMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> Type["CommonAttentionState"]:
|
||||
return CommonAttentionState
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
|
||||
num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: torch.Tensor,
|
||||
) -> None:
|
||||
from vllm._ipex_ops import ipex_ops as ops
|
||||
ops.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: torch.Tensor,
|
||||
) -> None:
|
||||
from vllm._ipex_ops import ipex_ops as ops
|
||||
key_caches = [kv_cache[0] for kv_cache in kv_caches]
|
||||
value_caches = [kv_cache[1] for kv_cache in kv_caches]
|
||||
ops.copy_blocks(key_caches, value_caches, src_to_dists)
|
||||
|
||||
|
||||
@dataclass
|
||||
class IpexAttnMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
"""Metadata for IpexAttnBackend.
|
||||
"""
|
||||
# Currently, input sequences can only contain all prompts
|
||||
# or all decoding. True if all sequences are prompts.
|
||||
is_prompt: bool
|
||||
slot_mapping: torch.Tensor
|
||||
seq_lens: Optional[List[int]]
|
||||
seqlen_q: Optional[torch.Tensor]
|
||||
max_seqlen: Optional[int]
|
||||
|
||||
def __post_init__(self):
|
||||
# Set during the execution of the first attention op.
|
||||
# It is a list because it is needed to set per prompt
|
||||
# when alibi slopes is used. It is because of the limitation
|
||||
# from xformer API.
|
||||
# will not appear in the __repr__ and __init__
|
||||
self.attn_bias: Optional[List[torch.Tensor]] = None
|
||||
|
||||
@property
|
||||
def prefill_metadata(self) -> Optional["IpexAttnMetadata"]:
|
||||
# Currently chunked prefill is not supported
|
||||
if self.num_decode_tokens == 0:
|
||||
assert self.num_prefills > 0
|
||||
return self
|
||||
|
||||
return None
|
||||
|
||||
@property
|
||||
def decode_metadata(self) -> Optional["IpexAttnMetadata"]:
|
||||
# Currently chunked prefill is not supported
|
||||
if self.num_prefills > 0:
|
||||
assert self.num_decode_tokens == 0
|
||||
return None
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
use_irope: bool = False,
|
||||
) -> None:
|
||||
if kv_sharing_target_layer_name is not None:
|
||||
raise NotImplementedError("KV sharing is not supported in V0.")
|
||||
if use_irope:
|
||||
logger.warning_once(
|
||||
"Using irope in Ipex is not supported yet, it will fall"
|
||||
" back to global attention for long context.")
|
||||
if blocksparse_params is not None:
|
||||
raise ValueError(
|
||||
"IPEX backend does not support block-sparse attention.")
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
if alibi_slopes is not None:
|
||||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
||||
self.alibi_slopes = alibi_slopes
|
||||
self.sliding_window = sliding_window
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
self.need_mask = (self.sliding_window is not None)
|
||||
if logits_soft_cap is None:
|
||||
logits_soft_cap = -1
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
|
||||
supported_head_sizes = PagedAttention.get_supported_head_sizes()
|
||||
if head_size not in supported_head_sizes:
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by PagedAttention. "
|
||||
f"Supported head sizes are: {supported_head_sizes}.")
|
||||
if is_quantized_kv_cache(kv_cache_dtype):
|
||||
raise NotImplementedError(
|
||||
"IPEX backend does not support FP8 KV cache. "
|
||||
"Please use xFormers backend instead.")
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"IpexAttnBackendImpl")
|
||||
|
||||
def split_kv_cache(
|
||||
self,
|
||||
kv_cache: torch.Tensor,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
x = 1
|
||||
num_blocks = kv_cache.shape[1]
|
||||
|
||||
key_cache = kv_cache[0]
|
||||
key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x,
|
||||
-1, x)
|
||||
value_cache = kv_cache[1]
|
||||
value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1)
|
||||
return key_cache, value_cache
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: IpexAttnMetadata, # type: ignore
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with IPEX varlen_attention and PagedAttention.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads * head_size]
|
||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
|
||||
NOTE: kv_cache will be an empty tensor with shape [0]
|
||||
for profiling run.
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
if output_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for IpexAttentionImpl")
|
||||
|
||||
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
|
||||
num_tokens, hidden_size = query.shape
|
||||
# Reshape the query, key, and value tensors.
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
|
||||
if kv_cache.numel() > 0:
|
||||
key_cache, value_cache = self.split_kv_cache(
|
||||
kv_cache, self.num_kv_heads, self.head_size)
|
||||
ipex_ops.reshape_and_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping.flatten(),
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale_float,
|
||||
layer._v_scale_float,
|
||||
)
|
||||
|
||||
if attn_metadata.is_prompt:
|
||||
assert attn_metadata.seq_lens is not None
|
||||
if (kv_cache.numel() == 0
|
||||
or attn_metadata.block_tables.numel() == 0):
|
||||
if self.num_kv_heads != self.num_heads:
|
||||
key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
|
||||
value = value.repeat_interleave(self.num_queries_per_kv,
|
||||
dim=1)
|
||||
|
||||
if attn_metadata.attn_bias is None:
|
||||
if self.sliding_window is not None:
|
||||
att_masks = _make_sliding_window_bias(
|
||||
attn_metadata.seq_lens, self.sliding_window,
|
||||
query.dtype) # type: ignore
|
||||
else:
|
||||
att_masks = _make_sliding_window_bias(
|
||||
attn_metadata.seq_lens, None, dtype=query.dtype)
|
||||
attn_metadata.attn_bias = att_masks
|
||||
|
||||
output = torch.empty(
|
||||
(num_tokens, self.num_heads, self.head_size),
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
ipex_ops.varlen_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
output,
|
||||
attn_metadata.seqlen_q,
|
||||
attn_metadata.seqlen_q,
|
||||
self.alibi_slopes,
|
||||
attn_metadata.max_seqlen,
|
||||
attn_metadata.max_seqlen,
|
||||
pdropout=0.0,
|
||||
softmax_scale=self.scale,
|
||||
zero_tensors=False,
|
||||
is_causal=True,
|
||||
return_softmax=False,
|
||||
gen_=None,
|
||||
window_size_left=-1,
|
||||
window_size_right=-1,
|
||||
logits_soft_cap=self.logits_soft_cap,
|
||||
)
|
||||
else:
|
||||
# prefix-enabled attention
|
||||
raise RuntimeError(
|
||||
"IPEX backend doesn't support prefix decoding.")
|
||||
|
||||
else:
|
||||
# Decoding run.
|
||||
max_seq_len = attn_metadata.max_decode_seq_len
|
||||
output = torch.empty_like(query)
|
||||
block_size = value_cache.shape[3]
|
||||
num_seqs, num_heads, head_size = query.shape
|
||||
max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) //
|
||||
_PARTITION_SIZE)
|
||||
# NOTE(woosuk): We use a simple heuristic to decide whether to use
|
||||
# PagedAttention V1 or V2. If the number of partitions is 1, we use
|
||||
# V1 to avoid the overhead of reduction. Also, if the number of
|
||||
# sequences or heads is large, we use V1 since there is enough work
|
||||
# to parallelize.
|
||||
# TODO(woosuk): Tune this heuristic.
|
||||
# For context len > 8192, use V2 kernel to avoid shared memory
|
||||
# shortage.
|
||||
use_v1 = (max_seq_len <= 8192 and
|
||||
(max_num_partitions == 1 or num_seqs * num_heads > 512))
|
||||
if use_v1:
|
||||
# Run PagedAttention V1.
|
||||
ipex_ops.paged_attention_v1(
|
||||
output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
self.num_kv_heads,
|
||||
self.scale,
|
||||
attn_metadata.block_tables,
|
||||
attn_metadata.seq_lens_tensor,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
self.alibi_slopes,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale_float,
|
||||
layer._v_scale_float,
|
||||
)
|
||||
else:
|
||||
# Run PagedAttention V2.
|
||||
assert _PARTITION_SIZE % block_size == 0
|
||||
tmp_output = torch.empty(
|
||||
size=(num_seqs, num_heads, max_num_partitions, head_size),
|
||||
dtype=output.dtype,
|
||||
device=output.device,
|
||||
)
|
||||
exp_sums = torch.empty(
|
||||
size=(num_seqs, num_heads, max_num_partitions),
|
||||
dtype=torch.float32,
|
||||
device=output.device,
|
||||
)
|
||||
max_logits = torch.empty_like(exp_sums)
|
||||
ipex_ops.paged_attention_v2(
|
||||
output,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
self.num_kv_heads,
|
||||
self.scale,
|
||||
attn_metadata.block_tables,
|
||||
attn_metadata.seq_lens_tensor,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
self.alibi_slopes,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale_float,
|
||||
layer._v_scale_float,
|
||||
)
|
||||
|
||||
# Reshape the output tensor.
|
||||
return output.view(-1, self.num_heads * self.head_size)
|
||||
|
||||
|
||||
def _make_alibi_bias(
|
||||
alibi_slopes: torch.Tensor,
|
||||
dtype: torch.dtype,
|
||||
seq_lens: List[int],
|
||||
) -> List[torch.Tensor]:
|
||||
attn_biases = []
|
||||
for seq_len in seq_lens:
|
||||
bias = torch.arange(seq_len, dtype=dtype, device=alibi_slopes.device)
|
||||
# NOTE(zhuohan): HF uses
|
||||
# `bias = bias[None, :].repeat(seq_len, 1)`
|
||||
# here. We find that both biases give the same results, but
|
||||
# the bias below more accurately follows the original ALiBi
|
||||
# paper.
|
||||
bias = bias[None, :] - bias[:, None]
|
||||
|
||||
num_heads = alibi_slopes.shape[0]
|
||||
bias = bias[None, :].repeat((num_heads, 1, 1))
|
||||
bias.mul_(alibi_slopes[:, None, None])
|
||||
inf_mask = torch.empty(
|
||||
(1, seq_len, seq_len),
|
||||
dtype=bias.dtype,
|
||||
device=alibi_slopes.device).fill_(-torch.inf).triu_(diagonal=1)
|
||||
attn_biases.append((bias + inf_mask).to(dtype))
|
||||
|
||||
return attn_biases
|
||||
|
||||
|
||||
def _make_sliding_window_bias(
|
||||
seq_lens: List[int],
|
||||
window_size: Optional[int],
|
||||
dtype: torch.dtype,
|
||||
) -> List[torch.Tensor]:
|
||||
attn_biases = []
|
||||
for seq_len in seq_lens:
|
||||
tensor = torch.full(
|
||||
(1, seq_len, seq_len),
|
||||
dtype=dtype,
|
||||
fill_value=1,
|
||||
)
|
||||
shift = 0
|
||||
mask = torch.tril(tensor, diagonal=shift).to(dtype) # type: ignore
|
||||
if window_size is not None:
|
||||
mask = torch.triu(mask, diagonal=shift - window_size + 1)
|
||||
mask = torch.log(mask)
|
||||
attn_biases.append(mask.to(dtype))
|
||||
|
||||
return attn_biases
|
||||
0
vllm/attention/backends/mla/__init__.py
Normal file
0
vllm/attention/backends/mla/__init__.py
Normal file
1405
vllm/attention/backends/mla/common.py
Normal file
1405
vllm/attention/backends/mla/common.py
Normal file
File diff suppressed because it is too large
Load Diff
356
vllm/attention/backends/pallas.py
Normal file
356
vllm/attention/backends/pallas.py
Normal file
@@ -0,0 +1,356 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
import torch_xla.experimental.custom_kernel # Required to register custom ops.
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionLayer,
|
||||
AttentionMetadata, AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class PallasAttentionBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "PALLAS"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["PallasAttentionBackendImpl"]:
|
||||
return PallasAttentionBackendImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> Type["PallasMetadata"]:
|
||||
return PallasMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> Type["CommonAttentionState"]:
|
||||
return CommonAttentionState
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return (num_kv_heads, num_blocks, block_size, head_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: torch.Tensor,
|
||||
) -> None:
|
||||
raise RuntimeError("swap_blocks is not used for the TPU backend.")
|
||||
|
||||
@torch.compile(backend="openxla")
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
src_to_dists: Tuple[torch.Tensor, torch.Tensor],
|
||||
) -> None:
|
||||
src_indices, dst_indices = src_to_dists
|
||||
for k_cache, v_cache in kv_caches:
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(k_cache, True)
|
||||
k_cache[:, dst_indices] = k_cache[:, src_indices]
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(v_cache, True)
|
||||
v_cache[:, dst_indices] = v_cache[:, src_indices]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PallasMetadata(AttentionMetadata):
|
||||
|
||||
# Currently, input sequences can only contain all prefills
|
||||
# or all decoding.
|
||||
block_tables: Optional[torch.Tensor] = None
|
||||
context_lens: Optional[torch.Tensor] = None
|
||||
effective_query_lens: Optional[torch.Tensor] = None
|
||||
|
||||
@property
|
||||
def prefill_metadata(self) -> Optional["PallasMetadata"]:
|
||||
if self.num_prefills == 0:
|
||||
return None
|
||||
|
||||
assert self.num_decode_tokens == 0
|
||||
return self
|
||||
|
||||
@property
|
||||
def decode_metadata(self) -> Optional["PallasMetadata"]:
|
||||
if self.num_decode_tokens == 0:
|
||||
return None
|
||||
|
||||
assert self.num_prefills == 0
|
||||
assert self.num_prefill_tokens == 0
|
||||
assert self.block_tables is not None
|
||||
assert self.context_lens is not None
|
||||
return self
|
||||
|
||||
|
||||
class PallasAttentionBackendImpl(AttentionImpl):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
use_irope: bool = False,
|
||||
) -> None:
|
||||
if kv_sharing_target_layer_name is not None:
|
||||
raise NotImplementedError("KV sharing is not supported in V0.")
|
||||
if use_irope:
|
||||
logger.warning_once(
|
||||
"Using irope in Pallas is not supported yet, it will fall back "
|
||||
"to global attention for long context.")
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
if head_size % 128 != 0:
|
||||
raise NotImplementedError(
|
||||
f"Head size must be a multiple of 128, found {head_size}.")
|
||||
if alibi_slopes is not None:
|
||||
raise NotImplementedError("Alibi slopes is not supported.")
|
||||
if sliding_window is not None:
|
||||
raise NotImplementedError("Sliding window is not supported.")
|
||||
if is_quantized_kv_cache(kv_cache_dtype):
|
||||
raise NotImplementedError("FP8 KV cache dtype is not supported.")
|
||||
if blocksparse_params is not None:
|
||||
raise NotImplementedError("Blocksparse is not supported.")
|
||||
|
||||
if torch_xla.tpu.version() < 4:
|
||||
raise NotImplementedError("TPU version must be 4 or higher.")
|
||||
|
||||
self.megacore_mode = None
|
||||
tpu_env = torch_xla.tpu.get_tpu_env()
|
||||
tpu_type = (tpu_env.get("ACCELERATOR_TYPE", None)
|
||||
or tpu_env.get("TYPE", None)
|
||||
or tpu_env.get("TPU_ACCELERATOR_TYPE", None))
|
||||
assert tpu_type is not None
|
||||
tpu_type = tpu_type.lower()
|
||||
|
||||
if (("lite" not in tpu_type) and ("v6" not in tpu_type)):
|
||||
if self.num_kv_heads % 2 == 0:
|
||||
self.megacore_mode = "kv_head"
|
||||
else:
|
||||
# NOTE(woosuk): If the batch size is not a multiple of 2, the
|
||||
# megacore mode will be None.
|
||||
self.megacore_mode = "batch"
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"PallasAttentionBackendImpl")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: Tuple[torch.Tensor, torch.Tensor],
|
||||
attn_metadata: PallasMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with Pallas attention.
|
||||
|
||||
Args:
|
||||
query: shape = [batch_size, seq_len, num_heads * head_size]
|
||||
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
||||
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
||||
kv_cache[0] = [num_kv_heads, num_blocks, block_size, head_size]
|
||||
kv_cache[1] = [num_kv_heads, num_blocks, block_size, head_size]
|
||||
NOTE: kv_cache[0] and kv_cache[1] will be an empty tensor
|
||||
with shape [0] for profiling run.
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [batch_size, seq_len, num_heads * head_size]
|
||||
"""
|
||||
if output_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for PallasAttentionImpl")
|
||||
|
||||
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
|
||||
batch_size, seq_len, hidden_size = query.shape
|
||||
query = query.view(batch_size, seq_len, self.num_heads, self.head_size)
|
||||
key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size)
|
||||
value = value.view(batch_size, seq_len, self.num_kv_heads,
|
||||
self.head_size)
|
||||
|
||||
if kv_cache[0].numel() > 0:
|
||||
slot_mapping = attn_metadata.slot_mapping
|
||||
key_cache, value_cache = kv_cache
|
||||
write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping)
|
||||
|
||||
query = query * self.scale
|
||||
if attn_metadata.num_prefills > 0:
|
||||
if attn_metadata.block_tables is None:
|
||||
# Prefill without paged KV cache.
|
||||
assert seq_len % 16 == 0, (
|
||||
"Pallas FlashAttention kernel requires seq_len to be a "
|
||||
f"multiple of 16 but got {seq_len}")
|
||||
|
||||
# Handle GQA/MQA.
|
||||
if self.num_kv_heads != self.num_heads:
|
||||
key = key.repeat_interleave(self.num_queries_per_kv,
|
||||
dim=-2)
|
||||
key = key.view(batch_size, seq_len, self.num_heads,
|
||||
self.head_size)
|
||||
value = value.repeat_interleave(self.num_queries_per_kv,
|
||||
dim=-2)
|
||||
value = value.view(batch_size, seq_len, self.num_heads,
|
||||
self.head_size)
|
||||
# FlashAttention kernel requires the input shape to be
|
||||
# [batch_size, num_heads, seq_len, d_model]
|
||||
# while the input is [batch_size, seq_len, num_heads, d_model].
|
||||
# Permute the input to match the required format.
|
||||
output = torch.ops.xla.flash_attention(
|
||||
query.permute(0, 2, 1, 3),
|
||||
key.permute(0, 2, 1, 3),
|
||||
value.permute(0, 2, 1, 3),
|
||||
True,
|
||||
)
|
||||
output = output.permute(0, 2, 1, 3)
|
||||
else:
|
||||
# Prefill with paged KV cache.
|
||||
# TODO(woosuk): Tune the below knobs.
|
||||
num_kv_pages_per_compute_block = 16
|
||||
num_queries_per_compute_block = 16
|
||||
assert seq_len % num_queries_per_compute_block == 0
|
||||
output = torch.ops.xla.multi_queries_paged_attention(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.context_lens,
|
||||
attn_metadata.block_tables,
|
||||
attn_metadata.effective_query_lens,
|
||||
num_kv_pages_per_compute_block,
|
||||
num_queries_per_compute_block,
|
||||
use_kernel=True,
|
||||
attn_logits_soft_cap=self.logits_soft_cap,
|
||||
)
|
||||
else:
|
||||
# Decoding run.
|
||||
assert kv_cache[0].numel() > 0
|
||||
query = query.squeeze(dim=1)
|
||||
pages_per_compute_block = 16 # TODO(woosuk): Tune this value.
|
||||
|
||||
assert attn_metadata.block_tables is not None
|
||||
assert attn_metadata.context_lens is not None
|
||||
# NOTE(woosuk): The PagedAttention Pallas kernel stores the entire
|
||||
# block table in SMEM. Therefore, if the block table is too large,
|
||||
# the kernel compilation will fail. To avoid this, we split the
|
||||
# batch dimension into smaller chunks and run the kernel multiple
|
||||
# times.
|
||||
MAX_SMEM_USAGE = 512 * 1024
|
||||
size_per_seq = 4 * attn_metadata.block_tables.shape[1]
|
||||
max_num_seq = MAX_SMEM_USAGE // size_per_seq
|
||||
|
||||
if batch_size <= max_num_seq:
|
||||
output = paged_attention(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.context_lens,
|
||||
attn_metadata.block_tables,
|
||||
pages_per_compute_block,
|
||||
self.megacore_mode,
|
||||
attn_logits_soft_cap=self.logits_soft_cap,
|
||||
)
|
||||
else:
|
||||
chunk_size = max_num_seq
|
||||
# Make sure the chunk size is a multiple of 2.
|
||||
chunk_size = chunk_size // 2 * 2
|
||||
num_chunks = (batch_size + chunk_size - 1) // chunk_size
|
||||
|
||||
output = torch.empty_like(query)
|
||||
for chunk_idx in range(num_chunks):
|
||||
chunk_start = chunk_idx * chunk_size
|
||||
chunk_end = chunk_start + chunk_size
|
||||
# NOTE(woosuk): We skip this line because it causes Dynamo
|
||||
# compilation error. Instead, we rely on the slice operation
|
||||
# to handle the out-of-bound case.
|
||||
# chunk_end = min(chunk_end, batch_size)
|
||||
chunk_output = paged_attention(
|
||||
query[chunk_start:chunk_end],
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.context_lens[chunk_start:chunk_end],
|
||||
attn_metadata.block_tables[chunk_start:chunk_end],
|
||||
pages_per_compute_block,
|
||||
self.megacore_mode,
|
||||
attn_logits_soft_cap=self.logits_soft_cap,
|
||||
)
|
||||
output[chunk_start:chunk_end] = chunk_output
|
||||
|
||||
# Reshape the output tensor.
|
||||
return output.reshape(batch_size, seq_len, hidden_size)
|
||||
|
||||
|
||||
def write_to_kv_cache(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
) -> None:
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(key_cache, True)
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(value_cache, True)
|
||||
|
||||
key = key.flatten(0, 2)
|
||||
value = value.flatten(0, 2)
|
||||
key_cache = key_cache.flatten(0, 2)
|
||||
value_cache = value_cache.flatten(0, 2)
|
||||
key_cache.index_copy_(0, slot_mapping, key)
|
||||
value_cache.index_copy_(0, slot_mapping, value)
|
||||
|
||||
|
||||
def paged_attention(
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
pages_per_compute_block: int,
|
||||
megacore_mode: Optional[str],
|
||||
*,
|
||||
attn_logits_soft_cap: Optional[float],
|
||||
) -> torch.Tensor:
|
||||
batch_size = query.shape[0]
|
||||
if megacore_mode == "batch" and batch_size % 2 != 0:
|
||||
megacore_mode = None
|
||||
else:
|
||||
megacore_mode = megacore_mode
|
||||
|
||||
return torch.ops.xla.paged_attention(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
context_lens,
|
||||
block_tables,
|
||||
pages_per_compute_block,
|
||||
megacore_mode=megacore_mode,
|
||||
attn_logits_soft_cap=attn_logits_soft_cap,
|
||||
)
|
||||
400
vllm/attention/backends/placeholder_attn.py
Normal file
400
vllm/attention/backends/placeholder_attn.py
Normal file
@@ -0,0 +1,400 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from itertools import accumulate
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata,
|
||||
AttentionMetadataBuilder)
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
from vllm.multimodal import MultiModalPlaceholderMap
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
|
||||
ModelInputForGPUWithSamplingMetadata)
|
||||
from vllm.utils import async_tensor_h2d
|
||||
|
||||
# Placeholder attention backend for models like Mamba and pooling models that
|
||||
# lack attention.
|
||||
|
||||
|
||||
class PlaceholderAttentionBackend(AttentionBackend):
|
||||
"""Placeholder backend for when no attention is needed."""
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "NO_ATTENTION"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["PlaceholderAttentionImpl"]:
|
||||
return PlaceholderAttentionImpl
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> Type["PlaceholderAttentionMetadataBuilder"]:
|
||||
return PlaceholderAttentionMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> Type["PlaceholderAttentionMetadata"]:
|
||||
return PlaceholderAttentionMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> Type["CommonAttentionState"]:
|
||||
return CommonAttentionState
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return (1, 1, 1, 1, 1)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: torch.Tensor,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: torch.Tensor,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlaceholderAttentionMetadata(AttentionMetadata):
|
||||
"""Attention metadata for prefill and decode batched together."""
|
||||
# (batch_size,). The sequence length per sequence. Sequence length means
|
||||
# the computed tokens + new tokens None if it is a decoding.
|
||||
seq_lens: Optional[List[int]]
|
||||
# seq_lens stored as a tensor.
|
||||
seq_lens_tensor: Optional[torch.Tensor]
|
||||
|
||||
# Maximum sequence length among prefill batch. 0 if there are decoding
|
||||
# requests only.
|
||||
max_prefill_seq_len: int
|
||||
# Maximum sequence length among decode batch. 0 if there are prefill
|
||||
# requests only.
|
||||
max_decode_seq_len: int
|
||||
# (batch_size,) A tensor of context lengths (tokens that are computed
|
||||
# so far).
|
||||
context_lens_tensor: Optional[torch.Tensor]
|
||||
|
||||
# Whether or not if cuda graph is enabled.
|
||||
# Cuda-graph is currently enabled for decoding only.
|
||||
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
||||
use_cuda_graph: bool
|
||||
|
||||
# Maximum query length in the batch.
|
||||
max_query_len: Optional[int]
|
||||
|
||||
# Max number of query tokens among request in the batch.
|
||||
max_decode_query_len: Optional[int]
|
||||
|
||||
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
||||
# the batch, used to index into subquery. E.g., if the subquery length
|
||||
# is [4, 6], it is [0, 4, 10].
|
||||
query_start_loc: Optional[torch.Tensor] = None
|
||||
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
||||
# the batch, used to index into sequence. E.g., if the sequence length is
|
||||
# [4, 6], it is [0, 4, 10].
|
||||
seq_start_loc: Optional[torch.Tensor] = None
|
||||
|
||||
# Placeholder.
|
||||
block_tables: Optional[torch.Tensor] = None
|
||||
|
||||
_cached_prefill_metadata: Optional["PlaceholderAttentionMetadata"] = None
|
||||
_cached_decode_metadata: Optional["PlaceholderAttentionMetadata"] = None
|
||||
|
||||
@property
|
||||
def prefill_metadata(self) -> Optional["PlaceholderAttentionMetadata"]:
|
||||
if self.num_prefills == 0:
|
||||
return None
|
||||
|
||||
if self._cached_prefill_metadata is not None:
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
# Compute some attn_metadata fields which default to None
|
||||
query_start_loc = (None if self.query_start_loc is None else
|
||||
self.query_start_loc[:self.num_prefills + 1])
|
||||
seq_lens = (None if self.seq_lens is None else
|
||||
self.seq_lens[:self.num_prefills])
|
||||
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
||||
self.seq_lens_tensor[:self.num_prefills])
|
||||
seq_start_loc = (None if self.seq_start_loc is None else
|
||||
self.seq_start_loc[:self.num_prefills + 1])
|
||||
context_lens_tensor = (None if self.context_lens_tensor is None else
|
||||
self.context_lens_tensor[:self.num_prefills])
|
||||
|
||||
# Placeholders
|
||||
slot_mapping = torch.empty(0)
|
||||
block_tables = torch.empty(0)
|
||||
|
||||
self._cached_prefill_metadata = PlaceholderAttentionMetadata(
|
||||
num_prefills=self.num_prefills,
|
||||
num_prefill_tokens=self.num_prefill_tokens,
|
||||
num_decode_tokens=0,
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=self.
|
||||
multi_modal_placeholder_index_maps,
|
||||
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_decode_query_len=0,
|
||||
max_query_len=self.max_query_len,
|
||||
max_prefill_seq_len=self.max_prefill_seq_len,
|
||||
max_decode_seq_len=0,
|
||||
query_start_loc=query_start_loc,
|
||||
seq_start_loc=seq_start_loc,
|
||||
context_lens_tensor=context_lens_tensor,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=False,
|
||||
)
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
@property
|
||||
def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]:
|
||||
if self.num_decode_tokens == 0:
|
||||
return None
|
||||
|
||||
if self._cached_decode_metadata is not None:
|
||||
return self._cached_decode_metadata
|
||||
assert self.seq_lens_tensor is not None
|
||||
|
||||
# Placeholders
|
||||
slot_mapping = torch.empty(0)
|
||||
block_tables = torch.empty(0)
|
||||
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
||||
self.seq_lens_tensor[self.num_prefills:])
|
||||
|
||||
self._cached_decode_metadata = PlaceholderAttentionMetadata(
|
||||
num_prefills=0,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=self.num_decode_tokens,
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=True,
|
||||
seq_lens=None,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_decode_query_len=self.max_decode_query_len,
|
||||
max_query_len=None,
|
||||
max_prefill_seq_len=0,
|
||||
max_decode_seq_len=self.max_decode_seq_len,
|
||||
query_start_loc=(self.query_start_loc[self.num_prefills:] -
|
||||
self.query_start_loc[self.num_prefills])
|
||||
if self.query_start_loc is not None else None,
|
||||
seq_start_loc=self.seq_start_loc[self.num_prefills:]
|
||||
if self.seq_start_loc is not None else None,
|
||||
context_lens_tensor=None,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=self.use_cuda_graph,
|
||||
)
|
||||
return self._cached_decode_metadata
|
||||
|
||||
def advance_step(self,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
sampled_token_ids: Optional[torch.Tensor],
|
||||
block_size: int,
|
||||
num_seqs: int,
|
||||
num_queries: int,
|
||||
turn_prefills_into_decodes: bool = False):
|
||||
"""
|
||||
Update metadata in-place to advance one decode step.
|
||||
"""
|
||||
# When using cudagraph, the num_seqs is padded to the next captured
|
||||
# batch sized, but num_queries tracks the actual number of requests in
|
||||
# the batch. For --enforce-eager mode, num_seqs == num_queries
|
||||
if num_seqs != num_queries:
|
||||
assert num_seqs > num_queries
|
||||
assert self.use_cuda_graph
|
||||
|
||||
assert not turn_prefills_into_decodes, \
|
||||
("Multi-Step + Chunked-Prefill is not supported for attention-free"
|
||||
"models. turn_prefills_into_decodes is a "
|
||||
"Multi-Step + Chunked-Prefill specific parameter.")
|
||||
|
||||
assert self.seq_lens is not None
|
||||
assert self.max_decode_seq_len == max(self.seq_lens)
|
||||
|
||||
assert self.num_prefills == 0
|
||||
assert self.num_prefill_tokens == 0
|
||||
assert self.num_decode_tokens == num_seqs
|
||||
|
||||
assert self.seq_lens is not None
|
||||
assert len(self.seq_lens) == num_seqs
|
||||
assert self.seq_lens_tensor is not None
|
||||
assert self.seq_lens_tensor.shape == (num_seqs, )
|
||||
assert self.max_query_len == 1
|
||||
assert self.max_prefill_seq_len == 0
|
||||
|
||||
assert self.query_start_loc is not None
|
||||
assert self.query_start_loc.shape == (num_queries + 1, )
|
||||
assert self.seq_start_loc is not None
|
||||
assert self.seq_start_loc.shape == (num_seqs + 1, )
|
||||
|
||||
assert self.context_lens_tensor is not None
|
||||
assert self.context_lens_tensor.shape == (num_queries, )
|
||||
|
||||
# Update query lengths. Note that we update only queries and not seqs,
|
||||
# since tensors may be padded due to captured cuda graph batch size
|
||||
for i in range(num_queries):
|
||||
self.seq_lens[i] += 1
|
||||
self.max_decode_seq_len = max(self.seq_lens)
|
||||
|
||||
# Update sequences, masking off entries greater than num_queries
|
||||
device = self.seq_lens_tensor.device
|
||||
mask = torch.arange(self.seq_lens_tensor.size(0),
|
||||
device=device) < num_queries
|
||||
self.seq_lens_tensor += mask.to(self.seq_lens_tensor.dtype)
|
||||
if sampled_token_ids is not None:
|
||||
model_input.input_tokens.masked_scatter_(
|
||||
mask, sampled_token_ids[:num_queries])
|
||||
|
||||
|
||||
class PlaceholderAttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[PlaceholderAttentionMetadata]):
|
||||
|
||||
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
|
||||
|
||||
self.input_builder = input_builder
|
||||
self.runner = input_builder.runner
|
||||
|
||||
def prepare(self):
|
||||
self.prefill_seq_lens: List[int] = []
|
||||
self.context_lens: List[int] = []
|
||||
self.curr_seq_lens: List[int] = []
|
||||
self.multimodal_placeholder_maps: Dict[
|
||||
str,
|
||||
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
|
||||
self.num_prefills = 0
|
||||
self.num_prefill_tokens = 0
|
||||
self.num_decode_tokens = 0
|
||||
|
||||
def _add_seq_group(
|
||||
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
|
||||
chunked_prefill_enabled: bool):
|
||||
"""Add a sequence group to the metadata. Specifically update/append
|
||||
1. context length.
|
||||
"""
|
||||
is_prompt = inter_data.is_prompt
|
||||
|
||||
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
|
||||
curr_sliding_window_block) in zip(
|
||||
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
|
||||
inter_data.orig_seq_lens, inter_data.seq_lens,
|
||||
inter_data.query_lens, inter_data.context_lens,
|
||||
inter_data.curr_sliding_window_blocks):
|
||||
self.context_lens.append(context_len)
|
||||
|
||||
if is_prompt:
|
||||
mm_maps = inter_data.multi_modal_placeholder_maps
|
||||
if mm_maps:
|
||||
for modality, placeholders in mm_maps.items():
|
||||
self.multimodal_placeholder_maps[modality].extend(
|
||||
placeholders)
|
||||
|
||||
self.num_prefills += 1
|
||||
self.num_prefill_tokens += token_len
|
||||
self.prefill_seq_lens.append(seq_len)
|
||||
else:
|
||||
self.num_decode_tokens += query_len
|
||||
self.curr_seq_lens.append(curr_seq_len)
|
||||
|
||||
def build(self, seq_lens: List[int], query_lens: List[int],
|
||||
cuda_graph_pad_size: int, batch_size: int):
|
||||
"""Build attention metadata with on-device tensors.
|
||||
|
||||
Args:
|
||||
seq_lens: The maybe padded sequence lengths of the input sequences.
|
||||
query_lens: The query lengths of the input sequences.
|
||||
cuda_graph_pad_size: The padding size for cuda graph.
|
||||
-1 if cuda graph is not used.
|
||||
batch_size: The maybe padded batch size.
|
||||
"""
|
||||
|
||||
# Some input builders such as ModelInputForCPUBuilder do not have the
|
||||
# "inter_data_list" attribute.
|
||||
# Let's check inter_data_list exists before we reference it.
|
||||
if hasattr(self.input_builder, "inter_data_list"):
|
||||
for inter_data in self.input_builder.inter_data_list:
|
||||
self._add_seq_group(inter_data,
|
||||
self.input_builder.chunked_prefill_enabled)
|
||||
|
||||
device = self.runner.device
|
||||
use_captured_graph = cuda_graph_pad_size != -1
|
||||
|
||||
max_query_len = max(query_lens)
|
||||
decode_query_lens = query_lens[self.num_prefills:]
|
||||
if len(decode_query_lens) > 0:
|
||||
max_decode_query_len = max(decode_query_lens)
|
||||
else:
|
||||
max_decode_query_len = 1
|
||||
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
|
||||
max_decode_seq_len = max(self.curr_seq_lens, default=0)
|
||||
num_decode_tokens = self.num_decode_tokens
|
||||
query_start_loc = list(accumulate(query_lens, initial=0))
|
||||
seq_start_loc = list(accumulate(seq_lens, initial=0))
|
||||
|
||||
if use_captured_graph:
|
||||
num_decode_tokens = batch_size - self.num_prefill_tokens
|
||||
assert max_query_len > 0, ("query_lens: {}".format(query_lens))
|
||||
|
||||
assert device is not None
|
||||
context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int,
|
||||
device, self.runner.pin_memory)
|
||||
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
|
||||
self.runner.pin_memory)
|
||||
query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32,
|
||||
device,
|
||||
self.runner.pin_memory)
|
||||
seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32,
|
||||
device, self.runner.pin_memory)
|
||||
|
||||
placeholder_index_maps = {
|
||||
modality: placeholder_map.index_map()
|
||||
for modality, placeholder_map in
|
||||
self.multimodal_placeholder_maps.items()
|
||||
}
|
||||
|
||||
# Placeholders
|
||||
slot_mapping_tensor = torch.empty(0)
|
||||
block_tables = torch.empty(0)
|
||||
|
||||
return PlaceholderAttentionMetadata(
|
||||
num_prefills=self.num_prefills,
|
||||
slot_mapping=slot_mapping_tensor,
|
||||
multi_modal_placeholder_index_maps=placeholder_index_maps,
|
||||
enable_kv_scales_calculation=True,
|
||||
num_prefill_tokens=self.num_prefill_tokens,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_query_len=max_query_len,
|
||||
max_decode_query_len=max_decode_query_len,
|
||||
max_prefill_seq_len=max_prefill_seq_len,
|
||||
max_decode_seq_len=max_decode_seq_len,
|
||||
query_start_loc=query_start_loc_tensor,
|
||||
seq_start_loc=seq_start_loc_tensor,
|
||||
context_lens_tensor=context_lens_tensor,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=use_captured_graph,
|
||||
)
|
||||
|
||||
|
||||
class PlaceholderAttentionImpl(AttentionImpl):
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
return
|
||||
|
||||
def forward(self, *args, **kwargs) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
435
vllm/attention/backends/rocm_aiter_mla.py
Normal file
435
vllm/attention/backends/rocm_aiter_mla.py
Normal file
@@ -0,0 +1,435 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional, Type, Union
|
||||
|
||||
import torch
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.mla.common import (MLACommonBackend,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata,
|
||||
MLACommonMetadataBuilder,
|
||||
MLACommonState)
|
||||
from vllm.attention.backends.utils import (compute_slot_mapping,
|
||||
compute_slot_mapping_start_idx,
|
||||
is_block_tables_empty)
|
||||
from vllm.attention.ops.rocm_aiter_mla import (aiter_mla_decode_fwd,
|
||||
get_aiter_mla_metadata)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import ModelInputForGPUBuilder
|
||||
|
||||
|
||||
def is_aiter_mla_enabled() -> bool:
|
||||
return envs.VLLM_ROCM_USE_AITER \
|
||||
and envs.VLLM_ROCM_USE_AITER_MLA
|
||||
|
||||
|
||||
class AiterMLABackend(MLACommonBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "ROCM_AITER_MLA"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["AiterMLAImpl"]:
|
||||
return AiterMLAImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> Type["AiterMLAMetadata"]:
|
||||
return AiterMLAMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> Type["AiterMLAMetadataBuilder"]:
|
||||
return AiterMLAMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> Type["AiterMLAState"]:
|
||||
return AiterMLAState
|
||||
|
||||
|
||||
@dataclass
|
||||
class AiterMLAMetadata(MLACommonMetadata):
|
||||
# The following 5 tensors are for current version of AITER MLA
|
||||
block_table_bound: Optional[torch.Tensor] = None
|
||||
# The indptr of the paged kv cache, shape: [batch_size + 1]
|
||||
paged_kv_indptr: Optional[torch.Tensor] = None
|
||||
# The page indices of the paged kv cache
|
||||
paged_kv_indices: Optional[torch.Tensor] = None
|
||||
# The number of entries in the last page of each request in
|
||||
# the paged kv cache, shape: [batch_size]
|
||||
paged_kv_last_page_lens: Optional[torch.Tensor] = None
|
||||
|
||||
# This is just to make new AITER MLA API work
|
||||
# -- MTP support is not added yet.
|
||||
qo_indptr: Optional[torch.Tensor] = None
|
||||
|
||||
@property
|
||||
def prefill_metadata(self):
|
||||
prefill_metadata = super().prefill_metadata
|
||||
self._cached_prefill_metadata = prefill_metadata
|
||||
|
||||
if prefill_metadata is not None:
|
||||
prefill_metadata.paged_kv_indptr = self.paged_kv_indptr
|
||||
prefill_metadata.paged_kv_indices = self.paged_kv_indices
|
||||
prefill_metadata\
|
||||
.paged_kv_last_page_lens = self.paged_kv_last_page_lens
|
||||
prefill_metadata.block_table_bound = self.block_table_bound
|
||||
prefill_metadata.qo_indptr = self.qo_indptr
|
||||
|
||||
# update the cache
|
||||
self._cached_prefill_metadata = self.__class__(
|
||||
**prefill_metadata.__dict__)
|
||||
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
@property
|
||||
def decode_metadata(self):
|
||||
decode_metadata = super().decode_metadata
|
||||
|
||||
self._cached_decode_metadata = decode_metadata
|
||||
|
||||
if decode_metadata is not None:
|
||||
decode_metadata.paged_kv_indptr = self.paged_kv_indptr
|
||||
decode_metadata.paged_kv_indices = self.paged_kv_indices
|
||||
decode_metadata\
|
||||
.paged_kv_last_page_lens = self.paged_kv_last_page_lens
|
||||
decode_metadata.block_table_bound = self.block_table_bound
|
||||
decode_metadata.qo_indptr = self.qo_indptr
|
||||
|
||||
# update the cache
|
||||
self._cached_decode_metadata = self.__class__(
|
||||
**decode_metadata.__dict__)
|
||||
|
||||
return self._cached_decode_metadata
|
||||
|
||||
def _ops_advance_step(self, num_seqs: int, num_queries: int,
|
||||
block_size: int, input_tokens: torch.Tensor,
|
||||
sampled_token_ids: torch.Tensor,
|
||||
input_positions: torch.Tensor) -> None:
|
||||
|
||||
ops.advance_step_flashinfer(
|
||||
num_seqs=num_seqs,
|
||||
num_queries=num_queries,
|
||||
block_size=block_size,
|
||||
input_tokens=input_tokens,
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
input_positions=input_positions,
|
||||
seq_lens=self.seq_lens_tensor,
|
||||
slot_mapping=self.slot_mapping,
|
||||
block_tables=self.block_tables,
|
||||
paged_kv_indices=self.paged_kv_indices,
|
||||
paged_kv_indptr=self.paged_kv_indptr,
|
||||
paged_kv_last_page_lens=self.paged_kv_last_page_lens,
|
||||
block_table_bound=self.block_table_bound)
|
||||
|
||||
|
||||
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
BLOCK_TABLE_EXTENDER: list[list[int]] = [[]]
|
||||
|
||||
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
|
||||
super().__init__(input_builder)
|
||||
assert self.block_size == 1, "AITER MLA requires only block size 1."
|
||||
|
||||
def prepare(self):
|
||||
super().prepare()
|
||||
self.paged_kv_indices: list[int] = []
|
||||
self.paged_kv_indptr: list[int] = [0]
|
||||
self.paged_kv_last_page_lens: list[int] = []
|
||||
self.total_blocks = 0
|
||||
self.qo_indptr: list[int] = [0]
|
||||
|
||||
def _add_seq_group(self, inter_data, chunked_prefill_enabled: bool,
|
||||
prefix_cache_hit: bool):
|
||||
"""Add a sequence group to the metadata. Specifically update/append
|
||||
1. context length.
|
||||
2. block table.
|
||||
3. slot mapping.
|
||||
"""
|
||||
is_prompt = inter_data.is_prompt
|
||||
block_tables = inter_data.block_tables
|
||||
|
||||
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
|
||||
curr_sliding_window_block) in zip(
|
||||
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
|
||||
inter_data.orig_seq_lens, inter_data.seq_lens,
|
||||
inter_data.query_lens, inter_data.context_lens,
|
||||
inter_data.curr_sliding_window_blocks):
|
||||
self.context_lens.append(context_len)
|
||||
if is_prompt:
|
||||
self.num_prefills += 1
|
||||
self.num_prefill_tokens += token_len
|
||||
self.prefill_seq_lens.append(seq_len)
|
||||
else:
|
||||
self.num_decode_tokens += query_len
|
||||
self.curr_seq_lens.append(curr_seq_len)
|
||||
|
||||
# Compute block table.
|
||||
# TODO(sang): Combine chunked prefill and prefix caching by
|
||||
# only allowing multiple of block_size chunk size.
|
||||
# NOTE: This only works for oooooooxxx style attention.
|
||||
block_table = []
|
||||
if prefix_cache_hit:
|
||||
# NOTE(woosuk): For flash-attn, the block table should
|
||||
# include the entries for the incoming prefill tokens.
|
||||
block_table = block_tables[seq_id]
|
||||
elif ((chunked_prefill_enabled or not is_prompt)
|
||||
and block_tables is not None):
|
||||
if curr_sliding_window_block == 0:
|
||||
block_table = block_tables[seq_id]
|
||||
else:
|
||||
block_table = block_tables[seq_id][
|
||||
-curr_sliding_window_block:]
|
||||
self.block_tables.append(block_table)
|
||||
|
||||
# Compute slot mapping.
|
||||
is_profile_run = is_block_tables_empty(block_tables)
|
||||
start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
|
||||
context_len,
|
||||
self.sliding_window)
|
||||
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
|
||||
seq_len, context_len, start_idx,
|
||||
self.block_size, inter_data.block_tables)
|
||||
if is_profile_run:
|
||||
return
|
||||
|
||||
# Update paged_kv_* tensors only for non-profile run
|
||||
block_table = block_tables[seq_id]
|
||||
self._update_paged_kv_tensors(block_table, seq_len)
|
||||
|
||||
def _update_paged_kv_tensors(self, block_table: list[int], seq_len: int):
|
||||
# Get the number of valid blocks based on sequence length.
|
||||
# If seq_len = 16, block_size = 16,
|
||||
# block_table_bound is 1 with 1 valid block.
|
||||
# If seq_len = 15, block_size = 16,
|
||||
# block_table_bound is 0 + 1 with 1 valid block.
|
||||
self.total_blocks += len(block_table)
|
||||
block_table_bound = seq_len // self.block_size + 1 \
|
||||
if seq_len % self.block_size != 0 \
|
||||
else seq_len // self.block_size
|
||||
self.paged_kv_indices.extend(block_table[:block_table_bound])
|
||||
self.paged_kv_indptr.append(self.paged_kv_indptr[-1] +
|
||||
block_table_bound)
|
||||
self.qo_indptr.append(self.qo_indptr[-1] + 1)
|
||||
|
||||
last_page_len = seq_len % self.block_size
|
||||
if last_page_len == 0:
|
||||
last_page_len = self.block_size
|
||||
self.paged_kv_last_page_lens.append(last_page_len)
|
||||
|
||||
def build(self, seq_lens: list[int], query_lens: list[int],
|
||||
cuda_graph_pad_size: int, batch_size: int) -> AiterMLAMetadata:
|
||||
metadata = super().build(seq_lens, query_lens, cuda_graph_pad_size,
|
||||
batch_size)
|
||||
device = self.runner.device
|
||||
use_captured_graph = cuda_graph_pad_size != -1
|
||||
|
||||
if use_captured_graph:
|
||||
last_paged_kv_indptr = self.paged_kv_indptr[-1]
|
||||
self.paged_kv_indptr.extend([last_paged_kv_indptr] *
|
||||
cuda_graph_pad_size)
|
||||
self.paged_kv_last_page_lens.extend([0] * cuda_graph_pad_size)
|
||||
last_qo_indptr = self.qo_indptr[-1]
|
||||
self.qo_indptr.extend([last_qo_indptr] * cuda_graph_pad_size)
|
||||
|
||||
# For current version of AITER MLA
|
||||
if len(self.paged_kv_indptr) > 0:
|
||||
# extend to the maximum number of blocks as returned by the
|
||||
# scheduler
|
||||
self.paged_kv_indices.extend(
|
||||
[0] * (self.total_blocks - len(self.paged_kv_indices)))
|
||||
paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices,
|
||||
device=device,
|
||||
dtype=torch.int)
|
||||
paged_kv_indptr_tensor = torch.tensor(self.paged_kv_indptr,
|
||||
device=device,
|
||||
dtype=torch.int)
|
||||
paged_kv_last_page_lens_tensor = torch.tensor(
|
||||
self.paged_kv_last_page_lens, device=device, dtype=torch.int)
|
||||
block_table_bound_tensor = torch.zeros(len(self.paged_kv_indptr) -
|
||||
1,
|
||||
device=device,
|
||||
dtype=torch.int)
|
||||
|
||||
qo_indptr = torch.tensor(self.qo_indptr,
|
||||
device=device,
|
||||
dtype=torch.int)
|
||||
else:
|
||||
paged_kv_indices_tensor = None
|
||||
paged_kv_indptr_tensor = None
|
||||
paged_kv_last_page_lens_tensor = None
|
||||
block_table_bound_tensor = None
|
||||
qo_indptr = None
|
||||
|
||||
metadata.paged_kv_indptr = paged_kv_indptr_tensor
|
||||
metadata.paged_kv_indices = paged_kv_indices_tensor
|
||||
metadata.paged_kv_last_page_lens = paged_kv_last_page_lens_tensor
|
||||
metadata.block_table_bound = block_table_bound_tensor
|
||||
metadata.qo_indptr = qo_indptr
|
||||
|
||||
return metadata
|
||||
|
||||
|
||||
class AiterMLAState(MLACommonState[AiterMLAMetadata]):
|
||||
|
||||
@contextmanager
|
||||
def graph_capture(self, max_batch_size: int):
|
||||
kv_indices, kv_indptr, last_page_lens, qo_indptr = \
|
||||
get_aiter_mla_metadata(
|
||||
max_batch_size=max_batch_size,
|
||||
block_size=self.runner.block_size,
|
||||
max_block_per_batch=\
|
||||
self.runner.get_max_block_per_batch(),
|
||||
device=self.runner.device)
|
||||
self._paged_kv_indices_tensor = kv_indices
|
||||
self._paged_kv_indptr_tensor = kv_indptr
|
||||
self._paged_kv_last_page_lens_tensor = last_page_lens
|
||||
self._qo_indptr_tensor = qo_indptr
|
||||
|
||||
with super().graph_capture(max_batch_size):
|
||||
yield
|
||||
|
||||
del self._paged_kv_indices_tensor
|
||||
del self._paged_kv_indptr_tensor
|
||||
del self._paged_kv_last_page_lens_tensor
|
||||
del self._qo_indptr_tensor
|
||||
|
||||
def graph_capture_get_metadata_for_batch(
|
||||
self,
|
||||
batch_size: int,
|
||||
is_encoder_decoder_model: bool = False) -> AiterMLAMetadata:
|
||||
|
||||
metadata = super().graph_capture_get_metadata_for_batch(
|
||||
batch_size, is_encoder_decoder_model)
|
||||
|
||||
paged_kv_indptr = self._paged_kv_indptr_tensor[:batch_size + 1]
|
||||
paged_kv_indices = self._paged_kv_indices_tensor
|
||||
paged_kv_last_page_lens = self._paged_kv_last_page_lens_tensor[:
|
||||
batch_size]
|
||||
qo_indptr = self._qo_indptr_tensor[:batch_size + 1]
|
||||
|
||||
metadata.paged_kv_indptr = paged_kv_indptr
|
||||
metadata.paged_kv_indices = paged_kv_indices
|
||||
metadata.paged_kv_last_page_lens = paged_kv_last_page_lens
|
||||
metadata.qo_indptr = qo_indptr
|
||||
|
||||
return metadata
|
||||
|
||||
def get_graph_input_buffers(self,
|
||||
attn_metadata: AiterMLAMetadata,
|
||||
is_encoder_decoder_model: bool = False):
|
||||
input_buffers = super().get_graph_input_buffers(
|
||||
attn_metadata, is_encoder_decoder_model)
|
||||
input_buffers[
|
||||
'paged_kv_indptr'] = attn_metadata.decode_metadata.paged_kv_indptr
|
||||
input_buffers[
|
||||
"paged_kv_indices"] = attn_metadata.\
|
||||
decode_metadata.paged_kv_indices
|
||||
input_buffers[
|
||||
"paged_kv_last_page_lens"] = attn_metadata.\
|
||||
decode_metadata.paged_kv_last_page_lens
|
||||
input_buffers['qo_indptr'] = attn_metadata.qo_indptr
|
||||
|
||||
return input_buffers
|
||||
|
||||
def prepare_graph_input_buffers(self,
|
||||
input_buffers,
|
||||
attn_metadata: AiterMLAMetadata,
|
||||
is_encoder_decoder_model: bool = False):
|
||||
super().prepare_graph_input_buffers(input_buffers, attn_metadata,
|
||||
is_encoder_decoder_model)
|
||||
|
||||
num_total_blocks = attn_metadata.decode_metadata.paged_kv_indices.shape[
|
||||
0]
|
||||
input_buffers["paged_kv_indptr"].copy_(
|
||||
attn_metadata.decode_metadata.paged_kv_indptr, non_blocking=True)
|
||||
input_buffers["paged_kv_indices"][:num_total_blocks].copy_(
|
||||
attn_metadata.decode_metadata.paged_kv_indices, non_blocking=True)
|
||||
input_buffers["paged_kv_last_page_lens"].copy_(
|
||||
attn_metadata.decode_metadata.paged_kv_last_page_lens,
|
||||
non_blocking=True)
|
||||
input_buffers["qo_indptr"].copy_(
|
||||
attn_metadata.decode_metadata.qo_indptr, non_blocking=True)
|
||||
|
||||
|
||||
class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[dict[str, Any]],
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str],
|
||||
# MLA Specific Arguments
|
||||
**mla_args) -> None:
|
||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
blocksparse_params, logits_soft_cap, attn_type,
|
||||
kv_sharing_target_layer_name, **mla_args)
|
||||
|
||||
unsupported_features = [
|
||||
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
|
||||
]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"Aiter MLA does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, blocksparse_params, "
|
||||
"logits_soft_cap")
|
||||
|
||||
from aiter import flash_attn_varlen_func
|
||||
self.flash_attn_varlen_func = flash_attn_varlen_func
|
||||
|
||||
def _flash_attn_varlen_diff_headdims(
|
||||
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
softmax_scale: float, return_softmax_lse: bool,
|
||||
**kwargs) -> Union[tuple[torch.Tensor, ...], torch.Tensor]:
|
||||
output = self.flash_attn_varlen_func(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: AiterMLAMetadata,
|
||||
) -> torch.Tensor:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
|
||||
decode_meta = attn_metadata.decode_metadata
|
||||
assert decode_meta is not None
|
||||
B = q_nope.shape[0]
|
||||
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
||||
o = torch.empty(B,
|
||||
self.num_heads,
|
||||
self.kv_lora_rank,
|
||||
dtype=q.dtype,
|
||||
device=q.device)
|
||||
|
||||
kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2)
|
||||
|
||||
aiter_mla_decode_fwd(q, kv_buffer, o, self.scale,
|
||||
attn_metadata.qo_indptr,
|
||||
attn_metadata.max_query_len,
|
||||
attn_metadata.paged_kv_indptr,
|
||||
attn_metadata.paged_kv_indices,
|
||||
attn_metadata.paged_kv_last_page_lens)
|
||||
|
||||
return self._v_up_proj(o)
|
||||
1096
vllm/attention/backends/rocm_flash_attn.py
Normal file
1096
vllm/attention/backends/rocm_flash_attn.py
Normal file
File diff suppressed because it is too large
Load Diff
707
vllm/attention/backends/torch_sdpa.py
Normal file
707
vllm/attention/backends/torch_sdpa.py
Normal file
@@ -0,0 +1,707 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
""" Attention layer with torch scaled_dot_product_attention
|
||||
and PagedAttention."""
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
from torch.nn.functional import scaled_dot_product_attention
|
||||
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionLayer,
|
||||
AttentionMetadata,
|
||||
AttentionMetadataBuilder,
|
||||
AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
# yapf: enable
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
from vllm.attention.ops.ipex_attn import PagedAttention, _use_ipex
|
||||
from vllm.attention.ops.paged_attn import PagedAttentionMetadata
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import make_tensor_with_pad
|
||||
from vllm.worker.cpu_model_runner import ModelInputForCPUBuilder
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class TorchSDPABackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "TORCH_SDPA"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["TorchSDPABackendImpl"]:
|
||||
return TorchSDPABackendImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> Type["AttentionMetadata"]:
|
||||
return TorchSDPAMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> Type["CommonAttentionState"]:
|
||||
return CommonAttentionState
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> Type["TorchSDPAMetadataBuilder"]:
|
||||
return TorchSDPAMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
|
||||
num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: torch.Tensor,
|
||||
) -> None:
|
||||
raise NotImplementedError("Swap is not supported in TorchSDPABackend.")
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: torch.Tensor,
|
||||
) -> None:
|
||||
PagedAttention.copy_blocks(kv_caches, src_to_dists)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
"""Metadata for TorchSDPABackend.
|
||||
"""
|
||||
# Currently, input sequences can only contain all prompts
|
||||
# or all decoding. True if all sequences are prompts.
|
||||
chunked_prefill: bool
|
||||
seq_lens: Optional[List[int]] = None # For non-chunked prefill
|
||||
|
||||
# For chunked prefill only
|
||||
max_query_len: Optional[int] = None
|
||||
max_kv_len: Optional[int] = None
|
||||
prefill_query_start_loc: Optional[torch.Tensor] = None
|
||||
kv_start_loc: Optional[torch.Tensor] = None
|
||||
prefill_block_tables: Optional[torch.Tensor] = None
|
||||
|
||||
# For V1 logits index only
|
||||
query_start_loc: Optional[torch.Tensor] = None
|
||||
|
||||
# Begin encoder attn & enc/dec cross-attn fields...
|
||||
# Encoder sequence lengths representation
|
||||
encoder_seq_lens: Optional[List[int]] = None
|
||||
encoder_seq_lens_tensor: Optional[torch.Tensor] = None
|
||||
|
||||
# Maximum sequence length among encoder sequences
|
||||
max_encoder_seq_len: Optional[int] = None
|
||||
|
||||
# Number of tokens input to encoder
|
||||
num_encoder_tokens: Optional[int] = None
|
||||
|
||||
# Cross-attention memory-mapping data structures: slot mapping
|
||||
# and block tables
|
||||
cross_slot_mapping: Optional[torch.Tensor] = None
|
||||
cross_block_tables: Optional[torch.Tensor] = None
|
||||
|
||||
def __post_init__(self):
|
||||
# Set during the execution of the first attention op.
|
||||
# It is a list because it is needed to set per prompt
|
||||
# when alibi slopes is used. It is because of the limitation
|
||||
# from xformer API.
|
||||
# will not appear in the __repr__ and __init__
|
||||
self.attn_bias: Optional[List[torch.Tensor]] = None
|
||||
self.encoder_attn_bias: Optional[List[torch.Tensor]] = None
|
||||
self.cross_attn_bias: Optional[List[torch.Tensor]] = None
|
||||
|
||||
@property
|
||||
def is_all_encoder_attn_metadata_set(self):
|
||||
'''
|
||||
All attention metadata required for encoder attention is set.
|
||||
'''
|
||||
return ((self.encoder_seq_lens is not None)
|
||||
and (self.encoder_seq_lens_tensor is not None)
|
||||
and (self.max_encoder_seq_len is not None))
|
||||
|
||||
@property
|
||||
def is_all_cross_attn_metadata_set(self):
|
||||
'''
|
||||
All attention metadata required for enc/dec cross-attention is set.
|
||||
|
||||
Superset of encoder attention required metadata.
|
||||
'''
|
||||
return (self.is_all_encoder_attn_metadata_set
|
||||
and (self.cross_slot_mapping is not None)
|
||||
and (self.cross_block_tables is not None))
|
||||
|
||||
@property
|
||||
def prefill_metadata(self) -> Optional["TorchSDPAMetadata"]:
|
||||
if self.num_prefill_tokens == 0:
|
||||
return None
|
||||
return self
|
||||
|
||||
@property
|
||||
def decode_metadata(self) -> Optional["TorchSDPAMetadata"]:
|
||||
if self.num_decode_tokens == 0:
|
||||
return None
|
||||
return self
|
||||
|
||||
def get_seq_lens(
|
||||
self,
|
||||
attn_type: str,
|
||||
):
|
||||
'''
|
||||
Extract appropriate sequence lengths from attention metadata
|
||||
according to attention type.
|
||||
|
||||
Arguments:
|
||||
|
||||
* attn_metadata: Attention metadata structure associated with attention
|
||||
* attn_type: encoder attention, decoder self-attention,
|
||||
encoder/decoder cross-attention
|
||||
|
||||
Returns:
|
||||
* Appropriate sequence lengths tensor for query
|
||||
* Appropriate sequence lengths tensor for key & value
|
||||
'''
|
||||
|
||||
if (attn_type == AttentionType.DECODER
|
||||
or attn_type == AttentionType.ENCODER_ONLY):
|
||||
seq_lens_q = self.seq_lens
|
||||
seq_lens_kv = self.seq_lens
|
||||
elif attn_type == AttentionType.ENCODER:
|
||||
seq_lens_q = self.encoder_seq_lens
|
||||
seq_lens_kv = self.encoder_seq_lens
|
||||
elif attn_type == AttentionType.ENCODER_DECODER:
|
||||
seq_lens_q = self.seq_lens
|
||||
seq_lens_kv = self.encoder_seq_lens
|
||||
else:
|
||||
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
||||
return seq_lens_q, seq_lens_kv
|
||||
|
||||
def get_attn_bias(
|
||||
self,
|
||||
attn_type: str,
|
||||
) -> Optional[List[torch.Tensor]]:
|
||||
'''
|
||||
Extract appropriate attention bias from attention metadata
|
||||
according to attention type.
|
||||
|
||||
Arguments:
|
||||
|
||||
* attn_metadata: Attention metadata structure associated with attention
|
||||
* attn_type: encoder attention, decoder self-attention,
|
||||
encoder/decoder cross-attention
|
||||
|
||||
Returns:
|
||||
* Appropriate attention bias value given the attention type
|
||||
'''
|
||||
|
||||
if (attn_type == AttentionType.DECODER
|
||||
or attn_type == AttentionType.ENCODER_ONLY):
|
||||
return self.attn_bias
|
||||
elif attn_type == AttentionType.ENCODER:
|
||||
return self.encoder_attn_bias
|
||||
elif attn_type == AttentionType.ENCODER_DECODER:
|
||||
return self.cross_attn_bias
|
||||
else:
|
||||
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
||||
|
||||
def set_attn_bias(
|
||||
self,
|
||||
attn_bias: List[torch.Tensor],
|
||||
attn_type: str,
|
||||
) -> None:
|
||||
'''
|
||||
Update appropriate attention bias field of attention metadata,
|
||||
according to attention type.
|
||||
|
||||
Arguments:
|
||||
|
||||
* attn_metadata: Attention metadata structure associated with attention
|
||||
* attn_bias: The desired attention bias value
|
||||
* attn_type: encoder attention, decoder self-attention,
|
||||
encoder/decoder cross-attention
|
||||
'''
|
||||
|
||||
if (attn_type == AttentionType.DECODER
|
||||
or attn_type == AttentionType.ENCODER_ONLY):
|
||||
self.attn_bias = attn_bias
|
||||
elif attn_type == AttentionType.ENCODER:
|
||||
self.encoder_attn_bias = attn_bias
|
||||
elif attn_type == AttentionType.ENCODER_DECODER:
|
||||
self.cross_attn_bias = attn_bias
|
||||
else:
|
||||
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
||||
|
||||
def get_seq_len_block_table_args(
|
||||
self,
|
||||
attn_type: str,
|
||||
) -> tuple:
|
||||
'''
|
||||
The particular choice of sequence-length- and block-table-related
|
||||
attributes which should be extracted from attn_metadata is dependent
|
||||
on the type of attention operation.
|
||||
|
||||
Decoder attn -> select entirely decoder self-attention-related fields
|
||||
Encoder/decoder cross-attn -> select encoder sequence lengths &
|
||||
cross-attn block-tables fields
|
||||
Encoder attn -> select encoder sequence lengths fields & no block tables
|
||||
|
||||
Arguments:
|
||||
|
||||
* attn_metadata: Attention metadata structure associated with attention
|
||||
* is_prompt: True if prefill, False otherwise
|
||||
* attn_type: encoder attention, decoder self-attention,
|
||||
encoder/decoder cross-attention
|
||||
|
||||
Returns:
|
||||
|
||||
* Appropriate sequence-lengths tensor
|
||||
* Appropriate max sequence-length scalar
|
||||
* Appropriate block tables (or None)
|
||||
'''
|
||||
|
||||
if (attn_type == AttentionType.DECODER
|
||||
or attn_type == AttentionType.ENCODER_ONLY):
|
||||
# Decoder self-attention
|
||||
# Choose max_seq_len based on whether we are in prompt_run
|
||||
return (self.seq_lens_tensor, self.max_decode_seq_len,
|
||||
self.block_tables)
|
||||
elif attn_type == AttentionType.ENCODER_DECODER:
|
||||
# Enc/dec cross-attention KVs match encoder sequence length;
|
||||
# cross-attention utilizes special "cross" block tables
|
||||
return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len,
|
||||
self.cross_block_tables)
|
||||
elif attn_type == AttentionType.ENCODER:
|
||||
# No block tables associated with encoder attention
|
||||
return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len,
|
||||
None)
|
||||
else:
|
||||
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
||||
|
||||
|
||||
class TorchSDPAMetadataBuilder(AttentionMetadataBuilder[TorchSDPAMetadata]):
|
||||
|
||||
def __init__(self, input_builder: ModelInputForCPUBuilder) -> None:
|
||||
self.chunked_prefill = input_builder.chunked_prefill
|
||||
self.input_builder = input_builder
|
||||
|
||||
def prepare(self):
|
||||
self.input_data = self.input_builder.input_data
|
||||
|
||||
def build(self, seq_lens: List[int], query_lens: List[int],
|
||||
cuda_graph_pad_size: int, batch_size: int) -> TorchSDPAMetadata:
|
||||
input_data = self.input_data
|
||||
prefill_seq_lens = seq_lens[0:input_data.num_prefills]
|
||||
prefill_query_lens = query_lens[0:input_data.num_prefills]
|
||||
slot_mapping = torch.tensor(input_data.slot_mapping,
|
||||
dtype=torch.long,
|
||||
device="cpu")
|
||||
|
||||
# For chunked-prefill
|
||||
if self.chunked_prefill and input_data.num_prefill_tokens != 0:
|
||||
prefill_block_tables = make_tensor_with_pad(
|
||||
self.input_data.prefill_block_tables,
|
||||
pad=0,
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
)
|
||||
query_lens_tensor = torch.tensor(prefill_query_lens,
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
kv_lens_tensor = torch.tensor(prefill_seq_lens,
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
query_start_loc = torch.zeros(input_data.num_prefills + 1,
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
kv_start_loc = torch.zeros(input_data.num_prefills + 1,
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
torch.cumsum(query_lens_tensor,
|
||||
dim=0,
|
||||
dtype=torch.int32,
|
||||
out=query_start_loc[1:])
|
||||
torch.cumsum(kv_lens_tensor,
|
||||
dim=0,
|
||||
dtype=torch.int32,
|
||||
out=kv_start_loc[1:])
|
||||
max_query_len = max(prefill_query_lens)
|
||||
max_kv_len = max(prefill_seq_lens)
|
||||
else:
|
||||
prefill_block_tables = None
|
||||
query_start_loc = None
|
||||
kv_start_loc = None
|
||||
max_query_len = None
|
||||
max_kv_len = None
|
||||
|
||||
# For paged attention
|
||||
if input_data.num_decode_tokens != 0:
|
||||
seq_lens_tensor = torch.tensor(
|
||||
input_data.seq_lens[input_data.num_prefills:],
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
)
|
||||
block_tables = make_tensor_with_pad(
|
||||
self.input_data.decode_block_tables,
|
||||
pad=0,
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
)
|
||||
else:
|
||||
block_tables = torch.tensor([])
|
||||
seq_lens_tensor = torch.tensor(
|
||||
input_data.seq_lens[:input_data.num_prefills],
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
)
|
||||
|
||||
# For multi-modal models
|
||||
placeholder_index_maps = None
|
||||
if len(input_data.multi_modal_inputs_list) != 0:
|
||||
placeholder_index_maps = {
|
||||
modality: placeholder_map.index_map()
|
||||
for modality, placeholder_map in
|
||||
input_data.multi_modal_placeholder_maps.items()
|
||||
}
|
||||
|
||||
attn_metadata = TorchSDPAMetadata(
|
||||
chunked_prefill=self.chunked_prefill,
|
||||
seq_lens=prefill_seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_query_len=max_query_len,
|
||||
max_kv_len=max_kv_len,
|
||||
prefill_query_start_loc=query_start_loc,
|
||||
kv_start_loc=kv_start_loc,
|
||||
max_decode_seq_len=input_data.max_decode_seq_len,
|
||||
num_prefills=input_data.num_prefills,
|
||||
num_prefill_tokens=input_data.num_prefill_tokens,
|
||||
num_decode_tokens=input_data.num_decode_tokens,
|
||||
block_tables=block_tables,
|
||||
prefill_block_tables=prefill_block_tables,
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=placeholder_index_maps,
|
||||
enable_kv_scales_calculation=False,
|
||||
)
|
||||
|
||||
return attn_metadata
|
||||
|
||||
|
||||
class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
use_irope: bool = False,
|
||||
) -> None:
|
||||
if kv_sharing_target_layer_name is not None:
|
||||
raise NotImplementedError("KV sharing is not supported in V0.")
|
||||
if blocksparse_params is not None:
|
||||
raise ValueError(
|
||||
"Torch SPDA does not support block-sparse attention.")
|
||||
if logits_soft_cap is not None:
|
||||
logger.warning_once("Torch SPDA does not support logits soft cap. "
|
||||
"Outputs may be slightly off.")
|
||||
if use_irope:
|
||||
logger.warning_once(
|
||||
"Using irope in Torch SPDA is not supported yet, it will fall"
|
||||
" back to global attention for long context.")
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
if alibi_slopes is not None:
|
||||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
||||
self.alibi_slopes = alibi_slopes
|
||||
self.sliding_window = sliding_window
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
self.need_mask = (self.alibi_slopes is not None
|
||||
or self.sliding_window is not None)
|
||||
|
||||
supported_head_sizes = PagedAttention.get_supported_head_sizes()
|
||||
if head_size not in supported_head_sizes:
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by PagedAttention. "
|
||||
f"Supported head sizes are: {supported_head_sizes}.")
|
||||
|
||||
if is_quantized_kv_cache(kv_cache_dtype) and not _use_ipex:
|
||||
raise NotImplementedError(
|
||||
"Torch SDPA backend FP8 KV cache requires "
|
||||
"intel_extension_for_pytorch support.")
|
||||
self.attn_type = attn_type
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: TorchSDPAMetadata, # type: ignore
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with torch SDPA and PagedAttention.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads * head_size]
|
||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
|
||||
NOTE: kv_cache will be an empty tensor with shape [0]
|
||||
for profiling run.
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
if output_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for TorchSDPABackendImpl")
|
||||
|
||||
# For warming-up
|
||||
if attn_metadata is None:
|
||||
return query
|
||||
|
||||
attn_type = self.attn_type
|
||||
if (attn_type == AttentionType.ENCODER
|
||||
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
|
||||
raise AttributeError("Encoder attention requires setting "
|
||||
"encoder metadata attributes.")
|
||||
elif (attn_type == AttentionType.ENCODER_DECODER
|
||||
and (not attn_metadata.is_all_cross_attn_metadata_set)):
|
||||
raise AttributeError("Encoder/decoder cross-attention "
|
||||
"requires setting cross-attention "
|
||||
"metadata attributes.")
|
||||
|
||||
# Reshape the query, key, and value tensors.
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
if key is not None:
|
||||
assert value is not None
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
else:
|
||||
assert value is None
|
||||
|
||||
if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0):
|
||||
# KV-cache during decoder-self- or
|
||||
# encoder-decoder-cross-attention, but not
|
||||
# during encoder attention.
|
||||
#
|
||||
# Even if there are no new key/value pairs to cache,
|
||||
# we still need to break out key_cache and value_cache
|
||||
# i.e. for later use by paged attention
|
||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||
kv_cache, self.num_kv_heads, self.head_size)
|
||||
|
||||
if (key is not None) and (value is not None):
|
||||
if attn_type == AttentionType.ENCODER_DECODER:
|
||||
# Update cross-attention KV cache (prefill-only)
|
||||
# During cross-attention decode, key & value will be None,
|
||||
# preventing this IF-statement branch from running
|
||||
updated_slot_mapping = attn_metadata.cross_slot_mapping
|
||||
else:
|
||||
# Update self-attention KV cache (prefill/decode)
|
||||
updated_slot_mapping = attn_metadata.slot_mapping
|
||||
|
||||
PagedAttention.write_to_paged_cache(
|
||||
key, value, key_cache, value_cache, updated_slot_mapping,
|
||||
self.kv_cache_dtype, layer._k_scale, layer._v_scale)
|
||||
|
||||
if attn_type != AttentionType.ENCODER:
|
||||
# Decoder self-attention supports chunked prefill.
|
||||
# Encoder/decoder cross-attention requires no chunked
|
||||
# prefill (100% prefill or 100% decode tokens, no mix)
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
else:
|
||||
# Encoder attention - chunked prefill is not applicable;
|
||||
# derive token-count from query shape & and treat them
|
||||
# as 100% prefill tokens
|
||||
assert attn_metadata.num_encoder_tokens is not None
|
||||
num_prefill_tokens = attn_metadata.num_encoder_tokens
|
||||
num_decode_tokens = 0
|
||||
|
||||
if attn_type == AttentionType.DECODER:
|
||||
# Only enforce this shape-constraint for decoder
|
||||
# self-attention
|
||||
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
|
||||
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
|
||||
|
||||
output = torch.empty_like(query)
|
||||
if prefill_meta := attn_metadata.prefill_metadata:
|
||||
if not prefill_meta.prefill_metadata.chunked_prefill: # type: ignore
|
||||
assert attn_metadata.seq_lens is not None
|
||||
self._run_sdpa_forward(output,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
prefill_meta,
|
||||
attn_type=attn_type)
|
||||
else:
|
||||
# prefix-enabled attention
|
||||
assert not self.need_mask
|
||||
import intel_extension_for_pytorch.llm.modules as ipex_modules
|
||||
output = torch.empty_like(query)
|
||||
ipex_modules.PagedAttention.flash_attn_varlen_func(
|
||||
output[:prefill_meta.num_prefill_tokens, :, :],
|
||||
query[:prefill_meta.num_prefill_tokens, :, :],
|
||||
key_cache,
|
||||
value_cache,
|
||||
prefill_meta.prefill_query_start_loc,
|
||||
prefill_meta.kv_start_loc,
|
||||
prefill_meta.max_query_len,
|
||||
prefill_meta.max_kv_len,
|
||||
self.scale,
|
||||
True,
|
||||
prefill_meta.prefill_block_tables,
|
||||
self.alibi_slopes,
|
||||
)
|
||||
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
assert attn_type != AttentionType.ENCODER_ONLY, (
|
||||
"Encoder-only models should not have decode metadata.")
|
||||
# Decoding run.
|
||||
(
|
||||
seq_lens_arg,
|
||||
max_seq_len_arg,
|
||||
block_tables_arg,
|
||||
) = decode_meta.get_seq_len_block_table_args(attn_type)
|
||||
|
||||
PagedAttention.forward_decode(
|
||||
output[attn_metadata.num_prefill_tokens:, :, :],
|
||||
query[attn_metadata.num_prefill_tokens:, :, :],
|
||||
key_cache,
|
||||
value_cache,
|
||||
block_tables_arg,
|
||||
seq_lens_arg,
|
||||
max_seq_len_arg,
|
||||
self.kv_cache_dtype,
|
||||
self.num_kv_heads,
|
||||
self.scale,
|
||||
self.alibi_slopes,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
# Reshape the output tensor.
|
||||
return output.view(-1, self.num_heads * self.head_size)
|
||||
|
||||
def _run_sdpa_forward(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_metadata: TorchSDPAMetadata,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
) -> None:
|
||||
if self.num_kv_heads != self.num_heads:
|
||||
key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
|
||||
value = value.repeat_interleave(self.num_queries_per_kv, dim=1)
|
||||
|
||||
attn_masks = attn_metadata.get_attn_bias(attn_type)
|
||||
if attn_masks is None:
|
||||
if self.alibi_slopes is not None:
|
||||
attn_masks = _make_alibi_bias(
|
||||
self.alibi_slopes, query.dtype,
|
||||
attn_metadata.seq_lens) # type: ignore
|
||||
elif self.sliding_window is not None:
|
||||
assert attn_metadata.seq_lens is not None
|
||||
attn_masks = _make_sliding_window_bias(
|
||||
attn_metadata.seq_lens, self.sliding_window,
|
||||
query.dtype) # type: ignore
|
||||
else:
|
||||
seq_lens, _ = attn_metadata.get_seq_lens(attn_type)
|
||||
attn_masks = [None] * len(seq_lens)
|
||||
attn_metadata.set_attn_bias(attn_masks, attn_type)
|
||||
|
||||
query = query.movedim(0, query.dim() - 2)
|
||||
key = key.movedim(0, key.dim() - 2)
|
||||
value = value.movedim(0, value.dim() - 2)
|
||||
|
||||
causal_attn = (attn_type == AttentionType.DECODER)
|
||||
|
||||
seq_lens_q, seq_lens_kv = attn_metadata.get_seq_lens(attn_type)
|
||||
start_q, start_kv = 0, 0
|
||||
for seq_len_q, seq_len_kv, mask in zip(seq_lens_q, seq_lens_kv,
|
||||
attn_masks):
|
||||
end_q = start_q + seq_len_q
|
||||
end_kv = start_kv + seq_len_kv
|
||||
sub_out = scaled_dot_product_attention(
|
||||
query[None, :, start_q:end_q, :],
|
||||
key[None, :, start_kv:end_kv, :],
|
||||
value[None, :, start_kv:end_kv, :],
|
||||
attn_mask=mask,
|
||||
dropout_p=0.0,
|
||||
is_causal=causal_attn and mask is None,
|
||||
scale=self.scale).squeeze(0).movedim(query.dim() - 2, 0)
|
||||
output[start_q:end_q, :, :] = sub_out
|
||||
start_q, start_kv = end_q, end_kv
|
||||
|
||||
|
||||
def _make_alibi_bias(
|
||||
alibi_slopes: torch.Tensor,
|
||||
dtype: torch.dtype,
|
||||
seq_lens: List[int],
|
||||
) -> List[torch.Tensor]:
|
||||
attn_biases: List[torch.Tensor] = []
|
||||
for seq_len in seq_lens:
|
||||
bias = torch.arange(seq_len, dtype=dtype)
|
||||
# NOTE(zhuohan): HF uses
|
||||
# `bias = bias[None, :].repeat(seq_len, 1)`
|
||||
# here. We find that both biases give the same results, but
|
||||
# the bias below more accurately follows the original ALiBi
|
||||
# paper.
|
||||
bias = bias[None, :] - bias[:, None]
|
||||
|
||||
num_heads = alibi_slopes.shape[0]
|
||||
bias = bias[None, :].repeat((num_heads, 1, 1))
|
||||
bias.mul_(alibi_slopes[:, None, None]).unsqueeze_(0)
|
||||
inf_mask = torch.empty(
|
||||
(1, seq_len, seq_len),
|
||||
dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1)
|
||||
attn_biases.append((bias + inf_mask).to(dtype))
|
||||
|
||||
return attn_biases
|
||||
|
||||
|
||||
def _make_sliding_window_bias(
|
||||
seq_lens: List[int],
|
||||
window_size: Optional[int],
|
||||
dtype: torch.dtype,
|
||||
) -> List[torch.Tensor]:
|
||||
attn_biases: List[torch.Tensor] = []
|
||||
for seq_len in seq_lens:
|
||||
tensor = torch.full(
|
||||
(1, seq_len, seq_len),
|
||||
dtype=dtype,
|
||||
fill_value=1,
|
||||
)
|
||||
shift = 0
|
||||
mask = torch.tril(tensor, diagonal=shift).to(dtype) # type: ignore
|
||||
if window_size is not None:
|
||||
mask = torch.triu(mask, diagonal=shift - window_size + 1)
|
||||
mask = torch.log(mask)
|
||||
attn_biases.append(mask.to(dtype))
|
||||
|
||||
return attn_biases
|
||||
55
vllm/attention/backends/tree_decoding_utils.py
Normal file
55
vllm/attention/backends/tree_decoding_utils.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Type, TypeVar, Union, Optional
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.blocksparse_attn import BlocksparseFlashAttentionImpl
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.ops.paged_attn import PagedAttention
|
||||
|
||||
def move_cache(
|
||||
backend,
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> None:
|
||||
if backend.get_name() == "rocm-flash-attn" or \
|
||||
backend.get_name() == "xformers":
|
||||
|
||||
key_caches = []
|
||||
value_caches = []
|
||||
|
||||
num_layers = len(kv_caches)
|
||||
token_num = src_to_dists.shape[0]
|
||||
|
||||
tmp_store_kv = torch.empty(
|
||||
(2, num_layers, token_num, num_kv_heads, head_size),
|
||||
dtype=kv_caches[0].dtype, device=kv_caches[0].device)
|
||||
keys = tmp_store_kv[0].contiguous()
|
||||
values = tmp_store_kv[1].contiguous()
|
||||
|
||||
for kv_cache in kv_caches:
|
||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||
kv_cache, num_kv_heads, head_size)
|
||||
key_caches.append(key_cache)
|
||||
value_caches.append(value_cache)
|
||||
|
||||
ops.read_cache(
|
||||
keys,
|
||||
values,
|
||||
key_caches,
|
||||
value_caches,
|
||||
src_to_dists[:, 0].contiguous(),
|
||||
kv_cache_dtype
|
||||
)
|
||||
|
||||
ops.write_cache_multi_layers(
|
||||
keys,
|
||||
values,
|
||||
key_caches,
|
||||
value_caches,
|
||||
src_to_dists[:, 1].contiguous(),
|
||||
kv_cache_dtype
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Only BlocksparseFlashAttention/ROCmFlash/XFormers backends support move cache for now!")
|
||||
184
vllm/attention/backends/triton_config.py
Normal file
184
vllm/attention/backends/triton_config.py
Normal file
@@ -0,0 +1,184 @@
|
||||
import functools
|
||||
import json
|
||||
import torch
|
||||
import os
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
import bisect
|
||||
from vllm.logger import init_logger
|
||||
logger = init_logger(__name__)
|
||||
|
||||
class KERNLE_KINDS(Enum):
|
||||
v1_2stages = 0
|
||||
v1_2stages_tc = 1
|
||||
v2 = 2
|
||||
v2_tc = 3
|
||||
TOTAL_KIND = 4
|
||||
|
||||
class BestConfig():
|
||||
def __init__(self):
|
||||
self.batch_size = 0
|
||||
self.seq_len = 0
|
||||
self.kernel_kind = KERNLE_KINDS.TOTAL_KIND
|
||||
self.BLOCK_N = 0
|
||||
self.BLOCK_DIM = 0
|
||||
# self.BLOCK_SEQ = 0
|
||||
# self.SPLIT_K = 0
|
||||
self.num_stages = 0
|
||||
self.num_warps = 0
|
||||
self.NUM_KV_SPLITS = 0
|
||||
self.BLOCK_N_2 = 0
|
||||
self.num_stages_2 = 0
|
||||
self.num_warps_2 = 0
|
||||
self.best_us = 0
|
||||
self.decode_fwd_stage1 = None
|
||||
self.decode_fwd_stage2 = None
|
||||
|
||||
def get_mla_config_file_name(QH: int, KVH: int, QKD: int, VD: int, cache_dtype: Optional[str]) -> str:
|
||||
if cache_dtype == "default":
|
||||
return f"QH={QH}_KVH={KVH}_QKD={QKD}_VD={VD}_default.json"
|
||||
|
||||
device_name = torch.cuda.get_device_name().replace(" ", "_")
|
||||
if "K100_AI" in device_name:
|
||||
return f"QH={QH}_KVH={KVH}_QKD={QKD}_VD={VD}_{cache_dtype}_K100AI.json"
|
||||
elif "BW" in device_name:
|
||||
return f"QH={QH}_KVH={KVH}_QKD={QKD}_VD={VD}_{cache_dtype}_BW.json"
|
||||
else:
|
||||
raise ValueError(f"Unsurpport device name: {device_name}")
|
||||
|
||||
|
||||
def get_attention_mla_configs_json(QH: int, KVH: int, QKD: int, VD: int, cache_dtype: Optional[str]) -> Optional[Dict[Any, Any]]:
|
||||
|
||||
# First look up if an optimized configuration is available in the configs
|
||||
# directory
|
||||
json_file_name = get_mla_config_file_name(QH, KVH, QKD, VD, cache_dtype)
|
||||
|
||||
config_file_path = os.path.join(
|
||||
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
|
||||
)
|
||||
if os.path.exists(config_file_path):
|
||||
with open(config_file_path) as f:
|
||||
# logger.info("Using decode attention configuration from %s for attention layer.", config_file_path)
|
||||
# If a configuration has been found, return it
|
||||
return json.load(f)
|
||||
else:
|
||||
logger.warning("Can not find best decode attention configuration %s for attention layer, it may not have the best performance to use default json. Please tune one. ", config_file_path)
|
||||
|
||||
json_file_name = get_mla_config_file_name(16, 1, 576, 512, "default")
|
||||
config_file_path = os.path.join(
|
||||
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
|
||||
)
|
||||
if os.path.exists(config_file_path):
|
||||
with open(config_file_path) as f:
|
||||
logger.warning("Using default decode attention configuration from %s for attention layer. It may not have the best performance to use default json. ", config_file_path)
|
||||
# If a configuration has been found, return it
|
||||
return json.load(f)
|
||||
else:
|
||||
raise ValueError("Please surpport default config can match 16 1 576 512")
|
||||
|
||||
# If no optimized configuration is available, we will use the default
|
||||
# configuration
|
||||
return None
|
||||
|
||||
|
||||
def get_config_map(attention_configs):
|
||||
ret_map = {}
|
||||
for bs in attention_configs.keys():
|
||||
int_bs = int(bs)
|
||||
seq_map = {}
|
||||
seq_configs = attention_configs[bs]
|
||||
ret_map[int_bs] = seq_map
|
||||
for seq_len in seq_configs.keys():
|
||||
int_seq_len = int(seq_len)
|
||||
kind_config = seq_configs[seq_len]
|
||||
configs = BestConfig()
|
||||
# configs.batch_size = int_bs
|
||||
# configs.seq_len = int_seq_len
|
||||
configs.best_us = kind_config['best_us']
|
||||
seq_map[int_seq_len] = configs
|
||||
if kind_config['kernel_kind'] == 'v1_2stages':
|
||||
best_config = kind_config['best_config']
|
||||
stage1 = best_config['stage1']
|
||||
stage2 = best_config['stage2']
|
||||
configs.kernel_kind = KERNLE_KINDS.v1_2stages
|
||||
# configs.SPLIT_K = stage1['SPLIT_K']
|
||||
configs.BLOCK_N = stage1['BLOCK_N']
|
||||
configs.num_stages = stage1['num_stages']
|
||||
configs.num_warps = stage1['num_warps']
|
||||
configs.BLOCK_N_2 = stage2['BLOCK_N']
|
||||
configs.num_stages_2 = stage2['num_stages']
|
||||
configs.num_warps_2 = stage2['num_warps']
|
||||
elif kind_config['kernel_kind'] == 'v1_2stages_tc':
|
||||
best_config = kind_config['best_config']
|
||||
stage1 = best_config['stage1']
|
||||
stage2 = best_config['stage2']
|
||||
configs.kernel_kind = KERNLE_KINDS.v1_2stages_tc
|
||||
# configs.SPLIT_K = stage1['SPLIT_K']
|
||||
configs.BLOCK_N = stage1['BLOCK_N']
|
||||
configs.num_stages = stage1['num_stages']
|
||||
configs.num_warps = stage1['num_warps']
|
||||
configs.BLOCK_N_2 = stage2['BLOCK_N']
|
||||
configs.num_stages_2 = stage2['num_stages']
|
||||
configs.num_warps_2 = stage2['num_warps']
|
||||
elif kind_config['kernel_kind'] == 'v2':
|
||||
best_config = kind_config['best_config']
|
||||
stage1 = best_config['stage1']
|
||||
stage2 = best_config['stage2']
|
||||
configs.kernel_kind = KERNLE_KINDS.v2
|
||||
# if 'BLOCK_SEQ' in stage1:
|
||||
# configs.BLOCK_SEQ = stage1['BLOCK_SEQ']
|
||||
# else:
|
||||
# configs.NUM_KV_SPLITS = stage1['NUM_KV_SPLITS']
|
||||
configs.BLOCK_N = stage1['BLOCK_N']
|
||||
configs.num_stages = stage1['num_stages']
|
||||
configs.num_warps = stage1['num_warps']
|
||||
configs.num_stages_2 = stage2['num_stages']
|
||||
configs.num_warps_2 = stage2['num_warps']
|
||||
elif kind_config['kernel_kind'] == 'v2_tc':
|
||||
best_config = kind_config['best_config']
|
||||
stage1 = best_config['stage1']
|
||||
stage2 = best_config['stage2']
|
||||
configs.kernel_kind = KERNLE_KINDS.v2_tc
|
||||
# if 'BLOCK_SEQ' in stage1:
|
||||
# configs.BLOCK_SEQ = stage1['BLOCK_SEQ']
|
||||
# else:
|
||||
# configs.NUM_KV_SPLITS = stage1['NUM_KV_SPLITS']
|
||||
configs.BLOCK_N = stage1['BLOCK_N']
|
||||
configs.BLOCK_DIM = stage1['BLOCK_DIM']
|
||||
configs.num_stages = stage1['num_stages']
|
||||
configs.num_warps = stage1['num_warps']
|
||||
configs.num_stages_2 = stage2['num_stages']
|
||||
configs.num_warps_2 = stage2['num_warps']
|
||||
return ret_map
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def get_attention_mla_configs(QH: int, KVH: int, QKD: int, VD: int, cache_dtype: Optional[str]) -> Optional[Dict[Any, Any]]:
|
||||
attention_configs = get_attention_mla_configs_json(QH, KVH, QKD, VD, cache_dtype)
|
||||
return get_config_map(attention_configs)
|
||||
|
||||
|
||||
def get_closest_key(dic_keys, target_key):
|
||||
keys = list(dic_keys)
|
||||
idx = bisect.bisect_left(keys, target_key)
|
||||
if idx == 0:
|
||||
return keys[0]
|
||||
if idx == len(keys):
|
||||
return keys[-1]
|
||||
left_key = keys[idx - 1]
|
||||
right_key = keys[idx]
|
||||
if target_key - left_key <= right_key - target_key:
|
||||
return left_key
|
||||
else:
|
||||
return right_key
|
||||
|
||||
def get_nearest_config(bs_key, mean_kv_seqlen_key, config):
|
||||
closest_bs_key = get_closest_key(config.keys(), bs_key)
|
||||
closest_mean_kv_seqlen_key = get_closest_key(config[closest_bs_key].keys(), mean_kv_seqlen_key)
|
||||
return config[closest_bs_key][closest_mean_kv_seqlen_key]
|
||||
|
||||
def get_config(bs_key, mean_kv_seqlen_key, config):
|
||||
if bs_key in config and mean_kv_seqlen_key in config[bs_key]:
|
||||
return config[bs_key][mean_kv_seqlen_key]
|
||||
else:
|
||||
raise ValueError(f"No matching configuration found for bs key: {bs_key} and mean kv seq key: {mean_kv_seqlen_key} when init decode attention db")
|
||||
135
vllm/attention/backends/triton_mla.py
Normal file
135
vllm/attention/backends/triton_mla.py
Normal file
@@ -0,0 +1,135 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
from .triton_config import get_nearest_config, get_attention_mla_configs, get_config, get_attention_mla_configs_json
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
from vllm.attention.backends.mla.common import (MLACommonBackend,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata)
|
||||
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
|
||||
import vllm.envs as envs
|
||||
|
||||
from vllm.logger import init_logger
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class TritonMLABackend(MLACommonBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "TRITON_MLA"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["TritonMLAImpl"]:
|
||||
return TritonMLAImpl
|
||||
|
||||
|
||||
class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[Dict[str, Any]],
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str],
|
||||
# MLA Specific Arguments
|
||||
**mla_args) -> None:
|
||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
blocksparse_params, logits_soft_cap, attn_type,
|
||||
kv_sharing_target_layer_name, **mla_args)
|
||||
|
||||
unsupported_features = [
|
||||
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
|
||||
]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"TritonMLAImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, blocksparse_params, "
|
||||
"logits_soft_cap")
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"TritonMLAImpl")
|
||||
|
||||
if envs.VLLM_USE_TRITON_OPT_MLA:
|
||||
self.attn_configs = get_attention_mla_configs_json(self.num_heads, 1, self.kv_lora_rank + self.qk_rope_head_dim, self.kv_lora_rank, "fp16")
|
||||
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
raise NotImplementedError(
|
||||
"TritonMLA with FP8 KV cache not yet supported")
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: MLACommonMetadata,
|
||||
) -> torch.Tensor:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
|
||||
decode_meta = attn_metadata.decode_metadata
|
||||
assert decode_meta is not None
|
||||
B = q_nope.shape[0]
|
||||
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
||||
o = torch.zeros(B,
|
||||
self.num_heads,
|
||||
self.kv_lora_rank,
|
||||
dtype=q.dtype,
|
||||
device=q.device)
|
||||
|
||||
num_kv_splits = 4 # TODO: heuristic
|
||||
|
||||
# TODO(lucas) Allocate ahead of time
|
||||
attn_logits = torch.empty(
|
||||
(
|
||||
B,
|
||||
self.num_heads,
|
||||
num_kv_splits,
|
||||
# NOTE(lucas) idk why the +1 is here but sglang has it so we
|
||||
# just mirror that
|
||||
self.kv_lora_rank + 1,
|
||||
),
|
||||
dtype=torch.float32,
|
||||
device=q.device,
|
||||
)
|
||||
|
||||
# Add a head dim of 1
|
||||
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2)
|
||||
kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank]
|
||||
PAGE_SIZE = kv_c_and_k_pe_cache.size(1)
|
||||
|
||||
# TODO
|
||||
max_seq_len = torch.max(decode_meta.seq_lens_tensor).item()
|
||||
if os.environ.get('PA_MATCH_USE_MEAN_SEQ') == '1':
|
||||
match_seq_len = int((decode_meta.seq_lens_tensor.sum()/ max(1, B)).item())
|
||||
else:
|
||||
match_seq_len = max_seq_len
|
||||
|
||||
if envs.VLLM_USE_TRITON_OPT_MLA:
|
||||
best_config = self.attn_configs[min(self.attn_configs.keys(), key=lambda x: abs(int(x) - match_seq_len))]
|
||||
|
||||
# Run MQA
|
||||
decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o,
|
||||
decode_meta.block_tables,
|
||||
decode_meta.seq_lens_tensor, attn_logits,
|
||||
num_kv_splits, self.scale, best_config, PAGE_SIZE)
|
||||
|
||||
return self._v_up_proj(o)
|
||||
635
vllm/attention/backends/utils.py
Normal file
635
vllm/attention/backends/utils.py
Normal file
@@ -0,0 +1,635 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Attention backend utils"""
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from itertools import accumulate
|
||||
from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type,
|
||||
TypeVar, Union)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder,
|
||||
AttentionState)
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal import MultiModalPlaceholderMap
|
||||
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner_base import ModelRunnerBase
|
||||
|
||||
# Error string(s) for encoder/decoder
|
||||
# unsupported attention scenarios
|
||||
STR_NOT_IMPL_ENC_DEC_ROCM_HIP = ("ROCm/HIP is not currently supported "
|
||||
"with encoder/decoder models.")
|
||||
|
||||
PAD_SLOT_ID = -1
|
||||
|
||||
# Switch to numpy implementation of compute_slot_mapping
|
||||
# if we have at least this many elements. Could be tuned further.
|
||||
_COMPUTE_SLOT_MAPPING_NUMPY_NUMEL = 256
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import ModelInputForGPUBuilder
|
||||
|
||||
|
||||
def is_block_tables_empty(block_tables: Union[None, Dict]):
|
||||
"""
|
||||
Check if block_tables is None or a dictionary with all None values.
|
||||
"""
|
||||
if block_tables is None:
|
||||
return True
|
||||
return (isinstance(block_tables, dict)
|
||||
and all(value is None for value in block_tables.values()))
|
||||
|
||||
|
||||
def compute_slot_mapping_start_idx(is_prompt: bool, query_len: int,
|
||||
context_len: int, sliding_window: int):
|
||||
"""
|
||||
Compute the start index of slot mapping.
|
||||
"""
|
||||
start_idx = 0
|
||||
if is_prompt and sliding_window is not None:
|
||||
start_idx = max(0, query_len - sliding_window)
|
||||
return start_idx
|
||||
|
||||
|
||||
def _compute_slot_mapping_python(slot_mapping: List[int],
|
||||
block_table: List[int], range_start: int,
|
||||
range_end: int, block_size: int):
|
||||
for i in range(range_start, range_end):
|
||||
block_number = block_table[i // block_size]
|
||||
block_offset = i % block_size
|
||||
slot = block_number * block_size + block_offset
|
||||
slot_mapping.append(slot)
|
||||
|
||||
|
||||
def _compute_slot_mapping_numpy(slot_mapping: List[int],
|
||||
block_table: List[int], range_start: int,
|
||||
range_end: int, block_size: int):
|
||||
block_table_array = np.array(block_table)
|
||||
idx = np.arange(range_start, range_end)
|
||||
block_offset = idx % block_size
|
||||
idx //= block_size
|
||||
seq_slot_mapping_array = block_table_array[idx]
|
||||
seq_slot_mapping_array *= block_size
|
||||
seq_slot_mapping_array += block_offset
|
||||
slot_mapping.extend(seq_slot_mapping_array)
|
||||
|
||||
|
||||
def compute_slot_mapping(is_profile_run: bool, slot_mapping: List[int],
|
||||
seq_id: int, seq_len: int, context_len: int,
|
||||
start_idx: int, block_size: int,
|
||||
block_tables: Dict[int, List[int]]):
|
||||
"""
|
||||
Compute slot mapping.
|
||||
"""
|
||||
if is_profile_run:
|
||||
# During memory profiling, the block tables are not
|
||||
# initialized yet. In this case, we just use a dummy
|
||||
# slot mapping.
|
||||
# In embeddings, the block tables are {seq_id: None}.
|
||||
slot_mapping.extend([PAD_SLOT_ID] * seq_len)
|
||||
return
|
||||
|
||||
# Mask the [0, start_idx) tokens of the prompt with
|
||||
# PAD_SLOT_ID, where start_idx is max(0, seq_len -
|
||||
# sliding_window). For example, if the prompt len is 10,
|
||||
# sliding window is 8, and block size is 4, the first two
|
||||
# tokens are masked and the slot mapping will be
|
||||
# [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
|
||||
padding_mask_len = max(0, start_idx - context_len)
|
||||
slot_mapping.extend([PAD_SLOT_ID] * padding_mask_len)
|
||||
|
||||
range_start = max(start_idx, context_len)
|
||||
range_end = seq_len
|
||||
numel = range_end - range_start
|
||||
block_table = block_tables[seq_id]
|
||||
|
||||
# numpy implementation will be faster than python if we have
|
||||
# many elements, otherwise it will be slower.
|
||||
if numel < _COMPUTE_SLOT_MAPPING_NUMPY_NUMEL:
|
||||
_compute_slot_mapping_python(slot_mapping, block_table, range_start,
|
||||
range_end, block_size)
|
||||
else:
|
||||
_compute_slot_mapping_numpy(slot_mapping, block_table, range_start,
|
||||
range_end, block_size)
|
||||
|
||||
|
||||
TAttentionMetadata = TypeVar("TAttentionMetadata", bound='AttentionMetadata')
|
||||
|
||||
|
||||
class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
|
||||
|
||||
_metadata_cls: Type[TAttentionMetadata]
|
||||
|
||||
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
|
||||
self.input_builder = input_builder
|
||||
self.runner = input_builder.runner
|
||||
|
||||
self.sliding_window = input_builder.sliding_window
|
||||
self.block_size = input_builder.block_size
|
||||
|
||||
def prepare(self):
|
||||
self.slot_mapping: List[int] = []
|
||||
self.prefill_seq_lens: List[int] = []
|
||||
self.context_lens: List[int] = []
|
||||
self.block_tables: List[List[int]] = []
|
||||
self.curr_seq_lens: List[int] = []
|
||||
self.multimodal_placeholder_maps: Dict[
|
||||
str,
|
||||
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
|
||||
self.num_prefills = 0
|
||||
self.num_prefill_tokens = 0
|
||||
self.num_decode_tokens = 0
|
||||
|
||||
def _add_seq_group(
|
||||
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
|
||||
chunked_prefill_enabled: bool):
|
||||
is_prompt = inter_data.is_prompt
|
||||
block_tables = inter_data.block_tables
|
||||
|
||||
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
|
||||
curr_sliding_window_block) in zip(
|
||||
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
|
||||
inter_data.orig_seq_lens, inter_data.seq_lens,
|
||||
inter_data.query_lens, inter_data.context_lens,
|
||||
inter_data.curr_sliding_window_blocks):
|
||||
self.context_lens.append(context_len)
|
||||
if is_prompt:
|
||||
mm_maps = inter_data.multi_modal_placeholder_maps
|
||||
if mm_maps:
|
||||
for modality, placeholders in mm_maps.items():
|
||||
self.multimodal_placeholder_maps[modality].extend(
|
||||
placeholders)
|
||||
|
||||
self.num_prefills += 1
|
||||
self.num_prefill_tokens += token_len
|
||||
self.prefill_seq_lens.append(seq_len)
|
||||
else:
|
||||
assert query_len == 1, (
|
||||
"seq_len: {}, context_len: {}, query_len: {}".format(
|
||||
seq_len, context_len, query_len))
|
||||
self.num_decode_tokens += query_len
|
||||
self.curr_seq_lens.append(curr_seq_len)
|
||||
|
||||
# Compute block table.
|
||||
# TODO(sang): Combine chunked prefill and prefix caching by
|
||||
# only allowing multiple of block_size chunk size.
|
||||
# NOTE: This only works for oooooooxxx style attention.
|
||||
block_table = []
|
||||
if inter_data.prefix_cache_hit:
|
||||
block_table = block_tables[seq_id]
|
||||
elif ((chunked_prefill_enabled or not is_prompt)
|
||||
and block_tables is not None):
|
||||
if curr_sliding_window_block == 0:
|
||||
block_table = block_tables[seq_id]
|
||||
else:
|
||||
block_table = block_tables[seq_id][
|
||||
-curr_sliding_window_block:]
|
||||
self.block_tables.append(block_table)
|
||||
|
||||
# Compute slot mapping.
|
||||
is_profile_run = is_block_tables_empty(block_tables)
|
||||
start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
|
||||
context_len,
|
||||
self.sliding_window)
|
||||
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
|
||||
seq_len, context_len, start_idx,
|
||||
self.block_size, inter_data.block_tables)
|
||||
|
||||
def build(self, seq_lens: List[int], query_lens: List[int],
|
||||
cuda_graph_pad_size: int, batch_size: int):
|
||||
"""Build attention metadata with on-device tensors.
|
||||
|
||||
Args:
|
||||
seq_lens: The maybe padded sequence lengths of the input sequences.
|
||||
query_lens: The query lengths of the input sequences.
|
||||
cuda_graph_pad_size: The padding size for cuda graph.
|
||||
-1 if cuda graph is not used.
|
||||
batch_size: The maybe padded batch size.
|
||||
"""
|
||||
for inter_data in self.input_builder.inter_data_list:
|
||||
self._add_seq_group(inter_data,
|
||||
self.input_builder.chunked_prefill_enabled)
|
||||
|
||||
device = self.runner.device
|
||||
use_captured_graph = cuda_graph_pad_size != -1
|
||||
|
||||
max_query_len = max(query_lens)
|
||||
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
|
||||
max_decode_seq_len = max(self.curr_seq_lens, default=0)
|
||||
num_decode_tokens = self.num_decode_tokens
|
||||
query_start_loc = list(accumulate(query_lens, initial=0))
|
||||
seq_start_loc = list(accumulate(seq_lens, initial=0))
|
||||
|
||||
if use_captured_graph:
|
||||
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
|
||||
self.block_tables.extend([] * cuda_graph_pad_size)
|
||||
num_decode_tokens = batch_size
|
||||
|
||||
# The shape of graph_block_tables is
|
||||
# [max batch size, max context len // block size].
|
||||
input_block_tables = self.runner.graph_block_tables[:batch_size]
|
||||
for i, block_table in enumerate(self.block_tables):
|
||||
if block_table:
|
||||
input_block_tables[i, :len(block_table)] = block_table
|
||||
# block_tables = torch.from_numpy(input_block_tables).to(
|
||||
# device, non_blocking=True)
|
||||
block_tables = torch.from_numpy(input_block_tables).pin_memory().to(
|
||||
device, non_blocking=True)
|
||||
|
||||
else:
|
||||
has_empty: bool = any(len(bt) == 0 for bt in self.block_tables)
|
||||
has_non_empty = any(len(bt) > 0 for bt in self.block_tables)
|
||||
max_block_length = 0
|
||||
if has_empty and has_non_empty:
|
||||
for inter_data in self.input_builder.inter_data_list:
|
||||
block_tables = inter_data.block_tables
|
||||
if block_tables:
|
||||
for seq_id in inter_data.seq_ids:
|
||||
if seq_id in block_tables:
|
||||
block_table = block_tables[seq_id]
|
||||
max_block_length = max(max_block_length, len(block_table))
|
||||
if max_block_length >0:
|
||||
block_tables = make_tensor_with_pad(
|
||||
self.block_tables,
|
||||
pad=0,
|
||||
dtype=torch.int,
|
||||
device=device,
|
||||
max_len=max_block_length,
|
||||
)
|
||||
else:
|
||||
block_tables = make_tensor_with_pad(
|
||||
self.block_tables,
|
||||
pad=0,
|
||||
dtype=torch.int,
|
||||
device=device,
|
||||
)
|
||||
|
||||
assert max_query_len > 0, "query_lens: {}".format(query_lens)
|
||||
|
||||
assert device is not None
|
||||
context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int,
|
||||
device, self.runner.pin_memory)
|
||||
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
|
||||
self.runner.pin_memory)
|
||||
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
|
||||
device, self.runner.pin_memory)
|
||||
query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32,
|
||||
device,
|
||||
self.runner.pin_memory)
|
||||
seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32,
|
||||
device, self.runner.pin_memory)
|
||||
placeholder_index_maps = {
|
||||
modality: placeholder_map.index_map()
|
||||
for modality, placeholder_map in
|
||||
self.multimodal_placeholder_maps.items()
|
||||
}
|
||||
|
||||
return self._metadata_cls( # type: ignore
|
||||
num_prefills=self.num_prefills,
|
||||
slot_mapping=slot_mapping_tensor,
|
||||
multi_modal_placeholder_index_maps=placeholder_index_maps,
|
||||
enable_kv_scales_calculation=True,
|
||||
num_prefill_tokens=self.num_prefill_tokens,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_query_len=max_query_len,
|
||||
max_prefill_seq_len=max_prefill_seq_len,
|
||||
max_decode_seq_len=max_decode_seq_len,
|
||||
query_start_loc=query_start_loc_tensor,
|
||||
seq_start_loc=seq_start_loc_tensor,
|
||||
context_lens_tensor=context_lens_tensor,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=use_captured_graph,
|
||||
block_tables_list=self.block_tables
|
||||
)
|
||||
|
||||
|
||||
class CommonAttentionState(AttentionState):
|
||||
|
||||
def __init__(self, runner: "ModelRunnerBase"):
|
||||
self.runner = runner
|
||||
self._is_graph_capturing = False
|
||||
|
||||
@contextmanager
|
||||
def graph_capture(self, max_batch_size: int):
|
||||
|
||||
self._is_graph_capturing = True
|
||||
|
||||
self._graph_slot_mapping = torch.full((max_batch_size, ),
|
||||
PAD_SLOT_ID,
|
||||
dtype=torch.long,
|
||||
device=self.runner.device)
|
||||
self._graph_seq_lens = torch.ones(max_batch_size,
|
||||
dtype=torch.int32,
|
||||
device=self.runner.device)
|
||||
self._graph_block_tables = torch.from_numpy(
|
||||
self.runner.graph_block_tables).to(device=self.runner.device)
|
||||
|
||||
yield
|
||||
|
||||
self._is_graph_capturing = False
|
||||
del self._graph_slot_mapping
|
||||
del self._graph_seq_lens
|
||||
del self._graph_block_tables
|
||||
|
||||
def graph_clone(self, batch_size: int) -> "CommonAttentionState":
|
||||
assert self._is_graph_capturing
|
||||
return self.__class__(self.runner)
|
||||
|
||||
def graph_capture_get_metadata_for_batch(
|
||||
self, batch_size: int, is_encoder_decoder_model: bool = False):
|
||||
assert self._is_graph_capturing
|
||||
attn_metadata = self.runner.attn_backend.make_metadata(
|
||||
num_prefills=0,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=batch_size,
|
||||
slot_mapping=self._graph_slot_mapping[:batch_size],
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=True,
|
||||
seq_lens=None,
|
||||
seq_lens_tensor=self._graph_seq_lens[:batch_size],
|
||||
max_query_len=1,
|
||||
max_decode_query_len=1,
|
||||
max_prefill_seq_len=0,
|
||||
max_decode_seq_len=self.runner.max_seq_len_to_capture,
|
||||
query_start_loc=None,
|
||||
seq_start_loc=None,
|
||||
context_lens_tensor=None,
|
||||
block_tables=self._graph_block_tables[:batch_size],
|
||||
use_cuda_graph=True,
|
||||
)
|
||||
if is_encoder_decoder_model:
|
||||
# The encoder decoder model works only with XFormers and
|
||||
# Flash Attention backend. Assert the same.
|
||||
assert self.runner.attn_backend.get_name() in \
|
||||
["XFORMERS", "FLASH_ATTN", "ROCM_FLASH"], \
|
||||
f"Expected attn_backend name to be either 'XFORMERS'," \
|
||||
f"'ROCM_FLASH', or 'FLASH_ATTN', but " \
|
||||
f"got '{self.runner.attn_backend.get_name()}'"
|
||||
self._update_captured_metadata_for_enc_dec_model(
|
||||
batch_size=batch_size, attn_metadata=attn_metadata)
|
||||
|
||||
return attn_metadata
|
||||
|
||||
def get_graph_input_buffers(
|
||||
self,
|
||||
attn_metadata,
|
||||
is_encoder_decoder_model: bool = False) -> Dict[str, Any]:
|
||||
input_buffers = {
|
||||
"slot_mapping": attn_metadata.slot_mapping,
|
||||
"seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor,
|
||||
"block_tables": attn_metadata.decode_metadata.block_tables,
|
||||
}
|
||||
if is_encoder_decoder_model:
|
||||
# The encoder decoder model works only with XFormers and
|
||||
# Flash Attention backend. Assert the same.
|
||||
assert self.runner.attn_backend.get_name() in \
|
||||
["XFORMERS", "FLASH_ATTN", "ROCM_FLASH"], \
|
||||
f"Expected attn_backend name to be either 'XFORMERS'," \
|
||||
f"'ROCM_FLASH', or 'FLASH_ATTN', but " \
|
||||
f"got '{self.runner.attn_backend.get_name()}'"
|
||||
self._add_additional_input_buffers_for_enc_dec_model(
|
||||
attn_metadata=attn_metadata, input_buffers=input_buffers)
|
||||
return input_buffers
|
||||
|
||||
def prepare_graph_input_buffers(
|
||||
self,
|
||||
input_buffers,
|
||||
attn_metadata,
|
||||
is_encoder_decoder_model: bool = False) -> None:
|
||||
input_buffers["seq_lens_tensor"].copy_(
|
||||
attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True)
|
||||
input_buffers["block_tables"].copy_(
|
||||
attn_metadata.decode_metadata.block_tables, non_blocking=True)
|
||||
if is_encoder_decoder_model:
|
||||
# The encoder decoder model works only with XFormers and
|
||||
# Flash Attention backend. Assert the same.
|
||||
assert self.runner.attn_backend.get_name() in\
|
||||
["XFORMERS", "FLASH_ATTN"], \
|
||||
f"Expected attn_backend name to be either 'XFORMERS' or "\
|
||||
f"'FLASH_ATTN', but "\
|
||||
f"got '{self.runner.attn_backend.get_name()}'"
|
||||
self._prepare_input_buffers_for_enc_dec_model(
|
||||
attn_metadata, input_buffers)
|
||||
|
||||
def begin_forward(self, model_input) -> None:
|
||||
return
|
||||
|
||||
def _update_captured_metadata_for_enc_dec_model(self, batch_size: int,
|
||||
attn_metadata):
|
||||
"""
|
||||
Updates the attention metadata parameters for CUDA graph capture in an
|
||||
encoder-decoder model.
|
||||
|
||||
This method modifies attention-related tensors and metadata required
|
||||
for CUDA graph capture in encoder-decoder models. Specifically, it
|
||||
updates the cross-attention and encoder sequence tensors in the
|
||||
AttentionMetadata object.
|
||||
"""
|
||||
# During decode phase the cross_slot_mapping will be empty. Hence set
|
||||
# an empty tensor for CUDA Graph capture.
|
||||
attn_metadata.cross_slot_mapping = torch.tensor(
|
||||
[], dtype=torch.int).cuda()
|
||||
attn_metadata.cross_block_tables = torch.full(
|
||||
(batch_size, self.runner.get_max_block_per_batch()),
|
||||
1,
|
||||
dtype=torch.int).cuda()
|
||||
attn_metadata.encoder_seq_lens = torch.full((batch_size, ),
|
||||
1,
|
||||
dtype=torch.int).cuda()
|
||||
attn_metadata.encoder_seq_lens_tensor = torch.full(
|
||||
(batch_size, ), 1, dtype=torch.int).cuda()
|
||||
attn_metadata.max_encoder_seq_len = self.runner.max_seq_len_to_capture
|
||||
attn_metadata.num_encoder_tokens = 0
|
||||
|
||||
def _add_additional_input_buffers_for_enc_dec_model(
|
||||
self, attn_metadata, input_buffers: Dict[str, Any]):
|
||||
"""
|
||||
Saves additional input buffers specific to the encoder-decoder model
|
||||
from the attention metadata.
|
||||
|
||||
This method extracts and stores encoder-decoder related input buffers
|
||||
from the `attn_metadata` into the `input_buffers` dictionary. The
|
||||
buffers include encoder sequence lengths, cross-slot mappings, and
|
||||
cross-block tables, which are essential for the encoder-decoder model
|
||||
during CUDA graph replay.
|
||||
"""
|
||||
input_buffers["encoder_seq_lens_tensor"] = (
|
||||
attn_metadata.decode_metadata.encoder_seq_lens_tensor)
|
||||
input_buffers["cross_slot_mapping"] = (
|
||||
attn_metadata.decode_metadata.cross_slot_mapping)
|
||||
input_buffers["cross_block_tables"] = (
|
||||
attn_metadata.decode_metadata.cross_block_tables)
|
||||
|
||||
def _prepare_input_buffers_for_enc_dec_model(self, attn_metadata,
|
||||
input_buffers: Dict[str,
|
||||
Any]):
|
||||
"""
|
||||
Populates input buffers with data from the encoder-decoder model's
|
||||
attention metadata.
|
||||
|
||||
This method fills the input buffers with encoder-decoder specific
|
||||
tensors. It copies data from the `attn_metadata` and keyword arguments
|
||||
(`kwargs`) into corresponding buffers in the `input_buffers` dictionary.
|
||||
The copied data includes attention-related metadata as well as input
|
||||
IDs and positional information for the encoder.
|
||||
"""
|
||||
input_buffers["encoder_seq_lens_tensor"].copy_(
|
||||
attn_metadata.decode_metadata.encoder_seq_lens_tensor,
|
||||
non_blocking=True)
|
||||
input_buffers["cross_slot_mapping"].copy_(
|
||||
attn_metadata.decode_metadata.cross_slot_mapping,
|
||||
non_blocking=True)
|
||||
input_buffers["cross_block_tables"].copy_(
|
||||
attn_metadata.decode_metadata.cross_block_tables,
|
||||
non_blocking=True)
|
||||
|
||||
|
||||
def is_all_encoder_attn_metadata_set(attn_metadata):
|
||||
'''
|
||||
All attention metadata required for encoder attention is set.
|
||||
'''
|
||||
return ((attn_metadata.encoder_seq_lens is not None)
|
||||
and (attn_metadata.encoder_seq_lens_tensor is not None)
|
||||
and (attn_metadata.max_encoder_seq_len is not None))
|
||||
|
||||
|
||||
def is_all_cross_attn_metadata_set(attn_metadata):
|
||||
'''
|
||||
All attention metadata required for enc/dec cross-attention is set.
|
||||
|
||||
Superset of encoder attention required metadata.
|
||||
'''
|
||||
return (attn_metadata.is_all_encoder_attn_metadata_set
|
||||
and (attn_metadata.cross_slot_mapping is not None)
|
||||
and (attn_metadata.cross_block_tables is not None))
|
||||
|
||||
|
||||
def get_seq_len_block_table_args(
|
||||
attn_metadata,
|
||||
is_prompt: bool,
|
||||
attn_type: str,
|
||||
) -> tuple:
|
||||
'''
|
||||
The particular choice of sequence-length- and block-table-related
|
||||
attributes which should be extracted from attn_metadata is dependent
|
||||
on the type of attention operation.
|
||||
|
||||
Decoder attn -> select entirely decoder self-attention-related fields
|
||||
Encoder/decoder cross-attn -> select encoder sequence lengths &
|
||||
cross-attn block-tables fields
|
||||
Encoder attn -> select encoder sequence lengths fields & no block tables
|
||||
|
||||
Arguments:
|
||||
|
||||
* attn_metadata: Attention metadata structure associated with attention op
|
||||
* is_prompt: True if prefill, False otherwise
|
||||
* attn_type: encoder attention, decoder self-attention,
|
||||
encoder/decoder cross-attention
|
||||
|
||||
Returns:
|
||||
|
||||
* Appropriate sequence-lengths tensor
|
||||
* Appropriate max sequence-length scalar
|
||||
* Appropriate block tables (or None)
|
||||
'''
|
||||
|
||||
if attn_type == AttentionType.DECODER:
|
||||
# Decoder self-attention
|
||||
# Choose max_seq_len based on whether we are in prompt_run
|
||||
if is_prompt:
|
||||
max_seq_len = attn_metadata.max_prefill_seq_len
|
||||
else:
|
||||
max_seq_len = attn_metadata.max_decode_seq_len
|
||||
return (attn_metadata.seq_lens_tensor, max_seq_len,
|
||||
attn_metadata.block_tables)
|
||||
elif attn_type == AttentionType.ENCODER_DECODER:
|
||||
# Enc/dec cross-attention KVs match encoder sequence length;
|
||||
# cross-attention utilizes special "cross" block tables
|
||||
return (attn_metadata.encoder_seq_lens_tensor,
|
||||
attn_metadata.max_encoder_seq_len,
|
||||
attn_metadata.cross_block_tables)
|
||||
elif attn_type == AttentionType.ENCODER:
|
||||
# No block tables associated with encoder attention
|
||||
return (attn_metadata.encoder_seq_lens_tensor,
|
||||
attn_metadata.max_encoder_seq_len, None)
|
||||
else:
|
||||
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
||||
|
||||
|
||||
def get_num_prefill_decode_query_kv_tokens(
|
||||
attn_metadata,
|
||||
attn_type: str,
|
||||
) -> Tuple[int, int, int]:
|
||||
"""
|
||||
Calculate the number of prefill and decode tokens for query, key/value
|
||||
based on the attention metadata and the specified attention type.
|
||||
|
||||
Args:
|
||||
attn_metadata (AttentionMetadata): Attention Metadata object.
|
||||
attn_type (AttentionType): The type of attention being used.
|
||||
Returns:
|
||||
Tuple[int, int, int]: A tuple containing three integers:
|
||||
- The number of prefill query tokens.
|
||||
- The number of prefill key/value tokens.
|
||||
- The number of decode query tokens.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the number of encoder tokens in `attn_metadata`
|
||||
is `None` when required for the calculations.
|
||||
"""
|
||||
num_prefill_query_tokens = 0
|
||||
num_decode_query_tokens = 0
|
||||
num_prefill_kv_tokens = 0
|
||||
if attn_type == AttentionType.ENCODER:
|
||||
# Encoder attention is only invoked during prefill phase.
|
||||
# The same input servers a both query and key.
|
||||
assert attn_metadata.num_encoder_tokens is not None
|
||||
num_prefill_query_tokens = attn_metadata.num_encoder_tokens
|
||||
num_prefill_kv_tokens = attn_metadata.num_encoder_tokens
|
||||
num_decode_query_tokens = 0
|
||||
elif attn_type == AttentionType.ENCODER_DECODER:
|
||||
assert attn_metadata.num_encoder_tokens is not None
|
||||
num_prefill_query_tokens = attn_metadata.num_prefill_tokens
|
||||
# The key is the encoder/cross-attention.
|
||||
num_prefill_kv_tokens = attn_metadata.num_encoder_tokens
|
||||
num_decode_query_tokens = attn_metadata.num_decode_tokens
|
||||
else: # attn_type == AttentionType.DECODER or
|
||||
# attn_type == AttentionType.ENCODER_ONLY
|
||||
num_prefill_query_tokens = attn_metadata.num_prefill_tokens
|
||||
num_prefill_kv_tokens = attn_metadata.num_prefill_tokens
|
||||
num_decode_query_tokens = attn_metadata.num_decode_tokens
|
||||
|
||||
return (num_prefill_query_tokens, num_prefill_kv_tokens,
|
||||
num_decode_query_tokens)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MLADims:
|
||||
q_lora_rank: Optional[int]
|
||||
kv_lora_rank: int
|
||||
qk_nope_head_dim: int
|
||||
qk_rope_head_dim: int
|
||||
v_head_dim: int
|
||||
|
||||
|
||||
def get_mla_dims(model_config: ModelConfig) -> MLADims:
|
||||
hf_text_config = model_config.hf_text_config
|
||||
|
||||
return MLADims(
|
||||
q_lora_rank=getattr(hf_text_config, "q_lora_rank", None),
|
||||
kv_lora_rank=hf_text_config.kv_lora_rank,
|
||||
qk_nope_head_dim=hf_text_config.qk_nope_head_dim,
|
||||
qk_rope_head_dim=hf_text_config.qk_rope_head_dim,
|
||||
v_head_dim=hf_text_config.v_head_dim,
|
||||
)
|
||||
818
vllm/attention/backends/xformers.py
Normal file
818
vllm/attention/backends/xformers.py
Normal file
@@ -0,0 +1,818 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Attention layer with xFormers and PagedAttention."""
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
from xformers import ops as xops
|
||||
from xformers.ops.fmha.attn_bias import (AttentionBias,
|
||||
BlockDiagonalCausalMask,
|
||||
BlockDiagonalMask,
|
||||
LowerTriangularMaskWithTensorBias)
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionLayer,
|
||||
AttentionMetadata, AttentionType)
|
||||
from vllm.attention.backends.utils import (
|
||||
CommonAttentionState, CommonMetadataBuilder,
|
||||
get_num_prefill_decode_query_kv_tokens, get_seq_len_block_table_args,
|
||||
is_all_cross_attn_metadata_set, is_all_encoder_attn_metadata_set)
|
||||
from vllm.attention.ops.paged_attn import (PagedAttention,
|
||||
PagedAttentionMetadata)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class XFormersBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "XFORMERS"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["XFormersImpl"]:
|
||||
return XFormersImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> Type["AttentionMetadata"]:
|
||||
return XFormersMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> Type["XFormersMetadataBuilder"]:
|
||||
return XFormersMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> Type["CommonAttentionState"]:
|
||||
return CommonAttentionState
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
|
||||
num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: Dict[int, int],
|
||||
) -> None:
|
||||
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: torch.Tensor,
|
||||
) -> None:
|
||||
PagedAttention.copy_blocks(kv_caches, src_to_dists)
|
||||
|
||||
|
||||
@dataclass
|
||||
class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
"""Metadata for XFormersbackend.
|
||||
|
||||
NOTE: Any python object stored here is not updated when it is
|
||||
cuda-graph replayed. If you have values that need to be changed
|
||||
dynamically, it should be stored in tensor. The tensor has to be
|
||||
updated from `CUDAGraphRunner.forward` API.
|
||||
"""
|
||||
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
# |- tokenA -|......................|-- newTokens ---|
|
||||
# |---------- context_len ----------|
|
||||
# |-------------------- seq_len ----------------------|
|
||||
# |-- query_len ---|
|
||||
|
||||
# seq_lens stored as a tensor.
|
||||
seq_lens_tensor: Optional[torch.Tensor]
|
||||
|
||||
# FIXME: It is for flash attn.
|
||||
# Maximum sequence length among prefill batch. 0 if there are decoding
|
||||
# requests only.
|
||||
max_prefill_seq_len: int
|
||||
# Maximum sequence length among decode batch. 0 if there are prefill
|
||||
# requests only.
|
||||
max_decode_seq_len: int
|
||||
|
||||
# Whether or not if cuda graph is enabled.
|
||||
# Cuda-graph is currently enabled for decoding only.
|
||||
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
||||
use_cuda_graph: bool
|
||||
|
||||
# (batch_size,). The sequence length per sequence. Sequence length means
|
||||
# the computed tokens + new tokens None if it is a decoding.
|
||||
seq_lens: Optional[List[int]] = None
|
||||
|
||||
# FIXME: It is for flash attn.
|
||||
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
||||
# the batch, used to index into sequence. E.g., if the sequence length is
|
||||
# [4, 6], it is [0, 4, 10].
|
||||
seq_start_loc: Optional[torch.Tensor] = None
|
||||
|
||||
# (batch_size,) A tensor of context lengths (tokens that are computed
|
||||
# so far).
|
||||
context_lens_tensor: Optional[torch.Tensor] = None
|
||||
|
||||
# Maximum query length in the batch. None for decoding.
|
||||
max_query_len: Optional[int] = None
|
||||
|
||||
# Max number of query tokens among request in the batch.
|
||||
max_decode_query_len: Optional[int] = None
|
||||
|
||||
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
||||
# the batch, used to index into subquery. E.g., if the subquery length
|
||||
# is [4, 6], it is [0, 4, 10].
|
||||
query_start_loc: Optional[torch.Tensor] = None
|
||||
|
||||
# Self-attention prefill/decode metadata cache
|
||||
_cached_prefill_metadata: Optional["XFormersMetadata"] = None
|
||||
_cached_decode_metadata: Optional["XFormersMetadata"] = None
|
||||
|
||||
# Begin encoder attn & enc/dec cross-attn fields...
|
||||
|
||||
# Encoder sequence lengths representation
|
||||
encoder_seq_lens: Optional[List[int]] = None
|
||||
encoder_seq_lens_tensor: Optional[torch.Tensor] = None
|
||||
# FIXME: It is for flash attn.
|
||||
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
||||
# the batch, used to index into sequence. E.g., if the sequence length is
|
||||
# [4, 6], it is [0, 4, 10].
|
||||
encoder_seq_start_loc: Optional[torch.Tensor] = None
|
||||
|
||||
# Maximum sequence length among encoder sequences
|
||||
max_encoder_seq_len: Optional[int] = None
|
||||
|
||||
# Number of tokens input to encoder
|
||||
num_encoder_tokens: Optional[int] = None
|
||||
|
||||
# Cross-attention memory-mapping data structures: slot mapping
|
||||
# and block tables
|
||||
cross_slot_mapping: Optional[torch.Tensor] = None
|
||||
cross_block_tables: Optional[torch.Tensor] = None
|
||||
|
||||
tree_attention_masks_tensor: Optional[torch.Tensor] = None
|
||||
block_tables_list: Optional[List[int]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
# Set during the execution of the first attention op.
|
||||
# It is a list because it is needed to set per prompt
|
||||
# when alibi slopes is used. It is because of the limitation
|
||||
# from xformer API.
|
||||
# will not appear in the __repr__ and __init__
|
||||
self.attn_bias: Optional[List[AttentionBias]] = None
|
||||
self.encoder_attn_bias: Optional[List[AttentionBias]] = None
|
||||
self.cross_attn_bias: Optional[List[AttentionBias]] = None
|
||||
|
||||
@property
|
||||
def is_all_encoder_attn_metadata_set(self):
|
||||
'''
|
||||
All attention metadata required for encoder attention is set.
|
||||
'''
|
||||
return is_all_encoder_attn_metadata_set(self)
|
||||
|
||||
@property
|
||||
def is_all_cross_attn_metadata_set(self):
|
||||
'''
|
||||
All attention metadata required for enc/dec cross-attention is set.
|
||||
|
||||
Superset of encoder attention required metadata.
|
||||
'''
|
||||
return is_all_cross_attn_metadata_set(self)
|
||||
|
||||
@property
|
||||
def prefill_metadata(self) -> Optional["XFormersMetadata"]:
|
||||
if self.num_prefills == 0:
|
||||
return None
|
||||
|
||||
if self._cached_prefill_metadata is not None:
|
||||
# Recover cached prefill-phase attention
|
||||
# metadata structure
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
assert ((self.seq_lens is not None)
|
||||
or (self.encoder_seq_lens is not None))
|
||||
assert ((self.seq_lens_tensor is not None)
|
||||
or (self.encoder_seq_lens_tensor is not None))
|
||||
|
||||
# Compute some attn_metadata fields which default to None
|
||||
query_start_loc = (None if self.query_start_loc is None else
|
||||
self.query_start_loc[:self.num_prefills + 1])
|
||||
seq_start_loc = (None if self.seq_start_loc is None else
|
||||
self.seq_start_loc[:self.num_prefills + 1])
|
||||
slot_mapping = (None if self.slot_mapping is None else
|
||||
self.slot_mapping[:self.num_prefill_tokens])
|
||||
seq_lens = (None if self.seq_lens is None else
|
||||
self.seq_lens[:self.num_prefills])
|
||||
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
||||
self.seq_lens_tensor[:self.num_prefills])
|
||||
context_lens_tensor = (None if self.context_lens_tensor is None else
|
||||
self.context_lens_tensor[:self.num_prefills])
|
||||
block_tables = (None if self.block_tables is None else
|
||||
self.block_tables[:self.num_prefills])
|
||||
|
||||
# Construct & cache prefill-phase attention metadata structure
|
||||
self._cached_prefill_metadata = XFormersMetadata(
|
||||
num_prefills=self.num_prefills,
|
||||
num_prefill_tokens=self.num_prefill_tokens,
|
||||
num_decode_tokens=0,
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=self.
|
||||
multi_modal_placeholder_index_maps,
|
||||
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_query_len=self.max_query_len,
|
||||
max_prefill_seq_len=self.max_prefill_seq_len,
|
||||
max_decode_seq_len=0,
|
||||
query_start_loc=query_start_loc,
|
||||
seq_start_loc=seq_start_loc,
|
||||
context_lens_tensor=context_lens_tensor,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=False,
|
||||
# Begin encoder & cross attn fields below...
|
||||
encoder_seq_lens=self.encoder_seq_lens,
|
||||
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
|
||||
max_encoder_seq_len=self.max_encoder_seq_len,
|
||||
cross_slot_mapping=self.cross_slot_mapping,
|
||||
cross_block_tables=self.cross_block_tables,
|
||||
tree_attention_masks_tensor=self.tree_attention_masks_tensor,
|
||||
block_tables_list=self.block_tables_list)
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
@property
|
||||
def decode_metadata(self) -> Optional["XFormersMetadata"]:
|
||||
if self.num_decode_tokens == 0:
|
||||
return None
|
||||
|
||||
if self._cached_decode_metadata is not None:
|
||||
# Recover cached decode-phase attention
|
||||
# metadata structure
|
||||
return self._cached_decode_metadata
|
||||
assert ((self.seq_lens_tensor is not None)
|
||||
or (self.encoder_seq_lens_tensor is not None))
|
||||
|
||||
# Compute some attn_metadata fields which default to None
|
||||
slot_mapping = (None if self.slot_mapping is None else
|
||||
self.slot_mapping[self.num_prefill_tokens:])
|
||||
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
||||
self.seq_lens_tensor[self.num_prefills:])
|
||||
block_tables = (None if self.block_tables is None else
|
||||
self.block_tables[self.num_prefills:])
|
||||
|
||||
# Construct & cache decode-phase attention metadata structure
|
||||
self._cached_decode_metadata = XFormersMetadata(
|
||||
num_prefills=0,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=self.num_decode_tokens,
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=True,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_prefill_seq_len=0,
|
||||
max_decode_seq_len=self.max_decode_seq_len,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=self.use_cuda_graph,
|
||||
# Begin encoder & cross attn fields below...
|
||||
encoder_seq_lens=self.encoder_seq_lens,
|
||||
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
|
||||
max_encoder_seq_len=self.max_encoder_seq_len,
|
||||
cross_slot_mapping=self.cross_slot_mapping,
|
||||
cross_block_tables=self.cross_block_tables,
|
||||
tree_attention_masks_tensor=self.tree_attention_masks_tensor,
|
||||
block_tables_list=self.block_tables_list)
|
||||
|
||||
# Batch may be composed of prefill|decodes, adjust query start indices
|
||||
# to refer to the start of decodes when the two are split apart.
|
||||
# E.g. in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
|
||||
if self._cached_decode_metadata.query_start_loc is not None:
|
||||
qs = self._cached_decode_metadata.query_start_loc
|
||||
self._cached_decode_metadata.query_start_loc = qs - qs[0]
|
||||
return self._cached_decode_metadata
|
||||
|
||||
|
||||
def _get_attn_bias(
|
||||
attn_metadata: XFormersMetadata,
|
||||
attn_type: str,
|
||||
) -> Optional[AttentionBias]:
|
||||
'''
|
||||
Extract appropriate attention bias from attention metadata
|
||||
according to attention type.
|
||||
|
||||
Arguments:
|
||||
|
||||
* attn_metadata: Attention metadata structure associated with attention
|
||||
* attn_type: encoder attention, decoder self-attention,
|
||||
encoder/decoder cross-attention
|
||||
|
||||
Returns:
|
||||
* Appropriate attention bias value given the attention type
|
||||
'''
|
||||
|
||||
if (attn_type == AttentionType.DECODER
|
||||
or attn_type == AttentionType.ENCODER_ONLY):
|
||||
return attn_metadata.attn_bias
|
||||
elif attn_type == AttentionType.ENCODER:
|
||||
return attn_metadata.encoder_attn_bias
|
||||
elif attn_type == AttentionType.ENCODER_DECODER:
|
||||
return attn_metadata.cross_attn_bias
|
||||
else:
|
||||
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
||||
|
||||
|
||||
def _set_attn_bias(
|
||||
attn_metadata: XFormersMetadata,
|
||||
attn_bias: List[Optional[AttentionBias]],
|
||||
attn_type: str,
|
||||
) -> None:
|
||||
'''
|
||||
Update appropriate attention bias field of attention metadata,
|
||||
according to attention type.
|
||||
|
||||
Arguments:
|
||||
|
||||
* attn_metadata: Attention metadata structure associated with attention
|
||||
* attn_bias: The desired attention bias value
|
||||
* attn_type: encoder attention, decoder self-attention,
|
||||
encoder/decoder cross-attention
|
||||
'''
|
||||
|
||||
if (attn_type == AttentionType.DECODER
|
||||
or attn_type == AttentionType.ENCODER_ONLY):
|
||||
attn_metadata.attn_bias = attn_bias
|
||||
elif attn_type == AttentionType.ENCODER:
|
||||
attn_metadata.encoder_attn_bias = attn_bias
|
||||
elif attn_type == AttentionType.ENCODER_DECODER:
|
||||
attn_metadata.cross_attn_bias = attn_bias
|
||||
else:
|
||||
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
||||
|
||||
|
||||
class XFormersMetadataBuilder(CommonMetadataBuilder[XFormersMetadata]):
|
||||
|
||||
_metadata_cls = XFormersMetadata
|
||||
|
||||
|
||||
class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
||||
"""
|
||||
If the input tensors contain prompt tokens, the layout is as follows:
|
||||
|<--------------- num_prefill_tokens ----------------->|
|
||||
|<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
|
||||
|
||||
Otherwise, the layout is as follows:
|
||||
|<----------------- num_decode_tokens ------------------>|
|
||||
|<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
|
||||
|
||||
Generation tokens can contain padding when cuda-graph is used.
|
||||
Currently, prompt tokens don't contain any padding.
|
||||
|
||||
The prompts might have different lengths, while the generation tokens
|
||||
always have length 1.
|
||||
|
||||
If chunked prefill is enabled, prefill tokens and decode tokens can be
|
||||
batched together in a flattened 1D query.
|
||||
|
||||
|<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
|
||||
|<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|
|
||||
|
||||
Currently, cuda graph is disabled for chunked prefill, meaning there's no
|
||||
padding between prefill and decode tokens.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
use_irope: bool = False,
|
||||
) -> None:
|
||||
if kv_sharing_target_layer_name is not None:
|
||||
raise NotImplementedError("KV sharing is not supported in V0.")
|
||||
if blocksparse_params is not None:
|
||||
raise ValueError(
|
||||
"XFormers does not support block-sparse attention.")
|
||||
if logits_soft_cap is not None:
|
||||
logger.warning_once("XFormers does not support logits soft cap. "
|
||||
"Outputs may be slightly off.")
|
||||
if use_irope:
|
||||
logger.warning_once(
|
||||
"Using irope in XFormers is not supported yet, it will fall"
|
||||
" back to global attention for long context.")
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
if alibi_slopes is not None:
|
||||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
||||
self.alibi_slopes = alibi_slopes
|
||||
self.sliding_window = sliding_window
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
supported_head_sizes = PagedAttention.get_supported_head_sizes()
|
||||
if head_size not in supported_head_sizes:
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by PagedAttention. "
|
||||
f"Supported head sizes are: {supported_head_sizes}.")
|
||||
|
||||
self.attn_type = attn_type
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
query: torch.Tensor,
|
||||
key: Optional[torch.Tensor],
|
||||
value: Optional[torch.Tensor],
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: "XFormersMetadata",
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with xFormers and PagedAttention.
|
||||
|
||||
For decoder-only models: query, key and value must be non-None.
|
||||
|
||||
For encoder/decoder models:
|
||||
* XFormersImpl.forward() may be invoked for both self- and cross-
|
||||
attention layers.
|
||||
* For self-attention: query, key and value must be non-None.
|
||||
* For cross-attention:
|
||||
* Query must be non-None
|
||||
* During prefill, key and value must be non-None; key and value
|
||||
get cached for use during decode.
|
||||
* During decode, key and value may be None, since:
|
||||
(1) key and value tensors were cached during prefill, and
|
||||
(2) cross-attention key and value tensors do not grow during
|
||||
decode
|
||||
|
||||
A note on how the attn_type (attention type enum) argument impacts
|
||||
attention forward() behavior:
|
||||
|
||||
* DECODER: normal decoder-only behavior;
|
||||
use decoder self-attention block table
|
||||
* ENCODER: no KV caching; pass encoder sequence
|
||||
attributes (encoder_seq_lens/encoder_seq_lens_tensor/
|
||||
max_encoder_seq_len) to kernel, in lieu of decoder
|
||||
sequence attributes (seq_lens/seq_lens_tensor/max_seq_len).
|
||||
Used for encoder branch of encoder-decoder models.
|
||||
* ENCODER_ONLY: no kv_caching, uses the normal attention
|
||||
attributes (seq_lens/seq_lens_tensor/max_seq_len).
|
||||
* ENCODER_DECODER: cross-attention behavior;
|
||||
use cross-attention block table for caching KVs derived
|
||||
from encoder hidden states; since KV sequence lengths
|
||||
will match encoder sequence lengths, pass encoder sequence
|
||||
attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/
|
||||
max_encoder_seq_len)
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads * head_size]
|
||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
|
||||
NOTE: kv_cache will be an empty tensor with shape [0]
|
||||
for profiling run.
|
||||
attn_metadata: Metadata for attention.
|
||||
attn_type: Select attention type, between encoder attention,
|
||||
decoder self-attention, or encoder/decoder cross-
|
||||
attention. Defaults to decoder self-attention,
|
||||
which is the vLLM default generally
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
if output_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for XFormersImpl")
|
||||
|
||||
attn_type = self.attn_type
|
||||
# Check that appropriate attention metadata attributes are
|
||||
# selected for the desired attention type
|
||||
if (attn_type == AttentionType.ENCODER
|
||||
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
|
||||
raise AttributeError("Encoder attention requires setting "
|
||||
"encoder metadata attributes.")
|
||||
|
||||
elif (attn_type == AttentionType.ENCODER_DECODER
|
||||
and (not attn_metadata.is_all_cross_attn_metadata_set)):
|
||||
raise AttributeError("Encoder/decoder cross-attention "
|
||||
"requires setting cross-attention "
|
||||
"metadata attributes.")
|
||||
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
if key is not None:
|
||||
assert value is not None
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
else:
|
||||
assert value is None
|
||||
|
||||
# Self-attention vs. cross-attention will impact
|
||||
# which KV cache memory-mapping & which
|
||||
# seqlen datastructures we utilize
|
||||
|
||||
if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0):
|
||||
# KV-cache during decoder-self- or
|
||||
# encoder-decoder-cross-attention, but not
|
||||
# during encoder attention.
|
||||
#
|
||||
# Even if there are no new key/value pairs to cache,
|
||||
# we still need to break out key_cache and value_cache
|
||||
# i.e. for later use by paged attention
|
||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||
kv_cache, self.num_kv_heads, self.head_size)
|
||||
|
||||
if (key is not None) and (value is not None):
|
||||
|
||||
if attn_type == AttentionType.ENCODER_DECODER:
|
||||
# Update cross-attention KV cache (prefill-only)
|
||||
# During cross-attention decode, key & value will be None,
|
||||
# preventing this IF-statement branch from running
|
||||
updated_slot_mapping = attn_metadata.cross_slot_mapping
|
||||
else:
|
||||
# Update self-attention KV cache (prefill/decode)
|
||||
updated_slot_mapping = attn_metadata.slot_mapping
|
||||
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# If kv_cache is not provided, the new key and value tensors are
|
||||
# not cached. This happens during the initial memory
|
||||
# profiling run.
|
||||
PagedAttention.write_to_paged_cache(
|
||||
key, value, key_cache, value_cache, updated_slot_mapping,
|
||||
self.kv_cache_dtype, layer._k_scale, layer._v_scale)
|
||||
(num_prefill_query_tokens, num_prefill_kv_tokens,
|
||||
num_decode_query_tokens) = \
|
||||
get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type)
|
||||
|
||||
output = torch.empty_like(query)
|
||||
# Query for decode. KV is not needed because it is already cached.
|
||||
decode_query = query[num_prefill_query_tokens:]
|
||||
# QKV for prefill.
|
||||
query = query[:num_prefill_query_tokens]
|
||||
if key is not None and value is not None:
|
||||
key = key[:num_prefill_kv_tokens]
|
||||
value = value[:num_prefill_kv_tokens]
|
||||
|
||||
assert query.shape[0] == num_prefill_query_tokens
|
||||
assert decode_query.shape[0] == num_decode_query_tokens
|
||||
|
||||
if prefill_meta := attn_metadata.prefill_metadata:
|
||||
# Prompt run.
|
||||
if kv_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0:
|
||||
# normal attention.
|
||||
# block tables are empty if the prompt does not have a cached
|
||||
# prefix.
|
||||
out = self._run_memory_efficient_xformers_forward(
|
||||
query, key, value, prefill_meta, attn_type=attn_type)
|
||||
assert out.shape == output[:num_prefill_query_tokens].shape
|
||||
output[:num_prefill_query_tokens] = out
|
||||
else:
|
||||
assert attn_type != AttentionType.ENCODER_ONLY, (
|
||||
"Encoder-only models should not have prefix attention.")
|
||||
|
||||
assert prefill_meta.query_start_loc is not None
|
||||
assert prefill_meta.max_query_len is not None
|
||||
|
||||
# prefix-enabled attention
|
||||
# TODO(Hai) this triton kernel has regression issue (broke) to
|
||||
# deal with different data types between KV and FP8 KV cache,
|
||||
# to be addressed separately.
|
||||
out = PagedAttention.forward_prefix(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
self.kv_cache_dtype,
|
||||
key_cache,
|
||||
value_cache,
|
||||
prefill_meta.block_tables,
|
||||
prefill_meta.query_start_loc,
|
||||
prefill_meta.seq_lens_tensor,
|
||||
prefill_meta.max_query_len,
|
||||
self.alibi_slopes,
|
||||
self.sliding_window,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
assert output[:num_prefill_query_tokens].shape == out.shape
|
||||
output[:num_prefill_query_tokens] = out
|
||||
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
assert attn_type != AttentionType.ENCODER_ONLY, (
|
||||
"Encoder-only models should not have decode metadata.")
|
||||
|
||||
(
|
||||
seq_lens_arg,
|
||||
max_seq_len_arg,
|
||||
block_tables_arg,
|
||||
) = get_seq_len_block_table_args(decode_meta, False, attn_type)
|
||||
|
||||
tree_attention_masks_tensor = decode_meta.tree_attention_masks_tensor
|
||||
|
||||
output[num_prefill_query_tokens:] = PagedAttention.forward_decode(
|
||||
decode_query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
block_tables_arg,
|
||||
seq_lens_arg,
|
||||
max_seq_len_arg,
|
||||
self.kv_cache_dtype,
|
||||
self.num_kv_heads,
|
||||
self.scale,
|
||||
self.alibi_slopes,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
attn_masks=tree_attention_masks_tensor,
|
||||
attn_masks_stride=tree_attention_masks_tensor.stride(0) if tree_attention_masks_tensor is not None else 0
|
||||
)
|
||||
|
||||
# Reshape the output tensor.
|
||||
return output.view(-1, self.num_heads * self.head_size)
|
||||
|
||||
def _run_memory_efficient_xformers_forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_metadata: XFormersMetadata,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
) -> torch.Tensor:
|
||||
"""Attention for 1D query of multiple prompts. Multiple prompt
|
||||
tokens are flattened in to `query` input.
|
||||
|
||||
See https://facebookresearch.github.io/xformers/components/ops.html
|
||||
for API spec.
|
||||
|
||||
Args:
|
||||
output: shape = [num_prefill_tokens, num_heads, head_size]
|
||||
query: shape = [num_prefill_tokens, num_heads, head_size]
|
||||
key: shape = [num_prefill_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_prefill_tokens, num_kv_heads, head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
attn_type: Select attention type, between encoder attention,
|
||||
decoder self-attention, or encoder/decoder cross-
|
||||
attention. Defaults to decoder self-attention,
|
||||
which is the vLLM default generally
|
||||
"""
|
||||
|
||||
original_query = query
|
||||
if self.num_kv_heads != self.num_heads:
|
||||
# GQA/MQA requires the shape [B, M, G, H, K].
|
||||
# Note that the output also has the same shape (which is different
|
||||
# from a spec from the doc).
|
||||
query = query.view(query.shape[0], self.num_kv_heads,
|
||||
self.num_queries_per_kv, query.shape[-1])
|
||||
key = key[:, :,
|
||||
None, :].expand(key.shape[0], self.num_kv_heads,
|
||||
self.num_queries_per_kv, key.shape[-1])
|
||||
value = value[:, :,
|
||||
None, :].expand(value.shape[0], self.num_kv_heads,
|
||||
self.num_queries_per_kv,
|
||||
value.shape[-1])
|
||||
|
||||
# Set attention bias if not provided. This typically happens at
|
||||
# the very attention layer of every iteration.
|
||||
# FIXME(woosuk): This is a hack.
|
||||
attn_bias = _get_attn_bias(attn_metadata, attn_type)
|
||||
if attn_bias is None:
|
||||
if self.alibi_slopes is None:
|
||||
|
||||
# Cross attention block of decoder branch of encoder-decoder
|
||||
# model uses seq_lens for dec / encoder_seq_lens for enc
|
||||
if (attn_type == AttentionType.ENCODER_DECODER):
|
||||
assert attn_metadata.seq_lens is not None
|
||||
assert attn_metadata.encoder_seq_lens is not None
|
||||
|
||||
# Cross-attention mask is non-causal
|
||||
attn_bias = BlockDiagonalMask.from_seqlens(
|
||||
attn_metadata.seq_lens,
|
||||
attn_metadata.encoder_seq_lens,
|
||||
device=query.device)
|
||||
|
||||
# Encoder branch of encoder-decoder model uses
|
||||
# attn_metadata.encoder_seq_lens
|
||||
elif attn_type == AttentionType.ENCODER:
|
||||
|
||||
assert attn_metadata.encoder_seq_lens is not None
|
||||
|
||||
# Encoder self-attention mask is non-causal
|
||||
attn_bias = BlockDiagonalMask.from_seqlens(
|
||||
attn_metadata.encoder_seq_lens, device=query.device)
|
||||
|
||||
# Self-attention block of encoder-only model just
|
||||
# uses the seq_lens directly.
|
||||
elif attn_type == AttentionType.ENCODER_ONLY:
|
||||
assert attn_metadata.seq_lens is not None
|
||||
|
||||
# Encoder self-attention mask is non-causal
|
||||
attn_bias = BlockDiagonalMask.from_seqlens(
|
||||
attn_metadata.seq_lens, device=query.device)
|
||||
|
||||
# Self-attention block of decoder branch just
|
||||
# uses the seq_lens directly
|
||||
elif attn_type == AttentionType.DECODER:
|
||||
assert attn_metadata.seq_lens is not None
|
||||
|
||||
# Decoder self-attention mask is causal
|
||||
attn_bias = BlockDiagonalCausalMask.from_seqlens(
|
||||
attn_metadata.seq_lens, device=query.device)
|
||||
else:
|
||||
raise ValueError("Unknown AttentionType: %s", attn_type)
|
||||
|
||||
if self.sliding_window is not None:
|
||||
attn_bias = attn_bias.make_local_attention(
|
||||
self.sliding_window)
|
||||
attn_bias = [attn_bias]
|
||||
else:
|
||||
assert attn_type == AttentionType.DECODER
|
||||
assert attn_metadata.seq_lens is not None
|
||||
attn_bias = _make_alibi_bias(self.alibi_slopes,
|
||||
self.num_kv_heads, query.dtype,
|
||||
attn_metadata.seq_lens)
|
||||
|
||||
_set_attn_bias(attn_metadata, attn_bias, attn_type)
|
||||
|
||||
# No alibi slopes.
|
||||
# TODO(woosuk): Too many view operations. Let's try to reduce
|
||||
# them in the future for code readability.
|
||||
if self.alibi_slopes is None:
|
||||
# Add the batch dimension.
|
||||
query = query.unsqueeze(0)
|
||||
key = key.unsqueeze(0)
|
||||
value = value.unsqueeze(0)
|
||||
out = xops.memory_efficient_attention_forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_bias=attn_bias[0],
|
||||
p=0.0,
|
||||
scale=self.scale)
|
||||
return out.view_as(original_query)
|
||||
|
||||
# Attention with alibi slopes.
|
||||
# FIXME(woosuk): Because xformers does not support dynamic sequence
|
||||
# lengths with custom attention bias, we process each prompt one by
|
||||
# one. This is inefficient, especially when we have many short prompts.
|
||||
assert attn_metadata.seq_lens is not None
|
||||
output = torch.empty_like(original_query)
|
||||
start = 0
|
||||
for i, seq_len in enumerate(attn_metadata.seq_lens):
|
||||
end = start + seq_len
|
||||
out = xops.memory_efficient_attention_forward(
|
||||
query[None, start:end],
|
||||
key[None, start:end],
|
||||
value[None, start:end],
|
||||
attn_bias=attn_bias[i],
|
||||
p=0.0,
|
||||
scale=self.scale)
|
||||
# TODO(woosuk): Unnecessary copy. Optimize.
|
||||
output[start:end].copy_(out.view_as(original_query[start:end]))
|
||||
start += seq_len
|
||||
return output
|
||||
|
||||
|
||||
def _make_alibi_bias(
|
||||
alibi_slopes: torch.Tensor,
|
||||
num_kv_heads: int,
|
||||
dtype: torch.dtype,
|
||||
seq_lens: List[int],
|
||||
) -> List[AttentionBias]:
|
||||
attn_biases: List[AttentionBias] = []
|
||||
for seq_len in seq_lens:
|
||||
bias = torch.arange(seq_len, dtype=dtype)
|
||||
# NOTE(zhuohan): HF uses
|
||||
# `bias = bias[None, :].repeat(seq_len, 1)`
|
||||
# here. We find that both biases give the same results, but
|
||||
# the bias below more accurately follows the original ALiBi
|
||||
# paper.
|
||||
# Calculate a matrix where each element represents ith element- jth
|
||||
# element.
|
||||
bias = bias[None, :] - bias[:, None]
|
||||
|
||||
padded_len = (seq_len + 7) // 8 * 8
|
||||
num_heads = alibi_slopes.shape[0]
|
||||
bias = torch.empty(
|
||||
1, # batch size
|
||||
num_heads,
|
||||
seq_len,
|
||||
padded_len,
|
||||
device=alibi_slopes.device,
|
||||
dtype=dtype,
|
||||
)[:, :, :, :seq_len].copy_(bias)
|
||||
bias.mul_(alibi_slopes[:, None, None])
|
||||
attn_biases.append(LowerTriangularMaskWithTensorBias(bias))
|
||||
|
||||
return attn_biases
|
||||
481
vllm/attention/layer.py
Normal file
481
vllm/attention/layer.py
Normal file
@@ -0,0 +1,481 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Attention layer."""
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention import AttentionType
|
||||
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
|
||||
from vllm.config import CacheConfig, get_current_vllm_config
|
||||
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||
has_kv_transfer_group,
|
||||
is_v1_kv_transfer_group)
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.platforms import _Backend, current_platform
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.v1.attention.backends.utils import validate_kv_sharing_target
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
"""Attention layer.
|
||||
|
||||
This class takes query, key, and value tensors as input. The input tensors
|
||||
can either contain prompt tokens or generation tokens.
|
||||
The class does the following:
|
||||
|
||||
1. Store the input key and value tensors in the KV cache.
|
||||
2. Perform (multi-head/multi-query/grouped-query) attention.
|
||||
3. Return the output tensor.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: Optional[int] = None,
|
||||
alibi_slopes: Optional[List[float]] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
per_layer_sliding_window: Optional[int] = None,
|
||||
use_mla: bool = False,
|
||||
prefix: str = "",
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
**extra_impl_args,
|
||||
) -> None:
|
||||
"""
|
||||
The KV cache is stored inside this class and is accessed via
|
||||
`self.kv_cache`.
|
||||
"""
|
||||
super().__init__()
|
||||
if per_layer_sliding_window is not None:
|
||||
# per-layer sliding window
|
||||
sliding_window = per_layer_sliding_window
|
||||
elif cache_config is not None:
|
||||
# model-level sliding window
|
||||
sliding_window = cache_config.sliding_window
|
||||
else:
|
||||
sliding_window = None
|
||||
|
||||
if cache_config is not None:
|
||||
kv_cache_dtype = cache_config.cache_dtype
|
||||
block_size = cache_config.block_size
|
||||
is_attention_free = cache_config.is_attention_free
|
||||
calculate_kv_scales = cache_config.calculate_kv_scales
|
||||
else:
|
||||
kv_cache_dtype = "auto"
|
||||
block_size = 64 if envs.VLLM_USE_FLASH_ATTN_PA or envs.VLLM_USE_FLASH_MLA else 16
|
||||
is_attention_free = False
|
||||
calculate_kv_scales = False
|
||||
if num_kv_heads is None:
|
||||
num_kv_heads = num_heads
|
||||
assert num_heads % num_kv_heads == 0, \
|
||||
f"num_heads ({num_heads}) is not " \
|
||||
f"divisible by num_kv_heads ({num_kv_heads})"
|
||||
|
||||
# The default k/v_scale is set to 1.0. This is ignored
|
||||
# when kv-cache is not fp8, and should be used with
|
||||
# kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
|
||||
# expect the pre-quantized k/v_scale to be loaded along
|
||||
# with the model weights.
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.calculate_kv_scales = calculate_kv_scales
|
||||
self._k_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||
self._v_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||
# FlashAttn doesn't support quantizing the kv-cache only
|
||||
# but requires q to be quantized as well.
|
||||
self._q_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||
self._prob_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||
|
||||
# We also keep the float32 versions of k/v_scale for attention
|
||||
# backends that don't support tensors (Flashinfer)
|
||||
self._k_scale_float = 1.0
|
||||
self._v_scale_float = 1.0
|
||||
|
||||
self.use_mla = use_mla
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.sliding_window = sliding_window
|
||||
|
||||
quant_method = quant_config.get_quant_method(
|
||||
self, prefix=prefix) if quant_config else None
|
||||
if quant_method is not None and not isinstance(
|
||||
quant_method, UnquantizedLinearMethod):
|
||||
assert isinstance(quant_method, BaseKVCacheMethod)
|
||||
# TODO (mgoin): kv cache dtype should be specified in the FP8
|
||||
# checkpoint config and become the "auto" behavior
|
||||
if self.kv_cache_dtype == "fp8_e5m2":
|
||||
raise ValueError("fp8_e5m2 kv-cache is not supported with "
|
||||
"fp8 checkpoints.")
|
||||
# If quantization is enabled, we make "k_scale" and "v_scale"
|
||||
# parameters so that it can be loaded from the model checkpoint.
|
||||
# The k/v_scale will then be converted back to native float32
|
||||
# values after weight loading.
|
||||
self.quant_method = quant_method
|
||||
self.quant_method.create_weights(self)
|
||||
|
||||
# During model initialization, the default dtype is set as the model
|
||||
# weight and activation dtype.
|
||||
dtype = torch.get_default_dtype()
|
||||
attn_backend = get_attn_backend(head_size,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
block_size,
|
||||
is_attention_free,
|
||||
blocksparse_params is not None,
|
||||
use_mla=use_mla)
|
||||
impl_cls = attn_backend.get_impl_cls()
|
||||
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
blocksparse_params, logits_soft_cap, attn_type,
|
||||
kv_sharing_target_layer_name, **extra_impl_args)
|
||||
self.backend = backend_name_to_enum(attn_backend.get_name())
|
||||
self.dtype = dtype
|
||||
|
||||
# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
|
||||
# torch.compile works by registering the attention as one giant
|
||||
# opaque custom op. For other platforms, we directly call them
|
||||
# and let torch.compile handle them.
|
||||
self.use_direct_call = not current_platform.is_cuda_alike(
|
||||
) and not current_platform.is_cpu()
|
||||
|
||||
self.use_output = attn_backend.accept_output_buffer
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
self.layer_name = prefix
|
||||
self.attn_type = attn_type
|
||||
|
||||
if kv_sharing_target_layer_name is not None:
|
||||
if not envs.VLLM_USE_V1:
|
||||
raise NotImplementedError(
|
||||
"Cross-layer KV sharing is not supported in V0.")
|
||||
|
||||
validate_kv_sharing_target(
|
||||
prefix,
|
||||
kv_sharing_target_layer_name,
|
||||
compilation_config.static_forward_context,
|
||||
)
|
||||
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||
|
||||
# use a placeholder kv cache tensor during init, which will be replaced
|
||||
# by bind_kv_cache
|
||||
# this variable will not be accessed if use_direct_call is True
|
||||
self.kv_cache = [
|
||||
torch.tensor([]) for _ in range(get_current_vllm_config(
|
||||
).parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
|
||||
self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
|
||||
self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
|
||||
self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
# For some alternate attention backends like MLA the attention output
|
||||
# shape does not match the query shape, so we optionally let the model
|
||||
# definition specify the output tensor shape.
|
||||
output_shape: Optional[torch.Size] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
The KV cache is stored inside this class and is accessed via
|
||||
`self.kv_cache`.
|
||||
|
||||
Attention metadata (`attn_metadata`) is set using a context manager in
|
||||
the model runner's `execute_model` method. It is accessed via forward
|
||||
context using
|
||||
`vllm.forward_context.get_forward_context().attn_metadata`.
|
||||
"""
|
||||
if self.calculate_kv_scales:
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
if attn_metadata.enable_kv_scales_calculation:
|
||||
self.calc_kv_scales(query, key, value)
|
||||
if self.use_output:
|
||||
output_shape = (output_shape
|
||||
if output_shape is not None else query.shape)
|
||||
output = torch.zeros(output_shape,
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
hidden_size = output_shape[-1]
|
||||
# We skip reshaping query, key and value tensors for the MLA
|
||||
# backend since these tensors have different semantics and are
|
||||
# processed differently.
|
||||
if not self.use_mla:
|
||||
# Reshape the query, key, and value tensors.
|
||||
# NOTE(woosuk): We do this outside the custom op to minimize the
|
||||
# CPU overheads from the non-CUDA-graph regions.
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
output = output.view(-1, self.num_heads, self.head_size)
|
||||
if key is not None:
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
if value is not None:
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
if self.use_direct_call:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata[self.layer_name]
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
self.impl.forward(self,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
self_kv_cache,
|
||||
attn_metadata,
|
||||
output=output)
|
||||
else:
|
||||
torch.ops.vllm.unified_attention_with_output(
|
||||
query, key, value, output, self.layer_name)
|
||||
return output.view(-1, hidden_size)
|
||||
else:
|
||||
if self.use_direct_call:
|
||||
forward_context = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata[self.layer_name]
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
return self.impl.forward(self, query, key, value,
|
||||
self_kv_cache, attn_metadata)
|
||||
else:
|
||||
return torch.ops.vllm.unified_attention(
|
||||
query, key, value, self.layer_name)
|
||||
|
||||
def calc_kv_scales(self, query, key, value):
|
||||
self._q_scale.copy_(torch.abs(query).max() / self.q_range)
|
||||
self._k_scale.copy_(torch.abs(key).max() / self.k_range)
|
||||
self._v_scale.copy_(torch.abs(value).max() / self.v_range)
|
||||
self._k_scale_float = self._k_scale.item()
|
||||
self._v_scale_float = self._v_scale.item()
|
||||
# We only calculate the scales once
|
||||
self.calculate_kv_scales = False
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
s = f"head_size={self.impl.head_size}" # type: ignore
|
||||
s += f", num_heads={self.impl.num_heads}" # type: ignore
|
||||
s += f", num_kv_heads={self.impl.num_kv_heads}" # type: ignore
|
||||
s += f", scale={self.impl.scale}" # type: ignore
|
||||
s += f", backend={self.impl.__class__.__name__}"
|
||||
return s
|
||||
|
||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||
if hasattr(self.impl, "process_weights_after_loading"):
|
||||
self.impl.process_weights_after_loading(act_dtype)
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
"""Multi-headed attention without any cache, used for ViT."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = scale
|
||||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0, \
|
||||
f"num_heads ({self.num_heads}) is not " \
|
||||
f"divisible by num_kv_heads ({self.num_kv_heads})"
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
dtype = torch.get_default_dtype()
|
||||
attn_backend = get_attn_backend(head_size,
|
||||
dtype,
|
||||
kv_cache_dtype=None,
|
||||
block_size=64 if envs.VLLM_USE_FLASH_ATTN_PA or envs.VLLM_USE_FLASH_MLA else 16,
|
||||
is_attention_free=False)
|
||||
backend = backend_name_to_enum(attn_backend.get_name())
|
||||
if current_platform.is_rocm():
|
||||
# currently, only torch_sdpa is supported on rocm
|
||||
self.attn_backend = _Backend.TORCH_SDPA
|
||||
else:
|
||||
if backend in (_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1,
|
||||
_Backend.FLEX_ATTENTION):
|
||||
backend = _Backend.XFORMERS
|
||||
|
||||
self.attn_backend = backend if backend in {
|
||||
_Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.PALLAS_VLLM_V1
|
||||
} else _Backend.TORCH_SDPA
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Input shape: batch_size x seq_len x hidden_size"""
|
||||
# TODO(Isotr0py): Use existing backend implementations and support FA3
|
||||
bsz, q_len, _ = query.size()
|
||||
kv_len = key.size(1)
|
||||
|
||||
query = query.view(bsz, q_len, self.num_heads, self.head_size)
|
||||
key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
|
||||
value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)
|
||||
|
||||
if (num_repeat := self.num_queries_per_kv) > 1:
|
||||
# Handle MQA and GQA
|
||||
key = torch.repeat_interleave(key, num_repeat, dim=2)
|
||||
value = torch.repeat_interleave(value, num_repeat, dim=2)
|
||||
|
||||
if self.attn_backend == _Backend.XFORMERS:
|
||||
from xformers import ops as xops
|
||||
|
||||
out = xops.memory_efficient_attention_forward(query,
|
||||
key,
|
||||
value,
|
||||
scale=self.scale)
|
||||
elif self.attn_backend == _Backend.TORCH_SDPA:
|
||||
query, key, value = (x.transpose(1, 2)
|
||||
for x in (query, key, value))
|
||||
out = F.scaled_dot_product_attention(query,
|
||||
key,
|
||||
value,
|
||||
scale=self.scale)
|
||||
out = out.transpose(1, 2)
|
||||
elif self.attn_backend == _Backend.PALLAS_VLLM_V1:
|
||||
query, key, value = (x.transpose(1, 2)
|
||||
for x in (query, key, value))
|
||||
from torch_xla.experimental.custom_kernel import flash_attention
|
||||
out = flash_attention(query, key, value, sm_scale=self.scale)
|
||||
out = out.transpose(1, 2)
|
||||
|
||||
return out.reshape(bsz, q_len, -1)
|
||||
|
||||
|
||||
def wait_for_kv_layer_from_connector(layer_name: str):
|
||||
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
|
||||
return
|
||||
|
||||
connector = get_kv_transfer_group()
|
||||
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if attn_metadata is None:
|
||||
return
|
||||
assert isinstance(attn_metadata, dict)
|
||||
connector.wait_for_layer_load(layer_name)
|
||||
|
||||
|
||||
def maybe_save_kv_layer_to_connector(
|
||||
layer_name: str,
|
||||
kv_cache_layer: List[torch.Tensor],
|
||||
):
|
||||
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
|
||||
return
|
||||
|
||||
connector = get_kv_transfer_group()
|
||||
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if attn_metadata is None:
|
||||
return
|
||||
assert isinstance(attn_metadata, dict)
|
||||
connector.save_kv_layer(layer_name, kv_cache_layer,
|
||||
attn_metadata[layer_name])
|
||||
|
||||
|
||||
def unified_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> torch.Tensor:
|
||||
wait_for_kv_layer_from_connector(layer_name)
|
||||
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata[layer_name]
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
output = self.impl.forward(self, query, key, value, kv_cache,
|
||||
attn_metadata)
|
||||
|
||||
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
|
||||
return output
|
||||
|
||||
|
||||
def unified_attention_fake(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(query).contiguous()
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="unified_attention",
|
||||
op_func=unified_attention,
|
||||
mutates_args=[],
|
||||
fake_impl=unified_attention_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
|
||||
def unified_attention_with_output(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
wait_for_kv_layer_from_connector(layer_name)
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata[layer_name]
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
self.impl.forward(self,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
output=output,
|
||||
output_scale=output_scale)
|
||||
|
||||
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
|
||||
|
||||
|
||||
def unified_attention_with_output_fake(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="unified_attention_with_output",
|
||||
op_func=unified_attention_with_output,
|
||||
mutates_args=["output"],
|
||||
fake_impl=unified_attention_with_output_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
0
vllm/attention/ops/__init__.py
Normal file
0
vllm/attention/ops/__init__.py
Normal file
@@ -0,0 +1,433 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
def blocksparse_flash_attn_varlen_fwd(
|
||||
q,
|
||||
k,
|
||||
v, # (#tokens, n_heads, head_size)
|
||||
cu_seqlens_k,
|
||||
cu_seqlens_q,
|
||||
sm_scale,
|
||||
sparse_layout,
|
||||
*,
|
||||
block_size=64,
|
||||
q_block_size=None,
|
||||
max_seqlen=None):
|
||||
# split q to blocks
|
||||
|
||||
assert isinstance(sparse_layout, (list, tuple))
|
||||
|
||||
_, n_heads, head_size = q.shape
|
||||
batch_size = cu_seqlens_k.size(0) - 1
|
||||
q_block_size = q_block_size or block_size
|
||||
|
||||
assert q.dim() == k.dim() == v.dim() == 3
|
||||
assert q.size(1) % k.size(1) == 0
|
||||
assert q.size(2) == k.size(2)
|
||||
# TODO(linxihui): allow k, v to have different head_size
|
||||
assert k.shape == v.shape
|
||||
assert cu_seqlens_k.dim() == 1
|
||||
|
||||
q_k_ratio = q.size(1) // k.size(1)
|
||||
|
||||
if cu_seqlens_q is None:
|
||||
if q.size(0) == batch_size: # decoding only
|
||||
cu_seqlens_q = torch.arange(
|
||||
0,
|
||||
batch_size + 1,
|
||||
dtype=cu_seqlens_k.dtype,
|
||||
device=cu_seqlens_k.device,
|
||||
)
|
||||
elif q.size(0) == k.size(0):
|
||||
cu_seqlens_q = cu_seqlens_k
|
||||
else:
|
||||
raise ValueError("cu_seqlens_q must be specified\
|
||||
if it mix of prefilling and decoding.")
|
||||
else:
|
||||
assert cu_seqlens_k.size(0) == cu_seqlens_q.size(0)
|
||||
|
||||
# switch to use cpu to avoid too many kernel launches when iterated over
|
||||
q_lens = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).cpu()
|
||||
k_lens = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).cpu()
|
||||
|
||||
assert torch.logical_or(q_lens == 1, k_lens == q_lens).all(), (
|
||||
"length of q should either be 1 (decoding) or same as k (prefilling).")
|
||||
|
||||
if max_seqlen:
|
||||
assert k_lens.max() <= max_seqlen
|
||||
|
||||
n_blocks = (q_lens + q_block_size - 1) // q_block_size
|
||||
|
||||
q_batch_ids = torch.tensor(
|
||||
[i for i, n in enumerate(n_blocks) for _ in range(n)],
|
||||
dtype=cu_seqlens_q.dtype,
|
||||
device=cu_seqlens_q.device,
|
||||
)
|
||||
q_start_sids = torch.tensor(
|
||||
[i * q_block_size for n in n_blocks for i in range(n)],
|
||||
dtype=cu_seqlens_q.dtype,
|
||||
device=cu_seqlens_q.device,
|
||||
)
|
||||
|
||||
out = q.new_empty(q.shape)
|
||||
cu_seqlens_q = cu_seqlens_q.contiguous()
|
||||
cu_seqlens_k = cu_seqlens_k.contiguous()
|
||||
|
||||
layout_crow_indices, layout_col_indices = sparse_layout
|
||||
block_d = triton.next_power_of_2(head_size)
|
||||
|
||||
decoding_only = (q_lens == 1).all().item()
|
||||
grid = (len(q_start_sids), n_heads, 1)
|
||||
|
||||
_fwd_kernel_batch_inference[grid](
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out,
|
||||
sm_scale,
|
||||
cu_seqlens_q[:-1],
|
||||
cu_seqlens_q[1:],
|
||||
cu_seqlens_k[:-1],
|
||||
cu_seqlens_k[1:],
|
||||
q_batch_ids,
|
||||
q_start_sids,
|
||||
0,
|
||||
*q.stride(),
|
||||
0,
|
||||
*k.stride(),
|
||||
0,
|
||||
*v.stride(),
|
||||
0,
|
||||
*out.stride(),
|
||||
layout_crow_indices,
|
||||
layout_col_indices,
|
||||
*layout_crow_indices.stride(),
|
||||
*layout_col_indices.stride(),
|
||||
q_k_ratio,
|
||||
HAS_BATCH_DIM=False,
|
||||
D_HEAD=head_size,
|
||||
BLOCK_M=q_block_size,
|
||||
BLOCK_N=block_size,
|
||||
BLOCK_D=block_d,
|
||||
BLOCK_M_LOADING=(16 if decoding_only else
|
||||
q_block_size), # smaller for decoding
|
||||
EVEN_D=block_d == head_size,
|
||||
num_warps=1 if decoding_only else 4,
|
||||
num_stages=3)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel_inner(
|
||||
acc,
|
||||
l_i,
|
||||
m_i,
|
||||
q,
|
||||
Q,
|
||||
k_block_col_idx,
|
||||
layout_col_ptr,
|
||||
layout_col_stride_h,
|
||||
layout_col_stride_m,
|
||||
k_ptrs,
|
||||
v_ptrs,
|
||||
off_h,
|
||||
offs_m,
|
||||
offs_n,
|
||||
offs_d,
|
||||
stride_kt,
|
||||
stride_vt,
|
||||
sm_scale,
|
||||
k_seqlen,
|
||||
past_len,
|
||||
LAST_K_BLOCK: tl.constexpr,
|
||||
BLOCK_M_LOADING: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
D_HEAD: tl.constexpr,
|
||||
EVEN_D: tl.constexpr,
|
||||
M_LT_N: tl.constexpr,
|
||||
):
|
||||
k_block_id = tl.load(layout_col_ptr + off_h * layout_col_stride_h +
|
||||
k_block_col_idx * layout_col_stride_m).to(tl.int32)
|
||||
start_n = k_block_id * BLOCK_N
|
||||
if LAST_K_BLOCK:
|
||||
if EVEN_D:
|
||||
k = tl.load(
|
||||
k_ptrs + start_n * stride_kt,
|
||||
mask=offs_n[None, :] + start_n < k_seqlen,
|
||||
other=0.0,
|
||||
)
|
||||
else:
|
||||
k = tl.load(
|
||||
k_ptrs + start_n * stride_kt,
|
||||
mask=(offs_n[None, :] + start_n < k_seqlen) &
|
||||
(offs_d[:, None] < D_HEAD),
|
||||
other=0.0,
|
||||
)
|
||||
else:
|
||||
if EVEN_D:
|
||||
k = tl.load(k_ptrs + start_n * stride_kt)
|
||||
else:
|
||||
k = tl.load(k_ptrs + start_n * stride_kt,
|
||||
mask=offs_d[:, None] < D_HEAD,
|
||||
other=0.0)
|
||||
|
||||
qk = tl.zeros([BLOCK_M_LOADING, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, k)
|
||||
qk *= sm_scale
|
||||
|
||||
# the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N
|
||||
if LAST_K_BLOCK | M_LT_N:
|
||||
qk += tl.where(
|
||||
offs_m[:, None] + past_len >= (start_n + offs_n[None, :]),
|
||||
0,
|
||||
float("-inf"),
|
||||
)
|
||||
|
||||
# flash-attn2
|
||||
m_ij = tl.maximum(m_i, tl.max(qk, 1))
|
||||
p = tl.math.exp2(qk - m_ij[:, None])
|
||||
l_ij = tl.sum(p, 1)
|
||||
alpha = tl.math.exp2(m_i - m_ij)
|
||||
acc = acc * alpha[:, None]
|
||||
# update m_i
|
||||
m_i = m_ij
|
||||
l_i = l_i * alpha + l_ij
|
||||
|
||||
p = p.to(Q.dtype.element_ty)
|
||||
# update acc
|
||||
if LAST_K_BLOCK:
|
||||
if EVEN_D:
|
||||
v = tl.load(
|
||||
v_ptrs + start_n * stride_vt,
|
||||
mask=offs_n[:, None] + start_n < k_seqlen,
|
||||
other=0.0,
|
||||
)
|
||||
else:
|
||||
v = tl.load(
|
||||
v_ptrs + start_n * stride_vt,
|
||||
mask=(offs_n[:, None] + start_n < k_seqlen) &
|
||||
(offs_d[None, :] < D_HEAD),
|
||||
other=0.0,
|
||||
)
|
||||
else:
|
||||
if EVEN_D:
|
||||
v = tl.load(v_ptrs + start_n * stride_vt)
|
||||
else:
|
||||
v = tl.load(v_ptrs + start_n * stride_vt,
|
||||
mask=offs_d[None, :] < D_HEAD,
|
||||
other=0.0)
|
||||
|
||||
acc += tl.dot(p, v)
|
||||
|
||||
return acc, l_i, m_i
|
||||
|
||||
|
||||
@triton.heuristics({
|
||||
"M_LT_N":
|
||||
lambda kwargs: kwargs["BLOCK_M"] < kwargs["BLOCK_N"],
|
||||
})
|
||||
@triton.jit
|
||||
def _fwd_kernel_batch_inference(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
Out,
|
||||
sm_scale,
|
||||
q_batch_starts,
|
||||
q_batch_ends,
|
||||
k_batch_starts,
|
||||
k_batch_ends,
|
||||
q_batch_ids,
|
||||
q_start_sids,
|
||||
stride_qb,
|
||||
stride_qt,
|
||||
stride_qh,
|
||||
stride_qd,
|
||||
stride_kb,
|
||||
stride_kt,
|
||||
stride_kh,
|
||||
stride_kd,
|
||||
stride_vb,
|
||||
stride_vt,
|
||||
stride_vh,
|
||||
stride_vd,
|
||||
stride_ob,
|
||||
stride_ot,
|
||||
stride_oh,
|
||||
stride_od,
|
||||
layout_crow_ptr,
|
||||
layout_col_ptr,
|
||||
layout_crow_stride_h,
|
||||
layout_crow_stride_m,
|
||||
layout_col_stride_h,
|
||||
layout_col_stride_m,
|
||||
q_k_ratio,
|
||||
HAS_BATCH_DIM: tl.constexpr,
|
||||
D_HEAD: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_D: tl.constexpr,
|
||||
BLOCK_M_LOADING: tl.constexpr,
|
||||
EVEN_D: tl.constexpr,
|
||||
M_LT_N: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
NOTATION:
|
||||
pid: position id
|
||||
sid: storage id
|
||||
sbid: storage block id
|
||||
pbid: position block id
|
||||
offs_m, offs_n: storage offsets of m-dim(q, row) and n-dim(k, col)
|
||||
|
||||
TODO(linxihui):
|
||||
Optimize grouped-attn
|
||||
"""
|
||||
off_zm = tl.program_id(0)
|
||||
off_h = tl.program_id(1)
|
||||
|
||||
off_h_for_kv = off_h // q_k_ratio
|
||||
|
||||
if HAS_BATCH_DIM:
|
||||
off_z = tl.program_id(2)
|
||||
Q += off_z * stride_qb
|
||||
K += off_z * stride_kb
|
||||
V += off_z * stride_vb
|
||||
Out += off_z * stride_ob
|
||||
start_m = off_zm
|
||||
q_start_sid = start_m * BLOCK_M # always 0 for decoding
|
||||
else:
|
||||
off_z = tl.load(q_batch_ids + off_zm).to(tl.int32) # [0, 0, 0, 1]
|
||||
q_start_sid = tl.load(q_start_sids + off_zm)
|
||||
start_m = q_start_sid // BLOCK_M # q_sbid
|
||||
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M_LOADING)
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_d = tl.arange(0, BLOCK_D)
|
||||
|
||||
q_cu_start = tl.load(q_batch_starts + off_z).to(tl.int32)
|
||||
q_seqlen = tl.load(q_batch_ends + off_z).to(tl.int32) - q_cu_start
|
||||
k_cu_start = tl.load(k_batch_starts + off_z).to(tl.int32)
|
||||
k_seqlen = tl.load(k_batch_ends + off_z).to(tl.int32) - k_cu_start
|
||||
past_len = k_seqlen - q_seqlen
|
||||
|
||||
Q += q_cu_start * stride_qt + off_h * stride_qh
|
||||
K += k_cu_start * stride_kt + off_h_for_kv * stride_kh
|
||||
V += k_cu_start * stride_vt + off_h_for_kv * stride_vh
|
||||
Out += q_cu_start * stride_ot + off_h * stride_oh
|
||||
|
||||
q_pbid = (past_len + q_start_sid) // BLOCK_M
|
||||
|
||||
if EVEN_D:
|
||||
q = tl.load(
|
||||
Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd,
|
||||
mask=offs_m[:, None] < q_seqlen,
|
||||
other=0.0,
|
||||
)
|
||||
else:
|
||||
q = tl.load(
|
||||
Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd,
|
||||
mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD),
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
sparse_crow_ptr = (layout_crow_ptr + off_h * layout_crow_stride_h +
|
||||
q_pbid * layout_crow_stride_m)
|
||||
|
||||
# TODO(linxihui): load at once, with any Triton version
|
||||
# that supports `tl.split`, e.g., Triton 3.0
|
||||
k_block_start = tl.load(sparse_crow_ptr).to(tl.int32)
|
||||
k_block_end = tl.load(sparse_crow_ptr + 1).to(tl.int32)
|
||||
|
||||
m_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32) - float("inf")
|
||||
l_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M_LOADING, BLOCK_D], dtype=tl.float32)
|
||||
|
||||
k_ptrs = K + offs_n[None, :] * stride_kt + offs_d[:, None] * stride_kd
|
||||
v_ptrs = V + offs_n[:, None] * stride_vt + offs_d[None, :] * stride_vd
|
||||
|
||||
sm_scale *= (
|
||||
1.44269504 # 1/log2 as we use base2 for exponential and logarithm
|
||||
)
|
||||
|
||||
for k_block_col_idx in range(k_block_start, k_block_end - 1):
|
||||
acc, l_i, m_i = _fwd_kernel_inner(
|
||||
acc,
|
||||
l_i,
|
||||
m_i,
|
||||
q,
|
||||
Q,
|
||||
k_block_col_idx,
|
||||
layout_col_ptr,
|
||||
layout_col_stride_h,
|
||||
layout_col_stride_m,
|
||||
k_ptrs,
|
||||
v_ptrs,
|
||||
off_h,
|
||||
offs_m,
|
||||
offs_n,
|
||||
offs_d,
|
||||
stride_kt,
|
||||
stride_vt,
|
||||
sm_scale,
|
||||
k_seqlen,
|
||||
past_len,
|
||||
False,
|
||||
BLOCK_M_LOADING,
|
||||
BLOCK_N,
|
||||
D_HEAD,
|
||||
EVEN_D,
|
||||
M_LT_N,
|
||||
)
|
||||
|
||||
acc, l_i, m_i = _fwd_kernel_inner(
|
||||
acc,
|
||||
l_i,
|
||||
m_i,
|
||||
q,
|
||||
Q,
|
||||
k_block_end - 1,
|
||||
layout_col_ptr,
|
||||
layout_col_stride_h,
|
||||
layout_col_stride_m,
|
||||
k_ptrs,
|
||||
v_ptrs,
|
||||
off_h,
|
||||
offs_m,
|
||||
offs_n,
|
||||
offs_d,
|
||||
stride_kt,
|
||||
stride_vt,
|
||||
sm_scale,
|
||||
k_seqlen,
|
||||
past_len,
|
||||
True,
|
||||
BLOCK_M_LOADING,
|
||||
BLOCK_N,
|
||||
D_HEAD,
|
||||
EVEN_D,
|
||||
M_LT_N,
|
||||
)
|
||||
|
||||
# flash-attn 2
|
||||
m_i += tl.math.log2(l_i)
|
||||
acc = acc / l_i[:, None]
|
||||
|
||||
# write output
|
||||
if EVEN_D:
|
||||
tl.store(
|
||||
Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od,
|
||||
acc,
|
||||
mask=offs_m[:, None] < q_seqlen,
|
||||
)
|
||||
else:
|
||||
tl.store(
|
||||
Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od,
|
||||
acc,
|
||||
mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD),
|
||||
)
|
||||
239
vllm/attention/ops/blocksparse_attention/interface.py
Normal file
239
vllm/attention/ops/blocksparse_attention/interface.py
Normal file
@@ -0,0 +1,239 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .utils import (dense_to_crow_col, get_head_sliding_step,
|
||||
get_sparse_attn_mask)
|
||||
|
||||
IS_COMPUTE_8_OR_ABOVE = current_platform.has_device_capability(80)
|
||||
|
||||
if IS_COMPUTE_8_OR_ABOVE:
|
||||
from .blocksparse_attention_kernel import blocksparse_flash_attn_varlen_fwd
|
||||
|
||||
|
||||
class LocalStridedBlockSparseAttn(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_heads,
|
||||
max_seqlen,
|
||||
local_blocks,
|
||||
vert_stride,
|
||||
block_size,
|
||||
device=None,
|
||||
dtype=None,
|
||||
homo_head=False,
|
||||
active_head_range=None,
|
||||
q_block_size=None,
|
||||
use_spda=None,
|
||||
):
|
||||
super().__init__()
|
||||
if use_spda is None:
|
||||
use_spda = current_platform.is_rocm() or \
|
||||
current_platform.is_cpu() or not \
|
||||
IS_COMPUTE_8_OR_ABOVE
|
||||
device = device or (torch.cuda.current_device()
|
||||
if current_platform.is_cuda_alike() else "cpu")
|
||||
device = torch.device(device)
|
||||
# NOTE: vllm CPU backend support BF16 instead of FP16.
|
||||
dtype = dtype or (torch.bfloat16 if IS_COMPUTE_8_OR_ABOVE
|
||||
or device.type == "cpu" else torch.half)
|
||||
|
||||
self.n_heads = n_heads
|
||||
self.max_seqlen = max_seqlen
|
||||
self.local_blocks = local_blocks
|
||||
self.vert_stride = vert_stride
|
||||
self.use_spda = use_spda
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.block_size = block_size
|
||||
self.q_block_size = q_block_size
|
||||
self.homo_head = homo_head
|
||||
self.active_head_range = active_head_range
|
||||
self.head_sliding_step = get_head_sliding_step(n_heads, vert_stride,
|
||||
homo_head)
|
||||
|
||||
sparse_layout, sparse_pattern, self.dense_attn_mask = (
|
||||
self.get_attn_pattern(dtype, device))
|
||||
|
||||
if q_block_size is not None and q_block_size != block_size:
|
||||
if q_block_size > block_size:
|
||||
assert q_block_size % block_size == 0
|
||||
blocks_to_merge = q_block_size // block_size
|
||||
shape = sparse_pattern.shape
|
||||
sparse_pattern = sparse_pattern.view(shape[0], -1,
|
||||
blocks_to_merge,
|
||||
shape[-1])
|
||||
sparse_pattern = sparse_pattern.sum(2)
|
||||
sparse_layout = dense_to_crow_col(sparse_pattern)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Does not support smaller q_block_size. It will be slower."
|
||||
)
|
||||
|
||||
self.sparse_layout = sparse_layout
|
||||
|
||||
def get_attn_pattern(self, dtype, device):
|
||||
sparse_layout, sparse_pattern, dense_attn_mask = get_sparse_attn_mask(
|
||||
self.n_heads,
|
||||
self.max_seqlen,
|
||||
self.max_seqlen,
|
||||
dtype,
|
||||
device,
|
||||
block_size=self.block_size,
|
||||
local_blocks=self.local_blocks,
|
||||
vert_stride=self.vert_stride,
|
||||
homo_head=self.homo_head,
|
||||
return_dense=self.use_spda,
|
||||
dense_mask_type="bias",
|
||||
)
|
||||
if (not self.homo_head) and (self.active_head_range is not None):
|
||||
assert isinstance(self.active_head_range, tuple)
|
||||
assert (len(self.active_head_range) == 2)
|
||||
h_start, h_end = self.active_head_range
|
||||
sparse_layout = tuple(x[h_start:h_end] for x in sparse_layout)
|
||||
if self.use_spda:
|
||||
dense_attn_mask = dense_attn_mask[h_start:h_end]
|
||||
return sparse_layout, sparse_pattern, dense_attn_mask
|
||||
|
||||
def varlen_attn(self,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_k,
|
||||
cu_seqlens_q=None,
|
||||
sm_scale=None):
|
||||
"""
|
||||
q, k, v: shape = (num_tokens, num_heads_q/kv, head_size).
|
||||
Support grouped attention, with `q[:, i*r:(i*r + r)]`
|
||||
is correspondent to `k[:, i]`, where `r` is the q/k ratio.
|
||||
cu_seqlens_k: shape=(batch_size + 1,),
|
||||
indicating segment of samples,
|
||||
e.g., `k[cu_seqlen[i]:cu_seqlne[i+1]]` is q of sample i
|
||||
cu_seqlens_q: shape=(batch_size + 1, ).
|
||||
Default None: same as cu_seqlens_k for prefilling or
|
||||
[0, 1, .., batch_size] for decoding.
|
||||
The only case you need to specify is when q is a mix of
|
||||
prefilling and decoding.
|
||||
sm_scale: softmax scale, default to 1/sqrt(head_size).
|
||||
|
||||
return: tensor of shape as q.
|
||||
"""
|
||||
assert (
|
||||
IS_COMPUTE_8_OR_ABOVE
|
||||
), "Requires compute capability of 8 or above (Ampere or newer) to use \
|
||||
Triton kernel."
|
||||
|
||||
sm_scale = sm_scale or 1.0 / math.sqrt(q.size(-1))
|
||||
|
||||
return blocksparse_flash_attn_varlen_fwd(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_k,
|
||||
cu_seqlens_q,
|
||||
sm_scale,
|
||||
self.sparse_layout,
|
||||
block_size=self.block_size,
|
||||
q_block_size=self.q_block_size,
|
||||
max_seqlen=self.max_seqlen,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def transpose_and_pad(x, cu_seqlens, maxlen, head_repeats=1):
|
||||
"""
|
||||
:param x: (total_tokens, n_heads, head_size)
|
||||
:return: (batch, n_heads, length, head_size)
|
||||
"""
|
||||
x_padded = x.new_empty(
|
||||
len(cu_seqlens) - 1, x.size(1), head_repeats, maxlen, x.size(2))
|
||||
cu_seqlens = cu_seqlens.cpu()
|
||||
for i, (s, e) in enumerate(zip(cu_seqlens[:-1], cu_seqlens[1:])):
|
||||
x_padded[i, :, :, :e - s].copy_(x[s:e].transpose(0,
|
||||
1).unsqueeze(1))
|
||||
return x_padded.flatten(1, 2)
|
||||
|
||||
@staticmethod
|
||||
def transpose_and_unpad(x_padded, cu_seqlens):
|
||||
"""
|
||||
:param x_padded: (batch, n_heads, length, head_size)
|
||||
:return: (total_tokens, n_heads, head_size)
|
||||
"""
|
||||
cu_seqlens = cu_seqlens.cpu()
|
||||
total_n_tokens = cu_seqlens[-1]
|
||||
x = x_padded.new_empty(total_n_tokens, x_padded.size(1),
|
||||
x_padded.size(3))
|
||||
for i, (s, e) in enumerate(zip(cu_seqlens[:-1], cu_seqlens[1:])):
|
||||
x[s:e].copy_(x_padded[i, :, :e - s].transpose(0, 1))
|
||||
return x
|
||||
|
||||
def spda(self, q, k, v, cu_seqlens_k, cu_seqlens_q=None, sm_scale=None):
|
||||
"""For CPU, V100 or other older GPUs.
|
||||
NOTE: torch SPDA supports nested tensor,
|
||||
but seems extremely slow. Choose to pad instead.
|
||||
"""
|
||||
assert (cu_seqlens_q is None or
|
||||
(cu_seqlens_q
|
||||
== cu_seqlens_k).all()), "Can only handle prompt with SPDA."
|
||||
assert q.size(0) == k.size(0), "can only handle prompt with SPDA."
|
||||
|
||||
assert q.size(1) % k.size(1) == 0
|
||||
q_k_ratio = q.size(1) // k.size(1)
|
||||
sm_scale = sm_scale or 1.0 / math.sqrt(q.size(-1))
|
||||
cu_seqlens = cu_seqlens_k.cpu()
|
||||
maxlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
||||
|
||||
if (self.dense_attn_mask.dtype != q.dtype
|
||||
or self.dense_attn_mask.device != q.device):
|
||||
_, _, self.dense_attn_mask = self.get_attn_pattern(
|
||||
q.dtype, q.device)
|
||||
attn_mask = self.dense_attn_mask[None, :, :maxlen, :maxlen]
|
||||
|
||||
q2 = self.transpose_and_pad(q, cu_seqlens, maxlen, 1)
|
||||
k2, v2 = (self.transpose_and_pad(x, cu_seqlens, maxlen, q_k_ratio)
|
||||
for x in [k, v])
|
||||
spda_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
q2, k2, v2, attn_mask=attn_mask, scale=sm_scale)
|
||||
return self.transpose_and_unpad(spda_output, cu_seqlens)
|
||||
|
||||
def forward(self, q, k, v, cu_seqlens_k, cu_seqlens_q=None, sm_scale=None):
|
||||
"""Dispatch to `varlen_attn` (Ampere or newer) or
|
||||
`self.spda`(cpu, Volta, Turing or older)based on
|
||||
the type of device used and cuda compute capability.
|
||||
|
||||
q, k, v: shape = (num_tokens, num_heads_q/kv, head_size).
|
||||
Support grouped attention, with `q[:, i*r:(i*r + r)]`
|
||||
is correspondent to `k[:, i]`, where `r` is the q/k ratio.
|
||||
cu_seqlens_k: shape=(batch_size + 1,), indicating segment of samples,
|
||||
e.g., `k[cu_seqlen[i]:cu_seqlne[i+1]]` is q of sample i
|
||||
cu_seqlens_q: shape=(batch_size + 1, ).
|
||||
Default None: same as cu_seqlens_k for prefilling or
|
||||
[0, 1, .., batch_size] for decoding.
|
||||
The only case you need to specify
|
||||
is when q is a mix of prefilling
|
||||
and decoding.
|
||||
sm_scale: softmax scale, default to 1/sqrt(head_size).
|
||||
|
||||
return: tensor of shape as q.
|
||||
"""
|
||||
assert k.dim() == 3
|
||||
if self.use_spda:
|
||||
return self.spda(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_k,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
sm_scale=sm_scale,
|
||||
)
|
||||
return self.varlen_attn(q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_k,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
sm_scale=sm_scale)
|
||||
246
vllm/attention/ops/blocksparse_attention/utils.py
Normal file
246
vllm/attention/ops/blocksparse_attention/utils.py
Normal file
@@ -0,0 +1,246 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Helper functions for 3D sparse pattern
|
||||
# These function are not optimized and very inefficient.
|
||||
# Avoid calling them too frequent or use a cache mechanism.
|
||||
|
||||
from functools import lru_cache
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import triton
|
||||
|
||||
|
||||
class csr_matrix:
|
||||
"""Simple implementation of CSR matrix conversion without scipy.
|
||||
This replaced scipy.sparse.csr_matrix() previously used."""
|
||||
|
||||
def __init__(self, input_array):
|
||||
if not isinstance(input_array, np.ndarray):
|
||||
raise ValueError("Input must be a NumPy array")
|
||||
|
||||
self.shape = input_array.shape
|
||||
rows, cols = self.shape
|
||||
data = []
|
||||
indices = []
|
||||
indptr = [0]
|
||||
|
||||
for i in range(rows):
|
||||
for j in range(cols):
|
||||
if input_array[i, j]:
|
||||
data.append(input_array[i, j])
|
||||
indices.append(j)
|
||||
indptr.append(len(indices))
|
||||
|
||||
self.data = np.array(data)
|
||||
self.indices = np.array(indices)
|
||||
self.indptr = np.array(indptr)
|
||||
|
||||
|
||||
def dense_to_crow_col(x: torch.Tensor):
|
||||
"""Turning a 2D/3D torch tensor (x) to CSR rows/cols indexing.
|
||||
NOTE: col_indices padded -1
|
||||
"""
|
||||
device = x.device
|
||||
pad = -1
|
||||
dim = x.dim()
|
||||
assert x.dim() in (2, 3)
|
||||
if x.dim() == 2:
|
||||
x = x[None]
|
||||
x = [csr_matrix(xi.bool().cpu().numpy()) for xi in x]
|
||||
crows = torch.vstack([torch.from_numpy(xi.indptr) for xi in x])
|
||||
cols = [torch.from_numpy(xi.indices) for xi in x]
|
||||
max_cols = max(len(xi) for xi in cols)
|
||||
cols = [
|
||||
torch.cat([xi, pad + xi.new_zeros(max_cols - xi.shape[0])])
|
||||
for xi in cols
|
||||
]
|
||||
cols = torch.vstack(cols)
|
||||
if dim == 2:
|
||||
crows = crows[0]
|
||||
cols = cols[0]
|
||||
return crows.to(device), cols.to(device)
|
||||
|
||||
|
||||
def crow_col_to_dense(crows: torch.Tensor,
|
||||
cols: torch.Tensor,
|
||||
dtype: torch.dtype = torch.float16):
|
||||
dim = crows.dim()
|
||||
if dim == 1:
|
||||
crows = crows[None]
|
||||
cols = cols[None]
|
||||
device = crows.device
|
||||
crows, cols = crows.cpu(), cols.cpu() # faster in cpu
|
||||
shape = (crows.shape[0], crows.shape[1] - 1, cols.max() + 1)
|
||||
x = torch.zeros(shape, dtype=dtype)
|
||||
for i in range(shape[0]):
|
||||
for j in range(shape[1]):
|
||||
x[i, j, cols[i, crows[i, j]:crows[i, j + 1]]] = 1
|
||||
if dim == 1:
|
||||
x = x[0]
|
||||
return x.to(device)
|
||||
|
||||
|
||||
def dense_to_ccol_row(x: torch.Tensor):
|
||||
"""Similar, but to CSC format"""
|
||||
x = x.transpose(-2, -1)
|
||||
return dense_to_crow_col(x)
|
||||
|
||||
|
||||
def ccol_row_to_dense(ccol: torch.Tensor,
|
||||
rows: torch.Tensor,
|
||||
dtype: torch.dtype = torch.float16):
|
||||
return crow_col_to_dense(ccol, rows, dtype).permute(0, 2, 1).contiguous()
|
||||
|
||||
|
||||
def _get_sparse_attn_mask_homo_head(
|
||||
q_len: int,
|
||||
max_seqlen: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
block_size: int = 128,
|
||||
local_blocks: int = 4,
|
||||
vert_stride: int = 4,
|
||||
return_dense: bool = False,
|
||||
):
|
||||
"""
|
||||
:return: a tuple of 3:
|
||||
- tuple of crow_indices, col_indices representation
|
||||
of CSR format.
|
||||
- block dense mask
|
||||
- all token dense mask (be aware that it can be
|
||||
OOM if it is too big) if `return_dense==True`,
|
||||
otherwise, None
|
||||
"""
|
||||
with torch.no_grad():
|
||||
num_blocks = triton.cdiv(max_seqlen, block_size)
|
||||
q_pos = torch.arange(num_blocks)[:, None]
|
||||
k_pos = torch.arange(num_blocks)[None]
|
||||
mask_vert_strided = (torch.arange(num_blocks) + 1) % vert_stride == 0
|
||||
block_mask_dense = (((q_pos >= k_pos)
|
||||
& ((q_pos - k_pos < local_blocks)
|
||||
| mask_vert_strided)).to(device).to(dtype))
|
||||
num_blocks_q = triton.cdiv(q_len, block_size)
|
||||
block_mask_dense_output = (dense_to_crow_col(
|
||||
block_mask_dense[-num_blocks_q:].contiguous()))
|
||||
if return_dense:
|
||||
mask_dense = torch.kron(
|
||||
block_mask_dense,
|
||||
block_mask_dense.new_ones((block_size, block_size)),
|
||||
)
|
||||
causal_mask = torch.tril(torch.ones(
|
||||
max_seqlen, max_seqlen)).type_as(mask_dense)[-q_len:]
|
||||
mask_dense = mask_dense[-q_len:, :max_seqlen] * causal_mask
|
||||
return (
|
||||
block_mask_dense_output,
|
||||
block_mask_dense,
|
||||
mask_dense,
|
||||
)
|
||||
else:
|
||||
return (
|
||||
block_mask_dense_output,
|
||||
block_mask_dense,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
def binary_mask_to_bias(mask_dense: torch.Tensor):
|
||||
mask_dense = 1 - mask_dense
|
||||
mask_dense.masked_fill_(mask_dense.bool(), -torch.inf)
|
||||
return mask_dense
|
||||
|
||||
|
||||
def get_head_sliding_step(n_heads: int,
|
||||
vert_stride: int,
|
||||
homo_head: bool = False):
|
||||
if homo_head:
|
||||
return 0
|
||||
return max(1, int(vert_stride / n_heads))
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_sparse_attn_mask(
|
||||
n_heads: int,
|
||||
q_len: int,
|
||||
max_seqlen: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
block_size: int = 64,
|
||||
local_blocks: int = 4,
|
||||
vert_stride: int = 4,
|
||||
homo_head: bool = True,
|
||||
return_dense: bool = False,
|
||||
dense_mask_type: str = "binary",
|
||||
):
|
||||
"""
|
||||
:param dense_mask_type: "binary" (0 for skip token, 1 for others)
|
||||
or "bias" (-inf for skip token, 0 or others)
|
||||
:return: a tuple of 3:
|
||||
- tuple of crow_indices, col_indices representation
|
||||
of CSR format.
|
||||
- block dense mask
|
||||
- all token dense mask (be aware that it can be OOM if it
|
||||
is too big) if `return_dense==True`, otherwise, None
|
||||
"""
|
||||
assert dense_mask_type in ("binary", "bias")
|
||||
if homo_head:
|
||||
with torch.no_grad():
|
||||
(crow, col), block_mask_dense, mask_dense = (
|
||||
_get_sparse_attn_mask_homo_head(
|
||||
q_len,
|
||||
max_seqlen,
|
||||
dtype,
|
||||
device,
|
||||
block_size,
|
||||
local_blocks,
|
||||
vert_stride,
|
||||
return_dense,
|
||||
))
|
||||
crow = crow[None].expand(n_heads, crow.shape[0])
|
||||
col = col[None].expand(n_heads, col.shape[0])
|
||||
if return_dense:
|
||||
mask_dense = mask_dense[None].expand(n_heads,
|
||||
*mask_dense.shape)
|
||||
if dense_mask_type == "bias":
|
||||
mask_dense = binary_mask_to_bias(mask_dense)
|
||||
return (crow, col), block_mask_dense, mask_dense
|
||||
|
||||
with torch.no_grad():
|
||||
num_blocks = triton.cdiv(max_seqlen, block_size)
|
||||
q_pos = torch.arange(num_blocks)[None, :, None]
|
||||
k_pos = torch.arange(num_blocks)[None, None]
|
||||
head_sliding_step = get_head_sliding_step(n_heads, vert_stride)
|
||||
mask_vert_strided = [
|
||||
(torch.arange(num_blocks) + h * head_sliding_step + 1) %
|
||||
vert_stride == 0 for h in range(n_heads)
|
||||
]
|
||||
mask_vert_strided = torch.vstack(mask_vert_strided).unsqueeze(1)
|
||||
block_mask_dense = (((q_pos >= k_pos)
|
||||
& ((q_pos - k_pos < local_blocks)
|
||||
| mask_vert_strided)).to(device).to(dtype))
|
||||
num_blocks_q = triton.cdiv(q_len, block_size)
|
||||
block_mask_dense_output = block_mask_dense[:, -num_blocks_q:]
|
||||
if return_dense:
|
||||
mask_dense = torch.kron(
|
||||
block_mask_dense,
|
||||
block_mask_dense.new_ones((block_size, block_size)),
|
||||
)
|
||||
causal_mask = torch.tril(torch.ones(
|
||||
max_seqlen, max_seqlen)).type_as(mask_dense)[-q_len:]
|
||||
mask_dense = mask_dense[..., -q_len:, :max_seqlen] * causal_mask[None]
|
||||
if dense_mask_type == "bias":
|
||||
mask_dense = binary_mask_to_bias(mask_dense)
|
||||
|
||||
return (
|
||||
dense_to_crow_col(block_mask_dense_output),
|
||||
block_mask_dense,
|
||||
mask_dense,
|
||||
)
|
||||
else:
|
||||
return (
|
||||
dense_to_crow_col(block_mask_dense_output),
|
||||
block_mask_dense,
|
||||
None,
|
||||
)
|
||||
368
vllm/attention/ops/chunked_prefill_paged_decode.py
Normal file
368
vllm/attention/ops/chunked_prefill_paged_decode.py
Normal file
@@ -0,0 +1,368 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Authors:
|
||||
# - Burkhard Ringlein <ngl@zurich.ibm.com>
|
||||
# - Jan van Lunteren <jvl@zurich.ibm.com>
|
||||
# - Chih-Chieh Yang <chih.chieh.yang@ibm.com>
|
||||
# - Thomas Parnell <tpa@zurich.ibm.com>
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.rocm import use_rocm_custom_paged_attention
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .prefix_prefill import context_attention_fwd
|
||||
|
||||
|
||||
@triton.jit
|
||||
def cdiv_fn(x, y):
|
||||
return (x + y - 1) // y
|
||||
|
||||
|
||||
@triton.jit
|
||||
def kernel_paged_attention_2d(
|
||||
output_ptr, # [num_tokens, num_query_heads, head_size]
|
||||
query_ptr, # [num_tokens, num_query_heads, head_size]
|
||||
key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x]
|
||||
value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size]
|
||||
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
|
||||
seq_lens_ptr, # [num_seqs]
|
||||
alibi_slopes_ptr, # [num_query_heads]
|
||||
scale, # float32
|
||||
k_scale, # float32
|
||||
v_scale, # float32
|
||||
num_query_heads: tl.constexpr, # int
|
||||
num_queries_per_kv: tl.constexpr, # int
|
||||
num_queries_per_kv_padded: tl.constexpr, # int
|
||||
block_table_stride: tl.int64, # int
|
||||
query_stride_0: tl.int64, # int
|
||||
query_stride_1: tl.int64, # int, should be equal to head_size
|
||||
output_stride_0: tl.int64, # int
|
||||
output_stride_1: tl.int64, # int, should be equal to head_size
|
||||
BLOCK_SIZE: tl.constexpr, # int
|
||||
HEAD_SIZE: tl.constexpr, # int
|
||||
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
|
||||
USE_ALIBI_SLOPES: tl.constexpr, # bool
|
||||
SLIDING_WINDOW: tl.constexpr, # int
|
||||
x: tl.constexpr, # int
|
||||
stride_k_cache_0: tl.int64, # int
|
||||
stride_k_cache_1: tl.int64, # int
|
||||
stride_k_cache_2: tl.int64, # int
|
||||
stride_k_cache_3: tl.int64, # int
|
||||
stride_k_cache_4: tl.int64, # int
|
||||
stride_v_cache_0: tl.int64, # int
|
||||
stride_v_cache_1: tl.int64, # int
|
||||
stride_v_cache_2: tl.int64, # int
|
||||
stride_v_cache_3: tl.int64, # int
|
||||
filter_by_query_len: tl.constexpr, # bool
|
||||
query_start_len_ptr, # [num_seqs+1]
|
||||
):
|
||||
seq_idx = tl.program_id(0)
|
||||
kv_head_idx = tl.program_id(1)
|
||||
|
||||
if filter_by_query_len:
|
||||
cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx)
|
||||
cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx +
|
||||
1)
|
||||
cur_batch_query_len = cur_batch_in_all_stop_index \
|
||||
- cur_batch_in_all_start_index
|
||||
if cur_batch_query_len > 1:
|
||||
return
|
||||
else:
|
||||
cur_batch_in_all_start_index = seq_idx
|
||||
|
||||
query_head_idx = kv_head_idx * num_queries_per_kv + tl.arange(
|
||||
0, num_queries_per_kv_padded)
|
||||
|
||||
query_offset = (cur_batch_in_all_start_index * query_stride_0 +
|
||||
query_head_idx[:, None] * query_stride_1)
|
||||
|
||||
head_mask = query_head_idx < (kv_head_idx + 1) * num_queries_per_kv
|
||||
head_mask = head_mask & (query_head_idx < num_query_heads)
|
||||
|
||||
dim_mask = tl.where(tl.arange(0, HEAD_SIZE_PADDED) < HEAD_SIZE, 1,
|
||||
0).to(tl.int1)
|
||||
|
||||
# Q : (num_queries_per_kv, HEAD_SIZE,)
|
||||
Q = tl.load(
|
||||
query_ptr + query_offset + tl.arange(0, HEAD_SIZE_PADDED)[None, :],
|
||||
mask=dim_mask[None, :] & head_mask[:, None],
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
block_table_offset = seq_idx * block_table_stride
|
||||
|
||||
M = tl.full([num_queries_per_kv_padded], float("-inf"), dtype=tl.float32)
|
||||
L = tl.full([num_queries_per_kv_padded], 1.0, dtype=tl.float32)
|
||||
acc = tl.zeros([num_queries_per_kv_padded, HEAD_SIZE_PADDED],
|
||||
dtype=tl.float32)
|
||||
|
||||
# sequence len for this particular sequence
|
||||
seq_len = tl.load(seq_lens_ptr + seq_idx)
|
||||
|
||||
# alibi slope for this head
|
||||
if USE_ALIBI_SLOPES:
|
||||
alibi_slope = tl.load(alibi_slopes_ptr + query_head_idx,
|
||||
mask=head_mask,
|
||||
other=0.0)
|
||||
|
||||
num_blocks = cdiv_fn(seq_len, BLOCK_SIZE)
|
||||
|
||||
# iterate through tiles
|
||||
for j in range(0, num_blocks):
|
||||
|
||||
physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j)
|
||||
|
||||
offs_n = tl.arange(0, BLOCK_SIZE)
|
||||
offs_d = tl.arange(0, HEAD_SIZE_PADDED)
|
||||
|
||||
v_offset = (physical_block_idx * stride_v_cache_0 +
|
||||
kv_head_idx * stride_v_cache_1 +
|
||||
offs_d[None, :] * stride_v_cache_2 +
|
||||
offs_n[:, None] * stride_v_cache_3)
|
||||
|
||||
k_offset = (physical_block_idx * stride_k_cache_0 +
|
||||
kv_head_idx * stride_k_cache_1 +
|
||||
(offs_d[:, None] // x) * stride_k_cache_2 +
|
||||
offs_n[None, :] * stride_k_cache_3 +
|
||||
(offs_d[:, None] % x) * stride_k_cache_4)
|
||||
|
||||
# K : (HEAD_SIZE, BLOCK_SIZE)
|
||||
K_load = tl.load(key_cache_ptr + k_offset,
|
||||
mask=dim_mask[:, None],
|
||||
other=0.0)
|
||||
|
||||
if K_load.dtype.is_fp8():
|
||||
K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype)
|
||||
else:
|
||||
K = K_load
|
||||
|
||||
# V : (BLOCK_SIZE, HEAD_SIZE)
|
||||
V_load = tl.load(value_cache_ptr + v_offset,
|
||||
mask=dim_mask[None, :],
|
||||
other=0.0)
|
||||
|
||||
if V_load.dtype.is_fp8():
|
||||
V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype)
|
||||
else:
|
||||
V = V_load
|
||||
|
||||
seq_offset = j * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
boundary = tl.full([BLOCK_SIZE], seq_len, dtype=tl.int32)
|
||||
seq_mask = seq_offset[None, :] < boundary
|
||||
|
||||
# S : (num_queries_per_kv, BLOCK_SIZE,)
|
||||
S = tl.where(head_mask[:, None] & seq_mask, 0.0,
|
||||
float("-inf")).to(tl.float32)
|
||||
S += scale * tl.dot(Q, K)
|
||||
|
||||
context_len = seq_len - 1
|
||||
|
||||
if SLIDING_WINDOW > 0:
|
||||
S = tl.where((context_len - seq_offset) < SLIDING_WINDOW, S,
|
||||
-10000)
|
||||
|
||||
if USE_ALIBI_SLOPES:
|
||||
S += alibi_slope[:, None] * (seq_offset - context_len)
|
||||
|
||||
# compute running maximum
|
||||
# m_j : (num_queries_per_kv,)
|
||||
m_j = tl.maximum(M, tl.max(S, axis=1))
|
||||
|
||||
# P : (num_queries_per_kv, BLOCK_SIZE,)
|
||||
P = tl.exp(S - m_j[:, None])
|
||||
|
||||
# l_j : (num_queries_per_kv,)
|
||||
l_j = tl.sum(P, axis=1)
|
||||
|
||||
# alpha : (num_queries_per_kv, )
|
||||
alpha = tl.exp(M - m_j)
|
||||
|
||||
# acc : (num_queries_per_kv, BLOCK_SIZE,)
|
||||
acc = acc * alpha[:, None]
|
||||
|
||||
# update constants
|
||||
L = L * alpha + l_j
|
||||
M = m_j
|
||||
|
||||
# acc : (num_queries_per_kv, BLOCK_SIZE,)
|
||||
acc += tl.dot(P.to(V.dtype), V)
|
||||
|
||||
# epilogue
|
||||
acc = acc / L[:, None]
|
||||
|
||||
output_offset = (cur_batch_in_all_start_index * output_stride_0 +
|
||||
query_head_idx * output_stride_1)
|
||||
|
||||
tl.store(
|
||||
output_ptr + output_offset[:, None] +
|
||||
tl.arange(0, HEAD_SIZE_PADDED)[None, :],
|
||||
acc,
|
||||
mask=dim_mask[None, :] & head_mask[:, None],
|
||||
)
|
||||
|
||||
|
||||
def chunked_prefill_paged_decode(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
output,
|
||||
kv_cache_dtype,
|
||||
key_cache,
|
||||
value_cache,
|
||||
block_table,
|
||||
query_start_loc,
|
||||
seq_lens,
|
||||
max_seq_len,
|
||||
max_query_len,
|
||||
k_scale,
|
||||
v_scale,
|
||||
alibi_slopes=None,
|
||||
sliding_window=None,
|
||||
sm_scale=None,
|
||||
):
|
||||
|
||||
if sm_scale is None:
|
||||
sm_scale = 1.0 / (query.shape[1]**0.5)
|
||||
|
||||
use_alibi_slopes = alibi_slopes is not None
|
||||
|
||||
if sliding_window is None or sliding_window <= 0:
|
||||
sliding_window = 0
|
||||
|
||||
if max_query_len > 1:
|
||||
context_attention_fwd(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
o=output,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
k_cache=key_cache,
|
||||
v_cache=value_cache,
|
||||
b_loc=block_table,
|
||||
b_start_loc=query_start_loc,
|
||||
b_seq_len=seq_lens,
|
||||
max_seq_len=max_seq_len,
|
||||
max_input_len=max_query_len,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
alibi_slopes=alibi_slopes,
|
||||
sliding_window=sliding_window,
|
||||
sm_scale=sm_scale,
|
||||
skip_decode=True,
|
||||
)
|
||||
|
||||
block_size = value_cache.shape[3]
|
||||
num_seqs = len(seq_lens)
|
||||
num_query_heads = query.shape[1]
|
||||
num_kv_heads = key.shape[1]
|
||||
num_queries_per_kv = query.shape[1] // key.shape[1]
|
||||
head_size = query.shape[2]
|
||||
|
||||
# Conversion of FP8 Tensor from uint8 storage to
|
||||
# appropriate torch.dtype for interpretation by Triton
|
||||
if "fp8" in kv_cache_dtype:
|
||||
assert key_cache.dtype in [torch.uint8, current_platform.fp8_dtype()]
|
||||
assert value_cache.dtype in [torch.uint8, current_platform.fp8_dtype()]
|
||||
|
||||
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
|
||||
target_dtype = current_platform.fp8_dtype()
|
||||
elif kv_cache_dtype == "fp8_e5m2":
|
||||
target_dtype = torch.float8_e5m2
|
||||
else:
|
||||
raise ValueError("Unsupported FP8 dtype:", kv_cache_dtype)
|
||||
|
||||
key_cache = key_cache.view(target_dtype)
|
||||
value_cache = value_cache.view(target_dtype)
|
||||
|
||||
num_queries_per_kv_padded = max(triton.next_power_of_2(num_queries_per_kv),
|
||||
16)
|
||||
|
||||
use_custom = use_rocm_custom_paged_attention(query.dtype, head_size,
|
||||
block_size,
|
||||
num_queries_per_kv,
|
||||
max_seq_len, sliding_window,
|
||||
kv_cache_dtype, alibi_slopes)
|
||||
if use_custom:
|
||||
_PARTITION_SIZE_ROCM = 256
|
||||
max_num_partitions = ((max_seq_len + _PARTITION_SIZE_ROCM - 1) //
|
||||
_PARTITION_SIZE_ROCM)
|
||||
assert _PARTITION_SIZE_ROCM % block_size == 0
|
||||
total_num_seq = block_table.shape[0]
|
||||
tmp_output = torch.empty(
|
||||
size=(total_num_seq, num_query_heads, max_num_partitions,
|
||||
head_size),
|
||||
dtype=output.dtype,
|
||||
device=output.device,
|
||||
)
|
||||
exp_sums = torch.empty(
|
||||
size=(total_num_seq, num_query_heads, max_num_partitions),
|
||||
dtype=torch.float32,
|
||||
device=output.device,
|
||||
)
|
||||
max_logits = torch.empty_like(exp_sums)
|
||||
|
||||
ops.paged_attention_rocm(
|
||||
output,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
scale=sm_scale,
|
||||
block_tables=block_table,
|
||||
seq_lens=seq_lens,
|
||||
query_start_loc=query_start_loc,
|
||||
block_size=block_size,
|
||||
max_seq_len=max_seq_len,
|
||||
alibi_slopes=alibi_slopes,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
)
|
||||
else:
|
||||
kernel_paged_attention_2d[(
|
||||
num_seqs,
|
||||
num_kv_heads,
|
||||
)](
|
||||
output_ptr=output,
|
||||
query_ptr=query,
|
||||
key_cache_ptr=key_cache,
|
||||
value_cache_ptr=value_cache,
|
||||
block_tables_ptr=block_table,
|
||||
seq_lens_ptr=seq_lens,
|
||||
alibi_slopes_ptr=alibi_slopes,
|
||||
scale=sm_scale,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
num_query_heads=num_query_heads,
|
||||
num_queries_per_kv=num_queries_per_kv,
|
||||
num_queries_per_kv_padded=num_queries_per_kv_padded,
|
||||
block_table_stride=block_table.stride(0),
|
||||
query_stride_0=query.stride(0),
|
||||
query_stride_1=query.stride(1),
|
||||
output_stride_0=output.stride(0),
|
||||
output_stride_1=output.stride(1),
|
||||
BLOCK_SIZE=block_size,
|
||||
HEAD_SIZE=head_size,
|
||||
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
|
||||
USE_ALIBI_SLOPES=use_alibi_slopes,
|
||||
SLIDING_WINDOW=sliding_window,
|
||||
x=key_cache.shape[4],
|
||||
stride_k_cache_0=key_cache.stride(0),
|
||||
stride_k_cache_1=key_cache.stride(1),
|
||||
stride_k_cache_2=key_cache.stride(2),
|
||||
stride_k_cache_3=key_cache.stride(3),
|
||||
stride_k_cache_4=key_cache.stride(4),
|
||||
stride_v_cache_0=value_cache.stride(0),
|
||||
stride_v_cache_1=value_cache.stride(1),
|
||||
stride_v_cache_2=value_cache.stride(2),
|
||||
stride_v_cache_3=value_cache.stride(3),
|
||||
filter_by_query_len=True,
|
||||
query_start_len_ptr=query_start_loc,
|
||||
)
|
||||
1308
vllm/attention/ops/flash_attn_triton_mqa_gqa.py
Normal file
1308
vllm/attention/ops/flash_attn_triton_mqa_gqa.py
Normal file
File diff suppressed because it is too large
Load Diff
156
vllm/attention/ops/flashmla.py
Normal file
156
vllm/attention/ops/flashmla.py
Normal file
@@ -0,0 +1,156 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# adapted from: https://github.com/deepseek-ai/FlashMLA/blob/main/flash_mla/flash_mla_interface.py
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if current_platform.is_cuda():
|
||||
try:
|
||||
import vllm._flashmla_C # noqa: F401
|
||||
_flashmla_C_AVAILABLE = True
|
||||
except ImportError:
|
||||
_flashmla_C_AVAILABLE = False
|
||||
else:
|
||||
_flashmla_C_AVAILABLE = False
|
||||
|
||||
if current_platform.is_rocm():
|
||||
import flash_mla_cuda
|
||||
_flashmla_C_AVAILABLE = True
|
||||
|
||||
def is_flashmla_supported() -> Tuple[bool, Optional[str]]:
|
||||
"""
|
||||
Return: is_supported_flag, unsupported_reason (optional).
|
||||
"""
|
||||
if not (current_platform.is_cuda() or current_platform.is_rocm()):
|
||||
return False, "FlashMLA is supported on CUDA and ROCM devices."
|
||||
if current_platform.get_device_capability()[0] != 9:
|
||||
return False, "FlashMLA is only supported on Hopper devices."
|
||||
if not _flashmla_C_AVAILABLE:
|
||||
return False, "vllm._flashmla_C is not available, likely was not "\
|
||||
"compiled due to insufficient nvcc version or a supported arch "\
|
||||
"(only sm90a currently) was not in the list of target arches to "\
|
||||
"compile for."
|
||||
return True, None
|
||||
|
||||
|
||||
def get_mla_metadata(
|
||||
cache_seqlens: torch.Tensor,
|
||||
num_heads_per_head_k: int,
|
||||
num_heads_k: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Arguments:
|
||||
cache_seqlens: (batch_size), dtype torch.int32.
|
||||
num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k.
|
||||
num_heads_k: num_heads_k.
|
||||
|
||||
Return:
|
||||
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize),
|
||||
dtype torch.int32.
|
||||
num_splits: (batch_size + 1), dtype torch.int32.
|
||||
"""
|
||||
if current_platform.is_rocm():
|
||||
return flash_mla_cuda.get_mla_metadata(cache_seqlens,
|
||||
num_heads_per_head_k,
|
||||
num_heads_k)
|
||||
else:
|
||||
return torch.ops._flashmla_C.get_mla_metadata(cache_seqlens,
|
||||
num_heads_per_head_k,
|
||||
num_heads_k)
|
||||
|
||||
|
||||
def flash_mla_with_kvcache(
|
||||
q: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
block_table: torch.Tensor,
|
||||
cache_seqlens: torch.Tensor,
|
||||
head_dim_v: int,
|
||||
tile_scheduler_metadata: torch.Tensor,
|
||||
num_splits: torch.Tensor,
|
||||
softmax_scale: Optional[float] = None,
|
||||
causal: bool = False,
|
||||
k_scale = None,
|
||||
kv_cache_dtype = "auto",
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Arguments:
|
||||
q: (batch_size, seq_len_q, num_heads_q, head_dim).
|
||||
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
|
||||
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
|
||||
cache_seqlens: (batch_size), torch.int32.
|
||||
head_dim_v: Head_dim of v.
|
||||
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize),
|
||||
torch.int32, return by get_mla_metadata.
|
||||
num_splits: (batch_size + 1), torch.int32, return by get_mla_metadata.
|
||||
softmax_scale: float. The scaling of QK^T before applying softmax.
|
||||
Default to 1 / sqrt(head_dim).
|
||||
causal: bool. Whether to apply causal attention mask.
|
||||
|
||||
Return:
|
||||
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
|
||||
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
|
||||
"""
|
||||
if softmax_scale is None:
|
||||
softmax_scale = q.shape[-1]**(-0.5)
|
||||
if current_platform.is_rocm():
|
||||
if kv_cache_dtype == "fp8":
|
||||
out, softmax_lse = flash_mla_cuda.fwd_kvcache_quantization_mla(
|
||||
q,
|
||||
k_cache,
|
||||
None,
|
||||
head_dim_v,
|
||||
cache_seqlens,
|
||||
block_table,
|
||||
softmax_scale,
|
||||
causal,
|
||||
tile_scheduler_metadata,
|
||||
num_splits,
|
||||
k_scale,
|
||||
"fp8_e4m3",
|
||||
)
|
||||
return out, softmax_lse
|
||||
out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla(
|
||||
q,
|
||||
k_cache,
|
||||
None,
|
||||
head_dim_v,
|
||||
cache_seqlens,
|
||||
block_table,
|
||||
softmax_scale,
|
||||
causal,
|
||||
tile_scheduler_metadata,
|
||||
num_splits,
|
||||
)
|
||||
else:
|
||||
out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla(
|
||||
q,
|
||||
k_cache,
|
||||
None,
|
||||
head_dim_v,
|
||||
cache_seqlens,
|
||||
block_table,
|
||||
softmax_scale,
|
||||
causal,
|
||||
tile_scheduler_metadata,
|
||||
num_splits,
|
||||
)
|
||||
return out, softmax_lse
|
||||
|
||||
|
||||
#
|
||||
# TODO: Add fake functions
|
||||
#
|
||||
# @register_fake("_flashmla_C::get_mla_metadata")
|
||||
# def _get_mla_metadata_fake(....) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# return ....
|
||||
#
|
||||
# @register_fake("_flashmla_C::fwd_kvcache_mla")
|
||||
# def _fwd_kvcache_mla_fake(....) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# return ....
|
||||
#
|
||||
88
vllm/attention/ops/hpu_paged_attn.py
Normal file
88
vllm/attention/ops/hpu_paged_attn.py
Normal file
@@ -0,0 +1,88 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
###############################################################################
|
||||
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company
|
||||
###############################################################################
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from vllm_hpu_extension import cache_ops, ops
|
||||
|
||||
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
|
||||
_PARTITION_SIZE = 512
|
||||
|
||||
|
||||
@dataclass
|
||||
class HPUPagedAttentionMetadata:
|
||||
"""Metadata for PagedAttention."""
|
||||
block_list: Optional[torch.Tensor]
|
||||
block_mapping: Optional[torch.Tensor]
|
||||
block_usage: Optional[torch.Tensor]
|
||||
block_indices: Optional[torch.Tensor]
|
||||
block_offsets: Optional[torch.Tensor]
|
||||
block_groups: Optional[torch.Tensor]
|
||||
|
||||
|
||||
class HPUPagedAttention:
|
||||
|
||||
@staticmethod
|
||||
def get_supported_head_sizes() -> List[int]:
|
||||
return [64, 80, 96, 112, 128, 256]
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return (num_blocks, block_size, num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def split_kv_cache(
|
||||
kv_cache: torch.Tensor,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
key_cache = kv_cache[0]
|
||||
value_cache = kv_cache[1]
|
||||
return key_cache, value_cache
|
||||
|
||||
@staticmethod
|
||||
def write_to_paged_cache(key: torch.Tensor, value: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor, kv_cache_dtype: str,
|
||||
is_prompt: bool) -> None:
|
||||
cache_ops.reshape_and_cache(key, value, key_cache, value_cache,
|
||||
slot_mapping, kv_cache_dtype, is_prompt)
|
||||
|
||||
@staticmethod
|
||||
def forward_decode(**kwargs) -> torch.Tensor:
|
||||
return ops.flat_pa(**kwargs)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: Tuple[torch.Tensor, torch.Tensor],
|
||||
dst_kv_cache: Tuple[torch.Tensor, torch.Tensor],
|
||||
src_to_dsts: torch.Tensor,
|
||||
) -> None:
|
||||
src_key_cache = src_kv_cache[0]
|
||||
dst_key_cache = dst_kv_cache[0]
|
||||
cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dsts)
|
||||
|
||||
src_value_cache = src_kv_cache[1]
|
||||
dst_value_cache = dst_kv_cache[1]
|
||||
cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dsts)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
src_to_dsts: torch.Tensor,
|
||||
) -> None:
|
||||
key_caches = [kv_cache[0] for kv_cache in kv_caches]
|
||||
value_caches = [kv_cache[1] for kv_cache in kv_caches]
|
||||
cache_ops.copy_blocks(key_caches, value_caches, src_to_dsts)
|
||||
195
vllm/attention/ops/ipex_attn.py
Normal file
195
vllm/attention/ops/ipex_attn.py
Normal file
@@ -0,0 +1,195 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch.llm.modules as ipex_modules
|
||||
_use_ipex = True
|
||||
# AttributeError is to handle a bug in ipex https://github.com/intel/intel-extension-for-pytorch/pull/813
|
||||
except (ImportError, AttributeError):
|
||||
_use_ipex = False
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
|
||||
class _PagedAttention:
|
||||
|
||||
@staticmethod
|
||||
def get_supported_head_sizes() -> List[int]:
|
||||
return [32, 64, 80, 96, 112, 128, 192, 256]
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
*args,
|
||||
) -> Tuple[int, ...]:
|
||||
return 2, num_blocks, block_size * num_kv_heads * head_size
|
||||
|
||||
@staticmethod
|
||||
def split_kv_cache(
|
||||
kv_cache: torch.Tensor,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
*args,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
x = 16 // kv_cache.element_size()
|
||||
num_blocks = kv_cache.shape[1]
|
||||
|
||||
key_cache = kv_cache[0]
|
||||
key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x,
|
||||
-1, x)
|
||||
value_cache = kv_cache[1]
|
||||
value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1)
|
||||
return key_cache, value_cache
|
||||
|
||||
@staticmethod
|
||||
def write_to_paged_cache(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
*args,
|
||||
) -> None:
|
||||
ops.reshape_and_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping.flatten(),
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def forward_decode(
|
||||
output: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
max_context_len: int,
|
||||
kv_cache_dtype: str,
|
||||
num_kv_heads: int,
|
||||
scale: float,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
*args,
|
||||
) -> None:
|
||||
tp_rank: int = 0
|
||||
blocksparse_local_blocks: int = 0
|
||||
blocksparse_vert_stride: int = 0
|
||||
blocksparse_block_size: int = 64
|
||||
blocksparse_head_sliding_step: int = 0
|
||||
block_size = value_cache.shape[3]
|
||||
|
||||
ops.paged_attention_v1(
|
||||
output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
context_lens,
|
||||
block_size,
|
||||
max_context_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
tp_rank,
|
||||
blocksparse_local_blocks,
|
||||
blocksparse_vert_stride,
|
||||
blocksparse_block_size,
|
||||
blocksparse_head_sliding_step,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: torch.Tensor,
|
||||
*args,
|
||||
) -> None:
|
||||
key_caches = [kv_cache[0] for kv_cache in kv_caches]
|
||||
value_caches = [kv_cache[1] for kv_cache in kv_caches]
|
||||
ops.copy_blocks(key_caches, value_caches, src_to_dists)
|
||||
|
||||
|
||||
class _IPEXPagedAttention(_PagedAttention):
|
||||
|
||||
@staticmethod
|
||||
def split_kv_cache(
|
||||
kv_cache: torch.Tensor,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
*args,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
num_blocks = kv_cache.shape[1]
|
||||
|
||||
key_cache = kv_cache[0]
|
||||
key_cache = key_cache.view(num_blocks, num_kv_heads, -1, head_size)
|
||||
value_cache = kv_cache[1]
|
||||
value_cache = value_cache.view(num_blocks, num_kv_heads, -1, head_size)
|
||||
return key_cache, value_cache
|
||||
|
||||
@staticmethod
|
||||
def write_to_paged_cache(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
*args,
|
||||
) -> None:
|
||||
ipex_modules.PagedAttention.reshape_and_cache(
|
||||
key, value, key_cache, value_cache,
|
||||
slot_mapping.flatten().int())
|
||||
|
||||
@staticmethod
|
||||
def forward_decode(
|
||||
output: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
max_context_len: int,
|
||||
kv_cache_dtype: str,
|
||||
num_kv_heads: int,
|
||||
scale: float,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
*args,
|
||||
) -> None:
|
||||
block_size = value_cache.shape[2]
|
||||
head_mapping = torch.arange(
|
||||
0,
|
||||
num_kv_heads,
|
||||
device="cpu",
|
||||
dtype=torch.int32,
|
||||
).view(num_kv_heads,
|
||||
1).repeat_interleave(query.size(1) // num_kv_heads).flatten()
|
||||
ipex_modules.PagedAttention.single_query_cached_kv_attention(
|
||||
output, query.contiguous(), key_cache, value_cache, head_mapping,
|
||||
scale, block_tables, context_lens, block_size, max_context_len,
|
||||
alibi_slopes)
|
||||
|
||||
|
||||
PagedAttention = _IPEXPagedAttention if _use_ipex else _PagedAttention
|
||||
44
vllm/attention/ops/merge_attn_states.py
Normal file
44
vllm/attention/ops/merge_attn_states.py
Normal file
@@ -0,0 +1,44 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm import envs
|
||||
|
||||
|
||||
def merge_attn_states(
|
||||
output: torch.Tensor,
|
||||
prefix_output: torch.Tensor,
|
||||
prefix_lse: torch.Tensor,
|
||||
suffix_output: torch.Tensor,
|
||||
suffix_lse: torch.Tensor,
|
||||
output_lse: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
|
||||
# NOTE(DefTruth): Currently, custom merge_attn_states CUDA kernel
|
||||
# is not support for FP8 dtype, fallback to use Triton kernel.
|
||||
def supported_dtypes(o: torch.Tensor) -> bool:
|
||||
return o.dtype in [torch.float32, torch.half, torch.bfloat16]
|
||||
|
||||
# NOTE(DefTruth): Currently, custom merge_attn_states CUDA
|
||||
# kernel load/store 128b(16 bytes) per memory issue within
|
||||
# thread. Namely, the headsize(headdim) must be multiple of
|
||||
# pack_size (float32 -> 4, half/bfloat16 -> 8).
|
||||
def supported_headdim(o: torch.Tensor) -> bool:
|
||||
headdim = o.shape[2] # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
||||
if o.dtype == torch.float32:
|
||||
return headdim % 4 == 0
|
||||
return headdim % 8 == 0
|
||||
|
||||
if (current_platform.is_cuda() or envs.VLLM_USE_MERGE_ATTN_STATES_OPT and supported_dtypes(output)
|
||||
and supported_headdim(output)):
|
||||
from vllm._custom_ops import merge_attn_states
|
||||
return merge_attn_states(output, prefix_output, prefix_lse,
|
||||
suffix_output, suffix_lse, output_lse)
|
||||
else:
|
||||
from vllm.attention.ops.triton_merge_attn_states import (
|
||||
merge_attn_states)
|
||||
return merge_attn_states(output, prefix_output, prefix_lse,
|
||||
suffix_output, suffix_lse, output_lse)
|
||||
903
vllm/attention/ops/nki_flash_attn.py
Normal file
903
vllm/attention/ops/nki_flash_attn.py
Normal file
@@ -0,0 +1,903 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import neuronxcc.nki.isa as nisa
|
||||
import neuronxcc.nki.language as nl
|
||||
import numpy as np
|
||||
import torch
|
||||
from neuronxcc import nki
|
||||
from neuronxcc.nki.language import par_dim
|
||||
|
||||
from vllm.utils import cdiv
|
||||
|
||||
|
||||
def is_power_of_2(x):
|
||||
return x > 0 and (x & (x - 1)) == 0
|
||||
|
||||
|
||||
@nki.jit
|
||||
def load_block_tables(block_tables_hbm, num_tiles, num_blocks_per_tile):
|
||||
"""
|
||||
Load block tables from HBM into SRAM
|
||||
|
||||
`block_tables_hbm` has shape `(num_tiles * num_blocks_per_tile, )`.
|
||||
In case `num_tiles > B_P_SIZE`, we need further tile `num_tile` dimension.
|
||||
"""
|
||||
B_P_SIZE = 128
|
||||
|
||||
# reshape as `(num_tiles, num_blocks_per_tile)`
|
||||
assert len(block_tables_hbm.shape) == 1
|
||||
(num_total_blocks, ) = block_tables_hbm.shape
|
||||
assert num_blocks_per_tile * num_tiles == num_total_blocks
|
||||
block_tables_hbm = block_tables_hbm.reshape(
|
||||
(num_tiles, num_blocks_per_tile))
|
||||
|
||||
block_tables_sbuf = nl.zeros(
|
||||
(cdiv(num_tiles, B_P_SIZE), par_dim(B_P_SIZE), num_blocks_per_tile),
|
||||
dtype=nl.int32,
|
||||
)
|
||||
for i in nl.affine_range(cdiv(num_tiles, B_P_SIZE)):
|
||||
i_p = nl.arange(B_P_SIZE)[:, None]
|
||||
i_f = nl.arange(num_blocks_per_tile)[None, :]
|
||||
block_tables_sbuf[i, i_p, i_f] = nl.load(
|
||||
block_tables_hbm[i_p + i * B_P_SIZE, i_f],
|
||||
dtype=nl.int32,
|
||||
mask=(i_p + i * B_P_SIZE < num_tiles),
|
||||
)
|
||||
return block_tables_sbuf
|
||||
|
||||
|
||||
@nki.jit
|
||||
def transform_block_tables_for_indirect_load(
|
||||
block_tables,
|
||||
block_size_tiling_factor,
|
||||
num_head,
|
||||
head_id,
|
||||
):
|
||||
"""
|
||||
This function does two things:
|
||||
1. calculate new `block_tables` for a `head_id` after flattening
|
||||
`num_block`, `num_head`, and `block_size_tiling_factor` dimensions
|
||||
2. transpose the result so that `block_table` for each tile is mapped to
|
||||
SBUF Partition dimension for vectorized DMA
|
||||
|
||||
Tiling trick to further improve DMA performance:
|
||||
Given KV cache shape `(num_block, num_head, block_size, D)`, when loading M
|
||||
blocks of a given `head_id` from HBM, the load `cache[block_tables,
|
||||
head_id]` has shape `(M, block_size, D)`. If M < B_P_SIZE = 128, DMA may not
|
||||
fully utilize hardware parallelization. The solution is to tile `block_size`
|
||||
into `(block_size_tiling_factor, tiled_block_size)` s.t. `M *
|
||||
block_size_tiling_factor = B_P_SIZE`. After tiling, KV cache has shape
|
||||
`(num_block, num_head, block_size_tiling_factor, tiled_block_size, D)`.
|
||||
|
||||
Note:
|
||||
We don't further tile D dimension as small DMA size also hurts performance.
|
||||
"""
|
||||
B_P_SIZE = 128
|
||||
num_partitions, num_tiles_per_partition, num_blocks_per_tile = (
|
||||
block_tables.shape)
|
||||
assert num_tiles_per_partition == B_P_SIZE
|
||||
assert is_power_of_2(
|
||||
num_blocks_per_tile), f"{num_blocks_per_tile=} is not power of 2"
|
||||
|
||||
num_loads = cdiv(num_blocks_per_tile, B_P_SIZE)
|
||||
block_tables_transposed = nl.ndarray(
|
||||
(
|
||||
num_loads,
|
||||
par_dim(B_P_SIZE),
|
||||
num_partitions * num_tiles_per_partition,
|
||||
),
|
||||
dtype=nl.int32,
|
||||
)
|
||||
|
||||
# prepare iota ahead of time to avoid repeatedly using Gpsimd
|
||||
if num_head > 1:
|
||||
head_id = nisa.iota(head_id, dtype=nl.int32).reshape((1, 1))
|
||||
head_id = nl.transpose(
|
||||
head_id.broadcast_to((1, num_tiles_per_partition)))
|
||||
if num_blocks_per_tile > 1:
|
||||
head_id = head_id.broadcast_to(
|
||||
(num_tiles_per_partition, num_blocks_per_tile))
|
||||
|
||||
if block_size_tiling_factor > 1:
|
||||
broadcast_shape = (
|
||||
num_tiles_per_partition,
|
||||
num_blocks_per_tile,
|
||||
block_size_tiling_factor,
|
||||
)
|
||||
offset = nisa.iota(nl.arange(block_size_tiling_factor)[None, None, :],
|
||||
dtype=nl.int32).broadcast_to(broadcast_shape)
|
||||
|
||||
for partition_id in nl.affine_range(num_partitions):
|
||||
block_tables_partition = block_tables[partition_id]
|
||||
if num_head > 1:
|
||||
# fuse num_block and num_head dimension
|
||||
block_tables_partition = block_tables_partition * num_head + head_id
|
||||
|
||||
if block_size_tiling_factor > 1:
|
||||
# need to apply block size tiling trick
|
||||
assert num_blocks_per_tile * block_size_tiling_factor == B_P_SIZE
|
||||
block_tables_partition = ((block_tables_partition *
|
||||
block_size_tiling_factor).reshape(
|
||||
(num_tiles_per_partition,
|
||||
num_blocks_per_tile,
|
||||
1)).broadcast_to(broadcast_shape))
|
||||
new_block_tables = block_tables_partition + offset
|
||||
new_block_tables = new_block_tables.reshape(
|
||||
(num_tiles_per_partition, B_P_SIZE))
|
||||
else:
|
||||
new_block_tables = block_tables_partition
|
||||
|
||||
# transpose the block table so that it can be used by vector DGE
|
||||
for i in nl.affine_range(num_loads):
|
||||
i_p = nl.arange(B_P_SIZE)[:, None]
|
||||
i_f = (partition_id * num_tiles_per_partition +
|
||||
nl.arange(num_tiles_per_partition)[None, :])
|
||||
block_tables_transposed[i, i_p, i_f] = nl.transpose(
|
||||
new_block_tables[:, nl.ds(i * B_P_SIZE, B_P_SIZE)])
|
||||
return block_tables_transposed
|
||||
|
||||
|
||||
@nki.jit
|
||||
def load_kv_tile_from_cache(
|
||||
cur_k_tile,
|
||||
cur_v_tile,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
large_k_tile_idx,
|
||||
num_blocks_per_large_tile,
|
||||
tiled_block_size,
|
||||
B_P_SIZE,
|
||||
B_D_SIZE,
|
||||
):
|
||||
"""
|
||||
Load KV cache and transform Key and Value into layout required by Matmul
|
||||
|
||||
Vectorized DMA Load layout:
|
||||
Key and Value: (par_dim(B_P_SIZE), seqlen_kv // B_P_SIZE * B_D_SIZE)
|
||||
|
||||
Layout used by attention matmuls:
|
||||
Key: (par_dim(B_D_SIZE), seqlen_kv)
|
||||
Value: (seqlen_kv // B_P_SIZE, par_dim(B_P_SIZE), B_D_SIZE)
|
||||
equivalent to (par_dim(B_P_SIZE), seqlen_kv // B_P_SIZE * B_D_SIZE)
|
||||
"""
|
||||
# load key cache
|
||||
num_loads = cdiv(num_blocks_per_large_tile, B_P_SIZE)
|
||||
for load_idx in nl.affine_range(num_loads):
|
||||
i_p = nl.arange(B_P_SIZE)[:, None]
|
||||
i_f = nl.arange(tiled_block_size * B_D_SIZE)[None, :]
|
||||
loaded = nl.load(kv_cache[0, block_tables[load_idx, i_p,
|
||||
large_k_tile_idx], i_f])
|
||||
if cur_k_tile.dtype != loaded.dtype:
|
||||
loaded = nl.copy(loaded, dtype=cur_k_tile.dtype)
|
||||
# Transpose SBUF tensor using PE
|
||||
for tb_i in nl.affine_range(tiled_block_size):
|
||||
cur_k_tile[
|
||||
:,
|
||||
nl.ds(
|
||||
load_idx * B_P_SIZE * tiled_block_size + tb_i * B_P_SIZE,
|
||||
B_P_SIZE,
|
||||
),
|
||||
] = nl.transpose(loaded[:, nl.ds(tb_i * B_D_SIZE, B_D_SIZE)])
|
||||
|
||||
# load value cache
|
||||
for load_idx in nl.affine_range(num_loads):
|
||||
loaded = nl.load(kv_cache[1, block_tables[load_idx, i_p,
|
||||
large_k_tile_idx], i_f])
|
||||
if cur_v_tile.dtype != loaded.dtype:
|
||||
loaded = nl.copy(loaded, dtype=cur_v_tile.dtype)
|
||||
i_p = nl.arange(B_P_SIZE)[:, None]
|
||||
i_f = nl.arange(tiled_block_size * B_D_SIZE)[None, :]
|
||||
cur_v_tile[
|
||||
:,
|
||||
nl.ds(
|
||||
load_idx * tiled_block_size * B_D_SIZE,
|
||||
tiled_block_size * B_D_SIZE,
|
||||
),
|
||||
] = loaded
|
||||
|
||||
|
||||
@nki.jit
|
||||
def transpose_p_local(p_local_transposed,
|
||||
p_local,
|
||||
LARGE_TILE_SZ,
|
||||
B_F_SIZE=512):
|
||||
for i in nl.affine_range(LARGE_TILE_SZ // B_F_SIZE):
|
||||
if nisa.get_nc_version() == nisa.nc_version.gen3:
|
||||
p_local_t_tmp = nl.ndarray((par_dim(128), B_F_SIZE),
|
||||
buffer=nl.sbuf,
|
||||
dtype=p_local.dtype)
|
||||
else:
|
||||
p_local_t_tmp = nl.ndarray((par_dim(128), B_F_SIZE),
|
||||
buffer=nl.psum,
|
||||
dtype=np.float32)
|
||||
|
||||
for j in nl.affine_range(B_F_SIZE // 128):
|
||||
j_128_slice = nl.ds(j * 128, 128)
|
||||
i_j_128_slice = nl.ds(i * B_F_SIZE + j * 128, 128)
|
||||
|
||||
if nisa.get_nc_version() == nisa.nc_version.gen3:
|
||||
p_local_t_tmp[:, j_128_slice] = nisa.dma_transpose(
|
||||
p_local[:, i_j_128_slice])
|
||||
else:
|
||||
p_local_t_tmp[:, j_128_slice] = nisa.nc_transpose(
|
||||
p_local[:, i_j_128_slice])
|
||||
|
||||
p_local_transposed[:, nl.ds(i * B_F_SIZE, B_F_SIZE)] = nl.copy(
|
||||
p_local_t_tmp, dtype=p_local_transposed.dtype)
|
||||
|
||||
|
||||
@nki.jit
|
||||
def _flash_attention_core(
|
||||
q_local_tile,
|
||||
k,
|
||||
v,
|
||||
o_buffer,
|
||||
l_buffer,
|
||||
m_buffer,
|
||||
kernel_dtype,
|
||||
acc_type,
|
||||
tile_mask,
|
||||
use_causal_mask,
|
||||
q_tile_idx=None,
|
||||
initialize=False,
|
||||
LARGE_TILE_SZ=2048,
|
||||
B_P_SIZE=128,
|
||||
B_F_SIZE=512,
|
||||
B_D_SIZE=128,
|
||||
qk_res_buffer=None,
|
||||
):
|
||||
"""
|
||||
The flash attention core function to calculate self attention between a tile
|
||||
of q and a block of K and V.
|
||||
The q_local_tile has (B_P_SIZE, B_D_SIZE)
|
||||
The K and V have shape (B_D_SIZE, LARGE_TILE_SZ), whose free dimension will
|
||||
be split into size B_F_SIZE tiles
|
||||
|
||||
The results are stored in the following three buffers
|
||||
o_buffer: (B_P_SIZE, d)
|
||||
l_buffer: (B_P_SIZE, 1)
|
||||
m_buffer: (B_P_SIZE, 1)
|
||||
|
||||
All IO buffers are in SBUF.
|
||||
"""
|
||||
num_k_tile_per_large_tile = LARGE_TILE_SZ // B_F_SIZE
|
||||
|
||||
qk_res_buf = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ),
|
||||
buffer=nl.sbuf,
|
||||
dtype=acc_type)
|
||||
max_local = nl.ndarray((par_dim(B_P_SIZE), num_k_tile_per_large_tile),
|
||||
dtype=acc_type)
|
||||
for k_i in nl.affine_range(num_k_tile_per_large_tile):
|
||||
k_i_b_f_slice = nl.ds(k_i * B_F_SIZE, B_F_SIZE)
|
||||
|
||||
if use_causal_mask:
|
||||
# mask are used to only apply computation to the lower half of the
|
||||
# matrix, which reduce the arithmetic intensity by up to 50%
|
||||
multiplication_required_selection = (q_tile_idx * B_P_SIZE
|
||||
>= k_i * B_F_SIZE)
|
||||
else:
|
||||
multiplication_required_selection = True
|
||||
|
||||
if multiplication_required_selection:
|
||||
qk_psum = nl.ndarray((par_dim(B_P_SIZE), B_F_SIZE),
|
||||
dtype=np.float32,
|
||||
buffer=nl.psum) # (128, 512)
|
||||
qk_psum[:, :] = nl.matmul(q_local_tile,
|
||||
k[:, k_i_b_f_slice],
|
||||
transpose_x=True) # (p(128), 512)
|
||||
qk_res_buf[:, k_i_b_f_slice] = nl.where(
|
||||
tile_mask[:, k_i_b_f_slice],
|
||||
qk_psum[:, nl.ds(0, B_F_SIZE)],
|
||||
-9984.0,
|
||||
dtype=acc_type,
|
||||
)
|
||||
else:
|
||||
qk_res_buf[:, k_i_b_f_slice] = -9984.0
|
||||
|
||||
# Calculate max of the current tile
|
||||
max_local[:, k_i] = nisa.tensor_reduce(
|
||||
np.max,
|
||||
qk_res_buf[:, k_i_b_f_slice],
|
||||
axis=(1, ),
|
||||
dtype=acc_type,
|
||||
negate=False,
|
||||
)
|
||||
|
||||
if qk_res_buffer is not None:
|
||||
qk_res_buffer[:, :] = nl.copy(qk_res_buf[:, :])
|
||||
|
||||
max_ = nisa.tensor_reduce(
|
||||
np.max,
|
||||
max_local[:, :],
|
||||
axis=(1, ),
|
||||
dtype=acc_type,
|
||||
negate=False,
|
||||
)
|
||||
|
||||
o_previous_scaled = nl.ndarray((par_dim(B_P_SIZE), B_D_SIZE),
|
||||
dtype=o_buffer.dtype)
|
||||
|
||||
if initialize:
|
||||
m_buffer[:, 0] = nl.copy(max_)
|
||||
m_current = max_
|
||||
else:
|
||||
m_previous = nl.copy(m_buffer[:, 0])
|
||||
m_buffer[:, 0] = nl.maximum(m_previous, max_) # (128,1)
|
||||
|
||||
m_current = m_buffer[:, 0]
|
||||
# Compute scaling factor
|
||||
alpha = nisa.activation(
|
||||
np.exp,
|
||||
m_previous,
|
||||
bias=-1 * m_current,
|
||||
scale=1.0,
|
||||
)
|
||||
o_previous_scaled[...] = nl.multiply(o_buffer[:, :], alpha)
|
||||
|
||||
p_local = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ),
|
||||
dtype=kernel_dtype)
|
||||
REDUCTION_TILE = min(2048, LARGE_TILE_SZ // 2)
|
||||
|
||||
p_partial_sum = nl.ndarray(
|
||||
(par_dim(B_P_SIZE), LARGE_TILE_SZ // REDUCTION_TILE),
|
||||
dtype=acc_type,
|
||||
)
|
||||
|
||||
for k_r_i in nl.affine_range(LARGE_TILE_SZ // REDUCTION_TILE):
|
||||
k_r_i_reduce_slice = nl.ds(k_r_i * REDUCTION_TILE, REDUCTION_TILE)
|
||||
|
||||
# compute exp(qk - max)
|
||||
# Compute partial row - tile sum of exp(qk - max))
|
||||
# FIXME : Use activation accumulate to accumulate over k_r_i loop ?
|
||||
p_local[:, k_r_i_reduce_slice] = nisa.activation_reduce(
|
||||
np.exp,
|
||||
qk_res_buf[:, k_r_i_reduce_slice],
|
||||
bias=-1 * m_current,
|
||||
scale=1.0,
|
||||
reduce_op=nl.add,
|
||||
reduce_res=p_partial_sum[:, k_r_i],
|
||||
dtype=kernel_dtype,
|
||||
)
|
||||
|
||||
ps = nl.sum(p_partial_sum, axis=1, dtype=acc_type)
|
||||
|
||||
p_local_transposed = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ),
|
||||
dtype=kernel_dtype)
|
||||
transpose_p_local(
|
||||
p_local_transposed=p_local_transposed,
|
||||
p_local=p_local,
|
||||
LARGE_TILE_SZ=LARGE_TILE_SZ,
|
||||
B_F_SIZE=B_F_SIZE,
|
||||
)
|
||||
|
||||
pv_psum = nl.zeros(
|
||||
(par_dim(B_P_SIZE), B_D_SIZE),
|
||||
dtype=np.float32,
|
||||
buffer=nl.psum,
|
||||
)
|
||||
for k_i in nl.affine_range(LARGE_TILE_SZ // B_P_SIZE):
|
||||
pv_psum[:, :] += nl.matmul(
|
||||
p_local_transposed[:, nl.ds(k_i * B_P_SIZE, B_P_SIZE)],
|
||||
v[:, nl.ds(k_i * B_D_SIZE, B_D_SIZE)],
|
||||
transpose_x=True,
|
||||
) # (128, 128) (p(Br), d)
|
||||
|
||||
if initialize:
|
||||
o_buffer[:, :] = nl.copy(pv_psum[:, :])
|
||||
l_buffer[:, 0] = nl.add(nl.log(ps), max_)
|
||||
else:
|
||||
o_buffer[:, :] = nl.add(o_previous_scaled, pv_psum)
|
||||
|
||||
l_prev = l_buffer[:, 0]
|
||||
l_exp = nl.add(
|
||||
nl.exp(nl.subtract(l_prev, m_current)),
|
||||
ps,
|
||||
)
|
||||
l_buffer[:, 0] = nl.add(m_current, nl.log(l_exp))
|
||||
|
||||
|
||||
@nki.jit
|
||||
def load_v_tile(v_hbm_tile, cur_v_tile, large_tile_idx, v_i, LARGE_TILE_SZ):
|
||||
B_P_SIZE = 128
|
||||
B_D_SIZE = v_hbm_tile.shape[-1]
|
||||
loaded = nl.load(v_hbm_tile[
|
||||
nl.ds(large_tile_idx * LARGE_TILE_SZ + B_P_SIZE * v_i, B_P_SIZE),
|
||||
:,
|
||||
])
|
||||
if cur_v_tile.dtype != loaded.dtype:
|
||||
loaded = nl.copy(loaded, dtype=cur_v_tile.dtype)
|
||||
cur_v_tile[:, nl.ds(v_i * B_D_SIZE, B_D_SIZE)] = loaded
|
||||
|
||||
|
||||
@nki.jit
|
||||
def flash_paged_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
mask,
|
||||
softmax_scale=None,
|
||||
mixed_precision=True,
|
||||
LARGE_TILE_SZ=2048,
|
||||
return_debug_tensors=False,
|
||||
):
|
||||
"""
|
||||
Flash PagedAttention Forward Kernel.
|
||||
|
||||
IO tensor layouts:
|
||||
- query: shape (1, n_heads, d, seq_q)
|
||||
- key: shape (1, n_kv_heads, d, seq_k)
|
||||
- value: shape (1, n_kv_heads, seq_v, d)
|
||||
- kv_cache: (2, num_blocks, n_kv_heads, block_size, d)
|
||||
- block_tables: (num_active_blocks, )
|
||||
- mask: (seq_q, num_active_blocks * block_size + seq_q)
|
||||
- o: shape (1, n_heads, seq_q, d)
|
||||
|
||||
- This kernel requires seq_k == seq_v
|
||||
- We use continuous batching by default, so the batch dimension is
|
||||
always 1, and different requests are concatenated along sequence
|
||||
dimension.
|
||||
- We use paged cache blocks (kv_cache) to store KV cache.
|
||||
|
||||
IO tensor dtypes:
|
||||
- This kernel assumes all IO tensors have the same dtype except for
|
||||
block_tables (int32) and mask (int32)
|
||||
- If mixed_precision is True, then all Tensor Engine operation will be
|
||||
performed in bfloat16 and accumulation will be performed in float32.
|
||||
Otherwise the intermediates will be in the same type as the inputs.
|
||||
|
||||
Compile-time Constants:
|
||||
- softmax_scale: scaling for softmax, is None, default is `1.0/(d**0.5)`
|
||||
- mixed_precision: flag to set non-matmul ops in fp32 precision, default
|
||||
is set to `true`, if false, we use same precision as input types
|
||||
- LARGE_TILE_SZ: `default=2048`, size of the kv tile size for attention
|
||||
computation reduction
|
||||
|
||||
GQA support Notes:
|
||||
the spmd kernel for launching kernel should be on kv_heads instead of
|
||||
nheads
|
||||
|
||||
Example usage:
|
||||
MHA: q: [b, h, d, s], k: [b, h, d, s], v: [b, h, s, d]
|
||||
usage: `flash_fwd[b, h](q, k, v, ...)`
|
||||
GQA: q: [b, h, d, s], k: [b, kv_h, d, s], v: [b, kv_h, s, d]
|
||||
usage: `flash_fwd[b, kv_h](q, k, v, ...)`
|
||||
"""
|
||||
B_F_SIZE = 512
|
||||
B_P_SIZE = 128
|
||||
b, h, d, seqlen_q = query.shape
|
||||
B_D_SIZE = d
|
||||
n_tile_q = seqlen_q // B_P_SIZE # since q will be loaded on tensor engine
|
||||
_, num_blocks, k_h, block_size, _ = kv_cache.shape
|
||||
q_h_per_k_h = h // k_h
|
||||
assert b == 1, f"invalid batch size {b=}"
|
||||
assert d <= 128, f" we do not support head_dim > 128, got head dim {d=}"
|
||||
cache_shape = (2, num_blocks, k_h, block_size, d)
|
||||
assert (tuple(kv_cache.shape) == cache_shape
|
||||
), f"{kv_cache.shape=} mismatch, expect {cache_shape}"
|
||||
assert key is None or tuple(key.shape) == (
|
||||
1,
|
||||
k_h,
|
||||
d,
|
||||
seqlen_q,
|
||||
), f"key shape {key.shape} mismatch!"
|
||||
assert value is None or tuple(value.shape) == (
|
||||
1,
|
||||
k_h,
|
||||
seqlen_q,
|
||||
d,
|
||||
), f"value shape {value.shape} mismatch!"
|
||||
|
||||
assert (
|
||||
nl.program_ndim() == 2
|
||||
), f"Expect spmd grid with 2 dimensions, got {nl.program_ndim()} instead!"
|
||||
batch_id = nl.program_id(axis=0)
|
||||
head_id = nl.program_id(axis=1)
|
||||
|
||||
(num_active_blocks, ) = block_tables.shape
|
||||
context_kv_len = num_active_blocks * block_size
|
||||
assert (
|
||||
LARGE_TILE_SZ % B_F_SIZE == 0
|
||||
), f"Need {LARGE_TILE_SZ=} to be divisible by {B_F_SIZE=} in transpose_p"
|
||||
assert (context_kv_len % LARGE_TILE_SZ == 0
|
||||
), f"Need {context_kv_len=} to be divisible by {LARGE_TILE_SZ=}"
|
||||
|
||||
num_blocks_per_large_tile = LARGE_TILE_SZ // block_size
|
||||
assert is_power_of_2(
|
||||
num_blocks_per_large_tile
|
||||
), f"{num_blocks_per_large_tile=} is expected of be power of 2"
|
||||
if seqlen_q > B_F_SIZE:
|
||||
MAX_REDUCTION_TILE = 2048
|
||||
if seqlen_q // 2 > MAX_REDUCTION_TILE:
|
||||
assert (
|
||||
seqlen_q % MAX_REDUCTION_TILE == 0
|
||||
), f"{seqlen_q=} should be divisible by {MAX_REDUCTION_TILE=}"
|
||||
else:
|
||||
assert (seqlen_q % B_F_SIZE == 0
|
||||
), f"{seqlen_q=} should be divisible by {B_F_SIZE=})"
|
||||
|
||||
kernel_dtype = nl.bfloat16 if mixed_precision else query.dtype
|
||||
acc_type = np.dtype(np.float32) if mixed_precision else kernel_dtype
|
||||
softmax_scale = softmax_scale or (1.0 / (d**0.5))
|
||||
num_large_k_tile = context_kv_len // LARGE_TILE_SZ
|
||||
|
||||
o = nl.ndarray((b, h, seqlen_q, d),
|
||||
dtype=query.dtype,
|
||||
buffer=nl.shared_hbm)
|
||||
hbm_l_buffer, hbm_m_buffer, hbm_qk_res, qk_res_buffer = (
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
if return_debug_tensors:
|
||||
hbm_l_buffer = nl.ndarray((b, h, seqlen_q),
|
||||
dtype=acc_type,
|
||||
buffer=nl.shared_hbm)
|
||||
hbm_m_buffer = nl.ndarray((b, h, seqlen_q),
|
||||
dtype=acc_type,
|
||||
buffer=nl.shared_hbm)
|
||||
hbm_qk_res = nl.ndarray((b, h, B_P_SIZE, seqlen_q),
|
||||
dtype=acc_type,
|
||||
buffer=nl.shared_hbm)
|
||||
qk_res_buffer = nl.zeros(
|
||||
(n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), seqlen_q),
|
||||
dtype=acc_type,
|
||||
buffer=nl.sbuf,
|
||||
lazy_initialization=True,
|
||||
)
|
||||
block_tables_sbuf = load_block_tables(
|
||||
block_tables_hbm=block_tables,
|
||||
num_tiles=num_large_k_tile,
|
||||
num_blocks_per_tile=num_blocks_per_large_tile,
|
||||
)
|
||||
|
||||
# On Neuron, we need B_P_SIZE = 128 blocks to make DMA efficient
|
||||
if num_blocks_per_large_tile < B_P_SIZE:
|
||||
# we checked num_blocks_per_tile is a power of 2
|
||||
assert B_P_SIZE % num_blocks_per_large_tile == 0
|
||||
block_size_tiling_factor = B_P_SIZE // num_blocks_per_large_tile
|
||||
# We assume block_size >= block_size_tiling_factor
|
||||
assert block_size % block_size_tiling_factor == 0
|
||||
else:
|
||||
block_size_tiling_factor = 1
|
||||
tiled_block_size = block_size // block_size_tiling_factor
|
||||
|
||||
# Indirect DMA load must be placed along Partition Dimension
|
||||
block_tables_sbuf = transform_block_tables_for_indirect_load(
|
||||
block_tables_sbuf,
|
||||
block_size_tiling_factor=block_size_tiling_factor,
|
||||
num_head=k_h,
|
||||
head_id=head_id,
|
||||
)
|
||||
|
||||
# Flatten KV cache to be 3D for loading into SBUF
|
||||
new_cache_shape = (
|
||||
2,
|
||||
num_blocks * k_h * block_size_tiling_factor,
|
||||
tiled_block_size * d,
|
||||
)
|
||||
kv_cache = kv_cache.reshape(new_cache_shape)
|
||||
|
||||
# Global Flash Attention accumulators
|
||||
o_buffer = nl.zeros(
|
||||
(n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), d),
|
||||
dtype=acc_type,
|
||||
buffer=nl.sbuf,
|
||||
lazy_initialization=True,
|
||||
)
|
||||
l_buffer = nl.zeros(
|
||||
(n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), 1),
|
||||
dtype=acc_type,
|
||||
buffer=nl.sbuf,
|
||||
lazy_initialization=True,
|
||||
)
|
||||
m_buffer = nl.zeros(
|
||||
(n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), 1),
|
||||
dtype=acc_type,
|
||||
buffer=nl.sbuf,
|
||||
lazy_initialization=True,
|
||||
)
|
||||
|
||||
for large_k_tile_idx in nl.sequential_range(0, num_large_k_tile):
|
||||
num_loads = cdiv(num_blocks_per_large_tile, B_P_SIZE)
|
||||
cur_k_tile = nl.ndarray(
|
||||
(par_dim(B_D_SIZE), LARGE_TILE_SZ),
|
||||
dtype=kernel_dtype,
|
||||
)
|
||||
cur_v_tile = nl.ndarray(
|
||||
(par_dim(B_P_SIZE), num_loads * tiled_block_size * B_D_SIZE),
|
||||
dtype=kernel_dtype,
|
||||
)
|
||||
load_kv_tile_from_cache(
|
||||
cur_k_tile=cur_k_tile,
|
||||
cur_v_tile=cur_v_tile,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=block_tables_sbuf,
|
||||
large_k_tile_idx=large_k_tile_idx,
|
||||
num_blocks_per_large_tile=num_blocks_per_large_tile,
|
||||
tiled_block_size=tiled_block_size,
|
||||
B_P_SIZE=B_P_SIZE,
|
||||
B_D_SIZE=B_D_SIZE,
|
||||
)
|
||||
|
||||
for i in nl.affine_range(n_tile_q):
|
||||
cur_mask = nl.load(mask[
|
||||
nl.ds(i * B_P_SIZE, B_P_SIZE),
|
||||
nl.ds(large_k_tile_idx * LARGE_TILE_SZ, LARGE_TILE_SZ),
|
||||
])
|
||||
for i_q_h in nl.affine_range(q_h_per_k_h):
|
||||
q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE), dtype=kernel_dtype)
|
||||
q_hbm_tile = query[batch_id, head_id * q_h_per_k_h + i_q_h]
|
||||
q_sbuf_tile = nl.load(q_hbm_tile[:,
|
||||
nl.ds(i *
|
||||
B_P_SIZE, B_P_SIZE)])
|
||||
if q_sbuf_tile.dtype != kernel_dtype:
|
||||
q_sbuf_tile = nl.copy(q_sbuf_tile, dtype=kernel_dtype)
|
||||
q_tile[:, :] = q_sbuf_tile * softmax_scale
|
||||
|
||||
_flash_attention_core(
|
||||
q_local_tile=q_tile,
|
||||
k=cur_k_tile,
|
||||
v=cur_v_tile,
|
||||
o_buffer=o_buffer[i, i_q_h],
|
||||
l_buffer=l_buffer[i, i_q_h],
|
||||
m_buffer=m_buffer[i, i_q_h],
|
||||
kernel_dtype=kernel_dtype,
|
||||
acc_type=acc_type,
|
||||
tile_mask=cur_mask,
|
||||
use_causal_mask=False,
|
||||
q_tile_idx=i,
|
||||
initialize=large_k_tile_idx == 0,
|
||||
LARGE_TILE_SZ=LARGE_TILE_SZ,
|
||||
B_P_SIZE=B_P_SIZE,
|
||||
B_F_SIZE=B_F_SIZE,
|
||||
B_D_SIZE=B_D_SIZE,
|
||||
)
|
||||
|
||||
# compute attention between input query, key and value
|
||||
if key is not None and value is not None:
|
||||
B_F_SIZE = min(seqlen_q, B_F_SIZE)
|
||||
LARGE_TILE_SZ = seqlen_q
|
||||
|
||||
cur_k_tile = nl.ndarray((par_dim(B_D_SIZE), LARGE_TILE_SZ),
|
||||
dtype=kernel_dtype)
|
||||
cur_v_tile = nl.ndarray(
|
||||
(par_dim(B_P_SIZE), LARGE_TILE_SZ // B_P_SIZE * B_D_SIZE),
|
||||
dtype=kernel_dtype,
|
||||
)
|
||||
|
||||
loaded = nl.load(key[batch_id, head_id, :, :])
|
||||
if loaded.dtype != kernel_dtype:
|
||||
loaded = nl.copy(loaded, dtype=kernel_dtype)
|
||||
cur_k_tile[:, :] = loaded
|
||||
|
||||
v_hbm_tile = value[batch_id, head_id]
|
||||
for v_i in nl.affine_range(LARGE_TILE_SZ // B_P_SIZE):
|
||||
load_v_tile(
|
||||
v_hbm_tile=v_hbm_tile,
|
||||
cur_v_tile=cur_v_tile,
|
||||
large_tile_idx=0,
|
||||
v_i=v_i,
|
||||
LARGE_TILE_SZ=LARGE_TILE_SZ,
|
||||
)
|
||||
|
||||
for i in nl.affine_range(n_tile_q):
|
||||
cur_mask = nl.load(mask[
|
||||
nl.ds(i * B_P_SIZE, B_P_SIZE),
|
||||
nl.ds(context_kv_len, LARGE_TILE_SZ),
|
||||
])
|
||||
for i_q_h in nl.affine_range(q_h_per_k_h):
|
||||
|
||||
q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE), dtype=kernel_dtype)
|
||||
q_hbm_tile = query[batch_id, head_id * q_h_per_k_h + i_q_h]
|
||||
q_sbuf_tile = nl.load(q_hbm_tile[:,
|
||||
nl.ds(i *
|
||||
B_P_SIZE, B_P_SIZE)])
|
||||
if q_sbuf_tile.dtype != kernel_dtype:
|
||||
q_sbuf_tile = nl.copy(q_sbuf_tile, dtype=kernel_dtype)
|
||||
q_tile[:, :] = q_sbuf_tile * softmax_scale
|
||||
_flash_attention_core(
|
||||
q_local_tile=q_tile,
|
||||
k=cur_k_tile,
|
||||
v=cur_v_tile,
|
||||
o_buffer=o_buffer[i, i_q_h],
|
||||
l_buffer=l_buffer[i, i_q_h],
|
||||
m_buffer=m_buffer[i, i_q_h],
|
||||
kernel_dtype=kernel_dtype,
|
||||
acc_type=acc_type,
|
||||
tile_mask=cur_mask,
|
||||
use_causal_mask=True,
|
||||
q_tile_idx=i,
|
||||
initialize=False,
|
||||
LARGE_TILE_SZ=LARGE_TILE_SZ,
|
||||
B_P_SIZE=B_P_SIZE,
|
||||
B_F_SIZE=B_F_SIZE,
|
||||
B_D_SIZE=B_D_SIZE,
|
||||
qk_res_buffer=(qk_res_buffer[i, i_q_h]
|
||||
if qk_res_buffer is not None else None),
|
||||
)
|
||||
|
||||
# -- -- -- -- write output to buffer on HBM -- -- -- -- -- -- #
|
||||
for i_q_h in nl.affine_range(q_h_per_k_h):
|
||||
for i in nl.affine_range(n_tile_q):
|
||||
out = nl.multiply(
|
||||
o_buffer[i, i_q_h],
|
||||
nl.exp(m_buffer[i, i_q_h] - l_buffer[i, i_q_h]),
|
||||
dtype=kernel_dtype,
|
||||
)
|
||||
|
||||
nl.store(
|
||||
o[
|
||||
batch_id,
|
||||
head_id * q_h_per_k_h + i_q_h,
|
||||
nl.ds(i * B_P_SIZE, B_P_SIZE),
|
||||
:,
|
||||
],
|
||||
out,
|
||||
)
|
||||
# maximum and summation statistics
|
||||
if return_debug_tensors:
|
||||
nl.store(
|
||||
hbm_m_buffer[
|
||||
batch_id,
|
||||
head_id * q_h_per_k_h + i_q_h,
|
||||
nl.ds(i * B_P_SIZE, B_P_SIZE),
|
||||
],
|
||||
m_buffer[i, i_q_h, :, :],
|
||||
)
|
||||
nl.store(
|
||||
hbm_l_buffer[
|
||||
batch_id,
|
||||
head_id * q_h_per_k_h + i_q_h,
|
||||
nl.ds(i * B_P_SIZE, B_P_SIZE),
|
||||
],
|
||||
l_buffer[i, i_q_h],
|
||||
)
|
||||
nl.store(
|
||||
hbm_qk_res[batch_id, head_id * q_h_per_k_h + i_q_h, :, :],
|
||||
qk_res_buffer[batch_id, i_q_h, :, :],
|
||||
)
|
||||
|
||||
if return_debug_tensors:
|
||||
return o, hbm_m_buffer, hbm_l_buffer, hbm_qk_res
|
||||
return o
|
||||
|
||||
|
||||
def reorder_context_mask(mask, LARGE_TILE_SZ, block_size):
|
||||
"""
|
||||
Reorder the mask to make it compatible with the flash attention kernel.
|
||||
|
||||
We vectorize KV cache read to improve DMA utilization. However, the layout
|
||||
that maximizes DMA bandwidth changes the order tokens are consumed.
|
||||
|
||||
The token layout (inner 2 dimensions) after vectorized load is (B_P_SIZE,
|
||||
tiled_block_size) in a tile of `B_P_SIZE * tiled_block_size` tokens. And
|
||||
each step the engine consumes a column (rather than a row) of B_P_SIZE
|
||||
tokens. Therefore, the tokens are visited in a strided way.
|
||||
|
||||
To make sure mask matches the order tokens are consumed, we need to properly
|
||||
transpose mask.
|
||||
"""
|
||||
total_query_len, total_seq_len = mask.shape
|
||||
context_kv_len = total_seq_len - total_query_len
|
||||
|
||||
B_P_SIZE = 128
|
||||
assert (LARGE_TILE_SZ
|
||||
>= B_P_SIZE), f"{LARGE_TILE_SZ=} must be larger than {B_P_SIZE=}"
|
||||
num_tiled_blocks = max(B_P_SIZE, LARGE_TILE_SZ // block_size)
|
||||
tiled_block_size = LARGE_TILE_SZ // num_tiled_blocks
|
||||
if tiled_block_size > 1:
|
||||
# Mask reordering is needed when tiled_block_size > 1
|
||||
device = mask.device
|
||||
mask = mask.cpu()
|
||||
context_mask = mask[:, :context_kv_len]
|
||||
context_mask = context_mask.view(
|
||||
total_query_len,
|
||||
context_kv_len // LARGE_TILE_SZ,
|
||||
num_tiled_blocks // B_P_SIZE,
|
||||
B_P_SIZE,
|
||||
tiled_block_size,
|
||||
)
|
||||
context_mask = context_mask.transpose(3, 4).reshape(
|
||||
total_query_len, context_kv_len)
|
||||
new_mask = mask[:, context_kv_len:]
|
||||
return torch.concat([context_mask, new_mask], dim=1).to(device)
|
||||
else:
|
||||
return mask
|
||||
|
||||
|
||||
def flash_attn_varlen_nkifunc(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
block_table,
|
||||
attn_mask,
|
||||
n_kv_head=None,
|
||||
head_size=None,
|
||||
LARGE_TILE_SZ=2048,
|
||||
mixed_precision=True,
|
||||
):
|
||||
"""
|
||||
Compute flash paged attention for variable length sequences.
|
||||
|
||||
This function is a wrapper around the flash attention NKI kernel. It takes
|
||||
in the following arguments:
|
||||
- query: (1, n_heads, d, seq_q)
|
||||
- key: (1, n_kv_heads, d, seq_k)
|
||||
- value: (1, n_kv_heads, seq_v, d)
|
||||
- kv_cache: (2, n_blocks, n_kv_heads, block_size, d)
|
||||
- block_tables: (n_active_blocks, )
|
||||
- attn_mask: (seq_q, n_active_blocks * block_size + seq_q)
|
||||
|
||||
Notes:
|
||||
- attn_mask must be reordered outside using `reorder_context_mask`
|
||||
- Key/value cache layout must be (n_blocks, n_kv_heads, block_size, d)
|
||||
for better DMA throughput
|
||||
"""
|
||||
if n_kv_head is None:
|
||||
n_kv_head = kv_cache.shape[2]
|
||||
assert kv_cache.shape[0] == 2
|
||||
assert kv_cache.shape[2] == n_kv_head
|
||||
if head_size is None:
|
||||
head_size = kv_cache.shape[-1]
|
||||
|
||||
kwargs = dict(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=block_table,
|
||||
mask=attn_mask,
|
||||
softmax_scale=1.0 / (head_size**0.5),
|
||||
mixed_precision=mixed_precision,
|
||||
LARGE_TILE_SZ=LARGE_TILE_SZ,
|
||||
)
|
||||
|
||||
o = flash_paged_attention[1, n_kv_head](**kwargs)
|
||||
return o
|
||||
|
||||
|
||||
def reshape_and_cache(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
) -> None:
|
||||
"""
|
||||
Writes key-value pairs to the KV cache at specified positions.
|
||||
|
||||
Args:
|
||||
key (torch.Tensor): Key tensor with shape
|
||||
(num_tokens, n_kv_head, d_head)
|
||||
value (torch.Tensor): Value tensor with shape
|
||||
(num_tokens, n_kv_head, d_head)
|
||||
kv_cache (torch.Tensor): Key/value cache tensor with shape
|
||||
(2, num_blocks, n_kv_head, block_size, d_head)
|
||||
slot_mapping (torch.Tensor): Mapping tensor indicating cache positions
|
||||
with shape (num_tokens)
|
||||
|
||||
Returns:
|
||||
None: Updates the kv_cache tensor in-place
|
||||
"""
|
||||
block_size = kv_cache.size(3)
|
||||
n_kv_head = key.size(1)
|
||||
|
||||
# Calculate indices with explicit floor division
|
||||
block_indices = torch.div(slot_mapping, block_size, rounding_mode="floor")
|
||||
block_offsets = slot_mapping % block_size
|
||||
|
||||
# Create the head indices tensor
|
||||
head_indices = torch.arange(n_kv_head, device=key.device)
|
||||
|
||||
# Update caches using index_put_
|
||||
kv_cache.index_put_(
|
||||
(torch.tensor([0], device=key.device), block_indices[:, None],
|
||||
head_indices[None, :], block_offsets[:, None]), key)
|
||||
|
||||
kv_cache.index_put_(
|
||||
(torch.tensor([1], device=key.device), block_indices[:, None],
|
||||
head_indices[None, :], block_offsets[:, None]), value)
|
||||
504
vllm/attention/ops/paged_attn.py
Normal file
504
vllm/attention/ops/paged_attn.py
Normal file
@@ -0,0 +1,504 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
import vllm.envs as envs
|
||||
from vllm.utils import SUPPORT_TC
|
||||
|
||||
if HAS_TRITON:
|
||||
from vllm.attention.ops.prefix_prefill import context_attention_fwd
|
||||
|
||||
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
|
||||
_PARTITION_SIZE = 512
|
||||
use_tc = envs.VLLM_USE_OPT_OP and envs.VLLM_USE_TC_PAGED_ATTN and SUPPORT_TC
|
||||
|
||||
@dataclass
|
||||
class PagedAttentionMetadata:
|
||||
"""Metadata for PagedAttention."""
|
||||
# (batch_size,). The length of sequences (entire tokens seen so far) per
|
||||
# sequence.
|
||||
seq_lens_tensor: Optional[torch.Tensor]
|
||||
# Maximum sequence length in the batch. 0 if it is prefill-only batch.
|
||||
max_decode_seq_len: int
|
||||
# (batch_size, max_blocks_per_seq).
|
||||
# Block addresses per sequence. (Seq id -> list of physical block)
|
||||
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
|
||||
# in the kv cache. Each block can contain up to block_size tokens.
|
||||
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
|
||||
# captured.
|
||||
block_tables: Optional[torch.Tensor]
|
||||
|
||||
|
||||
class PagedAttention:
|
||||
|
||||
@staticmethod
|
||||
def get_supported_head_sizes() -> List[int]:
|
||||
return [32, 64, 80, 96, 112, 120, 128, 192, 256]
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return (2, num_blocks, block_size * num_kv_heads * head_size)
|
||||
|
||||
@staticmethod
|
||||
def split_kv_cache(
|
||||
kv_cache: torch.Tensor,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
x = 16 // kv_cache.element_size()
|
||||
num_blocks = kv_cache.shape[1]
|
||||
|
||||
'''
|
||||
CUTLASS key_cache layout: [num_blocks, num_kv_heads, block_size, head_size]
|
||||
Triton key_cache layout: [num_blocks, num_kv_heads, head_size // x, block_size, x]
|
||||
value_cache layout: [num_blocks, num_kv_heads, head_size, block_size]
|
||||
'''
|
||||
if envs.VLLM_USE_FLASH_ATTN_PA:
|
||||
key_cache = kv_cache[0]
|
||||
key_cache = key_cache.view(num_blocks, num_kv_heads, -1, head_size)
|
||||
value_cache = kv_cache[1]
|
||||
value_cache=value_cache.view(num_blocks, num_kv_heads,head_size, -1)
|
||||
else:
|
||||
key_cache = kv_cache[0]
|
||||
key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x,
|
||||
-1, x)
|
||||
value_cache = kv_cache[1]
|
||||
value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1)
|
||||
return key_cache, value_cache
|
||||
|
||||
@staticmethod
|
||||
def write_to_paged_cache(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
) -> None:
|
||||
if envs.VLLM_USE_FLASH_ATTN_PA:
|
||||
ops.reshape_and_cache_cuda(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping.flatten(),
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
else:
|
||||
ops.reshape_and_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping.flatten(),
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def forward_decode(
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
max_seq_len: int,
|
||||
kv_cache_dtype: str,
|
||||
num_kv_heads: int,
|
||||
scale: float,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
tp_rank: int = 0,
|
||||
blocksparse_local_blocks: int = 0,
|
||||
blocksparse_vert_stride: int = 0,
|
||||
blocksparse_block_size: int = 64,
|
||||
blocksparse_head_sliding_step: int = 0,
|
||||
attn_masks: Optional[torch.Tensor] = None,
|
||||
attn_masks_stride: int = 0
|
||||
) -> torch.Tensor:
|
||||
if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1:
|
||||
# use blocksparse paged attention
|
||||
block_size = value_cache.size(-1)
|
||||
assert (blocksparse_block_size > 0 and
|
||||
blocksparse_block_size % block_size == 0), \
|
||||
(f"{blocksparse_block_size=} needs to be a multiple of"
|
||||
f"{block_size=} used in block_tables.")
|
||||
|
||||
output = torch.empty_like(query)
|
||||
block_size = value_cache.shape[3]
|
||||
num_seqs, num_heads, head_size = query.shape
|
||||
max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) //
|
||||
_PARTITION_SIZE)
|
||||
# NOTE(woosuk): We use a simple heuristic to decide whether to use
|
||||
# PagedAttention V1 or V2. If the number of partitions is 1, we use
|
||||
# V1 to avoid the overhead of reduction. Also, if the number of
|
||||
# sequences or heads is large, we use V1 since there is enough work
|
||||
# to parallelize.
|
||||
# TODO(woosuk): Tune this heuristic.
|
||||
# For context len > 8192, use V2 kernel to avoid shared memory shortage.
|
||||
|
||||
if use_tc and head_size==128:
|
||||
if envs.VLLM_USE_PA_PRINT_PARAM:
|
||||
print("PA V1 SIZE:")
|
||||
print(f"query.shape = {query.shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}")
|
||||
print(f"num_kv_heads = {num_kv_heads}, scale = {scale:.3f}, block_tables.shape = {block_tables.shape}, seq_lens.shape = {seq_lens.shape}, block_size = {block_size}, max_seq_len = {max_seq_len}")
|
||||
if attn_masks is None:
|
||||
ops.paged_attention_v1_opt_tc(
|
||||
output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
seq_lens,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
tp_rank,
|
||||
blocksparse_local_blocks,
|
||||
blocksparse_vert_stride,
|
||||
blocksparse_block_size,
|
||||
blocksparse_head_sliding_step
|
||||
)
|
||||
else:
|
||||
ops.paged_attention_v1_opt_tc_with_mask(
|
||||
output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
seq_lens,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
tp_rank,
|
||||
blocksparse_local_blocks,
|
||||
blocksparse_vert_stride,
|
||||
blocksparse_block_size,
|
||||
blocksparse_head_sliding_step,
|
||||
attn_masks,
|
||||
attn_masks_stride
|
||||
)
|
||||
return output
|
||||
|
||||
use_v1 = (max_seq_len <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512))
|
||||
|
||||
if use_v1:
|
||||
# Run PagedAttention V1.
|
||||
if envs.VLLM_USE_PA_PRINT_PARAM:
|
||||
print("PA V1 SIZE:")
|
||||
print(f"query.shape = {query.shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}")
|
||||
print(f"num_kv_heads = {num_kv_heads}, scale = {scale:.3f}, block_tables.shape = {block_tables.shape}, seq_lens.shape = {seq_lens.shape}, block_size = {block_size}, max_seq_len = {max_seq_len}")
|
||||
|
||||
if envs.VLLM_USE_OPT_OP:
|
||||
if attn_masks is None:
|
||||
ops.paged_attention_v1_opt(
|
||||
output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
seq_lens,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
tp_rank,
|
||||
blocksparse_local_blocks,
|
||||
blocksparse_vert_stride,
|
||||
blocksparse_block_size,
|
||||
blocksparse_head_sliding_step
|
||||
)
|
||||
else:
|
||||
ops.paged_attention_v1_opt_with_mask(
|
||||
output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
seq_lens,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
tp_rank,
|
||||
blocksparse_local_blocks,
|
||||
blocksparse_vert_stride,
|
||||
blocksparse_block_size,
|
||||
blocksparse_head_sliding_step,
|
||||
attn_masks,
|
||||
attn_masks_stride
|
||||
)
|
||||
else:
|
||||
if attn_masks is None:
|
||||
ops.paged_attention_v1(
|
||||
output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
seq_lens,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
tp_rank,
|
||||
blocksparse_local_blocks,
|
||||
blocksparse_vert_stride,
|
||||
blocksparse_block_size,
|
||||
blocksparse_head_sliding_step
|
||||
)
|
||||
else:
|
||||
ops.paged_attention_v1_with_mask(
|
||||
output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
seq_lens,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
tp_rank,
|
||||
blocksparse_local_blocks,
|
||||
blocksparse_vert_stride,
|
||||
blocksparse_block_size,
|
||||
blocksparse_head_sliding_step,
|
||||
attn_masks,
|
||||
attn_masks_stride
|
||||
)
|
||||
else:
|
||||
# Run PagedAttention V2.
|
||||
assert _PARTITION_SIZE % block_size == 0
|
||||
tmp_output = torch.empty(
|
||||
size=(num_seqs, num_heads, max_num_partitions, head_size),
|
||||
dtype=output.dtype,
|
||||
device=output.device,
|
||||
)
|
||||
exp_sums = torch.empty(
|
||||
size=(num_seqs, num_heads, max_num_partitions),
|
||||
dtype=torch.float32,
|
||||
device=output.device,
|
||||
)
|
||||
max_logits = torch.empty_like(exp_sums)
|
||||
|
||||
if envs.VLLM_USE_PA_PRINT_PARAM:
|
||||
print("PA V2 SIZE:")
|
||||
print(f"exp_sums.shape = {exp_sums.shape}, max_logits.shape = {max_logits.shape}, tmp_output.shape = {tmp_output.shape}")
|
||||
print(f"query.shape = {query.shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}")
|
||||
print(f"num_kv_heads = {num_kv_heads}, scale = {scale:.3f}, block_tables.shape = {block_tables.shape}, seq_lens.shape = {seq_lens.shape}, block_size = {block_size}, max_seq_len = {max_seq_len}")
|
||||
|
||||
if envs.VLLM_USE_OPT_OP:
|
||||
if attn_masks is None:
|
||||
ops.paged_attention_v2_opt(
|
||||
output,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
seq_lens,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
tp_rank,
|
||||
blocksparse_local_blocks,
|
||||
blocksparse_vert_stride,
|
||||
blocksparse_block_size,
|
||||
blocksparse_head_sliding_step
|
||||
)
|
||||
else:
|
||||
ops.paged_attention_v2_opt_with_mask(
|
||||
output,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
seq_lens,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
tp_rank,
|
||||
blocksparse_local_blocks,
|
||||
blocksparse_vert_stride,
|
||||
blocksparse_block_size,
|
||||
blocksparse_head_sliding_step,
|
||||
attn_masks,
|
||||
attn_masks_stride
|
||||
)
|
||||
else:
|
||||
if attn_masks is None:
|
||||
ops.paged_attention_v2(
|
||||
output,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
seq_lens,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
tp_rank,
|
||||
blocksparse_local_blocks,
|
||||
blocksparse_vert_stride,
|
||||
blocksparse_block_size,
|
||||
blocksparse_head_sliding_step
|
||||
)
|
||||
else:
|
||||
ops.paged_attention_v2_with_mask(
|
||||
output,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
seq_lens,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
tp_rank,
|
||||
blocksparse_local_blocks,
|
||||
blocksparse_vert_stride,
|
||||
blocksparse_block_size,
|
||||
blocksparse_head_sliding_step,
|
||||
attn_masks,
|
||||
attn_masks_stride
|
||||
)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def forward_prefix(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
query_start_loc: torch.Tensor,
|
||||
seq_lens_tensor: torch.Tensor,
|
||||
max_query_len: int,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
sliding_window: Optional[int],
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
output = torch.empty_like(query)
|
||||
max_seq_len = None
|
||||
context_attention_fwd(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
output,
|
||||
kv_cache_dtype,
|
||||
key_cache,
|
||||
value_cache,
|
||||
block_tables,
|
||||
# query_start_loc is (batch_size + 1,)
|
||||
query_start_loc,
|
||||
seq_lens_tensor,
|
||||
max_seq_len,
|
||||
max_query_len,
|
||||
k_scale,
|
||||
v_scale,
|
||||
alibi_slopes,
|
||||
sliding_window,
|
||||
)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: torch.Tensor,
|
||||
) -> None:
|
||||
src_key_cache = src_kv_cache[0]
|
||||
dst_key_cache = dst_kv_cache[0]
|
||||
ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
|
||||
|
||||
src_value_cache = src_kv_cache[1]
|
||||
dst_value_cache = dst_kv_cache[1]
|
||||
ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: torch.Tensor,
|
||||
) -> None:
|
||||
key_caches = [kv_cache[0] for kv_cache in kv_caches]
|
||||
value_caches = [kv_cache[1] for kv_cache in kv_caches]
|
||||
ops.copy_blocks(key_caches, value_caches, src_to_dists)
|
||||
120
vllm/attention/ops/pallas_kv_cache_update.py
Normal file
120
vllm/attention/ops/pallas_kv_cache_update.py
Normal file
@@ -0,0 +1,120 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import functools
|
||||
|
||||
import jax
|
||||
from jax.experimental import pallas as pl
|
||||
from jax.experimental.pallas import tpu as pltpu
|
||||
|
||||
from vllm.utils import cdiv
|
||||
|
||||
|
||||
def _kv_cache_update_kernel(
|
||||
# Prefetch
|
||||
slices_ref, # [3, padded_num_slices], list of (kv_cache_start,
|
||||
# new_kv_start, slice_len)
|
||||
# Input
|
||||
new_kv_hbm_ref, # [num_tokens, num_combined_kv_heads, head_dim]
|
||||
kv_cache_hbm_ref, # [total_num_pages * page_size, num_combined_kv_heads,
|
||||
# head_dim]
|
||||
# Output
|
||||
_, # [total_num_pages * page_size, num_combined_kv_heads, head_dim]
|
||||
# Scratch
|
||||
scratch, # [num_slices_per_block, page_size, num_combined_kv_heads,
|
||||
# head_dim]
|
||||
sem,
|
||||
):
|
||||
async_copies = []
|
||||
block_idx = pl.program_id(0)
|
||||
num_slices_per_block = scratch.shape[0]
|
||||
|
||||
# Copy from new_kv_hbm_ref to scratch
|
||||
for i in range(num_slices_per_block):
|
||||
offset_i = i + block_idx * num_slices_per_block
|
||||
new_kv_start = slices_ref[1, offset_i]
|
||||
length = slices_ref[2, offset_i]
|
||||
async_copy = pltpu.make_async_copy(
|
||||
new_kv_hbm_ref.at[pl.ds(new_kv_start, length), ...],
|
||||
scratch.at[i, pl.ds(0, length), ...],
|
||||
sem,
|
||||
)
|
||||
async_copy.start()
|
||||
async_copies.append(async_copy)
|
||||
|
||||
for async_copy in async_copies:
|
||||
async_copy.wait()
|
||||
|
||||
# Copy from scratch to kv_cache_hbm_ref
|
||||
async_copies.clear()
|
||||
for i in range(num_slices_per_block):
|
||||
offset_i = i + block_idx * num_slices_per_block
|
||||
kv_cache_start = slices_ref[0, offset_i]
|
||||
length = slices_ref[2, offset_i]
|
||||
async_copy = pltpu.make_async_copy(
|
||||
scratch.at[i, pl.ds(0, length), ...],
|
||||
kv_cache_hbm_ref.at[pl.ds(kv_cache_start, length), ...],
|
||||
sem,
|
||||
)
|
||||
async_copy.start()
|
||||
async_copies.append(async_copy)
|
||||
for async_copy in async_copies:
|
||||
async_copy.wait()
|
||||
|
||||
|
||||
@functools.partial(
|
||||
jax.jit,
|
||||
static_argnames=["page_size", "num_slices_per_block"],
|
||||
)
|
||||
def kv_cache_update(
|
||||
new_kv: jax.Array, # [total_num_token, num_combined_kv_heads, head_dim]
|
||||
slices: jax.
|
||||
Array, # [3, slices], list of (kv_cache_start, new_kv_start, slice_len)
|
||||
kv_cache: jax.
|
||||
Array, # [total_num_pages * page_size, num_combined_kv_heads, head_dim]
|
||||
num_kv_update_slices: jax.Array, # [1]
|
||||
*,
|
||||
page_size: int = 32,
|
||||
num_slices_per_block: int = 8,
|
||||
):
|
||||
assert slices.shape[1] % num_slices_per_block == 0
|
||||
_, num_combined_kv_heads, head_dim = new_kv.shape
|
||||
assert kv_cache.shape[1] == num_combined_kv_heads
|
||||
assert kv_cache.shape[2] == head_dim
|
||||
assert head_dim % 128 == 0
|
||||
# TODO: Add dynamic check to make sure that the all the slice lengths are
|
||||
# smaller or equal to page_size
|
||||
|
||||
in_specs = [
|
||||
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
|
||||
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
|
||||
]
|
||||
|
||||
out_specs = [pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY)]
|
||||
out_shape = [jax.ShapeDtypeStruct(kv_cache.shape, dtype=kv_cache.dtype)]
|
||||
|
||||
scalar_prefetches = [slices]
|
||||
scratch = pltpu.VMEM(
|
||||
(num_slices_per_block, page_size, num_combined_kv_heads, head_dim),
|
||||
new_kv.dtype,
|
||||
)
|
||||
|
||||
scratch_shapes = [
|
||||
scratch,
|
||||
pltpu.SemaphoreType.DMA,
|
||||
]
|
||||
|
||||
kernel = pl.pallas_call(
|
||||
_kv_cache_update_kernel,
|
||||
grid_spec=pltpu.PrefetchScalarGridSpec(
|
||||
num_scalar_prefetch=len(scalar_prefetches),
|
||||
in_specs=in_specs,
|
||||
out_specs=out_specs,
|
||||
grid=(cdiv(num_kv_update_slices[0], num_slices_per_block), ),
|
||||
scratch_shapes=scratch_shapes,
|
||||
),
|
||||
out_shape=out_shape,
|
||||
input_output_aliases={len(scalar_prefetches) + 1: 0},
|
||||
)
|
||||
|
||||
return kernel(*scalar_prefetches, new_kv, kv_cache)[0]
|
||||
906
vllm/attention/ops/prefix_prefill.py
Normal file
906
vllm/attention/ops/prefix_prefill.py
Normal file
@@ -0,0 +1,906 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# The kernels in this file are adapted from LightLLM's context_attention_fwd:
|
||||
# https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
# Static kernels parameters
|
||||
# BASE_BLOCK = 128 if current_platform.has_device_capability(80) else 64
|
||||
# NUM_WARPS = 4 if current_platform.is_rocm() else 8
|
||||
|
||||
BASE_BLOCK = 32 if current_platform.has_device_capability(80) else 32
|
||||
NUM_WARPS = 8
|
||||
|
||||
|
||||
# To check compatibility
|
||||
IS_TURING = current_platform.get_device_capability() == (7, 5)
|
||||
|
||||
|
||||
# Here's an example autotuner config for this kernel. This config does provide
|
||||
# a performance improvement, but dramatically increases first call latency in
|
||||
# triton 3.2. Because of this tradeoff, it's currently commented out.
|
||||
# @triton.autotune(
|
||||
# configs=[
|
||||
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, \
|
||||
# "num_unroll_cache": 4, \
|
||||
# "num_unroll_request": 1 } | \
|
||||
# ({"kpack": 2, "waves_per_eu": 2} \
|
||||
# if current_platform.is_rocm() else {}), \
|
||||
# num_warps=4, \
|
||||
# num_stages=1)
|
||||
# ],
|
||||
# key=["BLOCK_SIZE", "MAX_Q_LEN", "MAX_CTX_LEN"]
|
||||
# )
|
||||
@triton.jit
|
||||
def _fwd_kernel(Q,
|
||||
K,
|
||||
V,
|
||||
K_cache,
|
||||
V_cache,
|
||||
B_Loc,
|
||||
sm_scale,
|
||||
k_scale,
|
||||
v_scale,
|
||||
B_Start_Loc,
|
||||
B_Seqlen,
|
||||
x: tl.constexpr,
|
||||
Out,
|
||||
stride_b_loc_b,
|
||||
stride_b_loc_s,
|
||||
stride_qbs,
|
||||
stride_qh,
|
||||
stride_qd,
|
||||
stride_kbs,
|
||||
stride_kh,
|
||||
stride_kd,
|
||||
stride_vbs,
|
||||
stride_vh,
|
||||
stride_vd,
|
||||
stride_obs,
|
||||
stride_oh,
|
||||
stride_od,
|
||||
stride_k_cache_bs,
|
||||
stride_k_cache_h,
|
||||
stride_k_cache_d,
|
||||
stride_k_cache_bl: tl.constexpr,
|
||||
stride_k_cache_x,
|
||||
stride_v_cache_bs,
|
||||
stride_v_cache_h,
|
||||
stride_v_cache_d,
|
||||
stride_v_cache_bl,
|
||||
num_queries_per_kv: tl.constexpr,
|
||||
IN_PRECISION: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_DMODEL_PADDED: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
SLIDING_WINDOW: tl.constexpr,
|
||||
num_unroll_cache: tl.constexpr,
|
||||
num_unroll_request: tl.constexpr,
|
||||
SKIP_DECODE: tl.constexpr,
|
||||
MAX_Q_LEN: tl.constexpr = 0,
|
||||
MAX_CTX_LEN: tl.constexpr = 0):
|
||||
|
||||
cur_batch = tl.program_id(0)
|
||||
cur_head = tl.program_id(1)
|
||||
start_m = tl.program_id(2)
|
||||
|
||||
cur_kv_head = cur_head // num_queries_per_kv
|
||||
|
||||
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
||||
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
|
||||
cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1)
|
||||
cur_batch_query_len = (cur_batch_in_all_stop_index -
|
||||
cur_batch_in_all_start_index)
|
||||
cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len
|
||||
|
||||
if SKIP_DECODE and cur_batch_query_len == 1:
|
||||
return
|
||||
|
||||
# start position inside of the query
|
||||
# generally, N goes over kv, while M goes over query_len
|
||||
block_start_loc = BLOCK_M * start_m
|
||||
|
||||
# initialize offsets
|
||||
# [BLOCK_SIZE]; starts at 0
|
||||
offs_bs_n = tl.arange(0, BLOCK_SIZE)
|
||||
# [N]; starts at 0
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
# [D]; starts at 0
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL_PADDED)
|
||||
# [M]; starts at current position in query
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
# [M,D]
|
||||
off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
|
||||
cur_head * stride_qh + offs_d[None, :] * stride_qd)
|
||||
|
||||
dim_mask = tl.where(
|
||||
tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1,
|
||||
0).to(tl.int1) # [D]
|
||||
|
||||
q = tl.load(Q + off_q,
|
||||
mask=dim_mask[None, :] &
|
||||
(offs_m[:, None] < cur_batch_query_len),
|
||||
other=0.0) # [M,D]
|
||||
|
||||
# initialize pointer to m and l
|
||||
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
|
||||
l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) # [M,D]
|
||||
|
||||
# compute query against context (no causal mask here)
|
||||
for start_n in tl.range(0, cur_batch_ctx_len, BLOCK_SIZE, \
|
||||
loop_unroll_factor=num_unroll_cache):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_SIZE)
|
||||
# -- compute qk ----
|
||||
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
|
||||
(start_n // BLOCK_SIZE) * stride_b_loc_s)
|
||||
# [D,BLOCK_SIZE]
|
||||
off_k = (
|
||||
bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h +
|
||||
(offs_d[:, None] // x) * stride_k_cache_d +
|
||||
((start_n + offs_bs_n[None, :]) % BLOCK_SIZE) * stride_k_cache_bl +
|
||||
(offs_d[:, None] % x) * stride_k_cache_x)
|
||||
|
||||
# [BLOCK_SIZE,D]
|
||||
off_v = (bn[:, None] * stride_v_cache_bs +
|
||||
cur_kv_head * stride_v_cache_h +
|
||||
offs_d[None, :] * stride_v_cache_d +
|
||||
offs_bs_n[:, None] * stride_v_cache_bl)
|
||||
|
||||
if start_n + BLOCK_SIZE > cur_batch_ctx_len or \
|
||||
BLOCK_DMODEL != BLOCK_DMODEL_PADDED:
|
||||
k_load = tl.load(
|
||||
K_cache + off_k,
|
||||
mask=dim_mask[:, None] &
|
||||
((start_n + offs_bs_n[None, :]) < cur_batch_ctx_len),
|
||||
other=0.0) # [D,N]
|
||||
else:
|
||||
k_load = tl.load(K_cache + off_k)
|
||||
|
||||
if k_load.dtype.is_fp8():
|
||||
k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype)
|
||||
else:
|
||||
k = k_load
|
||||
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_SIZE], dtype=tl.float32) # [M,N]
|
||||
qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION)
|
||||
qk = tl.where((start_n + offs_bs_n[None, :]) < cur_batch_ctx_len, qk,
|
||||
float("-inf"))
|
||||
qk *= sm_scale
|
||||
if SLIDING_WINDOW > 0:
|
||||
# (cur_batch_ctx_len + offs_m[:, None]) are the positions of
|
||||
# Q entries in sequence
|
||||
# (start_n + offs_bs_n[None, :]) are the positions of
|
||||
# KV entries in sequence
|
||||
# So the condition makes sure each entry in Q only attends
|
||||
# to KV entries not more than SLIDING_WINDOW away.
|
||||
#
|
||||
# We can't use -inf here, because the
|
||||
# sliding window may lead to the entire row being masked.
|
||||
# This then makes m_ij contain -inf, which causes NaNs in
|
||||
# exp().
|
||||
qk = tl.where((cur_batch_ctx_len + offs_m[:, None]) -
|
||||
(start_n + offs_bs_n[None, :]) < SLIDING_WINDOW, qk,
|
||||
-10000)
|
||||
|
||||
# compute running maximum
|
||||
m_ij = tl.maximum(m_i, tl.max(qk, axis=1))
|
||||
p = tl.exp(qk - m_ij[:, None])
|
||||
l_ij = tl.sum(p, axis=1)
|
||||
alpha = tl.exp(m_i - m_ij)
|
||||
acc = acc * alpha[:, None]
|
||||
|
||||
# update acc
|
||||
if start_n + BLOCK_SIZE > cur_batch_ctx_len or \
|
||||
BLOCK_DMODEL != BLOCK_DMODEL_PADDED:
|
||||
v_load = tl.load(
|
||||
V_cache + off_v,
|
||||
mask=dim_mask[None, :] &
|
||||
((start_n + offs_bs_n[:, None]) < cur_batch_ctx_len),
|
||||
other=0.0) # [N,D]
|
||||
else:
|
||||
v_load = tl.load(V_cache + off_v)
|
||||
|
||||
if v_load.dtype.is_fp8():
|
||||
v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype)
|
||||
else:
|
||||
v = v_load
|
||||
p = p.to(v.dtype)
|
||||
|
||||
acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION)
|
||||
# # update m_i and l_i
|
||||
l_i = l_i * alpha + l_ij
|
||||
m_i = m_ij
|
||||
|
||||
off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
|
||||
offs_d[:, None] * stride_kd)
|
||||
off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
|
||||
offs_d[None, :] * stride_vd)
|
||||
k_ptrs = K + off_k
|
||||
v_ptrs = V + off_v
|
||||
|
||||
# block_mask is 0 when we're already past the current query length
|
||||
block_mask = tl.where(block_start_loc < cur_batch_query_len, 1, 0)
|
||||
|
||||
# compute query against itself (with causal mask)
|
||||
for start_n in tl.range(0, \
|
||||
block_mask * (start_m + 1) * BLOCK_M, BLOCK_N, \
|
||||
loop_unroll_factor=num_unroll_request):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
# -- compute qk ----
|
||||
k = tl.load(k_ptrs +
|
||||
(cur_batch_in_all_start_index + start_n) * stride_kbs,
|
||||
mask=dim_mask[:, None] &
|
||||
((start_n + offs_n[None, :]) < cur_batch_query_len),
|
||||
other=0.0)
|
||||
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION)
|
||||
qk *= sm_scale
|
||||
# apply causal mask
|
||||
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
|
||||
float("-inf"))
|
||||
if SLIDING_WINDOW > 0:
|
||||
qk = tl.where(
|
||||
offs_m[:, None] - (start_n + offs_n[None, :]) < SLIDING_WINDOW,
|
||||
qk, -10000)
|
||||
|
||||
# compute running maximum
|
||||
m_ij = tl.maximum(m_i, tl.max(qk, axis=1))
|
||||
p = tl.exp(qk - m_ij[:, None])
|
||||
l_ij = tl.sum(p, axis=1)
|
||||
alpha = tl.exp(m_i - m_ij)
|
||||
acc = acc * alpha[:, None]
|
||||
|
||||
# update acc
|
||||
v = tl.load(v_ptrs +
|
||||
(cur_batch_in_all_start_index + start_n) * stride_vbs,
|
||||
mask=dim_mask[None, :] &
|
||||
((start_n + offs_n[:, None]) < cur_batch_query_len),
|
||||
other=0.0)
|
||||
p = p.to(v.dtype)
|
||||
|
||||
acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION)
|
||||
# update m_i and l_i
|
||||
l_i = l_i * alpha + l_ij
|
||||
m_i = m_ij
|
||||
|
||||
acc = acc / l_i[:, None]
|
||||
|
||||
# initialize pointers to output
|
||||
off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
|
||||
cur_head * stride_oh + offs_d[None, :] * stride_od)
|
||||
out_ptrs = Out + off_o
|
||||
tl.store(out_ptrs,
|
||||
acc,
|
||||
mask=dim_mask[None, :] & (offs_m[:, None] < cur_batch_query_len))
|
||||
return
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel_flash_attn_v2(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
K_cache,
|
||||
V_cache,
|
||||
B_Loc,
|
||||
sm_scale,
|
||||
B_Start_Loc,
|
||||
B_Seqlen,
|
||||
B_Ctxlen,
|
||||
block_size,
|
||||
x,
|
||||
Out,
|
||||
stride_b_loc_b,
|
||||
stride_b_loc_s,
|
||||
stride_qbs,
|
||||
stride_qh,
|
||||
stride_qd,
|
||||
stride_kbs,
|
||||
stride_kh,
|
||||
stride_kd,
|
||||
stride_vbs,
|
||||
stride_vh,
|
||||
stride_vd,
|
||||
stride_obs,
|
||||
stride_oh,
|
||||
stride_od,
|
||||
stride_k_cache_bs,
|
||||
stride_k_cache_h,
|
||||
stride_k_cache_d,
|
||||
stride_k_cache_bl,
|
||||
stride_k_cache_x,
|
||||
stride_v_cache_bs,
|
||||
stride_v_cache_h,
|
||||
stride_v_cache_d,
|
||||
stride_v_cache_bl,
|
||||
num_queries_per_kv: int,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
cur_batch = tl.program_id(0)
|
||||
cur_head = tl.program_id(1)
|
||||
start_m = tl.program_id(2)
|
||||
|
||||
cur_kv_head = cur_head // num_queries_per_kv
|
||||
|
||||
cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
|
||||
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
||||
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
|
||||
|
||||
block_start_loc = BLOCK_M * start_m
|
||||
|
||||
# initialize offsets
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
|
||||
cur_head * stride_qh + offs_d[None, :] * stride_qd)
|
||||
|
||||
q = tl.load(Q + off_q,
|
||||
mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len,
|
||||
other=0.0)
|
||||
|
||||
# # initialize pointer to m and l
|
||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
|
||||
for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
# -- compute qk ----
|
||||
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
|
||||
((start_n + offs_n) // block_size) * stride_b_loc_s,
|
||||
mask=(start_n + offs_n) < cur_batch_ctx_len,
|
||||
other=0)
|
||||
off_k = (
|
||||
bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h +
|
||||
(offs_d[:, None] // x) * stride_k_cache_d +
|
||||
((start_n + offs_n[None, :]) % block_size) * stride_k_cache_bl +
|
||||
(offs_d[:, None] % x) * stride_k_cache_x)
|
||||
off_v = (bn[:, None] * stride_v_cache_bs +
|
||||
cur_kv_head * stride_v_cache_h +
|
||||
offs_d[None, :] * stride_v_cache_d +
|
||||
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
|
||||
k = tl.load(K_cache + off_k,
|
||||
mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len,
|
||||
other=0.0)
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, k)
|
||||
qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
|
||||
float("-inf"))
|
||||
qk *= sm_scale
|
||||
|
||||
# -- compute m_ij, p, l_ij
|
||||
m_ij = tl.max(qk, 1)
|
||||
m_i_new = tl.maximum(m_i, m_ij)
|
||||
p = tl.math.exp(qk - m_i_new[:, None])
|
||||
l_ij = tl.sum(p, 1)
|
||||
# -- update m_i and l_i
|
||||
|
||||
alpha = tl.math.exp(m_i - m_i_new)
|
||||
l_i_new = alpha * l_i + l_ij
|
||||
# -- update output accumulator --
|
||||
# scale p
|
||||
# scale acc
|
||||
acc_scale = alpha
|
||||
# acc_scale = l_i / l_i_new * alpha
|
||||
acc = acc * acc_scale[:, None]
|
||||
# update acc
|
||||
v = tl.load(V_cache + off_v,
|
||||
mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len,
|
||||
other=0.0)
|
||||
|
||||
p = p.to(v.dtype)
|
||||
acc += tl.dot(p, v)
|
||||
# update m_i and l_i
|
||||
l_i = l_i_new
|
||||
m_i = m_i_new
|
||||
|
||||
off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
|
||||
offs_d[:, None] * stride_kd)
|
||||
off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
|
||||
offs_d[None, :] * stride_vd)
|
||||
k_ptrs = K + off_k
|
||||
v_ptrs = V + off_v
|
||||
|
||||
block_mask = tl.where(
|
||||
block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)
|
||||
|
||||
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
# -- compute qk ----
|
||||
k = tl.load(k_ptrs +
|
||||
(cur_batch_in_all_start_index + start_n) * stride_kbs,
|
||||
mask=(start_n + offs_n[None, :])
|
||||
< cur_batch_seq_len - cur_batch_ctx_len,
|
||||
other=0.0)
|
||||
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, k)
|
||||
qk *= sm_scale
|
||||
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
|
||||
float("-inf"))
|
||||
|
||||
# -- compute m_ij, p, l_ij
|
||||
m_ij = tl.max(qk, 1)
|
||||
m_i_new = tl.maximum(m_i, m_ij)
|
||||
p = tl.math.exp(qk - m_i_new[:, None])
|
||||
l_ij = tl.sum(p, 1)
|
||||
# -- update m_i and l_i
|
||||
|
||||
alpha = tl.math.exp(m_i - m_i_new)
|
||||
l_i_new = alpha * l_i + l_ij
|
||||
# -- update output accumulator --
|
||||
# scale p
|
||||
# scale acc
|
||||
acc_scale = alpha
|
||||
# acc_scale = l_i / l_i_new * alpha
|
||||
acc = acc * acc_scale[:, None]
|
||||
# update acc
|
||||
v = tl.load(v_ptrs +
|
||||
(cur_batch_in_all_start_index + start_n) * stride_vbs,
|
||||
mask=(start_n + offs_n[:, None])
|
||||
< cur_batch_seq_len - cur_batch_ctx_len,
|
||||
other=0.0)
|
||||
|
||||
p = p.to(v.dtype)
|
||||
acc += tl.dot(p, v)
|
||||
# update m_i and l_i
|
||||
l_i = l_i_new
|
||||
m_i = m_i_new
|
||||
|
||||
# acc /= l_i[:, None]
|
||||
# initialize pointers to output
|
||||
off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
|
||||
cur_head * stride_oh + offs_d[None, :] * stride_od)
|
||||
out_ptrs = Out + off_o
|
||||
tl.store(out_ptrs,
|
||||
acc,
|
||||
mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)
|
||||
return
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel_alibi(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
K_cache,
|
||||
V_cache,
|
||||
B_Loc,
|
||||
sm_scale,
|
||||
k_scale,
|
||||
v_scale,
|
||||
B_Start_Loc,
|
||||
B_Seqlen,
|
||||
Alibi_slopes,
|
||||
block_size,
|
||||
x,
|
||||
Out,
|
||||
stride_b_loc_b,
|
||||
stride_b_loc_s,
|
||||
stride_qbs,
|
||||
stride_qh,
|
||||
stride_qd,
|
||||
stride_kbs,
|
||||
stride_kh,
|
||||
stride_kd,
|
||||
stride_vbs,
|
||||
stride_vh,
|
||||
stride_vd,
|
||||
stride_obs,
|
||||
stride_oh,
|
||||
stride_od,
|
||||
stride_k_cache_bs,
|
||||
stride_k_cache_h,
|
||||
stride_k_cache_d,
|
||||
stride_k_cache_bl,
|
||||
stride_k_cache_x,
|
||||
stride_v_cache_bs,
|
||||
stride_v_cache_h,
|
||||
stride_v_cache_d,
|
||||
stride_v_cache_bl,
|
||||
num_queries_per_kv: int,
|
||||
IN_PRECISION: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr, # head size
|
||||
BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2
|
||||
BLOCK_N: tl.constexpr,
|
||||
SKIP_DECODE: tl.constexpr,
|
||||
):
|
||||
# attn_bias[]
|
||||
cur_batch = tl.program_id(0)
|
||||
cur_head = tl.program_id(1)
|
||||
start_m = tl.program_id(2)
|
||||
|
||||
cur_kv_head = cur_head // num_queries_per_kv
|
||||
|
||||
# cur_batch_seq_len: the length of prompts
|
||||
# cur_batch_ctx_len: the length of prefix
|
||||
# cur_batch_in_all_start_index: the start id of the dim=0
|
||||
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
||||
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
|
||||
cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1)
|
||||
cur_batch_query_len = (cur_batch_in_all_stop_index -
|
||||
cur_batch_in_all_start_index)
|
||||
cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len
|
||||
|
||||
if SKIP_DECODE and cur_batch_query_len == 1:
|
||||
return
|
||||
|
||||
block_start_loc = BLOCK_M * start_m
|
||||
|
||||
# initialize offsets
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL_PADDED)
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
|
||||
cur_head * stride_qh + offs_d[None, :] * stride_qd)
|
||||
|
||||
dim_mask = tl.where(
|
||||
tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to(tl.int1)
|
||||
|
||||
q = tl.load(Q + off_q,
|
||||
mask=dim_mask[None, :] &
|
||||
(offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len),
|
||||
other=0.0)
|
||||
|
||||
# # initialize pointer to m and l
|
||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32)
|
||||
|
||||
alibi_slope = tl.load(Alibi_slopes + cur_head)
|
||||
alibi_start_q = tl.arange(0, BLOCK_M) + block_start_loc + cur_batch_ctx_len
|
||||
alibi_start_k = 0
|
||||
for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
# -- compute qk ----
|
||||
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
|
||||
((start_n + offs_n) // block_size) * stride_b_loc_s,
|
||||
mask=(start_n + offs_n) < cur_batch_ctx_len,
|
||||
other=0)
|
||||
off_k = (
|
||||
bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h +
|
||||
(offs_d[:, None] // x) * stride_k_cache_d +
|
||||
((start_n + offs_n[None, :]) % block_size) * stride_k_cache_bl +
|
||||
(offs_d[:, None] % x) * stride_k_cache_x)
|
||||
off_v = (bn[:, None] * stride_v_cache_bs +
|
||||
cur_kv_head * stride_v_cache_h +
|
||||
offs_d[None, :] * stride_v_cache_d +
|
||||
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
|
||||
k_load = tl.load(K_cache + off_k,
|
||||
mask=dim_mask[:, None] &
|
||||
((start_n + offs_n[None, :]) < cur_batch_ctx_len),
|
||||
other=0.0) # [D,N]
|
||||
|
||||
if k_load.dtype.is_fp8():
|
||||
k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype)
|
||||
else:
|
||||
k = k_load
|
||||
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION)
|
||||
qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
|
||||
float("-inf"))
|
||||
qk *= sm_scale
|
||||
|
||||
# load alibi
|
||||
alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k -
|
||||
alibi_start_q[:, None]) * alibi_slope
|
||||
alibi = tl.where(
|
||||
(alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), alibi,
|
||||
float("-inf"))
|
||||
qk += alibi
|
||||
alibi_start_k += BLOCK_N
|
||||
|
||||
# -- compute m_ij, p, l_ij
|
||||
m_ij = tl.max(qk, 1)
|
||||
m_i_new = tl.maximum(m_i, m_ij)
|
||||
p = tl.math.exp(qk - m_i_new[:, None])
|
||||
l_ij = tl.sum(p, 1)
|
||||
# -- update m_i and l_i
|
||||
|
||||
alpha = tl.math.exp(m_i - m_i_new)
|
||||
l_i_new = alpha * l_i + l_ij
|
||||
# -- update output accumulator --
|
||||
# scale p
|
||||
# scale acc
|
||||
acc_scale = alpha
|
||||
# acc_scale = l_i / l_i_new * alpha
|
||||
acc = acc * acc_scale[:, None]
|
||||
# update acc
|
||||
v_load = tl.load(V_cache + off_v,
|
||||
mask=dim_mask[None, :] &
|
||||
((start_n + offs_n[:, None]) < cur_batch_ctx_len),
|
||||
other=0.0)
|
||||
if v_load.dtype.is_fp8():
|
||||
v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype)
|
||||
else:
|
||||
v = v_load
|
||||
p = p.to(v.dtype)
|
||||
|
||||
acc = tl.dot(p, v, acc=acc, input_precision='ieee')
|
||||
# update m_i and l_i
|
||||
l_i = l_i_new
|
||||
m_i = m_i_new
|
||||
|
||||
off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
|
||||
offs_d[:, None] * stride_kd)
|
||||
off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
|
||||
offs_d[None, :] * stride_vd)
|
||||
k_ptrs = K + off_k
|
||||
v_ptrs = V + off_v
|
||||
|
||||
block_mask = tl.where(
|
||||
block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)
|
||||
|
||||
# init alibi
|
||||
alibi_slope = tl.load(Alibi_slopes + cur_head)
|
||||
alibi_start_q = tl.arange(0, BLOCK_M) + block_start_loc + cur_batch_ctx_len
|
||||
alibi_start_k = cur_batch_ctx_len
|
||||
# # init debugger
|
||||
# offset_db_q = tl.arange(0, BLOCK_M) + block_start_loc
|
||||
# offset_db_k = tl.arange(0, BLOCK_N)
|
||||
# calc q[BLOCK_M, BLOCK_MODEL] mul k[prefix_len: , BLOCK_DMODEL]
|
||||
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
# -- compute qk ----
|
||||
k = tl.load(
|
||||
k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,
|
||||
mask=dim_mask[:, None] & ((start_n + offs_n[None, :])
|
||||
< cur_batch_seq_len - cur_batch_ctx_len),
|
||||
other=0.0)
|
||||
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk = tl.dot(q, k, acc=qk, input_precision='ieee')
|
||||
qk *= sm_scale
|
||||
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
|
||||
float("-inf"))
|
||||
|
||||
# load alibi
|
||||
alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k -
|
||||
alibi_start_q[:, None]) * alibi_slope
|
||||
alibi = tl.where(
|
||||
(alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), alibi,
|
||||
float("-inf"))
|
||||
qk += alibi
|
||||
alibi_start_k += BLOCK_N
|
||||
|
||||
# -- compute m_ij, p, l_ij
|
||||
m_ij = tl.max(qk, 1)
|
||||
m_i_new = tl.maximum(m_i, m_ij)
|
||||
p = tl.math.exp(qk - m_i_new[:, None])
|
||||
l_ij = tl.sum(p, 1)
|
||||
# -- update m_i and l_i
|
||||
|
||||
alpha = tl.math.exp(m_i - m_i_new)
|
||||
l_i_new = alpha * l_i + l_ij
|
||||
# -- update output accumulator --
|
||||
# scale p
|
||||
# scale acc
|
||||
acc_scale = alpha
|
||||
# acc_scale = l_i / l_i_new * alpha
|
||||
acc = acc * acc_scale[:, None]
|
||||
# update acc
|
||||
v = tl.load(
|
||||
v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,
|
||||
mask=dim_mask[None, :] & ((start_n + offs_n[:, None])
|
||||
< cur_batch_seq_len - cur_batch_ctx_len),
|
||||
other=0.0)
|
||||
p = p.to(v.dtype)
|
||||
|
||||
acc = tl.dot(p, v, acc=acc, input_precision='ieee')
|
||||
# update m_i and l_i
|
||||
l_i = l_i_new
|
||||
m_i = m_i_new
|
||||
|
||||
acc = acc / l_i[:, None]
|
||||
|
||||
# initialize pointers to output
|
||||
off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
|
||||
cur_head * stride_oh + offs_d[None, :] * stride_od)
|
||||
out_ptrs = Out + off_o
|
||||
tl.store(out_ptrs,
|
||||
acc,
|
||||
mask=dim_mask[None, :] &
|
||||
(offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len))
|
||||
return
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def context_attention_fwd(q,
|
||||
k,
|
||||
v,
|
||||
o,
|
||||
kv_cache_dtype: str,
|
||||
k_cache,
|
||||
v_cache,
|
||||
b_loc,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
max_seq_len,
|
||||
max_input_len,
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
alibi_slopes=None,
|
||||
sliding_window=None,
|
||||
sm_scale=None,
|
||||
skip_decode=False):
|
||||
|
||||
q_dtype_is_f32 = q.dtype is torch.float32
|
||||
|
||||
# Turing does have tensor core for float32 multiplication
|
||||
# use ieee as fallback for triton kernels work. There is also
|
||||
# warning on vllm/config.py to inform users this fallback
|
||||
# implementation
|
||||
IN_PRECISION = 'ieee' if IS_TURING and q_dtype_is_f32 else None
|
||||
|
||||
# Conversion of FP8 Tensor from uint8 storage to
|
||||
# appropriate torch.dtype for interpretation by Triton
|
||||
if "fp8" in kv_cache_dtype:
|
||||
assert k_cache.dtype in [torch.uint8, current_platform.fp8_dtype()]
|
||||
assert v_cache.dtype in [torch.uint8, current_platform.fp8_dtype()]
|
||||
|
||||
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
|
||||
target_dtype = current_platform.fp8_dtype()
|
||||
elif kv_cache_dtype == "fp8_e5m2":
|
||||
target_dtype = torch.float8_e5m2
|
||||
else:
|
||||
raise ValueError("Unsupported FP8 dtype:", kv_cache_dtype)
|
||||
|
||||
k_cache = k_cache.view(target_dtype)
|
||||
v_cache = v_cache.view(target_dtype)
|
||||
|
||||
if (k_cache.dtype == torch.uint8
|
||||
or v_cache.dtype == torch.uint8 and kv_cache_dtype == "auto"):
|
||||
raise ValueError("kv_cache_dtype='auto' unsupported for\
|
||||
FP8 KV Cache prefill kernel")
|
||||
|
||||
# shape constraints
|
||||
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
||||
assert Lq == Lk and Lk == Lv
|
||||
# round up Lk to a power of 2 - this is required for Triton block size
|
||||
Lk_padded = triton.next_power_of_2(Lk)
|
||||
|
||||
if sm_scale is None:
|
||||
sm_scale = 1.0 / (Lq**0.5)
|
||||
batch, head = b_seq_len.shape[0], q.shape[1]
|
||||
num_queries_per_kv = q.shape[1] // k.shape[1]
|
||||
|
||||
assert batch + 1 == len(b_start_loc)
|
||||
|
||||
# 0 means "disable"
|
||||
if sliding_window is None or sliding_window <= 0:
|
||||
sliding_window = 0
|
||||
|
||||
if alibi_slopes is not None:
|
||||
# need to reduce num. blocks when using fp32
|
||||
# due to increased use of GPU shared memory
|
||||
# if q.dtype is torch.float32:
|
||||
BLOCK = BASE_BLOCK // 2 if q_dtype_is_f32 else BASE_BLOCK
|
||||
# batch, head,
|
||||
grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
|
||||
_fwd_kernel_alibi[grid](
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
k_cache,
|
||||
v_cache,
|
||||
b_loc,
|
||||
sm_scale,
|
||||
k_scale,
|
||||
v_scale,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
alibi_slopes,
|
||||
v_cache.shape[3],
|
||||
k_cache.shape[4],
|
||||
o,
|
||||
b_loc.stride(0),
|
||||
b_loc.stride(1),
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
q.stride(2),
|
||||
k.stride(0),
|
||||
k.stride(1),
|
||||
k.stride(2),
|
||||
v.stride(0),
|
||||
v.stride(1),
|
||||
v.stride(2),
|
||||
o.stride(0),
|
||||
o.stride(1),
|
||||
o.stride(2),
|
||||
k_cache.stride(0),
|
||||
k_cache.stride(1),
|
||||
k_cache.stride(2),
|
||||
k_cache.stride(3),
|
||||
k_cache.stride(
|
||||
4), #[num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||
v_cache.stride(0),
|
||||
v_cache.stride(1),
|
||||
v_cache.stride(2),
|
||||
v_cache.stride(
|
||||
3), #[num_blocks, num_kv_heads, head_size, block_size]
|
||||
num_queries_per_kv=num_queries_per_kv,
|
||||
IN_PRECISION=IN_PRECISION,
|
||||
BLOCK_M=BLOCK,
|
||||
BLOCK_DMODEL=Lk,
|
||||
BLOCK_DMODEL_PADDED=Lk_padded,
|
||||
BLOCK_N=BLOCK,
|
||||
SKIP_DECODE=skip_decode,
|
||||
num_warps=NUM_WARPS,
|
||||
num_stages=1,
|
||||
)
|
||||
return
|
||||
|
||||
max_seq_len = 0 if max_seq_len is None else max_seq_len
|
||||
extra_kargs = {}
|
||||
if current_platform.is_rocm():
|
||||
extra_kargs = {"kpack": 2, "waves_per_eu": 2}
|
||||
|
||||
grid = lambda META: (batch, head,
|
||||
triton.cdiv(max_input_len, META["BLOCK_M"]))
|
||||
_fwd_kernel[grid](
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
k_cache,
|
||||
v_cache,
|
||||
b_loc,
|
||||
sm_scale,
|
||||
k_scale,
|
||||
v_scale,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
k_cache.shape[4],
|
||||
o,
|
||||
b_loc.stride(0),
|
||||
b_loc.stride(1),
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
q.stride(2),
|
||||
k.stride(0),
|
||||
k.stride(1),
|
||||
k.stride(2),
|
||||
v.stride(0),
|
||||
v.stride(1),
|
||||
v.stride(2),
|
||||
o.stride(0),
|
||||
o.stride(1),
|
||||
o.stride(2),
|
||||
k_cache.stride(0),
|
||||
k_cache.stride(1),
|
||||
k_cache.stride(2),
|
||||
k_cache.stride(3),
|
||||
k_cache.stride(
|
||||
4), #[num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||
v_cache.stride(0),
|
||||
v_cache.stride(1),
|
||||
v_cache.stride(2),
|
||||
v_cache.stride(3), #[num_blocks, num_kv_heads, head_size, block_size]
|
||||
BLOCK_SIZE=v_cache.shape[3],
|
||||
num_queries_per_kv=num_queries_per_kv,
|
||||
IN_PRECISION=IN_PRECISION,
|
||||
BLOCK_DMODEL=Lk,
|
||||
BLOCK_DMODEL_PADDED=Lk_padded,
|
||||
SLIDING_WINDOW=sliding_window,
|
||||
SKIP_DECODE=skip_decode,
|
||||
BLOCK_M=128,
|
||||
BLOCK_N=64,
|
||||
num_unroll_cache=4,
|
||||
num_unroll_request=1,
|
||||
num_warps=4,
|
||||
num_stages=1,
|
||||
**extra_kargs)
|
||||
return
|
||||
100
vllm/attention/ops/rocm_aiter_mla.py
Normal file
100
vllm/attention/ops/rocm_aiter_mla.py
Normal file
@@ -0,0 +1,100 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
|
||||
def get_aiter_mla_metadata(max_batch_size: int, block_size: int,
|
||||
max_block_per_batch: int,
|
||||
device: torch.device) -> tuple[torch.Tensor, ...]:
|
||||
paged_kv_indices = torch.zeros(max_batch_size * max_block_per_batch,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
paged_kv_indptr = torch.zeros(max_batch_size + 1,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
paged_kv_last_page_lens = torch.full((max_batch_size, ),
|
||||
block_size,
|
||||
dtype=torch.int32)
|
||||
qo_indptr = torch.zeros(max_batch_size + 1, dtype=torch.int, device=device)
|
||||
return paged_kv_indices, paged_kv_indptr, paged_kv_last_page_lens, qo_indptr
|
||||
|
||||
|
||||
def aiter_mla_decode_fwd(
|
||||
q: torch.Tensor,
|
||||
kv_buffer: torch.Tensor,
|
||||
o: torch.Tensor,
|
||||
sm_scale: float,
|
||||
qo_indptr: torch.Tensor,
|
||||
max_seqlen_qo: int,
|
||||
kv_indptr: Optional[torch.Tensor] = None,
|
||||
kv_indices: Optional[torch.Tensor] = None,
|
||||
kv_last_page_lens: Optional[torch.Tensor] = None,
|
||||
logit_cap: float = 0.0,
|
||||
):
|
||||
|
||||
torch.ops.vllm.rocm_aiter_mla_decode_fwd(q,
|
||||
kv_buffer.view(
|
||||
-1, 1, 1, q.shape[-1]),
|
||||
o,
|
||||
qo_indptr,
|
||||
max_seqlen_qo,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_lens,
|
||||
sm_scale=sm_scale,
|
||||
logit_cap=logit_cap)
|
||||
|
||||
|
||||
def mla_decode_fwd_impl(
|
||||
q: torch.Tensor,
|
||||
kv_buffer: torch.Tensor,
|
||||
o: torch.Tensor,
|
||||
qo_indptr: torch.Tensor,
|
||||
max_seqlen_qo: int,
|
||||
kv_indptr: Optional[torch.Tensor] = None,
|
||||
kv_indices: Optional[torch.Tensor] = None,
|
||||
kv_last_page_lens: Optional[torch.Tensor] = None,
|
||||
sm_scale: float = 1.0,
|
||||
logit_cap: float = 0.0,
|
||||
) -> None:
|
||||
from aiter.mla import mla_decode_fwd
|
||||
|
||||
mla_decode_fwd(q,
|
||||
kv_buffer.view(-1, 1, 1, q.shape[-1]),
|
||||
o,
|
||||
qo_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_lens,
|
||||
max_seqlen_qo,
|
||||
sm_scale=sm_scale,
|
||||
logit_cap=logit_cap)
|
||||
|
||||
|
||||
def mla_decode_fwd_fake(
|
||||
q: torch.Tensor,
|
||||
kv_buffer: torch.Tensor,
|
||||
o: torch.Tensor,
|
||||
qo_indptr: torch.Tensor,
|
||||
max_seqlen_qo: int,
|
||||
kv_indptr: Optional[torch.Tensor] = None,
|
||||
kv_indices: Optional[torch.Tensor] = None,
|
||||
kv_last_page_lens: Optional[torch.Tensor] = None,
|
||||
sm_scale: float = 1.0,
|
||||
logit_cap: float = 0.0,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
if current_platform.is_rocm():
|
||||
direct_register_custom_op(op_name="rocm_aiter_mla_decode_fwd",
|
||||
op_func=mla_decode_fwd_impl,
|
||||
mutates_args=["o"],
|
||||
fake_impl=mla_decode_fwd_fake,
|
||||
tags=[torch.Tag.needs_fixed_stride_order])
|
||||
102
vllm/attention/ops/rocm_aiter_paged_attn.py
Normal file
102
vllm/attention/ops/rocm_aiter_paged_attn.py
Normal file
@@ -0,0 +1,102 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
|
||||
import aiter as rocm_aiter
|
||||
import torch
|
||||
|
||||
from vllm.attention.ops.paged_attn import PagedAttention
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import cdiv
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
|
||||
class AITERPagedAttention(PagedAttention):
|
||||
|
||||
@staticmethod
|
||||
def write_to_paged_cache(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
) -> None:
|
||||
if kv_cache_dtype not in ["int8", "fp8", "fp8_e4m3"]:
|
||||
PagedAttention.write_to_paged_cache(key, value, key_cache,
|
||||
value_cache, slot_mapping,
|
||||
kv_cache_dtype, k_scale,
|
||||
v_scale)
|
||||
else:
|
||||
kv_cache_torch_dtype = (FP8_DTYPE
|
||||
if "fp8" in kv_cache_dtype else torch.int8)
|
||||
key_cache = key_cache.view(kv_cache_torch_dtype)
|
||||
value_cache = value_cache.view(kv_cache_torch_dtype)
|
||||
|
||||
rocm_aiter.reshape_and_cache_with_pertoken_quant(
|
||||
key, value, key_cache, value_cache, k_scale, v_scale,
|
||||
slot_mapping.flatten(), True)
|
||||
|
||||
@staticmethod
|
||||
def forward_decode(
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
max_seq_len: int,
|
||||
kv_cache_dtype: str,
|
||||
num_kv_heads: int,
|
||||
scale: float,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
tp_rank: int = 0,
|
||||
blocksparse_local_blocks: int = 0,
|
||||
blocksparse_vert_stride: int = 0,
|
||||
blocksparse_block_size: int = 64,
|
||||
blocksparse_head_sliding_step: int = 0,
|
||||
) -> torch.Tensor:
|
||||
if kv_cache_dtype not in ["int8", "fp8", "fp8_e4m3"]:
|
||||
return PagedAttention.forward_decode(
|
||||
query=query,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
block_tables=block_tables,
|
||||
seq_lens=seq_lens,
|
||||
max_seq_len=max_seq_len,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
num_kv_heads=num_kv_heads,
|
||||
scale=scale,
|
||||
alibi_slopes=alibi_slopes,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
tp_rank=tp_rank,
|
||||
blocksparse_local_blocks=blocksparse_local_blocks,
|
||||
blocksparse_vert_stride=blocksparse_vert_stride,
|
||||
blocksparse_block_size=blocksparse_block_size,
|
||||
blocksparse_head_sliding_step=blocksparse_head_sliding_step)
|
||||
|
||||
if "fp8" in kv_cache_dtype:
|
||||
key_cache = key_cache.view(torch.float8_e4m3fnuz)
|
||||
value_cache = value_cache.view(torch.float8_e4m3fnuz)
|
||||
|
||||
if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1:
|
||||
# use blocksparse paged attention
|
||||
block_size = value_cache.size(-1)
|
||||
assert (blocksparse_block_size > 0 and
|
||||
blocksparse_block_size % block_size == 0), \
|
||||
(f"{blocksparse_block_size=} needs to be a multiple of"
|
||||
f"{block_size=} used in block_tables.")
|
||||
|
||||
output = torch.empty_like(query)
|
||||
block_size = value_cache.shape[3]
|
||||
max_num_blocks_per_seq = cdiv(max_seq_len, block_size)
|
||||
|
||||
rocm_aiter.pa_fwd_asm(query, key_cache, value_cache, block_tables,
|
||||
seq_lens, max_num_blocks_per_seq, k_scale,
|
||||
v_scale, output)
|
||||
return output
|
||||
1614
vllm/attention/ops/triton_decode_attention.py
Normal file
1614
vllm/attention/ops/triton_decode_attention.py
Normal file
File diff suppressed because it is too large
Load Diff
984
vllm/attention/ops/triton_flash_attention.py
Normal file
984
vllm/attention/ops/triton_flash_attention.py
Normal file
@@ -0,0 +1,984 @@
|
||||
#!/usr/bin/env python
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Fused Attention
|
||||
===============
|
||||
|
||||
This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao
|
||||
(https://tridao.me/publications/flash2/flash2.pdf)
|
||||
Credits: OpenAI kernel team, AMD ML Frameworks Triton team
|
||||
|
||||
Features supported:
|
||||
|
||||
1) Fwd with causal masking
|
||||
2) Any sequence lengths without padding (currently fwd kernel only)
|
||||
3) Support for different sequence lengths for q and k
|
||||
4) Nested tensor API currently does not support dropout or bias.
|
||||
|
||||
Not currently supported:
|
||||
|
||||
1) Non power of two head dims
|
||||
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
# Avoid misleading ROCm warning.
|
||||
if current_platform.is_rocm():
|
||||
from vllm.platforms.rocm import on_gfx1x
|
||||
else:
|
||||
on_gfx1x = lambda *args, **kwargs: False
|
||||
|
||||
torch_dtype: tl.constexpr = torch.float16
|
||||
|
||||
|
||||
@triton.jit
|
||||
def cdiv_fn(x, y):
|
||||
return (x + y - 1) // y
|
||||
|
||||
|
||||
@triton.jit
|
||||
def max_fn(x, y):
|
||||
return tl.math.max(x, y)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride):
|
||||
ms = tl.arange(0, m)
|
||||
ns = tl.arange(0, n)
|
||||
return philox_offset + ms[:, None] * stride + ns[None, :]
|
||||
|
||||
|
||||
@triton.jit
|
||||
def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride):
|
||||
rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n,
|
||||
stride).to(tl.uint32)
|
||||
# TODO: use tl.randint for better performance
|
||||
return tl.rand(philox_seed, rng_offsets)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride):
|
||||
rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n,
|
||||
stride)
|
||||
rng_keep = rng_output > dropout_p
|
||||
return rng_keep
|
||||
|
||||
|
||||
@triton.jit
|
||||
def load_fn(block_ptr, first, second, pad):
|
||||
if first and second:
|
||||
tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad)
|
||||
elif first:
|
||||
tensor = tl.load(block_ptr, boundary_check=(0, ), padding_option=pad)
|
||||
elif second:
|
||||
tensor = tl.load(block_ptr, boundary_check=(1, ), padding_option=pad)
|
||||
else:
|
||||
tensor = tl.load(block_ptr)
|
||||
return tensor
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _attn_fwd_inner(
|
||||
acc,
|
||||
l_i,
|
||||
m_i,
|
||||
q,
|
||||
K_block_ptr,
|
||||
V_block_ptr,
|
||||
start_m,
|
||||
actual_seqlen_k,
|
||||
dropout_p,
|
||||
philox_seed,
|
||||
batch_philox_offset,
|
||||
encoded_softmax_block_ptr,
|
||||
block_min,
|
||||
block_max,
|
||||
offs_n_causal,
|
||||
masked_blocks,
|
||||
n_extra_tokens,
|
||||
bias_ptr,
|
||||
IS_CAUSAL: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
OFFS_M: tl.constexpr,
|
||||
OFFS_N: tl.constexpr,
|
||||
PRE_LOAD_V: tl.constexpr,
|
||||
MASK_STEPS: tl.constexpr,
|
||||
ENABLE_DROPOUT: tl.constexpr,
|
||||
RETURN_ENCODED_SOFTMAX: tl.constexpr,
|
||||
PADDED_HEAD: tl.constexpr,
|
||||
USE_FP8: tl.constexpr,
|
||||
qk_scale,
|
||||
p_descale,
|
||||
):
|
||||
# loop over k, v, and update accumulator
|
||||
for start_n in range(block_min, block_max, BLOCK_N):
|
||||
# For padded blocks, we will overrun the tensor size if
|
||||
# we load all BLOCK_N. For others, the blocks are all within range.
|
||||
k = load_fn(
|
||||
K_block_ptr,
|
||||
PADDED_HEAD,
|
||||
MASK_STEPS and (n_extra_tokens != 0),
|
||||
"zero",
|
||||
)
|
||||
if PRE_LOAD_V:
|
||||
v = load_fn(
|
||||
V_block_ptr,
|
||||
MASK_STEPS and (n_extra_tokens != 0),
|
||||
PADDED_HEAD,
|
||||
"zero",
|
||||
)
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
# We start from end of seqlen_k so only the first iteration would need
|
||||
# to be checked for padding if it is not a multiple of block_n
|
||||
# TODO: This can be optimized to only be true for the padded block.
|
||||
if MASK_STEPS: # noqa: SIM102
|
||||
# If this is the last block / iteration, we want to
|
||||
# mask if the sequence length is not a multiple of block size
|
||||
# a solution is to always do BLOCK_M // BLOCK_N + 1 steps
|
||||
# if not is_modulo_mn. last step might get wasted but that is okay.
|
||||
# check if this masking works for that case.
|
||||
if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0):
|
||||
boundary_m = tl.full([BLOCK_M],
|
||||
actual_seqlen_k,
|
||||
dtype=tl.int32)
|
||||
size_n = start_n + OFFS_N[None, :]
|
||||
mask = size_n < boundary_m[:, None]
|
||||
qk = tl.where(mask, qk, float("-inf"))
|
||||
if IS_CAUSAL:
|
||||
causal_boundary = start_n + offs_n_causal
|
||||
causal_mask = OFFS_M[:, None] >= causal_boundary[None, :]
|
||||
qk = tl.where(causal_mask, qk, float("-inf"))
|
||||
# -- compute qk ----
|
||||
qk += tl.dot(q, k)
|
||||
if USE_FP8:
|
||||
qk *= qk_scale
|
||||
if bias_ptr is not None:
|
||||
bias = load_fn(bias_ptr, False, MASK_STEPS
|
||||
and (n_extra_tokens != 0), "zero")
|
||||
# While bias is added after multiplying qk with sm_scale, our
|
||||
# optimization to use 2^x instead of e^x results in an additional
|
||||
# scale factor of log2(e) which we must also multiply the bias with.
|
||||
qk += bias * 1.44269504089
|
||||
m_ij = tl.maximum(m_i, tl.max(qk, 1))
|
||||
qk = qk - m_ij[:, None]
|
||||
p = tl.math.exp2(qk)
|
||||
|
||||
# CAVEAT: Must update l_ij before applying dropout
|
||||
l_ij = tl.sum(p, 1)
|
||||
if ENABLE_DROPOUT:
|
||||
philox_offset = (batch_philox_offset +
|
||||
start_m * BLOCK_M * actual_seqlen_k + start_n -
|
||||
BLOCK_N)
|
||||
keep = dropout_mask(
|
||||
philox_seed,
|
||||
philox_offset,
|
||||
dropout_p,
|
||||
BLOCK_M,
|
||||
BLOCK_N,
|
||||
actual_seqlen_k,
|
||||
)
|
||||
if RETURN_ENCODED_SOFTMAX:
|
||||
tl.store(
|
||||
encoded_softmax_block_ptr,
|
||||
tl.where(keep, p,
|
||||
-p).to(encoded_softmax_block_ptr.type.element_ty),
|
||||
)
|
||||
p = tl.where(keep, p, 0.0)
|
||||
elif RETURN_ENCODED_SOFTMAX:
|
||||
tl.store(
|
||||
encoded_softmax_block_ptr,
|
||||
p.to(encoded_softmax_block_ptr.type.element_ty),
|
||||
)
|
||||
# -- update output accumulator --
|
||||
alpha = tl.math.exp2(m_i - m_ij)
|
||||
acc = acc * alpha[:, None]
|
||||
if not PRE_LOAD_V:
|
||||
v = load_fn(
|
||||
V_block_ptr,
|
||||
MASK_STEPS and (n_extra_tokens != 0),
|
||||
PADDED_HEAD,
|
||||
"zero",
|
||||
)
|
||||
# -- update m_i and l_i
|
||||
l_i = l_i * alpha + l_ij
|
||||
# update m_i and l_i
|
||||
m_i = m_ij
|
||||
|
||||
if USE_FP8:
|
||||
p *= p_descale
|
||||
|
||||
acc += tl.dot(p.to(V_block_ptr.type.element_ty), v)
|
||||
|
||||
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
|
||||
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
|
||||
if bias_ptr is not None:
|
||||
bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N))
|
||||
if RETURN_ENCODED_SOFTMAX:
|
||||
encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr,
|
||||
(0, BLOCK_N))
|
||||
return acc, l_i, m_i
|
||||
|
||||
|
||||
def get_cdna_autotune_configs():
|
||||
return [
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_M': 256,
|
||||
'BLOCK_N': 64,
|
||||
'waves_per_eu': 2,
|
||||
'PRE_LOAD_V': False
|
||||
},
|
||||
num_stages=1,
|
||||
num_warps=8),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_M': 128,
|
||||
'BLOCK_N': 128,
|
||||
'waves_per_eu': 2,
|
||||
'PRE_LOAD_V': False
|
||||
},
|
||||
num_stages=1,
|
||||
num_warps=4),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_M': 256,
|
||||
'BLOCK_N': 128,
|
||||
'waves_per_eu': 2,
|
||||
'PRE_LOAD_V': False
|
||||
},
|
||||
num_stages=1,
|
||||
num_warps=8),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_M': 128,
|
||||
'BLOCK_N': 64,
|
||||
'waves_per_eu': 1,
|
||||
'PRE_LOAD_V': False
|
||||
},
|
||||
num_stages=1,
|
||||
num_warps=4),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_M': 128,
|
||||
'BLOCK_N': 64,
|
||||
'waves_per_eu': 3,
|
||||
'PRE_LOAD_V': True
|
||||
},
|
||||
num_stages=1,
|
||||
num_warps=4),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_M': 128,
|
||||
'BLOCK_N': 64,
|
||||
'waves_per_eu': 3,
|
||||
'PRE_LOAD_V': False
|
||||
},
|
||||
num_stages=1,
|
||||
num_warps=4),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_M': 64,
|
||||
'BLOCK_N': 64,
|
||||
'waves_per_eu': 4,
|
||||
'PRE_LOAD_V': False
|
||||
},
|
||||
num_stages=1,
|
||||
num_warps=8),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_M': 32,
|
||||
'BLOCK_N': 32,
|
||||
'waves_per_eu': 4,
|
||||
'PRE_LOAD_V': False
|
||||
},
|
||||
num_stages=1,
|
||||
num_warps=8),
|
||||
# TODO: This config fails with head_size not pow2 with data mismatches.
|
||||
# triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1,
|
||||
# 'PRE_LOAD_V': False}, num_stages=1, num_warps=4),
|
||||
|
||||
# Fails in AccelerateAMDMatmul (Triton) assert when using FP8:
|
||||
# triton.Config(
|
||||
# {
|
||||
# "BLOCK_M": 16,
|
||||
# "BLOCK_N": 16,
|
||||
# "waves_per_eu": 1,
|
||||
# "PRE_LOAD_V": False,
|
||||
# },
|
||||
# num_stages=1,
|
||||
# num_warps=4,
|
||||
# ),
|
||||
], ['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL', 'USE_FP8']
|
||||
|
||||
|
||||
def get_rdna_autotune_configs():
|
||||
return [
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_M': 32,
|
||||
'BLOCK_N': 32,
|
||||
'waves_per_eu': 4,
|
||||
'PRE_LOAD_V': False
|
||||
},
|
||||
num_stages=1,
|
||||
num_warps=2),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_M': 32,
|
||||
'BLOCK_N': 32,
|
||||
'waves_per_eu': 2,
|
||||
'PRE_LOAD_V': False
|
||||
},
|
||||
num_stages=1,
|
||||
num_warps=2),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_M': 32,
|
||||
'BLOCK_N': 16,
|
||||
'waves_per_eu': 4,
|
||||
'PRE_LOAD_V': False
|
||||
},
|
||||
num_stages=1,
|
||||
num_warps=2),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_M': 32,
|
||||
'BLOCK_N': 16,
|
||||
'waves_per_eu': 2,
|
||||
'PRE_LOAD_V': False
|
||||
},
|
||||
num_stages=1,
|
||||
num_warps=2),
|
||||
# Fails in AccelerateAMDMatmul (Triton) assert when using FP8:
|
||||
# triton.Config(
|
||||
# {
|
||||
# 'BLOCK_M': 16,
|
||||
# 'BLOCK_N': 16,
|
||||
# 'waves_per_eu': 4,
|
||||
# 'PRE_LOAD_V': False
|
||||
# },
|
||||
# num_stages=1,
|
||||
# num_warps=2),
|
||||
# triton.Config(
|
||||
# {
|
||||
# 'BLOCK_M': 16,
|
||||
# 'BLOCK_N': 16,
|
||||
# 'waves_per_eu': 2,
|
||||
# 'PRE_LOAD_V': False
|
||||
# },
|
||||
# num_stages=1,
|
||||
# num_warps=2),
|
||||
# # Fall-back config.
|
||||
# triton.Config(
|
||||
# {
|
||||
# 'BLOCK_M': 16,
|
||||
# 'BLOCK_N': 16,
|
||||
# 'waves_per_eu': 1,
|
||||
# 'PRE_LOAD_V': False
|
||||
# },
|
||||
# num_stages=1,
|
||||
# num_warps=2),
|
||||
], ['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL', 'USE_FP8']
|
||||
|
||||
|
||||
def get_autotune_configs():
|
||||
if on_gfx1x():
|
||||
return get_rdna_autotune_configs()
|
||||
else:
|
||||
return get_cdna_autotune_configs()
|
||||
|
||||
|
||||
autotune_configs, autotune_keys = get_autotune_configs()
|
||||
|
||||
float8_info = torch.finfo(current_platform.fp8_dtype())
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=autotune_configs,
|
||||
key=autotune_keys,
|
||||
)
|
||||
@triton.jit
|
||||
def attn_fwd(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
bias,
|
||||
sm_scale,
|
||||
q_scale,
|
||||
k_scale,
|
||||
v_scale,
|
||||
p_scale,
|
||||
p_descale,
|
||||
o_descale,
|
||||
L,
|
||||
Out,
|
||||
stride_qz: tl.int64,
|
||||
stride_qh: tl.int64,
|
||||
stride_qm: tl.int64,
|
||||
stride_qk: tl.int64,
|
||||
stride_kz: tl.int64,
|
||||
stride_kh: tl.int64,
|
||||
stride_kn: tl.int64,
|
||||
stride_kk: tl.int64,
|
||||
stride_vz: tl.int64,
|
||||
stride_vh: tl.int64,
|
||||
stride_vk: tl.int64,
|
||||
stride_vn: tl.int64,
|
||||
stride_oz: tl.int64,
|
||||
stride_oh: tl.int64,
|
||||
stride_om: tl.int64,
|
||||
stride_on: tl.int64,
|
||||
stride_bz: tl.int64,
|
||||
stride_bh: tl.int64,
|
||||
stride_bm: tl.int64,
|
||||
stride_bn: tl.int64,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
dropout_p,
|
||||
philox_seed,
|
||||
philox_offset_base,
|
||||
encoded_softmax,
|
||||
HQ: tl.constexpr,
|
||||
HK: tl.constexpr,
|
||||
ACTUAL_BLOCK_DMODEL: tl.constexpr,
|
||||
MAX_SEQLENS_Q: tl.constexpr,
|
||||
MAX_SEQLENS_K: tl.constexpr,
|
||||
VARLEN: tl.constexpr,
|
||||
IS_CAUSAL: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
USE_FP8: tl.constexpr,
|
||||
USE_FP8_OUT: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
PRE_LOAD_V: tl.constexpr,
|
||||
BIAS_TYPE: tl.constexpr,
|
||||
ENABLE_DROPOUT: tl.constexpr,
|
||||
RETURN_ENCODED_SOFTMAX: tl.constexpr,
|
||||
FP8_MIN: tl.constexpr = float8_info.min,
|
||||
FP8_MAX: tl.constexpr = float8_info.max,
|
||||
):
|
||||
start_m = tl.program_id(0)
|
||||
off_h_q = tl.program_id(1)
|
||||
off_z = tl.program_id(2)
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
if VARLEN:
|
||||
cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z)
|
||||
cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1)
|
||||
seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start
|
||||
# We have a one-size-fits-all grid in id(0). Some seqlens might be too
|
||||
# small for all start_m so for those we return early.
|
||||
if start_m * BLOCK_M > seqlen_q:
|
||||
return
|
||||
cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z)
|
||||
cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1)
|
||||
seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start
|
||||
else:
|
||||
cu_seqlens_q_start = 0
|
||||
cu_seqlens_k_start = 0
|
||||
seqlen_q = MAX_SEQLENS_Q
|
||||
seqlen_k = MAX_SEQLENS_K
|
||||
|
||||
# Now we compute whether we need to exit early due to causal masking.
|
||||
# This is because for seqlen_q > seqlen_k, M rows of the attn scores
|
||||
# are completely masked, resulting in 0s written to the output, and
|
||||
# inf written to LSE. We don't need to do any GEMMs in this case.
|
||||
# This block of code determines what N is, and if this WG is operating
|
||||
# on those M rows.
|
||||
n_blocks = cdiv_fn(seqlen_k, BLOCK_N)
|
||||
if IS_CAUSAL:
|
||||
# If seqlen_q == seqlen_k, the attn scores are a square matrix.
|
||||
# If seqlen_q != seqlen_k, attn scores are rectangular which means
|
||||
# the causal mask boundary is bottom right aligned, and ends at either
|
||||
# the top edge (seqlen_q < seqlen_k) or left edge.
|
||||
# This captures the decrease in n_blocks if we have a rectangular attn
|
||||
# matrix
|
||||
n_blocks_seqlen = cdiv_fn(
|
||||
(start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N)
|
||||
# This is what adjusts the block_max for the current WG, only
|
||||
# if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks
|
||||
n_blocks = min(n_blocks, n_blocks_seqlen)
|
||||
# If we have no blocks after adjusting for seqlen deltas, this WG is
|
||||
# part of the blocks that are all 0. We exit early.
|
||||
if n_blocks <= 0:
|
||||
o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om +
|
||||
off_h_q * stride_oh)
|
||||
O_block_ptr = tl.make_block_ptr(
|
||||
base=Out + o_offset,
|
||||
shape=(seqlen_q, BLOCK_DMODEL),
|
||||
strides=(stride_om, stride_on),
|
||||
offsets=(start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty)
|
||||
# We still need to write 0s to the result
|
||||
# tl.store(O_block_ptr,
|
||||
# acc.to(Out.type.element_ty), boundary_check=(0,1))
|
||||
# l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q
|
||||
# + offs_m
|
||||
# We store inf to LSE, not -inf because in the bwd pass,
|
||||
# we subtract this
|
||||
# from qk which makes it -inf, such that exp(qk - inf) = 0
|
||||
# for these masked blocks.
|
||||
# l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32)
|
||||
# tl.store(l_ptrs, l)
|
||||
# TODO: Should dropout and return encoded softmax be handled here?
|
||||
return
|
||||
|
||||
# If MQA / GQA, set the K and V head offsets appropriately.
|
||||
GROUP_SIZE: tl.constexpr = HQ // HK
|
||||
off_h_k = off_h_q // GROUP_SIZE if GROUP_SIZE != 1 else off_h_q
|
||||
|
||||
n_extra_tokens = 0
|
||||
if seqlen_k < BLOCK_N:
|
||||
n_extra_tokens = BLOCK_N - seqlen_k
|
||||
elif seqlen_k % BLOCK_N:
|
||||
n_extra_tokens = seqlen_k % BLOCK_N
|
||||
padded_head = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL
|
||||
|
||||
# Compute pointers for all the tensors used in this kernel.
|
||||
q_offset = (off_z * stride_qz + off_h_q * stride_qh +
|
||||
cu_seqlens_q_start * stride_qm)
|
||||
Q_block_ptr = tl.make_block_ptr(
|
||||
base=Q + q_offset,
|
||||
shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
k_offset = (off_z * stride_kz + off_h_k * stride_kh +
|
||||
cu_seqlens_k_start * stride_kn)
|
||||
K_block_ptr = tl.make_block_ptr(
|
||||
base=K + k_offset,
|
||||
shape=(ACTUAL_BLOCK_DMODEL, seqlen_k),
|
||||
strides=(stride_kk, stride_kn),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_DMODEL, BLOCK_N),
|
||||
order=(0, 1),
|
||||
)
|
||||
v_offset = (off_z * stride_vz + off_h_k * stride_vh +
|
||||
cu_seqlens_k_start * stride_vk)
|
||||
V_block_ptr = tl.make_block_ptr(
|
||||
base=V + v_offset,
|
||||
shape=(seqlen_k, ACTUAL_BLOCK_DMODEL),
|
||||
strides=(stride_vk, stride_vn),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_N, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
if BIAS_TYPE != 0:
|
||||
bias_ptr = tl.make_block_ptr(
|
||||
base=bias + off_h_q * stride_bh,
|
||||
shape=(seqlen_q, seqlen_k),
|
||||
strides=(stride_bm, stride_bn),
|
||||
offsets=(start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_N),
|
||||
order=(1, 0),
|
||||
)
|
||||
else:
|
||||
bias_ptr = None
|
||||
if ENABLE_DROPOUT:
|
||||
batch_philox_offset = philox_offset_base \
|
||||
+ (off_z * HQ + off_h_q) \
|
||||
* seqlen_q * seqlen_k
|
||||
else:
|
||||
batch_philox_offset = 0
|
||||
# We can ask to return the dropout mask without actually doing any dropout.
|
||||
# In this case, we return an invalid pointer so indicate the mask is not i
|
||||
# valid.
|
||||
# TODO: Fix encoded softmax. It currently uses just h_q in the base offset.
|
||||
if RETURN_ENCODED_SOFTMAX:
|
||||
encoded_softmax_block_ptr = tl.make_block_ptr(
|
||||
base=encoded_softmax + off_h_q * seqlen_q * seqlen_k,
|
||||
shape=(seqlen_q, seqlen_k),
|
||||
strides=(seqlen_k, 1),
|
||||
offsets=(start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_N),
|
||||
order=(1, 0),
|
||||
)
|
||||
else:
|
||||
encoded_softmax_block_ptr = 0
|
||||
# initialize pointer to m and l
|
||||
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
|
||||
l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
# scale sm_scale by log_2(e) and use 2^x in the loop as we do not
|
||||
# have native e^x support in HW.
|
||||
qk_scale = sm_scale * 1.44269504089
|
||||
# Q is loaded once at the beginning and shared by all N blocks.
|
||||
q = load_fn(Q_block_ptr, True, padded_head, "zero")
|
||||
if not USE_FP8:
|
||||
q = (q * qk_scale).to(Q_block_ptr.type.element_ty)
|
||||
acc_scale = 1.0
|
||||
else:
|
||||
qk_scale *= q_scale * k_scale
|
||||
acc_scale = p_scale * v_scale
|
||||
|
||||
# Here we compute how many full and masked blocks we have.
|
||||
padded_block_k = n_extra_tokens != 0
|
||||
is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0)
|
||||
if IS_CAUSAL:
|
||||
# There are always at least BLOCK_M // BLOCK_N masked blocks.
|
||||
# Additionally there might be one more due to dissimilar seqlens.
|
||||
masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn)
|
||||
else:
|
||||
# Padding on Q does not need to be masked in the FA loop.
|
||||
masked_blocks = padded_block_k
|
||||
# if IS_CAUSAL, not is_modulo_mn does not always result in an additional
|
||||
# block. In this case we might exceed n_blocks so pick the min.
|
||||
masked_blocks = min(masked_blocks, n_blocks)
|
||||
n_full_blocks = n_blocks - masked_blocks
|
||||
block_min = 0
|
||||
block_max = n_blocks * BLOCK_N
|
||||
# Compute for full blocks. Here we set causal to false regardless of its
|
||||
# value because there is no masking. Similarly we do not need padding.
|
||||
if n_full_blocks > 0:
|
||||
block_max = (n_blocks - masked_blocks) * BLOCK_N
|
||||
acc, l_i, m_i = _attn_fwd_inner(
|
||||
acc,
|
||||
l_i,
|
||||
m_i,
|
||||
q,
|
||||
K_block_ptr,
|
||||
V_block_ptr,
|
||||
start_m,
|
||||
seqlen_k,
|
||||
dropout_p,
|
||||
philox_seed,
|
||||
batch_philox_offset,
|
||||
encoded_softmax_block_ptr,
|
||||
# _, _, offs_n_causal, masked_blocks, n_extra_tokens, _
|
||||
block_min,
|
||||
block_max,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
bias_ptr,
|
||||
# IS_CAUSAL, ....
|
||||
False,
|
||||
BLOCK_M,
|
||||
BLOCK_DMODEL,
|
||||
BLOCK_N,
|
||||
offs_m,
|
||||
offs_n,
|
||||
# _, MASK_STEPS, ...
|
||||
PRE_LOAD_V,
|
||||
False,
|
||||
ENABLE_DROPOUT,
|
||||
RETURN_ENCODED_SOFTMAX,
|
||||
padded_head,
|
||||
USE_FP8,
|
||||
qk_scale,
|
||||
p_descale,
|
||||
)
|
||||
block_min = block_max
|
||||
block_max = n_blocks * BLOCK_N
|
||||
|
||||
tl.debug_barrier()
|
||||
# Remaining blocks, if any, are full / not masked.
|
||||
if masked_blocks > 0:
|
||||
offs_n_causal = offs_n + (seqlen_q - seqlen_k) if IS_CAUSAL else 0
|
||||
K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks * BLOCK_N))
|
||||
V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks * BLOCK_N, 0))
|
||||
if bias_ptr is not None:
|
||||
bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N))
|
||||
if RETURN_ENCODED_SOFTMAX:
|
||||
encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr,
|
||||
(0, n_full_blocks))
|
||||
acc, l_i, m_i = _attn_fwd_inner(
|
||||
acc,
|
||||
l_i,
|
||||
m_i,
|
||||
q,
|
||||
K_block_ptr,
|
||||
V_block_ptr,
|
||||
start_m,
|
||||
seqlen_k,
|
||||
dropout_p,
|
||||
philox_seed,
|
||||
batch_philox_offset,
|
||||
encoded_softmax_block_ptr,
|
||||
block_min,
|
||||
block_max,
|
||||
offs_n_causal,
|
||||
masked_blocks,
|
||||
n_extra_tokens,
|
||||
bias_ptr,
|
||||
IS_CAUSAL,
|
||||
BLOCK_M,
|
||||
BLOCK_DMODEL,
|
||||
BLOCK_N,
|
||||
offs_m,
|
||||
offs_n,
|
||||
# _, MASK_STEPS, ...
|
||||
PRE_LOAD_V,
|
||||
True,
|
||||
ENABLE_DROPOUT,
|
||||
RETURN_ENCODED_SOFTMAX,
|
||||
padded_head,
|
||||
USE_FP8,
|
||||
qk_scale,
|
||||
p_descale,
|
||||
)
|
||||
# epilogue
|
||||
|
||||
if USE_FP8:
|
||||
acc *= acc_scale
|
||||
acc = acc / l_i[:, None]
|
||||
if ENABLE_DROPOUT:
|
||||
acc = acc / (1 - dropout_p)
|
||||
# If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M,
|
||||
# then we have one block with a row of all NaNs which come from computing
|
||||
# softmax over a row of all -infs (-inf - inf = NaN). We check for that here
|
||||
# and store 0s where there are NaNs as these rows should've been zeroed out.
|
||||
end_m_idx = (start_m + 1) * BLOCK_M
|
||||
start_m_idx = start_m * BLOCK_M
|
||||
causal_start_idx = seqlen_q - seqlen_k
|
||||
if USE_FP8_OUT:
|
||||
acc *= o_descale
|
||||
acc = tl.clamp(acc, FP8_MIN, FP8_MAX)
|
||||
acc = acc.to(Out.type.element_ty)
|
||||
if IS_CAUSAL: # noqa: SIM102
|
||||
if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx:
|
||||
out_mask_boundary = tl.full((BLOCK_DMODEL, ),
|
||||
causal_start_idx,
|
||||
dtype=tl.int32)
|
||||
mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M)
|
||||
out_ptrs_mask = (mask_m_offsets[:, None]
|
||||
>= out_mask_boundary[None, :])
|
||||
z = tl.zeros((1, ), tl.float32)
|
||||
acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty))
|
||||
# write back LSE
|
||||
# l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m
|
||||
# If seqlen_q not multiple of BLOCK_M, we need to mask out the last
|
||||
# few rows. This is only true for the last M block. For others,
|
||||
# overflow_size will be -ve
|
||||
# overflow_size = end_m_idx - seqlen_q
|
||||
# if overflow_size > 0:
|
||||
# boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32)
|
||||
# # This is a > check because mask being 0 blocks the store.
|
||||
# l_ptrs_mask = boundary > tl.arange(0, BLOCK_M)
|
||||
# tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask)
|
||||
# else:
|
||||
# tl.store(l_ptrs, m_i + tl.math.log2(l_i))
|
||||
|
||||
# write back O
|
||||
o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om +
|
||||
off_h_q * stride_oh)
|
||||
O_block_ptr = tl.make_block_ptr(
|
||||
base=Out + o_offset,
|
||||
shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),
|
||||
strides=(stride_om, stride_on),
|
||||
offsets=(start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
# Need boundary check on this to make sure the padding from the
|
||||
# Q and KV tensors in both dims are not part of what we store back.
|
||||
# TODO: Do the boundary check optionally.
|
||||
tl.store(O_block_ptr, acc, boundary_check=(0, 1))
|
||||
|
||||
|
||||
def check_args(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
o,
|
||||
varlen=True,
|
||||
max_seqlens=None,
|
||||
cu_seqlens_q=None,
|
||||
cu_seqlens_k=None,
|
||||
):
|
||||
assert q.dim() == k.dim() and q.dim() == v.dim()
|
||||
if varlen:
|
||||
assert q.dim() == 3
|
||||
total_q, nheads_q, head_size = q.shape
|
||||
total_k, nheads_k, _ = k.shape
|
||||
assert cu_seqlens_q is not None
|
||||
assert cu_seqlens_k is not None
|
||||
assert len(cu_seqlens_q) == len(cu_seqlens_k)
|
||||
else:
|
||||
assert q.dim() == 4
|
||||
batch, nheads_q, seqlen_q, head_size = q.shape
|
||||
_, nheads_k, seqlen_k, _ = k.shape
|
||||
assert max_seqlens > 0
|
||||
assert k.shape == v.shape
|
||||
assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1]
|
||||
# TODO: Change assert if we support qkl f8 and v f16
|
||||
assert q.dtype == k.dtype and q.dtype == v.dtype
|
||||
assert head_size <= 256
|
||||
assert o.shape == q.shape
|
||||
assert (nheads_q % nheads_k) == 0
|
||||
|
||||
|
||||
class _attention(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
o,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlens_q,
|
||||
max_seqlens_k,
|
||||
causal=False,
|
||||
sm_scale=1.0,
|
||||
bias=None,
|
||||
fp8_scales=None,
|
||||
fp8_out_scale=None,
|
||||
):
|
||||
if fp8_scales is not None:
|
||||
use_fp8 = True
|
||||
(q_scale, k_scale, v_scale, p_scale) = fp8_scales
|
||||
float8 = current_platform.fp8_dtype()
|
||||
|
||||
def check_and_convert(t, scale):
|
||||
if t.dtype != float8:
|
||||
descale = 1.0 / scale
|
||||
ts = (t * descale).clamp(min=float8_info.min,
|
||||
max=float8_info.max)
|
||||
return ts.to(float8)
|
||||
else:
|
||||
return t
|
||||
|
||||
q = check_and_convert(q, q_scale)
|
||||
k = check_and_convert(k, k_scale)
|
||||
v = check_and_convert(v, v_scale)
|
||||
else:
|
||||
use_fp8 = False
|
||||
q_scale = k_scale = v_scale = p_scale = 1.0
|
||||
|
||||
if o is None:
|
||||
o = torch.empty_like(q, dtype=v.dtype)
|
||||
|
||||
check_args(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
o,
|
||||
varlen=True,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
)
|
||||
if True: # varlen
|
||||
total_q, nheads_q, head_size = q.shape
|
||||
total_k, nheads_k, _ = k.shape
|
||||
batch = len(cu_seqlens_q) - 1
|
||||
q_strides = (0, q.stride(1), q.stride(0), q.stride(2))
|
||||
k_strides = (0, k.stride(1), k.stride(0), k.stride(2))
|
||||
v_strides = (0, v.stride(1), v.stride(0), v.stride(2))
|
||||
o_strides = (0, o.stride(1), o.stride(0), o.stride(2))
|
||||
else:
|
||||
batch, seqlen_q, nheads_q, head_size = q.shape
|
||||
_, seqlen_k, nheads_k, _ = k.shape
|
||||
q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3))
|
||||
k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3))
|
||||
v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3))
|
||||
o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3))
|
||||
|
||||
# Get closest power of 2 over or equal to 32.
|
||||
unpadded_head_dims = {32, 64, 128, 256}
|
||||
if head_size not in unpadded_head_dims:
|
||||
padded_d_model = None
|
||||
for i in unpadded_head_dims:
|
||||
if i > head_size:
|
||||
padded_d_model = i
|
||||
break
|
||||
assert padded_d_model is not None
|
||||
else:
|
||||
padded_d_model = head_size
|
||||
|
||||
grid = lambda META: (
|
||||
triton.cdiv(max_seqlens_q, META["BLOCK_M"]),
|
||||
nheads_q,
|
||||
batch,
|
||||
)
|
||||
|
||||
encoded_softmax = None
|
||||
|
||||
# Seed the RNG so we get reproducible results for testing.
|
||||
philox_seed = 0x1BF52
|
||||
philox_offset = 0x1D4B42
|
||||
|
||||
if bias is not None:
|
||||
bias_strides = (
|
||||
bias.stride(0),
|
||||
bias.stride(1),
|
||||
bias.stride(2),
|
||||
bias.stride(3),
|
||||
)
|
||||
else:
|
||||
bias_strides = (0, 0, 0, 0)
|
||||
|
||||
p_descale = 1.0 / p_scale
|
||||
o_descale = 1.0 / fp8_out_scale.item(
|
||||
) if fp8_out_scale is not None else 1.0
|
||||
|
||||
arg_max_seqlens_q = 0 if on_gfx1x() else max_seqlens_q
|
||||
arg_max_seqlens_k = 0 if on_gfx1x() else max_seqlens_k
|
||||
|
||||
attn_fwd[grid](
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
bias,
|
||||
sm_scale,
|
||||
q_scale,
|
||||
k_scale,
|
||||
v_scale,
|
||||
p_scale,
|
||||
p_descale,
|
||||
o_descale,
|
||||
None,
|
||||
o,
|
||||
*q_strides,
|
||||
*k_strides,
|
||||
*v_strides,
|
||||
*o_strides,
|
||||
*bias_strides,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
dropout_p=0.0,
|
||||
philox_seed=philox_seed,
|
||||
philox_offset_base=philox_offset,
|
||||
encoded_softmax=encoded_softmax,
|
||||
HQ=nheads_q,
|
||||
HK=nheads_k,
|
||||
ACTUAL_BLOCK_DMODEL=head_size,
|
||||
MAX_SEQLENS_Q=arg_max_seqlens_q,
|
||||
MAX_SEQLENS_K=arg_max_seqlens_k,
|
||||
IS_CAUSAL=causal,
|
||||
VARLEN=True,
|
||||
BLOCK_DMODEL=padded_d_model,
|
||||
BIAS_TYPE=0 if bias is None else 1,
|
||||
ENABLE_DROPOUT=False,
|
||||
RETURN_ENCODED_SOFTMAX=False,
|
||||
USE_FP8=use_fp8,
|
||||
USE_FP8_OUT=fp8_out_scale is not None,
|
||||
)
|
||||
|
||||
ctx.grid = grid
|
||||
ctx.sm_scale = sm_scale
|
||||
ctx.BLOCK_DMODEL = head_size
|
||||
ctx.causal = causal
|
||||
ctx.dropout_p = 0.0
|
||||
ctx.philox_seed = philox_seed
|
||||
ctx.philox_offset = philox_offset
|
||||
ctx.encoded_softmax = encoded_softmax
|
||||
ctx.return_encoded_softmax = False
|
||||
return o, encoded_softmax
|
||||
|
||||
|
||||
triton_attention = _attention.apply
|
||||
97
vllm/attention/ops/triton_merge_attn_states.py
Normal file
97
vllm/attention/ops/triton_merge_attn_states.py
Normal file
@@ -0,0 +1,97 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
# Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
|
||||
# can be used to combine partial attention results (in the split-KV case)
|
||||
def merge_attn_states(
|
||||
output: torch.Tensor,
|
||||
prefix_output: torch.Tensor,
|
||||
prefix_lse: torch.Tensor,
|
||||
suffix_output: torch.Tensor,
|
||||
suffix_lse: torch.Tensor,
|
||||
output_lse: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
num_tokens = output.shape[0]
|
||||
num_query_heads = output.shape[1]
|
||||
head_size = output.shape[2]
|
||||
padded_head_size = triton.next_power_of_2(head_size)
|
||||
|
||||
# TODO(woosuk): Use CUDA kernel instead of Triton to minimize CPU overhead.
|
||||
merge_attn_states_kernel[(num_tokens, num_query_heads)](
|
||||
output,
|
||||
output_lse,
|
||||
prefix_output,
|
||||
prefix_lse,
|
||||
suffix_output,
|
||||
suffix_lse,
|
||||
head_size,
|
||||
padded_head_size,
|
||||
output_lse is not None,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def merge_attn_states_kernel(
|
||||
output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
||||
output_lse, # [NUM_HEADS, NUM_TOKENS]
|
||||
prefix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
||||
prefix_lse, # [NUM_HEADS, NUM_TOKENS]
|
||||
suffix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
||||
suffix_lse, # [NUM_HEADS, NUM_TOKENS]
|
||||
HEAD_SIZE: tl.constexpr,
|
||||
PADDED_HEAD_SIZE: tl.constexpr,
|
||||
OUTPUT_LSE: tl.constexpr,
|
||||
):
|
||||
token_idx = tl.program_id(0)
|
||||
num_tokens = tl.num_programs(0)
|
||||
head_idx = tl.program_id(1)
|
||||
num_heads = tl.num_programs(1)
|
||||
|
||||
p_lse = tl.load(prefix_lse + head_idx * num_tokens + token_idx)
|
||||
s_lse = tl.load(suffix_lse + head_idx * num_tokens + token_idx)
|
||||
|
||||
# FA2 and FA3 have different behavior for when the sum-exp is 0, this namely
|
||||
# arises with 0 len seqlens. FA3 returns -inf here while FA2 returns inf.
|
||||
# If we see an inf assume FA2 and convert inf to -inf for consistency
|
||||
# and correctness. Inf generally doesn't make sense in this context outside
|
||||
# of undefined-behavior/FA2-case, so I think this a safe assumption.
|
||||
p_lse = float('-inf') if p_lse == float('inf') else p_lse
|
||||
s_lse = float('-inf') if s_lse == float('inf') else s_lse
|
||||
|
||||
max_lse = tl.maximum(p_lse, s_lse)
|
||||
p_lse = p_lse - max_lse
|
||||
s_lse = s_lse - max_lse
|
||||
# Will reuse precomputed Exp values for scale factor computation.
|
||||
p_se = tl.exp(p_lse)
|
||||
s_se = tl.exp(s_lse)
|
||||
out_se = (p_se + s_se)
|
||||
|
||||
if OUTPUT_LSE:
|
||||
out_lse = tl.log(out_se) + max_lse
|
||||
tl.store(output_lse + head_idx * num_tokens + token_idx, out_lse)
|
||||
|
||||
head_arange = tl.arange(0, PADDED_HEAD_SIZE)
|
||||
head_mask = head_arange < HEAD_SIZE
|
||||
p_out = tl.load(prefix_output + token_idx * num_heads * HEAD_SIZE +
|
||||
head_idx * HEAD_SIZE + head_arange,
|
||||
mask=head_mask)
|
||||
s_out = tl.load(suffix_output + token_idx * num_heads * HEAD_SIZE +
|
||||
head_idx * HEAD_SIZE + head_arange,
|
||||
mask=head_mask)
|
||||
|
||||
# NOTE(woosuk): Be careful with the numerical stability.
|
||||
# We should compute the scale first, and then multiply it with the output.
|
||||
# Do not multiply the output with tl.exp(p_lse) or tl.exp(s_lse) directly.
|
||||
p_scale = p_se / out_se
|
||||
s_scale = s_se / out_se
|
||||
out = p_out * p_scale + s_out * s_scale
|
||||
tl.store(output + token_idx * num_heads * HEAD_SIZE +
|
||||
head_idx * HEAD_SIZE + head_arange,
|
||||
out,
|
||||
mask=head_mask)
|
||||
738
vllm/attention/ops/triton_unified_attention.py
Normal file
738
vllm/attention/ops/triton_unified_attention.py
Normal file
@@ -0,0 +1,738 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Authors:
|
||||
# - Burkhard Ringlein <ngl@zurich.ibm.com>
|
||||
# - Jan van Lunteren <jvl@zurich.ibm.com>
|
||||
# - Chih-Chieh Yang <chih.chieh.yang@ibm.com>
|
||||
# - Thomas Parnell <tpa@zurich.ibm.com>
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def cdiv_fn(x, y):
|
||||
return (x + y - 1) // y
|
||||
|
||||
|
||||
@triton.jit
|
||||
def apply_softcap(S, x):
|
||||
Sdiv = S / x
|
||||
p1 = tl.exp(Sdiv)
|
||||
p2 = tl.exp(-Sdiv)
|
||||
return x * (p1 - p2) / (p1 + p2)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def find_seq_idx(query_start_len_ptr, target_idx, num_seqs,
|
||||
BLOCK_Q: tl.constexpr, use_q_block_mode: tl.constexpr):
|
||||
left: tl.int32 = 0
|
||||
right = num_seqs
|
||||
while left < right:
|
||||
mid = (left + right) // 2
|
||||
val = tl.load(query_start_len_ptr + mid)
|
||||
mid_val = val // BLOCK_Q + mid if use_q_block_mode else val
|
||||
|
||||
if mid_val <= target_idx:
|
||||
left = mid + 1
|
||||
else:
|
||||
right = mid
|
||||
|
||||
return left - 1
|
||||
|
||||
|
||||
@triton.jit
|
||||
def kernel_unified_attention_2d(
|
||||
output_ptr, # [num_tokens, num_query_heads, head_size]
|
||||
query_ptr, # [num_tokens, num_query_heads, head_size]
|
||||
key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
|
||||
value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
|
||||
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
|
||||
seq_lens_ptr, # [num_seqs]
|
||||
alibi_slopes_ptr, # [num_query_heads]
|
||||
scale, # float32
|
||||
k_scale, # float32
|
||||
v_scale, # float32
|
||||
softcap, # float32
|
||||
num_query_heads: tl.constexpr, # int
|
||||
num_queries_per_kv: tl.constexpr, # int
|
||||
block_table_stride: tl.int64, # int
|
||||
query_stride_0: tl.int64, # int
|
||||
query_stride_1: tl.int64, # int, should be equal to head_size
|
||||
output_stride_0: tl.int64, # int
|
||||
output_stride_1: tl.int64, # int, should be equal to head_size
|
||||
BLOCK_SIZE: tl.constexpr, # int
|
||||
HEAD_SIZE: tl.constexpr, # int
|
||||
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
|
||||
USE_ALIBI_SLOPES: tl.constexpr, # bool
|
||||
USE_SOFTCAP: tl.constexpr, # bool
|
||||
SLIDING_WINDOW: tl.constexpr, # int
|
||||
stride_k_cache_0: tl.int64, # int
|
||||
stride_k_cache_1: tl.int64, # int
|
||||
stride_k_cache_2: tl.int64, # int
|
||||
stride_k_cache_3: tl.constexpr, # int
|
||||
stride_v_cache_0: tl.int64, # int
|
||||
stride_v_cache_1: tl.int64, # int
|
||||
stride_v_cache_2: tl.int64, # int
|
||||
stride_v_cache_3: tl.constexpr, # int
|
||||
query_start_len_ptr, # [num_seqs+1]
|
||||
BLOCK_Q: tl.constexpr, # int
|
||||
num_seqs: tl.int32,
|
||||
BLOCK_M: tl.constexpr, # int
|
||||
):
|
||||
q_block_global_idx = tl.program_id(0)
|
||||
kv_head_idx = tl.program_id(1)
|
||||
|
||||
seq_idx = find_seq_idx(query_start_len_ptr, q_block_global_idx, num_seqs,
|
||||
BLOCK_Q, True)
|
||||
|
||||
q_block_start_idx = tl.load(query_start_len_ptr +
|
||||
seq_idx) // BLOCK_Q + seq_idx
|
||||
|
||||
q_block_local_idx = q_block_global_idx - q_block_start_idx
|
||||
|
||||
cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx)
|
||||
cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1)
|
||||
|
||||
cur_batch_query_len = cur_batch_in_all_stop_index \
|
||||
- cur_batch_in_all_start_index
|
||||
|
||||
if q_block_local_idx * BLOCK_Q >= cur_batch_query_len:
|
||||
return
|
||||
|
||||
offs_m = tl.arange(0, BLOCK_M)
|
||||
offs_d = tl.arange(0, HEAD_SIZE_PADDED)
|
||||
query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv
|
||||
|
||||
query_offset_0 = cur_batch_in_all_start_index + query_pos
|
||||
query_offset_1 = kv_head_idx * num_queries_per_kv + \
|
||||
offs_m % num_queries_per_kv
|
||||
query_offset = (query_offset_0[:, None] * query_stride_0 +
|
||||
query_offset_1[:, None] * query_stride_1 + offs_d[None, :])
|
||||
|
||||
dim_mask = tl.where(offs_d < HEAD_SIZE, 1, 0).to(tl.int1)
|
||||
query_mask_0 = tl.where(query_pos < cur_batch_query_len, 1, 0).to(tl.int1)
|
||||
query_mask_1 = tl.where(query_offset_1 < num_query_heads, 1, 0).to(tl.int1)
|
||||
|
||||
# Q : (BLOCK_M, HEAD_SIZE_PADDED)
|
||||
Q = tl.load(
|
||||
query_ptr + query_offset,
|
||||
mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None],
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
block_table_offset = seq_idx * block_table_stride
|
||||
|
||||
M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
|
||||
L = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32)
|
||||
|
||||
# sequence len for this particular sequence
|
||||
seq_len = tl.load(seq_lens_ptr + seq_idx)
|
||||
|
||||
# context length for this particular sequences
|
||||
context_len = seq_len - cur_batch_query_len
|
||||
|
||||
# alibi slope for this head
|
||||
if USE_ALIBI_SLOPES:
|
||||
alibi_slope = tl.load(alibi_slopes_ptr + query_offset_1,
|
||||
mask=query_mask_1,
|
||||
other=0.0)
|
||||
|
||||
num_blocks = cdiv_fn(seq_len, BLOCK_SIZE)
|
||||
|
||||
# iterate through tiles
|
||||
for j in range(0, num_blocks):
|
||||
|
||||
physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j)
|
||||
|
||||
offs_n = tl.arange(0, BLOCK_SIZE)
|
||||
|
||||
v_offset = (physical_block_idx * stride_v_cache_0 +
|
||||
kv_head_idx * stride_v_cache_2 +
|
||||
offs_d[None, :] * stride_v_cache_3 +
|
||||
offs_n[:, None] * stride_v_cache_1)
|
||||
|
||||
k_offset = (physical_block_idx * stride_k_cache_0 +
|
||||
kv_head_idx * stride_k_cache_2 +
|
||||
offs_d[:, None] * stride_k_cache_3 +
|
||||
offs_n[None, :] * stride_k_cache_1)
|
||||
|
||||
# K : (HEAD_SIZE, BLOCK_SIZE)
|
||||
K_load = tl.load(key_cache_ptr + k_offset,
|
||||
mask=dim_mask[:, None],
|
||||
other=0.0)
|
||||
|
||||
if K_load.dtype.is_fp8():
|
||||
if Q.dtype.is_fp8():
|
||||
K = K_load
|
||||
else:
|
||||
K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype)
|
||||
else:
|
||||
K = K_load
|
||||
|
||||
# V : (BLOCK_SIZE, HEAD_SIZE)
|
||||
V_load = tl.load(value_cache_ptr + v_offset,
|
||||
mask=dim_mask[None, :],
|
||||
other=0.0)
|
||||
|
||||
if V_load.dtype.is_fp8():
|
||||
if Q.dtype.is_fp8():
|
||||
V = V_load
|
||||
else:
|
||||
V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype)
|
||||
else:
|
||||
V = V_load
|
||||
|
||||
seq_offset = j * BLOCK_SIZE + offs_n
|
||||
|
||||
seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1
|
||||
|
||||
# S : (BLOCK_M, BLOCK_SIZE)
|
||||
S = tl.zeros(shape=(BLOCK_M, BLOCK_SIZE), dtype=tl.float32)
|
||||
|
||||
S += scale * tl.dot(Q, K)
|
||||
|
||||
if USE_SOFTCAP:
|
||||
S = apply_softcap(S, softcap)
|
||||
|
||||
S = tl.where(query_mask_1[:, None] & query_mask_0[:, None] & seq_mask,
|
||||
S, float("-inf"))
|
||||
|
||||
if SLIDING_WINDOW > 0:
|
||||
S = tl.where((context_len + query_pos[:, None] - seq_offset)
|
||||
< SLIDING_WINDOW, S, float("-inf"))
|
||||
|
||||
if USE_ALIBI_SLOPES:
|
||||
S += alibi_slope[:, None] * (seq_offset - context_len)
|
||||
|
||||
# compute running maximum
|
||||
# m_j : (BLOCK_M,)
|
||||
m_j = tl.maximum(M, tl.max(S, axis=1))
|
||||
# For sliding window there's a chance the max is -inf due to masking of
|
||||
# the entire row. In this case we need to set m_j 0 to avoid NaN
|
||||
m_j = tl.where(m_j > float("-inf"), m_j, 0.0)
|
||||
|
||||
# P : (BLOCK_M, BLOCK_SIZE)
|
||||
P = tl.exp(S - m_j[:, None])
|
||||
|
||||
# l_j : (BLOCK_M,)
|
||||
l_j = tl.sum(P, axis=1)
|
||||
|
||||
# alpha : (BLOCK_M, )
|
||||
alpha = tl.exp(M - m_j)
|
||||
|
||||
# acc : (BLOCK_M, HEAD_SIZE_PADDED)
|
||||
acc = acc * alpha[:, None]
|
||||
|
||||
# update constants
|
||||
L = L * alpha + l_j
|
||||
M = m_j
|
||||
|
||||
# acc : (BLOCK_M, HEAD_SIZE_PADDED)
|
||||
acc += tl.dot(P.to(V.dtype), V)
|
||||
|
||||
# epilogue
|
||||
acc = acc / L[:, None]
|
||||
|
||||
output_offset = (query_offset_0[:, None] * output_stride_0 +
|
||||
query_offset_1[:, None] * output_stride_1 +
|
||||
offs_d[None, :])
|
||||
|
||||
tl.store(
|
||||
output_ptr + output_offset,
|
||||
acc,
|
||||
mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None],
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def kernel_unified_attention_3d(
|
||||
segm_output_ptr,
|
||||
# [num_tokens, num_query_heads, num_segments, head_size]
|
||||
segm_max_ptr, # [num_tokens, num_query_heads, num_segments]
|
||||
segm_expsum_ptr, # [num_tokens, num_query_heads, num_segments]
|
||||
query_ptr, # [num_tokens, num_query_heads, head_size]
|
||||
key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x]
|
||||
value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size]
|
||||
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
|
||||
seq_lens_ptr, # [num_seqs]
|
||||
alibi_slopes_ptr, # [num_query_heads]
|
||||
scale, # float32
|
||||
k_scale, # float32
|
||||
v_scale, # float32
|
||||
softcap, # float32
|
||||
num_query_heads: tl.constexpr, # int
|
||||
num_queries_per_kv: tl.constexpr, # int
|
||||
block_table_stride: tl.int64, # int
|
||||
query_stride_0: tl.int64, # int
|
||||
query_stride_1: tl.int64, # int, should be equal to head_size
|
||||
BLOCK_SIZE: tl.constexpr, # int
|
||||
HEAD_SIZE: tl.constexpr, # int
|
||||
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
|
||||
USE_ALIBI_SLOPES: tl.constexpr, # bool
|
||||
USE_SOFTCAP: tl.constexpr, # bool
|
||||
SLIDING_WINDOW: tl.constexpr, # int
|
||||
stride_k_cache_0: tl.int64, # int
|
||||
stride_k_cache_1: tl.int64, # int
|
||||
stride_k_cache_2: tl.int64, # int
|
||||
stride_k_cache_3: tl.constexpr, # int
|
||||
stride_v_cache_0: tl.int64, # int
|
||||
stride_v_cache_1: tl.int64, # int
|
||||
stride_v_cache_2: tl.int64, # int
|
||||
stride_v_cache_3: tl.constexpr, # int
|
||||
query_start_len_ptr, # [num_seqs+1]
|
||||
BLOCK_Q: tl.constexpr, # int
|
||||
num_seqs: tl.int32,
|
||||
BLOCK_M: tl.constexpr, # int
|
||||
NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int
|
||||
):
|
||||
q_block_global_idx = tl.program_id(0)
|
||||
kv_head_idx = tl.program_id(1)
|
||||
segm_idx = tl.program_id(2)
|
||||
|
||||
seq_idx = find_seq_idx(query_start_len_ptr, q_block_global_idx, num_seqs,
|
||||
BLOCK_Q, True)
|
||||
|
||||
q_block_start_idx = tl.load(query_start_len_ptr +
|
||||
seq_idx) // BLOCK_Q + seq_idx
|
||||
|
||||
q_block_local_idx = q_block_global_idx - q_block_start_idx
|
||||
|
||||
cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx)
|
||||
cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1)
|
||||
|
||||
cur_batch_query_len = cur_batch_in_all_stop_index \
|
||||
- cur_batch_in_all_start_index
|
||||
|
||||
if q_block_local_idx * BLOCK_Q >= cur_batch_query_len:
|
||||
return
|
||||
|
||||
# sequence len for this particular sequence
|
||||
seq_len = tl.load(seq_lens_ptr + seq_idx)
|
||||
|
||||
# number of segments for this particular sequence
|
||||
num_segments = NUM_SEGMENTS_PER_SEQ
|
||||
blocks_per_segment = cdiv_fn(seq_len, num_segments * BLOCK_SIZE)
|
||||
|
||||
if segm_idx * blocks_per_segment * BLOCK_SIZE >= seq_len:
|
||||
return
|
||||
|
||||
offs_m = tl.arange(0, BLOCK_M)
|
||||
offs_d = tl.arange(0, HEAD_SIZE_PADDED)
|
||||
|
||||
query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv
|
||||
|
||||
query_offset_0 = cur_batch_in_all_start_index + query_pos
|
||||
query_offset_1 = kv_head_idx * num_queries_per_kv + \
|
||||
offs_m % num_queries_per_kv
|
||||
|
||||
query_offset = (query_offset_0[:, None] * query_stride_0 +
|
||||
query_offset_1[:, None] * query_stride_1 + offs_d[None, :])
|
||||
|
||||
dim_mask = tl.where(offs_d < HEAD_SIZE, 1, 0).to(tl.int1)
|
||||
query_mask_0 = tl.where(query_pos < cur_batch_query_len, 1, 0).to(tl.int1)
|
||||
query_mask_1 = tl.where(query_offset_1 < num_query_heads, 1, 0).to(tl.int1)
|
||||
|
||||
# Q : (BLOCK_M, HEAD_SIZE_PADDED)
|
||||
Q = tl.load(
|
||||
query_ptr + query_offset,
|
||||
mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None],
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
block_table_offset = seq_idx * block_table_stride
|
||||
|
||||
M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
|
||||
L = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32)
|
||||
|
||||
# context length for this particular sequences
|
||||
context_len = seq_len - cur_batch_query_len
|
||||
|
||||
# alibi slope for this head
|
||||
if USE_ALIBI_SLOPES:
|
||||
alibi_slope = tl.load(alibi_slopes_ptr + query_offset_1,
|
||||
mask=query_mask_1,
|
||||
other=0.0)
|
||||
|
||||
num_blocks = cdiv_fn(seq_len, BLOCK_SIZE)
|
||||
|
||||
# iterate through tiles within current segment
|
||||
for j in range(
|
||||
segm_idx * blocks_per_segment,
|
||||
min((segm_idx + 1) * blocks_per_segment, num_blocks),
|
||||
):
|
||||
physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j)
|
||||
|
||||
offs_n = tl.arange(0, BLOCK_SIZE)
|
||||
|
||||
v_offset = (physical_block_idx * stride_v_cache_0 +
|
||||
kv_head_idx * stride_v_cache_2 +
|
||||
offs_d[None, :] * stride_v_cache_3 +
|
||||
offs_n[:, None] * stride_v_cache_1)
|
||||
|
||||
k_offset = (physical_block_idx * stride_k_cache_0 +
|
||||
kv_head_idx * stride_k_cache_2 +
|
||||
offs_d[:, None] * stride_k_cache_3 +
|
||||
offs_n[None, :] * stride_k_cache_1)
|
||||
|
||||
# K : (HEAD_SIZE, BLOCK_SIZE)
|
||||
K_load = tl.load(key_cache_ptr + k_offset,
|
||||
mask=dim_mask[:, None],
|
||||
other=0.0)
|
||||
|
||||
if K_load.dtype.is_fp8():
|
||||
if Q.dtype.is_fp8():
|
||||
K = K_load
|
||||
else:
|
||||
K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype)
|
||||
else:
|
||||
K = K_load
|
||||
|
||||
# V : (BLOCK_SIZE, HEAD_SIZE)
|
||||
V_load = tl.load(value_cache_ptr + v_offset,
|
||||
mask=dim_mask[None, :],
|
||||
other=0.0)
|
||||
|
||||
if V_load.dtype.is_fp8():
|
||||
if Q.dtype.is_fp8():
|
||||
V = V_load
|
||||
else:
|
||||
V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype)
|
||||
else:
|
||||
V = V_load
|
||||
|
||||
seq_offset = j * BLOCK_SIZE + offs_n
|
||||
|
||||
seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1
|
||||
|
||||
# S : (BLOCK_M, BLOCK_SIZE)
|
||||
S = tl.zeros(shape=(BLOCK_M, BLOCK_SIZE), dtype=tl.float32)
|
||||
|
||||
S += scale * tl.dot(Q, K)
|
||||
|
||||
if USE_SOFTCAP:
|
||||
S = apply_softcap(S, softcap)
|
||||
|
||||
S = tl.where(query_mask_1[:, None] & query_mask_0[:, None] & seq_mask,
|
||||
S, float("-inf"))
|
||||
|
||||
if SLIDING_WINDOW > 0:
|
||||
S = tl.where((context_len + query_pos[:, None] - seq_offset)
|
||||
< SLIDING_WINDOW, S, float("-inf"))
|
||||
|
||||
if USE_ALIBI_SLOPES:
|
||||
S += alibi_slope[:, None] * (seq_offset - context_len)
|
||||
|
||||
# compute running maximum
|
||||
# m_j : (BLOCK_M,)
|
||||
m_j = tl.maximum(M, tl.max(S, axis=1))
|
||||
# For sliding window there's a chance the max is -inf due to masking of
|
||||
# the entire row. In this case we need to set m_j 0 to avoid NaN
|
||||
m_j = tl.where(m_j > float("-inf"), m_j, 0.0)
|
||||
|
||||
# P : (BLOCK_M, BLOCK_SIZE,)
|
||||
P = tl.exp(S - m_j[:, None])
|
||||
|
||||
# l_j : (BLOCK_M,)
|
||||
l_j = tl.sum(P, axis=1)
|
||||
|
||||
# alpha : (BLOCK_M, )
|
||||
alpha = tl.exp(M - m_j)
|
||||
|
||||
# acc : (BLOCK_M, HEAD_SIZE_PADDED)
|
||||
acc = acc * alpha[:, None]
|
||||
|
||||
# update constants
|
||||
L = L * alpha + l_j
|
||||
M = m_j
|
||||
|
||||
# acc : (BLOCK_M, HEAD_SIZE_PADDED)
|
||||
acc += tl.dot(P.to(V.dtype), V)
|
||||
|
||||
segm_output_offset = (
|
||||
query_offset_0[:, None].to(tl.int64) *
|
||||
(num_query_heads * NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) +
|
||||
query_offset_1[:, None] * (NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) +
|
||||
segm_idx * HEAD_SIZE_PADDED + tl.arange(0, HEAD_SIZE_PADDED)[None, :])
|
||||
tl.store(
|
||||
segm_output_ptr + segm_output_offset,
|
||||
acc,
|
||||
mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None],
|
||||
)
|
||||
segm_offset = (query_offset_0.to(tl.int64) *
|
||||
(num_query_heads * NUM_SEGMENTS_PER_SEQ) +
|
||||
query_offset_1 * NUM_SEGMENTS_PER_SEQ + segm_idx)
|
||||
tl.store(segm_max_ptr + segm_offset, M, mask=query_mask_0 & query_mask_1)
|
||||
tl.store(segm_expsum_ptr + segm_offset,
|
||||
L,
|
||||
mask=query_mask_0 & query_mask_1)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def reduce_segments(
|
||||
output_ptr, # [num_tokens, num_query_heads, head_size]
|
||||
segm_output_ptr,
|
||||
#[num_tokens, num_query_heads, max_num_segments, head_size]
|
||||
segm_max_ptr, # [num_tokens, num_query_heads, max_num_segments]
|
||||
segm_expsum_ptr, # [num_tokens, num_query_heads, max_num_segments]
|
||||
seq_lens_ptr, # [num_seqs]
|
||||
num_seqs, # int
|
||||
num_query_heads: tl.constexpr, # int
|
||||
output_stride_0: tl.int64, # int
|
||||
output_stride_1: tl.int64, # int, should be equal to head_size
|
||||
block_table_stride: tl.int64, # int
|
||||
BLOCK_SIZE: tl.constexpr, # int
|
||||
HEAD_SIZE: tl.constexpr, # int, must be power of 2
|
||||
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
|
||||
query_start_len_ptr, # [num_seqs+1]
|
||||
BLOCK_Q: tl.constexpr, # int
|
||||
NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int
|
||||
):
|
||||
query_token_idx = tl.program_id(0)
|
||||
query_head_idx = tl.program_id(1)
|
||||
|
||||
seq_idx = find_seq_idx(query_start_len_ptr, query_token_idx, num_seqs,
|
||||
BLOCK_Q, False)
|
||||
|
||||
# sequence len for this particular sequence
|
||||
seq_len = tl.load(seq_lens_ptr + seq_idx)
|
||||
|
||||
# number of segments for this particular sequence
|
||||
num_segments = NUM_SEGMENTS_PER_SEQ
|
||||
blocks_per_segment = cdiv_fn(seq_len, num_segments * BLOCK_SIZE)
|
||||
|
||||
# create masks for subsequent loads
|
||||
act_num_segments = cdiv_fn(seq_len, blocks_per_segment * BLOCK_SIZE)
|
||||
segm_mask = tl.arange(0, NUM_SEGMENTS_PER_SEQ) < tl.full(
|
||||
[NUM_SEGMENTS_PER_SEQ], act_num_segments, dtype=tl.int32)
|
||||
dim_mask = tl.where(tl.arange(0, HEAD_SIZE_PADDED) < HEAD_SIZE, 1,
|
||||
0).to(tl.int1)
|
||||
|
||||
# load segment maxima
|
||||
segm_offset = (query_token_idx.to(tl.int64) *
|
||||
(num_query_heads * NUM_SEGMENTS_PER_SEQ) +
|
||||
query_head_idx * NUM_SEGMENTS_PER_SEQ +
|
||||
tl.arange(0, NUM_SEGMENTS_PER_SEQ))
|
||||
segm_max = tl.load(segm_max_ptr + segm_offset,
|
||||
mask=segm_mask,
|
||||
other=float("-inf"))
|
||||
overall_max = tl.max(segm_max)
|
||||
|
||||
# load and rescale segment exp sums
|
||||
segm_expsum = tl.load(segm_expsum_ptr + segm_offset,
|
||||
mask=segm_mask,
|
||||
other=0.0)
|
||||
segm_expsum = segm_expsum * tl.exp(segm_max - overall_max)
|
||||
overall_expsum = tl.sum(segm_expsum)
|
||||
|
||||
# load, rescale, and add segment attention outputs
|
||||
segm_output_offset = (
|
||||
query_token_idx.to(tl.int64) *
|
||||
(num_query_heads * NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) +
|
||||
query_head_idx * (NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) +
|
||||
tl.arange(0, NUM_SEGMENTS_PER_SEQ)[:, None] * HEAD_SIZE_PADDED +
|
||||
tl.arange(0, HEAD_SIZE_PADDED)[None, :])
|
||||
segm_output = tl.load(
|
||||
segm_output_ptr + segm_output_offset,
|
||||
mask=segm_mask[:, None] & dim_mask[None, :],
|
||||
other=0.0,
|
||||
)
|
||||
segm_output *= tl.exp(segm_max - overall_max)[:, None]
|
||||
acc_sum = tl.sum(segm_output, axis=0)
|
||||
# safely divide by overall_expsum, returning 0.0 if overall_expsum is 0
|
||||
acc = tl.where(overall_expsum == 0.0, 0.0, acc_sum / overall_expsum)
|
||||
|
||||
# write result
|
||||
output_offset = (query_token_idx * output_stride_0 +
|
||||
query_head_idx * output_stride_1 +
|
||||
tl.arange(0, HEAD_SIZE_PADDED))
|
||||
tl.store(output_ptr + output_offset, acc, mask=dim_mask)
|
||||
|
||||
|
||||
def unified_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out,
|
||||
cu_seqlens_q,
|
||||
max_seqlen_q,
|
||||
seqused_k,
|
||||
max_seqlen_k,
|
||||
softmax_scale,
|
||||
causal,
|
||||
window_size,
|
||||
block_table,
|
||||
softcap,
|
||||
q_descale,
|
||||
k_descale,
|
||||
v_descale,
|
||||
alibi_slopes=None,
|
||||
):
|
||||
assert causal, "Only causal attention is supported"
|
||||
assert q_descale is None, "Q scales not supported"
|
||||
|
||||
block_size = v.shape[1]
|
||||
assert q.element_size() >= 2 or block_size >= 32, \
|
||||
"Block size must be at least 32 for fp8"
|
||||
|
||||
use_alibi_slopes = alibi_slopes is not None
|
||||
|
||||
block_size = v.shape[1]
|
||||
num_seqs = len(seqused_k)
|
||||
num_query_heads = q.shape[1]
|
||||
num_kv_heads = k.shape[2]
|
||||
num_queries_per_kv = num_query_heads // num_kv_heads
|
||||
head_size = q.shape[2]
|
||||
|
||||
BLOCK_M = 16
|
||||
BLOCK_Q = BLOCK_M // num_queries_per_kv
|
||||
|
||||
# Ideally we would launch with kernel with:
|
||||
# \sum_i[ceil(query_len[i] / BLOCK_Q)] blocks.
|
||||
# However, it is slow to realize the query_lens on cpu.
|
||||
# Instead we use upper-bound:
|
||||
# \sum_i[ceil(query_len[i] / BLOCK_Q)]
|
||||
# <= \sum_i[floor(query_len[i] / BLOCK_Q) + 1]
|
||||
# = \sum_i[floor(query_len[i] / BLOCK_Q)] + num_seqs
|
||||
# <= floor(\sum_i(query_len[i]) / BLOCK_Q) + num_seqs
|
||||
# = floor(q.shape[0] / BLOCK_Q) + num_seqs
|
||||
total_num_q_blocks = q.shape[0] // BLOCK_Q + num_seqs
|
||||
|
||||
# if batch contains a prefill
|
||||
if max_seqlen_q > 1 or total_num_q_blocks * num_kv_heads > 128:
|
||||
kernel_unified_attention_2d[(
|
||||
total_num_q_blocks,
|
||||
num_kv_heads,
|
||||
)](
|
||||
output_ptr=out,
|
||||
query_ptr=q,
|
||||
key_cache_ptr=k,
|
||||
value_cache_ptr=v,
|
||||
block_tables_ptr=block_table,
|
||||
seq_lens_ptr=seqused_k,
|
||||
alibi_slopes_ptr=alibi_slopes,
|
||||
scale=softmax_scale,
|
||||
k_scale=k_descale,
|
||||
v_scale=v_descale,
|
||||
softcap=softcap,
|
||||
num_query_heads=num_query_heads,
|
||||
num_queries_per_kv=num_queries_per_kv,
|
||||
block_table_stride=block_table.stride(0),
|
||||
query_stride_0=q.stride(0),
|
||||
query_stride_1=q.stride(1),
|
||||
output_stride_0=out.stride(0),
|
||||
output_stride_1=out.stride(1),
|
||||
BLOCK_SIZE=block_size,
|
||||
HEAD_SIZE=head_size,
|
||||
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
|
||||
USE_ALIBI_SLOPES=use_alibi_slopes,
|
||||
USE_SOFTCAP=(softcap > 0),
|
||||
SLIDING_WINDOW=(1 + window_size[0]),
|
||||
stride_k_cache_0=k.stride(0),
|
||||
stride_k_cache_1=k.stride(1),
|
||||
stride_k_cache_2=k.stride(2),
|
||||
stride_k_cache_3=k.stride(3),
|
||||
stride_v_cache_0=v.stride(0),
|
||||
stride_v_cache_1=v.stride(1),
|
||||
stride_v_cache_2=v.stride(2),
|
||||
stride_v_cache_3=v.stride(3),
|
||||
query_start_len_ptr=cu_seqlens_q,
|
||||
BLOCK_Q=BLOCK_Q,
|
||||
num_seqs=num_seqs,
|
||||
BLOCK_M=BLOCK_M,
|
||||
)
|
||||
else:
|
||||
# for initial version, NUM_SEGMENTS = 16 is chosen as a default
|
||||
# value that showed good performance in tests
|
||||
NUM_SEGMENTS = 16
|
||||
|
||||
segm_output = torch.empty(
|
||||
q.shape[0],
|
||||
num_query_heads,
|
||||
NUM_SEGMENTS,
|
||||
triton.next_power_of_2(head_size),
|
||||
dtype=torch.float32,
|
||||
device=q.device,
|
||||
)
|
||||
segm_max = torch.empty(
|
||||
q.shape[0],
|
||||
num_query_heads,
|
||||
NUM_SEGMENTS,
|
||||
dtype=torch.float32,
|
||||
device=q.device,
|
||||
)
|
||||
segm_expsum = torch.empty(
|
||||
q.shape[0],
|
||||
num_query_heads,
|
||||
NUM_SEGMENTS,
|
||||
dtype=torch.float32,
|
||||
device=q.device,
|
||||
)
|
||||
|
||||
kernel_unified_attention_3d[(
|
||||
total_num_q_blocks, num_kv_heads, NUM_SEGMENTS)](
|
||||
segm_output_ptr=segm_output,
|
||||
segm_max_ptr=segm_max,
|
||||
segm_expsum_ptr=segm_expsum,
|
||||
query_ptr=q,
|
||||
key_cache_ptr=k,
|
||||
value_cache_ptr=v,
|
||||
block_tables_ptr=block_table,
|
||||
seq_lens_ptr=seqused_k,
|
||||
alibi_slopes_ptr=alibi_slopes,
|
||||
scale=softmax_scale,
|
||||
k_scale=k_descale,
|
||||
v_scale=v_descale,
|
||||
softcap=softcap,
|
||||
num_query_heads=num_query_heads,
|
||||
num_queries_per_kv=num_queries_per_kv,
|
||||
block_table_stride=block_table.stride(0),
|
||||
query_stride_0=q.stride(0),
|
||||
query_stride_1=q.stride(1),
|
||||
BLOCK_SIZE=block_size,
|
||||
HEAD_SIZE=head_size,
|
||||
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
|
||||
USE_ALIBI_SLOPES=use_alibi_slopes,
|
||||
USE_SOFTCAP=(softcap > 0),
|
||||
SLIDING_WINDOW=(1 + window_size[0]),
|
||||
stride_k_cache_0=k.stride(0),
|
||||
stride_k_cache_1=k.stride(1),
|
||||
stride_k_cache_2=k.stride(2),
|
||||
stride_k_cache_3=k.stride(3),
|
||||
stride_v_cache_0=v.stride(0),
|
||||
stride_v_cache_1=v.stride(1),
|
||||
stride_v_cache_2=v.stride(2),
|
||||
stride_v_cache_3=v.stride(3),
|
||||
query_start_len_ptr=cu_seqlens_q,
|
||||
BLOCK_Q=BLOCK_Q,
|
||||
num_seqs=num_seqs,
|
||||
BLOCK_M=BLOCK_M,
|
||||
NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS,
|
||||
)
|
||||
|
||||
reduce_segments[(q.shape[0], num_query_heads)](
|
||||
output_ptr=out,
|
||||
segm_output_ptr=segm_output,
|
||||
segm_max_ptr=segm_max,
|
||||
segm_expsum_ptr=segm_expsum,
|
||||
seq_lens_ptr=seqused_k,
|
||||
num_seqs=num_seqs,
|
||||
num_query_heads=num_query_heads,
|
||||
output_stride_0=out.stride(0),
|
||||
output_stride_1=out.stride(1),
|
||||
block_table_stride=block_table.stride(0),
|
||||
BLOCK_SIZE=block_size,
|
||||
HEAD_SIZE=head_size,
|
||||
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
|
||||
query_start_len_ptr=cu_seqlens_q,
|
||||
BLOCK_Q=BLOCK_Q,
|
||||
NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS,
|
||||
)
|
||||
214
vllm/attention/selector.py
Normal file
214
vllm/attention/selector.py
Normal file
@@ -0,0 +1,214 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from functools import cache
|
||||
from typing import Generator, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import _Backend, current_platform
|
||||
from vllm.utils import STR_BACKEND_ENV_VAR, resolve_obj_by_qualname
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def backend_name_to_enum(backend_name: str) -> Optional[_Backend]:
|
||||
"""
|
||||
Convert a string backend name to a _Backend enum value.
|
||||
|
||||
Returns:
|
||||
* _Backend: enum value if backend_name is a valid in-tree type
|
||||
* None: otherwise it's an invalid in-tree type or an out-of-tree platform is
|
||||
loaded.
|
||||
"""
|
||||
assert backend_name is not None
|
||||
return _Backend[backend_name] if backend_name in _Backend.__members__ else \
|
||||
None
|
||||
|
||||
|
||||
def get_env_variable_attn_backend() -> Optional[_Backend]:
|
||||
'''
|
||||
Get the backend override specified by the vLLM attention
|
||||
backend environment variable, if one is specified.
|
||||
|
||||
Returns:
|
||||
|
||||
* _Backend enum value if an override is specified
|
||||
* None otherwise
|
||||
'''
|
||||
backend_name = os.environ.get(STR_BACKEND_ENV_VAR)
|
||||
return (None
|
||||
if backend_name is None else backend_name_to_enum(backend_name))
|
||||
|
||||
|
||||
# Global state allows a particular choice of backend
|
||||
# to be forced, overriding the logic which auto-selects
|
||||
# a backend based on system & workload configuration
|
||||
# (default behavior if this variable is None)
|
||||
#
|
||||
# THIS SELECTION TAKES PRECEDENCE OVER THE
|
||||
# VLLM_ATTENTION_BACKEND ENVIRONMENT VARIABLE
|
||||
forced_attn_backend: Optional[_Backend] = None
|
||||
|
||||
|
||||
def global_force_attn_backend(attn_backend: Optional[_Backend]) -> None:
|
||||
'''
|
||||
Force all attention operations to use a specified backend.
|
||||
|
||||
Passing `None` for the argument re-enables automatic
|
||||
backend selection.,
|
||||
|
||||
Arguments:
|
||||
|
||||
* attn_backend: backend selection (None to revert to auto)
|
||||
'''
|
||||
global forced_attn_backend
|
||||
forced_attn_backend = attn_backend
|
||||
|
||||
|
||||
def get_global_forced_attn_backend() -> Optional[_Backend]:
|
||||
'''
|
||||
Get the currently-forced choice of attention backend,
|
||||
or None if auto-selection is currently enabled.
|
||||
'''
|
||||
return forced_attn_backend
|
||||
|
||||
|
||||
def supports_head_size(
|
||||
attn_backend: Union[str, type[AttentionBackend]],
|
||||
head_size: int,
|
||||
) -> bool:
|
||||
if isinstance(attn_backend, str):
|
||||
try:
|
||||
attn_backend = resolve_obj_by_qualname(attn_backend)
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
assert isinstance(attn_backend, type)
|
||||
|
||||
# TODO: Update the interface once V0 is removed
|
||||
if get_supported_head_sizes := getattr(attn_backend,
|
||||
"get_supported_head_sizes", None):
|
||||
return head_size in get_supported_head_sizes()
|
||||
if validate_head_size := getattr(attn_backend, "validate_head_size", None):
|
||||
try:
|
||||
validate_head_size(head_size)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
raise NotImplementedError(f"{attn_backend.__name__} does not support "
|
||||
"head size validation")
|
||||
|
||||
|
||||
def get_attn_backend(
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: Optional[str],
|
||||
block_size: int,
|
||||
is_attention_free: bool,
|
||||
is_blocksparse: bool = False,
|
||||
use_mla: bool = False,
|
||||
) -> type[AttentionBackend]:
|
||||
"""Selects which attention backend to use and lazily imports it."""
|
||||
# Accessing envs.* behind an @lru_cache decorator can cause the wrong
|
||||
# value to be returned from the cache if the value changes between calls.
|
||||
# To avoid this, we read envs.VLLM_USE_V1 here and pass it explicitly to the
|
||||
# private function.
|
||||
return _cached_get_attn_backend(
|
||||
head_size=head_size,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
block_size=block_size,
|
||||
is_attention_free=is_attention_free,
|
||||
is_blocksparse=is_blocksparse,
|
||||
use_v1=envs.VLLM_USE_V1,
|
||||
use_mla=use_mla,
|
||||
)
|
||||
|
||||
|
||||
@cache
|
||||
def _cached_get_attn_backend(
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: Optional[str],
|
||||
block_size: int,
|
||||
is_attention_free: bool,
|
||||
is_blocksparse: bool = False,
|
||||
use_v1: bool = False,
|
||||
use_mla: bool = False,
|
||||
) -> type[AttentionBackend]:
|
||||
if is_blocksparse:
|
||||
logger.info("Using BlocksparseFlashAttention backend.")
|
||||
from vllm.attention.backends.blocksparse_attn import (
|
||||
BlocksparseFlashAttentionBackend)
|
||||
return BlocksparseFlashAttentionBackend
|
||||
|
||||
# If there are no attention layers (e.g. we are running Mamba),
|
||||
# use the placeholder NO_ATTENTION
|
||||
if is_attention_free:
|
||||
from vllm.attention.backends.placeholder_attn import (
|
||||
PlaceholderAttentionBackend)
|
||||
return PlaceholderAttentionBackend
|
||||
|
||||
# Check whether a particular choice of backend was
|
||||
# previously forced.
|
||||
#
|
||||
# THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND
|
||||
# ENVIRONMENT VARIABLE.
|
||||
selected_backend = None
|
||||
backend_by_global_setting: Optional[_Backend] = (
|
||||
get_global_forced_attn_backend())
|
||||
if backend_by_global_setting is not None:
|
||||
selected_backend = backend_by_global_setting
|
||||
else:
|
||||
# Check the environment variable and override if specified
|
||||
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
|
||||
if backend_by_env_var is not None:
|
||||
selected_backend = backend_name_to_enum(backend_by_env_var)
|
||||
|
||||
# get device-specific attn_backend
|
||||
attention_cls = current_platform.get_attn_backend_cls(
|
||||
selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1,
|
||||
use_mla)
|
||||
if not attention_cls:
|
||||
raise ValueError(
|
||||
f"Invalid attention backend for {current_platform.device_name}")
|
||||
return resolve_obj_by_qualname(attention_cls)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def global_force_attn_backend_context_manager(
|
||||
attn_backend: _Backend) -> Generator[None, None, None]:
|
||||
'''
|
||||
Globally force a vLLM attention backend override within a
|
||||
context manager, reverting the global attention backend
|
||||
override to its prior state upon exiting the context
|
||||
manager.
|
||||
|
||||
Arguments:
|
||||
|
||||
* attn_backend: attention backend to force
|
||||
|
||||
Returns:
|
||||
|
||||
* Generator
|
||||
'''
|
||||
|
||||
# Save the current state of the global backend override (if any)
|
||||
original_value = get_global_forced_attn_backend()
|
||||
|
||||
# Globally force the new backend override
|
||||
global_force_attn_backend(attn_backend)
|
||||
|
||||
# Yield control back to the enclosed code block
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
# Revert the original global backend override, if any
|
||||
global_force_attn_backend(original_value)
|
||||
76
vllm/attention/utils/fa_utils.py
Normal file
76
vllm/attention/utils/fa_utils.py
Normal file
@@ -0,0 +1,76 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
|
||||
from vllm import envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if current_platform.is_cuda():
|
||||
from vllm import _custom_ops as ops
|
||||
reshape_and_cache_flash = ops.reshape_and_cache_flash
|
||||
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
|
||||
get_scheduler_metadata)
|
||||
elif current_platform.is_rocm():
|
||||
from vllm import _custom_ops as ops
|
||||
reshape_and_cache_cuda = ops.reshape_and_cache_cuda
|
||||
from flash_attn import vllm_flash_attn_varlen_func
|
||||
elif current_platform.is_xpu():
|
||||
from vllm._ipex_ops import ipex_ops as ops
|
||||
reshape_and_cache_flash = ops.reshape_and_cache_flash
|
||||
flash_attn_varlen_func = ops.flash_attn_varlen_func
|
||||
get_scheduler_metadata = ops.get_scheduler_metadata
|
||||
|
||||
|
||||
def get_flash_attn_version(requires_alibi: bool = False) -> Optional[int]:
|
||||
# import here to avoid circular dependencies
|
||||
from vllm.platforms import current_platform
|
||||
if current_platform.is_xpu():
|
||||
return 2
|
||||
try:
|
||||
from vllm.vllm_flash_attn.flash_attn_interface import (
|
||||
fa_version_unsupported_reason, is_fa_version_supported)
|
||||
device_capability = current_platform.get_device_capability()
|
||||
|
||||
assert device_capability is not None
|
||||
|
||||
# 1. default version depending on platform
|
||||
fa_version = 3 if (device_capability.major == 9
|
||||
and is_fa_version_supported(3)) else 2
|
||||
|
||||
# 2. override if passed by environment
|
||||
if envs.VLLM_FLASH_ATTN_VERSION is not None:
|
||||
assert envs.VLLM_FLASH_ATTN_VERSION in [2, 3]
|
||||
fa_version = envs.VLLM_FLASH_ATTN_VERSION
|
||||
|
||||
# 3. fallback for unsupported combinations
|
||||
if device_capability.major == 10 and fa_version == 3:
|
||||
logger.warning_once(
|
||||
"Cannot use FA version 3 on Blackwell platform "
|
||||
"defaulting to FA version 2.")
|
||||
fa_version = 2
|
||||
|
||||
if requires_alibi and fa_version == 3:
|
||||
logger.warning_once("Cannot use FA version 3 with ALiBi, "
|
||||
"defaulting to FA version 2.")
|
||||
fa_version = 2
|
||||
|
||||
if not is_fa_version_supported(fa_version):
|
||||
logger.error("Cannot use FA version %d is not supported due to %s",
|
||||
fa_version, fa_version_unsupported_reason(fa_version))
|
||||
|
||||
assert is_fa_version_supported(fa_version)
|
||||
return fa_version
|
||||
except (ImportError, AssertionError):
|
||||
return None
|
||||
|
||||
|
||||
def flash_attn_supports_fp8() -> bool:
|
||||
return get_flash_attn_version() == 3 and \
|
||||
current_platform.get_device_capability().major == 9
|
||||
|
||||
|
||||
def is_flash_attn_varlen_func_available() -> bool:
|
||||
return current_platform.is_cuda() or current_platform.is_rocm() or current_platform.is_xpu()
|
||||
Reference in New Issue
Block a user