diff --git a/python/sglang/srt/constrained/jump_forward.py b/python/sglang/srt/constrained/jump_forward.py index 244931e05..b00c48d47 100644 --- a/python/sglang/srt/constrained/jump_forward.py +++ b/python/sglang/srt/constrained/jump_forward.py @@ -23,7 +23,6 @@ from collections import defaultdict import interegular import outlines.caching -from outlines.fsm.json_schema import build_regex_from_schema from sglang.srt.constrained import ( FSMInfo, diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index f1195aff7..d1b296e9b 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -28,6 +28,13 @@ from fastapi import HTTPException, Request, UploadFile from fastapi.responses import JSONResponse, StreamingResponse from pydantic import ValidationError +try: + from outlines.fsm.json_schema import convert_json_schema_to_str +except ImportError: + # Before outlines 0.0.47, convert_json_schema_to_str is under + # outlines.integrations.utils + from outlines.integrations.utils import convert_json_schema_to_str + from sglang.srt.conversation import ( Conversation, SeparatorStyle, @@ -888,22 +895,26 @@ def v1_chat_generate_request( return_logprobs.append(request.logprobs) logprob_start_lens.append(-1) top_logprobs_nums.append(request.top_logprobs) - sampling_params_list.append( - { - "temperature": request.temperature, - "max_new_tokens": request.max_tokens, - "min_new_tokens": request.min_tokens, - "stop": stop, - "stop_token_ids": request.stop_token_ids, - "top_p": request.top_p, - "presence_penalty": request.presence_penalty, - "frequency_penalty": request.frequency_penalty, - "repetition_penalty": request.repetition_penalty, - "regex": request.regex, - "json_schema": request.json_schema, - "n": request.n, - } - ) + + sampling_params = { + "temperature": request.temperature, + "max_new_tokens": request.max_tokens, + "min_new_tokens": request.min_tokens, + "stop": stop, + "stop_token_ids": request.stop_token_ids, + "top_p": request.top_p, + "presence_penalty": request.presence_penalty, + "frequency_penalty": request.frequency_penalty, + "repetition_penalty": request.repetition_penalty, + "regex": request.regex, + "n": request.n, + } + if request.response_format and request.response_format.type == "json_schema": + sampling_params["json_schema"] = convert_json_schema_to_str( + request.response_format.json_schema.schema_ + ) + sampling_params_list.append(sampling_params) + image_data_list.append(image_data) modalities_list.extend(modalities) if len(all_requests) == 1: diff --git a/python/sglang/srt/openai_api/protocol.py b/python/sglang/srt/openai_api/protocol.py index 5525cd882..3d7d450c9 100644 --- a/python/sglang/srt/openai_api/protocol.py +++ b/python/sglang/srt/openai_api/protocol.py @@ -82,6 +82,14 @@ class StreamOptions(BaseModel): include_usage: Optional[bool] = False +class JsonSchemaResponseFormat(BaseModel): + name: str + description: Optional[str] = None + # use alias to workaround pydantic conflict + schema_: Optional[Dict[str, object]] = Field(alias="schema", default=None) + strict: Optional[bool] = False + + class FileRequest(BaseModel): # https://platform.openai.com/docs/api-reference/files/create file: bytes # The File object (not file name) to be uploaded @@ -237,8 +245,8 @@ ChatCompletionMessageParam = Union[ class ResponseFormat(BaseModel): - # type must be "json_object" or "text" - type: Literal["text", "json_object"] + type: Literal["text", "json_object", "json_schema"] + json_schema: Optional[JsonSchemaResponseFormat] = None class ChatCompletionRequest(BaseModel): @@ -264,7 +272,6 @@ class ChatCompletionRequest(BaseModel): # Extra parameters for SRT backend only and will be ignored by OpenAI models. regex: Optional[str] = None - json_schema: Optional[str] = None min_tokens: Optional[int] = 0 repetition_penalty: Optional[float] = 1.0 stop_token_ids: Optional[List[int]] = Field(default_factory=list) diff --git a/test/srt/test_json_constrained.py b/test/srt/test_json_constrained.py index 5393ecc33..122d79968 100644 --- a/test/srt/test_json_constrained.py +++ b/test/srt/test_json_constrained.py @@ -79,7 +79,10 @@ class TestJSONConstrained(unittest.TestCase): ], temperature=0, max_tokens=128, - extra_body={"json_schema": self.json_schema}, + response_format={ + "type": "json_schema", + "json_schema": {"name": "foo", "schema": json.loads(self.json_schema)}, + }, ) text = response.choices[0].message.content