Sync from v0.13
This commit is contained in:
0
vllm/v1/metrics/__init__.py
Normal file
0
vllm/v1/metrics/__init__.py
Normal file
1305
vllm/v1/metrics/loggers.py
Normal file
1305
vllm/v1/metrics/loggers.py
Normal file
File diff suppressed because it is too large
Load Diff
82
vllm/v1/metrics/prometheus.py
Normal file
82
vllm/v1/metrics/prometheus.py
Normal file
@@ -0,0 +1,82 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from prometheus_client import REGISTRY, CollectorRegistry, multiprocess
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Global temporary directory for prometheus multiprocessing
|
||||
_prometheus_multiproc_dir: tempfile.TemporaryDirectory | None = None
|
||||
|
||||
|
||||
def setup_multiprocess_prometheus():
|
||||
"""Set up prometheus multiprocessing directory if not already configured."""
|
||||
global _prometheus_multiproc_dir
|
||||
|
||||
if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
|
||||
# Make TemporaryDirectory for prometheus multiprocessing
|
||||
# Note: global TemporaryDirectory will be automatically
|
||||
# cleaned up upon exit.
|
||||
_prometheus_multiproc_dir = tempfile.TemporaryDirectory()
|
||||
os.environ["PROMETHEUS_MULTIPROC_DIR"] = _prometheus_multiproc_dir.name
|
||||
logger.debug(
|
||||
"Created PROMETHEUS_MULTIPROC_DIR at %s", _prometheus_multiproc_dir.name
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Found PROMETHEUS_MULTIPROC_DIR was set by user. "
|
||||
"This directory must be wiped between vLLM runs or "
|
||||
"you will find inaccurate metrics. Unset the variable "
|
||||
"and vLLM will properly handle cleanup."
|
||||
)
|
||||
|
||||
|
||||
def get_prometheus_registry() -> CollectorRegistry:
|
||||
"""Get the appropriate prometheus registry based on multiprocessing
|
||||
configuration.
|
||||
|
||||
Returns:
|
||||
Registry: A prometheus registry
|
||||
"""
|
||||
if os.getenv("PROMETHEUS_MULTIPROC_DIR") is not None:
|
||||
logger.debug("Using multiprocess registry for prometheus metrics")
|
||||
registry = CollectorRegistry()
|
||||
multiprocess.MultiProcessCollector(registry)
|
||||
return registry
|
||||
|
||||
return REGISTRY
|
||||
|
||||
|
||||
def unregister_vllm_metrics():
|
||||
"""Unregister any existing vLLM collectors from the prometheus registry.
|
||||
|
||||
This is useful for testing and CI/CD where metrics may be registered
|
||||
multiple times across test runs.
|
||||
|
||||
Also, in case of multiprocess, we need to unregister the metrics from the
|
||||
global registry.
|
||||
"""
|
||||
registry = REGISTRY
|
||||
# Unregister any existing vLLM collectors
|
||||
for collector in list(registry._collector_to_names):
|
||||
if hasattr(collector, "_name") and "vllm" in collector._name:
|
||||
registry.unregister(collector)
|
||||
|
||||
|
||||
def shutdown_prometheus():
|
||||
"""Shutdown prometheus metrics."""
|
||||
|
||||
path = _prometheus_multiproc_dir
|
||||
if path is None:
|
||||
return
|
||||
try:
|
||||
pid = os.getpid()
|
||||
multiprocess.mark_process_dead(pid, path)
|
||||
logger.debug("Marked Prometheus metrics for process %d as dead", pid)
|
||||
except Exception as e:
|
||||
logger.error("Error during metrics cleanup: %s", str(e))
|
||||
194
vllm/v1/metrics/ray_wrappers.py
Normal file
194
vllm/v1/metrics/ray_wrappers.py
Normal file
@@ -0,0 +1,194 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import time
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorPrometheus
|
||||
from vllm.v1.metrics.loggers import PrometheusStatLogger
|
||||
from vllm.v1.spec_decode.metrics import SpecDecodingProm
|
||||
|
||||
try:
|
||||
from ray import serve as ray_serve
|
||||
from ray.util import metrics as ray_metrics
|
||||
from ray.util.metrics import Metric
|
||||
except ImportError:
|
||||
ray_metrics = None
|
||||
ray_serve = None
|
||||
import regex as re
|
||||
|
||||
|
||||
def _get_replica_id() -> str | None:
|
||||
"""Get the current Ray Serve replica ID, or None if not in a Serve context."""
|
||||
if ray_serve is None:
|
||||
return None
|
||||
try:
|
||||
return ray_serve.get_replica_context().replica_id.unique_id
|
||||
except ray_serve.exceptions.RayServeException:
|
||||
return None
|
||||
|
||||
|
||||
class RayPrometheusMetric:
|
||||
def __init__(self):
|
||||
if ray_metrics is None:
|
||||
raise ImportError("RayPrometheusMetric requires Ray to be installed.")
|
||||
self.metric: Metric = None
|
||||
|
||||
@staticmethod
|
||||
def _get_tag_keys(labelnames: list[str] | None) -> tuple[str, ...]:
|
||||
labels = list(labelnames) if labelnames else []
|
||||
labels.append("ReplicaId")
|
||||
return tuple(labels)
|
||||
|
||||
def labels(self, *labels, **labelskwargs):
|
||||
if labels:
|
||||
# -1 because ReplicaId was added automatically
|
||||
expected = len(self.metric._tag_keys) - 1
|
||||
if len(labels) != expected:
|
||||
raise ValueError(
|
||||
"Number of labels must match the number of tag keys. "
|
||||
f"Expected {expected}, got {len(labels)}"
|
||||
)
|
||||
labelskwargs.update(zip(self.metric._tag_keys, labels))
|
||||
|
||||
labelskwargs["ReplicaId"] = _get_replica_id() or ""
|
||||
|
||||
if labelskwargs:
|
||||
for k, v in labelskwargs.items():
|
||||
if not isinstance(v, str):
|
||||
labelskwargs[k] = str(v)
|
||||
self.metric.set_default_tags(labelskwargs)
|
||||
return self
|
||||
|
||||
@staticmethod
|
||||
def _get_sanitized_opentelemetry_name(name: str) -> str:
|
||||
"""
|
||||
For compatibility with Ray + OpenTelemetry, the metric name must be
|
||||
sanitized. In particular, this replaces disallowed character (e.g., ':')
|
||||
with '_' in the metric name.
|
||||
Allowed characters: a-z, A-Z, 0-9, _
|
||||
|
||||
# ruff: noqa: E501
|
||||
Ref: https://github.com/open-telemetry/opentelemetry-cpp/blob/main/sdk/src/metrics/instrument_metadata_validator.cc#L22-L23
|
||||
Ref: https://github.com/ray-project/ray/blob/master/src/ray/stats/metric.cc#L107
|
||||
"""
|
||||
|
||||
return re.sub(r"[^a-zA-Z0-9_]", "_", name)
|
||||
|
||||
|
||||
class RayGaugeWrapper(RayPrometheusMetric):
|
||||
"""Wraps around ray.util.metrics.Gauge to provide same API as
|
||||
prometheus_client.Gauge"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
documentation: str | None = "",
|
||||
labelnames: list[str] | None = None,
|
||||
multiprocess_mode: str | None = "",
|
||||
):
|
||||
# All Ray metrics are keyed by WorkerId, so multiprocess modes like
|
||||
# "mostrecent", "all", "sum" do not apply. This logic can be manually
|
||||
# implemented at the observability layer (Prometheus/Grafana).
|
||||
del multiprocess_mode
|
||||
|
||||
tag_keys = self._get_tag_keys(labelnames)
|
||||
name = self._get_sanitized_opentelemetry_name(name)
|
||||
|
||||
self.metric = ray_metrics.Gauge(
|
||||
name=name,
|
||||
description=documentation,
|
||||
tag_keys=tag_keys,
|
||||
)
|
||||
|
||||
def set(self, value: int | float):
|
||||
return self.metric.set(value)
|
||||
|
||||
def set_to_current_time(self):
|
||||
# ray metrics doesn't have set_to_current time, https://docs.ray.io/en/latest/_modules/ray/util/metrics.html
|
||||
return self.metric.set(time.time())
|
||||
|
||||
|
||||
class RayCounterWrapper(RayPrometheusMetric):
|
||||
"""Wraps around ray.util.metrics.Counter to provide same API as
|
||||
prometheus_client.Counter"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
documentation: str | None = "",
|
||||
labelnames: list[str] | None = None,
|
||||
):
|
||||
tag_keys = self._get_tag_keys(labelnames)
|
||||
name = self._get_sanitized_opentelemetry_name(name)
|
||||
self.metric = ray_metrics.Counter(
|
||||
name=name,
|
||||
description=documentation,
|
||||
tag_keys=tag_keys,
|
||||
)
|
||||
|
||||
def inc(self, value: int | float = 1.0):
|
||||
if value == 0:
|
||||
return
|
||||
return self.metric.inc(value)
|
||||
|
||||
|
||||
class RayHistogramWrapper(RayPrometheusMetric):
|
||||
"""Wraps around ray.util.metrics.Histogram to provide same API as
|
||||
prometheus_client.Histogram"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
documentation: str | None = "",
|
||||
labelnames: list[str] | None = None,
|
||||
buckets: list[float] | None = None,
|
||||
):
|
||||
tag_keys = self._get_tag_keys(labelnames)
|
||||
name = self._get_sanitized_opentelemetry_name(name)
|
||||
|
||||
boundaries = buckets if buckets else []
|
||||
self.metric = ray_metrics.Histogram(
|
||||
name=name,
|
||||
description=documentation,
|
||||
tag_keys=tag_keys,
|
||||
boundaries=boundaries,
|
||||
)
|
||||
|
||||
def observe(self, value: int | float):
|
||||
return self.metric.observe(value)
|
||||
|
||||
|
||||
class RaySpecDecodingProm(SpecDecodingProm):
|
||||
"""
|
||||
RaySpecDecodingProm is used by RayMetrics to log to Ray metrics.
|
||||
Provides the same metrics as SpecDecodingProm but uses Ray's
|
||||
util.metrics library.
|
||||
"""
|
||||
|
||||
_counter_cls = RayCounterWrapper
|
||||
|
||||
|
||||
class RayKVConnectorPrometheus(KVConnectorPrometheus):
|
||||
"""
|
||||
RayKVConnectorPrometheus is used by RayMetrics to log Ray
|
||||
metrics. Provides the same metrics as KV connectors but
|
||||
uses Ray's util.metrics library.
|
||||
"""
|
||||
|
||||
_gauge_cls = RayGaugeWrapper
|
||||
_counter_cls = RayCounterWrapper
|
||||
_histogram_cls = RayHistogramWrapper
|
||||
|
||||
|
||||
class RayPrometheusStatLogger(PrometheusStatLogger):
|
||||
"""RayPrometheusStatLogger uses Ray metrics instead."""
|
||||
|
||||
_gauge_cls = RayGaugeWrapper
|
||||
_counter_cls = RayCounterWrapper
|
||||
_histogram_cls = RayHistogramWrapper
|
||||
_spec_decoding_cls = RaySpecDecodingProm
|
||||
_kv_connector_cls = RayKVConnectorPrometheus
|
||||
|
||||
@staticmethod
|
||||
def _unregister_vllm_metrics():
|
||||
# No-op on purpose
|
||||
pass
|
||||
257
vllm/v1/metrics/reader.py
Normal file
257
vllm/v1/metrics/reader.py
Normal file
@@ -0,0 +1,257 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from prometheus_client import REGISTRY
|
||||
from prometheus_client import Metric as PromMetric
|
||||
from prometheus_client.samples import Sample
|
||||
|
||||
|
||||
@dataclass
|
||||
class Metric:
|
||||
"""A base class for prometheus metrics.
|
||||
|
||||
Each metric may be associated with key=value labels, and
|
||||
in some cases a single vLLM instance may have multiple
|
||||
metrics with the same name but different sets of labels.
|
||||
"""
|
||||
|
||||
name: str
|
||||
labels: dict[str, str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Counter(Metric):
|
||||
"""A monotonically increasing integer counter."""
|
||||
|
||||
value: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class Vector(Metric):
|
||||
"""An ordered array of integer counters.
|
||||
|
||||
This type - which doesn't exist in Prometheus - models one very
|
||||
specific metric, vllm:spec_decode_num_accepted_tokens_per_pos.
|
||||
"""
|
||||
|
||||
values: list[int]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Gauge(Metric):
|
||||
"""A numerical value that can go up or down."""
|
||||
|
||||
value: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class Histogram(Metric):
|
||||
"""Observations recorded in configurable buckets.
|
||||
|
||||
Buckets are represented by a dictionary. The key is
|
||||
the upper limit of the bucket, and the value is the
|
||||
observed count in that bucket. A '+Inf' key always
|
||||
exists.
|
||||
|
||||
The count property is the total count across all
|
||||
buckets, identical to the count of the '+Inf' bucket.
|
||||
|
||||
The sum property is the total sum of all observed
|
||||
values.
|
||||
"""
|
||||
|
||||
count: int
|
||||
sum: float
|
||||
buckets: dict[str, int]
|
||||
|
||||
|
||||
def get_metrics_snapshot() -> list[Metric]:
|
||||
"""An API for accessing in-memory Prometheus metrics.
|
||||
|
||||
Example:
|
||||
>>> for metric in llm.get_metrics():
|
||||
... if isinstance(metric, Counter):
|
||||
... print(f"{metric} = {metric.value}")
|
||||
... elif isinstance(metric, Gauge):
|
||||
... print(f"{metric} = {metric.value}")
|
||||
... elif isinstance(metric, Histogram):
|
||||
... print(f"{metric}")
|
||||
... print(f" sum = {metric.sum}")
|
||||
... print(f" count = {metric.count}")
|
||||
... for bucket_le, value in metrics.buckets.items():
|
||||
... print(f" {bucket_le} = {value}")
|
||||
"""
|
||||
collected: list[Metric] = []
|
||||
for metric in REGISTRY.collect():
|
||||
if not metric.name.startswith("vllm:"):
|
||||
continue
|
||||
if metric.type == "gauge":
|
||||
samples = _get_samples(metric)
|
||||
for s in samples:
|
||||
collected.append(
|
||||
Gauge(name=metric.name, labels=s.labels, value=s.value)
|
||||
)
|
||||
elif metric.type == "counter":
|
||||
samples = _get_samples(metric, "_total")
|
||||
if metric.name == "vllm:spec_decode_num_accepted_tokens_per_pos":
|
||||
#
|
||||
# Ugly vllm:num_accepted_tokens_per_pos special case.
|
||||
#
|
||||
# This metric is a vector of counters - for each spec
|
||||
# decoding token position, we observe the number of
|
||||
# accepted tokens using a Counter labeled with 'position'.
|
||||
# We convert these into a vector of integer values.
|
||||
#
|
||||
for labels, values in _digest_num_accepted_by_pos_samples(samples):
|
||||
collected.append(
|
||||
Vector(name=metric.name, labels=labels, values=values)
|
||||
)
|
||||
else:
|
||||
for s in samples:
|
||||
collected.append(
|
||||
Counter(name=metric.name, labels=s.labels, value=int(s.value))
|
||||
)
|
||||
|
||||
elif metric.type == "histogram":
|
||||
#
|
||||
# A histogram has a number of '_bucket' samples where
|
||||
# the 'le' label represents the upper limit of the bucket.
|
||||
# We convert these bucketized values into a dict of values
|
||||
# indexed by the value of the 'le' label. The 'le=+Inf'
|
||||
# label is a special case, catching all values observed.
|
||||
#
|
||||
bucket_samples = _get_samples(metric, "_bucket")
|
||||
count_samples = _get_samples(metric, "_count")
|
||||
sum_samples = _get_samples(metric, "_sum")
|
||||
for labels, buckets, count_value, sum_value in _digest_histogram(
|
||||
bucket_samples, count_samples, sum_samples
|
||||
):
|
||||
collected.append(
|
||||
Histogram(
|
||||
name=metric.name,
|
||||
labels=labels,
|
||||
buckets=buckets,
|
||||
count=count_value,
|
||||
sum=sum_value,
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise AssertionError(f"Unknown metric type {metric.type}")
|
||||
|
||||
return collected
|
||||
|
||||
|
||||
def _get_samples(metric: PromMetric, suffix: str | None = None) -> list[Sample]:
|
||||
name = (metric.name + suffix) if suffix is not None else metric.name
|
||||
return [s for s in metric.samples if s.name == name]
|
||||
|
||||
|
||||
def _strip_label(labels: dict[str, str], key_to_remove: str) -> dict[str, str]:
|
||||
labels_copy = labels.copy()
|
||||
labels_copy.pop(key_to_remove)
|
||||
return labels_copy
|
||||
|
||||
|
||||
def _digest_histogram(
|
||||
bucket_samples: list[Sample], count_samples: list[Sample], sum_samples: list[Sample]
|
||||
) -> list[tuple[dict[str, str], dict[str, int], int, float]]:
|
||||
#
|
||||
# In the case of DP, we have an indigestable
|
||||
# per-bucket-per-engine count as a list of labelled
|
||||
# samples, along with total and sum samples
|
||||
#
|
||||
# bucket_samples (in):
|
||||
# labels = {bucket: 100, idx: 0}, value = 2
|
||||
# labels = {bucket: 200, idx: 0}, value = 4
|
||||
# labels = {bucket: Inf, idx: 0}, value = 10
|
||||
# labels = {bucket: 100, idx: 1}, value = 1
|
||||
# labels = {bucket: 200, idx: 2}, value = 5
|
||||
# labels = {bucket: Inf, idx: 3}, value = 7
|
||||
# count_samples (in):
|
||||
# labels = {idx: 0}, value = 10
|
||||
# labels = {idx: 1}, value = 7
|
||||
# sum_samples (in):
|
||||
# labels = {idx: 0}, value = 2000
|
||||
# labels = {idx: 1}, value = 1200
|
||||
#
|
||||
# output: [
|
||||
# {idx: 0}, {"100": 2, "200": 4, "Inf": 10}, 10, 2000
|
||||
# {idx: 1}, {"100": 1, "200": 5, "Inf": 7}, 7, 1200
|
||||
# ]
|
||||
buckets_by_labels: dict[frozenset[tuple[str, str]], dict[str, int]] = {}
|
||||
for s in bucket_samples:
|
||||
bucket = s.labels["le"]
|
||||
labels_key = frozenset(_strip_label(s.labels, "le").items())
|
||||
if labels_key not in buckets_by_labels:
|
||||
buckets_by_labels[labels_key] = {}
|
||||
buckets_by_labels[labels_key][bucket] = int(s.value)
|
||||
|
||||
counts_by_labels: dict[frozenset[tuple[str, str]], int] = {}
|
||||
for s in count_samples:
|
||||
labels_key = frozenset(s.labels.items())
|
||||
counts_by_labels[labels_key] = int(s.value)
|
||||
|
||||
sums_by_labels: dict[frozenset[tuple[str, str]], float] = {}
|
||||
for s in sum_samples:
|
||||
labels_key = frozenset(s.labels.items())
|
||||
sums_by_labels[labels_key] = s.value
|
||||
|
||||
assert (
|
||||
set(buckets_by_labels.keys())
|
||||
== set(counts_by_labels.keys())
|
||||
== set(sums_by_labels.keys())
|
||||
)
|
||||
|
||||
output = []
|
||||
label_keys = list(buckets_by_labels.keys())
|
||||
for k in label_keys:
|
||||
labels = dict(k)
|
||||
output.append(
|
||||
(labels, buckets_by_labels[k], counts_by_labels[k], sums_by_labels[k])
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def _digest_num_accepted_by_pos_samples(
|
||||
samples: list[Sample],
|
||||
) -> list[tuple[dict[str, str], list[int]]]:
|
||||
#
|
||||
# In the case of DP, we have an indigestable
|
||||
# per-position-per-engine count as a list of
|
||||
# labelled samples
|
||||
#
|
||||
# samples (in):
|
||||
# labels = {pos: 0, idx: 0}, value = 10
|
||||
# labels = {pos: 1, idx: 0}, value = 7
|
||||
# labels = {pos: 2, idx: 0}, value = 2
|
||||
# labels = {pos: 0, idx: 1}, value = 5
|
||||
# labels = {pos: 1, idx: 1}, value = 3
|
||||
# labels = {pos: 2, idx: 1}, value = 1
|
||||
#
|
||||
# output: [
|
||||
# {idx: 0}, [10, 7, 2]
|
||||
# {idx: 1}, [5, 3, 1]
|
||||
# ]
|
||||
#
|
||||
max_pos = 0
|
||||
values_by_labels: dict[frozenset[tuple[str, str]], dict[int, int]] = {}
|
||||
|
||||
for s in samples:
|
||||
position = int(s.labels["position"])
|
||||
max_pos = max(max_pos, position)
|
||||
|
||||
labels_key = frozenset(_strip_label(s.labels, "position").items())
|
||||
if labels_key not in values_by_labels:
|
||||
values_by_labels[labels_key] = {}
|
||||
values_by_labels[labels_key][position] = int(s.value)
|
||||
|
||||
output = []
|
||||
for labels_key, values_by_position in values_by_labels.items():
|
||||
labels = dict(labels_key)
|
||||
values = [0] * (max_pos + 1)
|
||||
for pos, val in values_by_position.items():
|
||||
values[pos] = val
|
||||
output.append((labels, values))
|
||||
return output
|
||||
437
vllm/v1/metrics/stats.py
Normal file
437
vllm/v1/metrics/stats.py
Normal file
@@ -0,0 +1,437 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import time
|
||||
from collections import defaultdict, deque
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.cuda_graph import CUDAGraphStat
|
||||
from vllm.v1.spec_decode.metrics import SpecDecodingStats
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.engine import EngineCoreEvent, EngineCoreOutput, FinishReason
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseCacheStats:
|
||||
"""Stores cache hit statistics."""
|
||||
|
||||
reset: bool = False
|
||||
"""Whether the cache was reset."""
|
||||
|
||||
requests: int = 0
|
||||
"""The number of requests in this update."""
|
||||
|
||||
queries: int = 0
|
||||
"""The number of queries in these requests."""
|
||||
|
||||
hits: int = 0
|
||||
"""The number of hits in these requests."""
|
||||
|
||||
|
||||
class CachingMetrics:
|
||||
"""Metrics for caching with a hit rate of the most recent N requests.
|
||||
Args:
|
||||
interval: The number of the most recent requests to aggregate.
|
||||
Defaults to 1000.
|
||||
"""
|
||||
|
||||
def __init__(self, max_recent_requests: int = 1000) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.max_recent_requests = max_recent_requests
|
||||
# The current aggregated values.
|
||||
self.aggregated_requests = 0
|
||||
self.aggregated_query_total = 0
|
||||
self.aggregated_query_hit = 0
|
||||
|
||||
# A deque of (requests, queries, hits) for the most recent requests.
|
||||
self.query_queue = deque[tuple[int, int, int]]()
|
||||
|
||||
def observe(self, stats: BaseCacheStats):
|
||||
"""Observe the prefix caching for a set of requests.
|
||||
|
||||
This function is called with information gathered when new requests
|
||||
are being scheduled and are looking for computed blocks.
|
||||
|
||||
When there are more than `max_recent_requests` requests, the oldest set
|
||||
of requests are removed from the metrics.
|
||||
|
||||
Args:
|
||||
stats: The prefix cache stats.
|
||||
"""
|
||||
# reset_prefix_cache was invoked before the current update.
|
||||
# Reset the metrics before aggregating the current stats.
|
||||
if stats.reset:
|
||||
self.reset()
|
||||
|
||||
# DO NOT appending empty stats to avoid helpful info get kicked out
|
||||
# due to sliding window.
|
||||
if stats.requests == 0:
|
||||
return
|
||||
|
||||
# Update the metrics.
|
||||
self.query_queue.append((stats.requests, stats.queries, stats.hits))
|
||||
self.aggregated_requests += stats.requests
|
||||
self.aggregated_query_total += stats.queries
|
||||
self.aggregated_query_hit += stats.hits
|
||||
|
||||
# Remove the oldest stats until number of requests does not exceed
|
||||
# the limit.
|
||||
# NOTE: We preserve the latest added stats regardless.
|
||||
while (
|
||||
len(self.query_queue) > 1
|
||||
and self.aggregated_requests > self.max_recent_requests
|
||||
):
|
||||
old_requests, old_queries, old_hits = self.query_queue.popleft()
|
||||
self.aggregated_requests -= old_requests
|
||||
self.aggregated_query_total -= old_queries
|
||||
self.aggregated_query_hit -= old_hits
|
||||
|
||||
def reset(self):
|
||||
"""Reset the metrics."""
|
||||
self.aggregated_requests = 0
|
||||
self.aggregated_query_total = 0
|
||||
self.aggregated_query_hit = 0
|
||||
self.query_queue.clear()
|
||||
|
||||
@property
|
||||
def empty(self) -> bool:
|
||||
"""Return true if no requests have been observed."""
|
||||
return self.aggregated_requests == 0
|
||||
|
||||
@property
|
||||
def hit_rate(self) -> float:
|
||||
"""Calculate the hit rate for the past N requests."""
|
||||
if self.aggregated_query_total == 0:
|
||||
return 0.0
|
||||
return self.aggregated_query_hit / self.aggregated_query_total
|
||||
|
||||
|
||||
@dataclass
|
||||
class PrefixCacheStats(BaseCacheStats):
|
||||
"""
|
||||
Stores prefix cache hit statistics.
|
||||
- `reset`: Whether `reset_prefix_cache` was invoked.
|
||||
- `queries`: Refers to the number of tokens that were queried.
|
||||
"""
|
||||
|
||||
preempted_requests: int = 0
|
||||
"""The number of previously preempted requests in this update."""
|
||||
|
||||
preempted_queries: int = 0
|
||||
"""The `queries` number for preempted requests."""
|
||||
|
||||
preempted_hits: int = 0
|
||||
"""The `hits` number for preempted requests."""
|
||||
|
||||
def record(self, num_tokens: int, num_hits: int, preempted: bool) -> None:
|
||||
"""Aggregate request information into the stats."""
|
||||
if preempted:
|
||||
# Previously preempted request
|
||||
self.preempted_requests += 1
|
||||
self.preempted_queries += num_tokens
|
||||
self.preempted_hits += num_hits
|
||||
else:
|
||||
# New request
|
||||
self.requests += 1
|
||||
self.queries += num_tokens
|
||||
self.hits += num_hits
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiModalCacheStats(BaseCacheStats):
|
||||
"""
|
||||
Stores multi-modal cache hit statistics.
|
||||
- `reset`: Whether `reset_mm_cache` was invoked.
|
||||
- `queries`: Refers to the number of multi-modal data items
|
||||
that were queried.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class KVCacheEvictionEvent:
|
||||
"""Single KV cache block eviction sample."""
|
||||
|
||||
lifetime_seconds: float
|
||||
idle_seconds: float
|
||||
reuse_gaps_seconds: tuple[float, ...]
|
||||
|
||||
|
||||
@dataclass
|
||||
class SchedulerStats:
|
||||
"""Stats associated with the scheduler."""
|
||||
|
||||
num_running_reqs: int = 0
|
||||
num_waiting_reqs: int = 0
|
||||
|
||||
# These are used for internal DP load-balancing.
|
||||
step_counter: int = 0
|
||||
current_wave: int = 0
|
||||
|
||||
kv_cache_usage: float = 0.0
|
||||
|
||||
prefix_cache_stats: PrefixCacheStats = field(default_factory=PrefixCacheStats)
|
||||
connector_prefix_cache_stats: PrefixCacheStats | None = None
|
||||
|
||||
kv_cache_eviction_events: list[KVCacheEvictionEvent] = field(default_factory=list)
|
||||
|
||||
spec_decoding_stats: SpecDecodingStats | None = None
|
||||
kv_connector_stats: dict[str, Any] | None = None
|
||||
|
||||
waiting_lora_adapters: dict[str, int] = field(default_factory=dict)
|
||||
running_lora_adapters: dict[str, int] = field(default_factory=dict)
|
||||
|
||||
cudagraph_stats: CUDAGraphStat | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestStateStats:
|
||||
"""Stats that need to be tracked across delta updates."""
|
||||
|
||||
num_generation_tokens: int = 0
|
||||
|
||||
# This is an engine frontend timestamp (wall-clock)
|
||||
arrival_time: float = 0.0
|
||||
|
||||
# These are engine core timestamps (monotonic)
|
||||
queued_ts: float = 0.0
|
||||
scheduled_ts: float = 0.0
|
||||
first_token_ts: float = 0.0
|
||||
last_token_ts: float = 0.0
|
||||
|
||||
# first token latency
|
||||
first_token_latency: float = 0.0
|
||||
|
||||
# Track if this request is corrupted (NaNs in logits)
|
||||
is_corrupted: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class FinishedRequestStats:
|
||||
"""Stats associated with a finished request."""
|
||||
|
||||
finish_reason: "FinishReason"
|
||||
e2e_latency: float = 0.0
|
||||
num_prompt_tokens: int = 0
|
||||
num_generation_tokens: int = 0
|
||||
max_tokens_param: int | None = None
|
||||
queued_time: float = 0.0
|
||||
prefill_time: float = 0.0
|
||||
inference_time: float = 0.0
|
||||
decode_time: float = 0.0
|
||||
mean_time_per_output_token: float = 0.0
|
||||
is_corrupted: bool = False
|
||||
num_cached_tokens: int = 0
|
||||
|
||||
|
||||
class IterationStats:
|
||||
"""Stats associated with a single set of EngineCoreOutputs."""
|
||||
|
||||
def __init__(self):
|
||||
self.iteration_timestamp = time.time()
|
||||
self.num_generation_tokens = 0
|
||||
self.num_prompt_tokens = 0
|
||||
self.num_preempted_reqs = 0
|
||||
self.finished_requests: list[FinishedRequestStats] = []
|
||||
self.max_num_generation_tokens_iter: list[int] = []
|
||||
self.n_params_iter: list[int] = []
|
||||
self.time_to_first_tokens_iter: list[float] = []
|
||||
self.inter_token_latencies_iter: list[float] = []
|
||||
self.num_corrupted_reqs: int = 0
|
||||
|
||||
def __repr__(self) -> str:
|
||||
field_to_value_str = ", ".join(f"{k}={v}" for k, v in vars(self).items())
|
||||
return f"{self.__class__.__name__}({field_to_value_str})"
|
||||
|
||||
def _time_since(self, start: float) -> float:
|
||||
"""Calculate an interval relative to this iteration's timestamp."""
|
||||
return self.iteration_timestamp - start
|
||||
|
||||
def update_from_output(
|
||||
self,
|
||||
output: "EngineCoreOutput",
|
||||
engine_core_timestamp: float,
|
||||
is_prefilling: bool,
|
||||
prompt_len: int,
|
||||
req_stats: RequestStateStats,
|
||||
lora_states: "LoRARequestStates",
|
||||
lora_name: str | None,
|
||||
):
|
||||
num_new_generation_tokens = len(output.new_token_ids)
|
||||
|
||||
self.num_generation_tokens += num_new_generation_tokens
|
||||
if is_prefilling:
|
||||
self.num_prompt_tokens += prompt_len
|
||||
|
||||
first_token_latency = self._time_since(req_stats.arrival_time)
|
||||
self.time_to_first_tokens_iter.append(first_token_latency)
|
||||
req_stats.first_token_latency = first_token_latency
|
||||
|
||||
req_stats.num_generation_tokens += num_new_generation_tokens
|
||||
|
||||
# Track if this request is corrupted (only check once per request)
|
||||
# Early exit if already marked as corrupted to avoid redundant checks
|
||||
if (
|
||||
envs.VLLM_COMPUTE_NANS_IN_LOGITS
|
||||
and not req_stats.is_corrupted
|
||||
and output.num_nans_in_logits > 0
|
||||
):
|
||||
req_stats.is_corrupted = True
|
||||
|
||||
# Process request-level engine core events
|
||||
if output.events is not None:
|
||||
self.update_from_events(
|
||||
output.request_id,
|
||||
output.events,
|
||||
is_prefilling,
|
||||
req_stats,
|
||||
lora_states,
|
||||
lora_name,
|
||||
)
|
||||
|
||||
# Process the batch-level "new tokens" engine core event
|
||||
if is_prefilling:
|
||||
req_stats.first_token_ts = engine_core_timestamp
|
||||
else:
|
||||
itl = engine_core_timestamp - req_stats.last_token_ts
|
||||
self.inter_token_latencies_iter.append(itl)
|
||||
|
||||
req_stats.last_token_ts = engine_core_timestamp
|
||||
|
||||
def update_from_events(
|
||||
self,
|
||||
req_id: str,
|
||||
events: list["EngineCoreEvent"],
|
||||
is_prefilling: bool,
|
||||
req_stats: RequestStateStats,
|
||||
lora_states: "LoRARequestStates",
|
||||
lora_name: str | None,
|
||||
):
|
||||
# Avoid circular dependency
|
||||
from vllm.v1.engine import EngineCoreEventType
|
||||
|
||||
for event in events:
|
||||
if event.type == EngineCoreEventType.QUEUED:
|
||||
req_stats.queued_ts = event.timestamp
|
||||
lora_states.request_waiting(req_id, lora_name)
|
||||
elif event.type == EngineCoreEventType.SCHEDULED:
|
||||
if req_stats.scheduled_ts == 0.0: # ignore preemptions
|
||||
req_stats.scheduled_ts = event.timestamp
|
||||
lora_states.request_running(req_id, lora_name)
|
||||
elif event.type == EngineCoreEventType.PREEMPTED:
|
||||
self.num_preempted_reqs += 1
|
||||
lora_states.request_waiting(req_id, lora_name)
|
||||
|
||||
def update_from_finished_request(
|
||||
self,
|
||||
finish_reason: "FinishReason",
|
||||
num_prompt_tokens: int,
|
||||
max_tokens_param: int | None,
|
||||
req_stats: RequestStateStats,
|
||||
num_cached_tokens: int = 0,
|
||||
):
|
||||
e2e_latency = self._time_since(req_stats.arrival_time)
|
||||
|
||||
# Queued interval is from first QUEUED event to first SCHEDULED
|
||||
queued_time = req_stats.scheduled_ts - req_stats.queued_ts
|
||||
|
||||
# Prefill interval is from first SCHEDULED to first NEW_TOKEN
|
||||
# Any preemptions during prefill is included in the interval
|
||||
prefill_time = req_stats.first_token_ts - req_stats.scheduled_ts
|
||||
|
||||
# Decode interval is from first NEW_TOKEN to last NEW_TOKEN
|
||||
# Any preemptions during decode are included
|
||||
decode_time = req_stats.last_token_ts - req_stats.first_token_ts
|
||||
|
||||
# Inference interval is from first SCHEDULED to last NEW_TOKEN
|
||||
# Any preemptions during prefill or decode are included
|
||||
inference_time = req_stats.last_token_ts - req_stats.scheduled_ts
|
||||
|
||||
# Do not count the token generated by the prefill phase
|
||||
mean_time_per_output_token = (
|
||||
decode_time / (req_stats.num_generation_tokens - 1)
|
||||
if req_stats.num_generation_tokens - 1 > 0
|
||||
else 0
|
||||
)
|
||||
|
||||
finished_req = FinishedRequestStats(
|
||||
finish_reason=finish_reason,
|
||||
e2e_latency=e2e_latency,
|
||||
num_prompt_tokens=num_prompt_tokens,
|
||||
num_generation_tokens=req_stats.num_generation_tokens,
|
||||
max_tokens_param=max_tokens_param,
|
||||
queued_time=queued_time,
|
||||
prefill_time=prefill_time,
|
||||
inference_time=inference_time,
|
||||
decode_time=decode_time,
|
||||
mean_time_per_output_token=mean_time_per_output_token,
|
||||
is_corrupted=req_stats.is_corrupted,
|
||||
num_cached_tokens=num_cached_tokens,
|
||||
)
|
||||
self.finished_requests.append(finished_req)
|
||||
|
||||
# Count corrupted requests when they finish (only once per request)
|
||||
if req_stats.is_corrupted:
|
||||
self.num_corrupted_reqs += 1
|
||||
|
||||
|
||||
class LoRAStats:
|
||||
"""Tracks waiting and running request IDs for a single LoRA."""
|
||||
|
||||
def __init__(self):
|
||||
self.waiting: set[str] = set()
|
||||
self.running: set[str] = set()
|
||||
|
||||
def update(self, req_id: str, waiting: bool, running: bool):
|
||||
assert not (waiting and running)
|
||||
if waiting:
|
||||
self.waiting.add(req_id)
|
||||
else:
|
||||
self.waiting.discard(req_id)
|
||||
|
||||
if running:
|
||||
self.running.add(req_id)
|
||||
else:
|
||||
self.running.discard(req_id)
|
||||
|
||||
@property
|
||||
def empty(self) -> bool:
|
||||
return not (self.waiting or self.running)
|
||||
|
||||
|
||||
class LoRARequestStates:
|
||||
"""A per-LoRA count of running and waiting requests."""
|
||||
|
||||
def __init__(self, log_stats: bool = False):
|
||||
self.log_stats = log_stats
|
||||
self.requests: defaultdict[str, LoRAStats] = defaultdict(LoRAStats)
|
||||
|
||||
def _request_update(
|
||||
self, req_id: str, lora_name: str | None, waiting: bool, running: bool
|
||||
):
|
||||
if not self.log_stats or lora_name is None:
|
||||
return
|
||||
|
||||
lora_stats = self.requests[lora_name]
|
||||
lora_stats.update(req_id, waiting, running)
|
||||
if lora_stats.empty:
|
||||
del self.requests[lora_name]
|
||||
|
||||
def request_waiting(self, req_id: str, lora_name: str | None):
|
||||
self._request_update(req_id, lora_name, waiting=True, running=False)
|
||||
|
||||
def request_running(self, req_id: str, lora_name: str | None):
|
||||
self._request_update(req_id, lora_name, waiting=False, running=True)
|
||||
|
||||
def request_finished(self, req_id: str, lora_name: str | None):
|
||||
self._request_update(req_id, lora_name, waiting=False, running=False)
|
||||
|
||||
def update_scheduler_stats(self, scheduler_stats: SchedulerStats | None):
|
||||
if not self.log_stats or scheduler_stats is None:
|
||||
return
|
||||
for lora_name, stats in self.requests.items():
|
||||
scheduler_stats.waiting_lora_adapters[lora_name] = len(stats.waiting)
|
||||
scheduler_stats.running_lora_adapters[lora_name] = len(stats.running)
|
||||
Reference in New Issue
Block a user