update llama4 chat template and pythonic parser (#6679)

Co-authored-by: Chang Su <chang.s.su@oracle.com>
This commit is contained in:
Chao Yang
2025-05-30 17:01:22 -07:00
committed by GitHub
parent b581b22504
commit 4fac524b14
3 changed files with 165 additions and 73 deletions

View File

@@ -32,13 +32,24 @@ class PythonicDetector(BaseFormatDetector):
re.DOTALL,
)
@staticmethod
def _text_strip(text: str) -> str:
# Llama 4 model sometime will output <|python_start|> and <|python_end|> tokens
# remove those tokens
text = text.replace("<|python_start|>", "")
text = text.replace("<|python_end|>", "")
return text
def has_tool_call(self, text: str) -> bool:
return bool(self.tool_call_regex.search(text.strip()))
return bool(self.tool_call_regex.search(self._text_strip(text.strip())))
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
# Try parsing the text as a Python list of function calls
text = text.strip()
# Remove unexpected <|python_start|> and <|python_end|> for llama4
text = self._text_strip(text)
match = self.tool_call_regex.search(text)
if match is None:
return StreamingParseResult(normal_text=text, calls=[])
@@ -117,6 +128,30 @@ class PythonicDetector(BaseFormatDetector):
return i
return -1 # No matching bracket found
def _strip_and_split_buffer(self, buffer: str) -> tuple[str, str]:
"""
Strip special tokens from buffer and split into safe_text and held_back_text.
Returns:
tuple of (safe_text_to_output, text_to_hold_in_buffer)
"""
# Check if original buffer ends with a partial token at the end
special_tokens = ["<|python_start|>", "<|python_end|>"]
for token in special_tokens:
partial_length = self._ends_with_partial_token(buffer, token)
if partial_length > 0:
# Split buffer: safe part + held back partial token
safe_text = buffer[:-partial_length]
held_back = buffer[-partial_length:]
# Strip complete special tokens from safe part only
safe_text = self._text_strip(safe_text)
return safe_text, held_back
# No partial tokens found, strip complete tokens from entire buffer
safe_text = self._text_strip(buffer)
return safe_text, ""
def parse_streaming_increment(
self, new_text: str, tools: List[Tool]
) -> StreamingParseResult:
@@ -126,20 +161,28 @@ class PythonicDetector(BaseFormatDetector):
then parses and emits any detected calls.
"""
self._buffer += new_text
start = self._buffer.find("[")
# Strip special tokens from entire buffer and handle partial tokens
stripped_buffer, held_back = self._strip_and_split_buffer(self._buffer)
start = stripped_buffer.find("[")
if start == -1:
normal_text = self._buffer
self._buffer = ""
return StreamingParseResult(normal_text=normal_text)
# No tool call bracket found
self._buffer = held_back
return StreamingParseResult(normal_text=stripped_buffer)
normal_text = self._buffer[:start] if start > 0 else ""
normal_text = stripped_buffer[:start] if start > 0 else ""
end = self._find_matching_bracket(self._buffer, start)
end = self._find_matching_bracket(stripped_buffer, start)
if end != -1:
call_text = self._buffer[start : end + 1]
# Found complete tool call
call_text = stripped_buffer[start : end + 1]
result = self.detect_and_parse(call_text, tools)
self._buffer = self._buffer[end + 1 :]
# Update buffer with remaining text after tool call plus any held back text
remaining_text = stripped_buffer[end + 1 :] + held_back
self._buffer = remaining_text
# If we had normal text before the tool call, add it to the result
if normal_text:
@@ -148,8 +191,10 @@ class PythonicDetector(BaseFormatDetector):
return result
# We have an opening bracket but no closing bracket yet
# Put back everything from the bracket onwards plus held back text
self._buffer = stripped_buffer[start:] + held_back
if normal_text:
self._buffer = self._buffer[start:]
return StreamingParseResult(normal_text=normal_text)
# Otherwise, we're still accumulating a potential tool call