diff --git a/qwen3_6_scripts/paged_attn.py b/qwen3_6_scripts/paged_attn.py index ae4a81c..63c9443 100644 --- a/qwen3_6_scripts/paged_attn.py +++ b/qwen3_6_scripts/paged_attn.py @@ -309,28 +309,23 @@ class PagedAttention: seq_lens_tensor: torch.Tensor, context_lens: torch.Tensor, ) -> torch.Tensor: - """Pure-PyTorch prefix-attention with query-chunking (no Triton). + """Pure-PyTorch prefix-attention with K-tiling (Flash-Attention online softmax). - For each sequence, gathers the context KV from the paged KV cache, - concatenates with the current-chunk K/V, then computes scaled-dot- - product attention with a causal mask. + Memory complexity: O(q_len), independent of kv_len. + With chunked prefill (q_len ≤ max_num_batched_tokens = 4096) peak + per layer ≈ 96 MB regardless of context length. - Memory optimisation — GQA-aware Q-tiling - ----------------------------------------- - Two complementary tricks keep peak activation memory well below 1 GB - even for 100K context on TP=4 (kv_h=1, q_h=6): + Algorithm: Flash Attention online softmax. + Q is reshaped once to [kv_h, gqa, q_len, d] (24 MB) and held for all + K-tiles. For each tile a running (m, l, o) accumulator is updated — + the [q_len × kv_len] attention matrix is NEVER materialised in full. - 1. No GQA pre-expansion: K/V are kept at native [kv_h, kv_len, d] - resolution and GQA grouping is handled via 4D reshape+broadcast - inside the matmul. With kv_h=1 and kv_len=100K this saves ~6× - vs the old expand-then-float32 approach: - Old: [6, 100K, 256] fp32 = 586 MB each for K and V - New: [1, 100K, 256] fp32 = 98 MB each for K and V - - 2. Q-tiling (_ATTN_Q_CHUNK=64): attn_w [kv_h, gqa, Q, kv_len] fp32 - is bounded to ~148 MB at 100K instead of growing with q_len. - - Combined peak per layer (100K): ~352 MB vs ~1200 MB previously. + Tile budget (kv_h=1, gqa=6, q_len=4096, tile=256 tokens): + q_seq [1, 6, 4096, 256] fp32 24 MB (held all tiles) + o_acc same shape 24 MB (held all tiles) + s same shape 24 MB (per tile, freed before exp_s) + exp_s same shape 24 MB (per tile, brief overlap with s) + Peak ≈ 96 MB (s and exp_s briefly coexist during update). Shapes ------ @@ -344,29 +339,24 @@ class PagedAttention: seq_lens_tensor: [batch_size] total length (context + query) context_lens : [batch_size] tokens already in KV cache """ - # Memory-efficient query-chunked attention. - # Key optimisation: do NOT expand KV heads for GQA before materialising - # k_t / v_t. With kv_h=1 (Qwen3.6 TP=4), keeping K/V at native kv_h - # resolution saves ~6× memory vs expanding to q_h first: - # Old path (expand then float32): [6, 100K, 256] fp32 = 586 MB - # New path (keep kv_h, float32): [1, 100K, 256] fp32 = 98 MB - # GQA grouping is handled lazily inside the Q-tile matmul via 4D - # reshaping, so no extra tensors are created. try: - _ATTN_Q_CHUNK = 64 # [kv_h, gqa, Q_CHUNK, kv_len] fp32 ≤ 150 MB + # Paged-block tiles for context phase. + # tile_sz = _BLOCKS_PER_TILE × block_size (e.g. 16×16 = 256 tokens). + # Score tensor [kv_h, gqa, q_len, tile_sz] fp32 = 24 MB per tile. + # Same tile size reused for the current-chunk phase. + _BLOCKS_PER_TILE = 32 batch_size = seq_lens_tensor.shape[0] num_q_heads = query.shape[1] num_kv_heads = key_cache.shape[1] head_dim = query.shape[2] gqa_ratio = num_q_heads // num_kv_heads - - # value_cache: [num_blocks, num_kv_heads, head_dim, block_size] - block_size = value_cache.shape[3] - - scale = 1.0 / (head_dim ** 0.5) - output = torch.empty_like(query) - orig_dtype = query.dtype + block_size = value_cache.shape[3] + tile_sz = _BLOCKS_PER_TILE * block_size + scale = head_dim ** -0.5 + orig_dtype = query.dtype + output = torch.empty_like(query) + dev = query.device for i in range(batch_size): ctx_len = int(context_lens[i].item()) @@ -374,96 +364,147 @@ class PagedAttention: q_end = int(query_start_loc[i + 1].item()) q_len = q_end - q_start - q_i = query[q_start:q_end] # [q_len, num_q_heads, head_dim] - k_i = key [q_start:q_end] # [q_len, num_kv_heads, head_dim] + q_i = query[q_start:q_end] # [q_len, q_h, d] + k_i = key [q_start:q_end] # [q_len, kv_h, d] v_i = value[q_start:q_end] - # --- Build full K/V (context from cache + current chunk) ---- + # Q reshaped and scaled once; held for all K-tiles. + # [kv_h, gqa, q_len, d] fp32 — 24 MB for q_len=4096, d=256 + q_seq = (q_i.permute(1, 0, 2) + .float() + .view(num_kv_heads, gqa_ratio, q_len, head_dim) + .mul_(scale)) + + # Flash-Attention online-softmax accumulators. + # m, l : [kv_h, gqa, q_len] fp32 — <0.1 MB + # o : [kv_h, gqa, q_len, d] fp32 — 24 MB + m = torch.full((num_kv_heads, gqa_ratio, q_len), + float('-inf'), dtype=torch.float32, device=dev) + l = torch.zeros_like(m) + o = torch.zeros((num_kv_heads, gqa_ratio, q_len, head_dim), + dtype=torch.float32, device=dev) + + # -------------------------------------------------------------- + # Phase 1 — context tokens (positions 0 … ctx_len-1). + # + # Every context key has absolute position < ctx_len; every + # query has position ≥ ctx_len. k_pos < q_pos is always True + # → no causal mask needed for pure context tiles. + # -------------------------------------------------------------- if ctx_len > 0: num_ctx_blocks = (ctx_len + block_size - 1) // block_size - blk_ids = block_tables[i, :num_ctx_blocks] + for tile_blk in range(0, num_ctx_blocks, _BLOCKS_PER_TILE): + blk_end = min(tile_blk + _BLOCKS_PER_TILE, num_ctx_blocks) + blk_ids = block_tables[i, tile_blk:blk_end] - # key_cache[blk_ids]: [n, kv_h, d//x, blk_sz, x] - # → permute(0,3,1,2,4) → contiguous → view → [:ctx_len] - k_ctx = (key_cache[blk_ids] - .permute(0, 3, 1, 2, 4) - .contiguous() - .view(-1, num_kv_heads, head_dim))[:ctx_len] + # Gather K/V for this tile. + # key_cache [blk_ids]: [n, kv_h, d//x, blk_sz, x] + # value_cache[blk_ids]: [n, kv_h, d, blk_sz] + k_tile = (key_cache[blk_ids] + .permute(0, 3, 1, 2, 4) + .contiguous() + .view(-1, num_kv_heads, head_dim)) + v_tile = (value_cache[blk_ids] + .permute(0, 3, 1, 2) + .contiguous() + .view(-1, num_kv_heads, head_dim)) - # value_cache[blk_ids]: [n, kv_h, d, blk_sz] - # → permute(0,3,1,2) → contiguous → view → [:ctx_len] - v_ctx = (value_cache[blk_ids] - .permute(0, 3, 1, 2) - .contiguous() - .view(-1, num_kv_heads, head_dim))[:ctx_len] + # Trim padding in the last block of the tile. + valid = (min(blk_end * block_size, ctx_len) + - tile_blk * block_size) + k_tile = k_tile[:valid] # [valid, kv_h, d] + v_tile = v_tile[:valid] - k_full = torch.cat([k_ctx, k_i], dim=0) # [kv_len, kv_h, d] - v_full = torch.cat([v_ctx, v_i], dim=0) - del k_ctx, v_ctx - else: - k_full = k_i - v_full = v_i + # k_t: [kv_h, 1, d, valid] (broadcast over gqa_ratio) + # v_t: [kv_h, 1, valid, d] + k_t = (k_tile.permute(1, 0, 2) + .unsqueeze(1) + .transpose(-1, -2) + .float()) + v_t = (v_tile.permute(1, 0, 2) + .unsqueeze(1) + .float()) + del k_tile, v_tile - kv_len = k_full.shape[0] # ctx_len + q_len + # Scores: [kv_h, gqa, q_len, valid] + s = torch.matmul(q_seq, k_t) + del k_t + # No causal mask: all context keys precede all queries. - # Transpose to [kv_h, kv_len, d], keep original dtype (fp16/bf16). - # Do NOT cast to fp32 here — k/v stay in fp16 to halve memory. - # attn_w is computed in fp32 (q cast to fp32 before matmul, then - # k cast inline) so softmax precision is unaffected. - # Do NOT expand GQA heads here either — gqa_ratio x memory savings. - k_t = k_full.permute(1, 0, 2).contiguous() # [kv_h, kv_len, d] fp16 - del k_full - v_t = v_full.permute(1, 0, 2).contiguous() # [kv_h, kv_len, d] fp16 - del v_full + # Online softmax update — Flash-Attention Algorithm 1. + # exp_s = s - new_max (in-place exp after del s) + m_blk = s.amax(dim=-1) + m_new = torch.maximum(m, m_blk) + exp_s = s - m_new.unsqueeze(-1) + del s + exp_s.exp_() + corr = torch.exp(m - m_new) + m.copy_(m_new) + del m_blk, m_new + l.mul_(corr).add_(exp_s.sum(dim=-1)) + o.mul_(corr.unsqueeze(-1)).add_( + torch.matmul(exp_s, v_t)) + del exp_s, v_t, corr - # k_pos used for causal mask: shape [kv_len] - k_pos = torch.arange(kv_len, device=query.device) + # -------------------------------------------------------------- + # Phase 2 — current-chunk tokens (positions ctx_len … ctx_len+q_len-1). + # + # Causal mask: query at relative position j sees key at relative + # position k only when k ≤ j. Tiles of tile_sz tokens each. + # -------------------------------------------------------------- + for kc_start in range(0, q_len, tile_sz): + kc_end = min(kc_start + tile_sz, q_len) + kc_len = kc_end - kc_start - # --- Query-chunked attention with lazy GQA grouping ---------- - # q_i reshaped to [kv_h, gqa_ratio, qc, d] so matmul with - # k_t [kv_h, kv_len, d] (broadcast over gqa_ratio dim) gives - # attn_w [kv_h, gqa_ratio, qc, kv_len] without extra K copies. - for qc_start in range(0, q_len, _ATTN_Q_CHUNK): - qc_end = min(qc_start + _ATTN_Q_CHUNK, q_len) - qc = qc_end - qc_start + k_blk = k_i[kc_start:kc_end] # [kc_len, kv_h, d] + v_blk = v_i[kc_start:kc_end] - # [kv_h, gqa_ratio, qc, d] - q_t_chunk = (q_i[qc_start:qc_end] - .permute(1, 0, 2) # [q_h, qc, d] - .float() - .view(num_kv_heads, gqa_ratio, qc, head_dim)) + k_t = (k_blk.permute(1, 0, 2) + .unsqueeze(1) + .transpose(-1, -2) + .float()) # [kv_h, 1, d, kc_len] + v_t = (v_blk.permute(1, 0, 2) + .unsqueeze(1) + .float()) # [kv_h, 1, kc_len, d] - # [kv_h, gqa_ratio, qc, kv_len] - # k_t unsqueezed to [kv_h, 1, kv_len, d] broadcasts over gqa_ratio. - # Cast k slice to fp32 inline; the temporary is freed after matmul. - attn_w = torch.matmul(q_t_chunk * scale, - k_t.unsqueeze(1).transpose(-1, -2).float()) + s = torch.matmul(q_seq, k_t) # [kv_h, gqa, q_len, kc_len] + del k_t - # Causal mask for this sub-chunk: - # query absolute position = ctx_len + qc_start..qc_end-1 - qc_q_pos = torch.arange(qc_start, qc_end, - device=query.device) - # mask[j, k] = True → future key, block it - mask = k_pos.unsqueeze(0) > (ctx_len + qc_q_pos.unsqueeze(1)) - attn_w.masked_fill_( - mask.unsqueeze(0).unsqueeze(0), float('-inf')) + # Causal mask: key at (kc_start+k) must not exceed query j. + k_rel = torch.arange(kc_start, kc_end, device=dev) + q_rel = torch.arange(q_len, device=dev) + mask = k_rel.unsqueeze(0) > q_rel.unsqueeze(1) # [q_len, kc_len] + s.masked_fill_(mask.unsqueeze(0).unsqueeze(0), float('-inf')) + del mask, k_rel, q_rel - # In-place numerically stable softmax — avoids allocating a - # new 150 MB tensor (same size as attn_w) that torch.softmax - # would create, which exhausts the fragmented GPU pool. - attn_w -= attn_w.amax(dim=-1, keepdim=True) - attn_w.exp_() - attn_w /= attn_w.sum(dim=-1, keepdim=True) - # [kv_h, gqa_ratio, qc, d]; v_t cast to fp32 inline - out_c = torch.matmul(attn_w, - v_t.unsqueeze(1).float()) - # reshape to [q_h, qc, d] then [qc, q_h, d] - out_c = out_c.view(num_q_heads, qc, head_dim) + # Online softmax update (identical to context phase). + m_blk = s.amax(dim=-1) + m_new = torch.maximum(m, m_blk) + exp_s = s - m_new.unsqueeze(-1) + del s + exp_s.exp_() + corr = torch.exp(m - m_new) + m.copy_(m_new) + del m_blk, m_new + l.mul_(corr).add_(exp_s.sum(dim=-1)) + o.mul_(corr.unsqueeze(-1)).add_( + torch.matmul(exp_s, v_t)) + del exp_s, v_t, corr + + # -------------------------------------------------------------- + # Finalize: normalize running output by normalization factor. + # o: [kv_h, gqa, q_len, d] → [q_len, q_h, d] + # -------------------------------------------------------------- + o.div_(l.unsqueeze(-1)) + output[q_start:q_end] = ( + o.view(num_q_heads, q_len, head_dim) + .permute(1, 0, 2) + .to(orig_dtype) + ) - output[q_start + qc_start : q_start + qc_end] = ( - out_c.to(orig_dtype).permute(1, 0, 2)) except Exception as e: - print(f"[paged_attn ERROR] {type(e).__name__}: {e}", file=sys.stderr, flush=True) + print(f"[paged_attn ERROR] {type(e).__name__}: {e}", + file=sys.stderr, flush=True) traceback.print_exc(file=sys.stderr) raise return output diff --git a/qwen3_6_scripts/patch_ops.sh b/qwen3_6_scripts/patch_ops.sh index cb9d427..7b6acba 100755 --- a/qwen3_6_scripts/patch_ops.sh +++ b/qwen3_6_scripts/patch_ops.sh @@ -10,18 +10,11 @@ # GPU hang on BI-V100 because the Triton CUDA PTX kernels are incompatible. # # Important Note: Qwen3.6-27B must apply TP=4,PP=2 combination in order to deploy using 8 GPUs -# -# Recommended server start command for TP=4, context length: 50K, no chunked prefill mechanism: -# CUDA_VISIBLE_DEVICES="4,5,6,7" VLLM_ENGINE_ITERATION_TIMEOUT_S=3600 python3 -m vllm.entrypoints.openai.api_server \ -# --model /workspace/models/Qwen3.6-27B --port 1111 --served-model-name llm \ -# --max-model-len 50000 --enforce-eager --trust-remote-code -tp 4 --gpu-memory-utilization 0.90 \ -# --max-num-seqs 1 --disable-log-requests --disable-frontend-multiprocessing \ -# --max-num-batched-tokens 50000 # Recommended server start command for TP=4 support 100K, need chunked prefill # CUDA_VISIBLE_DEVICES="4,5,6,7" VLLM_ENGINE_ITERATION_TIMEOUT_S=3600 python3 -m vllm.entrypoints.openai.api_server \ # --model /workspace/models/Qwen3.6-27B --port 1111 --served-model-name llm \ -# --max-model-len 100000 --enforce-eager --trust-remote-code -tp 8 --gpu-memory-utilization 0.95 \ +# --max-model-len 100000 --enforce-eager --trust-remote-code -tp 4 --gpu-memory-utilization 0.95 \ # --max-num-seqs 1 --disable-log-requests --disable-frontend-multiprocessing \ # --max-num-batched-tokens 4096 --enable-chunked-prefill @@ -29,8 +22,8 @@ # The Triton context_attention_fwd kernel hangs BI-V100 GPUs permanently # (standard Triton 2.3.1 PTX is not supported by the corex runtime either). # Our paged_attn.py bypasses it entirely via _forward_prefix_pytorch, which -# also implements query-chunking (_ATTN_Q_CHUNK=256) to keep peak attention -# memory at O(256 × kv_len) instead of O(q_len × kv_len). +# utilizes K-tiling techniques, and also have _forward_decode_pytorch to bypass kernel +# when context length is high cp ./paged_attn.py /usr/local/corex/lib64/python3/dist-packages/vllm/attention/ops/paged_attn.py # --- transformers: Qwen3_5 tokenizer / model files -------------------------- @@ -49,4 +42,5 @@ python3 ./patch_vllm_qwen3_5.py # crashes (is_causal=True) or produces wrong output (attn_mask path). # The fallback uses query_start_loc to derive actual query lengths, so it # works correctly during profiling runs with chunked-prefill-style batches. +# also bypasses auto chunked prefill on python3 ./patch_xformers_sdpa_seq.py