diff --git a/python/sglang/srt/function_call/qwen3_coder_detector.py b/python/sglang/srt/function_call/qwen3_coder_detector.py index 674a189a7..454f5048e 100644 --- a/python/sglang/srt/function_call/qwen3_coder_detector.py +++ b/python/sglang/srt/function_call/qwen3_coder_detector.py @@ -57,6 +57,15 @@ class Qwen3CoderDetector(BaseFormatDetector): ) self._buf: str = "" + # Streaming state variables + self._current_function_name: str = "" + self._current_parameters: Dict[str, Any] = {} + self._streamed_parameters: Dict[str, str] = ( + {} + ) # Track what parameter content we've streamed + self._in_tool_call: bool = False + self._function_name_sent: bool = False + def has_tool_call(self, text: str) -> bool: return self.tool_call_start_token in text @@ -70,23 +79,224 @@ class Qwen3CoderDetector(BaseFormatDetector): self._buf += new_text normal = "" calls: List[ToolCallItem] = [] + + # Build tool indices for validation + if not hasattr(self, "_tool_indices"): + self._tool_indices = self._get_tool_indices(tools) + while True: - if self.tool_call_start_token not in self._buf: + # If we're not in a tool call and don't see a start token, return normal text + if not self._in_tool_call and self.tool_call_start_token not in self._buf: normal += self._buf self._buf = "" break - s = self._buf.find(self.tool_call_start_token) - if s > 0: + + # Look for tool call start + if not self._in_tool_call: + s = self._buf.find(self.tool_call_start_token) + if s == -1: + normal += self._buf + self._buf = "" + break + normal += self._buf[:s] self._buf = self._buf[s:] - e = self._buf.find(self.tool_call_end_token) - if e == -1: - break - block = self._buf[: e + len(self.tool_call_end_token)] - self._buf = self._buf[e + len(self.tool_call_end_token) :] - calls.extend(self._parse_block(block, tools)) + + self._in_tool_call = True + self._function_name_sent = False + self._current_function_name = "" + self._current_parameters = {} + self._streamed_parameters = {} + + # Remove the start token + self._buf = self._buf[len(self.tool_call_start_token) :] + continue + + # We're in a tool call, try to parse function name if not sent yet + if not self._function_name_sent: + # Look for function name pattern: + function_match = re.search(r"]+)>", self._buf) + if function_match: + function_name = function_match.group(1).strip() + + # Validate function name + if function_name in self._tool_indices: + self._current_function_name = function_name + self._function_name_sent = True + + # Initialize tool call tracking + if self.current_tool_id == -1: + self.current_tool_id = 0 + + # Ensure tracking arrays are large enough + while len(self.prev_tool_call_arr) <= self.current_tool_id: + self.prev_tool_call_arr.append({}) + while len(self.streamed_args_for_tool) <= self.current_tool_id: + self.streamed_args_for_tool.append("") + + # Store tool call info + self.prev_tool_call_arr[self.current_tool_id] = { + "name": function_name, + "arguments": {}, + } + + # Send tool name with empty parameters + calls.append( + ToolCallItem( + tool_index=self.current_tool_id, + name=function_name, + parameters="", + ) + ) + + # Remove the processed function declaration + self._buf = self._buf[function_match.end() :] + continue + else: + # Invalid function name, reset state + logger.warning(f"Invalid function name: {function_name}") + self._reset_streaming_state() + normal += self._buf + self._buf = "" + break + else: + # Function name not complete yet, wait for more text + break + + # Parse parameters incrementally + if self._function_name_sent: + # Process parameters and get any calls to emit + parameter_calls = self._parse_and_stream_parameters(self._buf) + calls.extend(parameter_calls) + + # Check if tool call is complete + if self.tool_call_end_token in self._buf: + end_pos = self._buf.find(self.tool_call_end_token) + + # Add closing brace to complete the JSON object + current_streamed = self.streamed_args_for_tool[self.current_tool_id] + if current_streamed: + # Count opening and closing braces to check if JSON is complete + open_braces = current_streamed.count("{") + close_braces = current_streamed.count("}") + if open_braces > close_braces: + calls.append( + ToolCallItem( + tool_index=self.current_tool_id, + name=None, + parameters="}", + ) + ) + self.streamed_args_for_tool[self.current_tool_id] = ( + current_streamed + "}" + ) + + # Complete the tool call + self._buf = self._buf[end_pos + len(self.tool_call_end_token) :] + self._reset_streaming_state() + self.current_tool_id += 1 + continue + else: + # Tool call not complete yet, wait for more text + break + return StreamingParseResult(normal_text=normal, calls=calls) + def _parse_and_stream_parameters(self, text_to_parse: str) -> List[ToolCallItem]: + """ + Parse complete parameter blocks from text and return any tool call items to emit. + + This method: + 1. Finds all complete blocks + 2. Parses them into a dictionary + 3. Compares with current parameters and generates diff if needed + 4. Updates internal state + + Args: + text_to_parse: The text to search for parameter blocks + + Returns: + List of ToolCallItem objects to emit (may be empty) + """ + calls: List[ToolCallItem] = [] + + # Find all complete parameter patterns + param_matches = list( + re.finditer( + r"]+)>(.*?)", text_to_parse, re.DOTALL + ) + ) + + # Build new parameters dictionary + new_params = {} + for match in param_matches: + param_name = match.group(1).strip() + param_value = match.group(2) + new_params[param_name] = _safe_val(param_value) + + # Calculate parameter diff to stream with proper incremental JSON building + if new_params != self._current_parameters: + previous_args_json = self.streamed_args_for_tool[self.current_tool_id] + + # Build incremental JSON properly + if not self._current_parameters: + # First parameter(s) - start JSON object but don't close it yet + items = [] + for key, value in new_params.items(): + items.append( + f"{json.dumps(key, ensure_ascii=False)}: {json.dumps(value, ensure_ascii=False)}" + ) + json_fragment = "{" + ", ".join(items) + + calls.append( + ToolCallItem( + tool_index=self.current_tool_id, + name=None, + parameters=json_fragment, + ) + ) + self.streamed_args_for_tool[self.current_tool_id] = json_fragment + + else: + # Additional parameters - add them incrementally + new_keys = set(new_params.keys()) - set(self._current_parameters.keys()) + if new_keys: + # Build the continuation part (no closing brace yet) + continuation_parts = [] + for key in new_keys: + value = new_params[key] + continuation_parts.append( + f"{json.dumps(key, ensure_ascii=False)}: {json.dumps(value, ensure_ascii=False)}" + ) + + json_fragment = ", " + ", ".join(continuation_parts) + + calls.append( + ToolCallItem( + tool_index=self.current_tool_id, + name=None, + parameters=json_fragment, + ) + ) + self.streamed_args_for_tool[self.current_tool_id] = ( + previous_args_json + json_fragment + ) + + # Update current state + self._current_parameters = new_params + self.prev_tool_call_arr[self.current_tool_id]["arguments"] = new_params + + return calls + + def _reset_streaming_state(self): + """Reset streaming state for the next tool call""" + self._in_tool_call = False + self._function_name_sent = False + self._current_function_name = "" + self._current_parameters = {} + self._streamed_parameters = {} + self.current_tool_name_sent = False + def _extract(self, text: str, tools: List[Tool]) -> Tuple[str, List[ToolCallItem]]: normal_parts: List[str] = [] calls: List[ToolCallItem] = [] diff --git a/test/srt/test_function_call_parser.py b/test/srt/test_function_call_parser.py index cc4521622..afbba82e3 100644 --- a/test/srt/test_function_call_parser.py +++ b/test/srt/test_function_call_parser.py @@ -1707,17 +1707,37 @@ fahrenheit accumulated_text = "" accumulated_calls = [] + tool_calls_by_index = {} for chunk in chunks: result = self.detector.parse_streaming_increment(chunk, tools=self.tools) accumulated_text += result.normal_text - accumulated_calls.extend(result.calls) + + # Track calls by tool_index to handle streaming properly + for call in result.calls: + if call.tool_index is not None: + if call.tool_index not in tool_calls_by_index: + tool_calls_by_index[call.tool_index] = { + "name": "", + "parameters": "", + } + + if call.name: + tool_calls_by_index[call.tool_index]["name"] = call.name + if call.parameters: + tool_calls_by_index[call.tool_index][ + "parameters" + ] += call.parameters self.assertEqual(accumulated_text, "Sure! Let me check the weather.") - self.assertEqual(len(accumulated_calls), 1) - self.assertEqual(accumulated_calls[0].name, "get_current_weather") + self.assertEqual(len(tool_calls_by_index), 1) - params = json.loads(accumulated_calls[0].parameters) + # Get the complete tool call + tool_call = tool_calls_by_index[0] + self.assertEqual(tool_call["name"], "get_current_weather") + + # Parse the accumulated parameters + params = json.loads(tool_call["parameters"]) self.assertEqual(params["city"], "Dallas") self.assertEqual(params["state"], "TX") @@ -1735,20 +1755,49 @@ fahrenheit # Missing , , ] - accumulated_calls = [] + tool_calls_by_index = {} for chunk in chunks: result = self.detector.parse_streaming_increment(chunk, tools=self.tools) - accumulated_calls.extend(result.calls) - # Should not have any complete calls yet - self.assertEqual(len(accumulated_calls), 0) + # Track calls by tool_index to handle streaming properly + for call in result.calls: + if call.tool_index is not None: + if call.tool_index not in tool_calls_by_index: + tool_calls_by_index[call.tool_index] = { + "name": "", + "parameters": "", + } + + if call.name: + tool_calls_by_index[call.tool_index]["name"] = call.name + if call.parameters: + tool_calls_by_index[call.tool_index][ + "parameters" + ] += call.parameters + + # Should have partial tool call with name but incomplete parameters + self.assertGreater(len(tool_calls_by_index), 0) + self.assertEqual(tool_calls_by_index[0]["name"], "get_current_weather") + + # Parameters should be incomplete (no closing brace) + params_str = tool_calls_by_index[0]["parameters"] + self.assertTrue(params_str.startswith('{"city": "Dallas"')) + self.assertFalse(params_str.endswith("}")) # Now complete it result = self.detector.parse_streaming_increment( "\n\n\n", tools=self.tools ) - self.assertEqual(len(result.calls), 1) - self.assertEqual(result.calls[0].name, "get_current_weather") + + # Update the accumulated parameters + for call in result.calls: + if call.tool_index is not None and call.parameters: + tool_calls_by_index[call.tool_index]["parameters"] += call.parameters + + # Now should have complete parameters + final_params = json.loads(tool_calls_by_index[0]["parameters"]) + self.assertEqual(final_params["city"], "Dallas") + self.assertEqual(final_params["state"], "TX") def test_edge_case_no_parameters(self): """Test tool call without parameters.""" @@ -1845,15 +1894,15 @@ hello world def test_parse_streaming_incremental(self): """Test that streaming is truly incremental with very small chunks.""" model_output = """I'll check the weather. - - -Dallas - - -TX - - -""" + + + Dallas + + + TX + + + """ # Simulate more realistic token-based chunks where is a single token chunks = [ @@ -1871,49 +1920,59 @@ TX ] accumulated_text = "" - accumulated_calls = [] + tool_calls = [] chunks_count = 0 for chunk in chunks: - result = self.detector.parse_streaming_increment(chunk, tools=self.tools) + result = self.detector.parse_streaming_increment(chunk, self.tools) accumulated_text += result.normal_text - accumulated_calls.extend(result.calls) chunks_count += 1 + for tool_call_chunk in result.calls: + if ( + hasattr(tool_call_chunk, "tool_index") + and tool_call_chunk.tool_index is not None + ): + while len(tool_calls) <= tool_call_chunk.tool_index: + tool_calls.append({"name": "", "parameters": ""}) + tc = tool_calls[tool_call_chunk.tool_index] + if tool_call_chunk.name: + tc["name"] = tool_call_chunk.name + if tool_call_chunk.parameters: + tc["parameters"] += tool_call_chunk.parameters self.assertGreater(chunks_count, 3) # Verify the accumulated results self.assertIn("I'll check the weather.", accumulated_text) - self.assertEqual(len(accumulated_calls), 1) - self.assertEqual(accumulated_calls[0].name, "get_current_weather") + self.assertEqual(len(tool_calls), 1) + self.assertEqual(tool_calls[0]["name"], "get_current_weather") - params = json.loads(accumulated_calls[0].parameters) - self.assertEqual(params["city"], "Dallas") - self.assertEqual(params["state"], "TX") + params = json.loads(tool_calls[0]["parameters"]) + self.assertEqual(params, {"city": "Dallas", "state": "TX"}) def test_parse_streaming_multiple_tools(self): """Test streaming with multiple tool calls.""" model_output = """ - - -Dallas - - -TX - - - -Some text in between. - - - -circle - - -{"radius": 5} - - -""" + + + Dallas + + + TX + + + + Some text in between. + + + + circle + + + {"radius": 5} + + + """ # Simulate streaming by chunks chunk_size = 20 @@ -1923,25 +1982,37 @@ circle ] accumulated_text = "" - accumulated_calls = [] + tool_calls = [] + chunks_count = 0 for chunk in chunks: - result = self.detector.parse_streaming_increment(chunk, tools=self.tools) + result = self.detector.parse_streaming_increment(chunk, self.tools) accumulated_text += result.normal_text - accumulated_calls.extend(result.calls) + chunks_count += 1 + for tool_call_chunk in result.calls: + if ( + hasattr(tool_call_chunk, "tool_index") + and tool_call_chunk.tool_index is not None + ): + while len(tool_calls) <= tool_call_chunk.tool_index: + tool_calls.append({"name": "", "parameters": ""}) + tc = tool_calls[tool_call_chunk.tool_index] + if tool_call_chunk.name: + tc["name"] = tool_call_chunk.name + if tool_call_chunk.parameters: + tc["parameters"] += tool_call_chunk.parameters self.assertIn("Some text in between.", accumulated_text) - self.assertEqual(len(accumulated_calls), 2) - self.assertEqual(accumulated_calls[0].name, "get_current_weather") - self.assertEqual(accumulated_calls[1].name, "calculate_area") + self.assertEqual(len(tool_calls), 2) + self.assertEqual(tool_calls[0]["name"], "get_current_weather") + self.assertEqual(tool_calls[1]["name"], "calculate_area") # Verify parameters - params1 = json.loads(accumulated_calls[0].parameters) - self.assertEqual(params1["city"], "Dallas") + params1 = json.loads(tool_calls[0]["parameters"]) + self.assertEqual(params1, {"city": "Dallas", "state": "TX"}) - params2 = json.loads(accumulated_calls[1].parameters) - self.assertEqual(params2["shape"], "circle") - self.assertEqual(params2["dimensions"], {"radius": 5}) + params2 = json.loads(tool_calls[1]["parameters"]) + self.assertEqual(params2, {"shape": "circle", "dimensions": {"radius": 5}}) class TestGlm4MoeDetector(unittest.TestCase):