230 lines
8.5 KiB
Python
230 lines
8.5 KiB
Python
import ast
|
|
import json
|
|
import logging
|
|
import re
|
|
from typing import List, Optional
|
|
|
|
from sglang.srt.entrypoints.openai.protocol import Tool
|
|
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
|
from sglang.srt.function_call.core_types import (
|
|
StreamingParseResult,
|
|
StructureInfo,
|
|
ToolCallItem,
|
|
_GetInfoFunc,
|
|
)
|
|
from sglang.srt.function_call.ebnf_composer import EBNFComposer
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class PythonicDetector(BaseFormatDetector):
|
|
"""
|
|
Detector for Llama-3.2 and Llama-4 models with pythonic tool call format.
|
|
Assumes function call format:
|
|
[tool1(arg1=val1, arg2=val2), tool2(arg1=val3)]
|
|
Arguments are Python literals (not JSON).
|
|
"""
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.tool_call_regex = re.compile(
|
|
r"\[([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s)?\),\s*)*([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s*)?\)\s*)+\]",
|
|
re.DOTALL,
|
|
)
|
|
|
|
@staticmethod
|
|
def _text_strip(text: str) -> str:
|
|
# Llama 4 model sometime will output <|python_start|> and <|python_end|> tokens
|
|
# remove those tokens
|
|
text = text.replace("<|python_start|>", "")
|
|
text = text.replace("<|python_end|>", "")
|
|
return text
|
|
|
|
def has_tool_call(self, text: str) -> bool:
|
|
return bool(self.tool_call_regex.search(self._text_strip(text.strip())))
|
|
|
|
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
|
# Try parsing the text as a Python list of function calls
|
|
text = text.strip()
|
|
|
|
# Remove unexpected <|python_start|> and <|python_end|> for llama4
|
|
text = self._text_strip(text)
|
|
|
|
match = self.tool_call_regex.search(text)
|
|
if match is None:
|
|
return StreamingParseResult(normal_text=text, calls=[])
|
|
|
|
# Extract the tool call part and any text before/after it
|
|
tool_call_start = match.start()
|
|
tool_call_end = match.end()
|
|
|
|
normal_text_before = text[:tool_call_start] if tool_call_start > 0 else ""
|
|
tool_call_text = text[tool_call_start:tool_call_end]
|
|
normal_text_after = text[tool_call_end:] if tool_call_end < len(text) else ""
|
|
|
|
# Combine normal text
|
|
normal_text = normal_text_before + normal_text_after
|
|
|
|
try:
|
|
module = ast.parse(tool_call_text)
|
|
parsed = getattr(module.body[0], "value", None)
|
|
if not (
|
|
isinstance(parsed, ast.List)
|
|
and all(isinstance(e, ast.Call) for e in parsed.elts)
|
|
):
|
|
return StreamingParseResult(normal_text=normal_text, calls=[])
|
|
|
|
calls = []
|
|
tool_indices = {
|
|
tool.function.name: i
|
|
for i, tool in enumerate(tools)
|
|
if tool.function.name
|
|
}
|
|
for call_index, call in enumerate(parsed.elts):
|
|
if not isinstance(call.func, ast.Name):
|
|
continue
|
|
function_name = call.func.id
|
|
# Validate that the function exists in the tools
|
|
if function_name not in tool_indices:
|
|
logger.warning(
|
|
f"Model attempted to call undefined function: {function_name}"
|
|
)
|
|
continue
|
|
arguments = {}
|
|
for keyword in call.keywords:
|
|
arguments[keyword.arg] = self._get_parameter_value(keyword.value)
|
|
calls.append(
|
|
ToolCallItem(
|
|
tool_index=call_index, # Use the call index in the response, not tool position
|
|
name=function_name,
|
|
parameters=json.dumps(arguments, ensure_ascii=False),
|
|
)
|
|
)
|
|
|
|
return StreamingParseResult(normal_text=normal_text, calls=calls)
|
|
except Exception:
|
|
logger.exception("Error in pythonic tool call parsing.")
|
|
return StreamingParseResult(normal_text=normal_text, calls=[])
|
|
|
|
def _find_matching_bracket(self, buffer: str, start: int) -> int:
|
|
"""
|
|
Find the matching closing bracket for the opening bracket at start position.
|
|
Properly handles nested brackets.
|
|
|
|
Args:
|
|
buffer: The text buffer to search in
|
|
start: Position of the opening bracket '['
|
|
|
|
Returns:
|
|
Position of the matching closing bracket ']', or -1 if not found
|
|
"""
|
|
bracket_count = 0
|
|
for i in range(start, len(buffer)):
|
|
if buffer[i] == "[":
|
|
bracket_count += 1
|
|
elif buffer[i] == "]":
|
|
bracket_count -= 1
|
|
if bracket_count == 0:
|
|
return i
|
|
return -1 # No matching bracket found
|
|
|
|
def _strip_and_split_buffer(self, buffer: str) -> tuple[str, str]:
|
|
"""
|
|
Strip special tokens from buffer and split into safe_text and held_back_text.
|
|
|
|
Returns:
|
|
tuple of (safe_text_to_output, text_to_hold_in_buffer)
|
|
"""
|
|
# Check if original buffer ends with a partial token at the end
|
|
special_tokens = ["<|python_start|>", "<|python_end|>"]
|
|
|
|
for token in special_tokens:
|
|
partial_length = self._ends_with_partial_token(buffer, token)
|
|
if partial_length > 0:
|
|
# Split buffer: safe part + held back partial token
|
|
safe_text = buffer[:-partial_length]
|
|
held_back = buffer[-partial_length:]
|
|
# Strip complete special tokens from safe part only
|
|
safe_text = self._text_strip(safe_text)
|
|
return safe_text, held_back
|
|
|
|
# No partial tokens found, strip complete tokens from entire buffer
|
|
safe_text = self._text_strip(buffer)
|
|
return safe_text, ""
|
|
|
|
def parse_streaming_increment(
|
|
self, new_text: str, tools: List[Tool]
|
|
) -> StreamingParseResult:
|
|
"""
|
|
Streaming incremental parsing for pythonic tool calls.
|
|
Buffers input until a complete pythonic tool call (from [ to ]) is found,
|
|
then parses and emits any detected calls.
|
|
"""
|
|
self._buffer += new_text
|
|
|
|
# Strip special tokens from entire buffer and handle partial tokens
|
|
stripped_buffer, held_back = self._strip_and_split_buffer(self._buffer)
|
|
|
|
start = stripped_buffer.find("[")
|
|
|
|
if start == -1:
|
|
# No tool call bracket found
|
|
self._buffer = held_back
|
|
return StreamingParseResult(normal_text=stripped_buffer)
|
|
|
|
normal_text = stripped_buffer[:start] if start > 0 else ""
|
|
|
|
end = self._find_matching_bracket(stripped_buffer, start)
|
|
if end != -1:
|
|
# Found complete tool call
|
|
call_text = stripped_buffer[start : end + 1]
|
|
result = self.detect_and_parse(call_text, tools)
|
|
|
|
# Update buffer with remaining text after tool call plus any held back text
|
|
remaining_text = stripped_buffer[end + 1 :] + held_back
|
|
self._buffer = remaining_text
|
|
|
|
# If we had normal text before the tool call, add it to the result
|
|
if normal_text:
|
|
result.normal_text = normal_text + (result.normal_text or "")
|
|
|
|
return result
|
|
|
|
# We have an opening bracket but no closing bracket yet
|
|
# Put back everything from the bracket onwards plus held back text
|
|
self._buffer = stripped_buffer[start:] + held_back
|
|
|
|
if normal_text:
|
|
return StreamingParseResult(normal_text=normal_text)
|
|
|
|
# Otherwise, we're still accumulating a potential tool call
|
|
return StreamingParseResult(normal_text="")
|
|
|
|
def _get_parameter_value(self, val):
|
|
if isinstance(val, ast.Constant):
|
|
return val.value
|
|
elif isinstance(val, ast.Dict):
|
|
return {
|
|
k.value: self._get_parameter_value(v)
|
|
for k, v in zip(val.keys, val.values)
|
|
}
|
|
elif isinstance(val, ast.List):
|
|
return [self._get_parameter_value(v) for v in val.elts]
|
|
else:
|
|
raise ValueError("Tool call arguments must be literals")
|
|
|
|
def structure_info(self) -> _GetInfoFunc:
|
|
def info(name: str):
|
|
return StructureInfo(begin=f"[{name}(", end=")]", trigger=f"[{name}(")
|
|
|
|
return info
|
|
|
|
def build_ebnf(self, tools: List[Tool]) -> Optional[str]:
|
|
return EBNFComposer.build_ebnf(
|
|
tools,
|
|
sequence_start_token="[",
|
|
sequence_end_token="]",
|
|
tool_call_separator=",",
|
|
function_format="pythonic",
|
|
)
|