From 2d1ef509921d8ca5a65c1459147b7f349c8f8da2 Mon Sep 17 00:00:00 2001 From: Lu Xinlong Date: Fri, 5 Jun 2026 16:03:34 +0800 Subject: [PATCH] chunked prefill support and memory opts --- qwen3_6_scripts/paged_attn.py | 156 ++++++++++++--------- qwen3_6_scripts/patch_ops.sh | 27 ++-- qwen3_6_scripts/patch_xformers_sdpa_seq.py | 49 +++++++ qwen3_6_scripts/qwen3_5.py | 20 +-- 4 files changed, 166 insertions(+), 86 deletions(-) diff --git a/qwen3_6_scripts/paged_attn.py b/qwen3_6_scripts/paged_attn.py index 4e58486..ae4a81c 100644 --- a/qwen3_6_scripts/paged_attn.py +++ b/qwen3_6_scripts/paged_attn.py @@ -123,33 +123,36 @@ class PagedAttention: num_blocks = (seq_len + block_size - 1) // block_size blk_ids = block_tables[i, :num_blocks] - # Gather K from paged cache: [seq_len, num_kv_heads, head_dim] - k_seq = (key_cache[blk_ids] - .permute(0, 3, 1, 2, 4) - .contiguous() - .view(-1, num_kv_heads, head_dim))[:seq_len] + # Gather K: [kv_h, head_dim, seq_len] fp32 — no GQA expansion. + # With kv_h=1 and seq_len=100K this is 98 MB vs 586 MB if expanded. + k_t = (key_cache[blk_ids] + .permute(0, 3, 1, 2, 4) + .contiguous() + .view(-1, num_kv_heads, head_dim))[:seq_len] \ + .permute(1, 2, 0).contiguous().float() # [kv_h, d, seq_len] - # Gather V from paged cache: [seq_len, num_kv_heads, head_dim] - v_seq = (value_cache[blk_ids] - .permute(0, 3, 1, 2) - .contiguous() - .view(-1, num_kv_heads, head_dim))[:seq_len] + # Gather V: [kv_h, seq_len, head_dim] fp32 + v_t = (value_cache[blk_ids] + .permute(0, 3, 1, 2) + .contiguous() + .view(-1, num_kv_heads, head_dim))[:seq_len] \ + .permute(1, 0, 2).contiguous().float() # [kv_h, seq_len, d] - if gqa_ratio > 1: - k_seq = k_seq.repeat_interleave(gqa_ratio, dim=1) - v_seq = v_seq.repeat_interleave(gqa_ratio, dim=1) + # Reshape Q for lazy GQA: [kv_h, gqa_ratio, 1, d] + q_grouped = (query[i].float() + .view(num_kv_heads, gqa_ratio, head_dim) + .unsqueeze(2)) - # [H, head_dim, seq_len] and [H, seq_len, head_dim] - k_t = k_seq.permute(1, 2, 0).float() - v_t = v_seq.permute(1, 0, 2).float() - - # q: [H, 1, head_dim]; attn_w: [H, 1, seq_len] - q_i = query[i].float().unsqueeze(1) - attn_w = torch.matmul(q_i * scale, k_t) + # [kv_h, gqa_ratio, 1, seq_len] + attn_w = torch.matmul( + q_grouped * scale, # [kv_h, gqa, 1, d] + k_t.unsqueeze(1)) # [kv_h, 1, d, seq_len] attn_w = torch.softmax(attn_w, dim=-1) - out_i = torch.matmul(attn_w, v_t) # [H, 1, head_dim] - output[i] = out_i.squeeze(1).to(orig_dtype) + # [kv_h, gqa_ratio, 1, d] → [num_heads, head_dim] + out_i = torch.matmul(attn_w, v_t.unsqueeze(1)) + output[i] = out_i.view(num_heads, head_dim).to(orig_dtype) + except Exception as e: print(f"[decode_pytorch ERROR] {type(e).__name__}: {e}", file=sys.stderr, flush=True) @@ -158,11 +161,10 @@ class PagedAttention: return output - # paged_attention_v1 on BI-V100 hangs when max_seq_len exceeds ~32K due to - # shared memory limits; use pure-PyTorch fallback above this threshold. - # Set to a large value to disable for now (50K decode confirmed working via - # hardware kernel); lower to 32768 if kernel hangs are observed at long contexts. - _PYTORCH_DECODE_THRESHOLD = 10_000_000 + # paged_attention_v1 on BI-V100 fails for long contexts. + # Route on actual sequence length (seq_lens.max()), not the max_seq_len + # parameter which is inflated to max_model_len in CUDA graph mode. + _PYTORCH_DECODE_THRESHOLD = 32768 @staticmethod def forward_decode( @@ -184,7 +186,8 @@ class PagedAttention: blocksparse_block_size: int = 64, blocksparse_head_sliding_step: int = 0, ) -> torch.Tensor: - if max_seq_len > PagedAttention._PYTORCH_DECODE_THRESHOLD: + actual_max = int(seq_lens.max().item()) if seq_lens.numel() > 0 else max_seq_len + if actual_max > PagedAttention._PYTORCH_DECODE_THRESHOLD: return PagedAttention._forward_decode_pytorch( query, key_cache, value_cache, block_tables, seq_lens, scale) @@ -312,18 +315,22 @@ class PagedAttention: concatenates with the current-chunk K/V, then computes scaled-dot- product attention with a causal mask. - Memory optimisation — query chunking - ------------------------------------ - A full-sequence attention matrix is O(q_len × kv_len) in float32. - For long sequences (e.g., q_len = kv_len = 20 000) that blows up - to ~9 GB per layer. Instead we tile the query axis in sub-chunks - of _ATTN_Q_CHUNK tokens and accumulate the output; peak attn memory - becomes O(_ATTN_Q_CHUNK × kv_len), e.g. 123 MB per layer for - chunk=256 and kv_len=20 000. + Memory optimisation — GQA-aware Q-tiling + ----------------------------------------- + Two complementary tricks keep peak activation memory well below 1 GB + even for 100K context on TP=4 (kv_h=1, q_h=6): - This replaces the need for vllm's --enable-chunked-prefill flag - (which the vendor's vllm 0.6.3 does not properly support for - has_inner_state=True models on BI-V100). + 1. No GQA pre-expansion: K/V are kept at native [kv_h, kv_len, d] + resolution and GQA grouping is handled via 4D reshape+broadcast + inside the matmul. With kv_h=1 and kv_len=100K this saves ~6× + vs the old expand-then-float32 approach: + Old: [6, 100K, 256] fp32 = 586 MB each for K and V + New: [1, 100K, 256] fp32 = 98 MB each for K and V + + 2. Q-tiling (_ATTN_Q_CHUNK=64): attn_w [kv_h, gqa, Q, kv_len] fp32 + is bounded to ~148 MB at 100K instead of growing with q_len. + + Combined peak per layer (100K): ~352 MB vs ~1200 MB previously. Shapes ------ @@ -337,13 +344,16 @@ class PagedAttention: seq_lens_tensor: [batch_size] total length (context + query) context_lens : [batch_size] tokens already in KV cache """ - # Maximum query tokens to process at once per attention step. - # Tune this to balance memory vs kernel-launch overhead: - # 256 → ~120 MB peak attn memory (conservative, safe for 20K ctx) - # 512 → ~240 MB peak attn memory - # 1024 → ~490 MB peak attn memory + # Memory-efficient query-chunked attention. + # Key optimisation: do NOT expand KV heads for GQA before materialising + # k_t / v_t. With kv_h=1 (Qwen3.6 TP=4), keeping K/V at native kv_h + # resolution saves ~6× memory vs expanding to q_h first: + # Old path (expand then float32): [6, 100K, 256] fp32 = 586 MB + # New path (keep kv_h, float32): [1, 100K, 256] fp32 = 98 MB + # GQA grouping is handled lazily inside the Q-tile matmul via 4D + # reshaping, so no extra tensors are created. try: - _ATTN_Q_CHUNK = 256 + _ATTN_Q_CHUNK = 64 # [kv_h, gqa, Q_CHUNK, kv_len] fp32 ≤ 150 MB batch_size = seq_lens_tensor.shape[0] num_q_heads = query.shape[1] @@ -389,50 +399,66 @@ class PagedAttention: k_full = torch.cat([k_ctx, k_i], dim=0) # [kv_len, kv_h, d] v_full = torch.cat([v_ctx, v_i], dim=0) + del k_ctx, v_ctx else: k_full = k_i v_full = v_i kv_len = k_full.shape[0] # ctx_len + q_len - # GQA: expand KV heads to match Q heads - if gqa_ratio > 1: - k_full = k_full.repeat_interleave(gqa_ratio, dim=1) - v_full = v_full.repeat_interleave(gqa_ratio, dim=1) - - k_t = k_full.permute(1, 0, 2).float() # [H, kv_len, d] - v_t = v_full.permute(1, 0, 2).float() # [H, kv_len, d] + # Transpose to [kv_h, kv_len, d], keep original dtype (fp16/bf16). + # Do NOT cast to fp32 here — k/v stay in fp16 to halve memory. + # attn_w is computed in fp32 (q cast to fp32 before matmul, then + # k cast inline) so softmax precision is unaffected. + # Do NOT expand GQA heads here either — gqa_ratio x memory savings. + k_t = k_full.permute(1, 0, 2).contiguous() # [kv_h, kv_len, d] fp16 + del k_full + v_t = v_full.permute(1, 0, 2).contiguous() # [kv_h, kv_len, d] fp16 + del v_full # k_pos used for causal mask: shape [kv_len] k_pos = torch.arange(kv_len, device=query.device) - # --- Query-chunked attention -------------------------------- - # Process _ATTN_Q_CHUNK query tokens at a time. - # Peak attn tensor: [H, _ATTN_Q_CHUNK, kv_len] float32 - # instead of [H, q_len, kv_len] float32. + # --- Query-chunked attention with lazy GQA grouping ---------- + # q_i reshaped to [kv_h, gqa_ratio, qc, d] so matmul with + # k_t [kv_h, kv_len, d] (broadcast over gqa_ratio dim) gives + # attn_w [kv_h, gqa_ratio, qc, kv_len] without extra K copies. for qc_start in range(0, q_len, _ATTN_Q_CHUNK): qc_end = min(qc_start + _ATTN_Q_CHUNK, q_len) + qc = qc_end - qc_start - # [H, qc, d] + # [kv_h, gqa_ratio, qc, d] q_t_chunk = (q_i[qc_start:qc_end] - .permute(1, 0, 2) - .float()) + .permute(1, 0, 2) # [q_h, qc, d] + .float() + .view(num_kv_heads, gqa_ratio, qc, head_dim)) - # [H, qc, kv_len] + # [kv_h, gqa_ratio, qc, kv_len] + # k_t unsqueezed to [kv_h, 1, kv_len, d] broadcasts over gqa_ratio. + # Cast k slice to fp32 inline; the temporary is freed after matmul. attn_w = torch.matmul(q_t_chunk * scale, - k_t.transpose(-1, -2)) + k_t.unsqueeze(1).transpose(-1, -2).float()) # Causal mask for this sub-chunk: # query absolute position = ctx_len + qc_start..qc_end-1 - # can attend to k_pos <= its own absolute position qc_q_pos = torch.arange(qc_start, qc_end, device=query.device) # mask[j, k] = True → future key, block it mask = k_pos.unsqueeze(0) > (ctx_len + qc_q_pos.unsqueeze(1)) - attn_w = attn_w.masked_fill(mask.unsqueeze(0), float('-inf')) + attn_w.masked_fill_( + mask.unsqueeze(0).unsqueeze(0), float('-inf')) - attn_w = torch.softmax(attn_w, dim=-1) # [H, qc, kv_len] - out_c = torch.matmul(attn_w, v_t) # [H, qc, d] + # In-place numerically stable softmax — avoids allocating a + # new 150 MB tensor (same size as attn_w) that torch.softmax + # would create, which exhausts the fragmented GPU pool. + attn_w -= attn_w.amax(dim=-1, keepdim=True) + attn_w.exp_() + attn_w /= attn_w.sum(dim=-1, keepdim=True) + # [kv_h, gqa_ratio, qc, d]; v_t cast to fp32 inline + out_c = torch.matmul(attn_w, + v_t.unsqueeze(1).float()) + # reshape to [q_h, qc, d] then [qc, q_h, d] + out_c = out_c.view(num_q_heads, qc, head_dim) output[q_start + qc_start : q_start + qc_end] = ( out_c.to(orig_dtype).permute(1, 0, 2)) diff --git a/qwen3_6_scripts/patch_ops.sh b/qwen3_6_scripts/patch_ops.sh index c31b7f6..cb9d427 100755 --- a/qwen3_6_scripts/patch_ops.sh +++ b/qwen3_6_scripts/patch_ops.sh @@ -9,20 +9,21 @@ # - DO NOT install BI-V150 corex Triton 2.1.0 (pkgs/triton): that causes # GPU hang on BI-V100 because the Triton CUDA PTX kernels are incompatible. # -# Chunked prefill note: -# --enable-chunked-prefill is NOT supported by the vendor's vllm 0.6.3 for -# has_inner_state=True models on BI-V100. It causes "Engine loop has died" -# immediately on first request. Do NOT use that flag. -# Long-context memory is instead handled by query-chunking inside -# _forward_prefix_pytorch (see paged_attn.py, _ATTN_Q_CHUNK=256). +# Important Note: Qwen3.6-27B must apply TP=4,PP=2 combination in order to deploy using 8 GPUs # -# Recommended server start command: -# python3 -m vllm.entrypoints.openai.api_server \ -# --model /workspace/models/Qwen3.6-27B --port 1111 \ -# --served-model-name llm --max-model-len 20000 \ -# --enforce-eager --trust-remote-code -tp 4 \ -# --gpu-memory-utilization 0.95 -# (No --enable-chunked-prefill, no --max-num-batched-tokens) +# Recommended server start command for TP=4, context length: 50K, no chunked prefill mechanism: +# CUDA_VISIBLE_DEVICES="4,5,6,7" VLLM_ENGINE_ITERATION_TIMEOUT_S=3600 python3 -m vllm.entrypoints.openai.api_server \ +# --model /workspace/models/Qwen3.6-27B --port 1111 --served-model-name llm \ +# --max-model-len 50000 --enforce-eager --trust-remote-code -tp 4 --gpu-memory-utilization 0.90 \ +# --max-num-seqs 1 --disable-log-requests --disable-frontend-multiprocessing \ +# --max-num-batched-tokens 50000 + +# Recommended server start command for TP=4 support 100K, need chunked prefill +# CUDA_VISIBLE_DEVICES="4,5,6,7" VLLM_ENGINE_ITERATION_TIMEOUT_S=3600 python3 -m vllm.entrypoints.openai.api_server \ +# --model /workspace/models/Qwen3.6-27B --port 1111 --served-model-name llm \ +# --max-model-len 100000 --enforce-eager --trust-remote-code -tp 8 --gpu-memory-utilization 0.95 \ +# --max-num-seqs 1 --disable-log-requests --disable-frontend-multiprocessing \ +# --max-num-batched-tokens 4096 --enable-chunked-prefill # --- paged_attn.py: replace forward_prefix with pure-PyTorch fallback ------- # The Triton context_attention_fwd kernel hangs BI-V100 GPUs permanently diff --git a/qwen3_6_scripts/patch_xformers_sdpa_seq.py b/qwen3_6_scripts/patch_xformers_sdpa_seq.py index 181d184..496abc1 100644 --- a/qwen3_6_scripts/patch_xformers_sdpa_seq.py +++ b/qwen3_6_scripts/patch_xformers_sdpa_seq.py @@ -44,6 +44,31 @@ ARG_UTILS_PATH = ( "vllm/engine/arg_utils.py" ) +LOGITS_PROC_PATH = ( + "/usr/local/corex/lib64/python3/dist-packages/" + "vllm/model_executor/layers/logits_processor.py" +) + +# _apply_logits_processors crashes when seq_groups is None (intermediate +# chunked-prefill chunks on the driver rank). Add an early-return guard. +_LP_OLD_BLOCK = """\ +def _apply_logits_processors( + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, +) -> torch.Tensor: + found_logits_processors = False\ +""" + +_LP_NEW_BLOCK = """\ +def _apply_logits_processors( + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, +) -> torch.Tensor: + if sampling_metadata.seq_groups is None: # intermediate chunked-prefill chunk + return logits + found_logits_processors = False\ +""" + # vllm 0.6.3 自动开启 chunked prefill 的原始块 _ARG_OLD_BLOCK = """\ if (is_gpu and not use_sliding_window and not use_spec_decode @@ -256,6 +281,26 @@ def patch_arg_utils(path): print(f" Written: {path}") +def patch_logits_processor(path): + with open(path, "r") as f: + content = f.read() + changed = False + + if "intermediate chunked-prefill chunk" in content: + print(" [skip] seq_groups=None guard already present") + elif _LP_OLD_BLOCK in content: + content = content.replace(_LP_OLD_BLOCK, _LP_NEW_BLOCK, 1) + print(" [ok] added seq_groups=None guard in _apply_logits_processors") + changed = True + else: + print(" [warn] target block not found — check logits_processor.py version") + + if changed: + with open(path, "w") as f: + f.write(content) + print(f" Written: {path}") + + def main(): print("=== patch_xformers_sdpa_seq (sequential, pure-math) ===") print(f"Target: {XFORMERS_PATH}") @@ -265,6 +310,10 @@ def main(): print(f"Target: {ARG_UTILS_PATH}") patch_arg_utils(ARG_UTILS_PATH) + print("\n=== patch_logits_processor (seq_groups=None guard for chunked prefill) ===") + print(f"Target: {LOGITS_PROC_PATH}") + patch_logits_processor(LOGITS_PROC_PATH) + print("\nDone.") diff --git a/qwen3_6_scripts/qwen3_5.py b/qwen3_6_scripts/qwen3_5.py index 2b724c9..371be71 100644 --- a/qwen3_6_scripts/qwen3_5.py +++ b/qwen3_6_scripts/qwen3_5.py @@ -334,6 +334,11 @@ class GatedDeltaNet(nn.Module): .transpose(0, 1).unsqueeze(0) .to(weight_2d.dtype)) + # Load prev conv state BEFORE overwriting (needed for causal conv padding). + # For first prefill of a request: mamba_cache is zeros → correct. + # For chunked prefill chunk 2+: carries last state_len tokens from prev chunk. + prev_conv = conv_state[si:si + 1].clone().to(weight_2d.dtype) # [1, local_conv_dim, state_len] + # Save conv state (last state_len positions) if seq_len >= state_len: conv_state[si].copy_(mixed_qkv[0, :, -state_len:]) @@ -342,8 +347,8 @@ class GatedDeltaNet(nn.Module): mixed_qkv[0]) conv_state[si, :, :state_len - seq_len] = 0 - # Causal conv (left-pad with zeros, then convolve) - padded = F.pad(mixed_qkv, (state_len, 0)) + # Causal conv: left-pad with previous conv state (not zeros). + padded = torch.cat([prev_conv, mixed_qkv], dim=2) mixed_qkv_conv = F.conv1d( padded, self.conv1d_weight, bias=None, padding=0, groups=local_conv_dim) @@ -850,12 +855,11 @@ class Qwen3_5ForCausalLM(nn.Module, HasInnerState, SupportsLoRA): hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - # Non-driver TP ranks have seq_groups=None in sampling_metadata (normal - # TP behavior); they must still call logits_processor to participate in - # the NCCL gather inside lm_head. logits_processor returns None for - # non-driver ranks after the gather, safely skipping _apply_logits_processors. - # Rank 0 (driver) always has seq_groups != None given - # --max-num-batched-tokens >= --max-model-len (no chunked-prefill splits). + # All TP ranks must call logits_processor to participate in the NCCL + # gather inside lm_head. Non-driver ranks return None after the gather. + # With chunked prefill, intermediate chunks have seq_groups=None on all + # ranks; _apply_logits_processors is guarded against this in + # logits_processor.py (patched by patch_xformers_sdpa_seq.py). return self.logits_processor(self.lm_head, hidden_states, sampling_metadata)