update llama4 chat template and pythonic parser (#6679)

Co-authored-by: Chang Su <chang.s.su@oracle.com>
This commit is contained in:
Chao Yang
2025-05-30 17:01:22 -07:00
committed by GitHub
parent b581b22504
commit 4fac524b14
3 changed files with 165 additions and 73 deletions

View File

@@ -265,6 +265,81 @@ class TestPythonicDetector(unittest.TestCase):
self.assertEqual(params["location"], "Tokyo")
self.assertEqual(params["data"], [1, 2, 3])
def test_parse_streaming_with_python_start_and_end_token(self):
"""Test parsing a message that starts with <|python_start|> and <|python_end|> across chunks."""
chunks = [
"Here's a call: ",
"<|python_",
"start|>[get_weather(location=",
"'Tokyo', data=[1, 2",
", 3])]<|python_end|>",
]
normal_text = ""
call_name = ""
parameters = ""
for chunk in chunks:
result = self.detector.parse_streaming_increment(chunk, self.tools)
if result.normal_text:
normal_text += result.normal_text
if result.calls:
call_name += result.calls[0].name
parameters += result.calls[0].parameters
self.assertEqual(normal_text, "Here's a call: ")
self.assertEqual(call_name, "get_weather")
self.assertEqual(self.detector._buffer, "")
self.assertEqual(
result.normal_text, "", "Final result should have no normal text"
)
# Check the parameters
params = json.loads(parameters)
self.assertEqual(params["location"], "Tokyo")
self.assertEqual(params["data"], [1, 2, 3])
chunks = [
"Here's a call: <|python_start|>[get_weather(location='Tokyo', data=[1, 2, 3])]<|python_end|>"
]
normal_text = ""
call_name = ""
parameters = ""
for chunk in chunks:
result = self.detector.parse_streaming_increment(chunk, self.tools)
if result.normal_text:
normal_text += result.normal_text
if result.calls:
call_name += result.calls[0].name
parameters += result.calls[0].parameters
self.assertEqual(normal_text, "Here's a call: ")
self.assertEqual(call_name, "get_weather")
self.assertEqual(self.detector._buffer, "")
# Check the parameters
params = json.loads(parameters)
self.assertEqual(params["location"], "Tokyo")
self.assertEqual(params["data"], [1, 2, 3])
def test_detect_and_parse_with_python_start_and_end_token(self):
"""Test parsing a message that starts with <|python_start|> and contains a valid tool call."""
text = "User wants to get the weather in Mars. <|python_start|>[get_weather(location='Mars', unit='celsius')]<|python_end|> In this way we will get the weather in Mars."
result = self.detector.detect_and_parse(text, self.tools)
self.assertEqual(
result.normal_text,
"User wants to get the weather in Mars. In this way we will get the weather in Mars.",
)
self.assertEqual(len(result.calls), 1)
self.assertEqual(result.calls[0].name, "get_weather")
self.assertEqual(self.detector._buffer, "")
# Check the parameters
params = json.loads(result.calls[0].parameters)
self.assertEqual(params["location"], "Mars")
self.assertEqual(params["unit"], "celsius")
class TestMistralDetector(unittest.TestCase):
def setUp(self):