### What this PR does / why we need it?
**Scope of Changes**:
| File Path |
| :--- |
| `vllm_ascend/attention/mla_v1.py` |
| `vllm_ascend/attention/sfa_v1.py` |
| `vllm_ascend/core/recompute_scheduler.py` |
| `vllm_ascend/core/scheduler_dynamic_batch.py` |
| `vllm_ascend/distributed/device_communicators/npu_communicator.py` |
| `vllm_ascend/distributed/device_communicators/pyhccl.py` |
| `vllm_ascend/distributed/device_communicators/pyhccl_wrapper.py` |
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.13.0
- vLLM main:
2c24bc6996
---------
Signed-off-by: MrZ20 <2609716663@qq.com>
Co-authored-by: Soren <user@SorendeMac-mini.local>
This commit is contained in:
@@ -21,26 +21,21 @@ from __future__ import annotations
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, fields
|
||||
from typing import Type, Union
|
||||
|
||||
from vllm._bc_linter import bc_linter_include
|
||||
from vllm.config import SchedulerConfig, VllmConfig
|
||||
from vllm.distributed.ec_transfer.ec_connector.base import ECConnectorMetadata
|
||||
from vllm.distributed.kv_events import KVEventBatch
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import \
|
||||
KVConnectorMetadata
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import \
|
||||
KVConnectorStats
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.core.sched.async_scheduler import AsyncScheduler
|
||||
from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput
|
||||
from vllm.v1.core.sched.request_queue import (SchedulingPolicy,
|
||||
create_request_queue)
|
||||
from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_queue
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
from vllm.v1.core.sched.utils import check_stop, remove_all
|
||||
from vllm.v1.engine import (EngineCoreEventType, EngineCoreOutput,
|
||||
EngineCoreOutputs, FinishReason)
|
||||
from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs, FinishReason
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
from vllm.v1.spec_decode.metrics import SpecDecodingStats
|
||||
@@ -51,26 +46,22 @@ logger = init_logger(__name__)
|
||||
|
||||
@dataclass
|
||||
class RecomputeSchedulerConfig(SchedulerConfig):
|
||||
scheduler_cls: Union[str, Type[object]] = (
|
||||
"vllm_ascend.core.recompute_scheduler.RecomputeScheduler")
|
||||
scheduler_cls: str | type[object] = "vllm_ascend.core.recompute_scheduler.RecomputeScheduler"
|
||||
|
||||
@classmethod
|
||||
def initialize_from_config(cls, vllm_config: VllmConfig):
|
||||
vllm_scheduler_config = vllm_config.scheduler_config
|
||||
scheduler_config = {
|
||||
field.name: getattr(vllm_scheduler_config, field.name)
|
||||
for field in fields(vllm_scheduler_config) if field.init
|
||||
for field in fields(vllm_scheduler_config)
|
||||
if field.init
|
||||
}
|
||||
if vllm_scheduler_config.async_scheduling:
|
||||
scheduler_config["scheduler_cls"] = (
|
||||
"vllm_ascend.core.recompute_scheduler.AsyncRecomputeScheduler")
|
||||
scheduler_config["scheduler_cls"] = "vllm_ascend.core.recompute_scheduler.AsyncRecomputeScheduler"
|
||||
else:
|
||||
scheduler_config["scheduler_cls"] = (
|
||||
"vllm_ascend.core.recompute_scheduler.RecomputeScheduler")
|
||||
scheduler_config[
|
||||
"max_model_len"] = vllm_config.model_config.max_model_len
|
||||
scheduler_config[
|
||||
"is_encoder_decoder"] = vllm_config.model_config.is_encoder_decoder
|
||||
scheduler_config["scheduler_cls"] = "vllm_ascend.core.recompute_scheduler.RecomputeScheduler"
|
||||
scheduler_config["max_model_len"] = vllm_config.model_config.max_model_len
|
||||
scheduler_config["is_encoder_decoder"] = vllm_config.model_config.is_encoder_decoder
|
||||
return cls(**scheduler_config)
|
||||
|
||||
|
||||
@@ -125,33 +116,32 @@ class RecomputeScheduler(Scheduler):
|
||||
while req_index < len(self.running) and token_budget > 0:
|
||||
request = self.running[req_index]
|
||||
|
||||
if (request.num_output_placeholders > 0
|
||||
# This is (num_computed_tokens + 1) - (num_output_placeholders - 1).
|
||||
# Since output placeholders are also included in the computed tokens
|
||||
# count, we subtract (num_output_placeholders - 1) to remove any draft
|
||||
# tokens, so that we can be sure no further steps are needed even if
|
||||
# they are all rejected.
|
||||
and request.num_computed_tokens + 2 -
|
||||
request.num_output_placeholders
|
||||
>= request.num_prompt_tokens + request.max_tokens):
|
||||
if (
|
||||
request.num_output_placeholders > 0
|
||||
# This is (num_computed_tokens + 1) - (num_output_placeholders - 1).
|
||||
# Since output placeholders are also included in the computed tokens
|
||||
# count, we subtract (num_output_placeholders - 1) to remove any draft
|
||||
# tokens, so that we can be sure no further steps are needed even if
|
||||
# they are all rejected.
|
||||
and request.num_computed_tokens + 2 - request.num_output_placeholders
|
||||
>= request.num_prompt_tokens + request.max_tokens
|
||||
):
|
||||
# Async scheduling: Avoid scheduling an extra step when we are sure that
|
||||
# the previous step has reached request.max_tokens. We don't schedule
|
||||
# partial draft tokens since this prevents uniform decode optimizations.
|
||||
req_index += 1
|
||||
continue
|
||||
|
||||
num_new_tokens = (request.num_tokens_with_spec +
|
||||
request.num_output_placeholders -
|
||||
request.num_computed_tokens)
|
||||
num_new_tokens = (
|
||||
request.num_tokens_with_spec + request.num_output_placeholders - request.num_computed_tokens
|
||||
)
|
||||
if 0 < self.scheduler_config.long_prefill_token_threshold < num_new_tokens:
|
||||
num_new_tokens = self.scheduler_config.long_prefill_token_threshold
|
||||
num_new_tokens = min(num_new_tokens, token_budget)
|
||||
|
||||
# Make sure the input position does not exceed the max model len.
|
||||
# This is necessary when using spec decoding.
|
||||
num_new_tokens = min(
|
||||
num_new_tokens,
|
||||
self.max_model_len - 1 - request.num_computed_tokens)
|
||||
num_new_tokens = min(num_new_tokens, self.max_model_len - 1 - request.num_computed_tokens)
|
||||
|
||||
# Schedule encoder inputs.
|
||||
encoder_inputs_to_schedule = None
|
||||
@@ -209,9 +199,10 @@ class RecomputeScheduler(Scheduler):
|
||||
recomputed_req = self.running.pop()
|
||||
self.kv_cache_manager.free(recomputed_req)
|
||||
recomputed_reqs.append(
|
||||
RecomputeReqInfo(recomputed_req.request_id,
|
||||
recomputed_req.output_token_ids,
|
||||
recomputed_req.client_index))
|
||||
RecomputeReqInfo(
|
||||
recomputed_req.request_id, recomputed_req.output_token_ids, recomputed_req.client_index
|
||||
)
|
||||
)
|
||||
if recomputed_req == request:
|
||||
break
|
||||
else:
|
||||
@@ -223,28 +214,23 @@ class RecomputeScheduler(Scheduler):
|
||||
self.running.remove(preempted_req)
|
||||
if preempted_req in scheduled_running_reqs:
|
||||
scheduled_running_reqs.remove(preempted_req)
|
||||
token_budget += num_scheduled_tokens[
|
||||
preempted_req.request_id]
|
||||
token_budget += num_scheduled_tokens[preempted_req.request_id]
|
||||
req_to_new_blocks.pop(preempted_req.request_id)
|
||||
num_scheduled_tokens.pop(
|
||||
preempted_req.request_id)
|
||||
scheduled_spec_decode_tokens.pop(
|
||||
preempted_req.request_id, None)
|
||||
preempted_encoder_inputs = scheduled_encoder_inputs.pop(
|
||||
preempted_req.request_id, None)
|
||||
num_scheduled_tokens.pop(preempted_req.request_id)
|
||||
scheduled_spec_decode_tokens.pop(preempted_req.request_id, None)
|
||||
preempted_encoder_inputs = scheduled_encoder_inputs.pop(preempted_req.request_id, None)
|
||||
if preempted_encoder_inputs:
|
||||
# Restore encoder compute budget if the preempted
|
||||
# request had encoder inputs scheduled in this step.
|
||||
num_embeds_to_restore = sum(
|
||||
preempted_req.get_num_encoder_embeds(i)
|
||||
for i in preempted_encoder_inputs)
|
||||
preempted_req.get_num_encoder_embeds(i) for i in preempted_encoder_inputs
|
||||
)
|
||||
encoder_compute_budget += num_embeds_to_restore
|
||||
req_index -= 1
|
||||
else:
|
||||
preempted_req = self.running.pop()
|
||||
|
||||
self._preempt_request(preempted_req,
|
||||
scheduled_timestamp)
|
||||
self._preempt_request(preempted_req, scheduled_timestamp)
|
||||
preempted_reqs.append(preempted_req)
|
||||
if preempted_req == request:
|
||||
# No more request to preempt. Cannot schedule this request.
|
||||
@@ -263,23 +249,20 @@ class RecomputeScheduler(Scheduler):
|
||||
|
||||
# Speculative decode related.
|
||||
if request.spec_token_ids:
|
||||
num_scheduled_spec_tokens = (num_new_tokens +
|
||||
request.num_computed_tokens -
|
||||
request.num_tokens -
|
||||
request.num_output_placeholders)
|
||||
num_scheduled_spec_tokens = (
|
||||
num_new_tokens + request.num_computed_tokens - request.num_tokens - request.num_output_placeholders
|
||||
)
|
||||
if num_scheduled_spec_tokens > 0:
|
||||
# Trim spec_token_ids list to num_scheduled_spec_tokens.
|
||||
del request.spec_token_ids[num_scheduled_spec_tokens:]
|
||||
scheduled_spec_decode_tokens[request.request_id] = (
|
||||
request.spec_token_ids)
|
||||
scheduled_spec_decode_tokens[request.request_id] = request.spec_token_ids
|
||||
# New spec tokens will be set in `update_draft_token_ids` before the
|
||||
# next step when applicable.
|
||||
request.spec_token_ids = []
|
||||
|
||||
# Encoder-related.
|
||||
if encoder_inputs_to_schedule:
|
||||
scheduled_encoder_inputs[request.request_id] = (
|
||||
encoder_inputs_to_schedule)
|
||||
scheduled_encoder_inputs[request.request_id] = encoder_inputs_to_schedule
|
||||
# Allocate the encoder cache.
|
||||
for i in encoder_inputs_to_schedule:
|
||||
self.encoder_cache_manager.allocate(request, i)
|
||||
@@ -294,8 +277,10 @@ class RecomputeScheduler(Scheduler):
|
||||
scheduled_loras: set[int] = set()
|
||||
if self.lora_config:
|
||||
scheduled_loras = set(
|
||||
req.lora_request.lora_int_id for req in scheduled_running_reqs
|
||||
if req.lora_request and req.lora_request.lora_int_id > 0)
|
||||
req.lora_request.lora_int_id
|
||||
for req in scheduled_running_reqs
|
||||
if req.lora_request and req.lora_request.lora_int_id > 0
|
||||
)
|
||||
assert len(scheduled_loras) <= self.lora_config.max_loras
|
||||
|
||||
# Use a temporary RequestQueue to collect requests that need to be
|
||||
@@ -337,9 +322,14 @@ class RecomputeScheduler(Scheduler):
|
||||
|
||||
# Check that adding the request still respects the max_loras
|
||||
# constraint.
|
||||
if (self.lora_config and request.lora_request and
|
||||
(len(scheduled_loras) == self.lora_config.max_loras and
|
||||
request.lora_request.lora_int_id not in scheduled_loras)):
|
||||
if (
|
||||
self.lora_config
|
||||
and request.lora_request
|
||||
and (
|
||||
len(scheduled_loras) == self.lora_config.max_loras
|
||||
and request.lora_request.lora_int_id not in scheduled_loras
|
||||
)
|
||||
):
|
||||
# Scheduling would exceed max_loras, skip.
|
||||
self.waiting.pop_request()
|
||||
skipped_waiting_requests.prepend_request(request)
|
||||
@@ -351,14 +341,15 @@ class RecomputeScheduler(Scheduler):
|
||||
# Get already-cached tokens.
|
||||
if request.num_computed_tokens == 0:
|
||||
# Get locally-cached tokens.
|
||||
new_computed_blocks, num_new_local_computed_tokens = (
|
||||
self.kv_cache_manager.get_computed_blocks(request))
|
||||
new_computed_blocks, num_new_local_computed_tokens = self.kv_cache_manager.get_computed_blocks(
|
||||
request
|
||||
)
|
||||
|
||||
# Get externally-cached tokens if using a KVConnector.
|
||||
if self.connector is not None:
|
||||
ext_tokens, load_kv_async = (
|
||||
self.connector.get_num_new_matched_tokens(
|
||||
request, num_new_local_computed_tokens))
|
||||
ext_tokens, load_kv_async = self.connector.get_num_new_matched_tokens(
|
||||
request, num_new_local_computed_tokens
|
||||
)
|
||||
|
||||
if ext_tokens is None:
|
||||
# The request cannot be scheduled because
|
||||
@@ -372,8 +363,7 @@ class RecomputeScheduler(Scheduler):
|
||||
num_external_computed_tokens = ext_tokens
|
||||
|
||||
# Total computed tokens (local + external).
|
||||
num_computed_tokens = (num_new_local_computed_tokens +
|
||||
num_external_computed_tokens)
|
||||
num_computed_tokens = num_new_local_computed_tokens + num_external_computed_tokens
|
||||
else:
|
||||
# KVTransfer: WAITING reqs have num_computed_tokens > 0
|
||||
# after async KV recvs are completed.
|
||||
@@ -401,8 +391,7 @@ class RecomputeScheduler(Scheduler):
|
||||
|
||||
# chunked prefill has to be enabled explicitly to allow
|
||||
# pooling requests to be chunked
|
||||
if (not self.scheduler_config.enable_chunked_prefill
|
||||
and num_new_tokens > token_budget):
|
||||
if not self.scheduler_config.enable_chunked_prefill and num_new_tokens > token_budget:
|
||||
# If chunked_prefill is disabled,
|
||||
# we can stop the scheduling here.
|
||||
break
|
||||
@@ -433,9 +422,7 @@ class RecomputeScheduler(Scheduler):
|
||||
# extra block gets allocated which
|
||||
# creates a mismatch between the number
|
||||
# of local and remote blocks.
|
||||
effective_lookahead_tokens = (0 if request.num_computed_tokens
|
||||
== 0 else
|
||||
self.num_lookahead_tokens)
|
||||
effective_lookahead_tokens = 0 if request.num_computed_tokens == 0 else self.num_lookahead_tokens
|
||||
|
||||
# Determine if we need to allocate cross-attention blocks.
|
||||
if self.is_encoder_decoder and request.has_encoder_inputs:
|
||||
@@ -443,8 +430,7 @@ class RecomputeScheduler(Scheduler):
|
||||
# always padded to the maximum length. If we support other
|
||||
# encoder-decoder models, this will need to be updated if we
|
||||
# want to only allocate what is needed.
|
||||
num_encoder_tokens = (
|
||||
self.scheduler_config.max_num_encoder_input_tokens)
|
||||
num_encoder_tokens = self.scheduler_config.max_num_encoder_input_tokens
|
||||
else:
|
||||
num_encoder_tokens = 0
|
||||
|
||||
@@ -488,20 +474,17 @@ class RecomputeScheduler(Scheduler):
|
||||
req_index += 1
|
||||
self.running.append(request)
|
||||
if self.log_stats:
|
||||
request.record_event(EngineCoreEventType.SCHEDULED,
|
||||
scheduled_timestamp)
|
||||
request.record_event(EngineCoreEventType.SCHEDULED, scheduled_timestamp)
|
||||
if request.status == RequestStatus.WAITING:
|
||||
scheduled_new_reqs.append(request)
|
||||
elif request.status == RequestStatus.PREEMPTED:
|
||||
scheduled_resumed_reqs.append(request)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Invalid request status: {request.status}")
|
||||
raise RuntimeError(f"Invalid request status: {request.status}")
|
||||
|
||||
if self.lora_config and request.lora_request:
|
||||
scheduled_loras.add(request.lora_request.lora_int_id)
|
||||
req_to_new_blocks[request.request_id] = (
|
||||
self.kv_cache_manager.get_blocks(request.request_id))
|
||||
req_to_new_blocks[request.request_id] = self.kv_cache_manager.get_blocks(request.request_id)
|
||||
num_scheduled_tokens[request.request_id] = num_new_tokens
|
||||
token_budget -= num_new_tokens
|
||||
request.status = RequestStatus.RUNNING
|
||||
@@ -511,8 +494,7 @@ class RecomputeScheduler(Scheduler):
|
||||
request.num_cached_tokens = num_computed_tokens
|
||||
# Encoder-related.
|
||||
if encoder_inputs_to_schedule:
|
||||
scheduled_encoder_inputs[request.request_id] = (
|
||||
encoder_inputs_to_schedule)
|
||||
scheduled_encoder_inputs[request.request_id] = encoder_inputs_to_schedule
|
||||
# Allocate the encoder cache.
|
||||
for i in encoder_inputs_to_schedule:
|
||||
self.encoder_cache_manager.allocate(request, i)
|
||||
@@ -522,8 +504,7 @@ class RecomputeScheduler(Scheduler):
|
||||
for i in external_load_encoder_input:
|
||||
self.encoder_cache_manager.allocate(request, i)
|
||||
if self.ec_connector is not None:
|
||||
self.ec_connector.update_state_after_alloc(
|
||||
request, i)
|
||||
self.ec_connector.update_state_after_alloc(request, i)
|
||||
# Put back any skipped requests at the head of the waiting queue
|
||||
if skipped_waiting_requests:
|
||||
self.waiting.prepend_requests(skipped_waiting_requests)
|
||||
@@ -537,20 +518,15 @@ class RecomputeScheduler(Scheduler):
|
||||
# Since some requests in the RUNNING queue may not be scheduled in
|
||||
# this step, the total number of scheduled requests can be smaller than
|
||||
# len(self.running).
|
||||
assert len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len(
|
||||
scheduled_running_reqs) <= len(self.running)
|
||||
assert len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len(scheduled_running_reqs) <= len(self.running)
|
||||
|
||||
# Get the longest common prefix among all requests in the running queue.
|
||||
# This can be potentially used for cascade attention.
|
||||
num_common_prefix_blocks = [0] * len(
|
||||
self.kv_cache_config.kv_cache_groups)
|
||||
with record_function_or_nullcontext(
|
||||
"schedule: get_num_common_prefix_blocks"):
|
||||
num_common_prefix_blocks = [0] * len(self.kv_cache_config.kv_cache_groups)
|
||||
with record_function_or_nullcontext("schedule: get_num_common_prefix_blocks"):
|
||||
if self.running:
|
||||
any_request = self.running[0]
|
||||
num_common_prefix_blocks = (
|
||||
self.kv_cache_manager.get_num_common_prefix_blocks(
|
||||
any_request.request_id))
|
||||
num_common_prefix_blocks = self.kv_cache_manager.get_num_common_prefix_blocks(any_request.request_id)
|
||||
|
||||
# Construct the scheduler output.
|
||||
if self.use_v2_model_runner:
|
||||
@@ -561,17 +537,16 @@ class RecomputeScheduler(Scheduler):
|
||||
req,
|
||||
req_to_new_blocks[req.request_id].get_block_ids(),
|
||||
req._all_token_ids,
|
||||
) for req in scheduled_new_reqs
|
||||
)
|
||||
for req in scheduled_new_reqs
|
||||
]
|
||||
else:
|
||||
new_reqs_data = [
|
||||
NewRequestData.from_request(
|
||||
req, req_to_new_blocks[req.request_id].get_block_ids())
|
||||
NewRequestData.from_request(req, req_to_new_blocks[req.request_id].get_block_ids())
|
||||
for req in scheduled_new_reqs
|
||||
]
|
||||
|
||||
with record_function_or_nullcontext(
|
||||
"schedule: make_cached_request_data"):
|
||||
with record_function_or_nullcontext("schedule: make_cached_request_data"):
|
||||
cached_reqs_data = self._make_cached_request_data(
|
||||
scheduled_running_reqs,
|
||||
scheduled_resumed_reqs,
|
||||
@@ -592,15 +567,13 @@ class RecomputeScheduler(Scheduler):
|
||||
scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,
|
||||
scheduled_encoder_inputs=scheduled_encoder_inputs,
|
||||
num_common_prefix_blocks=num_common_prefix_blocks,
|
||||
preempted_req_ids={req.request_id
|
||||
for req in preempted_reqs},
|
||||
preempted_req_ids={req.request_id for req in preempted_reqs},
|
||||
# finished_req_ids is an existing state in the scheduler,
|
||||
# instead of being newly scheduled in this step.
|
||||
# It contains the request IDs that are finished in between
|
||||
# the previous and the current steps.
|
||||
finished_req_ids=self.finished_req_ids,
|
||||
free_encoder_mm_hashes=self.encoder_cache_manager.
|
||||
get_freed_mm_hashes(),
|
||||
free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(),
|
||||
recomputed_reqs=recomputed_reqs,
|
||||
)
|
||||
|
||||
@@ -609,14 +582,12 @@ class RecomputeScheduler(Scheduler):
|
||||
# 2. Wrap up all the KV cache load / save ops into an opaque object
|
||||
# 3. Clear the internal states of the connector
|
||||
if self.connector is not None:
|
||||
meta: KVConnectorMetadata = self.connector.build_connector_meta(
|
||||
scheduler_output)
|
||||
meta: KVConnectorMetadata = self.connector.build_connector_meta(scheduler_output)
|
||||
scheduler_output.kv_connector_metadata = meta
|
||||
|
||||
# Build the connector meta for ECConnector
|
||||
if self.ec_connector is not None:
|
||||
ec_meta: ECConnectorMetadata = self.ec_connector.build_connector_meta(
|
||||
scheduler_output)
|
||||
ec_meta: ECConnectorMetadata = self.ec_connector.build_connector_meta(scheduler_output)
|
||||
scheduler_output.ec_connector_metadata = ec_meta
|
||||
|
||||
with record_function_or_nullcontext("schedule: update_after_schedule"):
|
||||
@@ -639,8 +610,8 @@ class RecomputeScheduler(Scheduler):
|
||||
outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list)
|
||||
spec_decoding_stats: SpecDecodingStats | None = None
|
||||
kv_connector_stats: KVConnectorStats | None = (
|
||||
kv_connector_output.kv_connector_stats
|
||||
if kv_connector_output else None)
|
||||
kv_connector_output.kv_connector_stats if kv_connector_output else None
|
||||
)
|
||||
if kv_connector_stats and self.connector:
|
||||
kv_stats = self.connector.get_kv_connector_stats()
|
||||
if kv_stats:
|
||||
@@ -651,8 +622,7 @@ class RecomputeScheduler(Scheduler):
|
||||
# These blocks contain externally computed tokens that failed to
|
||||
# load. Identify affected requests and adjust their computed token
|
||||
# count to trigger recomputation of the invalid blocks.
|
||||
failed_kv_load_req_ids = self._handle_invalid_blocks(
|
||||
kv_connector_output.invalid_block_ids)
|
||||
failed_kv_load_req_ids = self._handle_invalid_blocks(kv_connector_output.invalid_block_ids)
|
||||
|
||||
# return recomputed requests as EngineCoreOutput
|
||||
if scheduler_output.recomputed_reqs is not None:
|
||||
@@ -663,7 +633,8 @@ class RecomputeScheduler(Scheduler):
|
||||
finish_reason=FinishReason.STOP,
|
||||
new_token_ids=[req_info.output_token_ids[-1]],
|
||||
stop_reason="recomputed",
|
||||
))
|
||||
)
|
||||
)
|
||||
|
||||
# NOTE(woosuk): As len(num_scheduled_tokens) can be up to 1K or more,
|
||||
# the below loop can be a performance bottleneck. We should do our best
|
||||
@@ -683,11 +654,9 @@ class RecomputeScheduler(Scheduler):
|
||||
continue
|
||||
|
||||
req_index = model_runner_output.req_id_to_index[req_id]
|
||||
generated_token_ids = (sampled_token_ids[req_index]
|
||||
if sampled_token_ids else [])
|
||||
generated_token_ids = sampled_token_ids[req_index] if sampled_token_ids else []
|
||||
|
||||
scheduled_spec_token_ids = (
|
||||
scheduler_output.scheduled_spec_decode_tokens.get(req_id))
|
||||
scheduled_spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(req_id)
|
||||
if scheduled_spec_token_ids:
|
||||
num_draft_tokens = len(scheduled_spec_token_ids)
|
||||
num_accepted = len(generated_token_ids) - 1
|
||||
@@ -717,15 +686,13 @@ class RecomputeScheduler(Scheduler):
|
||||
|
||||
# Check for stop and update request status.
|
||||
if new_token_ids:
|
||||
new_token_ids, stopped = self._update_request_with_output(
|
||||
request, new_token_ids)
|
||||
new_token_ids, stopped = self._update_request_with_output(request, new_token_ids)
|
||||
|
||||
# Stop checking for pooler models.
|
||||
pooler_output = None
|
||||
if pooler_outputs:
|
||||
pooler_output = pooler_outputs[req_index]
|
||||
stopped = check_stop(request, self.max_model_len,
|
||||
pooler_output)
|
||||
stopped = check_stop(request, self.max_model_len, pooler_output)
|
||||
|
||||
if stopped:
|
||||
kv_transfer_params = self._free_request(request)
|
||||
@@ -735,19 +702,14 @@ class RecomputeScheduler(Scheduler):
|
||||
stopped_preempted_reqs.add(request)
|
||||
|
||||
# Extract sample logprobs if needed.
|
||||
if (request.sampling_params is not None
|
||||
and request.sampling_params.logprobs is not None
|
||||
and logprobs):
|
||||
new_logprobs = logprobs.slice_request(req_index,
|
||||
len(new_token_ids))
|
||||
if request.sampling_params is not None and request.sampling_params.logprobs is not None and logprobs:
|
||||
new_logprobs = logprobs.slice_request(req_index, len(new_token_ids))
|
||||
|
||||
if new_token_ids and self.structured_output_manager.should_advance(
|
||||
request):
|
||||
if new_token_ids and self.structured_output_manager.should_advance(request):
|
||||
struct_output_request = request.structured_output_request
|
||||
assert struct_output_request is not None
|
||||
assert struct_output_request.grammar is not None
|
||||
struct_output_request.grammar.accept_tokens(
|
||||
req_id, new_token_ids)
|
||||
struct_output_request.grammar.accept_tokens(req_id, new_token_ids)
|
||||
|
||||
if num_nans_in_logits is not None and req_id in num_nans_in_logits:
|
||||
request.num_nans_in_logits = num_nans_in_logits[req_id]
|
||||
@@ -770,7 +732,8 @@ class RecomputeScheduler(Scheduler):
|
||||
trace_headers=request.trace_headers,
|
||||
num_cached_tokens=request.num_cached_tokens,
|
||||
num_nans_in_logits=request.num_nans_in_logits,
|
||||
))
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Invariant: EngineCore returns no partial prefill outputs.
|
||||
assert not prompt_logprobs_tensors
|
||||
@@ -805,10 +768,7 @@ class RecomputeScheduler(Scheduler):
|
||||
|
||||
# Create EngineCoreOutputs for all clients that have requests with
|
||||
# outputs in this step.
|
||||
engine_core_outputs = {
|
||||
client_index: EngineCoreOutputs(outputs=outs)
|
||||
for client_index, outs in outputs.items()
|
||||
}
|
||||
engine_core_outputs = {client_index: EngineCoreOutputs(outputs=outs) for client_index, outs in outputs.items()}
|
||||
|
||||
finished_req_ids = self.finished_req_ids_dict
|
||||
if finished_req_ids:
|
||||
@@ -819,12 +779,10 @@ class RecomputeScheduler(Scheduler):
|
||||
if (eco := engine_core_outputs.get(client_index)) is not None:
|
||||
eco.finished_requests = finished_set
|
||||
else:
|
||||
engine_core_outputs[client_index] = EngineCoreOutputs(
|
||||
finished_requests=finished_set)
|
||||
engine_core_outputs[client_index] = EngineCoreOutputs(finished_requests=finished_set)
|
||||
finished_req_ids.clear()
|
||||
|
||||
if (stats := self.make_stats(spec_decoding_stats,
|
||||
kv_connector_stats)) is not None:
|
||||
if (stats := self.make_stats(spec_decoding_stats, kv_connector_stats)) is not None:
|
||||
# Return stats to only one of the front-ends.
|
||||
if (eco := next(iter(engine_core_outputs.values()), None)) is None:
|
||||
# We must return the stats even if there are no request
|
||||
@@ -836,6 +794,5 @@ class RecomputeScheduler(Scheduler):
|
||||
|
||||
|
||||
class AsyncRecomputeScheduler(AsyncScheduler, RecomputeScheduler):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
Reference in New Issue
Block a user