Implement save kv cache logic for v1 disaggregated prefill in ascend scheduler This PR adds support for saving kv cache in the ascend scheduler, which is part of the v1 disaggregated prefill design. The load functionality is not yet implemented. Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
529 lines
24 KiB
Python
529 lines
24 KiB
Python
#
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# This file is a part of the vllm-ascend project.
|
|
#
|
|
from collections import deque
|
|
from typing import Iterable, Optional, Union
|
|
|
|
from vllm.config import VllmConfig
|
|
from vllm.logger import logger
|
|
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
|
from vllm.utils import cdiv
|
|
from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput
|
|
from vllm.v1.core.sched.scheduler import Scheduler
|
|
from vllm.v1.core.sched.utils import check_stop
|
|
from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs
|
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
|
from vllm.v1.outputs import ModelRunnerOutput
|
|
from vllm.v1.request import Request, RequestStatus
|
|
from vllm.v1.spec_decode.metrics import SpecDecodingStats
|
|
from vllm.v1.structured_output import StructuredOutputManager
|
|
|
|
|
|
class AscendScheduler(Scheduler):
|
|
"""This Scheduler extends vllm's original v1 scheduler
|
|
with prefill-first scheduling strategy."""
|
|
|
|
def __init__(
|
|
self,
|
|
vllm_config: VllmConfig,
|
|
kv_cache_config: KVCacheConfig,
|
|
structured_output_manager: StructuredOutputManager,
|
|
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
|
include_finished_set: bool = False,
|
|
log_stats: bool = False,
|
|
) -> None:
|
|
super().__init__(vllm_config, kv_cache_config,
|
|
structured_output_manager, mm_registry,
|
|
include_finished_set, log_stats)
|
|
self.scheduled_req_ids: set[str] = set()
|
|
self.running: list[Request] = []
|
|
|
|
if self.vllm_config.kv_transfer_config is not None and \
|
|
self.vllm_config.kv_transfer_config.is_kv_consumer:
|
|
raise ValueError(
|
|
"AscendScheduler cannot be used for decode nodes. ")
|
|
|
|
def schedule(self) -> SchedulerOutput:
|
|
if self.scheduler_config.chunked_prefill_enabled:
|
|
return super().schedule()
|
|
scheduled_new_reqs: list[Request] = []
|
|
scheduled_resumed_reqs: list[Request] = []
|
|
scheduled_running_reqs: list[Request] = []
|
|
preempted_reqs: list[Request] = []
|
|
|
|
req_to_new_block_ids: dict[str, list[int]] = {}
|
|
num_scheduled_tokens: dict[str, int] = {}
|
|
token_budget = self.max_num_scheduled_tokens
|
|
# Spec decode-related.
|
|
scheduled_spec_decode_tokens: dict[str, list[int]] = {}
|
|
|
|
# Record scheduled LoRA requests.
|
|
scheduled_loras: set[int] = set()
|
|
|
|
# Use a temporary deque to collect requests that need to be skipped
|
|
# and put back at the head of the waiting queue later
|
|
skipped_waiting_requests: deque[Request] = deque()
|
|
|
|
# Schedule prefill requests first.
|
|
while self.waiting and token_budget > 0:
|
|
if len(self.running) == self.max_num_running_reqs:
|
|
break
|
|
|
|
request = self.waiting[0]
|
|
|
|
def skip_cur_request():
|
|
self.waiting.popleft()
|
|
skipped_waiting_requests.appendleft(request)
|
|
|
|
# Check that adding the request still respects the max_loras
|
|
# constraint.
|
|
if (self.lora_config and request.lora_request and
|
|
(len(scheduled_loras) == self.lora_config.max_loras
|
|
and request.lora_request.lora_int_id not in scheduled_loras)):
|
|
# Scheduling would exceed max_loras, skip.
|
|
skip_cur_request()
|
|
continue
|
|
|
|
prompt_limit = self._get_prompt_limit(request)
|
|
# Get already-cached tokens.
|
|
computed_blocks, num_computed_tokens = (
|
|
self.kv_cache_manager.get_computed_blocks(request))
|
|
num_new_tokens = request.num_tokens - num_computed_tokens
|
|
if (0 < self.scheduler_config.long_prefill_token_threshold <
|
|
num_new_tokens):
|
|
num_new_tokens = (
|
|
self.scheduler_config.long_prefill_token_threshold)
|
|
max_tokens_in_kvcache = (self.kv_cache_config.num_blocks *
|
|
self.block_size)
|
|
prompt_limit = min(prompt_limit, max_tokens_in_kvcache)
|
|
|
|
# Finish request that exceeds prompt_limit or kv cache size.
|
|
if num_new_tokens > prompt_limit:
|
|
logger.warning(
|
|
"Input prompt (%d tokens) is too long"
|
|
" and exceeds limit of %d",
|
|
num_new_tokens,
|
|
prompt_limit,
|
|
)
|
|
request.status = RequestStatus.FINISHED_IGNORED
|
|
self.finished_req_ids.add(request.request_id) # type: ignore
|
|
self.waiting.popleft()
|
|
continue
|
|
|
|
if num_new_tokens > token_budget:
|
|
# Scheduling would exceed token_budget, skip.
|
|
skip_cur_request()
|
|
continue
|
|
|
|
assert num_new_tokens > 0
|
|
watermark = getattr(self.scheduler_config, "watermark", 0.01)
|
|
if not self._check_watermark_for_prefill(
|
|
request, num_new_tokens, computed_blocks, watermark):
|
|
# Scheduling would exceed watermark, skip.
|
|
skip_cur_request()
|
|
continue
|
|
|
|
new_blocks = self.kv_cache_manager.allocate_slots(
|
|
request, num_new_tokens, computed_blocks)
|
|
if new_blocks is None:
|
|
# The request cannot be scheduled.
|
|
break
|
|
|
|
self.waiting.popleft()
|
|
self.running.append(request)
|
|
self.scheduled_req_ids.add(request.request_id)
|
|
# Check request status.
|
|
if request.status == RequestStatus.WAITING:
|
|
scheduled_new_reqs.append(request)
|
|
elif request.status == RequestStatus.PREEMPTED:
|
|
scheduled_resumed_reqs.append(request)
|
|
else:
|
|
raise RuntimeError(f"Invalid request status: {request.status}")
|
|
|
|
if self.lora_config and request.lora_request:
|
|
scheduled_loras.add(request.lora_request.lora_int_id)
|
|
req_to_new_block_ids[request.request_id] = [
|
|
b.block_id for b in computed_blocks + new_blocks
|
|
]
|
|
# Update request info.
|
|
num_scheduled_tokens[request.request_id] = num_new_tokens
|
|
token_budget -= num_new_tokens
|
|
request.status = RequestStatus.RUNNING
|
|
request.num_computed_tokens = num_computed_tokens
|
|
|
|
# Put back any skipped requests at the head of the waiting queue
|
|
if skipped_waiting_requests:
|
|
self.waiting.extendleft(skipped_waiting_requests)
|
|
|
|
# If no prefill requests are scheduled,
|
|
# Schedule decode requests next.
|
|
if len(self.scheduled_req_ids) == 0:
|
|
req_index = 0
|
|
while req_index < len(self.running) and token_budget > 0:
|
|
request = self.running[req_index]
|
|
if request.request_id in self.scheduled_req_ids:
|
|
# This request has already been scheduled.
|
|
req_index += 1
|
|
continue
|
|
|
|
num_new_tokens = (request.num_tokens_with_spec -
|
|
request.num_computed_tokens)
|
|
if (0 < self.scheduler_config.long_prefill_token_threshold <
|
|
num_new_tokens):
|
|
num_new_tokens = (
|
|
self.scheduler_config.long_prefill_token_threshold)
|
|
num_new_tokens = min(num_new_tokens, token_budget)
|
|
assert num_new_tokens == 1
|
|
|
|
while True:
|
|
new_blocks = self.kv_cache_manager.allocate_slots(
|
|
request, num_new_tokens)
|
|
if new_blocks is None:
|
|
# The request cannot be scheduled.
|
|
# Preempt the lowest-priority request.
|
|
preempted_req = self.running.pop()
|
|
self.kv_cache_manager.free(preempted_req)
|
|
preempted_req.status = RequestStatus.PREEMPTED
|
|
preempted_req.num_computed_tokens = 0
|
|
self.waiting.appendleft(preempted_req)
|
|
preempted_reqs.append(preempted_req)
|
|
if preempted_req == request:
|
|
# No more request to preempt.
|
|
can_schedule = False
|
|
break
|
|
else:
|
|
# The request can be scheduled.
|
|
can_schedule = True
|
|
break
|
|
if not can_schedule:
|
|
break
|
|
assert new_blocks is not None
|
|
|
|
# Schedule the request.
|
|
scheduled_running_reqs.append(request)
|
|
self.scheduled_req_ids.add(request.request_id)
|
|
req_to_new_block_ids[request.request_id] = [
|
|
b.block_id for b in new_blocks
|
|
]
|
|
num_scheduled_tokens[request.request_id] = num_new_tokens
|
|
token_budget -= num_new_tokens
|
|
req_index += 1
|
|
|
|
# Speculative decode related.
|
|
if request.spec_token_ids:
|
|
num_scheduled_spec_tokens = (num_new_tokens +
|
|
request.num_computed_tokens -
|
|
request.num_tokens)
|
|
if num_scheduled_spec_tokens > 0:
|
|
# Trim spec_token_ids list to num_scheduled_spec_tokens.
|
|
del request.spec_token_ids[num_scheduled_spec_tokens:]
|
|
scheduled_spec_decode_tokens[request.request_id] = (
|
|
request.spec_token_ids)
|
|
|
|
# Check if the scheduling constraints are satisfied.
|
|
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
|
|
assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens
|
|
assert token_budget >= 0
|
|
assert len(self.running) <= self.max_num_running_reqs
|
|
assert len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len(
|
|
scheduled_running_reqs) <= len(self.running)
|
|
|
|
# Get the longest common prefix among all requests in the running queue.
|
|
# This can be potentially used for cascade attention.
|
|
num_common_prefix_blocks = 0
|
|
if self.running:
|
|
any_request = self.running[0]
|
|
num_common_prefix_blocks = (
|
|
self.kv_cache_manager.get_num_common_prefix_blocks(
|
|
any_request, len(self.running)))
|
|
|
|
# Construct the scheduler output.
|
|
new_reqs_data = [
|
|
NewRequestData.from_request(req,
|
|
req_to_new_block_ids[req.request_id])
|
|
for req in scheduled_new_reqs
|
|
]
|
|
resumed_reqs_data = [
|
|
self._make_cached_request_data(
|
|
req,
|
|
num_scheduled_tokens[req.request_id],
|
|
len(scheduled_spec_decode_tokens.get(req.request_id, ())),
|
|
req_to_new_block_ids[req.request_id],
|
|
resumed_from_preemption=True,
|
|
) for req in scheduled_resumed_reqs
|
|
]
|
|
running_reqs_data = [
|
|
self._make_cached_request_data(
|
|
req,
|
|
num_scheduled_tokens[req.request_id],
|
|
len(scheduled_spec_decode_tokens.get(req.request_id, ())),
|
|
req_to_new_block_ids[req.request_id],
|
|
resumed_from_preemption=False,
|
|
) for req in scheduled_running_reqs
|
|
]
|
|
scheduler_output = SchedulerOutput(
|
|
scheduled_new_reqs=new_reqs_data,
|
|
scheduled_cached_reqs=resumed_reqs_data + running_reqs_data,
|
|
num_scheduled_tokens=num_scheduled_tokens,
|
|
total_num_scheduled_tokens=total_num_scheduled_tokens,
|
|
scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,
|
|
scheduled_encoder_inputs={},
|
|
num_common_prefix_blocks=num_common_prefix_blocks,
|
|
# finished_req_ids is an existing state in the scheduler,
|
|
# instead of being newly scheduled in this step.
|
|
# It contains the request IDs that are finished in between
|
|
# the previous and the current steps.
|
|
finished_req_ids=self.finished_req_ids, # type: ignore
|
|
free_encoder_input_ids=self.encoder_cache_manager.get_freed_ids(),
|
|
structured_output_request_ids={},
|
|
grammar_bitmask=None,
|
|
)
|
|
|
|
# NOTE(Kuntai): this function is designed for multiple purposes:
|
|
# 1. Plan the KV cache store
|
|
# 2. Wrap up all the KV cache load / save ops into an opaque object
|
|
# 3. Clear the internal states of the connector
|
|
if self.connector is not None:
|
|
meta = self.connector.build_connector_meta(scheduler_output)
|
|
scheduler_output.kv_connector_metadata = meta
|
|
|
|
# Advance the number of computed tokens for the request AFTER
|
|
# the request is scheduled.
|
|
# 1. The scheduler_output of the current step has to include the
|
|
# original number of scheduled tokens to determine input IDs.
|
|
# 2. Advance the number of computed tokens here allowing us to
|
|
# schedule the prefill request again immediately in the next
|
|
# scheduling step.
|
|
# 3. If some tokens (e.g. spec tokens) are rejected later, the number of
|
|
# computed tokens will be adjusted in update_from_output.
|
|
for req_id, num_scheduled_token in num_scheduled_tokens.items():
|
|
self.requests[req_id].num_computed_tokens += num_scheduled_token
|
|
|
|
self.finished_req_ids = set() # type: ignore
|
|
return scheduler_output
|
|
|
|
def _check_watermark_for_prefill(self,
|
|
request,
|
|
num_new_tokens,
|
|
computed_blocks,
|
|
watermark=0.01):
|
|
computed_blocks = computed_blocks or []
|
|
watermark_blocks = self.kv_cache_config.num_blocks * watermark
|
|
num_computed_tokens = (request.num_computed_tokens +
|
|
len(computed_blocks) * self.block_size)
|
|
num_required_blocks = cdiv(num_new_tokens + num_computed_tokens,
|
|
self.block_size)
|
|
req_blocks = self.kv_cache_manager.req_to_blocks[request.request_id]
|
|
num_new_blocks = (num_required_blocks - len(req_blocks) -
|
|
len(computed_blocks))
|
|
num_evictable_computed_blocks = sum(1 for blk in computed_blocks
|
|
if blk.ref_cnt == 0)
|
|
# If number of free blocks is less than water mark after allocating, don't allocate.
|
|
if (self.kv_cache_manager.block_pool.get_num_free_blocks() -
|
|
num_evictable_computed_blocks -
|
|
num_new_blocks) < watermark_blocks:
|
|
return False
|
|
return True
|
|
|
|
def _get_prompt_limit(self, request: Request) -> int:
|
|
if (self.scheduler_config.chunked_prefill_enabled
|
|
and not self.scheduler_config.is_multi_step):
|
|
prompt_limit = self.scheduler_config.max_model_len
|
|
else:
|
|
prompt_limit = min(
|
|
self.scheduler_config.max_model_len,
|
|
self.scheduler_config.max_num_batched_tokens,
|
|
)
|
|
|
|
# Model is fine tuned with long context. Return the fine tuned max_len.
|
|
if request.lora_request and request.lora_request.long_lora_max_len:
|
|
assert prompt_limit <= request.lora_request.long_lora_max_len
|
|
return request.lora_request.long_lora_max_len
|
|
else:
|
|
return prompt_limit
|
|
|
|
def finish_requests(
|
|
self,
|
|
request_ids: Union[str, Iterable[str]],
|
|
finished_status: RequestStatus,
|
|
) -> None:
|
|
"""Handles the finish signal from outside the scheduler.
|
|
|
|
For example, the API server can abort a request when the client
|
|
disconnects.
|
|
"""
|
|
assert RequestStatus.is_finished(finished_status)
|
|
if isinstance(request_ids, str):
|
|
request_ids = (request_ids, )
|
|
else:
|
|
request_ids = set(request_ids)
|
|
|
|
for req_id in request_ids:
|
|
request = self.requests.get(req_id)
|
|
if request is None:
|
|
# Invalid request ID.
|
|
continue
|
|
|
|
if request.status == RequestStatus.RUNNING:
|
|
self.running.remove(request)
|
|
self.scheduled_req_ids.discard(request.request_id)
|
|
else:
|
|
self.waiting.remove(request)
|
|
request.status = finished_status
|
|
self._free_request(request)
|
|
|
|
def update_from_output(
|
|
self,
|
|
scheduler_output: SchedulerOutput,
|
|
model_runner_output: ModelRunnerOutput,
|
|
) -> EngineCoreOutputs:
|
|
sampled_token_ids = model_runner_output.sampled_token_ids
|
|
spec_token_ids = model_runner_output.spec_token_ids
|
|
logprobs = model_runner_output.logprobs
|
|
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
|
|
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
|
|
|
|
new_running: list[Request] = []
|
|
outputs: list[EngineCoreOutput] = []
|
|
spec_decoding_stats: Optional[SpecDecodingStats] = None
|
|
|
|
# NOTE(woosuk): As len(self.running) can be up to 1K or more, the below
|
|
# loop can be a performance bottleneck. We should do our best to avoid
|
|
# expensive operations inside the loop.
|
|
for request in self.running:
|
|
req_id = request.request_id
|
|
num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0)
|
|
if num_tokens_scheduled == 0:
|
|
# The request was not scheduled in this step.
|
|
new_running.append(request)
|
|
continue
|
|
|
|
req_index = model_runner_output.req_id_to_index[req_id]
|
|
generated_token_ids = sampled_token_ids[req_index]
|
|
|
|
scheduled_spec_token_ids = (
|
|
scheduler_output.scheduled_spec_decode_tokens.get(req_id))
|
|
if scheduled_spec_token_ids:
|
|
# num_computed_tokens represents the number of tokens
|
|
# processed in the current step, considering scheduled
|
|
# tokens and rejections. If some tokens are rejected,
|
|
# num_computed_tokens is decreased by the number of rejected
|
|
# tokens, where is given by:
|
|
# len(scheduled_spec_token_ids) + 1 - len(generated_token_ids).
|
|
num_tokens_rejected = (len(scheduled_spec_token_ids) + 1 -
|
|
len(generated_token_ids))
|
|
request.num_computed_tokens -= num_tokens_rejected
|
|
spec_decoding_stats = self.make_spec_decoding_stats(
|
|
spec_decoding_stats,
|
|
num_draft_tokens=len(scheduled_spec_token_ids),
|
|
num_accepted_tokens=len(generated_token_ids) - 1)
|
|
|
|
cached_encoder_input_ids = (
|
|
self.encoder_cache_manager.get_cached_input_ids(request))
|
|
# OPTIMIZATION: Avoid list(set) if the set is empty.
|
|
if cached_encoder_input_ids:
|
|
for input_id in list(cached_encoder_input_ids):
|
|
mm_positions = request.mm_positions[input_id]
|
|
start_pos = mm_positions.offset
|
|
num_tokens = mm_positions.length
|
|
if start_pos + num_tokens <= request.num_computed_tokens:
|
|
# The encoder output is already processed and stored
|
|
# in the decoder's KV cache.
|
|
self.encoder_cache_manager.free_encoder_input(
|
|
request, input_id)
|
|
|
|
stopped = False
|
|
new_logprobs = None
|
|
new_token_ids = generated_token_ids
|
|
|
|
# Append generated tokens and check for stop. Note that if
|
|
# a request is still being prefilled, we expect the model runner
|
|
# to return empty token ids for the request.
|
|
for num_new, output_token_id in enumerate(new_token_ids, 1):
|
|
request.append_output_token_ids(output_token_id)
|
|
|
|
# Check for stop and update request state.
|
|
# This must be called before we make the EngineCoreOutput.
|
|
stopped = check_stop(request, self.max_model_len)
|
|
if stopped:
|
|
self._free_request(request)
|
|
del new_token_ids[num_new:] # Trim new tokens if needed.
|
|
break
|
|
|
|
# Extract sample logprobs if needed.
|
|
if request.sampling_params.logprobs is not None and logprobs:
|
|
# NOTE: once we support N tokens per step (spec decode),
|
|
# the outer lists can be of length > 1.
|
|
new_logprobs = logprobs.slice(req_index, req_index + 1)
|
|
|
|
if new_token_ids and request.use_structured_output:
|
|
# NOTE: structured_output_request
|
|
# should not be None if use_structured_output, we have
|
|
# check above, so safe to ignore type warning
|
|
request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr]
|
|
req_id, new_token_ids)
|
|
|
|
# Add newly generated spec token ids to the request.
|
|
if spec_token_ids is not None:
|
|
if request.use_structured_output:
|
|
metadata = request.structured_output_request
|
|
assert metadata is not None and metadata.grammar is not None
|
|
# Needs to happen after new_token_ids are accepted.
|
|
request.spec_token_ids = metadata.grammar.validate_tokens(
|
|
spec_token_ids[req_index])
|
|
else:
|
|
request.spec_token_ids = spec_token_ids[req_index]
|
|
|
|
# Get prompt logprobs for this request.
|
|
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
|
|
if new_token_ids:
|
|
# Add EngineCoreOutput for this Request.
|
|
outputs.append(
|
|
EngineCoreOutput(
|
|
request_id=req_id,
|
|
new_token_ids=new_token_ids,
|
|
finish_reason=request.get_finished_reason(),
|
|
new_logprobs=new_logprobs,
|
|
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
|
|
stop_reason=request.stop_reason,
|
|
events=request.take_events()))
|
|
else:
|
|
# Invariant: EngineCore returns no partial prefill outputs.
|
|
assert not prompt_logprobs_tensors
|
|
|
|
self.scheduled_req_ids.remove(req_id)
|
|
if not stopped:
|
|
new_running.append(request)
|
|
|
|
# Return the cached request data to the queue so they can be reused.
|
|
for req_data in scheduler_output.scheduled_cached_reqs:
|
|
# NOTE(rob): since we free stopped reqs above, adding stopped reqs
|
|
# to _cached_reqs_data will cause a memory leak.
|
|
if req_data.req_id not in self.finished_req_ids:
|
|
self._cached_reqs_data[req_data.req_id].append(req_data)
|
|
|
|
self.running = new_running
|
|
engine_core_outputs = EngineCoreOutputs(
|
|
outputs=outputs,
|
|
scheduler_stats=self.make_stats(spec_decoding_stats),
|
|
)
|
|
if self.include_finished_set:
|
|
#TODO currently sending duplicates here, improve this
|
|
engine_core_outputs.finished_requests = (
|
|
scheduler_output.finished_req_ids | self.finished_req_ids)
|
|
|
|
return engine_core_outputs
|