Feature/function calling update (#2700)
Co-authored-by: Mingyuan Ma <mamingyuan2001@berkeley.edu> Co-authored-by: Chayenne <zhaochen20@outlook.com> Co-authored-by: shuaills <shishuaiuoe@gmail.com>
This commit is contained in:
@@ -4,62 +4,23 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Function Calling\n",
|
||||
"# Tool and Function Calling\n",
|
||||
"\n",
|
||||
"This notebook provides a quick-start guide to use function tooling using SGLang chat completions API\n",
|
||||
"\n",
|
||||
"## Supported Models\n",
|
||||
"\n",
|
||||
"Currently, we added the support for tools calling in the following models:\n",
|
||||
" - Llama 3.2 models\n",
|
||||
" - Llama 3.1 models\n",
|
||||
" - Qwen 2.5 models\n",
|
||||
" - InternLM Models"
|
||||
"This guide demonstrates how to use SGLang’s **Tool Calling** functionality."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Usage\n",
|
||||
"\n",
|
||||
"### Launch a server\n",
|
||||
"\n",
|
||||
"This code block is equivalent to executing\n",
|
||||
"\n",
|
||||
"`python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n",
|
||||
"--port 30000 --host 0.0.0.0`\n",
|
||||
"in your terminal and wait for the server to be ready. Once the server is running, you can send test requests using curl or requests. The server implements the OpenAI-compatible APIs."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from sglang.utils import (\n",
|
||||
" execute_shell_command,\n",
|
||||
" wait_for_server,\n",
|
||||
" terminate_process,\n",
|
||||
" print_highlight,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"server_process = execute_shell_command(\n",
|
||||
" \"\"\"\n",
|
||||
" python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --port 30000 --host 0.0.0.0\n",
|
||||
"\"\"\"\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"wait_for_server(\"http://localhost:30000\")"
|
||||
"## OpenAI Compatible API"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Single Round Invocation"
|
||||
"### Launching the Server"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -69,7 +30,47 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from openai import OpenAI\n",
|
||||
"import json\n",
|
||||
"from sglang.utils import (\n",
|
||||
" execute_shell_command,\n",
|
||||
" wait_for_server,\n",
|
||||
" terminate_process,\n",
|
||||
" print_highlight,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"server_process = execute_shell_command(\n",
|
||||
" \"python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --tool-call-parser llama3 --port 30333 --host 0.0.0.0\" # llama3\n",
|
||||
")\n",
|
||||
"wait_for_server(\"http://localhost:30333\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Note that `--tool-call-parser` defines the parser used to interpret responses. Currently supported parsers include:\n",
|
||||
"\n",
|
||||
"- llama3: Llama 3.1 / 3.2 (e.g. meta-llama/Llama-3.1-8B-Instruct, meta-llama/Llama-3.2-1B-Instruct).\n",
|
||||
"- mistral: Mistral (e.g. mistralai/Mistral-7B-Instruct-v0.3, mistralai/Mistral-Nemo-Instruct-2407, mistralai/\n",
|
||||
"Mistral-Nemo-Instruct-2407, mistralai/Mistral-7B-v0.3).\n",
|
||||
"- qwen25: Qwen 2.5 (e.g. Qwen/Qwen2.5-1.5B-Instruct, Qwen/Qwen2.5-7B-Instruct)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Define Tools for Function Call\n",
|
||||
"Below is a Python snippet that shows how to define a tool as a dictionary. The dictionary includes a tool name, a description, and property defined Parameters."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Define tools\n",
|
||||
"tools = [\n",
|
||||
" {\n",
|
||||
" \"type\": \"function\",\n",
|
||||
@@ -79,22 +80,264 @@
|
||||
" \"parameters\": {\n",
|
||||
" \"type\": \"object\",\n",
|
||||
" \"properties\": {\n",
|
||||
" \"location\": {\n",
|
||||
" \"city\": {\n",
|
||||
" \"type\": \"string\",\n",
|
||||
" \"description\": \"The city and state, e.g. San Francisco, CA\",\n",
|
||||
" \"description\": \"The city to find the weather for, e.g. 'San Francisco'\",\n",
|
||||
" },\n",
|
||||
" \"state\": {\n",
|
||||
" \"type\": \"string\",\n",
|
||||
" \"description\": \"the two-letter abbreviation for the state that the city is\"\n",
|
||||
" \" in, e.g. 'CA' which would mean 'California'\",\n",
|
||||
" },\n",
|
||||
" \"unit\": {\n",
|
||||
" \"type\": \"string\",\n",
|
||||
" \"description\": \"The unit to fetch the temperature in\",\n",
|
||||
" \"enum\": [\"celsius\", \"fahrenheit\"],\n",
|
||||
" },\n",
|
||||
" \"unit\": {\"type\": \"string\", \"enum\": [\"celsius\", \"fahrenheit\"]},\n",
|
||||
" },\n",
|
||||
" \"required\": [\"location\"],\n",
|
||||
" \"required\": [\"city\", \"state\", \"unit\"],\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
" }\n",
|
||||
"]\n",
|
||||
"messages = [{\"role\": \"user\", \"content\": \"What's the weather like in Boston today?\"}]\n",
|
||||
"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Define Messages"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def get_messages():\n",
|
||||
" return [\n",
|
||||
" {\n",
|
||||
" \"role\": \"user\",\n",
|
||||
" \"content\": \"What's the weather like in Boston today? Please respond with the format: Today's weather is :{function call result}\",\n",
|
||||
" }\n",
|
||||
" ]\n",
|
||||
"\n",
|
||||
"client = OpenAI(api_key=\"YOUR_API_KEY\", base_url=\"http://0.0.0.0:30000/v1\")\n",
|
||||
"model_name = client.models.list().data[0].id\n",
|
||||
"response = client.chat.completions.create(\n",
|
||||
"\n",
|
||||
"messages = get_messages()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Initialize the Client"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Initialize OpenAI-like client\n",
|
||||
"client = OpenAI(api_key=\"None\", base_url=\"http://0.0.0.0:30333/v1\")\n",
|
||||
"model_name = client.models.list().data[0].id"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Non-Streaming Request"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Non-streaming mode test\n",
|
||||
"response_non_stream = client.chat.completions.create(\n",
|
||||
" model=model_name,\n",
|
||||
" messages=messages,\n",
|
||||
" temperature=0.8,\n",
|
||||
" top_p=0.8,\n",
|
||||
" stream=False, # Non-streaming\n",
|
||||
" tools=tools,\n",
|
||||
")\n",
|
||||
"print_highlight(\"Non-stream response:\")\n",
|
||||
"print(response_non_stream)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Streaming Request"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Streaming mode test\n",
|
||||
"print_highlight(\"Streaming response:\")\n",
|
||||
"response_stream = client.chat.completions.create(\n",
|
||||
" model=model_name,\n",
|
||||
" messages=messages,\n",
|
||||
" temperature=0.8,\n",
|
||||
" top_p=0.8,\n",
|
||||
" stream=True, # Enable streaming\n",
|
||||
" tools=tools,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"chunks = []\n",
|
||||
"for chunk in response_stream:\n",
|
||||
" chunks.append(chunk)\n",
|
||||
" if chunk.choices[0].delta.tool_calls:\n",
|
||||
" print(chunk.choices[0].delta.tool_calls[0])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"\n",
|
||||
"### Handle Tool Calls\n",
|
||||
"\n",
|
||||
"When the engine determines it should call a particular tool, it will return arguments or partial arguments through the response. You can parse these arguments and later invoke the tool accordingly."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"**Non-Streaming Request**"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"name_non_stream = response_non_stream.choices[0].message.tool_calls[0].function.name\n",
|
||||
"arguments_non_stream = (\n",
|
||||
" response_non_stream.choices[0].message.tool_calls[0].function.arguments\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print_highlight(f\"Final streamed function call name: {name_non_stream}\")\n",
|
||||
"print_highlight(f\"Final streamed function call arguments: {arguments_non_stream}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"**Streaming Request**"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Parse and combine function call arguments\n",
|
||||
"arguments = []\n",
|
||||
"for chunk in chunks:\n",
|
||||
" choice = chunk.choices[0]\n",
|
||||
" delta = choice.delta\n",
|
||||
" if delta.tool_calls:\n",
|
||||
" tool_call = delta.tool_calls[0]\n",
|
||||
" if tool_call.function.name:\n",
|
||||
" print_highlight(f\"Streamed function call name: {tool_call.function.name}\")\n",
|
||||
"\n",
|
||||
" if tool_call.function.arguments:\n",
|
||||
" arguments.append(tool_call.function.arguments)\n",
|
||||
" print(f\"Streamed function call arguments: {tool_call.function.arguments}\")\n",
|
||||
"\n",
|
||||
"# Combine all fragments into a single JSON string\n",
|
||||
"full_arguments = \"\".join(arguments)\n",
|
||||
"print_highlight(f\"Final streamed function call arguments: {full_arguments}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Define a Tool Function"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# This is a demonstration, define real function according to your usage.\n",
|
||||
"def get_current_weather(city: str, state: str, unit: \"str\"):\n",
|
||||
" return (\n",
|
||||
" f\"The weather in {city}, {state} is 85 degrees {unit}. It is \"\n",
|
||||
" \"partly cloudly, with highs in the 90's.\"\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"available_tools = {\"get_current_weather\": get_current_weather}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"\n",
|
||||
"## Execute the Tool"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"call_data = json.loads(full_arguments)\n",
|
||||
"\n",
|
||||
"messages.append(\n",
|
||||
" {\n",
|
||||
" \"role\": \"user\",\n",
|
||||
" \"content\": \"\",\n",
|
||||
" \"tool_calls\": {\"name\": \"get_current_weather\", \"arguments\": full_arguments},\n",
|
||||
" }\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Call the corresponding tool function\n",
|
||||
"tool_name = messages[-1][\"tool_calls\"][\"name\"]\n",
|
||||
"tool_to_call = available_tools[tool_name]\n",
|
||||
"result = tool_to_call(**call_data)\n",
|
||||
"print_highlight(f\"Function call result: {result}\")\n",
|
||||
"messages.append({\"role\": \"tool\", \"content\": result, \"name\": tool_name})\n",
|
||||
"\n",
|
||||
"print_highlight(f\"Updated message history: {messages}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Send Results Back to Model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"final_response = client.chat.completions.create(\n",
|
||||
" model=model_name,\n",
|
||||
" messages=messages,\n",
|
||||
" temperature=0.8,\n",
|
||||
@@ -102,17 +345,56 @@
|
||||
" stream=False,\n",
|
||||
" tools=tools,\n",
|
||||
")\n",
|
||||
"print_highlight(\"Non-stream response:\")\n",
|
||||
"print(final_response)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Native API and SGLang Runtime (SRT)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from transformers import AutoTokenizer\n",
|
||||
"import requests\n",
|
||||
"\n",
|
||||
"print(response)\n",
|
||||
"# generate an answer\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained(\"meta-llama/Meta-Llama-3.1-8B-Instruct\")\n",
|
||||
"\n",
|
||||
"\"\"\"\n",
|
||||
"messages = get_messages()\n",
|
||||
"\n",
|
||||
"ChatCompletion(id='d6f620e1767e490d85b5ce45c15151cf', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content=None, refusal=None, \n",
|
||||
"role='assistant', audio=None, function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='0', function=Function(arguments='{\"a\": \"3\", \"b\": \"5\"}', name='add'), type='function')]), \n",
|
||||
"matched_stop=128008)], created=1735411703, model='meta-llama/Llama-3.2-1B-Instruct', object='chat.completion', service_tier=None, system_fingerprint=None, \n",
|
||||
"usage=CompletionUsage(completion_tokens=23, prompt_tokens=198, total_tokens=221, completion_tokens_details=None, prompt_tokens_details=None))\n",
|
||||
"input = tokenizer.apply_chat_template(\n",
|
||||
" messages,\n",
|
||||
" tokenize=False,\n",
|
||||
" add_generation_prompt=True,\n",
|
||||
" tools=tools,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"\"\"\""
|
||||
"gen_url = \"http://localhost:30333/generate\"\n",
|
||||
"gen_data = {\"text\": input, \"sampling_params\": {\"skip_special_tokens\": False}}\n",
|
||||
"gen_response = requests.post(gen_url, json=gen_data).json()[\"text\"]\n",
|
||||
"print(gen_response)\n",
|
||||
"\n",
|
||||
"# parse the response\n",
|
||||
"parse_url = \"http://localhost:30333/function_call\"\n",
|
||||
"\n",
|
||||
"function_call_input = {\n",
|
||||
" \"text\": gen_response,\n",
|
||||
" \"tool_call_parser\": \"llama3\",\n",
|
||||
" \"tools\": tools,\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"function_call_response = requests.post(parse_url, json=function_call_input)\n",
|
||||
"function_call_response_json = function_call_response.json()\n",
|
||||
"print(\"function name: \", function_call_response_json[\"calls\"][0][\"name\"])\n",
|
||||
"print(\"function arguments: \", function_call_response_json[\"calls\"][0][\"parameters\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -128,11 +410,98 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## How to support a new model?\n",
|
||||
"## Offline Engine API"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import sglang as sgl\n",
|
||||
"from sglang.srt.function_call_parser import FunctionCallParser\n",
|
||||
"from sglang.srt.managers.io_struct import Tool, Function\n",
|
||||
"\n",
|
||||
"For adding support of more different models:\n",
|
||||
" 1. Update the `TOOLS_TAG_LIST` in `sglang/srt/utils.py` with the tool tag used by the model.\n",
|
||||
" 2. Add support in `parse_tool_response` function for converting into tool calls `sglang/srt/utils.py`\n"
|
||||
"llm = sgl.Engine(model_path=\"meta-llama/Meta-Llama-3.1-8B-Instruct\")\n",
|
||||
"tokenizer = llm.tokenizer_manager.tokenizer\n",
|
||||
"input_ids = tokenizer.apply_chat_template(\n",
|
||||
" messages, tokenize=True, add_generation_prompt=True, tools=tools\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"sampling_params = {\n",
|
||||
" \"max_new_tokens\": 128,\n",
|
||||
" \"temperature\": 0.3,\n",
|
||||
" \"top_p\": 0.95,\n",
|
||||
" \"skip_special_tokens\": False,\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"# 1) Offline generation\n",
|
||||
"result = llm.generate(input_ids=input_ids, sampling_params=sampling_params)\n",
|
||||
"generated_text = result[\"text\"] # Assume there is only one prompt\n",
|
||||
"\n",
|
||||
"print(\"=== Offline Engine Output Text ===\")\n",
|
||||
"print(generated_text)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# 2) Parse using FunctionCallParser\n",
|
||||
"def convert_dict_to_tool(tool_dict: dict) -> Tool:\n",
|
||||
" function_dict = tool_dict.get(\"function\", {})\n",
|
||||
" return Tool(\n",
|
||||
" type=tool_dict.get(\"type\", \"function\"),\n",
|
||||
" function=Function(\n",
|
||||
" name=function_dict.get(\"name\"),\n",
|
||||
" description=function_dict.get(\"description\"),\n",
|
||||
" parameters=function_dict.get(\"parameters\"),\n",
|
||||
" ),\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"tools = [convert_dict_to_tool(raw_tool) for raw_tool in tools]\n",
|
||||
"\n",
|
||||
"parser = FunctionCallParser(tools=tools, tool_call_parser=\"llama3\")\n",
|
||||
"normal_text, calls = parser.parse_non_stream(generated_text)\n",
|
||||
"\n",
|
||||
"print(\"\\n=== Parsing Result ===\")\n",
|
||||
"print(\"Normal text portion:\", normal_text)\n",
|
||||
"print(\"Function call portion:\")\n",
|
||||
"for call in calls:\n",
|
||||
" # call: ToolCallItem\n",
|
||||
" print(f\" - tool name: {call.name}\")\n",
|
||||
" print(f\" parameters: {call.parameters}\")\n",
|
||||
"\n",
|
||||
"# 3) If needed, perform additional logic on the parsed functions, such as automatically calling the corresponding function to obtain a return value, etc."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"llm.shutdown()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## How to support a new model?\n",
|
||||
"1. Update the TOOLS_TAG_LIST in sglang/srt/function_call_parser.py with the model’s tool tags. Currently supported tags include:\n",
|
||||
"```\n",
|
||||
"\tTOOLS_TAG_LIST = [\n",
|
||||
"\t “<|plugin|>“,\n",
|
||||
"\t “<function=“,\n",
|
||||
"\t “<tool_call>“,\n",
|
||||
"\t “<|python_tag|>“,\n",
|
||||
"\t “[TOOL_CALLS]”\n",
|
||||
"\t]\n",
|
||||
"```\n",
|
||||
"2. Create a new detector class in sglang/srt/function_call_parser.py that inherits from BaseFormatDetector. The detector should handle the model’s specific function call format. For example:\n",
|
||||
"```\n",
|
||||
" class NewModelDetector(BaseFormatDetector):\n",
|
||||
"```\n",
|
||||
"3. Add the new detector to the MultiFormatParser class that manages all the format detectors."
|
||||
]
|
||||
}
|
||||
],
|
||||
|
||||
Reference in New Issue
Block a user