fix(tool call): Fix tool_index in PythonicDetector and issues with mixed output in non-streaming (#6678)
This commit is contained in:
@@ -72,7 +72,7 @@ class BaseFormatDetector(ABC):
|
|||||||
action = json.loads(text)
|
action = json.loads(text)
|
||||||
return StreamingParseResult(calls=self.parse_base_json(action, tools))
|
return StreamingParseResult(calls=self.parse_base_json(action, tools))
|
||||||
|
|
||||||
def ends_with_partial_token(self, buffer: str, bot_token: str) -> int:
|
def _ends_with_partial_token(self, buffer: str, bot_token: str) -> int:
|
||||||
"""
|
"""
|
||||||
Check if buffer ends with a partial bot_token.
|
Check if buffer ends with a partial bot_token.
|
||||||
Return the length of the partial bot_token.
|
Return the length of the partial bot_token.
|
||||||
@@ -108,7 +108,7 @@ class BaseFormatDetector(ABC):
|
|||||||
current_text = self._buffer
|
current_text = self._buffer
|
||||||
if not (self.bot_token in current_text or current_text.startswith("{")):
|
if not (self.bot_token in current_text or current_text.startswith("{")):
|
||||||
# Only clear buffer if we're sure no tool call is starting
|
# Only clear buffer if we're sure no tool call is starting
|
||||||
if not self.ends_with_partial_token(self._buffer, self.bot_token):
|
if not self._ends_with_partial_token(self._buffer, self.bot_token):
|
||||||
normal_text = self._buffer
|
normal_text = self._buffer
|
||||||
self._buffer = ""
|
self._buffer = ""
|
||||||
if self.eot_token in normal_text:
|
if self.eot_token in normal_text:
|
||||||
|
|||||||
@@ -33,46 +33,67 @@ class PythonicDetector(BaseFormatDetector):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def has_tool_call(self, text: str) -> bool:
|
def has_tool_call(self, text: str) -> bool:
|
||||||
return bool(self.tool_call_regex.match(text.strip()))
|
return bool(self.tool_call_regex.search(text.strip()))
|
||||||
|
|
||||||
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
||||||
# Try parsing the text as a Python list of function calls
|
# Try parsing the text as a Python list of function calls
|
||||||
text = text.strip()
|
text = text.strip()
|
||||||
if not (text.startswith("[") and text.endswith("]")):
|
|
||||||
# Not a pythonic tool call format
|
match = self.tool_call_regex.search(text)
|
||||||
|
if match is None:
|
||||||
return StreamingParseResult(normal_text=text, calls=[])
|
return StreamingParseResult(normal_text=text, calls=[])
|
||||||
|
|
||||||
|
# Extract the tool call part and any text before/after it
|
||||||
|
tool_call_start = match.start()
|
||||||
|
tool_call_end = match.end()
|
||||||
|
|
||||||
|
normal_text_before = text[:tool_call_start] if tool_call_start > 0 else ""
|
||||||
|
tool_call_text = text[tool_call_start:tool_call_end]
|
||||||
|
normal_text_after = text[tool_call_end:] if tool_call_end < len(text) else ""
|
||||||
|
|
||||||
|
# Combine normal text
|
||||||
|
normal_text = normal_text_before + normal_text_after
|
||||||
|
|
||||||
try:
|
try:
|
||||||
module = ast.parse(text)
|
module = ast.parse(tool_call_text)
|
||||||
parsed = getattr(module.body[0], "value", None)
|
parsed = getattr(module.body[0], "value", None)
|
||||||
if not (
|
if not (
|
||||||
isinstance(parsed, ast.List)
|
isinstance(parsed, ast.List)
|
||||||
and all(isinstance(e, ast.Call) for e in parsed.elts)
|
and all(isinstance(e, ast.Call) for e in parsed.elts)
|
||||||
):
|
):
|
||||||
return StreamingParseResult(normal_text=text, calls=[])
|
return StreamingParseResult(normal_text=normal_text, calls=[])
|
||||||
|
|
||||||
calls = []
|
calls = []
|
||||||
tool_indices = {
|
tool_indices = {
|
||||||
tool.function.name: i
|
tool.function.name: i
|
||||||
for i, tool in enumerate(tools)
|
for i, tool in enumerate(tools)
|
||||||
if tool.function.name
|
if tool.function.name
|
||||||
}
|
}
|
||||||
for call in parsed.elts:
|
for call_index, call in enumerate(parsed.elts):
|
||||||
if not isinstance(call.func, ast.Name):
|
if not isinstance(call.func, ast.Name):
|
||||||
continue
|
continue
|
||||||
function_name = call.func.id
|
function_name = call.func.id
|
||||||
|
# Validate that the function exists in the tools
|
||||||
|
if function_name not in tool_indices:
|
||||||
|
logger.warning(
|
||||||
|
f"Model attempted to call undefined function: {function_name}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
arguments = {}
|
arguments = {}
|
||||||
for keyword in call.keywords:
|
for keyword in call.keywords:
|
||||||
arguments[keyword.arg] = self._get_parameter_value(keyword.value)
|
arguments[keyword.arg] = self._get_parameter_value(keyword.value)
|
||||||
calls.append(
|
calls.append(
|
||||||
ToolCallItem(
|
ToolCallItem(
|
||||||
tool_index=tool_indices.get(function_name, -1),
|
tool_index=call_index, # Use the call index in the response, not tool position
|
||||||
name=function_name,
|
name=function_name,
|
||||||
parameters=json.dumps(arguments, ensure_ascii=False),
|
parameters=json.dumps(arguments, ensure_ascii=False),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return StreamingParseResult(normal_text="", calls=calls)
|
|
||||||
|
return StreamingParseResult(normal_text=normal_text, calls=calls)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error in pythonic tool call parsing.")
|
logger.exception("Error in pythonic tool call parsing.")
|
||||||
return StreamingParseResult(normal_text=text, calls=[])
|
return StreamingParseResult(normal_text=normal_text, calls=[])
|
||||||
|
|
||||||
def _find_matching_bracket(self, buffer: str, start: int) -> int:
|
def _find_matching_bracket(self, buffer: str, start: int) -> int:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -86,7 +86,7 @@ class Qwen25Detector(BaseFormatDetector):
|
|||||||
result.normal_text = cleaned_text
|
result.normal_text = cleaned_text
|
||||||
else:
|
else:
|
||||||
# Check if buffer might contain partial end token at the end
|
# Check if buffer might contain partial end token at the end
|
||||||
partial_match_len = self.ends_with_partial_token(
|
partial_match_len = self._ends_with_partial_token(
|
||||||
self._normal_text_buffer, end_token_without_newline
|
self._normal_text_buffer, end_token_without_newline
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user