From b6fb5d766623f155818efb3abe800c31ca1fd1c3 Mon Sep 17 00:00:00 2001 From: Scott Lee Date: Mon, 13 Oct 2025 11:24:27 -0700 Subject: [PATCH] Add metrics for speculative decoding (acceptance rate, average acceptance length) (#11441) --- .../srt/managers/detokenizer_manager.py | 1 + python/sglang/srt/managers/io_struct.py | 2 ++ .../srt/managers/multi_tokenizer_mixin.py | 5 ++++ python/sglang/srt/managers/schedule_batch.py | 4 +++ .../srt/managers/scheduler_metrics_mixin.py | 15 +++++++++- .../scheduler_output_processor_mixin.py | 3 ++ .../sglang/srt/managers/tokenizer_manager.py | 30 +++++++++++++++++++ python/sglang/srt/metrics/collector.py | 8 +++++ python/sglang/srt/speculative/eagle_info.py | 3 ++ 9 files changed, 70 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index f8135767e..8db07db71 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -233,6 +233,7 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin): completion_tokens=recv_obj.completion_tokens, cached_tokens=recv_obj.cached_tokens, spec_verify_ct=recv_obj.spec_verify_ct, + spec_accepted_tokens=recv_obj.spec_accepted_tokens, input_token_logprobs_val=recv_obj.input_token_logprobs_val, input_token_logprobs_idx=recv_obj.input_token_logprobs_idx, output_token_logprobs_val=recv_obj.output_token_logprobs_val, diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index bb542b7bd..ef2b6f611 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -823,6 +823,7 @@ class BatchTokenIDOutput(BaseBatchReq): completion_tokens: List[int] cached_tokens: List[int] spec_verify_ct: List[int] + spec_accepted_tokens: List[int] # Logprobs input_token_logprobs_val: List[float] @@ -896,6 +897,7 @@ class BatchStrOutput(BaseBatchReq): completion_tokens: List[int] cached_tokens: List[int] spec_verify_ct: List[int] + spec_accepted_tokens: List[int] # Logprobs input_token_logprobs_val: List[float] diff --git a/python/sglang/srt/managers/multi_tokenizer_mixin.py b/python/sglang/srt/managers/multi_tokenizer_mixin.py index 302546e5f..4a3eb1323 100644 --- a/python/sglang/srt/managers/multi_tokenizer_mixin.py +++ b/python/sglang/srt/managers/multi_tokenizer_mixin.py @@ -252,6 +252,11 @@ def _handle_output_by_index(output, i): spec_verify_ct=( [output.spec_verify_ct[i]] if len(output.spec_verify_ct) > i else None ), + spec_accepted_tokens=( + [output.spec_accepted_tokens[i]] + if len(output.spec_accepted_tokens) > i + else None + ), input_token_logprobs_val=( [output.input_token_logprobs_val[i]] if output.input_token_logprobs_val diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 5ebf3f61a..720ac2b67 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -608,6 +608,10 @@ class Req: # This is used to compute the average acceptance length per request. self.spec_verify_ct = 0 + # The number of accepted tokens in speculative decoding for this request. + # This is used to compute the acceptance rate and average acceptance length per request. + self.spec_accepted_tokens = 0 + # For metrics self.metrics_collector = metrics_collector self.time_stats: TimeStats = TimeStats(disagg_mode=disagg_mode) diff --git a/python/sglang/srt/managers/scheduler_metrics_mixin.py b/python/sglang/srt/managers/scheduler_metrics_mixin.py index 8beea66db..521cf3d52 100644 --- a/python/sglang/srt/managers/scheduler_metrics_mixin.py +++ b/python/sglang/srt/managers/scheduler_metrics_mixin.py @@ -252,14 +252,24 @@ class SchedulerMetricsMixin: if self.spec_algorithm.is_none(): spec_accept_length = 0 + spec_accept_rate = 0 else: spec_accept_length = ( self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct ) + # Calculate acceptance rate: accepted tokens / total draft tokens + total_draft_tokens = self.spec_num_total_forward_ct * ( + self.server_args.speculative_num_steps or 1 + ) + spec_accept_rate = ( + self.spec_num_total_accepted_tokens / total_draft_tokens + if total_draft_tokens > 0 + else 0 + ) self.cum_spec_accept_length += self.spec_num_total_accepted_tokens self.cum_spec_accept_count += self.spec_num_total_forward_ct self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0 - msg += f"accept len: {spec_accept_length:.2f}, " + msg += f"accept len: {spec_accept_length:.2f}, accept rate: {spec_accept_rate:.2f}, " cache_hit_rate = 0.0 if self.disaggregation_mode == DisaggregationMode.DECODE: @@ -287,6 +297,9 @@ class SchedulerMetricsMixin: self.stats.num_queue_reqs = len(self.waiting_queue) self.stats.num_grammar_queue_reqs = len(self.grammar_queue) self.stats.cache_hit_rate = cache_hit_rate + + # Speculative decoding + self.stats.spec_accept_rate = spec_accept_rate self.stats.spec_accept_length = spec_accept_length # Retract diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py index ba3b09e1a..bde62957e 100644 --- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py @@ -711,6 +711,7 @@ class SchedulerOutputProcessorMixin: completion_tokens = [] cached_tokens = [] spec_verify_ct = [] + spec_accepted_tokens = [] output_hidden_states = None if return_logprob: @@ -808,6 +809,7 @@ class SchedulerOutputProcessorMixin: if not self.spec_algorithm.is_none(): spec_verify_ct.append(req.spec_verify_ct) + spec_accepted_tokens.append(req.spec_accepted_tokens) if return_logprob: if ( @@ -908,6 +910,7 @@ class SchedulerOutputProcessorMixin: completion_tokens, cached_tokens, spec_verify_ct, + spec_accepted_tokens, input_token_logprobs_val, input_token_logprobs_idx, output_token_logprobs_val, diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index c521b1112..bb43c9c79 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -1395,6 +1395,36 @@ class TokenizerManager(TokenizerCommunicatorMixin): if state.finished: if self.server_args.speculative_algorithm: meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i] + if ( + recv_obj.spec_verify_ct[i] > 0 + and self.server_args.speculative_num_steps is not None + and not isinstance(recv_obj, BatchEmbeddingOutput) + and hasattr(recv_obj, "spec_accepted_tokens") + # Checks that `spec_accepted_tokens[i]` will exist. + and len(recv_obj.spec_accepted_tokens) > i + ): + total_draft_tokens = ( + recv_obj.spec_verify_ct[i] + * self.server_args.speculative_num_steps + ) + accepted_tokens = recv_obj.spec_accepted_tokens[i] + + # Calculate per-request acceptance rate and average acceptance length. + if total_draft_tokens > 0: + # Calculate acceptance rate: accepted / (steps * lookahead) + meta_info["spec_accept_rate"] = ( + accepted_tokens / total_draft_tokens + ) + meta_info["spec_accept_length"] = ( + recv_obj.completion_tokens[i] + / recv_obj.spec_verify_ct[i] + ) + else: + meta_info["spec_accept_rate"] = 0.0 + meta_info["spec_accept_length"] = 0 + else: + meta_info["spec_accept_rate"] = 0.0 + meta_info["spec_accept_length"] = 0 state.finished_time = time.time() meta_info["e2e_latency"] = state.finished_time - state.created_time diff --git a/python/sglang/srt/metrics/collector.py b/python/sglang/srt/metrics/collector.py index e793eb988..60e0758ea 100644 --- a/python/sglang/srt/metrics/collector.py +++ b/python/sglang/srt/metrics/collector.py @@ -127,6 +127,7 @@ class SchedulerStats: # Speculative decoding spec_accept_length: float = 0.0 + spec_accept_rate: float = 0.0 # Retract num_retracted_reqs: int = 0 @@ -220,6 +221,12 @@ class SchedulerMetricsCollector: labelnames=labels.keys(), multiprocess_mode="mostrecent", ) + self.spec_accept_rate = Gauge( + name="sglang:spec_accept_rate", + documentation="The average acceptance rate of speculative decoding (`accepted tokens / total draft tokens` in batch).", + labelnames=labels.keys(), + multiprocess_mode="mostrecent", + ) # Retract self.num_retracted_reqs = Gauge( @@ -520,6 +527,7 @@ class SchedulerMetricsCollector: # Speculative decoding self._log_gauge(self.spec_accept_length, stats.spec_accept_length) + self._log_gauge(self.spec_accept_rate, stats.spec_accept_rate) # PD disaggregation self._log_gauge( diff --git a/python/sglang/srt/speculative/eagle_info.py b/python/sglang/srt/speculative/eagle_info.py index ad94a7c5c..4373aa3f3 100644 --- a/python/sglang/srt/speculative/eagle_info.py +++ b/python/sglang/srt/speculative/eagle_info.py @@ -384,6 +384,9 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin): else: unfinished_accept_index.append(accept_index[i]) req.spec_verify_ct += 1 + req.spec_accepted_tokens += ( + sum(1 for idx in accept_index_row if idx != -1) - 1 + ) if has_finished: accept_length = (accept_index != -1).sum(dim=1) - 1