Return more infos for computing average acceptance length (#3152)

This commit is contained in:
Lianmin Zheng
2025-01-26 04:51:54 -08:00
committed by GitHub
parent 7e0976133c
commit 1dda8c5e4c
10 changed files with 97 additions and 15 deletions

View File

@@ -201,6 +201,7 @@ class DetokenizerManager:
prompt_tokens=recv_obj.prompt_tokens,
completion_tokens=recv_obj.completion_tokens,
cached_tokens=recv_obj.cached_tokens,
spec_verify_ct=recv_obj.spec_verify_ct,
input_token_logprobs_val=recv_obj.input_token_logprobs_val,
input_token_logprobs_idx=recv_obj.input_token_logprobs_idx,
output_token_logprobs_val=recv_obj.output_token_logprobs_val,

View File

@@ -354,10 +354,13 @@ class BatchTokenIDOut:
skip_special_tokens: List[bool]
spaces_between_special_tokens: List[bool]
no_stop_trim: List[bool]
# Token counts
prompt_tokens: List[int]
completion_tokens: List[int]
cached_tokens: List[int]
spec_verify_ct: List[int]
# Logprobs
input_token_logprobs_val: List[float]
input_token_logprobs_idx: List[int]
@@ -382,6 +385,7 @@ class BatchStrOut:
prompt_tokens: List[int]
completion_tokens: List[int]
cached_tokens: List[int]
spec_verify_ct: List[int]
# Logprobs
input_token_logprobs_val: List[float]

View File

@@ -252,7 +252,6 @@ class Req:
# Sampling info
self.sampling_params = sampling_params
self.lora_path = lora_path
self.custom_logit_processor = custom_logit_processor
# Memory pool info
@@ -300,7 +299,7 @@ class Req:
self.logprob_start_len = 0
self.top_logprobs_num = top_logprobs_num
# Logprobs (return value)
# Logprobs (return values)
self.input_token_logprobs_val: Optional[List[float]] = None
self.input_token_logprobs_idx: Optional[List[int]] = None
self.input_top_logprobs_val: Optional[List[float]] = None
@@ -329,10 +328,15 @@ class Req:
# Constrained decoding
self.grammar: Optional[BaseGrammarObject] = None
# The number of cached tokens, that were already cached in the KV cache
# The number of cached tokens that were already cached in the KV cache
self.cached_tokens = 0
self.already_computed = 0
# The number of verification forward passes in the speculative decoding.
# This is used to compute the average acceptance length per request.
self.spec_verify_ct = 0
self.lora_path = lora_path
def extend_image_inputs(self, image_inputs):
if self.image_inputs is None:
self.image_inputs = image_inputs

View File

@@ -281,6 +281,7 @@ class Scheduler:
# Print debug info
logger.info(
f"max_total_num_tokens={self.max_total_num_tokens}, "
f"chunked_prefill_size={server_args.chunked_prefill_size}, "
f"max_prefill_tokens={self.max_prefill_tokens}, "
f"max_running_requests={self.max_running_requests}, "
f"context_len={self.model_config.context_len}"
@@ -408,6 +409,11 @@ class Scheduler:
},
)
# The largest prefill length of a single request
self._largest_prefill_len: int = 0
# The largest context length (prefill + generation) of a single request
self._largest_prefill_decode_len: int = 0
# Init request dispatcher
self._request_dispatcher = TypeBasedDispatcher(
[
@@ -1371,6 +1377,7 @@ class Scheduler:
prompt_tokens = []
completion_tokens = []
cached_tokens = []
spec_verify_ct = []
if return_logprob:
input_token_logprobs_val = []
@@ -1424,6 +1431,9 @@ class Scheduler:
completion_tokens.append(len(req.output_ids))
cached_tokens.append(req.cached_tokens)
if not self.spec_algorithm.is_none():
spec_verify_ct.append(req.spec_verify_ct)
if return_logprob:
input_token_logprobs_val.append(req.input_token_logprobs_val)
input_token_logprobs_idx.append(req.input_token_logprobs_idx)
@@ -1451,6 +1461,7 @@ class Scheduler:
prompt_tokens,
completion_tokens,
cached_tokens,
spec_verify_ct,
input_token_logprobs_val,
input_token_logprobs_idx,
output_token_logprobs_val,

View File

@@ -785,6 +785,9 @@ class TokenizerManager:
i,
)
if self.server_args.speculative_algorithm:
meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
if not isinstance(recv_obj, BatchEmbeddingOut):
meta_info.update(
{
@@ -809,6 +812,7 @@ class TokenizerManager:
"embedding": recv_obj.embeddings[i],
"meta_info": meta_info,
}
state.out_list.append(out_dict)
state.finished = recv_obj.finished_reasons[i] is not None
state.event.set()