diff --git a/python/sglang/srt/function_call/base_format_detector.py b/python/sglang/srt/function_call/base_format_detector.py
index d9ac71253..39bb92f5f 100644
--- a/python/sglang/srt/function_call/base_format_detector.py
+++ b/python/sglang/srt/function_call/base_format_detector.py
@@ -321,6 +321,10 @@ class BaseFormatDetector(ABC):
"""
raise NotImplementedError()
+ def supports_structural_tag(self) -> bool:
+ """Return True if this detector supports structural tag format."""
+ return True
+
@abstractmethod
def structure_info(self) -> _GetInfoFunc:
"""
diff --git a/python/sglang/srt/function_call/ebnf_composer.py b/python/sglang/srt/function_call/ebnf_composer.py
index 60035e05d..85d6039bb 100644
--- a/python/sglang/srt/function_call/ebnf_composer.py
+++ b/python/sglang/srt/function_call/ebnf_composer.py
@@ -1,51 +1,73 @@
-from typing import Literal, Optional
+from typing import Any, Dict, Literal, Optional
class EBNFComposer:
# Adapted from https://xgrammar.mlc.ai/docs/how_to/ebnf_guided_generation.html#try-out-via-hf-transformers
- json_grammar_ebnf_str = r"""
- json ::= basic_array | basic_object
- basic_any ::= basic_number | basic_string | basic_boolean | basic_null | basic_array | basic_object
+ # Shared primitive grammar rules used across all formats
+ BASE_PRIMITIVE_GRAMMAR = r"""
+ basic_string ::= (([\"] basic_string_1 [\"]))
+ basic_string_1 ::= "" | [^"\\\x00-\x1F] basic_string_1 | "\\" escape basic_string_1
+ escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9]{4}
basic_integer ::= ("0" | "-"? [1-9] [0-9]*) ".0"?
basic_number ::= ("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)?
- basic_string ::= (([\"] basic_string_1 [\"]))
- basic_string_1 ::= "" | [^"\\\x00-\x1F] basic_string_1 | "\\" escape basic_string_1
- escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9]
- basic_boolean ::= "true" | "false"
- basic_null ::= "null"
- basic_array ::= "[" ("" | ws basic_any (ws "," ws basic_any)*) ws "]"
- basic_object ::= "{" ("" | ws basic_string ws ":" ws basic_any ( ws "," ws basic_string ws ":" ws basic_any)*) ws "}"
- ws ::= [ \n\t]*
- """
-
- pythonic_grammar_ebnf_str = r"""
- pythonic ::= basic_number | basic_string | basic_array | "True" | "False" | "None"
- basic_any ::= basic_number | basic_string | basic_array | basic_object
- basic_number ::= ("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)?
- basic_string ::= (([\"] basic_string_1 [\"]))
- basic_string_1 ::= "" | [^"\\\x00-\x1F] basic_string_1 | "\\" escape basic_string_1
- escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9]
basic_array ::= "[" ("" | ws basic_any (ws "," ws basic_any)*) ws "]"
basic_object ::= "{" ("" | ws basic_string ws ":" ws basic_any ( ws "," ws basic_string ws ":" ws basic_any)*) ws "}"
ws ::= [ \n\t]*
"""
+ # Format-specific extensions
+ json_grammar_ebnf_str = (
+ r"""
+ json ::= basic_array | basic_object
+ basic_any ::= basic_number | basic_string | basic_boolean | basic_null | basic_array | basic_object
+ basic_boolean ::= "true" | "false"
+ basic_null ::= "null"
+ """
+ + BASE_PRIMITIVE_GRAMMAR
+ )
+
+ pythonic_grammar_ebnf_str = (
+ r"""
+ pythonic ::= basic_number | basic_string | basic_array | "True" | "False" | "None"
+ basic_any ::= basic_number | basic_string | basic_array | basic_object
+ basic_boolean ::= "True" | "False"
+ basic_null ::= "None"
+ """
+ + BASE_PRIMITIVE_GRAMMAR
+ )
+
+ xml_grammar_ebnf_str = (
+ r"""
+ xml ::= xml_element | xml_text
+ xml_element ::= basic_string | basic_number | basic_boolean | basic_null | basic_array | basic_object
+ xml_text ::= [^<>]*
+ basic_any ::= basic_number | basic_string | basic_boolean | basic_null | basic_array | basic_object
+ basic_boolean ::= "true" | "false"
+ basic_null ::= "null"
+ """
+ + BASE_PRIMITIVE_GRAMMAR
+ )
+
CALL_RULE_MAP = {
"pythonic": 'call_{name} ::= "{name}" "(" {arguments_rule} ")"',
"json": 'call_{name} ::= "{{" "\\"name\\"" ":" "\\"{name}\\"" ", " "\\"arguments\\"" ":" {arguments_rule} "}}"',
+ "xml": 'call_{name} ::= "\\n" {arguments_rule} "\\n"',
}
ARGUMENTS_RULE_MAP = {
"pythonic": "{arg_rules}",
"json": '"{{" {arg_rules} "}}"',
+ "xml": "{arg_rules}",
}
KEY_VALUE_RULE_MAP = {
"pythonic": '"{key}" "=" {valrule}',
"json": '"\\"{key}\\"" ":" {valrule}',
+ "xml": '"\\n" {valrule} "\\n"',
}
- JSON_TYPE_MAPPING = {
+ # Base type mapping - most types are the same across formats
+ BASE_TYPE_MAPPING = {
"string": "basic_string",
"number": "basic_number",
"integer": "basic_number",
@@ -55,19 +77,20 @@ class EBNFComposer:
"object": "basic_object",
}
- PYTHONIC_TYPE_MAPPING = {
- "string": "basic_string",
- "number": "basic_number",
- "integer": "basic_number",
- "boolean": '"True" | "False"',
- "null": '"None"',
- "array": "basic_array",
- "object": "basic_object",
+ # Format-specific overrides for types that differ
+ FORMAT_TYPE_OVERRIDES = {
+ "pythonic": {
+ "boolean": '"True" | "False"',
+ "null": '"None"',
+ },
+ "xml": {
+ "string": "xml_text",
+ },
}
@staticmethod
def get_value_rule(
- prop: dict, function_format: Literal["pythonic", "json"] = "json"
+ prop: dict, function_format: Literal["pythonic", "json", "xml"] = "json"
) -> str:
if "enum" in prop:
return EBNFComposer._handle_enum(prop, function_format)
@@ -83,48 +106,46 @@ class EBNFComposer:
enum_values = prop["enum"]
prop_type = prop.get("type", "string")
- # Define formatters for different type/format combinations
- formatters = {
- ("string", "json"): lambda v: f'"\\"{v}\\""',
- ("string", "pythonic"): lambda v: f'"\\"{v}\\""',
- ("number", "json"): str,
- ("number", "pythonic"): str,
- ("integer", "json"): str,
- ("integer", "pythonic"): str,
- ("boolean", "json"): lambda v: "true" if v else "false",
- ("boolean", "pythonic"): lambda v: "True" if v else "False",
- }
+ def format_enum_val(v: Any) -> str:
+ if prop_type == "boolean":
+ if function_format == "json" or function_format == "xml":
+ return "true" if v else "false"
+ elif function_format == "pythonic":
+ return "True" if v else "False"
+ else:
+ return str(v) # fallback
- # Get the formatter or default to string handling
- formatter = formatters.get(
- (prop_type, function_format),
- formatters[("string", function_format)], # Default to string handling
- )
+ if prop_type == "string":
+ if function_format == "xml":
+ return f'"{v}"'
+ else: # json or pythonic
+ return f'"\\"{v}\\""' # escape quote-wrapped string
- formatted_values = [formatter(value) for value in enum_values]
+ # All other types (number, integer, etc.)
+ return str(v)
+
+ formatted_values = [format_enum_val(v) for v in enum_values]
enum_rule = " | ".join(formatted_values)
+ return f"({enum_rule})" if len(formatted_values) > 1 else enum_rule
- # Wrap in parentheses if there are multiple values to ensure correct EBNF precedence
- if len(formatted_values) > 1:
- enum_rule = f"({enum_rule})"
-
- return enum_rule
+ @staticmethod
+ def get_type_mapping(function_format: str) -> Dict[str, str]:
+ """Get the complete type mapping for a given format."""
+ mapping = EBNFComposer.BASE_TYPE_MAPPING.copy()
+ overrides = EBNFComposer.FORMAT_TYPE_OVERRIDES.get(function_format, {})
+ mapping.update({k: v for k, v in overrides.items() if v is not None})
+ return mapping
@staticmethod
def _handle_type(prop: dict, function_format: str) -> str:
"""Handle type properties using the appropriate type mapping."""
prop_type = prop["type"]
- type_mapping = (
- EBNFComposer.PYTHONIC_TYPE_MAPPING
- if function_format == "pythonic"
- else EBNFComposer.JSON_TYPE_MAPPING
- )
+ type_mapping = EBNFComposer.get_type_mapping(function_format)
if isinstance(prop_type, list):
type_rules = [
- type_mapping[single_type]
+ type_mapping.get(single_type, function_format)
for single_type in prop_type
- if single_type in type_mapping
]
return " | ".join(type_rules) if type_rules else function_format
@@ -133,7 +154,7 @@ class EBNFComposer:
@staticmethod
def build_ebnf(
tools,
- function_format: Literal["pythonic", "json"] = "json",
+ function_format: Literal["pythonic", "json", "xml"] = "json",
# Parameters for wrapping the entire sequence of tool calls
sequence_start_token: Optional[str] = None,
sequence_end_token: Optional[str] = None,
@@ -143,6 +164,7 @@ class EBNFComposer:
# Parameter for separating multiple tool calls
tool_call_separator: Optional[str] = None,
call_rule_fmt: Optional[str] = None,
+ key_value_rule_fmt: Optional[str] = None,
):
"""
Generalized EBNF builder for all detectors.
@@ -157,6 +179,9 @@ class EBNFComposer:
call_rule_fmt: Optional custom format string for call_{name} rule. It should define each function call's format, with
the placeholders {name} for the function name and {arguments_rule} for the arguments rule. If None, a default
format based on function_format will be used.
+ key_value_rule_fmt: Optional custom format string for key-value pairs. It should define how each parameter is formatted,
+ with placeholders {key} for the parameter name and {valrule} for the value rule. If None, a default format
+ based on function_format will be used.
"""
# =================================================================
# Step 1: Determine the root tool calls rule
@@ -200,7 +225,11 @@ class EBNFComposer:
else EBNFComposer.CALL_RULE_MAP[function_format]
)
args_template = EBNFComposer.ARGUMENTS_RULE_MAP[function_format]
- key_value_template = EBNFComposer.KEY_VALUE_RULE_MAP[function_format]
+ key_value_template = (
+ key_value_rule_fmt
+ if key_value_rule_fmt
+ else EBNFComposer.KEY_VALUE_RULE_MAP[function_format]
+ )
# =================================================================
# Step 4: Build rules for each tool
@@ -292,10 +321,13 @@ class EBNFComposer:
# =================================================================
# Step 5: Add base grammar rules
# =================================================================
- base_grammar = (
- EBNFComposer.pythonic_grammar_ebnf_str
- if function_format == "pythonic"
- else EBNFComposer.json_grammar_ebnf_str
+ grammar_dict = {
+ "pythonic": EBNFComposer.pythonic_grammar_ebnf_str,
+ "json": EBNFComposer.json_grammar_ebnf_str,
+ "xml": EBNFComposer.xml_grammar_ebnf_str,
+ }
+ base_grammar = grammar_dict.get(
+ function_format, EBNFComposer.json_grammar_ebnf_str
)
ebnf_lines.append(base_grammar)
diff --git a/python/sglang/srt/function_call/function_call_parser.py b/python/sglang/srt/function_call/function_call_parser.py
index 4c38d9d4f..fde00f303 100644
--- a/python/sglang/srt/function_call/function_call_parser.py
+++ b/python/sglang/srt/function_call/function_call_parser.py
@@ -14,7 +14,7 @@ from sglang.srt.function_call.kimik2_detector import KimiK2Detector
from sglang.srt.function_call.llama32_detector import Llama32Detector
from sglang.srt.function_call.mistral_detector import MistralDetector
from sglang.srt.function_call.pythonic_detector import PythonicDetector
-from sglang.srt.function_call.qwen3_detector import Qwen3XMLDetector
+from sglang.srt.function_call.qwen3_coder_detector import Qwen3CoderDetector
from sglang.srt.function_call.qwen25_detector import Qwen25Detector
logger = logging.getLogger(__name__)
@@ -36,7 +36,7 @@ class FunctionCallParser:
"deepseekv3": DeepSeekV3Detector,
"pythonic": PythonicDetector,
"kimi_k2": KimiK2Detector,
- "qwen3": Qwen3XMLDetector,
+ "qwen3_coder": Qwen3CoderDetector,
}
def __init__(self, tools: List[Tool], tool_call_parser: str):
@@ -155,9 +155,9 @@ class FunctionCallParser:
or None if no constraint applies.
"""
# NOTE: structural_tag only supports JSON-compatible content between the begin and end.
- # It cannot parse or validate Python syntax like function calls.
+ # It cannot parse or validate function call Pythonic or XML-ish syntax.
if (
- not isinstance(self.detector, PythonicDetector)
+ self.detector.supports_structural_tag()
and tool_choice == "auto"
and any(tool.function.strict for tool in self.tools)
):
diff --git a/python/sglang/srt/function_call/pythonic_detector.py b/python/sglang/srt/function_call/pythonic_detector.py
index 85c3cd135..be183c6bf 100644
--- a/python/sglang/srt/function_call/pythonic_detector.py
+++ b/python/sglang/srt/function_call/pythonic_detector.py
@@ -8,7 +8,6 @@ from sglang.srt.entrypoints.openai.protocol import Tool
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import (
StreamingParseResult,
- StructureInfo,
ToolCallItem,
_GetInfoFunc,
)
@@ -216,11 +215,11 @@ class PythonicDetector(BaseFormatDetector):
else:
raise ValueError("Tool call arguments must be literals")
- def structure_info(self) -> _GetInfoFunc:
- def info(name: str):
- return StructureInfo(begin=f"[{name}(", end=")]", trigger=f"[{name}(")
+ def supports_structural_tag(self) -> bool:
+ return False
- return info
+ def structure_info(self) -> _GetInfoFunc:
+ raise NotImplementedError
def build_ebnf(self, tools: List[Tool]) -> Optional[str]:
return EBNFComposer.build_ebnf(
diff --git a/python/sglang/srt/function_call/qwen3_detector.py b/python/sglang/srt/function_call/qwen3_coder_detector.py
similarity index 91%
rename from python/sglang/srt/function_call/qwen3_detector.py
rename to python/sglang/srt/function_call/qwen3_coder_detector.py
index 5c6ac698e..641c86806 100644
--- a/python/sglang/srt/function_call/qwen3_detector.py
+++ b/python/sglang/srt/function_call/qwen3_coder_detector.py
@@ -9,7 +9,6 @@ from sglang.srt.entrypoints.openai.protocol import Tool
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import (
StreamingParseResult,
- StructureInfo,
ToolCallItem,
_GetInfoFunc,
)
@@ -29,7 +28,7 @@ def _safe_val(raw: str) -> Any:
return raw
-class Qwen3XMLDetector(BaseFormatDetector):
+class Qwen3CoderDetector(BaseFormatDetector):
"""
Detector for Qwen 3 models.
Assumes function call format:
@@ -127,24 +126,26 @@ class Qwen3XMLDetector(BaseFormatDetector):
params[pname] = _safe_val(pval)
raw = {"name": fname, "arguments": params}
try:
+ # TODO: fix idx in function call, the index for a function
+ # call will always be -1 in parse_base_json
res.extend(self.parse_base_json(raw, tools))
except Exception:
logger.warning("invalid tool call for %s dropped", fname)
return res
- def structure_info(self) -> _GetInfoFunc:
- return lambda n: StructureInfo(
- begin=f"{self.tool_call_start_token}\n",
- end=f"\n{self.tool_call_end_token}",
- trigger=self.tool_call_start_token,
- )
+ def supports_structural_tag(self) -> bool:
+ return False
+
+ def structure_info(self) -> _GetInfoFunc:
+ raise NotImplementedError
- # TODO: fake ebnf for xml + outlines backend
def build_ebnf(self, tools: List[Tool]):
return EBNFComposer.build_ebnf(
tools,
individual_call_start_token=self.tool_call_start_token.replace("\n", "\\n"),
individual_call_end_token=self.tool_call_end_token.replace("\n", "\\n"),
tool_call_separator="\\n",
- function_format="json",
+ function_format="xml",
+ call_rule_fmt='"\\n" {arguments_rule} "\\n"',
+ key_value_rule_fmt='"\\n" {valrule} "\\n"',
)
diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py
index 1625f2c3a..b48cbf725 100644
--- a/python/sglang/srt/server_args.py
+++ b/python/sglang/srt/server_args.py
@@ -1099,10 +1099,10 @@ class ServerArgs:
"deepseekv3",
"pythonic",
"kimi_k2",
- "qwen3",
+ "qwen3_coder",
],
default=ServerArgs.tool_call_parser,
- help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', 'pythonic', and 'kimi_k2'.",
+ help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', 'pythonic', 'kimi_k2', and 'qwen3_coder'.",
)
# Data parallelism
diff --git a/test/srt/test_function_call_parser.py b/test/srt/test_function_call_parser.py
index 26dd24fbb..511020651 100644
--- a/test/srt/test_function_call_parser.py
+++ b/test/srt/test_function_call_parser.py
@@ -10,6 +10,7 @@ from sglang.srt.function_call.kimik2_detector import KimiK2Detector
from sglang.srt.function_call.llama32_detector import Llama32Detector
from sglang.srt.function_call.mistral_detector import MistralDetector
from sglang.srt.function_call.pythonic_detector import PythonicDetector
+from sglang.srt.function_call.qwen3_coder_detector import Qwen3CoderDetector
from sglang.srt.function_call.qwen25_detector import Qwen25Detector
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST
@@ -507,6 +508,7 @@ class TestEBNFGeneration(unittest.TestCase):
self.llama32_detector = Llama32Detector()
self.mistral_detector = MistralDetector()
self.qwen25_detector = Qwen25Detector()
+ self.qwen3_coder_detector = Qwen3CoderDetector()
self.kimik2_detector = KimiK2Detector()
def test_pythonic_detector_ebnf(self):
@@ -620,6 +622,26 @@ class TestEBNFGeneration(unittest.TestCase):
except RuntimeError as e:
self.fail(f"Failed to compile EBNF: {e}")
+ def test_qwen3_coder_detector_ebnf(self):
+ """Test that the Qwen3CoderDetector generates valid EBNF."""
+ ebnf = self.qwen3_coder_detector.build_ebnf(self.tools)
+ self.assertIsNotNone(ebnf)
+ # Check that the EBNF contains expected patterns for XML format
+ self.assertIn("", ebnf)
+ self.assertIn("", ebnf)
+ self.assertIn('"\\n"', ebnf)
+ self.assertIn('"\\n"', ebnf)
+ self.assertIn('"\\n"', ebnf)
+ self.assertIn('"\\n"', ebnf)
+ # Check that it uses xml_text for string parameters
+ self.assertIn("xml_text", ebnf)
+ # Validate that the EBNF can be compiled by GrammarCompiler
+ try:
+ ctx = self.grammar_compiler.compile_grammar(ebnf)
+ self.assertIsNotNone(ctx, "EBNF should be valid and compile successfully")
+ except RuntimeError as e:
+ self.fail(f"Failed to compile EBNF: {e}")
+
def test_weather_function_optional_parameter_handling(self):
"""Test that weather function with optional unit parameter generates correct EBNF without trailing commas."""
# Create a weather tool with required location and optional unit
@@ -1464,5 +1486,438 @@ class TestDeepSeekV3Detector(unittest.TestCase):
self.assertEqual(params2["city"], "Beijing")
+class TestQwen3CoderDetector(unittest.TestCase):
+ def setUp(self):
+ # Create sample tools for testing
+ self.tools = [
+ Tool(
+ type="function",
+ function=Function(
+ name="get_current_weather",
+ description="Get the current weather",
+ parameters={
+ "properties": {
+ "city": {"type": "string", "description": "The city name"},
+ "state": {
+ "type": "string",
+ "description": "The state code",
+ },
+ "unit": {
+ "type": "string",
+ "enum": ["fahrenheit", "celsius"],
+ },
+ },
+ "required": ["city", "state"],
+ },
+ ),
+ ),
+ Tool(
+ type="function",
+ function=Function(
+ name="calculate_area",
+ description="Calculate area of a shape",
+ parameters={
+ "properties": {
+ "shape": {"type": "string"},
+ "dimensions": {"type": "object"},
+ "precision": {"type": "integer"},
+ }
+ },
+ ),
+ ),
+ ]
+ self.detector = Qwen3CoderDetector()
+
+ def test_has_tool_call(self):
+ """Test detection of tool call markers."""
+ self.assertTrue(self.detector.has_tool_call("test"))
+ self.assertFalse(self.detector.has_tool_call("No tool call here"))
+
+ def test_detect_and_parse_no_tools(self):
+ """Test parsing text without tool calls."""
+ model_output = "This is a test response without any tool calls"
+ result = self.detector.detect_and_parse(model_output, tools=[])
+ self.assertEqual(result.normal_text, model_output)
+ self.assertEqual(result.calls, [])
+
+ def test_detect_and_parse_single_tool(self):
+ """Test parsing a single tool call."""
+ model_output = """
+
+
+Dallas
+
+
+TX
+
+
+fahrenheit
+
+
+"""
+
+ result = self.detector.detect_and_parse(model_output, tools=self.tools)
+
+ self.assertEqual(result.normal_text, "")
+ self.assertEqual(len(result.calls), 1)
+ self.assertEqual(result.calls[0].name, "get_current_weather")
+
+ params = json.loads(result.calls[0].parameters)
+ self.assertEqual(params["city"], "Dallas")
+ self.assertEqual(params["state"], "TX")
+ self.assertEqual(params["unit"], "fahrenheit")
+
+ def test_detect_and_parse_with_content(self):
+ """Test parsing tool call with surrounding text."""
+ model_output = """Sure! Let me check the weather for you.
+
+
+Dallas
+
+
+TX
+
+
+fahrenheit
+
+
+"""
+
+ result = self.detector.detect_and_parse(model_output, tools=self.tools)
+
+ self.assertEqual(result.normal_text, "Sure! Let me check the weather for you.")
+ self.assertEqual(len(result.calls), 1)
+ self.assertEqual(result.calls[0].name, "get_current_weather")
+
+ def test_detect_and_parse_multiline_param(self):
+ """Test parsing tool call with multiline parameter values."""
+ model_output = """
+
+
+rectangle
+
+
+{"width": 10,
+ "height": 20}
+
+
+2
+
+
+"""
+
+ result = self.detector.detect_and_parse(model_output, tools=self.tools)
+
+ self.assertEqual(len(result.calls), 1)
+ self.assertEqual(result.calls[0].name, "calculate_area")
+
+ params = json.loads(result.calls[0].parameters)
+ self.assertEqual(params["shape"], "rectangle")
+ self.assertEqual(params["dimensions"], {"width": 10, "height": 20})
+ self.assertEqual(params["precision"], 2)
+
+ def test_detect_and_parse_parallel_tools(self):
+ """Test parsing multiple tool calls."""
+ model_output = """
+
+
+Dallas
+
+
+TX
+
+
+fahrenheit
+
+
+
+
+
+
+Orlando
+
+
+FL
+
+
+fahrenheit
+
+
+"""
+
+ result = self.detector.detect_and_parse(model_output, tools=self.tools)
+
+ self.assertEqual(result.normal_text, "\n")
+ self.assertEqual(len(result.calls), 2)
+
+ # First call
+ self.assertEqual(result.calls[0].name, "get_current_weather")
+ params1 = json.loads(result.calls[0].parameters)
+ self.assertEqual(params1["city"], "Dallas")
+ self.assertEqual(params1["state"], "TX")
+
+ # Second call
+ self.assertEqual(result.calls[1].name, "get_current_weather")
+ params2 = json.loads(result.calls[1].parameters)
+ self.assertEqual(params2["city"], "Orlando")
+ self.assertEqual(params2["state"], "FL")
+
+ def test_parse_streaming_simple(self):
+ """Test basic streaming parsing."""
+ chunks = [
+ "Sure! ",
+ "Let me check ",
+ "the weather.",
+ "",
+ "\n",
+ "\n",
+ "\nDallas",
+ "\n",
+ "\n",
+ "\nTX",
+ "\n",
+ "\n",
+ "\n",
+ ]
+
+ accumulated_text = ""
+ accumulated_calls = []
+
+ for chunk in chunks:
+ result = self.detector.parse_streaming_increment(chunk, tools=self.tools)
+ accumulated_text += result.normal_text
+ accumulated_calls.extend(result.calls)
+
+ self.assertEqual(accumulated_text, "Sure! Let me check the weather.")
+ self.assertEqual(len(accumulated_calls), 1)
+ self.assertEqual(accumulated_calls[0].name, "get_current_weather")
+
+ params = json.loads(accumulated_calls[0].parameters)
+ self.assertEqual(params["city"], "Dallas")
+ self.assertEqual(params["state"], "TX")
+
+ def test_parse_streaming_incomplete(self):
+ """Test streaming with incomplete tool call."""
+ # Send incomplete tool call
+ chunks = [
+ "",
+ "\n",
+ "\n",
+ "\nDallas",
+ "\n",
+ "\n",
+ "\nTX",
+ # Missing , ,
+ ]
+
+ accumulated_calls = []
+ for chunk in chunks:
+ result = self.detector.parse_streaming_increment(chunk, tools=self.tools)
+ accumulated_calls.extend(result.calls)
+
+ # Should not have any complete calls yet
+ self.assertEqual(len(accumulated_calls), 0)
+
+ # Now complete it
+ result = self.detector.parse_streaming_increment(
+ "\n\n\n", tools=self.tools
+ )
+ self.assertEqual(len(result.calls), 1)
+ self.assertEqual(result.calls[0].name, "get_current_weather")
+
+ def test_edge_case_no_parameters(self):
+ """Test tool call without parameters."""
+ model_output = """
+
+
+"""
+
+ result = self.detector.detect_and_parse(model_output, tools=self.tools)
+ self.assertEqual(len(result.calls), 1)
+ self.assertEqual(result.calls[0].name, "get_current_weather")
+ self.assertEqual(json.loads(result.calls[0].parameters), {})
+
+ def test_edge_case_special_chars_in_value(self):
+ """Test parameter with special characters in value."""
+ model_output = """
+
+
+Dallas->TX
+
+
+"""
+
+ result = self.detector.detect_and_parse(model_output, tools=self.tools)
+ self.assertEqual(len(result.calls), 1)
+
+ params = json.loads(result.calls[0].parameters)
+ self.assertEqual(params["city"], "Dallas->TX")
+
+ def test_extract_tool_calls_fallback_no_tags(self):
+ """Test fallback parsing when XML tags are missing (just function without tool_call wrapper)."""
+ model_output = """
+
+Dallas
+
+
+TX
+
+"""
+
+ result = self.detector.detect_and_parse(model_output, tools=self.tools)
+
+ self.assertIsNotNone(result)
+
+ def test_extract_tool_calls_type_conversion(self):
+ """Test parameter type conversion based on tool schema."""
+ test_tool = Tool(
+ type="function",
+ function=Function(
+ name="test_types",
+ parameters={
+ "type": "object",
+ "properties": {
+ "int_param": {"type": "integer"},
+ "float_param": {"type": "float"},
+ "bool_param": {"type": "boolean"},
+ "str_param": {"type": "string"},
+ "obj_param": {"type": "object"},
+ },
+ },
+ ),
+ )
+
+ model_output = """
+
+
+42
+
+
+3.14
+
+
+true
+
+
+hello world
+
+
+{"key": "value"}
+
+
+"""
+
+ result = self.detector.detect_and_parse(model_output, tools=[test_tool])
+
+ self.assertEqual(len(result.calls), 1)
+ params = json.loads(result.calls[0].parameters)
+ self.assertEqual(params["int_param"], 42)
+ self.assertEqual(params["float_param"], 3.14)
+ self.assertEqual(params["bool_param"], True)
+ self.assertEqual(params["str_param"], "hello world")
+ self.assertEqual(params["obj_param"], {"key": "value"})
+
+ def test_parse_streaming_incremental(self):
+ """Test that streaming is truly incremental with very small chunks."""
+ model_output = """I'll check the weather.
+
+
+Dallas
+
+
+TX
+
+
+"""
+
+ # Simulate more realistic token-based chunks where is a single token
+ chunks = [
+ "I'll check the weather.",
+ "",
+ "\n\n",
+ "\n",
+ "Dallas\n",
+ "\n",
+ "\n",
+ "TX\n",
+ "\n",
+ "\n",
+ "",
+ ]
+
+ accumulated_text = ""
+ accumulated_calls = []
+ chunks_count = 0
+
+ for chunk in chunks:
+ result = self.detector.parse_streaming_increment(chunk, tools=self.tools)
+ accumulated_text += result.normal_text
+ accumulated_calls.extend(result.calls)
+ chunks_count += 1
+
+ self.assertGreater(chunks_count, 3)
+
+ # Verify the accumulated results
+ self.assertIn("I'll check the weather.", accumulated_text)
+ self.assertEqual(len(accumulated_calls), 1)
+ self.assertEqual(accumulated_calls[0].name, "get_current_weather")
+
+ params = json.loads(accumulated_calls[0].parameters)
+ self.assertEqual(params["city"], "Dallas")
+ self.assertEqual(params["state"], "TX")
+
+ def test_parse_streaming_multiple_tools(self):
+ """Test streaming with multiple tool calls."""
+ model_output = """
+
+
+Dallas
+
+
+TX
+
+
+
+Some text in between.
+
+
+
+circle
+
+
+{"radius": 5}
+
+
+"""
+
+ # Simulate streaming by chunks
+ chunk_size = 20
+ chunks = [
+ model_output[i : i + chunk_size]
+ for i in range(0, len(model_output), chunk_size)
+ ]
+
+ accumulated_text = ""
+ accumulated_calls = []
+
+ for chunk in chunks:
+ result = self.detector.parse_streaming_increment(chunk, tools=self.tools)
+ accumulated_text += result.normal_text
+ accumulated_calls.extend(result.calls)
+
+ self.assertIn("Some text in between.", accumulated_text)
+ self.assertEqual(len(accumulated_calls), 2)
+ self.assertEqual(accumulated_calls[0].name, "get_current_weather")
+ self.assertEqual(accumulated_calls[1].name, "calculate_area")
+
+ # Verify parameters
+ params1 = json.loads(accumulated_calls[0].parameters)
+ self.assertEqual(params1["city"], "Dallas")
+
+ params2 = json.loads(accumulated_calls[1].parameters)
+ self.assertEqual(params2["shape"], "circle")
+ self.assertEqual(params2["dimensions"], {"radius": 5})
+
+
if __name__ == "__main__":
unittest.main()