From 9935f97b3e594e246776466d04134decff1b59ae Mon Sep 17 00:00:00 2001 From: havetc Date: Mon, 26 Aug 2024 18:37:26 +0200 Subject: [PATCH] [FEAT] JSON constrained support (#1125) Co-authored-by: Yineng Zhang --- docs/en/sampling_params.md | 3 + python/sglang/srt/constrained/fsm_cache.py | 13 ++- python/sglang/srt/constrained/jump_forward.py | 1 + python/sglang/srt/managers/schedule_batch.py | 7 ++ python/sglang/srt/managers/tp_worker.py | 21 +++- python/sglang/srt/openai_api/adapter.py | 2 + python/sglang/srt/openai_api/protocol.py | 2 + python/sglang/srt/sampling/sampling_params.py | 4 + test/srt/run_suite.py | 1 + test/srt/test_json_constrained.py | 96 +++++++++++++++++++ 10 files changed, 147 insertions(+), 3 deletions(-) create mode 100644 test/srt/test_json_constrained.py diff --git a/docs/en/sampling_params.md b/docs/en/sampling_params.md index 54b03bf32..0e1c13e4b 100644 --- a/docs/en/sampling_params.md +++ b/docs/en/sampling_params.md @@ -60,6 +60,9 @@ spaces_between_special_tokens: bool = True, regex: Optional[str] = None, # Do parallel sampling and return `n` outputs. n: int = 1, +# Constrains the output to follow a given JSON schema. +# `regex` and `json_schema` cannot be set at the same time. +json_schema: Optional[str] = None, ## Penalties. See [Performance Implications on Penalties] section below for more informations. diff --git a/python/sglang/srt/constrained/fsm_cache.py b/python/sglang/srt/constrained/fsm_cache.py index fa41f90de..6bc6ea6d2 100644 --- a/python/sglang/srt/constrained/fsm_cache.py +++ b/python/sglang/srt/constrained/fsm_cache.py @@ -15,6 +15,8 @@ limitations under the License. """Cache for the compressed finite state machine.""" +from outlines.fsm.json_schema import build_regex_from_schema + from sglang.srt.constrained import RegexGuide, TransformerTokenizer from sglang.srt.constrained.base_tool_cache import BaseToolCache @@ -26,9 +28,12 @@ class FSMCache(BaseToolCache): tokenizer_args_dict, enable=True, skip_tokenizer_init=False, + json_schema_mode=False, ): super().__init__(enable=enable) + self.json_schema_mode = json_schema_mode + if ( skip_tokenizer_init or tokenizer_path.endswith(".json") @@ -72,5 +77,9 @@ class FSMCache(BaseToolCache): tokenizer_path, **tokenizer_args_dict ) - def init_value(self, regex): - return RegexGuide(regex, self.outlines_tokenizer) + def init_value(self, value): + if self.json_schema_mode: + regex = build_regex_from_schema(value) + return RegexGuide(regex, self.outlines_tokenizer), regex + else: + return RegexGuide(value, self.outlines_tokenizer) diff --git a/python/sglang/srt/constrained/jump_forward.py b/python/sglang/srt/constrained/jump_forward.py index b00c48d47..244931e05 100644 --- a/python/sglang/srt/constrained/jump_forward.py +++ b/python/sglang/srt/constrained/jump_forward.py @@ -23,6 +23,7 @@ from collections import defaultdict import interegular import outlines.caching +from outlines.fsm.json_schema import build_regex_from_schema from sglang.srt.constrained import ( FSMInfo, diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index dfd32dea9..cc180ba21 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -268,7 +268,14 @@ class Req: all_text = self.origin_input_text + self.decoded_text + jump_forward_str all_ids = self.tokenizer.encode(all_text) + if not all_ids: + warnings.warn("Encoded all_text resulted in empty all_ids") + return False + prompt_tokens = len(self.origin_input_ids_unpadded) + if prompt_tokens > len(all_ids): + warnings.warn("prompt_tokens is larger than encoded all_ids") + return False if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]: # TODO(lsyin): fix token fusion diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index ddf20970e..127f71900 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -197,6 +197,16 @@ class ModelTpServer: "trust_remote_code": server_args.trust_remote_code, }, skip_tokenizer_init=server_args.skip_tokenizer_init, + json_schema_mode=False, + ) + self.json_fsm_cache = FSMCache( + server_args.tokenizer_path, + { + "tokenizer_mode": server_args.tokenizer_mode, + "trust_remote_code": server_args.trust_remote_code, + }, + skip_tokenizer_init=server_args.skip_tokenizer_init, + json_schema_mode=True, ) self.jump_forward_cache = JumpForwardCache() @@ -349,8 +359,17 @@ class ModelTpServer: req.top_logprobs_num = recv_req.top_logprobs_num req.stream = recv_req.stream + # Init regex fsm fron json + if req.sampling_params.json_schema is not None: + req.regex_fsm, computed_regex_string = self.json_fsm_cache.query( + req.sampling_params.json_schema + ) + if not self.disable_regex_jump_forward: + req.jump_forward_map = self.jump_forward_cache.query( + computed_regex_string + ) # Init regex fsm - if req.sampling_params.regex is not None: + elif req.sampling_params.regex is not None: req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex) if not self.disable_regex_jump_forward: req.jump_forward_map = self.jump_forward_cache.query( diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index f325e84b2..148f2689d 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -434,6 +434,7 @@ def v1_generate_request(all_requests: List[CompletionRequest]): "frequency_penalty": request.frequency_penalty, "repetition_penalty": request.repetition_penalty, "regex": request.regex, + "json_schema": request.json_schema, "n": request.n, "ignore_eos": request.ignore_eos, } @@ -802,6 +803,7 @@ def v1_chat_generate_request( "frequency_penalty": request.frequency_penalty, "repetition_penalty": request.repetition_penalty, "regex": request.regex, + "json_schema": request.json_schema, "n": request.n, } ) diff --git a/python/sglang/srt/openai_api/protocol.py b/python/sglang/srt/openai_api/protocol.py index 758e48ede..ce51e1c02 100644 --- a/python/sglang/srt/openai_api/protocol.py +++ b/python/sglang/srt/openai_api/protocol.py @@ -161,6 +161,7 @@ class CompletionRequest(BaseModel): # Extra parameters for SRT backend only and will be ignored by OpenAI models. regex: Optional[str] = None + json_schema: Optional[str] = None ignore_eos: Optional[bool] = False min_tokens: Optional[int] = 0 repetition_penalty: Optional[float] = 1.0 @@ -262,6 +263,7 @@ class ChatCompletionRequest(BaseModel): # Extra parameters for SRT backend only and will be ignored by OpenAI models. regex: Optional[str] = None + json_schema: Optional[str] = None min_tokens: Optional[int] = 0 repetition_penalty: Optional[float] = 1.0 stop_token_ids: Optional[List[int]] = Field(default_factory=list) diff --git a/python/sglang/srt/sampling/sampling_params.py b/python/sglang/srt/sampling/sampling_params.py index c30717dd7..8111757d8 100644 --- a/python/sglang/srt/sampling/sampling_params.py +++ b/python/sglang/srt/sampling/sampling_params.py @@ -39,6 +39,7 @@ class SamplingParams: spaces_between_special_tokens: bool = True, regex: Optional[str] = None, n: int = 1, + json_schema: Optional[str] = None, ) -> None: self.temperature = temperature self.top_p = top_p @@ -56,6 +57,7 @@ class SamplingParams: self.spaces_between_special_tokens = spaces_between_special_tokens self.regex = regex self.n = n + self.json_schema = json_schema # Process some special cases if self.temperature < _SAMPLING_EPS: @@ -106,6 +108,8 @@ class SamplingParams: f"min_new_tokens must be in (0, max_new_tokens({self.max_new_tokens})], got " f"{self.min_new_tokens}." ) + if self.regex is not None and self.json_schema is not None: + raise ValueError("regex and json_schema cannot be both set.") def normalize(self, tokenizer): # Process stop strings diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 2351579f1..cafcf3f2d 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -13,6 +13,7 @@ suites = { "test_eval_accuracy_mini.py", "test_large_max_new_tokens.py", "test_openai_server.py", + "test_json_constrained.py", "test_skip_tokenizer_init.py", "test_torch_compile.py", "test_triton_attn_backend.py", diff --git a/test/srt/test_json_constrained.py b/test/srt/test_json_constrained.py new file mode 100644 index 000000000..5393ecc33 --- /dev/null +++ b/test/srt/test_json_constrained.py @@ -0,0 +1,96 @@ +import json +import unittest + +import openai +import requests + +from sglang.srt.utils import kill_child_process +from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +class TestJSONConstrained(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.api_key = "sk-123456" + cls.json_schema = json.dumps( + { + "type": "object", + "properties": { + "name": {"type": "string", "pattern": "^[\\w]+$"}, + "population": {"type": "integer"}, + }, + "required": ["name", "population"], + } + ) + cls.process = popen_launch_server( + cls.model, cls.base_url, timeout=300, api_key=cls.api_key + ) + + @classmethod + def tearDownClass(cls): + kill_child_process(cls.process.pid) + + def run_decode(self, return_logprob=False, top_logprobs_num=0, n=1): + headers = {"Authorization": f"Bearer {self.api_key}"} + response = requests.post( + self.base_url + "/generate", + json={ + "text": "The capital of France is", + "sampling_params": { + "temperature": 0 if n == 1 else 0.5, + "max_new_tokens": 128, + "n": n, + "stop_token_ids": [119690], + "json_schema": self.json_schema, + }, + "stream": False, + "return_logprob": return_logprob, + "top_logprobs_num": top_logprobs_num, + "logprob_start_len": 0, + }, + headers=headers, + ) + print(json.dumps(response.json())) + print("=" * 100) + try: + js_obj = json.loads(response.json()["text"]) + except (TypeError, json.decoder.JSONDecodeError): + raise + assert isinstance(js_obj["name"], str) + assert isinstance(js_obj["population"], int) + + def test_json_generate(self): + self.run_decode() + + def test_json_openai(self): + client = openai.Client(api_key=self.api_key, base_url=f"{self.base_url}/v1") + + response = client.chat.completions.create( + model=self.model, + messages=[ + {"role": "system", "content": "You are a helpful AI assistant"}, + {"role": "user", "content": "Introduce the capital of France."}, + ], + temperature=0, + max_tokens=128, + extra_body={"json_schema": self.json_schema}, + ) + text = response.choices[0].message.content + + try: + js_obj = json.loads(text) + except (TypeError, json.decoder.JSONDecodeError): + print("JSONDecodeError", text) + raise + assert isinstance(js_obj["name"], str) + assert isinstance(js_obj["population"], int) + + +if __name__ == "__main__": + unittest.main()