From 4d303c4fa365dbe8b4d474be6e613954bb829939 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Fri, 9 Feb 2024 20:06:15 -0800 Subject: [PATCH] Fix token usage with jump forward (#174) --- python/sglang/srt/managers/router/infer_batch.py | 3 +++ python/sglang/srt/managers/router/model_rpc.py | 10 ++++++++-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/router/infer_batch.py b/python/sglang/srt/managers/router/infer_batch.py index 002a12927..7bb9bc234 100644 --- a/python/sglang/srt/managers/router/infer_batch.py +++ b/python/sglang/srt/managers/router/infer_batch.py @@ -27,6 +27,9 @@ class Req: self.input_ids = input_ids self.output_ids = [] + # for accumulated prompt tokens from jump forward + self.orig_prompt_tokens = len(input_ids) + # For vision input self.pixel_values = None self.image_size = None diff --git a/python/sglang/srt/managers/router/model_rpc.py b/python/sglang/srt/managers/router/model_rpc.py index dfcf34378..b57f6ce52 100644 --- a/python/sglang/srt/managers/router/model_rpc.py +++ b/python/sglang/srt/managers/router/model_rpc.py @@ -534,10 +534,16 @@ class ModelRpcServer(rpyc.Service): output_skip_special_tokens.append( req.sampling_params.skip_special_tokens ) + + # For the length of input_ids, which will be accumulated during jump-forward. + # Use the original length of input_ids to calculate the token usage info. meta_info = { - "prompt_tokens": len(req.input_ids), - "completion_tokens": len(req.output_ids), + "prompt_tokens": req.orig_prompt_tokens, + "completion_tokens": len(req.input_ids) + + len(req.output_ids) + - req.orig_prompt_tokens, } + if req.return_logprob: meta_info["prompt_logprob"] = req.logprob meta_info["token_logprob"] = req.token_logprob