From 0babd487368165b682fb25d4af6ad8b42927976e Mon Sep 17 00:00:00 2001 From: Scott Lee Date: Fri, 10 Oct 2025 00:46:44 -0700 Subject: [PATCH] Add metrics for speculative decoding (acceptance rate, average acceptance length) (#11144) --- .../srt/managers/detokenizer_manager.py | 1 + python/sglang/srt/managers/io_struct.py | 2 ++ .../srt/managers/multi_tokenizer_mixin.py | 5 ++++ python/sglang/srt/managers/schedule_batch.py | 4 +++ .../srt/managers/scheduler_metrics_mixin.py | 15 +++++++++- .../scheduler_output_processor_mixin.py | 3 ++ .../sglang/srt/managers/tokenizer_manager.py | 30 +++++++++++++++++++ python/sglang/srt/metrics/collector.py | 8 +++++ python/sglang/srt/speculative/eagle_info.py | 7 +++++ 9 files changed, 74 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index 68132991c..6d98bc173 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -233,6 +233,7 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin): completion_tokens=recv_obj.completion_tokens, cached_tokens=recv_obj.cached_tokens, spec_verify_ct=recv_obj.spec_verify_ct, + spec_accepted_tokens=recv_obj.spec_accepted_tokens, input_token_logprobs_val=recv_obj.input_token_logprobs_val, input_token_logprobs_idx=recv_obj.input_token_logprobs_idx, output_token_logprobs_val=recv_obj.output_token_logprobs_val, diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index e6dfa35c4..6d6cbb3d7 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -816,6 +816,7 @@ class BatchTokenIDOutput(BaseBatchReq): completion_tokens: List[int] cached_tokens: List[int] spec_verify_ct: List[int] + spec_accepted_tokens: List[int] # Logprobs input_token_logprobs_val: List[float] @@ -882,6 +883,7 @@ class BatchStrOutput(BaseBatchReq): completion_tokens: List[int] cached_tokens: List[int] spec_verify_ct: List[int] + spec_accepted_tokens: List[int] # Logprobs input_token_logprobs_val: List[float] diff --git a/python/sglang/srt/managers/multi_tokenizer_mixin.py b/python/sglang/srt/managers/multi_tokenizer_mixin.py index 83c966ec6..bd391d1a5 100644 --- a/python/sglang/srt/managers/multi_tokenizer_mixin.py +++ b/python/sglang/srt/managers/multi_tokenizer_mixin.py @@ -246,6 +246,11 @@ def _handle_output_by_index(output, i): spec_verify_ct=( [output.spec_verify_ct[i]] if len(output.spec_verify_ct) > i else None ), + spec_accepted_tokens=( + [output.spec_accepted_tokens[i]] + if len(output.spec_accepted_tokens) > i + else None + ), input_token_logprobs_val=( [output.input_token_logprobs_val[i]] if output.input_token_logprobs_val diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 075d90477..3a10ab4b4 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -631,6 +631,10 @@ class Req: # This is used to compute the average acceptance length per request. self.spec_verify_ct = 0 + # The number of accepted tokens in speculative decoding for this request. + # This is used to compute the acceptance rate and average acceptance length per request. + self.spec_accepted_tokens = 0 + # For metrics self.metrics_collector = metrics_collector self.time_stats: TimeStats = TimeStats(disagg_mode=disagg_mode) diff --git a/python/sglang/srt/managers/scheduler_metrics_mixin.py b/python/sglang/srt/managers/scheduler_metrics_mixin.py index 4fa4bfee1..2fa5e4575 100644 --- a/python/sglang/srt/managers/scheduler_metrics_mixin.py +++ b/python/sglang/srt/managers/scheduler_metrics_mixin.py @@ -216,14 +216,24 @@ class SchedulerMetricsMixin: if self.spec_algorithm.is_none(): spec_accept_length = 0 + spec_accept_rate = 0 else: spec_accept_length = ( self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct ) + # Calculate acceptance rate: accepted tokens / total draft tokens + total_draft_tokens = self.spec_num_total_forward_ct * ( + self.server_args.speculative_num_steps or 1 + ) + spec_accept_rate = ( + self.spec_num_total_accepted_tokens / total_draft_tokens + if total_draft_tokens > 0 + else 0 + ) self.cum_spec_accept_length += self.spec_num_total_accepted_tokens self.cum_spec_accept_count += self.spec_num_total_forward_ct self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0 - msg += f"accept len: {spec_accept_length:.2f}, " + msg += f"accept len: {spec_accept_length:.2f}, accept rate: {spec_accept_rate:.2f}, " cache_hit_rate = 0.0 if self.disaggregation_mode == DisaggregationMode.DECODE: @@ -251,6 +261,9 @@ class SchedulerMetricsMixin: self.stats.num_queue_reqs = len(self.waiting_queue) self.stats.num_grammar_queue_reqs = len(self.grammar_queue) self.stats.cache_hit_rate = cache_hit_rate + + # Speculative decoding + self.stats.spec_accept_rate = spec_accept_rate self.stats.spec_accept_length = spec_accept_length # Retract diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py index 2072f9f68..a205fdb7f 100644 --- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py @@ -634,6 +634,7 @@ class SchedulerOutputProcessorMixin: completion_tokens = [] cached_tokens = [] spec_verify_ct = [] + spec_accepted_tokens = [] output_hidden_states = None if return_logprob: @@ -725,6 +726,7 @@ class SchedulerOutputProcessorMixin: if not self.spec_algorithm.is_none(): spec_verify_ct.append(req.spec_verify_ct) + spec_accepted_tokens.append(req.spec_accepted_tokens) if return_logprob: if ( @@ -825,6 +827,7 @@ class SchedulerOutputProcessorMixin: completion_tokens, cached_tokens, spec_verify_ct, + spec_accepted_tokens, input_token_logprobs_val, input_token_logprobs_idx, output_token_logprobs_val, diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index be9e5699a..619782592 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -1394,6 +1394,36 @@ class TokenizerManager(TokenizerCommunicatorMixin): if state.finished: if self.server_args.speculative_algorithm: meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i] + if ( + recv_obj.spec_verify_ct[i] > 0 + and self.server_args.speculative_num_steps is not None + and not isinstance(recv_obj, BatchEmbeddingOutput) + and hasattr(recv_obj, "spec_accepted_tokens") + # Checks that `spec_accepted_tokens[i]` will exist. + and len(recv_obj.spec_accepted_tokens) > i + ): + total_draft_tokens = ( + recv_obj.spec_verify_ct[i] + * self.server_args.speculative_num_steps + ) + accepted_tokens = recv_obj.spec_accepted_tokens[i] + + # Calculate per-request acceptance rate and average acceptance length. + if total_draft_tokens > 0: + # Calculate acceptance rate: accepted / (steps * lookahead) + meta_info["spec_accept_rate"] = ( + accepted_tokens / total_draft_tokens + ) + meta_info["spec_accept_length"] = ( + recv_obj.completion_tokens[i] + / recv_obj.spec_verify_ct[i] + ) + else: + meta_info["spec_accept_rate"] = 0.0 + meta_info["spec_accept_length"] = 0 + else: + meta_info["spec_acceptance_rate"] = 0.0 + meta_info["spec_accept_length"] = 0 state.finished_time = time.time() meta_info["e2e_latency"] = state.finished_time - state.created_time diff --git a/python/sglang/srt/metrics/collector.py b/python/sglang/srt/metrics/collector.py index e793eb988..60e0758ea 100644 --- a/python/sglang/srt/metrics/collector.py +++ b/python/sglang/srt/metrics/collector.py @@ -127,6 +127,7 @@ class SchedulerStats: # Speculative decoding spec_accept_length: float = 0.0 + spec_accept_rate: float = 0.0 # Retract num_retracted_reqs: int = 0 @@ -220,6 +221,12 @@ class SchedulerMetricsCollector: labelnames=labels.keys(), multiprocess_mode="mostrecent", ) + self.spec_accept_rate = Gauge( + name="sglang:spec_accept_rate", + documentation="The average acceptance rate of speculative decoding (`accepted tokens / total draft tokens` in batch).", + labelnames=labels.keys(), + multiprocess_mode="mostrecent", + ) # Retract self.num_retracted_reqs = Gauge( @@ -520,6 +527,7 @@ class SchedulerMetricsCollector: # Speculative decoding self._log_gauge(self.spec_accept_length, stats.spec_accept_length) + self._log_gauge(self.spec_accept_rate, stats.spec_accept_rate) # PD disaggregation self._log_gauge( diff --git a/python/sglang/srt/speculative/eagle_info.py b/python/sglang/srt/speculative/eagle_info.py index 5d8c920c4..80d78592e 100644 --- a/python/sglang/srt/speculative/eagle_info.py +++ b/python/sglang/srt/speculative/eagle_info.py @@ -378,6 +378,13 @@ class EagleVerifyInput(SpecInput): unfinished_accept_index.append(accept_index[i]) req.spec_verify_ct += 1 + # For each request, accumulate # of accepted tokens for this verify pass. + accept_length_this_pass = (accept_index != -1).sum(dim=1) - 1 + for i, (req, accepted_count) in enumerate( + zip(batch.reqs, accept_length_this_pass.tolist()) + ): + req.spec_accepted_tokens += accepted_count + if has_finished: accept_length = (accept_index != -1).sum(dim=1) - 1