[tool call] Fix prev_tool_call_arr management in base_format_detector.py (#11367)
Co-authored-by: Xinyuan Tong <115166877+JustinTong0323@users.noreply.github.com>
This commit is contained in:
@@ -265,12 +265,6 @@ class BaseFormatDetector(ABC):
|
|||||||
# Only remove the processed portion, keep unprocessed content
|
# Only remove the processed portion, keep unprocessed content
|
||||||
self._buffer = current_text[start_idx + end_idx :]
|
self._buffer = current_text[start_idx + end_idx :]
|
||||||
|
|
||||||
if self.current_tool_id < len(self.prev_tool_call_arr):
|
|
||||||
self.prev_tool_call_arr[self.current_tool_id].clear()
|
|
||||||
self.current_tool_name_sent = False
|
|
||||||
self.streamed_args_for_tool[self.current_tool_id] = ""
|
|
||||||
self.current_tool_id += 1
|
|
||||||
|
|
||||||
# If the tool is still being parsed, send incremental changes
|
# If the tool is still being parsed, send incremental changes
|
||||||
elif prev_arguments:
|
elif prev_arguments:
|
||||||
prev_args_json = json.dumps(prev_arguments)
|
prev_args_json = json.dumps(prev_arguments)
|
||||||
@@ -278,6 +272,20 @@ class BaseFormatDetector(ABC):
|
|||||||
prefix = _find_common_prefix(prev_args_json, cur_args_json)
|
prefix = _find_common_prefix(prev_args_json, cur_args_json)
|
||||||
argument_diff = prefix[sent:]
|
argument_diff = prefix[sent:]
|
||||||
|
|
||||||
|
# Update prev_tool_call_arr with current state
|
||||||
|
if self.current_tool_id >= 0:
|
||||||
|
# Ensure prev_tool_call_arr is large enough
|
||||||
|
while len(self.prev_tool_call_arr) <= self.current_tool_id:
|
||||||
|
self.prev_tool_call_arr.append({})
|
||||||
|
self.prev_tool_call_arr[self.current_tool_id] = (
|
||||||
|
current_tool_call
|
||||||
|
)
|
||||||
|
|
||||||
|
# Advance to next tool if complete
|
||||||
|
if is_current_complete:
|
||||||
|
self.current_tool_name_sent = False
|
||||||
|
self.current_tool_id += 1
|
||||||
|
|
||||||
# Send the argument diff if there's something new
|
# Send the argument diff if there's something new
|
||||||
if argument_diff is not None:
|
if argument_diff is not None:
|
||||||
# Use the correct tool_index: completing_tool_id for completed tools, current_tool_id for ongoing
|
# Use the correct tool_index: completing_tool_id for completed tools, current_tool_id for ongoing
|
||||||
@@ -294,17 +302,7 @@ class BaseFormatDetector(ABC):
|
|||||||
)
|
)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
if not is_current_complete:
|
self.streamed_args_for_tool[tool_index_to_use] += argument_diff
|
||||||
self.streamed_args_for_tool[
|
|
||||||
self.current_tool_id
|
|
||||||
] += argument_diff
|
|
||||||
|
|
||||||
# Update prev_tool_call_arr with current state
|
|
||||||
if self.current_tool_id >= 0:
|
|
||||||
# Ensure prev_tool_call_arr is large enough
|
|
||||||
while len(self.prev_tool_call_arr) <= self.current_tool_id:
|
|
||||||
self.prev_tool_call_arr.append({})
|
|
||||||
self.prev_tool_call_arr[self.current_tool_id] = current_tool_call
|
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user