From 2b06484bd198cf090d7153e7aad5daffa8250355 Mon Sep 17 00:00:00 2001 From: Chang Su Date: Tue, 29 Apr 2025 17:30:44 -0700 Subject: [PATCH] feat: support pythonic tool call and index in tool call streaming (#5725) --- docs/backend/function_calling.ipynb | 167 ++++++++++++++++++ .../tool_chat_template_llama4_pythonic.jinja | 140 +++++++++++++++ python/sglang/srt/function_call_parser.py | 97 ++++++++++ python/sglang/srt/openai_api/adapter.py | 1 + python/sglang/srt/openai_api/protocol.py | 1 + python/sglang/srt/server_args.py | 4 +- test/srt/run_suite.py | 2 +- test/srt/test_function_calling.py | 132 ++++++++++++++ 8 files changed, 541 insertions(+), 3 deletions(-) create mode 100644 examples/chat_template/tool_chat_template_llama4_pythonic.jinja diff --git a/docs/backend/function_calling.ipynb b/docs/backend/function_calling.ipynb index 99d11f50a..2fece950f 100644 --- a/docs/backend/function_calling.ipynb +++ b/docs/backend/function_calling.ipynb @@ -503,6 +503,173 @@ "llm.shutdown()" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Pythonic Tool Call Format (Llama-3.2 / Llama-3.3 / Llama-4)\n", + "\n", + "Some Llama models (such as Llama-3.2-1B, Llama-3.2-3B, Llama-3.3-70B, and Llama-4) support a \"pythonic\" tool call format, where the model outputs function calls as Python code, e.g.:\n", + "\n", + "```python\n", + "[get_current_weather(city=\"San Francisco\", state=\"CA\", unit=\"celsius\")]\n", + "```\n", + "\n", + "- The output is a Python list of function calls, with arguments as Python literals (not JSON).\n", + "- Multiple tool calls can be returned in the same list:\n", + "```python\n", + "[get_current_weather(city=\"San Francisco\", state=\"CA\", unit=\"celsius\"),\n", + " get_current_weather(city=\"New York\", state=\"NY\", unit=\"fahrenheit\")]\n", + "```\n", + "\n", + "For more information, refer to Meta’s documentation on [Zero shot function calling](https://github.com/meta-llama/llama-models/blob/main/models/llama4/prompt_format.md#zero-shot-function-calling---system-message).\n", + "\n", + "### How to enable\n", + "- Launch the server with `--tool-call-parser pythonic`\n", + "- You may also specify --chat-template with the improved template for the model (e.g., `--chat-template=examples/chat_template/tool_chat_template_llama4_pythonic.jinja`).\n", + "This is recommended because the model expects a special prompt format to reliably produce valid pythonic tool call outputs. The template ensures that the prompt structure (e.g., special tokens, message boundaries like `<|eom|>`, and function call delimiters) matches what the model was trained or fine-tuned on. If you do not use the correct chat template, tool calling may fail or produce inconsistent results.\n", + "\n", + "#### Forcing Pythonic Tool Call Output Without a Chat Template\n", + "If you don't want to specify a chat template, you must give the model extremely explicit instructions in your messages to enforce pythonic output. For example, for `Llama-3.2-1B-Instruct`, you need:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import openai\n", + "from sglang.utils import wait_for_server, print_highlight, terminate_process\n", + "from sglang.test.test_utils import is_in_ci\n", + "\n", + "\n", + "if is_in_ci():\n", + " from patch import launch_server_cmd\n", + "else:\n", + " from sglang.utils import launch_server_cmd\n", + "\n", + "server_process, port = launch_server_cmd(\n", + " \" python3 -m sglang.launch_server --model-path meta-llama/Llama-3.2-1B-Instruct --tool-call-parser pythonic --tp 1\" # llama-3.2-1b-instruct\n", + ")\n", + "wait_for_server(f\"http://localhost:{port}\")\n", + "\n", + "tools = [\n", + " {\n", + " \"type\": \"function\",\n", + " \"function\": {\n", + " \"name\": \"get_weather\",\n", + " \"description\": \"Get the current weather for a given location.\",\n", + " \"parameters\": {\n", + " \"type\": \"object\",\n", + " \"properties\": {\n", + " \"location\": {\n", + " \"type\": \"string\",\n", + " \"description\": \"The name of the city or location.\",\n", + " }\n", + " },\n", + " \"required\": [\"location\"],\n", + " },\n", + " },\n", + " },\n", + " {\n", + " \"type\": \"function\",\n", + " \"function\": {\n", + " \"name\": \"get_tourist_attractions\",\n", + " \"description\": \"Get a list of top tourist attractions for a given city.\",\n", + " \"parameters\": {\n", + " \"type\": \"object\",\n", + " \"properties\": {\n", + " \"city\": {\n", + " \"type\": \"string\",\n", + " \"description\": \"The name of the city to find attractions for.\",\n", + " }\n", + " },\n", + " \"required\": [\"city\"],\n", + " },\n", + " },\n", + " },\n", + "]\n", + "\n", + "\n", + "def get_messages():\n", + " return [\n", + " {\n", + " \"role\": \"system\",\n", + " \"content\": (\n", + " \"You are a travel assistant. \"\n", + " \"When asked to call functions, ALWAYS respond ONLY with a python list of function calls, \"\n", + " \"using this format: [func_name1(param1=value1, param2=value2), func_name2(param=value)]. \"\n", + " \"Do NOT use JSON, do NOT use variables, do NOT use any other format. \"\n", + " \"Here is an example:\\n\"\n", + " '[get_weather(location=\"Paris\"), get_tourist_attractions(city=\"Paris\")]'\n", + " ),\n", + " },\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": (\n", + " \"I'm planning a trip to Tokyo next week. What's the weather like and what are some top tourist attractions? \"\n", + " \"Propose parallel tool calls at once, using the python list of function calls format as shown above.\"\n", + " ),\n", + " },\n", + " ]\n", + "\n", + "\n", + "messages = get_messages()\n", + "\n", + "client = openai.Client(base_url=f\"http://localhost:{port}/v1\", api_key=\"xxxxxx\")\n", + "model_name = client.models.list().data[0].id\n", + "\n", + "\n", + "response_non_stream = client.chat.completions.create(\n", + " model=model_name,\n", + " messages=messages,\n", + " temperature=0.8,\n", + " top_p=0.8,\n", + " stream=False, # Non-streaming\n", + " tools=tools,\n", + ")\n", + "print_highlight(\"Non-stream response:\")\n", + "print(response_non_stream)\n", + "\n", + "response_stream = client.chat.completions.create(\n", + " model=model_name,\n", + " messages=messages,\n", + " temperature=0.8,\n", + " top_p=0.8,\n", + " stream=True,\n", + " tools=tools,\n", + ")\n", + "texts = \"\"\n", + "tool_calls = []\n", + "name = \"\"\n", + "arguments = \"\"\n", + "\n", + "for chunk in response_stream:\n", + " if chunk.choices[0].delta.content:\n", + " texts += chunk.choices[0].delta.content\n", + " if chunk.choices[0].delta.tool_calls:\n", + " tool_calls.append(chunk.choices[0].delta.tool_calls[0])\n", + "\n", + "print_highlight(\"Streaming Response:\")\n", + "print_highlight(\"==== Text ====\")\n", + "print(texts)\n", + "\n", + "print_highlight(\"==== Tool Call ====\")\n", + "for tool_call in tool_calls:\n", + " print(tool_call)\n", + "\n", + "terminate_process(server_process)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "> **Note:** \n", + "> The model may still default to JSON if it was heavily finetuned on that format. Prompt engineering (including examples) is the only way to increase the chance of pythonic output if you are not using a chat template." + ] + }, { "cell_type": "markdown", "metadata": {}, diff --git a/examples/chat_template/tool_chat_template_llama4_pythonic.jinja b/examples/chat_template/tool_chat_template_llama4_pythonic.jinja new file mode 100644 index 000000000..3b38f0ee0 --- /dev/null +++ b/examples/chat_template/tool_chat_template_llama4_pythonic.jinja @@ -0,0 +1,140 @@ +{# Copied from https://github.com/yeqcharlotte/vllm/blob/4fcf68a948bbe0498dc8a98feafa102cfb1dd210/examples/tool_chat_template_llama4_pythonic.jinja to enable better model response. #} +{{- bos_token }} +{%- if custom_tools is defined %} + {%- set tools = custom_tools %} +{%- endif %} +{%- if not tools_in_user_message is defined %} + {%- set tools_in_user_message = false %} +{%- endif %} +{%- if not tools is defined %} + {%- set tools = none %} +{%- endif %} + +{#- This block extracts the system message, so we can slot it into the right place. #} +{%- if messages[0]['role'] == 'system' %} + {%- if messages[0]['content'] is string %} + {%- set system_message = messages[0]['content']|trim %} + {%- else %} + {%- set system_message = messages[0]['content'][0]['text']|trim %} + {%- endif %} + {%- set messages = messages[1:] %} +{%- else %} + {%- if tools is not none %} + {#- Add default tool system message when tools are provided #} + {%- set system_message = "You are a helpful assistant with tool calling " + "capabilities. Only reply with a tool call if the function exists in the " + "library provided by the user. If it doesn't exist, just reply directly in " + "natural language. When you receive a tool call response, use the output to " + "format an answer to the original user question." %} + {%- else %} + {%- set system_message = "" %} + {%- endif %} +{%- endif %} + +{#- System message if the user supplied one, or if tools are used (default tool system message) #} +{%- if system_message %} + {#- always use user provided system message to override default tool system message #} + {{- "<|header_start|>system<|header_end|>\n\n" }} + {{- system_message }} + {%- if tools is not none and not tools_in_user_message %} + {{- "Tools: You have access to the following tools. You might need to use one " + "or more function/tool calls to fulfill the task. \n" + "If none are needed, then proceed to the response.\n\n" + "Tool Call Syntax: You can call tools using the following syntax:\n" + "[func_name1(params_name1=params_value1, params_name2=params_value2, ...), ...]\n" + "Do not include anything else when calling the tools with the syntax above.\n\n" + "Here is a list of functions in JSON format that you can invoke.\n " }} + {%- for t in tools %} + {{- t | tojson(indent=4) }} + {{- "\n\n" }} + {%- endfor %} + {%- endif %} + {{- "<|eot|>" }} +{%- endif %} + +{#- Custom tools are passed in a user message with some extra guidance #} +{%- if tools_in_user_message and tools is not none %} + {#- Extract the first user message so we can plug it in here #} + {%- if messages | length != 0 %} + {%- if messages[0]['content'] is string %} + {%- set first_user_message = messages[0]['content']|trim %} + {%- else %} + {%- set first_user_message = messages[0]['content'] | selectattr('type', 'equalto', 'text') | map(attribute='text') | map('trim') | join('\n') %} + {%- endif %} + {%- set messages = messages[1:] %} + {%- else %} + {{- raise_exception("Cannot put tools in the first user message when there's no first user message!") }} + {%- endif %} + {{- '<|header_start|>user<|header_end|>\n\n' -}} + {{- first_user_message}} + {{- "\nHere is a list of functions in JSON format that you can invoke:"}} + {%- for t in tools %} + {{- t | tojson(indent=4) }} + {{- "\n\n" }} + {%- endfor %} + {{- "Should you decide to return the function call(s), put them in the format " + "of [func_name1(params_name1=params_value1, params_name2=params_value2, " + "...), ...]\nDo not include anything else when calling the tools with the " + "syntax above." }} +{%- endif %} + +{%- for message in messages %} + {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %} + {{- '<|header_start|>' + message['role'] + '<|header_end|>\n\n' }} + {%- if message['content'] is string %} + {{- message['content'] }} + {%- else %} + {%- for content in message['content'] %} + {%- if content['type'] == 'image' %} + {{- '<|image|>' }} + {%- elif content['type'] == 'text' %} + {{- content['text'] | trim }} + {%- endif %} + {%- endfor %} + {%- endif %} + {{- "<|eot|>" }} + {%- elif 'tool_calls' in message and message.tool_calls|length > 0 %} + {%- set tool_call = message.tool_calls[0].function %} + {{- '<|header_start|>assistant<|header_end|>\n\n' -}} + {%- if message['content'] is string %} + {{- message['content'] }} + {%- else %} + {%- for content in message['content'] %} + {%- if content['type'] == 'image' %} + {{- '<|image|>' }} + {%- elif content['type'] == 'text' %} + {{- content['text'] }} + {%- endif %} + {%- endfor %} + {%- endif %} + {%- for tool_call in message.tool_calls %} + {%- if tool_call.function is defined %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- tool_call.name + '(' -}} + {%- for param in tool_call.arguments %} + {{- param + '=' -}} + {{- "%s" | format(tool_call.arguments[param]) -}} + {% if not loop.last %}, {% endif %} + {%- endfor %} + {{- ')' -}} + {% if not loop.last %}, {% endif %} + {%- endfor %} + {{- "<|eom|>" }} + {%- elif message.role == "tool" or message.role == "ipython" %} + {{- "<|header_start|>ipython<|header_end|>\n\n" }} + {%- if message.content is string %} + {{- message.content | tojson }} + {%- else %} + {%- for content in message['content'] %} + {%- if content['type'] == 'text' %} + {{- content['text'] | tojson }} + {%- endif %} + {%- endfor %} + {%- endif %} + {{- "<|eom|>" }} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|header_start|>assistant<|header_end|>\n\n' }} +{%- endif %} diff --git a/python/sglang/srt/function_call_parser.py b/python/sglang/srt/function_call_parser.py index abc6cf650..ef7b51058 100644 --- a/python/sglang/srt/function_call_parser.py +++ b/python/sglang/srt/function_call_parser.py @@ -1,3 +1,4 @@ +import ast import json import logging import re @@ -664,6 +665,101 @@ class MultiFormatParser: return final_normal_text, final_calls +class PythonicDetector(BaseFormatDetector): + """ + Detector for Llama-3.2 and Llama-4 models with pythonic tool call format. + Assumes function call format: + [tool1(arg1=val1, arg2=val2), tool2(arg1=val3)] + Arguments are Python literals (not JSON). + """ + + def __init__(self): + super().__init__() + self.tool_call_regex = re.compile( + r"\[([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s)?\),\s*)*([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s*)?\)\s*)+\]", + re.DOTALL, + ) + + def has_tool_call(self, text: str) -> bool: + return bool(self.tool_call_regex.match(text.strip())) + + def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult: + # Try parsing the text as a Python list of function calls + text = text.strip() + if not (text.startswith("[") and text.endswith("]")): + # Not a pythonic tool call format + return StreamingParseResult(normal_text=text, calls=[]) + try: + module = ast.parse(text) + parsed = getattr(module.body[0], "value", None) + if not ( + isinstance(parsed, ast.List) + and all(isinstance(e, ast.Call) for e in parsed.elts) + ): + return StreamingParseResult(normal_text=text, calls=[]) + calls = [] + tool_indices = { + tool.function.name: i + for i, tool in enumerate(tools) + if tool.function.name + } + for call in parsed.elts: + if not isinstance(call.func, ast.Name): + continue + function_name = call.func.id + arguments = {} + for keyword in call.keywords: + arguments[keyword.arg] = self._get_parameter_value(keyword.value) + calls.append( + ToolCallItem( + tool_index=tool_indices.get(function_name, -1), + name=function_name, + parameters=json.dumps(arguments, ensure_ascii=False), + ) + ) + return StreamingParseResult(normal_text="", calls=calls) + except Exception: + logger.exception("Error in pythonic tool call parsing.") + return StreamingParseResult(normal_text=text, calls=[]) + + def parse_streaming_increment( + self, new_text: str, tools: List[Tool] + ) -> StreamingParseResult: + """ + Streaming incremental parsing for pythonic tool calls. + Buffers input until a complete pythonic tool call (from [ to ]) is found, + then parses and emits any detected calls. + """ + self._buffer += new_text + start = self._buffer.find("[") + end = self._buffer.find("]", start) + if start != -1 and end != -1: + call_text = self._buffer[start : end + 1] + result = self.detect_and_parse(call_text, tools) + self._buffer = self._buffer[end + 1 :] + return result + return StreamingParseResult(normal_text="") + + def _get_parameter_value(self, val): + if isinstance(val, ast.Constant): + return val.value + elif isinstance(val, ast.Dict): + return { + k.value: self._get_parameter_value(v) + for k, v in zip(val.keys, val.values) + } + elif isinstance(val, ast.List): + return [self._get_parameter_value(v) for v in val.elts] + else: + raise ValueError("Tool call arguments must be literals") + + def structure_info(self) -> _GetInfoFunc: + def info(name: str): + return StructureInfo(begin="[", end="]", trigger="") + + return info + + class FunctionCallParser: """ In streaming scenarios, each time new_text is received, it calls multi_format_parser.parse_streaming_increment @@ -675,6 +771,7 @@ class FunctionCallParser: "qwen25": Qwen25Detector, "mistral": MistralDetector, "deepseekv3": DeepSeekV3Detector, + "pythonic": PythonicDetector, } def __init__(self, tools: List[Tool], tool_call_parser: str): diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index 619f1d404..adf2cebc8 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -1618,6 +1618,7 @@ async def v1_chat_completions( tool_call = ToolCall( id=str(call_item.tool_index), + index=call_item.tool_index, function=FunctionResponse( name=call_item.name, arguments=call_item.parameters, diff --git a/python/sglang/srt/openai_api/protocol.py b/python/sglang/srt/openai_api/protocol.py index 39e25a57c..88d8873d1 100644 --- a/python/sglang/srt/openai_api/protocol.py +++ b/python/sglang/srt/openai_api/protocol.py @@ -389,6 +389,7 @@ class ToolCall(BaseModel): """Tool call response.""" id: str + index: Optional[int] = None type: Literal["function"] = "function" function: FunctionResponse diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 29585a7f9..a23ee4ad5 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1107,9 +1107,9 @@ class ServerArgs: parser.add_argument( "--tool-call-parser", type=str, - choices=["qwen25", "mistral", "llama3", "deepseekv3"], + choices=["qwen25", "mistral", "llama3", "deepseekv3", "pythonic"], default=ServerArgs.tool_call_parser, - help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', and 'llama3'.", + help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', and 'pythonic'.", ) parser.add_argument( "--enable-hierarchical-cache", diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index ee6f0cbad..9ae90b6d4 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -36,7 +36,7 @@ suites = { TestFile("test_fa3.py", 376), TestFile("test_fim_completion.py", 40), TestFile("test_fp8_kernel.py", 8), - TestFile("test_function_calling.py", 35), + TestFile("test_function_calling.py", 60), TestFile("test_fused_moe.py", 30), TestFile("test_hicache.py", 116), TestFile("test_hicache_mla.py", 254), diff --git a/test/srt/test_function_calling.py b/test/srt/test_function_calling.py index ebe052aec..3b436ff92 100644 --- a/test/srt/test_function_calling.py +++ b/test/srt/test_function_calling.py @@ -296,5 +296,137 @@ class TestOpenAIServerFunctionCalling(CustomTestCase): self.assertEqual(str(args_obj["int_b"]), "7", "Parameter int_b should be 7") +class TestOpenAIPythonicFunctionCalling(CustomTestCase): + PYTHONIC_TOOLS = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the current weather for a given location.", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The name of the city or location.", + } + }, + "required": ["location"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "get_tourist_attractions", + "description": "Get a list of top tourist attractions for a given city.", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The name of the city to find attractions for.", + } + }, + "required": ["city"], + }, + }, + }, + ] + + PYTHONIC_MESSAGES = [ + { + "role": "system", + "content": ( + "You are a travel assistant. " + "When asked to call functions, ALWAYS respond ONLY with a python list of function calls, " + "using this format: [func_name1(param1=value1, param2=value2), func_name2(param=value)]. " + "Do NOT use JSON, do NOT use variables, do NOT use any other format. " + "Here is an example:\n" + '[get_weather(location="Paris"), get_tourist_attractions(city="Paris")]' + ), + }, + { + "role": "user", + "content": ( + "I'm planning a trip to Tokyo next week. What's the weather like and what are some top tourist attractions? " + "Propose parallel tool calls at once, using the python list of function calls format as shown above." + ), + }, + ] + + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.api_key = "sk-123456" + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + api_key=cls.api_key, + other_args=[ + "--tool-call-parser", + "pythonic", + ], + ) + cls.base_url += "/v1" + cls.tokenizer = get_tokenizer(cls.model) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_pythonic_tool_call_prompt(self): + """ + Test: Explicit prompt for pythonic tool call format without chat template. + """ + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + response = client.chat.completions.create( + model=self.model, + messages=self.PYTHONIC_MESSAGES, + tools=self.PYTHONIC_TOOLS, + temperature=0.1, + stream=False, + ) + tool_calls = response.choices[0].message.tool_calls + self.assertIsInstance(tool_calls, list) + self.assertGreaterEqual(len(tool_calls), 1) + names = [tc.function.name for tc in tool_calls] + self.assertIn("get_weather", names) + self.assertIn("get_tourist_attractions", names) + + def test_pythonic_tool_call_streaming(self): + """ + Test: Streaming pythonic tool call format; assert tool_call index is present. + """ + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + response_stream = client.chat.completions.create( + model=self.model, + messages=self.PYTHONIC_MESSAGES, + tools=self.PYTHONIC_TOOLS, + temperature=0.1, + stream=True, + ) + found_tool_calls = False + found_index = False + found_names = set() + for chunk in response_stream: + choice = chunk.choices[0] + if getattr(choice.delta, "tool_calls", None): + found_tool_calls = True + tool_call = choice.delta.tool_calls[0] + if hasattr(tool_call, "index") or ( + isinstance(tool_call, dict) and "index" in tool_call + ): + found_index = True + found_names.add(str(tool_call.function.name)) + + self.assertTrue(found_tool_calls, "No tool_calls found in streaming response") + self.assertTrue(found_index, "No index field found in any streamed tool_call") + self.assertIn("get_weather", found_names) + self.assertIn("get_tourist_attractions", found_names) + + if __name__ == "__main__": unittest.main()