some modifications to ensure 50K context input
This commit is contained in:
@@ -85,6 +85,85 @@ class PagedAttention:
|
||||
v_scale,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _forward_decode_pytorch(
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
scale: float,
|
||||
) -> torch.Tensor:
|
||||
"""Pure-PyTorch decode attention for long contexts (no hardware kernel).
|
||||
|
||||
paged_attention_v1 hangs on BI-V100 when max_seq_len > ~32K due to
|
||||
shared memory limits. For decode, q_len=1 per sequence so no Q-tiling
|
||||
is needed — the attention weight tensor is [H, 1, seq_len] which is
|
||||
trivially small (~5 MB at 50K).
|
||||
|
||||
Shapes
|
||||
------
|
||||
query : [num_seqs, num_heads, head_dim]
|
||||
key_cache : [num_blocks, num_kv_heads, head_dim//x, block_size, x]
|
||||
value_cache : [num_blocks, num_kv_heads, head_dim, block_size]
|
||||
block_tables: [num_seqs, max_blocks_per_seq]
|
||||
seq_lens : [num_seqs]
|
||||
"""
|
||||
num_seqs, num_heads, head_dim = query.shape
|
||||
num_kv_heads = key_cache.shape[1]
|
||||
block_size = value_cache.shape[3]
|
||||
gqa_ratio = num_heads // num_kv_heads
|
||||
orig_dtype = query.dtype
|
||||
|
||||
output = torch.empty_like(query)
|
||||
|
||||
try:
|
||||
for i in range(num_seqs):
|
||||
seq_len = int(seq_lens[i].item())
|
||||
num_blocks = (seq_len + block_size - 1) // block_size
|
||||
blk_ids = block_tables[i, :num_blocks]
|
||||
|
||||
# Gather K from paged cache: [seq_len, num_kv_heads, head_dim]
|
||||
k_seq = (key_cache[blk_ids]
|
||||
.permute(0, 3, 1, 2, 4)
|
||||
.contiguous()
|
||||
.view(-1, num_kv_heads, head_dim))[:seq_len]
|
||||
|
||||
# Gather V from paged cache: [seq_len, num_kv_heads, head_dim]
|
||||
v_seq = (value_cache[blk_ids]
|
||||
.permute(0, 3, 1, 2)
|
||||
.contiguous()
|
||||
.view(-1, num_kv_heads, head_dim))[:seq_len]
|
||||
|
||||
if gqa_ratio > 1:
|
||||
k_seq = k_seq.repeat_interleave(gqa_ratio, dim=1)
|
||||
v_seq = v_seq.repeat_interleave(gqa_ratio, dim=1)
|
||||
|
||||
# [H, head_dim, seq_len] and [H, seq_len, head_dim]
|
||||
k_t = k_seq.permute(1, 2, 0).float()
|
||||
v_t = v_seq.permute(1, 0, 2).float()
|
||||
|
||||
# q: [H, 1, head_dim]; attn_w: [H, 1, seq_len]
|
||||
q_i = query[i].float().unsqueeze(1)
|
||||
attn_w = torch.matmul(q_i * scale, k_t)
|
||||
attn_w = torch.softmax(attn_w, dim=-1)
|
||||
|
||||
out_i = torch.matmul(attn_w, v_t) # [H, 1, head_dim]
|
||||
output[i] = out_i.squeeze(1).to(orig_dtype)
|
||||
except Exception as e:
|
||||
print(f"[decode_pytorch ERROR] {type(e).__name__}: {e}",
|
||||
file=sys.stderr, flush=True)
|
||||
traceback.print_exc(file=sys.stderr)
|
||||
raise
|
||||
|
||||
return output
|
||||
|
||||
# paged_attention_v1 on BI-V100 hangs when max_seq_len exceeds ~32K due to
|
||||
# shared memory limits; use pure-PyTorch fallback above this threshold.
|
||||
# Set to a large value to disable for now (50K decode confirmed working via
|
||||
# hardware kernel); lower to 32768 if kernel hangs are observed at long contexts.
|
||||
_PYTORCH_DECODE_THRESHOLD = 10_000_000
|
||||
|
||||
@staticmethod
|
||||
def forward_decode(
|
||||
query: torch.Tensor,
|
||||
@@ -105,6 +184,10 @@ class PagedAttention:
|
||||
blocksparse_block_size: int = 64,
|
||||
blocksparse_head_sliding_step: int = 0,
|
||||
) -> torch.Tensor:
|
||||
if max_seq_len > PagedAttention._PYTORCH_DECODE_THRESHOLD:
|
||||
return PagedAttention._forward_decode_pytorch(
|
||||
query, key_cache, value_cache, block_tables, seq_lens, scale)
|
||||
|
||||
if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1:
|
||||
# use blocksparse paged attention
|
||||
block_size = value_cache.size(-1)
|
||||
|
||||
@@ -24,6 +24,12 @@ flash attention kernel(ixformer / cudnnFlashAttnForward)。
|
||||
max-model-len=8192 → 峰值 ~800 MB
|
||||
max-model-len=16384 → 峰值 ~3.2 GB
|
||||
|
||||
额外 patch(arg_utils.py):
|
||||
vllm 0.6.3 在 max_model_len > 32K 时会自动开启 chunked prefill(无命令行
|
||||
关闭选项),原意是防止 profiling OOM。但 _run_sdpa_fallback 已通过 Q-tiling
|
||||
解决了该问题,chunked prefill 反而会把推理路径从 _run_sdpa_fallback 切换到
|
||||
_forward_prefix_pytorch,属于不必要的行为变更,因此一并禁用该自动逻辑。
|
||||
|
||||
Deploy:
|
||||
python3 modified_scripts/patch_xformers_sdpa_seq.py
|
||||
"""
|
||||
@@ -33,6 +39,33 @@ XFORMERS_PATH = (
|
||||
"vllm/attention/backends/xformers.py"
|
||||
)
|
||||
|
||||
ARG_UTILS_PATH = (
|
||||
"/usr/local/corex/lib64/python3/dist-packages/"
|
||||
"vllm/engine/arg_utils.py"
|
||||
)
|
||||
|
||||
# vllm 0.6.3 自动开启 chunked prefill 的原始块
|
||||
_ARG_OLD_BLOCK = """\
|
||||
if (is_gpu and not use_sliding_window and not use_spec_decode
|
||||
and not self.enable_lora
|
||||
and not self.enable_prompt_adapter):
|
||||
self.enable_chunked_prefill = True
|
||||
logger.warning(
|
||||
"Chunked prefill is enabled by default for models with "
|
||||
"max_model_len > 32K. Currently, chunked prefill might "
|
||||
"not work with some features or models. If you "
|
||||
"encounter any issues, please disable chunked prefill "
|
||||
"by setting --enable-chunked-prefill=False.")\
|
||||
"""
|
||||
|
||||
_ARG_NEW_BLOCK = """\
|
||||
if (is_gpu and not use_sliding_window and not use_spec_decode
|
||||
and not self.enable_lora
|
||||
and not self.enable_prompt_adapter):
|
||||
pass # skip auto-enable: Q-tiling in _run_sdpa_fallback
|
||||
# handles long-context memory without chunked prefill\
|
||||
"""
|
||||
|
||||
FALLBACK_METHOD = '''
|
||||
def _run_sdpa_fallback(
|
||||
self,
|
||||
@@ -203,10 +236,35 @@ def patch_file(path):
|
||||
print(f" Written: {path}")
|
||||
|
||||
|
||||
def patch_arg_utils(path):
|
||||
with open(path, "r") as f:
|
||||
content = f.read()
|
||||
changed = False
|
||||
|
||||
if "skip auto-enable: Q-tiling" in content:
|
||||
print(" [skip] chunked-prefill auto-enable already disabled")
|
||||
elif _ARG_OLD_BLOCK in content:
|
||||
content = content.replace(_ARG_OLD_BLOCK, _ARG_NEW_BLOCK, 1)
|
||||
print(" [ok] disabled chunked-prefill auto-enable for 32K+")
|
||||
changed = True
|
||||
else:
|
||||
print(" [warn] target block not found — check arg_utils.py version")
|
||||
|
||||
if changed:
|
||||
with open(path, "w") as f:
|
||||
f.write(content)
|
||||
print(f" Written: {path}")
|
||||
|
||||
|
||||
def main():
|
||||
print("=== patch_xformers_sdpa_seq (sequential, pure-math) ===")
|
||||
print(f"Target: {XFORMERS_PATH}")
|
||||
patch_file(XFORMERS_PATH)
|
||||
|
||||
print("\n=== patch_arg_utils (disable chunked-prefill auto-enable) ===")
|
||||
print(f"Target: {ARG_UTILS_PATH}")
|
||||
patch_arg_utils(ARG_UTILS_PATH)
|
||||
|
||||
print("\nDone.")
|
||||
|
||||
|
||||
|
||||
@@ -412,6 +412,9 @@ class GatedDeltaNet(nn.Module):
|
||||
|
||||
else:
|
||||
# Decode: one token per sequence
|
||||
with open("/tmp/vllm_decode_debug.log", "a") as _f:
|
||||
_f.write(f"[deltanet decode] layer={self.layer_idx} num_seqs={hidden_states.shape[0]}\n")
|
||||
_f.flush()
|
||||
num_seqs = hidden_states.shape[0]
|
||||
weight_2d = self.conv1d_weight.squeeze(1)
|
||||
|
||||
@@ -847,6 +850,12 @@ class Qwen3_5ForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
# Non-driver TP ranks have seq_groups=None in sampling_metadata (normal
|
||||
# TP behavior); they must still call logits_processor to participate in
|
||||
# the NCCL gather inside lm_head. logits_processor returns None for
|
||||
# non-driver ranks after the gather, safely skipping _apply_logits_processors.
|
||||
# Rank 0 (driver) always has seq_groups != None given
|
||||
# --max-num-batched-tokens >= --max-model-len (no chunked-prefill splits).
|
||||
return self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user