From b6c14ec0b4f3d7f744c734a3835298b3242a2b90 Mon Sep 17 00:00:00 2001 From: cicirori <32845984+cicirori@users.noreply.github.com> Date: Wed, 27 Aug 2025 00:01:29 +0200 Subject: [PATCH] add `response_format` support for `completion` API (#9665) --- .../sglang/srt/entrypoints/openai/protocol.py | 35 ++++++------ .../entrypoints/openai/serving_completions.py | 15 +++++ .../basic/test_serving_completions.py | 57 +++++++++++++++++++ 3 files changed, 90 insertions(+), 17 deletions(-) diff --git a/python/sglang/srt/entrypoints/openai/protocol.py b/python/sglang/srt/entrypoints/openai/protocol.py index 7c1b07318..ab6411b47 100644 --- a/python/sglang/srt/entrypoints/openai/protocol.py +++ b/python/sglang/srt/entrypoints/openai/protocol.py @@ -108,6 +108,23 @@ class JsonSchemaResponseFormat(BaseModel): strict: Optional[bool] = False +class ResponseFormat(BaseModel): + type: Literal["text", "json_object", "json_schema"] + json_schema: Optional[JsonSchemaResponseFormat] = None + + +class StructuresResponseFormat(BaseModel): + begin: str + schema_: Optional[Dict[str, object]] = Field(alias="schema", default=None) + end: str + + +class StructuralTagResponseFormat(BaseModel): + type: Literal["structural_tag"] + structures: List[StructuresResponseFormat] + triggers: List[str] + + class FileRequest(BaseModel): # https://platform.openai.com/docs/api-reference/files/create file: bytes # The File object (not file name) to be uploaded @@ -200,6 +217,7 @@ class CompletionRequest(BaseModel): skip_special_tokens: bool = True lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None session_params: Optional[Dict] = None + response_format: Optional[Union[ResponseFormat, StructuralTagResponseFormat]] = None # For PD disaggregation bootstrap_host: Optional[Union[List[str], str]] = None @@ -359,23 +377,6 @@ ChatCompletionMessageParam = Union[ ] -class ResponseFormat(BaseModel): - type: Literal["text", "json_object", "json_schema"] - json_schema: Optional[JsonSchemaResponseFormat] = None - - -class StructuresResponseFormat(BaseModel): - begin: str - schema_: Optional[Dict[str, object]] = Field(alias="schema", default=None) - end: str - - -class StructuralTagResponseFormat(BaseModel): - type: Literal["structural_tag"] - structures: List[StructuresResponseFormat] - triggers: List[str] - - class Function(BaseModel): """Function descriptions.""" diff --git a/python/sglang/srt/entrypoints/openai/serving_completions.py b/python/sglang/srt/entrypoints/openai/serving_completions.py index 8ad88c3a2..3b30f9070 100644 --- a/python/sglang/srt/entrypoints/openai/serving_completions.py +++ b/python/sglang/srt/entrypoints/openai/serving_completions.py @@ -23,6 +23,7 @@ from sglang.srt.entrypoints.openai.utils import ( from sglang.srt.managers.io_struct import GenerateReqInput from sglang.srt.managers.template_manager import TemplateManager from sglang.srt.managers.tokenizer_manager import TokenizerManager +from sglang.utils import convert_json_schema_to_str logger = logging.getLogger(__name__) @@ -125,6 +126,20 @@ class OpenAIServingCompletion(OpenAIServingBase): "logit_bias": request.logit_bias, } + # Handle response_format constraints + if request.response_format and request.response_format.type == "json_schema": + sampling_params["json_schema"] = convert_json_schema_to_str( + request.response_format.json_schema.schema_ + ) + elif request.response_format and request.response_format.type == "json_object": + sampling_params["json_schema"] = '{"type": "object"}' + elif ( + request.response_format and request.response_format.type == "structural_tag" + ): + sampling_params["structural_tag"] = convert_json_schema_to_str( + request.response_format.model_dump(by_alias=True) + ) + return sampling_params async def _handle_streaming_request( diff --git a/test/srt/openai_server/basic/test_serving_completions.py b/test/srt/openai_server/basic/test_serving_completions.py index c0568e93b..022ba9ad1 100644 --- a/test/srt/openai_server/basic/test_serving_completions.py +++ b/test/srt/openai_server/basic/test_serving_completions.py @@ -95,6 +95,63 @@ class ServingCompletionTestCase(unittest.TestCase): self.sc.tokenizer_manager.tokenizer.decode.return_value = "decoded" self.assertEqual(self.sc._prepare_echo_prompts(req), ["decoded"]) + # ---------- response_format handling ---------- + def test_response_format_json_object(self): + """Test that response_format json_object is correctly processed in sampling params.""" + req = CompletionRequest( + model="x", + prompt="Generate a JSON object:", + max_tokens=100, + response_format={"type": "json_object"}, + ) + sampling_params = self.sc._build_sampling_params(req) + self.assertEqual(sampling_params["json_schema"], '{"type": "object"}') + + def test_response_format_json_schema(self): + """Test that response_format json_schema is correctly processed in sampling params.""" + schema = { + "type": "object", + "properties": {"name": {"type": "string"}, "age": {"type": "integer"}}, + } + req = CompletionRequest( + model="x", + prompt="Generate a JSON object:", + max_tokens=100, + response_format={ + "type": "json_schema", + "json_schema": {"name": "person", "schema": schema}, + }, + ) + sampling_params = self.sc._build_sampling_params(req) + # The schema should be converted to string by convert_json_schema_to_str + self.assertIn("json_schema", sampling_params) + self.assertIsInstance(sampling_params["json_schema"], str) + + def test_response_format_structural_tag(self): + """Test that response_format structural_tag is correctly processed in sampling params.""" + req = CompletionRequest( + model="x", + prompt="Generate structured output:", + max_tokens=100, + response_format={ + "type": "structural_tag", + "structures": [{"begin": "", "end": ""}], + "triggers": [""], + }, + ) + sampling_params = self.sc._build_sampling_params(req) + # The structural_tag should be processed + self.assertIn("structural_tag", sampling_params) + self.assertIsInstance(sampling_params["structural_tag"], str) + + def test_response_format_none(self): + """Test that no response_format doesn't add extra constraints.""" + req = CompletionRequest(model="x", prompt="Generate text:", max_tokens=100) + sampling_params = self.sc._build_sampling_params(req) + # Should not have json_schema or structural_tag from response_format + # (but might have json_schema from the legacy json_schema field) + self.assertIsNone(sampling_params.get("structural_tag")) + if __name__ == "__main__": unittest.main(verbosity=2)