[Scheduler][MTP] Add support for speculative decoding in AsecendScheduler. (#943)
This PR adds support for speculative decoding in AsecendScheduler. Also inculde part of support for disaggregated prefill, full support will be merged in follow-up PR. --------- Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
@@ -14,16 +14,19 @@
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
import time
|
||||
from collections import deque
|
||||
from typing import Iterable, Union
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_events import KVEventBatch
|
||||
from vllm.logger import logger
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.utils import cdiv
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
from vllm.v1.engine import EngineCoreOutputs
|
||||
from vllm.v1.engine import EngineCoreEventType, EngineCoreOutputs
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
@@ -49,11 +52,6 @@ class AscendScheduler(Scheduler):
|
||||
self.scheduled_req_ids: set[str] = set()
|
||||
self.running: list[Request] = []
|
||||
|
||||
if self.vllm_config.kv_transfer_config is not None and \
|
||||
self.vllm_config.kv_transfer_config.is_kv_consumer:
|
||||
raise ValueError(
|
||||
"AscendScheduler cannot be used for decode nodes. ")
|
||||
|
||||
def schedule(self) -> SchedulerOutput:
|
||||
if self.scheduler_config.chunked_prefill_enabled:
|
||||
return super().schedule()
|
||||
@@ -68,6 +66,9 @@ class AscendScheduler(Scheduler):
|
||||
# Spec decode-related.
|
||||
scheduled_spec_decode_tokens: dict[str, list[int]] = {}
|
||||
|
||||
# For logging.
|
||||
scheduled_timestamp = time.monotonic()
|
||||
|
||||
# Record scheduled LoRA requests.
|
||||
scheduled_loras: set[int] = set()
|
||||
|
||||
@@ -86,6 +87,18 @@ class AscendScheduler(Scheduler):
|
||||
self.waiting.popleft()
|
||||
skipped_waiting_requests.appendleft(request)
|
||||
|
||||
num_prealloc_computed_tokens = 0
|
||||
# P/D: skip request if still waiting for remote kvs.
|
||||
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
|
||||
is_ready = self._update_waiting_for_remote_kv(request)
|
||||
if is_ready:
|
||||
request.status = RequestStatus.WAITING
|
||||
num_prealloc_computed_tokens = (
|
||||
request.num_computed_tokens)
|
||||
else:
|
||||
skip_cur_request()
|
||||
continue
|
||||
|
||||
# Check that adding the request still respects the max_loras
|
||||
# constraint.
|
||||
if (self.lora_config and request.lora_request and
|
||||
@@ -95,39 +108,72 @@ class AscendScheduler(Scheduler):
|
||||
skip_cur_request()
|
||||
continue
|
||||
|
||||
prompt_limit = self._get_prompt_limit(request)
|
||||
num_external_computed_tokens = 0
|
||||
load_kv_async = False
|
||||
|
||||
# Get already-cached tokens.
|
||||
computed_blocks, num_computed_tokens = (
|
||||
self.kv_cache_manager.get_computed_blocks(request))
|
||||
num_new_tokens = request.num_tokens - num_computed_tokens
|
||||
if (0 < self.scheduler_config.long_prefill_token_threshold <
|
||||
num_new_tokens):
|
||||
num_new_tokens = (
|
||||
self.scheduler_config.long_prefill_token_threshold)
|
||||
max_tokens_in_kvcache = (self.kv_cache_config.num_blocks *
|
||||
self.block_size)
|
||||
prompt_limit = min(prompt_limit, max_tokens_in_kvcache)
|
||||
if num_prealloc_computed_tokens == 0:
|
||||
new_computed_blocks, num_native_computed_tokens = \
|
||||
self.kv_cache_manager.get_computed_blocks(
|
||||
request)
|
||||
|
||||
# Finish request that exceeds prompt_limit or kv cache size.
|
||||
if num_new_tokens > prompt_limit:
|
||||
logger.warning(
|
||||
"Input prompt (%d tokens) is too long"
|
||||
" and exceeds limit of %d",
|
||||
num_new_tokens,
|
||||
prompt_limit,
|
||||
)
|
||||
request.status = RequestStatus.FINISHED_IGNORED
|
||||
self.finished_req_ids.add(request.request_id) # type: ignore
|
||||
self.waiting.popleft()
|
||||
continue
|
||||
# Get externally-cached tokens if using a KVConnector.
|
||||
if self.connector is not None:
|
||||
num_external_computed_tokens, load_kv_async = (
|
||||
self.connector.get_num_new_matched_tokens(
|
||||
request, num_native_computed_tokens))
|
||||
|
||||
if num_new_tokens > token_budget:
|
||||
# Scheduling would exceed token_budget, skip.
|
||||
skip_cur_request()
|
||||
continue
|
||||
# Total computed tokens (local + external).
|
||||
num_computed_tokens = (num_native_computed_tokens +
|
||||
num_external_computed_tokens)
|
||||
else:
|
||||
# P/D: skip checking prefix cache if loaded from remote kvs.
|
||||
new_computed_blocks = KVCacheBlocks.create_empty()
|
||||
num_native_computed_tokens = 0
|
||||
|
||||
# Total computed tokens (allocated in prior step).
|
||||
num_computed_tokens = num_prealloc_computed_tokens
|
||||
|
||||
# P/D: loading remote KV, do not allocate for new work.
|
||||
if load_kv_async:
|
||||
assert num_external_computed_tokens > 0
|
||||
num_new_tokens = 0
|
||||
blocks = None
|
||||
# Number of tokens to be scheduled.
|
||||
else:
|
||||
prompt_limit = self._get_prompt_limit(request)
|
||||
# Get already-cached tokens.
|
||||
computed_blocks, num_computed_tokens = (
|
||||
self.kv_cache_manager.get_computed_blocks(request))
|
||||
# We use `request.num_tokens` instead of
|
||||
# `request.num_prompt_tokens` to consider the resumed
|
||||
# requests, which have output tokens.
|
||||
num_new_tokens = request.num_tokens - num_computed_tokens
|
||||
max_tokens_in_kvcache = (self.kv_cache_config.num_blocks *
|
||||
self.block_size)
|
||||
prompt_limit = min(prompt_limit, max_tokens_in_kvcache)
|
||||
|
||||
# Finish request that exceeds prompt_limit or kv cache size.
|
||||
if num_new_tokens > prompt_limit:
|
||||
logger.warning(
|
||||
"Input prompt (%d tokens) is too long"
|
||||
" and exceeds limit of %d",
|
||||
num_new_tokens,
|
||||
prompt_limit,
|
||||
)
|
||||
request.status = RequestStatus.FINISHED_IGNORED
|
||||
self.finished_req_ids.add( # type: ignore
|
||||
request.request_id) # type: ignore
|
||||
self.waiting.popleft()
|
||||
continue
|
||||
|
||||
if num_new_tokens > token_budget:
|
||||
# Scheduling would exceed token_budget, skip.
|
||||
skip_cur_request()
|
||||
continue
|
||||
assert num_new_tokens > 0
|
||||
blocks = computed_blocks.blocks[0]
|
||||
|
||||
assert num_new_tokens > 0
|
||||
blocks = computed_blocks.blocks[0]
|
||||
watermark = getattr(self.scheduler_config, "watermark", 0.01)
|
||||
if not self._check_watermark_for_prefill(request, num_new_tokens,
|
||||
blocks, watermark):
|
||||
@@ -136,13 +182,38 @@ class AscendScheduler(Scheduler):
|
||||
continue
|
||||
|
||||
new_blocks = self.kv_cache_manager.allocate_slots(
|
||||
request, num_new_tokens, new_computed_blocks=computed_blocks)
|
||||
request,
|
||||
num_new_tokens + num_external_computed_tokens,
|
||||
num_native_computed_tokens,
|
||||
new_computed_blocks=computed_blocks,
|
||||
num_lookahead_tokens=self.num_lookahead_tokens,
|
||||
delay_cache_blocks=load_kv_async)
|
||||
if new_blocks is None:
|
||||
# The request cannot be scheduled.
|
||||
break
|
||||
|
||||
# KVConnector: update internal state after allocation.
|
||||
# This information is used to determine if a load is
|
||||
# needed for this request.
|
||||
if num_external_computed_tokens:
|
||||
assert self.connector is not None
|
||||
self.connector.update_state_after_alloc(
|
||||
request,
|
||||
new_computed_blocks + new_blocks,
|
||||
num_external_computed_tokens,
|
||||
)
|
||||
|
||||
self.waiting.popleft()
|
||||
if load_kv_async:
|
||||
# If loading async, allocate memory and put request
|
||||
# into the WAITING_FOR_REMOTE_KV state.
|
||||
skipped_waiting_requests.appendleft(request)
|
||||
request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
continue
|
||||
self.running.append(request)
|
||||
if self.log_stats:
|
||||
request.record_event(EngineCoreEventType.SCHEDULED,
|
||||
scheduled_timestamp)
|
||||
self.scheduled_req_ids.add(request.request_id)
|
||||
# Check request status.
|
||||
if request.status == RequestStatus.WAITING:
|
||||
@@ -161,6 +232,9 @@ class AscendScheduler(Scheduler):
|
||||
token_budget -= num_new_tokens
|
||||
request.status = RequestStatus.RUNNING
|
||||
request.num_computed_tokens = num_computed_tokens
|
||||
# Count the number of prifix cached tokens.
|
||||
if request.num_cached_tokens < 0:
|
||||
request.num_cached_tokens = num_computed_tokens
|
||||
|
||||
# Put back any skipped requests at the head of the waiting queue
|
||||
if skipped_waiting_requests:
|
||||
@@ -179,16 +253,45 @@ class AscendScheduler(Scheduler):
|
||||
|
||||
num_new_tokens = (request.num_tokens_with_spec -
|
||||
request.num_computed_tokens)
|
||||
if (0 < self.scheduler_config.long_prefill_token_threshold <
|
||||
num_new_tokens):
|
||||
num_new_tokens = (
|
||||
self.scheduler_config.long_prefill_token_threshold)
|
||||
assert (request.num_tokens - request.num_computed_tokens) == 1
|
||||
num_new_tokens = min(num_new_tokens, token_budget)
|
||||
assert num_new_tokens == 1
|
||||
# Make sure the input position does not exceed the max model len.
|
||||
# This is necessary when using spec decoding.
|
||||
num_new_tokens = min(
|
||||
num_new_tokens,
|
||||
self.max_model_len - request.num_computed_tokens)
|
||||
# Check that adding the request still respects the max_loras
|
||||
# constraint.
|
||||
if self.lora_config and request.lora_request and (
|
||||
len(scheduled_loras) == self.lora_config.max_loras
|
||||
and request.lora_request.lora_int_id
|
||||
not in scheduled_loras):
|
||||
# Scheduling would exceed max_loras, skip.
|
||||
num_new_tokens = 0
|
||||
|
||||
if num_new_tokens == 0:
|
||||
# The request cannot be scheduled because one of the following
|
||||
# reason:
|
||||
# 1. No new tokens to schedule. This may happen when PP>1 and
|
||||
# we have already scheduled all prompt tokens but they are
|
||||
# not finished yet.
|
||||
# 2. Adding the request exceeds the max_loras constraint.
|
||||
# NOTE(woosuk): Here, by doing `continue` instead of `break`,
|
||||
# we do not strictly follow the FCFS scheduling policy and
|
||||
# allow the lower-priority requests to be scheduled.
|
||||
req_index += 1
|
||||
continue
|
||||
|
||||
num_draft_tokens = max(
|
||||
num_new_tokens + request.num_computed_tokens -
|
||||
request.num_tokens, 0)
|
||||
|
||||
while True:
|
||||
new_blocks = self.kv_cache_manager.allocate_slots(
|
||||
request, num_new_tokens)
|
||||
request,
|
||||
num_new_tokens,
|
||||
num_draft_tokens=num_draft_tokens,
|
||||
num_lookahead_tokens=self.num_lookahead_tokens)
|
||||
if new_blocks is None:
|
||||
# The request cannot be scheduled.
|
||||
# Preempt the lowest-priority request.
|
||||
@@ -196,6 +299,10 @@ class AscendScheduler(Scheduler):
|
||||
self.kv_cache_manager.free(preempted_req)
|
||||
preempted_req.status = RequestStatus.PREEMPTED
|
||||
preempted_req.num_computed_tokens = 0
|
||||
if self.log_stats:
|
||||
preempted_req.record_event(
|
||||
EngineCoreEventType.PREEMPTED,
|
||||
scheduled_timestamp)
|
||||
self.waiting.appendleft(preempted_req)
|
||||
preempted_reqs.append(preempted_req)
|
||||
if preempted_req == request:
|
||||
@@ -230,6 +337,10 @@ class AscendScheduler(Scheduler):
|
||||
scheduled_spec_decode_tokens[request.request_id] = (
|
||||
request.spec_token_ids)
|
||||
|
||||
# Record scheduled LoRA requests.
|
||||
if self.lora_config and request.lora_request:
|
||||
scheduled_loras.add(request.lora_request.lora_int_id)
|
||||
|
||||
# Check if the scheduling constraints are satisfied.
|
||||
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
|
||||
assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens
|
||||
@@ -297,6 +408,11 @@ class AscendScheduler(Scheduler):
|
||||
meta = self.connector.build_connector_meta(scheduler_output)
|
||||
scheduler_output.kv_connector_metadata = meta
|
||||
|
||||
events = self.kv_cache_manager.take_events()
|
||||
if events:
|
||||
batch = KVEventBatch(ts=time.time(), events=events)
|
||||
self.kv_event_publisher.publish(batch)
|
||||
|
||||
# Advance the number of computed tokens for the request AFTER
|
||||
# the request is scheduled.
|
||||
# 1. The scheduler_output of the current step has to include the
|
||||
@@ -388,7 +504,8 @@ class AscendScheduler(Scheduler):
|
||||
if num_tokens_scheduled == 0:
|
||||
# The request was not scheduled in this step.
|
||||
continue
|
||||
self.scheduled_req_ids.remove(req_id)
|
||||
if req_id in self.scheduled_req_ids:
|
||||
self.scheduled_req_ids.remove(req_id)
|
||||
|
||||
return super().update_from_output(scheduler_output,
|
||||
model_runner_output)
|
||||
|
||||
Reference in New Issue
Block a user