some modifications to ensure 50K context input

This commit is contained in:
2026-06-04 17:56:29 +08:00
parent 1c33ef1355
commit 8c047a70ea
3 changed files with 150 additions and 0 deletions

View File

@@ -24,6 +24,12 @@ flash attention kernelixformer / cudnnFlashAttnForward
max-model-len=8192 → 峰值 ~800 MB
max-model-len=16384 → 峰值 ~3.2 GB
额外 patcharg_utils.py
vllm 0.6.3 在 max_model_len > 32K 时会自动开启 chunked prefill无命令行
关闭选项),原意是防止 profiling OOM。但 _run_sdpa_fallback 已通过 Q-tiling
解决了该问题chunked prefill 反而会把推理路径从 _run_sdpa_fallback 切换到
_forward_prefix_pytorch属于不必要的行为变更因此一并禁用该自动逻辑。
Deploy:
python3 modified_scripts/patch_xformers_sdpa_seq.py
"""
@@ -33,6 +39,33 @@ XFORMERS_PATH = (
"vllm/attention/backends/xformers.py"
)
ARG_UTILS_PATH = (
"/usr/local/corex/lib64/python3/dist-packages/"
"vllm/engine/arg_utils.py"
)
# vllm 0.6.3 自动开启 chunked prefill 的原始块
_ARG_OLD_BLOCK = """\
if (is_gpu and not use_sliding_window and not use_spec_decode
and not self.enable_lora
and not self.enable_prompt_adapter):
self.enable_chunked_prefill = True
logger.warning(
"Chunked prefill is enabled by default for models with "
"max_model_len > 32K. Currently, chunked prefill might "
"not work with some features or models. If you "
"encounter any issues, please disable chunked prefill "
"by setting --enable-chunked-prefill=False.")\
"""
_ARG_NEW_BLOCK = """\
if (is_gpu and not use_sliding_window and not use_spec_decode
and not self.enable_lora
and not self.enable_prompt_adapter):
pass # skip auto-enable: Q-tiling in _run_sdpa_fallback
# handles long-context memory without chunked prefill\
"""
FALLBACK_METHOD = '''
def _run_sdpa_fallback(
self,
@@ -203,10 +236,35 @@ def patch_file(path):
print(f" Written: {path}")
def patch_arg_utils(path):
with open(path, "r") as f:
content = f.read()
changed = False
if "skip auto-enable: Q-tiling" in content:
print(" [skip] chunked-prefill auto-enable already disabled")
elif _ARG_OLD_BLOCK in content:
content = content.replace(_ARG_OLD_BLOCK, _ARG_NEW_BLOCK, 1)
print(" [ok] disabled chunked-prefill auto-enable for 32K+")
changed = True
else:
print(" [warn] target block not found — check arg_utils.py version")
if changed:
with open(path, "w") as f:
f.write(content)
print(f" Written: {path}")
def main():
print("=== patch_xformers_sdpa_seq (sequential, pure-math) ===")
print(f"Target: {XFORMERS_PATH}")
patch_file(XFORMERS_PATH)
print("\n=== patch_arg_utils (disable chunked-prefill auto-enable) ===")
print(f"Target: {ARG_UTILS_PATH}")
patch_arg_utils(ARG_UTILS_PATH)
print("\nDone.")