diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 9f9c53eaa..fb6202932 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -601,7 +601,7 @@ class TokenizerManager: while not self.gracefully_exit: await asyncio.sleep(5) - # drain requests + # Drain requests while True: remain_num_req = len(self.rid_to_state) logger.info( @@ -679,45 +679,7 @@ class TokenizerManager: state.event.set() if self.enable_metrics: - completion_tokens = ( - recv_obj.completion_tokens[i] - if getattr(recv_obj, "completion_tokens", None) - else 0 - ) - - if state.first_token_time is None: - state.first_token_time = time.time() - self.metrics_collector.observe_time_to_first_token( - state.first_token_time - state.created_time - ) - else: - if completion_tokens >= 2: - # Compute time_per_output_token for the streaming case - self.metrics_collector.observe_time_per_output_token( - (time.time() - state.first_token_time) - / (completion_tokens - 1) - ) - - if state.finished: - self.metrics_collector.inc_prompt_tokens( - recv_obj.prompt_tokens[i] - ) - self.metrics_collector.inc_generation_tokens( - completion_tokens - ) - self.metrics_collector.observe_e2e_request_latency( - time.time() - state.created_time - ) - # Compute time_per_output_token for the non-streaming case - if ( - hasattr(state.obj, "stream") - and not state.obj.stream - and completion_tokens >= 1 - ): - self.metrics_collector.observe_time_per_output_token( - (time.time() - state.created_time) - / completion_tokens - ) + self.collect_metrics(state, recv_obj, i) elif isinstance(recv_obj, OpenSessionReqOutput): self.session_futures[recv_obj.session_id].set_result( recv_obj.session_id if recv_obj.success else None @@ -820,6 +782,42 @@ class TokenizerManager: ret.append(None) return ret + def collect_metrics(self, state: ReqState, recv_obj: BatchStrOut, i: int): + completion_tokens = ( + recv_obj.completion_tokens[i] + if getattr(recv_obj, "completion_tokens", None) + else 0 + ) + + if state.first_token_time is None: + state.first_token_time = time.time() + self.metrics_collector.observe_time_to_first_token( + state.first_token_time - state.created_time + ) + else: + if completion_tokens >= 2: + # Compute time_per_output_token for the streaming case + self.metrics_collector.observe_time_per_output_token( + (time.time() - state.first_token_time) / (completion_tokens - 1) + ) + + if state.finished: + self.metrics_collector.observe_one_finished_request( + recv_obj.prompt_tokens[i], completion_tokens + ) + self.metrics_collector.observe_e2e_request_latency( + time.time() - state.created_time + ) + # Compute time_per_output_token for the non-streaming case + if ( + hasattr(state.obj, "stream") + and not state.obj.stream + and completion_tokens >= 1 + ): + self.metrics_collector.observe_time_per_output_token( + (time.time() - state.created_time) / completion_tokens + ) + class SignalHandler: def __init__(self, tokenizer_manager): diff --git a/python/sglang/srt/metrics/collector.py b/python/sglang/srt/metrics/collector.py index 9505f012f..070b405be 100644 --- a/python/sglang/srt/metrics/collector.py +++ b/python/sglang/srt/metrics/collector.py @@ -109,6 +109,12 @@ class TokenizerMetricsCollector: labelnames=labels.keys(), ) + self.num_requests_total = Counter( + name="sglang:num_requests_total", + documentation="Number of requests processed.", + labelnames=labels.keys(), + ) + self.histogram_time_to_first_token = Histogram( name="sglang:time_to_first_token_seconds", documentation="Histogram of time to first token in seconds.", @@ -185,11 +191,10 @@ class TokenizerMetricsCollector: # Convenience function for logging to counter. counter.labels(**self.labels).inc(data) - def inc_prompt_tokens(self, value: int): - self._log_counter(self.prompt_tokens_total, value) - - def inc_generation_tokens(self, value: int): - self._log_counter(self.generation_tokens_total, value) + def observe_one_finished_request(self, prompt_tokens: int, generation_tokens: int): + self.prompt_tokens_total.labels(**self.labels).inc(prompt_tokens) + self.generation_tokens_total.labels(**self.labels).inc(generation_tokens) + self.num_requests_total.labels(**self.labels).inc(1) def observe_time_to_first_token(self, value: Union[float, int]): self._log_histogram(self.histogram_time_to_first_token, value) diff --git a/test/srt/test_metrics.py b/test/srt/test_metrics.py index ccaea5be8..69babf795 100644 --- a/test/srt/test_metrics.py +++ b/test/srt/test_metrics.py @@ -59,6 +59,7 @@ class TestEnableMetrics(unittest.TestCase): "sglang:func_latency_seconds", "sglang:prompt_tokens_total", "sglang:generation_tokens_total", + "sglang:num_requests_total", "sglang:time_to_first_token_seconds", "sglang:time_per_output_token_seconds", "sglang:e2e_request_latency_seconds",