Files
enginex-vllm-bi100-qwen36/qwen3_6_scripts/patch_model_runner.py

79 lines
3.0 KiB
Python
Raw Permalink Normal View History

2026-06-26 12:55:02 +08:00
"""
Fix: prefix_cache_hit stays True for chunked-prefill chunk 2+ even when past cache.
Root cause:
model_runner.py _compute_for_prefix_cache_hit has three cases:
Case 1: prefix_cache_len <= context_len "already past cache, do normal"
Case 2: context_len < prefix_cache_len < seq_len partial hit, correct
Case 3: seq_len <= prefix_cache_len full hit, reduce to 1 token
Case 1 does nothing (leaves prefix_cache_hit = True). Then in utils.py:
if inter_data.prefix_cache_hit:
block_table = computed_block_nums ONLY the original prefix blocks!
But context_len > prefix_cache_len means chunk 1 tokens (between prefix_cache_len
and context_len) are ALSO in KV cache and need to be in block_table.
block_table = computed_block_nums misses all chunk-1 blocks.
In _forward_prefix_pytorch:
num_ctx_blocks = ceil(context_len / block_size) # e.g. 268
block_tables.shape[1] = len(computed_block_nums) # e.g. 12 <-- too small!
At tile_blk >= 12: blk_ids is empty k_t shape [..., 0] amax crash.
Fix:
Set prefix_cache_hit = False for Case 1, so utils.py falls through to:
elif chunked_prefill_enabled:
block_table = block_tables[seq_id] full block table (prefix + chunk1)
"""
import re
import sys
CANDIDATE_PATHS = [
"/usr/local/corex/lib64/python3/dist-packages/vllm/worker/model_runner.py",
"/usr/local/corex/lib/python3/dist-packages/vllm/worker/model_runner.py",
]
OLD_BLOCK = """\
if prefix_cache_len <= context_len:
# We already passed the cache hit region,
# so do normal computation.
pass"""
NEW_BLOCK = """\
if prefix_cache_len <= context_len:
# We already passed the cache hit region,
# so do normal computation.
# Must clear prefix_cache_hit so _add_seq_group uses the full
# block_tables (prefix + previous-chunk blocks) instead of only
# computed_block_nums (prefix only). Without this, block_tables
# passed to _forward_prefix_pytorch is too narrow for context_len,
# causing an empty blk_ids slice and a zero-dim amax() crash.
inter_data.prefix_cache_hit = False"""
import os
patched = False
for path in CANDIDATE_PATHS:
if not os.path.exists(path):
continue
with open(path, "r") as f:
src = f.read()
if OLD_BLOCK not in src:
if NEW_BLOCK in src:
print(f"[patch_model_runner] already patched: {path}")
patched = True
break
print(f"[patch_model_runner] WARNING: expected block not found in {path}, skipping")
continue
patched_src = src.replace(OLD_BLOCK, NEW_BLOCK, 1)
with open(path, "w") as f:
f.write(patched_src)
print(f"[patch_model_runner] patched Case-1 prefix_cache_hit fix in: {path}")
patched = True
break
if not patched:
print("[patch_model_runner] ERROR: could not find model_runner.py at any known path", file=sys.stderr)
sys.exit(1)