diff --git a/computility-run.yaml b/computility-run.yaml index 4152dbb..ae58bdf 100644 --- a/computility-run.yaml +++ b/computility-run.yaml @@ -19,7 +19,7 @@ command: - --disable-log-requests - --disable-frontend-multiprocessing - --max-num-batched-tokens - - '4096' + - '8192' - --enable-chunked-prefill - --max-seq-len-to-capture - '32768' diff --git a/qwen3_6_scripts/paged_attn.py b/qwen3_6_scripts/paged_attn.py index 63c9443..8590489 100644 --- a/qwen3_6_scripts/paged_attn.py +++ b/qwen3_6_scripts/paged_attn.py @@ -393,6 +393,20 @@ class PagedAttention: # -------------------------------------------------------------- if ctx_len > 0: num_ctx_blocks = (ctx_len + block_size - 1) // block_size + # Safety: if block_tables is too narrow this indicates a + # prefix_cache_hit + chunked-prefill bug in model_runner.py + # (Case 1 leaves prefix_cache_hit=True but block_table is + # only computed_block_nums, not the full context blocks). + # patch_model_runner.py fixes the root cause; this guard + # prevents a zero-dim amax() crash if it still slips through. + if num_ctx_blocks > block_tables.shape[1]: + print( + f"[paged_attn WARNING] seq {i}: num_ctx_blocks={num_ctx_blocks} " + f"> block_tables.shape[1]={block_tables.shape[1]}, ctx_len={ctx_len}. " + "Block table is undersized (prefix_cache_hit bug). " + "Capping context to available blocks — attention may be incorrect.", + file=sys.stderr, flush=True) + num_ctx_blocks = block_tables.shape[1] for tile_blk in range(0, num_ctx_blocks, _BLOCKS_PER_TILE): blk_end = min(tile_blk + _BLOCKS_PER_TILE, num_ctx_blocks) blk_ids = block_tables[i, tile_blk:blk_end] diff --git a/qwen3_6_scripts/patch_model_runner.py b/qwen3_6_scripts/patch_model_runner.py new file mode 100644 index 0000000..e10ad27 --- /dev/null +++ b/qwen3_6_scripts/patch_model_runner.py @@ -0,0 +1,78 @@ +""" +Fix: prefix_cache_hit stays True for chunked-prefill chunk 2+ even when past cache. + +Root cause: + model_runner.py _compute_for_prefix_cache_hit has three cases: + Case 1: prefix_cache_len <= context_len → "already past cache, do normal" + Case 2: context_len < prefix_cache_len < seq_len → partial hit, correct + Case 3: seq_len <= prefix_cache_len → full hit, reduce to 1 token + + Case 1 does nothing (leaves prefix_cache_hit = True). Then in utils.py: + if inter_data.prefix_cache_hit: + block_table = computed_block_nums ← ONLY the original prefix blocks! + + But context_len > prefix_cache_len means chunk 1 tokens (between prefix_cache_len + and context_len) are ALSO in KV cache and need to be in block_table. + block_table = computed_block_nums misses all chunk-1 blocks. + + In _forward_prefix_pytorch: + num_ctx_blocks = ceil(context_len / block_size) # e.g. 268 + block_tables.shape[1] = len(computed_block_nums) # e.g. 12 <-- too small! + At tile_blk >= 12: blk_ids is empty → k_t shape [..., 0] → amax crash. + +Fix: + Set prefix_cache_hit = False for Case 1, so utils.py falls through to: + elif chunked_prefill_enabled: + block_table = block_tables[seq_id] ← full block table (prefix + chunk1) +""" + +import re +import sys + +CANDIDATE_PATHS = [ + "/usr/local/corex/lib64/python3/dist-packages/vllm/worker/model_runner.py", + "/usr/local/corex/lib/python3/dist-packages/vllm/worker/model_runner.py", +] + +OLD_BLOCK = """\ + if prefix_cache_len <= context_len: + # We already passed the cache hit region, + # so do normal computation. + pass""" + +NEW_BLOCK = """\ + if prefix_cache_len <= context_len: + # We already passed the cache hit region, + # so do normal computation. + # Must clear prefix_cache_hit so _add_seq_group uses the full + # block_tables (prefix + previous-chunk blocks) instead of only + # computed_block_nums (prefix only). Without this, block_tables + # passed to _forward_prefix_pytorch is too narrow for context_len, + # causing an empty blk_ids slice and a zero-dim amax() crash. + inter_data.prefix_cache_hit = False""" + +import os + +patched = False +for path in CANDIDATE_PATHS: + if not os.path.exists(path): + continue + with open(path, "r") as f: + src = f.read() + if OLD_BLOCK not in src: + if NEW_BLOCK in src: + print(f"[patch_model_runner] already patched: {path}") + patched = True + break + print(f"[patch_model_runner] WARNING: expected block not found in {path}, skipping") + continue + patched_src = src.replace(OLD_BLOCK, NEW_BLOCK, 1) + with open(path, "w") as f: + f.write(patched_src) + print(f"[patch_model_runner] patched Case-1 prefix_cache_hit fix in: {path}") + patched = True + break + +if not patched: + print("[patch_model_runner] ERROR: could not find model_runner.py at any known path", file=sys.stderr) + sys.exit(1) diff --git a/qwen3_6_scripts/patch_ops.sh b/qwen3_6_scripts/patch_ops.sh index d354992..1035b2f 100755 --- a/qwen3_6_scripts/patch_ops.sh +++ b/qwen3_6_scripts/patch_ops.sh @@ -8,8 +8,6 @@ # are already correct for standard Triton 2.3.1 — do NOT overwrite them. # - DO NOT install BI-V150 corex Triton 2.1.0 (pkgs/triton): that causes # GPU hang on BI-V100 because the Triton CUDA PTX kernels are incompatible. -# -# Important Note: Qwen3.6-27B must apply TP=4,PP=2 combination in order to deploy using 8 GPUs # Recommended server start command for TP=4 support 100K, need chunked prefill # CUDA_VISIBLE_DEVICES="4,5,6,7" VLLM_ENGINE_ITERATION_TIMEOUT_S=3600 python3 -m vllm.entrypoints.openai.api_server \ @@ -17,6 +15,14 @@ # --max-model-len 100000 --enforce-eager --trust-remote-code -tp 4 --gpu-memory-utilization 0.95 \ # --max-num-seqs 1 --disable-log-requests --disable-frontend-multiprocessing \ # --max-num-batched-tokens 4096 --enable-chunked-prefill +# +# With prefix caching (GDN align-mode, requires chunked prefill): +# CUDA_VISIBLE_DEVICES="4,5,6,7" VLLM_ENGINE_ITERATION_TIMEOUT_S=3600 python3 -m vllm.entrypoints.openai.api_server \ +# --model /workspace/models/Qwen3.6-35B-A3B --port 1111 --served-model-name llm \ +# --max-model-len 150000 --trust-remote-code -tp 4 --gpu-memory-utilization 0.90 \ +# --max-num-seqs 1 --disable-log-requests --disable-frontend-multiprocessing \ +# --max-num-batched-tokens 8192 --enable-chunked-prefill --enable-prefix-caching \ +# --max-seq-len-to-capture 32768 # --- paged_attn.py: replace forward_prefix with pure-PyTorch fallback ------- # The Triton context_attention_fwd kernel hangs BI-V100 GPUs permanently @@ -26,6 +32,15 @@ # when context length is high cp ./paged_attn.py /usr/local/corex/lib/python3/dist-packages/vllm/attention/ops/paged_attn.py +# --- model_runner.py: fix prefix_cache_hit stays True in chunked-prefill chunk 2+ --- +# Bug: _compute_for_prefix_cache_hit Case 1 (prefix_cache_len <= context_len) +# leaves prefix_cache_hit=True. Then _add_seq_group uses block_table=computed_block_nums +# (only the original prefix blocks), ignoring chunk-1 KV cache blocks. +# _forward_prefix_pytorch then gets an undersized block_tables and crashes with +# "amax(): Expected reduction dim -1 to have non-zero size" on the 2nd tile. +# Fix: set prefix_cache_hit=False for Case 1 so the full block_tables is used. +python3 ./patch_model_runner.py + # --- transformers: Qwen3_5 tokenizer / model files -------------------------- pip install transformers==4.55.3 -i https://pypi.tuna.tsinghua.edu.cn/simple cp -r ./qwen3_5 /usr/local/lib/python3.10/site-packages/transformers/models/ @@ -42,8 +57,15 @@ python3 ./patch_vllm_qwen3_5.py # returns _cached_all_token_ids[-0:] == [0:] (the ENTIRE prompt+output list). # Each prefill chunk step adds prompt_len to previous_num_tokens, so a 10K # prompt processed in 3 chunks inflates completion_tokens by ~30K. +# Also adds num_cached_tokens field to RequestMetrics for prefix-cache stats. cp ./sequence.py /usr/local/corex/lib/python3/dist-packages/vllm/sequence.py +# --- scheduler.py: record num_cached_tokens in RequestMetrics ---------------- +# Sets seq_group.metrics.num_cached_tokens = prefix_cache_len on first prefill +# when --enable-prefix-caching is active, so serving_chat.py can report it in +# usage.prompt_tokens_details.cached_tokens (OpenAI-compatible API response). +cp ./scheduler.py /usr/local/corex/lib/python3/dist-packages/vllm/core/scheduler.py + # --- xformers: bypass cudnnFlashAttnForward (head_dim=256 > 128 limit) ------ # Injects _run_sdpa_fallback (pure matmul+softmax) into xformers.py. # Required because head_dim=256 > 128 and ixformer flash attention either diff --git a/qwen3_6_scripts/protocol.py b/qwen3_6_scripts/protocol.py index 3a6a557..456ad51 100644 --- a/qwen3_6_scripts/protocol.py +++ b/qwen3_6_scripts/protocol.py @@ -99,11 +99,16 @@ class ModelList(OpenAIBaseModel): data: List[ModelCard] = Field(default_factory=list) +class PromptTokensDetails(OpenAIBaseModel): + cached_tokens: int = 0 + + class UsageInfo(OpenAIBaseModel): prompt_tokens: int = 0 total_tokens: int = 0 completion_tokens: Optional[int] = 0 reasoning_tokens: Optional[int] = None + prompt_tokens_details: Optional[PromptTokensDetails] = None class RequestResponseMetadata(BaseModel): diff --git a/qwen3_6_scripts/qwen3_5.py b/qwen3_6_scripts/qwen3_5.py index 617416b..daaaa89 100644 --- a/qwen3_6_scripts/qwen3_5.py +++ b/qwen3_6_scripts/qwen3_5.py @@ -2,7 +2,8 @@ # Pure-PyTorch DeltaNet (no fla / causal_conv1d dependency). # Text-only (no VL, no MTP). -from typing import Iterable, List, Optional, Tuple +from collections import OrderedDict +from typing import Dict, Iterable, List, Optional, Tuple import torch import torch.nn.functional as F @@ -1033,6 +1034,15 @@ class Qwen3_5ForCausalLM(nn.Module, HasInnerState, SupportsLoRA): # Lazy initialised in first forward call self.mamba_cache: Optional[MambaCacheManager] = None + # GDN prefix state cache (align mode): stores (conv_states, temporal_states) snapshots + # at KV-block boundaries so that prefix-cache-hit requests can restore correct GDN state. + # Key: tuple of physical block IDs covering the cached prefix + # Value: (conv_states_cpu, temporal_states_cpu) each of shape (num_gdn_layers, ...) + self._gdn_prefix_cache: OrderedDict = OrderedDict() + self._gdn_prefix_cache_max: int = 16 # ~16 × 16 MB ≈ 256 MB CPU RAM + self._block_size: int = (cache_config.block_size + if cache_config is not None else 16) + def _get_mamba_cache_shape(self): tp_size = get_tensor_model_parallel_world_size() # Each sequence's state is stored in float32 @@ -1069,9 +1079,69 @@ class Qwen3_5ForCausalLM(nn.Module, HasInnerState, SupportsLoRA): # temporal_states: (num_linear_layers, batch, local_num_v, k_dim, v_dim) conv_states, temporal_states = mamba_tensors + # ── GDN prefix-cache align mode: inject saved state on prefix hit ───── + # Conditions: prefill pass, batch=1, context_len > 0 (prefix cached or + # previous chunk already processed), block_tables available. + # We always attempt a lookup: for subsequent chunked-prefill chunks the + # key matches our own saved state (same data already in slot → no-op). + # For a true cross-request prefix hit the key matches a previous request. + _is_single_seq_prefill = ( + attn_metadata is not None + and attn_metadata.num_prefill_tokens > 0 + and conv_states.shape[1] == 1 # batch == 1 + and getattr(attn_metadata, 'context_lens_tensor', None) is not None + and getattr(attn_metadata, 'block_tables', None) is not None + and attn_metadata.block_tables.numel() > 0 + ) + if _is_single_seq_prefill: + context_len = int(attn_metadata.context_lens_tensor[0].item()) + if context_len > 0: + num_prefix_blocks = context_len // self._block_size + if (num_prefix_blocks > 0 + and attn_metadata.block_tables.shape[1] >= num_prefix_blocks): + lookup_key = tuple( + attn_metadata.block_tables[0, :num_prefix_blocks] + .cpu().tolist()) + if lookup_key in self._gdn_prefix_cache: + saved_conv, saved_temporal = self._gdn_prefix_cache[lookup_key] + conv_states[:, 0].copy_( + saved_conv.to(conv_states.device), non_blocking=True) + temporal_states[:, 0].copy_( + saved_temporal.to(temporal_states.device), non_blocking=True) + self._gdn_prefix_cache.move_to_end(lookup_key) + logger.debug("GDN prefix cache hit: prefix_len=%d blocks=%d", + context_len, num_prefix_blocks) + # ── End inject ────────────────────────────────────────────────────────── + hidden_states = self.model( input_ids, positions, kv_caches, attn_metadata, conv_states, temporal_states) + + # ── GDN prefix-cache align mode: save state after this prefill chunk ─── + # Save state keyed by ALL complete KV blocks processed so far. + # Next requests reusing this prefix will restore from here. + if _is_single_seq_prefill: + context_len = int(attn_metadata.context_lens_tensor[0].item()) + query_len = attn_metadata.num_prefill_tokens + total_processed = context_len + query_len + num_complete_blocks = total_processed // self._block_size + if (num_complete_blocks > 0 + and attn_metadata.block_tables.shape[1] >= num_complete_blocks): + save_key = tuple( + attn_metadata.block_tables[0, :num_complete_blocks] + .cpu().tolist()) + # Move to end (LRU: most recent = last) and update value + if save_key in self._gdn_prefix_cache: + self._gdn_prefix_cache.move_to_end(save_key) + self._gdn_prefix_cache[save_key] = ( + conv_states[:, 0].cpu().clone(), + temporal_states[:, 0].cpu().clone(), + ) + # Evict oldest entries beyond max + while len(self._gdn_prefix_cache) > self._gdn_prefix_cache_max: + self._gdn_prefix_cache.popitem(last=False) + # ── End save ──────────────────────────────────────────────────────────── + return hidden_states def compute_logits( diff --git a/qwen3_6_scripts/scheduler.py b/qwen3_6_scripts/scheduler.py new file mode 100644 index 0000000..3c06fd7 --- /dev/null +++ b/qwen3_6_scripts/scheduler.py @@ -0,0 +1,1656 @@ +import enum +import os +import random +import time +from collections import deque +from dataclasses import dataclass, field +from typing import (Callable, Deque, Dict, Iterable, List, Optional, Set, + Tuple, Union) + +from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig +from vllm.core.interfaces import AllocStatus, BlockSpaceManager +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sequence import (Sequence, SequenceData, SequenceGroup, + SequenceGroupMetadata, SequenceGroupMetadataDelta, + SequenceStatus) +from vllm.utils import Device, PyObjectCache + +logger = init_logger(__name__) + +# Test-only. If configured, decode is preempted with +# ARTIFICIAL_PREEMPTION_PROB% probability. +ENABLE_ARTIFICIAL_PREEMPT = bool( + os.getenv("VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT", False)) # noqa +ARTIFICIAL_PREEMPTION_PROB = 0.5 +ARTIFICIAL_PREEMPTION_MAX_CNT = 500 + + +class PreemptionMode(enum.Enum): + """Preemption modes. + + 1. Swapping: Swap out the blocks of the preempted sequences to CPU memory + and swap them back in when the sequences are resumed. + 2. Recomputation: Discard the blocks of the preempted sequences and + recompute them when the sequences are resumed, treating the sequences as + new prompts. + """ + SWAP = enum.auto() + RECOMPUTE = enum.auto() + + +@dataclass +class SchedulingBudget: + """The available slots for scheduling. + + TODO(sang): Right now, the budget is request_id-aware meaning it can ignore + budget update from the same request_id. It is because in normal scheduling + path, we update RUNNING num_seqs ahead of time, meaning it could be + updated more than once when scheduling RUNNING requests. Since this won't + happen if we only have chunked prefill scheduling, we can remove this + feature from the API when chunked prefill is enabled by default. + """ + token_budget: int + max_num_seqs: int + _request_ids_num_batched_tokens: Set[str] = field(default_factory=set) + _request_ids_num_curr_seqs: Set[str] = field(default_factory=set) + _num_batched_tokens: int = 0 + _num_curr_seqs: int = 0 + + def can_schedule(self, *, num_new_tokens: int, num_new_seqs: int): + assert num_new_tokens != 0 + assert num_new_seqs != 0 + return (self.num_batched_tokens + num_new_tokens <= self.token_budget + and self.num_curr_seqs + num_new_seqs <= self.max_num_seqs) + + def remaining_token_budget(self): + return self.token_budget - self.num_batched_tokens + + def add_num_batched_tokens(self, req_id: str, num_batched_tokens: int): + if req_id in self._request_ids_num_batched_tokens: + return + + self._request_ids_num_batched_tokens.add(req_id) + self._num_batched_tokens += num_batched_tokens + + def subtract_num_batched_tokens(self, req_id: str, + num_batched_tokens: int): + if req_id in self._request_ids_num_batched_tokens: + self._request_ids_num_batched_tokens.remove(req_id) + self._num_batched_tokens -= num_batched_tokens + + def add_num_seqs(self, req_id: str, num_curr_seqs: int): + if req_id in self._request_ids_num_curr_seqs: + return + + self._request_ids_num_curr_seqs.add(req_id) + self._num_curr_seqs += num_curr_seqs + + def subtract_num_seqs(self, req_id: str, num_curr_seqs: int): + if req_id in self._request_ids_num_curr_seqs: + self._request_ids_num_curr_seqs.remove(req_id) + self._num_curr_seqs -= num_curr_seqs + + @property + def num_batched_tokens(self): + return self._num_batched_tokens + + @property + def num_curr_seqs(self): + return self._num_curr_seqs + + +@dataclass +class ScheduledSequenceGroup: + # A sequence group that's scheduled. + seq_group: SequenceGroup + # The total chunk size (number of tokens) to process for next iteration. + # 1 for decoding. Same as prompt tokens for prefill, but if prefill is + # chunked, it can be smaller than that. + token_chunk_size: int + + +@dataclass +class SchedulerOutputs: + """The scheduling decision made from a scheduler.""" + # Scheduled sequence groups. + scheduled_seq_groups: Iterable[ScheduledSequenceGroup] + # Number of prefill groups scheduled. + num_prefill_groups: int + # Total number of batched tokens. + num_batched_tokens: int + # Blocks to swap in. List of CPU -> GPU block number. + blocks_to_swap_in: List[Tuple[int, int]] + # Blocks to swap out. List of GPU -> CPU block number. + blocks_to_swap_out: List[Tuple[int, int]] + # Blocks to copy. Source to dest block. + blocks_to_copy: List[Tuple[int, int]] + # Sequence groups that are going to be ignored. + ignored_seq_groups: List[SequenceGroup] + # The number of slots for lookahead decoding. + num_lookahead_slots: int + # The number of requests in the running queue + running_queue_size: int + preempted: int + + def __post_init__(self): + # Swap in and swap out should never happen at the same time. + assert not (self.blocks_to_swap_in and self.blocks_to_swap_out) + + self.num_loras: int = len(self.lora_requests) + if self.num_loras > 0: + self._sort_by_lora_ids() + + self.num_prompt_adapters: int = len(self.prompt_adapter_requests) + + def is_empty(self) -> bool: + # NOTE: We do not consider the ignored sequence groups. + return (not self.scheduled_seq_groups and not self.blocks_to_swap_in + and not self.blocks_to_swap_out and not self.blocks_to_copy) + + def _sort_by_lora_ids(self): + self.scheduled_seq_groups = sorted( + self.scheduled_seq_groups, + key=lambda g: (g.seq_group.lora_int_id, g.seq_group.request_id)) + + @property + def lora_requests(self) -> Set[LoRARequest]: + return { + g.seq_group.lora_request + for g in self.scheduled_seq_groups + if g.seq_group.lora_request is not None + } + + @property + def prompt_adapter_requests(self) -> Set[PromptAdapterRequest]: + return { + g.seq_group.prompt_adapter_request + for g in self.scheduled_seq_groups + if g.seq_group.prompt_adapter_request is not None + } + + +@dataclass +class SchedulerRunningOutputs: + """The requests that are scheduled from a running queue. + + Could contain prefill (prefill that's chunked) or decodes. If there's not + enough memory, it can be preempted (for recompute) or swapped out. + """ + # Selected sequences that are running and in a decoding phase. + decode_seq_groups: List[ScheduledSequenceGroup] + # Selected sequences that are running and in a prefill phase. + # I.e., it means the prefill has been chunked. + prefill_seq_groups: List[ScheduledSequenceGroup] + # The preempted sequences. + preempted: List[SequenceGroup] + # Sequences that are swapped out. + swapped_out: List[SequenceGroup] + # The blocks to swap out. + blocks_to_swap_out: List[Tuple[int, int]] + # The blocks to copy. + blocks_to_copy: List[Tuple[int, int]] + # The number of slots for lookahead decoding. + num_lookahead_slots: int + + # Optimization for fast-access to seq_group lists + decode_seq_groups_list: List[SequenceGroup] + prefill_seq_groups_list: List[SequenceGroup] + + @classmethod + def create_empty(cls) -> "SchedulerRunningOutputs": + return SchedulerRunningOutputs( + decode_seq_groups=[], + prefill_seq_groups=[], + preempted=[], + swapped_out=[], + blocks_to_swap_out=[], + blocks_to_copy=[], + num_lookahead_slots=0, + decode_seq_groups_list=[], + prefill_seq_groups_list=[], + ) + + +@dataclass +class SchedulerSwappedInOutputs: + """The requests that are scheduled from a swap queue. + + Could contain prefill (prefill that's chunked) or decodes. + """ + # Selected sequences that are going to be swapped in and is in a + # decoding phase. + decode_seq_groups: List[ScheduledSequenceGroup] + # Selected sequences that are going to be swapped in and in a prefill + # phase. I.e., it means the prefill has been chunked. + prefill_seq_groups: List[ScheduledSequenceGroup] + # The blocks to swap in. + blocks_to_swap_in: List[Tuple[int, int]] + # The blocks to copy. + blocks_to_copy: List[Tuple[int, int]] + # The number of slots for lookahead decoding. + num_lookahead_slots: int + # Infeasible sequence groups. + infeasible_seq_groups: List[SequenceGroup] + + @classmethod + def create_empty(cls) -> "SchedulerSwappedInOutputs": + return SchedulerSwappedInOutputs( + decode_seq_groups=[], + prefill_seq_groups=[], + blocks_to_swap_in=[], + blocks_to_copy=[], + num_lookahead_slots=0, + infeasible_seq_groups=[], + ) + + +@dataclass +class SchedulerPrefillOutputs: + """The requests that are scheduled from a waiting queue. + + Could contain a fresh prefill requests or preempted requests that need + to be recomputed from scratch. + """ + # Selected sequences for prefill. + seq_groups: List[ScheduledSequenceGroup] + # Ignored sequence groups. + ignored_seq_groups: List[SequenceGroup] + num_lookahead_slots: int + + @classmethod + def create_empty(cls) -> "SchedulerPrefillOutputs": + return SchedulerPrefillOutputs( + seq_groups=[], + ignored_seq_groups=[], + num_lookahead_slots=0, + ) + + +def seq_group_metadata_builder(): + return SequenceGroupMetadata(request_id="", + is_prompt=False, + seq_data={}, + sampling_params=None, + block_tables={}) + + +def scheduler_running_outputs_builder(): + return SchedulerRunningOutputs(decode_seq_groups=[], + prefill_seq_groups=[], + preempted=[], + swapped_out=[], + blocks_to_swap_out=[], + blocks_to_copy=[], + num_lookahead_slots=0, + prefill_seq_groups_list=[], + decode_seq_groups_list=[]) + + +def scheduled_seq_group_builder(): + return ScheduledSequenceGroup(SequenceGroup("", [], -1), + token_chunk_size=0) + # return ScheduledSequenceGroup(seq_group=None, token_chunk_size=0) + + +class Scheduler: + + def __init__( + self, + scheduler_config: SchedulerConfig, + cache_config: CacheConfig, + lora_config: Optional[LoRAConfig], + pipeline_parallel_size: int = 1, + output_proc_callback: Optional[Callable] = None, + ) -> None: + self.scheduler_config = scheduler_config + self.cache_config = cache_config + # Note for LoRA scheduling: the current policy is extremely + # simple and NOT fair. It can lead to starvation of some + # LoRAs. This should be improved in the future. + self.lora_config = lora_config + + version = "v1" + if self.scheduler_config.use_v2_block_manager: + version = "v2" + if (self.scheduler_config.embedding_mode + or self.cache_config.is_attention_free): + version = "placeholder" + + BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class( + version) + + num_gpu_blocks = cache_config.num_gpu_blocks + if num_gpu_blocks: + num_gpu_blocks //= pipeline_parallel_size + + num_cpu_blocks = cache_config.num_cpu_blocks + if num_cpu_blocks: + num_cpu_blocks //= pipeline_parallel_size + + # Create the block space manager. + self.block_manager = BlockSpaceManagerImpl( + block_size=self.cache_config.block_size, + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=num_cpu_blocks, + sliding_window=self.cache_config.sliding_window, + enable_caching=self.cache_config.enable_prefix_caching) + + # Sequence groups in the WAITING state. + # Contain new prefill or preempted requests. + self.waiting: Deque[SequenceGroup] = deque() + # Sequence groups in the RUNNING state. + # Contain decode requests. + self.running: Deque[SequenceGroup] = deque() + # Sequence groups in the SWAPPED state. + # Contain decode requests that are swapped out. + self.swapped: Deque[SequenceGroup] = deque() + # Sequence groups finished requests ids since last step iteration. + # It lets the model know that any state associated with these requests + # can and must be released after the current step. + # This is used to evict the finished requests from the Mamba cache. + self._finished_requests_ids: List[str] = list() + # Time at previous scheduling step + self.prev_time = 0.0 + # Did we schedule a prompt at previous step? + self.prev_prompt = False + # Latency of the last prompt step + self.last_prompt_latency = 0.0 + # preemption mode, RECOMPUTE or SWAP + self.user_specified_preemption_mode = scheduler_config.preemption_mode + + # The following field is test-only. It is used to inject artificial + # preemption. + self.enable_artificial_preemption = ENABLE_ARTIFICIAL_PREEMPT + self.artificial_preempt_cnt = (ARTIFICIAL_PREEMPTION_MAX_CNT + if self.enable_artificial_preemption + else 0) + self.num_cumulative_preemption: int = 0 + + # Used to cache python objects + self._seq_group_metadata_cache: List[PyObjectCache] = [] + self._scheduler_running_outputs_cache: List[PyObjectCache] = [] + self._scheduled_seq_group_cache: List[PyObjectCache] = [] + + # For async output processing, we need to swap cache buffers between + # iterations. I.e. since the output processing is lagged one step, + # we cannot reuse the cached objects immediately when the schedule() + # is called again, but only when schedule() is called the second time. + self.output_proc_callback = output_proc_callback + self.use_async_output_proc = self.output_proc_callback is not None + self.num_cache_iters = 2 if self.use_async_output_proc else 1 + + self.cache_id = 0 + for i in range(self.num_cache_iters): + self._seq_group_metadata_cache.append( + PyObjectCache(seq_group_metadata_builder)) + self._scheduler_running_outputs_cache.append( + PyObjectCache(scheduler_running_outputs_builder)) + self._scheduled_seq_group_cache.append( + PyObjectCache(scheduled_seq_group_builder)) + + # For async postprocessor, the extra decode run cannot be done + # when the request reaches max_model_len. In this case, the request + # will be stopped during schedule() call and added to this stop list + # for processing and deallocation by the free_finished_seq_groups() + self._async_stopped: List[SequenceGroup] = [] + + @property + def next_cache_id(self): + return (self.cache_id + 1) % self.num_cache_iters + + @property + def lora_enabled(self) -> bool: + return bool(self.lora_config) + + @property + def num_decoding_tokens_per_seq(self) -> int: + """The number of new tokens.""" + return 1 + + def add_seq_group(self, seq_group: SequenceGroup) -> None: + # Add sequence groups to the waiting queue. + self.waiting.append(seq_group) + + def _add_seq_group_to_running(self, seq_group: SequenceGroup) -> None: + # Add sequence groups to the running queue. + # Only for testing purposes. + self.running.append(seq_group) + + def _add_seq_group_to_swapped(self, seq_group: SequenceGroup) -> None: + # Add sequence groups to the swapped queue. + # Only for testing purposes. + self.swapped.append(seq_group) + + def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None: + """Aborts a sequence group with the given ID. + + Check if the sequence group with the given ID + is present in any of the state queue. + If present, remove the sequence group from the state queue. + Also, if any of the sequences in the sequence group is not finished, + free the sequence with status `FINISHED_ABORTED`. + Otherwise, do nothing. + + Args: + request_id: The ID(s) of the sequence group to abort. + """ + if isinstance(request_id, str): + request_id = (request_id, ) + request_ids = set(request_id) + for state_queue in [self.waiting, self.running, self.swapped]: + aborted_groups: List[SequenceGroup] = [] + for seq_group in state_queue: + if not request_ids: + # Using 'break' here may add two extra iterations, + # but is acceptable to reduce complexity. + break + if seq_group.request_id in request_ids: + # Appending aborted group into pending list. + aborted_groups.append(seq_group) + request_ids.remove(seq_group.request_id) + for aborted_group in aborted_groups: + # Remove the sequence group from the state queue. + state_queue.remove(aborted_group) + # Remove the aborted request from the Mamba cache. + self._finished_requests_ids.append(aborted_group.request_id) + for seq in aborted_group.get_seqs(): + if seq.is_finished(): + continue + seq.status = SequenceStatus.FINISHED_ABORTED + self.free_seq(seq) + + self._free_seq_group_cross_attn_blocks(aborted_group) + + def _free_seq_group_cross_attn_blocks( + self, + seq_group: SequenceGroup, + ) -> None: + """ + Free a sequence group from a cross-attention block table. + Has no effect on decoder-only models. + """ + if seq_group.is_encoder_decoder(): + self.block_manager.free_cross(seq_group) + + def has_unfinished_seqs(self) -> bool: + return len(self.waiting) != 0 or len(self.running) != 0 or len( + self.swapped) != 0 + + def get_prefix_cache_hit_rate(self, device: Device) -> float: + return self.block_manager.get_prefix_cache_hit_rate(device) + + def get_num_unfinished_seq_groups(self) -> int: + return len(self.waiting) + len(self.running) + len(self.swapped) + + def get_and_reset_finished_requests_ids(self) -> List[str]: + """Flushes the list of request ids of previously finished seq_groups.""" + finished_requests_ids = self._finished_requests_ids + self._finished_requests_ids = list() + return finished_requests_ids + + def _schedule_running( + self, + budget: SchedulingBudget, + curr_loras: Optional[Set[int]], + enable_chunking: bool = False, + ) -> SchedulerRunningOutputs: + """Schedule sequence groups that are running. + + Running queue should include decode and chunked prefill requests. + + Args: + budget: The scheduling budget. The argument is in-place updated + when any decodes are preempted. + curr_loras: Currently batched lora request ids. The argument is + in-place updated when any decodes are preempted. + enable_chunking: If True, seq group can be chunked and only a + chunked number of tokens are scheduled if + `budget.num_batched_tokens` has not enough capacity to schedule + all tokens. + + Returns: + SchedulerRunningOutputs. + """ + ret: SchedulerRunningOutputs = \ + self._scheduler_running_outputs_cache[self.cache_id].get_object() + ret.blocks_to_swap_out.clear() + ret.blocks_to_copy.clear() + ret.decode_seq_groups.clear() + ret.prefill_seq_groups.clear() + ret.preempted.clear() + ret.swapped_out.clear() + + ret.num_lookahead_slots = self._get_num_lookahead_slots( + is_prefill=False, enable_chunking=enable_chunking) + + ret.decode_seq_groups_list.clear() + ret.prefill_seq_groups_list.clear() + + # Blocks that need to be swapped or copied before model execution. + blocks_to_swap_out: List[Tuple[int, int]] = ret.blocks_to_swap_out + blocks_to_copy: List[Tuple[int, int]] = ret.blocks_to_copy + + decode_seq_groups: List[ScheduledSequenceGroup] = ret.decode_seq_groups + prefill_seq_groups: List[ + ScheduledSequenceGroup] = ret.prefill_seq_groups + preempted: List[SequenceGroup] = ret.preempted + swapped_out: List[SequenceGroup] = ret.swapped_out + + running_queue = self.running + assert len(self._async_stopped) == 0 + while running_queue: + seq_group = running_queue[0] + num_running_tokens = self._get_num_new_tokens( + seq_group, SequenceStatus.RUNNING, enable_chunking, budget) + + if num_running_tokens == 0: + # No budget => Stop + break + + running_queue.popleft() + + # With async postprocessor, an extra decode run is done + # to process the final tokens. The check below avoids this extra + # decode run when the model max len is reached, in order to avoid + # a memory overflow. + if self.use_async_output_proc and seq_group.seqs[0].get_len( + ) > self.scheduler_config.max_model_len: + self._async_stopped.append(seq_group) + continue + + # NOTE(woosuk): Preemption happens only when there is no available + # slot to keep all the sequence groups in the RUNNING state. + while not self._can_append_slots(seq_group, enable_chunking): + budget.subtract_num_batched_tokens(seq_group.request_id, + num_running_tokens) + num_running_seqs = seq_group.get_max_num_running_seqs() + budget.subtract_num_seqs(seq_group.request_id, + num_running_seqs) + + if (curr_loras is not None and seq_group.lora_int_id > 0 + and seq_group.lora_int_id in curr_loras): + curr_loras.remove(seq_group.lora_int_id) + + # Determine victim sequence + cont_loop = True + if running_queue: + # Preempt the lowest-priority sequence group. + victim_seq_group = running_queue.pop() + else: + # No other sequence group can be preempted. + # Preempt the current sequence group. + # Note: This is also where we stop this loop + # (since there is nothing else to preempt) + victim_seq_group = seq_group + cont_loop = False + + # With async postprocessor, before preempting a sequence + # we need to ensure it has no pending async postprocessor + do_preempt = True + if self.use_async_output_proc: + assert self.output_proc_callback is not None + self.output_proc_callback( + request_id=victim_seq_group.request_id) + + # It may be that the async pending "victim_seq_group" + # becomes finished, in which case we simply free it. + if victim_seq_group.is_finished(): + self._free_finished_seq_group(victim_seq_group) + do_preempt = False + + # Do preemption + if do_preempt: + preempted_mode = self._preempt(victim_seq_group, + blocks_to_swap_out) + if preempted_mode == PreemptionMode.RECOMPUTE: + preempted.append(victim_seq_group) + else: + swapped_out.append(victim_seq_group) + + if not cont_loop: + break + else: + self._append_slots(seq_group, blocks_to_copy, enable_chunking) + is_prefill = seq_group.is_prefill() + + scheduled_seq_group: ScheduledSequenceGroup = \ + self._scheduled_seq_group_cache[self.cache_id].get_object() + scheduled_seq_group.seq_group = seq_group + if is_prefill: + scheduled_seq_group.token_chunk_size = num_running_tokens + prefill_seq_groups.append(scheduled_seq_group) + ret.prefill_seq_groups_list.append(seq_group) + else: + scheduled_seq_group.token_chunk_size = 1 + decode_seq_groups.append(scheduled_seq_group) + ret.decode_seq_groups_list.append(seq_group) + + budget.add_num_batched_tokens(seq_group.request_id, + num_running_tokens) + # OPTIMIZATION: Note that get_max_num_running_seqs is + # expensive. For the default scheduling chase where + # enable_chunking is False, num_seqs are updated before running + # this method, so we don't have to update it again here. + if enable_chunking: + num_running_seqs = seq_group.get_max_num_running_seqs() + budget.add_num_seqs(seq_group.request_id, num_running_seqs) + if curr_loras is not None and seq_group.lora_int_id > 0: + curr_loras.add(seq_group.lora_int_id) + + self._scheduler_running_outputs_cache[self.next_cache_id].reset() + self._scheduled_seq_group_cache[self.next_cache_id].reset() + + return ret + + def _schedule_swapped( + self, + budget: SchedulingBudget, + curr_loras: Optional[Set[int]], + enable_chunking: bool = False, + ) -> SchedulerSwappedInOutputs: + """Schedule sequence groups that are swapped out. + + It schedules swapped requests as long as it fits `budget` and + curr_loras <= max_lora from the scheduling config. The input arguments + `budget` and `curr_loras` are updated based on scheduled seq_groups. + + Args: + budget: The scheduling budget. The argument is in-place updated + when any requests are swapped in. + curr_loras: Currently batched lora request ids. The argument is + in-place updated when any requests are swapped in. + enable_chunking: If True, seq group can be chunked and only a + chunked number of tokens are scheduled if + `budget.num_batched_tokens` has not enough capacity to schedule + all tokens. + + Returns: + SchedulerSwappedInOutputs. + """ + # Blocks that need to be swapped or copied before model execution. + blocks_to_swap_in: List[Tuple[int, int]] = [] + blocks_to_copy: List[Tuple[int, int]] = [] + decode_seq_groups: List[ScheduledSequenceGroup] = [] + prefill_seq_groups: List[ScheduledSequenceGroup] = [] + infeasible_seq_groups: List[SequenceGroup] = [] + + swapped_queue = self.swapped + + leftover_swapped: Deque[SequenceGroup] = deque() + while swapped_queue: + seq_group = swapped_queue[0] + + # If the sequence group cannot be swapped in, stop. + is_prefill = seq_group.is_prefill() + alloc_status = self.block_manager.can_swap_in( + seq_group, + self._get_num_lookahead_slots(is_prefill, enable_chunking)) + if alloc_status == AllocStatus.LATER: + break + elif alloc_status == AllocStatus.NEVER: + logger.warning( + "Failing the request %s because there's not enough kv " + "cache blocks to run the entire sequence.", + seq_group.request_id) + for seq in seq_group.get_seqs(): + seq.status = SequenceStatus.FINISHED_IGNORED + infeasible_seq_groups.append(seq_group) + swapped_queue.popleft() + continue + + lora_int_id = 0 + if self.lora_enabled: + lora_int_id = seq_group.lora_int_id + assert curr_loras is not None + assert self.lora_config is not None + if (lora_int_id > 0 and (lora_int_id not in curr_loras) + and len(curr_loras) >= self.lora_config.max_loras): + # We don't have a space for another LoRA, so + # we ignore this request for now. + leftover_swapped.appendleft(seq_group) + swapped_queue.popleft() + continue + + # The total number of sequences in the RUNNING state should not + # exceed the maximum number of sequences. + num_new_seqs = seq_group.get_max_num_running_seqs() + num_new_tokens = self._get_num_new_tokens(seq_group, + SequenceStatus.SWAPPED, + enable_chunking, budget) + + if (num_new_tokens == 0 + or not budget.can_schedule(num_new_tokens=num_new_tokens, + num_new_seqs=num_new_seqs)): + break + + if lora_int_id > 0 and curr_loras is not None: + curr_loras.add(lora_int_id) + swapped_queue.popleft() + self._swap_in(seq_group, blocks_to_swap_in) + self._append_slots(seq_group, blocks_to_copy, enable_chunking) + is_prefill = seq_group.is_prefill() + if is_prefill: + prefill_seq_groups.append( + ScheduledSequenceGroup(seq_group, + token_chunk_size=num_new_tokens)) + else: + decode_seq_groups.append( + ScheduledSequenceGroup(seq_group, token_chunk_size=1)) + budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens) + budget.add_num_seqs(seq_group.request_id, num_new_seqs) + + swapped_queue.extendleft(leftover_swapped) + + return SchedulerSwappedInOutputs( + decode_seq_groups=decode_seq_groups, + prefill_seq_groups=prefill_seq_groups, + blocks_to_swap_in=blocks_to_swap_in, + blocks_to_copy=blocks_to_copy, + num_lookahead_slots=self._get_num_lookahead_slots( + is_prefill=False, enable_chunking=enable_chunking), + infeasible_seq_groups=infeasible_seq_groups, + ) + + def _get_prompt_limit(self, seq_group: SequenceGroup) -> int: + if self.scheduler_config.chunked_prefill_enabled and \ + not self.scheduler_config.is_multi_step: + prompt_limit = self.scheduler_config.max_model_len + else: + prompt_limit = min(self.scheduler_config.max_model_len, + self.scheduler_config.max_num_batched_tokens) + + # Model is fine tuned with long context. Return the fine tuned max_len. + if (seq_group.lora_request + and seq_group.lora_request.long_lora_max_len): + assert prompt_limit <= seq_group.lora_request.long_lora_max_len + return seq_group.lora_request.long_lora_max_len + else: + return prompt_limit + + def _get_priority(self, + seq_group: SequenceGroup) -> Tuple[Optional[int], float]: + """ Get the priority of the sequence group. + Highest preference to user-defined priority, followed by arrival time. + Args: + seq_group: The sequence group input. + Returns: + The priority of the sequence group. + """ + return seq_group.priority, seq_group.arrival_time + + def _schedule_priority_preemption( + self, + budget: SchedulingBudget, + ) -> int: + """Sorts waiting and running queue. Also, force preempt requests + from the running queue if their priority is lower. + Priority-based preemption is used with the priority policy. + Args: + budget: The scheduling budget. The argument is in-place updated + when any requests are scheduled. + Returns: + A count of priority-based preemptions. + """ + + waiting_queue = self.waiting + + running_queue = deque(sorted(self.running, key=self._get_priority)) + + blocks_to_swap_out: List[Tuple[int, int]] = [] + force_preemption_count = 0 + + if waiting_queue: + seq_group = waiting_queue.popleft() + num_new_seqs = seq_group.get_max_num_running_seqs() + num_new_tokens = self._get_num_new_tokens(seq_group, + SequenceStatus.WAITING, + False, budget) + + #Only preempt if priority inversion exists + while running_queue and self._get_priority( + running_queue[-1]) > self._get_priority(seq_group): + #Only preempt if waiting sequence cannot be allocated + can_allocate = self.block_manager.can_allocate(seq_group) + if (num_new_tokens and can_allocate == AllocStatus.OK + and budget.can_schedule(num_new_tokens=num_new_tokens, + num_new_seqs=num_new_seqs)): + break + + #Adjust budget to remove the victim sequence group + vseq_group = running_queue.pop() + num_running_tokens = self._get_num_new_tokens( + vseq_group, SequenceStatus.RUNNING, False, budget) + budget.subtract_num_batched_tokens(vseq_group.request_id, + num_running_tokens) + num_running_seqs = vseq_group.get_max_num_running_seqs() + budget.subtract_num_seqs(vseq_group.request_id, + num_running_seqs) + + #Preempt out the victim sequence group + self._preempt(vseq_group, blocks_to_swap_out, + PreemptionMode.RECOMPUTE) + waiting_queue.appendleft(vseq_group) + force_preemption_count += 1 + #Put the sequence back into the waiting queue + waiting_queue.appendleft(seq_group) + + waiting_queue = deque(sorted(waiting_queue, key=self._get_priority)) + + self.waiting = waiting_queue + self.running = running_queue + return force_preemption_count + + def _schedule_prefills( + self, + budget: SchedulingBudget, + curr_loras: Optional[Set[int]], + enable_chunking: bool = False, + ) -> SchedulerPrefillOutputs: + """Schedule sequence groups that are in prefill stage. + + Note that the current scheduler treats PREEMPTED_FOR_RECOMPUTE + as a new prefill (that starts from beginning -> most recently generated + tokens). + + It schedules waiting requests as long as it fits `budget` and + curr_loras <= max_lora from the scheduling config. The input arguments + `budget` and `curr_loras` are updated based on scheduled seq_groups. + + Args: + budget: The scheduling budget. The argument is in-place updated + when any requests are scheduled. + curr_loras: Currently batched lora request ids. The argument is + in-place updated when any requests are scheduled. + enable_chunking: If True, seq group can be chunked and only a + chunked number of tokens are scheduled if + `budget.num_batched_tokens` has not enough capacity to schedule + all tokens. + + Returns: + SchedulerPrefillOutputs. + """ + ignored_seq_groups: List[SequenceGroup] = [] + seq_groups: List[ScheduledSequenceGroup] = [] + + waiting_queue = self.waiting + + leftover_waiting_sequences: Deque[SequenceGroup] = deque() + while self._passed_delay(time.time()) and waiting_queue: + seq_group = waiting_queue[0] + + waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING) + assert len(waiting_seqs) == 1, ( + "Waiting sequence group should have only one prompt " + "sequence.") + num_new_tokens = self._get_num_new_tokens(seq_group, + SequenceStatus.WAITING, + enable_chunking, budget) + if not enable_chunking: + num_prompt_tokens = waiting_seqs[0].get_len() + assert num_new_tokens == num_prompt_tokens + + prompt_limit = self._get_prompt_limit(seq_group) + if num_new_tokens > prompt_limit: + logger.warning( + "Input prompt (%d tokens) is too long" + " and exceeds limit of %d", num_new_tokens, prompt_limit) + for seq in waiting_seqs: + seq.status = SequenceStatus.FINISHED_IGNORED + ignored_seq_groups.append(seq_group) + waiting_queue.popleft() + continue + + num_lookahead_slots: int = 0 + if self.scheduler_config.is_multi_step and enable_chunking: + num_lookahead_slots = self._get_num_lookahead_slots( + True, enable_chunking) + + # If the sequence group cannot be allocated, stop. + can_allocate = self.block_manager.can_allocate( + seq_group, num_lookahead_slots=num_lookahead_slots) + if can_allocate == AllocStatus.LATER: + break + elif can_allocate == AllocStatus.NEVER: + logger.warning( + "Input prompt (%d tokens) + lookahead slots (%d) is " + "too long and exceeds the capacity of block_manager", + num_new_tokens, num_lookahead_slots) + for seq in waiting_seqs: + seq.status = SequenceStatus.FINISHED_IGNORED + ignored_seq_groups.append(seq_group) + waiting_queue.popleft() + continue + + lora_int_id = 0 + if self.lora_enabled: + lora_int_id = seq_group.lora_int_id + assert curr_loras is not None + assert self.lora_config is not None + if (self.lora_enabled and lora_int_id > 0 + and lora_int_id not in curr_loras + and len(curr_loras) >= self.lora_config.max_loras): + # We don't have a space for another LoRA, so + # we ignore this request for now. + leftover_waiting_sequences.appendleft(seq_group) + waiting_queue.popleft() + continue + + num_new_seqs = seq_group.get_max_num_running_seqs() + if (num_new_tokens == 0 + or not budget.can_schedule(num_new_tokens=num_new_tokens, + num_new_seqs=num_new_seqs)): + break + + # Can schedule this request. + if curr_loras is not None and lora_int_id > 0: + curr_loras.add(lora_int_id) + waiting_queue.popleft() + self._allocate_and_set_running(seq_group) + + if enable_chunking and self.scheduler_config.is_multi_step: + blocks_to_copy: List[Tuple[int, int]] = [] + # init_multi_step_from_lookahead_slots happens in append_slots + self._append_slots(seq_group, blocks_to_copy, enable_chunking) + # This assert will trip when a copy-on-write happens. This is + # not a concern as the very first sequence-group block + # allocation happens above. Still, we have the assert to + # catch any edge-cases. + assert not blocks_to_copy + else: + seq_group.init_multi_step_from_lookahead_slots( + num_lookahead_slots, + num_scheduler_steps=self.scheduler_config. + num_scheduler_steps, + is_multi_step=self.scheduler_config.is_multi_step, + enable_chunking=enable_chunking) + + seq_groups.append( + ScheduledSequenceGroup(seq_group=seq_group, + token_chunk_size=num_new_tokens)) + budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens) + budget.add_num_seqs(seq_group.request_id, num_new_seqs) + + # Queue requests that couldn't be scheduled. + waiting_queue.extendleft(leftover_waiting_sequences) + if len(seq_groups) > 0: + self.prev_prompt = True + + return SchedulerPrefillOutputs( + seq_groups=seq_groups, + ignored_seq_groups=ignored_seq_groups, + num_lookahead_slots=self._get_num_lookahead_slots( + is_prefill=True, enable_chunking=enable_chunking)) + + def _schedule_default(self) -> SchedulerOutputs: + """Schedule queued requests. + + The current policy is designed to optimize the throughput. First, + it batches as many prefill requests as possible. And it schedules + decodes. If there's a pressure on GPU memory, decode requests can + be swapped or preempted. + """ + # Include running requests to the budget. + budget = SchedulingBudget( + token_budget=self.scheduler_config.max_num_batched_tokens, + max_num_seqs=self.scheduler_config.max_num_seqs, + ) + # Make sure we include num running seqs before scheduling prefill, + # so that we don't schedule beyond max_num_seqs for prefill. + for seq_group in self.running: + budget.add_num_seqs(seq_group.request_id, + seq_group.get_max_num_running_seqs()) + curr_loras = set( + seq_group.lora_int_id for seq_group in self.running + if seq_group.lora_int_id > 0) if self.lora_enabled else None + + prefills = SchedulerPrefillOutputs.create_empty() + running_scheduled = SchedulerRunningOutputs.create_empty() + swapped_in = SchedulerSwappedInOutputs.create_empty() + + # If any requests are swapped, prioritized swapped requests. + if not self.swapped: + prefills = self._schedule_prefills(budget, + curr_loras, + enable_chunking=False) + + if len(prefills.seq_groups + ) == 0 and self.scheduler_config.policy == "priority": + self._schedule_priority_preemption(budget) + + # Don't schedule decodes if prefills are scheduled. + # NOTE: If `_schedule_prefills` doesn't enable chunking, self.running + # only contains decode requests, not chunked prefills. + if len(prefills.seq_groups) == 0: + running_scheduled = self._schedule_running(budget, + curr_loras, + enable_chunking=False) + + # If any sequence group is preempted, do not swap in any sequence + # group. because it means there's no slot for new running requests. + if len(running_scheduled.preempted) + len( + running_scheduled.swapped_out) == 0: + swapped_in = self._schedule_swapped(budget, curr_loras) + + assert (budget.num_batched_tokens <= + self.scheduler_config.max_num_batched_tokens) + assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs + + # Update waiting requests. + self.waiting.extendleft(running_scheduled.preempted) + # Update new running requests. + if len(prefills.seq_groups) > 0: + self.running.extend([s.seq_group for s in prefills.seq_groups]) + + self.running.extend(running_scheduled.decode_seq_groups_list) + + if len(swapped_in.decode_seq_groups) > 0: + self.running.extend( + [s.seq_group for s in swapped_in.decode_seq_groups]) + + # Update swapped requests. + self.swapped.extend(running_scheduled.swapped_out) + preempted = (len(running_scheduled.preempted) + + len(running_scheduled.swapped_out)) + + # There should be no prefill from running queue because this policy + # doesn't allow chunked prefills. + assert len(running_scheduled.prefill_seq_groups) == 0 + assert len(swapped_in.prefill_seq_groups) == 0 + + # Merge lists + num_prefill_groups = len(prefills.seq_groups) + if num_prefill_groups > 0: + scheduled_seq_groups = prefills.seq_groups + scheduled_seq_groups.extend(running_scheduled.decode_seq_groups) + else: + scheduled_seq_groups = running_scheduled.decode_seq_groups + scheduled_seq_groups.extend(swapped_in.decode_seq_groups) + + blocks_to_copy = running_scheduled.blocks_to_copy + blocks_to_copy.extend(swapped_in.blocks_to_copy) + + ignored_seq_groups = prefills.ignored_seq_groups + ignored_seq_groups.extend(swapped_in.infeasible_seq_groups) + + return SchedulerOutputs( + scheduled_seq_groups=scheduled_seq_groups, + num_prefill_groups=num_prefill_groups, + num_batched_tokens=budget.num_batched_tokens, + blocks_to_swap_in=swapped_in.blocks_to_swap_in, + blocks_to_swap_out=running_scheduled.blocks_to_swap_out, + blocks_to_copy=blocks_to_copy, + ignored_seq_groups=ignored_seq_groups, + num_lookahead_slots=running_scheduled.num_lookahead_slots, + running_queue_size=len(self.running), + preempted=preempted, + ) + + def _schedule_chunked_prefill(self) -> SchedulerOutputs: + """Schedule queued requests. + + Chunked prefill allows to chunk prefill requests, batch them together + with decode requests. This policy 1. schedule as many decoding requests + as possible. 2. schedule chunked prefill requests that are not + finished. 3. schedule swapped request. 4. schedule new prefill + requests. + + The policy can sustain the high GPU utilization because it can put + prefill and decodes requests to the same batch, while it improves + inter token latency because decodes requests don't need to be blocked + by prefill requests. + """ + budget = SchedulingBudget( + token_budget=self.scheduler_config.max_num_batched_tokens, + max_num_seqs=self.scheduler_config.max_num_seqs, + ) + curr_loras: Set[int] = set() + + prefills = SchedulerPrefillOutputs.create_empty() + swapped_in = SchedulerSwappedInOutputs.create_empty() + + # Decoding should be always scheduled first by fcfs. + running_scheduled = self._schedule_running(budget, + curr_loras, + enable_chunking=True) + + # Schedule swapped out requests. + # If preemption happens, it means we don't have space for swap-in. + if len(running_scheduled.preempted) + len( + running_scheduled.swapped_out) == 0: + swapped_in = self._schedule_swapped(budget, curr_loras) + + # Schedule new prefills. + prefills = self._schedule_prefills(budget, + curr_loras, + enable_chunking=True) + + assert (budget.num_batched_tokens <= + self.scheduler_config.max_num_batched_tokens) + assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs + + # Update waiting requests. + self.waiting.extendleft(running_scheduled.preempted) + + # Update new running requests. + # By default, vLLM scheduler prioritizes prefills. + # Once chunked prefill is enabled, + # the policy is changed to prioritize decode requests. + self.running.extend( + [s.seq_group for s in swapped_in.decode_seq_groups]) + self.running.extend( + [s.seq_group for s in swapped_in.prefill_seq_groups]) + self.running.extend( + [s.seq_group for s in running_scheduled.decode_seq_groups]) + self.running.extend( + [s.seq_group for s in running_scheduled.prefill_seq_groups]) + self.running.extend([s.seq_group for s in prefills.seq_groups]) + + # Update swapped requests. + self.swapped.extend(running_scheduled.swapped_out) + return SchedulerOutputs( + scheduled_seq_groups=(prefills.seq_groups + + running_scheduled.prefill_seq_groups + + swapped_in.prefill_seq_groups + + running_scheduled.decode_seq_groups + + swapped_in.decode_seq_groups), + num_prefill_groups=(len(prefills.seq_groups) + + len(swapped_in.prefill_seq_groups) + + len(running_scheduled.prefill_seq_groups)), + num_batched_tokens=budget.num_batched_tokens, + blocks_to_swap_in=swapped_in.blocks_to_swap_in, + blocks_to_swap_out=running_scheduled.blocks_to_swap_out, + blocks_to_copy=running_scheduled.blocks_to_copy + + swapped_in.blocks_to_copy, + ignored_seq_groups=prefills.ignored_seq_groups + + swapped_in.infeasible_seq_groups, + num_lookahead_slots=running_scheduled.num_lookahead_slots, + running_queue_size=len(self.running), + preempted=(len(running_scheduled.preempted) + + len(running_scheduled.swapped_out)), + ) + + def _schedule(self) -> SchedulerOutputs: + """Schedule queued requests.""" + if self.scheduler_config.chunked_prefill_enabled: + return self._schedule_chunked_prefill() + else: + return self._schedule_default() + + def _can_append_slots(self, seq_group: SequenceGroup, + enable_chunking: bool) -> bool: + """Determine whether or not we have enough space in the KV cache to + continue generation of the sequence group. + """ + # It is True only for testing case to trigger artificial preemption. + if (self.enable_artificial_preemption + and random.uniform(0, 1) < ARTIFICIAL_PREEMPTION_PROB + and self.artificial_preempt_cnt > 0): + self.artificial_preempt_cnt -= 1 + return False + + is_prefill = seq_group.is_prefill() + num_lookahead_slots = self._get_num_lookahead_slots( + is_prefill, enable_chunking) + + if is_prefill and num_lookahead_slots > 0: + # Appending prefill slots only happens multi-step and + # chunked-prefill are enabled together. + assert self.scheduler_config.is_multi_step and enable_chunking + + return self.block_manager.can_append_slots( + seq_group=seq_group, num_lookahead_slots=num_lookahead_slots) + + def _allow_async_output_proc(self, seq_group: SequenceGroup) -> bool: + # async_output_proc is allowed only when we have a single sequence + # in the sequence group + no_single_seq = seq_group.sampling_params is None or ( + seq_group.sampling_params.n == 1) + return no_single_seq + + def schedule( + self + ) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, bool]: + # Schedule sequence groups. + # This function call changes the internal states of the scheduler + # such as self.running, self.swapped, and self.waiting. + scheduler_start_time = time.perf_counter() + + scheduler_outputs: SchedulerOutputs = self._schedule() + now = time.time() + + if not self.cache_config.enable_prefix_caching: + common_computed_block_nums = [] + + allow_async_output_proc: bool = self.use_async_output_proc + + # Create input data structures. + seq_group_metadata_list: List[SequenceGroupMetadata] = [] + for i, scheduled_seq_group in enumerate( + scheduler_outputs.scheduled_seq_groups): + seq_group = scheduled_seq_group.seq_group + token_chunk_size = scheduled_seq_group.token_chunk_size + seq_group.maybe_set_first_scheduled_time(now) + + seq_group_metadata = self._seq_group_metadata_cache[ + self.cache_id].get_object() + seq_group_metadata.seq_data.clear() + seq_group_metadata.block_tables.clear() + + # seq_id -> SequenceData + seq_data: Dict[int, SequenceData] = {} + # seq_id -> physical block numbers + block_tables: Dict[int, List[int]] = {} + + if seq_group.is_encoder_decoder(): + # Encoder associated with SequenceGroup + encoder_seq = seq_group.get_encoder_seq() + assert encoder_seq is not None + encoder_seq_data = encoder_seq.data + # Block table for cross-attention + # Also managed at SequenceGroup level + cross_block_table = self.block_manager.get_cross_block_table( + seq_group) + else: + encoder_seq_data = None + cross_block_table = None + + for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): + seq_id = seq.seq_id + seq_data[seq_id] = seq.data + block_tables[seq_id] = self.block_manager.get_block_table(seq) + self.block_manager.access_all_blocks_in_seq(seq, now) + + if self.cache_config.enable_prefix_caching: + common_computed_block_nums = ( + self.block_manager.get_common_computed_block_ids( + seq_group.get_seqs(status=SequenceStatus.RUNNING))) + + do_sample = True + is_prompt = seq_group.is_prefill() + # We should send the metadata to workers when the first prefill + # is sent. Subsequent requests could be chunked prefill or decode. + is_first_prefill = False + if is_prompt: + seqs = seq_group.get_seqs() + # Prefill has only 1 sequence. + assert len(seqs) == 1 + num_computed_tokens = seqs[0].data.get_num_computed_tokens() + is_first_prefill = num_computed_tokens == 0 + if (is_first_prefill + and self.cache_config.enable_prefix_caching + and seq_group.metrics is not None): + seq_group.metrics.num_cached_tokens = ( + len(common_computed_block_nums) + * self.cache_config.block_size) + # In the next iteration, all prompt tokens are not computed. + # It means the prefill is chunked, and we don't need sampling. + # NOTE: We use get_len instead of get_prompt_len because when + # a sequence is preempted, prefill includes previous generated + # output tokens. + if (token_chunk_size + num_computed_tokens < + seqs[0].data.get_len()): + do_sample = False + + # It assumes the scheduled_seq_groups is ordered by + # prefill < decoding. + if is_first_prefill or not self.scheduler_config.send_delta_data: + seq_group_metadata = SequenceGroupMetadata( + request_id=seq_group.request_id, + is_prompt=is_prompt, + seq_data=seq_data, + sampling_params=seq_group.sampling_params, + block_tables=block_tables, + do_sample=do_sample, + pooling_params=seq_group.pooling_params, + token_chunk_size=token_chunk_size, + lora_request=seq_group.lora_request, + computed_block_nums=common_computed_block_nums, + encoder_seq_data=encoder_seq_data, + cross_block_table=cross_block_table, + state=seq_group.state, + # `multi_modal_data` will only be present for the 1st comm + # between engine and worker. + # the subsequent comms can still use delta, but + # `multi_modal_data` will be None. + multi_modal_data=seq_group.multi_modal_data + if scheduler_outputs.num_prefill_groups > 0 else None, + mm_processor_kwargs=seq_group.mm_processor_kwargs, + prompt_adapter_request=seq_group.prompt_adapter_request, + ) + else: + # When SPMD mode is enabled, we only send delta data except for + # the first request to reduce serialization cost. + seq_data_delta = {} + for id, data in seq_data.items(): + seq_data_delta[id] = data.get_delta_and_reset() + seq_group_metadata = SequenceGroupMetadataDelta( + seq_data_delta, + seq_group.request_id, + block_tables, + is_prompt, + do_sample=do_sample, + token_chunk_size=token_chunk_size, + computed_block_nums=common_computed_block_nums, + ) + seq_group_metadata_list.append(seq_group_metadata) + + if allow_async_output_proc: + allow_async_output_proc = self._allow_async_output_proc( + seq_group) + + # Now that the batch has been created, we can assume all blocks in the + # batch will have been computed before the next scheduling invocation. + # This is because the engine assumes that a failure in model execution + # will crash the vLLM instance / will not retry. + for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups: + self.block_manager.mark_blocks_as_computed( + scheduled_seq_group.seq_group, + scheduled_seq_group.token_chunk_size) + + self._seq_group_metadata_cache[self.next_cache_id].reset() + + scheduler_time = time.perf_counter() - scheduler_start_time + # Add this to scheduler time to all the sequences that are currently + # running. This will help estimate if the scheduler is a significant + # component in the e2e latency. + for seq_group in self.running: + if seq_group is not None and seq_group.metrics is not None: + if seq_group.metrics.scheduler_time is not None: + seq_group.metrics.scheduler_time += scheduler_time + else: + seq_group.metrics.scheduler_time = scheduler_time + + # Move to next cache (if exists) + self.cache_id = self.next_cache_id + + # Return results + return (seq_group_metadata_list, scheduler_outputs, + allow_async_output_proc) + + def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None: + self.block_manager.fork(parent_seq, child_seq) + + def free_seq(self, seq: Sequence) -> None: + """Free a sequence from a block table.""" + self.block_manager.free(seq) + + def _free_finished_seqs(self, seq_group: SequenceGroup) -> None: + """Free finished seqs in a sequence group.""" + for seq in seq_group.get_seqs(): + if seq.is_finished(): + self.free_seq(seq) + + def _free_finished_seq_group(self, seq_group: SequenceGroup) -> None: + if seq_group.is_finished(): + # Free cross-attention block table, if it exists + self._free_seq_group_cross_attn_blocks(seq_group) + + # Add the finished requests to the finished requests list. + # This list will be used to update the Mamba cache in the + # next step. + self._finished_requests_ids.append(seq_group.request_id) + + # Free finished seqs + self._free_finished_seqs(seq_group) + + def free_finished_seq_groups(self) -> None: + remaining: Deque[SequenceGroup] = deque() + for seq_group in self.running: + self._free_finished_seq_group(seq_group) + if not seq_group.is_finished(): + remaining.append(seq_group) + + self.running = remaining + + # Handle async stopped sequence groups + # (ones that reached max model len) + if self._async_stopped: + for seq_group in self._async_stopped: + self._free_seq_group_cross_attn_blocks(seq_group) + self._finished_requests_ids.append(seq_group.request_id) + + # Free finished seqs + self._free_finished_seqs(seq_group) + + self._async_stopped.clear() + + def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None: + self.block_manager.allocate(seq_group) + for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): + seq.status = SequenceStatus.RUNNING + + def _append_slots(self, + seq_group: SequenceGroup, + blocks_to_copy: List[Tuple[int, int]], + enable_chunking: bool = False) -> None: + """Appends new slots to the sequences in the given sequence group. + + Args: + seq_group (SequenceGroup): The sequence group containing the + sequences to append slots to. + blocks_to_copy (List[Tuple[int, int]]): A list of tuple of two + ints, the first int is the source block index, and the second + int is the destination block index. This list is updated with + the new source and destination block indices for the appended + slots. + enable_chunking (bool): True if chunked prefill is enabled. + """ + is_prefill: bool = seq_group.is_prefill() + num_lookahead_slots: int = self._get_num_lookahead_slots( + is_prefill, enable_chunking) + + seq_group.init_multi_step_from_lookahead_slots( + num_lookahead_slots, + num_scheduler_steps=self.scheduler_config.num_scheduler_steps, + is_multi_step=self.scheduler_config.is_multi_step, + enable_chunking=enable_chunking) + + seq_status: Optional[SequenceStatus] = SequenceStatus.RUNNING + if self.scheduler_config.is_multi_step and enable_chunking: + # In multi-step chunked-prefill any sequence type can have + # slots appended. + seq_status = None + + for seq in seq_group.get_seqs(status=seq_status): + cows = self.block_manager.append_slots(seq, num_lookahead_slots) + if len(cows) > 0: + blocks_to_copy.extend(cows) + + def _preempt( + self, + seq_group: SequenceGroup, + blocks_to_swap_out: List[Tuple[int, int]], + preemption_mode: Optional[PreemptionMode] = None, + ) -> PreemptionMode: + # If preemption mode is not specified, we determine the mode as follows: + # We use recomputation by default since it incurs lower overhead than + # swapping. However, when the sequence group has multiple sequences + # (e.g., beam search), recomputation is not currently supported. In + # such a case, we use swapping instead. + # FIXME(woosuk): This makes our scheduling policy a bit bizarre. + # As swapped sequences are prioritized over waiting sequences, + # sequence groups with multiple sequences are implicitly prioritized + # over sequence groups with a single sequence. + # TODO(woosuk): Support recomputation for sequence groups with multiple + # sequences. This may require a more sophisticated CUDA kernel. + if self.user_specified_preemption_mode is None: + if seq_group.get_max_num_running_seqs() == 1: + preemption_mode = PreemptionMode.RECOMPUTE + else: + preemption_mode = PreemptionMode.SWAP + + elif self.user_specified_preemption_mode == "swap": + preemption_mode = PreemptionMode.SWAP + else: + preemption_mode = PreemptionMode.RECOMPUTE + + if self.num_cumulative_preemption % 50 == 0: + logger.warning( + "Sequence group %s is preempted by %s mode because there is " + "not enough KV cache space. This can affect the end-to-end " + "performance. Increase gpu_memory_utilization or " + "tensor_parallel_size to provide more KV cache memory. " + "total_num_cumulative_preemption=%d", seq_group.request_id, + preemption_mode, self.num_cumulative_preemption + 1) + self.num_cumulative_preemption += 1 + + if preemption_mode == PreemptionMode.RECOMPUTE: + self._preempt_by_recompute(seq_group) + elif preemption_mode == PreemptionMode.SWAP: + self._preempt_by_swap(seq_group, blocks_to_swap_out) + else: + raise AssertionError("Invalid preemption mode.") + return preemption_mode + + def _preempt_by_recompute( + self, + seq_group: SequenceGroup, + ) -> None: + seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) + assert len(seqs) == 1 + for seq in seqs: + seq.status = SequenceStatus.WAITING + self.free_seq(seq) + seq.reset_state_for_recompute() + + def _preempt_by_swap( + self, + seq_group: SequenceGroup, + blocks_to_swap_out: List[Tuple[int, int]], + ) -> None: + self._swap_out(seq_group, blocks_to_swap_out) + + def _swap_in( + self, + seq_group: SequenceGroup, + blocks_to_swap_in: List[Tuple[int, int]], + ) -> None: + mapping = self.block_manager.swap_in(seq_group) + blocks_to_swap_in.extend(mapping) + for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): + seq.status = SequenceStatus.RUNNING + + def _swap_out( + self, + seq_group: SequenceGroup, + blocks_to_swap_out: List[Tuple[int, int]], + ) -> None: + if not self.block_manager.can_swap_out(seq_group): + # FIXME(woosuk): Abort the sequence group instead of aborting the + # entire engine. + raise RuntimeError( + "Aborted due to the lack of CPU swap space. Please increase " + "the swap space to avoid this error.") + mapping = self.block_manager.swap_out(seq_group) + blocks_to_swap_out.extend(mapping) + for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): + seq.status = SequenceStatus.SWAPPED + + def _passed_delay(self, now: float) -> bool: + if self.prev_prompt: + self.last_prompt_latency = now - self.prev_time + self.prev_time, self.prev_prompt = now, False + # Delay scheduling prompts to let waiting queue fill up + if self.scheduler_config.delay_factor > 0 and self.waiting: + earliest_arrival_time = min( + [e.metrics.arrival_time for e in self.waiting]) + passed_delay = ( + (now - earliest_arrival_time) > + (self.scheduler_config.delay_factor * self.last_prompt_latency) + or not self.running) + else: + passed_delay = True + return passed_delay + + def _get_num_lookahead_slots(self, is_prefill: bool, + enable_chunking: bool) -> int: + """The number of slots to allocate per sequence per step, beyond known + token ids. Speculative decoding uses these slots to store KV activations + of tokens which may or may not be accepted. + + Speculative decoding does not yet support prefill, so we do not perform + lookahead allocation for prefill. + + When chunking is enabled with multi-step, we allocate lookahead slots + for the prefills for when the prefills turn into decodes in the first + step. + """ + if is_prefill: + if self.scheduler_config.is_multi_step and enable_chunking: + # num_lookahead_slots was introduced in the context of decodes, + # in Speculative Decoding. + # When the num_scheduler_steps is 8, say, then the + # num_lookahead_slots is 7. Meaning, we are doing a 1-step of + # decode anyways and we wish to do 7 more. + # + # "lookaheads" for prefills, is introduced in support for + # Chunked-Prefill in Multi-Step. + return self.scheduler_config.num_lookahead_slots + 1 + else: + return 0 + + return self.scheduler_config.num_lookahead_slots + + def _get_num_new_tokens(self, seq_group: SequenceGroup, + status: SequenceStatus, enable_chunking: bool, + budget: SchedulingBudget) -> int: + """Get the next new tokens to compute for a given sequence group + that's in a given `status`. + + The API could chunk the number of tokens to compute based on `budget` + if `enable_chunking` is True. If a sequence group has multiple + sequences (e.g., running beam search), it means it is in decoding + phase, so chunking doesn't happen. + + Returns 0 if the new token cannot be computed due to token budget. + """ + num_new_tokens = 0 + seqs = seq_group.get_seqs(status=status) + for seq in seqs: + num_new_tokens += seq.get_num_new_tokens() + assert num_new_tokens > 0 + # Chunk if a running request cannot fit in the given budget. + # If number of seq > 1, it means it is doing beam search + # in a decode phase. Do not chunk. + if enable_chunking and len(seqs) == 1: + remaining_token_budget = budget.remaining_token_budget() + if self.scheduler_config.is_multi_step: + # The current multi-step + chunked prefill capability does + # not actually support chunking prompts. + # + # Therefore, `num_new_tokens` is computed in the same fashion + # for both multi-step+chunked-prefill & + # multi-step+chunked-prefill+APC + # + # Prompts with more tokens than the current remaining budget + # are postponed to future scheduler steps + if num_new_tokens > self._get_prompt_limit(seq_group): + # If the seq_group is in prompt-stage, pass the + # num_new_tokens as-is so the caller can ignore + # the sequence. + pass + else: + num_new_tokens = 0 \ + if num_new_tokens > remaining_token_budget \ + else num_new_tokens + elif self.cache_config.enable_prefix_caching: + # When prefix caching is enabled, we always allocate + # the number of new tokens that is dividable by the block + # size to avoid partial block matching. + block_size = self.cache_config.block_size + remainder = budget.token_budget % block_size + if remainder != 0: + raise ValueError("When enabling chunked prefill and " + "prefix caching, max_num_batched_tokens " + "(chunk size) must be dividable by " + "block size, but got chunk_size " + f"({budget.token_budget}) % block_size " + f"({block_size}) = {remainder}") + if remaining_token_budget < num_new_tokens: + num_new_tokens = (remaining_token_budget // + block_size) * block_size + else: + num_new_tokens = min(num_new_tokens, remaining_token_budget) + return num_new_tokens diff --git a/qwen3_6_scripts/sequence.py b/qwen3_6_scripts/sequence.py index a3e59f3..6c08613 100644 --- a/qwen3_6_scripts/sequence.py +++ b/qwen3_6_scripts/sequence.py @@ -119,6 +119,7 @@ class RequestMetrics: scheduler_time: Optional[float] = None model_forward_time: Optional[float] = None model_execute_time: Optional[float] = None + num_cached_tokens: Optional[int] = None class SequenceDataDelta( diff --git a/qwen3_6_scripts/serving_chat.py b/qwen3_6_scripts/serving_chat.py index 6404ad4..c4d24fa 100644 --- a/qwen3_6_scripts/serving_chat.py +++ b/qwen3_6_scripts/serving_chat.py @@ -25,7 +25,7 @@ from vllm.entrypoints.openai.protocol import ( ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage, DeltaToolCall, ErrorResponse, FunctionCall, RequestResponseMetadata, - ToolCall, UsageInfo) + PromptTokensDetails, ToolCall, UsageInfo) from vllm.entrypoints.openai.serving_engine import (BaseModelPath, LoRAModulePath, OpenAIServing, @@ -179,6 +179,16 @@ class OpenAIServingChat(OpenAIServing): logger.exception("Error in loading multi-modal data") return self.create_error_response(str(e)) + # n > max_num_seqs deadlock guard: scheduler uses break (not continue) + # when can_schedule(num_new_seqs=n) fails, so an n that exceeds + # max_num_seqs permanently blocks the entire waiting queue with no error. + _sched_cfg = await self.engine_client.get_scheduler_config() + _max_seqs = _sched_cfg.max_num_seqs + if request.n is not None and request.n > _max_seqs: + return self.create_error_response( + f"n={request.n} exceeds max_num_seqs={_max_seqs}. " + f"Use n<={_max_seqs} or omit n.") + # validation for OpenAI tools # tool_choice = "required" is not supported if request.tool_choice == "required": @@ -318,6 +328,7 @@ class OpenAIServingChat(OpenAIServing): previous_num_tokens = [0] * num_choices finish_reason_sent = [False] * num_choices num_prompt_tokens = 0 + num_cached_tokens: Optional[int] = None if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam): tool_choice_function_name = request.tool_choice.function.name @@ -385,6 +396,10 @@ class OpenAIServingChat(OpenAIServing): num_prompt_tokens = len(res.prompt_token_ids) if res.encoder_prompt_token_ids is not None: num_prompt_tokens += len(res.encoder_prompt_token_ids) + if (num_cached_tokens is None + and res.metrics is not None + and res.metrics.num_cached_tokens is not None): + num_cached_tokens = res.metrics.num_cached_tokens # We need to do it here, because if there are exceptions in # the result_generator, it needs to be sent as the FIRST @@ -691,6 +706,9 @@ class OpenAIServingChat(OpenAIServing): completion_tokens=completion_tokens, total_tokens=num_prompt_tokens + completion_tokens, reasoning_tokens=total_reasoning, + prompt_tokens_details=( + PromptTokensDetails(cached_tokens=num_cached_tokens) + if num_cached_tokens is not None else None), ) final_usage_chunk = ChatCompletionStreamResponse( @@ -713,6 +731,10 @@ class OpenAIServingChat(OpenAIServing): total_tokens=num_prompt_tokens + num_completion_tokens, reasoning_tokens=total_reasoning) + except asyncio.CancelledError: + # Client disconnected; abort the engine request so GPU is freed. + await self.engine_client.abort(request_id) + return except ValueError as e: # TODO: Use a vllm-specific Validation Error logger.error("error in chat completion stream generator: %s", e) @@ -739,6 +761,7 @@ class OpenAIServingChat(OpenAIServing): async for res in result_generator: final_res = res except asyncio.CancelledError: + await self.engine_client.abort(request_id) return self.create_error_response("Client disconnected") assert final_res is not None @@ -881,11 +904,16 @@ class OpenAIServingChat(OpenAIServing): total_reasoning_tokens = sum( rp.count_reasoning_tokens(list(output.token_ids)) for output in final_res.outputs) + num_cached_tokens = (final_res.metrics.num_cached_tokens + if final_res.metrics is not None else None) usage = UsageInfo( prompt_tokens=num_prompt_tokens, completion_tokens=num_generated_tokens, total_tokens=num_prompt_tokens + num_generated_tokens, reasoning_tokens=total_reasoning_tokens, + prompt_tokens_details=( + PromptTokensDetails(cached_tokens=num_cached_tokens) + if num_cached_tokens is not None else None), ) request_metadata.final_usage_info = usage