Feature/function calling update (#2700)
Co-authored-by: Mingyuan Ma <mamingyuan2001@berkeley.edu> Co-authored-by: Chayenne <zhaochen20@outlook.com> Co-authored-by: shuaills <shishuaiuoe@gmail.com>
This commit is contained in:
@@ -4,62 +4,23 @@
|
|||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"# Function Calling\n",
|
"# Tool and Function Calling\n",
|
||||||
"\n",
|
"\n",
|
||||||
"This notebook provides a quick-start guide to use function tooling using SGLang chat completions API\n",
|
"This guide demonstrates how to use SGLang’s **Tool Calling** functionality."
|
||||||
"\n",
|
|
||||||
"## Supported Models\n",
|
|
||||||
"\n",
|
|
||||||
"Currently, we added the support for tools calling in the following models:\n",
|
|
||||||
" - Llama 3.2 models\n",
|
|
||||||
" - Llama 3.1 models\n",
|
|
||||||
" - Qwen 2.5 models\n",
|
|
||||||
" - InternLM Models"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"## Usage\n",
|
"## OpenAI Compatible API"
|
||||||
"\n",
|
|
||||||
"### Launch a server\n",
|
|
||||||
"\n",
|
|
||||||
"This code block is equivalent to executing\n",
|
|
||||||
"\n",
|
|
||||||
"`python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n",
|
|
||||||
"--port 30000 --host 0.0.0.0`\n",
|
|
||||||
"in your terminal and wait for the server to be ready. Once the server is running, you can send test requests using curl or requests. The server implements the OpenAI-compatible APIs."
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"from sglang.utils import (\n",
|
|
||||||
" execute_shell_command,\n",
|
|
||||||
" wait_for_server,\n",
|
|
||||||
" terminate_process,\n",
|
|
||||||
" print_highlight,\n",
|
|
||||||
")\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
|
||||||
"server_process = execute_shell_command(\n",
|
|
||||||
" \"\"\"\n",
|
|
||||||
" python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --port 30000 --host 0.0.0.0\n",
|
|
||||||
"\"\"\"\n",
|
|
||||||
")\n",
|
|
||||||
"\n",
|
|
||||||
"wait_for_server(\"http://localhost:30000\")"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"### Single Round Invocation"
|
"### Launching the Server"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -69,7 +30,47 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from openai import OpenAI\n",
|
"from openai import OpenAI\n",
|
||||||
|
"import json\n",
|
||||||
|
"from sglang.utils import (\n",
|
||||||
|
" execute_shell_command,\n",
|
||||||
|
" wait_for_server,\n",
|
||||||
|
" terminate_process,\n",
|
||||||
|
" print_highlight,\n",
|
||||||
|
")\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"server_process = execute_shell_command(\n",
|
||||||
|
" \"python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --tool-call-parser llama3 --port 30333 --host 0.0.0.0\" # llama3\n",
|
||||||
|
")\n",
|
||||||
|
"wait_for_server(\"http://localhost:30333\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Note that `--tool-call-parser` defines the parser used to interpret responses. Currently supported parsers include:\n",
|
||||||
|
"\n",
|
||||||
|
"- llama3: Llama 3.1 / 3.2 (e.g. meta-llama/Llama-3.1-8B-Instruct, meta-llama/Llama-3.2-1B-Instruct).\n",
|
||||||
|
"- mistral: Mistral (e.g. mistralai/Mistral-7B-Instruct-v0.3, mistralai/Mistral-Nemo-Instruct-2407, mistralai/\n",
|
||||||
|
"Mistral-Nemo-Instruct-2407, mistralai/Mistral-7B-v0.3).\n",
|
||||||
|
"- qwen25: Qwen 2.5 (e.g. Qwen/Qwen2.5-1.5B-Instruct, Qwen/Qwen2.5-7B-Instruct)."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Define Tools for Function Call\n",
|
||||||
|
"Below is a Python snippet that shows how to define a tool as a dictionary. The dictionary includes a tool name, a description, and property defined Parameters."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Define tools\n",
|
||||||
"tools = [\n",
|
"tools = [\n",
|
||||||
" {\n",
|
" {\n",
|
||||||
" \"type\": \"function\",\n",
|
" \"type\": \"function\",\n",
|
||||||
@@ -79,22 +80,264 @@
|
|||||||
" \"parameters\": {\n",
|
" \"parameters\": {\n",
|
||||||
" \"type\": \"object\",\n",
|
" \"type\": \"object\",\n",
|
||||||
" \"properties\": {\n",
|
" \"properties\": {\n",
|
||||||
" \"location\": {\n",
|
" \"city\": {\n",
|
||||||
" \"type\": \"string\",\n",
|
" \"type\": \"string\",\n",
|
||||||
" \"description\": \"The city and state, e.g. San Francisco, CA\",\n",
|
" \"description\": \"The city to find the weather for, e.g. 'San Francisco'\",\n",
|
||||||
|
" },\n",
|
||||||
|
" \"state\": {\n",
|
||||||
|
" \"type\": \"string\",\n",
|
||||||
|
" \"description\": \"the two-letter abbreviation for the state that the city is\"\n",
|
||||||
|
" \" in, e.g. 'CA' which would mean 'California'\",\n",
|
||||||
|
" },\n",
|
||||||
|
" \"unit\": {\n",
|
||||||
|
" \"type\": \"string\",\n",
|
||||||
|
" \"description\": \"The unit to fetch the temperature in\",\n",
|
||||||
|
" \"enum\": [\"celsius\", \"fahrenheit\"],\n",
|
||||||
" },\n",
|
" },\n",
|
||||||
" \"unit\": {\"type\": \"string\", \"enum\": [\"celsius\", \"fahrenheit\"]},\n",
|
|
||||||
" },\n",
|
" },\n",
|
||||||
" \"required\": [\"location\"],\n",
|
" \"required\": [\"city\", \"state\", \"unit\"],\n",
|
||||||
" },\n",
|
" },\n",
|
||||||
" },\n",
|
" },\n",
|
||||||
" }\n",
|
" }\n",
|
||||||
"]\n",
|
"]"
|
||||||
"messages = [{\"role\": \"user\", \"content\": \"What's the weather like in Boston today?\"}]\n",
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Define Messages"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def get_messages():\n",
|
||||||
|
" return [\n",
|
||||||
|
" {\n",
|
||||||
|
" \"role\": \"user\",\n",
|
||||||
|
" \"content\": \"What's the weather like in Boston today? Please respond with the format: Today's weather is :{function call result}\",\n",
|
||||||
|
" }\n",
|
||||||
|
" ]\n",
|
||||||
"\n",
|
"\n",
|
||||||
"client = OpenAI(api_key=\"YOUR_API_KEY\", base_url=\"http://0.0.0.0:30000/v1\")\n",
|
"\n",
|
||||||
"model_name = client.models.list().data[0].id\n",
|
"messages = get_messages()"
|
||||||
"response = client.chat.completions.create(\n",
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Initialize the Client"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Initialize OpenAI-like client\n",
|
||||||
|
"client = OpenAI(api_key=\"None\", base_url=\"http://0.0.0.0:30333/v1\")\n",
|
||||||
|
"model_name = client.models.list().data[0].id"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Non-Streaming Request"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Non-streaming mode test\n",
|
||||||
|
"response_non_stream = client.chat.completions.create(\n",
|
||||||
|
" model=model_name,\n",
|
||||||
|
" messages=messages,\n",
|
||||||
|
" temperature=0.8,\n",
|
||||||
|
" top_p=0.8,\n",
|
||||||
|
" stream=False, # Non-streaming\n",
|
||||||
|
" tools=tools,\n",
|
||||||
|
")\n",
|
||||||
|
"print_highlight(\"Non-stream response:\")\n",
|
||||||
|
"print(response_non_stream)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Streaming Request"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Streaming mode test\n",
|
||||||
|
"print_highlight(\"Streaming response:\")\n",
|
||||||
|
"response_stream = client.chat.completions.create(\n",
|
||||||
|
" model=model_name,\n",
|
||||||
|
" messages=messages,\n",
|
||||||
|
" temperature=0.8,\n",
|
||||||
|
" top_p=0.8,\n",
|
||||||
|
" stream=True, # Enable streaming\n",
|
||||||
|
" tools=tools,\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"chunks = []\n",
|
||||||
|
"for chunk in response_stream:\n",
|
||||||
|
" chunks.append(chunk)\n",
|
||||||
|
" if chunk.choices[0].delta.tool_calls:\n",
|
||||||
|
" print(chunk.choices[0].delta.tool_calls[0])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"\n",
|
||||||
|
"### Handle Tool Calls\n",
|
||||||
|
"\n",
|
||||||
|
"When the engine determines it should call a particular tool, it will return arguments or partial arguments through the response. You can parse these arguments and later invoke the tool accordingly."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"**Non-Streaming Request**"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"name_non_stream = response_non_stream.choices[0].message.tool_calls[0].function.name\n",
|
||||||
|
"arguments_non_stream = (\n",
|
||||||
|
" response_non_stream.choices[0].message.tool_calls[0].function.arguments\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"print_highlight(f\"Final streamed function call name: {name_non_stream}\")\n",
|
||||||
|
"print_highlight(f\"Final streamed function call arguments: {arguments_non_stream}\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"**Streaming Request**"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Parse and combine function call arguments\n",
|
||||||
|
"arguments = []\n",
|
||||||
|
"for chunk in chunks:\n",
|
||||||
|
" choice = chunk.choices[0]\n",
|
||||||
|
" delta = choice.delta\n",
|
||||||
|
" if delta.tool_calls:\n",
|
||||||
|
" tool_call = delta.tool_calls[0]\n",
|
||||||
|
" if tool_call.function.name:\n",
|
||||||
|
" print_highlight(f\"Streamed function call name: {tool_call.function.name}\")\n",
|
||||||
|
"\n",
|
||||||
|
" if tool_call.function.arguments:\n",
|
||||||
|
" arguments.append(tool_call.function.arguments)\n",
|
||||||
|
" print(f\"Streamed function call arguments: {tool_call.function.arguments}\")\n",
|
||||||
|
"\n",
|
||||||
|
"# Combine all fragments into a single JSON string\n",
|
||||||
|
"full_arguments = \"\".join(arguments)\n",
|
||||||
|
"print_highlight(f\"Final streamed function call arguments: {full_arguments}\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Define a Tool Function"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# This is a demonstration, define real function according to your usage.\n",
|
||||||
|
"def get_current_weather(city: str, state: str, unit: \"str\"):\n",
|
||||||
|
" return (\n",
|
||||||
|
" f\"The weather in {city}, {state} is 85 degrees {unit}. It is \"\n",
|
||||||
|
" \"partly cloudly, with highs in the 90's.\"\n",
|
||||||
|
" )\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"available_tools = {\"get_current_weather\": get_current_weather}"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"\n",
|
||||||
|
"## Execute the Tool"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"call_data = json.loads(full_arguments)\n",
|
||||||
|
"\n",
|
||||||
|
"messages.append(\n",
|
||||||
|
" {\n",
|
||||||
|
" \"role\": \"user\",\n",
|
||||||
|
" \"content\": \"\",\n",
|
||||||
|
" \"tool_calls\": {\"name\": \"get_current_weather\", \"arguments\": full_arguments},\n",
|
||||||
|
" }\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"# Call the corresponding tool function\n",
|
||||||
|
"tool_name = messages[-1][\"tool_calls\"][\"name\"]\n",
|
||||||
|
"tool_to_call = available_tools[tool_name]\n",
|
||||||
|
"result = tool_to_call(**call_data)\n",
|
||||||
|
"print_highlight(f\"Function call result: {result}\")\n",
|
||||||
|
"messages.append({\"role\": \"tool\", \"content\": result, \"name\": tool_name})\n",
|
||||||
|
"\n",
|
||||||
|
"print_highlight(f\"Updated message history: {messages}\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Send Results Back to Model"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"final_response = client.chat.completions.create(\n",
|
||||||
" model=model_name,\n",
|
" model=model_name,\n",
|
||||||
" messages=messages,\n",
|
" messages=messages,\n",
|
||||||
" temperature=0.8,\n",
|
" temperature=0.8,\n",
|
||||||
@@ -102,17 +345,56 @@
|
|||||||
" stream=False,\n",
|
" stream=False,\n",
|
||||||
" tools=tools,\n",
|
" tools=tools,\n",
|
||||||
")\n",
|
")\n",
|
||||||
|
"print_highlight(\"Non-stream response:\")\n",
|
||||||
|
"print(final_response)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Native API and SGLang Runtime (SRT)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from transformers import AutoTokenizer\n",
|
||||||
|
"import requests\n",
|
||||||
"\n",
|
"\n",
|
||||||
"print(response)\n",
|
"# generate an answer\n",
|
||||||
|
"tokenizer = AutoTokenizer.from_pretrained(\"meta-llama/Meta-Llama-3.1-8B-Instruct\")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\"\"\"\n",
|
"messages = get_messages()\n",
|
||||||
"\n",
|
"\n",
|
||||||
"ChatCompletion(id='d6f620e1767e490d85b5ce45c15151cf', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content=None, refusal=None, \n",
|
"input = tokenizer.apply_chat_template(\n",
|
||||||
"role='assistant', audio=None, function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='0', function=Function(arguments='{\"a\": \"3\", \"b\": \"5\"}', name='add'), type='function')]), \n",
|
" messages,\n",
|
||||||
"matched_stop=128008)], created=1735411703, model='meta-llama/Llama-3.2-1B-Instruct', object='chat.completion', service_tier=None, system_fingerprint=None, \n",
|
" tokenize=False,\n",
|
||||||
"usage=CompletionUsage(completion_tokens=23, prompt_tokens=198, total_tokens=221, completion_tokens_details=None, prompt_tokens_details=None))\n",
|
" add_generation_prompt=True,\n",
|
||||||
|
" tools=tools,\n",
|
||||||
|
")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\"\"\""
|
"gen_url = \"http://localhost:30333/generate\"\n",
|
||||||
|
"gen_data = {\"text\": input, \"sampling_params\": {\"skip_special_tokens\": False}}\n",
|
||||||
|
"gen_response = requests.post(gen_url, json=gen_data).json()[\"text\"]\n",
|
||||||
|
"print(gen_response)\n",
|
||||||
|
"\n",
|
||||||
|
"# parse the response\n",
|
||||||
|
"parse_url = \"http://localhost:30333/function_call\"\n",
|
||||||
|
"\n",
|
||||||
|
"function_call_input = {\n",
|
||||||
|
" \"text\": gen_response,\n",
|
||||||
|
" \"tool_call_parser\": \"llama3\",\n",
|
||||||
|
" \"tools\": tools,\n",
|
||||||
|
"}\n",
|
||||||
|
"\n",
|
||||||
|
"function_call_response = requests.post(parse_url, json=function_call_input)\n",
|
||||||
|
"function_call_response_json = function_call_response.json()\n",
|
||||||
|
"print(\"function name: \", function_call_response_json[\"calls\"][0][\"name\"])\n",
|
||||||
|
"print(\"function arguments: \", function_call_response_json[\"calls\"][0][\"parameters\"])"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -128,11 +410,98 @@
|
|||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"## How to support a new model?\n",
|
"## Offline Engine API"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import sglang as sgl\n",
|
||||||
|
"from sglang.srt.function_call_parser import FunctionCallParser\n",
|
||||||
|
"from sglang.srt.managers.io_struct import Tool, Function\n",
|
||||||
"\n",
|
"\n",
|
||||||
"For adding support of more different models:\n",
|
"llm = sgl.Engine(model_path=\"meta-llama/Meta-Llama-3.1-8B-Instruct\")\n",
|
||||||
" 1. Update the `TOOLS_TAG_LIST` in `sglang/srt/utils.py` with the tool tag used by the model.\n",
|
"tokenizer = llm.tokenizer_manager.tokenizer\n",
|
||||||
" 2. Add support in `parse_tool_response` function for converting into tool calls `sglang/srt/utils.py`\n"
|
"input_ids = tokenizer.apply_chat_template(\n",
|
||||||
|
" messages, tokenize=True, add_generation_prompt=True, tools=tools\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"sampling_params = {\n",
|
||||||
|
" \"max_new_tokens\": 128,\n",
|
||||||
|
" \"temperature\": 0.3,\n",
|
||||||
|
" \"top_p\": 0.95,\n",
|
||||||
|
" \"skip_special_tokens\": False,\n",
|
||||||
|
"}\n",
|
||||||
|
"\n",
|
||||||
|
"# 1) Offline generation\n",
|
||||||
|
"result = llm.generate(input_ids=input_ids, sampling_params=sampling_params)\n",
|
||||||
|
"generated_text = result[\"text\"] # Assume there is only one prompt\n",
|
||||||
|
"\n",
|
||||||
|
"print(\"=== Offline Engine Output Text ===\")\n",
|
||||||
|
"print(generated_text)\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"# 2) Parse using FunctionCallParser\n",
|
||||||
|
"def convert_dict_to_tool(tool_dict: dict) -> Tool:\n",
|
||||||
|
" function_dict = tool_dict.get(\"function\", {})\n",
|
||||||
|
" return Tool(\n",
|
||||||
|
" type=tool_dict.get(\"type\", \"function\"),\n",
|
||||||
|
" function=Function(\n",
|
||||||
|
" name=function_dict.get(\"name\"),\n",
|
||||||
|
" description=function_dict.get(\"description\"),\n",
|
||||||
|
" parameters=function_dict.get(\"parameters\"),\n",
|
||||||
|
" ),\n",
|
||||||
|
" )\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"tools = [convert_dict_to_tool(raw_tool) for raw_tool in tools]\n",
|
||||||
|
"\n",
|
||||||
|
"parser = FunctionCallParser(tools=tools, tool_call_parser=\"llama3\")\n",
|
||||||
|
"normal_text, calls = parser.parse_non_stream(generated_text)\n",
|
||||||
|
"\n",
|
||||||
|
"print(\"\\n=== Parsing Result ===\")\n",
|
||||||
|
"print(\"Normal text portion:\", normal_text)\n",
|
||||||
|
"print(\"Function call portion:\")\n",
|
||||||
|
"for call in calls:\n",
|
||||||
|
" # call: ToolCallItem\n",
|
||||||
|
" print(f\" - tool name: {call.name}\")\n",
|
||||||
|
" print(f\" parameters: {call.parameters}\")\n",
|
||||||
|
"\n",
|
||||||
|
"# 3) If needed, perform additional logic on the parsed functions, such as automatically calling the corresponding function to obtain a return value, etc."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"llm.shutdown()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## How to support a new model?\n",
|
||||||
|
"1. Update the TOOLS_TAG_LIST in sglang/srt/function_call_parser.py with the model’s tool tags. Currently supported tags include:\n",
|
||||||
|
"```\n",
|
||||||
|
"\tTOOLS_TAG_LIST = [\n",
|
||||||
|
"\t “<|plugin|>“,\n",
|
||||||
|
"\t “<function=“,\n",
|
||||||
|
"\t “<tool_call>“,\n",
|
||||||
|
"\t “<|python_tag|>“,\n",
|
||||||
|
"\t “[TOOL_CALLS]”\n",
|
||||||
|
"\t]\n",
|
||||||
|
"```\n",
|
||||||
|
"2. Create a new detector class in sglang/srt/function_call_parser.py that inherits from BaseFormatDetector. The detector should handle the model’s specific function call format. For example:\n",
|
||||||
|
"```\n",
|
||||||
|
" class NewModelDetector(BaseFormatDetector):\n",
|
||||||
|
"```\n",
|
||||||
|
"3. Add the new detector to the MultiFormatParser class that manages all the format detectors."
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
|||||||
@@ -39,10 +39,12 @@ from fastapi.middleware.cors import CORSMiddleware
|
|||||||
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
||||||
|
|
||||||
from sglang.srt.entrypoints.engine import _launch_subprocesses
|
from sglang.srt.entrypoints.engine import _launch_subprocesses
|
||||||
|
from sglang.srt.function_call_parser import FunctionCallParser
|
||||||
from sglang.srt.managers.io_struct import (
|
from sglang.srt.managers.io_struct import (
|
||||||
CloseSessionReqInput,
|
CloseSessionReqInput,
|
||||||
ConfigureLoggingReq,
|
ConfigureLoggingReq,
|
||||||
EmbeddingReqInput,
|
EmbeddingReqInput,
|
||||||
|
FunctionCallReqInput,
|
||||||
GenerateReqInput,
|
GenerateReqInput,
|
||||||
GetWeightsByNameReqInput,
|
GetWeightsByNameReqInput,
|
||||||
InitWeightsUpdateGroupReqInput,
|
InitWeightsUpdateGroupReqInput,
|
||||||
@@ -369,6 +371,28 @@ async def configure_logging(obj: ConfigureLoggingReq, request: Request):
|
|||||||
return Response(status_code=200)
|
return Response(status_code=200)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/function_call")
|
||||||
|
async def function_call_request(obj: FunctionCallReqInput, request: Request):
|
||||||
|
"""
|
||||||
|
A native API endpoint to parse function calls from a text.
|
||||||
|
"""
|
||||||
|
# 1) Initialize the parser based on the request body
|
||||||
|
parser = FunctionCallParser(tools=obj.tools, tool_call_parser=obj.tool_call_parser)
|
||||||
|
|
||||||
|
# 2) Call the non-stream parsing method (non-stream)
|
||||||
|
normal_text, calls = parser.parse_non_stream(obj.text)
|
||||||
|
|
||||||
|
# 3) Organize the response content
|
||||||
|
response_data = {
|
||||||
|
"normal_text": normal_text,
|
||||||
|
"calls": [
|
||||||
|
call.model_dump() for call in calls
|
||||||
|
], # Convert pydantic objects to dictionaries
|
||||||
|
}
|
||||||
|
|
||||||
|
return ORJSONResponse(content=response_data, status_code=200)
|
||||||
|
|
||||||
|
|
||||||
##### OpenAI-compatible API endpoints #####
|
##### OpenAI-compatible API endpoints #####
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
494
python/sglang/srt/function_call_parser.py
Normal file
494
python/sglang/srt/function_call_parser.py
Normal file
@@ -0,0 +1,494 @@
|
|||||||
|
import json
|
||||||
|
import re
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from json import JSONDecodeError, JSONDecoder
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import partial_json_parser
|
||||||
|
from partial_json_parser.core.options import Allow
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
TOOLS_TAG_LIST = [
|
||||||
|
"<|plugin|>",
|
||||||
|
"<function=",
|
||||||
|
"<tool_call>",
|
||||||
|
"<|python_tag|>",
|
||||||
|
"[TOOL_CALLS]",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class Function(BaseModel):
|
||||||
|
"""Function Tool Template."""
|
||||||
|
|
||||||
|
description: Optional[str] = Field(default=None, examples=[None])
|
||||||
|
name: Optional[str] = None
|
||||||
|
parameters: Optional[object] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ToolCallItem(BaseModel):
|
||||||
|
"""Simple encapsulation of the parsed ToolCall result for easier usage in streaming contexts."""
|
||||||
|
|
||||||
|
tool_index: int
|
||||||
|
name: Optional[str] = None
|
||||||
|
parameters: str # JSON string
|
||||||
|
|
||||||
|
|
||||||
|
def _find_common_prefix(s1: str, s2: str) -> str:
|
||||||
|
prefix = ""
|
||||||
|
min_length = min(len(s1), len(s2))
|
||||||
|
for i in range(0, min_length):
|
||||||
|
if s1[i] == s2[i]:
|
||||||
|
prefix += s1[i]
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
return prefix
|
||||||
|
|
||||||
|
|
||||||
|
def _partial_json_loads(input_str: str, flags: Allow) -> Tuple[Any, int]:
|
||||||
|
try:
|
||||||
|
return (partial_json_parser.loads(input_str, flags), len(input_str))
|
||||||
|
except JSONDecodeError as e:
|
||||||
|
if "Extra data" in e.msg:
|
||||||
|
dec = JSONDecoder()
|
||||||
|
return dec.raw_decode(input_str)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def _is_complete_json(input_str: str) -> bool:
|
||||||
|
try:
|
||||||
|
json.loads(input_str)
|
||||||
|
return True
|
||||||
|
except JSONDecodeError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class StreamingParseResult:
|
||||||
|
"""Result of streaming incremental parsing."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, normal_text: str = "", calls: Optional[List[ToolCallItem]] = None
|
||||||
|
):
|
||||||
|
self.normal_text = normal_text
|
||||||
|
self.calls = calls or []
|
||||||
|
|
||||||
|
|
||||||
|
class BaseFormatDetector:
|
||||||
|
"""Base class providing two sets of interfaces: one-time and streaming incremental."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# initialize properties used for state when parsing tool calls in
|
||||||
|
self._buffer = ""
|
||||||
|
# streaming mode
|
||||||
|
self.prev_tool_call_arr: List[Dict] = []
|
||||||
|
self.current_tool_id: int = -1
|
||||||
|
self.current_tool_name_sent: bool = False
|
||||||
|
self.streamed_args_for_tool: List[str] = (
|
||||||
|
[]
|
||||||
|
) # map what has been streamed for each tool so far to a list
|
||||||
|
self.bot_token = ""
|
||||||
|
self.eot_token = ""
|
||||||
|
|
||||||
|
def parse_base_json(self, action: Dict, tools: List[Function]):
|
||||||
|
name, parameters = action["name"], json.dumps(
|
||||||
|
action.get("parameters", action.get("arguments", {})),
|
||||||
|
ensure_ascii=False,
|
||||||
|
)
|
||||||
|
tool_index = [tool.function.name for tool in tools].index(name)
|
||||||
|
tool_call_item = ToolCallItem(
|
||||||
|
tool_index=tool_index, name=name, parameters=parameters
|
||||||
|
)
|
||||||
|
calls = [tool_call_item]
|
||||||
|
return calls
|
||||||
|
|
||||||
|
def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
|
||||||
|
"""
|
||||||
|
Parses the text in one go. Returns success=True if the format matches, otherwise False.
|
||||||
|
Note that leftover_text here represents "content that this parser will not consume further".
|
||||||
|
"""
|
||||||
|
action = json.loads(text)
|
||||||
|
return self.parse_base_json(action, tools)
|
||||||
|
|
||||||
|
def parse_streaming_increment(
|
||||||
|
self, new_text: str, tools: List[Function]
|
||||||
|
) -> StreamingParseResult:
|
||||||
|
"""
|
||||||
|
Streaming incremental parsing, referencing the logic of Llama32Detector.
|
||||||
|
We partially parse JSON within <tool_call>...</tool_call>, and handle
|
||||||
|
incremental argument output.
|
||||||
|
"""
|
||||||
|
# Append new text to buffer
|
||||||
|
self._buffer += new_text
|
||||||
|
current_text = self._buffer
|
||||||
|
if not (self.bot_token in current_text or current_text.startswith("{")):
|
||||||
|
self._buffer = ""
|
||||||
|
if self.eot_token in new_text:
|
||||||
|
new_text = new_text.replace(self.eot_token, "")
|
||||||
|
return StreamingParseResult(normal_text=new_text)
|
||||||
|
|
||||||
|
# bit mask flags for partial JSON parsing. If the name hasn't been
|
||||||
|
# sent yet, don't allow sending
|
||||||
|
# an incomplete string since OpenAI only ever (as far as I have
|
||||||
|
# seen) allows sending the entire tool/ function name at once.
|
||||||
|
flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR
|
||||||
|
try:
|
||||||
|
tool_call_arr = []
|
||||||
|
is_complete = []
|
||||||
|
try:
|
||||||
|
# depending on the prompt format the Llama model may or may not
|
||||||
|
# prefix the output with the <|python_tag|> token
|
||||||
|
start_idx = (
|
||||||
|
len(self.bot_token)
|
||||||
|
if current_text.startswith(self.bot_token)
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
while start_idx < len(current_text):
|
||||||
|
(obj, end_idx) = _partial_json_loads(
|
||||||
|
current_text[start_idx:], flags
|
||||||
|
)
|
||||||
|
is_complete.append(
|
||||||
|
_is_complete_json(current_text[start_idx : start_idx + end_idx])
|
||||||
|
)
|
||||||
|
start_idx += end_idx + len("; ")
|
||||||
|
# depending on the prompt Llama can use
|
||||||
|
# either arguments or parameters
|
||||||
|
if "parameters" in obj:
|
||||||
|
assert (
|
||||||
|
"arguments" not in obj
|
||||||
|
), "model generated both parameters and arguments"
|
||||||
|
obj["arguments"] = obj["parameters"]
|
||||||
|
tool_call_arr.append(obj)
|
||||||
|
|
||||||
|
except partial_json_parser.core.exceptions.MalformedJSON:
|
||||||
|
# not enough tokens to parse into JSON yet
|
||||||
|
return StreamingParseResult()
|
||||||
|
|
||||||
|
# select as the current tool call the one we're on the state at
|
||||||
|
current_tool_call: Dict = (
|
||||||
|
tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {}
|
||||||
|
)
|
||||||
|
|
||||||
|
# case -- if no tokens have been streamed for the tool, e.g.
|
||||||
|
# only the array brackets, stream nothing
|
||||||
|
if len(tool_call_arr) == 0:
|
||||||
|
return StreamingParseResult()
|
||||||
|
|
||||||
|
# case: we are starting a new tool in the array
|
||||||
|
# -> array has > 0 length AND length has moved past cursor
|
||||||
|
elif (
|
||||||
|
len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1
|
||||||
|
):
|
||||||
|
|
||||||
|
# if we're moving on to a new call, first make sure we
|
||||||
|
# haven't missed anything in the previous one that was
|
||||||
|
# auto-generated due to JSON completions, but wasn't
|
||||||
|
# streamed to the client yet.
|
||||||
|
if self.current_tool_id >= 0:
|
||||||
|
cur_arguments = current_tool_call.get("arguments")
|
||||||
|
if cur_arguments:
|
||||||
|
cur_args_json = json.dumps(cur_arguments)
|
||||||
|
sent = len(self.streamed_args_for_tool[self.current_tool_id])
|
||||||
|
argument_diff = cur_args_json[sent:]
|
||||||
|
|
||||||
|
res = StreamingParseResult(
|
||||||
|
normal_text=None,
|
||||||
|
calls=[
|
||||||
|
ToolCallItem(
|
||||||
|
tool_index=self.current_tool_id,
|
||||||
|
name="",
|
||||||
|
parameters=argument_diff,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
self.streamed_args_for_tool[
|
||||||
|
self.current_tool_id
|
||||||
|
] += argument_diff
|
||||||
|
else:
|
||||||
|
res = StreamingParseResult()
|
||||||
|
else:
|
||||||
|
res = StreamingParseResult()
|
||||||
|
# re-set stuff pertaining to progress in the current tool
|
||||||
|
self.current_tool_id = len(tool_call_arr) - 1
|
||||||
|
self.current_tool_name_sent = False
|
||||||
|
self.streamed_args_for_tool.append("")
|
||||||
|
print("starting on new tool %d", self.current_tool_id)
|
||||||
|
return res
|
||||||
|
|
||||||
|
# if the current tool name hasn't been sent, send if available
|
||||||
|
# - otherwise send nothing
|
||||||
|
elif not self.current_tool_name_sent:
|
||||||
|
function_name = current_tool_call.get("name")
|
||||||
|
if function_name:
|
||||||
|
res = StreamingParseResult(
|
||||||
|
normal_text=None,
|
||||||
|
calls=[
|
||||||
|
ToolCallItem(
|
||||||
|
tool_index=self.current_tool_id,
|
||||||
|
name=function_name,
|
||||||
|
parameters="",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
self.current_tool_name_sent = True
|
||||||
|
else:
|
||||||
|
res = StreamingParseResult()
|
||||||
|
|
||||||
|
# now we know we're on the same tool call and we're streaming
|
||||||
|
# arguments
|
||||||
|
else:
|
||||||
|
cur_arguments = current_tool_call.get("arguments")
|
||||||
|
res = StreamingParseResult()
|
||||||
|
|
||||||
|
if cur_arguments:
|
||||||
|
sent = len(self.streamed_args_for_tool[self.current_tool_id])
|
||||||
|
cur_args_json = json.dumps(cur_arguments)
|
||||||
|
prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get(
|
||||||
|
"arguments"
|
||||||
|
)
|
||||||
|
|
||||||
|
argument_diff = None
|
||||||
|
if is_complete[self.current_tool_id]:
|
||||||
|
argument_diff = cur_args_json[sent:]
|
||||||
|
self._buffer = ""
|
||||||
|
self.prev_tool_call_arr[self.current_tool_id].clear()
|
||||||
|
self.current_tool_name_sent: bool = False
|
||||||
|
self.streamed_args_for_tool[self.current_tool_id] = ""
|
||||||
|
|
||||||
|
elif prev_arguments:
|
||||||
|
prev_args_json = json.dumps(prev_arguments)
|
||||||
|
if cur_args_json != prev_args_json:
|
||||||
|
|
||||||
|
prefix = _find_common_prefix(prev_args_json, cur_args_json)
|
||||||
|
argument_diff = prefix[sent:]
|
||||||
|
|
||||||
|
if argument_diff is not None:
|
||||||
|
res = StreamingParseResult(
|
||||||
|
calls=[
|
||||||
|
ToolCallItem(
|
||||||
|
tool_index=self.current_tool_id,
|
||||||
|
name="",
|
||||||
|
parameters=argument_diff,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
if not is_complete[self.current_tool_id]:
|
||||||
|
self.streamed_args_for_tool[
|
||||||
|
self.current_tool_id
|
||||||
|
] += argument_diff
|
||||||
|
|
||||||
|
self.prev_tool_call_arr = tool_call_arr
|
||||||
|
return res
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
# Skipping chunk as a result of tool streaming extraction error
|
||||||
|
return StreamingParseResult()
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen25Detector(BaseFormatDetector):
|
||||||
|
"""
|
||||||
|
Detector for Qwen 2.5 models.
|
||||||
|
Assumes function call format:
|
||||||
|
<tool_call>{"name":"xxx", "arguments":{...}}</tool_call>
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""
|
||||||
|
Initializes the detector with necessary state variables.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.bot_token = "<tool_call>"
|
||||||
|
self.eot_token = "</tool_call>"
|
||||||
|
|
||||||
|
def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
|
||||||
|
"""
|
||||||
|
One-time parsing: Detects and parses tool calls in the provided text.
|
||||||
|
|
||||||
|
:param text: The complete text to parse.
|
||||||
|
:param tools: List of available tools.
|
||||||
|
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
|
||||||
|
"""
|
||||||
|
if "<tool_call>" not in text:
|
||||||
|
return []
|
||||||
|
pattern = r"<tool_call>(.*?)</tool_call>"
|
||||||
|
match_result_list = re.findall(pattern, text, re.DOTALL)
|
||||||
|
calls = []
|
||||||
|
for match_result in match_result_list:
|
||||||
|
match_result = json.loads(match_result)
|
||||||
|
calls.extend(self.parse_base_json(match_result, tools))
|
||||||
|
return calls
|
||||||
|
|
||||||
|
|
||||||
|
class MistralDetector(BaseFormatDetector):
|
||||||
|
"""
|
||||||
|
Detector for Mistral models.
|
||||||
|
Assumes function call format:
|
||||||
|
<|action_start|><|plugin|>{"name":"xxx", "arguments":{...}}<|action_end|>
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""
|
||||||
|
Initializes the detector with necessary state variables.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.bot_token = "[TOOL_CALLS] ["
|
||||||
|
self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
|
||||||
|
|
||||||
|
def _clean_text(self, text: str) -> str:
|
||||||
|
"""
|
||||||
|
clean text to only leave ''[TOOL_CALLS] [{"name": xxx, "arguments": {xxx}}]'
|
||||||
|
for example,
|
||||||
|
text = '[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"location": "Boston, MA", "unit": "fahrenheit"}}]\n\nToday\'s weather in Boston is :{function call result} (in Fahrenheit)\n\nIf you prefer Celsius, please let me know.'
|
||||||
|
return '[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"location": "Boston, MA", "unit": "fahrenheit"}}]'
|
||||||
|
The key pattern is [TOOL_CALLS] [...]
|
||||||
|
"""
|
||||||
|
find_results = re.findall(r"\[TOOL_CALLS\] \[.*?\]", text, re.DOTALL)
|
||||||
|
if len(find_results) > 0:
|
||||||
|
return find_results[0]
|
||||||
|
else:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
|
||||||
|
"""
|
||||||
|
One-time parsing: Detects and parses tool calls in the provided text.
|
||||||
|
|
||||||
|
:param text: The complete text to parse.
|
||||||
|
:param tools: List of available tools.
|
||||||
|
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
|
||||||
|
"""
|
||||||
|
text = self._clean_text(text)
|
||||||
|
tool_content = text.replace("[TOOL_CALLS]", "").strip()
|
||||||
|
raw_tool_calls = self.tool_call_regex.findall(tool_content)
|
||||||
|
calls = []
|
||||||
|
if len(raw_tool_calls) > 0:
|
||||||
|
raw_tool_call = raw_tool_calls[0]
|
||||||
|
function_call_arr = json.loads(raw_tool_call)
|
||||||
|
for match_result in function_call_arr:
|
||||||
|
calls.extend(self.parse_base_json(match_result, tools))
|
||||||
|
return calls
|
||||||
|
|
||||||
|
|
||||||
|
class Llama32Detector(BaseFormatDetector):
|
||||||
|
"""
|
||||||
|
Detector for Llama 3.2 models.
|
||||||
|
Assumes function call format:
|
||||||
|
<|python_tag|>{"name":"xxx", "arguments":{...}}
|
||||||
|
Does not require a closing tag "</python_tag|>",
|
||||||
|
relies on json.loads(...) success to determine if JSON is complete.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""
|
||||||
|
Initializes the detector with necessary state variables.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.bot_token = "<|python_tag|>"
|
||||||
|
|
||||||
|
def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
|
||||||
|
"""
|
||||||
|
One-time parsing: Detects and parses tool calls in the provided text.
|
||||||
|
|
||||||
|
:param text: The complete text to parse.
|
||||||
|
:param tools: List of available tools.
|
||||||
|
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if "<|python_tag|>" not in text:
|
||||||
|
return []
|
||||||
|
_, action = text.split("<|python_tag|>")
|
||||||
|
action = json.loads(action)
|
||||||
|
return self.parse_base_json(action, tools)
|
||||||
|
|
||||||
|
|
||||||
|
class MultiFormatParser:
|
||||||
|
def __init__(self, detectors: List[BaseFormatDetector]):
|
||||||
|
"""
|
||||||
|
:param detectors: A series of available Detector instances passed in
|
||||||
|
"""
|
||||||
|
self.detectors = detectors
|
||||||
|
|
||||||
|
def parse_once(self, text: str, tools: List[Function]):
|
||||||
|
"""
|
||||||
|
One-time parsing: Loop through detectors until there are no new matches or text is exhausted
|
||||||
|
Return: (final_text, all_calls)
|
||||||
|
- final_text: The remaining text after parsing that was not consumed by any Detector (can be treated as normal text)
|
||||||
|
- all_calls: All calls parsed by the Detectors
|
||||||
|
"""
|
||||||
|
final_calls = []
|
||||||
|
final_normal_text = text
|
||||||
|
for detector in self.detectors:
|
||||||
|
tool_call_list = detector.detect_and_parse(text, tools)
|
||||||
|
if len(tool_call_list) > 0: # parsed successfully
|
||||||
|
final_calls = tool_call_list
|
||||||
|
break
|
||||||
|
|
||||||
|
# leftover_text is the normal text not consumed by any Detector
|
||||||
|
return final_normal_text, final_calls
|
||||||
|
|
||||||
|
def parse_streaming_increment(self, new_text: str, tools: List[Function]):
|
||||||
|
"""
|
||||||
|
Streaming incremental parsing: Feed new_text to each detector's parse_streaming_increment
|
||||||
|
and merge their produced normal_text/calls to return.
|
||||||
|
(The logic here can be "priority-based" or "parallel parsing" based on your needs)
|
||||||
|
"""
|
||||||
|
final_normal_text = ""
|
||||||
|
final_calls = []
|
||||||
|
|
||||||
|
for detector in self.detectors:
|
||||||
|
sp_result = detector.parse_streaming_increment(new_text, tools)
|
||||||
|
# Merge normal_text and calls
|
||||||
|
# If one sp_result contains result call, this should be a successful parse
|
||||||
|
# If one sp_result only contains normal_text, this can either be a successful
|
||||||
|
# parse or it is not using the desired parsing tool.
|
||||||
|
if sp_result.normal_text:
|
||||||
|
final_normal_text = sp_result.normal_text
|
||||||
|
if sp_result.calls:
|
||||||
|
final_calls.extend(sp_result.calls)
|
||||||
|
final_normal_text = sp_result.normal_text
|
||||||
|
break
|
||||||
|
|
||||||
|
return final_normal_text, final_calls
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionCallParser:
|
||||||
|
"""
|
||||||
|
In streaming scenarios, each time new_text is received, it calls multi_format_parser.parse_streaming_increment
|
||||||
|
and returns the resulting normal_text and calls to the upper layer (or SSE).
|
||||||
|
"""
|
||||||
|
|
||||||
|
ToolCallParserEnum: Dict[str, BaseFormatDetector] = {
|
||||||
|
"llama3": Llama32Detector,
|
||||||
|
"qwen25": Qwen25Detector,
|
||||||
|
"mistral": MistralDetector,
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, tools: List[Function], tool_call_parser: str = None):
|
||||||
|
detectors = []
|
||||||
|
if tool_call_parser:
|
||||||
|
detector_class = self.ToolCallParserEnum.get(tool_call_parser)
|
||||||
|
if detector_class:
|
||||||
|
detectors.append(detector_class())
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported tool_call_parser: {tool_call_parser}")
|
||||||
|
else:
|
||||||
|
raise ValueError("Tool Call Parser Not Given!")
|
||||||
|
|
||||||
|
self.multi_format_parser = MultiFormatParser(detectors)
|
||||||
|
self.tools = tools
|
||||||
|
|
||||||
|
def parse_non_stream(self, full_text: str):
|
||||||
|
"""
|
||||||
|
Non-streaming call: one-time parsing
|
||||||
|
"""
|
||||||
|
full_normal_text, calls = self.multi_format_parser.parse_once(
|
||||||
|
full_text, self.tools
|
||||||
|
)
|
||||||
|
return full_normal_text, calls
|
||||||
|
|
||||||
|
def parse_stream_chunk(self, chunk_text: str):
|
||||||
|
"""
|
||||||
|
Streaming call: incremental parsing
|
||||||
|
"""
|
||||||
|
normal_text, calls = self.multi_format_parser.parse_streaming_increment(
|
||||||
|
chunk_text, self.tools
|
||||||
|
)
|
||||||
|
return normal_text, calls
|
||||||
@@ -17,7 +17,7 @@ processes (TokenizerManager, DetokenizerManager, Controller).
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
@@ -540,3 +540,27 @@ class CloseSessionReqInput:
|
|||||||
class OpenSessionReqOutput:
|
class OpenSessionReqOutput:
|
||||||
session_id: Optional[str]
|
session_id: Optional[str]
|
||||||
success: bool
|
success: bool
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Function:
|
||||||
|
description: Optional[str] = None
|
||||||
|
name: Optional[str] = None
|
||||||
|
parameters: Optional[object] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Tool:
|
||||||
|
function: Function
|
||||||
|
type: Optional[str] = "function"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FunctionCallReqInput:
|
||||||
|
text: str # The text to parse.
|
||||||
|
tools: List[Tool] = field(
|
||||||
|
default_factory=list
|
||||||
|
) # A list of available function tools (name, parameters, etc.).
|
||||||
|
tool_call_parser: Optional[str] = (
|
||||||
|
None # Specify the parser type, e.g. 'llama3', 'qwen25', or 'mistral'. If not specified, tries all.
|
||||||
|
)
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ import os
|
|||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import Dict, List
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
from fastapi import HTTPException, Request, UploadFile
|
from fastapi import HTTPException, Request, UploadFile
|
||||||
from fastapi.responses import ORJSONResponse, StreamingResponse
|
from fastapi.responses import ORJSONResponse, StreamingResponse
|
||||||
@@ -40,6 +40,7 @@ from sglang.srt.conversation import (
|
|||||||
generate_chat_conv,
|
generate_chat_conv,
|
||||||
register_conv_template,
|
register_conv_template,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.function_call_parser import TOOLS_TAG_LIST, FunctionCallParser
|
||||||
from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
|
from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
|
||||||
from sglang.srt.openai_api.protocol import (
|
from sglang.srt.openai_api.protocol import (
|
||||||
BatchRequest,
|
BatchRequest,
|
||||||
@@ -71,7 +72,6 @@ from sglang.srt.openai_api.protocol import (
|
|||||||
TopLogprob,
|
TopLogprob,
|
||||||
UsageInfo,
|
UsageInfo,
|
||||||
)
|
)
|
||||||
from sglang.srt.utils import TOOLS_TAG_LIST, parse_tool_response
|
|
||||||
from sglang.utils import get_exception_traceback
|
from sglang.utils import get_exception_traceback
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -309,6 +309,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
|
|||||||
ret,
|
ret,
|
||||||
to_file=True,
|
to_file=True,
|
||||||
cache_report=tokenizer_manager.server_args.enable_cache_report,
|
cache_report=tokenizer_manager.server_args.enable_cache_report,
|
||||||
|
tool_call_parser=tokenizer_manager.server_args.tool_call_parser,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
responses = v1_generate_response(
|
responses = v1_generate_response(
|
||||||
@@ -877,9 +878,6 @@ def v1_chat_generate_request(
|
|||||||
tools = None
|
tools = None
|
||||||
if request.tools and request.tool_choice != "none":
|
if request.tools and request.tool_choice != "none":
|
||||||
request.skip_special_tokens = False
|
request.skip_special_tokens = False
|
||||||
if request.stream:
|
|
||||||
logger.warning("Streaming is not supported with tools.")
|
|
||||||
request.stream = False
|
|
||||||
if not isinstance(request.tool_choice, str):
|
if not isinstance(request.tool_choice, str):
|
||||||
tools = [
|
tools = [
|
||||||
item.function.model_dump()
|
item.function.model_dump()
|
||||||
@@ -908,12 +906,26 @@ def v1_chat_generate_request(
|
|||||||
openai_compatible_messages = openai_compatible_messages[:-1]
|
openai_compatible_messages = openai_compatible_messages[:-1]
|
||||||
else:
|
else:
|
||||||
assistant_prefix = None
|
assistant_prefix = None
|
||||||
prompt_ids = tokenizer_manager.tokenizer.apply_chat_template(
|
|
||||||
openai_compatible_messages,
|
try:
|
||||||
tokenize=True,
|
prompt_ids = tokenizer_manager.tokenizer.apply_chat_template(
|
||||||
add_generation_prompt=True,
|
openai_compatible_messages,
|
||||||
tools=tools,
|
tokenize=True,
|
||||||
)
|
add_generation_prompt=True,
|
||||||
|
tools=tools,
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
# This except branch will be triggered when the chosen model
|
||||||
|
# has a different tools input format that is not compatiable
|
||||||
|
# with openAI's apply_chat_template tool_call format, like Mistral.
|
||||||
|
tools = [t if "function" in t else {"function": t} for t in tools]
|
||||||
|
prompt_ids = tokenizer_manager.tokenizer.apply_chat_template(
|
||||||
|
openai_compatible_messages,
|
||||||
|
tokenize=True,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
tools=tools,
|
||||||
|
)
|
||||||
|
|
||||||
if assistant_prefix:
|
if assistant_prefix:
|
||||||
prompt_ids += tokenizer_manager.tokenizer.encode(assistant_prefix)
|
prompt_ids += tokenizer_manager.tokenizer.encode(assistant_prefix)
|
||||||
stop = request.stop
|
stop = request.stop
|
||||||
@@ -1005,7 +1017,9 @@ def v1_chat_generate_request(
|
|||||||
return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]
|
return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]
|
||||||
|
|
||||||
|
|
||||||
def v1_chat_generate_response(request, ret, to_file=False, cache_report=False):
|
def v1_chat_generate_response(
|
||||||
|
request, ret, to_file=False, cache_report=False, tool_call_parser=None
|
||||||
|
):
|
||||||
choices = []
|
choices = []
|
||||||
|
|
||||||
for idx, ret_item in enumerate(ret):
|
for idx, ret_item in enumerate(ret):
|
||||||
@@ -1066,12 +1080,13 @@ def v1_chat_generate_response(request, ret, to_file=False, cache_report=False):
|
|||||||
if finish_reason == "stop":
|
if finish_reason == "stop":
|
||||||
finish_reason = "tool_calls"
|
finish_reason = "tool_calls"
|
||||||
try:
|
try:
|
||||||
text, call_info_list = parse_tool_response(text, tools) # noqa
|
parser = FunctionCallParser(tools, tool_call_parser)
|
||||||
|
full_normal_text, call_info_list = parser.parse_non_stream(text)
|
||||||
tool_calls = [
|
tool_calls = [
|
||||||
ToolCall(
|
ToolCall(
|
||||||
id=str(call_info[0]),
|
id=str(call_info.tool_index),
|
||||||
function=FunctionResponse(
|
function=FunctionResponse(
|
||||||
name=call_info[1], arguments=call_info[2]
|
name=call_info.name, arguments=call_info.parameters
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
for call_info in call_info_list
|
for call_info in call_info_list
|
||||||
@@ -1172,6 +1187,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|||||||
adapted_request, request = v1_chat_generate_request(all_requests, tokenizer_manager)
|
adapted_request, request = v1_chat_generate_request(all_requests, tokenizer_manager)
|
||||||
|
|
||||||
if adapted_request.stream:
|
if adapted_request.stream:
|
||||||
|
parser_dict = {}
|
||||||
|
|
||||||
async def generate_stream_resp():
|
async def generate_stream_resp():
|
||||||
is_firsts = {}
|
is_firsts = {}
|
||||||
@@ -1184,6 +1200,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|||||||
adapted_request, raw_request
|
adapted_request, raw_request
|
||||||
):
|
):
|
||||||
index = content.get("index", 0)
|
index = content.get("index", 0)
|
||||||
|
text = content["text"]
|
||||||
|
|
||||||
is_first = is_firsts.get(index, True)
|
is_first = is_firsts.get(index, True)
|
||||||
stream_buffer = stream_buffers.get(index, "")
|
stream_buffer = stream_buffers.get(index, "")
|
||||||
@@ -1263,29 +1280,111 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|||||||
|
|
||||||
text = content["text"]
|
text = content["text"]
|
||||||
delta = text[len(stream_buffer) :]
|
delta = text[len(stream_buffer) :]
|
||||||
stream_buffer = stream_buffer + delta
|
new_stream_buffer = stream_buffer + delta
|
||||||
choice_data = ChatCompletionResponseStreamChoice(
|
|
||||||
index=index,
|
|
||||||
delta=DeltaMessage(content=delta),
|
|
||||||
finish_reason=(finish_reason["type"] if finish_reason else ""),
|
|
||||||
matched_stop=(
|
|
||||||
finish_reason["matched"]
|
|
||||||
if finish_reason and "matched" in finish_reason
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
logprobs=choice_logprobs,
|
|
||||||
)
|
|
||||||
chunk = ChatCompletionStreamResponse(
|
|
||||||
id=content["meta_info"]["id"],
|
|
||||||
choices=[choice_data],
|
|
||||||
model=request.model,
|
|
||||||
)
|
|
||||||
|
|
||||||
is_firsts[index] = is_first
|
if request.tool_choice != "none" and request.tools:
|
||||||
stream_buffers[index] = stream_buffer
|
if index not in parser_dict:
|
||||||
n_prev_tokens[index] = n_prev_token
|
parser_dict[index] = FunctionCallParser(
|
||||||
|
tools=request.tools,
|
||||||
|
tool_call_parser=tokenizer_manager.server_args.tool_call_parser,
|
||||||
|
)
|
||||||
|
parser = parser_dict[index]
|
||||||
|
|
||||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
# parse_increment => returns (normal_text, calls)
|
||||||
|
normal_text, calls = parser.parse_stream_chunk(delta)
|
||||||
|
|
||||||
|
# 1) if there's normal_text, output it as normal content
|
||||||
|
if normal_text:
|
||||||
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
|
index=index,
|
||||||
|
delta=DeltaMessage(content=normal_text),
|
||||||
|
finish_reason=(
|
||||||
|
finish_reason["type"] if finish_reason else ""
|
||||||
|
),
|
||||||
|
)
|
||||||
|
chunk = ChatCompletionStreamResponse(
|
||||||
|
id=content["meta_info"]["id"],
|
||||||
|
choices=[choice_data],
|
||||||
|
model=request.model,
|
||||||
|
)
|
||||||
|
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||||
|
|
||||||
|
# 2) if we found calls, we output them as separate chunk(s)
|
||||||
|
for call_item in calls:
|
||||||
|
# transform call_item -> FunctionResponse + ToolCall
|
||||||
|
|
||||||
|
if (
|
||||||
|
content["meta_info"]["finish_reason"]
|
||||||
|
and content["meta_info"]["finish_reason"]["type"]
|
||||||
|
== "stop"
|
||||||
|
):
|
||||||
|
latest_delta_len = 0
|
||||||
|
if isinstance(call_item.parameters, str):
|
||||||
|
latest_delta_len = len(call_item.parameters)
|
||||||
|
|
||||||
|
expected_call = json.dumps(
|
||||||
|
parser.multi_format_parser.detectors[0]
|
||||||
|
.prev_tool_call_arr[index]
|
||||||
|
.get("arguments", {}),
|
||||||
|
ensure_ascii=False,
|
||||||
|
)
|
||||||
|
actual_call = parser.multi_format_parser.detectors[
|
||||||
|
0
|
||||||
|
].streamed_args_for_tool[index]
|
||||||
|
if latest_delta_len > 0:
|
||||||
|
actual_call = actual_call[:-latest_delta_len]
|
||||||
|
remaining_call = expected_call.replace(
|
||||||
|
actual_call, "", 1
|
||||||
|
)
|
||||||
|
call_item.parameters = remaining_call
|
||||||
|
|
||||||
|
tool_call = ToolCall(
|
||||||
|
id=str(call_item.tool_index),
|
||||||
|
function=FunctionResponse(
|
||||||
|
name=call_item.name,
|
||||||
|
arguments=call_item.parameters,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
|
index=index,
|
||||||
|
delta=DeltaMessage(
|
||||||
|
role="assistant", tool_calls=[tool_call]
|
||||||
|
),
|
||||||
|
finish_reason="tool_call",
|
||||||
|
)
|
||||||
|
chunk = ChatCompletionStreamResponse(
|
||||||
|
id=content["meta_info"]["id"],
|
||||||
|
choices=[choice_data],
|
||||||
|
model=request.model,
|
||||||
|
)
|
||||||
|
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||||
|
|
||||||
|
stream_buffers[index] = new_stream_buffer
|
||||||
|
is_firsts[index] = is_first
|
||||||
|
|
||||||
|
else:
|
||||||
|
# No tool calls => just treat this as normal text
|
||||||
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
|
index=index,
|
||||||
|
delta=DeltaMessage(content=delta),
|
||||||
|
finish_reason=(
|
||||||
|
finish_reason["type"] if finish_reason else ""
|
||||||
|
),
|
||||||
|
matched_stop=(
|
||||||
|
finish_reason["matched"]
|
||||||
|
if finish_reason and "matched" in finish_reason
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
logprobs=choice_logprobs,
|
||||||
|
)
|
||||||
|
chunk = ChatCompletionStreamResponse(
|
||||||
|
id=content["meta_info"]["id"],
|
||||||
|
choices=[choice_data],
|
||||||
|
model=request.model,
|
||||||
|
)
|
||||||
|
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||||
|
stream_buffers[index] = new_stream_buffer
|
||||||
|
is_firsts[index] = is_first
|
||||||
if request.stream_options and request.stream_options.include_usage:
|
if request.stream_options and request.stream_options.include_usage:
|
||||||
total_prompt_tokens = sum(
|
total_prompt_tokens = sum(
|
||||||
tokens
|
tokens
|
||||||
@@ -1333,7 +1432,10 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|||||||
ret = [ret]
|
ret = [ret]
|
||||||
|
|
||||||
response = v1_chat_generate_response(
|
response = v1_chat_generate_response(
|
||||||
request, ret, cache_report=tokenizer_manager.server_args.enable_cache_report
|
request,
|
||||||
|
ret,
|
||||||
|
cache_report=tokenizer_manager.server_args.enable_cache_report,
|
||||||
|
tool_call_parser=tokenizer_manager.server_args.tool_call_parser,
|
||||||
)
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|||||||
@@ -262,7 +262,7 @@ class Function(BaseModel):
|
|||||||
"""Function descriptions."""
|
"""Function descriptions."""
|
||||||
|
|
||||||
description: Optional[str] = Field(default=None, examples=[None])
|
description: Optional[str] = Field(default=None, examples=[None])
|
||||||
name: str
|
name: Optional[str] = None
|
||||||
parameters: Optional[object] = None
|
parameters: Optional[object] = None
|
||||||
|
|
||||||
|
|
||||||
@@ -276,7 +276,7 @@ class Tool(BaseModel):
|
|||||||
class ToolChoiceFuncName(BaseModel):
|
class ToolChoiceFuncName(BaseModel):
|
||||||
"""The name of tool choice function."""
|
"""The name of tool choice function."""
|
||||||
|
|
||||||
name: str
|
name: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class ToolChoice(BaseModel):
|
class ToolChoice(BaseModel):
|
||||||
@@ -329,8 +329,8 @@ class ChatCompletionRequest(BaseModel):
|
|||||||
class FunctionResponse(BaseModel):
|
class FunctionResponse(BaseModel):
|
||||||
"""Function response."""
|
"""Function response."""
|
||||||
|
|
||||||
name: str
|
name: Optional[str] = None
|
||||||
arguments: str
|
arguments: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class ToolCall(BaseModel):
|
class ToolCall(BaseModel):
|
||||||
@@ -367,6 +367,7 @@ class ChatCompletionResponse(BaseModel):
|
|||||||
class DeltaMessage(BaseModel):
|
class DeltaMessage(BaseModel):
|
||||||
role: Optional[str] = None
|
role: Optional[str] = None
|
||||||
content: Optional[str] = None
|
content: Optional[str] = None
|
||||||
|
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionResponseStreamChoice(BaseModel):
|
class ChatCompletionResponseStreamChoice(BaseModel):
|
||||||
|
|||||||
@@ -161,6 +161,7 @@ class ServerArgs:
|
|||||||
|
|
||||||
# Custom logit processor
|
# Custom logit processor
|
||||||
enable_custom_logit_processor: bool = False
|
enable_custom_logit_processor: bool = False
|
||||||
|
tool_call_parser: str = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# Set missing default values
|
# Set missing default values
|
||||||
@@ -877,6 +878,14 @@ class ServerArgs:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Enable users to pass custom logit processors to the server (disabled by default for security)",
|
help="Enable users to pass custom logit processors to the server (disabled by default for security)",
|
||||||
)
|
)
|
||||||
|
# Function Calling
|
||||||
|
parser.add_argument(
|
||||||
|
"--tool-call-parser",
|
||||||
|
type=str,
|
||||||
|
choices=["qwen25", "mistral", "llama3"],
|
||||||
|
default=ServerArgs.tool_call_parser,
|
||||||
|
help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', and 'llama3'.",
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_cli_args(cls, args: argparse.Namespace):
|
def from_cli_args(cls, args: argparse.Namespace):
|
||||||
|
|||||||
@@ -1243,68 +1243,6 @@ def dataclass_to_string_truncated(data, max_length=2048):
|
|||||||
return str(data)
|
return str(data)
|
||||||
|
|
||||||
|
|
||||||
TOOLS_TAG_LIST = ["<|plugin|>", "<function=", "<tool_call>", "<|python_tag|>"]
|
|
||||||
|
|
||||||
|
|
||||||
def parse_tool_response(text, tools, **kwargs):
|
|
||||||
"""Parse model response containing tool information.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text(str): model response in string format
|
|
||||||
tools(List): tools from user request
|
|
||||||
"""
|
|
||||||
if "<|plugin|>" in text: # internlm2
|
|
||||||
text, action = text.split("<|action_start|><|plugin|>")
|
|
||||||
action = action.split("<|action_end|>".strip())[0]
|
|
||||||
action = action[action.find("{") :]
|
|
||||||
action = json.loads(action)
|
|
||||||
name, parameters = action["name"], json.dumps(
|
|
||||||
action.get("parameters", action.get("arguments", {})), ensure_ascii=False
|
|
||||||
)
|
|
||||||
call_info_list = [(name, parameters)]
|
|
||||||
elif "<function=" in text: # llama3.1
|
|
||||||
action, _ = text.split("</function>")
|
|
||||||
parameters = action[action.find("{") :]
|
|
||||||
name = action.split("<function=")[1].split(">{")[0]
|
|
||||||
call_info_list = [(name, parameters)]
|
|
||||||
elif "<tool_call>" in text and "</tool_call>" in text: # qwen2.5
|
|
||||||
# get tool_call in text
|
|
||||||
pattern = r"<tool_call>(.*?)</tool_call>"
|
|
||||||
match_result_list = re.findall(pattern, text, re.DOTALL)
|
|
||||||
call_info_list = []
|
|
||||||
for match_result in match_result_list:
|
|
||||||
action = json.loads(match_result)
|
|
||||||
call_info_list.append(
|
|
||||||
(action["name"], json.dumps(action["arguments"], ensure_ascii=False))
|
|
||||||
)
|
|
||||||
# get text outside of tags
|
|
||||||
if not text.startswith("<tool_call>"):
|
|
||||||
text = text[: text.find("<tool_call>")]
|
|
||||||
elif not text.endswith("</tool_call>"):
|
|
||||||
text = text[text.rfind("</tool_call>") + len("</tool_call>") :]
|
|
||||||
else:
|
|
||||||
text = ""
|
|
||||||
elif "<|python_tag|>" in text: # llama3.2
|
|
||||||
_, action = text.split("<|python_tag|>")
|
|
||||||
action = json.loads(action)
|
|
||||||
name, parameters = action["name"], json.dumps(
|
|
||||||
action.get("parameters", action.get("arguments", {})), ensure_ascii=False
|
|
||||||
)
|
|
||||||
call_info_list = [(name, parameters)]
|
|
||||||
else:
|
|
||||||
raise RuntimeError(f"Unexpected model response: {text}")
|
|
||||||
|
|
||||||
call_info_list = [
|
|
||||||
(
|
|
||||||
[tool.function.name for tool in tools].index(call_info[0]),
|
|
||||||
call_info[0],
|
|
||||||
call_info[1],
|
|
||||||
)
|
|
||||||
for call_info in call_info_list
|
|
||||||
]
|
|
||||||
return text, call_info_list
|
|
||||||
|
|
||||||
|
|
||||||
def permute_weight(x: torch.Tensor) -> torch.Tensor:
|
def permute_weight(x: torch.Tensor) -> torch.Tensor:
|
||||||
b_ = x.shape[0]
|
b_ = x.shape[0]
|
||||||
n_ = x.shape[1]
|
n_ = x.shape[1]
|
||||||
|
|||||||
249
test/srt/test_function_calling.py
Normal file
249
test/srt/test_function_calling.py
Normal file
@@ -0,0 +1,249 @@
|
|||||||
|
import json
|
||||||
|
import time
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import openai
|
||||||
|
|
||||||
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||||
|
from sglang.srt.utils import kill_process_tree
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||||
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
DEFAULT_URL_FOR_TEST,
|
||||||
|
popen_launch_server,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestOpenAIServerFunctionCalling(unittest.TestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
# Replace with the model name needed for testing; if not required, reuse DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||||
|
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
cls.api_key = "sk-123456"
|
||||||
|
|
||||||
|
# Start the local OpenAI Server. If necessary, you can add other parameters such as --enable-tools.
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
cls.model,
|
||||||
|
cls.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
api_key=cls.api_key,
|
||||||
|
other_args=[
|
||||||
|
# If your server needs extra parameters to test function calling, please add them here.
|
||||||
|
"--tool-call-parser",
|
||||||
|
"llama3",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
cls.base_url += "/v1"
|
||||||
|
cls.tokenizer = get_tokenizer(cls.model)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
kill_process_tree(cls.process.pid)
|
||||||
|
|
||||||
|
def test_function_calling_format(self):
|
||||||
|
"""
|
||||||
|
Test: Whether the function call format returned by the AI is correct.
|
||||||
|
When returning a tool call, message.content should be None, and tool_calls should be a list.
|
||||||
|
"""
|
||||||
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||||
|
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "add",
|
||||||
|
"description": "Compute the sum of two numbers",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"a": {
|
||||||
|
"type": "int",
|
||||||
|
"description": "A number",
|
||||||
|
},
|
||||||
|
"b": {
|
||||||
|
"type": "int",
|
||||||
|
"description": "A number",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["a", "b"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
messages = [{"role": "user", "content": "Compute (3+5)"}]
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model=self.model,
|
||||||
|
messages=messages,
|
||||||
|
temperature=0.8,
|
||||||
|
top_p=0.8,
|
||||||
|
stream=False,
|
||||||
|
tools=tools,
|
||||||
|
)
|
||||||
|
|
||||||
|
content = response.choices[0].message.content
|
||||||
|
tool_calls = response.choices[0].message.tool_calls
|
||||||
|
|
||||||
|
assert content is None, (
|
||||||
|
"When function call is successful, message.content should be None, "
|
||||||
|
f"but got: {content}"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
isinstance(tool_calls, list) and len(tool_calls) > 0
|
||||||
|
), "tool_calls should be a non-empty list"
|
||||||
|
|
||||||
|
function_name = tool_calls[0].function.name
|
||||||
|
assert function_name == "add", "Function name should be 'add'"
|
||||||
|
|
||||||
|
def test_function_calling_streaming_simple(self):
|
||||||
|
"""
|
||||||
|
Test: Whether the function name can be correctly recognized in streaming mode.
|
||||||
|
- Expect a function call to be found, and the function name to be correct.
|
||||||
|
- Verify that streaming mode returns at least multiple chunks.
|
||||||
|
"""
|
||||||
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||||
|
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"description": "Get the current weather in a given location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"city": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city to find the weather for",
|
||||||
|
},
|
||||||
|
"unit": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Weather unit (celsius or fahrenheit)",
|
||||||
|
"enum": ["celsius", "fahrenheit"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["city", "unit"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
messages = [{"role": "user", "content": "What is the temperature in Paris?"}]
|
||||||
|
|
||||||
|
response_stream = client.chat.completions.create(
|
||||||
|
model=self.model,
|
||||||
|
messages=messages,
|
||||||
|
temperature=0.8,
|
||||||
|
top_p=0.8,
|
||||||
|
stream=True,
|
||||||
|
tools=tools,
|
||||||
|
)
|
||||||
|
|
||||||
|
chunks = list(response_stream)
|
||||||
|
self.assertTrue(len(chunks) > 0, "Streaming should return at least one chunk")
|
||||||
|
|
||||||
|
found_function_name = False
|
||||||
|
for chunk in chunks:
|
||||||
|
choice = chunk.choices[0]
|
||||||
|
# Check whether the current chunk contains tool_calls
|
||||||
|
if choice.delta.tool_calls:
|
||||||
|
tool_call = choice.delta.tool_calls[0]
|
||||||
|
if tool_call.function.name:
|
||||||
|
self.assertEqual(
|
||||||
|
tool_call.function.name,
|
||||||
|
"get_current_weather",
|
||||||
|
"Function name should be 'get_current_weather'",
|
||||||
|
)
|
||||||
|
found_function_name = True
|
||||||
|
break
|
||||||
|
|
||||||
|
self.assertTrue(
|
||||||
|
found_function_name,
|
||||||
|
"Target function name 'get_current_weather' was not found in the streaming chunks",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_function_calling_streaming_args_parsing(self):
|
||||||
|
"""
|
||||||
|
Test: Whether the function call arguments returned in streaming mode can be correctly concatenated into valid JSON.
|
||||||
|
- The user request requires multiple parameters.
|
||||||
|
- AI may return the arguments in chunks that need to be concatenated.
|
||||||
|
"""
|
||||||
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||||
|
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "add",
|
||||||
|
"description": "Compute the sum of two integers",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"a": {
|
||||||
|
"type": "int",
|
||||||
|
"description": "First integer",
|
||||||
|
},
|
||||||
|
"b": {
|
||||||
|
"type": "int",
|
||||||
|
"description": "Second integer",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["a", "b"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": "Please sum 5 and 7, just call the function."}
|
||||||
|
]
|
||||||
|
|
||||||
|
response_stream = client.chat.completions.create(
|
||||||
|
model=self.model,
|
||||||
|
messages=messages,
|
||||||
|
temperature=0.9,
|
||||||
|
top_p=0.9,
|
||||||
|
stream=True,
|
||||||
|
tools=tools,
|
||||||
|
)
|
||||||
|
|
||||||
|
argument_fragments = []
|
||||||
|
function_name = None
|
||||||
|
for chunk in response_stream:
|
||||||
|
choice = chunk.choices[0]
|
||||||
|
if choice.delta.tool_calls:
|
||||||
|
tool_call = choice.delta.tool_calls[0]
|
||||||
|
# Record the function name on first occurrence
|
||||||
|
function_name = tool_call.function.name or function_name
|
||||||
|
# In case of multiple chunks, JSON fragments may need to be concatenated
|
||||||
|
if tool_call.function.arguments:
|
||||||
|
argument_fragments.append(tool_call.function.arguments)
|
||||||
|
|
||||||
|
self.assertEqual(function_name, "add", "Function name should be 'add'")
|
||||||
|
joined_args = "".join(argument_fragments)
|
||||||
|
self.assertTrue(
|
||||||
|
len(joined_args) > 0,
|
||||||
|
"No parameter fragments were returned in the function call",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check whether the concatenated JSON is valid
|
||||||
|
try:
|
||||||
|
args_obj = json.loads(joined_args)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
self.fail(
|
||||||
|
"The concatenated tool call arguments are not valid JSON, parsing failed"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertIn("a", args_obj, "Missing parameter 'a'")
|
||||||
|
self.assertIn("b", args_obj, "Missing parameter 'b'")
|
||||||
|
self.assertEqual(
|
||||||
|
args_obj["a"],
|
||||||
|
5,
|
||||||
|
"Parameter a should be 5",
|
||||||
|
)
|
||||||
|
self.assertEqual(args_obj["b"], 7, "Parameter b should be 7")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
@@ -623,58 +623,6 @@ class TestOpenAIServerEBNF(unittest.TestCase):
|
|||||||
text, pattern, f"Text '{text}' not matching the EBNF strict JSON shape"
|
text, pattern, f"Text '{text}' not matching the EBNF strict JSON shape"
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_function_calling_format(self):
|
|
||||||
|
|
||||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
|
||||||
tools = [
|
|
||||||
{
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": "add",
|
|
||||||
"description": "Compute the sum of two numbers",
|
|
||||||
"parameters": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"a": {
|
|
||||||
"type": "int",
|
|
||||||
"description": "A number",
|
|
||||||
},
|
|
||||||
"b": {
|
|
||||||
"type": "int",
|
|
||||||
"description": "A number",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"required": ["a", "b"],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
messages = [{"role": "user", "content": "Compute (3+5)"}]
|
|
||||||
response = client.chat.completions.create(
|
|
||||||
model=self.model,
|
|
||||||
messages=messages,
|
|
||||||
temperature=0.8,
|
|
||||||
top_p=0.8,
|
|
||||||
stream=False,
|
|
||||||
tools=tools,
|
|
||||||
)
|
|
||||||
|
|
||||||
content = response.choices[0].message.content
|
|
||||||
tool_calls = response.choices[0].message.tool_calls
|
|
||||||
|
|
||||||
assert (
|
|
||||||
content is None
|
|
||||||
), "When tools provided by the response, content should be None"
|
|
||||||
assert (
|
|
||||||
isinstance(tool_calls, list) and len(tool_calls) > 0
|
|
||||||
), "Format not matched, tool_calls should be a list"
|
|
||||||
|
|
||||||
function_name = tool_calls[0].function.name
|
|
||||||
assert (
|
|
||||||
function_name == "add"
|
|
||||||
), "Function name should be add for the above response"
|
|
||||||
|
|
||||||
|
|
||||||
class TestOpenAIEmbedding(unittest.TestCase):
|
class TestOpenAIEmbedding(unittest.TestCase):
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
Reference in New Issue
Block a user