diff --git a/docs/advanced_features/structured_outputs.ipynb b/docs/advanced_features/structured_outputs.ipynb index 1382f1e0e..7388adfb4 100644 --- a/docs/advanced_features/structured_outputs.ipynb +++ b/docs/advanced_features/structured_outputs.ipynb @@ -349,6 +349,50 @@ "print_highlight(response.choices[0].message.content)" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Support for XGrammar latest structural tag format\n", + "# https://xgrammar.mlc.ai/docs/tutorials/structural_tag.html\n", + "\n", + "response = client.chat.completions.create(\n", + " model=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n", + " messages=messages,\n", + " response_format={\n", + " \"type\": \"structural_tag\",\n", + " \"format\": {\n", + " \"type\": \"triggered_tags\",\n", + " \"triggers\": [\"\",\n", + " \"content\": {\n", + " \"type\": \"json_schema\",\n", + " \"json_schema\": schema_get_current_weather,\n", + " },\n", + " \"end\": \"\",\n", + " },\n", + " {\n", + " \"begin\": \"\",\n", + " \"content\": {\n", + " \"type\": \"json_schema\",\n", + " \"json_schema\": schema_get_current_date,\n", + " },\n", + " \"end\": \"\",\n", + " },\n", + " ],\n", + " \"at_least_one\": False,\n", + " \"stop_after_first\": False,\n", + " },\n", + " },\n", + ")\n", + "\n", + "print_highlight(response.choices[0].message.content)" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -594,6 +638,56 @@ "print_highlight(response.json())" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Support for XGrammar latest structural tag format\n", + "# https://xgrammar.mlc.ai/docs/tutorials/structural_tag.html\n", + "\n", + "payload = {\n", + " \"text\": text,\n", + " \"sampling_params\": {\n", + " \"structural_tag\": json.dumps(\n", + " {\n", + " \"type\": \"structural_tag\",\n", + " \"format\": {\n", + " \"type\": \"triggered_tags\",\n", + " \"triggers\": [\"\",\n", + " \"content\": {\n", + " \"type\": \"json_schema\",\n", + " \"json_schema\": schema_get_current_weather,\n", + " },\n", + " \"end\": \"\",\n", + " },\n", + " {\n", + " \"begin\": \"\",\n", + " \"content\": {\n", + " \"type\": \"json_schema\",\n", + " \"json_schema\": schema_get_current_date,\n", + " },\n", + " \"end\": \"\",\n", + " },\n", + " ],\n", + " \"at_least_one\": False,\n", + " \"stop_after_first\": False,\n", + " },\n", + " }\n", + " )\n", + " },\n", + "}\n", + "\n", + "\n", + "# Send POST request to the API endpoint\n", + "response = requests.post(f\"http://localhost:{port}/generate\", json=payload)\n", + "print_highlight(response.json())" + ] + }, { "cell_type": "code", "execution_count": null, @@ -825,6 +919,57 @@ " print_highlight(f\"Prompt: {prompt}\\nGenerated text: {output['text']}\")" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Support for XGrammar latest structural tag format\n", + "# https://xgrammar.mlc.ai/docs/tutorials/structural_tag.html\n", + "\n", + "sampling_params = {\n", + " \"temperature\": 0.8,\n", + " \"top_p\": 0.95,\n", + " \"structural_tag\": json.dumps(\n", + " {\n", + " \"type\": \"structural_tag\",\n", + " \"format\": {\n", + " \"type\": \"triggered_tags\",\n", + " \"triggers\": [\"\",\n", + " \"content\": {\n", + " \"type\": \"json_schema\",\n", + " \"json_schema\": schema_get_current_weather,\n", + " },\n", + " \"end\": \"\",\n", + " },\n", + " {\n", + " \"begin\": \"\",\n", + " \"content\": {\n", + " \"type\": \"json_schema\",\n", + " \"json_schema\": schema_get_current_date,\n", + " },\n", + " \"end\": \"\",\n", + " },\n", + " ],\n", + " \"at_least_one\": False,\n", + " \"stop_after_first\": False,\n", + " },\n", + " }\n", + " ),\n", + "}\n", + "\n", + "\n", + "# Send POST request to the API endpoint\n", + "outputs = llm.generate(prompts, sampling_params)\n", + "for prompt, output in zip(prompts, outputs):\n", + " print_highlight(\"===============================\")\n", + " print_highlight(f\"Prompt: {prompt}\\nGenerated text: {output['text']}\")" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/python/sglang/srt/constrained/llguidance_backend.py b/python/sglang/srt/constrained/llguidance_backend.py index dc34a353d..c7a87fdd7 100644 --- a/python/sglang/srt/constrained/llguidance_backend.py +++ b/python/sglang/srt/constrained/llguidance_backend.py @@ -32,6 +32,7 @@ from sglang.srt.constrained.base_grammar_backend import ( BaseGrammarBackend, BaseGrammarObject, ) +from sglang.srt.constrained.utils import is_legacy_structural_tag logger = logging.getLogger(__name__) @@ -160,6 +161,7 @@ class GuidanceBackend(BaseGrammarBackend): def dispatch_structural_tag(self, key_string: str) -> Optional[GuidanceGrammar]: try: structural_tag = json.loads(key_string) + assert is_legacy_structural_tag(structural_tag) tags = [ StructTag( begin=structure["begin"], diff --git a/python/sglang/srt/constrained/utils.py b/python/sglang/srt/constrained/utils.py new file mode 100644 index 000000000..40cdcc434 --- /dev/null +++ b/python/sglang/srt/constrained/utils.py @@ -0,0 +1,12 @@ +from typing import Dict + + +def is_legacy_structural_tag(obj: Dict) -> bool: + # test whether an object is a legacy structural tag + # see `StructuralTagResponseFormat` at `sglang.srt.entrypoints.openai.protocol` + if obj.get("structures", None) is not None: + assert obj.get("triggers", None) is not None + return True + else: + assert obj.get("format", None) is not None + return False diff --git a/python/sglang/srt/constrained/xgrammar_backend.py b/python/sglang/srt/constrained/xgrammar_backend.py index 00b54baef..58ea764d6 100644 --- a/python/sglang/srt/constrained/xgrammar_backend.py +++ b/python/sglang/srt/constrained/xgrammar_backend.py @@ -34,6 +34,7 @@ from sglang.srt.constrained.base_grammar_backend import ( BaseGrammarObject, GrammarStats, ) +from sglang.srt.constrained.utils import is_legacy_structural_tag from sglang.srt.utils import is_hip _is_hip = is_hip() @@ -241,18 +242,22 @@ class XGrammarGrammarBackend(BaseGrammarBackend): def dispatch_structural_tag(self, key_string: str) -> Optional[XGrammarGrammar]: try: + # TODO(dark): it's REALLY stupid to construct object from string and decode it again structural_tag = json.loads(key_string) - tags = [ - StructuralTagItem( - begin=structure["begin"], - schema=json.dumps(structure["schema"]), - end=structure["end"], + if is_legacy_structural_tag(structural_tag): + tags = [ + StructuralTagItem( + begin=structure["begin"], + schema=json.dumps(structure["schema"]), + end=structure["end"], + ) + for structure in structural_tag["structures"] + ] + ctx = self.grammar_compiler.compile_structural_tag( + tags, structural_tag["triggers"] ) - for structure in structural_tag["structures"] - ] - ctx = self.grammar_compiler.compile_structural_tag( - tags, structural_tag["triggers"] - ) + else: + ctx = self.grammar_compiler.compile_structural_tag(key_string) except (RuntimeError, json.decoder.JSONDecodeError) as e: logging.error(f"Hit invalid structural_tag: {key_string=}, {e=}") return INVALID_GRAMMAR_OBJ diff --git a/python/sglang/srt/entrypoints/openai/protocol.py b/python/sglang/srt/entrypoints/openai/protocol.py index 46ebc6687..50c42a1ff 100644 --- a/python/sglang/srt/entrypoints/openai/protocol.py +++ b/python/sglang/srt/entrypoints/openai/protocol.py @@ -17,7 +17,7 @@ import logging import time import uuid from dataclasses import dataclass -from typing import Any, Dict, List, NamedTuple, Optional, TypeAlias, Union +from typing import Any, Dict, List, NamedTuple, Optional, Tuple, TypeAlias, Union from openai.types.responses import ( ResponseFunctionToolCall, @@ -37,6 +37,7 @@ from pydantic import ( model_validator, ) from typing_extensions import Literal +from xgrammar import StructuralTag from sglang.utils import convert_json_schema_to_str @@ -128,12 +129,23 @@ class StructuresResponseFormat(BaseModel): end: str -class StructuralTagResponseFormat(BaseModel): +# NOTE(dark): keep this for backward compatibility +class LegacyStructuralTagResponseFormat(BaseModel): type: Literal["structural_tag"] structures: List[StructuresResponseFormat] triggers: List[str] +StructuralTagResponseFormat: TypeAlias = Union[ + LegacyStructuralTagResponseFormat, StructuralTag +] + +ToolCallConstraint: TypeAlias = Union[ + Tuple[Literal["structural_tag"], StructuralTagResponseFormat], + Tuple[Literal["json_schema"], Any], # json_schema can be dict/str/None +] + + class FileRequest(BaseModel): # https://platform.openai.com/docs/api-reference/files/create file: bytes # The File object (not file name) to be uploaded @@ -583,7 +595,7 @@ class ChatCompletionRequest(BaseModel): self, stop: List[str], model_generation_config: Dict[str, Any], - tool_call_constraint: Optional[Any] = None, + tool_call_constraint: Optional[ToolCallConstraint] = None, ) -> Dict[str, Any]: """ Convert request to sampling parameters. @@ -649,7 +661,7 @@ class ChatCompletionRequest(BaseModel): ) elif constraint_type == "json_schema": sampling_params[constraint_type] = convert_json_schema_to_str( - constraint_value + constraint_value # type: ignore ) else: sampling_params[constraint_type] = constraint_value @@ -1177,7 +1189,7 @@ class MessageProcessingResult: video_data: Optional[Any] modalities: List[str] stop: List[str] - tool_call_constraint: Optional[Any] = None + tool_call_constraint: Optional[ToolCallConstraint] = None class ToolCallProcessingResult(NamedTuple): diff --git a/python/sglang/srt/function_call/function_call_parser.py b/python/sglang/srt/function_call/function_call_parser.py index 56588cb1c..7b957ec0e 100644 --- a/python/sglang/srt/function_call/function_call_parser.py +++ b/python/sglang/srt/function_call/function_call_parser.py @@ -2,9 +2,10 @@ import logging from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Type, Union from sglang.srt.entrypoints.openai.protocol import ( - StructuralTagResponseFormat, + LegacyStructuralTagResponseFormat, StructuresResponseFormat, Tool, + ToolCallConstraint, ToolChoice, ) from sglang.srt.function_call.base_format_detector import BaseFormatDetector @@ -51,7 +52,6 @@ class FunctionCallParser: } def __init__(self, tools: List[Tool], tool_call_parser: str): - detector: Type[BaseFormatDetector] = None detector_class = self.ToolCallParserEnum.get(tool_call_parser) if detector_class: detector = detector_class() @@ -123,7 +123,7 @@ class FunctionCallParser: return final_normal_text, final_calls - def get_structure_tag(self) -> StructuralTagResponseFormat: + def get_structure_tag(self) -> LegacyStructuralTagResponseFormat: """ Generate a structural tag response format for all available tools. @@ -151,7 +151,9 @@ class FunctionCallParser: ) tool_trigger_set.add(info.trigger) - return StructuralTagResponseFormat( + # TODO(dark): move this into new structural tag format + # This requires all grammar backend support the new format + return LegacyStructuralTagResponseFormat( type="structural_tag", structures=tool_structures, triggers=list(tool_trigger_set), @@ -159,7 +161,7 @@ class FunctionCallParser: def get_structure_constraint( self, tool_choice: Union[ToolChoice, Literal["auto", "required"]] - ) -> Optional[Tuple[str, Any]]: + ) -> Optional[ToolCallConstraint]: """ Returns the appropriate structure constraint for tool calls based on the tool_choice. The constraint is used to guide the model's output format. @@ -178,8 +180,8 @@ class FunctionCallParser: and tool_choice == "auto" and any(tool.function.strict for tool in self.tools) ): - strict_tag = self.get_structure_tag() - return ("structural_tag", strict_tag) + tag = self.get_structure_tag() + return ("structural_tag", tag) elif tool_choice == "required" or isinstance(tool_choice, ToolChoice): json_schema = get_json_schema_constraint(self.tools, tool_choice) return ("json_schema", json_schema) diff --git a/test/srt/openai_server/features/test_structural_tag.py b/test/srt/openai_server/features/test_structural_tag.py new file mode 100644 index 000000000..f0fed981b --- /dev/null +++ b/test/srt/openai_server/features/test_structural_tag.py @@ -0,0 +1,126 @@ +""" +python3 -m unittest test.srt.openai_server.features.test_structural_tag +""" + +import json +import unittest +from typing import Any + +import openai + +from sglang.srt.utils import kill_process_tree +from sglang.test.test_utils import ( + DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + + +def setup_class(cls, backend: str): + cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + + other_args = [ + "--max-running-requests", + "10", + "--grammar-backend", + backend, + ] + + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=other_args, + ) + + +class TestStructuralTagXGrammarBackend(CustomTestCase): + model: str + base_url: str + process: Any + + @classmethod + def setUpClass(cls): + setup_class(cls, backend="xgrammar") + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_stag_constant_str_openai(self): + client = openai.Client(api_key="EMPTY", base_url=f"{self.base_url}/v1") + + # even when the answer is ridiculous, the model should follow the instruction + answer = "The capital of France is Berlin." + + response = client.chat.completions.create( + model=self.model, + messages=[ + {"role": "system", "content": "You are a helpful AI assistant"}, + { + "role": "user", + "content": "Introduce the capital of France. Return in a JSON format.", + }, + ], + temperature=0, + max_tokens=128, + response_format={ + "type": "structural_tag", + "format": { + "type": "const_string", + "value": answer, + }, + }, + ) + + text = response.choices[0].message.content + self.assertEqual(text, answer) + + def test_stag_json_schema_openai(self): + client = openai.Client(api_key="EMPTY", base_url=f"{self.base_url}/v1") + json_schema = { + "type": "object", + "properties": { + "name": {"type": "string", "pattern": "^[\\w]+$"}, + "population": {"type": "integer"}, + }, + "required": ["name", "population"], + "additionalProperties": False, + } + + response = client.chat.completions.create( + model=self.model, + messages=[ + {"role": "system", "content": "You are a helpful AI assistant"}, + { + "role": "user", + "content": "Introduce the capital of France. Return in a JSON format.", + }, + ], + temperature=0, + max_tokens=128, + response_format={ + "type": "structural_tag", + "format": { + "type": "json_schema", + "json_schema": json_schema, + }, + }, + ) + + text = response.choices[0].message.content + try: + js_obj = json.loads(text) + except (TypeError, json.decoder.JSONDecodeError): + print("JSONDecodeError", text) + raise + + self.assertIsInstance(js_obj["name"], str) + self.assertIsInstance(js_obj["population"], int) + + +if __name__ == "__main__": + unittest.main()