Files
enginex-c_series-vllm/v1/metrics/ray_wrappers.py
2025-08-13 19:46:19 +08:00

132 lines
4.4 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from typing import Optional, Union
from vllm.config import VllmConfig
from vllm.v1.metrics.loggers import PrometheusStatLogger
from vllm.v1.spec_decode.metrics import SpecDecodingProm
try:
from ray.util import metrics as ray_metrics
from ray.util.metrics import Metric
except ImportError:
ray_metrics = None
class RayPrometheusMetric:
def __init__(self):
if ray_metrics is None:
raise ImportError(
"RayPrometheusMetric requires Ray to be installed.")
self.metric: Metric = None
def labels(self, *labels, **labelskwargs):
if labelskwargs:
for k, v in labelskwargs.items():
if not isinstance(v, str):
labelskwargs[k] = str(v)
self.metric.set_default_tags(labelskwargs)
if labels:
if len(labels) != len(self.metric._tag_keys):
raise ValueError(
"Number of labels must match the number of tag keys. "
f"Expected {len(self.metric._tag_keys)}, got {len(labels)}"
)
self.metric.set_default_tags(
dict(zip(self.metric._tag_keys, labels)))
return self
class RayGaugeWrapper(RayPrometheusMetric):
"""Wraps around ray.util.metrics.Gauge to provide same API as
prometheus_client.Gauge"""
def __init__(self,
name: str,
documentation: Optional[str] = "",
labelnames: Optional[list[str]] = None):
labelnames_tuple = tuple(labelnames) if labelnames else None
self.metric = ray_metrics.Gauge(name=name,
description=documentation,
tag_keys=labelnames_tuple)
def set(self, value: Union[int, float]):
return self.metric.set(value)
def set_to_current_time(self):
# ray metrics doesn't have set_to_current time, https://docs.ray.io/en/latest/_modules/ray/util/metrics.html
return self.metric.set(time.time())
class RayCounterWrapper(RayPrometheusMetric):
"""Wraps around ray.util.metrics.Counter to provide same API as
prometheus_client.Counter"""
def __init__(self,
name: str,
documentation: Optional[str] = "",
labelnames: Optional[list[str]] = None):
labelnames_tuple = tuple(labelnames) if labelnames else None
self.metric = ray_metrics.Counter(name=name,
description=documentation,
tag_keys=labelnames_tuple)
def inc(self, value: Union[int, float] = 1.0):
if value == 0:
return
return self.metric.inc(value)
class RayHistogramWrapper(RayPrometheusMetric):
"""Wraps around ray.util.metrics.Histogram to provide same API as
prometheus_client.Histogram"""
def __init__(self,
name: str,
documentation: Optional[str] = "",
labelnames: Optional[list[str]] = None,
buckets: Optional[list[float]] = None):
labelnames_tuple = tuple(labelnames) if labelnames else None
boundaries = buckets if buckets else []
self.metric = ray_metrics.Histogram(name=name,
description=documentation,
tag_keys=labelnames_tuple,
boundaries=boundaries)
def observe(self, value: Union[int, float]):
return self.metric.observe(value)
class RaySpecDecodingProm(SpecDecodingProm):
"""
RaySpecDecodingProm is used by RayMetrics to log to Ray metrics.
Provides the same metrics as SpecDecodingProm but uses Ray's
util.metrics library.
"""
_counter_cls = RayCounterWrapper
class RayPrometheusStatLogger(PrometheusStatLogger):
"""RayPrometheusStatLogger uses Ray metrics instead."""
_gauge_cls = RayGaugeWrapper
_counter_cls = RayCounterWrapper
_histogram_cls = RayHistogramWrapper
_spec_decoding_cls = RaySpecDecodingProm
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
super().__init__(vllm_config, engine_index)
@staticmethod
def _unregister_vllm_metrics():
# No-op on purpose
pass