[gpt-oss] Add gpt-oss bf16 support

This commit is contained in:
2025-08-13 21:25:57 +08:00
parent 5d2e7edf78
commit 17ea2ec6aa
1232 changed files with 777 additions and 36 deletions

View File

523
vllm/v1/metrics/loggers.py Normal file
View File

@@ -0,0 +1,523 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import logging
import time
from abc import ABC, abstractmethod
from typing import Callable, Optional
import numpy as np
import prometheus_client
from vllm.config import SupportsMetricsInfo, VllmConfig
from vllm.logger import init_logger
from vllm.v1.core.kv_cache_utils import PrefixCachingMetrics
from vllm.v1.engine import FinishReason
from vllm.v1.metrics.prometheus import unregister_vllm_metrics
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
from vllm.v1.spec_decode.metrics import SpecDecodingLogging, SpecDecodingProm
logger = init_logger(__name__)
StatLoggerFactory = Callable[[VllmConfig, int], "StatLoggerBase"]
class StatLoggerBase(ABC):
"""Interface for logging metrics.
API users may define custom loggers that implement this interface.
However, note that the `SchedulerStats` and `IterationStats` classes
are not considered stable interfaces and may change in future versions.
"""
@abstractmethod
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
...
@abstractmethod
def record(self, scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats]):
...
@abstractmethod
def log_engine_initialized(self):
...
def log(self): # noqa
pass
class LoggingStatLogger(StatLoggerBase):
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
self.engine_index = engine_index
self.vllm_config = vllm_config
self._reset(time.monotonic())
self.last_scheduler_stats = SchedulerStats()
# Prefix cache metrics. This cannot be reset.
# TODO: Make the interval configurable.
self.prefix_caching_metrics = PrefixCachingMetrics()
self.spec_decoding_logging = SpecDecodingLogging()
self.last_prompt_throughput: float = 0.0
self.last_generation_throughput: float = 0.0
def _reset(self, now):
self.last_log_time = now
# Tracked stats over current local logging interval.
self.num_prompt_tokens: list[int] = []
self.num_generation_tokens: list[int] = []
def _track_iteration_stats(self, iteration_stats: IterationStats):
# Save tracked stats for token counters.
self.num_prompt_tokens.append(iteration_stats.num_prompt_tokens)
self.num_generation_tokens.append(
iteration_stats.num_generation_tokens)
def _get_throughput(self, tracked_stats: list[int], now: float) -> float:
# Compute summary metrics for tracked stats
return float(np.sum(tracked_stats) / (now - self.last_log_time))
def record(self, scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats]):
"""Log Stats to standard output."""
if iteration_stats:
self._track_iteration_stats(iteration_stats)
if scheduler_stats is not None:
self.prefix_caching_metrics.observe(
scheduler_stats.prefix_cache_stats)
if scheduler_stats.spec_decoding_stats is not None:
self.spec_decoding_logging.observe(
scheduler_stats.spec_decoding_stats)
self.last_scheduler_stats = scheduler_stats
def log(self):
now = time.monotonic()
prompt_throughput = self._get_throughput(self.num_prompt_tokens, now)
generation_throughput = self._get_throughput(
self.num_generation_tokens, now)
self._reset(now)
scheduler_stats = self.last_scheduler_stats
log_fn = logger.info
if not any(
(prompt_throughput, generation_throughput,
self.last_prompt_throughput, self.last_generation_throughput)):
# Avoid log noise on an idle production system
log_fn = logger.debug
self.last_generation_throughput = generation_throughput
self.last_prompt_throughput = prompt_throughput
# Format and print output.
log_fn(
"Engine %03d: "
"Avg prompt throughput: %.1f tokens/s, "
"Avg generation throughput: %.1f tokens/s, "
"Running: %d reqs, Waiting: %d reqs, "
"GPU KV cache usage: %.1f%%, "
"Prefix cache hit rate: %.1f%%",
self.engine_index,
prompt_throughput,
generation_throughput,
scheduler_stats.num_running_reqs,
scheduler_stats.num_waiting_reqs,
scheduler_stats.gpu_cache_usage * 100,
self.prefix_caching_metrics.hit_rate * 100,
)
self.spec_decoding_logging.log(log_fn=log_fn)
def log_engine_initialized(self):
if self.vllm_config.cache_config.num_gpu_blocks:
logger.info(
"Engine %03d: vllm cache_config_info with initialization "
"after num_gpu_blocks is: %d", self.engine_index,
self.vllm_config.cache_config.num_gpu_blocks)
class PrometheusStatLogger(StatLoggerBase):
_gauge_cls = prometheus_client.Gauge
_counter_cls = prometheus_client.Counter
_histogram_cls = prometheus_client.Histogram
_spec_decoding_cls = SpecDecodingProm
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
unregister_vllm_metrics()
self.vllm_config = vllm_config
self.engine_index = engine_index
# Use this flag to hide metrics that were deprecated in
# a previous release and which will be removed future
self.show_hidden_metrics = \
vllm_config.observability_config.show_hidden_metrics
labelnames = ["model_name", "engine"]
labelvalues = [
vllm_config.model_config.served_model_name,
str(engine_index)
]
max_model_len = vllm_config.model_config.max_model_len
self.spec_decoding_prom = self._spec_decoding_cls(
vllm_config.speculative_config, labelnames, labelvalues)
#
# Scheduler state
#
self.gauge_scheduler_running = self._gauge_cls(
name="vllm:num_requests_running",
documentation="Number of requests in model execution batches.",
multiprocess_mode="mostrecent",
labelnames=labelnames).labels(*labelvalues)
self.gauge_scheduler_waiting = self._gauge_cls(
name="vllm:num_requests_waiting",
documentation="Number of requests waiting to be processed.",
multiprocess_mode="mostrecent",
labelnames=labelnames).labels(*labelvalues)
#
# GPU cache
#
self.gauge_gpu_cache_usage = self._gauge_cls(
name="vllm:gpu_cache_usage_perc",
documentation="GPU KV-cache usage. 1 means 100 percent usage.",
multiprocess_mode="mostrecent",
labelnames=labelnames).labels(*labelvalues)
self.counter_gpu_prefix_cache_queries = self._counter_cls(
name="vllm:gpu_prefix_cache_queries",
documentation=
"GPU prefix cache queries, in terms of number of queried tokens.",
labelnames=labelnames).labels(*labelvalues)
self.counter_gpu_prefix_cache_hits = self._counter_cls(
name="vllm:gpu_prefix_cache_hits",
documentation=
"GPU prefix cache hits, in terms of number of cached tokens.",
labelnames=labelnames).labels(*labelvalues)
#
# Counters
#
self.counter_num_preempted_reqs = self._counter_cls(
name="vllm:num_preemptions",
documentation="Cumulative number of preemption from the engine.",
labelnames=labelnames).labels(*labelvalues)
self.counter_prompt_tokens = self._counter_cls(
name="vllm:prompt_tokens",
documentation="Number of prefill tokens processed.",
labelnames=labelnames).labels(*labelvalues)
self.counter_generation_tokens = self._counter_cls(
name="vllm:generation_tokens",
documentation="Number of generation tokens processed.",
labelnames=labelnames).labels(*labelvalues)
self.counter_request_success: dict[FinishReason,
prometheus_client.Counter] = {}
counter_request_success_base = self._counter_cls(
name="vllm:request_success",
documentation="Count of successfully processed requests.",
labelnames=labelnames + ["finished_reason"])
for reason in FinishReason:
self.counter_request_success[
reason] = counter_request_success_base.labels(*(labelvalues +
[str(reason)]))
#
# Histograms of counts
#
self.histogram_num_prompt_tokens_request = \
self._histogram_cls(
name="vllm:request_prompt_tokens",
documentation="Number of prefill tokens processed.",
buckets=build_1_2_5_buckets(max_model_len),
labelnames=labelnames).labels(*labelvalues)
self.histogram_num_generation_tokens_request = \
self._histogram_cls(
name="vllm:request_generation_tokens",
documentation="Number of generation tokens processed.",
buckets=build_1_2_5_buckets(max_model_len),
labelnames=labelnames).labels(*labelvalues)
# TODO: This metric might be incorrect in case of using multiple
# api_server counts which uses prometheus mp.
# See: https://github.com/vllm-project/vllm/pull/18053
self.histogram_iteration_tokens = \
self._histogram_cls(
name="vllm:iteration_tokens_total",
documentation="Histogram of number of tokens per engine_step.",
buckets=[
1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192,
16384
],
labelnames=labelnames).labels(*labelvalues)
self.histogram_max_num_generation_tokens_request = \
self._histogram_cls(
name="vllm:request_max_num_generation_tokens",
documentation=
"Histogram of maximum number of requested generation tokens.",
buckets=build_1_2_5_buckets(max_model_len),
labelnames=labelnames).labels(*labelvalues)
self.histogram_n_request = \
self._histogram_cls(
name="vllm:request_params_n",
documentation="Histogram of the n request parameter.",
buckets=[1, 2, 5, 10, 20],
labelnames=labelnames).labels(*labelvalues)
self.histogram_max_tokens_request = \
self._histogram_cls(
name="vllm:request_params_max_tokens",
documentation="Histogram of the max_tokens request parameter.",
buckets=build_1_2_5_buckets(max_model_len),
labelnames=labelnames).labels(*labelvalues)
#
# Histogram of timing intervals
#
self.histogram_time_to_first_token = \
self._histogram_cls(
name="vllm:time_to_first_token_seconds",
documentation="Histogram of time to first token in seconds.",
buckets=[
0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5,
0.75, 1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0, 160.0,
640.0, 2560.0
],
labelnames=labelnames).labels(*labelvalues)
self.histogram_time_per_output_token = \
self._histogram_cls(
name="vllm:time_per_output_token_seconds",
documentation="Histogram of time per output token in seconds.",
buckets=[
0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5,
0.75, 1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0
],
labelnames=labelnames).labels(*labelvalues)
request_latency_buckets = [
0.3, 0.5, 0.8, 1.0, 1.5, 2.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0,
40.0, 50.0, 60.0, 120.0, 240.0, 480.0, 960.0, 1920.0, 7680.0
]
self.histogram_e2e_time_request = \
self._histogram_cls(
name="vllm:e2e_request_latency_seconds",
documentation="Histogram of e2e request latency in seconds.",
buckets=request_latency_buckets,
labelnames=labelnames).labels(*labelvalues)
self.histogram_queue_time_request = \
self._histogram_cls(
name="vllm:request_queue_time_seconds",
documentation=
"Histogram of time spent in WAITING phase for request.",
buckets=request_latency_buckets,
labelnames=labelnames).labels(*labelvalues)
self.histogram_inference_time_request = \
self._histogram_cls(
name="vllm:request_inference_time_seconds",
documentation=
"Histogram of time spent in RUNNING phase for request.",
buckets=request_latency_buckets,
labelnames=labelnames).labels(*labelvalues)
self.histogram_prefill_time_request = \
self._histogram_cls(
name="vllm:request_prefill_time_seconds",
documentation=
"Histogram of time spent in PREFILL phase for request.",
buckets=request_latency_buckets,
labelnames=labelnames).labels(*labelvalues)
self.histogram_decode_time_request = \
self._histogram_cls(
name="vllm:request_decode_time_seconds",
documentation=
"Histogram of time spent in DECODE phase for request.",
buckets=request_latency_buckets,
labelnames=labelnames).labels(*labelvalues)
#
# LoRA metrics
#
# TODO: This metric might be incorrect in case of using multiple
# api_server counts which uses prometheus mp.
self.gauge_lora_info: Optional[prometheus_client.Gauge] = None
if vllm_config.lora_config is not None:
self.labelname_max_lora = "max_lora"
self.labelname_waiting_lora_adapters = "waiting_lora_adapters"
self.labelname_running_lora_adapters = "running_lora_adapters"
self.max_lora = vllm_config.lora_config.max_loras
self.gauge_lora_info = \
self._gauge_cls(
name="vllm:lora_requests_info",
documentation="Running stats on lora requests.",
multiprocess_mode="sum",
labelnames=[
self.labelname_max_lora,
self.labelname_waiting_lora_adapters,
self.labelname_running_lora_adapters,
],
)
def log_metrics_info(self, type: str, config_obj: SupportsMetricsInfo):
metrics_info = config_obj.metrics_info()
metrics_info["engine"] = self.engine_index
name, documentation = None, None
if type == "cache_config":
name = "vllm:cache_config_info"
documentation = "Information of the LLMEngine CacheConfig"
assert name is not None, f"Unknown metrics info type {type}"
# Info type metrics are syntactic sugar for a gauge permanently set to 1
# Since prometheus multiprocessing mode does not support Info, emulate
# info here with a gauge.
info_gauge = self._gauge_cls(
name=name,
documentation=documentation,
multiprocess_mode="mostrecent",
labelnames=metrics_info.keys(),
).labels(**metrics_info)
info_gauge.set(1)
def record(self, scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats]):
"""Log to prometheus."""
if scheduler_stats is not None:
self.gauge_scheduler_running.set(scheduler_stats.num_running_reqs)
self.gauge_scheduler_waiting.set(scheduler_stats.num_waiting_reqs)
self.gauge_gpu_cache_usage.set(scheduler_stats.gpu_cache_usage)
self.counter_gpu_prefix_cache_queries.inc(
scheduler_stats.prefix_cache_stats.queries)
self.counter_gpu_prefix_cache_hits.inc(
scheduler_stats.prefix_cache_stats.hits)
if scheduler_stats.spec_decoding_stats is not None:
self.spec_decoding_prom.observe(
scheduler_stats.spec_decoding_stats)
if iteration_stats is None:
return
self.counter_num_preempted_reqs.inc(iteration_stats.num_preempted_reqs)
self.counter_prompt_tokens.inc(iteration_stats.num_prompt_tokens)
self.counter_generation_tokens.inc(
iteration_stats.num_generation_tokens)
self.histogram_iteration_tokens.observe(
iteration_stats.num_prompt_tokens + \
iteration_stats.num_generation_tokens)
for max_gen_tokens in iteration_stats.max_num_generation_tokens_iter:
self.histogram_max_num_generation_tokens_request.observe(
max_gen_tokens)
for n_param in iteration_stats.n_params_iter:
self.histogram_n_request.observe(n_param)
for ttft in iteration_stats.time_to_first_tokens_iter:
self.histogram_time_to_first_token.observe(ttft)
for tpot in iteration_stats.time_per_output_tokens_iter:
self.histogram_time_per_output_token.observe(tpot)
for finished_request in iteration_stats.finished_requests:
self.counter_request_success[finished_request.finish_reason].inc()
self.histogram_e2e_time_request.observe(
finished_request.e2e_latency)
self.histogram_queue_time_request.observe(
finished_request.queued_time)
self.histogram_prefill_time_request.observe(
finished_request.prefill_time)
self.histogram_inference_time_request.observe(
finished_request.inference_time)
self.histogram_decode_time_request.observe(
finished_request.decode_time)
self.histogram_num_prompt_tokens_request.observe(
finished_request.num_prompt_tokens)
self.histogram_num_generation_tokens_request.observe(
finished_request.num_generation_tokens)
self.histogram_max_tokens_request.observe(
finished_request.max_tokens_param)
if self.gauge_lora_info is not None:
running_lora_adapters = \
",".join(iteration_stats.running_lora_adapters.keys())
waiting_lora_adapters = \
",".join(iteration_stats.waiting_lora_adapters.keys())
lora_info_labels = {
self.labelname_running_lora_adapters: running_lora_adapters,
self.labelname_waiting_lora_adapters: waiting_lora_adapters,
self.labelname_max_lora: self.max_lora,
}
self.gauge_lora_info.labels(**lora_info_labels)\
.set_to_current_time()
def log_engine_initialized(self):
self.log_metrics_info("cache_config", self.vllm_config.cache_config)
def build_buckets(mantissa_lst: list[int], max_value: int) -> list[int]:
"""
Builds a list of buckets with increasing powers of 10 multiplied by
mantissa values until the value exceeds the specified maximum.
"""
exponent = 0
buckets: list[int] = []
while True:
for m in mantissa_lst:
value = m * 10**exponent
if value <= max_value:
buckets.append(value)
else:
return buckets
exponent += 1
def build_1_2_5_buckets(max_value: int) -> list[int]:
"""
Example:
>>> build_1_2_5_buckets(100)
[1, 2, 5, 10, 20, 50, 100]
"""
return build_buckets([1, 2, 5], max_value)
def setup_default_loggers(
vllm_config: VllmConfig,
log_stats: bool,
engine_num: int,
custom_stat_loggers: Optional[list[StatLoggerFactory]] = None,
) -> list[list[StatLoggerBase]]:
"""Setup logging and prometheus metrics."""
if not log_stats:
return []
factories: list[StatLoggerFactory]
if custom_stat_loggers is not None:
factories = custom_stat_loggers
else:
factories = [PrometheusStatLogger]
if logger.isEnabledFor(logging.INFO):
factories.append(LoggingStatLogger)
stat_loggers: list[list[StatLoggerBase]] = []
for i in range(engine_num):
per_engine_stat_loggers: list[StatLoggerBase] = []
for logger_factory in factories:
per_engine_stat_loggers.append(logger_factory(vllm_config, i))
stat_loggers.append(per_engine_stat_loggers)
return stat_loggers

View File

@@ -0,0 +1,82 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import tempfile
from typing import Optional
from prometheus_client import REGISTRY, CollectorRegistry, multiprocess
from vllm.logger import init_logger
logger = init_logger(__name__)
# Global temporary directory for prometheus multiprocessing
_prometheus_multiproc_dir: Optional[tempfile.TemporaryDirectory] = None
def setup_multiprocess_prometheus():
"""Set up prometheus multiprocessing directory if not already configured.
"""
global _prometheus_multiproc_dir
if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
# Make TemporaryDirectory for prometheus multiprocessing
# Note: global TemporaryDirectory will be automatically
# cleaned up upon exit.
_prometheus_multiproc_dir = tempfile.TemporaryDirectory()
os.environ["PROMETHEUS_MULTIPROC_DIR"] = _prometheus_multiproc_dir.name
logger.debug("Created PROMETHEUS_MULTIPROC_DIR at %s",
_prometheus_multiproc_dir.name)
else:
logger.warning("Found PROMETHEUS_MULTIPROC_DIR was set by user. "
"This directory must be wiped between vLLM runs or "
"you will find inaccurate metrics. Unset the variable "
"and vLLM will properly handle cleanup.")
def get_prometheus_registry():
"""Get the appropriate prometheus registry based on multiprocessing
configuration.
Returns:
Registry: A prometheus registry
"""
if os.getenv("PROMETHEUS_MULTIPROC_DIR") is not None:
logger.debug("Using multiprocess registry for prometheus metrics")
registry = CollectorRegistry()
multiprocess.MultiProcessCollector(registry)
return registry
return REGISTRY
def unregister_vllm_metrics():
"""Unregister any existing vLLM collectors from the prometheus registry.
This is useful for testing and CI/CD where metrics may be registered
multiple times across test runs.
Also, in case of multiprocess, we need to unregister the metrics from the
global registry.
"""
registry = REGISTRY
# Unregister any existing vLLM collectors
for collector in list(registry._collector_to_names):
if hasattr(collector, "_name") and "vllm" in collector._name:
registry.unregister(collector)
def shutdown_prometheus():
"""Shutdown prometheus metrics."""
path = _prometheus_multiproc_dir
if path is None:
return
try:
pid = os.getpid()
multiprocess.mark_process_dead(pid, path)
logger.debug("Marked Prometheus metrics for process %d as dead", pid)
except Exception as e:
logger.error("Error during metrics cleanup: %s", str(e))

View File

@@ -0,0 +1,131 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from typing import Optional, Union
from vllm.config import VllmConfig
from vllm.v1.metrics.loggers import PrometheusStatLogger
from vllm.v1.spec_decode.metrics import SpecDecodingProm
try:
from ray.util import metrics as ray_metrics
from ray.util.metrics import Metric
except ImportError:
ray_metrics = None
class RayPrometheusMetric:
def __init__(self):
if ray_metrics is None:
raise ImportError(
"RayPrometheusMetric requires Ray to be installed.")
self.metric: Metric = None
def labels(self, *labels, **labelskwargs):
if labelskwargs:
for k, v in labelskwargs.items():
if not isinstance(v, str):
labelskwargs[k] = str(v)
self.metric.set_default_tags(labelskwargs)
if labels:
if len(labels) != len(self.metric._tag_keys):
raise ValueError(
"Number of labels must match the number of tag keys. "
f"Expected {len(self.metric._tag_keys)}, got {len(labels)}"
)
self.metric.set_default_tags(
dict(zip(self.metric._tag_keys, labels)))
return self
class RayGaugeWrapper(RayPrometheusMetric):
"""Wraps around ray.util.metrics.Gauge to provide same API as
prometheus_client.Gauge"""
def __init__(self,
name: str,
documentation: Optional[str] = "",
labelnames: Optional[list[str]] = None):
labelnames_tuple = tuple(labelnames) if labelnames else None
self.metric = ray_metrics.Gauge(name=name,
description=documentation,
tag_keys=labelnames_tuple)
def set(self, value: Union[int, float]):
return self.metric.set(value)
def set_to_current_time(self):
# ray metrics doesn't have set_to_current time, https://docs.ray.io/en/latest/_modules/ray/util/metrics.html
return self.metric.set(time.time())
class RayCounterWrapper(RayPrometheusMetric):
"""Wraps around ray.util.metrics.Counter to provide same API as
prometheus_client.Counter"""
def __init__(self,
name: str,
documentation: Optional[str] = "",
labelnames: Optional[list[str]] = None):
labelnames_tuple = tuple(labelnames) if labelnames else None
self.metric = ray_metrics.Counter(name=name,
description=documentation,
tag_keys=labelnames_tuple)
def inc(self, value: Union[int, float] = 1.0):
if value == 0:
return
return self.metric.inc(value)
class RayHistogramWrapper(RayPrometheusMetric):
"""Wraps around ray.util.metrics.Histogram to provide same API as
prometheus_client.Histogram"""
def __init__(self,
name: str,
documentation: Optional[str] = "",
labelnames: Optional[list[str]] = None,
buckets: Optional[list[float]] = None):
labelnames_tuple = tuple(labelnames) if labelnames else None
boundaries = buckets if buckets else []
self.metric = ray_metrics.Histogram(name=name,
description=documentation,
tag_keys=labelnames_tuple,
boundaries=boundaries)
def observe(self, value: Union[int, float]):
return self.metric.observe(value)
class RaySpecDecodingProm(SpecDecodingProm):
"""
RaySpecDecodingProm is used by RayMetrics to log to Ray metrics.
Provides the same metrics as SpecDecodingProm but uses Ray's
util.metrics library.
"""
_counter_cls = RayCounterWrapper
class RayPrometheusStatLogger(PrometheusStatLogger):
"""RayPrometheusStatLogger uses Ray metrics instead."""
_gauge_cls = RayGaugeWrapper
_counter_cls = RayCounterWrapper
_histogram_cls = RayHistogramWrapper
_spec_decoding_cls = RaySpecDecodingProm
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
super().__init__(vllm_config, engine_index)
@staticmethod
def _unregister_vllm_metrics():
# No-op on purpose
pass

246
vllm/v1/metrics/reader.py Normal file
View File

@@ -0,0 +1,246 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Optional
from prometheus_client import REGISTRY
from prometheus_client import Metric as PromMetric
from prometheus_client.samples import Sample
@dataclass
class Metric:
"""A base class for prometheus metrics.
Each metric may be associated with key=value labels, and
in some cases a single vLLM instance may have multiple
metrics with the same name but different sets of labels.
"""
name: str
labels: dict[str, str]
@dataclass
class Counter(Metric):
"""A monotonically increasing integer counter."""
value: int
@dataclass
class Vector(Metric):
"""An ordered array of integer counters.
This type - which doesn't exist in Prometheus - models one very
specific metric, vllm:spec_decode_num_accepted_tokens_per_pos.
"""
values: list[int]
@dataclass
class Gauge(Metric):
"""A numerical value that can go up or down."""
value: float
@dataclass
class Histogram(Metric):
"""Observations recorded in configurable buckets.
Buckets are represented by a dictionary. The key is
the upper limit of the bucket, and the value is the
observed count in that bucket. A '+Inf' key always
exists.
The count property is the total count across all
buckets, identical to the count of the '+Inf' bucket.
The sum property is the total sum of all observed
values.
"""
count: int
sum: float
buckets: dict[str, int]
def get_metrics_snapshot() -> list[Metric]:
"""An API for accessing in-memory Prometheus metrics.
Example:
>>> for metric in llm.get_metrics():
... if isinstance(metric, Counter):
... print(f"{metric} = {metric.value}")
... elif isinstance(metric, Gauge):
... print(f"{metric} = {metric.value}")
... elif isinstance(metric, Histogram):
... print(f"{metric}")
... print(f" sum = {metric.sum}")
... print(f" count = {metric.count}")
... for bucket_le, value in metrics.buckets.items():
... print(f" {bucket_le} = {value}")
"""
collected: list[Metric] = []
for metric in REGISTRY.collect():
if not metric.name.startswith("vllm:"):
continue
if metric.type == "gauge":
samples = _get_samples(metric)
for s in samples:
collected.append(
Gauge(name=metric.name, labels=s.labels, value=s.value))
elif metric.type == "counter":
samples = _get_samples(metric, "_total")
if metric.name == "vllm:spec_decode_num_accepted_tokens_per_pos":
#
# Ugly vllm:num_accepted_tokens_per_pos special case.
#
# This metric is a vector of counters - for each spec
# decoding token position, we observe the number of
# accepted tokens using a Counter labeled with 'position'.
# We convert these into a vector of integer values.
#
for labels, values in _digest_num_accepted_by_pos_samples(
samples):
collected.append(
Vector(name=metric.name, labels=labels, values=values))
else:
for s in samples:
collected.append(
Counter(name=metric.name,
labels=s.labels,
value=int(s.value)))
elif metric.type == "histogram":
#
# A histogram has a number of '_bucket' samples where
# the 'le' label represents the upper limit of the bucket.
# We convert these bucketized values into a dict of values
# indexed by the value of the 'le' label. The 'le=+Inf'
# label is a special case, catching all values observed.
#
bucket_samples = _get_samples(metric, "_bucket")
count_samples = _get_samples(metric, "_count")
sum_samples = _get_samples(metric, "_sum")
for labels, buckets, count_value, sum_value in _digest_histogram(
bucket_samples, count_samples, sum_samples):
collected.append(
Histogram(name=metric.name,
labels=labels,
buckets=buckets,
count=count_value,
sum=sum_value))
else:
raise AssertionError(f"Unknown metric type {metric.type}")
return collected
def _get_samples(metric: PromMetric,
suffix: Optional[str] = None) -> list[Sample]:
name = (metric.name + suffix) if suffix is not None else metric.name
return [s for s in metric.samples if s.name == name]
def _strip_label(labels: dict[str, str], key_to_remove: str) -> dict[str, str]:
labels_copy = labels.copy()
labels_copy.pop(key_to_remove)
return labels_copy
def _digest_histogram(
bucket_samples: list[Sample], count_samples: list[Sample],
sum_samples: list[Sample]
) -> list[tuple[dict[str, str], dict[str, int], int, float]]:
#
# In the case of DP, we have an indigestable
# per-bucket-per-engine count as a list of labelled
# samples, along with total and sum samples
#
# bucket_samples (in):
# labels = {bucket: 100, idx: 0}, value = 2
# labels = {bucket: 200, idx: 0}, value = 4
# labels = {bucket: Inf, idx: 0}, value = 10
# labels = {bucket: 100, idx: 1}, value = 1
# labels = {bucket: 200, idx: 2}, value = 5
# labels = {bucket: Inf, idx: 3}, value = 7
# count_samples (in):
# labels = {idx: 0}, value = 10
# labels = {idx: 1}, value = 7
# sum_samples (in):
# labels = {idx: 0}, value = 2000
# labels = {idx: 1}, value = 1200
#
# output: [
# {idx: 0}, {"100": 2, "200": 4, "Inf": 10}, 10, 2000
# {idx: 1}, {"100": 1, "200": 5, "Inf": 7}, 7, 1200
# ]
buckets_by_labels: dict[frozenset[tuple[str, str]], dict[str, int]] = {}
for s in bucket_samples:
bucket = s.labels["le"]
labels_key = frozenset(_strip_label(s.labels, "le").items())
if labels_key not in buckets_by_labels:
buckets_by_labels[labels_key] = {}
buckets_by_labels[labels_key][bucket] = int(s.value)
counts_by_labels: dict[frozenset[tuple[str, str]], int] = {}
for s in count_samples:
labels_key = frozenset(s.labels.items())
counts_by_labels[labels_key] = int(s.value)
sums_by_labels: dict[frozenset[tuple[str, str]], float] = {}
for s in sum_samples:
labels_key = frozenset(s.labels.items())
sums_by_labels[labels_key] = s.value
assert set(buckets_by_labels.keys()) == set(
counts_by_labels.keys()) == set(sums_by_labels.keys())
output = []
label_keys = list(buckets_by_labels.keys())
for k in label_keys:
labels = dict(k)
output.append((labels, buckets_by_labels[k], counts_by_labels[k],
sums_by_labels[k]))
return output
def _digest_num_accepted_by_pos_samples(
samples: list[Sample]) -> list[tuple[dict[str, str], list[int]]]:
#
# In the case of DP, we have an indigestable
# per-position-per-engine count as a list of
# labelled samples
#
# samples (in):
# labels = {pos: 0, idx: 0}, value = 10
# labels = {pos: 1, idx: 0}, value = 7
# labels = {pos: 2, idx: 0}, value = 2
# labels = {pos: 0, idx: 1}, value = 5
# labels = {pos: 1, idx: 1}, value = 3
# labels = {pos: 2, idx: 1}, value = 1
#
# output: [
# {idx: 0}, [10, 7, 2]
# {idx: 1}, [5, 3, 1]
# ]
#
max_pos = 0
values_by_labels: dict[frozenset[tuple[str, str]], dict[int, int]] = {}
for s in samples:
position = int(s.labels["position"])
max_pos = max(max_pos, position)
labels_key = frozenset(_strip_label(s.labels, "position").items())
if labels_key not in values_by_labels:
values_by_labels[labels_key] = {}
values_by_labels[labels_key][position] = int(s.value)
output = []
for labels_key, values_by_position in values_by_labels.items():
labels = dict(labels_key)
values = [0] * (max_pos + 1)
for pos, val in values_by_position.items():
values[pos] = val
output.append((labels, values))
return output

239
vllm/v1/metrics/stats.py Normal file
View File

@@ -0,0 +1,239 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Optional
from vllm.v1.spec_decode.metrics import SpecDecodingStats
if TYPE_CHECKING:
from vllm.v1.engine import EngineCoreEvent, EngineCoreOutput, FinishReason
from vllm.v1.engine.output_processor import RequestState
@dataclass
class PrefixCacheStats:
"""Stores prefix cache hit statistics."""
# Whether reset_prefix_cache was invoked.
reset: bool = False
# The number of requests in this update.
requests: int = 0
# The number of queries in these requests. Note that "queries" here
# means the number of tokens that were queried from the cache.
queries: int = 0
# The number of hits in these requests.
hits: int = 0
@dataclass
class SchedulerStats:
"""Stats associated with the scheduler."""
num_running_reqs: int = 0
num_waiting_reqs: int = 0
gpu_cache_usage: float = 0.0
prefix_cache_stats: PrefixCacheStats = field(
default_factory=PrefixCacheStats)
spec_decoding_stats: Optional[SpecDecodingStats] = None
@dataclass
class LoRAStats:
waiting_requests: set[str] = field(default_factory=set)
running_requests: set[str] = field(default_factory=set)
@dataclass
class RequestStateStats:
"""Stats that need to be tracked across delta updates."""
num_generation_tokens: int = 0
# This is a engine frontend timestamp (wall-clock)
arrival_time: float = 0.0
# These are engine core timestamps (monotonic)
queued_ts: float = 0.0
scheduled_ts: float = 0.0
first_token_ts: float = 0.0
last_token_ts: float = 0.0
@dataclass
class FinishedRequestStats:
"""Stats associated with a finished request."""
finish_reason: "FinishReason"
e2e_latency: float = 0.0
num_prompt_tokens: int = 0
num_generation_tokens: int = 0
max_tokens_param: Optional[int] = None
queued_time: float = 0.0
prefill_time: float = 0.0
inference_time: float = 0.0
decode_time: float = 0.0
class IterationStats:
"""Stats associated with a single set of EngineCoreOutputs."""
def __init__(self):
self.iteration_timestamp = time.time()
self.num_generation_tokens = 0
self.num_prompt_tokens = 0
self.num_preempted_reqs = 0
self.finished_requests: list[FinishedRequestStats] = []
self.max_num_generation_tokens_iter: list[int] = []
self.n_params_iter: list[int] = []
self.time_to_first_tokens_iter: list[float] = []
self.time_per_output_tokens_iter: list[float] = []
self.waiting_lora_adapters: dict[str, int] = {}
self.running_lora_adapters: dict[str, int] = {}
def _time_since(self, start: float) -> float:
"""Calculate an interval relative to this iteration's timestamp."""
return self.iteration_timestamp - start
def update_from_output(self, output: "EngineCoreOutput",
engine_core_timestamp: float, is_prefilling: bool,
prompt_len: int, req_stats: RequestStateStats,
lora_stats: Optional[LoRAStats]):
num_new_generation_tokens = len(output.new_token_ids)
self.num_generation_tokens += num_new_generation_tokens
if is_prefilling:
assert num_new_generation_tokens > 0
self.num_prompt_tokens += prompt_len
first_token_latency = self._time_since(req_stats.arrival_time)
self.time_to_first_tokens_iter.append(first_token_latency)
req_stats.num_generation_tokens += num_new_generation_tokens
# Process request-level engine core events
if output.events is not None:
self.update_from_events(output.request_id, output.events,
is_prefilling, req_stats, lora_stats)
# Process the batch-level "new tokens" engine core event
if is_prefilling:
req_stats.first_token_ts = engine_core_timestamp
else:
tpot = engine_core_timestamp - req_stats.last_token_ts
self.time_per_output_tokens_iter.append(tpot)
req_stats.last_token_ts = engine_core_timestamp
def update_from_events(self, req_id: str, events: list["EngineCoreEvent"],
is_prefilling: bool, req_stats: RequestStateStats,
lora_stats: Optional[LoRAStats]):
# Avoid circular dependency
from vllm.v1.engine import EngineCoreEventType
for event in events:
if event.type == EngineCoreEventType.QUEUED:
req_stats.queued_ts = event.timestamp
if lora_stats is not None:
lora_stats.waiting_requests.add(req_id)
elif event.type == EngineCoreEventType.SCHEDULED:
if req_stats.scheduled_ts == 0.0: # ignore preemptions
req_stats.scheduled_ts = event.timestamp
LoRARequestStates.scheduled_request(lora_stats, req_id)
elif event.type == EngineCoreEventType.PREEMPTED:
self.num_preempted_reqs += 1
LoRARequestStates.preempted_request(lora_stats, req_id)
def update_from_finished_request(self, finish_reason: "FinishReason",
num_prompt_tokens: int,
max_tokens_param: Optional[int],
req_stats: RequestStateStats):
e2e_latency = self._time_since(req_stats.arrival_time)
# Queued interval is from first QUEUED event to first SCHEDULED
queued_time = req_stats.scheduled_ts - req_stats.queued_ts
# Prefill interval is from first SCHEDULED to first NEW_TOKEN
# Any preemptions during prefill is included in the interval
prefill_time = req_stats.first_token_ts - req_stats.scheduled_ts
# Decode interval is from first NEW_TOKEN to last NEW_TOKEN
# Any preemptions during decode are included
decode_time = req_stats.last_token_ts - req_stats.first_token_ts
# Inference interval is from first SCHEDULED to last NEW_TOKEN
# Any preemptions during prefill or decode are included
inference_time = req_stats.last_token_ts - req_stats.scheduled_ts
finished_req = \
FinishedRequestStats(finish_reason=finish_reason,
e2e_latency=e2e_latency,
num_prompt_tokens=num_prompt_tokens,
num_generation_tokens=req_stats.num_generation_tokens,
max_tokens_param=max_tokens_param,
queued_time=queued_time,
prefill_time=prefill_time,
inference_time=inference_time,
decode_time=decode_time)
self.finished_requests.append(finished_req)
class LoRARequestStates:
"""Per-LoRA request state stats."""
def __init__(self):
self.lora_name_to_stats: dict[str, LoRAStats] = {}
def get_stats(self, req_state: 'RequestState') -> Optional[LoRAStats]:
if req_state.lora_name is None:
return None
if req_state.lora_name not in self.lora_name_to_stats:
self.lora_name_to_stats[req_state.lora_name] = LoRAStats()
return self.lora_name_to_stats[req_state.lora_name]
def add_request(self, req_state: 'RequestState'):
if (lora_stats := self.get_stats(req_state)) is not None:
lora_stats.waiting_requests.add(req_state.request_id)
def finish_request(self, req_state: 'RequestState'):
if req_state.lora_name is None:
return
lora_stats = self.lora_name_to_stats[req_state.lora_name]
lora_stats.running_requests.remove(req_state.request_id)
def abort_request(self, req_state: 'RequestState'):
if req_state.lora_name is None:
return
lora_stats = self.lora_name_to_stats[req_state.lora_name]
lora_stats.waiting_requests.discard(req_state.request_id)
lora_stats.running_requests.discard(req_state.request_id)
# Break the pattern for this lifecycle methods so we can
# call this from IterationStats.update_from_events()
@staticmethod
def scheduled_request(lora_stats: Optional[LoRAStats], request_id: str):
if lora_stats is None:
return
lora_stats.waiting_requests.remove(request_id)
lora_stats.running_requests.add(request_id)
@staticmethod
def preempted_request(lora_stats: Optional[LoRAStats], request_id: str):
if lora_stats is None:
return
lora_stats.running_requests.remove(request_id)
lora_stats.waiting_requests.add(request_id)
def update_iteration_stats(self,
iteration_stats: Optional[IterationStats]):
if iteration_stats is None:
return
for lora_name, stats in self.lora_name_to_stats.items():
if stats.waiting_requests:
iteration_stats.waiting_lora_adapters[lora_name] = \
len(stats.waiting_requests)
if stats.running_requests:
iteration_stats.running_lora_adapters[lora_name] = \
len(stats.running_requests)