fix issues

This commit is contained in:
2026-06-26 12:55:02 +08:00
parent 3d62430fd7
commit c84151eef9
9 changed files with 1879 additions and 5 deletions

View File

@@ -19,7 +19,7 @@ command:
- --disable-log-requests - --disable-log-requests
- --disable-frontend-multiprocessing - --disable-frontend-multiprocessing
- --max-num-batched-tokens - --max-num-batched-tokens
- '4096' - '8192'
- --enable-chunked-prefill - --enable-chunked-prefill
- --max-seq-len-to-capture - --max-seq-len-to-capture
- '32768' - '32768'

View File

@@ -393,6 +393,20 @@ class PagedAttention:
# -------------------------------------------------------------- # --------------------------------------------------------------
if ctx_len > 0: if ctx_len > 0:
num_ctx_blocks = (ctx_len + block_size - 1) // block_size num_ctx_blocks = (ctx_len + block_size - 1) // block_size
# Safety: if block_tables is too narrow this indicates a
# prefix_cache_hit + chunked-prefill bug in model_runner.py
# (Case 1 leaves prefix_cache_hit=True but block_table is
# only computed_block_nums, not the full context blocks).
# patch_model_runner.py fixes the root cause; this guard
# prevents a zero-dim amax() crash if it still slips through.
if num_ctx_blocks > block_tables.shape[1]:
print(
f"[paged_attn WARNING] seq {i}: num_ctx_blocks={num_ctx_blocks} "
f"> block_tables.shape[1]={block_tables.shape[1]}, ctx_len={ctx_len}. "
"Block table is undersized (prefix_cache_hit bug). "
"Capping context to available blocks — attention may be incorrect.",
file=sys.stderr, flush=True)
num_ctx_blocks = block_tables.shape[1]
for tile_blk in range(0, num_ctx_blocks, _BLOCKS_PER_TILE): for tile_blk in range(0, num_ctx_blocks, _BLOCKS_PER_TILE):
blk_end = min(tile_blk + _BLOCKS_PER_TILE, num_ctx_blocks) blk_end = min(tile_blk + _BLOCKS_PER_TILE, num_ctx_blocks)
blk_ids = block_tables[i, tile_blk:blk_end] blk_ids = block_tables[i, tile_blk:blk_end]

View File

@@ -0,0 +1,78 @@
"""
Fix: prefix_cache_hit stays True for chunked-prefill chunk 2+ even when past cache.
Root cause:
model_runner.py _compute_for_prefix_cache_hit has three cases:
Case 1: prefix_cache_len <= context_len → "already past cache, do normal"
Case 2: context_len < prefix_cache_len < seq_len → partial hit, correct
Case 3: seq_len <= prefix_cache_len → full hit, reduce to 1 token
Case 1 does nothing (leaves prefix_cache_hit = True). Then in utils.py:
if inter_data.prefix_cache_hit:
block_table = computed_block_nums ← ONLY the original prefix blocks!
But context_len > prefix_cache_len means chunk 1 tokens (between prefix_cache_len
and context_len) are ALSO in KV cache and need to be in block_table.
block_table = computed_block_nums misses all chunk-1 blocks.
In _forward_prefix_pytorch:
num_ctx_blocks = ceil(context_len / block_size) # e.g. 268
block_tables.shape[1] = len(computed_block_nums) # e.g. 12 <-- too small!
At tile_blk >= 12: blk_ids is empty → k_t shape [..., 0] → amax crash.
Fix:
Set prefix_cache_hit = False for Case 1, so utils.py falls through to:
elif chunked_prefill_enabled:
block_table = block_tables[seq_id] ← full block table (prefix + chunk1)
"""
import re
import sys
CANDIDATE_PATHS = [
"/usr/local/corex/lib64/python3/dist-packages/vllm/worker/model_runner.py",
"/usr/local/corex/lib/python3/dist-packages/vllm/worker/model_runner.py",
]
OLD_BLOCK = """\
if prefix_cache_len <= context_len:
# We already passed the cache hit region,
# so do normal computation.
pass"""
NEW_BLOCK = """\
if prefix_cache_len <= context_len:
# We already passed the cache hit region,
# so do normal computation.
# Must clear prefix_cache_hit so _add_seq_group uses the full
# block_tables (prefix + previous-chunk blocks) instead of only
# computed_block_nums (prefix only). Without this, block_tables
# passed to _forward_prefix_pytorch is too narrow for context_len,
# causing an empty blk_ids slice and a zero-dim amax() crash.
inter_data.prefix_cache_hit = False"""
import os
patched = False
for path in CANDIDATE_PATHS:
if not os.path.exists(path):
continue
with open(path, "r") as f:
src = f.read()
if OLD_BLOCK not in src:
if NEW_BLOCK in src:
print(f"[patch_model_runner] already patched: {path}")
patched = True
break
print(f"[patch_model_runner] WARNING: expected block not found in {path}, skipping")
continue
patched_src = src.replace(OLD_BLOCK, NEW_BLOCK, 1)
with open(path, "w") as f:
f.write(patched_src)
print(f"[patch_model_runner] patched Case-1 prefix_cache_hit fix in: {path}")
patched = True
break
if not patched:
print("[patch_model_runner] ERROR: could not find model_runner.py at any known path", file=sys.stderr)
sys.exit(1)

View File

@@ -8,8 +8,6 @@
# are already correct for standard Triton 2.3.1 — do NOT overwrite them. # are already correct for standard Triton 2.3.1 — do NOT overwrite them.
# - DO NOT install BI-V150 corex Triton 2.1.0 (pkgs/triton): that causes # - DO NOT install BI-V150 corex Triton 2.1.0 (pkgs/triton): that causes
# GPU hang on BI-V100 because the Triton CUDA PTX kernels are incompatible. # GPU hang on BI-V100 because the Triton CUDA PTX kernels are incompatible.
#
# Important Note: Qwen3.6-27B must apply TP=4,PP=2 combination in order to deploy using 8 GPUs
# Recommended server start command for TP=4 support 100K, need chunked prefill # Recommended server start command for TP=4 support 100K, need chunked prefill
# CUDA_VISIBLE_DEVICES="4,5,6,7" VLLM_ENGINE_ITERATION_TIMEOUT_S=3600 python3 -m vllm.entrypoints.openai.api_server \ # CUDA_VISIBLE_DEVICES="4,5,6,7" VLLM_ENGINE_ITERATION_TIMEOUT_S=3600 python3 -m vllm.entrypoints.openai.api_server \
@@ -17,6 +15,14 @@
# --max-model-len 100000 --enforce-eager --trust-remote-code -tp 4 --gpu-memory-utilization 0.95 \ # --max-model-len 100000 --enforce-eager --trust-remote-code -tp 4 --gpu-memory-utilization 0.95 \
# --max-num-seqs 1 --disable-log-requests --disable-frontend-multiprocessing \ # --max-num-seqs 1 --disable-log-requests --disable-frontend-multiprocessing \
# --max-num-batched-tokens 4096 --enable-chunked-prefill # --max-num-batched-tokens 4096 --enable-chunked-prefill
#
# With prefix caching (GDN align-mode, requires chunked prefill):
# CUDA_VISIBLE_DEVICES="4,5,6,7" VLLM_ENGINE_ITERATION_TIMEOUT_S=3600 python3 -m vllm.entrypoints.openai.api_server \
# --model /workspace/models/Qwen3.6-35B-A3B --port 1111 --served-model-name llm \
# --max-model-len 150000 --trust-remote-code -tp 4 --gpu-memory-utilization 0.90 \
# --max-num-seqs 1 --disable-log-requests --disable-frontend-multiprocessing \
# --max-num-batched-tokens 8192 --enable-chunked-prefill --enable-prefix-caching \
# --max-seq-len-to-capture 32768
# --- paged_attn.py: replace forward_prefix with pure-PyTorch fallback ------- # --- paged_attn.py: replace forward_prefix with pure-PyTorch fallback -------
# The Triton context_attention_fwd kernel hangs BI-V100 GPUs permanently # The Triton context_attention_fwd kernel hangs BI-V100 GPUs permanently
@@ -26,6 +32,15 @@
# when context length is high # when context length is high
cp ./paged_attn.py /usr/local/corex/lib/python3/dist-packages/vllm/attention/ops/paged_attn.py cp ./paged_attn.py /usr/local/corex/lib/python3/dist-packages/vllm/attention/ops/paged_attn.py
# --- model_runner.py: fix prefix_cache_hit stays True in chunked-prefill chunk 2+ ---
# Bug: _compute_for_prefix_cache_hit Case 1 (prefix_cache_len <= context_len)
# leaves prefix_cache_hit=True. Then _add_seq_group uses block_table=computed_block_nums
# (only the original prefix blocks), ignoring chunk-1 KV cache blocks.
# _forward_prefix_pytorch then gets an undersized block_tables and crashes with
# "amax(): Expected reduction dim -1 to have non-zero size" on the 2nd tile.
# Fix: set prefix_cache_hit=False for Case 1 so the full block_tables is used.
python3 ./patch_model_runner.py
# --- transformers: Qwen3_5 tokenizer / model files -------------------------- # --- transformers: Qwen3_5 tokenizer / model files --------------------------
pip install transformers==4.55.3 -i https://pypi.tuna.tsinghua.edu.cn/simple pip install transformers==4.55.3 -i https://pypi.tuna.tsinghua.edu.cn/simple
cp -r ./qwen3_5 /usr/local/lib/python3.10/site-packages/transformers/models/ cp -r ./qwen3_5 /usr/local/lib/python3.10/site-packages/transformers/models/
@@ -42,8 +57,15 @@ python3 ./patch_vllm_qwen3_5.py
# returns _cached_all_token_ids[-0:] == [0:] (the ENTIRE prompt+output list). # returns _cached_all_token_ids[-0:] == [0:] (the ENTIRE prompt+output list).
# Each prefill chunk step adds prompt_len to previous_num_tokens, so a 10K # Each prefill chunk step adds prompt_len to previous_num_tokens, so a 10K
# prompt processed in 3 chunks inflates completion_tokens by ~30K. # prompt processed in 3 chunks inflates completion_tokens by ~30K.
# Also adds num_cached_tokens field to RequestMetrics for prefix-cache stats.
cp ./sequence.py /usr/local/corex/lib/python3/dist-packages/vllm/sequence.py cp ./sequence.py /usr/local/corex/lib/python3/dist-packages/vllm/sequence.py
# --- scheduler.py: record num_cached_tokens in RequestMetrics ----------------
# Sets seq_group.metrics.num_cached_tokens = prefix_cache_len on first prefill
# when --enable-prefix-caching is active, so serving_chat.py can report it in
# usage.prompt_tokens_details.cached_tokens (OpenAI-compatible API response).
cp ./scheduler.py /usr/local/corex/lib/python3/dist-packages/vllm/core/scheduler.py
# --- xformers: bypass cudnnFlashAttnForward (head_dim=256 > 128 limit) ------ # --- xformers: bypass cudnnFlashAttnForward (head_dim=256 > 128 limit) ------
# Injects _run_sdpa_fallback (pure matmul+softmax) into xformers.py. # Injects _run_sdpa_fallback (pure matmul+softmax) into xformers.py.
# Required because head_dim=256 > 128 and ixformer flash attention either # Required because head_dim=256 > 128 and ixformer flash attention either

View File

@@ -99,11 +99,16 @@ class ModelList(OpenAIBaseModel):
data: List[ModelCard] = Field(default_factory=list) data: List[ModelCard] = Field(default_factory=list)
class PromptTokensDetails(OpenAIBaseModel):
cached_tokens: int = 0
class UsageInfo(OpenAIBaseModel): class UsageInfo(OpenAIBaseModel):
prompt_tokens: int = 0 prompt_tokens: int = 0
total_tokens: int = 0 total_tokens: int = 0
completion_tokens: Optional[int] = 0 completion_tokens: Optional[int] = 0
reasoning_tokens: Optional[int] = None reasoning_tokens: Optional[int] = None
prompt_tokens_details: Optional[PromptTokensDetails] = None
class RequestResponseMetadata(BaseModel): class RequestResponseMetadata(BaseModel):

View File

@@ -2,7 +2,8 @@
# Pure-PyTorch DeltaNet (no fla / causal_conv1d dependency). # Pure-PyTorch DeltaNet (no fla / causal_conv1d dependency).
# Text-only (no VL, no MTP). # Text-only (no VL, no MTP).
from typing import Iterable, List, Optional, Tuple from collections import OrderedDict
from typing import Dict, Iterable, List, Optional, Tuple
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@@ -1033,6 +1034,15 @@ class Qwen3_5ForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
# Lazy initialised in first forward call # Lazy initialised in first forward call
self.mamba_cache: Optional[MambaCacheManager] = None self.mamba_cache: Optional[MambaCacheManager] = None
# GDN prefix state cache (align mode): stores (conv_states, temporal_states) snapshots
# at KV-block boundaries so that prefix-cache-hit requests can restore correct GDN state.
# Key: tuple of physical block IDs covering the cached prefix
# Value: (conv_states_cpu, temporal_states_cpu) each of shape (num_gdn_layers, ...)
self._gdn_prefix_cache: OrderedDict = OrderedDict()
self._gdn_prefix_cache_max: int = 16 # ~16 × 16 MB ≈ 256 MB CPU RAM
self._block_size: int = (cache_config.block_size
if cache_config is not None else 16)
def _get_mamba_cache_shape(self): def _get_mamba_cache_shape(self):
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
# Each sequence's state is stored in float32 # Each sequence's state is stored in float32
@@ -1069,9 +1079,69 @@ class Qwen3_5ForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
# temporal_states: (num_linear_layers, batch, local_num_v, k_dim, v_dim) # temporal_states: (num_linear_layers, batch, local_num_v, k_dim, v_dim)
conv_states, temporal_states = mamba_tensors conv_states, temporal_states = mamba_tensors
# ── GDN prefix-cache align mode: inject saved state on prefix hit ─────
# Conditions: prefill pass, batch=1, context_len > 0 (prefix cached or
# previous chunk already processed), block_tables available.
# We always attempt a lookup: for subsequent chunked-prefill chunks the
# key matches our own saved state (same data already in slot → no-op).
# For a true cross-request prefix hit the key matches a previous request.
_is_single_seq_prefill = (
attn_metadata is not None
and attn_metadata.num_prefill_tokens > 0
and conv_states.shape[1] == 1 # batch == 1
and getattr(attn_metadata, 'context_lens_tensor', None) is not None
and getattr(attn_metadata, 'block_tables', None) is not None
and attn_metadata.block_tables.numel() > 0
)
if _is_single_seq_prefill:
context_len = int(attn_metadata.context_lens_tensor[0].item())
if context_len > 0:
num_prefix_blocks = context_len // self._block_size
if (num_prefix_blocks > 0
and attn_metadata.block_tables.shape[1] >= num_prefix_blocks):
lookup_key = tuple(
attn_metadata.block_tables[0, :num_prefix_blocks]
.cpu().tolist())
if lookup_key in self._gdn_prefix_cache:
saved_conv, saved_temporal = self._gdn_prefix_cache[lookup_key]
conv_states[:, 0].copy_(
saved_conv.to(conv_states.device), non_blocking=True)
temporal_states[:, 0].copy_(
saved_temporal.to(temporal_states.device), non_blocking=True)
self._gdn_prefix_cache.move_to_end(lookup_key)
logger.debug("GDN prefix cache hit: prefix_len=%d blocks=%d",
context_len, num_prefix_blocks)
# ── End inject ──────────────────────────────────────────────────────────
hidden_states = self.model( hidden_states = self.model(
input_ids, positions, kv_caches, attn_metadata, input_ids, positions, kv_caches, attn_metadata,
conv_states, temporal_states) conv_states, temporal_states)
# ── GDN prefix-cache align mode: save state after this prefill chunk ───
# Save state keyed by ALL complete KV blocks processed so far.
# Next requests reusing this prefix will restore from here.
if _is_single_seq_prefill:
context_len = int(attn_metadata.context_lens_tensor[0].item())
query_len = attn_metadata.num_prefill_tokens
total_processed = context_len + query_len
num_complete_blocks = total_processed // self._block_size
if (num_complete_blocks > 0
and attn_metadata.block_tables.shape[1] >= num_complete_blocks):
save_key = tuple(
attn_metadata.block_tables[0, :num_complete_blocks]
.cpu().tolist())
# Move to end (LRU: most recent = last) and update value
if save_key in self._gdn_prefix_cache:
self._gdn_prefix_cache.move_to_end(save_key)
self._gdn_prefix_cache[save_key] = (
conv_states[:, 0].cpu().clone(),
temporal_states[:, 0].cpu().clone(),
)
# Evict oldest entries beyond max
while len(self._gdn_prefix_cache) > self._gdn_prefix_cache_max:
self._gdn_prefix_cache.popitem(last=False)
# ── End save ────────────────────────────────────────────────────────────
return hidden_states return hidden_states
def compute_logits( def compute_logits(

1656
qwen3_6_scripts/scheduler.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -119,6 +119,7 @@ class RequestMetrics:
scheduler_time: Optional[float] = None scheduler_time: Optional[float] = None
model_forward_time: Optional[float] = None model_forward_time: Optional[float] = None
model_execute_time: Optional[float] = None model_execute_time: Optional[float] = None
num_cached_tokens: Optional[int] = None
class SequenceDataDelta( class SequenceDataDelta(

View File

@@ -25,7 +25,7 @@ from vllm.entrypoints.openai.protocol import (
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage, ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage,
DeltaToolCall, ErrorResponse, FunctionCall, RequestResponseMetadata, DeltaToolCall, ErrorResponse, FunctionCall, RequestResponseMetadata,
ToolCall, UsageInfo) PromptTokensDetails, ToolCall, UsageInfo)
from vllm.entrypoints.openai.serving_engine import (BaseModelPath, from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
LoRAModulePath, LoRAModulePath,
OpenAIServing, OpenAIServing,
@@ -179,6 +179,16 @@ class OpenAIServingChat(OpenAIServing):
logger.exception("Error in loading multi-modal data") logger.exception("Error in loading multi-modal data")
return self.create_error_response(str(e)) return self.create_error_response(str(e))
# n > max_num_seqs deadlock guard: scheduler uses break (not continue)
# when can_schedule(num_new_seqs=n) fails, so an n that exceeds
# max_num_seqs permanently blocks the entire waiting queue with no error.
_sched_cfg = await self.engine_client.get_scheduler_config()
_max_seqs = _sched_cfg.max_num_seqs
if request.n is not None and request.n > _max_seqs:
return self.create_error_response(
f"n={request.n} exceeds max_num_seqs={_max_seqs}. "
f"Use n<={_max_seqs} or omit n.")
# validation for OpenAI tools # validation for OpenAI tools
# tool_choice = "required" is not supported # tool_choice = "required" is not supported
if request.tool_choice == "required": if request.tool_choice == "required":
@@ -318,6 +328,7 @@ class OpenAIServingChat(OpenAIServing):
previous_num_tokens = [0] * num_choices previous_num_tokens = [0] * num_choices
finish_reason_sent = [False] * num_choices finish_reason_sent = [False] * num_choices
num_prompt_tokens = 0 num_prompt_tokens = 0
num_cached_tokens: Optional[int] = None
if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam): if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam):
tool_choice_function_name = request.tool_choice.function.name tool_choice_function_name = request.tool_choice.function.name
@@ -385,6 +396,10 @@ class OpenAIServingChat(OpenAIServing):
num_prompt_tokens = len(res.prompt_token_ids) num_prompt_tokens = len(res.prompt_token_ids)
if res.encoder_prompt_token_ids is not None: if res.encoder_prompt_token_ids is not None:
num_prompt_tokens += len(res.encoder_prompt_token_ids) num_prompt_tokens += len(res.encoder_prompt_token_ids)
if (num_cached_tokens is None
and res.metrics is not None
and res.metrics.num_cached_tokens is not None):
num_cached_tokens = res.metrics.num_cached_tokens
# We need to do it here, because if there are exceptions in # We need to do it here, because if there are exceptions in
# the result_generator, it needs to be sent as the FIRST # the result_generator, it needs to be sent as the FIRST
@@ -691,6 +706,9 @@ class OpenAIServingChat(OpenAIServing):
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=num_prompt_tokens + completion_tokens, total_tokens=num_prompt_tokens + completion_tokens,
reasoning_tokens=total_reasoning, reasoning_tokens=total_reasoning,
prompt_tokens_details=(
PromptTokensDetails(cached_tokens=num_cached_tokens)
if num_cached_tokens is not None else None),
) )
final_usage_chunk = ChatCompletionStreamResponse( final_usage_chunk = ChatCompletionStreamResponse(
@@ -713,6 +731,10 @@ class OpenAIServingChat(OpenAIServing):
total_tokens=num_prompt_tokens + num_completion_tokens, total_tokens=num_prompt_tokens + num_completion_tokens,
reasoning_tokens=total_reasoning) reasoning_tokens=total_reasoning)
except asyncio.CancelledError:
# Client disconnected; abort the engine request so GPU is freed.
await self.engine_client.abort(request_id)
return
except ValueError as e: except ValueError as e:
# TODO: Use a vllm-specific Validation Error # TODO: Use a vllm-specific Validation Error
logger.error("error in chat completion stream generator: %s", e) logger.error("error in chat completion stream generator: %s", e)
@@ -739,6 +761,7 @@ class OpenAIServingChat(OpenAIServing):
async for res in result_generator: async for res in result_generator:
final_res = res final_res = res
except asyncio.CancelledError: except asyncio.CancelledError:
await self.engine_client.abort(request_id)
return self.create_error_response("Client disconnected") return self.create_error_response("Client disconnected")
assert final_res is not None assert final_res is not None
@@ -881,11 +904,16 @@ class OpenAIServingChat(OpenAIServing):
total_reasoning_tokens = sum( total_reasoning_tokens = sum(
rp.count_reasoning_tokens(list(output.token_ids)) rp.count_reasoning_tokens(list(output.token_ids))
for output in final_res.outputs) for output in final_res.outputs)
num_cached_tokens = (final_res.metrics.num_cached_tokens
if final_res.metrics is not None else None)
usage = UsageInfo( usage = UsageInfo(
prompt_tokens=num_prompt_tokens, prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens, completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens, total_tokens=num_prompt_tokens + num_generated_tokens,
reasoning_tokens=total_reasoning_tokens, reasoning_tokens=total_reasoning_tokens,
prompt_tokens_details=(
PromptTokensDetails(cached_tokens=num_cached_tokens)
if num_cached_tokens is not None else None),
) )
request_metadata.final_usage_info = usage request_metadata.final_usage_info = usage