Feature/function calling update (#2700)
Co-authored-by: Mingyuan Ma <mamingyuan2001@berkeley.edu> Co-authored-by: Chayenne <zhaochen20@outlook.com> Co-authored-by: shuaills <shishuaiuoe@gmail.com>
This commit is contained in:
@@ -4,62 +4,23 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Function Calling\n",
|
||||
"# Tool and Function Calling\n",
|
||||
"\n",
|
||||
"This notebook provides a quick-start guide to use function tooling using SGLang chat completions API\n",
|
||||
"\n",
|
||||
"## Supported Models\n",
|
||||
"\n",
|
||||
"Currently, we added the support for tools calling in the following models:\n",
|
||||
" - Llama 3.2 models\n",
|
||||
" - Llama 3.1 models\n",
|
||||
" - Qwen 2.5 models\n",
|
||||
" - InternLM Models"
|
||||
"This guide demonstrates how to use SGLang’s **Tool Calling** functionality."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Usage\n",
|
||||
"\n",
|
||||
"### Launch a server\n",
|
||||
"\n",
|
||||
"This code block is equivalent to executing\n",
|
||||
"\n",
|
||||
"`python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n",
|
||||
"--port 30000 --host 0.0.0.0`\n",
|
||||
"in your terminal and wait for the server to be ready. Once the server is running, you can send test requests using curl or requests. The server implements the OpenAI-compatible APIs."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from sglang.utils import (\n",
|
||||
" execute_shell_command,\n",
|
||||
" wait_for_server,\n",
|
||||
" terminate_process,\n",
|
||||
" print_highlight,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"server_process = execute_shell_command(\n",
|
||||
" \"\"\"\n",
|
||||
" python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --port 30000 --host 0.0.0.0\n",
|
||||
"\"\"\"\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"wait_for_server(\"http://localhost:30000\")"
|
||||
"## OpenAI Compatible API"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Single Round Invocation"
|
||||
"### Launching the Server"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -69,7 +30,47 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from openai import OpenAI\n",
|
||||
"import json\n",
|
||||
"from sglang.utils import (\n",
|
||||
" execute_shell_command,\n",
|
||||
" wait_for_server,\n",
|
||||
" terminate_process,\n",
|
||||
" print_highlight,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"server_process = execute_shell_command(\n",
|
||||
" \"python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --tool-call-parser llama3 --port 30333 --host 0.0.0.0\" # llama3\n",
|
||||
")\n",
|
||||
"wait_for_server(\"http://localhost:30333\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Note that `--tool-call-parser` defines the parser used to interpret responses. Currently supported parsers include:\n",
|
||||
"\n",
|
||||
"- llama3: Llama 3.1 / 3.2 (e.g. meta-llama/Llama-3.1-8B-Instruct, meta-llama/Llama-3.2-1B-Instruct).\n",
|
||||
"- mistral: Mistral (e.g. mistralai/Mistral-7B-Instruct-v0.3, mistralai/Mistral-Nemo-Instruct-2407, mistralai/\n",
|
||||
"Mistral-Nemo-Instruct-2407, mistralai/Mistral-7B-v0.3).\n",
|
||||
"- qwen25: Qwen 2.5 (e.g. Qwen/Qwen2.5-1.5B-Instruct, Qwen/Qwen2.5-7B-Instruct)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Define Tools for Function Call\n",
|
||||
"Below is a Python snippet that shows how to define a tool as a dictionary. The dictionary includes a tool name, a description, and property defined Parameters."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Define tools\n",
|
||||
"tools = [\n",
|
||||
" {\n",
|
||||
" \"type\": \"function\",\n",
|
||||
@@ -79,22 +80,264 @@
|
||||
" \"parameters\": {\n",
|
||||
" \"type\": \"object\",\n",
|
||||
" \"properties\": {\n",
|
||||
" \"location\": {\n",
|
||||
" \"city\": {\n",
|
||||
" \"type\": \"string\",\n",
|
||||
" \"description\": \"The city and state, e.g. San Francisco, CA\",\n",
|
||||
" \"description\": \"The city to find the weather for, e.g. 'San Francisco'\",\n",
|
||||
" },\n",
|
||||
" \"state\": {\n",
|
||||
" \"type\": \"string\",\n",
|
||||
" \"description\": \"the two-letter abbreviation for the state that the city is\"\n",
|
||||
" \" in, e.g. 'CA' which would mean 'California'\",\n",
|
||||
" },\n",
|
||||
" \"unit\": {\n",
|
||||
" \"type\": \"string\",\n",
|
||||
" \"description\": \"The unit to fetch the temperature in\",\n",
|
||||
" \"enum\": [\"celsius\", \"fahrenheit\"],\n",
|
||||
" },\n",
|
||||
" \"unit\": {\"type\": \"string\", \"enum\": [\"celsius\", \"fahrenheit\"]},\n",
|
||||
" },\n",
|
||||
" \"required\": [\"location\"],\n",
|
||||
" \"required\": [\"city\", \"state\", \"unit\"],\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
" }\n",
|
||||
"]\n",
|
||||
"messages = [{\"role\": \"user\", \"content\": \"What's the weather like in Boston today?\"}]\n",
|
||||
"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Define Messages"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def get_messages():\n",
|
||||
" return [\n",
|
||||
" {\n",
|
||||
" \"role\": \"user\",\n",
|
||||
" \"content\": \"What's the weather like in Boston today? Please respond with the format: Today's weather is :{function call result}\",\n",
|
||||
" }\n",
|
||||
" ]\n",
|
||||
"\n",
|
||||
"client = OpenAI(api_key=\"YOUR_API_KEY\", base_url=\"http://0.0.0.0:30000/v1\")\n",
|
||||
"model_name = client.models.list().data[0].id\n",
|
||||
"response = client.chat.completions.create(\n",
|
||||
"\n",
|
||||
"messages = get_messages()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Initialize the Client"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Initialize OpenAI-like client\n",
|
||||
"client = OpenAI(api_key=\"None\", base_url=\"http://0.0.0.0:30333/v1\")\n",
|
||||
"model_name = client.models.list().data[0].id"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Non-Streaming Request"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Non-streaming mode test\n",
|
||||
"response_non_stream = client.chat.completions.create(\n",
|
||||
" model=model_name,\n",
|
||||
" messages=messages,\n",
|
||||
" temperature=0.8,\n",
|
||||
" top_p=0.8,\n",
|
||||
" stream=False, # Non-streaming\n",
|
||||
" tools=tools,\n",
|
||||
")\n",
|
||||
"print_highlight(\"Non-stream response:\")\n",
|
||||
"print(response_non_stream)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Streaming Request"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Streaming mode test\n",
|
||||
"print_highlight(\"Streaming response:\")\n",
|
||||
"response_stream = client.chat.completions.create(\n",
|
||||
" model=model_name,\n",
|
||||
" messages=messages,\n",
|
||||
" temperature=0.8,\n",
|
||||
" top_p=0.8,\n",
|
||||
" stream=True, # Enable streaming\n",
|
||||
" tools=tools,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"chunks = []\n",
|
||||
"for chunk in response_stream:\n",
|
||||
" chunks.append(chunk)\n",
|
||||
" if chunk.choices[0].delta.tool_calls:\n",
|
||||
" print(chunk.choices[0].delta.tool_calls[0])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"\n",
|
||||
"### Handle Tool Calls\n",
|
||||
"\n",
|
||||
"When the engine determines it should call a particular tool, it will return arguments or partial arguments through the response. You can parse these arguments and later invoke the tool accordingly."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"**Non-Streaming Request**"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"name_non_stream = response_non_stream.choices[0].message.tool_calls[0].function.name\n",
|
||||
"arguments_non_stream = (\n",
|
||||
" response_non_stream.choices[0].message.tool_calls[0].function.arguments\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print_highlight(f\"Final streamed function call name: {name_non_stream}\")\n",
|
||||
"print_highlight(f\"Final streamed function call arguments: {arguments_non_stream}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"**Streaming Request**"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Parse and combine function call arguments\n",
|
||||
"arguments = []\n",
|
||||
"for chunk in chunks:\n",
|
||||
" choice = chunk.choices[0]\n",
|
||||
" delta = choice.delta\n",
|
||||
" if delta.tool_calls:\n",
|
||||
" tool_call = delta.tool_calls[0]\n",
|
||||
" if tool_call.function.name:\n",
|
||||
" print_highlight(f\"Streamed function call name: {tool_call.function.name}\")\n",
|
||||
"\n",
|
||||
" if tool_call.function.arguments:\n",
|
||||
" arguments.append(tool_call.function.arguments)\n",
|
||||
" print(f\"Streamed function call arguments: {tool_call.function.arguments}\")\n",
|
||||
"\n",
|
||||
"# Combine all fragments into a single JSON string\n",
|
||||
"full_arguments = \"\".join(arguments)\n",
|
||||
"print_highlight(f\"Final streamed function call arguments: {full_arguments}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Define a Tool Function"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# This is a demonstration, define real function according to your usage.\n",
|
||||
"def get_current_weather(city: str, state: str, unit: \"str\"):\n",
|
||||
" return (\n",
|
||||
" f\"The weather in {city}, {state} is 85 degrees {unit}. It is \"\n",
|
||||
" \"partly cloudly, with highs in the 90's.\"\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"available_tools = {\"get_current_weather\": get_current_weather}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"\n",
|
||||
"## Execute the Tool"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"call_data = json.loads(full_arguments)\n",
|
||||
"\n",
|
||||
"messages.append(\n",
|
||||
" {\n",
|
||||
" \"role\": \"user\",\n",
|
||||
" \"content\": \"\",\n",
|
||||
" \"tool_calls\": {\"name\": \"get_current_weather\", \"arguments\": full_arguments},\n",
|
||||
" }\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Call the corresponding tool function\n",
|
||||
"tool_name = messages[-1][\"tool_calls\"][\"name\"]\n",
|
||||
"tool_to_call = available_tools[tool_name]\n",
|
||||
"result = tool_to_call(**call_data)\n",
|
||||
"print_highlight(f\"Function call result: {result}\")\n",
|
||||
"messages.append({\"role\": \"tool\", \"content\": result, \"name\": tool_name})\n",
|
||||
"\n",
|
||||
"print_highlight(f\"Updated message history: {messages}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Send Results Back to Model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"final_response = client.chat.completions.create(\n",
|
||||
" model=model_name,\n",
|
||||
" messages=messages,\n",
|
||||
" temperature=0.8,\n",
|
||||
@@ -102,17 +345,56 @@
|
||||
" stream=False,\n",
|
||||
" tools=tools,\n",
|
||||
")\n",
|
||||
"print_highlight(\"Non-stream response:\")\n",
|
||||
"print(final_response)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Native API and SGLang Runtime (SRT)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from transformers import AutoTokenizer\n",
|
||||
"import requests\n",
|
||||
"\n",
|
||||
"print(response)\n",
|
||||
"# generate an answer\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained(\"meta-llama/Meta-Llama-3.1-8B-Instruct\")\n",
|
||||
"\n",
|
||||
"\"\"\"\n",
|
||||
"messages = get_messages()\n",
|
||||
"\n",
|
||||
"ChatCompletion(id='d6f620e1767e490d85b5ce45c15151cf', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content=None, refusal=None, \n",
|
||||
"role='assistant', audio=None, function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='0', function=Function(arguments='{\"a\": \"3\", \"b\": \"5\"}', name='add'), type='function')]), \n",
|
||||
"matched_stop=128008)], created=1735411703, model='meta-llama/Llama-3.2-1B-Instruct', object='chat.completion', service_tier=None, system_fingerprint=None, \n",
|
||||
"usage=CompletionUsage(completion_tokens=23, prompt_tokens=198, total_tokens=221, completion_tokens_details=None, prompt_tokens_details=None))\n",
|
||||
"input = tokenizer.apply_chat_template(\n",
|
||||
" messages,\n",
|
||||
" tokenize=False,\n",
|
||||
" add_generation_prompt=True,\n",
|
||||
" tools=tools,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"\"\"\""
|
||||
"gen_url = \"http://localhost:30333/generate\"\n",
|
||||
"gen_data = {\"text\": input, \"sampling_params\": {\"skip_special_tokens\": False}}\n",
|
||||
"gen_response = requests.post(gen_url, json=gen_data).json()[\"text\"]\n",
|
||||
"print(gen_response)\n",
|
||||
"\n",
|
||||
"# parse the response\n",
|
||||
"parse_url = \"http://localhost:30333/function_call\"\n",
|
||||
"\n",
|
||||
"function_call_input = {\n",
|
||||
" \"text\": gen_response,\n",
|
||||
" \"tool_call_parser\": \"llama3\",\n",
|
||||
" \"tools\": tools,\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"function_call_response = requests.post(parse_url, json=function_call_input)\n",
|
||||
"function_call_response_json = function_call_response.json()\n",
|
||||
"print(\"function name: \", function_call_response_json[\"calls\"][0][\"name\"])\n",
|
||||
"print(\"function arguments: \", function_call_response_json[\"calls\"][0][\"parameters\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -128,11 +410,98 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## How to support a new model?\n",
|
||||
"## Offline Engine API"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import sglang as sgl\n",
|
||||
"from sglang.srt.function_call_parser import FunctionCallParser\n",
|
||||
"from sglang.srt.managers.io_struct import Tool, Function\n",
|
||||
"\n",
|
||||
"For adding support of more different models:\n",
|
||||
" 1. Update the `TOOLS_TAG_LIST` in `sglang/srt/utils.py` with the tool tag used by the model.\n",
|
||||
" 2. Add support in `parse_tool_response` function for converting into tool calls `sglang/srt/utils.py`\n"
|
||||
"llm = sgl.Engine(model_path=\"meta-llama/Meta-Llama-3.1-8B-Instruct\")\n",
|
||||
"tokenizer = llm.tokenizer_manager.tokenizer\n",
|
||||
"input_ids = tokenizer.apply_chat_template(\n",
|
||||
" messages, tokenize=True, add_generation_prompt=True, tools=tools\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"sampling_params = {\n",
|
||||
" \"max_new_tokens\": 128,\n",
|
||||
" \"temperature\": 0.3,\n",
|
||||
" \"top_p\": 0.95,\n",
|
||||
" \"skip_special_tokens\": False,\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"# 1) Offline generation\n",
|
||||
"result = llm.generate(input_ids=input_ids, sampling_params=sampling_params)\n",
|
||||
"generated_text = result[\"text\"] # Assume there is only one prompt\n",
|
||||
"\n",
|
||||
"print(\"=== Offline Engine Output Text ===\")\n",
|
||||
"print(generated_text)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# 2) Parse using FunctionCallParser\n",
|
||||
"def convert_dict_to_tool(tool_dict: dict) -> Tool:\n",
|
||||
" function_dict = tool_dict.get(\"function\", {})\n",
|
||||
" return Tool(\n",
|
||||
" type=tool_dict.get(\"type\", \"function\"),\n",
|
||||
" function=Function(\n",
|
||||
" name=function_dict.get(\"name\"),\n",
|
||||
" description=function_dict.get(\"description\"),\n",
|
||||
" parameters=function_dict.get(\"parameters\"),\n",
|
||||
" ),\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"tools = [convert_dict_to_tool(raw_tool) for raw_tool in tools]\n",
|
||||
"\n",
|
||||
"parser = FunctionCallParser(tools=tools, tool_call_parser=\"llama3\")\n",
|
||||
"normal_text, calls = parser.parse_non_stream(generated_text)\n",
|
||||
"\n",
|
||||
"print(\"\\n=== Parsing Result ===\")\n",
|
||||
"print(\"Normal text portion:\", normal_text)\n",
|
||||
"print(\"Function call portion:\")\n",
|
||||
"for call in calls:\n",
|
||||
" # call: ToolCallItem\n",
|
||||
" print(f\" - tool name: {call.name}\")\n",
|
||||
" print(f\" parameters: {call.parameters}\")\n",
|
||||
"\n",
|
||||
"# 3) If needed, perform additional logic on the parsed functions, such as automatically calling the corresponding function to obtain a return value, etc."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"llm.shutdown()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## How to support a new model?\n",
|
||||
"1. Update the TOOLS_TAG_LIST in sglang/srt/function_call_parser.py with the model’s tool tags. Currently supported tags include:\n",
|
||||
"```\n",
|
||||
"\tTOOLS_TAG_LIST = [\n",
|
||||
"\t “<|plugin|>“,\n",
|
||||
"\t “<function=“,\n",
|
||||
"\t “<tool_call>“,\n",
|
||||
"\t “<|python_tag|>“,\n",
|
||||
"\t “[TOOL_CALLS]”\n",
|
||||
"\t]\n",
|
||||
"```\n",
|
||||
"2. Create a new detector class in sglang/srt/function_call_parser.py that inherits from BaseFormatDetector. The detector should handle the model’s specific function call format. For example:\n",
|
||||
"```\n",
|
||||
" class NewModelDetector(BaseFormatDetector):\n",
|
||||
"```\n",
|
||||
"3. Add the new detector to the MultiFormatParser class that manages all the format detectors."
|
||||
]
|
||||
}
|
||||
],
|
||||
|
||||
@@ -39,10 +39,12 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
||||
|
||||
from sglang.srt.entrypoints.engine import _launch_subprocesses
|
||||
from sglang.srt.function_call_parser import FunctionCallParser
|
||||
from sglang.srt.managers.io_struct import (
|
||||
CloseSessionReqInput,
|
||||
ConfigureLoggingReq,
|
||||
EmbeddingReqInput,
|
||||
FunctionCallReqInput,
|
||||
GenerateReqInput,
|
||||
GetWeightsByNameReqInput,
|
||||
InitWeightsUpdateGroupReqInput,
|
||||
@@ -369,6 +371,28 @@ async def configure_logging(obj: ConfigureLoggingReq, request: Request):
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@app.post("/function_call")
|
||||
async def function_call_request(obj: FunctionCallReqInput, request: Request):
|
||||
"""
|
||||
A native API endpoint to parse function calls from a text.
|
||||
"""
|
||||
# 1) Initialize the parser based on the request body
|
||||
parser = FunctionCallParser(tools=obj.tools, tool_call_parser=obj.tool_call_parser)
|
||||
|
||||
# 2) Call the non-stream parsing method (non-stream)
|
||||
normal_text, calls = parser.parse_non_stream(obj.text)
|
||||
|
||||
# 3) Organize the response content
|
||||
response_data = {
|
||||
"normal_text": normal_text,
|
||||
"calls": [
|
||||
call.model_dump() for call in calls
|
||||
], # Convert pydantic objects to dictionaries
|
||||
}
|
||||
|
||||
return ORJSONResponse(content=response_data, status_code=200)
|
||||
|
||||
|
||||
##### OpenAI-compatible API endpoints #####
|
||||
|
||||
|
||||
|
||||
494
python/sglang/srt/function_call_parser.py
Normal file
494
python/sglang/srt/function_call_parser.py
Normal file
@@ -0,0 +1,494 @@
|
||||
import json
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from json import JSONDecodeError, JSONDecoder
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import partial_json_parser
|
||||
from partial_json_parser.core.options import Allow
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
TOOLS_TAG_LIST = [
|
||||
"<|plugin|>",
|
||||
"<function=",
|
||||
"<tool_call>",
|
||||
"<|python_tag|>",
|
||||
"[TOOL_CALLS]",
|
||||
]
|
||||
|
||||
|
||||
class Function(BaseModel):
|
||||
"""Function Tool Template."""
|
||||
|
||||
description: Optional[str] = Field(default=None, examples=[None])
|
||||
name: Optional[str] = None
|
||||
parameters: Optional[object] = None
|
||||
|
||||
|
||||
class ToolCallItem(BaseModel):
|
||||
"""Simple encapsulation of the parsed ToolCall result for easier usage in streaming contexts."""
|
||||
|
||||
tool_index: int
|
||||
name: Optional[str] = None
|
||||
parameters: str # JSON string
|
||||
|
||||
|
||||
def _find_common_prefix(s1: str, s2: str) -> str:
|
||||
prefix = ""
|
||||
min_length = min(len(s1), len(s2))
|
||||
for i in range(0, min_length):
|
||||
if s1[i] == s2[i]:
|
||||
prefix += s1[i]
|
||||
else:
|
||||
break
|
||||
return prefix
|
||||
|
||||
|
||||
def _partial_json_loads(input_str: str, flags: Allow) -> Tuple[Any, int]:
|
||||
try:
|
||||
return (partial_json_parser.loads(input_str, flags), len(input_str))
|
||||
except JSONDecodeError as e:
|
||||
if "Extra data" in e.msg:
|
||||
dec = JSONDecoder()
|
||||
return dec.raw_decode(input_str)
|
||||
raise
|
||||
|
||||
|
||||
def _is_complete_json(input_str: str) -> bool:
|
||||
try:
|
||||
json.loads(input_str)
|
||||
return True
|
||||
except JSONDecodeError:
|
||||
return False
|
||||
|
||||
|
||||
class StreamingParseResult:
|
||||
"""Result of streaming incremental parsing."""
|
||||
|
||||
def __init__(
|
||||
self, normal_text: str = "", calls: Optional[List[ToolCallItem]] = None
|
||||
):
|
||||
self.normal_text = normal_text
|
||||
self.calls = calls or []
|
||||
|
||||
|
||||
class BaseFormatDetector:
|
||||
"""Base class providing two sets of interfaces: one-time and streaming incremental."""
|
||||
|
||||
def __init__(self):
|
||||
# initialize properties used for state when parsing tool calls in
|
||||
self._buffer = ""
|
||||
# streaming mode
|
||||
self.prev_tool_call_arr: List[Dict] = []
|
||||
self.current_tool_id: int = -1
|
||||
self.current_tool_name_sent: bool = False
|
||||
self.streamed_args_for_tool: List[str] = (
|
||||
[]
|
||||
) # map what has been streamed for each tool so far to a list
|
||||
self.bot_token = ""
|
||||
self.eot_token = ""
|
||||
|
||||
def parse_base_json(self, action: Dict, tools: List[Function]):
|
||||
name, parameters = action["name"], json.dumps(
|
||||
action.get("parameters", action.get("arguments", {})),
|
||||
ensure_ascii=False,
|
||||
)
|
||||
tool_index = [tool.function.name for tool in tools].index(name)
|
||||
tool_call_item = ToolCallItem(
|
||||
tool_index=tool_index, name=name, parameters=parameters
|
||||
)
|
||||
calls = [tool_call_item]
|
||||
return calls
|
||||
|
||||
def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
|
||||
"""
|
||||
Parses the text in one go. Returns success=True if the format matches, otherwise False.
|
||||
Note that leftover_text here represents "content that this parser will not consume further".
|
||||
"""
|
||||
action = json.loads(text)
|
||||
return self.parse_base_json(action, tools)
|
||||
|
||||
def parse_streaming_increment(
|
||||
self, new_text: str, tools: List[Function]
|
||||
) -> StreamingParseResult:
|
||||
"""
|
||||
Streaming incremental parsing, referencing the logic of Llama32Detector.
|
||||
We partially parse JSON within <tool_call>...</tool_call>, and handle
|
||||
incremental argument output.
|
||||
"""
|
||||
# Append new text to buffer
|
||||
self._buffer += new_text
|
||||
current_text = self._buffer
|
||||
if not (self.bot_token in current_text or current_text.startswith("{")):
|
||||
self._buffer = ""
|
||||
if self.eot_token in new_text:
|
||||
new_text = new_text.replace(self.eot_token, "")
|
||||
return StreamingParseResult(normal_text=new_text)
|
||||
|
||||
# bit mask flags for partial JSON parsing. If the name hasn't been
|
||||
# sent yet, don't allow sending
|
||||
# an incomplete string since OpenAI only ever (as far as I have
|
||||
# seen) allows sending the entire tool/ function name at once.
|
||||
flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR
|
||||
try:
|
||||
tool_call_arr = []
|
||||
is_complete = []
|
||||
try:
|
||||
# depending on the prompt format the Llama model may or may not
|
||||
# prefix the output with the <|python_tag|> token
|
||||
start_idx = (
|
||||
len(self.bot_token)
|
||||
if current_text.startswith(self.bot_token)
|
||||
else 0
|
||||
)
|
||||
while start_idx < len(current_text):
|
||||
(obj, end_idx) = _partial_json_loads(
|
||||
current_text[start_idx:], flags
|
||||
)
|
||||
is_complete.append(
|
||||
_is_complete_json(current_text[start_idx : start_idx + end_idx])
|
||||
)
|
||||
start_idx += end_idx + len("; ")
|
||||
# depending on the prompt Llama can use
|
||||
# either arguments or parameters
|
||||
if "parameters" in obj:
|
||||
assert (
|
||||
"arguments" not in obj
|
||||
), "model generated both parameters and arguments"
|
||||
obj["arguments"] = obj["parameters"]
|
||||
tool_call_arr.append(obj)
|
||||
|
||||
except partial_json_parser.core.exceptions.MalformedJSON:
|
||||
# not enough tokens to parse into JSON yet
|
||||
return StreamingParseResult()
|
||||
|
||||
# select as the current tool call the one we're on the state at
|
||||
current_tool_call: Dict = (
|
||||
tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {}
|
||||
)
|
||||
|
||||
# case -- if no tokens have been streamed for the tool, e.g.
|
||||
# only the array brackets, stream nothing
|
||||
if len(tool_call_arr) == 0:
|
||||
return StreamingParseResult()
|
||||
|
||||
# case: we are starting a new tool in the array
|
||||
# -> array has > 0 length AND length has moved past cursor
|
||||
elif (
|
||||
len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1
|
||||
):
|
||||
|
||||
# if we're moving on to a new call, first make sure we
|
||||
# haven't missed anything in the previous one that was
|
||||
# auto-generated due to JSON completions, but wasn't
|
||||
# streamed to the client yet.
|
||||
if self.current_tool_id >= 0:
|
||||
cur_arguments = current_tool_call.get("arguments")
|
||||
if cur_arguments:
|
||||
cur_args_json = json.dumps(cur_arguments)
|
||||
sent = len(self.streamed_args_for_tool[self.current_tool_id])
|
||||
argument_diff = cur_args_json[sent:]
|
||||
|
||||
res = StreamingParseResult(
|
||||
normal_text=None,
|
||||
calls=[
|
||||
ToolCallItem(
|
||||
tool_index=self.current_tool_id,
|
||||
name="",
|
||||
parameters=argument_diff,
|
||||
)
|
||||
],
|
||||
)
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id
|
||||
] += argument_diff
|
||||
else:
|
||||
res = StreamingParseResult()
|
||||
else:
|
||||
res = StreamingParseResult()
|
||||
# re-set stuff pertaining to progress in the current tool
|
||||
self.current_tool_id = len(tool_call_arr) - 1
|
||||
self.current_tool_name_sent = False
|
||||
self.streamed_args_for_tool.append("")
|
||||
print("starting on new tool %d", self.current_tool_id)
|
||||
return res
|
||||
|
||||
# if the current tool name hasn't been sent, send if available
|
||||
# - otherwise send nothing
|
||||
elif not self.current_tool_name_sent:
|
||||
function_name = current_tool_call.get("name")
|
||||
if function_name:
|
||||
res = StreamingParseResult(
|
||||
normal_text=None,
|
||||
calls=[
|
||||
ToolCallItem(
|
||||
tool_index=self.current_tool_id,
|
||||
name=function_name,
|
||||
parameters="",
|
||||
)
|
||||
],
|
||||
)
|
||||
self.current_tool_name_sent = True
|
||||
else:
|
||||
res = StreamingParseResult()
|
||||
|
||||
# now we know we're on the same tool call and we're streaming
|
||||
# arguments
|
||||
else:
|
||||
cur_arguments = current_tool_call.get("arguments")
|
||||
res = StreamingParseResult()
|
||||
|
||||
if cur_arguments:
|
||||
sent = len(self.streamed_args_for_tool[self.current_tool_id])
|
||||
cur_args_json = json.dumps(cur_arguments)
|
||||
prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get(
|
||||
"arguments"
|
||||
)
|
||||
|
||||
argument_diff = None
|
||||
if is_complete[self.current_tool_id]:
|
||||
argument_diff = cur_args_json[sent:]
|
||||
self._buffer = ""
|
||||
self.prev_tool_call_arr[self.current_tool_id].clear()
|
||||
self.current_tool_name_sent: bool = False
|
||||
self.streamed_args_for_tool[self.current_tool_id] = ""
|
||||
|
||||
elif prev_arguments:
|
||||
prev_args_json = json.dumps(prev_arguments)
|
||||
if cur_args_json != prev_args_json:
|
||||
|
||||
prefix = _find_common_prefix(prev_args_json, cur_args_json)
|
||||
argument_diff = prefix[sent:]
|
||||
|
||||
if argument_diff is not None:
|
||||
res = StreamingParseResult(
|
||||
calls=[
|
||||
ToolCallItem(
|
||||
tool_index=self.current_tool_id,
|
||||
name="",
|
||||
parameters=argument_diff,
|
||||
)
|
||||
],
|
||||
)
|
||||
if not is_complete[self.current_tool_id]:
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id
|
||||
] += argument_diff
|
||||
|
||||
self.prev_tool_call_arr = tool_call_arr
|
||||
return res
|
||||
|
||||
except Exception as e:
|
||||
print(e)
|
||||
# Skipping chunk as a result of tool streaming extraction error
|
||||
return StreamingParseResult()
|
||||
|
||||
|
||||
class Qwen25Detector(BaseFormatDetector):
|
||||
"""
|
||||
Detector for Qwen 2.5 models.
|
||||
Assumes function call format:
|
||||
<tool_call>{"name":"xxx", "arguments":{...}}</tool_call>
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Initializes the detector with necessary state variables.
|
||||
"""
|
||||
super().__init__()
|
||||
self.bot_token = "<tool_call>"
|
||||
self.eot_token = "</tool_call>"
|
||||
|
||||
def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
|
||||
"""
|
||||
One-time parsing: Detects and parses tool calls in the provided text.
|
||||
|
||||
:param text: The complete text to parse.
|
||||
:param tools: List of available tools.
|
||||
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
|
||||
"""
|
||||
if "<tool_call>" not in text:
|
||||
return []
|
||||
pattern = r"<tool_call>(.*?)</tool_call>"
|
||||
match_result_list = re.findall(pattern, text, re.DOTALL)
|
||||
calls = []
|
||||
for match_result in match_result_list:
|
||||
match_result = json.loads(match_result)
|
||||
calls.extend(self.parse_base_json(match_result, tools))
|
||||
return calls
|
||||
|
||||
|
||||
class MistralDetector(BaseFormatDetector):
|
||||
"""
|
||||
Detector for Mistral models.
|
||||
Assumes function call format:
|
||||
<|action_start|><|plugin|>{"name":"xxx", "arguments":{...}}<|action_end|>
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Initializes the detector with necessary state variables.
|
||||
"""
|
||||
super().__init__()
|
||||
self.bot_token = "[TOOL_CALLS] ["
|
||||
self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
|
||||
|
||||
def _clean_text(self, text: str) -> str:
|
||||
"""
|
||||
clean text to only leave ''[TOOL_CALLS] [{"name": xxx, "arguments": {xxx}}]'
|
||||
for example,
|
||||
text = '[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"location": "Boston, MA", "unit": "fahrenheit"}}]\n\nToday\'s weather in Boston is :{function call result} (in Fahrenheit)\n\nIf you prefer Celsius, please let me know.'
|
||||
return '[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"location": "Boston, MA", "unit": "fahrenheit"}}]'
|
||||
The key pattern is [TOOL_CALLS] [...]
|
||||
"""
|
||||
find_results = re.findall(r"\[TOOL_CALLS\] \[.*?\]", text, re.DOTALL)
|
||||
if len(find_results) > 0:
|
||||
return find_results[0]
|
||||
else:
|
||||
return ""
|
||||
|
||||
def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
|
||||
"""
|
||||
One-time parsing: Detects and parses tool calls in the provided text.
|
||||
|
||||
:param text: The complete text to parse.
|
||||
:param tools: List of available tools.
|
||||
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
|
||||
"""
|
||||
text = self._clean_text(text)
|
||||
tool_content = text.replace("[TOOL_CALLS]", "").strip()
|
||||
raw_tool_calls = self.tool_call_regex.findall(tool_content)
|
||||
calls = []
|
||||
if len(raw_tool_calls) > 0:
|
||||
raw_tool_call = raw_tool_calls[0]
|
||||
function_call_arr = json.loads(raw_tool_call)
|
||||
for match_result in function_call_arr:
|
||||
calls.extend(self.parse_base_json(match_result, tools))
|
||||
return calls
|
||||
|
||||
|
||||
class Llama32Detector(BaseFormatDetector):
|
||||
"""
|
||||
Detector for Llama 3.2 models.
|
||||
Assumes function call format:
|
||||
<|python_tag|>{"name":"xxx", "arguments":{...}}
|
||||
Does not require a closing tag "</python_tag|>",
|
||||
relies on json.loads(...) success to determine if JSON is complete.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Initializes the detector with necessary state variables.
|
||||
"""
|
||||
super().__init__()
|
||||
self.bot_token = "<|python_tag|>"
|
||||
|
||||
def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
|
||||
"""
|
||||
One-time parsing: Detects and parses tool calls in the provided text.
|
||||
|
||||
:param text: The complete text to parse.
|
||||
:param tools: List of available tools.
|
||||
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
|
||||
"""
|
||||
|
||||
if "<|python_tag|>" not in text:
|
||||
return []
|
||||
_, action = text.split("<|python_tag|>")
|
||||
action = json.loads(action)
|
||||
return self.parse_base_json(action, tools)
|
||||
|
||||
|
||||
class MultiFormatParser:
|
||||
def __init__(self, detectors: List[BaseFormatDetector]):
|
||||
"""
|
||||
:param detectors: A series of available Detector instances passed in
|
||||
"""
|
||||
self.detectors = detectors
|
||||
|
||||
def parse_once(self, text: str, tools: List[Function]):
|
||||
"""
|
||||
One-time parsing: Loop through detectors until there are no new matches or text is exhausted
|
||||
Return: (final_text, all_calls)
|
||||
- final_text: The remaining text after parsing that was not consumed by any Detector (can be treated as normal text)
|
||||
- all_calls: All calls parsed by the Detectors
|
||||
"""
|
||||
final_calls = []
|
||||
final_normal_text = text
|
||||
for detector in self.detectors:
|
||||
tool_call_list = detector.detect_and_parse(text, tools)
|
||||
if len(tool_call_list) > 0: # parsed successfully
|
||||
final_calls = tool_call_list
|
||||
break
|
||||
|
||||
# leftover_text is the normal text not consumed by any Detector
|
||||
return final_normal_text, final_calls
|
||||
|
||||
def parse_streaming_increment(self, new_text: str, tools: List[Function]):
|
||||
"""
|
||||
Streaming incremental parsing: Feed new_text to each detector's parse_streaming_increment
|
||||
and merge their produced normal_text/calls to return.
|
||||
(The logic here can be "priority-based" or "parallel parsing" based on your needs)
|
||||
"""
|
||||
final_normal_text = ""
|
||||
final_calls = []
|
||||
|
||||
for detector in self.detectors:
|
||||
sp_result = detector.parse_streaming_increment(new_text, tools)
|
||||
# Merge normal_text and calls
|
||||
# If one sp_result contains result call, this should be a successful parse
|
||||
# If one sp_result only contains normal_text, this can either be a successful
|
||||
# parse or it is not using the desired parsing tool.
|
||||
if sp_result.normal_text:
|
||||
final_normal_text = sp_result.normal_text
|
||||
if sp_result.calls:
|
||||
final_calls.extend(sp_result.calls)
|
||||
final_normal_text = sp_result.normal_text
|
||||
break
|
||||
|
||||
return final_normal_text, final_calls
|
||||
|
||||
|
||||
class FunctionCallParser:
|
||||
"""
|
||||
In streaming scenarios, each time new_text is received, it calls multi_format_parser.parse_streaming_increment
|
||||
and returns the resulting normal_text and calls to the upper layer (or SSE).
|
||||
"""
|
||||
|
||||
ToolCallParserEnum: Dict[str, BaseFormatDetector] = {
|
||||
"llama3": Llama32Detector,
|
||||
"qwen25": Qwen25Detector,
|
||||
"mistral": MistralDetector,
|
||||
}
|
||||
|
||||
def __init__(self, tools: List[Function], tool_call_parser: str = None):
|
||||
detectors = []
|
||||
if tool_call_parser:
|
||||
detector_class = self.ToolCallParserEnum.get(tool_call_parser)
|
||||
if detector_class:
|
||||
detectors.append(detector_class())
|
||||
else:
|
||||
raise ValueError(f"Unsupported tool_call_parser: {tool_call_parser}")
|
||||
else:
|
||||
raise ValueError("Tool Call Parser Not Given!")
|
||||
|
||||
self.multi_format_parser = MultiFormatParser(detectors)
|
||||
self.tools = tools
|
||||
|
||||
def parse_non_stream(self, full_text: str):
|
||||
"""
|
||||
Non-streaming call: one-time parsing
|
||||
"""
|
||||
full_normal_text, calls = self.multi_format_parser.parse_once(
|
||||
full_text, self.tools
|
||||
)
|
||||
return full_normal_text, calls
|
||||
|
||||
def parse_stream_chunk(self, chunk_text: str):
|
||||
"""
|
||||
Streaming call: incremental parsing
|
||||
"""
|
||||
normal_text, calls = self.multi_format_parser.parse_streaming_increment(
|
||||
chunk_text, self.tools
|
||||
)
|
||||
return normal_text, calls
|
||||
@@ -17,7 +17,7 @@ processes (TokenizerManager, DetokenizerManager, Controller).
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
@@ -540,3 +540,27 @@ class CloseSessionReqInput:
|
||||
class OpenSessionReqOutput:
|
||||
session_id: Optional[str]
|
||||
success: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class Function:
|
||||
description: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
parameters: Optional[object] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Tool:
|
||||
function: Function
|
||||
type: Optional[str] = "function"
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionCallReqInput:
|
||||
text: str # The text to parse.
|
||||
tools: List[Tool] = field(
|
||||
default_factory=list
|
||||
) # A list of available function tools (name, parameters, etc.).
|
||||
tool_call_parser: Optional[str] = (
|
||||
None # Specify the parser type, e.g. 'llama3', 'qwen25', or 'mistral'. If not specified, tries all.
|
||||
)
|
||||
|
||||
@@ -20,7 +20,7 @@ import os
|
||||
import time
|
||||
import uuid
|
||||
from http import HTTPStatus
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from fastapi import HTTPException, Request, UploadFile
|
||||
from fastapi.responses import ORJSONResponse, StreamingResponse
|
||||
@@ -40,6 +40,7 @@ from sglang.srt.conversation import (
|
||||
generate_chat_conv,
|
||||
register_conv_template,
|
||||
)
|
||||
from sglang.srt.function_call_parser import TOOLS_TAG_LIST, FunctionCallParser
|
||||
from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
|
||||
from sglang.srt.openai_api.protocol import (
|
||||
BatchRequest,
|
||||
@@ -71,7 +72,6 @@ from sglang.srt.openai_api.protocol import (
|
||||
TopLogprob,
|
||||
UsageInfo,
|
||||
)
|
||||
from sglang.srt.utils import TOOLS_TAG_LIST, parse_tool_response
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -309,6 +309,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
|
||||
ret,
|
||||
to_file=True,
|
||||
cache_report=tokenizer_manager.server_args.enable_cache_report,
|
||||
tool_call_parser=tokenizer_manager.server_args.tool_call_parser,
|
||||
)
|
||||
else:
|
||||
responses = v1_generate_response(
|
||||
@@ -877,9 +878,6 @@ def v1_chat_generate_request(
|
||||
tools = None
|
||||
if request.tools and request.tool_choice != "none":
|
||||
request.skip_special_tokens = False
|
||||
if request.stream:
|
||||
logger.warning("Streaming is not supported with tools.")
|
||||
request.stream = False
|
||||
if not isinstance(request.tool_choice, str):
|
||||
tools = [
|
||||
item.function.model_dump()
|
||||
@@ -908,12 +906,26 @@ def v1_chat_generate_request(
|
||||
openai_compatible_messages = openai_compatible_messages[:-1]
|
||||
else:
|
||||
assistant_prefix = None
|
||||
prompt_ids = tokenizer_manager.tokenizer.apply_chat_template(
|
||||
openai_compatible_messages,
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
try:
|
||||
prompt_ids = tokenizer_manager.tokenizer.apply_chat_template(
|
||||
openai_compatible_messages,
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
tools=tools,
|
||||
)
|
||||
except:
|
||||
# This except branch will be triggered when the chosen model
|
||||
# has a different tools input format that is not compatiable
|
||||
# with openAI's apply_chat_template tool_call format, like Mistral.
|
||||
tools = [t if "function" in t else {"function": t} for t in tools]
|
||||
prompt_ids = tokenizer_manager.tokenizer.apply_chat_template(
|
||||
openai_compatible_messages,
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
if assistant_prefix:
|
||||
prompt_ids += tokenizer_manager.tokenizer.encode(assistant_prefix)
|
||||
stop = request.stop
|
||||
@@ -1005,7 +1017,9 @@ def v1_chat_generate_request(
|
||||
return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]
|
||||
|
||||
|
||||
def v1_chat_generate_response(request, ret, to_file=False, cache_report=False):
|
||||
def v1_chat_generate_response(
|
||||
request, ret, to_file=False, cache_report=False, tool_call_parser=None
|
||||
):
|
||||
choices = []
|
||||
|
||||
for idx, ret_item in enumerate(ret):
|
||||
@@ -1066,12 +1080,13 @@ def v1_chat_generate_response(request, ret, to_file=False, cache_report=False):
|
||||
if finish_reason == "stop":
|
||||
finish_reason = "tool_calls"
|
||||
try:
|
||||
text, call_info_list = parse_tool_response(text, tools) # noqa
|
||||
parser = FunctionCallParser(tools, tool_call_parser)
|
||||
full_normal_text, call_info_list = parser.parse_non_stream(text)
|
||||
tool_calls = [
|
||||
ToolCall(
|
||||
id=str(call_info[0]),
|
||||
id=str(call_info.tool_index),
|
||||
function=FunctionResponse(
|
||||
name=call_info[1], arguments=call_info[2]
|
||||
name=call_info.name, arguments=call_info.parameters
|
||||
),
|
||||
)
|
||||
for call_info in call_info_list
|
||||
@@ -1172,6 +1187,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
||||
adapted_request, request = v1_chat_generate_request(all_requests, tokenizer_manager)
|
||||
|
||||
if adapted_request.stream:
|
||||
parser_dict = {}
|
||||
|
||||
async def generate_stream_resp():
|
||||
is_firsts = {}
|
||||
@@ -1184,6 +1200,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
||||
adapted_request, raw_request
|
||||
):
|
||||
index = content.get("index", 0)
|
||||
text = content["text"]
|
||||
|
||||
is_first = is_firsts.get(index, True)
|
||||
stream_buffer = stream_buffers.get(index, "")
|
||||
@@ -1263,29 +1280,111 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
||||
|
||||
text = content["text"]
|
||||
delta = text[len(stream_buffer) :]
|
||||
stream_buffer = stream_buffer + delta
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=index,
|
||||
delta=DeltaMessage(content=delta),
|
||||
finish_reason=(finish_reason["type"] if finish_reason else ""),
|
||||
matched_stop=(
|
||||
finish_reason["matched"]
|
||||
if finish_reason and "matched" in finish_reason
|
||||
else None
|
||||
),
|
||||
logprobs=choice_logprobs,
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=content["meta_info"]["id"],
|
||||
choices=[choice_data],
|
||||
model=request.model,
|
||||
)
|
||||
new_stream_buffer = stream_buffer + delta
|
||||
|
||||
is_firsts[index] = is_first
|
||||
stream_buffers[index] = stream_buffer
|
||||
n_prev_tokens[index] = n_prev_token
|
||||
if request.tool_choice != "none" and request.tools:
|
||||
if index not in parser_dict:
|
||||
parser_dict[index] = FunctionCallParser(
|
||||
tools=request.tools,
|
||||
tool_call_parser=tokenizer_manager.server_args.tool_call_parser,
|
||||
)
|
||||
parser = parser_dict[index]
|
||||
|
||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||
# parse_increment => returns (normal_text, calls)
|
||||
normal_text, calls = parser.parse_stream_chunk(delta)
|
||||
|
||||
# 1) if there's normal_text, output it as normal content
|
||||
if normal_text:
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=index,
|
||||
delta=DeltaMessage(content=normal_text),
|
||||
finish_reason=(
|
||||
finish_reason["type"] if finish_reason else ""
|
||||
),
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=content["meta_info"]["id"],
|
||||
choices=[choice_data],
|
||||
model=request.model,
|
||||
)
|
||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||
|
||||
# 2) if we found calls, we output them as separate chunk(s)
|
||||
for call_item in calls:
|
||||
# transform call_item -> FunctionResponse + ToolCall
|
||||
|
||||
if (
|
||||
content["meta_info"]["finish_reason"]
|
||||
and content["meta_info"]["finish_reason"]["type"]
|
||||
== "stop"
|
||||
):
|
||||
latest_delta_len = 0
|
||||
if isinstance(call_item.parameters, str):
|
||||
latest_delta_len = len(call_item.parameters)
|
||||
|
||||
expected_call = json.dumps(
|
||||
parser.multi_format_parser.detectors[0]
|
||||
.prev_tool_call_arr[index]
|
||||
.get("arguments", {}),
|
||||
ensure_ascii=False,
|
||||
)
|
||||
actual_call = parser.multi_format_parser.detectors[
|
||||
0
|
||||
].streamed_args_for_tool[index]
|
||||
if latest_delta_len > 0:
|
||||
actual_call = actual_call[:-latest_delta_len]
|
||||
remaining_call = expected_call.replace(
|
||||
actual_call, "", 1
|
||||
)
|
||||
call_item.parameters = remaining_call
|
||||
|
||||
tool_call = ToolCall(
|
||||
id=str(call_item.tool_index),
|
||||
function=FunctionResponse(
|
||||
name=call_item.name,
|
||||
arguments=call_item.parameters,
|
||||
),
|
||||
)
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=index,
|
||||
delta=DeltaMessage(
|
||||
role="assistant", tool_calls=[tool_call]
|
||||
),
|
||||
finish_reason="tool_call",
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=content["meta_info"]["id"],
|
||||
choices=[choice_data],
|
||||
model=request.model,
|
||||
)
|
||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||
|
||||
stream_buffers[index] = new_stream_buffer
|
||||
is_firsts[index] = is_first
|
||||
|
||||
else:
|
||||
# No tool calls => just treat this as normal text
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=index,
|
||||
delta=DeltaMessage(content=delta),
|
||||
finish_reason=(
|
||||
finish_reason["type"] if finish_reason else ""
|
||||
),
|
||||
matched_stop=(
|
||||
finish_reason["matched"]
|
||||
if finish_reason and "matched" in finish_reason
|
||||
else None
|
||||
),
|
||||
logprobs=choice_logprobs,
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=content["meta_info"]["id"],
|
||||
choices=[choice_data],
|
||||
model=request.model,
|
||||
)
|
||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||
stream_buffers[index] = new_stream_buffer
|
||||
is_firsts[index] = is_first
|
||||
if request.stream_options and request.stream_options.include_usage:
|
||||
total_prompt_tokens = sum(
|
||||
tokens
|
||||
@@ -1333,7 +1432,10 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
||||
ret = [ret]
|
||||
|
||||
response = v1_chat_generate_response(
|
||||
request, ret, cache_report=tokenizer_manager.server_args.enable_cache_report
|
||||
request,
|
||||
ret,
|
||||
cache_report=tokenizer_manager.server_args.enable_cache_report,
|
||||
tool_call_parser=tokenizer_manager.server_args.tool_call_parser,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@@ -262,7 +262,7 @@ class Function(BaseModel):
|
||||
"""Function descriptions."""
|
||||
|
||||
description: Optional[str] = Field(default=None, examples=[None])
|
||||
name: str
|
||||
name: Optional[str] = None
|
||||
parameters: Optional[object] = None
|
||||
|
||||
|
||||
@@ -276,7 +276,7 @@ class Tool(BaseModel):
|
||||
class ToolChoiceFuncName(BaseModel):
|
||||
"""The name of tool choice function."""
|
||||
|
||||
name: str
|
||||
name: Optional[str] = None
|
||||
|
||||
|
||||
class ToolChoice(BaseModel):
|
||||
@@ -329,8 +329,8 @@ class ChatCompletionRequest(BaseModel):
|
||||
class FunctionResponse(BaseModel):
|
||||
"""Function response."""
|
||||
|
||||
name: str
|
||||
arguments: str
|
||||
name: Optional[str] = None
|
||||
arguments: Optional[str] = None
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
@@ -367,6 +367,7 @@ class ChatCompletionResponse(BaseModel):
|
||||
class DeltaMessage(BaseModel):
|
||||
role: Optional[str] = None
|
||||
content: Optional[str] = None
|
||||
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
|
||||
|
||||
|
||||
class ChatCompletionResponseStreamChoice(BaseModel):
|
||||
|
||||
@@ -161,6 +161,7 @@ class ServerArgs:
|
||||
|
||||
# Custom logit processor
|
||||
enable_custom_logit_processor: bool = False
|
||||
tool_call_parser: str = None
|
||||
|
||||
def __post_init__(self):
|
||||
# Set missing default values
|
||||
@@ -877,6 +878,14 @@ class ServerArgs:
|
||||
action="store_true",
|
||||
help="Enable users to pass custom logit processors to the server (disabled by default for security)",
|
||||
)
|
||||
# Function Calling
|
||||
parser.add_argument(
|
||||
"--tool-call-parser",
|
||||
type=str,
|
||||
choices=["qwen25", "mistral", "llama3"],
|
||||
default=ServerArgs.tool_call_parser,
|
||||
help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', and 'llama3'.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
|
||||
@@ -1243,68 +1243,6 @@ def dataclass_to_string_truncated(data, max_length=2048):
|
||||
return str(data)
|
||||
|
||||
|
||||
TOOLS_TAG_LIST = ["<|plugin|>", "<function=", "<tool_call>", "<|python_tag|>"]
|
||||
|
||||
|
||||
def parse_tool_response(text, tools, **kwargs):
|
||||
"""Parse model response containing tool information.
|
||||
|
||||
Args:
|
||||
text(str): model response in string format
|
||||
tools(List): tools from user request
|
||||
"""
|
||||
if "<|plugin|>" in text: # internlm2
|
||||
text, action = text.split("<|action_start|><|plugin|>")
|
||||
action = action.split("<|action_end|>".strip())[0]
|
||||
action = action[action.find("{") :]
|
||||
action = json.loads(action)
|
||||
name, parameters = action["name"], json.dumps(
|
||||
action.get("parameters", action.get("arguments", {})), ensure_ascii=False
|
||||
)
|
||||
call_info_list = [(name, parameters)]
|
||||
elif "<function=" in text: # llama3.1
|
||||
action, _ = text.split("</function>")
|
||||
parameters = action[action.find("{") :]
|
||||
name = action.split("<function=")[1].split(">{")[0]
|
||||
call_info_list = [(name, parameters)]
|
||||
elif "<tool_call>" in text and "</tool_call>" in text: # qwen2.5
|
||||
# get tool_call in text
|
||||
pattern = r"<tool_call>(.*?)</tool_call>"
|
||||
match_result_list = re.findall(pattern, text, re.DOTALL)
|
||||
call_info_list = []
|
||||
for match_result in match_result_list:
|
||||
action = json.loads(match_result)
|
||||
call_info_list.append(
|
||||
(action["name"], json.dumps(action["arguments"], ensure_ascii=False))
|
||||
)
|
||||
# get text outside of tags
|
||||
if not text.startswith("<tool_call>"):
|
||||
text = text[: text.find("<tool_call>")]
|
||||
elif not text.endswith("</tool_call>"):
|
||||
text = text[text.rfind("</tool_call>") + len("</tool_call>") :]
|
||||
else:
|
||||
text = ""
|
||||
elif "<|python_tag|>" in text: # llama3.2
|
||||
_, action = text.split("<|python_tag|>")
|
||||
action = json.loads(action)
|
||||
name, parameters = action["name"], json.dumps(
|
||||
action.get("parameters", action.get("arguments", {})), ensure_ascii=False
|
||||
)
|
||||
call_info_list = [(name, parameters)]
|
||||
else:
|
||||
raise RuntimeError(f"Unexpected model response: {text}")
|
||||
|
||||
call_info_list = [
|
||||
(
|
||||
[tool.function.name for tool in tools].index(call_info[0]),
|
||||
call_info[0],
|
||||
call_info[1],
|
||||
)
|
||||
for call_info in call_info_list
|
||||
]
|
||||
return text, call_info_list
|
||||
|
||||
|
||||
def permute_weight(x: torch.Tensor) -> torch.Tensor:
|
||||
b_ = x.shape[0]
|
||||
n_ = x.shape[1]
|
||||
|
||||
249
test/srt/test_function_calling.py
Normal file
249
test/srt/test_function_calling.py
Normal file
@@ -0,0 +1,249 @@
|
||||
import json
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import openai
|
||||
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
|
||||
class TestOpenAIServerFunctionCalling(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
# Replace with the model name needed for testing; if not required, reuse DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.api_key = "sk-123456"
|
||||
|
||||
# Start the local OpenAI Server. If necessary, you can add other parameters such as --enable-tools.
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
api_key=cls.api_key,
|
||||
other_args=[
|
||||
# If your server needs extra parameters to test function calling, please add them here.
|
||||
"--tool-call-parser",
|
||||
"llama3",
|
||||
],
|
||||
)
|
||||
cls.base_url += "/v1"
|
||||
cls.tokenizer = get_tokenizer(cls.model)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def test_function_calling_format(self):
|
||||
"""
|
||||
Test: Whether the function call format returned by the AI is correct.
|
||||
When returning a tool call, message.content should be None, and tool_calls should be a list.
|
||||
"""
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "add",
|
||||
"description": "Compute the sum of two numbers",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"a": {
|
||||
"type": "int",
|
||||
"description": "A number",
|
||||
},
|
||||
"b": {
|
||||
"type": "int",
|
||||
"description": "A number",
|
||||
},
|
||||
},
|
||||
"required": ["a", "b"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
messages = [{"role": "user", "content": "Compute (3+5)"}]
|
||||
response = client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
temperature=0.8,
|
||||
top_p=0.8,
|
||||
stream=False,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
content = response.choices[0].message.content
|
||||
tool_calls = response.choices[0].message.tool_calls
|
||||
|
||||
assert content is None, (
|
||||
"When function call is successful, message.content should be None, "
|
||||
f"but got: {content}"
|
||||
)
|
||||
assert (
|
||||
isinstance(tool_calls, list) and len(tool_calls) > 0
|
||||
), "tool_calls should be a non-empty list"
|
||||
|
||||
function_name = tool_calls[0].function.name
|
||||
assert function_name == "add", "Function name should be 'add'"
|
||||
|
||||
def test_function_calling_streaming_simple(self):
|
||||
"""
|
||||
Test: Whether the function name can be correctly recognized in streaming mode.
|
||||
- Expect a function call to be found, and the function name to be correct.
|
||||
- Verify that streaming mode returns at least multiple chunks.
|
||||
"""
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The city to find the weather for",
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"description": "Weather unit (celsius or fahrenheit)",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
},
|
||||
},
|
||||
"required": ["city", "unit"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
messages = [{"role": "user", "content": "What is the temperature in Paris?"}]
|
||||
|
||||
response_stream = client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
temperature=0.8,
|
||||
top_p=0.8,
|
||||
stream=True,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
chunks = list(response_stream)
|
||||
self.assertTrue(len(chunks) > 0, "Streaming should return at least one chunk")
|
||||
|
||||
found_function_name = False
|
||||
for chunk in chunks:
|
||||
choice = chunk.choices[0]
|
||||
# Check whether the current chunk contains tool_calls
|
||||
if choice.delta.tool_calls:
|
||||
tool_call = choice.delta.tool_calls[0]
|
||||
if tool_call.function.name:
|
||||
self.assertEqual(
|
||||
tool_call.function.name,
|
||||
"get_current_weather",
|
||||
"Function name should be 'get_current_weather'",
|
||||
)
|
||||
found_function_name = True
|
||||
break
|
||||
|
||||
self.assertTrue(
|
||||
found_function_name,
|
||||
"Target function name 'get_current_weather' was not found in the streaming chunks",
|
||||
)
|
||||
|
||||
def test_function_calling_streaming_args_parsing(self):
|
||||
"""
|
||||
Test: Whether the function call arguments returned in streaming mode can be correctly concatenated into valid JSON.
|
||||
- The user request requires multiple parameters.
|
||||
- AI may return the arguments in chunks that need to be concatenated.
|
||||
"""
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "add",
|
||||
"description": "Compute the sum of two integers",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"a": {
|
||||
"type": "int",
|
||||
"description": "First integer",
|
||||
},
|
||||
"b": {
|
||||
"type": "int",
|
||||
"description": "Second integer",
|
||||
},
|
||||
},
|
||||
"required": ["a", "b"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "Please sum 5 and 7, just call the function."}
|
||||
]
|
||||
|
||||
response_stream = client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
temperature=0.9,
|
||||
top_p=0.9,
|
||||
stream=True,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
argument_fragments = []
|
||||
function_name = None
|
||||
for chunk in response_stream:
|
||||
choice = chunk.choices[0]
|
||||
if choice.delta.tool_calls:
|
||||
tool_call = choice.delta.tool_calls[0]
|
||||
# Record the function name on first occurrence
|
||||
function_name = tool_call.function.name or function_name
|
||||
# In case of multiple chunks, JSON fragments may need to be concatenated
|
||||
if tool_call.function.arguments:
|
||||
argument_fragments.append(tool_call.function.arguments)
|
||||
|
||||
self.assertEqual(function_name, "add", "Function name should be 'add'")
|
||||
joined_args = "".join(argument_fragments)
|
||||
self.assertTrue(
|
||||
len(joined_args) > 0,
|
||||
"No parameter fragments were returned in the function call",
|
||||
)
|
||||
|
||||
# Check whether the concatenated JSON is valid
|
||||
try:
|
||||
args_obj = json.loads(joined_args)
|
||||
except json.JSONDecodeError:
|
||||
self.fail(
|
||||
"The concatenated tool call arguments are not valid JSON, parsing failed"
|
||||
)
|
||||
|
||||
self.assertIn("a", args_obj, "Missing parameter 'a'")
|
||||
self.assertIn("b", args_obj, "Missing parameter 'b'")
|
||||
self.assertEqual(
|
||||
args_obj["a"],
|
||||
5,
|
||||
"Parameter a should be 5",
|
||||
)
|
||||
self.assertEqual(args_obj["b"], 7, "Parameter b should be 7")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -623,58 +623,6 @@ class TestOpenAIServerEBNF(unittest.TestCase):
|
||||
text, pattern, f"Text '{text}' not matching the EBNF strict JSON shape"
|
||||
)
|
||||
|
||||
def test_function_calling_format(self):
|
||||
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "add",
|
||||
"description": "Compute the sum of two numbers",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"a": {
|
||||
"type": "int",
|
||||
"description": "A number",
|
||||
},
|
||||
"b": {
|
||||
"type": "int",
|
||||
"description": "A number",
|
||||
},
|
||||
},
|
||||
"required": ["a", "b"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
messages = [{"role": "user", "content": "Compute (3+5)"}]
|
||||
response = client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
temperature=0.8,
|
||||
top_p=0.8,
|
||||
stream=False,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
content = response.choices[0].message.content
|
||||
tool_calls = response.choices[0].message.tool_calls
|
||||
|
||||
assert (
|
||||
content is None
|
||||
), "When tools provided by the response, content should be None"
|
||||
assert (
|
||||
isinstance(tool_calls, list) and len(tool_calls) > 0
|
||||
), "Format not matched, tool_calls should be a list"
|
||||
|
||||
function_name = tool_calls[0].function.name
|
||||
assert (
|
||||
function_name == "add"
|
||||
), "Function name should be add for the above response"
|
||||
|
||||
|
||||
class TestOpenAIEmbedding(unittest.TestCase):
|
||||
@classmethod
|
||||
|
||||
Reference in New Issue
Block a user