diff --git a/python/sglang/srt/function_call/base_format_detector.py b/python/sglang/srt/function_call/base_format_detector.py index d5bb9dc89..1df62a7a8 100644 --- a/python/sglang/srt/function_call/base_format_detector.py +++ b/python/sglang/srt/function_call/base_format_detector.py @@ -72,7 +72,7 @@ class BaseFormatDetector(ABC): action = json.loads(text) return StreamingParseResult(calls=self.parse_base_json(action, tools)) - def ends_with_partial_token(self, buffer: str, bot_token: str) -> int: + def _ends_with_partial_token(self, buffer: str, bot_token: str) -> int: """ Check if buffer ends with a partial bot_token. Return the length of the partial bot_token. @@ -108,7 +108,7 @@ class BaseFormatDetector(ABC): current_text = self._buffer if not (self.bot_token in current_text or current_text.startswith("{")): # Only clear buffer if we're sure no tool call is starting - if not self.ends_with_partial_token(self._buffer, self.bot_token): + if not self._ends_with_partial_token(self._buffer, self.bot_token): normal_text = self._buffer self._buffer = "" if self.eot_token in normal_text: diff --git a/python/sglang/srt/function_call/pythonic_detector.py b/python/sglang/srt/function_call/pythonic_detector.py index 2ee802284..b9a7bd3a8 100644 --- a/python/sglang/srt/function_call/pythonic_detector.py +++ b/python/sglang/srt/function_call/pythonic_detector.py @@ -33,46 +33,67 @@ class PythonicDetector(BaseFormatDetector): ) def has_tool_call(self, text: str) -> bool: - return bool(self.tool_call_regex.match(text.strip())) + return bool(self.tool_call_regex.search(text.strip())) def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult: # Try parsing the text as a Python list of function calls text = text.strip() - if not (text.startswith("[") and text.endswith("]")): - # Not a pythonic tool call format + + match = self.tool_call_regex.search(text) + if match is None: return StreamingParseResult(normal_text=text, calls=[]) + + # Extract the tool call part and any text before/after it + tool_call_start = match.start() + tool_call_end = match.end() + + normal_text_before = text[:tool_call_start] if tool_call_start > 0 else "" + tool_call_text = text[tool_call_start:tool_call_end] + normal_text_after = text[tool_call_end:] if tool_call_end < len(text) else "" + + # Combine normal text + normal_text = normal_text_before + normal_text_after + try: - module = ast.parse(text) + module = ast.parse(tool_call_text) parsed = getattr(module.body[0], "value", None) if not ( isinstance(parsed, ast.List) and all(isinstance(e, ast.Call) for e in parsed.elts) ): - return StreamingParseResult(normal_text=text, calls=[]) + return StreamingParseResult(normal_text=normal_text, calls=[]) + calls = [] tool_indices = { tool.function.name: i for i, tool in enumerate(tools) if tool.function.name } - for call in parsed.elts: + for call_index, call in enumerate(parsed.elts): if not isinstance(call.func, ast.Name): continue function_name = call.func.id + # Validate that the function exists in the tools + if function_name not in tool_indices: + logger.warning( + f"Model attempted to call undefined function: {function_name}" + ) + continue arguments = {} for keyword in call.keywords: arguments[keyword.arg] = self._get_parameter_value(keyword.value) calls.append( ToolCallItem( - tool_index=tool_indices.get(function_name, -1), + tool_index=call_index, # Use the call index in the response, not tool position name=function_name, parameters=json.dumps(arguments, ensure_ascii=False), ) ) - return StreamingParseResult(normal_text="", calls=calls) + + return StreamingParseResult(normal_text=normal_text, calls=calls) except Exception: logger.exception("Error in pythonic tool call parsing.") - return StreamingParseResult(normal_text=text, calls=[]) + return StreamingParseResult(normal_text=normal_text, calls=[]) def _find_matching_bracket(self, buffer: str, start: int) -> int: """ diff --git a/python/sglang/srt/function_call/qwen25_detector.py b/python/sglang/srt/function_call/qwen25_detector.py index 1e2254bb8..0a2f4bd5d 100644 --- a/python/sglang/srt/function_call/qwen25_detector.py +++ b/python/sglang/srt/function_call/qwen25_detector.py @@ -86,7 +86,7 @@ class Qwen25Detector(BaseFormatDetector): result.normal_text = cleaned_text else: # Check if buffer might contain partial end token at the end - partial_match_len = self.ends_with_partial_token( + partial_match_len = self._ends_with_partial_token( self._normal_text_buffer, end_token_without_newline )