Files
enginex-vllm-bi100-qwen36/qwen3_6_scripts/paged_attn.py

534 lines
22 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from dataclasses import dataclass
from typing import List, Optional, Tuple
import sys
import torch
import traceback
from vllm import _custom_ops as ops
# from vllm.attention.ops.prefix_prefill import context_attention_fwd
# NOTE: context_attention_fwd (Triton kernel from prefix_prefill.py) is NOT
# imported here. On Iluvatar BI-V100 that kernel hangs the GPU card
# permanently. Chunked-prefill / prefix-caching attention is handled by
# _forward_prefix_pytorch below (pure PyTorch, no Triton dependency).
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
_PARTITION_SIZE = 512
@dataclass
class PagedAttentionMetadata:
"""Metadata for PagedAttention."""
# (batch_size,). The length of sequences (entire tokens seen so far) per
# sequence.
seq_lens_tensor: Optional[torch.Tensor]
# Maximum sequence length in the batch. 0 if it is prefill-only batch.
max_decode_seq_len: int
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
# in the kv cache. Each block can contain up to block_size tokens.
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured.
block_tables: Optional[torch.Tensor]
class PagedAttention:
@staticmethod
def get_supported_head_sizes() -> List[int]:
return [64, 80, 96, 112, 120, 128, 192, 256]
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return (2, num_blocks, block_size * num_kv_heads * head_size)
@staticmethod
def split_kv_cache(
kv_cache: torch.Tensor,
num_kv_heads: int,
head_size: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
x = 16 // kv_cache.element_size()
num_blocks = kv_cache.shape[1]
key_cache = kv_cache[0]
key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x,
-1, x)
value_cache = kv_cache[1]
value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1)
return key_cache, value_cache
@staticmethod
def write_to_paged_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
) -> None:
ops.reshape_and_cache(
key,
value,
key_cache,
value_cache,
slot_mapping.flatten(),
kv_cache_dtype,
k_scale,
v_scale,
)
@staticmethod
def _forward_decode_pytorch(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
scale: float,
) -> torch.Tensor:
"""Pure-PyTorch decode attention for long contexts (no hardware kernel).
paged_attention_v1 hangs on BI-V100 when max_seq_len > ~32K due to
shared memory limits. For decode, q_len=1 per sequence so no Q-tiling
is needed — the attention weight tensor is [H, 1, seq_len] which is
trivially small (~5 MB at 50K).
Shapes
------
query : [num_seqs, num_heads, head_dim]
key_cache : [num_blocks, num_kv_heads, head_dim//x, block_size, x]
value_cache : [num_blocks, num_kv_heads, head_dim, block_size]
block_tables: [num_seqs, max_blocks_per_seq]
seq_lens : [num_seqs]
"""
num_seqs, num_heads, head_dim = query.shape
num_kv_heads = key_cache.shape[1]
block_size = value_cache.shape[3]
gqa_ratio = num_heads // num_kv_heads
orig_dtype = query.dtype
output = torch.empty_like(query)
try:
for i in range(num_seqs):
seq_len = int(seq_lens[i].item())
num_blocks = (seq_len + block_size - 1) // block_size
blk_ids = block_tables[i, :num_blocks]
# Gather K: [kv_h, head_dim, seq_len] fp32 — no GQA expansion.
# With kv_h=1 and seq_len=100K this is 98 MB vs 586 MB if expanded.
k_t = (key_cache[blk_ids]
.permute(0, 3, 1, 2, 4)
.contiguous()
.view(-1, num_kv_heads, head_dim))[:seq_len] \
.permute(1, 2, 0).contiguous().float() # [kv_h, d, seq_len]
# Gather V: [kv_h, seq_len, head_dim] fp32
v_t = (value_cache[blk_ids]
.permute(0, 3, 1, 2)
.contiguous()
.view(-1, num_kv_heads, head_dim))[:seq_len] \
.permute(1, 0, 2).contiguous().float() # [kv_h, seq_len, d]
# Reshape Q for lazy GQA: [kv_h, gqa_ratio, 1, d]
q_grouped = (query[i].float()
.view(num_kv_heads, gqa_ratio, head_dim)
.unsqueeze(2))
# [kv_h, gqa_ratio, 1, seq_len]
attn_w = torch.matmul(
q_grouped * scale, # [kv_h, gqa, 1, d]
k_t.unsqueeze(1)) # [kv_h, 1, d, seq_len]
attn_w = torch.softmax(attn_w, dim=-1)
# [kv_h, gqa_ratio, 1, d] → [num_heads, head_dim]
out_i = torch.matmul(attn_w, v_t.unsqueeze(1))
output[i] = out_i.view(num_heads, head_dim).to(orig_dtype)
except Exception as e:
print(f"[decode_pytorch ERROR] {type(e).__name__}: {e}",
file=sys.stderr, flush=True)
traceback.print_exc(file=sys.stderr)
raise
return output
# paged_attention_v1 on BI-V100 fails for long contexts.
# Route on actual sequence length (seq_lens.max()), not the max_seq_len
# parameter which is inflated to max_model_len in CUDA graph mode.
_PYTORCH_DECODE_THRESHOLD = 32768
@staticmethod
def forward_decode(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
max_seq_len: int,
kv_cache_dtype: str,
num_kv_heads: int,
scale: float,
alibi_slopes: Optional[torch.Tensor],
k_scale: float,
v_scale: float,
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0,
blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0,
) -> torch.Tensor:
actual_max = int(seq_lens.max().item()) if seq_lens.numel() > 0 else max_seq_len
if actual_max > PagedAttention._PYTORCH_DECODE_THRESHOLD:
return PagedAttention._forward_decode_pytorch(
query, key_cache, value_cache, block_tables, seq_lens, scale)
if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1:
# use blocksparse paged attention
block_size = value_cache.size(-1)
assert (blocksparse_block_size > 0 and
blocksparse_block_size % block_size == 0), \
(f"{blocksparse_block_size=} needs to be a multiple of"
f"{block_size=} used in block_tables.")
output = torch.empty_like(query)
block_size = value_cache.shape[3]
num_seqs, num_heads, head_size = query.shape
max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) //
_PARTITION_SIZE)
# NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use
# V1 to avoid the overhead of reduction. Also, if the number of
# sequences or heads is large, we use V1 since there is enough work
# to parallelize.
# TODO(woosuk): Tune this heuristic.
# For context len > 8192, use V2 kernel to avoid shared memory shortage.
use_v1 = (max_seq_len <= 8192
and (max_num_partitions == 1 or num_seqs * num_heads > 512))
use_v1 = True
if use_v1:
# Run PagedAttention V1.
ops.paged_attention_v1(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
)
else:
# Run PagedAttention V2.
assert _PARTITION_SIZE % block_size == 0
tmp_output = torch.empty(
size=(num_seqs, num_heads, max_num_partitions, head_size),
dtype=output.dtype,
device=output.device,
)
exp_sums = torch.empty(
size=(num_seqs, num_heads, max_num_partitions),
dtype=torch.float32,
device=output.device,
)
max_logits = torch.empty_like(exp_sums)
ops.paged_attention_v2(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
)
return output
@staticmethod
def forward_prefix(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache_dtype: str,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
query_start_loc: torch.Tensor,
seq_lens_tensor: torch.Tensor,
context_lens: torch.Tensor,
max_query_len: int,
alibi_slopes: Optional[torch.Tensor],
sliding_window: Optional[int],
k_scale: float,
v_scale: float,
) -> torch.Tensor:
# NOTE: The Triton context_attention_fwd kernel hangs on Iluvatar
# BI-V100 hardware (same class of issue as cudnnFlashAttnForward).
# Use a pure-PyTorch fallback that reads the paged KV cache directly.
return PagedAttention._forward_prefix_pytorch(
query, key, value,
key_cache, value_cache,
block_tables, query_start_loc,
seq_lens_tensor, context_lens,
)
@staticmethod
def _forward_prefix_pytorch(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
query_start_loc: torch.Tensor,
seq_lens_tensor: torch.Tensor,
context_lens: torch.Tensor,
) -> torch.Tensor:
"""Pure-PyTorch prefix-attention with K-tiling (Flash-Attention online softmax).
Memory complexity: O(q_len), independent of kv_len.
With chunked prefill (q_len ≤ max_num_batched_tokens = 4096) peak
per layer ≈ 96 MB regardless of context length.
Algorithm: Flash Attention online softmax.
Q is reshaped once to [kv_h, gqa, q_len, d] (24 MB) and held for all
K-tiles. For each tile a running (m, l, o) accumulator is updated —
the [q_len × kv_len] attention matrix is NEVER materialised in full.
Tile budget (kv_h=1, gqa=6, q_len=4096, tile=256 tokens):
q_seq [1, 6, 4096, 256] fp32 24 MB (held all tiles)
o_acc same shape 24 MB (held all tiles)
s same shape 24 MB (per tile, freed before exp_s)
exp_s same shape 24 MB (per tile, brief overlap with s)
Peak ≈ 96 MB (s and exp_s briefly coexist during update).
Shapes
------
query : [total_q_tokens, num_q_heads, head_dim]
key : [total_q_tokens, num_kv_heads, head_dim]
value : [total_q_tokens, num_kv_heads, head_dim]
key_cache : [num_blocks, num_kv_heads, head_dim//x, block_size, x]
value_cache : [num_blocks, num_kv_heads, head_dim, block_size]
block_tables : [batch_size, max_blocks_per_seq]
query_start_loc: [batch_size + 1]
seq_lens_tensor: [batch_size] total length (context + query)
context_lens : [batch_size] tokens already in KV cache
"""
try:
# Paged-block tiles for context phase.
# tile_sz = _BLOCKS_PER_TILE × block_size (e.g. 16×16 = 256 tokens).
# Score tensor [kv_h, gqa, q_len, tile_sz] fp32 = 24 MB per tile.
# Same tile size reused for the current-chunk phase.
_BLOCKS_PER_TILE = 32
batch_size = seq_lens_tensor.shape[0]
num_q_heads = query.shape[1]
num_kv_heads = key_cache.shape[1]
head_dim = query.shape[2]
gqa_ratio = num_q_heads // num_kv_heads
block_size = value_cache.shape[3]
tile_sz = _BLOCKS_PER_TILE * block_size
scale = head_dim ** -0.5
orig_dtype = query.dtype
output = torch.empty_like(query)
dev = query.device
for i in range(batch_size):
ctx_len = int(context_lens[i].item())
q_start = int(query_start_loc[i].item())
q_end = int(query_start_loc[i + 1].item())
q_len = q_end - q_start
q_i = query[q_start:q_end] # [q_len, q_h, d]
k_i = key [q_start:q_end] # [q_len, kv_h, d]
v_i = value[q_start:q_end]
# Q reshaped and scaled once; held for all K-tiles.
# [kv_h, gqa, q_len, d] fp32 — 24 MB for q_len=4096, d=256
q_seq = (q_i.permute(1, 0, 2)
.float()
.view(num_kv_heads, gqa_ratio, q_len, head_dim)
.mul_(scale))
# Flash-Attention online-softmax accumulators.
# m, l : [kv_h, gqa, q_len] fp32 — <0.1 MB
# o : [kv_h, gqa, q_len, d] fp32 — 24 MB
m = torch.full((num_kv_heads, gqa_ratio, q_len),
float('-inf'), dtype=torch.float32, device=dev)
l = torch.zeros_like(m)
o = torch.zeros((num_kv_heads, gqa_ratio, q_len, head_dim),
dtype=torch.float32, device=dev)
# --------------------------------------------------------------
# Phase 1 — context tokens (positions 0 … ctx_len-1).
#
# Every context key has absolute position < ctx_len; every
# query has position ≥ ctx_len. k_pos < q_pos is always True
# → no causal mask needed for pure context tiles.
# --------------------------------------------------------------
if ctx_len > 0:
num_ctx_blocks = (ctx_len + block_size - 1) // block_size
for tile_blk in range(0, num_ctx_blocks, _BLOCKS_PER_TILE):
blk_end = min(tile_blk + _BLOCKS_PER_TILE, num_ctx_blocks)
blk_ids = block_tables[i, tile_blk:blk_end]
# Gather K/V for this tile.
# key_cache [blk_ids]: [n, kv_h, d//x, blk_sz, x]
# value_cache[blk_ids]: [n, kv_h, d, blk_sz]
k_tile = (key_cache[blk_ids]
.permute(0, 3, 1, 2, 4)
.contiguous()
.view(-1, num_kv_heads, head_dim))
v_tile = (value_cache[blk_ids]
.permute(0, 3, 1, 2)
.contiguous()
.view(-1, num_kv_heads, head_dim))
# Trim padding in the last block of the tile.
valid = (min(blk_end * block_size, ctx_len)
- tile_blk * block_size)
k_tile = k_tile[:valid] # [valid, kv_h, d]
v_tile = v_tile[:valid]
# k_t: [kv_h, 1, d, valid] (broadcast over gqa_ratio)
# v_t: [kv_h, 1, valid, d]
k_t = (k_tile.permute(1, 0, 2)
.unsqueeze(1)
.transpose(-1, -2)
.float())
v_t = (v_tile.permute(1, 0, 2)
.unsqueeze(1)
.float())
del k_tile, v_tile
# Scores: [kv_h, gqa, q_len, valid]
s = torch.matmul(q_seq, k_t)
del k_t
# No causal mask: all context keys precede all queries.
# Online softmax update — Flash-Attention Algorithm 1.
# exp_s = s - new_max (in-place exp after del s)
m_blk = s.amax(dim=-1)
m_new = torch.maximum(m, m_blk)
exp_s = s - m_new.unsqueeze(-1)
del s
exp_s.exp_()
corr = torch.exp(m - m_new)
m.copy_(m_new)
del m_blk, m_new
l.mul_(corr).add_(exp_s.sum(dim=-1))
o.mul_(corr.unsqueeze(-1)).add_(
torch.matmul(exp_s, v_t))
del exp_s, v_t, corr
# --------------------------------------------------------------
# Phase 2 — current-chunk tokens (positions ctx_len … ctx_len+q_len-1).
#
# Causal mask: query at relative position j sees key at relative
# position k only when k ≤ j. Tiles of tile_sz tokens each.
# --------------------------------------------------------------
for kc_start in range(0, q_len, tile_sz):
kc_end = min(kc_start + tile_sz, q_len)
kc_len = kc_end - kc_start
k_blk = k_i[kc_start:kc_end] # [kc_len, kv_h, d]
v_blk = v_i[kc_start:kc_end]
k_t = (k_blk.permute(1, 0, 2)
.unsqueeze(1)
.transpose(-1, -2)
.float()) # [kv_h, 1, d, kc_len]
v_t = (v_blk.permute(1, 0, 2)
.unsqueeze(1)
.float()) # [kv_h, 1, kc_len, d]
s = torch.matmul(q_seq, k_t) # [kv_h, gqa, q_len, kc_len]
del k_t
# Causal mask: key at (kc_start+k) must not exceed query j.
k_rel = torch.arange(kc_start, kc_end, device=dev)
q_rel = torch.arange(q_len, device=dev)
mask = k_rel.unsqueeze(0) > q_rel.unsqueeze(1) # [q_len, kc_len]
s.masked_fill_(mask.unsqueeze(0).unsqueeze(0), float('-inf'))
del mask, k_rel, q_rel
# Online softmax update (identical to context phase).
m_blk = s.amax(dim=-1)
m_new = torch.maximum(m, m_blk)
exp_s = s - m_new.unsqueeze(-1)
del s
exp_s.exp_()
corr = torch.exp(m - m_new)
m.copy_(m_new)
del m_blk, m_new
l.mul_(corr).add_(exp_s.sum(dim=-1))
o.mul_(corr.unsqueeze(-1)).add_(
torch.matmul(exp_s, v_t))
del exp_s, v_t, corr
# --------------------------------------------------------------
# Finalize: normalize running output by normalization factor.
# o: [kv_h, gqa, q_len, d] → [q_len, q_h, d]
# --------------------------------------------------------------
o.div_(l.unsqueeze(-1))
output[q_start:q_end] = (
o.view(num_q_heads, q_len, head_dim)
.permute(1, 0, 2)
.to(orig_dtype)
)
except Exception as e:
print(f"[paged_attn ERROR] {type(e).__name__}: {e}",
file=sys.stderr, flush=True)
traceback.print_exc(file=sys.stderr)
raise
return output
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: torch.Tensor,
) -> None:
src_key_cache = src_kv_cache[0]
dst_key_cache = dst_kv_cache[0]
ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
src_value_cache = src_kv_cache[1]
dst_value_cache = dst_kv_cache[1]
ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
) -> None:
key_caches = [kv_cache[0] for kv_cache in kv_caches]
value_caches = [kv_cache[1] for kv_cache in kv_caches]
ops.copy_blocks(key_caches, value_caches, src_to_dists)