diff --git a/docs/backend/function_calling.ipynb b/docs/backend/function_calling.ipynb index 18ee1a431..25d7e6273 100644 --- a/docs/backend/function_calling.ipynb +++ b/docs/backend/function_calling.ipynb @@ -54,10 +54,12 @@ "source": [ "Note that `--tool-call-parser` defines the parser used to interpret responses. Currently supported parsers include:\n", "\n", - "- llama3: Llama 3.1 / 3.2 (e.g. meta-llama/Llama-3.1-8B-Instruct, meta-llama/Llama-3.2-1B-Instruct).\n", + "- llama3: Llama 3.1 / 3.2 / 3.3 (e.g. meta-llama/Llama-3.1-8B-Instruct, meta-llama/Llama-3.2-1B-Instruct, meta-llama/Llama-3.3-70B-Instruct).\n", + "- llama4: Llama 4 (e.g. meta-llama/Llama-4-Scout-17B-16E-Instruct).\n", "- mistral: Mistral (e.g. mistralai/Mistral-7B-Instruct-v0.3, mistralai/Mistral-Nemo-Instruct-2407, mistralai/\n", "Mistral-Nemo-Instruct-2407, mistralai/Mistral-7B-v0.3).\n", - "- qwen25: Qwen 2.5 (e.g. Qwen/Qwen2.5-1.5B-Instruct, Qwen/Qwen2.5-7B-Instruct) and QwQ (i.e. Qwen/QwQ-32B). Especially, for QwQ, we can enable the reasoning parser together with tool call parser, details about reasoning parser can be found in [reasoning parser](https://docs.sglang.ai/backend/separate_reasoning.html)." + "- qwen25: Qwen 2.5 (e.g. Qwen/Qwen2.5-1.5B-Instruct, Qwen/Qwen2.5-7B-Instruct) and QwQ (i.e. Qwen/QwQ-32B). Especially, for QwQ, we can enable the reasoning parser together with tool call parser, details about reasoning parser can be found in [reasoning parser](https://docs.sglang.ai/backend/separate_reasoning.html).\n", + "- deepseekv3: DeepSeek-v3 (e.g., deepseek-ai/DeepSeek-V3-0324).\n" ] }, { @@ -360,6 +362,164 @@ "print(final_response.choices[0].message.content)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Tool Choice Mode\n", + "\n", + "SGLang supports OpenAI's `tool_choice` parameter to control when and which tools the model should call. This feature is implemented using EBNF (Extended Backus-Naur Form) grammar to ensure reliable tool calling behavior.\n", + "\n", + "### Supported Tool Choice Options\n", + "\n", + "- **`tool_choice=\"required\"`**: Forces the model to call at least one tool\n", + "- **`tool_choice={\"type\": \"function\", \"function\": {\"name\": \"specific_function\"}}`**: Forces the model to call a specific function\n", + "\n", + "### Backend Compatibility\n", + "\n", + "Tool choice is fully supported with the **Xgrammar backend**, which is the default grammar backend (`--grammar-backend xgrammar`). However, it may not be fully supported with other backends such as `outlines`.\n", + "\n", + "### Example: Required Tool Choice" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Response with tool_choice='required':\n", + "Content: None\n", + "Tool calls: [ChatCompletionMessageToolCall(id='call_NFO3TSZuRRO8Eu3Cv79uiQ', function=Function(arguments='{\"city\": \"Paris\", \"unit\": \"celsius\"}', name='get_current_weather'), type='function', index=0)]\n" + ] + } + ], + "source": [ + "from openai import OpenAI\n", + "import json\n", + "from sglang.utils import wait_for_server, print_highlight, terminate_process\n", + "from sglang.test.test_utils import is_in_ci\n", + "\n", + "if is_in_ci():\n", + " from patch import launch_server_cmd\n", + "else:\n", + " from sglang.utils import launch_server_cmd\n", + " import nest_asyncio\n", + "\n", + " nest_asyncio.apply()\n", + "\n", + "# Start a new server session for tool choice examples\n", + "server_process_tool_choice, port_tool_choice = launch_server_cmd(\n", + " \"python3 -m sglang.launch_server --model-path Qwen/Qwen2.5-7B-Instruct --tool-call-parser qwen25 --host 0.0.0.0\"\n", + ")\n", + "wait_for_server(f\"http://localhost:{port_tool_choice}\")\n", + "\n", + "# Initialize client for tool choice examples\n", + "client_tool_choice = OpenAI(\n", + " api_key=\"None\", base_url=f\"http://0.0.0.0:{port_tool_choice}/v1\"\n", + ")\n", + "model_name_tool_choice = client_tool_choice.models.list().data[0].id\n", + "\n", + "# Example with tool_choice=\"required\" - forces the model to call a tool\n", + "messages_required = [\n", + " {\"role\": \"user\", \"content\": \"Hello, what is the capital of France?\"}\n", + "]\n", + "\n", + "# Define tools\n", + "tools = [\n", + " {\n", + " \"type\": \"function\",\n", + " \"function\": {\n", + " \"name\": \"get_current_weather\",\n", + " \"description\": \"Get the current weather in a given location\",\n", + " \"parameters\": {\n", + " \"type\": \"object\",\n", + " \"properties\": {\n", + " \"city\": {\n", + " \"type\": \"string\",\n", + " \"description\": \"The city to find the weather for, e.g. 'San Francisco'\",\n", + " },\n", + " \"unit\": {\n", + " \"type\": \"string\",\n", + " \"description\": \"The unit to fetch the temperature in\",\n", + " \"enum\": [\"celsius\", \"fahrenheit\"],\n", + " },\n", + " },\n", + " \"required\": [\"city\", \"unit\"],\n", + " },\n", + " },\n", + " }\n", + "]\n", + "\n", + "response_required = client_tool_choice.chat.completions.create(\n", + " model=model_name_tool_choice,\n", + " messages=messages_required,\n", + " temperature=0,\n", + " max_tokens=1024,\n", + " tools=tools,\n", + " tool_choice=\"required\", # Force the model to call a tool\n", + ")\n", + "\n", + "print_highlight(\"Response with tool_choice='required':\")\n", + "print(\"Content:\", response_required.choices[0].message.content)\n", + "print(\"Tool calls:\", response_required.choices[0].message.tool_calls)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Example: Specific Function Choice\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Response with specific function choice:\n", + "Content: None\n", + "Tool calls: [ChatCompletionMessageToolCall(id='call_fGL_1qsPQFqntNBPkSynJw', function=Function(arguments='{\"city\": \"Sophia Antipolis\", \"unit\": \"celsius\"}', name='get_current_weather'), type='function', index=0)]\n", + "Called function: get_current_weather\n", + "Arguments: {\"city\": \"Sophia Antipolis\", \"unit\": \"celsius\"}\n" + ] + } + ], + "source": [ + "# Example with specific function choice - forces the model to call a specific function\n", + "messages_specific = [\n", + " {\"role\": \"user\", \"content\": \"What are the most attactive places in France?\"}\n", + "]\n", + "\n", + "response_specific = client_tool_choice.chat.completions.create(\n", + " model=model_name_tool_choice,\n", + " messages=messages_specific,\n", + " temperature=0,\n", + " max_tokens=1024,\n", + " tools=tools,\n", + " tool_choice={\n", + " \"type\": \"function\",\n", + " \"function\": {\"name\": \"get_current_weather\"},\n", + " }, # Force the model to call the specific get_current_weather function\n", + ")\n", + "\n", + "print_highlight(\"Response with specific function choice:\")\n", + "print(\"Content:\", response_specific.choices[0].message.content)\n", + "print(\"Tool calls:\", response_specific.choices[0].message.tool_calls)\n", + "\n", + "if response_specific.choices[0].message.tool_calls:\n", + " tool_call = response_specific.choices[0].message.tool_calls[0]\n", + " print(f\"Called function: {tool_call.function.name}\")\n", + " print(f\"Arguments: {tool_call.function.arguments}\")" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -444,7 +604,7 @@ "outputs": [], "source": [ "import sglang as sgl\n", - "from sglang.srt.function_call_parser import FunctionCallParser\n", + "from sglang.srt.function_call.function_call_parser import FunctionCallParser\n", "from sglang.srt.managers.io_struct import Tool, Function\n", "\n", "llm = sgl.Engine(model_path=\"Qwen/Qwen2.5-7B-Instruct\")\n", diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index e98e3d3de..ff0978e38 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -47,7 +47,7 @@ from sglang.srt.disaggregation.utils import ( register_disaggregation_server, ) from sglang.srt.entrypoints.engine import _launch_subprocesses -from sglang.srt.function_call_parser import FunctionCallParser +from sglang.srt.function_call.function_call_parser import FunctionCallParser from sglang.srt.managers.io_struct import ( AbortReq, CloseSessionReqInput, diff --git a/python/sglang/srt/function_call/base_format_detector.py b/python/sglang/srt/function_call/base_format_detector.py new file mode 100644 index 000000000..497e29c56 --- /dev/null +++ b/python/sglang/srt/function_call/base_format_detector.py @@ -0,0 +1,250 @@ +import json +import logging +from abc import ABC, abstractmethod +from typing import Any, Dict, List + +from partial_json_parser.core.exceptions import MalformedJSON +from partial_json_parser.core.options import Allow + +from sglang.srt.function_call.core_types import ( + StreamingParseResult, + ToolCallItem, + _GetInfoFunc, +) +from sglang.srt.function_call.utils import ( + _find_common_prefix, + _is_complete_json, + _partial_json_loads, +) +from sglang.srt.openai_api.protocol import Tool + +logger = logging.getLogger(__name__) + + +class BaseFormatDetector(ABC): + """Base class providing two sets of interfaces: one-time and streaming incremental.""" + + def __init__(self): + # initialize properties used for state when parsing tool calls in + self._buffer = "" + # streaming mode + self.prev_tool_call_arr: List[Dict] = [] + self.current_tool_id: int = -1 + self.current_tool_name_sent: bool = False + self.streamed_args_for_tool: List[str] = ( + [] + ) # map what has been streamed for each tool so far to a list + self.bot_token = "" + self.eot_token = "" + + def parse_base_json(self, action: Any, tools: List[Tool]) -> List[ToolCallItem]: + tool_indices = { + tool.function.name: i for i, tool in enumerate(tools) if tool.function.name + } + if not isinstance(action, list): + action = [action] + + results = [] + for act in action: + name = act.get("name") + if name and name in tool_indices: + results.append( + ToolCallItem( + tool_index=tool_indices[name], + name=name, + parameters=json.dumps( + act.get("parameters") or act.get("arguments", {}), + ensure_ascii=False, + ), + ) + ) + else: + logger.warning(f"Model attempted to call undefined function: {name}") + + return results + + @abstractmethod + def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult: + """ + Parses the text in one go. Returns success=True if the format matches, otherwise False. + Note that leftover_text here represents "content that this parser will not consume further". + """ + action = json.loads(text) + return StreamingParseResult(calls=self.parse_base_json(action, tools)) + + def parse_streaming_increment( + self, new_text: str, tools: List[Tool] + ) -> StreamingParseResult: + """ + Streaming incremental parsing with tool validation. + """ + # Append new text to buffer + self._buffer += new_text + current_text = self._buffer + if not (self.bot_token in current_text or current_text.startswith("{")): + self._buffer = "" + if self.eot_token in new_text: + new_text = new_text.replace(self.eot_token, "") + return StreamingParseResult(normal_text=new_text) + + # Build tool indices if not already built + if not hasattr(self, "_tool_indices"): + self._tool_indices = { + tool.function.name: i + for i, tool in enumerate(tools) + if tool.function and tool.function.name + } + + flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR + try: + tool_call_arr = [] + is_complete = [] + try: + start_idx = ( + len(self.bot_token) + if current_text.startswith(self.bot_token) + else 0 + ) + while start_idx < len(current_text): + (obj, end_idx) = _partial_json_loads( + current_text[start_idx:], flags + ) + is_complete.append( + _is_complete_json(current_text[start_idx : start_idx + end_idx]) + ) + start_idx += end_idx + len("; ") + + # Validate tool name if present + if "name" in obj and obj["name"] not in self._tool_indices: + # Invalid tool name - reset state + self._buffer = "" + self.current_tool_id = -1 + self.current_tool_name_sent = False + if self.streamed_args_for_tool: + self.streamed_args_for_tool.pop() + return StreamingParseResult() + + # Handle parameters/arguments consistency + if "parameters" in obj: + assert ( + "arguments" not in obj + ), "model generated both parameters and arguments" + obj["arguments"] = obj["parameters"] + tool_call_arr.append(obj) + + except MalformedJSON: + return StreamingParseResult() + + if len(tool_call_arr) == 0: + return StreamingParseResult() + + current_tool_call: Dict = ( + tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {} + ) + + # Handle new tool in array + if len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1: + if self.current_tool_id >= 0: + cur_arguments = current_tool_call.get("arguments") + if cur_arguments: + cur_args_json = json.dumps(cur_arguments) + sent = len(self.streamed_args_for_tool[self.current_tool_id]) + argument_diff = cur_args_json[sent:] + + res = StreamingParseResult( + calls=[ + ToolCallItem( + tool_index=self.current_tool_id, + name="", + parameters=argument_diff, + ) + ], + ) + self.streamed_args_for_tool[ + self.current_tool_id + ] += argument_diff + else: + res = StreamingParseResult() + else: + res = StreamingParseResult() + + self.current_tool_id = len(tool_call_arr) - 1 + self.current_tool_name_sent = False + self.streamed_args_for_tool.append("") + return res + + # Handle tool name + elif not self.current_tool_name_sent: + function_name = current_tool_call.get("name") + if function_name and function_name in self._tool_indices: + res = StreamingParseResult( + calls=[ + ToolCallItem( + tool_index=self._tool_indices[function_name], + name=function_name, + parameters="", + ) + ], + ) + self.current_tool_name_sent = True + else: + res = StreamingParseResult() + + # Handle streaming arguments + else: + cur_arguments = current_tool_call.get("arguments") + res = StreamingParseResult() + + if cur_arguments: + sent = len(self.streamed_args_for_tool[self.current_tool_id]) + cur_args_json = json.dumps(cur_arguments) + prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get( + "arguments" + ) + + argument_diff = None + if is_complete[self.current_tool_id]: + argument_diff = cur_args_json[sent:] + self._buffer = "" + self.prev_tool_call_arr[self.current_tool_id].clear() + self.current_tool_name_sent = False + self.streamed_args_for_tool[self.current_tool_id] = "" + + elif prev_arguments: + prev_args_json = json.dumps(prev_arguments) + if cur_args_json != prev_args_json: + prefix = _find_common_prefix(prev_args_json, cur_args_json) + argument_diff = prefix[sent:] + + if argument_diff is not None: + res = StreamingParseResult( + calls=[ + ToolCallItem( + tool_index=self.current_tool_id, + parameters=argument_diff, + ) + ], + ) + if not is_complete[self.current_tool_id]: + self.streamed_args_for_tool[ + self.current_tool_id + ] += argument_diff + + self.prev_tool_call_arr = tool_call_arr + return res + + except Exception as e: + logger.error(f"Error in parse_streaming_increment: {e}") + return StreamingParseResult() + + @abstractmethod + def has_tool_call(self, text: str) -> bool: + raise NotImplementedError() + + @abstractmethod + def structure_info(self) -> _GetInfoFunc: + raise NotImplementedError() + + @abstractmethod + def build_ebnf(self, tools: List[Tool]) -> str: + raise NotImplementedError() diff --git a/python/sglang/srt/function_call/core_types.py b/python/sglang/srt/function_call/core_types.py new file mode 100644 index 000000000..1ea87df79 --- /dev/null +++ b/python/sglang/srt/function_call/core_types.py @@ -0,0 +1,34 @@ +from dataclasses import dataclass +from typing import Callable, List, Optional + +from pydantic import BaseModel + + +class ToolCallItem(BaseModel): + """Simple encapsulation of the parsed ToolCall result for easier usage in streaming contexts.""" + + tool_index: int + name: Optional[str] = None + parameters: str # JSON string + + +class StreamingParseResult(BaseModel): + """Result of streaming incremental parsing.""" + + normal_text: str = "" + calls: List[ToolCallItem] = [] + + +@dataclass +class StructureInfo: + begin: str + end: str + trigger: str + + +""" +Helper alias of function +Usually it is a function that takes a name string and returns a StructureInfo object, +which can be used to construct a structural_tag object +""" +_GetInfoFunc = Callable[[str], StructureInfo] diff --git a/python/sglang/srt/function_call/deepseekv3_detector.py b/python/sglang/srt/function_call/deepseekv3_detector.py new file mode 100644 index 000000000..b53c68ab6 --- /dev/null +++ b/python/sglang/srt/function_call/deepseekv3_detector.py @@ -0,0 +1,157 @@ +import json +import logging +import re +from typing import List + +from sglang.srt.function_call.base_format_detector import BaseFormatDetector +from sglang.srt.function_call.core_types import ( + StreamingParseResult, + StructureInfo, + ToolCallItem, + _GetInfoFunc, +) +from sglang.srt.function_call.ebnf_composer import EBNFComposer +from sglang.srt.function_call.utils import _is_complete_json +from sglang.srt.openai_api.protocol import Tool + +logger = logging.getLogger(__name__) + + +class DeepSeekV3Detector(BaseFormatDetector): + """ + Detector for DeepSeek models. + Assumes function call format: + '<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_current_weather\n```json\n{"location": "Tokyo"}\n```<|tool▁call▁end|>\n<|tool▁call▁begin|>function<|tool▁sep|>get_current_weather\n```json\n{"location": "Paris"}\n```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|> + """ + + def __init__(self): + super().__init__() + self.bot_token = "<|tool▁calls▁begin|>" + self.eot_token = "<|tool▁calls▁end|>" + self.func_call_regex = r"<|tool▁call▁begin|>.*?<|tool▁call▁end|>" + self.func_detail_regex = r"<|tool▁call▁begin|>(.*)<|tool▁sep|>(.*)\n```json\n(.*)\n```<|tool▁call▁end|>" + self._last_arguments = "" + + def has_tool_call(self, text: str) -> bool: + """Check if the text contains a deepseek format tool call.""" + return self.bot_token in text + + def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult: + """ + One-time parsing: Detects and parses tool calls in the provided text. + + :param text: The complete text to parse. + :param tools: List of available tools. + :return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls. + """ + idx = text.find(self.bot_token) + normal_text = text[:idx].strip() if idx != -1 else text + if self.bot_token not in text: + return StreamingParseResult(normal_text=normal_text, calls=[]) + match_result_list = re.findall(self.func_call_regex, text, re.DOTALL) + calls = [] + try: + for match_result in match_result_list: + # Get function name + func_detail = re.search(self.func_detail_regex, match_result, re.DOTALL) + func_name = func_detail.group(2) + func_args = func_detail.group(3) + func_args = json.loads(func_args) + # construct match_result for parse_base_json + match_result = {"name": func_name, "parameters": func_args} + calls.extend(self.parse_base_json(match_result, tools)) + return StreamingParseResult(normal_text=normal_text, calls=calls) + except Exception as e: + logger.error(f"Error in detect_and_parse: {e}") + # return the normal text if parsing fails + return StreamingParseResult(normal_text=text) + + def parse_streaming_increment( + self, new_text: str, tools: List[Tool] + ) -> StreamingParseResult: + """ + Streaming incremental parsing tool calls for DeepSeekV3 format. + """ + self._buffer += new_text + current_text = self._buffer + + if self.bot_token not in current_text: + self._buffer = "" + for e_token in [self.eot_token, "```", "<|tool▁call▁end|>"]: + if e_token in new_text: + new_text = new_text.replace(e_token, "") + return StreamingParseResult(normal_text=new_text) + + if not hasattr(self, "_tool_indices"): + self._tool_indices = { + tool.function.name: i + for i, tool in enumerate(tools) + if tool.function and tool.function.name + } + + calls: list[ToolCallItem] = [] + try: + partial_match = re.search( + pattern=r"<|tool▁call▁begin|>(.*)<|tool▁sep|>(.*)\n```json\n(.*)", + string=current_text, + flags=re.DOTALL, + ) + if partial_match: + func_name = partial_match.group(2).strip() + func_args_raw = partial_match.group(3).strip() + + if not self.current_tool_name_sent: + calls.append( + ToolCallItem( + tool_index=self._tool_indices.get(func_name, 0), + name=func_name, + parameters="", + ) + ) + self.current_tool_name_sent = True + else: + argument_diff = ( + func_args_raw[len(self._last_arguments) :] + if func_args_raw.startswith(self._last_arguments) + else func_args_raw + ) + + if argument_diff: + calls.append( + ToolCallItem( + tool_index=self._tool_indices.get(func_name, 0), + name=None, + parameters=argument_diff, + ) + ) + self._last_arguments += argument_diff + + if _is_complete_json(func_args_raw): + result = StreamingParseResult(normal_text="", calls=calls) + self._buffer = "" + self._last_arguments = "" + self.current_tool_name_sent = False + return result + + return StreamingParseResult(normal_text="", calls=calls) + + except Exception as e: + logger.error(f"Error in parse_streaming_increment: {e}") + return StreamingParseResult(normal_text=current_text) + + def structure_info(self) -> _GetInfoFunc: + return lambda name: StructureInfo( + begin=">" + name + "\n```json\n", + end="\n```<", + trigger=">" + name + "\n```json\n", + ) + + def build_ebnf(self, tools: List[Tool]): + return EBNFComposer.build_ebnf( + tools, + bot_token=self.bot_token, + eot_token=self.eot_token, + tool_call_separator="", + call_rule_fmt='"<|tool▁call▁begin|>function<|tool▁sep|>{name}\\n```json\\n" {arguments_rule} "\\n```<|tool▁call▁end|>"', + function_format="json", + ) diff --git a/python/sglang/srt/function_call/ebnf_composer.py b/python/sglang/srt/function_call/ebnf_composer.py new file mode 100644 index 000000000..d749f65d7 --- /dev/null +++ b/python/sglang/srt/function_call/ebnf_composer.py @@ -0,0 +1,234 @@ +from typing import Literal, Optional + + +class EBNFComposer: + # Adapted from https://xgrammar.mlc.ai/docs/how_to/ebnf_guided_generation.html#try-out-via-hf-transformers + json_grammar_ebnf_str = r""" + json ::= basic_array | basic_object + basic_any ::= basic_number | basic_string | basic_boolean | basic_null | basic_array | basic_object + basic_integer ::= ("0" | "-"? [1-9] [0-9]*) ".0"? + basic_number ::= ("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)? + basic_string ::= (([\"] basic_string_1 [\"])) + basic_string_1 ::= "" | [^"\\\x00-\x1F] basic_string_1 | "\\" escape basic_string_1 + escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] + basic_boolean ::= "true" | "false" + basic_null ::= "null" + basic_array ::= "[" ("" | ws basic_any (ws "," ws basic_any)*) ws "]" + basic_object ::= "{" ("" | ws basic_string ws ":" ws basic_any ( ws "," ws basic_string ws ":" ws basic_any)*) ws "}" + ws ::= [ \n\t]* + """ + + pythonic_grammar_ebnf_str = r""" + pythonic ::= basic_number | basic_string | basic_array | "True" | "False" | "None" + basic_any ::= basic_number | basic_string | basic_array | basic_object + basic_number ::= ("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)? + basic_string ::= (([\"] basic_string_1 [\"])) + basic_string_1 ::= "" | [^"\\\x00-\x1F] basic_string_1 | "\\" escape basic_string_1 + escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] + basic_array ::= "[" ("" | ws basic_any (ws "," ws basic_any)*) ws "]" + basic_object ::= "{" ("" | ws basic_string ws ":" ws basic_any ( ws "," ws basic_string ws ":" ws basic_any)*) ws "}" + ws ::= [ \n\t]* + """ + + TOOL_CALLS_MAP = { + "pythonic": '"[" function_call ("," function_call)* "]"', + "json": "function_call", + } + + CALL_RULE_MAP = { + "pythonic": 'call_{name} ::= "{name}" "(" {arguments_rule} ")"', + "json": 'call_{name} ::= "{{" "\\"name\\"" ":" "\\"{name}\\"" ", " "\\"arguments\\"" ":" {arguments_rule} "}}"', + } + + ARGUMENTS_RULE_MAP = { + "pythonic": "{arg_rules}", + "json": '"{{" {arg_rules} "}}"', + } + + KEY_VALUE_RULE_MAP = { + "pythonic": '"{key}" "=" {valrule}', + "json": '"\\"{key}\\"" ":" {valrule}', + } + + JSON_TYPE_MAPPING = { + "string": "basic_string", + "number": "basic_number", + "integer": "basic_number", + "boolean": "basic_boolean", + "null": "basic_null", + "array": "basic_array", + "object": "basic_object", + } + + PYTHONIC_TYPE_MAPPING = { + "string": "basic_string", + "number": "basic_number", + "integer": "basic_number", + "boolean": '"True" | "False"', + "null": '"None"', + "array": "basic_array", + "object": "basic_object", + } + + @staticmethod + def get_value_rule( + prop: dict, function_format: Literal["pythonic", "json"] = "json" + ) -> str: + if "enum" in prop: + return EBNFComposer._handle_enum(prop, function_format) + + if "type" in prop: + return EBNFComposer._handle_type(prop, function_format) + + return function_format + + @staticmethod + def _handle_enum(prop: dict, function_format: str) -> str: + """Handle enum properties by formatting each value according to type and format.""" + enum_values = prop["enum"] + prop_type = prop.get("type", "string") + + # Define formatters for different type/format combinations + formatters = { + ("string", "json"): lambda v: f'"\\"{v}\\""', + ("string", "pythonic"): lambda v: f'"\\"{v}\\""', + ("number", "json"): str, + ("number", "pythonic"): str, + ("integer", "json"): str, + ("integer", "pythonic"): str, + ("boolean", "json"): lambda v: "true" if v else "false", + ("boolean", "pythonic"): lambda v: "True" if v else "False", + } + + # Get the formatter or default to string handling + formatter = formatters.get( + (prop_type, function_format), + formatters[("string", function_format)], # Default to string handling + ) + + formatted_values = [formatter(value) for value in enum_values] + enum_rule = " | ".join(formatted_values) + + # Wrap in parentheses if there are multiple values to ensure correct EBNF precedence + if len(formatted_values) > 1: + enum_rule = f"({enum_rule})" + + return enum_rule + + @staticmethod + def _handle_type(prop: dict, function_format: str) -> str: + """Handle type properties using the appropriate type mapping.""" + prop_type = prop["type"] + type_mapping = ( + EBNFComposer.PYTHONIC_TYPE_MAPPING + if function_format == "pythonic" + else EBNFComposer.JSON_TYPE_MAPPING + ) + + if isinstance(prop_type, list): + type_rules = [ + type_mapping[single_type] + for single_type in prop_type + if single_type in type_mapping + ] + return " | ".join(type_rules) if type_rules else function_format + + return type_mapping.get(prop_type, function_format) + + @staticmethod + def build_ebnf( + tools, + *, + call_rule_fmt: Optional[str] = None, + function_format: Literal["pythonic", "json"] = "json", + bot_token: Optional[str] = None, + eot_token: Optional[str] = None, + tool_call_separator: Optional[str] = None, + ): + """ + Generalized EBNF builder for all detectors. + Args: + tools: List of Tool objects to generate EBNF grammar for + call_rule_fmt: Optional custom format string for call_{name} rule. It should define each function call's format, with + the placeholders {name} for the function name and {arguments_rule} for the arguments rule. If None, a default + format based on function_format will be used. + function_format: The format of function calls, either "pythonic" or "json" + bot_token: The token that indicates the start of a tool call section + eot_token: The token that indicates the end of a tool call section + tool_call_separator: The separator between multiple tool calls + """ + # ================================================================= + # Step 1: Determine the root tool calls rule + # ================================================================= + if bot_token and eot_token: + if tool_call_separator: + root_rule = f'"{bot_token}" function_call ( "{tool_call_separator}" function_call )* "{eot_token}"' + else: + root_rule = f'"{bot_token}" function_call "{eot_token}"' + else: + root_rule = EBNFComposer.TOOL_CALLS_MAP[function_format] + + # ================================================================= + # Step 2: Build the header rules + # ================================================================= + ebnf_lines = [ + f"root ::= {root_rule}", + "function_call ::= " + + " | ".join([f"call_{tool.function.name}" for tool in tools]), + ] + + # ================================================================= + # Step 3: Set up formatting templates + # ================================================================= + call_template = ( + f"call_{{name}} ::= {call_rule_fmt}" + if call_rule_fmt + else EBNFComposer.CALL_RULE_MAP[function_format] + ) + args_template = EBNFComposer.ARGUMENTS_RULE_MAP[function_format] + key_value_template = EBNFComposer.KEY_VALUE_RULE_MAP[function_format] + + # ================================================================= + # Step 4: Build rules for each tool + # ================================================================= + for tool in tools: + tool_name = tool.function.name + params = tool.function.parameters or {} + properties = params.get("properties", {}) + required_props = set(params.get("required", [])) + + # Build argument rules for this tool + arg_rules = [] + for prop_name, prop_schema in properties.items(): + value_rule = EBNFComposer.get_value_rule(prop_schema, function_format) + # Create key=value pair + pair = key_value_template.format(key=prop_name, valrule=value_rule) + + if prop_name not in required_props: + pair = f"[ {pair} ]" + + arg_rules.append(pair) + + # Combine all argument rules + combined_args = ' "," '.join(arg_rules) if arg_rules else "" + arguments_rule = args_template.format(arg_rules=combined_args) + + # Add the function call rule and its arguments rule + ebnf_lines.append( + call_template.format( + name=tool_name, arguments_rule=f"arguments_{tool_name}" + ) + ) + ebnf_lines.append(f"arguments_{tool_name} ::= {arguments_rule}") + + # ================================================================= + # Step 5: Add base grammar rules + # ================================================================= + base_grammar = ( + EBNFComposer.pythonic_grammar_ebnf_str + if function_format == "pythonic" + else EBNFComposer.json_grammar_ebnf_str + ) + ebnf_lines.append(base_grammar) + + return "\n".join(ebnf_lines) diff --git a/python/sglang/srt/function_call/function_call_parser.py b/python/sglang/srt/function_call/function_call_parser.py new file mode 100644 index 000000000..f73217086 --- /dev/null +++ b/python/sglang/srt/function_call/function_call_parser.py @@ -0,0 +1,175 @@ +from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Type, Union + +from sglang.srt.function_call.base_format_detector import BaseFormatDetector +from sglang.srt.function_call.core_types import ToolCallItem +from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector +from sglang.srt.function_call.llama32_detector import Llama32Detector +from sglang.srt.function_call.mistral_detector import MistralDetector +from sglang.srt.function_call.pythonic_detector import PythonicDetector +from sglang.srt.function_call.qwen25_detector import Qwen25Detector +from sglang.srt.openai_api.protocol import ( + StructuralTagResponseFormat, + StructuresResponseFormat, + Tool, + ToolChoice, +) + + +class FunctionCallParser: + """ + Parser for function/tool calls in model outputs. + + This class handles both streaming and non-streaming parsing of function calls using a detector. + In streaming scenarios, each time new_text is received, it calls detector.parse_streaming_increment + and returns the resulting normal_text and calls to the upper layer (or SSE). + """ + + ToolCallParserEnum: Dict[str, Type[BaseFormatDetector]] = { + "llama3": Llama32Detector, + "qwen25": Qwen25Detector, + "mistral": MistralDetector, + "deepseekv3": DeepSeekV3Detector, + "pythonic": PythonicDetector, + } + + def __init__(self, tools: List[Tool], tool_call_parser: str): + detector: Type[BaseFormatDetector] = None + detector_class = self.ToolCallParserEnum.get(tool_call_parser) + if detector_class: + detector = detector_class() + else: + raise ValueError(f"Unsupported tool_call_parser: {tool_call_parser}") + + self.detector = detector + self.tools = tools + + def has_tool_call(self, text: str) -> bool: + """ + Check if the given text contains a tool call in the format supported by this parser. + This delegates to the detector's implementation. + + Args: + text: The text to check for tool calls + + Returns: + True if the text contains a tool call, False otherwise + """ + return self.detector.has_tool_call(text) + + def parse_non_stream(self, full_text: str) -> Tuple[str, list[ToolCallItem]]: + """ + One-time parsing of the full text to extract tool calls. + + Args: + full_text: The complete text to parse + + Returns: + A tuple containing: + - The remaining text after parsing that was not consumed by the detector (can be treated as normal text) + - A list of tool calls parsed from the text + """ + parsed_result = self.detector.detect_and_parse(full_text, self.tools) + tool_call_list = parsed_result.calls + if tool_call_list: + return parsed_result.normal_text, tool_call_list + else: + return full_text, [] + + def parse_stream_chunk(self, chunk_text: str) -> Tuple[str, list[ToolCallItem]]: + """ + Streaming incremental parsing of chunks of text as they arrive. + + Args: + chunk_text: The new chunk of text to parse + + Returns: + A tuple containing: + - The normal text that should be displayed to the user + - A list of tool calls parsed from the chunk + """ + final_normal_text = "" + final_calls = [] + + sp_result = self.detector.parse_streaming_increment(chunk_text, self.tools) + if sp_result.normal_text: + final_normal_text = sp_result.normal_text + if sp_result.calls: + final_calls.extend(sp_result.calls) + final_normal_text = sp_result.normal_text + + return final_normal_text, final_calls + + def get_structure_tag(self) -> StructuralTagResponseFormat: + """ + Generate a structural tag response format for all available tools. + + This creates the necessary structural tags that guide the model's output format. + """ + tool_structures: List[StructuresResponseFormat] = list() + tool_trigger_set: Set[str] = set() + + get_structure_info = self.detector.structure_info() + for tool in self.tools: + function = tool.function + name = function.name + assert name is not None + info = get_structure_info(name) + + # accept all if not strict, otherwise only accept the schema + schema = function.parameters if function.strict else {} + + tool_structures.append( + StructuresResponseFormat( + begin=info.begin, + schema=schema, # type: ignore + end=info.end, + ) + ) + tool_trigger_set.add(info.trigger) + + return StructuralTagResponseFormat( + type="structural_tag", + structures=tool_structures, + triggers=list(tool_trigger_set), + ) + + def get_structure_constraint( + self, tool_choice: Union[ToolChoice, Literal["auto", "required"]] + ) -> Optional[Tuple[str, Any]]: + """ + Returns the appropriate structure constraint for tool calls based on the tool_choice. + The constraint is used to guide the model's output format. + + Args: + tool_choice: The tool choice setting from the request + + Returns: + A tuple of (constraint_type, constraint_value) to be added to sampling parameters, + or None if no constraint applies. + """ + # NOTE: structural_tag only supports JSON-compatible content between the begin and end. + # It cannot parse or validate Python syntax like function calls. + if ( + not isinstance(self.detector, PythonicDetector) + and tool_choice == "auto" + and any(tool.function.strict for tool in self.tools) + ): + strict_tag = self.get_structure_tag() + return ("structural_tag", strict_tag) + elif tool_choice == "required" or isinstance(tool_choice, ToolChoice): + ebnf = self.get_ebnf(tool_choice) + return ("ebnf", ebnf) if ebnf is not None else None + + def get_ebnf( + self, tool_choice: Union[ToolChoice, Literal["required"]] + ) -> Optional[str]: + """ + Get the EBNF grammar for the specified tool choice. + """ + filtered_tools = [] + if isinstance(tool_choice, ToolChoice): + fn_name = tool_choice.function.name + filtered_tools = [t for t in self.tools if t.function.name == fn_name] + else: + filtered_tools = self.tools + return self.detector.build_ebnf(filtered_tools) diff --git a/python/sglang/srt/function_call/llama32_detector.py b/python/sglang/srt/function_call/llama32_detector.py new file mode 100644 index 000000000..32670782c --- /dev/null +++ b/python/sglang/srt/function_call/llama32_detector.py @@ -0,0 +1,74 @@ +import json +import logging +from typing import List + +from sglang.srt.function_call.base_format_detector import BaseFormatDetector +from sglang.srt.function_call.core_types import ( + StreamingParseResult, + StructureInfo, + _GetInfoFunc, +) +from sglang.srt.function_call.ebnf_composer import EBNFComposer +from sglang.srt.openai_api.protocol import Tool + +logger = logging.getLogger(__name__) + + +class Llama32Detector(BaseFormatDetector): + """ + Detector for Llama 3.2 models. + Assumes function call format: + <|python_tag|>{"name":"xxx", "arguments":{...}} + """ + + def __init__(self): + super().__init__() + self.bot_token = "<|python_tag|>" + + def has_tool_call(self, text: str) -> bool: + """Check if the text contains a Llama 3.2 format tool call.""" + # depending on the prompt format the Llama model may or may not + # prefix the output with the <|python_tag|> token + return "<|python_tag|>" in text or text.startswith("{") + + def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult: + """Parse function calls from text, handling multiple JSON objects.""" + if "<|python_tag|>" not in text and not text.startswith("{"): + return StreamingParseResult(normal_text=text, calls=[]) + + if "<|python_tag|>" in text: + normal_text, action_text = text.split("<|python_tag|>") + else: + normal_text, action_text = "", text + + # Split by semicolon and process each part + json_parts = [part.strip() for part in action_text.split(";") if part.strip()] + all_actions = [] + for part in json_parts: + try: + # Parse each individual JSON object + action = json.loads(part) + all_actions.append(action) + except json.JSONDecodeError as e: + logger.warning(f"Failed to parse JSON part: {part}") + logger.warning(f"JSON parse error: {str(e)}") + continue + calls = [] + # Only process if we found valid JSON objects + if all_actions: + calls = self.parse_base_json(all_actions, tools) + return StreamingParseResult(normal_text=normal_text, calls=calls) + + def structure_info(self) -> _GetInfoFunc: + return lambda name: StructureInfo( + begin='<|python_tag|>{"name":"' + name + '", "arguments":', + end="}", + trigger="<|python_tag|>", + ) + + def build_ebnf(self, tools: List[Tool]): + return EBNFComposer.build_ebnf( + tools, + function_format="json", + tool_call_separator=",", + ) diff --git a/python/sglang/srt/function_call/mistral_detector.py b/python/sglang/srt/function_call/mistral_detector.py new file mode 100644 index 000000000..a5d2475ea --- /dev/null +++ b/python/sglang/srt/function_call/mistral_detector.py @@ -0,0 +1,84 @@ +import json +import re +from typing import List + +from sglang.srt.function_call.base_format_detector import BaseFormatDetector +from sglang.srt.function_call.core_types import ( + StreamingParseResult, + StructureInfo, + _GetInfoFunc, +) +from sglang.srt.function_call.ebnf_composer import EBNFComposer +from sglang.srt.openai_api.protocol import Tool + + +class MistralDetector(BaseFormatDetector): + """ + Detector for Mistral models. + Assumes function call format: + [TOOL_CALLS] [{"name":"xxx", "arguments":{...}}] + """ + + def __init__(self): + """ + Initializes the detector with necessary state variables. + """ + super().__init__() + self.bot_token = "[TOOL_CALLS] [" + self.eot_token = "]" + self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL) + + def has_tool_call(self, text: str) -> bool: + """Check if the text contains a Mistral format tool call.""" + return self.bot_token in text + + def _clean_text(self, text: str) -> str: + """ + clean text to only leave ''[TOOL_CALLS] [{"name": xxx, "arguments": {xxx}}]' + for example, + text = '[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"location": "Boston, MA", "unit": "fahrenheit"}}]\n\nToday\'s weather in Boston is :{function call result} (in Fahrenheit)\n\nIf you prefer Celsius, please let me know.' + return '[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"location": "Boston, MA", "unit": "fahrenheit"}}]' + The key pattern is [TOOL_CALLS] [...] + """ + # TODO: check if Mistral supports multiple tool calls, currently assume only support one tool call + find_results = re.findall(r"\[TOOL_CALLS\] \[.*?\]", text, re.DOTALL) + if len(find_results) > 0: + return find_results[0] + else: + return "" + + def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult: + """ + One-time parsing: Detects and parses tool calls in the provided text. + + :param text: The complete text to parse. + :param tools: List of available tools. + :return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls. + """ + idx = text.find(self.bot_token) + normal_text = text[:idx].strip() if idx != -1 else text + text = self._clean_text(text) + tool_content = text.replace("[TOOL_CALLS]", "").strip() + raw_tool_calls = self.tool_call_regex.findall(tool_content) + calls = [] + if len(raw_tool_calls) > 0: + raw_tool_call = raw_tool_calls[0] + function_call_arr = json.loads(raw_tool_call) + for match_result in function_call_arr: + calls.extend(self.parse_base_json(match_result, tools)) + return StreamingParseResult(normal_text=normal_text, calls=calls) + + def structure_info(self) -> _GetInfoFunc: + return lambda name: StructureInfo( + begin='[TOOL_CALLS] [{"name":"' + name + '", "arguments":', + end="}]", + trigger="[TOOL_CALLS]", + ) + + def build_ebnf(self, tools: List[Tool]): + return EBNFComposer.build_ebnf( + tools, + bot_token=self.bot_token, + eot_token=self.eot_token, + function_format="json", + ) diff --git a/python/sglang/srt/function_call/pythonic_detector.py b/python/sglang/srt/function_call/pythonic_detector.py new file mode 100644 index 000000000..e60ab63bf --- /dev/null +++ b/python/sglang/srt/function_call/pythonic_detector.py @@ -0,0 +1,163 @@ +import ast +import json +import logging +import re +from typing import List, Optional + +from sglang.srt.function_call.base_format_detector import BaseFormatDetector +from sglang.srt.function_call.core_types import ( + StreamingParseResult, + StructureInfo, + ToolCallItem, + _GetInfoFunc, +) +from sglang.srt.function_call.ebnf_composer import EBNFComposer +from sglang.srt.openai_api.protocol import Tool + +logger = logging.getLogger(__name__) + + +class PythonicDetector(BaseFormatDetector): + """ + Detector for Llama-3.2 and Llama-4 models with pythonic tool call format. + Assumes function call format: + [tool1(arg1=val1, arg2=val2), tool2(arg1=val3)] + Arguments are Python literals (not JSON). + """ + + def __init__(self): + super().__init__() + self.tool_call_regex = re.compile( + r"\[([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s)?\),\s*)*([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s*)?\)\s*)+\]", + re.DOTALL, + ) + + def has_tool_call(self, text: str) -> bool: + return bool(self.tool_call_regex.match(text.strip())) + + def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult: + # Try parsing the text as a Python list of function calls + text = text.strip() + if not (text.startswith("[") and text.endswith("]")): + # Not a pythonic tool call format + return StreamingParseResult(normal_text=text, calls=[]) + try: + module = ast.parse(text) + parsed = getattr(module.body[0], "value", None) + if not ( + isinstance(parsed, ast.List) + and all(isinstance(e, ast.Call) for e in parsed.elts) + ): + return StreamingParseResult(normal_text=text, calls=[]) + calls = [] + tool_indices = { + tool.function.name: i + for i, tool in enumerate(tools) + if tool.function.name + } + for call in parsed.elts: + if not isinstance(call.func, ast.Name): + continue + function_name = call.func.id + arguments = {} + for keyword in call.keywords: + arguments[keyword.arg] = self._get_parameter_value(keyword.value) + calls.append( + ToolCallItem( + tool_index=tool_indices.get(function_name, -1), + name=function_name, + parameters=json.dumps(arguments, ensure_ascii=False), + ) + ) + return StreamingParseResult(normal_text="", calls=calls) + except Exception: + logger.exception("Error in pythonic tool call parsing.") + return StreamingParseResult(normal_text=text, calls=[]) + + def _find_matching_bracket(self, buffer: str, start: int) -> int: + """ + Find the matching closing bracket for the opening bracket at start position. + Properly handles nested brackets. + + Args: + buffer: The text buffer to search in + start: Position of the opening bracket '[' + + Returns: + Position of the matching closing bracket ']', or -1 if not found + """ + bracket_count = 0 + for i in range(start, len(buffer)): + if buffer[i] == "[": + bracket_count += 1 + elif buffer[i] == "]": + bracket_count -= 1 + if bracket_count == 0: + return i + return -1 # No matching bracket found + + def parse_streaming_increment( + self, new_text: str, tools: List[Tool] + ) -> StreamingParseResult: + """ + Streaming incremental parsing for pythonic tool calls. + Buffers input until a complete pythonic tool call (from [ to ]) is found, + then parses and emits any detected calls. + """ + self._buffer += new_text + start = self._buffer.find("[") + + if start == -1: + normal_text = self._buffer + self._buffer = "" + return StreamingParseResult(normal_text=normal_text) + + normal_text = self._buffer[:start] if start > 0 else "" + + end = self._find_matching_bracket(self._buffer, start) + if end != -1: + call_text = self._buffer[start : end + 1] + result = self.detect_and_parse(call_text, tools) + self._buffer = self._buffer[end + 1 :] + + # If we had normal text before the tool call, add it to the result + if normal_text: + result.normal_text = normal_text + (result.normal_text or "") + + return result + + # We have an opening bracket but no closing bracket yet + if normal_text: + self._buffer = self._buffer[start:] + return StreamingParseResult(normal_text=normal_text) + + # Otherwise, we're still accumulating a potential tool call + return StreamingParseResult(normal_text="") + + def _get_parameter_value(self, val): + if isinstance(val, ast.Constant): + return val.value + elif isinstance(val, ast.Dict): + return { + k.value: self._get_parameter_value(v) + for k, v in zip(val.keys, val.values) + } + elif isinstance(val, ast.List): + return [self._get_parameter_value(v) for v in val.elts] + else: + raise ValueError("Tool call arguments must be literals") + + def structure_info(self) -> _GetInfoFunc: + def info(name: str): + return StructureInfo(begin=f"[{name}(", end=")]", trigger=f"[{name}(") + + return info + + def build_ebnf(self, tools: List[Tool]) -> Optional[str]: + return EBNFComposer.build_ebnf( + tools, + bot_token="[", + eot_token="]", + tool_call_separator=",", + function_format="pythonic", + ) diff --git a/python/sglang/srt/function_call/qwen25_detector.py b/python/sglang/srt/function_call/qwen25_detector.py new file mode 100644 index 000000000..1d32099f7 --- /dev/null +++ b/python/sglang/srt/function_call/qwen25_detector.py @@ -0,0 +1,67 @@ +import json +import re +from typing import List + +from sglang.srt.function_call.base_format_detector import BaseFormatDetector +from sglang.srt.function_call.core_types import ( + StreamingParseResult, + StructureInfo, + _GetInfoFunc, +) +from sglang.srt.function_call.ebnf_composer import EBNFComposer +from sglang.srt.openai_api.protocol import Tool + + +class Qwen25Detector(BaseFormatDetector): + """ + Detector for Qwen 2.5 models. + Assumes function call format: + {"name":"xxx", "arguments":{...}} + """ + + def __init__(self): + """ + Initializes the detector with necessary state variables. + """ + super().__init__() + self.bot_token = "" + self.eot_token = "" + + def has_tool_call(self, text: str) -> bool: + """Check if the text contains a Qwen 2.5 format tool call.""" + return self.bot_token in text + + def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult: + """ + One-time parsing: Detects and parses tool calls in the provided text. + + :param text: The complete text to parse. + :param tools: List of available tools. + :return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls. + """ + idx = text.find(self.bot_token) + normal_text = text[:idx].strip() if idx != -1 else text + if self.bot_token not in text: + return StreamingParseResult(normal_text=normal_text, calls=[]) + pattern = rf"{self.bot_token}(.*?){self.eot_token}" + match_result_list = re.findall(pattern, text, re.DOTALL) + calls = [] + for match_result in match_result_list: + match_result = json.loads(match_result) + calls.extend(self.parse_base_json(match_result, tools)) + return StreamingParseResult(normal_text=normal_text, calls=calls) + + def structure_info(self) -> _GetInfoFunc: + return lambda name: StructureInfo( + begin='{"name":"' + name + '", "arguments":', + end="}", + trigger="", + ) + + def build_ebnf(self, tools: List[Tool]): + return EBNFComposer.build_ebnf( + tools, + bot_token=self.bot_token, + eot_token=self.eot_token, + function_format="json", + ) diff --git a/python/sglang/srt/function_call/utils.py b/python/sglang/srt/function_call/utils.py new file mode 100644 index 000000000..e8a585bb2 --- /dev/null +++ b/python/sglang/srt/function_call/utils.py @@ -0,0 +1,35 @@ +import json +from json import JSONDecodeError, JSONDecoder +from typing import Any, Tuple + +import partial_json_parser +from partial_json_parser.core.options import Allow + + +def _find_common_prefix(s1: str, s2: str) -> str: + prefix = "" + min_length = min(len(s1), len(s2)) + for i in range(0, min_length): + if s1[i] == s2[i]: + prefix += s1[i] + else: + break + return prefix + + +def _partial_json_loads(input_str: str, flags: Allow) -> Tuple[Any, int]: + try: + return (partial_json_parser.loads(input_str, flags), len(input_str)) + except JSONDecodeError as e: + if "Extra data" in e.msg: + dec = JSONDecoder() + return dec.raw_decode(input_str) + raise + + +def _is_complete_json(input_str: str) -> bool: + try: + json.loads(input_str) + return True + except JSONDecodeError: + return False diff --git a/python/sglang/srt/function_call_parser.py b/python/sglang/srt/function_call_parser.py deleted file mode 100644 index 549843146..000000000 --- a/python/sglang/srt/function_call_parser.py +++ /dev/null @@ -1,858 +0,0 @@ -import ast -import json -import logging -import re -from abc import ABC, abstractmethod -from dataclasses import dataclass -from json import JSONDecodeError, JSONDecoder -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type - -import partial_json_parser -from partial_json_parser.core.exceptions import MalformedJSON -from partial_json_parser.core.options import Allow -from pydantic import BaseModel - -from sglang.srt.openai_api.protocol import ( - StructuralTagResponseFormat, - StructuresResponseFormat, - Tool, -) - -logger = logging.getLogger(__name__) - -TOOLS_TAG_LIST = [ - "<|plugin|>", - "", - "<|python_tag|>", - "[TOOL_CALLS]", - "<|tool▁calls▁begin|>", -] - - -class ToolCallItem(BaseModel): - """Simple encapsulation of the parsed ToolCall result for easier usage in streaming contexts.""" - - tool_index: int - name: Optional[str] = None - parameters: str # JSON string - - -def _find_common_prefix(s1: str, s2: str) -> str: - prefix = "" - min_length = min(len(s1), len(s2)) - for i in range(0, min_length): - if s1[i] == s2[i]: - prefix += s1[i] - else: - break - return prefix - - -def _partial_json_loads(input_str: str, flags: Allow) -> Tuple[Any, int]: - try: - return (partial_json_parser.loads(input_str, flags), len(input_str)) - except JSONDecodeError as e: - if "Extra data" in e.msg: - dec = JSONDecoder() - return dec.raw_decode(input_str) - raise - - -def _is_complete_json(input_str: str) -> bool: - try: - json.loads(input_str) - return True - except JSONDecodeError: - return False - - -class StreamingParseResult: - """Result of streaming incremental parsing.""" - - def __init__( - self, normal_text: str = "", calls: Optional[List[ToolCallItem]] = None - ): - self.normal_text = normal_text - self.calls = calls or [] - - -@dataclass -class StructureInfo: - begin: str - end: str - trigger: str - - -_GetInfoFunc = Callable[[str], StructureInfo] -""" -Helper alias of function -Usually it is a function that takes a name string and returns a StructureInfo object, -which can be used to construct a structural_tag object -""" - - -class BaseFormatDetector(ABC): - """Base class providing two sets of interfaces: one-time and streaming incremental.""" - - def __init__(self): - # initialize properties used for state when parsing tool calls in - self._buffer = "" - # streaming mode - self.prev_tool_call_arr: List[Dict] = [] - self.current_tool_id: int = -1 - self.current_tool_name_sent: bool = False - self.streamed_args_for_tool: List[str] = ( - [] - ) # map what has been streamed for each tool so far to a list - self.bot_token = "" - self.eot_token = "" - - def parse_base_json(self, action: Any, tools: List[Tool]) -> List[ToolCallItem]: - tool_indices = { - tool.function.name: i for i, tool in enumerate(tools) if tool.function.name - } - if not isinstance(action, list): - action = [action] - - results = [] - for act in action: - name = act.get("name") - if name and name in tool_indices: - results.append( - ToolCallItem( - tool_index=tool_indices[name], - name=name, - parameters=json.dumps( - act.get("parameters") or act.get("arguments", {}), - ensure_ascii=False, - ), - ) - ) - else: - logger.warning(f"Model attempted to call undefined function: {name}") - - return results - - @abstractmethod - def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult: - """ - Parses the text in one go. Returns success=True if the format matches, otherwise False. - Note that leftover_text here represents "content that this parser will not consume further". - """ - action = json.loads(text) - return StreamingParseResult(calls=self.parse_base_json(action, tools)) - - def parse_streaming_increment( - self, new_text: str, tools: List[Tool] - ) -> StreamingParseResult: - """ - Streaming incremental parsing with tool validation. - """ - # Append new text to buffer - self._buffer += new_text - current_text = self._buffer - if not (self.bot_token in current_text or current_text.startswith("{")): - self._buffer = "" - if self.eot_token in new_text: - new_text = new_text.replace(self.eot_token, "") - return StreamingParseResult(normal_text=new_text) - - # Build tool indices if not already built - if not hasattr(self, "_tool_indices"): - self._tool_indices = { - tool.function.name: i - for i, tool in enumerate(tools) - if tool.function and tool.function.name - } - - flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR - try: - tool_call_arr = [] - is_complete = [] - try: - start_idx = ( - len(self.bot_token) - if current_text.startswith(self.bot_token) - else 0 - ) - while start_idx < len(current_text): - (obj, end_idx) = _partial_json_loads( - current_text[start_idx:], flags - ) - is_complete.append( - _is_complete_json(current_text[start_idx : start_idx + end_idx]) - ) - start_idx += end_idx + len("; ") - - # Validate tool name if present - if "name" in obj and obj["name"] not in self._tool_indices: - # Invalid tool name - reset state - self._buffer = "" - self.current_tool_id = -1 - self.current_tool_name_sent = False - if self.streamed_args_for_tool: - self.streamed_args_for_tool.pop() - return StreamingParseResult() - - # Handle parameters/arguments consistency - if "parameters" in obj: - assert ( - "arguments" not in obj - ), "model generated both parameters and arguments" - obj["arguments"] = obj["parameters"] - tool_call_arr.append(obj) - - except MalformedJSON: - return StreamingParseResult() - - if len(tool_call_arr) == 0: - return StreamingParseResult() - - current_tool_call: Dict = ( - tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {} - ) - - # Handle new tool in array - if len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1: - if self.current_tool_id >= 0: - cur_arguments = current_tool_call.get("arguments") - if cur_arguments: - cur_args_json = json.dumps(cur_arguments) - sent = len(self.streamed_args_for_tool[self.current_tool_id]) - argument_diff = cur_args_json[sent:] - - res = StreamingParseResult( - calls=[ - ToolCallItem( - tool_index=self.current_tool_id, - name="", - parameters=argument_diff, - ) - ], - ) - self.streamed_args_for_tool[ - self.current_tool_id - ] += argument_diff - else: - res = StreamingParseResult() - else: - res = StreamingParseResult() - - self.current_tool_id = len(tool_call_arr) - 1 - self.current_tool_name_sent = False - self.streamed_args_for_tool.append("") - return res - - # Handle tool name - elif not self.current_tool_name_sent: - function_name = current_tool_call.get("name") - if function_name and function_name in self._tool_indices: - res = StreamingParseResult( - calls=[ - ToolCallItem( - tool_index=self._tool_indices[function_name], - name=function_name, - parameters="", - ) - ], - ) - self.current_tool_name_sent = True - else: - res = StreamingParseResult() - - # Handle streaming arguments - else: - cur_arguments = current_tool_call.get("arguments") - res = StreamingParseResult() - - if cur_arguments: - sent = len(self.streamed_args_for_tool[self.current_tool_id]) - cur_args_json = json.dumps(cur_arguments) - prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get( - "arguments" - ) - - argument_diff = None - if is_complete[self.current_tool_id]: - argument_diff = cur_args_json[sent:] - self._buffer = "" - self.prev_tool_call_arr[self.current_tool_id].clear() - self.current_tool_name_sent = False - self.streamed_args_for_tool[self.current_tool_id] = "" - - elif prev_arguments: - prev_args_json = json.dumps(prev_arguments) - if cur_args_json != prev_args_json: - prefix = _find_common_prefix(prev_args_json, cur_args_json) - argument_diff = prefix[sent:] - - if argument_diff is not None: - res = StreamingParseResult( - calls=[ - ToolCallItem( - tool_index=self.current_tool_id, - parameters=argument_diff, - ) - ], - ) - if not is_complete[self.current_tool_id]: - self.streamed_args_for_tool[ - self.current_tool_id - ] += argument_diff - - self.prev_tool_call_arr = tool_call_arr - return res - - except Exception as e: - logger.error(f"Error in parse_streaming_increment: {e}") - return StreamingParseResult() - - @abstractmethod - def has_tool_call(self, text: str) -> bool: - raise NotImplementedError() - - @abstractmethod - def structure_info(self) -> _GetInfoFunc: - raise NotImplementedError() - - -class Qwen25Detector(BaseFormatDetector): - """ - Detector for Qwen 2.5 models. - Assumes function call format: - {"name":"xxx", "arguments":{...}} - """ - - def __init__(self): - """ - Initializes the detector with necessary state variables. - """ - super().__init__() - self.bot_token = "" - self.eot_token = "" - - def has_tool_call(self, text: str) -> bool: - """Check if the text contains a Qwen 2.5 format tool call.""" - return self.bot_token in text - - def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult: - """ - One-time parsing: Detects and parses tool calls in the provided text. - - :param text: The complete text to parse. - :param tools: List of available tools. - :return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls. - """ - idx = text.find(self.bot_token) - normal_text = text[:idx].strip() if idx != -1 else text - if self.bot_token not in text: - return StreamingParseResult(normal_text=normal_text, calls=[]) - pattern = rf"{self.bot_token}(.*?){self.eot_token}" - match_result_list = re.findall(pattern, text, re.DOTALL) - calls = [] - for match_result in match_result_list: - match_result = json.loads(match_result) - calls.extend(self.parse_base_json(match_result, tools)) - return StreamingParseResult(normal_text=normal_text, calls=calls) - - def structure_info(self) -> _GetInfoFunc: - return lambda name: StructureInfo( - begin='{"name":"' + name + '", "arguments":', - end="}", - trigger="", - ) - - -class MistralDetector(BaseFormatDetector): - """ - Detector for Mistral models. - Assumes function call format: - <|action_start|><|plugin|>{"name":"xxx", "arguments":{...}}<|action_end|> - """ - - def __init__(self): - """ - Initializes the detector with necessary state variables. - """ - super().__init__() - self.bot_token = "[TOOL_CALLS] [" - self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL) - - def has_tool_call(self, text: str) -> bool: - """Check if the text contains a Mistral format tool call.""" - return self.bot_token in text - - def _clean_text(self, text: str) -> str: - """ - clean text to only leave ''[TOOL_CALLS] [{"name": xxx, "arguments": {xxx}}]' - for example, - text = '[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"location": "Boston, MA", "unit": "fahrenheit"}}]\n\nToday\'s weather in Boston is :{function call result} (in Fahrenheit)\n\nIf you prefer Celsius, please let me know.' - return '[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"location": "Boston, MA", "unit": "fahrenheit"}}]' - The key pattern is [TOOL_CALLS] [...] - """ - find_results = re.findall(r"\[TOOL_CALLS\] \[.*?\]", text, re.DOTALL) - if len(find_results) > 0: - return find_results[0] - else: - return "" - - def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult: - """ - One-time parsing: Detects and parses tool calls in the provided text. - - :param text: The complete text to parse. - :param tools: List of available tools. - :return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls. - """ - idx = text.find(self.bot_token) - normal_text = text[:idx].strip() if idx != -1 else text - text = self._clean_text(text) - tool_content = text.replace("[TOOL_CALLS]", "").strip() - raw_tool_calls = self.tool_call_regex.findall(tool_content) - calls = [] - if len(raw_tool_calls) > 0: - raw_tool_call = raw_tool_calls[0] - function_call_arr = json.loads(raw_tool_call) - for match_result in function_call_arr: - calls.extend(self.parse_base_json(match_result, tools)) - return StreamingParseResult(normal_text=normal_text, calls=calls) - - def structure_info(self) -> _GetInfoFunc: - return lambda name: StructureInfo( - begin='[TOOL_CALLS] [{"name":"' + name + '", "arguments":', - end="}]", - trigger="[TOOL_CALLS]", - ) - - -class Llama32Detector(BaseFormatDetector): - """ - Detector for Llama 3.2 models. - Assumes function call format: - <|python_tag|>{"name":"xxx", "arguments":{...}} - """ - - def __init__(self): - super().__init__() - self.bot_token = "<|python_tag|>" - - def has_tool_call(self, text: str) -> bool: - """Check if the text contains a Llama 3.2 format tool call.""" - # depending on the prompt format the Llama model may or may not - # prefix the output with the <|python_tag|> token - return "<|python_tag|>" in text or text.startswith("{") - - def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult: - """Parse function calls from text, handling multiple JSON objects.""" - if "<|python_tag|>" not in text and not text.startswith("{"): - return StreamingParseResult(normal_text=text, calls=[]) - - if "<|python_tag|>" in text: - normal_text, action_text = text.split("<|python_tag|>") - else: - normal_text, action_text = "", text - - # Split by semicolon and process each part - json_parts = [part.strip() for part in action_text.split(";") if part.strip()] - all_actions = [] - for part in json_parts: - try: - # Parse each individual JSON object - action = json.loads(part) - all_actions.append(action) - except json.JSONDecodeError as e: - logger.warning(f"Failed to parse JSON part: {part}") - logger.warning(f"JSON parse error: {str(e)}") - continue - calls = [] - # Only process if we found valid JSON objects - if all_actions: - calls = self.parse_base_json(all_actions, tools) - return StreamingParseResult(normal_text=normal_text, calls=calls) - - def structure_info(self) -> _GetInfoFunc: - return lambda name: StructureInfo( - begin='<|python_tag|>{"name":"' + name + '", "arguments":', - end="}", - trigger="<|python_tag|>", - ) - - -class DeepSeekV3Detector(BaseFormatDetector): - """ - Detector for DeepSeek models. - Assumes function call format: - '<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_current_weather\n```json\n{"location": "Tokyo"}\n```<|tool▁call▁end|>\n<|tool▁call▁begin|>function<|tool▁sep|>get_current_weather\n```json\n{"location": "Paris"}\n```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|> - """ - - def __init__(self): - super().__init__() - self.bot_token = "<|tool▁calls▁begin|>" - self.eot_token = "<|tool▁calls▁end|>" - self.func_call_regex = r"<|tool▁call▁begin|>.*?<|tool▁call▁end|>" - self.func_detail_regex = r"<|tool▁call▁begin|>(.*)<|tool▁sep|>(.*)\n```json\n(.*)\n```<|tool▁call▁end|>" - self._last_arguments = "" - - def has_tool_call(self, text: str) -> bool: - """Check if the text contains a deepseek format tool call.""" - return self.bot_token in text - - def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult: - """ - One-time parsing: Detects and parses tool calls in the provided text. - - :param text: The complete text to parse. - :param tools: List of available tools. - :return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls. - """ - idx = text.find(self.bot_token) - normal_text = text[:idx].strip() if idx != -1 else text - if self.bot_token not in text: - return StreamingParseResult(normal_text=normal_text, calls=[]) - match_result_list = re.findall(self.func_call_regex, text, re.DOTALL) - calls = [] - try: - for match_result in match_result_list: - # Get function name - func_detail = re.search(self.func_detail_regex, match_result, re.DOTALL) - func_name = func_detail.group(2) - func_args = func_detail.group(3) - func_args = json.loads(func_args) - # construct match_result for parse_base_json - match_result = {"name": func_name, "parameters": func_args} - calls.extend(self.parse_base_json(match_result, tools)) - return StreamingParseResult(normal_text=normal_text, calls=calls) - except Exception as e: - logger.error(f"Error in detect_and_parse: {e}") - # return the normal text if parsing fails - return StreamingParseResult(normal_text=text) - - def structure_info(self) -> _GetInfoFunc: - return lambda name: StructureInfo( - begin=">" + name + "\n```json\n", - end="\n```<", - trigger=">" + name + "\n```json\n", - ) - - def parse_streaming_increment( - self, new_text: str, tools: List[Tool] - ) -> StreamingParseResult: - """ - Streaming incremental parsing tool calls for DeepSeekV3 format. - """ - self._buffer += new_text - current_text = self._buffer - - if self.bot_token not in current_text: - self._buffer = "" - for e_token in [self.eot_token, "```", "<|tool▁call▁end|>"]: - if e_token in new_text: - new_text = new_text.replace(e_token, "") - return StreamingParseResult(normal_text=new_text) - - if not hasattr(self, "_tool_indices"): - self._tool_indices = { - tool.function.name: i - for i, tool in enumerate(tools) - if tool.function and tool.function.name - } - - calls: list[ToolCallItem] = [] - try: - partial_match = re.search( - pattern=r"<|tool▁call▁begin|>(.*)<|tool▁sep|>(.*)\n```json\n(.*)", - string=current_text, - flags=re.DOTALL, - ) - if partial_match: - func_name = partial_match.group(2).strip() - func_args_raw = partial_match.group(3).strip() - - if not self.current_tool_name_sent: - calls.append( - ToolCallItem( - tool_index=self._tool_indices.get(func_name, 0), - name=func_name, - parameters="", - ) - ) - self.current_tool_name_sent = True - else: - argument_diff = ( - func_args_raw[len(self._last_arguments) :] - if func_args_raw.startswith(self._last_arguments) - else func_args_raw - ) - - if argument_diff: - calls.append( - ToolCallItem( - tool_index=self._tool_indices.get(func_name, 0), - name=None, - parameters=argument_diff, - ) - ) - self._last_arguments += argument_diff - - if _is_complete_json(func_args_raw): - result = StreamingParseResult(normal_text="", calls=calls) - self._buffer = "" - self._last_arguments = "" - self.current_tool_name_sent = False - return result - - return StreamingParseResult(normal_text="", calls=calls) - - except Exception as e: - logger.error(f"Error in parse_streaming_increment: {e}") - return StreamingParseResult(normal_text=current_text) - - -class MultiFormatParser: - def __init__(self, detectors: List[BaseFormatDetector]): - """ - :param detectors: A series of available Detector instances passed in - """ - self.detectors = detectors - - def parse_once( - self, text: str, tools: List[Tool] - ) -> Tuple[str, list[ToolCallItem]]: - """ - One-time parsing: Loop through detectors until there are no new matches or text is exhausted - Return: (final_text, all_calls) - - final_text: The remaining text after parsing that was not consumed by any Detector (can be treated as normal text) - - all_calls: All calls parsed by the Detectors - """ - final_calls = [] - final_normal_text = text - for detector in self.detectors: - parsed_result = detector.detect_and_parse(text, tools) - tool_call_list = parsed_result.calls - if len(tool_call_list) > 0: # parsed successfully - final_calls = tool_call_list - final_normal_text = parsed_result.normal_text - break - - # leftover_text is the normal text not consumed by any Detector - return final_normal_text, final_calls - - def parse_streaming_increment( - self, new_text: str, tools: List[Tool] - ) -> Tuple[str, list[ToolCallItem]]: - """ - Streaming incremental parsing: Feed new_text to each detector's parse_streaming_increment - and merge their produced normal_text/calls to return. - (The logic here can be "priority-based" or "parallel parsing" based on your needs) - """ - final_normal_text = "" - final_calls = [] - - for detector in self.detectors: - sp_result = detector.parse_streaming_increment(new_text, tools) - # Merge normal_text and calls - # If one sp_result contains result call, this should be a successful parse - # If one sp_result only contains normal_text, this can either be a successful - # parse or it is not using the desired parsing tool. - if sp_result.normal_text: - final_normal_text = sp_result.normal_text - if sp_result.calls: - final_calls.extend(sp_result.calls) - final_normal_text = sp_result.normal_text - break - - return final_normal_text, final_calls - - -class PythonicDetector(BaseFormatDetector): - """ - Detector for Llama-3.2 and Llama-4 models with pythonic tool call format. - Assumes function call format: - [tool1(arg1=val1, arg2=val2), tool2(arg1=val3)] - Arguments are Python literals (not JSON). - """ - - def __init__(self): - super().__init__() - self.tool_call_regex = re.compile( - r"\[([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s)?\),\s*)*([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s*)?\)\s*)+\]", - re.DOTALL, - ) - - def has_tool_call(self, text: str) -> bool: - return bool(self.tool_call_regex.match(text.strip())) - - def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult: - # Try parsing the text as a Python list of function calls - text = text.strip() - if not (text.startswith("[") and text.endswith("]")): - # Not a pythonic tool call format - return StreamingParseResult(normal_text=text, calls=[]) - try: - module = ast.parse(text) - parsed = getattr(module.body[0], "value", None) - if not ( - isinstance(parsed, ast.List) - and all(isinstance(e, ast.Call) for e in parsed.elts) - ): - return StreamingParseResult(normal_text=text, calls=[]) - calls = [] - tool_indices = { - tool.function.name: i - for i, tool in enumerate(tools) - if tool.function.name - } - for call in parsed.elts: - if not isinstance(call.func, ast.Name): - continue - function_name = call.func.id - arguments = {} - for keyword in call.keywords: - arguments[keyword.arg] = self._get_parameter_value(keyword.value) - calls.append( - ToolCallItem( - tool_index=tool_indices.get(function_name, -1), - name=function_name, - parameters=json.dumps(arguments, ensure_ascii=False), - ) - ) - return StreamingParseResult(normal_text="", calls=calls) - except Exception: - logger.exception("Error in pythonic tool call parsing.") - return StreamingParseResult(normal_text=text, calls=[]) - - def parse_streaming_increment( - self, new_text: str, tools: List[Tool] - ) -> StreamingParseResult: - """ - Streaming incremental parsing for pythonic tool calls. - Buffers input until a complete pythonic tool call (from [ to ]) is found, - then parses and emits any detected calls. - """ - self._buffer += new_text - start = self._buffer.find("[") - end = self._buffer.find("]", start) - if start != -1 and end != -1: - call_text = self._buffer[start : end + 1] - result = self.detect_and_parse(call_text, tools) - self._buffer = self._buffer[end + 1 :] - return result - return StreamingParseResult(normal_text="") - - def _get_parameter_value(self, val): - if isinstance(val, ast.Constant): - return val.value - elif isinstance(val, ast.Dict): - return { - k.value: self._get_parameter_value(v) - for k, v in zip(val.keys, val.values) - } - elif isinstance(val, ast.List): - return [self._get_parameter_value(v) for v in val.elts] - else: - raise ValueError("Tool call arguments must be literals") - - def structure_info(self) -> _GetInfoFunc: - def info(name: str): - return StructureInfo(begin="[", end="]", trigger="") - - return info - - -class FunctionCallParser: - """ - In streaming scenarios, each time new_text is received, it calls multi_format_parser.parse_streaming_increment - and returns the resulting normal_text and calls to the upper layer (or SSE). - """ - - ToolCallParserEnum: Dict[str, Type[BaseFormatDetector]] = { - "llama3": Llama32Detector, - "qwen25": Qwen25Detector, - "mistral": MistralDetector, - "deepseekv3": DeepSeekV3Detector, - "pythonic": PythonicDetector, - } - - def __init__(self, tools: List[Tool], tool_call_parser: str): - detectors = [] - if tool_call_parser: - detector_class = self.ToolCallParserEnum.get(tool_call_parser) - if detector_class: - detectors.append(detector_class()) - else: - raise ValueError(f"Unsupported tool_call_parser: {tool_call_parser}") - else: - raise ValueError("Tool Call Parser Not Given!") - - self.multi_format_parser = MultiFormatParser(detectors) - self.tools = tools - - def has_tool_call(self, text: str) -> bool: - """ - Check if the given text contains a tool call in the format supported by this parser. - This delegates to the detector's implementation. - - :param text: The text to check for tool calls - :return: True if the text contains a tool call, False otherwise - """ - # Check all detectors in the multi_format_parser - for detector in self.multi_format_parser.detectors: - if detector.has_tool_call(text): - return True - return False - - def parse_non_stream(self, full_text: str) -> Tuple[str, list[ToolCallItem]]: - """ - Non-streaming call: one-time parsing - """ - full_normal_text, calls = self.multi_format_parser.parse_once( - full_text, self.tools - ) - return full_normal_text, calls - - def parse_stream_chunk(self, chunk_text: str) -> Tuple[str, list[ToolCallItem]]: - """ - Streaming call: incremental parsing - """ - normal_text, calls = self.multi_format_parser.parse_streaming_increment( - chunk_text, self.tools - ) - return normal_text, calls - - def structure_infos(self) -> List[_GetInfoFunc]: - """ - Returns a list of structure_info functions for each detector - """ - return [ - detector.structure_info() for detector in self.multi_format_parser.detectors - ] - - def get_structure_tag(self) -> StructuralTagResponseFormat: - tool_structures: List[StructuresResponseFormat] = list() - tool_trigger_set: Set[str] = set() - - for wrapper in self.structure_infos(): - for tool in self.tools: - function = tool.function - name = function.name - assert name is not None - info = wrapper(name) - - # accept all if not strict, otherwise only accept the schema - schema = function.parameters if function.strict else {} - - tool_structures.append( - StructuresResponseFormat( - begin=info.begin, - schema=schema, # type: ignore - end=info.end, - ) - ) - tool_trigger_set.add(info.trigger) - - return StructuralTagResponseFormat( - type="structural_tag", - structures=tool_structures, - triggers=list(tool_trigger_set), - ) diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index 9f54641d9..d23cf5c05 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -40,7 +40,7 @@ from sglang.srt.conversation import ( get_conv_template_by_model_path, register_conv_template, ) -from sglang.srt.function_call_parser import FunctionCallParser +from sglang.srt.function_call.function_call_parser import FunctionCallParser from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput from sglang.srt.openai_api.protocol import ( BatchRequest, @@ -970,7 +970,7 @@ def v1_chat_generate_request( # - image_data: None or a list of image strings (URLs or base64 strings). # - audio_data: None or a list of audio strings (URLs). # None skips any image processing in GenerateReqInput. - strict_tag = None + tool_call_constraint = None prompt = "" prompt_ids = [] if not isinstance(request.messages, str): @@ -989,7 +989,9 @@ def v1_chat_generate_request( tool_call_parser = tokenizer_manager.server_args.tool_call_parser parser = FunctionCallParser(request.tools, tool_call_parser) - strict_tag = parser.get_structure_tag() + tool_call_constraint = parser.get_structure_constraint( + request.tool_choice + ) if chat_template_name is None: openai_compatible_messages = [] @@ -1156,20 +1158,24 @@ def v1_chat_generate_request( request.response_format.model_dump(by_alias=True) ) - if strict_tag is not None: - if ( - sampling_params.get("regex") - or sampling_params.get("ebnf") - or sampling_params.get("structural_tag") - or sampling_params.get("json_schema") - ): - logger.warning( - "Constrained decoding is not compatible with tool calls." + # Check if there are already existing output constraints + has_existing_constraints = ( + sampling_params.get("regex") + or sampling_params.get("ebnf") + or sampling_params.get("structural_tag") + or sampling_params.get("json_schema") + ) + + if tool_call_constraint and has_existing_constraints: + logger.warning("Constrained decoding is not compatible with tool calls.") + elif tool_call_constraint: + constraint_type, constraint_value = tool_call_constraint + if constraint_type == "structural_tag": + sampling_params[constraint_type] = convert_json_schema_to_str( + constraint_value.model_dump(by_alias=True) ) else: - sampling_params["structural_tag"] = convert_json_schema_to_str( - strict_tag.model_dump(by_alias=True) - ) + sampling_params[constraint_type] = constraint_value sampling_params_list.append(sampling_params) diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index dcb4970fc..b63d7b5e7 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -36,7 +36,7 @@ suites = { TestFile("test_fa3.py", 376), TestFile("test_fim_completion.py", 40), TestFile("test_fp8_kernel.py", 8), - TestFile("test_function_calling.py", 60), + TestFile("test_function_call_parser.py", 10), TestFile("test_fused_moe.py", 30), TestFile("test_hicache.py", 116), TestFile("test_hicache_mla.py", 254), @@ -54,6 +54,7 @@ suites = { TestFile("test_flashmla.py", 300), TestFile("test_no_chunked_prefill.py", 108), TestFile("test_no_overlap_scheduler.py", 216), + TestFile("test_openai_function_calling.py", 60), TestFile("test_openai_server.py", 149), TestFile("test_penalty.py", 41), TestFile("test_page_size.py", 60), diff --git a/test/srt/test_function_call_parser.py b/test/srt/test_function_call_parser.py new file mode 100644 index 000000000..0a00a7dbd --- /dev/null +++ b/test/srt/test_function_call_parser.py @@ -0,0 +1,408 @@ +import json +import unittest + +from xgrammar import GrammarCompiler, TokenizerInfo + +from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector +from sglang.srt.function_call.llama32_detector import Llama32Detector +from sglang.srt.function_call.mistral_detector import MistralDetector +from sglang.srt.function_call.pythonic_detector import PythonicDetector +from sglang.srt.function_call.qwen25_detector import Qwen25Detector +from sglang.srt.hf_transformers_utils import get_tokenizer +from sglang.srt.openai_api.protocol import Function, Tool +from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST + + +class TestPythonicDetector(unittest.TestCase): + def setUp(self): + # Create sample tools for testing + self.tools = [ + Tool( + type="function", + function=Function( + name="get_weather", + description="Get weather information", + parameters={ + "properties": { + "location": { + "type": "string", + "description": "Location to get weather for", + }, + "unit": { + "type": "string", + "description": "Temperature unit", + "enum": ["celsius", "fahrenheit"], + }, + }, + "required": ["location"], + }, + ), + ), + Tool( + type="function", + function=Function( + name="search", + description="Search for information", + parameters={ + "properties": { + "query": { + "type": "string", + "description": "Search query", + }, + }, + "required": ["query"], + }, + ), + ), + ] + self.detector = PythonicDetector() + + def test_parse_streaming_no_brackets(self): + """Test parsing text with no brackets (no tool calls).""" + text = "This is just normal text without any tool calls." + result = self.detector.parse_streaming_increment(text, self.tools) + + self.assertEqual(result.normal_text, text) + self.assertEqual(result.calls, []) + self.assertEqual(self.detector._buffer, "") # Buffer should be cleared + + def test_parse_streaming_complete_tool_call(self): + """Test parsing a complete tool call.""" + text = "Here's a tool call: [get_weather(location='New York', unit='celsius')]" + result = self.detector.parse_streaming_increment(text, self.tools) + + self.assertEqual(result.normal_text, "Here's a tool call: ") + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].name, "get_weather") + self.assertEqual( + self.detector._buffer, "" + ) # Buffer should be cleared after processing + + # Check the parameters + params = json.loads(result.calls[0].parameters) + self.assertEqual(params["location"], "New York") + self.assertEqual(params["unit"], "celsius") + + def test_parse_streaming_text_before_tool_call(self): + """Test parsing text that appears before a tool call.""" + text = "This is some text before [get_weather(location='London')]" + result = self.detector.parse_streaming_increment(text, self.tools) + + self.assertEqual(result.normal_text, "This is some text before ") + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].name, "get_weather") + + # Check the parameters + params = json.loads(result.calls[0].parameters) + self.assertEqual(params["location"], "London") + + def test_parse_streaming_partial_tool_call(self): + """Test parsing a partial tool call that spans multiple chunks.""" + # First chunk with opening bracket but no closing bracket + text1 = "Let me check the weather: [get_weather(location=" + result1 = self.detector.parse_streaming_increment(text1, self.tools) + + self.assertEqual(result1.normal_text, "Let me check the weather: ") + self.assertEqual(result1.calls, []) + self.assertEqual( + self.detector._buffer, "[get_weather(location=" + ) # Partial tool call remains in buffer + + # Second chunk completing the tool call + text2 = "'Paris')]" + result2 = self.detector.parse_streaming_increment(text2, self.tools) + + self.assertEqual(result2.normal_text, "") + self.assertEqual(len(result2.calls), 1) + self.assertEqual(result2.calls[0].name, "get_weather") + + # Check the parameters + params = json.loads(result2.calls[0].parameters) + self.assertEqual(params["location"], "Paris") + self.assertEqual( + self.detector._buffer, "" + ) # Buffer should be cleared after processing + + def test_parse_streaming_bracket_without_text_before(self): + """Test parsing a tool call that starts at the beginning of the text.""" + text = "[search(query='python programming')]" + result = self.detector.parse_streaming_increment(text, self.tools) + + self.assertEqual(result.normal_text, "") + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].name, "search") + + # Check the parameters + params = json.loads(result.calls[0].parameters) + self.assertEqual(params["query"], "python programming") + + def test_parse_streaming_text_after_tool_call(self): + """Test parsing text that appears after a tool call.""" + # First chunk with complete tool call and some text after + text = "[get_weather(location='Tokyo')] Here's the forecast:" + result = self.detector.parse_streaming_increment(text, self.tools) + + self.assertEqual(result.normal_text, "") + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].name, "get_weather") + self.assertEqual( + self.detector._buffer, " Here's the forecast:" + ) # Text after tool call remains in buffer + + # Process the remaining text in buffer + result2 = self.detector.parse_streaming_increment("", self.tools) + self.assertEqual(result2.normal_text, " Here's the forecast:") + self.assertEqual(result2.calls, []) + self.assertEqual(self.detector._buffer, "") # Buffer should be cleared + + def test_parse_streaming_multiple_tool_calls(self): + """Test parsing multiple tool calls in sequence.""" + text = "[get_weather(location='Berlin')] and [search(query='restaurants')]" + + # First tool call + result1 = self.detector.parse_streaming_increment(text, self.tools) + self.assertEqual(len(result1.calls), 1) + self.assertEqual(result1.calls[0].name, "get_weather") + self.assertEqual(self.detector._buffer, " and [search(query='restaurants')]") + + # Second tool call + result2 = self.detector.parse_streaming_increment("", self.tools) + self.assertEqual(result2.normal_text, " and ") + self.assertEqual(len(result2.calls), 1) + self.assertEqual(result2.calls[0].name, "search") + self.assertEqual(self.detector._buffer, "") + + def test_parse_streaming_opening_bracket_only(self): + """Test parsing text with only an opening bracket but no closing bracket.""" + text = "Let's try this: [" + result = self.detector.parse_streaming_increment(text, self.tools) + + self.assertEqual(result.normal_text, "Let's try this: ") + self.assertEqual(result.calls, []) + self.assertEqual( + self.detector._buffer, "[" + ) # Opening bracket remains in buffer + + def test_parse_streaming_nested_brackets(self): + """Test parsing tool calls with nested brackets in arguments.""" + # Test with list argument containing nested brackets + text = "[get_weather(location='New York', unit='celsius', data=[1, 2, 3])]" + result = self.detector.parse_streaming_increment(text, self.tools) + + self.assertEqual(result.normal_text, "") + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].name, "get_weather") + self.assertEqual(self.detector._buffer, "") + + # Check the parameters + params = json.loads(result.calls[0].parameters) + self.assertEqual(params["location"], "New York") + self.assertEqual(params["unit"], "celsius") + self.assertEqual(params["data"], [1, 2, 3]) + + def test_parse_streaming_nested_brackets_dict(self): + """Test parsing tool calls with nested dictionaries and lists.""" + # Test with nested dict and list arguments + text = "[search(query='test', config={'options': [1, 2], 'nested': {'key': 'value'}})]" + result = self.detector.parse_streaming_increment(text, self.tools) + + self.assertEqual(result.normal_text, "") + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].name, "search") + self.assertEqual(self.detector._buffer, "") + + # Check the parameters + params = json.loads(result.calls[0].parameters) + self.assertEqual(params["query"], "test") + self.assertEqual(params["config"]["options"], [1, 2]) + self.assertEqual(params["config"]["nested"]["key"], "value") + + def test_parse_streaming_multiple_tools_with_nested_brackets(self): + """Test parsing multiple tool calls with nested brackets.""" + text = "[get_weather(location='Paris', data=[10, 20]), search(query='test', filters=['a', 'b'])]" + result = self.detector.parse_streaming_increment(text, self.tools) + + self.assertEqual(result.normal_text, "") + self.assertEqual(len(result.calls), 2) + self.assertEqual(self.detector._buffer, "") + + # Check first tool call + params1 = json.loads(result.calls[0].parameters) + self.assertEqual(result.calls[0].name, "get_weather") + self.assertEqual(params1["location"], "Paris") + self.assertEqual(params1["data"], [10, 20]) + + # Check second tool call + params2 = json.loads(result.calls[1].parameters) + self.assertEqual(result.calls[1].name, "search") + self.assertEqual(params2["query"], "test") + self.assertEqual(params2["filters"], ["a", "b"]) + + def test_parse_streaming_partial_nested_brackets(self): + """Test parsing partial tool calls with nested brackets across chunks.""" + # First chunk with nested brackets but incomplete + text1 = "Here's a call: [get_weather(location='Tokyo', data=[1, 2" + result1 = self.detector.parse_streaming_increment(text1, self.tools) + + self.assertEqual(result1.normal_text, "Here's a call: ") + self.assertEqual(result1.calls, []) + self.assertEqual( + self.detector._buffer, "[get_weather(location='Tokyo', data=[1, 2" + ) + + # Second chunk completing the nested brackets + text2 = ", 3])]" + result2 = self.detector.parse_streaming_increment(text2, self.tools) + + self.assertEqual(result2.normal_text, "") + self.assertEqual(len(result2.calls), 1) + self.assertEqual(result2.calls[0].name, "get_weather") + self.assertEqual(self.detector._buffer, "") + + # Check the parameters + params = json.loads(result2.calls[0].parameters) + self.assertEqual(params["location"], "Tokyo") + self.assertEqual(params["data"], [1, 2, 3]) + + +class TestEBNFGeneration(unittest.TestCase): + def setUp(self): + # Create sample tools for testing + self.tools = [ + Tool( + type="function", + function=Function( + name="get_weather", + description="Get weather information", + parameters={ + "properties": { + "location": { + "type": "string", + "description": "Location to get weather for", + }, + "unit": { + "type": "string", + "description": "Temperature unit", + "enum": ["celsius", "fahrenheit"], + }, + }, + "required": ["location"], + }, + ), + ), + Tool( + type="function", + function=Function( + name="search", + description="Search for information", + parameters={ + "properties": { + "query": { + "type": "string", + "description": "Search query", + }, + }, + "required": ["query"], + }, + ), + ), + ] + + self.tokenizer = get_tokenizer(DEFAULT_SMALL_MODEL_NAME_FOR_TEST) + tokenizer_info = TokenizerInfo.from_huggingface(self.tokenizer) + self.grammar_compiler = GrammarCompiler(tokenizer_info=tokenizer_info) + + # Initialize all detectors + self.pythonic_detector = PythonicDetector() + self.deepseekv3_detector = DeepSeekV3Detector() + self.llama32_detector = Llama32Detector() + self.mistral_detector = MistralDetector() + self.qwen25_detector = Qwen25Detector() + + def test_pythonic_detector_ebnf(self): + """Test that the PythonicDetector generates valid EBNF.""" + ebnf = self.pythonic_detector.build_ebnf(self.tools) + self.assertIsNotNone(ebnf) + + # Check that the EBNF contains expected patterns + self.assertIn('call_get_weather ::= "get_weather" "(" ', ebnf) + self.assertIn('"location" "=" basic_string', ebnf) + self.assertIn('[ "unit" "=" ("\\"celsius\\"" | "\\"fahrenheit\\"") ]', ebnf) + + # Validate that the EBNF can be compiled by GrammarCompiler + try: + ctx = self.grammar_compiler.compile_grammar(ebnf) + self.assertIsNotNone(ctx, "EBNF should be valid and compile successfully") + except RuntimeError as e: + self.fail(f"Failed to compile EBNF: {e}") + + def test_deepseekv3_detector_ebnf(self): + """Test that the DeepSeekV3Detector generates valid EBNF.""" + ebnf = self.deepseekv3_detector.build_ebnf(self.tools) + self.assertIsNotNone(ebnf) + + # Check that the EBNF contains expected patterns + self.assertIn("<|tool▁calls▁begin|>", ebnf) + self.assertIn("<|tool▁call▁begin|>function<|tool▁sep|>get_weather", ebnf) + self.assertIn('\\"location\\"" ":" basic_string ', ebnf) + + # Validate that the EBNF can be compiled by GrammarCompiler + try: + ctx = self.grammar_compiler.compile_grammar(ebnf) + self.assertIsNotNone(ctx, "EBNF should be valid and compile successfully") + except RuntimeError as e: + self.fail(f"Failed to compile EBNF: {e}") + + def test_llama32_detector_ebnf(self): + """Test that the Llama32Detector generates valid EBNF.""" + ebnf = self.llama32_detector.build_ebnf(self.tools) + self.assertIsNotNone(ebnf) + + # Check that the EBNF contains expected patterns + self.assertIn('\\"name\\"" ":" "\\"get_weather\\"', ebnf) + self.assertIn('"\\"arguments\\"" ":"', ebnf) + + # Validate that the EBNF can be compiled by GrammarCompiler + try: + ctx = self.grammar_compiler.compile_grammar(ebnf) + self.assertIsNotNone(ctx, "EBNF should be valid and compile successfully") + except RuntimeError as e: + self.fail(f"Failed to compile EBNF: {e}") + + def test_mistral_detector_ebnf(self): + """Test that the MistralDetector generates valid EBNF.""" + ebnf = self.mistral_detector.build_ebnf(self.tools) + self.assertIsNotNone(ebnf) + + # Check that the EBNF contains expected patterns + self.assertIn('"[TOOL_CALLS] ["', ebnf) + self.assertIn("call_get_weather | call_search", ebnf) + self.assertIn('"\\"arguments\\"" ":"', ebnf) + + # Validate that the EBNF can be compiled by GrammarCompiler + try: + ctx = self.grammar_compiler.compile_grammar(ebnf) + self.assertIsNotNone(ctx, "EBNF should be valid and compile successfully") + except RuntimeError as e: + self.fail(f"Failed to compile EBNF: {e}") + + def test_qwen25_detector_ebnf(self): + """Test that the Qwen25Detector generates valid EBNF.""" + ebnf = self.qwen25_detector.build_ebnf(self.tools) + self.assertIsNotNone(ebnf) + + # Check that the EBNF contains expected patterns + self.assertIn("", ebnf) + self.assertIn('\\"name\\"" ":" "\\"get_weather\\"', ebnf) + self.assertIn('"\\"arguments\\"" ":"', ebnf) + + # Validate that the EBNF can be compiled by GrammarCompiler + try: + ctx = self.grammar_compiler.compile_grammar(ebnf) + self.assertIsNotNone(ctx, "EBNF should be valid and compile successfully") + except RuntimeError as e: + self.fail(f"Failed to compile EBNF: {e}") + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_function_calling.py b/test/srt/test_openai_function_calling.py similarity index 71% rename from test/srt/test_function_calling.py rename to test/srt/test_openai_function_calling.py index 9556cf87b..7755d3729 100644 --- a/test/srt/test_function_calling.py +++ b/test/srt/test_openai_function_calling.py @@ -290,6 +290,151 @@ class TestOpenAIServerFunctionCalling(CustomTestCase): self.assertEqual(str(args_obj["int_a"]), "5", "Parameter int_a should be 5") self.assertEqual(str(args_obj["int_b"]), "7", "Parameter int_b should be 7") + def test_function_call_required(self): + """ + Test: Whether tool_choice: "required" works as expected + - When tool_choice == "required", the model should return one or more tool_calls. + """ + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + + tools = [ + { + "type": "function", + "function": { + "name": "sub", + "description": "Compute the difference of two integers", + "parameters": { + "type": "object", + "properties": { + "int_a": { + "type": "integer", + "description": "First integer", + }, + "int_b": { + "type": "integer", + "description": "Second integer", + }, + }, + "required": ["int_a", "int_b"], + }, + "strict": True, + }, + }, + { + "type": "function", + "function": { + "name": "get_weather", + "description": "use this to get latest weather information for a city given its name", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "name of the city to get weather for", + } + }, + "required": ["city"], + }, + }, + }, + ] + + messages = [{"role": "user", "content": "What is the capital of France?"}] + response = client.chat.completions.create( + model=self.model, + messages=messages, + temperature=0.8, + top_p=0.8, + stream=False, + tools=tools, + tool_choice="required", + ) + + tool_calls = response.choices[0].message.tool_calls + self.assertIsNotNone(tool_calls, "No tool_calls in the response") + function_name = tool_calls[0].function.name + arguments = tool_calls[0].function.arguments + args_obj = json.loads(arguments) + + self.assertEqual( + function_name, "get_weather", "Function name should be 'get_weather'" + ) + self.assertIn("city", args_obj, "Function arguments should have 'city'") + self.assertIn( + "Paris", args_obj["city"], "Parameter city should contain 'Paris'" + ) # might be flaky + + def test_function_call_specific(self): + """ + Test: Whether tool_choice: ToolChoice works as expected + - When tool_choice is a specific ToolChoice, the model should return one or more tool_calls. + """ + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + + tools = [ + { + "type": "function", + "function": { + "name": "sub", + "description": "Compute the difference of two integers", + "parameters": { + "type": "object", + "properties": { + "int_a": { + "type": "integer", + "description": "First integer", + }, + "int_b": { + "type": "integer", + "description": "Second integer", + }, + }, + "required": ["int_a", "int_b"], + }, + "strict": True, + }, + }, + { + "type": "function", + "function": { + "name": "get_weather", + "description": "use this to get latest weather information for a city given its name", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "name of the city to get weather for", + } + }, + "required": ["city"], + }, + }, + }, + ] + + messages = [{"role": "user", "content": "What is the capital of France?"}] + response = client.chat.completions.create( + model=self.model, + messages=messages, + temperature=0.8, + top_p=0.8, + stream=False, + tools=tools, + tool_choice={"type": "function", "function": {"name": "get_weather"}}, + ) + + tool_calls = response.choices[0].message.tool_calls + self.assertIsNotNone(tool_calls, "No tool_calls in the response") + function_name = tool_calls[0].function.name + arguments = tool_calls[0].function.arguments + args_obj = json.loads(arguments) + + self.assertEqual( + function_name, "get_weather", "Function name should be 'get_weather'" + ) + self.assertIn("city", args_obj, "Function arguments should have 'city'") + class TestOpenAIPythonicFunctionCalling(CustomTestCase): PYTHONIC_TOOLS = [ @@ -385,11 +530,13 @@ class TestOpenAIPythonicFunctionCalling(CustomTestCase): stream=False, ) tool_calls = response.choices[0].message.tool_calls - self.assertIsInstance(tool_calls, list) + self.assertIsInstance(tool_calls, list, "No tool_calls found") self.assertGreaterEqual(len(tool_calls), 1) names = [tc.function.name for tc in tool_calls] - self.assertIn("get_weather", names) - self.assertIn("get_tourist_attractions", names) + self.assertTrue( + "get_weather" in names or "get_tourist_attractions" in names, + f"Function name '{names}' should container either 'get_weather' or 'get_tourist_attractions'", + ) def test_pythonic_tool_call_streaming(self): """ @@ -419,8 +566,10 @@ class TestOpenAIPythonicFunctionCalling(CustomTestCase): self.assertTrue(found_tool_calls, "No tool_calls found in streaming response") self.assertTrue(found_index, "No index field found in any streamed tool_call") - self.assertIn("get_weather", found_names) - self.assertIn("get_tourist_attractions", found_names) + self.assertTrue( + "get_weather" in found_names or "get_tourist_attractions" in found_names, + f"Function name '{found_names}' should container either 'get_weather' or 'get_tourist_attractions'", + ) if __name__ == "__main__":