From 63ba630bbbb2d55787ac54ac0a01cbde993afc20 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Thu, 15 Feb 2024 10:54:20 -0800 Subject: [PATCH] Refactor decoding logprob and add completion_tokens_wo_jump_forward (#189) --- python/sglang/srt/managers/io_struct.py | 4 ++- .../sglang/srt/managers/router/infer_batch.py | 8 ++++-- .../sglang/srt/managers/router/model_rpc.py | 11 ++++---- python/sglang/srt/server.py | 27 ++++++++++++++----- 4 files changed, 35 insertions(+), 15 deletions(-) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index a2070a5b1..b6817994a 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -15,10 +15,12 @@ class GenerateReqInput: sampling_params: Union[List[Dict], Dict] = None # The request id rid: Optional[Union[List[str], str]] = None - # Whether return logprobs of the prompts + # Whether to return logprobs return_logprob: Optional[Union[List[bool], bool]] = None # The start location of the prompt for return_logprob logprob_start_len: Optional[Union[List[int], int]] = None + # Whether to detokenize tokens in logprobs + return_text_in_logprobs: bool = False # Whether to stream output stream: bool = False diff --git a/python/sglang/srt/managers/router/infer_batch.py b/python/sglang/srt/managers/router/infer_batch.py index 7bb9bc234..da5cab42d 100644 --- a/python/sglang/srt/managers/router/infer_batch.py +++ b/python/sglang/srt/managers/router/infer_batch.py @@ -27,8 +27,12 @@ class Req: self.input_ids = input_ids self.output_ids = [] - # for accumulated prompt tokens from jump forward - self.orig_prompt_tokens = len(input_ids) + # Since jump forward may retokenize the prompt with partial outputs, + # we maintain the original prompt length to report the correct usage. + self.prompt_tokens = len(input_ids) + # The number of decoded tokens for token usage report. Note that + # this does not include the jump forward tokens. + self.completion_tokens_wo_jump_forward = 0 # For vision input self.pixel_values = None diff --git a/python/sglang/srt/managers/router/model_rpc.py b/python/sglang/srt/managers/router/model_rpc.py index c89fa67f3..8f1f1e58a 100644 --- a/python/sglang/srt/managers/router/model_rpc.py +++ b/python/sglang/srt/managers/router/model_rpc.py @@ -424,6 +424,7 @@ class ModelRpcServer(rpyc.Service): # Check finish condition pt = 0 for i, req in enumerate(reqs): + req.completion_tokens_wo_jump_forward += 1 req.output_ids = [next_token_ids[i]] req.check_finished() @@ -500,6 +501,7 @@ class ModelRpcServer(rpyc.Service): # Check finish condition for i, (req, next_tok_id) in enumerate(zip(reqs, next_token_ids)): + req.completion_tokens_wo_jump_forward += 1 req.output_ids.append(next_tok_id) req.check_finished() @@ -541,15 +543,14 @@ class ModelRpcServer(rpyc.Service): req.sampling_params.skip_special_tokens ) - # For the length of input_ids, which will be accumulated during jump-forward. - # Use the original length of input_ids to calculate the token usage info. meta_info = { - "prompt_tokens": req.orig_prompt_tokens, + "prompt_tokens": req.prompt_tokens, "completion_tokens": len(req.input_ids) + len(req.output_ids) - - req.orig_prompt_tokens, + - req.prompt_tokens, + "completion_tokens_wo_jump_forward": + req.completion_tokens_wo_jump_forward } - if req.return_logprob: meta_info["prompt_logprob"] = req.logprob meta_info["token_logprob"] = req.token_logprob diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index d132f5814..55b3ff046 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -52,7 +52,7 @@ from sglang.srt.managers.openai_protocol import ( from sglang.srt.managers.router.manager import start_router_process from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.server_args import PortArgs, ServerArgs -from sglang.srt.utils import alloc_usable_network_port, handle_port_init +from sglang.srt.utils import handle_port_init asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) @@ -96,19 +96,25 @@ async def flush_cache(): ) -async def stream_generator(obj): +async def detokenize_logprob_tokens(token_logprobs): + token_ids = [tid for tid, _ in token_logprobs] + token_texts = await tokenizer_manager.detokenize(DetokenizeReqInput(token_ids)) + return [(text, logprob) for text, (_, logprob) in zip(token_texts, token_logprobs)] + + +async def stream_generator(obj: GenerateReqInput): async for out in tokenizer_manager.generate_request(obj): + if obj.return_logprob and obj.return_text_in_logprobs: + out["meta_info"]["token_logprob"] = await detokenize_logprob_tokens( + out["meta_info"]["token_logprob"] + ) yield out async def make_openai_style_logprobs(token_logprobs): ret_logprobs = LogProbs() - # Detokenize - token_ids = [tid for tid, _ in token_logprobs] - token_texts = await tokenizer_manager.detokenize(DetokenizeReqInput(token_ids)) - - for token_text, (_, token_logprob) in zip(token_texts, token_logprobs): + for token_text, token_logprob in token_logprobs: ret_logprobs.tokens.append(token_text) ret_logprobs.token_logprobs.append(token_logprob) @@ -132,6 +138,11 @@ async def generate_request(obj: GenerateReqInput): return StreamingResponse(stream_results(), media_type="text/event-stream") ret = await tokenizer_manager.generate_request(obj).__anext__() + if obj.return_logprob and obj.return_text_in_logprobs: + ret["meta_info"]["token_logprob"] = await detokenize_logprob_tokens( + ret["meta_info"]["token_logprob"] + ) + return ret @@ -155,6 +166,7 @@ async def v1_completions(raw_request: Request): "regex": request.regex, }, return_logprob=request.logprobs is not None, + return_text_in_logprobs=True, stream=request.stream, ) adapted_request.post_init() @@ -211,6 +223,7 @@ async def v1_completions(raw_request: Request): # Non-streaming response. ret = await generate_request(adapted_request) + ret = ret[0] if isinstance(ret, list) else ret prompt_tokens = ret["meta_info"]["prompt_tokens"] completion_tokens = ret["meta_info"]["completion_tokens"]