diff --git a/examples/disaggregated_prefill_v1/load_balance_proxy_layerwise_server_example.py b/examples/disaggregated_prefill_v1/load_balance_proxy_layerwise_server_example.py index 61d4201..1336e5a 100644 --- a/examples/disaggregated_prefill_v1/load_balance_proxy_layerwise_server_example.py +++ b/examples/disaggregated_prefill_v1/load_balance_proxy_layerwise_server_example.py @@ -84,17 +84,18 @@ # # For more details, see the code and comments in this file. - import argparse import asyncio import functools import heapq +import json import os import sys -import uuid import threading +import uuid from contextlib import asynccontextmanager -from typing import List +from dataclasses import dataclass +from typing import Any, List import httpx from fastapi import FastAPI, Request @@ -106,6 +107,7 @@ logger = init_logger(__name__) # Add uvloop for faster event loop if available try: import uvloop + asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) except ImportError: pass @@ -324,7 +326,7 @@ async def listen_for_disconnect(request: Request) -> None: def with_cancellation(handler_func): - + @functools.wraps(handler_func) async def wrapper(*args, **kwargs): request = kwargs["request"] @@ -337,9 +339,9 @@ def with_cancellation(handler_func): if handler_task in done: return handler_task.result() return None - + return wrapper - + app = FastAPI(lifespan=lifespan) @@ -362,7 +364,8 @@ async def send_request_to_service(client: httpx.AsyncClient, "remote_host": None, "remote_port": None, "aborted_request": list(aborted_requests), - "metaserver": f"http://{global_args.host}:{global_args.port}/v1/metaserver" + "metaserver": + f"http://{global_args.host}:{global_args.port}/v1/metaserver" } req_data["stream"] = False req_data["max_tokens"] = 1 @@ -455,72 +458,174 @@ def get_api_request_id(api, req_id): return "chatcmpl-" + req_id +async def _handle_select_instance(api: str, req_data: Any, + request_length: int): + prefiller_score = proxy_state.calculate_prefill_scores(request_length) + logger.debug( + f"Request length: {request_length}, Prefiller score: {prefiller_score}" + ) + request_id = await proxy_state.next_req_id() + # Select prefiller + prefiller_idx = proxy_state.select_prefiller(prefiller_score) + prefiller = proxy_state.prefillers[prefiller_idx] + result_future = asyncio.Future() # type: ignore + request_id_api = get_api_request_id(api, request_id) + proxy_state.req_id_future[request_id_api] = result_future + # Send request to prefiller + asyncio.get_running_loop().create_task( + send_request_to_service(prefiller.client, + prefiller_idx, + api, + req_data, + request_id, + max_retries=global_args.max_retries, + base_delay=global_args.retry_delay)) + proxy_state.release_prefiller(prefiller_idx, prefiller_score) + + response = await result_future + del proxy_state.req_id_future[request_id_api] + req_data["kv_transfer_params"] = response + + # Select decoder + decoder_score = proxy_state.calculate_decode_scores(request_length) + logger.debug("Decoder score: %f", decoder_score) + # Use the prefiller's kv_transfer_params to select decoder + decoder_idx = proxy_state.select_decoder(decoder_score) + decoder = proxy_state.decoders[decoder_idx] + logger.debug("Using %s %s", prefiller.url, decoder.url) + return InstanceInfo(request_id=request_id, + prefiller_idx=prefiller_idx, + prefiller_score=prefiller_score, + prefiller=prefiller, + decoder=decoder, + decoder_idx=decoder_idx, + decoder_score=decoder_score) + + +@dataclass +class InstanceInfo: + request_id: str + prefiller_idx: int + prefiller_score: float + prefiller: ServerState + decoder_idx: int + decoder_score: float + decoder: ServerState + + async def _handle_completions(api: str, request: Request): try: req_data = await request.json() req_body = await request.body() request_length = len(req_body) - prefiller_score = proxy_state.calculate_prefill_scores(request_length) - logger.debug( - f"Request length: {request_length}, Prefiller score: {prefiller_score}" - ) - request_id = await proxy_state.next_req_id() - # Select prefiller - prefiller_idx = proxy_state.select_prefiller(prefiller_score) - prefiller = proxy_state.prefillers[prefiller_idx] - result_future = asyncio.Future() # type: ignore - request_id_api = get_api_request_id(api, request_id) - proxy_state.req_id_future[request_id_api] = result_future - # Send request to prefiller - asyncio.get_running_loop().create_task(send_request_to_service( - prefiller.client, - prefiller_idx, - api, - req_data, - request_id, - max_retries=global_args.max_retries, - base_delay=global_args.retry_delay)) - proxy_state.release_prefiller(prefiller_idx, prefiller_score) - - response = await result_future - del proxy_state.req_id_future[request_id_api] - req_data["kv_transfer_params"] = response + instance_info = await _handle_select_instance(api, req_data, + request_length) + stream_flag = bool(req_data.get("stream", False)) + chat_flag = "messages" in req_data + + if "prompt" in req_data: + origin_prompt = req_data["prompt"] + elif chat_flag: + messages = req_data["messages"] + origin_prompt = messages[0].get("content", "") + else: + origin_prompt = "" + # refer to vLLM sampling_params: max_token default value + origin_max_tokens = req_data.get("max_tokens", 16) - # Select decoder - decoder_score = proxy_state.calculate_decode_scores(request_length) - logger.debug("Decoder score: %f", decoder_score) - # Use the prefiller's kv_transfer_params to select decoder - decoder_idx = proxy_state.select_decoder(decoder_score) - decoder = proxy_state.decoders[decoder_idx] - logger.debug("Using %s %s", prefiller.url, decoder.url) - # Stream response from decoder - released_kv = False async def generate_stream(): - nonlocal released_kv + nonlocal instance_info + generated_token = "" + released_kv = False + retry_count = 0 + retry = True + completion_tokens = 0 # Only one await per chunk, minimal logic in loop try: - async for chunk in stream_service_response_with_retry( - decoder.client, - api, - req_data, - request_id=request_id, - max_retries=global_args.max_retries, - base_delay=global_args.retry_delay): - if not released_kv and chunk: - proxy_state.release_prefiller_kv( - prefiller_idx, prefiller_score) - released_kv = True - yield chunk + while retry: + retry = False + async for chunk in stream_service_response_with_retry( + instance_info.decoder.client, + api, + req_data, + request_id=instance_info.request_id, + max_retries=global_args.max_retries, + base_delay=global_args.retry_delay): + if not released_kv and chunk: + proxy_state.release_prefiller_kv( + instance_info.prefiller_idx, + instance_info.prefiller_score) + released_kv = True + chunk_str = chunk.decode("utf-8").strip() + if not chunk_str: + continue + if chunk_str.startswith("data: "): + chunk_str = chunk_str[len("data: "):] + try: + chunk_json = json.loads(chunk_str) + except json.JSONDecodeError: + # if chunk is [done], skip it. + logger.warning( + f"Skipping chunk: {chunk_str}") + yield chunk + continue + choices = chunk_json.get("choices", []) + if not choices: + yield chunk + continue + + choice = choices[0] + delta = choice.get("delta") or {} + message = choice.get("message") or {} + content = ( + delta.get("content") + or message.get("content") + or choice.get("text") + or "" + ) + generated_token += content + + stop_reason = choice.get( + "stop_reason") + usage = chunk_json.get("usage", {}) + completion_tokens = (completion_tokens + 1) if stream_flag else \ + (completion_tokens + usage.get("completion_tokens")) + if stop_reason == "recomputed": + retry = True + retry_count += 1 + if chat_flag: + messages[0][ + "content"] = origin_prompt + generated_token + else: + req_data[ + "prompt"] = origin_prompt + generated_token + req_data[ + "max_tokens"] = origin_max_tokens - completion_tokens + retry_count + tmp_request_length = len( + json.dumps(req_data).encode("utf-8")) + instance_info = await _handle_select_instance( + api, req_data, tmp_request_length) + break + if retry_count > 0 and not stream_flag: + if chat_flag: + choices[0]["message"][ + "content"] = generated_token + else: + choices[0]["text"] = generated_token + chunk = json.dumps(chunk_json).encode("utf-8") + yield chunk except Exception as e: logger.error( - f"Error during streaming from decoder {decoder.url}: {str(e)} the aborted request {request_id} will be routing to the target prefiller when new request is ready to dispatch to it" + f"Error during streaming from decoder {instance_info.decoder.url}: {str(e)} the aborted request {instance_info.request_id} will be routing to the target prefiller when new request is ready to dispatch to it" ) - proxy_state.abort_prefiller_request(prefiller_idx, request_id) - proxy_state.release_prefiller_kv(prefiller_idx, - prefiller_score) + proxy_state.abort_prefiller_request( + instance_info.prefiller_idx, instance_info.request_id) + proxy_state.release_prefiller_kv(instance_info.prefiller_idx, + instance_info.prefiller_score) # After streaming done, release tokens - proxy_state.release_decoder(decoder_idx, decoder_score) + proxy_state.release_decoder(instance_info.decoder_idx, + instance_info.decoder_score) return StreamingResponse(generate_stream(), media_type="application/json") @@ -564,13 +669,12 @@ async def metaserver(request: Request): result_future = proxy_state.req_id_future[request_id] result_future.set_result(req_data) except Exception as e: - logger.error( - f"Post metaserver failed with: {str(e)}" - ) + logger.error(f"Post metaserver failed with: {str(e)}") if __name__ == '__main__': global global_args global_args = parse_args() import uvicorn + uvicorn.run(app, host=global_args.host, port=global_args.port) diff --git a/examples/disaggregated_prefill_v1/load_balance_proxy_server_example.py b/examples/disaggregated_prefill_v1/load_balance_proxy_server_example.py index fd1c7e5..0e28deb 100644 --- a/examples/disaggregated_prefill_v1/load_balance_proxy_server_example.py +++ b/examples/disaggregated_prefill_v1/load_balance_proxy_server_example.py @@ -84,16 +84,17 @@ # # For more details, see the code and comments in this file. - import argparse import asyncio import functools import heapq +import json import os import sys import uuid from contextlib import asynccontextmanager -from typing import List +from dataclasses import dataclass +from typing import Any, List import httpx from fastapi import FastAPI, Request @@ -105,6 +106,7 @@ logger = init_logger(__name__) # Add uvloop for faster event loop if available try: import uvloop + asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) except ImportError: pass @@ -443,69 +445,170 @@ async def stream_service_response_with_retry(client: httpx.AsyncClient, raise e +async def _handle_select_instance(api: str, req_data: Any, + request_length: int): + prefiller_score = proxy_state.calculate_prefill_scores(request_length) + logger.debug( + f"Request length: {request_length}, Prefiller score: {prefiller_score}" + ) + request_id = await proxy_state.next_req_id() + # Select prefiller + prefiller_idx = proxy_state.select_prefiller(prefiller_score) + prefiller = proxy_state.prefillers[prefiller_idx] + # Send request to prefiller + response = await send_request_to_service( + prefiller.client, + prefiller_idx, + api, + req_data, + request_id, + max_retries=global_args.max_retries, + base_delay=global_args.retry_delay) + proxy_state.release_prefiller(prefiller_idx, prefiller_score) + response_json = response.json() + kv_transfer_params = response_json.get('kv_transfer_params', {}) + if kv_transfer_params: + req_data["kv_transfer_params"] = kv_transfer_params + # Select decoder + decoder_score = proxy_state.calculate_decode_scores(request_length) + logger.debug("Decoder score: %f", decoder_score) + # Use the prefiller's kv_transfer_params to select decoder + decoder_idx = proxy_state.select_decoder(decoder_score) + decoder = proxy_state.decoders[decoder_idx] + logger.debug("Using %s %s", prefiller.url, decoder.url) + return InstanceInfo(request_id=request_id, + prefiller_idx=prefiller_idx, + prefiller_score=prefiller_score, + prefiller=prefiller, + decoder=decoder, + decoder_idx=decoder_idx, + decoder_score=decoder_score) + + +@dataclass +class InstanceInfo: + request_id: str + prefiller_idx: int + prefiller_score: float + prefiller: ServerState + decoder_idx: int + decoder_score: float + decoder: ServerState + + async def _handle_completions(api: str, request: Request): try: req_data = await request.json() req_body = await request.body() request_length = len(req_body) - prefiller_score = proxy_state.calculate_prefill_scores(request_length) - logger.debug( - f"Request length: {request_length}, Prefiller score: {prefiller_score}" - ) - request_id = await proxy_state.next_req_id() - # Select prefiller - prefiller_idx = proxy_state.select_prefiller(prefiller_score) - prefiller = proxy_state.prefillers[prefiller_idx] - # Send request to prefiller - response = await send_request_to_service( - prefiller.client, - prefiller_idx, - api, - req_data, - request_id, - max_retries=global_args.max_retries, - base_delay=global_args.retry_delay) - proxy_state.release_prefiller(prefiller_idx, prefiller_score) - response_json = response.json() - kv_transfer_params = response_json.get('kv_transfer_params', {}) - if kv_transfer_params: - req_data["kv_transfer_params"] = kv_transfer_params - # Select decoder - decoder_score = proxy_state.calculate_decode_scores(request_length) - logger.debug("Decoder score: %f", decoder_score) - # Use the prefiller's kv_transfer_params to select decoder - decoder_idx = proxy_state.select_decoder(decoder_score) - decoder = proxy_state.decoders[decoder_idx] - logger.debug("Using %s %s", prefiller.url, decoder.url) - # Stream response from decoder - released_kv = False + instance_info = await _handle_select_instance(api, req_data, + request_length) + stream_flag = bool(req_data.get("stream", False)) + chat_flag = "messages" in req_data + + if "prompt" in req_data: + origin_prompt = req_data["prompt"] + elif chat_flag: + messages = req_data["messages"] + origin_prompt = messages[0].get("content", "") + else: + origin_prompt = "" + # refer to vLLM sampling_params: max_token default value + origin_max_tokens = req_data.get("max_tokens", 16) async def generate_stream(): - nonlocal released_kv + nonlocal instance_info + generated_token = "" + released_kv = False + retry_count = 0 + retry = True + completion_tokens = 0 # Only one await per chunk, minimal logic in loop try: - async for chunk in stream_service_response_with_retry( - decoder.client, - api, - req_data, - request_id=request_id, - max_retries=global_args.max_retries, - base_delay=global_args.retry_delay): - if not released_kv and chunk: - proxy_state.release_prefiller_kv( - prefiller_idx, prefiller_score) - released_kv = True - yield chunk + while retry: + retry = False + async for chunk in stream_service_response_with_retry( + instance_info.decoder.client, + api, + req_data, + request_id=instance_info.request_id, + max_retries=global_args.max_retries, + base_delay=global_args.retry_delay): + if not released_kv and chunk: + proxy_state.release_prefiller_kv( + instance_info.prefiller_idx, + instance_info.prefiller_score) + released_kv = True + chunk_str = chunk.decode("utf-8").strip() + if not chunk_str: + continue + if chunk_str.startswith("data: "): + chunk_str = chunk_str[len("data: "):] + try: + chunk_json = json.loads(chunk_str) + except json.JSONDecodeError: + # if chunk is [done], skip it. + logger.warning( + f"Skipping chunk: {chunk_str}") + yield chunk + continue + choices = chunk_json.get("choices", []) + if not choices: + yield chunk + continue + + choice = choices[0] + delta = choice.get("delta") or {} + message = choice.get("message") or {} + content = ( + delta.get("content") + or message.get("content") + or choice.get("text") + or "" + ) + generated_token += content + + stop_reason = choice.get( + "stop_reason") + usage = chunk_json.get("usage", {}) + completion_tokens = (completion_tokens + 1) if stream_flag else \ + (completion_tokens + usage.get("completion_tokens")) + if stop_reason == "recomputed": + retry = True + retry_count += 1 + if chat_flag: + messages[0][ + "content"] = origin_prompt + generated_token + else: + req_data[ + "prompt"] = origin_prompt + generated_token + req_data[ + "max_tokens"] = origin_max_tokens - completion_tokens + retry_count + tmp_request_length = len( + json.dumps(req_data).encode("utf-8")) + instance_info = await _handle_select_instance( + api, req_data, tmp_request_length) + break + if retry_count > 0 and not stream_flag: + if chat_flag: + choice["message"][ + "content"] = generated_token + else: + choice["text"] = generated_token + chunk = json.dumps(chunk_json).encode("utf-8") + yield chunk except Exception as e: logger.error( - f"Error during streaming from decoder {decoder.url}: {str(e)} the aborted request {request_id} will be routing to the target prefiller when new request is ready to dispatch to it" + f"Error during streaming from decoder {instance_info.decoder.url}: {str(e)} the aborted request {instance_info.request_id} will be routing to the target prefiller when new request is ready to dispatch to it" ) - proxy_state.abort_prefiller_request(prefiller_idx, request_id) - proxy_state.release_prefiller_kv(prefiller_idx, - prefiller_score) + proxy_state.abort_prefiller_request( + instance_info.prefiller_idx, instance_info.request_id) + proxy_state.release_prefiller_kv(instance_info.prefiller_idx, + instance_info.prefiller_score) # After streaming done, release tokens - proxy_state.release_decoder(decoder_idx, decoder_score) + proxy_state.release_decoder(instance_info.decoder_idx, + instance_info.decoder_score) return StreamingResponse(generate_stream(), media_type="application/json") @@ -544,4 +647,5 @@ if __name__ == '__main__': global global_args global_args = parse_args() import uvicorn - uvicorn.run(app, host=global_args.host, port=global_args.port) \ No newline at end of file + + uvicorn.run(app, host=global_args.host, port=global_args.port) diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 6a60695..a265e96 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -70,6 +70,8 @@ class AscendConfig: ) and not self.torchair_graph_config.enabled and vllm_config.parallel_config.enable_expert_parallel self.multistream_overlap_shared_expert = additional_config.get( "multistream_overlap_shared_expert", False) + self.recompute_scheduler_enable = additional_config.get( + "recompute_scheduler_enable", False) self.lmhead_tensor_parallel_size = additional_config.get( "lmhead_tensor_parallel_size", None) if self.lmhead_tensor_parallel_size is not None: diff --git a/vllm_ascend/core/recompute_schedule_config.py b/vllm_ascend/core/recompute_schedule_config.py new file mode 100644 index 0000000..be19a1c --- /dev/null +++ b/vllm_ascend/core/recompute_schedule_config.py @@ -0,0 +1,39 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + +from dataclasses import dataclass, fields +from typing import Type, Union + +from vllm.config import SchedulerConfig + +MAX_INT = 2147483647 + + +@dataclass +class RecomputeSchedulerConfig(SchedulerConfig): + scheduler_cls: Union[str, Type[object]] = ( + "vllm_ascend.core.recompute_scheduler.RecomputeScheduler") + + @classmethod + def initialize_from_config(cls, vllm_scheduler_config: SchedulerConfig): + scheduler_config = { + field.name: getattr(vllm_scheduler_config, field.name) + for field in fields(vllm_scheduler_config) if field.init + } + scheduler_config["scheduler_cls"] = ( + "vllm_ascend.core.recompute_scheduler.RecomputeScheduler") + return cls(**scheduler_config) diff --git a/vllm_ascend/core/recompute_scheduler.py b/vllm_ascend/core/recompute_scheduler.py new file mode 100644 index 0000000..8946e2f --- /dev/null +++ b/vllm_ascend/core/recompute_scheduler.py @@ -0,0 +1,1392 @@ +## +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# +from __future__ import annotations + +import itertools +import time +from collections import defaultdict +from collections.abc import Iterable +from dataclasses import dataclass +from typing import Any, Optional, Union + +import numpy as np +import numpy.typing as npt +from vllm.config import VllmConfig +from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch +from vllm.distributed.kv_transfer.kv_connector.factory import \ + KVConnectorFactory +from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1, + KVConnectorRole) +from vllm.distributed.kv_transfer.kv_connector.v1.base import \ + KVConnectorMetadata +from vllm.distributed.kv_transfer.kv_connector.v1.metrics import \ + KVConnectorStats +from vllm.logger import init_logger +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry +from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, + compute_encoder_budget) +from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager +from vllm.v1.core.sched.interface import SchedulerInterface +from vllm.v1.core.sched.output import CachedRequestData, NewRequestData +from vllm.v1.core.sched.request_queue import (SchedulingPolicy, + create_request_queue) +from vllm.v1.core.sched.utils import check_stop, remove_all +from vllm.v1.engine import (EngineCoreEventType, EngineCoreOutput, + EngineCoreOutputs, FinishReason) +from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.metrics.stats import SchedulerStats +from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput +from vllm.v1.request import Request, RequestStatus +from vllm.v1.spec_decode.metrics import SpecDecodingStats +from vllm.v1.structured_output import StructuredOutputManager +from vllm.v1.utils import ConstantList + +logger = init_logger(__name__) + + +class RecomputeScheduler(SchedulerInterface): + """This Scheduler extends vllm's original v1 scheduler of version 0.11 + to fix recomputing bug.""" + + def __init__( + self, + vllm_config: VllmConfig, + kv_cache_config: KVCacheConfig, + structured_output_manager: StructuredOutputManager, + mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, + include_finished_set: bool = False, + log_stats: bool = False, + ) -> None: + self.vllm_config = vllm_config + self.scheduler_config = vllm_config.scheduler_config + self.cache_config = vllm_config.cache_config + self.lora_config = vllm_config.lora_config + self.kv_cache_config = kv_cache_config + self.kv_events_config = vllm_config.kv_events_config + self.parallel_config = vllm_config.parallel_config + self.log_stats = log_stats + self.structured_output_manager = structured_output_manager + self.is_encoder_decoder = vllm_config.model_config.is_encoder_decoder + + # include_finished_set controls whether a separate set of finished + # request ids should be included in the EngineCoreOutputs returned + # by update_from_outputs(). This is currently used in the multi-engine + # case to track request lifetimes efficiently. + self.finished_req_ids_dict: Optional[dict[int, set[str]]] = ( + defaultdict(set) if include_finished_set else None) + + # Scheduling constraints. + self.max_num_running_reqs = self.scheduler_config.max_num_seqs + self.max_num_scheduled_tokens = \ + self.scheduler_config.max_num_batched_tokens + self.max_model_len = self.scheduler_config.max_model_len + self.enable_kv_cache_events = ( + self.kv_events_config is not None + and self.kv_events_config.enable_kv_cache_events) + + # Create KVConnector for the Scheduler. Note that each Worker + # will have a corresponding KVConnector with Role=WORKER. + # KV Connector pushes/pull of remote KVs for P/D and offloading. + self.connector = None + if self.vllm_config.kv_transfer_config is not None: + assert len(self.kv_cache_config.kv_cache_groups) == 1, ( + "Multiple KV cache groups are not currently supported " + "with KV connectors") + assert not self.is_encoder_decoder, ( + "Encoder-decoder models are not currently supported " + "with KV connectors") + self.connector = KVConnectorFactory.create_connector( + config=self.vllm_config, role=KVConnectorRole.SCHEDULER) + + self.kv_event_publisher = EventPublisherFactory.create( + self.kv_events_config, + self.parallel_config.data_parallel_rank, + ) + + num_gpu_blocks = self.cache_config.num_gpu_blocks + assert num_gpu_blocks is not None and num_gpu_blocks > 0 + + self.block_size = self.cache_config.block_size + + self.dcp_world_size = \ + vllm_config.parallel_config.decode_context_parallel_size + # Note(hc): The scheduler’s block_size must be multiplied + # by dcp_world_size, since block hashes are computed on the + # original full token sequence at a granularity of + # original_block_size × dcp_world_size. + if self.dcp_world_size > 1: + self.block_size *= self.dcp_world_size + + # req_id -> Request + self.requests: dict[str, Request] = {} + # Scheduling policy + if self.scheduler_config.policy == "priority": + self.policy = SchedulingPolicy.PRIORITY + elif self.scheduler_config.policy == "fcfs": + self.policy = SchedulingPolicy.FCFS + else: + raise ValueError( + f"Unknown scheduling policy: {self.scheduler_config.policy}") + # Priority queues for requests. + self.waiting = create_request_queue(self.policy) + self.running: list[Request] = [] + + # The request IDs that are finished in between the previous and the + # current steps. This is used to notify the workers about the finished + # requests so that they can free the cached states for those requests. + # This is flushed at the end of each scheduling step. + self.finished_req_ids: set[str] = set() + + # KV Connector: requests in process of async KV loading or recving + self.finished_recving_kv_req_ids: set[str] = set() + + # Encoder-related. + # Calculate encoder cache size if applicable + # NOTE: For now we use the same budget for both compute and space. + # This can be changed when we make encoder cache for embedding caching + # across requests. + encoder_compute_budget, encoder_cache_size = compute_encoder_budget( + model_config=vllm_config.model_config, + scheduler_config=vllm_config.scheduler_config, + mm_registry=mm_registry, + ) + + # NOTE(woosuk): Here, "encoder" includes the vision encoder (and + # projector if needed) for MM models as well as encoder-decoder + # transformers. + self.max_num_encoder_input_tokens = encoder_compute_budget + # NOTE: For the models without encoder (e.g., text-only models), + # the encoder cache will not be initialized because cache size is 0 + # for these models. + self.encoder_cache_manager = EncoderCacheManager( + cache_size=encoder_cache_size) + + speculative_config = vllm_config.speculative_config + self.use_eagle = False + self.num_spec_tokens = self.num_lookahead_tokens = 0 + if speculative_config: + self.num_spec_tokens = speculative_config.num_speculative_tokens + if speculative_config.use_eagle(): + self.use_eagle = True + self.num_lookahead_tokens = self.num_spec_tokens + + # Create the KV cache manager. + self.kv_cache_manager = KVCacheManager( + kv_cache_config=kv_cache_config, + max_model_len=self.max_model_len, + enable_caching=self.cache_config.enable_prefix_caching, + use_eagle=self.use_eagle, + log_stats=self.log_stats, + enable_kv_cache_events=self.enable_kv_cache_events, + dcp_world_size=self.dcp_world_size, + ) + self.use_pp = self.parallel_config.pipeline_parallel_size > 1 + + def schedule(self) -> RecomputeSchedulerOutput: + """This scheduler extends vLLM's original v1 scheduler + by introducing a decoding instance recomputing scheduling strategy. + Specifically, if a request is preempted in the decoding instance, + it halts the process with the recomputed symbol and recalculates + its KVC in the prefill instance.""" + + scheduled_new_reqs: list[Request] = [] + scheduled_resumed_reqs: list[Request] = [] + scheduled_running_reqs: list[Request] = [] + preempted_reqs: list[Request] = [] + recomputed_reqs: list[RecomputeReqInfo] = [] + + req_to_new_blocks: dict[str, KVCacheBlocks] = {} + num_scheduled_tokens: dict[str, int] = {} + token_budget = self.max_num_scheduled_tokens + # Encoder-related. + scheduled_encoder_inputs: dict[str, list[int]] = {} + encoder_compute_budget = self.max_num_encoder_input_tokens + # Spec decode-related. + scheduled_spec_decode_tokens: dict[str, list[int]] = {} + + # For logging. + scheduled_timestamp = time.monotonic() + + # First, schedule the RUNNING requests. + req_index = 0 + while req_index < len(self.running) and token_budget > 0: + request = self.running[req_index] + + num_new_tokens = (request.num_tokens_with_spec + + request.num_output_placeholders - + request.num_computed_tokens) + if (0 < self.scheduler_config.long_prefill_token_threshold < + num_new_tokens): + num_new_tokens = ( + self.scheduler_config.long_prefill_token_threshold) + num_new_tokens = min(num_new_tokens, token_budget) + + # Make sure the input position does not exceed the max model len. + # This is necessary when using spec decoding. + num_new_tokens = min( + num_new_tokens, + self.max_model_len - 1 - request.num_computed_tokens) + + # Schedule encoder inputs. + encoder_inputs_to_schedule = None + new_encoder_compute_budget = encoder_compute_budget + if request.has_encoder_inputs: + (encoder_inputs_to_schedule, num_new_tokens, + new_encoder_compute_budget + ) = self._try_schedule_encoder_inputs( + request, request.num_computed_tokens, num_new_tokens, + encoder_compute_budget) + + if num_new_tokens == 0: + # The request cannot be scheduled because one of the following + # reasons: + # 1. No new tokens to schedule. This may happen when + # (1) PP>1 and we have already scheduled all prompt tokens + # but they are not finished yet. + # (2) Async scheduling and the request has reached to either + # its max_total_tokens or max_model_len. + # 2. The encoder budget is exhausted. + # 3. The encoder cache is exhausted. + # NOTE(woosuk): Here, by doing `continue` instead of `break`, + # we do not strictly follow the FCFS scheduling policy and + # allow the lower-priority requests to be scheduled. + req_index += 1 + continue + + while True: + new_blocks = self.kv_cache_manager.allocate_slots( + request, + num_new_tokens, + num_lookahead_tokens=self.num_lookahead_tokens) + if new_blocks is None: + transfer_config = self.vllm_config.kv_transfer_config + if transfer_config is not None and not transfer_config.is_kv_producer: + recomputed_req = self.running.pop() + self.kv_cache_manager.free(recomputed_req) + recomputed_reqs.append( + RecomputeReqInfo(recomputed_req.request_id, + recomputed_req.output_token_ids, + recomputed_req.client_index)) + if recomputed_req == request: + can_schedule = False + break + else: + # The request cannot be scheduled. + # Preempt the lowest-priority request. + if self.policy == SchedulingPolicy.PRIORITY: + preempted_req = max( + self.running, + key=lambda r: (r.priority, r.arrival_time), + ) + self.running.remove(preempted_req) + if preempted_req in scheduled_running_reqs: + scheduled_running_reqs.remove(preempted_req) + else: + preempted_req = self.running.pop() + + self.kv_cache_manager.free(preempted_req) + self.encoder_cache_manager.free(preempted_req) + preempted_req.status = RequestStatus.PREEMPTED + preempted_req.num_computed_tokens = 0 + if self.log_stats: + preempted_req.record_event( + EngineCoreEventType.PREEMPTED, + scheduled_timestamp) + + self.waiting.prepend_request(preempted_req) + preempted_reqs.append(preempted_req) + if preempted_req == request: + # No more request to preempt. + can_schedule = False + break + else: + # The request can be scheduled. + can_schedule = True + break + if not can_schedule: + break + assert new_blocks is not None + + # Schedule the request. + scheduled_running_reqs.append(request) + req_to_new_blocks[request.request_id] = new_blocks + num_scheduled_tokens[request.request_id] = num_new_tokens + token_budget -= num_new_tokens + req_index += 1 + + # Speculative decode related. + if request.spec_token_ids: + num_scheduled_spec_tokens = (num_new_tokens + + request.num_computed_tokens - + request.num_tokens) + if num_scheduled_spec_tokens > 0: + # Trim spec_token_ids list to num_scheduled_spec_tokens. + del request.spec_token_ids[num_scheduled_spec_tokens:] + scheduled_spec_decode_tokens[request.request_id] = ( + request.spec_token_ids) + + # Encoder-related. + if encoder_inputs_to_schedule: + scheduled_encoder_inputs[request.request_id] = ( + encoder_inputs_to_schedule) + # Allocate the encoder cache. + for i in encoder_inputs_to_schedule: + self.encoder_cache_manager.allocate(request, i) + encoder_compute_budget = new_encoder_compute_budget + + # Record the LoRAs in scheduled_running_reqs + scheduled_loras: set[int] = set() + if self.lora_config: + scheduled_loras = set( + req.lora_request.lora_int_id for req in scheduled_running_reqs + if req.lora_request and req.lora_request.lora_int_id > 0) + assert len(scheduled_loras) <= self.lora_config.max_loras + + # Use a temporary RequestQueue to collect requests that need to be + # skipped and put back at the head of the waiting queue later + skipped_waiting_requests = create_request_queue(self.policy) + + # Next, schedule the WAITING requests. + if not preempted_reqs: + while self.waiting and token_budget > 0: + if len(self.running) == self.max_num_running_reqs: + break + + request = self.waiting.peek_request() + + # KVTransfer: skip request if still waiting for remote kvs. + if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: + is_ready = self._update_waiting_for_remote_kv(request) + if is_ready: + request.status = RequestStatus.WAITING + else: + logger.debug( + "%s is still in WAITING_FOR_REMOTE_KVS state.", + request.request_id) + self.waiting.pop_request() + skipped_waiting_requests.prepend_request(request) + continue + + # Skip request if the structured output request is still waiting + # for FSM compilation. + if request.status == RequestStatus.WAITING_FOR_FSM: + structured_output_req = request.structured_output_request + if structured_output_req and structured_output_req.grammar: + request.status = RequestStatus.WAITING + else: + self.waiting.pop_request() + skipped_waiting_requests.prepend_request(request) + continue + + # Check that adding the request still respects the max_loras + # constraint. + if (self.lora_config and request.lora_request and + (len(scheduled_loras) == self.lora_config.max_loras and + request.lora_request.lora_int_id not in scheduled_loras)): + # Scheduling would exceed max_loras, skip. + self.waiting.pop_request() + skipped_waiting_requests.prepend_request(request) + continue + + num_external_computed_tokens = 0 + load_kv_async = False + + # Get already-cached tokens. + if request.num_computed_tokens == 0: + # Get locally-cached tokens. + new_computed_blocks, num_new_local_computed_tokens = \ + self.kv_cache_manager.get_computed_blocks( + request) + + # Get externally-cached tokens if using a KVConnector. + if self.connector is not None: + num_external_computed_tokens, load_kv_async = ( + self.connector.get_num_new_matched_tokens( + request, num_new_local_computed_tokens)) + + if num_external_computed_tokens is None: + # The request cannot be scheduled because + # the KVConnector couldn't determine + # the number of matched tokens. + self.waiting.pop_request() + skipped_waiting_requests.prepend_request(request) + continue + + # Total computed tokens (local + external). + num_computed_tokens = (num_new_local_computed_tokens + + num_external_computed_tokens) + # KVTransfer: WAITING reqs have num_computed_tokens > 0 + # after async KV recvs are completed. + else: + new_computed_blocks = ( + self.kv_cache_manager.create_empty_block_list()) + num_new_local_computed_tokens = 0 + num_computed_tokens = request.num_computed_tokens + + encoder_inputs_to_schedule = None + new_encoder_compute_budget = encoder_compute_budget + + # KVTransfer: loading remote KV, do not allocate for new work. + if load_kv_async: + assert num_external_computed_tokens > 0 + num_new_tokens = 0 + # Number of tokens to be scheduled. + else: + # We use `request.num_tokens` instead of + # `request.num_prompt_tokens` to consider the resumed + # requests, which have output tokens. + num_new_tokens = request.num_tokens - num_computed_tokens + if (0 < self.scheduler_config.long_prefill_token_threshold + < num_new_tokens): + num_new_tokens = ( + self.scheduler_config.long_prefill_token_threshold) + + # chunked prefill has to be enabled explicitly to allow + # pooling requests to be chunked + if not self.scheduler_config.chunked_prefill_enabled and \ + num_new_tokens > token_budget: + self.waiting.pop_request() + skipped_waiting_requests.prepend_request(request) + continue + + num_new_tokens = min(num_new_tokens, token_budget) + assert num_new_tokens > 0 + + # Schedule encoder inputs. + if request.has_encoder_inputs: + (encoder_inputs_to_schedule, num_new_tokens, + new_encoder_compute_budget + ) = self._try_schedule_encoder_inputs( + request, num_computed_tokens, num_new_tokens, + encoder_compute_budget) + if num_new_tokens == 0: + # The request cannot be scheduled. + break + + # Handles an edge case when P/D Disaggregation + # is used with Spec Decoding where an + # extra block gets allocated which + # creates a mismatch between the number + # of local and remote blocks. + effective_lookahead_tokens = (0 if request.num_computed_tokens + == 0 else + self.num_lookahead_tokens) + + # Determine if we need to allocate cross-attention blocks. + if self.is_encoder_decoder and request.has_encoder_inputs: + # TODO(russellb): For Whisper, we know that the input is + # always padded to the maximum length. If we support other + # encoder-decoder models, this will need to be updated if we + # want to only allocate what is needed. + num_encoder_tokens = \ + self.scheduler_config.max_num_encoder_input_tokens + else: + num_encoder_tokens = 0 + + new_blocks = self.kv_cache_manager.allocate_slots( + request, + num_new_tokens + num_external_computed_tokens, + num_new_local_computed_tokens, + new_computed_blocks, + num_lookahead_tokens=effective_lookahead_tokens, + delay_cache_blocks=load_kv_async, + num_encoder_tokens=num_encoder_tokens, + ) + + if new_blocks is None: + # The request cannot be scheduled. + break + + # KVTransfer: the connector uses this info to determine + # if a load is needed. Note that + # This information is used to determine if a load is + # needed for this request. + if self.connector is not None: + self.connector.update_state_after_alloc( + request, + new_computed_blocks + new_blocks, + num_external_computed_tokens, + ) + + # Request was already popped from self.waiting + # unless it was re-added above due to new_blocks being None. + request = self.waiting.pop_request() + if load_kv_async: + # If loading async, allocate memory and put request + # into the WAITING_FOR_REMOTE_KV state. + skipped_waiting_requests.prepend_request(request) + request.status = RequestStatus.WAITING_FOR_REMOTE_KVS + continue + + req_index += 1 + self.running.append(request) + if self.log_stats: + request.record_event(EngineCoreEventType.SCHEDULED, + scheduled_timestamp) + if request.status == RequestStatus.WAITING: + scheduled_new_reqs.append(request) + elif request.status == RequestStatus.PREEMPTED: + scheduled_resumed_reqs.append(request) + else: + raise RuntimeError( + f"Invalid request status: {request.status}") + + if self.lora_config and request.lora_request: + scheduled_loras.add(request.lora_request.lora_int_id) + req_to_new_blocks[request.request_id] = ( + self.kv_cache_manager.get_blocks(request.request_id)) + num_scheduled_tokens[request.request_id] = num_new_tokens + token_budget -= num_new_tokens + request.status = RequestStatus.RUNNING + request.num_computed_tokens = num_computed_tokens + # Count the number of prefix cached tokens. + if request.num_cached_tokens < 0: + request.num_cached_tokens = num_computed_tokens + # Encoder-related. + if encoder_inputs_to_schedule: + scheduled_encoder_inputs[request.request_id] = ( + encoder_inputs_to_schedule) + # Allocate the encoder cache. + for i in encoder_inputs_to_schedule: + self.encoder_cache_manager.allocate(request, i) + encoder_compute_budget = new_encoder_compute_budget + + # Put back any skipped requests at the head of the waiting queue + if skipped_waiting_requests: + self.waiting.prepend_requests(skipped_waiting_requests) + + # Check if the scheduling constraints are satisfied. + total_num_scheduled_tokens = sum(num_scheduled_tokens.values()) + assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens + assert token_budget >= 0 + assert len(self.running) <= self.max_num_running_reqs + # Since some requests in the RUNNING queue may not be scheduled in + # this step, the total number of scheduled requests can be smaller than + # len(self.running). + assert (len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + + len(scheduled_running_reqs) <= len(self.running)) + + # Get the longest common prefix among all requests in the running queue. + # This can be potentially used for cascade attention. + num_common_prefix_blocks = [0] * len( + self.kv_cache_config.kv_cache_groups) + if self.running: + any_request = self.running[0] + num_common_prefix_blocks = ( + self.kv_cache_manager.get_num_common_prefix_blocks( + any_request, len(self.running))) + + # Construct the scheduler output. + new_reqs_data = [ + NewRequestData.from_request( + req, req_to_new_blocks[req.request_id].get_block_ids()) + for req in scheduled_new_reqs + ] + cached_reqs_data = self._make_cached_request_data( + scheduled_running_reqs, + scheduled_resumed_reqs, + num_scheduled_tokens, + scheduled_spec_decode_tokens, + req_to_new_blocks, + ) + scheduled_requests = (scheduled_new_reqs + scheduled_running_reqs + + scheduled_resumed_reqs) + structured_output_request_ids, grammar_bitmask = ( + self.get_grammar_bitmask(scheduled_requests, + scheduled_spec_decode_tokens)) + scheduler_output = RecomputeSchedulerOutput( + scheduled_new_reqs=new_reqs_data, + scheduled_cached_reqs=cached_reqs_data, + num_scheduled_tokens=num_scheduled_tokens, + total_num_scheduled_tokens=total_num_scheduled_tokens, + scheduled_spec_decode_tokens=scheduled_spec_decode_tokens, + scheduled_encoder_inputs=scheduled_encoder_inputs, + num_common_prefix_blocks=num_common_prefix_blocks, + # finished_req_ids is an existing state in the scheduler, + # instead of being newly scheduled in this step. + # It contains the request IDs that are finished in between + # the previous and the current steps. + finished_req_ids=self.finished_req_ids, + free_encoder_mm_hashes=self.encoder_cache_manager. + get_freed_mm_hashes(), + structured_output_request_ids=structured_output_request_ids, + grammar_bitmask=grammar_bitmask, + recomputed_reqs=recomputed_reqs, + ) + + # NOTE(Kuntai): this function is designed for multiple purposes: + # 1. Plan the KV cache store + # 2. Wrap up all the KV cache load / save ops into an opaque object + # 3. Clear the internal states of the connector + if self.connector is not None: + meta = self.connector.build_connector_meta(scheduler_output) + scheduler_output.kv_connector_metadata = meta + + # collect KV cache events from KV cache manager + events = self.kv_cache_manager.take_events() + + # collect KV cache events from connector + if self.connector is not None: + connector_events = self.connector.take_events() + if connector_events: + if events is None: + events = list(connector_events) + else: + events.extend(connector_events) + + # publish collected KV cache events + if events: + batch = KVEventBatch(ts=time.time(), events=events) + self.kv_event_publisher.publish(batch) + + self._update_after_schedule(scheduler_output) + return scheduler_output + + def _update_after_schedule( + self, + scheduler_output: RecomputeSchedulerOutput, + ) -> None: + # Advance the number of computed tokens for the request AFTER + # the request is scheduled. + # 1. The scheduler_output of the current step has to include the + # original number of scheduled tokens to determine input IDs. + # 2. Advance the number of computed tokens here allowing us to + # schedule the prefill request again immediately in the next + # scheduling step. + # 3. If some tokens (e.g. spec tokens) are rejected later, the number of + # computed tokens will be adjusted in update_from_output. + num_scheduled_tokens = scheduler_output.num_scheduled_tokens + for req_id, num_scheduled_token in num_scheduled_tokens.items(): + request = self.requests[req_id] + request.num_computed_tokens += num_scheduled_token + + # NOTE: _free_encoder_inputs relies on num_computed_tokens, which + # may be updated again in _update_from_output for speculative + # decoding. However, it is safe to call the method here because + # encoder inputs are always part of the prompt, not the output, + # and thus are unaffected by speculative decoding. + if request.has_encoder_inputs: + self._free_encoder_inputs(request) + + # Clear the finished request IDs. + # NOTE: We shouldn't do self.finished_req_ids.clear() here because + # it will also affect the scheduler output. + self.finished_req_ids = set() + + def _make_cached_request_data( + self, + running_reqs: list[Request], + resumed_reqs: list[Request], + num_scheduled_tokens: dict[str, int], + spec_decode_tokens: dict[str, list[int]], + req_to_new_blocks: dict[str, KVCacheBlocks], + ) -> CachedRequestData: + req_ids: list[str] = [] + new_token_ids: list[list[int]] = [] + new_block_ids: list[Optional[tuple[list[int], ...]]] = [] + num_computed_tokens: list[int] = [] + + use_connector = self.connector is not None + for req in itertools.chain(running_reqs, resumed_reqs): + req_id = req.request_id + req_ids.append(req_id) + num_tokens = (num_scheduled_tokens[req_id] - + len(spec_decode_tokens.get(req_id, ()))) + if self.use_pp: + # When using PP, the scheduler sends the sampled tokens back, + # because there's no direct communication between the first- + # stage worker and the last-stage worker. Otherwise, we don't + # need to send the sampled tokens back because the model runner + # will cache them. + token_ids = req.all_token_ids[req.num_computed_tokens:req. + num_computed_tokens + num_tokens] + new_token_ids.append(token_ids) + elif use_connector: + # When using a KVConnector, we add a placeholder to avoid index + # out of bounds errors. TODO: Remove this once the KVConnector + # is updated to handle token IDs properly. + new_token_ids.append([]) + new_block_ids.append( + req_to_new_blocks[req_id].get_block_ids(allow_none=True)) + num_computed_tokens.append(req.num_computed_tokens) + # Because resumed_reqs is usually empty, it is more efficient to do + # in-place appending so that we don't need to allocate a new list. + resumed_from_preemption = [False] * len(running_reqs) + resumed_from_preemption += [True] * len(resumed_reqs) + + return CachedRequestData( + req_ids=req_ids, + resumed_from_preemption=resumed_from_preemption, + new_token_ids=new_token_ids, + new_block_ids=new_block_ids, + num_computed_tokens=num_computed_tokens, + ) + + def _try_schedule_encoder_inputs( + self, + request: Request, + num_computed_tokens: int, + num_new_tokens: int, + encoder_compute_budget: int, + ) -> tuple[list[int], int, int]: + """ + Determine which encoder inputs need to be scheduled in the current step, + and update `num_new_tokens` and encoder token budget accordingly. + + An encoder input will be scheduled if: + - Its output tokens overlap with the range of tokens being computed + in this step, i.e., + [num_computed_tokens, num_computed_tokens + num_new_tokens). + - It is not already computed and stored in the encoder cache. + - There is sufficient encoder token budget to process it. + - The encoder cache has space to store it. + + If an encoder input cannot be scheduled due to cache or budget + limitations, the method adjusts `num_new_tokens` to schedule only the + decoder tokens up to just before the unschedulable encoder input. + + Note that num_computed_tokens includes both locally cached + blocks and externally cached blocks (via KVConnector). + """ + if num_new_tokens == 0 or not request.has_encoder_inputs: + return [], num_new_tokens, encoder_compute_budget + encoder_inputs_to_schedule: list[int] = [] + mm_features = request.mm_features + assert mm_features is not None + assert len(mm_features) > 0 + + # NOTE: since scheduler operates on the request level (possibly with + # multiple encoder inputs per request), we need to create temporary + # trackers for accounting at the encoder input level. + mm_hashes_to_schedule = set() + num_tokens_to_schedule = 0 + for i, mm_feature in enumerate(mm_features): + start_pos = mm_feature.mm_position.offset + num_encoder_tokens = mm_feature.mm_position.length + + # The encoder output is needed if the two ranges overlap: + # [num_computed_tokens, num_computed_tokens + num_new_tokens) and + # [start_pos, start_pos + num_encoder_tokens) + if start_pos >= num_computed_tokens + num_new_tokens: + # The encoder input is not needed in this step. + break + + if self.is_encoder_decoder and num_computed_tokens > 0: + assert start_pos == 0, ( + "Encoder input should be processed at the beginning of " + "the sequence when encoder-decoder models are used.") + # Encoder input has already been computed + # The calculation here is a bit different. We don't turn encoder + # output into tokens that get processed by the decoder and + # reflected in num_computed_tokens. Instead, start_pos reflects + # the position where we need to ensure we calculate encoder + # inputs. This should always be 0 to ensure we calculate encoder + # inputs before running the decoder. Once we've calculated some + # decoder tokens (num_computed_tokens > 0), then we know we + # already calculated encoder inputs and can skip here. + continue + elif start_pos + num_encoder_tokens <= num_computed_tokens: + # The encoder input is already computed and stored + # in the decoder's KV cache. + continue + + if not self.is_encoder_decoder: + # We are not using the encoder cache for encoder-decoder models, + # yet. + if request.mm_features[i].identifier in mm_hashes_to_schedule: + # The same encoder input has already been scheduled in the + # current step. + continue + + if self.encoder_cache_manager.check_and_update_cache( + request, i): + # The encoder input is already computed and cached from a + # previous step. + continue + + # If no encoder input chunking is allowed, we do not want to + # partially schedule a multimodal item. If the scheduled range would + # only cover part of the mm input, roll back to before the mm item. + if (self.scheduler_config.disable_chunked_mm_input + and num_computed_tokens < start_pos + and (num_computed_tokens + num_new_tokens) + < (start_pos + num_encoder_tokens)): + num_new_tokens = start_pos - num_computed_tokens + break + + if not self.encoder_cache_manager.can_allocate( + request, i, encoder_compute_budget, + num_tokens_to_schedule): + # The encoder cache is full or the encoder budget is exhausted. + # NOTE(woosuk): We assume that the encoder input tokens should + # be processed altogether, as the encoder usually uses + # bidirectional attention. + if num_computed_tokens < start_pos: + # We only schedule the decoder tokens just before the + # encoder input. + num_new_tokens = start_pos - num_computed_tokens + else: + # Because of prefix caching, num_computed_tokens is greater + # than start_pos even though its encoder input is not + # available. In this case, we can't schedule any token for + # the request in this step. + num_new_tokens = 0 + break + + num_tokens_to_schedule += num_encoder_tokens + encoder_compute_budget -= num_encoder_tokens + mm_hashes_to_schedule.add(request.mm_features[i].identifier) + encoder_inputs_to_schedule.append(i) + + return ( + encoder_inputs_to_schedule, + num_new_tokens, + encoder_compute_budget, + ) + + def get_grammar_bitmask( + self, + requests: list[Request], + scheduled_spec_decode_tokens: dict[str, list[int]], + ): + # NOTE: structured_output_request_ids maps + # a request's (request that uses structured output) + # request_id to its index in the batch. + # This will help us determine to slice the grammar bitmask + # and only applies valid mask for requests that + # uses structured decoding. + structured_output_request_ids: dict[str, int] = {} + for i, req in enumerate(requests): + if req.use_structured_output: + # PERF: in case of chunked prefill, + # request might not include any new tokens. + # Therefore, we might introduce some additional + # cycle to fill in the bitmask, which could be a big no-op. + structured_output_request_ids[req.request_id] = i + + if not structured_output_request_ids: + bitmask = None + else: + bitmask = self.structured_output_manager.grammar_bitmask( + self.requests, + structured_output_request_ids, + scheduled_spec_decode_tokens, + ) + return structured_output_request_ids, bitmask + + def update_from_output( + self, + scheduler_output: RecomputeSchedulerOutput, + model_runner_output: ModelRunnerOutput, + ) -> dict[int, EngineCoreOutputs]: + sampled_token_ids = model_runner_output.sampled_token_ids + logprobs = model_runner_output.logprobs + prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict + num_scheduled_tokens = scheduler_output.num_scheduled_tokens + pooler_outputs = model_runner_output.pooler_output + num_nans_in_logits = model_runner_output.num_nans_in_logits + kv_connector_output = model_runner_output.kv_connector_output + + outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list) + spec_decoding_stats: Optional[SpecDecodingStats] = None + kv_connector_stats = (kv_connector_output.kv_connector_stats + if kv_connector_output else None) + # return recomputed requests as EngineCoreOutput + for req_info in scheduler_output.recomputed_reqs: + outputs[req_info.client_index].append( + EngineCoreOutput( + request_id=req_info.request_id, + finish_reason=FinishReason.STOP, + new_token_ids=[req_info.output_token_ids[-1]], + stop_reason="recomputed", + )) + # NOTE(woosuk): As len(num_scheduled_tokens) can be up to 1K or more, + # the below loop can be a performance bottleneck. We should do our best + # to avoid expensive operations inside the loop. + stopped_running_reqs: set[Request] = set() + stopped_preempted_reqs: set[Request] = set() + for req_id, num_tokens_scheduled in num_scheduled_tokens.items(): + assert num_tokens_scheduled > 0 + request = self.requests.get(req_id) + if request is None: + # The request is already finished. This can happen if the + # request is aborted while the model is executing it (e.g., + # in pipeline parallelism). + continue + + req_index = model_runner_output.req_id_to_index[req_id] + generated_token_ids = sampled_token_ids[ + req_index] if sampled_token_ids else [] + + scheduled_spec_token_ids = ( + scheduler_output.scheduled_spec_decode_tokens.get(req_id)) + if scheduled_spec_token_ids: + num_draft_tokens = len(scheduled_spec_token_ids) + num_accepted = len(generated_token_ids) - 1 + num_rejected = num_draft_tokens - num_accepted + # num_computed_tokens represents the number of tokens + # processed in the current step, considering scheduled + # tokens and rejections. If some tokens are rejected, + # num_computed_tokens is decreased by the number of rejected + # tokens. + request.num_computed_tokens -= num_rejected + spec_decoding_stats = self.make_spec_decoding_stats( + spec_decoding_stats, + num_draft_tokens=num_draft_tokens, + num_accepted_tokens=num_accepted) + + stopped = False + new_logprobs = None + new_token_ids = generated_token_ids + kv_transfer_params = None + status_before_stop = request.status + + # Check for stop and update request status. + if new_token_ids: + new_token_ids, stopped = self._update_request_with_output( + request, new_token_ids) + + # Stop checking for pooler models. + pooler_output = None + if pooler_outputs: + pooler_output = pooler_outputs[req_index] + stopped = check_stop(request, self.max_model_len, + pooler_output) + + if stopped: + kv_transfer_params = self._free_request(request) + if status_before_stop == RequestStatus.RUNNING: + stopped_running_reqs.add(request) + else: + stopped_preempted_reqs.add(request) + + # Extract sample logprobs if needed. + if request.sampling_params is not None \ + and request.sampling_params.logprobs is not None and logprobs: + # NOTE: once we support N tokens per step (spec decode), + # the outer lists can be of length > 1. + new_logprobs = logprobs.slice(req_index, req_index + 1) + + if new_token_ids and self.structured_output_manager.should_advance( + request): + # NOTE: structured_output_request + # should not be None if use_structured_output, we have + # checked above, so safe to ignore type warning + request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr] + req_id, new_token_ids) + + if num_nans_in_logits is not None and req_id in num_nans_in_logits: + request.num_nans_in_logits = num_nans_in_logits[req_id] + + # Get prompt logprobs for this request. + prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) + if new_token_ids or pooler_output is not None \ + or kv_transfer_params: + + # Add EngineCoreOutput for this Request. + outputs[request.client_index].append( + EngineCoreOutput( + request_id=req_id, + new_token_ids=new_token_ids, + finish_reason=request.get_finished_reason(), + new_logprobs=new_logprobs, + new_prompt_logprobs_tensors=prompt_logprobs_tensors, + pooling_output=pooler_output, + stop_reason=request.stop_reason, + events=request.take_events(), + kv_transfer_params=kv_transfer_params, + trace_headers=request.trace_headers, + num_cached_tokens=request.num_cached_tokens, + )) + else: + # Invariant: EngineCore returns no partial prefill outputs. + assert not prompt_logprobs_tensors + + # Remove the stopped requests from the running and waiting queues. + if stopped_running_reqs: + self.running = remove_all(self.running, stopped_running_reqs) + if stopped_preempted_reqs: + # This is a rare case and unlikely to impact performance. + self.waiting.remove_requests(stopped_preempted_reqs) + + # KV Connector: update state for finished KV Transfers. + if model_runner_output.kv_connector_output: + self._update_from_kv_xfer_finished( + model_runner_output.kv_connector_output) + + # Create EngineCoreOutputs for all clients that have requests with + # outputs in this step. + engine_core_outputs = { + client_index: EngineCoreOutputs(outputs=outs) + for client_index, outs in outputs.items() + } + + finished_req_ids = self.finished_req_ids_dict + if finished_req_ids: + # Include ids of requests that finished since last outputs + # were sent. + for client_index, finished_set in finished_req_ids.items(): + # Set finished request set in EngineCoreOutputs for this client. + if (eco := engine_core_outputs.get(client_index)) is not None: + eco.finished_requests = finished_set + else: + engine_core_outputs[client_index] = EngineCoreOutputs( + finished_requests=finished_set) + finished_req_ids.clear() + + if (stats := self.make_stats(spec_decoding_stats, + kv_connector_stats)) is not None: + # Return stats to only one of the front-ends. + if (eco := next(iter(engine_core_outputs.values()), None)) is None: + # We must return the stats even if there are no request + # outputs this step. + engine_core_outputs[0] = eco = EngineCoreOutputs() + eco.scheduler_stats = stats + + return engine_core_outputs + + def _update_request_with_output( + self, + request: Request, + new_token_ids: list[int], + ) -> tuple[list[int], bool]: + # Append generated tokens and check for stop. Note that if + # a request is still being prefilled, we expect the model runner + # to return empty token ids for the request. + stopped = False + for num_new, output_token_id in enumerate(new_token_ids, 1): + request.append_output_token_ids(output_token_id) + + # Check for stop and update request state. + # This must be called before we make the EngineCoreOutput. + stopped = check_stop(request, self.max_model_len) + if stopped: + del new_token_ids[num_new:] # Trim new tokens if needed. + break + return new_token_ids, stopped + + def _free_encoder_inputs(self, request: Request) -> None: + cached_encoder_input_ids = ( + self.encoder_cache_manager.get_cached_input_ids(request)) + # OPTIMIZATION: Avoid list(set) if the set is empty. + if not cached_encoder_input_ids: + return + + # Here, we use list(set) to avoid modifying the set while iterating + # over it. + for input_id in list(cached_encoder_input_ids): + mm_feature = request.mm_features[input_id] + start_pos = mm_feature.mm_position.offset + num_tokens = mm_feature.mm_position.length + if self.is_encoder_decoder and request.num_computed_tokens > 0: + # With Whisper, as soon as we've generated a single token, + # we know we're done with the encoder input. Cross Attention + # KVs have been calculated and cached already. + self.encoder_cache_manager.free_encoder_input( + request, input_id) + elif start_pos + num_tokens <= request.num_computed_tokens: + # The encoder output is already processed and stored + # in the decoder's KV cache. + self.encoder_cache_manager.free_encoder_input( + request, input_id) + + def update_draft_token_ids( + self, + draft_token_ids: DraftTokenIds, + ) -> None: + for req_id, spec_token_ids in zip( + draft_token_ids.req_ids, + draft_token_ids.draft_token_ids, + ): + request = self.requests.get(req_id) + if request is None or request.is_finished(): + # The request may have been finished. Skip. + continue + + # Add newly generated spec token ids to the request. + if not spec_token_ids: + # NOTE(woosuk): request.spec_token_ids should be updated. + request.spec_token_ids.clear() + elif self.structured_output_manager.should_advance(request): + metadata = request.structured_output_request + request.spec_token_ids = metadata.grammar.validate_tokens( # type: ignore[union-attr] + spec_token_ids) + else: + request.spec_token_ids = spec_token_ids + + def get_request_counts(self) -> tuple[int, int]: + """Returns (num_running_reqs, num_waiting_reqs).""" + return len(self.running), len(self.waiting) + + def add_request(self, request: Request) -> None: + self.waiting.add_request(request) + self.requests[request.request_id] = request + if self.log_stats: + request.record_event(EngineCoreEventType.QUEUED) + + def finish_requests( + self, + request_ids: Union[str, Iterable[str]], + finished_status: RequestStatus, + ) -> None: + """Handles the finish signal from outside the scheduler. + + For example, the API server can abort a request when the client + disconnects. + """ + assert RequestStatus.is_finished(finished_status) + if isinstance(request_ids, str): + request_ids = (request_ids, ) + else: + request_ids = set(request_ids) + + running_requests_to_remove = set() + waiting_requests_to_remove = [] + valid_requests = [] + + # First pass: collect requests to remove from queues + for req_id in request_ids: + request = self.requests.get(req_id) + if request is None: + # Invalid request ID. + continue + + valid_requests.append(request) + if request.status == RequestStatus.RUNNING: + running_requests_to_remove.add(request) + else: + waiting_requests_to_remove.append(request) + + # Remove all requests from queues at once for better efficiency + if running_requests_to_remove: + self.running = remove_all(self.running, running_requests_to_remove) + if waiting_requests_to_remove: + self.waiting.remove_requests(waiting_requests_to_remove) + + # Second pass: set status and free requests + for request in valid_requests: + request.status = finished_status + self._free_request(request) + + def _free_request(self, request: Request) -> Optional[dict[str, Any]]: + assert request.is_finished() + + delay_free_blocks, kv_xfer_params = self._connector_finished(request) + self.encoder_cache_manager.free(request) + request_id = request.request_id + self.finished_req_ids.add(request_id) + if self.finished_req_ids_dict is not None: + self.finished_req_ids_dict[request.client_index].add(request_id) + + if not delay_free_blocks: + self._free_blocks(request) + + return kv_xfer_params + + def _free_blocks(self, request: Request): + assert request.is_finished() + self.kv_cache_manager.free(request) + del self.requests[request.request_id] + + def get_num_unfinished_requests(self) -> int: + return len(self.waiting) + len(self.running) + + def has_finished_requests(self) -> bool: + return len(self.finished_req_ids) > 0 + + def reset_prefix_cache(self) -> bool: + return self.kv_cache_manager.reset_prefix_cache() + + def make_stats( + self, + spec_decoding_stats: Optional[SpecDecodingStats] = None, + kv_connector_stats: Optional[KVConnectorStats] = None, + ) -> Optional[SchedulerStats]: + if not self.log_stats: + return None + prefix_cache_stats = self.kv_cache_manager.make_prefix_cache_stats() + assert prefix_cache_stats is not None + return SchedulerStats(num_running_reqs=len(self.running), + num_waiting_reqs=len(self.waiting), + kv_cache_usage=self.kv_cache_manager.usage, + prefix_cache_stats=prefix_cache_stats, + spec_decoding_stats=spec_decoding_stats, + num_corrupted_reqs=sum(req.is_output_corrupted + for req in self.running), + kv_connector_stats=kv_connector_stats.data + if kv_connector_stats else None) + + def make_spec_decoding_stats( + self, + spec_decoding_stats: Optional[SpecDecodingStats], + num_draft_tokens: int, + num_accepted_tokens: int, + ) -> Optional[SpecDecodingStats]: + if not self.log_stats: + return None + if spec_decoding_stats is None: + spec_decoding_stats = SpecDecodingStats.new(self.num_spec_tokens) + spec_decoding_stats.observe_draft( + num_draft_tokens=num_draft_tokens, + num_accepted_tokens=num_accepted_tokens) + return spec_decoding_stats + + def shutdown(self) -> None: + if self.kv_event_publisher: + self.kv_event_publisher.shutdown() + if self.connector is not None: + self.connector.shutdown() + + ######################################################################## + # KV Connector Related Methods + ######################################################################## + + def get_kv_connector(self) -> Optional[KVConnectorBase_V1]: + return self.connector + + def _connector_finished( + self, request: Request) -> tuple[bool, Optional[dict[str, Any]]]: + """ + Invoke the KV connector request_finished() method if applicable. + + Returns optional kv transfer parameters to be included with the + request outputs. + """ + if self.connector is None: + return False, None + + (block_ids, ) = self.kv_cache_manager.get_block_ids(request.request_id) + return self.connector.request_finished(request, block_ids) + + def _update_waiting_for_remote_kv(self, request: Request) -> bool: + """ + KV Connector: check if the request_id is finished_recving. + + The finished_recving_kv_req_ids list is populated + on the previous steps()'s update_from_output based + on the worker side connector. + + When the kv transfer is ready, we cache the blocks + and the request state will be moved back to WAITING from + WAITING_FOR_REMOTE_KV. + """ + assert self.connector is not None + if request.request_id not in self.finished_recving_kv_req_ids: + return False + + # Now that the blocks are ready, actually cache them. + (block_ids, ) = self.kv_cache_manager.get_block_ids(request.request_id) + num_computed_tokens = len(block_ids) * self.block_size + # Handle the case where num request tokens less than one block. + num_computed_tokens = min(num_computed_tokens, request.num_tokens) + if num_computed_tokens == request.num_tokens: + num_computed_tokens -= 1 + # This will cache the blocks iff caching is enabled. + self.kv_cache_manager.cache_blocks(request, num_computed_tokens) + + # Update the request state for scheduling. + request.num_computed_tokens = num_computed_tokens + + # Return that we are ready. + self.finished_recving_kv_req_ids.remove(request.request_id) + return True + + def _update_from_kv_xfer_finished(self, + kv_connector_output: KVConnectorOutput): + """ + KV Connector: update the scheduler state based on the output. + + The Worker side connectors add finished_recving and + finished_sending reqs to the output. + * if finished_sending: free the blocks + # if finished_recving: add to state so we can + schedule the request during the next step. + """ + + if self.connector is not None: + self.connector.update_connector_output(kv_connector_output) + + # KV Connector:: update recv and send status from last step. + for req_id in (kv_connector_output.finished_recving or ()): + logger.debug("Finished recving KV transfer for request %s", req_id) + self.finished_recving_kv_req_ids.add(req_id) + for req_id in (kv_connector_output.finished_sending or ()): + logger.debug("Finished sending KV transfer for request %s", req_id) + if req_id not in self.requests: + logger.warning( + "Got finished sending KV transfer for request %s," + "but the request is already freed.", req_id) + else: + self._free_blocks(self.requests[req_id]) + + +@dataclass +class RecomputeReqInfo: + request_id: str + output_token_ids: ConstantList + client_index: int = 0 + + +@dataclass +class RecomputeSchedulerOutput: + + # list of the requests that are scheduled for the first time. + # We cache the request's data in each worker process, so that we don't + # need to re-send it every scheduling step. + scheduled_new_reqs: list[NewRequestData] + # list of the requests that have been scheduled before. + # Since the request's data is already cached in the worker processes, + # we only send the diff to minimize the communication cost. + scheduled_cached_reqs: CachedRequestData + + # req_id -> num_scheduled_tokens + # Number of tokens scheduled for each request. + num_scheduled_tokens: dict[str, int] + # Total number of tokens scheduled for all requests. + # Equal to sum(num_scheduled_tokens.values()) + total_num_scheduled_tokens: int + # req_id -> spec_token_ids + # If a request does not have any spec decode tokens, it will not be + # included in the dictionary. + scheduled_spec_decode_tokens: dict[str, list[int]] + # req_id -> encoder input indices that need processing. + # E.g., if a request has [0, 1], it could mean the vision encoder needs + # to process that the request's 0-th and 1-th images in the current step. + scheduled_encoder_inputs: dict[str, list[int]] + # Number of common prefix blocks for all requests in each KV cache group. + # This can be used for cascade attention. + num_common_prefix_blocks: list[int] + + # Request IDs that are finished in between the previous and the current + # steps. This is used to notify the workers about the finished requests + # so that they can free the cached states for those requests. + finished_req_ids: set[str] + # list of mm_hash strings associated with the encoder outputs to be + # freed from the encoder cache. + free_encoder_mm_hashes: list[str] + + # Dict of request ids to their index within the batch + # for filling the next token bitmask + structured_output_request_ids: dict[str, int] + # the bitmask for the whole batch + grammar_bitmask: Optional[npt.NDArray[np.int32]] + + # requests that need to recompute kv + recomputed_reqs: list[RecomputeReqInfo] + + # KV Cache Connector metadata. + kv_connector_metadata: Optional[KVConnectorMetadata] = None diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index ac8f011..a67f054 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -300,6 +300,12 @@ class NPUPlatform(Platform): vllm_config.scheduler_config, ascend_config.ascend_scheduler_config) vllm_config.scheduler_config = ascend_scheduler_config + elif ascend_config.recompute_scheduler_enable: + from vllm_ascend.core.recompute_schedule_config import \ + RecomputeSchedulerConfig + recompute_scheduler_config = RecomputeSchedulerConfig.initialize_from_config( + vllm_config.scheduler_config) + vllm_config.scheduler_config = recompute_scheduler_config @classmethod def get_attn_backend_cls(