init
This commit is contained in:
283
vllm/attention/backends/flash_attn.py
Normal file
283
vllm/attention/backends/flash_attn.py
Normal file
@@ -0,0 +1,283 @@
|
||||
"""Attention layer with Flash and PagedAttention.
|
||||
|
||||
NOTE(woosuk): At the moment, this file includes a lot of duplicated code from
|
||||
XFormers backend. The duplicated code will be removed once we use flash-attn or
|
||||
flashinfer for all the attention operations.
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
import torch_musa
|
||||
from torch.nn.functional import scaled_dot_product_attention
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata,
|
||||
AttentionMetadataPerStage)
|
||||
from vllm.attention.ops.paged_attn import (PagedAttention,
|
||||
PagedAttentionMetadata)
|
||||
|
||||
|
||||
class FlashAttentionBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["FlashAttentionImpl"]:
|
||||
return FlashAttentionImpl
|
||||
|
||||
@staticmethod
|
||||
def make_metadata(*args, **kwargs) -> "FlashAttentionMetadata":
|
||||
return FlashAttentionMetadata(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
|
||||
num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: Dict[int, int],
|
||||
) -> None:
|
||||
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: Dict[int, List[int]],
|
||||
) -> None:
|
||||
PagedAttention.copy_blocks(kv_caches, src_to_dists)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashAttentionMetadata(AttentionMetadataPerStage,
|
||||
PagedAttentionMetadata):
|
||||
"""Metadata for FlashAttentionBackend.
|
||||
|
||||
NOTE: Any python object stored here is not updated when it is
|
||||
cuda-graph replayed. If you have values that need to be changed
|
||||
dynamically, it should be stored in tensor. The tensor has to be
|
||||
updated from `CUDAGraphRunner.forward` API.
|
||||
"""
|
||||
# Currently, input sequences can only contain all prompts
|
||||
# or all decoding. True if all sequences are prompts.
|
||||
is_prompt: bool
|
||||
# (batch_size,). The sequence length per sequence. Sequence length means
|
||||
# the computed tokens + new tokens None if it is a decoding.
|
||||
seq_lens: Optional[List[int]]
|
||||
# seq_lens stored as a tensor.
|
||||
seq_lens_tensor: Optional[torch.Tensor]
|
||||
|
||||
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
# |- tokenA -|......................|-- newTokens ---|
|
||||
# |---------- context_len ----------|
|
||||
# |-------------------- seq_len ----------------------|
|
||||
# |-- query_len ---|
|
||||
|
||||
# Maximum query length in the batch.
|
||||
max_query_len: Optional[int]
|
||||
# Maximum sequence length in the batch.
|
||||
max_seq_len: Optional[int]
|
||||
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
||||
# the batch, used to index into subquery. E.g., if the subquery length
|
||||
# is [4, 6], it is [0, 4, 10].
|
||||
subquery_start_loc: Optional[torch.Tensor]
|
||||
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
||||
# the batch, used to index into sequence. E.g., if the sequence length is
|
||||
# [4, 6], it is [0, 4, 10].
|
||||
seq_start_loc: Optional[torch.Tensor]
|
||||
# (batch_size,) A tensor of context lengths (tokens that are computed
|
||||
# so far).
|
||||
context_lens_tensor: Optional[torch.Tensor]
|
||||
|
||||
# Whether or not if cuda graph is enabled.
|
||||
# Cuda-graph is currently enabled for decoding only.
|
||||
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
||||
use_cuda_graph: bool
|
||||
|
||||
|
||||
class FlashAttentionImpl(AttentionImpl):
|
||||
"""
|
||||
If the input tensors contain prompt tokens, the layout is as follows:
|
||||
|<--------------- num_prefill_tokens ----------------->|
|
||||
|<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
|
||||
|
||||
Otherwise, the layout is as follows:
|
||||
|<----------------- num_decode_tokens ------------------>|
|
||||
|<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
|
||||
|
||||
Generation tokens can contain padding when cuda-graph is used.
|
||||
Currently, prompt tokens don't contain any padding.
|
||||
|
||||
The prompts might have different lengths, while the generation tokens
|
||||
always have length 1.
|
||||
|
||||
If chunked prefill is enabled, prefill tokens and decode tokens can be
|
||||
batched together in a flattened 1D query.
|
||||
|
||||
|<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
|
||||
|<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|
|
||||
|
||||
Currently, cuda graph is disabled for chunked prefill, meaning there's no
|
||||
padding between prefill and decode tokens.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: Optional[int] = None,
|
||||
alibi_slopes: Optional[List[float]] = None,
|
||||
sliding_window: Optional[int] = None,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||
self.sliding_window = -1
|
||||
if alibi_slopes is not None:
|
||||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
||||
self.alibi_slopes = alibi_slopes
|
||||
|
||||
self.need_mask = (self.alibi_slopes is not None
|
||||
or self.sliding_window is not None)
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
suppored_head_sizes = PagedAttention.get_supported_head_sizes()
|
||||
if head_size not in suppored_head_sizes:
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by PagedAttention. "
|
||||
f"Supported head sizes are: {suppored_head_sizes}.")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata[FlashAttentionMetadata],
|
||||
kv_scale: float,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashAttention and PagedAttention.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads * head_size]
|
||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
num_tokens, hidden_size = query.shape
|
||||
# Reshape the query, key, and value tensors.
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
|
||||
# enable musa flash attention
|
||||
torch.backends.cuda.enable_flash_sdp(True)
|
||||
torch.backends.cuda.enable_math_sdp(False)
|
||||
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
||||
|
||||
if kv_cache is not None:
|
||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||
kv_cache, self.num_kv_heads, self.head_size)
|
||||
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# If kv_cache is not provided, the new key and value tensors are
|
||||
# not cached. This happens during the initial memory profiling run.
|
||||
PagedAttention.write_to_paged_cache(key, value, key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
attn_metadata.kv_cache_dtype,
|
||||
kv_scale)
|
||||
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
|
||||
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
|
||||
|
||||
output = torch.empty_like(query)
|
||||
# Query for decode. KV is not needed because it is already cached.
|
||||
decode_query = query[num_prefill_tokens:]
|
||||
# QKV for prefill.
|
||||
query = query[:num_prefill_tokens]
|
||||
key = key[:num_prefill_tokens]
|
||||
value = value[:num_prefill_tokens]
|
||||
query = query.movedim(0, query.dim() - 2).unsqueeze(0)
|
||||
key = key.movedim(0, key.dim() - 2).unsqueeze(0)
|
||||
value = value.movedim(0, value.dim() - 2).unsqueeze(0)
|
||||
|
||||
assert decode_query.shape[0] == num_decode_tokens
|
||||
|
||||
if prefill_meta := attn_metadata.prefill_metadata:
|
||||
tensor = torch.full(
|
||||
(1, 1, num_tokens, num_tokens),
|
||||
dtype=torch.bool,
|
||||
fill_value=1,
|
||||
device=query.device)
|
||||
att_mask = torch.tril(tensor, diagonal=0)
|
||||
# Prompt run.
|
||||
if kv_cache is None or prefill_meta.block_tables.numel() == 0:
|
||||
# normal attention
|
||||
# When block_tables are not filled, it means q and k are the
|
||||
# prompt, and they have the same length.
|
||||
attn_output = scaled_dot_product_attention(
|
||||
query.contiguous(),
|
||||
key.contiguous(),
|
||||
value.contiguous(),
|
||||
attn_mask=att_mask.contiguous(),
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
)
|
||||
attn_output = attn_output.squeeze(0).permute(1, 0, 2).contiguous()
|
||||
assert output[:num_prefill_tokens].shape == attn_output.shape
|
||||
output[:num_prefill_tokens] = attn_output
|
||||
else:
|
||||
# prefix-enabled attention
|
||||
# TODO(Hai) this triton kernel has regression issue (broke) to
|
||||
# deal with different data types between KV and FP8 KV cache,
|
||||
# to be addressed separately.
|
||||
output[:num_prefill_tokens] = PagedAttention.forward_prefix(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
prefill_meta.block_tables,
|
||||
prefill_meta.subquery_start_loc,
|
||||
prefill_meta.seq_lens_tensor,
|
||||
prefill_meta.context_lens_tensor,
|
||||
prefill_meta.max_query_len,
|
||||
self.alibi_slopes,
|
||||
self.sliding_window[0],
|
||||
)
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
# Decoding run.
|
||||
output[num_prefill_tokens:] = PagedAttention.forward_decode(
|
||||
decode_query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
decode_meta.block_tables,
|
||||
decode_meta.seq_lens_tensor,
|
||||
decode_meta.max_seq_len,
|
||||
attn_metadata.kv_cache_dtype,
|
||||
self.num_kv_heads,
|
||||
self.scale,
|
||||
self.alibi_slopes,
|
||||
kv_scale,
|
||||
)
|
||||
|
||||
# Reshape the output tensor.
|
||||
return output.view(num_tokens, hidden_size)
|
||||
Reference in New Issue
Block a user