diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 79fe1cf9f..1abd67424 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -31,7 +31,6 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch import dataclasses import logging -import time from typing import List, Optional, Tuple, Union import torch @@ -255,16 +254,6 @@ class Req: # For Qwen2-VL self.mrope_position_delta = [] # use mutable object - # Lifetime traces - # time when request is created and added to waitlist - self.created_time = None - # time when request is added to prefill batch - self.queued_time = None - # time when request is being processed - self.started_time = None - # time when request is finished - self.finished_time = None - # whether request reached finished condition def finished(self) -> bool: return self.finished_reason is not None @@ -1038,10 +1027,6 @@ class ScheduleBatch: f"#req={(len(self.reqs))})" ) - def mark_reqs_started(self): - for req in self.reqs: - req.started_time = time.time() - @dataclasses.dataclass class ModelWorkerBatch: diff --git a/python/sglang/srt/managers/schedule_policy.py b/python/sglang/srt/managers/schedule_policy.py index 994ffdeb5..2bfdffc42 100644 --- a/python/sglang/srt/managers/schedule_policy.py +++ b/python/sglang/srt/managers/schedule_policy.py @@ -17,7 +17,6 @@ limitations under the License. import os import random -import time from collections import defaultdict from contextlib import contextmanager from enum import Enum, auto @@ -307,7 +306,6 @@ class PrefillAdder: ): # Non-chunked prefill self.can_run_list.append(req) - req.queued_time = time.time() self.tree_cache.inc_lock_ref(req.last_node) self._prefill_one_req( prefix_len, @@ -326,7 +324,6 @@ class PrefillAdder: req.extend_input_len = trunc_len req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len] self.can_run_list.append(req) - req.queued_time = time.time() self.new_inflight_req = req self.tree_cache.inc_lock_ref(req.last_node) self._prefill_one_req(prefix_len, trunc_len, 0) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 49159923a..2dc1944d5 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -62,8 +62,7 @@ from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.radix_cache import RadixCache -from sglang.srt.metrics.metrics_collector import PrometheusMetricsCollector -from sglang.srt.metrics.metrics_types import Stats +from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import ( broadcast_pyobj, @@ -106,6 +105,7 @@ class Scheduler: self.max_loras_per_batch = server_args.max_loras_per_batch self.enable_overlap = server_args.enable_overlap_schedule self.skip_tokenizer_init = server_args.skip_tokenizer_init + self.enable_metrics = server_args.enable_metrics # Init inter-process communication context = zmq.Context(2) @@ -224,8 +224,7 @@ class Scheduler: self.forward_ct = 0 self.forward_ct_decode = 0 self.num_generated_tokens = 0 - self.last_stats_tic = time.time() # time of last stats for every iter - self.last_log_tic = time.time() # time of last log for print decode log + self.last_decode_stats_tic = time.time() self.stream_interval = server_args.stream_interval # Init chunked prefill @@ -294,15 +293,16 @@ class Scheduler: ], with_stack=True, ) + # Init metrics stats - self.stats = Stats() - self.metrics_collector = PrometheusMetricsCollector( - labels={ - "model_name": self.server_args.served_model_name, - # TODO: Add lora name/path in the future, - }, - max_model_len=self.max_total_num_tokens, - ) + self.stats = SchedulerStats() + if self.enable_metrics: + self.metrics_collector = SchedulerMetricsCollector( + labels={ + "model_name": self.server_args.served_model_name, + # TODO: Add lora name/path in the future, + }, + ) def watchdog_thread(self): self.watchdog_last_forward_ct = 0 @@ -350,11 +350,6 @@ class Scheduler: else: self.check_memory() self.new_token_ratio = self.init_new_token_ratio - # log stats - if self.is_generation and self.server_args.enable_metrics: - stats = self.get_stats(batch) - self.log_stats(stats) - self.last_stats_tic = time.time() self.last_batch = batch @@ -493,7 +488,6 @@ class Scheduler: self.max_req_len - len(req.origin_input_ids) - 1, ) - req.created_time = time.time() self.waiting_queue.append(req) def handle_embedding_request( @@ -518,25 +512,68 @@ class Scheduler: self.waiting_queue.append(req) - def print_decode_stats(self): + def log_prefill_stats(self, adder, can_run_list, running_bs, has_inflight): + if isinstance(self.tree_cache, RadixCache): + self.tree_cache_metrics["total"] += ( + adder.log_input_tokens + adder.log_hit_tokens + ) / 10**9 + self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9 + tree_cache_hit_rate = ( + self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"] + ) + else: + tree_cache_hit_rate = 0.0 + num_used = self.max_total_num_tokens - ( self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size() ) - throughput = self.num_generated_tokens / (time.time() - self.last_log_tic) + + logger.info( + f"Prefill batch. " + f"#new-seq: {len(can_run_list)}, " + f"#new-token: {adder.log_input_tokens}, " + f"#cached-token: {adder.log_hit_tokens}, " + f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, " + f"token usage: {num_used / self.max_total_num_tokens:.2f}, " + f"#running-req: {running_bs}, " + f"#queue-req: {len(self.waiting_queue) + has_inflight}" + ) + + if self.enable_metrics: + self.stats.num_running_reqs = running_bs + self.stats.num_used_tokens = num_used + self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2) + self.stats.num_queue_reqs = len(self.waiting_queue) + has_inflight + self.stats.cache_hit_rate = tree_cache_hit_rate + self.metrics_collector.log_stats(self.stats) + + def log_decode_stats(self): + num_used = self.max_total_num_tokens - ( + self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size() + ) + gen_throughput = self.num_generated_tokens / ( + time.time() - self.last_decode_stats_tic + ) self.num_generated_tokens = 0 - self.last_log_tic = time.time() - # set system stats - self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2) + self.last_decode_stats_tic = time.time() num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0 logger.info( f"Decode batch. " f"#running-req: {num_running_reqs}, " f"#token: {num_used}, " f"token usage: {num_used / self.max_total_num_tokens:.2f}, " - f"gen throughput (token/s): {throughput:.2f}, " + f"gen throughput (token/s): {gen_throughput:.2f}, " f"#queue-req: {len(self.waiting_queue)}" ) + if self.enable_metrics: + self.stats.num_running_reqs = num_running_reqs + self.stats.num_used_tokens = num_used + self.stats.token_usage = num_used / self.max_total_num_tokens + self.stats.gen_throughput = gen_throughput + self.stats.num_queue_reqs = len(self.waiting_queue) + self.metrics_collector.log_stats(self.stats) + def check_memory(self): available_size = ( self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size() @@ -612,7 +649,6 @@ class Scheduler: prefix_computed = self.policy.calc_priority(self.waiting_queue) # Prefill policy - num_mixed_running = running_bs if self.is_mixed_chunk else 0 adder = PrefillAdder( self.tree_cache, self.running_batch, @@ -620,7 +656,7 @@ class Scheduler: self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(), self.max_prefill_tokens, self.chunked_prefill_size, - num_mixed_running, + running_bs if self.is_mixed_chunk else 0, ) has_inflight = self.being_chunked_req is not None @@ -677,47 +713,7 @@ class Scheduler: # Print stats if self.tp_rank == 0: - if isinstance(self.tree_cache, RadixCache): - self.tree_cache_metrics["total"] += ( - adder.log_input_tokens + adder.log_hit_tokens - ) / 10**9 - self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9 - tree_cache_hit_rate = ( - self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"] - ) - else: - tree_cache_hit_rate = 0.0 - - num_used = self.max_total_num_tokens - ( - self.token_to_kv_pool.available_size() - + self.tree_cache.evictable_size() - ) - # set system stats - self.stats.cache_hit_rate = round(100.0 * tree_cache_hit_rate, 2) - self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2) - - if num_mixed_running > 0: - logger.info( - f"Prefill batch" - f"(mixed #running-req: {num_mixed_running}). " - f"#new-seq: {len(can_run_list)}, " - f"#new-token: {adder.log_input_tokens}, " - f"#cached-token: {adder.log_hit_tokens}, " - f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, " - f"token usage: {num_used / self.max_total_num_tokens:.2f}, " - f"#queue-req: {len(self.waiting_queue) + has_inflight}" - ) - else: - logger.info( - f"Prefill batch. " - f"#new-seq: {len(can_run_list)}, " - f"#new-token: {adder.log_input_tokens}, " - f"#cached-token: {adder.log_hit_tokens}, " - f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, " - f"token usage: {num_used / self.max_total_num_tokens:.2f}, " - f"#running-req: {running_bs}, " - f"#queue-req: {len(self.waiting_queue) + has_inflight}" - ) + self.log_prefill_stats(adder, can_run_list, running_bs, has_inflight) # Create a new batch new_batch = ScheduleBatch.init_new( @@ -789,7 +785,6 @@ class Scheduler: if self.is_generation: model_worker_batch = batch.get_model_worker_batch() if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0: - batch.mark_reqs_started() logits_output, next_token_ids = self.tp_worker.forward_batch_generation( model_worker_batch ) @@ -810,94 +805,6 @@ class Scheduler: ret = embeddings, model_worker_batch.bid return ret - def get_stats(self, batch: ScheduleBatch): - # TODO: get stats for chunked prefill - - now = time.time() - # system stats - # Scheduler State - new_seq: int = 0 - num_running_req = len(self.running_batch.reqs) if self.running_batch else 0 - num_waiting_req = len(self.waiting_queue) - # Cache State - cache_hit_rate: float = 0.0 - token_usage: float = 0.0 - - # set stats from prefill - if self.stats is not None: - # new_seq=self.stats.new_seq - cache_hit_rate = self.stats.cache_hit_rate - token_usage = self.stats.token_usage - # Iteration stats - num_prompt_tokens_iter = 0 - num_generation_tokens_iter = 0 - time_to_first_tokens_iter: List[float] = [] - time_per_output_tokens_iter: List[float] = [] - - # Request stats - # Decode - gen_throughput: float = 0.0 - # Latency - time_e2e_requests: List[float] = [] - time_waiting_requests: List[float] = [] - # Metadata - num_prompt_tokens_requests: List[int] = [] - num_generation_tokens_requests: List[int] = [] - finished_reason_requests: List[str] = [] - - # _, next_token_ids, _ = result - if batch is not None: - num_generation_tokens_iter = len(batch.output_ids) - gen_throughput = round( - num_generation_tokens_iter / (now - self.last_stats_tic), 2 - ) - - for i, req in enumerate(batch.reqs): - # NOTE: Batch forward mode is extend befor start decode, - if batch.forward_mode.is_extend(): - num_prompt_tokens_iter = len(batch.input_ids) + sum( - batch.prefix_lens - ) - time_to_first_tokens_iter.append(now - req.started_time) - else: - time_per_output_tokens_iter.append(now - self.last_stats_tic) - - if req.finished(): - time_e2e_requests.append(now - req.created_time) - time_waiting_requests.append(req.queued_time - req.created_time) - num_prompt_tokens_requests.append(len(req.origin_input_ids)) - num_generation_tokens_requests.append(len(req.output_ids)) - finished_reason_requests.append( - req.finished_reason.to_json() - if req.finished_reason is not None - else None - ) - - return Stats( - new_seq=new_seq, - num_running_req=num_running_req, - num_waiting_req=num_waiting_req, - cache_hit_rate=cache_hit_rate, - token_usage=token_usage, - num_prompt_tokens_iter=num_prompt_tokens_iter, - num_generation_tokens_iter=num_generation_tokens_iter, - time_to_first_tokens_iter=time_to_first_tokens_iter, - time_per_output_tokens_iter=time_per_output_tokens_iter, - gen_throughput=gen_throughput, - time_e2e_requests=time_e2e_requests, - time_waiting_requests=time_waiting_requests, - num_prompt_tokens_requests=num_prompt_tokens_requests, - num_generation_tokens_requests=num_generation_tokens_requests, - finished_reason_requests=finished_reason_requests, - context_len=self.model_config.context_len, - max_total_num_tokens=self.max_total_num_tokens, - max_prefill_tokens=self.max_prefill_tokens, - max_running_requests=self.max_running_requests, - ) - - def log_stats(self, stats: Stats): - self.metrics_collector.log_stats(stats) - def process_batch_result(self, batch: ScheduleBatch, result): if batch.forward_mode.is_decode(): self.process_batch_result_decode(batch, result) @@ -1035,7 +942,7 @@ class Scheduler: self.tp_rank == 0 and self.forward_ct_decode % self.server_args.decode_log_interval == 0 ): - self.print_decode_stats() + self.log_decode_stats() def add_logprob_return_values( self, diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 11e14e5b6..1db60ef49 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -22,6 +22,7 @@ import logging import os import signal import sys +import time from typing import Dict, List, Optional, Tuple, Union import fastapi @@ -52,6 +53,7 @@ from sglang.srt.managers.io_struct import ( UpdateWeightReqInput, UpdateWeightReqOutput, ) +from sglang.srt.metrics.collector import TokenizerMetricsCollector from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import get_zmq_socket, kill_child_process @@ -69,6 +71,10 @@ class ReqState: finished: bool event: asyncio.Event + # For metrics + created_time: float + first_token_time: Optional[float] = None + class TokenizerManager: """TokenizerManager is a process that tokenizes the text.""" @@ -80,6 +86,7 @@ class TokenizerManager: ): # Parse args self.server_args = server_args + self.enable_metrics = server_args.enable_metrics # Init inter-process communication context = zmq.asyncio.Context(2) @@ -142,11 +149,22 @@ class TokenizerManager: # Others self.gracefully_exit = False + # Metrics + if self.enable_metrics: + self.metrics_collector = TokenizerMetricsCollector( + labels={ + "model_name": self.server_args.served_model_name, + # TODO: Add lora name/path in the future, + }, + ) + async def generate_request( self, obj: Union[GenerateReqInput, EmbeddingReqInput], request: Optional[fastapi.Request] = None, ): + created_time = time.time() + if self.to_create_loop: self.create_handle_loop() @@ -164,10 +182,12 @@ class TokenizerManager: if is_single: tokenized_obj = await self._tokenize_one_request(obj) self.send_to_scheduler.send_pyobj(tokenized_obj) - async for response in self._wait_one_response(obj, request): + async for response in self._wait_one_response(obj, request, created_time): yield response else: - async for response in self._handle_batch_request(obj, request): + async for response in self._handle_batch_request( + obj, request, created_time + ): yield response async def _tokenize_one_request( @@ -231,10 +251,11 @@ class TokenizerManager: self, obj: Union[GenerateReqInput, EmbeddingReqInput], request: Optional[fastapi.Request] = None, + created_time: Optional[float] = None, ): """Wait for the response of one request.""" event = asyncio.Event() - state = ReqState([], False, event) + state = ReqState([], False, event, created_time=created_time) self.rid_to_state[obj.rid] = state while True: @@ -272,6 +293,7 @@ class TokenizerManager: self, obj: Union[GenerateReqInput, EmbeddingReqInput], request: Optional[fastapi.Request] = None, + created_time: Optional[float] = None, ): batch_size = obj.batch_size @@ -283,7 +305,9 @@ class TokenizerManager: tmp_obj = obj[i] tokenized_obj = await self._tokenize_one_request(tmp_obj) self.send_to_scheduler.send_pyobj(tokenized_obj) - generators.append(self._wait_one_response(tmp_obj, request)) + generators.append( + self._wait_one_response(tmp_obj, request, created_time) + ) rids.append(tmp_obj.rid) else: # FIXME: When using batch and parallel_sample_num together, the perf is not optimal. @@ -303,7 +327,9 @@ class TokenizerManager: tokenized_obj.sampling_params.max_new_tokens = 0 tokenized_obj.stream = False self.send_to_scheduler.send_pyobj(tokenized_obj) - await self._wait_one_response(tmp_obj, request).__anext__() + await self._wait_one_response( + tmp_obj, request, created_time + ).__anext__() # Expand requests, assign new rids for them, and send them for i in range(batch_size): @@ -312,7 +338,9 @@ class TokenizerManager: tokenized_obj = copy.copy(tokenized_objs[i]) tokenized_obj.rid = tmp_obj.regenerate_rid() self.send_to_scheduler.send_pyobj(tokenized_obj) - generators.append(self._wait_one_response(tmp_obj, request)) + generators.append( + self._wait_one_response(tmp_obj, request, created_time) + ) rids.append(tmp_obj.rid) # Wait for all requests @@ -524,6 +552,34 @@ class TokenizerManager: state.finished = recv_obj.finished_reason[i] is not None state.event.set() + if self.enable_metrics: + completion_tokens = recv_obj.meta_info[i]["completion_tokens"] + + if state.first_token_time is None: + state.first_token_time = time.time() + self.metrics_collector.observe_time_to_first_token( + state.first_token_time - state.created_time + ) + else: + if completion_tokens >= 2: + self.metrics_collector.observe_time_per_output_token( + (time.time() - state.first_token_time) + / (completion_tokens - 1) + ) + + if state.finished: + self.metrics_collector.inc_prompt_tokens( + recv_obj.meta_info[i]["prompt_tokens"] + ) + self.metrics_collector.inc_generation_tokens(completion_tokens) + self.metrics_collector.observe_e2e_request_latency( + time.time() - state.created_time + ) + if completion_tokens >= 1: + self.metrics_collector.observe_time_per_output_token( + (time.time() - state.created_time) / completion_tokens + ) + def convert_logprob_style( self, ret: dict, diff --git a/python/sglang/srt/metrics/collector.py b/python/sglang/srt/metrics/collector.py new file mode 100644 index 000000000..5a9d42c63 --- /dev/null +++ b/python/sglang/srt/metrics/collector.py @@ -0,0 +1,211 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""Utilities for Prometheus Metrics Collection.""" + +from dataclasses import dataclass +from typing import Dict, Union + + +@dataclass +class SchedulerStats: + num_running_reqs: int = 0 + num_used_tokens: int = 0 + token_usage: float = 0.0 + gen_throughput: float = 0.0 + num_queue_reqs: int = 0 + cache_hit_rate: float = 0.0 + + +class SchedulerMetricsCollector: + + def __init__(self, labels: Dict[str, str]) -> None: + # We need to import prometheus_client after setting the env variable `PROMETHEUS_MULTIPROC_DIR` + from prometheus_client import Gauge + + self.labels = labels + + self.num_running_reqs = Gauge( + name="sglang:num_running_reqs", + documentation="The number of running requests", + labelnames=labels.keys(), + multiprocess_mode="sum", + ) + + self.num_used_tokens = Gauge( + name="sglang:num_used_tokens", + documentation="The number of used tokens", + labelnames=labels.keys(), + multiprocess_mode="sum", + ) + + self.token_usage = Gauge( + name="sglang:token_usage", + documentation="The token usage", + labelnames=labels.keys(), + multiprocess_mode="mostrecent", + ) + + self.gen_throughput = Gauge( + name="sglang:gen_throughput", + documentation="The generate throughput (token/s)", + labelnames=labels.keys(), + multiprocess_mode="sum", + ) + + self.num_queue_reqs = Gauge( + name="sglang:num_queue_reqs", + documentation="The number of requests in the waiting queue", + labelnames=labels.keys(), + multiprocess_mode="sum", + ) + + self.cache_hit_rate = Gauge( + name="sglang:cache_hit_rate", + documentation="The cache hit rate", + labelnames=labels.keys(), + multiprocess_mode="mostrecent", + ) + + def _log_gauge(self, gauge, data: Union[int, float]) -> None: + # Convenience function for logging to gauge. + gauge.labels(**self.labels).set(data) + + def log_stats(self, stats: SchedulerStats) -> None: + self._log_gauge(self.num_running_reqs, stats.num_running_reqs) + self._log_gauge(self.num_used_tokens, stats.num_used_tokens) + self._log_gauge(self.token_usage, stats.token_usage) + self._log_gauge(self.gen_throughput, stats.gen_throughput) + self._log_gauge(self.num_queue_reqs, stats.num_queue_reqs) + self._log_gauge(self.cache_hit_rate, stats.cache_hit_rate) + + +class TokenizerMetricsCollector: + def __init__(self, labels: Dict[str, str]) -> None: + # We need to import prometheus_client after setting the env variable `PROMETHEUS_MULTIPROC_DIR` + from prometheus_client import Counter, Histogram + + self.labels = labels + + self.prompt_tokens_total = Counter( + name="sglang:prompt_tokens_total", + documentation="Number of prefill tokens processed.", + labelnames=labels.keys(), + ) + + self.generation_tokens_total = Counter( + name="sglang:generation_tokens_total", + documentation="Number of generation tokens processed.", + labelnames=labels.keys(), + ) + + self.histogram_time_to_first_token = Histogram( + name="sglang:time_to_first_token_seconds", + documentation="Histogram of time to first token in seconds.", + labelnames=labels.keys(), + buckets=[ + 0.001, + 0.005, + 0.01, + 0.02, + 0.04, + 0.06, + 0.08, + 0.1, + 0.25, + 0.5, + 0.75, + 1.0, + 2.5, + 5.0, + 7.5, + 10.0, + 15.0, + 20.0, + 25.0, + 30.0, + ], + ) + + self.histogram_time_per_output_token = Histogram( + name="sglang:time_per_output_token_seconds", + documentation="Histogram of time per output token in seconds.", + labelnames=labels.keys(), + buckets=[ + 0.005, + 0.01, + 0.015, + 0.02, + 0.025, + 0.03, + 0.04, + 0.05, + 0.075, + 0.1, + 0.15, + 0.2, + 0.3, + 0.4, + 0.5, + 0.75, + 1.0, + 2.5, + ], + ) + + self.histogram_e2e_request_latency = Histogram( + name="sglang:e2e_request_latency_seconds", + documentation="Histogram of End-to-end request latency in seconds", + labelnames=labels.keys(), + buckets=[ + 0.3, + 0.5, + 0.8, + 1.0, + 1.5, + 2.0, + 2.5, + 5.0, + 10.0, + 15.0, + 20.0, + 30.0, + 40.0, + 50.0, + 60.0, + ], + ) + + def _log_histogram(self, histogram, data: Union[int, float]) -> None: + histogram.labels(**self.labels).observe(data) + + def _log_counter(self, counter, data: Union[int, float]) -> None: + # Convenience function for logging to counter. + counter.labels(**self.labels).inc(data) + + def inc_prompt_tokens(self, value: int): + self._log_counter(self.prompt_tokens_total, value) + + def inc_generation_tokens(self, value: int): + self._log_counter(self.generation_tokens_total, value) + + def observe_time_to_first_token(self, value: Union[float, int]): + self._log_histogram(self.histogram_time_to_first_token, value) + + def observe_time_per_output_token(self, value: Union[float, int]): + self._log_histogram(self.histogram_time_per_output_token, value) + + def observe_e2e_request_latency(self, value: Union[float, int]): + self._log_histogram(self.histogram_e2e_request_latency, value) diff --git a/python/sglang/srt/metrics/func_timer.py b/python/sglang/srt/metrics/func_timer.py new file mode 100644 index 000000000..71258d868 --- /dev/null +++ b/python/sglang/srt/metrics/func_timer.py @@ -0,0 +1,108 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +""" +Records the latency of some functions +""" + +import asyncio +import time +from functools import wraps +from typing import Any, Callable, List, Optional + +enable_metrics = False + + +def enable_func_timer(): + # We need to import prometheus_client after setting the env variable `PROMETHEUS_MULTIPROC_DIR` + from prometheus_client import Histogram + + global enable_metrics, FUNC_LATENCY + enable_metrics = True + + FUNC_LATENCY = Histogram( + "sglang:func_latency_seconds", + "Function latency in seconds", + # captures latency in range [50ms - ~50s] + buckets=exponential_buckets(start=0.05, width=1.5, length=18), + labelnames=["name"], + ) + + +FUNC_LATENCY = None + + +def exponential_buckets(start: float, width: float, length: int) -> List[float]: + buckets = [] + for i in range(length): + buckets.append(start * (width**i)) + return buckets + + +def time_func_latency( + func: Callable = None, name: Optional[str] = None +) -> Callable[..., Any]: + """ + A decorator to observe the latency of a function's execution. Supports both sync and async functions. + + NOTE: We use our own implementation of a timer decorator since prometheus_client does not support async + context manager yet. + + Overhead: The overhead introduced here in case of an async function could likely be because of `await` introduced + which will return in another coroutine object creation and under heavy load could see longer wall time + (scheduling delays due to introduction of another awaitable). + """ + + def measure(func: Callable[..., Any]) -> Callable[..., Any]: + nonlocal name + + name = name or func.__name__ + + @wraps(func) + async def async_wrapper(*args, **kwargs): + if not enable_metrics: + return await func(*args, **kwargs) + + metric = FUNC_LATENCY + start = time.monotonic() + ret = func(*args, **kwargs) + if isinstance(ret, asyncio.Future) or asyncio.iscoroutine(ret): + try: + ret = await ret + finally: + metric.labels(name=name).observe(time.monotonic() - start) + return ret + + @wraps(func) + def sync_wrapper(*args, **kwargs): + if not enable_metrics: + return func(*args, **kwargs) + + metric = FUNC_LATENCY + start = time.monotonic() + try: + ret = func(*args, **kwargs) + finally: + metric.labels(name=name).observe(time.monotonic() - start) + return ret + + if asyncio.iscoroutinefunction(func): + return async_wrapper + return sync_wrapper + + if func: + return measure(func) + else: + return measure diff --git a/python/sglang/srt/metrics/metrics_collector.py b/python/sglang/srt/metrics/metrics_collector.py deleted file mode 100644 index 5299cecc8..000000000 --- a/python/sglang/srt/metrics/metrics_collector.py +++ /dev/null @@ -1,388 +0,0 @@ -""" -Copyright 2023-2024 SGLang Team -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" - -"""Utilities for Prometheus Metrics Collection.""" - -import logging -from abc import ABC, abstractmethod -from typing import Counter as CollectionsCounter -from typing import Dict, List, Union - -import numpy as np -from prometheus_client import Counter, Gauge, Histogram - -from sglang.srt.metrics.metrics_types import Stats - - -class Metrics: - """ - SGLang Metrics - """ - - def __init__(self, labelnames: List[str], max_model_len): - - # Configuration Stats - self.max_total_num_tokens = Gauge( - name="sglang:max_total_num_tokens", - documentation="Maximum total number of tokens", - labelnames=labelnames, - multiprocess_mode="min", - ) # static across processes - - self.max_prefill_tokens = Gauge( - name="sglang:max_prefill_tokens", - documentation="Maximum prefill tokens", - labelnames=labelnames, - multiprocess_mode="min", - ) # static across processes - - self.max_running_requests = Gauge( - name="sglang:max_running_requests", - documentation="Maximum running requests", - labelnames=labelnames, - multiprocess_mode="min", - ) # static across processes - - self.context_len = Gauge( - name="sglang:context_len", - documentation="Context length", - labelnames=labelnames, - multiprocess_mode="min", - ) # static across processes - # Decode Stats - self.num_running_sys = Gauge( - name="sglang:num_requests_running", - documentation="Number of requests currently running on GPU", - labelnames=labelnames, - multiprocess_mode="sum", - ) - self.num_waiting_sys = Gauge( - name="sglang:num_requests_waiting", - documentation="Number of requests waiting to be processed.", - labelnames=labelnames, - multiprocess_mode="sum", - ) - self.gen_throughput = Gauge( - name="sglang:gen_throughput", - documentation="Gen token throughput (token/s)", - labelnames=labelnames, - multiprocess_mode="sum", - ) - self.token_usage = Gauge( - name="sglang:token_usage", - documentation="Total token usage", - labelnames=labelnames, - multiprocess_mode="sum", - ) - # System Stats - # KV Cache Usage in % - # self.gpu_cache_usage_sys = Gauge( - # "gpu_cache_usage_perc", - # "GPU KV-cache usage. 1 means 100 percent usage.", - # labelnames=labelnames, - # multiprocess_mode="sum") - - self.new_seq = Gauge( - name="sglang:new_seq", - documentation="Number of new sequences", - labelnames=labelnames, - multiprocess_mode="sum", - ) - self.new_token = Gauge( - name="sglang:new_token", - documentation="Number of new token", - labelnames=labelnames, - multiprocess_mode="sum", - ) - # Prefix caching block hit rate - self.cached_token = Gauge( - name="sglang:cached_token", - documentation="Number of cached token", - labelnames=labelnames, - multiprocess_mode="sum", - ) - self.cache_hit_rate = Gauge( - name="sglang:cache_hit_rate", - documentation="Cache hit rate", - labelnames=labelnames, - multiprocess_mode="sum", - ) - self.queue_req = Gauge( - name="sglang:queue_req", - documentation="Number of queued requests", - labelnames=labelnames, - multiprocess_mode="sum", - ) - - # Iteration stats - self.counter_prompt_tokens = Counter( - name="sglang:prompt_tokens_total", - documentation="Number of prefill tokens processed.", - labelnames=labelnames, - ) - self.counter_generation_tokens = Counter( - name="sglang:generation_tokens_total", - documentation="Number of generation tokens processed.", - labelnames=labelnames, - ) - self.histogram_time_to_first_token = Histogram( - name="sglang:time_to_first_token_seconds", - documentation="Histogram of time to first token in seconds.", - labelnames=labelnames, - buckets=[ - 0.001, - 0.005, - 0.01, - 0.02, - 0.04, - 0.06, - 0.08, - 0.1, - 0.25, - 0.5, - 0.75, - 1.0, - 2.5, - 5.0, - 7.5, - 10.0, - 15.0, - 20.0, - 25.0, - 30.0, - ], - ) - self.histogram_time_per_output_token = Histogram( - name="sglang:time_per_output_token_seconds", - documentation="Histogram of time per output token in seconds.", - labelnames=labelnames, - buckets=[ - 0.005, - 0.01, - 0.015, - 0.02, - 0.025, - 0.03, - 0.04, - 0.05, - 0.075, - 0.1, - 0.15, - 0.2, - 0.3, - 0.4, - 0.5, - 0.75, - 1.0, - 2.5, - ], - ) - - # Request Stats - # Metadata - self.num_prompt_tokens_requests = Histogram( - name="sglang:request_prompt_tokens", - documentation="Number of prefill tokens processed", - labelnames=labelnames, - buckets=build_1_2_5_buckets(max_model_len), - ) - self.num_generation_tokens_requests = Histogram( - name="sglang:request_generation_tokens", - documentation="Number of generation tokens processed.", - labelnames=labelnames, - buckets=build_1_2_5_buckets(max_model_len), - ) - self.finished_reason_requests = Counter( - name="sglang:request_success_total", - documentation="Count of successfully processed requests.", - labelnames=labelnames + ["finished_reason"], - ) - self.histogram_time_e2e_requests = Histogram( - name="sglang:e2e_request_latency_seconds", - documentation="Histogram of End-to-end request latency in seconds", - labelnames=labelnames, - buckets=[ - 0.3, - 0.5, - 0.8, - 1.0, - 1.5, - 2.0, - 2.5, - 5.0, - 10.0, - 15.0, - 20.0, - 30.0, - 40.0, - 50.0, - 60.0, - ], - ) - self.histogram_time_waiting_requests = Histogram( - name="sglang:waiting_request_latency_seconds", - documentation="Histogram of request waiting time in seconds", - labelnames=labelnames, - buckets=[ - 0.3, - 0.5, - 0.8, - 1.0, - 1.5, - 2.0, - 2.5, - 5.0, - 10.0, - 15.0, - 20.0, - 30.0, - 40.0, - 50.0, - 60.0, - ], - ) - self.histogram_time_decode_requests = Histogram( - name="sglang:decode_request_latency_seconds", - documentation="Histogram of request decoding time in seconds", - labelnames=labelnames, - buckets=[ - 0.3, - 0.5, - 0.8, - 1.0, - 1.5, - 2.0, - 2.5, - 5.0, - 10.0, - 15.0, - 20.0, - 30.0, - 40.0, - 50.0, - 60.0, - ], - ) - - -class MetricsCollector(ABC): - """ - SGLang Metrics Collector - """ - - @abstractmethod - def log_stats(self, stats: Stats) -> None: - pass - - -class PrometheusMetricsCollector(MetricsCollector): - """ - SGLang Metrics Collector - """ - - def __init__(self, labels: Dict[str, str], max_model_len: int) -> None: - self.labels = labels - self.metrics = Metrics( - labelnames=list(labels.keys()), max_model_len=max_model_len - ) - - def _log_gauge(self, gauge, data: Union[int, float]) -> None: - # Convenience function for logging to gauge. - gauge.labels(**self.labels).set(data) - - def _log_counter(self, counter, data: Union[int, float]) -> None: - # Convenience function for logging to counter. - counter.labels(**self.labels).inc(data) - - def _log_counter_labels( - self, counter, data: CollectionsCounter, label_key: str - ) -> None: - # Convenience function for collection counter of labels. - for label, count in data.items(): - counter.labels(**{**self.labels, label_key: label}).inc(count) - - def _log_histogram(self, histogram, data: Union[List[int], List[float]]) -> None: - # Convenience function for logging list to histogram. - for datum in data: - histogram.labels(**self.labels).observe(datum) - - def log_stats(self, stats: Stats) -> None: - self._log_gauge(self.metrics.max_total_num_tokens, stats.max_total_num_tokens) - self._log_gauge(self.metrics.max_prefill_tokens, stats.max_prefill_tokens) - self._log_gauge(self.metrics.max_running_requests, stats.max_running_requests) - self._log_gauge(self.metrics.context_len, stats.context_len) - self._log_histogram( - self.metrics.num_prompt_tokens_requests, stats.num_prompt_tokens_requests - ) - self._log_histogram( - self.metrics.num_generation_tokens_requests, - stats.num_generation_tokens_requests, - ) - - self._log_counter( - self.metrics.counter_prompt_tokens, stats.num_prompt_tokens_iter - ) - self._log_counter( - self.metrics.counter_generation_tokens, stats.num_generation_tokens_iter - ) - self._log_histogram( - self.metrics.histogram_time_to_first_token, stats.time_to_first_tokens_iter - ) - self._log_histogram( - self.metrics.histogram_time_per_output_token, - stats.time_per_output_tokens_iter, - ) - - # self._log_gauge(self.metrics.gpu_cache_usage_sys, stats.gpu_cache_usage_sys) - self._log_gauge(self.metrics.num_running_sys, stats.num_running_req) - self._log_gauge(self.metrics.num_waiting_sys, stats.num_waiting_req) - self._log_gauge(self.metrics.gen_throughput, stats.gen_throughput) - self._log_gauge(self.metrics.token_usage, stats.token_usage) - self._log_histogram( - self.metrics.histogram_time_e2e_requests, stats.time_e2e_requests - ) - self._log_histogram( - self.metrics.histogram_time_waiting_requests, stats.time_waiting_requests - ) - self._log_histogram( - self.metrics.histogram_time_decode_requests, stats.time_decode_requests - ) - self._log_gauge(self.metrics.new_seq, stats.new_seq) - self._log_gauge(self.metrics.new_token, stats.new_token) - self._log_gauge(self.metrics.cached_token, stats.cached_token) - self._log_gauge(self.metrics.cache_hit_rate, stats.cache_hit_rate) - self._log_gauge(self.metrics.queue_req, stats.queue_req) - - -def build_1_2_5_buckets(max_value: int) -> List[int]: - """ - Builds a list of buckets with increasing powers of 10 multiplied by - mantissa values (1, 2, 5) until the value exceeds the specified maximum. - - Example: - >>> build_1_2_5_buckets(100) - [1, 2, 5, 10, 20, 50, 100] - """ - mantissa_lst = [1, 2, 5] - exponent = 0 - buckets: List[int] = [] - while True: - for m in mantissa_lst: - value = m * 10**exponent - if value <= max_value: - buckets.append(value) - else: - return buckets - exponent += 1 diff --git a/python/sglang/srt/metrics/metrics_types.py b/python/sglang/srt/metrics/metrics_types.py deleted file mode 100644 index 2bd0f54f5..000000000 --- a/python/sglang/srt/metrics/metrics_types.py +++ /dev/null @@ -1,54 +0,0 @@ -""" -Copyright 2023-2024 SGLang Team -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" - -"""Metrics Types""" - -from dataclasses import dataclass, field -from typing import List - - -@dataclass -class Stats: - # config - max_total_num_tokens: int = 0 - max_prefill_tokens: int = 0 - max_running_requests: int = 0 - context_len: int = 0 - # request stats - num_prompt_tokens_requests: List[int] = field(default_factory=list) - num_generation_tokens_requests: List[int] = field(default_factory=list) - finished_reason_requests: List[str] = field(default_factory=list) - # decode stats - num_running_req: int = 0 - num_waiting_req: int = 0 - gen_throughput: float = 0.0 - waiting_queue: int = 0 - time_e2e_requests: List[float] = field(default_factory=list) - time_waiting_requests: List[float] = field(default_factory=list) - time_decode_requests: List[float] = field(default_factory=list) - # system stats - token_usage: float = 0.0 - new_seq: int = 0 - new_token: int = 0 - cached_token: int = 0 - cache_hit_rate: float = 0.0 - running_req: int = 0 - queue_req: int = 0 - - # Iteration stats (should have _iter suffix) - num_prompt_tokens_iter: int = 0 - num_generation_tokens_iter: int = 0 - time_to_first_tokens_iter: List[float] = field(default_factory=list) - time_per_output_tokens_iter: List[float] = field(default_factory=list) diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index ce01154a6..8aa2ba453 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -56,6 +56,7 @@ from sglang.srt.managers.io_struct import ( ) from sglang.srt.managers.scheduler import run_scheduler_process from sglang.srt.managers.tokenizer_manager import TokenizerManager +from sglang.srt.metrics.func_timer import enable_func_timer, time_func_latency from sglang.srt.openai_api.adapter import ( load_chat_template_for_openai_api, v1_batches, @@ -196,6 +197,7 @@ async def get_memory_pool_size(): @app.post("/update_weights") +@time_func_latency async def update_weights(obj: UpdateWeightReqInput, request: Request): """Update the weights inplace without re-launching the server.""" success, message = await tokenizer_manager.update_weights(obj, request) @@ -212,7 +214,7 @@ async def update_weights(obj: UpdateWeightReqInput, request: Request): ) -# fastapi implicitly converts json in the request to obj (dataclass) +@time_func_latency async def generate_request(obj: GenerateReqInput, request: Request): """Handle a generate request.""" if obj.stream: @@ -245,10 +247,12 @@ async def generate_request(obj: GenerateReqInput, request: Request): ) +# fastapi implicitly converts json in the request to obj (dataclass) app.post("/generate")(generate_request) app.put("/generate")(generate_request) +@time_func_latency async def encode_request(obj: EmbeddingReqInput, request: Request): """Handle an embedding request.""" try: @@ -264,6 +268,7 @@ app.post("/encode")(encode_request) app.put("/encode")(encode_request) +@time_func_latency async def classify_request(obj: EmbeddingReqInput, request: Request): """Handle a reward model request. Now the arguments and return values are the same as embedding models.""" try: @@ -283,16 +288,19 @@ app.put("/classify")(classify_request) @app.post("/v1/completions") +@time_func_latency async def openai_v1_completions(raw_request: Request): return await v1_completions(tokenizer_manager, raw_request) @app.post("/v1/chat/completions") +@time_func_latency async def openai_v1_chat_completions(raw_request: Request): return await v1_chat_completions(tokenizer_manager, raw_request) @app.post("/v1/embeddings", response_class=ORJSONResponse) +@time_func_latency async def openai_v1_embeddings(raw_request: Request): response = await v1_embeddings(tokenizer_manager, raw_request) return response @@ -455,6 +463,7 @@ def launch_server( # add prometheus middleware if server_args.enable_metrics: add_prometheus_middleware(app) + enable_func_timer() # Send a warmup request t = threading.Thread( @@ -492,6 +501,10 @@ def _set_envs_and_config(server_args: ServerArgs): os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4" + # Set prometheus env vars + if server_args.enable_metrics: + set_prometheus_multiproc_dir() + # Set ulimit set_ulimit() @@ -510,10 +523,6 @@ def _set_envs_and_config(server_args: ServerArgs): "at https://docs.flashinfer.ai/installation.html.", ) - # Set prometheus env vars - if server_args.enable_metrics: - set_prometheus_multiproc_dir() - mp.set_start_method("spawn", force=True) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index d8184b018..ff0ef5e42 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -781,6 +781,7 @@ def set_prometheus_multiproc_dir(): def add_prometheus_middleware(app): + # We need to import prometheus_client after setting the env variable `PROMETHEUS_MULTIPROC_DIR` from prometheus_client import CollectorRegistry, make_asgi_app, multiprocess registry = CollectorRegistry() diff --git a/test/srt/test_metrics.py b/test/srt/test_metrics.py index 37ba6c7ef..163a7cc0e 100644 --- a/test/srt/test_metrics.py +++ b/test/srt/test_metrics.py @@ -22,23 +22,41 @@ class TestEnableMetrics(unittest.TestCase): ) try: - # Make a request to generate some metrics + # Make some requests to generate some metrics response = requests.get(f"{DEFAULT_URL_FOR_TEST}/health_generate") self.assertEqual(response.status_code, 200) + response = requests.post( + f"{DEFAULT_URL_FOR_TEST}/generate", + json={ + "text": "The capital of France is", + "sampling_params": { + "temperature": 0, + "max_new_tokens": 32, + }, + "stream": True, + }, + stream=True, + ) + for _ in response.iter_lines(decode_unicode=False): + pass + # Get metrics metrics_response = requests.get(f"{DEFAULT_URL_FOR_TEST}/metrics") self.assertEqual(metrics_response.status_code, 200) metrics_content = metrics_response.text - print(f"{metrics_content=}") + print(f"metrics_content=\n{metrics_content}") # Verify essential metrics are present essential_metrics = [ + "sglang:num_running_reqs", + "sglang:token_usage", + "sglang:gen_throughput", + "sglang:cache_hit_rate", + "sglang:func_latency_seconds", "sglang:prompt_tokens_total", "sglang:generation_tokens_total", - "sglang:max_total_num_tokens", - "sglang:context_len", "sglang:time_to_first_token_seconds", "sglang:time_per_output_token_seconds", "sglang:e2e_request_latency_seconds", @@ -50,6 +68,7 @@ class TestEnableMetrics(unittest.TestCase): # Verify model name label is present and correct expected_model_name = DEFAULT_SMALL_MODEL_NAME_FOR_TEST self.assertIn(f'model_name="{expected_model_name}"', metrics_content) + # Verify metrics have values (not empty) self.assertIn("_sum{", metrics_content) self.assertIn("_count{", metrics_content)