Revert "Add metrics for speculative decoding (acceptance rate, average acceptance length)" (#11433)
This commit is contained in:
@@ -233,7 +233,6 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
|
|||||||
completion_tokens=recv_obj.completion_tokens,
|
completion_tokens=recv_obj.completion_tokens,
|
||||||
cached_tokens=recv_obj.cached_tokens,
|
cached_tokens=recv_obj.cached_tokens,
|
||||||
spec_verify_ct=recv_obj.spec_verify_ct,
|
spec_verify_ct=recv_obj.spec_verify_ct,
|
||||||
spec_accepted_tokens=recv_obj.spec_accepted_tokens,
|
|
||||||
input_token_logprobs_val=recv_obj.input_token_logprobs_val,
|
input_token_logprobs_val=recv_obj.input_token_logprobs_val,
|
||||||
input_token_logprobs_idx=recv_obj.input_token_logprobs_idx,
|
input_token_logprobs_idx=recv_obj.input_token_logprobs_idx,
|
||||||
output_token_logprobs_val=recv_obj.output_token_logprobs_val,
|
output_token_logprobs_val=recv_obj.output_token_logprobs_val,
|
||||||
|
|||||||
@@ -816,7 +816,6 @@ class BatchTokenIDOutput(BaseBatchReq):
|
|||||||
completion_tokens: List[int]
|
completion_tokens: List[int]
|
||||||
cached_tokens: List[int]
|
cached_tokens: List[int]
|
||||||
spec_verify_ct: List[int]
|
spec_verify_ct: List[int]
|
||||||
spec_accepted_tokens: List[int]
|
|
||||||
|
|
||||||
# Logprobs
|
# Logprobs
|
||||||
input_token_logprobs_val: List[float]
|
input_token_logprobs_val: List[float]
|
||||||
@@ -883,7 +882,6 @@ class BatchStrOutput(BaseBatchReq):
|
|||||||
completion_tokens: List[int]
|
completion_tokens: List[int]
|
||||||
cached_tokens: List[int]
|
cached_tokens: List[int]
|
||||||
spec_verify_ct: List[int]
|
spec_verify_ct: List[int]
|
||||||
spec_accepted_tokens: List[int]
|
|
||||||
|
|
||||||
# Logprobs
|
# Logprobs
|
||||||
input_token_logprobs_val: List[float]
|
input_token_logprobs_val: List[float]
|
||||||
|
|||||||
@@ -246,11 +246,6 @@ def _handle_output_by_index(output, i):
|
|||||||
spec_verify_ct=(
|
spec_verify_ct=(
|
||||||
[output.spec_verify_ct[i]] if len(output.spec_verify_ct) > i else None
|
[output.spec_verify_ct[i]] if len(output.spec_verify_ct) > i else None
|
||||||
),
|
),
|
||||||
spec_accepted_tokens=(
|
|
||||||
[output.spec_accepted_tokens[i]]
|
|
||||||
if len(output.spec_accepted_tokens) > i
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
input_token_logprobs_val=(
|
input_token_logprobs_val=(
|
||||||
[output.input_token_logprobs_val[i]]
|
[output.input_token_logprobs_val[i]]
|
||||||
if output.input_token_logprobs_val
|
if output.input_token_logprobs_val
|
||||||
|
|||||||
@@ -631,10 +631,6 @@ class Req:
|
|||||||
# This is used to compute the average acceptance length per request.
|
# This is used to compute the average acceptance length per request.
|
||||||
self.spec_verify_ct = 0
|
self.spec_verify_ct = 0
|
||||||
|
|
||||||
# The number of accepted tokens in speculative decoding for this request.
|
|
||||||
# This is used to compute the acceptance rate and average acceptance length per request.
|
|
||||||
self.spec_accepted_tokens = 0
|
|
||||||
|
|
||||||
# For metrics
|
# For metrics
|
||||||
self.metrics_collector = metrics_collector
|
self.metrics_collector = metrics_collector
|
||||||
self.time_stats: TimeStats = TimeStats(disagg_mode=disagg_mode)
|
self.time_stats: TimeStats = TimeStats(disagg_mode=disagg_mode)
|
||||||
|
|||||||
@@ -216,24 +216,14 @@ class SchedulerMetricsMixin:
|
|||||||
|
|
||||||
if self.spec_algorithm.is_none():
|
if self.spec_algorithm.is_none():
|
||||||
spec_accept_length = 0
|
spec_accept_length = 0
|
||||||
spec_accept_rate = 0
|
|
||||||
else:
|
else:
|
||||||
spec_accept_length = (
|
spec_accept_length = (
|
||||||
self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct
|
self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct
|
||||||
)
|
)
|
||||||
# Calculate acceptance rate: accepted tokens / total draft tokens
|
|
||||||
total_draft_tokens = self.spec_num_total_forward_ct * (
|
|
||||||
self.server_args.speculative_num_steps or 1
|
|
||||||
)
|
|
||||||
spec_accept_rate = (
|
|
||||||
self.spec_num_total_accepted_tokens / total_draft_tokens
|
|
||||||
if total_draft_tokens > 0
|
|
||||||
else 0
|
|
||||||
)
|
|
||||||
self.cum_spec_accept_length += self.spec_num_total_accepted_tokens
|
self.cum_spec_accept_length += self.spec_num_total_accepted_tokens
|
||||||
self.cum_spec_accept_count += self.spec_num_total_forward_ct
|
self.cum_spec_accept_count += self.spec_num_total_forward_ct
|
||||||
self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
|
self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
|
||||||
msg += f"accept len: {spec_accept_length:.2f}, accept rate: {spec_accept_rate:.2f}, "
|
msg += f"accept len: {spec_accept_length:.2f}, "
|
||||||
cache_hit_rate = 0.0
|
cache_hit_rate = 0.0
|
||||||
|
|
||||||
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||||
@@ -261,9 +251,6 @@ class SchedulerMetricsMixin:
|
|||||||
self.stats.num_queue_reqs = len(self.waiting_queue)
|
self.stats.num_queue_reqs = len(self.waiting_queue)
|
||||||
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
|
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
|
||||||
self.stats.cache_hit_rate = cache_hit_rate
|
self.stats.cache_hit_rate = cache_hit_rate
|
||||||
|
|
||||||
# Speculative decoding
|
|
||||||
self.stats.spec_accept_rate = spec_accept_rate
|
|
||||||
self.stats.spec_accept_length = spec_accept_length
|
self.stats.spec_accept_length = spec_accept_length
|
||||||
|
|
||||||
# Retract
|
# Retract
|
||||||
|
|||||||
@@ -634,7 +634,6 @@ class SchedulerOutputProcessorMixin:
|
|||||||
completion_tokens = []
|
completion_tokens = []
|
||||||
cached_tokens = []
|
cached_tokens = []
|
||||||
spec_verify_ct = []
|
spec_verify_ct = []
|
||||||
spec_accepted_tokens = []
|
|
||||||
output_hidden_states = None
|
output_hidden_states = None
|
||||||
|
|
||||||
if return_logprob:
|
if return_logprob:
|
||||||
@@ -726,7 +725,6 @@ class SchedulerOutputProcessorMixin:
|
|||||||
|
|
||||||
if not self.spec_algorithm.is_none():
|
if not self.spec_algorithm.is_none():
|
||||||
spec_verify_ct.append(req.spec_verify_ct)
|
spec_verify_ct.append(req.spec_verify_ct)
|
||||||
spec_accepted_tokens.append(req.spec_accepted_tokens)
|
|
||||||
|
|
||||||
if return_logprob:
|
if return_logprob:
|
||||||
if (
|
if (
|
||||||
@@ -827,7 +825,6 @@ class SchedulerOutputProcessorMixin:
|
|||||||
completion_tokens,
|
completion_tokens,
|
||||||
cached_tokens,
|
cached_tokens,
|
||||||
spec_verify_ct,
|
spec_verify_ct,
|
||||||
spec_accepted_tokens,
|
|
||||||
input_token_logprobs_val,
|
input_token_logprobs_val,
|
||||||
input_token_logprobs_idx,
|
input_token_logprobs_idx,
|
||||||
output_token_logprobs_val,
|
output_token_logprobs_val,
|
||||||
|
|||||||
@@ -1394,36 +1394,6 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|||||||
if state.finished:
|
if state.finished:
|
||||||
if self.server_args.speculative_algorithm:
|
if self.server_args.speculative_algorithm:
|
||||||
meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
|
meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
|
||||||
if (
|
|
||||||
recv_obj.spec_verify_ct[i] > 0
|
|
||||||
and self.server_args.speculative_num_steps is not None
|
|
||||||
and not isinstance(recv_obj, BatchEmbeddingOutput)
|
|
||||||
and hasattr(recv_obj, "spec_accepted_tokens")
|
|
||||||
# Checks that `spec_accepted_tokens[i]` will exist.
|
|
||||||
and len(recv_obj.spec_accepted_tokens) > i
|
|
||||||
):
|
|
||||||
total_draft_tokens = (
|
|
||||||
recv_obj.spec_verify_ct[i]
|
|
||||||
* self.server_args.speculative_num_steps
|
|
||||||
)
|
|
||||||
accepted_tokens = recv_obj.spec_accepted_tokens[i]
|
|
||||||
|
|
||||||
# Calculate per-request acceptance rate and average acceptance length.
|
|
||||||
if total_draft_tokens > 0:
|
|
||||||
# Calculate acceptance rate: accepted / (steps * lookahead)
|
|
||||||
meta_info["spec_accept_rate"] = (
|
|
||||||
accepted_tokens / total_draft_tokens
|
|
||||||
)
|
|
||||||
meta_info["spec_accept_length"] = (
|
|
||||||
recv_obj.completion_tokens[i]
|
|
||||||
/ recv_obj.spec_verify_ct[i]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
meta_info["spec_accept_rate"] = 0.0
|
|
||||||
meta_info["spec_accept_length"] = 0
|
|
||||||
else:
|
|
||||||
meta_info["spec_acceptance_rate"] = 0.0
|
|
||||||
meta_info["spec_accept_length"] = 0
|
|
||||||
state.finished_time = time.time()
|
state.finished_time = time.time()
|
||||||
meta_info["e2e_latency"] = state.finished_time - state.created_time
|
meta_info["e2e_latency"] = state.finished_time - state.created_time
|
||||||
|
|
||||||
|
|||||||
@@ -127,7 +127,6 @@ class SchedulerStats:
|
|||||||
|
|
||||||
# Speculative decoding
|
# Speculative decoding
|
||||||
spec_accept_length: float = 0.0
|
spec_accept_length: float = 0.0
|
||||||
spec_accept_rate: float = 0.0
|
|
||||||
|
|
||||||
# Retract
|
# Retract
|
||||||
num_retracted_reqs: int = 0
|
num_retracted_reqs: int = 0
|
||||||
@@ -221,12 +220,6 @@ class SchedulerMetricsCollector:
|
|||||||
labelnames=labels.keys(),
|
labelnames=labels.keys(),
|
||||||
multiprocess_mode="mostrecent",
|
multiprocess_mode="mostrecent",
|
||||||
)
|
)
|
||||||
self.spec_accept_rate = Gauge(
|
|
||||||
name="sglang:spec_accept_rate",
|
|
||||||
documentation="The average acceptance rate of speculative decoding (`accepted tokens / total draft tokens` in batch).",
|
|
||||||
labelnames=labels.keys(),
|
|
||||||
multiprocess_mode="mostrecent",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Retract
|
# Retract
|
||||||
self.num_retracted_reqs = Gauge(
|
self.num_retracted_reqs = Gauge(
|
||||||
@@ -527,7 +520,6 @@ class SchedulerMetricsCollector:
|
|||||||
|
|
||||||
# Speculative decoding
|
# Speculative decoding
|
||||||
self._log_gauge(self.spec_accept_length, stats.spec_accept_length)
|
self._log_gauge(self.spec_accept_length, stats.spec_accept_length)
|
||||||
self._log_gauge(self.spec_accept_rate, stats.spec_accept_rate)
|
|
||||||
|
|
||||||
# PD disaggregation
|
# PD disaggregation
|
||||||
self._log_gauge(
|
self._log_gauge(
|
||||||
|
|||||||
@@ -378,13 +378,6 @@ class EagleVerifyInput(SpecInput):
|
|||||||
unfinished_accept_index.append(accept_index[i])
|
unfinished_accept_index.append(accept_index[i])
|
||||||
req.spec_verify_ct += 1
|
req.spec_verify_ct += 1
|
||||||
|
|
||||||
# For each request, accumulate # of accepted tokens for this verify pass.
|
|
||||||
accept_length_this_pass = (accept_index != -1).sum(dim=1) - 1
|
|
||||||
for i, (req, accepted_count) in enumerate(
|
|
||||||
zip(batch.reqs, accept_length_this_pass.tolist())
|
|
||||||
):
|
|
||||||
req.spec_accepted_tokens += accepted_count
|
|
||||||
|
|
||||||
if has_finished:
|
if has_finished:
|
||||||
accept_length = (accept_index != -1).sum(dim=1) - 1
|
accept_length = (accept_index != -1).sum(dim=1) - 1
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user