fix(tool call): Fix tool_index in PythonicDetector and issues with mixed output in non-streaming (#6678)
This commit is contained in:
@@ -72,7 +72,7 @@ class BaseFormatDetector(ABC):
|
||||
action = json.loads(text)
|
||||
return StreamingParseResult(calls=self.parse_base_json(action, tools))
|
||||
|
||||
def ends_with_partial_token(self, buffer: str, bot_token: str) -> int:
|
||||
def _ends_with_partial_token(self, buffer: str, bot_token: str) -> int:
|
||||
"""
|
||||
Check if buffer ends with a partial bot_token.
|
||||
Return the length of the partial bot_token.
|
||||
@@ -108,7 +108,7 @@ class BaseFormatDetector(ABC):
|
||||
current_text = self._buffer
|
||||
if not (self.bot_token in current_text or current_text.startswith("{")):
|
||||
# Only clear buffer if we're sure no tool call is starting
|
||||
if not self.ends_with_partial_token(self._buffer, self.bot_token):
|
||||
if not self._ends_with_partial_token(self._buffer, self.bot_token):
|
||||
normal_text = self._buffer
|
||||
self._buffer = ""
|
||||
if self.eot_token in normal_text:
|
||||
|
||||
@@ -33,46 +33,67 @@ class PythonicDetector(BaseFormatDetector):
|
||||
)
|
||||
|
||||
def has_tool_call(self, text: str) -> bool:
|
||||
return bool(self.tool_call_regex.match(text.strip()))
|
||||
return bool(self.tool_call_regex.search(text.strip()))
|
||||
|
||||
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
||||
# Try parsing the text as a Python list of function calls
|
||||
text = text.strip()
|
||||
if not (text.startswith("[") and text.endswith("]")):
|
||||
# Not a pythonic tool call format
|
||||
|
||||
match = self.tool_call_regex.search(text)
|
||||
if match is None:
|
||||
return StreamingParseResult(normal_text=text, calls=[])
|
||||
|
||||
# Extract the tool call part and any text before/after it
|
||||
tool_call_start = match.start()
|
||||
tool_call_end = match.end()
|
||||
|
||||
normal_text_before = text[:tool_call_start] if tool_call_start > 0 else ""
|
||||
tool_call_text = text[tool_call_start:tool_call_end]
|
||||
normal_text_after = text[tool_call_end:] if tool_call_end < len(text) else ""
|
||||
|
||||
# Combine normal text
|
||||
normal_text = normal_text_before + normal_text_after
|
||||
|
||||
try:
|
||||
module = ast.parse(text)
|
||||
module = ast.parse(tool_call_text)
|
||||
parsed = getattr(module.body[0], "value", None)
|
||||
if not (
|
||||
isinstance(parsed, ast.List)
|
||||
and all(isinstance(e, ast.Call) for e in parsed.elts)
|
||||
):
|
||||
return StreamingParseResult(normal_text=text, calls=[])
|
||||
return StreamingParseResult(normal_text=normal_text, calls=[])
|
||||
|
||||
calls = []
|
||||
tool_indices = {
|
||||
tool.function.name: i
|
||||
for i, tool in enumerate(tools)
|
||||
if tool.function.name
|
||||
}
|
||||
for call in parsed.elts:
|
||||
for call_index, call in enumerate(parsed.elts):
|
||||
if not isinstance(call.func, ast.Name):
|
||||
continue
|
||||
function_name = call.func.id
|
||||
# Validate that the function exists in the tools
|
||||
if function_name not in tool_indices:
|
||||
logger.warning(
|
||||
f"Model attempted to call undefined function: {function_name}"
|
||||
)
|
||||
continue
|
||||
arguments = {}
|
||||
for keyword in call.keywords:
|
||||
arguments[keyword.arg] = self._get_parameter_value(keyword.value)
|
||||
calls.append(
|
||||
ToolCallItem(
|
||||
tool_index=tool_indices.get(function_name, -1),
|
||||
tool_index=call_index, # Use the call index in the response, not tool position
|
||||
name=function_name,
|
||||
parameters=json.dumps(arguments, ensure_ascii=False),
|
||||
)
|
||||
)
|
||||
return StreamingParseResult(normal_text="", calls=calls)
|
||||
|
||||
return StreamingParseResult(normal_text=normal_text, calls=calls)
|
||||
except Exception:
|
||||
logger.exception("Error in pythonic tool call parsing.")
|
||||
return StreamingParseResult(normal_text=text, calls=[])
|
||||
return StreamingParseResult(normal_text=normal_text, calls=[])
|
||||
|
||||
def _find_matching_bracket(self, buffer: str, start: int) -> int:
|
||||
"""
|
||||
|
||||
@@ -86,7 +86,7 @@ class Qwen25Detector(BaseFormatDetector):
|
||||
result.normal_text = cleaned_text
|
||||
else:
|
||||
# Check if buffer might contain partial end token at the end
|
||||
partial_match_len = self.ends_with_partial_token(
|
||||
partial_match_len = self._ends_with_partial_token(
|
||||
self._normal_text_buffer, end_token_without_newline
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user