diff --git a/docs/backend/function_calling.ipynb b/docs/backend/function_calling.ipynb index 47a2e2278..3de80aadf 100644 --- a/docs/backend/function_calling.ipynb +++ b/docs/backend/function_calling.ipynb @@ -4,62 +4,23 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Function Calling\n", + "# Tool and Function Calling\n", "\n", - "This notebook provides a quick-start guide to use function tooling using SGLang chat completions API\n", - "\n", - "## Supported Models\n", - "\n", - "Currently, we added the support for tools calling in the following models:\n", - " - Llama 3.2 models\n", - " - Llama 3.1 models\n", - " - Qwen 2.5 models\n", - " - InternLM Models" + "This guide demonstrates how to use SGLang’s **Tool Calling** functionality." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Usage\n", - "\n", - "### Launch a server\n", - "\n", - "This code block is equivalent to executing\n", - "\n", - "`python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n", - "--port 30000 --host 0.0.0.0`\n", - "in your terminal and wait for the server to be ready. Once the server is running, you can send test requests using curl or requests. The server implements the OpenAI-compatible APIs." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from sglang.utils import (\n", - " execute_shell_command,\n", - " wait_for_server,\n", - " terminate_process,\n", - " print_highlight,\n", - ")\n", - "\n", - "\n", - "server_process = execute_shell_command(\n", - " \"\"\"\n", - " python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --port 30000 --host 0.0.0.0\n", - "\"\"\"\n", - ")\n", - "\n", - "wait_for_server(\"http://localhost:30000\")" + "## OpenAI Compatible API" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### Single Round Invocation" + "### Launching the Server" ] }, { @@ -69,7 +30,47 @@ "outputs": [], "source": [ "from openai import OpenAI\n", + "import json\n", + "from sglang.utils import (\n", + " execute_shell_command,\n", + " wait_for_server,\n", + " terminate_process,\n", + " print_highlight,\n", + ")\n", "\n", + "server_process = execute_shell_command(\n", + " \"python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --tool-call-parser llama3 --port 30333 --host 0.0.0.0\" # llama3\n", + ")\n", + "wait_for_server(\"http://localhost:30333\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note that `--tool-call-parser` defines the parser used to interpret responses. Currently supported parsers include:\n", + "\n", + "- llama3: Llama 3.1 / 3.2 (e.g. meta-llama/Llama-3.1-8B-Instruct, meta-llama/Llama-3.2-1B-Instruct).\n", + "- mistral: Mistral (e.g. mistralai/Mistral-7B-Instruct-v0.3, mistralai/Mistral-Nemo-Instruct-2407, mistralai/\n", + "Mistral-Nemo-Instruct-2407, mistralai/Mistral-7B-v0.3).\n", + "- qwen25: Qwen 2.5 (e.g. Qwen/Qwen2.5-1.5B-Instruct, Qwen/Qwen2.5-7B-Instruct)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Define Tools for Function Call\n", + "Below is a Python snippet that shows how to define a tool as a dictionary. The dictionary includes a tool name, a description, and property defined Parameters." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Define tools\n", "tools = [\n", " {\n", " \"type\": \"function\",\n", @@ -79,22 +80,264 @@ " \"parameters\": {\n", " \"type\": \"object\",\n", " \"properties\": {\n", - " \"location\": {\n", + " \"city\": {\n", " \"type\": \"string\",\n", - " \"description\": \"The city and state, e.g. San Francisco, CA\",\n", + " \"description\": \"The city to find the weather for, e.g. 'San Francisco'\",\n", + " },\n", + " \"state\": {\n", + " \"type\": \"string\",\n", + " \"description\": \"the two-letter abbreviation for the state that the city is\"\n", + " \" in, e.g. 'CA' which would mean 'California'\",\n", + " },\n", + " \"unit\": {\n", + " \"type\": \"string\",\n", + " \"description\": \"The unit to fetch the temperature in\",\n", + " \"enum\": [\"celsius\", \"fahrenheit\"],\n", " },\n", - " \"unit\": {\"type\": \"string\", \"enum\": [\"celsius\", \"fahrenheit\"]},\n", " },\n", - " \"required\": [\"location\"],\n", + " \"required\": [\"city\", \"state\", \"unit\"],\n", " },\n", " },\n", " }\n", - "]\n", - "messages = [{\"role\": \"user\", \"content\": \"What's the weather like in Boston today?\"}]\n", + "]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Define Messages" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def get_messages():\n", + " return [\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": \"What's the weather like in Boston today? Please respond with the format: Today's weather is :{function call result}\",\n", + " }\n", + " ]\n", "\n", - "client = OpenAI(api_key=\"YOUR_API_KEY\", base_url=\"http://0.0.0.0:30000/v1\")\n", - "model_name = client.models.list().data[0].id\n", - "response = client.chat.completions.create(\n", + "\n", + "messages = get_messages()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Initialize the Client" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Initialize OpenAI-like client\n", + "client = OpenAI(api_key=\"None\", base_url=\"http://0.0.0.0:30333/v1\")\n", + "model_name = client.models.list().data[0].id" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Non-Streaming Request" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Non-streaming mode test\n", + "response_non_stream = client.chat.completions.create(\n", + " model=model_name,\n", + " messages=messages,\n", + " temperature=0.8,\n", + " top_p=0.8,\n", + " stream=False, # Non-streaming\n", + " tools=tools,\n", + ")\n", + "print_highlight(\"Non-stream response:\")\n", + "print(response_non_stream)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Streaming Request" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Streaming mode test\n", + "print_highlight(\"Streaming response:\")\n", + "response_stream = client.chat.completions.create(\n", + " model=model_name,\n", + " messages=messages,\n", + " temperature=0.8,\n", + " top_p=0.8,\n", + " stream=True, # Enable streaming\n", + " tools=tools,\n", + ")\n", + "\n", + "chunks = []\n", + "for chunk in response_stream:\n", + " chunks.append(chunk)\n", + " if chunk.choices[0].delta.tool_calls:\n", + " print(chunk.choices[0].delta.tool_calls[0])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "### Handle Tool Calls\n", + "\n", + "When the engine determines it should call a particular tool, it will return arguments or partial arguments through the response. You can parse these arguments and later invoke the tool accordingly." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Non-Streaming Request**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "name_non_stream = response_non_stream.choices[0].message.tool_calls[0].function.name\n", + "arguments_non_stream = (\n", + " response_non_stream.choices[0].message.tool_calls[0].function.arguments\n", + ")\n", + "\n", + "print_highlight(f\"Final streamed function call name: {name_non_stream}\")\n", + "print_highlight(f\"Final streamed function call arguments: {arguments_non_stream}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Streaming Request**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Parse and combine function call arguments\n", + "arguments = []\n", + "for chunk in chunks:\n", + " choice = chunk.choices[0]\n", + " delta = choice.delta\n", + " if delta.tool_calls:\n", + " tool_call = delta.tool_calls[0]\n", + " if tool_call.function.name:\n", + " print_highlight(f\"Streamed function call name: {tool_call.function.name}\")\n", + "\n", + " if tool_call.function.arguments:\n", + " arguments.append(tool_call.function.arguments)\n", + " print(f\"Streamed function call arguments: {tool_call.function.arguments}\")\n", + "\n", + "# Combine all fragments into a single JSON string\n", + "full_arguments = \"\".join(arguments)\n", + "print_highlight(f\"Final streamed function call arguments: {full_arguments}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Define a Tool Function" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# This is a demonstration, define real function according to your usage.\n", + "def get_current_weather(city: str, state: str, unit: \"str\"):\n", + " return (\n", + " f\"The weather in {city}, {state} is 85 degrees {unit}. It is \"\n", + " \"partly cloudly, with highs in the 90's.\"\n", + " )\n", + "\n", + "\n", + "available_tools = {\"get_current_weather\": get_current_weather}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "## Execute the Tool" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "call_data = json.loads(full_arguments)\n", + "\n", + "messages.append(\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": \"\",\n", + " \"tool_calls\": {\"name\": \"get_current_weather\", \"arguments\": full_arguments},\n", + " }\n", + ")\n", + "\n", + "# Call the corresponding tool function\n", + "tool_name = messages[-1][\"tool_calls\"][\"name\"]\n", + "tool_to_call = available_tools[tool_name]\n", + "result = tool_to_call(**call_data)\n", + "print_highlight(f\"Function call result: {result}\")\n", + "messages.append({\"role\": \"tool\", \"content\": result, \"name\": tool_name})\n", + "\n", + "print_highlight(f\"Updated message history: {messages}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Send Results Back to Model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "final_response = client.chat.completions.create(\n", " model=model_name,\n", " messages=messages,\n", " temperature=0.8,\n", @@ -102,17 +345,56 @@ " stream=False,\n", " tools=tools,\n", ")\n", + "print_highlight(\"Non-stream response:\")\n", + "print(final_response)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Native API and SGLang Runtime (SRT)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import AutoTokenizer\n", + "import requests\n", "\n", - "print(response)\n", + "# generate an answer\n", + "tokenizer = AutoTokenizer.from_pretrained(\"meta-llama/Meta-Llama-3.1-8B-Instruct\")\n", "\n", - "\"\"\"\n", + "messages = get_messages()\n", "\n", - "ChatCompletion(id='d6f620e1767e490d85b5ce45c15151cf', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content=None, refusal=None, \n", - "role='assistant', audio=None, function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='0', function=Function(arguments='{\"a\": \"3\", \"b\": \"5\"}', name='add'), type='function')]), \n", - "matched_stop=128008)], created=1735411703, model='meta-llama/Llama-3.2-1B-Instruct', object='chat.completion', service_tier=None, system_fingerprint=None, \n", - "usage=CompletionUsage(completion_tokens=23, prompt_tokens=198, total_tokens=221, completion_tokens_details=None, prompt_tokens_details=None))\n", + "input = tokenizer.apply_chat_template(\n", + " messages,\n", + " tokenize=False,\n", + " add_generation_prompt=True,\n", + " tools=tools,\n", + ")\n", "\n", - "\"\"\"" + "gen_url = \"http://localhost:30333/generate\"\n", + "gen_data = {\"text\": input, \"sampling_params\": {\"skip_special_tokens\": False}}\n", + "gen_response = requests.post(gen_url, json=gen_data).json()[\"text\"]\n", + "print(gen_response)\n", + "\n", + "# parse the response\n", + "parse_url = \"http://localhost:30333/function_call\"\n", + "\n", + "function_call_input = {\n", + " \"text\": gen_response,\n", + " \"tool_call_parser\": \"llama3\",\n", + " \"tools\": tools,\n", + "}\n", + "\n", + "function_call_response = requests.post(parse_url, json=function_call_input)\n", + "function_call_response_json = function_call_response.json()\n", + "print(\"function name: \", function_call_response_json[\"calls\"][0][\"name\"])\n", + "print(\"function arguments: \", function_call_response_json[\"calls\"][0][\"parameters\"])" ] }, { @@ -128,11 +410,98 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## How to support a new model?\n", + "## Offline Engine API" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sglang as sgl\n", + "from sglang.srt.function_call_parser import FunctionCallParser\n", + "from sglang.srt.managers.io_struct import Tool, Function\n", "\n", - "For adding support of more different models:\n", - " 1. Update the `TOOLS_TAG_LIST` in `sglang/srt/utils.py` with the tool tag used by the model.\n", - " 2. Add support in `parse_tool_response` function for converting into tool calls `sglang/srt/utils.py`\n" + "llm = sgl.Engine(model_path=\"meta-llama/Meta-Llama-3.1-8B-Instruct\")\n", + "tokenizer = llm.tokenizer_manager.tokenizer\n", + "input_ids = tokenizer.apply_chat_template(\n", + " messages, tokenize=True, add_generation_prompt=True, tools=tools\n", + ")\n", + "\n", + "sampling_params = {\n", + " \"max_new_tokens\": 128,\n", + " \"temperature\": 0.3,\n", + " \"top_p\": 0.95,\n", + " \"skip_special_tokens\": False,\n", + "}\n", + "\n", + "# 1) Offline generation\n", + "result = llm.generate(input_ids=input_ids, sampling_params=sampling_params)\n", + "generated_text = result[\"text\"] # Assume there is only one prompt\n", + "\n", + "print(\"=== Offline Engine Output Text ===\")\n", + "print(generated_text)\n", + "\n", + "\n", + "# 2) Parse using FunctionCallParser\n", + "def convert_dict_to_tool(tool_dict: dict) -> Tool:\n", + " function_dict = tool_dict.get(\"function\", {})\n", + " return Tool(\n", + " type=tool_dict.get(\"type\", \"function\"),\n", + " function=Function(\n", + " name=function_dict.get(\"name\"),\n", + " description=function_dict.get(\"description\"),\n", + " parameters=function_dict.get(\"parameters\"),\n", + " ),\n", + " )\n", + "\n", + "\n", + "tools = [convert_dict_to_tool(raw_tool) for raw_tool in tools]\n", + "\n", + "parser = FunctionCallParser(tools=tools, tool_call_parser=\"llama3\")\n", + "normal_text, calls = parser.parse_non_stream(generated_text)\n", + "\n", + "print(\"\\n=== Parsing Result ===\")\n", + "print(\"Normal text portion:\", normal_text)\n", + "print(\"Function call portion:\")\n", + "for call in calls:\n", + " # call: ToolCallItem\n", + " print(f\" - tool name: {call.name}\")\n", + " print(f\" parameters: {call.parameters}\")\n", + "\n", + "# 3) If needed, perform additional logic on the parsed functions, such as automatically calling the corresponding function to obtain a return value, etc." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "llm.shutdown()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## How to support a new model?\n", + "1. Update the TOOLS_TAG_LIST in sglang/srt/function_call_parser.py with the model’s tool tags. Currently supported tags include:\n", + "```\n", + "\tTOOLS_TAG_LIST = [\n", + "\t “<|plugin|>“,\n", + "\t ““,\n", + "\t “<|python_tag|>“,\n", + "\t “[TOOL_CALLS]”\n", + "\t]\n", + "```\n", + "2. Create a new detector class in sglang/srt/function_call_parser.py that inherits from BaseFormatDetector. The detector should handle the model’s specific function call format. For example:\n", + "```\n", + " class NewModelDetector(BaseFormatDetector):\n", + "```\n", + "3. Add the new detector to the MultiFormatParser class that manages all the format detectors." ] } ], diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 0ebce1a85..1759cd2bb 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -39,10 +39,12 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import ORJSONResponse, Response, StreamingResponse from sglang.srt.entrypoints.engine import _launch_subprocesses +from sglang.srt.function_call_parser import FunctionCallParser from sglang.srt.managers.io_struct import ( CloseSessionReqInput, ConfigureLoggingReq, EmbeddingReqInput, + FunctionCallReqInput, GenerateReqInput, GetWeightsByNameReqInput, InitWeightsUpdateGroupReqInput, @@ -369,6 +371,28 @@ async def configure_logging(obj: ConfigureLoggingReq, request: Request): return Response(status_code=200) +@app.post("/function_call") +async def function_call_request(obj: FunctionCallReqInput, request: Request): + """ + A native API endpoint to parse function calls from a text. + """ + # 1) Initialize the parser based on the request body + parser = FunctionCallParser(tools=obj.tools, tool_call_parser=obj.tool_call_parser) + + # 2) Call the non-stream parsing method (non-stream) + normal_text, calls = parser.parse_non_stream(obj.text) + + # 3) Organize the response content + response_data = { + "normal_text": normal_text, + "calls": [ + call.model_dump() for call in calls + ], # Convert pydantic objects to dictionaries + } + + return ORJSONResponse(content=response_data, status_code=200) + + ##### OpenAI-compatible API endpoints ##### diff --git a/python/sglang/srt/function_call_parser.py b/python/sglang/srt/function_call_parser.py new file mode 100644 index 000000000..3def4e1eb --- /dev/null +++ b/python/sglang/srt/function_call_parser.py @@ -0,0 +1,494 @@ +import json +import re +from abc import ABC, abstractmethod +from json import JSONDecodeError, JSONDecoder +from typing import Any, Dict, List, Optional, Tuple + +import partial_json_parser +from partial_json_parser.core.options import Allow +from pydantic import BaseModel, Field + +TOOLS_TAG_LIST = [ + "<|plugin|>", + "", + "<|python_tag|>", + "[TOOL_CALLS]", +] + + +class Function(BaseModel): + """Function Tool Template.""" + + description: Optional[str] = Field(default=None, examples=[None]) + name: Optional[str] = None + parameters: Optional[object] = None + + +class ToolCallItem(BaseModel): + """Simple encapsulation of the parsed ToolCall result for easier usage in streaming contexts.""" + + tool_index: int + name: Optional[str] = None + parameters: str # JSON string + + +def _find_common_prefix(s1: str, s2: str) -> str: + prefix = "" + min_length = min(len(s1), len(s2)) + for i in range(0, min_length): + if s1[i] == s2[i]: + prefix += s1[i] + else: + break + return prefix + + +def _partial_json_loads(input_str: str, flags: Allow) -> Tuple[Any, int]: + try: + return (partial_json_parser.loads(input_str, flags), len(input_str)) + except JSONDecodeError as e: + if "Extra data" in e.msg: + dec = JSONDecoder() + return dec.raw_decode(input_str) + raise + + +def _is_complete_json(input_str: str) -> bool: + try: + json.loads(input_str) + return True + except JSONDecodeError: + return False + + +class StreamingParseResult: + """Result of streaming incremental parsing.""" + + def __init__( + self, normal_text: str = "", calls: Optional[List[ToolCallItem]] = None + ): + self.normal_text = normal_text + self.calls = calls or [] + + +class BaseFormatDetector: + """Base class providing two sets of interfaces: one-time and streaming incremental.""" + + def __init__(self): + # initialize properties used for state when parsing tool calls in + self._buffer = "" + # streaming mode + self.prev_tool_call_arr: List[Dict] = [] + self.current_tool_id: int = -1 + self.current_tool_name_sent: bool = False + self.streamed_args_for_tool: List[str] = ( + [] + ) # map what has been streamed for each tool so far to a list + self.bot_token = "" + self.eot_token = "" + + def parse_base_json(self, action: Dict, tools: List[Function]): + name, parameters = action["name"], json.dumps( + action.get("parameters", action.get("arguments", {})), + ensure_ascii=False, + ) + tool_index = [tool.function.name for tool in tools].index(name) + tool_call_item = ToolCallItem( + tool_index=tool_index, name=name, parameters=parameters + ) + calls = [tool_call_item] + return calls + + def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]: + """ + Parses the text in one go. Returns success=True if the format matches, otherwise False. + Note that leftover_text here represents "content that this parser will not consume further". + """ + action = json.loads(text) + return self.parse_base_json(action, tools) + + def parse_streaming_increment( + self, new_text: str, tools: List[Function] + ) -> StreamingParseResult: + """ + Streaming incremental parsing, referencing the logic of Llama32Detector. + We partially parse JSON within ..., and handle + incremental argument output. + """ + # Append new text to buffer + self._buffer += new_text + current_text = self._buffer + if not (self.bot_token in current_text or current_text.startswith("{")): + self._buffer = "" + if self.eot_token in new_text: + new_text = new_text.replace(self.eot_token, "") + return StreamingParseResult(normal_text=new_text) + + # bit mask flags for partial JSON parsing. If the name hasn't been + # sent yet, don't allow sending + # an incomplete string since OpenAI only ever (as far as I have + # seen) allows sending the entire tool/ function name at once. + flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR + try: + tool_call_arr = [] + is_complete = [] + try: + # depending on the prompt format the Llama model may or may not + # prefix the output with the <|python_tag|> token + start_idx = ( + len(self.bot_token) + if current_text.startswith(self.bot_token) + else 0 + ) + while start_idx < len(current_text): + (obj, end_idx) = _partial_json_loads( + current_text[start_idx:], flags + ) + is_complete.append( + _is_complete_json(current_text[start_idx : start_idx + end_idx]) + ) + start_idx += end_idx + len("; ") + # depending on the prompt Llama can use + # either arguments or parameters + if "parameters" in obj: + assert ( + "arguments" not in obj + ), "model generated both parameters and arguments" + obj["arguments"] = obj["parameters"] + tool_call_arr.append(obj) + + except partial_json_parser.core.exceptions.MalformedJSON: + # not enough tokens to parse into JSON yet + return StreamingParseResult() + + # select as the current tool call the one we're on the state at + current_tool_call: Dict = ( + tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {} + ) + + # case -- if no tokens have been streamed for the tool, e.g. + # only the array brackets, stream nothing + if len(tool_call_arr) == 0: + return StreamingParseResult() + + # case: we are starting a new tool in the array + # -> array has > 0 length AND length has moved past cursor + elif ( + len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1 + ): + + # if we're moving on to a new call, first make sure we + # haven't missed anything in the previous one that was + # auto-generated due to JSON completions, but wasn't + # streamed to the client yet. + if self.current_tool_id >= 0: + cur_arguments = current_tool_call.get("arguments") + if cur_arguments: + cur_args_json = json.dumps(cur_arguments) + sent = len(self.streamed_args_for_tool[self.current_tool_id]) + argument_diff = cur_args_json[sent:] + + res = StreamingParseResult( + normal_text=None, + calls=[ + ToolCallItem( + tool_index=self.current_tool_id, + name="", + parameters=argument_diff, + ) + ], + ) + self.streamed_args_for_tool[ + self.current_tool_id + ] += argument_diff + else: + res = StreamingParseResult() + else: + res = StreamingParseResult() + # re-set stuff pertaining to progress in the current tool + self.current_tool_id = len(tool_call_arr) - 1 + self.current_tool_name_sent = False + self.streamed_args_for_tool.append("") + print("starting on new tool %d", self.current_tool_id) + return res + + # if the current tool name hasn't been sent, send if available + # - otherwise send nothing + elif not self.current_tool_name_sent: + function_name = current_tool_call.get("name") + if function_name: + res = StreamingParseResult( + normal_text=None, + calls=[ + ToolCallItem( + tool_index=self.current_tool_id, + name=function_name, + parameters="", + ) + ], + ) + self.current_tool_name_sent = True + else: + res = StreamingParseResult() + + # now we know we're on the same tool call and we're streaming + # arguments + else: + cur_arguments = current_tool_call.get("arguments") + res = StreamingParseResult() + + if cur_arguments: + sent = len(self.streamed_args_for_tool[self.current_tool_id]) + cur_args_json = json.dumps(cur_arguments) + prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get( + "arguments" + ) + + argument_diff = None + if is_complete[self.current_tool_id]: + argument_diff = cur_args_json[sent:] + self._buffer = "" + self.prev_tool_call_arr[self.current_tool_id].clear() + self.current_tool_name_sent: bool = False + self.streamed_args_for_tool[self.current_tool_id] = "" + + elif prev_arguments: + prev_args_json = json.dumps(prev_arguments) + if cur_args_json != prev_args_json: + + prefix = _find_common_prefix(prev_args_json, cur_args_json) + argument_diff = prefix[sent:] + + if argument_diff is not None: + res = StreamingParseResult( + calls=[ + ToolCallItem( + tool_index=self.current_tool_id, + name="", + parameters=argument_diff, + ) + ], + ) + if not is_complete[self.current_tool_id]: + self.streamed_args_for_tool[ + self.current_tool_id + ] += argument_diff + + self.prev_tool_call_arr = tool_call_arr + return res + + except Exception as e: + print(e) + # Skipping chunk as a result of tool streaming extraction error + return StreamingParseResult() + + +class Qwen25Detector(BaseFormatDetector): + """ + Detector for Qwen 2.5 models. + Assumes function call format: + {"name":"xxx", "arguments":{...}} + """ + + def __init__(self): + """ + Initializes the detector with necessary state variables. + """ + super().__init__() + self.bot_token = "" + self.eot_token = "" + + def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]: + """ + One-time parsing: Detects and parses tool calls in the provided text. + + :param text: The complete text to parse. + :param tools: List of available tools. + :return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls. + """ + if "" not in text: + return [] + pattern = r"(.*?)" + match_result_list = re.findall(pattern, text, re.DOTALL) + calls = [] + for match_result in match_result_list: + match_result = json.loads(match_result) + calls.extend(self.parse_base_json(match_result, tools)) + return calls + + +class MistralDetector(BaseFormatDetector): + """ + Detector for Mistral models. + Assumes function call format: + <|action_start|><|plugin|>{"name":"xxx", "arguments":{...}}<|action_end|> + """ + + def __init__(self): + """ + Initializes the detector with necessary state variables. + """ + super().__init__() + self.bot_token = "[TOOL_CALLS] [" + self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL) + + def _clean_text(self, text: str) -> str: + """ + clean text to only leave ''[TOOL_CALLS] [{"name": xxx, "arguments": {xxx}}]' + for example, + text = '[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"location": "Boston, MA", "unit": "fahrenheit"}}]\n\nToday\'s weather in Boston is :{function call result} (in Fahrenheit)\n\nIf you prefer Celsius, please let me know.' + return '[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"location": "Boston, MA", "unit": "fahrenheit"}}]' + The key pattern is [TOOL_CALLS] [...] + """ + find_results = re.findall(r"\[TOOL_CALLS\] \[.*?\]", text, re.DOTALL) + if len(find_results) > 0: + return find_results[0] + else: + return "" + + def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]: + """ + One-time parsing: Detects and parses tool calls in the provided text. + + :param text: The complete text to parse. + :param tools: List of available tools. + :return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls. + """ + text = self._clean_text(text) + tool_content = text.replace("[TOOL_CALLS]", "").strip() + raw_tool_calls = self.tool_call_regex.findall(tool_content) + calls = [] + if len(raw_tool_calls) > 0: + raw_tool_call = raw_tool_calls[0] + function_call_arr = json.loads(raw_tool_call) + for match_result in function_call_arr: + calls.extend(self.parse_base_json(match_result, tools)) + return calls + + +class Llama32Detector(BaseFormatDetector): + """ + Detector for Llama 3.2 models. + Assumes function call format: + <|python_tag|>{"name":"xxx", "arguments":{...}} + Does not require a closing tag "", + relies on json.loads(...) success to determine if JSON is complete. + """ + + def __init__(self): + """ + Initializes the detector with necessary state variables. + """ + super().__init__() + self.bot_token = "<|python_tag|>" + + def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]: + """ + One-time parsing: Detects and parses tool calls in the provided text. + + :param text: The complete text to parse. + :param tools: List of available tools. + :return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls. + """ + + if "<|python_tag|>" not in text: + return [] + _, action = text.split("<|python_tag|>") + action = json.loads(action) + return self.parse_base_json(action, tools) + + +class MultiFormatParser: + def __init__(self, detectors: List[BaseFormatDetector]): + """ + :param detectors: A series of available Detector instances passed in + """ + self.detectors = detectors + + def parse_once(self, text: str, tools: List[Function]): + """ + One-time parsing: Loop through detectors until there are no new matches or text is exhausted + Return: (final_text, all_calls) + - final_text: The remaining text after parsing that was not consumed by any Detector (can be treated as normal text) + - all_calls: All calls parsed by the Detectors + """ + final_calls = [] + final_normal_text = text + for detector in self.detectors: + tool_call_list = detector.detect_and_parse(text, tools) + if len(tool_call_list) > 0: # parsed successfully + final_calls = tool_call_list + break + + # leftover_text is the normal text not consumed by any Detector + return final_normal_text, final_calls + + def parse_streaming_increment(self, new_text: str, tools: List[Function]): + """ + Streaming incremental parsing: Feed new_text to each detector's parse_streaming_increment + and merge their produced normal_text/calls to return. + (The logic here can be "priority-based" or "parallel parsing" based on your needs) + """ + final_normal_text = "" + final_calls = [] + + for detector in self.detectors: + sp_result = detector.parse_streaming_increment(new_text, tools) + # Merge normal_text and calls + # If one sp_result contains result call, this should be a successful parse + # If one sp_result only contains normal_text, this can either be a successful + # parse or it is not using the desired parsing tool. + if sp_result.normal_text: + final_normal_text = sp_result.normal_text + if sp_result.calls: + final_calls.extend(sp_result.calls) + final_normal_text = sp_result.normal_text + break + + return final_normal_text, final_calls + + +class FunctionCallParser: + """ + In streaming scenarios, each time new_text is received, it calls multi_format_parser.parse_streaming_increment + and returns the resulting normal_text and calls to the upper layer (or SSE). + """ + + ToolCallParserEnum: Dict[str, BaseFormatDetector] = { + "llama3": Llama32Detector, + "qwen25": Qwen25Detector, + "mistral": MistralDetector, + } + + def __init__(self, tools: List[Function], tool_call_parser: str = None): + detectors = [] + if tool_call_parser: + detector_class = self.ToolCallParserEnum.get(tool_call_parser) + if detector_class: + detectors.append(detector_class()) + else: + raise ValueError(f"Unsupported tool_call_parser: {tool_call_parser}") + else: + raise ValueError("Tool Call Parser Not Given!") + + self.multi_format_parser = MultiFormatParser(detectors) + self.tools = tools + + def parse_non_stream(self, full_text: str): + """ + Non-streaming call: one-time parsing + """ + full_normal_text, calls = self.multi_format_parser.parse_once( + full_text, self.tools + ) + return full_normal_text, calls + + def parse_stream_chunk(self, chunk_text: str): + """ + Streaming call: incremental parsing + """ + normal_text, calls = self.multi_format_parser.parse_streaming_increment( + chunk_text, self.tools + ) + return normal_text, calls diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index a2f25abc2..f7419d04f 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -17,7 +17,7 @@ processes (TokenizerManager, DetokenizerManager, Controller). """ import uuid -from dataclasses import dataclass +from dataclasses import dataclass, field from enum import Enum from typing import Dict, List, Optional, Union @@ -540,3 +540,27 @@ class CloseSessionReqInput: class OpenSessionReqOutput: session_id: Optional[str] success: bool + + +@dataclass +class Function: + description: Optional[str] = None + name: Optional[str] = None + parameters: Optional[object] = None + + +@dataclass +class Tool: + function: Function + type: Optional[str] = "function" + + +@dataclass +class FunctionCallReqInput: + text: str # The text to parse. + tools: List[Tool] = field( + default_factory=list + ) # A list of available function tools (name, parameters, etc.). + tool_call_parser: Optional[str] = ( + None # Specify the parser type, e.g. 'llama3', 'qwen25', or 'mistral'. If not specified, tries all. + ) diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index 5056ba22e..6687a4c01 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -20,7 +20,7 @@ import os import time import uuid from http import HTTPStatus -from typing import Dict, List +from typing import Dict, List, Optional from fastapi import HTTPException, Request, UploadFile from fastapi.responses import ORJSONResponse, StreamingResponse @@ -40,6 +40,7 @@ from sglang.srt.conversation import ( generate_chat_conv, register_conv_template, ) +from sglang.srt.function_call_parser import TOOLS_TAG_LIST, FunctionCallParser from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput from sglang.srt.openai_api.protocol import ( BatchRequest, @@ -71,7 +72,6 @@ from sglang.srt.openai_api.protocol import ( TopLogprob, UsageInfo, ) -from sglang.srt.utils import TOOLS_TAG_LIST, parse_tool_response from sglang.utils import get_exception_traceback logger = logging.getLogger(__name__) @@ -309,6 +309,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe ret, to_file=True, cache_report=tokenizer_manager.server_args.enable_cache_report, + tool_call_parser=tokenizer_manager.server_args.tool_call_parser, ) else: responses = v1_generate_response( @@ -877,9 +878,6 @@ def v1_chat_generate_request( tools = None if request.tools and request.tool_choice != "none": request.skip_special_tokens = False - if request.stream: - logger.warning("Streaming is not supported with tools.") - request.stream = False if not isinstance(request.tool_choice, str): tools = [ item.function.model_dump() @@ -908,12 +906,26 @@ def v1_chat_generate_request( openai_compatible_messages = openai_compatible_messages[:-1] else: assistant_prefix = None - prompt_ids = tokenizer_manager.tokenizer.apply_chat_template( - openai_compatible_messages, - tokenize=True, - add_generation_prompt=True, - tools=tools, - ) + + try: + prompt_ids = tokenizer_manager.tokenizer.apply_chat_template( + openai_compatible_messages, + tokenize=True, + add_generation_prompt=True, + tools=tools, + ) + except: + # This except branch will be triggered when the chosen model + # has a different tools input format that is not compatiable + # with openAI's apply_chat_template tool_call format, like Mistral. + tools = [t if "function" in t else {"function": t} for t in tools] + prompt_ids = tokenizer_manager.tokenizer.apply_chat_template( + openai_compatible_messages, + tokenize=True, + add_generation_prompt=True, + tools=tools, + ) + if assistant_prefix: prompt_ids += tokenizer_manager.tokenizer.encode(assistant_prefix) stop = request.stop @@ -1005,7 +1017,9 @@ def v1_chat_generate_request( return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0] -def v1_chat_generate_response(request, ret, to_file=False, cache_report=False): +def v1_chat_generate_response( + request, ret, to_file=False, cache_report=False, tool_call_parser=None +): choices = [] for idx, ret_item in enumerate(ret): @@ -1066,12 +1080,13 @@ def v1_chat_generate_response(request, ret, to_file=False, cache_report=False): if finish_reason == "stop": finish_reason = "tool_calls" try: - text, call_info_list = parse_tool_response(text, tools) # noqa + parser = FunctionCallParser(tools, tool_call_parser) + full_normal_text, call_info_list = parser.parse_non_stream(text) tool_calls = [ ToolCall( - id=str(call_info[0]), + id=str(call_info.tool_index), function=FunctionResponse( - name=call_info[1], arguments=call_info[2] + name=call_info.name, arguments=call_info.parameters ), ) for call_info in call_info_list @@ -1172,6 +1187,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): adapted_request, request = v1_chat_generate_request(all_requests, tokenizer_manager) if adapted_request.stream: + parser_dict = {} async def generate_stream_resp(): is_firsts = {} @@ -1184,6 +1200,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): adapted_request, raw_request ): index = content.get("index", 0) + text = content["text"] is_first = is_firsts.get(index, True) stream_buffer = stream_buffers.get(index, "") @@ -1263,29 +1280,111 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): text = content["text"] delta = text[len(stream_buffer) :] - stream_buffer = stream_buffer + delta - choice_data = ChatCompletionResponseStreamChoice( - index=index, - delta=DeltaMessage(content=delta), - finish_reason=(finish_reason["type"] if finish_reason else ""), - matched_stop=( - finish_reason["matched"] - if finish_reason and "matched" in finish_reason - else None - ), - logprobs=choice_logprobs, - ) - chunk = ChatCompletionStreamResponse( - id=content["meta_info"]["id"], - choices=[choice_data], - model=request.model, - ) + new_stream_buffer = stream_buffer + delta - is_firsts[index] = is_first - stream_buffers[index] = stream_buffer - n_prev_tokens[index] = n_prev_token + if request.tool_choice != "none" and request.tools: + if index not in parser_dict: + parser_dict[index] = FunctionCallParser( + tools=request.tools, + tool_call_parser=tokenizer_manager.server_args.tool_call_parser, + ) + parser = parser_dict[index] - yield f"data: {chunk.model_dump_json()}\n\n" + # parse_increment => returns (normal_text, calls) + normal_text, calls = parser.parse_stream_chunk(delta) + + # 1) if there's normal_text, output it as normal content + if normal_text: + choice_data = ChatCompletionResponseStreamChoice( + index=index, + delta=DeltaMessage(content=normal_text), + finish_reason=( + finish_reason["type"] if finish_reason else "" + ), + ) + chunk = ChatCompletionStreamResponse( + id=content["meta_info"]["id"], + choices=[choice_data], + model=request.model, + ) + yield f"data: {chunk.model_dump_json()}\n\n" + + # 2) if we found calls, we output them as separate chunk(s) + for call_item in calls: + # transform call_item -> FunctionResponse + ToolCall + + if ( + content["meta_info"]["finish_reason"] + and content["meta_info"]["finish_reason"]["type"] + == "stop" + ): + latest_delta_len = 0 + if isinstance(call_item.parameters, str): + latest_delta_len = len(call_item.parameters) + + expected_call = json.dumps( + parser.multi_format_parser.detectors[0] + .prev_tool_call_arr[index] + .get("arguments", {}), + ensure_ascii=False, + ) + actual_call = parser.multi_format_parser.detectors[ + 0 + ].streamed_args_for_tool[index] + if latest_delta_len > 0: + actual_call = actual_call[:-latest_delta_len] + remaining_call = expected_call.replace( + actual_call, "", 1 + ) + call_item.parameters = remaining_call + + tool_call = ToolCall( + id=str(call_item.tool_index), + function=FunctionResponse( + name=call_item.name, + arguments=call_item.parameters, + ), + ) + choice_data = ChatCompletionResponseStreamChoice( + index=index, + delta=DeltaMessage( + role="assistant", tool_calls=[tool_call] + ), + finish_reason="tool_call", + ) + chunk = ChatCompletionStreamResponse( + id=content["meta_info"]["id"], + choices=[choice_data], + model=request.model, + ) + yield f"data: {chunk.model_dump_json()}\n\n" + + stream_buffers[index] = new_stream_buffer + is_firsts[index] = is_first + + else: + # No tool calls => just treat this as normal text + choice_data = ChatCompletionResponseStreamChoice( + index=index, + delta=DeltaMessage(content=delta), + finish_reason=( + finish_reason["type"] if finish_reason else "" + ), + matched_stop=( + finish_reason["matched"] + if finish_reason and "matched" in finish_reason + else None + ), + logprobs=choice_logprobs, + ) + chunk = ChatCompletionStreamResponse( + id=content["meta_info"]["id"], + choices=[choice_data], + model=request.model, + ) + yield f"data: {chunk.model_dump_json()}\n\n" + stream_buffers[index] = new_stream_buffer + is_firsts[index] = is_first if request.stream_options and request.stream_options.include_usage: total_prompt_tokens = sum( tokens @@ -1333,7 +1432,10 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): ret = [ret] response = v1_chat_generate_response( - request, ret, cache_report=tokenizer_manager.server_args.enable_cache_report + request, + ret, + cache_report=tokenizer_manager.server_args.enable_cache_report, + tool_call_parser=tokenizer_manager.server_args.tool_call_parser, ) return response diff --git a/python/sglang/srt/openai_api/protocol.py b/python/sglang/srt/openai_api/protocol.py index 2ed9006c0..95b34527e 100644 --- a/python/sglang/srt/openai_api/protocol.py +++ b/python/sglang/srt/openai_api/protocol.py @@ -262,7 +262,7 @@ class Function(BaseModel): """Function descriptions.""" description: Optional[str] = Field(default=None, examples=[None]) - name: str + name: Optional[str] = None parameters: Optional[object] = None @@ -276,7 +276,7 @@ class Tool(BaseModel): class ToolChoiceFuncName(BaseModel): """The name of tool choice function.""" - name: str + name: Optional[str] = None class ToolChoice(BaseModel): @@ -329,8 +329,8 @@ class ChatCompletionRequest(BaseModel): class FunctionResponse(BaseModel): """Function response.""" - name: str - arguments: str + name: Optional[str] = None + arguments: Optional[str] = None class ToolCall(BaseModel): @@ -367,6 +367,7 @@ class ChatCompletionResponse(BaseModel): class DeltaMessage(BaseModel): role: Optional[str] = None content: Optional[str] = None + tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None]) class ChatCompletionResponseStreamChoice(BaseModel): diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 330c38132..e841a4799 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -161,6 +161,7 @@ class ServerArgs: # Custom logit processor enable_custom_logit_processor: bool = False + tool_call_parser: str = None def __post_init__(self): # Set missing default values @@ -877,6 +878,14 @@ class ServerArgs: action="store_true", help="Enable users to pass custom logit processors to the server (disabled by default for security)", ) + # Function Calling + parser.add_argument( + "--tool-call-parser", + type=str, + choices=["qwen25", "mistral", "llama3"], + default=ServerArgs.tool_call_parser, + help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', and 'llama3'.", + ) @classmethod def from_cli_args(cls, args: argparse.Namespace): diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 0568f0fd4..ff6f3a981 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -1243,68 +1243,6 @@ def dataclass_to_string_truncated(data, max_length=2048): return str(data) -TOOLS_TAG_LIST = ["<|plugin|>", "", "<|python_tag|>"] - - -def parse_tool_response(text, tools, **kwargs): - """Parse model response containing tool information. - - Args: - text(str): model response in string format - tools(List): tools from user request - """ - if "<|plugin|>" in text: # internlm2 - text, action = text.split("<|action_start|><|plugin|>") - action = action.split("<|action_end|>".strip())[0] - action = action[action.find("{") :] - action = json.loads(action) - name, parameters = action["name"], json.dumps( - action.get("parameters", action.get("arguments", {})), ensure_ascii=False - ) - call_info_list = [(name, parameters)] - elif "") - parameters = action[action.find("{") :] - name = action.split("{")[0] - call_info_list = [(name, parameters)] - elif "" in text and "" in text: # qwen2.5 - # get tool_call in text - pattern = r"(.*?)" - match_result_list = re.findall(pattern, text, re.DOTALL) - call_info_list = [] - for match_result in match_result_list: - action = json.loads(match_result) - call_info_list.append( - (action["name"], json.dumps(action["arguments"], ensure_ascii=False)) - ) - # get text outside of tags - if not text.startswith(""): - text = text[: text.find("")] - elif not text.endswith(""): - text = text[text.rfind("") + len("") :] - else: - text = "" - elif "<|python_tag|>" in text: # llama3.2 - _, action = text.split("<|python_tag|>") - action = json.loads(action) - name, parameters = action["name"], json.dumps( - action.get("parameters", action.get("arguments", {})), ensure_ascii=False - ) - call_info_list = [(name, parameters)] - else: - raise RuntimeError(f"Unexpected model response: {text}") - - call_info_list = [ - ( - [tool.function.name for tool in tools].index(call_info[0]), - call_info[0], - call_info[1], - ) - for call_info in call_info_list - ] - return text, call_info_list - - def permute_weight(x: torch.Tensor) -> torch.Tensor: b_ = x.shape[0] n_ = x.shape[1] diff --git a/test/srt/test_function_calling.py b/test/srt/test_function_calling.py new file mode 100644 index 000000000..24f341a5e --- /dev/null +++ b/test/srt/test_function_calling.py @@ -0,0 +1,249 @@ +import json +import time +import unittest + +import openai + +from sglang.srt.hf_transformers_utils import get_tokenizer +from sglang.srt.utils import kill_process_tree +from sglang.test.test_utils import ( + DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +class TestOpenAIServerFunctionCalling(unittest.TestCase): + @classmethod + def setUpClass(cls): + # Replace with the model name needed for testing; if not required, reuse DEFAULT_SMALL_MODEL_NAME_FOR_TEST + cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.api_key = "sk-123456" + + # Start the local OpenAI Server. If necessary, you can add other parameters such as --enable-tools. + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + api_key=cls.api_key, + other_args=[ + # If your server needs extra parameters to test function calling, please add them here. + "--tool-call-parser", + "llama3", + ], + ) + cls.base_url += "/v1" + cls.tokenizer = get_tokenizer(cls.model) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_function_calling_format(self): + """ + Test: Whether the function call format returned by the AI is correct. + When returning a tool call, message.content should be None, and tool_calls should be a list. + """ + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + + tools = [ + { + "type": "function", + "function": { + "name": "add", + "description": "Compute the sum of two numbers", + "parameters": { + "type": "object", + "properties": { + "a": { + "type": "int", + "description": "A number", + }, + "b": { + "type": "int", + "description": "A number", + }, + }, + "required": ["a", "b"], + }, + }, + } + ] + + messages = [{"role": "user", "content": "Compute (3+5)"}] + response = client.chat.completions.create( + model=self.model, + messages=messages, + temperature=0.8, + top_p=0.8, + stream=False, + tools=tools, + ) + + content = response.choices[0].message.content + tool_calls = response.choices[0].message.tool_calls + + assert content is None, ( + "When function call is successful, message.content should be None, " + f"but got: {content}" + ) + assert ( + isinstance(tool_calls, list) and len(tool_calls) > 0 + ), "tool_calls should be a non-empty list" + + function_name = tool_calls[0].function.name + assert function_name == "add", "Function name should be 'add'" + + def test_function_calling_streaming_simple(self): + """ + Test: Whether the function name can be correctly recognized in streaming mode. + - Expect a function call to be found, and the function name to be correct. + - Verify that streaming mode returns at least multiple chunks. + """ + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + + tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The city to find the weather for", + }, + "unit": { + "type": "string", + "description": "Weather unit (celsius or fahrenheit)", + "enum": ["celsius", "fahrenheit"], + }, + }, + "required": ["city", "unit"], + }, + }, + } + ] + + messages = [{"role": "user", "content": "What is the temperature in Paris?"}] + + response_stream = client.chat.completions.create( + model=self.model, + messages=messages, + temperature=0.8, + top_p=0.8, + stream=True, + tools=tools, + ) + + chunks = list(response_stream) + self.assertTrue(len(chunks) > 0, "Streaming should return at least one chunk") + + found_function_name = False + for chunk in chunks: + choice = chunk.choices[0] + # Check whether the current chunk contains tool_calls + if choice.delta.tool_calls: + tool_call = choice.delta.tool_calls[0] + if tool_call.function.name: + self.assertEqual( + tool_call.function.name, + "get_current_weather", + "Function name should be 'get_current_weather'", + ) + found_function_name = True + break + + self.assertTrue( + found_function_name, + "Target function name 'get_current_weather' was not found in the streaming chunks", + ) + + def test_function_calling_streaming_args_parsing(self): + """ + Test: Whether the function call arguments returned in streaming mode can be correctly concatenated into valid JSON. + - The user request requires multiple parameters. + - AI may return the arguments in chunks that need to be concatenated. + """ + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + + tools = [ + { + "type": "function", + "function": { + "name": "add", + "description": "Compute the sum of two integers", + "parameters": { + "type": "object", + "properties": { + "a": { + "type": "int", + "description": "First integer", + }, + "b": { + "type": "int", + "description": "Second integer", + }, + }, + "required": ["a", "b"], + }, + }, + } + ] + + messages = [ + {"role": "user", "content": "Please sum 5 and 7, just call the function."} + ] + + response_stream = client.chat.completions.create( + model=self.model, + messages=messages, + temperature=0.9, + top_p=0.9, + stream=True, + tools=tools, + ) + + argument_fragments = [] + function_name = None + for chunk in response_stream: + choice = chunk.choices[0] + if choice.delta.tool_calls: + tool_call = choice.delta.tool_calls[0] + # Record the function name on first occurrence + function_name = tool_call.function.name or function_name + # In case of multiple chunks, JSON fragments may need to be concatenated + if tool_call.function.arguments: + argument_fragments.append(tool_call.function.arguments) + + self.assertEqual(function_name, "add", "Function name should be 'add'") + joined_args = "".join(argument_fragments) + self.assertTrue( + len(joined_args) > 0, + "No parameter fragments were returned in the function call", + ) + + # Check whether the concatenated JSON is valid + try: + args_obj = json.loads(joined_args) + except json.JSONDecodeError: + self.fail( + "The concatenated tool call arguments are not valid JSON, parsing failed" + ) + + self.assertIn("a", args_obj, "Missing parameter 'a'") + self.assertIn("b", args_obj, "Missing parameter 'b'") + self.assertEqual( + args_obj["a"], + 5, + "Parameter a should be 5", + ) + self.assertEqual(args_obj["b"], 7, "Parameter b should be 7") + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_openai_server.py b/test/srt/test_openai_server.py index 4bedf7439..23e028729 100644 --- a/test/srt/test_openai_server.py +++ b/test/srt/test_openai_server.py @@ -623,58 +623,6 @@ class TestOpenAIServerEBNF(unittest.TestCase): text, pattern, f"Text '{text}' not matching the EBNF strict JSON shape" ) - def test_function_calling_format(self): - - client = openai.Client(api_key=self.api_key, base_url=self.base_url) - tools = [ - { - "type": "function", - "function": { - "name": "add", - "description": "Compute the sum of two numbers", - "parameters": { - "type": "object", - "properties": { - "a": { - "type": "int", - "description": "A number", - }, - "b": { - "type": "int", - "description": "A number", - }, - }, - "required": ["a", "b"], - }, - }, - } - ] - - messages = [{"role": "user", "content": "Compute (3+5)"}] - response = client.chat.completions.create( - model=self.model, - messages=messages, - temperature=0.8, - top_p=0.8, - stream=False, - tools=tools, - ) - - content = response.choices[0].message.content - tool_calls = response.choices[0].message.tool_calls - - assert ( - content is None - ), "When tools provided by the response, content should be None" - assert ( - isinstance(tool_calls, list) and len(tool_calls) > 0 - ), "Format not matched, tool_calls should be a list" - - function_name = tool_calls[0].function.name - assert ( - function_name == "add" - ), "Function name should be add for the above response" - class TestOpenAIEmbedding(unittest.TestCase): @classmethod