### What this PR does / why we need it? This is the step 1 of refactoring code to adapt with vllm main, and this pr aligned with17c540a9931. refactor deepseek to the latest code arch as of17c540a9932. bunches of fixes due to vllm changes - Fix `AscendScheduler` `__post_init__`, caused by https://github.com/vllm-project/vllm/pull/25075 - Fix `AscendScheduler` init got an unexpected arg `block_size`, caused by https://github.com/vllm-project/vllm/pull/26296 - Fix `KVCacheManager` `get_num_common_prefix_blocks` arg, caused by https://github.com/vllm-project/vllm/pull/23485 - Fix `MLAAttention` import,caused by https://github.com/vllm-project/vllm/pull/25103 - Fix `SharedFusedMoE` import, caused by https://github.com/vllm-project/vllm/pull/26145 - Fix `LazyLoader` improt, caused by https://github.com/vllm-project/vllm/pull/27022 - Fix `vllm.utils.swap_dict_values` improt, caused by https://github.com/vllm-project/vllm/pull/26990 - Fix `Backend` enum import, caused by https://github.com/vllm-project/vllm/pull/25893 - Fix `CompilationLevel` renaming to `CompilationMode` issue introduced by https://github.com/vllm-project/vllm/pull/26355 - Fix fused_moe ops, caused by https://github.com/vllm-project/vllm/pull/24097 - Fix bert model because of `inputs_embeds`, caused by https://github.com/vllm-project/vllm/pull/25922 - Fix MRope because of `get_input_positions_tensor` to `get_mrope_input_positions`, caused by https://github.com/vllm-project/vllm/pull/24172 - Fix `splitting_ops` changes introduced by https://github.com/vllm-project/vllm/pull/25845 - Fix multi-modality changes introduced by https://github.com/vllm-project/vllm/issues/16229 - Fix lora bias dropping issue introduced by https://github.com/vllm-project/vllm/pull/25807 - Fix structured ouput break introduced by https://github.com/vllm-project/vllm/issues/26737 ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? CI passed with existing test. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: MengqingCao <cmq0113@163.com> Signed-off-by: Icey <1790571317@qq.com> Co-authored-by: Icey <1790571317@qq.com>
615 lines
29 KiB
Python
615 lines
29 KiB
Python
#
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# This file is a part of the vllm-ascend project.
|
|
#
|
|
import os
|
|
import time
|
|
from typing import Optional
|
|
|
|
import pandas as pd
|
|
from vllm.config import VllmConfig
|
|
from vllm.distributed.kv_events import KVEventBatch
|
|
from vllm.logger import logger
|
|
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
|
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
|
from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput
|
|
from vllm.v1.core.sched.request_queue import (SchedulingPolicy,
|
|
create_request_queue)
|
|
from vllm.v1.core.sched.scheduler import Scheduler
|
|
from vllm.v1.engine import EngineCoreEventType
|
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
|
from vllm.v1.request import Request, RequestStatus
|
|
from vllm.v1.structured_output import StructuredOutputManager
|
|
|
|
from vllm_ascend.utils import vllm_version_is
|
|
|
|
|
|
class BudgetRefiner:
|
|
"""This budget refiner can make dynamic adjustment to the token budget
|
|
in the chunked prefill scheduling strategy."""
|
|
|
|
def __init__(self, default_budget, slo_limit=-1) -> None:
|
|
self.enabled = slo_limit > 0
|
|
if not self.enabled:
|
|
return
|
|
logger.info(
|
|
"Dynamic batch is enabled with SLO limit: {}, and chunked prefill is forced to be activated because dynamic batch relies on it"
|
|
.format(str(slo_limit)))
|
|
self.lookup: dict[tuple[int, int], int] = {}
|
|
self.context_keys: set[int] = set()
|
|
self.dnum_keys: set[int] = set()
|
|
self.default_budget = default_budget
|
|
self._read_lookup_table(slo_limit)
|
|
|
|
def _read_lookup_table(self, slo_limit):
|
|
"""Load the lookup table for dynamic budget."""
|
|
base_dir = os.path.dirname(os.path.abspath(__file__))
|
|
table_file_path = os.path.join(base_dir, "profile_table.csv")
|
|
if not os.path.exists(table_file_path):
|
|
# proceed without dynamic batch
|
|
logger.error(
|
|
"The dynamic batching feature requires the lookup table "
|
|
"'profile_table.csv', but it was not found at '%s'. "
|
|
"Please download the corresponding table file.",
|
|
table_file_path)
|
|
self.enabled = False
|
|
return
|
|
else:
|
|
df = pd.read_csv(table_file_path)
|
|
grouped = df.groupby(['ctx_len', 'd_num'])
|
|
for (ctx_len, d_num), group in grouped:
|
|
valid = group[group['cost'] <= slo_limit]
|
|
if not valid.empty:
|
|
max_row = valid.loc[valid['chunk_size'].idxmax()]
|
|
self.lookup[(ctx_len, d_num)] = int(max_row['chunk_size'])
|
|
self.context_keys.add(ctx_len)
|
|
self.dnum_keys.add(d_num)
|
|
self.context_keys = set(sorted(self.context_keys))
|
|
self.dnum_keys = set(sorted(self.dnum_keys))
|
|
|
|
def _align_key(self, value, valid_keys):
|
|
"""Align the minimum value within the valid_keys that is greater than the value."""
|
|
for k in valid_keys:
|
|
if k >= value:
|
|
return k
|
|
return None
|
|
|
|
def _get_max_budget(self, num_deocde_tokens, num_decode):
|
|
"""Get the maximum budget according to the number of decoding tokens and the decoding requests."""
|
|
aligned_ctx = self._align_key(num_deocde_tokens, self.context_keys)
|
|
aligned_dnum = self._align_key(num_decode, self.dnum_keys)
|
|
if aligned_ctx is None or aligned_dnum is None:
|
|
return self.default_budget
|
|
budget = self.lookup.get((aligned_ctx, aligned_dnum), None)
|
|
if budget is None:
|
|
logger.warn(f"Table miss for ctx,dnum{aligned_ctx, aligned_dnum}")
|
|
budget = self.default_budget
|
|
# For debug.
|
|
# logger.info(f"budget {budget}, ctx,dnum {aligned_ctx, aligned_dnum}, raw ctx,dnum {num_deocde_tokens, num_decode}")
|
|
return budget
|
|
|
|
def refine_budget(self, running_request, budget):
|
|
"""Dynamically refine the token budget according to the running request."""
|
|
if not self.enabled:
|
|
return budget
|
|
# assume all running request will be scheduled.
|
|
num_decode_token_lst = [
|
|
req.num_tokens_with_spec \
|
|
for req in running_request \
|
|
if req.num_computed_tokens >= req.num_prompt_tokens ]
|
|
num_decode = len(num_decode_token_lst)
|
|
if num_decode <= 0:
|
|
return budget
|
|
num_deocde_tokens = sum(num_decode_token_lst) / num_decode
|
|
return self._get_max_budget(num_deocde_tokens, num_decode)
|
|
|
|
|
|
class SchedulerDynamicBatch(Scheduler):
|
|
"""This Scheduler extends vllm's original v1 scheduler
|
|
with dynamic batch."""
|
|
|
|
def __init__(
|
|
self,
|
|
vllm_config: VllmConfig,
|
|
kv_cache_config: KVCacheConfig,
|
|
structured_output_manager: StructuredOutputManager,
|
|
block_size: Optional[int] = None,
|
|
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
|
include_finished_set: bool = False,
|
|
log_stats: bool = False,
|
|
) -> None:
|
|
if vllm_version_is("0.11.0"):
|
|
super().__init__(vllm_config, kv_cache_config,
|
|
structured_output_manager, mm_registry,
|
|
include_finished_set, log_stats)
|
|
else:
|
|
super().__init__(vllm_config, kv_cache_config,
|
|
structured_output_manager, block_size,
|
|
mm_registry, include_finished_set, log_stats)
|
|
self.running: list[Request] = []
|
|
self.budget_refiner = BudgetRefiner(
|
|
default_budget=self.scheduler_config.max_num_batched_tokens,
|
|
slo_limit=self.scheduler_config.SLO_limits_for_dynamic_batch)
|
|
|
|
def schedule(self) -> SchedulerOutput:
|
|
# NOTE: This scheduling algorithm is developed based on the "super.schedule()"
|
|
# with the implementations of the dynamic batch and some modifications:
|
|
# 1. Token budget can be dynamically refined according to the self.running
|
|
# through the BudgetRefiner;
|
|
# 2. This scheduling algorithm follows decode-first chunked prefills and FCFS
|
|
# strategy, which is slightly different to the "super.schedule()"
|
|
# 3. Similar to the "super.schedule()", at each step, the scheduler tries to
|
|
# assign tokens to the requests so that each request's num_computed_tokens can
|
|
# catch up its num_tokens_with_spec.
|
|
# 4. So far, the dynamic batch only supports 910B3 NPU. Further work will include
|
|
# more devices and finer optimization strategy.
|
|
|
|
scheduled_new_reqs: list[Request] = []
|
|
scheduled_resumed_reqs: list[Request] = []
|
|
scheduled_running_reqs: list[Request] = []
|
|
preempted_reqs: list[Request] = []
|
|
|
|
req_to_new_blocks: dict[str, KVCacheBlocks] = {}
|
|
num_scheduled_tokens: dict[str, int] = {}
|
|
token_budget = self.max_num_scheduled_tokens
|
|
token_budget = self.budget_refiner.refine_budget(
|
|
self.running, token_budget)
|
|
|
|
# NOTE: We move the prefill requests to the end of the self.running
|
|
# list and keep the relative order unchanged. This rearrangement makes this
|
|
# scheduling algorithm a strict decode-first chunked prefills.
|
|
d_lst = [
|
|
req for req in self.running
|
|
if req.num_computed_tokens >= req.num_prompt_tokens
|
|
]
|
|
p_lst = [
|
|
req for req in self.running
|
|
if req.num_computed_tokens < req.num_prompt_tokens
|
|
]
|
|
self.running = d_lst + p_lst
|
|
|
|
# Encoder-related.
|
|
scheduled_encoder_inputs: dict[str, list[int]] = {}
|
|
encoder_compute_budget = self.max_num_encoder_input_tokens
|
|
# Spec decode-related.
|
|
scheduled_spec_decode_tokens: dict[str, list[int]] = {}
|
|
|
|
# For logging.
|
|
scheduled_timestamp = time.monotonic()
|
|
|
|
# First, schedule the RUNNING requests.
|
|
req_index = 0
|
|
while req_index < len(self.running) and token_budget > 0:
|
|
request = self.running[req_index]
|
|
|
|
num_new_tokens = (request.num_tokens_with_spec +
|
|
request.num_output_placeholders -
|
|
request.num_computed_tokens)
|
|
if (0 < self.scheduler_config.long_prefill_token_threshold <
|
|
num_new_tokens):
|
|
num_new_tokens = (
|
|
self.scheduler_config.long_prefill_token_threshold)
|
|
num_new_tokens = min(num_new_tokens, token_budget)
|
|
|
|
# Make sure the input position does not exceed the max model len.
|
|
# This is necessary when using spec decoding.
|
|
num_new_tokens = min(
|
|
num_new_tokens,
|
|
self.max_model_len - 1 - request.num_computed_tokens)
|
|
|
|
# Schedule encoder inputs.
|
|
encoder_inputs_to_schedule = None
|
|
new_encoder_compute_budget = encoder_compute_budget
|
|
if request.has_encoder_inputs:
|
|
(encoder_inputs_to_schedule, num_new_tokens,
|
|
new_encoder_compute_budget
|
|
) = self._try_schedule_encoder_inputs(
|
|
request, request.num_computed_tokens, num_new_tokens,
|
|
encoder_compute_budget)
|
|
|
|
if num_new_tokens == 0:
|
|
# The request cannot be scheduled because one of the following
|
|
# reasons:
|
|
# 1. No new tokens to schedule. This may happen when
|
|
# (1) PP>1 and we have already scheduled all prompt tokens
|
|
# but they are not finished yet.
|
|
# (2) Async scheduling and the request has reached to either
|
|
# its max_total_tokens or max_model_len.
|
|
# 2. The encoder budget is exhausted.
|
|
# 3. The encoder cache is exhausted.
|
|
# NOTE(woosuk): Here, by doing `break` instead of `continue` as
|
|
# in v1 scheduler, we strictly follow the FCFS scheduling policy.
|
|
req_index += 1
|
|
break
|
|
|
|
while True:
|
|
new_blocks = self.kv_cache_manager.allocate_slots(
|
|
request,
|
|
num_new_tokens,
|
|
num_lookahead_tokens=self.num_lookahead_tokens)
|
|
if new_blocks is None:
|
|
# The request cannot be scheduled.
|
|
# Preempt the lowest-priority request.
|
|
if self.policy == SchedulingPolicy.PRIORITY:
|
|
preempted_req = max(
|
|
self.running,
|
|
key=lambda r: (r.priority, r.arrival_time),
|
|
)
|
|
self.running.remove(preempted_req)
|
|
if preempted_req in scheduled_running_reqs:
|
|
scheduled_running_reqs.remove(preempted_req)
|
|
else:
|
|
preempted_req = self.running.pop()
|
|
|
|
self.kv_cache_manager.free(preempted_req)
|
|
self.encoder_cache_manager.free(preempted_req)
|
|
preempted_req.status = RequestStatus.PREEMPTED
|
|
preempted_req.num_computed_tokens = 0
|
|
if self.log_stats:
|
|
preempted_req.record_event(
|
|
EngineCoreEventType.PREEMPTED, scheduled_timestamp)
|
|
|
|
self.waiting.prepend_request(preempted_req)
|
|
preempted_reqs.append(preempted_req)
|
|
if preempted_req == request:
|
|
# No more request to preempt.
|
|
can_schedule = False
|
|
break
|
|
else:
|
|
# The request can be scheduled.
|
|
can_schedule = True
|
|
break
|
|
if not can_schedule:
|
|
break
|
|
assert new_blocks is not None
|
|
|
|
# Schedule the request.
|
|
scheduled_running_reqs.append(request)
|
|
req_to_new_blocks[request.request_id] = new_blocks
|
|
num_scheduled_tokens[request.request_id] = num_new_tokens
|
|
token_budget -= num_new_tokens
|
|
req_index += 1
|
|
|
|
# Speculative decode related.
|
|
if request.spec_token_ids:
|
|
num_scheduled_spec_tokens = (num_new_tokens +
|
|
request.num_computed_tokens -
|
|
request.num_tokens)
|
|
if num_scheduled_spec_tokens > 0:
|
|
# Trim spec_token_ids list to num_scheduled_spec_tokens.
|
|
del request.spec_token_ids[num_scheduled_spec_tokens:]
|
|
scheduled_spec_decode_tokens[request.request_id] = (
|
|
request.spec_token_ids)
|
|
|
|
# Encoder-related.
|
|
if encoder_inputs_to_schedule:
|
|
scheduled_encoder_inputs[request.request_id] = (
|
|
encoder_inputs_to_schedule)
|
|
# Allocate the encoder cache.
|
|
for i in encoder_inputs_to_schedule:
|
|
self.encoder_cache_manager.allocate(request, i)
|
|
encoder_compute_budget = new_encoder_compute_budget
|
|
|
|
# Record the LoRAs in scheduled_running_reqs
|
|
scheduled_loras: set[int] = set()
|
|
if self.lora_config:
|
|
scheduled_loras = set(
|
|
req.lora_request.lora_int_id for req in scheduled_running_reqs
|
|
if req.lora_request and req.lora_request.lora_int_id > 0)
|
|
assert len(scheduled_loras) <= self.lora_config.max_loras
|
|
|
|
# Use a temporary RequestQueue to collect requests that need to be
|
|
# skipped and put back at the head of the waiting queue later
|
|
skipped_waiting_requests = create_request_queue(self.policy)
|
|
|
|
# Next, schedule the WAITING requests.
|
|
if not preempted_reqs:
|
|
while self.waiting and token_budget > 0:
|
|
if len(self.running) == self.max_num_running_reqs:
|
|
break
|
|
|
|
request = self.waiting.peek_request()
|
|
|
|
# KVTransfer: skip request if still waiting for remote kvs.
|
|
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
|
|
is_ready = self._update_waiting_for_remote_kv(request)
|
|
if is_ready:
|
|
request.status = RequestStatus.WAITING
|
|
else:
|
|
logger.debug(
|
|
"%s is still in WAITING_FOR_REMOTE_KVS state.",
|
|
request.request_id)
|
|
self.waiting.pop_request()
|
|
skipped_waiting_requests.prepend_request(request)
|
|
continue
|
|
|
|
# Skip request if the structured output request is still waiting
|
|
# for FSM compilation.
|
|
if request.status == RequestStatus.WAITING_FOR_FSM:
|
|
structured_output_req = request.structured_output_request
|
|
if structured_output_req and structured_output_req.grammar:
|
|
request.status = RequestStatus.WAITING
|
|
else:
|
|
self.waiting.pop_request()
|
|
skipped_waiting_requests.prepend_request(request)
|
|
continue
|
|
|
|
# Check that adding the request still respects the max_loras
|
|
# constraint.
|
|
if (self.lora_config and request.lora_request and
|
|
(len(scheduled_loras) == self.lora_config.max_loras and
|
|
request.lora_request.lora_int_id not in scheduled_loras)):
|
|
# Scheduling would exceed max_loras, skip.
|
|
self.waiting.pop_request()
|
|
skipped_waiting_requests.prepend_request(request)
|
|
continue
|
|
|
|
num_external_computed_tokens = 0
|
|
load_kv_async = False
|
|
|
|
# Get already-cached tokens.
|
|
if request.num_computed_tokens == 0:
|
|
# Get locally-cached tokens.
|
|
new_computed_blocks, num_new_local_computed_tokens = \
|
|
self.kv_cache_manager.get_computed_blocks(
|
|
request)
|
|
|
|
# Get externally-cached tokens if using a KVConnector.
|
|
if self.connector is not None:
|
|
num_external_computed_tokens, load_kv_async = (
|
|
self.connector.get_num_new_matched_tokens(
|
|
request, num_new_local_computed_tokens))
|
|
|
|
if num_external_computed_tokens is None:
|
|
# The request cannot be scheduled because
|
|
# the KVConnector couldn't determine
|
|
# the number of matched tokens.
|
|
self.waiting.pop_request()
|
|
skipped_waiting_requests.prepend_request(request)
|
|
continue
|
|
|
|
# Total computed tokens (local + external).
|
|
num_computed_tokens = (num_new_local_computed_tokens +
|
|
num_external_computed_tokens)
|
|
# KVTransfer: WAITING reqs have num_computed_tokens > 0
|
|
# after async KV recvs are completed.
|
|
else:
|
|
new_computed_blocks = (
|
|
self.kv_cache_manager.create_empty_block_list())
|
|
num_new_local_computed_tokens = 0
|
|
num_computed_tokens = request.num_computed_tokens
|
|
|
|
encoder_inputs_to_schedule = None
|
|
new_encoder_compute_budget = encoder_compute_budget
|
|
|
|
# KVTransfer: loading remote KV, do not allocate for new work.
|
|
if load_kv_async:
|
|
assert num_external_computed_tokens > 0
|
|
num_new_tokens = 0
|
|
# Number of tokens to be scheduled.
|
|
else:
|
|
# We use `request.num_tokens` instead of
|
|
# `request.num_prompt_tokens` to consider the resumed
|
|
# requests, which have output tokens.
|
|
num_new_tokens = request.num_tokens - num_computed_tokens
|
|
if (0 < self.scheduler_config.long_prefill_token_threshold
|
|
< num_new_tokens):
|
|
num_new_tokens = (
|
|
self.scheduler_config.long_prefill_token_threshold)
|
|
|
|
# chunked prefill has to be enabled explicitly to allow
|
|
# pooling requests to be chunked
|
|
if not self.scheduler_config.chunked_prefill_enabled and \
|
|
num_new_tokens > token_budget:
|
|
self.waiting.pop_request()
|
|
skipped_waiting_requests.prepend_request(request)
|
|
continue
|
|
|
|
num_new_tokens = min(num_new_tokens, token_budget)
|
|
assert num_new_tokens > 0
|
|
|
|
# Schedule encoder inputs.
|
|
if request.has_encoder_inputs:
|
|
(encoder_inputs_to_schedule, num_new_tokens,
|
|
new_encoder_compute_budget
|
|
) = self._try_schedule_encoder_inputs(
|
|
request, num_computed_tokens, num_new_tokens,
|
|
encoder_compute_budget)
|
|
if num_new_tokens == 0:
|
|
# The request cannot be scheduled.
|
|
break
|
|
|
|
# Handles an edge case when P/D Disaggregation
|
|
# is used with Spec Decoding where an
|
|
# extra block gets allocated which
|
|
# creates a mismatch between the number
|
|
# of local and remote blocks.
|
|
effective_lookahead_tokens = (0 if request.num_computed_tokens
|
|
== 0 else
|
|
self.num_lookahead_tokens)
|
|
|
|
# Determine if we need to allocate cross-attention blocks.
|
|
if self.is_encoder_decoder and request.has_encoder_inputs:
|
|
# TODO(russellb): For Whisper, we know that the input is
|
|
# always padded to the maximum length. If we support other
|
|
# encoder-decoder models, this will need to be updated if we
|
|
# want to only allocate what is needed.
|
|
num_encoder_tokens =\
|
|
self.scheduler_config.max_num_encoder_input_tokens
|
|
else:
|
|
num_encoder_tokens = 0
|
|
|
|
new_blocks = self.kv_cache_manager.allocate_slots(
|
|
request,
|
|
num_new_tokens + num_external_computed_tokens,
|
|
num_new_local_computed_tokens,
|
|
new_computed_blocks,
|
|
num_lookahead_tokens=effective_lookahead_tokens,
|
|
delay_cache_blocks=load_kv_async,
|
|
num_encoder_tokens=num_encoder_tokens,
|
|
)
|
|
|
|
if new_blocks is None:
|
|
# The request cannot be scheduled.
|
|
break
|
|
|
|
# KVTransfer: the connector uses this info to determine
|
|
# if a load is needed. Note that
|
|
# This information is used to determine if a load is
|
|
# needed for this request.
|
|
if self.connector is not None:
|
|
self.connector.update_state_after_alloc(
|
|
request,
|
|
new_computed_blocks + new_blocks,
|
|
num_external_computed_tokens,
|
|
)
|
|
|
|
# Request was already popped from self.waiting
|
|
# unless it was re-added above due to new_blocks being None.
|
|
request = self.waiting.pop_request()
|
|
if load_kv_async:
|
|
# If loading async, allocate memory and put request
|
|
# into the WAITING_FOR_REMOTE_KV state.
|
|
skipped_waiting_requests.prepend_request(request)
|
|
request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
|
|
continue
|
|
|
|
req_index += 1
|
|
self.running.append(request)
|
|
if self.log_stats:
|
|
request.record_event(EngineCoreEventType.SCHEDULED,
|
|
scheduled_timestamp)
|
|
if request.status == RequestStatus.WAITING:
|
|
scheduled_new_reqs.append(request)
|
|
elif request.status == RequestStatus.PREEMPTED:
|
|
scheduled_resumed_reqs.append(request)
|
|
else:
|
|
raise RuntimeError(
|
|
f"Invalid request status: {request.status}")
|
|
|
|
if self.lora_config and request.lora_request:
|
|
scheduled_loras.add(request.lora_request.lora_int_id)
|
|
req_to_new_blocks[request.request_id] = (
|
|
self.kv_cache_manager.get_blocks(request.request_id))
|
|
num_scheduled_tokens[request.request_id] = num_new_tokens
|
|
token_budget -= num_new_tokens
|
|
request.status = RequestStatus.RUNNING
|
|
request.num_computed_tokens = num_computed_tokens
|
|
# Count the number of prefix cached tokens.
|
|
if request.num_cached_tokens < 0:
|
|
request.num_cached_tokens = num_computed_tokens
|
|
# Encoder-related.
|
|
if encoder_inputs_to_schedule:
|
|
scheduled_encoder_inputs[request.request_id] = (
|
|
encoder_inputs_to_schedule)
|
|
# Allocate the encoder cache.
|
|
for i in encoder_inputs_to_schedule:
|
|
self.encoder_cache_manager.allocate(request, i)
|
|
encoder_compute_budget = new_encoder_compute_budget
|
|
|
|
# Put back any skipped requests at the head of the waiting queue
|
|
if skipped_waiting_requests:
|
|
self.waiting.prepend_requests(skipped_waiting_requests)
|
|
|
|
# Check if the scheduling constraints are satisfied.
|
|
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
|
|
assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens
|
|
assert token_budget >= 0
|
|
assert len(self.running) <= self.max_num_running_reqs
|
|
# Since some requests in the RUNNING queue may not be scheduled in
|
|
# this step, the total number of scheduled requests can be smaller than
|
|
# len(self.running).
|
|
assert (len(scheduled_new_reqs) + len(scheduled_resumed_reqs) +
|
|
len(scheduled_running_reqs) <= len(self.running))
|
|
|
|
# Get the longest common prefix among all requests in the running queue.
|
|
# This can be potentially used for cascade attention.
|
|
num_common_prefix_blocks = [0] * len(
|
|
self.kv_cache_config.kv_cache_groups)
|
|
if self.running:
|
|
any_request = self.running[0]
|
|
if vllm_version_is("0.11.0"):
|
|
num_common_prefix_blocks = (
|
|
self.kv_cache_manager.get_num_common_prefix_blocks(
|
|
any_request, len(self.running)))
|
|
else:
|
|
num_common_prefix_blocks = (
|
|
self.kv_cache_manager.get_num_common_prefix_blocks(
|
|
any_request.request_id))
|
|
# Construct the scheduler output.
|
|
new_reqs_data = [
|
|
NewRequestData.from_request(
|
|
req, req_to_new_blocks[req.request_id].get_block_ids())
|
|
for req in scheduled_new_reqs
|
|
]
|
|
cached_reqs_data = self._make_cached_request_data(
|
|
scheduled_running_reqs,
|
|
scheduled_resumed_reqs,
|
|
num_scheduled_tokens,
|
|
scheduled_spec_decode_tokens,
|
|
req_to_new_blocks,
|
|
)
|
|
scheduled_requests = (scheduled_new_reqs + scheduled_running_reqs +
|
|
scheduled_resumed_reqs)
|
|
structured_output_request_ids, grammar_bitmask = (
|
|
self.get_grammar_bitmask(scheduled_requests,
|
|
scheduled_spec_decode_tokens))
|
|
scheduler_output = SchedulerOutput(
|
|
scheduled_new_reqs=new_reqs_data,
|
|
scheduled_cached_reqs=cached_reqs_data,
|
|
num_scheduled_tokens=num_scheduled_tokens,
|
|
total_num_scheduled_tokens=total_num_scheduled_tokens,
|
|
scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,
|
|
scheduled_encoder_inputs=scheduled_encoder_inputs,
|
|
num_common_prefix_blocks=num_common_prefix_blocks,
|
|
# finished_req_ids is an existing state in the scheduler,
|
|
# instead of being newly scheduled in this step.
|
|
# It contains the request IDs that are finished in between
|
|
# the previous and the current steps.
|
|
finished_req_ids=self.finished_req_ids,
|
|
free_encoder_mm_hashes=self.encoder_cache_manager.
|
|
get_freed_mm_hashes(),
|
|
structured_output_request_ids=structured_output_request_ids,
|
|
grammar_bitmask=grammar_bitmask,
|
|
)
|
|
|
|
# NOTE(Kuntai): this function is designed for multiple purposes:
|
|
# 1. Plan the KV cache store
|
|
# 2. Wrap up all the KV cache load / save ops into an opaque object
|
|
# 3. Clear the internal states of the connector
|
|
if self.connector is not None:
|
|
meta = self.connector.build_connector_meta(scheduler_output)
|
|
scheduler_output.kv_connector_metadata = meta
|
|
|
|
# collect KV cache events from KV cache manager
|
|
events = self.kv_cache_manager.take_events()
|
|
|
|
# collect KV cache events from connector
|
|
if self.connector is not None:
|
|
connector_events = self.connector.take_events()
|
|
if connector_events:
|
|
if events is None:
|
|
events = list(connector_events)
|
|
else:
|
|
events.extend(connector_events)
|
|
|
|
# publish collected KV cache events
|
|
if events:
|
|
batch = KVEventBatch(ts=time.time(), events=events)
|
|
self.kv_event_publisher.publish(batch)
|
|
|
|
self._update_after_schedule(scheduler_output)
|
|
return scheduler_output
|