diff --git a/examples/chat_template/tool_chat_template_llama4_pythonic.jinja b/examples/chat_template/tool_chat_template_llama4_pythonic.jinja index 3b38f0ee0..74b315c31 100644 --- a/examples/chat_template/tool_chat_template_llama4_pythonic.jinja +++ b/examples/chat_template/tool_chat_template_llama4_pythonic.jinja @@ -1,17 +1,18 @@ -{# Copied from https://github.com/yeqcharlotte/vllm/blob/4fcf68a948bbe0498dc8a98feafa102cfb1dd210/examples/tool_chat_template_llama4_pythonic.jinja to enable better model response. #} +{# Copied from https://github.com/wukaixingxp/vllm/blob/8a32e2a6e452a03c0e8222e3876ad6086cbf581f/examples/tool_chat_template_llama4_pythonic.jinja to enable better model response. #} {{- bos_token }} -{%- if custom_tools is defined %} +{%- if custom_tools is defined and custom_tools %} {%- set tools = custom_tools %} {%- endif %} -{%- if not tools_in_user_message is defined %} - {%- set tools_in_user_message = false %} -{%- endif %} -{%- if not tools is defined %} +{%- if tools is defined and tools %} + {%- set tool_definition = tool_definition ~ (tools | tojson(indent=4)) %} +{%- else %} {%- set tools = none %} {%- endif %} + {#- This block extracts the system message, so we can slot it into the right place. #} {%- if messages[0]['role'] == 'system' %} + {%- set user_provided_system_message = true %} {%- if messages[0]['content'] is string %} {%- set system_message = messages[0]['content']|trim %} {%- else %} @@ -19,68 +20,33 @@ {%- endif %} {%- set messages = messages[1:] %} {%- else %} - {%- if tools is not none %} - {#- Add default tool system message when tools are provided #} - {%- set system_message = "You are a helpful assistant with tool calling " - "capabilities. Only reply with a tool call if the function exists in the " - "library provided by the user. If it doesn't exist, just reply directly in " - "natural language. When you receive a tool call response, use the output to " - "format an answer to the original user question." %} + {%- if tools is not none %} + {#- Since not system_message was provided by user, if tool is provided, system_message is now default tool system message #} + {#- This system message is from llama website:https://www.llama.com/docs/model-cards-and-prompt-formats/llama4/ #} + {%- set system_message = "You are a helpful assistant and an expert in function composition. You can answer general questions using your internal knowledge OR invoke functions when necessary. Follow these strict guidelines:\n\n1. FUNCTION CALLS:\n- ONLY use functions that are EXPLICITLY listed in the function list below\n- If NO functions are listed (empty function list []), respond ONLY with internal knowledge or \"I don't have access to [Unavailable service] information\"\n- If a function is not in the list, respond ONLY with internal knowledge or \"I don't have access to [Unavailable service] information\"\n- If ALL required parameters are present AND the query EXACTLY matches a listed function's purpose: output ONLY the function call(s)\n- Use exact format: [func_name1(param1=value1, param2=value2), func_name2(...)]\nExamples:\nCORRECT: [get_weather(location=\"Vancouver\"), calculate_route(start=\"Boston\", end=\"New York\")] <- Only if get_weather and calculate_route are in function list\nINCORRECT: get_weather(location=\"New York\")\nINCORRECT: Let me check the weather: [get_weather(location=\"New York\")]\nINCORRECT: [get_events(location=\"Singapore\")] <- If function not in list\n\n2. RESPONSE RULES:\n- For pure function requests matching a listed function: ONLY output the function call(s)\n- For knowledge questions: ONLY output text\n- For missing parameters: ONLY request the specific missing parameters\n- For unavailable services (not in function list): output ONLY with internal knowledge or \"I don't have access to [Unavailable service] information\". Do NOT execute a function call.\n- If the query asks for information beyond what a listed function provides: output ONLY with internal knowledge about your limitations\n- NEVER combine text and function calls in the same response\n- NEVER suggest alternative functions when the requested service is unavailable\n- NEVER create or invent new functions not listed below\n\n3. STRICT BOUNDARIES:\n- ONLY use functions from the list below - no exceptions\n- NEVER use a function as an alternative to unavailable information\n- NEVER call functions not present in the function list\n- NEVER add explanatory text to function calls\n- NEVER respond with empty brackets\n- Use proper Python/JSON syntax for function calls\n- Check the function list carefully before responding\n\n4. TOOL RESPONSE HANDLING:\n- When receiving tool responses: provide concise, natural language responses\n- Don't repeat tool response verbatim\n- Don't add supplementary information\n\nHere is a list of functions in JSON format that you can invoke:\n" %} {%- else %} {%- set system_message = "" %} {%- endif %} {%- endif %} - -{#- System message if the user supplied one, or if tools are used (default tool system message) #} +{#- Now writing the system message: use the user provided system message if user_provided_system_message, else default tool system message if tools presented #} {%- if system_message %} {#- always use user provided system message to override default tool system message #} {{- "<|header_start|>system<|header_end|>\n\n" }} {{- system_message }} - {%- if tools is not none and not tools_in_user_message %} - {{- "Tools: You have access to the following tools. You might need to use one " - "or more function/tool calls to fulfill the task. \n" - "If none are needed, then proceed to the response.\n\n" - "Tool Call Syntax: You can call tools using the following syntax:\n" - "[func_name1(params_name1=params_value1, params_name2=params_value2, ...), ...]\n" - "Do not include anything else when calling the tools with the syntax above.\n\n" - "Here is a list of functions in JSON format that you can invoke.\n " }} - {%- for t in tools %} - {{- t | tojson(indent=4) }} - {{- "\n\n" }} - {%- endfor %} + {%- if user_provided_system_message and tools %} + {{- "\nHere is a list of functions in JSON format that you can invoke. Use exact format: [func_name1(param1=value1, param2=value2), func_name2(...)]\n" }} + {{- tool_definition -}} + {%- elif tool_definition %} + {{- tool_definition -}} {%- endif %} {{- "<|eot|>" }} {%- endif %} -{#- Custom tools are passed in a user message with some extra guidance #} -{%- if tools_in_user_message and tools is not none %} - {#- Extract the first user message so we can plug it in here #} - {%- if messages | length != 0 %} - {%- if messages[0]['content'] is string %} - {%- set first_user_message = messages[0]['content']|trim %} - {%- else %} - {%- set first_user_message = messages[0]['content'] | selectattr('type', 'equalto', 'text') | map(attribute='text') | map('trim') | join('\n') %} - {%- endif %} - {%- set messages = messages[1:] %} - {%- else %} - {{- raise_exception("Cannot put tools in the first user message when there's no first user message!") }} - {%- endif %} - {{- '<|header_start|>user<|header_end|>\n\n' -}} - {{- first_user_message}} - {{- "\nHere is a list of functions in JSON format that you can invoke:"}} - {%- for t in tools %} - {{- t | tojson(indent=4) }} - {{- "\n\n" }} - {%- endfor %} - {{- "Should you decide to return the function call(s), put them in the format " - "of [func_name1(params_name1=params_value1, params_name2=params_value2, " - "...), ...]\nDo not include anything else when calling the tools with the " - "syntax above." }} -{%- endif %} - +{#- Now deal with all other messages #} {%- for message in messages %} - {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %} - {{- '<|header_start|>' + message['role'] + '<|header_end|>\n\n' }} + {#- Base case: messages that are not from tool role and has empty tool_call list #} + {%- if not (message.role == 'ipython' or message.role == 'tool' or ('tool_calls' in message and message.tool_calls|length != 0 )) %} + {{- '<|header_start|>' + message['role'] + '<|header_end|>\n\n' }} {%- if message['content'] is string %} {{- message['content'] }} {%- else %} @@ -92,10 +58,12 @@ {%- endif %} {%- endfor %} {%- endif %} - {{- "<|eot|>" }} - {%- elif 'tool_calls' in message and message.tool_calls|length > 0 %} - {%- set tool_call = message.tool_calls[0].function %} - {{- '<|header_start|>assistant<|header_end|>\n\n' -}} + {{- "<|eot|>" }} + {#- Tool case: messages has non-empty tool_call list, must from assistant #} + {%- elif 'tool_calls' in message %} + {#- assume tool_calls are always coming from assistant #} + {%- if message.role == 'assistant' %} + {{- '<|header_start|>assistant<|header_end|>\n\n' -}} {%- if message['content'] is string %} {{- message['content'] }} {%- else %} @@ -107,20 +75,24 @@ {%- endif %} {%- endfor %} {%- endif %} + {{- "[" }} {%- for tool_call in message.tool_calls %} {%- if tool_call.function is defined %} {%- set tool_call = tool_call.function %} {%- endif %} - {{- tool_call.name + '(' -}} + {{- tool_call.name + '(' -}} {%- for param in tool_call.arguments %} - {{- param + '=' -}} + {{- param + '="' -}} {{- "%s" | format(tool_call.arguments[param]) -}} + {{- '"' -}} {% if not loop.last %}, {% endif %} {%- endfor %} {{- ')' -}} {% if not loop.last %}, {% endif %} {%- endfor %} - {{- "<|eom|>" }} + {{- "]<|eot|>" }} +{%- endif %} +{#- Tool_response case: messages are from tool_response #} {%- elif message.role == "tool" or message.role == "ipython" %} {{- "<|header_start|>ipython<|header_end|>\n\n" }} {%- if message.content is string %} @@ -132,7 +104,7 @@ {%- endif %} {%- endfor %} {%- endif %} - {{- "<|eom|>" }} + {{- "<|eot|>" }} {%- endif %} {%- endfor %} {%- if add_generation_prompt %} diff --git a/python/sglang/srt/function_call/pythonic_detector.py b/python/sglang/srt/function_call/pythonic_detector.py index b9a7bd3a8..4ef2e3db5 100644 --- a/python/sglang/srt/function_call/pythonic_detector.py +++ b/python/sglang/srt/function_call/pythonic_detector.py @@ -32,13 +32,24 @@ class PythonicDetector(BaseFormatDetector): re.DOTALL, ) + @staticmethod + def _text_strip(text: str) -> str: + # Llama 4 model sometime will output <|python_start|> and <|python_end|> tokens + # remove those tokens + text = text.replace("<|python_start|>", "") + text = text.replace("<|python_end|>", "") + return text + def has_tool_call(self, text: str) -> bool: - return bool(self.tool_call_regex.search(text.strip())) + return bool(self.tool_call_regex.search(self._text_strip(text.strip()))) def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult: # Try parsing the text as a Python list of function calls text = text.strip() + # Remove unexpected <|python_start|> and <|python_end|> for llama4 + text = self._text_strip(text) + match = self.tool_call_regex.search(text) if match is None: return StreamingParseResult(normal_text=text, calls=[]) @@ -117,6 +128,30 @@ class PythonicDetector(BaseFormatDetector): return i return -1 # No matching bracket found + def _strip_and_split_buffer(self, buffer: str) -> tuple[str, str]: + """ + Strip special tokens from buffer and split into safe_text and held_back_text. + + Returns: + tuple of (safe_text_to_output, text_to_hold_in_buffer) + """ + # Check if original buffer ends with a partial token at the end + special_tokens = ["<|python_start|>", "<|python_end|>"] + + for token in special_tokens: + partial_length = self._ends_with_partial_token(buffer, token) + if partial_length > 0: + # Split buffer: safe part + held back partial token + safe_text = buffer[:-partial_length] + held_back = buffer[-partial_length:] + # Strip complete special tokens from safe part only + safe_text = self._text_strip(safe_text) + return safe_text, held_back + + # No partial tokens found, strip complete tokens from entire buffer + safe_text = self._text_strip(buffer) + return safe_text, "" + def parse_streaming_increment( self, new_text: str, tools: List[Tool] ) -> StreamingParseResult: @@ -126,20 +161,28 @@ class PythonicDetector(BaseFormatDetector): then parses and emits any detected calls. """ self._buffer += new_text - start = self._buffer.find("[") + + # Strip special tokens from entire buffer and handle partial tokens + stripped_buffer, held_back = self._strip_and_split_buffer(self._buffer) + + start = stripped_buffer.find("[") if start == -1: - normal_text = self._buffer - self._buffer = "" - return StreamingParseResult(normal_text=normal_text) + # No tool call bracket found + self._buffer = held_back + return StreamingParseResult(normal_text=stripped_buffer) - normal_text = self._buffer[:start] if start > 0 else "" + normal_text = stripped_buffer[:start] if start > 0 else "" - end = self._find_matching_bracket(self._buffer, start) + end = self._find_matching_bracket(stripped_buffer, start) if end != -1: - call_text = self._buffer[start : end + 1] + # Found complete tool call + call_text = stripped_buffer[start : end + 1] result = self.detect_and_parse(call_text, tools) - self._buffer = self._buffer[end + 1 :] + + # Update buffer with remaining text after tool call plus any held back text + remaining_text = stripped_buffer[end + 1 :] + held_back + self._buffer = remaining_text # If we had normal text before the tool call, add it to the result if normal_text: @@ -148,8 +191,10 @@ class PythonicDetector(BaseFormatDetector): return result # We have an opening bracket but no closing bracket yet + # Put back everything from the bracket onwards plus held back text + self._buffer = stripped_buffer[start:] + held_back + if normal_text: - self._buffer = self._buffer[start:] return StreamingParseResult(normal_text=normal_text) # Otherwise, we're still accumulating a potential tool call diff --git a/test/srt/test_function_call_parser.py b/test/srt/test_function_call_parser.py index 1ac58d9f6..d418fcd90 100644 --- a/test/srt/test_function_call_parser.py +++ b/test/srt/test_function_call_parser.py @@ -265,6 +265,81 @@ class TestPythonicDetector(unittest.TestCase): self.assertEqual(params["location"], "Tokyo") self.assertEqual(params["data"], [1, 2, 3]) + def test_parse_streaming_with_python_start_and_end_token(self): + """Test parsing a message that starts with <|python_start|> and <|python_end|> across chunks.""" + chunks = [ + "Here's a call: ", + "<|python_", + "start|>[get_weather(location=", + "'Tokyo', data=[1, 2", + ", 3])]<|python_end|>", + ] + + normal_text = "" + call_name = "" + parameters = "" + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, self.tools) + if result.normal_text: + normal_text += result.normal_text + if result.calls: + call_name += result.calls[0].name + parameters += result.calls[0].parameters + + self.assertEqual(normal_text, "Here's a call: ") + self.assertEqual(call_name, "get_weather") + self.assertEqual(self.detector._buffer, "") + self.assertEqual( + result.normal_text, "", "Final result should have no normal text" + ) + + # Check the parameters + params = json.loads(parameters) + self.assertEqual(params["location"], "Tokyo") + self.assertEqual(params["data"], [1, 2, 3]) + + chunks = [ + "Here's a call: <|python_start|>[get_weather(location='Tokyo', data=[1, 2, 3])]<|python_end|>" + ] + + normal_text = "" + call_name = "" + parameters = "" + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, self.tools) + if result.normal_text: + normal_text += result.normal_text + if result.calls: + call_name += result.calls[0].name + parameters += result.calls[0].parameters + + self.assertEqual(normal_text, "Here's a call: ") + self.assertEqual(call_name, "get_weather") + self.assertEqual(self.detector._buffer, "") + + # Check the parameters + params = json.loads(parameters) + self.assertEqual(params["location"], "Tokyo") + self.assertEqual(params["data"], [1, 2, 3]) + + def test_detect_and_parse_with_python_start_and_end_token(self): + """Test parsing a message that starts with <|python_start|> and contains a valid tool call.""" + text = "User wants to get the weather in Mars. <|python_start|>[get_weather(location='Mars', unit='celsius')]<|python_end|> In this way we will get the weather in Mars." + result = self.detector.detect_and_parse(text, self.tools) + + self.assertEqual( + result.normal_text, + "User wants to get the weather in Mars. In this way we will get the weather in Mars.", + ) + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].name, "get_weather") + self.assertEqual(self.detector._buffer, "") + + # Check the parameters + params = json.loads(result.calls[0].parameters) + self.assertEqual(params["location"], "Mars") + self.assertEqual(params["unit"], "celsius") + class TestMistralDetector(unittest.TestCase): def setUp(self):