diff --git a/python/sglang/srt/metrics/collector.py b/python/sglang/srt/metrics/collector.py index b174bbeb3..7cbcb6949 100644 --- a/python/sglang/srt/metrics/collector.py +++ b/python/sglang/srt/metrics/collector.py @@ -50,6 +50,9 @@ class TimeStats: DECODE = "decode" INVALID = "invalid" + def get_queueing_time(self) -> float: + return self.forward_entry_time - self.wait_queue_entry_time + def __str__(self) -> str: # if unified _type = self.get_type() @@ -134,27 +137,48 @@ class TimeStats: @dataclass class SchedulerStats: + # Basics num_running_reqs: int = 0 num_used_tokens: int = 0 token_usage: float = 0.0 + swa_token_usage: float = 0.0 gen_throughput: float = 0.0 num_queue_reqs: int = 0 - cache_hit_rate: float = 0.0 num_grammar_queue_reqs: int = 0 - spec_accept_length: float = 0.0 + num_running_reqs_offline_batch: int = 0 avg_request_queue_latency: float = 0.0 + cache_hit_rate: float = 0.0 + + # Speculative decoding + spec_accept_length: float = 0.0 + + # PD disaggregation num_prefill_prealloc_queue_reqs: int = 0 num_prefill_inflight_queue_reqs: int = 0 num_decode_prealloc_queue_reqs: int = 0 num_decode_transfer_queue_reqs: int = 0 + kv_transfer_speed_gb_s: float = 0.0 + kv_transfer_latency_ms: float = 0.0 + + # Retract total_retracted_reqs: int = 0 + num_retracted_reqs: int = 0 + num_paused_reqs: int = 0 + + # Utilization + utilization: float = 0.0 + max_running_requests_under_SLO: Optional[int] = None + + # Engine startup + engine_startup_time: float = 0.0 + engine_load_weights_time: float = 0.0 class SchedulerMetricsCollector: def __init__(self, labels: Dict[str, str]) -> None: # We need to import prometheus_client after setting the env variable `PROMETHEUS_MULTIPROC_DIR` - from prometheus_client import Counter, Gauge + from prometheus_client import Counter, Gauge, Histogram self.labels = labels self.last_log_time = time.perf_counter() @@ -165,42 +189,54 @@ class SchedulerMetricsCollector: labelnames=labels.keys(), multiprocess_mode="mostrecent", ) - self.num_used_tokens = Gauge( name="sglang:num_used_tokens", documentation="The number of used tokens.", labelnames=labels.keys(), multiprocess_mode="mostrecent", ) - self.token_usage = Gauge( name="sglang:token_usage", documentation="The token usage.", labelnames=labels.keys(), multiprocess_mode="mostrecent", ) - + self.swa_token_usage = Gauge( + name="sglang:swa_token_usage", + documentation="The token usage for SWA layers.", + labelnames=labels.keys(), + multiprocess_mode="mostrecent", + ) self.gen_throughput = Gauge( name="sglang:gen_throughput", documentation="The generation throughput (token/s).", labelnames=labels.keys(), multiprocess_mode="mostrecent", ) - self.num_queue_reqs = Gauge( name="sglang:num_queue_reqs", documentation="The number of requests in the waiting queue.", labelnames=labels.keys(), multiprocess_mode="mostrecent", ) - self.num_grammar_queue_reqs = Gauge( name="sglang:num_grammar_queue_reqs", documentation="The number of requests in the grammar waiting queue.", labelnames=labels.keys(), multiprocess_mode="mostrecent", ) - + self.num_running_reqs_offline_batch = Gauge( + name="sglang:num_running_reqs_offline_batch", + documentation="The number of running low-priority offline batch requests(label is 'batch').", + labelnames=labels.keys(), + multiprocess_mode="mostrecent", + ) + self.avg_request_queue_latency = Gauge( + name="sglang:avg_request_queue_latency", + documentation="The average request queue latency for the last batch of requests in seconds.", + labelnames=labels.keys(), + multiprocess_mode="mostrecent", + ) self.cache_hit_rate = Gauge( name="sglang:cache_hit_rate", documentation="The prefix cache hit rate.", @@ -208,6 +244,7 @@ class SchedulerMetricsCollector: multiprocess_mode="mostrecent", ) + # Speculative decoding self.spec_accept_length = Gauge( name="sglang:spec_accept_length", documentation="The average acceptance length of speculative decoding.", @@ -215,65 +252,275 @@ class SchedulerMetricsCollector: multiprocess_mode="mostrecent", ) - self.avg_request_queue_latency = Gauge( - name="sglang:avg_request_queue_latency", - documentation="The average request queue latency for the last batch of requests in seconds.", - labelnames=labels.keys(), - multiprocess_mode="mostrecent", - ) - - self.total_retracted_reqs = Gauge( - name="sglang:total_retracted_reqs", - documentation="The total number of retracted requests due to kvcache full.", - labelnames=labels.keys(), - multiprocess_mode="mostrecent", - ) - - # Disaggregation queue metrics + # PD disaggregation self.num_prefill_prealloc_queue_reqs = Gauge( name="sglang:num_prefill_prealloc_queue_reqs", documentation="The number of requests in the prefill prealloc queue.", labelnames=labels.keys(), multiprocess_mode="mostrecent", ) - self.num_prefill_inflight_queue_reqs = Gauge( name="sglang:num_prefill_inflight_queue_reqs", documentation="The number of requests in the prefill inflight queue.", labelnames=labels.keys(), multiprocess_mode="mostrecent", ) - self.num_decode_prealloc_queue_reqs = Gauge( name="sglang:num_decode_prealloc_queue_reqs", documentation="The number of requests in the decode prealloc queue.", labelnames=labels.keys(), multiprocess_mode="mostrecent", ) - self.num_decode_transfer_queue_reqs = Gauge( name="sglang:num_decode_transfer_queue_reqs", documentation="The number of requests in the decode transfer queue.", labelnames=labels.keys(), multiprocess_mode="mostrecent", ) - self.num_bootstrap_failed_reqs = Counter( - name="sglang:num_bootstrap_failed_reqs", + name="sglang:num_bootstrap_failed_reqs_total", documentation="The number of bootstrap failed requests.", labelnames=labels.keys(), ) - self.num_transfer_failed_reqs = Counter( - name="sglang:num_transfer_failed_reqs", + name="sglang:num_transfer_failed_reqs_total", documentation="The number of transfer failed requests.", labelnames=labels.keys(), ) + self.kv_transfer_speed_gb_s = Gauge( + name="sglang:kv_transfer_speed_gb_s", + documentation="The transfer speed of the KV cache in GB/s.", + labelnames=labels.keys(), + multiprocess_mode="mostrecent", + ) + self.kv_transfer_latency_ms = Gauge( + name="sglang:kv_transfer_latency_ms", + documentation="The transfer latency of the KV cache in ms.", + labelnames=labels.keys(), + multiprocess_mode="mostrecent", + ) + + # Retract + self.total_retracted_reqs = Gauge( + name="sglang:total_retracted_reqs", + documentation="The total number of retracted requests due to kvcache full.", + labelnames=labels.keys(), + multiprocess_mode="mostrecent", + ) + self.num_retracted_reqs = Gauge( + name="sglang:num_retracted_reqs", + documentation="The number of retracted requests.", + labelnames=labels.keys(), + ) + self.num_paused_reqs = Gauge( + name="sglang:num_paused_reqs", + documentation="The number of paused requests by async weight sync.", + labelnames=labels.keys(), + ) + + # Utilization + self.utilization = Gauge( + name="sglang:utilization", + documentation="The utilization.", + labelnames=labels.keys(), + multiprocess_mode="mostrecent", + ) + self.max_running_requests_under_SLO = Gauge( + name="sglang:max_running_requests_under_SLO", + documentation="The maximum number of running requests under SLO.", + labelnames=labels.keys(), + multiprocess_mode="mostrecent", + ) + + # Engine startup + self.engine_startup_time = Gauge( + name="sglang:engine_startup_time", + documentation="The time taken for the engine to start up.", + labelnames=labels.keys(), + multiprocess_mode="mostrecent", + ) + self.engine_load_weights_time = Gauge( + name="sglang:engine_load_weights_time", + documentation="The time taken for the engine to load weights.", + labelnames=labels.keys(), + multiprocess_mode="mostrecent", + ) + + # Additional queueing time histogram + self.queue_time = Histogram( + name="sglang:queue_time_s", + documentation="Histogram of queueing time in seconds.", + labelnames=labels.keys(), + buckets=[ + 0.0, + 0.1, + 0.2, + 0.5, + 1, + 2, + 3, + 4, + 5, + 10, + 15, + 20, + 30, + 40, + 50, + 60, + 70, + 80, + 90, + 100, + 200, + 300, + 400, + 500, + 600, + 700, + 800, + 900, + 1000, + 1200, + 1400, + 1600, + 1800, + 2000, + 2500, + 3000, + ], + ) + + # Grammar metrics + self.grammar_compilation_time = Histogram( + name="sglang:grammar_compilation_time_seconds", + documentation="Histogram of grammar compilation time in seconds.", + labelnames=labels.keys(), + buckets=[ + 0.0, + 0.01, + 0.02, + 0.05, + 0.1, + 0.2, + 0.5, + 1, + 2, + 5, + 10, + 20, + 30, + 60, + 90, + 120, + 240, + ], + ) + self.num_grammar_cache_hit = Counter( + name="sglang:num_grammar_cache_hit_total", + documentation="Number of grammar cache hits.", + labelnames=labels.keys(), + ) + self.num_grammar_aborted = Counter( + name="sglang:num_grammar_aborted_total", + documentation="Number of grammar aborted requests.", + labelnames=labels.keys(), + ) + self.num_grammar_total = Counter( + name="sglang:num_grammar_total", + documentation="Number of the total grammar requests.", + labelnames=labels.keys(), + ) + self.grammar_schema_count = Histogram( + name="sglang:grammar_schema_count", + documentation="Histogram of grammar schema count.", + labelnames=labels.keys(), + buckets=[ + 0, + 1, + 2, + 5, + 10, + 20, + 30, + 40, + 60, + 80, + 100, + 120, + 140, + 160, + 180, + 200, + 300, + 400, + 500, + 700, + 1000, + ], + ) + self.grammar_ebnf_size = Histogram( + name="sglang:grammar_ebnf_size", + documentation="Histogram of grammar EBNF size.", + labelnames=labels.keys(), + buckets=[ + 0, + 50, + 100, + 200, + 300, + 500, + 1000, + 2000, + 3000, + 5000, + 10000, + 20000, + 30000, + 50000, + 100000, + ], + ) + + tree_traversal_time_buckets = [ + 0.0, + 0.01, + 0.02, + 0.05, + 0.1, + 0.2, + 0.5, + 1, + 2, + 5, + 10, + 15, + 30, + 60, + 90, + 120, + 240, + ] + self.grammar_tree_traversal_time_avg = Histogram( + name="sglang:grammar_tree_traversal_time_avg", + documentation="Histogram of average grammar tree traversal time in seconds.", + labelnames=labels.keys(), + buckets=tree_traversal_time_buckets, + ) + self.grammar_tree_traversal_time_max = Histogram( + name="sglang:grammar_tree_traversal_time_max", + documentation="Histogram of max grammar tree traversal time in seconds.", + labelnames=labels.keys(), + buckets=tree_traversal_time_buckets, + ) def _log_gauge(self, gauge, data: Union[int, float]) -> None: # Convenience function for logging to gauge. gauge.labels(**self.labels).set(data) + def log_histogram(self, histogram, data: Union[int, float]) -> None: + histogram.labels(**self.labels).observe(data) + def increment_bootstrap_failed_reqs(self) -> None: self.num_bootstrap_failed_reqs.labels(**self.labels).inc(1) @@ -284,14 +531,19 @@ class SchedulerMetricsCollector: self._log_gauge(self.num_running_reqs, stats.num_running_reqs) self._log_gauge(self.num_used_tokens, stats.num_used_tokens) self._log_gauge(self.token_usage, stats.token_usage) + self._log_gauge(self.swa_token_usage, stats.swa_token_usage) self._log_gauge(self.gen_throughput, stats.gen_throughput) self._log_gauge(self.num_queue_reqs, stats.num_queue_reqs) self._log_gauge(self.num_grammar_queue_reqs, stats.num_grammar_queue_reqs) + self._log_gauge( + self.num_running_reqs_offline_batch, stats.num_running_reqs_offline_batch + ) self._log_gauge(self.cache_hit_rate, stats.cache_hit_rate) - self._log_gauge(self.spec_accept_length, stats.spec_accept_length) - self._log_gauge(self.total_retracted_reqs, stats.total_retracted_reqs) - # Disaggregation metrics + # Speculative decoding + self._log_gauge(self.spec_accept_length, stats.spec_accept_length) + + # PD disaggregation self._log_gauge( self.num_prefill_prealloc_queue_reqs, stats.num_prefill_prealloc_queue_reqs ) @@ -304,15 +556,59 @@ class SchedulerMetricsCollector: self._log_gauge( self.num_decode_transfer_queue_reqs, stats.num_decode_transfer_queue_reqs ) + self._log_gauge(self.kv_transfer_speed_gb_s, stats.kv_transfer_speed_gb_s) + self._log_gauge(self.kv_transfer_latency_ms, stats.kv_transfer_latency_ms) + + # Retract + self._log_gauge(self.total_retracted_reqs, stats.total_retracted_reqs) + self._log_gauge(self.num_retracted_reqs, stats.num_retracted_reqs) + self._log_gauge(self.num_paused_reqs, stats.num_paused_reqs) + + # Utilization + self._log_gauge(self.utilization, stats.utilization) + if stats.max_running_requests_under_SLO is not None: + self._log_gauge( + self.max_running_requests_under_SLO, + stats.max_running_requests_under_SLO, + ) + + # Engine startup time + self._log_gauge(self.engine_startup_time, stats.engine_startup_time) + if stats.engine_load_weights_time is not None: + self._log_gauge( + self.engine_load_weights_time, stats.engine_load_weights_time + ) self.last_log_time = time.perf_counter() + def log_grammar_stats(self, grammar_stats) -> None: + # Duck-typed GrammarStats to avoid cross-package dependency + if getattr(grammar_stats, "compilation_time", None) is not None: + self.log_histogram( + self.grammar_compilation_time, grammar_stats.compilation_time + ) + if getattr(grammar_stats, "schema_count", None) is not None: + self.log_histogram(self.grammar_schema_count, grammar_stats.schema_count) + if getattr(grammar_stats, "ebnf_size", None) is not None: + self.log_histogram(self.grammar_ebnf_size, grammar_stats.ebnf_size) + tree_times = getattr(grammar_stats, "tree_traversal_time", None) + if tree_times: + max_time = max(tree_times) + avg_time = sum(tree_times) / len(tree_times) + self.log_histogram(self.grammar_tree_traversal_time_max, max_time) + self.log_histogram(self.grammar_tree_traversal_time_avg, avg_time) + if getattr(grammar_stats, "is_cache_hit", False): + self.num_grammar_cache_hit.labels(**self.labels).inc(1) + if getattr(grammar_stats, "is_grammar_aborted", False): + self.num_grammar_aborted.labels(**self.labels).inc(1) + self.num_grammar_total.labels(**self.labels).inc(1) + class TokenizerMetricsCollector: def __init__( self, - server_args: ServerArgs, - labels: Dict[str, str], + server_args: Optional[ServerArgs] = None, + labels: Dict[str, str] = None, bucket_time_to_first_token: Optional[List[float]] = None, bucket_inter_token_latency: Optional[List[float]] = None, bucket_e2e_request_latency: Optional[List[float]] = None, @@ -321,7 +617,7 @@ class TokenizerMetricsCollector: # We need to import prometheus_client after setting the env variable `PROMETHEUS_MULTIPROC_DIR` from prometheus_client import Counter, Histogram - self.labels = labels + self.labels = labels or {} self.collect_tokens_histogram = collect_tokens_histogram self.prompt_tokens_total = Counter( @@ -361,6 +657,13 @@ class TokenizerMetricsCollector: 30000, 35000, 40000, + 66000, + 99000, + 132000, + 300000, + 600000, + 900000, + 1100000, ] self.prompt_tokens_histogram = Histogram( name="sglang:prompt_tokens_histogram", @@ -370,34 +673,13 @@ class TokenizerMetricsCollector: server_args.prompt_tokens_buckets, default_bucket_prompt_tokens ), ) - default_bucket_generation_tokens = [ - 100, - 300, - 500, - 1000, - 1200, - 1500, - 1700, - 2000, - 2500, - 3000, - 3500, - 4000, - 4500, - 5000, - 6000, - 7000, - 8000, - 9000, - 10000, - ] self.generation_tokens_histogram = Histogram( name="sglang:generation_tokens_histogram", documentation="Histogram of generation token length.", labelnames=labels.keys(), buckets=generate_buckets( server_args.generation_tokens_buckets, - default_bucket_generation_tokens, + default_bucket_prompt_tokens, ), ) @@ -467,7 +749,10 @@ class TokenizerMetricsCollector: 100, 200, 400, - 800, + 600, + 1200, + 1800, + 2400, ] if bucket_inter_token_latency is None: @@ -518,6 +803,14 @@ class TokenizerMetricsCollector: buckets=bucket_e2e_request_latency, ) + # Offline batch specific TTFB histogram + self.histogram_time_to_first_token_offline_batch = Histogram( + name="sglang:time_to_first_token_seconds_offline_batch", + documentation="Histogram of time to first token in seconds for offline batch requests.", + labelnames=labels.keys(), + buckets=bucket_time_to_first_token, + ) + def _log_histogram(self, histogram, data: Union[int, float]) -> None: histogram.labels(**self.labels).observe(data) @@ -541,8 +834,26 @@ class TokenizerMetricsCollector: self._log_histogram(self.prompt_tokens_histogram, prompt_tokens) self._log_histogram(self.generation_tokens_histogram, generation_tokens) - def observe_time_to_first_token(self, value: float): - self.histogram_time_to_first_token.labels(**self.labels).observe(value) + def observe_time_to_first_token(self, value: float, label: str = ""): + if label == "batch": + self.histogram_time_to_first_token_offline_batch.labels( + **self.labels + ).observe(value) + else: + self.histogram_time_to_first_token.labels(**self.labels).observe(value) + + def check_time_to_first_token_straggler(self, value: float) -> bool: + his = self.histogram_time_to_first_token.labels(**self.labels) + total_observations = sum(bucket._value for bucket in his._buckets) + if total_observations < 100: + return False + p99_threshold = total_observations * 0.99 + cumulative_count = 0 + for i, bucket in enumerate(his._buckets): + cumulative_count += bucket._value + if cumulative_count > p99_threshold: + return value >= his._upper_bounds[i] + return False def observe_inter_token_latency(self, internval: float, num_new_tokens: int): adjusted_interval = internval / num_new_tokens diff --git a/python/sglang/srt/metrics/startup_func_log_and_timer.py b/python/sglang/srt/metrics/startup_func_log_and_timer.py new file mode 100644 index 000000000..752daccbd --- /dev/null +++ b/python/sglang/srt/metrics/startup_func_log_and_timer.py @@ -0,0 +1,150 @@ +""" +Records startup latency breakdown by context using gauge metrics in seconds +""" + +import logging +import time +from contextlib import contextmanager +from functools import wraps +from typing import Any, Callable, Dict, Generator, Optional + +logger = logging.getLogger(__name__) + +enable_startup_metrics = False +STARTUP_LATENCY_SECONDS = None +# Track maximum durations for each context +_max_durations: Dict[str, float] = {} + + +def enable_startup_timer(): + """Initialize startup latency metrics when metrics are enabled""" + # We need to import prometheus_client after setting the env variable `PROMETHEUS_MULTIPROC_DIR` + from prometheus_client import Gauge + + global enable_startup_metrics, STARTUP_LATENCY_SECONDS + enable_startup_metrics = True + + STARTUP_LATENCY_SECONDS = Gauge( + "sglang:startup_latency_breakdown_seconds_max", + "Startup latency breakdown in seconds by context, only records the maximum duration if the context is called multiple times.", + labelnames=["context"], + multiprocess_mode="mostrecent", + ) + + +def set_startup_metric(context: str, value: float, should_log: bool = True): + """Set the startup metric for a given context""" + if should_log: + logger.info(f"Setting startup metric: {context} took {value:.3f}s") + + if not enable_startup_metrics: + return + current_max = _max_durations.get(context, 0.0) + if value > current_max: + _max_durations[context] = value + STARTUP_LATENCY_SECONDS.labels(context=context).set(value) + + +def reset_startup_timers(): + """Reset all recorded maximum durations. Useful for testing or reinitialization.""" + global _max_durations + _max_durations.clear() + + +def get_max_duration(context: str) -> Optional[float]: + """Get the maximum recorded duration for a context name.""" + return _max_durations.get(context) + + +@contextmanager +def startup_timer(name: str, log_only: bool = False) -> Generator[None, None, None]: + """ + Context manager to measure startup latency for arbitrary code blocks. + Only records the maximum duration if the context is called multiple times. + + Usage: + with startup_timer("model_loading"): + # model loading code + model = load_model() + + with startup_timer("memory_allocation"): + # memory setup code + allocate_memory() + """ + start_time = time.monotonic() + try: + yield + finally: + duration_seconds = time.monotonic() - start_time + + # Track the maximum duration for this context name + current_max = _max_durations.get(name, 0.0) + is_new_max = duration_seconds > current_max + + if is_new_max: + _max_durations[name] = duration_seconds + + # Only update Prometheus gauge if this is a new maximum + if enable_startup_metrics and not log_only: + STARTUP_LATENCY_SECONDS.labels(context=name).set(duration_seconds) + + # Log with indication if this was a new max + logger.info(f"Startup timing: {name} took {duration_seconds:.3f}s") + + +def time_startup_latency( + func: Callable = None, name: Optional[str] = None, log_only: bool = False +) -> Callable[..., Any]: + """ + A decorator to measure startup context latency and record it in seconds. + Only records the maximum duration if the context is called multiple times. + + Usage: + @time_startup_latency + def load_model(): + # model loading code + + @time_startup_latency(name="custom_init") + def initialize_something(): + # initialization code + + @time_startup_latency(name="debug_only", log_only=True) + def debug_function(): + # This will only log, not record to Prometheus + """ + + def measure(func: Callable[..., Any]) -> Callable[..., Any]: + nonlocal name + name = name or func.__name__ + + @wraps(func) + def wrapper(*args, **kwargs): + start_time = time.monotonic() + try: + result = func(*args, **kwargs) + return result + finally: + duration_seconds = time.monotonic() - start_time + + # Track the maximum duration for this context name + current_max = _max_durations.get(name, 0.0) + is_new_max = duration_seconds > current_max + + if is_new_max: + _max_durations[name] = duration_seconds + + # Only update Prometheus gauge if this is a new maximum + if enable_startup_metrics and not log_only: + STARTUP_LATENCY_SECONDS.labels(context=name).set( + duration_seconds + ) + + # Log the timing + logger.info(f"Startup timing: {name} took {duration_seconds:.3f}s") + + return wrapper + + if func: + return measure(func) + else: + return measure