Improve structured outputs: fix race condition, server crash, metrics and style (#6188)
This commit is contained in:
@@ -15,7 +15,119 @@
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Union
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from sglang.srt.utils import get_bool_env_var
|
||||
|
||||
SGLANG_TEST_REQUEST_TIME_STATS = get_bool_env_var("SGLANG_TEST_REQUEST_TIME_STATS")
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimeStats:
|
||||
"""
|
||||
Store the timestamps for each stage of a request.
|
||||
|
||||
Unified: wait_queue -> forward -> completion
|
||||
Prefill: bootstrap_queue -> wait_queue -> forward -> transfer_queue -> completion
|
||||
Decode: prealloc_queue -> transfer_queue -> wait_queue -> forward -> completion
|
||||
"""
|
||||
|
||||
lb_entry_time: float = 0.0
|
||||
wait_queue_entry_time: float = 0.0
|
||||
forward_entry_time: float = 0.0
|
||||
completion_time: float = 0.0
|
||||
prefill_bootstrap_queue_entry_time: float = 0.0
|
||||
prefill_transfer_queue_entry_time: float = 0.0
|
||||
decode_prealloc_queue_entry_time: float = 0.0
|
||||
decode_transfer_queue_entry_time: float = 0.0
|
||||
|
||||
class RequestType(Enum):
|
||||
UNIFIED = "unified"
|
||||
PREFILL = "prefill"
|
||||
DECODE = "decode"
|
||||
INVALID = "invalid"
|
||||
|
||||
def __str__(self) -> str:
|
||||
# if unified
|
||||
_type = self.get_type()
|
||||
|
||||
if _type == self.RequestType.UNIFIED:
|
||||
queue_duration = self.forward_entry_time - self.wait_queue_entry_time
|
||||
forward_duration = self.completion_time - self.forward_entry_time
|
||||
|
||||
if SGLANG_TEST_REQUEST_TIME_STATS:
|
||||
assert (
|
||||
queue_duration >= 0 and forward_duration >= 0
|
||||
), f"queue_duration={queue_duration} < 0 or forward_duration={forward_duration} < 0"
|
||||
|
||||
return f"queue_duration={self.format_duration(queue_duration)}, forward_duration={self.format_duration(forward_duration)}, start_time={self.wait_queue_entry_time}"
|
||||
elif _type == self.RequestType.PREFILL:
|
||||
bootstrap_duration = (
|
||||
self.wait_queue_entry_time - self.prefill_bootstrap_queue_entry_time
|
||||
)
|
||||
|
||||
queue_duration = self.forward_entry_time - self.wait_queue_entry_time
|
||||
|
||||
forward_duration = self.completion_time - self.forward_entry_time
|
||||
|
||||
if SGLANG_TEST_REQUEST_TIME_STATS:
|
||||
assert (
|
||||
bootstrap_duration >= 0
|
||||
and queue_duration >= 0
|
||||
and forward_duration >= 0
|
||||
), f"bootstrap_duration={bootstrap_duration} < 0 or queue_duration={queue_duration} < 0 or forward_duration={forward_duration} < 0"
|
||||
return f"bootstrap_duration={self.format_duration(bootstrap_duration)}, queue_duration={self.format_duration(queue_duration)}, forward_duration={self.format_duration(forward_duration)}, start_time={self.prefill_bootstrap_queue_entry_time}"
|
||||
# if decode
|
||||
elif _type == self.RequestType.DECODE:
|
||||
prealloc_duration = (
|
||||
self.decode_transfer_queue_entry_time
|
||||
- self.decode_prealloc_queue_entry_time
|
||||
)
|
||||
|
||||
transfer_duration = (
|
||||
self.wait_queue_entry_time - self.decode_transfer_queue_entry_time
|
||||
)
|
||||
queue_duration = self.forward_entry_time - self.wait_queue_entry_time
|
||||
forward_duration = self.completion_time - self.forward_entry_time
|
||||
|
||||
if SGLANG_TEST_REQUEST_TIME_STATS:
|
||||
assert (
|
||||
prealloc_duration >= 0
|
||||
and transfer_duration >= 0
|
||||
and queue_duration >= 0
|
||||
and forward_duration >= 0
|
||||
), f"prealloc_duration={prealloc_duration} < 0 or transfer_duration={transfer_duration} < 0 or queue_duration={queue_duration} < 0 or forward_duration={forward_duration} < 0"
|
||||
|
||||
return f"prealloc_duration={self.format_duration(prealloc_duration)}, transfer_duration={self.format_duration(transfer_duration)}, queue_duration={self.format_duration(queue_duration)}, forward_duration={self.format_duration(forward_duration)}, start_time={self.decode_prealloc_queue_entry_time}"
|
||||
else:
|
||||
return "Invalid Time Stats"
|
||||
|
||||
def format_duration(self, duration: float) -> str:
|
||||
return f"{duration * 1e3:.2f}ms"
|
||||
|
||||
def get_type(self) -> RequestType:
|
||||
"""Determine the type of request based on timestamp values."""
|
||||
if (
|
||||
self.prefill_bootstrap_queue_entry_time == 0.0
|
||||
and self.prefill_transfer_queue_entry_time == 0.0
|
||||
and self.decode_prealloc_queue_entry_time == 0.0
|
||||
and self.decode_transfer_queue_entry_time == 0.0
|
||||
):
|
||||
return self.RequestType.UNIFIED
|
||||
elif (
|
||||
self.prefill_bootstrap_queue_entry_time > 0.0
|
||||
and self.prefill_transfer_queue_entry_time > 0.0
|
||||
):
|
||||
return self.RequestType.PREFILL
|
||||
elif (
|
||||
self.decode_prealloc_queue_entry_time > 0.0
|
||||
and self.decode_transfer_queue_entry_time > 0.0
|
||||
and self.wait_queue_entry_time > 0.0
|
||||
):
|
||||
return self.RequestType.DECODE
|
||||
else:
|
||||
return self.RequestType.INVALID
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -26,15 +138,20 @@ class SchedulerStats:
|
||||
gen_throughput: float = 0.0
|
||||
num_queue_reqs: int = 0
|
||||
cache_hit_rate: float = 0.0
|
||||
num_grammar_queue_reqs: int = 0
|
||||
spec_accept_length: float = 0.0
|
||||
avg_request_queue_latency: float = 0.0
|
||||
num_prefill_prealloc_queue_reqs: int = 0
|
||||
num_prefill_infight_queue_reqs: int = 0
|
||||
num_decode_prealloc_queue_reqs: int = 0
|
||||
num_decode_transfer_queue_reqs: int = 0
|
||||
|
||||
|
||||
class SchedulerMetricsCollector:
|
||||
|
||||
def __init__(self, labels: Dict[str, str]) -> None:
|
||||
# We need to import prometheus_client after setting the env variable `PROMETHEUS_MULTIPROC_DIR`
|
||||
from prometheus_client import Gauge, Histogram
|
||||
from prometheus_client import Counter, Gauge
|
||||
|
||||
self.labels = labels
|
||||
self.last_log_time = time.time()
|
||||
@@ -74,6 +191,13 @@ class SchedulerMetricsCollector:
|
||||
multiprocess_mode="mostrecent",
|
||||
)
|
||||
|
||||
self.num_grammar_queue_reqs = Gauge(
|
||||
name="sglang:num_grammar_queue_reqs",
|
||||
documentation="The number of requests in the grammar waiting queue.",
|
||||
labelnames=labels.keys(),
|
||||
multiprocess_mode="mostrecent",
|
||||
)
|
||||
|
||||
self.cache_hit_rate = Gauge(
|
||||
name="sglang:cache_hit_rate",
|
||||
documentation="The prefix cache hit rate.",
|
||||
@@ -95,28 +219,98 @@ class SchedulerMetricsCollector:
|
||||
multiprocess_mode="mostrecent",
|
||||
)
|
||||
|
||||
# Disaggregation queue metrics
|
||||
self.num_prefill_prealloc_queue_reqs = Gauge(
|
||||
name="sglang:num_prefill_prealloc_queue_reqs",
|
||||
documentation="The number of requests in the prefill prealloc queue.",
|
||||
labelnames=labels.keys(),
|
||||
multiprocess_mode="mostrecent",
|
||||
)
|
||||
|
||||
self.num_prefill_infight_queue_reqs = Gauge(
|
||||
name="sglang:num_prefill_infight_queue_reqs",
|
||||
documentation="The number of requests in the prefill infight queue.",
|
||||
labelnames=labels.keys(),
|
||||
multiprocess_mode="mostrecent",
|
||||
)
|
||||
|
||||
self.num_decode_prealloc_queue_reqs = Gauge(
|
||||
name="sglang:num_decode_prealloc_queue_reqs",
|
||||
documentation="The number of requests in the decode prealloc queue.",
|
||||
labelnames=labels.keys(),
|
||||
multiprocess_mode="mostrecent",
|
||||
)
|
||||
|
||||
self.num_decode_transfer_queue_reqs = Gauge(
|
||||
name="sglang:num_decode_transfer_queue_reqs",
|
||||
documentation="The number of requests in the decode transfer queue.",
|
||||
labelnames=labels.keys(),
|
||||
multiprocess_mode="mostrecent",
|
||||
)
|
||||
|
||||
self.num_bootstrap_failed_reqs = Counter(
|
||||
name="sglang:num_bootstrap_failed_reqs",
|
||||
documentation="The number of bootstrap failed requests.",
|
||||
labelnames=labels.keys(),
|
||||
)
|
||||
|
||||
self.num_transfer_failed_reqs = Counter(
|
||||
name="sglang:num_transfer_failed_reqs",
|
||||
documentation="The number of transfer failed requests.",
|
||||
labelnames=labels.keys(),
|
||||
)
|
||||
|
||||
def _log_gauge(self, gauge, data: Union[int, float]) -> None:
|
||||
# Convenience function for logging to gauge.
|
||||
gauge.labels(**self.labels).set(data)
|
||||
|
||||
def increment_bootstrap_failed_reqs(self) -> None:
|
||||
self.num_bootstrap_failed_reqs.labels(**self.labels).inc(1)
|
||||
|
||||
def increment_transfer_failed_reqs(self) -> None:
|
||||
self.num_transfer_failed_reqs.labels(**self.labels).inc(1)
|
||||
|
||||
def log_stats(self, stats: SchedulerStats) -> None:
|
||||
self._log_gauge(self.num_running_reqs, stats.num_running_reqs)
|
||||
self._log_gauge(self.num_used_tokens, stats.num_used_tokens)
|
||||
self._log_gauge(self.token_usage, stats.token_usage)
|
||||
self._log_gauge(self.gen_throughput, stats.gen_throughput)
|
||||
self._log_gauge(self.num_queue_reqs, stats.num_queue_reqs)
|
||||
self._log_gauge(self.num_grammar_queue_reqs, stats.num_grammar_queue_reqs)
|
||||
self._log_gauge(self.cache_hit_rate, stats.cache_hit_rate)
|
||||
self._log_gauge(self.spec_accept_length, stats.spec_accept_length)
|
||||
self._log_gauge(self.avg_request_queue_latency, stats.avg_request_queue_latency)
|
||||
|
||||
# Disaggregation metrics
|
||||
self._log_gauge(
|
||||
self.num_prefill_prealloc_queue_reqs, stats.num_prefill_prealloc_queue_reqs
|
||||
)
|
||||
self._log_gauge(
|
||||
self.num_prefill_infight_queue_reqs, stats.num_prefill_infight_queue_reqs
|
||||
)
|
||||
self._log_gauge(
|
||||
self.num_decode_prealloc_queue_reqs, stats.num_decode_prealloc_queue_reqs
|
||||
)
|
||||
self._log_gauge(
|
||||
self.num_decode_transfer_queue_reqs, stats.num_decode_transfer_queue_reqs
|
||||
)
|
||||
|
||||
self.last_log_time = time.time()
|
||||
|
||||
|
||||
class TokenizerMetricsCollector:
|
||||
def __init__(self, labels: Dict[str, str]) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
labels: Dict[str, str],
|
||||
bucket_time_to_first_token: Optional[List[float]] = None,
|
||||
bucket_inter_token_latency: Optional[List[float]] = None,
|
||||
bucket_e2e_request_latency: Optional[List[float]] = None,
|
||||
collect_tokens_histogram: bool = False,
|
||||
) -> None:
|
||||
# We need to import prometheus_client after setting the env variable `PROMETHEUS_MULTIPROC_DIR`
|
||||
from prometheus_client import Counter, Histogram
|
||||
|
||||
self.labels = labels
|
||||
self.collect_tokens_histogram = collect_tokens_histogram
|
||||
|
||||
self.prompt_tokens_total = Counter(
|
||||
name="sglang:prompt_tokens_total",
|
||||
@@ -130,6 +324,66 @@ class TokenizerMetricsCollector:
|
||||
labelnames=labels.keys(),
|
||||
)
|
||||
|
||||
if collect_tokens_histogram:
|
||||
bucket_prompt_tokens = [
|
||||
100,
|
||||
300,
|
||||
500,
|
||||
700,
|
||||
1000,
|
||||
1500,
|
||||
2000,
|
||||
3000,
|
||||
4000,
|
||||
5000,
|
||||
6000,
|
||||
7000,
|
||||
8000,
|
||||
9000,
|
||||
10000,
|
||||
12000,
|
||||
15000,
|
||||
20000,
|
||||
22000,
|
||||
25000,
|
||||
30000,
|
||||
35000,
|
||||
40000,
|
||||
]
|
||||
self.prompt_tokens_histogram = Histogram(
|
||||
name="sglang:prompt_tokens_histogram",
|
||||
documentation="Histogram of prompt token length.",
|
||||
labelnames=labels.keys(),
|
||||
buckets=bucket_prompt_tokens,
|
||||
)
|
||||
bucket_generation_tokens = [
|
||||
100,
|
||||
300,
|
||||
500,
|
||||
1000,
|
||||
1200,
|
||||
1500,
|
||||
1700,
|
||||
2000,
|
||||
2500,
|
||||
3000,
|
||||
3500,
|
||||
4000,
|
||||
4500,
|
||||
5000,
|
||||
6000,
|
||||
7000,
|
||||
8000,
|
||||
9000,
|
||||
10000,
|
||||
]
|
||||
self.generation_tokens_histogram = Histogram(
|
||||
name="sglang:generation_tokens_histogram",
|
||||
documentation="Histogram of generation token length.",
|
||||
labelnames=labels.keys(),
|
||||
buckets=bucket_generation_tokens,
|
||||
)
|
||||
|
||||
self.cached_tokens_total = Counter(
|
||||
name="sglang:cached_tokens_total",
|
||||
documentation="Number of cached prompt tokens.",
|
||||
@@ -142,11 +396,14 @@ class TokenizerMetricsCollector:
|
||||
labelnames=labels.keys(),
|
||||
)
|
||||
|
||||
self.histogram_time_to_first_token = Histogram(
|
||||
name="sglang:time_to_first_token_seconds",
|
||||
documentation="Histogram of time to first token in seconds.",
|
||||
self.num_so_requests_total = Counter(
|
||||
name="sglang:num_so_requests_total",
|
||||
documentation="Number of structured output requests processed.",
|
||||
labelnames=labels.keys(),
|
||||
buckets=[
|
||||
)
|
||||
|
||||
if bucket_time_to_first_token is None:
|
||||
bucket_time_to_first_token = [
|
||||
0.1,
|
||||
0.2,
|
||||
0.4,
|
||||
@@ -165,14 +422,33 @@ class TokenizerMetricsCollector:
|
||||
100,
|
||||
200,
|
||||
400,
|
||||
],
|
||||
)
|
||||
]
|
||||
|
||||
self.histogram_inter_token_latency_seconds = Histogram(
|
||||
name="sglang:inter_token_latency_seconds",
|
||||
documentation="Histogram of inter-token latency in seconds.",
|
||||
labelnames=labels.keys(),
|
||||
buckets=[
|
||||
if bucket_e2e_request_latency is None:
|
||||
bucket_e2e_request_latency = [
|
||||
0.1,
|
||||
0.2,
|
||||
0.4,
|
||||
0.6,
|
||||
0.8,
|
||||
1,
|
||||
2,
|
||||
4,
|
||||
6,
|
||||
8,
|
||||
10,
|
||||
20,
|
||||
40,
|
||||
60,
|
||||
80,
|
||||
100,
|
||||
200,
|
||||
400,
|
||||
800,
|
||||
]
|
||||
|
||||
if bucket_inter_token_latency is None:
|
||||
bucket_inter_token_latency = [
|
||||
0.002,
|
||||
0.004,
|
||||
0.006,
|
||||
@@ -196,34 +472,27 @@ class TokenizerMetricsCollector:
|
||||
4.000,
|
||||
6.000,
|
||||
8.000,
|
||||
],
|
||||
]
|
||||
|
||||
self.histogram_time_to_first_token = Histogram(
|
||||
name="sglang:time_to_first_token_seconds",
|
||||
documentation="Histogram of time to first token in seconds.",
|
||||
labelnames=labels.keys(),
|
||||
buckets=bucket_time_to_first_token,
|
||||
)
|
||||
|
||||
self.histogram_inter_token_latency_seconds = Histogram(
|
||||
name="sglang:inter_token_latency_seconds",
|
||||
documentation="Histogram of inter-token latency in seconds.",
|
||||
labelnames=labels.keys(),
|
||||
buckets=bucket_inter_token_latency,
|
||||
)
|
||||
|
||||
self.histogram_e2e_request_latency = Histogram(
|
||||
name="sglang:e2e_request_latency_seconds",
|
||||
documentation="Histogram of End-to-end request latency in seconds",
|
||||
labelnames=labels.keys(),
|
||||
buckets=[
|
||||
0.1,
|
||||
0.2,
|
||||
0.4,
|
||||
0.6,
|
||||
0.8,
|
||||
1,
|
||||
2,
|
||||
4,
|
||||
6,
|
||||
8,
|
||||
10,
|
||||
20,
|
||||
40,
|
||||
60,
|
||||
80,
|
||||
100,
|
||||
200,
|
||||
400,
|
||||
800,
|
||||
],
|
||||
buckets=bucket_e2e_request_latency,
|
||||
)
|
||||
|
||||
def _log_histogram(self, histogram, data: Union[int, float]) -> None:
|
||||
@@ -235,13 +504,19 @@ class TokenizerMetricsCollector:
|
||||
generation_tokens: int,
|
||||
cached_tokens: int,
|
||||
e2e_latency: float,
|
||||
has_grammar: bool,
|
||||
):
|
||||
self.prompt_tokens_total.labels(**self.labels).inc(prompt_tokens)
|
||||
self.generation_tokens_total.labels(**self.labels).inc(generation_tokens)
|
||||
if cached_tokens > 0:
|
||||
self.cached_tokens_total.labels(**self.labels).inc(cached_tokens)
|
||||
self.num_requests_total.labels(**self.labels).inc(1)
|
||||
if has_grammar:
|
||||
self.num_so_requests_total.labels(**self.labels).inc(1)
|
||||
self._log_histogram(self.histogram_e2e_request_latency, e2e_latency)
|
||||
if self.collect_tokens_histogram:
|
||||
self._log_histogram(self.prompt_tokens_histogram, prompt_tokens)
|
||||
self._log_histogram(self.generation_tokens_histogram, generation_tokens)
|
||||
|
||||
def observe_time_to_first_token(self, value: float):
|
||||
self.histogram_time_to_first_token.labels(**self.labels).observe(value)
|
||||
|
||||
Reference in New Issue
Block a user