chunked prefill support and memory opts
This commit is contained in:
@@ -44,6 +44,31 @@ ARG_UTILS_PATH = (
|
||||
"vllm/engine/arg_utils.py"
|
||||
)
|
||||
|
||||
LOGITS_PROC_PATH = (
|
||||
"/usr/local/corex/lib64/python3/dist-packages/"
|
||||
"vllm/model_executor/layers/logits_processor.py"
|
||||
)
|
||||
|
||||
# _apply_logits_processors crashes when seq_groups is None (intermediate
|
||||
# chunked-prefill chunks on the driver rank). Add an early-return guard.
|
||||
_LP_OLD_BLOCK = """\
|
||||
def _apply_logits_processors(
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
found_logits_processors = False\
|
||||
"""
|
||||
|
||||
_LP_NEW_BLOCK = """\
|
||||
def _apply_logits_processors(
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
if sampling_metadata.seq_groups is None: # intermediate chunked-prefill chunk
|
||||
return logits
|
||||
found_logits_processors = False\
|
||||
"""
|
||||
|
||||
# vllm 0.6.3 自动开启 chunked prefill 的原始块
|
||||
_ARG_OLD_BLOCK = """\
|
||||
if (is_gpu and not use_sliding_window and not use_spec_decode
|
||||
@@ -256,6 +281,26 @@ def patch_arg_utils(path):
|
||||
print(f" Written: {path}")
|
||||
|
||||
|
||||
def patch_logits_processor(path):
|
||||
with open(path, "r") as f:
|
||||
content = f.read()
|
||||
changed = False
|
||||
|
||||
if "intermediate chunked-prefill chunk" in content:
|
||||
print(" [skip] seq_groups=None guard already present")
|
||||
elif _LP_OLD_BLOCK in content:
|
||||
content = content.replace(_LP_OLD_BLOCK, _LP_NEW_BLOCK, 1)
|
||||
print(" [ok] added seq_groups=None guard in _apply_logits_processors")
|
||||
changed = True
|
||||
else:
|
||||
print(" [warn] target block not found — check logits_processor.py version")
|
||||
|
||||
if changed:
|
||||
with open(path, "w") as f:
|
||||
f.write(content)
|
||||
print(f" Written: {path}")
|
||||
|
||||
|
||||
def main():
|
||||
print("=== patch_xformers_sdpa_seq (sequential, pure-math) ===")
|
||||
print(f"Target: {XFORMERS_PATH}")
|
||||
@@ -265,6 +310,10 @@ def main():
|
||||
print(f"Target: {ARG_UTILS_PATH}")
|
||||
patch_arg_utils(ARG_UTILS_PATH)
|
||||
|
||||
print("\n=== patch_logits_processor (seq_groups=None guard for chunked prefill) ===")
|
||||
print(f"Target: {LOGITS_PROC_PATH}")
|
||||
patch_logits_processor(LOGITS_PROC_PATH)
|
||||
|
||||
print("\nDone.")
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user