diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index bb43c9c79..e646c2a6c 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -1394,37 +1394,7 @@ class TokenizerManager(TokenizerCommunicatorMixin): state.finished = recv_obj.finished_reasons[i] is not None if state.finished: if self.server_args.speculative_algorithm: - meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i] - if ( - recv_obj.spec_verify_ct[i] > 0 - and self.server_args.speculative_num_steps is not None - and not isinstance(recv_obj, BatchEmbeddingOutput) - and hasattr(recv_obj, "spec_accepted_tokens") - # Checks that `spec_accepted_tokens[i]` will exist. - and len(recv_obj.spec_accepted_tokens) > i - ): - total_draft_tokens = ( - recv_obj.spec_verify_ct[i] - * self.server_args.speculative_num_steps - ) - accepted_tokens = recv_obj.spec_accepted_tokens[i] - - # Calculate per-request acceptance rate and average acceptance length. - if total_draft_tokens > 0: - # Calculate acceptance rate: accepted / (steps * lookahead) - meta_info["spec_accept_rate"] = ( - accepted_tokens / total_draft_tokens - ) - meta_info["spec_accept_length"] = ( - recv_obj.completion_tokens[i] - / recv_obj.spec_verify_ct[i] - ) - else: - meta_info["spec_accept_rate"] = 0.0 - meta_info["spec_accept_length"] = 0 - else: - meta_info["spec_accept_rate"] = 0.0 - meta_info["spec_accept_length"] = 0 + self._calculate_spec_decoding_metrics(meta_info, recv_obj, i) state.finished_time = time.time() meta_info["e2e_latency"] = state.finished_time - state.created_time @@ -1572,6 +1542,43 @@ class TokenizerManager(TokenizerCommunicatorMixin): ret.append(None) return ret + def _calculate_spec_decoding_metrics( + self, + meta_info: Dict[str, Any], + recv_obj: Union[ + BatchStrOutput, + BatchEmbeddingOutput, + BatchMultimodalOutput, + BatchTokenIDOutput, + ], + i: int, + ) -> None: + """Calculate speculative decoding metrics, such as acceptance rate and acceptance length metrics.""" + meta_info["spec_accept_rate"] = 0.0 + meta_info["spec_accept_length"] = 0 + meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i] + + if ( + recv_obj.spec_verify_ct[i] > 0 + and self.server_args.speculative_num_steps is not None + and not isinstance(recv_obj, BatchEmbeddingOutput) + and hasattr(recv_obj, "spec_accepted_tokens") + # Checks that `spec_accepted_tokens[i]` will exist. + and len(recv_obj.spec_accepted_tokens) > i + ): + total_draft_tokens = ( + recv_obj.spec_verify_ct[i] * self.server_args.speculative_num_steps + ) + accepted_tokens = recv_obj.spec_accepted_tokens[i] + + # Calculate per-request acceptance rate and average acceptance length. + if total_draft_tokens > 0: + # Calculate acceptance rate: accepted / (steps * lookahead) + meta_info["spec_accept_rate"] = accepted_tokens / total_draft_tokens + meta_info["spec_accept_length"] = ( + recv_obj.completion_tokens[i] / recv_obj.spec_verify_ct[i] + ) + def collect_metrics(self, state: ReqState, recv_obj: BatchStrOutput, i: int): completion_tokens = ( recv_obj.completion_tokens[i]