Files
sglang/python/sglang/srt/function_call_parser.py
2025-03-04 21:23:47 -08:00

521 lines
19 KiB
Python

import json
import logging
import re
from json import JSONDecodeError, JSONDecoder
from typing import Any, Dict, List, Optional, Tuple
import partial_json_parser
from partial_json_parser.core.options import Allow
from pydantic import BaseModel, Field
logger = logging.getLogger(__name__)
TOOLS_TAG_LIST = [
"<|plugin|>",
"<function=",
"<tool_call>",
"<|python_tag|>",
"[TOOL_CALLS]",
]
class Function(BaseModel):
"""Function Tool Template."""
description: Optional[str] = Field(default=None, examples=[None])
name: Optional[str] = None
parameters: Optional[object] = None
class ToolCallItem(BaseModel):
"""Simple encapsulation of the parsed ToolCall result for easier usage in streaming contexts."""
tool_index: int
name: Optional[str] = None
parameters: str # JSON string
def _find_common_prefix(s1: str, s2: str) -> str:
prefix = ""
min_length = min(len(s1), len(s2))
for i in range(0, min_length):
if s1[i] == s2[i]:
prefix += s1[i]
else:
break
return prefix
def _partial_json_loads(input_str: str, flags: Allow) -> Tuple[Any, int]:
try:
return (partial_json_parser.loads(input_str, flags), len(input_str))
except JSONDecodeError as e:
if "Extra data" in e.msg:
dec = JSONDecoder()
return dec.raw_decode(input_str)
raise
def _is_complete_json(input_str: str) -> bool:
try:
json.loads(input_str)
return True
except JSONDecodeError:
return False
class StreamingParseResult:
"""Result of streaming incremental parsing."""
def __init__(
self, normal_text: str = "", calls: Optional[List[ToolCallItem]] = None
):
self.normal_text = normal_text
self.calls = calls or []
class BaseFormatDetector:
"""Base class providing two sets of interfaces: one-time and streaming incremental."""
def __init__(self):
# initialize properties used for state when parsing tool calls in
self._buffer = ""
# streaming mode
self.prev_tool_call_arr: List[Dict] = []
self.current_tool_id: int = -1
self.current_tool_name_sent: bool = False
self.streamed_args_for_tool: List[str] = (
[]
) # map what has been streamed for each tool so far to a list
self.bot_token = ""
self.eot_token = ""
def parse_base_json(self, action: Any, tools: List[Function]) -> List[ToolCallItem]:
tool_indices = {
tool.function.name: i for i, tool in enumerate(tools) if tool.function.name
}
if not isinstance(action, list):
name = action.get("name")
if not name or name not in tool_indices:
logger.warning(f"Model attempted to call undefined function: {name}")
return []
return [
ToolCallItem(
tool_index=tool_indices[name],
name=name,
parameters=json.dumps(
action.get("parameters") or action.get("arguments", {}),
ensure_ascii=False,
),
)
]
results = []
for act in action:
name = act.get("name")
if name and name in tool_indices:
results.append(
ToolCallItem(
tool_index=tool_indices[name],
name=name,
parameters=json.dumps(
act.get("parameters") or act.get("arguments", {}),
ensure_ascii=False,
),
)
)
return results
def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
"""
Parses the text in one go. Returns success=True if the format matches, otherwise False.
Note that leftover_text here represents "content that this parser will not consume further".
"""
action = json.loads(text)
return self.parse_base_json(action, tools)
def parse_streaming_increment(
self, new_text: str, tools: List[Function]
) -> StreamingParseResult:
"""
Streaming incremental parsing with tool validation.
"""
# Append new text to buffer
self._buffer += new_text
current_text = self._buffer
if not (self.bot_token in current_text or current_text.startswith("{")):
self._buffer = ""
if self.eot_token in new_text:
new_text = new_text.replace(self.eot_token, "")
return StreamingParseResult(normal_text=new_text)
# Build tool indices if not already built
if not hasattr(self, "_tool_indices"):
self._tool_indices = {
tool.function.name: i
for i, tool in enumerate(tools)
if tool.function and tool.function.name
}
flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR
try:
tool_call_arr = []
is_complete = []
try:
start_idx = (
len(self.bot_token)
if current_text.startswith(self.bot_token)
else 0
)
while start_idx < len(current_text):
(obj, end_idx) = _partial_json_loads(
current_text[start_idx:], flags
)
is_complete.append(
_is_complete_json(current_text[start_idx : start_idx + end_idx])
)
start_idx += end_idx + len("; ")
# Validate tool name if present
if "name" in obj and obj["name"] not in self._tool_indices:
# Invalid tool name - reset state
self._buffer = ""
self.current_tool_id = -1
self.current_tool_name_sent = False
if self.streamed_args_for_tool:
self.streamed_args_for_tool.pop()
return StreamingParseResult()
# Handle parameters/arguments consistency
if "parameters" in obj:
assert (
"arguments" not in obj
), "model generated both parameters and arguments"
obj["arguments"] = obj["parameters"]
tool_call_arr.append(obj)
except partial_json_parser.core.exceptions.MalformedJSON:
return StreamingParseResult()
if len(tool_call_arr) == 0:
return StreamingParseResult()
current_tool_call: Dict = (
tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {}
)
# Handle new tool in array
if len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1:
if self.current_tool_id >= 0:
cur_arguments = current_tool_call.get("arguments")
if cur_arguments:
cur_args_json = json.dumps(cur_arguments)
sent = len(self.streamed_args_for_tool[self.current_tool_id])
argument_diff = cur_args_json[sent:]
res = StreamingParseResult(
calls=[
ToolCallItem(
tool_index=self.current_tool_id,
name="",
parameters=argument_diff,
)
],
)
self.streamed_args_for_tool[
self.current_tool_id
] += argument_diff
else:
res = StreamingParseResult()
else:
res = StreamingParseResult()
self.current_tool_id = len(tool_call_arr) - 1
self.current_tool_name_sent = False
self.streamed_args_for_tool.append("")
return res
# Handle tool name
elif not self.current_tool_name_sent:
function_name = current_tool_call.get("name")
if function_name and function_name in self._tool_indices:
res = StreamingParseResult(
calls=[
ToolCallItem(
tool_index=self._tool_indices[function_name],
name=function_name,
parameters="",
)
],
)
self.current_tool_name_sent = True
else:
res = StreamingParseResult()
# Handle streaming arguments
else:
cur_arguments = current_tool_call.get("arguments")
res = StreamingParseResult()
if cur_arguments:
sent = len(self.streamed_args_for_tool[self.current_tool_id])
cur_args_json = json.dumps(cur_arguments)
prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get(
"arguments"
)
argument_diff = None
if is_complete[self.current_tool_id]:
argument_diff = cur_args_json[sent:]
self._buffer = ""
self.prev_tool_call_arr[self.current_tool_id].clear()
self.current_tool_name_sent = False
self.streamed_args_for_tool[self.current_tool_id] = ""
elif prev_arguments:
prev_args_json = json.dumps(prev_arguments)
if cur_args_json != prev_args_json:
prefix = _find_common_prefix(prev_args_json, cur_args_json)
argument_diff = prefix[sent:]
if argument_diff is not None:
res = StreamingParseResult(
calls=[
ToolCallItem(
tool_index=self.current_tool_id,
name="",
parameters=argument_diff,
)
],
)
if not is_complete[self.current_tool_id]:
self.streamed_args_for_tool[
self.current_tool_id
] += argument_diff
self.prev_tool_call_arr = tool_call_arr
return res
except Exception as e:
logger.error(f"Error in parse_streaming_increment: {e}")
return StreamingParseResult()
class Qwen25Detector(BaseFormatDetector):
"""
Detector for Qwen 2.5 models.
Assumes function call format:
<tool_call>{"name":"xxx", "arguments":{...}}</tool_call>
"""
def __init__(self):
"""
Initializes the detector with necessary state variables.
"""
super().__init__()
self.bot_token = "<tool_call>"
self.eot_token = "</tool_call>"
def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
"""
One-time parsing: Detects and parses tool calls in the provided text.
:param text: The complete text to parse.
:param tools: List of available tools.
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
"""
if "<tool_call>" not in text:
return []
pattern = r"<tool_call>(.*?)</tool_call>"
match_result_list = re.findall(pattern, text, re.DOTALL)
calls = []
for match_result in match_result_list:
match_result = json.loads(match_result)
calls.extend(self.parse_base_json(match_result, tools))
return calls
class MistralDetector(BaseFormatDetector):
"""
Detector for Mistral models.
Assumes function call format:
<|action_start|><|plugin|>{"name":"xxx", "arguments":{...}}<|action_end|>
"""
def __init__(self):
"""
Initializes the detector with necessary state variables.
"""
super().__init__()
self.bot_token = "[TOOL_CALLS] ["
self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
def _clean_text(self, text: str) -> str:
"""
clean text to only leave ''[TOOL_CALLS] [{"name": xxx, "arguments": {xxx}}]'
for example,
text = '[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"location": "Boston, MA", "unit": "fahrenheit"}}]\n\nToday\'s weather in Boston is :{function call result} (in Fahrenheit)\n\nIf you prefer Celsius, please let me know.'
return '[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"location": "Boston, MA", "unit": "fahrenheit"}}]'
The key pattern is [TOOL_CALLS] [...]
"""
find_results = re.findall(r"\[TOOL_CALLS\] \[.*?\]", text, re.DOTALL)
if len(find_results) > 0:
return find_results[0]
else:
return ""
def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
"""
One-time parsing: Detects and parses tool calls in the provided text.
:param text: The complete text to parse.
:param tools: List of available tools.
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
"""
text = self._clean_text(text)
tool_content = text.replace("[TOOL_CALLS]", "").strip()
raw_tool_calls = self.tool_call_regex.findall(tool_content)
calls = []
if len(raw_tool_calls) > 0:
raw_tool_call = raw_tool_calls[0]
function_call_arr = json.loads(raw_tool_call)
for match_result in function_call_arr:
calls.extend(self.parse_base_json(match_result, tools))
return calls
class Llama32Detector(BaseFormatDetector):
"""
Detector for Llama 3.2 models.
Assumes function call format:
<|python_tag|>{"name":"xxx", "arguments":{...}}
"""
def __init__(self):
super().__init__()
self.bot_token = "<|python_tag|>"
def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
"""Parse function calls from text, handling multiple JSON objects."""
if "<|python_tag|>" not in text:
return []
_, action_text = text.split("<|python_tag|>")
# Split by semicolon and process each part
json_parts = [part.strip() for part in action_text.split(";") if part.strip()]
all_actions = []
for part in json_parts:
try:
# Parse each individual JSON object
action = json.loads(part)
all_actions.append(action)
except json.JSONDecodeError as e:
logger.warning(f"Failed to parse JSON part: {part}")
logger.warning(f"JSON parse error: {str(e)}")
continue
# Only process if we found valid JSON objects
if all_actions:
return self.parse_base_json(all_actions, tools)
return []
class MultiFormatParser:
def __init__(self, detectors: List[BaseFormatDetector]):
"""
:param detectors: A series of available Detector instances passed in
"""
self.detectors = detectors
def parse_once(self, text: str, tools: List[Function]):
"""
One-time parsing: Loop through detectors until there are no new matches or text is exhausted
Return: (final_text, all_calls)
- final_text: The remaining text after parsing that was not consumed by any Detector (can be treated as normal text)
- all_calls: All calls parsed by the Detectors
"""
final_calls = []
final_normal_text = text
for detector in self.detectors:
tool_call_list = detector.detect_and_parse(text, tools)
if len(tool_call_list) > 0: # parsed successfully
final_calls = tool_call_list
break
# leftover_text is the normal text not consumed by any Detector
return final_normal_text, final_calls
def parse_streaming_increment(self, new_text: str, tools: List[Function]):
"""
Streaming incremental parsing: Feed new_text to each detector's parse_streaming_increment
and merge their produced normal_text/calls to return.
(The logic here can be "priority-based" or "parallel parsing" based on your needs)
"""
final_normal_text = ""
final_calls = []
for detector in self.detectors:
sp_result = detector.parse_streaming_increment(new_text, tools)
# Merge normal_text and calls
# If one sp_result contains result call, this should be a successful parse
# If one sp_result only contains normal_text, this can either be a successful
# parse or it is not using the desired parsing tool.
if sp_result.normal_text:
final_normal_text = sp_result.normal_text
if sp_result.calls:
final_calls.extend(sp_result.calls)
final_normal_text = sp_result.normal_text
break
return final_normal_text, final_calls
class FunctionCallParser:
"""
In streaming scenarios, each time new_text is received, it calls multi_format_parser.parse_streaming_increment
and returns the resulting normal_text and calls to the upper layer (or SSE).
"""
ToolCallParserEnum: Dict[str, BaseFormatDetector] = {
"llama3": Llama32Detector,
"qwen25": Qwen25Detector,
"mistral": MistralDetector,
}
def __init__(self, tools: List[Function], tool_call_parser: str = None):
detectors = []
if tool_call_parser:
detector_class = self.ToolCallParserEnum.get(tool_call_parser)
if detector_class:
detectors.append(detector_class())
else:
raise ValueError(f"Unsupported tool_call_parser: {tool_call_parser}")
else:
raise ValueError("Tool Call Parser Not Given!")
self.multi_format_parser = MultiFormatParser(detectors)
self.tools = tools
def parse_non_stream(self, full_text: str):
"""
Non-streaming call: one-time parsing
"""
full_normal_text, calls = self.multi_format_parser.parse_once(
full_text, self.tools
)
return full_normal_text, calls
def parse_stream_chunk(self, chunk_text: str):
"""
Streaming call: incremental parsing
"""
normal_text, calls = self.multi_format_parser.parse_streaming_increment(
chunk_text, self.tools
)
return normal_text, calls