68 lines
2.3 KiB
Python
68 lines
2.3 KiB
Python
|
|
import json
|
||
|
|
import re
|
||
|
|
from typing import List
|
||
|
|
|
||
|
|
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
||
|
|
from sglang.srt.function_call.core_types import (
|
||
|
|
StreamingParseResult,
|
||
|
|
StructureInfo,
|
||
|
|
_GetInfoFunc,
|
||
|
|
)
|
||
|
|
from sglang.srt.function_call.ebnf_composer import EBNFComposer
|
||
|
|
from sglang.srt.openai_api.protocol import Tool
|
||
|
|
|
||
|
|
|
||
|
|
class Qwen25Detector(BaseFormatDetector):
|
||
|
|
"""
|
||
|
|
Detector for Qwen 2.5 models.
|
||
|
|
Assumes function call format:
|
||
|
|
<tool_call>{"name":"xxx", "arguments":{...}}</tool_call>
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(self):
|
||
|
|
"""
|
||
|
|
Initializes the detector with necessary state variables.
|
||
|
|
"""
|
||
|
|
super().__init__()
|
||
|
|
self.bot_token = "<tool_call>"
|
||
|
|
self.eot_token = "</tool_call>"
|
||
|
|
|
||
|
|
def has_tool_call(self, text: str) -> bool:
|
||
|
|
"""Check if the text contains a Qwen 2.5 format tool call."""
|
||
|
|
return self.bot_token in text
|
||
|
|
|
||
|
|
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
||
|
|
"""
|
||
|
|
One-time parsing: Detects and parses tool calls in the provided text.
|
||
|
|
|
||
|
|
:param text: The complete text to parse.
|
||
|
|
:param tools: List of available tools.
|
||
|
|
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
|
||
|
|
"""
|
||
|
|
idx = text.find(self.bot_token)
|
||
|
|
normal_text = text[:idx].strip() if idx != -1 else text
|
||
|
|
if self.bot_token not in text:
|
||
|
|
return StreamingParseResult(normal_text=normal_text, calls=[])
|
||
|
|
pattern = rf"{self.bot_token}(.*?){self.eot_token}"
|
||
|
|
match_result_list = re.findall(pattern, text, re.DOTALL)
|
||
|
|
calls = []
|
||
|
|
for match_result in match_result_list:
|
||
|
|
match_result = json.loads(match_result)
|
||
|
|
calls.extend(self.parse_base_json(match_result, tools))
|
||
|
|
return StreamingParseResult(normal_text=normal_text, calls=calls)
|
||
|
|
|
||
|
|
def structure_info(self) -> _GetInfoFunc:
|
||
|
|
return lambda name: StructureInfo(
|
||
|
|
begin='<tool_call>{"name":"' + name + '", "arguments":',
|
||
|
|
end="}</tool_call>",
|
||
|
|
trigger="<tool_call>",
|
||
|
|
)
|
||
|
|
|
||
|
|
def build_ebnf(self, tools: List[Tool]):
|
||
|
|
return EBNFComposer.build_ebnf(
|
||
|
|
tools,
|
||
|
|
bot_token=self.bot_token,
|
||
|
|
eot_token=self.eot_token,
|
||
|
|
function_format="json",
|
||
|
|
)
|