add response_format support for completion API (#9665)

This commit is contained in:
cicirori
2025-08-27 00:01:29 +02:00
committed by GitHub
parent 43de1d7304
commit b6c14ec0b4
3 changed files with 90 additions and 17 deletions

View File

@@ -95,6 +95,63 @@ class ServingCompletionTestCase(unittest.TestCase):
self.sc.tokenizer_manager.tokenizer.decode.return_value = "decoded"
self.assertEqual(self.sc._prepare_echo_prompts(req), ["decoded"])
# ---------- response_format handling ----------
def test_response_format_json_object(self):
"""Test that response_format json_object is correctly processed in sampling params."""
req = CompletionRequest(
model="x",
prompt="Generate a JSON object:",
max_tokens=100,
response_format={"type": "json_object"},
)
sampling_params = self.sc._build_sampling_params(req)
self.assertEqual(sampling_params["json_schema"], '{"type": "object"}')
def test_response_format_json_schema(self):
"""Test that response_format json_schema is correctly processed in sampling params."""
schema = {
"type": "object",
"properties": {"name": {"type": "string"}, "age": {"type": "integer"}},
}
req = CompletionRequest(
model="x",
prompt="Generate a JSON object:",
max_tokens=100,
response_format={
"type": "json_schema",
"json_schema": {"name": "person", "schema": schema},
},
)
sampling_params = self.sc._build_sampling_params(req)
# The schema should be converted to string by convert_json_schema_to_str
self.assertIn("json_schema", sampling_params)
self.assertIsInstance(sampling_params["json_schema"], str)
def test_response_format_structural_tag(self):
"""Test that response_format structural_tag is correctly processed in sampling params."""
req = CompletionRequest(
model="x",
prompt="Generate structured output:",
max_tokens=100,
response_format={
"type": "structural_tag",
"structures": [{"begin": "<data>", "end": "</data>"}],
"triggers": ["<data>"],
},
)
sampling_params = self.sc._build_sampling_params(req)
# The structural_tag should be processed
self.assertIn("structural_tag", sampling_params)
self.assertIsInstance(sampling_params["structural_tag"], str)
def test_response_format_none(self):
"""Test that no response_format doesn't add extra constraints."""
req = CompletionRequest(model="x", prompt="Generate text:", max_tokens=100)
sampling_params = self.sc._build_sampling_params(req)
# Should not have json_schema or structural_tag from response_format
# (but might have json_schema from the legacy json_schema field)
self.assertIsNone(sampling_params.get("structural_tag"))
if __name__ == "__main__":
unittest.main(verbosity=2)