diff --git a/python/sglang/srt/function_call/base_format_detector.py b/python/sglang/srt/function_call/base_format_detector.py index 1df62a7a8..cdd7b4607 100644 --- a/python/sglang/srt/function_call/base_format_detector.py +++ b/python/sglang/srt/function_call/base_format_detector.py @@ -36,6 +36,7 @@ class BaseFormatDetector(ABC): ) # map what has been streamed for each tool so far to a list self.bot_token = "" self.eot_token = "" + self.tool_call_separator = ", " def parse_base_json(self, action: Any, tools: List[Tool]) -> List[ToolCallItem]: tool_indices = { @@ -50,7 +51,7 @@ class BaseFormatDetector(ABC): if name and name in tool_indices: results.append( ToolCallItem( - tool_index=tool_indices[name], + tool_index=-1, # Caller should update this based on the actual tools array called name=name, parameters=json.dumps( act.get("parameters") or act.get("arguments", {}), @@ -106,7 +107,17 @@ class BaseFormatDetector(ABC): # Append new text to buffer self._buffer += new_text current_text = self._buffer - if not (self.bot_token in current_text or current_text.startswith("{")): + + # The current_text has tool_call if it is the start of a new tool call sequence + # or it is the start of a new tool call after a tool call separator, when there is a previous tool call + if not ( + self.bot_token in current_text + or current_text.startswith("{") + or ( + self.current_tool_id > 0 + and current_text.startswith(self.tool_call_separator + "{") + ) + ): # Only clear buffer if we're sure no tool call is starting if not self._ends_with_partial_token(self._buffer, self.bot_token): normal_text = self._buffer @@ -127,91 +138,73 @@ class BaseFormatDetector(ABC): } flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR + try: - tool_call_arr = [] - is_complete = [] try: - start_idx = ( - len(self.bot_token) - if current_text.startswith(self.bot_token) - else 0 + if current_text.startswith(self.bot_token): + start_idx = len(self.bot_token) + elif self.current_tool_id > 0 and current_text.startswith( + self.tool_call_separator + ): + start_idx = len(self.tool_call_separator) + else: + start_idx = 0 + + if start_idx >= len(current_text): + return StreamingParseResult() + + (obj, end_idx) = _partial_json_loads(current_text[start_idx:], flags) + + is_current_complete = _is_complete_json( + current_text[start_idx : start_idx + end_idx] ) - while start_idx < len(current_text): - (obj, end_idx) = _partial_json_loads( - current_text[start_idx:], flags - ) - is_complete.append( - _is_complete_json(current_text[start_idx : start_idx + end_idx]) - ) - start_idx += end_idx + len("; ") - # Validate tool name if present - if "name" in obj and obj["name"] not in self._tool_indices: - # Invalid tool name - reset state - self._buffer = "" - self.current_tool_id = -1 - self.current_tool_name_sent = False - if self.streamed_args_for_tool: - self.streamed_args_for_tool.pop() - return StreamingParseResult() + # Validate tool name if present + if "name" in obj and obj["name"] not in self._tool_indices: + # Invalid tool name - reset state + self._buffer = "" + self.current_tool_id = -1 + self.current_tool_name_sent = False + if self.streamed_args_for_tool: + self.streamed_args_for_tool.pop() + return StreamingParseResult() - # Handle parameters/arguments consistency - if "parameters" in obj: - assert ( - "arguments" not in obj - ), "model generated both parameters and arguments" - obj["arguments"] = obj["parameters"] - tool_call_arr.append(obj) + # Handle parameters/arguments consistency + # NOTE: we assume here that the obj is always partial of a single tool call + if "parameters" in obj: + assert ( + "arguments" not in obj + ), "model generated both parameters and arguments" + obj["arguments"] = obj["parameters"] + + current_tool_call = obj except MalformedJSON: return StreamingParseResult() - if len(tool_call_arr) == 0: + if not current_tool_call: return StreamingParseResult() - current_tool_call: Dict = ( - tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {} - ) - - # Handle new tool in array - if len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1: - if self.current_tool_id >= 0: - cur_arguments = current_tool_call.get("arguments") - if cur_arguments: - cur_args_json = json.dumps(cur_arguments) - sent = len(self.streamed_args_for_tool[self.current_tool_id]) - argument_diff = cur_args_json[sent:] - - res = StreamingParseResult( - calls=[ - ToolCallItem( - tool_index=self.current_tool_id, - name="", - parameters=argument_diff, - ) - ], - ) - self.streamed_args_for_tool[ - self.current_tool_id - ] += argument_diff - else: - res = StreamingParseResult() - else: - res = StreamingParseResult() - - self.current_tool_id = len(tool_call_arr) - 1 - self.current_tool_name_sent = False - self.streamed_args_for_tool.append("") - return res - - # Handle tool name - elif not self.current_tool_name_sent: + # Case 1: Handle tool name streaming + # This happens when we encounter a tool but haven't sent its name yet + if not self.current_tool_name_sent: function_name = current_tool_call.get("name") + if function_name and function_name in self._tool_indices: + # If this is a new tool (current_tool_id was -1), initialize it + if self.current_tool_id == -1: + self.current_tool_id = 0 + self.streamed_args_for_tool.append("") + # If this is a subsequent tool, ensure streamed_args_for_tool is large enough + elif self.current_tool_id >= len(self.streamed_args_for_tool): + while len(self.streamed_args_for_tool) <= self.current_tool_id: + self.streamed_args_for_tool.append("") + + # Send the tool name with empty parameters res = StreamingParseResult( calls=[ ToolCallItem( - tool_index=self._tool_indices[function_name], + tool_index=self.current_tool_id, name=function_name, parameters="", ) @@ -221,47 +214,75 @@ class BaseFormatDetector(ABC): else: res = StreamingParseResult() - # Handle streaming arguments + # Case 2: Handle streaming arguments + # This happens when we've already sent the tool name and now need to stream arguments incrementally else: cur_arguments = current_tool_call.get("arguments") res = StreamingParseResult() if cur_arguments: + # Calculate how much of the arguments we've already streamed sent = len(self.streamed_args_for_tool[self.current_tool_id]) cur_args_json = json.dumps(cur_arguments) - prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get( - "arguments" - ) + prev_arguments = None + if self.current_tool_id < len(self.prev_tool_call_arr): + prev_arguments = self.prev_tool_call_arr[ + self.current_tool_id + ].get("arguments") argument_diff = None - if is_complete[self.current_tool_id]: + + # If the current tool's JSON is complete, send all remaining arguments + if is_current_complete: argument_diff = cur_args_json[sent:] - self._buffer = "" - self.prev_tool_call_arr[self.current_tool_id].clear() + completing_tool_id = ( + self.current_tool_id + ) # Save the ID of the tool that's completing + + # Only remove the processed portion, keep unprocessed content + self._buffer = current_text[start_idx + end_idx :] + + if self.current_tool_id < len(self.prev_tool_call_arr): + self.prev_tool_call_arr[self.current_tool_id].clear() self.current_tool_name_sent = False self.streamed_args_for_tool[self.current_tool_id] = "" + self.current_tool_id += 1 + # If the tool is still being parsed, send incremental changes elif prev_arguments: prev_args_json = json.dumps(prev_arguments) if cur_args_json != prev_args_json: prefix = _find_common_prefix(prev_args_json, cur_args_json) argument_diff = prefix[sent:] + # Send the argument diff if there's something new if argument_diff is not None: + # Use the correct tool_index: completing_tool_id for completed tools, current_tool_id for ongoing + tool_index_to_use = ( + completing_tool_id + if is_current_complete + else self.current_tool_id + ) res = StreamingParseResult( calls=[ ToolCallItem( - tool_index=self.current_tool_id, + tool_index=tool_index_to_use, parameters=argument_diff, ) ], ) - if not is_complete[self.current_tool_id]: + if not is_current_complete: self.streamed_args_for_tool[ self.current_tool_id ] += argument_diff - self.prev_tool_call_arr = tool_call_arr + # Update prev_tool_call_arr with current state + if self.current_tool_id >= 0: + # Ensure prev_tool_call_arr is large enough + while len(self.prev_tool_call_arr) <= self.current_tool_id: + self.prev_tool_call_arr.append({}) + self.prev_tool_call_arr[self.current_tool_id] = current_tool_call + return res except Exception as e: diff --git a/python/sglang/srt/function_call/llama32_detector.py b/python/sglang/srt/function_call/llama32_detector.py index 32670782c..a2aaba3fe 100644 --- a/python/sglang/srt/function_call/llama32_detector.py +++ b/python/sglang/srt/function_call/llama32_detector.py @@ -24,6 +24,11 @@ class Llama32Detector(BaseFormatDetector): def __init__(self): super().__init__() self.bot_token = "<|python_tag|>" + # NOTE: technically Llama3.2 doesn't support well with parallel tool calls + # They need specific prompt engineering to support parallel tool calls + # Here we use ';' as the separator, which might have compatibility issues + # if users define to use a different separator in their prompt + self.tool_call_separator = ";" def has_tool_call(self, text: str) -> bool: """Check if the text contains a Llama 3.2 format tool call.""" @@ -42,7 +47,11 @@ class Llama32Detector(BaseFormatDetector): normal_text, action_text = "", text # Split by semicolon and process each part - json_parts = [part.strip() for part in action_text.split(";") if part.strip()] + json_parts = [ + part.strip() + for part in action_text.split(self.tool_call_separator) + if part.strip() + ] all_actions = [] for part in json_parts: try: @@ -70,5 +79,5 @@ class Llama32Detector(BaseFormatDetector): return EBNFComposer.build_ebnf( tools, function_format="json", - tool_call_separator=",", + tool_call_separator=self.tool_call_separator, ) diff --git a/python/sglang/srt/function_call/mistral_detector.py b/python/sglang/srt/function_call/mistral_detector.py index 9e3260ffd..05d3bfead 100644 --- a/python/sglang/srt/function_call/mistral_detector.py +++ b/python/sglang/srt/function_call/mistral_detector.py @@ -30,6 +30,7 @@ class MistralDetector(BaseFormatDetector): self.bot_token = "[TOOL_CALLS] [" self.eot_token = "]" self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL) + self.tool_call_separator = ", " def has_tool_call(self, text: str) -> bool: """Check if the text contains a Mistral format tool call.""" @@ -126,5 +127,5 @@ class MistralDetector(BaseFormatDetector): sequence_start_token=self.bot_token, sequence_end_token=self.eot_token, function_format="json", - tool_call_separator=", ", + tool_call_separator=self.tool_call_separator, ) diff --git a/python/sglang/srt/function_call/qwen25_detector.py b/python/sglang/srt/function_call/qwen25_detector.py index 0a2f4bd5d..ad1317777 100644 --- a/python/sglang/srt/function_call/qwen25_detector.py +++ b/python/sglang/srt/function_call/qwen25_detector.py @@ -29,6 +29,7 @@ class Qwen25Detector(BaseFormatDetector): super().__init__() self.bot_token = "\n" self.eot_token = "\n" + self.tool_call_separator = "\n" self._normal_text_buffer = "" # Buffer for handling partial end tokens def has_tool_call(self, text: str) -> bool: @@ -104,7 +105,6 @@ class Qwen25Detector(BaseFormatDetector): return result def structure_info(self) -> _GetInfoFunc: - # TODO: Update the begin and end tokens with '\n' if necessary return lambda name: StructureInfo( begin='\n{"name":"' + name + '", "arguments":', end="}\n", diff --git a/python/sglang/srt/function_call/utils.py b/python/sglang/srt/function_call/utils.py index e8a585bb2..c4da456f3 100644 --- a/python/sglang/srt/function_call/utils.py +++ b/python/sglang/srt/function_call/utils.py @@ -18,6 +18,23 @@ def _find_common_prefix(s1: str, s2: str) -> str: def _partial_json_loads(input_str: str, flags: Allow) -> Tuple[Any, int]: + """ + Parse incomplete or partial JSON strings commonly encountered during streaming. + + Args: + input_str (str): The potentially incomplete JSON string to parse. + flags (Allow): Bitwise flags controlling what types of partial data are allowed. + Common flags include: + - Allow.STR: Allow partial strings (e.g., '"hello wo' -> 'hello wo') + - Allow.OBJ: Allow partial objects (e.g., '{"key":' -> {'key': None}) + - Allow.ARR: Allow partial arrays (e.g., '[1, 2,' -> [1, 2]) + - Allow.ALL: Allow all types of partial data + + Returns: + Tuple[Any, int]: A tuple containing: + - parsed_object: The Python object parsed from the JSON + - consumed_length: Number of characters consumed from input_str + """ try: return (partial_json_parser.loads(input_str, flags), len(input_str)) except JSONDecodeError as e: diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index 7212f9acd..27336dc75 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -1327,7 +1327,6 @@ def v1_chat_generate_response( tool_calls = [ ToolCall( id=f"call_{base64.urlsafe_b64encode(uuid.uuid4().bytes).rstrip(b'=').decode()}", - index=call_info.tool_index, function=FunctionResponse( name=call_info.name, arguments=call_info.parameters ), diff --git a/test/srt/test_function_call_parser.py b/test/srt/test_function_call_parser.py index 99c7c9dd7..1ac58d9f6 100644 --- a/test/srt/test_function_call_parser.py +++ b/test/srt/test_function_call_parser.py @@ -3,6 +3,7 @@ import unittest from xgrammar import GrammarCompiler, TokenizerInfo +from sglang.srt.function_call.base_format_detector import BaseFormatDetector from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector from sglang.srt.function_call.llama32_detector import Llama32Detector from sglang.srt.function_call.mistral_detector import MistralDetector @@ -516,5 +517,237 @@ class TestEBNFGeneration(unittest.TestCase): self.fail(f"Failed to compile EBNF: {e}") +class TestBaseFormatDetector(unittest.TestCase): + """Test buffer management and sequential tool index assignment in BaseFormatDetector.""" + + def setUp(self): + """Set up test detector and tools.""" + + # Create a concrete implementation of BaseFormatDetector for testing + class TestFormatDetector(BaseFormatDetector): + def __init__(self): + super().__init__() + self.bot_token = "" + self.eot_token = "" + + def detect_and_parse(self, text, tools): + # Not used in streaming tests + pass + + def has_tool_call(self, text): + return "" in text + + def structure_info(self): + # Not used in streaming tests + pass + + def build_ebnf(self, tools): + # Not used in streaming tests + pass + + self.detector = TestFormatDetector() + self.tools = [ + Tool( + type="function", + function=Function( + name="get_weather", + description="Get weather information", + parameters={ + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "City name", + } + }, + "required": ["city"], + }, + ), + ), + Tool( + type="function", + function=Function( + name="get_tourist_attractions", + description="Get tourist attractions", + parameters={ + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "City name", + } + }, + "required": ["city"], + }, + ), + ), + ] + + def test_sequential_tool_index_assignment(self): + """Test that multiple tool calls get sequential tool_index values (0, 1, 2, ...).""" + # Simulate streaming chunks for two consecutive tool calls + chunks = [ + "", + '{"name": "get_weather", ', + '"arguments": {"city": "Paris"}}', + ", ", + '{"name": "get_tourist_attractions", ', + '"arguments": {"city": "London"}}', + "", + ] + + tool_indices_seen = [] + + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, self.tools) + + if result.calls: + for call in result.calls: + if call.tool_index is not None: + tool_indices_seen.append(call.tool_index) + + # Verify we got sequential tool indices + unique_indices = sorted(set(tool_indices_seen)) + self.assertEqual( + unique_indices, + [0, 1], + f"Expected sequential tool indices [0, 1], got {unique_indices}", + ) + + def test_buffer_content_preservation(self): + """Test that buffer correctly preserves unprocessed content when tool completes.""" + # Test simpler scenario: tool completion followed by new tool start + chunks = [ + "", + '{"name": "get_weather", ', + '"arguments": {"city": "Paris"}}', + ", ", + '{"name": "get_tourist_attractions", ', + '"arguments": {"city": "London"}} ', + ] + + tool_calls_seen = [] + + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, self.tools) + if result.calls: + for call in result.calls: + if ( + call.name + ): # Only count calls with names (not just parameter updates) + tool_calls_seen.append(call.name) + + # Should see both tool names + self.assertIn("get_weather", tool_calls_seen, "Should process first tool") + self.assertIn( + "get_tourist_attractions", tool_calls_seen, "Should process second tool" + ) + + def test_current_tool_id_increment_on_completion(self): + """Test that current_tool_id increments when a tool completes.""" + # Initial state + self.assertEqual( + self.detector.current_tool_id, -1, "Should start with current_tool_id=-1" + ) + + # Process first tool completely + chunks = [ + "", + '{"name": "get_weather", ', + ] + + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, self.tools) + + self.assertEqual( + self.detector.current_tool_id, 0, "current_tool_id should be 0" + ) + self.assertEqual( + result.calls[0].name, "get_weather", "The first tool should be get_weather" + ) + self.assertEqual( + result.calls[0].tool_index, 0, "The first tool index should be 0" + ) + + # Complete second tool name - this should show that current_tool_id is now 1 + result = self.detector.parse_streaming_increment( + '"arguments": {"city": "Paris"}}, {"name": "get_', self.tools + ) + self.assertEqual(result.calls[0].parameters, '{"city": "Paris"}') + + self.assertEqual( + self.detector.current_tool_id, + 1, + "current_tool_id should be 1 after first tool completes and second tool starts", + ) + + result = self.detector.parse_streaming_increment( + 'tourist_attractions", ', self.tools + ) + + # Second tool should have tool_index=1 + tourist_calls = [ + call for call in result.calls if call.name == "get_tourist_attractions" + ] + self.assertEqual( + tourist_calls[0].tool_index, 1, "Second tool should have tool_index=1" + ) + + def test_tool_name_streaming_with_correct_index(self): + """Test that tool names are streamed with correct tool_index values.""" + # Process first tool + self.detector.parse_streaming_increment("", self.tools) + result1 = self.detector.parse_streaming_increment( + '{"name": "get_weather", ', self.tools + ) + + # First tool name should have tool_index=0 + weather_calls = [call for call in result1.calls if call.name == "get_weather"] + self.assertEqual(len(weather_calls), 1, "Should have one weather call") + self.assertEqual( + weather_calls[0].tool_index, 0, "First tool should have tool_index=0" + ) + + # Complete first tool + self.detector.parse_streaming_increment( + '"arguments": {"city": "Paris"}}', self.tools + ) + + # Start second tool + self.detector.parse_streaming_increment(", ", self.tools) + result2 = self.detector.parse_streaming_increment( + '{"name": "get_tourist_attractions", ', self.tools + ) + + # Second tool name should have tool_index=1 + tourist_calls = [ + call for call in result2.calls if call.name == "get_tourist_attractions" + ] + self.assertEqual( + len(tourist_calls), 1, "Should have one tourist attractions call" + ) + self.assertEqual( + tourist_calls[0].tool_index, 1, "Second tool should have tool_index=1" + ) + + def test_buffer_reset_on_invalid_tool(self): + """Test that buffer and state are reset when an invalid tool name is encountered.""" + # Start fresh with an invalid tool name from the beginning + result = self.detector.parse_streaming_increment( + '{"name": "invalid_tool", ', self.tools + ) + + # Should return empty result and reset state + self.assertEqual(result.calls, [], "Should return no calls for invalid tool") + self.assertEqual( + self.detector.current_tool_id, + -1, + "current_tool_id should remain -1 for invalid tool", + ) + self.assertEqual( + self.detector._buffer, "", "Buffer should be cleared for invalid tool" + ) + + if __name__ == "__main__": unittest.main()