diff --git a/docs/backend/function_calling.ipynb b/docs/backend/function_calling.ipynb
index 18ee1a431..25d7e6273 100644
--- a/docs/backend/function_calling.ipynb
+++ b/docs/backend/function_calling.ipynb
@@ -54,10 +54,12 @@
"source": [
"Note that `--tool-call-parser` defines the parser used to interpret responses. Currently supported parsers include:\n",
"\n",
- "- llama3: Llama 3.1 / 3.2 (e.g. meta-llama/Llama-3.1-8B-Instruct, meta-llama/Llama-3.2-1B-Instruct).\n",
+ "- llama3: Llama 3.1 / 3.2 / 3.3 (e.g. meta-llama/Llama-3.1-8B-Instruct, meta-llama/Llama-3.2-1B-Instruct, meta-llama/Llama-3.3-70B-Instruct).\n",
+ "- llama4: Llama 4 (e.g. meta-llama/Llama-4-Scout-17B-16E-Instruct).\n",
"- mistral: Mistral (e.g. mistralai/Mistral-7B-Instruct-v0.3, mistralai/Mistral-Nemo-Instruct-2407, mistralai/\n",
"Mistral-Nemo-Instruct-2407, mistralai/Mistral-7B-v0.3).\n",
- "- qwen25: Qwen 2.5 (e.g. Qwen/Qwen2.5-1.5B-Instruct, Qwen/Qwen2.5-7B-Instruct) and QwQ (i.e. Qwen/QwQ-32B). Especially, for QwQ, we can enable the reasoning parser together with tool call parser, details about reasoning parser can be found in [reasoning parser](https://docs.sglang.ai/backend/separate_reasoning.html)."
+ "- qwen25: Qwen 2.5 (e.g. Qwen/Qwen2.5-1.5B-Instruct, Qwen/Qwen2.5-7B-Instruct) and QwQ (i.e. Qwen/QwQ-32B). Especially, for QwQ, we can enable the reasoning parser together with tool call parser, details about reasoning parser can be found in [reasoning parser](https://docs.sglang.ai/backend/separate_reasoning.html).\n",
+ "- deepseekv3: DeepSeek-v3 (e.g., deepseek-ai/DeepSeek-V3-0324).\n"
]
},
{
@@ -360,6 +362,164 @@
"print(final_response.choices[0].message.content)"
]
},
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Tool Choice Mode\n",
+ "\n",
+ "SGLang supports OpenAI's `tool_choice` parameter to control when and which tools the model should call. This feature is implemented using EBNF (Extended Backus-Naur Form) grammar to ensure reliable tool calling behavior.\n",
+ "\n",
+ "### Supported Tool Choice Options\n",
+ "\n",
+ "- **`tool_choice=\"required\"`**: Forces the model to call at least one tool\n",
+ "- **`tool_choice={\"type\": \"function\", \"function\": {\"name\": \"specific_function\"}}`**: Forces the model to call a specific function\n",
+ "\n",
+ "### Backend Compatibility\n",
+ "\n",
+ "Tool choice is fully supported with the **Xgrammar backend**, which is the default grammar backend (`--grammar-backend xgrammar`). However, it may not be fully supported with other backends such as `outlines`.\n",
+ "\n",
+ "### Example: Required Tool Choice"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Response with tool_choice='required':\n",
+ "Content: None\n",
+ "Tool calls: [ChatCompletionMessageToolCall(id='call_NFO3TSZuRRO8Eu3Cv79uiQ', function=Function(arguments='{\"city\": \"Paris\", \"unit\": \"celsius\"}', name='get_current_weather'), type='function', index=0)]\n"
+ ]
+ }
+ ],
+ "source": [
+ "from openai import OpenAI\n",
+ "import json\n",
+ "from sglang.utils import wait_for_server, print_highlight, terminate_process\n",
+ "from sglang.test.test_utils import is_in_ci\n",
+ "\n",
+ "if is_in_ci():\n",
+ " from patch import launch_server_cmd\n",
+ "else:\n",
+ " from sglang.utils import launch_server_cmd\n",
+ " import nest_asyncio\n",
+ "\n",
+ " nest_asyncio.apply()\n",
+ "\n",
+ "# Start a new server session for tool choice examples\n",
+ "server_process_tool_choice, port_tool_choice = launch_server_cmd(\n",
+ " \"python3 -m sglang.launch_server --model-path Qwen/Qwen2.5-7B-Instruct --tool-call-parser qwen25 --host 0.0.0.0\"\n",
+ ")\n",
+ "wait_for_server(f\"http://localhost:{port_tool_choice}\")\n",
+ "\n",
+ "# Initialize client for tool choice examples\n",
+ "client_tool_choice = OpenAI(\n",
+ " api_key=\"None\", base_url=f\"http://0.0.0.0:{port_tool_choice}/v1\"\n",
+ ")\n",
+ "model_name_tool_choice = client_tool_choice.models.list().data[0].id\n",
+ "\n",
+ "# Example with tool_choice=\"required\" - forces the model to call a tool\n",
+ "messages_required = [\n",
+ " {\"role\": \"user\", \"content\": \"Hello, what is the capital of France?\"}\n",
+ "]\n",
+ "\n",
+ "# Define tools\n",
+ "tools = [\n",
+ " {\n",
+ " \"type\": \"function\",\n",
+ " \"function\": {\n",
+ " \"name\": \"get_current_weather\",\n",
+ " \"description\": \"Get the current weather in a given location\",\n",
+ " \"parameters\": {\n",
+ " \"type\": \"object\",\n",
+ " \"properties\": {\n",
+ " \"city\": {\n",
+ " \"type\": \"string\",\n",
+ " \"description\": \"The city to find the weather for, e.g. 'San Francisco'\",\n",
+ " },\n",
+ " \"unit\": {\n",
+ " \"type\": \"string\",\n",
+ " \"description\": \"The unit to fetch the temperature in\",\n",
+ " \"enum\": [\"celsius\", \"fahrenheit\"],\n",
+ " },\n",
+ " },\n",
+ " \"required\": [\"city\", \"unit\"],\n",
+ " },\n",
+ " },\n",
+ " }\n",
+ "]\n",
+ "\n",
+ "response_required = client_tool_choice.chat.completions.create(\n",
+ " model=model_name_tool_choice,\n",
+ " messages=messages_required,\n",
+ " temperature=0,\n",
+ " max_tokens=1024,\n",
+ " tools=tools,\n",
+ " tool_choice=\"required\", # Force the model to call a tool\n",
+ ")\n",
+ "\n",
+ "print_highlight(\"Response with tool_choice='required':\")\n",
+ "print(\"Content:\", response_required.choices[0].message.content)\n",
+ "print(\"Tool calls:\", response_required.choices[0].message.tool_calls)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Example: Specific Function Choice\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Response with specific function choice:\n",
+ "Content: None\n",
+ "Tool calls: [ChatCompletionMessageToolCall(id='call_fGL_1qsPQFqntNBPkSynJw', function=Function(arguments='{\"city\": \"Sophia Antipolis\", \"unit\": \"celsius\"}', name='get_current_weather'), type='function', index=0)]\n",
+ "Called function: get_current_weather\n",
+ "Arguments: {\"city\": \"Sophia Antipolis\", \"unit\": \"celsius\"}\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Example with specific function choice - forces the model to call a specific function\n",
+ "messages_specific = [\n",
+ " {\"role\": \"user\", \"content\": \"What are the most attactive places in France?\"}\n",
+ "]\n",
+ "\n",
+ "response_specific = client_tool_choice.chat.completions.create(\n",
+ " model=model_name_tool_choice,\n",
+ " messages=messages_specific,\n",
+ " temperature=0,\n",
+ " max_tokens=1024,\n",
+ " tools=tools,\n",
+ " tool_choice={\n",
+ " \"type\": \"function\",\n",
+ " \"function\": {\"name\": \"get_current_weather\"},\n",
+ " }, # Force the model to call the specific get_current_weather function\n",
+ ")\n",
+ "\n",
+ "print_highlight(\"Response with specific function choice:\")\n",
+ "print(\"Content:\", response_specific.choices[0].message.content)\n",
+ "print(\"Tool calls:\", response_specific.choices[0].message.tool_calls)\n",
+ "\n",
+ "if response_specific.choices[0].message.tool_calls:\n",
+ " tool_call = response_specific.choices[0].message.tool_calls[0]\n",
+ " print(f\"Called function: {tool_call.function.name}\")\n",
+ " print(f\"Arguments: {tool_call.function.arguments}\")"
+ ]
+ },
{
"cell_type": "markdown",
"metadata": {},
@@ -444,7 +604,7 @@
"outputs": [],
"source": [
"import sglang as sgl\n",
- "from sglang.srt.function_call_parser import FunctionCallParser\n",
+ "from sglang.srt.function_call.function_call_parser import FunctionCallParser\n",
"from sglang.srt.managers.io_struct import Tool, Function\n",
"\n",
"llm = sgl.Engine(model_path=\"Qwen/Qwen2.5-7B-Instruct\")\n",
diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py
index e98e3d3de..ff0978e38 100644
--- a/python/sglang/srt/entrypoints/http_server.py
+++ b/python/sglang/srt/entrypoints/http_server.py
@@ -47,7 +47,7 @@ from sglang.srt.disaggregation.utils import (
register_disaggregation_server,
)
from sglang.srt.entrypoints.engine import _launch_subprocesses
-from sglang.srt.function_call_parser import FunctionCallParser
+from sglang.srt.function_call.function_call_parser import FunctionCallParser
from sglang.srt.managers.io_struct import (
AbortReq,
CloseSessionReqInput,
diff --git a/python/sglang/srt/function_call/base_format_detector.py b/python/sglang/srt/function_call/base_format_detector.py
new file mode 100644
index 000000000..497e29c56
--- /dev/null
+++ b/python/sglang/srt/function_call/base_format_detector.py
@@ -0,0 +1,250 @@
+import json
+import logging
+from abc import ABC, abstractmethod
+from typing import Any, Dict, List
+
+from partial_json_parser.core.exceptions import MalformedJSON
+from partial_json_parser.core.options import Allow
+
+from sglang.srt.function_call.core_types import (
+ StreamingParseResult,
+ ToolCallItem,
+ _GetInfoFunc,
+)
+from sglang.srt.function_call.utils import (
+ _find_common_prefix,
+ _is_complete_json,
+ _partial_json_loads,
+)
+from sglang.srt.openai_api.protocol import Tool
+
+logger = logging.getLogger(__name__)
+
+
+class BaseFormatDetector(ABC):
+ """Base class providing two sets of interfaces: one-time and streaming incremental."""
+
+ def __init__(self):
+ # initialize properties used for state when parsing tool calls in
+ self._buffer = ""
+ # streaming mode
+ self.prev_tool_call_arr: List[Dict] = []
+ self.current_tool_id: int = -1
+ self.current_tool_name_sent: bool = False
+ self.streamed_args_for_tool: List[str] = (
+ []
+ ) # map what has been streamed for each tool so far to a list
+ self.bot_token = ""
+ self.eot_token = ""
+
+ def parse_base_json(self, action: Any, tools: List[Tool]) -> List[ToolCallItem]:
+ tool_indices = {
+ tool.function.name: i for i, tool in enumerate(tools) if tool.function.name
+ }
+ if not isinstance(action, list):
+ action = [action]
+
+ results = []
+ for act in action:
+ name = act.get("name")
+ if name and name in tool_indices:
+ results.append(
+ ToolCallItem(
+ tool_index=tool_indices[name],
+ name=name,
+ parameters=json.dumps(
+ act.get("parameters") or act.get("arguments", {}),
+ ensure_ascii=False,
+ ),
+ )
+ )
+ else:
+ logger.warning(f"Model attempted to call undefined function: {name}")
+
+ return results
+
+ @abstractmethod
+ def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
+ """
+ Parses the text in one go. Returns success=True if the format matches, otherwise False.
+ Note that leftover_text here represents "content that this parser will not consume further".
+ """
+ action = json.loads(text)
+ return StreamingParseResult(calls=self.parse_base_json(action, tools))
+
+ def parse_streaming_increment(
+ self, new_text: str, tools: List[Tool]
+ ) -> StreamingParseResult:
+ """
+ Streaming incremental parsing with tool validation.
+ """
+ # Append new text to buffer
+ self._buffer += new_text
+ current_text = self._buffer
+ if not (self.bot_token in current_text or current_text.startswith("{")):
+ self._buffer = ""
+ if self.eot_token in new_text:
+ new_text = new_text.replace(self.eot_token, "")
+ return StreamingParseResult(normal_text=new_text)
+
+ # Build tool indices if not already built
+ if not hasattr(self, "_tool_indices"):
+ self._tool_indices = {
+ tool.function.name: i
+ for i, tool in enumerate(tools)
+ if tool.function and tool.function.name
+ }
+
+ flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR
+ try:
+ tool_call_arr = []
+ is_complete = []
+ try:
+ start_idx = (
+ len(self.bot_token)
+ if current_text.startswith(self.bot_token)
+ else 0
+ )
+ while start_idx < len(current_text):
+ (obj, end_idx) = _partial_json_loads(
+ current_text[start_idx:], flags
+ )
+ is_complete.append(
+ _is_complete_json(current_text[start_idx : start_idx + end_idx])
+ )
+ start_idx += end_idx + len("; ")
+
+ # Validate tool name if present
+ if "name" in obj and obj["name"] not in self._tool_indices:
+ # Invalid tool name - reset state
+ self._buffer = ""
+ self.current_tool_id = -1
+ self.current_tool_name_sent = False
+ if self.streamed_args_for_tool:
+ self.streamed_args_for_tool.pop()
+ return StreamingParseResult()
+
+ # Handle parameters/arguments consistency
+ if "parameters" in obj:
+ assert (
+ "arguments" not in obj
+ ), "model generated both parameters and arguments"
+ obj["arguments"] = obj["parameters"]
+ tool_call_arr.append(obj)
+
+ except MalformedJSON:
+ return StreamingParseResult()
+
+ if len(tool_call_arr) == 0:
+ return StreamingParseResult()
+
+ current_tool_call: Dict = (
+ tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {}
+ )
+
+ # Handle new tool in array
+ if len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1:
+ if self.current_tool_id >= 0:
+ cur_arguments = current_tool_call.get("arguments")
+ if cur_arguments:
+ cur_args_json = json.dumps(cur_arguments)
+ sent = len(self.streamed_args_for_tool[self.current_tool_id])
+ argument_diff = cur_args_json[sent:]
+
+ res = StreamingParseResult(
+ calls=[
+ ToolCallItem(
+ tool_index=self.current_tool_id,
+ name="",
+ parameters=argument_diff,
+ )
+ ],
+ )
+ self.streamed_args_for_tool[
+ self.current_tool_id
+ ] += argument_diff
+ else:
+ res = StreamingParseResult()
+ else:
+ res = StreamingParseResult()
+
+ self.current_tool_id = len(tool_call_arr) - 1
+ self.current_tool_name_sent = False
+ self.streamed_args_for_tool.append("")
+ return res
+
+ # Handle tool name
+ elif not self.current_tool_name_sent:
+ function_name = current_tool_call.get("name")
+ if function_name and function_name in self._tool_indices:
+ res = StreamingParseResult(
+ calls=[
+ ToolCallItem(
+ tool_index=self._tool_indices[function_name],
+ name=function_name,
+ parameters="",
+ )
+ ],
+ )
+ self.current_tool_name_sent = True
+ else:
+ res = StreamingParseResult()
+
+ # Handle streaming arguments
+ else:
+ cur_arguments = current_tool_call.get("arguments")
+ res = StreamingParseResult()
+
+ if cur_arguments:
+ sent = len(self.streamed_args_for_tool[self.current_tool_id])
+ cur_args_json = json.dumps(cur_arguments)
+ prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get(
+ "arguments"
+ )
+
+ argument_diff = None
+ if is_complete[self.current_tool_id]:
+ argument_diff = cur_args_json[sent:]
+ self._buffer = ""
+ self.prev_tool_call_arr[self.current_tool_id].clear()
+ self.current_tool_name_sent = False
+ self.streamed_args_for_tool[self.current_tool_id] = ""
+
+ elif prev_arguments:
+ prev_args_json = json.dumps(prev_arguments)
+ if cur_args_json != prev_args_json:
+ prefix = _find_common_prefix(prev_args_json, cur_args_json)
+ argument_diff = prefix[sent:]
+
+ if argument_diff is not None:
+ res = StreamingParseResult(
+ calls=[
+ ToolCallItem(
+ tool_index=self.current_tool_id,
+ parameters=argument_diff,
+ )
+ ],
+ )
+ if not is_complete[self.current_tool_id]:
+ self.streamed_args_for_tool[
+ self.current_tool_id
+ ] += argument_diff
+
+ self.prev_tool_call_arr = tool_call_arr
+ return res
+
+ except Exception as e:
+ logger.error(f"Error in parse_streaming_increment: {e}")
+ return StreamingParseResult()
+
+ @abstractmethod
+ def has_tool_call(self, text: str) -> bool:
+ raise NotImplementedError()
+
+ @abstractmethod
+ def structure_info(self) -> _GetInfoFunc:
+ raise NotImplementedError()
+
+ @abstractmethod
+ def build_ebnf(self, tools: List[Tool]) -> str:
+ raise NotImplementedError()
diff --git a/python/sglang/srt/function_call/core_types.py b/python/sglang/srt/function_call/core_types.py
new file mode 100644
index 000000000..1ea87df79
--- /dev/null
+++ b/python/sglang/srt/function_call/core_types.py
@@ -0,0 +1,34 @@
+from dataclasses import dataclass
+from typing import Callable, List, Optional
+
+from pydantic import BaseModel
+
+
+class ToolCallItem(BaseModel):
+ """Simple encapsulation of the parsed ToolCall result for easier usage in streaming contexts."""
+
+ tool_index: int
+ name: Optional[str] = None
+ parameters: str # JSON string
+
+
+class StreamingParseResult(BaseModel):
+ """Result of streaming incremental parsing."""
+
+ normal_text: str = ""
+ calls: List[ToolCallItem] = []
+
+
+@dataclass
+class StructureInfo:
+ begin: str
+ end: str
+ trigger: str
+
+
+"""
+Helper alias of function
+Usually it is a function that takes a name string and returns a StructureInfo object,
+which can be used to construct a structural_tag object
+"""
+_GetInfoFunc = Callable[[str], StructureInfo]
diff --git a/python/sglang/srt/function_call/deepseekv3_detector.py b/python/sglang/srt/function_call/deepseekv3_detector.py
new file mode 100644
index 000000000..b53c68ab6
--- /dev/null
+++ b/python/sglang/srt/function_call/deepseekv3_detector.py
@@ -0,0 +1,157 @@
+import json
+import logging
+import re
+from typing import List
+
+from sglang.srt.function_call.base_format_detector import BaseFormatDetector
+from sglang.srt.function_call.core_types import (
+ StreamingParseResult,
+ StructureInfo,
+ ToolCallItem,
+ _GetInfoFunc,
+)
+from sglang.srt.function_call.ebnf_composer import EBNFComposer
+from sglang.srt.function_call.utils import _is_complete_json
+from sglang.srt.openai_api.protocol import Tool
+
+logger = logging.getLogger(__name__)
+
+
+class DeepSeekV3Detector(BaseFormatDetector):
+ """
+ Detector for DeepSeek models.
+ Assumes function call format:
+ '<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_current_weather\n```json\n{"location": "Tokyo"}\n```<|tool▁call▁end|>\n<|tool▁call▁begin|>function<|tool▁sep|>get_current_weather\n```json\n{"location": "Paris"}\n```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>
+ """
+
+ def __init__(self):
+ super().__init__()
+ self.bot_token = "<|tool▁calls▁begin|>"
+ self.eot_token = "<|tool▁calls▁end|>"
+ self.func_call_regex = r"<|tool▁call▁begin|>.*?<|tool▁call▁end|>"
+ self.func_detail_regex = r"<|tool▁call▁begin|>(.*)<|tool▁sep|>(.*)\n```json\n(.*)\n```<|tool▁call▁end|>"
+ self._last_arguments = ""
+
+ def has_tool_call(self, text: str) -> bool:
+ """Check if the text contains a deepseek format tool call."""
+ return self.bot_token in text
+
+ def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
+ """
+ One-time parsing: Detects and parses tool calls in the provided text.
+
+ :param text: The complete text to parse.
+ :param tools: List of available tools.
+ :return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
+ """
+ idx = text.find(self.bot_token)
+ normal_text = text[:idx].strip() if idx != -1 else text
+ if self.bot_token not in text:
+ return StreamingParseResult(normal_text=normal_text, calls=[])
+ match_result_list = re.findall(self.func_call_regex, text, re.DOTALL)
+ calls = []
+ try:
+ for match_result in match_result_list:
+ # Get function name
+ func_detail = re.search(self.func_detail_regex, match_result, re.DOTALL)
+ func_name = func_detail.group(2)
+ func_args = func_detail.group(3)
+ func_args = json.loads(func_args)
+ # construct match_result for parse_base_json
+ match_result = {"name": func_name, "parameters": func_args}
+ calls.extend(self.parse_base_json(match_result, tools))
+ return StreamingParseResult(normal_text=normal_text, calls=calls)
+ except Exception as e:
+ logger.error(f"Error in detect_and_parse: {e}")
+ # return the normal text if parsing fails
+ return StreamingParseResult(normal_text=text)
+
+ def parse_streaming_increment(
+ self, new_text: str, tools: List[Tool]
+ ) -> StreamingParseResult:
+ """
+ Streaming incremental parsing tool calls for DeepSeekV3 format.
+ """
+ self._buffer += new_text
+ current_text = self._buffer
+
+ if self.bot_token not in current_text:
+ self._buffer = ""
+ for e_token in [self.eot_token, "```", "<|tool▁call▁end|>"]:
+ if e_token in new_text:
+ new_text = new_text.replace(e_token, "")
+ return StreamingParseResult(normal_text=new_text)
+
+ if not hasattr(self, "_tool_indices"):
+ self._tool_indices = {
+ tool.function.name: i
+ for i, tool in enumerate(tools)
+ if tool.function and tool.function.name
+ }
+
+ calls: list[ToolCallItem] = []
+ try:
+ partial_match = re.search(
+ pattern=r"<|tool▁call▁begin|>(.*)<|tool▁sep|>(.*)\n```json\n(.*)",
+ string=current_text,
+ flags=re.DOTALL,
+ )
+ if partial_match:
+ func_name = partial_match.group(2).strip()
+ func_args_raw = partial_match.group(3).strip()
+
+ if not self.current_tool_name_sent:
+ calls.append(
+ ToolCallItem(
+ tool_index=self._tool_indices.get(func_name, 0),
+ name=func_name,
+ parameters="",
+ )
+ )
+ self.current_tool_name_sent = True
+ else:
+ argument_diff = (
+ func_args_raw[len(self._last_arguments) :]
+ if func_args_raw.startswith(self._last_arguments)
+ else func_args_raw
+ )
+
+ if argument_diff:
+ calls.append(
+ ToolCallItem(
+ tool_index=self._tool_indices.get(func_name, 0),
+ name=None,
+ parameters=argument_diff,
+ )
+ )
+ self._last_arguments += argument_diff
+
+ if _is_complete_json(func_args_raw):
+ result = StreamingParseResult(normal_text="", calls=calls)
+ self._buffer = ""
+ self._last_arguments = ""
+ self.current_tool_name_sent = False
+ return result
+
+ return StreamingParseResult(normal_text="", calls=calls)
+
+ except Exception as e:
+ logger.error(f"Error in parse_streaming_increment: {e}")
+ return StreamingParseResult(normal_text=current_text)
+
+ def structure_info(self) -> _GetInfoFunc:
+ return lambda name: StructureInfo(
+ begin=">" + name + "\n```json\n",
+ end="\n```<",
+ trigger=">" + name + "\n```json\n",
+ )
+
+ def build_ebnf(self, tools: List[Tool]):
+ return EBNFComposer.build_ebnf(
+ tools,
+ bot_token=self.bot_token,
+ eot_token=self.eot_token,
+ tool_call_separator="",
+ call_rule_fmt='"<|tool▁call▁begin|>function<|tool▁sep|>{name}\\n```json\\n" {arguments_rule} "\\n```<|tool▁call▁end|>"',
+ function_format="json",
+ )
diff --git a/python/sglang/srt/function_call/ebnf_composer.py b/python/sglang/srt/function_call/ebnf_composer.py
new file mode 100644
index 000000000..d749f65d7
--- /dev/null
+++ b/python/sglang/srt/function_call/ebnf_composer.py
@@ -0,0 +1,234 @@
+from typing import Literal, Optional
+
+
+class EBNFComposer:
+ # Adapted from https://xgrammar.mlc.ai/docs/how_to/ebnf_guided_generation.html#try-out-via-hf-transformers
+ json_grammar_ebnf_str = r"""
+ json ::= basic_array | basic_object
+ basic_any ::= basic_number | basic_string | basic_boolean | basic_null | basic_array | basic_object
+ basic_integer ::= ("0" | "-"? [1-9] [0-9]*) ".0"?
+ basic_number ::= ("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)?
+ basic_string ::= (([\"] basic_string_1 [\"]))
+ basic_string_1 ::= "" | [^"\\\x00-\x1F] basic_string_1 | "\\" escape basic_string_1
+ escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9]
+ basic_boolean ::= "true" | "false"
+ basic_null ::= "null"
+ basic_array ::= "[" ("" | ws basic_any (ws "," ws basic_any)*) ws "]"
+ basic_object ::= "{" ("" | ws basic_string ws ":" ws basic_any ( ws "," ws basic_string ws ":" ws basic_any)*) ws "}"
+ ws ::= [ \n\t]*
+ """
+
+ pythonic_grammar_ebnf_str = r"""
+ pythonic ::= basic_number | basic_string | basic_array | "True" | "False" | "None"
+ basic_any ::= basic_number | basic_string | basic_array | basic_object
+ basic_number ::= ("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)?
+ basic_string ::= (([\"] basic_string_1 [\"]))
+ basic_string_1 ::= "" | [^"\\\x00-\x1F] basic_string_1 | "\\" escape basic_string_1
+ escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9]
+ basic_array ::= "[" ("" | ws basic_any (ws "," ws basic_any)*) ws "]"
+ basic_object ::= "{" ("" | ws basic_string ws ":" ws basic_any ( ws "," ws basic_string ws ":" ws basic_any)*) ws "}"
+ ws ::= [ \n\t]*
+ """
+
+ TOOL_CALLS_MAP = {
+ "pythonic": '"[" function_call ("," function_call)* "]"',
+ "json": "function_call",
+ }
+
+ CALL_RULE_MAP = {
+ "pythonic": 'call_{name} ::= "{name}" "(" {arguments_rule} ")"',
+ "json": 'call_{name} ::= "{{" "\\"name\\"" ":" "\\"{name}\\"" ", " "\\"arguments\\"" ":" {arguments_rule} "}}"',
+ }
+
+ ARGUMENTS_RULE_MAP = {
+ "pythonic": "{arg_rules}",
+ "json": '"{{" {arg_rules} "}}"',
+ }
+
+ KEY_VALUE_RULE_MAP = {
+ "pythonic": '"{key}" "=" {valrule}',
+ "json": '"\\"{key}\\"" ":" {valrule}',
+ }
+
+ JSON_TYPE_MAPPING = {
+ "string": "basic_string",
+ "number": "basic_number",
+ "integer": "basic_number",
+ "boolean": "basic_boolean",
+ "null": "basic_null",
+ "array": "basic_array",
+ "object": "basic_object",
+ }
+
+ PYTHONIC_TYPE_MAPPING = {
+ "string": "basic_string",
+ "number": "basic_number",
+ "integer": "basic_number",
+ "boolean": '"True" | "False"',
+ "null": '"None"',
+ "array": "basic_array",
+ "object": "basic_object",
+ }
+
+ @staticmethod
+ def get_value_rule(
+ prop: dict, function_format: Literal["pythonic", "json"] = "json"
+ ) -> str:
+ if "enum" in prop:
+ return EBNFComposer._handle_enum(prop, function_format)
+
+ if "type" in prop:
+ return EBNFComposer._handle_type(prop, function_format)
+
+ return function_format
+
+ @staticmethod
+ def _handle_enum(prop: dict, function_format: str) -> str:
+ """Handle enum properties by formatting each value according to type and format."""
+ enum_values = prop["enum"]
+ prop_type = prop.get("type", "string")
+
+ # Define formatters for different type/format combinations
+ formatters = {
+ ("string", "json"): lambda v: f'"\\"{v}\\""',
+ ("string", "pythonic"): lambda v: f'"\\"{v}\\""',
+ ("number", "json"): str,
+ ("number", "pythonic"): str,
+ ("integer", "json"): str,
+ ("integer", "pythonic"): str,
+ ("boolean", "json"): lambda v: "true" if v else "false",
+ ("boolean", "pythonic"): lambda v: "True" if v else "False",
+ }
+
+ # Get the formatter or default to string handling
+ formatter = formatters.get(
+ (prop_type, function_format),
+ formatters[("string", function_format)], # Default to string handling
+ )
+
+ formatted_values = [formatter(value) for value in enum_values]
+ enum_rule = " | ".join(formatted_values)
+
+ # Wrap in parentheses if there are multiple values to ensure correct EBNF precedence
+ if len(formatted_values) > 1:
+ enum_rule = f"({enum_rule})"
+
+ return enum_rule
+
+ @staticmethod
+ def _handle_type(prop: dict, function_format: str) -> str:
+ """Handle type properties using the appropriate type mapping."""
+ prop_type = prop["type"]
+ type_mapping = (
+ EBNFComposer.PYTHONIC_TYPE_MAPPING
+ if function_format == "pythonic"
+ else EBNFComposer.JSON_TYPE_MAPPING
+ )
+
+ if isinstance(prop_type, list):
+ type_rules = [
+ type_mapping[single_type]
+ for single_type in prop_type
+ if single_type in type_mapping
+ ]
+ return " | ".join(type_rules) if type_rules else function_format
+
+ return type_mapping.get(prop_type, function_format)
+
+ @staticmethod
+ def build_ebnf(
+ tools,
+ *,
+ call_rule_fmt: Optional[str] = None,
+ function_format: Literal["pythonic", "json"] = "json",
+ bot_token: Optional[str] = None,
+ eot_token: Optional[str] = None,
+ tool_call_separator: Optional[str] = None,
+ ):
+ """
+ Generalized EBNF builder for all detectors.
+ Args:
+ tools: List of Tool objects to generate EBNF grammar for
+ call_rule_fmt: Optional custom format string for call_{name} rule. It should define each function call's format, with
+ the placeholders {name} for the function name and {arguments_rule} for the arguments rule. If None, a default
+ format based on function_format will be used.
+ function_format: The format of function calls, either "pythonic" or "json"
+ bot_token: The token that indicates the start of a tool call section
+ eot_token: The token that indicates the end of a tool call section
+ tool_call_separator: The separator between multiple tool calls
+ """
+ # =================================================================
+ # Step 1: Determine the root tool calls rule
+ # =================================================================
+ if bot_token and eot_token:
+ if tool_call_separator:
+ root_rule = f'"{bot_token}" function_call ( "{tool_call_separator}" function_call )* "{eot_token}"'
+ else:
+ root_rule = f'"{bot_token}" function_call "{eot_token}"'
+ else:
+ root_rule = EBNFComposer.TOOL_CALLS_MAP[function_format]
+
+ # =================================================================
+ # Step 2: Build the header rules
+ # =================================================================
+ ebnf_lines = [
+ f"root ::= {root_rule}",
+ "function_call ::= "
+ + " | ".join([f"call_{tool.function.name}" for tool in tools]),
+ ]
+
+ # =================================================================
+ # Step 3: Set up formatting templates
+ # =================================================================
+ call_template = (
+ f"call_{{name}} ::= {call_rule_fmt}"
+ if call_rule_fmt
+ else EBNFComposer.CALL_RULE_MAP[function_format]
+ )
+ args_template = EBNFComposer.ARGUMENTS_RULE_MAP[function_format]
+ key_value_template = EBNFComposer.KEY_VALUE_RULE_MAP[function_format]
+
+ # =================================================================
+ # Step 4: Build rules for each tool
+ # =================================================================
+ for tool in tools:
+ tool_name = tool.function.name
+ params = tool.function.parameters or {}
+ properties = params.get("properties", {})
+ required_props = set(params.get("required", []))
+
+ # Build argument rules for this tool
+ arg_rules = []
+ for prop_name, prop_schema in properties.items():
+ value_rule = EBNFComposer.get_value_rule(prop_schema, function_format)
+ # Create key=value pair
+ pair = key_value_template.format(key=prop_name, valrule=value_rule)
+
+ if prop_name not in required_props:
+ pair = f"[ {pair} ]"
+
+ arg_rules.append(pair)
+
+ # Combine all argument rules
+ combined_args = ' "," '.join(arg_rules) if arg_rules else ""
+ arguments_rule = args_template.format(arg_rules=combined_args)
+
+ # Add the function call rule and its arguments rule
+ ebnf_lines.append(
+ call_template.format(
+ name=tool_name, arguments_rule=f"arguments_{tool_name}"
+ )
+ )
+ ebnf_lines.append(f"arguments_{tool_name} ::= {arguments_rule}")
+
+ # =================================================================
+ # Step 5: Add base grammar rules
+ # =================================================================
+ base_grammar = (
+ EBNFComposer.pythonic_grammar_ebnf_str
+ if function_format == "pythonic"
+ else EBNFComposer.json_grammar_ebnf_str
+ )
+ ebnf_lines.append(base_grammar)
+
+ return "\n".join(ebnf_lines)
diff --git a/python/sglang/srt/function_call/function_call_parser.py b/python/sglang/srt/function_call/function_call_parser.py
new file mode 100644
index 000000000..f73217086
--- /dev/null
+++ b/python/sglang/srt/function_call/function_call_parser.py
@@ -0,0 +1,175 @@
+from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Type, Union
+
+from sglang.srt.function_call.base_format_detector import BaseFormatDetector
+from sglang.srt.function_call.core_types import ToolCallItem
+from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector
+from sglang.srt.function_call.llama32_detector import Llama32Detector
+from sglang.srt.function_call.mistral_detector import MistralDetector
+from sglang.srt.function_call.pythonic_detector import PythonicDetector
+from sglang.srt.function_call.qwen25_detector import Qwen25Detector
+from sglang.srt.openai_api.protocol import (
+ StructuralTagResponseFormat,
+ StructuresResponseFormat,
+ Tool,
+ ToolChoice,
+)
+
+
+class FunctionCallParser:
+ """
+ Parser for function/tool calls in model outputs.
+
+ This class handles both streaming and non-streaming parsing of function calls using a detector.
+ In streaming scenarios, each time new_text is received, it calls detector.parse_streaming_increment
+ and returns the resulting normal_text and calls to the upper layer (or SSE).
+ """
+
+ ToolCallParserEnum: Dict[str, Type[BaseFormatDetector]] = {
+ "llama3": Llama32Detector,
+ "qwen25": Qwen25Detector,
+ "mistral": MistralDetector,
+ "deepseekv3": DeepSeekV3Detector,
+ "pythonic": PythonicDetector,
+ }
+
+ def __init__(self, tools: List[Tool], tool_call_parser: str):
+ detector: Type[BaseFormatDetector] = None
+ detector_class = self.ToolCallParserEnum.get(tool_call_parser)
+ if detector_class:
+ detector = detector_class()
+ else:
+ raise ValueError(f"Unsupported tool_call_parser: {tool_call_parser}")
+
+ self.detector = detector
+ self.tools = tools
+
+ def has_tool_call(self, text: str) -> bool:
+ """
+ Check if the given text contains a tool call in the format supported by this parser.
+ This delegates to the detector's implementation.
+
+ Args:
+ text: The text to check for tool calls
+
+ Returns:
+ True if the text contains a tool call, False otherwise
+ """
+ return self.detector.has_tool_call(text)
+
+ def parse_non_stream(self, full_text: str) -> Tuple[str, list[ToolCallItem]]:
+ """
+ One-time parsing of the full text to extract tool calls.
+
+ Args:
+ full_text: The complete text to parse
+
+ Returns:
+ A tuple containing:
+ - The remaining text after parsing that was not consumed by the detector (can be treated as normal text)
+ - A list of tool calls parsed from the text
+ """
+ parsed_result = self.detector.detect_and_parse(full_text, self.tools)
+ tool_call_list = parsed_result.calls
+ if tool_call_list:
+ return parsed_result.normal_text, tool_call_list
+ else:
+ return full_text, []
+
+ def parse_stream_chunk(self, chunk_text: str) -> Tuple[str, list[ToolCallItem]]:
+ """
+ Streaming incremental parsing of chunks of text as they arrive.
+
+ Args:
+ chunk_text: The new chunk of text to parse
+
+ Returns:
+ A tuple containing:
+ - The normal text that should be displayed to the user
+ - A list of tool calls parsed from the chunk
+ """
+ final_normal_text = ""
+ final_calls = []
+
+ sp_result = self.detector.parse_streaming_increment(chunk_text, self.tools)
+ if sp_result.normal_text:
+ final_normal_text = sp_result.normal_text
+ if sp_result.calls:
+ final_calls.extend(sp_result.calls)
+ final_normal_text = sp_result.normal_text
+
+ return final_normal_text, final_calls
+
+ def get_structure_tag(self) -> StructuralTagResponseFormat:
+ """
+ Generate a structural tag response format for all available tools.
+
+ This creates the necessary structural tags that guide the model's output format.
+ """
+ tool_structures: List[StructuresResponseFormat] = list()
+ tool_trigger_set: Set[str] = set()
+
+ get_structure_info = self.detector.structure_info()
+ for tool in self.tools:
+ function = tool.function
+ name = function.name
+ assert name is not None
+ info = get_structure_info(name)
+
+ # accept all if not strict, otherwise only accept the schema
+ schema = function.parameters if function.strict else {}
+
+ tool_structures.append(
+ StructuresResponseFormat(
+ begin=info.begin,
+ schema=schema, # type: ignore
+ end=info.end,
+ )
+ )
+ tool_trigger_set.add(info.trigger)
+
+ return StructuralTagResponseFormat(
+ type="structural_tag",
+ structures=tool_structures,
+ triggers=list(tool_trigger_set),
+ )
+
+ def get_structure_constraint(
+ self, tool_choice: Union[ToolChoice, Literal["auto", "required"]]
+ ) -> Optional[Tuple[str, Any]]:
+ """
+ Returns the appropriate structure constraint for tool calls based on the tool_choice.
+ The constraint is used to guide the model's output format.
+
+ Args:
+ tool_choice: The tool choice setting from the request
+
+ Returns:
+ A tuple of (constraint_type, constraint_value) to be added to sampling parameters,
+ or None if no constraint applies.
+ """
+ # NOTE: structural_tag only supports JSON-compatible content between the begin and end.
+ # It cannot parse or validate Python syntax like function calls.
+ if (
+ not isinstance(self.detector, PythonicDetector)
+ and tool_choice == "auto"
+ and any(tool.function.strict for tool in self.tools)
+ ):
+ strict_tag = self.get_structure_tag()
+ return ("structural_tag", strict_tag)
+ elif tool_choice == "required" or isinstance(tool_choice, ToolChoice):
+ ebnf = self.get_ebnf(tool_choice)
+ return ("ebnf", ebnf) if ebnf is not None else None
+
+ def get_ebnf(
+ self, tool_choice: Union[ToolChoice, Literal["required"]]
+ ) -> Optional[str]:
+ """
+ Get the EBNF grammar for the specified tool choice.
+ """
+ filtered_tools = []
+ if isinstance(tool_choice, ToolChoice):
+ fn_name = tool_choice.function.name
+ filtered_tools = [t for t in self.tools if t.function.name == fn_name]
+ else:
+ filtered_tools = self.tools
+ return self.detector.build_ebnf(filtered_tools)
diff --git a/python/sglang/srt/function_call/llama32_detector.py b/python/sglang/srt/function_call/llama32_detector.py
new file mode 100644
index 000000000..32670782c
--- /dev/null
+++ b/python/sglang/srt/function_call/llama32_detector.py
@@ -0,0 +1,74 @@
+import json
+import logging
+from typing import List
+
+from sglang.srt.function_call.base_format_detector import BaseFormatDetector
+from sglang.srt.function_call.core_types import (
+ StreamingParseResult,
+ StructureInfo,
+ _GetInfoFunc,
+)
+from sglang.srt.function_call.ebnf_composer import EBNFComposer
+from sglang.srt.openai_api.protocol import Tool
+
+logger = logging.getLogger(__name__)
+
+
+class Llama32Detector(BaseFormatDetector):
+ """
+ Detector for Llama 3.2 models.
+ Assumes function call format:
+ <|python_tag|>{"name":"xxx", "arguments":{...}}
+ """
+
+ def __init__(self):
+ super().__init__()
+ self.bot_token = "<|python_tag|>"
+
+ def has_tool_call(self, text: str) -> bool:
+ """Check if the text contains a Llama 3.2 format tool call."""
+ # depending on the prompt format the Llama model may or may not
+ # prefix the output with the <|python_tag|> token
+ return "<|python_tag|>" in text or text.startswith("{")
+
+ def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
+ """Parse function calls from text, handling multiple JSON objects."""
+ if "<|python_tag|>" not in text and not text.startswith("{"):
+ return StreamingParseResult(normal_text=text, calls=[])
+
+ if "<|python_tag|>" in text:
+ normal_text, action_text = text.split("<|python_tag|>")
+ else:
+ normal_text, action_text = "", text
+
+ # Split by semicolon and process each part
+ json_parts = [part.strip() for part in action_text.split(";") if part.strip()]
+ all_actions = []
+ for part in json_parts:
+ try:
+ # Parse each individual JSON object
+ action = json.loads(part)
+ all_actions.append(action)
+ except json.JSONDecodeError as e:
+ logger.warning(f"Failed to parse JSON part: {part}")
+ logger.warning(f"JSON parse error: {str(e)}")
+ continue
+ calls = []
+ # Only process if we found valid JSON objects
+ if all_actions:
+ calls = self.parse_base_json(all_actions, tools)
+ return StreamingParseResult(normal_text=normal_text, calls=calls)
+
+ def structure_info(self) -> _GetInfoFunc:
+ return lambda name: StructureInfo(
+ begin='<|python_tag|>{"name":"' + name + '", "arguments":',
+ end="}",
+ trigger="<|python_tag|>",
+ )
+
+ def build_ebnf(self, tools: List[Tool]):
+ return EBNFComposer.build_ebnf(
+ tools,
+ function_format="json",
+ tool_call_separator=",",
+ )
diff --git a/python/sglang/srt/function_call/mistral_detector.py b/python/sglang/srt/function_call/mistral_detector.py
new file mode 100644
index 000000000..a5d2475ea
--- /dev/null
+++ b/python/sglang/srt/function_call/mistral_detector.py
@@ -0,0 +1,84 @@
+import json
+import re
+from typing import List
+
+from sglang.srt.function_call.base_format_detector import BaseFormatDetector
+from sglang.srt.function_call.core_types import (
+ StreamingParseResult,
+ StructureInfo,
+ _GetInfoFunc,
+)
+from sglang.srt.function_call.ebnf_composer import EBNFComposer
+from sglang.srt.openai_api.protocol import Tool
+
+
+class MistralDetector(BaseFormatDetector):
+ """
+ Detector for Mistral models.
+ Assumes function call format:
+ [TOOL_CALLS] [{"name":"xxx", "arguments":{...}}]
+ """
+
+ def __init__(self):
+ """
+ Initializes the detector with necessary state variables.
+ """
+ super().__init__()
+ self.bot_token = "[TOOL_CALLS] ["
+ self.eot_token = "]"
+ self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
+
+ def has_tool_call(self, text: str) -> bool:
+ """Check if the text contains a Mistral format tool call."""
+ return self.bot_token in text
+
+ def _clean_text(self, text: str) -> str:
+ """
+ clean text to only leave ''[TOOL_CALLS] [{"name": xxx, "arguments": {xxx}}]'
+ for example,
+ text = '[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"location": "Boston, MA", "unit": "fahrenheit"}}]\n\nToday\'s weather in Boston is :{function call result} (in Fahrenheit)\n\nIf you prefer Celsius, please let me know.'
+ return '[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"location": "Boston, MA", "unit": "fahrenheit"}}]'
+ The key pattern is [TOOL_CALLS] [...]
+ """
+ # TODO: check if Mistral supports multiple tool calls, currently assume only support one tool call
+ find_results = re.findall(r"\[TOOL_CALLS\] \[.*?\]", text, re.DOTALL)
+ if len(find_results) > 0:
+ return find_results[0]
+ else:
+ return ""
+
+ def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
+ """
+ One-time parsing: Detects and parses tool calls in the provided text.
+
+ :param text: The complete text to parse.
+ :param tools: List of available tools.
+ :return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
+ """
+ idx = text.find(self.bot_token)
+ normal_text = text[:idx].strip() if idx != -1 else text
+ text = self._clean_text(text)
+ tool_content = text.replace("[TOOL_CALLS]", "").strip()
+ raw_tool_calls = self.tool_call_regex.findall(tool_content)
+ calls = []
+ if len(raw_tool_calls) > 0:
+ raw_tool_call = raw_tool_calls[0]
+ function_call_arr = json.loads(raw_tool_call)
+ for match_result in function_call_arr:
+ calls.extend(self.parse_base_json(match_result, tools))
+ return StreamingParseResult(normal_text=normal_text, calls=calls)
+
+ def structure_info(self) -> _GetInfoFunc:
+ return lambda name: StructureInfo(
+ begin='[TOOL_CALLS] [{"name":"' + name + '", "arguments":',
+ end="}]",
+ trigger="[TOOL_CALLS]",
+ )
+
+ def build_ebnf(self, tools: List[Tool]):
+ return EBNFComposer.build_ebnf(
+ tools,
+ bot_token=self.bot_token,
+ eot_token=self.eot_token,
+ function_format="json",
+ )
diff --git a/python/sglang/srt/function_call/pythonic_detector.py b/python/sglang/srt/function_call/pythonic_detector.py
new file mode 100644
index 000000000..e60ab63bf
--- /dev/null
+++ b/python/sglang/srt/function_call/pythonic_detector.py
@@ -0,0 +1,163 @@
+import ast
+import json
+import logging
+import re
+from typing import List, Optional
+
+from sglang.srt.function_call.base_format_detector import BaseFormatDetector
+from sglang.srt.function_call.core_types import (
+ StreamingParseResult,
+ StructureInfo,
+ ToolCallItem,
+ _GetInfoFunc,
+)
+from sglang.srt.function_call.ebnf_composer import EBNFComposer
+from sglang.srt.openai_api.protocol import Tool
+
+logger = logging.getLogger(__name__)
+
+
+class PythonicDetector(BaseFormatDetector):
+ """
+ Detector for Llama-3.2 and Llama-4 models with pythonic tool call format.
+ Assumes function call format:
+ [tool1(arg1=val1, arg2=val2), tool2(arg1=val3)]
+ Arguments are Python literals (not JSON).
+ """
+
+ def __init__(self):
+ super().__init__()
+ self.tool_call_regex = re.compile(
+ r"\[([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s)?\),\s*)*([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s*)?\)\s*)+\]",
+ re.DOTALL,
+ )
+
+ def has_tool_call(self, text: str) -> bool:
+ return bool(self.tool_call_regex.match(text.strip()))
+
+ def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
+ # Try parsing the text as a Python list of function calls
+ text = text.strip()
+ if not (text.startswith("[") and text.endswith("]")):
+ # Not a pythonic tool call format
+ return StreamingParseResult(normal_text=text, calls=[])
+ try:
+ module = ast.parse(text)
+ parsed = getattr(module.body[0], "value", None)
+ if not (
+ isinstance(parsed, ast.List)
+ and all(isinstance(e, ast.Call) for e in parsed.elts)
+ ):
+ return StreamingParseResult(normal_text=text, calls=[])
+ calls = []
+ tool_indices = {
+ tool.function.name: i
+ for i, tool in enumerate(tools)
+ if tool.function.name
+ }
+ for call in parsed.elts:
+ if not isinstance(call.func, ast.Name):
+ continue
+ function_name = call.func.id
+ arguments = {}
+ for keyword in call.keywords:
+ arguments[keyword.arg] = self._get_parameter_value(keyword.value)
+ calls.append(
+ ToolCallItem(
+ tool_index=tool_indices.get(function_name, -1),
+ name=function_name,
+ parameters=json.dumps(arguments, ensure_ascii=False),
+ )
+ )
+ return StreamingParseResult(normal_text="", calls=calls)
+ except Exception:
+ logger.exception("Error in pythonic tool call parsing.")
+ return StreamingParseResult(normal_text=text, calls=[])
+
+ def _find_matching_bracket(self, buffer: str, start: int) -> int:
+ """
+ Find the matching closing bracket for the opening bracket at start position.
+ Properly handles nested brackets.
+
+ Args:
+ buffer: The text buffer to search in
+ start: Position of the opening bracket '['
+
+ Returns:
+ Position of the matching closing bracket ']', or -1 if not found
+ """
+ bracket_count = 0
+ for i in range(start, len(buffer)):
+ if buffer[i] == "[":
+ bracket_count += 1
+ elif buffer[i] == "]":
+ bracket_count -= 1
+ if bracket_count == 0:
+ return i
+ return -1 # No matching bracket found
+
+ def parse_streaming_increment(
+ self, new_text: str, tools: List[Tool]
+ ) -> StreamingParseResult:
+ """
+ Streaming incremental parsing for pythonic tool calls.
+ Buffers input until a complete pythonic tool call (from [ to ]) is found,
+ then parses and emits any detected calls.
+ """
+ self._buffer += new_text
+ start = self._buffer.find("[")
+
+ if start == -1:
+ normal_text = self._buffer
+ self._buffer = ""
+ return StreamingParseResult(normal_text=normal_text)
+
+ normal_text = self._buffer[:start] if start > 0 else ""
+
+ end = self._find_matching_bracket(self._buffer, start)
+ if end != -1:
+ call_text = self._buffer[start : end + 1]
+ result = self.detect_and_parse(call_text, tools)
+ self._buffer = self._buffer[end + 1 :]
+
+ # If we had normal text before the tool call, add it to the result
+ if normal_text:
+ result.normal_text = normal_text + (result.normal_text or "")
+
+ return result
+
+ # We have an opening bracket but no closing bracket yet
+ if normal_text:
+ self._buffer = self._buffer[start:]
+ return StreamingParseResult(normal_text=normal_text)
+
+ # Otherwise, we're still accumulating a potential tool call
+ return StreamingParseResult(normal_text="")
+
+ def _get_parameter_value(self, val):
+ if isinstance(val, ast.Constant):
+ return val.value
+ elif isinstance(val, ast.Dict):
+ return {
+ k.value: self._get_parameter_value(v)
+ for k, v in zip(val.keys, val.values)
+ }
+ elif isinstance(val, ast.List):
+ return [self._get_parameter_value(v) for v in val.elts]
+ else:
+ raise ValueError("Tool call arguments must be literals")
+
+ def structure_info(self) -> _GetInfoFunc:
+ def info(name: str):
+ return StructureInfo(begin=f"[{name}(", end=")]", trigger=f"[{name}(")
+
+ return info
+
+ def build_ebnf(self, tools: List[Tool]) -> Optional[str]:
+ return EBNFComposer.build_ebnf(
+ tools,
+ bot_token="[",
+ eot_token="]",
+ tool_call_separator=",",
+ function_format="pythonic",
+ )
diff --git a/python/sglang/srt/function_call/qwen25_detector.py b/python/sglang/srt/function_call/qwen25_detector.py
new file mode 100644
index 000000000..1d32099f7
--- /dev/null
+++ b/python/sglang/srt/function_call/qwen25_detector.py
@@ -0,0 +1,67 @@
+import json
+import re
+from typing import List
+
+from sglang.srt.function_call.base_format_detector import BaseFormatDetector
+from sglang.srt.function_call.core_types import (
+ StreamingParseResult,
+ StructureInfo,
+ _GetInfoFunc,
+)
+from sglang.srt.function_call.ebnf_composer import EBNFComposer
+from sglang.srt.openai_api.protocol import Tool
+
+
+class Qwen25Detector(BaseFormatDetector):
+ """
+ Detector for Qwen 2.5 models.
+ Assumes function call format:
+ {"name":"xxx", "arguments":{...}}
+ """
+
+ def __init__(self):
+ """
+ Initializes the detector with necessary state variables.
+ """
+ super().__init__()
+ self.bot_token = ""
+ self.eot_token = ""
+
+ def has_tool_call(self, text: str) -> bool:
+ """Check if the text contains a Qwen 2.5 format tool call."""
+ return self.bot_token in text
+
+ def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
+ """
+ One-time parsing: Detects and parses tool calls in the provided text.
+
+ :param text: The complete text to parse.
+ :param tools: List of available tools.
+ :return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
+ """
+ idx = text.find(self.bot_token)
+ normal_text = text[:idx].strip() if idx != -1 else text
+ if self.bot_token not in text:
+ return StreamingParseResult(normal_text=normal_text, calls=[])
+ pattern = rf"{self.bot_token}(.*?){self.eot_token}"
+ match_result_list = re.findall(pattern, text, re.DOTALL)
+ calls = []
+ for match_result in match_result_list:
+ match_result = json.loads(match_result)
+ calls.extend(self.parse_base_json(match_result, tools))
+ return StreamingParseResult(normal_text=normal_text, calls=calls)
+
+ def structure_info(self) -> _GetInfoFunc:
+ return lambda name: StructureInfo(
+ begin='{"name":"' + name + '", "arguments":',
+ end="}",
+ trigger="",
+ )
+
+ def build_ebnf(self, tools: List[Tool]):
+ return EBNFComposer.build_ebnf(
+ tools,
+ bot_token=self.bot_token,
+ eot_token=self.eot_token,
+ function_format="json",
+ )
diff --git a/python/sglang/srt/function_call/utils.py b/python/sglang/srt/function_call/utils.py
new file mode 100644
index 000000000..e8a585bb2
--- /dev/null
+++ b/python/sglang/srt/function_call/utils.py
@@ -0,0 +1,35 @@
+import json
+from json import JSONDecodeError, JSONDecoder
+from typing import Any, Tuple
+
+import partial_json_parser
+from partial_json_parser.core.options import Allow
+
+
+def _find_common_prefix(s1: str, s2: str) -> str:
+ prefix = ""
+ min_length = min(len(s1), len(s2))
+ for i in range(0, min_length):
+ if s1[i] == s2[i]:
+ prefix += s1[i]
+ else:
+ break
+ return prefix
+
+
+def _partial_json_loads(input_str: str, flags: Allow) -> Tuple[Any, int]:
+ try:
+ return (partial_json_parser.loads(input_str, flags), len(input_str))
+ except JSONDecodeError as e:
+ if "Extra data" in e.msg:
+ dec = JSONDecoder()
+ return dec.raw_decode(input_str)
+ raise
+
+
+def _is_complete_json(input_str: str) -> bool:
+ try:
+ json.loads(input_str)
+ return True
+ except JSONDecodeError:
+ return False
diff --git a/python/sglang/srt/function_call_parser.py b/python/sglang/srt/function_call_parser.py
deleted file mode 100644
index 549843146..000000000
--- a/python/sglang/srt/function_call_parser.py
+++ /dev/null
@@ -1,858 +0,0 @@
-import ast
-import json
-import logging
-import re
-from abc import ABC, abstractmethod
-from dataclasses import dataclass
-from json import JSONDecodeError, JSONDecoder
-from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type
-
-import partial_json_parser
-from partial_json_parser.core.exceptions import MalformedJSON
-from partial_json_parser.core.options import Allow
-from pydantic import BaseModel
-
-from sglang.srt.openai_api.protocol import (
- StructuralTagResponseFormat,
- StructuresResponseFormat,
- Tool,
-)
-
-logger = logging.getLogger(__name__)
-
-TOOLS_TAG_LIST = [
- "<|plugin|>",
- "",
- "<|python_tag|>",
- "[TOOL_CALLS]",
- "<|tool▁calls▁begin|>",
-]
-
-
-class ToolCallItem(BaseModel):
- """Simple encapsulation of the parsed ToolCall result for easier usage in streaming contexts."""
-
- tool_index: int
- name: Optional[str] = None
- parameters: str # JSON string
-
-
-def _find_common_prefix(s1: str, s2: str) -> str:
- prefix = ""
- min_length = min(len(s1), len(s2))
- for i in range(0, min_length):
- if s1[i] == s2[i]:
- prefix += s1[i]
- else:
- break
- return prefix
-
-
-def _partial_json_loads(input_str: str, flags: Allow) -> Tuple[Any, int]:
- try:
- return (partial_json_parser.loads(input_str, flags), len(input_str))
- except JSONDecodeError as e:
- if "Extra data" in e.msg:
- dec = JSONDecoder()
- return dec.raw_decode(input_str)
- raise
-
-
-def _is_complete_json(input_str: str) -> bool:
- try:
- json.loads(input_str)
- return True
- except JSONDecodeError:
- return False
-
-
-class StreamingParseResult:
- """Result of streaming incremental parsing."""
-
- def __init__(
- self, normal_text: str = "", calls: Optional[List[ToolCallItem]] = None
- ):
- self.normal_text = normal_text
- self.calls = calls or []
-
-
-@dataclass
-class StructureInfo:
- begin: str
- end: str
- trigger: str
-
-
-_GetInfoFunc = Callable[[str], StructureInfo]
-"""
-Helper alias of function
-Usually it is a function that takes a name string and returns a StructureInfo object,
-which can be used to construct a structural_tag object
-"""
-
-
-class BaseFormatDetector(ABC):
- """Base class providing two sets of interfaces: one-time and streaming incremental."""
-
- def __init__(self):
- # initialize properties used for state when parsing tool calls in
- self._buffer = ""
- # streaming mode
- self.prev_tool_call_arr: List[Dict] = []
- self.current_tool_id: int = -1
- self.current_tool_name_sent: bool = False
- self.streamed_args_for_tool: List[str] = (
- []
- ) # map what has been streamed for each tool so far to a list
- self.bot_token = ""
- self.eot_token = ""
-
- def parse_base_json(self, action: Any, tools: List[Tool]) -> List[ToolCallItem]:
- tool_indices = {
- tool.function.name: i for i, tool in enumerate(tools) if tool.function.name
- }
- if not isinstance(action, list):
- action = [action]
-
- results = []
- for act in action:
- name = act.get("name")
- if name and name in tool_indices:
- results.append(
- ToolCallItem(
- tool_index=tool_indices[name],
- name=name,
- parameters=json.dumps(
- act.get("parameters") or act.get("arguments", {}),
- ensure_ascii=False,
- ),
- )
- )
- else:
- logger.warning(f"Model attempted to call undefined function: {name}")
-
- return results
-
- @abstractmethod
- def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
- """
- Parses the text in one go. Returns success=True if the format matches, otherwise False.
- Note that leftover_text here represents "content that this parser will not consume further".
- """
- action = json.loads(text)
- return StreamingParseResult(calls=self.parse_base_json(action, tools))
-
- def parse_streaming_increment(
- self, new_text: str, tools: List[Tool]
- ) -> StreamingParseResult:
- """
- Streaming incremental parsing with tool validation.
- """
- # Append new text to buffer
- self._buffer += new_text
- current_text = self._buffer
- if not (self.bot_token in current_text or current_text.startswith("{")):
- self._buffer = ""
- if self.eot_token in new_text:
- new_text = new_text.replace(self.eot_token, "")
- return StreamingParseResult(normal_text=new_text)
-
- # Build tool indices if not already built
- if not hasattr(self, "_tool_indices"):
- self._tool_indices = {
- tool.function.name: i
- for i, tool in enumerate(tools)
- if tool.function and tool.function.name
- }
-
- flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR
- try:
- tool_call_arr = []
- is_complete = []
- try:
- start_idx = (
- len(self.bot_token)
- if current_text.startswith(self.bot_token)
- else 0
- )
- while start_idx < len(current_text):
- (obj, end_idx) = _partial_json_loads(
- current_text[start_idx:], flags
- )
- is_complete.append(
- _is_complete_json(current_text[start_idx : start_idx + end_idx])
- )
- start_idx += end_idx + len("; ")
-
- # Validate tool name if present
- if "name" in obj and obj["name"] not in self._tool_indices:
- # Invalid tool name - reset state
- self._buffer = ""
- self.current_tool_id = -1
- self.current_tool_name_sent = False
- if self.streamed_args_for_tool:
- self.streamed_args_for_tool.pop()
- return StreamingParseResult()
-
- # Handle parameters/arguments consistency
- if "parameters" in obj:
- assert (
- "arguments" not in obj
- ), "model generated both parameters and arguments"
- obj["arguments"] = obj["parameters"]
- tool_call_arr.append(obj)
-
- except MalformedJSON:
- return StreamingParseResult()
-
- if len(tool_call_arr) == 0:
- return StreamingParseResult()
-
- current_tool_call: Dict = (
- tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {}
- )
-
- # Handle new tool in array
- if len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1:
- if self.current_tool_id >= 0:
- cur_arguments = current_tool_call.get("arguments")
- if cur_arguments:
- cur_args_json = json.dumps(cur_arguments)
- sent = len(self.streamed_args_for_tool[self.current_tool_id])
- argument_diff = cur_args_json[sent:]
-
- res = StreamingParseResult(
- calls=[
- ToolCallItem(
- tool_index=self.current_tool_id,
- name="",
- parameters=argument_diff,
- )
- ],
- )
- self.streamed_args_for_tool[
- self.current_tool_id
- ] += argument_diff
- else:
- res = StreamingParseResult()
- else:
- res = StreamingParseResult()
-
- self.current_tool_id = len(tool_call_arr) - 1
- self.current_tool_name_sent = False
- self.streamed_args_for_tool.append("")
- return res
-
- # Handle tool name
- elif not self.current_tool_name_sent:
- function_name = current_tool_call.get("name")
- if function_name and function_name in self._tool_indices:
- res = StreamingParseResult(
- calls=[
- ToolCallItem(
- tool_index=self._tool_indices[function_name],
- name=function_name,
- parameters="",
- )
- ],
- )
- self.current_tool_name_sent = True
- else:
- res = StreamingParseResult()
-
- # Handle streaming arguments
- else:
- cur_arguments = current_tool_call.get("arguments")
- res = StreamingParseResult()
-
- if cur_arguments:
- sent = len(self.streamed_args_for_tool[self.current_tool_id])
- cur_args_json = json.dumps(cur_arguments)
- prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get(
- "arguments"
- )
-
- argument_diff = None
- if is_complete[self.current_tool_id]:
- argument_diff = cur_args_json[sent:]
- self._buffer = ""
- self.prev_tool_call_arr[self.current_tool_id].clear()
- self.current_tool_name_sent = False
- self.streamed_args_for_tool[self.current_tool_id] = ""
-
- elif prev_arguments:
- prev_args_json = json.dumps(prev_arguments)
- if cur_args_json != prev_args_json:
- prefix = _find_common_prefix(prev_args_json, cur_args_json)
- argument_diff = prefix[sent:]
-
- if argument_diff is not None:
- res = StreamingParseResult(
- calls=[
- ToolCallItem(
- tool_index=self.current_tool_id,
- parameters=argument_diff,
- )
- ],
- )
- if not is_complete[self.current_tool_id]:
- self.streamed_args_for_tool[
- self.current_tool_id
- ] += argument_diff
-
- self.prev_tool_call_arr = tool_call_arr
- return res
-
- except Exception as e:
- logger.error(f"Error in parse_streaming_increment: {e}")
- return StreamingParseResult()
-
- @abstractmethod
- def has_tool_call(self, text: str) -> bool:
- raise NotImplementedError()
-
- @abstractmethod
- def structure_info(self) -> _GetInfoFunc:
- raise NotImplementedError()
-
-
-class Qwen25Detector(BaseFormatDetector):
- """
- Detector for Qwen 2.5 models.
- Assumes function call format:
- {"name":"xxx", "arguments":{...}}
- """
-
- def __init__(self):
- """
- Initializes the detector with necessary state variables.
- """
- super().__init__()
- self.bot_token = ""
- self.eot_token = ""
-
- def has_tool_call(self, text: str) -> bool:
- """Check if the text contains a Qwen 2.5 format tool call."""
- return self.bot_token in text
-
- def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
- """
- One-time parsing: Detects and parses tool calls in the provided text.
-
- :param text: The complete text to parse.
- :param tools: List of available tools.
- :return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
- """
- idx = text.find(self.bot_token)
- normal_text = text[:idx].strip() if idx != -1 else text
- if self.bot_token not in text:
- return StreamingParseResult(normal_text=normal_text, calls=[])
- pattern = rf"{self.bot_token}(.*?){self.eot_token}"
- match_result_list = re.findall(pattern, text, re.DOTALL)
- calls = []
- for match_result in match_result_list:
- match_result = json.loads(match_result)
- calls.extend(self.parse_base_json(match_result, tools))
- return StreamingParseResult(normal_text=normal_text, calls=calls)
-
- def structure_info(self) -> _GetInfoFunc:
- return lambda name: StructureInfo(
- begin='{"name":"' + name + '", "arguments":',
- end="}",
- trigger="",
- )
-
-
-class MistralDetector(BaseFormatDetector):
- """
- Detector for Mistral models.
- Assumes function call format:
- <|action_start|><|plugin|>{"name":"xxx", "arguments":{...}}<|action_end|>
- """
-
- def __init__(self):
- """
- Initializes the detector with necessary state variables.
- """
- super().__init__()
- self.bot_token = "[TOOL_CALLS] ["
- self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
-
- def has_tool_call(self, text: str) -> bool:
- """Check if the text contains a Mistral format tool call."""
- return self.bot_token in text
-
- def _clean_text(self, text: str) -> str:
- """
- clean text to only leave ''[TOOL_CALLS] [{"name": xxx, "arguments": {xxx}}]'
- for example,
- text = '[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"location": "Boston, MA", "unit": "fahrenheit"}}]\n\nToday\'s weather in Boston is :{function call result} (in Fahrenheit)\n\nIf you prefer Celsius, please let me know.'
- return '[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"location": "Boston, MA", "unit": "fahrenheit"}}]'
- The key pattern is [TOOL_CALLS] [...]
- """
- find_results = re.findall(r"\[TOOL_CALLS\] \[.*?\]", text, re.DOTALL)
- if len(find_results) > 0:
- return find_results[0]
- else:
- return ""
-
- def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
- """
- One-time parsing: Detects and parses tool calls in the provided text.
-
- :param text: The complete text to parse.
- :param tools: List of available tools.
- :return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
- """
- idx = text.find(self.bot_token)
- normal_text = text[:idx].strip() if idx != -1 else text
- text = self._clean_text(text)
- tool_content = text.replace("[TOOL_CALLS]", "").strip()
- raw_tool_calls = self.tool_call_regex.findall(tool_content)
- calls = []
- if len(raw_tool_calls) > 0:
- raw_tool_call = raw_tool_calls[0]
- function_call_arr = json.loads(raw_tool_call)
- for match_result in function_call_arr:
- calls.extend(self.parse_base_json(match_result, tools))
- return StreamingParseResult(normal_text=normal_text, calls=calls)
-
- def structure_info(self) -> _GetInfoFunc:
- return lambda name: StructureInfo(
- begin='[TOOL_CALLS] [{"name":"' + name + '", "arguments":',
- end="}]",
- trigger="[TOOL_CALLS]",
- )
-
-
-class Llama32Detector(BaseFormatDetector):
- """
- Detector for Llama 3.2 models.
- Assumes function call format:
- <|python_tag|>{"name":"xxx", "arguments":{...}}
- """
-
- def __init__(self):
- super().__init__()
- self.bot_token = "<|python_tag|>"
-
- def has_tool_call(self, text: str) -> bool:
- """Check if the text contains a Llama 3.2 format tool call."""
- # depending on the prompt format the Llama model may or may not
- # prefix the output with the <|python_tag|> token
- return "<|python_tag|>" in text or text.startswith("{")
-
- def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
- """Parse function calls from text, handling multiple JSON objects."""
- if "<|python_tag|>" not in text and not text.startswith("{"):
- return StreamingParseResult(normal_text=text, calls=[])
-
- if "<|python_tag|>" in text:
- normal_text, action_text = text.split("<|python_tag|>")
- else:
- normal_text, action_text = "", text
-
- # Split by semicolon and process each part
- json_parts = [part.strip() for part in action_text.split(";") if part.strip()]
- all_actions = []
- for part in json_parts:
- try:
- # Parse each individual JSON object
- action = json.loads(part)
- all_actions.append(action)
- except json.JSONDecodeError as e:
- logger.warning(f"Failed to parse JSON part: {part}")
- logger.warning(f"JSON parse error: {str(e)}")
- continue
- calls = []
- # Only process if we found valid JSON objects
- if all_actions:
- calls = self.parse_base_json(all_actions, tools)
- return StreamingParseResult(normal_text=normal_text, calls=calls)
-
- def structure_info(self) -> _GetInfoFunc:
- return lambda name: StructureInfo(
- begin='<|python_tag|>{"name":"' + name + '", "arguments":',
- end="}",
- trigger="<|python_tag|>",
- )
-
-
-class DeepSeekV3Detector(BaseFormatDetector):
- """
- Detector for DeepSeek models.
- Assumes function call format:
- '<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_current_weather\n```json\n{"location": "Tokyo"}\n```<|tool▁call▁end|>\n<|tool▁call▁begin|>function<|tool▁sep|>get_current_weather\n```json\n{"location": "Paris"}\n```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>
- """
-
- def __init__(self):
- super().__init__()
- self.bot_token = "<|tool▁calls▁begin|>"
- self.eot_token = "<|tool▁calls▁end|>"
- self.func_call_regex = r"<|tool▁call▁begin|>.*?<|tool▁call▁end|>"
- self.func_detail_regex = r"<|tool▁call▁begin|>(.*)<|tool▁sep|>(.*)\n```json\n(.*)\n```<|tool▁call▁end|>"
- self._last_arguments = ""
-
- def has_tool_call(self, text: str) -> bool:
- """Check if the text contains a deepseek format tool call."""
- return self.bot_token in text
-
- def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
- """
- One-time parsing: Detects and parses tool calls in the provided text.
-
- :param text: The complete text to parse.
- :param tools: List of available tools.
- :return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
- """
- idx = text.find(self.bot_token)
- normal_text = text[:idx].strip() if idx != -1 else text
- if self.bot_token not in text:
- return StreamingParseResult(normal_text=normal_text, calls=[])
- match_result_list = re.findall(self.func_call_regex, text, re.DOTALL)
- calls = []
- try:
- for match_result in match_result_list:
- # Get function name
- func_detail = re.search(self.func_detail_regex, match_result, re.DOTALL)
- func_name = func_detail.group(2)
- func_args = func_detail.group(3)
- func_args = json.loads(func_args)
- # construct match_result for parse_base_json
- match_result = {"name": func_name, "parameters": func_args}
- calls.extend(self.parse_base_json(match_result, tools))
- return StreamingParseResult(normal_text=normal_text, calls=calls)
- except Exception as e:
- logger.error(f"Error in detect_and_parse: {e}")
- # return the normal text if parsing fails
- return StreamingParseResult(normal_text=text)
-
- def structure_info(self) -> _GetInfoFunc:
- return lambda name: StructureInfo(
- begin=">" + name + "\n```json\n",
- end="\n```<",
- trigger=">" + name + "\n```json\n",
- )
-
- def parse_streaming_increment(
- self, new_text: str, tools: List[Tool]
- ) -> StreamingParseResult:
- """
- Streaming incremental parsing tool calls for DeepSeekV3 format.
- """
- self._buffer += new_text
- current_text = self._buffer
-
- if self.bot_token not in current_text:
- self._buffer = ""
- for e_token in [self.eot_token, "```", "<|tool▁call▁end|>"]:
- if e_token in new_text:
- new_text = new_text.replace(e_token, "")
- return StreamingParseResult(normal_text=new_text)
-
- if not hasattr(self, "_tool_indices"):
- self._tool_indices = {
- tool.function.name: i
- for i, tool in enumerate(tools)
- if tool.function and tool.function.name
- }
-
- calls: list[ToolCallItem] = []
- try:
- partial_match = re.search(
- pattern=r"<|tool▁call▁begin|>(.*)<|tool▁sep|>(.*)\n```json\n(.*)",
- string=current_text,
- flags=re.DOTALL,
- )
- if partial_match:
- func_name = partial_match.group(2).strip()
- func_args_raw = partial_match.group(3).strip()
-
- if not self.current_tool_name_sent:
- calls.append(
- ToolCallItem(
- tool_index=self._tool_indices.get(func_name, 0),
- name=func_name,
- parameters="",
- )
- )
- self.current_tool_name_sent = True
- else:
- argument_diff = (
- func_args_raw[len(self._last_arguments) :]
- if func_args_raw.startswith(self._last_arguments)
- else func_args_raw
- )
-
- if argument_diff:
- calls.append(
- ToolCallItem(
- tool_index=self._tool_indices.get(func_name, 0),
- name=None,
- parameters=argument_diff,
- )
- )
- self._last_arguments += argument_diff
-
- if _is_complete_json(func_args_raw):
- result = StreamingParseResult(normal_text="", calls=calls)
- self._buffer = ""
- self._last_arguments = ""
- self.current_tool_name_sent = False
- return result
-
- return StreamingParseResult(normal_text="", calls=calls)
-
- except Exception as e:
- logger.error(f"Error in parse_streaming_increment: {e}")
- return StreamingParseResult(normal_text=current_text)
-
-
-class MultiFormatParser:
- def __init__(self, detectors: List[BaseFormatDetector]):
- """
- :param detectors: A series of available Detector instances passed in
- """
- self.detectors = detectors
-
- def parse_once(
- self, text: str, tools: List[Tool]
- ) -> Tuple[str, list[ToolCallItem]]:
- """
- One-time parsing: Loop through detectors until there are no new matches or text is exhausted
- Return: (final_text, all_calls)
- - final_text: The remaining text after parsing that was not consumed by any Detector (can be treated as normal text)
- - all_calls: All calls parsed by the Detectors
- """
- final_calls = []
- final_normal_text = text
- for detector in self.detectors:
- parsed_result = detector.detect_and_parse(text, tools)
- tool_call_list = parsed_result.calls
- if len(tool_call_list) > 0: # parsed successfully
- final_calls = tool_call_list
- final_normal_text = parsed_result.normal_text
- break
-
- # leftover_text is the normal text not consumed by any Detector
- return final_normal_text, final_calls
-
- def parse_streaming_increment(
- self, new_text: str, tools: List[Tool]
- ) -> Tuple[str, list[ToolCallItem]]:
- """
- Streaming incremental parsing: Feed new_text to each detector's parse_streaming_increment
- and merge their produced normal_text/calls to return.
- (The logic here can be "priority-based" or "parallel parsing" based on your needs)
- """
- final_normal_text = ""
- final_calls = []
-
- for detector in self.detectors:
- sp_result = detector.parse_streaming_increment(new_text, tools)
- # Merge normal_text and calls
- # If one sp_result contains result call, this should be a successful parse
- # If one sp_result only contains normal_text, this can either be a successful
- # parse or it is not using the desired parsing tool.
- if sp_result.normal_text:
- final_normal_text = sp_result.normal_text
- if sp_result.calls:
- final_calls.extend(sp_result.calls)
- final_normal_text = sp_result.normal_text
- break
-
- return final_normal_text, final_calls
-
-
-class PythonicDetector(BaseFormatDetector):
- """
- Detector for Llama-3.2 and Llama-4 models with pythonic tool call format.
- Assumes function call format:
- [tool1(arg1=val1, arg2=val2), tool2(arg1=val3)]
- Arguments are Python literals (not JSON).
- """
-
- def __init__(self):
- super().__init__()
- self.tool_call_regex = re.compile(
- r"\[([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s)?\),\s*)*([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s*)?\)\s*)+\]",
- re.DOTALL,
- )
-
- def has_tool_call(self, text: str) -> bool:
- return bool(self.tool_call_regex.match(text.strip()))
-
- def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
- # Try parsing the text as a Python list of function calls
- text = text.strip()
- if not (text.startswith("[") and text.endswith("]")):
- # Not a pythonic tool call format
- return StreamingParseResult(normal_text=text, calls=[])
- try:
- module = ast.parse(text)
- parsed = getattr(module.body[0], "value", None)
- if not (
- isinstance(parsed, ast.List)
- and all(isinstance(e, ast.Call) for e in parsed.elts)
- ):
- return StreamingParseResult(normal_text=text, calls=[])
- calls = []
- tool_indices = {
- tool.function.name: i
- for i, tool in enumerate(tools)
- if tool.function.name
- }
- for call in parsed.elts:
- if not isinstance(call.func, ast.Name):
- continue
- function_name = call.func.id
- arguments = {}
- for keyword in call.keywords:
- arguments[keyword.arg] = self._get_parameter_value(keyword.value)
- calls.append(
- ToolCallItem(
- tool_index=tool_indices.get(function_name, -1),
- name=function_name,
- parameters=json.dumps(arguments, ensure_ascii=False),
- )
- )
- return StreamingParseResult(normal_text="", calls=calls)
- except Exception:
- logger.exception("Error in pythonic tool call parsing.")
- return StreamingParseResult(normal_text=text, calls=[])
-
- def parse_streaming_increment(
- self, new_text: str, tools: List[Tool]
- ) -> StreamingParseResult:
- """
- Streaming incremental parsing for pythonic tool calls.
- Buffers input until a complete pythonic tool call (from [ to ]) is found,
- then parses and emits any detected calls.
- """
- self._buffer += new_text
- start = self._buffer.find("[")
- end = self._buffer.find("]", start)
- if start != -1 and end != -1:
- call_text = self._buffer[start : end + 1]
- result = self.detect_and_parse(call_text, tools)
- self._buffer = self._buffer[end + 1 :]
- return result
- return StreamingParseResult(normal_text="")
-
- def _get_parameter_value(self, val):
- if isinstance(val, ast.Constant):
- return val.value
- elif isinstance(val, ast.Dict):
- return {
- k.value: self._get_parameter_value(v)
- for k, v in zip(val.keys, val.values)
- }
- elif isinstance(val, ast.List):
- return [self._get_parameter_value(v) for v in val.elts]
- else:
- raise ValueError("Tool call arguments must be literals")
-
- def structure_info(self) -> _GetInfoFunc:
- def info(name: str):
- return StructureInfo(begin="[", end="]", trigger="")
-
- return info
-
-
-class FunctionCallParser:
- """
- In streaming scenarios, each time new_text is received, it calls multi_format_parser.parse_streaming_increment
- and returns the resulting normal_text and calls to the upper layer (or SSE).
- """
-
- ToolCallParserEnum: Dict[str, Type[BaseFormatDetector]] = {
- "llama3": Llama32Detector,
- "qwen25": Qwen25Detector,
- "mistral": MistralDetector,
- "deepseekv3": DeepSeekV3Detector,
- "pythonic": PythonicDetector,
- }
-
- def __init__(self, tools: List[Tool], tool_call_parser: str):
- detectors = []
- if tool_call_parser:
- detector_class = self.ToolCallParserEnum.get(tool_call_parser)
- if detector_class:
- detectors.append(detector_class())
- else:
- raise ValueError(f"Unsupported tool_call_parser: {tool_call_parser}")
- else:
- raise ValueError("Tool Call Parser Not Given!")
-
- self.multi_format_parser = MultiFormatParser(detectors)
- self.tools = tools
-
- def has_tool_call(self, text: str) -> bool:
- """
- Check if the given text contains a tool call in the format supported by this parser.
- This delegates to the detector's implementation.
-
- :param text: The text to check for tool calls
- :return: True if the text contains a tool call, False otherwise
- """
- # Check all detectors in the multi_format_parser
- for detector in self.multi_format_parser.detectors:
- if detector.has_tool_call(text):
- return True
- return False
-
- def parse_non_stream(self, full_text: str) -> Tuple[str, list[ToolCallItem]]:
- """
- Non-streaming call: one-time parsing
- """
- full_normal_text, calls = self.multi_format_parser.parse_once(
- full_text, self.tools
- )
- return full_normal_text, calls
-
- def parse_stream_chunk(self, chunk_text: str) -> Tuple[str, list[ToolCallItem]]:
- """
- Streaming call: incremental parsing
- """
- normal_text, calls = self.multi_format_parser.parse_streaming_increment(
- chunk_text, self.tools
- )
- return normal_text, calls
-
- def structure_infos(self) -> List[_GetInfoFunc]:
- """
- Returns a list of structure_info functions for each detector
- """
- return [
- detector.structure_info() for detector in self.multi_format_parser.detectors
- ]
-
- def get_structure_tag(self) -> StructuralTagResponseFormat:
- tool_structures: List[StructuresResponseFormat] = list()
- tool_trigger_set: Set[str] = set()
-
- for wrapper in self.structure_infos():
- for tool in self.tools:
- function = tool.function
- name = function.name
- assert name is not None
- info = wrapper(name)
-
- # accept all if not strict, otherwise only accept the schema
- schema = function.parameters if function.strict else {}
-
- tool_structures.append(
- StructuresResponseFormat(
- begin=info.begin,
- schema=schema, # type: ignore
- end=info.end,
- )
- )
- tool_trigger_set.add(info.trigger)
-
- return StructuralTagResponseFormat(
- type="structural_tag",
- structures=tool_structures,
- triggers=list(tool_trigger_set),
- )
diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py
index 9f54641d9..d23cf5c05 100644
--- a/python/sglang/srt/openai_api/adapter.py
+++ b/python/sglang/srt/openai_api/adapter.py
@@ -40,7 +40,7 @@ from sglang.srt.conversation import (
get_conv_template_by_model_path,
register_conv_template,
)
-from sglang.srt.function_call_parser import FunctionCallParser
+from sglang.srt.function_call.function_call_parser import FunctionCallParser
from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
from sglang.srt.openai_api.protocol import (
BatchRequest,
@@ -970,7 +970,7 @@ def v1_chat_generate_request(
# - image_data: None or a list of image strings (URLs or base64 strings).
# - audio_data: None or a list of audio strings (URLs).
# None skips any image processing in GenerateReqInput.
- strict_tag = None
+ tool_call_constraint = None
prompt = ""
prompt_ids = []
if not isinstance(request.messages, str):
@@ -989,7 +989,9 @@ def v1_chat_generate_request(
tool_call_parser = tokenizer_manager.server_args.tool_call_parser
parser = FunctionCallParser(request.tools, tool_call_parser)
- strict_tag = parser.get_structure_tag()
+ tool_call_constraint = parser.get_structure_constraint(
+ request.tool_choice
+ )
if chat_template_name is None:
openai_compatible_messages = []
@@ -1156,20 +1158,24 @@ def v1_chat_generate_request(
request.response_format.model_dump(by_alias=True)
)
- if strict_tag is not None:
- if (
- sampling_params.get("regex")
- or sampling_params.get("ebnf")
- or sampling_params.get("structural_tag")
- or sampling_params.get("json_schema")
- ):
- logger.warning(
- "Constrained decoding is not compatible with tool calls."
+ # Check if there are already existing output constraints
+ has_existing_constraints = (
+ sampling_params.get("regex")
+ or sampling_params.get("ebnf")
+ or sampling_params.get("structural_tag")
+ or sampling_params.get("json_schema")
+ )
+
+ if tool_call_constraint and has_existing_constraints:
+ logger.warning("Constrained decoding is not compatible with tool calls.")
+ elif tool_call_constraint:
+ constraint_type, constraint_value = tool_call_constraint
+ if constraint_type == "structural_tag":
+ sampling_params[constraint_type] = convert_json_schema_to_str(
+ constraint_value.model_dump(by_alias=True)
)
else:
- sampling_params["structural_tag"] = convert_json_schema_to_str(
- strict_tag.model_dump(by_alias=True)
- )
+ sampling_params[constraint_type] = constraint_value
sampling_params_list.append(sampling_params)
diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py
index dcb4970fc..b63d7b5e7 100644
--- a/test/srt/run_suite.py
+++ b/test/srt/run_suite.py
@@ -36,7 +36,7 @@ suites = {
TestFile("test_fa3.py", 376),
TestFile("test_fim_completion.py", 40),
TestFile("test_fp8_kernel.py", 8),
- TestFile("test_function_calling.py", 60),
+ TestFile("test_function_call_parser.py", 10),
TestFile("test_fused_moe.py", 30),
TestFile("test_hicache.py", 116),
TestFile("test_hicache_mla.py", 254),
@@ -54,6 +54,7 @@ suites = {
TestFile("test_flashmla.py", 300),
TestFile("test_no_chunked_prefill.py", 108),
TestFile("test_no_overlap_scheduler.py", 216),
+ TestFile("test_openai_function_calling.py", 60),
TestFile("test_openai_server.py", 149),
TestFile("test_penalty.py", 41),
TestFile("test_page_size.py", 60),
diff --git a/test/srt/test_function_call_parser.py b/test/srt/test_function_call_parser.py
new file mode 100644
index 000000000..0a00a7dbd
--- /dev/null
+++ b/test/srt/test_function_call_parser.py
@@ -0,0 +1,408 @@
+import json
+import unittest
+
+from xgrammar import GrammarCompiler, TokenizerInfo
+
+from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector
+from sglang.srt.function_call.llama32_detector import Llama32Detector
+from sglang.srt.function_call.mistral_detector import MistralDetector
+from sglang.srt.function_call.pythonic_detector import PythonicDetector
+from sglang.srt.function_call.qwen25_detector import Qwen25Detector
+from sglang.srt.hf_transformers_utils import get_tokenizer
+from sglang.srt.openai_api.protocol import Function, Tool
+from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST
+
+
+class TestPythonicDetector(unittest.TestCase):
+ def setUp(self):
+ # Create sample tools for testing
+ self.tools = [
+ Tool(
+ type="function",
+ function=Function(
+ name="get_weather",
+ description="Get weather information",
+ parameters={
+ "properties": {
+ "location": {
+ "type": "string",
+ "description": "Location to get weather for",
+ },
+ "unit": {
+ "type": "string",
+ "description": "Temperature unit",
+ "enum": ["celsius", "fahrenheit"],
+ },
+ },
+ "required": ["location"],
+ },
+ ),
+ ),
+ Tool(
+ type="function",
+ function=Function(
+ name="search",
+ description="Search for information",
+ parameters={
+ "properties": {
+ "query": {
+ "type": "string",
+ "description": "Search query",
+ },
+ },
+ "required": ["query"],
+ },
+ ),
+ ),
+ ]
+ self.detector = PythonicDetector()
+
+ def test_parse_streaming_no_brackets(self):
+ """Test parsing text with no brackets (no tool calls)."""
+ text = "This is just normal text without any tool calls."
+ result = self.detector.parse_streaming_increment(text, self.tools)
+
+ self.assertEqual(result.normal_text, text)
+ self.assertEqual(result.calls, [])
+ self.assertEqual(self.detector._buffer, "") # Buffer should be cleared
+
+ def test_parse_streaming_complete_tool_call(self):
+ """Test parsing a complete tool call."""
+ text = "Here's a tool call: [get_weather(location='New York', unit='celsius')]"
+ result = self.detector.parse_streaming_increment(text, self.tools)
+
+ self.assertEqual(result.normal_text, "Here's a tool call: ")
+ self.assertEqual(len(result.calls), 1)
+ self.assertEqual(result.calls[0].name, "get_weather")
+ self.assertEqual(
+ self.detector._buffer, ""
+ ) # Buffer should be cleared after processing
+
+ # Check the parameters
+ params = json.loads(result.calls[0].parameters)
+ self.assertEqual(params["location"], "New York")
+ self.assertEqual(params["unit"], "celsius")
+
+ def test_parse_streaming_text_before_tool_call(self):
+ """Test parsing text that appears before a tool call."""
+ text = "This is some text before [get_weather(location='London')]"
+ result = self.detector.parse_streaming_increment(text, self.tools)
+
+ self.assertEqual(result.normal_text, "This is some text before ")
+ self.assertEqual(len(result.calls), 1)
+ self.assertEqual(result.calls[0].name, "get_weather")
+
+ # Check the parameters
+ params = json.loads(result.calls[0].parameters)
+ self.assertEqual(params["location"], "London")
+
+ def test_parse_streaming_partial_tool_call(self):
+ """Test parsing a partial tool call that spans multiple chunks."""
+ # First chunk with opening bracket but no closing bracket
+ text1 = "Let me check the weather: [get_weather(location="
+ result1 = self.detector.parse_streaming_increment(text1, self.tools)
+
+ self.assertEqual(result1.normal_text, "Let me check the weather: ")
+ self.assertEqual(result1.calls, [])
+ self.assertEqual(
+ self.detector._buffer, "[get_weather(location="
+ ) # Partial tool call remains in buffer
+
+ # Second chunk completing the tool call
+ text2 = "'Paris')]"
+ result2 = self.detector.parse_streaming_increment(text2, self.tools)
+
+ self.assertEqual(result2.normal_text, "")
+ self.assertEqual(len(result2.calls), 1)
+ self.assertEqual(result2.calls[0].name, "get_weather")
+
+ # Check the parameters
+ params = json.loads(result2.calls[0].parameters)
+ self.assertEqual(params["location"], "Paris")
+ self.assertEqual(
+ self.detector._buffer, ""
+ ) # Buffer should be cleared after processing
+
+ def test_parse_streaming_bracket_without_text_before(self):
+ """Test parsing a tool call that starts at the beginning of the text."""
+ text = "[search(query='python programming')]"
+ result = self.detector.parse_streaming_increment(text, self.tools)
+
+ self.assertEqual(result.normal_text, "")
+ self.assertEqual(len(result.calls), 1)
+ self.assertEqual(result.calls[0].name, "search")
+
+ # Check the parameters
+ params = json.loads(result.calls[0].parameters)
+ self.assertEqual(params["query"], "python programming")
+
+ def test_parse_streaming_text_after_tool_call(self):
+ """Test parsing text that appears after a tool call."""
+ # First chunk with complete tool call and some text after
+ text = "[get_weather(location='Tokyo')] Here's the forecast:"
+ result = self.detector.parse_streaming_increment(text, self.tools)
+
+ self.assertEqual(result.normal_text, "")
+ self.assertEqual(len(result.calls), 1)
+ self.assertEqual(result.calls[0].name, "get_weather")
+ self.assertEqual(
+ self.detector._buffer, " Here's the forecast:"
+ ) # Text after tool call remains in buffer
+
+ # Process the remaining text in buffer
+ result2 = self.detector.parse_streaming_increment("", self.tools)
+ self.assertEqual(result2.normal_text, " Here's the forecast:")
+ self.assertEqual(result2.calls, [])
+ self.assertEqual(self.detector._buffer, "") # Buffer should be cleared
+
+ def test_parse_streaming_multiple_tool_calls(self):
+ """Test parsing multiple tool calls in sequence."""
+ text = "[get_weather(location='Berlin')] and [search(query='restaurants')]"
+
+ # First tool call
+ result1 = self.detector.parse_streaming_increment(text, self.tools)
+ self.assertEqual(len(result1.calls), 1)
+ self.assertEqual(result1.calls[0].name, "get_weather")
+ self.assertEqual(self.detector._buffer, " and [search(query='restaurants')]")
+
+ # Second tool call
+ result2 = self.detector.parse_streaming_increment("", self.tools)
+ self.assertEqual(result2.normal_text, " and ")
+ self.assertEqual(len(result2.calls), 1)
+ self.assertEqual(result2.calls[0].name, "search")
+ self.assertEqual(self.detector._buffer, "")
+
+ def test_parse_streaming_opening_bracket_only(self):
+ """Test parsing text with only an opening bracket but no closing bracket."""
+ text = "Let's try this: ["
+ result = self.detector.parse_streaming_increment(text, self.tools)
+
+ self.assertEqual(result.normal_text, "Let's try this: ")
+ self.assertEqual(result.calls, [])
+ self.assertEqual(
+ self.detector._buffer, "["
+ ) # Opening bracket remains in buffer
+
+ def test_parse_streaming_nested_brackets(self):
+ """Test parsing tool calls with nested brackets in arguments."""
+ # Test with list argument containing nested brackets
+ text = "[get_weather(location='New York', unit='celsius', data=[1, 2, 3])]"
+ result = self.detector.parse_streaming_increment(text, self.tools)
+
+ self.assertEqual(result.normal_text, "")
+ self.assertEqual(len(result.calls), 1)
+ self.assertEqual(result.calls[0].name, "get_weather")
+ self.assertEqual(self.detector._buffer, "")
+
+ # Check the parameters
+ params = json.loads(result.calls[0].parameters)
+ self.assertEqual(params["location"], "New York")
+ self.assertEqual(params["unit"], "celsius")
+ self.assertEqual(params["data"], [1, 2, 3])
+
+ def test_parse_streaming_nested_brackets_dict(self):
+ """Test parsing tool calls with nested dictionaries and lists."""
+ # Test with nested dict and list arguments
+ text = "[search(query='test', config={'options': [1, 2], 'nested': {'key': 'value'}})]"
+ result = self.detector.parse_streaming_increment(text, self.tools)
+
+ self.assertEqual(result.normal_text, "")
+ self.assertEqual(len(result.calls), 1)
+ self.assertEqual(result.calls[0].name, "search")
+ self.assertEqual(self.detector._buffer, "")
+
+ # Check the parameters
+ params = json.loads(result.calls[0].parameters)
+ self.assertEqual(params["query"], "test")
+ self.assertEqual(params["config"]["options"], [1, 2])
+ self.assertEqual(params["config"]["nested"]["key"], "value")
+
+ def test_parse_streaming_multiple_tools_with_nested_brackets(self):
+ """Test parsing multiple tool calls with nested brackets."""
+ text = "[get_weather(location='Paris', data=[10, 20]), search(query='test', filters=['a', 'b'])]"
+ result = self.detector.parse_streaming_increment(text, self.tools)
+
+ self.assertEqual(result.normal_text, "")
+ self.assertEqual(len(result.calls), 2)
+ self.assertEqual(self.detector._buffer, "")
+
+ # Check first tool call
+ params1 = json.loads(result.calls[0].parameters)
+ self.assertEqual(result.calls[0].name, "get_weather")
+ self.assertEqual(params1["location"], "Paris")
+ self.assertEqual(params1["data"], [10, 20])
+
+ # Check second tool call
+ params2 = json.loads(result.calls[1].parameters)
+ self.assertEqual(result.calls[1].name, "search")
+ self.assertEqual(params2["query"], "test")
+ self.assertEqual(params2["filters"], ["a", "b"])
+
+ def test_parse_streaming_partial_nested_brackets(self):
+ """Test parsing partial tool calls with nested brackets across chunks."""
+ # First chunk with nested brackets but incomplete
+ text1 = "Here's a call: [get_weather(location='Tokyo', data=[1, 2"
+ result1 = self.detector.parse_streaming_increment(text1, self.tools)
+
+ self.assertEqual(result1.normal_text, "Here's a call: ")
+ self.assertEqual(result1.calls, [])
+ self.assertEqual(
+ self.detector._buffer, "[get_weather(location='Tokyo', data=[1, 2"
+ )
+
+ # Second chunk completing the nested brackets
+ text2 = ", 3])]"
+ result2 = self.detector.parse_streaming_increment(text2, self.tools)
+
+ self.assertEqual(result2.normal_text, "")
+ self.assertEqual(len(result2.calls), 1)
+ self.assertEqual(result2.calls[0].name, "get_weather")
+ self.assertEqual(self.detector._buffer, "")
+
+ # Check the parameters
+ params = json.loads(result2.calls[0].parameters)
+ self.assertEqual(params["location"], "Tokyo")
+ self.assertEqual(params["data"], [1, 2, 3])
+
+
+class TestEBNFGeneration(unittest.TestCase):
+ def setUp(self):
+ # Create sample tools for testing
+ self.tools = [
+ Tool(
+ type="function",
+ function=Function(
+ name="get_weather",
+ description="Get weather information",
+ parameters={
+ "properties": {
+ "location": {
+ "type": "string",
+ "description": "Location to get weather for",
+ },
+ "unit": {
+ "type": "string",
+ "description": "Temperature unit",
+ "enum": ["celsius", "fahrenheit"],
+ },
+ },
+ "required": ["location"],
+ },
+ ),
+ ),
+ Tool(
+ type="function",
+ function=Function(
+ name="search",
+ description="Search for information",
+ parameters={
+ "properties": {
+ "query": {
+ "type": "string",
+ "description": "Search query",
+ },
+ },
+ "required": ["query"],
+ },
+ ),
+ ),
+ ]
+
+ self.tokenizer = get_tokenizer(DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
+ tokenizer_info = TokenizerInfo.from_huggingface(self.tokenizer)
+ self.grammar_compiler = GrammarCompiler(tokenizer_info=tokenizer_info)
+
+ # Initialize all detectors
+ self.pythonic_detector = PythonicDetector()
+ self.deepseekv3_detector = DeepSeekV3Detector()
+ self.llama32_detector = Llama32Detector()
+ self.mistral_detector = MistralDetector()
+ self.qwen25_detector = Qwen25Detector()
+
+ def test_pythonic_detector_ebnf(self):
+ """Test that the PythonicDetector generates valid EBNF."""
+ ebnf = self.pythonic_detector.build_ebnf(self.tools)
+ self.assertIsNotNone(ebnf)
+
+ # Check that the EBNF contains expected patterns
+ self.assertIn('call_get_weather ::= "get_weather" "(" ', ebnf)
+ self.assertIn('"location" "=" basic_string', ebnf)
+ self.assertIn('[ "unit" "=" ("\\"celsius\\"" | "\\"fahrenheit\\"") ]', ebnf)
+
+ # Validate that the EBNF can be compiled by GrammarCompiler
+ try:
+ ctx = self.grammar_compiler.compile_grammar(ebnf)
+ self.assertIsNotNone(ctx, "EBNF should be valid and compile successfully")
+ except RuntimeError as e:
+ self.fail(f"Failed to compile EBNF: {e}")
+
+ def test_deepseekv3_detector_ebnf(self):
+ """Test that the DeepSeekV3Detector generates valid EBNF."""
+ ebnf = self.deepseekv3_detector.build_ebnf(self.tools)
+ self.assertIsNotNone(ebnf)
+
+ # Check that the EBNF contains expected patterns
+ self.assertIn("<|tool▁calls▁begin|>", ebnf)
+ self.assertIn("<|tool▁call▁begin|>function<|tool▁sep|>get_weather", ebnf)
+ self.assertIn('\\"location\\"" ":" basic_string ', ebnf)
+
+ # Validate that the EBNF can be compiled by GrammarCompiler
+ try:
+ ctx = self.grammar_compiler.compile_grammar(ebnf)
+ self.assertIsNotNone(ctx, "EBNF should be valid and compile successfully")
+ except RuntimeError as e:
+ self.fail(f"Failed to compile EBNF: {e}")
+
+ def test_llama32_detector_ebnf(self):
+ """Test that the Llama32Detector generates valid EBNF."""
+ ebnf = self.llama32_detector.build_ebnf(self.tools)
+ self.assertIsNotNone(ebnf)
+
+ # Check that the EBNF contains expected patterns
+ self.assertIn('\\"name\\"" ":" "\\"get_weather\\"', ebnf)
+ self.assertIn('"\\"arguments\\"" ":"', ebnf)
+
+ # Validate that the EBNF can be compiled by GrammarCompiler
+ try:
+ ctx = self.grammar_compiler.compile_grammar(ebnf)
+ self.assertIsNotNone(ctx, "EBNF should be valid and compile successfully")
+ except RuntimeError as e:
+ self.fail(f"Failed to compile EBNF: {e}")
+
+ def test_mistral_detector_ebnf(self):
+ """Test that the MistralDetector generates valid EBNF."""
+ ebnf = self.mistral_detector.build_ebnf(self.tools)
+ self.assertIsNotNone(ebnf)
+
+ # Check that the EBNF contains expected patterns
+ self.assertIn('"[TOOL_CALLS] ["', ebnf)
+ self.assertIn("call_get_weather | call_search", ebnf)
+ self.assertIn('"\\"arguments\\"" ":"', ebnf)
+
+ # Validate that the EBNF can be compiled by GrammarCompiler
+ try:
+ ctx = self.grammar_compiler.compile_grammar(ebnf)
+ self.assertIsNotNone(ctx, "EBNF should be valid and compile successfully")
+ except RuntimeError as e:
+ self.fail(f"Failed to compile EBNF: {e}")
+
+ def test_qwen25_detector_ebnf(self):
+ """Test that the Qwen25Detector generates valid EBNF."""
+ ebnf = self.qwen25_detector.build_ebnf(self.tools)
+ self.assertIsNotNone(ebnf)
+
+ # Check that the EBNF contains expected patterns
+ self.assertIn("", ebnf)
+ self.assertIn('\\"name\\"" ":" "\\"get_weather\\"', ebnf)
+ self.assertIn('"\\"arguments\\"" ":"', ebnf)
+
+ # Validate that the EBNF can be compiled by GrammarCompiler
+ try:
+ ctx = self.grammar_compiler.compile_grammar(ebnf)
+ self.assertIsNotNone(ctx, "EBNF should be valid and compile successfully")
+ except RuntimeError as e:
+ self.fail(f"Failed to compile EBNF: {e}")
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/test/srt/test_function_calling.py b/test/srt/test_openai_function_calling.py
similarity index 71%
rename from test/srt/test_function_calling.py
rename to test/srt/test_openai_function_calling.py
index 9556cf87b..7755d3729 100644
--- a/test/srt/test_function_calling.py
+++ b/test/srt/test_openai_function_calling.py
@@ -290,6 +290,151 @@ class TestOpenAIServerFunctionCalling(CustomTestCase):
self.assertEqual(str(args_obj["int_a"]), "5", "Parameter int_a should be 5")
self.assertEqual(str(args_obj["int_b"]), "7", "Parameter int_b should be 7")
+ def test_function_call_required(self):
+ """
+ Test: Whether tool_choice: "required" works as expected
+ - When tool_choice == "required", the model should return one or more tool_calls.
+ """
+ client = openai.Client(api_key=self.api_key, base_url=self.base_url)
+
+ tools = [
+ {
+ "type": "function",
+ "function": {
+ "name": "sub",
+ "description": "Compute the difference of two integers",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "int_a": {
+ "type": "integer",
+ "description": "First integer",
+ },
+ "int_b": {
+ "type": "integer",
+ "description": "Second integer",
+ },
+ },
+ "required": ["int_a", "int_b"],
+ },
+ "strict": True,
+ },
+ },
+ {
+ "type": "function",
+ "function": {
+ "name": "get_weather",
+ "description": "use this to get latest weather information for a city given its name",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "city": {
+ "type": "string",
+ "description": "name of the city to get weather for",
+ }
+ },
+ "required": ["city"],
+ },
+ },
+ },
+ ]
+
+ messages = [{"role": "user", "content": "What is the capital of France?"}]
+ response = client.chat.completions.create(
+ model=self.model,
+ messages=messages,
+ temperature=0.8,
+ top_p=0.8,
+ stream=False,
+ tools=tools,
+ tool_choice="required",
+ )
+
+ tool_calls = response.choices[0].message.tool_calls
+ self.assertIsNotNone(tool_calls, "No tool_calls in the response")
+ function_name = tool_calls[0].function.name
+ arguments = tool_calls[0].function.arguments
+ args_obj = json.loads(arguments)
+
+ self.assertEqual(
+ function_name, "get_weather", "Function name should be 'get_weather'"
+ )
+ self.assertIn("city", args_obj, "Function arguments should have 'city'")
+ self.assertIn(
+ "Paris", args_obj["city"], "Parameter city should contain 'Paris'"
+ ) # might be flaky
+
+ def test_function_call_specific(self):
+ """
+ Test: Whether tool_choice: ToolChoice works as expected
+ - When tool_choice is a specific ToolChoice, the model should return one or more tool_calls.
+ """
+ client = openai.Client(api_key=self.api_key, base_url=self.base_url)
+
+ tools = [
+ {
+ "type": "function",
+ "function": {
+ "name": "sub",
+ "description": "Compute the difference of two integers",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "int_a": {
+ "type": "integer",
+ "description": "First integer",
+ },
+ "int_b": {
+ "type": "integer",
+ "description": "Second integer",
+ },
+ },
+ "required": ["int_a", "int_b"],
+ },
+ "strict": True,
+ },
+ },
+ {
+ "type": "function",
+ "function": {
+ "name": "get_weather",
+ "description": "use this to get latest weather information for a city given its name",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "city": {
+ "type": "string",
+ "description": "name of the city to get weather for",
+ }
+ },
+ "required": ["city"],
+ },
+ },
+ },
+ ]
+
+ messages = [{"role": "user", "content": "What is the capital of France?"}]
+ response = client.chat.completions.create(
+ model=self.model,
+ messages=messages,
+ temperature=0.8,
+ top_p=0.8,
+ stream=False,
+ tools=tools,
+ tool_choice={"type": "function", "function": {"name": "get_weather"}},
+ )
+
+ tool_calls = response.choices[0].message.tool_calls
+ self.assertIsNotNone(tool_calls, "No tool_calls in the response")
+ function_name = tool_calls[0].function.name
+ arguments = tool_calls[0].function.arguments
+ args_obj = json.loads(arguments)
+
+ self.assertEqual(
+ function_name, "get_weather", "Function name should be 'get_weather'"
+ )
+ self.assertIn("city", args_obj, "Function arguments should have 'city'")
+
class TestOpenAIPythonicFunctionCalling(CustomTestCase):
PYTHONIC_TOOLS = [
@@ -385,11 +530,13 @@ class TestOpenAIPythonicFunctionCalling(CustomTestCase):
stream=False,
)
tool_calls = response.choices[0].message.tool_calls
- self.assertIsInstance(tool_calls, list)
+ self.assertIsInstance(tool_calls, list, "No tool_calls found")
self.assertGreaterEqual(len(tool_calls), 1)
names = [tc.function.name for tc in tool_calls]
- self.assertIn("get_weather", names)
- self.assertIn("get_tourist_attractions", names)
+ self.assertTrue(
+ "get_weather" in names or "get_tourist_attractions" in names,
+ f"Function name '{names}' should container either 'get_weather' or 'get_tourist_attractions'",
+ )
def test_pythonic_tool_call_streaming(self):
"""
@@ -419,8 +566,10 @@ class TestOpenAIPythonicFunctionCalling(CustomTestCase):
self.assertTrue(found_tool_calls, "No tool_calls found in streaming response")
self.assertTrue(found_index, "No index field found in any streamed tool_call")
- self.assertIn("get_weather", found_names)
- self.assertIn("get_tourist_attractions", found_names)
+ self.assertTrue(
+ "get_weather" in found_names or "get_tourist_attractions" in found_names,
+ f"Function name '{found_names}' should container either 'get_weather' or 'get_tourist_attractions'",
+ )
if __name__ == "__main__":