69 lines
2.7 KiB
Python
69 lines
2.7 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
from vllm.logger import init_logger
|
|
from vllm.v1.core.sched.output import SchedulerOutput
|
|
from vllm.v1.core.sched.scheduler import Scheduler
|
|
from vllm.v1.request import Request, RequestStatus
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
class AsyncScheduler(Scheduler):
|
|
def _update_after_schedule(
|
|
self,
|
|
scheduler_output: SchedulerOutput,
|
|
) -> None:
|
|
super()._update_after_schedule(scheduler_output)
|
|
pending_structured_output_tokens = False
|
|
spec_decode_tokens = scheduler_output.scheduled_spec_decode_tokens
|
|
for req_id in scheduler_output.num_scheduled_tokens:
|
|
request = self.requests[req_id]
|
|
pending_structured_output_tokens |= (
|
|
request.use_structured_output and request.num_output_placeholders > 0
|
|
)
|
|
cur_num_spec_tokens = len(spec_decode_tokens.get(req_id, ()))
|
|
if (
|
|
request.num_computed_tokens
|
|
== request.num_tokens
|
|
+ request.num_output_placeholders
|
|
+ cur_num_spec_tokens
|
|
):
|
|
# The request will generate a new token plus num_spec_tokens
|
|
# in this scheduling step.
|
|
request.num_output_placeholders += 1 + cur_num_spec_tokens
|
|
# Add placeholders for the new tokens in spec_token_ids.
|
|
# We will update the actual spec token ids in the worker process.
|
|
request.spec_token_ids = [-1] * self.num_spec_tokens
|
|
|
|
scheduler_output.pending_structured_output_tokens = (
|
|
pending_structured_output_tokens
|
|
)
|
|
|
|
def _update_request_with_output(
|
|
self,
|
|
request: Request,
|
|
new_token_ids: list[int],
|
|
) -> tuple[list[int], bool]:
|
|
if request.discard_latest_async_tokens:
|
|
# If the request is force preempted in reset_prefix_cache, we
|
|
# should discard the latest async token.
|
|
request.discard_latest_async_tokens = False
|
|
return [], False
|
|
|
|
status_before_update = request.status
|
|
new_token_ids, stopped = super()._update_request_with_output(
|
|
request, new_token_ids
|
|
)
|
|
|
|
# Update the number of output placeholders.
|
|
request.num_output_placeholders -= len(new_token_ids)
|
|
assert request.num_output_placeholders >= 0
|
|
|
|
# Cache the new tokens. Preempted requests should be skipped.
|
|
if status_before_update == RequestStatus.RUNNING:
|
|
self.kv_cache_manager.cache_blocks(
|
|
request, request.num_computed_tokens - request.num_output_placeholders
|
|
)
|
|
return new_token_ids, stopped
|