Refactor spec decoding metrics calculation into separate TokenizerManager utility function (#11586)
This commit is contained in:
@@ -1394,37 +1394,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
||||
state.finished = recv_obj.finished_reasons[i] is not None
|
||||
if state.finished:
|
||||
if self.server_args.speculative_algorithm:
|
||||
meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
|
||||
if (
|
||||
recv_obj.spec_verify_ct[i] > 0
|
||||
and self.server_args.speculative_num_steps is not None
|
||||
and not isinstance(recv_obj, BatchEmbeddingOutput)
|
||||
and hasattr(recv_obj, "spec_accepted_tokens")
|
||||
# Checks that `spec_accepted_tokens[i]` will exist.
|
||||
and len(recv_obj.spec_accepted_tokens) > i
|
||||
):
|
||||
total_draft_tokens = (
|
||||
recv_obj.spec_verify_ct[i]
|
||||
* self.server_args.speculative_num_steps
|
||||
)
|
||||
accepted_tokens = recv_obj.spec_accepted_tokens[i]
|
||||
|
||||
# Calculate per-request acceptance rate and average acceptance length.
|
||||
if total_draft_tokens > 0:
|
||||
# Calculate acceptance rate: accepted / (steps * lookahead)
|
||||
meta_info["spec_accept_rate"] = (
|
||||
accepted_tokens / total_draft_tokens
|
||||
)
|
||||
meta_info["spec_accept_length"] = (
|
||||
recv_obj.completion_tokens[i]
|
||||
/ recv_obj.spec_verify_ct[i]
|
||||
)
|
||||
else:
|
||||
meta_info["spec_accept_rate"] = 0.0
|
||||
meta_info["spec_accept_length"] = 0
|
||||
else:
|
||||
meta_info["spec_accept_rate"] = 0.0
|
||||
meta_info["spec_accept_length"] = 0
|
||||
self._calculate_spec_decoding_metrics(meta_info, recv_obj, i)
|
||||
state.finished_time = time.time()
|
||||
meta_info["e2e_latency"] = state.finished_time - state.created_time
|
||||
|
||||
@@ -1572,6 +1542,43 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
||||
ret.append(None)
|
||||
return ret
|
||||
|
||||
def _calculate_spec_decoding_metrics(
|
||||
self,
|
||||
meta_info: Dict[str, Any],
|
||||
recv_obj: Union[
|
||||
BatchStrOutput,
|
||||
BatchEmbeddingOutput,
|
||||
BatchMultimodalOutput,
|
||||
BatchTokenIDOutput,
|
||||
],
|
||||
i: int,
|
||||
) -> None:
|
||||
"""Calculate speculative decoding metrics, such as acceptance rate and acceptance length metrics."""
|
||||
meta_info["spec_accept_rate"] = 0.0
|
||||
meta_info["spec_accept_length"] = 0
|
||||
meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
|
||||
|
||||
if (
|
||||
recv_obj.spec_verify_ct[i] > 0
|
||||
and self.server_args.speculative_num_steps is not None
|
||||
and not isinstance(recv_obj, BatchEmbeddingOutput)
|
||||
and hasattr(recv_obj, "spec_accepted_tokens")
|
||||
# Checks that `spec_accepted_tokens[i]` will exist.
|
||||
and len(recv_obj.spec_accepted_tokens) > i
|
||||
):
|
||||
total_draft_tokens = (
|
||||
recv_obj.spec_verify_ct[i] * self.server_args.speculative_num_steps
|
||||
)
|
||||
accepted_tokens = recv_obj.spec_accepted_tokens[i]
|
||||
|
||||
# Calculate per-request acceptance rate and average acceptance length.
|
||||
if total_draft_tokens > 0:
|
||||
# Calculate acceptance rate: accepted / (steps * lookahead)
|
||||
meta_info["spec_accept_rate"] = accepted_tokens / total_draft_tokens
|
||||
meta_info["spec_accept_length"] = (
|
||||
recv_obj.completion_tokens[i] / recv_obj.spec_verify_ct[i]
|
||||
)
|
||||
|
||||
def collect_metrics(self, state: ReqState, recv_obj: BatchStrOutput, i: int):
|
||||
completion_tokens = (
|
||||
recv_obj.completion_tokens[i]
|
||||
|
||||
Reference in New Issue
Block a user