122 lines
4.8 KiB
Python
122 lines
4.8 KiB
Python
import json
|
|
import logging
|
|
import re
|
|
from typing import List
|
|
|
|
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
|
from sglang.srt.function_call.core_types import (
|
|
StreamingParseResult,
|
|
StructureInfo,
|
|
_GetInfoFunc,
|
|
)
|
|
from sglang.srt.function_call.ebnf_composer import EBNFComposer
|
|
from sglang.srt.openai_api.protocol import Tool
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class Qwen25Detector(BaseFormatDetector):
|
|
"""
|
|
Detector for Qwen 2.5 models.
|
|
Assumes function call format:
|
|
<tool_call>\n{"name":"func1", "arguments":{...}}\n</tool_call>\n<tool_call>\n{"name":"func2", "arguments":{...}}\n</tool_call>
|
|
"""
|
|
|
|
def __init__(self):
|
|
"""
|
|
Initializes the detector with necessary state variables.
|
|
"""
|
|
super().__init__()
|
|
self.bot_token = "<tool_call>\n"
|
|
self.eot_token = "\n</tool_call>"
|
|
self._normal_text_buffer = "" # Buffer for handling partial end tokens
|
|
|
|
def has_tool_call(self, text: str) -> bool:
|
|
"""Check if the text contains a Qwen 2.5 format tool call."""
|
|
return self.bot_token in text
|
|
|
|
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
|
"""
|
|
One-time parsing: Detects and parses tool calls in the provided text.
|
|
|
|
:param text: The complete text to parse.
|
|
:param tools: List of available tools.
|
|
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
|
|
"""
|
|
idx = text.find(self.bot_token)
|
|
normal_text = text[:idx].strip() if idx != -1 else text
|
|
if self.bot_token not in text:
|
|
return StreamingParseResult(normal_text=normal_text, calls=[])
|
|
|
|
# Find all <tool_call>\n...\n</tool_call> blocks
|
|
pattern = rf"{re.escape(self.bot_token)}(.*?){re.escape(self.eot_token)}"
|
|
match_result_list = re.findall(pattern, text, re.DOTALL)
|
|
calls = []
|
|
for match_result in match_result_list:
|
|
try:
|
|
parsed_call = json.loads(match_result.strip())
|
|
calls.extend(self.parse_base_json(parsed_call, tools))
|
|
except json.JSONDecodeError as e:
|
|
logger.warning(
|
|
f"Failed to parse JSON part: {match_result}, JSON parse error: {str(e)}"
|
|
)
|
|
continue
|
|
return StreamingParseResult(normal_text=normal_text, calls=calls)
|
|
|
|
def parse_streaming_increment(
|
|
self, new_text: str, tools: List[Tool]
|
|
) -> StreamingParseResult:
|
|
"""
|
|
Streaming incremental parsing for Qwen 2.5 tool calls.
|
|
Uses base class implementation with buffering to handle partial end tokens.
|
|
"""
|
|
result = super().parse_streaming_increment(new_text, tools)
|
|
|
|
# Handle partial end tokens that are streamed character by character
|
|
if result.normal_text:
|
|
self._normal_text_buffer += result.normal_text
|
|
|
|
# Check if buffer contains complete end token (without leading newline)
|
|
end_token_without_newline = self.eot_token[1:] # "</tool_call>"
|
|
if end_token_without_newline in self._normal_text_buffer:
|
|
cleaned_text = self._normal_text_buffer.replace(
|
|
end_token_without_newline, ""
|
|
)
|
|
self._normal_text_buffer = ""
|
|
result.normal_text = cleaned_text
|
|
else:
|
|
# Check if buffer might contain partial end token at the end
|
|
partial_match_len = self.ends_with_partial_token(
|
|
self._normal_text_buffer, end_token_without_newline
|
|
)
|
|
|
|
if partial_match_len:
|
|
# Keep potential partial match in buffer, return the rest
|
|
result.normal_text = self._normal_text_buffer[:-partial_match_len]
|
|
self._normal_text_buffer = self._normal_text_buffer[
|
|
-partial_match_len:
|
|
]
|
|
else:
|
|
# No partial match, return all buffered text
|
|
result.normal_text = self._normal_text_buffer
|
|
self._normal_text_buffer = ""
|
|
|
|
return result
|
|
|
|
def structure_info(self) -> _GetInfoFunc:
|
|
# TODO: Update the begin and end tokens with '\n' if necessary
|
|
return lambda name: StructureInfo(
|
|
begin='<tool_call>{"name":"' + name + '", "arguments":',
|
|
end="}</tool_call>",
|
|
trigger="<tool_call>",
|
|
)
|
|
|
|
def build_ebnf(self, tools: List[Tool]):
|
|
return EBNFComposer.build_ebnf(
|
|
tools,
|
|
individual_call_start_token=self.bot_token.replace("\n", "\\n"),
|
|
individual_call_end_token=self.eot_token.replace("\n", "\\n"),
|
|
tool_call_separator="\\n",
|
|
function_format="json",
|
|
)
|