update llama4 chat template and pythonic parser (#6679)
Co-authored-by: Chang Su <chang.s.su@oracle.com>
This commit is contained in:
@@ -265,6 +265,81 @@ class TestPythonicDetector(unittest.TestCase):
|
||||
self.assertEqual(params["location"], "Tokyo")
|
||||
self.assertEqual(params["data"], [1, 2, 3])
|
||||
|
||||
def test_parse_streaming_with_python_start_and_end_token(self):
|
||||
"""Test parsing a message that starts with <|python_start|> and <|python_end|> across chunks."""
|
||||
chunks = [
|
||||
"Here's a call: ",
|
||||
"<|python_",
|
||||
"start|>[get_weather(location=",
|
||||
"'Tokyo', data=[1, 2",
|
||||
", 3])]<|python_end|>",
|
||||
]
|
||||
|
||||
normal_text = ""
|
||||
call_name = ""
|
||||
parameters = ""
|
||||
for chunk in chunks:
|
||||
result = self.detector.parse_streaming_increment(chunk, self.tools)
|
||||
if result.normal_text:
|
||||
normal_text += result.normal_text
|
||||
if result.calls:
|
||||
call_name += result.calls[0].name
|
||||
parameters += result.calls[0].parameters
|
||||
|
||||
self.assertEqual(normal_text, "Here's a call: ")
|
||||
self.assertEqual(call_name, "get_weather")
|
||||
self.assertEqual(self.detector._buffer, "")
|
||||
self.assertEqual(
|
||||
result.normal_text, "", "Final result should have no normal text"
|
||||
)
|
||||
|
||||
# Check the parameters
|
||||
params = json.loads(parameters)
|
||||
self.assertEqual(params["location"], "Tokyo")
|
||||
self.assertEqual(params["data"], [1, 2, 3])
|
||||
|
||||
chunks = [
|
||||
"Here's a call: <|python_start|>[get_weather(location='Tokyo', data=[1, 2, 3])]<|python_end|>"
|
||||
]
|
||||
|
||||
normal_text = ""
|
||||
call_name = ""
|
||||
parameters = ""
|
||||
for chunk in chunks:
|
||||
result = self.detector.parse_streaming_increment(chunk, self.tools)
|
||||
if result.normal_text:
|
||||
normal_text += result.normal_text
|
||||
if result.calls:
|
||||
call_name += result.calls[0].name
|
||||
parameters += result.calls[0].parameters
|
||||
|
||||
self.assertEqual(normal_text, "Here's a call: ")
|
||||
self.assertEqual(call_name, "get_weather")
|
||||
self.assertEqual(self.detector._buffer, "")
|
||||
|
||||
# Check the parameters
|
||||
params = json.loads(parameters)
|
||||
self.assertEqual(params["location"], "Tokyo")
|
||||
self.assertEqual(params["data"], [1, 2, 3])
|
||||
|
||||
def test_detect_and_parse_with_python_start_and_end_token(self):
|
||||
"""Test parsing a message that starts with <|python_start|> and contains a valid tool call."""
|
||||
text = "User wants to get the weather in Mars. <|python_start|>[get_weather(location='Mars', unit='celsius')]<|python_end|> In this way we will get the weather in Mars."
|
||||
result = self.detector.detect_and_parse(text, self.tools)
|
||||
|
||||
self.assertEqual(
|
||||
result.normal_text,
|
||||
"User wants to get the weather in Mars. In this way we will get the weather in Mars.",
|
||||
)
|
||||
self.assertEqual(len(result.calls), 1)
|
||||
self.assertEqual(result.calls[0].name, "get_weather")
|
||||
self.assertEqual(self.detector._buffer, "")
|
||||
|
||||
# Check the parameters
|
||||
params = json.loads(result.calls[0].parameters)
|
||||
self.assertEqual(params["location"], "Mars")
|
||||
self.assertEqual(params["unit"], "celsius")
|
||||
|
||||
|
||||
class TestMistralDetector(unittest.TestCase):
|
||||
def setUp(self):
|
||||
|
||||
Reference in New Issue
Block a user