Add metrics for speculative decoding (acceptance rate, average acceptance length) (#11144)
This commit is contained in:
@@ -233,6 +233,7 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
|
|||||||
completion_tokens=recv_obj.completion_tokens,
|
completion_tokens=recv_obj.completion_tokens,
|
||||||
cached_tokens=recv_obj.cached_tokens,
|
cached_tokens=recv_obj.cached_tokens,
|
||||||
spec_verify_ct=recv_obj.spec_verify_ct,
|
spec_verify_ct=recv_obj.spec_verify_ct,
|
||||||
|
spec_accepted_tokens=recv_obj.spec_accepted_tokens,
|
||||||
input_token_logprobs_val=recv_obj.input_token_logprobs_val,
|
input_token_logprobs_val=recv_obj.input_token_logprobs_val,
|
||||||
input_token_logprobs_idx=recv_obj.input_token_logprobs_idx,
|
input_token_logprobs_idx=recv_obj.input_token_logprobs_idx,
|
||||||
output_token_logprobs_val=recv_obj.output_token_logprobs_val,
|
output_token_logprobs_val=recv_obj.output_token_logprobs_val,
|
||||||
|
|||||||
@@ -816,6 +816,7 @@ class BatchTokenIDOutput(BaseBatchReq):
|
|||||||
completion_tokens: List[int]
|
completion_tokens: List[int]
|
||||||
cached_tokens: List[int]
|
cached_tokens: List[int]
|
||||||
spec_verify_ct: List[int]
|
spec_verify_ct: List[int]
|
||||||
|
spec_accepted_tokens: List[int]
|
||||||
|
|
||||||
# Logprobs
|
# Logprobs
|
||||||
input_token_logprobs_val: List[float]
|
input_token_logprobs_val: List[float]
|
||||||
@@ -882,6 +883,7 @@ class BatchStrOutput(BaseBatchReq):
|
|||||||
completion_tokens: List[int]
|
completion_tokens: List[int]
|
||||||
cached_tokens: List[int]
|
cached_tokens: List[int]
|
||||||
spec_verify_ct: List[int]
|
spec_verify_ct: List[int]
|
||||||
|
spec_accepted_tokens: List[int]
|
||||||
|
|
||||||
# Logprobs
|
# Logprobs
|
||||||
input_token_logprobs_val: List[float]
|
input_token_logprobs_val: List[float]
|
||||||
|
|||||||
@@ -246,6 +246,11 @@ def _handle_output_by_index(output, i):
|
|||||||
spec_verify_ct=(
|
spec_verify_ct=(
|
||||||
[output.spec_verify_ct[i]] if len(output.spec_verify_ct) > i else None
|
[output.spec_verify_ct[i]] if len(output.spec_verify_ct) > i else None
|
||||||
),
|
),
|
||||||
|
spec_accepted_tokens=(
|
||||||
|
[output.spec_accepted_tokens[i]]
|
||||||
|
if len(output.spec_accepted_tokens) > i
|
||||||
|
else None
|
||||||
|
),
|
||||||
input_token_logprobs_val=(
|
input_token_logprobs_val=(
|
||||||
[output.input_token_logprobs_val[i]]
|
[output.input_token_logprobs_val[i]]
|
||||||
if output.input_token_logprobs_val
|
if output.input_token_logprobs_val
|
||||||
|
|||||||
@@ -631,6 +631,10 @@ class Req:
|
|||||||
# This is used to compute the average acceptance length per request.
|
# This is used to compute the average acceptance length per request.
|
||||||
self.spec_verify_ct = 0
|
self.spec_verify_ct = 0
|
||||||
|
|
||||||
|
# The number of accepted tokens in speculative decoding for this request.
|
||||||
|
# This is used to compute the acceptance rate and average acceptance length per request.
|
||||||
|
self.spec_accepted_tokens = 0
|
||||||
|
|
||||||
# For metrics
|
# For metrics
|
||||||
self.metrics_collector = metrics_collector
|
self.metrics_collector = metrics_collector
|
||||||
self.time_stats: TimeStats = TimeStats(disagg_mode=disagg_mode)
|
self.time_stats: TimeStats = TimeStats(disagg_mode=disagg_mode)
|
||||||
|
|||||||
@@ -216,14 +216,24 @@ class SchedulerMetricsMixin:
|
|||||||
|
|
||||||
if self.spec_algorithm.is_none():
|
if self.spec_algorithm.is_none():
|
||||||
spec_accept_length = 0
|
spec_accept_length = 0
|
||||||
|
spec_accept_rate = 0
|
||||||
else:
|
else:
|
||||||
spec_accept_length = (
|
spec_accept_length = (
|
||||||
self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct
|
self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct
|
||||||
)
|
)
|
||||||
|
# Calculate acceptance rate: accepted tokens / total draft tokens
|
||||||
|
total_draft_tokens = self.spec_num_total_forward_ct * (
|
||||||
|
self.server_args.speculative_num_steps or 1
|
||||||
|
)
|
||||||
|
spec_accept_rate = (
|
||||||
|
self.spec_num_total_accepted_tokens / total_draft_tokens
|
||||||
|
if total_draft_tokens > 0
|
||||||
|
else 0
|
||||||
|
)
|
||||||
self.cum_spec_accept_length += self.spec_num_total_accepted_tokens
|
self.cum_spec_accept_length += self.spec_num_total_accepted_tokens
|
||||||
self.cum_spec_accept_count += self.spec_num_total_forward_ct
|
self.cum_spec_accept_count += self.spec_num_total_forward_ct
|
||||||
self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
|
self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
|
||||||
msg += f"accept len: {spec_accept_length:.2f}, "
|
msg += f"accept len: {spec_accept_length:.2f}, accept rate: {spec_accept_rate:.2f}, "
|
||||||
cache_hit_rate = 0.0
|
cache_hit_rate = 0.0
|
||||||
|
|
||||||
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||||
@@ -251,6 +261,9 @@ class SchedulerMetricsMixin:
|
|||||||
self.stats.num_queue_reqs = len(self.waiting_queue)
|
self.stats.num_queue_reqs = len(self.waiting_queue)
|
||||||
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
|
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
|
||||||
self.stats.cache_hit_rate = cache_hit_rate
|
self.stats.cache_hit_rate = cache_hit_rate
|
||||||
|
|
||||||
|
# Speculative decoding
|
||||||
|
self.stats.spec_accept_rate = spec_accept_rate
|
||||||
self.stats.spec_accept_length = spec_accept_length
|
self.stats.spec_accept_length = spec_accept_length
|
||||||
|
|
||||||
# Retract
|
# Retract
|
||||||
|
|||||||
@@ -634,6 +634,7 @@ class SchedulerOutputProcessorMixin:
|
|||||||
completion_tokens = []
|
completion_tokens = []
|
||||||
cached_tokens = []
|
cached_tokens = []
|
||||||
spec_verify_ct = []
|
spec_verify_ct = []
|
||||||
|
spec_accepted_tokens = []
|
||||||
output_hidden_states = None
|
output_hidden_states = None
|
||||||
|
|
||||||
if return_logprob:
|
if return_logprob:
|
||||||
@@ -725,6 +726,7 @@ class SchedulerOutputProcessorMixin:
|
|||||||
|
|
||||||
if not self.spec_algorithm.is_none():
|
if not self.spec_algorithm.is_none():
|
||||||
spec_verify_ct.append(req.spec_verify_ct)
|
spec_verify_ct.append(req.spec_verify_ct)
|
||||||
|
spec_accepted_tokens.append(req.spec_accepted_tokens)
|
||||||
|
|
||||||
if return_logprob:
|
if return_logprob:
|
||||||
if (
|
if (
|
||||||
@@ -825,6 +827,7 @@ class SchedulerOutputProcessorMixin:
|
|||||||
completion_tokens,
|
completion_tokens,
|
||||||
cached_tokens,
|
cached_tokens,
|
||||||
spec_verify_ct,
|
spec_verify_ct,
|
||||||
|
spec_accepted_tokens,
|
||||||
input_token_logprobs_val,
|
input_token_logprobs_val,
|
||||||
input_token_logprobs_idx,
|
input_token_logprobs_idx,
|
||||||
output_token_logprobs_val,
|
output_token_logprobs_val,
|
||||||
|
|||||||
@@ -1394,6 +1394,36 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|||||||
if state.finished:
|
if state.finished:
|
||||||
if self.server_args.speculative_algorithm:
|
if self.server_args.speculative_algorithm:
|
||||||
meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
|
meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
|
||||||
|
if (
|
||||||
|
recv_obj.spec_verify_ct[i] > 0
|
||||||
|
and self.server_args.speculative_num_steps is not None
|
||||||
|
and not isinstance(recv_obj, BatchEmbeddingOutput)
|
||||||
|
and hasattr(recv_obj, "spec_accepted_tokens")
|
||||||
|
# Checks that `spec_accepted_tokens[i]` will exist.
|
||||||
|
and len(recv_obj.spec_accepted_tokens) > i
|
||||||
|
):
|
||||||
|
total_draft_tokens = (
|
||||||
|
recv_obj.spec_verify_ct[i]
|
||||||
|
* self.server_args.speculative_num_steps
|
||||||
|
)
|
||||||
|
accepted_tokens = recv_obj.spec_accepted_tokens[i]
|
||||||
|
|
||||||
|
# Calculate per-request acceptance rate and average acceptance length.
|
||||||
|
if total_draft_tokens > 0:
|
||||||
|
# Calculate acceptance rate: accepted / (steps * lookahead)
|
||||||
|
meta_info["spec_accept_rate"] = (
|
||||||
|
accepted_tokens / total_draft_tokens
|
||||||
|
)
|
||||||
|
meta_info["spec_accept_length"] = (
|
||||||
|
recv_obj.completion_tokens[i]
|
||||||
|
/ recv_obj.spec_verify_ct[i]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
meta_info["spec_accept_rate"] = 0.0
|
||||||
|
meta_info["spec_accept_length"] = 0
|
||||||
|
else:
|
||||||
|
meta_info["spec_acceptance_rate"] = 0.0
|
||||||
|
meta_info["spec_accept_length"] = 0
|
||||||
state.finished_time = time.time()
|
state.finished_time = time.time()
|
||||||
meta_info["e2e_latency"] = state.finished_time - state.created_time
|
meta_info["e2e_latency"] = state.finished_time - state.created_time
|
||||||
|
|
||||||
|
|||||||
@@ -127,6 +127,7 @@ class SchedulerStats:
|
|||||||
|
|
||||||
# Speculative decoding
|
# Speculative decoding
|
||||||
spec_accept_length: float = 0.0
|
spec_accept_length: float = 0.0
|
||||||
|
spec_accept_rate: float = 0.0
|
||||||
|
|
||||||
# Retract
|
# Retract
|
||||||
num_retracted_reqs: int = 0
|
num_retracted_reqs: int = 0
|
||||||
@@ -220,6 +221,12 @@ class SchedulerMetricsCollector:
|
|||||||
labelnames=labels.keys(),
|
labelnames=labels.keys(),
|
||||||
multiprocess_mode="mostrecent",
|
multiprocess_mode="mostrecent",
|
||||||
)
|
)
|
||||||
|
self.spec_accept_rate = Gauge(
|
||||||
|
name="sglang:spec_accept_rate",
|
||||||
|
documentation="The average acceptance rate of speculative decoding (`accepted tokens / total draft tokens` in batch).",
|
||||||
|
labelnames=labels.keys(),
|
||||||
|
multiprocess_mode="mostrecent",
|
||||||
|
)
|
||||||
|
|
||||||
# Retract
|
# Retract
|
||||||
self.num_retracted_reqs = Gauge(
|
self.num_retracted_reqs = Gauge(
|
||||||
@@ -520,6 +527,7 @@ class SchedulerMetricsCollector:
|
|||||||
|
|
||||||
# Speculative decoding
|
# Speculative decoding
|
||||||
self._log_gauge(self.spec_accept_length, stats.spec_accept_length)
|
self._log_gauge(self.spec_accept_length, stats.spec_accept_length)
|
||||||
|
self._log_gauge(self.spec_accept_rate, stats.spec_accept_rate)
|
||||||
|
|
||||||
# PD disaggregation
|
# PD disaggregation
|
||||||
self._log_gauge(
|
self._log_gauge(
|
||||||
|
|||||||
@@ -378,6 +378,13 @@ class EagleVerifyInput(SpecInput):
|
|||||||
unfinished_accept_index.append(accept_index[i])
|
unfinished_accept_index.append(accept_index[i])
|
||||||
req.spec_verify_ct += 1
|
req.spec_verify_ct += 1
|
||||||
|
|
||||||
|
# For each request, accumulate # of accepted tokens for this verify pass.
|
||||||
|
accept_length_this_pass = (accept_index != -1).sum(dim=1) - 1
|
||||||
|
for i, (req, accepted_count) in enumerate(
|
||||||
|
zip(batch.reqs, accept_length_this_pass.tolist())
|
||||||
|
):
|
||||||
|
req.spec_accepted_tokens += accepted_count
|
||||||
|
|
||||||
if has_finished:
|
if has_finished:
|
||||||
accept_length = (accept_index != -1).sum(dim=1) - 1
|
accept_length = (accept_index != -1).sum(dim=1) - 1
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user