feat(Tool Calling): Support required and specific function mode (#6550)
This commit is contained in:
@@ -54,10 +54,12 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"Note that `--tool-call-parser` defines the parser used to interpret responses. Currently supported parsers include:\n",
|
"Note that `--tool-call-parser` defines the parser used to interpret responses. Currently supported parsers include:\n",
|
||||||
"\n",
|
"\n",
|
||||||
"- llama3: Llama 3.1 / 3.2 (e.g. meta-llama/Llama-3.1-8B-Instruct, meta-llama/Llama-3.2-1B-Instruct).\n",
|
"- llama3: Llama 3.1 / 3.2 / 3.3 (e.g. meta-llama/Llama-3.1-8B-Instruct, meta-llama/Llama-3.2-1B-Instruct, meta-llama/Llama-3.3-70B-Instruct).\n",
|
||||||
|
"- llama4: Llama 4 (e.g. meta-llama/Llama-4-Scout-17B-16E-Instruct).\n",
|
||||||
"- mistral: Mistral (e.g. mistralai/Mistral-7B-Instruct-v0.3, mistralai/Mistral-Nemo-Instruct-2407, mistralai/\n",
|
"- mistral: Mistral (e.g. mistralai/Mistral-7B-Instruct-v0.3, mistralai/Mistral-Nemo-Instruct-2407, mistralai/\n",
|
||||||
"Mistral-Nemo-Instruct-2407, mistralai/Mistral-7B-v0.3).\n",
|
"Mistral-Nemo-Instruct-2407, mistralai/Mistral-7B-v0.3).\n",
|
||||||
"- qwen25: Qwen 2.5 (e.g. Qwen/Qwen2.5-1.5B-Instruct, Qwen/Qwen2.5-7B-Instruct) and QwQ (i.e. Qwen/QwQ-32B). Especially, for QwQ, we can enable the reasoning parser together with tool call parser, details about reasoning parser can be found in [reasoning parser](https://docs.sglang.ai/backend/separate_reasoning.html)."
|
"- qwen25: Qwen 2.5 (e.g. Qwen/Qwen2.5-1.5B-Instruct, Qwen/Qwen2.5-7B-Instruct) and QwQ (i.e. Qwen/QwQ-32B). Especially, for QwQ, we can enable the reasoning parser together with tool call parser, details about reasoning parser can be found in [reasoning parser](https://docs.sglang.ai/backend/separate_reasoning.html).\n",
|
||||||
|
"- deepseekv3: DeepSeek-v3 (e.g., deepseek-ai/DeepSeek-V3-0324).\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -360,6 +362,164 @@
|
|||||||
"print(final_response.choices[0].message.content)"
|
"print(final_response.choices[0].message.content)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Tool Choice Mode\n",
|
||||||
|
"\n",
|
||||||
|
"SGLang supports OpenAI's `tool_choice` parameter to control when and which tools the model should call. This feature is implemented using EBNF (Extended Backus-Naur Form) grammar to ensure reliable tool calling behavior.\n",
|
||||||
|
"\n",
|
||||||
|
"### Supported Tool Choice Options\n",
|
||||||
|
"\n",
|
||||||
|
"- **`tool_choice=\"required\"`**: Forces the model to call at least one tool\n",
|
||||||
|
"- **`tool_choice={\"type\": \"function\", \"function\": {\"name\": \"specific_function\"}}`**: Forces the model to call a specific function\n",
|
||||||
|
"\n",
|
||||||
|
"### Backend Compatibility\n",
|
||||||
|
"\n",
|
||||||
|
"Tool choice is fully supported with the **Xgrammar backend**, which is the default grammar backend (`--grammar-backend xgrammar`). However, it may not be fully supported with other backends such as `outlines`.\n",
|
||||||
|
"\n",
|
||||||
|
"### Example: Required Tool Choice"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Response with tool_choice='required':\n",
|
||||||
|
"Content: None\n",
|
||||||
|
"Tool calls: [ChatCompletionMessageToolCall(id='call_NFO3TSZuRRO8Eu3Cv79uiQ', function=Function(arguments='{\"city\": \"Paris\", \"unit\": \"celsius\"}', name='get_current_weather'), type='function', index=0)]\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"from openai import OpenAI\n",
|
||||||
|
"import json\n",
|
||||||
|
"from sglang.utils import wait_for_server, print_highlight, terminate_process\n",
|
||||||
|
"from sglang.test.test_utils import is_in_ci\n",
|
||||||
|
"\n",
|
||||||
|
"if is_in_ci():\n",
|
||||||
|
" from patch import launch_server_cmd\n",
|
||||||
|
"else:\n",
|
||||||
|
" from sglang.utils import launch_server_cmd\n",
|
||||||
|
" import nest_asyncio\n",
|
||||||
|
"\n",
|
||||||
|
" nest_asyncio.apply()\n",
|
||||||
|
"\n",
|
||||||
|
"# Start a new server session for tool choice examples\n",
|
||||||
|
"server_process_tool_choice, port_tool_choice = launch_server_cmd(\n",
|
||||||
|
" \"python3 -m sglang.launch_server --model-path Qwen/Qwen2.5-7B-Instruct --tool-call-parser qwen25 --host 0.0.0.0\"\n",
|
||||||
|
")\n",
|
||||||
|
"wait_for_server(f\"http://localhost:{port_tool_choice}\")\n",
|
||||||
|
"\n",
|
||||||
|
"# Initialize client for tool choice examples\n",
|
||||||
|
"client_tool_choice = OpenAI(\n",
|
||||||
|
" api_key=\"None\", base_url=f\"http://0.0.0.0:{port_tool_choice}/v1\"\n",
|
||||||
|
")\n",
|
||||||
|
"model_name_tool_choice = client_tool_choice.models.list().data[0].id\n",
|
||||||
|
"\n",
|
||||||
|
"# Example with tool_choice=\"required\" - forces the model to call a tool\n",
|
||||||
|
"messages_required = [\n",
|
||||||
|
" {\"role\": \"user\", \"content\": \"Hello, what is the capital of France?\"}\n",
|
||||||
|
"]\n",
|
||||||
|
"\n",
|
||||||
|
"# Define tools\n",
|
||||||
|
"tools = [\n",
|
||||||
|
" {\n",
|
||||||
|
" \"type\": \"function\",\n",
|
||||||
|
" \"function\": {\n",
|
||||||
|
" \"name\": \"get_current_weather\",\n",
|
||||||
|
" \"description\": \"Get the current weather in a given location\",\n",
|
||||||
|
" \"parameters\": {\n",
|
||||||
|
" \"type\": \"object\",\n",
|
||||||
|
" \"properties\": {\n",
|
||||||
|
" \"city\": {\n",
|
||||||
|
" \"type\": \"string\",\n",
|
||||||
|
" \"description\": \"The city to find the weather for, e.g. 'San Francisco'\",\n",
|
||||||
|
" },\n",
|
||||||
|
" \"unit\": {\n",
|
||||||
|
" \"type\": \"string\",\n",
|
||||||
|
" \"description\": \"The unit to fetch the temperature in\",\n",
|
||||||
|
" \"enum\": [\"celsius\", \"fahrenheit\"],\n",
|
||||||
|
" },\n",
|
||||||
|
" },\n",
|
||||||
|
" \"required\": [\"city\", \"unit\"],\n",
|
||||||
|
" },\n",
|
||||||
|
" },\n",
|
||||||
|
" }\n",
|
||||||
|
"]\n",
|
||||||
|
"\n",
|
||||||
|
"response_required = client_tool_choice.chat.completions.create(\n",
|
||||||
|
" model=model_name_tool_choice,\n",
|
||||||
|
" messages=messages_required,\n",
|
||||||
|
" temperature=0,\n",
|
||||||
|
" max_tokens=1024,\n",
|
||||||
|
" tools=tools,\n",
|
||||||
|
" tool_choice=\"required\", # Force the model to call a tool\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"print_highlight(\"Response with tool_choice='required':\")\n",
|
||||||
|
"print(\"Content:\", response_required.choices[0].message.content)\n",
|
||||||
|
"print(\"Tool calls:\", response_required.choices[0].message.tool_calls)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Example: Specific Function Choice\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Response with specific function choice:\n",
|
||||||
|
"Content: None\n",
|
||||||
|
"Tool calls: [ChatCompletionMessageToolCall(id='call_fGL_1qsPQFqntNBPkSynJw', function=Function(arguments='{\"city\": \"Sophia Antipolis\", \"unit\": \"celsius\"}', name='get_current_weather'), type='function', index=0)]\n",
|
||||||
|
"Called function: get_current_weather\n",
|
||||||
|
"Arguments: {\"city\": \"Sophia Antipolis\", \"unit\": \"celsius\"}\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# Example with specific function choice - forces the model to call a specific function\n",
|
||||||
|
"messages_specific = [\n",
|
||||||
|
" {\"role\": \"user\", \"content\": \"What are the most attactive places in France?\"}\n",
|
||||||
|
"]\n",
|
||||||
|
"\n",
|
||||||
|
"response_specific = client_tool_choice.chat.completions.create(\n",
|
||||||
|
" model=model_name_tool_choice,\n",
|
||||||
|
" messages=messages_specific,\n",
|
||||||
|
" temperature=0,\n",
|
||||||
|
" max_tokens=1024,\n",
|
||||||
|
" tools=tools,\n",
|
||||||
|
" tool_choice={\n",
|
||||||
|
" \"type\": \"function\",\n",
|
||||||
|
" \"function\": {\"name\": \"get_current_weather\"},\n",
|
||||||
|
" }, # Force the model to call the specific get_current_weather function\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"print_highlight(\"Response with specific function choice:\")\n",
|
||||||
|
"print(\"Content:\", response_specific.choices[0].message.content)\n",
|
||||||
|
"print(\"Tool calls:\", response_specific.choices[0].message.tool_calls)\n",
|
||||||
|
"\n",
|
||||||
|
"if response_specific.choices[0].message.tool_calls:\n",
|
||||||
|
" tool_call = response_specific.choices[0].message.tool_calls[0]\n",
|
||||||
|
" print(f\"Called function: {tool_call.function.name}\")\n",
|
||||||
|
" print(f\"Arguments: {tool_call.function.arguments}\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
@@ -444,7 +604,7 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"import sglang as sgl\n",
|
"import sglang as sgl\n",
|
||||||
"from sglang.srt.function_call_parser import FunctionCallParser\n",
|
"from sglang.srt.function_call.function_call_parser import FunctionCallParser\n",
|
||||||
"from sglang.srt.managers.io_struct import Tool, Function\n",
|
"from sglang.srt.managers.io_struct import Tool, Function\n",
|
||||||
"\n",
|
"\n",
|
||||||
"llm = sgl.Engine(model_path=\"Qwen/Qwen2.5-7B-Instruct\")\n",
|
"llm = sgl.Engine(model_path=\"Qwen/Qwen2.5-7B-Instruct\")\n",
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ from sglang.srt.disaggregation.utils import (
|
|||||||
register_disaggregation_server,
|
register_disaggregation_server,
|
||||||
)
|
)
|
||||||
from sglang.srt.entrypoints.engine import _launch_subprocesses
|
from sglang.srt.entrypoints.engine import _launch_subprocesses
|
||||||
from sglang.srt.function_call_parser import FunctionCallParser
|
from sglang.srt.function_call.function_call_parser import FunctionCallParser
|
||||||
from sglang.srt.managers.io_struct import (
|
from sglang.srt.managers.io_struct import (
|
||||||
AbortReq,
|
AbortReq,
|
||||||
CloseSessionReqInput,
|
CloseSessionReqInput,
|
||||||
|
|||||||
250
python/sglang/srt/function_call/base_format_detector.py
Normal file
250
python/sglang/srt/function_call/base_format_detector.py
Normal file
@@ -0,0 +1,250 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
from partial_json_parser.core.exceptions import MalformedJSON
|
||||||
|
from partial_json_parser.core.options import Allow
|
||||||
|
|
||||||
|
from sglang.srt.function_call.core_types import (
|
||||||
|
StreamingParseResult,
|
||||||
|
ToolCallItem,
|
||||||
|
_GetInfoFunc,
|
||||||
|
)
|
||||||
|
from sglang.srt.function_call.utils import (
|
||||||
|
_find_common_prefix,
|
||||||
|
_is_complete_json,
|
||||||
|
_partial_json_loads,
|
||||||
|
)
|
||||||
|
from sglang.srt.openai_api.protocol import Tool
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseFormatDetector(ABC):
|
||||||
|
"""Base class providing two sets of interfaces: one-time and streaming incremental."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# initialize properties used for state when parsing tool calls in
|
||||||
|
self._buffer = ""
|
||||||
|
# streaming mode
|
||||||
|
self.prev_tool_call_arr: List[Dict] = []
|
||||||
|
self.current_tool_id: int = -1
|
||||||
|
self.current_tool_name_sent: bool = False
|
||||||
|
self.streamed_args_for_tool: List[str] = (
|
||||||
|
[]
|
||||||
|
) # map what has been streamed for each tool so far to a list
|
||||||
|
self.bot_token = ""
|
||||||
|
self.eot_token = ""
|
||||||
|
|
||||||
|
def parse_base_json(self, action: Any, tools: List[Tool]) -> List[ToolCallItem]:
|
||||||
|
tool_indices = {
|
||||||
|
tool.function.name: i for i, tool in enumerate(tools) if tool.function.name
|
||||||
|
}
|
||||||
|
if not isinstance(action, list):
|
||||||
|
action = [action]
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for act in action:
|
||||||
|
name = act.get("name")
|
||||||
|
if name and name in tool_indices:
|
||||||
|
results.append(
|
||||||
|
ToolCallItem(
|
||||||
|
tool_index=tool_indices[name],
|
||||||
|
name=name,
|
||||||
|
parameters=json.dumps(
|
||||||
|
act.get("parameters") or act.get("arguments", {}),
|
||||||
|
ensure_ascii=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(f"Model attempted to call undefined function: {name}")
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
||||||
|
"""
|
||||||
|
Parses the text in one go. Returns success=True if the format matches, otherwise False.
|
||||||
|
Note that leftover_text here represents "content that this parser will not consume further".
|
||||||
|
"""
|
||||||
|
action = json.loads(text)
|
||||||
|
return StreamingParseResult(calls=self.parse_base_json(action, tools))
|
||||||
|
|
||||||
|
def parse_streaming_increment(
|
||||||
|
self, new_text: str, tools: List[Tool]
|
||||||
|
) -> StreamingParseResult:
|
||||||
|
"""
|
||||||
|
Streaming incremental parsing with tool validation.
|
||||||
|
"""
|
||||||
|
# Append new text to buffer
|
||||||
|
self._buffer += new_text
|
||||||
|
current_text = self._buffer
|
||||||
|
if not (self.bot_token in current_text or current_text.startswith("{")):
|
||||||
|
self._buffer = ""
|
||||||
|
if self.eot_token in new_text:
|
||||||
|
new_text = new_text.replace(self.eot_token, "")
|
||||||
|
return StreamingParseResult(normal_text=new_text)
|
||||||
|
|
||||||
|
# Build tool indices if not already built
|
||||||
|
if not hasattr(self, "_tool_indices"):
|
||||||
|
self._tool_indices = {
|
||||||
|
tool.function.name: i
|
||||||
|
for i, tool in enumerate(tools)
|
||||||
|
if tool.function and tool.function.name
|
||||||
|
}
|
||||||
|
|
||||||
|
flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR
|
||||||
|
try:
|
||||||
|
tool_call_arr = []
|
||||||
|
is_complete = []
|
||||||
|
try:
|
||||||
|
start_idx = (
|
||||||
|
len(self.bot_token)
|
||||||
|
if current_text.startswith(self.bot_token)
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
while start_idx < len(current_text):
|
||||||
|
(obj, end_idx) = _partial_json_loads(
|
||||||
|
current_text[start_idx:], flags
|
||||||
|
)
|
||||||
|
is_complete.append(
|
||||||
|
_is_complete_json(current_text[start_idx : start_idx + end_idx])
|
||||||
|
)
|
||||||
|
start_idx += end_idx + len("; ")
|
||||||
|
|
||||||
|
# Validate tool name if present
|
||||||
|
if "name" in obj and obj["name"] not in self._tool_indices:
|
||||||
|
# Invalid tool name - reset state
|
||||||
|
self._buffer = ""
|
||||||
|
self.current_tool_id = -1
|
||||||
|
self.current_tool_name_sent = False
|
||||||
|
if self.streamed_args_for_tool:
|
||||||
|
self.streamed_args_for_tool.pop()
|
||||||
|
return StreamingParseResult()
|
||||||
|
|
||||||
|
# Handle parameters/arguments consistency
|
||||||
|
if "parameters" in obj:
|
||||||
|
assert (
|
||||||
|
"arguments" not in obj
|
||||||
|
), "model generated both parameters and arguments"
|
||||||
|
obj["arguments"] = obj["parameters"]
|
||||||
|
tool_call_arr.append(obj)
|
||||||
|
|
||||||
|
except MalformedJSON:
|
||||||
|
return StreamingParseResult()
|
||||||
|
|
||||||
|
if len(tool_call_arr) == 0:
|
||||||
|
return StreamingParseResult()
|
||||||
|
|
||||||
|
current_tool_call: Dict = (
|
||||||
|
tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle new tool in array
|
||||||
|
if len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1:
|
||||||
|
if self.current_tool_id >= 0:
|
||||||
|
cur_arguments = current_tool_call.get("arguments")
|
||||||
|
if cur_arguments:
|
||||||
|
cur_args_json = json.dumps(cur_arguments)
|
||||||
|
sent = len(self.streamed_args_for_tool[self.current_tool_id])
|
||||||
|
argument_diff = cur_args_json[sent:]
|
||||||
|
|
||||||
|
res = StreamingParseResult(
|
||||||
|
calls=[
|
||||||
|
ToolCallItem(
|
||||||
|
tool_index=self.current_tool_id,
|
||||||
|
name="",
|
||||||
|
parameters=argument_diff,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
self.streamed_args_for_tool[
|
||||||
|
self.current_tool_id
|
||||||
|
] += argument_diff
|
||||||
|
else:
|
||||||
|
res = StreamingParseResult()
|
||||||
|
else:
|
||||||
|
res = StreamingParseResult()
|
||||||
|
|
||||||
|
self.current_tool_id = len(tool_call_arr) - 1
|
||||||
|
self.current_tool_name_sent = False
|
||||||
|
self.streamed_args_for_tool.append("")
|
||||||
|
return res
|
||||||
|
|
||||||
|
# Handle tool name
|
||||||
|
elif not self.current_tool_name_sent:
|
||||||
|
function_name = current_tool_call.get("name")
|
||||||
|
if function_name and function_name in self._tool_indices:
|
||||||
|
res = StreamingParseResult(
|
||||||
|
calls=[
|
||||||
|
ToolCallItem(
|
||||||
|
tool_index=self._tool_indices[function_name],
|
||||||
|
name=function_name,
|
||||||
|
parameters="",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
self.current_tool_name_sent = True
|
||||||
|
else:
|
||||||
|
res = StreamingParseResult()
|
||||||
|
|
||||||
|
# Handle streaming arguments
|
||||||
|
else:
|
||||||
|
cur_arguments = current_tool_call.get("arguments")
|
||||||
|
res = StreamingParseResult()
|
||||||
|
|
||||||
|
if cur_arguments:
|
||||||
|
sent = len(self.streamed_args_for_tool[self.current_tool_id])
|
||||||
|
cur_args_json = json.dumps(cur_arguments)
|
||||||
|
prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get(
|
||||||
|
"arguments"
|
||||||
|
)
|
||||||
|
|
||||||
|
argument_diff = None
|
||||||
|
if is_complete[self.current_tool_id]:
|
||||||
|
argument_diff = cur_args_json[sent:]
|
||||||
|
self._buffer = ""
|
||||||
|
self.prev_tool_call_arr[self.current_tool_id].clear()
|
||||||
|
self.current_tool_name_sent = False
|
||||||
|
self.streamed_args_for_tool[self.current_tool_id] = ""
|
||||||
|
|
||||||
|
elif prev_arguments:
|
||||||
|
prev_args_json = json.dumps(prev_arguments)
|
||||||
|
if cur_args_json != prev_args_json:
|
||||||
|
prefix = _find_common_prefix(prev_args_json, cur_args_json)
|
||||||
|
argument_diff = prefix[sent:]
|
||||||
|
|
||||||
|
if argument_diff is not None:
|
||||||
|
res = StreamingParseResult(
|
||||||
|
calls=[
|
||||||
|
ToolCallItem(
|
||||||
|
tool_index=self.current_tool_id,
|
||||||
|
parameters=argument_diff,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
if not is_complete[self.current_tool_id]:
|
||||||
|
self.streamed_args_for_tool[
|
||||||
|
self.current_tool_id
|
||||||
|
] += argument_diff
|
||||||
|
|
||||||
|
self.prev_tool_call_arr = tool_call_arr
|
||||||
|
return res
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in parse_streaming_increment: {e}")
|
||||||
|
return StreamingParseResult()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def has_tool_call(self, text: str) -> bool:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def structure_info(self) -> _GetInfoFunc:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def build_ebnf(self, tools: List[Tool]) -> str:
|
||||||
|
raise NotImplementedError()
|
||||||
34
python/sglang/srt/function_call/core_types.py
Normal file
34
python/sglang/srt/function_call/core_types.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Callable, List, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class ToolCallItem(BaseModel):
|
||||||
|
"""Simple encapsulation of the parsed ToolCall result for easier usage in streaming contexts."""
|
||||||
|
|
||||||
|
tool_index: int
|
||||||
|
name: Optional[str] = None
|
||||||
|
parameters: str # JSON string
|
||||||
|
|
||||||
|
|
||||||
|
class StreamingParseResult(BaseModel):
|
||||||
|
"""Result of streaming incremental parsing."""
|
||||||
|
|
||||||
|
normal_text: str = ""
|
||||||
|
calls: List[ToolCallItem] = []
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StructureInfo:
|
||||||
|
begin: str
|
||||||
|
end: str
|
||||||
|
trigger: str
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Helper alias of function
|
||||||
|
Usually it is a function that takes a name string and returns a StructureInfo object,
|
||||||
|
which can be used to construct a structural_tag object
|
||||||
|
"""
|
||||||
|
_GetInfoFunc = Callable[[str], StructureInfo]
|
||||||
157
python/sglang/srt/function_call/deepseekv3_detector.py
Normal file
157
python/sglang/srt/function_call/deepseekv3_detector.py
Normal file
@@ -0,0 +1,157 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
||||||
|
from sglang.srt.function_call.core_types import (
|
||||||
|
StreamingParseResult,
|
||||||
|
StructureInfo,
|
||||||
|
ToolCallItem,
|
||||||
|
_GetInfoFunc,
|
||||||
|
)
|
||||||
|
from sglang.srt.function_call.ebnf_composer import EBNFComposer
|
||||||
|
from sglang.srt.function_call.utils import _is_complete_json
|
||||||
|
from sglang.srt.openai_api.protocol import Tool
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DeepSeekV3Detector(BaseFormatDetector):
|
||||||
|
"""
|
||||||
|
Detector for DeepSeek models.
|
||||||
|
Assumes function call format:
|
||||||
|
'<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_current_weather\n```json\n{"location": "Tokyo"}\n```<|tool▁call▁end|>\n<|tool▁call▁begin|>function<|tool▁sep|>get_current_weather\n```json\n{"location": "Paris"}\n```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.bot_token = "<|tool▁calls▁begin|>"
|
||||||
|
self.eot_token = "<|tool▁calls▁end|>"
|
||||||
|
self.func_call_regex = r"<|tool▁call▁begin|>.*?<|tool▁call▁end|>"
|
||||||
|
self.func_detail_regex = r"<|tool▁call▁begin|>(.*)<|tool▁sep|>(.*)\n```json\n(.*)\n```<|tool▁call▁end|>"
|
||||||
|
self._last_arguments = ""
|
||||||
|
|
||||||
|
def has_tool_call(self, text: str) -> bool:
|
||||||
|
"""Check if the text contains a deepseek format tool call."""
|
||||||
|
return self.bot_token in text
|
||||||
|
|
||||||
|
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
||||||
|
"""
|
||||||
|
One-time parsing: Detects and parses tool calls in the provided text.
|
||||||
|
|
||||||
|
:param text: The complete text to parse.
|
||||||
|
:param tools: List of available tools.
|
||||||
|
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
|
||||||
|
"""
|
||||||
|
idx = text.find(self.bot_token)
|
||||||
|
normal_text = text[:idx].strip() if idx != -1 else text
|
||||||
|
if self.bot_token not in text:
|
||||||
|
return StreamingParseResult(normal_text=normal_text, calls=[])
|
||||||
|
match_result_list = re.findall(self.func_call_regex, text, re.DOTALL)
|
||||||
|
calls = []
|
||||||
|
try:
|
||||||
|
for match_result in match_result_list:
|
||||||
|
# Get function name
|
||||||
|
func_detail = re.search(self.func_detail_regex, match_result, re.DOTALL)
|
||||||
|
func_name = func_detail.group(2)
|
||||||
|
func_args = func_detail.group(3)
|
||||||
|
func_args = json.loads(func_args)
|
||||||
|
# construct match_result for parse_base_json
|
||||||
|
match_result = {"name": func_name, "parameters": func_args}
|
||||||
|
calls.extend(self.parse_base_json(match_result, tools))
|
||||||
|
return StreamingParseResult(normal_text=normal_text, calls=calls)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in detect_and_parse: {e}")
|
||||||
|
# return the normal text if parsing fails
|
||||||
|
return StreamingParseResult(normal_text=text)
|
||||||
|
|
||||||
|
def parse_streaming_increment(
|
||||||
|
self, new_text: str, tools: List[Tool]
|
||||||
|
) -> StreamingParseResult:
|
||||||
|
"""
|
||||||
|
Streaming incremental parsing tool calls for DeepSeekV3 format.
|
||||||
|
"""
|
||||||
|
self._buffer += new_text
|
||||||
|
current_text = self._buffer
|
||||||
|
|
||||||
|
if self.bot_token not in current_text:
|
||||||
|
self._buffer = ""
|
||||||
|
for e_token in [self.eot_token, "```", "<|tool▁call▁end|>"]:
|
||||||
|
if e_token in new_text:
|
||||||
|
new_text = new_text.replace(e_token, "")
|
||||||
|
return StreamingParseResult(normal_text=new_text)
|
||||||
|
|
||||||
|
if not hasattr(self, "_tool_indices"):
|
||||||
|
self._tool_indices = {
|
||||||
|
tool.function.name: i
|
||||||
|
for i, tool in enumerate(tools)
|
||||||
|
if tool.function and tool.function.name
|
||||||
|
}
|
||||||
|
|
||||||
|
calls: list[ToolCallItem] = []
|
||||||
|
try:
|
||||||
|
partial_match = re.search(
|
||||||
|
pattern=r"<|tool▁call▁begin|>(.*)<|tool▁sep|>(.*)\n```json\n(.*)",
|
||||||
|
string=current_text,
|
||||||
|
flags=re.DOTALL,
|
||||||
|
)
|
||||||
|
if partial_match:
|
||||||
|
func_name = partial_match.group(2).strip()
|
||||||
|
func_args_raw = partial_match.group(3).strip()
|
||||||
|
|
||||||
|
if not self.current_tool_name_sent:
|
||||||
|
calls.append(
|
||||||
|
ToolCallItem(
|
||||||
|
tool_index=self._tool_indices.get(func_name, 0),
|
||||||
|
name=func_name,
|
||||||
|
parameters="",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.current_tool_name_sent = True
|
||||||
|
else:
|
||||||
|
argument_diff = (
|
||||||
|
func_args_raw[len(self._last_arguments) :]
|
||||||
|
if func_args_raw.startswith(self._last_arguments)
|
||||||
|
else func_args_raw
|
||||||
|
)
|
||||||
|
|
||||||
|
if argument_diff:
|
||||||
|
calls.append(
|
||||||
|
ToolCallItem(
|
||||||
|
tool_index=self._tool_indices.get(func_name, 0),
|
||||||
|
name=None,
|
||||||
|
parameters=argument_diff,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self._last_arguments += argument_diff
|
||||||
|
|
||||||
|
if _is_complete_json(func_args_raw):
|
||||||
|
result = StreamingParseResult(normal_text="", calls=calls)
|
||||||
|
self._buffer = ""
|
||||||
|
self._last_arguments = ""
|
||||||
|
self.current_tool_name_sent = False
|
||||||
|
return result
|
||||||
|
|
||||||
|
return StreamingParseResult(normal_text="", calls=calls)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in parse_streaming_increment: {e}")
|
||||||
|
return StreamingParseResult(normal_text=current_text)
|
||||||
|
|
||||||
|
def structure_info(self) -> _GetInfoFunc:
|
||||||
|
return lambda name: StructureInfo(
|
||||||
|
begin=">" + name + "\n```json\n",
|
||||||
|
end="\n```<",
|
||||||
|
trigger=">" + name + "\n```json\n",
|
||||||
|
)
|
||||||
|
|
||||||
|
def build_ebnf(self, tools: List[Tool]):
|
||||||
|
return EBNFComposer.build_ebnf(
|
||||||
|
tools,
|
||||||
|
bot_token=self.bot_token,
|
||||||
|
eot_token=self.eot_token,
|
||||||
|
tool_call_separator="",
|
||||||
|
call_rule_fmt='"<|tool▁call▁begin|>function<|tool▁sep|>{name}\\n```json\\n" {arguments_rule} "\\n```<|tool▁call▁end|>"',
|
||||||
|
function_format="json",
|
||||||
|
)
|
||||||
234
python/sglang/srt/function_call/ebnf_composer.py
Normal file
234
python/sglang/srt/function_call/ebnf_composer.py
Normal file
@@ -0,0 +1,234 @@
|
|||||||
|
from typing import Literal, Optional
|
||||||
|
|
||||||
|
|
||||||
|
class EBNFComposer:
|
||||||
|
# Adapted from https://xgrammar.mlc.ai/docs/how_to/ebnf_guided_generation.html#try-out-via-hf-transformers
|
||||||
|
json_grammar_ebnf_str = r"""
|
||||||
|
json ::= basic_array | basic_object
|
||||||
|
basic_any ::= basic_number | basic_string | basic_boolean | basic_null | basic_array | basic_object
|
||||||
|
basic_integer ::= ("0" | "-"? [1-9] [0-9]*) ".0"?
|
||||||
|
basic_number ::= ("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)?
|
||||||
|
basic_string ::= (([\"] basic_string_1 [\"]))
|
||||||
|
basic_string_1 ::= "" | [^"\\\x00-\x1F] basic_string_1 | "\\" escape basic_string_1
|
||||||
|
escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9]
|
||||||
|
basic_boolean ::= "true" | "false"
|
||||||
|
basic_null ::= "null"
|
||||||
|
basic_array ::= "[" ("" | ws basic_any (ws "," ws basic_any)*) ws "]"
|
||||||
|
basic_object ::= "{" ("" | ws basic_string ws ":" ws basic_any ( ws "," ws basic_string ws ":" ws basic_any)*) ws "}"
|
||||||
|
ws ::= [ \n\t]*
|
||||||
|
"""
|
||||||
|
|
||||||
|
pythonic_grammar_ebnf_str = r"""
|
||||||
|
pythonic ::= basic_number | basic_string | basic_array | "True" | "False" | "None"
|
||||||
|
basic_any ::= basic_number | basic_string | basic_array | basic_object
|
||||||
|
basic_number ::= ("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)?
|
||||||
|
basic_string ::= (([\"] basic_string_1 [\"]))
|
||||||
|
basic_string_1 ::= "" | [^"\\\x00-\x1F] basic_string_1 | "\\" escape basic_string_1
|
||||||
|
escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9]
|
||||||
|
basic_array ::= "[" ("" | ws basic_any (ws "," ws basic_any)*) ws "]"
|
||||||
|
basic_object ::= "{" ("" | ws basic_string ws ":" ws basic_any ( ws "," ws basic_string ws ":" ws basic_any)*) ws "}"
|
||||||
|
ws ::= [ \n\t]*
|
||||||
|
"""
|
||||||
|
|
||||||
|
TOOL_CALLS_MAP = {
|
||||||
|
"pythonic": '"[" function_call ("," function_call)* "]"',
|
||||||
|
"json": "function_call",
|
||||||
|
}
|
||||||
|
|
||||||
|
CALL_RULE_MAP = {
|
||||||
|
"pythonic": 'call_{name} ::= "{name}" "(" {arguments_rule} ")"',
|
||||||
|
"json": 'call_{name} ::= "{{" "\\"name\\"" ":" "\\"{name}\\"" ", " "\\"arguments\\"" ":" {arguments_rule} "}}"',
|
||||||
|
}
|
||||||
|
|
||||||
|
ARGUMENTS_RULE_MAP = {
|
||||||
|
"pythonic": "{arg_rules}",
|
||||||
|
"json": '"{{" {arg_rules} "}}"',
|
||||||
|
}
|
||||||
|
|
||||||
|
KEY_VALUE_RULE_MAP = {
|
||||||
|
"pythonic": '"{key}" "=" {valrule}',
|
||||||
|
"json": '"\\"{key}\\"" ":" {valrule}',
|
||||||
|
}
|
||||||
|
|
||||||
|
JSON_TYPE_MAPPING = {
|
||||||
|
"string": "basic_string",
|
||||||
|
"number": "basic_number",
|
||||||
|
"integer": "basic_number",
|
||||||
|
"boolean": "basic_boolean",
|
||||||
|
"null": "basic_null",
|
||||||
|
"array": "basic_array",
|
||||||
|
"object": "basic_object",
|
||||||
|
}
|
||||||
|
|
||||||
|
PYTHONIC_TYPE_MAPPING = {
|
||||||
|
"string": "basic_string",
|
||||||
|
"number": "basic_number",
|
||||||
|
"integer": "basic_number",
|
||||||
|
"boolean": '"True" | "False"',
|
||||||
|
"null": '"None"',
|
||||||
|
"array": "basic_array",
|
||||||
|
"object": "basic_object",
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_value_rule(
|
||||||
|
prop: dict, function_format: Literal["pythonic", "json"] = "json"
|
||||||
|
) -> str:
|
||||||
|
if "enum" in prop:
|
||||||
|
return EBNFComposer._handle_enum(prop, function_format)
|
||||||
|
|
||||||
|
if "type" in prop:
|
||||||
|
return EBNFComposer._handle_type(prop, function_format)
|
||||||
|
|
||||||
|
return function_format
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _handle_enum(prop: dict, function_format: str) -> str:
|
||||||
|
"""Handle enum properties by formatting each value according to type and format."""
|
||||||
|
enum_values = prop["enum"]
|
||||||
|
prop_type = prop.get("type", "string")
|
||||||
|
|
||||||
|
# Define formatters for different type/format combinations
|
||||||
|
formatters = {
|
||||||
|
("string", "json"): lambda v: f'"\\"{v}\\""',
|
||||||
|
("string", "pythonic"): lambda v: f'"\\"{v}\\""',
|
||||||
|
("number", "json"): str,
|
||||||
|
("number", "pythonic"): str,
|
||||||
|
("integer", "json"): str,
|
||||||
|
("integer", "pythonic"): str,
|
||||||
|
("boolean", "json"): lambda v: "true" if v else "false",
|
||||||
|
("boolean", "pythonic"): lambda v: "True" if v else "False",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Get the formatter or default to string handling
|
||||||
|
formatter = formatters.get(
|
||||||
|
(prop_type, function_format),
|
||||||
|
formatters[("string", function_format)], # Default to string handling
|
||||||
|
)
|
||||||
|
|
||||||
|
formatted_values = [formatter(value) for value in enum_values]
|
||||||
|
enum_rule = " | ".join(formatted_values)
|
||||||
|
|
||||||
|
# Wrap in parentheses if there are multiple values to ensure correct EBNF precedence
|
||||||
|
if len(formatted_values) > 1:
|
||||||
|
enum_rule = f"({enum_rule})"
|
||||||
|
|
||||||
|
return enum_rule
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _handle_type(prop: dict, function_format: str) -> str:
|
||||||
|
"""Handle type properties using the appropriate type mapping."""
|
||||||
|
prop_type = prop["type"]
|
||||||
|
type_mapping = (
|
||||||
|
EBNFComposer.PYTHONIC_TYPE_MAPPING
|
||||||
|
if function_format == "pythonic"
|
||||||
|
else EBNFComposer.JSON_TYPE_MAPPING
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(prop_type, list):
|
||||||
|
type_rules = [
|
||||||
|
type_mapping[single_type]
|
||||||
|
for single_type in prop_type
|
||||||
|
if single_type in type_mapping
|
||||||
|
]
|
||||||
|
return " | ".join(type_rules) if type_rules else function_format
|
||||||
|
|
||||||
|
return type_mapping.get(prop_type, function_format)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def build_ebnf(
|
||||||
|
tools,
|
||||||
|
*,
|
||||||
|
call_rule_fmt: Optional[str] = None,
|
||||||
|
function_format: Literal["pythonic", "json"] = "json",
|
||||||
|
bot_token: Optional[str] = None,
|
||||||
|
eot_token: Optional[str] = None,
|
||||||
|
tool_call_separator: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Generalized EBNF builder for all detectors.
|
||||||
|
Args:
|
||||||
|
tools: List of Tool objects to generate EBNF grammar for
|
||||||
|
call_rule_fmt: Optional custom format string for call_{name} rule. It should define each function call's format, with
|
||||||
|
the placeholders {name} for the function name and {arguments_rule} for the arguments rule. If None, a default
|
||||||
|
format based on function_format will be used.
|
||||||
|
function_format: The format of function calls, either "pythonic" or "json"
|
||||||
|
bot_token: The token that indicates the start of a tool call section
|
||||||
|
eot_token: The token that indicates the end of a tool call section
|
||||||
|
tool_call_separator: The separator between multiple tool calls
|
||||||
|
"""
|
||||||
|
# =================================================================
|
||||||
|
# Step 1: Determine the root tool calls rule
|
||||||
|
# =================================================================
|
||||||
|
if bot_token and eot_token:
|
||||||
|
if tool_call_separator:
|
||||||
|
root_rule = f'"{bot_token}" function_call ( "{tool_call_separator}" function_call )* "{eot_token}"'
|
||||||
|
else:
|
||||||
|
root_rule = f'"{bot_token}" function_call "{eot_token}"'
|
||||||
|
else:
|
||||||
|
root_rule = EBNFComposer.TOOL_CALLS_MAP[function_format]
|
||||||
|
|
||||||
|
# =================================================================
|
||||||
|
# Step 2: Build the header rules
|
||||||
|
# =================================================================
|
||||||
|
ebnf_lines = [
|
||||||
|
f"root ::= {root_rule}",
|
||||||
|
"function_call ::= "
|
||||||
|
+ " | ".join([f"call_{tool.function.name}" for tool in tools]),
|
||||||
|
]
|
||||||
|
|
||||||
|
# =================================================================
|
||||||
|
# Step 3: Set up formatting templates
|
||||||
|
# =================================================================
|
||||||
|
call_template = (
|
||||||
|
f"call_{{name}} ::= {call_rule_fmt}"
|
||||||
|
if call_rule_fmt
|
||||||
|
else EBNFComposer.CALL_RULE_MAP[function_format]
|
||||||
|
)
|
||||||
|
args_template = EBNFComposer.ARGUMENTS_RULE_MAP[function_format]
|
||||||
|
key_value_template = EBNFComposer.KEY_VALUE_RULE_MAP[function_format]
|
||||||
|
|
||||||
|
# =================================================================
|
||||||
|
# Step 4: Build rules for each tool
|
||||||
|
# =================================================================
|
||||||
|
for tool in tools:
|
||||||
|
tool_name = tool.function.name
|
||||||
|
params = tool.function.parameters or {}
|
||||||
|
properties = params.get("properties", {})
|
||||||
|
required_props = set(params.get("required", []))
|
||||||
|
|
||||||
|
# Build argument rules for this tool
|
||||||
|
arg_rules = []
|
||||||
|
for prop_name, prop_schema in properties.items():
|
||||||
|
value_rule = EBNFComposer.get_value_rule(prop_schema, function_format)
|
||||||
|
# Create key=value pair
|
||||||
|
pair = key_value_template.format(key=prop_name, valrule=value_rule)
|
||||||
|
|
||||||
|
if prop_name not in required_props:
|
||||||
|
pair = f"[ {pair} ]"
|
||||||
|
|
||||||
|
arg_rules.append(pair)
|
||||||
|
|
||||||
|
# Combine all argument rules
|
||||||
|
combined_args = ' "," '.join(arg_rules) if arg_rules else ""
|
||||||
|
arguments_rule = args_template.format(arg_rules=combined_args)
|
||||||
|
|
||||||
|
# Add the function call rule and its arguments rule
|
||||||
|
ebnf_lines.append(
|
||||||
|
call_template.format(
|
||||||
|
name=tool_name, arguments_rule=f"arguments_{tool_name}"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
ebnf_lines.append(f"arguments_{tool_name} ::= {arguments_rule}")
|
||||||
|
|
||||||
|
# =================================================================
|
||||||
|
# Step 5: Add base grammar rules
|
||||||
|
# =================================================================
|
||||||
|
base_grammar = (
|
||||||
|
EBNFComposer.pythonic_grammar_ebnf_str
|
||||||
|
if function_format == "pythonic"
|
||||||
|
else EBNFComposer.json_grammar_ebnf_str
|
||||||
|
)
|
||||||
|
ebnf_lines.append(base_grammar)
|
||||||
|
|
||||||
|
return "\n".join(ebnf_lines)
|
||||||
175
python/sglang/srt/function_call/function_call_parser.py
Normal file
175
python/sglang/srt/function_call/function_call_parser.py
Normal file
@@ -0,0 +1,175 @@
|
|||||||
|
from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Type, Union
|
||||||
|
|
||||||
|
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
||||||
|
from sglang.srt.function_call.core_types import ToolCallItem
|
||||||
|
from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector
|
||||||
|
from sglang.srt.function_call.llama32_detector import Llama32Detector
|
||||||
|
from sglang.srt.function_call.mistral_detector import MistralDetector
|
||||||
|
from sglang.srt.function_call.pythonic_detector import PythonicDetector
|
||||||
|
from sglang.srt.function_call.qwen25_detector import Qwen25Detector
|
||||||
|
from sglang.srt.openai_api.protocol import (
|
||||||
|
StructuralTagResponseFormat,
|
||||||
|
StructuresResponseFormat,
|
||||||
|
Tool,
|
||||||
|
ToolChoice,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionCallParser:
|
||||||
|
"""
|
||||||
|
Parser for function/tool calls in model outputs.
|
||||||
|
|
||||||
|
This class handles both streaming and non-streaming parsing of function calls using a detector.
|
||||||
|
In streaming scenarios, each time new_text is received, it calls detector.parse_streaming_increment
|
||||||
|
and returns the resulting normal_text and calls to the upper layer (or SSE).
|
||||||
|
"""
|
||||||
|
|
||||||
|
ToolCallParserEnum: Dict[str, Type[BaseFormatDetector]] = {
|
||||||
|
"llama3": Llama32Detector,
|
||||||
|
"qwen25": Qwen25Detector,
|
||||||
|
"mistral": MistralDetector,
|
||||||
|
"deepseekv3": DeepSeekV3Detector,
|
||||||
|
"pythonic": PythonicDetector,
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, tools: List[Tool], tool_call_parser: str):
|
||||||
|
detector: Type[BaseFormatDetector] = None
|
||||||
|
detector_class = self.ToolCallParserEnum.get(tool_call_parser)
|
||||||
|
if detector_class:
|
||||||
|
detector = detector_class()
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported tool_call_parser: {tool_call_parser}")
|
||||||
|
|
||||||
|
self.detector = detector
|
||||||
|
self.tools = tools
|
||||||
|
|
||||||
|
def has_tool_call(self, text: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the given text contains a tool call in the format supported by this parser.
|
||||||
|
This delegates to the detector's implementation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to check for tool calls
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the text contains a tool call, False otherwise
|
||||||
|
"""
|
||||||
|
return self.detector.has_tool_call(text)
|
||||||
|
|
||||||
|
def parse_non_stream(self, full_text: str) -> Tuple[str, list[ToolCallItem]]:
|
||||||
|
"""
|
||||||
|
One-time parsing of the full text to extract tool calls.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
full_text: The complete text to parse
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple containing:
|
||||||
|
- The remaining text after parsing that was not consumed by the detector (can be treated as normal text)
|
||||||
|
- A list of tool calls parsed from the text
|
||||||
|
"""
|
||||||
|
parsed_result = self.detector.detect_and_parse(full_text, self.tools)
|
||||||
|
tool_call_list = parsed_result.calls
|
||||||
|
if tool_call_list:
|
||||||
|
return parsed_result.normal_text, tool_call_list
|
||||||
|
else:
|
||||||
|
return full_text, []
|
||||||
|
|
||||||
|
def parse_stream_chunk(self, chunk_text: str) -> Tuple[str, list[ToolCallItem]]:
|
||||||
|
"""
|
||||||
|
Streaming incremental parsing of chunks of text as they arrive.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chunk_text: The new chunk of text to parse
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple containing:
|
||||||
|
- The normal text that should be displayed to the user
|
||||||
|
- A list of tool calls parsed from the chunk
|
||||||
|
"""
|
||||||
|
final_normal_text = ""
|
||||||
|
final_calls = []
|
||||||
|
|
||||||
|
sp_result = self.detector.parse_streaming_increment(chunk_text, self.tools)
|
||||||
|
if sp_result.normal_text:
|
||||||
|
final_normal_text = sp_result.normal_text
|
||||||
|
if sp_result.calls:
|
||||||
|
final_calls.extend(sp_result.calls)
|
||||||
|
final_normal_text = sp_result.normal_text
|
||||||
|
|
||||||
|
return final_normal_text, final_calls
|
||||||
|
|
||||||
|
def get_structure_tag(self) -> StructuralTagResponseFormat:
|
||||||
|
"""
|
||||||
|
Generate a structural tag response format for all available tools.
|
||||||
|
|
||||||
|
This creates the necessary structural tags that guide the model's output format.
|
||||||
|
"""
|
||||||
|
tool_structures: List[StructuresResponseFormat] = list()
|
||||||
|
tool_trigger_set: Set[str] = set()
|
||||||
|
|
||||||
|
get_structure_info = self.detector.structure_info()
|
||||||
|
for tool in self.tools:
|
||||||
|
function = tool.function
|
||||||
|
name = function.name
|
||||||
|
assert name is not None
|
||||||
|
info = get_structure_info(name)
|
||||||
|
|
||||||
|
# accept all if not strict, otherwise only accept the schema
|
||||||
|
schema = function.parameters if function.strict else {}
|
||||||
|
|
||||||
|
tool_structures.append(
|
||||||
|
StructuresResponseFormat(
|
||||||
|
begin=info.begin,
|
||||||
|
schema=schema, # type: ignore
|
||||||
|
end=info.end,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
tool_trigger_set.add(info.trigger)
|
||||||
|
|
||||||
|
return StructuralTagResponseFormat(
|
||||||
|
type="structural_tag",
|
||||||
|
structures=tool_structures,
|
||||||
|
triggers=list(tool_trigger_set),
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_structure_constraint(
|
||||||
|
self, tool_choice: Union[ToolChoice, Literal["auto", "required"]]
|
||||||
|
) -> Optional[Tuple[str, Any]]:
|
||||||
|
"""
|
||||||
|
Returns the appropriate structure constraint for tool calls based on the tool_choice.
|
||||||
|
The constraint is used to guide the model's output format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_choice: The tool choice setting from the request
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of (constraint_type, constraint_value) to be added to sampling parameters,
|
||||||
|
or None if no constraint applies.
|
||||||
|
"""
|
||||||
|
# NOTE: structural_tag only supports JSON-compatible content between the begin and end.
|
||||||
|
# It cannot parse or validate Python syntax like function calls.
|
||||||
|
if (
|
||||||
|
not isinstance(self.detector, PythonicDetector)
|
||||||
|
and tool_choice == "auto"
|
||||||
|
and any(tool.function.strict for tool in self.tools)
|
||||||
|
):
|
||||||
|
strict_tag = self.get_structure_tag()
|
||||||
|
return ("structural_tag", strict_tag)
|
||||||
|
elif tool_choice == "required" or isinstance(tool_choice, ToolChoice):
|
||||||
|
ebnf = self.get_ebnf(tool_choice)
|
||||||
|
return ("ebnf", ebnf) if ebnf is not None else None
|
||||||
|
|
||||||
|
def get_ebnf(
|
||||||
|
self, tool_choice: Union[ToolChoice, Literal["required"]]
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Get the EBNF grammar for the specified tool choice.
|
||||||
|
"""
|
||||||
|
filtered_tools = []
|
||||||
|
if isinstance(tool_choice, ToolChoice):
|
||||||
|
fn_name = tool_choice.function.name
|
||||||
|
filtered_tools = [t for t in self.tools if t.function.name == fn_name]
|
||||||
|
else:
|
||||||
|
filtered_tools = self.tools
|
||||||
|
return self.detector.build_ebnf(filtered_tools)
|
||||||
74
python/sglang/srt/function_call/llama32_detector.py
Normal file
74
python/sglang/srt/function_call/llama32_detector.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
||||||
|
from sglang.srt.function_call.core_types import (
|
||||||
|
StreamingParseResult,
|
||||||
|
StructureInfo,
|
||||||
|
_GetInfoFunc,
|
||||||
|
)
|
||||||
|
from sglang.srt.function_call.ebnf_composer import EBNFComposer
|
||||||
|
from sglang.srt.openai_api.protocol import Tool
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class Llama32Detector(BaseFormatDetector):
|
||||||
|
"""
|
||||||
|
Detector for Llama 3.2 models.
|
||||||
|
Assumes function call format:
|
||||||
|
<|python_tag|>{"name":"xxx", "arguments":{...}}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.bot_token = "<|python_tag|>"
|
||||||
|
|
||||||
|
def has_tool_call(self, text: str) -> bool:
|
||||||
|
"""Check if the text contains a Llama 3.2 format tool call."""
|
||||||
|
# depending on the prompt format the Llama model may or may not
|
||||||
|
# prefix the output with the <|python_tag|> token
|
||||||
|
return "<|python_tag|>" in text or text.startswith("{")
|
||||||
|
|
||||||
|
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
||||||
|
"""Parse function calls from text, handling multiple JSON objects."""
|
||||||
|
if "<|python_tag|>" not in text and not text.startswith("{"):
|
||||||
|
return StreamingParseResult(normal_text=text, calls=[])
|
||||||
|
|
||||||
|
if "<|python_tag|>" in text:
|
||||||
|
normal_text, action_text = text.split("<|python_tag|>")
|
||||||
|
else:
|
||||||
|
normal_text, action_text = "", text
|
||||||
|
|
||||||
|
# Split by semicolon and process each part
|
||||||
|
json_parts = [part.strip() for part in action_text.split(";") if part.strip()]
|
||||||
|
all_actions = []
|
||||||
|
for part in json_parts:
|
||||||
|
try:
|
||||||
|
# Parse each individual JSON object
|
||||||
|
action = json.loads(part)
|
||||||
|
all_actions.append(action)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.warning(f"Failed to parse JSON part: {part}")
|
||||||
|
logger.warning(f"JSON parse error: {str(e)}")
|
||||||
|
continue
|
||||||
|
calls = []
|
||||||
|
# Only process if we found valid JSON objects
|
||||||
|
if all_actions:
|
||||||
|
calls = self.parse_base_json(all_actions, tools)
|
||||||
|
return StreamingParseResult(normal_text=normal_text, calls=calls)
|
||||||
|
|
||||||
|
def structure_info(self) -> _GetInfoFunc:
|
||||||
|
return lambda name: StructureInfo(
|
||||||
|
begin='<|python_tag|>{"name":"' + name + '", "arguments":',
|
||||||
|
end="}",
|
||||||
|
trigger="<|python_tag|>",
|
||||||
|
)
|
||||||
|
|
||||||
|
def build_ebnf(self, tools: List[Tool]):
|
||||||
|
return EBNFComposer.build_ebnf(
|
||||||
|
tools,
|
||||||
|
function_format="json",
|
||||||
|
tool_call_separator=",",
|
||||||
|
)
|
||||||
84
python/sglang/srt/function_call/mistral_detector.py
Normal file
84
python/sglang/srt/function_call/mistral_detector.py
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
import json
|
||||||
|
import re
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
||||||
|
from sglang.srt.function_call.core_types import (
|
||||||
|
StreamingParseResult,
|
||||||
|
StructureInfo,
|
||||||
|
_GetInfoFunc,
|
||||||
|
)
|
||||||
|
from sglang.srt.function_call.ebnf_composer import EBNFComposer
|
||||||
|
from sglang.srt.openai_api.protocol import Tool
|
||||||
|
|
||||||
|
|
||||||
|
class MistralDetector(BaseFormatDetector):
|
||||||
|
"""
|
||||||
|
Detector for Mistral models.
|
||||||
|
Assumes function call format:
|
||||||
|
[TOOL_CALLS] [{"name":"xxx", "arguments":{...}}]
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""
|
||||||
|
Initializes the detector with necessary state variables.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.bot_token = "[TOOL_CALLS] ["
|
||||||
|
self.eot_token = "]"
|
||||||
|
self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
|
||||||
|
|
||||||
|
def has_tool_call(self, text: str) -> bool:
|
||||||
|
"""Check if the text contains a Mistral format tool call."""
|
||||||
|
return self.bot_token in text
|
||||||
|
|
||||||
|
def _clean_text(self, text: str) -> str:
|
||||||
|
"""
|
||||||
|
clean text to only leave ''[TOOL_CALLS] [{"name": xxx, "arguments": {xxx}}]'
|
||||||
|
for example,
|
||||||
|
text = '[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"location": "Boston, MA", "unit": "fahrenheit"}}]\n\nToday\'s weather in Boston is :{function call result} (in Fahrenheit)\n\nIf you prefer Celsius, please let me know.'
|
||||||
|
return '[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"location": "Boston, MA", "unit": "fahrenheit"}}]'
|
||||||
|
The key pattern is [TOOL_CALLS] [...]
|
||||||
|
"""
|
||||||
|
# TODO: check if Mistral supports multiple tool calls, currently assume only support one tool call
|
||||||
|
find_results = re.findall(r"\[TOOL_CALLS\] \[.*?\]", text, re.DOTALL)
|
||||||
|
if len(find_results) > 0:
|
||||||
|
return find_results[0]
|
||||||
|
else:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
||||||
|
"""
|
||||||
|
One-time parsing: Detects and parses tool calls in the provided text.
|
||||||
|
|
||||||
|
:param text: The complete text to parse.
|
||||||
|
:param tools: List of available tools.
|
||||||
|
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
|
||||||
|
"""
|
||||||
|
idx = text.find(self.bot_token)
|
||||||
|
normal_text = text[:idx].strip() if idx != -1 else text
|
||||||
|
text = self._clean_text(text)
|
||||||
|
tool_content = text.replace("[TOOL_CALLS]", "").strip()
|
||||||
|
raw_tool_calls = self.tool_call_regex.findall(tool_content)
|
||||||
|
calls = []
|
||||||
|
if len(raw_tool_calls) > 0:
|
||||||
|
raw_tool_call = raw_tool_calls[0]
|
||||||
|
function_call_arr = json.loads(raw_tool_call)
|
||||||
|
for match_result in function_call_arr:
|
||||||
|
calls.extend(self.parse_base_json(match_result, tools))
|
||||||
|
return StreamingParseResult(normal_text=normal_text, calls=calls)
|
||||||
|
|
||||||
|
def structure_info(self) -> _GetInfoFunc:
|
||||||
|
return lambda name: StructureInfo(
|
||||||
|
begin='[TOOL_CALLS] [{"name":"' + name + '", "arguments":',
|
||||||
|
end="}]",
|
||||||
|
trigger="[TOOL_CALLS]",
|
||||||
|
)
|
||||||
|
|
||||||
|
def build_ebnf(self, tools: List[Tool]):
|
||||||
|
return EBNFComposer.build_ebnf(
|
||||||
|
tools,
|
||||||
|
bot_token=self.bot_token,
|
||||||
|
eot_token=self.eot_token,
|
||||||
|
function_format="json",
|
||||||
|
)
|
||||||
163
python/sglang/srt/function_call/pythonic_detector.py
Normal file
163
python/sglang/srt/function_call/pythonic_detector.py
Normal file
@@ -0,0 +1,163 @@
|
|||||||
|
import ast
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
||||||
|
from sglang.srt.function_call.core_types import (
|
||||||
|
StreamingParseResult,
|
||||||
|
StructureInfo,
|
||||||
|
ToolCallItem,
|
||||||
|
_GetInfoFunc,
|
||||||
|
)
|
||||||
|
from sglang.srt.function_call.ebnf_composer import EBNFComposer
|
||||||
|
from sglang.srt.openai_api.protocol import Tool
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PythonicDetector(BaseFormatDetector):
|
||||||
|
"""
|
||||||
|
Detector for Llama-3.2 and Llama-4 models with pythonic tool call format.
|
||||||
|
Assumes function call format:
|
||||||
|
[tool1(arg1=val1, arg2=val2), tool2(arg1=val3)]
|
||||||
|
Arguments are Python literals (not JSON).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.tool_call_regex = re.compile(
|
||||||
|
r"\[([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s)?\),\s*)*([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s*)?\)\s*)+\]",
|
||||||
|
re.DOTALL,
|
||||||
|
)
|
||||||
|
|
||||||
|
def has_tool_call(self, text: str) -> bool:
|
||||||
|
return bool(self.tool_call_regex.match(text.strip()))
|
||||||
|
|
||||||
|
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
||||||
|
# Try parsing the text as a Python list of function calls
|
||||||
|
text = text.strip()
|
||||||
|
if not (text.startswith("[") and text.endswith("]")):
|
||||||
|
# Not a pythonic tool call format
|
||||||
|
return StreamingParseResult(normal_text=text, calls=[])
|
||||||
|
try:
|
||||||
|
module = ast.parse(text)
|
||||||
|
parsed = getattr(module.body[0], "value", None)
|
||||||
|
if not (
|
||||||
|
isinstance(parsed, ast.List)
|
||||||
|
and all(isinstance(e, ast.Call) for e in parsed.elts)
|
||||||
|
):
|
||||||
|
return StreamingParseResult(normal_text=text, calls=[])
|
||||||
|
calls = []
|
||||||
|
tool_indices = {
|
||||||
|
tool.function.name: i
|
||||||
|
for i, tool in enumerate(tools)
|
||||||
|
if tool.function.name
|
||||||
|
}
|
||||||
|
for call in parsed.elts:
|
||||||
|
if not isinstance(call.func, ast.Name):
|
||||||
|
continue
|
||||||
|
function_name = call.func.id
|
||||||
|
arguments = {}
|
||||||
|
for keyword in call.keywords:
|
||||||
|
arguments[keyword.arg] = self._get_parameter_value(keyword.value)
|
||||||
|
calls.append(
|
||||||
|
ToolCallItem(
|
||||||
|
tool_index=tool_indices.get(function_name, -1),
|
||||||
|
name=function_name,
|
||||||
|
parameters=json.dumps(arguments, ensure_ascii=False),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return StreamingParseResult(normal_text="", calls=calls)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Error in pythonic tool call parsing.")
|
||||||
|
return StreamingParseResult(normal_text=text, calls=[])
|
||||||
|
|
||||||
|
def _find_matching_bracket(self, buffer: str, start: int) -> int:
|
||||||
|
"""
|
||||||
|
Find the matching closing bracket for the opening bracket at start position.
|
||||||
|
Properly handles nested brackets.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
buffer: The text buffer to search in
|
||||||
|
start: Position of the opening bracket '['
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Position of the matching closing bracket ']', or -1 if not found
|
||||||
|
"""
|
||||||
|
bracket_count = 0
|
||||||
|
for i in range(start, len(buffer)):
|
||||||
|
if buffer[i] == "[":
|
||||||
|
bracket_count += 1
|
||||||
|
elif buffer[i] == "]":
|
||||||
|
bracket_count -= 1
|
||||||
|
if bracket_count == 0:
|
||||||
|
return i
|
||||||
|
return -1 # No matching bracket found
|
||||||
|
|
||||||
|
def parse_streaming_increment(
|
||||||
|
self, new_text: str, tools: List[Tool]
|
||||||
|
) -> StreamingParseResult:
|
||||||
|
"""
|
||||||
|
Streaming incremental parsing for pythonic tool calls.
|
||||||
|
Buffers input until a complete pythonic tool call (from [ to ]) is found,
|
||||||
|
then parses and emits any detected calls.
|
||||||
|
"""
|
||||||
|
self._buffer += new_text
|
||||||
|
start = self._buffer.find("[")
|
||||||
|
|
||||||
|
if start == -1:
|
||||||
|
normal_text = self._buffer
|
||||||
|
self._buffer = ""
|
||||||
|
return StreamingParseResult(normal_text=normal_text)
|
||||||
|
|
||||||
|
normal_text = self._buffer[:start] if start > 0 else ""
|
||||||
|
|
||||||
|
end = self._find_matching_bracket(self._buffer, start)
|
||||||
|
if end != -1:
|
||||||
|
call_text = self._buffer[start : end + 1]
|
||||||
|
result = self.detect_and_parse(call_text, tools)
|
||||||
|
self._buffer = self._buffer[end + 1 :]
|
||||||
|
|
||||||
|
# If we had normal text before the tool call, add it to the result
|
||||||
|
if normal_text:
|
||||||
|
result.normal_text = normal_text + (result.normal_text or "")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
# We have an opening bracket but no closing bracket yet
|
||||||
|
if normal_text:
|
||||||
|
self._buffer = self._buffer[start:]
|
||||||
|
return StreamingParseResult(normal_text=normal_text)
|
||||||
|
|
||||||
|
# Otherwise, we're still accumulating a potential tool call
|
||||||
|
return StreamingParseResult(normal_text="")
|
||||||
|
|
||||||
|
def _get_parameter_value(self, val):
|
||||||
|
if isinstance(val, ast.Constant):
|
||||||
|
return val.value
|
||||||
|
elif isinstance(val, ast.Dict):
|
||||||
|
return {
|
||||||
|
k.value: self._get_parameter_value(v)
|
||||||
|
for k, v in zip(val.keys, val.values)
|
||||||
|
}
|
||||||
|
elif isinstance(val, ast.List):
|
||||||
|
return [self._get_parameter_value(v) for v in val.elts]
|
||||||
|
else:
|
||||||
|
raise ValueError("Tool call arguments must be literals")
|
||||||
|
|
||||||
|
def structure_info(self) -> _GetInfoFunc:
|
||||||
|
def info(name: str):
|
||||||
|
return StructureInfo(begin=f"[{name}(", end=")]", trigger=f"[{name}(")
|
||||||
|
|
||||||
|
return info
|
||||||
|
|
||||||
|
def build_ebnf(self, tools: List[Tool]) -> Optional[str]:
|
||||||
|
return EBNFComposer.build_ebnf(
|
||||||
|
tools,
|
||||||
|
bot_token="[",
|
||||||
|
eot_token="]",
|
||||||
|
tool_call_separator=",",
|
||||||
|
function_format="pythonic",
|
||||||
|
)
|
||||||
67
python/sglang/srt/function_call/qwen25_detector.py
Normal file
67
python/sglang/srt/function_call/qwen25_detector.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
import json
|
||||||
|
import re
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
||||||
|
from sglang.srt.function_call.core_types import (
|
||||||
|
StreamingParseResult,
|
||||||
|
StructureInfo,
|
||||||
|
_GetInfoFunc,
|
||||||
|
)
|
||||||
|
from sglang.srt.function_call.ebnf_composer import EBNFComposer
|
||||||
|
from sglang.srt.openai_api.protocol import Tool
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen25Detector(BaseFormatDetector):
|
||||||
|
"""
|
||||||
|
Detector for Qwen 2.5 models.
|
||||||
|
Assumes function call format:
|
||||||
|
<tool_call>{"name":"xxx", "arguments":{...}}</tool_call>
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""
|
||||||
|
Initializes the detector with necessary state variables.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.bot_token = "<tool_call>"
|
||||||
|
self.eot_token = "</tool_call>"
|
||||||
|
|
||||||
|
def has_tool_call(self, text: str) -> bool:
|
||||||
|
"""Check if the text contains a Qwen 2.5 format tool call."""
|
||||||
|
return self.bot_token in text
|
||||||
|
|
||||||
|
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
||||||
|
"""
|
||||||
|
One-time parsing: Detects and parses tool calls in the provided text.
|
||||||
|
|
||||||
|
:param text: The complete text to parse.
|
||||||
|
:param tools: List of available tools.
|
||||||
|
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
|
||||||
|
"""
|
||||||
|
idx = text.find(self.bot_token)
|
||||||
|
normal_text = text[:idx].strip() if idx != -1 else text
|
||||||
|
if self.bot_token not in text:
|
||||||
|
return StreamingParseResult(normal_text=normal_text, calls=[])
|
||||||
|
pattern = rf"{self.bot_token}(.*?){self.eot_token}"
|
||||||
|
match_result_list = re.findall(pattern, text, re.DOTALL)
|
||||||
|
calls = []
|
||||||
|
for match_result in match_result_list:
|
||||||
|
match_result = json.loads(match_result)
|
||||||
|
calls.extend(self.parse_base_json(match_result, tools))
|
||||||
|
return StreamingParseResult(normal_text=normal_text, calls=calls)
|
||||||
|
|
||||||
|
def structure_info(self) -> _GetInfoFunc:
|
||||||
|
return lambda name: StructureInfo(
|
||||||
|
begin='<tool_call>{"name":"' + name + '", "arguments":',
|
||||||
|
end="}</tool_call>",
|
||||||
|
trigger="<tool_call>",
|
||||||
|
)
|
||||||
|
|
||||||
|
def build_ebnf(self, tools: List[Tool]):
|
||||||
|
return EBNFComposer.build_ebnf(
|
||||||
|
tools,
|
||||||
|
bot_token=self.bot_token,
|
||||||
|
eot_token=self.eot_token,
|
||||||
|
function_format="json",
|
||||||
|
)
|
||||||
35
python/sglang/srt/function_call/utils.py
Normal file
35
python/sglang/srt/function_call/utils.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
import json
|
||||||
|
from json import JSONDecodeError, JSONDecoder
|
||||||
|
from typing import Any, Tuple
|
||||||
|
|
||||||
|
import partial_json_parser
|
||||||
|
from partial_json_parser.core.options import Allow
|
||||||
|
|
||||||
|
|
||||||
|
def _find_common_prefix(s1: str, s2: str) -> str:
|
||||||
|
prefix = ""
|
||||||
|
min_length = min(len(s1), len(s2))
|
||||||
|
for i in range(0, min_length):
|
||||||
|
if s1[i] == s2[i]:
|
||||||
|
prefix += s1[i]
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
return prefix
|
||||||
|
|
||||||
|
|
||||||
|
def _partial_json_loads(input_str: str, flags: Allow) -> Tuple[Any, int]:
|
||||||
|
try:
|
||||||
|
return (partial_json_parser.loads(input_str, flags), len(input_str))
|
||||||
|
except JSONDecodeError as e:
|
||||||
|
if "Extra data" in e.msg:
|
||||||
|
dec = JSONDecoder()
|
||||||
|
return dec.raw_decode(input_str)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def _is_complete_json(input_str: str) -> bool:
|
||||||
|
try:
|
||||||
|
json.loads(input_str)
|
||||||
|
return True
|
||||||
|
except JSONDecodeError:
|
||||||
|
return False
|
||||||
@@ -1,858 +0,0 @@
|
|||||||
import ast
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import re
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from json import JSONDecodeError, JSONDecoder
|
|
||||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type
|
|
||||||
|
|
||||||
import partial_json_parser
|
|
||||||
from partial_json_parser.core.exceptions import MalformedJSON
|
|
||||||
from partial_json_parser.core.options import Allow
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from sglang.srt.openai_api.protocol import (
|
|
||||||
StructuralTagResponseFormat,
|
|
||||||
StructuresResponseFormat,
|
|
||||||
Tool,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
TOOLS_TAG_LIST = [
|
|
||||||
"<|plugin|>",
|
|
||||||
"<function=",
|
|
||||||
"<tool_call>",
|
|
||||||
"<|python_tag|>",
|
|
||||||
"[TOOL_CALLS]",
|
|
||||||
"<|tool▁calls▁begin|>",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class ToolCallItem(BaseModel):
|
|
||||||
"""Simple encapsulation of the parsed ToolCall result for easier usage in streaming contexts."""
|
|
||||||
|
|
||||||
tool_index: int
|
|
||||||
name: Optional[str] = None
|
|
||||||
parameters: str # JSON string
|
|
||||||
|
|
||||||
|
|
||||||
def _find_common_prefix(s1: str, s2: str) -> str:
|
|
||||||
prefix = ""
|
|
||||||
min_length = min(len(s1), len(s2))
|
|
||||||
for i in range(0, min_length):
|
|
||||||
if s1[i] == s2[i]:
|
|
||||||
prefix += s1[i]
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
return prefix
|
|
||||||
|
|
||||||
|
|
||||||
def _partial_json_loads(input_str: str, flags: Allow) -> Tuple[Any, int]:
|
|
||||||
try:
|
|
||||||
return (partial_json_parser.loads(input_str, flags), len(input_str))
|
|
||||||
except JSONDecodeError as e:
|
|
||||||
if "Extra data" in e.msg:
|
|
||||||
dec = JSONDecoder()
|
|
||||||
return dec.raw_decode(input_str)
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
def _is_complete_json(input_str: str) -> bool:
|
|
||||||
try:
|
|
||||||
json.loads(input_str)
|
|
||||||
return True
|
|
||||||
except JSONDecodeError:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
class StreamingParseResult:
|
|
||||||
"""Result of streaming incremental parsing."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, normal_text: str = "", calls: Optional[List[ToolCallItem]] = None
|
|
||||||
):
|
|
||||||
self.normal_text = normal_text
|
|
||||||
self.calls = calls or []
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class StructureInfo:
|
|
||||||
begin: str
|
|
||||||
end: str
|
|
||||||
trigger: str
|
|
||||||
|
|
||||||
|
|
||||||
_GetInfoFunc = Callable[[str], StructureInfo]
|
|
||||||
"""
|
|
||||||
Helper alias of function
|
|
||||||
Usually it is a function that takes a name string and returns a StructureInfo object,
|
|
||||||
which can be used to construct a structural_tag object
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class BaseFormatDetector(ABC):
|
|
||||||
"""Base class providing two sets of interfaces: one-time and streaming incremental."""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
# initialize properties used for state when parsing tool calls in
|
|
||||||
self._buffer = ""
|
|
||||||
# streaming mode
|
|
||||||
self.prev_tool_call_arr: List[Dict] = []
|
|
||||||
self.current_tool_id: int = -1
|
|
||||||
self.current_tool_name_sent: bool = False
|
|
||||||
self.streamed_args_for_tool: List[str] = (
|
|
||||||
[]
|
|
||||||
) # map what has been streamed for each tool so far to a list
|
|
||||||
self.bot_token = ""
|
|
||||||
self.eot_token = ""
|
|
||||||
|
|
||||||
def parse_base_json(self, action: Any, tools: List[Tool]) -> List[ToolCallItem]:
|
|
||||||
tool_indices = {
|
|
||||||
tool.function.name: i for i, tool in enumerate(tools) if tool.function.name
|
|
||||||
}
|
|
||||||
if not isinstance(action, list):
|
|
||||||
action = [action]
|
|
||||||
|
|
||||||
results = []
|
|
||||||
for act in action:
|
|
||||||
name = act.get("name")
|
|
||||||
if name and name in tool_indices:
|
|
||||||
results.append(
|
|
||||||
ToolCallItem(
|
|
||||||
tool_index=tool_indices[name],
|
|
||||||
name=name,
|
|
||||||
parameters=json.dumps(
|
|
||||||
act.get("parameters") or act.get("arguments", {}),
|
|
||||||
ensure_ascii=False,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.warning(f"Model attempted to call undefined function: {name}")
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
|
||||||
"""
|
|
||||||
Parses the text in one go. Returns success=True if the format matches, otherwise False.
|
|
||||||
Note that leftover_text here represents "content that this parser will not consume further".
|
|
||||||
"""
|
|
||||||
action = json.loads(text)
|
|
||||||
return StreamingParseResult(calls=self.parse_base_json(action, tools))
|
|
||||||
|
|
||||||
def parse_streaming_increment(
|
|
||||||
self, new_text: str, tools: List[Tool]
|
|
||||||
) -> StreamingParseResult:
|
|
||||||
"""
|
|
||||||
Streaming incremental parsing with tool validation.
|
|
||||||
"""
|
|
||||||
# Append new text to buffer
|
|
||||||
self._buffer += new_text
|
|
||||||
current_text = self._buffer
|
|
||||||
if not (self.bot_token in current_text or current_text.startswith("{")):
|
|
||||||
self._buffer = ""
|
|
||||||
if self.eot_token in new_text:
|
|
||||||
new_text = new_text.replace(self.eot_token, "")
|
|
||||||
return StreamingParseResult(normal_text=new_text)
|
|
||||||
|
|
||||||
# Build tool indices if not already built
|
|
||||||
if not hasattr(self, "_tool_indices"):
|
|
||||||
self._tool_indices = {
|
|
||||||
tool.function.name: i
|
|
||||||
for i, tool in enumerate(tools)
|
|
||||||
if tool.function and tool.function.name
|
|
||||||
}
|
|
||||||
|
|
||||||
flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR
|
|
||||||
try:
|
|
||||||
tool_call_arr = []
|
|
||||||
is_complete = []
|
|
||||||
try:
|
|
||||||
start_idx = (
|
|
||||||
len(self.bot_token)
|
|
||||||
if current_text.startswith(self.bot_token)
|
|
||||||
else 0
|
|
||||||
)
|
|
||||||
while start_idx < len(current_text):
|
|
||||||
(obj, end_idx) = _partial_json_loads(
|
|
||||||
current_text[start_idx:], flags
|
|
||||||
)
|
|
||||||
is_complete.append(
|
|
||||||
_is_complete_json(current_text[start_idx : start_idx + end_idx])
|
|
||||||
)
|
|
||||||
start_idx += end_idx + len("; ")
|
|
||||||
|
|
||||||
# Validate tool name if present
|
|
||||||
if "name" in obj and obj["name"] not in self._tool_indices:
|
|
||||||
# Invalid tool name - reset state
|
|
||||||
self._buffer = ""
|
|
||||||
self.current_tool_id = -1
|
|
||||||
self.current_tool_name_sent = False
|
|
||||||
if self.streamed_args_for_tool:
|
|
||||||
self.streamed_args_for_tool.pop()
|
|
||||||
return StreamingParseResult()
|
|
||||||
|
|
||||||
# Handle parameters/arguments consistency
|
|
||||||
if "parameters" in obj:
|
|
||||||
assert (
|
|
||||||
"arguments" not in obj
|
|
||||||
), "model generated both parameters and arguments"
|
|
||||||
obj["arguments"] = obj["parameters"]
|
|
||||||
tool_call_arr.append(obj)
|
|
||||||
|
|
||||||
except MalformedJSON:
|
|
||||||
return StreamingParseResult()
|
|
||||||
|
|
||||||
if len(tool_call_arr) == 0:
|
|
||||||
return StreamingParseResult()
|
|
||||||
|
|
||||||
current_tool_call: Dict = (
|
|
||||||
tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Handle new tool in array
|
|
||||||
if len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1:
|
|
||||||
if self.current_tool_id >= 0:
|
|
||||||
cur_arguments = current_tool_call.get("arguments")
|
|
||||||
if cur_arguments:
|
|
||||||
cur_args_json = json.dumps(cur_arguments)
|
|
||||||
sent = len(self.streamed_args_for_tool[self.current_tool_id])
|
|
||||||
argument_diff = cur_args_json[sent:]
|
|
||||||
|
|
||||||
res = StreamingParseResult(
|
|
||||||
calls=[
|
|
||||||
ToolCallItem(
|
|
||||||
tool_index=self.current_tool_id,
|
|
||||||
name="",
|
|
||||||
parameters=argument_diff,
|
|
||||||
)
|
|
||||||
],
|
|
||||||
)
|
|
||||||
self.streamed_args_for_tool[
|
|
||||||
self.current_tool_id
|
|
||||||
] += argument_diff
|
|
||||||
else:
|
|
||||||
res = StreamingParseResult()
|
|
||||||
else:
|
|
||||||
res = StreamingParseResult()
|
|
||||||
|
|
||||||
self.current_tool_id = len(tool_call_arr) - 1
|
|
||||||
self.current_tool_name_sent = False
|
|
||||||
self.streamed_args_for_tool.append("")
|
|
||||||
return res
|
|
||||||
|
|
||||||
# Handle tool name
|
|
||||||
elif not self.current_tool_name_sent:
|
|
||||||
function_name = current_tool_call.get("name")
|
|
||||||
if function_name and function_name in self._tool_indices:
|
|
||||||
res = StreamingParseResult(
|
|
||||||
calls=[
|
|
||||||
ToolCallItem(
|
|
||||||
tool_index=self._tool_indices[function_name],
|
|
||||||
name=function_name,
|
|
||||||
parameters="",
|
|
||||||
)
|
|
||||||
],
|
|
||||||
)
|
|
||||||
self.current_tool_name_sent = True
|
|
||||||
else:
|
|
||||||
res = StreamingParseResult()
|
|
||||||
|
|
||||||
# Handle streaming arguments
|
|
||||||
else:
|
|
||||||
cur_arguments = current_tool_call.get("arguments")
|
|
||||||
res = StreamingParseResult()
|
|
||||||
|
|
||||||
if cur_arguments:
|
|
||||||
sent = len(self.streamed_args_for_tool[self.current_tool_id])
|
|
||||||
cur_args_json = json.dumps(cur_arguments)
|
|
||||||
prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get(
|
|
||||||
"arguments"
|
|
||||||
)
|
|
||||||
|
|
||||||
argument_diff = None
|
|
||||||
if is_complete[self.current_tool_id]:
|
|
||||||
argument_diff = cur_args_json[sent:]
|
|
||||||
self._buffer = ""
|
|
||||||
self.prev_tool_call_arr[self.current_tool_id].clear()
|
|
||||||
self.current_tool_name_sent = False
|
|
||||||
self.streamed_args_for_tool[self.current_tool_id] = ""
|
|
||||||
|
|
||||||
elif prev_arguments:
|
|
||||||
prev_args_json = json.dumps(prev_arguments)
|
|
||||||
if cur_args_json != prev_args_json:
|
|
||||||
prefix = _find_common_prefix(prev_args_json, cur_args_json)
|
|
||||||
argument_diff = prefix[sent:]
|
|
||||||
|
|
||||||
if argument_diff is not None:
|
|
||||||
res = StreamingParseResult(
|
|
||||||
calls=[
|
|
||||||
ToolCallItem(
|
|
||||||
tool_index=self.current_tool_id,
|
|
||||||
parameters=argument_diff,
|
|
||||||
)
|
|
||||||
],
|
|
||||||
)
|
|
||||||
if not is_complete[self.current_tool_id]:
|
|
||||||
self.streamed_args_for_tool[
|
|
||||||
self.current_tool_id
|
|
||||||
] += argument_diff
|
|
||||||
|
|
||||||
self.prev_tool_call_arr = tool_call_arr
|
|
||||||
return res
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in parse_streaming_increment: {e}")
|
|
||||||
return StreamingParseResult()
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def has_tool_call(self, text: str) -> bool:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def structure_info(self) -> _GetInfoFunc:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
|
|
||||||
class Qwen25Detector(BaseFormatDetector):
|
|
||||||
"""
|
|
||||||
Detector for Qwen 2.5 models.
|
|
||||||
Assumes function call format:
|
|
||||||
<tool_call>{"name":"xxx", "arguments":{...}}</tool_call>
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
"""
|
|
||||||
Initializes the detector with necessary state variables.
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
self.bot_token = "<tool_call>"
|
|
||||||
self.eot_token = "</tool_call>"
|
|
||||||
|
|
||||||
def has_tool_call(self, text: str) -> bool:
|
|
||||||
"""Check if the text contains a Qwen 2.5 format tool call."""
|
|
||||||
return self.bot_token in text
|
|
||||||
|
|
||||||
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
|
||||||
"""
|
|
||||||
One-time parsing: Detects and parses tool calls in the provided text.
|
|
||||||
|
|
||||||
:param text: The complete text to parse.
|
|
||||||
:param tools: List of available tools.
|
|
||||||
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
|
|
||||||
"""
|
|
||||||
idx = text.find(self.bot_token)
|
|
||||||
normal_text = text[:idx].strip() if idx != -1 else text
|
|
||||||
if self.bot_token not in text:
|
|
||||||
return StreamingParseResult(normal_text=normal_text, calls=[])
|
|
||||||
pattern = rf"{self.bot_token}(.*?){self.eot_token}"
|
|
||||||
match_result_list = re.findall(pattern, text, re.DOTALL)
|
|
||||||
calls = []
|
|
||||||
for match_result in match_result_list:
|
|
||||||
match_result = json.loads(match_result)
|
|
||||||
calls.extend(self.parse_base_json(match_result, tools))
|
|
||||||
return StreamingParseResult(normal_text=normal_text, calls=calls)
|
|
||||||
|
|
||||||
def structure_info(self) -> _GetInfoFunc:
|
|
||||||
return lambda name: StructureInfo(
|
|
||||||
begin='<tool_call>{"name":"' + name + '", "arguments":',
|
|
||||||
end="}</tool_call>",
|
|
||||||
trigger="<tool_call>",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class MistralDetector(BaseFormatDetector):
|
|
||||||
"""
|
|
||||||
Detector for Mistral models.
|
|
||||||
Assumes function call format:
|
|
||||||
<|action_start|><|plugin|>{"name":"xxx", "arguments":{...}}<|action_end|>
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
"""
|
|
||||||
Initializes the detector with necessary state variables.
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
self.bot_token = "[TOOL_CALLS] ["
|
|
||||||
self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
|
|
||||||
|
|
||||||
def has_tool_call(self, text: str) -> bool:
|
|
||||||
"""Check if the text contains a Mistral format tool call."""
|
|
||||||
return self.bot_token in text
|
|
||||||
|
|
||||||
def _clean_text(self, text: str) -> str:
|
|
||||||
"""
|
|
||||||
clean text to only leave ''[TOOL_CALLS] [{"name": xxx, "arguments": {xxx}}]'
|
|
||||||
for example,
|
|
||||||
text = '[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"location": "Boston, MA", "unit": "fahrenheit"}}]\n\nToday\'s weather in Boston is :{function call result} (in Fahrenheit)\n\nIf you prefer Celsius, please let me know.'
|
|
||||||
return '[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"location": "Boston, MA", "unit": "fahrenheit"}}]'
|
|
||||||
The key pattern is [TOOL_CALLS] [...]
|
|
||||||
"""
|
|
||||||
find_results = re.findall(r"\[TOOL_CALLS\] \[.*?\]", text, re.DOTALL)
|
|
||||||
if len(find_results) > 0:
|
|
||||||
return find_results[0]
|
|
||||||
else:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
|
||||||
"""
|
|
||||||
One-time parsing: Detects and parses tool calls in the provided text.
|
|
||||||
|
|
||||||
:param text: The complete text to parse.
|
|
||||||
:param tools: List of available tools.
|
|
||||||
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
|
|
||||||
"""
|
|
||||||
idx = text.find(self.bot_token)
|
|
||||||
normal_text = text[:idx].strip() if idx != -1 else text
|
|
||||||
text = self._clean_text(text)
|
|
||||||
tool_content = text.replace("[TOOL_CALLS]", "").strip()
|
|
||||||
raw_tool_calls = self.tool_call_regex.findall(tool_content)
|
|
||||||
calls = []
|
|
||||||
if len(raw_tool_calls) > 0:
|
|
||||||
raw_tool_call = raw_tool_calls[0]
|
|
||||||
function_call_arr = json.loads(raw_tool_call)
|
|
||||||
for match_result in function_call_arr:
|
|
||||||
calls.extend(self.parse_base_json(match_result, tools))
|
|
||||||
return StreamingParseResult(normal_text=normal_text, calls=calls)
|
|
||||||
|
|
||||||
def structure_info(self) -> _GetInfoFunc:
|
|
||||||
return lambda name: StructureInfo(
|
|
||||||
begin='[TOOL_CALLS] [{"name":"' + name + '", "arguments":',
|
|
||||||
end="}]",
|
|
||||||
trigger="[TOOL_CALLS]",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Llama32Detector(BaseFormatDetector):
|
|
||||||
"""
|
|
||||||
Detector for Llama 3.2 models.
|
|
||||||
Assumes function call format:
|
|
||||||
<|python_tag|>{"name":"xxx", "arguments":{...}}
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.bot_token = "<|python_tag|>"
|
|
||||||
|
|
||||||
def has_tool_call(self, text: str) -> bool:
|
|
||||||
"""Check if the text contains a Llama 3.2 format tool call."""
|
|
||||||
# depending on the prompt format the Llama model may or may not
|
|
||||||
# prefix the output with the <|python_tag|> token
|
|
||||||
return "<|python_tag|>" in text or text.startswith("{")
|
|
||||||
|
|
||||||
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
|
||||||
"""Parse function calls from text, handling multiple JSON objects."""
|
|
||||||
if "<|python_tag|>" not in text and not text.startswith("{"):
|
|
||||||
return StreamingParseResult(normal_text=text, calls=[])
|
|
||||||
|
|
||||||
if "<|python_tag|>" in text:
|
|
||||||
normal_text, action_text = text.split("<|python_tag|>")
|
|
||||||
else:
|
|
||||||
normal_text, action_text = "", text
|
|
||||||
|
|
||||||
# Split by semicolon and process each part
|
|
||||||
json_parts = [part.strip() for part in action_text.split(";") if part.strip()]
|
|
||||||
all_actions = []
|
|
||||||
for part in json_parts:
|
|
||||||
try:
|
|
||||||
# Parse each individual JSON object
|
|
||||||
action = json.loads(part)
|
|
||||||
all_actions.append(action)
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
logger.warning(f"Failed to parse JSON part: {part}")
|
|
||||||
logger.warning(f"JSON parse error: {str(e)}")
|
|
||||||
continue
|
|
||||||
calls = []
|
|
||||||
# Only process if we found valid JSON objects
|
|
||||||
if all_actions:
|
|
||||||
calls = self.parse_base_json(all_actions, tools)
|
|
||||||
return StreamingParseResult(normal_text=normal_text, calls=calls)
|
|
||||||
|
|
||||||
def structure_info(self) -> _GetInfoFunc:
|
|
||||||
return lambda name: StructureInfo(
|
|
||||||
begin='<|python_tag|>{"name":"' + name + '", "arguments":',
|
|
||||||
end="}",
|
|
||||||
trigger="<|python_tag|>",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DeepSeekV3Detector(BaseFormatDetector):
|
|
||||||
"""
|
|
||||||
Detector for DeepSeek models.
|
|
||||||
Assumes function call format:
|
|
||||||
'<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_current_weather\n```json\n{"location": "Tokyo"}\n```<|tool▁call▁end|>\n<|tool▁call▁begin|>function<|tool▁sep|>get_current_weather\n```json\n{"location": "Paris"}\n```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.bot_token = "<|tool▁calls▁begin|>"
|
|
||||||
self.eot_token = "<|tool▁calls▁end|>"
|
|
||||||
self.func_call_regex = r"<|tool▁call▁begin|>.*?<|tool▁call▁end|>"
|
|
||||||
self.func_detail_regex = r"<|tool▁call▁begin|>(.*)<|tool▁sep|>(.*)\n```json\n(.*)\n```<|tool▁call▁end|>"
|
|
||||||
self._last_arguments = ""
|
|
||||||
|
|
||||||
def has_tool_call(self, text: str) -> bool:
|
|
||||||
"""Check if the text contains a deepseek format tool call."""
|
|
||||||
return self.bot_token in text
|
|
||||||
|
|
||||||
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
|
||||||
"""
|
|
||||||
One-time parsing: Detects and parses tool calls in the provided text.
|
|
||||||
|
|
||||||
:param text: The complete text to parse.
|
|
||||||
:param tools: List of available tools.
|
|
||||||
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
|
|
||||||
"""
|
|
||||||
idx = text.find(self.bot_token)
|
|
||||||
normal_text = text[:idx].strip() if idx != -1 else text
|
|
||||||
if self.bot_token not in text:
|
|
||||||
return StreamingParseResult(normal_text=normal_text, calls=[])
|
|
||||||
match_result_list = re.findall(self.func_call_regex, text, re.DOTALL)
|
|
||||||
calls = []
|
|
||||||
try:
|
|
||||||
for match_result in match_result_list:
|
|
||||||
# Get function name
|
|
||||||
func_detail = re.search(self.func_detail_regex, match_result, re.DOTALL)
|
|
||||||
func_name = func_detail.group(2)
|
|
||||||
func_args = func_detail.group(3)
|
|
||||||
func_args = json.loads(func_args)
|
|
||||||
# construct match_result for parse_base_json
|
|
||||||
match_result = {"name": func_name, "parameters": func_args}
|
|
||||||
calls.extend(self.parse_base_json(match_result, tools))
|
|
||||||
return StreamingParseResult(normal_text=normal_text, calls=calls)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in detect_and_parse: {e}")
|
|
||||||
# return the normal text if parsing fails
|
|
||||||
return StreamingParseResult(normal_text=text)
|
|
||||||
|
|
||||||
def structure_info(self) -> _GetInfoFunc:
|
|
||||||
return lambda name: StructureInfo(
|
|
||||||
begin=">" + name + "\n```json\n",
|
|
||||||
end="\n```<",
|
|
||||||
trigger=">" + name + "\n```json\n",
|
|
||||||
)
|
|
||||||
|
|
||||||
def parse_streaming_increment(
|
|
||||||
self, new_text: str, tools: List[Tool]
|
|
||||||
) -> StreamingParseResult:
|
|
||||||
"""
|
|
||||||
Streaming incremental parsing tool calls for DeepSeekV3 format.
|
|
||||||
"""
|
|
||||||
self._buffer += new_text
|
|
||||||
current_text = self._buffer
|
|
||||||
|
|
||||||
if self.bot_token not in current_text:
|
|
||||||
self._buffer = ""
|
|
||||||
for e_token in [self.eot_token, "```", "<|tool▁call▁end|>"]:
|
|
||||||
if e_token in new_text:
|
|
||||||
new_text = new_text.replace(e_token, "")
|
|
||||||
return StreamingParseResult(normal_text=new_text)
|
|
||||||
|
|
||||||
if not hasattr(self, "_tool_indices"):
|
|
||||||
self._tool_indices = {
|
|
||||||
tool.function.name: i
|
|
||||||
for i, tool in enumerate(tools)
|
|
||||||
if tool.function and tool.function.name
|
|
||||||
}
|
|
||||||
|
|
||||||
calls: list[ToolCallItem] = []
|
|
||||||
try:
|
|
||||||
partial_match = re.search(
|
|
||||||
pattern=r"<|tool▁call▁begin|>(.*)<|tool▁sep|>(.*)\n```json\n(.*)",
|
|
||||||
string=current_text,
|
|
||||||
flags=re.DOTALL,
|
|
||||||
)
|
|
||||||
if partial_match:
|
|
||||||
func_name = partial_match.group(2).strip()
|
|
||||||
func_args_raw = partial_match.group(3).strip()
|
|
||||||
|
|
||||||
if not self.current_tool_name_sent:
|
|
||||||
calls.append(
|
|
||||||
ToolCallItem(
|
|
||||||
tool_index=self._tool_indices.get(func_name, 0),
|
|
||||||
name=func_name,
|
|
||||||
parameters="",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.current_tool_name_sent = True
|
|
||||||
else:
|
|
||||||
argument_diff = (
|
|
||||||
func_args_raw[len(self._last_arguments) :]
|
|
||||||
if func_args_raw.startswith(self._last_arguments)
|
|
||||||
else func_args_raw
|
|
||||||
)
|
|
||||||
|
|
||||||
if argument_diff:
|
|
||||||
calls.append(
|
|
||||||
ToolCallItem(
|
|
||||||
tool_index=self._tool_indices.get(func_name, 0),
|
|
||||||
name=None,
|
|
||||||
parameters=argument_diff,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self._last_arguments += argument_diff
|
|
||||||
|
|
||||||
if _is_complete_json(func_args_raw):
|
|
||||||
result = StreamingParseResult(normal_text="", calls=calls)
|
|
||||||
self._buffer = ""
|
|
||||||
self._last_arguments = ""
|
|
||||||
self.current_tool_name_sent = False
|
|
||||||
return result
|
|
||||||
|
|
||||||
return StreamingParseResult(normal_text="", calls=calls)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in parse_streaming_increment: {e}")
|
|
||||||
return StreamingParseResult(normal_text=current_text)
|
|
||||||
|
|
||||||
|
|
||||||
class MultiFormatParser:
|
|
||||||
def __init__(self, detectors: List[BaseFormatDetector]):
|
|
||||||
"""
|
|
||||||
:param detectors: A series of available Detector instances passed in
|
|
||||||
"""
|
|
||||||
self.detectors = detectors
|
|
||||||
|
|
||||||
def parse_once(
|
|
||||||
self, text: str, tools: List[Tool]
|
|
||||||
) -> Tuple[str, list[ToolCallItem]]:
|
|
||||||
"""
|
|
||||||
One-time parsing: Loop through detectors until there are no new matches or text is exhausted
|
|
||||||
Return: (final_text, all_calls)
|
|
||||||
- final_text: The remaining text after parsing that was not consumed by any Detector (can be treated as normal text)
|
|
||||||
- all_calls: All calls parsed by the Detectors
|
|
||||||
"""
|
|
||||||
final_calls = []
|
|
||||||
final_normal_text = text
|
|
||||||
for detector in self.detectors:
|
|
||||||
parsed_result = detector.detect_and_parse(text, tools)
|
|
||||||
tool_call_list = parsed_result.calls
|
|
||||||
if len(tool_call_list) > 0: # parsed successfully
|
|
||||||
final_calls = tool_call_list
|
|
||||||
final_normal_text = parsed_result.normal_text
|
|
||||||
break
|
|
||||||
|
|
||||||
# leftover_text is the normal text not consumed by any Detector
|
|
||||||
return final_normal_text, final_calls
|
|
||||||
|
|
||||||
def parse_streaming_increment(
|
|
||||||
self, new_text: str, tools: List[Tool]
|
|
||||||
) -> Tuple[str, list[ToolCallItem]]:
|
|
||||||
"""
|
|
||||||
Streaming incremental parsing: Feed new_text to each detector's parse_streaming_increment
|
|
||||||
and merge their produced normal_text/calls to return.
|
|
||||||
(The logic here can be "priority-based" or "parallel parsing" based on your needs)
|
|
||||||
"""
|
|
||||||
final_normal_text = ""
|
|
||||||
final_calls = []
|
|
||||||
|
|
||||||
for detector in self.detectors:
|
|
||||||
sp_result = detector.parse_streaming_increment(new_text, tools)
|
|
||||||
# Merge normal_text and calls
|
|
||||||
# If one sp_result contains result call, this should be a successful parse
|
|
||||||
# If one sp_result only contains normal_text, this can either be a successful
|
|
||||||
# parse or it is not using the desired parsing tool.
|
|
||||||
if sp_result.normal_text:
|
|
||||||
final_normal_text = sp_result.normal_text
|
|
||||||
if sp_result.calls:
|
|
||||||
final_calls.extend(sp_result.calls)
|
|
||||||
final_normal_text = sp_result.normal_text
|
|
||||||
break
|
|
||||||
|
|
||||||
return final_normal_text, final_calls
|
|
||||||
|
|
||||||
|
|
||||||
class PythonicDetector(BaseFormatDetector):
|
|
||||||
"""
|
|
||||||
Detector for Llama-3.2 and Llama-4 models with pythonic tool call format.
|
|
||||||
Assumes function call format:
|
|
||||||
[tool1(arg1=val1, arg2=val2), tool2(arg1=val3)]
|
|
||||||
Arguments are Python literals (not JSON).
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.tool_call_regex = re.compile(
|
|
||||||
r"\[([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s)?\),\s*)*([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s*)?\)\s*)+\]",
|
|
||||||
re.DOTALL,
|
|
||||||
)
|
|
||||||
|
|
||||||
def has_tool_call(self, text: str) -> bool:
|
|
||||||
return bool(self.tool_call_regex.match(text.strip()))
|
|
||||||
|
|
||||||
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
|
||||||
# Try parsing the text as a Python list of function calls
|
|
||||||
text = text.strip()
|
|
||||||
if not (text.startswith("[") and text.endswith("]")):
|
|
||||||
# Not a pythonic tool call format
|
|
||||||
return StreamingParseResult(normal_text=text, calls=[])
|
|
||||||
try:
|
|
||||||
module = ast.parse(text)
|
|
||||||
parsed = getattr(module.body[0], "value", None)
|
|
||||||
if not (
|
|
||||||
isinstance(parsed, ast.List)
|
|
||||||
and all(isinstance(e, ast.Call) for e in parsed.elts)
|
|
||||||
):
|
|
||||||
return StreamingParseResult(normal_text=text, calls=[])
|
|
||||||
calls = []
|
|
||||||
tool_indices = {
|
|
||||||
tool.function.name: i
|
|
||||||
for i, tool in enumerate(tools)
|
|
||||||
if tool.function.name
|
|
||||||
}
|
|
||||||
for call in parsed.elts:
|
|
||||||
if not isinstance(call.func, ast.Name):
|
|
||||||
continue
|
|
||||||
function_name = call.func.id
|
|
||||||
arguments = {}
|
|
||||||
for keyword in call.keywords:
|
|
||||||
arguments[keyword.arg] = self._get_parameter_value(keyword.value)
|
|
||||||
calls.append(
|
|
||||||
ToolCallItem(
|
|
||||||
tool_index=tool_indices.get(function_name, -1),
|
|
||||||
name=function_name,
|
|
||||||
parameters=json.dumps(arguments, ensure_ascii=False),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return StreamingParseResult(normal_text="", calls=calls)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Error in pythonic tool call parsing.")
|
|
||||||
return StreamingParseResult(normal_text=text, calls=[])
|
|
||||||
|
|
||||||
def parse_streaming_increment(
|
|
||||||
self, new_text: str, tools: List[Tool]
|
|
||||||
) -> StreamingParseResult:
|
|
||||||
"""
|
|
||||||
Streaming incremental parsing for pythonic tool calls.
|
|
||||||
Buffers input until a complete pythonic tool call (from [ to ]) is found,
|
|
||||||
then parses and emits any detected calls.
|
|
||||||
"""
|
|
||||||
self._buffer += new_text
|
|
||||||
start = self._buffer.find("[")
|
|
||||||
end = self._buffer.find("]", start)
|
|
||||||
if start != -1 and end != -1:
|
|
||||||
call_text = self._buffer[start : end + 1]
|
|
||||||
result = self.detect_and_parse(call_text, tools)
|
|
||||||
self._buffer = self._buffer[end + 1 :]
|
|
||||||
return result
|
|
||||||
return StreamingParseResult(normal_text="")
|
|
||||||
|
|
||||||
def _get_parameter_value(self, val):
|
|
||||||
if isinstance(val, ast.Constant):
|
|
||||||
return val.value
|
|
||||||
elif isinstance(val, ast.Dict):
|
|
||||||
return {
|
|
||||||
k.value: self._get_parameter_value(v)
|
|
||||||
for k, v in zip(val.keys, val.values)
|
|
||||||
}
|
|
||||||
elif isinstance(val, ast.List):
|
|
||||||
return [self._get_parameter_value(v) for v in val.elts]
|
|
||||||
else:
|
|
||||||
raise ValueError("Tool call arguments must be literals")
|
|
||||||
|
|
||||||
def structure_info(self) -> _GetInfoFunc:
|
|
||||||
def info(name: str):
|
|
||||||
return StructureInfo(begin="[", end="]", trigger="")
|
|
||||||
|
|
||||||
return info
|
|
||||||
|
|
||||||
|
|
||||||
class FunctionCallParser:
|
|
||||||
"""
|
|
||||||
In streaming scenarios, each time new_text is received, it calls multi_format_parser.parse_streaming_increment
|
|
||||||
and returns the resulting normal_text and calls to the upper layer (or SSE).
|
|
||||||
"""
|
|
||||||
|
|
||||||
ToolCallParserEnum: Dict[str, Type[BaseFormatDetector]] = {
|
|
||||||
"llama3": Llama32Detector,
|
|
||||||
"qwen25": Qwen25Detector,
|
|
||||||
"mistral": MistralDetector,
|
|
||||||
"deepseekv3": DeepSeekV3Detector,
|
|
||||||
"pythonic": PythonicDetector,
|
|
||||||
}
|
|
||||||
|
|
||||||
def __init__(self, tools: List[Tool], tool_call_parser: str):
|
|
||||||
detectors = []
|
|
||||||
if tool_call_parser:
|
|
||||||
detector_class = self.ToolCallParserEnum.get(tool_call_parser)
|
|
||||||
if detector_class:
|
|
||||||
detectors.append(detector_class())
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported tool_call_parser: {tool_call_parser}")
|
|
||||||
else:
|
|
||||||
raise ValueError("Tool Call Parser Not Given!")
|
|
||||||
|
|
||||||
self.multi_format_parser = MultiFormatParser(detectors)
|
|
||||||
self.tools = tools
|
|
||||||
|
|
||||||
def has_tool_call(self, text: str) -> bool:
|
|
||||||
"""
|
|
||||||
Check if the given text contains a tool call in the format supported by this parser.
|
|
||||||
This delegates to the detector's implementation.
|
|
||||||
|
|
||||||
:param text: The text to check for tool calls
|
|
||||||
:return: True if the text contains a tool call, False otherwise
|
|
||||||
"""
|
|
||||||
# Check all detectors in the multi_format_parser
|
|
||||||
for detector in self.multi_format_parser.detectors:
|
|
||||||
if detector.has_tool_call(text):
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def parse_non_stream(self, full_text: str) -> Tuple[str, list[ToolCallItem]]:
|
|
||||||
"""
|
|
||||||
Non-streaming call: one-time parsing
|
|
||||||
"""
|
|
||||||
full_normal_text, calls = self.multi_format_parser.parse_once(
|
|
||||||
full_text, self.tools
|
|
||||||
)
|
|
||||||
return full_normal_text, calls
|
|
||||||
|
|
||||||
def parse_stream_chunk(self, chunk_text: str) -> Tuple[str, list[ToolCallItem]]:
|
|
||||||
"""
|
|
||||||
Streaming call: incremental parsing
|
|
||||||
"""
|
|
||||||
normal_text, calls = self.multi_format_parser.parse_streaming_increment(
|
|
||||||
chunk_text, self.tools
|
|
||||||
)
|
|
||||||
return normal_text, calls
|
|
||||||
|
|
||||||
def structure_infos(self) -> List[_GetInfoFunc]:
|
|
||||||
"""
|
|
||||||
Returns a list of structure_info functions for each detector
|
|
||||||
"""
|
|
||||||
return [
|
|
||||||
detector.structure_info() for detector in self.multi_format_parser.detectors
|
|
||||||
]
|
|
||||||
|
|
||||||
def get_structure_tag(self) -> StructuralTagResponseFormat:
|
|
||||||
tool_structures: List[StructuresResponseFormat] = list()
|
|
||||||
tool_trigger_set: Set[str] = set()
|
|
||||||
|
|
||||||
for wrapper in self.structure_infos():
|
|
||||||
for tool in self.tools:
|
|
||||||
function = tool.function
|
|
||||||
name = function.name
|
|
||||||
assert name is not None
|
|
||||||
info = wrapper(name)
|
|
||||||
|
|
||||||
# accept all if not strict, otherwise only accept the schema
|
|
||||||
schema = function.parameters if function.strict else {}
|
|
||||||
|
|
||||||
tool_structures.append(
|
|
||||||
StructuresResponseFormat(
|
|
||||||
begin=info.begin,
|
|
||||||
schema=schema, # type: ignore
|
|
||||||
end=info.end,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
tool_trigger_set.add(info.trigger)
|
|
||||||
|
|
||||||
return StructuralTagResponseFormat(
|
|
||||||
type="structural_tag",
|
|
||||||
structures=tool_structures,
|
|
||||||
triggers=list(tool_trigger_set),
|
|
||||||
)
|
|
||||||
@@ -40,7 +40,7 @@ from sglang.srt.conversation import (
|
|||||||
get_conv_template_by_model_path,
|
get_conv_template_by_model_path,
|
||||||
register_conv_template,
|
register_conv_template,
|
||||||
)
|
)
|
||||||
from sglang.srt.function_call_parser import FunctionCallParser
|
from sglang.srt.function_call.function_call_parser import FunctionCallParser
|
||||||
from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
|
from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
|
||||||
from sglang.srt.openai_api.protocol import (
|
from sglang.srt.openai_api.protocol import (
|
||||||
BatchRequest,
|
BatchRequest,
|
||||||
@@ -970,7 +970,7 @@ def v1_chat_generate_request(
|
|||||||
# - image_data: None or a list of image strings (URLs or base64 strings).
|
# - image_data: None or a list of image strings (URLs or base64 strings).
|
||||||
# - audio_data: None or a list of audio strings (URLs).
|
# - audio_data: None or a list of audio strings (URLs).
|
||||||
# None skips any image processing in GenerateReqInput.
|
# None skips any image processing in GenerateReqInput.
|
||||||
strict_tag = None
|
tool_call_constraint = None
|
||||||
prompt = ""
|
prompt = ""
|
||||||
prompt_ids = []
|
prompt_ids = []
|
||||||
if not isinstance(request.messages, str):
|
if not isinstance(request.messages, str):
|
||||||
@@ -989,7 +989,9 @@ def v1_chat_generate_request(
|
|||||||
|
|
||||||
tool_call_parser = tokenizer_manager.server_args.tool_call_parser
|
tool_call_parser = tokenizer_manager.server_args.tool_call_parser
|
||||||
parser = FunctionCallParser(request.tools, tool_call_parser)
|
parser = FunctionCallParser(request.tools, tool_call_parser)
|
||||||
strict_tag = parser.get_structure_tag()
|
tool_call_constraint = parser.get_structure_constraint(
|
||||||
|
request.tool_choice
|
||||||
|
)
|
||||||
|
|
||||||
if chat_template_name is None:
|
if chat_template_name is None:
|
||||||
openai_compatible_messages = []
|
openai_compatible_messages = []
|
||||||
@@ -1156,20 +1158,24 @@ def v1_chat_generate_request(
|
|||||||
request.response_format.model_dump(by_alias=True)
|
request.response_format.model_dump(by_alias=True)
|
||||||
)
|
)
|
||||||
|
|
||||||
if strict_tag is not None:
|
# Check if there are already existing output constraints
|
||||||
if (
|
has_existing_constraints = (
|
||||||
sampling_params.get("regex")
|
sampling_params.get("regex")
|
||||||
or sampling_params.get("ebnf")
|
or sampling_params.get("ebnf")
|
||||||
or sampling_params.get("structural_tag")
|
or sampling_params.get("structural_tag")
|
||||||
or sampling_params.get("json_schema")
|
or sampling_params.get("json_schema")
|
||||||
):
|
)
|
||||||
logger.warning(
|
|
||||||
"Constrained decoding is not compatible with tool calls."
|
if tool_call_constraint and has_existing_constraints:
|
||||||
|
logger.warning("Constrained decoding is not compatible with tool calls.")
|
||||||
|
elif tool_call_constraint:
|
||||||
|
constraint_type, constraint_value = tool_call_constraint
|
||||||
|
if constraint_type == "structural_tag":
|
||||||
|
sampling_params[constraint_type] = convert_json_schema_to_str(
|
||||||
|
constraint_value.model_dump(by_alias=True)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
sampling_params["structural_tag"] = convert_json_schema_to_str(
|
sampling_params[constraint_type] = constraint_value
|
||||||
strict_tag.model_dump(by_alias=True)
|
|
||||||
)
|
|
||||||
|
|
||||||
sampling_params_list.append(sampling_params)
|
sampling_params_list.append(sampling_params)
|
||||||
|
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ suites = {
|
|||||||
TestFile("test_fa3.py", 376),
|
TestFile("test_fa3.py", 376),
|
||||||
TestFile("test_fim_completion.py", 40),
|
TestFile("test_fim_completion.py", 40),
|
||||||
TestFile("test_fp8_kernel.py", 8),
|
TestFile("test_fp8_kernel.py", 8),
|
||||||
TestFile("test_function_calling.py", 60),
|
TestFile("test_function_call_parser.py", 10),
|
||||||
TestFile("test_fused_moe.py", 30),
|
TestFile("test_fused_moe.py", 30),
|
||||||
TestFile("test_hicache.py", 116),
|
TestFile("test_hicache.py", 116),
|
||||||
TestFile("test_hicache_mla.py", 254),
|
TestFile("test_hicache_mla.py", 254),
|
||||||
@@ -54,6 +54,7 @@ suites = {
|
|||||||
TestFile("test_flashmla.py", 300),
|
TestFile("test_flashmla.py", 300),
|
||||||
TestFile("test_no_chunked_prefill.py", 108),
|
TestFile("test_no_chunked_prefill.py", 108),
|
||||||
TestFile("test_no_overlap_scheduler.py", 216),
|
TestFile("test_no_overlap_scheduler.py", 216),
|
||||||
|
TestFile("test_openai_function_calling.py", 60),
|
||||||
TestFile("test_openai_server.py", 149),
|
TestFile("test_openai_server.py", 149),
|
||||||
TestFile("test_penalty.py", 41),
|
TestFile("test_penalty.py", 41),
|
||||||
TestFile("test_page_size.py", 60),
|
TestFile("test_page_size.py", 60),
|
||||||
|
|||||||
408
test/srt/test_function_call_parser.py
Normal file
408
test/srt/test_function_call_parser.py
Normal file
@@ -0,0 +1,408 @@
|
|||||||
|
import json
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from xgrammar import GrammarCompiler, TokenizerInfo
|
||||||
|
|
||||||
|
from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector
|
||||||
|
from sglang.srt.function_call.llama32_detector import Llama32Detector
|
||||||
|
from sglang.srt.function_call.mistral_detector import MistralDetector
|
||||||
|
from sglang.srt.function_call.pythonic_detector import PythonicDetector
|
||||||
|
from sglang.srt.function_call.qwen25_detector import Qwen25Detector
|
||||||
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||||
|
from sglang.srt.openai_api.protocol import Function, Tool
|
||||||
|
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||||
|
|
||||||
|
|
||||||
|
class TestPythonicDetector(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
# Create sample tools for testing
|
||||||
|
self.tools = [
|
||||||
|
Tool(
|
||||||
|
type="function",
|
||||||
|
function=Function(
|
||||||
|
name="get_weather",
|
||||||
|
description="Get weather information",
|
||||||
|
parameters={
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Location to get weather for",
|
||||||
|
},
|
||||||
|
"unit": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Temperature unit",
|
||||||
|
"enum": ["celsius", "fahrenheit"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["location"],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
),
|
||||||
|
Tool(
|
||||||
|
type="function",
|
||||||
|
function=Function(
|
||||||
|
name="search",
|
||||||
|
description="Search for information",
|
||||||
|
parameters={
|
||||||
|
"properties": {
|
||||||
|
"query": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Search query",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["query"],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
self.detector = PythonicDetector()
|
||||||
|
|
||||||
|
def test_parse_streaming_no_brackets(self):
|
||||||
|
"""Test parsing text with no brackets (no tool calls)."""
|
||||||
|
text = "This is just normal text without any tool calls."
|
||||||
|
result = self.detector.parse_streaming_increment(text, self.tools)
|
||||||
|
|
||||||
|
self.assertEqual(result.normal_text, text)
|
||||||
|
self.assertEqual(result.calls, [])
|
||||||
|
self.assertEqual(self.detector._buffer, "") # Buffer should be cleared
|
||||||
|
|
||||||
|
def test_parse_streaming_complete_tool_call(self):
|
||||||
|
"""Test parsing a complete tool call."""
|
||||||
|
text = "Here's a tool call: [get_weather(location='New York', unit='celsius')]"
|
||||||
|
result = self.detector.parse_streaming_increment(text, self.tools)
|
||||||
|
|
||||||
|
self.assertEqual(result.normal_text, "Here's a tool call: ")
|
||||||
|
self.assertEqual(len(result.calls), 1)
|
||||||
|
self.assertEqual(result.calls[0].name, "get_weather")
|
||||||
|
self.assertEqual(
|
||||||
|
self.detector._buffer, ""
|
||||||
|
) # Buffer should be cleared after processing
|
||||||
|
|
||||||
|
# Check the parameters
|
||||||
|
params = json.loads(result.calls[0].parameters)
|
||||||
|
self.assertEqual(params["location"], "New York")
|
||||||
|
self.assertEqual(params["unit"], "celsius")
|
||||||
|
|
||||||
|
def test_parse_streaming_text_before_tool_call(self):
|
||||||
|
"""Test parsing text that appears before a tool call."""
|
||||||
|
text = "This is some text before [get_weather(location='London')]"
|
||||||
|
result = self.detector.parse_streaming_increment(text, self.tools)
|
||||||
|
|
||||||
|
self.assertEqual(result.normal_text, "This is some text before ")
|
||||||
|
self.assertEqual(len(result.calls), 1)
|
||||||
|
self.assertEqual(result.calls[0].name, "get_weather")
|
||||||
|
|
||||||
|
# Check the parameters
|
||||||
|
params = json.loads(result.calls[0].parameters)
|
||||||
|
self.assertEqual(params["location"], "London")
|
||||||
|
|
||||||
|
def test_parse_streaming_partial_tool_call(self):
|
||||||
|
"""Test parsing a partial tool call that spans multiple chunks."""
|
||||||
|
# First chunk with opening bracket but no closing bracket
|
||||||
|
text1 = "Let me check the weather: [get_weather(location="
|
||||||
|
result1 = self.detector.parse_streaming_increment(text1, self.tools)
|
||||||
|
|
||||||
|
self.assertEqual(result1.normal_text, "Let me check the weather: ")
|
||||||
|
self.assertEqual(result1.calls, [])
|
||||||
|
self.assertEqual(
|
||||||
|
self.detector._buffer, "[get_weather(location="
|
||||||
|
) # Partial tool call remains in buffer
|
||||||
|
|
||||||
|
# Second chunk completing the tool call
|
||||||
|
text2 = "'Paris')]"
|
||||||
|
result2 = self.detector.parse_streaming_increment(text2, self.tools)
|
||||||
|
|
||||||
|
self.assertEqual(result2.normal_text, "")
|
||||||
|
self.assertEqual(len(result2.calls), 1)
|
||||||
|
self.assertEqual(result2.calls[0].name, "get_weather")
|
||||||
|
|
||||||
|
# Check the parameters
|
||||||
|
params = json.loads(result2.calls[0].parameters)
|
||||||
|
self.assertEqual(params["location"], "Paris")
|
||||||
|
self.assertEqual(
|
||||||
|
self.detector._buffer, ""
|
||||||
|
) # Buffer should be cleared after processing
|
||||||
|
|
||||||
|
def test_parse_streaming_bracket_without_text_before(self):
|
||||||
|
"""Test parsing a tool call that starts at the beginning of the text."""
|
||||||
|
text = "[search(query='python programming')]"
|
||||||
|
result = self.detector.parse_streaming_increment(text, self.tools)
|
||||||
|
|
||||||
|
self.assertEqual(result.normal_text, "")
|
||||||
|
self.assertEqual(len(result.calls), 1)
|
||||||
|
self.assertEqual(result.calls[0].name, "search")
|
||||||
|
|
||||||
|
# Check the parameters
|
||||||
|
params = json.loads(result.calls[0].parameters)
|
||||||
|
self.assertEqual(params["query"], "python programming")
|
||||||
|
|
||||||
|
def test_parse_streaming_text_after_tool_call(self):
|
||||||
|
"""Test parsing text that appears after a tool call."""
|
||||||
|
# First chunk with complete tool call and some text after
|
||||||
|
text = "[get_weather(location='Tokyo')] Here's the forecast:"
|
||||||
|
result = self.detector.parse_streaming_increment(text, self.tools)
|
||||||
|
|
||||||
|
self.assertEqual(result.normal_text, "")
|
||||||
|
self.assertEqual(len(result.calls), 1)
|
||||||
|
self.assertEqual(result.calls[0].name, "get_weather")
|
||||||
|
self.assertEqual(
|
||||||
|
self.detector._buffer, " Here's the forecast:"
|
||||||
|
) # Text after tool call remains in buffer
|
||||||
|
|
||||||
|
# Process the remaining text in buffer
|
||||||
|
result2 = self.detector.parse_streaming_increment("", self.tools)
|
||||||
|
self.assertEqual(result2.normal_text, " Here's the forecast:")
|
||||||
|
self.assertEqual(result2.calls, [])
|
||||||
|
self.assertEqual(self.detector._buffer, "") # Buffer should be cleared
|
||||||
|
|
||||||
|
def test_parse_streaming_multiple_tool_calls(self):
|
||||||
|
"""Test parsing multiple tool calls in sequence."""
|
||||||
|
text = "[get_weather(location='Berlin')] and [search(query='restaurants')]"
|
||||||
|
|
||||||
|
# First tool call
|
||||||
|
result1 = self.detector.parse_streaming_increment(text, self.tools)
|
||||||
|
self.assertEqual(len(result1.calls), 1)
|
||||||
|
self.assertEqual(result1.calls[0].name, "get_weather")
|
||||||
|
self.assertEqual(self.detector._buffer, " and [search(query='restaurants')]")
|
||||||
|
|
||||||
|
# Second tool call
|
||||||
|
result2 = self.detector.parse_streaming_increment("", self.tools)
|
||||||
|
self.assertEqual(result2.normal_text, " and ")
|
||||||
|
self.assertEqual(len(result2.calls), 1)
|
||||||
|
self.assertEqual(result2.calls[0].name, "search")
|
||||||
|
self.assertEqual(self.detector._buffer, "")
|
||||||
|
|
||||||
|
def test_parse_streaming_opening_bracket_only(self):
|
||||||
|
"""Test parsing text with only an opening bracket but no closing bracket."""
|
||||||
|
text = "Let's try this: ["
|
||||||
|
result = self.detector.parse_streaming_increment(text, self.tools)
|
||||||
|
|
||||||
|
self.assertEqual(result.normal_text, "Let's try this: ")
|
||||||
|
self.assertEqual(result.calls, [])
|
||||||
|
self.assertEqual(
|
||||||
|
self.detector._buffer, "["
|
||||||
|
) # Opening bracket remains in buffer
|
||||||
|
|
||||||
|
def test_parse_streaming_nested_brackets(self):
|
||||||
|
"""Test parsing tool calls with nested brackets in arguments."""
|
||||||
|
# Test with list argument containing nested brackets
|
||||||
|
text = "[get_weather(location='New York', unit='celsius', data=[1, 2, 3])]"
|
||||||
|
result = self.detector.parse_streaming_increment(text, self.tools)
|
||||||
|
|
||||||
|
self.assertEqual(result.normal_text, "")
|
||||||
|
self.assertEqual(len(result.calls), 1)
|
||||||
|
self.assertEqual(result.calls[0].name, "get_weather")
|
||||||
|
self.assertEqual(self.detector._buffer, "")
|
||||||
|
|
||||||
|
# Check the parameters
|
||||||
|
params = json.loads(result.calls[0].parameters)
|
||||||
|
self.assertEqual(params["location"], "New York")
|
||||||
|
self.assertEqual(params["unit"], "celsius")
|
||||||
|
self.assertEqual(params["data"], [1, 2, 3])
|
||||||
|
|
||||||
|
def test_parse_streaming_nested_brackets_dict(self):
|
||||||
|
"""Test parsing tool calls with nested dictionaries and lists."""
|
||||||
|
# Test with nested dict and list arguments
|
||||||
|
text = "[search(query='test', config={'options': [1, 2], 'nested': {'key': 'value'}})]"
|
||||||
|
result = self.detector.parse_streaming_increment(text, self.tools)
|
||||||
|
|
||||||
|
self.assertEqual(result.normal_text, "")
|
||||||
|
self.assertEqual(len(result.calls), 1)
|
||||||
|
self.assertEqual(result.calls[0].name, "search")
|
||||||
|
self.assertEqual(self.detector._buffer, "")
|
||||||
|
|
||||||
|
# Check the parameters
|
||||||
|
params = json.loads(result.calls[0].parameters)
|
||||||
|
self.assertEqual(params["query"], "test")
|
||||||
|
self.assertEqual(params["config"]["options"], [1, 2])
|
||||||
|
self.assertEqual(params["config"]["nested"]["key"], "value")
|
||||||
|
|
||||||
|
def test_parse_streaming_multiple_tools_with_nested_brackets(self):
|
||||||
|
"""Test parsing multiple tool calls with nested brackets."""
|
||||||
|
text = "[get_weather(location='Paris', data=[10, 20]), search(query='test', filters=['a', 'b'])]"
|
||||||
|
result = self.detector.parse_streaming_increment(text, self.tools)
|
||||||
|
|
||||||
|
self.assertEqual(result.normal_text, "")
|
||||||
|
self.assertEqual(len(result.calls), 2)
|
||||||
|
self.assertEqual(self.detector._buffer, "")
|
||||||
|
|
||||||
|
# Check first tool call
|
||||||
|
params1 = json.loads(result.calls[0].parameters)
|
||||||
|
self.assertEqual(result.calls[0].name, "get_weather")
|
||||||
|
self.assertEqual(params1["location"], "Paris")
|
||||||
|
self.assertEqual(params1["data"], [10, 20])
|
||||||
|
|
||||||
|
# Check second tool call
|
||||||
|
params2 = json.loads(result.calls[1].parameters)
|
||||||
|
self.assertEqual(result.calls[1].name, "search")
|
||||||
|
self.assertEqual(params2["query"], "test")
|
||||||
|
self.assertEqual(params2["filters"], ["a", "b"])
|
||||||
|
|
||||||
|
def test_parse_streaming_partial_nested_brackets(self):
|
||||||
|
"""Test parsing partial tool calls with nested brackets across chunks."""
|
||||||
|
# First chunk with nested brackets but incomplete
|
||||||
|
text1 = "Here's a call: [get_weather(location='Tokyo', data=[1, 2"
|
||||||
|
result1 = self.detector.parse_streaming_increment(text1, self.tools)
|
||||||
|
|
||||||
|
self.assertEqual(result1.normal_text, "Here's a call: ")
|
||||||
|
self.assertEqual(result1.calls, [])
|
||||||
|
self.assertEqual(
|
||||||
|
self.detector._buffer, "[get_weather(location='Tokyo', data=[1, 2"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Second chunk completing the nested brackets
|
||||||
|
text2 = ", 3])]"
|
||||||
|
result2 = self.detector.parse_streaming_increment(text2, self.tools)
|
||||||
|
|
||||||
|
self.assertEqual(result2.normal_text, "")
|
||||||
|
self.assertEqual(len(result2.calls), 1)
|
||||||
|
self.assertEqual(result2.calls[0].name, "get_weather")
|
||||||
|
self.assertEqual(self.detector._buffer, "")
|
||||||
|
|
||||||
|
# Check the parameters
|
||||||
|
params = json.loads(result2.calls[0].parameters)
|
||||||
|
self.assertEqual(params["location"], "Tokyo")
|
||||||
|
self.assertEqual(params["data"], [1, 2, 3])
|
||||||
|
|
||||||
|
|
||||||
|
class TestEBNFGeneration(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
# Create sample tools for testing
|
||||||
|
self.tools = [
|
||||||
|
Tool(
|
||||||
|
type="function",
|
||||||
|
function=Function(
|
||||||
|
name="get_weather",
|
||||||
|
description="Get weather information",
|
||||||
|
parameters={
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Location to get weather for",
|
||||||
|
},
|
||||||
|
"unit": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Temperature unit",
|
||||||
|
"enum": ["celsius", "fahrenheit"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["location"],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
),
|
||||||
|
Tool(
|
||||||
|
type="function",
|
||||||
|
function=Function(
|
||||||
|
name="search",
|
||||||
|
description="Search for information",
|
||||||
|
parameters={
|
||||||
|
"properties": {
|
||||||
|
"query": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Search query",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["query"],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
self.tokenizer = get_tokenizer(DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
|
||||||
|
tokenizer_info = TokenizerInfo.from_huggingface(self.tokenizer)
|
||||||
|
self.grammar_compiler = GrammarCompiler(tokenizer_info=tokenizer_info)
|
||||||
|
|
||||||
|
# Initialize all detectors
|
||||||
|
self.pythonic_detector = PythonicDetector()
|
||||||
|
self.deepseekv3_detector = DeepSeekV3Detector()
|
||||||
|
self.llama32_detector = Llama32Detector()
|
||||||
|
self.mistral_detector = MistralDetector()
|
||||||
|
self.qwen25_detector = Qwen25Detector()
|
||||||
|
|
||||||
|
def test_pythonic_detector_ebnf(self):
|
||||||
|
"""Test that the PythonicDetector generates valid EBNF."""
|
||||||
|
ebnf = self.pythonic_detector.build_ebnf(self.tools)
|
||||||
|
self.assertIsNotNone(ebnf)
|
||||||
|
|
||||||
|
# Check that the EBNF contains expected patterns
|
||||||
|
self.assertIn('call_get_weather ::= "get_weather" "(" ', ebnf)
|
||||||
|
self.assertIn('"location" "=" basic_string', ebnf)
|
||||||
|
self.assertIn('[ "unit" "=" ("\\"celsius\\"" | "\\"fahrenheit\\"") ]', ebnf)
|
||||||
|
|
||||||
|
# Validate that the EBNF can be compiled by GrammarCompiler
|
||||||
|
try:
|
||||||
|
ctx = self.grammar_compiler.compile_grammar(ebnf)
|
||||||
|
self.assertIsNotNone(ctx, "EBNF should be valid and compile successfully")
|
||||||
|
except RuntimeError as e:
|
||||||
|
self.fail(f"Failed to compile EBNF: {e}")
|
||||||
|
|
||||||
|
def test_deepseekv3_detector_ebnf(self):
|
||||||
|
"""Test that the DeepSeekV3Detector generates valid EBNF."""
|
||||||
|
ebnf = self.deepseekv3_detector.build_ebnf(self.tools)
|
||||||
|
self.assertIsNotNone(ebnf)
|
||||||
|
|
||||||
|
# Check that the EBNF contains expected patterns
|
||||||
|
self.assertIn("<|tool▁calls▁begin|>", ebnf)
|
||||||
|
self.assertIn("<|tool▁call▁begin|>function<|tool▁sep|>get_weather", ebnf)
|
||||||
|
self.assertIn('\\"location\\"" ":" basic_string ', ebnf)
|
||||||
|
|
||||||
|
# Validate that the EBNF can be compiled by GrammarCompiler
|
||||||
|
try:
|
||||||
|
ctx = self.grammar_compiler.compile_grammar(ebnf)
|
||||||
|
self.assertIsNotNone(ctx, "EBNF should be valid and compile successfully")
|
||||||
|
except RuntimeError as e:
|
||||||
|
self.fail(f"Failed to compile EBNF: {e}")
|
||||||
|
|
||||||
|
def test_llama32_detector_ebnf(self):
|
||||||
|
"""Test that the Llama32Detector generates valid EBNF."""
|
||||||
|
ebnf = self.llama32_detector.build_ebnf(self.tools)
|
||||||
|
self.assertIsNotNone(ebnf)
|
||||||
|
|
||||||
|
# Check that the EBNF contains expected patterns
|
||||||
|
self.assertIn('\\"name\\"" ":" "\\"get_weather\\"', ebnf)
|
||||||
|
self.assertIn('"\\"arguments\\"" ":"', ebnf)
|
||||||
|
|
||||||
|
# Validate that the EBNF can be compiled by GrammarCompiler
|
||||||
|
try:
|
||||||
|
ctx = self.grammar_compiler.compile_grammar(ebnf)
|
||||||
|
self.assertIsNotNone(ctx, "EBNF should be valid and compile successfully")
|
||||||
|
except RuntimeError as e:
|
||||||
|
self.fail(f"Failed to compile EBNF: {e}")
|
||||||
|
|
||||||
|
def test_mistral_detector_ebnf(self):
|
||||||
|
"""Test that the MistralDetector generates valid EBNF."""
|
||||||
|
ebnf = self.mistral_detector.build_ebnf(self.tools)
|
||||||
|
self.assertIsNotNone(ebnf)
|
||||||
|
|
||||||
|
# Check that the EBNF contains expected patterns
|
||||||
|
self.assertIn('"[TOOL_CALLS] ["', ebnf)
|
||||||
|
self.assertIn("call_get_weather | call_search", ebnf)
|
||||||
|
self.assertIn('"\\"arguments\\"" ":"', ebnf)
|
||||||
|
|
||||||
|
# Validate that the EBNF can be compiled by GrammarCompiler
|
||||||
|
try:
|
||||||
|
ctx = self.grammar_compiler.compile_grammar(ebnf)
|
||||||
|
self.assertIsNotNone(ctx, "EBNF should be valid and compile successfully")
|
||||||
|
except RuntimeError as e:
|
||||||
|
self.fail(f"Failed to compile EBNF: {e}")
|
||||||
|
|
||||||
|
def test_qwen25_detector_ebnf(self):
|
||||||
|
"""Test that the Qwen25Detector generates valid EBNF."""
|
||||||
|
ebnf = self.qwen25_detector.build_ebnf(self.tools)
|
||||||
|
self.assertIsNotNone(ebnf)
|
||||||
|
|
||||||
|
# Check that the EBNF contains expected patterns
|
||||||
|
self.assertIn("<tool_call>", ebnf)
|
||||||
|
self.assertIn('\\"name\\"" ":" "\\"get_weather\\"', ebnf)
|
||||||
|
self.assertIn('"\\"arguments\\"" ":"', ebnf)
|
||||||
|
|
||||||
|
# Validate that the EBNF can be compiled by GrammarCompiler
|
||||||
|
try:
|
||||||
|
ctx = self.grammar_compiler.compile_grammar(ebnf)
|
||||||
|
self.assertIsNotNone(ctx, "EBNF should be valid and compile successfully")
|
||||||
|
except RuntimeError as e:
|
||||||
|
self.fail(f"Failed to compile EBNF: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
@@ -290,6 +290,151 @@ class TestOpenAIServerFunctionCalling(CustomTestCase):
|
|||||||
self.assertEqual(str(args_obj["int_a"]), "5", "Parameter int_a should be 5")
|
self.assertEqual(str(args_obj["int_a"]), "5", "Parameter int_a should be 5")
|
||||||
self.assertEqual(str(args_obj["int_b"]), "7", "Parameter int_b should be 7")
|
self.assertEqual(str(args_obj["int_b"]), "7", "Parameter int_b should be 7")
|
||||||
|
|
||||||
|
def test_function_call_required(self):
|
||||||
|
"""
|
||||||
|
Test: Whether tool_choice: "required" works as expected
|
||||||
|
- When tool_choice == "required", the model should return one or more tool_calls.
|
||||||
|
"""
|
||||||
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||||
|
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "sub",
|
||||||
|
"description": "Compute the difference of two integers",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"int_a": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "First integer",
|
||||||
|
},
|
||||||
|
"int_b": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Second integer",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["int_a", "int_b"],
|
||||||
|
},
|
||||||
|
"strict": True,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "use this to get latest weather information for a city given its name",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"city": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "name of the city to get weather for",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["city"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
messages = [{"role": "user", "content": "What is the capital of France?"}]
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model=self.model,
|
||||||
|
messages=messages,
|
||||||
|
temperature=0.8,
|
||||||
|
top_p=0.8,
|
||||||
|
stream=False,
|
||||||
|
tools=tools,
|
||||||
|
tool_choice="required",
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_calls = response.choices[0].message.tool_calls
|
||||||
|
self.assertIsNotNone(tool_calls, "No tool_calls in the response")
|
||||||
|
function_name = tool_calls[0].function.name
|
||||||
|
arguments = tool_calls[0].function.arguments
|
||||||
|
args_obj = json.loads(arguments)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
function_name, "get_weather", "Function name should be 'get_weather'"
|
||||||
|
)
|
||||||
|
self.assertIn("city", args_obj, "Function arguments should have 'city'")
|
||||||
|
self.assertIn(
|
||||||
|
"Paris", args_obj["city"], "Parameter city should contain 'Paris'"
|
||||||
|
) # might be flaky
|
||||||
|
|
||||||
|
def test_function_call_specific(self):
|
||||||
|
"""
|
||||||
|
Test: Whether tool_choice: ToolChoice works as expected
|
||||||
|
- When tool_choice is a specific ToolChoice, the model should return one or more tool_calls.
|
||||||
|
"""
|
||||||
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||||
|
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "sub",
|
||||||
|
"description": "Compute the difference of two integers",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"int_a": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "First integer",
|
||||||
|
},
|
||||||
|
"int_b": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Second integer",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["int_a", "int_b"],
|
||||||
|
},
|
||||||
|
"strict": True,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "use this to get latest weather information for a city given its name",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"city": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "name of the city to get weather for",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["city"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
messages = [{"role": "user", "content": "What is the capital of France?"}]
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model=self.model,
|
||||||
|
messages=messages,
|
||||||
|
temperature=0.8,
|
||||||
|
top_p=0.8,
|
||||||
|
stream=False,
|
||||||
|
tools=tools,
|
||||||
|
tool_choice={"type": "function", "function": {"name": "get_weather"}},
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_calls = response.choices[0].message.tool_calls
|
||||||
|
self.assertIsNotNone(tool_calls, "No tool_calls in the response")
|
||||||
|
function_name = tool_calls[0].function.name
|
||||||
|
arguments = tool_calls[0].function.arguments
|
||||||
|
args_obj = json.loads(arguments)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
function_name, "get_weather", "Function name should be 'get_weather'"
|
||||||
|
)
|
||||||
|
self.assertIn("city", args_obj, "Function arguments should have 'city'")
|
||||||
|
|
||||||
|
|
||||||
class TestOpenAIPythonicFunctionCalling(CustomTestCase):
|
class TestOpenAIPythonicFunctionCalling(CustomTestCase):
|
||||||
PYTHONIC_TOOLS = [
|
PYTHONIC_TOOLS = [
|
||||||
@@ -385,11 +530,13 @@ class TestOpenAIPythonicFunctionCalling(CustomTestCase):
|
|||||||
stream=False,
|
stream=False,
|
||||||
)
|
)
|
||||||
tool_calls = response.choices[0].message.tool_calls
|
tool_calls = response.choices[0].message.tool_calls
|
||||||
self.assertIsInstance(tool_calls, list)
|
self.assertIsInstance(tool_calls, list, "No tool_calls found")
|
||||||
self.assertGreaterEqual(len(tool_calls), 1)
|
self.assertGreaterEqual(len(tool_calls), 1)
|
||||||
names = [tc.function.name for tc in tool_calls]
|
names = [tc.function.name for tc in tool_calls]
|
||||||
self.assertIn("get_weather", names)
|
self.assertTrue(
|
||||||
self.assertIn("get_tourist_attractions", names)
|
"get_weather" in names or "get_tourist_attractions" in names,
|
||||||
|
f"Function name '{names}' should container either 'get_weather' or 'get_tourist_attractions'",
|
||||||
|
)
|
||||||
|
|
||||||
def test_pythonic_tool_call_streaming(self):
|
def test_pythonic_tool_call_streaming(self):
|
||||||
"""
|
"""
|
||||||
@@ -419,8 +566,10 @@ class TestOpenAIPythonicFunctionCalling(CustomTestCase):
|
|||||||
|
|
||||||
self.assertTrue(found_tool_calls, "No tool_calls found in streaming response")
|
self.assertTrue(found_tool_calls, "No tool_calls found in streaming response")
|
||||||
self.assertTrue(found_index, "No index field found in any streamed tool_call")
|
self.assertTrue(found_index, "No index field found in any streamed tool_call")
|
||||||
self.assertIn("get_weather", found_names)
|
self.assertTrue(
|
||||||
self.assertIn("get_tourist_attractions", found_names)
|
"get_weather" in found_names or "get_tourist_attractions" in found_names,
|
||||||
|
f"Function name '{found_names}' should container either 'get_weather' or 'get_tourist_attractions'",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
Reference in New Issue
Block a user