Return more infos for computing average acceptance length (#3152)
This commit is contained in:
@@ -201,6 +201,7 @@ class DetokenizerManager:
|
||||
prompt_tokens=recv_obj.prompt_tokens,
|
||||
completion_tokens=recv_obj.completion_tokens,
|
||||
cached_tokens=recv_obj.cached_tokens,
|
||||
spec_verify_ct=recv_obj.spec_verify_ct,
|
||||
input_token_logprobs_val=recv_obj.input_token_logprobs_val,
|
||||
input_token_logprobs_idx=recv_obj.input_token_logprobs_idx,
|
||||
output_token_logprobs_val=recv_obj.output_token_logprobs_val,
|
||||
|
||||
@@ -354,10 +354,13 @@ class BatchTokenIDOut:
|
||||
skip_special_tokens: List[bool]
|
||||
spaces_between_special_tokens: List[bool]
|
||||
no_stop_trim: List[bool]
|
||||
|
||||
# Token counts
|
||||
prompt_tokens: List[int]
|
||||
completion_tokens: List[int]
|
||||
cached_tokens: List[int]
|
||||
spec_verify_ct: List[int]
|
||||
|
||||
# Logprobs
|
||||
input_token_logprobs_val: List[float]
|
||||
input_token_logprobs_idx: List[int]
|
||||
@@ -382,6 +385,7 @@ class BatchStrOut:
|
||||
prompt_tokens: List[int]
|
||||
completion_tokens: List[int]
|
||||
cached_tokens: List[int]
|
||||
spec_verify_ct: List[int]
|
||||
|
||||
# Logprobs
|
||||
input_token_logprobs_val: List[float]
|
||||
|
||||
@@ -252,7 +252,6 @@ class Req:
|
||||
|
||||
# Sampling info
|
||||
self.sampling_params = sampling_params
|
||||
self.lora_path = lora_path
|
||||
self.custom_logit_processor = custom_logit_processor
|
||||
|
||||
# Memory pool info
|
||||
@@ -300,7 +299,7 @@ class Req:
|
||||
self.logprob_start_len = 0
|
||||
self.top_logprobs_num = top_logprobs_num
|
||||
|
||||
# Logprobs (return value)
|
||||
# Logprobs (return values)
|
||||
self.input_token_logprobs_val: Optional[List[float]] = None
|
||||
self.input_token_logprobs_idx: Optional[List[int]] = None
|
||||
self.input_top_logprobs_val: Optional[List[float]] = None
|
||||
@@ -329,10 +328,15 @@ class Req:
|
||||
# Constrained decoding
|
||||
self.grammar: Optional[BaseGrammarObject] = None
|
||||
|
||||
# The number of cached tokens, that were already cached in the KV cache
|
||||
# The number of cached tokens that were already cached in the KV cache
|
||||
self.cached_tokens = 0
|
||||
self.already_computed = 0
|
||||
|
||||
# The number of verification forward passes in the speculative decoding.
|
||||
# This is used to compute the average acceptance length per request.
|
||||
self.spec_verify_ct = 0
|
||||
self.lora_path = lora_path
|
||||
|
||||
def extend_image_inputs(self, image_inputs):
|
||||
if self.image_inputs is None:
|
||||
self.image_inputs = image_inputs
|
||||
|
||||
@@ -281,6 +281,7 @@ class Scheduler:
|
||||
# Print debug info
|
||||
logger.info(
|
||||
f"max_total_num_tokens={self.max_total_num_tokens}, "
|
||||
f"chunked_prefill_size={server_args.chunked_prefill_size}, "
|
||||
f"max_prefill_tokens={self.max_prefill_tokens}, "
|
||||
f"max_running_requests={self.max_running_requests}, "
|
||||
f"context_len={self.model_config.context_len}"
|
||||
@@ -408,6 +409,11 @@ class Scheduler:
|
||||
},
|
||||
)
|
||||
|
||||
# The largest prefill length of a single request
|
||||
self._largest_prefill_len: int = 0
|
||||
# The largest context length (prefill + generation) of a single request
|
||||
self._largest_prefill_decode_len: int = 0
|
||||
|
||||
# Init request dispatcher
|
||||
self._request_dispatcher = TypeBasedDispatcher(
|
||||
[
|
||||
@@ -1371,6 +1377,7 @@ class Scheduler:
|
||||
prompt_tokens = []
|
||||
completion_tokens = []
|
||||
cached_tokens = []
|
||||
spec_verify_ct = []
|
||||
|
||||
if return_logprob:
|
||||
input_token_logprobs_val = []
|
||||
@@ -1424,6 +1431,9 @@ class Scheduler:
|
||||
completion_tokens.append(len(req.output_ids))
|
||||
cached_tokens.append(req.cached_tokens)
|
||||
|
||||
if not self.spec_algorithm.is_none():
|
||||
spec_verify_ct.append(req.spec_verify_ct)
|
||||
|
||||
if return_logprob:
|
||||
input_token_logprobs_val.append(req.input_token_logprobs_val)
|
||||
input_token_logprobs_idx.append(req.input_token_logprobs_idx)
|
||||
@@ -1451,6 +1461,7 @@ class Scheduler:
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
cached_tokens,
|
||||
spec_verify_ct,
|
||||
input_token_logprobs_val,
|
||||
input_token_logprobs_idx,
|
||||
output_token_logprobs_val,
|
||||
|
||||
@@ -785,6 +785,9 @@ class TokenizerManager:
|
||||
i,
|
||||
)
|
||||
|
||||
if self.server_args.speculative_algorithm:
|
||||
meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
|
||||
|
||||
if not isinstance(recv_obj, BatchEmbeddingOut):
|
||||
meta_info.update(
|
||||
{
|
||||
@@ -809,6 +812,7 @@ class TokenizerManager:
|
||||
"embedding": recv_obj.embeddings[i],
|
||||
"meta_info": meta_info,
|
||||
}
|
||||
|
||||
state.out_list.append(out_dict)
|
||||
state.finished = recv_obj.finished_reasons[i] is not None
|
||||
state.event.set()
|
||||
|
||||
Reference in New Issue
Block a user