Iluvatar-mrv100 SDK 4.3.0
This commit is contained in:
0
vllm/v1/attention/backends/__init__.py
Normal file
0
vllm/v1/attention/backends/__init__.py
Normal file
739
vllm/v1/attention/backends/flash_attn.py
Normal file
739
vllm/v1/attention/backends/flash_attn.py
Normal file
@@ -0,0 +1,739 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Attention layer with FlashAttention."""
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata, AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
# from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import cdiv
|
||||
from vllm.vllm_flash_attn.fa_utils import (flash_attn_supports_fp8,
|
||||
get_flash_attn_version)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
|
||||
# if current_platform.is_cuda():
|
||||
# from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||
from ixformer.contrib.vllm_flash_attn import flash_attn_varlen_func, merge_attn_states
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class FlashAttentionBackend(AttentionBackend):
|
||||
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
@staticmethod
|
||||
def get_supported_head_sizes() -> list[int]:
|
||||
return [32, 64, 96, 128, 160, 192, 224, 256]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASH_ATTN_VLLM_V1"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["FlashAttentionImpl"]:
|
||||
return FlashAttentionImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> type["AttentionMetadata"]:
|
||||
return FlashAttentionMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["FlashAttentionMetadataBuilder"]:
|
||||
return FlashAttentionMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> tuple[int, ...]:
|
||||
if block_size % 16 != 0:
|
||||
raise ValueError("Block size must be a multiple of 16.")
|
||||
return (2, num_blocks, num_kv_heads, block_size, head_size)
|
||||
|
||||
@staticmethod
|
||||
def use_cascade_attention(*args, **kwargs) -> bool:
|
||||
return use_cascade_attention(*args, **kwargs)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashAttentionMetadata:
|
||||
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
# |- tokenA -|......................|-- newTokens ---|
|
||||
# |---------- context_len ----------|
|
||||
# |-------------------- seq_len ---------------------|
|
||||
# |-- query_len ---|
|
||||
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
max_query_len: int
|
||||
query_start_loc: torch.Tensor
|
||||
key_start_loc: torch.Tensor
|
||||
max_seq_len: int
|
||||
seq_lens: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
# For cascade attention.
|
||||
use_cascade: bool
|
||||
common_prefix_len: int
|
||||
cu_prefix_query_lens: Optional[torch.Tensor]
|
||||
prefix_kv_lens: Optional[torch.Tensor]
|
||||
suffix_kv_lens: Optional[torch.Tensor]
|
||||
cu_prefix_kv_lens: Optional[torch.Tensor]
|
||||
cu_suffix_kv_lens: Optional[torch.Tensor]
|
||||
|
||||
# For logging.
|
||||
num_input_tokens: int = 0 # Number of tokens including padding.
|
||||
|
||||
# for local attention
|
||||
@dataclass
|
||||
class LocalAttentionMetadata:
|
||||
local_query_start_loc: torch.Tensor
|
||||
local_k_start_loc: torch.Tensor
|
||||
local_seqused_k: torch.Tensor
|
||||
local_block_table: torch.Tensor
|
||||
local_max_query_len: int
|
||||
local_max_seq_len: int
|
||||
|
||||
local_attn_metadata: Optional[LocalAttentionMetadata] = None
|
||||
|
||||
|
||||
#
|
||||
# Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into
|
||||
# local attention blocks, where each block is passed to the attention kernel
|
||||
# as an independent local ("virtual") batch item.
|
||||
#
|
||||
# For example, if are performing a chunked prefill a batch of 3 sequences:
|
||||
# q_seqlens = [4, 10, 5]
|
||||
# kv_seqlens = [6, 17, 9]
|
||||
# Then normally for regular attention we would compute with an attention mask
|
||||
# for batch idx 0 (q_seqlens = 4, kv_seqlens = 6) like:
|
||||
# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6)
|
||||
# k_toks > 0 1 2 3 4 5
|
||||
# q_toks v _____________
|
||||
# 0 | 1 1 1
|
||||
# 1 | 1 1 1 1
|
||||
# 2 | 1 1 1 1 1
|
||||
# 3 | 1 1 1 1 1 1
|
||||
#
|
||||
# for local attention (with attn_chunk_size = 4) we would compute with an
|
||||
# attention mask like:
|
||||
# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6, attn_chunk_size = 4)
|
||||
# k_toks > 0 1 2 3 4 5
|
||||
# q_toks v _____________
|
||||
# 0 | 1 1 1
|
||||
# 1 | 1 1 1 1
|
||||
# 2 | 1
|
||||
# 3 | 1 1
|
||||
#
|
||||
# We can simulate this mask using standard flash-attention by breaking the
|
||||
# sequences into local ("virtual") batches, where each local batch item is a
|
||||
# local attention block, so in this case batch idx 0 would be broken up into:
|
||||
#
|
||||
# local-batch idx: 0 (q_seqlens = 2, kv_seqlens = 4) (batch 0)
|
||||
# k_toks > 0 1 2 3
|
||||
# q_toks v _____________
|
||||
# 0 | 1 1 1
|
||||
# 1 | 1 1 1 1
|
||||
# local-batch idx: 1 (q_seqlens = 2, kv_seqlens = 2) (batch 0)
|
||||
# k_toks > 4 5
|
||||
# q_toks v _____________
|
||||
# 2 | 1
|
||||
# 3 | 1 1
|
||||
#
|
||||
# e.g. if we have:
|
||||
# attn_chunk_size = 4
|
||||
# query_start_loc_np = [0, 4, 14, 19] (q_seqlens = [4, 10, 5])
|
||||
# Then this function would return:
|
||||
# __b0__ ______b1______ __b2__ < orig batch indices
|
||||
# q_seqlens_local = [ 2, 2, 1, 4, 4, 1, 4, 1]
|
||||
# cu_seqlens_q_local = [0, 4, 6, 10, 14, 18, 19, 23, 24]
|
||||
# seqlens_k_local = [ 4, 2, 4, 4, 4, 1, 4, 1]
|
||||
# block_table_local : shape[local_virtual_batches, pages_per_local_batch]
|
||||
def make_local_attention_virtual_batches(
|
||||
attn_chunk_size: int,
|
||||
query_start_loc_np: np.ndarray,
|
||||
seq_lens_np: np.ndarray,
|
||||
block_table: torch.tensor,
|
||||
page_size: int = 0,
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray, torch.tensor]:
|
||||
q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1]
|
||||
actual_batch_size = seq_lens_np.shape[0]
|
||||
|
||||
# Handle if we are starting in the middle of a local attention block,
|
||||
# we assume q_seqlens > 0 (for all elements), for each batch idx we compute
|
||||
# the number of tokens that are not in the first local attention block and
|
||||
# then we can simply use a cdiv for the rest.
|
||||
# For example if we have:
|
||||
# attn_chunk_size = 4
|
||||
# q_seqlens = [4, 10, 5]
|
||||
# k_seqlens = [6, 17, 9]
|
||||
# Then we would get:
|
||||
# new_tokens_in_first_block = [2, 1, 4]
|
||||
# local_blocks = [2, 4, 2]
|
||||
q_tokens_in_first_block = np.minimum(
|
||||
attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size),
|
||||
q_seqlens).astype(np.int32)
|
||||
tokens_in_last_block = attn_chunk_size + (seq_lens_np % -attn_chunk_size)
|
||||
local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block,
|
||||
attn_chunk_size)
|
||||
|
||||
# Once we know the number of local blocks we can compute the request spans
|
||||
# for each batch idx, we can figure out the number of "virtual" requests we
|
||||
# have to make,
|
||||
# For the above example we would get:
|
||||
# seqlens_q_local = [2, 2, 1, 4, 4, 1, 4, 1]
|
||||
#
|
||||
# First Get batched arange. (E.g., [2, 4, 2] -> [0, 1, 0, 1, 2, 3, 0, 1])
|
||||
# (TODO: max a utility to share this code with _prepare_inputs)
|
||||
# arange step 1. [2, 4, 2] -> [2, 6, 8]
|
||||
cu_num_blocks = np.cumsum(local_blocks)
|
||||
virtual_batches = cu_num_blocks[-1]
|
||||
# arange step 2. [2, 6, 8] -> [0, 0, 2, 2, 2, 2, 6, 6]
|
||||
block_offsets = np.repeat(cu_num_blocks - local_blocks, local_blocks)
|
||||
# arange step 3. [0, 1, 0, 1, 2, 3, 0, 1]
|
||||
arange = np.arange(virtual_batches, dtype=np.int32) - block_offsets
|
||||
# also compute reverse arange (i.e. [1, 0, 3, 2, 1, 0, 1, 0])
|
||||
rarange = np.repeat(local_blocks, local_blocks) - arange - 1
|
||||
# Then we can compute the seqlens_q_local, handling the fact that the
|
||||
# first and last blocks could be partial
|
||||
seqlens_q_local = \
|
||||
np.repeat(q_seqlens - q_tokens_in_first_block, local_blocks)
|
||||
# set the first block since this may be a partial block
|
||||
seqlens_q_local[arange == 0] = q_tokens_in_first_block
|
||||
# set the remaining blocks
|
||||
seqlens_q_local[arange > 0] = np.minimum(
|
||||
seqlens_q_local - attn_chunk_size * (arange - 1),
|
||||
attn_chunk_size)[arange > 0]
|
||||
|
||||
# convert from q_seqlens to cu_seqlens_q
|
||||
cu_seqlens_q_local = np.pad(np.cumsum(seqlens_q_local), (1, 0))\
|
||||
.astype(np.int32)
|
||||
|
||||
# compute the seqlens_k_local,
|
||||
# basically a full local attention block for all but the last block in each
|
||||
# batch
|
||||
# For our example this will be:
|
||||
# seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1]
|
||||
seqlens_k_local = np.full(cu_num_blocks[-1],
|
||||
attn_chunk_size,
|
||||
dtype=np.int32)
|
||||
seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block
|
||||
# convert from q_seqlens to cu_seqlens_q
|
||||
cu_seqlens_k_local = np.pad(np.cumsum(seqlens_k_local), (1, 0))\
|
||||
.astype(np.int32)
|
||||
|
||||
k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - \
|
||||
(rarange * attn_chunk_size + \
|
||||
np.repeat(tokens_in_last_block, local_blocks))
|
||||
# For the example the local attention blocks start at:
|
||||
# _b0_ _____b1_____ _b2_
|
||||
# k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8]
|
||||
block_starts = k_seqstarts_absolute // page_size
|
||||
assert attn_chunk_size % page_size == 0, \
|
||||
f"attn_chunk_size {attn_chunk_size} is not " \
|
||||
f"divisible by page_size {page_size}"
|
||||
pages_per_local_batch = attn_chunk_size // page_size
|
||||
|
||||
# Create a block_table for the local attention blocks
|
||||
# For out example if we have a block-table like (assuming page_size=2):
|
||||
# block_table = [
|
||||
# [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], < batch 0
|
||||
# [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], < batch 1
|
||||
# [20, 21, 22, 23, 24, 25, 26, 27, 28, 29], < batch 2
|
||||
# ]
|
||||
# Then for the local batches we would want a block-table like
|
||||
# block_table_local = [
|
||||
# [ 0, 1 ], < local-batch 0, (batch 0, starting from k[0])
|
||||
# [ 2, 3 ], < local-batch 1, (batch 0, starting from k[4])
|
||||
# [ 12, 13 ], < local-batch 2, (batch 1, starting from k[4])
|
||||
# [ 14, 15 ], < local-batch 3, (batch 1, starting from k[8])
|
||||
# [ 16, 17 ], < local-batch 4, (batch 1, starting from k[12])
|
||||
# [ 18, 19 ], < local-batch 5, (batch 1, starting from k[16])
|
||||
# [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4])
|
||||
# [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8])
|
||||
# ]
|
||||
block_indices= np.broadcast_to(
|
||||
np.arange(pages_per_local_batch, dtype=np.int32),
|
||||
(virtual_batches, pages_per_local_batch)) \
|
||||
+ np.expand_dims(block_starts, axis=1)
|
||||
block_indices = block_indices.flatten()
|
||||
batch_indices = np.repeat(np.arange(actual_batch_size, dtype=np.int32),
|
||||
local_blocks * pages_per_local_batch)
|
||||
block_table_local = block_table[batch_indices, block_indices]\
|
||||
.view(virtual_batches, -1)
|
||||
|
||||
return seqlens_q_local, cu_seqlens_q_local, seqlens_k_local, cu_seqlens_k_local, \
|
||||
block_table_local
|
||||
|
||||
|
||||
class FlashAttentionMetadataBuilder:
|
||||
|
||||
def __init__(self, runner: "GPUModelRunner"):
|
||||
self.runner = runner
|
||||
|
||||
def reorder_batch(self, input_batch: "InputBatch",
|
||||
scheduler_output: "SchedulerOutput") -> bool:
|
||||
return False
|
||||
|
||||
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
|
||||
common_prefix_len: int):
|
||||
max_seq_len = self.runner.seq_lens_np[:num_reqs].max()
|
||||
query_start_loc_cpu = self.runner.query_start_loc_cpu[:num_reqs + 1]
|
||||
query_start_loc = query_start_loc_cpu.to(self.runner.device,
|
||||
non_blocking=True)
|
||||
seq_lens_cpu = self.runner.seq_lens_cpu[:num_reqs]
|
||||
seq_lens = seq_lens_cpu.to(self.runner.device, non_blocking=True)
|
||||
key_start_loc = torch.zeros([seq_lens_cpu.shape[0]+1])
|
||||
key_start_loc[1:] = seq_lens_cpu
|
||||
key_start_loc = key_start_loc.cumsum(dim=0).to(seq_lens.dtype)
|
||||
key_start_loc = key_start_loc.to(self.runner.device, non_blocking=True)
|
||||
|
||||
block_table = (
|
||||
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
|
||||
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
|
||||
self.runner.device, non_blocking=True).long()
|
||||
|
||||
# for local attention
|
||||
local_attn_metadata = None
|
||||
if self.runner.attention_chunk_size is not None:
|
||||
seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, virt_k_cu_seqlens_np, \
|
||||
virt_block_table = make_local_attention_virtual_batches(
|
||||
self.runner.attention_chunk_size,
|
||||
self.runner.query_start_loc_np[:num_reqs + 1],
|
||||
self.runner.seq_lens_np[:num_reqs],
|
||||
block_table,
|
||||
self.runner.block_size,
|
||||
)
|
||||
local_attn_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
|
||||
local_query_start_loc=torch.from_numpy(
|
||||
virt_q_cu_seqlens_np).to(self.runner.device,
|
||||
non_blocking=True),
|
||||
local_seqused_k=torch.from_numpy(virt_k_seqlens_np).to(
|
||||
self.runner.device, non_blocking=True),
|
||||
local_block_table=virt_block_table,
|
||||
local_max_query_len=seqlens_q_local_np.max(),
|
||||
local_max_seq_len=virt_k_seqlens_np.max(),
|
||||
local_k_start_loc=torch.from_numpy(
|
||||
virt_k_cu_seqlens_np).to(self.runner.device,
|
||||
non_blocking=True),
|
||||
)
|
||||
|
||||
use_cascade = common_prefix_len > 0
|
||||
if use_cascade:
|
||||
cu_prefix_query_lens = torch.tensor([0, num_actual_tokens],
|
||||
dtype=torch.int32,
|
||||
device=self.runner.device)
|
||||
prefix_kv_lens = torch.tensor([common_prefix_len],
|
||||
dtype=torch.int32,
|
||||
device=self.runner.device)
|
||||
cu_prefix_kv_lens = torch.tensor([0, common_prefix_len],
|
||||
dtype=torch.int32,
|
||||
device=self.runner.device)
|
||||
suffix_kv_lens = (self.runner.seq_lens_np[:num_reqs] -
|
||||
common_prefix_len)
|
||||
suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to(
|
||||
self.runner.device)
|
||||
cu_suffix_kv_lens = suffix_kv_lens.new_zeros([suffix_kv_lens.shape[0]+1])
|
||||
cu_suffix_kv_lens[1:] = suffix_kv_lens
|
||||
cu_suffix_kv_lens = cu_suffix_kv_lens.cumsum(dim=0).int()
|
||||
else:
|
||||
cu_prefix_query_lens = None
|
||||
prefix_kv_lens = None
|
||||
suffix_kv_lens = None
|
||||
cu_prefix_kv_lens = None
|
||||
cu_suffix_kv_lens = None
|
||||
|
||||
attn_metadata = FlashAttentionMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
max_query_len=max_query_len,
|
||||
query_start_loc=query_start_loc,
|
||||
key_start_loc=key_start_loc,
|
||||
max_seq_len=max_seq_len,
|
||||
seq_lens=seq_lens,
|
||||
block_table=block_table,
|
||||
slot_mapping=slot_mapping,
|
||||
use_cascade=use_cascade,
|
||||
common_prefix_len=common_prefix_len,
|
||||
cu_prefix_query_lens=cu_prefix_query_lens,
|
||||
prefix_kv_lens=prefix_kv_lens,
|
||||
suffix_kv_lens=suffix_kv_lens,
|
||||
local_attn_metadata=local_attn_metadata,
|
||||
cu_prefix_kv_lens=cu_prefix_kv_lens,
|
||||
cu_suffix_kv_lens=cu_suffix_kv_lens,
|
||||
)
|
||||
|
||||
return attn_metadata
|
||||
|
||||
|
||||
class FlashAttentionImpl(AttentionImpl):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
use_irope: bool = False,
|
||||
) -> None:
|
||||
if blocksparse_params is not None:
|
||||
raise ValueError(
|
||||
"FlashAttention does not support block-sparse attention.")
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
if alibi_slopes is not None:
|
||||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
||||
self.alibi_slopes = alibi_slopes
|
||||
if sliding_window is None:
|
||||
self.sliding_window = (-1, -1)
|
||||
else:
|
||||
self.sliding_window = (sliding_window - 1, 0)
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
if logits_soft_cap is None:
|
||||
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
|
||||
logits_soft_cap = 0
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
support_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
|
||||
if head_size not in support_head_sizes:
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by FlashAttention. "
|
||||
f"Supported head sizes are: {support_head_sizes}. "
|
||||
"Set VLLM_USE_V1=0 to use another attention backend.")
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"FlashAttentionImpl")
|
||||
self.use_irope = use_irope
|
||||
self.vllm_flash_attn_version = get_flash_attn_version()
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype) \
|
||||
and not flash_attn_supports_fp8():
|
||||
raise NotImplementedError(
|
||||
"FlashAttention does not support fp8 kv-cache on this device.")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
kv_cache_scale: torch.Tensor,
|
||||
attn_metadata: FlashAttentionMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashAttention.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads, head_size]
|
||||
key: shape = [num_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_tokens, num_kv_heads, head_size]
|
||||
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
NOTE: FP8 quantization, flash-attn expect the size of
|
||||
{q,k,v}_descale to be (num_sequences, num_kv_heads).
|
||||
We use torch's .expand() to avoid duplicating values
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
|
||||
return output.view(-1, self.num_heads * self.head_size)
|
||||
|
||||
# IMPORTANT!
|
||||
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
|
||||
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
|
||||
# in this method. For example, `view` and `slice` (or `[:n]`) operations
|
||||
# are surprisingly slow even in the case they do not invoke any GPU ops.
|
||||
# Minimize the PyTorch ops in this method as much as possible.
|
||||
# Whenever making a change in this method, please benchmark the
|
||||
# performance to make sure it does not introduce any overhead.
|
||||
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
|
||||
# not padded. However, we don't need to do key[:num_actual_tokens] and
|
||||
# value[:num_actual_tokens] because the reshape_and_cache_flash op uses
|
||||
# the slot_mapping's shape to determine the number of actual tokens.
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
key_cache = key_cache.view(torch.float8_e4m3fn)
|
||||
value_cache = value_cache.view(torch.float8_e4m3fn)
|
||||
num_tokens, num_heads, head_size = query.shape
|
||||
query, _ = ops.scaled_fp8_quant(
|
||||
query.reshape(
|
||||
(num_tokens, num_heads * head_size)).contiguous(),
|
||||
layer._q_scale)
|
||||
query = query.reshape((num_tokens, num_heads, head_size))
|
||||
|
||||
# Compute attention and update output up to `num_actual_tokens`.
|
||||
use_local_attn = \
|
||||
(self.use_irope and attn_metadata.local_attn_metadata is not None)
|
||||
|
||||
if not attn_metadata.use_cascade or use_local_attn:
|
||||
if use_local_attn:
|
||||
assert attn_metadata.local_attn_metadata is not None
|
||||
local_metadata = attn_metadata.local_attn_metadata
|
||||
cu_seqlens_q = local_metadata.local_query_start_loc
|
||||
seqused_k = local_metadata.local_seqused_k
|
||||
max_seqlen_q = local_metadata.local_max_query_len
|
||||
max_seqlen_k = local_metadata.local_max_seq_len
|
||||
block_table = local_metadata.local_block_table
|
||||
cu_seqlens_k = local_metadata.local_k_start_loc
|
||||
else:
|
||||
cu_seqlens_q = attn_metadata.query_start_loc
|
||||
seqused_k = attn_metadata.seq_lens
|
||||
max_seqlen_q = attn_metadata.max_query_len
|
||||
max_seqlen_k = attn_metadata.max_seq_len
|
||||
block_table = attn_metadata.block_table
|
||||
cu_seqlens_k = attn_metadata.key_start_loc
|
||||
|
||||
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
|
||||
flash_attn_varlen_func( # noqa
|
||||
q=query[:num_actual_tokens],
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
window_size=self.sliding_window,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
block_table=block_table,
|
||||
softcap=self.logits_soft_cap,
|
||||
sqrt_alibi=False,
|
||||
out=output[:num_actual_tokens],
|
||||
)
|
||||
return output.view(-1, self.num_heads * self.head_size)
|
||||
|
||||
assert not use_local_attn, (
|
||||
"Cascade attention does not support local attention.")
|
||||
# Cascade attention (rare case).
|
||||
cascade_attention(
|
||||
output[:num_actual_tokens],
|
||||
query[:num_actual_tokens],
|
||||
key_cache,
|
||||
value_cache,
|
||||
cu_query_lens=attn_metadata.query_start_loc,
|
||||
max_query_len=attn_metadata.max_query_len,
|
||||
cu_prefix_query_lens=attn_metadata.cu_prefix_query_lens,
|
||||
cu_prefix_kv_lens=attn_metadata.cu_prefix_kv_lens,
|
||||
cu_suffix_kv_lens=attn_metadata.cu_suffix_kv_lens,
|
||||
max_kv_len=attn_metadata.max_seq_len,
|
||||
softmax_scale=self.scale,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
sliding_window=self.sliding_window,
|
||||
logits_soft_cap=self.logits_soft_cap,
|
||||
block_table=attn_metadata.block_table,
|
||||
common_prefix_len=attn_metadata.common_prefix_len,
|
||||
fa_version=self.vllm_flash_attn_version,
|
||||
q_descale=layer._q_scale,
|
||||
k_descale=layer._k_scale,
|
||||
v_descale=layer._v_scale,
|
||||
)
|
||||
return output.view(-1, self.num_heads * self.head_size)
|
||||
|
||||
|
||||
def use_cascade_attention(
|
||||
common_prefix_len: int,
|
||||
query_lens: np.ndarray,
|
||||
num_query_heads: int,
|
||||
num_kv_heads: int,
|
||||
use_alibi: bool,
|
||||
use_sliding_window: bool,
|
||||
num_sms: int,
|
||||
) -> bool:
|
||||
"""Decide whether to use cascade attention.
|
||||
|
||||
This function 1) checks whether cascade attention is supported with the
|
||||
given configuration, and 2) heuristically decides whether using cascade
|
||||
attention can improve performance.
|
||||
"""
|
||||
# Too short common prefix. Probably not worth using cascade attention.
|
||||
# We use an arbitrary threshold of 256 tokens. TODO: Tune this threshold.
|
||||
# NOTE(woosuk): This is the common case. We should return False as soon as
|
||||
# possible to avoid any unnecessary computation.
|
||||
if common_prefix_len < 256:
|
||||
return False
|
||||
# Cascade attention is currently not supported with these variants.
|
||||
if use_alibi or use_sliding_window:
|
||||
return False
|
||||
# Too few queries. Probably not worth using cascade attention.
|
||||
# We use an arbitrary threshold of 8 queries. TODO: Tune this threshold.
|
||||
num_reqs = len(query_lens)
|
||||
if num_reqs < 8:
|
||||
return False
|
||||
|
||||
# Heuristics to decide whether using cascade attention is beneficial.
|
||||
# 1. When FlashDecoding is not used for normal attention, cascade attention
|
||||
# is likely to be faster since it saves memory bandwidth.
|
||||
num_queries_per_kv = num_query_heads // num_kv_heads
|
||||
# The criteria for using FlashDecoding can be found in the following link:
|
||||
# https://github.com/vllm-project/flash-attention/blob/96266b1111111f3d11aabefaf3bacbab6a89d03c/csrc/flash_attn/flash_api.cpp#L535
|
||||
use_flash_decoding = (num_queries_per_kv > 1 and not use_sliding_window
|
||||
and not use_alibi and np.all(query_lens == 1))
|
||||
if not use_flash_decoding:
|
||||
# Use cascade attention.
|
||||
return True
|
||||
else:
|
||||
# flash_decoding not supported now!
|
||||
return False
|
||||
|
||||
# 2. When FlashDecoding is used for normal attention, it is not clear
|
||||
# whether cascade attention is beneficial, because FlashDecoding can
|
||||
# launch more CTAs than cascade attention.
|
||||
# We use a simple performance model to compare the two methods.
|
||||
# NOTE(woosuk): The performance model is very rough and may not be
|
||||
# accurate.
|
||||
num_tokens = num_reqs
|
||||
# NOTE(woosuk): These are default tile sizes. flash-attn might use
|
||||
# different tile sizes (e.g., 64 or 256) depending on the configuration.
|
||||
q_tile_size = 128
|
||||
kv_tile_size = 128
|
||||
num_prefix_tiles = cdiv(common_prefix_len, kv_tile_size)
|
||||
|
||||
cascade_ctas = num_query_heads * cdiv(num_tokens, q_tile_size)
|
||||
cascade_waves = cdiv(cascade_ctas, num_sms)
|
||||
cascade_time = cascade_waves * num_prefix_tiles
|
||||
|
||||
flash_decoding_ctas = (num_reqs * num_kv_heads *
|
||||
cdiv(num_queries_per_kv, q_tile_size))
|
||||
flash_decoding_ctas *= num_prefix_tiles
|
||||
flash_decoding_time = cdiv(flash_decoding_ctas, num_sms)
|
||||
|
||||
# Use cascade attention if it is faster than FlashDecoding.
|
||||
return cascade_time < flash_decoding_time
|
||||
|
||||
|
||||
def cascade_attention(
|
||||
output: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
cu_query_lens: torch.Tensor,
|
||||
max_query_len: int,
|
||||
cu_prefix_query_lens: torch.Tensor,
|
||||
cu_prefix_kv_lens: torch.Tensor,
|
||||
cu_suffix_kv_lens: torch.Tensor,
|
||||
max_kv_len: int,
|
||||
softmax_scale: float,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
sliding_window: tuple[int, int],
|
||||
logits_soft_cap: float,
|
||||
block_table: torch.Tensor,
|
||||
common_prefix_len: int,
|
||||
fa_version: int,
|
||||
q_descale: Optional[torch.Tensor] = None,
|
||||
k_descale: Optional[torch.Tensor] = None,
|
||||
v_descale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
assert alibi_slopes is None, ("Cascade attention does not support ALiBi.")
|
||||
# TODO: Support sliding window.
|
||||
assert sliding_window == (-1, -1), (
|
||||
"Cascade attention does not support sliding window.")
|
||||
|
||||
num_tokens = query.shape[0]
|
||||
block_size = key_cache.shape[-2]
|
||||
assert common_prefix_len % block_size == 0
|
||||
num_common_kv_blocks = common_prefix_len // block_size
|
||||
assert num_common_kv_blocks > 0
|
||||
assert q_descale is None or q_descale==1, f"q_descale is not None, q_descale: {q_descale}"
|
||||
assert k_descale is None or k_descale==1, f"k_descale is not None, k_descale: {k_descale}"
|
||||
assert v_descale is None or v_descale==1, f"v_descale is not None, v_descale: {v_descale}"
|
||||
|
||||
# Process shared prefix.
|
||||
prefix_output, prefix_lse = flash_attn_varlen_func(
|
||||
q=query,
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
cu_seqlens_q=cu_prefix_query_lens,
|
||||
cu_seqlens_k=cu_prefix_kv_lens,
|
||||
max_seqlen_q=num_tokens,
|
||||
max_seqlen_k=common_prefix_len,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=False,
|
||||
window_size=sliding_window,
|
||||
block_table=block_table[:1],
|
||||
softcap=logits_soft_cap,
|
||||
return_softmax_lse=True,
|
||||
)
|
||||
|
||||
# Process suffix per query.
|
||||
suffix_output, suffix_lse = flash_attn_varlen_func(
|
||||
q=query,
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
cu_seqlens_q=cu_query_lens,
|
||||
cu_seqlens_k=cu_suffix_kv_lens,
|
||||
max_seqlen_q=max_query_len,
|
||||
max_seqlen_k=max_kv_len - common_prefix_len,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
window_size=sliding_window,
|
||||
block_table=block_table[:, num_common_kv_blocks:],
|
||||
softcap=logits_soft_cap,
|
||||
return_softmax_lse=True,
|
||||
)
|
||||
|
||||
def ref_merge_state(out_1, lse_1, out_2, lse_2):
|
||||
num_heads, seq_len = lse_1.shape
|
||||
|
||||
lse_2 = lse_2.transpose(0,1).view(seq_len, num_heads, 1)
|
||||
lse_1 = lse_1.transpose(0,1).view(seq_len, num_heads, 1)
|
||||
|
||||
s_max = torch.maximum(lse_1, lse_2)
|
||||
|
||||
d = torch.exp2(lse_1-s_max) + torch.exp2(lse_2-s_max)
|
||||
v_merged = out_1 * torch.exp2(lse_1-s_max) + out_2 * torch.exp2(lse_2-s_max)
|
||||
v_merged = v_merged / d
|
||||
return v_merged, (torch.log2(d) + s_max).view(seq_len, num_heads)
|
||||
|
||||
# Merge prefix and suffix outputs, and store the result in output.
|
||||
# merge_attn_states(output, prefix_output, prefix_lse, suffix_output,
|
||||
# suffix_lse)
|
||||
merge_attn_states(prefix_output, prefix_lse, suffix_output, suffix_lse, output)
|
||||
|
||||
0
vllm/v1/attention/backends/mla/__init__.py
Normal file
0
vllm/v1/attention/backends/mla/__init__.py
Normal file
1194
vllm/v1/attention/backends/mla/common.py
Normal file
1194
vllm/v1/attention/backends/mla/common.py
Normal file
File diff suppressed because it is too large
Load Diff
149
vllm/v1/attention/backends/mla/flashmla.py
Normal file
149
vllm/v1/attention/backends/mla/flashmla.py
Normal file
@@ -0,0 +1,149 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
|
||||
get_mla_metadata,
|
||||
is_flashmla_supported)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
|
||||
MLACommonDecodeMetadata,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata,
|
||||
MLACommonMetadataBuilder)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class FlashMLABackend(MLACommonBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASHMLA_VLLM_V1"
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> type["FlashMLAMetadata"]:
|
||||
return FlashMLAMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["FlashMLAMetadataBuilder"]:
|
||||
return FlashMLAMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["FlashMLAImpl"]:
|
||||
return FlashMLAImpl
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashMLADecodeMetadata(MLACommonDecodeMetadata):
|
||||
tile_scheduler_metadata: tuple[torch.Tensor, torch.Tensor]
|
||||
num_splits: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):
|
||||
pass
|
||||
|
||||
|
||||
class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
|
||||
def __init__(self, runner):
|
||||
super().__init__(runner)
|
||||
|
||||
self.num_q_heads = self.runner.model_config.get_num_attention_heads(
|
||||
self.runner.parallel_config)
|
||||
|
||||
def _build_decode(self, input_positions: torch.Tensor,
|
||||
block_table: torch.Tensor,
|
||||
seq_lens: torch.Tensor) -> FlashMLADecodeMetadata:
|
||||
tile_scheduler_metadata, num_splits = \
|
||||
get_mla_metadata(
|
||||
seq_lens,
|
||||
self.num_q_heads,
|
||||
1, # MQA for the decode path
|
||||
)
|
||||
|
||||
return FlashMLADecodeMetadata(
|
||||
input_positions=input_positions,
|
||||
block_table=block_table,
|
||||
seq_lens=seq_lens,
|
||||
tile_scheduler_metadata=tile_scheduler_metadata,
|
||||
num_splits=num_splits,
|
||||
)
|
||||
|
||||
|
||||
class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[dict[str, Any]],
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
# MLA Specific Arguments
|
||||
**mla_args) -> None:
|
||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
blocksparse_params, logits_soft_cap, attn_type,
|
||||
**mla_args)
|
||||
|
||||
assert is_flashmla_supported(), \
|
||||
"FlashMLA is not supported on this device"
|
||||
|
||||
unsupported_features = [
|
||||
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
|
||||
]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"FlashMLAImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, blocksparse_params, "
|
||||
"logits_soft_cap")
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"FlashMLAImpl")
|
||||
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
raise NotImplementedError(
|
||||
"FlashMLA V1 with FP8 KV cache not yet supported")
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: FlashMLAMetadata,
|
||||
) -> torch.Tensor:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)\
|
||||
.unsqueeze(1) # Add seqlen dim of 1 (decode)
|
||||
|
||||
o, _ = flash_mla_with_kvcache(
|
||||
q=q,
|
||||
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
|
||||
block_table=attn_metadata.decode.block_table,
|
||||
cache_seqlens=attn_metadata.decode.seq_lens,
|
||||
head_dim_v=self.kv_lora_rank,
|
||||
tile_scheduler_metadata=attn_metadata.decode.
|
||||
tile_scheduler_metadata,
|
||||
num_splits=attn_metadata.decode.num_splits,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
return self._v_up_proj_and_o_proj(o)
|
||||
154
vllm/v1/attention/backends/mla/triton_mla.py
Normal file
154
vllm/v1/attention/backends/mla/triton_mla.py
Normal file
@@ -0,0 +1,154 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata)
|
||||
import ixformer.inference.functions as ixf_ops
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class TritonMLABackend(MLACommonBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "TRITON_MLA_VLLM_V1"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["TritonMLAImpl"]:
|
||||
return TritonMLAImpl
|
||||
|
||||
|
||||
class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[dict[str, Any]],
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
# MLA Specific Arguments
|
||||
**mla_args) -> None:
|
||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
blocksparse_params, logits_soft_cap, attn_type,
|
||||
**mla_args)
|
||||
|
||||
unsupported_features = [
|
||||
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
|
||||
]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"TritonMLAImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, blocksparse_params, "
|
||||
"logits_soft_cap")
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"TritonMLAImpl")
|
||||
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
raise NotImplementedError(
|
||||
"TritonMLA V1 with FP8 KV cache not yet supported")
|
||||
self._k_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
kv_c_and_k_pe_cache_scale: torch.Tensor,
|
||||
attn_metadata: MLACommonMetadata,
|
||||
k_c_normed: torch.Tensor=None,
|
||||
k_pe: torch.Tensor=None,
|
||||
) -> torch.Tensor:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
raise NotImplementedError("FP8 Triton MLA not yet supported")
|
||||
|
||||
B = q_nope.shape[0]
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
||||
|
||||
o = torch.empty(B,
|
||||
self.num_heads,
|
||||
self.kv_lora_rank,
|
||||
dtype=q_nope.dtype,
|
||||
device=q_nope.device)
|
||||
|
||||
# num_kv_splits = 4 # TODO: heuristic
|
||||
|
||||
# # TODO(lucas) Allocate ahead of time
|
||||
# attn_logits = torch.empty(
|
||||
# (
|
||||
# B,
|
||||
# self.num_heads,
|
||||
# num_kv_splits,
|
||||
# # NOTE(lucas) idk why the +1 is here but sglang has it so we
|
||||
# # just mirror that
|
||||
# self.kv_lora_rank + 1,
|
||||
# ),
|
||||
# dtype=torch.float32,
|
||||
# device=q.device,
|
||||
# )
|
||||
|
||||
# # Add a head dim of 1
|
||||
# kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2)
|
||||
# kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank]
|
||||
# PAGE_SIZE = kv_c_and_k_pe_cache.size(1)
|
||||
|
||||
# # Run MQA
|
||||
# decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o,
|
||||
# attn_metadata.decode.block_table,
|
||||
# attn_metadata.decode.seq_lens, attn_logits,
|
||||
# num_kv_splits, self.scale, PAGE_SIZE)
|
||||
if envs.VLLM_USE_INT8_MLA:
|
||||
q_int8, q_scale = ops.quant_kv(q)
|
||||
ixf_ops.vllm_paged_attention_mla_int8(
|
||||
o,
|
||||
q_int8,
|
||||
q_scale,
|
||||
kv_c_and_k_pe_cache,
|
||||
kv_c_and_k_pe_cache_scale,
|
||||
self.scale,
|
||||
attn_metadata.decode.block_table,
|
||||
attn_metadata.decode.seq_lens,
|
||||
attn_metadata.decode.max_decode_seq_len,
|
||||
attn_metadata.decode.use_cuda_graph
|
||||
)
|
||||
|
||||
else:
|
||||
# fused q concat & cache write
|
||||
ixf_ops.vllm_paged_attention_mla_fused(
|
||||
output=o,
|
||||
q_nope=q_nope,
|
||||
q_pe=q_pe.contiguous(),
|
||||
kv_cache=kv_c_and_k_pe_cache,
|
||||
scale=self.scale,
|
||||
block_tables=attn_metadata.decode.block_table,
|
||||
context_lens=attn_metadata.decode.seq_lens,
|
||||
max_context_len=attn_metadata.decode.max_decode_seq_len,
|
||||
k_c_normed=k_c_normed,
|
||||
k_pe=k_pe,
|
||||
use_cuda_graph=attn_metadata.decode.use_cuda_graph
|
||||
)
|
||||
return self._v_up_proj_and_o_proj(o)
|
||||
198
vllm/v1/attention/backends/pallas.py
Normal file
198
vllm/v1/attention/backends/pallas.py
Normal file
@@ -0,0 +1,198 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
# Required to register custom ops.
|
||||
import torch_xla.experimental.custom_kernel # noqa: F401
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionLayer, AttentionType)
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
|
||||
|
||||
class PallasAttentionBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "PALLAS_VLLM_V1"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["PallasAttentionBackendImpl"]:
|
||||
return PallasAttentionBackendImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> type["PallasMetadata"]:
|
||||
return PallasMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> type["CommonAttentionState"]:
|
||||
return CommonAttentionState
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> tuple[int, ...]:
|
||||
return (num_blocks, block_size, num_kv_heads * 2, head_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: torch.Tensor,
|
||||
) -> None:
|
||||
raise RuntimeError("swap_blocks is not used for the TPU backend.")
|
||||
|
||||
|
||||
@dataclass
|
||||
class PallasMetadata:
|
||||
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
# |- tokenA -|......................|-- newTokens ---|
|
||||
# |---------- context_len ----------|
|
||||
# |-------------------- seq_len ---------------------|
|
||||
# |-- query_len ---|
|
||||
|
||||
# Used in the PallasAttentionBackendImpl
|
||||
slot_mapping: torch.Tensor
|
||||
block_tables: torch.Tensor
|
||||
context_lens: torch.Tensor
|
||||
query_start_loc: torch.Tensor
|
||||
num_seqs: int
|
||||
|
||||
|
||||
class PallasAttentionBackendImpl(AttentionImpl):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
) -> None:
|
||||
if blocksparse_params is not None:
|
||||
raise ValueError("Paged attention Pallas kernel does "
|
||||
"not support block-sparse attention.")
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.sliding_window = sliding_window
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
if head_size % 128 != 0:
|
||||
raise NotImplementedError("Head size must be a multiple of 128.")
|
||||
if alibi_slopes is not None:
|
||||
raise NotImplementedError("Alibi slopes is not supported.")
|
||||
if kv_cache_dtype != "auto":
|
||||
raise NotImplementedError("FP8 KV cache dtype is not supported.")
|
||||
if blocksparse_params is not None:
|
||||
raise NotImplementedError("Blocksparse is not supported.")
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"PallasAttentionBackendImpl")
|
||||
|
||||
tpu_version = torch_xla.tpu.version()
|
||||
if tpu_version < 4:
|
||||
raise NotImplementedError("TPU version must be 4 or higher.")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: PallasMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with Pallas attention.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads * head_size]
|
||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||
kv_cache = [num_blocks, block_size, num_kv_heads * 2, head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
# For determine_available_memory case.
|
||||
if kv_cache.numel() == 0:
|
||||
if output is None:
|
||||
output = torch.ones_like(query)
|
||||
return output
|
||||
|
||||
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
|
||||
num_tokens, hidden_size = query.shape
|
||||
query = query.view(num_tokens, self.num_heads, self.head_size)
|
||||
|
||||
if kv_cache.numel() > 0:
|
||||
slot_mapping = attn_metadata.slot_mapping
|
||||
write_to_kv_cache(key, value, kv_cache, slot_mapping)
|
||||
|
||||
output = torch.ops.xla.ragged_paged_attention(
|
||||
query,
|
||||
kv_cache,
|
||||
attn_metadata.context_lens,
|
||||
attn_metadata.block_tables,
|
||||
attn_metadata.query_start_loc,
|
||||
attn_metadata.num_seqs,
|
||||
# By default, the system utilizes optimized block size and
|
||||
# vmem_limit_bytes parameters from the kernel repository. However,
|
||||
# these can be manually adjusted for debugging if necessary.
|
||||
num_kv_pages_per_block=None,
|
||||
num_queries_per_block=None,
|
||||
vmem_limit_bytes=None,
|
||||
use_kernel=True,
|
||||
sm_scale=self.scale,
|
||||
sliding_window=self.sliding_window,
|
||||
soft_cap=self.logits_soft_cap,
|
||||
)
|
||||
|
||||
return output.reshape(num_tokens, hidden_size)
|
||||
|
||||
|
||||
def write_to_kv_cache(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
) -> None:
|
||||
""" Write the key and values to the KV cache.
|
||||
|
||||
Args:
|
||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||
kv_cache = [num_blocks, block_size, num_kv_heads * 2, head_size]
|
||||
|
||||
"""
|
||||
_, _, num_combined_kv_heads, head_size = kv_cache.shape
|
||||
num_kv_heads = num_combined_kv_heads // 2
|
||||
|
||||
key = key.view(-1, num_kv_heads, head_size)
|
||||
value = value.view(-1, num_kv_heads, head_size)
|
||||
|
||||
kv = torch.cat([key, value], axis=-1).reshape(-1, num_combined_kv_heads,
|
||||
head_size)
|
||||
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(kv_cache, True)
|
||||
|
||||
kv_cache = kv_cache.flatten(0, 1)
|
||||
kv_cache.index_copy_(0, slot_mapping, kv)
|
||||
198
vllm/v1/attention/backends/triton_attn.py
Normal file
198
vllm/v1/attention/backends/triton_attn.py
Normal file
@@ -0,0 +1,198 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Attention layer with PagedAttention and Triton prefix prefill."""
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata, AttentionType)
|
||||
from vllm.attention.ops.chunked_prefill_paged_decode import (
|
||||
chunked_prefill_paged_decode)
|
||||
from vllm.attention.ops.paged_attn import PagedAttention
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.flash_attn import (
|
||||
FlashAttentionMetadata, FlashAttentionMetadataBuilder)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class TritonAttentionBackend(AttentionBackend):
|
||||
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
@staticmethod
|
||||
def get_supported_head_sizes() -> list[int]:
|
||||
return [32, 64, 96, 128, 160, 192, 224, 256]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "TRITON_ATTN_VLLM_V1"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["TritonAttentionImpl"]:
|
||||
return TritonAttentionImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> type["AttentionMetadata"]:
|
||||
return FlashAttentionMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> tuple[int, ...]:
|
||||
if block_size % 16 != 0:
|
||||
raise ValueError("Block size must be a multiple of 16.")
|
||||
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def use_cascade_attention(*args, **kwargs) -> bool:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["FlashAttentionMetadataBuilder"]:
|
||||
return FlashAttentionMetadataBuilder
|
||||
|
||||
|
||||
class TritonAttentionImpl(AttentionImpl):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
use_irope: bool = False,
|
||||
) -> None:
|
||||
if blocksparse_params is not None:
|
||||
raise ValueError(
|
||||
"TritonAttention does not support block-sparse attention.")
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
if alibi_slopes is not None:
|
||||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
||||
self.alibi_slopes = alibi_slopes
|
||||
if sliding_window is None:
|
||||
self.sliding_window = (-1, -1)
|
||||
else:
|
||||
self.sliding_window = (sliding_window - 1, 0)
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.use_irope = use_irope
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
support_head_sizes = TritonAttentionBackend.get_supported_head_sizes()
|
||||
if head_size not in support_head_sizes:
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by TritonAttention. "
|
||||
f"Supported head sizes are: {support_head_sizes}.")
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"TritonAttentionImpl")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: FlashAttentionMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashAttention.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads, head_size]
|
||||
key: shape = [num_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_tokens, num_kv_heads, head_size]
|
||||
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output
|
||||
|
||||
assert attn_metadata.use_cascade is False
|
||||
|
||||
# IMPORTANT!
|
||||
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
|
||||
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
|
||||
# in this method. For example, `view` and `slice` (or `[:n]`) operations
|
||||
# are surprisingly slow even in the case they do not invoke any GPU ops.
|
||||
# Minimize the PyTorch ops in this method as much as possible.
|
||||
# Whenever making a change in this method, please benchmark the
|
||||
# performance to make sure it does not introduce any overhead.
|
||||
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||
kv_cache, self.num_kv_heads, self.head_size)
|
||||
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
PagedAttention.write_to_paged_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
use_local_attn = \
|
||||
(self.use_irope and attn_metadata.local_attn_metadata is not None)
|
||||
|
||||
if use_local_attn:
|
||||
assert attn_metadata.local_attn_metadata is not None
|
||||
local_metadata = attn_metadata.local_attn_metadata
|
||||
cu_seqlens_q = local_metadata.local_query_start_loc
|
||||
sequesd_k = local_metadata.local_seqused_k
|
||||
max_seqlen_q = local_metadata.local_max_query_len
|
||||
max_seqlen_k = local_metadata.local_max_seq_len
|
||||
block_table = local_metadata.local_block_table
|
||||
else:
|
||||
cu_seqlens_q = attn_metadata.query_start_loc
|
||||
sequesd_k = attn_metadata.seq_lens
|
||||
max_seqlen_q = attn_metadata.max_query_len
|
||||
max_seqlen_k = attn_metadata.max_seq_len
|
||||
block_table = attn_metadata.block_table
|
||||
|
||||
# Compute attention and update output up to `num_actual_tokens`.
|
||||
chunked_prefill_paged_decode(query=query[:num_actual_tokens],
|
||||
key=key[:num_actual_tokens],
|
||||
value=value[:num_actual_tokens],
|
||||
output=output[:num_actual_tokens],
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
block_table=block_table,
|
||||
query_start_loc=cu_seqlens_q,
|
||||
seq_lens=sequesd_k,
|
||||
max_seq_len=max_seqlen_k,
|
||||
max_query_len=max_seqlen_q,
|
||||
k_scale=layer._k_scale,
|
||||
v_scale=layer._v_scale,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
sliding_window=self.sliding_window[0],
|
||||
sm_scale=self.scale)
|
||||
|
||||
return output
|
||||
Reference in New Issue
Block a user