[Bugfix][Feat] Add XML-ish grammar in EBNFComposer and fix misc bugs in Qwen3 detector (#8357)

This commit is contained in:
Chang Su
2025-07-25 12:03:16 -07:00
committed by GitHub
parent 12cb760a37
commit f8260f2539
7 changed files with 578 additions and 87 deletions

View File

@@ -321,6 +321,10 @@ class BaseFormatDetector(ABC):
"""
raise NotImplementedError()
def supports_structural_tag(self) -> bool:
"""Return True if this detector supports structural tag format."""
return True
@abstractmethod
def structure_info(self) -> _GetInfoFunc:
"""

View File

@@ -1,51 +1,73 @@
from typing import Literal, Optional
from typing import Any, Dict, Literal, Optional
class EBNFComposer:
# Adapted from https://xgrammar.mlc.ai/docs/how_to/ebnf_guided_generation.html#try-out-via-hf-transformers
json_grammar_ebnf_str = r"""
json ::= basic_array | basic_object
basic_any ::= basic_number | basic_string | basic_boolean | basic_null | basic_array | basic_object
# Shared primitive grammar rules used across all formats
BASE_PRIMITIVE_GRAMMAR = r"""
basic_string ::= (([\"] basic_string_1 [\"]))
basic_string_1 ::= "" | [^"\\\x00-\x1F] basic_string_1 | "\\" escape basic_string_1
escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9]{4}
basic_integer ::= ("0" | "-"? [1-9] [0-9]*) ".0"?
basic_number ::= ("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)?
basic_string ::= (([\"] basic_string_1 [\"]))
basic_string_1 ::= "" | [^"\\\x00-\x1F] basic_string_1 | "\\" escape basic_string_1
escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9]
basic_boolean ::= "true" | "false"
basic_null ::= "null"
basic_array ::= "[" ("" | ws basic_any (ws "," ws basic_any)*) ws "]"
basic_object ::= "{" ("" | ws basic_string ws ":" ws basic_any ( ws "," ws basic_string ws ":" ws basic_any)*) ws "}"
ws ::= [ \n\t]*
"""
pythonic_grammar_ebnf_str = r"""
pythonic ::= basic_number | basic_string | basic_array | "True" | "False" | "None"
basic_any ::= basic_number | basic_string | basic_array | basic_object
basic_number ::= ("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)?
basic_string ::= (([\"] basic_string_1 [\"]))
basic_string_1 ::= "" | [^"\\\x00-\x1F] basic_string_1 | "\\" escape basic_string_1
escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9]
basic_array ::= "[" ("" | ws basic_any (ws "," ws basic_any)*) ws "]"
basic_object ::= "{" ("" | ws basic_string ws ":" ws basic_any ( ws "," ws basic_string ws ":" ws basic_any)*) ws "}"
ws ::= [ \n\t]*
"""
# Format-specific extensions
json_grammar_ebnf_str = (
r"""
json ::= basic_array | basic_object
basic_any ::= basic_number | basic_string | basic_boolean | basic_null | basic_array | basic_object
basic_boolean ::= "true" | "false"
basic_null ::= "null"
"""
+ BASE_PRIMITIVE_GRAMMAR
)
pythonic_grammar_ebnf_str = (
r"""
pythonic ::= basic_number | basic_string | basic_array | "True" | "False" | "None"
basic_any ::= basic_number | basic_string | basic_array | basic_object
basic_boolean ::= "True" | "False"
basic_null ::= "None"
"""
+ BASE_PRIMITIVE_GRAMMAR
)
xml_grammar_ebnf_str = (
r"""
xml ::= xml_element | xml_text
xml_element ::= basic_string | basic_number | basic_boolean | basic_null | basic_array | basic_object
xml_text ::= [^<>]*
basic_any ::= basic_number | basic_string | basic_boolean | basic_null | basic_array | basic_object
basic_boolean ::= "true" | "false"
basic_null ::= "null"
"""
+ BASE_PRIMITIVE_GRAMMAR
)
CALL_RULE_MAP = {
"pythonic": 'call_{name} ::= "{name}" "(" {arguments_rule} ")"',
"json": 'call_{name} ::= "{{" "\\"name\\"" ":" "\\"{name}\\"" ", " "\\"arguments\\"" ":" {arguments_rule} "}}"',
"xml": 'call_{name} ::= "<function={name}>\\n" {arguments_rule} "\\n</function>"',
}
ARGUMENTS_RULE_MAP = {
"pythonic": "{arg_rules}",
"json": '"{{" {arg_rules} "}}"',
"xml": "{arg_rules}",
}
KEY_VALUE_RULE_MAP = {
"pythonic": '"{key}" "=" {valrule}',
"json": '"\\"{key}\\"" ":" {valrule}',
"xml": '"<parameter={key}>\\n" {valrule} "\\n</parameter>"',
}
JSON_TYPE_MAPPING = {
# Base type mapping - most types are the same across formats
BASE_TYPE_MAPPING = {
"string": "basic_string",
"number": "basic_number",
"integer": "basic_number",
@@ -55,19 +77,20 @@ class EBNFComposer:
"object": "basic_object",
}
PYTHONIC_TYPE_MAPPING = {
"string": "basic_string",
"number": "basic_number",
"integer": "basic_number",
"boolean": '"True" | "False"',
"null": '"None"',
"array": "basic_array",
"object": "basic_object",
# Format-specific overrides for types that differ
FORMAT_TYPE_OVERRIDES = {
"pythonic": {
"boolean": '"True" | "False"',
"null": '"None"',
},
"xml": {
"string": "xml_text",
},
}
@staticmethod
def get_value_rule(
prop: dict, function_format: Literal["pythonic", "json"] = "json"
prop: dict, function_format: Literal["pythonic", "json", "xml"] = "json"
) -> str:
if "enum" in prop:
return EBNFComposer._handle_enum(prop, function_format)
@@ -83,48 +106,46 @@ class EBNFComposer:
enum_values = prop["enum"]
prop_type = prop.get("type", "string")
# Define formatters for different type/format combinations
formatters = {
("string", "json"): lambda v: f'"\\"{v}\\""',
("string", "pythonic"): lambda v: f'"\\"{v}\\""',
("number", "json"): str,
("number", "pythonic"): str,
("integer", "json"): str,
("integer", "pythonic"): str,
("boolean", "json"): lambda v: "true" if v else "false",
("boolean", "pythonic"): lambda v: "True" if v else "False",
}
def format_enum_val(v: Any) -> str:
if prop_type == "boolean":
if function_format == "json" or function_format == "xml":
return "true" if v else "false"
elif function_format == "pythonic":
return "True" if v else "False"
else:
return str(v) # fallback
# Get the formatter or default to string handling
formatter = formatters.get(
(prop_type, function_format),
formatters[("string", function_format)], # Default to string handling
)
if prop_type == "string":
if function_format == "xml":
return f'"{v}"'
else: # json or pythonic
return f'"\\"{v}\\""' # escape quote-wrapped string
formatted_values = [formatter(value) for value in enum_values]
# All other types (number, integer, etc.)
return str(v)
formatted_values = [format_enum_val(v) for v in enum_values]
enum_rule = " | ".join(formatted_values)
return f"({enum_rule})" if len(formatted_values) > 1 else enum_rule
# Wrap in parentheses if there are multiple values to ensure correct EBNF precedence
if len(formatted_values) > 1:
enum_rule = f"({enum_rule})"
return enum_rule
@staticmethod
def get_type_mapping(function_format: str) -> Dict[str, str]:
"""Get the complete type mapping for a given format."""
mapping = EBNFComposer.BASE_TYPE_MAPPING.copy()
overrides = EBNFComposer.FORMAT_TYPE_OVERRIDES.get(function_format, {})
mapping.update({k: v for k, v in overrides.items() if v is not None})
return mapping
@staticmethod
def _handle_type(prop: dict, function_format: str) -> str:
"""Handle type properties using the appropriate type mapping."""
prop_type = prop["type"]
type_mapping = (
EBNFComposer.PYTHONIC_TYPE_MAPPING
if function_format == "pythonic"
else EBNFComposer.JSON_TYPE_MAPPING
)
type_mapping = EBNFComposer.get_type_mapping(function_format)
if isinstance(prop_type, list):
type_rules = [
type_mapping[single_type]
type_mapping.get(single_type, function_format)
for single_type in prop_type
if single_type in type_mapping
]
return " | ".join(type_rules) if type_rules else function_format
@@ -133,7 +154,7 @@ class EBNFComposer:
@staticmethod
def build_ebnf(
tools,
function_format: Literal["pythonic", "json"] = "json",
function_format: Literal["pythonic", "json", "xml"] = "json",
# Parameters for wrapping the entire sequence of tool calls
sequence_start_token: Optional[str] = None,
sequence_end_token: Optional[str] = None,
@@ -143,6 +164,7 @@ class EBNFComposer:
# Parameter for separating multiple tool calls
tool_call_separator: Optional[str] = None,
call_rule_fmt: Optional[str] = None,
key_value_rule_fmt: Optional[str] = None,
):
"""
Generalized EBNF builder for all detectors.
@@ -157,6 +179,9 @@ class EBNFComposer:
call_rule_fmt: Optional custom format string for call_{name} rule. It should define each function call's format, with
the placeholders {name} for the function name and {arguments_rule} for the arguments rule. If None, a default
format based on function_format will be used.
key_value_rule_fmt: Optional custom format string for key-value pairs. It should define how each parameter is formatted,
with placeholders {key} for the parameter name and {valrule} for the value rule. If None, a default format
based on function_format will be used.
"""
# =================================================================
# Step 1: Determine the root tool calls rule
@@ -200,7 +225,11 @@ class EBNFComposer:
else EBNFComposer.CALL_RULE_MAP[function_format]
)
args_template = EBNFComposer.ARGUMENTS_RULE_MAP[function_format]
key_value_template = EBNFComposer.KEY_VALUE_RULE_MAP[function_format]
key_value_template = (
key_value_rule_fmt
if key_value_rule_fmt
else EBNFComposer.KEY_VALUE_RULE_MAP[function_format]
)
# =================================================================
# Step 4: Build rules for each tool
@@ -292,10 +321,13 @@ class EBNFComposer:
# =================================================================
# Step 5: Add base grammar rules
# =================================================================
base_grammar = (
EBNFComposer.pythonic_grammar_ebnf_str
if function_format == "pythonic"
else EBNFComposer.json_grammar_ebnf_str
grammar_dict = {
"pythonic": EBNFComposer.pythonic_grammar_ebnf_str,
"json": EBNFComposer.json_grammar_ebnf_str,
"xml": EBNFComposer.xml_grammar_ebnf_str,
}
base_grammar = grammar_dict.get(
function_format, EBNFComposer.json_grammar_ebnf_str
)
ebnf_lines.append(base_grammar)

View File

@@ -14,7 +14,7 @@ from sglang.srt.function_call.kimik2_detector import KimiK2Detector
from sglang.srt.function_call.llama32_detector import Llama32Detector
from sglang.srt.function_call.mistral_detector import MistralDetector
from sglang.srt.function_call.pythonic_detector import PythonicDetector
from sglang.srt.function_call.qwen3_detector import Qwen3XMLDetector
from sglang.srt.function_call.qwen3_coder_detector import Qwen3CoderDetector
from sglang.srt.function_call.qwen25_detector import Qwen25Detector
logger = logging.getLogger(__name__)
@@ -36,7 +36,7 @@ class FunctionCallParser:
"deepseekv3": DeepSeekV3Detector,
"pythonic": PythonicDetector,
"kimi_k2": KimiK2Detector,
"qwen3": Qwen3XMLDetector,
"qwen3_coder": Qwen3CoderDetector,
}
def __init__(self, tools: List[Tool], tool_call_parser: str):
@@ -155,9 +155,9 @@ class FunctionCallParser:
or None if no constraint applies.
"""
# NOTE: structural_tag only supports JSON-compatible content between the begin and end.
# It cannot parse or validate Python syntax like function calls.
# It cannot parse or validate function call Pythonic or XML-ish syntax.
if (
not isinstance(self.detector, PythonicDetector)
self.detector.supports_structural_tag()
and tool_choice == "auto"
and any(tool.function.strict for tool in self.tools)
):

View File

@@ -8,7 +8,6 @@ from sglang.srt.entrypoints.openai.protocol import Tool
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import (
StreamingParseResult,
StructureInfo,
ToolCallItem,
_GetInfoFunc,
)
@@ -216,11 +215,11 @@ class PythonicDetector(BaseFormatDetector):
else:
raise ValueError("Tool call arguments must be literals")
def structure_info(self) -> _GetInfoFunc:
def info(name: str):
return StructureInfo(begin=f"[{name}(", end=")]", trigger=f"[{name}(")
def supports_structural_tag(self) -> bool:
return False
return info
def structure_info(self) -> _GetInfoFunc:
raise NotImplementedError
def build_ebnf(self, tools: List[Tool]) -> Optional[str]:
return EBNFComposer.build_ebnf(

View File

@@ -9,7 +9,6 @@ from sglang.srt.entrypoints.openai.protocol import Tool
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import (
StreamingParseResult,
StructureInfo,
ToolCallItem,
_GetInfoFunc,
)
@@ -29,7 +28,7 @@ def _safe_val(raw: str) -> Any:
return raw
class Qwen3XMLDetector(BaseFormatDetector):
class Qwen3CoderDetector(BaseFormatDetector):
"""
Detector for Qwen 3 models.
Assumes function call format:
@@ -127,24 +126,26 @@ class Qwen3XMLDetector(BaseFormatDetector):
params[pname] = _safe_val(pval)
raw = {"name": fname, "arguments": params}
try:
# TODO: fix idx in function call, the index for a function
# call will always be -1 in parse_base_json
res.extend(self.parse_base_json(raw, tools))
except Exception:
logger.warning("invalid tool call for %s dropped", fname)
return res
def structure_info(self) -> _GetInfoFunc:
return lambda n: StructureInfo(
begin=f"{self.tool_call_start_token}\n<function={n}>",
end=f"</function>\n{self.tool_call_end_token}",
trigger=self.tool_call_start_token,
)
def supports_structural_tag(self) -> bool:
return False
def structure_info(self) -> _GetInfoFunc:
raise NotImplementedError
# TODO: fake ebnf for xml + outlines backend
def build_ebnf(self, tools: List[Tool]):
return EBNFComposer.build_ebnf(
tools,
individual_call_start_token=self.tool_call_start_token.replace("\n", "\\n"),
individual_call_end_token=self.tool_call_end_token.replace("\n", "\\n"),
tool_call_separator="\\n",
function_format="json",
function_format="xml",
call_rule_fmt='"<function={name}>\\n" {arguments_rule} "\\n</function>"',
key_value_rule_fmt='"<parameter={key}>\\n" {valrule} "\\n</parameter>"',
)

View File

@@ -1099,10 +1099,10 @@ class ServerArgs:
"deepseekv3",
"pythonic",
"kimi_k2",
"qwen3",
"qwen3_coder",
],
default=ServerArgs.tool_call_parser,
help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', 'pythonic', and 'kimi_k2'.",
help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', 'pythonic', 'kimi_k2', and 'qwen3_coder'.",
)
# Data parallelism