[PD metrics] Add latency Histogram metrics of each stage for generate requests (#8710)
This commit is contained in:
@@ -45,7 +45,7 @@ from sglang.srt.disaggregation.utils import (
|
||||
prepare_abort,
|
||||
)
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||
from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch
|
||||
from sglang.srt.managers.schedule_batch import FINISH_ABORT, RequestStage, ScheduleBatch
|
||||
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
|
||||
@@ -253,6 +253,7 @@ class DecodePreallocQueue:
|
||||
prefill_dp_rank=req.data_parallel_rank,
|
||||
)
|
||||
|
||||
req.add_latency(RequestStage.DECODE_PREPARE)
|
||||
self.queue.append(
|
||||
DecodeRequest(req=req, kv_receiver=kv_receiver, waiting_for_input=False)
|
||||
)
|
||||
@@ -421,6 +422,7 @@ class DecodePreallocQueue:
|
||||
kv_indices, self.token_to_kv_pool_allocator.page_size
|
||||
)
|
||||
decode_req.kv_receiver.init(page_indices, decode_req.metadata_buffer_index)
|
||||
decode_req.req.add_latency(RequestStage.DECODE_BOOTSTRAP)
|
||||
preallocated_reqs.append(decode_req)
|
||||
indices_to_remove.add(i)
|
||||
|
||||
@@ -662,6 +664,7 @@ class DecodeTransferQueue:
|
||||
for i in indices_to_remove:
|
||||
idx = self.queue[i].metadata_buffer_index
|
||||
assert idx != -1
|
||||
self.queue[i].req.add_latency(RequestStage.DECODE_TRANSFERRED)
|
||||
self.req_to_metadata_buffer_idx_allocator.free(idx)
|
||||
|
||||
self.queue = [
|
||||
@@ -853,6 +856,7 @@ class SchedulerDisaggregationDecodeMixin:
|
||||
# we can only add at least `num_not_used_batch` new batch to the running queue
|
||||
if i < num_not_used_batch:
|
||||
can_run_list.append(req)
|
||||
req.add_latency(RequestStage.DECODE_WAITING)
|
||||
req.init_next_round_input(self.tree_cache)
|
||||
else:
|
||||
waiting_queue.append(req)
|
||||
|
||||
@@ -42,7 +42,12 @@ from sglang.srt.disaggregation.utils import (
|
||||
poll_and_all_reduce,
|
||||
prepare_abort,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch
|
||||
from sglang.srt.managers.schedule_batch import (
|
||||
FINISH_LENGTH,
|
||||
Req,
|
||||
RequestStage,
|
||||
ScheduleBatch,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
|
||||
from sglang.srt.utils import (
|
||||
DynamicGradMode,
|
||||
@@ -170,6 +175,7 @@ class PrefillBootstrapQueue:
|
||||
pp_rank=self.pp_rank,
|
||||
)
|
||||
self._process_req(req)
|
||||
req.add_latency(RequestStage.PREFILL_PREPARE)
|
||||
self.queue.append(req)
|
||||
|
||||
def extend(self, reqs: List[Req], num_kv_heads: int) -> None:
|
||||
@@ -256,6 +262,8 @@ class PrefillBootstrapQueue:
|
||||
|
||||
num_pages = kv_to_page_num(num_kv_indices, self.token_to_kv_pool.page_size)
|
||||
req.disagg_kv_sender.init(num_pages, req.metadata_buffer_index)
|
||||
|
||||
req.add_latency(RequestStage.PREFILL_BOOTSTRAP)
|
||||
bootstrapped_reqs.append(req)
|
||||
indices_to_remove.add(i)
|
||||
|
||||
@@ -404,6 +412,7 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
# There is no output_ids for prefill
|
||||
req.output_ids.append(next_token_id)
|
||||
self.tree_cache.cache_unfinished_req(req) # update the tree and lock
|
||||
req.add_latency(RequestStage.PREFILL_FORWARD)
|
||||
self.disagg_prefill_inflight_queue.append(req)
|
||||
if (
|
||||
logits_output is not None
|
||||
@@ -539,6 +548,7 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
)
|
||||
for req in done_reqs:
|
||||
req: Req
|
||||
req.add_latency(RequestStage.PREFILL_TRANSFER_KV_CACHE)
|
||||
self.req_to_metadata_buffer_idx_allocator.free(req.metadata_buffer_index)
|
||||
req.metadata_buffer_index = -1
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -35,6 +37,7 @@ import copy
|
||||
import dataclasses
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from enum import Enum, auto
|
||||
from http import HTTPStatus
|
||||
from itertools import chain
|
||||
@@ -61,7 +64,7 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
||||
from sglang.srt.mem_cache.lora_radix_cache import LoRAKey, LoRARadixCache
|
||||
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool
|
||||
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
||||
from sglang.srt.metrics.collector import TimeStats
|
||||
from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats
|
||||
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
@@ -407,6 +410,23 @@ class MultimodalInputs:
|
||||
# other args would be kept intact
|
||||
|
||||
|
||||
class RequestStage(str, enum.Enum):
|
||||
# prefill
|
||||
PREFILL_WAITING = "prefill_waiting"
|
||||
|
||||
# disaggregation prefill
|
||||
PREFILL_PREPARE = "prefill_prepare"
|
||||
PREFILL_BOOTSTRAP = "prefill_bootstrap"
|
||||
PREFILL_FORWARD = "prefill_forward"
|
||||
PREFILL_TRANSFER_KV_CACHE = "prefill_transfer_kv_cache"
|
||||
|
||||
# disaggregation decode
|
||||
DECODE_PREPARE = "decode_prepare"
|
||||
DECODE_BOOTSTRAP = "decode_bootstrap"
|
||||
DECODE_WAITING = "decode_waiting"
|
||||
DECODE_TRANSFERRED = "decode_transferred"
|
||||
|
||||
|
||||
class Req:
|
||||
"""The input and output status of a request."""
|
||||
|
||||
@@ -433,6 +453,7 @@ class Req:
|
||||
bootstrap_room: Optional[int] = None,
|
||||
data_parallel_rank: Optional[int] = None,
|
||||
vocab_size: Optional[int] = None,
|
||||
metrics_collector: Optional[SchedulerMetricsCollector] = None,
|
||||
):
|
||||
# Input and output info
|
||||
self.rid = rid
|
||||
@@ -590,10 +611,12 @@ class Req:
|
||||
self.spec_verify_ct = 0
|
||||
|
||||
# For metrics
|
||||
self.metrics_collector = metrics_collector
|
||||
self.time_stats: TimeStats = TimeStats()
|
||||
self.has_log_time_stats: bool = False
|
||||
self.queue_time_start = None
|
||||
self.queue_time_end = None
|
||||
self.last_tic = time.monotonic()
|
||||
|
||||
# For disaggregation
|
||||
self.bootstrap_host: str = bootstrap_host
|
||||
@@ -626,6 +649,16 @@ class Req:
|
||||
"""Check if this request is prefill-only (no token generation needed)."""
|
||||
return self.sampling_params.max_new_tokens == 0
|
||||
|
||||
def add_latency(self, stage: RequestStage):
|
||||
if self.metrics_collector is None:
|
||||
return
|
||||
assert stage.name in RequestStage.__members__, f"{stage=} is invalid"
|
||||
now = time.monotonic()
|
||||
self.metrics_collector.observe_request_latency_seconds(
|
||||
stage.value, now - self.last_tic
|
||||
)
|
||||
self.last_tic = now
|
||||
|
||||
def extend_image_inputs(self, image_inputs):
|
||||
if self.multimodal_inputs is None:
|
||||
self.multimodal_inputs = image_inputs
|
||||
|
||||
@@ -116,6 +116,7 @@ from sglang.srt.managers.schedule_batch import (
|
||||
FINISH_ABORT,
|
||||
MultimodalInputs,
|
||||
Req,
|
||||
RequestStage,
|
||||
ScheduleBatch,
|
||||
global_server_args_dict,
|
||||
)
|
||||
@@ -1232,6 +1233,9 @@ class Scheduler(
|
||||
bootstrap_room=recv_req.bootstrap_room,
|
||||
data_parallel_rank=recv_req.data_parallel_rank,
|
||||
vocab_size=self.model_config.vocab_size,
|
||||
metrics_collector=(
|
||||
self.metrics_collector if self.enable_metrics else None
|
||||
),
|
||||
)
|
||||
req.tokenizer = self.tokenizer
|
||||
|
||||
@@ -1768,6 +1772,7 @@ class Scheduler(
|
||||
# only record queue time when enable_metrics is True to avoid overhead
|
||||
for req in can_run_list:
|
||||
req.queue_time_end = time.perf_counter()
|
||||
req.add_latency(RequestStage.PREFILL_WAITING)
|
||||
|
||||
self.waiting_queue = [
|
||||
x for x in self.waiting_queue if x not in set(can_run_list)
|
||||
|
||||
@@ -17,7 +17,7 @@ from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from sglang.srt.metrics.utils import generate_buckets
|
||||
from sglang.srt.metrics.utils import exponential_buckets, generate_buckets
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import get_bool_env_var
|
||||
|
||||
@@ -513,6 +513,14 @@ class SchedulerMetricsCollector:
|
||||
buckets=tree_traversal_time_buckets,
|
||||
)
|
||||
|
||||
self.request_latency_seconds = Histogram(
|
||||
name="sglang:request_latency_seconds",
|
||||
documentation="The latency of each stage of requests.",
|
||||
# captures latency in range [1ms - ~1191s]
|
||||
buckets=exponential_buckets(start=0.001, width=1.62, length=30),
|
||||
labelnames=list(labels.keys()) + ["stage"],
|
||||
)
|
||||
|
||||
def _log_gauge(self, gauge, data: Union[int, float]) -> None:
|
||||
# Convenience function for logging to gauge.
|
||||
gauge.labels(**self.labels).set(data)
|
||||
@@ -526,6 +534,10 @@ class SchedulerMetricsCollector:
|
||||
def increment_transfer_failed_reqs(self) -> None:
|
||||
self.num_transfer_failed_reqs.labels(**self.labels).inc(1)
|
||||
|
||||
def observe_request_latency_seconds(self, stage: str, latency: float) -> None:
|
||||
labels_with_stage = {**self.labels, "stage": stage}
|
||||
self.request_latency_seconds.labels(**labels_with_stage).observe(latency)
|
||||
|
||||
def log_stats(self, stats: SchedulerStats) -> None:
|
||||
self._log_gauge(self.num_running_reqs, stats.num_running_reqs)
|
||||
self._log_gauge(self.num_used_tokens, stats.num_used_tokens)
|
||||
|
||||
@@ -20,6 +20,8 @@ import time
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, List, Optional
|
||||
|
||||
from sglang.srt.metrics.utils import exponential_buckets
|
||||
|
||||
enable_metrics = False
|
||||
|
||||
|
||||
@@ -42,13 +44,6 @@ def enable_func_timer():
|
||||
FUNC_LATENCY = None
|
||||
|
||||
|
||||
def exponential_buckets(start: float, width: float, length: int) -> List[float]:
|
||||
buckets = []
|
||||
for i in range(length):
|
||||
buckets.append(start * (width**i))
|
||||
return buckets
|
||||
|
||||
|
||||
def time_func_latency(
|
||||
func: Callable = None, name: Optional[str] = None
|
||||
) -> Callable[..., Any]:
|
||||
|
||||
@@ -46,3 +46,10 @@ def generate_buckets(
|
||||
return sorted(set(default_buckets))
|
||||
assert rule == "customer"
|
||||
return sorted(set([float(x) for x in buckets_rule[1:]]))
|
||||
|
||||
|
||||
def exponential_buckets(start: float, width: float, length: int) -> List[float]:
|
||||
buckets = []
|
||||
for i in range(length):
|
||||
buckets.append(start * (width**i))
|
||||
return buckets
|
||||
|
||||
Reference in New Issue
Block a user