Files
enginex-vllm-bi100-qwen36/qwen3_6_scripts/paged_attn.py

493 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from dataclasses import dataclass
from typing import List, Optional, Tuple
import sys
import torch
import traceback
from vllm import _custom_ops as ops
# from vllm.attention.ops.prefix_prefill import context_attention_fwd
# NOTE: context_attention_fwd (Triton kernel from prefix_prefill.py) is NOT
# imported here. On Iluvatar BI-V100 that kernel hangs the GPU card
# permanently. Chunked-prefill / prefix-caching attention is handled by
# _forward_prefix_pytorch below (pure PyTorch, no Triton dependency).
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
_PARTITION_SIZE = 512
@dataclass
class PagedAttentionMetadata:
"""Metadata for PagedAttention."""
# (batch_size,). The length of sequences (entire tokens seen so far) per
# sequence.
seq_lens_tensor: Optional[torch.Tensor]
# Maximum sequence length in the batch. 0 if it is prefill-only batch.
max_decode_seq_len: int
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
# in the kv cache. Each block can contain up to block_size tokens.
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured.
block_tables: Optional[torch.Tensor]
class PagedAttention:
@staticmethod
def get_supported_head_sizes() -> List[int]:
return [64, 80, 96, 112, 120, 128, 192, 256]
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return (2, num_blocks, block_size * num_kv_heads * head_size)
@staticmethod
def split_kv_cache(
kv_cache: torch.Tensor,
num_kv_heads: int,
head_size: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
x = 16 // kv_cache.element_size()
num_blocks = kv_cache.shape[1]
key_cache = kv_cache[0]
key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x,
-1, x)
value_cache = kv_cache[1]
value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1)
return key_cache, value_cache
@staticmethod
def write_to_paged_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
) -> None:
ops.reshape_and_cache(
key,
value,
key_cache,
value_cache,
slot_mapping.flatten(),
kv_cache_dtype,
k_scale,
v_scale,
)
@staticmethod
def _forward_decode_pytorch(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
scale: float,
) -> torch.Tensor:
"""Pure-PyTorch decode attention for long contexts (no hardware kernel).
paged_attention_v1 hangs on BI-V100 when max_seq_len > ~32K due to
shared memory limits. For decode, q_len=1 per sequence so no Q-tiling
is needed — the attention weight tensor is [H, 1, seq_len] which is
trivially small (~5 MB at 50K).
Shapes
------
query : [num_seqs, num_heads, head_dim]
key_cache : [num_blocks, num_kv_heads, head_dim//x, block_size, x]
value_cache : [num_blocks, num_kv_heads, head_dim, block_size]
block_tables: [num_seqs, max_blocks_per_seq]
seq_lens : [num_seqs]
"""
num_seqs, num_heads, head_dim = query.shape
num_kv_heads = key_cache.shape[1]
block_size = value_cache.shape[3]
gqa_ratio = num_heads // num_kv_heads
orig_dtype = query.dtype
output = torch.empty_like(query)
try:
for i in range(num_seqs):
seq_len = int(seq_lens[i].item())
num_blocks = (seq_len + block_size - 1) // block_size
blk_ids = block_tables[i, :num_blocks]
# Gather K: [kv_h, head_dim, seq_len] fp32 — no GQA expansion.
# With kv_h=1 and seq_len=100K this is 98 MB vs 586 MB if expanded.
k_t = (key_cache[blk_ids]
.permute(0, 3, 1, 2, 4)
.contiguous()
.view(-1, num_kv_heads, head_dim))[:seq_len] \
.permute(1, 2, 0).contiguous().float() # [kv_h, d, seq_len]
# Gather V: [kv_h, seq_len, head_dim] fp32
v_t = (value_cache[blk_ids]
.permute(0, 3, 1, 2)
.contiguous()
.view(-1, num_kv_heads, head_dim))[:seq_len] \
.permute(1, 0, 2).contiguous().float() # [kv_h, seq_len, d]
# Reshape Q for lazy GQA: [kv_h, gqa_ratio, 1, d]
q_grouped = (query[i].float()
.view(num_kv_heads, gqa_ratio, head_dim)
.unsqueeze(2))
# [kv_h, gqa_ratio, 1, seq_len]
attn_w = torch.matmul(
q_grouped * scale, # [kv_h, gqa, 1, d]
k_t.unsqueeze(1)) # [kv_h, 1, d, seq_len]
attn_w = torch.softmax(attn_w, dim=-1)
# [kv_h, gqa_ratio, 1, d] → [num_heads, head_dim]
out_i = torch.matmul(attn_w, v_t.unsqueeze(1))
output[i] = out_i.view(num_heads, head_dim).to(orig_dtype)
except Exception as e:
print(f"[decode_pytorch ERROR] {type(e).__name__}: {e}",
file=sys.stderr, flush=True)
traceback.print_exc(file=sys.stderr)
raise
return output
# paged_attention_v1 on BI-V100 fails for long contexts.
# Route on actual sequence length (seq_lens.max()), not the max_seq_len
# parameter which is inflated to max_model_len in CUDA graph mode.
_PYTORCH_DECODE_THRESHOLD = 32768
@staticmethod
def forward_decode(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
max_seq_len: int,
kv_cache_dtype: str,
num_kv_heads: int,
scale: float,
alibi_slopes: Optional[torch.Tensor],
k_scale: float,
v_scale: float,
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0,
blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0,
) -> torch.Tensor:
actual_max = int(seq_lens.max().item()) if seq_lens.numel() > 0 else max_seq_len
if actual_max > PagedAttention._PYTORCH_DECODE_THRESHOLD:
return PagedAttention._forward_decode_pytorch(
query, key_cache, value_cache, block_tables, seq_lens, scale)
if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1:
# use blocksparse paged attention
block_size = value_cache.size(-1)
assert (blocksparse_block_size > 0 and
blocksparse_block_size % block_size == 0), \
(f"{blocksparse_block_size=} needs to be a multiple of"
f"{block_size=} used in block_tables.")
output = torch.empty_like(query)
block_size = value_cache.shape[3]
num_seqs, num_heads, head_size = query.shape
max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) //
_PARTITION_SIZE)
# NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use
# V1 to avoid the overhead of reduction. Also, if the number of
# sequences or heads is large, we use V1 since there is enough work
# to parallelize.
# TODO(woosuk): Tune this heuristic.
# For context len > 8192, use V2 kernel to avoid shared memory shortage.
use_v1 = (max_seq_len <= 8192
and (max_num_partitions == 1 or num_seqs * num_heads > 512))
use_v1 = True
if use_v1:
# Run PagedAttention V1.
ops.paged_attention_v1(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
)
else:
# Run PagedAttention V2.
assert _PARTITION_SIZE % block_size == 0
tmp_output = torch.empty(
size=(num_seqs, num_heads, max_num_partitions, head_size),
dtype=output.dtype,
device=output.device,
)
exp_sums = torch.empty(
size=(num_seqs, num_heads, max_num_partitions),
dtype=torch.float32,
device=output.device,
)
max_logits = torch.empty_like(exp_sums)
ops.paged_attention_v2(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
)
return output
@staticmethod
def forward_prefix(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache_dtype: str,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
query_start_loc: torch.Tensor,
seq_lens_tensor: torch.Tensor,
context_lens: torch.Tensor,
max_query_len: int,
alibi_slopes: Optional[torch.Tensor],
sliding_window: Optional[int],
k_scale: float,
v_scale: float,
) -> torch.Tensor:
# NOTE: The Triton context_attention_fwd kernel hangs on Iluvatar
# BI-V100 hardware (same class of issue as cudnnFlashAttnForward).
# Use a pure-PyTorch fallback that reads the paged KV cache directly.
return PagedAttention._forward_prefix_pytorch(
query, key, value,
key_cache, value_cache,
block_tables, query_start_loc,
seq_lens_tensor, context_lens,
)
@staticmethod
def _forward_prefix_pytorch(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
query_start_loc: torch.Tensor,
seq_lens_tensor: torch.Tensor,
context_lens: torch.Tensor,
) -> torch.Tensor:
"""Pure-PyTorch prefix-attention with query-chunking (no Triton).
For each sequence, gathers the context KV from the paged KV cache,
concatenates with the current-chunk K/V, then computes scaled-dot-
product attention with a causal mask.
Memory optimisation — GQA-aware Q-tiling
-----------------------------------------
Two complementary tricks keep peak activation memory well below 1 GB
even for 100K context on TP=4 (kv_h=1, q_h=6):
1. No GQA pre-expansion: K/V are kept at native [kv_h, kv_len, d]
resolution and GQA grouping is handled via 4D reshape+broadcast
inside the matmul. With kv_h=1 and kv_len=100K this saves ~6×
vs the old expand-then-float32 approach:
Old: [6, 100K, 256] fp32 = 586 MB each for K and V
New: [1, 100K, 256] fp32 = 98 MB each for K and V
2. Q-tiling (_ATTN_Q_CHUNK=64): attn_w [kv_h, gqa, Q, kv_len] fp32
is bounded to ~148 MB at 100K instead of growing with q_len.
Combined peak per layer (100K): ~352 MB vs ~1200 MB previously.
Shapes
------
query : [total_q_tokens, num_q_heads, head_dim]
key : [total_q_tokens, num_kv_heads, head_dim]
value : [total_q_tokens, num_kv_heads, head_dim]
key_cache : [num_blocks, num_kv_heads, head_dim//x, block_size, x]
value_cache : [num_blocks, num_kv_heads, head_dim, block_size]
block_tables : [batch_size, max_blocks_per_seq]
query_start_loc: [batch_size + 1]
seq_lens_tensor: [batch_size] total length (context + query)
context_lens : [batch_size] tokens already in KV cache
"""
# Memory-efficient query-chunked attention.
# Key optimisation: do NOT expand KV heads for GQA before materialising
# k_t / v_t. With kv_h=1 (Qwen3.6 TP=4), keeping K/V at native kv_h
# resolution saves ~6× memory vs expanding to q_h first:
# Old path (expand then float32): [6, 100K, 256] fp32 = 586 MB
# New path (keep kv_h, float32): [1, 100K, 256] fp32 = 98 MB
# GQA grouping is handled lazily inside the Q-tile matmul via 4D
# reshaping, so no extra tensors are created.
try:
_ATTN_Q_CHUNK = 64 # [kv_h, gqa, Q_CHUNK, kv_len] fp32 ≤ 150 MB
batch_size = seq_lens_tensor.shape[0]
num_q_heads = query.shape[1]
num_kv_heads = key_cache.shape[1]
head_dim = query.shape[2]
gqa_ratio = num_q_heads // num_kv_heads
# value_cache: [num_blocks, num_kv_heads, head_dim, block_size]
block_size = value_cache.shape[3]
scale = 1.0 / (head_dim ** 0.5)
output = torch.empty_like(query)
orig_dtype = query.dtype
for i in range(batch_size):
ctx_len = int(context_lens[i].item())
q_start = int(query_start_loc[i].item())
q_end = int(query_start_loc[i + 1].item())
q_len = q_end - q_start
q_i = query[q_start:q_end] # [q_len, num_q_heads, head_dim]
k_i = key [q_start:q_end] # [q_len, num_kv_heads, head_dim]
v_i = value[q_start:q_end]
# --- Build full K/V (context from cache + current chunk) ----
if ctx_len > 0:
num_ctx_blocks = (ctx_len + block_size - 1) // block_size
blk_ids = block_tables[i, :num_ctx_blocks]
# key_cache[blk_ids]: [n, kv_h, d//x, blk_sz, x]
# → permute(0,3,1,2,4) → contiguous → view → [:ctx_len]
k_ctx = (key_cache[blk_ids]
.permute(0, 3, 1, 2, 4)
.contiguous()
.view(-1, num_kv_heads, head_dim))[:ctx_len]
# value_cache[blk_ids]: [n, kv_h, d, blk_sz]
# → permute(0,3,1,2) → contiguous → view → [:ctx_len]
v_ctx = (value_cache[blk_ids]
.permute(0, 3, 1, 2)
.contiguous()
.view(-1, num_kv_heads, head_dim))[:ctx_len]
k_full = torch.cat([k_ctx, k_i], dim=0) # [kv_len, kv_h, d]
v_full = torch.cat([v_ctx, v_i], dim=0)
del k_ctx, v_ctx
else:
k_full = k_i
v_full = v_i
kv_len = k_full.shape[0] # ctx_len + q_len
# Transpose to [kv_h, kv_len, d], keep original dtype (fp16/bf16).
# Do NOT cast to fp32 here — k/v stay in fp16 to halve memory.
# attn_w is computed in fp32 (q cast to fp32 before matmul, then
# k cast inline) so softmax precision is unaffected.
# Do NOT expand GQA heads here either — gqa_ratio x memory savings.
k_t = k_full.permute(1, 0, 2).contiguous() # [kv_h, kv_len, d] fp16
del k_full
v_t = v_full.permute(1, 0, 2).contiguous() # [kv_h, kv_len, d] fp16
del v_full
# k_pos used for causal mask: shape [kv_len]
k_pos = torch.arange(kv_len, device=query.device)
# --- Query-chunked attention with lazy GQA grouping ----------
# q_i reshaped to [kv_h, gqa_ratio, qc, d] so matmul with
# k_t [kv_h, kv_len, d] (broadcast over gqa_ratio dim) gives
# attn_w [kv_h, gqa_ratio, qc, kv_len] without extra K copies.
for qc_start in range(0, q_len, _ATTN_Q_CHUNK):
qc_end = min(qc_start + _ATTN_Q_CHUNK, q_len)
qc = qc_end - qc_start
# [kv_h, gqa_ratio, qc, d]
q_t_chunk = (q_i[qc_start:qc_end]
.permute(1, 0, 2) # [q_h, qc, d]
.float()
.view(num_kv_heads, gqa_ratio, qc, head_dim))
# [kv_h, gqa_ratio, qc, kv_len]
# k_t unsqueezed to [kv_h, 1, kv_len, d] broadcasts over gqa_ratio.
# Cast k slice to fp32 inline; the temporary is freed after matmul.
attn_w = torch.matmul(q_t_chunk * scale,
k_t.unsqueeze(1).transpose(-1, -2).float())
# Causal mask for this sub-chunk:
# query absolute position = ctx_len + qc_start..qc_end-1
qc_q_pos = torch.arange(qc_start, qc_end,
device=query.device)
# mask[j, k] = True → future key, block it
mask = k_pos.unsqueeze(0) > (ctx_len + qc_q_pos.unsqueeze(1))
attn_w.masked_fill_(
mask.unsqueeze(0).unsqueeze(0), float('-inf'))
# In-place numerically stable softmax — avoids allocating a
# new 150 MB tensor (same size as attn_w) that torch.softmax
# would create, which exhausts the fragmented GPU pool.
attn_w -= attn_w.amax(dim=-1, keepdim=True)
attn_w.exp_()
attn_w /= attn_w.sum(dim=-1, keepdim=True)
# [kv_h, gqa_ratio, qc, d]; v_t cast to fp32 inline
out_c = torch.matmul(attn_w,
v_t.unsqueeze(1).float())
# reshape to [q_h, qc, d] then [qc, q_h, d]
out_c = out_c.view(num_q_heads, qc, head_dim)
output[q_start + qc_start : q_start + qc_end] = (
out_c.to(orig_dtype).permute(1, 0, 2))
except Exception as e:
print(f"[paged_attn ERROR] {type(e).__name__}: {e}", file=sys.stderr, flush=True)
traceback.print_exc(file=sys.stderr)
raise
return output
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: torch.Tensor,
) -> None:
src_key_cache = src_kv_cache[0]
dst_key_cache = dst_kv_cache[0]
ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
src_value_cache = src_kv_cache[1]
dst_value_cache = dst_kv_cache[1]
ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
) -> None:
key_caches = [kv_cache[0] for kv_cache in kv_caches]
value_caches = [kv_cache[1] for kv_cache in kv_caches]
ops.copy_blocks(key_caches, value_caches, src_to_dists)