update llama4 chat template and pythonic parser (#6679)
Co-authored-by: Chang Su <chang.s.su@oracle.com>
This commit is contained in:
@@ -32,13 +32,24 @@ class PythonicDetector(BaseFormatDetector):
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _text_strip(text: str) -> str:
|
||||
# Llama 4 model sometime will output <|python_start|> and <|python_end|> tokens
|
||||
# remove those tokens
|
||||
text = text.replace("<|python_start|>", "")
|
||||
text = text.replace("<|python_end|>", "")
|
||||
return text
|
||||
|
||||
def has_tool_call(self, text: str) -> bool:
|
||||
return bool(self.tool_call_regex.search(text.strip()))
|
||||
return bool(self.tool_call_regex.search(self._text_strip(text.strip())))
|
||||
|
||||
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
||||
# Try parsing the text as a Python list of function calls
|
||||
text = text.strip()
|
||||
|
||||
# Remove unexpected <|python_start|> and <|python_end|> for llama4
|
||||
text = self._text_strip(text)
|
||||
|
||||
match = self.tool_call_regex.search(text)
|
||||
if match is None:
|
||||
return StreamingParseResult(normal_text=text, calls=[])
|
||||
@@ -117,6 +128,30 @@ class PythonicDetector(BaseFormatDetector):
|
||||
return i
|
||||
return -1 # No matching bracket found
|
||||
|
||||
def _strip_and_split_buffer(self, buffer: str) -> tuple[str, str]:
|
||||
"""
|
||||
Strip special tokens from buffer and split into safe_text and held_back_text.
|
||||
|
||||
Returns:
|
||||
tuple of (safe_text_to_output, text_to_hold_in_buffer)
|
||||
"""
|
||||
# Check if original buffer ends with a partial token at the end
|
||||
special_tokens = ["<|python_start|>", "<|python_end|>"]
|
||||
|
||||
for token in special_tokens:
|
||||
partial_length = self._ends_with_partial_token(buffer, token)
|
||||
if partial_length > 0:
|
||||
# Split buffer: safe part + held back partial token
|
||||
safe_text = buffer[:-partial_length]
|
||||
held_back = buffer[-partial_length:]
|
||||
# Strip complete special tokens from safe part only
|
||||
safe_text = self._text_strip(safe_text)
|
||||
return safe_text, held_back
|
||||
|
||||
# No partial tokens found, strip complete tokens from entire buffer
|
||||
safe_text = self._text_strip(buffer)
|
||||
return safe_text, ""
|
||||
|
||||
def parse_streaming_increment(
|
||||
self, new_text: str, tools: List[Tool]
|
||||
) -> StreamingParseResult:
|
||||
@@ -126,20 +161,28 @@ class PythonicDetector(BaseFormatDetector):
|
||||
then parses and emits any detected calls.
|
||||
"""
|
||||
self._buffer += new_text
|
||||
start = self._buffer.find("[")
|
||||
|
||||
# Strip special tokens from entire buffer and handle partial tokens
|
||||
stripped_buffer, held_back = self._strip_and_split_buffer(self._buffer)
|
||||
|
||||
start = stripped_buffer.find("[")
|
||||
|
||||
if start == -1:
|
||||
normal_text = self._buffer
|
||||
self._buffer = ""
|
||||
return StreamingParseResult(normal_text=normal_text)
|
||||
# No tool call bracket found
|
||||
self._buffer = held_back
|
||||
return StreamingParseResult(normal_text=stripped_buffer)
|
||||
|
||||
normal_text = self._buffer[:start] if start > 0 else ""
|
||||
normal_text = stripped_buffer[:start] if start > 0 else ""
|
||||
|
||||
end = self._find_matching_bracket(self._buffer, start)
|
||||
end = self._find_matching_bracket(stripped_buffer, start)
|
||||
if end != -1:
|
||||
call_text = self._buffer[start : end + 1]
|
||||
# Found complete tool call
|
||||
call_text = stripped_buffer[start : end + 1]
|
||||
result = self.detect_and_parse(call_text, tools)
|
||||
self._buffer = self._buffer[end + 1 :]
|
||||
|
||||
# Update buffer with remaining text after tool call plus any held back text
|
||||
remaining_text = stripped_buffer[end + 1 :] + held_back
|
||||
self._buffer = remaining_text
|
||||
|
||||
# If we had normal text before the tool call, add it to the result
|
||||
if normal_text:
|
||||
@@ -148,8 +191,10 @@ class PythonicDetector(BaseFormatDetector):
|
||||
return result
|
||||
|
||||
# We have an opening bracket but no closing bracket yet
|
||||
# Put back everything from the bracket onwards plus held back text
|
||||
self._buffer = stripped_buffer[start:] + held_back
|
||||
|
||||
if normal_text:
|
||||
self._buffer = self._buffer[start:]
|
||||
return StreamingParseResult(normal_text=normal_text)
|
||||
|
||||
# Otherwise, we're still accumulating a potential tool call
|
||||
|
||||
Reference in New Issue
Block a user