feat(tool call): Enhance Llama32Detector for improved JSON parsing in non-stream (#6784)
This commit is contained in:
@@ -42,31 +42,41 @@ class Llama32Detector(BaseFormatDetector):
|
|||||||
return StreamingParseResult(normal_text=text, calls=[])
|
return StreamingParseResult(normal_text=text, calls=[])
|
||||||
|
|
||||||
if "<|python_tag|>" in text:
|
if "<|python_tag|>" in text:
|
||||||
normal_text, action_text = text.split("<|python_tag|>")
|
normal_text, action_text = text.split("<|python_tag|>", maxsplit=1)
|
||||||
else:
|
else:
|
||||||
normal_text, action_text = "", text
|
normal_text, action_text = "", text
|
||||||
|
|
||||||
# Split by semicolon and process each part
|
decoder = json.JSONDecoder()
|
||||||
json_parts = [
|
idx = 0
|
||||||
part.strip()
|
safe_idx = idx # the index of the last valid JSON object
|
||||||
for part in action_text.split(self.tool_call_separator)
|
|
||||||
if part.strip()
|
|
||||||
]
|
|
||||||
all_actions = []
|
all_actions = []
|
||||||
for part in json_parts:
|
action_text_len = len(action_text)
|
||||||
|
while idx < action_text_len:
|
||||||
try:
|
try:
|
||||||
# Parse each individual JSON object
|
obj, end = decoder.raw_decode(action_text[idx:])
|
||||||
action = json.loads(part)
|
all_actions.append(obj)
|
||||||
all_actions.append(action)
|
idx += end + len(self.tool_call_separator)
|
||||||
|
safe_idx = idx
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
logger.warning(f"Failed to parse JSON part: {part}")
|
# Find where next `{"name"` appears and try again
|
||||||
logger.warning(f"JSON parse error: {str(e)}")
|
logger.warning(
|
||||||
|
f"Failed to parse JSON part: {action_text[idx:]}, JSON parse error: {str(e)}"
|
||||||
|
)
|
||||||
|
next_obj_start = action_text.find('{"name":', idx + 1)
|
||||||
|
if next_obj_start == -1:
|
||||||
|
break
|
||||||
|
idx = next_obj_start
|
||||||
continue
|
continue
|
||||||
calls = []
|
|
||||||
# Only process if we found valid JSON objects
|
# Only process if we found valid JSON objects
|
||||||
if all_actions:
|
calls = self.parse_base_json(all_actions, tools) if all_actions else []
|
||||||
calls = self.parse_base_json(all_actions, tools)
|
# Use safe_idx to avoid idx containing the last part of an invalid JSON object
|
||||||
return StreamingParseResult(normal_text=normal_text, calls=calls)
|
trailing_text = (
|
||||||
|
action_text[safe_idx:].strip() if safe_idx < action_text_len else ""
|
||||||
|
)
|
||||||
|
return StreamingParseResult(
|
||||||
|
normal_text=normal_text + trailing_text, calls=calls
|
||||||
|
)
|
||||||
|
|
||||||
def structure_info(self) -> _GetInfoFunc:
|
def structure_info(self) -> _GetInfoFunc:
|
||||||
return lambda name: StructureInfo(
|
return lambda name: StructureInfo(
|
||||||
|
|||||||
@@ -824,5 +824,101 @@ class TestBaseFormatDetector(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestLlama32Detector(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
"""Set up test tools and detector for Mistral format testing."""
|
||||||
|
self.tools = [
|
||||||
|
Tool(
|
||||||
|
type="function",
|
||||||
|
function=Function(
|
||||||
|
name="get_weather",
|
||||||
|
description="Get weather information",
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"city": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "City name",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["city"],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
),
|
||||||
|
Tool(
|
||||||
|
type="function",
|
||||||
|
function=Function(
|
||||||
|
name="get_tourist_attractions",
|
||||||
|
description="Get tourist attractions",
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"city": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "City name",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["city"],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
self.detector = Llama32Detector()
|
||||||
|
|
||||||
|
def test_single_json(self):
|
||||||
|
text = '{"name": "get_weather", "parameters": {"city": "Paris"}}'
|
||||||
|
result = self.detector.detect_and_parse(text, self.tools)
|
||||||
|
assert len(result.calls) == 1
|
||||||
|
assert result.calls[0].name == "get_weather"
|
||||||
|
assert result.normal_text == ""
|
||||||
|
|
||||||
|
def test_multiple_json_with_separator(self):
|
||||||
|
text = (
|
||||||
|
'<|python_tag|>{"name": "get_weather", "parameters": {"city": "Paris"}};'
|
||||||
|
'{"name": "get_tourist_attractions", "parameters": {"city": "Paris"}}'
|
||||||
|
)
|
||||||
|
result = self.detector.detect_and_parse(text, self.tools)
|
||||||
|
self.assertEqual(len(result.calls), 2)
|
||||||
|
self.assertEqual(result.calls[1].name, "get_tourist_attractions")
|
||||||
|
self.assertEqual(result.normal_text, "")
|
||||||
|
|
||||||
|
def test_multiple_json_with_separator_customized(self):
|
||||||
|
text = (
|
||||||
|
'<|python_tag|>{"name": "get_weather", "parameters": {}}'
|
||||||
|
'<|python_tag|>{"name": "get_tourist_attractions", "parameters": {}}'
|
||||||
|
)
|
||||||
|
result = self.detector.detect_and_parse(text, self.tools)
|
||||||
|
self.assertEqual(len(result.calls), 2)
|
||||||
|
self.assertEqual(result.calls[1].name, "get_tourist_attractions")
|
||||||
|
self.assertEqual(result.normal_text, "")
|
||||||
|
|
||||||
|
def test_json_with_trailing_text(self):
|
||||||
|
text = '{"name": "get_weather", "parameters": {}} Some follow-up text'
|
||||||
|
result = self.detector.detect_and_parse(text, self.tools)
|
||||||
|
self.assertEqual(len(result.calls), 1)
|
||||||
|
self.assertIn("follow-up", result.normal_text)
|
||||||
|
|
||||||
|
def test_invalid_then_valid_json(self):
|
||||||
|
text = (
|
||||||
|
'{"name": "get_weather", "parameters": {' # malformed
|
||||||
|
'{"name": "get_weather", "parameters": {}}'
|
||||||
|
)
|
||||||
|
result = self.detector.detect_and_parse(text, self.tools)
|
||||||
|
self.assertEqual(len(result.calls), 1)
|
||||||
|
self.assertEqual(result.calls[0].name, "get_weather")
|
||||||
|
|
||||||
|
def test_plain_text_only(self):
|
||||||
|
text = "This is just plain explanation text."
|
||||||
|
result = self.detector.detect_and_parse(text, self.tools)
|
||||||
|
self.assertEqual(result.calls, [])
|
||||||
|
self.assertEqual(result.normal_text, text)
|
||||||
|
|
||||||
|
def test_with_python_tag_prefix(self):
|
||||||
|
text = 'Some intro. <|python_tag|>{"name": "get_weather", "parameters": {}}'
|
||||||
|
result = self.detector.detect_and_parse(text, self.tools)
|
||||||
|
self.assertEqual(len(result.calls), 1)
|
||||||
|
self.assertTrue(result.normal_text.strip().startswith("Some intro."))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user