Preliminary Support for Qwen3XMLDetector (#8260)
Co-authored-by: Chayenne <zhaochen20@outlook.com>
This commit is contained in:
@@ -14,6 +14,7 @@ from sglang.srt.function_call.kimik2_detector import KimiK2Detector
|
|||||||
from sglang.srt.function_call.llama32_detector import Llama32Detector
|
from sglang.srt.function_call.llama32_detector import Llama32Detector
|
||||||
from sglang.srt.function_call.mistral_detector import MistralDetector
|
from sglang.srt.function_call.mistral_detector import MistralDetector
|
||||||
from sglang.srt.function_call.pythonic_detector import PythonicDetector
|
from sglang.srt.function_call.pythonic_detector import PythonicDetector
|
||||||
|
from sglang.srt.function_call.qwen3_detector import Qwen3XMLDetector
|
||||||
from sglang.srt.function_call.qwen25_detector import Qwen25Detector
|
from sglang.srt.function_call.qwen25_detector import Qwen25Detector
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -35,6 +36,7 @@ class FunctionCallParser:
|
|||||||
"deepseekv3": DeepSeekV3Detector,
|
"deepseekv3": DeepSeekV3Detector,
|
||||||
"pythonic": PythonicDetector,
|
"pythonic": PythonicDetector,
|
||||||
"kimi_k2": KimiK2Detector,
|
"kimi_k2": KimiK2Detector,
|
||||||
|
"qwen3": Qwen3XMLDetector,
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, tools: List[Tool], tool_call_parser: str):
|
def __init__(self, tools: List[Tool], tool_call_parser: str):
|
||||||
|
|||||||
150
python/sglang/srt/function_call/qwen3_detector.py
Normal file
150
python/sglang/srt/function_call/qwen3_detector.py
Normal file
@@ -0,0 +1,150 @@
|
|||||||
|
import ast
|
||||||
|
import html
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from typing import Any, Dict, List, Tuple
|
||||||
|
|
||||||
|
from sglang.srt.entrypoints.openai.protocol import Tool
|
||||||
|
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
||||||
|
from sglang.srt.function_call.core_types import (
|
||||||
|
StreamingParseResult,
|
||||||
|
StructureInfo,
|
||||||
|
ToolCallItem,
|
||||||
|
_GetInfoFunc,
|
||||||
|
)
|
||||||
|
from sglang.srt.function_call.ebnf_composer import EBNFComposer
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _safe_val(raw: str) -> Any:
|
||||||
|
raw = html.unescape(raw.strip())
|
||||||
|
try:
|
||||||
|
return json.loads(raw)
|
||||||
|
except Exception:
|
||||||
|
try:
|
||||||
|
return ast.literal_eval(raw)
|
||||||
|
except Exception:
|
||||||
|
return raw
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen3XMLDetector(BaseFormatDetector):
|
||||||
|
"""
|
||||||
|
Detector for Qwen 3 models.
|
||||||
|
Assumes function call format:
|
||||||
|
<tool_call>
|
||||||
|
<function=execute_bash>
|
||||||
|
<parameter=command>
|
||||||
|
pwd && ls
|
||||||
|
</parameter>
|
||||||
|
</function>
|
||||||
|
</tool_call>
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.tool_call_start_token: str = "<tool_call>"
|
||||||
|
self.tool_call_end_token: str = "</tool_call>"
|
||||||
|
self.tool_call_prefix: str = "<function="
|
||||||
|
self.tool_call_regex = re.compile(
|
||||||
|
r"<tool_call>(.*?)</tool_call>|<tool_call>(.*?)$", re.DOTALL
|
||||||
|
)
|
||||||
|
self.tool_call_function_regex = re.compile(
|
||||||
|
r"<function=(.*?)</function>|<function=(.*)$", re.DOTALL
|
||||||
|
)
|
||||||
|
self.tool_call_parameter_regex = re.compile(
|
||||||
|
r"<parameter=(.*?)</parameter>|<parameter=(.*?)$", re.DOTALL
|
||||||
|
)
|
||||||
|
self._buf: str = ""
|
||||||
|
|
||||||
|
def has_tool_call(self, text: str) -> bool:
|
||||||
|
return self.tool_call_start_token in text
|
||||||
|
|
||||||
|
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
||||||
|
normal, calls = self._extract(text, tools)
|
||||||
|
return StreamingParseResult(normal_text=normal, calls=calls)
|
||||||
|
|
||||||
|
def parse_streaming_increment(
|
||||||
|
self, new_text: str, tools: List[Tool]
|
||||||
|
) -> StreamingParseResult:
|
||||||
|
self._buf += new_text
|
||||||
|
normal = ""
|
||||||
|
calls: List[ToolCallItem] = []
|
||||||
|
while True:
|
||||||
|
if self.tool_call_start_token not in self._buf:
|
||||||
|
normal += self._buf
|
||||||
|
self._buf = ""
|
||||||
|
break
|
||||||
|
s = self._buf.find(self.tool_call_start_token)
|
||||||
|
if s > 0:
|
||||||
|
normal += self._buf[:s]
|
||||||
|
self._buf = self._buf[s:]
|
||||||
|
e = self._buf.find(self.tool_call_end_token)
|
||||||
|
if e == -1:
|
||||||
|
break
|
||||||
|
block = self._buf[: e + len(self.tool_call_end_token)]
|
||||||
|
self._buf = self._buf[e + len(self.tool_call_end_token) :]
|
||||||
|
calls.extend(self._parse_block(block, tools))
|
||||||
|
return StreamingParseResult(normal_text=normal, calls=calls)
|
||||||
|
|
||||||
|
def _extract(self, text: str, tools: List[Tool]) -> Tuple[str, List[ToolCallItem]]:
|
||||||
|
normal_parts: List[str] = []
|
||||||
|
calls: List[ToolCallItem] = []
|
||||||
|
cursor = 0
|
||||||
|
while True:
|
||||||
|
s = text.find(self.tool_call_start_token, cursor)
|
||||||
|
if s == -1:
|
||||||
|
normal_parts.append(text[cursor:])
|
||||||
|
break
|
||||||
|
normal_parts.append(text[cursor:s])
|
||||||
|
e = text.find(self.tool_call_end_token, s)
|
||||||
|
if e == -1:
|
||||||
|
normal_parts.append(text[s:])
|
||||||
|
break
|
||||||
|
block = text[s : e + len(self.tool_call_end_token)]
|
||||||
|
cursor = e + len(self.tool_call_end_token)
|
||||||
|
calls.extend(self._parse_block(block, tools))
|
||||||
|
return "".join(normal_parts), calls
|
||||||
|
|
||||||
|
def _parse_block(self, block: str, tools: List[Tool]) -> List[ToolCallItem]:
|
||||||
|
res: List[ToolCallItem] = []
|
||||||
|
for m in self.tool_call_function_regex.findall(block):
|
||||||
|
txt = m[0] if m[0] else m[1]
|
||||||
|
if ">" not in txt:
|
||||||
|
continue
|
||||||
|
idx = txt.index(">")
|
||||||
|
fname = txt[:idx].strip()
|
||||||
|
body = txt[idx + 1 :]
|
||||||
|
params: Dict[str, Any] = {}
|
||||||
|
for pm in self.tool_call_parameter_regex.findall(body):
|
||||||
|
ptxt = pm[0] if pm[0] else pm[1]
|
||||||
|
if ">" not in ptxt:
|
||||||
|
continue
|
||||||
|
pidx = ptxt.index(">")
|
||||||
|
pname = ptxt[:pidx].strip()
|
||||||
|
pval = ptxt[pidx + 1 :].lstrip("\n").rstrip("\n")
|
||||||
|
params[pname] = _safe_val(pval)
|
||||||
|
raw = {"name": fname, "arguments": params}
|
||||||
|
try:
|
||||||
|
res.extend(self.parse_base_json(raw, tools))
|
||||||
|
except Exception:
|
||||||
|
logger.warning("invalid tool call for %s dropped", fname)
|
||||||
|
return res
|
||||||
|
|
||||||
|
def structure_info(self) -> _GetInfoFunc:
|
||||||
|
return lambda n: StructureInfo(
|
||||||
|
begin=f"{self.tool_call_start_token}\n<function={n}>",
|
||||||
|
end=f"</function>\n{self.tool_call_end_token}",
|
||||||
|
trigger=self.tool_call_start_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: fake ebnf for xml + outlines backend
|
||||||
|
def build_ebnf(self, tools: List[Tool]):
|
||||||
|
return EBNFComposer.build_ebnf(
|
||||||
|
tools,
|
||||||
|
individual_call_start_token=self.tool_call_start_token.replace("\n", "\\n"),
|
||||||
|
individual_call_end_token=self.tool_call_end_token.replace("\n", "\\n"),
|
||||||
|
tool_call_separator="\\n",
|
||||||
|
function_format="json",
|
||||||
|
)
|
||||||
@@ -1099,6 +1099,7 @@ class ServerArgs:
|
|||||||
"deepseekv3",
|
"deepseekv3",
|
||||||
"pythonic",
|
"pythonic",
|
||||||
"kimi_k2",
|
"kimi_k2",
|
||||||
|
"qwen3",
|
||||||
],
|
],
|
||||||
default=ServerArgs.tool_call_parser,
|
default=ServerArgs.tool_call_parser,
|
||||||
help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', 'pythonic', and 'kimi_k2'.",
|
help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', 'pythonic', and 'kimi_k2'.",
|
||||||
|
|||||||
Reference in New Issue
Block a user