Support OpenAI API json_schema response format (#1363)
This commit is contained in:
@@ -23,7 +23,6 @@ from collections import defaultdict
|
||||
|
||||
import interegular
|
||||
import outlines.caching
|
||||
from outlines.fsm.json_schema import build_regex_from_schema
|
||||
|
||||
from sglang.srt.constrained import (
|
||||
FSMInfo,
|
||||
|
||||
@@ -28,6 +28,13 @@ from fastapi import HTTPException, Request, UploadFile
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from pydantic import ValidationError
|
||||
|
||||
try:
|
||||
from outlines.fsm.json_schema import convert_json_schema_to_str
|
||||
except ImportError:
|
||||
# Before outlines 0.0.47, convert_json_schema_to_str is under
|
||||
# outlines.integrations.utils
|
||||
from outlines.integrations.utils import convert_json_schema_to_str
|
||||
|
||||
from sglang.srt.conversation import (
|
||||
Conversation,
|
||||
SeparatorStyle,
|
||||
@@ -888,22 +895,26 @@ def v1_chat_generate_request(
|
||||
return_logprobs.append(request.logprobs)
|
||||
logprob_start_lens.append(-1)
|
||||
top_logprobs_nums.append(request.top_logprobs)
|
||||
sampling_params_list.append(
|
||||
{
|
||||
"temperature": request.temperature,
|
||||
"max_new_tokens": request.max_tokens,
|
||||
"min_new_tokens": request.min_tokens,
|
||||
"stop": stop,
|
||||
"stop_token_ids": request.stop_token_ids,
|
||||
"top_p": request.top_p,
|
||||
"presence_penalty": request.presence_penalty,
|
||||
"frequency_penalty": request.frequency_penalty,
|
||||
"repetition_penalty": request.repetition_penalty,
|
||||
"regex": request.regex,
|
||||
"json_schema": request.json_schema,
|
||||
"n": request.n,
|
||||
}
|
||||
)
|
||||
|
||||
sampling_params = {
|
||||
"temperature": request.temperature,
|
||||
"max_new_tokens": request.max_tokens,
|
||||
"min_new_tokens": request.min_tokens,
|
||||
"stop": stop,
|
||||
"stop_token_ids": request.stop_token_ids,
|
||||
"top_p": request.top_p,
|
||||
"presence_penalty": request.presence_penalty,
|
||||
"frequency_penalty": request.frequency_penalty,
|
||||
"repetition_penalty": request.repetition_penalty,
|
||||
"regex": request.regex,
|
||||
"n": request.n,
|
||||
}
|
||||
if request.response_format and request.response_format.type == "json_schema":
|
||||
sampling_params["json_schema"] = convert_json_schema_to_str(
|
||||
request.response_format.json_schema.schema_
|
||||
)
|
||||
sampling_params_list.append(sampling_params)
|
||||
|
||||
image_data_list.append(image_data)
|
||||
modalities_list.extend(modalities)
|
||||
if len(all_requests) == 1:
|
||||
|
||||
@@ -82,6 +82,14 @@ class StreamOptions(BaseModel):
|
||||
include_usage: Optional[bool] = False
|
||||
|
||||
|
||||
class JsonSchemaResponseFormat(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
# use alias to workaround pydantic conflict
|
||||
schema_: Optional[Dict[str, object]] = Field(alias="schema", default=None)
|
||||
strict: Optional[bool] = False
|
||||
|
||||
|
||||
class FileRequest(BaseModel):
|
||||
# https://platform.openai.com/docs/api-reference/files/create
|
||||
file: bytes # The File object (not file name) to be uploaded
|
||||
@@ -237,8 +245,8 @@ ChatCompletionMessageParam = Union[
|
||||
|
||||
|
||||
class ResponseFormat(BaseModel):
|
||||
# type must be "json_object" or "text"
|
||||
type: Literal["text", "json_object"]
|
||||
type: Literal["text", "json_object", "json_schema"]
|
||||
json_schema: Optional[JsonSchemaResponseFormat] = None
|
||||
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
@@ -264,7 +272,6 @@ class ChatCompletionRequest(BaseModel):
|
||||
|
||||
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
|
||||
regex: Optional[str] = None
|
||||
json_schema: Optional[str] = None
|
||||
min_tokens: Optional[int] = 0
|
||||
repetition_penalty: Optional[float] = 1.0
|
||||
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
||||
|
||||
Reference in New Issue
Block a user