chunked prefill support and memory opts
This commit is contained in:
@@ -123,33 +123,36 @@ class PagedAttention:
|
||||
num_blocks = (seq_len + block_size - 1) // block_size
|
||||
blk_ids = block_tables[i, :num_blocks]
|
||||
|
||||
# Gather K from paged cache: [seq_len, num_kv_heads, head_dim]
|
||||
k_seq = (key_cache[blk_ids]
|
||||
.permute(0, 3, 1, 2, 4)
|
||||
.contiguous()
|
||||
.view(-1, num_kv_heads, head_dim))[:seq_len]
|
||||
# Gather K: [kv_h, head_dim, seq_len] fp32 — no GQA expansion.
|
||||
# With kv_h=1 and seq_len=100K this is 98 MB vs 586 MB if expanded.
|
||||
k_t = (key_cache[blk_ids]
|
||||
.permute(0, 3, 1, 2, 4)
|
||||
.contiguous()
|
||||
.view(-1, num_kv_heads, head_dim))[:seq_len] \
|
||||
.permute(1, 2, 0).contiguous().float() # [kv_h, d, seq_len]
|
||||
|
||||
# Gather V from paged cache: [seq_len, num_kv_heads, head_dim]
|
||||
v_seq = (value_cache[blk_ids]
|
||||
.permute(0, 3, 1, 2)
|
||||
.contiguous()
|
||||
.view(-1, num_kv_heads, head_dim))[:seq_len]
|
||||
# Gather V: [kv_h, seq_len, head_dim] fp32
|
||||
v_t = (value_cache[blk_ids]
|
||||
.permute(0, 3, 1, 2)
|
||||
.contiguous()
|
||||
.view(-1, num_kv_heads, head_dim))[:seq_len] \
|
||||
.permute(1, 0, 2).contiguous().float() # [kv_h, seq_len, d]
|
||||
|
||||
if gqa_ratio > 1:
|
||||
k_seq = k_seq.repeat_interleave(gqa_ratio, dim=1)
|
||||
v_seq = v_seq.repeat_interleave(gqa_ratio, dim=1)
|
||||
# Reshape Q for lazy GQA: [kv_h, gqa_ratio, 1, d]
|
||||
q_grouped = (query[i].float()
|
||||
.view(num_kv_heads, gqa_ratio, head_dim)
|
||||
.unsqueeze(2))
|
||||
|
||||
# [H, head_dim, seq_len] and [H, seq_len, head_dim]
|
||||
k_t = k_seq.permute(1, 2, 0).float()
|
||||
v_t = v_seq.permute(1, 0, 2).float()
|
||||
|
||||
# q: [H, 1, head_dim]; attn_w: [H, 1, seq_len]
|
||||
q_i = query[i].float().unsqueeze(1)
|
||||
attn_w = torch.matmul(q_i * scale, k_t)
|
||||
# [kv_h, gqa_ratio, 1, seq_len]
|
||||
attn_w = torch.matmul(
|
||||
q_grouped * scale, # [kv_h, gqa, 1, d]
|
||||
k_t.unsqueeze(1)) # [kv_h, 1, d, seq_len]
|
||||
attn_w = torch.softmax(attn_w, dim=-1)
|
||||
|
||||
out_i = torch.matmul(attn_w, v_t) # [H, 1, head_dim]
|
||||
output[i] = out_i.squeeze(1).to(orig_dtype)
|
||||
# [kv_h, gqa_ratio, 1, d] → [num_heads, head_dim]
|
||||
out_i = torch.matmul(attn_w, v_t.unsqueeze(1))
|
||||
output[i] = out_i.view(num_heads, head_dim).to(orig_dtype)
|
||||
|
||||
except Exception as e:
|
||||
print(f"[decode_pytorch ERROR] {type(e).__name__}: {e}",
|
||||
file=sys.stderr, flush=True)
|
||||
@@ -158,11 +161,10 @@ class PagedAttention:
|
||||
|
||||
return output
|
||||
|
||||
# paged_attention_v1 on BI-V100 hangs when max_seq_len exceeds ~32K due to
|
||||
# shared memory limits; use pure-PyTorch fallback above this threshold.
|
||||
# Set to a large value to disable for now (50K decode confirmed working via
|
||||
# hardware kernel); lower to 32768 if kernel hangs are observed at long contexts.
|
||||
_PYTORCH_DECODE_THRESHOLD = 10_000_000
|
||||
# paged_attention_v1 on BI-V100 fails for long contexts.
|
||||
# Route on actual sequence length (seq_lens.max()), not the max_seq_len
|
||||
# parameter which is inflated to max_model_len in CUDA graph mode.
|
||||
_PYTORCH_DECODE_THRESHOLD = 32768
|
||||
|
||||
@staticmethod
|
||||
def forward_decode(
|
||||
@@ -184,7 +186,8 @@ class PagedAttention:
|
||||
blocksparse_block_size: int = 64,
|
||||
blocksparse_head_sliding_step: int = 0,
|
||||
) -> torch.Tensor:
|
||||
if max_seq_len > PagedAttention._PYTORCH_DECODE_THRESHOLD:
|
||||
actual_max = int(seq_lens.max().item()) if seq_lens.numel() > 0 else max_seq_len
|
||||
if actual_max > PagedAttention._PYTORCH_DECODE_THRESHOLD:
|
||||
return PagedAttention._forward_decode_pytorch(
|
||||
query, key_cache, value_cache, block_tables, seq_lens, scale)
|
||||
|
||||
@@ -312,18 +315,22 @@ class PagedAttention:
|
||||
concatenates with the current-chunk K/V, then computes scaled-dot-
|
||||
product attention with a causal mask.
|
||||
|
||||
Memory optimisation — query chunking
|
||||
------------------------------------
|
||||
A full-sequence attention matrix is O(q_len × kv_len) in float32.
|
||||
For long sequences (e.g., q_len = kv_len = 20 000) that blows up
|
||||
to ~9 GB per layer. Instead we tile the query axis in sub-chunks
|
||||
of _ATTN_Q_CHUNK tokens and accumulate the output; peak attn memory
|
||||
becomes O(_ATTN_Q_CHUNK × kv_len), e.g. 123 MB per layer for
|
||||
chunk=256 and kv_len=20 000.
|
||||
Memory optimisation — GQA-aware Q-tiling
|
||||
-----------------------------------------
|
||||
Two complementary tricks keep peak activation memory well below 1 GB
|
||||
even for 100K context on TP=4 (kv_h=1, q_h=6):
|
||||
|
||||
This replaces the need for vllm's --enable-chunked-prefill flag
|
||||
(which the vendor's vllm 0.6.3 does not properly support for
|
||||
has_inner_state=True models on BI-V100).
|
||||
1. No GQA pre-expansion: K/V are kept at native [kv_h, kv_len, d]
|
||||
resolution and GQA grouping is handled via 4D reshape+broadcast
|
||||
inside the matmul. With kv_h=1 and kv_len=100K this saves ~6×
|
||||
vs the old expand-then-float32 approach:
|
||||
Old: [6, 100K, 256] fp32 = 586 MB each for K and V
|
||||
New: [1, 100K, 256] fp32 = 98 MB each for K and V
|
||||
|
||||
2. Q-tiling (_ATTN_Q_CHUNK=64): attn_w [kv_h, gqa, Q, kv_len] fp32
|
||||
is bounded to ~148 MB at 100K instead of growing with q_len.
|
||||
|
||||
Combined peak per layer (100K): ~352 MB vs ~1200 MB previously.
|
||||
|
||||
Shapes
|
||||
------
|
||||
@@ -337,13 +344,16 @@ class PagedAttention:
|
||||
seq_lens_tensor: [batch_size] total length (context + query)
|
||||
context_lens : [batch_size] tokens already in KV cache
|
||||
"""
|
||||
# Maximum query tokens to process at once per attention step.
|
||||
# Tune this to balance memory vs kernel-launch overhead:
|
||||
# 256 → ~120 MB peak attn memory (conservative, safe for 20K ctx)
|
||||
# 512 → ~240 MB peak attn memory
|
||||
# 1024 → ~490 MB peak attn memory
|
||||
# Memory-efficient query-chunked attention.
|
||||
# Key optimisation: do NOT expand KV heads for GQA before materialising
|
||||
# k_t / v_t. With kv_h=1 (Qwen3.6 TP=4), keeping K/V at native kv_h
|
||||
# resolution saves ~6× memory vs expanding to q_h first:
|
||||
# Old path (expand then float32): [6, 100K, 256] fp32 = 586 MB
|
||||
# New path (keep kv_h, float32): [1, 100K, 256] fp32 = 98 MB
|
||||
# GQA grouping is handled lazily inside the Q-tile matmul via 4D
|
||||
# reshaping, so no extra tensors are created.
|
||||
try:
|
||||
_ATTN_Q_CHUNK = 256
|
||||
_ATTN_Q_CHUNK = 64 # [kv_h, gqa, Q_CHUNK, kv_len] fp32 ≤ 150 MB
|
||||
|
||||
batch_size = seq_lens_tensor.shape[0]
|
||||
num_q_heads = query.shape[1]
|
||||
@@ -389,50 +399,66 @@ class PagedAttention:
|
||||
|
||||
k_full = torch.cat([k_ctx, k_i], dim=0) # [kv_len, kv_h, d]
|
||||
v_full = torch.cat([v_ctx, v_i], dim=0)
|
||||
del k_ctx, v_ctx
|
||||
else:
|
||||
k_full = k_i
|
||||
v_full = v_i
|
||||
|
||||
kv_len = k_full.shape[0] # ctx_len + q_len
|
||||
|
||||
# GQA: expand KV heads to match Q heads
|
||||
if gqa_ratio > 1:
|
||||
k_full = k_full.repeat_interleave(gqa_ratio, dim=1)
|
||||
v_full = v_full.repeat_interleave(gqa_ratio, dim=1)
|
||||
|
||||
k_t = k_full.permute(1, 0, 2).float() # [H, kv_len, d]
|
||||
v_t = v_full.permute(1, 0, 2).float() # [H, kv_len, d]
|
||||
# Transpose to [kv_h, kv_len, d], keep original dtype (fp16/bf16).
|
||||
# Do NOT cast to fp32 here — k/v stay in fp16 to halve memory.
|
||||
# attn_w is computed in fp32 (q cast to fp32 before matmul, then
|
||||
# k cast inline) so softmax precision is unaffected.
|
||||
# Do NOT expand GQA heads here either — gqa_ratio x memory savings.
|
||||
k_t = k_full.permute(1, 0, 2).contiguous() # [kv_h, kv_len, d] fp16
|
||||
del k_full
|
||||
v_t = v_full.permute(1, 0, 2).contiguous() # [kv_h, kv_len, d] fp16
|
||||
del v_full
|
||||
|
||||
# k_pos used for causal mask: shape [kv_len]
|
||||
k_pos = torch.arange(kv_len, device=query.device)
|
||||
|
||||
# --- Query-chunked attention --------------------------------
|
||||
# Process _ATTN_Q_CHUNK query tokens at a time.
|
||||
# Peak attn tensor: [H, _ATTN_Q_CHUNK, kv_len] float32
|
||||
# instead of [H, q_len, kv_len] float32.
|
||||
# --- Query-chunked attention with lazy GQA grouping ----------
|
||||
# q_i reshaped to [kv_h, gqa_ratio, qc, d] so matmul with
|
||||
# k_t [kv_h, kv_len, d] (broadcast over gqa_ratio dim) gives
|
||||
# attn_w [kv_h, gqa_ratio, qc, kv_len] without extra K copies.
|
||||
for qc_start in range(0, q_len, _ATTN_Q_CHUNK):
|
||||
qc_end = min(qc_start + _ATTN_Q_CHUNK, q_len)
|
||||
qc = qc_end - qc_start
|
||||
|
||||
# [H, qc, d]
|
||||
# [kv_h, gqa_ratio, qc, d]
|
||||
q_t_chunk = (q_i[qc_start:qc_end]
|
||||
.permute(1, 0, 2)
|
||||
.float())
|
||||
.permute(1, 0, 2) # [q_h, qc, d]
|
||||
.float()
|
||||
.view(num_kv_heads, gqa_ratio, qc, head_dim))
|
||||
|
||||
# [H, qc, kv_len]
|
||||
# [kv_h, gqa_ratio, qc, kv_len]
|
||||
# k_t unsqueezed to [kv_h, 1, kv_len, d] broadcasts over gqa_ratio.
|
||||
# Cast k slice to fp32 inline; the temporary is freed after matmul.
|
||||
attn_w = torch.matmul(q_t_chunk * scale,
|
||||
k_t.transpose(-1, -2))
|
||||
k_t.unsqueeze(1).transpose(-1, -2).float())
|
||||
|
||||
# Causal mask for this sub-chunk:
|
||||
# query absolute position = ctx_len + qc_start..qc_end-1
|
||||
# can attend to k_pos <= its own absolute position
|
||||
qc_q_pos = torch.arange(qc_start, qc_end,
|
||||
device=query.device)
|
||||
# mask[j, k] = True → future key, block it
|
||||
mask = k_pos.unsqueeze(0) > (ctx_len + qc_q_pos.unsqueeze(1))
|
||||
attn_w = attn_w.masked_fill(mask.unsqueeze(0), float('-inf'))
|
||||
attn_w.masked_fill_(
|
||||
mask.unsqueeze(0).unsqueeze(0), float('-inf'))
|
||||
|
||||
attn_w = torch.softmax(attn_w, dim=-1) # [H, qc, kv_len]
|
||||
out_c = torch.matmul(attn_w, v_t) # [H, qc, d]
|
||||
# In-place numerically stable softmax — avoids allocating a
|
||||
# new 150 MB tensor (same size as attn_w) that torch.softmax
|
||||
# would create, which exhausts the fragmented GPU pool.
|
||||
attn_w -= attn_w.amax(dim=-1, keepdim=True)
|
||||
attn_w.exp_()
|
||||
attn_w /= attn_w.sum(dim=-1, keepdim=True)
|
||||
# [kv_h, gqa_ratio, qc, d]; v_t cast to fp32 inline
|
||||
out_c = torch.matmul(attn_w,
|
||||
v_t.unsqueeze(1).float())
|
||||
# reshape to [q_h, qc, d] then [qc, q_h, d]
|
||||
out_c = out_c.view(num_q_heads, qc, head_dim)
|
||||
|
||||
output[q_start + qc_start : q_start + qc_end] = (
|
||||
out_c.to(orig_dtype).permute(1, 0, 2))
|
||||
|
||||
Reference in New Issue
Block a user