feat(Tool Calling): Support required and specific function mode (#6550)
This commit is contained in:
@@ -54,10 +54,12 @@
|
||||
"source": [
|
||||
"Note that `--tool-call-parser` defines the parser used to interpret responses. Currently supported parsers include:\n",
|
||||
"\n",
|
||||
"- llama3: Llama 3.1 / 3.2 (e.g. meta-llama/Llama-3.1-8B-Instruct, meta-llama/Llama-3.2-1B-Instruct).\n",
|
||||
"- llama3: Llama 3.1 / 3.2 / 3.3 (e.g. meta-llama/Llama-3.1-8B-Instruct, meta-llama/Llama-3.2-1B-Instruct, meta-llama/Llama-3.3-70B-Instruct).\n",
|
||||
"- llama4: Llama 4 (e.g. meta-llama/Llama-4-Scout-17B-16E-Instruct).\n",
|
||||
"- mistral: Mistral (e.g. mistralai/Mistral-7B-Instruct-v0.3, mistralai/Mistral-Nemo-Instruct-2407, mistralai/\n",
|
||||
"Mistral-Nemo-Instruct-2407, mistralai/Mistral-7B-v0.3).\n",
|
||||
"- qwen25: Qwen 2.5 (e.g. Qwen/Qwen2.5-1.5B-Instruct, Qwen/Qwen2.5-7B-Instruct) and QwQ (i.e. Qwen/QwQ-32B). Especially, for QwQ, we can enable the reasoning parser together with tool call parser, details about reasoning parser can be found in [reasoning parser](https://docs.sglang.ai/backend/separate_reasoning.html)."
|
||||
"- qwen25: Qwen 2.5 (e.g. Qwen/Qwen2.5-1.5B-Instruct, Qwen/Qwen2.5-7B-Instruct) and QwQ (i.e. Qwen/QwQ-32B). Especially, for QwQ, we can enable the reasoning parser together with tool call parser, details about reasoning parser can be found in [reasoning parser](https://docs.sglang.ai/backend/separate_reasoning.html).\n",
|
||||
"- deepseekv3: DeepSeek-v3 (e.g., deepseek-ai/DeepSeek-V3-0324).\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -360,6 +362,164 @@
|
||||
"print(final_response.choices[0].message.content)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Tool Choice Mode\n",
|
||||
"\n",
|
||||
"SGLang supports OpenAI's `tool_choice` parameter to control when and which tools the model should call. This feature is implemented using EBNF (Extended Backus-Naur Form) grammar to ensure reliable tool calling behavior.\n",
|
||||
"\n",
|
||||
"### Supported Tool Choice Options\n",
|
||||
"\n",
|
||||
"- **`tool_choice=\"required\"`**: Forces the model to call at least one tool\n",
|
||||
"- **`tool_choice={\"type\": \"function\", \"function\": {\"name\": \"specific_function\"}}`**: Forces the model to call a specific function\n",
|
||||
"\n",
|
||||
"### Backend Compatibility\n",
|
||||
"\n",
|
||||
"Tool choice is fully supported with the **Xgrammar backend**, which is the default grammar backend (`--grammar-backend xgrammar`). However, it may not be fully supported with other backends such as `outlines`.\n",
|
||||
"\n",
|
||||
"### Example: Required Tool Choice"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Response with tool_choice='required':\n",
|
||||
"Content: None\n",
|
||||
"Tool calls: [ChatCompletionMessageToolCall(id='call_NFO3TSZuRRO8Eu3Cv79uiQ', function=Function(arguments='{\"city\": \"Paris\", \"unit\": \"celsius\"}', name='get_current_weather'), type='function', index=0)]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from openai import OpenAI\n",
|
||||
"import json\n",
|
||||
"from sglang.utils import wait_for_server, print_highlight, terminate_process\n",
|
||||
"from sglang.test.test_utils import is_in_ci\n",
|
||||
"\n",
|
||||
"if is_in_ci():\n",
|
||||
" from patch import launch_server_cmd\n",
|
||||
"else:\n",
|
||||
" from sglang.utils import launch_server_cmd\n",
|
||||
" import nest_asyncio\n",
|
||||
"\n",
|
||||
" nest_asyncio.apply()\n",
|
||||
"\n",
|
||||
"# Start a new server session for tool choice examples\n",
|
||||
"server_process_tool_choice, port_tool_choice = launch_server_cmd(\n",
|
||||
" \"python3 -m sglang.launch_server --model-path Qwen/Qwen2.5-7B-Instruct --tool-call-parser qwen25 --host 0.0.0.0\"\n",
|
||||
")\n",
|
||||
"wait_for_server(f\"http://localhost:{port_tool_choice}\")\n",
|
||||
"\n",
|
||||
"# Initialize client for tool choice examples\n",
|
||||
"client_tool_choice = OpenAI(\n",
|
||||
" api_key=\"None\", base_url=f\"http://0.0.0.0:{port_tool_choice}/v1\"\n",
|
||||
")\n",
|
||||
"model_name_tool_choice = client_tool_choice.models.list().data[0].id\n",
|
||||
"\n",
|
||||
"# Example with tool_choice=\"required\" - forces the model to call a tool\n",
|
||||
"messages_required = [\n",
|
||||
" {\"role\": \"user\", \"content\": \"Hello, what is the capital of France?\"}\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"# Define tools\n",
|
||||
"tools = [\n",
|
||||
" {\n",
|
||||
" \"type\": \"function\",\n",
|
||||
" \"function\": {\n",
|
||||
" \"name\": \"get_current_weather\",\n",
|
||||
" \"description\": \"Get the current weather in a given location\",\n",
|
||||
" \"parameters\": {\n",
|
||||
" \"type\": \"object\",\n",
|
||||
" \"properties\": {\n",
|
||||
" \"city\": {\n",
|
||||
" \"type\": \"string\",\n",
|
||||
" \"description\": \"The city to find the weather for, e.g. 'San Francisco'\",\n",
|
||||
" },\n",
|
||||
" \"unit\": {\n",
|
||||
" \"type\": \"string\",\n",
|
||||
" \"description\": \"The unit to fetch the temperature in\",\n",
|
||||
" \"enum\": [\"celsius\", \"fahrenheit\"],\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
" \"required\": [\"city\", \"unit\"],\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
" }\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"response_required = client_tool_choice.chat.completions.create(\n",
|
||||
" model=model_name_tool_choice,\n",
|
||||
" messages=messages_required,\n",
|
||||
" temperature=0,\n",
|
||||
" max_tokens=1024,\n",
|
||||
" tools=tools,\n",
|
||||
" tool_choice=\"required\", # Force the model to call a tool\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print_highlight(\"Response with tool_choice='required':\")\n",
|
||||
"print(\"Content:\", response_required.choices[0].message.content)\n",
|
||||
"print(\"Tool calls:\", response_required.choices[0].message.tool_calls)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Example: Specific Function Choice\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Response with specific function choice:\n",
|
||||
"Content: None\n",
|
||||
"Tool calls: [ChatCompletionMessageToolCall(id='call_fGL_1qsPQFqntNBPkSynJw', function=Function(arguments='{\"city\": \"Sophia Antipolis\", \"unit\": \"celsius\"}', name='get_current_weather'), type='function', index=0)]\n",
|
||||
"Called function: get_current_weather\n",
|
||||
"Arguments: {\"city\": \"Sophia Antipolis\", \"unit\": \"celsius\"}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Example with specific function choice - forces the model to call a specific function\n",
|
||||
"messages_specific = [\n",
|
||||
" {\"role\": \"user\", \"content\": \"What are the most attactive places in France?\"}\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"response_specific = client_tool_choice.chat.completions.create(\n",
|
||||
" model=model_name_tool_choice,\n",
|
||||
" messages=messages_specific,\n",
|
||||
" temperature=0,\n",
|
||||
" max_tokens=1024,\n",
|
||||
" tools=tools,\n",
|
||||
" tool_choice={\n",
|
||||
" \"type\": \"function\",\n",
|
||||
" \"function\": {\"name\": \"get_current_weather\"},\n",
|
||||
" }, # Force the model to call the specific get_current_weather function\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print_highlight(\"Response with specific function choice:\")\n",
|
||||
"print(\"Content:\", response_specific.choices[0].message.content)\n",
|
||||
"print(\"Tool calls:\", response_specific.choices[0].message.tool_calls)\n",
|
||||
"\n",
|
||||
"if response_specific.choices[0].message.tool_calls:\n",
|
||||
" tool_call = response_specific.choices[0].message.tool_calls[0]\n",
|
||||
" print(f\"Called function: {tool_call.function.name}\")\n",
|
||||
" print(f\"Arguments: {tool_call.function.arguments}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
@@ -444,7 +604,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import sglang as sgl\n",
|
||||
"from sglang.srt.function_call_parser import FunctionCallParser\n",
|
||||
"from sglang.srt.function_call.function_call_parser import FunctionCallParser\n",
|
||||
"from sglang.srt.managers.io_struct import Tool, Function\n",
|
||||
"\n",
|
||||
"llm = sgl.Engine(model_path=\"Qwen/Qwen2.5-7B-Instruct\")\n",
|
||||
|
||||
@@ -47,7 +47,7 @@ from sglang.srt.disaggregation.utils import (
|
||||
register_disaggregation_server,
|
||||
)
|
||||
from sglang.srt.entrypoints.engine import _launch_subprocesses
|
||||
from sglang.srt.function_call_parser import FunctionCallParser
|
||||
from sglang.srt.function_call.function_call_parser import FunctionCallParser
|
||||
from sglang.srt.managers.io_struct import (
|
||||
AbortReq,
|
||||
CloseSessionReqInput,
|
||||
|
||||
250
python/sglang/srt/function_call/base_format_detector.py
Normal file
250
python/sglang/srt/function_call/base_format_detector.py
Normal file
@@ -0,0 +1,250 @@
|
||||
import json
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from partial_json_parser.core.exceptions import MalformedJSON
|
||||
from partial_json_parser.core.options import Allow
|
||||
|
||||
from sglang.srt.function_call.core_types import (
|
||||
StreamingParseResult,
|
||||
ToolCallItem,
|
||||
_GetInfoFunc,
|
||||
)
|
||||
from sglang.srt.function_call.utils import (
|
||||
_find_common_prefix,
|
||||
_is_complete_json,
|
||||
_partial_json_loads,
|
||||
)
|
||||
from sglang.srt.openai_api.protocol import Tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseFormatDetector(ABC):
|
||||
"""Base class providing two sets of interfaces: one-time and streaming incremental."""
|
||||
|
||||
def __init__(self):
|
||||
# initialize properties used for state when parsing tool calls in
|
||||
self._buffer = ""
|
||||
# streaming mode
|
||||
self.prev_tool_call_arr: List[Dict] = []
|
||||
self.current_tool_id: int = -1
|
||||
self.current_tool_name_sent: bool = False
|
||||
self.streamed_args_for_tool: List[str] = (
|
||||
[]
|
||||
) # map what has been streamed for each tool so far to a list
|
||||
self.bot_token = ""
|
||||
self.eot_token = ""
|
||||
|
||||
def parse_base_json(self, action: Any, tools: List[Tool]) -> List[ToolCallItem]:
|
||||
tool_indices = {
|
||||
tool.function.name: i for i, tool in enumerate(tools) if tool.function.name
|
||||
}
|
||||
if not isinstance(action, list):
|
||||
action = [action]
|
||||
|
||||
results = []
|
||||
for act in action:
|
||||
name = act.get("name")
|
||||
if name and name in tool_indices:
|
||||
results.append(
|
||||
ToolCallItem(
|
||||
tool_index=tool_indices[name],
|
||||
name=name,
|
||||
parameters=json.dumps(
|
||||
act.get("parameters") or act.get("arguments", {}),
|
||||
ensure_ascii=False,
|
||||
),
|
||||
)
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Model attempted to call undefined function: {name}")
|
||||
|
||||
return results
|
||||
|
||||
@abstractmethod
|
||||
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
||||
"""
|
||||
Parses the text in one go. Returns success=True if the format matches, otherwise False.
|
||||
Note that leftover_text here represents "content that this parser will not consume further".
|
||||
"""
|
||||
action = json.loads(text)
|
||||
return StreamingParseResult(calls=self.parse_base_json(action, tools))
|
||||
|
||||
def parse_streaming_increment(
|
||||
self, new_text: str, tools: List[Tool]
|
||||
) -> StreamingParseResult:
|
||||
"""
|
||||
Streaming incremental parsing with tool validation.
|
||||
"""
|
||||
# Append new text to buffer
|
||||
self._buffer += new_text
|
||||
current_text = self._buffer
|
||||
if not (self.bot_token in current_text or current_text.startswith("{")):
|
||||
self._buffer = ""
|
||||
if self.eot_token in new_text:
|
||||
new_text = new_text.replace(self.eot_token, "")
|
||||
return StreamingParseResult(normal_text=new_text)
|
||||
|
||||
# Build tool indices if not already built
|
||||
if not hasattr(self, "_tool_indices"):
|
||||
self._tool_indices = {
|
||||
tool.function.name: i
|
||||
for i, tool in enumerate(tools)
|
||||
if tool.function and tool.function.name
|
||||
}
|
||||
|
||||
flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR
|
||||
try:
|
||||
tool_call_arr = []
|
||||
is_complete = []
|
||||
try:
|
||||
start_idx = (
|
||||
len(self.bot_token)
|
||||
if current_text.startswith(self.bot_token)
|
||||
else 0
|
||||
)
|
||||
while start_idx < len(current_text):
|
||||
(obj, end_idx) = _partial_json_loads(
|
||||
current_text[start_idx:], flags
|
||||
)
|
||||
is_complete.append(
|
||||
_is_complete_json(current_text[start_idx : start_idx + end_idx])
|
||||
)
|
||||
start_idx += end_idx + len("; ")
|
||||
|
||||
# Validate tool name if present
|
||||
if "name" in obj and obj["name"] not in self._tool_indices:
|
||||
# Invalid tool name - reset state
|
||||
self._buffer = ""
|
||||
self.current_tool_id = -1
|
||||
self.current_tool_name_sent = False
|
||||
if self.streamed_args_for_tool:
|
||||
self.streamed_args_for_tool.pop()
|
||||
return StreamingParseResult()
|
||||
|
||||
# Handle parameters/arguments consistency
|
||||
if "parameters" in obj:
|
||||
assert (
|
||||
"arguments" not in obj
|
||||
), "model generated both parameters and arguments"
|
||||
obj["arguments"] = obj["parameters"]
|
||||
tool_call_arr.append(obj)
|
||||
|
||||
except MalformedJSON:
|
||||
return StreamingParseResult()
|
||||
|
||||
if len(tool_call_arr) == 0:
|
||||
return StreamingParseResult()
|
||||
|
||||
current_tool_call: Dict = (
|
||||
tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {}
|
||||
)
|
||||
|
||||
# Handle new tool in array
|
||||
if len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1:
|
||||
if self.current_tool_id >= 0:
|
||||
cur_arguments = current_tool_call.get("arguments")
|
||||
if cur_arguments:
|
||||
cur_args_json = json.dumps(cur_arguments)
|
||||
sent = len(self.streamed_args_for_tool[self.current_tool_id])
|
||||
argument_diff = cur_args_json[sent:]
|
||||
|
||||
res = StreamingParseResult(
|
||||
calls=[
|
||||
ToolCallItem(
|
||||
tool_index=self.current_tool_id,
|
||||
name="",
|
||||
parameters=argument_diff,
|
||||
)
|
||||
],
|
||||
)
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id
|
||||
] += argument_diff
|
||||
else:
|
||||
res = StreamingParseResult()
|
||||
else:
|
||||
res = StreamingParseResult()
|
||||
|
||||
self.current_tool_id = len(tool_call_arr) - 1
|
||||
self.current_tool_name_sent = False
|
||||
self.streamed_args_for_tool.append("")
|
||||
return res
|
||||
|
||||
# Handle tool name
|
||||
elif not self.current_tool_name_sent:
|
||||
function_name = current_tool_call.get("name")
|
||||
if function_name and function_name in self._tool_indices:
|
||||
res = StreamingParseResult(
|
||||
calls=[
|
||||
ToolCallItem(
|
||||
tool_index=self._tool_indices[function_name],
|
||||
name=function_name,
|
||||
parameters="",
|
||||
)
|
||||
],
|
||||
)
|
||||
self.current_tool_name_sent = True
|
||||
else:
|
||||
res = StreamingParseResult()
|
||||
|
||||
# Handle streaming arguments
|
||||
else:
|
||||
cur_arguments = current_tool_call.get("arguments")
|
||||
res = StreamingParseResult()
|
||||
|
||||
if cur_arguments:
|
||||
sent = len(self.streamed_args_for_tool[self.current_tool_id])
|
||||
cur_args_json = json.dumps(cur_arguments)
|
||||
prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get(
|
||||
"arguments"
|
||||
)
|
||||
|
||||
argument_diff = None
|
||||
if is_complete[self.current_tool_id]:
|
||||
argument_diff = cur_args_json[sent:]
|
||||
self._buffer = ""
|
||||
self.prev_tool_call_arr[self.current_tool_id].clear()
|
||||
self.current_tool_name_sent = False
|
||||
self.streamed_args_for_tool[self.current_tool_id] = ""
|
||||
|
||||
elif prev_arguments:
|
||||
prev_args_json = json.dumps(prev_arguments)
|
||||
if cur_args_json != prev_args_json:
|
||||
prefix = _find_common_prefix(prev_args_json, cur_args_json)
|
||||
argument_diff = prefix[sent:]
|
||||
|
||||
if argument_diff is not None:
|
||||
res = StreamingParseResult(
|
||||
calls=[
|
||||
ToolCallItem(
|
||||
tool_index=self.current_tool_id,
|
||||
parameters=argument_diff,
|
||||
)
|
||||
],
|
||||
)
|
||||
if not is_complete[self.current_tool_id]:
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id
|
||||
] += argument_diff
|
||||
|
||||
self.prev_tool_call_arr = tool_call_arr
|
||||
return res
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in parse_streaming_increment: {e}")
|
||||
return StreamingParseResult()
|
||||
|
||||
@abstractmethod
|
||||
def has_tool_call(self, text: str) -> bool:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def structure_info(self) -> _GetInfoFunc:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def build_ebnf(self, tools: List[Tool]) -> str:
|
||||
raise NotImplementedError()
|
||||
34
python/sglang/srt/function_call/core_types.py
Normal file
34
python/sglang/srt/function_call/core_types.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ToolCallItem(BaseModel):
|
||||
"""Simple encapsulation of the parsed ToolCall result for easier usage in streaming contexts."""
|
||||
|
||||
tool_index: int
|
||||
name: Optional[str] = None
|
||||
parameters: str # JSON string
|
||||
|
||||
|
||||
class StreamingParseResult(BaseModel):
|
||||
"""Result of streaming incremental parsing."""
|
||||
|
||||
normal_text: str = ""
|
||||
calls: List[ToolCallItem] = []
|
||||
|
||||
|
||||
@dataclass
|
||||
class StructureInfo:
|
||||
begin: str
|
||||
end: str
|
||||
trigger: str
|
||||
|
||||
|
||||
"""
|
||||
Helper alias of function
|
||||
Usually it is a function that takes a name string and returns a StructureInfo object,
|
||||
which can be used to construct a structural_tag object
|
||||
"""
|
||||
_GetInfoFunc = Callable[[str], StructureInfo]
|
||||
157
python/sglang/srt/function_call/deepseekv3_detector.py
Normal file
157
python/sglang/srt/function_call/deepseekv3_detector.py
Normal file
@@ -0,0 +1,157 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import List
|
||||
|
||||
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
||||
from sglang.srt.function_call.core_types import (
|
||||
StreamingParseResult,
|
||||
StructureInfo,
|
||||
ToolCallItem,
|
||||
_GetInfoFunc,
|
||||
)
|
||||
from sglang.srt.function_call.ebnf_composer import EBNFComposer
|
||||
from sglang.srt.function_call.utils import _is_complete_json
|
||||
from sglang.srt.openai_api.protocol import Tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DeepSeekV3Detector(BaseFormatDetector):
|
||||
"""
|
||||
Detector for DeepSeek models.
|
||||
Assumes function call format:
|
||||
'<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_current_weather\n```json\n{"location": "Tokyo"}\n```<|tool▁call▁end|>\n<|tool▁call▁begin|>function<|tool▁sep|>get_current_weather\n```json\n{"location": "Paris"}\n```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.bot_token = "<|tool▁calls▁begin|>"
|
||||
self.eot_token = "<|tool▁calls▁end|>"
|
||||
self.func_call_regex = r"<|tool▁call▁begin|>.*?<|tool▁call▁end|>"
|
||||
self.func_detail_regex = r"<|tool▁call▁begin|>(.*)<|tool▁sep|>(.*)\n```json\n(.*)\n```<|tool▁call▁end|>"
|
||||
self._last_arguments = ""
|
||||
|
||||
def has_tool_call(self, text: str) -> bool:
|
||||
"""Check if the text contains a deepseek format tool call."""
|
||||
return self.bot_token in text
|
||||
|
||||
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
||||
"""
|
||||
One-time parsing: Detects and parses tool calls in the provided text.
|
||||
|
||||
:param text: The complete text to parse.
|
||||
:param tools: List of available tools.
|
||||
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
|
||||
"""
|
||||
idx = text.find(self.bot_token)
|
||||
normal_text = text[:idx].strip() if idx != -1 else text
|
||||
if self.bot_token not in text:
|
||||
return StreamingParseResult(normal_text=normal_text, calls=[])
|
||||
match_result_list = re.findall(self.func_call_regex, text, re.DOTALL)
|
||||
calls = []
|
||||
try:
|
||||
for match_result in match_result_list:
|
||||
# Get function name
|
||||
func_detail = re.search(self.func_detail_regex, match_result, re.DOTALL)
|
||||
func_name = func_detail.group(2)
|
||||
func_args = func_detail.group(3)
|
||||
func_args = json.loads(func_args)
|
||||
# construct match_result for parse_base_json
|
||||
match_result = {"name": func_name, "parameters": func_args}
|
||||
calls.extend(self.parse_base_json(match_result, tools))
|
||||
return StreamingParseResult(normal_text=normal_text, calls=calls)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in detect_and_parse: {e}")
|
||||
# return the normal text if parsing fails
|
||||
return StreamingParseResult(normal_text=text)
|
||||
|
||||
def parse_streaming_increment(
|
||||
self, new_text: str, tools: List[Tool]
|
||||
) -> StreamingParseResult:
|
||||
"""
|
||||
Streaming incremental parsing tool calls for DeepSeekV3 format.
|
||||
"""
|
||||
self._buffer += new_text
|
||||
current_text = self._buffer
|
||||
|
||||
if self.bot_token not in current_text:
|
||||
self._buffer = ""
|
||||
for e_token in [self.eot_token, "```", "<|tool▁call▁end|>"]:
|
||||
if e_token in new_text:
|
||||
new_text = new_text.replace(e_token, "")
|
||||
return StreamingParseResult(normal_text=new_text)
|
||||
|
||||
if not hasattr(self, "_tool_indices"):
|
||||
self._tool_indices = {
|
||||
tool.function.name: i
|
||||
for i, tool in enumerate(tools)
|
||||
if tool.function and tool.function.name
|
||||
}
|
||||
|
||||
calls: list[ToolCallItem] = []
|
||||
try:
|
||||
partial_match = re.search(
|
||||
pattern=r"<|tool▁call▁begin|>(.*)<|tool▁sep|>(.*)\n```json\n(.*)",
|
||||
string=current_text,
|
||||
flags=re.DOTALL,
|
||||
)
|
||||
if partial_match:
|
||||
func_name = partial_match.group(2).strip()
|
||||
func_args_raw = partial_match.group(3).strip()
|
||||
|
||||
if not self.current_tool_name_sent:
|
||||
calls.append(
|
||||
ToolCallItem(
|
||||
tool_index=self._tool_indices.get(func_name, 0),
|
||||
name=func_name,
|
||||
parameters="",
|
||||
)
|
||||
)
|
||||
self.current_tool_name_sent = True
|
||||
else:
|
||||
argument_diff = (
|
||||
func_args_raw[len(self._last_arguments) :]
|
||||
if func_args_raw.startswith(self._last_arguments)
|
||||
else func_args_raw
|
||||
)
|
||||
|
||||
if argument_diff:
|
||||
calls.append(
|
||||
ToolCallItem(
|
||||
tool_index=self._tool_indices.get(func_name, 0),
|
||||
name=None,
|
||||
parameters=argument_diff,
|
||||
)
|
||||
)
|
||||
self._last_arguments += argument_diff
|
||||
|
||||
if _is_complete_json(func_args_raw):
|
||||
result = StreamingParseResult(normal_text="", calls=calls)
|
||||
self._buffer = ""
|
||||
self._last_arguments = ""
|
||||
self.current_tool_name_sent = False
|
||||
return result
|
||||
|
||||
return StreamingParseResult(normal_text="", calls=calls)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in parse_streaming_increment: {e}")
|
||||
return StreamingParseResult(normal_text=current_text)
|
||||
|
||||
def structure_info(self) -> _GetInfoFunc:
|
||||
return lambda name: StructureInfo(
|
||||
begin=">" + name + "\n```json\n",
|
||||
end="\n```<",
|
||||
trigger=">" + name + "\n```json\n",
|
||||
)
|
||||
|
||||
def build_ebnf(self, tools: List[Tool]):
|
||||
return EBNFComposer.build_ebnf(
|
||||
tools,
|
||||
bot_token=self.bot_token,
|
||||
eot_token=self.eot_token,
|
||||
tool_call_separator="",
|
||||
call_rule_fmt='"<|tool▁call▁begin|>function<|tool▁sep|>{name}\\n```json\\n" {arguments_rule} "\\n```<|tool▁call▁end|>"',
|
||||
function_format="json",
|
||||
)
|
||||
234
python/sglang/srt/function_call/ebnf_composer.py
Normal file
234
python/sglang/srt/function_call/ebnf_composer.py
Normal file
@@ -0,0 +1,234 @@
|
||||
from typing import Literal, Optional
|
||||
|
||||
|
||||
class EBNFComposer:
|
||||
# Adapted from https://xgrammar.mlc.ai/docs/how_to/ebnf_guided_generation.html#try-out-via-hf-transformers
|
||||
json_grammar_ebnf_str = r"""
|
||||
json ::= basic_array | basic_object
|
||||
basic_any ::= basic_number | basic_string | basic_boolean | basic_null | basic_array | basic_object
|
||||
basic_integer ::= ("0" | "-"? [1-9] [0-9]*) ".0"?
|
||||
basic_number ::= ("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)?
|
||||
basic_string ::= (([\"] basic_string_1 [\"]))
|
||||
basic_string_1 ::= "" | [^"\\\x00-\x1F] basic_string_1 | "\\" escape basic_string_1
|
||||
escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9]
|
||||
basic_boolean ::= "true" | "false"
|
||||
basic_null ::= "null"
|
||||
basic_array ::= "[" ("" | ws basic_any (ws "," ws basic_any)*) ws "]"
|
||||
basic_object ::= "{" ("" | ws basic_string ws ":" ws basic_any ( ws "," ws basic_string ws ":" ws basic_any)*) ws "}"
|
||||
ws ::= [ \n\t]*
|
||||
"""
|
||||
|
||||
pythonic_grammar_ebnf_str = r"""
|
||||
pythonic ::= basic_number | basic_string | basic_array | "True" | "False" | "None"
|
||||
basic_any ::= basic_number | basic_string | basic_array | basic_object
|
||||
basic_number ::= ("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)?
|
||||
basic_string ::= (([\"] basic_string_1 [\"]))
|
||||
basic_string_1 ::= "" | [^"\\\x00-\x1F] basic_string_1 | "\\" escape basic_string_1
|
||||
escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9]
|
||||
basic_array ::= "[" ("" | ws basic_any (ws "," ws basic_any)*) ws "]"
|
||||
basic_object ::= "{" ("" | ws basic_string ws ":" ws basic_any ( ws "," ws basic_string ws ":" ws basic_any)*) ws "}"
|
||||
ws ::= [ \n\t]*
|
||||
"""
|
||||
|
||||
TOOL_CALLS_MAP = {
|
||||
"pythonic": '"[" function_call ("," function_call)* "]"',
|
||||
"json": "function_call",
|
||||
}
|
||||
|
||||
CALL_RULE_MAP = {
|
||||
"pythonic": 'call_{name} ::= "{name}" "(" {arguments_rule} ")"',
|
||||
"json": 'call_{name} ::= "{{" "\\"name\\"" ":" "\\"{name}\\"" ", " "\\"arguments\\"" ":" {arguments_rule} "}}"',
|
||||
}
|
||||
|
||||
ARGUMENTS_RULE_MAP = {
|
||||
"pythonic": "{arg_rules}",
|
||||
"json": '"{{" {arg_rules} "}}"',
|
||||
}
|
||||
|
||||
KEY_VALUE_RULE_MAP = {
|
||||
"pythonic": '"{key}" "=" {valrule}',
|
||||
"json": '"\\"{key}\\"" ":" {valrule}',
|
||||
}
|
||||
|
||||
JSON_TYPE_MAPPING = {
|
||||
"string": "basic_string",
|
||||
"number": "basic_number",
|
||||
"integer": "basic_number",
|
||||
"boolean": "basic_boolean",
|
||||
"null": "basic_null",
|
||||
"array": "basic_array",
|
||||
"object": "basic_object",
|
||||
}
|
||||
|
||||
PYTHONIC_TYPE_MAPPING = {
|
||||
"string": "basic_string",
|
||||
"number": "basic_number",
|
||||
"integer": "basic_number",
|
||||
"boolean": '"True" | "False"',
|
||||
"null": '"None"',
|
||||
"array": "basic_array",
|
||||
"object": "basic_object",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_value_rule(
|
||||
prop: dict, function_format: Literal["pythonic", "json"] = "json"
|
||||
) -> str:
|
||||
if "enum" in prop:
|
||||
return EBNFComposer._handle_enum(prop, function_format)
|
||||
|
||||
if "type" in prop:
|
||||
return EBNFComposer._handle_type(prop, function_format)
|
||||
|
||||
return function_format
|
||||
|
||||
@staticmethod
|
||||
def _handle_enum(prop: dict, function_format: str) -> str:
|
||||
"""Handle enum properties by formatting each value according to type and format."""
|
||||
enum_values = prop["enum"]
|
||||
prop_type = prop.get("type", "string")
|
||||
|
||||
# Define formatters for different type/format combinations
|
||||
formatters = {
|
||||
("string", "json"): lambda v: f'"\\"{v}\\""',
|
||||
("string", "pythonic"): lambda v: f'"\\"{v}\\""',
|
||||
("number", "json"): str,
|
||||
("number", "pythonic"): str,
|
||||
("integer", "json"): str,
|
||||
("integer", "pythonic"): str,
|
||||
("boolean", "json"): lambda v: "true" if v else "false",
|
||||
("boolean", "pythonic"): lambda v: "True" if v else "False",
|
||||
}
|
||||
|
||||
# Get the formatter or default to string handling
|
||||
formatter = formatters.get(
|
||||
(prop_type, function_format),
|
||||
formatters[("string", function_format)], # Default to string handling
|
||||
)
|
||||
|
||||
formatted_values = [formatter(value) for value in enum_values]
|
||||
enum_rule = " | ".join(formatted_values)
|
||||
|
||||
# Wrap in parentheses if there are multiple values to ensure correct EBNF precedence
|
||||
if len(formatted_values) > 1:
|
||||
enum_rule = f"({enum_rule})"
|
||||
|
||||
return enum_rule
|
||||
|
||||
@staticmethod
|
||||
def _handle_type(prop: dict, function_format: str) -> str:
|
||||
"""Handle type properties using the appropriate type mapping."""
|
||||
prop_type = prop["type"]
|
||||
type_mapping = (
|
||||
EBNFComposer.PYTHONIC_TYPE_MAPPING
|
||||
if function_format == "pythonic"
|
||||
else EBNFComposer.JSON_TYPE_MAPPING
|
||||
)
|
||||
|
||||
if isinstance(prop_type, list):
|
||||
type_rules = [
|
||||
type_mapping[single_type]
|
||||
for single_type in prop_type
|
||||
if single_type in type_mapping
|
||||
]
|
||||
return " | ".join(type_rules) if type_rules else function_format
|
||||
|
||||
return type_mapping.get(prop_type, function_format)
|
||||
|
||||
@staticmethod
|
||||
def build_ebnf(
|
||||
tools,
|
||||
*,
|
||||
call_rule_fmt: Optional[str] = None,
|
||||
function_format: Literal["pythonic", "json"] = "json",
|
||||
bot_token: Optional[str] = None,
|
||||
eot_token: Optional[str] = None,
|
||||
tool_call_separator: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Generalized EBNF builder for all detectors.
|
||||
Args:
|
||||
tools: List of Tool objects to generate EBNF grammar for
|
||||
call_rule_fmt: Optional custom format string for call_{name} rule. It should define each function call's format, with
|
||||
the placeholders {name} for the function name and {arguments_rule} for the arguments rule. If None, a default
|
||||
format based on function_format will be used.
|
||||
function_format: The format of function calls, either "pythonic" or "json"
|
||||
bot_token: The token that indicates the start of a tool call section
|
||||
eot_token: The token that indicates the end of a tool call section
|
||||
tool_call_separator: The separator between multiple tool calls
|
||||
"""
|
||||
# =================================================================
|
||||
# Step 1: Determine the root tool calls rule
|
||||
# =================================================================
|
||||
if bot_token and eot_token:
|
||||
if tool_call_separator:
|
||||
root_rule = f'"{bot_token}" function_call ( "{tool_call_separator}" function_call )* "{eot_token}"'
|
||||
else:
|
||||
root_rule = f'"{bot_token}" function_call "{eot_token}"'
|
||||
else:
|
||||
root_rule = EBNFComposer.TOOL_CALLS_MAP[function_format]
|
||||
|
||||
# =================================================================
|
||||
# Step 2: Build the header rules
|
||||
# =================================================================
|
||||
ebnf_lines = [
|
||||
f"root ::= {root_rule}",
|
||||
"function_call ::= "
|
||||
+ " | ".join([f"call_{tool.function.name}" for tool in tools]),
|
||||
]
|
||||
|
||||
# =================================================================
|
||||
# Step 3: Set up formatting templates
|
||||
# =================================================================
|
||||
call_template = (
|
||||
f"call_{{name}} ::= {call_rule_fmt}"
|
||||
if call_rule_fmt
|
||||
else EBNFComposer.CALL_RULE_MAP[function_format]
|
||||
)
|
||||
args_template = EBNFComposer.ARGUMENTS_RULE_MAP[function_format]
|
||||
key_value_template = EBNFComposer.KEY_VALUE_RULE_MAP[function_format]
|
||||
|
||||
# =================================================================
|
||||
# Step 4: Build rules for each tool
|
||||
# =================================================================
|
||||
for tool in tools:
|
||||
tool_name = tool.function.name
|
||||
params = tool.function.parameters or {}
|
||||
properties = params.get("properties", {})
|
||||
required_props = set(params.get("required", []))
|
||||
|
||||
# Build argument rules for this tool
|
||||
arg_rules = []
|
||||
for prop_name, prop_schema in properties.items():
|
||||
value_rule = EBNFComposer.get_value_rule(prop_schema, function_format)
|
||||
# Create key=value pair
|
||||
pair = key_value_template.format(key=prop_name, valrule=value_rule)
|
||||
|
||||
if prop_name not in required_props:
|
||||
pair = f"[ {pair} ]"
|
||||
|
||||
arg_rules.append(pair)
|
||||
|
||||
# Combine all argument rules
|
||||
combined_args = ' "," '.join(arg_rules) if arg_rules else ""
|
||||
arguments_rule = args_template.format(arg_rules=combined_args)
|
||||
|
||||
# Add the function call rule and its arguments rule
|
||||
ebnf_lines.append(
|
||||
call_template.format(
|
||||
name=tool_name, arguments_rule=f"arguments_{tool_name}"
|
||||
)
|
||||
)
|
||||
ebnf_lines.append(f"arguments_{tool_name} ::= {arguments_rule}")
|
||||
|
||||
# =================================================================
|
||||
# Step 5: Add base grammar rules
|
||||
# =================================================================
|
||||
base_grammar = (
|
||||
EBNFComposer.pythonic_grammar_ebnf_str
|
||||
if function_format == "pythonic"
|
||||
else EBNFComposer.json_grammar_ebnf_str
|
||||
)
|
||||
ebnf_lines.append(base_grammar)
|
||||
|
||||
return "\n".join(ebnf_lines)
|
||||
175
python/sglang/srt/function_call/function_call_parser.py
Normal file
175
python/sglang/srt/function_call/function_call_parser.py
Normal file
@@ -0,0 +1,175 @@
|
||||
from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Type, Union
|
||||
|
||||
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
||||
from sglang.srt.function_call.core_types import ToolCallItem
|
||||
from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector
|
||||
from sglang.srt.function_call.llama32_detector import Llama32Detector
|
||||
from sglang.srt.function_call.mistral_detector import MistralDetector
|
||||
from sglang.srt.function_call.pythonic_detector import PythonicDetector
|
||||
from sglang.srt.function_call.qwen25_detector import Qwen25Detector
|
||||
from sglang.srt.openai_api.protocol import (
|
||||
StructuralTagResponseFormat,
|
||||
StructuresResponseFormat,
|
||||
Tool,
|
||||
ToolChoice,
|
||||
)
|
||||
|
||||
|
||||
class FunctionCallParser:
|
||||
"""
|
||||
Parser for function/tool calls in model outputs.
|
||||
|
||||
This class handles both streaming and non-streaming parsing of function calls using a detector.
|
||||
In streaming scenarios, each time new_text is received, it calls detector.parse_streaming_increment
|
||||
and returns the resulting normal_text and calls to the upper layer (or SSE).
|
||||
"""
|
||||
|
||||
ToolCallParserEnum: Dict[str, Type[BaseFormatDetector]] = {
|
||||
"llama3": Llama32Detector,
|
||||
"qwen25": Qwen25Detector,
|
||||
"mistral": MistralDetector,
|
||||
"deepseekv3": DeepSeekV3Detector,
|
||||
"pythonic": PythonicDetector,
|
||||
}
|
||||
|
||||
def __init__(self, tools: List[Tool], tool_call_parser: str):
|
||||
detector: Type[BaseFormatDetector] = None
|
||||
detector_class = self.ToolCallParserEnum.get(tool_call_parser)
|
||||
if detector_class:
|
||||
detector = detector_class()
|
||||
else:
|
||||
raise ValueError(f"Unsupported tool_call_parser: {tool_call_parser}")
|
||||
|
||||
self.detector = detector
|
||||
self.tools = tools
|
||||
|
||||
def has_tool_call(self, text: str) -> bool:
|
||||
"""
|
||||
Check if the given text contains a tool call in the format supported by this parser.
|
||||
This delegates to the detector's implementation.
|
||||
|
||||
Args:
|
||||
text: The text to check for tool calls
|
||||
|
||||
Returns:
|
||||
True if the text contains a tool call, False otherwise
|
||||
"""
|
||||
return self.detector.has_tool_call(text)
|
||||
|
||||
def parse_non_stream(self, full_text: str) -> Tuple[str, list[ToolCallItem]]:
|
||||
"""
|
||||
One-time parsing of the full text to extract tool calls.
|
||||
|
||||
Args:
|
||||
full_text: The complete text to parse
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
- The remaining text after parsing that was not consumed by the detector (can be treated as normal text)
|
||||
- A list of tool calls parsed from the text
|
||||
"""
|
||||
parsed_result = self.detector.detect_and_parse(full_text, self.tools)
|
||||
tool_call_list = parsed_result.calls
|
||||
if tool_call_list:
|
||||
return parsed_result.normal_text, tool_call_list
|
||||
else:
|
||||
return full_text, []
|
||||
|
||||
def parse_stream_chunk(self, chunk_text: str) -> Tuple[str, list[ToolCallItem]]:
|
||||
"""
|
||||
Streaming incremental parsing of chunks of text as they arrive.
|
||||
|
||||
Args:
|
||||
chunk_text: The new chunk of text to parse
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
- The normal text that should be displayed to the user
|
||||
- A list of tool calls parsed from the chunk
|
||||
"""
|
||||
final_normal_text = ""
|
||||
final_calls = []
|
||||
|
||||
sp_result = self.detector.parse_streaming_increment(chunk_text, self.tools)
|
||||
if sp_result.normal_text:
|
||||
final_normal_text = sp_result.normal_text
|
||||
if sp_result.calls:
|
||||
final_calls.extend(sp_result.calls)
|
||||
final_normal_text = sp_result.normal_text
|
||||
|
||||
return final_normal_text, final_calls
|
||||
|
||||
def get_structure_tag(self) -> StructuralTagResponseFormat:
|
||||
"""
|
||||
Generate a structural tag response format for all available tools.
|
||||
|
||||
This creates the necessary structural tags that guide the model's output format.
|
||||
"""
|
||||
tool_structures: List[StructuresResponseFormat] = list()
|
||||
tool_trigger_set: Set[str] = set()
|
||||
|
||||
get_structure_info = self.detector.structure_info()
|
||||
for tool in self.tools:
|
||||
function = tool.function
|
||||
name = function.name
|
||||
assert name is not None
|
||||
info = get_structure_info(name)
|
||||
|
||||
# accept all if not strict, otherwise only accept the schema
|
||||
schema = function.parameters if function.strict else {}
|
||||
|
||||
tool_structures.append(
|
||||
StructuresResponseFormat(
|
||||
begin=info.begin,
|
||||
schema=schema, # type: ignore
|
||||
end=info.end,
|
||||
)
|
||||
)
|
||||
tool_trigger_set.add(info.trigger)
|
||||
|
||||
return StructuralTagResponseFormat(
|
||||
type="structural_tag",
|
||||
structures=tool_structures,
|
||||
triggers=list(tool_trigger_set),
|
||||
)
|
||||
|
||||
def get_structure_constraint(
|
||||
self, tool_choice: Union[ToolChoice, Literal["auto", "required"]]
|
||||
) -> Optional[Tuple[str, Any]]:
|
||||
"""
|
||||
Returns the appropriate structure constraint for tool calls based on the tool_choice.
|
||||
The constraint is used to guide the model's output format.
|
||||
|
||||
Args:
|
||||
tool_choice: The tool choice setting from the request
|
||||
|
||||
Returns:
|
||||
A tuple of (constraint_type, constraint_value) to be added to sampling parameters,
|
||||
or None if no constraint applies.
|
||||
"""
|
||||
# NOTE: structural_tag only supports JSON-compatible content between the begin and end.
|
||||
# It cannot parse or validate Python syntax like function calls.
|
||||
if (
|
||||
not isinstance(self.detector, PythonicDetector)
|
||||
and tool_choice == "auto"
|
||||
and any(tool.function.strict for tool in self.tools)
|
||||
):
|
||||
strict_tag = self.get_structure_tag()
|
||||
return ("structural_tag", strict_tag)
|
||||
elif tool_choice == "required" or isinstance(tool_choice, ToolChoice):
|
||||
ebnf = self.get_ebnf(tool_choice)
|
||||
return ("ebnf", ebnf) if ebnf is not None else None
|
||||
|
||||
def get_ebnf(
|
||||
self, tool_choice: Union[ToolChoice, Literal["required"]]
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Get the EBNF grammar for the specified tool choice.
|
||||
"""
|
||||
filtered_tools = []
|
||||
if isinstance(tool_choice, ToolChoice):
|
||||
fn_name = tool_choice.function.name
|
||||
filtered_tools = [t for t in self.tools if t.function.name == fn_name]
|
||||
else:
|
||||
filtered_tools = self.tools
|
||||
return self.detector.build_ebnf(filtered_tools)
|
||||
74
python/sglang/srt/function_call/llama32_detector.py
Normal file
74
python/sglang/srt/function_call/llama32_detector.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
||||
from sglang.srt.function_call.core_types import (
|
||||
StreamingParseResult,
|
||||
StructureInfo,
|
||||
_GetInfoFunc,
|
||||
)
|
||||
from sglang.srt.function_call.ebnf_composer import EBNFComposer
|
||||
from sglang.srt.openai_api.protocol import Tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Llama32Detector(BaseFormatDetector):
|
||||
"""
|
||||
Detector for Llama 3.2 models.
|
||||
Assumes function call format:
|
||||
<|python_tag|>{"name":"xxx", "arguments":{...}}
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.bot_token = "<|python_tag|>"
|
||||
|
||||
def has_tool_call(self, text: str) -> bool:
|
||||
"""Check if the text contains a Llama 3.2 format tool call."""
|
||||
# depending on the prompt format the Llama model may or may not
|
||||
# prefix the output with the <|python_tag|> token
|
||||
return "<|python_tag|>" in text or text.startswith("{")
|
||||
|
||||
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
||||
"""Parse function calls from text, handling multiple JSON objects."""
|
||||
if "<|python_tag|>" not in text and not text.startswith("{"):
|
||||
return StreamingParseResult(normal_text=text, calls=[])
|
||||
|
||||
if "<|python_tag|>" in text:
|
||||
normal_text, action_text = text.split("<|python_tag|>")
|
||||
else:
|
||||
normal_text, action_text = "", text
|
||||
|
||||
# Split by semicolon and process each part
|
||||
json_parts = [part.strip() for part in action_text.split(";") if part.strip()]
|
||||
all_actions = []
|
||||
for part in json_parts:
|
||||
try:
|
||||
# Parse each individual JSON object
|
||||
action = json.loads(part)
|
||||
all_actions.append(action)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"Failed to parse JSON part: {part}")
|
||||
logger.warning(f"JSON parse error: {str(e)}")
|
||||
continue
|
||||
calls = []
|
||||
# Only process if we found valid JSON objects
|
||||
if all_actions:
|
||||
calls = self.parse_base_json(all_actions, tools)
|
||||
return StreamingParseResult(normal_text=normal_text, calls=calls)
|
||||
|
||||
def structure_info(self) -> _GetInfoFunc:
|
||||
return lambda name: StructureInfo(
|
||||
begin='<|python_tag|>{"name":"' + name + '", "arguments":',
|
||||
end="}",
|
||||
trigger="<|python_tag|>",
|
||||
)
|
||||
|
||||
def build_ebnf(self, tools: List[Tool]):
|
||||
return EBNFComposer.build_ebnf(
|
||||
tools,
|
||||
function_format="json",
|
||||
tool_call_separator=",",
|
||||
)
|
||||
84
python/sglang/srt/function_call/mistral_detector.py
Normal file
84
python/sglang/srt/function_call/mistral_detector.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import json
|
||||
import re
|
||||
from typing import List
|
||||
|
||||
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
||||
from sglang.srt.function_call.core_types import (
|
||||
StreamingParseResult,
|
||||
StructureInfo,
|
||||
_GetInfoFunc,
|
||||
)
|
||||
from sglang.srt.function_call.ebnf_composer import EBNFComposer
|
||||
from sglang.srt.openai_api.protocol import Tool
|
||||
|
||||
|
||||
class MistralDetector(BaseFormatDetector):
|
||||
"""
|
||||
Detector for Mistral models.
|
||||
Assumes function call format:
|
||||
[TOOL_CALLS] [{"name":"xxx", "arguments":{...}}]
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Initializes the detector with necessary state variables.
|
||||
"""
|
||||
super().__init__()
|
||||
self.bot_token = "[TOOL_CALLS] ["
|
||||
self.eot_token = "]"
|
||||
self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
|
||||
|
||||
def has_tool_call(self, text: str) -> bool:
|
||||
"""Check if the text contains a Mistral format tool call."""
|
||||
return self.bot_token in text
|
||||
|
||||
def _clean_text(self, text: str) -> str:
|
||||
"""
|
||||
clean text to only leave ''[TOOL_CALLS] [{"name": xxx, "arguments": {xxx}}]'
|
||||
for example,
|
||||
text = '[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"location": "Boston, MA", "unit": "fahrenheit"}}]\n\nToday\'s weather in Boston is :{function call result} (in Fahrenheit)\n\nIf you prefer Celsius, please let me know.'
|
||||
return '[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"location": "Boston, MA", "unit": "fahrenheit"}}]'
|
||||
The key pattern is [TOOL_CALLS] [...]
|
||||
"""
|
||||
# TODO: check if Mistral supports multiple tool calls, currently assume only support one tool call
|
||||
find_results = re.findall(r"\[TOOL_CALLS\] \[.*?\]", text, re.DOTALL)
|
||||
if len(find_results) > 0:
|
||||
return find_results[0]
|
||||
else:
|
||||
return ""
|
||||
|
||||
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
||||
"""
|
||||
One-time parsing: Detects and parses tool calls in the provided text.
|
||||
|
||||
:param text: The complete text to parse.
|
||||
:param tools: List of available tools.
|
||||
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
|
||||
"""
|
||||
idx = text.find(self.bot_token)
|
||||
normal_text = text[:idx].strip() if idx != -1 else text
|
||||
text = self._clean_text(text)
|
||||
tool_content = text.replace("[TOOL_CALLS]", "").strip()
|
||||
raw_tool_calls = self.tool_call_regex.findall(tool_content)
|
||||
calls = []
|
||||
if len(raw_tool_calls) > 0:
|
||||
raw_tool_call = raw_tool_calls[0]
|
||||
function_call_arr = json.loads(raw_tool_call)
|
||||
for match_result in function_call_arr:
|
||||
calls.extend(self.parse_base_json(match_result, tools))
|
||||
return StreamingParseResult(normal_text=normal_text, calls=calls)
|
||||
|
||||
def structure_info(self) -> _GetInfoFunc:
|
||||
return lambda name: StructureInfo(
|
||||
begin='[TOOL_CALLS] [{"name":"' + name + '", "arguments":',
|
||||
end="}]",
|
||||
trigger="[TOOL_CALLS]",
|
||||
)
|
||||
|
||||
def build_ebnf(self, tools: List[Tool]):
|
||||
return EBNFComposer.build_ebnf(
|
||||
tools,
|
||||
bot_token=self.bot_token,
|
||||
eot_token=self.eot_token,
|
||||
function_format="json",
|
||||
)
|
||||
163
python/sglang/srt/function_call/pythonic_detector.py
Normal file
163
python/sglang/srt/function_call/pythonic_detector.py
Normal file
@@ -0,0 +1,163 @@
|
||||
import ast
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import List, Optional
|
||||
|
||||
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
||||
from sglang.srt.function_call.core_types import (
|
||||
StreamingParseResult,
|
||||
StructureInfo,
|
||||
ToolCallItem,
|
||||
_GetInfoFunc,
|
||||
)
|
||||
from sglang.srt.function_call.ebnf_composer import EBNFComposer
|
||||
from sglang.srt.openai_api.protocol import Tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PythonicDetector(BaseFormatDetector):
|
||||
"""
|
||||
Detector for Llama-3.2 and Llama-4 models with pythonic tool call format.
|
||||
Assumes function call format:
|
||||
[tool1(arg1=val1, arg2=val2), tool2(arg1=val3)]
|
||||
Arguments are Python literals (not JSON).
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.tool_call_regex = re.compile(
|
||||
r"\[([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s)?\),\s*)*([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s*)?\)\s*)+\]",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
def has_tool_call(self, text: str) -> bool:
|
||||
return bool(self.tool_call_regex.match(text.strip()))
|
||||
|
||||
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
||||
# Try parsing the text as a Python list of function calls
|
||||
text = text.strip()
|
||||
if not (text.startswith("[") and text.endswith("]")):
|
||||
# Not a pythonic tool call format
|
||||
return StreamingParseResult(normal_text=text, calls=[])
|
||||
try:
|
||||
module = ast.parse(text)
|
||||
parsed = getattr(module.body[0], "value", None)
|
||||
if not (
|
||||
isinstance(parsed, ast.List)
|
||||
and all(isinstance(e, ast.Call) for e in parsed.elts)
|
||||
):
|
||||
return StreamingParseResult(normal_text=text, calls=[])
|
||||
calls = []
|
||||
tool_indices = {
|
||||
tool.function.name: i
|
||||
for i, tool in enumerate(tools)
|
||||
if tool.function.name
|
||||
}
|
||||
for call in parsed.elts:
|
||||
if not isinstance(call.func, ast.Name):
|
||||
continue
|
||||
function_name = call.func.id
|
||||
arguments = {}
|
||||
for keyword in call.keywords:
|
||||
arguments[keyword.arg] = self._get_parameter_value(keyword.value)
|
||||
calls.append(
|
||||
ToolCallItem(
|
||||
tool_index=tool_indices.get(function_name, -1),
|
||||
name=function_name,
|
||||
parameters=json.dumps(arguments, ensure_ascii=False),
|
||||
)
|
||||
)
|
||||
return StreamingParseResult(normal_text="", calls=calls)
|
||||
except Exception:
|
||||
logger.exception("Error in pythonic tool call parsing.")
|
||||
return StreamingParseResult(normal_text=text, calls=[])
|
||||
|
||||
def _find_matching_bracket(self, buffer: str, start: int) -> int:
|
||||
"""
|
||||
Find the matching closing bracket for the opening bracket at start position.
|
||||
Properly handles nested brackets.
|
||||
|
||||
Args:
|
||||
buffer: The text buffer to search in
|
||||
start: Position of the opening bracket '['
|
||||
|
||||
Returns:
|
||||
Position of the matching closing bracket ']', or -1 if not found
|
||||
"""
|
||||
bracket_count = 0
|
||||
for i in range(start, len(buffer)):
|
||||
if buffer[i] == "[":
|
||||
bracket_count += 1
|
||||
elif buffer[i] == "]":
|
||||
bracket_count -= 1
|
||||
if bracket_count == 0:
|
||||
return i
|
||||
return -1 # No matching bracket found
|
||||
|
||||
def parse_streaming_increment(
|
||||
self, new_text: str, tools: List[Tool]
|
||||
) -> StreamingParseResult:
|
||||
"""
|
||||
Streaming incremental parsing for pythonic tool calls.
|
||||
Buffers input until a complete pythonic tool call (from [ to ]) is found,
|
||||
then parses and emits any detected calls.
|
||||
"""
|
||||
self._buffer += new_text
|
||||
start = self._buffer.find("[")
|
||||
|
||||
if start == -1:
|
||||
normal_text = self._buffer
|
||||
self._buffer = ""
|
||||
return StreamingParseResult(normal_text=normal_text)
|
||||
|
||||
normal_text = self._buffer[:start] if start > 0 else ""
|
||||
|
||||
end = self._find_matching_bracket(self._buffer, start)
|
||||
if end != -1:
|
||||
call_text = self._buffer[start : end + 1]
|
||||
result = self.detect_and_parse(call_text, tools)
|
||||
self._buffer = self._buffer[end + 1 :]
|
||||
|
||||
# If we had normal text before the tool call, add it to the result
|
||||
if normal_text:
|
||||
result.normal_text = normal_text + (result.normal_text or "")
|
||||
|
||||
return result
|
||||
|
||||
# We have an opening bracket but no closing bracket yet
|
||||
if normal_text:
|
||||
self._buffer = self._buffer[start:]
|
||||
return StreamingParseResult(normal_text=normal_text)
|
||||
|
||||
# Otherwise, we're still accumulating a potential tool call
|
||||
return StreamingParseResult(normal_text="")
|
||||
|
||||
def _get_parameter_value(self, val):
|
||||
if isinstance(val, ast.Constant):
|
||||
return val.value
|
||||
elif isinstance(val, ast.Dict):
|
||||
return {
|
||||
k.value: self._get_parameter_value(v)
|
||||
for k, v in zip(val.keys, val.values)
|
||||
}
|
||||
elif isinstance(val, ast.List):
|
||||
return [self._get_parameter_value(v) for v in val.elts]
|
||||
else:
|
||||
raise ValueError("Tool call arguments must be literals")
|
||||
|
||||
def structure_info(self) -> _GetInfoFunc:
|
||||
def info(name: str):
|
||||
return StructureInfo(begin=f"[{name}(", end=")]", trigger=f"[{name}(")
|
||||
|
||||
return info
|
||||
|
||||
def build_ebnf(self, tools: List[Tool]) -> Optional[str]:
|
||||
return EBNFComposer.build_ebnf(
|
||||
tools,
|
||||
bot_token="[",
|
||||
eot_token="]",
|
||||
tool_call_separator=",",
|
||||
function_format="pythonic",
|
||||
)
|
||||
67
python/sglang/srt/function_call/qwen25_detector.py
Normal file
67
python/sglang/srt/function_call/qwen25_detector.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import json
|
||||
import re
|
||||
from typing import List
|
||||
|
||||
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
||||
from sglang.srt.function_call.core_types import (
|
||||
StreamingParseResult,
|
||||
StructureInfo,
|
||||
_GetInfoFunc,
|
||||
)
|
||||
from sglang.srt.function_call.ebnf_composer import EBNFComposer
|
||||
from sglang.srt.openai_api.protocol import Tool
|
||||
|
||||
|
||||
class Qwen25Detector(BaseFormatDetector):
|
||||
"""
|
||||
Detector for Qwen 2.5 models.
|
||||
Assumes function call format:
|
||||
<tool_call>{"name":"xxx", "arguments":{...}}</tool_call>
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Initializes the detector with necessary state variables.
|
||||
"""
|
||||
super().__init__()
|
||||
self.bot_token = "<tool_call>"
|
||||
self.eot_token = "</tool_call>"
|
||||
|
||||
def has_tool_call(self, text: str) -> bool:
|
||||
"""Check if the text contains a Qwen 2.5 format tool call."""
|
||||
return self.bot_token in text
|
||||
|
||||
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
||||
"""
|
||||
One-time parsing: Detects and parses tool calls in the provided text.
|
||||
|
||||
:param text: The complete text to parse.
|
||||
:param tools: List of available tools.
|
||||
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
|
||||
"""
|
||||
idx = text.find(self.bot_token)
|
||||
normal_text = text[:idx].strip() if idx != -1 else text
|
||||
if self.bot_token not in text:
|
||||
return StreamingParseResult(normal_text=normal_text, calls=[])
|
||||
pattern = rf"{self.bot_token}(.*?){self.eot_token}"
|
||||
match_result_list = re.findall(pattern, text, re.DOTALL)
|
||||
calls = []
|
||||
for match_result in match_result_list:
|
||||
match_result = json.loads(match_result)
|
||||
calls.extend(self.parse_base_json(match_result, tools))
|
||||
return StreamingParseResult(normal_text=normal_text, calls=calls)
|
||||
|
||||
def structure_info(self) -> _GetInfoFunc:
|
||||
return lambda name: StructureInfo(
|
||||
begin='<tool_call>{"name":"' + name + '", "arguments":',
|
||||
end="}</tool_call>",
|
||||
trigger="<tool_call>",
|
||||
)
|
||||
|
||||
def build_ebnf(self, tools: List[Tool]):
|
||||
return EBNFComposer.build_ebnf(
|
||||
tools,
|
||||
bot_token=self.bot_token,
|
||||
eot_token=self.eot_token,
|
||||
function_format="json",
|
||||
)
|
||||
35
python/sglang/srt/function_call/utils.py
Normal file
35
python/sglang/srt/function_call/utils.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import json
|
||||
from json import JSONDecodeError, JSONDecoder
|
||||
from typing import Any, Tuple
|
||||
|
||||
import partial_json_parser
|
||||
from partial_json_parser.core.options import Allow
|
||||
|
||||
|
||||
def _find_common_prefix(s1: str, s2: str) -> str:
|
||||
prefix = ""
|
||||
min_length = min(len(s1), len(s2))
|
||||
for i in range(0, min_length):
|
||||
if s1[i] == s2[i]:
|
||||
prefix += s1[i]
|
||||
else:
|
||||
break
|
||||
return prefix
|
||||
|
||||
|
||||
def _partial_json_loads(input_str: str, flags: Allow) -> Tuple[Any, int]:
|
||||
try:
|
||||
return (partial_json_parser.loads(input_str, flags), len(input_str))
|
||||
except JSONDecodeError as e:
|
||||
if "Extra data" in e.msg:
|
||||
dec = JSONDecoder()
|
||||
return dec.raw_decode(input_str)
|
||||
raise
|
||||
|
||||
|
||||
def _is_complete_json(input_str: str) -> bool:
|
||||
try:
|
||||
json.loads(input_str)
|
||||
return True
|
||||
except JSONDecodeError:
|
||||
return False
|
||||
@@ -1,858 +0,0 @@
|
||||
import ast
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from json import JSONDecodeError, JSONDecoder
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type
|
||||
|
||||
import partial_json_parser
|
||||
from partial_json_parser.core.exceptions import MalformedJSON
|
||||
from partial_json_parser.core.options import Allow
|
||||
from pydantic import BaseModel
|
||||
|
||||
from sglang.srt.openai_api.protocol import (
|
||||
StructuralTagResponseFormat,
|
||||
StructuresResponseFormat,
|
||||
Tool,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TOOLS_TAG_LIST = [
|
||||
"<|plugin|>",
|
||||
"<function=",
|
||||
"<tool_call>",
|
||||
"<|python_tag|>",
|
||||
"[TOOL_CALLS]",
|
||||
"<|tool▁calls▁begin|>",
|
||||
]
|
||||
|
||||
|
||||
class ToolCallItem(BaseModel):
|
||||
"""Simple encapsulation of the parsed ToolCall result for easier usage in streaming contexts."""
|
||||
|
||||
tool_index: int
|
||||
name: Optional[str] = None
|
||||
parameters: str # JSON string
|
||||
|
||||
|
||||
def _find_common_prefix(s1: str, s2: str) -> str:
|
||||
prefix = ""
|
||||
min_length = min(len(s1), len(s2))
|
||||
for i in range(0, min_length):
|
||||
if s1[i] == s2[i]:
|
||||
prefix += s1[i]
|
||||
else:
|
||||
break
|
||||
return prefix
|
||||
|
||||
|
||||
def _partial_json_loads(input_str: str, flags: Allow) -> Tuple[Any, int]:
|
||||
try:
|
||||
return (partial_json_parser.loads(input_str, flags), len(input_str))
|
||||
except JSONDecodeError as e:
|
||||
if "Extra data" in e.msg:
|
||||
dec = JSONDecoder()
|
||||
return dec.raw_decode(input_str)
|
||||
raise
|
||||
|
||||
|
||||
def _is_complete_json(input_str: str) -> bool:
|
||||
try:
|
||||
json.loads(input_str)
|
||||
return True
|
||||
except JSONDecodeError:
|
||||
return False
|
||||
|
||||
|
||||
class StreamingParseResult:
|
||||
"""Result of streaming incremental parsing."""
|
||||
|
||||
def __init__(
|
||||
self, normal_text: str = "", calls: Optional[List[ToolCallItem]] = None
|
||||
):
|
||||
self.normal_text = normal_text
|
||||
self.calls = calls or []
|
||||
|
||||
|
||||
@dataclass
|
||||
class StructureInfo:
|
||||
begin: str
|
||||
end: str
|
||||
trigger: str
|
||||
|
||||
|
||||
_GetInfoFunc = Callable[[str], StructureInfo]
|
||||
"""
|
||||
Helper alias of function
|
||||
Usually it is a function that takes a name string and returns a StructureInfo object,
|
||||
which can be used to construct a structural_tag object
|
||||
"""
|
||||
|
||||
|
||||
class BaseFormatDetector(ABC):
|
||||
"""Base class providing two sets of interfaces: one-time and streaming incremental."""
|
||||
|
||||
def __init__(self):
|
||||
# initialize properties used for state when parsing tool calls in
|
||||
self._buffer = ""
|
||||
# streaming mode
|
||||
self.prev_tool_call_arr: List[Dict] = []
|
||||
self.current_tool_id: int = -1
|
||||
self.current_tool_name_sent: bool = False
|
||||
self.streamed_args_for_tool: List[str] = (
|
||||
[]
|
||||
) # map what has been streamed for each tool so far to a list
|
||||
self.bot_token = ""
|
||||
self.eot_token = ""
|
||||
|
||||
def parse_base_json(self, action: Any, tools: List[Tool]) -> List[ToolCallItem]:
|
||||
tool_indices = {
|
||||
tool.function.name: i for i, tool in enumerate(tools) if tool.function.name
|
||||
}
|
||||
if not isinstance(action, list):
|
||||
action = [action]
|
||||
|
||||
results = []
|
||||
for act in action:
|
||||
name = act.get("name")
|
||||
if name and name in tool_indices:
|
||||
results.append(
|
||||
ToolCallItem(
|
||||
tool_index=tool_indices[name],
|
||||
name=name,
|
||||
parameters=json.dumps(
|
||||
act.get("parameters") or act.get("arguments", {}),
|
||||
ensure_ascii=False,
|
||||
),
|
||||
)
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Model attempted to call undefined function: {name}")
|
||||
|
||||
return results
|
||||
|
||||
@abstractmethod
|
||||
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
||||
"""
|
||||
Parses the text in one go. Returns success=True if the format matches, otherwise False.
|
||||
Note that leftover_text here represents "content that this parser will not consume further".
|
||||
"""
|
||||
action = json.loads(text)
|
||||
return StreamingParseResult(calls=self.parse_base_json(action, tools))
|
||||
|
||||
def parse_streaming_increment(
|
||||
self, new_text: str, tools: List[Tool]
|
||||
) -> StreamingParseResult:
|
||||
"""
|
||||
Streaming incremental parsing with tool validation.
|
||||
"""
|
||||
# Append new text to buffer
|
||||
self._buffer += new_text
|
||||
current_text = self._buffer
|
||||
if not (self.bot_token in current_text or current_text.startswith("{")):
|
||||
self._buffer = ""
|
||||
if self.eot_token in new_text:
|
||||
new_text = new_text.replace(self.eot_token, "")
|
||||
return StreamingParseResult(normal_text=new_text)
|
||||
|
||||
# Build tool indices if not already built
|
||||
if not hasattr(self, "_tool_indices"):
|
||||
self._tool_indices = {
|
||||
tool.function.name: i
|
||||
for i, tool in enumerate(tools)
|
||||
if tool.function and tool.function.name
|
||||
}
|
||||
|
||||
flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR
|
||||
try:
|
||||
tool_call_arr = []
|
||||
is_complete = []
|
||||
try:
|
||||
start_idx = (
|
||||
len(self.bot_token)
|
||||
if current_text.startswith(self.bot_token)
|
||||
else 0
|
||||
)
|
||||
while start_idx < len(current_text):
|
||||
(obj, end_idx) = _partial_json_loads(
|
||||
current_text[start_idx:], flags
|
||||
)
|
||||
is_complete.append(
|
||||
_is_complete_json(current_text[start_idx : start_idx + end_idx])
|
||||
)
|
||||
start_idx += end_idx + len("; ")
|
||||
|
||||
# Validate tool name if present
|
||||
if "name" in obj and obj["name"] not in self._tool_indices:
|
||||
# Invalid tool name - reset state
|
||||
self._buffer = ""
|
||||
self.current_tool_id = -1
|
||||
self.current_tool_name_sent = False
|
||||
if self.streamed_args_for_tool:
|
||||
self.streamed_args_for_tool.pop()
|
||||
return StreamingParseResult()
|
||||
|
||||
# Handle parameters/arguments consistency
|
||||
if "parameters" in obj:
|
||||
assert (
|
||||
"arguments" not in obj
|
||||
), "model generated both parameters and arguments"
|
||||
obj["arguments"] = obj["parameters"]
|
||||
tool_call_arr.append(obj)
|
||||
|
||||
except MalformedJSON:
|
||||
return StreamingParseResult()
|
||||
|
||||
if len(tool_call_arr) == 0:
|
||||
return StreamingParseResult()
|
||||
|
||||
current_tool_call: Dict = (
|
||||
tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {}
|
||||
)
|
||||
|
||||
# Handle new tool in array
|
||||
if len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1:
|
||||
if self.current_tool_id >= 0:
|
||||
cur_arguments = current_tool_call.get("arguments")
|
||||
if cur_arguments:
|
||||
cur_args_json = json.dumps(cur_arguments)
|
||||
sent = len(self.streamed_args_for_tool[self.current_tool_id])
|
||||
argument_diff = cur_args_json[sent:]
|
||||
|
||||
res = StreamingParseResult(
|
||||
calls=[
|
||||
ToolCallItem(
|
||||
tool_index=self.current_tool_id,
|
||||
name="",
|
||||
parameters=argument_diff,
|
||||
)
|
||||
],
|
||||
)
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id
|
||||
] += argument_diff
|
||||
else:
|
||||
res = StreamingParseResult()
|
||||
else:
|
||||
res = StreamingParseResult()
|
||||
|
||||
self.current_tool_id = len(tool_call_arr) - 1
|
||||
self.current_tool_name_sent = False
|
||||
self.streamed_args_for_tool.append("")
|
||||
return res
|
||||
|
||||
# Handle tool name
|
||||
elif not self.current_tool_name_sent:
|
||||
function_name = current_tool_call.get("name")
|
||||
if function_name and function_name in self._tool_indices:
|
||||
res = StreamingParseResult(
|
||||
calls=[
|
||||
ToolCallItem(
|
||||
tool_index=self._tool_indices[function_name],
|
||||
name=function_name,
|
||||
parameters="",
|
||||
)
|
||||
],
|
||||
)
|
||||
self.current_tool_name_sent = True
|
||||
else:
|
||||
res = StreamingParseResult()
|
||||
|
||||
# Handle streaming arguments
|
||||
else:
|
||||
cur_arguments = current_tool_call.get("arguments")
|
||||
res = StreamingParseResult()
|
||||
|
||||
if cur_arguments:
|
||||
sent = len(self.streamed_args_for_tool[self.current_tool_id])
|
||||
cur_args_json = json.dumps(cur_arguments)
|
||||
prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get(
|
||||
"arguments"
|
||||
)
|
||||
|
||||
argument_diff = None
|
||||
if is_complete[self.current_tool_id]:
|
||||
argument_diff = cur_args_json[sent:]
|
||||
self._buffer = ""
|
||||
self.prev_tool_call_arr[self.current_tool_id].clear()
|
||||
self.current_tool_name_sent = False
|
||||
self.streamed_args_for_tool[self.current_tool_id] = ""
|
||||
|
||||
elif prev_arguments:
|
||||
prev_args_json = json.dumps(prev_arguments)
|
||||
if cur_args_json != prev_args_json:
|
||||
prefix = _find_common_prefix(prev_args_json, cur_args_json)
|
||||
argument_diff = prefix[sent:]
|
||||
|
||||
if argument_diff is not None:
|
||||
res = StreamingParseResult(
|
||||
calls=[
|
||||
ToolCallItem(
|
||||
tool_index=self.current_tool_id,
|
||||
parameters=argument_diff,
|
||||
)
|
||||
],
|
||||
)
|
||||
if not is_complete[self.current_tool_id]:
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id
|
||||
] += argument_diff
|
||||
|
||||
self.prev_tool_call_arr = tool_call_arr
|
||||
return res
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in parse_streaming_increment: {e}")
|
||||
return StreamingParseResult()
|
||||
|
||||
@abstractmethod
|
||||
def has_tool_call(self, text: str) -> bool:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def structure_info(self) -> _GetInfoFunc:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class Qwen25Detector(BaseFormatDetector):
|
||||
"""
|
||||
Detector for Qwen 2.5 models.
|
||||
Assumes function call format:
|
||||
<tool_call>{"name":"xxx", "arguments":{...}}</tool_call>
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Initializes the detector with necessary state variables.
|
||||
"""
|
||||
super().__init__()
|
||||
self.bot_token = "<tool_call>"
|
||||
self.eot_token = "</tool_call>"
|
||||
|
||||
def has_tool_call(self, text: str) -> bool:
|
||||
"""Check if the text contains a Qwen 2.5 format tool call."""
|
||||
return self.bot_token in text
|
||||
|
||||
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
||||
"""
|
||||
One-time parsing: Detects and parses tool calls in the provided text.
|
||||
|
||||
:param text: The complete text to parse.
|
||||
:param tools: List of available tools.
|
||||
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
|
||||
"""
|
||||
idx = text.find(self.bot_token)
|
||||
normal_text = text[:idx].strip() if idx != -1 else text
|
||||
if self.bot_token not in text:
|
||||
return StreamingParseResult(normal_text=normal_text, calls=[])
|
||||
pattern = rf"{self.bot_token}(.*?){self.eot_token}"
|
||||
match_result_list = re.findall(pattern, text, re.DOTALL)
|
||||
calls = []
|
||||
for match_result in match_result_list:
|
||||
match_result = json.loads(match_result)
|
||||
calls.extend(self.parse_base_json(match_result, tools))
|
||||
return StreamingParseResult(normal_text=normal_text, calls=calls)
|
||||
|
||||
def structure_info(self) -> _GetInfoFunc:
|
||||
return lambda name: StructureInfo(
|
||||
begin='<tool_call>{"name":"' + name + '", "arguments":',
|
||||
end="}</tool_call>",
|
||||
trigger="<tool_call>",
|
||||
)
|
||||
|
||||
|
||||
class MistralDetector(BaseFormatDetector):
|
||||
"""
|
||||
Detector for Mistral models.
|
||||
Assumes function call format:
|
||||
<|action_start|><|plugin|>{"name":"xxx", "arguments":{...}}<|action_end|>
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Initializes the detector with necessary state variables.
|
||||
"""
|
||||
super().__init__()
|
||||
self.bot_token = "[TOOL_CALLS] ["
|
||||
self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
|
||||
|
||||
def has_tool_call(self, text: str) -> bool:
|
||||
"""Check if the text contains a Mistral format tool call."""
|
||||
return self.bot_token in text
|
||||
|
||||
def _clean_text(self, text: str) -> str:
|
||||
"""
|
||||
clean text to only leave ''[TOOL_CALLS] [{"name": xxx, "arguments": {xxx}}]'
|
||||
for example,
|
||||
text = '[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"location": "Boston, MA", "unit": "fahrenheit"}}]\n\nToday\'s weather in Boston is :{function call result} (in Fahrenheit)\n\nIf you prefer Celsius, please let me know.'
|
||||
return '[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"location": "Boston, MA", "unit": "fahrenheit"}}]'
|
||||
The key pattern is [TOOL_CALLS] [...]
|
||||
"""
|
||||
find_results = re.findall(r"\[TOOL_CALLS\] \[.*?\]", text, re.DOTALL)
|
||||
if len(find_results) > 0:
|
||||
return find_results[0]
|
||||
else:
|
||||
return ""
|
||||
|
||||
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
||||
"""
|
||||
One-time parsing: Detects and parses tool calls in the provided text.
|
||||
|
||||
:param text: The complete text to parse.
|
||||
:param tools: List of available tools.
|
||||
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
|
||||
"""
|
||||
idx = text.find(self.bot_token)
|
||||
normal_text = text[:idx].strip() if idx != -1 else text
|
||||
text = self._clean_text(text)
|
||||
tool_content = text.replace("[TOOL_CALLS]", "").strip()
|
||||
raw_tool_calls = self.tool_call_regex.findall(tool_content)
|
||||
calls = []
|
||||
if len(raw_tool_calls) > 0:
|
||||
raw_tool_call = raw_tool_calls[0]
|
||||
function_call_arr = json.loads(raw_tool_call)
|
||||
for match_result in function_call_arr:
|
||||
calls.extend(self.parse_base_json(match_result, tools))
|
||||
return StreamingParseResult(normal_text=normal_text, calls=calls)
|
||||
|
||||
def structure_info(self) -> _GetInfoFunc:
|
||||
return lambda name: StructureInfo(
|
||||
begin='[TOOL_CALLS] [{"name":"' + name + '", "arguments":',
|
||||
end="}]",
|
||||
trigger="[TOOL_CALLS]",
|
||||
)
|
||||
|
||||
|
||||
class Llama32Detector(BaseFormatDetector):
|
||||
"""
|
||||
Detector for Llama 3.2 models.
|
||||
Assumes function call format:
|
||||
<|python_tag|>{"name":"xxx", "arguments":{...}}
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.bot_token = "<|python_tag|>"
|
||||
|
||||
def has_tool_call(self, text: str) -> bool:
|
||||
"""Check if the text contains a Llama 3.2 format tool call."""
|
||||
# depending on the prompt format the Llama model may or may not
|
||||
# prefix the output with the <|python_tag|> token
|
||||
return "<|python_tag|>" in text or text.startswith("{")
|
||||
|
||||
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
||||
"""Parse function calls from text, handling multiple JSON objects."""
|
||||
if "<|python_tag|>" not in text and not text.startswith("{"):
|
||||
return StreamingParseResult(normal_text=text, calls=[])
|
||||
|
||||
if "<|python_tag|>" in text:
|
||||
normal_text, action_text = text.split("<|python_tag|>")
|
||||
else:
|
||||
normal_text, action_text = "", text
|
||||
|
||||
# Split by semicolon and process each part
|
||||
json_parts = [part.strip() for part in action_text.split(";") if part.strip()]
|
||||
all_actions = []
|
||||
for part in json_parts:
|
||||
try:
|
||||
# Parse each individual JSON object
|
||||
action = json.loads(part)
|
||||
all_actions.append(action)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"Failed to parse JSON part: {part}")
|
||||
logger.warning(f"JSON parse error: {str(e)}")
|
||||
continue
|
||||
calls = []
|
||||
# Only process if we found valid JSON objects
|
||||
if all_actions:
|
||||
calls = self.parse_base_json(all_actions, tools)
|
||||
return StreamingParseResult(normal_text=normal_text, calls=calls)
|
||||
|
||||
def structure_info(self) -> _GetInfoFunc:
|
||||
return lambda name: StructureInfo(
|
||||
begin='<|python_tag|>{"name":"' + name + '", "arguments":',
|
||||
end="}",
|
||||
trigger="<|python_tag|>",
|
||||
)
|
||||
|
||||
|
||||
class DeepSeekV3Detector(BaseFormatDetector):
|
||||
"""
|
||||
Detector for DeepSeek models.
|
||||
Assumes function call format:
|
||||
'<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_current_weather\n```json\n{"location": "Tokyo"}\n```<|tool▁call▁end|>\n<|tool▁call▁begin|>function<|tool▁sep|>get_current_weather\n```json\n{"location": "Paris"}\n```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.bot_token = "<|tool▁calls▁begin|>"
|
||||
self.eot_token = "<|tool▁calls▁end|>"
|
||||
self.func_call_regex = r"<|tool▁call▁begin|>.*?<|tool▁call▁end|>"
|
||||
self.func_detail_regex = r"<|tool▁call▁begin|>(.*)<|tool▁sep|>(.*)\n```json\n(.*)\n```<|tool▁call▁end|>"
|
||||
self._last_arguments = ""
|
||||
|
||||
def has_tool_call(self, text: str) -> bool:
|
||||
"""Check if the text contains a deepseek format tool call."""
|
||||
return self.bot_token in text
|
||||
|
||||
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
||||
"""
|
||||
One-time parsing: Detects and parses tool calls in the provided text.
|
||||
|
||||
:param text: The complete text to parse.
|
||||
:param tools: List of available tools.
|
||||
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
|
||||
"""
|
||||
idx = text.find(self.bot_token)
|
||||
normal_text = text[:idx].strip() if idx != -1 else text
|
||||
if self.bot_token not in text:
|
||||
return StreamingParseResult(normal_text=normal_text, calls=[])
|
||||
match_result_list = re.findall(self.func_call_regex, text, re.DOTALL)
|
||||
calls = []
|
||||
try:
|
||||
for match_result in match_result_list:
|
||||
# Get function name
|
||||
func_detail = re.search(self.func_detail_regex, match_result, re.DOTALL)
|
||||
func_name = func_detail.group(2)
|
||||
func_args = func_detail.group(3)
|
||||
func_args = json.loads(func_args)
|
||||
# construct match_result for parse_base_json
|
||||
match_result = {"name": func_name, "parameters": func_args}
|
||||
calls.extend(self.parse_base_json(match_result, tools))
|
||||
return StreamingParseResult(normal_text=normal_text, calls=calls)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in detect_and_parse: {e}")
|
||||
# return the normal text if parsing fails
|
||||
return StreamingParseResult(normal_text=text)
|
||||
|
||||
def structure_info(self) -> _GetInfoFunc:
|
||||
return lambda name: StructureInfo(
|
||||
begin=">" + name + "\n```json\n",
|
||||
end="\n```<",
|
||||
trigger=">" + name + "\n```json\n",
|
||||
)
|
||||
|
||||
def parse_streaming_increment(
|
||||
self, new_text: str, tools: List[Tool]
|
||||
) -> StreamingParseResult:
|
||||
"""
|
||||
Streaming incremental parsing tool calls for DeepSeekV3 format.
|
||||
"""
|
||||
self._buffer += new_text
|
||||
current_text = self._buffer
|
||||
|
||||
if self.bot_token not in current_text:
|
||||
self._buffer = ""
|
||||
for e_token in [self.eot_token, "```", "<|tool▁call▁end|>"]:
|
||||
if e_token in new_text:
|
||||
new_text = new_text.replace(e_token, "")
|
||||
return StreamingParseResult(normal_text=new_text)
|
||||
|
||||
if not hasattr(self, "_tool_indices"):
|
||||
self._tool_indices = {
|
||||
tool.function.name: i
|
||||
for i, tool in enumerate(tools)
|
||||
if tool.function and tool.function.name
|
||||
}
|
||||
|
||||
calls: list[ToolCallItem] = []
|
||||
try:
|
||||
partial_match = re.search(
|
||||
pattern=r"<|tool▁call▁begin|>(.*)<|tool▁sep|>(.*)\n```json\n(.*)",
|
||||
string=current_text,
|
||||
flags=re.DOTALL,
|
||||
)
|
||||
if partial_match:
|
||||
func_name = partial_match.group(2).strip()
|
||||
func_args_raw = partial_match.group(3).strip()
|
||||
|
||||
if not self.current_tool_name_sent:
|
||||
calls.append(
|
||||
ToolCallItem(
|
||||
tool_index=self._tool_indices.get(func_name, 0),
|
||||
name=func_name,
|
||||
parameters="",
|
||||
)
|
||||
)
|
||||
self.current_tool_name_sent = True
|
||||
else:
|
||||
argument_diff = (
|
||||
func_args_raw[len(self._last_arguments) :]
|
||||
if func_args_raw.startswith(self._last_arguments)
|
||||
else func_args_raw
|
||||
)
|
||||
|
||||
if argument_diff:
|
||||
calls.append(
|
||||
ToolCallItem(
|
||||
tool_index=self._tool_indices.get(func_name, 0),
|
||||
name=None,
|
||||
parameters=argument_diff,
|
||||
)
|
||||
)
|
||||
self._last_arguments += argument_diff
|
||||
|
||||
if _is_complete_json(func_args_raw):
|
||||
result = StreamingParseResult(normal_text="", calls=calls)
|
||||
self._buffer = ""
|
||||
self._last_arguments = ""
|
||||
self.current_tool_name_sent = False
|
||||
return result
|
||||
|
||||
return StreamingParseResult(normal_text="", calls=calls)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in parse_streaming_increment: {e}")
|
||||
return StreamingParseResult(normal_text=current_text)
|
||||
|
||||
|
||||
class MultiFormatParser:
|
||||
def __init__(self, detectors: List[BaseFormatDetector]):
|
||||
"""
|
||||
:param detectors: A series of available Detector instances passed in
|
||||
"""
|
||||
self.detectors = detectors
|
||||
|
||||
def parse_once(
|
||||
self, text: str, tools: List[Tool]
|
||||
) -> Tuple[str, list[ToolCallItem]]:
|
||||
"""
|
||||
One-time parsing: Loop through detectors until there are no new matches or text is exhausted
|
||||
Return: (final_text, all_calls)
|
||||
- final_text: The remaining text after parsing that was not consumed by any Detector (can be treated as normal text)
|
||||
- all_calls: All calls parsed by the Detectors
|
||||
"""
|
||||
final_calls = []
|
||||
final_normal_text = text
|
||||
for detector in self.detectors:
|
||||
parsed_result = detector.detect_and_parse(text, tools)
|
||||
tool_call_list = parsed_result.calls
|
||||
if len(tool_call_list) > 0: # parsed successfully
|
||||
final_calls = tool_call_list
|
||||
final_normal_text = parsed_result.normal_text
|
||||
break
|
||||
|
||||
# leftover_text is the normal text not consumed by any Detector
|
||||
return final_normal_text, final_calls
|
||||
|
||||
def parse_streaming_increment(
|
||||
self, new_text: str, tools: List[Tool]
|
||||
) -> Tuple[str, list[ToolCallItem]]:
|
||||
"""
|
||||
Streaming incremental parsing: Feed new_text to each detector's parse_streaming_increment
|
||||
and merge their produced normal_text/calls to return.
|
||||
(The logic here can be "priority-based" or "parallel parsing" based on your needs)
|
||||
"""
|
||||
final_normal_text = ""
|
||||
final_calls = []
|
||||
|
||||
for detector in self.detectors:
|
||||
sp_result = detector.parse_streaming_increment(new_text, tools)
|
||||
# Merge normal_text and calls
|
||||
# If one sp_result contains result call, this should be a successful parse
|
||||
# If one sp_result only contains normal_text, this can either be a successful
|
||||
# parse or it is not using the desired parsing tool.
|
||||
if sp_result.normal_text:
|
||||
final_normal_text = sp_result.normal_text
|
||||
if sp_result.calls:
|
||||
final_calls.extend(sp_result.calls)
|
||||
final_normal_text = sp_result.normal_text
|
||||
break
|
||||
|
||||
return final_normal_text, final_calls
|
||||
|
||||
|
||||
class PythonicDetector(BaseFormatDetector):
|
||||
"""
|
||||
Detector for Llama-3.2 and Llama-4 models with pythonic tool call format.
|
||||
Assumes function call format:
|
||||
[tool1(arg1=val1, arg2=val2), tool2(arg1=val3)]
|
||||
Arguments are Python literals (not JSON).
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.tool_call_regex = re.compile(
|
||||
r"\[([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s)?\),\s*)*([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s*)?\)\s*)+\]",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
def has_tool_call(self, text: str) -> bool:
|
||||
return bool(self.tool_call_regex.match(text.strip()))
|
||||
|
||||
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
||||
# Try parsing the text as a Python list of function calls
|
||||
text = text.strip()
|
||||
if not (text.startswith("[") and text.endswith("]")):
|
||||
# Not a pythonic tool call format
|
||||
return StreamingParseResult(normal_text=text, calls=[])
|
||||
try:
|
||||
module = ast.parse(text)
|
||||
parsed = getattr(module.body[0], "value", None)
|
||||
if not (
|
||||
isinstance(parsed, ast.List)
|
||||
and all(isinstance(e, ast.Call) for e in parsed.elts)
|
||||
):
|
||||
return StreamingParseResult(normal_text=text, calls=[])
|
||||
calls = []
|
||||
tool_indices = {
|
||||
tool.function.name: i
|
||||
for i, tool in enumerate(tools)
|
||||
if tool.function.name
|
||||
}
|
||||
for call in parsed.elts:
|
||||
if not isinstance(call.func, ast.Name):
|
||||
continue
|
||||
function_name = call.func.id
|
||||
arguments = {}
|
||||
for keyword in call.keywords:
|
||||
arguments[keyword.arg] = self._get_parameter_value(keyword.value)
|
||||
calls.append(
|
||||
ToolCallItem(
|
||||
tool_index=tool_indices.get(function_name, -1),
|
||||
name=function_name,
|
||||
parameters=json.dumps(arguments, ensure_ascii=False),
|
||||
)
|
||||
)
|
||||
return StreamingParseResult(normal_text="", calls=calls)
|
||||
except Exception:
|
||||
logger.exception("Error in pythonic tool call parsing.")
|
||||
return StreamingParseResult(normal_text=text, calls=[])
|
||||
|
||||
def parse_streaming_increment(
|
||||
self, new_text: str, tools: List[Tool]
|
||||
) -> StreamingParseResult:
|
||||
"""
|
||||
Streaming incremental parsing for pythonic tool calls.
|
||||
Buffers input until a complete pythonic tool call (from [ to ]) is found,
|
||||
then parses and emits any detected calls.
|
||||
"""
|
||||
self._buffer += new_text
|
||||
start = self._buffer.find("[")
|
||||
end = self._buffer.find("]", start)
|
||||
if start != -1 and end != -1:
|
||||
call_text = self._buffer[start : end + 1]
|
||||
result = self.detect_and_parse(call_text, tools)
|
||||
self._buffer = self._buffer[end + 1 :]
|
||||
return result
|
||||
return StreamingParseResult(normal_text="")
|
||||
|
||||
def _get_parameter_value(self, val):
|
||||
if isinstance(val, ast.Constant):
|
||||
return val.value
|
||||
elif isinstance(val, ast.Dict):
|
||||
return {
|
||||
k.value: self._get_parameter_value(v)
|
||||
for k, v in zip(val.keys, val.values)
|
||||
}
|
||||
elif isinstance(val, ast.List):
|
||||
return [self._get_parameter_value(v) for v in val.elts]
|
||||
else:
|
||||
raise ValueError("Tool call arguments must be literals")
|
||||
|
||||
def structure_info(self) -> _GetInfoFunc:
|
||||
def info(name: str):
|
||||
return StructureInfo(begin="[", end="]", trigger="")
|
||||
|
||||
return info
|
||||
|
||||
|
||||
class FunctionCallParser:
|
||||
"""
|
||||
In streaming scenarios, each time new_text is received, it calls multi_format_parser.parse_streaming_increment
|
||||
and returns the resulting normal_text and calls to the upper layer (or SSE).
|
||||
"""
|
||||
|
||||
ToolCallParserEnum: Dict[str, Type[BaseFormatDetector]] = {
|
||||
"llama3": Llama32Detector,
|
||||
"qwen25": Qwen25Detector,
|
||||
"mistral": MistralDetector,
|
||||
"deepseekv3": DeepSeekV3Detector,
|
||||
"pythonic": PythonicDetector,
|
||||
}
|
||||
|
||||
def __init__(self, tools: List[Tool], tool_call_parser: str):
|
||||
detectors = []
|
||||
if tool_call_parser:
|
||||
detector_class = self.ToolCallParserEnum.get(tool_call_parser)
|
||||
if detector_class:
|
||||
detectors.append(detector_class())
|
||||
else:
|
||||
raise ValueError(f"Unsupported tool_call_parser: {tool_call_parser}")
|
||||
else:
|
||||
raise ValueError("Tool Call Parser Not Given!")
|
||||
|
||||
self.multi_format_parser = MultiFormatParser(detectors)
|
||||
self.tools = tools
|
||||
|
||||
def has_tool_call(self, text: str) -> bool:
|
||||
"""
|
||||
Check if the given text contains a tool call in the format supported by this parser.
|
||||
This delegates to the detector's implementation.
|
||||
|
||||
:param text: The text to check for tool calls
|
||||
:return: True if the text contains a tool call, False otherwise
|
||||
"""
|
||||
# Check all detectors in the multi_format_parser
|
||||
for detector in self.multi_format_parser.detectors:
|
||||
if detector.has_tool_call(text):
|
||||
return True
|
||||
return False
|
||||
|
||||
def parse_non_stream(self, full_text: str) -> Tuple[str, list[ToolCallItem]]:
|
||||
"""
|
||||
Non-streaming call: one-time parsing
|
||||
"""
|
||||
full_normal_text, calls = self.multi_format_parser.parse_once(
|
||||
full_text, self.tools
|
||||
)
|
||||
return full_normal_text, calls
|
||||
|
||||
def parse_stream_chunk(self, chunk_text: str) -> Tuple[str, list[ToolCallItem]]:
|
||||
"""
|
||||
Streaming call: incremental parsing
|
||||
"""
|
||||
normal_text, calls = self.multi_format_parser.parse_streaming_increment(
|
||||
chunk_text, self.tools
|
||||
)
|
||||
return normal_text, calls
|
||||
|
||||
def structure_infos(self) -> List[_GetInfoFunc]:
|
||||
"""
|
||||
Returns a list of structure_info functions for each detector
|
||||
"""
|
||||
return [
|
||||
detector.structure_info() for detector in self.multi_format_parser.detectors
|
||||
]
|
||||
|
||||
def get_structure_tag(self) -> StructuralTagResponseFormat:
|
||||
tool_structures: List[StructuresResponseFormat] = list()
|
||||
tool_trigger_set: Set[str] = set()
|
||||
|
||||
for wrapper in self.structure_infos():
|
||||
for tool in self.tools:
|
||||
function = tool.function
|
||||
name = function.name
|
||||
assert name is not None
|
||||
info = wrapper(name)
|
||||
|
||||
# accept all if not strict, otherwise only accept the schema
|
||||
schema = function.parameters if function.strict else {}
|
||||
|
||||
tool_structures.append(
|
||||
StructuresResponseFormat(
|
||||
begin=info.begin,
|
||||
schema=schema, # type: ignore
|
||||
end=info.end,
|
||||
)
|
||||
)
|
||||
tool_trigger_set.add(info.trigger)
|
||||
|
||||
return StructuralTagResponseFormat(
|
||||
type="structural_tag",
|
||||
structures=tool_structures,
|
||||
triggers=list(tool_trigger_set),
|
||||
)
|
||||
@@ -40,7 +40,7 @@ from sglang.srt.conversation import (
|
||||
get_conv_template_by_model_path,
|
||||
register_conv_template,
|
||||
)
|
||||
from sglang.srt.function_call_parser import FunctionCallParser
|
||||
from sglang.srt.function_call.function_call_parser import FunctionCallParser
|
||||
from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
|
||||
from sglang.srt.openai_api.protocol import (
|
||||
BatchRequest,
|
||||
@@ -970,7 +970,7 @@ def v1_chat_generate_request(
|
||||
# - image_data: None or a list of image strings (URLs or base64 strings).
|
||||
# - audio_data: None or a list of audio strings (URLs).
|
||||
# None skips any image processing in GenerateReqInput.
|
||||
strict_tag = None
|
||||
tool_call_constraint = None
|
||||
prompt = ""
|
||||
prompt_ids = []
|
||||
if not isinstance(request.messages, str):
|
||||
@@ -989,7 +989,9 @@ def v1_chat_generate_request(
|
||||
|
||||
tool_call_parser = tokenizer_manager.server_args.tool_call_parser
|
||||
parser = FunctionCallParser(request.tools, tool_call_parser)
|
||||
strict_tag = parser.get_structure_tag()
|
||||
tool_call_constraint = parser.get_structure_constraint(
|
||||
request.tool_choice
|
||||
)
|
||||
|
||||
if chat_template_name is None:
|
||||
openai_compatible_messages = []
|
||||
@@ -1156,20 +1158,24 @@ def v1_chat_generate_request(
|
||||
request.response_format.model_dump(by_alias=True)
|
||||
)
|
||||
|
||||
if strict_tag is not None:
|
||||
if (
|
||||
sampling_params.get("regex")
|
||||
or sampling_params.get("ebnf")
|
||||
or sampling_params.get("structural_tag")
|
||||
or sampling_params.get("json_schema")
|
||||
):
|
||||
logger.warning(
|
||||
"Constrained decoding is not compatible with tool calls."
|
||||
# Check if there are already existing output constraints
|
||||
has_existing_constraints = (
|
||||
sampling_params.get("regex")
|
||||
or sampling_params.get("ebnf")
|
||||
or sampling_params.get("structural_tag")
|
||||
or sampling_params.get("json_schema")
|
||||
)
|
||||
|
||||
if tool_call_constraint and has_existing_constraints:
|
||||
logger.warning("Constrained decoding is not compatible with tool calls.")
|
||||
elif tool_call_constraint:
|
||||
constraint_type, constraint_value = tool_call_constraint
|
||||
if constraint_type == "structural_tag":
|
||||
sampling_params[constraint_type] = convert_json_schema_to_str(
|
||||
constraint_value.model_dump(by_alias=True)
|
||||
)
|
||||
else:
|
||||
sampling_params["structural_tag"] = convert_json_schema_to_str(
|
||||
strict_tag.model_dump(by_alias=True)
|
||||
)
|
||||
sampling_params[constraint_type] = constraint_value
|
||||
|
||||
sampling_params_list.append(sampling_params)
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ suites = {
|
||||
TestFile("test_fa3.py", 376),
|
||||
TestFile("test_fim_completion.py", 40),
|
||||
TestFile("test_fp8_kernel.py", 8),
|
||||
TestFile("test_function_calling.py", 60),
|
||||
TestFile("test_function_call_parser.py", 10),
|
||||
TestFile("test_fused_moe.py", 30),
|
||||
TestFile("test_hicache.py", 116),
|
||||
TestFile("test_hicache_mla.py", 254),
|
||||
@@ -54,6 +54,7 @@ suites = {
|
||||
TestFile("test_flashmla.py", 300),
|
||||
TestFile("test_no_chunked_prefill.py", 108),
|
||||
TestFile("test_no_overlap_scheduler.py", 216),
|
||||
TestFile("test_openai_function_calling.py", 60),
|
||||
TestFile("test_openai_server.py", 149),
|
||||
TestFile("test_penalty.py", 41),
|
||||
TestFile("test_page_size.py", 60),
|
||||
|
||||
408
test/srt/test_function_call_parser.py
Normal file
408
test/srt/test_function_call_parser.py
Normal file
@@ -0,0 +1,408 @@
|
||||
import json
|
||||
import unittest
|
||||
|
||||
from xgrammar import GrammarCompiler, TokenizerInfo
|
||||
|
||||
from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector
|
||||
from sglang.srt.function_call.llama32_detector import Llama32Detector
|
||||
from sglang.srt.function_call.mistral_detector import MistralDetector
|
||||
from sglang.srt.function_call.pythonic_detector import PythonicDetector
|
||||
from sglang.srt.function_call.qwen25_detector import Qwen25Detector
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.srt.openai_api.protocol import Function, Tool
|
||||
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
|
||||
|
||||
class TestPythonicDetector(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# Create sample tools for testing
|
||||
self.tools = [
|
||||
Tool(
|
||||
type="function",
|
||||
function=Function(
|
||||
name="get_weather",
|
||||
description="Get weather information",
|
||||
parameters={
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "Location to get weather for",
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"description": "Temperature unit",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
),
|
||||
),
|
||||
Tool(
|
||||
type="function",
|
||||
function=Function(
|
||||
name="search",
|
||||
description="Search for information",
|
||||
parameters={
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query",
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
),
|
||||
),
|
||||
]
|
||||
self.detector = PythonicDetector()
|
||||
|
||||
def test_parse_streaming_no_brackets(self):
|
||||
"""Test parsing text with no brackets (no tool calls)."""
|
||||
text = "This is just normal text without any tool calls."
|
||||
result = self.detector.parse_streaming_increment(text, self.tools)
|
||||
|
||||
self.assertEqual(result.normal_text, text)
|
||||
self.assertEqual(result.calls, [])
|
||||
self.assertEqual(self.detector._buffer, "") # Buffer should be cleared
|
||||
|
||||
def test_parse_streaming_complete_tool_call(self):
|
||||
"""Test parsing a complete tool call."""
|
||||
text = "Here's a tool call: [get_weather(location='New York', unit='celsius')]"
|
||||
result = self.detector.parse_streaming_increment(text, self.tools)
|
||||
|
||||
self.assertEqual(result.normal_text, "Here's a tool call: ")
|
||||
self.assertEqual(len(result.calls), 1)
|
||||
self.assertEqual(result.calls[0].name, "get_weather")
|
||||
self.assertEqual(
|
||||
self.detector._buffer, ""
|
||||
) # Buffer should be cleared after processing
|
||||
|
||||
# Check the parameters
|
||||
params = json.loads(result.calls[0].parameters)
|
||||
self.assertEqual(params["location"], "New York")
|
||||
self.assertEqual(params["unit"], "celsius")
|
||||
|
||||
def test_parse_streaming_text_before_tool_call(self):
|
||||
"""Test parsing text that appears before a tool call."""
|
||||
text = "This is some text before [get_weather(location='London')]"
|
||||
result = self.detector.parse_streaming_increment(text, self.tools)
|
||||
|
||||
self.assertEqual(result.normal_text, "This is some text before ")
|
||||
self.assertEqual(len(result.calls), 1)
|
||||
self.assertEqual(result.calls[0].name, "get_weather")
|
||||
|
||||
# Check the parameters
|
||||
params = json.loads(result.calls[0].parameters)
|
||||
self.assertEqual(params["location"], "London")
|
||||
|
||||
def test_parse_streaming_partial_tool_call(self):
|
||||
"""Test parsing a partial tool call that spans multiple chunks."""
|
||||
# First chunk with opening bracket but no closing bracket
|
||||
text1 = "Let me check the weather: [get_weather(location="
|
||||
result1 = self.detector.parse_streaming_increment(text1, self.tools)
|
||||
|
||||
self.assertEqual(result1.normal_text, "Let me check the weather: ")
|
||||
self.assertEqual(result1.calls, [])
|
||||
self.assertEqual(
|
||||
self.detector._buffer, "[get_weather(location="
|
||||
) # Partial tool call remains in buffer
|
||||
|
||||
# Second chunk completing the tool call
|
||||
text2 = "'Paris')]"
|
||||
result2 = self.detector.parse_streaming_increment(text2, self.tools)
|
||||
|
||||
self.assertEqual(result2.normal_text, "")
|
||||
self.assertEqual(len(result2.calls), 1)
|
||||
self.assertEqual(result2.calls[0].name, "get_weather")
|
||||
|
||||
# Check the parameters
|
||||
params = json.loads(result2.calls[0].parameters)
|
||||
self.assertEqual(params["location"], "Paris")
|
||||
self.assertEqual(
|
||||
self.detector._buffer, ""
|
||||
) # Buffer should be cleared after processing
|
||||
|
||||
def test_parse_streaming_bracket_without_text_before(self):
|
||||
"""Test parsing a tool call that starts at the beginning of the text."""
|
||||
text = "[search(query='python programming')]"
|
||||
result = self.detector.parse_streaming_increment(text, self.tools)
|
||||
|
||||
self.assertEqual(result.normal_text, "")
|
||||
self.assertEqual(len(result.calls), 1)
|
||||
self.assertEqual(result.calls[0].name, "search")
|
||||
|
||||
# Check the parameters
|
||||
params = json.loads(result.calls[0].parameters)
|
||||
self.assertEqual(params["query"], "python programming")
|
||||
|
||||
def test_parse_streaming_text_after_tool_call(self):
|
||||
"""Test parsing text that appears after a tool call."""
|
||||
# First chunk with complete tool call and some text after
|
||||
text = "[get_weather(location='Tokyo')] Here's the forecast:"
|
||||
result = self.detector.parse_streaming_increment(text, self.tools)
|
||||
|
||||
self.assertEqual(result.normal_text, "")
|
||||
self.assertEqual(len(result.calls), 1)
|
||||
self.assertEqual(result.calls[0].name, "get_weather")
|
||||
self.assertEqual(
|
||||
self.detector._buffer, " Here's the forecast:"
|
||||
) # Text after tool call remains in buffer
|
||||
|
||||
# Process the remaining text in buffer
|
||||
result2 = self.detector.parse_streaming_increment("", self.tools)
|
||||
self.assertEqual(result2.normal_text, " Here's the forecast:")
|
||||
self.assertEqual(result2.calls, [])
|
||||
self.assertEqual(self.detector._buffer, "") # Buffer should be cleared
|
||||
|
||||
def test_parse_streaming_multiple_tool_calls(self):
|
||||
"""Test parsing multiple tool calls in sequence."""
|
||||
text = "[get_weather(location='Berlin')] and [search(query='restaurants')]"
|
||||
|
||||
# First tool call
|
||||
result1 = self.detector.parse_streaming_increment(text, self.tools)
|
||||
self.assertEqual(len(result1.calls), 1)
|
||||
self.assertEqual(result1.calls[0].name, "get_weather")
|
||||
self.assertEqual(self.detector._buffer, " and [search(query='restaurants')]")
|
||||
|
||||
# Second tool call
|
||||
result2 = self.detector.parse_streaming_increment("", self.tools)
|
||||
self.assertEqual(result2.normal_text, " and ")
|
||||
self.assertEqual(len(result2.calls), 1)
|
||||
self.assertEqual(result2.calls[0].name, "search")
|
||||
self.assertEqual(self.detector._buffer, "")
|
||||
|
||||
def test_parse_streaming_opening_bracket_only(self):
|
||||
"""Test parsing text with only an opening bracket but no closing bracket."""
|
||||
text = "Let's try this: ["
|
||||
result = self.detector.parse_streaming_increment(text, self.tools)
|
||||
|
||||
self.assertEqual(result.normal_text, "Let's try this: ")
|
||||
self.assertEqual(result.calls, [])
|
||||
self.assertEqual(
|
||||
self.detector._buffer, "["
|
||||
) # Opening bracket remains in buffer
|
||||
|
||||
def test_parse_streaming_nested_brackets(self):
|
||||
"""Test parsing tool calls with nested brackets in arguments."""
|
||||
# Test with list argument containing nested brackets
|
||||
text = "[get_weather(location='New York', unit='celsius', data=[1, 2, 3])]"
|
||||
result = self.detector.parse_streaming_increment(text, self.tools)
|
||||
|
||||
self.assertEqual(result.normal_text, "")
|
||||
self.assertEqual(len(result.calls), 1)
|
||||
self.assertEqual(result.calls[0].name, "get_weather")
|
||||
self.assertEqual(self.detector._buffer, "")
|
||||
|
||||
# Check the parameters
|
||||
params = json.loads(result.calls[0].parameters)
|
||||
self.assertEqual(params["location"], "New York")
|
||||
self.assertEqual(params["unit"], "celsius")
|
||||
self.assertEqual(params["data"], [1, 2, 3])
|
||||
|
||||
def test_parse_streaming_nested_brackets_dict(self):
|
||||
"""Test parsing tool calls with nested dictionaries and lists."""
|
||||
# Test with nested dict and list arguments
|
||||
text = "[search(query='test', config={'options': [1, 2], 'nested': {'key': 'value'}})]"
|
||||
result = self.detector.parse_streaming_increment(text, self.tools)
|
||||
|
||||
self.assertEqual(result.normal_text, "")
|
||||
self.assertEqual(len(result.calls), 1)
|
||||
self.assertEqual(result.calls[0].name, "search")
|
||||
self.assertEqual(self.detector._buffer, "")
|
||||
|
||||
# Check the parameters
|
||||
params = json.loads(result.calls[0].parameters)
|
||||
self.assertEqual(params["query"], "test")
|
||||
self.assertEqual(params["config"]["options"], [1, 2])
|
||||
self.assertEqual(params["config"]["nested"]["key"], "value")
|
||||
|
||||
def test_parse_streaming_multiple_tools_with_nested_brackets(self):
|
||||
"""Test parsing multiple tool calls with nested brackets."""
|
||||
text = "[get_weather(location='Paris', data=[10, 20]), search(query='test', filters=['a', 'b'])]"
|
||||
result = self.detector.parse_streaming_increment(text, self.tools)
|
||||
|
||||
self.assertEqual(result.normal_text, "")
|
||||
self.assertEqual(len(result.calls), 2)
|
||||
self.assertEqual(self.detector._buffer, "")
|
||||
|
||||
# Check first tool call
|
||||
params1 = json.loads(result.calls[0].parameters)
|
||||
self.assertEqual(result.calls[0].name, "get_weather")
|
||||
self.assertEqual(params1["location"], "Paris")
|
||||
self.assertEqual(params1["data"], [10, 20])
|
||||
|
||||
# Check second tool call
|
||||
params2 = json.loads(result.calls[1].parameters)
|
||||
self.assertEqual(result.calls[1].name, "search")
|
||||
self.assertEqual(params2["query"], "test")
|
||||
self.assertEqual(params2["filters"], ["a", "b"])
|
||||
|
||||
def test_parse_streaming_partial_nested_brackets(self):
|
||||
"""Test parsing partial tool calls with nested brackets across chunks."""
|
||||
# First chunk with nested brackets but incomplete
|
||||
text1 = "Here's a call: [get_weather(location='Tokyo', data=[1, 2"
|
||||
result1 = self.detector.parse_streaming_increment(text1, self.tools)
|
||||
|
||||
self.assertEqual(result1.normal_text, "Here's a call: ")
|
||||
self.assertEqual(result1.calls, [])
|
||||
self.assertEqual(
|
||||
self.detector._buffer, "[get_weather(location='Tokyo', data=[1, 2"
|
||||
)
|
||||
|
||||
# Second chunk completing the nested brackets
|
||||
text2 = ", 3])]"
|
||||
result2 = self.detector.parse_streaming_increment(text2, self.tools)
|
||||
|
||||
self.assertEqual(result2.normal_text, "")
|
||||
self.assertEqual(len(result2.calls), 1)
|
||||
self.assertEqual(result2.calls[0].name, "get_weather")
|
||||
self.assertEqual(self.detector._buffer, "")
|
||||
|
||||
# Check the parameters
|
||||
params = json.loads(result2.calls[0].parameters)
|
||||
self.assertEqual(params["location"], "Tokyo")
|
||||
self.assertEqual(params["data"], [1, 2, 3])
|
||||
|
||||
|
||||
class TestEBNFGeneration(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# Create sample tools for testing
|
||||
self.tools = [
|
||||
Tool(
|
||||
type="function",
|
||||
function=Function(
|
||||
name="get_weather",
|
||||
description="Get weather information",
|
||||
parameters={
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "Location to get weather for",
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"description": "Temperature unit",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
),
|
||||
),
|
||||
Tool(
|
||||
type="function",
|
||||
function=Function(
|
||||
name="search",
|
||||
description="Search for information",
|
||||
parameters={
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query",
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
self.tokenizer = get_tokenizer(DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
|
||||
tokenizer_info = TokenizerInfo.from_huggingface(self.tokenizer)
|
||||
self.grammar_compiler = GrammarCompiler(tokenizer_info=tokenizer_info)
|
||||
|
||||
# Initialize all detectors
|
||||
self.pythonic_detector = PythonicDetector()
|
||||
self.deepseekv3_detector = DeepSeekV3Detector()
|
||||
self.llama32_detector = Llama32Detector()
|
||||
self.mistral_detector = MistralDetector()
|
||||
self.qwen25_detector = Qwen25Detector()
|
||||
|
||||
def test_pythonic_detector_ebnf(self):
|
||||
"""Test that the PythonicDetector generates valid EBNF."""
|
||||
ebnf = self.pythonic_detector.build_ebnf(self.tools)
|
||||
self.assertIsNotNone(ebnf)
|
||||
|
||||
# Check that the EBNF contains expected patterns
|
||||
self.assertIn('call_get_weather ::= "get_weather" "(" ', ebnf)
|
||||
self.assertIn('"location" "=" basic_string', ebnf)
|
||||
self.assertIn('[ "unit" "=" ("\\"celsius\\"" | "\\"fahrenheit\\"") ]', ebnf)
|
||||
|
||||
# Validate that the EBNF can be compiled by GrammarCompiler
|
||||
try:
|
||||
ctx = self.grammar_compiler.compile_grammar(ebnf)
|
||||
self.assertIsNotNone(ctx, "EBNF should be valid and compile successfully")
|
||||
except RuntimeError as e:
|
||||
self.fail(f"Failed to compile EBNF: {e}")
|
||||
|
||||
def test_deepseekv3_detector_ebnf(self):
|
||||
"""Test that the DeepSeekV3Detector generates valid EBNF."""
|
||||
ebnf = self.deepseekv3_detector.build_ebnf(self.tools)
|
||||
self.assertIsNotNone(ebnf)
|
||||
|
||||
# Check that the EBNF contains expected patterns
|
||||
self.assertIn("<|tool▁calls▁begin|>", ebnf)
|
||||
self.assertIn("<|tool▁call▁begin|>function<|tool▁sep|>get_weather", ebnf)
|
||||
self.assertIn('\\"location\\"" ":" basic_string ', ebnf)
|
||||
|
||||
# Validate that the EBNF can be compiled by GrammarCompiler
|
||||
try:
|
||||
ctx = self.grammar_compiler.compile_grammar(ebnf)
|
||||
self.assertIsNotNone(ctx, "EBNF should be valid and compile successfully")
|
||||
except RuntimeError as e:
|
||||
self.fail(f"Failed to compile EBNF: {e}")
|
||||
|
||||
def test_llama32_detector_ebnf(self):
|
||||
"""Test that the Llama32Detector generates valid EBNF."""
|
||||
ebnf = self.llama32_detector.build_ebnf(self.tools)
|
||||
self.assertIsNotNone(ebnf)
|
||||
|
||||
# Check that the EBNF contains expected patterns
|
||||
self.assertIn('\\"name\\"" ":" "\\"get_weather\\"', ebnf)
|
||||
self.assertIn('"\\"arguments\\"" ":"', ebnf)
|
||||
|
||||
# Validate that the EBNF can be compiled by GrammarCompiler
|
||||
try:
|
||||
ctx = self.grammar_compiler.compile_grammar(ebnf)
|
||||
self.assertIsNotNone(ctx, "EBNF should be valid and compile successfully")
|
||||
except RuntimeError as e:
|
||||
self.fail(f"Failed to compile EBNF: {e}")
|
||||
|
||||
def test_mistral_detector_ebnf(self):
|
||||
"""Test that the MistralDetector generates valid EBNF."""
|
||||
ebnf = self.mistral_detector.build_ebnf(self.tools)
|
||||
self.assertIsNotNone(ebnf)
|
||||
|
||||
# Check that the EBNF contains expected patterns
|
||||
self.assertIn('"[TOOL_CALLS] ["', ebnf)
|
||||
self.assertIn("call_get_weather | call_search", ebnf)
|
||||
self.assertIn('"\\"arguments\\"" ":"', ebnf)
|
||||
|
||||
# Validate that the EBNF can be compiled by GrammarCompiler
|
||||
try:
|
||||
ctx = self.grammar_compiler.compile_grammar(ebnf)
|
||||
self.assertIsNotNone(ctx, "EBNF should be valid and compile successfully")
|
||||
except RuntimeError as e:
|
||||
self.fail(f"Failed to compile EBNF: {e}")
|
||||
|
||||
def test_qwen25_detector_ebnf(self):
|
||||
"""Test that the Qwen25Detector generates valid EBNF."""
|
||||
ebnf = self.qwen25_detector.build_ebnf(self.tools)
|
||||
self.assertIsNotNone(ebnf)
|
||||
|
||||
# Check that the EBNF contains expected patterns
|
||||
self.assertIn("<tool_call>", ebnf)
|
||||
self.assertIn('\\"name\\"" ":" "\\"get_weather\\"', ebnf)
|
||||
self.assertIn('"\\"arguments\\"" ":"', ebnf)
|
||||
|
||||
# Validate that the EBNF can be compiled by GrammarCompiler
|
||||
try:
|
||||
ctx = self.grammar_compiler.compile_grammar(ebnf)
|
||||
self.assertIsNotNone(ctx, "EBNF should be valid and compile successfully")
|
||||
except RuntimeError as e:
|
||||
self.fail(f"Failed to compile EBNF: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -290,6 +290,151 @@ class TestOpenAIServerFunctionCalling(CustomTestCase):
|
||||
self.assertEqual(str(args_obj["int_a"]), "5", "Parameter int_a should be 5")
|
||||
self.assertEqual(str(args_obj["int_b"]), "7", "Parameter int_b should be 7")
|
||||
|
||||
def test_function_call_required(self):
|
||||
"""
|
||||
Test: Whether tool_choice: "required" works as expected
|
||||
- When tool_choice == "required", the model should return one or more tool_calls.
|
||||
"""
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "sub",
|
||||
"description": "Compute the difference of two integers",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"int_a": {
|
||||
"type": "integer",
|
||||
"description": "First integer",
|
||||
},
|
||||
"int_b": {
|
||||
"type": "integer",
|
||||
"description": "Second integer",
|
||||
},
|
||||
},
|
||||
"required": ["int_a", "int_b"],
|
||||
},
|
||||
"strict": True,
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "use this to get latest weather information for a city given its name",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "name of the city to get weather for",
|
||||
}
|
||||
},
|
||||
"required": ["city"],
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
messages = [{"role": "user", "content": "What is the capital of France?"}]
|
||||
response = client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
temperature=0.8,
|
||||
top_p=0.8,
|
||||
stream=False,
|
||||
tools=tools,
|
||||
tool_choice="required",
|
||||
)
|
||||
|
||||
tool_calls = response.choices[0].message.tool_calls
|
||||
self.assertIsNotNone(tool_calls, "No tool_calls in the response")
|
||||
function_name = tool_calls[0].function.name
|
||||
arguments = tool_calls[0].function.arguments
|
||||
args_obj = json.loads(arguments)
|
||||
|
||||
self.assertEqual(
|
||||
function_name, "get_weather", "Function name should be 'get_weather'"
|
||||
)
|
||||
self.assertIn("city", args_obj, "Function arguments should have 'city'")
|
||||
self.assertIn(
|
||||
"Paris", args_obj["city"], "Parameter city should contain 'Paris'"
|
||||
) # might be flaky
|
||||
|
||||
def test_function_call_specific(self):
|
||||
"""
|
||||
Test: Whether tool_choice: ToolChoice works as expected
|
||||
- When tool_choice is a specific ToolChoice, the model should return one or more tool_calls.
|
||||
"""
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "sub",
|
||||
"description": "Compute the difference of two integers",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"int_a": {
|
||||
"type": "integer",
|
||||
"description": "First integer",
|
||||
},
|
||||
"int_b": {
|
||||
"type": "integer",
|
||||
"description": "Second integer",
|
||||
},
|
||||
},
|
||||
"required": ["int_a", "int_b"],
|
||||
},
|
||||
"strict": True,
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "use this to get latest weather information for a city given its name",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "name of the city to get weather for",
|
||||
}
|
||||
},
|
||||
"required": ["city"],
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
messages = [{"role": "user", "content": "What is the capital of France?"}]
|
||||
response = client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
temperature=0.8,
|
||||
top_p=0.8,
|
||||
stream=False,
|
||||
tools=tools,
|
||||
tool_choice={"type": "function", "function": {"name": "get_weather"}},
|
||||
)
|
||||
|
||||
tool_calls = response.choices[0].message.tool_calls
|
||||
self.assertIsNotNone(tool_calls, "No tool_calls in the response")
|
||||
function_name = tool_calls[0].function.name
|
||||
arguments = tool_calls[0].function.arguments
|
||||
args_obj = json.loads(arguments)
|
||||
|
||||
self.assertEqual(
|
||||
function_name, "get_weather", "Function name should be 'get_weather'"
|
||||
)
|
||||
self.assertIn("city", args_obj, "Function arguments should have 'city'")
|
||||
|
||||
|
||||
class TestOpenAIPythonicFunctionCalling(CustomTestCase):
|
||||
PYTHONIC_TOOLS = [
|
||||
@@ -385,11 +530,13 @@ class TestOpenAIPythonicFunctionCalling(CustomTestCase):
|
||||
stream=False,
|
||||
)
|
||||
tool_calls = response.choices[0].message.tool_calls
|
||||
self.assertIsInstance(tool_calls, list)
|
||||
self.assertIsInstance(tool_calls, list, "No tool_calls found")
|
||||
self.assertGreaterEqual(len(tool_calls), 1)
|
||||
names = [tc.function.name for tc in tool_calls]
|
||||
self.assertIn("get_weather", names)
|
||||
self.assertIn("get_tourist_attractions", names)
|
||||
self.assertTrue(
|
||||
"get_weather" in names or "get_tourist_attractions" in names,
|
||||
f"Function name '{names}' should container either 'get_weather' or 'get_tourist_attractions'",
|
||||
)
|
||||
|
||||
def test_pythonic_tool_call_streaming(self):
|
||||
"""
|
||||
@@ -419,8 +566,10 @@ class TestOpenAIPythonicFunctionCalling(CustomTestCase):
|
||||
|
||||
self.assertTrue(found_tool_calls, "No tool_calls found in streaming response")
|
||||
self.assertTrue(found_index, "No index field found in any streamed tool_call")
|
||||
self.assertIn("get_weather", found_names)
|
||||
self.assertIn("get_tourist_attractions", found_names)
|
||||
self.assertTrue(
|
||||
"get_weather" in found_names or "get_tourist_attractions" in found_names,
|
||||
f"Function name '{found_names}' should container either 'get_weather' or 'get_tourist_attractions'",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Reference in New Issue
Block a user