diff --git a/python/sglang/srt/function_call/base_format_detector.py b/python/sglang/srt/function_call/base_format_detector.py index d9ac71253..39bb92f5f 100644 --- a/python/sglang/srt/function_call/base_format_detector.py +++ b/python/sglang/srt/function_call/base_format_detector.py @@ -321,6 +321,10 @@ class BaseFormatDetector(ABC): """ raise NotImplementedError() + def supports_structural_tag(self) -> bool: + """Return True if this detector supports structural tag format.""" + return True + @abstractmethod def structure_info(self) -> _GetInfoFunc: """ diff --git a/python/sglang/srt/function_call/ebnf_composer.py b/python/sglang/srt/function_call/ebnf_composer.py index 60035e05d..85d6039bb 100644 --- a/python/sglang/srt/function_call/ebnf_composer.py +++ b/python/sglang/srt/function_call/ebnf_composer.py @@ -1,51 +1,73 @@ -from typing import Literal, Optional +from typing import Any, Dict, Literal, Optional class EBNFComposer: # Adapted from https://xgrammar.mlc.ai/docs/how_to/ebnf_guided_generation.html#try-out-via-hf-transformers - json_grammar_ebnf_str = r""" - json ::= basic_array | basic_object - basic_any ::= basic_number | basic_string | basic_boolean | basic_null | basic_array | basic_object + # Shared primitive grammar rules used across all formats + BASE_PRIMITIVE_GRAMMAR = r""" + basic_string ::= (([\"] basic_string_1 [\"])) + basic_string_1 ::= "" | [^"\\\x00-\x1F] basic_string_1 | "\\" escape basic_string_1 + escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9]{4} basic_integer ::= ("0" | "-"? [1-9] [0-9]*) ".0"? basic_number ::= ("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)? - basic_string ::= (([\"] basic_string_1 [\"])) - basic_string_1 ::= "" | [^"\\\x00-\x1F] basic_string_1 | "\\" escape basic_string_1 - escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] - basic_boolean ::= "true" | "false" - basic_null ::= "null" - basic_array ::= "[" ("" | ws basic_any (ws "," ws basic_any)*) ws "]" - basic_object ::= "{" ("" | ws basic_string ws ":" ws basic_any ( ws "," ws basic_string ws ":" ws basic_any)*) ws "}" - ws ::= [ \n\t]* - """ - - pythonic_grammar_ebnf_str = r""" - pythonic ::= basic_number | basic_string | basic_array | "True" | "False" | "None" - basic_any ::= basic_number | basic_string | basic_array | basic_object - basic_number ::= ("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)? - basic_string ::= (([\"] basic_string_1 [\"])) - basic_string_1 ::= "" | [^"\\\x00-\x1F] basic_string_1 | "\\" escape basic_string_1 - escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] basic_array ::= "[" ("" | ws basic_any (ws "," ws basic_any)*) ws "]" basic_object ::= "{" ("" | ws basic_string ws ":" ws basic_any ( ws "," ws basic_string ws ":" ws basic_any)*) ws "}" ws ::= [ \n\t]* """ + # Format-specific extensions + json_grammar_ebnf_str = ( + r""" + json ::= basic_array | basic_object + basic_any ::= basic_number | basic_string | basic_boolean | basic_null | basic_array | basic_object + basic_boolean ::= "true" | "false" + basic_null ::= "null" + """ + + BASE_PRIMITIVE_GRAMMAR + ) + + pythonic_grammar_ebnf_str = ( + r""" + pythonic ::= basic_number | basic_string | basic_array | "True" | "False" | "None" + basic_any ::= basic_number | basic_string | basic_array | basic_object + basic_boolean ::= "True" | "False" + basic_null ::= "None" + """ + + BASE_PRIMITIVE_GRAMMAR + ) + + xml_grammar_ebnf_str = ( + r""" + xml ::= xml_element | xml_text + xml_element ::= basic_string | basic_number | basic_boolean | basic_null | basic_array | basic_object + xml_text ::= [^<>]* + basic_any ::= basic_number | basic_string | basic_boolean | basic_null | basic_array | basic_object + basic_boolean ::= "true" | "false" + basic_null ::= "null" + """ + + BASE_PRIMITIVE_GRAMMAR + ) + CALL_RULE_MAP = { "pythonic": 'call_{name} ::= "{name}" "(" {arguments_rule} ")"', "json": 'call_{name} ::= "{{" "\\"name\\"" ":" "\\"{name}\\"" ", " "\\"arguments\\"" ":" {arguments_rule} "}}"', + "xml": 'call_{name} ::= "\\n" {arguments_rule} "\\n"', } ARGUMENTS_RULE_MAP = { "pythonic": "{arg_rules}", "json": '"{{" {arg_rules} "}}"', + "xml": "{arg_rules}", } KEY_VALUE_RULE_MAP = { "pythonic": '"{key}" "=" {valrule}', "json": '"\\"{key}\\"" ":" {valrule}', + "xml": '"\\n" {valrule} "\\n"', } - JSON_TYPE_MAPPING = { + # Base type mapping - most types are the same across formats + BASE_TYPE_MAPPING = { "string": "basic_string", "number": "basic_number", "integer": "basic_number", @@ -55,19 +77,20 @@ class EBNFComposer: "object": "basic_object", } - PYTHONIC_TYPE_MAPPING = { - "string": "basic_string", - "number": "basic_number", - "integer": "basic_number", - "boolean": '"True" | "False"', - "null": '"None"', - "array": "basic_array", - "object": "basic_object", + # Format-specific overrides for types that differ + FORMAT_TYPE_OVERRIDES = { + "pythonic": { + "boolean": '"True" | "False"', + "null": '"None"', + }, + "xml": { + "string": "xml_text", + }, } @staticmethod def get_value_rule( - prop: dict, function_format: Literal["pythonic", "json"] = "json" + prop: dict, function_format: Literal["pythonic", "json", "xml"] = "json" ) -> str: if "enum" in prop: return EBNFComposer._handle_enum(prop, function_format) @@ -83,48 +106,46 @@ class EBNFComposer: enum_values = prop["enum"] prop_type = prop.get("type", "string") - # Define formatters for different type/format combinations - formatters = { - ("string", "json"): lambda v: f'"\\"{v}\\""', - ("string", "pythonic"): lambda v: f'"\\"{v}\\""', - ("number", "json"): str, - ("number", "pythonic"): str, - ("integer", "json"): str, - ("integer", "pythonic"): str, - ("boolean", "json"): lambda v: "true" if v else "false", - ("boolean", "pythonic"): lambda v: "True" if v else "False", - } + def format_enum_val(v: Any) -> str: + if prop_type == "boolean": + if function_format == "json" or function_format == "xml": + return "true" if v else "false" + elif function_format == "pythonic": + return "True" if v else "False" + else: + return str(v) # fallback - # Get the formatter or default to string handling - formatter = formatters.get( - (prop_type, function_format), - formatters[("string", function_format)], # Default to string handling - ) + if prop_type == "string": + if function_format == "xml": + return f'"{v}"' + else: # json or pythonic + return f'"\\"{v}\\""' # escape quote-wrapped string - formatted_values = [formatter(value) for value in enum_values] + # All other types (number, integer, etc.) + return str(v) + + formatted_values = [format_enum_val(v) for v in enum_values] enum_rule = " | ".join(formatted_values) + return f"({enum_rule})" if len(formatted_values) > 1 else enum_rule - # Wrap in parentheses if there are multiple values to ensure correct EBNF precedence - if len(formatted_values) > 1: - enum_rule = f"({enum_rule})" - - return enum_rule + @staticmethod + def get_type_mapping(function_format: str) -> Dict[str, str]: + """Get the complete type mapping for a given format.""" + mapping = EBNFComposer.BASE_TYPE_MAPPING.copy() + overrides = EBNFComposer.FORMAT_TYPE_OVERRIDES.get(function_format, {}) + mapping.update({k: v for k, v in overrides.items() if v is not None}) + return mapping @staticmethod def _handle_type(prop: dict, function_format: str) -> str: """Handle type properties using the appropriate type mapping.""" prop_type = prop["type"] - type_mapping = ( - EBNFComposer.PYTHONIC_TYPE_MAPPING - if function_format == "pythonic" - else EBNFComposer.JSON_TYPE_MAPPING - ) + type_mapping = EBNFComposer.get_type_mapping(function_format) if isinstance(prop_type, list): type_rules = [ - type_mapping[single_type] + type_mapping.get(single_type, function_format) for single_type in prop_type - if single_type in type_mapping ] return " | ".join(type_rules) if type_rules else function_format @@ -133,7 +154,7 @@ class EBNFComposer: @staticmethod def build_ebnf( tools, - function_format: Literal["pythonic", "json"] = "json", + function_format: Literal["pythonic", "json", "xml"] = "json", # Parameters for wrapping the entire sequence of tool calls sequence_start_token: Optional[str] = None, sequence_end_token: Optional[str] = None, @@ -143,6 +164,7 @@ class EBNFComposer: # Parameter for separating multiple tool calls tool_call_separator: Optional[str] = None, call_rule_fmt: Optional[str] = None, + key_value_rule_fmt: Optional[str] = None, ): """ Generalized EBNF builder for all detectors. @@ -157,6 +179,9 @@ class EBNFComposer: call_rule_fmt: Optional custom format string for call_{name} rule. It should define each function call's format, with the placeholders {name} for the function name and {arguments_rule} for the arguments rule. If None, a default format based on function_format will be used. + key_value_rule_fmt: Optional custom format string for key-value pairs. It should define how each parameter is formatted, + with placeholders {key} for the parameter name and {valrule} for the value rule. If None, a default format + based on function_format will be used. """ # ================================================================= # Step 1: Determine the root tool calls rule @@ -200,7 +225,11 @@ class EBNFComposer: else EBNFComposer.CALL_RULE_MAP[function_format] ) args_template = EBNFComposer.ARGUMENTS_RULE_MAP[function_format] - key_value_template = EBNFComposer.KEY_VALUE_RULE_MAP[function_format] + key_value_template = ( + key_value_rule_fmt + if key_value_rule_fmt + else EBNFComposer.KEY_VALUE_RULE_MAP[function_format] + ) # ================================================================= # Step 4: Build rules for each tool @@ -292,10 +321,13 @@ class EBNFComposer: # ================================================================= # Step 5: Add base grammar rules # ================================================================= - base_grammar = ( - EBNFComposer.pythonic_grammar_ebnf_str - if function_format == "pythonic" - else EBNFComposer.json_grammar_ebnf_str + grammar_dict = { + "pythonic": EBNFComposer.pythonic_grammar_ebnf_str, + "json": EBNFComposer.json_grammar_ebnf_str, + "xml": EBNFComposer.xml_grammar_ebnf_str, + } + base_grammar = grammar_dict.get( + function_format, EBNFComposer.json_grammar_ebnf_str ) ebnf_lines.append(base_grammar) diff --git a/python/sglang/srt/function_call/function_call_parser.py b/python/sglang/srt/function_call/function_call_parser.py index 4c38d9d4f..fde00f303 100644 --- a/python/sglang/srt/function_call/function_call_parser.py +++ b/python/sglang/srt/function_call/function_call_parser.py @@ -14,7 +14,7 @@ from sglang.srt.function_call.kimik2_detector import KimiK2Detector from sglang.srt.function_call.llama32_detector import Llama32Detector from sglang.srt.function_call.mistral_detector import MistralDetector from sglang.srt.function_call.pythonic_detector import PythonicDetector -from sglang.srt.function_call.qwen3_detector import Qwen3XMLDetector +from sglang.srt.function_call.qwen3_coder_detector import Qwen3CoderDetector from sglang.srt.function_call.qwen25_detector import Qwen25Detector logger = logging.getLogger(__name__) @@ -36,7 +36,7 @@ class FunctionCallParser: "deepseekv3": DeepSeekV3Detector, "pythonic": PythonicDetector, "kimi_k2": KimiK2Detector, - "qwen3": Qwen3XMLDetector, + "qwen3_coder": Qwen3CoderDetector, } def __init__(self, tools: List[Tool], tool_call_parser: str): @@ -155,9 +155,9 @@ class FunctionCallParser: or None if no constraint applies. """ # NOTE: structural_tag only supports JSON-compatible content between the begin and end. - # It cannot parse or validate Python syntax like function calls. + # It cannot parse or validate function call Pythonic or XML-ish syntax. if ( - not isinstance(self.detector, PythonicDetector) + self.detector.supports_structural_tag() and tool_choice == "auto" and any(tool.function.strict for tool in self.tools) ): diff --git a/python/sglang/srt/function_call/pythonic_detector.py b/python/sglang/srt/function_call/pythonic_detector.py index 85c3cd135..be183c6bf 100644 --- a/python/sglang/srt/function_call/pythonic_detector.py +++ b/python/sglang/srt/function_call/pythonic_detector.py @@ -8,7 +8,6 @@ from sglang.srt.entrypoints.openai.protocol import Tool from sglang.srt.function_call.base_format_detector import BaseFormatDetector from sglang.srt.function_call.core_types import ( StreamingParseResult, - StructureInfo, ToolCallItem, _GetInfoFunc, ) @@ -216,11 +215,11 @@ class PythonicDetector(BaseFormatDetector): else: raise ValueError("Tool call arguments must be literals") - def structure_info(self) -> _GetInfoFunc: - def info(name: str): - return StructureInfo(begin=f"[{name}(", end=")]", trigger=f"[{name}(") + def supports_structural_tag(self) -> bool: + return False - return info + def structure_info(self) -> _GetInfoFunc: + raise NotImplementedError def build_ebnf(self, tools: List[Tool]) -> Optional[str]: return EBNFComposer.build_ebnf( diff --git a/python/sglang/srt/function_call/qwen3_detector.py b/python/sglang/srt/function_call/qwen3_coder_detector.py similarity index 91% rename from python/sglang/srt/function_call/qwen3_detector.py rename to python/sglang/srt/function_call/qwen3_coder_detector.py index 5c6ac698e..641c86806 100644 --- a/python/sglang/srt/function_call/qwen3_detector.py +++ b/python/sglang/srt/function_call/qwen3_coder_detector.py @@ -9,7 +9,6 @@ from sglang.srt.entrypoints.openai.protocol import Tool from sglang.srt.function_call.base_format_detector import BaseFormatDetector from sglang.srt.function_call.core_types import ( StreamingParseResult, - StructureInfo, ToolCallItem, _GetInfoFunc, ) @@ -29,7 +28,7 @@ def _safe_val(raw: str) -> Any: return raw -class Qwen3XMLDetector(BaseFormatDetector): +class Qwen3CoderDetector(BaseFormatDetector): """ Detector for Qwen 3 models. Assumes function call format: @@ -127,24 +126,26 @@ class Qwen3XMLDetector(BaseFormatDetector): params[pname] = _safe_val(pval) raw = {"name": fname, "arguments": params} try: + # TODO: fix idx in function call, the index for a function + # call will always be -1 in parse_base_json res.extend(self.parse_base_json(raw, tools)) except Exception: logger.warning("invalid tool call for %s dropped", fname) return res - def structure_info(self) -> _GetInfoFunc: - return lambda n: StructureInfo( - begin=f"{self.tool_call_start_token}\n", - end=f"\n{self.tool_call_end_token}", - trigger=self.tool_call_start_token, - ) + def supports_structural_tag(self) -> bool: + return False + + def structure_info(self) -> _GetInfoFunc: + raise NotImplementedError - # TODO: fake ebnf for xml + outlines backend def build_ebnf(self, tools: List[Tool]): return EBNFComposer.build_ebnf( tools, individual_call_start_token=self.tool_call_start_token.replace("\n", "\\n"), individual_call_end_token=self.tool_call_end_token.replace("\n", "\\n"), tool_call_separator="\\n", - function_format="json", + function_format="xml", + call_rule_fmt='"\\n" {arguments_rule} "\\n"', + key_value_rule_fmt='"\\n" {valrule} "\\n"', ) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 1625f2c3a..b48cbf725 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1099,10 +1099,10 @@ class ServerArgs: "deepseekv3", "pythonic", "kimi_k2", - "qwen3", + "qwen3_coder", ], default=ServerArgs.tool_call_parser, - help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', 'pythonic', and 'kimi_k2'.", + help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', 'pythonic', 'kimi_k2', and 'qwen3_coder'.", ) # Data parallelism diff --git a/test/srt/test_function_call_parser.py b/test/srt/test_function_call_parser.py index 26dd24fbb..511020651 100644 --- a/test/srt/test_function_call_parser.py +++ b/test/srt/test_function_call_parser.py @@ -10,6 +10,7 @@ from sglang.srt.function_call.kimik2_detector import KimiK2Detector from sglang.srt.function_call.llama32_detector import Llama32Detector from sglang.srt.function_call.mistral_detector import MistralDetector from sglang.srt.function_call.pythonic_detector import PythonicDetector +from sglang.srt.function_call.qwen3_coder_detector import Qwen3CoderDetector from sglang.srt.function_call.qwen25_detector import Qwen25Detector from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST @@ -507,6 +508,7 @@ class TestEBNFGeneration(unittest.TestCase): self.llama32_detector = Llama32Detector() self.mistral_detector = MistralDetector() self.qwen25_detector = Qwen25Detector() + self.qwen3_coder_detector = Qwen3CoderDetector() self.kimik2_detector = KimiK2Detector() def test_pythonic_detector_ebnf(self): @@ -620,6 +622,26 @@ class TestEBNFGeneration(unittest.TestCase): except RuntimeError as e: self.fail(f"Failed to compile EBNF: {e}") + def test_qwen3_coder_detector_ebnf(self): + """Test that the Qwen3CoderDetector generates valid EBNF.""" + ebnf = self.qwen3_coder_detector.build_ebnf(self.tools) + self.assertIsNotNone(ebnf) + # Check that the EBNF contains expected patterns for XML format + self.assertIn("", ebnf) + self.assertIn("", ebnf) + self.assertIn('"\\n"', ebnf) + self.assertIn('"\\n"', ebnf) + self.assertIn('"\\n"', ebnf) + self.assertIn('"\\n"', ebnf) + # Check that it uses xml_text for string parameters + self.assertIn("xml_text", ebnf) + # Validate that the EBNF can be compiled by GrammarCompiler + try: + ctx = self.grammar_compiler.compile_grammar(ebnf) + self.assertIsNotNone(ctx, "EBNF should be valid and compile successfully") + except RuntimeError as e: + self.fail(f"Failed to compile EBNF: {e}") + def test_weather_function_optional_parameter_handling(self): """Test that weather function with optional unit parameter generates correct EBNF without trailing commas.""" # Create a weather tool with required location and optional unit @@ -1464,5 +1486,438 @@ class TestDeepSeekV3Detector(unittest.TestCase): self.assertEqual(params2["city"], "Beijing") +class TestQwen3CoderDetector(unittest.TestCase): + def setUp(self): + # Create sample tools for testing + self.tools = [ + Tool( + type="function", + function=Function( + name="get_current_weather", + description="Get the current weather", + parameters={ + "properties": { + "city": {"type": "string", "description": "The city name"}, + "state": { + "type": "string", + "description": "The state code", + }, + "unit": { + "type": "string", + "enum": ["fahrenheit", "celsius"], + }, + }, + "required": ["city", "state"], + }, + ), + ), + Tool( + type="function", + function=Function( + name="calculate_area", + description="Calculate area of a shape", + parameters={ + "properties": { + "shape": {"type": "string"}, + "dimensions": {"type": "object"}, + "precision": {"type": "integer"}, + } + }, + ), + ), + ] + self.detector = Qwen3CoderDetector() + + def test_has_tool_call(self): + """Test detection of tool call markers.""" + self.assertTrue(self.detector.has_tool_call("test")) + self.assertFalse(self.detector.has_tool_call("No tool call here")) + + def test_detect_and_parse_no_tools(self): + """Test parsing text without tool calls.""" + model_output = "This is a test response without any tool calls" + result = self.detector.detect_and_parse(model_output, tools=[]) + self.assertEqual(result.normal_text, model_output) + self.assertEqual(result.calls, []) + + def test_detect_and_parse_single_tool(self): + """Test parsing a single tool call.""" + model_output = """ + + +Dallas + + +TX + + +fahrenheit + + +""" + + result = self.detector.detect_and_parse(model_output, tools=self.tools) + + self.assertEqual(result.normal_text, "") + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].name, "get_current_weather") + + params = json.loads(result.calls[0].parameters) + self.assertEqual(params["city"], "Dallas") + self.assertEqual(params["state"], "TX") + self.assertEqual(params["unit"], "fahrenheit") + + def test_detect_and_parse_with_content(self): + """Test parsing tool call with surrounding text.""" + model_output = """Sure! Let me check the weather for you. + + +Dallas + + +TX + + +fahrenheit + + +""" + + result = self.detector.detect_and_parse(model_output, tools=self.tools) + + self.assertEqual(result.normal_text, "Sure! Let me check the weather for you.") + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].name, "get_current_weather") + + def test_detect_and_parse_multiline_param(self): + """Test parsing tool call with multiline parameter values.""" + model_output = """ + + +rectangle + + +{"width": 10, + "height": 20} + + +2 + + +""" + + result = self.detector.detect_and_parse(model_output, tools=self.tools) + + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].name, "calculate_area") + + params = json.loads(result.calls[0].parameters) + self.assertEqual(params["shape"], "rectangle") + self.assertEqual(params["dimensions"], {"width": 10, "height": 20}) + self.assertEqual(params["precision"], 2) + + def test_detect_and_parse_parallel_tools(self): + """Test parsing multiple tool calls.""" + model_output = """ + + +Dallas + + +TX + + +fahrenheit + + + + + + +Orlando + + +FL + + +fahrenheit + + +""" + + result = self.detector.detect_and_parse(model_output, tools=self.tools) + + self.assertEqual(result.normal_text, "\n") + self.assertEqual(len(result.calls), 2) + + # First call + self.assertEqual(result.calls[0].name, "get_current_weather") + params1 = json.loads(result.calls[0].parameters) + self.assertEqual(params1["city"], "Dallas") + self.assertEqual(params1["state"], "TX") + + # Second call + self.assertEqual(result.calls[1].name, "get_current_weather") + params2 = json.loads(result.calls[1].parameters) + self.assertEqual(params2["city"], "Orlando") + self.assertEqual(params2["state"], "FL") + + def test_parse_streaming_simple(self): + """Test basic streaming parsing.""" + chunks = [ + "Sure! ", + "Let me check ", + "the weather.", + "", + "\n", + "\n", + "\nDallas", + "\n", + "\n", + "\nTX", + "\n", + "\n", + "\n", + ] + + accumulated_text = "" + accumulated_calls = [] + + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, tools=self.tools) + accumulated_text += result.normal_text + accumulated_calls.extend(result.calls) + + self.assertEqual(accumulated_text, "Sure! Let me check the weather.") + self.assertEqual(len(accumulated_calls), 1) + self.assertEqual(accumulated_calls[0].name, "get_current_weather") + + params = json.loads(accumulated_calls[0].parameters) + self.assertEqual(params["city"], "Dallas") + self.assertEqual(params["state"], "TX") + + def test_parse_streaming_incomplete(self): + """Test streaming with incomplete tool call.""" + # Send incomplete tool call + chunks = [ + "", + "\n", + "\n", + "\nDallas", + "\n", + "\n", + "\nTX", + # Missing , , + ] + + accumulated_calls = [] + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, tools=self.tools) + accumulated_calls.extend(result.calls) + + # Should not have any complete calls yet + self.assertEqual(len(accumulated_calls), 0) + + # Now complete it + result = self.detector.parse_streaming_increment( + "\n\n\n", tools=self.tools + ) + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].name, "get_current_weather") + + def test_edge_case_no_parameters(self): + """Test tool call without parameters.""" + model_output = """ + + +""" + + result = self.detector.detect_and_parse(model_output, tools=self.tools) + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].name, "get_current_weather") + self.assertEqual(json.loads(result.calls[0].parameters), {}) + + def test_edge_case_special_chars_in_value(self): + """Test parameter with special characters in value.""" + model_output = """ + + +Dallas->TX + + +""" + + result = self.detector.detect_and_parse(model_output, tools=self.tools) + self.assertEqual(len(result.calls), 1) + + params = json.loads(result.calls[0].parameters) + self.assertEqual(params["city"], "Dallas->TX") + + def test_extract_tool_calls_fallback_no_tags(self): + """Test fallback parsing when XML tags are missing (just function without tool_call wrapper).""" + model_output = """ + +Dallas + + +TX + +""" + + result = self.detector.detect_and_parse(model_output, tools=self.tools) + + self.assertIsNotNone(result) + + def test_extract_tool_calls_type_conversion(self): + """Test parameter type conversion based on tool schema.""" + test_tool = Tool( + type="function", + function=Function( + name="test_types", + parameters={ + "type": "object", + "properties": { + "int_param": {"type": "integer"}, + "float_param": {"type": "float"}, + "bool_param": {"type": "boolean"}, + "str_param": {"type": "string"}, + "obj_param": {"type": "object"}, + }, + }, + ), + ) + + model_output = """ + + +42 + + +3.14 + + +true + + +hello world + + +{"key": "value"} + + +""" + + result = self.detector.detect_and_parse(model_output, tools=[test_tool]) + + self.assertEqual(len(result.calls), 1) + params = json.loads(result.calls[0].parameters) + self.assertEqual(params["int_param"], 42) + self.assertEqual(params["float_param"], 3.14) + self.assertEqual(params["bool_param"], True) + self.assertEqual(params["str_param"], "hello world") + self.assertEqual(params["obj_param"], {"key": "value"}) + + def test_parse_streaming_incremental(self): + """Test that streaming is truly incremental with very small chunks.""" + model_output = """I'll check the weather. + + +Dallas + + +TX + + +""" + + # Simulate more realistic token-based chunks where is a single token + chunks = [ + "I'll check the weather.", + "", + "\n\n", + "\n", + "Dallas\n", + "\n", + "\n", + "TX\n", + "\n", + "\n", + "", + ] + + accumulated_text = "" + accumulated_calls = [] + chunks_count = 0 + + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, tools=self.tools) + accumulated_text += result.normal_text + accumulated_calls.extend(result.calls) + chunks_count += 1 + + self.assertGreater(chunks_count, 3) + + # Verify the accumulated results + self.assertIn("I'll check the weather.", accumulated_text) + self.assertEqual(len(accumulated_calls), 1) + self.assertEqual(accumulated_calls[0].name, "get_current_weather") + + params = json.loads(accumulated_calls[0].parameters) + self.assertEqual(params["city"], "Dallas") + self.assertEqual(params["state"], "TX") + + def test_parse_streaming_multiple_tools(self): + """Test streaming with multiple tool calls.""" + model_output = """ + + +Dallas + + +TX + + + +Some text in between. + + + +circle + + +{"radius": 5} + + +""" + + # Simulate streaming by chunks + chunk_size = 20 + chunks = [ + model_output[i : i + chunk_size] + for i in range(0, len(model_output), chunk_size) + ] + + accumulated_text = "" + accumulated_calls = [] + + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, tools=self.tools) + accumulated_text += result.normal_text + accumulated_calls.extend(result.calls) + + self.assertIn("Some text in between.", accumulated_text) + self.assertEqual(len(accumulated_calls), 2) + self.assertEqual(accumulated_calls[0].name, "get_current_weather") + self.assertEqual(accumulated_calls[1].name, "calculate_area") + + # Verify parameters + params1 = json.loads(accumulated_calls[0].parameters) + self.assertEqual(params1["city"], "Dallas") + + params2 = json.loads(accumulated_calls[1].parameters) + self.assertEqual(params2["shape"], "circle") + self.assertEqual(params2["dimensions"], {"radius": 5}) + + if __name__ == "__main__": unittest.main()