From 48761171716302446a95c8d9d1fe1a469f12309e Mon Sep 17 00:00:00 2001 From: Ying Sheng Date: Sun, 13 Oct 2024 01:07:09 -0700 Subject: [PATCH] [Fix] fix eos trim inconsistency (#1650) --- .../srt/managers/detokenizer_manager.py | 41 ++++++++++++---- python/sglang/srt/managers/io_struct.py | 1 + python/sglang/srt/managers/scheduler.py | 3 ++ python/sglang/srt/openai_api/adapter.py | 49 ++++++++++++------- python/sglang/srt/openai_api/protocol.py | 1 + python/sglang/srt/sampling/sampling_params.py | 2 + python/sglang/srt/utils.py | 7 +++ 7 files changed, 77 insertions(+), 27 deletions(-) diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index 49c9e6fdb..6e90d19cd 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -18,7 +18,7 @@ limitations under the License. import dataclasses import logging from collections import OrderedDict -from typing import List +from typing import List, Union import zmq @@ -29,7 +29,7 @@ from sglang.srt.managers.io_struct import ( BatchTokenIDOut, UpdateWeightReqOutput, ) -from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR +from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR, FINISH_MATCHED_TOKEN from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import configure_logger, kill_parent_process from sglang.utils import find_printable_text, get_exception_traceback @@ -75,6 +75,21 @@ class DetokenizerManager: self.decode_status = LimitedCapacityDict() + def trim_eos(self, output: Union[str, List[int]], finished_reason, no_eos_trim): + if no_eos_trim: + return output + + # Trim stop str. TODO(lmzheng): handle the case where multiple stop strs are hit + if isinstance(finished_reason, FINISH_MATCHED_STR) and isinstance(output, str): + pos = output.find(finished_reason.matched) + return output[:pos] if pos != -1 else output + if isinstance(finished_reason, FINISH_MATCHED_TOKEN) and isinstance( + output, list + ): + assert len(output) > 0 + return output[:-1] + return output + def event_loop(self): """The event loop that handles requests""" @@ -122,7 +137,13 @@ class DetokenizerManager: s = self.decode_status[rid] s.decode_ids = recv_obj.decode_ids[i] - read_ids.append(s.decode_ids[s.surr_offset :]) + read_ids.append( + self.trim_eos( + s.decode_ids[s.surr_offset :], + recv_obj.finished_reason[i], + recv_obj.no_eos_trim[i], + ) + ) surr_ids.append(s.decode_ids[s.surr_offset : s.read_offset]) # TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request @@ -152,13 +173,13 @@ class DetokenizerManager: else: new_text = find_printable_text(new_text) - output_strs.append(s.decoded_text + new_text) - - # Trim stop str. TODO(lmzheng): handle the case where multiple stop strs are hit - if isinstance(recv_obj.finished_reason[i], FINISH_MATCHED_STR): - pos = output_strs[i].find(recv_obj.finished_reason[i].matched) - if pos != -1: - output_strs[i] = output_strs[i][:pos] + output_strs.append( + self.trim_eos( + s.decoded_text + new_text, + recv_obj.finished_reason[i], + recv_obj.no_eos_trim[i], + ) + ) self.send_to_tokenizer.send_pyobj( BatchStrOut( diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index c9ee00e9d..9cc847706 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -295,6 +295,7 @@ class BatchTokenIDOut: spaces_between_special_tokens: List[bool] meta_info: List[Dict] finished_reason: List[BaseFinishReason] + no_eos_trim: List[bool] @dataclass diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 62d1ff9ed..03ae37d66 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -883,6 +883,7 @@ class Scheduler: output_read_offsets = [] output_skip_special_tokens = [] output_spaces_between_special_tokens = [] + output_no_eos_trim = [] else: # embedding or reward model output_embeddings = [] unfinished_indices = [] @@ -914,6 +915,7 @@ class Scheduler: output_spaces_between_special_tokens.append( req.sampling_params.spaces_between_special_tokens ) + output_no_eos_trim.append(req.sampling_params.no_eos_trim) meta_info = { "prompt_tokens": len(req.origin_input_ids), @@ -961,6 +963,7 @@ class Scheduler: output_spaces_between_special_tokens, output_meta_info, output_finished_reason, + output_no_eos_trim, ) ) else: # embedding or reward model diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index 04b7befa2..a3638d601 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -493,23 +493,38 @@ def v1_generate_request( top_logprobs_nums.append( request.logprobs if request.logprobs is not None else 0 ) - sampling_params_list.append( - { - "temperature": request.temperature, - "max_new_tokens": request.max_tokens, - "min_new_tokens": request.min_tokens, - "stop": request.stop, - "stop_token_ids": request.stop_token_ids, - "top_p": request.top_p, - "presence_penalty": request.presence_penalty, - "frequency_penalty": request.frequency_penalty, - "repetition_penalty": request.repetition_penalty, - "regex": request.regex, - "json_schema": request.json_schema, - "n": request.n, - "ignore_eos": request.ignore_eos, - } - ) + sampling_params = [] + if isinstance(request.no_eos_trim, list): + num_reqs = len(request.prompt) + else: + num_reqs = 1 + for i in range(num_reqs): + sampling_params.append( + { + "temperature": request.temperature, + "max_new_tokens": request.max_tokens, + "min_new_tokens": request.min_tokens, + "stop": request.stop, + "stop_token_ids": request.stop_token_ids, + "top_p": request.top_p, + "presence_penalty": request.presence_penalty, + "frequency_penalty": request.frequency_penalty, + "repetition_penalty": request.repetition_penalty, + "regex": request.regex, + "json_schema": request.json_schema, + "n": request.n, + "ignore_eos": request.ignore_eos, + "no_eos_trim": ( + request.no_eos_trim + if not isinstance(request.no_eos_trim, list) + else request.no_eos_trim[i] + ), + } + ) + if num_reqs == 1: + sampling_params_list.append(sampling_params[0]) + else: + sampling_params_list.append(sampling_params) if len(all_requests) == 1: prompt = prompts[0] diff --git a/python/sglang/srt/openai_api/protocol.py b/python/sglang/srt/openai_api/protocol.py index ff4c62f00..4b382240a 100644 --- a/python/sglang/srt/openai_api/protocol.py +++ b/python/sglang/srt/openai_api/protocol.py @@ -174,6 +174,7 @@ class CompletionRequest(BaseModel): min_tokens: int = 0 repetition_penalty: Optional[float] = 1.0 stop_token_ids: Optional[List[int]] = Field(default_factory=list) + no_eos_trim: Union[bool, List[bool]] = False class CompletionResponseChoice(BaseModel): diff --git a/python/sglang/srt/sampling/sampling_params.py b/python/sglang/srt/sampling/sampling_params.py index 700fefa3d..6e497ea7b 100644 --- a/python/sglang/srt/sampling/sampling_params.py +++ b/python/sglang/srt/sampling/sampling_params.py @@ -40,6 +40,7 @@ class SamplingParams: regex: Optional[str] = None, n: int = 1, json_schema: Optional[str] = None, + no_eos_trim: bool = False, ) -> None: self.temperature = temperature self.top_p = top_p @@ -60,6 +61,7 @@ class SamplingParams: self.regex = regex self.n = n self.json_schema = json_schema + self.no_eos_trim = no_eos_trim # Process some special cases if self.temperature < _SAMPLING_EPS: diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index f0ac21fb1..ac2a8cf7f 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -690,3 +690,10 @@ def pytorch_profile(name, func, *args, data_size=-1): prof.export_chrome_trace(f"trace/{name}_{step_counter}.json") step_counter += 1 return result + + +def first_rank_print(*args, **kwargs): + if torch.cuda.current_device() == 0: + print(*args, **kwargs) + else: + pass