[Feature] New structural tag support (#10691)
This commit is contained in:
@@ -349,6 +349,50 @@
|
||||
"print_highlight(response.choices[0].message.content)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Support for XGrammar latest structural tag format\n",
|
||||
"# https://xgrammar.mlc.ai/docs/tutorials/structural_tag.html\n",
|
||||
"\n",
|
||||
"response = client.chat.completions.create(\n",
|
||||
" model=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n",
|
||||
" messages=messages,\n",
|
||||
" response_format={\n",
|
||||
" \"type\": \"structural_tag\",\n",
|
||||
" \"format\": {\n",
|
||||
" \"type\": \"triggered_tags\",\n",
|
||||
" \"triggers\": [\"<function=\"],\n",
|
||||
" \"tags\": [\n",
|
||||
" {\n",
|
||||
" \"begin\": \"<function=get_current_weather>\",\n",
|
||||
" \"content\": {\n",
|
||||
" \"type\": \"json_schema\",\n",
|
||||
" \"json_schema\": schema_get_current_weather,\n",
|
||||
" },\n",
|
||||
" \"end\": \"</function>\",\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"begin\": \"<function=get_current_date>\",\n",
|
||||
" \"content\": {\n",
|
||||
" \"type\": \"json_schema\",\n",
|
||||
" \"json_schema\": schema_get_current_date,\n",
|
||||
" },\n",
|
||||
" \"end\": \"</function>\",\n",
|
||||
" },\n",
|
||||
" ],\n",
|
||||
" \"at_least_one\": False,\n",
|
||||
" \"stop_after_first\": False,\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print_highlight(response.choices[0].message.content)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
@@ -594,6 +638,56 @@
|
||||
"print_highlight(response.json())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Support for XGrammar latest structural tag format\n",
|
||||
"# https://xgrammar.mlc.ai/docs/tutorials/structural_tag.html\n",
|
||||
"\n",
|
||||
"payload = {\n",
|
||||
" \"text\": text,\n",
|
||||
" \"sampling_params\": {\n",
|
||||
" \"structural_tag\": json.dumps(\n",
|
||||
" {\n",
|
||||
" \"type\": \"structural_tag\",\n",
|
||||
" \"format\": {\n",
|
||||
" \"type\": \"triggered_tags\",\n",
|
||||
" \"triggers\": [\"<function=\"],\n",
|
||||
" \"tags\": [\n",
|
||||
" {\n",
|
||||
" \"begin\": \"<function=get_current_weather>\",\n",
|
||||
" \"content\": {\n",
|
||||
" \"type\": \"json_schema\",\n",
|
||||
" \"json_schema\": schema_get_current_weather,\n",
|
||||
" },\n",
|
||||
" \"end\": \"</function>\",\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"begin\": \"<function=get_current_date>\",\n",
|
||||
" \"content\": {\n",
|
||||
" \"type\": \"json_schema\",\n",
|
||||
" \"json_schema\": schema_get_current_date,\n",
|
||||
" },\n",
|
||||
" \"end\": \"</function>\",\n",
|
||||
" },\n",
|
||||
" ],\n",
|
||||
" \"at_least_one\": False,\n",
|
||||
" \"stop_after_first\": False,\n",
|
||||
" },\n",
|
||||
" }\n",
|
||||
" )\n",
|
||||
" },\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Send POST request to the API endpoint\n",
|
||||
"response = requests.post(f\"http://localhost:{port}/generate\", json=payload)\n",
|
||||
"print_highlight(response.json())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@@ -825,6 +919,57 @@
|
||||
" print_highlight(f\"Prompt: {prompt}\\nGenerated text: {output['text']}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Support for XGrammar latest structural tag format\n",
|
||||
"# https://xgrammar.mlc.ai/docs/tutorials/structural_tag.html\n",
|
||||
"\n",
|
||||
"sampling_params = {\n",
|
||||
" \"temperature\": 0.8,\n",
|
||||
" \"top_p\": 0.95,\n",
|
||||
" \"structural_tag\": json.dumps(\n",
|
||||
" {\n",
|
||||
" \"type\": \"structural_tag\",\n",
|
||||
" \"format\": {\n",
|
||||
" \"type\": \"triggered_tags\",\n",
|
||||
" \"triggers\": [\"<function=\"],\n",
|
||||
" \"tags\": [\n",
|
||||
" {\n",
|
||||
" \"begin\": \"<function=get_current_weather>\",\n",
|
||||
" \"content\": {\n",
|
||||
" \"type\": \"json_schema\",\n",
|
||||
" \"json_schema\": schema_get_current_weather,\n",
|
||||
" },\n",
|
||||
" \"end\": \"</function>\",\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"begin\": \"<function=get_current_date>\",\n",
|
||||
" \"content\": {\n",
|
||||
" \"type\": \"json_schema\",\n",
|
||||
" \"json_schema\": schema_get_current_date,\n",
|
||||
" },\n",
|
||||
" \"end\": \"</function>\",\n",
|
||||
" },\n",
|
||||
" ],\n",
|
||||
" \"at_least_one\": False,\n",
|
||||
" \"stop_after_first\": False,\n",
|
||||
" },\n",
|
||||
" }\n",
|
||||
" ),\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Send POST request to the API endpoint\n",
|
||||
"outputs = llm.generate(prompts, sampling_params)\n",
|
||||
"for prompt, output in zip(prompts, outputs):\n",
|
||||
" print_highlight(\"===============================\")\n",
|
||||
" print_highlight(f\"Prompt: {prompt}\\nGenerated text: {output['text']}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
|
||||
@@ -32,6 +32,7 @@ from sglang.srt.constrained.base_grammar_backend import (
|
||||
BaseGrammarBackend,
|
||||
BaseGrammarObject,
|
||||
)
|
||||
from sglang.srt.constrained.utils import is_legacy_structural_tag
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -160,6 +161,7 @@ class GuidanceBackend(BaseGrammarBackend):
|
||||
def dispatch_structural_tag(self, key_string: str) -> Optional[GuidanceGrammar]:
|
||||
try:
|
||||
structural_tag = json.loads(key_string)
|
||||
assert is_legacy_structural_tag(structural_tag)
|
||||
tags = [
|
||||
StructTag(
|
||||
begin=structure["begin"],
|
||||
|
||||
12
python/sglang/srt/constrained/utils.py
Normal file
12
python/sglang/srt/constrained/utils.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from typing import Dict
|
||||
|
||||
|
||||
def is_legacy_structural_tag(obj: Dict) -> bool:
|
||||
# test whether an object is a legacy structural tag
|
||||
# see `StructuralTagResponseFormat` at `sglang.srt.entrypoints.openai.protocol`
|
||||
if obj.get("structures", None) is not None:
|
||||
assert obj.get("triggers", None) is not None
|
||||
return True
|
||||
else:
|
||||
assert obj.get("format", None) is not None
|
||||
return False
|
||||
@@ -34,6 +34,7 @@ from sglang.srt.constrained.base_grammar_backend import (
|
||||
BaseGrammarObject,
|
||||
GrammarStats,
|
||||
)
|
||||
from sglang.srt.constrained.utils import is_legacy_structural_tag
|
||||
from sglang.srt.utils import is_hip
|
||||
|
||||
_is_hip = is_hip()
|
||||
@@ -241,18 +242,22 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
||||
|
||||
def dispatch_structural_tag(self, key_string: str) -> Optional[XGrammarGrammar]:
|
||||
try:
|
||||
# TODO(dark): it's REALLY stupid to construct object from string and decode it again
|
||||
structural_tag = json.loads(key_string)
|
||||
tags = [
|
||||
StructuralTagItem(
|
||||
begin=structure["begin"],
|
||||
schema=json.dumps(structure["schema"]),
|
||||
end=structure["end"],
|
||||
if is_legacy_structural_tag(structural_tag):
|
||||
tags = [
|
||||
StructuralTagItem(
|
||||
begin=structure["begin"],
|
||||
schema=json.dumps(structure["schema"]),
|
||||
end=structure["end"],
|
||||
)
|
||||
for structure in structural_tag["structures"]
|
||||
]
|
||||
ctx = self.grammar_compiler.compile_structural_tag(
|
||||
tags, structural_tag["triggers"]
|
||||
)
|
||||
for structure in structural_tag["structures"]
|
||||
]
|
||||
ctx = self.grammar_compiler.compile_structural_tag(
|
||||
tags, structural_tag["triggers"]
|
||||
)
|
||||
else:
|
||||
ctx = self.grammar_compiler.compile_structural_tag(key_string)
|
||||
except (RuntimeError, json.decoder.JSONDecodeError) as e:
|
||||
logging.error(f"Hit invalid structural_tag: {key_string=}, {e=}")
|
||||
return INVALID_GRAMMAR_OBJ
|
||||
|
||||
@@ -17,7 +17,7 @@ import logging
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, NamedTuple, Optional, TypeAlias, Union
|
||||
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, TypeAlias, Union
|
||||
|
||||
from openai.types.responses import (
|
||||
ResponseFunctionToolCall,
|
||||
@@ -37,6 +37,7 @@ from pydantic import (
|
||||
model_validator,
|
||||
)
|
||||
from typing_extensions import Literal
|
||||
from xgrammar import StructuralTag
|
||||
|
||||
from sglang.utils import convert_json_schema_to_str
|
||||
|
||||
@@ -128,12 +129,23 @@ class StructuresResponseFormat(BaseModel):
|
||||
end: str
|
||||
|
||||
|
||||
class StructuralTagResponseFormat(BaseModel):
|
||||
# NOTE(dark): keep this for backward compatibility
|
||||
class LegacyStructuralTagResponseFormat(BaseModel):
|
||||
type: Literal["structural_tag"]
|
||||
structures: List[StructuresResponseFormat]
|
||||
triggers: List[str]
|
||||
|
||||
|
||||
StructuralTagResponseFormat: TypeAlias = Union[
|
||||
LegacyStructuralTagResponseFormat, StructuralTag
|
||||
]
|
||||
|
||||
ToolCallConstraint: TypeAlias = Union[
|
||||
Tuple[Literal["structural_tag"], StructuralTagResponseFormat],
|
||||
Tuple[Literal["json_schema"], Any], # json_schema can be dict/str/None
|
||||
]
|
||||
|
||||
|
||||
class FileRequest(BaseModel):
|
||||
# https://platform.openai.com/docs/api-reference/files/create
|
||||
file: bytes # The File object (not file name) to be uploaded
|
||||
@@ -583,7 +595,7 @@ class ChatCompletionRequest(BaseModel):
|
||||
self,
|
||||
stop: List[str],
|
||||
model_generation_config: Dict[str, Any],
|
||||
tool_call_constraint: Optional[Any] = None,
|
||||
tool_call_constraint: Optional[ToolCallConstraint] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert request to sampling parameters.
|
||||
@@ -649,7 +661,7 @@ class ChatCompletionRequest(BaseModel):
|
||||
)
|
||||
elif constraint_type == "json_schema":
|
||||
sampling_params[constraint_type] = convert_json_schema_to_str(
|
||||
constraint_value
|
||||
constraint_value # type: ignore
|
||||
)
|
||||
else:
|
||||
sampling_params[constraint_type] = constraint_value
|
||||
@@ -1177,7 +1189,7 @@ class MessageProcessingResult:
|
||||
video_data: Optional[Any]
|
||||
modalities: List[str]
|
||||
stop: List[str]
|
||||
tool_call_constraint: Optional[Any] = None
|
||||
tool_call_constraint: Optional[ToolCallConstraint] = None
|
||||
|
||||
|
||||
class ToolCallProcessingResult(NamedTuple):
|
||||
|
||||
@@ -2,9 +2,10 @@ import logging
|
||||
from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Type, Union
|
||||
|
||||
from sglang.srt.entrypoints.openai.protocol import (
|
||||
StructuralTagResponseFormat,
|
||||
LegacyStructuralTagResponseFormat,
|
||||
StructuresResponseFormat,
|
||||
Tool,
|
||||
ToolCallConstraint,
|
||||
ToolChoice,
|
||||
)
|
||||
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
||||
@@ -51,7 +52,6 @@ class FunctionCallParser:
|
||||
}
|
||||
|
||||
def __init__(self, tools: List[Tool], tool_call_parser: str):
|
||||
detector: Type[BaseFormatDetector] = None
|
||||
detector_class = self.ToolCallParserEnum.get(tool_call_parser)
|
||||
if detector_class:
|
||||
detector = detector_class()
|
||||
@@ -123,7 +123,7 @@ class FunctionCallParser:
|
||||
|
||||
return final_normal_text, final_calls
|
||||
|
||||
def get_structure_tag(self) -> StructuralTagResponseFormat:
|
||||
def get_structure_tag(self) -> LegacyStructuralTagResponseFormat:
|
||||
"""
|
||||
Generate a structural tag response format for all available tools.
|
||||
|
||||
@@ -151,7 +151,9 @@ class FunctionCallParser:
|
||||
)
|
||||
tool_trigger_set.add(info.trigger)
|
||||
|
||||
return StructuralTagResponseFormat(
|
||||
# TODO(dark): move this into new structural tag format
|
||||
# This requires all grammar backend support the new format
|
||||
return LegacyStructuralTagResponseFormat(
|
||||
type="structural_tag",
|
||||
structures=tool_structures,
|
||||
triggers=list(tool_trigger_set),
|
||||
@@ -159,7 +161,7 @@ class FunctionCallParser:
|
||||
|
||||
def get_structure_constraint(
|
||||
self, tool_choice: Union[ToolChoice, Literal["auto", "required"]]
|
||||
) -> Optional[Tuple[str, Any]]:
|
||||
) -> Optional[ToolCallConstraint]:
|
||||
"""
|
||||
Returns the appropriate structure constraint for tool calls based on the tool_choice.
|
||||
The constraint is used to guide the model's output format.
|
||||
@@ -178,8 +180,8 @@ class FunctionCallParser:
|
||||
and tool_choice == "auto"
|
||||
and any(tool.function.strict for tool in self.tools)
|
||||
):
|
||||
strict_tag = self.get_structure_tag()
|
||||
return ("structural_tag", strict_tag)
|
||||
tag = self.get_structure_tag()
|
||||
return ("structural_tag", tag)
|
||||
elif tool_choice == "required" or isinstance(tool_choice, ToolChoice):
|
||||
json_schema = get_json_schema_constraint(self.tools, tool_choice)
|
||||
return ("json_schema", json_schema)
|
||||
|
||||
126
test/srt/openai_server/features/test_structural_tag.py
Normal file
126
test/srt/openai_server/features/test_structural_tag.py
Normal file
@@ -0,0 +1,126 @@
|
||||
"""
|
||||
python3 -m unittest test.srt.openai_server.features.test_structural_tag
|
||||
"""
|
||||
|
||||
import json
|
||||
import unittest
|
||||
from typing import Any
|
||||
|
||||
import openai
|
||||
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
|
||||
def setup_class(cls, backend: str):
|
||||
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
|
||||
other_args = [
|
||||
"--max-running-requests",
|
||||
"10",
|
||||
"--grammar-backend",
|
||||
backend,
|
||||
]
|
||||
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=other_args,
|
||||
)
|
||||
|
||||
|
||||
class TestStructuralTagXGrammarBackend(CustomTestCase):
|
||||
model: str
|
||||
base_url: str
|
||||
process: Any
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
setup_class(cls, backend="xgrammar")
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def test_stag_constant_str_openai(self):
|
||||
client = openai.Client(api_key="EMPTY", base_url=f"{self.base_url}/v1")
|
||||
|
||||
# even when the answer is ridiculous, the model should follow the instruction
|
||||
answer = "The capital of France is Berlin."
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful AI assistant"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Introduce the capital of France. Return in a JSON format.",
|
||||
},
|
||||
],
|
||||
temperature=0,
|
||||
max_tokens=128,
|
||||
response_format={
|
||||
"type": "structural_tag",
|
||||
"format": {
|
||||
"type": "const_string",
|
||||
"value": answer,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
text = response.choices[0].message.content
|
||||
self.assertEqual(text, answer)
|
||||
|
||||
def test_stag_json_schema_openai(self):
|
||||
client = openai.Client(api_key="EMPTY", base_url=f"{self.base_url}/v1")
|
||||
json_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string", "pattern": "^[\\w]+$"},
|
||||
"population": {"type": "integer"},
|
||||
},
|
||||
"required": ["name", "population"],
|
||||
"additionalProperties": False,
|
||||
}
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful AI assistant"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Introduce the capital of France. Return in a JSON format.",
|
||||
},
|
||||
],
|
||||
temperature=0,
|
||||
max_tokens=128,
|
||||
response_format={
|
||||
"type": "structural_tag",
|
||||
"format": {
|
||||
"type": "json_schema",
|
||||
"json_schema": json_schema,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
text = response.choices[0].message.content
|
||||
try:
|
||||
js_obj = json.loads(text)
|
||||
except (TypeError, json.decoder.JSONDecodeError):
|
||||
print("JSONDecodeError", text)
|
||||
raise
|
||||
|
||||
self.assertIsInstance(js_obj["name"], str)
|
||||
self.assertIsInstance(js_obj["population"], int)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user