feat(tool call): Enhance Llama32Detector for improved JSON parsing in non-stream (#6784)
This commit is contained in:
@@ -42,31 +42,41 @@ class Llama32Detector(BaseFormatDetector):
|
||||
return StreamingParseResult(normal_text=text, calls=[])
|
||||
|
||||
if "<|python_tag|>" in text:
|
||||
normal_text, action_text = text.split("<|python_tag|>")
|
||||
normal_text, action_text = text.split("<|python_tag|>", maxsplit=1)
|
||||
else:
|
||||
normal_text, action_text = "", text
|
||||
|
||||
# Split by semicolon and process each part
|
||||
json_parts = [
|
||||
part.strip()
|
||||
for part in action_text.split(self.tool_call_separator)
|
||||
if part.strip()
|
||||
]
|
||||
decoder = json.JSONDecoder()
|
||||
idx = 0
|
||||
safe_idx = idx # the index of the last valid JSON object
|
||||
all_actions = []
|
||||
for part in json_parts:
|
||||
action_text_len = len(action_text)
|
||||
while idx < action_text_len:
|
||||
try:
|
||||
# Parse each individual JSON object
|
||||
action = json.loads(part)
|
||||
all_actions.append(action)
|
||||
obj, end = decoder.raw_decode(action_text[idx:])
|
||||
all_actions.append(obj)
|
||||
idx += end + len(self.tool_call_separator)
|
||||
safe_idx = idx
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"Failed to parse JSON part: {part}")
|
||||
logger.warning(f"JSON parse error: {str(e)}")
|
||||
# Find where next `{"name"` appears and try again
|
||||
logger.warning(
|
||||
f"Failed to parse JSON part: {action_text[idx:]}, JSON parse error: {str(e)}"
|
||||
)
|
||||
next_obj_start = action_text.find('{"name":', idx + 1)
|
||||
if next_obj_start == -1:
|
||||
break
|
||||
idx = next_obj_start
|
||||
continue
|
||||
calls = []
|
||||
|
||||
# Only process if we found valid JSON objects
|
||||
if all_actions:
|
||||
calls = self.parse_base_json(all_actions, tools)
|
||||
return StreamingParseResult(normal_text=normal_text, calls=calls)
|
||||
calls = self.parse_base_json(all_actions, tools) if all_actions else []
|
||||
# Use safe_idx to avoid idx containing the last part of an invalid JSON object
|
||||
trailing_text = (
|
||||
action_text[safe_idx:].strip() if safe_idx < action_text_len else ""
|
||||
)
|
||||
return StreamingParseResult(
|
||||
normal_text=normal_text + trailing_text, calls=calls
|
||||
)
|
||||
|
||||
def structure_info(self) -> _GetInfoFunc:
|
||||
return lambda name: StructureInfo(
|
||||
|
||||
Reference in New Issue
Block a user