chunked prefill support and memory opts

This commit is contained in:
2026-06-05 16:03:34 +08:00
parent 8c047a70ea
commit 2d1ef50992
4 changed files with 166 additions and 86 deletions

View File

@@ -44,6 +44,31 @@ ARG_UTILS_PATH = (
"vllm/engine/arg_utils.py"
)
LOGITS_PROC_PATH = (
"/usr/local/corex/lib64/python3/dist-packages/"
"vllm/model_executor/layers/logits_processor.py"
)
# _apply_logits_processors crashes when seq_groups is None (intermediate
# chunked-prefill chunks on the driver rank). Add an early-return guard.
_LP_OLD_BLOCK = """\
def _apply_logits_processors(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
found_logits_processors = False\
"""
_LP_NEW_BLOCK = """\
def _apply_logits_processors(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
if sampling_metadata.seq_groups is None: # intermediate chunked-prefill chunk
return logits
found_logits_processors = False\
"""
# vllm 0.6.3 自动开启 chunked prefill 的原始块
_ARG_OLD_BLOCK = """\
if (is_gpu and not use_sliding_window and not use_spec_decode
@@ -256,6 +281,26 @@ def patch_arg_utils(path):
print(f" Written: {path}")
def patch_logits_processor(path):
with open(path, "r") as f:
content = f.read()
changed = False
if "intermediate chunked-prefill chunk" in content:
print(" [skip] seq_groups=None guard already present")
elif _LP_OLD_BLOCK in content:
content = content.replace(_LP_OLD_BLOCK, _LP_NEW_BLOCK, 1)
print(" [ok] added seq_groups=None guard in _apply_logits_processors")
changed = True
else:
print(" [warn] target block not found — check logits_processor.py version")
if changed:
with open(path, "w") as f:
f.write(content)
print(f" Written: {path}")
def main():
print("=== patch_xformers_sdpa_seq (sequential, pure-math) ===")
print(f"Target: {XFORMERS_PATH}")
@@ -265,6 +310,10 @@ def main():
print(f"Target: {ARG_UTILS_PATH}")
patch_arg_utils(ARG_UTILS_PATH)
print("\n=== patch_logits_processor (seq_groups=None guard for chunked prefill) ===")
print(f"Target: {LOGITS_PROC_PATH}")
patch_logits_processor(LOGITS_PROC_PATH)
print("\nDone.")