661 lines
30 KiB
Python
661 lines
30 KiB
Python
from functools import partial
|
|
import os
|
|
import queue
|
|
import threading
|
|
import traceback
|
|
from typing import Callable, Dict, List, Mapping, Optional, Type, Union
|
|
from zlib import ZLIB_VERSION
|
|
import torch
|
|
from vllm import envs
|
|
from vllm.config import DecodingConfig, ObservabilityConfig, VllmConfig
|
|
from vllm.core.scheduler import ScheduledSequenceGroup
|
|
from vllm.engine.llm_engine import _LOCAL_LOGGING_INTERVAL_SEC, LLMEngine, SchedulerContext, SchedulerOutputState
|
|
from vllm.engine.metrics_types import StatLoggerBase
|
|
from vllm.engine.output_processor.interfaces import SequenceGroupOutputProcessor
|
|
from vllm.logger import init_logger
|
|
from vllm.executor.executor_base import ExecutorBase
|
|
from vllm.inputs import INPUT_REGISTRY
|
|
from vllm.inputs.data import ProcessorInputs
|
|
from vllm.inputs.parse import split_enc_dec_inputs
|
|
from vllm.inputs.preprocess import InputPreprocessor
|
|
from vllm.inputs.registry import InputRegistry
|
|
from vllm.lora.request import LoRARequest
|
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
|
from vllm.multimodal.registry import MultiModalRegistry
|
|
from vllm.outputs import PoolingRequestOutput, RequestOutput
|
|
from vllm.pooling_params import PoolingParams
|
|
from vllm.prompt_adapter.request import PromptAdapterRequest
|
|
from vllm.sampling_params import SamplingParams
|
|
from vllm.sequence import ExecuteModelRequest, ParallelSampleSequenceGroup, SequenceGroup, SequenceGroupBase, SequenceGroupMetadata
|
|
from vllm.tracing import init_tracer
|
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
|
from vllm.version import __version__ as VLLM_VERSION
|
|
from vllm.usage.usage_lib import UsageContext, is_usage_stats_enabled
|
|
from vllm.utils import resolve_obj_by_qualname, weak_bind, Counter
|
|
from vllm.zero_overhead.sampler import SampleRecorder, get_last_sampler
|
|
from vllm.zero_overhead.sequence import ZeroOverheadSequence
|
|
from vllm.zero_overhead.stop_check import ZeroOverheadStopChecker
|
|
from vllm.zero_overhead.tokenizer import ZeroOverheadDetokenizer
|
|
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
|
|
usage_message)
|
|
from vllm.profiler.prof import profile
|
|
from vllm.zero_overhead.utils import SpecStepKind, get_accepted_token_ids, get_spec_step, is_zero_no_thread, set_spec_step, zero_overhead_stream
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
class ZeroOverheadEngine(LLMEngine):
|
|
def __init__(
|
|
self,
|
|
vllm_config: VllmConfig,
|
|
executor_class: Type[ExecutorBase],
|
|
log_stats: bool,
|
|
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
|
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
|
|
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
|
use_cached_outputs: bool = False,
|
|
) -> None:
|
|
if envs.VLLM_USE_V1:
|
|
raise ValueError(
|
|
"Using V0 LLMEngine, but envs.VLLM_USE_V1=True. "
|
|
"This should not happen. As a workaround, try using "
|
|
"LLMEngine.from_vllm_config(...) or explicitly set "
|
|
"VLLM_USE_V1=0 or 1 and report this issue on Github.")
|
|
|
|
self.vllm_config = vllm_config
|
|
self.model_config = vllm_config.model_config
|
|
self.cache_config = vllm_config.cache_config
|
|
self.lora_config = vllm_config.lora_config
|
|
self.parallel_config = vllm_config.parallel_config
|
|
self.scheduler_config = vllm_config.scheduler_config
|
|
self.device_config = vllm_config.device_config
|
|
self.speculative_config = vllm_config.speculative_config # noqa
|
|
self.load_config = vllm_config.load_config
|
|
self.decoding_config = vllm_config.decoding_config or DecodingConfig( # noqa
|
|
)
|
|
self.prompt_adapter_config = vllm_config.prompt_adapter_config # noqa
|
|
self.observability_config = vllm_config.observability_config or ObservabilityConfig( # noqa
|
|
)
|
|
|
|
logger.info(
|
|
"Initializing a V0 LLM engine (v%s) with config: %s, "
|
|
"use_cached_outputs=%s, ",
|
|
VLLM_VERSION,
|
|
vllm_config,
|
|
use_cached_outputs,
|
|
)
|
|
|
|
self.log_stats = log_stats
|
|
self.use_cached_outputs = use_cached_outputs
|
|
self.thread_running = False
|
|
|
|
if not self.model_config.skip_tokenizer_init:
|
|
self.tokenizer = self._init_tokenizer()
|
|
self.detokenizer = ZeroOverheadDetokenizer(self.tokenizer)
|
|
tokenizer_group = self.get_tokenizer_group()
|
|
else:
|
|
self.tokenizer = None
|
|
self.detokenizer = None
|
|
tokenizer_group = None
|
|
|
|
# Ensure that the function doesn't contain a reference to self,
|
|
# to avoid engine GC issues
|
|
def get_tokenizer_for_seq(sequence: ZeroOverheadSequence) -> AnyTokenizer:
|
|
assert tokenizer_group, ("tokenizer_group cannot be None, "
|
|
"make sure skip_tokenizer_init is False")
|
|
return tokenizer_group.get_lora_tokenizer(sequence.lora_request)
|
|
|
|
self.seq_counter = Counter()
|
|
self.generation_config_fields = (
|
|
self.model_config.try_get_generation_config())
|
|
|
|
self.input_preprocessor = InputPreprocessor(self.model_config,
|
|
self.tokenizer,
|
|
mm_registry)
|
|
|
|
self.model_executor = executor_class(vllm_config=vllm_config, )
|
|
|
|
if self.model_config.runner_type != "pooling":
|
|
self._initialize_kv_caches()
|
|
|
|
# If usage stat is enabled, collect relevant info.
|
|
if is_usage_stats_enabled():
|
|
from vllm.model_executor.model_loader import (
|
|
get_architecture_class_name)
|
|
usage_message.report_usage(
|
|
get_architecture_class_name(self.model_config),
|
|
usage_context,
|
|
extra_kvs={
|
|
# Common configuration
|
|
"dtype":
|
|
str(self.model_config.dtype),
|
|
"tensor_parallel_size":
|
|
self.parallel_config.tensor_parallel_size,
|
|
"block_size":
|
|
self.cache_config.block_size,
|
|
"gpu_memory_utilization":
|
|
self.cache_config.gpu_memory_utilization,
|
|
|
|
# Quantization
|
|
"quantization":
|
|
self.model_config.quantization,
|
|
"kv_cache_dtype":
|
|
str(self.cache_config.cache_dtype),
|
|
|
|
# Feature flags
|
|
"enable_lora":
|
|
bool(self.lora_config),
|
|
"enable_prompt_adapter":
|
|
bool(self.prompt_adapter_config),
|
|
"enable_prefix_caching":
|
|
self.cache_config.enable_prefix_caching,
|
|
"enforce_eager":
|
|
self.model_config.enforce_eager,
|
|
"disable_custom_all_reduce":
|
|
self.parallel_config.disable_custom_all_reduce,
|
|
})
|
|
|
|
self.cached_scheduler_outputs = [
|
|
SchedulerOutputState()
|
|
for _ in range(self.parallel_config.pipeline_parallel_size)
|
|
]
|
|
|
|
self.scheduler_contexts = [
|
|
SchedulerContext(multi_step_stream_outputs=self.scheduler_config.
|
|
multi_step_stream_outputs)
|
|
for _ in range(self.parallel_config.pipeline_parallel_size)
|
|
]
|
|
|
|
if self.model_config.use_async_output_proc:
|
|
process_model_outputs = weak_bind(self._process_model_outputs)
|
|
|
|
self.async_callbacks = [
|
|
partial(process_model_outputs,
|
|
ctx=self.scheduler_contexts[v_id])
|
|
for v_id in range(self.parallel_config.pipeline_parallel_size)
|
|
]
|
|
else:
|
|
self.async_callbacks = []
|
|
|
|
# Currently used by AsyncLLMEngine to ensure quick append
|
|
# of request outputs to asyncio queues
|
|
self.process_request_outputs_callback: Optional[Callable] = None
|
|
|
|
# Create the scheduler.
|
|
# NOTE: the cache_config here have been updated with the numbers of
|
|
# GPU and CPU blocks, which are profiled in the distributed executor.
|
|
if isinstance(self.vllm_config.scheduler_config.scheduler_cls, str):
|
|
Scheduler = resolve_obj_by_qualname(
|
|
self.vllm_config.scheduler_config.scheduler_cls)
|
|
else:
|
|
Scheduler = self.vllm_config.scheduler_config.scheduler_cls
|
|
self.scheduler = [
|
|
Scheduler(
|
|
self.scheduler_config, self.cache_config, self.lora_config,
|
|
self.parallel_config.pipeline_parallel_size,
|
|
self.async_callbacks[v_id]
|
|
if self.model_config.use_async_output_proc else None)
|
|
for v_id in range(self.parallel_config.pipeline_parallel_size)
|
|
]
|
|
|
|
# Metric Logging.
|
|
if self.log_stats:
|
|
if stat_loggers is not None:
|
|
self.stat_loggers = stat_loggers
|
|
else:
|
|
# Lazy import for prometheus multiprocessing.
|
|
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
|
|
# before prometheus_client is imported.
|
|
# See https://prometheus.github.io/client_python/multiprocess/
|
|
from vllm.engine.metrics import (LoggingStatLogger,
|
|
PrometheusStatLogger)
|
|
|
|
self.stat_loggers = {
|
|
"logging":
|
|
LoggingStatLogger(
|
|
local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
|
|
vllm_config=vllm_config),
|
|
"prometheus":
|
|
PrometheusStatLogger(
|
|
local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
|
|
labels=dict(
|
|
model_name=self.model_config.served_model_name),
|
|
vllm_config=vllm_config),
|
|
}
|
|
self.stat_loggers["prometheus"].info("cache_config",
|
|
self.cache_config)
|
|
|
|
self.tracer = None
|
|
if self.observability_config.otlp_traces_endpoint:
|
|
self.tracer = init_tracer(
|
|
"vllm.llm_engine",
|
|
self.observability_config.otlp_traces_endpoint)
|
|
|
|
# Create sequence output processor, e.g. for beam search or
|
|
# speculative decoding.
|
|
self.output_processor = (
|
|
SequenceGroupOutputProcessor.create_output_processor(
|
|
self.scheduler_config,
|
|
self.detokenizer,
|
|
self.scheduler,
|
|
self.seq_counter,
|
|
get_tokenizer_for_seq,
|
|
stop_checker=ZeroOverheadStopChecker(
|
|
self.scheduler_config.max_model_len,
|
|
get_tokenizer_for_seq,
|
|
),
|
|
))
|
|
self.tree_decoding = os.environ.get('VLLM_TREE_DECODING') == '1'
|
|
|
|
self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {}
|
|
|
|
# Flag to set when an input fails to process and the engine should run
|
|
# the next step without re-scheduling.
|
|
self._skip_scheduling_next_step = False
|
|
self.async_d2h = None
|
|
self.last_record = None
|
|
self.async_event = torch.cuda.Event(enable_timing=False)
|
|
self.q_recorder = queue.Queue()
|
|
self.use_stream = zero_overhead_stream(self.model_executor.device_config.device)
|
|
if not is_zero_no_thread():
|
|
self.zero_thread = threading.Thread(target=self.thread_zero_overhead)
|
|
self.thread_running = True
|
|
self.sem_m2s = threading.Semaphore(0) # main to scheduler thread
|
|
self.zero_thread.start()
|
|
profile.StartTracer()
|
|
|
|
def __del__(self):
|
|
self.finish_thread()
|
|
return super().__del__()
|
|
|
|
def finish_thread(self):
|
|
if self.thread_running:
|
|
self.thread_running = False
|
|
self.sem_m2s.release()
|
|
|
|
|
|
def thread_zero_overhead(self):
|
|
logger.info('zero overhead thread start!')
|
|
last_sampler = get_last_sampler()
|
|
last_sampler.seq_ids.clear()
|
|
try:
|
|
with torch.cuda.stream(self.use_stream):
|
|
while True:
|
|
self.sem_m2s.acquire()
|
|
if not self.thread_running:
|
|
logger.debug("Stopping remote worker execution loop.")
|
|
self.model_executor.stop_remote_worker_execution_loop()
|
|
|
|
break
|
|
|
|
virtual_engine = 0
|
|
# Clear outputs for each new scheduler iteration
|
|
|
|
# Schedule iteration
|
|
(seq_group_metadata_list, scheduler_outputs,
|
|
allow_async_output_proc
|
|
) = self.scheduler[virtual_engine].schedule()
|
|
if self.last_record is not None:
|
|
last_sampler = self.last_record[1]
|
|
spec_step = get_spec_step()
|
|
if spec_step == SpecStepKind.KIND_DEFAULT:
|
|
if last_sampler.sampled_token_ids_tensor is not None:
|
|
self.async_d2h = last_sampler.sampled_token_ids_tensor.to('cpu', non_blocking=True)
|
|
else:
|
|
self.async_d2h = None
|
|
elif spec_step == SpecStepKind.SCORE_DECODE:
|
|
self.async_d2h = last_sampler.to('cpu', non_blocking=True)
|
|
self.async_event.record()
|
|
self.q_recorder.put(self.last_record)
|
|
else:
|
|
self.q_recorder.put(None)
|
|
if len(seq_group_metadata_list) == 0:
|
|
self.last_record = None
|
|
continue
|
|
|
|
finished_requests_ids = self.scheduler[
|
|
virtual_engine].get_and_reset_finished_requests_ids()
|
|
|
|
assert seq_group_metadata_list is not None
|
|
assert scheduler_outputs is not None
|
|
last_sampled_token_ids = \
|
|
self._get_last_sampled_token_ids(virtual_engine)
|
|
|
|
execute_model_req = ExecuteModelRequest(
|
|
seq_group_metadata_list=seq_group_metadata_list,
|
|
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
|
|
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
|
|
blocks_to_copy=scheduler_outputs.blocks_to_copy,
|
|
num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
|
|
running_queue_size=scheduler_outputs.running_queue_size,
|
|
finished_requests_ids=finished_requests_ids,
|
|
# We use ExecuteModelRequest to pass the last sampled_token_ids
|
|
# to each of the non-last PP stages for in-place prepare_input.
|
|
last_sampled_token_ids=last_sampled_token_ids)
|
|
outputs = self.model_executor.execute_model(
|
|
execute_model_req=execute_model_req)
|
|
|
|
for output in outputs:
|
|
self._advance_to_next_step(
|
|
output, seq_group_metadata_list,
|
|
scheduler_outputs.scheduled_seq_groups)
|
|
scheduler_outputs.scheduled_seq_groups = [item for item in scheduler_outputs.scheduled_seq_groups] #deep copy
|
|
last_sampler = None
|
|
spec_step = get_spec_step()
|
|
if spec_step == SpecStepKind.KIND_DEFAULT:
|
|
last_sampler = get_last_sampler()
|
|
elif spec_step == SpecStepKind.SCORE_DECODE:
|
|
last_sampler, _ = get_accepted_token_ids()
|
|
self.last_record = [outputs, last_sampler, seq_group_metadata_list, scheduler_outputs, spec_step]
|
|
|
|
except Exception as e:
|
|
print(f"thread_zero_overhead error : {e}")
|
|
traceback.print_exc()
|
|
|
|
def zero_overhead_step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
|
|
if not self.thread_running:
|
|
self.zero_thread.join()
|
|
self.thread_running = True
|
|
self.zero_thread = threading.Thread(target=self.thread_zero_overhead)
|
|
self.zero_thread.start()
|
|
self.sem_m2s.release()
|
|
recode_output = self.q_recorder.get()
|
|
if recode_output is None: # None is for the first step
|
|
return None
|
|
virtual_engine = 0
|
|
ctx = self.scheduler_contexts[virtual_engine]
|
|
ctx.request_outputs.clear()
|
|
outputs, last_sampler, seq_group_metadata_list, scheduler_outputs, spec_step = recode_output
|
|
ctx.seq_group_metadata_list = seq_group_metadata_list
|
|
ctx.scheduler_outputs = scheduler_outputs
|
|
if spec_step == SpecStepKind.KIND_DEFAULT:
|
|
self.async_event.synchronize()
|
|
if self.async_d2h is not None:
|
|
self._fix_last_step(
|
|
outputs, last_sampler, seq_group_metadata_list,
|
|
scheduler_outputs.scheduled_seq_groups)
|
|
elif spec_step == SpecStepKind.SCORE_DECODE:
|
|
self.async_event.synchronize()
|
|
self._fix_spec_decode_steps(
|
|
outputs, seq_group_metadata_list,
|
|
scheduler_outputs.scheduled_seq_groups)
|
|
|
|
# is_first_step_output is True only when the num_steps of all
|
|
# the sequences are 1. When the num_steps > 1,
|
|
# multi_step_model_runner does the first-step output append.
|
|
is_first_step_output: bool = False if not seq_group_metadata_list \
|
|
else seq_group_metadata_list[0].state.num_steps == 1
|
|
|
|
# Add results to the output_queue
|
|
ctx.append_output(outputs=outputs,
|
|
seq_group_metadata_list=seq_group_metadata_list,
|
|
scheduler_outputs=scheduler_outputs,
|
|
is_async=True,
|
|
is_last_step=True,
|
|
is_first_step_output=is_first_step_output)
|
|
|
|
# Check if need to run the usual non-async path
|
|
#if not allow_async_output_proc:
|
|
self._process_model_outputs(ctx=ctx)
|
|
|
|
#profile.ProfRangeAutoPush('has_unfinish')
|
|
if not self.has_unfinished_requests():
|
|
# Drain async postprocessor (if exists)
|
|
if len(ctx.output_queue) > 0:
|
|
self._process_model_outputs(ctx=ctx)
|
|
assert len(ctx.output_queue) == 0
|
|
|
|
# Stop the execute model loop in parallel workers until there are
|
|
# more requests to process. This avoids waiting indefinitely in
|
|
# torch.distributed ops which may otherwise timeout, and unblocks
|
|
# the RPC thread in the workers so that they can process any other
|
|
# queued control plane messages, such as add/remove lora adapters.
|
|
# logger.debug("Stopping remote worker execution loop.")
|
|
# self.model_executor.stop_remote_worker_execution_loop()
|
|
self.finish_thread()
|
|
return ctx.request_outputs
|
|
|
|
|
|
def _fix_last_step(
|
|
self, output: List[SamplerOutput],
|
|
last_sampler: SampleRecorder,
|
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
|
scheduled_seq_groups: List[ScheduledSequenceGroup]) -> None:
|
|
|
|
#sample_out_list = output[0].sampler_out_tenosr.cpu().tolist()
|
|
sample_out_list = self.async_d2h.tolist()
|
|
sample_out_ids = last_sampler.seq_ids
|
|
for seq_group_metadata, sequence_group_outputs, scheduled_seq_group in \
|
|
zip(seq_group_metadata_list, output[0], scheduled_seq_groups):
|
|
seq_group = scheduled_seq_group.seq_group
|
|
|
|
if seq_group.is_finished():
|
|
continue
|
|
|
|
if seq_group_metadata.do_sample:
|
|
sample = sequence_group_outputs.samples[0]
|
|
|
|
assert len(seq_group.seqs) == 1
|
|
seq : ZeroOverheadSequence = seq_group.seqs[0]
|
|
for token_id, seq_id in zip(sample_out_list, sample_out_ids):
|
|
if seq.seq_id == seq_id:
|
|
if type(token_id) is list:
|
|
sample.output_token = token_id[0]
|
|
else:
|
|
sample.output_token = token_id
|
|
seq.fix_last_token_id(sample.output_token)
|
|
break
|
|
|
|
def _fix_spec_decode_steps(
|
|
self, output: List[SamplerOutput],
|
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
|
scheduled_seq_groups: List[ScheduledSequenceGroup]):
|
|
|
|
sample_out_list = self.async_d2h.tolist()
|
|
group_idx = 0
|
|
for seq_group_metadata, accept_token_ids, scheduled_seq_group in \
|
|
zip(seq_group_metadata_list, sample_out_list, scheduled_seq_groups):
|
|
seq_group = scheduled_seq_group.seq_group
|
|
|
|
if seq_group.is_finished():
|
|
group_idx += 1
|
|
continue
|
|
|
|
if seq_group_metadata.do_sample:
|
|
assert len(seq_group.seqs) == 1
|
|
seq : ZeroOverheadSequence = seq_group.seqs[0]
|
|
remove_count = 0
|
|
for token_id in accept_token_ids:
|
|
if token_id == -1:
|
|
remove_count += 1
|
|
else:
|
|
seq.fix_last_token_id(token_id)
|
|
seq.remove_last_place_holder(remove_count)
|
|
group_idx += 1
|
|
|
|
def no_thread_step(self):
|
|
virtual_engine = 0
|
|
# Clear outputs for each new scheduler iteration
|
|
|
|
# Schedule iteration
|
|
(seq_group_metadata_list, scheduler_outputs,
|
|
allow_async_output_proc
|
|
) = self.scheduler[virtual_engine].schedule()
|
|
if self.last_record is not None:
|
|
last_sampler = self.last_record[1]
|
|
self.async_d2h = last_sampler.sampled_token_ids_tensor.to('cpu', non_blocking=True)
|
|
self.async_event.record()
|
|
self.q_recorder.put(self.last_record)
|
|
else:
|
|
self.q_recorder.put(None)
|
|
if len(seq_group_metadata_list) == 0:
|
|
self.last_record = None
|
|
else:
|
|
finished_requests_ids = self.scheduler[
|
|
virtual_engine].get_and_reset_finished_requests_ids()
|
|
|
|
assert seq_group_metadata_list is not None
|
|
assert scheduler_outputs is not None
|
|
last_sampled_token_ids = \
|
|
self._get_last_sampled_token_ids(virtual_engine)
|
|
|
|
execute_model_req = ExecuteModelRequest(
|
|
seq_group_metadata_list=seq_group_metadata_list,
|
|
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
|
|
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
|
|
blocks_to_copy=scheduler_outputs.blocks_to_copy,
|
|
num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
|
|
running_queue_size=scheduler_outputs.running_queue_size,
|
|
finished_requests_ids=finished_requests_ids,
|
|
# We use ExecuteModelRequest to pass the last sampled_token_ids
|
|
# to each of the non-last PP stages for in-place prepare_input.
|
|
last_sampled_token_ids=last_sampled_token_ids)
|
|
outputs = self.model_executor.execute_model(
|
|
execute_model_req=execute_model_req)
|
|
|
|
if len(outputs) == 1:
|
|
self._advance_to_next_step(
|
|
outputs[0], seq_group_metadata_list,
|
|
scheduler_outputs.scheduled_seq_groups)
|
|
scheduler_outputs.scheduled_seq_groups = [item for item in scheduler_outputs.scheduled_seq_groups] #deep copy
|
|
last_sampler = get_last_sampler()
|
|
self.last_record = [outputs, last_sampler, seq_group_metadata_list, scheduler_outputs]
|
|
|
|
recode_output = self.q_recorder.get()
|
|
if recode_output is None: # None is for the first step
|
|
return None
|
|
virtual_engine = 0
|
|
ctx = self.scheduler_contexts[virtual_engine]
|
|
ctx.request_outputs.clear()
|
|
outputs, last_sampler, seq_group_metadata_list, scheduler_outputs = recode_output
|
|
ctx.seq_group_metadata_list = seq_group_metadata_list
|
|
ctx.scheduler_outputs = scheduler_outputs
|
|
self.async_event.synchronize()
|
|
self._fix_last_step(
|
|
outputs, last_sampler, seq_group_metadata_list,
|
|
scheduler_outputs.scheduled_seq_groups)
|
|
|
|
# is_first_step_output is True only when the num_steps of all
|
|
# the sequences are 1. When the num_steps > 1,
|
|
# multi_step_model_runner does the first-step output append.
|
|
is_first_step_output: bool = False if not seq_group_metadata_list \
|
|
else seq_group_metadata_list[0].state.num_steps == 1
|
|
|
|
# Add results to the output_queue
|
|
ctx.append_output(outputs=outputs,
|
|
seq_group_metadata_list=seq_group_metadata_list,
|
|
scheduler_outputs=scheduler_outputs,
|
|
is_async=True,
|
|
is_last_step=True,
|
|
is_first_step_output=is_first_step_output)
|
|
|
|
# Check if need to run the usual non-async path
|
|
#if not allow_async_output_proc:
|
|
self._process_model_outputs(ctx=ctx)
|
|
|
|
#profile.ProfRangeAutoPush('has_unfinish')
|
|
if not self.has_unfinished_requests():
|
|
# Drain async postprocessor (if exists)
|
|
if len(ctx.output_queue) > 0:
|
|
self._process_model_outputs(ctx=ctx)
|
|
assert len(ctx.output_queue) == 0
|
|
|
|
# Stop the execute model loop in parallel workers until there are
|
|
# more requests to process. This avoids waiting indefinitely in
|
|
# torch.distributed ops which may otherwise timeout, and unblocks
|
|
# the RPC thread in the workers so that they can process any other
|
|
# queued control plane messages, such as add/remove lora adapters.
|
|
logger.debug("Stopping remote worker execution loop.")
|
|
self.model_executor.stop_remote_worker_execution_loop()
|
|
return ctx.request_outputs
|
|
|
|
def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
|
|
with torch.cuda.stream(self.use_stream):
|
|
if is_zero_no_thread():
|
|
out = self.no_thread_step()
|
|
if out is None: #the first step need launch twice
|
|
out = self.no_thread_step()
|
|
else:
|
|
out = self.zero_overhead_step()
|
|
if out is None: #the first step need launch twice
|
|
out = self.zero_overhead_step()
|
|
return out
|
|
|
|
def _add_processed_request(
|
|
self,
|
|
request_id: str,
|
|
processed_inputs: ProcessorInputs,
|
|
params: Union[SamplingParams, PoolingParams],
|
|
arrival_time: float,
|
|
lora_request: Optional[LoRARequest],
|
|
prompt_adapter_request: Optional[PromptAdapterRequest],
|
|
trace_headers: Optional[Mapping[str, str]] = None,
|
|
priority: int = 0,
|
|
) -> Optional[SequenceGroup]:
|
|
"""Add a processed request to the engine's request pool.
|
|
return the created sequence group.
|
|
"""
|
|
if isinstance(params, SamplingParams) and params.n > 1:
|
|
ParallelSampleSequenceGroup.add_request(
|
|
request_id,
|
|
self,
|
|
params,
|
|
processed_inputs=processed_inputs,
|
|
arrival_time=arrival_time,
|
|
lora_request=lora_request,
|
|
trace_headers=trace_headers,
|
|
prompt_adapter_request=prompt_adapter_request,
|
|
priority=priority,
|
|
)
|
|
return None
|
|
|
|
self._validate_model_inputs(processed_inputs, lora_request)
|
|
# Create the sequences.
|
|
block_size = self.cache_config.block_size
|
|
seq_id = next(self.seq_counter)
|
|
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
|
|
|
|
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
|
|
|
|
seq = ZeroOverheadSequence(seq_id, decoder_inputs, block_size, eos_token_id,
|
|
lora_request, prompt_adapter_request)
|
|
|
|
encoder_seq = (None if encoder_inputs is None else ZeroOverheadSequence(
|
|
seq_id, encoder_inputs, block_size, eos_token_id, lora_request,
|
|
prompt_adapter_request))
|
|
|
|
# Create a SequenceGroup based on SamplingParams or PoolingParams
|
|
if isinstance(params, SamplingParams):
|
|
seq_group = self._create_sequence_group_with_sampling(
|
|
request_id,
|
|
seq,
|
|
params,
|
|
arrival_time=arrival_time,
|
|
lora_request=lora_request,
|
|
trace_headers=trace_headers,
|
|
prompt_adapter_request=prompt_adapter_request,
|
|
encoder_seq=encoder_seq,
|
|
priority=priority)
|
|
elif isinstance(params, PoolingParams):
|
|
seq_group = self._create_sequence_group_with_pooling(
|
|
request_id,
|
|
seq,
|
|
params,
|
|
arrival_time=arrival_time,
|
|
lora_request=lora_request,
|
|
prompt_adapter_request=prompt_adapter_request,
|
|
encoder_seq=encoder_seq,
|
|
priority=priority)
|
|
else:
|
|
raise ValueError(
|
|
"Either SamplingParams or PoolingParams must be provided.")
|
|
|
|
# Add the sequence group to the scheduler with least unfinished seqs.
|
|
costs = [
|
|
scheduler.get_num_unfinished_seq_groups()
|
|
for scheduler in self.scheduler
|
|
]
|
|
min_cost_scheduler = self.scheduler[costs.index(min(costs))]
|
|
min_cost_scheduler.add_seq_group(seq_group)
|
|
|
|
return seq_group |