update llama4 chat template and pythonic parser (#6679)
Co-authored-by: Chang Su <chang.s.su@oracle.com>
This commit is contained in:
@@ -1,17 +1,18 @@
|
||||
{# Copied from https://github.com/yeqcharlotte/vllm/blob/4fcf68a948bbe0498dc8a98feafa102cfb1dd210/examples/tool_chat_template_llama4_pythonic.jinja to enable better model response. #}
|
||||
{# Copied from https://github.com/wukaixingxp/vllm/blob/8a32e2a6e452a03c0e8222e3876ad6086cbf581f/examples/tool_chat_template_llama4_pythonic.jinja to enable better model response. #}
|
||||
{{- bos_token }}
|
||||
{%- if custom_tools is defined %}
|
||||
{%- if custom_tools is defined and custom_tools %}
|
||||
{%- set tools = custom_tools %}
|
||||
{%- endif %}
|
||||
{%- if not tools_in_user_message is defined %}
|
||||
{%- set tools_in_user_message = false %}
|
||||
{%- endif %}
|
||||
{%- if not tools is defined %}
|
||||
{%- if tools is defined and tools %}
|
||||
{%- set tool_definition = tool_definition ~ (tools | tojson(indent=4)) %}
|
||||
{%- else %}
|
||||
{%- set tools = none %}
|
||||
{%- endif %}
|
||||
|
||||
|
||||
{#- This block extracts the system message, so we can slot it into the right place. #}
|
||||
{%- if messages[0]['role'] == 'system' %}
|
||||
{%- set user_provided_system_message = true %}
|
||||
{%- if messages[0]['content'] is string %}
|
||||
{%- set system_message = messages[0]['content']|trim %}
|
||||
{%- else %}
|
||||
@@ -19,68 +20,33 @@
|
||||
{%- endif %}
|
||||
{%- set messages = messages[1:] %}
|
||||
{%- else %}
|
||||
{%- if tools is not none %}
|
||||
{#- Add default tool system message when tools are provided #}
|
||||
{%- set system_message = "You are a helpful assistant with tool calling "
|
||||
"capabilities. Only reply with a tool call if the function exists in the "
|
||||
"library provided by the user. If it doesn't exist, just reply directly in "
|
||||
"natural language. When you receive a tool call response, use the output to "
|
||||
"format an answer to the original user question." %}
|
||||
{%- if tools is not none %}
|
||||
{#- Since not system_message was provided by user, if tool is provided, system_message is now default tool system message #}
|
||||
{#- This system message is from llama website:https://www.llama.com/docs/model-cards-and-prompt-formats/llama4/ #}
|
||||
{%- set system_message = "You are a helpful assistant and an expert in function composition. You can answer general questions using your internal knowledge OR invoke functions when necessary. Follow these strict guidelines:\n\n1. FUNCTION CALLS:\n- ONLY use functions that are EXPLICITLY listed in the function list below\n- If NO functions are listed (empty function list []), respond ONLY with internal knowledge or \"I don't have access to [Unavailable service] information\"\n- If a function is not in the list, respond ONLY with internal knowledge or \"I don't have access to [Unavailable service] information\"\n- If ALL required parameters are present AND the query EXACTLY matches a listed function's purpose: output ONLY the function call(s)\n- Use exact format: [func_name1(param1=value1, param2=value2), func_name2(...)]\nExamples:\nCORRECT: [get_weather(location=\"Vancouver\"), calculate_route(start=\"Boston\", end=\"New York\")] <- Only if get_weather and calculate_route are in function list\nINCORRECT: get_weather(location=\"New York\")\nINCORRECT: Let me check the weather: [get_weather(location=\"New York\")]\nINCORRECT: [get_events(location=\"Singapore\")] <- If function not in list\n\n2. RESPONSE RULES:\n- For pure function requests matching a listed function: ONLY output the function call(s)\n- For knowledge questions: ONLY output text\n- For missing parameters: ONLY request the specific missing parameters\n- For unavailable services (not in function list): output ONLY with internal knowledge or \"I don't have access to [Unavailable service] information\". Do NOT execute a function call.\n- If the query asks for information beyond what a listed function provides: output ONLY with internal knowledge about your limitations\n- NEVER combine text and function calls in the same response\n- NEVER suggest alternative functions when the requested service is unavailable\n- NEVER create or invent new functions not listed below\n\n3. STRICT BOUNDARIES:\n- ONLY use functions from the list below - no exceptions\n- NEVER use a function as an alternative to unavailable information\n- NEVER call functions not present in the function list\n- NEVER add explanatory text to function calls\n- NEVER respond with empty brackets\n- Use proper Python/JSON syntax for function calls\n- Check the function list carefully before responding\n\n4. TOOL RESPONSE HANDLING:\n- When receiving tool responses: provide concise, natural language responses\n- Don't repeat tool response verbatim\n- Don't add supplementary information\n\nHere is a list of functions in JSON format that you can invoke:\n" %}
|
||||
{%- else %}
|
||||
{%- set system_message = "" %}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
|
||||
{#- System message if the user supplied one, or if tools are used (default tool system message) #}
|
||||
{#- Now writing the system message: use the user provided system message if user_provided_system_message, else default tool system message if tools presented #}
|
||||
{%- if system_message %}
|
||||
{#- always use user provided system message to override default tool system message #}
|
||||
{{- "<|header_start|>system<|header_end|>\n\n" }}
|
||||
{{- system_message }}
|
||||
{%- if tools is not none and not tools_in_user_message %}
|
||||
{{- "Tools: You have access to the following tools. You might need to use one "
|
||||
"or more function/tool calls to fulfill the task. \n"
|
||||
"If none are needed, then proceed to the response.\n\n"
|
||||
"Tool Call Syntax: You can call tools using the following syntax:\n"
|
||||
"[func_name1(params_name1=params_value1, params_name2=params_value2, ...), ...]\n"
|
||||
"Do not include anything else when calling the tools with the syntax above.\n\n"
|
||||
"Here is a list of functions in JSON format that you can invoke.\n " }}
|
||||
{%- for t in tools %}
|
||||
{{- t | tojson(indent=4) }}
|
||||
{{- "\n\n" }}
|
||||
{%- endfor %}
|
||||
{%- if user_provided_system_message and tools %}
|
||||
{{- "\nHere is a list of functions in JSON format that you can invoke. Use exact format: [func_name1(param1=value1, param2=value2), func_name2(...)]\n" }}
|
||||
{{- tool_definition -}}
|
||||
{%- elif tool_definition %}
|
||||
{{- tool_definition -}}
|
||||
{%- endif %}
|
||||
{{- "<|eot|>" }}
|
||||
{%- endif %}
|
||||
|
||||
{#- Custom tools are passed in a user message with some extra guidance #}
|
||||
{%- if tools_in_user_message and tools is not none %}
|
||||
{#- Extract the first user message so we can plug it in here #}
|
||||
{%- if messages | length != 0 %}
|
||||
{%- if messages[0]['content'] is string %}
|
||||
{%- set first_user_message = messages[0]['content']|trim %}
|
||||
{%- else %}
|
||||
{%- set first_user_message = messages[0]['content'] | selectattr('type', 'equalto', 'text') | map(attribute='text') | map('trim') | join('\n') %}
|
||||
{%- endif %}
|
||||
{%- set messages = messages[1:] %}
|
||||
{%- else %}
|
||||
{{- raise_exception("Cannot put tools in the first user message when there's no first user message!") }}
|
||||
{%- endif %}
|
||||
{{- '<|header_start|>user<|header_end|>\n\n' -}}
|
||||
{{- first_user_message}}
|
||||
{{- "\nHere is a list of functions in JSON format that you can invoke:"}}
|
||||
{%- for t in tools %}
|
||||
{{- t | tojson(indent=4) }}
|
||||
{{- "\n\n" }}
|
||||
{%- endfor %}
|
||||
{{- "Should you decide to return the function call(s), put them in the format "
|
||||
"of [func_name1(params_name1=params_value1, params_name2=params_value2, "
|
||||
"...), ...]\nDo not include anything else when calling the tools with the "
|
||||
"syntax above." }}
|
||||
{%- endif %}
|
||||
|
||||
{#- Now deal with all other messages #}
|
||||
{%- for message in messages %}
|
||||
{%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}
|
||||
{{- '<|header_start|>' + message['role'] + '<|header_end|>\n\n' }}
|
||||
{#- Base case: messages that are not from tool role and has empty tool_call list #}
|
||||
{%- if not (message.role == 'ipython' or message.role == 'tool' or ('tool_calls' in message and message.tool_calls|length != 0 )) %}
|
||||
{{- '<|header_start|>' + message['role'] + '<|header_end|>\n\n' }}
|
||||
{%- if message['content'] is string %}
|
||||
{{- message['content'] }}
|
||||
{%- else %}
|
||||
@@ -92,10 +58,12 @@
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
{{- "<|eot|>" }}
|
||||
{%- elif 'tool_calls' in message and message.tool_calls|length > 0 %}
|
||||
{%- set tool_call = message.tool_calls[0].function %}
|
||||
{{- '<|header_start|>assistant<|header_end|>\n\n' -}}
|
||||
{{- "<|eot|>" }}
|
||||
{#- Tool case: messages has non-empty tool_call list, must from assistant #}
|
||||
{%- elif 'tool_calls' in message %}
|
||||
{#- assume tool_calls are always coming from assistant #}
|
||||
{%- if message.role == 'assistant' %}
|
||||
{{- '<|header_start|>assistant<|header_end|>\n\n' -}}
|
||||
{%- if message['content'] is string %}
|
||||
{{- message['content'] }}
|
||||
{%- else %}
|
||||
@@ -107,20 +75,24 @@
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
{{- "[" }}
|
||||
{%- for tool_call in message.tool_calls %}
|
||||
{%- if tool_call.function is defined %}
|
||||
{%- set tool_call = tool_call.function %}
|
||||
{%- endif %}
|
||||
{{- tool_call.name + '(' -}}
|
||||
{{- tool_call.name + '(' -}}
|
||||
{%- for param in tool_call.arguments %}
|
||||
{{- param + '=' -}}
|
||||
{{- param + '="' -}}
|
||||
{{- "%s" | format(tool_call.arguments[param]) -}}
|
||||
{{- '"' -}}
|
||||
{% if not loop.last %}, {% endif %}
|
||||
{%- endfor %}
|
||||
{{- ')' -}}
|
||||
{% if not loop.last %}, {% endif %}
|
||||
{%- endfor %}
|
||||
{{- "<|eom|>" }}
|
||||
{{- "]<|eot|>" }}
|
||||
{%- endif %}
|
||||
{#- Tool_response case: messages are from tool_response #}
|
||||
{%- elif message.role == "tool" or message.role == "ipython" %}
|
||||
{{- "<|header_start|>ipython<|header_end|>\n\n" }}
|
||||
{%- if message.content is string %}
|
||||
@@ -132,7 +104,7 @@
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
{{- "<|eom|>" }}
|
||||
{{- "<|eot|>" }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- if add_generation_prompt %}
|
||||
|
||||
@@ -32,13 +32,24 @@ class PythonicDetector(BaseFormatDetector):
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _text_strip(text: str) -> str:
|
||||
# Llama 4 model sometime will output <|python_start|> and <|python_end|> tokens
|
||||
# remove those tokens
|
||||
text = text.replace("<|python_start|>", "")
|
||||
text = text.replace("<|python_end|>", "")
|
||||
return text
|
||||
|
||||
def has_tool_call(self, text: str) -> bool:
|
||||
return bool(self.tool_call_regex.search(text.strip()))
|
||||
return bool(self.tool_call_regex.search(self._text_strip(text.strip())))
|
||||
|
||||
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
||||
# Try parsing the text as a Python list of function calls
|
||||
text = text.strip()
|
||||
|
||||
# Remove unexpected <|python_start|> and <|python_end|> for llama4
|
||||
text = self._text_strip(text)
|
||||
|
||||
match = self.tool_call_regex.search(text)
|
||||
if match is None:
|
||||
return StreamingParseResult(normal_text=text, calls=[])
|
||||
@@ -117,6 +128,30 @@ class PythonicDetector(BaseFormatDetector):
|
||||
return i
|
||||
return -1 # No matching bracket found
|
||||
|
||||
def _strip_and_split_buffer(self, buffer: str) -> tuple[str, str]:
|
||||
"""
|
||||
Strip special tokens from buffer and split into safe_text and held_back_text.
|
||||
|
||||
Returns:
|
||||
tuple of (safe_text_to_output, text_to_hold_in_buffer)
|
||||
"""
|
||||
# Check if original buffer ends with a partial token at the end
|
||||
special_tokens = ["<|python_start|>", "<|python_end|>"]
|
||||
|
||||
for token in special_tokens:
|
||||
partial_length = self._ends_with_partial_token(buffer, token)
|
||||
if partial_length > 0:
|
||||
# Split buffer: safe part + held back partial token
|
||||
safe_text = buffer[:-partial_length]
|
||||
held_back = buffer[-partial_length:]
|
||||
# Strip complete special tokens from safe part only
|
||||
safe_text = self._text_strip(safe_text)
|
||||
return safe_text, held_back
|
||||
|
||||
# No partial tokens found, strip complete tokens from entire buffer
|
||||
safe_text = self._text_strip(buffer)
|
||||
return safe_text, ""
|
||||
|
||||
def parse_streaming_increment(
|
||||
self, new_text: str, tools: List[Tool]
|
||||
) -> StreamingParseResult:
|
||||
@@ -126,20 +161,28 @@ class PythonicDetector(BaseFormatDetector):
|
||||
then parses and emits any detected calls.
|
||||
"""
|
||||
self._buffer += new_text
|
||||
start = self._buffer.find("[")
|
||||
|
||||
# Strip special tokens from entire buffer and handle partial tokens
|
||||
stripped_buffer, held_back = self._strip_and_split_buffer(self._buffer)
|
||||
|
||||
start = stripped_buffer.find("[")
|
||||
|
||||
if start == -1:
|
||||
normal_text = self._buffer
|
||||
self._buffer = ""
|
||||
return StreamingParseResult(normal_text=normal_text)
|
||||
# No tool call bracket found
|
||||
self._buffer = held_back
|
||||
return StreamingParseResult(normal_text=stripped_buffer)
|
||||
|
||||
normal_text = self._buffer[:start] if start > 0 else ""
|
||||
normal_text = stripped_buffer[:start] if start > 0 else ""
|
||||
|
||||
end = self._find_matching_bracket(self._buffer, start)
|
||||
end = self._find_matching_bracket(stripped_buffer, start)
|
||||
if end != -1:
|
||||
call_text = self._buffer[start : end + 1]
|
||||
# Found complete tool call
|
||||
call_text = stripped_buffer[start : end + 1]
|
||||
result = self.detect_and_parse(call_text, tools)
|
||||
self._buffer = self._buffer[end + 1 :]
|
||||
|
||||
# Update buffer with remaining text after tool call plus any held back text
|
||||
remaining_text = stripped_buffer[end + 1 :] + held_back
|
||||
self._buffer = remaining_text
|
||||
|
||||
# If we had normal text before the tool call, add it to the result
|
||||
if normal_text:
|
||||
@@ -148,8 +191,10 @@ class PythonicDetector(BaseFormatDetector):
|
||||
return result
|
||||
|
||||
# We have an opening bracket but no closing bracket yet
|
||||
# Put back everything from the bracket onwards plus held back text
|
||||
self._buffer = stripped_buffer[start:] + held_back
|
||||
|
||||
if normal_text:
|
||||
self._buffer = self._buffer[start:]
|
||||
return StreamingParseResult(normal_text=normal_text)
|
||||
|
||||
# Otherwise, we're still accumulating a potential tool call
|
||||
|
||||
@@ -265,6 +265,81 @@ class TestPythonicDetector(unittest.TestCase):
|
||||
self.assertEqual(params["location"], "Tokyo")
|
||||
self.assertEqual(params["data"], [1, 2, 3])
|
||||
|
||||
def test_parse_streaming_with_python_start_and_end_token(self):
|
||||
"""Test parsing a message that starts with <|python_start|> and <|python_end|> across chunks."""
|
||||
chunks = [
|
||||
"Here's a call: ",
|
||||
"<|python_",
|
||||
"start|>[get_weather(location=",
|
||||
"'Tokyo', data=[1, 2",
|
||||
", 3])]<|python_end|>",
|
||||
]
|
||||
|
||||
normal_text = ""
|
||||
call_name = ""
|
||||
parameters = ""
|
||||
for chunk in chunks:
|
||||
result = self.detector.parse_streaming_increment(chunk, self.tools)
|
||||
if result.normal_text:
|
||||
normal_text += result.normal_text
|
||||
if result.calls:
|
||||
call_name += result.calls[0].name
|
||||
parameters += result.calls[0].parameters
|
||||
|
||||
self.assertEqual(normal_text, "Here's a call: ")
|
||||
self.assertEqual(call_name, "get_weather")
|
||||
self.assertEqual(self.detector._buffer, "")
|
||||
self.assertEqual(
|
||||
result.normal_text, "", "Final result should have no normal text"
|
||||
)
|
||||
|
||||
# Check the parameters
|
||||
params = json.loads(parameters)
|
||||
self.assertEqual(params["location"], "Tokyo")
|
||||
self.assertEqual(params["data"], [1, 2, 3])
|
||||
|
||||
chunks = [
|
||||
"Here's a call: <|python_start|>[get_weather(location='Tokyo', data=[1, 2, 3])]<|python_end|>"
|
||||
]
|
||||
|
||||
normal_text = ""
|
||||
call_name = ""
|
||||
parameters = ""
|
||||
for chunk in chunks:
|
||||
result = self.detector.parse_streaming_increment(chunk, self.tools)
|
||||
if result.normal_text:
|
||||
normal_text += result.normal_text
|
||||
if result.calls:
|
||||
call_name += result.calls[0].name
|
||||
parameters += result.calls[0].parameters
|
||||
|
||||
self.assertEqual(normal_text, "Here's a call: ")
|
||||
self.assertEqual(call_name, "get_weather")
|
||||
self.assertEqual(self.detector._buffer, "")
|
||||
|
||||
# Check the parameters
|
||||
params = json.loads(parameters)
|
||||
self.assertEqual(params["location"], "Tokyo")
|
||||
self.assertEqual(params["data"], [1, 2, 3])
|
||||
|
||||
def test_detect_and_parse_with_python_start_and_end_token(self):
|
||||
"""Test parsing a message that starts with <|python_start|> and contains a valid tool call."""
|
||||
text = "User wants to get the weather in Mars. <|python_start|>[get_weather(location='Mars', unit='celsius')]<|python_end|> In this way we will get the weather in Mars."
|
||||
result = self.detector.detect_and_parse(text, self.tools)
|
||||
|
||||
self.assertEqual(
|
||||
result.normal_text,
|
||||
"User wants to get the weather in Mars. In this way we will get the weather in Mars.",
|
||||
)
|
||||
self.assertEqual(len(result.calls), 1)
|
||||
self.assertEqual(result.calls[0].name, "get_weather")
|
||||
self.assertEqual(self.detector._buffer, "")
|
||||
|
||||
# Check the parameters
|
||||
params = json.loads(result.calls[0].parameters)
|
||||
self.assertEqual(params["location"], "Mars")
|
||||
self.assertEqual(params["unit"], "celsius")
|
||||
|
||||
|
||||
class TestMistralDetector(unittest.TestCase):
|
||||
def setUp(self):
|
||||
|
||||
Reference in New Issue
Block a user