Update qwen3_coder_detector.py for streaming (#8371)

This commit is contained in:
maocheng23
2025-08-08 14:51:03 -07:00
committed by GitHub
parent 7b7e56150e
commit b3359dc9bf
2 changed files with 348 additions and 67 deletions

View File

@@ -1707,17 +1707,37 @@ fahrenheit
accumulated_text = ""
accumulated_calls = []
tool_calls_by_index = {}
for chunk in chunks:
result = self.detector.parse_streaming_increment(chunk, tools=self.tools)
accumulated_text += result.normal_text
accumulated_calls.extend(result.calls)
# Track calls by tool_index to handle streaming properly
for call in result.calls:
if call.tool_index is not None:
if call.tool_index not in tool_calls_by_index:
tool_calls_by_index[call.tool_index] = {
"name": "",
"parameters": "",
}
if call.name:
tool_calls_by_index[call.tool_index]["name"] = call.name
if call.parameters:
tool_calls_by_index[call.tool_index][
"parameters"
] += call.parameters
self.assertEqual(accumulated_text, "Sure! Let me check the weather.")
self.assertEqual(len(accumulated_calls), 1)
self.assertEqual(accumulated_calls[0].name, "get_current_weather")
self.assertEqual(len(tool_calls_by_index), 1)
params = json.loads(accumulated_calls[0].parameters)
# Get the complete tool call
tool_call = tool_calls_by_index[0]
self.assertEqual(tool_call["name"], "get_current_weather")
# Parse the accumulated parameters
params = json.loads(tool_call["parameters"])
self.assertEqual(params["city"], "Dallas")
self.assertEqual(params["state"], "TX")
@@ -1735,20 +1755,49 @@ fahrenheit
# Missing </parameter>, </function>, </tool_call>
]
accumulated_calls = []
tool_calls_by_index = {}
for chunk in chunks:
result = self.detector.parse_streaming_increment(chunk, tools=self.tools)
accumulated_calls.extend(result.calls)
# Should not have any complete calls yet
self.assertEqual(len(accumulated_calls), 0)
# Track calls by tool_index to handle streaming properly
for call in result.calls:
if call.tool_index is not None:
if call.tool_index not in tool_calls_by_index:
tool_calls_by_index[call.tool_index] = {
"name": "",
"parameters": "",
}
if call.name:
tool_calls_by_index[call.tool_index]["name"] = call.name
if call.parameters:
tool_calls_by_index[call.tool_index][
"parameters"
] += call.parameters
# Should have partial tool call with name but incomplete parameters
self.assertGreater(len(tool_calls_by_index), 0)
self.assertEqual(tool_calls_by_index[0]["name"], "get_current_weather")
# Parameters should be incomplete (no closing brace)
params_str = tool_calls_by_index[0]["parameters"]
self.assertTrue(params_str.startswith('{"city": "Dallas"'))
self.assertFalse(params_str.endswith("}"))
# Now complete it
result = self.detector.parse_streaming_increment(
"\n</parameter>\n</function>\n</tool_call>", tools=self.tools
)
self.assertEqual(len(result.calls), 1)
self.assertEqual(result.calls[0].name, "get_current_weather")
# Update the accumulated parameters
for call in result.calls:
if call.tool_index is not None and call.parameters:
tool_calls_by_index[call.tool_index]["parameters"] += call.parameters
# Now should have complete parameters
final_params = json.loads(tool_calls_by_index[0]["parameters"])
self.assertEqual(final_params["city"], "Dallas")
self.assertEqual(final_params["state"], "TX")
def test_edge_case_no_parameters(self):
"""Test tool call without parameters."""
@@ -1845,15 +1894,15 @@ hello world
def test_parse_streaming_incremental(self):
"""Test that streaming is truly incremental with very small chunks."""
model_output = """I'll check the weather.<tool_call>
<function=get_current_weather>
<parameter=city>
Dallas
</parameter>
<parameter=state>
TX
</parameter>
</function>
</tool_call>"""
<function=get_current_weather>
<parameter=city>
Dallas
</parameter>
<parameter=state>
TX
</parameter>
</function>
</tool_call>"""
# Simulate more realistic token-based chunks where <tool_call> is a single token
chunks = [
@@ -1871,49 +1920,59 @@ TX
]
accumulated_text = ""
accumulated_calls = []
tool_calls = []
chunks_count = 0
for chunk in chunks:
result = self.detector.parse_streaming_increment(chunk, tools=self.tools)
result = self.detector.parse_streaming_increment(chunk, self.tools)
accumulated_text += result.normal_text
accumulated_calls.extend(result.calls)
chunks_count += 1
for tool_call_chunk in result.calls:
if (
hasattr(tool_call_chunk, "tool_index")
and tool_call_chunk.tool_index is not None
):
while len(tool_calls) <= tool_call_chunk.tool_index:
tool_calls.append({"name": "", "parameters": ""})
tc = tool_calls[tool_call_chunk.tool_index]
if tool_call_chunk.name:
tc["name"] = tool_call_chunk.name
if tool_call_chunk.parameters:
tc["parameters"] += tool_call_chunk.parameters
self.assertGreater(chunks_count, 3)
# Verify the accumulated results
self.assertIn("I'll check the weather.", accumulated_text)
self.assertEqual(len(accumulated_calls), 1)
self.assertEqual(accumulated_calls[0].name, "get_current_weather")
self.assertEqual(len(tool_calls), 1)
self.assertEqual(tool_calls[0]["name"], "get_current_weather")
params = json.loads(accumulated_calls[0].parameters)
self.assertEqual(params["city"], "Dallas")
self.assertEqual(params["state"], "TX")
params = json.loads(tool_calls[0]["parameters"])
self.assertEqual(params, {"city": "Dallas", "state": "TX"})
def test_parse_streaming_multiple_tools(self):
"""Test streaming with multiple tool calls."""
model_output = """<tool_call>
<function=get_current_weather>
<parameter=city>
Dallas
</parameter>
<parameter=state>
TX
</parameter>
</function>
</tool_call>
Some text in between.
<tool_call>
<function=calculate_area>
<parameter=shape>
circle
</parameter>
<parameter=dimensions>
{"radius": 5}
</parameter>
</function>
</tool_call>"""
<function=get_current_weather>
<parameter=city>
Dallas
</parameter>
<parameter=state>
TX
</parameter>
</function>
</tool_call>
Some text in between.
<tool_call>
<function=calculate_area>
<parameter=shape>
circle
</parameter>
<parameter=dimensions>
{"radius": 5}
</parameter>
</function>
</tool_call>"""
# Simulate streaming by chunks
chunk_size = 20
@@ -1923,25 +1982,37 @@ circle
]
accumulated_text = ""
accumulated_calls = []
tool_calls = []
chunks_count = 0
for chunk in chunks:
result = self.detector.parse_streaming_increment(chunk, tools=self.tools)
result = self.detector.parse_streaming_increment(chunk, self.tools)
accumulated_text += result.normal_text
accumulated_calls.extend(result.calls)
chunks_count += 1
for tool_call_chunk in result.calls:
if (
hasattr(tool_call_chunk, "tool_index")
and tool_call_chunk.tool_index is not None
):
while len(tool_calls) <= tool_call_chunk.tool_index:
tool_calls.append({"name": "", "parameters": ""})
tc = tool_calls[tool_call_chunk.tool_index]
if tool_call_chunk.name:
tc["name"] = tool_call_chunk.name
if tool_call_chunk.parameters:
tc["parameters"] += tool_call_chunk.parameters
self.assertIn("Some text in between.", accumulated_text)
self.assertEqual(len(accumulated_calls), 2)
self.assertEqual(accumulated_calls[0].name, "get_current_weather")
self.assertEqual(accumulated_calls[1].name, "calculate_area")
self.assertEqual(len(tool_calls), 2)
self.assertEqual(tool_calls[0]["name"], "get_current_weather")
self.assertEqual(tool_calls[1]["name"], "calculate_area")
# Verify parameters
params1 = json.loads(accumulated_calls[0].parameters)
self.assertEqual(params1["city"], "Dallas")
params1 = json.loads(tool_calls[0]["parameters"])
self.assertEqual(params1, {"city": "Dallas", "state": "TX"})
params2 = json.loads(accumulated_calls[1].parameters)
self.assertEqual(params2["shape"], "circle")
self.assertEqual(params2["dimensions"], {"radius": 5})
params2 = json.loads(tool_calls[1]["parameters"])
self.assertEqual(params2, {"shape": "circle", "dimensions": {"radius": 5}})
class TestGlm4MoeDetector(unittest.TestCase):