add response_format support for completion API (#9665)
This commit is contained in:
@@ -108,6 +108,23 @@ class JsonSchemaResponseFormat(BaseModel):
|
||||
strict: Optional[bool] = False
|
||||
|
||||
|
||||
class ResponseFormat(BaseModel):
|
||||
type: Literal["text", "json_object", "json_schema"]
|
||||
json_schema: Optional[JsonSchemaResponseFormat] = None
|
||||
|
||||
|
||||
class StructuresResponseFormat(BaseModel):
|
||||
begin: str
|
||||
schema_: Optional[Dict[str, object]] = Field(alias="schema", default=None)
|
||||
end: str
|
||||
|
||||
|
||||
class StructuralTagResponseFormat(BaseModel):
|
||||
type: Literal["structural_tag"]
|
||||
structures: List[StructuresResponseFormat]
|
||||
triggers: List[str]
|
||||
|
||||
|
||||
class FileRequest(BaseModel):
|
||||
# https://platform.openai.com/docs/api-reference/files/create
|
||||
file: bytes # The File object (not file name) to be uploaded
|
||||
@@ -200,6 +217,7 @@ class CompletionRequest(BaseModel):
|
||||
skip_special_tokens: bool = True
|
||||
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
||||
session_params: Optional[Dict] = None
|
||||
response_format: Optional[Union[ResponseFormat, StructuralTagResponseFormat]] = None
|
||||
|
||||
# For PD disaggregation
|
||||
bootstrap_host: Optional[Union[List[str], str]] = None
|
||||
@@ -359,23 +377,6 @@ ChatCompletionMessageParam = Union[
|
||||
]
|
||||
|
||||
|
||||
class ResponseFormat(BaseModel):
|
||||
type: Literal["text", "json_object", "json_schema"]
|
||||
json_schema: Optional[JsonSchemaResponseFormat] = None
|
||||
|
||||
|
||||
class StructuresResponseFormat(BaseModel):
|
||||
begin: str
|
||||
schema_: Optional[Dict[str, object]] = Field(alias="schema", default=None)
|
||||
end: str
|
||||
|
||||
|
||||
class StructuralTagResponseFormat(BaseModel):
|
||||
type: Literal["structural_tag"]
|
||||
structures: List[StructuresResponseFormat]
|
||||
triggers: List[str]
|
||||
|
||||
|
||||
class Function(BaseModel):
|
||||
"""Function descriptions."""
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@ from sglang.srt.entrypoints.openai.utils import (
|
||||
from sglang.srt.managers.io_struct import GenerateReqInput
|
||||
from sglang.srt.managers.template_manager import TemplateManager
|
||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||
from sglang.utils import convert_json_schema_to_str
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -125,6 +126,20 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
||||
"logit_bias": request.logit_bias,
|
||||
}
|
||||
|
||||
# Handle response_format constraints
|
||||
if request.response_format and request.response_format.type == "json_schema":
|
||||
sampling_params["json_schema"] = convert_json_schema_to_str(
|
||||
request.response_format.json_schema.schema_
|
||||
)
|
||||
elif request.response_format and request.response_format.type == "json_object":
|
||||
sampling_params["json_schema"] = '{"type": "object"}'
|
||||
elif (
|
||||
request.response_format and request.response_format.type == "structural_tag"
|
||||
):
|
||||
sampling_params["structural_tag"] = convert_json_schema_to_str(
|
||||
request.response_format.model_dump(by_alias=True)
|
||||
)
|
||||
|
||||
return sampling_params
|
||||
|
||||
async def _handle_streaming_request(
|
||||
|
||||
@@ -95,6 +95,63 @@ class ServingCompletionTestCase(unittest.TestCase):
|
||||
self.sc.tokenizer_manager.tokenizer.decode.return_value = "decoded"
|
||||
self.assertEqual(self.sc._prepare_echo_prompts(req), ["decoded"])
|
||||
|
||||
# ---------- response_format handling ----------
|
||||
def test_response_format_json_object(self):
|
||||
"""Test that response_format json_object is correctly processed in sampling params."""
|
||||
req = CompletionRequest(
|
||||
model="x",
|
||||
prompt="Generate a JSON object:",
|
||||
max_tokens=100,
|
||||
response_format={"type": "json_object"},
|
||||
)
|
||||
sampling_params = self.sc._build_sampling_params(req)
|
||||
self.assertEqual(sampling_params["json_schema"], '{"type": "object"}')
|
||||
|
||||
def test_response_format_json_schema(self):
|
||||
"""Test that response_format json_schema is correctly processed in sampling params."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {"name": {"type": "string"}, "age": {"type": "integer"}},
|
||||
}
|
||||
req = CompletionRequest(
|
||||
model="x",
|
||||
prompt="Generate a JSON object:",
|
||||
max_tokens=100,
|
||||
response_format={
|
||||
"type": "json_schema",
|
||||
"json_schema": {"name": "person", "schema": schema},
|
||||
},
|
||||
)
|
||||
sampling_params = self.sc._build_sampling_params(req)
|
||||
# The schema should be converted to string by convert_json_schema_to_str
|
||||
self.assertIn("json_schema", sampling_params)
|
||||
self.assertIsInstance(sampling_params["json_schema"], str)
|
||||
|
||||
def test_response_format_structural_tag(self):
|
||||
"""Test that response_format structural_tag is correctly processed in sampling params."""
|
||||
req = CompletionRequest(
|
||||
model="x",
|
||||
prompt="Generate structured output:",
|
||||
max_tokens=100,
|
||||
response_format={
|
||||
"type": "structural_tag",
|
||||
"structures": [{"begin": "<data>", "end": "</data>"}],
|
||||
"triggers": ["<data>"],
|
||||
},
|
||||
)
|
||||
sampling_params = self.sc._build_sampling_params(req)
|
||||
# The structural_tag should be processed
|
||||
self.assertIn("structural_tag", sampling_params)
|
||||
self.assertIsInstance(sampling_params["structural_tag"], str)
|
||||
|
||||
def test_response_format_none(self):
|
||||
"""Test that no response_format doesn't add extra constraints."""
|
||||
req = CompletionRequest(model="x", prompt="Generate text:", max_tokens=100)
|
||||
sampling_params = self.sc._build_sampling_params(req)
|
||||
# Should not have json_schema or structural_tag from response_format
|
||||
# (but might have json_schema from the legacy json_schema field)
|
||||
self.assertIsNone(sampling_params.get("structural_tag"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
Reference in New Issue
Block a user