chunked prefill support and memory opts

This commit is contained in:
2026-06-05 16:03:34 +08:00
parent 8c047a70ea
commit 2d1ef50992
4 changed files with 166 additions and 86 deletions

View File

@@ -123,33 +123,36 @@ class PagedAttention:
num_blocks = (seq_len + block_size - 1) // block_size num_blocks = (seq_len + block_size - 1) // block_size
blk_ids = block_tables[i, :num_blocks] blk_ids = block_tables[i, :num_blocks]
# Gather K from paged cache: [seq_len, num_kv_heads, head_dim] # Gather K: [kv_h, head_dim, seq_len] fp32 — no GQA expansion.
k_seq = (key_cache[blk_ids] # With kv_h=1 and seq_len=100K this is 98 MB vs 586 MB if expanded.
k_t = (key_cache[blk_ids]
.permute(0, 3, 1, 2, 4) .permute(0, 3, 1, 2, 4)
.contiguous() .contiguous()
.view(-1, num_kv_heads, head_dim))[:seq_len] .view(-1, num_kv_heads, head_dim))[:seq_len] \
.permute(1, 2, 0).contiguous().float() # [kv_h, d, seq_len]
# Gather V from paged cache: [seq_len, num_kv_heads, head_dim] # Gather V: [kv_h, seq_len, head_dim] fp32
v_seq = (value_cache[blk_ids] v_t = (value_cache[blk_ids]
.permute(0, 3, 1, 2) .permute(0, 3, 1, 2)
.contiguous() .contiguous()
.view(-1, num_kv_heads, head_dim))[:seq_len] .view(-1, num_kv_heads, head_dim))[:seq_len] \
.permute(1, 0, 2).contiguous().float() # [kv_h, seq_len, d]
if gqa_ratio > 1: # Reshape Q for lazy GQA: [kv_h, gqa_ratio, 1, d]
k_seq = k_seq.repeat_interleave(gqa_ratio, dim=1) q_grouped = (query[i].float()
v_seq = v_seq.repeat_interleave(gqa_ratio, dim=1) .view(num_kv_heads, gqa_ratio, head_dim)
.unsqueeze(2))
# [H, head_dim, seq_len] and [H, seq_len, head_dim] # [kv_h, gqa_ratio, 1, seq_len]
k_t = k_seq.permute(1, 2, 0).float() attn_w = torch.matmul(
v_t = v_seq.permute(1, 0, 2).float() q_grouped * scale, # [kv_h, gqa, 1, d]
k_t.unsqueeze(1)) # [kv_h, 1, d, seq_len]
# q: [H, 1, head_dim]; attn_w: [H, 1, seq_len]
q_i = query[i].float().unsqueeze(1)
attn_w = torch.matmul(q_i * scale, k_t)
attn_w = torch.softmax(attn_w, dim=-1) attn_w = torch.softmax(attn_w, dim=-1)
out_i = torch.matmul(attn_w, v_t) # [H, 1, head_dim] # [kv_h, gqa_ratio, 1, d] → [num_heads, head_dim]
output[i] = out_i.squeeze(1).to(orig_dtype) out_i = torch.matmul(attn_w, v_t.unsqueeze(1))
output[i] = out_i.view(num_heads, head_dim).to(orig_dtype)
except Exception as e: except Exception as e:
print(f"[decode_pytorch ERROR] {type(e).__name__}: {e}", print(f"[decode_pytorch ERROR] {type(e).__name__}: {e}",
file=sys.stderr, flush=True) file=sys.stderr, flush=True)
@@ -158,11 +161,10 @@ class PagedAttention:
return output return output
# paged_attention_v1 on BI-V100 hangs when max_seq_len exceeds ~32K due to # paged_attention_v1 on BI-V100 fails for long contexts.
# shared memory limits; use pure-PyTorch fallback above this threshold. # Route on actual sequence length (seq_lens.max()), not the max_seq_len
# Set to a large value to disable for now (50K decode confirmed working via # parameter which is inflated to max_model_len in CUDA graph mode.
# hardware kernel); lower to 32768 if kernel hangs are observed at long contexts. _PYTORCH_DECODE_THRESHOLD = 32768
_PYTORCH_DECODE_THRESHOLD = 10_000_000
@staticmethod @staticmethod
def forward_decode( def forward_decode(
@@ -184,7 +186,8 @@ class PagedAttention:
blocksparse_block_size: int = 64, blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0, blocksparse_head_sliding_step: int = 0,
) -> torch.Tensor: ) -> torch.Tensor:
if max_seq_len > PagedAttention._PYTORCH_DECODE_THRESHOLD: actual_max = int(seq_lens.max().item()) if seq_lens.numel() > 0 else max_seq_len
if actual_max > PagedAttention._PYTORCH_DECODE_THRESHOLD:
return PagedAttention._forward_decode_pytorch( return PagedAttention._forward_decode_pytorch(
query, key_cache, value_cache, block_tables, seq_lens, scale) query, key_cache, value_cache, block_tables, seq_lens, scale)
@@ -312,18 +315,22 @@ class PagedAttention:
concatenates with the current-chunk K/V, then computes scaled-dot- concatenates with the current-chunk K/V, then computes scaled-dot-
product attention with a causal mask. product attention with a causal mask.
Memory optimisation — query chunking Memory optimisation — GQA-aware Q-tiling
------------------------------------ -----------------------------------------
A full-sequence attention matrix is O(q_len × kv_len) in float32. Two complementary tricks keep peak activation memory well below 1 GB
For long sequences (e.g., q_len = kv_len = 20 000) that blows up even for 100K context on TP=4 (kv_h=1, q_h=6):
to ~9 GB per layer. Instead we tile the query axis in sub-chunks
of _ATTN_Q_CHUNK tokens and accumulate the output; peak attn memory
becomes O(_ATTN_Q_CHUNK × kv_len), e.g. 123 MB per layer for
chunk=256 and kv_len=20 000.
This replaces the need for vllm's --enable-chunked-prefill flag 1. No GQA pre-expansion: K/V are kept at native [kv_h, kv_len, d]
(which the vendor's vllm 0.6.3 does not properly support for resolution and GQA grouping is handled via 4D reshape+broadcast
has_inner_state=True models on BI-V100). inside the matmul. With kv_h=1 and kv_len=100K this saves ~6×
vs the old expand-then-float32 approach:
Old: [6, 100K, 256] fp32 = 586 MB each for K and V
New: [1, 100K, 256] fp32 = 98 MB each for K and V
2. Q-tiling (_ATTN_Q_CHUNK=64): attn_w [kv_h, gqa, Q, kv_len] fp32
is bounded to ~148 MB at 100K instead of growing with q_len.
Combined peak per layer (100K): ~352 MB vs ~1200 MB previously.
Shapes Shapes
------ ------
@@ -337,13 +344,16 @@ class PagedAttention:
seq_lens_tensor: [batch_size] total length (context + query) seq_lens_tensor: [batch_size] total length (context + query)
context_lens : [batch_size] tokens already in KV cache context_lens : [batch_size] tokens already in KV cache
""" """
# Maximum query tokens to process at once per attention step. # Memory-efficient query-chunked attention.
# Tune this to balance memory vs kernel-launch overhead: # Key optimisation: do NOT expand KV heads for GQA before materialising
# 256 → ~120 MB peak attn memory (conservative, safe for 20K ctx) # k_t / v_t. With kv_h=1 (Qwen3.6 TP=4), keeping K/V at native kv_h
# 512 → ~240 MB peak attn memory # resolution saves ~6× memory vs expanding to q_h first:
# 1024 → ~490 MB peak attn memory # Old path (expand then float32): [6, 100K, 256] fp32 = 586 MB
# New path (keep kv_h, float32): [1, 100K, 256] fp32 = 98 MB
# GQA grouping is handled lazily inside the Q-tile matmul via 4D
# reshaping, so no extra tensors are created.
try: try:
_ATTN_Q_CHUNK = 256 _ATTN_Q_CHUNK = 64 # [kv_h, gqa, Q_CHUNK, kv_len] fp32 ≤ 150 MB
batch_size = seq_lens_tensor.shape[0] batch_size = seq_lens_tensor.shape[0]
num_q_heads = query.shape[1] num_q_heads = query.shape[1]
@@ -389,50 +399,66 @@ class PagedAttention:
k_full = torch.cat([k_ctx, k_i], dim=0) # [kv_len, kv_h, d] k_full = torch.cat([k_ctx, k_i], dim=0) # [kv_len, kv_h, d]
v_full = torch.cat([v_ctx, v_i], dim=0) v_full = torch.cat([v_ctx, v_i], dim=0)
del k_ctx, v_ctx
else: else:
k_full = k_i k_full = k_i
v_full = v_i v_full = v_i
kv_len = k_full.shape[0] # ctx_len + q_len kv_len = k_full.shape[0] # ctx_len + q_len
# GQA: expand KV heads to match Q heads # Transpose to [kv_h, kv_len, d], keep original dtype (fp16/bf16).
if gqa_ratio > 1: # Do NOT cast to fp32 here — k/v stay in fp16 to halve memory.
k_full = k_full.repeat_interleave(gqa_ratio, dim=1) # attn_w is computed in fp32 (q cast to fp32 before matmul, then
v_full = v_full.repeat_interleave(gqa_ratio, dim=1) # k cast inline) so softmax precision is unaffected.
# Do NOT expand GQA heads here either — gqa_ratio x memory savings.
k_t = k_full.permute(1, 0, 2).float() # [H, kv_len, d] k_t = k_full.permute(1, 0, 2).contiguous() # [kv_h, kv_len, d] fp16
v_t = v_full.permute(1, 0, 2).float() # [H, kv_len, d] del k_full
v_t = v_full.permute(1, 0, 2).contiguous() # [kv_h, kv_len, d] fp16
del v_full
# k_pos used for causal mask: shape [kv_len] # k_pos used for causal mask: shape [kv_len]
k_pos = torch.arange(kv_len, device=query.device) k_pos = torch.arange(kv_len, device=query.device)
# --- Query-chunked attention -------------------------------- # --- Query-chunked attention with lazy GQA grouping ----------
# Process _ATTN_Q_CHUNK query tokens at a time. # q_i reshaped to [kv_h, gqa_ratio, qc, d] so matmul with
# Peak attn tensor: [H, _ATTN_Q_CHUNK, kv_len] float32 # k_t [kv_h, kv_len, d] (broadcast over gqa_ratio dim) gives
# instead of [H, q_len, kv_len] float32. # attn_w [kv_h, gqa_ratio, qc, kv_len] without extra K copies.
for qc_start in range(0, q_len, _ATTN_Q_CHUNK): for qc_start in range(0, q_len, _ATTN_Q_CHUNK):
qc_end = min(qc_start + _ATTN_Q_CHUNK, q_len) qc_end = min(qc_start + _ATTN_Q_CHUNK, q_len)
qc = qc_end - qc_start
# [H, qc, d] # [kv_h, gqa_ratio, qc, d]
q_t_chunk = (q_i[qc_start:qc_end] q_t_chunk = (q_i[qc_start:qc_end]
.permute(1, 0, 2) .permute(1, 0, 2) # [q_h, qc, d]
.float()) .float()
.view(num_kv_heads, gqa_ratio, qc, head_dim))
# [H, qc, kv_len] # [kv_h, gqa_ratio, qc, kv_len]
# k_t unsqueezed to [kv_h, 1, kv_len, d] broadcasts over gqa_ratio.
# Cast k slice to fp32 inline; the temporary is freed after matmul.
attn_w = torch.matmul(q_t_chunk * scale, attn_w = torch.matmul(q_t_chunk * scale,
k_t.transpose(-1, -2)) k_t.unsqueeze(1).transpose(-1, -2).float())
# Causal mask for this sub-chunk: # Causal mask for this sub-chunk:
# query absolute position = ctx_len + qc_start..qc_end-1 # query absolute position = ctx_len + qc_start..qc_end-1
# can attend to k_pos <= its own absolute position
qc_q_pos = torch.arange(qc_start, qc_end, qc_q_pos = torch.arange(qc_start, qc_end,
device=query.device) device=query.device)
# mask[j, k] = True → future key, block it # mask[j, k] = True → future key, block it
mask = k_pos.unsqueeze(0) > (ctx_len + qc_q_pos.unsqueeze(1)) mask = k_pos.unsqueeze(0) > (ctx_len + qc_q_pos.unsqueeze(1))
attn_w = attn_w.masked_fill(mask.unsqueeze(0), float('-inf')) attn_w.masked_fill_(
mask.unsqueeze(0).unsqueeze(0), float('-inf'))
attn_w = torch.softmax(attn_w, dim=-1) # [H, qc, kv_len] # In-place numerically stable softmax — avoids allocating a
out_c = torch.matmul(attn_w, v_t) # [H, qc, d] # new 150 MB tensor (same size as attn_w) that torch.softmax
# would create, which exhausts the fragmented GPU pool.
attn_w -= attn_w.amax(dim=-1, keepdim=True)
attn_w.exp_()
attn_w /= attn_w.sum(dim=-1, keepdim=True)
# [kv_h, gqa_ratio, qc, d]; v_t cast to fp32 inline
out_c = torch.matmul(attn_w,
v_t.unsqueeze(1).float())
# reshape to [q_h, qc, d] then [qc, q_h, d]
out_c = out_c.view(num_q_heads, qc, head_dim)
output[q_start + qc_start : q_start + qc_end] = ( output[q_start + qc_start : q_start + qc_end] = (
out_c.to(orig_dtype).permute(1, 0, 2)) out_c.to(orig_dtype).permute(1, 0, 2))

View File

@@ -9,20 +9,21 @@
# - DO NOT install BI-V150 corex Triton 2.1.0 (pkgs/triton): that causes # - DO NOT install BI-V150 corex Triton 2.1.0 (pkgs/triton): that causes
# GPU hang on BI-V100 because the Triton CUDA PTX kernels are incompatible. # GPU hang on BI-V100 because the Triton CUDA PTX kernels are incompatible.
# #
# Chunked prefill note: # Important Note: Qwen3.6-27B must apply TP=4,PP=2 combination in order to deploy using 8 GPUs
# --enable-chunked-prefill is NOT supported by the vendor's vllm 0.6.3 for
# has_inner_state=True models on BI-V100. It causes "Engine loop has died"
# immediately on first request. Do NOT use that flag.
# Long-context memory is instead handled by query-chunking inside
# _forward_prefix_pytorch (see paged_attn.py, _ATTN_Q_CHUNK=256).
# #
# Recommended server start command: # Recommended server start command for TP=4, context length: 50K, no chunked prefill mechanism:
# python3 -m vllm.entrypoints.openai.api_server \ # CUDA_VISIBLE_DEVICES="4,5,6,7" VLLM_ENGINE_ITERATION_TIMEOUT_S=3600 python3 -m vllm.entrypoints.openai.api_server \
# --model /workspace/models/Qwen3.6-27B --port 1111 \ # --model /workspace/models/Qwen3.6-27B --port 1111 --served-model-name llm \
# --served-model-name llm --max-model-len 20000 \ # --max-model-len 50000 --enforce-eager --trust-remote-code -tp 4 --gpu-memory-utilization 0.90 \
# --enforce-eager --trust-remote-code -tp 4 \ # --max-num-seqs 1 --disable-log-requests --disable-frontend-multiprocessing \
# --gpu-memory-utilization 0.95 # --max-num-batched-tokens 50000
# (No --enable-chunked-prefill, no --max-num-batched-tokens)
# Recommended server start command for TP=4 support 100K, need chunked prefill
# CUDA_VISIBLE_DEVICES="4,5,6,7" VLLM_ENGINE_ITERATION_TIMEOUT_S=3600 python3 -m vllm.entrypoints.openai.api_server \
# --model /workspace/models/Qwen3.6-27B --port 1111 --served-model-name llm \
# --max-model-len 100000 --enforce-eager --trust-remote-code -tp 8 --gpu-memory-utilization 0.95 \
# --max-num-seqs 1 --disable-log-requests --disable-frontend-multiprocessing \
# --max-num-batched-tokens 4096 --enable-chunked-prefill
# --- paged_attn.py: replace forward_prefix with pure-PyTorch fallback ------- # --- paged_attn.py: replace forward_prefix with pure-PyTorch fallback -------
# The Triton context_attention_fwd kernel hangs BI-V100 GPUs permanently # The Triton context_attention_fwd kernel hangs BI-V100 GPUs permanently

View File

@@ -44,6 +44,31 @@ ARG_UTILS_PATH = (
"vllm/engine/arg_utils.py" "vllm/engine/arg_utils.py"
) )
LOGITS_PROC_PATH = (
"/usr/local/corex/lib64/python3/dist-packages/"
"vllm/model_executor/layers/logits_processor.py"
)
# _apply_logits_processors crashes when seq_groups is None (intermediate
# chunked-prefill chunks on the driver rank). Add an early-return guard.
_LP_OLD_BLOCK = """\
def _apply_logits_processors(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
found_logits_processors = False\
"""
_LP_NEW_BLOCK = """\
def _apply_logits_processors(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
if sampling_metadata.seq_groups is None: # intermediate chunked-prefill chunk
return logits
found_logits_processors = False\
"""
# vllm 0.6.3 自动开启 chunked prefill 的原始块 # vllm 0.6.3 自动开启 chunked prefill 的原始块
_ARG_OLD_BLOCK = """\ _ARG_OLD_BLOCK = """\
if (is_gpu and not use_sliding_window and not use_spec_decode if (is_gpu and not use_sliding_window and not use_spec_decode
@@ -256,6 +281,26 @@ def patch_arg_utils(path):
print(f" Written: {path}") print(f" Written: {path}")
def patch_logits_processor(path):
with open(path, "r") as f:
content = f.read()
changed = False
if "intermediate chunked-prefill chunk" in content:
print(" [skip] seq_groups=None guard already present")
elif _LP_OLD_BLOCK in content:
content = content.replace(_LP_OLD_BLOCK, _LP_NEW_BLOCK, 1)
print(" [ok] added seq_groups=None guard in _apply_logits_processors")
changed = True
else:
print(" [warn] target block not found — check logits_processor.py version")
if changed:
with open(path, "w") as f:
f.write(content)
print(f" Written: {path}")
def main(): def main():
print("=== patch_xformers_sdpa_seq (sequential, pure-math) ===") print("=== patch_xformers_sdpa_seq (sequential, pure-math) ===")
print(f"Target: {XFORMERS_PATH}") print(f"Target: {XFORMERS_PATH}")
@@ -265,6 +310,10 @@ def main():
print(f"Target: {ARG_UTILS_PATH}") print(f"Target: {ARG_UTILS_PATH}")
patch_arg_utils(ARG_UTILS_PATH) patch_arg_utils(ARG_UTILS_PATH)
print("\n=== patch_logits_processor (seq_groups=None guard for chunked prefill) ===")
print(f"Target: {LOGITS_PROC_PATH}")
patch_logits_processor(LOGITS_PROC_PATH)
print("\nDone.") print("\nDone.")

View File

@@ -334,6 +334,11 @@ class GatedDeltaNet(nn.Module):
.transpose(0, 1).unsqueeze(0) .transpose(0, 1).unsqueeze(0)
.to(weight_2d.dtype)) .to(weight_2d.dtype))
# Load prev conv state BEFORE overwriting (needed for causal conv padding).
# For first prefill of a request: mamba_cache is zeros → correct.
# For chunked prefill chunk 2+: carries last state_len tokens from prev chunk.
prev_conv = conv_state[si:si + 1].clone().to(weight_2d.dtype) # [1, local_conv_dim, state_len]
# Save conv state (last state_len positions) # Save conv state (last state_len positions)
if seq_len >= state_len: if seq_len >= state_len:
conv_state[si].copy_(mixed_qkv[0, :, -state_len:]) conv_state[si].copy_(mixed_qkv[0, :, -state_len:])
@@ -342,8 +347,8 @@ class GatedDeltaNet(nn.Module):
mixed_qkv[0]) mixed_qkv[0])
conv_state[si, :, :state_len - seq_len] = 0 conv_state[si, :, :state_len - seq_len] = 0
# Causal conv (left-pad with zeros, then convolve) # Causal conv: left-pad with previous conv state (not zeros).
padded = F.pad(mixed_qkv, (state_len, 0)) padded = torch.cat([prev_conv, mixed_qkv], dim=2)
mixed_qkv_conv = F.conv1d( mixed_qkv_conv = F.conv1d(
padded, self.conv1d_weight, padded, self.conv1d_weight,
bias=None, padding=0, groups=local_conv_dim) bias=None, padding=0, groups=local_conv_dim)
@@ -850,12 +855,11 @@ class Qwen3_5ForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
# Non-driver TP ranks have seq_groups=None in sampling_metadata (normal # All TP ranks must call logits_processor to participate in the NCCL
# TP behavior); they must still call logits_processor to participate in # gather inside lm_head. Non-driver ranks return None after the gather.
# the NCCL gather inside lm_head. logits_processor returns None for # With chunked prefill, intermediate chunks have seq_groups=None on all
# non-driver ranks after the gather, safely skipping _apply_logits_processors. # ranks; _apply_logits_processors is guarded against this in
# Rank 0 (driver) always has seq_groups != None given # logits_processor.py (patched by patch_xformers_sdpa_seq.py).
# --max-num-batched-tokens >= --max-model-len (no chunked-prefill splits).
return self.logits_processor(self.lm_head, hidden_states, return self.logits_processor(self.lm_head, hidden_states,
sampling_metadata) sampling_metadata)