some modifications to ensure 50K context input

This commit is contained in:
2026-06-04 17:56:29 +08:00
parent 1c33ef1355
commit 8c047a70ea
3 changed files with 150 additions and 0 deletions

View File

@@ -85,6 +85,85 @@ class PagedAttention:
v_scale,
)
@staticmethod
def _forward_decode_pytorch(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
scale: float,
) -> torch.Tensor:
"""Pure-PyTorch decode attention for long contexts (no hardware kernel).
paged_attention_v1 hangs on BI-V100 when max_seq_len > ~32K due to
shared memory limits. For decode, q_len=1 per sequence so no Q-tiling
is needed — the attention weight tensor is [H, 1, seq_len] which is
trivially small (~5 MB at 50K).
Shapes
------
query : [num_seqs, num_heads, head_dim]
key_cache : [num_blocks, num_kv_heads, head_dim//x, block_size, x]
value_cache : [num_blocks, num_kv_heads, head_dim, block_size]
block_tables: [num_seqs, max_blocks_per_seq]
seq_lens : [num_seqs]
"""
num_seqs, num_heads, head_dim = query.shape
num_kv_heads = key_cache.shape[1]
block_size = value_cache.shape[3]
gqa_ratio = num_heads // num_kv_heads
orig_dtype = query.dtype
output = torch.empty_like(query)
try:
for i in range(num_seqs):
seq_len = int(seq_lens[i].item())
num_blocks = (seq_len + block_size - 1) // block_size
blk_ids = block_tables[i, :num_blocks]
# Gather K from paged cache: [seq_len, num_kv_heads, head_dim]
k_seq = (key_cache[blk_ids]
.permute(0, 3, 1, 2, 4)
.contiguous()
.view(-1, num_kv_heads, head_dim))[:seq_len]
# Gather V from paged cache: [seq_len, num_kv_heads, head_dim]
v_seq = (value_cache[blk_ids]
.permute(0, 3, 1, 2)
.contiguous()
.view(-1, num_kv_heads, head_dim))[:seq_len]
if gqa_ratio > 1:
k_seq = k_seq.repeat_interleave(gqa_ratio, dim=1)
v_seq = v_seq.repeat_interleave(gqa_ratio, dim=1)
# [H, head_dim, seq_len] and [H, seq_len, head_dim]
k_t = k_seq.permute(1, 2, 0).float()
v_t = v_seq.permute(1, 0, 2).float()
# q: [H, 1, head_dim]; attn_w: [H, 1, seq_len]
q_i = query[i].float().unsqueeze(1)
attn_w = torch.matmul(q_i * scale, k_t)
attn_w = torch.softmax(attn_w, dim=-1)
out_i = torch.matmul(attn_w, v_t) # [H, 1, head_dim]
output[i] = out_i.squeeze(1).to(orig_dtype)
except Exception as e:
print(f"[decode_pytorch ERROR] {type(e).__name__}: {e}",
file=sys.stderr, flush=True)
traceback.print_exc(file=sys.stderr)
raise
return output
# paged_attention_v1 on BI-V100 hangs when max_seq_len exceeds ~32K due to
# shared memory limits; use pure-PyTorch fallback above this threshold.
# Set to a large value to disable for now (50K decode confirmed working via
# hardware kernel); lower to 32768 if kernel hangs are observed at long contexts.
_PYTORCH_DECODE_THRESHOLD = 10_000_000
@staticmethod
def forward_decode(
query: torch.Tensor,
@@ -105,6 +184,10 @@ class PagedAttention:
blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0,
) -> torch.Tensor:
if max_seq_len > PagedAttention._PYTORCH_DECODE_THRESHOLD:
return PagedAttention._forward_decode_pytorch(
query, key_cache, value_cache, block_tables, seq_lens, scale)
if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1:
# use blocksparse paged attention
block_size = value_cache.size(-1)

View File

@@ -24,6 +24,12 @@ flash attention kernelixformer / cudnnFlashAttnForward
max-model-len=8192 → 峰值 ~800 MB
max-model-len=16384 → 峰值 ~3.2 GB
额外 patcharg_utils.py
vllm 0.6.3 在 max_model_len > 32K 时会自动开启 chunked prefill无命令行
关闭选项),原意是防止 profiling OOM。但 _run_sdpa_fallback 已通过 Q-tiling
解决了该问题chunked prefill 反而会把推理路径从 _run_sdpa_fallback 切换到
_forward_prefix_pytorch属于不必要的行为变更因此一并禁用该自动逻辑。
Deploy:
python3 modified_scripts/patch_xformers_sdpa_seq.py
"""
@@ -33,6 +39,33 @@ XFORMERS_PATH = (
"vllm/attention/backends/xformers.py"
)
ARG_UTILS_PATH = (
"/usr/local/corex/lib64/python3/dist-packages/"
"vllm/engine/arg_utils.py"
)
# vllm 0.6.3 自动开启 chunked prefill 的原始块
_ARG_OLD_BLOCK = """\
if (is_gpu and not use_sliding_window and not use_spec_decode
and not self.enable_lora
and not self.enable_prompt_adapter):
self.enable_chunked_prefill = True
logger.warning(
"Chunked prefill is enabled by default for models with "
"max_model_len > 32K. Currently, chunked prefill might "
"not work with some features or models. If you "
"encounter any issues, please disable chunked prefill "
"by setting --enable-chunked-prefill=False.")\
"""
_ARG_NEW_BLOCK = """\
if (is_gpu and not use_sliding_window and not use_spec_decode
and not self.enable_lora
and not self.enable_prompt_adapter):
pass # skip auto-enable: Q-tiling in _run_sdpa_fallback
# handles long-context memory without chunked prefill\
"""
FALLBACK_METHOD = '''
def _run_sdpa_fallback(
self,
@@ -203,10 +236,35 @@ def patch_file(path):
print(f" Written: {path}")
def patch_arg_utils(path):
with open(path, "r") as f:
content = f.read()
changed = False
if "skip auto-enable: Q-tiling" in content:
print(" [skip] chunked-prefill auto-enable already disabled")
elif _ARG_OLD_BLOCK in content:
content = content.replace(_ARG_OLD_BLOCK, _ARG_NEW_BLOCK, 1)
print(" [ok] disabled chunked-prefill auto-enable for 32K+")
changed = True
else:
print(" [warn] target block not found — check arg_utils.py version")
if changed:
with open(path, "w") as f:
f.write(content)
print(f" Written: {path}")
def main():
print("=== patch_xformers_sdpa_seq (sequential, pure-math) ===")
print(f"Target: {XFORMERS_PATH}")
patch_file(XFORMERS_PATH)
print("\n=== patch_arg_utils (disable chunked-prefill auto-enable) ===")
print(f"Target: {ARG_UTILS_PATH}")
patch_arg_utils(ARG_UTILS_PATH)
print("\nDone.")

View File

@@ -412,6 +412,9 @@ class GatedDeltaNet(nn.Module):
else:
# Decode: one token per sequence
with open("/tmp/vllm_decode_debug.log", "a") as _f:
_f.write(f"[deltanet decode] layer={self.layer_idx} num_seqs={hidden_states.shape[0]}\n")
_f.flush()
num_seqs = hidden_states.shape[0]
weight_2d = self.conv1d_weight.squeeze(1)
@@ -847,6 +850,12 @@ class Qwen3_5ForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
# Non-driver TP ranks have seq_groups=None in sampling_metadata (normal
# TP behavior); they must still call logits_processor to participate in
# the NCCL gather inside lm_head. logits_processor returns None for
# non-driver ranks after the gather, safely skipping _apply_logits_processors.
# Rank 0 (driver) always has seq_groups != None given
# --max-num-batched-tokens >= --max-model-len (no chunked-prefill splits).
return self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)