accomendate json schema in the "schema" field, not in "json_schema" field of response_format (#9786)
This commit is contained in:
@@ -460,6 +460,38 @@ class ChatCompletionRequest(BaseModel):
|
||||
values["tool_choice"] = "auto"
|
||||
return values
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def set_json_schema(cls, values):
|
||||
response_format = values.get("response_format")
|
||||
if not response_format:
|
||||
return values
|
||||
|
||||
if response_format.get("type") != "json_schema":
|
||||
return values
|
||||
|
||||
schema = response_format.pop("schema", None)
|
||||
json_schema = response_format.get("json_schema")
|
||||
|
||||
if json_schema:
|
||||
return values
|
||||
|
||||
if schema:
|
||||
name_ = schema.get("title", "Schema")
|
||||
strict_ = False
|
||||
if "properties" in schema and "strict" in schema["properties"]:
|
||||
item = schema["properties"].pop("strict", None)
|
||||
if item and item.get("default", False):
|
||||
strict_ = True
|
||||
|
||||
response_format["json_schema"] = {
|
||||
"name": name_,
|
||||
"schema": schema,
|
||||
"strict": strict_,
|
||||
}
|
||||
|
||||
return values
|
||||
|
||||
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
|
||||
top_k: int = -1
|
||||
min_p: float = 0.0
|
||||
|
||||
@@ -18,7 +18,7 @@ import time
|
||||
import unittest
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import ValidationError
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
|
||||
from sglang.srt.entrypoints.openai.protocol import (
|
||||
BatchRequest,
|
||||
@@ -192,6 +192,81 @@ class TestChatCompletionRequest(unittest.TestCase):
|
||||
self.assertFalse(request.stream_reasoning)
|
||||
self.assertEqual(request.chat_template_kwargs, {"custom_param": "value"})
|
||||
|
||||
def test_chat_completion_json_format(self):
|
||||
"""Test chat completion json format"""
|
||||
transcript = "Good morning! It's 7:00 AM, and I'm just waking up. Today is going to be a busy day, "
|
||||
"so let's get started. First, I need to make a quick breakfast. I think I'll have some "
|
||||
"scrambled eggs and toast with a cup of coffee. While I'm cooking, I'll also check my "
|
||||
"emails to see if there's anything urgent."
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "The following is a voice message transcript. Only answer in JSON.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": transcript,
|
||||
},
|
||||
]
|
||||
|
||||
class VoiceNote(BaseModel):
|
||||
title: str = Field(description="A title for the voice note")
|
||||
summary: str = Field(
|
||||
description="A short one sentence summary of the voice note."
|
||||
)
|
||||
strict: Optional[bool] = True
|
||||
actionItems: List[str] = Field(
|
||||
description="A list of action items from the voice note"
|
||||
)
|
||||
|
||||
request = ChatCompletionRequest(
|
||||
model="test-model",
|
||||
messages=messages,
|
||||
top_k=40,
|
||||
min_p=0.05,
|
||||
separate_reasoning=False,
|
||||
stream_reasoning=False,
|
||||
chat_template_kwargs={"custom_param": "value"},
|
||||
response_format={
|
||||
"type": "json_schema",
|
||||
"schema": VoiceNote.model_json_schema(),
|
||||
},
|
||||
)
|
||||
res_format = request.response_format
|
||||
json_format = res_format.json_schema
|
||||
name = json_format.name
|
||||
schema = json_format.schema_
|
||||
strict = json_format.strict
|
||||
self.assertEqual(name, "VoiceNote")
|
||||
self.assertEqual(strict, True)
|
||||
self.assertNotIn("strict", schema["properties"])
|
||||
|
||||
request = ChatCompletionRequest(
|
||||
model="test-model",
|
||||
messages=messages,
|
||||
top_k=40,
|
||||
min_p=0.05,
|
||||
separate_reasoning=False,
|
||||
stream_reasoning=False,
|
||||
chat_template_kwargs={"custom_param": "value"},
|
||||
response_format={
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "VoiceNote",
|
||||
"schema": VoiceNote.model_json_schema(),
|
||||
"strict": True,
|
||||
},
|
||||
},
|
||||
)
|
||||
res_format = request.response_format
|
||||
json_format = res_format.json_schema
|
||||
name = json_format.name
|
||||
schema = json_format.schema_
|
||||
strict = json_format.strict
|
||||
self.assertEqual(name, "VoiceNote")
|
||||
self.assertEqual(strict, True)
|
||||
|
||||
|
||||
class TestModelSerialization(unittest.TestCase):
|
||||
"""Test model serialization with hidden states"""
|
||||
|
||||
Reference in New Issue
Block a user