From 26908d9568ef6b7f658cbaea6096f5fcd7df5451 Mon Sep 17 00:00:00 2001 From: Pan Lyu Date: Sun, 7 Jul 2024 05:53:22 +0800 Subject: [PATCH] * fix(detokenizer_manager.py): fix truncated decoded output (#586) Co-authored-by: hnyls2002 --- python/sglang/backend/runtime_endpoint.py | 10 +++++----- python/sglang/srt/managers/detokenizer_manager.py | 11 ++++++++--- python/sglang/srt/managers/io_struct.py | 3 ++- python/sglang/srt/managers/tokenizer_manager.py | 3 ++- 4 files changed, 17 insertions(+), 10 deletions(-) diff --git a/python/sglang/backend/runtime_endpoint.py b/python/sglang/backend/runtime_endpoint.py index f7b8f7b5d..6f11d5492 100644 --- a/python/sglang/backend/runtime_endpoint.py +++ b/python/sglang/backend/runtime_endpoint.py @@ -8,7 +8,7 @@ from sglang.global_config import global_config from sglang.lang.chat_template import get_chat_template_by_model_path from sglang.lang.interpreter import StreamExecutor from sglang.lang.ir import SglSamplingParams -from sglang.utils import find_printable_text, http_request +from sglang.utils import http_request class RuntimeEndpoint(BaseBackend): @@ -187,11 +187,11 @@ class RuntimeEndpoint(BaseBackend): if chunk == "data: [DONE]": break data = json.loads(chunk[5:].strip("\n")) - text = find_printable_text(data["text"][pos:]) + chunk_text = data["text"][pos:] + incomplete_text = data["incomplete_text"] meta_info = data["meta_info"] - pos += len(text) - incomplete_text = data["text"][pos:] - yield text, meta_info + pos += len(chunk_text) + yield chunk_text, meta_info if len(incomplete_text) > 0: yield incomplete_text, meta_info diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index ecba679e2..be1eb4d44 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -11,7 +11,7 @@ from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.managers.controller.infer_batch import FINISH_MATCHED_STR from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut from sglang.srt.server_args import PortArgs, ServerArgs -from sglang.utils import get_exception_traceback, graceful_registry +from sglang.utils import find_printable_text, get_exception_traceback, graceful_registry asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) @@ -55,9 +55,13 @@ class DetokenizerManager: # Trim stop str # TODO(lmzheng): handle the case where multiple stop strs are hit output_strs = [] + incomplete_strs = [] for i in range(len(recv_obj.rids)): new_text = read_texts[i][len(surr_texts[i]) :] - output_strs.append(recv_obj.decoded_texts[i] + new_text) + complete_new_text = find_printable_text(new_text) + incomplete_new_text = new_text[len(complete_new_text) :] + output_strs.append(recv_obj.decoded_texts[i] + complete_new_text) + incomplete_strs.append(incomplete_new_text) if isinstance(recv_obj.finished_reason[i], FINISH_MATCHED_STR): pos = output_strs[i].find(recv_obj.finished_reason[i].matched) @@ -67,7 +71,8 @@ class DetokenizerManager: self.send_to_tokenizer.send_pyobj( BatchStrOut( rids=recv_obj.rids, - output_str=output_strs, + output_strs=output_strs, + incomplete_strs=incomplete_strs, meta_info=recv_obj.meta_info, finished_reason=recv_obj.finished_reason, ) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index e661edfaf..681e888a9 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -122,7 +122,8 @@ class BatchTokenIDOut: @dataclass class BatchStrOut: rids: List[str] - output_str: List[str] + output_strs: List[str] + incomplete_strs: List[str] meta_info: List[Dict] finished_reason: List[BaseFinishReason] diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 42f970370..90aebcd4b 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -316,7 +316,8 @@ class TokenizerManager: recv_obj.meta_info[i]["id"] = rid out_dict = { - "text": recv_obj.output_str[i], + "text": recv_obj.output_strs[i], + "incomplete_text": recv_obj.incomplete_strs[i], "meta_info": recv_obj.meta_info[i], } state.out_list.append(out_dict)