[PD metrics] Add latency Histogram metrics of each stage for generate requests (#8710)
This commit is contained in:
@@ -45,7 +45,7 @@ from sglang.srt.disaggregation.utils import (
|
|||||||
prepare_abort,
|
prepare_abort,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||||
from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch
|
from sglang.srt.managers.schedule_batch import FINISH_ABORT, RequestStage, ScheduleBatch
|
||||||
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
||||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||||
from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
|
from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
|
||||||
@@ -253,6 +253,7 @@ class DecodePreallocQueue:
|
|||||||
prefill_dp_rank=req.data_parallel_rank,
|
prefill_dp_rank=req.data_parallel_rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
req.add_latency(RequestStage.DECODE_PREPARE)
|
||||||
self.queue.append(
|
self.queue.append(
|
||||||
DecodeRequest(req=req, kv_receiver=kv_receiver, waiting_for_input=False)
|
DecodeRequest(req=req, kv_receiver=kv_receiver, waiting_for_input=False)
|
||||||
)
|
)
|
||||||
@@ -421,6 +422,7 @@ class DecodePreallocQueue:
|
|||||||
kv_indices, self.token_to_kv_pool_allocator.page_size
|
kv_indices, self.token_to_kv_pool_allocator.page_size
|
||||||
)
|
)
|
||||||
decode_req.kv_receiver.init(page_indices, decode_req.metadata_buffer_index)
|
decode_req.kv_receiver.init(page_indices, decode_req.metadata_buffer_index)
|
||||||
|
decode_req.req.add_latency(RequestStage.DECODE_BOOTSTRAP)
|
||||||
preallocated_reqs.append(decode_req)
|
preallocated_reqs.append(decode_req)
|
||||||
indices_to_remove.add(i)
|
indices_to_remove.add(i)
|
||||||
|
|
||||||
@@ -662,6 +664,7 @@ class DecodeTransferQueue:
|
|||||||
for i in indices_to_remove:
|
for i in indices_to_remove:
|
||||||
idx = self.queue[i].metadata_buffer_index
|
idx = self.queue[i].metadata_buffer_index
|
||||||
assert idx != -1
|
assert idx != -1
|
||||||
|
self.queue[i].req.add_latency(RequestStage.DECODE_TRANSFERRED)
|
||||||
self.req_to_metadata_buffer_idx_allocator.free(idx)
|
self.req_to_metadata_buffer_idx_allocator.free(idx)
|
||||||
|
|
||||||
self.queue = [
|
self.queue = [
|
||||||
@@ -853,6 +856,7 @@ class SchedulerDisaggregationDecodeMixin:
|
|||||||
# we can only add at least `num_not_used_batch` new batch to the running queue
|
# we can only add at least `num_not_used_batch` new batch to the running queue
|
||||||
if i < num_not_used_batch:
|
if i < num_not_used_batch:
|
||||||
can_run_list.append(req)
|
can_run_list.append(req)
|
||||||
|
req.add_latency(RequestStage.DECODE_WAITING)
|
||||||
req.init_next_round_input(self.tree_cache)
|
req.init_next_round_input(self.tree_cache)
|
||||||
else:
|
else:
|
||||||
waiting_queue.append(req)
|
waiting_queue.append(req)
|
||||||
|
|||||||
@@ -42,7 +42,12 @@ from sglang.srt.disaggregation.utils import (
|
|||||||
poll_and_all_reduce,
|
poll_and_all_reduce,
|
||||||
prepare_abort,
|
prepare_abort,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch
|
from sglang.srt.managers.schedule_batch import (
|
||||||
|
FINISH_LENGTH,
|
||||||
|
Req,
|
||||||
|
RequestStage,
|
||||||
|
ScheduleBatch,
|
||||||
|
)
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
DynamicGradMode,
|
DynamicGradMode,
|
||||||
@@ -170,6 +175,7 @@ class PrefillBootstrapQueue:
|
|||||||
pp_rank=self.pp_rank,
|
pp_rank=self.pp_rank,
|
||||||
)
|
)
|
||||||
self._process_req(req)
|
self._process_req(req)
|
||||||
|
req.add_latency(RequestStage.PREFILL_PREPARE)
|
||||||
self.queue.append(req)
|
self.queue.append(req)
|
||||||
|
|
||||||
def extend(self, reqs: List[Req], num_kv_heads: int) -> None:
|
def extend(self, reqs: List[Req], num_kv_heads: int) -> None:
|
||||||
@@ -256,6 +262,8 @@ class PrefillBootstrapQueue:
|
|||||||
|
|
||||||
num_pages = kv_to_page_num(num_kv_indices, self.token_to_kv_pool.page_size)
|
num_pages = kv_to_page_num(num_kv_indices, self.token_to_kv_pool.page_size)
|
||||||
req.disagg_kv_sender.init(num_pages, req.metadata_buffer_index)
|
req.disagg_kv_sender.init(num_pages, req.metadata_buffer_index)
|
||||||
|
|
||||||
|
req.add_latency(RequestStage.PREFILL_BOOTSTRAP)
|
||||||
bootstrapped_reqs.append(req)
|
bootstrapped_reqs.append(req)
|
||||||
indices_to_remove.add(i)
|
indices_to_remove.add(i)
|
||||||
|
|
||||||
@@ -404,6 +412,7 @@ class SchedulerDisaggregationPrefillMixin:
|
|||||||
# There is no output_ids for prefill
|
# There is no output_ids for prefill
|
||||||
req.output_ids.append(next_token_id)
|
req.output_ids.append(next_token_id)
|
||||||
self.tree_cache.cache_unfinished_req(req) # update the tree and lock
|
self.tree_cache.cache_unfinished_req(req) # update the tree and lock
|
||||||
|
req.add_latency(RequestStage.PREFILL_FORWARD)
|
||||||
self.disagg_prefill_inflight_queue.append(req)
|
self.disagg_prefill_inflight_queue.append(req)
|
||||||
if (
|
if (
|
||||||
logits_output is not None
|
logits_output is not None
|
||||||
@@ -539,6 +548,7 @@ class SchedulerDisaggregationPrefillMixin:
|
|||||||
)
|
)
|
||||||
for req in done_reqs:
|
for req in done_reqs:
|
||||||
req: Req
|
req: Req
|
||||||
|
req.add_latency(RequestStage.PREFILL_TRANSFER_KV_CACHE)
|
||||||
self.req_to_metadata_buffer_idx_allocator.free(req.metadata_buffer_index)
|
self.req_to_metadata_buffer_idx_allocator.free(req.metadata_buffer_index)
|
||||||
req.metadata_buffer_index = -1
|
req.metadata_buffer_index = -1
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import enum
|
||||||
|
|
||||||
# Copyright 2023-2024 SGLang Team
|
# Copyright 2023-2024 SGLang Team
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@@ -35,6 +37,7 @@ import copy
|
|||||||
import dataclasses
|
import dataclasses
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
|
import time
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
@@ -61,7 +64,7 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
|||||||
from sglang.srt.mem_cache.lora_radix_cache import LoRAKey, LoRARadixCache
|
from sglang.srt.mem_cache.lora_radix_cache import LoRAKey, LoRARadixCache
|
||||||
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool
|
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool
|
||||||
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
||||||
from sglang.srt.metrics.collector import TimeStats
|
from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats
|
||||||
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
||||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||||
@@ -407,6 +410,23 @@ class MultimodalInputs:
|
|||||||
# other args would be kept intact
|
# other args would be kept intact
|
||||||
|
|
||||||
|
|
||||||
|
class RequestStage(str, enum.Enum):
|
||||||
|
# prefill
|
||||||
|
PREFILL_WAITING = "prefill_waiting"
|
||||||
|
|
||||||
|
# disaggregation prefill
|
||||||
|
PREFILL_PREPARE = "prefill_prepare"
|
||||||
|
PREFILL_BOOTSTRAP = "prefill_bootstrap"
|
||||||
|
PREFILL_FORWARD = "prefill_forward"
|
||||||
|
PREFILL_TRANSFER_KV_CACHE = "prefill_transfer_kv_cache"
|
||||||
|
|
||||||
|
# disaggregation decode
|
||||||
|
DECODE_PREPARE = "decode_prepare"
|
||||||
|
DECODE_BOOTSTRAP = "decode_bootstrap"
|
||||||
|
DECODE_WAITING = "decode_waiting"
|
||||||
|
DECODE_TRANSFERRED = "decode_transferred"
|
||||||
|
|
||||||
|
|
||||||
class Req:
|
class Req:
|
||||||
"""The input and output status of a request."""
|
"""The input and output status of a request."""
|
||||||
|
|
||||||
@@ -433,6 +453,7 @@ class Req:
|
|||||||
bootstrap_room: Optional[int] = None,
|
bootstrap_room: Optional[int] = None,
|
||||||
data_parallel_rank: Optional[int] = None,
|
data_parallel_rank: Optional[int] = None,
|
||||||
vocab_size: Optional[int] = None,
|
vocab_size: Optional[int] = None,
|
||||||
|
metrics_collector: Optional[SchedulerMetricsCollector] = None,
|
||||||
):
|
):
|
||||||
# Input and output info
|
# Input and output info
|
||||||
self.rid = rid
|
self.rid = rid
|
||||||
@@ -590,10 +611,12 @@ class Req:
|
|||||||
self.spec_verify_ct = 0
|
self.spec_verify_ct = 0
|
||||||
|
|
||||||
# For metrics
|
# For metrics
|
||||||
|
self.metrics_collector = metrics_collector
|
||||||
self.time_stats: TimeStats = TimeStats()
|
self.time_stats: TimeStats = TimeStats()
|
||||||
self.has_log_time_stats: bool = False
|
self.has_log_time_stats: bool = False
|
||||||
self.queue_time_start = None
|
self.queue_time_start = None
|
||||||
self.queue_time_end = None
|
self.queue_time_end = None
|
||||||
|
self.last_tic = time.monotonic()
|
||||||
|
|
||||||
# For disaggregation
|
# For disaggregation
|
||||||
self.bootstrap_host: str = bootstrap_host
|
self.bootstrap_host: str = bootstrap_host
|
||||||
@@ -626,6 +649,16 @@ class Req:
|
|||||||
"""Check if this request is prefill-only (no token generation needed)."""
|
"""Check if this request is prefill-only (no token generation needed)."""
|
||||||
return self.sampling_params.max_new_tokens == 0
|
return self.sampling_params.max_new_tokens == 0
|
||||||
|
|
||||||
|
def add_latency(self, stage: RequestStage):
|
||||||
|
if self.metrics_collector is None:
|
||||||
|
return
|
||||||
|
assert stage.name in RequestStage.__members__, f"{stage=} is invalid"
|
||||||
|
now = time.monotonic()
|
||||||
|
self.metrics_collector.observe_request_latency_seconds(
|
||||||
|
stage.value, now - self.last_tic
|
||||||
|
)
|
||||||
|
self.last_tic = now
|
||||||
|
|
||||||
def extend_image_inputs(self, image_inputs):
|
def extend_image_inputs(self, image_inputs):
|
||||||
if self.multimodal_inputs is None:
|
if self.multimodal_inputs is None:
|
||||||
self.multimodal_inputs = image_inputs
|
self.multimodal_inputs = image_inputs
|
||||||
|
|||||||
@@ -116,6 +116,7 @@ from sglang.srt.managers.schedule_batch import (
|
|||||||
FINISH_ABORT,
|
FINISH_ABORT,
|
||||||
MultimodalInputs,
|
MultimodalInputs,
|
||||||
Req,
|
Req,
|
||||||
|
RequestStage,
|
||||||
ScheduleBatch,
|
ScheduleBatch,
|
||||||
global_server_args_dict,
|
global_server_args_dict,
|
||||||
)
|
)
|
||||||
@@ -1232,6 +1233,9 @@ class Scheduler(
|
|||||||
bootstrap_room=recv_req.bootstrap_room,
|
bootstrap_room=recv_req.bootstrap_room,
|
||||||
data_parallel_rank=recv_req.data_parallel_rank,
|
data_parallel_rank=recv_req.data_parallel_rank,
|
||||||
vocab_size=self.model_config.vocab_size,
|
vocab_size=self.model_config.vocab_size,
|
||||||
|
metrics_collector=(
|
||||||
|
self.metrics_collector if self.enable_metrics else None
|
||||||
|
),
|
||||||
)
|
)
|
||||||
req.tokenizer = self.tokenizer
|
req.tokenizer = self.tokenizer
|
||||||
|
|
||||||
@@ -1768,6 +1772,7 @@ class Scheduler(
|
|||||||
# only record queue time when enable_metrics is True to avoid overhead
|
# only record queue time when enable_metrics is True to avoid overhead
|
||||||
for req in can_run_list:
|
for req in can_run_list:
|
||||||
req.queue_time_end = time.perf_counter()
|
req.queue_time_end = time.perf_counter()
|
||||||
|
req.add_latency(RequestStage.PREFILL_WAITING)
|
||||||
|
|
||||||
self.waiting_queue = [
|
self.waiting_queue = [
|
||||||
x for x in self.waiting_queue if x not in set(can_run_list)
|
x for x in self.waiting_queue if x not in set(can_run_list)
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ from dataclasses import dataclass, field
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
from sglang.srt.metrics.utils import generate_buckets
|
from sglang.srt.metrics.utils import exponential_buckets, generate_buckets
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.utils import get_bool_env_var
|
from sglang.srt.utils import get_bool_env_var
|
||||||
|
|
||||||
@@ -513,6 +513,14 @@ class SchedulerMetricsCollector:
|
|||||||
buckets=tree_traversal_time_buckets,
|
buckets=tree_traversal_time_buckets,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.request_latency_seconds = Histogram(
|
||||||
|
name="sglang:request_latency_seconds",
|
||||||
|
documentation="The latency of each stage of requests.",
|
||||||
|
# captures latency in range [1ms - ~1191s]
|
||||||
|
buckets=exponential_buckets(start=0.001, width=1.62, length=30),
|
||||||
|
labelnames=list(labels.keys()) + ["stage"],
|
||||||
|
)
|
||||||
|
|
||||||
def _log_gauge(self, gauge, data: Union[int, float]) -> None:
|
def _log_gauge(self, gauge, data: Union[int, float]) -> None:
|
||||||
# Convenience function for logging to gauge.
|
# Convenience function for logging to gauge.
|
||||||
gauge.labels(**self.labels).set(data)
|
gauge.labels(**self.labels).set(data)
|
||||||
@@ -526,6 +534,10 @@ class SchedulerMetricsCollector:
|
|||||||
def increment_transfer_failed_reqs(self) -> None:
|
def increment_transfer_failed_reqs(self) -> None:
|
||||||
self.num_transfer_failed_reqs.labels(**self.labels).inc(1)
|
self.num_transfer_failed_reqs.labels(**self.labels).inc(1)
|
||||||
|
|
||||||
|
def observe_request_latency_seconds(self, stage: str, latency: float) -> None:
|
||||||
|
labels_with_stage = {**self.labels, "stage": stage}
|
||||||
|
self.request_latency_seconds.labels(**labels_with_stage).observe(latency)
|
||||||
|
|
||||||
def log_stats(self, stats: SchedulerStats) -> None:
|
def log_stats(self, stats: SchedulerStats) -> None:
|
||||||
self._log_gauge(self.num_running_reqs, stats.num_running_reqs)
|
self._log_gauge(self.num_running_reqs, stats.num_running_reqs)
|
||||||
self._log_gauge(self.num_used_tokens, stats.num_used_tokens)
|
self._log_gauge(self.num_used_tokens, stats.num_used_tokens)
|
||||||
|
|||||||
@@ -20,6 +20,8 @@ import time
|
|||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Any, Callable, List, Optional
|
from typing import Any, Callable, List, Optional
|
||||||
|
|
||||||
|
from sglang.srt.metrics.utils import exponential_buckets
|
||||||
|
|
||||||
enable_metrics = False
|
enable_metrics = False
|
||||||
|
|
||||||
|
|
||||||
@@ -42,13 +44,6 @@ def enable_func_timer():
|
|||||||
FUNC_LATENCY = None
|
FUNC_LATENCY = None
|
||||||
|
|
||||||
|
|
||||||
def exponential_buckets(start: float, width: float, length: int) -> List[float]:
|
|
||||||
buckets = []
|
|
||||||
for i in range(length):
|
|
||||||
buckets.append(start * (width**i))
|
|
||||||
return buckets
|
|
||||||
|
|
||||||
|
|
||||||
def time_func_latency(
|
def time_func_latency(
|
||||||
func: Callable = None, name: Optional[str] = None
|
func: Callable = None, name: Optional[str] = None
|
||||||
) -> Callable[..., Any]:
|
) -> Callable[..., Any]:
|
||||||
|
|||||||
@@ -46,3 +46,10 @@ def generate_buckets(
|
|||||||
return sorted(set(default_buckets))
|
return sorted(set(default_buckets))
|
||||||
assert rule == "customer"
|
assert rule == "customer"
|
||||||
return sorted(set([float(x) for x in buckets_rule[1:]]))
|
return sorted(set([float(x) for x in buckets_rule[1:]]))
|
||||||
|
|
||||||
|
|
||||||
|
def exponential_buckets(start: float, width: float, length: int) -> List[float]:
|
||||||
|
buckets = []
|
||||||
|
for i in range(length):
|
||||||
|
buckets.append(start * (width**i))
|
||||||
|
return buckets
|
||||||
|
|||||||
Reference in New Issue
Block a user