Utilize chunked prefill + K-tiling techniques to ensure 100K context

This commit is contained in:
2026-06-05 17:00:41 +08:00
parent 2d1ef50992
commit c2de1c83b0
2 changed files with 153 additions and 118 deletions

View File

@@ -309,28 +309,23 @@ class PagedAttention:
seq_lens_tensor: torch.Tensor, seq_lens_tensor: torch.Tensor,
context_lens: torch.Tensor, context_lens: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
"""Pure-PyTorch prefix-attention with query-chunking (no Triton). """Pure-PyTorch prefix-attention with K-tiling (Flash-Attention online softmax).
For each sequence, gathers the context KV from the paged KV cache, Memory complexity: O(q_len), independent of kv_len.
concatenates with the current-chunk K/V, then computes scaled-dot- With chunked prefill (q_len ≤ max_num_batched_tokens = 4096) peak
product attention with a causal mask. per layer ≈ 96 MB regardless of context length.
Memory optimisation — GQA-aware Q-tiling Algorithm: Flash Attention online softmax.
----------------------------------------- Q is reshaped once to [kv_h, gqa, q_len, d] (24 MB) and held for all
Two complementary tricks keep peak activation memory well below 1 GB K-tiles. For each tile a running (m, l, o) accumulator is updated —
even for 100K context on TP=4 (kv_h=1, q_h=6): the [q_len × kv_len] attention matrix is NEVER materialised in full.
1. No GQA pre-expansion: K/V are kept at native [kv_h, kv_len, d] Tile budget (kv_h=1, gqa=6, q_len=4096, tile=256 tokens):
resolution and GQA grouping is handled via 4D reshape+broadcast q_seq [1, 6, 4096, 256] fp32 24 MB (held all tiles)
inside the matmul. With kv_h=1 and kv_len=100K this saves ~6× o_acc same shape 24 MB (held all tiles)
vs the old expand-then-float32 approach: s same shape 24 MB (per tile, freed before exp_s)
Old: [6, 100K, 256] fp32 = 586 MB each for K and V exp_s same shape 24 MB (per tile, brief overlap with s)
New: [1, 100K, 256] fp32 = 98 MB each for K and V Peak ≈ 96 MB (s and exp_s briefly coexist during update).
2. Q-tiling (_ATTN_Q_CHUNK=64): attn_w [kv_h, gqa, Q, kv_len] fp32
is bounded to ~148 MB at 100K instead of growing with q_len.
Combined peak per layer (100K): ~352 MB vs ~1200 MB previously.
Shapes Shapes
------ ------
@@ -344,29 +339,24 @@ class PagedAttention:
seq_lens_tensor: [batch_size] total length (context + query) seq_lens_tensor: [batch_size] total length (context + query)
context_lens : [batch_size] tokens already in KV cache context_lens : [batch_size] tokens already in KV cache
""" """
# Memory-efficient query-chunked attention.
# Key optimisation: do NOT expand KV heads for GQA before materialising
# k_t / v_t. With kv_h=1 (Qwen3.6 TP=4), keeping K/V at native kv_h
# resolution saves ~6× memory vs expanding to q_h first:
# Old path (expand then float32): [6, 100K, 256] fp32 = 586 MB
# New path (keep kv_h, float32): [1, 100K, 256] fp32 = 98 MB
# GQA grouping is handled lazily inside the Q-tile matmul via 4D
# reshaping, so no extra tensors are created.
try: try:
_ATTN_Q_CHUNK = 64 # [kv_h, gqa, Q_CHUNK, kv_len] fp32 ≤ 150 MB # Paged-block tiles for context phase.
# tile_sz = _BLOCKS_PER_TILE × block_size (e.g. 16×16 = 256 tokens).
# Score tensor [kv_h, gqa, q_len, tile_sz] fp32 = 24 MB per tile.
# Same tile size reused for the current-chunk phase.
_BLOCKS_PER_TILE = 32
batch_size = seq_lens_tensor.shape[0] batch_size = seq_lens_tensor.shape[0]
num_q_heads = query.shape[1] num_q_heads = query.shape[1]
num_kv_heads = key_cache.shape[1] num_kv_heads = key_cache.shape[1]
head_dim = query.shape[2] head_dim = query.shape[2]
gqa_ratio = num_q_heads // num_kv_heads gqa_ratio = num_q_heads // num_kv_heads
# value_cache: [num_blocks, num_kv_heads, head_dim, block_size]
block_size = value_cache.shape[3] block_size = value_cache.shape[3]
tile_sz = _BLOCKS_PER_TILE * block_size
scale = 1.0 / (head_dim ** 0.5) scale = head_dim ** -0.5
output = torch.empty_like(query)
orig_dtype = query.dtype orig_dtype = query.dtype
output = torch.empty_like(query)
dev = query.device
for i in range(batch_size): for i in range(batch_size):
ctx_len = int(context_lens[i].item()) ctx_len = int(context_lens[i].item())
@@ -374,96 +364,147 @@ class PagedAttention:
q_end = int(query_start_loc[i + 1].item()) q_end = int(query_start_loc[i + 1].item())
q_len = q_end - q_start q_len = q_end - q_start
q_i = query[q_start:q_end] # [q_len, num_q_heads, head_dim] q_i = query[q_start:q_end] # [q_len, q_h, d]
k_i = key [q_start:q_end] # [q_len, num_kv_heads, head_dim] k_i = key [q_start:q_end] # [q_len, kv_h, d]
v_i = value[q_start:q_end] v_i = value[q_start:q_end]
# --- Build full K/V (context from cache + current chunk) ---- # Q reshaped and scaled once; held for all K-tiles.
# [kv_h, gqa, q_len, d] fp32 — 24 MB for q_len=4096, d=256
q_seq = (q_i.permute(1, 0, 2)
.float()
.view(num_kv_heads, gqa_ratio, q_len, head_dim)
.mul_(scale))
# Flash-Attention online-softmax accumulators.
# m, l : [kv_h, gqa, q_len] fp32 — <0.1 MB
# o : [kv_h, gqa, q_len, d] fp32 — 24 MB
m = torch.full((num_kv_heads, gqa_ratio, q_len),
float('-inf'), dtype=torch.float32, device=dev)
l = torch.zeros_like(m)
o = torch.zeros((num_kv_heads, gqa_ratio, q_len, head_dim),
dtype=torch.float32, device=dev)
# --------------------------------------------------------------
# Phase 1 — context tokens (positions 0 … ctx_len-1).
#
# Every context key has absolute position < ctx_len; every
# query has position ≥ ctx_len. k_pos < q_pos is always True
# → no causal mask needed for pure context tiles.
# --------------------------------------------------------------
if ctx_len > 0: if ctx_len > 0:
num_ctx_blocks = (ctx_len + block_size - 1) // block_size num_ctx_blocks = (ctx_len + block_size - 1) // block_size
blk_ids = block_tables[i, :num_ctx_blocks] for tile_blk in range(0, num_ctx_blocks, _BLOCKS_PER_TILE):
blk_end = min(tile_blk + _BLOCKS_PER_TILE, num_ctx_blocks)
blk_ids = block_tables[i, tile_blk:blk_end]
# key_cache[blk_ids]: [n, kv_h, d//x, blk_sz, x] # Gather K/V for this tile.
# → permute(0,3,1,2,4) → contiguous → view → [:ctx_len] # key_cache [blk_ids]: [n, kv_h, d//x, blk_sz, x]
k_ctx = (key_cache[blk_ids] # value_cache[blk_ids]: [n, kv_h, d, blk_sz]
k_tile = (key_cache[blk_ids]
.permute(0, 3, 1, 2, 4) .permute(0, 3, 1, 2, 4)
.contiguous() .contiguous()
.view(-1, num_kv_heads, head_dim))[:ctx_len] .view(-1, num_kv_heads, head_dim))
v_tile = (value_cache[blk_ids]
# value_cache[blk_ids]: [n, kv_h, d, blk_sz]
# → permute(0,3,1,2) → contiguous → view → [:ctx_len]
v_ctx = (value_cache[blk_ids]
.permute(0, 3, 1, 2) .permute(0, 3, 1, 2)
.contiguous() .contiguous()
.view(-1, num_kv_heads, head_dim))[:ctx_len] .view(-1, num_kv_heads, head_dim))
k_full = torch.cat([k_ctx, k_i], dim=0) # [kv_len, kv_h, d] # Trim padding in the last block of the tile.
v_full = torch.cat([v_ctx, v_i], dim=0) valid = (min(blk_end * block_size, ctx_len)
del k_ctx, v_ctx - tile_blk * block_size)
else: k_tile = k_tile[:valid] # [valid, kv_h, d]
k_full = k_i v_tile = v_tile[:valid]
v_full = v_i
kv_len = k_full.shape[0] # ctx_len + q_len # k_t: [kv_h, 1, d, valid] (broadcast over gqa_ratio)
# v_t: [kv_h, 1, valid, d]
k_t = (k_tile.permute(1, 0, 2)
.unsqueeze(1)
.transpose(-1, -2)
.float())
v_t = (v_tile.permute(1, 0, 2)
.unsqueeze(1)
.float())
del k_tile, v_tile
# Transpose to [kv_h, kv_len, d], keep original dtype (fp16/bf16). # Scores: [kv_h, gqa, q_len, valid]
# Do NOT cast to fp32 here — k/v stay in fp16 to halve memory. s = torch.matmul(q_seq, k_t)
# attn_w is computed in fp32 (q cast to fp32 before matmul, then del k_t
# k cast inline) so softmax precision is unaffected. # No causal mask: all context keys precede all queries.
# Do NOT expand GQA heads here either — gqa_ratio x memory savings.
k_t = k_full.permute(1, 0, 2).contiguous() # [kv_h, kv_len, d] fp16
del k_full
v_t = v_full.permute(1, 0, 2).contiguous() # [kv_h, kv_len, d] fp16
del v_full
# k_pos used for causal mask: shape [kv_len] # Online softmax update — Flash-Attention Algorithm 1.
k_pos = torch.arange(kv_len, device=query.device) # exp_s = s - new_max (in-place exp after del s)
m_blk = s.amax(dim=-1)
m_new = torch.maximum(m, m_blk)
exp_s = s - m_new.unsqueeze(-1)
del s
exp_s.exp_()
corr = torch.exp(m - m_new)
m.copy_(m_new)
del m_blk, m_new
l.mul_(corr).add_(exp_s.sum(dim=-1))
o.mul_(corr.unsqueeze(-1)).add_(
torch.matmul(exp_s, v_t))
del exp_s, v_t, corr
# --- Query-chunked attention with lazy GQA grouping ---------- # --------------------------------------------------------------
# q_i reshaped to [kv_h, gqa_ratio, qc, d] so matmul with # Phase 2 — current-chunk tokens (positions ctx_len … ctx_len+q_len-1).
# k_t [kv_h, kv_len, d] (broadcast over gqa_ratio dim) gives #
# attn_w [kv_h, gqa_ratio, qc, kv_len] without extra K copies. # Causal mask: query at relative position j sees key at relative
for qc_start in range(0, q_len, _ATTN_Q_CHUNK): # position k only when k ≤ j. Tiles of tile_sz tokens each.
qc_end = min(qc_start + _ATTN_Q_CHUNK, q_len) # --------------------------------------------------------------
qc = qc_end - qc_start for kc_start in range(0, q_len, tile_sz):
kc_end = min(kc_start + tile_sz, q_len)
kc_len = kc_end - kc_start
# [kv_h, gqa_ratio, qc, d] k_blk = k_i[kc_start:kc_end] # [kc_len, kv_h, d]
q_t_chunk = (q_i[qc_start:qc_end] v_blk = v_i[kc_start:kc_end]
.permute(1, 0, 2) # [q_h, qc, d]
.float()
.view(num_kv_heads, gqa_ratio, qc, head_dim))
# [kv_h, gqa_ratio, qc, kv_len] k_t = (k_blk.permute(1, 0, 2)
# k_t unsqueezed to [kv_h, 1, kv_len, d] broadcasts over gqa_ratio. .unsqueeze(1)
# Cast k slice to fp32 inline; the temporary is freed after matmul. .transpose(-1, -2)
attn_w = torch.matmul(q_t_chunk * scale, .float()) # [kv_h, 1, d, kc_len]
k_t.unsqueeze(1).transpose(-1, -2).float()) v_t = (v_blk.permute(1, 0, 2)
.unsqueeze(1)
.float()) # [kv_h, 1, kc_len, d]
# Causal mask for this sub-chunk: s = torch.matmul(q_seq, k_t) # [kv_h, gqa, q_len, kc_len]
# query absolute position = ctx_len + qc_start..qc_end-1 del k_t
qc_q_pos = torch.arange(qc_start, qc_end,
device=query.device)
# mask[j, k] = True → future key, block it
mask = k_pos.unsqueeze(0) > (ctx_len + qc_q_pos.unsqueeze(1))
attn_w.masked_fill_(
mask.unsqueeze(0).unsqueeze(0), float('-inf'))
# In-place numerically stable softmax — avoids allocating a # Causal mask: key at (kc_start+k) must not exceed query j.
# new 150 MB tensor (same size as attn_w) that torch.softmax k_rel = torch.arange(kc_start, kc_end, device=dev)
# would create, which exhausts the fragmented GPU pool. q_rel = torch.arange(q_len, device=dev)
attn_w -= attn_w.amax(dim=-1, keepdim=True) mask = k_rel.unsqueeze(0) > q_rel.unsqueeze(1) # [q_len, kc_len]
attn_w.exp_() s.masked_fill_(mask.unsqueeze(0).unsqueeze(0), float('-inf'))
attn_w /= attn_w.sum(dim=-1, keepdim=True) del mask, k_rel, q_rel
# [kv_h, gqa_ratio, qc, d]; v_t cast to fp32 inline
out_c = torch.matmul(attn_w, # Online softmax update (identical to context phase).
v_t.unsqueeze(1).float()) m_blk = s.amax(dim=-1)
# reshape to [q_h, qc, d] then [qc, q_h, d] m_new = torch.maximum(m, m_blk)
out_c = out_c.view(num_q_heads, qc, head_dim) exp_s = s - m_new.unsqueeze(-1)
del s
exp_s.exp_()
corr = torch.exp(m - m_new)
m.copy_(m_new)
del m_blk, m_new
l.mul_(corr).add_(exp_s.sum(dim=-1))
o.mul_(corr.unsqueeze(-1)).add_(
torch.matmul(exp_s, v_t))
del exp_s, v_t, corr
# --------------------------------------------------------------
# Finalize: normalize running output by normalization factor.
# o: [kv_h, gqa, q_len, d] → [q_len, q_h, d]
# --------------------------------------------------------------
o.div_(l.unsqueeze(-1))
output[q_start:q_end] = (
o.view(num_q_heads, q_len, head_dim)
.permute(1, 0, 2)
.to(orig_dtype)
)
output[q_start + qc_start : q_start + qc_end] = (
out_c.to(orig_dtype).permute(1, 0, 2))
except Exception as e: except Exception as e:
print(f"[paged_attn ERROR] {type(e).__name__}: {e}", file=sys.stderr, flush=True) print(f"[paged_attn ERROR] {type(e).__name__}: {e}",
file=sys.stderr, flush=True)
traceback.print_exc(file=sys.stderr) traceback.print_exc(file=sys.stderr)
raise raise
return output return output

View File

@@ -10,18 +10,11 @@
# GPU hang on BI-V100 because the Triton CUDA PTX kernels are incompatible. # GPU hang on BI-V100 because the Triton CUDA PTX kernels are incompatible.
# #
# Important Note: Qwen3.6-27B must apply TP=4,PP=2 combination in order to deploy using 8 GPUs # Important Note: Qwen3.6-27B must apply TP=4,PP=2 combination in order to deploy using 8 GPUs
#
# Recommended server start command for TP=4, context length: 50K, no chunked prefill mechanism:
# CUDA_VISIBLE_DEVICES="4,5,6,7" VLLM_ENGINE_ITERATION_TIMEOUT_S=3600 python3 -m vllm.entrypoints.openai.api_server \
# --model /workspace/models/Qwen3.6-27B --port 1111 --served-model-name llm \
# --max-model-len 50000 --enforce-eager --trust-remote-code -tp 4 --gpu-memory-utilization 0.90 \
# --max-num-seqs 1 --disable-log-requests --disable-frontend-multiprocessing \
# --max-num-batched-tokens 50000
# Recommended server start command for TP=4 support 100K, need chunked prefill # Recommended server start command for TP=4 support 100K, need chunked prefill
# CUDA_VISIBLE_DEVICES="4,5,6,7" VLLM_ENGINE_ITERATION_TIMEOUT_S=3600 python3 -m vllm.entrypoints.openai.api_server \ # CUDA_VISIBLE_DEVICES="4,5,6,7" VLLM_ENGINE_ITERATION_TIMEOUT_S=3600 python3 -m vllm.entrypoints.openai.api_server \
# --model /workspace/models/Qwen3.6-27B --port 1111 --served-model-name llm \ # --model /workspace/models/Qwen3.6-27B --port 1111 --served-model-name llm \
# --max-model-len 100000 --enforce-eager --trust-remote-code -tp 8 --gpu-memory-utilization 0.95 \ # --max-model-len 100000 --enforce-eager --trust-remote-code -tp 4 --gpu-memory-utilization 0.95 \
# --max-num-seqs 1 --disable-log-requests --disable-frontend-multiprocessing \ # --max-num-seqs 1 --disable-log-requests --disable-frontend-multiprocessing \
# --max-num-batched-tokens 4096 --enable-chunked-prefill # --max-num-batched-tokens 4096 --enable-chunked-prefill
@@ -29,8 +22,8 @@
# The Triton context_attention_fwd kernel hangs BI-V100 GPUs permanently # The Triton context_attention_fwd kernel hangs BI-V100 GPUs permanently
# (standard Triton 2.3.1 PTX is not supported by the corex runtime either). # (standard Triton 2.3.1 PTX is not supported by the corex runtime either).
# Our paged_attn.py bypasses it entirely via _forward_prefix_pytorch, which # Our paged_attn.py bypasses it entirely via _forward_prefix_pytorch, which
# also implements query-chunking (_ATTN_Q_CHUNK=256) to keep peak attention # utilizes K-tiling techniques, and also have _forward_decode_pytorch to bypass kernel
# memory at O(256 × kv_len) instead of O(q_len × kv_len). # when context length is high
cp ./paged_attn.py /usr/local/corex/lib64/python3/dist-packages/vllm/attention/ops/paged_attn.py cp ./paged_attn.py /usr/local/corex/lib64/python3/dist-packages/vllm/attention/ops/paged_attn.py
# --- transformers: Qwen3_5 tokenizer / model files -------------------------- # --- transformers: Qwen3_5 tokenizer / model files --------------------------
@@ -49,4 +42,5 @@ python3 ./patch_vllm_qwen3_5.py
# crashes (is_causal=True) or produces wrong output (attn_mask path). # crashes (is_causal=True) or produces wrong output (attn_mask path).
# The fallback uses query_start_loc to derive actual query lengths, so it # The fallback uses query_start_loc to derive actual query lengths, so it
# works correctly during profiling runs with chunked-prefill-style batches. # works correctly during profiling runs with chunked-prefill-style batches.
# also bypasses auto chunked prefill on
python3 ./patch_xformers_sdpa_seq.py python3 ./patch_xformers_sdpa_seq.py