import ast import json import logging import re from typing import List, Optional from sglang.srt.entrypoints.openai.protocol import Tool from sglang.srt.function_call.base_format_detector import BaseFormatDetector from sglang.srt.function_call.core_types import ( StreamingParseResult, StructureInfo, ToolCallItem, _GetInfoFunc, ) from sglang.srt.function_call.ebnf_composer import EBNFComposer logger = logging.getLogger(__name__) class PythonicDetector(BaseFormatDetector): """ Detector for Llama-3.2 and Llama-4 models with pythonic tool call format. Assumes function call format: [tool1(arg1=val1, arg2=val2), tool2(arg1=val3)] Arguments are Python literals (not JSON). """ def __init__(self): super().__init__() self.tool_call_regex = re.compile( r"\[([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s)?\),\s*)*([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s*)?\)\s*)+\]", re.DOTALL, ) @staticmethod def _text_strip(text: str) -> str: # Llama 4 model sometime will output <|python_start|> and <|python_end|> tokens # remove those tokens text = text.replace("<|python_start|>", "") text = text.replace("<|python_end|>", "") return text def has_tool_call(self, text: str) -> bool: return bool(self.tool_call_regex.search(self._text_strip(text.strip()))) def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult: # Try parsing the text as a Python list of function calls text = text.strip() # Remove unexpected <|python_start|> and <|python_end|> for llama4 text = self._text_strip(text) match = self.tool_call_regex.search(text) if match is None: return StreamingParseResult(normal_text=text, calls=[]) # Extract the tool call part and any text before/after it tool_call_start = match.start() tool_call_end = match.end() normal_text_before = text[:tool_call_start] if tool_call_start > 0 else "" tool_call_text = text[tool_call_start:tool_call_end] normal_text_after = text[tool_call_end:] if tool_call_end < len(text) else "" # Combine normal text normal_text = normal_text_before + normal_text_after try: module = ast.parse(tool_call_text) parsed = getattr(module.body[0], "value", None) if not ( isinstance(parsed, ast.List) and all(isinstance(e, ast.Call) for e in parsed.elts) ): return StreamingParseResult(normal_text=normal_text, calls=[]) calls = [] tool_indices = { tool.function.name: i for i, tool in enumerate(tools) if tool.function.name } for call_index, call in enumerate(parsed.elts): if not isinstance(call.func, ast.Name): continue function_name = call.func.id # Validate that the function exists in the tools if function_name not in tool_indices: logger.warning( f"Model attempted to call undefined function: {function_name}" ) continue arguments = {} for keyword in call.keywords: arguments[keyword.arg] = self._get_parameter_value(keyword.value) calls.append( ToolCallItem( tool_index=call_index, # Use the call index in the response, not tool position name=function_name, parameters=json.dumps(arguments, ensure_ascii=False), ) ) return StreamingParseResult(normal_text=normal_text, calls=calls) except Exception: logger.exception("Error in pythonic tool call parsing.") return StreamingParseResult(normal_text=normal_text, calls=[]) def _find_matching_bracket(self, buffer: str, start: int) -> int: """ Find the matching closing bracket for the opening bracket at start position. Properly handles nested brackets. Args: buffer: The text buffer to search in start: Position of the opening bracket '[' Returns: Position of the matching closing bracket ']', or -1 if not found """ bracket_count = 0 for i in range(start, len(buffer)): if buffer[i] == "[": bracket_count += 1 elif buffer[i] == "]": bracket_count -= 1 if bracket_count == 0: return i return -1 # No matching bracket found def _strip_and_split_buffer(self, buffer: str) -> tuple[str, str]: """ Strip special tokens from buffer and split into safe_text and held_back_text. Returns: tuple of (safe_text_to_output, text_to_hold_in_buffer) """ # Check if original buffer ends with a partial token at the end special_tokens = ["<|python_start|>", "<|python_end|>"] for token in special_tokens: partial_length = self._ends_with_partial_token(buffer, token) if partial_length > 0: # Split buffer: safe part + held back partial token safe_text = buffer[:-partial_length] held_back = buffer[-partial_length:] # Strip complete special tokens from safe part only safe_text = self._text_strip(safe_text) return safe_text, held_back # No partial tokens found, strip complete tokens from entire buffer safe_text = self._text_strip(buffer) return safe_text, "" def parse_streaming_increment( self, new_text: str, tools: List[Tool] ) -> StreamingParseResult: """ Streaming incremental parsing for pythonic tool calls. Buffers input until a complete pythonic tool call (from [ to ]) is found, then parses and emits any detected calls. """ self._buffer += new_text # Strip special tokens from entire buffer and handle partial tokens stripped_buffer, held_back = self._strip_and_split_buffer(self._buffer) start = stripped_buffer.find("[") if start == -1: # No tool call bracket found self._buffer = held_back return StreamingParseResult(normal_text=stripped_buffer) normal_text = stripped_buffer[:start] if start > 0 else "" end = self._find_matching_bracket(stripped_buffer, start) if end != -1: # Found complete tool call call_text = stripped_buffer[start : end + 1] result = self.detect_and_parse(call_text, tools) # Update buffer with remaining text after tool call plus any held back text remaining_text = stripped_buffer[end + 1 :] + held_back self._buffer = remaining_text # If we had normal text before the tool call, add it to the result if normal_text: result.normal_text = normal_text + (result.normal_text or "") return result # We have an opening bracket but no closing bracket yet # Put back everything from the bracket onwards plus held back text self._buffer = stripped_buffer[start:] + held_back if normal_text: return StreamingParseResult(normal_text=normal_text) # Otherwise, we're still accumulating a potential tool call return StreamingParseResult(normal_text="") def _get_parameter_value(self, val): if isinstance(val, ast.Constant): return val.value elif isinstance(val, ast.Dict): return { k.value: self._get_parameter_value(v) for k, v in zip(val.keys, val.values) } elif isinstance(val, ast.List): return [self._get_parameter_value(v) for v in val.elts] else: raise ValueError("Tool call arguments must be literals") def structure_info(self) -> _GetInfoFunc: def info(name: str): return StructureInfo(begin=f"[{name}(", end=")]", trigger=f"[{name}(") return info def build_ebnf(self, tools: List[Tool]) -> Optional[str]: return EBNFComposer.build_ebnf( tools, sequence_start_token="[", sequence_end_token="]", tool_call_separator=",", function_format="pythonic", )