[Bugfix][Feat] Add XML-ish grammar in EBNFComposer and fix misc bugs in Qwen3 detector (#8357)
This commit is contained in:
@@ -321,6 +321,10 @@ class BaseFormatDetector(ABC):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def supports_structural_tag(self) -> bool:
|
||||||
|
"""Return True if this detector supports structural tag format."""
|
||||||
|
return True
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def structure_info(self) -> _GetInfoFunc:
|
def structure_info(self) -> _GetInfoFunc:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,51 +1,73 @@
|
|||||||
from typing import Literal, Optional
|
from typing import Any, Dict, Literal, Optional
|
||||||
|
|
||||||
|
|
||||||
class EBNFComposer:
|
class EBNFComposer:
|
||||||
# Adapted from https://xgrammar.mlc.ai/docs/how_to/ebnf_guided_generation.html#try-out-via-hf-transformers
|
# Adapted from https://xgrammar.mlc.ai/docs/how_to/ebnf_guided_generation.html#try-out-via-hf-transformers
|
||||||
json_grammar_ebnf_str = r"""
|
# Shared primitive grammar rules used across all formats
|
||||||
json ::= basic_array | basic_object
|
BASE_PRIMITIVE_GRAMMAR = r"""
|
||||||
basic_any ::= basic_number | basic_string | basic_boolean | basic_null | basic_array | basic_object
|
basic_string ::= (([\"] basic_string_1 [\"]))
|
||||||
|
basic_string_1 ::= "" | [^"\\\x00-\x1F] basic_string_1 | "\\" escape basic_string_1
|
||||||
|
escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9]{4}
|
||||||
basic_integer ::= ("0" | "-"? [1-9] [0-9]*) ".0"?
|
basic_integer ::= ("0" | "-"? [1-9] [0-9]*) ".0"?
|
||||||
basic_number ::= ("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)?
|
basic_number ::= ("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)?
|
||||||
basic_string ::= (([\"] basic_string_1 [\"]))
|
|
||||||
basic_string_1 ::= "" | [^"\\\x00-\x1F] basic_string_1 | "\\" escape basic_string_1
|
|
||||||
escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9]
|
|
||||||
basic_boolean ::= "true" | "false"
|
|
||||||
basic_null ::= "null"
|
|
||||||
basic_array ::= "[" ("" | ws basic_any (ws "," ws basic_any)*) ws "]"
|
|
||||||
basic_object ::= "{" ("" | ws basic_string ws ":" ws basic_any ( ws "," ws basic_string ws ":" ws basic_any)*) ws "}"
|
|
||||||
ws ::= [ \n\t]*
|
|
||||||
"""
|
|
||||||
|
|
||||||
pythonic_grammar_ebnf_str = r"""
|
|
||||||
pythonic ::= basic_number | basic_string | basic_array | "True" | "False" | "None"
|
|
||||||
basic_any ::= basic_number | basic_string | basic_array | basic_object
|
|
||||||
basic_number ::= ("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)?
|
|
||||||
basic_string ::= (([\"] basic_string_1 [\"]))
|
|
||||||
basic_string_1 ::= "" | [^"\\\x00-\x1F] basic_string_1 | "\\" escape basic_string_1
|
|
||||||
escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9]
|
|
||||||
basic_array ::= "[" ("" | ws basic_any (ws "," ws basic_any)*) ws "]"
|
basic_array ::= "[" ("" | ws basic_any (ws "," ws basic_any)*) ws "]"
|
||||||
basic_object ::= "{" ("" | ws basic_string ws ":" ws basic_any ( ws "," ws basic_string ws ":" ws basic_any)*) ws "}"
|
basic_object ::= "{" ("" | ws basic_string ws ":" ws basic_any ( ws "," ws basic_string ws ":" ws basic_any)*) ws "}"
|
||||||
ws ::= [ \n\t]*
|
ws ::= [ \n\t]*
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Format-specific extensions
|
||||||
|
json_grammar_ebnf_str = (
|
||||||
|
r"""
|
||||||
|
json ::= basic_array | basic_object
|
||||||
|
basic_any ::= basic_number | basic_string | basic_boolean | basic_null | basic_array | basic_object
|
||||||
|
basic_boolean ::= "true" | "false"
|
||||||
|
basic_null ::= "null"
|
||||||
|
"""
|
||||||
|
+ BASE_PRIMITIVE_GRAMMAR
|
||||||
|
)
|
||||||
|
|
||||||
|
pythonic_grammar_ebnf_str = (
|
||||||
|
r"""
|
||||||
|
pythonic ::= basic_number | basic_string | basic_array | "True" | "False" | "None"
|
||||||
|
basic_any ::= basic_number | basic_string | basic_array | basic_object
|
||||||
|
basic_boolean ::= "True" | "False"
|
||||||
|
basic_null ::= "None"
|
||||||
|
"""
|
||||||
|
+ BASE_PRIMITIVE_GRAMMAR
|
||||||
|
)
|
||||||
|
|
||||||
|
xml_grammar_ebnf_str = (
|
||||||
|
r"""
|
||||||
|
xml ::= xml_element | xml_text
|
||||||
|
xml_element ::= basic_string | basic_number | basic_boolean | basic_null | basic_array | basic_object
|
||||||
|
xml_text ::= [^<>]*
|
||||||
|
basic_any ::= basic_number | basic_string | basic_boolean | basic_null | basic_array | basic_object
|
||||||
|
basic_boolean ::= "true" | "false"
|
||||||
|
basic_null ::= "null"
|
||||||
|
"""
|
||||||
|
+ BASE_PRIMITIVE_GRAMMAR
|
||||||
|
)
|
||||||
|
|
||||||
CALL_RULE_MAP = {
|
CALL_RULE_MAP = {
|
||||||
"pythonic": 'call_{name} ::= "{name}" "(" {arguments_rule} ")"',
|
"pythonic": 'call_{name} ::= "{name}" "(" {arguments_rule} ")"',
|
||||||
"json": 'call_{name} ::= "{{" "\\"name\\"" ":" "\\"{name}\\"" ", " "\\"arguments\\"" ":" {arguments_rule} "}}"',
|
"json": 'call_{name} ::= "{{" "\\"name\\"" ":" "\\"{name}\\"" ", " "\\"arguments\\"" ":" {arguments_rule} "}}"',
|
||||||
|
"xml": 'call_{name} ::= "<function={name}>\\n" {arguments_rule} "\\n</function>"',
|
||||||
}
|
}
|
||||||
|
|
||||||
ARGUMENTS_RULE_MAP = {
|
ARGUMENTS_RULE_MAP = {
|
||||||
"pythonic": "{arg_rules}",
|
"pythonic": "{arg_rules}",
|
||||||
"json": '"{{" {arg_rules} "}}"',
|
"json": '"{{" {arg_rules} "}}"',
|
||||||
|
"xml": "{arg_rules}",
|
||||||
}
|
}
|
||||||
|
|
||||||
KEY_VALUE_RULE_MAP = {
|
KEY_VALUE_RULE_MAP = {
|
||||||
"pythonic": '"{key}" "=" {valrule}',
|
"pythonic": '"{key}" "=" {valrule}',
|
||||||
"json": '"\\"{key}\\"" ":" {valrule}',
|
"json": '"\\"{key}\\"" ":" {valrule}',
|
||||||
|
"xml": '"<parameter={key}>\\n" {valrule} "\\n</parameter>"',
|
||||||
}
|
}
|
||||||
|
|
||||||
JSON_TYPE_MAPPING = {
|
# Base type mapping - most types are the same across formats
|
||||||
|
BASE_TYPE_MAPPING = {
|
||||||
"string": "basic_string",
|
"string": "basic_string",
|
||||||
"number": "basic_number",
|
"number": "basic_number",
|
||||||
"integer": "basic_number",
|
"integer": "basic_number",
|
||||||
@@ -55,19 +77,20 @@ class EBNFComposer:
|
|||||||
"object": "basic_object",
|
"object": "basic_object",
|
||||||
}
|
}
|
||||||
|
|
||||||
PYTHONIC_TYPE_MAPPING = {
|
# Format-specific overrides for types that differ
|
||||||
"string": "basic_string",
|
FORMAT_TYPE_OVERRIDES = {
|
||||||
"number": "basic_number",
|
"pythonic": {
|
||||||
"integer": "basic_number",
|
"boolean": '"True" | "False"',
|
||||||
"boolean": '"True" | "False"',
|
"null": '"None"',
|
||||||
"null": '"None"',
|
},
|
||||||
"array": "basic_array",
|
"xml": {
|
||||||
"object": "basic_object",
|
"string": "xml_text",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_value_rule(
|
def get_value_rule(
|
||||||
prop: dict, function_format: Literal["pythonic", "json"] = "json"
|
prop: dict, function_format: Literal["pythonic", "json", "xml"] = "json"
|
||||||
) -> str:
|
) -> str:
|
||||||
if "enum" in prop:
|
if "enum" in prop:
|
||||||
return EBNFComposer._handle_enum(prop, function_format)
|
return EBNFComposer._handle_enum(prop, function_format)
|
||||||
@@ -83,48 +106,46 @@ class EBNFComposer:
|
|||||||
enum_values = prop["enum"]
|
enum_values = prop["enum"]
|
||||||
prop_type = prop.get("type", "string")
|
prop_type = prop.get("type", "string")
|
||||||
|
|
||||||
# Define formatters for different type/format combinations
|
def format_enum_val(v: Any) -> str:
|
||||||
formatters = {
|
if prop_type == "boolean":
|
||||||
("string", "json"): lambda v: f'"\\"{v}\\""',
|
if function_format == "json" or function_format == "xml":
|
||||||
("string", "pythonic"): lambda v: f'"\\"{v}\\""',
|
return "true" if v else "false"
|
||||||
("number", "json"): str,
|
elif function_format == "pythonic":
|
||||||
("number", "pythonic"): str,
|
return "True" if v else "False"
|
||||||
("integer", "json"): str,
|
else:
|
||||||
("integer", "pythonic"): str,
|
return str(v) # fallback
|
||||||
("boolean", "json"): lambda v: "true" if v else "false",
|
|
||||||
("boolean", "pythonic"): lambda v: "True" if v else "False",
|
|
||||||
}
|
|
||||||
|
|
||||||
# Get the formatter or default to string handling
|
if prop_type == "string":
|
||||||
formatter = formatters.get(
|
if function_format == "xml":
|
||||||
(prop_type, function_format),
|
return f'"{v}"'
|
||||||
formatters[("string", function_format)], # Default to string handling
|
else: # json or pythonic
|
||||||
)
|
return f'"\\"{v}\\""' # escape quote-wrapped string
|
||||||
|
|
||||||
formatted_values = [formatter(value) for value in enum_values]
|
# All other types (number, integer, etc.)
|
||||||
|
return str(v)
|
||||||
|
|
||||||
|
formatted_values = [format_enum_val(v) for v in enum_values]
|
||||||
enum_rule = " | ".join(formatted_values)
|
enum_rule = " | ".join(formatted_values)
|
||||||
|
return f"({enum_rule})" if len(formatted_values) > 1 else enum_rule
|
||||||
|
|
||||||
# Wrap in parentheses if there are multiple values to ensure correct EBNF precedence
|
@staticmethod
|
||||||
if len(formatted_values) > 1:
|
def get_type_mapping(function_format: str) -> Dict[str, str]:
|
||||||
enum_rule = f"({enum_rule})"
|
"""Get the complete type mapping for a given format."""
|
||||||
|
mapping = EBNFComposer.BASE_TYPE_MAPPING.copy()
|
||||||
return enum_rule
|
overrides = EBNFComposer.FORMAT_TYPE_OVERRIDES.get(function_format, {})
|
||||||
|
mapping.update({k: v for k, v in overrides.items() if v is not None})
|
||||||
|
return mapping
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _handle_type(prop: dict, function_format: str) -> str:
|
def _handle_type(prop: dict, function_format: str) -> str:
|
||||||
"""Handle type properties using the appropriate type mapping."""
|
"""Handle type properties using the appropriate type mapping."""
|
||||||
prop_type = prop["type"]
|
prop_type = prop["type"]
|
||||||
type_mapping = (
|
type_mapping = EBNFComposer.get_type_mapping(function_format)
|
||||||
EBNFComposer.PYTHONIC_TYPE_MAPPING
|
|
||||||
if function_format == "pythonic"
|
|
||||||
else EBNFComposer.JSON_TYPE_MAPPING
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(prop_type, list):
|
if isinstance(prop_type, list):
|
||||||
type_rules = [
|
type_rules = [
|
||||||
type_mapping[single_type]
|
type_mapping.get(single_type, function_format)
|
||||||
for single_type in prop_type
|
for single_type in prop_type
|
||||||
if single_type in type_mapping
|
|
||||||
]
|
]
|
||||||
return " | ".join(type_rules) if type_rules else function_format
|
return " | ".join(type_rules) if type_rules else function_format
|
||||||
|
|
||||||
@@ -133,7 +154,7 @@ class EBNFComposer:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def build_ebnf(
|
def build_ebnf(
|
||||||
tools,
|
tools,
|
||||||
function_format: Literal["pythonic", "json"] = "json",
|
function_format: Literal["pythonic", "json", "xml"] = "json",
|
||||||
# Parameters for wrapping the entire sequence of tool calls
|
# Parameters for wrapping the entire sequence of tool calls
|
||||||
sequence_start_token: Optional[str] = None,
|
sequence_start_token: Optional[str] = None,
|
||||||
sequence_end_token: Optional[str] = None,
|
sequence_end_token: Optional[str] = None,
|
||||||
@@ -143,6 +164,7 @@ class EBNFComposer:
|
|||||||
# Parameter for separating multiple tool calls
|
# Parameter for separating multiple tool calls
|
||||||
tool_call_separator: Optional[str] = None,
|
tool_call_separator: Optional[str] = None,
|
||||||
call_rule_fmt: Optional[str] = None,
|
call_rule_fmt: Optional[str] = None,
|
||||||
|
key_value_rule_fmt: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Generalized EBNF builder for all detectors.
|
Generalized EBNF builder for all detectors.
|
||||||
@@ -157,6 +179,9 @@ class EBNFComposer:
|
|||||||
call_rule_fmt: Optional custom format string for call_{name} rule. It should define each function call's format, with
|
call_rule_fmt: Optional custom format string for call_{name} rule. It should define each function call's format, with
|
||||||
the placeholders {name} for the function name and {arguments_rule} for the arguments rule. If None, a default
|
the placeholders {name} for the function name and {arguments_rule} for the arguments rule. If None, a default
|
||||||
format based on function_format will be used.
|
format based on function_format will be used.
|
||||||
|
key_value_rule_fmt: Optional custom format string for key-value pairs. It should define how each parameter is formatted,
|
||||||
|
with placeholders {key} for the parameter name and {valrule} for the value rule. If None, a default format
|
||||||
|
based on function_format will be used.
|
||||||
"""
|
"""
|
||||||
# =================================================================
|
# =================================================================
|
||||||
# Step 1: Determine the root tool calls rule
|
# Step 1: Determine the root tool calls rule
|
||||||
@@ -200,7 +225,11 @@ class EBNFComposer:
|
|||||||
else EBNFComposer.CALL_RULE_MAP[function_format]
|
else EBNFComposer.CALL_RULE_MAP[function_format]
|
||||||
)
|
)
|
||||||
args_template = EBNFComposer.ARGUMENTS_RULE_MAP[function_format]
|
args_template = EBNFComposer.ARGUMENTS_RULE_MAP[function_format]
|
||||||
key_value_template = EBNFComposer.KEY_VALUE_RULE_MAP[function_format]
|
key_value_template = (
|
||||||
|
key_value_rule_fmt
|
||||||
|
if key_value_rule_fmt
|
||||||
|
else EBNFComposer.KEY_VALUE_RULE_MAP[function_format]
|
||||||
|
)
|
||||||
|
|
||||||
# =================================================================
|
# =================================================================
|
||||||
# Step 4: Build rules for each tool
|
# Step 4: Build rules for each tool
|
||||||
@@ -292,10 +321,13 @@ class EBNFComposer:
|
|||||||
# =================================================================
|
# =================================================================
|
||||||
# Step 5: Add base grammar rules
|
# Step 5: Add base grammar rules
|
||||||
# =================================================================
|
# =================================================================
|
||||||
base_grammar = (
|
grammar_dict = {
|
||||||
EBNFComposer.pythonic_grammar_ebnf_str
|
"pythonic": EBNFComposer.pythonic_grammar_ebnf_str,
|
||||||
if function_format == "pythonic"
|
"json": EBNFComposer.json_grammar_ebnf_str,
|
||||||
else EBNFComposer.json_grammar_ebnf_str
|
"xml": EBNFComposer.xml_grammar_ebnf_str,
|
||||||
|
}
|
||||||
|
base_grammar = grammar_dict.get(
|
||||||
|
function_format, EBNFComposer.json_grammar_ebnf_str
|
||||||
)
|
)
|
||||||
ebnf_lines.append(base_grammar)
|
ebnf_lines.append(base_grammar)
|
||||||
|
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ from sglang.srt.function_call.kimik2_detector import KimiK2Detector
|
|||||||
from sglang.srt.function_call.llama32_detector import Llama32Detector
|
from sglang.srt.function_call.llama32_detector import Llama32Detector
|
||||||
from sglang.srt.function_call.mistral_detector import MistralDetector
|
from sglang.srt.function_call.mistral_detector import MistralDetector
|
||||||
from sglang.srt.function_call.pythonic_detector import PythonicDetector
|
from sglang.srt.function_call.pythonic_detector import PythonicDetector
|
||||||
from sglang.srt.function_call.qwen3_detector import Qwen3XMLDetector
|
from sglang.srt.function_call.qwen3_coder_detector import Qwen3CoderDetector
|
||||||
from sglang.srt.function_call.qwen25_detector import Qwen25Detector
|
from sglang.srt.function_call.qwen25_detector import Qwen25Detector
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -36,7 +36,7 @@ class FunctionCallParser:
|
|||||||
"deepseekv3": DeepSeekV3Detector,
|
"deepseekv3": DeepSeekV3Detector,
|
||||||
"pythonic": PythonicDetector,
|
"pythonic": PythonicDetector,
|
||||||
"kimi_k2": KimiK2Detector,
|
"kimi_k2": KimiK2Detector,
|
||||||
"qwen3": Qwen3XMLDetector,
|
"qwen3_coder": Qwen3CoderDetector,
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, tools: List[Tool], tool_call_parser: str):
|
def __init__(self, tools: List[Tool], tool_call_parser: str):
|
||||||
@@ -155,9 +155,9 @@ class FunctionCallParser:
|
|||||||
or None if no constraint applies.
|
or None if no constraint applies.
|
||||||
"""
|
"""
|
||||||
# NOTE: structural_tag only supports JSON-compatible content between the begin and end.
|
# NOTE: structural_tag only supports JSON-compatible content between the begin and end.
|
||||||
# It cannot parse or validate Python syntax like function calls.
|
# It cannot parse or validate function call Pythonic or XML-ish syntax.
|
||||||
if (
|
if (
|
||||||
not isinstance(self.detector, PythonicDetector)
|
self.detector.supports_structural_tag()
|
||||||
and tool_choice == "auto"
|
and tool_choice == "auto"
|
||||||
and any(tool.function.strict for tool in self.tools)
|
and any(tool.function.strict for tool in self.tools)
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ from sglang.srt.entrypoints.openai.protocol import Tool
|
|||||||
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
||||||
from sglang.srt.function_call.core_types import (
|
from sglang.srt.function_call.core_types import (
|
||||||
StreamingParseResult,
|
StreamingParseResult,
|
||||||
StructureInfo,
|
|
||||||
ToolCallItem,
|
ToolCallItem,
|
||||||
_GetInfoFunc,
|
_GetInfoFunc,
|
||||||
)
|
)
|
||||||
@@ -216,11 +215,11 @@ class PythonicDetector(BaseFormatDetector):
|
|||||||
else:
|
else:
|
||||||
raise ValueError("Tool call arguments must be literals")
|
raise ValueError("Tool call arguments must be literals")
|
||||||
|
|
||||||
def structure_info(self) -> _GetInfoFunc:
|
def supports_structural_tag(self) -> bool:
|
||||||
def info(name: str):
|
return False
|
||||||
return StructureInfo(begin=f"[{name}(", end=")]", trigger=f"[{name}(")
|
|
||||||
|
|
||||||
return info
|
def structure_info(self) -> _GetInfoFunc:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
def build_ebnf(self, tools: List[Tool]) -> Optional[str]:
|
def build_ebnf(self, tools: List[Tool]) -> Optional[str]:
|
||||||
return EBNFComposer.build_ebnf(
|
return EBNFComposer.build_ebnf(
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ from sglang.srt.entrypoints.openai.protocol import Tool
|
|||||||
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
||||||
from sglang.srt.function_call.core_types import (
|
from sglang.srt.function_call.core_types import (
|
||||||
StreamingParseResult,
|
StreamingParseResult,
|
||||||
StructureInfo,
|
|
||||||
ToolCallItem,
|
ToolCallItem,
|
||||||
_GetInfoFunc,
|
_GetInfoFunc,
|
||||||
)
|
)
|
||||||
@@ -29,7 +28,7 @@ def _safe_val(raw: str) -> Any:
|
|||||||
return raw
|
return raw
|
||||||
|
|
||||||
|
|
||||||
class Qwen3XMLDetector(BaseFormatDetector):
|
class Qwen3CoderDetector(BaseFormatDetector):
|
||||||
"""
|
"""
|
||||||
Detector for Qwen 3 models.
|
Detector for Qwen 3 models.
|
||||||
Assumes function call format:
|
Assumes function call format:
|
||||||
@@ -127,24 +126,26 @@ class Qwen3XMLDetector(BaseFormatDetector):
|
|||||||
params[pname] = _safe_val(pval)
|
params[pname] = _safe_val(pval)
|
||||||
raw = {"name": fname, "arguments": params}
|
raw = {"name": fname, "arguments": params}
|
||||||
try:
|
try:
|
||||||
|
# TODO: fix idx in function call, the index for a function
|
||||||
|
# call will always be -1 in parse_base_json
|
||||||
res.extend(self.parse_base_json(raw, tools))
|
res.extend(self.parse_base_json(raw, tools))
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.warning("invalid tool call for %s dropped", fname)
|
logger.warning("invalid tool call for %s dropped", fname)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def structure_info(self) -> _GetInfoFunc:
|
def supports_structural_tag(self) -> bool:
|
||||||
return lambda n: StructureInfo(
|
return False
|
||||||
begin=f"{self.tool_call_start_token}\n<function={n}>",
|
|
||||||
end=f"</function>\n{self.tool_call_end_token}",
|
def structure_info(self) -> _GetInfoFunc:
|
||||||
trigger=self.tool_call_start_token,
|
raise NotImplementedError
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: fake ebnf for xml + outlines backend
|
|
||||||
def build_ebnf(self, tools: List[Tool]):
|
def build_ebnf(self, tools: List[Tool]):
|
||||||
return EBNFComposer.build_ebnf(
|
return EBNFComposer.build_ebnf(
|
||||||
tools,
|
tools,
|
||||||
individual_call_start_token=self.tool_call_start_token.replace("\n", "\\n"),
|
individual_call_start_token=self.tool_call_start_token.replace("\n", "\\n"),
|
||||||
individual_call_end_token=self.tool_call_end_token.replace("\n", "\\n"),
|
individual_call_end_token=self.tool_call_end_token.replace("\n", "\\n"),
|
||||||
tool_call_separator="\\n",
|
tool_call_separator="\\n",
|
||||||
function_format="json",
|
function_format="xml",
|
||||||
|
call_rule_fmt='"<function={name}>\\n" {arguments_rule} "\\n</function>"',
|
||||||
|
key_value_rule_fmt='"<parameter={key}>\\n" {valrule} "\\n</parameter>"',
|
||||||
)
|
)
|
||||||
@@ -1099,10 +1099,10 @@ class ServerArgs:
|
|||||||
"deepseekv3",
|
"deepseekv3",
|
||||||
"pythonic",
|
"pythonic",
|
||||||
"kimi_k2",
|
"kimi_k2",
|
||||||
"qwen3",
|
"qwen3_coder",
|
||||||
],
|
],
|
||||||
default=ServerArgs.tool_call_parser,
|
default=ServerArgs.tool_call_parser,
|
||||||
help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', 'pythonic', and 'kimi_k2'.",
|
help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', 'pythonic', 'kimi_k2', and 'qwen3_coder'.",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Data parallelism
|
# Data parallelism
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from sglang.srt.function_call.kimik2_detector import KimiK2Detector
|
|||||||
from sglang.srt.function_call.llama32_detector import Llama32Detector
|
from sglang.srt.function_call.llama32_detector import Llama32Detector
|
||||||
from sglang.srt.function_call.mistral_detector import MistralDetector
|
from sglang.srt.function_call.mistral_detector import MistralDetector
|
||||||
from sglang.srt.function_call.pythonic_detector import PythonicDetector
|
from sglang.srt.function_call.pythonic_detector import PythonicDetector
|
||||||
|
from sglang.srt.function_call.qwen3_coder_detector import Qwen3CoderDetector
|
||||||
from sglang.srt.function_call.qwen25_detector import Qwen25Detector
|
from sglang.srt.function_call.qwen25_detector import Qwen25Detector
|
||||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||||
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||||
@@ -507,6 +508,7 @@ class TestEBNFGeneration(unittest.TestCase):
|
|||||||
self.llama32_detector = Llama32Detector()
|
self.llama32_detector = Llama32Detector()
|
||||||
self.mistral_detector = MistralDetector()
|
self.mistral_detector = MistralDetector()
|
||||||
self.qwen25_detector = Qwen25Detector()
|
self.qwen25_detector = Qwen25Detector()
|
||||||
|
self.qwen3_coder_detector = Qwen3CoderDetector()
|
||||||
self.kimik2_detector = KimiK2Detector()
|
self.kimik2_detector = KimiK2Detector()
|
||||||
|
|
||||||
def test_pythonic_detector_ebnf(self):
|
def test_pythonic_detector_ebnf(self):
|
||||||
@@ -620,6 +622,26 @@ class TestEBNFGeneration(unittest.TestCase):
|
|||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
self.fail(f"Failed to compile EBNF: {e}")
|
self.fail(f"Failed to compile EBNF: {e}")
|
||||||
|
|
||||||
|
def test_qwen3_coder_detector_ebnf(self):
|
||||||
|
"""Test that the Qwen3CoderDetector generates valid EBNF."""
|
||||||
|
ebnf = self.qwen3_coder_detector.build_ebnf(self.tools)
|
||||||
|
self.assertIsNotNone(ebnf)
|
||||||
|
# Check that the EBNF contains expected patterns for XML format
|
||||||
|
self.assertIn("<tool_call>", ebnf)
|
||||||
|
self.assertIn("</tool_call>", ebnf)
|
||||||
|
self.assertIn('"<function=get_weather>\\n"', ebnf)
|
||||||
|
self.assertIn('"\\n</function>"', ebnf)
|
||||||
|
self.assertIn('"<parameter=location>\\n"', ebnf)
|
||||||
|
self.assertIn('"\\n</parameter>"', ebnf)
|
||||||
|
# Check that it uses xml_text for string parameters
|
||||||
|
self.assertIn("xml_text", ebnf)
|
||||||
|
# Validate that the EBNF can be compiled by GrammarCompiler
|
||||||
|
try:
|
||||||
|
ctx = self.grammar_compiler.compile_grammar(ebnf)
|
||||||
|
self.assertIsNotNone(ctx, "EBNF should be valid and compile successfully")
|
||||||
|
except RuntimeError as e:
|
||||||
|
self.fail(f"Failed to compile EBNF: {e}")
|
||||||
|
|
||||||
def test_weather_function_optional_parameter_handling(self):
|
def test_weather_function_optional_parameter_handling(self):
|
||||||
"""Test that weather function with optional unit parameter generates correct EBNF without trailing commas."""
|
"""Test that weather function with optional unit parameter generates correct EBNF without trailing commas."""
|
||||||
# Create a weather tool with required location and optional unit
|
# Create a weather tool with required location and optional unit
|
||||||
@@ -1464,5 +1486,438 @@ class TestDeepSeekV3Detector(unittest.TestCase):
|
|||||||
self.assertEqual(params2["city"], "Beijing")
|
self.assertEqual(params2["city"], "Beijing")
|
||||||
|
|
||||||
|
|
||||||
|
class TestQwen3CoderDetector(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
# Create sample tools for testing
|
||||||
|
self.tools = [
|
||||||
|
Tool(
|
||||||
|
type="function",
|
||||||
|
function=Function(
|
||||||
|
name="get_current_weather",
|
||||||
|
description="Get the current weather",
|
||||||
|
parameters={
|
||||||
|
"properties": {
|
||||||
|
"city": {"type": "string", "description": "The city name"},
|
||||||
|
"state": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The state code",
|
||||||
|
},
|
||||||
|
"unit": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["fahrenheit", "celsius"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["city", "state"],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
),
|
||||||
|
Tool(
|
||||||
|
type="function",
|
||||||
|
function=Function(
|
||||||
|
name="calculate_area",
|
||||||
|
description="Calculate area of a shape",
|
||||||
|
parameters={
|
||||||
|
"properties": {
|
||||||
|
"shape": {"type": "string"},
|
||||||
|
"dimensions": {"type": "object"},
|
||||||
|
"precision": {"type": "integer"},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
self.detector = Qwen3CoderDetector()
|
||||||
|
|
||||||
|
def test_has_tool_call(self):
|
||||||
|
"""Test detection of tool call markers."""
|
||||||
|
self.assertTrue(self.detector.has_tool_call("<tool_call>test</tool_call>"))
|
||||||
|
self.assertFalse(self.detector.has_tool_call("No tool call here"))
|
||||||
|
|
||||||
|
def test_detect_and_parse_no_tools(self):
|
||||||
|
"""Test parsing text without tool calls."""
|
||||||
|
model_output = "This is a test response without any tool calls"
|
||||||
|
result = self.detector.detect_and_parse(model_output, tools=[])
|
||||||
|
self.assertEqual(result.normal_text, model_output)
|
||||||
|
self.assertEqual(result.calls, [])
|
||||||
|
|
||||||
|
def test_detect_and_parse_single_tool(self):
|
||||||
|
"""Test parsing a single tool call."""
|
||||||
|
model_output = """<tool_call>
|
||||||
|
<function=get_current_weather>
|
||||||
|
<parameter=city>
|
||||||
|
Dallas
|
||||||
|
</parameter>
|
||||||
|
<parameter=state>
|
||||||
|
TX
|
||||||
|
</parameter>
|
||||||
|
<parameter=unit>
|
||||||
|
fahrenheit
|
||||||
|
</parameter>
|
||||||
|
</function>
|
||||||
|
</tool_call>"""
|
||||||
|
|
||||||
|
result = self.detector.detect_and_parse(model_output, tools=self.tools)
|
||||||
|
|
||||||
|
self.assertEqual(result.normal_text, "")
|
||||||
|
self.assertEqual(len(result.calls), 1)
|
||||||
|
self.assertEqual(result.calls[0].name, "get_current_weather")
|
||||||
|
|
||||||
|
params = json.loads(result.calls[0].parameters)
|
||||||
|
self.assertEqual(params["city"], "Dallas")
|
||||||
|
self.assertEqual(params["state"], "TX")
|
||||||
|
self.assertEqual(params["unit"], "fahrenheit")
|
||||||
|
|
||||||
|
def test_detect_and_parse_with_content(self):
|
||||||
|
"""Test parsing tool call with surrounding text."""
|
||||||
|
model_output = """Sure! Let me check the weather for you.<tool_call>
|
||||||
|
<function=get_current_weather>
|
||||||
|
<parameter=city>
|
||||||
|
Dallas
|
||||||
|
</parameter>
|
||||||
|
<parameter=state>
|
||||||
|
TX
|
||||||
|
</parameter>
|
||||||
|
<parameter=unit>
|
||||||
|
fahrenheit
|
||||||
|
</parameter>
|
||||||
|
</function>
|
||||||
|
</tool_call>"""
|
||||||
|
|
||||||
|
result = self.detector.detect_and_parse(model_output, tools=self.tools)
|
||||||
|
|
||||||
|
self.assertEqual(result.normal_text, "Sure! Let me check the weather for you.")
|
||||||
|
self.assertEqual(len(result.calls), 1)
|
||||||
|
self.assertEqual(result.calls[0].name, "get_current_weather")
|
||||||
|
|
||||||
|
def test_detect_and_parse_multiline_param(self):
|
||||||
|
"""Test parsing tool call with multiline parameter values."""
|
||||||
|
model_output = """<tool_call>
|
||||||
|
<function=calculate_area>
|
||||||
|
<parameter=shape>
|
||||||
|
rectangle
|
||||||
|
</parameter>
|
||||||
|
<parameter=dimensions>
|
||||||
|
{"width": 10,
|
||||||
|
"height": 20}
|
||||||
|
</parameter>
|
||||||
|
<parameter=precision>
|
||||||
|
2
|
||||||
|
</parameter>
|
||||||
|
</function>
|
||||||
|
</tool_call>"""
|
||||||
|
|
||||||
|
result = self.detector.detect_and_parse(model_output, tools=self.tools)
|
||||||
|
|
||||||
|
self.assertEqual(len(result.calls), 1)
|
||||||
|
self.assertEqual(result.calls[0].name, "calculate_area")
|
||||||
|
|
||||||
|
params = json.loads(result.calls[0].parameters)
|
||||||
|
self.assertEqual(params["shape"], "rectangle")
|
||||||
|
self.assertEqual(params["dimensions"], {"width": 10, "height": 20})
|
||||||
|
self.assertEqual(params["precision"], 2)
|
||||||
|
|
||||||
|
def test_detect_and_parse_parallel_tools(self):
|
||||||
|
"""Test parsing multiple tool calls."""
|
||||||
|
model_output = """<tool_call>
|
||||||
|
<function=get_current_weather>
|
||||||
|
<parameter=city>
|
||||||
|
Dallas
|
||||||
|
</parameter>
|
||||||
|
<parameter=state>
|
||||||
|
TX
|
||||||
|
</parameter>
|
||||||
|
<parameter=unit>
|
||||||
|
fahrenheit
|
||||||
|
</parameter>
|
||||||
|
</function>
|
||||||
|
</tool_call>
|
||||||
|
<tool_call>
|
||||||
|
<function=get_current_weather>
|
||||||
|
<parameter=city>
|
||||||
|
Orlando
|
||||||
|
</parameter>
|
||||||
|
<parameter=state>
|
||||||
|
FL
|
||||||
|
</parameter>
|
||||||
|
<parameter=unit>
|
||||||
|
fahrenheit
|
||||||
|
</parameter>
|
||||||
|
</function>
|
||||||
|
</tool_call>"""
|
||||||
|
|
||||||
|
result = self.detector.detect_and_parse(model_output, tools=self.tools)
|
||||||
|
|
||||||
|
self.assertEqual(result.normal_text, "\n")
|
||||||
|
self.assertEqual(len(result.calls), 2)
|
||||||
|
|
||||||
|
# First call
|
||||||
|
self.assertEqual(result.calls[0].name, "get_current_weather")
|
||||||
|
params1 = json.loads(result.calls[0].parameters)
|
||||||
|
self.assertEqual(params1["city"], "Dallas")
|
||||||
|
self.assertEqual(params1["state"], "TX")
|
||||||
|
|
||||||
|
# Second call
|
||||||
|
self.assertEqual(result.calls[1].name, "get_current_weather")
|
||||||
|
params2 = json.loads(result.calls[1].parameters)
|
||||||
|
self.assertEqual(params2["city"], "Orlando")
|
||||||
|
self.assertEqual(params2["state"], "FL")
|
||||||
|
|
||||||
|
def test_parse_streaming_simple(self):
|
||||||
|
"""Test basic streaming parsing."""
|
||||||
|
chunks = [
|
||||||
|
"Sure! ",
|
||||||
|
"Let me check ",
|
||||||
|
"the weather.",
|
||||||
|
"<tool_call>",
|
||||||
|
"\n<function=get_current_weather>",
|
||||||
|
"\n<parameter=city>",
|
||||||
|
"\nDallas",
|
||||||
|
"\n</parameter>",
|
||||||
|
"\n<parameter=state>",
|
||||||
|
"\nTX",
|
||||||
|
"\n</parameter>",
|
||||||
|
"\n</function>",
|
||||||
|
"\n</tool_call>",
|
||||||
|
]
|
||||||
|
|
||||||
|
accumulated_text = ""
|
||||||
|
accumulated_calls = []
|
||||||
|
|
||||||
|
for chunk in chunks:
|
||||||
|
result = self.detector.parse_streaming_increment(chunk, tools=self.tools)
|
||||||
|
accumulated_text += result.normal_text
|
||||||
|
accumulated_calls.extend(result.calls)
|
||||||
|
|
||||||
|
self.assertEqual(accumulated_text, "Sure! Let me check the weather.")
|
||||||
|
self.assertEqual(len(accumulated_calls), 1)
|
||||||
|
self.assertEqual(accumulated_calls[0].name, "get_current_weather")
|
||||||
|
|
||||||
|
params = json.loads(accumulated_calls[0].parameters)
|
||||||
|
self.assertEqual(params["city"], "Dallas")
|
||||||
|
self.assertEqual(params["state"], "TX")
|
||||||
|
|
||||||
|
def test_parse_streaming_incomplete(self):
|
||||||
|
"""Test streaming with incomplete tool call."""
|
||||||
|
# Send incomplete tool call
|
||||||
|
chunks = [
|
||||||
|
"<tool_call>",
|
||||||
|
"\n<function=get_current_weather>",
|
||||||
|
"\n<parameter=city>",
|
||||||
|
"\nDallas",
|
||||||
|
"\n</parameter>",
|
||||||
|
"\n<parameter=state>",
|
||||||
|
"\nTX",
|
||||||
|
# Missing </parameter>, </function>, </tool_call>
|
||||||
|
]
|
||||||
|
|
||||||
|
accumulated_calls = []
|
||||||
|
for chunk in chunks:
|
||||||
|
result = self.detector.parse_streaming_increment(chunk, tools=self.tools)
|
||||||
|
accumulated_calls.extend(result.calls)
|
||||||
|
|
||||||
|
# Should not have any complete calls yet
|
||||||
|
self.assertEqual(len(accumulated_calls), 0)
|
||||||
|
|
||||||
|
# Now complete it
|
||||||
|
result = self.detector.parse_streaming_increment(
|
||||||
|
"\n</parameter>\n</function>\n</tool_call>", tools=self.tools
|
||||||
|
)
|
||||||
|
self.assertEqual(len(result.calls), 1)
|
||||||
|
self.assertEqual(result.calls[0].name, "get_current_weather")
|
||||||
|
|
||||||
|
def test_edge_case_no_parameters(self):
|
||||||
|
"""Test tool call without parameters."""
|
||||||
|
model_output = """<tool_call>
|
||||||
|
<function=get_current_weather>
|
||||||
|
</function>
|
||||||
|
</tool_call>"""
|
||||||
|
|
||||||
|
result = self.detector.detect_and_parse(model_output, tools=self.tools)
|
||||||
|
self.assertEqual(len(result.calls), 1)
|
||||||
|
self.assertEqual(result.calls[0].name, "get_current_weather")
|
||||||
|
self.assertEqual(json.loads(result.calls[0].parameters), {})
|
||||||
|
|
||||||
|
def test_edge_case_special_chars_in_value(self):
|
||||||
|
"""Test parameter with special characters in value."""
|
||||||
|
model_output = """<tool_call>
|
||||||
|
<function=get_current_weather>
|
||||||
|
<parameter=city>
|
||||||
|
Dallas->TX
|
||||||
|
</parameter>
|
||||||
|
</function>
|
||||||
|
</tool_call>"""
|
||||||
|
|
||||||
|
result = self.detector.detect_and_parse(model_output, tools=self.tools)
|
||||||
|
self.assertEqual(len(result.calls), 1)
|
||||||
|
|
||||||
|
params = json.loads(result.calls[0].parameters)
|
||||||
|
self.assertEqual(params["city"], "Dallas->TX")
|
||||||
|
|
||||||
|
def test_extract_tool_calls_fallback_no_tags(self):
|
||||||
|
"""Test fallback parsing when XML tags are missing (just function without tool_call wrapper)."""
|
||||||
|
model_output = """<function=get_current_weather>
|
||||||
|
<parameter=city>
|
||||||
|
Dallas
|
||||||
|
</parameter>
|
||||||
|
<parameter=state>
|
||||||
|
TX
|
||||||
|
</parameter>
|
||||||
|
</function>"""
|
||||||
|
|
||||||
|
result = self.detector.detect_and_parse(model_output, tools=self.tools)
|
||||||
|
|
||||||
|
self.assertIsNotNone(result)
|
||||||
|
|
||||||
|
def test_extract_tool_calls_type_conversion(self):
|
||||||
|
"""Test parameter type conversion based on tool schema."""
|
||||||
|
test_tool = Tool(
|
||||||
|
type="function",
|
||||||
|
function=Function(
|
||||||
|
name="test_types",
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"int_param": {"type": "integer"},
|
||||||
|
"float_param": {"type": "float"},
|
||||||
|
"bool_param": {"type": "boolean"},
|
||||||
|
"str_param": {"type": "string"},
|
||||||
|
"obj_param": {"type": "object"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
model_output = """<tool_call>
|
||||||
|
<function=test_types>
|
||||||
|
<parameter=int_param>
|
||||||
|
42
|
||||||
|
</parameter>
|
||||||
|
<parameter=float_param>
|
||||||
|
3.14
|
||||||
|
</parameter>
|
||||||
|
<parameter=bool_param>
|
||||||
|
true
|
||||||
|
</parameter>
|
||||||
|
<parameter=str_param>
|
||||||
|
hello world
|
||||||
|
</parameter>
|
||||||
|
<parameter=obj_param>
|
||||||
|
{"key": "value"}
|
||||||
|
</parameter>
|
||||||
|
</function>
|
||||||
|
</tool_call>"""
|
||||||
|
|
||||||
|
result = self.detector.detect_and_parse(model_output, tools=[test_tool])
|
||||||
|
|
||||||
|
self.assertEqual(len(result.calls), 1)
|
||||||
|
params = json.loads(result.calls[0].parameters)
|
||||||
|
self.assertEqual(params["int_param"], 42)
|
||||||
|
self.assertEqual(params["float_param"], 3.14)
|
||||||
|
self.assertEqual(params["bool_param"], True)
|
||||||
|
self.assertEqual(params["str_param"], "hello world")
|
||||||
|
self.assertEqual(params["obj_param"], {"key": "value"})
|
||||||
|
|
||||||
|
def test_parse_streaming_incremental(self):
|
||||||
|
"""Test that streaming is truly incremental with very small chunks."""
|
||||||
|
model_output = """I'll check the weather.<tool_call>
|
||||||
|
<function=get_current_weather>
|
||||||
|
<parameter=city>
|
||||||
|
Dallas
|
||||||
|
</parameter>
|
||||||
|
<parameter=state>
|
||||||
|
TX
|
||||||
|
</parameter>
|
||||||
|
</function>
|
||||||
|
</tool_call>"""
|
||||||
|
|
||||||
|
# Simulate more realistic token-based chunks where <tool_call> is a single token
|
||||||
|
chunks = [
|
||||||
|
"I'll check the weather.",
|
||||||
|
"<tool_call>",
|
||||||
|
"\n<function=get_current_weather>\n",
|
||||||
|
"<parameter=city>\n",
|
||||||
|
"Dallas\n",
|
||||||
|
"</parameter>\n",
|
||||||
|
"<parameter=state>\n",
|
||||||
|
"TX\n",
|
||||||
|
"</parameter>\n",
|
||||||
|
"</function>\n",
|
||||||
|
"</tool_call>",
|
||||||
|
]
|
||||||
|
|
||||||
|
accumulated_text = ""
|
||||||
|
accumulated_calls = []
|
||||||
|
chunks_count = 0
|
||||||
|
|
||||||
|
for chunk in chunks:
|
||||||
|
result = self.detector.parse_streaming_increment(chunk, tools=self.tools)
|
||||||
|
accumulated_text += result.normal_text
|
||||||
|
accumulated_calls.extend(result.calls)
|
||||||
|
chunks_count += 1
|
||||||
|
|
||||||
|
self.assertGreater(chunks_count, 3)
|
||||||
|
|
||||||
|
# Verify the accumulated results
|
||||||
|
self.assertIn("I'll check the weather.", accumulated_text)
|
||||||
|
self.assertEqual(len(accumulated_calls), 1)
|
||||||
|
self.assertEqual(accumulated_calls[0].name, "get_current_weather")
|
||||||
|
|
||||||
|
params = json.loads(accumulated_calls[0].parameters)
|
||||||
|
self.assertEqual(params["city"], "Dallas")
|
||||||
|
self.assertEqual(params["state"], "TX")
|
||||||
|
|
||||||
|
def test_parse_streaming_multiple_tools(self):
|
||||||
|
"""Test streaming with multiple tool calls."""
|
||||||
|
model_output = """<tool_call>
|
||||||
|
<function=get_current_weather>
|
||||||
|
<parameter=city>
|
||||||
|
Dallas
|
||||||
|
</parameter>
|
||||||
|
<parameter=state>
|
||||||
|
TX
|
||||||
|
</parameter>
|
||||||
|
</function>
|
||||||
|
</tool_call>
|
||||||
|
Some text in between.
|
||||||
|
<tool_call>
|
||||||
|
<function=calculate_area>
|
||||||
|
<parameter=shape>
|
||||||
|
circle
|
||||||
|
</parameter>
|
||||||
|
<parameter=dimensions>
|
||||||
|
{"radius": 5}
|
||||||
|
</parameter>
|
||||||
|
</function>
|
||||||
|
</tool_call>"""
|
||||||
|
|
||||||
|
# Simulate streaming by chunks
|
||||||
|
chunk_size = 20
|
||||||
|
chunks = [
|
||||||
|
model_output[i : i + chunk_size]
|
||||||
|
for i in range(0, len(model_output), chunk_size)
|
||||||
|
]
|
||||||
|
|
||||||
|
accumulated_text = ""
|
||||||
|
accumulated_calls = []
|
||||||
|
|
||||||
|
for chunk in chunks:
|
||||||
|
result = self.detector.parse_streaming_increment(chunk, tools=self.tools)
|
||||||
|
accumulated_text += result.normal_text
|
||||||
|
accumulated_calls.extend(result.calls)
|
||||||
|
|
||||||
|
self.assertIn("Some text in between.", accumulated_text)
|
||||||
|
self.assertEqual(len(accumulated_calls), 2)
|
||||||
|
self.assertEqual(accumulated_calls[0].name, "get_current_weather")
|
||||||
|
self.assertEqual(accumulated_calls[1].name, "calculate_area")
|
||||||
|
|
||||||
|
# Verify parameters
|
||||||
|
params1 = json.loads(accumulated_calls[0].parameters)
|
||||||
|
self.assertEqual(params1["city"], "Dallas")
|
||||||
|
|
||||||
|
params2 = json.loads(accumulated_calls[1].parameters)
|
||||||
|
self.assertEqual(params2["shape"], "circle")
|
||||||
|
self.assertEqual(params2["dimensions"], {"radius": 5})
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user