update llama4 chat template and pythonic parser (#6679)
Co-authored-by: Chang Su <chang.s.su@oracle.com>
This commit is contained in:
@@ -1,17 +1,18 @@
|
|||||||
{# Copied from https://github.com/yeqcharlotte/vllm/blob/4fcf68a948bbe0498dc8a98feafa102cfb1dd210/examples/tool_chat_template_llama4_pythonic.jinja to enable better model response. #}
|
{# Copied from https://github.com/wukaixingxp/vllm/blob/8a32e2a6e452a03c0e8222e3876ad6086cbf581f/examples/tool_chat_template_llama4_pythonic.jinja to enable better model response. #}
|
||||||
{{- bos_token }}
|
{{- bos_token }}
|
||||||
{%- if custom_tools is defined %}
|
{%- if custom_tools is defined and custom_tools %}
|
||||||
{%- set tools = custom_tools %}
|
{%- set tools = custom_tools %}
|
||||||
{%- endif %}
|
{%- endif %}
|
||||||
{%- if not tools_in_user_message is defined %}
|
{%- if tools is defined and tools %}
|
||||||
{%- set tools_in_user_message = false %}
|
{%- set tool_definition = tool_definition ~ (tools | tojson(indent=4)) %}
|
||||||
{%- endif %}
|
{%- else %}
|
||||||
{%- if not tools is defined %}
|
|
||||||
{%- set tools = none %}
|
{%- set tools = none %}
|
||||||
{%- endif %}
|
{%- endif %}
|
||||||
|
|
||||||
|
|
||||||
{#- This block extracts the system message, so we can slot it into the right place. #}
|
{#- This block extracts the system message, so we can slot it into the right place. #}
|
||||||
{%- if messages[0]['role'] == 'system' %}
|
{%- if messages[0]['role'] == 'system' %}
|
||||||
|
{%- set user_provided_system_message = true %}
|
||||||
{%- if messages[0]['content'] is string %}
|
{%- if messages[0]['content'] is string %}
|
||||||
{%- set system_message = messages[0]['content']|trim %}
|
{%- set system_message = messages[0]['content']|trim %}
|
||||||
{%- else %}
|
{%- else %}
|
||||||
@@ -19,68 +20,33 @@
|
|||||||
{%- endif %}
|
{%- endif %}
|
||||||
{%- set messages = messages[1:] %}
|
{%- set messages = messages[1:] %}
|
||||||
{%- else %}
|
{%- else %}
|
||||||
{%- if tools is not none %}
|
{%- if tools is not none %}
|
||||||
{#- Add default tool system message when tools are provided #}
|
{#- Since not system_message was provided by user, if tool is provided, system_message is now default tool system message #}
|
||||||
{%- set system_message = "You are a helpful assistant with tool calling "
|
{#- This system message is from llama website:https://www.llama.com/docs/model-cards-and-prompt-formats/llama4/ #}
|
||||||
"capabilities. Only reply with a tool call if the function exists in the "
|
{%- set system_message = "You are a helpful assistant and an expert in function composition. You can answer general questions using your internal knowledge OR invoke functions when necessary. Follow these strict guidelines:\n\n1. FUNCTION CALLS:\n- ONLY use functions that are EXPLICITLY listed in the function list below\n- If NO functions are listed (empty function list []), respond ONLY with internal knowledge or \"I don't have access to [Unavailable service] information\"\n- If a function is not in the list, respond ONLY with internal knowledge or \"I don't have access to [Unavailable service] information\"\n- If ALL required parameters are present AND the query EXACTLY matches a listed function's purpose: output ONLY the function call(s)\n- Use exact format: [func_name1(param1=value1, param2=value2), func_name2(...)]\nExamples:\nCORRECT: [get_weather(location=\"Vancouver\"), calculate_route(start=\"Boston\", end=\"New York\")] <- Only if get_weather and calculate_route are in function list\nINCORRECT: get_weather(location=\"New York\")\nINCORRECT: Let me check the weather: [get_weather(location=\"New York\")]\nINCORRECT: [get_events(location=\"Singapore\")] <- If function not in list\n\n2. RESPONSE RULES:\n- For pure function requests matching a listed function: ONLY output the function call(s)\n- For knowledge questions: ONLY output text\n- For missing parameters: ONLY request the specific missing parameters\n- For unavailable services (not in function list): output ONLY with internal knowledge or \"I don't have access to [Unavailable service] information\". Do NOT execute a function call.\n- If the query asks for information beyond what a listed function provides: output ONLY with internal knowledge about your limitations\n- NEVER combine text and function calls in the same response\n- NEVER suggest alternative functions when the requested service is unavailable\n- NEVER create or invent new functions not listed below\n\n3. STRICT BOUNDARIES:\n- ONLY use functions from the list below - no exceptions\n- NEVER use a function as an alternative to unavailable information\n- NEVER call functions not present in the function list\n- NEVER add explanatory text to function calls\n- NEVER respond with empty brackets\n- Use proper Python/JSON syntax for function calls\n- Check the function list carefully before responding\n\n4. TOOL RESPONSE HANDLING:\n- When receiving tool responses: provide concise, natural language responses\n- Don't repeat tool response verbatim\n- Don't add supplementary information\n\nHere is a list of functions in JSON format that you can invoke:\n" %}
|
||||||
"library provided by the user. If it doesn't exist, just reply directly in "
|
|
||||||
"natural language. When you receive a tool call response, use the output to "
|
|
||||||
"format an answer to the original user question." %}
|
|
||||||
{%- else %}
|
{%- else %}
|
||||||
{%- set system_message = "" %}
|
{%- set system_message = "" %}
|
||||||
{%- endif %}
|
{%- endif %}
|
||||||
{%- endif %}
|
{%- endif %}
|
||||||
|
{#- Now writing the system message: use the user provided system message if user_provided_system_message, else default tool system message if tools presented #}
|
||||||
{#- System message if the user supplied one, or if tools are used (default tool system message) #}
|
|
||||||
{%- if system_message %}
|
{%- if system_message %}
|
||||||
{#- always use user provided system message to override default tool system message #}
|
{#- always use user provided system message to override default tool system message #}
|
||||||
{{- "<|header_start|>system<|header_end|>\n\n" }}
|
{{- "<|header_start|>system<|header_end|>\n\n" }}
|
||||||
{{- system_message }}
|
{{- system_message }}
|
||||||
{%- if tools is not none and not tools_in_user_message %}
|
{%- if user_provided_system_message and tools %}
|
||||||
{{- "Tools: You have access to the following tools. You might need to use one "
|
{{- "\nHere is a list of functions in JSON format that you can invoke. Use exact format: [func_name1(param1=value1, param2=value2), func_name2(...)]\n" }}
|
||||||
"or more function/tool calls to fulfill the task. \n"
|
{{- tool_definition -}}
|
||||||
"If none are needed, then proceed to the response.\n\n"
|
{%- elif tool_definition %}
|
||||||
"Tool Call Syntax: You can call tools using the following syntax:\n"
|
{{- tool_definition -}}
|
||||||
"[func_name1(params_name1=params_value1, params_name2=params_value2, ...), ...]\n"
|
|
||||||
"Do not include anything else when calling the tools with the syntax above.\n\n"
|
|
||||||
"Here is a list of functions in JSON format that you can invoke.\n " }}
|
|
||||||
{%- for t in tools %}
|
|
||||||
{{- t | tojson(indent=4) }}
|
|
||||||
{{- "\n\n" }}
|
|
||||||
{%- endfor %}
|
|
||||||
{%- endif %}
|
{%- endif %}
|
||||||
{{- "<|eot|>" }}
|
{{- "<|eot|>" }}
|
||||||
{%- endif %}
|
{%- endif %}
|
||||||
|
|
||||||
{#- Custom tools are passed in a user message with some extra guidance #}
|
{#- Now deal with all other messages #}
|
||||||
{%- if tools_in_user_message and tools is not none %}
|
|
||||||
{#- Extract the first user message so we can plug it in here #}
|
|
||||||
{%- if messages | length != 0 %}
|
|
||||||
{%- if messages[0]['content'] is string %}
|
|
||||||
{%- set first_user_message = messages[0]['content']|trim %}
|
|
||||||
{%- else %}
|
|
||||||
{%- set first_user_message = messages[0]['content'] | selectattr('type', 'equalto', 'text') | map(attribute='text') | map('trim') | join('\n') %}
|
|
||||||
{%- endif %}
|
|
||||||
{%- set messages = messages[1:] %}
|
|
||||||
{%- else %}
|
|
||||||
{{- raise_exception("Cannot put tools in the first user message when there's no first user message!") }}
|
|
||||||
{%- endif %}
|
|
||||||
{{- '<|header_start|>user<|header_end|>\n\n' -}}
|
|
||||||
{{- first_user_message}}
|
|
||||||
{{- "\nHere is a list of functions in JSON format that you can invoke:"}}
|
|
||||||
{%- for t in tools %}
|
|
||||||
{{- t | tojson(indent=4) }}
|
|
||||||
{{- "\n\n" }}
|
|
||||||
{%- endfor %}
|
|
||||||
{{- "Should you decide to return the function call(s), put them in the format "
|
|
||||||
"of [func_name1(params_name1=params_value1, params_name2=params_value2, "
|
|
||||||
"...), ...]\nDo not include anything else when calling the tools with the "
|
|
||||||
"syntax above." }}
|
|
||||||
{%- endif %}
|
|
||||||
|
|
||||||
{%- for message in messages %}
|
{%- for message in messages %}
|
||||||
{%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}
|
{#- Base case: messages that are not from tool role and has empty tool_call list #}
|
||||||
{{- '<|header_start|>' + message['role'] + '<|header_end|>\n\n' }}
|
{%- if not (message.role == 'ipython' or message.role == 'tool' or ('tool_calls' in message and message.tool_calls|length != 0 )) %}
|
||||||
|
{{- '<|header_start|>' + message['role'] + '<|header_end|>\n\n' }}
|
||||||
{%- if message['content'] is string %}
|
{%- if message['content'] is string %}
|
||||||
{{- message['content'] }}
|
{{- message['content'] }}
|
||||||
{%- else %}
|
{%- else %}
|
||||||
@@ -92,10 +58,12 @@
|
|||||||
{%- endif %}
|
{%- endif %}
|
||||||
{%- endfor %}
|
{%- endfor %}
|
||||||
{%- endif %}
|
{%- endif %}
|
||||||
{{- "<|eot|>" }}
|
{{- "<|eot|>" }}
|
||||||
{%- elif 'tool_calls' in message and message.tool_calls|length > 0 %}
|
{#- Tool case: messages has non-empty tool_call list, must from assistant #}
|
||||||
{%- set tool_call = message.tool_calls[0].function %}
|
{%- elif 'tool_calls' in message %}
|
||||||
{{- '<|header_start|>assistant<|header_end|>\n\n' -}}
|
{#- assume tool_calls are always coming from assistant #}
|
||||||
|
{%- if message.role == 'assistant' %}
|
||||||
|
{{- '<|header_start|>assistant<|header_end|>\n\n' -}}
|
||||||
{%- if message['content'] is string %}
|
{%- if message['content'] is string %}
|
||||||
{{- message['content'] }}
|
{{- message['content'] }}
|
||||||
{%- else %}
|
{%- else %}
|
||||||
@@ -107,20 +75,24 @@
|
|||||||
{%- endif %}
|
{%- endif %}
|
||||||
{%- endfor %}
|
{%- endfor %}
|
||||||
{%- endif %}
|
{%- endif %}
|
||||||
|
{{- "[" }}
|
||||||
{%- for tool_call in message.tool_calls %}
|
{%- for tool_call in message.tool_calls %}
|
||||||
{%- if tool_call.function is defined %}
|
{%- if tool_call.function is defined %}
|
||||||
{%- set tool_call = tool_call.function %}
|
{%- set tool_call = tool_call.function %}
|
||||||
{%- endif %}
|
{%- endif %}
|
||||||
{{- tool_call.name + '(' -}}
|
{{- tool_call.name + '(' -}}
|
||||||
{%- for param in tool_call.arguments %}
|
{%- for param in tool_call.arguments %}
|
||||||
{{- param + '=' -}}
|
{{- param + '="' -}}
|
||||||
{{- "%s" | format(tool_call.arguments[param]) -}}
|
{{- "%s" | format(tool_call.arguments[param]) -}}
|
||||||
|
{{- '"' -}}
|
||||||
{% if not loop.last %}, {% endif %}
|
{% if not loop.last %}, {% endif %}
|
||||||
{%- endfor %}
|
{%- endfor %}
|
||||||
{{- ')' -}}
|
{{- ')' -}}
|
||||||
{% if not loop.last %}, {% endif %}
|
{% if not loop.last %}, {% endif %}
|
||||||
{%- endfor %}
|
{%- endfor %}
|
||||||
{{- "<|eom|>" }}
|
{{- "]<|eot|>" }}
|
||||||
|
{%- endif %}
|
||||||
|
{#- Tool_response case: messages are from tool_response #}
|
||||||
{%- elif message.role == "tool" or message.role == "ipython" %}
|
{%- elif message.role == "tool" or message.role == "ipython" %}
|
||||||
{{- "<|header_start|>ipython<|header_end|>\n\n" }}
|
{{- "<|header_start|>ipython<|header_end|>\n\n" }}
|
||||||
{%- if message.content is string %}
|
{%- if message.content is string %}
|
||||||
@@ -132,7 +104,7 @@
|
|||||||
{%- endif %}
|
{%- endif %}
|
||||||
{%- endfor %}
|
{%- endfor %}
|
||||||
{%- endif %}
|
{%- endif %}
|
||||||
{{- "<|eom|>" }}
|
{{- "<|eot|>" }}
|
||||||
{%- endif %}
|
{%- endif %}
|
||||||
{%- endfor %}
|
{%- endfor %}
|
||||||
{%- if add_generation_prompt %}
|
{%- if add_generation_prompt %}
|
||||||
|
|||||||
@@ -32,13 +32,24 @@ class PythonicDetector(BaseFormatDetector):
|
|||||||
re.DOTALL,
|
re.DOTALL,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _text_strip(text: str) -> str:
|
||||||
|
# Llama 4 model sometime will output <|python_start|> and <|python_end|> tokens
|
||||||
|
# remove those tokens
|
||||||
|
text = text.replace("<|python_start|>", "")
|
||||||
|
text = text.replace("<|python_end|>", "")
|
||||||
|
return text
|
||||||
|
|
||||||
def has_tool_call(self, text: str) -> bool:
|
def has_tool_call(self, text: str) -> bool:
|
||||||
return bool(self.tool_call_regex.search(text.strip()))
|
return bool(self.tool_call_regex.search(self._text_strip(text.strip())))
|
||||||
|
|
||||||
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
||||||
# Try parsing the text as a Python list of function calls
|
# Try parsing the text as a Python list of function calls
|
||||||
text = text.strip()
|
text = text.strip()
|
||||||
|
|
||||||
|
# Remove unexpected <|python_start|> and <|python_end|> for llama4
|
||||||
|
text = self._text_strip(text)
|
||||||
|
|
||||||
match = self.tool_call_regex.search(text)
|
match = self.tool_call_regex.search(text)
|
||||||
if match is None:
|
if match is None:
|
||||||
return StreamingParseResult(normal_text=text, calls=[])
|
return StreamingParseResult(normal_text=text, calls=[])
|
||||||
@@ -117,6 +128,30 @@ class PythonicDetector(BaseFormatDetector):
|
|||||||
return i
|
return i
|
||||||
return -1 # No matching bracket found
|
return -1 # No matching bracket found
|
||||||
|
|
||||||
|
def _strip_and_split_buffer(self, buffer: str) -> tuple[str, str]:
|
||||||
|
"""
|
||||||
|
Strip special tokens from buffer and split into safe_text and held_back_text.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple of (safe_text_to_output, text_to_hold_in_buffer)
|
||||||
|
"""
|
||||||
|
# Check if original buffer ends with a partial token at the end
|
||||||
|
special_tokens = ["<|python_start|>", "<|python_end|>"]
|
||||||
|
|
||||||
|
for token in special_tokens:
|
||||||
|
partial_length = self._ends_with_partial_token(buffer, token)
|
||||||
|
if partial_length > 0:
|
||||||
|
# Split buffer: safe part + held back partial token
|
||||||
|
safe_text = buffer[:-partial_length]
|
||||||
|
held_back = buffer[-partial_length:]
|
||||||
|
# Strip complete special tokens from safe part only
|
||||||
|
safe_text = self._text_strip(safe_text)
|
||||||
|
return safe_text, held_back
|
||||||
|
|
||||||
|
# No partial tokens found, strip complete tokens from entire buffer
|
||||||
|
safe_text = self._text_strip(buffer)
|
||||||
|
return safe_text, ""
|
||||||
|
|
||||||
def parse_streaming_increment(
|
def parse_streaming_increment(
|
||||||
self, new_text: str, tools: List[Tool]
|
self, new_text: str, tools: List[Tool]
|
||||||
) -> StreamingParseResult:
|
) -> StreamingParseResult:
|
||||||
@@ -126,20 +161,28 @@ class PythonicDetector(BaseFormatDetector):
|
|||||||
then parses and emits any detected calls.
|
then parses and emits any detected calls.
|
||||||
"""
|
"""
|
||||||
self._buffer += new_text
|
self._buffer += new_text
|
||||||
start = self._buffer.find("[")
|
|
||||||
|
# Strip special tokens from entire buffer and handle partial tokens
|
||||||
|
stripped_buffer, held_back = self._strip_and_split_buffer(self._buffer)
|
||||||
|
|
||||||
|
start = stripped_buffer.find("[")
|
||||||
|
|
||||||
if start == -1:
|
if start == -1:
|
||||||
normal_text = self._buffer
|
# No tool call bracket found
|
||||||
self._buffer = ""
|
self._buffer = held_back
|
||||||
return StreamingParseResult(normal_text=normal_text)
|
return StreamingParseResult(normal_text=stripped_buffer)
|
||||||
|
|
||||||
normal_text = self._buffer[:start] if start > 0 else ""
|
normal_text = stripped_buffer[:start] if start > 0 else ""
|
||||||
|
|
||||||
end = self._find_matching_bracket(self._buffer, start)
|
end = self._find_matching_bracket(stripped_buffer, start)
|
||||||
if end != -1:
|
if end != -1:
|
||||||
call_text = self._buffer[start : end + 1]
|
# Found complete tool call
|
||||||
|
call_text = stripped_buffer[start : end + 1]
|
||||||
result = self.detect_and_parse(call_text, tools)
|
result = self.detect_and_parse(call_text, tools)
|
||||||
self._buffer = self._buffer[end + 1 :]
|
|
||||||
|
# Update buffer with remaining text after tool call plus any held back text
|
||||||
|
remaining_text = stripped_buffer[end + 1 :] + held_back
|
||||||
|
self._buffer = remaining_text
|
||||||
|
|
||||||
# If we had normal text before the tool call, add it to the result
|
# If we had normal text before the tool call, add it to the result
|
||||||
if normal_text:
|
if normal_text:
|
||||||
@@ -148,8 +191,10 @@ class PythonicDetector(BaseFormatDetector):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
# We have an opening bracket but no closing bracket yet
|
# We have an opening bracket but no closing bracket yet
|
||||||
|
# Put back everything from the bracket onwards plus held back text
|
||||||
|
self._buffer = stripped_buffer[start:] + held_back
|
||||||
|
|
||||||
if normal_text:
|
if normal_text:
|
||||||
self._buffer = self._buffer[start:]
|
|
||||||
return StreamingParseResult(normal_text=normal_text)
|
return StreamingParseResult(normal_text=normal_text)
|
||||||
|
|
||||||
# Otherwise, we're still accumulating a potential tool call
|
# Otherwise, we're still accumulating a potential tool call
|
||||||
|
|||||||
@@ -265,6 +265,81 @@ class TestPythonicDetector(unittest.TestCase):
|
|||||||
self.assertEqual(params["location"], "Tokyo")
|
self.assertEqual(params["location"], "Tokyo")
|
||||||
self.assertEqual(params["data"], [1, 2, 3])
|
self.assertEqual(params["data"], [1, 2, 3])
|
||||||
|
|
||||||
|
def test_parse_streaming_with_python_start_and_end_token(self):
|
||||||
|
"""Test parsing a message that starts with <|python_start|> and <|python_end|> across chunks."""
|
||||||
|
chunks = [
|
||||||
|
"Here's a call: ",
|
||||||
|
"<|python_",
|
||||||
|
"start|>[get_weather(location=",
|
||||||
|
"'Tokyo', data=[1, 2",
|
||||||
|
", 3])]<|python_end|>",
|
||||||
|
]
|
||||||
|
|
||||||
|
normal_text = ""
|
||||||
|
call_name = ""
|
||||||
|
parameters = ""
|
||||||
|
for chunk in chunks:
|
||||||
|
result = self.detector.parse_streaming_increment(chunk, self.tools)
|
||||||
|
if result.normal_text:
|
||||||
|
normal_text += result.normal_text
|
||||||
|
if result.calls:
|
||||||
|
call_name += result.calls[0].name
|
||||||
|
parameters += result.calls[0].parameters
|
||||||
|
|
||||||
|
self.assertEqual(normal_text, "Here's a call: ")
|
||||||
|
self.assertEqual(call_name, "get_weather")
|
||||||
|
self.assertEqual(self.detector._buffer, "")
|
||||||
|
self.assertEqual(
|
||||||
|
result.normal_text, "", "Final result should have no normal text"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check the parameters
|
||||||
|
params = json.loads(parameters)
|
||||||
|
self.assertEqual(params["location"], "Tokyo")
|
||||||
|
self.assertEqual(params["data"], [1, 2, 3])
|
||||||
|
|
||||||
|
chunks = [
|
||||||
|
"Here's a call: <|python_start|>[get_weather(location='Tokyo', data=[1, 2, 3])]<|python_end|>"
|
||||||
|
]
|
||||||
|
|
||||||
|
normal_text = ""
|
||||||
|
call_name = ""
|
||||||
|
parameters = ""
|
||||||
|
for chunk in chunks:
|
||||||
|
result = self.detector.parse_streaming_increment(chunk, self.tools)
|
||||||
|
if result.normal_text:
|
||||||
|
normal_text += result.normal_text
|
||||||
|
if result.calls:
|
||||||
|
call_name += result.calls[0].name
|
||||||
|
parameters += result.calls[0].parameters
|
||||||
|
|
||||||
|
self.assertEqual(normal_text, "Here's a call: ")
|
||||||
|
self.assertEqual(call_name, "get_weather")
|
||||||
|
self.assertEqual(self.detector._buffer, "")
|
||||||
|
|
||||||
|
# Check the parameters
|
||||||
|
params = json.loads(parameters)
|
||||||
|
self.assertEqual(params["location"], "Tokyo")
|
||||||
|
self.assertEqual(params["data"], [1, 2, 3])
|
||||||
|
|
||||||
|
def test_detect_and_parse_with_python_start_and_end_token(self):
|
||||||
|
"""Test parsing a message that starts with <|python_start|> and contains a valid tool call."""
|
||||||
|
text = "User wants to get the weather in Mars. <|python_start|>[get_weather(location='Mars', unit='celsius')]<|python_end|> In this way we will get the weather in Mars."
|
||||||
|
result = self.detector.detect_and_parse(text, self.tools)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
result.normal_text,
|
||||||
|
"User wants to get the weather in Mars. In this way we will get the weather in Mars.",
|
||||||
|
)
|
||||||
|
self.assertEqual(len(result.calls), 1)
|
||||||
|
self.assertEqual(result.calls[0].name, "get_weather")
|
||||||
|
self.assertEqual(self.detector._buffer, "")
|
||||||
|
|
||||||
|
# Check the parameters
|
||||||
|
params = json.loads(result.calls[0].parameters)
|
||||||
|
self.assertEqual(params["location"], "Mars")
|
||||||
|
self.assertEqual(params["unit"], "celsius")
|
||||||
|
|
||||||
|
|
||||||
class TestMistralDetector(unittest.TestCase):
|
class TestMistralDetector(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user