feat(tool call): Enhance Llama32Detector for improved JSON parsing in non-stream (#6784)
This commit is contained in:
@@ -824,5 +824,101 @@ class TestBaseFormatDetector(unittest.TestCase):
|
||||
)
|
||||
|
||||
|
||||
class TestLlama32Detector(unittest.TestCase):
|
||||
def setUp(self):
|
||||
"""Set up test tools and detector for Mistral format testing."""
|
||||
self.tools = [
|
||||
Tool(
|
||||
type="function",
|
||||
function=Function(
|
||||
name="get_weather",
|
||||
description="Get weather information",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "City name",
|
||||
}
|
||||
},
|
||||
"required": ["city"],
|
||||
},
|
||||
),
|
||||
),
|
||||
Tool(
|
||||
type="function",
|
||||
function=Function(
|
||||
name="get_tourist_attractions",
|
||||
description="Get tourist attractions",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "City name",
|
||||
}
|
||||
},
|
||||
"required": ["city"],
|
||||
},
|
||||
),
|
||||
),
|
||||
]
|
||||
self.detector = Llama32Detector()
|
||||
|
||||
def test_single_json(self):
|
||||
text = '{"name": "get_weather", "parameters": {"city": "Paris"}}'
|
||||
result = self.detector.detect_and_parse(text, self.tools)
|
||||
assert len(result.calls) == 1
|
||||
assert result.calls[0].name == "get_weather"
|
||||
assert result.normal_text == ""
|
||||
|
||||
def test_multiple_json_with_separator(self):
|
||||
text = (
|
||||
'<|python_tag|>{"name": "get_weather", "parameters": {"city": "Paris"}};'
|
||||
'{"name": "get_tourist_attractions", "parameters": {"city": "Paris"}}'
|
||||
)
|
||||
result = self.detector.detect_and_parse(text, self.tools)
|
||||
self.assertEqual(len(result.calls), 2)
|
||||
self.assertEqual(result.calls[1].name, "get_tourist_attractions")
|
||||
self.assertEqual(result.normal_text, "")
|
||||
|
||||
def test_multiple_json_with_separator_customized(self):
|
||||
text = (
|
||||
'<|python_tag|>{"name": "get_weather", "parameters": {}}'
|
||||
'<|python_tag|>{"name": "get_tourist_attractions", "parameters": {}}'
|
||||
)
|
||||
result = self.detector.detect_and_parse(text, self.tools)
|
||||
self.assertEqual(len(result.calls), 2)
|
||||
self.assertEqual(result.calls[1].name, "get_tourist_attractions")
|
||||
self.assertEqual(result.normal_text, "")
|
||||
|
||||
def test_json_with_trailing_text(self):
|
||||
text = '{"name": "get_weather", "parameters": {}} Some follow-up text'
|
||||
result = self.detector.detect_and_parse(text, self.tools)
|
||||
self.assertEqual(len(result.calls), 1)
|
||||
self.assertIn("follow-up", result.normal_text)
|
||||
|
||||
def test_invalid_then_valid_json(self):
|
||||
text = (
|
||||
'{"name": "get_weather", "parameters": {' # malformed
|
||||
'{"name": "get_weather", "parameters": {}}'
|
||||
)
|
||||
result = self.detector.detect_and_parse(text, self.tools)
|
||||
self.assertEqual(len(result.calls), 1)
|
||||
self.assertEqual(result.calls[0].name, "get_weather")
|
||||
|
||||
def test_plain_text_only(self):
|
||||
text = "This is just plain explanation text."
|
||||
result = self.detector.detect_and_parse(text, self.tools)
|
||||
self.assertEqual(result.calls, [])
|
||||
self.assertEqual(result.normal_text, text)
|
||||
|
||||
def test_with_python_tag_prefix(self):
|
||||
text = 'Some intro. <|python_tag|>{"name": "get_weather", "parameters": {}}'
|
||||
result = self.detector.detect_and_parse(text, self.tools)
|
||||
self.assertEqual(len(result.calls), 1)
|
||||
self.assertTrue(result.normal_text.strip().startswith("Some intro."))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user