Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -5,8 +5,6 @@ from dataclasses import dataclass
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from vllm._bc_linter import bc_linter_include
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
@@ -29,7 +27,6 @@ else:
|
||||
Request = object
|
||||
|
||||
|
||||
@bc_linter_include
|
||||
@dataclass
|
||||
class NewRequestData:
|
||||
req_id: str
|
||||
@@ -109,7 +106,6 @@ class NewRequestData:
|
||||
)
|
||||
|
||||
|
||||
@bc_linter_include
|
||||
@dataclass
|
||||
class CachedRequestData:
|
||||
req_ids: list[str]
|
||||
@@ -179,7 +175,6 @@ class CachedRequestData:
|
||||
)
|
||||
|
||||
|
||||
@bc_linter_include
|
||||
@dataclass
|
||||
class SchedulerOutput:
|
||||
# list of the requests that are scheduled for the first time.
|
||||
@@ -217,6 +212,9 @@ class SchedulerOutput:
|
||||
# freed from the encoder cache.
|
||||
free_encoder_mm_hashes: list[str]
|
||||
|
||||
# Request IDs that are resumed from preemption in this step.
|
||||
scheduled_resumed_reqs: list[str] | None = None
|
||||
|
||||
# Request IDs that are preempted in this step.
|
||||
# Only used for v2 model runner.
|
||||
preempted_req_ids: set[str] | None = None
|
||||
@@ -238,6 +236,11 @@ class SchedulerOutput:
|
||||
# EC Cache Connector metadata
|
||||
ec_connector_metadata: ECConnectorMetadata | None = None
|
||||
|
||||
# Block IDs freshly allocated from the pool during this scheduling step.
|
||||
# The worker zeros the corresponding GPU memory before the blocks are used,
|
||||
# preventing stale NaN/data from corrupting attention or SSM computation.
|
||||
new_block_ids_to_zero: list[int] | None = None
|
||||
|
||||
@classmethod
|
||||
def make_empty(cls) -> "SchedulerOutput":
|
||||
return cls(
|
||||
|
||||
@@ -48,7 +48,7 @@ from vllm.v1.core.sched.output import (
|
||||
from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_queue
|
||||
from vllm.v1.core.sched.utils import check_stop, remove_all
|
||||
from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig, MambaSpec
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.metrics.perf import ModelMetrics, PerfStats
|
||||
from vllm.v1.metrics.stats import PrefixCacheStats, SchedulerStats
|
||||
from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput
|
||||
@@ -233,13 +233,8 @@ class Scheduler(SchedulerInterface):
|
||||
self.use_pp = self.parallel_config.pipeline_parallel_size > 1
|
||||
self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER
|
||||
|
||||
def has_mamba_layers(kv_cache_config: KVCacheConfig) -> bool:
|
||||
return any(
|
||||
isinstance(group_spec.kv_cache_spec, MambaSpec)
|
||||
for group_spec in kv_cache_config.kv_cache_groups
|
||||
)
|
||||
|
||||
self.has_mamba_layers = has_mamba_layers(kv_cache_config)
|
||||
self.has_mamba_layers = kv_cache_config.has_mamba_layers
|
||||
self.needs_kv_cache_zeroing = kv_cache_config.needs_kv_cache_zeroing
|
||||
self.need_mamba_block_aligned_split = (
|
||||
self.has_mamba_layers and self.cache_config.mamba_cache_mode == "align"
|
||||
)
|
||||
@@ -320,6 +315,9 @@ class Scheduler(SchedulerInterface):
|
||||
return num_new_tokens
|
||||
|
||||
def schedule(self) -> SchedulerOutput:
|
||||
if envs.VLLM_ENABLE_PP_MIX_ILU_SCHEDULING:
|
||||
return self.schedule_opt()
|
||||
|
||||
# NOTE(woosuk) on the scheduling algorithm:
|
||||
# There's no "decoding phase" nor "prefill phase" in the scheduler.
|
||||
# Each request just has the num_computed_tokens and
|
||||
@@ -413,7 +411,7 @@ class Scheduler(SchedulerInterface):
|
||||
request, num_new_tokens
|
||||
)
|
||||
|
||||
if num_new_tokens == 0:
|
||||
if num_new_tokens <= 0:
|
||||
# The request cannot be scheduled because one of the following
|
||||
# reasons:
|
||||
# 1. No new tokens to schedule. This may happen when
|
||||
@@ -425,6 +423,8 @@ class Scheduler(SchedulerInterface):
|
||||
# 3. The encoder cache is exhausted.
|
||||
# 4. Insufficient budget for a block-aligned chunk in hybrid
|
||||
# models with mamba cache mode \"align\".
|
||||
# 5. num_computed_tokens > num_tokens_with_spec due to PP
|
||||
# timing: schedule() runs before update_from_output().
|
||||
# NOTE(woosuk): Here, by doing `continue` instead of `break`,
|
||||
# we do not strictly follow the FCFS scheduling policy and
|
||||
# allow the lower-priority requests to be scheduled.
|
||||
@@ -670,7 +670,7 @@ class Scheduler(SchedulerInterface):
|
||||
# If chunked_prefill is disabled,
|
||||
# we can stop the scheduling here.
|
||||
break
|
||||
temp_num_new_tokens = num_new_tokens
|
||||
|
||||
num_new_tokens = min(num_new_tokens, token_budget)
|
||||
assert num_new_tokens > 0
|
||||
|
||||
@@ -688,7 +688,7 @@ class Scheduler(SchedulerInterface):
|
||||
encoder_compute_budget,
|
||||
shift_computed_tokens=1 if self.use_eagle else 0,
|
||||
)
|
||||
if num_new_tokens == 0 or num_new_tokens < temp_num_new_tokens:
|
||||
if num_new_tokens == 0:
|
||||
# The request cannot be scheduled.
|
||||
break
|
||||
|
||||
@@ -723,6 +723,35 @@ class Scheduler(SchedulerInterface):
|
||||
for i in encoder_inputs_to_schedule
|
||||
)
|
||||
|
||||
if not load_kv_async:
|
||||
enable_chunked = self.scheduler_config.enable_chunked_prefill
|
||||
tokens_still_to_compute = (
|
||||
request.num_tokens - num_computed_tokens
|
||||
)
|
||||
is_chunked = (
|
||||
enable_chunked
|
||||
and tokens_still_to_compute > num_new_tokens
|
||||
)
|
||||
if is_chunked:
|
||||
assert (
|
||||
request.num_tokens <= self.max_model_len
|
||||
), "request.num_tokens must not exceed max_model_len"
|
||||
num_tokens_need_slot = min(
|
||||
request.num_tokens + effective_lookahead_tokens,
|
||||
self.max_model_len,
|
||||
)
|
||||
blocks_needed = (
|
||||
self.kv_cache_manager.get_num_blocks_needed_for_tokens(
|
||||
request.request_id,
|
||||
num_tokens_need_slot,
|
||||
new_computed_blocks,
|
||||
num_encoder_tokens,
|
||||
)
|
||||
)
|
||||
num_free = self.kv_cache_manager.get_num_free_blocks()
|
||||
if num_free < blocks_needed:
|
||||
break
|
||||
|
||||
new_blocks = self.kv_cache_manager.allocate_slots(
|
||||
request,
|
||||
num_new_tokens,
|
||||
@@ -871,6 +900,12 @@ class Scheduler(SchedulerInterface):
|
||||
self.prev_step_scheduled_req_ids.clear()
|
||||
self.prev_step_scheduled_req_ids.update(num_scheduled_tokens.keys())
|
||||
|
||||
new_block_ids_to_zero = (
|
||||
(self.kv_cache_manager.take_new_block_ids() or None)
|
||||
if self.needs_kv_cache_zeroing
|
||||
else None
|
||||
)
|
||||
|
||||
scheduler_output = SchedulerOutput(
|
||||
scheduled_new_reqs=new_reqs_data,
|
||||
scheduled_cached_reqs=cached_reqs_data,
|
||||
@@ -886,6 +921,7 @@ class Scheduler(SchedulerInterface):
|
||||
# the previous and the current steps.
|
||||
finished_req_ids=self.finished_req_ids,
|
||||
free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(),
|
||||
new_block_ids_to_zero=new_block_ids_to_zero,
|
||||
)
|
||||
|
||||
# NOTE(Kuntai): this function is designed for multiple purposes:
|
||||
@@ -909,6 +945,527 @@ class Scheduler(SchedulerInterface):
|
||||
self._update_after_schedule(scheduler_output)
|
||||
return scheduler_output
|
||||
|
||||
def schedule_opt(self) -> SchedulerOutput:
|
||||
"""PP mix ILU scheduling variant of schedule()."""
|
||||
|
||||
scheduled_new_reqs: list[Request] = []
|
||||
scheduled_resumed_reqs: list[Request] = []
|
||||
scheduled_running_reqs: list[Request] = []
|
||||
preempted_reqs: list[Request] = []
|
||||
|
||||
req_to_new_blocks: dict[str, KVCacheBlocks] = {}
|
||||
num_scheduled_tokens: dict[str, int] = {}
|
||||
token_budget = self.max_num_scheduled_tokens
|
||||
if self._pause_state == PauseState.PAUSED_ALL:
|
||||
token_budget = 0
|
||||
|
||||
# Encoder-related.
|
||||
scheduled_encoder_inputs: dict[str, list[int]] = {}
|
||||
encoder_compute_budget = self.max_num_encoder_input_tokens
|
||||
# Spec decode-related.
|
||||
scheduled_spec_decode_tokens: dict[str, list[int]] = {}
|
||||
|
||||
# For logging.
|
||||
scheduled_timestamp = time.monotonic()
|
||||
|
||||
self.kv_cache_manager.new_step_starts()
|
||||
|
||||
# First, schedule the RUNNING requests.
|
||||
req_index = 0
|
||||
while req_index < len(self.running) and token_budget > 0:
|
||||
request = self.running[req_index]
|
||||
|
||||
if (
|
||||
request.num_output_placeholders > 0
|
||||
and request.num_computed_tokens + 2 - request.num_output_placeholders
|
||||
>= request.num_prompt_tokens + request.max_tokens
|
||||
):
|
||||
req_index += 1
|
||||
continue
|
||||
|
||||
num_new_tokens = (
|
||||
request.num_tokens_with_spec
|
||||
+ request.num_output_placeholders
|
||||
- request.num_computed_tokens
|
||||
)
|
||||
if 0 < self.scheduler_config.long_prefill_token_threshold < num_new_tokens:
|
||||
num_new_tokens = self.scheduler_config.long_prefill_token_threshold
|
||||
num_new_tokens = min(num_new_tokens, token_budget)
|
||||
|
||||
num_new_tokens = min(
|
||||
num_new_tokens, self.max_model_len - 1 - request.num_computed_tokens
|
||||
)
|
||||
|
||||
# Schedule encoder inputs.
|
||||
encoder_inputs_to_schedule = None
|
||||
external_load_encoder_input: list[int] = []
|
||||
new_encoder_compute_budget = encoder_compute_budget
|
||||
if request.has_encoder_inputs:
|
||||
(
|
||||
encoder_inputs_to_schedule,
|
||||
num_new_tokens,
|
||||
new_encoder_compute_budget,
|
||||
external_load_encoder_input,
|
||||
) = self._try_schedule_encoder_inputs(
|
||||
request,
|
||||
request.num_computed_tokens,
|
||||
num_new_tokens,
|
||||
encoder_compute_budget,
|
||||
shift_computed_tokens=1 if self.use_eagle else 0,
|
||||
)
|
||||
|
||||
if self.need_mamba_block_aligned_split:
|
||||
num_new_tokens = self._mamba_block_aligned_split(
|
||||
request, num_new_tokens
|
||||
)
|
||||
|
||||
if num_new_tokens <= 0:
|
||||
# The request cannot be scheduled because one of the following
|
||||
# reasons:
|
||||
# 1. No new tokens to schedule. This may happen when
|
||||
# (1) PP>1 and we have already scheduled all prompt tokens
|
||||
# but they are not finished yet.
|
||||
# (2) Async scheduling and the request has reached to either
|
||||
# its max_total_tokens or max_model_len.
|
||||
# 2. The encoder budget is exhausted.
|
||||
# 3. The encoder cache is exhausted.
|
||||
# 4. num_computed_tokens > num_tokens_with_spec due to PP
|
||||
# timing: schedule() runs before update_from_output().
|
||||
req_index += 1
|
||||
continue
|
||||
|
||||
# Schedule newly needed KV blocks for the request.
|
||||
with record_function_or_nullcontext("schedule: allocate_slots"):
|
||||
while True:
|
||||
new_blocks = self.kv_cache_manager.allocate_slots(
|
||||
request,
|
||||
num_new_tokens,
|
||||
num_lookahead_tokens=self.num_lookahead_tokens,
|
||||
)
|
||||
|
||||
if new_blocks is not None:
|
||||
break
|
||||
|
||||
if self.policy == SchedulingPolicy.PRIORITY:
|
||||
preempted_req = max(
|
||||
self.running,
|
||||
key=lambda r: (r.priority, r.arrival_time),
|
||||
)
|
||||
self.running.remove(preempted_req)
|
||||
if preempted_req in scheduled_running_reqs:
|
||||
preempted_req_id = preempted_req.request_id
|
||||
scheduled_running_reqs.remove(preempted_req)
|
||||
token_budget += num_scheduled_tokens.pop(preempted_req_id)
|
||||
req_to_new_blocks.pop(preempted_req_id)
|
||||
scheduled_spec_decode_tokens.pop(preempted_req_id, None)
|
||||
preempted_encoder_inputs = scheduled_encoder_inputs.pop(
|
||||
preempted_req_id, None
|
||||
)
|
||||
if preempted_encoder_inputs:
|
||||
num_embeds_to_restore = sum(
|
||||
preempted_req.get_num_encoder_embeds(i)
|
||||
for i in preempted_encoder_inputs
|
||||
)
|
||||
encoder_compute_budget += num_embeds_to_restore
|
||||
req_index -= 1
|
||||
else:
|
||||
preempted_req = self.running.pop()
|
||||
|
||||
self._preempt_request(preempted_req, scheduled_timestamp)
|
||||
preempted_reqs.append(preempted_req)
|
||||
if preempted_req == request:
|
||||
break
|
||||
|
||||
if new_blocks is None:
|
||||
break
|
||||
|
||||
scheduled_running_reqs.append(request)
|
||||
request_id = request.request_id
|
||||
req_to_new_blocks[request_id] = new_blocks
|
||||
num_scheduled_tokens[request_id] = num_new_tokens
|
||||
token_budget -= num_new_tokens
|
||||
req_index += 1
|
||||
|
||||
if request.spec_token_ids:
|
||||
num_scheduled_spec_tokens = (
|
||||
num_new_tokens
|
||||
+ request.num_computed_tokens
|
||||
- request.num_tokens
|
||||
- request.num_output_placeholders
|
||||
)
|
||||
if num_scheduled_spec_tokens > 0:
|
||||
spec_token_ids = request.spec_token_ids
|
||||
if len(spec_token_ids) > num_scheduled_spec_tokens:
|
||||
spec_token_ids = spec_token_ids[:num_scheduled_spec_tokens]
|
||||
scheduled_spec_decode_tokens[request.request_id] = spec_token_ids
|
||||
|
||||
request.spec_token_ids = []
|
||||
|
||||
if encoder_inputs_to_schedule:
|
||||
scheduled_encoder_inputs[request_id] = encoder_inputs_to_schedule
|
||||
for i in encoder_inputs_to_schedule:
|
||||
self.encoder_cache_manager.allocate(request, i)
|
||||
encoder_compute_budget = new_encoder_compute_budget
|
||||
if external_load_encoder_input:
|
||||
for i in external_load_encoder_input:
|
||||
self.encoder_cache_manager.allocate(request, i)
|
||||
if self.ec_connector is not None:
|
||||
self.ec_connector.update_state_after_alloc(request, i)
|
||||
|
||||
# Record the LoRAs in scheduled_running_reqs
|
||||
scheduled_loras: set[int] = set()
|
||||
if self.lora_config:
|
||||
scheduled_loras = set(
|
||||
req.lora_request.lora_int_id
|
||||
for req in scheduled_running_reqs
|
||||
if req.lora_request and req.lora_request.lora_int_id > 0
|
||||
)
|
||||
assert len(scheduled_loras) <= self.lora_config.max_loras
|
||||
|
||||
# Next, schedule the WAITING requests.
|
||||
if not preempted_reqs and self._pause_state == PauseState.UNPAUSED:
|
||||
skipped_waiting_requests = create_request_queue(self.policy)
|
||||
|
||||
while self.waiting and token_budget > 0:
|
||||
if len(self.running) == self.max_num_running_reqs:
|
||||
break
|
||||
|
||||
request = self.waiting.peek_request()
|
||||
request_id = request.request_id
|
||||
|
||||
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
|
||||
is_ready = self._update_waiting_for_remote_kv(request)
|
||||
if is_ready:
|
||||
if request.num_preemptions:
|
||||
request.status = RequestStatus.PREEMPTED
|
||||
else:
|
||||
request.status = RequestStatus.WAITING
|
||||
else:
|
||||
logger.debug(
|
||||
"%s is still in WAITING_FOR_REMOTE_KVS state.",
|
||||
request_id,
|
||||
)
|
||||
self.waiting.pop_request()
|
||||
skipped_waiting_requests.prepend_request(request)
|
||||
continue
|
||||
|
||||
if request.status == RequestStatus.WAITING_FOR_FSM:
|
||||
structured_output_req = request.structured_output_request
|
||||
if structured_output_req and structured_output_req.grammar:
|
||||
request.status = RequestStatus.WAITING
|
||||
else:
|
||||
self.waiting.pop_request()
|
||||
skipped_waiting_requests.prepend_request(request)
|
||||
continue
|
||||
|
||||
if request.status == RequestStatus.WAITING_FOR_STREAMING_REQ:
|
||||
assert not request.streaming_queue
|
||||
self.waiting.pop_request()
|
||||
skipped_waiting_requests.prepend_request(request)
|
||||
continue
|
||||
|
||||
if (
|
||||
self.lora_config
|
||||
and request.lora_request
|
||||
and (
|
||||
len(scheduled_loras) == self.lora_config.max_loras
|
||||
and request.lora_request.lora_int_id not in scheduled_loras
|
||||
)
|
||||
):
|
||||
self.waiting.pop_request()
|
||||
skipped_waiting_requests.prepend_request(request)
|
||||
continue
|
||||
|
||||
num_external_computed_tokens = 0
|
||||
load_kv_async = False
|
||||
connector_prefix_cache_queries, connector_prefix_cache_hits = 0, 0
|
||||
|
||||
if request.num_computed_tokens == 0:
|
||||
new_computed_blocks, num_new_local_computed_tokens = (
|
||||
self.kv_cache_manager.get_computed_blocks(request)
|
||||
)
|
||||
|
||||
if self.connector is not None:
|
||||
ext_tokens, load_kv_async = (
|
||||
self.connector.get_num_new_matched_tokens(
|
||||
request, num_new_local_computed_tokens
|
||||
)
|
||||
)
|
||||
|
||||
if ext_tokens is None:
|
||||
self.waiting.pop_request()
|
||||
skipped_waiting_requests.prepend_request(request)
|
||||
continue
|
||||
|
||||
request.num_external_computed_tokens = ext_tokens
|
||||
num_external_computed_tokens = ext_tokens
|
||||
|
||||
connector_prefix_cache_queries = (
|
||||
request.num_tokens - num_new_local_computed_tokens
|
||||
)
|
||||
connector_prefix_cache_hits = num_external_computed_tokens
|
||||
|
||||
num_computed_tokens = (
|
||||
num_new_local_computed_tokens + num_external_computed_tokens
|
||||
)
|
||||
else:
|
||||
new_computed_blocks = self.kv_cache_manager.empty_kv_cache_blocks
|
||||
num_new_local_computed_tokens = 0
|
||||
num_computed_tokens = request.num_computed_tokens
|
||||
|
||||
encoder_inputs_to_schedule = None
|
||||
external_load_encoder_input = []
|
||||
new_encoder_compute_budget = encoder_compute_budget
|
||||
|
||||
if load_kv_async:
|
||||
assert num_external_computed_tokens > 0
|
||||
num_new_tokens = 0
|
||||
else:
|
||||
num_new_tokens = request.num_tokens - num_computed_tokens
|
||||
threshold = self.scheduler_config.long_prefill_token_threshold
|
||||
if 0 < threshold < num_new_tokens:
|
||||
num_new_tokens = threshold
|
||||
|
||||
if (
|
||||
not self.scheduler_config.enable_chunked_prefill
|
||||
and num_new_tokens > token_budget
|
||||
):
|
||||
break
|
||||
|
||||
num_new_tokens = min(num_new_tokens, token_budget)
|
||||
assert num_new_tokens > 0
|
||||
|
||||
if request.has_encoder_inputs:
|
||||
(
|
||||
encoder_inputs_to_schedule,
|
||||
num_new_tokens,
|
||||
new_encoder_compute_budget,
|
||||
external_load_encoder_input,
|
||||
) = self._try_schedule_encoder_inputs(
|
||||
request,
|
||||
num_computed_tokens,
|
||||
num_new_tokens,
|
||||
encoder_compute_budget,
|
||||
shift_computed_tokens=1 if self.use_eagle else 0,
|
||||
)
|
||||
if num_new_tokens == 0:
|
||||
break
|
||||
|
||||
if self.need_mamba_block_aligned_split:
|
||||
num_new_tokens = self._mamba_block_aligned_split(
|
||||
request,
|
||||
num_new_tokens,
|
||||
num_new_local_computed_tokens,
|
||||
num_external_computed_tokens,
|
||||
)
|
||||
if num_new_tokens == 0:
|
||||
break
|
||||
|
||||
effective_lookahead_tokens = (
|
||||
0 if request.num_computed_tokens == 0 else self.num_lookahead_tokens
|
||||
)
|
||||
|
||||
num_encoder_tokens = 0
|
||||
if (
|
||||
self.is_encoder_decoder
|
||||
and request.has_encoder_inputs
|
||||
and encoder_inputs_to_schedule
|
||||
):
|
||||
num_encoder_tokens = sum(
|
||||
request.get_num_encoder_embeds(i)
|
||||
for i in encoder_inputs_to_schedule
|
||||
)
|
||||
|
||||
if not load_kv_async:
|
||||
enable_chunked = self.scheduler_config.enable_chunked_prefill
|
||||
tokens_still_to_compute = (
|
||||
request.num_tokens - num_computed_tokens
|
||||
)
|
||||
is_chunked = (
|
||||
enable_chunked
|
||||
and tokens_still_to_compute > num_new_tokens
|
||||
)
|
||||
if is_chunked:
|
||||
assert (
|
||||
request.num_tokens <= self.max_model_len
|
||||
), "request.num_tokens must not exceed max_model_len"
|
||||
num_tokens_need_slot = min(
|
||||
request.num_tokens + effective_lookahead_tokens,
|
||||
self.max_model_len,
|
||||
)
|
||||
blocks_needed = (
|
||||
self.kv_cache_manager.get_num_blocks_needed_for_tokens(
|
||||
request.request_id,
|
||||
num_tokens_need_slot,
|
||||
new_computed_blocks,
|
||||
num_encoder_tokens,
|
||||
)
|
||||
)
|
||||
num_free = self.kv_cache_manager.get_num_free_blocks()
|
||||
if num_free < blocks_needed:
|
||||
break
|
||||
|
||||
new_blocks = self.kv_cache_manager.allocate_slots(
|
||||
request,
|
||||
num_new_tokens,
|
||||
num_new_computed_tokens=num_new_local_computed_tokens,
|
||||
new_computed_blocks=new_computed_blocks,
|
||||
num_lookahead_tokens=effective_lookahead_tokens,
|
||||
num_external_computed_tokens=num_external_computed_tokens,
|
||||
delay_cache_blocks=load_kv_async,
|
||||
num_encoder_tokens=num_encoder_tokens,
|
||||
)
|
||||
|
||||
if new_blocks is None:
|
||||
if request.has_encoder_inputs:
|
||||
self.encoder_cache_manager.free(request)
|
||||
break
|
||||
|
||||
if self.connector is not None:
|
||||
self.connector.update_state_after_alloc(
|
||||
request,
|
||||
self.kv_cache_manager.get_blocks(request_id),
|
||||
num_external_computed_tokens,
|
||||
)
|
||||
if (
|
||||
self.connector_prefix_cache_stats is not None
|
||||
and connector_prefix_cache_queries != 0
|
||||
):
|
||||
self.connector_prefix_cache_stats.record(
|
||||
num_tokens=connector_prefix_cache_queries,
|
||||
num_hits=connector_prefix_cache_hits,
|
||||
preempted=request.num_preemptions > 0,
|
||||
)
|
||||
|
||||
request = self.waiting.pop_request()
|
||||
if load_kv_async:
|
||||
skipped_waiting_requests.prepend_request(request)
|
||||
request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
continue
|
||||
|
||||
self.running.append(request)
|
||||
if self.log_stats:
|
||||
request.record_event(
|
||||
EngineCoreEventType.SCHEDULED, scheduled_timestamp
|
||||
)
|
||||
if request.status == RequestStatus.WAITING:
|
||||
scheduled_new_reqs.append(request)
|
||||
elif request.status == RequestStatus.PREEMPTED:
|
||||
scheduled_resumed_reqs.append(request)
|
||||
else:
|
||||
raise RuntimeError(f"Invalid request status: {request.status}")
|
||||
|
||||
if self.lora_config and request.lora_request:
|
||||
scheduled_loras.add(request.lora_request.lora_int_id)
|
||||
req_to_new_blocks[request_id] = self.kv_cache_manager.get_blocks(
|
||||
request_id
|
||||
)
|
||||
num_scheduled_tokens[request_id] = num_new_tokens
|
||||
token_budget -= num_new_tokens
|
||||
request.status = RequestStatus.RUNNING
|
||||
request.num_computed_tokens = num_computed_tokens
|
||||
if request.num_cached_tokens < 0:
|
||||
request.num_cached_tokens = num_computed_tokens
|
||||
if encoder_inputs_to_schedule:
|
||||
scheduled_encoder_inputs[request_id] = encoder_inputs_to_schedule
|
||||
for i in encoder_inputs_to_schedule:
|
||||
self.encoder_cache_manager.allocate(request, i)
|
||||
encoder_compute_budget = new_encoder_compute_budget
|
||||
if external_load_encoder_input:
|
||||
for i in external_load_encoder_input:
|
||||
self.encoder_cache_manager.allocate(request, i)
|
||||
if self.ec_connector is not None:
|
||||
self.ec_connector.update_state_after_alloc(request, i)
|
||||
|
||||
if skipped_waiting_requests:
|
||||
self.waiting.prepend_requests(skipped_waiting_requests)
|
||||
|
||||
# Check if the scheduling constraints are satisfied.
|
||||
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
|
||||
assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens
|
||||
|
||||
assert token_budget >= 0
|
||||
assert len(self.running) <= self.max_num_running_reqs
|
||||
assert len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len(
|
||||
scheduled_running_reqs
|
||||
) <= len(self.running)
|
||||
|
||||
num_common_prefix_blocks = [0] * len(self.kv_cache_config.kv_cache_groups)
|
||||
with record_function_or_nullcontext("schedule: get_num_common_prefix_blocks"):
|
||||
if self.running:
|
||||
any_request_id = self.running[0].request_id
|
||||
num_common_prefix_blocks = (
|
||||
self.kv_cache_manager.get_num_common_prefix_blocks(any_request_id)
|
||||
)
|
||||
|
||||
if self.use_v2_model_runner:
|
||||
scheduled_new_reqs = scheduled_new_reqs + scheduled_resumed_reqs
|
||||
scheduled_resumed_reqs = []
|
||||
new_reqs_data = [
|
||||
NewRequestData.from_request(
|
||||
req,
|
||||
req_to_new_blocks[req.request_id].get_block_ids(),
|
||||
req._all_token_ids,
|
||||
)
|
||||
for req in scheduled_new_reqs
|
||||
]
|
||||
else:
|
||||
new_reqs_data = [
|
||||
NewRequestData.from_request(
|
||||
req, req_to_new_blocks[req.request_id].get_block_ids()
|
||||
)
|
||||
for req in scheduled_new_reqs
|
||||
]
|
||||
|
||||
with record_function_or_nullcontext("schedule: make_cached_request_data"):
|
||||
cached_reqs_data = self._make_cached_request_data(
|
||||
scheduled_running_reqs,
|
||||
scheduled_resumed_reqs,
|
||||
num_scheduled_tokens,
|
||||
scheduled_spec_decode_tokens,
|
||||
req_to_new_blocks,
|
||||
)
|
||||
|
||||
self.prev_step_scheduled_req_ids.clear()
|
||||
self.prev_step_scheduled_req_ids.update(num_scheduled_tokens.keys())
|
||||
|
||||
new_block_ids_to_zero = (
|
||||
(self.kv_cache_manager.take_new_block_ids() or None)
|
||||
if self.needs_kv_cache_zeroing
|
||||
else None
|
||||
)
|
||||
|
||||
scheduler_output = SchedulerOutput(
|
||||
scheduled_new_reqs=new_reqs_data,
|
||||
scheduled_cached_reqs=cached_reqs_data,
|
||||
scheduled_resumed_reqs=[r.request_id for r in scheduled_resumed_reqs],
|
||||
num_scheduled_tokens=num_scheduled_tokens,
|
||||
total_num_scheduled_tokens=total_num_scheduled_tokens,
|
||||
scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,
|
||||
scheduled_encoder_inputs=scheduled_encoder_inputs,
|
||||
num_common_prefix_blocks=num_common_prefix_blocks,
|
||||
preempted_req_ids={req.request_id for req in preempted_reqs},
|
||||
finished_req_ids=self.finished_req_ids,
|
||||
free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(),
|
||||
new_block_ids_to_zero=new_block_ids_to_zero,
|
||||
)
|
||||
|
||||
if self.connector is not None:
|
||||
meta: KVConnectorMetadata = self.connector.build_connector_meta(
|
||||
scheduler_output
|
||||
)
|
||||
scheduler_output.kv_connector_metadata = meta
|
||||
|
||||
if self.ec_connector is not None:
|
||||
ec_meta: ECConnectorMetadata = self.ec_connector.build_connector_meta(
|
||||
scheduler_output
|
||||
)
|
||||
scheduler_output.ec_connector_metadata = ec_meta
|
||||
|
||||
with record_function_or_nullcontext("schedule: update_after_schedule"):
|
||||
self._update_after_schedule(scheduler_output)
|
||||
return scheduler_output
|
||||
|
||||
def _preempt_request(self, request: Request, timestamp: float) -> None:
|
||||
"""Preempt a request and put it back to the waiting queue.
|
||||
|
||||
@@ -1193,7 +1750,6 @@ class Scheduler(SchedulerInterface):
|
||||
# available. In this case, we can't schedule any token for
|
||||
# the request in this step.
|
||||
num_new_tokens = 0
|
||||
num_new_tokens = 0
|
||||
break
|
||||
|
||||
# Calculate the number of embeddings to schedule in the current range
|
||||
@@ -1508,6 +2064,9 @@ class Scheduler(SchedulerInterface):
|
||||
# outputs this step.
|
||||
engine_core_outputs[0] = eco = EngineCoreOutputs()
|
||||
eco.scheduler_stats = stats
|
||||
|
||||
if model_runner_output.draft_token_ids is not None:
|
||||
self.update_draft_token_ids(model_runner_output.draft_token_ids)
|
||||
|
||||
return engine_core_outputs
|
||||
|
||||
|
||||
@@ -1,10 +1,64 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import contextlib
|
||||
from collections.abc import Sequence
|
||||
|
||||
from vllm.sampling_params import RepetitionDetectionParams
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
|
||||
|
||||
def _has_repeating_pattern(
|
||||
token_ids: Sequence[int],
|
||||
pattern_len: int,
|
||||
repetition_min_count: int,
|
||||
) -> bool:
|
||||
"""Check if the tail of token_ids contains a repeating pattern.
|
||||
|
||||
Compares the last pattern_len tokens against the preceding
|
||||
(repetition_min_count - 1) repetitions of the same length.
|
||||
"""
|
||||
for n in range(1, pattern_len + 1):
|
||||
target_token = token_ids[-n]
|
||||
for m in range(1, repetition_min_count):
|
||||
if token_ids[-(pattern_len * m + n)] != target_token:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def check_sequence_repetition(
|
||||
token_ids: Sequence[int],
|
||||
params: RepetitionDetectionParams,
|
||||
) -> bool:
|
||||
"""Check if a sequence of token IDs has a repetition pattern.
|
||||
Args:
|
||||
token_ids: List of token IDs
|
||||
params: Repetition detection parameters.
|
||||
Returns:
|
||||
True if a repetition pattern is found, False otherwise.
|
||||
"""
|
||||
max_pattern_size = params.max_pattern_size
|
||||
min_pattern_size = params.min_pattern_size
|
||||
min_count = params.min_count
|
||||
|
||||
if min_pattern_size <= 0:
|
||||
min_pattern_size = 1
|
||||
|
||||
if max_pattern_size <= 0 or min_count < 2 or min_pattern_size > max_pattern_size:
|
||||
return False
|
||||
|
||||
for pattern_len in range(
|
||||
min_pattern_size,
|
||||
max_pattern_size + 1,
|
||||
):
|
||||
if pattern_len * min_count > len(token_ids):
|
||||
return False
|
||||
|
||||
if _has_repeating_pattern(token_ids, pattern_len, min_count):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def remove_all(lst: list, items_to_remove: set) -> list:
|
||||
"""Remove all items from a list that are in the items_to_remove set.
|
||||
|
||||
@@ -61,4 +115,16 @@ def check_stop(request: Request, max_model_len: int) -> bool:
|
||||
):
|
||||
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
|
||||
return True
|
||||
|
||||
repetition_detection = sampling_params.repetition_detection
|
||||
if repetition_detection is not None and (
|
||||
check_sequence_repetition(
|
||||
request.output_token_ids,
|
||||
repetition_detection,
|
||||
)
|
||||
):
|
||||
request.status = RequestStatus.FINISHED_REPETITION
|
||||
request.stop_reason = "repetition_detected"
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
Reference in New Issue
Block a user