add response_format support for completion API (#9665)
This commit is contained in:
@@ -108,6 +108,23 @@ class JsonSchemaResponseFormat(BaseModel):
|
|||||||
strict: Optional[bool] = False
|
strict: Optional[bool] = False
|
||||||
|
|
||||||
|
|
||||||
|
class ResponseFormat(BaseModel):
|
||||||
|
type: Literal["text", "json_object", "json_schema"]
|
||||||
|
json_schema: Optional[JsonSchemaResponseFormat] = None
|
||||||
|
|
||||||
|
|
||||||
|
class StructuresResponseFormat(BaseModel):
|
||||||
|
begin: str
|
||||||
|
schema_: Optional[Dict[str, object]] = Field(alias="schema", default=None)
|
||||||
|
end: str
|
||||||
|
|
||||||
|
|
||||||
|
class StructuralTagResponseFormat(BaseModel):
|
||||||
|
type: Literal["structural_tag"]
|
||||||
|
structures: List[StructuresResponseFormat]
|
||||||
|
triggers: List[str]
|
||||||
|
|
||||||
|
|
||||||
class FileRequest(BaseModel):
|
class FileRequest(BaseModel):
|
||||||
# https://platform.openai.com/docs/api-reference/files/create
|
# https://platform.openai.com/docs/api-reference/files/create
|
||||||
file: bytes # The File object (not file name) to be uploaded
|
file: bytes # The File object (not file name) to be uploaded
|
||||||
@@ -200,6 +217,7 @@ class CompletionRequest(BaseModel):
|
|||||||
skip_special_tokens: bool = True
|
skip_special_tokens: bool = True
|
||||||
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
||||||
session_params: Optional[Dict] = None
|
session_params: Optional[Dict] = None
|
||||||
|
response_format: Optional[Union[ResponseFormat, StructuralTagResponseFormat]] = None
|
||||||
|
|
||||||
# For PD disaggregation
|
# For PD disaggregation
|
||||||
bootstrap_host: Optional[Union[List[str], str]] = None
|
bootstrap_host: Optional[Union[List[str], str]] = None
|
||||||
@@ -359,23 +377,6 @@ ChatCompletionMessageParam = Union[
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class ResponseFormat(BaseModel):
|
|
||||||
type: Literal["text", "json_object", "json_schema"]
|
|
||||||
json_schema: Optional[JsonSchemaResponseFormat] = None
|
|
||||||
|
|
||||||
|
|
||||||
class StructuresResponseFormat(BaseModel):
|
|
||||||
begin: str
|
|
||||||
schema_: Optional[Dict[str, object]] = Field(alias="schema", default=None)
|
|
||||||
end: str
|
|
||||||
|
|
||||||
|
|
||||||
class StructuralTagResponseFormat(BaseModel):
|
|
||||||
type: Literal["structural_tag"]
|
|
||||||
structures: List[StructuresResponseFormat]
|
|
||||||
triggers: List[str]
|
|
||||||
|
|
||||||
|
|
||||||
class Function(BaseModel):
|
class Function(BaseModel):
|
||||||
"""Function descriptions."""
|
"""Function descriptions."""
|
||||||
|
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ from sglang.srt.entrypoints.openai.utils import (
|
|||||||
from sglang.srt.managers.io_struct import GenerateReqInput
|
from sglang.srt.managers.io_struct import GenerateReqInput
|
||||||
from sglang.srt.managers.template_manager import TemplateManager
|
from sglang.srt.managers.template_manager import TemplateManager
|
||||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||||
|
from sglang.utils import convert_json_schema_to_str
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -125,6 +126,20 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
|||||||
"logit_bias": request.logit_bias,
|
"logit_bias": request.logit_bias,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Handle response_format constraints
|
||||||
|
if request.response_format and request.response_format.type == "json_schema":
|
||||||
|
sampling_params["json_schema"] = convert_json_schema_to_str(
|
||||||
|
request.response_format.json_schema.schema_
|
||||||
|
)
|
||||||
|
elif request.response_format and request.response_format.type == "json_object":
|
||||||
|
sampling_params["json_schema"] = '{"type": "object"}'
|
||||||
|
elif (
|
||||||
|
request.response_format and request.response_format.type == "structural_tag"
|
||||||
|
):
|
||||||
|
sampling_params["structural_tag"] = convert_json_schema_to_str(
|
||||||
|
request.response_format.model_dump(by_alias=True)
|
||||||
|
)
|
||||||
|
|
||||||
return sampling_params
|
return sampling_params
|
||||||
|
|
||||||
async def _handle_streaming_request(
|
async def _handle_streaming_request(
|
||||||
|
|||||||
@@ -95,6 +95,63 @@ class ServingCompletionTestCase(unittest.TestCase):
|
|||||||
self.sc.tokenizer_manager.tokenizer.decode.return_value = "decoded"
|
self.sc.tokenizer_manager.tokenizer.decode.return_value = "decoded"
|
||||||
self.assertEqual(self.sc._prepare_echo_prompts(req), ["decoded"])
|
self.assertEqual(self.sc._prepare_echo_prompts(req), ["decoded"])
|
||||||
|
|
||||||
|
# ---------- response_format handling ----------
|
||||||
|
def test_response_format_json_object(self):
|
||||||
|
"""Test that response_format json_object is correctly processed in sampling params."""
|
||||||
|
req = CompletionRequest(
|
||||||
|
model="x",
|
||||||
|
prompt="Generate a JSON object:",
|
||||||
|
max_tokens=100,
|
||||||
|
response_format={"type": "json_object"},
|
||||||
|
)
|
||||||
|
sampling_params = self.sc._build_sampling_params(req)
|
||||||
|
self.assertEqual(sampling_params["json_schema"], '{"type": "object"}')
|
||||||
|
|
||||||
|
def test_response_format_json_schema(self):
|
||||||
|
"""Test that response_format json_schema is correctly processed in sampling params."""
|
||||||
|
schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"name": {"type": "string"}, "age": {"type": "integer"}},
|
||||||
|
}
|
||||||
|
req = CompletionRequest(
|
||||||
|
model="x",
|
||||||
|
prompt="Generate a JSON object:",
|
||||||
|
max_tokens=100,
|
||||||
|
response_format={
|
||||||
|
"type": "json_schema",
|
||||||
|
"json_schema": {"name": "person", "schema": schema},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
sampling_params = self.sc._build_sampling_params(req)
|
||||||
|
# The schema should be converted to string by convert_json_schema_to_str
|
||||||
|
self.assertIn("json_schema", sampling_params)
|
||||||
|
self.assertIsInstance(sampling_params["json_schema"], str)
|
||||||
|
|
||||||
|
def test_response_format_structural_tag(self):
|
||||||
|
"""Test that response_format structural_tag is correctly processed in sampling params."""
|
||||||
|
req = CompletionRequest(
|
||||||
|
model="x",
|
||||||
|
prompt="Generate structured output:",
|
||||||
|
max_tokens=100,
|
||||||
|
response_format={
|
||||||
|
"type": "structural_tag",
|
||||||
|
"structures": [{"begin": "<data>", "end": "</data>"}],
|
||||||
|
"triggers": ["<data>"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
sampling_params = self.sc._build_sampling_params(req)
|
||||||
|
# The structural_tag should be processed
|
||||||
|
self.assertIn("structural_tag", sampling_params)
|
||||||
|
self.assertIsInstance(sampling_params["structural_tag"], str)
|
||||||
|
|
||||||
|
def test_response_format_none(self):
|
||||||
|
"""Test that no response_format doesn't add extra constraints."""
|
||||||
|
req = CompletionRequest(model="x", prompt="Generate text:", max_tokens=100)
|
||||||
|
sampling_params = self.sc._build_sampling_params(req)
|
||||||
|
# Should not have json_schema or structural_tag from response_format
|
||||||
|
# (but might have json_schema from the legacy json_schema field)
|
||||||
|
self.assertIsNone(sampling_params.get("structural_tag"))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main(verbosity=2)
|
unittest.main(verbosity=2)
|
||||||
|
|||||||
Reference in New Issue
Block a user