Utilize chunked prefill + K-tiling techniques to ensure 100K context

This commit is contained in:
2026-06-05 17:00:41 +08:00
parent 2d1ef50992
commit c2de1c83b0
2 changed files with 153 additions and 118 deletions

View File

@@ -309,28 +309,23 @@ class PagedAttention:
seq_lens_tensor: torch.Tensor,
context_lens: torch.Tensor,
) -> torch.Tensor:
"""Pure-PyTorch prefix-attention with query-chunking (no Triton).
"""Pure-PyTorch prefix-attention with K-tiling (Flash-Attention online softmax).
For each sequence, gathers the context KV from the paged KV cache,
concatenates with the current-chunk K/V, then computes scaled-dot-
product attention with a causal mask.
Memory complexity: O(q_len), independent of kv_len.
With chunked prefill (q_len ≤ max_num_batched_tokens = 4096) peak
per layer ≈ 96 MB regardless of context length.
Memory optimisation — GQA-aware Q-tiling
-----------------------------------------
Two complementary tricks keep peak activation memory well below 1 GB
even for 100K context on TP=4 (kv_h=1, q_h=6):
Algorithm: Flash Attention online softmax.
Q is reshaped once to [kv_h, gqa, q_len, d] (24 MB) and held for all
K-tiles. For each tile a running (m, l, o) accumulator is updated —
the [q_len × kv_len] attention matrix is NEVER materialised in full.
1. No GQA pre-expansion: K/V are kept at native [kv_h, kv_len, d]
resolution and GQA grouping is handled via 4D reshape+broadcast
inside the matmul. With kv_h=1 and kv_len=100K this saves ~6×
vs the old expand-then-float32 approach:
Old: [6, 100K, 256] fp32 = 586 MB each for K and V
New: [1, 100K, 256] fp32 = 98 MB each for K and V
2. Q-tiling (_ATTN_Q_CHUNK=64): attn_w [kv_h, gqa, Q, kv_len] fp32
is bounded to ~148 MB at 100K instead of growing with q_len.
Combined peak per layer (100K): ~352 MB vs ~1200 MB previously.
Tile budget (kv_h=1, gqa=6, q_len=4096, tile=256 tokens):
q_seq [1, 6, 4096, 256] fp32 24 MB (held all tiles)
o_acc same shape 24 MB (held all tiles)
s same shape 24 MB (per tile, freed before exp_s)
exp_s same shape 24 MB (per tile, brief overlap with s)
Peak ≈ 96 MB (s and exp_s briefly coexist during update).
Shapes
------
@@ -344,29 +339,24 @@ class PagedAttention:
seq_lens_tensor: [batch_size] total length (context + query)
context_lens : [batch_size] tokens already in KV cache
"""
# Memory-efficient query-chunked attention.
# Key optimisation: do NOT expand KV heads for GQA before materialising
# k_t / v_t. With kv_h=1 (Qwen3.6 TP=4), keeping K/V at native kv_h
# resolution saves ~6× memory vs expanding to q_h first:
# Old path (expand then float32): [6, 100K, 256] fp32 = 586 MB
# New path (keep kv_h, float32): [1, 100K, 256] fp32 = 98 MB
# GQA grouping is handled lazily inside the Q-tile matmul via 4D
# reshaping, so no extra tensors are created.
try:
_ATTN_Q_CHUNK = 64 # [kv_h, gqa, Q_CHUNK, kv_len] fp32 ≤ 150 MB
# Paged-block tiles for context phase.
# tile_sz = _BLOCKS_PER_TILE × block_size (e.g. 16×16 = 256 tokens).
# Score tensor [kv_h, gqa, q_len, tile_sz] fp32 = 24 MB per tile.
# Same tile size reused for the current-chunk phase.
_BLOCKS_PER_TILE = 32
batch_size = seq_lens_tensor.shape[0]
num_q_heads = query.shape[1]
num_kv_heads = key_cache.shape[1]
head_dim = query.shape[2]
gqa_ratio = num_q_heads // num_kv_heads
# value_cache: [num_blocks, num_kv_heads, head_dim, block_size]
block_size = value_cache.shape[3]
scale = 1.0 / (head_dim ** 0.5)
output = torch.empty_like(query)
orig_dtype = query.dtype
block_size = value_cache.shape[3]
tile_sz = _BLOCKS_PER_TILE * block_size
scale = head_dim ** -0.5
orig_dtype = query.dtype
output = torch.empty_like(query)
dev = query.device
for i in range(batch_size):
ctx_len = int(context_lens[i].item())
@@ -374,96 +364,147 @@ class PagedAttention:
q_end = int(query_start_loc[i + 1].item())
q_len = q_end - q_start
q_i = query[q_start:q_end] # [q_len, num_q_heads, head_dim]
k_i = key [q_start:q_end] # [q_len, num_kv_heads, head_dim]
q_i = query[q_start:q_end] # [q_len, q_h, d]
k_i = key [q_start:q_end] # [q_len, kv_h, d]
v_i = value[q_start:q_end]
# --- Build full K/V (context from cache + current chunk) ----
# Q reshaped and scaled once; held for all K-tiles.
# [kv_h, gqa, q_len, d] fp32 — 24 MB for q_len=4096, d=256
q_seq = (q_i.permute(1, 0, 2)
.float()
.view(num_kv_heads, gqa_ratio, q_len, head_dim)
.mul_(scale))
# Flash-Attention online-softmax accumulators.
# m, l : [kv_h, gqa, q_len] fp32 — <0.1 MB
# o : [kv_h, gqa, q_len, d] fp32 — 24 MB
m = torch.full((num_kv_heads, gqa_ratio, q_len),
float('-inf'), dtype=torch.float32, device=dev)
l = torch.zeros_like(m)
o = torch.zeros((num_kv_heads, gqa_ratio, q_len, head_dim),
dtype=torch.float32, device=dev)
# --------------------------------------------------------------
# Phase 1 — context tokens (positions 0 … ctx_len-1).
#
# Every context key has absolute position < ctx_len; every
# query has position ≥ ctx_len. k_pos < q_pos is always True
# → no causal mask needed for pure context tiles.
# --------------------------------------------------------------
if ctx_len > 0:
num_ctx_blocks = (ctx_len + block_size - 1) // block_size
blk_ids = block_tables[i, :num_ctx_blocks]
for tile_blk in range(0, num_ctx_blocks, _BLOCKS_PER_TILE):
blk_end = min(tile_blk + _BLOCKS_PER_TILE, num_ctx_blocks)
blk_ids = block_tables[i, tile_blk:blk_end]
# key_cache[blk_ids]: [n, kv_h, d//x, blk_sz, x]
# → permute(0,3,1,2,4) → contiguous → view → [:ctx_len]
k_ctx = (key_cache[blk_ids]
.permute(0, 3, 1, 2, 4)
.contiguous()
.view(-1, num_kv_heads, head_dim))[:ctx_len]
# Gather K/V for this tile.
# key_cache [blk_ids]: [n, kv_h, d//x, blk_sz, x]
# value_cache[blk_ids]: [n, kv_h, d, blk_sz]
k_tile = (key_cache[blk_ids]
.permute(0, 3, 1, 2, 4)
.contiguous()
.view(-1, num_kv_heads, head_dim))
v_tile = (value_cache[blk_ids]
.permute(0, 3, 1, 2)
.contiguous()
.view(-1, num_kv_heads, head_dim))
# value_cache[blk_ids]: [n, kv_h, d, blk_sz]
# → permute(0,3,1,2) → contiguous → view → [:ctx_len]
v_ctx = (value_cache[blk_ids]
.permute(0, 3, 1, 2)
.contiguous()
.view(-1, num_kv_heads, head_dim))[:ctx_len]
# Trim padding in the last block of the tile.
valid = (min(blk_end * block_size, ctx_len)
- tile_blk * block_size)
k_tile = k_tile[:valid] # [valid, kv_h, d]
v_tile = v_tile[:valid]
k_full = torch.cat([k_ctx, k_i], dim=0) # [kv_len, kv_h, d]
v_full = torch.cat([v_ctx, v_i], dim=0)
del k_ctx, v_ctx
else:
k_full = k_i
v_full = v_i
# k_t: [kv_h, 1, d, valid] (broadcast over gqa_ratio)
# v_t: [kv_h, 1, valid, d]
k_t = (k_tile.permute(1, 0, 2)
.unsqueeze(1)
.transpose(-1, -2)
.float())
v_t = (v_tile.permute(1, 0, 2)
.unsqueeze(1)
.float())
del k_tile, v_tile
kv_len = k_full.shape[0] # ctx_len + q_len
# Scores: [kv_h, gqa, q_len, valid]
s = torch.matmul(q_seq, k_t)
del k_t
# No causal mask: all context keys precede all queries.
# Transpose to [kv_h, kv_len, d], keep original dtype (fp16/bf16).
# Do NOT cast to fp32 here — k/v stay in fp16 to halve memory.
# attn_w is computed in fp32 (q cast to fp32 before matmul, then
# k cast inline) so softmax precision is unaffected.
# Do NOT expand GQA heads here either — gqa_ratio x memory savings.
k_t = k_full.permute(1, 0, 2).contiguous() # [kv_h, kv_len, d] fp16
del k_full
v_t = v_full.permute(1, 0, 2).contiguous() # [kv_h, kv_len, d] fp16
del v_full
# Online softmax update — Flash-Attention Algorithm 1.
# exp_s = s - new_max (in-place exp after del s)
m_blk = s.amax(dim=-1)
m_new = torch.maximum(m, m_blk)
exp_s = s - m_new.unsqueeze(-1)
del s
exp_s.exp_()
corr = torch.exp(m - m_new)
m.copy_(m_new)
del m_blk, m_new
l.mul_(corr).add_(exp_s.sum(dim=-1))
o.mul_(corr.unsqueeze(-1)).add_(
torch.matmul(exp_s, v_t))
del exp_s, v_t, corr
# k_pos used for causal mask: shape [kv_len]
k_pos = torch.arange(kv_len, device=query.device)
# --------------------------------------------------------------
# Phase 2 — current-chunk tokens (positions ctx_len … ctx_len+q_len-1).
#
# Causal mask: query at relative position j sees key at relative
# position k only when k ≤ j. Tiles of tile_sz tokens each.
# --------------------------------------------------------------
for kc_start in range(0, q_len, tile_sz):
kc_end = min(kc_start + tile_sz, q_len)
kc_len = kc_end - kc_start
# --- Query-chunked attention with lazy GQA grouping ----------
# q_i reshaped to [kv_h, gqa_ratio, qc, d] so matmul with
# k_t [kv_h, kv_len, d] (broadcast over gqa_ratio dim) gives
# attn_w [kv_h, gqa_ratio, qc, kv_len] without extra K copies.
for qc_start in range(0, q_len, _ATTN_Q_CHUNK):
qc_end = min(qc_start + _ATTN_Q_CHUNK, q_len)
qc = qc_end - qc_start
k_blk = k_i[kc_start:kc_end] # [kc_len, kv_h, d]
v_blk = v_i[kc_start:kc_end]
# [kv_h, gqa_ratio, qc, d]
q_t_chunk = (q_i[qc_start:qc_end]
.permute(1, 0, 2) # [q_h, qc, d]
.float()
.view(num_kv_heads, gqa_ratio, qc, head_dim))
k_t = (k_blk.permute(1, 0, 2)
.unsqueeze(1)
.transpose(-1, -2)
.float()) # [kv_h, 1, d, kc_len]
v_t = (v_blk.permute(1, 0, 2)
.unsqueeze(1)
.float()) # [kv_h, 1, kc_len, d]
# [kv_h, gqa_ratio, qc, kv_len]
# k_t unsqueezed to [kv_h, 1, kv_len, d] broadcasts over gqa_ratio.
# Cast k slice to fp32 inline; the temporary is freed after matmul.
attn_w = torch.matmul(q_t_chunk * scale,
k_t.unsqueeze(1).transpose(-1, -2).float())
s = torch.matmul(q_seq, k_t) # [kv_h, gqa, q_len, kc_len]
del k_t
# Causal mask for this sub-chunk:
# query absolute position = ctx_len + qc_start..qc_end-1
qc_q_pos = torch.arange(qc_start, qc_end,
device=query.device)
# mask[j, k] = True → future key, block it
mask = k_pos.unsqueeze(0) > (ctx_len + qc_q_pos.unsqueeze(1))
attn_w.masked_fill_(
mask.unsqueeze(0).unsqueeze(0), float('-inf'))
# Causal mask: key at (kc_start+k) must not exceed query j.
k_rel = torch.arange(kc_start, kc_end, device=dev)
q_rel = torch.arange(q_len, device=dev)
mask = k_rel.unsqueeze(0) > q_rel.unsqueeze(1) # [q_len, kc_len]
s.masked_fill_(mask.unsqueeze(0).unsqueeze(0), float('-inf'))
del mask, k_rel, q_rel
# In-place numerically stable softmax — avoids allocating a
# new 150 MB tensor (same size as attn_w) that torch.softmax
# would create, which exhausts the fragmented GPU pool.
attn_w -= attn_w.amax(dim=-1, keepdim=True)
attn_w.exp_()
attn_w /= attn_w.sum(dim=-1, keepdim=True)
# [kv_h, gqa_ratio, qc, d]; v_t cast to fp32 inline
out_c = torch.matmul(attn_w,
v_t.unsqueeze(1).float())
# reshape to [q_h, qc, d] then [qc, q_h, d]
out_c = out_c.view(num_q_heads, qc, head_dim)
# Online softmax update (identical to context phase).
m_blk = s.amax(dim=-1)
m_new = torch.maximum(m, m_blk)
exp_s = s - m_new.unsqueeze(-1)
del s
exp_s.exp_()
corr = torch.exp(m - m_new)
m.copy_(m_new)
del m_blk, m_new
l.mul_(corr).add_(exp_s.sum(dim=-1))
o.mul_(corr.unsqueeze(-1)).add_(
torch.matmul(exp_s, v_t))
del exp_s, v_t, corr
# --------------------------------------------------------------
# Finalize: normalize running output by normalization factor.
# o: [kv_h, gqa, q_len, d] → [q_len, q_h, d]
# --------------------------------------------------------------
o.div_(l.unsqueeze(-1))
output[q_start:q_end] = (
o.view(num_q_heads, q_len, head_dim)
.permute(1, 0, 2)
.to(orig_dtype)
)
output[q_start + qc_start : q_start + qc_end] = (
out_c.to(orig_dtype).permute(1, 0, 2))
except Exception as e:
print(f"[paged_attn ERROR] {type(e).__name__}: {e}", file=sys.stderr, flush=True)
print(f"[paged_attn ERROR] {type(e).__name__}: {e}",
file=sys.stderr, flush=True)
traceback.print_exc(file=sys.stderr)
raise
return output