From 183df4728260a1469612f848f980cc71266591b9 Mon Sep 17 00:00:00 2001 From: ZhouXingg <165115237+ZhouXingg@users.noreply.github.com> Date: Wed, 1 May 2024 07:17:12 +0800 Subject: [PATCH] SamplingParams add "spaces_between_special_tokens" argument (#392) --- python/sglang/backend/runtime_endpoint.py | 4 ++++ python/sglang/global_config.py | 1 + python/sglang/srt/managers/detokenizer_manager.py | 3 ++- python/sglang/srt/managers/io_struct.py | 1 + python/sglang/srt/managers/router/model_rpc.py | 5 +++++ python/sglang/srt/sampling_params.py | 2 ++ 6 files changed, 15 insertions(+), 1 deletion(-) diff --git a/python/sglang/backend/runtime_endpoint.py b/python/sglang/backend/runtime_endpoint.py index 77b9a3277..c4adf6987 100644 --- a/python/sglang/backend/runtime_endpoint.py +++ b/python/sglang/backend/runtime_endpoint.py @@ -107,6 +107,7 @@ class RuntimeEndpoint(BaseBackend): "text": s.text_, "sampling_params": { "skip_special_tokens": global_config.skip_special_tokens_in_output, + "spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out, **sampling_params.to_srt_kwargs(), }, } @@ -115,6 +116,7 @@ class RuntimeEndpoint(BaseBackend): "text": s.text_, "sampling_params": { "skip_special_tokens": global_config.skip_special_tokens_in_output, + "spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out, "dtype": "int", **sampling_params.to_srt_kwargs(), }, @@ -145,6 +147,7 @@ class RuntimeEndpoint(BaseBackend): "text": s.text_, "sampling_params": { "skip_special_tokens": global_config.skip_special_tokens_in_output, + "spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out, **sampling_params.to_srt_kwargs(), }, } @@ -153,6 +156,7 @@ class RuntimeEndpoint(BaseBackend): "text": s.text_, "sampling_params": { "skip_special_tokens": global_config.skip_special_tokens_in_output, + "spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out, "dtype": "int", **sampling_params.to_srt_kwargs(), }, diff --git a/python/sglang/global_config.py b/python/sglang/global_config.py index 36458b7b1..ef0853b7e 100644 --- a/python/sglang/global_config.py +++ b/python/sglang/global_config.py @@ -12,6 +12,7 @@ class GlobalConfig: # Output configs self.skip_special_tokens_in_output = True + self.spaces_between_special_tokens_in_out = True # Optimization configs self.eager_fill_image = False diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index 32454ead4..595355425 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -38,10 +38,11 @@ class DetokenizerManager: if isinstance(recv_obj, BatchTokenIDOut): output_tokens = recv_obj.output_tokens - # TODO(lmzheng): handle skip_special_tokens per request + # TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request output_strs = self.tokenizer.batch_decode( output_tokens, skip_special_tokens=recv_obj.skip_special_tokens[0], + spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0], ) # Trim stop str diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 53b1f552a..6e64380c9 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -97,6 +97,7 @@ class BatchTokenIDOut: output_and_jump_forward_strs: List[str] hit_stop_str: List[Optional[str]] skip_special_tokens: List[bool] + spaces_between_special_tokens: List[bool] meta_info: List[Dict] finished: List[bool] diff --git a/python/sglang/srt/managers/router/model_rpc.py b/python/sglang/srt/managers/router/model_rpc.py index 02c98560b..80c38b0df 100644 --- a/python/sglang/srt/managers/router/model_rpc.py +++ b/python/sglang/srt/managers/router/model_rpc.py @@ -549,6 +549,7 @@ class ModelRpcServer: output_and_jump_forward_strs = [] output_hit_stop_str = [] output_skip_special_tokens = [] + output_spaces_between_special_tokens = [] output_meta_info = [] output_finished = [] finished_indices = [] @@ -575,6 +576,9 @@ class ModelRpcServer: output_skip_special_tokens.append( req.sampling_params.skip_special_tokens ) + output_spaces_between_special_tokens.append( + req.sampling_params.spaces_between_special_tokens + ) meta_info = { "prompt_tokens": req.prompt_tokens, @@ -609,6 +613,7 @@ class ModelRpcServer: output_and_jump_forward_strs, output_hit_stop_str, output_skip_special_tokens, + output_spaces_between_special_tokens, output_meta_info, output_finished, ) diff --git a/python/sglang/srt/sampling_params.py b/python/sglang/srt/sampling_params.py index f096de2d9..f6b4f5706 100644 --- a/python/sglang/srt/sampling_params.py +++ b/python/sglang/srt/sampling_params.py @@ -17,6 +17,7 @@ class SamplingParams: presence_penalty: float = 0.0, ignore_eos: bool = False, skip_special_tokens: bool = True, + spaces_between_special_tokens: bool = True, dtype: Optional[str] = None, regex: Optional[str] = None, ) -> None: @@ -29,6 +30,7 @@ class SamplingParams: self.max_new_tokens = max_new_tokens self.ignore_eos = ignore_eos self.skip_special_tokens = skip_special_tokens + self.spaces_between_special_tokens = spaces_between_special_tokens self.dtype = dtype self.regex = regex