Fix token usage with jump forward (#174)

This commit is contained in:
Cody Yu
2024-02-09 20:06:15 -08:00
committed by GitHub
parent 37b42297f8
commit 4d303c4fa3
2 changed files with 11 additions and 2 deletions

View File

@@ -27,6 +27,9 @@ class Req:
self.input_ids = input_ids
self.output_ids = []
# for accumulated prompt tokens from jump forward
self.orig_prompt_tokens = len(input_ids)
# For vision input
self.pixel_values = None
self.image_size = None

View File

@@ -534,10 +534,16 @@ class ModelRpcServer(rpyc.Service):
output_skip_special_tokens.append(
req.sampling_params.skip_special_tokens
)
# For the length of input_ids, which will be accumulated during jump-forward.
# Use the original length of input_ids to calculate the token usage info.
meta_info = {
"prompt_tokens": len(req.input_ids),
"completion_tokens": len(req.output_ids),
"prompt_tokens": req.orig_prompt_tokens,
"completion_tokens": len(req.input_ids)
+ len(req.output_ids)
- req.orig_prompt_tokens,
}
if req.return_logprob:
meta_info["prompt_logprob"] = req.logprob
meta_info["token_logprob"] = req.token_logprob