support prometheus metrics (#1853)
Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com> Co-authored-by: Byron Hsu <byronhsu1230@gmail.com>
This commit is contained in:
@@ -31,6 +31,7 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
|
||||
|
||||
import dataclasses
|
||||
import logging
|
||||
import time
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@@ -254,6 +255,16 @@ class Req:
|
||||
# For Qwen2-VL
|
||||
self.mrope_position_delta = [] # use mutable object
|
||||
|
||||
# Lifetime traces
|
||||
# time when request is created and added to waitlist
|
||||
self.created_time = None
|
||||
# time when request is added to prefill batch
|
||||
self.queued_time = None
|
||||
# time when request is being processed
|
||||
self.started_time = None
|
||||
# time when request is finished
|
||||
self.finished_time = None
|
||||
|
||||
# whether request reached finished condition
|
||||
def finished(self) -> bool:
|
||||
return self.finished_reason is not None
|
||||
@@ -1028,6 +1039,9 @@ class ScheduleBatch:
|
||||
f"#req={(len(self.reqs))})"
|
||||
)
|
||||
|
||||
def mark_reqs_started(self):
|
||||
for req in self.reqs:
|
||||
req.started_time = time.time()
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ModelWorkerBatch:
|
||||
|
||||
@@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum, auto
|
||||
@@ -306,6 +307,7 @@ class PrefillAdder:
|
||||
):
|
||||
# Non-chunked prefill
|
||||
self.can_run_list.append(req)
|
||||
req.queued_time = time.time()
|
||||
self.tree_cache.inc_lock_ref(req.last_node)
|
||||
self._prefill_one_req(
|
||||
prefix_len,
|
||||
@@ -324,6 +326,7 @@ class PrefillAdder:
|
||||
req.extend_input_len = trunc_len
|
||||
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len]
|
||||
self.can_run_list.append(req)
|
||||
req.queued_time = time.time()
|
||||
self.new_inflight_req = req
|
||||
self.tree_cache.inc_lock_ref(req.last_node)
|
||||
self._prefill_one_req(prefix_len, trunc_len, 0)
|
||||
|
||||
@@ -62,6 +62,8 @@ from sglang.srt.managers.tp_worker import TpModelWorker
|
||||
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
|
||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache
|
||||
from sglang.srt.metrics.metrics_collector import PrometheusMetricsCollector
|
||||
from sglang.srt.metrics.metrics_types import Stats
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
broadcast_pyobj,
|
||||
@@ -222,7 +224,8 @@ class Scheduler:
|
||||
self.forward_ct = 0
|
||||
self.forward_ct_decode = 0
|
||||
self.num_generated_tokens = 0
|
||||
self.last_stats_tic = time.time()
|
||||
self.last_stats_tic = time.time() # time of last stats for every iter
|
||||
self.last_log_tic = time.time() # time of last log for print decode log
|
||||
self.stream_interval = server_args.stream_interval
|
||||
|
||||
# Init chunked prefill
|
||||
@@ -291,6 +294,15 @@ class Scheduler:
|
||||
],
|
||||
with_stack=True,
|
||||
)
|
||||
# Init metrics stats
|
||||
self.stats = Stats()
|
||||
self.metrics_collector = PrometheusMetricsCollector(
|
||||
labels={
|
||||
"model_name": self.server_args.served_model_name,
|
||||
# TODO: Add lora name/path in the future,
|
||||
},
|
||||
max_model_len=self.max_total_num_tokens,
|
||||
)
|
||||
|
||||
def watchdog_thread(self):
|
||||
self.watchdog_last_forward_ct = 0
|
||||
@@ -338,6 +350,11 @@ class Scheduler:
|
||||
else:
|
||||
self.check_memory()
|
||||
self.new_token_ratio = self.init_new_token_ratio
|
||||
# log stats
|
||||
if self.is_generation and self.server_args.enable_metrics:
|
||||
stats = self.get_stats(batch)
|
||||
self.log_stats(stats)
|
||||
self.last_stats_tic = time.time()
|
||||
|
||||
self.last_batch = batch
|
||||
|
||||
@@ -476,6 +493,7 @@ class Scheduler:
|
||||
self.max_req_len - len(req.origin_input_ids) - 1,
|
||||
)
|
||||
|
||||
req.created_time = time.time()
|
||||
self.waiting_queue.append(req)
|
||||
|
||||
def handle_embedding_request(
|
||||
@@ -504,9 +522,11 @@ class Scheduler:
|
||||
num_used = self.max_total_num_tokens - (
|
||||
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
||||
)
|
||||
throughput = self.num_generated_tokens / (time.time() - self.last_stats_tic)
|
||||
throughput = self.num_generated_tokens / (time.time() - self.last_log_tic)
|
||||
self.num_generated_tokens = 0
|
||||
self.last_stats_tic = time.time()
|
||||
self.last_log_tic = time.time()
|
||||
# set system stats
|
||||
self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
|
||||
num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
|
||||
logger.info(
|
||||
f"Decode batch. "
|
||||
@@ -676,6 +696,9 @@ class Scheduler:
|
||||
self.token_to_kv_pool.available_size()
|
||||
+ self.tree_cache.evictable_size()
|
||||
)
|
||||
# set system stats
|
||||
self.stats.cache_hit_rate = round(100.0 * tree_cache_hit_rate, 2)
|
||||
self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
|
||||
|
||||
if num_mixed_running > 0:
|
||||
logger.info(
|
||||
@@ -770,6 +793,7 @@ class Scheduler:
|
||||
if self.is_generation:
|
||||
if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
batch.mark_reqs_started()
|
||||
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
|
||||
model_worker_batch
|
||||
)
|
||||
@@ -789,6 +813,88 @@ class Scheduler:
|
||||
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
|
||||
ret = embeddings, model_worker_batch.bid
|
||||
return ret
|
||||
def get_stats(self,batch: ScheduleBatch):
|
||||
# TODO: get stats for chunked prefill
|
||||
|
||||
now = time.time()
|
||||
# system stats
|
||||
# Scheduler State
|
||||
new_seq: int = 0
|
||||
num_running_req = len(self.running_batch.reqs) if self.running_batch else 0
|
||||
num_waiting_req = len(self.waiting_queue)
|
||||
# Cache State
|
||||
cache_hit_rate: float = 0.0
|
||||
token_usage: float = 0.0
|
||||
|
||||
# set stats from prefill
|
||||
if self.stats is not None:
|
||||
# new_seq=self.stats.new_seq
|
||||
cache_hit_rate=self.stats.cache_hit_rate
|
||||
token_usage=self.stats.token_usage
|
||||
# Iteration stats
|
||||
num_prompt_tokens_iter = 0
|
||||
num_generation_tokens_iter = 0
|
||||
time_to_first_tokens_iter: List[float] = []
|
||||
time_per_output_tokens_iter: List[float] = []
|
||||
|
||||
# Request stats
|
||||
# Decode
|
||||
gen_throughput: float = 0.0
|
||||
# Latency
|
||||
time_e2e_requests: List[float] = []
|
||||
time_waiting_requests: List[float] = []
|
||||
# Metadata
|
||||
num_prompt_tokens_requests: List[int] = []
|
||||
num_generation_tokens_requests: List[int] = []
|
||||
finished_reason_requests: List[str] = []
|
||||
|
||||
# _, next_token_ids, _ = result
|
||||
if batch is not None:
|
||||
num_generation_tokens_iter = len(batch.output_ids)
|
||||
gen_throughput = round(num_generation_tokens_iter / (now - self.last_stats_tic), 2)
|
||||
|
||||
for i, req in enumerate(batch.reqs):
|
||||
# NOTE: Batch forward mode is extend befor start decode,
|
||||
if batch.forward_mode.is_extend():
|
||||
num_prompt_tokens_iter=len(batch.input_ids)+sum(batch.prefix_lens)
|
||||
time_to_first_tokens_iter.append(now - req.started_time)
|
||||
else:
|
||||
time_per_output_tokens_iter.append(now-self.last_stats_tic)
|
||||
|
||||
if req.finished():
|
||||
time_e2e_requests.append(now - req.created_time)
|
||||
time_waiting_requests.append(req.queued_time - req.created_time)
|
||||
num_prompt_tokens_requests.append(len(req.origin_input_ids))
|
||||
num_generation_tokens_requests.append(len(req.output_ids))
|
||||
finished_reason_requests.append(
|
||||
req.finished_reason.to_json()
|
||||
if req.finished_reason is not None
|
||||
else None)
|
||||
|
||||
return Stats(
|
||||
new_seq=new_seq,
|
||||
num_running_req=num_running_req,
|
||||
num_waiting_req=num_waiting_req,
|
||||
cache_hit_rate=cache_hit_rate,
|
||||
token_usage=token_usage,
|
||||
num_prompt_tokens_iter=num_prompt_tokens_iter,
|
||||
num_generation_tokens_iter=num_generation_tokens_iter,
|
||||
time_to_first_tokens_iter=time_to_first_tokens_iter,
|
||||
time_per_output_tokens_iter=time_per_output_tokens_iter,
|
||||
gen_throughput=gen_throughput,
|
||||
time_e2e_requests=time_e2e_requests,
|
||||
time_waiting_requests=time_waiting_requests,
|
||||
num_prompt_tokens_requests=num_prompt_tokens_requests,
|
||||
num_generation_tokens_requests=num_generation_tokens_requests,
|
||||
finished_reason_requests=finished_reason_requests,
|
||||
context_len=self.model_config.context_len,
|
||||
max_total_num_tokens=self.max_total_num_tokens,
|
||||
max_prefill_tokens=self.max_prefill_tokens,
|
||||
max_running_requests=self.max_running_requests,
|
||||
)
|
||||
|
||||
def log_stats(self,stats:Stats):
|
||||
self.metrics_collector.log_stats(stats)
|
||||
|
||||
def process_batch_result(self, batch: ScheduleBatch, result):
|
||||
if batch.forward_mode.is_decode():
|
||||
|
||||
297
python/sglang/srt/metrics/metrics_collector.py
Normal file
297
python/sglang/srt/metrics/metrics_collector.py
Normal file
@@ -0,0 +1,297 @@
|
||||
"""
|
||||
Copyright 2023-2024 SGLang Team
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
"""Utilities for Prometheus Metrics Collection."""
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Counter as CollectionsCounter
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import numpy as np
|
||||
from prometheus_client import Counter, Gauge, Histogram
|
||||
|
||||
from sglang.srt.metrics.metrics_types import Stats
|
||||
|
||||
|
||||
class Metrics:
|
||||
"""
|
||||
SGLang Metrics
|
||||
"""
|
||||
|
||||
def __init__(self, labelnames: List[str], max_model_len):
|
||||
|
||||
# Configuration Stats
|
||||
self.max_total_num_tokens = Gauge(
|
||||
name="sglang:max_total_num_tokens",
|
||||
documentation="Maximum total number of tokens",
|
||||
labelnames=labelnames,
|
||||
multiprocess_mode="min",
|
||||
) # static across processes
|
||||
|
||||
self.max_prefill_tokens = Gauge(
|
||||
name="sglang:max_prefill_tokens",
|
||||
documentation="Maximum prefill tokens",
|
||||
labelnames=labelnames,
|
||||
multiprocess_mode="min",
|
||||
) # static across processes
|
||||
|
||||
self.max_running_requests = Gauge(
|
||||
name="sglang:max_running_requests",
|
||||
documentation="Maximum running requests",
|
||||
labelnames=labelnames,
|
||||
multiprocess_mode="min",
|
||||
) # static across processes
|
||||
|
||||
self.context_len = Gauge(
|
||||
name="sglang:context_len",
|
||||
documentation="Context length",
|
||||
labelnames=labelnames,
|
||||
multiprocess_mode="min",
|
||||
) # static across processes
|
||||
# Decode Stats
|
||||
self.num_running_sys = Gauge(
|
||||
name="sglang:num_requests_running",
|
||||
documentation="Number of requests currently running on GPU",
|
||||
labelnames=labelnames,
|
||||
multiprocess_mode="sum",
|
||||
)
|
||||
self.num_waiting_sys = Gauge(
|
||||
name="sglang:num_requests_waiting",
|
||||
documentation="Number of requests waiting to be processed.",
|
||||
labelnames=labelnames,
|
||||
multiprocess_mode="sum",
|
||||
)
|
||||
self.gen_throughput = Gauge(
|
||||
name="sglang:gen_throughput",
|
||||
documentation="Gen token throughput (token/s)",
|
||||
labelnames=labelnames,
|
||||
multiprocess_mode="sum",
|
||||
)
|
||||
self.token_usage = Gauge(
|
||||
name="sglang:token_usage",
|
||||
documentation="Total token usage",
|
||||
labelnames=labelnames,
|
||||
multiprocess_mode="sum",
|
||||
)
|
||||
# System Stats
|
||||
# KV Cache Usage in %
|
||||
# self.gpu_cache_usage_sys = Gauge(
|
||||
# "gpu_cache_usage_perc",
|
||||
# "GPU KV-cache usage. 1 means 100 percent usage.",
|
||||
# labelnames=labelnames,
|
||||
# multiprocess_mode="sum")
|
||||
|
||||
self.new_seq = Gauge(
|
||||
name="sglang:new_seq",
|
||||
documentation="Number of new sequences",
|
||||
labelnames=labelnames,
|
||||
multiprocess_mode="sum",
|
||||
)
|
||||
self.new_token = Gauge(
|
||||
name="sglang:new_token",
|
||||
documentation="Number of new token",
|
||||
labelnames=labelnames,
|
||||
multiprocess_mode="sum",
|
||||
)
|
||||
# Prefix caching block hit rate
|
||||
self.cached_token = Gauge(
|
||||
name="sglang:cached_token",
|
||||
documentation="Number of cached token",
|
||||
labelnames=labelnames,
|
||||
multiprocess_mode="sum",
|
||||
)
|
||||
self.cache_hit_rate = Gauge(
|
||||
name="sglang:cache_hit_rate",
|
||||
documentation="Cache hit rate",
|
||||
labelnames=labelnames,
|
||||
multiprocess_mode="sum",
|
||||
)
|
||||
self.queue_req = Gauge(
|
||||
name="sglang:queue_req",
|
||||
documentation="Number of queued requests",
|
||||
labelnames=labelnames,
|
||||
multiprocess_mode="sum",
|
||||
)
|
||||
|
||||
# Iteration stats
|
||||
self.counter_prompt_tokens = Counter(
|
||||
name="sglang:prompt_tokens_total",
|
||||
documentation="Number of prefill tokens processed.",
|
||||
labelnames=labelnames)
|
||||
self.counter_generation_tokens = Counter(
|
||||
name="sglang:generation_tokens_total",
|
||||
documentation="Number of generation tokens processed.",
|
||||
labelnames=labelnames)
|
||||
self.histogram_time_to_first_token = Histogram(
|
||||
name="sglang:time_to_first_token_seconds",
|
||||
documentation="Histogram of time to first token in seconds.",
|
||||
labelnames=labelnames,
|
||||
buckets=[
|
||||
0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5,
|
||||
0.75, 1.0, 2.5, 5.0, 7.5, 10.0, 15.0, 20.0, 25.0, 30.0
|
||||
])
|
||||
self.histogram_time_per_output_token = Histogram(
|
||||
name="sglang:time_per_output_token_seconds",
|
||||
documentation="Histogram of time per output token in seconds.",
|
||||
labelnames=labelnames,
|
||||
buckets=[
|
||||
0.005, 0.01, 0.015, 0.02, 0.025, 0.03, 0.04, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75,
|
||||
1.0, 2.5
|
||||
])
|
||||
|
||||
# Request Stats
|
||||
# Metadata
|
||||
self.num_prompt_tokens_requests = Histogram(
|
||||
name="sglang:request_prompt_tokens",
|
||||
documentation="Number of prefill tokens processed",
|
||||
labelnames=labelnames,
|
||||
buckets=build_1_2_5_buckets(max_model_len),
|
||||
)
|
||||
self.num_generation_tokens_requests = Histogram(
|
||||
name="sglang:request_generation_tokens",
|
||||
documentation="Number of generation tokens processed.",
|
||||
labelnames=labelnames,
|
||||
buckets=build_1_2_5_buckets(max_model_len),
|
||||
)
|
||||
self.finished_reason_requests = Counter(
|
||||
name="sglang:request_success_total",
|
||||
documentation="Count of successfully processed requests.",
|
||||
labelnames=labelnames + ["finished_reason"],
|
||||
)
|
||||
self.histogram_time_e2e_requests = Histogram(
|
||||
name="sglang:e2e_request_latency_seconds",
|
||||
documentation="Histogram of End-to-end request latency in seconds",
|
||||
labelnames=labelnames,
|
||||
buckets=build_1_2_5_buckets(max_model_len),
|
||||
)
|
||||
self.histogram_time_waiting_requests = Histogram(
|
||||
name="sglang:waiting_request_latency_seconds",
|
||||
documentation="Histogram of request waiting time in seconds",
|
||||
labelnames=labelnames,
|
||||
buckets=build_1_2_5_buckets(max_model_len),
|
||||
)
|
||||
self.histogram_time_decode_requests = Histogram(
|
||||
name="sglang:decode_request_latency_seconds",
|
||||
documentation="Histogram of request decoding time in seconds",
|
||||
labelnames=labelnames,
|
||||
buckets=build_1_2_5_buckets(max_model_len),
|
||||
)
|
||||
|
||||
|
||||
class MetricsCollector(ABC):
|
||||
"""
|
||||
SGLang Metrics Collector
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def log_stats(self, stats: Stats) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class PrometheusMetricsCollector(MetricsCollector):
|
||||
"""
|
||||
SGLang Metrics Collector
|
||||
"""
|
||||
|
||||
def __init__(self, labels: Dict[str, str], max_model_len: int) -> None:
|
||||
self.labels = labels
|
||||
self.metrics = Metrics(
|
||||
labelnames=list(labels.keys()), max_model_len=max_model_len
|
||||
)
|
||||
|
||||
def _log_gauge(self, gauge, data: Union[int, float]) -> None:
|
||||
# Convenience function for logging to gauge.
|
||||
gauge.labels(**self.labels).set(data)
|
||||
|
||||
def _log_counter(self, counter, data: Union[int, float]) -> None:
|
||||
# Convenience function for logging to counter.
|
||||
counter.labels(**self.labels).inc(data)
|
||||
|
||||
def _log_counter_labels(
|
||||
self, counter, data: CollectionsCounter, label_key: str
|
||||
) -> None:
|
||||
# Convenience function for collection counter of labels.
|
||||
for label, count in data.items():
|
||||
counter.labels(**{**self.labels, label_key: label}).inc(count)
|
||||
|
||||
def _log_histogram(self, histogram, data: Union[List[int], List[float]]) -> None:
|
||||
# Convenience function for logging list to histogram.
|
||||
for datum in data:
|
||||
histogram.labels(**self.labels).observe(datum)
|
||||
|
||||
def log_stats(self, stats: Stats) -> None:
|
||||
self._log_gauge(self.metrics.max_total_num_tokens, stats.max_total_num_tokens)
|
||||
self._log_gauge(self.metrics.max_prefill_tokens, stats.max_prefill_tokens)
|
||||
self._log_gauge(self.metrics.max_running_requests, stats.max_running_requests)
|
||||
self._log_gauge(self.metrics.context_len, stats.context_len)
|
||||
self._log_histogram(
|
||||
self.metrics.num_prompt_tokens_requests, stats.num_prompt_tokens_requests
|
||||
)
|
||||
self._log_histogram(
|
||||
self.metrics.num_generation_tokens_requests,
|
||||
stats.num_generation_tokens_requests,
|
||||
)
|
||||
|
||||
self._log_counter(self.metrics.counter_prompt_tokens,
|
||||
stats.num_prompt_tokens_iter)
|
||||
self._log_counter(self.metrics.counter_generation_tokens,
|
||||
stats.num_generation_tokens_iter)
|
||||
self._log_histogram(self.metrics.histogram_time_to_first_token,
|
||||
stats.time_to_first_tokens_iter)
|
||||
self._log_histogram(self.metrics.histogram_time_per_output_token,
|
||||
stats.time_per_output_tokens_iter)
|
||||
|
||||
# self._log_gauge(self.metrics.gpu_cache_usage_sys, stats.gpu_cache_usage_sys)
|
||||
self._log_gauge(self.metrics.num_running_sys, stats.num_running_req)
|
||||
self._log_gauge(self.metrics.num_waiting_sys, stats.num_waiting_req)
|
||||
self._log_gauge(self.metrics.gen_throughput, stats.gen_throughput)
|
||||
self._log_gauge(self.metrics.token_usage, stats.token_usage)
|
||||
self._log_histogram(
|
||||
self.metrics.histogram_time_e2e_requests, stats.time_e2e_requests
|
||||
)
|
||||
self._log_histogram(
|
||||
self.metrics.histogram_time_waiting_requests, stats.time_waiting_requests
|
||||
)
|
||||
self._log_histogram(
|
||||
self.metrics.histogram_time_decode_requests, stats.time_decode_requests
|
||||
)
|
||||
self._log_gauge(self.metrics.new_seq, stats.new_seq)
|
||||
self._log_gauge(self.metrics.new_token, stats.new_token)
|
||||
self._log_gauge(self.metrics.cached_token, stats.cached_token)
|
||||
self._log_gauge(self.metrics.cache_hit_rate, stats.cache_hit_rate)
|
||||
self._log_gauge(self.metrics.queue_req, stats.queue_req)
|
||||
|
||||
|
||||
def build_1_2_5_buckets(max_value: int) -> List[int]:
|
||||
"""
|
||||
Builds a list of buckets with increasing powers of 10 multiplied by
|
||||
mantissa values (1, 2, 5) until the value exceeds the specified maximum.
|
||||
|
||||
Example:
|
||||
>>> build_1_2_5_buckets(100)
|
||||
[1, 2, 5, 10, 20, 50, 100]
|
||||
"""
|
||||
mantissa_lst = [1, 2, 5]
|
||||
exponent = 0
|
||||
buckets: List[int] = []
|
||||
while True:
|
||||
for m in mantissa_lst:
|
||||
value = m * 10**exponent
|
||||
if value <= max_value:
|
||||
buckets.append(value)
|
||||
else:
|
||||
return buckets
|
||||
exponent += 1
|
||||
57
python/sglang/srt/metrics/metrics_types.py
Normal file
57
python/sglang/srt/metrics/metrics_types.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""
|
||||
Copyright 2023-2024 SGLang Team
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
"""Metrics Types"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List
|
||||
|
||||
|
||||
@dataclass
|
||||
class Stats:
|
||||
# config
|
||||
max_total_num_tokens: int = 0
|
||||
max_prefill_tokens: int = 0
|
||||
max_running_requests: int = 0
|
||||
context_len: int = 0
|
||||
# request stats
|
||||
num_prompt_tokens_requests: List[int] = field(default_factory=list)
|
||||
num_generation_tokens_requests: List[int] = field(default_factory=list)
|
||||
finished_reason_requests: List[str] = field(default_factory=list)
|
||||
# decode stats
|
||||
num_running_req: int = 0
|
||||
num_waiting_req: int = 0
|
||||
gen_throughput: float = 0.0
|
||||
num_token: int = 0
|
||||
token_usage: float = 0.0
|
||||
waiting_queue: int = 0
|
||||
time_e2e_requests: List[float] = field(default_factory=list)
|
||||
time_waiting_requests: List[float] = field(default_factory=list)
|
||||
time_decode_requests: List[float] = field(default_factory=list)
|
||||
# system stats
|
||||
token_usage: float = 0.0
|
||||
is_mixed_chunk: bool = False
|
||||
new_seq: int = 0
|
||||
new_token: int = 0
|
||||
cached_token: int = 0
|
||||
cache_hit_rate: float = 0.0
|
||||
running_req: int = 0
|
||||
queue_req: int = 0
|
||||
|
||||
# Iteration stats (should have _iter suffix)
|
||||
num_prompt_tokens_iter: int = 0
|
||||
num_generation_tokens_iter: int = 0
|
||||
time_to_first_tokens_iter: List[float] = field(default_factory=list)
|
||||
time_per_output_tokens_iter: List[float] = field(default_factory=list)
|
||||
@@ -25,12 +25,15 @@ import json
|
||||
import logging
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
from http import HTTPStatus
|
||||
from typing import AsyncIterator, Dict, List, Optional, Union
|
||||
|
||||
import orjson
|
||||
from starlette.routing import Mount
|
||||
|
||||
# Fix a bug of Python threading
|
||||
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
||||
@@ -86,6 +89,10 @@ from sglang.utils import get_exception_traceback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Temporary directory for prometheus multiprocess mode
|
||||
# Cleaned up automatically when this object is garbage collected
|
||||
prometheus_multiproc_dir: tempfile.TemporaryDirectory
|
||||
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
|
||||
|
||||
@@ -412,6 +419,18 @@ def launch_engine(
|
||||
for i in range(len(scheduler_pipe_readers)):
|
||||
scheduler_pipe_readers[i].recv()
|
||||
|
||||
def add_prometheus_middleware(app: FastAPI):
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.1/vllm/entrypoints/openai/api_server.py#L216
|
||||
from prometheus_client import CollectorRegistry, make_asgi_app, multiprocess
|
||||
|
||||
registry = CollectorRegistry()
|
||||
multiprocess.MultiProcessCollector(registry)
|
||||
metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
|
||||
|
||||
# Workaround for 307 Redirect for /metrics
|
||||
metrics_route.path_regex = re.compile("^/metrics(?P<path>.*)$")
|
||||
app.routes.append(metrics_route)
|
||||
|
||||
|
||||
def launch_server(
|
||||
server_args: ServerArgs,
|
||||
@@ -439,6 +458,11 @@ def launch_server(
|
||||
if server_args.api_key:
|
||||
add_api_key_middleware(app, server_args.api_key)
|
||||
|
||||
# add prometheus middleware
|
||||
if server_args.enable_metrics:
|
||||
_set_prometheus_env()
|
||||
add_prometheus_middleware(app)
|
||||
|
||||
# Send a warmup request
|
||||
t = threading.Thread(
|
||||
target=_wait_and_warmup, args=(server_args, pipe_finish_writer)
|
||||
@@ -466,6 +490,21 @@ def launch_server(
|
||||
finally:
|
||||
t.join()
|
||||
|
||||
def _set_prometheus_env():
|
||||
# Set prometheus multiprocess directory
|
||||
# sglang uses prometheus multiprocess mode
|
||||
# we need to set this before importing prometheus_client
|
||||
# https://prometheus.github.io/client_python/multiprocess/
|
||||
global prometheus_multiproc_dir
|
||||
if "PROMETHEUS_MULTIPROC_DIR" in os.environ:
|
||||
logger.debug(f"User set PROMETHEUS_MULTIPROC_DIR detected.")
|
||||
prometheus_multiproc_dir = tempfile.TemporaryDirectory(
|
||||
dir=os.environ["PROMETHEUS_MULTIPROC_DIR"]
|
||||
)
|
||||
else:
|
||||
prometheus_multiproc_dir = tempfile.TemporaryDirectory()
|
||||
os.environ["PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
|
||||
logger.debug(f"PROMETHEUS_MULTIPROC_DIR: {os.environ['PROMETHEUS_MULTIPROC_DIR']}")
|
||||
|
||||
def _set_envs_and_config(server_args: ServerArgs):
|
||||
# Set global environments
|
||||
|
||||
@@ -70,6 +70,7 @@ class ServerArgs:
|
||||
log_level_http: Optional[str] = None
|
||||
log_requests: bool = False
|
||||
show_time_cost: bool = False
|
||||
enable_metrics: bool = False
|
||||
|
||||
# Other
|
||||
api_key: Optional[str] = None
|
||||
@@ -414,6 +415,12 @@ class ServerArgs:
|
||||
action="store_true",
|
||||
help="Show time cost of custom marks.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-metrics",
|
||||
action="store_true",
|
||||
help="Enable log prometheus metrics.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--api-key",
|
||||
type=str,
|
||||
|
||||
Reference in New Issue
Block a user