diff --git a/python/sglang/srt/managers/router/infer_batch.py b/python/sglang/srt/managers/router/infer_batch.py index 002a12927..7bb9bc234 100644 --- a/python/sglang/srt/managers/router/infer_batch.py +++ b/python/sglang/srt/managers/router/infer_batch.py @@ -27,6 +27,9 @@ class Req: self.input_ids = input_ids self.output_ids = [] + # for accumulated prompt tokens from jump forward + self.orig_prompt_tokens = len(input_ids) + # For vision input self.pixel_values = None self.image_size = None diff --git a/python/sglang/srt/managers/router/model_rpc.py b/python/sglang/srt/managers/router/model_rpc.py index dfcf34378..b57f6ce52 100644 --- a/python/sglang/srt/managers/router/model_rpc.py +++ b/python/sglang/srt/managers/router/model_rpc.py @@ -534,10 +534,16 @@ class ModelRpcServer(rpyc.Service): output_skip_special_tokens.append( req.sampling_params.skip_special_tokens ) + + # For the length of input_ids, which will be accumulated during jump-forward. + # Use the original length of input_ids to calculate the token usage info. meta_info = { - "prompt_tokens": len(req.input_ids), - "completion_tokens": len(req.output_ids), + "prompt_tokens": req.orig_prompt_tokens, + "completion_tokens": len(req.input_ids) + + len(req.output_ids) + - req.orig_prompt_tokens, } + if req.return_logprob: meta_info["prompt_logprob"] = req.logprob meta_info["token_logprob"] = req.token_logprob