[PD metrics] Add latency Histogram metrics of each stage for generate requests (#8710)
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -35,6 +37,7 @@ import copy
|
||||
import dataclasses
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from enum import Enum, auto
|
||||
from http import HTTPStatus
|
||||
from itertools import chain
|
||||
@@ -61,7 +64,7 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
||||
from sglang.srt.mem_cache.lora_radix_cache import LoRAKey, LoRARadixCache
|
||||
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool
|
||||
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
||||
from sglang.srt.metrics.collector import TimeStats
|
||||
from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats
|
||||
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
@@ -407,6 +410,23 @@ class MultimodalInputs:
|
||||
# other args would be kept intact
|
||||
|
||||
|
||||
class RequestStage(str, enum.Enum):
|
||||
# prefill
|
||||
PREFILL_WAITING = "prefill_waiting"
|
||||
|
||||
# disaggregation prefill
|
||||
PREFILL_PREPARE = "prefill_prepare"
|
||||
PREFILL_BOOTSTRAP = "prefill_bootstrap"
|
||||
PREFILL_FORWARD = "prefill_forward"
|
||||
PREFILL_TRANSFER_KV_CACHE = "prefill_transfer_kv_cache"
|
||||
|
||||
# disaggregation decode
|
||||
DECODE_PREPARE = "decode_prepare"
|
||||
DECODE_BOOTSTRAP = "decode_bootstrap"
|
||||
DECODE_WAITING = "decode_waiting"
|
||||
DECODE_TRANSFERRED = "decode_transferred"
|
||||
|
||||
|
||||
class Req:
|
||||
"""The input and output status of a request."""
|
||||
|
||||
@@ -433,6 +453,7 @@ class Req:
|
||||
bootstrap_room: Optional[int] = None,
|
||||
data_parallel_rank: Optional[int] = None,
|
||||
vocab_size: Optional[int] = None,
|
||||
metrics_collector: Optional[SchedulerMetricsCollector] = None,
|
||||
):
|
||||
# Input and output info
|
||||
self.rid = rid
|
||||
@@ -590,10 +611,12 @@ class Req:
|
||||
self.spec_verify_ct = 0
|
||||
|
||||
# For metrics
|
||||
self.metrics_collector = metrics_collector
|
||||
self.time_stats: TimeStats = TimeStats()
|
||||
self.has_log_time_stats: bool = False
|
||||
self.queue_time_start = None
|
||||
self.queue_time_end = None
|
||||
self.last_tic = time.monotonic()
|
||||
|
||||
# For disaggregation
|
||||
self.bootstrap_host: str = bootstrap_host
|
||||
@@ -626,6 +649,16 @@ class Req:
|
||||
"""Check if this request is prefill-only (no token generation needed)."""
|
||||
return self.sampling_params.max_new_tokens == 0
|
||||
|
||||
def add_latency(self, stage: RequestStage):
|
||||
if self.metrics_collector is None:
|
||||
return
|
||||
assert stage.name in RequestStage.__members__, f"{stage=} is invalid"
|
||||
now = time.monotonic()
|
||||
self.metrics_collector.observe_request_latency_seconds(
|
||||
stage.value, now - self.last_tic
|
||||
)
|
||||
self.last_tic = now
|
||||
|
||||
def extend_image_inputs(self, image_inputs):
|
||||
if self.multimodal_inputs is None:
|
||||
self.multimodal_inputs = image_inputs
|
||||
|
||||
Reference in New Issue
Block a user