Files
2026-01-09 15:09:53 +08:00

661 lines
30 KiB
Python

from functools import partial
import os
import queue
import threading
import traceback
from typing import Callable, Dict, List, Mapping, Optional, Type, Union
from zlib import ZLIB_VERSION
import torch
from vllm import envs
from vllm.config import DecodingConfig, ObservabilityConfig, VllmConfig
from vllm.core.scheduler import ScheduledSequenceGroup
from vllm.engine.llm_engine import _LOCAL_LOGGING_INTERVAL_SEC, LLMEngine, SchedulerContext, SchedulerOutputState
from vllm.engine.metrics_types import StatLoggerBase
from vllm.engine.output_processor.interfaces import SequenceGroupOutputProcessor
from vllm.logger import init_logger
from vllm.executor.executor_base import ExecutorBase
from vllm.inputs import INPUT_REGISTRY
from vllm.inputs.data import ProcessorInputs
from vllm.inputs.parse import split_enc_dec_inputs
from vllm.inputs.preprocess import InputPreprocessor
from vllm.inputs.registry import InputRegistry
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.registry import MultiModalRegistry
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.sequence import ExecuteModelRequest, ParallelSampleSequenceGroup, SequenceGroup, SequenceGroupBase, SequenceGroupMetadata
from vllm.tracing import init_tracer
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.version import __version__ as VLLM_VERSION
from vllm.usage.usage_lib import UsageContext, is_usage_stats_enabled
from vllm.utils import resolve_obj_by_qualname, weak_bind, Counter
from vllm.zero_overhead.sampler import SampleRecorder, get_last_sampler
from vllm.zero_overhead.sequence import ZeroOverheadSequence
from vllm.zero_overhead.stop_check import ZeroOverheadStopChecker
from vllm.zero_overhead.tokenizer import ZeroOverheadDetokenizer
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message)
from vllm.profiler.prof import profile
from vllm.zero_overhead.utils import SpecStepKind, get_accepted_token_ids, get_spec_step, is_zero_no_thread, set_spec_step, zero_overhead_stream
logger = init_logger(__name__)
class ZeroOverheadEngine(LLMEngine):
def __init__(
self,
vllm_config: VllmConfig,
executor_class: Type[ExecutorBase],
log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
use_cached_outputs: bool = False,
) -> None:
if envs.VLLM_USE_V1:
raise ValueError(
"Using V0 LLMEngine, but envs.VLLM_USE_V1=True. "
"This should not happen. As a workaround, try using "
"LLMEngine.from_vllm_config(...) or explicitly set "
"VLLM_USE_V1=0 or 1 and report this issue on Github.")
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config
self.parallel_config = vllm_config.parallel_config
self.scheduler_config = vllm_config.scheduler_config
self.device_config = vllm_config.device_config
self.speculative_config = vllm_config.speculative_config # noqa
self.load_config = vllm_config.load_config
self.decoding_config = vllm_config.decoding_config or DecodingConfig( # noqa
)
self.prompt_adapter_config = vllm_config.prompt_adapter_config # noqa
self.observability_config = vllm_config.observability_config or ObservabilityConfig( # noqa
)
logger.info(
"Initializing a V0 LLM engine (v%s) with config: %s, "
"use_cached_outputs=%s, ",
VLLM_VERSION,
vllm_config,
use_cached_outputs,
)
self.log_stats = log_stats
self.use_cached_outputs = use_cached_outputs
self.thread_running = False
if not self.model_config.skip_tokenizer_init:
self.tokenizer = self._init_tokenizer()
self.detokenizer = ZeroOverheadDetokenizer(self.tokenizer)
tokenizer_group = self.get_tokenizer_group()
else:
self.tokenizer = None
self.detokenizer = None
tokenizer_group = None
# Ensure that the function doesn't contain a reference to self,
# to avoid engine GC issues
def get_tokenizer_for_seq(sequence: ZeroOverheadSequence) -> AnyTokenizer:
assert tokenizer_group, ("tokenizer_group cannot be None, "
"make sure skip_tokenizer_init is False")
return tokenizer_group.get_lora_tokenizer(sequence.lora_request)
self.seq_counter = Counter()
self.generation_config_fields = (
self.model_config.try_get_generation_config())
self.input_preprocessor = InputPreprocessor(self.model_config,
self.tokenizer,
mm_registry)
self.model_executor = executor_class(vllm_config=vllm_config, )
if self.model_config.runner_type != "pooling":
self._initialize_kv_caches()
# If usage stat is enabled, collect relevant info.
if is_usage_stats_enabled():
from vllm.model_executor.model_loader import (
get_architecture_class_name)
usage_message.report_usage(
get_architecture_class_name(self.model_config),
usage_context,
extra_kvs={
# Common configuration
"dtype":
str(self.model_config.dtype),
"tensor_parallel_size":
self.parallel_config.tensor_parallel_size,
"block_size":
self.cache_config.block_size,
"gpu_memory_utilization":
self.cache_config.gpu_memory_utilization,
# Quantization
"quantization":
self.model_config.quantization,
"kv_cache_dtype":
str(self.cache_config.cache_dtype),
# Feature flags
"enable_lora":
bool(self.lora_config),
"enable_prompt_adapter":
bool(self.prompt_adapter_config),
"enable_prefix_caching":
self.cache_config.enable_prefix_caching,
"enforce_eager":
self.model_config.enforce_eager,
"disable_custom_all_reduce":
self.parallel_config.disable_custom_all_reduce,
})
self.cached_scheduler_outputs = [
SchedulerOutputState()
for _ in range(self.parallel_config.pipeline_parallel_size)
]
self.scheduler_contexts = [
SchedulerContext(multi_step_stream_outputs=self.scheduler_config.
multi_step_stream_outputs)
for _ in range(self.parallel_config.pipeline_parallel_size)
]
if self.model_config.use_async_output_proc:
process_model_outputs = weak_bind(self._process_model_outputs)
self.async_callbacks = [
partial(process_model_outputs,
ctx=self.scheduler_contexts[v_id])
for v_id in range(self.parallel_config.pipeline_parallel_size)
]
else:
self.async_callbacks = []
# Currently used by AsyncLLMEngine to ensure quick append
# of request outputs to asyncio queues
self.process_request_outputs_callback: Optional[Callable] = None
# Create the scheduler.
# NOTE: the cache_config here have been updated with the numbers of
# GPU and CPU blocks, which are profiled in the distributed executor.
if isinstance(self.vllm_config.scheduler_config.scheduler_cls, str):
Scheduler = resolve_obj_by_qualname(
self.vllm_config.scheduler_config.scheduler_cls)
else:
Scheduler = self.vllm_config.scheduler_config.scheduler_cls
self.scheduler = [
Scheduler(
self.scheduler_config, self.cache_config, self.lora_config,
self.parallel_config.pipeline_parallel_size,
self.async_callbacks[v_id]
if self.model_config.use_async_output_proc else None)
for v_id in range(self.parallel_config.pipeline_parallel_size)
]
# Metric Logging.
if self.log_stats:
if stat_loggers is not None:
self.stat_loggers = stat_loggers
else:
# Lazy import for prometheus multiprocessing.
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
# before prometheus_client is imported.
# See https://prometheus.github.io/client_python/multiprocess/
from vllm.engine.metrics import (LoggingStatLogger,
PrometheusStatLogger)
self.stat_loggers = {
"logging":
LoggingStatLogger(
local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
vllm_config=vllm_config),
"prometheus":
PrometheusStatLogger(
local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
labels=dict(
model_name=self.model_config.served_model_name),
vllm_config=vllm_config),
}
self.stat_loggers["prometheus"].info("cache_config",
self.cache_config)
self.tracer = None
if self.observability_config.otlp_traces_endpoint:
self.tracer = init_tracer(
"vllm.llm_engine",
self.observability_config.otlp_traces_endpoint)
# Create sequence output processor, e.g. for beam search or
# speculative decoding.
self.output_processor = (
SequenceGroupOutputProcessor.create_output_processor(
self.scheduler_config,
self.detokenizer,
self.scheduler,
self.seq_counter,
get_tokenizer_for_seq,
stop_checker=ZeroOverheadStopChecker(
self.scheduler_config.max_model_len,
get_tokenizer_for_seq,
),
))
self.tree_decoding = os.environ.get('VLLM_TREE_DECODING') == '1'
self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {}
# Flag to set when an input fails to process and the engine should run
# the next step without re-scheduling.
self._skip_scheduling_next_step = False
self.async_d2h = None
self.last_record = None
self.async_event = torch.cuda.Event(enable_timing=False)
self.q_recorder = queue.Queue()
self.use_stream = zero_overhead_stream(self.model_executor.device_config.device)
if not is_zero_no_thread():
self.zero_thread = threading.Thread(target=self.thread_zero_overhead)
self.thread_running = True
self.sem_m2s = threading.Semaphore(0) # main to scheduler thread
self.zero_thread.start()
profile.StartTracer()
def __del__(self):
self.finish_thread()
return super().__del__()
def finish_thread(self):
if self.thread_running:
self.thread_running = False
self.sem_m2s.release()
def thread_zero_overhead(self):
logger.info('zero overhead thread start!')
last_sampler = get_last_sampler()
last_sampler.seq_ids.clear()
try:
with torch.cuda.stream(self.use_stream):
while True:
self.sem_m2s.acquire()
if not self.thread_running:
logger.debug("Stopping remote worker execution loop.")
self.model_executor.stop_remote_worker_execution_loop()
break
virtual_engine = 0
# Clear outputs for each new scheduler iteration
# Schedule iteration
(seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc
) = self.scheduler[virtual_engine].schedule()
if self.last_record is not None:
last_sampler = self.last_record[1]
spec_step = get_spec_step()
if spec_step == SpecStepKind.KIND_DEFAULT:
if last_sampler.sampled_token_ids_tensor is not None:
self.async_d2h = last_sampler.sampled_token_ids_tensor.to('cpu', non_blocking=True)
else:
self.async_d2h = None
elif spec_step == SpecStepKind.SCORE_DECODE:
self.async_d2h = last_sampler.to('cpu', non_blocking=True)
self.async_event.record()
self.q_recorder.put(self.last_record)
else:
self.q_recorder.put(None)
if len(seq_group_metadata_list) == 0:
self.last_record = None
continue
finished_requests_ids = self.scheduler[
virtual_engine].get_and_reset_finished_requests_ids()
assert seq_group_metadata_list is not None
assert scheduler_outputs is not None
last_sampled_token_ids = \
self._get_last_sampled_token_ids(virtual_engine)
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
blocks_to_copy=scheduler_outputs.blocks_to_copy,
num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
running_queue_size=scheduler_outputs.running_queue_size,
finished_requests_ids=finished_requests_ids,
# We use ExecuteModelRequest to pass the last sampled_token_ids
# to each of the non-last PP stages for in-place prepare_input.
last_sampled_token_ids=last_sampled_token_ids)
outputs = self.model_executor.execute_model(
execute_model_req=execute_model_req)
for output in outputs:
self._advance_to_next_step(
output, seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)
scheduler_outputs.scheduled_seq_groups = [item for item in scheduler_outputs.scheduled_seq_groups] #deep copy
last_sampler = None
spec_step = get_spec_step()
if spec_step == SpecStepKind.KIND_DEFAULT:
last_sampler = get_last_sampler()
elif spec_step == SpecStepKind.SCORE_DECODE:
last_sampler, _ = get_accepted_token_ids()
self.last_record = [outputs, last_sampler, seq_group_metadata_list, scheduler_outputs, spec_step]
except Exception as e:
print(f"thread_zero_overhead error : {e}")
traceback.print_exc()
def zero_overhead_step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
if not self.thread_running:
self.zero_thread.join()
self.thread_running = True
self.zero_thread = threading.Thread(target=self.thread_zero_overhead)
self.zero_thread.start()
self.sem_m2s.release()
recode_output = self.q_recorder.get()
if recode_output is None: # None is for the first step
return None
virtual_engine = 0
ctx = self.scheduler_contexts[virtual_engine]
ctx.request_outputs.clear()
outputs, last_sampler, seq_group_metadata_list, scheduler_outputs, spec_step = recode_output
ctx.seq_group_metadata_list = seq_group_metadata_list
ctx.scheduler_outputs = scheduler_outputs
if spec_step == SpecStepKind.KIND_DEFAULT:
self.async_event.synchronize()
if self.async_d2h is not None:
self._fix_last_step(
outputs, last_sampler, seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)
elif spec_step == SpecStepKind.SCORE_DECODE:
self.async_event.synchronize()
self._fix_spec_decode_steps(
outputs, seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)
# is_first_step_output is True only when the num_steps of all
# the sequences are 1. When the num_steps > 1,
# multi_step_model_runner does the first-step output append.
is_first_step_output: bool = False if not seq_group_metadata_list \
else seq_group_metadata_list[0].state.num_steps == 1
# Add results to the output_queue
ctx.append_output(outputs=outputs,
seq_group_metadata_list=seq_group_metadata_list,
scheduler_outputs=scheduler_outputs,
is_async=True,
is_last_step=True,
is_first_step_output=is_first_step_output)
# Check if need to run the usual non-async path
#if not allow_async_output_proc:
self._process_model_outputs(ctx=ctx)
#profile.ProfRangeAutoPush('has_unfinish')
if not self.has_unfinished_requests():
# Drain async postprocessor (if exists)
if len(ctx.output_queue) > 0:
self._process_model_outputs(ctx=ctx)
assert len(ctx.output_queue) == 0
# Stop the execute model loop in parallel workers until there are
# more requests to process. This avoids waiting indefinitely in
# torch.distributed ops which may otherwise timeout, and unblocks
# the RPC thread in the workers so that they can process any other
# queued control plane messages, such as add/remove lora adapters.
# logger.debug("Stopping remote worker execution loop.")
# self.model_executor.stop_remote_worker_execution_loop()
self.finish_thread()
return ctx.request_outputs
def _fix_last_step(
self, output: List[SamplerOutput],
last_sampler: SampleRecorder,
seq_group_metadata_list: List[SequenceGroupMetadata],
scheduled_seq_groups: List[ScheduledSequenceGroup]) -> None:
#sample_out_list = output[0].sampler_out_tenosr.cpu().tolist()
sample_out_list = self.async_d2h.tolist()
sample_out_ids = last_sampler.seq_ids
for seq_group_metadata, sequence_group_outputs, scheduled_seq_group in \
zip(seq_group_metadata_list, output[0], scheduled_seq_groups):
seq_group = scheduled_seq_group.seq_group
if seq_group.is_finished():
continue
if seq_group_metadata.do_sample:
sample = sequence_group_outputs.samples[0]
assert len(seq_group.seqs) == 1
seq : ZeroOverheadSequence = seq_group.seqs[0]
for token_id, seq_id in zip(sample_out_list, sample_out_ids):
if seq.seq_id == seq_id:
if type(token_id) is list:
sample.output_token = token_id[0]
else:
sample.output_token = token_id
seq.fix_last_token_id(sample.output_token)
break
def _fix_spec_decode_steps(
self, output: List[SamplerOutput],
seq_group_metadata_list: List[SequenceGroupMetadata],
scheduled_seq_groups: List[ScheduledSequenceGroup]):
sample_out_list = self.async_d2h.tolist()
group_idx = 0
for seq_group_metadata, accept_token_ids, scheduled_seq_group in \
zip(seq_group_metadata_list, sample_out_list, scheduled_seq_groups):
seq_group = scheduled_seq_group.seq_group
if seq_group.is_finished():
group_idx += 1
continue
if seq_group_metadata.do_sample:
assert len(seq_group.seqs) == 1
seq : ZeroOverheadSequence = seq_group.seqs[0]
remove_count = 0
for token_id in accept_token_ids:
if token_id == -1:
remove_count += 1
else:
seq.fix_last_token_id(token_id)
seq.remove_last_place_holder(remove_count)
group_idx += 1
def no_thread_step(self):
virtual_engine = 0
# Clear outputs for each new scheduler iteration
# Schedule iteration
(seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc
) = self.scheduler[virtual_engine].schedule()
if self.last_record is not None:
last_sampler = self.last_record[1]
self.async_d2h = last_sampler.sampled_token_ids_tensor.to('cpu', non_blocking=True)
self.async_event.record()
self.q_recorder.put(self.last_record)
else:
self.q_recorder.put(None)
if len(seq_group_metadata_list) == 0:
self.last_record = None
else:
finished_requests_ids = self.scheduler[
virtual_engine].get_and_reset_finished_requests_ids()
assert seq_group_metadata_list is not None
assert scheduler_outputs is not None
last_sampled_token_ids = \
self._get_last_sampled_token_ids(virtual_engine)
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
blocks_to_copy=scheduler_outputs.blocks_to_copy,
num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
running_queue_size=scheduler_outputs.running_queue_size,
finished_requests_ids=finished_requests_ids,
# We use ExecuteModelRequest to pass the last sampled_token_ids
# to each of the non-last PP stages for in-place prepare_input.
last_sampled_token_ids=last_sampled_token_ids)
outputs = self.model_executor.execute_model(
execute_model_req=execute_model_req)
if len(outputs) == 1:
self._advance_to_next_step(
outputs[0], seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)
scheduler_outputs.scheduled_seq_groups = [item for item in scheduler_outputs.scheduled_seq_groups] #deep copy
last_sampler = get_last_sampler()
self.last_record = [outputs, last_sampler, seq_group_metadata_list, scheduler_outputs]
recode_output = self.q_recorder.get()
if recode_output is None: # None is for the first step
return None
virtual_engine = 0
ctx = self.scheduler_contexts[virtual_engine]
ctx.request_outputs.clear()
outputs, last_sampler, seq_group_metadata_list, scheduler_outputs = recode_output
ctx.seq_group_metadata_list = seq_group_metadata_list
ctx.scheduler_outputs = scheduler_outputs
self.async_event.synchronize()
self._fix_last_step(
outputs, last_sampler, seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)
# is_first_step_output is True only when the num_steps of all
# the sequences are 1. When the num_steps > 1,
# multi_step_model_runner does the first-step output append.
is_first_step_output: bool = False if not seq_group_metadata_list \
else seq_group_metadata_list[0].state.num_steps == 1
# Add results to the output_queue
ctx.append_output(outputs=outputs,
seq_group_metadata_list=seq_group_metadata_list,
scheduler_outputs=scheduler_outputs,
is_async=True,
is_last_step=True,
is_first_step_output=is_first_step_output)
# Check if need to run the usual non-async path
#if not allow_async_output_proc:
self._process_model_outputs(ctx=ctx)
#profile.ProfRangeAutoPush('has_unfinish')
if not self.has_unfinished_requests():
# Drain async postprocessor (if exists)
if len(ctx.output_queue) > 0:
self._process_model_outputs(ctx=ctx)
assert len(ctx.output_queue) == 0
# Stop the execute model loop in parallel workers until there are
# more requests to process. This avoids waiting indefinitely in
# torch.distributed ops which may otherwise timeout, and unblocks
# the RPC thread in the workers so that they can process any other
# queued control plane messages, such as add/remove lora adapters.
logger.debug("Stopping remote worker execution loop.")
self.model_executor.stop_remote_worker_execution_loop()
return ctx.request_outputs
def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
with torch.cuda.stream(self.use_stream):
if is_zero_no_thread():
out = self.no_thread_step()
if out is None: #the first step need launch twice
out = self.no_thread_step()
else:
out = self.zero_overhead_step()
if out is None: #the first step need launch twice
out = self.zero_overhead_step()
return out
def _add_processed_request(
self,
request_id: str,
processed_inputs: ProcessorInputs,
params: Union[SamplingParams, PoolingParams],
arrival_time: float,
lora_request: Optional[LoRARequest],
prompt_adapter_request: Optional[PromptAdapterRequest],
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
) -> Optional[SequenceGroup]:
"""Add a processed request to the engine's request pool.
return the created sequence group.
"""
if isinstance(params, SamplingParams) and params.n > 1:
ParallelSampleSequenceGroup.add_request(
request_id,
self,
params,
processed_inputs=processed_inputs,
arrival_time=arrival_time,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority,
)
return None
self._validate_model_inputs(processed_inputs, lora_request)
# Create the sequences.
block_size = self.cache_config.block_size
seq_id = next(self.seq_counter)
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
seq = ZeroOverheadSequence(seq_id, decoder_inputs, block_size, eos_token_id,
lora_request, prompt_adapter_request)
encoder_seq = (None if encoder_inputs is None else ZeroOverheadSequence(
seq_id, encoder_inputs, block_size, eos_token_id, lora_request,
prompt_adapter_request))
# Create a SequenceGroup based on SamplingParams or PoolingParams
if isinstance(params, SamplingParams):
seq_group = self._create_sequence_group_with_sampling(
request_id,
seq,
params,
arrival_time=arrival_time,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
encoder_seq=encoder_seq,
priority=priority)
elif isinstance(params, PoolingParams):
seq_group = self._create_sequence_group_with_pooling(
request_id,
seq,
params,
arrival_time=arrival_time,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
encoder_seq=encoder_seq,
priority=priority)
else:
raise ValueError(
"Either SamplingParams or PoolingParams must be provided.")
# Add the sequence group to the scheduler with least unfinished seqs.
costs = [
scheduler.get_num_unfinished_seq_groups()
for scheduler in self.scheduler
]
min_cost_scheduler = self.scheduler[costs.index(min(costs))]
min_cost_scheduler.add_seq_group(seq_group)
return seq_group