[Feature] New structural tag support (#10691)
This commit is contained in:
@@ -349,6 +349,50 @@
|
|||||||
"print_highlight(response.choices[0].message.content)"
|
"print_highlight(response.choices[0].message.content)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Support for XGrammar latest structural tag format\n",
|
||||||
|
"# https://xgrammar.mlc.ai/docs/tutorials/structural_tag.html\n",
|
||||||
|
"\n",
|
||||||
|
"response = client.chat.completions.create(\n",
|
||||||
|
" model=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n",
|
||||||
|
" messages=messages,\n",
|
||||||
|
" response_format={\n",
|
||||||
|
" \"type\": \"structural_tag\",\n",
|
||||||
|
" \"format\": {\n",
|
||||||
|
" \"type\": \"triggered_tags\",\n",
|
||||||
|
" \"triggers\": [\"<function=\"],\n",
|
||||||
|
" \"tags\": [\n",
|
||||||
|
" {\n",
|
||||||
|
" \"begin\": \"<function=get_current_weather>\",\n",
|
||||||
|
" \"content\": {\n",
|
||||||
|
" \"type\": \"json_schema\",\n",
|
||||||
|
" \"json_schema\": schema_get_current_weather,\n",
|
||||||
|
" },\n",
|
||||||
|
" \"end\": \"</function>\",\n",
|
||||||
|
" },\n",
|
||||||
|
" {\n",
|
||||||
|
" \"begin\": \"<function=get_current_date>\",\n",
|
||||||
|
" \"content\": {\n",
|
||||||
|
" \"type\": \"json_schema\",\n",
|
||||||
|
" \"json_schema\": schema_get_current_date,\n",
|
||||||
|
" },\n",
|
||||||
|
" \"end\": \"</function>\",\n",
|
||||||
|
" },\n",
|
||||||
|
" ],\n",
|
||||||
|
" \"at_least_one\": False,\n",
|
||||||
|
" \"stop_after_first\": False,\n",
|
||||||
|
" },\n",
|
||||||
|
" },\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"print_highlight(response.choices[0].message.content)"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
@@ -594,6 +638,56 @@
|
|||||||
"print_highlight(response.json())"
|
"print_highlight(response.json())"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Support for XGrammar latest structural tag format\n",
|
||||||
|
"# https://xgrammar.mlc.ai/docs/tutorials/structural_tag.html\n",
|
||||||
|
"\n",
|
||||||
|
"payload = {\n",
|
||||||
|
" \"text\": text,\n",
|
||||||
|
" \"sampling_params\": {\n",
|
||||||
|
" \"structural_tag\": json.dumps(\n",
|
||||||
|
" {\n",
|
||||||
|
" \"type\": \"structural_tag\",\n",
|
||||||
|
" \"format\": {\n",
|
||||||
|
" \"type\": \"triggered_tags\",\n",
|
||||||
|
" \"triggers\": [\"<function=\"],\n",
|
||||||
|
" \"tags\": [\n",
|
||||||
|
" {\n",
|
||||||
|
" \"begin\": \"<function=get_current_weather>\",\n",
|
||||||
|
" \"content\": {\n",
|
||||||
|
" \"type\": \"json_schema\",\n",
|
||||||
|
" \"json_schema\": schema_get_current_weather,\n",
|
||||||
|
" },\n",
|
||||||
|
" \"end\": \"</function>\",\n",
|
||||||
|
" },\n",
|
||||||
|
" {\n",
|
||||||
|
" \"begin\": \"<function=get_current_date>\",\n",
|
||||||
|
" \"content\": {\n",
|
||||||
|
" \"type\": \"json_schema\",\n",
|
||||||
|
" \"json_schema\": schema_get_current_date,\n",
|
||||||
|
" },\n",
|
||||||
|
" \"end\": \"</function>\",\n",
|
||||||
|
" },\n",
|
||||||
|
" ],\n",
|
||||||
|
" \"at_least_one\": False,\n",
|
||||||
|
" \"stop_after_first\": False,\n",
|
||||||
|
" },\n",
|
||||||
|
" }\n",
|
||||||
|
" )\n",
|
||||||
|
" },\n",
|
||||||
|
"}\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"# Send POST request to the API endpoint\n",
|
||||||
|
"response = requests.post(f\"http://localhost:{port}/generate\", json=payload)\n",
|
||||||
|
"print_highlight(response.json())"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
@@ -825,6 +919,57 @@
|
|||||||
" print_highlight(f\"Prompt: {prompt}\\nGenerated text: {output['text']}\")"
|
" print_highlight(f\"Prompt: {prompt}\\nGenerated text: {output['text']}\")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Support for XGrammar latest structural tag format\n",
|
||||||
|
"# https://xgrammar.mlc.ai/docs/tutorials/structural_tag.html\n",
|
||||||
|
"\n",
|
||||||
|
"sampling_params = {\n",
|
||||||
|
" \"temperature\": 0.8,\n",
|
||||||
|
" \"top_p\": 0.95,\n",
|
||||||
|
" \"structural_tag\": json.dumps(\n",
|
||||||
|
" {\n",
|
||||||
|
" \"type\": \"structural_tag\",\n",
|
||||||
|
" \"format\": {\n",
|
||||||
|
" \"type\": \"triggered_tags\",\n",
|
||||||
|
" \"triggers\": [\"<function=\"],\n",
|
||||||
|
" \"tags\": [\n",
|
||||||
|
" {\n",
|
||||||
|
" \"begin\": \"<function=get_current_weather>\",\n",
|
||||||
|
" \"content\": {\n",
|
||||||
|
" \"type\": \"json_schema\",\n",
|
||||||
|
" \"json_schema\": schema_get_current_weather,\n",
|
||||||
|
" },\n",
|
||||||
|
" \"end\": \"</function>\",\n",
|
||||||
|
" },\n",
|
||||||
|
" {\n",
|
||||||
|
" \"begin\": \"<function=get_current_date>\",\n",
|
||||||
|
" \"content\": {\n",
|
||||||
|
" \"type\": \"json_schema\",\n",
|
||||||
|
" \"json_schema\": schema_get_current_date,\n",
|
||||||
|
" },\n",
|
||||||
|
" \"end\": \"</function>\",\n",
|
||||||
|
" },\n",
|
||||||
|
" ],\n",
|
||||||
|
" \"at_least_one\": False,\n",
|
||||||
|
" \"stop_after_first\": False,\n",
|
||||||
|
" },\n",
|
||||||
|
" }\n",
|
||||||
|
" ),\n",
|
||||||
|
"}\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"# Send POST request to the API endpoint\n",
|
||||||
|
"outputs = llm.generate(prompts, sampling_params)\n",
|
||||||
|
"for prompt, output in zip(prompts, outputs):\n",
|
||||||
|
" print_highlight(\"===============================\")\n",
|
||||||
|
" print_highlight(f\"Prompt: {prompt}\\nGenerated text: {output['text']}\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ from sglang.srt.constrained.base_grammar_backend import (
|
|||||||
BaseGrammarBackend,
|
BaseGrammarBackend,
|
||||||
BaseGrammarObject,
|
BaseGrammarObject,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.constrained.utils import is_legacy_structural_tag
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -160,6 +161,7 @@ class GuidanceBackend(BaseGrammarBackend):
|
|||||||
def dispatch_structural_tag(self, key_string: str) -> Optional[GuidanceGrammar]:
|
def dispatch_structural_tag(self, key_string: str) -> Optional[GuidanceGrammar]:
|
||||||
try:
|
try:
|
||||||
structural_tag = json.loads(key_string)
|
structural_tag = json.loads(key_string)
|
||||||
|
assert is_legacy_structural_tag(structural_tag)
|
||||||
tags = [
|
tags = [
|
||||||
StructTag(
|
StructTag(
|
||||||
begin=structure["begin"],
|
begin=structure["begin"],
|
||||||
|
|||||||
12
python/sglang/srt/constrained/utils.py
Normal file
12
python/sglang/srt/constrained/utils.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
|
||||||
|
def is_legacy_structural_tag(obj: Dict) -> bool:
|
||||||
|
# test whether an object is a legacy structural tag
|
||||||
|
# see `StructuralTagResponseFormat` at `sglang.srt.entrypoints.openai.protocol`
|
||||||
|
if obj.get("structures", None) is not None:
|
||||||
|
assert obj.get("triggers", None) is not None
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
assert obj.get("format", None) is not None
|
||||||
|
return False
|
||||||
@@ -34,6 +34,7 @@ from sglang.srt.constrained.base_grammar_backend import (
|
|||||||
BaseGrammarObject,
|
BaseGrammarObject,
|
||||||
GrammarStats,
|
GrammarStats,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.constrained.utils import is_legacy_structural_tag
|
||||||
from sglang.srt.utils import is_hip
|
from sglang.srt.utils import is_hip
|
||||||
|
|
||||||
_is_hip = is_hip()
|
_is_hip = is_hip()
|
||||||
@@ -241,18 +242,22 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
|||||||
|
|
||||||
def dispatch_structural_tag(self, key_string: str) -> Optional[XGrammarGrammar]:
|
def dispatch_structural_tag(self, key_string: str) -> Optional[XGrammarGrammar]:
|
||||||
try:
|
try:
|
||||||
|
# TODO(dark): it's REALLY stupid to construct object from string and decode it again
|
||||||
structural_tag = json.loads(key_string)
|
structural_tag = json.loads(key_string)
|
||||||
tags = [
|
if is_legacy_structural_tag(structural_tag):
|
||||||
StructuralTagItem(
|
tags = [
|
||||||
begin=structure["begin"],
|
StructuralTagItem(
|
||||||
schema=json.dumps(structure["schema"]),
|
begin=structure["begin"],
|
||||||
end=structure["end"],
|
schema=json.dumps(structure["schema"]),
|
||||||
|
end=structure["end"],
|
||||||
|
)
|
||||||
|
for structure in structural_tag["structures"]
|
||||||
|
]
|
||||||
|
ctx = self.grammar_compiler.compile_structural_tag(
|
||||||
|
tags, structural_tag["triggers"]
|
||||||
)
|
)
|
||||||
for structure in structural_tag["structures"]
|
else:
|
||||||
]
|
ctx = self.grammar_compiler.compile_structural_tag(key_string)
|
||||||
ctx = self.grammar_compiler.compile_structural_tag(
|
|
||||||
tags, structural_tag["triggers"]
|
|
||||||
)
|
|
||||||
except (RuntimeError, json.decoder.JSONDecodeError) as e:
|
except (RuntimeError, json.decoder.JSONDecodeError) as e:
|
||||||
logging.error(f"Hit invalid structural_tag: {key_string=}, {e=}")
|
logging.error(f"Hit invalid structural_tag: {key_string=}, {e=}")
|
||||||
return INVALID_GRAMMAR_OBJ
|
return INVALID_GRAMMAR_OBJ
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ import logging
|
|||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, List, NamedTuple, Optional, TypeAlias, Union
|
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, TypeAlias, Union
|
||||||
|
|
||||||
from openai.types.responses import (
|
from openai.types.responses import (
|
||||||
ResponseFunctionToolCall,
|
ResponseFunctionToolCall,
|
||||||
@@ -37,6 +37,7 @@ from pydantic import (
|
|||||||
model_validator,
|
model_validator,
|
||||||
)
|
)
|
||||||
from typing_extensions import Literal
|
from typing_extensions import Literal
|
||||||
|
from xgrammar import StructuralTag
|
||||||
|
|
||||||
from sglang.utils import convert_json_schema_to_str
|
from sglang.utils import convert_json_schema_to_str
|
||||||
|
|
||||||
@@ -128,12 +129,23 @@ class StructuresResponseFormat(BaseModel):
|
|||||||
end: str
|
end: str
|
||||||
|
|
||||||
|
|
||||||
class StructuralTagResponseFormat(BaseModel):
|
# NOTE(dark): keep this for backward compatibility
|
||||||
|
class LegacyStructuralTagResponseFormat(BaseModel):
|
||||||
type: Literal["structural_tag"]
|
type: Literal["structural_tag"]
|
||||||
structures: List[StructuresResponseFormat]
|
structures: List[StructuresResponseFormat]
|
||||||
triggers: List[str]
|
triggers: List[str]
|
||||||
|
|
||||||
|
|
||||||
|
StructuralTagResponseFormat: TypeAlias = Union[
|
||||||
|
LegacyStructuralTagResponseFormat, StructuralTag
|
||||||
|
]
|
||||||
|
|
||||||
|
ToolCallConstraint: TypeAlias = Union[
|
||||||
|
Tuple[Literal["structural_tag"], StructuralTagResponseFormat],
|
||||||
|
Tuple[Literal["json_schema"], Any], # json_schema can be dict/str/None
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class FileRequest(BaseModel):
|
class FileRequest(BaseModel):
|
||||||
# https://platform.openai.com/docs/api-reference/files/create
|
# https://platform.openai.com/docs/api-reference/files/create
|
||||||
file: bytes # The File object (not file name) to be uploaded
|
file: bytes # The File object (not file name) to be uploaded
|
||||||
@@ -583,7 +595,7 @@ class ChatCompletionRequest(BaseModel):
|
|||||||
self,
|
self,
|
||||||
stop: List[str],
|
stop: List[str],
|
||||||
model_generation_config: Dict[str, Any],
|
model_generation_config: Dict[str, Any],
|
||||||
tool_call_constraint: Optional[Any] = None,
|
tool_call_constraint: Optional[ToolCallConstraint] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Convert request to sampling parameters.
|
Convert request to sampling parameters.
|
||||||
@@ -649,7 +661,7 @@ class ChatCompletionRequest(BaseModel):
|
|||||||
)
|
)
|
||||||
elif constraint_type == "json_schema":
|
elif constraint_type == "json_schema":
|
||||||
sampling_params[constraint_type] = convert_json_schema_to_str(
|
sampling_params[constraint_type] = convert_json_schema_to_str(
|
||||||
constraint_value
|
constraint_value # type: ignore
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
sampling_params[constraint_type] = constraint_value
|
sampling_params[constraint_type] = constraint_value
|
||||||
@@ -1177,7 +1189,7 @@ class MessageProcessingResult:
|
|||||||
video_data: Optional[Any]
|
video_data: Optional[Any]
|
||||||
modalities: List[str]
|
modalities: List[str]
|
||||||
stop: List[str]
|
stop: List[str]
|
||||||
tool_call_constraint: Optional[Any] = None
|
tool_call_constraint: Optional[ToolCallConstraint] = None
|
||||||
|
|
||||||
|
|
||||||
class ToolCallProcessingResult(NamedTuple):
|
class ToolCallProcessingResult(NamedTuple):
|
||||||
|
|||||||
@@ -2,9 +2,10 @@ import logging
|
|||||||
from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Type, Union
|
from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Type, Union
|
||||||
|
|
||||||
from sglang.srt.entrypoints.openai.protocol import (
|
from sglang.srt.entrypoints.openai.protocol import (
|
||||||
StructuralTagResponseFormat,
|
LegacyStructuralTagResponseFormat,
|
||||||
StructuresResponseFormat,
|
StructuresResponseFormat,
|
||||||
Tool,
|
Tool,
|
||||||
|
ToolCallConstraint,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
)
|
)
|
||||||
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
||||||
@@ -51,7 +52,6 @@ class FunctionCallParser:
|
|||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, tools: List[Tool], tool_call_parser: str):
|
def __init__(self, tools: List[Tool], tool_call_parser: str):
|
||||||
detector: Type[BaseFormatDetector] = None
|
|
||||||
detector_class = self.ToolCallParserEnum.get(tool_call_parser)
|
detector_class = self.ToolCallParserEnum.get(tool_call_parser)
|
||||||
if detector_class:
|
if detector_class:
|
||||||
detector = detector_class()
|
detector = detector_class()
|
||||||
@@ -123,7 +123,7 @@ class FunctionCallParser:
|
|||||||
|
|
||||||
return final_normal_text, final_calls
|
return final_normal_text, final_calls
|
||||||
|
|
||||||
def get_structure_tag(self) -> StructuralTagResponseFormat:
|
def get_structure_tag(self) -> LegacyStructuralTagResponseFormat:
|
||||||
"""
|
"""
|
||||||
Generate a structural tag response format for all available tools.
|
Generate a structural tag response format for all available tools.
|
||||||
|
|
||||||
@@ -151,7 +151,9 @@ class FunctionCallParser:
|
|||||||
)
|
)
|
||||||
tool_trigger_set.add(info.trigger)
|
tool_trigger_set.add(info.trigger)
|
||||||
|
|
||||||
return StructuralTagResponseFormat(
|
# TODO(dark): move this into new structural tag format
|
||||||
|
# This requires all grammar backend support the new format
|
||||||
|
return LegacyStructuralTagResponseFormat(
|
||||||
type="structural_tag",
|
type="structural_tag",
|
||||||
structures=tool_structures,
|
structures=tool_structures,
|
||||||
triggers=list(tool_trigger_set),
|
triggers=list(tool_trigger_set),
|
||||||
@@ -159,7 +161,7 @@ class FunctionCallParser:
|
|||||||
|
|
||||||
def get_structure_constraint(
|
def get_structure_constraint(
|
||||||
self, tool_choice: Union[ToolChoice, Literal["auto", "required"]]
|
self, tool_choice: Union[ToolChoice, Literal["auto", "required"]]
|
||||||
) -> Optional[Tuple[str, Any]]:
|
) -> Optional[ToolCallConstraint]:
|
||||||
"""
|
"""
|
||||||
Returns the appropriate structure constraint for tool calls based on the tool_choice.
|
Returns the appropriate structure constraint for tool calls based on the tool_choice.
|
||||||
The constraint is used to guide the model's output format.
|
The constraint is used to guide the model's output format.
|
||||||
@@ -178,8 +180,8 @@ class FunctionCallParser:
|
|||||||
and tool_choice == "auto"
|
and tool_choice == "auto"
|
||||||
and any(tool.function.strict for tool in self.tools)
|
and any(tool.function.strict for tool in self.tools)
|
||||||
):
|
):
|
||||||
strict_tag = self.get_structure_tag()
|
tag = self.get_structure_tag()
|
||||||
return ("structural_tag", strict_tag)
|
return ("structural_tag", tag)
|
||||||
elif tool_choice == "required" or isinstance(tool_choice, ToolChoice):
|
elif tool_choice == "required" or isinstance(tool_choice, ToolChoice):
|
||||||
json_schema = get_json_schema_constraint(self.tools, tool_choice)
|
json_schema = get_json_schema_constraint(self.tools, tool_choice)
|
||||||
return ("json_schema", json_schema)
|
return ("json_schema", json_schema)
|
||||||
|
|||||||
126
test/srt/openai_server/features/test_structural_tag.py
Normal file
126
test/srt/openai_server/features/test_structural_tag.py
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
"""
|
||||||
|
python3 -m unittest test.srt.openai_server.features.test_structural_tag
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import unittest
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import openai
|
||||||
|
|
||||||
|
from sglang.srt.utils import kill_process_tree
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||||
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
DEFAULT_URL_FOR_TEST,
|
||||||
|
CustomTestCase,
|
||||||
|
popen_launch_server,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def setup_class(cls, backend: str):
|
||||||
|
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
|
||||||
|
other_args = [
|
||||||
|
"--max-running-requests",
|
||||||
|
"10",
|
||||||
|
"--grammar-backend",
|
||||||
|
backend,
|
||||||
|
]
|
||||||
|
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
cls.model,
|
||||||
|
cls.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=other_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestStructuralTagXGrammarBackend(CustomTestCase):
|
||||||
|
model: str
|
||||||
|
base_url: str
|
||||||
|
process: Any
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
setup_class(cls, backend="xgrammar")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
kill_process_tree(cls.process.pid)
|
||||||
|
|
||||||
|
def test_stag_constant_str_openai(self):
|
||||||
|
client = openai.Client(api_key="EMPTY", base_url=f"{self.base_url}/v1")
|
||||||
|
|
||||||
|
# even when the answer is ridiculous, the model should follow the instruction
|
||||||
|
answer = "The capital of France is Berlin."
|
||||||
|
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model=self.model,
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": "You are a helpful AI assistant"},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Introduce the capital of France. Return in a JSON format.",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=128,
|
||||||
|
response_format={
|
||||||
|
"type": "structural_tag",
|
||||||
|
"format": {
|
||||||
|
"type": "const_string",
|
||||||
|
"value": answer,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
text = response.choices[0].message.content
|
||||||
|
self.assertEqual(text, answer)
|
||||||
|
|
||||||
|
def test_stag_json_schema_openai(self):
|
||||||
|
client = openai.Client(api_key="EMPTY", base_url=f"{self.base_url}/v1")
|
||||||
|
json_schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": {"type": "string", "pattern": "^[\\w]+$"},
|
||||||
|
"population": {"type": "integer"},
|
||||||
|
},
|
||||||
|
"required": ["name", "population"],
|
||||||
|
"additionalProperties": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model=self.model,
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": "You are a helpful AI assistant"},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Introduce the capital of France. Return in a JSON format.",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=128,
|
||||||
|
response_format={
|
||||||
|
"type": "structural_tag",
|
||||||
|
"format": {
|
||||||
|
"type": "json_schema",
|
||||||
|
"json_schema": json_schema,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
text = response.choices[0].message.content
|
||||||
|
try:
|
||||||
|
js_obj = json.loads(text)
|
||||||
|
except (TypeError, json.decoder.JSONDecodeError):
|
||||||
|
print("JSONDecodeError", text)
|
||||||
|
raise
|
||||||
|
|
||||||
|
self.assertIsInstance(js_obj["name"], str)
|
||||||
|
self.assertIsInstance(js_obj["population"], int)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user