Add metrics for speculative decoding (acceptance rate, average acceptance length) (#11144)
This commit is contained in:
@@ -233,6 +233,7 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
|
||||
completion_tokens=recv_obj.completion_tokens,
|
||||
cached_tokens=recv_obj.cached_tokens,
|
||||
spec_verify_ct=recv_obj.spec_verify_ct,
|
||||
spec_accepted_tokens=recv_obj.spec_accepted_tokens,
|
||||
input_token_logprobs_val=recv_obj.input_token_logprobs_val,
|
||||
input_token_logprobs_idx=recv_obj.input_token_logprobs_idx,
|
||||
output_token_logprobs_val=recv_obj.output_token_logprobs_val,
|
||||
|
||||
@@ -816,6 +816,7 @@ class BatchTokenIDOutput(BaseBatchReq):
|
||||
completion_tokens: List[int]
|
||||
cached_tokens: List[int]
|
||||
spec_verify_ct: List[int]
|
||||
spec_accepted_tokens: List[int]
|
||||
|
||||
# Logprobs
|
||||
input_token_logprobs_val: List[float]
|
||||
@@ -882,6 +883,7 @@ class BatchStrOutput(BaseBatchReq):
|
||||
completion_tokens: List[int]
|
||||
cached_tokens: List[int]
|
||||
spec_verify_ct: List[int]
|
||||
spec_accepted_tokens: List[int]
|
||||
|
||||
# Logprobs
|
||||
input_token_logprobs_val: List[float]
|
||||
|
||||
@@ -246,6 +246,11 @@ def _handle_output_by_index(output, i):
|
||||
spec_verify_ct=(
|
||||
[output.spec_verify_ct[i]] if len(output.spec_verify_ct) > i else None
|
||||
),
|
||||
spec_accepted_tokens=(
|
||||
[output.spec_accepted_tokens[i]]
|
||||
if len(output.spec_accepted_tokens) > i
|
||||
else None
|
||||
),
|
||||
input_token_logprobs_val=(
|
||||
[output.input_token_logprobs_val[i]]
|
||||
if output.input_token_logprobs_val
|
||||
|
||||
@@ -631,6 +631,10 @@ class Req:
|
||||
# This is used to compute the average acceptance length per request.
|
||||
self.spec_verify_ct = 0
|
||||
|
||||
# The number of accepted tokens in speculative decoding for this request.
|
||||
# This is used to compute the acceptance rate and average acceptance length per request.
|
||||
self.spec_accepted_tokens = 0
|
||||
|
||||
# For metrics
|
||||
self.metrics_collector = metrics_collector
|
||||
self.time_stats: TimeStats = TimeStats(disagg_mode=disagg_mode)
|
||||
|
||||
@@ -216,14 +216,24 @@ class SchedulerMetricsMixin:
|
||||
|
||||
if self.spec_algorithm.is_none():
|
||||
spec_accept_length = 0
|
||||
spec_accept_rate = 0
|
||||
else:
|
||||
spec_accept_length = (
|
||||
self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct
|
||||
)
|
||||
# Calculate acceptance rate: accepted tokens / total draft tokens
|
||||
total_draft_tokens = self.spec_num_total_forward_ct * (
|
||||
self.server_args.speculative_num_steps or 1
|
||||
)
|
||||
spec_accept_rate = (
|
||||
self.spec_num_total_accepted_tokens / total_draft_tokens
|
||||
if total_draft_tokens > 0
|
||||
else 0
|
||||
)
|
||||
self.cum_spec_accept_length += self.spec_num_total_accepted_tokens
|
||||
self.cum_spec_accept_count += self.spec_num_total_forward_ct
|
||||
self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
|
||||
msg += f"accept len: {spec_accept_length:.2f}, "
|
||||
msg += f"accept len: {spec_accept_length:.2f}, accept rate: {spec_accept_rate:.2f}, "
|
||||
cache_hit_rate = 0.0
|
||||
|
||||
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||
@@ -251,6 +261,9 @@ class SchedulerMetricsMixin:
|
||||
self.stats.num_queue_reqs = len(self.waiting_queue)
|
||||
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
|
||||
self.stats.cache_hit_rate = cache_hit_rate
|
||||
|
||||
# Speculative decoding
|
||||
self.stats.spec_accept_rate = spec_accept_rate
|
||||
self.stats.spec_accept_length = spec_accept_length
|
||||
|
||||
# Retract
|
||||
|
||||
@@ -634,6 +634,7 @@ class SchedulerOutputProcessorMixin:
|
||||
completion_tokens = []
|
||||
cached_tokens = []
|
||||
spec_verify_ct = []
|
||||
spec_accepted_tokens = []
|
||||
output_hidden_states = None
|
||||
|
||||
if return_logprob:
|
||||
@@ -725,6 +726,7 @@ class SchedulerOutputProcessorMixin:
|
||||
|
||||
if not self.spec_algorithm.is_none():
|
||||
spec_verify_ct.append(req.spec_verify_ct)
|
||||
spec_accepted_tokens.append(req.spec_accepted_tokens)
|
||||
|
||||
if return_logprob:
|
||||
if (
|
||||
@@ -825,6 +827,7 @@ class SchedulerOutputProcessorMixin:
|
||||
completion_tokens,
|
||||
cached_tokens,
|
||||
spec_verify_ct,
|
||||
spec_accepted_tokens,
|
||||
input_token_logprobs_val,
|
||||
input_token_logprobs_idx,
|
||||
output_token_logprobs_val,
|
||||
|
||||
@@ -1394,6 +1394,36 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
||||
if state.finished:
|
||||
if self.server_args.speculative_algorithm:
|
||||
meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
|
||||
if (
|
||||
recv_obj.spec_verify_ct[i] > 0
|
||||
and self.server_args.speculative_num_steps is not None
|
||||
and not isinstance(recv_obj, BatchEmbeddingOutput)
|
||||
and hasattr(recv_obj, "spec_accepted_tokens")
|
||||
# Checks that `spec_accepted_tokens[i]` will exist.
|
||||
and len(recv_obj.spec_accepted_tokens) > i
|
||||
):
|
||||
total_draft_tokens = (
|
||||
recv_obj.spec_verify_ct[i]
|
||||
* self.server_args.speculative_num_steps
|
||||
)
|
||||
accepted_tokens = recv_obj.spec_accepted_tokens[i]
|
||||
|
||||
# Calculate per-request acceptance rate and average acceptance length.
|
||||
if total_draft_tokens > 0:
|
||||
# Calculate acceptance rate: accepted / (steps * lookahead)
|
||||
meta_info["spec_accept_rate"] = (
|
||||
accepted_tokens / total_draft_tokens
|
||||
)
|
||||
meta_info["spec_accept_length"] = (
|
||||
recv_obj.completion_tokens[i]
|
||||
/ recv_obj.spec_verify_ct[i]
|
||||
)
|
||||
else:
|
||||
meta_info["spec_accept_rate"] = 0.0
|
||||
meta_info["spec_accept_length"] = 0
|
||||
else:
|
||||
meta_info["spec_acceptance_rate"] = 0.0
|
||||
meta_info["spec_accept_length"] = 0
|
||||
state.finished_time = time.time()
|
||||
meta_info["e2e_latency"] = state.finished_time - state.created_time
|
||||
|
||||
|
||||
@@ -127,6 +127,7 @@ class SchedulerStats:
|
||||
|
||||
# Speculative decoding
|
||||
spec_accept_length: float = 0.0
|
||||
spec_accept_rate: float = 0.0
|
||||
|
||||
# Retract
|
||||
num_retracted_reqs: int = 0
|
||||
@@ -220,6 +221,12 @@ class SchedulerMetricsCollector:
|
||||
labelnames=labels.keys(),
|
||||
multiprocess_mode="mostrecent",
|
||||
)
|
||||
self.spec_accept_rate = Gauge(
|
||||
name="sglang:spec_accept_rate",
|
||||
documentation="The average acceptance rate of speculative decoding (`accepted tokens / total draft tokens` in batch).",
|
||||
labelnames=labels.keys(),
|
||||
multiprocess_mode="mostrecent",
|
||||
)
|
||||
|
||||
# Retract
|
||||
self.num_retracted_reqs = Gauge(
|
||||
@@ -520,6 +527,7 @@ class SchedulerMetricsCollector:
|
||||
|
||||
# Speculative decoding
|
||||
self._log_gauge(self.spec_accept_length, stats.spec_accept_length)
|
||||
self._log_gauge(self.spec_accept_rate, stats.spec_accept_rate)
|
||||
|
||||
# PD disaggregation
|
||||
self._log_gauge(
|
||||
|
||||
@@ -378,6 +378,13 @@ class EagleVerifyInput(SpecInput):
|
||||
unfinished_accept_index.append(accept_index[i])
|
||||
req.spec_verify_ct += 1
|
||||
|
||||
# For each request, accumulate # of accepted tokens for this verify pass.
|
||||
accept_length_this_pass = (accept_index != -1).sum(dim=1) - 1
|
||||
for i, (req, accepted_count) in enumerate(
|
||||
zip(batch.reqs, accept_length_this_pass.tolist())
|
||||
):
|
||||
req.spec_accepted_tokens += accepted_count
|
||||
|
||||
if has_finished:
|
||||
accept_length = (accept_index != -1).sum(dim=1) - 1
|
||||
|
||||
|
||||
Reference in New Issue
Block a user