Preliminary Support for Qwen3XMLDetector (#8260)
Co-authored-by: Chayenne <zhaochen20@outlook.com>
This commit is contained in:
@@ -14,6 +14,7 @@ from sglang.srt.function_call.kimik2_detector import KimiK2Detector
|
||||
from sglang.srt.function_call.llama32_detector import Llama32Detector
|
||||
from sglang.srt.function_call.mistral_detector import MistralDetector
|
||||
from sglang.srt.function_call.pythonic_detector import PythonicDetector
|
||||
from sglang.srt.function_call.qwen3_detector import Qwen3XMLDetector
|
||||
from sglang.srt.function_call.qwen25_detector import Qwen25Detector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -35,6 +36,7 @@ class FunctionCallParser:
|
||||
"deepseekv3": DeepSeekV3Detector,
|
||||
"pythonic": PythonicDetector,
|
||||
"kimi_k2": KimiK2Detector,
|
||||
"qwen3": Qwen3XMLDetector,
|
||||
}
|
||||
|
||||
def __init__(self, tools: List[Tool], tool_call_parser: str):
|
||||
|
||||
150
python/sglang/srt/function_call/qwen3_detector.py
Normal file
150
python/sglang/srt/function_call/qwen3_detector.py
Normal file
@@ -0,0 +1,150 @@
|
||||
import ast
|
||||
import html
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from sglang.srt.entrypoints.openai.protocol import Tool
|
||||
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
||||
from sglang.srt.function_call.core_types import (
|
||||
StreamingParseResult,
|
||||
StructureInfo,
|
||||
ToolCallItem,
|
||||
_GetInfoFunc,
|
||||
)
|
||||
from sglang.srt.function_call.ebnf_composer import EBNFComposer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _safe_val(raw: str) -> Any:
|
||||
raw = html.unescape(raw.strip())
|
||||
try:
|
||||
return json.loads(raw)
|
||||
except Exception:
|
||||
try:
|
||||
return ast.literal_eval(raw)
|
||||
except Exception:
|
||||
return raw
|
||||
|
||||
|
||||
class Qwen3XMLDetector(BaseFormatDetector):
|
||||
"""
|
||||
Detector for Qwen 3 models.
|
||||
Assumes function call format:
|
||||
<tool_call>
|
||||
<function=execute_bash>
|
||||
<parameter=command>
|
||||
pwd && ls
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.tool_call_start_token: str = "<tool_call>"
|
||||
self.tool_call_end_token: str = "</tool_call>"
|
||||
self.tool_call_prefix: str = "<function="
|
||||
self.tool_call_regex = re.compile(
|
||||
r"<tool_call>(.*?)</tool_call>|<tool_call>(.*?)$", re.DOTALL
|
||||
)
|
||||
self.tool_call_function_regex = re.compile(
|
||||
r"<function=(.*?)</function>|<function=(.*)$", re.DOTALL
|
||||
)
|
||||
self.tool_call_parameter_regex = re.compile(
|
||||
r"<parameter=(.*?)</parameter>|<parameter=(.*?)$", re.DOTALL
|
||||
)
|
||||
self._buf: str = ""
|
||||
|
||||
def has_tool_call(self, text: str) -> bool:
|
||||
return self.tool_call_start_token in text
|
||||
|
||||
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
||||
normal, calls = self._extract(text, tools)
|
||||
return StreamingParseResult(normal_text=normal, calls=calls)
|
||||
|
||||
def parse_streaming_increment(
|
||||
self, new_text: str, tools: List[Tool]
|
||||
) -> StreamingParseResult:
|
||||
self._buf += new_text
|
||||
normal = ""
|
||||
calls: List[ToolCallItem] = []
|
||||
while True:
|
||||
if self.tool_call_start_token not in self._buf:
|
||||
normal += self._buf
|
||||
self._buf = ""
|
||||
break
|
||||
s = self._buf.find(self.tool_call_start_token)
|
||||
if s > 0:
|
||||
normal += self._buf[:s]
|
||||
self._buf = self._buf[s:]
|
||||
e = self._buf.find(self.tool_call_end_token)
|
||||
if e == -1:
|
||||
break
|
||||
block = self._buf[: e + len(self.tool_call_end_token)]
|
||||
self._buf = self._buf[e + len(self.tool_call_end_token) :]
|
||||
calls.extend(self._parse_block(block, tools))
|
||||
return StreamingParseResult(normal_text=normal, calls=calls)
|
||||
|
||||
def _extract(self, text: str, tools: List[Tool]) -> Tuple[str, List[ToolCallItem]]:
|
||||
normal_parts: List[str] = []
|
||||
calls: List[ToolCallItem] = []
|
||||
cursor = 0
|
||||
while True:
|
||||
s = text.find(self.tool_call_start_token, cursor)
|
||||
if s == -1:
|
||||
normal_parts.append(text[cursor:])
|
||||
break
|
||||
normal_parts.append(text[cursor:s])
|
||||
e = text.find(self.tool_call_end_token, s)
|
||||
if e == -1:
|
||||
normal_parts.append(text[s:])
|
||||
break
|
||||
block = text[s : e + len(self.tool_call_end_token)]
|
||||
cursor = e + len(self.tool_call_end_token)
|
||||
calls.extend(self._parse_block(block, tools))
|
||||
return "".join(normal_parts), calls
|
||||
|
||||
def _parse_block(self, block: str, tools: List[Tool]) -> List[ToolCallItem]:
|
||||
res: List[ToolCallItem] = []
|
||||
for m in self.tool_call_function_regex.findall(block):
|
||||
txt = m[0] if m[0] else m[1]
|
||||
if ">" not in txt:
|
||||
continue
|
||||
idx = txt.index(">")
|
||||
fname = txt[:idx].strip()
|
||||
body = txt[idx + 1 :]
|
||||
params: Dict[str, Any] = {}
|
||||
for pm in self.tool_call_parameter_regex.findall(body):
|
||||
ptxt = pm[0] if pm[0] else pm[1]
|
||||
if ">" not in ptxt:
|
||||
continue
|
||||
pidx = ptxt.index(">")
|
||||
pname = ptxt[:pidx].strip()
|
||||
pval = ptxt[pidx + 1 :].lstrip("\n").rstrip("\n")
|
||||
params[pname] = _safe_val(pval)
|
||||
raw = {"name": fname, "arguments": params}
|
||||
try:
|
||||
res.extend(self.parse_base_json(raw, tools))
|
||||
except Exception:
|
||||
logger.warning("invalid tool call for %s dropped", fname)
|
||||
return res
|
||||
|
||||
def structure_info(self) -> _GetInfoFunc:
|
||||
return lambda n: StructureInfo(
|
||||
begin=f"{self.tool_call_start_token}\n<function={n}>",
|
||||
end=f"</function>\n{self.tool_call_end_token}",
|
||||
trigger=self.tool_call_start_token,
|
||||
)
|
||||
|
||||
# TODO: fake ebnf for xml + outlines backend
|
||||
def build_ebnf(self, tools: List[Tool]):
|
||||
return EBNFComposer.build_ebnf(
|
||||
tools,
|
||||
individual_call_start_token=self.tool_call_start_token.replace("\n", "\\n"),
|
||||
individual_call_end_token=self.tool_call_end_token.replace("\n", "\\n"),
|
||||
tool_call_separator="\\n",
|
||||
function_format="json",
|
||||
)
|
||||
@@ -1099,6 +1099,7 @@ class ServerArgs:
|
||||
"deepseekv3",
|
||||
"pythonic",
|
||||
"kimi_k2",
|
||||
"qwen3",
|
||||
],
|
||||
default=ServerArgs.tool_call_parser,
|
||||
help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', 'pythonic', and 'kimi_k2'.",
|
||||
|
||||
Reference in New Issue
Block a user