diff --git a/python/sglang/srt/function_call/base_format_detector.py b/python/sglang/srt/function_call/base_format_detector.py
index 1df62a7a8..cdd7b4607 100644
--- a/python/sglang/srt/function_call/base_format_detector.py
+++ b/python/sglang/srt/function_call/base_format_detector.py
@@ -36,6 +36,7 @@ class BaseFormatDetector(ABC):
) # map what has been streamed for each tool so far to a list
self.bot_token = ""
self.eot_token = ""
+ self.tool_call_separator = ", "
def parse_base_json(self, action: Any, tools: List[Tool]) -> List[ToolCallItem]:
tool_indices = {
@@ -50,7 +51,7 @@ class BaseFormatDetector(ABC):
if name and name in tool_indices:
results.append(
ToolCallItem(
- tool_index=tool_indices[name],
+ tool_index=-1, # Caller should update this based on the actual tools array called
name=name,
parameters=json.dumps(
act.get("parameters") or act.get("arguments", {}),
@@ -106,7 +107,17 @@ class BaseFormatDetector(ABC):
# Append new text to buffer
self._buffer += new_text
current_text = self._buffer
- if not (self.bot_token in current_text or current_text.startswith("{")):
+
+ # The current_text has tool_call if it is the start of a new tool call sequence
+ # or it is the start of a new tool call after a tool call separator, when there is a previous tool call
+ if not (
+ self.bot_token in current_text
+ or current_text.startswith("{")
+ or (
+ self.current_tool_id > 0
+ and current_text.startswith(self.tool_call_separator + "{")
+ )
+ ):
# Only clear buffer if we're sure no tool call is starting
if not self._ends_with_partial_token(self._buffer, self.bot_token):
normal_text = self._buffer
@@ -127,91 +138,73 @@ class BaseFormatDetector(ABC):
}
flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR
+
try:
- tool_call_arr = []
- is_complete = []
try:
- start_idx = (
- len(self.bot_token)
- if current_text.startswith(self.bot_token)
- else 0
+ if current_text.startswith(self.bot_token):
+ start_idx = len(self.bot_token)
+ elif self.current_tool_id > 0 and current_text.startswith(
+ self.tool_call_separator
+ ):
+ start_idx = len(self.tool_call_separator)
+ else:
+ start_idx = 0
+
+ if start_idx >= len(current_text):
+ return StreamingParseResult()
+
+ (obj, end_idx) = _partial_json_loads(current_text[start_idx:], flags)
+
+ is_current_complete = _is_complete_json(
+ current_text[start_idx : start_idx + end_idx]
)
- while start_idx < len(current_text):
- (obj, end_idx) = _partial_json_loads(
- current_text[start_idx:], flags
- )
- is_complete.append(
- _is_complete_json(current_text[start_idx : start_idx + end_idx])
- )
- start_idx += end_idx + len("; ")
- # Validate tool name if present
- if "name" in obj and obj["name"] not in self._tool_indices:
- # Invalid tool name - reset state
- self._buffer = ""
- self.current_tool_id = -1
- self.current_tool_name_sent = False
- if self.streamed_args_for_tool:
- self.streamed_args_for_tool.pop()
- return StreamingParseResult()
+ # Validate tool name if present
+ if "name" in obj and obj["name"] not in self._tool_indices:
+ # Invalid tool name - reset state
+ self._buffer = ""
+ self.current_tool_id = -1
+ self.current_tool_name_sent = False
+ if self.streamed_args_for_tool:
+ self.streamed_args_for_tool.pop()
+ return StreamingParseResult()
- # Handle parameters/arguments consistency
- if "parameters" in obj:
- assert (
- "arguments" not in obj
- ), "model generated both parameters and arguments"
- obj["arguments"] = obj["parameters"]
- tool_call_arr.append(obj)
+ # Handle parameters/arguments consistency
+ # NOTE: we assume here that the obj is always partial of a single tool call
+ if "parameters" in obj:
+ assert (
+ "arguments" not in obj
+ ), "model generated both parameters and arguments"
+ obj["arguments"] = obj["parameters"]
+
+ current_tool_call = obj
except MalformedJSON:
return StreamingParseResult()
- if len(tool_call_arr) == 0:
+ if not current_tool_call:
return StreamingParseResult()
- current_tool_call: Dict = (
- tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {}
- )
-
- # Handle new tool in array
- if len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1:
- if self.current_tool_id >= 0:
- cur_arguments = current_tool_call.get("arguments")
- if cur_arguments:
- cur_args_json = json.dumps(cur_arguments)
- sent = len(self.streamed_args_for_tool[self.current_tool_id])
- argument_diff = cur_args_json[sent:]
-
- res = StreamingParseResult(
- calls=[
- ToolCallItem(
- tool_index=self.current_tool_id,
- name="",
- parameters=argument_diff,
- )
- ],
- )
- self.streamed_args_for_tool[
- self.current_tool_id
- ] += argument_diff
- else:
- res = StreamingParseResult()
- else:
- res = StreamingParseResult()
-
- self.current_tool_id = len(tool_call_arr) - 1
- self.current_tool_name_sent = False
- self.streamed_args_for_tool.append("")
- return res
-
- # Handle tool name
- elif not self.current_tool_name_sent:
+ # Case 1: Handle tool name streaming
+ # This happens when we encounter a tool but haven't sent its name yet
+ if not self.current_tool_name_sent:
function_name = current_tool_call.get("name")
+
if function_name and function_name in self._tool_indices:
+ # If this is a new tool (current_tool_id was -1), initialize it
+ if self.current_tool_id == -1:
+ self.current_tool_id = 0
+ self.streamed_args_for_tool.append("")
+ # If this is a subsequent tool, ensure streamed_args_for_tool is large enough
+ elif self.current_tool_id >= len(self.streamed_args_for_tool):
+ while len(self.streamed_args_for_tool) <= self.current_tool_id:
+ self.streamed_args_for_tool.append("")
+
+ # Send the tool name with empty parameters
res = StreamingParseResult(
calls=[
ToolCallItem(
- tool_index=self._tool_indices[function_name],
+ tool_index=self.current_tool_id,
name=function_name,
parameters="",
)
@@ -221,47 +214,75 @@ class BaseFormatDetector(ABC):
else:
res = StreamingParseResult()
- # Handle streaming arguments
+ # Case 2: Handle streaming arguments
+ # This happens when we've already sent the tool name and now need to stream arguments incrementally
else:
cur_arguments = current_tool_call.get("arguments")
res = StreamingParseResult()
if cur_arguments:
+ # Calculate how much of the arguments we've already streamed
sent = len(self.streamed_args_for_tool[self.current_tool_id])
cur_args_json = json.dumps(cur_arguments)
- prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get(
- "arguments"
- )
+ prev_arguments = None
+ if self.current_tool_id < len(self.prev_tool_call_arr):
+ prev_arguments = self.prev_tool_call_arr[
+ self.current_tool_id
+ ].get("arguments")
argument_diff = None
- if is_complete[self.current_tool_id]:
+
+ # If the current tool's JSON is complete, send all remaining arguments
+ if is_current_complete:
argument_diff = cur_args_json[sent:]
- self._buffer = ""
- self.prev_tool_call_arr[self.current_tool_id].clear()
+ completing_tool_id = (
+ self.current_tool_id
+ ) # Save the ID of the tool that's completing
+
+ # Only remove the processed portion, keep unprocessed content
+ self._buffer = current_text[start_idx + end_idx :]
+
+ if self.current_tool_id < len(self.prev_tool_call_arr):
+ self.prev_tool_call_arr[self.current_tool_id].clear()
self.current_tool_name_sent = False
self.streamed_args_for_tool[self.current_tool_id] = ""
+ self.current_tool_id += 1
+ # If the tool is still being parsed, send incremental changes
elif prev_arguments:
prev_args_json = json.dumps(prev_arguments)
if cur_args_json != prev_args_json:
prefix = _find_common_prefix(prev_args_json, cur_args_json)
argument_diff = prefix[sent:]
+ # Send the argument diff if there's something new
if argument_diff is not None:
+ # Use the correct tool_index: completing_tool_id for completed tools, current_tool_id for ongoing
+ tool_index_to_use = (
+ completing_tool_id
+ if is_current_complete
+ else self.current_tool_id
+ )
res = StreamingParseResult(
calls=[
ToolCallItem(
- tool_index=self.current_tool_id,
+ tool_index=tool_index_to_use,
parameters=argument_diff,
)
],
)
- if not is_complete[self.current_tool_id]:
+ if not is_current_complete:
self.streamed_args_for_tool[
self.current_tool_id
] += argument_diff
- self.prev_tool_call_arr = tool_call_arr
+ # Update prev_tool_call_arr with current state
+ if self.current_tool_id >= 0:
+ # Ensure prev_tool_call_arr is large enough
+ while len(self.prev_tool_call_arr) <= self.current_tool_id:
+ self.prev_tool_call_arr.append({})
+ self.prev_tool_call_arr[self.current_tool_id] = current_tool_call
+
return res
except Exception as e:
diff --git a/python/sglang/srt/function_call/llama32_detector.py b/python/sglang/srt/function_call/llama32_detector.py
index 32670782c..a2aaba3fe 100644
--- a/python/sglang/srt/function_call/llama32_detector.py
+++ b/python/sglang/srt/function_call/llama32_detector.py
@@ -24,6 +24,11 @@ class Llama32Detector(BaseFormatDetector):
def __init__(self):
super().__init__()
self.bot_token = "<|python_tag|>"
+ # NOTE: technically Llama3.2 doesn't support well with parallel tool calls
+ # They need specific prompt engineering to support parallel tool calls
+ # Here we use ';' as the separator, which might have compatibility issues
+ # if users define to use a different separator in their prompt
+ self.tool_call_separator = ";"
def has_tool_call(self, text: str) -> bool:
"""Check if the text contains a Llama 3.2 format tool call."""
@@ -42,7 +47,11 @@ class Llama32Detector(BaseFormatDetector):
normal_text, action_text = "", text
# Split by semicolon and process each part
- json_parts = [part.strip() for part in action_text.split(";") if part.strip()]
+ json_parts = [
+ part.strip()
+ for part in action_text.split(self.tool_call_separator)
+ if part.strip()
+ ]
all_actions = []
for part in json_parts:
try:
@@ -70,5 +79,5 @@ class Llama32Detector(BaseFormatDetector):
return EBNFComposer.build_ebnf(
tools,
function_format="json",
- tool_call_separator=",",
+ tool_call_separator=self.tool_call_separator,
)
diff --git a/python/sglang/srt/function_call/mistral_detector.py b/python/sglang/srt/function_call/mistral_detector.py
index 9e3260ffd..05d3bfead 100644
--- a/python/sglang/srt/function_call/mistral_detector.py
+++ b/python/sglang/srt/function_call/mistral_detector.py
@@ -30,6 +30,7 @@ class MistralDetector(BaseFormatDetector):
self.bot_token = "[TOOL_CALLS] ["
self.eot_token = "]"
self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
+ self.tool_call_separator = ", "
def has_tool_call(self, text: str) -> bool:
"""Check if the text contains a Mistral format tool call."""
@@ -126,5 +127,5 @@ class MistralDetector(BaseFormatDetector):
sequence_start_token=self.bot_token,
sequence_end_token=self.eot_token,
function_format="json",
- tool_call_separator=", ",
+ tool_call_separator=self.tool_call_separator,
)
diff --git a/python/sglang/srt/function_call/qwen25_detector.py b/python/sglang/srt/function_call/qwen25_detector.py
index 0a2f4bd5d..ad1317777 100644
--- a/python/sglang/srt/function_call/qwen25_detector.py
+++ b/python/sglang/srt/function_call/qwen25_detector.py
@@ -29,6 +29,7 @@ class Qwen25Detector(BaseFormatDetector):
super().__init__()
self.bot_token = "\n"
self.eot_token = "\n"
+ self.tool_call_separator = "\n"
self._normal_text_buffer = "" # Buffer for handling partial end tokens
def has_tool_call(self, text: str) -> bool:
@@ -104,7 +105,6 @@ class Qwen25Detector(BaseFormatDetector):
return result
def structure_info(self) -> _GetInfoFunc:
- # TODO: Update the begin and end tokens with '\n' if necessary
return lambda name: StructureInfo(
begin='\n{"name":"' + name + '", "arguments":',
end="}\n",
diff --git a/python/sglang/srt/function_call/utils.py b/python/sglang/srt/function_call/utils.py
index e8a585bb2..c4da456f3 100644
--- a/python/sglang/srt/function_call/utils.py
+++ b/python/sglang/srt/function_call/utils.py
@@ -18,6 +18,23 @@ def _find_common_prefix(s1: str, s2: str) -> str:
def _partial_json_loads(input_str: str, flags: Allow) -> Tuple[Any, int]:
+ """
+ Parse incomplete or partial JSON strings commonly encountered during streaming.
+
+ Args:
+ input_str (str): The potentially incomplete JSON string to parse.
+ flags (Allow): Bitwise flags controlling what types of partial data are allowed.
+ Common flags include:
+ - Allow.STR: Allow partial strings (e.g., '"hello wo' -> 'hello wo')
+ - Allow.OBJ: Allow partial objects (e.g., '{"key":' -> {'key': None})
+ - Allow.ARR: Allow partial arrays (e.g., '[1, 2,' -> [1, 2])
+ - Allow.ALL: Allow all types of partial data
+
+ Returns:
+ Tuple[Any, int]: A tuple containing:
+ - parsed_object: The Python object parsed from the JSON
+ - consumed_length: Number of characters consumed from input_str
+ """
try:
return (partial_json_parser.loads(input_str, flags), len(input_str))
except JSONDecodeError as e:
diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py
index 7212f9acd..27336dc75 100644
--- a/python/sglang/srt/openai_api/adapter.py
+++ b/python/sglang/srt/openai_api/adapter.py
@@ -1327,7 +1327,6 @@ def v1_chat_generate_response(
tool_calls = [
ToolCall(
id=f"call_{base64.urlsafe_b64encode(uuid.uuid4().bytes).rstrip(b'=').decode()}",
- index=call_info.tool_index,
function=FunctionResponse(
name=call_info.name, arguments=call_info.parameters
),
diff --git a/test/srt/test_function_call_parser.py b/test/srt/test_function_call_parser.py
index 99c7c9dd7..1ac58d9f6 100644
--- a/test/srt/test_function_call_parser.py
+++ b/test/srt/test_function_call_parser.py
@@ -3,6 +3,7 @@ import unittest
from xgrammar import GrammarCompiler, TokenizerInfo
+from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector
from sglang.srt.function_call.llama32_detector import Llama32Detector
from sglang.srt.function_call.mistral_detector import MistralDetector
@@ -516,5 +517,237 @@ class TestEBNFGeneration(unittest.TestCase):
self.fail(f"Failed to compile EBNF: {e}")
+class TestBaseFormatDetector(unittest.TestCase):
+ """Test buffer management and sequential tool index assignment in BaseFormatDetector."""
+
+ def setUp(self):
+ """Set up test detector and tools."""
+
+ # Create a concrete implementation of BaseFormatDetector for testing
+ class TestFormatDetector(BaseFormatDetector):
+ def __init__(self):
+ super().__init__()
+ self.bot_token = ""
+ self.eot_token = ""
+
+ def detect_and_parse(self, text, tools):
+ # Not used in streaming tests
+ pass
+
+ def has_tool_call(self, text):
+ return "" in text
+
+ def structure_info(self):
+ # Not used in streaming tests
+ pass
+
+ def build_ebnf(self, tools):
+ # Not used in streaming tests
+ pass
+
+ self.detector = TestFormatDetector()
+ self.tools = [
+ Tool(
+ type="function",
+ function=Function(
+ name="get_weather",
+ description="Get weather information",
+ parameters={
+ "type": "object",
+ "properties": {
+ "city": {
+ "type": "string",
+ "description": "City name",
+ }
+ },
+ "required": ["city"],
+ },
+ ),
+ ),
+ Tool(
+ type="function",
+ function=Function(
+ name="get_tourist_attractions",
+ description="Get tourist attractions",
+ parameters={
+ "type": "object",
+ "properties": {
+ "city": {
+ "type": "string",
+ "description": "City name",
+ }
+ },
+ "required": ["city"],
+ },
+ ),
+ ),
+ ]
+
+ def test_sequential_tool_index_assignment(self):
+ """Test that multiple tool calls get sequential tool_index values (0, 1, 2, ...)."""
+ # Simulate streaming chunks for two consecutive tool calls
+ chunks = [
+ "",
+ '{"name": "get_weather", ',
+ '"arguments": {"city": "Paris"}}',
+ ", ",
+ '{"name": "get_tourist_attractions", ',
+ '"arguments": {"city": "London"}}',
+ "",
+ ]
+
+ tool_indices_seen = []
+
+ for chunk in chunks:
+ result = self.detector.parse_streaming_increment(chunk, self.tools)
+
+ if result.calls:
+ for call in result.calls:
+ if call.tool_index is not None:
+ tool_indices_seen.append(call.tool_index)
+
+ # Verify we got sequential tool indices
+ unique_indices = sorted(set(tool_indices_seen))
+ self.assertEqual(
+ unique_indices,
+ [0, 1],
+ f"Expected sequential tool indices [0, 1], got {unique_indices}",
+ )
+
+ def test_buffer_content_preservation(self):
+ """Test that buffer correctly preserves unprocessed content when tool completes."""
+ # Test simpler scenario: tool completion followed by new tool start
+ chunks = [
+ "",
+ '{"name": "get_weather", ',
+ '"arguments": {"city": "Paris"}}',
+ ", ",
+ '{"name": "get_tourist_attractions", ',
+ '"arguments": {"city": "London"}} ',
+ ]
+
+ tool_calls_seen = []
+
+ for chunk in chunks:
+ result = self.detector.parse_streaming_increment(chunk, self.tools)
+ if result.calls:
+ for call in result.calls:
+ if (
+ call.name
+ ): # Only count calls with names (not just parameter updates)
+ tool_calls_seen.append(call.name)
+
+ # Should see both tool names
+ self.assertIn("get_weather", tool_calls_seen, "Should process first tool")
+ self.assertIn(
+ "get_tourist_attractions", tool_calls_seen, "Should process second tool"
+ )
+
+ def test_current_tool_id_increment_on_completion(self):
+ """Test that current_tool_id increments when a tool completes."""
+ # Initial state
+ self.assertEqual(
+ self.detector.current_tool_id, -1, "Should start with current_tool_id=-1"
+ )
+
+ # Process first tool completely
+ chunks = [
+ "",
+ '{"name": "get_weather", ',
+ ]
+
+ for chunk in chunks:
+ result = self.detector.parse_streaming_increment(chunk, self.tools)
+
+ self.assertEqual(
+ self.detector.current_tool_id, 0, "current_tool_id should be 0"
+ )
+ self.assertEqual(
+ result.calls[0].name, "get_weather", "The first tool should be get_weather"
+ )
+ self.assertEqual(
+ result.calls[0].tool_index, 0, "The first tool index should be 0"
+ )
+
+ # Complete second tool name - this should show that current_tool_id is now 1
+ result = self.detector.parse_streaming_increment(
+ '"arguments": {"city": "Paris"}}, {"name": "get_', self.tools
+ )
+ self.assertEqual(result.calls[0].parameters, '{"city": "Paris"}')
+
+ self.assertEqual(
+ self.detector.current_tool_id,
+ 1,
+ "current_tool_id should be 1 after first tool completes and second tool starts",
+ )
+
+ result = self.detector.parse_streaming_increment(
+ 'tourist_attractions", ', self.tools
+ )
+
+ # Second tool should have tool_index=1
+ tourist_calls = [
+ call for call in result.calls if call.name == "get_tourist_attractions"
+ ]
+ self.assertEqual(
+ tourist_calls[0].tool_index, 1, "Second tool should have tool_index=1"
+ )
+
+ def test_tool_name_streaming_with_correct_index(self):
+ """Test that tool names are streamed with correct tool_index values."""
+ # Process first tool
+ self.detector.parse_streaming_increment("", self.tools)
+ result1 = self.detector.parse_streaming_increment(
+ '{"name": "get_weather", ', self.tools
+ )
+
+ # First tool name should have tool_index=0
+ weather_calls = [call for call in result1.calls if call.name == "get_weather"]
+ self.assertEqual(len(weather_calls), 1, "Should have one weather call")
+ self.assertEqual(
+ weather_calls[0].tool_index, 0, "First tool should have tool_index=0"
+ )
+
+ # Complete first tool
+ self.detector.parse_streaming_increment(
+ '"arguments": {"city": "Paris"}}', self.tools
+ )
+
+ # Start second tool
+ self.detector.parse_streaming_increment(", ", self.tools)
+ result2 = self.detector.parse_streaming_increment(
+ '{"name": "get_tourist_attractions", ', self.tools
+ )
+
+ # Second tool name should have tool_index=1
+ tourist_calls = [
+ call for call in result2.calls if call.name == "get_tourist_attractions"
+ ]
+ self.assertEqual(
+ len(tourist_calls), 1, "Should have one tourist attractions call"
+ )
+ self.assertEqual(
+ tourist_calls[0].tool_index, 1, "Second tool should have tool_index=1"
+ )
+
+ def test_buffer_reset_on_invalid_tool(self):
+ """Test that buffer and state are reset when an invalid tool name is encountered."""
+ # Start fresh with an invalid tool name from the beginning
+ result = self.detector.parse_streaming_increment(
+ '{"name": "invalid_tool", ', self.tools
+ )
+
+ # Should return empty result and reset state
+ self.assertEqual(result.calls, [], "Should return no calls for invalid tool")
+ self.assertEqual(
+ self.detector.current_tool_id,
+ -1,
+ "current_tool_id should remain -1 for invalid tool",
+ )
+ self.assertEqual(
+ self.detector._buffer, "", "Buffer should be cleared for invalid tool"
+ )
+
+
if __name__ == "__main__":
unittest.main()