Update qwen3_coder_detector.py for streaming (#8371)
This commit is contained in:
@@ -1707,17 +1707,37 @@ fahrenheit
|
||||
|
||||
accumulated_text = ""
|
||||
accumulated_calls = []
|
||||
tool_calls_by_index = {}
|
||||
|
||||
for chunk in chunks:
|
||||
result = self.detector.parse_streaming_increment(chunk, tools=self.tools)
|
||||
accumulated_text += result.normal_text
|
||||
accumulated_calls.extend(result.calls)
|
||||
|
||||
# Track calls by tool_index to handle streaming properly
|
||||
for call in result.calls:
|
||||
if call.tool_index is not None:
|
||||
if call.tool_index not in tool_calls_by_index:
|
||||
tool_calls_by_index[call.tool_index] = {
|
||||
"name": "",
|
||||
"parameters": "",
|
||||
}
|
||||
|
||||
if call.name:
|
||||
tool_calls_by_index[call.tool_index]["name"] = call.name
|
||||
if call.parameters:
|
||||
tool_calls_by_index[call.tool_index][
|
||||
"parameters"
|
||||
] += call.parameters
|
||||
|
||||
self.assertEqual(accumulated_text, "Sure! Let me check the weather.")
|
||||
self.assertEqual(len(accumulated_calls), 1)
|
||||
self.assertEqual(accumulated_calls[0].name, "get_current_weather")
|
||||
self.assertEqual(len(tool_calls_by_index), 1)
|
||||
|
||||
params = json.loads(accumulated_calls[0].parameters)
|
||||
# Get the complete tool call
|
||||
tool_call = tool_calls_by_index[0]
|
||||
self.assertEqual(tool_call["name"], "get_current_weather")
|
||||
|
||||
# Parse the accumulated parameters
|
||||
params = json.loads(tool_call["parameters"])
|
||||
self.assertEqual(params["city"], "Dallas")
|
||||
self.assertEqual(params["state"], "TX")
|
||||
|
||||
@@ -1735,20 +1755,49 @@ fahrenheit
|
||||
# Missing </parameter>, </function>, </tool_call>
|
||||
]
|
||||
|
||||
accumulated_calls = []
|
||||
tool_calls_by_index = {}
|
||||
for chunk in chunks:
|
||||
result = self.detector.parse_streaming_increment(chunk, tools=self.tools)
|
||||
accumulated_calls.extend(result.calls)
|
||||
|
||||
# Should not have any complete calls yet
|
||||
self.assertEqual(len(accumulated_calls), 0)
|
||||
# Track calls by tool_index to handle streaming properly
|
||||
for call in result.calls:
|
||||
if call.tool_index is not None:
|
||||
if call.tool_index not in tool_calls_by_index:
|
||||
tool_calls_by_index[call.tool_index] = {
|
||||
"name": "",
|
||||
"parameters": "",
|
||||
}
|
||||
|
||||
if call.name:
|
||||
tool_calls_by_index[call.tool_index]["name"] = call.name
|
||||
if call.parameters:
|
||||
tool_calls_by_index[call.tool_index][
|
||||
"parameters"
|
||||
] += call.parameters
|
||||
|
||||
# Should have partial tool call with name but incomplete parameters
|
||||
self.assertGreater(len(tool_calls_by_index), 0)
|
||||
self.assertEqual(tool_calls_by_index[0]["name"], "get_current_weather")
|
||||
|
||||
# Parameters should be incomplete (no closing brace)
|
||||
params_str = tool_calls_by_index[0]["parameters"]
|
||||
self.assertTrue(params_str.startswith('{"city": "Dallas"'))
|
||||
self.assertFalse(params_str.endswith("}"))
|
||||
|
||||
# Now complete it
|
||||
result = self.detector.parse_streaming_increment(
|
||||
"\n</parameter>\n</function>\n</tool_call>", tools=self.tools
|
||||
)
|
||||
self.assertEqual(len(result.calls), 1)
|
||||
self.assertEqual(result.calls[0].name, "get_current_weather")
|
||||
|
||||
# Update the accumulated parameters
|
||||
for call in result.calls:
|
||||
if call.tool_index is not None and call.parameters:
|
||||
tool_calls_by_index[call.tool_index]["parameters"] += call.parameters
|
||||
|
||||
# Now should have complete parameters
|
||||
final_params = json.loads(tool_calls_by_index[0]["parameters"])
|
||||
self.assertEqual(final_params["city"], "Dallas")
|
||||
self.assertEqual(final_params["state"], "TX")
|
||||
|
||||
def test_edge_case_no_parameters(self):
|
||||
"""Test tool call without parameters."""
|
||||
@@ -1845,15 +1894,15 @@ hello world
|
||||
def test_parse_streaming_incremental(self):
|
||||
"""Test that streaming is truly incremental with very small chunks."""
|
||||
model_output = """I'll check the weather.<tool_call>
|
||||
<function=get_current_weather>
|
||||
<parameter=city>
|
||||
Dallas
|
||||
</parameter>
|
||||
<parameter=state>
|
||||
TX
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>"""
|
||||
<function=get_current_weather>
|
||||
<parameter=city>
|
||||
Dallas
|
||||
</parameter>
|
||||
<parameter=state>
|
||||
TX
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>"""
|
||||
|
||||
# Simulate more realistic token-based chunks where <tool_call> is a single token
|
||||
chunks = [
|
||||
@@ -1871,49 +1920,59 @@ TX
|
||||
]
|
||||
|
||||
accumulated_text = ""
|
||||
accumulated_calls = []
|
||||
tool_calls = []
|
||||
chunks_count = 0
|
||||
|
||||
for chunk in chunks:
|
||||
result = self.detector.parse_streaming_increment(chunk, tools=self.tools)
|
||||
result = self.detector.parse_streaming_increment(chunk, self.tools)
|
||||
accumulated_text += result.normal_text
|
||||
accumulated_calls.extend(result.calls)
|
||||
chunks_count += 1
|
||||
for tool_call_chunk in result.calls:
|
||||
if (
|
||||
hasattr(tool_call_chunk, "tool_index")
|
||||
and tool_call_chunk.tool_index is not None
|
||||
):
|
||||
while len(tool_calls) <= tool_call_chunk.tool_index:
|
||||
tool_calls.append({"name": "", "parameters": ""})
|
||||
tc = tool_calls[tool_call_chunk.tool_index]
|
||||
if tool_call_chunk.name:
|
||||
tc["name"] = tool_call_chunk.name
|
||||
if tool_call_chunk.parameters:
|
||||
tc["parameters"] += tool_call_chunk.parameters
|
||||
|
||||
self.assertGreater(chunks_count, 3)
|
||||
|
||||
# Verify the accumulated results
|
||||
self.assertIn("I'll check the weather.", accumulated_text)
|
||||
self.assertEqual(len(accumulated_calls), 1)
|
||||
self.assertEqual(accumulated_calls[0].name, "get_current_weather")
|
||||
self.assertEqual(len(tool_calls), 1)
|
||||
self.assertEqual(tool_calls[0]["name"], "get_current_weather")
|
||||
|
||||
params = json.loads(accumulated_calls[0].parameters)
|
||||
self.assertEqual(params["city"], "Dallas")
|
||||
self.assertEqual(params["state"], "TX")
|
||||
params = json.loads(tool_calls[0]["parameters"])
|
||||
self.assertEqual(params, {"city": "Dallas", "state": "TX"})
|
||||
|
||||
def test_parse_streaming_multiple_tools(self):
|
||||
"""Test streaming with multiple tool calls."""
|
||||
model_output = """<tool_call>
|
||||
<function=get_current_weather>
|
||||
<parameter=city>
|
||||
Dallas
|
||||
</parameter>
|
||||
<parameter=state>
|
||||
TX
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>
|
||||
Some text in between.
|
||||
<tool_call>
|
||||
<function=calculate_area>
|
||||
<parameter=shape>
|
||||
circle
|
||||
</parameter>
|
||||
<parameter=dimensions>
|
||||
{"radius": 5}
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>"""
|
||||
<function=get_current_weather>
|
||||
<parameter=city>
|
||||
Dallas
|
||||
</parameter>
|
||||
<parameter=state>
|
||||
TX
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>
|
||||
Some text in between.
|
||||
<tool_call>
|
||||
<function=calculate_area>
|
||||
<parameter=shape>
|
||||
circle
|
||||
</parameter>
|
||||
<parameter=dimensions>
|
||||
{"radius": 5}
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>"""
|
||||
|
||||
# Simulate streaming by chunks
|
||||
chunk_size = 20
|
||||
@@ -1923,25 +1982,37 @@ circle
|
||||
]
|
||||
|
||||
accumulated_text = ""
|
||||
accumulated_calls = []
|
||||
tool_calls = []
|
||||
chunks_count = 0
|
||||
|
||||
for chunk in chunks:
|
||||
result = self.detector.parse_streaming_increment(chunk, tools=self.tools)
|
||||
result = self.detector.parse_streaming_increment(chunk, self.tools)
|
||||
accumulated_text += result.normal_text
|
||||
accumulated_calls.extend(result.calls)
|
||||
chunks_count += 1
|
||||
for tool_call_chunk in result.calls:
|
||||
if (
|
||||
hasattr(tool_call_chunk, "tool_index")
|
||||
and tool_call_chunk.tool_index is not None
|
||||
):
|
||||
while len(tool_calls) <= tool_call_chunk.tool_index:
|
||||
tool_calls.append({"name": "", "parameters": ""})
|
||||
tc = tool_calls[tool_call_chunk.tool_index]
|
||||
if tool_call_chunk.name:
|
||||
tc["name"] = tool_call_chunk.name
|
||||
if tool_call_chunk.parameters:
|
||||
tc["parameters"] += tool_call_chunk.parameters
|
||||
|
||||
self.assertIn("Some text in between.", accumulated_text)
|
||||
self.assertEqual(len(accumulated_calls), 2)
|
||||
self.assertEqual(accumulated_calls[0].name, "get_current_weather")
|
||||
self.assertEqual(accumulated_calls[1].name, "calculate_area")
|
||||
self.assertEqual(len(tool_calls), 2)
|
||||
self.assertEqual(tool_calls[0]["name"], "get_current_weather")
|
||||
self.assertEqual(tool_calls[1]["name"], "calculate_area")
|
||||
|
||||
# Verify parameters
|
||||
params1 = json.loads(accumulated_calls[0].parameters)
|
||||
self.assertEqual(params1["city"], "Dallas")
|
||||
params1 = json.loads(tool_calls[0]["parameters"])
|
||||
self.assertEqual(params1, {"city": "Dallas", "state": "TX"})
|
||||
|
||||
params2 = json.loads(accumulated_calls[1].parameters)
|
||||
self.assertEqual(params2["shape"], "circle")
|
||||
self.assertEqual(params2["dimensions"], {"radius": 5})
|
||||
params2 = json.loads(tool_calls[1]["parameters"])
|
||||
self.assertEqual(params2, {"shape": "circle", "dimensions": {"radius": 5}})
|
||||
|
||||
|
||||
class TestGlm4MoeDetector(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user