feat: support torchair graph mode in v1 engine (#789)
### What this PR does / why we need it? support torchair graph mode with v1 engine --------- Signed-off-by: boying <897013703@qq.com>
This commit is contained in:
@@ -15,18 +15,42 @@
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
from collections import deque
|
||||
from typing import Iterable, Optional, Union
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import logger
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.utils import cdiv
|
||||
from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
from vllm.v1.core.sched.utils import check_stop
|
||||
from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
from vllm.v1.spec_decode.metrics import SpecDecodingStats
|
||||
from vllm.v1.structured_output import StructuredOutputManager
|
||||
|
||||
|
||||
class AscendScheduler(Scheduler):
|
||||
"""This Scheduler extends vllm's original v1 scheduler
|
||||
with prefill-first scheduling strategy."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
structured_output_manager: StructuredOutputManager,
|
||||
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||
include_finished_set: bool = False,
|
||||
log_stats: bool = False,
|
||||
) -> None:
|
||||
super().__init__(vllm_config, kv_cache_config,
|
||||
structured_output_manager, mm_registry,
|
||||
include_finished_set, log_stats)
|
||||
self.scheduled_req_ids: set[str] = set()
|
||||
self.running: list[Request] = []
|
||||
|
||||
def schedule(self) -> SchedulerOutput:
|
||||
if self.scheduler_config.chunked_prefill_enabled:
|
||||
return super().schedule()
|
||||
@@ -317,3 +341,175 @@ class AscendScheduler(Scheduler):
|
||||
return request.lora_request.long_lora_max_len
|
||||
else:
|
||||
return prompt_limit
|
||||
|
||||
def finish_requests(
|
||||
self,
|
||||
request_ids: Union[str, Iterable[str]],
|
||||
finished_status: RequestStatus,
|
||||
) -> None:
|
||||
"""Handles the finish signal from outside the scheduler.
|
||||
|
||||
For example, the API server can abort a request when the client
|
||||
disconnects.
|
||||
"""
|
||||
assert RequestStatus.is_finished(finished_status)
|
||||
if isinstance(request_ids, str):
|
||||
request_ids = (request_ids, )
|
||||
else:
|
||||
request_ids = set(request_ids)
|
||||
|
||||
for req_id in request_ids:
|
||||
request = self.requests.get(req_id)
|
||||
if request is None:
|
||||
# Invalid request ID.
|
||||
continue
|
||||
|
||||
if request.status == RequestStatus.RUNNING:
|
||||
self.running.remove(request)
|
||||
self.scheduled_req_ids.discard(request.request_id)
|
||||
else:
|
||||
self.waiting.remove(request)
|
||||
request.status = finished_status
|
||||
self._free_request(request)
|
||||
|
||||
def update_from_output(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
model_runner_output: ModelRunnerOutput,
|
||||
) -> EngineCoreOutputs:
|
||||
sampled_token_ids = model_runner_output.sampled_token_ids
|
||||
spec_token_ids = model_runner_output.spec_token_ids
|
||||
logprobs = model_runner_output.logprobs
|
||||
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
|
||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
|
||||
|
||||
new_running: list[Request] = []
|
||||
outputs: list[EngineCoreOutput] = []
|
||||
spec_decoding_stats: Optional[SpecDecodingStats] = None
|
||||
|
||||
# NOTE(woosuk): As len(self.running) can be up to 1K or more, the below
|
||||
# loop can be a performance bottleneck. We should do our best to avoid
|
||||
# expensive operations inside the loop.
|
||||
for request in self.running:
|
||||
req_id = request.request_id
|
||||
num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0)
|
||||
if num_tokens_scheduled == 0:
|
||||
# The request was not scheduled in this step.
|
||||
new_running.append(request)
|
||||
continue
|
||||
|
||||
req_index = model_runner_output.req_id_to_index[req_id]
|
||||
generated_token_ids = sampled_token_ids[req_index]
|
||||
|
||||
scheduled_spec_token_ids = (
|
||||
scheduler_output.scheduled_spec_decode_tokens.get(req_id))
|
||||
if scheduled_spec_token_ids:
|
||||
# num_computed_tokens represents the number of tokens
|
||||
# processed in the current step, considering scheduled
|
||||
# tokens and rejections. If some tokens are rejected,
|
||||
# num_computed_tokens is decreased by the number of rejected
|
||||
# tokens, where is given by:
|
||||
# len(scheduled_spec_token_ids) + 1 - len(generated_token_ids).
|
||||
num_tokens_rejected = (len(scheduled_spec_token_ids) + 1 -
|
||||
len(generated_token_ids))
|
||||
request.num_computed_tokens -= num_tokens_rejected
|
||||
spec_decoding_stats = self.make_spec_decoding_stats(
|
||||
spec_decoding_stats,
|
||||
num_draft_tokens=len(scheduled_spec_token_ids),
|
||||
num_accepted_tokens=len(generated_token_ids) - 1)
|
||||
|
||||
cached_encoder_input_ids = (
|
||||
self.encoder_cache_manager.get_cached_input_ids(request))
|
||||
# OPTIMIZATION: Avoid list(set) if the set is empty.
|
||||
if cached_encoder_input_ids:
|
||||
for input_id in list(cached_encoder_input_ids):
|
||||
mm_positions = request.mm_positions[input_id]
|
||||
start_pos = mm_positions.offset
|
||||
num_tokens = mm_positions.length
|
||||
if start_pos + num_tokens <= request.num_computed_tokens:
|
||||
# The encoder output is already processed and stored
|
||||
# in the decoder's KV cache.
|
||||
self.encoder_cache_manager.free_encoder_input(
|
||||
request, input_id)
|
||||
|
||||
stopped = False
|
||||
new_logprobs = None
|
||||
new_token_ids = generated_token_ids
|
||||
|
||||
# Append generated tokens and check for stop. Note that if
|
||||
# a request is still being prefilled, we expect the model runner
|
||||
# to return empty token ids for the request.
|
||||
for num_new, output_token_id in enumerate(new_token_ids, 1):
|
||||
request.append_output_token_ids(output_token_id)
|
||||
|
||||
# Check for stop and update request state.
|
||||
# This must be called before we make the EngineCoreOutput.
|
||||
stopped = check_stop(request, self.max_model_len)
|
||||
if stopped:
|
||||
self._free_request(request)
|
||||
del new_token_ids[num_new:] # Trim new tokens if needed.
|
||||
break
|
||||
|
||||
# Extract sample logprobs if needed.
|
||||
if request.sampling_params.logprobs is not None and logprobs:
|
||||
# NOTE: once we support N tokens per step (spec decode),
|
||||
# the outer lists can be of length > 1.
|
||||
new_logprobs = logprobs.slice(req_index, req_index + 1)
|
||||
|
||||
if new_token_ids and request.use_structured_output:
|
||||
# NOTE: structured_output_request
|
||||
# should not be None if use_structured_output, we have
|
||||
# check above, so safe to ignore type warning
|
||||
request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr]
|
||||
req_id, new_token_ids)
|
||||
|
||||
# Add newly generated spec token ids to the request.
|
||||
if spec_token_ids is not None:
|
||||
if request.use_structured_output:
|
||||
metadata = request.structured_output_request
|
||||
assert metadata is not None and metadata.grammar is not None
|
||||
# Needs to happen after new_token_ids are accepted.
|
||||
request.spec_token_ids = metadata.grammar.validate_tokens(
|
||||
spec_token_ids[req_index])
|
||||
else:
|
||||
request.spec_token_ids = spec_token_ids[req_index]
|
||||
|
||||
# Get prompt logprobs for this request.
|
||||
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
|
||||
if new_token_ids:
|
||||
# Add EngineCoreOutput for this Request.
|
||||
outputs.append(
|
||||
EngineCoreOutput(
|
||||
request_id=req_id,
|
||||
new_token_ids=new_token_ids,
|
||||
finish_reason=request.get_finished_reason(),
|
||||
new_logprobs=new_logprobs,
|
||||
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
|
||||
stop_reason=request.stop_reason,
|
||||
events=request.take_events()))
|
||||
else:
|
||||
# Invariant: EngineCore returns no partial prefill outputs.
|
||||
assert not prompt_logprobs_tensors
|
||||
|
||||
self.scheduled_req_ids.remove(req_id)
|
||||
if not stopped:
|
||||
new_running.append(request)
|
||||
|
||||
# Return the cached request data to the queue so they can be reused.
|
||||
for req_data in scheduler_output.scheduled_cached_reqs:
|
||||
# NOTE(rob): since we free stopped reqs above, adding stopped reqs
|
||||
# to _cached_reqs_data will cause a memory leak.
|
||||
if req_data.req_id not in self.finished_req_ids:
|
||||
self._cached_reqs_data[req_data.req_id].append(req_data)
|
||||
|
||||
self.running = new_running
|
||||
engine_core_outputs = EngineCoreOutputs(
|
||||
outputs=outputs,
|
||||
scheduler_stats=self.make_stats(spec_decoding_stats),
|
||||
)
|
||||
if self.include_finished_set:
|
||||
#TODO currently sending duplicates here, improve this
|
||||
engine_core_outputs.finished_requests = (
|
||||
scheduler_output.finished_req_ids | self.finished_req_ids)
|
||||
|
||||
return engine_core_outputs
|
||||
|
||||
Reference in New Issue
Block a user