[0.14.1][bugfix][sched] fix incompatibility of RecomputeScheduler with vllm v0.14.1 (#6286)
### What this PR does / why we need it?
This PR rebases RecomputeScheduler codebase to vllm tags/v0.14.1 in
order to fix the incompatibility with vllm's original Scheduler and
AsyncScheduler. Main changes focus on multimodal model and speculative
decoding parts.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
We tested this PR with 2P1D E2E serving test case.
- vLLM version: v0.14.1
- vLLM main:
d68209402d
---------
Signed-off-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
@@ -22,6 +22,7 @@ import time
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, fields
|
||||
|
||||
import numpy as np
|
||||
from vllm._bc_linter import bc_linter_include
|
||||
from vllm.config import SchedulerConfig, VllmConfig
|
||||
from vllm.distributed.ec_transfer.ec_connector.base import ECConnectorMetadata
|
||||
@@ -34,8 +35,9 @@ from vllm.v1.core.sched.async_scheduler import AsyncScheduler
|
||||
from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput
|
||||
from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_queue
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
from vllm.v1.core.sched.utils import check_stop, remove_all
|
||||
from vllm.v1.core.sched.utils import remove_all
|
||||
from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs, FinishReason
|
||||
from vllm.v1.metrics.perf import PerfStats
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
from vllm.v1.spec_decode.metrics import SpecDecodingStats
|
||||
@@ -81,7 +83,7 @@ class RecomputeSchedulerOutput(SchedulerOutput):
|
||||
class RecomputeScheduler(Scheduler):
|
||||
running: list[Request]
|
||||
|
||||
def schedule(self) -> SchedulerOutput:
|
||||
def schedule(self) -> RecomputeSchedulerOutput:
|
||||
# NOTE(woosuk) on the scheduling algorithm:
|
||||
# There's no "decoding phase" nor "prefill phase" in the scheduler.
|
||||
# Each request just has the num_computed_tokens and
|
||||
@@ -299,7 +301,12 @@ class RecomputeScheduler(Scheduler):
|
||||
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
|
||||
is_ready = self._update_waiting_for_remote_kv(request)
|
||||
if is_ready:
|
||||
request.status = RequestStatus.WAITING
|
||||
if request.num_preemptions:
|
||||
# We must be loading for a resumed preemption
|
||||
# rather than a new request.
|
||||
request.status = RequestStatus.PREEMPTED
|
||||
else:
|
||||
request.status = RequestStatus.WAITING
|
||||
else:
|
||||
logger.debug(
|
||||
"%s is still in WAITING_FOR_REMOTE_KVS state.",
|
||||
@@ -424,28 +431,28 @@ class RecomputeScheduler(Scheduler):
|
||||
# of local and remote blocks.
|
||||
effective_lookahead_tokens = 0 if request.num_computed_tokens == 0 else self.num_lookahead_tokens
|
||||
|
||||
# Determine if we need to allocate cross-attention blocks.
|
||||
if self.is_encoder_decoder and request.has_encoder_inputs:
|
||||
# TODO(russellb): For Whisper, we know that the input is
|
||||
# always padded to the maximum length. If we support other
|
||||
# encoder-decoder models, this will need to be updated if we
|
||||
# want to only allocate what is needed.
|
||||
num_encoder_tokens = self.scheduler_config.max_num_encoder_input_tokens
|
||||
else:
|
||||
num_encoder_tokens = 0
|
||||
num_encoder_tokens = (
|
||||
self._num_encoder_max_input_tokens if self.is_encoder_decoder and request.has_encoder_inputs else 0
|
||||
)
|
||||
|
||||
new_blocks = self.kv_cache_manager.allocate_slots(
|
||||
request,
|
||||
num_new_tokens + num_external_computed_tokens,
|
||||
num_new_local_computed_tokens,
|
||||
new_computed_blocks,
|
||||
num_new_tokens,
|
||||
num_new_computed_tokens=num_new_local_computed_tokens,
|
||||
new_computed_blocks=new_computed_blocks,
|
||||
num_lookahead_tokens=effective_lookahead_tokens,
|
||||
num_external_computed_tokens=num_external_computed_tokens,
|
||||
delay_cache_blocks=load_kv_async,
|
||||
num_encoder_tokens=num_encoder_tokens,
|
||||
)
|
||||
|
||||
if new_blocks is None:
|
||||
# The request cannot be scheduled.
|
||||
|
||||
# NOTE: we need to untouch the request from the encode cache
|
||||
# manager
|
||||
if request.has_encoder_inputs:
|
||||
self.encoder_cache_manager.free(request)
|
||||
break
|
||||
|
||||
# KVTransfer: the connector uses this info to determine
|
||||
@@ -455,7 +462,7 @@ class RecomputeScheduler(Scheduler):
|
||||
if self.connector is not None:
|
||||
self.connector.update_state_after_alloc(
|
||||
request,
|
||||
new_computed_blocks + new_blocks,
|
||||
self.kv_cache_manager.get_blocks(request.request_id),
|
||||
num_external_computed_tokens,
|
||||
)
|
||||
|
||||
@@ -471,7 +478,6 @@ class RecomputeScheduler(Scheduler):
|
||||
|
||||
self._update_connector_prefix_cache_stats(request)
|
||||
|
||||
req_index += 1
|
||||
self.running.append(request)
|
||||
if self.log_stats:
|
||||
request.record_event(EngineCoreEventType.SCHEDULED, scheduled_timestamp)
|
||||
@@ -606,6 +612,11 @@ class RecomputeScheduler(Scheduler):
|
||||
pooler_outputs = model_runner_output.pooler_output
|
||||
num_nans_in_logits = model_runner_output.num_nans_in_logits
|
||||
kv_connector_output = model_runner_output.kv_connector_output
|
||||
cudagraph_stats = model_runner_output.cudagraph_stats
|
||||
|
||||
perf_stats: PerfStats | None = None
|
||||
if self.perf_metrics and self.perf_metrics.is_enabled():
|
||||
perf_stats = self.perf_metrics.get_step_perf_stats_per_gpu(scheduler_output)
|
||||
|
||||
outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list)
|
||||
spec_decoding_stats: SpecDecodingStats | None = None
|
||||
@@ -644,7 +655,7 @@ class RecomputeScheduler(Scheduler):
|
||||
for req_id, num_tokens_scheduled in num_scheduled_tokens.items():
|
||||
assert num_tokens_scheduled > 0
|
||||
if failed_kv_load_req_ids and req_id in failed_kv_load_req_ids:
|
||||
# Skip requests that were recovered from KV load failure
|
||||
# skip failed or rescheduled requests from KV load failure
|
||||
continue
|
||||
request = self.requests.get(req_id)
|
||||
if request is None:
|
||||
@@ -676,25 +687,46 @@ class RecomputeScheduler(Scheduler):
|
||||
spec_decoding_stats,
|
||||
num_draft_tokens=num_draft_tokens,
|
||||
num_accepted_tokens=num_accepted,
|
||||
num_invalid_spec_tokens=scheduler_output.num_invalid_spec_tokens,
|
||||
request_id=req_id,
|
||||
)
|
||||
|
||||
stopped = False
|
||||
new_logprobs = None
|
||||
new_token_ids = generated_token_ids
|
||||
pooler_output = pooler_outputs[req_index] if pooler_outputs else None
|
||||
kv_transfer_params = None
|
||||
status_before_stop = request.status
|
||||
|
||||
# Check for stop and update request status.
|
||||
if new_token_ids:
|
||||
new_token_ids, stopped = self._update_request_with_output(request, new_token_ids)
|
||||
elif request.pooling_params and pooler_output is not None:
|
||||
# Pooling stops as soon as there is output.
|
||||
request.status = RequestStatus.FINISHED_STOPPED
|
||||
stopped = True
|
||||
|
||||
# Stop checking for pooler models.
|
||||
pooler_output = None
|
||||
if pooler_outputs:
|
||||
pooler_output = pooler_outputs[req_index]
|
||||
stopped = check_stop(request, self.max_model_len, pooler_output)
|
||||
|
||||
routed_experts = None
|
||||
if stopped:
|
||||
if self.vllm_config.model_config.enable_return_routed_experts:
|
||||
kv_blocks = self.kv_cache_manager.get_blocks(request.request_id)
|
||||
block_ids = kv_blocks.get_block_ids()[0]
|
||||
num_tokens = request.num_tokens - 1
|
||||
|
||||
# compute slot mapping
|
||||
block_ids_array = np.array(block_ids, dtype=np.int32)
|
||||
num_blocks = len(block_ids)
|
||||
block_size = self.block_size
|
||||
|
||||
# generate block offsets
|
||||
block_offsets = np.arange(0, block_size)
|
||||
|
||||
# compute slot mapping: slot = block_id * block_size + offset
|
||||
slot_mapping = (
|
||||
block_offsets.reshape((1, block_size)) + block_ids_array.reshape((num_blocks, 1)) * block_size
|
||||
).flatten()[:num_tokens]
|
||||
|
||||
routed_experts = self.routed_experts_reader.get_routed_experts(indices=slot_mapping)
|
||||
kv_transfer_params = self._free_request(request)
|
||||
if status_before_stop == RequestStatus.RUNNING:
|
||||
stopped_running_reqs.add(request)
|
||||
@@ -709,7 +741,13 @@ class RecomputeScheduler(Scheduler):
|
||||
struct_output_request = request.structured_output_request
|
||||
assert struct_output_request is not None
|
||||
assert struct_output_request.grammar is not None
|
||||
struct_output_request.grammar.accept_tokens(req_id, new_token_ids)
|
||||
ok = struct_output_request.grammar.accept_tokens(req_id, new_token_ids)
|
||||
if not ok:
|
||||
logger.warning(
|
||||
"Unexpected: grammar rejected tokens %s for request %s.",
|
||||
new_token_ids,
|
||||
req_id,
|
||||
)
|
||||
|
||||
if num_nans_in_logits is not None and req_id in num_nans_in_logits:
|
||||
request.num_nans_in_logits = num_nans_in_logits[req_id]
|
||||
@@ -731,6 +769,7 @@ class RecomputeScheduler(Scheduler):
|
||||
kv_transfer_params=kv_transfer_params,
|
||||
trace_headers=request.trace_headers,
|
||||
num_cached_tokens=request.num_cached_tokens,
|
||||
routed_experts=routed_experts,
|
||||
num_nans_in_logits=request.num_nans_in_logits,
|
||||
)
|
||||
)
|
||||
@@ -745,6 +784,21 @@ class RecomputeScheduler(Scheduler):
|
||||
# This is a rare case and unlikely to impact performance.
|
||||
self.waiting.remove_requests(stopped_preempted_reqs)
|
||||
|
||||
if failed_kv_load_req_ids and not self.recompute_kv_load_failures:
|
||||
requests = [self.requests[req_id] for req_id in failed_kv_load_req_ids]
|
||||
self.finish_requests(failed_kv_load_req_ids, RequestStatus.FINISHED_ERROR)
|
||||
for request in requests:
|
||||
outputs[request.client_index].append(
|
||||
EngineCoreOutput(
|
||||
request_id=request.request_id,
|
||||
new_token_ids=[],
|
||||
finish_reason=request.get_finished_reason(),
|
||||
events=request.take_events(),
|
||||
trace_headers=request.trace_headers,
|
||||
num_cached_tokens=request.num_cached_tokens,
|
||||
)
|
||||
)
|
||||
|
||||
# KV Connector: update state for finished KV Transfers.
|
||||
if kv_connector_output:
|
||||
self._update_from_kv_xfer_finished(kv_connector_output)
|
||||
@@ -782,7 +836,7 @@ class RecomputeScheduler(Scheduler):
|
||||
engine_core_outputs[client_index] = EngineCoreOutputs(finished_requests=finished_set)
|
||||
finished_req_ids.clear()
|
||||
|
||||
if (stats := self.make_stats(spec_decoding_stats, kv_connector_stats)) is not None:
|
||||
if (stats := self.make_stats(spec_decoding_stats, kv_connector_stats, cudagraph_stats, perf_stats)) is not None:
|
||||
# Return stats to only one of the front-ends.
|
||||
if (eco := next(iter(engine_core_outputs.values()), None)) is None:
|
||||
# We must return the stats even if there are no request
|
||||
|
||||
Reference in New Issue
Block a user