[PD metrics] Add latency Histogram metrics of each stage for generate requests (#8710)

This commit is contained in:
Yingchun Lai
2025-09-16 01:52:49 +08:00
committed by GitHub
parent 57234d0c9c
commit b1721edbac
7 changed files with 77 additions and 11 deletions

View File

@@ -45,7 +45,7 @@ from sglang.srt.disaggregation.utils import (
prepare_abort, prepare_abort,
) )
from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch from sglang.srt.managers.schedule_batch import FINISH_ABORT, RequestStage, ScheduleBatch
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
@@ -253,6 +253,7 @@ class DecodePreallocQueue:
prefill_dp_rank=req.data_parallel_rank, prefill_dp_rank=req.data_parallel_rank,
) )
req.add_latency(RequestStage.DECODE_PREPARE)
self.queue.append( self.queue.append(
DecodeRequest(req=req, kv_receiver=kv_receiver, waiting_for_input=False) DecodeRequest(req=req, kv_receiver=kv_receiver, waiting_for_input=False)
) )
@@ -421,6 +422,7 @@ class DecodePreallocQueue:
kv_indices, self.token_to_kv_pool_allocator.page_size kv_indices, self.token_to_kv_pool_allocator.page_size
) )
decode_req.kv_receiver.init(page_indices, decode_req.metadata_buffer_index) decode_req.kv_receiver.init(page_indices, decode_req.metadata_buffer_index)
decode_req.req.add_latency(RequestStage.DECODE_BOOTSTRAP)
preallocated_reqs.append(decode_req) preallocated_reqs.append(decode_req)
indices_to_remove.add(i) indices_to_remove.add(i)
@@ -662,6 +664,7 @@ class DecodeTransferQueue:
for i in indices_to_remove: for i in indices_to_remove:
idx = self.queue[i].metadata_buffer_index idx = self.queue[i].metadata_buffer_index
assert idx != -1 assert idx != -1
self.queue[i].req.add_latency(RequestStage.DECODE_TRANSFERRED)
self.req_to_metadata_buffer_idx_allocator.free(idx) self.req_to_metadata_buffer_idx_allocator.free(idx)
self.queue = [ self.queue = [
@@ -853,6 +856,7 @@ class SchedulerDisaggregationDecodeMixin:
# we can only add at least `num_not_used_batch` new batch to the running queue # we can only add at least `num_not_used_batch` new batch to the running queue
if i < num_not_used_batch: if i < num_not_used_batch:
can_run_list.append(req) can_run_list.append(req)
req.add_latency(RequestStage.DECODE_WAITING)
req.init_next_round_input(self.tree_cache) req.init_next_round_input(self.tree_cache)
else: else:
waiting_queue.append(req) waiting_queue.append(req)

View File

@@ -42,7 +42,12 @@ from sglang.srt.disaggregation.utils import (
poll_and_all_reduce, poll_and_all_reduce,
prepare_abort, prepare_abort,
) )
from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch from sglang.srt.managers.schedule_batch import (
FINISH_LENGTH,
Req,
RequestStage,
ScheduleBatch,
)
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
from sglang.srt.utils import ( from sglang.srt.utils import (
DynamicGradMode, DynamicGradMode,
@@ -170,6 +175,7 @@ class PrefillBootstrapQueue:
pp_rank=self.pp_rank, pp_rank=self.pp_rank,
) )
self._process_req(req) self._process_req(req)
req.add_latency(RequestStage.PREFILL_PREPARE)
self.queue.append(req) self.queue.append(req)
def extend(self, reqs: List[Req], num_kv_heads: int) -> None: def extend(self, reqs: List[Req], num_kv_heads: int) -> None:
@@ -256,6 +262,8 @@ class PrefillBootstrapQueue:
num_pages = kv_to_page_num(num_kv_indices, self.token_to_kv_pool.page_size) num_pages = kv_to_page_num(num_kv_indices, self.token_to_kv_pool.page_size)
req.disagg_kv_sender.init(num_pages, req.metadata_buffer_index) req.disagg_kv_sender.init(num_pages, req.metadata_buffer_index)
req.add_latency(RequestStage.PREFILL_BOOTSTRAP)
bootstrapped_reqs.append(req) bootstrapped_reqs.append(req)
indices_to_remove.add(i) indices_to_remove.add(i)
@@ -404,6 +412,7 @@ class SchedulerDisaggregationPrefillMixin:
# There is no output_ids for prefill # There is no output_ids for prefill
req.output_ids.append(next_token_id) req.output_ids.append(next_token_id)
self.tree_cache.cache_unfinished_req(req) # update the tree and lock self.tree_cache.cache_unfinished_req(req) # update the tree and lock
req.add_latency(RequestStage.PREFILL_FORWARD)
self.disagg_prefill_inflight_queue.append(req) self.disagg_prefill_inflight_queue.append(req)
if ( if (
logits_output is not None logits_output is not None
@@ -539,6 +548,7 @@ class SchedulerDisaggregationPrefillMixin:
) )
for req in done_reqs: for req in done_reqs:
req: Req req: Req
req.add_latency(RequestStage.PREFILL_TRANSFER_KV_CACHE)
self.req_to_metadata_buffer_idx_allocator.free(req.metadata_buffer_index) self.req_to_metadata_buffer_idx_allocator.free(req.metadata_buffer_index)
req.metadata_buffer_index = -1 req.metadata_buffer_index = -1

View File

@@ -1,5 +1,7 @@
from __future__ import annotations from __future__ import annotations
import enum
# Copyright 2023-2024 SGLang Team # Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -35,6 +37,7 @@ import copy
import dataclasses import dataclasses
import logging import logging
import threading import threading
import time
from enum import Enum, auto from enum import Enum, auto
from http import HTTPStatus from http import HTTPStatus
from itertools import chain from itertools import chain
@@ -61,7 +64,7 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
from sglang.srt.mem_cache.lora_radix_cache import LoRAKey, LoRARadixCache from sglang.srt.mem_cache.lora_radix_cache import LoRAKey, LoRARadixCache
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
from sglang.srt.metrics.collector import TimeStats from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.sampling.sampling_params import SamplingParams
@@ -407,6 +410,23 @@ class MultimodalInputs:
# other args would be kept intact # other args would be kept intact
class RequestStage(str, enum.Enum):
# prefill
PREFILL_WAITING = "prefill_waiting"
# disaggregation prefill
PREFILL_PREPARE = "prefill_prepare"
PREFILL_BOOTSTRAP = "prefill_bootstrap"
PREFILL_FORWARD = "prefill_forward"
PREFILL_TRANSFER_KV_CACHE = "prefill_transfer_kv_cache"
# disaggregation decode
DECODE_PREPARE = "decode_prepare"
DECODE_BOOTSTRAP = "decode_bootstrap"
DECODE_WAITING = "decode_waiting"
DECODE_TRANSFERRED = "decode_transferred"
class Req: class Req:
"""The input and output status of a request.""" """The input and output status of a request."""
@@ -433,6 +453,7 @@ class Req:
bootstrap_room: Optional[int] = None, bootstrap_room: Optional[int] = None,
data_parallel_rank: Optional[int] = None, data_parallel_rank: Optional[int] = None,
vocab_size: Optional[int] = None, vocab_size: Optional[int] = None,
metrics_collector: Optional[SchedulerMetricsCollector] = None,
): ):
# Input and output info # Input and output info
self.rid = rid self.rid = rid
@@ -590,10 +611,12 @@ class Req:
self.spec_verify_ct = 0 self.spec_verify_ct = 0
# For metrics # For metrics
self.metrics_collector = metrics_collector
self.time_stats: TimeStats = TimeStats() self.time_stats: TimeStats = TimeStats()
self.has_log_time_stats: bool = False self.has_log_time_stats: bool = False
self.queue_time_start = None self.queue_time_start = None
self.queue_time_end = None self.queue_time_end = None
self.last_tic = time.monotonic()
# For disaggregation # For disaggregation
self.bootstrap_host: str = bootstrap_host self.bootstrap_host: str = bootstrap_host
@@ -626,6 +649,16 @@ class Req:
"""Check if this request is prefill-only (no token generation needed).""" """Check if this request is prefill-only (no token generation needed)."""
return self.sampling_params.max_new_tokens == 0 return self.sampling_params.max_new_tokens == 0
def add_latency(self, stage: RequestStage):
if self.metrics_collector is None:
return
assert stage.name in RequestStage.__members__, f"{stage=} is invalid"
now = time.monotonic()
self.metrics_collector.observe_request_latency_seconds(
stage.value, now - self.last_tic
)
self.last_tic = now
def extend_image_inputs(self, image_inputs): def extend_image_inputs(self, image_inputs):
if self.multimodal_inputs is None: if self.multimodal_inputs is None:
self.multimodal_inputs = image_inputs self.multimodal_inputs = image_inputs

View File

@@ -116,6 +116,7 @@ from sglang.srt.managers.schedule_batch import (
FINISH_ABORT, FINISH_ABORT,
MultimodalInputs, MultimodalInputs,
Req, Req,
RequestStage,
ScheduleBatch, ScheduleBatch,
global_server_args_dict, global_server_args_dict,
) )
@@ -1232,6 +1233,9 @@ class Scheduler(
bootstrap_room=recv_req.bootstrap_room, bootstrap_room=recv_req.bootstrap_room,
data_parallel_rank=recv_req.data_parallel_rank, data_parallel_rank=recv_req.data_parallel_rank,
vocab_size=self.model_config.vocab_size, vocab_size=self.model_config.vocab_size,
metrics_collector=(
self.metrics_collector if self.enable_metrics else None
),
) )
req.tokenizer = self.tokenizer req.tokenizer = self.tokenizer
@@ -1768,6 +1772,7 @@ class Scheduler(
# only record queue time when enable_metrics is True to avoid overhead # only record queue time when enable_metrics is True to avoid overhead
for req in can_run_list: for req in can_run_list:
req.queue_time_end = time.perf_counter() req.queue_time_end = time.perf_counter()
req.add_latency(RequestStage.PREFILL_WAITING)
self.waiting_queue = [ self.waiting_queue = [
x for x in self.waiting_queue if x not in set(can_run_list) x for x in self.waiting_queue if x not in set(can_run_list)

View File

@@ -17,7 +17,7 @@ from dataclasses import dataclass, field
from enum import Enum from enum import Enum
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
from sglang.srt.metrics.utils import generate_buckets from sglang.srt.metrics.utils import exponential_buckets, generate_buckets
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_bool_env_var from sglang.srt.utils import get_bool_env_var
@@ -513,6 +513,14 @@ class SchedulerMetricsCollector:
buckets=tree_traversal_time_buckets, buckets=tree_traversal_time_buckets,
) )
self.request_latency_seconds = Histogram(
name="sglang:request_latency_seconds",
documentation="The latency of each stage of requests.",
# captures latency in range [1ms - ~1191s]
buckets=exponential_buckets(start=0.001, width=1.62, length=30),
labelnames=list(labels.keys()) + ["stage"],
)
def _log_gauge(self, gauge, data: Union[int, float]) -> None: def _log_gauge(self, gauge, data: Union[int, float]) -> None:
# Convenience function for logging to gauge. # Convenience function for logging to gauge.
gauge.labels(**self.labels).set(data) gauge.labels(**self.labels).set(data)
@@ -526,6 +534,10 @@ class SchedulerMetricsCollector:
def increment_transfer_failed_reqs(self) -> None: def increment_transfer_failed_reqs(self) -> None:
self.num_transfer_failed_reqs.labels(**self.labels).inc(1) self.num_transfer_failed_reqs.labels(**self.labels).inc(1)
def observe_request_latency_seconds(self, stage: str, latency: float) -> None:
labels_with_stage = {**self.labels, "stage": stage}
self.request_latency_seconds.labels(**labels_with_stage).observe(latency)
def log_stats(self, stats: SchedulerStats) -> None: def log_stats(self, stats: SchedulerStats) -> None:
self._log_gauge(self.num_running_reqs, stats.num_running_reqs) self._log_gauge(self.num_running_reqs, stats.num_running_reqs)
self._log_gauge(self.num_used_tokens, stats.num_used_tokens) self._log_gauge(self.num_used_tokens, stats.num_used_tokens)

View File

@@ -20,6 +20,8 @@ import time
from functools import wraps from functools import wraps
from typing import Any, Callable, List, Optional from typing import Any, Callable, List, Optional
from sglang.srt.metrics.utils import exponential_buckets
enable_metrics = False enable_metrics = False
@@ -42,13 +44,6 @@ def enable_func_timer():
FUNC_LATENCY = None FUNC_LATENCY = None
def exponential_buckets(start: float, width: float, length: int) -> List[float]:
buckets = []
for i in range(length):
buckets.append(start * (width**i))
return buckets
def time_func_latency( def time_func_latency(
func: Callable = None, name: Optional[str] = None func: Callable = None, name: Optional[str] = None
) -> Callable[..., Any]: ) -> Callable[..., Any]:

View File

@@ -46,3 +46,10 @@ def generate_buckets(
return sorted(set(default_buckets)) return sorted(set(default_buckets))
assert rule == "customer" assert rule == "customer"
return sorted(set([float(x) for x in buckets_rule[1:]])) return sorted(set([float(x) for x in buckets_rule[1:]]))
def exponential_buckets(start: float, width: float, length: int) -> List[float]:
buckets = []
for i in range(length):
buckets.append(start * (width**i))
return buckets