[FEAT] JSON constrained support (#1125)
Co-authored-by: Yineng Zhang <me@zhyncs.com>
This commit is contained in:
@@ -60,6 +60,9 @@ spaces_between_special_tokens: bool = True,
|
|||||||
regex: Optional[str] = None,
|
regex: Optional[str] = None,
|
||||||
# Do parallel sampling and return `n` outputs.
|
# Do parallel sampling and return `n` outputs.
|
||||||
n: int = 1,
|
n: int = 1,
|
||||||
|
# Constrains the output to follow a given JSON schema.
|
||||||
|
# `regex` and `json_schema` cannot be set at the same time.
|
||||||
|
json_schema: Optional[str] = None,
|
||||||
|
|
||||||
## Penalties. See [Performance Implications on Penalties] section below for more informations.
|
## Penalties. See [Performance Implications on Penalties] section below for more informations.
|
||||||
|
|
||||||
|
|||||||
@@ -15,6 +15,8 @@ limitations under the License.
|
|||||||
|
|
||||||
"""Cache for the compressed finite state machine."""
|
"""Cache for the compressed finite state machine."""
|
||||||
|
|
||||||
|
from outlines.fsm.json_schema import build_regex_from_schema
|
||||||
|
|
||||||
from sglang.srt.constrained import RegexGuide, TransformerTokenizer
|
from sglang.srt.constrained import RegexGuide, TransformerTokenizer
|
||||||
from sglang.srt.constrained.base_tool_cache import BaseToolCache
|
from sglang.srt.constrained.base_tool_cache import BaseToolCache
|
||||||
|
|
||||||
@@ -26,9 +28,12 @@ class FSMCache(BaseToolCache):
|
|||||||
tokenizer_args_dict,
|
tokenizer_args_dict,
|
||||||
enable=True,
|
enable=True,
|
||||||
skip_tokenizer_init=False,
|
skip_tokenizer_init=False,
|
||||||
|
json_schema_mode=False,
|
||||||
):
|
):
|
||||||
super().__init__(enable=enable)
|
super().__init__(enable=enable)
|
||||||
|
|
||||||
|
self.json_schema_mode = json_schema_mode
|
||||||
|
|
||||||
if (
|
if (
|
||||||
skip_tokenizer_init
|
skip_tokenizer_init
|
||||||
or tokenizer_path.endswith(".json")
|
or tokenizer_path.endswith(".json")
|
||||||
@@ -72,5 +77,9 @@ class FSMCache(BaseToolCache):
|
|||||||
tokenizer_path, **tokenizer_args_dict
|
tokenizer_path, **tokenizer_args_dict
|
||||||
)
|
)
|
||||||
|
|
||||||
def init_value(self, regex):
|
def init_value(self, value):
|
||||||
return RegexGuide(regex, self.outlines_tokenizer)
|
if self.json_schema_mode:
|
||||||
|
regex = build_regex_from_schema(value)
|
||||||
|
return RegexGuide(regex, self.outlines_tokenizer), regex
|
||||||
|
else:
|
||||||
|
return RegexGuide(value, self.outlines_tokenizer)
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ from collections import defaultdict
|
|||||||
|
|
||||||
import interegular
|
import interegular
|
||||||
import outlines.caching
|
import outlines.caching
|
||||||
|
from outlines.fsm.json_schema import build_regex_from_schema
|
||||||
|
|
||||||
from sglang.srt.constrained import (
|
from sglang.srt.constrained import (
|
||||||
FSMInfo,
|
FSMInfo,
|
||||||
|
|||||||
@@ -268,7 +268,14 @@ class Req:
|
|||||||
|
|
||||||
all_text = self.origin_input_text + self.decoded_text + jump_forward_str
|
all_text = self.origin_input_text + self.decoded_text + jump_forward_str
|
||||||
all_ids = self.tokenizer.encode(all_text)
|
all_ids = self.tokenizer.encode(all_text)
|
||||||
|
if not all_ids:
|
||||||
|
warnings.warn("Encoded all_text resulted in empty all_ids")
|
||||||
|
return False
|
||||||
|
|
||||||
prompt_tokens = len(self.origin_input_ids_unpadded)
|
prompt_tokens = len(self.origin_input_ids_unpadded)
|
||||||
|
if prompt_tokens > len(all_ids):
|
||||||
|
warnings.warn("prompt_tokens is larger than encoded all_ids")
|
||||||
|
return False
|
||||||
|
|
||||||
if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]:
|
if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]:
|
||||||
# TODO(lsyin): fix token fusion
|
# TODO(lsyin): fix token fusion
|
||||||
|
|||||||
@@ -197,6 +197,16 @@ class ModelTpServer:
|
|||||||
"trust_remote_code": server_args.trust_remote_code,
|
"trust_remote_code": server_args.trust_remote_code,
|
||||||
},
|
},
|
||||||
skip_tokenizer_init=server_args.skip_tokenizer_init,
|
skip_tokenizer_init=server_args.skip_tokenizer_init,
|
||||||
|
json_schema_mode=False,
|
||||||
|
)
|
||||||
|
self.json_fsm_cache = FSMCache(
|
||||||
|
server_args.tokenizer_path,
|
||||||
|
{
|
||||||
|
"tokenizer_mode": server_args.tokenizer_mode,
|
||||||
|
"trust_remote_code": server_args.trust_remote_code,
|
||||||
|
},
|
||||||
|
skip_tokenizer_init=server_args.skip_tokenizer_init,
|
||||||
|
json_schema_mode=True,
|
||||||
)
|
)
|
||||||
self.jump_forward_cache = JumpForwardCache()
|
self.jump_forward_cache = JumpForwardCache()
|
||||||
|
|
||||||
@@ -349,8 +359,17 @@ class ModelTpServer:
|
|||||||
req.top_logprobs_num = recv_req.top_logprobs_num
|
req.top_logprobs_num = recv_req.top_logprobs_num
|
||||||
req.stream = recv_req.stream
|
req.stream = recv_req.stream
|
||||||
|
|
||||||
|
# Init regex fsm fron json
|
||||||
|
if req.sampling_params.json_schema is not None:
|
||||||
|
req.regex_fsm, computed_regex_string = self.json_fsm_cache.query(
|
||||||
|
req.sampling_params.json_schema
|
||||||
|
)
|
||||||
|
if not self.disable_regex_jump_forward:
|
||||||
|
req.jump_forward_map = self.jump_forward_cache.query(
|
||||||
|
computed_regex_string
|
||||||
|
)
|
||||||
# Init regex fsm
|
# Init regex fsm
|
||||||
if req.sampling_params.regex is not None:
|
elif req.sampling_params.regex is not None:
|
||||||
req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex)
|
req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex)
|
||||||
if not self.disable_regex_jump_forward:
|
if not self.disable_regex_jump_forward:
|
||||||
req.jump_forward_map = self.jump_forward_cache.query(
|
req.jump_forward_map = self.jump_forward_cache.query(
|
||||||
|
|||||||
@@ -434,6 +434,7 @@ def v1_generate_request(all_requests: List[CompletionRequest]):
|
|||||||
"frequency_penalty": request.frequency_penalty,
|
"frequency_penalty": request.frequency_penalty,
|
||||||
"repetition_penalty": request.repetition_penalty,
|
"repetition_penalty": request.repetition_penalty,
|
||||||
"regex": request.regex,
|
"regex": request.regex,
|
||||||
|
"json_schema": request.json_schema,
|
||||||
"n": request.n,
|
"n": request.n,
|
||||||
"ignore_eos": request.ignore_eos,
|
"ignore_eos": request.ignore_eos,
|
||||||
}
|
}
|
||||||
@@ -802,6 +803,7 @@ def v1_chat_generate_request(
|
|||||||
"frequency_penalty": request.frequency_penalty,
|
"frequency_penalty": request.frequency_penalty,
|
||||||
"repetition_penalty": request.repetition_penalty,
|
"repetition_penalty": request.repetition_penalty,
|
||||||
"regex": request.regex,
|
"regex": request.regex,
|
||||||
|
"json_schema": request.json_schema,
|
||||||
"n": request.n,
|
"n": request.n,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -161,6 +161,7 @@ class CompletionRequest(BaseModel):
|
|||||||
|
|
||||||
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
|
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
|
||||||
regex: Optional[str] = None
|
regex: Optional[str] = None
|
||||||
|
json_schema: Optional[str] = None
|
||||||
ignore_eos: Optional[bool] = False
|
ignore_eos: Optional[bool] = False
|
||||||
min_tokens: Optional[int] = 0
|
min_tokens: Optional[int] = 0
|
||||||
repetition_penalty: Optional[float] = 1.0
|
repetition_penalty: Optional[float] = 1.0
|
||||||
@@ -262,6 +263,7 @@ class ChatCompletionRequest(BaseModel):
|
|||||||
|
|
||||||
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
|
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
|
||||||
regex: Optional[str] = None
|
regex: Optional[str] = None
|
||||||
|
json_schema: Optional[str] = None
|
||||||
min_tokens: Optional[int] = 0
|
min_tokens: Optional[int] = 0
|
||||||
repetition_penalty: Optional[float] = 1.0
|
repetition_penalty: Optional[float] = 1.0
|
||||||
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ class SamplingParams:
|
|||||||
spaces_between_special_tokens: bool = True,
|
spaces_between_special_tokens: bool = True,
|
||||||
regex: Optional[str] = None,
|
regex: Optional[str] = None,
|
||||||
n: int = 1,
|
n: int = 1,
|
||||||
|
json_schema: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.temperature = temperature
|
self.temperature = temperature
|
||||||
self.top_p = top_p
|
self.top_p = top_p
|
||||||
@@ -56,6 +57,7 @@ class SamplingParams:
|
|||||||
self.spaces_between_special_tokens = spaces_between_special_tokens
|
self.spaces_between_special_tokens = spaces_between_special_tokens
|
||||||
self.regex = regex
|
self.regex = regex
|
||||||
self.n = n
|
self.n = n
|
||||||
|
self.json_schema = json_schema
|
||||||
|
|
||||||
# Process some special cases
|
# Process some special cases
|
||||||
if self.temperature < _SAMPLING_EPS:
|
if self.temperature < _SAMPLING_EPS:
|
||||||
@@ -106,6 +108,8 @@ class SamplingParams:
|
|||||||
f"min_new_tokens must be in (0, max_new_tokens({self.max_new_tokens})], got "
|
f"min_new_tokens must be in (0, max_new_tokens({self.max_new_tokens})], got "
|
||||||
f"{self.min_new_tokens}."
|
f"{self.min_new_tokens}."
|
||||||
)
|
)
|
||||||
|
if self.regex is not None and self.json_schema is not None:
|
||||||
|
raise ValueError("regex and json_schema cannot be both set.")
|
||||||
|
|
||||||
def normalize(self, tokenizer):
|
def normalize(self, tokenizer):
|
||||||
# Process stop strings
|
# Process stop strings
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ suites = {
|
|||||||
"test_eval_accuracy_mini.py",
|
"test_eval_accuracy_mini.py",
|
||||||
"test_large_max_new_tokens.py",
|
"test_large_max_new_tokens.py",
|
||||||
"test_openai_server.py",
|
"test_openai_server.py",
|
||||||
|
"test_json_constrained.py",
|
||||||
"test_skip_tokenizer_init.py",
|
"test_skip_tokenizer_init.py",
|
||||||
"test_torch_compile.py",
|
"test_torch_compile.py",
|
||||||
"test_triton_attn_backend.py",
|
"test_triton_attn_backend.py",
|
||||||
|
|||||||
96
test/srt/test_json_constrained.py
Normal file
96
test/srt/test_json_constrained.py
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
import json
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import openai
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from sglang.srt.utils import kill_child_process
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
DEFAULT_MODEL_NAME_FOR_TEST,
|
||||||
|
DEFAULT_URL_FOR_TEST,
|
||||||
|
popen_launch_server,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestJSONConstrained(unittest.TestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
cls.api_key = "sk-123456"
|
||||||
|
cls.json_schema = json.dumps(
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": {"type": "string", "pattern": "^[\\w]+$"},
|
||||||
|
"population": {"type": "integer"},
|
||||||
|
},
|
||||||
|
"required": ["name", "population"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
cls.model, cls.base_url, timeout=300, api_key=cls.api_key
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
kill_child_process(cls.process.pid)
|
||||||
|
|
||||||
|
def run_decode(self, return_logprob=False, top_logprobs_num=0, n=1):
|
||||||
|
headers = {"Authorization": f"Bearer {self.api_key}"}
|
||||||
|
response = requests.post(
|
||||||
|
self.base_url + "/generate",
|
||||||
|
json={
|
||||||
|
"text": "The capital of France is",
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0 if n == 1 else 0.5,
|
||||||
|
"max_new_tokens": 128,
|
||||||
|
"n": n,
|
||||||
|
"stop_token_ids": [119690],
|
||||||
|
"json_schema": self.json_schema,
|
||||||
|
},
|
||||||
|
"stream": False,
|
||||||
|
"return_logprob": return_logprob,
|
||||||
|
"top_logprobs_num": top_logprobs_num,
|
||||||
|
"logprob_start_len": 0,
|
||||||
|
},
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
print(json.dumps(response.json()))
|
||||||
|
print("=" * 100)
|
||||||
|
try:
|
||||||
|
js_obj = json.loads(response.json()["text"])
|
||||||
|
except (TypeError, json.decoder.JSONDecodeError):
|
||||||
|
raise
|
||||||
|
assert isinstance(js_obj["name"], str)
|
||||||
|
assert isinstance(js_obj["population"], int)
|
||||||
|
|
||||||
|
def test_json_generate(self):
|
||||||
|
self.run_decode()
|
||||||
|
|
||||||
|
def test_json_openai(self):
|
||||||
|
client = openai.Client(api_key=self.api_key, base_url=f"{self.base_url}/v1")
|
||||||
|
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model=self.model,
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": "You are a helpful AI assistant"},
|
||||||
|
{"role": "user", "content": "Introduce the capital of France."},
|
||||||
|
],
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=128,
|
||||||
|
extra_body={"json_schema": self.json_schema},
|
||||||
|
)
|
||||||
|
text = response.choices[0].message.content
|
||||||
|
|
||||||
|
try:
|
||||||
|
js_obj = json.loads(text)
|
||||||
|
except (TypeError, json.decoder.JSONDecodeError):
|
||||||
|
print("JSONDecodeError", text)
|
||||||
|
raise
|
||||||
|
assert isinstance(js_obj["name"], str)
|
||||||
|
assert isinstance(js_obj["population"], int)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user