First commit
This commit is contained in:
0
vllm/engine/__init__.py
Normal file
0
vllm/engine/__init__.py
Normal file
BIN
vllm/engine/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
vllm/engine/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/engine/__pycache__/arg_utils.cpython-310.pyc
Normal file
BIN
vllm/engine/__pycache__/arg_utils.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/engine/__pycache__/async_llm_engine.cpython-310.pyc
Normal file
BIN
vllm/engine/__pycache__/async_llm_engine.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/engine/__pycache__/async_timeout.cpython-310.pyc
Normal file
BIN
vllm/engine/__pycache__/async_timeout.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/engine/__pycache__/llm_engine.cpython-310.pyc
Normal file
BIN
vllm/engine/__pycache__/llm_engine.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/engine/__pycache__/metrics.cpython-310.pyc
Normal file
BIN
vllm/engine/__pycache__/metrics.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/engine/__pycache__/metrics_types.cpython-310.pyc
Normal file
BIN
vllm/engine/__pycache__/metrics_types.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/engine/__pycache__/protocol.cpython-310.pyc
Normal file
BIN
vllm/engine/__pycache__/protocol.cpython-310.pyc
Normal file
Binary file not shown.
1143
vllm/engine/arg_utils.py
Normal file
1143
vllm/engine/arg_utils.py
Normal file
File diff suppressed because it is too large
Load Diff
1323
vllm/engine/async_llm_engine.py
Normal file
1323
vllm/engine/async_llm_engine.py
Normal file
File diff suppressed because it is too large
Load Diff
189
vllm/engine/async_timeout.py
Normal file
189
vllm/engine/async_timeout.py
Normal file
@@ -0,0 +1,189 @@
|
||||
# Workaround for https://github.com/python/cpython/issues/86296
|
||||
#
|
||||
# From https://github.com/aio-libs/async-timeout/blob/master/async_timeout/__init__.py
|
||||
# Licensed under the Apache License (Apache-2.0)
|
||||
|
||||
import asyncio
|
||||
import enum
|
||||
import sys
|
||||
import warnings
|
||||
from types import TracebackType
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
if sys.version_info[:2] >= (3, 11):
|
||||
from asyncio import timeout as asyncio_timeout
|
||||
else:
|
||||
|
||||
def asyncio_timeout(delay: Optional[float]) -> "Timeout":
|
||||
"""timeout context manager.
|
||||
Useful in cases when you want to apply timeout logic around block
|
||||
of code or in cases when asyncio.wait_for is not suitable. For example:
|
||||
>>> async with timeout(0.001):
|
||||
... async with aiohttp.get('https://github.com') as r:
|
||||
... await r.text()
|
||||
delay - value in seconds or None to disable timeout logic
|
||||
"""
|
||||
loop = asyncio.get_running_loop()
|
||||
deadline = loop.time() + delay if delay is not None else None
|
||||
return Timeout(deadline, loop)
|
||||
|
||||
class _State(enum.Enum):
|
||||
INIT = "INIT"
|
||||
ENTER = "ENTER"
|
||||
TIMEOUT = "TIMEOUT"
|
||||
EXIT = "EXIT"
|
||||
|
||||
class Timeout:
|
||||
# Internal class, please don't instantiate it directly
|
||||
# Use timeout() and timeout_at() public factories instead.
|
||||
#
|
||||
# Implementation note: `async with timeout()` is preferred
|
||||
# over `with timeout()`.
|
||||
# While technically the Timeout class implementation
|
||||
# doesn't need to be async at all,
|
||||
# the `async with` statement explicitly points that
|
||||
# the context manager should be used from async function context.
|
||||
#
|
||||
# This design allows to avoid many silly misusages.
|
||||
#
|
||||
# TimeoutError is raised immediately when scheduled
|
||||
# if the deadline is passed.
|
||||
# The purpose is to time out as soon as possible
|
||||
# without waiting for the next await expression.
|
||||
|
||||
__slots__ = ("_deadline", "_loop", "_state", "_timeout_handler")
|
||||
|
||||
def __init__(self, deadline: Optional[float],
|
||||
loop: asyncio.AbstractEventLoop) -> None:
|
||||
self._loop = loop
|
||||
self._state = _State.INIT
|
||||
|
||||
self._timeout_handler = None # type: Optional[asyncio.Handle]
|
||||
if deadline is None:
|
||||
self._deadline = None # type: Optional[float]
|
||||
else:
|
||||
self.update(deadline)
|
||||
|
||||
def __enter__(self) -> "Timeout":
|
||||
warnings.warn(
|
||||
"with timeout() is deprecated, use async with timeout()",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
self._do_enter()
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[TracebackType],
|
||||
) -> Optional[bool]:
|
||||
self._do_exit(exc_type)
|
||||
return None
|
||||
|
||||
async def __aenter__(self) -> "Timeout":
|
||||
self._do_enter()
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[TracebackType],
|
||||
) -> Optional[bool]:
|
||||
self._do_exit(exc_type)
|
||||
return None
|
||||
|
||||
@property
|
||||
def expired(self) -> bool:
|
||||
"""Is timeout expired during execution?"""
|
||||
return self._state == _State.TIMEOUT
|
||||
|
||||
@property
|
||||
def deadline(self) -> Optional[float]:
|
||||
return self._deadline
|
||||
|
||||
def reject(self) -> None:
|
||||
"""Reject scheduled timeout if any."""
|
||||
# cancel is maybe better name but
|
||||
# task.cancel() raises CancelledError in asyncio world.
|
||||
if self._state not in (_State.INIT, _State.ENTER):
|
||||
raise RuntimeError(f"invalid state {self._state.value}")
|
||||
self._reject()
|
||||
|
||||
def _reject(self) -> None:
|
||||
if self._timeout_handler is not None:
|
||||
self._timeout_handler.cancel()
|
||||
self._timeout_handler = None
|
||||
|
||||
def shift(self, delay: float) -> None:
|
||||
"""Advance timeout on delay seconds.
|
||||
The delay can be negative.
|
||||
Raise RuntimeError if shift is called when deadline is not scheduled
|
||||
"""
|
||||
deadline = self._deadline
|
||||
if deadline is None:
|
||||
raise RuntimeError(
|
||||
"cannot shift timeout if deadline is not scheduled")
|
||||
self.update(deadline + delay)
|
||||
|
||||
def update(self, deadline: float) -> None:
|
||||
"""Set deadline to absolute value.
|
||||
deadline argument points on the time in the same clock system
|
||||
as loop.time().
|
||||
If new deadline is in the past the timeout is raised immediately.
|
||||
Please note: it is not POSIX time but a time with
|
||||
undefined starting base, e.g. the time of the system power on.
|
||||
"""
|
||||
if self._state == _State.EXIT:
|
||||
raise RuntimeError(
|
||||
"cannot reschedule after exit from context manager")
|
||||
if self._state == _State.TIMEOUT:
|
||||
raise RuntimeError("cannot reschedule expired timeout")
|
||||
if self._timeout_handler is not None:
|
||||
self._timeout_handler.cancel()
|
||||
self._deadline = deadline
|
||||
if self._state != _State.INIT:
|
||||
self._reschedule()
|
||||
|
||||
def _reschedule(self) -> None:
|
||||
assert self._state == _State.ENTER
|
||||
deadline = self._deadline
|
||||
if deadline is None:
|
||||
return
|
||||
|
||||
now = self._loop.time()
|
||||
if self._timeout_handler is not None:
|
||||
self._timeout_handler.cancel()
|
||||
|
||||
task = asyncio.current_task()
|
||||
if deadline <= now:
|
||||
self._timeout_handler = self._loop.call_soon(
|
||||
self._on_timeout, task)
|
||||
else:
|
||||
self._timeout_handler = self._loop.call_at(
|
||||
deadline, self._on_timeout, task)
|
||||
|
||||
def _do_enter(self) -> None:
|
||||
if self._state != _State.INIT:
|
||||
raise RuntimeError(f"invalid state {self._state.value}")
|
||||
self._state = _State.ENTER
|
||||
self._reschedule()
|
||||
|
||||
def _do_exit(self, exc_type: Optional[Type[BaseException]]) -> None:
|
||||
if exc_type is asyncio.CancelledError and \
|
||||
self._state == _State.TIMEOUT:
|
||||
self._timeout_handler = None
|
||||
raise asyncio.TimeoutError
|
||||
# timeout has not expired
|
||||
self._state = _State.EXIT
|
||||
self._reject()
|
||||
return None
|
||||
|
||||
def _on_timeout(self, task: "Optional[asyncio.Task[Any]]") -> None:
|
||||
if task:
|
||||
task.cancel()
|
||||
self._state = _State.TIMEOUT
|
||||
# drop the reference early
|
||||
self._timeout_handler = None
|
||||
1934
vllm/engine/llm_engine.py
Normal file
1934
vllm/engine/llm_engine.py
Normal file
File diff suppressed because it is too large
Load Diff
555
vllm/engine/metrics.py
Normal file
555
vllm/engine/metrics.py
Normal file
@@ -0,0 +1,555 @@
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Counter as CollectionsCounter
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import prometheus_client
|
||||
|
||||
from vllm.engine.metrics_types import (StatLoggerBase, Stats,
|
||||
SupportsMetricsInfo)
|
||||
from vllm.executor.ray_utils import ray
|
||||
from vllm.logger import init_logger
|
||||
|
||||
if ray is not None:
|
||||
from ray.util import metrics as ray_metrics
|
||||
else:
|
||||
ray_metrics = None
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
prometheus_client.disable_created_metrics()
|
||||
|
||||
# The begin-* and end* here are used by the documentation generator
|
||||
# to extract the metrics definitions.
|
||||
|
||||
|
||||
# begin-metrics-definitions
|
||||
class Metrics:
|
||||
"""
|
||||
vLLM uses a multiprocessing-based frontend for the OpenAI server.
|
||||
This means that we need to run prometheus_client in multiprocessing mode
|
||||
See https://prometheus.github.io/client_python/multiprocess/ for more
|
||||
details on limitations.
|
||||
"""
|
||||
labelname_finish_reason = "finished_reason"
|
||||
_gauge_cls = prometheus_client.Gauge
|
||||
_counter_cls = prometheus_client.Counter
|
||||
_histogram_cls = prometheus_client.Histogram
|
||||
|
||||
def __init__(self, labelnames: List[str], max_model_len: int):
|
||||
# Unregister any existing vLLM collectors (for CI/CD)
|
||||
self._unregister_vllm_metrics()
|
||||
|
||||
# System stats
|
||||
# Scheduler State
|
||||
self.gauge_scheduler_running = self._gauge_cls(
|
||||
name="vllm:num_requests_running",
|
||||
documentation="Number of requests currently running on GPU.",
|
||||
labelnames=labelnames,
|
||||
multiprocess_mode="sum")
|
||||
self.gauge_scheduler_waiting = self._gauge_cls(
|
||||
name="vllm:num_requests_waiting",
|
||||
documentation="Number of requests waiting to be processed.",
|
||||
labelnames=labelnames,
|
||||
multiprocess_mode="sum")
|
||||
self.gauge_scheduler_swapped = self._gauge_cls(
|
||||
name="vllm:num_requests_swapped",
|
||||
documentation="Number of requests swapped to CPU.",
|
||||
labelnames=labelnames,
|
||||
multiprocess_mode="sum")
|
||||
# KV Cache Usage in %
|
||||
self.gauge_gpu_cache_usage = self._gauge_cls(
|
||||
name="vllm:gpu_cache_usage_perc",
|
||||
documentation="GPU KV-cache usage. 1 means 100 percent usage.",
|
||||
labelnames=labelnames,
|
||||
multiprocess_mode="sum")
|
||||
self.gauge_cpu_cache_usage = self._gauge_cls(
|
||||
name="vllm:cpu_cache_usage_perc",
|
||||
documentation="CPU KV-cache usage. 1 means 100 percent usage.",
|
||||
labelnames=labelnames,
|
||||
multiprocess_mode="sum")
|
||||
# Prefix caching block hit rate
|
||||
self.gauge_cpu_prefix_cache_hit_rate = self._gauge_cls(
|
||||
name="vllm:cpu_prefix_cache_hit_rate",
|
||||
documentation="CPU prefix cache block hit rate.",
|
||||
labelnames=labelnames,
|
||||
multiprocess_mode="sum")
|
||||
self.gauge_gpu_prefix_cache_hit_rate = self._gauge_cls(
|
||||
name="vllm:gpu_prefix_cache_hit_rate",
|
||||
documentation="GPU prefix cache block hit rate.",
|
||||
labelnames=labelnames,
|
||||
multiprocess_mode="sum")
|
||||
|
||||
# Iteration stats
|
||||
self.counter_num_preemption = self._counter_cls(
|
||||
name="vllm:num_preemptions_total",
|
||||
documentation="Cumulative number of preemption from the engine.",
|
||||
labelnames=labelnames)
|
||||
self.counter_prompt_tokens = self._counter_cls(
|
||||
name="vllm:prompt_tokens_total",
|
||||
documentation="Number of prefill tokens processed.",
|
||||
labelnames=labelnames)
|
||||
self.counter_generation_tokens = self._counter_cls(
|
||||
name="vllm:generation_tokens_total",
|
||||
documentation="Number of generation tokens processed.",
|
||||
labelnames=labelnames)
|
||||
self.histogram_time_to_first_token = self._histogram_cls(
|
||||
name="vllm:time_to_first_token_seconds",
|
||||
documentation="Histogram of time to first token in seconds.",
|
||||
labelnames=labelnames,
|
||||
buckets=[
|
||||
0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5,
|
||||
0.75, 1.0, 2.5, 5.0, 7.5, 10.0
|
||||
])
|
||||
self.histogram_time_per_output_token = self._histogram_cls(
|
||||
name="vllm:time_per_output_token_seconds",
|
||||
documentation="Histogram of time per output token in seconds.",
|
||||
labelnames=labelnames,
|
||||
buckets=[
|
||||
0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75,
|
||||
1.0, 2.5
|
||||
])
|
||||
|
||||
# Request stats
|
||||
# Latency
|
||||
self.histogram_e2e_time_request = self._histogram_cls(
|
||||
name="vllm:e2e_request_latency_seconds",
|
||||
documentation="Histogram of end to end request latency in seconds.",
|
||||
labelnames=labelnames,
|
||||
buckets=[1.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0])
|
||||
# Metadata
|
||||
self.histogram_num_prompt_tokens_request = self._histogram_cls(
|
||||
name="vllm:request_prompt_tokens",
|
||||
documentation="Number of prefill tokens processed.",
|
||||
labelnames=labelnames,
|
||||
buckets=build_1_2_5_buckets(max_model_len),
|
||||
)
|
||||
self.histogram_num_generation_tokens_request = \
|
||||
self._histogram_cls(
|
||||
name="vllm:request_generation_tokens",
|
||||
documentation="Number of generation tokens processed.",
|
||||
labelnames=labelnames,
|
||||
buckets=build_1_2_5_buckets(max_model_len),
|
||||
)
|
||||
self.histogram_n_request = self._histogram_cls(
|
||||
name="vllm:request_params_n",
|
||||
documentation="Histogram of the n request parameter.",
|
||||
labelnames=labelnames,
|
||||
buckets=[1, 2, 5, 10, 20],
|
||||
)
|
||||
self.counter_request_success = self._counter_cls(
|
||||
name="vllm:request_success_total",
|
||||
documentation="Count of successfully processed requests.",
|
||||
labelnames=labelnames + [Metrics.labelname_finish_reason])
|
||||
|
||||
# Speculatie decoding stats
|
||||
self.gauge_spec_decode_draft_acceptance_rate = self._gauge_cls(
|
||||
name="vllm:spec_decode_draft_acceptance_rate",
|
||||
documentation="Speulative token acceptance rate.",
|
||||
labelnames=labelnames,
|
||||
multiprocess_mode="sum")
|
||||
self.gauge_spec_decode_efficiency = self._gauge_cls(
|
||||
name="vllm:spec_decode_efficiency",
|
||||
documentation="Speculative decoding system efficiency.",
|
||||
labelnames=labelnames,
|
||||
multiprocess_mode="sum")
|
||||
self.counter_spec_decode_num_accepted_tokens = (self._counter_cls(
|
||||
name="vllm:spec_decode_num_accepted_tokens_total",
|
||||
documentation="Number of accepted tokens.",
|
||||
labelnames=labelnames))
|
||||
self.counter_spec_decode_num_draft_tokens = self._counter_cls(
|
||||
name="vllm:spec_decode_num_draft_tokens_total",
|
||||
documentation="Number of draft tokens.",
|
||||
labelnames=labelnames)
|
||||
self.counter_spec_decode_num_emitted_tokens = (self._counter_cls(
|
||||
name="vllm:spec_decode_num_emitted_tokens_total",
|
||||
documentation="Number of emitted tokens.",
|
||||
labelnames=labelnames))
|
||||
|
||||
# Deprecated in favor of vllm:prompt_tokens_total
|
||||
self.gauge_avg_prompt_throughput = self._gauge_cls(
|
||||
name="vllm:avg_prompt_throughput_toks_per_s",
|
||||
documentation="Average prefill throughput in tokens/s.",
|
||||
labelnames=labelnames,
|
||||
multiprocess_mode="sum",
|
||||
)
|
||||
# Deprecated in favor of vllm:generation_tokens_total
|
||||
self.gauge_avg_generation_throughput = self._gauge_cls(
|
||||
name="vllm:avg_generation_throughput_toks_per_s",
|
||||
documentation="Average generation throughput in tokens/s.",
|
||||
labelnames=labelnames,
|
||||
multiprocess_mode="sum",
|
||||
)
|
||||
|
||||
|
||||
# end-metrics-definitions
|
||||
|
||||
def _unregister_vllm_metrics(self) -> None:
|
||||
for collector in list(prometheus_client.REGISTRY._collector_to_names):
|
||||
if hasattr(collector, "_name") and "vllm" in collector._name:
|
||||
prometheus_client.REGISTRY.unregister(collector)
|
||||
|
||||
|
||||
class _RayGaugeWrapper:
|
||||
"""Wraps around ray.util.metrics.Gauge to provide same API as
|
||||
prometheus_client.Gauge"""
|
||||
|
||||
def __init__(self,
|
||||
name: str,
|
||||
documentation: str = "",
|
||||
labelnames: Optional[List[str]] = None,
|
||||
multiprocess_mode: str = ""):
|
||||
del multiprocess_mode
|
||||
labelnames_tuple = tuple(labelnames) if labelnames else None
|
||||
self._gauge = ray_metrics.Gauge(name=name,
|
||||
description=documentation,
|
||||
tag_keys=labelnames_tuple)
|
||||
|
||||
def labels(self, **labels):
|
||||
self._gauge.set_default_tags(labels)
|
||||
return self
|
||||
|
||||
def set(self, value: Union[int, float]):
|
||||
return self._gauge.set(value)
|
||||
|
||||
|
||||
class _RayCounterWrapper:
|
||||
"""Wraps around ray.util.metrics.Counter to provide same API as
|
||||
prometheus_client.Counter"""
|
||||
|
||||
def __init__(self,
|
||||
name: str,
|
||||
documentation: str = "",
|
||||
labelnames: Optional[List[str]] = None):
|
||||
labelnames_tuple = tuple(labelnames) if labelnames else None
|
||||
self._counter = ray_metrics.Counter(name=name,
|
||||
description=documentation,
|
||||
tag_keys=labelnames_tuple)
|
||||
|
||||
def labels(self, **labels):
|
||||
self._counter.set_default_tags(labels)
|
||||
return self
|
||||
|
||||
def inc(self, value: Union[int, float] = 1.0):
|
||||
if value == 0:
|
||||
return
|
||||
return self._counter.inc(value)
|
||||
|
||||
|
||||
class _RayHistogramWrapper:
|
||||
"""Wraps around ray.util.metrics.Histogram to provide same API as
|
||||
prometheus_client.Histogram"""
|
||||
|
||||
def __init__(self,
|
||||
name: str,
|
||||
documentation: str = "",
|
||||
labelnames: Optional[List[str]] = None,
|
||||
buckets: Optional[List[float]] = None):
|
||||
labelnames_tuple = tuple(labelnames) if labelnames else None
|
||||
self._histogram = ray_metrics.Histogram(name=name,
|
||||
description=documentation,
|
||||
tag_keys=labelnames_tuple,
|
||||
boundaries=buckets)
|
||||
|
||||
def labels(self, **labels):
|
||||
self._histogram.set_default_tags(labels)
|
||||
return self
|
||||
|
||||
def observe(self, value: Union[int, float]):
|
||||
return self._histogram.observe(value)
|
||||
|
||||
|
||||
class RayMetrics(Metrics):
|
||||
"""
|
||||
RayMetrics is used by RayPrometheusStatLogger to log to Ray metrics.
|
||||
Provides the same metrics as Metrics but uses Ray's util.metrics library.
|
||||
"""
|
||||
_gauge_cls = _RayGaugeWrapper
|
||||
_counter_cls = _RayCounterWrapper
|
||||
_histogram_cls = _RayHistogramWrapper
|
||||
|
||||
def __init__(self, labelnames: List[str], max_model_len: int):
|
||||
if ray_metrics is None:
|
||||
raise ImportError("RayMetrics requires Ray to be installed.")
|
||||
super().__init__(labelnames, max_model_len)
|
||||
|
||||
def _unregister_vllm_metrics(self) -> None:
|
||||
# No-op on purpose
|
||||
pass
|
||||
|
||||
|
||||
def build_1_2_5_buckets(max_value: int) -> List[int]:
|
||||
"""
|
||||
Builds a list of buckets with increasing powers of 10 multiplied by
|
||||
mantissa values (1, 2, 5) until the value exceeds the specified maximum.
|
||||
|
||||
Example:
|
||||
>>> build_1_2_5_buckets(100)
|
||||
[1, 2, 5, 10, 20, 50, 100]
|
||||
"""
|
||||
mantissa_lst = [1, 2, 5]
|
||||
exponent = 0
|
||||
buckets: List[int] = []
|
||||
while True:
|
||||
for m in mantissa_lst:
|
||||
value = m * 10**exponent
|
||||
if value <= max_value:
|
||||
buckets.append(value)
|
||||
else:
|
||||
return buckets
|
||||
exponent += 1
|
||||
|
||||
|
||||
def local_interval_elapsed(now: float, last_log: float,
|
||||
local_interval: float) -> bool:
|
||||
elapsed_time = now - last_log
|
||||
return elapsed_time > local_interval
|
||||
|
||||
|
||||
def get_throughput(tracked_stats: List[int], now: float,
|
||||
last_log: float) -> float:
|
||||
return float(np.sum(tracked_stats) / (now - last_log))
|
||||
|
||||
|
||||
class LoggingStatLogger(StatLoggerBase):
|
||||
"""LoggingStatLogger is used in LLMEngine to log to Stdout."""
|
||||
|
||||
def log(self, stats: Stats) -> None:
|
||||
"""Called by LLMEngine.
|
||||
Logs to Stdout every self.local_interval seconds."""
|
||||
|
||||
# Save tracked stats for token counters.
|
||||
self.num_prompt_tokens.append(stats.num_prompt_tokens_iter)
|
||||
self.num_generation_tokens.append(stats.num_generation_tokens_iter)
|
||||
|
||||
# Update spec decode metrics
|
||||
self.maybe_update_spec_decode_metrics(stats)
|
||||
|
||||
# Log locally every local_interval seconds.
|
||||
if local_interval_elapsed(stats.now, self.last_local_log,
|
||||
self.local_interval):
|
||||
# Compute summary metrics for tracked stats (and log them
|
||||
# to promethus if applicable).
|
||||
prompt_throughput = get_throughput(self.num_prompt_tokens,
|
||||
now=stats.now,
|
||||
last_log=self.last_local_log)
|
||||
generation_throughput = get_throughput(
|
||||
self.num_generation_tokens,
|
||||
now=stats.now,
|
||||
last_log=self.last_local_log)
|
||||
|
||||
# Log to stdout.
|
||||
logger.info(
|
||||
"Avg prompt throughput: %.1f tokens/s, "
|
||||
"Avg generation throughput: %.1f tokens/s, "
|
||||
"Running: %d reqs, Swapped: %d reqs, "
|
||||
"Pending: %d reqs, GPU KV cache usage: %.1f%%, "
|
||||
"CPU KV cache usage: %.1f%%.",
|
||||
prompt_throughput,
|
||||
generation_throughput,
|
||||
stats.num_running_sys,
|
||||
stats.num_swapped_sys,
|
||||
stats.num_waiting_sys,
|
||||
stats.gpu_cache_usage_sys * 100,
|
||||
stats.cpu_cache_usage_sys * 100,
|
||||
)
|
||||
if (stats.cpu_prefix_cache_hit_rate >= 0
|
||||
or stats.gpu_prefix_cache_hit_rate >= 0):
|
||||
logger.info(
|
||||
"Prefix cache hit rate: GPU: %.2f%%, CPU: %.2f%%",
|
||||
stats.gpu_prefix_cache_hit_rate * 100,
|
||||
stats.cpu_prefix_cache_hit_rate * 100,
|
||||
)
|
||||
if self.spec_decode_metrics is not None:
|
||||
logger.info(
|
||||
self._format_spec_decode_metrics_str(
|
||||
self.spec_decode_metrics))
|
||||
|
||||
# Reset tracked stats for next interval.
|
||||
self.num_prompt_tokens = []
|
||||
self.num_generation_tokens = []
|
||||
self.last_local_log = stats.now
|
||||
self.spec_decode_metrics = None
|
||||
|
||||
def _format_spec_decode_metrics_str(
|
||||
self, metrics: "SpecDecodeWorkerMetrics") -> str:
|
||||
|
||||
return ("Speculative metrics: "
|
||||
f"Draft acceptance rate: {metrics.draft_acceptance_rate:.3f}, "
|
||||
f"System efficiency: {metrics.system_efficiency:.3f}, "
|
||||
f"Number of speculative tokens: {metrics.num_spec_tokens}, "
|
||||
f"Number of accepted tokens: {metrics.accepted_tokens}, "
|
||||
f"Number of draft tokens: {metrics.draft_tokens}, "
|
||||
f"Number of emitted tokens: {metrics.emitted_tokens}.")
|
||||
|
||||
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class PrometheusStatLogger(StatLoggerBase):
|
||||
"""PrometheusStatLogger is used LLMEngine to log to Promethus."""
|
||||
_metrics_cls = Metrics
|
||||
_gauge_cls = prometheus_client.Gauge
|
||||
|
||||
def __init__(self, local_interval: float, labels: Dict[str, str],
|
||||
max_model_len: int) -> None:
|
||||
super().__init__(local_interval)
|
||||
# Prometheus metrics
|
||||
self.labels = labels
|
||||
self.metrics = self._metrics_cls(labelnames=list(labels.keys()),
|
||||
max_model_len=max_model_len)
|
||||
|
||||
def _log_gauge(self, gauge, data: Union[int, float]) -> None:
|
||||
# Convenience function for logging to gauge.
|
||||
gauge.labels(**self.labels).set(data)
|
||||
|
||||
def _log_counter(self, counter, data: Union[int, float]) -> None:
|
||||
# Convenience function for logging to counter.
|
||||
counter.labels(**self.labels).inc(data)
|
||||
|
||||
def _log_counter_labels(self, counter, data: CollectionsCounter,
|
||||
label_key: str) -> None:
|
||||
# Convenience function for collection counter of labels.
|
||||
for label, count in data.items():
|
||||
counter.labels(**{**self.labels, label_key: label}).inc(count)
|
||||
|
||||
def _log_histogram(self, histogram, data: Union[List[int],
|
||||
List[float]]) -> None:
|
||||
# Convenience function for logging list to histogram.
|
||||
for datum in data:
|
||||
histogram.labels(**self.labels).observe(datum)
|
||||
|
||||
def _log_prometheus(self, stats: Stats) -> None:
|
||||
# System state data
|
||||
self._log_gauge(self.metrics.gauge_scheduler_running,
|
||||
stats.num_running_sys)
|
||||
self._log_gauge(self.metrics.gauge_scheduler_swapped,
|
||||
stats.num_swapped_sys)
|
||||
self._log_gauge(self.metrics.gauge_scheduler_waiting,
|
||||
stats.num_waiting_sys)
|
||||
self._log_gauge(self.metrics.gauge_gpu_cache_usage,
|
||||
stats.gpu_cache_usage_sys)
|
||||
self._log_gauge(self.metrics.gauge_cpu_cache_usage,
|
||||
stats.cpu_cache_usage_sys)
|
||||
self._log_gauge(self.metrics.gauge_cpu_prefix_cache_hit_rate,
|
||||
stats.cpu_prefix_cache_hit_rate)
|
||||
self._log_gauge(self.metrics.gauge_gpu_prefix_cache_hit_rate,
|
||||
stats.gpu_prefix_cache_hit_rate)
|
||||
|
||||
# Iteration level data
|
||||
self._log_counter(self.metrics.counter_num_preemption,
|
||||
stats.num_preemption_iter)
|
||||
self._log_counter(self.metrics.counter_prompt_tokens,
|
||||
stats.num_prompt_tokens_iter)
|
||||
self._log_counter(self.metrics.counter_generation_tokens,
|
||||
stats.num_generation_tokens_iter)
|
||||
self._log_histogram(self.metrics.histogram_time_to_first_token,
|
||||
stats.time_to_first_tokens_iter)
|
||||
self._log_histogram(self.metrics.histogram_time_per_output_token,
|
||||
stats.time_per_output_tokens_iter)
|
||||
|
||||
# Request level data
|
||||
# Latency
|
||||
self._log_histogram(self.metrics.histogram_e2e_time_request,
|
||||
stats.time_e2e_requests)
|
||||
# Metadata
|
||||
finished_reason_counter = CollectionsCounter(
|
||||
stats.finished_reason_requests)
|
||||
self._log_counter_labels(self.metrics.counter_request_success,
|
||||
finished_reason_counter,
|
||||
Metrics.labelname_finish_reason)
|
||||
self._log_histogram(self.metrics.histogram_num_prompt_tokens_request,
|
||||
stats.num_prompt_tokens_requests)
|
||||
self._log_histogram(
|
||||
self.metrics.histogram_num_generation_tokens_request,
|
||||
stats.num_generation_tokens_requests)
|
||||
self._log_histogram(self.metrics.histogram_n_request, stats.n_requests)
|
||||
|
||||
def _log_prometheus_interval(self, prompt_throughput: float,
|
||||
generation_throughput: float) -> None:
|
||||
# Logs metrics to prometheus that are computed every logging_interval.
|
||||
# Support legacy gauge metrics that make throughput calculations on
|
||||
# the vLLM side. Moving forward, we should use counters like
|
||||
# counter_prompt_tokens, counter_generation_tokens
|
||||
# Which log raw data and calculate summaries using rate() on the
|
||||
# grafana/prometheus side. See
|
||||
# https://github.com/vllm-project/vllm/pull/2316#discussion_r1464204666
|
||||
self.metrics.gauge_avg_prompt_throughput.labels(
|
||||
**self.labels).set(prompt_throughput)
|
||||
self.metrics.gauge_avg_generation_throughput.labels(
|
||||
**self.labels).set(generation_throughput)
|
||||
|
||||
def log(self, stats: Stats):
|
||||
"""Logs to prometheus and tracked stats every iteration."""
|
||||
# Log to prometheus.
|
||||
self._log_prometheus(stats)
|
||||
|
||||
# Save tracked stats for token counters.
|
||||
self.num_prompt_tokens.append(stats.num_prompt_tokens_iter)
|
||||
self.num_generation_tokens.append(stats.num_generation_tokens_iter)
|
||||
|
||||
# Update spec decode metrics
|
||||
self.maybe_update_spec_decode_metrics(stats)
|
||||
|
||||
# Log locally every local_interval seconds.
|
||||
if local_interval_elapsed(stats.now, self.last_local_log,
|
||||
self.local_interval):
|
||||
# Compute summary metrics for tracked stats (and log them
|
||||
# to promethus if applicable).
|
||||
prompt_throughput = get_throughput(self.num_prompt_tokens,
|
||||
now=stats.now,
|
||||
last_log=self.last_local_log)
|
||||
generation_throughput = get_throughput(
|
||||
self.num_generation_tokens,
|
||||
now=stats.now,
|
||||
last_log=self.last_local_log)
|
||||
|
||||
self._log_prometheus_interval(
|
||||
prompt_throughput=prompt_throughput,
|
||||
generation_throughput=generation_throughput)
|
||||
|
||||
if self.spec_decode_metrics is not None:
|
||||
self._log_gauge(
|
||||
self.metrics.gauge_spec_decode_draft_acceptance_rate,
|
||||
self.spec_decode_metrics.draft_acceptance_rate)
|
||||
self._log_gauge(self.metrics.gauge_spec_decode_efficiency,
|
||||
self.spec_decode_metrics.system_efficiency)
|
||||
self._log_counter(
|
||||
self.metrics.counter_spec_decode_num_accepted_tokens,
|
||||
self.spec_decode_metrics.accepted_tokens)
|
||||
self._log_counter(
|
||||
self.metrics.counter_spec_decode_num_draft_tokens,
|
||||
self.spec_decode_metrics.draft_tokens)
|
||||
self._log_counter(
|
||||
self.metrics.counter_spec_decode_num_emitted_tokens,
|
||||
self.spec_decode_metrics.emitted_tokens)
|
||||
|
||||
# Reset tracked stats for next interval.
|
||||
self.num_prompt_tokens = []
|
||||
self.num_generation_tokens = []
|
||||
self.last_local_log = stats.now
|
||||
self.spec_decode_metrics = None
|
||||
|
||||
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
|
||||
# Info type metrics are syntactic sugar for a gauge permanently set to 1
|
||||
# Since prometheus multiprocessing mode does not support Info, emulate
|
||||
# info here with a gauge.
|
||||
if type == "cache_config":
|
||||
metrics_info = obj.metrics_info()
|
||||
info_gauge = self._gauge_cls(
|
||||
name="vllm:cache_config_info",
|
||||
documentation="Information of the LLMEngine CacheConfig",
|
||||
labelnames=metrics_info.keys(),
|
||||
multiprocess_mode="mostrecent")
|
||||
info_gauge.labels(**metrics_info).set(1)
|
||||
|
||||
|
||||
class RayPrometheusStatLogger(PrometheusStatLogger):
|
||||
"""RayPrometheusStatLogger uses Ray metrics instead."""
|
||||
_metrics_cls = RayMetrics
|
||||
|
||||
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
|
||||
return None
|
||||
87
vllm/engine/metrics_types.py
Normal file
87
vllm/engine/metrics_types.py
Normal file
@@ -0,0 +1,87 @@
|
||||
"""
|
||||
These types are defined in this file to avoid importing vllm.engine.metrics
|
||||
and therefore importing prometheus_client.
|
||||
|
||||
This is required due to usage of Prometheus multiprocess mode to enable
|
||||
metrics after splitting out the uvicorn process from the engine process.
|
||||
|
||||
Prometheus multiprocess mode requires setting PROMETHEUS_MULTIPROC_DIR
|
||||
before prometheus_client is imported. Typically, this is done by setting
|
||||
the env variable before launch, but since we are a library, we need to
|
||||
do this in Python code and lazily import prometheus_client.
|
||||
"""
|
||||
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Protocol
|
||||
|
||||
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
|
||||
|
||||
|
||||
@dataclass
|
||||
class Stats:
|
||||
"""Created by LLMEngine for use by StatLogger."""
|
||||
now: float
|
||||
|
||||
# System stats (should have _sys suffix)
|
||||
# Scheduler State
|
||||
num_running_sys: int
|
||||
num_waiting_sys: int
|
||||
num_swapped_sys: int
|
||||
# KV Cache Usage in %
|
||||
gpu_cache_usage_sys: float
|
||||
cpu_cache_usage_sys: float
|
||||
# Prefix caching block hit rate
|
||||
cpu_prefix_cache_hit_rate: float
|
||||
gpu_prefix_cache_hit_rate: float
|
||||
|
||||
# Iteration stats (should have _iter suffix)
|
||||
num_prompt_tokens_iter: int
|
||||
num_generation_tokens_iter: int
|
||||
time_to_first_tokens_iter: List[float]
|
||||
time_per_output_tokens_iter: List[float]
|
||||
num_preemption_iter: int
|
||||
|
||||
# Request stats (should have _requests suffix)
|
||||
# Latency
|
||||
time_e2e_requests: List[float]
|
||||
# Metadata
|
||||
num_prompt_tokens_requests: List[int]
|
||||
num_generation_tokens_requests: List[int]
|
||||
n_requests: List[int]
|
||||
finished_reason_requests: List[str]
|
||||
|
||||
spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None
|
||||
|
||||
|
||||
class SupportsMetricsInfo(Protocol):
|
||||
|
||||
def metrics_info(self) -> Dict[str, str]:
|
||||
...
|
||||
|
||||
|
||||
class StatLoggerBase(ABC):
|
||||
"""Base class for StatLogger."""
|
||||
|
||||
def __init__(self, local_interval: float) -> None:
|
||||
# Tracked stats over current local logging interval.
|
||||
self.num_prompt_tokens: List[int] = []
|
||||
self.num_generation_tokens: List[int] = []
|
||||
self.last_local_log = time.time()
|
||||
self.local_interval = local_interval
|
||||
self.spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None
|
||||
|
||||
@abstractmethod
|
||||
def log(self, stats: Stats) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def maybe_update_spec_decode_metrics(self, stats: Stats):
|
||||
"""Save spec decode metrics (since they are unlikely
|
||||
to be emitted at same time as log interval)."""
|
||||
if stats.spec_decode_metrics is not None:
|
||||
self.spec_decode_metrics = stats.spec_decode_metrics
|
||||
135
vllm/engine/multiprocessing/__init__.py
Normal file
135
vllm/engine/multiprocessing/__init__.py
Normal file
@@ -0,0 +1,135 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import List, Mapping, Optional, Union, overload
|
||||
|
||||
from vllm import PoolingParams
|
||||
from vllm.inputs import PromptType
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import deprecate_kwargs
|
||||
|
||||
VLLM_RPC_SUCCESS_STR = "SUCCESS"
|
||||
|
||||
IPC_INPUT_EXT = "_input_socket"
|
||||
IPC_OUTPUT_EXT = "_output_socket"
|
||||
IPC_HEALTH_EXT = "_health_socket"
|
||||
IPC_DATA_EXT = "_data_socket"
|
||||
|
||||
|
||||
class MQEngineDeadError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class RPCProcessRequest:
|
||||
prompt: PromptType
|
||||
params: Union[SamplingParams, PoolingParams]
|
||||
request_id: str
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
trace_headers: Optional[Mapping[str, str]] = None
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
||||
priority: int = 0
|
||||
|
||||
@overload # DEPRECATED
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
inputs: PromptType,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
) -> None:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
) -> None:
|
||||
...
|
||||
|
||||
@deprecate_kwargs(
|
||||
"inputs",
|
||||
additional_message="Please use the 'prompt' parameter instead.",
|
||||
)
|
||||
def __init__(
|
||||
self,
|
||||
prompt: Optional[PromptType] = None,
|
||||
params: Optional[Union[SamplingParams, PoolingParams]] = None,
|
||||
request_id: Optional[str] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
*,
|
||||
inputs: Optional[PromptType] = None, # DEPRECATED
|
||||
) -> None:
|
||||
if inputs is not None:
|
||||
prompt = inputs
|
||||
assert (prompt is not None and params is not None
|
||||
and request_id is not None)
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.prompt = prompt
|
||||
self.params = params
|
||||
self.request_id = request_id
|
||||
self.lora_request = lora_request
|
||||
self.trace_headers = trace_headers
|
||||
self.prompt_adapter_request = prompt_adapter_request
|
||||
self.priority = priority
|
||||
|
||||
|
||||
@dataclass
|
||||
class RPCError:
|
||||
request_id: Optional[str]
|
||||
is_engine_errored: bool
|
||||
exception: BaseException
|
||||
|
||||
|
||||
@dataclass
|
||||
class RPCAbortRequest:
|
||||
request_id: str
|
||||
|
||||
|
||||
class RPCStartupRequest(Enum):
|
||||
IS_SERVER_READY = 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class RPCStartupResponse:
|
||||
tracing_enabled: bool
|
||||
|
||||
|
||||
class RPCUProfileRequest(Enum):
|
||||
START_PROFILE = 1
|
||||
STOP_PROFILE = 2
|
||||
|
||||
|
||||
RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest,
|
||||
RPCUProfileRequest]
|
||||
|
||||
REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCError]
|
||||
|
||||
|
||||
def ENGINE_DEAD_ERROR(
|
||||
error: Optional[BaseException] = None) -> MQEngineDeadError:
|
||||
if error is None:
|
||||
return MQEngineDeadError(
|
||||
"Engine loop is not running. Inspect the stacktrace to "
|
||||
"find the original error")
|
||||
|
||||
return MQEngineDeadError(
|
||||
"Engine loop is not running. Inspect the stacktrace to "
|
||||
f"find the original error: {repr(error)}.")
|
||||
BIN
vllm/engine/multiprocessing/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
vllm/engine/multiprocessing/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/engine/multiprocessing/__pycache__/client.cpython-310.pyc
Normal file
BIN
vllm/engine/multiprocessing/__pycache__/client.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/engine/multiprocessing/__pycache__/engine.cpython-310.pyc
Normal file
BIN
vllm/engine/multiprocessing/__pycache__/engine.cpython-310.pyc
Normal file
Binary file not shown.
704
vllm/engine/multiprocessing/client.py
Normal file
704
vllm/engine/multiprocessing/client.py
Normal file
@@ -0,0 +1,704 @@
|
||||
import asyncio
|
||||
import copy
|
||||
import pickle
|
||||
from contextlib import contextmanager, suppress
|
||||
from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping,
|
||||
Optional, Union, overload)
|
||||
|
||||
import cloudpickle
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
from zmq import Frame # type: ignore[attr-defined]
|
||||
from zmq.asyncio import Socket
|
||||
|
||||
from vllm import PoolingParams
|
||||
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
|
||||
from vllm.config import DecodingConfig, EngineConfig, ModelConfig
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.engine.async_llm_engine import (
|
||||
build_guided_decoding_logits_processor_async)
|
||||
from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
|
||||
IPC_HEALTH_EXT, IPC_INPUT_EXT,
|
||||
IPC_OUTPUT_EXT, RPC_REQUEST_T,
|
||||
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
|
||||
RPCError, RPCProcessRequest,
|
||||
RPCStartupRequest, RPCStartupResponse,
|
||||
RPCUProfileRequest)
|
||||
# yapf: enable
|
||||
from vllm.envs import VLLM_RPC_TIMEOUT
|
||||
from vllm.inputs import PromptType, TokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import (CompletionOutput, EmbeddingRequestOutput,
|
||||
RequestOutput)
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
||||
from vllm.utils import (collect_from_async_generator, deprecate_kwargs,
|
||||
random_uuid)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class MQClientClosedError(Exception):
|
||||
"""Exception class raised when the client is used post-close.
|
||||
|
||||
The client can be closed, which closes the ZMQ context. This normally
|
||||
happens on server shutdown. In some cases, methods like abort and
|
||||
do_log_stats will still be called and then try to open a socket, which
|
||||
causes a ZMQError and creates a huge stack trace.
|
||||
So, we throw this error such that we can suppress it.
|
||||
"""
|
||||
|
||||
|
||||
class MQLLMEngineClient:
|
||||
"""A client wrapper for MQLLMEngine that conforms to the
|
||||
EngineClient protocol.
|
||||
|
||||
MQLLMEngine and MQLLMEngineClient are intended to run in separate
|
||||
processes communicating via zeromq ipc sockets.
|
||||
|
||||
The entrypoint to MQLLMEngineClient is through the generate()
|
||||
method. On generate() MQLLMEngine does three things:
|
||||
- Creates an asyncio output queue
|
||||
- Sends a RPCGenerateRequest to the MQLLMEngine via zmq
|
||||
- Pulls RequestOutputs from its queue and yields them
|
||||
|
||||
MQLLMEngine runs two background loops:
|
||||
- output_loop: the output loop pulls List[RequestOutput]
|
||||
from the MQLLMEngine via zmq (each list is the output
|
||||
of one engine_step in the LLMEngine). It then parses
|
||||
the list and pushes individual request_outputs into
|
||||
the corresponding output_queue such that they can be
|
||||
consumed by the .generate() method.
|
||||
- health_loop: the health loop queries the health socket
|
||||
every N seconds, confirming the engine is healthy
|
||||
"""
|
||||
|
||||
def __init__(self, ipc_path: str, engine_config: EngineConfig):
|
||||
self.context = zmq.asyncio.Context()
|
||||
self._errored_with: Optional[BaseException] = None
|
||||
|
||||
# Get the configs.
|
||||
self.model_config = engine_config.model_config
|
||||
self.decoding_config = engine_config.decoding_config
|
||||
|
||||
# Create the tokenizer group.
|
||||
self.tokenizer = init_tokenizer_from_configs(
|
||||
model_config=self.model_config,
|
||||
scheduler_config=engine_config.scheduler_config,
|
||||
parallel_config=engine_config.parallel_config,
|
||||
enable_lora=bool(engine_config.lora_config),
|
||||
)
|
||||
|
||||
# Send RPCGenerateRequest to the MQLLMEngine.
|
||||
self.input_socket: Socket = self.context.socket(zmq.constants.PUSH)
|
||||
self.input_socket.connect(f"{ipc_path}{IPC_INPUT_EXT}")
|
||||
|
||||
# Receive streams of RequestOutput from the MQLLMEngine.
|
||||
self.output_socket: Socket = self.context.socket(zmq.constants.PULL)
|
||||
self.output_socket.connect(f"{ipc_path}{IPC_OUTPUT_EXT}")
|
||||
|
||||
# IPC path for acking heartbeats.
|
||||
self.heartbeat_socket: Socket = self.context.socket(zmq.constants.PULL)
|
||||
self.heartbeat_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}")
|
||||
|
||||
# IPC path for the data socket.
|
||||
self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}"
|
||||
|
||||
# Stream for each individual request.
|
||||
self.output_queues: Dict[str, asyncio.Queue] = {}
|
||||
self.output_loop = asyncio.create_task(self.run_output_handler_loop())
|
||||
|
||||
# Loop to check health of the LLMEngine periodically.
|
||||
# Started after the MQLLMEngine is ready.
|
||||
self.health_loop: Optional[asyncio.Task] = None
|
||||
|
||||
@staticmethod
|
||||
def is_unsupported_config(engine_args: AsyncEngineArgs):
|
||||
# Pipeline parallel not yet supported
|
||||
return engine_args.pipeline_parallel_size > 1
|
||||
|
||||
@contextmanager
|
||||
def get_data_socket(self) -> Iterator[Socket]:
|
||||
socket = self.context.socket(zmq.constants.DEALER)
|
||||
try:
|
||||
socket.connect(self.data_ipc_path)
|
||||
yield socket
|
||||
finally:
|
||||
socket.close(linger=0)
|
||||
|
||||
async def run_heartbeat_loop(self, timeout: int):
|
||||
"""Background loop that continually listens to the RPCServer for
|
||||
heartbeats.
|
||||
"""
|
||||
try:
|
||||
while True:
|
||||
if await self.heartbeat_socket.poll(timeout=timeout) == 0:
|
||||
# No heartbeat was received. Set error and exit the loop
|
||||
self._set_errored(
|
||||
TimeoutError("No heartbeat received "
|
||||
"from MQLLMEngine"))
|
||||
logger.debug("Shutting down MQLLMEngineClient check "
|
||||
"health loop due to timeout")
|
||||
break
|
||||
|
||||
else:
|
||||
# Heartbeat received- check the message
|
||||
await self._check_success(
|
||||
error_message="Heartbeat failed.",
|
||||
socket=self.heartbeat_socket)
|
||||
|
||||
logger.debug("Heartbeat successful.")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.debug("Shutting down MQLLMEngineClient check health loop.")
|
||||
|
||||
except Exception as e:
|
||||
self._set_errored(e)
|
||||
|
||||
async def run_output_handler_loop(self):
|
||||
"""Get RequestOutputs from Engine and stream to Request Queues"""
|
||||
|
||||
try:
|
||||
while True:
|
||||
# Poll, checking for ENGINE_DEAD
|
||||
while await self.output_socket.poll(timeout=VLLM_RPC_TIMEOUT
|
||||
) == 0:
|
||||
logger.debug("Waiting for output from MQLLMEngine.")
|
||||
|
||||
# If errored, alert all running requests.
|
||||
if self.errored:
|
||||
for queue_j in tuple(self.output_queues.values()):
|
||||
queue_j.put_nowait(
|
||||
ENGINE_DEAD_ERROR(self._errored_with))
|
||||
return
|
||||
|
||||
message: Frame = await self.output_socket.recv(copy=False)
|
||||
request_outputs = pickle.loads(message.buffer)
|
||||
|
||||
is_error = isinstance(request_outputs,
|
||||
(BaseException, RPCError))
|
||||
if is_error:
|
||||
if isinstance(request_outputs, RPCError):
|
||||
rpc_error: RPCError = request_outputs
|
||||
request_id = rpc_error.request_id
|
||||
exception = rpc_error.exception
|
||||
is_engine_errored = rpc_error.is_engine_errored
|
||||
else:
|
||||
# MPLLMEngine should always return an RPCError to
|
||||
# the output_socket when an issue arises.
|
||||
# If we are here, we are in a bad state and
|
||||
# should shut down the server.
|
||||
error: BaseException = request_outputs
|
||||
logger.error(
|
||||
"Received Exception %s rather than RPCError from "
|
||||
"MPLLMEngine. This should never happen.", error)
|
||||
request_id = None
|
||||
exception = error
|
||||
is_engine_errored = True
|
||||
|
||||
# Set to error state only on engine critical error
|
||||
# (and record only the first one)
|
||||
if is_engine_errored and not self._errored_with:
|
||||
self._errored_with = exception
|
||||
|
||||
if request_id is None:
|
||||
for queue_i in tuple(self.output_queues.values()):
|
||||
queue_i.put_nowait(exception)
|
||||
else:
|
||||
queue = self.output_queues.get(request_id)
|
||||
if queue is not None:
|
||||
queue.put_nowait(exception)
|
||||
else:
|
||||
# Put each output into the appropriate steam.
|
||||
for request_output in request_outputs:
|
||||
queue = self.output_queues.get(
|
||||
request_output.request_id)
|
||||
if queue is not None:
|
||||
queue.put_nowait(request_output)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.debug("Shutting down MQLLMEngineClient output handler.")
|
||||
|
||||
async def setup(self):
|
||||
"""Setup the client before it starts sending server requests."""
|
||||
|
||||
with self.get_data_socket() as socket:
|
||||
# Wait until server is ready.
|
||||
response = await self._wait_for_server_rpc(socket)
|
||||
|
||||
self.tracing_flag = response.tracing_enabled
|
||||
|
||||
# Start health_loop.
|
||||
self.health_loop = asyncio.create_task(
|
||||
self.run_heartbeat_loop(timeout=VLLM_RPC_TIMEOUT))
|
||||
|
||||
def close(self):
|
||||
"""Destroy the ZeroMQ Context."""
|
||||
# Close all sockets and terminate the context.
|
||||
self.context.destroy(linger=0)
|
||||
|
||||
# Cancel background tasks.
|
||||
if self.health_loop is not None:
|
||||
self.health_loop.cancel()
|
||||
self.output_loop.cancel()
|
||||
|
||||
def _set_errored(self, e: BaseException):
|
||||
logger.exception(repr(e))
|
||||
if self._errored_with is None:
|
||||
self._errored_with = e
|
||||
|
||||
@staticmethod
|
||||
async def _send_get_data_rpc_request(request: RPCStartupRequest,
|
||||
expected_type: Any,
|
||||
error_message: str,
|
||||
socket: Socket) -> Any:
|
||||
"""Send an RPC request that is expecting data back."""
|
||||
|
||||
# Ping RPCServer with a request.
|
||||
await socket.send_multipart((pickle.dumps(request), ), copy=False)
|
||||
|
||||
# Make sure the server responds in time.
|
||||
if await socket.poll(timeout=VLLM_RPC_TIMEOUT) == 0:
|
||||
raise TimeoutError("RPCServer didn't reply within "
|
||||
f"{VLLM_RPC_TIMEOUT} ms")
|
||||
|
||||
# Await the data from the Server.
|
||||
frame = await socket.recv(copy=False)
|
||||
data = pickle.loads(frame.buffer)
|
||||
|
||||
if isinstance(data, BaseException):
|
||||
raise data
|
||||
elif not isinstance(data, expected_type):
|
||||
raise ValueError(error_message)
|
||||
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
async def _send_one_way_rpc_request(request: RPC_REQUEST_T,
|
||||
socket: Socket):
|
||||
"""Send one-way RPC request to trigger an action."""
|
||||
|
||||
if socket.closed:
|
||||
raise MQClientClosedError()
|
||||
|
||||
await socket.send_multipart((pickle.dumps(request), ))
|
||||
|
||||
async def _await_ack(self, error_message: str, socket: Socket):
|
||||
"""Await acknowledgement that a request succeeded."""
|
||||
|
||||
if socket.closed:
|
||||
raise MQClientClosedError()
|
||||
|
||||
if await socket.poll(timeout=VLLM_RPC_TIMEOUT) == 0:
|
||||
raise TimeoutError("MQLLMEngine didn't reply within "
|
||||
f"{VLLM_RPC_TIMEOUT}ms")
|
||||
|
||||
await self._check_success(error_message, socket)
|
||||
|
||||
@staticmethod
|
||||
async def _check_success(error_message: str, socket: Socket):
|
||||
"""Confirm that socket has a VLLM_RPC_SUCCESS_STR message"""
|
||||
|
||||
if socket.closed:
|
||||
raise MQClientClosedError()
|
||||
|
||||
frame = await socket.recv(copy=False)
|
||||
response = pickle.loads(frame.buffer)
|
||||
|
||||
# Raise error if unsuccessful
|
||||
if isinstance(response, BaseException):
|
||||
raise response
|
||||
elif (not isinstance(response, str)
|
||||
or response != VLLM_RPC_SUCCESS_STR):
|
||||
raise ValueError(error_message)
|
||||
|
||||
async def get_tokenizer(self, lora_request: LoRARequest):
|
||||
return await self.tokenizer.get_lora_tokenizer_async(lora_request)
|
||||
|
||||
async def get_decoding_config(self) -> DecodingConfig:
|
||||
return self.decoding_config
|
||||
|
||||
async def get_model_config(self) -> ModelConfig:
|
||||
return self.model_config
|
||||
|
||||
async def is_tracing_enabled(self) -> bool:
|
||||
return self.tracing_flag
|
||||
|
||||
async def _wait_for_server_rpc(self, socket: Socket) -> RPCStartupResponse:
|
||||
"""Wait for the RPCServer to start up."""
|
||||
|
||||
return await self._send_get_data_rpc_request(
|
||||
request=RPCStartupRequest.IS_SERVER_READY,
|
||||
expected_type=RPCStartupResponse,
|
||||
error_message="Unable to start RPC Server",
|
||||
socket=socket)
|
||||
|
||||
async def abort(self, request_id: str):
|
||||
"""Send an ABORT_REQUEST signal to the RPC Server"""
|
||||
|
||||
with suppress(MQClientClosedError):
|
||||
await self._send_one_way_rpc_request(
|
||||
request=RPCAbortRequest(request_id), socket=self.input_socket)
|
||||
|
||||
async def do_log_stats(self):
|
||||
"""Ignore do_log_stats (handled on MQLLMEngine polling)"""
|
||||
pass
|
||||
|
||||
async def check_health(self):
|
||||
"""
|
||||
The check health loop probes the health status of the
|
||||
Engine's health every N seconds and sets _errored_with
|
||||
if the engine is unhealthy.
|
||||
"""
|
||||
if self._errored_with is not None:
|
||||
raise self._errored_with
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
return not self.errored
|
||||
|
||||
@property
|
||||
def is_stopped(self) -> bool:
|
||||
return self.errored
|
||||
|
||||
@property
|
||||
def errored(self) -> bool:
|
||||
return self._errored_with is not None
|
||||
|
||||
@property
|
||||
def dead_error(self) -> BaseException:
|
||||
return ENGINE_DEAD_ERROR(self._errored_with)
|
||||
|
||||
@overload # DEPRECATED
|
||||
def generate(
|
||||
self,
|
||||
*,
|
||||
inputs: PromptType,
|
||||
sampling_params: SamplingParams,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
) -> AsyncGenerator[RequestOutput, None]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
sampling_params: SamplingParams,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
) -> AsyncGenerator[RequestOutput, None]:
|
||||
...
|
||||
|
||||
@deprecate_kwargs(
|
||||
"inputs",
|
||||
additional_message="Please use the 'prompt' parameter instead.",
|
||||
)
|
||||
def generate(
|
||||
self,
|
||||
prompt: Optional[PromptType] = None,
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
request_id: Optional[str] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
*,
|
||||
inputs: Optional[PromptType] = None # DEPRECATED
|
||||
) -> AsyncGenerator[RequestOutput, None]:
|
||||
"""Generate outputs for a request.
|
||||
|
||||
Generate outputs for a request. This method is a coroutine. It adds the
|
||||
request into the waiting queue of the LLMEngine and streams the outputs
|
||||
from the LLMEngine to the caller.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
|
||||
for more details about the format of each input.
|
||||
sampling_params: The sampling parameters of the request.
|
||||
request_id: The unique id of the request.
|
||||
lora_request: LoRA request to use for generation, if any.
|
||||
trace_headers: OpenTelemetry trace headers.
|
||||
prompt_adapter_request: Prompt Adapter request to use
|
||||
for generation, if any.
|
||||
priority: Priority of the request (lower means earlier handling).
|
||||
Any priority other than 0 will lead to an error if the
|
||||
scheduling policy is not "priority".
|
||||
"""
|
||||
if inputs is not None:
|
||||
prompt = inputs
|
||||
assert (prompt is not None and sampling_params is not None
|
||||
and request_id is not None)
|
||||
|
||||
return self._process_request(prompt, sampling_params, request_id,
|
||||
lora_request, trace_headers,
|
||||
prompt_adapter_request, priority)
|
||||
|
||||
async def beam_search(
|
||||
self,
|
||||
prompt: Union[PromptType, List[int]],
|
||||
request_id: str,
|
||||
params: BeamSearchParams,
|
||||
) -> AsyncGenerator[RequestOutput, None]:
|
||||
|
||||
beam_width = params.beam_width
|
||||
max_tokens = params.max_tokens
|
||||
ignore_eos = params.ignore_eos
|
||||
temperature = params.temperature
|
||||
length_penalty = params.length_penalty
|
||||
|
||||
tokenizer = await self.get_tokenizer(lora_request=None)
|
||||
tokenizedPrompt = prompt if isinstance(
|
||||
prompt, list) else tokenizer.encode(prompt)
|
||||
tokenizedLength = len(tokenizedPrompt)
|
||||
|
||||
sort_beams_key = create_sort_beams_key_function(
|
||||
tokenizer.eos_token_id, length_penalty)
|
||||
|
||||
beam_search_params = SamplingParams(logprobs=2 * beam_width,
|
||||
max_tokens=1,
|
||||
temperature=temperature)
|
||||
all_beams = [BeamSearchSequence(tokens=tokenizedPrompt, cum_logprob=0)]
|
||||
completed = []
|
||||
|
||||
for _ in range(max_tokens):
|
||||
prompts_batch = [
|
||||
TokensPrompt(prompt_token_ids=beam.tokens)
|
||||
for beam in all_beams
|
||||
]
|
||||
|
||||
tasks = []
|
||||
|
||||
request_id = f"beam_search-{random_uuid()}"
|
||||
for i, individual_prompt in enumerate(prompts_batch):
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
task = asyncio.create_task(
|
||||
collect_from_async_generator(
|
||||
self.generate(individual_prompt, beam_search_params,
|
||||
request_id_item)))
|
||||
tasks.append(task)
|
||||
|
||||
output = await asyncio.gather(*tasks)
|
||||
|
||||
output = [x[0] for x in output]
|
||||
|
||||
logger.info(output)
|
||||
|
||||
new_beams = []
|
||||
for i, current_beam in enumerate(all_beams):
|
||||
result = output[i]
|
||||
|
||||
if result.outputs[0].logprobs is not None:
|
||||
logprobs = result.outputs[0].logprobs[0]
|
||||
for token_id, logprob_obj in logprobs.items():
|
||||
new_beam = BeamSearchSequence(
|
||||
tokens=current_beam.tokens + [token_id],
|
||||
cum_logprob=current_beam.cum_logprob +
|
||||
logprob_obj.logprob)
|
||||
|
||||
if token_id == tokenizer.eos_token_id and \
|
||||
not ignore_eos:
|
||||
completed.append(new_beam)
|
||||
else:
|
||||
new_beams.append(new_beam)
|
||||
|
||||
sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True)
|
||||
all_beams = sorted_beams[:beam_width]
|
||||
|
||||
completed.extend(all_beams)
|
||||
sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
|
||||
best_beams = sorted_completed[:beam_width]
|
||||
|
||||
for beam in best_beams:
|
||||
beam.text = tokenizer.decode(beam.tokens[tokenizedLength:])
|
||||
|
||||
beam_search_output = RequestOutput(
|
||||
request_id=request_id,
|
||||
prompt=prompt,
|
||||
outputs=[
|
||||
CompletionOutput(
|
||||
text=beam.text,
|
||||
cumulative_logprob=beam.cum_logprob,
|
||||
token_ids=beam.tokens,
|
||||
index=i,
|
||||
logprobs=beam.cum_logprob,
|
||||
) for (i, beam) in enumerate(best_beams)
|
||||
],
|
||||
finished=True,
|
||||
prompt_token_ids=tokenizedPrompt,
|
||||
prompt_logprobs=None)
|
||||
|
||||
logger.info(beam_search_output)
|
||||
|
||||
yield beam_search_output
|
||||
|
||||
@overload # DEPRECATED
|
||||
def encode(
|
||||
self,
|
||||
*,
|
||||
inputs: PromptType,
|
||||
pooling_params: PoolingParams,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
priority: int = 0,
|
||||
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def encode(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
pooling_params: PoolingParams,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
priority: int = 0,
|
||||
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
|
||||
...
|
||||
|
||||
@deprecate_kwargs(
|
||||
"inputs",
|
||||
additional_message="Please use the 'prompt' parameter instead.",
|
||||
)
|
||||
def encode(
|
||||
self,
|
||||
prompt: Optional[PromptType] = None,
|
||||
pooling_params: Optional[PoolingParams] = None,
|
||||
request_id: Optional[str] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
priority: int = 0,
|
||||
*,
|
||||
inputs: Optional[PromptType] = None # DEPRECATED
|
||||
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
|
||||
"""Generate outputs for a request from an embedding model.
|
||||
|
||||
Generate outputs for a request. This method is a coroutine. It adds the
|
||||
request into the waiting queue of the LLMEngine and streams the outputs
|
||||
from the LLMEngine to the caller.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
|
||||
for more details about the format of each input.
|
||||
pooling_params: The pooling parameters of the request.
|
||||
request_id: The unique id of the request.
|
||||
lora_request: LoRA request to use for generation, if any.
|
||||
trace_headers: OpenTelemetry trace headers.
|
||||
|
||||
Yields:
|
||||
The output `EmbeddingRequestOutput` objects from the LLMEngine
|
||||
for the request.
|
||||
"""
|
||||
if inputs is not None:
|
||||
prompt = inputs
|
||||
assert (prompt is not None and pooling_params is not None
|
||||
and request_id is not None)
|
||||
|
||||
return self._process_request(prompt, pooling_params, request_id,
|
||||
lora_request, trace_headers, None,
|
||||
priority)
|
||||
|
||||
async def _process_request(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[
|
||||
EmbeddingRequestOutput, None]]:
|
||||
"""Send an RPCGenerateRequest to the RPCServer and stream responses."""
|
||||
|
||||
# If already dead, error out.
|
||||
if self._errored_with is not None:
|
||||
raise ENGINE_DEAD_ERROR(self._errored_with)
|
||||
|
||||
# Constructing guided decoding logits processors is expensive, so we do
|
||||
# it here to avoid contending with cpu resources and the GIL on the
|
||||
# backend process.
|
||||
if isinstance(params, SamplingParams) and \
|
||||
params.guided_decoding is not None:
|
||||
params = await \
|
||||
build_guided_decoding_logits_processor_async(
|
||||
sampling_params=params,
|
||||
tokenizer=await self.get_tokenizer(lora_request),
|
||||
default_guided_backend=self.decoding_config.guided_decoding_backend
|
||||
)
|
||||
|
||||
# 1) Create output queue for this requests.
|
||||
queue: asyncio.Queue[Union[RequestOutput,
|
||||
BaseException]] = asyncio.Queue()
|
||||
self.output_queues[request_id] = queue
|
||||
|
||||
try:
|
||||
# 2) Detach logits processors so that they can be pickled
|
||||
# separately (may require cloudpickle which is slower)
|
||||
if isinstance(params, SamplingParams) and params.logits_processors:
|
||||
# Defensive shallow copy
|
||||
params = copy.copy(params)
|
||||
logits_processors = params.logits_processors
|
||||
params.logits_processors = None
|
||||
lp_bytes = cloudpickle.dumps(logits_processors)
|
||||
else:
|
||||
lp_bytes = None
|
||||
|
||||
request_bytes = pickle.dumps(
|
||||
RPCProcessRequest(
|
||||
prompt=prompt,
|
||||
params=params,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
priority=priority,
|
||||
))
|
||||
|
||||
# 3) Send the RPCGenerateRequest to the MQLLMEngine.
|
||||
parts = (request_bytes,
|
||||
lp_bytes) if lp_bytes else (request_bytes, )
|
||||
await self.input_socket.send_multipart(parts, copy=False)
|
||||
|
||||
# 4) Stream the RequestOutputs from the output queue. Note
|
||||
# that the output_loop pushes RequestOutput objects to this
|
||||
# queue after pulling them from the zmq socket.
|
||||
finished = False
|
||||
try:
|
||||
while not finished:
|
||||
request_output = await queue.get()
|
||||
|
||||
if isinstance(request_output, BaseException):
|
||||
raise request_output
|
||||
|
||||
finished = request_output.finished
|
||||
yield request_output
|
||||
finally:
|
||||
# Request was canceled by the client.
|
||||
if not finished and not self.errored:
|
||||
await self.abort(request_id)
|
||||
finally:
|
||||
self.output_queues.pop(request_id)
|
||||
|
||||
async def start_profile(self) -> None:
|
||||
"""Start profiling the engine"""
|
||||
|
||||
await self._send_one_way_rpc_request(
|
||||
request=RPCUProfileRequest.START_PROFILE, socket=self.input_socket)
|
||||
|
||||
async def stop_profile(self) -> None:
|
||||
"""Stop profiling the engine"""
|
||||
|
||||
await self._send_one_way_rpc_request(
|
||||
request=RPCUProfileRequest.STOP_PROFILE, socket=self.input_socket)
|
||||
395
vllm/engine/multiprocessing/engine.py
Normal file
395
vllm/engine/multiprocessing/engine.py
Normal file
@@ -0,0 +1,395 @@
|
||||
import pickle
|
||||
import signal
|
||||
import threading
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from typing import Iterator, List, Optional, Union
|
||||
|
||||
import cloudpickle
|
||||
import zmq
|
||||
|
||||
from vllm import AsyncEngineArgs, LLMEngine, SamplingParams
|
||||
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig)
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
|
||||
IPC_HEALTH_EXT, IPC_INPUT_EXT,
|
||||
IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T,
|
||||
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
|
||||
RPCError, RPCProcessRequest,
|
||||
RPCStartupRequest, RPCStartupResponse,
|
||||
RPCUProfileRequest)
|
||||
# yapf: enable
|
||||
from vllm.envs import VLLM_RPC_TIMEOUT
|
||||
from vllm.executor.gpu_executor import GPUExecutor
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
|
||||
CONFIG_TYPE = Union[ModelConfig, DecodingConfig, ParallelConfig,
|
||||
SchedulerConfig, LoRAConfig]
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
POLLING_TIMEOUT_MS = 10000
|
||||
HEALTHY_RESPONSE = (pickle.dumps(VLLM_RPC_SUCCESS_STR), )
|
||||
|
||||
|
||||
class MQLLMEngine:
|
||||
"""A multiprocessing wrapper for :class:`LLMEngine`.
|
||||
|
||||
This class is used to wrap the :class:`LLMEngine` class to enable use
|
||||
in concurrnet manner. It runs a background loop and uses zeromq to
|
||||
receive new requests and stream outputs incrementally via ipc.
|
||||
|
||||
The :class:`LLMEngine` generate or encode process is kicked off when a new
|
||||
RPCProcessRequest is received by the input_socket.
|
||||
|
||||
The self.engine_loop checks the input_socket for new requests,
|
||||
adds them to the LLMEngine if there are any, calls the internal
|
||||
:class:`LLMEngine.step()`, and sends the RequestOutputs back over
|
||||
the output_socket.
|
||||
|
||||
If use_async_sockets is set, the logic associated with reading new
|
||||
requests from the socket and sending data to the socket is passed
|
||||
as a callback to the llm_engine, which calls the logic asynchronously
|
||||
such that the IPC can be overlapped with the GPU.
|
||||
|
||||
Args:
|
||||
ipc_path: Base path for zeromq interprocess messaging
|
||||
use_async_sockets: Whether to make send/recv async with GPU
|
||||
log_requests: Whether to log the requests.
|
||||
*args: Arguments for :class:`LLMEngine`.
|
||||
**kwargs: Arguments for :class:`LLMEngine`.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
ipc_path: str,
|
||||
use_async_sockets: bool,
|
||||
*args,
|
||||
log_requests: bool = True,
|
||||
**kwargs) -> None:
|
||||
# For MQLLMEngine, we can use cached outputs, since each new request
|
||||
# output is immediately pickled and send over the socket, which frees
|
||||
# the python object to be reused again.
|
||||
use_cached_outputs = True
|
||||
|
||||
self.engine = LLMEngine(*args,
|
||||
**kwargs,
|
||||
use_cached_outputs=use_cached_outputs)
|
||||
self.log_requests = log_requests
|
||||
|
||||
self.use_async_sockets = use_async_sockets
|
||||
if self.use_async_sockets:
|
||||
self.engine.process_request_outputs_callback = \
|
||||
self._async_socket_engine_callback
|
||||
|
||||
self.ctx = zmq.Context() # type: ignore[attr-defined]
|
||||
|
||||
# Receive input from the client.
|
||||
self.input_socket = self.ctx.socket(zmq.constants.PULL)
|
||||
self.input_socket.bind(f"{ipc_path}{IPC_INPUT_EXT}")
|
||||
|
||||
# Send output stream back to client.
|
||||
self.output_socket = self.ctx.socket(zmq.constants.PUSH)
|
||||
self.output_socket.bind(f"{ipc_path}{IPC_OUTPUT_EXT}")
|
||||
|
||||
# Send heartbeats back to client.
|
||||
self.heartbeat_socket = self.ctx.socket(zmq.constants.PUSH)
|
||||
self.heartbeat_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}")
|
||||
|
||||
# IPC path for the data socket.
|
||||
self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}"
|
||||
|
||||
# Error state.
|
||||
self._errored_with: Optional[BaseException] = None
|
||||
|
||||
# Heartbeat thread
|
||||
self.heartbeat_thread = threading.Thread(target=self._heartbeat_loop,
|
||||
daemon=True)
|
||||
self._heartbeat_stop_event = threading.Event()
|
||||
# The heartbeat needs to be faster than what the client will wait for
|
||||
# The VLLM_RPC_TIMEOUT duration is in ms, and we need one in seconds
|
||||
self.heartbeat_interval_seconds = VLLM_RPC_TIMEOUT / 5000.0
|
||||
|
||||
self._last_alive_time = time.time()
|
||||
# The heartbeats can tolerate a long period of the engine chugging
|
||||
# away at a generation request.
|
||||
# The VLLM_RPC_TIMEOUT duration is in ms, and we need one in seconds
|
||||
self.last_alive_threshold = VLLM_RPC_TIMEOUT * 3.0 / 1000.0
|
||||
|
||||
@property
|
||||
def dead_error(self) -> BaseException:
|
||||
if self._errored_with is not None:
|
||||
return ENGINE_DEAD_ERROR(self._errored_with)
|
||||
else:
|
||||
return ENGINE_DEAD_ERROR()
|
||||
|
||||
@classmethod
|
||||
def from_engine_args(cls, engine_args: AsyncEngineArgs,
|
||||
usage_context: UsageContext, ipc_path: str):
|
||||
"""Creates an MQLLMEngine from the engine arguments."""
|
||||
# Setup plugins for each process
|
||||
from vllm.plugins import load_general_plugins
|
||||
load_general_plugins()
|
||||
|
||||
engine_config = engine_args.create_engine_config()
|
||||
|
||||
executor_class = LLMEngine._get_executor_cls(engine_config)
|
||||
|
||||
return cls(
|
||||
ipc_path=ipc_path,
|
||||
use_async_sockets=engine_config.model_config.use_async_output_proc,
|
||||
**engine_config.to_dict(),
|
||||
executor_class=executor_class,
|
||||
log_requests=not engine_args.disable_log_requests,
|
||||
log_stats=not engine_args.disable_log_stats,
|
||||
usage_context=usage_context)
|
||||
|
||||
def start(self):
|
||||
try:
|
||||
try:
|
||||
logger.debug("Starting Startup Loop.")
|
||||
self.run_startup_loop()
|
||||
logger.debug("Starting heartbeat thread")
|
||||
self.heartbeat_thread.start()
|
||||
logger.debug("Starting Engine Loop.")
|
||||
self.run_engine_loop()
|
||||
except Exception as e:
|
||||
logger.exception(repr(e))
|
||||
except KeyboardInterrupt:
|
||||
logger.debug("Shutting down MQLLMEngine.")
|
||||
finally:
|
||||
logger.debug("MQLLMEngine is shut down.")
|
||||
self.cleanup()
|
||||
|
||||
def cleanup(self):
|
||||
"""Cleanup zeromq state on shutdown."""
|
||||
# Closes all sockets and destroys context.
|
||||
self._heartbeat_stop_event.set()
|
||||
self.ctx.destroy(linger=0)
|
||||
del self.engine
|
||||
|
||||
@contextmanager
|
||||
def make_data_socket(
|
||||
self) -> Iterator[zmq.Socket]: # type: ignore[name-defined]
|
||||
socket = self.ctx.socket(zmq.constants.ROUTER)
|
||||
try:
|
||||
socket.bind(self.data_ipc_path)
|
||||
yield socket
|
||||
finally:
|
||||
socket.close(linger=0)
|
||||
|
||||
def run_startup_loop(self) -> None:
|
||||
"""Startup loop for sending data from Engine -> Client."""
|
||||
|
||||
with self.make_data_socket() as socket:
|
||||
response: Union[RPCStartupResponse, BaseException]
|
||||
try:
|
||||
identity, message = socket.recv_multipart(copy=False)
|
||||
request: RPCStartupRequest = pickle.loads(message.buffer)
|
||||
|
||||
# Handle the query from the Client.
|
||||
if request == RPCStartupRequest.IS_SERVER_READY:
|
||||
tracing_enabled = self.engine.is_tracing_enabled()
|
||||
response = RPCStartupResponse(
|
||||
tracing_enabled=tracing_enabled)
|
||||
|
||||
except Exception as e:
|
||||
response = e
|
||||
|
||||
socket.send_multipart((identity, pickle.dumps(response)),
|
||||
copy=False)
|
||||
|
||||
def run_engine_loop(self):
|
||||
"""Core busy loop of the LLMEngine."""
|
||||
|
||||
while True:
|
||||
self._alive()
|
||||
if not self.engine.has_unfinished_requests():
|
||||
# Poll until there is work to do.
|
||||
while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
|
||||
self._alive()
|
||||
self.engine.do_log_stats()
|
||||
logger.debug("Waiting for new requests in engine loop.")
|
||||
|
||||
# Handle any input from the client.
|
||||
self.handle_new_input()
|
||||
|
||||
# Engine step.
|
||||
request_outputs = self.engine_step()
|
||||
|
||||
# Send request outputs (if async, done in engine_step callback).
|
||||
if not self.use_async_sockets:
|
||||
self._send_outputs(request_outputs)
|
||||
|
||||
def engine_step(self) -> List[RequestOutput]:
|
||||
"""Engine step wrapper with error handling."""
|
||||
try:
|
||||
return self.engine.step()
|
||||
except SystemExit:
|
||||
raise
|
||||
except BaseException as e:
|
||||
self._set_errored(e)
|
||||
rpc_err = RPCError(request_id=None,
|
||||
is_engine_errored=True,
|
||||
exception=e)
|
||||
self._send_outputs(rpc_err)
|
||||
raise e
|
||||
|
||||
def handle_new_input(self):
|
||||
"""Handle new input from the socket"""
|
||||
try:
|
||||
while self.input_socket.poll(timeout=0) != 0:
|
||||
frames = self.input_socket.recv_multipart(copy=False)
|
||||
request = pickle.loads(frames[0].buffer)
|
||||
|
||||
if isinstance(request, RPCProcessRequest):
|
||||
if len(frames) > 1:
|
||||
# Use cloudpickle for logits processors
|
||||
assert isinstance(request.params, SamplingParams)
|
||||
lprocs = cloudpickle.loads(frames[1].buffer)
|
||||
request.params.logits_processors = lprocs
|
||||
self._handle_process_request(request)
|
||||
elif isinstance(request, RPCAbortRequest):
|
||||
self._handle_abort_request(request)
|
||||
elif isinstance(request, RPCUProfileRequest):
|
||||
if request == RPCUProfileRequest.START_PROFILE:
|
||||
self.start_profile()
|
||||
else:
|
||||
self.stop_profile()
|
||||
else:
|
||||
raise ValueError("Unknown RPCRequest Type: "
|
||||
f"{type(request)}")
|
||||
|
||||
except Exception as e:
|
||||
self._set_errored(e)
|
||||
self._send_unhealthy(e)
|
||||
raise e
|
||||
|
||||
def _handle_process_request(self, request: RPCProcessRequest):
|
||||
"""Handle RPCProcessRequest by adding it to the LLMEngine."""
|
||||
request_id = request.request_id
|
||||
|
||||
if self._errored_with is not None:
|
||||
rpc_err = RPCError(request_id=request_id,
|
||||
is_engine_errored=True,
|
||||
exception=ENGINE_DEAD_ERROR(self._errored_with))
|
||||
self._send_outputs(rpc_err)
|
||||
|
||||
try:
|
||||
self.engine.add_request(
|
||||
request_id=request_id,
|
||||
prompt=request.prompt,
|
||||
params=request.params,
|
||||
lora_request=request.lora_request,
|
||||
trace_headers=request.trace_headers,
|
||||
prompt_adapter_request=request.prompt_adapter_request,
|
||||
priority=request.priority)
|
||||
|
||||
if self.log_requests:
|
||||
logger.info("Added request %s.", request.request_id)
|
||||
|
||||
except Exception as e:
|
||||
# We do not set self._errored = True here, since the error
|
||||
# is due to an issue adding this request to the engine,
|
||||
# rather than an issue with the engine itself.
|
||||
is_errored = self._errored_with is not None
|
||||
rpc_err = RPCError(request_id=request_id,
|
||||
is_engine_errored=is_errored,
|
||||
exception=e)
|
||||
self._send_outputs(rpc_err)
|
||||
|
||||
# Remove request from the engine.
|
||||
self.engine.abort_request(request_id)
|
||||
|
||||
def _handle_abort_request(self, request: RPCAbortRequest):
|
||||
self.engine.abort_request(request.request_id)
|
||||
if self.log_requests:
|
||||
logger.info("Aborted request %s.", request.request_id)
|
||||
|
||||
def _heartbeat_loop(self):
|
||||
while not self._heartbeat_stop_event.wait(
|
||||
timeout=self.heartbeat_interval_seconds):
|
||||
# Loops until the stop event is set
|
||||
self._heartbeat()
|
||||
|
||||
logger.debug("Exiting MQLLMEngine heartbeat thread")
|
||||
|
||||
def _heartbeat(self):
|
||||
# Send unhealthy if engine has already errored
|
||||
if self._errored_with is not None:
|
||||
self._send_unhealthy(self._errored_with)
|
||||
|
||||
# Check for life of the main loop
|
||||
elif time.time() - self._last_alive_time > self.last_alive_threshold:
|
||||
self._send_unhealthy(RuntimeError("Engine loop has died"))
|
||||
|
||||
else:
|
||||
# Otherwise- check health of the engine
|
||||
# self.engine.check_health() raises on unhealthy
|
||||
try:
|
||||
self.engine.check_health()
|
||||
self._send_healthy()
|
||||
except Exception as e:
|
||||
self._set_errored(e)
|
||||
self._send_unhealthy(e)
|
||||
|
||||
def _send_outputs(self, outputs: REQUEST_OUTPUTS_T):
|
||||
"""Send List of RequestOutput to RPCClient."""
|
||||
if outputs:
|
||||
output_bytes = pickle.dumps(outputs)
|
||||
self.output_socket.send_multipart((output_bytes, ), copy=False)
|
||||
|
||||
def _send_healthy(self):
|
||||
"""Send HEALTHY message to RPCClient."""
|
||||
if not self.heartbeat_socket.closed:
|
||||
self.heartbeat_socket.send_multipart(HEALTHY_RESPONSE, copy=False)
|
||||
|
||||
def _send_unhealthy(self, error: BaseException):
|
||||
"""Send UNHEALTHY message to RPCClient."""
|
||||
if not self.heartbeat_socket.closed:
|
||||
error_bytes = pickle.dumps(error)
|
||||
self.heartbeat_socket.send_multipart((error_bytes, ), copy=False)
|
||||
|
||||
def _async_socket_engine_callback(self,
|
||||
request_outputs: REQUEST_OUTPUTS_T):
|
||||
"""Callback used by engine to make socket handling async with GPU."""
|
||||
self._send_outputs(request_outputs)
|
||||
self.handle_new_input()
|
||||
|
||||
def _set_errored(self, e: BaseException):
|
||||
"""Log and set errored status if this is the first issue."""
|
||||
if self._errored_with is None:
|
||||
self._errored_with = e
|
||||
|
||||
def _alive(self):
|
||||
self._last_alive_time = time.time()
|
||||
|
||||
def start_profile(self) -> None:
|
||||
if type(self.engine.model_executor) is GPUExecutor:
|
||||
self.engine.model_executor.start_profile()
|
||||
else:
|
||||
self.engine.model_executor._run_workers("start_profile")
|
||||
|
||||
def stop_profile(self) -> None:
|
||||
if type(self.engine.model_executor) is GPUExecutor:
|
||||
self.engine.model_executor.stop_profile()
|
||||
else:
|
||||
self.engine.model_executor._run_workers("stop_profile")
|
||||
|
||||
|
||||
def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext,
|
||||
ipc_path: str):
|
||||
|
||||
def signal_handler(*_) -> None:
|
||||
# Interrupt server on sigterm
|
||||
raise KeyboardInterrupt("MQLLMEngine terminated")
|
||||
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
engine = MQLLMEngine.from_engine_args(engine_args=engine_args,
|
||||
usage_context=usage_context,
|
||||
ipc_path=ipc_path)
|
||||
engine.start()
|
||||
0
vllm/engine/output_processor/__init__.py
Normal file
0
vllm/engine/output_processor/__init__.py
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
vllm/engine/output_processor/__pycache__/util.cpython-310.pyc
Normal file
BIN
vllm/engine/output_processor/__pycache__/util.cpython-310.pyc
Normal file
Binary file not shown.
72
vllm/engine/output_processor/interfaces.py
Normal file
72
vllm/engine/output_processor/interfaces.py
Normal file
@@ -0,0 +1,72 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, List
|
||||
|
||||
from vllm.config import SchedulerConfig
|
||||
from vllm.core.scheduler import Scheduler
|
||||
from vllm.engine.output_processor.stop_checker import StopChecker
|
||||
from vllm.sequence import Sequence, SequenceGroup, SequenceGroupOutput
|
||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import Counter
|
||||
|
||||
|
||||
class SequenceGroupOutputProcessor(ABC):
|
||||
"""Interface for logic that processes new token ids in sequence groups,
|
||||
managing detokenization, stop checking, and freeing/forking sequences with
|
||||
the scheduler.
|
||||
|
||||
This is highly coupled with the LLMEngine and should be seen as an extension
|
||||
of it. The logic is separated to simplify the LLMEngine class and allow
|
||||
separate implementations for single-step decoding (which supports beam
|
||||
search sequence forking) and multi-step decoding (which does not support
|
||||
beam search, but does support speculative decoding).
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def create_output_processor(
|
||||
scheduler_config: SchedulerConfig,
|
||||
detokenizer: Detokenizer,
|
||||
scheduler: List[Scheduler],
|
||||
seq_counter: Counter,
|
||||
get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer],
|
||||
stop_checker: "StopChecker",
|
||||
):
|
||||
"""Create an output processor.
|
||||
|
||||
This returns a single-step output processor if num_lookahead_slots is
|
||||
zero, else returns a multi-step output processor.
|
||||
"""
|
||||
if scheduler_config.num_lookahead_slots == 0:
|
||||
# Importing here to avoid cycle.
|
||||
from vllm.engine.output_processor.single_step import (
|
||||
SingleStepOutputProcessor)
|
||||
return SingleStepOutputProcessor(scheduler_config, detokenizer,
|
||||
scheduler, seq_counter,
|
||||
stop_checker)
|
||||
else:
|
||||
# Importing here to avoid cycle.
|
||||
from vllm.engine.output_processor.multi_step import (
|
||||
MultiStepOutputProcessor)
|
||||
return MultiStepOutputProcessor(
|
||||
detokenizer,
|
||||
scheduler,
|
||||
seq_counter,
|
||||
get_tokenizer_for_seq,
|
||||
stop_checker,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def process_outputs(self, sequence_group: SequenceGroup,
|
||||
outputs: List[SequenceGroupOutput],
|
||||
is_async: bool) -> None:
|
||||
"""Process new token ids for the sequence group. Handles logic such as
|
||||
detokenization, stop checking, and freeing/forking sequences in the
|
||||
scheduler.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def process_prompt_logprob(self, seq_group: SequenceGroup,
|
||||
outputs: List[SequenceGroupOutput]) -> None:
|
||||
"""Update prompt logprobs received from outputs to seq_group."""
|
||||
pass
|
||||
188
vllm/engine/output_processor/multi_step.py
Normal file
188
vllm/engine/output_processor/multi_step.py
Normal file
@@ -0,0 +1,188 @@
|
||||
import functools
|
||||
from typing import Callable, List
|
||||
|
||||
from vllm.core.scheduler import Scheduler
|
||||
from vllm.engine.output_processor.interfaces import (
|
||||
SequenceGroupOutputProcessor)
|
||||
from vllm.engine.output_processor.single_step import (
|
||||
single_step_process_prompt_logprob)
|
||||
from vllm.engine.output_processor.stop_checker import StopChecker
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import (VLLM_INVALID_TOKEN_ID, Sequence, SequenceGroup,
|
||||
SequenceGroupOutput, SequenceOutput, SequenceStatus)
|
||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import Counter
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
|
||||
"""SequenceGroupOutputProcessor which handles logic related to
|
||||
detokenization and stopping conditions. It specializes to "multi-step
|
||||
decoding", where vLLM's worker may generate multiple tokens per invocation.
|
||||
This is currently mutually exclusive with advanced sampling techniques like
|
||||
beam search, which motivates the separation of this logic from the single
|
||||
step output processor.
|
||||
|
||||
This class is responsible for things such as correctly appending all new
|
||||
token ids to their sequence, detokenizing new token ids, truncating new
|
||||
output tokens after an eos token, and correctly handling the case where the
|
||||
number of new output tokens per sequence differs in a single batch.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
detokenizer: Detokenizer,
|
||||
scheduler: List[Scheduler],
|
||||
seq_counter: Counter,
|
||||
get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer],
|
||||
stop_checker: StopChecker,
|
||||
):
|
||||
self.detokenizer = detokenizer
|
||||
self.scheduler = scheduler
|
||||
self.seq_counter = seq_counter
|
||||
self.get_tokenizer_for_seq = get_tokenizer_for_seq
|
||||
self.stop_checker = stop_checker
|
||||
|
||||
def process_prompt_logprob(self, seq_group: SequenceGroup,
|
||||
outputs: List[SequenceGroupOutput]) -> None:
|
||||
"""Process prompt logprobs associated with each step of a multi-step-
|
||||
scheduled computation.
|
||||
|
||||
Args:
|
||||
seq_group: the outputs are associated with this :class:`SequenceGroup`
|
||||
outputs: the :class:`SequenceGroupOutput`s for all scheduler steps
|
||||
"""
|
||||
for output in outputs:
|
||||
# Concatenate single-step prompt logprob processing results.
|
||||
single_step_process_prompt_logprob(self, seq_group, output)
|
||||
|
||||
@staticmethod
|
||||
@functools.lru_cache()
|
||||
def _log_prompt_logprob_unsupported_warning_once():
|
||||
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
|
||||
# If the feature combo become valid
|
||||
logger.warning(
|
||||
"Prompt logprob is not supported by multi step workers. "
|
||||
"(e.g., speculative decode uses multi step workers).")
|
||||
|
||||
def process_outputs(self,
|
||||
sequence_group: SequenceGroup,
|
||||
outputs: List[SequenceGroupOutput],
|
||||
is_async: bool = False) -> None:
|
||||
"""Append new tokens in the outputs to sequences in the sequence group.
|
||||
|
||||
This only supports sequence groups of size 1. It supports greater than
|
||||
one new token per sequence.
|
||||
|
||||
This applies logic like stop condition checking and detokenization.
|
||||
It also handles cases where there are tokens emitted after
|
||||
the EOS token.
|
||||
|
||||
is_async - Indicates whether this postprocessor runs in
|
||||
parallel with the GPU forward pass and is processing
|
||||
tokens from the previous step. If this is true, then
|
||||
no tokens need to be appended since it is already done
|
||||
externally (before the next schedule() call)
|
||||
"""
|
||||
# Sequences can be in RUNNING or FINISHED_ABORTED state
|
||||
# once scheduled, as a sequence is moved to FINSIHED_ABORTED
|
||||
# if a client disconnects from the api server.
|
||||
seqs = sequence_group.get_seqs(status=SequenceStatus.RUNNING)
|
||||
if seqs is None:
|
||||
seqs = sequence_group.get_seqs(
|
||||
status=SequenceStatus.FINISHED_ABORTED)
|
||||
|
||||
assert seqs, "Expected RUNNING or FINISHED_ABORTED sequences"
|
||||
assert len(seqs) == 1, (
|
||||
"Beam search not supported in multi-step decoding.")
|
||||
seq = seqs[0]
|
||||
seq_id = seq.seq_id
|
||||
assert all(
|
||||
[seq_id == output.samples[0].parent_seq_id for output in outputs])
|
||||
|
||||
if is_async:
|
||||
# Async case: We process tokens one by one. Here, we know the token
|
||||
# was already appended, so we only need to do the rest of the
|
||||
# postprocessor: Detokenization + stopping logic
|
||||
self._process_decode_and_stop(seq, sequence_group.sampling_params)
|
||||
else:
|
||||
# Standard multi-step case
|
||||
|
||||
# Since there's only one sequence per sequence group,
|
||||
# we can take the first sample.
|
||||
samples = [output.samples[0] for output in outputs]
|
||||
|
||||
# entries in sample tokens may be invalid (eg. due to spec decode
|
||||
# rejecting tokens).
|
||||
valid_samples = [
|
||||
sample for sample in samples
|
||||
if sample.output_token != VLLM_INVALID_TOKEN_ID
|
||||
]
|
||||
assert valid_samples
|
||||
|
||||
self._process_seq_outputs(seq, valid_samples,
|
||||
sequence_group.sampling_params)
|
||||
|
||||
def _process_decode_and_stop(self, seq: Sequence,
|
||||
sampling_params: SamplingParams) -> None:
|
||||
new_char_count = 0
|
||||
if sampling_params.detokenize:
|
||||
new_char_count = self.detokenizer.decode_sequence_inplace(
|
||||
seq, sampling_params)
|
||||
|
||||
# TODO(sang): Support lora.
|
||||
self.stop_checker.maybe_stop_sequence(
|
||||
seq,
|
||||
new_char_count=new_char_count,
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
|
||||
def _process_seq_outputs(self, seq: Sequence,
|
||||
valid_samples: List[SequenceOutput],
|
||||
sampling_params: SamplingParams) -> None:
|
||||
output_token_ids = [sample.output_token for sample in valid_samples]
|
||||
output_logprobs = [sample.logprobs for sample in valid_samples]
|
||||
|
||||
# Truncate to max_tokens if necessary.
|
||||
remaining_tokens = sampling_params.max_tokens - (seq.get_output_len() +
|
||||
len(output_token_ids))
|
||||
if remaining_tokens < 0:
|
||||
output_token_ids = output_token_ids[:remaining_tokens]
|
||||
|
||||
# Truncate any tokens after EOS. This is required as spec decode
|
||||
# generates a fixed number of tokens without evaluating stopping
|
||||
# conditions within the block. This can cause an eos token to be
|
||||
# unintentionally ignored.
|
||||
if not sampling_params.ignore_eos:
|
||||
eos_token_id = self.get_tokenizer_for_seq(seq).eos_token_id
|
||||
# Avoiding .index calls as exception throwing in the happy path
|
||||
# is expensive.
|
||||
for i in range(len(output_token_ids)):
|
||||
if output_token_ids[i] == eos_token_id:
|
||||
output_token_ids = output_token_ids[:i + 1]
|
||||
break
|
||||
|
||||
is_prefill_sampled_token = seq.data.get_num_uncomputed_tokens() == 0
|
||||
# Incrementally append tokens to the sequence, as if we had only one new
|
||||
# token.
|
||||
for output_token_id, output_logprob in zip(output_token_ids,
|
||||
output_logprobs):
|
||||
seq.append_token_id(
|
||||
token_id=output_token_id,
|
||||
logprobs=output_logprob,
|
||||
)
|
||||
|
||||
if is_prefill_sampled_token:
|
||||
is_prefill_sampled_token = False
|
||||
else:
|
||||
# Update num_computed_tokens iff the sampled token is not from
|
||||
# a prefill step.
|
||||
seq.data.update_num_computed_tokens(1)
|
||||
|
||||
self._process_decode_and_stop(seq, sampling_params)
|
||||
|
||||
if seq.is_finished():
|
||||
break
|
||||
215
vllm/engine/output_processor/single_step.py
Normal file
215
vllm/engine/output_processor/single_step.py
Normal file
@@ -0,0 +1,215 @@
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from vllm.config import SchedulerConfig
|
||||
from vllm.core.scheduler import Scheduler
|
||||
from vllm.engine.output_processor.interfaces import (
|
||||
SequenceGroupOutputProcessor)
|
||||
from vllm.engine.output_processor.stop_checker import StopChecker
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupOutput,
|
||||
SequenceOutput, SequenceStatus)
|
||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||
from vllm.utils import Counter
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def single_step_process_prompt_logprob(
|
||||
sg_output_proc: SequenceGroupOutputProcessor, seq_group: SequenceGroup,
|
||||
output: SequenceGroupOutput) -> None:
|
||||
"""Process prompt logprobs associated with the :class:`SequenceGroupOutput`
|
||||
for a given step.
|
||||
|
||||
Do nothing if the output has no prompt logprobs.
|
||||
|
||||
Account for the fact that transformers do not compute first-token logprobs.
|
||||
|
||||
Args:
|
||||
sg_output_proc: :class:`SequenceGroupOutputProcessor` instance
|
||||
seq_group: the output is associated with this :class:`SequenceGroup`
|
||||
output: the :class:`SequenceGroupOutput` for a single scheduler step
|
||||
"""
|
||||
prompt_logprobs = output.prompt_logprobs
|
||||
|
||||
# If this is the first (or only) "chunk" of the prefill, we need
|
||||
# to prepend None to the list of prompt logprobs. The reason for this
|
||||
# is that for N prompt tokens, the Sampler will generate N-1 total
|
||||
# prompt logprobs during prefill since the token at idx 0 will not
|
||||
# have a logprob associated with it.
|
||||
if prompt_logprobs is not None:
|
||||
if not seq_group.prompt_logprobs:
|
||||
prompt_logprobs = [None] + prompt_logprobs
|
||||
seq_group.prompt_logprobs = []
|
||||
|
||||
assert hasattr(sg_output_proc, 'detokenizer')
|
||||
if (seq_group.sampling_params.detokenize
|
||||
and sg_output_proc.detokenizer):
|
||||
sg_output_proc.detokenizer.decode_prompt_logprobs_inplace(
|
||||
seq_group,
|
||||
prompt_logprobs,
|
||||
position_offset=len(seq_group.prompt_logprobs))
|
||||
|
||||
seq_group.prompt_logprobs.extend(prompt_logprobs)
|
||||
|
||||
|
||||
class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
|
||||
"""SequenceGroupOutputProcessor which handles "output processing" logic,
|
||||
which happens after the model returns generated token ids and before
|
||||
scheduling of the next batch. Output processing logic includes
|
||||
detokenization, and determining if a sequence is finished (e.g. via max len
|
||||
or eos token).
|
||||
|
||||
The SingleStepOutputProcessor is specialized to the case where the model
|
||||
emits at most a single token per invocation, which precludes configurations
|
||||
such as speculative decoding or multi-step decoding. This enables beam
|
||||
search sampling, which requires forking/finishing/freeing sequences in a way
|
||||
that is currently difficult to schedule multiple steps ahead of time.
|
||||
"""
|
||||
|
||||
def __init__(self, scheduler_config: SchedulerConfig,
|
||||
detokenizer: Detokenizer, scheduler: List[Scheduler],
|
||||
seq_counter: Counter, stop_checker: StopChecker):
|
||||
self.scheduler_config = scheduler_config
|
||||
self.detokenizer = detokenizer
|
||||
self.scheduler = scheduler
|
||||
self.seq_counter = seq_counter
|
||||
self.stop_checker = stop_checker
|
||||
|
||||
def process_outputs(self, sequence_group: SequenceGroup,
|
||||
outputs: List[SequenceGroupOutput],
|
||||
is_async: bool) -> None:
|
||||
"""Append all new tokens to sequences in the sequence group. Fork any
|
||||
surviving beam candidates; free any unsurviving ones.
|
||||
|
||||
Invokes detokenizer to detokenize new tokens, and also marks sequences
|
||||
as finished if they meet stop conditions.
|
||||
|
||||
is_async - Indicates whether this postprocessor runs in
|
||||
parallel with the GPU forward pass and is processing
|
||||
tokens from the previous step. If this is true, then
|
||||
no tokens need to be appended since it is already done
|
||||
externally (before the next schedule() call)
|
||||
"""
|
||||
assert (len(outputs) == 1
|
||||
), f"{type(self)} does not support multiple outputs per step"
|
||||
return self._process_sequence_group_outputs(sequence_group, outputs[0],
|
||||
is_async)
|
||||
|
||||
def process_prompt_logprob(self, seq_group: SequenceGroup,
|
||||
outputs: List[SequenceGroupOutput]) -> None:
|
||||
"""Process prompt logprobs associated with one step of a single-step-
|
||||
scheduled computation.
|
||||
|
||||
Args:
|
||||
seq_group: the output is associated with this :class:`SequenceGroup`
|
||||
output: the :class:`SequenceGroupOutput` for a single scheduler step
|
||||
"""
|
||||
assert len(outputs) == 1, ("Single step should only has 1 output.")
|
||||
output = outputs[0]
|
||||
single_step_process_prompt_logprob(self, seq_group, output)
|
||||
|
||||
def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
|
||||
outputs: SequenceGroupOutput,
|
||||
is_async: bool) -> None:
|
||||
sampling_params = seq_group.sampling_params
|
||||
if sampling_params.n == 1:
|
||||
# only have one output sample
|
||||
sample = outputs.samples[0]
|
||||
# only have one sequence
|
||||
seq = seq_group.seqs[0]
|
||||
if not is_async:
|
||||
seq.append_token_id(sample.output_token, sample.logprobs)
|
||||
if sampling_params.detokenize and self.detokenizer:
|
||||
new_char_count = self.detokenizer.decode_sequence_inplace(
|
||||
seq, sampling_params)
|
||||
else:
|
||||
new_char_count = 0
|
||||
self.stop_checker.maybe_stop_sequence(
|
||||
seq,
|
||||
new_char_count,
|
||||
sampling_params,
|
||||
lora_req=seq_group.lora_request,
|
||||
)
|
||||
if seq.is_finished():
|
||||
for scheduler in self.scheduler:
|
||||
scheduler.free_seq(seq)
|
||||
return
|
||||
|
||||
# TODO: Add support for async for beam search
|
||||
assert not is_async
|
||||
|
||||
# Process samples
|
||||
samples = outputs.samples
|
||||
parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
|
||||
parent_child_dict: Dict[int, List[SequenceOutput]] = {
|
||||
parent_seq.seq_id: []
|
||||
for parent_seq in parent_seqs
|
||||
}
|
||||
for sample in samples:
|
||||
# Guard against a KeyError which can occur if the request was
|
||||
# aborted while the output was generated
|
||||
if (child_list :=
|
||||
parent_child_dict.get(sample.parent_seq_id)) is not None:
|
||||
child_list.append(sample)
|
||||
# List of (child, parent)
|
||||
child_seqs: List[Tuple[Sequence, Sequence]] = []
|
||||
|
||||
# Process the child samples for each parent sequence
|
||||
for parent in parent_seqs:
|
||||
child_samples: List[SequenceOutput] = parent_child_dict[
|
||||
parent.seq_id]
|
||||
if len(child_samples) == 0:
|
||||
# This parent sequence has no children samples. Remove
|
||||
# the parent sequence from the sequence group since it will
|
||||
# not be used in the future iterations.
|
||||
parent.status = SequenceStatus.FINISHED_ABORTED
|
||||
seq_group.remove(parent.seq_id)
|
||||
for scheduler in self.scheduler:
|
||||
scheduler.free_seq(parent)
|
||||
continue
|
||||
# Fork the parent sequence if there are multiple child samples.
|
||||
for child_sample in child_samples[:-1]:
|
||||
new_child_seq_id: int = next(self.seq_counter)
|
||||
child = parent.fork(new_child_seq_id)
|
||||
child.append_token_id(child_sample.output_token,
|
||||
child_sample.logprobs)
|
||||
child_seqs.append((child, parent))
|
||||
# Continue the parent sequence for the last child sample.
|
||||
# We reuse the parent sequence here to reduce redundant memory
|
||||
# copies, especially when using non-beam search sampling methods.
|
||||
last_child_sample = child_samples[-1]
|
||||
parent.append_token_id(last_child_sample.output_token,
|
||||
last_child_sample.logprobs)
|
||||
child_seqs.append((parent, parent))
|
||||
|
||||
for seq, _ in child_seqs:
|
||||
if sampling_params.detokenize and self.detokenizer:
|
||||
new_char_count = self.detokenizer.decode_sequence_inplace(
|
||||
seq, sampling_params)
|
||||
else:
|
||||
new_char_count = 0
|
||||
self.stop_checker.maybe_stop_sequence(
|
||||
seq,
|
||||
new_char_count,
|
||||
sampling_params,
|
||||
lora_req=seq_group.lora_request,
|
||||
)
|
||||
|
||||
# For newly created child sequences, add them to the sequence group
|
||||
# and fork them in block manager if they are not finished.
|
||||
for seq, parent in child_seqs:
|
||||
if seq is not parent:
|
||||
seq_group.add(seq)
|
||||
if not seq.is_finished():
|
||||
for scheduler in self.scheduler:
|
||||
scheduler.fork_seq(parent, seq)
|
||||
|
||||
# Free the finished and selected parent sequences' memory in block
|
||||
# manager. Keep them in the sequence group as candidate output.
|
||||
# NOTE: we need to fork the new sequences before freeing the
|
||||
# old sequences.
|
||||
for seq, parent in child_seqs:
|
||||
if seq is parent and seq.is_finished():
|
||||
for scheduler in self.scheduler:
|
||||
scheduler.free_seq(seq)
|
||||
return
|
||||
117
vllm/engine/output_processor/stop_checker.py
Normal file
117
vllm/engine/output_processor/stop_checker.py
Normal file
@@ -0,0 +1,117 @@
|
||||
from typing import Callable, Optional
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import Sequence, SequenceStatus
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
|
||||
class StopChecker:
|
||||
"""LLMEngine helper class which separates out the logic involving stop
|
||||
checking. This checks things such as: whether the eos token was emitted,
|
||||
whether the max_tokens has been consumed, whether a stop string has been
|
||||
emitted, or if we have exceeded the max model len.
|
||||
"""
|
||||
|
||||
def __init__(self, max_model_len: int,
|
||||
get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer]):
|
||||
# Do not use it directly, but use `self._get_max_model_len`.
|
||||
self._max_model_len = max_model_len
|
||||
self.get_tokenizer_for_seq = get_tokenizer_for_seq
|
||||
|
||||
def _get_max_model_len(self, lora_req: Optional[LoRARequest]):
|
||||
if lora_req and lora_req.long_lora_max_len:
|
||||
return lora_req.long_lora_max_len
|
||||
else:
|
||||
return self._max_model_len
|
||||
|
||||
def maybe_stop_sequence(
|
||||
self,
|
||||
seq: Sequence,
|
||||
new_char_count: int,
|
||||
sampling_params: SamplingParams,
|
||||
lora_req: Optional[LoRARequest] = None,
|
||||
) -> None:
|
||||
"""Stop the finished sequences.
|
||||
|
||||
new_char_count is the number of chars added to the
|
||||
sequence's output text for the newly generated token
|
||||
"""
|
||||
|
||||
# Check if the minimum number of tokens has been generated yet;
|
||||
# skip the stop string/token checks if not
|
||||
if seq.get_output_len() < sampling_params.min_tokens:
|
||||
return
|
||||
|
||||
# Check if the sequence has generated the EOS token.
|
||||
if ((not sampling_params.ignore_eos)
|
||||
and seq.get_last_token_id() == seq.eos_token_id):
|
||||
# Remove the last EOS token unless explicitly specified
|
||||
# This prevents unintended exposure of the EOS token
|
||||
if new_char_count and (
|
||||
not sampling_params.include_stop_str_in_output):
|
||||
seq.output_text = seq.output_text[:-new_char_count]
|
||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||
return
|
||||
|
||||
# Check if a stop token was encountered.
|
||||
# This assumes a single token produced per step.
|
||||
last_token_id = seq.get_last_token_id()
|
||||
if last_token_id in sampling_params.stop_token_ids:
|
||||
if new_char_count and (
|
||||
not sampling_params.include_stop_str_in_output):
|
||||
# Remove last token
|
||||
seq.output_text = seq.output_text[:-new_char_count]
|
||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||
seq.stop_reason = last_token_id
|
||||
return
|
||||
|
||||
# Check if any stop strings are matched.
|
||||
stop_str = self._check_stop_strings(seq, new_char_count,
|
||||
sampling_params)
|
||||
if stop_str is not None:
|
||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||
seq.stop_reason = stop_str
|
||||
return
|
||||
|
||||
# Check if the sequence has reached max_model_len.
|
||||
if seq.get_len() > self._get_max_model_len(lora_req):
|
||||
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
|
||||
return
|
||||
|
||||
# Check if the sequence has reached max_tokens.
|
||||
if seq.get_output_len() == sampling_params.max_tokens:
|
||||
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def _check_stop_strings(seq: Sequence, new_char_count: int,
|
||||
sampling_params: SamplingParams) -> Optional[str]:
|
||||
"""Check if any stop strings are matched and truncate sequence
|
||||
output text accordingly.
|
||||
|
||||
Returns the stop string if matched or else None.
|
||||
"""
|
||||
if not new_char_count:
|
||||
return None
|
||||
|
||||
for stop_str in sampling_params.stop:
|
||||
stop_string_len = len(stop_str)
|
||||
# Avoid searching already-searched text.
|
||||
stop_index = seq.output_text.find(
|
||||
stop_str, -new_char_count - stop_string_len)
|
||||
if stop_index == -1:
|
||||
continue
|
||||
|
||||
if sampling_params.include_stop_str_in_output:
|
||||
# Truncate to end of stop string.
|
||||
stop_index += stop_string_len
|
||||
if stop_index >= len(seq.output_text):
|
||||
# No truncation required.
|
||||
return stop_str
|
||||
|
||||
# Truncate the output text to either the beginning
|
||||
# or end of the stop string.
|
||||
seq.output_text = seq.output_text[:stop_index]
|
||||
return stop_str
|
||||
return None
|
||||
22
vllm/engine/output_processor/util.py
Normal file
22
vllm/engine/output_processor/util.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from typing import List
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import Union
|
||||
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.sequence import PoolerOutput, SequenceGroupOutput
|
||||
|
||||
|
||||
def create_output_by_sequence_group(
|
||||
outputs: GenericSequence[Union[SamplerOutput, PoolerOutput]],
|
||||
num_seq_groups: int) -> List[List[SequenceGroupOutput]]:
|
||||
"""Helper method which transforms a 2d list organized by
|
||||
[step][sequence group] into [sequence group][step].
|
||||
"""
|
||||
output_by_sequence_group: List[List[SequenceGroupOutput]] = [
|
||||
[] for _ in range(num_seq_groups)
|
||||
]
|
||||
for step in outputs:
|
||||
for i, sequence_group_output in enumerate(step):
|
||||
output_by_sequence_group[i].append(sequence_group_output)
|
||||
|
||||
return output_by_sequence_group
|
||||
103
vllm/engine/protocol.py
Normal file
103
vllm/engine/protocol.py
Normal file
@@ -0,0 +1,103 @@
|
||||
from typing import (AsyncGenerator, List, Mapping, Optional, Protocol,
|
||||
runtime_checkable)
|
||||
|
||||
from vllm.config import DecodingConfig, ModelConfig
|
||||
from vllm.core.scheduler import SchedulerOutputs
|
||||
from vllm.inputs.data import PromptType
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class EngineClient(Protocol):
|
||||
"""Protocol class for Clients to Engine"""
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
...
|
||||
|
||||
@property
|
||||
def is_stopped(self) -> bool:
|
||||
...
|
||||
|
||||
@property
|
||||
def errored(self) -> bool:
|
||||
...
|
||||
|
||||
@property
|
||||
def dead_error(self) -> BaseException:
|
||||
...
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
sampling_params: SamplingParams,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
) -> AsyncGenerator[RequestOutput, None]:
|
||||
"""Generate outputs for a request."""
|
||||
...
|
||||
|
||||
def encode(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
pooling_params: PoolingParams,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
priority: int = 0,
|
||||
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
|
||||
"""Generate outputs for a request from an embedding model."""
|
||||
...
|
||||
|
||||
async def abort(self, request_id: str) -> None:
|
||||
"""Abort a request.
|
||||
|
||||
Args:
|
||||
request_id: The unique id of the request.
|
||||
"""
|
||||
|
||||
async def get_model_config(self) -> ModelConfig:
|
||||
"""Get the model configuration of the vLLM engine."""
|
||||
...
|
||||
|
||||
async def get_decoding_config(self) -> DecodingConfig:
|
||||
...
|
||||
"""Get the decoding configuration of the vLLM engine."""
|
||||
|
||||
async def get_tokenizer(
|
||||
self,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
) -> AnyTokenizer:
|
||||
"""Get the appropriate tokenizer for the request"""
|
||||
...
|
||||
|
||||
async def is_tracing_enabled(self) -> bool:
|
||||
...
|
||||
|
||||
async def do_log_stats(
|
||||
self,
|
||||
scheduler_outputs: Optional[SchedulerOutputs] = None,
|
||||
model_output: Optional[List[SamplerOutput]] = None,
|
||||
) -> None:
|
||||
...
|
||||
|
||||
async def check_health(self) -> None:
|
||||
"""Raise if unhealthy"""
|
||||
...
|
||||
|
||||
async def start_profile(self) -> None:
|
||||
"""Start profiling the engine"""
|
||||
...
|
||||
|
||||
async def stop_profile(self) -> None:
|
||||
"""Start profiling the engine"""
|
||||
...
|
||||
Reference in New Issue
Block a user