accomendate json schema in the "schema" field, not in "json_schema" field of response_format (#9786)
This commit is contained in:
@@ -460,6 +460,38 @@ class ChatCompletionRequest(BaseModel):
|
|||||||
values["tool_choice"] = "auto"
|
values["tool_choice"] = "auto"
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def set_json_schema(cls, values):
|
||||||
|
response_format = values.get("response_format")
|
||||||
|
if not response_format:
|
||||||
|
return values
|
||||||
|
|
||||||
|
if response_format.get("type") != "json_schema":
|
||||||
|
return values
|
||||||
|
|
||||||
|
schema = response_format.pop("schema", None)
|
||||||
|
json_schema = response_format.get("json_schema")
|
||||||
|
|
||||||
|
if json_schema:
|
||||||
|
return values
|
||||||
|
|
||||||
|
if schema:
|
||||||
|
name_ = schema.get("title", "Schema")
|
||||||
|
strict_ = False
|
||||||
|
if "properties" in schema and "strict" in schema["properties"]:
|
||||||
|
item = schema["properties"].pop("strict", None)
|
||||||
|
if item and item.get("default", False):
|
||||||
|
strict_ = True
|
||||||
|
|
||||||
|
response_format["json_schema"] = {
|
||||||
|
"name": name_,
|
||||||
|
"schema": schema,
|
||||||
|
"strict": strict_,
|
||||||
|
}
|
||||||
|
|
||||||
|
return values
|
||||||
|
|
||||||
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
|
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
|
||||||
top_k: int = -1
|
top_k: int = -1
|
||||||
min_p: float = 0.0
|
min_p: float = 0.0
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ import time
|
|||||||
import unittest
|
import unittest
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
from pydantic import ValidationError
|
from pydantic import BaseModel, Field, ValidationError
|
||||||
|
|
||||||
from sglang.srt.entrypoints.openai.protocol import (
|
from sglang.srt.entrypoints.openai.protocol import (
|
||||||
BatchRequest,
|
BatchRequest,
|
||||||
@@ -192,6 +192,81 @@ class TestChatCompletionRequest(unittest.TestCase):
|
|||||||
self.assertFalse(request.stream_reasoning)
|
self.assertFalse(request.stream_reasoning)
|
||||||
self.assertEqual(request.chat_template_kwargs, {"custom_param": "value"})
|
self.assertEqual(request.chat_template_kwargs, {"custom_param": "value"})
|
||||||
|
|
||||||
|
def test_chat_completion_json_format(self):
|
||||||
|
"""Test chat completion json format"""
|
||||||
|
transcript = "Good morning! It's 7:00 AM, and I'm just waking up. Today is going to be a busy day, "
|
||||||
|
"so let's get started. First, I need to make a quick breakfast. I think I'll have some "
|
||||||
|
"scrambled eggs and toast with a cup of coffee. While I'm cooking, I'll also check my "
|
||||||
|
"emails to see if there's anything urgent."
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "The following is a voice message transcript. Only answer in JSON.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": transcript,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
class VoiceNote(BaseModel):
|
||||||
|
title: str = Field(description="A title for the voice note")
|
||||||
|
summary: str = Field(
|
||||||
|
description="A short one sentence summary of the voice note."
|
||||||
|
)
|
||||||
|
strict: Optional[bool] = True
|
||||||
|
actionItems: List[str] = Field(
|
||||||
|
description="A list of action items from the voice note"
|
||||||
|
)
|
||||||
|
|
||||||
|
request = ChatCompletionRequest(
|
||||||
|
model="test-model",
|
||||||
|
messages=messages,
|
||||||
|
top_k=40,
|
||||||
|
min_p=0.05,
|
||||||
|
separate_reasoning=False,
|
||||||
|
stream_reasoning=False,
|
||||||
|
chat_template_kwargs={"custom_param": "value"},
|
||||||
|
response_format={
|
||||||
|
"type": "json_schema",
|
||||||
|
"schema": VoiceNote.model_json_schema(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
res_format = request.response_format
|
||||||
|
json_format = res_format.json_schema
|
||||||
|
name = json_format.name
|
||||||
|
schema = json_format.schema_
|
||||||
|
strict = json_format.strict
|
||||||
|
self.assertEqual(name, "VoiceNote")
|
||||||
|
self.assertEqual(strict, True)
|
||||||
|
self.assertNotIn("strict", schema["properties"])
|
||||||
|
|
||||||
|
request = ChatCompletionRequest(
|
||||||
|
model="test-model",
|
||||||
|
messages=messages,
|
||||||
|
top_k=40,
|
||||||
|
min_p=0.05,
|
||||||
|
separate_reasoning=False,
|
||||||
|
stream_reasoning=False,
|
||||||
|
chat_template_kwargs={"custom_param": "value"},
|
||||||
|
response_format={
|
||||||
|
"type": "json_schema",
|
||||||
|
"json_schema": {
|
||||||
|
"name": "VoiceNote",
|
||||||
|
"schema": VoiceNote.model_json_schema(),
|
||||||
|
"strict": True,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
res_format = request.response_format
|
||||||
|
json_format = res_format.json_schema
|
||||||
|
name = json_format.name
|
||||||
|
schema = json_format.schema_
|
||||||
|
strict = json_format.strict
|
||||||
|
self.assertEqual(name, "VoiceNote")
|
||||||
|
self.assertEqual(strict, True)
|
||||||
|
|
||||||
|
|
||||||
class TestModelSerialization(unittest.TestCase):
|
class TestModelSerialization(unittest.TestCase):
|
||||||
"""Test model serialization with hidden states"""
|
"""Test model serialization with hidden states"""
|
||||||
|
|||||||
Reference in New Issue
Block a user