Utilize chunked prefill + K-tiling techniques to ensure 100K context
This commit is contained in:
@@ -309,28 +309,23 @@ class PagedAttention:
|
||||
seq_lens_tensor: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Pure-PyTorch prefix-attention with query-chunking (no Triton).
|
||||
"""Pure-PyTorch prefix-attention with K-tiling (Flash-Attention online softmax).
|
||||
|
||||
For each sequence, gathers the context KV from the paged KV cache,
|
||||
concatenates with the current-chunk K/V, then computes scaled-dot-
|
||||
product attention with a causal mask.
|
||||
Memory complexity: O(q_len), independent of kv_len.
|
||||
With chunked prefill (q_len ≤ max_num_batched_tokens = 4096) peak
|
||||
per layer ≈ 96 MB regardless of context length.
|
||||
|
||||
Memory optimisation — GQA-aware Q-tiling
|
||||
-----------------------------------------
|
||||
Two complementary tricks keep peak activation memory well below 1 GB
|
||||
even for 100K context on TP=4 (kv_h=1, q_h=6):
|
||||
Algorithm: Flash Attention online softmax.
|
||||
Q is reshaped once to [kv_h, gqa, q_len, d] (24 MB) and held for all
|
||||
K-tiles. For each tile a running (m, l, o) accumulator is updated —
|
||||
the [q_len × kv_len] attention matrix is NEVER materialised in full.
|
||||
|
||||
1. No GQA pre-expansion: K/V are kept at native [kv_h, kv_len, d]
|
||||
resolution and GQA grouping is handled via 4D reshape+broadcast
|
||||
inside the matmul. With kv_h=1 and kv_len=100K this saves ~6×
|
||||
vs the old expand-then-float32 approach:
|
||||
Old: [6, 100K, 256] fp32 = 586 MB each for K and V
|
||||
New: [1, 100K, 256] fp32 = 98 MB each for K and V
|
||||
|
||||
2. Q-tiling (_ATTN_Q_CHUNK=64): attn_w [kv_h, gqa, Q, kv_len] fp32
|
||||
is bounded to ~148 MB at 100K instead of growing with q_len.
|
||||
|
||||
Combined peak per layer (100K): ~352 MB vs ~1200 MB previously.
|
||||
Tile budget (kv_h=1, gqa=6, q_len=4096, tile=256 tokens):
|
||||
q_seq [1, 6, 4096, 256] fp32 24 MB (held all tiles)
|
||||
o_acc same shape 24 MB (held all tiles)
|
||||
s same shape 24 MB (per tile, freed before exp_s)
|
||||
exp_s same shape 24 MB (per tile, brief overlap with s)
|
||||
Peak ≈ 96 MB (s and exp_s briefly coexist during update).
|
||||
|
||||
Shapes
|
||||
------
|
||||
@@ -344,29 +339,24 @@ class PagedAttention:
|
||||
seq_lens_tensor: [batch_size] total length (context + query)
|
||||
context_lens : [batch_size] tokens already in KV cache
|
||||
"""
|
||||
# Memory-efficient query-chunked attention.
|
||||
# Key optimisation: do NOT expand KV heads for GQA before materialising
|
||||
# k_t / v_t. With kv_h=1 (Qwen3.6 TP=4), keeping K/V at native kv_h
|
||||
# resolution saves ~6× memory vs expanding to q_h first:
|
||||
# Old path (expand then float32): [6, 100K, 256] fp32 = 586 MB
|
||||
# New path (keep kv_h, float32): [1, 100K, 256] fp32 = 98 MB
|
||||
# GQA grouping is handled lazily inside the Q-tile matmul via 4D
|
||||
# reshaping, so no extra tensors are created.
|
||||
try:
|
||||
_ATTN_Q_CHUNK = 64 # [kv_h, gqa, Q_CHUNK, kv_len] fp32 ≤ 150 MB
|
||||
# Paged-block tiles for context phase.
|
||||
# tile_sz = _BLOCKS_PER_TILE × block_size (e.g. 16×16 = 256 tokens).
|
||||
# Score tensor [kv_h, gqa, q_len, tile_sz] fp32 = 24 MB per tile.
|
||||
# Same tile size reused for the current-chunk phase.
|
||||
_BLOCKS_PER_TILE = 32
|
||||
|
||||
batch_size = seq_lens_tensor.shape[0]
|
||||
num_q_heads = query.shape[1]
|
||||
num_kv_heads = key_cache.shape[1]
|
||||
head_dim = query.shape[2]
|
||||
gqa_ratio = num_q_heads // num_kv_heads
|
||||
|
||||
# value_cache: [num_blocks, num_kv_heads, head_dim, block_size]
|
||||
block_size = value_cache.shape[3]
|
||||
|
||||
scale = 1.0 / (head_dim ** 0.5)
|
||||
output = torch.empty_like(query)
|
||||
orig_dtype = query.dtype
|
||||
block_size = value_cache.shape[3]
|
||||
tile_sz = _BLOCKS_PER_TILE * block_size
|
||||
scale = head_dim ** -0.5
|
||||
orig_dtype = query.dtype
|
||||
output = torch.empty_like(query)
|
||||
dev = query.device
|
||||
|
||||
for i in range(batch_size):
|
||||
ctx_len = int(context_lens[i].item())
|
||||
@@ -374,96 +364,147 @@ class PagedAttention:
|
||||
q_end = int(query_start_loc[i + 1].item())
|
||||
q_len = q_end - q_start
|
||||
|
||||
q_i = query[q_start:q_end] # [q_len, num_q_heads, head_dim]
|
||||
k_i = key [q_start:q_end] # [q_len, num_kv_heads, head_dim]
|
||||
q_i = query[q_start:q_end] # [q_len, q_h, d]
|
||||
k_i = key [q_start:q_end] # [q_len, kv_h, d]
|
||||
v_i = value[q_start:q_end]
|
||||
|
||||
# --- Build full K/V (context from cache + current chunk) ----
|
||||
# Q reshaped and scaled once; held for all K-tiles.
|
||||
# [kv_h, gqa, q_len, d] fp32 — 24 MB for q_len=4096, d=256
|
||||
q_seq = (q_i.permute(1, 0, 2)
|
||||
.float()
|
||||
.view(num_kv_heads, gqa_ratio, q_len, head_dim)
|
||||
.mul_(scale))
|
||||
|
||||
# Flash-Attention online-softmax accumulators.
|
||||
# m, l : [kv_h, gqa, q_len] fp32 — <0.1 MB
|
||||
# o : [kv_h, gqa, q_len, d] fp32 — 24 MB
|
||||
m = torch.full((num_kv_heads, gqa_ratio, q_len),
|
||||
float('-inf'), dtype=torch.float32, device=dev)
|
||||
l = torch.zeros_like(m)
|
||||
o = torch.zeros((num_kv_heads, gqa_ratio, q_len, head_dim),
|
||||
dtype=torch.float32, device=dev)
|
||||
|
||||
# --------------------------------------------------------------
|
||||
# Phase 1 — context tokens (positions 0 … ctx_len-1).
|
||||
#
|
||||
# Every context key has absolute position < ctx_len; every
|
||||
# query has position ≥ ctx_len. k_pos < q_pos is always True
|
||||
# → no causal mask needed for pure context tiles.
|
||||
# --------------------------------------------------------------
|
||||
if ctx_len > 0:
|
||||
num_ctx_blocks = (ctx_len + block_size - 1) // block_size
|
||||
blk_ids = block_tables[i, :num_ctx_blocks]
|
||||
for tile_blk in range(0, num_ctx_blocks, _BLOCKS_PER_TILE):
|
||||
blk_end = min(tile_blk + _BLOCKS_PER_TILE, num_ctx_blocks)
|
||||
blk_ids = block_tables[i, tile_blk:blk_end]
|
||||
|
||||
# key_cache[blk_ids]: [n, kv_h, d//x, blk_sz, x]
|
||||
# → permute(0,3,1,2,4) → contiguous → view → [:ctx_len]
|
||||
k_ctx = (key_cache[blk_ids]
|
||||
.permute(0, 3, 1, 2, 4)
|
||||
.contiguous()
|
||||
.view(-1, num_kv_heads, head_dim))[:ctx_len]
|
||||
# Gather K/V for this tile.
|
||||
# key_cache [blk_ids]: [n, kv_h, d//x, blk_sz, x]
|
||||
# value_cache[blk_ids]: [n, kv_h, d, blk_sz]
|
||||
k_tile = (key_cache[blk_ids]
|
||||
.permute(0, 3, 1, 2, 4)
|
||||
.contiguous()
|
||||
.view(-1, num_kv_heads, head_dim))
|
||||
v_tile = (value_cache[blk_ids]
|
||||
.permute(0, 3, 1, 2)
|
||||
.contiguous()
|
||||
.view(-1, num_kv_heads, head_dim))
|
||||
|
||||
# value_cache[blk_ids]: [n, kv_h, d, blk_sz]
|
||||
# → permute(0,3,1,2) → contiguous → view → [:ctx_len]
|
||||
v_ctx = (value_cache[blk_ids]
|
||||
.permute(0, 3, 1, 2)
|
||||
.contiguous()
|
||||
.view(-1, num_kv_heads, head_dim))[:ctx_len]
|
||||
# Trim padding in the last block of the tile.
|
||||
valid = (min(blk_end * block_size, ctx_len)
|
||||
- tile_blk * block_size)
|
||||
k_tile = k_tile[:valid] # [valid, kv_h, d]
|
||||
v_tile = v_tile[:valid]
|
||||
|
||||
k_full = torch.cat([k_ctx, k_i], dim=0) # [kv_len, kv_h, d]
|
||||
v_full = torch.cat([v_ctx, v_i], dim=0)
|
||||
del k_ctx, v_ctx
|
||||
else:
|
||||
k_full = k_i
|
||||
v_full = v_i
|
||||
# k_t: [kv_h, 1, d, valid] (broadcast over gqa_ratio)
|
||||
# v_t: [kv_h, 1, valid, d]
|
||||
k_t = (k_tile.permute(1, 0, 2)
|
||||
.unsqueeze(1)
|
||||
.transpose(-1, -2)
|
||||
.float())
|
||||
v_t = (v_tile.permute(1, 0, 2)
|
||||
.unsqueeze(1)
|
||||
.float())
|
||||
del k_tile, v_tile
|
||||
|
||||
kv_len = k_full.shape[0] # ctx_len + q_len
|
||||
# Scores: [kv_h, gqa, q_len, valid]
|
||||
s = torch.matmul(q_seq, k_t)
|
||||
del k_t
|
||||
# No causal mask: all context keys precede all queries.
|
||||
|
||||
# Transpose to [kv_h, kv_len, d], keep original dtype (fp16/bf16).
|
||||
# Do NOT cast to fp32 here — k/v stay in fp16 to halve memory.
|
||||
# attn_w is computed in fp32 (q cast to fp32 before matmul, then
|
||||
# k cast inline) so softmax precision is unaffected.
|
||||
# Do NOT expand GQA heads here either — gqa_ratio x memory savings.
|
||||
k_t = k_full.permute(1, 0, 2).contiguous() # [kv_h, kv_len, d] fp16
|
||||
del k_full
|
||||
v_t = v_full.permute(1, 0, 2).contiguous() # [kv_h, kv_len, d] fp16
|
||||
del v_full
|
||||
# Online softmax update — Flash-Attention Algorithm 1.
|
||||
# exp_s = s - new_max (in-place exp after del s)
|
||||
m_blk = s.amax(dim=-1)
|
||||
m_new = torch.maximum(m, m_blk)
|
||||
exp_s = s - m_new.unsqueeze(-1)
|
||||
del s
|
||||
exp_s.exp_()
|
||||
corr = torch.exp(m - m_new)
|
||||
m.copy_(m_new)
|
||||
del m_blk, m_new
|
||||
l.mul_(corr).add_(exp_s.sum(dim=-1))
|
||||
o.mul_(corr.unsqueeze(-1)).add_(
|
||||
torch.matmul(exp_s, v_t))
|
||||
del exp_s, v_t, corr
|
||||
|
||||
# k_pos used for causal mask: shape [kv_len]
|
||||
k_pos = torch.arange(kv_len, device=query.device)
|
||||
# --------------------------------------------------------------
|
||||
# Phase 2 — current-chunk tokens (positions ctx_len … ctx_len+q_len-1).
|
||||
#
|
||||
# Causal mask: query at relative position j sees key at relative
|
||||
# position k only when k ≤ j. Tiles of tile_sz tokens each.
|
||||
# --------------------------------------------------------------
|
||||
for kc_start in range(0, q_len, tile_sz):
|
||||
kc_end = min(kc_start + tile_sz, q_len)
|
||||
kc_len = kc_end - kc_start
|
||||
|
||||
# --- Query-chunked attention with lazy GQA grouping ----------
|
||||
# q_i reshaped to [kv_h, gqa_ratio, qc, d] so matmul with
|
||||
# k_t [kv_h, kv_len, d] (broadcast over gqa_ratio dim) gives
|
||||
# attn_w [kv_h, gqa_ratio, qc, kv_len] without extra K copies.
|
||||
for qc_start in range(0, q_len, _ATTN_Q_CHUNK):
|
||||
qc_end = min(qc_start + _ATTN_Q_CHUNK, q_len)
|
||||
qc = qc_end - qc_start
|
||||
k_blk = k_i[kc_start:kc_end] # [kc_len, kv_h, d]
|
||||
v_blk = v_i[kc_start:kc_end]
|
||||
|
||||
# [kv_h, gqa_ratio, qc, d]
|
||||
q_t_chunk = (q_i[qc_start:qc_end]
|
||||
.permute(1, 0, 2) # [q_h, qc, d]
|
||||
.float()
|
||||
.view(num_kv_heads, gqa_ratio, qc, head_dim))
|
||||
k_t = (k_blk.permute(1, 0, 2)
|
||||
.unsqueeze(1)
|
||||
.transpose(-1, -2)
|
||||
.float()) # [kv_h, 1, d, kc_len]
|
||||
v_t = (v_blk.permute(1, 0, 2)
|
||||
.unsqueeze(1)
|
||||
.float()) # [kv_h, 1, kc_len, d]
|
||||
|
||||
# [kv_h, gqa_ratio, qc, kv_len]
|
||||
# k_t unsqueezed to [kv_h, 1, kv_len, d] broadcasts over gqa_ratio.
|
||||
# Cast k slice to fp32 inline; the temporary is freed after matmul.
|
||||
attn_w = torch.matmul(q_t_chunk * scale,
|
||||
k_t.unsqueeze(1).transpose(-1, -2).float())
|
||||
s = torch.matmul(q_seq, k_t) # [kv_h, gqa, q_len, kc_len]
|
||||
del k_t
|
||||
|
||||
# Causal mask for this sub-chunk:
|
||||
# query absolute position = ctx_len + qc_start..qc_end-1
|
||||
qc_q_pos = torch.arange(qc_start, qc_end,
|
||||
device=query.device)
|
||||
# mask[j, k] = True → future key, block it
|
||||
mask = k_pos.unsqueeze(0) > (ctx_len + qc_q_pos.unsqueeze(1))
|
||||
attn_w.masked_fill_(
|
||||
mask.unsqueeze(0).unsqueeze(0), float('-inf'))
|
||||
# Causal mask: key at (kc_start+k) must not exceed query j.
|
||||
k_rel = torch.arange(kc_start, kc_end, device=dev)
|
||||
q_rel = torch.arange(q_len, device=dev)
|
||||
mask = k_rel.unsqueeze(0) > q_rel.unsqueeze(1) # [q_len, kc_len]
|
||||
s.masked_fill_(mask.unsqueeze(0).unsqueeze(0), float('-inf'))
|
||||
del mask, k_rel, q_rel
|
||||
|
||||
# In-place numerically stable softmax — avoids allocating a
|
||||
# new 150 MB tensor (same size as attn_w) that torch.softmax
|
||||
# would create, which exhausts the fragmented GPU pool.
|
||||
attn_w -= attn_w.amax(dim=-1, keepdim=True)
|
||||
attn_w.exp_()
|
||||
attn_w /= attn_w.sum(dim=-1, keepdim=True)
|
||||
# [kv_h, gqa_ratio, qc, d]; v_t cast to fp32 inline
|
||||
out_c = torch.matmul(attn_w,
|
||||
v_t.unsqueeze(1).float())
|
||||
# reshape to [q_h, qc, d] then [qc, q_h, d]
|
||||
out_c = out_c.view(num_q_heads, qc, head_dim)
|
||||
# Online softmax update (identical to context phase).
|
||||
m_blk = s.amax(dim=-1)
|
||||
m_new = torch.maximum(m, m_blk)
|
||||
exp_s = s - m_new.unsqueeze(-1)
|
||||
del s
|
||||
exp_s.exp_()
|
||||
corr = torch.exp(m - m_new)
|
||||
m.copy_(m_new)
|
||||
del m_blk, m_new
|
||||
l.mul_(corr).add_(exp_s.sum(dim=-1))
|
||||
o.mul_(corr.unsqueeze(-1)).add_(
|
||||
torch.matmul(exp_s, v_t))
|
||||
del exp_s, v_t, corr
|
||||
|
||||
# --------------------------------------------------------------
|
||||
# Finalize: normalize running output by normalization factor.
|
||||
# o: [kv_h, gqa, q_len, d] → [q_len, q_h, d]
|
||||
# --------------------------------------------------------------
|
||||
o.div_(l.unsqueeze(-1))
|
||||
output[q_start:q_end] = (
|
||||
o.view(num_q_heads, q_len, head_dim)
|
||||
.permute(1, 0, 2)
|
||||
.to(orig_dtype)
|
||||
)
|
||||
|
||||
output[q_start + qc_start : q_start + qc_end] = (
|
||||
out_c.to(orig_dtype).permute(1, 0, 2))
|
||||
except Exception as e:
|
||||
print(f"[paged_attn ERROR] {type(e).__name__}: {e}", file=sys.stderr, flush=True)
|
||||
print(f"[paged_attn ERROR] {type(e).__name__}: {e}",
|
||||
file=sys.stderr, flush=True)
|
||||
traceback.print_exc(file=sys.stderr)
|
||||
raise
|
||||
return output
|
||||
|
||||
Reference in New Issue
Block a user