Files
2026-01-19 10:38:50 +08:00

1001 lines
38 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with AiterFlashAttention."""
from dataclasses import dataclass
from typing import ClassVar
import torch
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionImpl,
AttentionType,
MultipleOf,
)
from vllm.attention.layer import Attention
from vllm.attention.ops.merge_attn_states import merge_attn_states
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv
from vllm.utils.platform_utils import get_cu_count
from vllm.v1.attention.backends.utils import (
AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
split_decodes_prefills_and_extends,
)
from vllm.v1.kv_cache_interface import AttentionSpec
_PARTITION_SIZE_ROCM = 256
_CP_TOKENS_PER_ITER_ROCM = 32 * 1024
if current_platform.is_rocm():
import aiter
from vllm.triton_utils import tl, triton
def block_size(x, head_dim):
return min(65536 // x.element_size(), triton.next_power_of_2(head_dim))
def num_programs(total_tokens):
return min(total_tokens, get_cu_count())
@triton.jit
def cp_mha_gather_cache_kernel(
key_cache_ptr, # [num_blocks, page_size, num_head, head_size]
value_cache_ptr, # [num_blocks, page_size, num_head, head_size]
key_ptr, # [num_tokens, num_heads, head_size]
value_ptr, # [num_tokens, num_heads, head_size]
block_table_ptr, # [num_batches, max_block_num]
cu_seqlens_kv_ptr, # [num_batches + 1]
token_to_batch_ptr, # [max_cum_tokens]
seq_start_ptr, # [num_batches]
k_scale_ptr,
v_scale_ptr,
num_heads,
head_size,
x,
max_block_num,
DEQUANT: tl.constexpr,
PAGE_SIZE: tl.constexpr,
CACHE_FORMAT: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
token_id = tl.program_id(0)
col_offsets = tl.arange(0, BLOCK_SIZE)
if DEQUANT:
k_scale = tl.load(k_scale_ptr)
v_scale = tl.load(v_scale_ptr)
key_ptr_offset = key_ptr + token_id * head_size * num_heads
value_ptr_offset = value_ptr + token_id * head_size * num_heads
batch_idx = tl.load(token_to_batch_ptr + token_id)
batch_start = tl.load(seq_start_ptr + batch_idx)
token_start = tl.load(cu_seqlens_kv_ptr + batch_idx)
batch_offset = token_id - token_start + batch_start
block_offset = batch_offset // PAGE_SIZE
block_id = tl.load(
block_table_ptr + max_block_num * batch_idx + block_offset
).to(tl.int64)
slot_id = batch_offset % PAGE_SIZE
if CACHE_FORMAT == "NHD":
# for kv cache layout as
# K: [num_blocks, page_size, num_head, head_dim]
# V: [num_blocks, page_size, num_head, head_dim]
key_cache_ptr_offset = (
key_cache_ptr
+ block_id * num_heads * head_size * PAGE_SIZE
+ slot_id * num_heads * head_size
)
value_cache_ptr_offset = (
value_cache_ptr
+ block_id * num_heads * head_size * PAGE_SIZE
+ slot_id * num_heads * head_size
)
for i in tl.range(0, head_size * num_heads, BLOCK_SIZE):
mask = (col_offsets + i) < head_size * num_heads
k_reg = tl.load(key_cache_ptr_offset + col_offsets + i, mask=mask)
v_reg = tl.load(value_cache_ptr_offset + col_offsets + i, mask=mask)
if DEQUANT:
k_dtype = k_reg.dtype
v_dtype = v_reg.dtype
k_reg = (k_reg.to(tl.float32) * k_scale).to(k_dtype)
v_reg = (v_reg.to(tl.float32) * v_scale).to(v_dtype)
tl.store(key_ptr_offset + col_offsets + i, k_reg, mask=mask)
tl.store(value_ptr_offset + col_offsets + i, v_reg, mask=mask)
def cp_mha_gather_cache(
key_cache: torch.Tensor,
value_cache: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
block_tables: torch.Tensor,
k_scales: torch.Tensor,
v_scales: torch.Tensor,
cu_seqlens_kv: torch.Tensor,
token_to_batch: torch.Tensor,
seq_starts: torch.Tensor,
dequant: bool,
kv_cache_layout: str,
total_tokens: int,
):
assert kv_cache_layout in ["v0", "NHD", "HND"], (
"kv_cache_layout only support v0, NHD, HND"
)
head_dim = key.shape[2]
x = 0
# assert dequant is True, "Currently, we only support "\
# "gather cache with dequant"
# For k cache layout: [num_blocks, num_heads, page_size, head_dim]
assert kv_cache_layout == "NHD", (
"ROCM_AITER_FA_BACKEND Only support NHD kv cache layout for now"
)
assert head_dim == key_cache.shape[3], (
"We assume your kv cache layout is [num_blocks, "
"page_size, num_heads, head_dim], but got otherwise"
)
page_size = key_cache.shape[1]
num_heads = key_cache.shape[2]
grid = lambda meta: (total_tokens,)
cp_mha_gather_cache_kernel[grid](
key_cache,
value_cache,
key,
value,
block_tables,
cu_seqlens_kv,
token_to_batch,
seq_starts,
k_scales,
v_scales,
num_heads,
head_dim,
x,
block_tables.size(1),
DEQUANT=dequant,
PAGE_SIZE=page_size,
CACHE_FORMAT=kv_cache_layout,
BLOCK_SIZE=head_dim,
)
logger = init_logger(__name__)
@dataclass
class AiterFlashAttentionDecodeMetadata:
max_query_len: int
min_query_len: int
max_seq_len: int
query_start_loc: torch.Tensor
@dataclass
class AiterFlashAttentionPrefillMetadata:
max_query_len: int
min_query_len: int
max_seq_len: int
query_start_loc: torch.Tensor
@dataclass
class AiterChunkSlidingWindowMetadata:
swa_seqlens: torch.Tensor
swa_cu_seqlens: torch.Tensor
swa_seq_starts: torch.Tensor
swa_token_to_batch: torch.Tensor
swa_max_seqlens: int
swa_total_tokens: int
swa_workspace: torch.Tensor
@dataclass
class AiterChunkContextMetadata:
workspace: torch.Tensor
cu_seq_lens_chunk: torch.Tensor
chunk_starts: torch.Tensor
token_to_batch: torch.Tensor
seq_tot: list[int]
max_seq_lens: list[int]
seq_lens: torch.Tensor
num_chunks: int
total_token_per_batch: list[int]
swa_metadata: AiterChunkSlidingWindowMetadata | None
@dataclass
class AiterFlashAttentionChunkPrefillMetadata:
max_query_len: int
min_query_len: int
max_seq_len: int
query_start_loc: torch.Tensor
chunk_context_metadata: AiterChunkContextMetadata
@dataclass
class AiterFlashAttentionMetadata:
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
num_actual_tokens: int # Number of tokens excluding padding.
num_actual_kv_tokens: int
max_query_len: int
query_start_loc: torch.Tensor
max_seq_len: int
seq_lens: torch.Tensor
slot_mapping: torch.Tensor
block_table: torch.Tensor
# prefill and deocde split
num_decodes: int
num_decode_tokens: int
num_prefills: int
num_prefill_tokens: int
num_extends: int
num_extend_tokens: int
decode_metadata: AiterFlashAttentionDecodeMetadata | None
prefill_metadata: AiterFlashAttentionPrefillMetadata | None
extend_metadata: AiterFlashAttentionChunkPrefillMetadata | None
# For cascade attention.
use_cascade: bool
common_prefix_len: int
total_tokens: int
class AiterFlashAttentionMetadataBuilder(
AttentionMetadataBuilder[AiterFlashAttentionMetadata]
):
_cudagraph_support = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
reorder_batch_threshold: int = 1
def __init__(
self,
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self.model_config = vllm_config.model_config
self.parallel_config = vllm_config.parallel_config
self.cache_config = vllm_config.cache_config
self.num_heads_q = self.model_config.get_num_attention_heads(
self.parallel_config
)
self.num_heads_kv = self.model_config.get_num_kv_heads(self.parallel_config)
self.headdim = self.model_config.get_head_size()
self.block_size = kv_cache_spec.block_size
# Sliding window size to be used with the AOT scheduler will be
# populated on first build() call.
self.aot_sliding_window: tuple[int, int] | None = None
self.total_tokens: int = 0
sliding_window_configs: set[tuple[int, int] | None] = set()
layers = get_layers_from_vllm_config(self.vllm_config, Attention)
for layer in layers.values():
assert isinstance(layer.impl, AiterFlashAttentionImpl)
sliding_window_configs.add(layer.impl.sliding_window)
while len(sliding_window_configs) > 0:
sliding_window_config = sliding_window_configs.pop()
if sliding_window_config is not None and sliding_window_config[0] != -1:
assert self.aot_sliding_window is None, (
"Aiter Flash ATTENTION can only support one valid sliding window!"
)
self.aot_sliding_window = sliding_window_config
self.extend_workspace = torch.empty(
[2, _CP_TOKENS_PER_ITER_ROCM, self.num_heads_kv, self.headdim],
dtype=self.model_config.dtype,
device=device,
)
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata
):
self.total_tokens = (
self.model_config.max_model_len
* self.vllm_config.scheduler_config.max_num_partial_prefills
)
res = self.build(common_prefix_len=0, common_attn_metadata=common_attn_metadata)
self.total_tokens = 0
return res
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> "AiterFlashAttentionMetadata":
split_ret = split_decodes_prefills_and_extends(
common_attn_metadata,
decode_threshold=self.reorder_batch_threshold,
)
(
num_decodes,
num_extends,
num_prefills,
num_decode_tokens,
num_extend_tokens,
num_prefill_tokens,
) = split_ret
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
seq_lens = common_attn_metadata.seq_lens_cpu
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
decode_metadata = None
if num_decodes > 0:
decode_metadata = AiterFlashAttentionDecodeMetadata(
max_query_len=query_lens_cpu[:num_decodes].max().item(),
min_query_len=query_lens_cpu[:num_decodes].min().item(),
max_seq_len=seq_lens[:num_decodes].max().item(),
query_start_loc=common_attn_metadata.query_start_loc[: num_decodes + 1],
)
prefill_metadata = None
if num_prefills > 0:
query_lens_for_prefill = query_lens_cpu[num_decodes + num_extends :]
query_start_loc_device = common_attn_metadata.query_start_loc[
num_decodes + num_extends :
]
prefill_metadata = AiterFlashAttentionPrefillMetadata(
max_query_len=query_lens_for_prefill.max().item(),
min_query_len=query_lens_for_prefill.min().item(),
max_seq_len=seq_lens[num_decodes + num_extends :].max().item(),
query_start_loc=query_start_loc_device - query_start_loc_device[0],
)
extend_metadata = None
if num_extends > 0:
num_extends_slice = slice(num_decodes, num_decodes + num_extends)
query_lens_for_extend = query_lens_cpu[num_extends_slice]
seq_lens_for_extend = common_attn_metadata.seq_lens_cpu[num_extends_slice]
computed_kv_lens = seq_lens_for_extend - query_lens_for_extend
swa_metadata = None
if self.aot_sliding_window is not None:
swa_seqlen_for_extend = torch.minimum(
seq_lens_for_extend,
query_lens_for_extend + self.aot_sliding_window[0] + 1,
)
cu_seq_lens = torch.zeros(
num_extends + 1,
dtype=torch.int32,
device=seq_lens_for_extend.device,
)
torch.cumsum(
swa_seqlen_for_extend,
dim=0,
dtype=cu_seq_lens.dtype,
out=cu_seq_lens[1:],
)
token_to_seq = torch.arange(
0,
num_extends,
dtype=torch.int32,
device=seq_lens_for_extend.device,
)
token_to_seq = torch.repeat_interleave(
token_to_seq, swa_seqlen_for_extend
)
fetched_shape = cu_seq_lens[-1].item()
# TODO(ganyi): Maybe reuse these 2 buffer from extend_workspace
swa_workspace = torch.empty(
(2, fetched_shape, self.num_heads_kv, self.headdim),
dtype=self.vllm_config.model_config.dtype,
device=self.device,
)
seq_starts = seq_lens_for_extend - swa_seqlen_for_extend
max_seqlen_k = swa_seqlen_for_extend.max().item()
total_tokens = cu_seq_lens[-1].item()
swa_metadata = AiterChunkSlidingWindowMetadata(
swa_seqlens=swa_seqlen_for_extend.to(
self.device, non_blocking=True
),
swa_cu_seqlens=cu_seq_lens.to(self.device, non_blocking=True),
swa_seq_starts=seq_starts.to(self.device, non_blocking=True),
swa_token_to_batch=token_to_seq.to(self.device, non_blocking=True),
swa_max_seqlens=max_seqlen_k,
swa_total_tokens=total_tokens,
swa_workspace=swa_workspace,
)
# allocate the equal amount of workspace for
# each chunk prefill request
max_context_chunk = _CP_TOKENS_PER_ITER_ROCM // num_extends
num_chunks = cdiv(computed_kv_lens.max().item(), max_context_chunk)
chunk_starts = (
torch.arange(num_chunks, dtype=torch.int32)
.unsqueeze(1)
.expand(-1, num_extends)
* max_context_chunk
)
chunk_ends = torch.min(
computed_kv_lens.unsqueeze(0), chunk_starts + max_context_chunk
)
chunk_seq_lens = (chunk_ends - chunk_starts).clamp(
min=0
) # [num_chunks, num_extends]
cu_seq_lens_cpu = torch.zeros(
[num_chunks, num_extends + 1], dtype=torch.int32, pin_memory=True
)
torch.cumsum(
chunk_seq_lens, dim=1, out=cu_seq_lens_cpu[:, 1:], dtype=torch.int32
)
max_cum_tokens = cu_seq_lens_cpu[:, -1].max().item()
range_idx = torch.arange(max_cum_tokens, dtype=torch.int32)[None, None, :]
idx_to_batch_tensor = range_idx == cu_seq_lens_cpu[:, 1:][:, :, None]
idx_to_batch_tensor = idx_to_batch_tensor.sum(
dim=1
) # [num_chunks, max_cum_tokens]
token_to_batch_tensor = torch.cumsum(idx_to_batch_tensor, dim=1)
chunk_context_metadata = AiterChunkContextMetadata(
workspace=self.extend_workspace,
cu_seq_lens_chunk=cu_seq_lens_cpu.to(self.device, non_blocking=True),
chunk_starts=chunk_starts.to(self.device, non_blocking=True),
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
seq_lens=chunk_seq_lens,
token_to_batch=token_to_batch_tensor.to(self.device, non_blocking=True),
num_chunks=num_chunks,
total_token_per_batch=cu_seq_lens_cpu[:, -1].tolist(),
swa_metadata=swa_metadata,
)
query_start_loc_device = common_attn_metadata.query_start_loc[
num_decodes : num_decodes + num_extends + 1
]
seq_lens_device = common_attn_metadata.seq_lens[num_extends_slice]
cu_seq_lens = torch.zeros(
num_extends + 1, dtype=torch.int32, device=seq_lens_device.device
)
torch.cumsum(
seq_lens_device, dim=0, dtype=cu_seq_lens.dtype, out=cu_seq_lens[1:]
)
extend_metadata = AiterFlashAttentionChunkPrefillMetadata(
max_query_len=query_lens_for_extend.max().item(),
min_query_len=query_lens_for_extend.min().item(),
max_seq_len=seq_lens[num_extends_slice].max().item(),
query_start_loc=query_start_loc_device - query_start_loc_device[0],
chunk_context_metadata=chunk_context_metadata,
)
num_actual_kv_tokens = torch.sum(seq_lens).item()
use_cascade = common_prefix_len > 0
attn_metadata = AiterFlashAttentionMetadata(
num_actual_tokens=common_attn_metadata.num_actual_tokens,
num_actual_kv_tokens=num_actual_kv_tokens,
max_query_len=common_attn_metadata.max_query_len,
query_start_loc=common_attn_metadata.query_start_loc,
max_seq_len=common_attn_metadata.max_seq_len,
seq_lens=common_attn_metadata.seq_lens,
block_table=common_attn_metadata.block_table_tensor,
slot_mapping=common_attn_metadata.slot_mapping,
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
num_extends=num_extends,
num_extend_tokens=num_extend_tokens,
decode_metadata=decode_metadata,
prefill_metadata=prefill_metadata,
extend_metadata=extend_metadata,
use_cascade=use_cascade,
common_prefix_len=common_prefix_len,
total_tokens=self.total_tokens,
)
return attn_metadata
def use_cascade_attention(self, *args, **kwargs) -> bool:
return False
class AiterFlashAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [MultipleOf(16)]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [64, 128, 256]
@staticmethod
def get_name() -> str:
return "FLASH_ATTN"
@staticmethod
def get_impl_cls() -> type["AiterFlashAttentionImpl"]:
return AiterFlashAttentionImpl
@staticmethod
def get_builder_cls() -> type["AiterFlashAttentionMetadataBuilder"]:
return AiterFlashAttentionMetadataBuilder
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
return (2, num_blocks, block_size, num_kv_heads, head_size)
class AiterFlashAttentionImpl(AttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: list[float] | None,
sliding_window: int | None,
kv_cache_dtype: str,
logits_soft_cap: float | None = None,
attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: int | None = None,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
if sliding_window is None:
self.sliding_window = (-1, -1)
else:
self.sliding_window = (sliding_window - 1, 0)
self.kv_cache_dtype = kv_cache_dtype
if logits_soft_cap is None:
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
logits_soft_cap = 0.0
self.logits_soft_cap = logits_soft_cap
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
if attn_type not in [AttentionType.DECODER, AttentionType.ENCODER_DECODER]:
raise NotImplementedError(
"Encoder self-attention is not implemented for FlashAttentionImpl"
)
def extend_for_sliding_window(
self,
attn_metadata: AiterFlashAttentionMetadata,
query: torch.Tensor,
key_cache,
value_cache,
output: torch.Tensor,
cu_seqlens_q: torch.Tensor,
max_seqlen_q: int,
block_table: torch.Tensor,
k_scale: float,
v_scale: float,
):
assert attn_metadata.extend_metadata is not None
assert attn_metadata.extend_metadata.chunk_context_metadata is not None
chunked_metadata = attn_metadata.extend_metadata.chunk_context_metadata
swa_metadata = chunked_metadata.swa_metadata
assert swa_metadata is not None
swa_cu_seqlens = swa_metadata.swa_cu_seqlens
swa_seq_starts = swa_metadata.swa_seq_starts
swa_token_to_batch = swa_metadata.swa_token_to_batch
swa_max_seqlens = swa_metadata.swa_max_seqlens
swa_total_tokens = swa_metadata.swa_total_tokens
key_fetched, value_fetched = (
swa_metadata.swa_workspace[0],
swa_metadata.swa_workspace[1],
)
cp_mha_gather_cache(
key_cache=key_cache,
value_cache=value_cache,
key=key_fetched,
value=value_fetched,
block_tables=block_table,
k_scales=k_scale,
v_scales=v_scale,
cu_seqlens_kv=swa_cu_seqlens,
token_to_batch=swa_token_to_batch,
seq_starts=swa_seq_starts,
dequant=False,
kv_cache_layout="NHD",
total_tokens=swa_total_tokens,
)
aiter.flash_attn_varlen_func(
q=query,
k=key_fetched,
v=value_fetched,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=swa_cu_seqlens,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=swa_max_seqlens,
min_seqlen_q=1,
dropout_p=0.0,
softmax_scale=self.scale,
causal=True,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
return_lse=False,
out=output,
)
def extend_forward(
self,
attn_metadata: AiterFlashAttentionMetadata,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
output: torch.Tensor,
cu_seqlens_q: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
min_seqlen_q: int,
block_table: torch.Tensor,
slot_mapping: torch.Tensor,
k_scale: float,
v_scale: float,
):
if self.sliding_window[0] != -1:
self.extend_for_sliding_window(
attn_metadata,
query,
key_cache,
value_cache,
output,
cu_seqlens_q,
max_seqlen_q,
block_table,
k_scale,
v_scale,
)
return
out, lse = aiter.flash_attn_varlen_func(
q=query,
k=key,
v=value,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_q,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_q,
min_seqlen_q=min_seqlen_q,
dropout_p=0.0,
softmax_scale=self.scale,
causal=True,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
return_lse=True,
)
assert attn_metadata.extend_metadata is not None
chunk_context_metadata = attn_metadata.extend_metadata.chunk_context_metadata
num_chunks = chunk_context_metadata.num_chunks
workspace = chunk_context_metadata.workspace
cu_seqlens_kv = chunk_context_metadata.cu_seq_lens_chunk
max_seqlens = chunk_context_metadata.max_seq_lens
chunk_starts = chunk_context_metadata.chunk_starts
token_to_batch = chunk_context_metadata.token_to_batch
total_token_per_batch = chunk_context_metadata.total_token_per_batch
key_fetched, value_fetched = workspace[0], workspace[1]
chunked_output = None
chunked_lse = None
for chunk_idx in range(num_chunks):
cp_mha_gather_cache(
key_cache=key_cache,
value_cache=value_cache,
key=key_fetched,
value=value_fetched,
block_tables=block_table,
k_scales=k_scale,
v_scales=v_scale,
cu_seqlens_kv=cu_seqlens_kv[chunk_idx],
token_to_batch=token_to_batch[chunk_idx],
seq_starts=chunk_starts[chunk_idx],
dequant=False,
kv_cache_layout="NHD",
total_tokens=total_token_per_batch[chunk_idx],
)
suf_out, suf_lse = aiter.flash_attn_varlen_func(
q=query,
k=key_fetched,
v=value_fetched,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_kv[chunk_idx],
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlens[chunk_idx],
min_seqlen_q=min_seqlen_q,
dropout_p=0.0,
softmax_scale=self.scale,
causal=False,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
return_lse=True,
)
if chunked_output is None:
chunked_output = suf_out
chunked_lse = suf_lse
else:
tmp_output = torch.empty_like(out)
tmp_lse = torch.empty_like(lse)
merge_attn_states(
output=tmp_output,
output_lse=tmp_lse,
prefix_output=chunked_output,
prefix_lse=chunked_lse,
suffix_output=suf_out,
suffix_lse=suf_lse,
)
chunked_output = tmp_output
chunked_lse = tmp_lse
merge_attn_states(
output=output,
prefix_output=chunked_output,
prefix_lse=chunked_lse,
suffix_output=out,
suffix_lse=lse,
)
def forward(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AiterFlashAttentionMetadata,
output: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
"""Forward pass with AiterFlashAttention.
Args:
query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache: shape =
[2, num_blocks, block_size, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
NOTE: FP8 quantization, flash-attn expect the size of
{q,k,v}_descale to be (num_sequences, num_kv_heads).
We use torch's .expand() to avoid duplicating values
"""
assert output is not None, "Output tensor must be provided."
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported for FlashAttentionImpl"
)
if attn_metadata is None:
# Profiling run.
return output.fill_(0)
# IMPORTANT!
# NOTE(woosuk): With piece-wise CUDA graphs, this method is
# executed in eager-mode PyTorch. Thus, we need to be careful
# about any CPU overhead in this method. For example, `view`
# and `slice` (or `[:n]`) operations are surprisingly slow even
# in the case they do not invoke any GPU ops.
# Minimize the PyTorch ops in this method as much as possible.
# Whenever making a change in this method, please benchmark the
# performance to make sure it does not introduce any overhead.
num_actual_tokens = attn_metadata.num_actual_tokens
key_cache, value_cache = kv_cache.unbind(0)
# key and value may be None in the case of cross attention. They are
# calculated once based on the output from the encoder and then cached
# in KV cache.
if (
self.kv_sharing_target_layer_name is None
and key is not None
and value is not None
):
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# NOTE(woosuk): Here, key and value are padded while slot_mapping
# is not padded. However, we don't need to do
# key[:num_actual_tokens] and value[:num_actual_tokens] because
# the reshape_and_cache_flash op uses the slot_mapping's shape
# to determine the number of actual tokens.
torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(current_platform.fp8_dtype())
value_cache = value_cache.view(current_platform.fp8_dtype())
# decode:extend:prefill
query = query[:num_actual_tokens]
if key is not None:
key = key[:num_actual_tokens]
if value is not None:
value = value[:num_actual_tokens]
output_actual_tokens = output[:num_actual_tokens]
num_decodes = attn_metadata.num_decodes
num_prefills = attn_metadata.num_prefills
num_extends = attn_metadata.num_extends
num_decode_tokens = attn_metadata.num_decode_tokens
num_extend_tokens = attn_metadata.num_extend_tokens
if not attn_metadata.use_cascade:
# calculate for pure prefills
if num_prefills > 0:
assert attn_metadata.prefill_metadata is not None
prefill_query = query[num_decode_tokens + num_extend_tokens :]
prefill_key = key[num_decode_tokens + num_extend_tokens :]
prefill_value = value[num_decode_tokens + num_extend_tokens :]
aiter.flash_attn_varlen_func(
q=prefill_query,
k=prefill_key,
v=prefill_value,
cu_seqlens_q=attn_metadata.prefill_metadata.query_start_loc,
cu_seqlens_k=attn_metadata.prefill_metadata.query_start_loc,
max_seqlen_q=attn_metadata.prefill_metadata.max_query_len,
max_seqlen_k=attn_metadata.prefill_metadata.max_seq_len,
min_seqlen_q=1,
dropout_p=0.0,
softmax_scale=self.scale,
causal=True,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
out=output_actual_tokens[num_decode_tokens + num_extend_tokens :],
)
# calculate for extends
if num_extends > 0:
assert attn_metadata.extend_metadata is not None
extend_tokens_slice = slice(
num_decode_tokens, num_decode_tokens + num_extend_tokens
)
extend_querys = query[extend_tokens_slice]
extend_keys = key[extend_tokens_slice]
extend_values = value[extend_tokens_slice]
extend_outputs = output[extend_tokens_slice]
self.extend_forward(
attn_metadata=attn_metadata,
query=extend_querys,
key=extend_keys,
value=extend_values,
key_cache=key_cache,
value_cache=value_cache,
output=extend_outputs,
cu_seqlens_q=attn_metadata.extend_metadata.query_start_loc,
max_seqlen_q=attn_metadata.extend_metadata.max_query_len,
max_seqlen_k=attn_metadata.extend_metadata.max_seq_len,
min_seqlen_q=1,
block_table=attn_metadata.block_table[
num_decodes : num_decodes + num_extends
],
slot_mapping=attn_metadata.slot_mapping[
num_decodes : num_decodes + num_extends
],
k_scale=layer._k_scale,
v_scale=layer._v_scale,
)
# calculate for decodes
if num_decodes > 0:
assert attn_metadata.decode_metadata is not None
if self.sliding_window[0] != -1:
from aiter.ops.triton.unified_attention import (
unified_attention,
)
descale_shape = (
attn_metadata.query_start_loc[:num_decodes].shape[0] - 1,
key_cache.shape[2],
)
unified_attention(
q=query[:num_decode_tokens],
k=key_cache,
v=value_cache,
out=output[:num_decode_tokens],
cu_seqlens_q=attn_metadata.query_start_loc[:num_decodes],
max_seqlen_q=1, # optimize this
seqused_k=attn_metadata.seq_lens[:num_decodes],
max_seqlen_k=attn_metadata.max_seq_len,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
block_table=attn_metadata.block_table[:num_decodes],
softcap=self.logits_soft_cap,
q_descale=None,
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
)
return
assert attn_metadata.decode_metadata is not None
_, num_heads, head_size = query.shape
nbytes_per_qo_elem = torch.finfo(query.dtype).bits // 8
num_seqs = attn_metadata.seq_lens.shape[0]
max_num_partitions = (
attn_metadata.max_seq_len + _PARTITION_SIZE_ROCM - 1
) // _PARTITION_SIZE_ROCM
workspace_buffer = torch.empty(
(num_seqs * num_heads * max_num_partitions * head_size)
* nbytes_per_qo_elem
+ 2 * (num_seqs * num_heads * max_num_partitions) * 4,
dtype=torch.uint8,
device=output.device,
)
torch.ops.aiter.paged_attention_v1(
output[:num_decode_tokens],
workspace_buffer,
query[:num_decode_tokens],
key_cache,
value_cache,
self.scale,
attn_metadata.block_table[:num_decodes],
attn_metadata.query_start_loc[:num_decodes],
attn_metadata.seq_lens[:num_decodes],
attn_metadata.max_seq_len,
self.alibi_slopes,
self.kv_cache_dtype,
"NHD",
self.logits_soft_cap,
layer._k_scale,
layer._v_scale,
None,
_PARTITION_SIZE_ROCM,
)
else:
raise NotImplementedError(
"Cascade attention is not implemented for ROCM AITER"
)
return output