diff --git a/qwen3_6_scripts/paged_attn.py b/qwen3_6_scripts/paged_attn.py index 64d97d8..4e58486 100644 --- a/qwen3_6_scripts/paged_attn.py +++ b/qwen3_6_scripts/paged_attn.py @@ -85,6 +85,85 @@ class PagedAttention: v_scale, ) + @staticmethod + def _forward_decode_pytorch( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + scale: float, + ) -> torch.Tensor: + """Pure-PyTorch decode attention for long contexts (no hardware kernel). + + paged_attention_v1 hangs on BI-V100 when max_seq_len > ~32K due to + shared memory limits. For decode, q_len=1 per sequence so no Q-tiling + is needed — the attention weight tensor is [H, 1, seq_len] which is + trivially small (~5 MB at 50K). + + Shapes + ------ + query : [num_seqs, num_heads, head_dim] + key_cache : [num_blocks, num_kv_heads, head_dim//x, block_size, x] + value_cache : [num_blocks, num_kv_heads, head_dim, block_size] + block_tables: [num_seqs, max_blocks_per_seq] + seq_lens : [num_seqs] + """ + num_seqs, num_heads, head_dim = query.shape + num_kv_heads = key_cache.shape[1] + block_size = value_cache.shape[3] + gqa_ratio = num_heads // num_kv_heads + orig_dtype = query.dtype + + output = torch.empty_like(query) + + try: + for i in range(num_seqs): + seq_len = int(seq_lens[i].item()) + num_blocks = (seq_len + block_size - 1) // block_size + blk_ids = block_tables[i, :num_blocks] + + # Gather K from paged cache: [seq_len, num_kv_heads, head_dim] + k_seq = (key_cache[blk_ids] + .permute(0, 3, 1, 2, 4) + .contiguous() + .view(-1, num_kv_heads, head_dim))[:seq_len] + + # Gather V from paged cache: [seq_len, num_kv_heads, head_dim] + v_seq = (value_cache[blk_ids] + .permute(0, 3, 1, 2) + .contiguous() + .view(-1, num_kv_heads, head_dim))[:seq_len] + + if gqa_ratio > 1: + k_seq = k_seq.repeat_interleave(gqa_ratio, dim=1) + v_seq = v_seq.repeat_interleave(gqa_ratio, dim=1) + + # [H, head_dim, seq_len] and [H, seq_len, head_dim] + k_t = k_seq.permute(1, 2, 0).float() + v_t = v_seq.permute(1, 0, 2).float() + + # q: [H, 1, head_dim]; attn_w: [H, 1, seq_len] + q_i = query[i].float().unsqueeze(1) + attn_w = torch.matmul(q_i * scale, k_t) + attn_w = torch.softmax(attn_w, dim=-1) + + out_i = torch.matmul(attn_w, v_t) # [H, 1, head_dim] + output[i] = out_i.squeeze(1).to(orig_dtype) + except Exception as e: + print(f"[decode_pytorch ERROR] {type(e).__name__}: {e}", + file=sys.stderr, flush=True) + traceback.print_exc(file=sys.stderr) + raise + + return output + + # paged_attention_v1 on BI-V100 hangs when max_seq_len exceeds ~32K due to + # shared memory limits; use pure-PyTorch fallback above this threshold. + # Set to a large value to disable for now (50K decode confirmed working via + # hardware kernel); lower to 32768 if kernel hangs are observed at long contexts. + _PYTORCH_DECODE_THRESHOLD = 10_000_000 + @staticmethod def forward_decode( query: torch.Tensor, @@ -105,6 +184,10 @@ class PagedAttention: blocksparse_block_size: int = 64, blocksparse_head_sliding_step: int = 0, ) -> torch.Tensor: + if max_seq_len > PagedAttention._PYTORCH_DECODE_THRESHOLD: + return PagedAttention._forward_decode_pytorch( + query, key_cache, value_cache, block_tables, seq_lens, scale) + if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1: # use blocksparse paged attention block_size = value_cache.size(-1) diff --git a/qwen3_6_scripts/patch_xformers_sdpa_seq.py b/qwen3_6_scripts/patch_xformers_sdpa_seq.py index 8919502..181d184 100644 --- a/qwen3_6_scripts/patch_xformers_sdpa_seq.py +++ b/qwen3_6_scripts/patch_xformers_sdpa_seq.py @@ -24,6 +24,12 @@ flash attention kernel(ixformer / cudnnFlashAttnForward)。 max-model-len=8192 → 峰值 ~800 MB max-model-len=16384 → 峰值 ~3.2 GB +额外 patch(arg_utils.py): + vllm 0.6.3 在 max_model_len > 32K 时会自动开启 chunked prefill(无命令行 + 关闭选项),原意是防止 profiling OOM。但 _run_sdpa_fallback 已通过 Q-tiling + 解决了该问题,chunked prefill 反而会把推理路径从 _run_sdpa_fallback 切换到 + _forward_prefix_pytorch,属于不必要的行为变更,因此一并禁用该自动逻辑。 + Deploy: python3 modified_scripts/patch_xformers_sdpa_seq.py """ @@ -33,6 +39,33 @@ XFORMERS_PATH = ( "vllm/attention/backends/xformers.py" ) +ARG_UTILS_PATH = ( + "/usr/local/corex/lib64/python3/dist-packages/" + "vllm/engine/arg_utils.py" +) + +# vllm 0.6.3 自动开启 chunked prefill 的原始块 +_ARG_OLD_BLOCK = """\ + if (is_gpu and not use_sliding_window and not use_spec_decode + and not self.enable_lora + and not self.enable_prompt_adapter): + self.enable_chunked_prefill = True + logger.warning( + "Chunked prefill is enabled by default for models with " + "max_model_len > 32K. Currently, chunked prefill might " + "not work with some features or models. If you " + "encounter any issues, please disable chunked prefill " + "by setting --enable-chunked-prefill=False.")\ +""" + +_ARG_NEW_BLOCK = """\ + if (is_gpu and not use_sliding_window and not use_spec_decode + and not self.enable_lora + and not self.enable_prompt_adapter): + pass # skip auto-enable: Q-tiling in _run_sdpa_fallback + # handles long-context memory without chunked prefill\ +""" + FALLBACK_METHOD = ''' def _run_sdpa_fallback( self, @@ -203,10 +236,35 @@ def patch_file(path): print(f" Written: {path}") +def patch_arg_utils(path): + with open(path, "r") as f: + content = f.read() + changed = False + + if "skip auto-enable: Q-tiling" in content: + print(" [skip] chunked-prefill auto-enable already disabled") + elif _ARG_OLD_BLOCK in content: + content = content.replace(_ARG_OLD_BLOCK, _ARG_NEW_BLOCK, 1) + print(" [ok] disabled chunked-prefill auto-enable for 32K+") + changed = True + else: + print(" [warn] target block not found — check arg_utils.py version") + + if changed: + with open(path, "w") as f: + f.write(content) + print(f" Written: {path}") + + def main(): print("=== patch_xformers_sdpa_seq (sequential, pure-math) ===") print(f"Target: {XFORMERS_PATH}") patch_file(XFORMERS_PATH) + + print("\n=== patch_arg_utils (disable chunked-prefill auto-enable) ===") + print(f"Target: {ARG_UTILS_PATH}") + patch_arg_utils(ARG_UTILS_PATH) + print("\nDone.") diff --git a/qwen3_6_scripts/qwen3_5.py b/qwen3_6_scripts/qwen3_5.py index 78f6603..2b724c9 100644 --- a/qwen3_6_scripts/qwen3_5.py +++ b/qwen3_6_scripts/qwen3_5.py @@ -412,6 +412,9 @@ class GatedDeltaNet(nn.Module): else: # Decode: one token per sequence + with open("/tmp/vllm_decode_debug.log", "a") as _f: + _f.write(f"[deltanet decode] layer={self.layer_idx} num_seqs={hidden_states.shape[0]}\n") + _f.flush() num_seqs = hidden_states.shape[0] weight_2d = self.conv1d_weight.squeeze(1) @@ -847,6 +850,12 @@ class Qwen3_5ForCausalLM(nn.Module, HasInnerState, SupportsLoRA): hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: + # Non-driver TP ranks have seq_groups=None in sampling_metadata (normal + # TP behavior); they must still call logits_processor to participate in + # the NCCL gather inside lm_head. logits_processor returns None for + # non-driver ranks after the gather, safely skipping _apply_logits_processors. + # Rank 0 (driver) always has seq_groups != None given + # --max-num-batched-tokens >= --max-model-len (no chunked-prefill splits). return self.logits_processor(self.lm_head, hidden_states, sampling_metadata)