Support OpenAI API json_schema response format (#1363)
This commit is contained in:
@@ -23,7 +23,6 @@ from collections import defaultdict
|
|||||||
|
|
||||||
import interegular
|
import interegular
|
||||||
import outlines.caching
|
import outlines.caching
|
||||||
from outlines.fsm.json_schema import build_regex_from_schema
|
|
||||||
|
|
||||||
from sglang.srt.constrained import (
|
from sglang.srt.constrained import (
|
||||||
FSMInfo,
|
FSMInfo,
|
||||||
|
|||||||
@@ -28,6 +28,13 @@ from fastapi import HTTPException, Request, UploadFile
|
|||||||
from fastapi.responses import JSONResponse, StreamingResponse
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
try:
|
||||||
|
from outlines.fsm.json_schema import convert_json_schema_to_str
|
||||||
|
except ImportError:
|
||||||
|
# Before outlines 0.0.47, convert_json_schema_to_str is under
|
||||||
|
# outlines.integrations.utils
|
||||||
|
from outlines.integrations.utils import convert_json_schema_to_str
|
||||||
|
|
||||||
from sglang.srt.conversation import (
|
from sglang.srt.conversation import (
|
||||||
Conversation,
|
Conversation,
|
||||||
SeparatorStyle,
|
SeparatorStyle,
|
||||||
@@ -888,22 +895,26 @@ def v1_chat_generate_request(
|
|||||||
return_logprobs.append(request.logprobs)
|
return_logprobs.append(request.logprobs)
|
||||||
logprob_start_lens.append(-1)
|
logprob_start_lens.append(-1)
|
||||||
top_logprobs_nums.append(request.top_logprobs)
|
top_logprobs_nums.append(request.top_logprobs)
|
||||||
sampling_params_list.append(
|
|
||||||
{
|
sampling_params = {
|
||||||
"temperature": request.temperature,
|
"temperature": request.temperature,
|
||||||
"max_new_tokens": request.max_tokens,
|
"max_new_tokens": request.max_tokens,
|
||||||
"min_new_tokens": request.min_tokens,
|
"min_new_tokens": request.min_tokens,
|
||||||
"stop": stop,
|
"stop": stop,
|
||||||
"stop_token_ids": request.stop_token_ids,
|
"stop_token_ids": request.stop_token_ids,
|
||||||
"top_p": request.top_p,
|
"top_p": request.top_p,
|
||||||
"presence_penalty": request.presence_penalty,
|
"presence_penalty": request.presence_penalty,
|
||||||
"frequency_penalty": request.frequency_penalty,
|
"frequency_penalty": request.frequency_penalty,
|
||||||
"repetition_penalty": request.repetition_penalty,
|
"repetition_penalty": request.repetition_penalty,
|
||||||
"regex": request.regex,
|
"regex": request.regex,
|
||||||
"json_schema": request.json_schema,
|
"n": request.n,
|
||||||
"n": request.n,
|
}
|
||||||
}
|
if request.response_format and request.response_format.type == "json_schema":
|
||||||
)
|
sampling_params["json_schema"] = convert_json_schema_to_str(
|
||||||
|
request.response_format.json_schema.schema_
|
||||||
|
)
|
||||||
|
sampling_params_list.append(sampling_params)
|
||||||
|
|
||||||
image_data_list.append(image_data)
|
image_data_list.append(image_data)
|
||||||
modalities_list.extend(modalities)
|
modalities_list.extend(modalities)
|
||||||
if len(all_requests) == 1:
|
if len(all_requests) == 1:
|
||||||
|
|||||||
@@ -82,6 +82,14 @@ class StreamOptions(BaseModel):
|
|||||||
include_usage: Optional[bool] = False
|
include_usage: Optional[bool] = False
|
||||||
|
|
||||||
|
|
||||||
|
class JsonSchemaResponseFormat(BaseModel):
|
||||||
|
name: str
|
||||||
|
description: Optional[str] = None
|
||||||
|
# use alias to workaround pydantic conflict
|
||||||
|
schema_: Optional[Dict[str, object]] = Field(alias="schema", default=None)
|
||||||
|
strict: Optional[bool] = False
|
||||||
|
|
||||||
|
|
||||||
class FileRequest(BaseModel):
|
class FileRequest(BaseModel):
|
||||||
# https://platform.openai.com/docs/api-reference/files/create
|
# https://platform.openai.com/docs/api-reference/files/create
|
||||||
file: bytes # The File object (not file name) to be uploaded
|
file: bytes # The File object (not file name) to be uploaded
|
||||||
@@ -237,8 +245,8 @@ ChatCompletionMessageParam = Union[
|
|||||||
|
|
||||||
|
|
||||||
class ResponseFormat(BaseModel):
|
class ResponseFormat(BaseModel):
|
||||||
# type must be "json_object" or "text"
|
type: Literal["text", "json_object", "json_schema"]
|
||||||
type: Literal["text", "json_object"]
|
json_schema: Optional[JsonSchemaResponseFormat] = None
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionRequest(BaseModel):
|
class ChatCompletionRequest(BaseModel):
|
||||||
@@ -264,7 +272,6 @@ class ChatCompletionRequest(BaseModel):
|
|||||||
|
|
||||||
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
|
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
|
||||||
regex: Optional[str] = None
|
regex: Optional[str] = None
|
||||||
json_schema: Optional[str] = None
|
|
||||||
min_tokens: Optional[int] = 0
|
min_tokens: Optional[int] = 0
|
||||||
repetition_penalty: Optional[float] = 1.0
|
repetition_penalty: Optional[float] = 1.0
|
||||||
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
||||||
|
|||||||
@@ -79,7 +79,10 @@ class TestJSONConstrained(unittest.TestCase):
|
|||||||
],
|
],
|
||||||
temperature=0,
|
temperature=0,
|
||||||
max_tokens=128,
|
max_tokens=128,
|
||||||
extra_body={"json_schema": self.json_schema},
|
response_format={
|
||||||
|
"type": "json_schema",
|
||||||
|
"json_schema": {"name": "foo", "schema": json.loads(self.json_schema)},
|
||||||
|
},
|
||||||
)
|
)
|
||||||
text = response.choices[0].message.content
|
text = response.choices[0].message.content
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user