From 93dffd699bd653fb1dfef44f30eb3d7ec40d6a4d Mon Sep 17 00:00:00 2001 From: zifeitong Date: Mon, 16 Sep 2024 13:29:18 -0700 Subject: [PATCH] Add constrained_json_whitespace_pattern to ServerArgs (#1438) --- python/sglang/srt/constrained/fsm_cache.py | 6 +++++- python/sglang/srt/managers/tp_worker.py | 11 +++++------ python/sglang/srt/server_args.py | 7 +++++++ 3 files changed, 17 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/constrained/fsm_cache.py b/python/sglang/srt/constrained/fsm_cache.py index fd5995dad..4ac4cef48 100644 --- a/python/sglang/srt/constrained/fsm_cache.py +++ b/python/sglang/srt/constrained/fsm_cache.py @@ -29,6 +29,7 @@ class FSMCache(BaseToolCache): tokenizer_args_dict, enable=True, skip_tokenizer_init=False, + constrained_json_whitespace_pattern=None, ): super().__init__(enable=enable) @@ -63,11 +64,14 @@ class FSMCache(BaseToolCache): self.outlines_tokenizer.vocabulary = ( self.outlines_tokenizer.tokenizer.get_vocab() ) + self.constrained_json_whitespace_pattern = constrained_json_whitespace_pattern def init_value(self, key): key_type, key_string = key if key_type == "json": - regex = build_regex_from_schema(key_string, whitespace_pattern=r"[\n\t ]*") + regex = build_regex_from_schema( + key_string, whitespace_pattern=self.constrained_json_whitespace_pattern + ) elif key_type == "regex": regex = key_string else: diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 09a2ede21..fe9017f12 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -198,6 +198,7 @@ class ModelTpServer: "trust_remote_code": server_args.trust_remote_code, }, skip_tokenizer_init=server_args.skip_tokenizer_init, + constrained_json_whitespace_pattern=server_args.constrained_json_whitespace_pattern, ) self.jump_forward_cache = JumpForwardCache() @@ -807,12 +808,10 @@ class ModelTpServer: unfinished_indices.append(i) if req.finished() or ( - ( - req.stream - and ( - self.decode_forward_ct % self.stream_interval == 0 - or len(req.output_ids) == 1 - ) + req.stream + and ( + self.decode_forward_ct % self.stream_interval == 0 + or len(req.output_ids) == 1 ) ): output_rids.append(req.rid) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 11769b57f..818856716 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -70,6 +70,7 @@ class ServerArgs: tp_size: int = 1 stream_interval: int = 1 random_seed: Optional[int] = None + constrained_json_whitespace_pattern: Optional[str] = None # Logging log_level: str = "info" @@ -370,6 +371,12 @@ class ServerArgs: default=ServerArgs.random_seed, help="The random seed.", ) + parser.add_argument( + "--constrained-json-whitespace-pattern", + type=str, + default=ServerArgs.constrained_json_whitespace_pattern, + help=r"Regex pattern for syntactic whitespaces allowed in JSON constrained output. For example, to allow the model generate consecutive whitespaces, set the pattern to [\n\t ]*", + ) parser.add_argument( "--log-level", type=str,