add response_format support for completion API (#9665)
This commit is contained in:
@@ -108,6 +108,23 @@ class JsonSchemaResponseFormat(BaseModel):
|
||||
strict: Optional[bool] = False
|
||||
|
||||
|
||||
class ResponseFormat(BaseModel):
|
||||
type: Literal["text", "json_object", "json_schema"]
|
||||
json_schema: Optional[JsonSchemaResponseFormat] = None
|
||||
|
||||
|
||||
class StructuresResponseFormat(BaseModel):
|
||||
begin: str
|
||||
schema_: Optional[Dict[str, object]] = Field(alias="schema", default=None)
|
||||
end: str
|
||||
|
||||
|
||||
class StructuralTagResponseFormat(BaseModel):
|
||||
type: Literal["structural_tag"]
|
||||
structures: List[StructuresResponseFormat]
|
||||
triggers: List[str]
|
||||
|
||||
|
||||
class FileRequest(BaseModel):
|
||||
# https://platform.openai.com/docs/api-reference/files/create
|
||||
file: bytes # The File object (not file name) to be uploaded
|
||||
@@ -200,6 +217,7 @@ class CompletionRequest(BaseModel):
|
||||
skip_special_tokens: bool = True
|
||||
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
||||
session_params: Optional[Dict] = None
|
||||
response_format: Optional[Union[ResponseFormat, StructuralTagResponseFormat]] = None
|
||||
|
||||
# For PD disaggregation
|
||||
bootstrap_host: Optional[Union[List[str], str]] = None
|
||||
@@ -359,23 +377,6 @@ ChatCompletionMessageParam = Union[
|
||||
]
|
||||
|
||||
|
||||
class ResponseFormat(BaseModel):
|
||||
type: Literal["text", "json_object", "json_schema"]
|
||||
json_schema: Optional[JsonSchemaResponseFormat] = None
|
||||
|
||||
|
||||
class StructuresResponseFormat(BaseModel):
|
||||
begin: str
|
||||
schema_: Optional[Dict[str, object]] = Field(alias="schema", default=None)
|
||||
end: str
|
||||
|
||||
|
||||
class StructuralTagResponseFormat(BaseModel):
|
||||
type: Literal["structural_tag"]
|
||||
structures: List[StructuresResponseFormat]
|
||||
triggers: List[str]
|
||||
|
||||
|
||||
class Function(BaseModel):
|
||||
"""Function descriptions."""
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@ from sglang.srt.entrypoints.openai.utils import (
|
||||
from sglang.srt.managers.io_struct import GenerateReqInput
|
||||
from sglang.srt.managers.template_manager import TemplateManager
|
||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||
from sglang.utils import convert_json_schema_to_str
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -125,6 +126,20 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
||||
"logit_bias": request.logit_bias,
|
||||
}
|
||||
|
||||
# Handle response_format constraints
|
||||
if request.response_format and request.response_format.type == "json_schema":
|
||||
sampling_params["json_schema"] = convert_json_schema_to_str(
|
||||
request.response_format.json_schema.schema_
|
||||
)
|
||||
elif request.response_format and request.response_format.type == "json_object":
|
||||
sampling_params["json_schema"] = '{"type": "object"}'
|
||||
elif (
|
||||
request.response_format and request.response_format.type == "structural_tag"
|
||||
):
|
||||
sampling_params["structural_tag"] = convert_json_schema_to_str(
|
||||
request.response_format.model_dump(by_alias=True)
|
||||
)
|
||||
|
||||
return sampling_params
|
||||
|
||||
async def _handle_streaming_request(
|
||||
|
||||
Reference in New Issue
Block a user