[gpt-oss] Add gpt-oss bf16 support

This commit is contained in:
2025-08-13 21:25:57 +08:00
parent 5d2e7edf78
commit 17ea2ec6aa
1232 changed files with 777 additions and 36 deletions

0
vllm/engine/__init__.py Normal file
View File

1708
vllm/engine/arg_utils.py Normal file

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,173 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Workaround for https://github.com/python/cpython/issues/86296
#
# From https://github.com/aio-libs/async-timeout/blob/master/async_timeout/__init__.py
# Licensed under the Apache License (Apache-2.0)
import asyncio
import enum
import sys
from types import TracebackType
from typing import Any, Optional, Type
if sys.version_info[:2] >= (3, 11):
from asyncio import timeout as asyncio_timeout
else:
def asyncio_timeout(delay: Optional[float]) -> "Timeout":
"""timeout context manager.
Useful in cases when you want to apply timeout logic around block
of code or in cases when asyncio.wait_for is not suitable. For example:
>>> async with timeout(0.001):
... async with aiohttp.get('https://github.com') as r:
... await r.text()
delay - value in seconds or None to disable timeout logic
"""
loop = asyncio.get_running_loop()
deadline = loop.time() + delay if delay is not None else None
return Timeout(deadline, loop)
class _State(enum.Enum):
INIT = "INIT"
ENTER = "ENTER"
TIMEOUT = "TIMEOUT"
EXIT = "EXIT"
class Timeout:
# Internal class, please don't instantiate it directly
# Use timeout() and timeout_at() public factories instead.
#
# Implementation note: `async with timeout()` is preferred
# over `with timeout()`.
# While technically the Timeout class implementation
# doesn't need to be async at all,
# the `async with` statement explicitly points that
# the context manager should be used from async function context.
#
# This design allows to avoid many silly misusages.
#
# TimeoutError is raised immediately when scheduled
# if the deadline is passed.
# The purpose is to time out as soon as possible
# without waiting for the next await expression.
__slots__ = ("_deadline", "_loop", "_state", "_timeout_handler")
def __init__(self, deadline: Optional[float],
loop: asyncio.AbstractEventLoop) -> None:
self._loop = loop
self._state = _State.INIT
self._timeout_handler = None # type: Optional[asyncio.Handle]
if deadline is None:
self._deadline = None # type: Optional[float]
else:
self.update(deadline)
async def __aenter__(self) -> "Timeout":
self._do_enter()
return self
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> Optional[bool]:
self._do_exit(exc_type)
return None
@property
def expired(self) -> bool:
"""Is timeout expired during execution?"""
return self._state == _State.TIMEOUT
@property
def deadline(self) -> Optional[float]:
return self._deadline
def reject(self) -> None:
"""Reject scheduled timeout if any."""
# cancel is maybe better name but
# task.cancel() raises CancelledError in asyncio world.
if self._state not in (_State.INIT, _State.ENTER):
raise RuntimeError(f"invalid state {self._state.value}")
self._reject()
def _reject(self) -> None:
if self._timeout_handler is not None:
self._timeout_handler.cancel()
self._timeout_handler = None
def shift(self, delay: float) -> None:
"""Advance timeout on delay seconds.
The delay can be negative.
Raise RuntimeError if shift is called when deadline is not scheduled
"""
deadline = self._deadline
if deadline is None:
raise RuntimeError(
"cannot shift timeout if deadline is not scheduled")
self.update(deadline + delay)
def update(self, deadline: float) -> None:
"""Set deadline to absolute value.
deadline argument points on the time in the same clock system
as loop.time().
If new deadline is in the past the timeout is raised immediately.
Please note: it is not POSIX time but a time with
undefined starting base, e.g. the time of the system power on.
"""
if self._state == _State.EXIT:
raise RuntimeError(
"cannot reschedule after exit from context manager")
if self._state == _State.TIMEOUT:
raise RuntimeError("cannot reschedule expired timeout")
if self._timeout_handler is not None:
self._timeout_handler.cancel()
self._deadline = deadline
if self._state != _State.INIT:
self._reschedule()
def _reschedule(self) -> None:
assert self._state == _State.ENTER
deadline = self._deadline
if deadline is None:
return
now = self._loop.time()
if self._timeout_handler is not None:
self._timeout_handler.cancel()
task = asyncio.current_task()
if deadline <= now:
self._timeout_handler = self._loop.call_soon(
self._on_timeout, task)
else:
self._timeout_handler = self._loop.call_at(
deadline, self._on_timeout, task)
def _do_enter(self) -> None:
if self._state != _State.INIT:
raise RuntimeError(f"invalid state {self._state.value}")
self._state = _State.ENTER
self._reschedule()
def _do_exit(self, exc_type: Optional[Type[BaseException]]) -> None:
if exc_type is asyncio.CancelledError and \
self._state == _State.TIMEOUT:
self._timeout_handler = None
raise asyncio.TimeoutError
# timeout has not expired
self._state = _State.EXIT
self._reject()
return None
def _on_timeout(self, task: "Optional[asyncio.Task[Any]]") -> None:
if task:
task.cancel()
self._state = _State.TIMEOUT
# drop the reference early
self._timeout_handler = None

2097
vllm/engine/llm_engine.py Normal file

File diff suppressed because it is too large Load Diff

629
vllm/engine/metrics.py Normal file
View File

@@ -0,0 +1,629 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from typing import TYPE_CHECKING
from typing import Counter as CollectionsCounter
from typing import Dict, List, Optional, Type, Union, cast
import numpy as np
import prometheus_client
from vllm.config import SupportsMetricsInfo, VllmConfig
from vllm.engine.metrics_types import StatLoggerBase, Stats
from vllm.executor.ray_utils import ray
from vllm.logger import init_logger
if ray is not None:
from ray.util import metrics as ray_metrics
else:
ray_metrics = None
if TYPE_CHECKING:
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
logger = init_logger(__name__)
prometheus_client.disable_created_metrics()
# The begin-* and end* here are used by the documentation generator
# to extract the metrics definitions.
# --8<-- [start:metrics-definitions]
class Metrics:
"""
vLLM uses a multiprocessing-based frontend for the OpenAI server.
This means that we need to run prometheus_client in multiprocessing mode
See https://prometheus.github.io/client_python/multiprocess/ for more
details on limitations.
"""
labelname_finish_reason = "finished_reason"
labelname_waiting_lora_adapters = "waiting_lora_adapters"
labelname_running_lora_adapters = "running_lora_adapters"
labelname_max_lora = "max_lora"
_gauge_cls = prometheus_client.Gauge
_counter_cls = prometheus_client.Counter
_histogram_cls = prometheus_client.Histogram
def __init__(self, labelnames: List[str], vllm_config: VllmConfig):
# Unregister any existing vLLM collectors (for CI/CD)
self._unregister_vllm_metrics()
max_model_len = vllm_config.model_config.max_model_len
# Use this flag to hide metrics that were deprecated in
# a previous release and which will be removed future
self.show_hidden_metrics = \
vllm_config.observability_config.show_hidden_metrics
# System stats
# Scheduler State
self.gauge_scheduler_running = self._gauge_cls(
name="vllm:num_requests_running",
documentation="Number of requests currently running on GPU.",
labelnames=labelnames,
multiprocess_mode="sum")
self.gauge_scheduler_waiting = self._gauge_cls(
name="vllm:num_requests_waiting",
documentation="Number of requests waiting to be processed.",
labelnames=labelnames,
multiprocess_mode="sum")
self.gauge_lora_info = self._gauge_cls(
name="vllm:lora_requests_info",
documentation="Running stats on lora requests.",
labelnames=[
self.labelname_running_lora_adapters,
self.labelname_max_lora,
self.labelname_waiting_lora_adapters,
],
multiprocess_mode="livemostrecent",
)
# KV Cache Usage in %
self.gauge_gpu_cache_usage = self._gauge_cls(
name="vllm:gpu_cache_usage_perc",
documentation="GPU KV-cache usage. 1 means 100 percent usage.",
labelnames=labelnames,
multiprocess_mode="sum")
# Iteration stats
self.counter_num_preemption = self._counter_cls(
name="vllm:num_preemptions_total",
documentation="Cumulative number of preemption from the engine.",
labelnames=labelnames)
self.counter_prompt_tokens = self._counter_cls(
name="vllm:prompt_tokens_total",
documentation="Number of prefill tokens processed.",
labelnames=labelnames)
self.counter_generation_tokens = self._counter_cls(
name="vllm:generation_tokens_total",
documentation="Number of generation tokens processed.",
labelnames=labelnames)
self.histogram_iteration_tokens = self._histogram_cls(
name="vllm:iteration_tokens_total",
documentation="Histogram of number of tokens per engine_step.",
labelnames=labelnames,
buckets=[
1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384
])
self.histogram_time_to_first_token = self._histogram_cls(
name="vllm:time_to_first_token_seconds",
documentation="Histogram of time to first token in seconds.",
labelnames=labelnames,
buckets=[
0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5,
0.75, 1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0, 160.0, 640.0,
2560.0
])
self.histogram_time_per_output_token = self._histogram_cls(
name="vllm:time_per_output_token_seconds",
documentation="Histogram of time per output token in seconds.",
labelnames=labelnames,
buckets=[
0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75,
1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0
])
# Request stats
# Latency
request_latency_buckets = [
0.3, 0.5, 0.8, 1.0, 1.5, 2.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0,
40.0, 50.0, 60.0, 120.0, 240.0, 480.0, 960.0, 1920.0, 7680.0
]
self.histogram_e2e_time_request = self._histogram_cls(
name="vllm:e2e_request_latency_seconds",
documentation="Histogram of end to end request latency in seconds.",
labelnames=labelnames,
buckets=request_latency_buckets)
self.histogram_queue_time_request = self._histogram_cls(
name="vllm:request_queue_time_seconds",
documentation=
"Histogram of time spent in WAITING phase for request.",
labelnames=labelnames,
buckets=request_latency_buckets)
self.histogram_inference_time_request = self._histogram_cls(
name="vllm:request_inference_time_seconds",
documentation=
"Histogram of time spent in RUNNING phase for request.",
labelnames=labelnames,
buckets=request_latency_buckets)
self.histogram_prefill_time_request = self._histogram_cls(
name="vllm:request_prefill_time_seconds",
documentation=
"Histogram of time spent in PREFILL phase for request.",
labelnames=labelnames,
buckets=request_latency_buckets)
self.histogram_decode_time_request = self._histogram_cls(
name="vllm:request_decode_time_seconds",
documentation=
"Histogram of time spent in DECODE phase for request.",
labelnames=labelnames,
buckets=request_latency_buckets)
# Metadata
self.histogram_num_prompt_tokens_request = self._histogram_cls(
name="vllm:request_prompt_tokens",
documentation="Number of prefill tokens processed.",
labelnames=labelnames,
buckets=build_1_2_5_buckets(max_model_len),
)
self.histogram_num_generation_tokens_request = \
self._histogram_cls(
name="vllm:request_generation_tokens",
documentation="Number of generation tokens processed.",
labelnames=labelnames,
buckets=build_1_2_5_buckets(max_model_len),
)
self.histogram_max_num_generation_tokens_request = self._histogram_cls(
name="vllm:request_max_num_generation_tokens",
documentation=
"Histogram of maximum number of requested generation tokens.",
labelnames=labelnames,
buckets=build_1_2_5_buckets(max_model_len))
self.histogram_n_request = self._histogram_cls(
name="vllm:request_params_n",
documentation="Histogram of the n request parameter.",
labelnames=labelnames,
buckets=[1, 2, 5, 10, 20],
)
self.histogram_max_tokens_request = self._histogram_cls(
name="vllm:request_params_max_tokens",
documentation="Histogram of the max_tokens request parameter.",
labelnames=labelnames,
buckets=build_1_2_5_buckets(max_model_len),
)
self.counter_request_success = self._counter_cls(
name="vllm:request_success_total",
documentation="Count of successfully processed requests.",
labelnames=labelnames + [Metrics.labelname_finish_reason])
# Speculative decoding stats
self.gauge_spec_decode_draft_acceptance_rate = self._gauge_cls(
name="vllm:spec_decode_draft_acceptance_rate",
documentation="Speulative token acceptance rate.",
labelnames=labelnames,
multiprocess_mode="sum")
self.gauge_spec_decode_efficiency = self._gauge_cls(
name="vllm:spec_decode_efficiency",
documentation="Speculative decoding system efficiency.",
labelnames=labelnames,
multiprocess_mode="sum")
self.counter_spec_decode_num_accepted_tokens = (self._counter_cls(
name="vllm:spec_decode_num_accepted_tokens_total",
documentation="Number of accepted tokens.",
labelnames=labelnames))
self.counter_spec_decode_num_draft_tokens = self._counter_cls(
name="vllm:spec_decode_num_draft_tokens_total",
documentation="Number of draft tokens.",
labelnames=labelnames)
self.counter_spec_decode_num_emitted_tokens = (self._counter_cls(
name="vllm:spec_decode_num_emitted_tokens_total",
documentation="Number of emitted tokens.",
labelnames=labelnames))
# --8<-- [end:metrics-definitions]
def _unregister_vllm_metrics(self) -> None:
for collector in list(prometheus_client.REGISTRY._collector_to_names):
if hasattr(collector, "_name") and "vllm" in collector._name:
prometheus_client.REGISTRY.unregister(collector)
class _RayGaugeWrapper:
"""Wraps around ray.util.metrics.Gauge to provide same API as
prometheus_client.Gauge"""
def __init__(self,
name: str,
documentation: str = "",
labelnames: Optional[List[str]] = None,
multiprocess_mode: str = ""):
del multiprocess_mode
labelnames_tuple = tuple(labelnames) if labelnames else None
self._gauge = ray_metrics.Gauge(name=name,
description=documentation,
tag_keys=labelnames_tuple)
def labels(self, **labels):
self._gauge.set_default_tags(labels)
return self
def set(self, value: Union[int, float]):
return self._gauge.set(value)
def set_to_current_time(self):
# ray metrics doesn't have set_to_current time, https://docs.ray.io/en/latest/_modules/ray/util/metrics.html
return self._gauge.set(time.time())
class _RayCounterWrapper:
"""Wraps around ray.util.metrics.Counter to provide same API as
prometheus_client.Counter"""
def __init__(self,
name: str,
documentation: str = "",
labelnames: Optional[List[str]] = None):
labelnames_tuple = tuple(labelnames) if labelnames else None
self._counter = ray_metrics.Counter(name=name,
description=documentation,
tag_keys=labelnames_tuple)
def labels(self, **labels):
self._counter.set_default_tags(labels)
return self
def inc(self, value: Union[int, float] = 1.0):
if value == 0:
return
return self._counter.inc(value)
class _RayHistogramWrapper:
"""Wraps around ray.util.metrics.Histogram to provide same API as
prometheus_client.Histogram"""
def __init__(self,
name: str,
documentation: str = "",
labelnames: Optional[List[str]] = None,
buckets: Optional[List[float]] = None):
labelnames_tuple = tuple(labelnames) if labelnames else None
boundaries = buckets if buckets else []
self._histogram = ray_metrics.Histogram(name=name,
description=documentation,
tag_keys=labelnames_tuple,
boundaries=boundaries)
def labels(self, **labels):
self._histogram.set_default_tags(labels)
return self
def observe(self, value: Union[int, float]):
return self._histogram.observe(value)
class RayMetrics(Metrics):
"""
RayMetrics is used by RayPrometheusStatLogger to log to Ray metrics.
Provides the same metrics as Metrics but uses Ray's util.metrics library.
"""
_gauge_cls: Type[prometheus_client.Gauge] = cast(
Type[prometheus_client.Gauge], _RayGaugeWrapper)
_counter_cls: Type[prometheus_client.Counter] = cast(
Type[prometheus_client.Counter], _RayCounterWrapper)
_histogram_cls: Type[prometheus_client.Histogram] = cast(
Type[prometheus_client.Histogram], _RayHistogramWrapper)
def __init__(self, labelnames: List[str], vllm_config: VllmConfig):
if ray_metrics is None:
raise ImportError("RayMetrics requires Ray to be installed.")
super().__init__(labelnames, vllm_config)
def _unregister_vllm_metrics(self) -> None:
# No-op on purpose
pass
def build_buckets(mantissa_lst: List[int], max_value: int) -> List[int]:
"""
Builds a list of buckets with increasing powers of 10 multiplied by
mantissa values until the value exceeds the specified maximum.
"""
exponent = 0
buckets: List[int] = []
while True:
for m in mantissa_lst:
value = m * 10**exponent
if value <= max_value:
buckets.append(value)
else:
return buckets
exponent += 1
def build_1_2_5_buckets(max_value: int) -> List[int]:
"""
Example:
>>> build_1_2_5_buckets(100)
[1, 2, 5, 10, 20, 50, 100]
"""
return build_buckets([1, 2, 5], max_value)
def build_1_2_3_5_8_buckets(max_value: int) -> List[int]:
"""
Example:
>>> build_1_2_3_5_8_buckets(100)
[1, 2, 3, 5, 8, 10, 20, 30, 50, 80, 100]
"""
return build_buckets([1, 2, 3, 5, 8], max_value)
def local_interval_elapsed(now: float, last_log: float,
local_interval: float) -> bool:
elapsed_time = now - last_log
return elapsed_time > local_interval
def get_throughput(tracked_stats: List[int], now: float,
last_log: float) -> float:
return float(np.sum(tracked_stats) / (now - last_log))
class LoggingStatLogger(StatLoggerBase):
"""LoggingStatLogger is used in LLMEngine to log to Stdout."""
def __init__(self, local_interval: float, vllm_config: VllmConfig) -> None:
super().__init__(local_interval, vllm_config)
self.last_prompt_throughput: Optional[float] = None
self.last_generation_throughput: Optional[float] = None
def log(self, stats: Stats) -> None:
"""Called by LLMEngine.
Logs to Stdout every self.local_interval seconds."""
# Save tracked stats for token counters.
self.num_prompt_tokens.append(stats.num_prompt_tokens_iter)
self.num_generation_tokens.append(stats.num_generation_tokens_iter)
# Update spec decode metrics
self.maybe_update_spec_decode_metrics(stats)
# Log locally every local_interval seconds.
if local_interval_elapsed(stats.now, self.last_local_log,
self.local_interval):
# Compute summary metrics for tracked stats (and log them
# to promethus if applicable).
prompt_throughput = get_throughput(self.num_prompt_tokens,
now=stats.now,
last_log=self.last_local_log)
generation_throughput = get_throughput(
self.num_generation_tokens,
now=stats.now,
last_log=self.last_local_log)
log_fn = logger.info
if not any((prompt_throughput, generation_throughput,
self.last_prompt_throughput,
self.last_generation_throughput)):
# Avoid log noise on an idle production system
log_fn = logger.debug
log_fn(
"Avg prompt throughput: %.1f tokens/s, "
"Avg generation throughput: %.1f tokens/s, "
"Running: %d reqs, Swapped: %d reqs, "
"Pending: %d reqs, GPU KV cache usage: %.1f%%, "
"CPU KV cache usage: %.1f%%.",
prompt_throughput,
generation_throughput,
stats.num_running_sys,
stats.num_swapped_sys,
stats.num_waiting_sys,
stats.gpu_cache_usage_sys * 100,
stats.cpu_cache_usage_sys * 100,
)
if (stats.cpu_prefix_cache_hit_rate >= 0
or stats.gpu_prefix_cache_hit_rate >= 0):
log_fn(
"Prefix cache hit rate: GPU: %.2f%%, CPU: %.2f%%",
stats.gpu_prefix_cache_hit_rate * 100,
stats.cpu_prefix_cache_hit_rate * 100,
)
if self.spec_decode_metrics is not None:
log_fn(
self._format_spec_decode_metrics_str(
self.spec_decode_metrics))
self._reset(stats, prompt_throughput, generation_throughput)
def _reset(self, stats, prompt_throughput, generation_throughput) -> None:
# Reset tracked stats for next interval.
self.num_prompt_tokens = []
self.num_generation_tokens = []
self.last_local_log = stats.now
self.spec_decode_metrics = None
self.last_prompt_throughput = prompt_throughput
self.last_generation_throughput = generation_throughput
def _format_spec_decode_metrics_str(
self, metrics: "SpecDecodeWorkerMetrics") -> str:
return ("Speculative metrics: "
f"Draft acceptance rate: {metrics.draft_acceptance_rate:.3f}, "
f"System efficiency: {metrics.system_efficiency:.3f}, "
f"Number of speculative tokens: {metrics.num_spec_tokens}, "
f"Number of accepted tokens: {metrics.accepted_tokens}, "
f"Number of draft tokens: {metrics.draft_tokens}, "
f"Number of emitted tokens: {metrics.emitted_tokens}.")
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
raise NotImplementedError
class PrometheusStatLogger(StatLoggerBase):
"""PrometheusStatLogger is used LLMEngine to log to Promethus."""
_metrics_cls = Metrics
_gauge_cls = prometheus_client.Gauge
def __init__(self, local_interval: float, labels: Dict[str, str],
vllm_config: VllmConfig) -> None:
super().__init__(local_interval, vllm_config)
# Prometheus metrics
self.labels = labels
self.metrics = self._metrics_cls(labelnames=list(labels.keys()),
vllm_config=vllm_config)
def _log_gauge(self, gauge, data: Union[int, float]) -> None:
# Convenience function for logging to gauge.
gauge.labels(**self.labels).set(data)
def _log_counter(self, counter, data: Union[int, float]) -> None:
# Convenience function for logging to counter.
# Prevent ValueError from negative increment
if data < 0:
logger.warning("Skipping negative increment of %g to %s", data,
counter)
return
counter.labels(**self.labels).inc(data)
def _log_counter_labels(self, counter, data: CollectionsCounter,
label_key: str) -> None:
# Convenience function for collection counter of labels.
for label, count in data.items():
counter.labels(**{**self.labels, label_key: label}).inc(count)
def _log_histogram(self, histogram, data: Union[List[int],
List[float]]) -> None:
# Convenience function for logging list to histogram.
for datum in data:
histogram.labels(**self.labels).observe(datum)
def _log_gauge_string(self, gauge, data: Dict[str, str]) -> None:
gauge.labels(**data).set_to_current_time()
def _log_prometheus(self, stats: Stats) -> None:
# System state data
self._log_gauge(self.metrics.gauge_scheduler_running,
stats.num_running_sys)
self._log_gauge(self.metrics.gauge_scheduler_waiting,
stats.num_waiting_sys)
self._log_gauge(self.metrics.gauge_gpu_cache_usage,
stats.gpu_cache_usage_sys)
# Including max-lora in metric, in future this property of lora
# config maybe extended to be dynamic.
lora_info = {
self.metrics.labelname_running_lora_adapters:
",".join(stats.running_lora_adapters),
self.metrics.labelname_waiting_lora_adapters:
",".join(stats.waiting_lora_adapters),
self.metrics.labelname_max_lora:
stats.max_lora,
}
self._log_gauge_string(self.metrics.gauge_lora_info, lora_info)
# Iteration level data
self._log_counter(self.metrics.counter_num_preemption,
stats.num_preemption_iter)
self._log_counter(self.metrics.counter_prompt_tokens,
stats.num_prompt_tokens_iter)
self._log_counter(self.metrics.counter_generation_tokens,
stats.num_generation_tokens_iter)
self._log_histogram(self.metrics.histogram_iteration_tokens,
[stats.num_tokens_iter])
self._log_histogram(self.metrics.histogram_time_to_first_token,
stats.time_to_first_tokens_iter)
self._log_histogram(self.metrics.histogram_time_per_output_token,
stats.time_per_output_tokens_iter)
# Request level data
# Latency
self._log_histogram(self.metrics.histogram_e2e_time_request,
stats.time_e2e_requests)
self._log_histogram(self.metrics.histogram_queue_time_request,
stats.time_queue_requests)
self._log_histogram(self.metrics.histogram_inference_time_request,
stats.time_inference_requests)
self._log_histogram(self.metrics.histogram_prefill_time_request,
stats.time_prefill_requests)
self._log_histogram(self.metrics.histogram_decode_time_request,
stats.time_decode_requests)
# Metadata
finished_reason_counter = CollectionsCounter(
stats.finished_reason_requests)
self._log_counter_labels(self.metrics.counter_request_success,
finished_reason_counter,
Metrics.labelname_finish_reason)
self._log_histogram(self.metrics.histogram_num_prompt_tokens_request,
stats.num_prompt_tokens_requests)
self._log_histogram(
self.metrics.histogram_num_generation_tokens_request,
stats.num_generation_tokens_requests)
self._log_histogram(self.metrics.histogram_n_request, stats.n_requests)
self._log_histogram(
self.metrics.histogram_max_num_generation_tokens_request,
stats.max_num_generation_tokens_requests)
self._log_histogram(self.metrics.histogram_max_tokens_request,
stats.max_tokens_requests)
def log(self, stats: Stats):
"""Logs to prometheus and tracked stats every iteration."""
# Log to prometheus.
self._log_prometheus(stats)
# Save tracked stats for token counters.
self.num_prompt_tokens.append(stats.num_prompt_tokens_iter)
self.num_generation_tokens.append(stats.num_generation_tokens_iter)
# Update spec decode metrics
self.maybe_update_spec_decode_metrics(stats)
# Log locally every local_interval seconds.
if local_interval_elapsed(stats.now, self.last_local_log,
self.local_interval):
if self.spec_decode_metrics is not None:
self._log_gauge(
self.metrics.gauge_spec_decode_draft_acceptance_rate,
self.spec_decode_metrics.draft_acceptance_rate)
self._log_gauge(self.metrics.gauge_spec_decode_efficiency,
self.spec_decode_metrics.system_efficiency)
self._log_counter(
self.metrics.counter_spec_decode_num_accepted_tokens,
self.spec_decode_metrics.accepted_tokens)
self._log_counter(
self.metrics.counter_spec_decode_num_draft_tokens,
self.spec_decode_metrics.draft_tokens)
self._log_counter(
self.metrics.counter_spec_decode_num_emitted_tokens,
self.spec_decode_metrics.emitted_tokens)
# Reset tracked stats for next interval.
self.num_prompt_tokens = []
self.num_generation_tokens = []
self.last_local_log = stats.now
self.spec_decode_metrics = None
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
# Info type metrics are syntactic sugar for a gauge permanently set to 1
# Since prometheus multiprocessing mode does not support Info, emulate
# info here with a gauge.
if type == "cache_config":
metrics_info = obj.metrics_info()
info_gauge = self._gauge_cls(
name="vllm:cache_config_info",
documentation="Information of the LLMEngine CacheConfig",
labelnames=metrics_info.keys(),
multiprocess_mode="mostrecent")
info_gauge.labels(**metrics_info).set(1)
class RayPrometheusStatLogger(PrometheusStatLogger):
"""RayPrometheusStatLogger uses Ray metrics instead."""
_metrics_cls = RayMetrics
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
return None

View File

@@ -0,0 +1,94 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
These types are defined in this file to avoid importing vllm.engine.metrics
and therefore importing prometheus_client.
This is required due to usage of Prometheus multiprocess mode to enable
metrics after splitting out the uvicorn process from the engine process.
Prometheus multiprocess mode requires setting PROMETHEUS_MULTIPROC_DIR
before prometheus_client is imported. Typically, this is done by setting
the env variable before launch, but since we are a library, we need to
do this in Python code and lazily import prometheus_client.
"""
import time
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List, Optional
from vllm.config import SupportsMetricsInfo, VllmConfig
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
@dataclass
class Stats:
"""Created by LLMEngine for use by StatLogger."""
now: float
# System stats (should have _sys suffix)
# Scheduler State
num_running_sys: int
num_waiting_sys: int
num_swapped_sys: int
# KV Cache Usage in %
gpu_cache_usage_sys: float
cpu_cache_usage_sys: float
# Prefix caching block hit rate
cpu_prefix_cache_hit_rate: float
gpu_prefix_cache_hit_rate: float
# Iteration stats (should have _iter suffix)
num_prompt_tokens_iter: int
num_generation_tokens_iter: int
num_tokens_iter: int
time_to_first_tokens_iter: List[float]
time_per_output_tokens_iter: List[float]
num_preemption_iter: int
# Request stats (should have _requests suffix)
# Latency
time_e2e_requests: List[float]
time_queue_requests: List[float]
time_inference_requests: List[float]
time_prefill_requests: List[float]
time_decode_requests: List[float]
# Metadata
num_prompt_tokens_requests: List[int]
num_generation_tokens_requests: List[int]
n_requests: List[int]
max_num_generation_tokens_requests: List[int]
max_tokens_requests: List[int]
finished_reason_requests: List[str]
waiting_lora_adapters: List[str]
running_lora_adapters: List[str]
max_lora: str
spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None
class StatLoggerBase(ABC):
"""Base class for StatLogger."""
def __init__(self, local_interval: float, vllm_config: VllmConfig) -> None:
# Tracked stats over current local logging interval.
self.num_prompt_tokens: List[int] = []
self.num_generation_tokens: List[int] = []
self.last_local_log = time.time()
self.local_interval = local_interval
self.spec_decode_metrics: Optional[SpecDecodeWorkerMetrics] = None
@abstractmethod
def log(self, stats: Stats) -> None:
raise NotImplementedError
@abstractmethod
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
raise NotImplementedError
def maybe_update_spec_decode_metrics(self, stats: Stats):
"""Save spec decode metrics (since they are unlikely
to be emitted at same time as log interval)."""
if stats.spec_decode_metrics is not None:
self.spec_decode_metrics = stats.spec_decode_metrics

View File

@@ -0,0 +1,148 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import uuid
from dataclasses import dataclass, field
from enum import Enum
from typing import List, Mapping, Optional, Union
from vllm import PoolingParams
from vllm.inputs import PromptType
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.utils import Device
VLLM_RPC_SUCCESS_STR = "SUCCESS"
IPC_INPUT_EXT = "_input_socket"
IPC_OUTPUT_EXT = "_output_socket"
IPC_HEALTH_EXT = "_health_socket"
IPC_DATA_EXT = "_data_socket"
class MQEngineDeadError(RuntimeError):
pass
@dataclass
class RPCProcessRequest:
prompt: PromptType
params: Union[SamplingParams, PoolingParams]
request_id: str
lora_request: Optional[LoRARequest] = None
trace_headers: Optional[Mapping[str, str]] = None
prompt_adapter_request: Optional[PromptAdapterRequest] = None
priority: int = 0
def __init__(
self,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None:
super().__init__()
self.prompt = prompt
self.params = params
self.request_id = request_id
self.lora_request = lora_request
self.trace_headers = trace_headers
self.prompt_adapter_request = prompt_adapter_request
self.priority = priority
@dataclass
class RPCError:
request_id: Optional[str]
is_engine_errored: bool
exception: BaseException
@dataclass
class RPCAbortRequest:
request_id: str
class RPCStartupRequest(Enum):
IS_SERVER_READY = 1
@dataclass
class RPCStartupResponse:
tracing_enabled: bool
class RPCUProfileRequest(Enum):
START_PROFILE = 1
STOP_PROFILE = 2
class RPCResetMultiModalCacheRequest(Enum):
RESET = 1
@dataclass
class RPCResetPrefixCacheRequest:
device: Device
class RPCSleepRequest(Enum):
SLEEP_LEVEL_1 = 1
SLEEP_LEVEL_2 = 2
@dataclass
class RPCWakeUpRequest:
tags: Optional[list[str]] = None
@dataclass
class RPCIsSleepingRequest:
# Set the default value of request_id to a new UUID
request_id: str = field(default_factory=lambda: str(uuid.uuid4()))
@dataclass
class RPCIsSleepingResponse:
request_id: str
is_sleeping: bool
@dataclass
class RPCLoadAdapterRequest:
lora_request: LoRARequest
# Set the default value of request_id to a new UUID
request_id: str = field(default_factory=lambda: str(uuid.uuid4()))
@dataclass
class RPCAdapterLoadedResponse:
request_id: str
RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest,
RPCUProfileRequest, RPCLoadAdapterRequest,
RPCResetMultiModalCacheRequest,
RPCResetPrefixCacheRequest, RPCSleepRequest,
RPCWakeUpRequest, RPCIsSleepingRequest]
REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCAdapterLoadedResponse,
RPCIsSleepingResponse, RPCError]
def ENGINE_DEAD_ERROR(
error: Optional[BaseException] = None) -> MQEngineDeadError:
if error is None:
return MQEngineDeadError(
"Engine loop is not running. Inspect the stacktrace to "
"find the original error")
return MQEngineDeadError(
"Engine loop is not running. Inspect the stacktrace to "
f"find the original error: {repr(error)}.")

View File

@@ -0,0 +1,681 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import copy
import pickle
from contextlib import contextmanager, suppress
from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping,
Optional, Union, cast)
import cloudpickle
import psutil
import zmq
import zmq.asyncio
from zmq import Frame # type: ignore[attr-defined]
from zmq.asyncio import Socket
from vllm import PoolingParams
from vllm.config import DecodingConfig, ModelConfig, VllmConfig
from vllm.core.scheduler import SchedulerOutputs
# yapf conflicts with isort for this block
# yapf: disable
from vllm.engine.async_llm_engine import (
build_guided_decoding_logits_processor_async)
from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
IPC_HEALTH_EXT, IPC_INPUT_EXT,
IPC_OUTPUT_EXT, RPC_REQUEST_T,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCAdapterLoadedResponse, RPCError,
RPCIsSleepingRequest,
RPCIsSleepingResponse,
RPCLoadAdapterRequest,
RPCProcessRequest,
RPCResetMultiModalCacheRequest,
RPCResetPrefixCacheRequest,
RPCSleepRequest, RPCStartupRequest,
RPCStartupResponse,
RPCUProfileRequest, RPCWakeUpRequest)
from vllm.engine.protocol import EngineClient
# yapf: enable
from vllm.envs import VLLM_RPC_TIMEOUT
from vllm.inputs import PromptType
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.utils import Device
logger = init_logger(__name__)
class MQClientClosedError(Exception):
"""Exception class raised when the client is used post-close.
The client can be closed, which closes the ZMQ context. This normally
happens on server shutdown. In some cases, methods like abort and
do_log_stats will still be called and then try to open a socket, which
causes a ZMQError and creates a huge stack trace.
So, we throw this error such that we can suppress it.
"""
class MQLLMEngineClient(EngineClient):
"""A client wrapper for MQLLMEngine that conforms to the
EngineClient protocol.
MQLLMEngine and MQLLMEngineClient are intended to run in separate
processes communicating via zeromq ipc sockets.
The entrypoint to MQLLMEngineClient is through the generate()
method. On generate() MQLLMEngine does three things:
- Creates an asyncio output queue
- Sends a RPCGenerateRequest to the MQLLMEngine via zmq
- Pulls RequestOutputs from its queue and yields them
MQLLMEngine runs two background loops:
- output_loop: the output loop pulls List[RequestOutput]
from the MQLLMEngine via zmq (each list is the output
of one engine_step in the LLMEngine). It then parses
the list and pushes individual request_outputs into
the corresponding output_queue such that they can be
consumed by the .generate() method.
- health_loop: the health loop queries the health socket
every N seconds, confirming the engine is healthy
"""
def __init__(self, ipc_path: str, engine_config: VllmConfig,
engine_pid: int):
self.context = zmq.asyncio.Context()
self._errored_with: Optional[BaseException] = None
# Get the configs.
self.vllm_config = engine_config
self.model_config = engine_config.model_config
self.decoding_config = engine_config.decoding_config
# Create the tokenizer group.
self.tokenizer = init_tokenizer_from_configs(
model_config=self.model_config,
scheduler_config=engine_config.scheduler_config,
lora_config=engine_config.lora_config)
self.input_preprocessor = InputPreprocessor(self.model_config,
self.tokenizer)
# Send RPCGenerateRequest to the MQLLMEngine.
self.input_socket: Socket = self.context.socket(zmq.constants.PUSH)
self.input_socket.connect(f"{ipc_path}{IPC_INPUT_EXT}")
# Receive streams of RequestOutput from the MQLLMEngine.
self.output_socket: Socket = self.context.socket(zmq.constants.PULL)
self.output_socket.connect(f"{ipc_path}{IPC_OUTPUT_EXT}")
# IPC path for acking heartbeats.
self.heartbeat_socket: Socket = self.context.socket(zmq.constants.PULL)
self.heartbeat_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}")
# IPC path for the data socket.
self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}"
# Stream for each individual request.
self.output_queues: Dict[str, asyncio.Queue] = {}
# Loop to handle output of the LLMEngine periodically.
# Started after the MQLLMEngine is ready so that we can
# build the Client in an executor to enable clean shutdown.
self.output_loop: Optional[asyncio.Task] = None
# Loop to check health of the LLMEngine periodically.
# Started after the MQLLMEngine is ready.
self.health_loop: Optional[asyncio.Task] = None
self._engine_process = psutil.Process(engine_pid)
@staticmethod
def is_unsupported_config(vllm_config: VllmConfig):
# Pipeline parallel not yet supported
return vllm_config.parallel_config.pipeline_parallel_size > 1
@contextmanager
def get_data_socket(self) -> Iterator[Socket]:
socket = self.context.socket(zmq.constants.DEALER)
try:
socket.connect(self.data_ipc_path)
yield socket
finally:
socket.close(linger=0)
async def run_heartbeat_loop(self, timeout: int):
"""Background loop that continually checks to ensure the engine process
is still alive.
"""
try:
while True:
# Check if the engine process is running:
if not self._engine_process.is_running() or (
self._engine_process.status() == psutil.STATUS_ZOMBIE):
# NB: is_running() returns True for zombies
self._set_errored(
RuntimeError(
f"Engine process (pid {self._engine_process.pid}) "
"died."))
break
if await self.heartbeat_socket.poll(timeout=timeout):
# Heartbeat received- check the message
await self._check_success(
error_message="Heartbeat failed.",
socket=self.heartbeat_socket)
logger.debug("Heartbeat successful.")
except asyncio.CancelledError:
logger.debug("Shutting down MQLLMEngineClient check health loop.")
except psutil.NoSuchProcess:
self._set_errored(
RuntimeError(
f"Engine process (pid {self._engine_process.pid}) died."))
except Exception as e:
self._set_errored(e)
async def run_output_handler_loop(self):
"""Get RequestOutputs from Engine and stream to Request Queues"""
try:
while True:
# Poll, checking for ENGINE_DEAD
while await self.output_socket.poll(timeout=VLLM_RPC_TIMEOUT
) == 0:
logger.debug("Waiting for output from MQLLMEngine.")
# If errored, alert all running requests.
if self.errored:
for queue_j in tuple(self.output_queues.values()):
queue_j.put_nowait(
ENGINE_DEAD_ERROR(self._errored_with))
return
message: Frame = await self.output_socket.recv(copy=False)
request_outputs = pickle.loads(message.buffer)
is_error = isinstance(request_outputs,
(BaseException, RPCError))
if is_error:
if isinstance(request_outputs, RPCError):
rpc_error: RPCError = request_outputs
request_id = rpc_error.request_id
exception = rpc_error.exception
is_engine_errored = rpc_error.is_engine_errored
else:
# MPLLMEngine should always return an RPCError to
# the output_socket when an issue arises.
# If we are here, we are in a bad state and
# should shut down the server.
error: BaseException = request_outputs
logger.error(
"Received Exception %s rather than RPCError from "
"MPLLMEngine. This should never happen.", error)
request_id = None
exception = error
is_engine_errored = True
# Set to error state only on engine critical error
# (and record only the first one)
if is_engine_errored and not self._errored_with:
self._errored_with = exception
# If engine is errored, no matter the type of exception
# it will no longer be able to receive new requests,
# therefore we have to inform that the current
# processed requests failed as well. Send back a dead
# engine error give this feedback and also give a
# 'hint' to the server to shutdown next.
exception = self.dead_error
if request_id is None:
# If request_id is None, then the engine raised an
# exception for a batch, and we may not know the
# request that caused it, neither if it was actually
# caused by any of them (e.g. CUDA OOM). Therefore we
# broadcast the same exception for all requests.
for queue_i in tuple(self.output_queues.values()):
queue_i.put_nowait(exception)
else:
queue = self.output_queues.get(request_id)
if queue is not None:
queue.put_nowait(exception)
# Put each output into the appropriate queue.
elif isinstance(
request_outputs,
(RPCAdapterLoadedResponse, RPCIsSleepingResponse)):
self._add_output(request_outputs)
else:
for request_output in request_outputs:
self._add_output(request_output)
except asyncio.CancelledError:
logger.debug("Shutting down MQLLMEngineClient output handler.")
def _add_output(self, request_output: Union[RequestOutput,
RPCAdapterLoadedResponse,
RPCIsSleepingResponse]):
queue = self.output_queues.get(request_output.request_id)
if queue is not None:
queue.put_nowait(request_output)
async def setup(self):
"""Setup the client before it starts sending server requests."""
# Start output_loop
if self.output_loop is None:
# only generate once to avoid multiple concurrent output_loops
# this will lead to race conditions and wrong orders of tokens
# returned by the engine
# setup will be called multiple times during the startup of
# the engine
self.output_loop = asyncio.create_task(
self.run_output_handler_loop())
with self.get_data_socket() as socket:
# Wait until server is ready.
response = await self._wait_for_server_rpc(socket)
self.tracing_flag = response.tracing_enabled
# Start health_loop.
if self.health_loop is None:
self.health_loop = asyncio.create_task(
self.run_heartbeat_loop(timeout=VLLM_RPC_TIMEOUT))
def close(self):
"""Destroy the ZeroMQ Context."""
# Close all sockets and terminate the context.
self.context.destroy(linger=0)
# Cancel background tasks.
if self.health_loop is not None:
self.health_loop.cancel()
if self.output_loop is not None:
self.output_loop.cancel()
def _set_errored(self, e: BaseException):
logger.exception(repr(e))
if self._errored_with is None:
self._errored_with = e
@staticmethod
async def _send_get_data_rpc_request(request: RPCStartupRequest,
expected_type: Any,
error_message: str,
socket: Socket) -> Any:
"""Send an RPC request that is expecting data back."""
# Ping RPCServer with a request.
await socket.send_multipart((pickle.dumps(request), ), copy=False)
# Make sure the server responds in time.
if await socket.poll(timeout=VLLM_RPC_TIMEOUT) == 0:
raise TimeoutError("RPCServer didn't reply within "
f"{VLLM_RPC_TIMEOUT} ms")
# Await the data from the Server.
frame = await socket.recv(copy=False)
data = pickle.loads(frame.buffer)
if isinstance(data, BaseException):
raise data
elif not isinstance(data, expected_type):
raise ValueError(error_message)
return data
@staticmethod
async def _send_one_way_rpc_request(request: RPC_REQUEST_T,
socket: Socket):
"""Send one-way RPC request to trigger an action."""
if socket.closed:
raise MQClientClosedError()
await socket.send_multipart((pickle.dumps(request), ))
async def _await_ack(self, error_message: str, socket: Socket):
"""Await acknowledgement that a request succeeded."""
if socket.closed:
raise MQClientClosedError()
if await socket.poll(timeout=VLLM_RPC_TIMEOUT) == 0:
raise TimeoutError("MQLLMEngine didn't reply within "
f"{VLLM_RPC_TIMEOUT}ms")
await self._check_success(error_message, socket)
@staticmethod
async def _check_success(error_message: str, socket: Socket):
"""Confirm that socket has a VLLM_RPC_SUCCESS_STR message"""
if socket.closed:
raise MQClientClosedError()
frame = await socket.recv(copy=False)
response = pickle.loads(frame.buffer)
# Raise error if unsuccessful
if isinstance(response, BaseException):
raise response
elif (not isinstance(response, str)
or response != VLLM_RPC_SUCCESS_STR):
raise ValueError(error_message)
async def get_input_preprocessor(self) -> InputPreprocessor:
return self.input_preprocessor
async def get_tokenizer(self, lora_request: Optional[LoRARequest] = None):
return await self.tokenizer.get_lora_tokenizer_async(lora_request)
async def get_vllm_config(self) -> VllmConfig:
return self.vllm_config
async def get_decoding_config(self) -> DecodingConfig:
return self.decoding_config
async def get_model_config(self) -> ModelConfig:
return self.model_config
async def is_tracing_enabled(self) -> bool:
return self.tracing_flag
async def _wait_for_server_rpc(self, socket: Socket) -> RPCStartupResponse:
"""Wait for the RPCServer to start up."""
return await self._send_get_data_rpc_request(
request=RPCStartupRequest.IS_SERVER_READY,
expected_type=RPCStartupResponse,
error_message="Unable to start RPC Server",
socket=socket)
async def abort(self, request_id: str):
"""Send an ABORT_REQUEST signal to the RPC Server"""
with suppress(MQClientClosedError):
await self._send_one_way_rpc_request(
request=RPCAbortRequest(request_id), socket=self.input_socket)
async def do_log_stats(
self,
scheduler_outputs: Optional[SchedulerOutputs] = None,
model_output: Optional[List[SamplerOutput]] = None,
) -> None:
"""
Ignore do_log_stats (handled on MQLLMEngine polling)
"""
pass
async def check_health(self):
"""
The check health loop probes the health status of the
Engine's health every N seconds and sets _errored_with
if the engine is unhealthy.
"""
if self._errored_with is not None:
raise self._errored_with
@property
def is_running(self) -> bool:
return not self.errored
@property
def is_stopped(self) -> bool:
return self.errored
@property
def errored(self) -> bool:
return self._errored_with is not None
@property
def dead_error(self) -> BaseException:
return ENGINE_DEAD_ERROR(self._errored_with)
def generate(
self,
prompt: PromptType,
sampling_params: SamplingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> AsyncGenerator[RequestOutput, None]:
"""Generate outputs for a request.
Generate outputs for a request. This method is a coroutine. It adds the
request into the waiting queue of the LLMEngine and streams the outputs
from the LLMEngine to the caller.
Args:
prompt: The prompt to the LLM. See
[`PromptType`][vllm.inputs.PromptType] for more details about
the format of each input.
sampling_params: The sampling parameters of the request.
request_id: The unique id of the request.
lora_request: LoRA request to use for generation, if any.
trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: Prompt Adapter request to use
for generation, if any.
priority: Priority of the request (lower means earlier handling).
Any priority other than 0 will lead to an error if the
scheduling policy is not "priority".
"""
return cast(
AsyncGenerator[RequestOutput, None],
self._process_request(prompt, sampling_params, request_id,
lora_request, trace_headers,
prompt_adapter_request, priority))
def encode(
self,
prompt: PromptType,
pooling_params: PoolingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
) -> AsyncGenerator[PoolingRequestOutput, None]:
"""Generate outputs for a request from a pooling model.
Generate outputs for a request. This method is a coroutine. It adds the
request into the waiting queue of the LLMEngine and streams the outputs
from the LLMEngine to the caller.
Args:
prompt: The prompt to the LLM. See
[`PromptType`][vllm.inputs.PromptType] for more details about
the format of each input.
pooling_params: The pooling parameters of the request.
request_id: The unique id of the request.
lora_request: LoRA request to use for generation, if any.
trace_headers: OpenTelemetry trace headers.
Yields:
The output `PoolingRequestOutput` objects from the LLMEngine
for the request.
"""
return cast(
AsyncGenerator[PoolingRequestOutput, None],
self._process_request(prompt,
pooling_params,
request_id,
lora_request,
trace_headers,
priority=priority))
async def _process_request(
self,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[
PoolingRequestOutput, None]]:
"""Send an RPCGenerateRequest to the RPCServer and stream responses."""
# If already dead, error out.
if self._errored_with is not None:
raise ENGINE_DEAD_ERROR(self._errored_with)
# Ensure the request id is unique among running requests
if request_id in self.output_queues:
raise ValueError(f"Request {request_id} already exists")
# Constructing guided decoding logits processors is expensive, so we do
# it here to avoid contending with cpu resources and the GIL on the
# backend process.
if isinstance(params, SamplingParams) and \
params.guided_decoding is not None:
params = await \
build_guided_decoding_logits_processor_async(
sampling_params=params,
tokenizer=await self.get_tokenizer(lora_request),
default_guided_backend=(self.decoding_config.backend
if self.decoding_config
else DecodingConfig.backend),
model_config=self.model_config,
reasoning_backend=self.decoding_config.reasoning_backend,
)
# 1) Create output queue for this requests.
queue: asyncio.Queue[Union[RequestOutput,
BaseException]] = asyncio.Queue()
self.output_queues[request_id] = queue
try:
# 2) Detach logits processors so that they can be pickled
# separately (may require cloudpickle which is slower)
if isinstance(params, SamplingParams) and params.logits_processors:
# Defensive shallow copy
params = copy.copy(params)
logits_processors = params.logits_processors
params.logits_processors = None
lp_bytes = cloudpickle.dumps(logits_processors)
else:
lp_bytes = None
request_bytes = pickle.dumps(
RPCProcessRequest(
prompt=prompt,
params=params,
request_id=request_id,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority,
))
# 3) Send the RPCGenerateRequest to the MQLLMEngine.
parts = (request_bytes,
lp_bytes) if lp_bytes else (request_bytes, )
await self.input_socket.send_multipart(parts, copy=False)
# 4) Stream the RequestOutputs from the output queue. Note
# that the output_loop pushes RequestOutput objects to this
# queue after pulling them from the zmq socket.
finished = False
try:
while not finished:
request_output = await queue.get()
if isinstance(request_output, BaseException):
raise request_output
finished = request_output.finished
yield request_output
finally:
# Request was canceled by the client.
if not finished and not self.errored:
await self.abort(request_id)
finally:
self.output_queues.pop(request_id)
async def start_profile(self) -> None:
"""Start profiling the engine"""
await self._send_one_way_rpc_request(
request=RPCUProfileRequest.START_PROFILE, socket=self.input_socket)
async def stop_profile(self) -> None:
"""Stop profiling the engine"""
await self._send_one_way_rpc_request(
request=RPCUProfileRequest.STOP_PROFILE, socket=self.input_socket)
async def reset_mm_cache(self) -> None:
"""Reset the multi-modal cache"""
await self._send_one_way_rpc_request(
request=RPCResetMultiModalCacheRequest.RESET,
socket=self.input_socket)
async def reset_prefix_cache(self,
device: Optional[Device] = None) -> None:
"""Reset the prefix cache"""
await self._send_one_way_rpc_request(
request=RPCResetPrefixCacheRequest(device),
socket=self.input_socket)
async def sleep(self, level: int = 1) -> None:
"""Sleep the engine for a given level"""
return await self._send_one_way_rpc_request(
request=RPCSleepRequest(level), socket=self.input_socket)
async def wake_up(self, tags: Optional[list[str]] = None) -> None:
"""Wake up the engine"""
return await self._send_one_way_rpc_request(
request=RPCWakeUpRequest(tags), socket=self.input_socket)
async def is_sleeping(self) -> bool:
"""Check whether the engine is sleeping"""
request = RPCIsSleepingRequest()
queue: asyncio.Queue[Union[BaseException,
RPCIsSleepingResponse]] = asyncio.Queue()
self.output_queues[request.request_id] = queue
request_bytes = pickle.dumps(request)
await self.input_socket.send_multipart((request_bytes, ), copy=False)
request_output = await queue.get()
self.output_queues.pop(request.request_id)
if isinstance(request_output, BaseException):
raise request_output
return request_output.is_sleeping
async def add_lora(self, lora_request: LoRARequest) -> None:
"""Load a new LoRA adapter into the engine for future requests."""
# Uses the same I/O as generate requests
request = RPCLoadAdapterRequest(lora_request)
# Create output queue for this requests.
queue: asyncio.Queue[Union[None, BaseException]] = asyncio.Queue()
self.output_queues[request.request_id] = queue
# Send the request
request_bytes = pickle.dumps(request)
await self.input_socket.send_multipart((request_bytes, ), copy=False)
# Wait for the response
request_output = await queue.get()
self.output_queues.pop(request.request_id)
# Raise on error, otherwise happily return None
if isinstance(request_output, BaseException):
raise request_output

View File

@@ -0,0 +1,460 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pickle
import signal
from contextlib import contextmanager
from typing import Iterator, List, Optional, Union
import cloudpickle
import zmq
from vllm import AsyncEngineArgs, SamplingParams
from vllm.config import VllmConfig
from vllm.engine.llm_engine import LLMEngine
# yapf conflicts with isort for this block
# yapf: disable
from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
IPC_HEALTH_EXT, IPC_INPUT_EXT,
IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCAdapterLoadedResponse, RPCError,
RPCIsSleepingRequest,
RPCIsSleepingResponse,
RPCLoadAdapterRequest,
RPCProcessRequest,
RPCResetMultiModalCacheRequest,
RPCResetPrefixCacheRequest,
RPCSleepRequest, RPCStartupRequest,
RPCStartupResponse,
RPCUProfileRequest, RPCWakeUpRequest)
# yapf: enable
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
from vllm.usage.usage_lib import UsageContext
from vllm.worker.model_runner_base import InputProcessingError
logger = init_logger(__name__)
POLLING_TIMEOUT_MS = 10000
HEALTHY_RESPONSE = (pickle.dumps(VLLM_RPC_SUCCESS_STR), )
class MQLLMEngine:
"""A multiprocessing wrapper for
[`LLMEngine`][vllm.engine.llm_engine.LLMEngine].
This class is used to wrap the
[`LLMEngine`][vllm.engine.llm_engine.LLMEngine] class to enable use
in concurrnet manner. It runs a background loop and uses zeromq to
receive new requests and stream outputs incrementally via ipc.
The [`LLMEngine`][vllm.engine.llm_engine.LLMEngine] generate or encode
process is kicked off when a new RPCProcessRequest is received by the
input_socket.
The self.engine_loop checks the input_socket for new requests,
adds them to the LLMEngine if there are any, calls the internal
[`LLMEngine.step()`][vllm.engine.llm_engine.LLMEngine.step], and sends
the RequestOutputs back over the output_socket.
If use_async_sockets is set, the logic associated with reading new
requests from the socket and sending data to the socket is passed
as a callback to the llm_engine, which calls the logic asynchronously
such that the IPC can be overlapped with the GPU.
Args:
ipc_path: Base path for zeromq interprocess messaging
use_async_sockets: Whether to make send/recv async with GPU
log_requests: Whether to log the requests.
*args: Arguments for [`LLMEngine`][vllm.engine.llm_engine.LLMEngine].
**kwargs: Arguments for [`LLMEngine`][vllm.engine.llm_engine.LLMEngine].
"""
def __init__(self,
ipc_path: str,
use_async_sockets: bool,
*args,
log_requests: bool = True,
**kwargs) -> None:
# For MQLLMEngine, we can use cached outputs, since each new request
# output is immediately pickled and send over the socket, which frees
# the python object to be reused again.
kwargs['use_cached_outputs'] = True
self.engine = LLMEngine(*args, **kwargs)
self.log_requests = log_requests
self.use_async_sockets = use_async_sockets
if self.use_async_sockets:
self.engine.process_request_outputs_callback = \
self._async_socket_engine_callback
self.ctx = zmq.Context() # type: ignore[attr-defined]
# Receive input from the client.
self.input_socket = self.ctx.socket(zmq.constants.PULL)
self.input_socket.bind(f"{ipc_path}{IPC_INPUT_EXT}")
# Send output stream back to client.
self.output_socket = self.ctx.socket(zmq.constants.PUSH)
self.output_socket.bind(f"{ipc_path}{IPC_OUTPUT_EXT}")
# Send heartbeats back to client.
self.heartbeat_socket = self.ctx.socket(zmq.constants.PUSH)
self.heartbeat_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}")
# IPC path for the data socket.
self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}"
# Error state.
self._errored_with: Optional[BaseException] = None
@property
def dead_error(self) -> BaseException:
if self._errored_with is not None:
return ENGINE_DEAD_ERROR(self._errored_with)
else:
return ENGINE_DEAD_ERROR()
@classmethod
def from_vllm_config(cls, vllm_config: VllmConfig,
usage_context: UsageContext,
disable_log_requests: bool, disable_log_stats: bool,
ipc_path: str) -> "MQLLMEngine":
# Setup plugins for each process
from vllm.plugins import load_general_plugins
load_general_plugins()
use_async_sockets = vllm_config.model_config.use_async_output_proc
return cls(
vllm_config=vllm_config,
executor_class=LLMEngine._get_executor_cls(vllm_config),
ipc_path=ipc_path,
usage_context=usage_context,
use_async_sockets=use_async_sockets,
log_requests=(not disable_log_requests),
log_stats=(not disable_log_stats),
)
@staticmethod
def from_engine_args(engine_args: AsyncEngineArgs,
usage_context: UsageContext, ipc_path: str):
"""Creates an MQLLMEngine from the engine arguments."""
vllm_config = engine_args.create_engine_config(usage_context)
return MQLLMEngine.from_vllm_config(
ipc_path=ipc_path,
vllm_config=vllm_config,
usage_context=usage_context,
disable_log_requests=engine_args.disable_log_requests,
disable_log_stats=engine_args.disable_log_stats,
)
def start(self):
try:
try:
logger.debug("Starting Startup Loop.")
self.run_startup_loop()
logger.debug("Starting Engine Loop.")
self.run_engine_loop()
except Exception as e:
logger.exception(repr(e))
except KeyboardInterrupt:
logger.debug("Shutting down MQLLMEngine.")
finally:
logger.debug("MQLLMEngine is shut down.")
self.cleanup()
def cleanup(self):
"""Cleanup zeromq state on shutdown."""
# Closes all sockets and destroys context.
self.ctx.destroy(linger=0)
del self.engine
@contextmanager
def make_data_socket(
self) -> Iterator[zmq.Socket]: # type: ignore[name-defined]
socket = self.ctx.socket(zmq.constants.ROUTER)
try:
socket.bind(self.data_ipc_path)
yield socket
finally:
socket.close(linger=0)
def run_startup_loop(self) -> None:
"""Startup loop for sending data from Engine -> Client."""
with self.make_data_socket() as socket:
response: Union[RPCStartupResponse, BaseException]
try:
identity, message = socket.recv_multipart(copy=False)
request: RPCStartupRequest = pickle.loads(message.buffer)
# Handle the query from the Client.
if request == RPCStartupRequest.IS_SERVER_READY:
tracing_enabled = self.engine.is_tracing_enabled()
response = RPCStartupResponse(
tracing_enabled=tracing_enabled)
except Exception as e:
response = e
socket.send_multipart((identity, pickle.dumps(response)),
copy=False)
def run_engine_loop(self):
"""Core busy loop of the LLMEngine."""
while True:
if not self.engine.has_unfinished_requests():
# Poll until there is work to do.
while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
# When there's no work, check on engine health and send
# health status back to client
self._health_check()
self.engine.do_log_stats()
logger.debug("Waiting for new requests in engine loop.")
# Handle any input from the client.
self.handle_new_input()
# Engine step.
request_outputs = self.engine_step()
# Send request outputs (if async, done in engine_step callback).
if not self.use_async_sockets:
self._send_outputs(request_outputs)
def engine_step(self) -> List[RequestOutput]:
"""Engine step wrapper with error handling."""
try:
return self.engine.step()
except SystemExit:
raise
except InputProcessingError as e:
# Special case where we handle an error preparing the inputs for
# a single request in the batch
rpc_err = RPCError(request_id=e.request_id,
is_engine_errored=False,
exception=e.__cause__)
self._send_outputs(rpc_err)
return []
except BaseException as e:
self._set_errored(e)
rpc_err = RPCError(request_id=None,
is_engine_errored=True,
exception=e)
self._send_outputs(rpc_err)
raise e
def handle_new_input(self):
"""Handle new input from the socket"""
try:
while self.input_socket.poll(timeout=0) != 0:
frames = self.input_socket.recv_multipart(copy=False)
request = pickle.loads(frames[0].buffer)
if isinstance(request, RPCProcessRequest):
if len(frames) > 1:
# Use cloudpickle for logits processors
assert isinstance(request.params, SamplingParams)
lprocs = cloudpickle.loads(frames[1].buffer)
request.params.logits_processors = lprocs
self._handle_process_request(request)
elif isinstance(request, RPCAbortRequest):
self._handle_abort_request(request)
elif isinstance(request, RPCUProfileRequest):
if request == RPCUProfileRequest.START_PROFILE:
self.start_profile()
else:
self.stop_profile()
elif isinstance(request, RPCLoadAdapterRequest):
self._handle_load_adapter_request(request)
elif isinstance(request, RPCResetMultiModalCacheRequest):
self.reset_mm_cache()
elif isinstance(request, RPCResetPrefixCacheRequest):
self.reset_prefix_cache()
elif isinstance(request, RPCSleepRequest):
self.sleep(request.value)
elif isinstance(request, RPCWakeUpRequest):
self.wake_up(request.tags)
elif isinstance(request, RPCIsSleepingRequest):
self._handle_is_sleeping_request(request)
else:
raise ValueError("Unknown RPCRequest Type: "
f"{type(request)}")
except Exception as e:
self._set_errored(e)
self._send_unhealthy(e)
raise e from None
def _handle_process_request(self, request: RPCProcessRequest):
"""Handle RPCProcessRequest by adding it to the LLMEngine."""
request_id = request.request_id
if self._errored_with is not None:
rpc_err = RPCError(request_id=request_id,
is_engine_errored=True,
exception=ENGINE_DEAD_ERROR(self._errored_with))
self._send_outputs(rpc_err)
try:
self.engine.add_request(
request_id=request_id,
prompt=request.prompt,
params=request.params,
lora_request=request.lora_request,
trace_headers=request.trace_headers,
prompt_adapter_request=request.prompt_adapter_request,
priority=request.priority)
if self.log_requests:
logger.info("Added request %s.", request.request_id)
except Exception as e:
# We do not set self._errored = True here, since the error
# is due to an issue adding this request to the engine,
# rather than an issue with the engine itself.
logger.debug("Failed to add request %s to engine. %s",
request.request_id, e)
is_errored = self._errored_with is not None
rpc_err = RPCError(request_id=request_id,
is_engine_errored=is_errored,
exception=e)
self._send_outputs(rpc_err)
# Remove request from the engine.
self.engine.abort_request(request_id)
def _handle_abort_request(self, request: RPCAbortRequest):
self.engine.abort_request(request.request_id)
if self.log_requests:
logger.info("Aborted request %s.", request.request_id)
def _handle_load_adapter_request(self, request: RPCLoadAdapterRequest):
try:
self.engine.add_lora(request.lora_request)
except BaseException as e:
# Send back an error if the adater fails to load
rpc_err = RPCError(request_id=request.request_id,
is_engine_errored=False,
exception=e)
self._send_outputs(rpc_err)
return
# Otherwise, send back the successful load message
self._send_outputs(
RPCAdapterLoadedResponse(request_id=request.request_id))
def _handle_is_sleeping_request(self, request: RPCIsSleepingRequest):
is_sleeping = self.is_sleeping()
self._send_outputs(
RPCIsSleepingResponse(request_id=request.request_id,
is_sleeping=is_sleeping))
def _health_check(self):
# Send unhealthy if engine has already errored
if self._errored_with is not None:
self._send_unhealthy(self._errored_with)
try:
self.engine.check_health()
self._send_healthy()
except Exception as e:
self._set_errored(e)
self._send_unhealthy(e)
def _send_outputs(self, outputs: REQUEST_OUTPUTS_T):
"""Send outputs back to the engine client. These can be:
- Exceptions
- A list of generation outputs
- A response from loading a lora adapter
"""
if outputs:
try:
from ray.exceptions import RayTaskError
# RayTaskError might not pickelable here. We need to unpack the
# underlying exception as the real exception in the output.
if (isinstance(outputs, RPCError)
and isinstance(outputs.exception, RayTaskError)):
outputs.exception = outputs.exception.cause
except ImportError:
pass
output_bytes = pickle.dumps(outputs)
self.output_socket.send_multipart((output_bytes, ), copy=False)
def _send_healthy(self):
"""Send HEALTHY message to RPCClient."""
if not self.heartbeat_socket.closed:
self.heartbeat_socket.send_multipart(HEALTHY_RESPONSE, copy=False)
def _send_unhealthy(self, error: BaseException):
"""Send UNHEALTHY message to RPCClient."""
if not self.heartbeat_socket.closed:
error_bytes = pickle.dumps(error)
self.heartbeat_socket.send_multipart((error_bytes, ), copy=False)
def _async_socket_engine_callback(self,
request_outputs: REQUEST_OUTPUTS_T):
"""Callback used by engine to make socket handling async with GPU."""
self._send_outputs(request_outputs)
self.handle_new_input()
def _set_errored(self, e: BaseException):
"""Log and set errored status if this is the first issue."""
if self._errored_with is None:
self._errored_with = e
def start_profile(self) -> None:
self.engine.start_profile()
def stop_profile(self) -> None:
self.engine.stop_profile()
def reset_mm_cache(self) -> bool:
return self.engine.reset_mm_cache()
def reset_prefix_cache(self) -> bool:
return self.engine.reset_prefix_cache()
def sleep(self, level: int = 1) -> None:
self.engine.sleep(level)
def wake_up(self, tags: Optional[list[str]] = None) -> None:
self.engine.wake_up(tags)
def is_sleeping(self) -> bool:
return self.engine.is_sleeping()
def signal_handler(*_) -> None:
raise KeyboardInterrupt("MQLLMEngine terminated")
def run_mp_engine(vllm_config: VllmConfig, usage_context: UsageContext,
ipc_path: str, disable_log_stats: bool,
disable_log_requests: bool, engine_alive):
try:
# Ensure we can serialize transformer config before spawning
maybe_register_config_serialize_by_value()
engine = MQLLMEngine.from_vllm_config(
vllm_config=vllm_config,
usage_context=usage_context,
disable_log_stats=disable_log_stats,
disable_log_requests=disable_log_requests,
ipc_path=ipc_path)
signal.signal(signal.SIGTERM, signal_handler)
engine.start()
except BaseException as e:
logger.exception(e)
engine_alive.value = False
raise e from None

View File

View File

@@ -0,0 +1,75 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from typing import Callable, List
from vllm.config import SchedulerConfig
from vllm.core.scheduler import Scheduler
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.sequence import Sequence, SequenceGroup, SequenceGroupOutput
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import Counter
class SequenceGroupOutputProcessor(ABC):
"""Interface for logic that processes new token ids in sequence groups,
managing detokenization, stop checking, and freeing/forking sequences with
the scheduler.
This is highly coupled with the LLMEngine and should be seen as an extension
of it. The logic is separated to simplify the LLMEngine class and allow
separate implementations for single-step decoding (which supports beam
search sequence forking) and multi-step decoding (which does not support
beam search, but does support speculative decoding).
"""
@staticmethod
def create_output_processor(
scheduler_config: SchedulerConfig,
detokenizer: Detokenizer,
scheduler: List[Scheduler],
seq_counter: Counter,
get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer],
stop_checker: "StopChecker",
):
"""Create an output processor.
This returns a single-step output processor if num_lookahead_slots is
zero, else returns a multi-step output processor.
"""
if scheduler_config.num_lookahead_slots == 0:
# Importing here to avoid cycle.
from vllm.engine.output_processor.single_step import (
SingleStepOutputProcessor)
return SingleStepOutputProcessor(scheduler_config, detokenizer,
scheduler, seq_counter,
stop_checker)
else:
# Importing here to avoid cycle.
from vllm.engine.output_processor.multi_step import (
MultiStepOutputProcessor)
return MultiStepOutputProcessor(
detokenizer,
scheduler,
seq_counter,
get_tokenizer_for_seq,
stop_checker,
)
@abstractmethod
def process_outputs(self, sequence_group: SequenceGroup,
outputs: List[SequenceGroupOutput],
is_async: bool) -> None:
"""Process new token ids for the sequence group. Handles logic such as
detokenization, stop checking, and freeing/forking sequences in the
scheduler.
"""
pass
@abstractmethod
def process_prompt_logprob(self, seq_group: SequenceGroup,
outputs: List[SequenceGroupOutput]) -> None:
"""Update prompt logprobs received from outputs to seq_group."""
pass

View File

@@ -0,0 +1,216 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
from typing import Callable, List, cast
from vllm.core.scheduler import Scheduler
from vllm.engine.output_processor.interfaces import (
SequenceGroupOutputProcessor)
from vllm.engine.output_processor.single_step import (
single_step_process_prompt_logprob)
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams
from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
CompletionSequenceGroupOutput, Sequence,
SequenceGroup, SequenceGroupOutput, SequenceOutput,
SequenceStatus)
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import Counter
logger = init_logger(__name__)
class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
"""SequenceGroupOutputProcessor which handles logic related to
detokenization and stopping conditions. It specializes to "multi-step
decoding", where vLLM's worker may generate multiple tokens per invocation.
This is currently mutually exclusive with advanced sampling techniques like
beam search, which motivates the separation of this logic from the single
step output processor.
This class is responsible for things such as correctly appending all new
token ids to their sequence, detokenizing new token ids, truncating new
output tokens after an eos token, and correctly handling the case where the
number of new output tokens per sequence differs in a single batch.
"""
def __init__(
self,
detokenizer: Detokenizer,
scheduler: List[Scheduler],
seq_counter: Counter,
get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer],
stop_checker: StopChecker,
):
self.detokenizer = detokenizer
self.scheduler = scheduler
self.seq_counter = seq_counter
self.get_tokenizer_for_seq = get_tokenizer_for_seq
self.stop_checker = stop_checker
def process_prompt_logprob(self, seq_group: SequenceGroup,
outputs: List[SequenceGroupOutput]) -> None:
"""Process prompt logprobs associated with each step of a multi-step-
scheduled computation.
Args:
seq_group: the outputs are associated with this
[`SequenceGroup`][vllm.sequence.SequenceGroup]
outputs: the
[`SequenceGroupOutput`][vllm.sequence.SequenceGroupOutput]s
for all scheduler steps
"""
for output in outputs:
# Concatenate single-step prompt logprob processing results.
assert isinstance(output, CompletionSequenceGroupOutput)
single_step_process_prompt_logprob(self, seq_group, output)
@staticmethod
@functools.lru_cache
def _log_prompt_logprob_unsupported_warning_once():
# Reminder: Please update docs/features/compatibility_matrix.md
# If the feature combo become valid
logger.warning(
"Prompt logprob is not supported by multi step workers. "
"(e.g., speculative decode uses multi step workers).")
def process_outputs(self,
sequence_group: SequenceGroup,
outputs: List[SequenceGroupOutput],
is_async: bool = False) -> None:
"""Append new tokens in the outputs to sequences in the sequence group.
This only supports sequence groups of size 1. It supports greater than
one new token per sequence.
This applies logic like stop condition checking and detokenization.
It also handles cases where there are tokens emitted after
the EOS token.
is_async - Indicates whether this postprocessor runs in
parallel with the GPU forward pass and is processing
tokens from the previous step. If this is true, then
no tokens need to be appended since it is already done
externally (before the next schedule() call)
"""
# Sequences can be in RUNNING or FINISHED_ABORTED state
# once scheduled, as a sequence is moved to FINISHED_ABORTED
# if a client disconnects from the api server.
seqs = sequence_group.get_seqs(status=SequenceStatus.RUNNING)
if seqs is None:
seqs = sequence_group.get_seqs(
status=SequenceStatus.FINISHED_ABORTED)
for output in outputs:
if output.samples[0].output_token != VLLM_INVALID_TOKEN_ID:
sequence_group.metrics.spec_token_acceptance_counts[
output.step_index] += 1
assert seqs, "Expected RUNNING or FINISHED_ABORTED sequences"
assert len(seqs) == 1, (
"Beam search not supported in multi-step decoding.")
seq = seqs[0]
seq_id = seq.seq_id
# This method is defined in the more generic
# SequenceGroupOutputProcessor, but here we assume that the outputs are
# of a more specific type.
assert all([
isinstance(output, CompletionSequenceGroupOutput)
for output in outputs
])
compl_outputs = cast(List[CompletionSequenceGroupOutput], outputs)
assert all([
seq_id == output.samples[0].parent_seq_id
for output in compl_outputs
])
if is_async:
# Async case: We process tokens one by one. Here, we know the token
# was already appended, so we only need to do the rest of the
# postprocessor: Detokenization + stopping logic
self._process_decode_and_stop(seq, sequence_group.sampling_params)
else:
# Standard multi-step case
# Since there's only one sequence per sequence group,
# we can take the first sample.
samples = [output.samples[0] for output in compl_outputs]
# entries in sample tokens may be invalid (eg. due to spec decode
# rejecting tokens).
valid_samples = [
sample for sample in samples
if sample.output_token != VLLM_INVALID_TOKEN_ID
]
# When both spec-decode and pre-fill chunking are enabled, we
# don't have guaranteed samples here (e.g. all -1s).
if valid_samples:
self._process_seq_outputs(seq, valid_samples,
sequence_group.sampling_params)
def _process_decode_and_stop(self, seq: Sequence,
sampling_params: SamplingParams) -> None:
new_char_count = 0
if sampling_params.detokenize and self.detokenizer:
new_char_count = self.detokenizer.decode_sequence_inplace(
seq, sampling_params)
# TODO(sang): Support lora.
self.stop_checker.maybe_stop_sequence(
seq,
new_char_count=new_char_count,
sampling_params=sampling_params,
)
def _process_seq_outputs(self, seq: Sequence,
valid_samples: List[SequenceOutput],
sampling_params: SamplingParams) -> None:
output_token_ids = [sample.output_token for sample in valid_samples]
output_logprobs = [sample.logprobs for sample in valid_samples]
output_embeds = [sample.output_embed for sample in valid_samples]
# Truncate to max_tokens if necessary.
remaining_tokens = sampling_params.max_tokens - (seq.get_output_len() +
len(output_token_ids))
if remaining_tokens < 0:
output_token_ids = output_token_ids[:remaining_tokens]
# Truncate any tokens after EOS. This is required as spec decode
# generates a fixed number of tokens without evaluating stopping
# conditions within the block. This can cause an eos token to be
# unintentionally ignored.
if not sampling_params.ignore_eos and self.detokenizer:
eos_token_id = self.get_tokenizer_for_seq(seq).eos_token_id
# Avoiding .index calls as exception throwing in the happy path
# is expensive.
for i in range(len(output_token_ids)):
if output_token_ids[i] == eos_token_id:
output_token_ids = output_token_ids[:i + 1]
break
is_prefill_sampled_token = seq.data.get_num_uncomputed_tokens() == 0
# Incrementally append tokens to the sequence, as if we had only one new
# token.
for output_token_id, output_logprob, output_embed in zip(
output_token_ids, output_logprobs, output_embeds):
seq.append_token_id(
token_id=output_token_id,
logprobs=output_logprob,
token_embed=output_embed,
)
if is_prefill_sampled_token:
is_prefill_sampled_token = False
else:
# Update num_computed_tokens iff the sampled token is not from
# a prefill step.
seq.data.update_num_computed_tokens(1)
self._process_decode_and_stop(seq, sampling_params)
if seq.is_finished():
break

View File

@@ -0,0 +1,145 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import List
from vllm.config import SchedulerConfig
from vllm.core.scheduler import Scheduler
from vllm.engine.output_processor.interfaces import (
SequenceGroupOutputProcessor)
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.logger import init_logger
from vllm.sequence import (CompletionSequenceGroupOutput, SequenceGroup,
SequenceGroupOutput)
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.utils import Counter
logger = init_logger(__name__)
def single_step_process_prompt_logprob(
sg_output_proc: SequenceGroupOutputProcessor, seq_group: SequenceGroup,
output: CompletionSequenceGroupOutput) -> None:
"""Process prompt logprobs associated with the
[`SequenceGroupOutput`][vllm.sequence.SequenceGroupOutput] for a given step.
Do nothing if the output has no prompt logprobs.
Account for the fact that transformers do not compute first-token logprobs.
Args:
sg_output_proc:
[`SequenceGroupOutputProcessor`][vllm.engine.output_processor.interfaces.SequenceGroupOutputProcessor]
instance
seq_group: the output is associated with this
[`SequenceGroup`][vllm.sequence.SequenceGroup]
output: the [`SequenceGroupOutput`][vllm.sequence.SequenceGroupOutput]
for a single scheduler step
"""
prompt_logprobs = output.prompt_logprobs
# If this is the first (or only) "chunk" of the prefill, we need
# to prepend None to the list of prompt logprobs. The reason for this
# is that for N prompt tokens, the Sampler will generate N-1 total
# prompt logprobs during prefill since the token at idx 0 will not
# have a logprob associated with it.
if prompt_logprobs is not None:
if not seq_group.prompt_logprobs:
prompt_logprobs = [None] + prompt_logprobs
seq_group.prompt_logprobs = []
assert hasattr(sg_output_proc, 'detokenizer')
if (seq_group.sampling_params.detokenize
and sg_output_proc.detokenizer):
sg_output_proc.detokenizer.decode_prompt_logprobs_inplace(
seq_group,
prompt_logprobs,
position_offset=len(seq_group.prompt_logprobs))
seq_group.prompt_logprobs.extend(prompt_logprobs)
class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
"""SequenceGroupOutputProcessor which handles "output processing" logic,
which happens after the model returns generated token ids and before
scheduling of the next batch. Output processing logic includes
detokenization, and determining if a sequence is finished (e.g. via max len
or eos token).
The SingleStepOutputProcessor is specialized to the case where the model
emits at most a single token per invocation, which precludes configurations
such as speculative decoding or multi-step decoding. This enables beam
search sampling, which requires forking/finishing/freeing sequences in a way
that is currently difficult to schedule multiple steps ahead of time.
"""
def __init__(self, scheduler_config: SchedulerConfig,
detokenizer: Detokenizer, scheduler: List[Scheduler],
seq_counter: Counter, stop_checker: StopChecker):
self.scheduler_config = scheduler_config
self.detokenizer = detokenizer
self.scheduler = scheduler
self.seq_counter = seq_counter
self.stop_checker = stop_checker
def process_outputs(self, sequence_group: SequenceGroup,
outputs: List[SequenceGroupOutput],
is_async: bool) -> None:
"""Append all new tokens to sequences in the sequence group. Fork any
surviving beam candidates; free any unsurviving ones.
Invokes detokenizer to detokenize new tokens, and also marks sequences
as finished if they meet stop conditions.
is_async - Indicates whether this postprocessor runs in
parallel with the GPU forward pass and is processing
tokens from the previous step. If this is true, then
no tokens need to be appended since it is already done
externally (before the next schedule() call)
"""
assert (len(outputs) == 1
), f"{type(self)} does not support multiple outputs per step"
return self._process_sequence_group_outputs(sequence_group, outputs[0],
is_async)
def process_prompt_logprob(self, seq_group: SequenceGroup,
outputs: List[SequenceGroupOutput]) -> None:
"""Process prompt logprobs associated with one step of a single-step-
scheduled computation.
Args:
seq_group: the output is associated with this
[`SequenceGroup`][vllm.sequence.SequenceGroup]
outputs: the
[`SequenceGroupOutput`][vllm.sequence.SequenceGroupOutput]
for a single scheduler step
"""
assert len(outputs) == 1, "Single step should only have 1 output."
output = outputs[0]
assert isinstance(output, CompletionSequenceGroupOutput)
single_step_process_prompt_logprob(self, seq_group, output)
def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
outputs: SequenceGroupOutput,
is_async: bool) -> None:
sampling_params = seq_group.sampling_params
sample = outputs.samples[0]
seq = seq_group.first_seq
if not is_async:
seq.append_token_id(sample.output_token, sample.logprobs,
sample.output_embed)
if sampling_params.detokenize and self.detokenizer:
new_char_count = self.detokenizer.decode_sequence_inplace(
seq, sampling_params)
else:
new_char_count = 0
self.stop_checker.maybe_stop_sequence(
seq,
new_char_count,
sampling_params,
lora_req=seq_group.lora_request,
)
if seq.is_finished():
for scheduler in self.scheduler:
scheduler.free_seq(seq)

View File

@@ -0,0 +1,131 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Callable, List, Optional, Tuple
from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams
from vllm.sequence import Sequence, SequenceStatus
from vllm.transformers_utils.tokenizer import AnyTokenizer
class StopChecker:
"""LLMEngine helper class which separates out the logic involving stop
checking. This checks things such as: whether the eos token was emitted,
whether the max_tokens has been consumed, whether a stop string has been
emitted, or if we have exceeded the max model len.
"""
def __init__(self, max_model_len: int,
get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer]):
# Do not use it directly, but use `self._get_max_model_len`.
self._max_model_len = max_model_len
self.get_tokenizer_for_seq = get_tokenizer_for_seq
def _get_max_model_len(self, lora_req: Optional[LoRARequest]):
if lora_req and lora_req.long_lora_max_len:
return lora_req.long_lora_max_len
else:
return self._max_model_len
def maybe_stop_sequence(
self,
seq: Sequence,
new_char_count: int,
sampling_params: SamplingParams,
lora_req: Optional[LoRARequest] = None,
) -> None:
"""Stop the finished sequences.
new_char_count is the number of chars added to the
sequence's output text for the newly generated token
"""
# Check if the minimum number of tokens has been generated yet;
# skip the stop string/token checks if not
if seq.get_output_len() < sampling_params.min_tokens:
return
# Check if the sequence has generated the EOS token.
if ((not sampling_params.ignore_eos)
and seq.get_last_token_id() == seq.eos_token_id):
# Remove the last EOS token unless explicitly specified
# This prevents unintended exposure of the EOS token
if new_char_count and (
not sampling_params.include_stop_str_in_output):
seq.output_text = seq.output_text[:-new_char_count]
seq.status = SequenceStatus.FINISHED_STOPPED
return
# Check if a stop token was encountered.
# This assumes a single token produced per step.
last_token_id = seq.get_last_token_id()
if last_token_id in (sampling_params.stop_token_ids or ()):
if new_char_count and (
not sampling_params.include_stop_str_in_output):
# Remove last token
seq.output_text = seq.output_text[:-new_char_count]
seq.status = SequenceStatus.FINISHED_STOPPED
seq.stop_reason = last_token_id
return
# Check if any stop strings are matched.
stop = self.check_stop_strings(
seq.output_text, new_char_count, sampling_params.stop,
sampling_params.include_stop_str_in_output)
if stop is not None:
stop_str, truncate_to = stop
if truncate_to != -1:
seq.output_text = seq.output_text[:truncate_to]
seq.status = SequenceStatus.FINISHED_STOPPED
seq.stop_reason = stop_str
return
# Check if the sequence has reached max_model_len.
if seq.get_len() >= self._get_max_model_len(lora_req):
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return
# Check if the sequence has reached max_tokens.
if seq.get_output_len() == sampling_params.max_tokens:
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return
@staticmethod
def check_stop_strings(
output_text: str,
new_char_count: int,
stop: List[str],
include_in_output: bool,
) -> Optional[Tuple[str, int]]:
"""Check if any stop strings are matched and truncate sequence
output text accordingly.
Returns tuple (stop_string, offset) if matched or else None.
Where stop_string is the matched stop string and offset is the
length to which output_text should be truncated, or -1 for no
truncation.
"""
if not new_char_count or not stop:
return None
for stop_str in stop:
stop_string_len = len(stop_str)
# Avoid searching already-searched text.
stop_index = output_text.find(stop_str,
1 - new_char_count - stop_string_len)
if stop_index == -1:
continue
if include_in_output:
# Truncate to end of stop string.
stop_index += stop_string_len
if stop_index >= len(output_text):
# No truncation required.
return stop_str, -1
# Truncate the output text to either the beginning
# or end of the stop string.
return stop_str, stop_index
return None

View File

@@ -0,0 +1,28 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import List
from typing import Sequence as GenericSequence
from typing import cast
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import CompletionSequenceGroupOutput, SequenceGroupOutput
def create_output_by_sequence_group(
outputs: GenericSequence[SamplerOutput],
num_seq_groups: int) -> List[List[SequenceGroupOutput]]:
"""Helper method which transforms a 2d list organized by
[step][sequence group] into [sequence group][step].
"""
output_by_sequence_group: List[List[CompletionSequenceGroupOutput]] = [
[] for _ in range(num_seq_groups)
]
for step in outputs:
sequence_group_output: CompletionSequenceGroupOutput
for i, sequence_group_output in enumerate(step):
output_by_sequence_group[i].append(sequence_group_output)
# Cast to the more generic type that CompletionSequenceGroupOutput
# inherits from.
return cast(List[List[SequenceGroupOutput]], output_by_sequence_group)

317
vllm/engine/protocol.py Normal file
View File

@@ -0,0 +1,317 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
from abc import ABC, abstractmethod
from typing import AsyncGenerator, Mapping, Optional
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
from vllm.config import DecodingConfig, ModelConfig, VllmConfig
from vllm.core.scheduler import SchedulerOutputs
from vllm.inputs.data import PromptType, TokensPrompt
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import Device, collect_from_async_generator, random_uuid
logger = init_logger(__name__)
class EngineClient(ABC):
"""Protocol class for Clients to Engine"""
@property
@abstractmethod
def is_running(self) -> bool:
...
@property
@abstractmethod
def is_stopped(self) -> bool:
...
@property
@abstractmethod
def errored(self) -> bool:
...
@property
@abstractmethod
def dead_error(self) -> BaseException:
...
@abstractmethod
def generate(
self,
prompt: PromptType,
sampling_params: SamplingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> AsyncGenerator[RequestOutput, None]:
"""Generate outputs for a request."""
...
async def beam_search(
self,
prompt: PromptType,
request_id: str,
params: BeamSearchParams,
lora_request: Optional[LoRARequest] = None,
) -> AsyncGenerator[RequestOutput, None]:
beam_width = params.beam_width
max_tokens = params.max_tokens
ignore_eos = params.ignore_eos
temperature = params.temperature
length_penalty = params.length_penalty
include_stop_str_in_output = params.include_stop_str_in_output
preprocessor = await self.get_input_preprocessor()
tokenizer_group = preprocessor.get_tokenizer_group()
tokenizer = await tokenizer_group.get_lora_tokenizer_async()
if is_explicit_encoder_decoder_prompt(prompt):
raise NotImplementedError
else:
processed_inputs = preprocessor._prompt_to_llm_inputs(prompt)
if processed_inputs["type"] == "embeds":
raise NotImplementedError
prompt_token_ids = processed_inputs["prompt_token_ids"]
prompt_text = processed_inputs.get("prompt")
multi_modal_data = processed_inputs.get("multi_modal_data")
mm_processor_kwargs = processed_inputs.get("mm_processor_kwargs")
tokenized_length = len(prompt_token_ids)
sort_beams_key = create_sort_beams_key_function(
tokenizer.eos_token_id, length_penalty)
beam_search_params = SamplingParams(
logprobs=2 * beam_width,
max_tokens=1,
temperature=temperature,
)
all_beams = [
BeamSearchSequence(tokens=prompt_token_ids,
cum_logprob=0,
logprobs=[],
multi_modal_data=multi_modal_data,
mm_processor_kwargs=mm_processor_kwargs,
lora_request=lora_request)
]
completed = []
for _ in range(max_tokens):
prompts_batch, lora_req_batch = zip(*[(
TokensPrompt(prompt_token_ids=beam.tokens,
multi_modal_data=beam.multi_modal_data,
mm_processor_kwargs=beam.mm_processor_kwargs),
beam.lora_request,
) for beam in all_beams])
tasks = []
request_id = f"beam_search-{random_uuid()}"
for i, (individual_prompt,
lora_req) in enumerate(zip(prompts_batch, lora_req_batch)):
request_id_item = f"{request_id}-{i}"
task = asyncio.create_task(
collect_from_async_generator(
self.generate(individual_prompt,
beam_search_params,
request_id_item,
lora_request=lora_req)))
tasks.append(task)
output = await asyncio.gather(*tasks)
output = [x[0] for x in output]
new_beams = []
for i, current_beam in enumerate(all_beams):
result = output[i]
if result.outputs[0].logprobs is not None:
logprobs = result.outputs[0].logprobs[0]
for token_id, logprob_obj in logprobs.items():
if token_id == tokenizer.eos_token_id and \
not ignore_eos:
completed.append(
BeamSearchSequence(
tokens=current_beam.tokens +
[token_id] if include_stop_str_in_output
else current_beam.tokens,
logprobs=current_beam.logprobs +
[logprobs],
cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob,
finish_reason="stop",
stop_reason=tokenizer.eos_token_id))
else:
new_beams.append(
BeamSearchSequence(
tokens=current_beam.tokens + [token_id],
logprobs=current_beam.logprobs +
[logprobs],
lora_request=current_beam.lora_request,
cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob,
multi_modal_data=current_beam.
multi_modal_data,
mm_processor_kwargs=current_beam.
mm_processor_kwargs))
sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True)
all_beams = sorted_beams[:beam_width]
completed.extend(all_beams)
sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
best_beams = sorted_completed[:beam_width]
for beam in best_beams:
if (beam.tokens[-1] == tokenizer.eos_token_id and not ignore_eos):
# Skip the eos token in the text.
tokens = beam.tokens[tokenized_length:-1]
else:
tokens = beam.tokens[tokenized_length:]
beam.text = tokenizer.decode(tokens)
beam_search_output = RequestOutput(
request_id=request_id,
prompt=prompt_text,
outputs=[
CompletionOutput(text=beam.text,
cumulative_logprob=beam.cum_logprob,
token_ids=beam.tokens[tokenized_length:],
index=i,
logprobs=beam.logprobs,
finish_reason=beam.finish_reason if
beam.finish_reason is not None else "length",
stop_reason=beam.stop_reason)
for (i, beam) in enumerate(best_beams)
],
finished=True,
prompt_token_ids=prompt_token_ids,
prompt_logprobs=None)
yield beam_search_output
@abstractmethod
def encode(
self,
prompt: PromptType,
pooling_params: PoolingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
) -> AsyncGenerator[PoolingRequestOutput, None]:
"""Generate outputs for a request from a pooling model."""
...
@abstractmethod
async def abort(self, request_id: str) -> None:
"""Abort a request.
Args:
request_id: The unique id of the request.
"""
...
@abstractmethod
async def get_vllm_config(self) -> VllmConfig:
"""Get the vllm configuration of the vLLM engine."""
...
@abstractmethod
async def get_model_config(self) -> ModelConfig:
"""Get the model configuration of the vLLM engine."""
...
@abstractmethod
async def get_decoding_config(self) -> DecodingConfig:
"""Get the decoding configuration of the vLLM engine."""
...
@abstractmethod
async def get_input_preprocessor(self) -> InputPreprocessor:
"""Get the input processor of the vLLM engine."""
...
@abstractmethod
async def get_tokenizer(
self,
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
"""Get the appropriate tokenizer for the request"""
...
@abstractmethod
async def is_tracing_enabled(self) -> bool:
...
@abstractmethod
async def do_log_stats(
self,
scheduler_outputs: Optional[SchedulerOutputs] = None,
model_output: Optional[list[SamplerOutput]] = None,
) -> None:
...
@abstractmethod
async def check_health(self) -> None:
"""Raise if unhealthy"""
...
@abstractmethod
async def start_profile(self) -> None:
"""Start profiling the engine"""
...
@abstractmethod
async def stop_profile(self) -> None:
"""Start profiling the engine"""
...
@abstractmethod
async def reset_mm_cache(self) -> None:
"""Reset the multi-modal cache"""
...
@abstractmethod
async def reset_prefix_cache(self,
device: Optional[Device] = None) -> None:
"""Reset the prefix cache"""
...
@abstractmethod
async def sleep(self, level: int = 1) -> None:
"""Sleep the engine"""
...
@abstractmethod
async def wake_up(self, tags: Optional[list[str]] = None) -> None:
"""Wake up the engine"""
...
@abstractmethod
async def is_sleeping(self) -> bool:
"""Check whether the engine is sleeping"""
...
@abstractmethod
async def add_lora(self, lora_request: LoRARequest) -> None:
"""Load a new LoRA adapter into the engine for future requests."""
...