From b6c14ec0b4f3d7f744c734a3835298b3242a2b90 Mon Sep 17 00:00:00 2001
From: cicirori <32845984+cicirori@users.noreply.github.com>
Date: Wed, 27 Aug 2025 00:01:29 +0200
Subject: [PATCH] add `response_format` support for `completion` API (#9665)
---
.../sglang/srt/entrypoints/openai/protocol.py | 35 ++++++------
.../entrypoints/openai/serving_completions.py | 15 +++++
.../basic/test_serving_completions.py | 57 +++++++++++++++++++
3 files changed, 90 insertions(+), 17 deletions(-)
diff --git a/python/sglang/srt/entrypoints/openai/protocol.py b/python/sglang/srt/entrypoints/openai/protocol.py
index 7c1b07318..ab6411b47 100644
--- a/python/sglang/srt/entrypoints/openai/protocol.py
+++ b/python/sglang/srt/entrypoints/openai/protocol.py
@@ -108,6 +108,23 @@ class JsonSchemaResponseFormat(BaseModel):
strict: Optional[bool] = False
+class ResponseFormat(BaseModel):
+ type: Literal["text", "json_object", "json_schema"]
+ json_schema: Optional[JsonSchemaResponseFormat] = None
+
+
+class StructuresResponseFormat(BaseModel):
+ begin: str
+ schema_: Optional[Dict[str, object]] = Field(alias="schema", default=None)
+ end: str
+
+
+class StructuralTagResponseFormat(BaseModel):
+ type: Literal["structural_tag"]
+ structures: List[StructuresResponseFormat]
+ triggers: List[str]
+
+
class FileRequest(BaseModel):
# https://platform.openai.com/docs/api-reference/files/create
file: bytes # The File object (not file name) to be uploaded
@@ -200,6 +217,7 @@ class CompletionRequest(BaseModel):
skip_special_tokens: bool = True
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
session_params: Optional[Dict] = None
+ response_format: Optional[Union[ResponseFormat, StructuralTagResponseFormat]] = None
# For PD disaggregation
bootstrap_host: Optional[Union[List[str], str]] = None
@@ -359,23 +377,6 @@ ChatCompletionMessageParam = Union[
]
-class ResponseFormat(BaseModel):
- type: Literal["text", "json_object", "json_schema"]
- json_schema: Optional[JsonSchemaResponseFormat] = None
-
-
-class StructuresResponseFormat(BaseModel):
- begin: str
- schema_: Optional[Dict[str, object]] = Field(alias="schema", default=None)
- end: str
-
-
-class StructuralTagResponseFormat(BaseModel):
- type: Literal["structural_tag"]
- structures: List[StructuresResponseFormat]
- triggers: List[str]
-
-
class Function(BaseModel):
"""Function descriptions."""
diff --git a/python/sglang/srt/entrypoints/openai/serving_completions.py b/python/sglang/srt/entrypoints/openai/serving_completions.py
index 8ad88c3a2..3b30f9070 100644
--- a/python/sglang/srt/entrypoints/openai/serving_completions.py
+++ b/python/sglang/srt/entrypoints/openai/serving_completions.py
@@ -23,6 +23,7 @@ from sglang.srt.entrypoints.openai.utils import (
from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.managers.template_manager import TemplateManager
from sglang.srt.managers.tokenizer_manager import TokenizerManager
+from sglang.utils import convert_json_schema_to_str
logger = logging.getLogger(__name__)
@@ -125,6 +126,20 @@ class OpenAIServingCompletion(OpenAIServingBase):
"logit_bias": request.logit_bias,
}
+ # Handle response_format constraints
+ if request.response_format and request.response_format.type == "json_schema":
+ sampling_params["json_schema"] = convert_json_schema_to_str(
+ request.response_format.json_schema.schema_
+ )
+ elif request.response_format and request.response_format.type == "json_object":
+ sampling_params["json_schema"] = '{"type": "object"}'
+ elif (
+ request.response_format and request.response_format.type == "structural_tag"
+ ):
+ sampling_params["structural_tag"] = convert_json_schema_to_str(
+ request.response_format.model_dump(by_alias=True)
+ )
+
return sampling_params
async def _handle_streaming_request(
diff --git a/test/srt/openai_server/basic/test_serving_completions.py b/test/srt/openai_server/basic/test_serving_completions.py
index c0568e93b..022ba9ad1 100644
--- a/test/srt/openai_server/basic/test_serving_completions.py
+++ b/test/srt/openai_server/basic/test_serving_completions.py
@@ -95,6 +95,63 @@ class ServingCompletionTestCase(unittest.TestCase):
self.sc.tokenizer_manager.tokenizer.decode.return_value = "decoded"
self.assertEqual(self.sc._prepare_echo_prompts(req), ["decoded"])
+ # ---------- response_format handling ----------
+ def test_response_format_json_object(self):
+ """Test that response_format json_object is correctly processed in sampling params."""
+ req = CompletionRequest(
+ model="x",
+ prompt="Generate a JSON object:",
+ max_tokens=100,
+ response_format={"type": "json_object"},
+ )
+ sampling_params = self.sc._build_sampling_params(req)
+ self.assertEqual(sampling_params["json_schema"], '{"type": "object"}')
+
+ def test_response_format_json_schema(self):
+ """Test that response_format json_schema is correctly processed in sampling params."""
+ schema = {
+ "type": "object",
+ "properties": {"name": {"type": "string"}, "age": {"type": "integer"}},
+ }
+ req = CompletionRequest(
+ model="x",
+ prompt="Generate a JSON object:",
+ max_tokens=100,
+ response_format={
+ "type": "json_schema",
+ "json_schema": {"name": "person", "schema": schema},
+ },
+ )
+ sampling_params = self.sc._build_sampling_params(req)
+ # The schema should be converted to string by convert_json_schema_to_str
+ self.assertIn("json_schema", sampling_params)
+ self.assertIsInstance(sampling_params["json_schema"], str)
+
+ def test_response_format_structural_tag(self):
+ """Test that response_format structural_tag is correctly processed in sampling params."""
+ req = CompletionRequest(
+ model="x",
+ prompt="Generate structured output:",
+ max_tokens=100,
+ response_format={
+ "type": "structural_tag",
+ "structures": [{"begin": "", "end": ""}],
+ "triggers": [""],
+ },
+ )
+ sampling_params = self.sc._build_sampling_params(req)
+ # The structural_tag should be processed
+ self.assertIn("structural_tag", sampling_params)
+ self.assertIsInstance(sampling_params["structural_tag"], str)
+
+ def test_response_format_none(self):
+ """Test that no response_format doesn't add extra constraints."""
+ req = CompletionRequest(model="x", prompt="Generate text:", max_tokens=100)
+ sampling_params = self.sc._build_sampling_params(req)
+ # Should not have json_schema or structural_tag from response_format
+ # (but might have json_schema from the legacy json_schema field)
+ self.assertIsNone(sampling_params.get("structural_tag"))
+
if __name__ == "__main__":
unittest.main(verbosity=2)