From b1721edbac0d936a158bec1594a42e8417171827 Mon Sep 17 00:00:00 2001 From: Yingchun Lai Date: Tue, 16 Sep 2025 01:52:49 +0800 Subject: [PATCH] [PD metrics] Add latency Histogram metrics of each stage for generate requests (#8710) --- python/sglang/srt/disaggregation/decode.py | 6 +++- python/sglang/srt/disaggregation/prefill.py | 12 ++++++- python/sglang/srt/managers/schedule_batch.py | 35 +++++++++++++++++++- python/sglang/srt/managers/scheduler.py | 5 +++ python/sglang/srt/metrics/collector.py | 14 +++++++- python/sglang/srt/metrics/func_timer.py | 9 ++--- python/sglang/srt/metrics/utils.py | 7 ++++ 7 files changed, 77 insertions(+), 11 deletions(-) diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index 0bddf3dcc..f4d7e8f7f 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -45,7 +45,7 @@ from sglang.srt.disaggregation.utils import ( prepare_abort, ) from sglang.srt.layers.dp_attention import get_attention_tp_size -from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch +from sglang.srt.managers.schedule_batch import FINISH_ABORT, RequestStage, ScheduleBatch from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool @@ -253,6 +253,7 @@ class DecodePreallocQueue: prefill_dp_rank=req.data_parallel_rank, ) + req.add_latency(RequestStage.DECODE_PREPARE) self.queue.append( DecodeRequest(req=req, kv_receiver=kv_receiver, waiting_for_input=False) ) @@ -421,6 +422,7 @@ class DecodePreallocQueue: kv_indices, self.token_to_kv_pool_allocator.page_size ) decode_req.kv_receiver.init(page_indices, decode_req.metadata_buffer_index) + decode_req.req.add_latency(RequestStage.DECODE_BOOTSTRAP) preallocated_reqs.append(decode_req) indices_to_remove.add(i) @@ -662,6 +664,7 @@ class DecodeTransferQueue: for i in indices_to_remove: idx = self.queue[i].metadata_buffer_index assert idx != -1 + self.queue[i].req.add_latency(RequestStage.DECODE_TRANSFERRED) self.req_to_metadata_buffer_idx_allocator.free(idx) self.queue = [ @@ -853,6 +856,7 @@ class SchedulerDisaggregationDecodeMixin: # we can only add at least `num_not_used_batch` new batch to the running queue if i < num_not_used_batch: can_run_list.append(req) + req.add_latency(RequestStage.DECODE_WAITING) req.init_next_round_input(self.tree_cache) else: waiting_queue.append(req) diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index b70748250..f128ae914 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -42,7 +42,12 @@ from sglang.srt.disaggregation.utils import ( poll_and_all_reduce, prepare_abort, ) -from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch +from sglang.srt.managers.schedule_batch import ( + FINISH_LENGTH, + Req, + RequestStage, + ScheduleBatch, +) from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors from sglang.srt.utils import ( DynamicGradMode, @@ -170,6 +175,7 @@ class PrefillBootstrapQueue: pp_rank=self.pp_rank, ) self._process_req(req) + req.add_latency(RequestStage.PREFILL_PREPARE) self.queue.append(req) def extend(self, reqs: List[Req], num_kv_heads: int) -> None: @@ -256,6 +262,8 @@ class PrefillBootstrapQueue: num_pages = kv_to_page_num(num_kv_indices, self.token_to_kv_pool.page_size) req.disagg_kv_sender.init(num_pages, req.metadata_buffer_index) + + req.add_latency(RequestStage.PREFILL_BOOTSTRAP) bootstrapped_reqs.append(req) indices_to_remove.add(i) @@ -404,6 +412,7 @@ class SchedulerDisaggregationPrefillMixin: # There is no output_ids for prefill req.output_ids.append(next_token_id) self.tree_cache.cache_unfinished_req(req) # update the tree and lock + req.add_latency(RequestStage.PREFILL_FORWARD) self.disagg_prefill_inflight_queue.append(req) if ( logits_output is not None @@ -539,6 +548,7 @@ class SchedulerDisaggregationPrefillMixin: ) for req in done_reqs: req: Req + req.add_latency(RequestStage.PREFILL_TRANSFER_KV_CACHE) self.req_to_metadata_buffer_idx_allocator.free(req.metadata_buffer_index) req.metadata_buffer_index = -1 diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index b226b8331..9402e723f 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -1,5 +1,7 @@ from __future__ import annotations +import enum + # Copyright 2023-2024 SGLang Team # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -35,6 +37,7 @@ import copy import dataclasses import logging import threading +import time from enum import Enum, auto from http import HTTPStatus from itertools import chain @@ -61,7 +64,7 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache from sglang.srt.mem_cache.lora_radix_cache import LoRAKey, LoRARadixCache from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache -from sglang.srt.metrics.collector import TimeStats +from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_params import SamplingParams @@ -407,6 +410,23 @@ class MultimodalInputs: # other args would be kept intact +class RequestStage(str, enum.Enum): + # prefill + PREFILL_WAITING = "prefill_waiting" + + # disaggregation prefill + PREFILL_PREPARE = "prefill_prepare" + PREFILL_BOOTSTRAP = "prefill_bootstrap" + PREFILL_FORWARD = "prefill_forward" + PREFILL_TRANSFER_KV_CACHE = "prefill_transfer_kv_cache" + + # disaggregation decode + DECODE_PREPARE = "decode_prepare" + DECODE_BOOTSTRAP = "decode_bootstrap" + DECODE_WAITING = "decode_waiting" + DECODE_TRANSFERRED = "decode_transferred" + + class Req: """The input and output status of a request.""" @@ -433,6 +453,7 @@ class Req: bootstrap_room: Optional[int] = None, data_parallel_rank: Optional[int] = None, vocab_size: Optional[int] = None, + metrics_collector: Optional[SchedulerMetricsCollector] = None, ): # Input and output info self.rid = rid @@ -590,10 +611,12 @@ class Req: self.spec_verify_ct = 0 # For metrics + self.metrics_collector = metrics_collector self.time_stats: TimeStats = TimeStats() self.has_log_time_stats: bool = False self.queue_time_start = None self.queue_time_end = None + self.last_tic = time.monotonic() # For disaggregation self.bootstrap_host: str = bootstrap_host @@ -626,6 +649,16 @@ class Req: """Check if this request is prefill-only (no token generation needed).""" return self.sampling_params.max_new_tokens == 0 + def add_latency(self, stage: RequestStage): + if self.metrics_collector is None: + return + assert stage.name in RequestStage.__members__, f"{stage=} is invalid" + now = time.monotonic() + self.metrics_collector.observe_request_latency_seconds( + stage.value, now - self.last_tic + ) + self.last_tic = now + def extend_image_inputs(self, image_inputs): if self.multimodal_inputs is None: self.multimodal_inputs = image_inputs diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index de2ad078d..f2697e75f 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -116,6 +116,7 @@ from sglang.srt.managers.schedule_batch import ( FINISH_ABORT, MultimodalInputs, Req, + RequestStage, ScheduleBatch, global_server_args_dict, ) @@ -1232,6 +1233,9 @@ class Scheduler( bootstrap_room=recv_req.bootstrap_room, data_parallel_rank=recv_req.data_parallel_rank, vocab_size=self.model_config.vocab_size, + metrics_collector=( + self.metrics_collector if self.enable_metrics else None + ), ) req.tokenizer = self.tokenizer @@ -1768,6 +1772,7 @@ class Scheduler( # only record queue time when enable_metrics is True to avoid overhead for req in can_run_list: req.queue_time_end = time.perf_counter() + req.add_latency(RequestStage.PREFILL_WAITING) self.waiting_queue = [ x for x in self.waiting_queue if x not in set(can_run_list) diff --git a/python/sglang/srt/metrics/collector.py b/python/sglang/srt/metrics/collector.py index 551d51184..884f4e211 100644 --- a/python/sglang/srt/metrics/collector.py +++ b/python/sglang/srt/metrics/collector.py @@ -17,7 +17,7 @@ from dataclasses import dataclass, field from enum import Enum from typing import Dict, List, Optional, Union -from sglang.srt.metrics.utils import generate_buckets +from sglang.srt.metrics.utils import exponential_buckets, generate_buckets from sglang.srt.server_args import ServerArgs from sglang.srt.utils import get_bool_env_var @@ -513,6 +513,14 @@ class SchedulerMetricsCollector: buckets=tree_traversal_time_buckets, ) + self.request_latency_seconds = Histogram( + name="sglang:request_latency_seconds", + documentation="The latency of each stage of requests.", + # captures latency in range [1ms - ~1191s] + buckets=exponential_buckets(start=0.001, width=1.62, length=30), + labelnames=list(labels.keys()) + ["stage"], + ) + def _log_gauge(self, gauge, data: Union[int, float]) -> None: # Convenience function for logging to gauge. gauge.labels(**self.labels).set(data) @@ -526,6 +534,10 @@ class SchedulerMetricsCollector: def increment_transfer_failed_reqs(self) -> None: self.num_transfer_failed_reqs.labels(**self.labels).inc(1) + def observe_request_latency_seconds(self, stage: str, latency: float) -> None: + labels_with_stage = {**self.labels, "stage": stage} + self.request_latency_seconds.labels(**labels_with_stage).observe(latency) + def log_stats(self, stats: SchedulerStats) -> None: self._log_gauge(self.num_running_reqs, stats.num_running_reqs) self._log_gauge(self.num_used_tokens, stats.num_used_tokens) diff --git a/python/sglang/srt/metrics/func_timer.py b/python/sglang/srt/metrics/func_timer.py index e965d25f8..fbb01bac8 100644 --- a/python/sglang/srt/metrics/func_timer.py +++ b/python/sglang/srt/metrics/func_timer.py @@ -20,6 +20,8 @@ import time from functools import wraps from typing import Any, Callable, List, Optional +from sglang.srt.metrics.utils import exponential_buckets + enable_metrics = False @@ -42,13 +44,6 @@ def enable_func_timer(): FUNC_LATENCY = None -def exponential_buckets(start: float, width: float, length: int) -> List[float]: - buckets = [] - for i in range(length): - buckets.append(start * (width**i)) - return buckets - - def time_func_latency( func: Callable = None, name: Optional[str] = None ) -> Callable[..., Any]: diff --git a/python/sglang/srt/metrics/utils.py b/python/sglang/srt/metrics/utils.py index ffc7e1066..73c0b4e73 100644 --- a/python/sglang/srt/metrics/utils.py +++ b/python/sglang/srt/metrics/utils.py @@ -46,3 +46,10 @@ def generate_buckets( return sorted(set(default_buckets)) assert rule == "customer" return sorted(set([float(x) for x in buckets_rule[1:]])) + + +def exponential_buckets(start: float, width: float, length: int) -> List[float]: + buckets = [] + for i in range(length): + buckets.append(start * (width**i)) + return buckets