diff --git a/python/sglang/srt/entrypoints/openai/serving_chat.py b/python/sglang/srt/entrypoints/openai/serving_chat.py index 7b0f6f867..bd9f9a98f 100644 --- a/python/sglang/srt/entrypoints/openai/serving_chat.py +++ b/python/sglang/srt/entrypoints/openai/serving_chat.py @@ -412,6 +412,8 @@ class OpenAIServingChat(OpenAIServingBase): is_firsts = {} stream_buffers = {} n_prev_tokens = {} + has_tool_calls = {} + finish_reasons = {} # Usage tracking prompt_tokens = {} @@ -443,6 +445,10 @@ class OpenAIServingChat(OpenAIServingBase): finish_reason = content["meta_info"]["finish_reason"] finish_reason_type = finish_reason["type"] if finish_reason else None + # Track finish_reason for each index + if finish_reason_type: + finish_reasons[index] = finish_reason + # First chunk with role if is_firsts.get(index, True): is_firsts[index] = False @@ -450,13 +456,8 @@ class OpenAIServingChat(OpenAIServingBase): choice_data = ChatCompletionResponseStreamChoice( index=index, delta=delta, - finish_reason=finish_reason_type, - matched_stop=( - finish_reason["matched"] - if finish_reason and "matched" in finish_reason - else None - ), - logprobs=choice_logprobs, + finish_reason=None, + logprobs=None, ) chunk = ChatCompletionStreamResponse( id=content["meta_info"]["id"], @@ -483,7 +484,7 @@ class OpenAIServingChat(OpenAIServingBase): choice_data = ChatCompletionResponseStreamChoice( index=index, delta=DeltaMessage(reasoning_content=reasoning_text), - finish_reason=finish_reason_type, + finish_reason=None, ) chunk = ChatCompletionStreamResponse( id=content["meta_info"]["id"], @@ -495,40 +496,34 @@ class OpenAIServingChat(OpenAIServingBase): # Handle tool calls if request.tool_choice != "none" and request.tools: - async for ( - chunk, - tool_call_finish_reason_type, - ) in self._process_tool_call_stream( + async for chunk in self._process_tool_call_stream( index, delta, parser_dict, content, request, - finish_reason_type, + has_tool_calls, ): if chunk: yield chunk - finish_reason_type = tool_call_finish_reason_type + + # Send any remaining tool call arguments when generation finishes + if finish_reason_type is not None and index in parser_dict: + parser = parser_dict[index] + remaining_chunk = self._check_for_unstreamed_tool_args( + parser, content, request, index + ) + if remaining_chunk: + yield remaining_chunk else: # Regular content - if delta or not ( - request.stream_options and request.stream_options.include_usage - ): + if delta: choice_data = ChatCompletionResponseStreamChoice( index=index, delta=DeltaMessage(content=delta if delta else None), - finish_reason=( - None - if request.stream_options - and request.stream_options.include_usage - else finish_reason_type - ), - matched_stop=( - finish_reason["matched"] - if finish_reason and "matched" in finish_reason - else None - ), + finish_reason=None, + matched_stop=None, logprobs=choice_logprobs, ) chunk = ChatCompletionStreamResponse( @@ -539,26 +534,36 @@ class OpenAIServingChat(OpenAIServingBase): ) yield f"data: {chunk.model_dump_json()}\n\n" - # Final chunk with finish_reason - finish_reason_chunk = ChatCompletionStreamResponse( - id=content["meta_info"]["id"], - created=int(time.time()), - choices=[ - ChatCompletionResponseStreamChoice( - index=index, - delta=DeltaMessage(), - finish_reason=finish_reason_type, - matched_stop=( - finish_reason["matched"] - if finish_reason and "matched" in finish_reason - else None - ), - ) - ], - model=request.model, - usage=None, - ) - yield f"data: {finish_reason_chunk.model_dump_json()}\n\n" + # Send finish_reason chunks for each index that completed + for idx, finish_reason_data in finish_reasons.items(): + finish_reason_type = finish_reason_data["type"] + + # Change finish_reason to "tool_calls" if we had tool calls and stopped naturally + final_finish_reason = finish_reason_type + if has_tool_calls.get(idx, False) and finish_reason_type == "stop": + final_finish_reason = "tool_calls" + + finish_reason_chunk = ChatCompletionStreamResponse( + id=content["meta_info"][ + "id" + ], # NOTE: openai uses the same chatcmpl-id for all indices + created=int(time.time()), + choices=[ + ChatCompletionResponseStreamChoice( + index=idx, + delta=DeltaMessage(), + finish_reason=final_finish_reason, + matched_stop=( + finish_reason_data["matched"] + if "matched" in finish_reason_data + else None + ), + ) + ], + model=request.model, + usage=None, + ) + yield f"data: {finish_reason_chunk.model_dump_json()}\n\n" # Send hidden states if requested if request.return_hidden_states and hidden_states: @@ -578,7 +583,7 @@ class OpenAIServingChat(OpenAIServingBase): delta=DeltaMessage( hidden_states=last_token_hidden_states ), - finish_reason=finish_reason_type, + finish_reason=None, # Hidden states don't need finish_reason ) ], model=request.model, @@ -857,7 +862,7 @@ class OpenAIServingChat(OpenAIServingBase): parser_dict: Dict[int, FunctionCallParser], content: Dict[str, Any], request: ChatCompletionRequest, - finish_reason_type: Optional[str], + has_tool_calls: Dict[int, bool], ): """Process tool calls in streaming response""" if index not in parser_dict: @@ -874,7 +879,7 @@ class OpenAIServingChat(OpenAIServingBase): choice_data = ChatCompletionResponseStreamChoice( index=index, delta=DeltaMessage(content=normal_text), - finish_reason=finish_reason_type, + finish_reason=None, ) chunk = ChatCompletionStreamResponse( id=content["meta_info"]["id"], @@ -882,10 +887,13 @@ class OpenAIServingChat(OpenAIServingBase): choices=[choice_data], model=request.model, ) - yield f"data: {chunk.model_dump_json()}\n\n", finish_reason_type + yield f"data: {chunk.model_dump_json()}\n\n" # Yield tool calls for call_item in calls: + # Mark that this choice has tool calls + has_tool_calls[index] = True + # Tool call ID should be generated only once per tool call if call_item.name: # First chunk: include ID and function name @@ -896,23 +904,6 @@ class OpenAIServingChat(OpenAIServingBase): tool_call_id = None function_name = None - if finish_reason_type == "stop": - # Handle remaining arguments - latest_delta_len = 0 - if isinstance(call_item.parameters, str): - latest_delta_len = len(call_item.parameters) - - expected_call = json.dumps( - parser.detector.prev_tool_call_arr[index].get("arguments", {}), - ensure_ascii=False, - ) - actual_call = parser.detector.streamed_args_for_tool[index] - if latest_delta_len > 0: - actual_call = actual_call[:-latest_delta_len] - remaining_call = expected_call.replace(actual_call, "", 1) - call_item.parameters = remaining_call - finish_reason_type = "tool_calls" - tool_call = ToolCall( id=tool_call_id, index=call_item.tool_index, @@ -925,11 +916,7 @@ class OpenAIServingChat(OpenAIServingBase): choice_data = ChatCompletionResponseStreamChoice( index=index, delta=DeltaMessage(tool_calls=[tool_call]), - finish_reason=( - None - if request.stream_options and request.stream_options.include_usage - else finish_reason_type - ), + finish_reason=None, ) chunk = ChatCompletionStreamResponse( id=content["meta_info"]["id"], @@ -937,7 +924,76 @@ class OpenAIServingChat(OpenAIServingBase): choices=[choice_data], model=request.model, ) - yield f"data: {chunk.model_dump_json()}\n\n", finish_reason_type + yield f"data: {chunk.model_dump_json()}\n\n" - if finish_reason_type == "stop": - yield None, "tool_calls" + def _check_for_unstreamed_tool_args( + self, + parser: FunctionCallParser, + content: Dict[str, Any], + request: ChatCompletionRequest, + index: int, + ) -> Optional[str]: + """ + Check for any remaining tool call arguments that need to be streamed + when generation finishes. This ensures tool calls are properly completed + even if the model generates the final arguments in the last chunk. + """ + # Only check if we have tool calls and the parser has tracked data + if ( + not hasattr(parser.detector, "prev_tool_call_arr") + or not parser.detector.prev_tool_call_arr + ): + return None + + if ( + not hasattr(parser.detector, "streamed_args_for_tool") + or not parser.detector.streamed_args_for_tool + ): + return None + + # Get the last tool call that was being processed + tool_index = len(parser.detector.prev_tool_call_arr) - 1 + if tool_index < 0 or tool_index >= len(parser.detector.streamed_args_for_tool): + return None + + # Get expected vs actual arguments + expected_args = parser.detector.prev_tool_call_arr[tool_index].get( + "arguments", {} + ) + expected_call = json.dumps(expected_args, ensure_ascii=False) + actual_call = parser.detector.streamed_args_for_tool[tool_index] + + # Check if there are remaining arguments to send + remaining_call = ( + expected_call.replace(actual_call, "", 1) + if actual_call in expected_call + else "" + ) + + if remaining_call: + # Create tool call chunk with remaining arguments + tool_call = ToolCall( + id=None, # No ID for argument deltas + index=tool_index, + function=FunctionResponse( + name=None, # No name for argument deltas + arguments=remaining_call, + ), + ) + + choice_data = ChatCompletionResponseStreamChoice( + index=index, + delta=DeltaMessage(tool_calls=[tool_call]), + finish_reason=None, # Don't send finish_reason with this chunk + ) + + chunk = ChatCompletionStreamResponse( + id=content["meta_info"]["id"], + created=int(time.time()), + choices=[choice_data], + model=request.model, + ) + + return f"data: {chunk.model_dump_json()}\n\n" + + return None diff --git a/test/srt/openai_server/basic/test_openai_server.py b/test/srt/openai_server/basic/test_openai_server.py index deafaad3c..f42039bff 100644 --- a/test/srt/openai_server/basic/test_openai_server.py +++ b/test/srt/openai_server/basic/test_openai_server.py @@ -233,6 +233,7 @@ class TestOpenAIServer(CustomTestCase): is_firsts = {} is_finished = {} + finish_reason_counts = {} for response in generator: usage = response.usage if usage is not None: @@ -245,6 +246,7 @@ class TestOpenAIServer(CustomTestCase): finish_reason = response.choices[0].finish_reason if finish_reason is not None: is_finished[index] = True + finish_reason_counts[index] = finish_reason_counts.get(index, 0) + 1 data = response.choices[0].delta @@ -284,6 +286,15 @@ class TestOpenAIServer(CustomTestCase): index, True ), f"index {index} is not found in the response" + # Verify that each choice gets exactly one finish_reason chunk + for index in range(parallel_sample_num): + assert ( + index in finish_reason_counts + ), f"No finish_reason found for index {index}" + assert ( + finish_reason_counts[index] == 1 + ), f"Expected 1 finish_reason chunk for index {index}, got {finish_reason_counts[index]}" + def test_completion(self): for echo in [False, True]: for logprobs in [None, 5]: @@ -420,91 +431,6 @@ The SmartHome Mini is a compact smart home assistant available in black or white client.models.retrieve("non-existent-model") -# ------------------------------------------------------------------------- -# EBNF Test Class: TestOpenAIServerEBNF -# Launches the server with xgrammar, has only EBNF tests -# ------------------------------------------------------------------------- -class TestOpenAIServerEBNF(CustomTestCase): - @classmethod - def setUpClass(cls): - cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST - cls.base_url = DEFAULT_URL_FOR_TEST - cls.api_key = "sk-123456" - - # passing xgrammar specifically - other_args = ["--grammar-backend", "xgrammar"] - cls.process = popen_launch_server( - cls.model, - cls.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - api_key=cls.api_key, - other_args=other_args, - ) - cls.base_url += "/v1" - cls.tokenizer = get_tokenizer(DEFAULT_SMALL_MODEL_NAME_FOR_TEST) - - @classmethod - def tearDownClass(cls): - kill_process_tree(cls.process.pid) - - def test_ebnf(self): - """ - Ensure we can pass `ebnf` to the local openai server - and that it enforces the grammar. - """ - client = openai.Client(api_key=self.api_key, base_url=self.base_url) - ebnf_grammar = r""" - root ::= "Hello" | "Hi" | "Hey" - """ - pattern = re.compile(r"^(Hello|Hi|Hey)[.!?]*\s*$") - - response = client.chat.completions.create( - model=self.model, - messages=[ - {"role": "system", "content": "You are a helpful EBNF test bot."}, - {"role": "user", "content": "Say a greeting (Hello, Hi, or Hey)."}, - ], - temperature=0, - max_tokens=32, - extra_body={"ebnf": ebnf_grammar}, - ) - text = response.choices[0].message.content.strip() - self.assertTrue(len(text) > 0, "Got empty text from EBNF generation") - self.assertRegex(text, pattern, f"Text '{text}' doesn't match EBNF choices") - - def test_ebnf_strict_json(self): - """ - A stricter EBNF that produces exactly {"name":"Alice"} format - with no trailing punctuation or extra fields. - """ - client = openai.Client(api_key=self.api_key, base_url=self.base_url) - ebnf_grammar = r""" - root ::= "{" pair "}" - pair ::= "\"name\"" ":" string - string ::= "\"" [A-Za-z]+ "\"" - """ - pattern = re.compile(r'^\{"name":"[A-Za-z]+"\}$') - - response = client.chat.completions.create( - model=self.model, - messages=[ - {"role": "system", "content": "EBNF mini-JSON generator."}, - { - "role": "user", - "content": "Generate single key JSON with only letters.", - }, - ], - temperature=0, - max_tokens=64, - extra_body={"ebnf": ebnf_grammar}, - ) - text = response.choices[0].message.content.strip() - self.assertTrue(len(text) > 0, "Got empty text from EBNF strict JSON test") - self.assertRegex( - text, pattern, f"Text '{text}' not matching the EBNF strict JSON shape" - ) - - class TestOpenAIV1Rerank(CustomTestCase): @classmethod def setUpClass(cls): diff --git a/test/srt/openai_server/basic/test_serving_chat.py b/test/srt/openai_server/basic/test_serving_chat.py index 7108b405d..262f8b8bd 100644 --- a/test/srt/openai_server/basic/test_serving_chat.py +++ b/test/srt/openai_server/basic/test_serving_chat.py @@ -197,6 +197,134 @@ class ServingChatTestCase(unittest.TestCase): self.assertEqual(params["min_new_tokens"], 5) self.assertEqual(params["stop"], [""]) + async def test_unstreamed_tool_args_completion(self): + """Test that remaining tool call arguments are sent when generation finishes.""" + + # Mock FunctionCallParser with detector that has partial tool call data + mock_parser = Mock() + mock_detector = Mock() + + # Simulate a tool call that was partially streamed + mock_detector.prev_tool_call_arr = [ + { + "name": "get_weather", + "arguments": {"location": "San Francisco", "unit": "celsius"}, + } + ] + mock_detector.streamed_args_for_tool = [ + '{"location": "San Francisco"' # Partial arguments streamed so far + ] + mock_parser.detector = mock_detector + + content = { + "meta_info": { + "id": "chatcmpl-test123", + } + } + + request = ChatCompletionRequest( + model="test", + messages=[{"role": "user", "content": "What's the weather?"}], + tools=[{"type": "function", "function": {"name": "get_weather"}}], + ) + + # Test the completion method + result = self.chat._check_for_unstreamed_tool_args( + parser=mock_parser, + content=content, + request=request, + finish_reason_type="stop", + index=0, + ) + + # Should return a chunk with remaining arguments + self.assertIsNotNone(result, "Should return chunk with remaining arguments") + self.assertIn('"arguments":', result, "Should contain arguments field") + self.assertIn( + ', "unit": "celsius"}', result, "Should contain remaining arguments" + ) + self.assertIn( + '"finish_reason":null', + result, + "Should not include finish_reason in completion chunk", + ) + + async def test_unstreamed_tool_args_no_completion_needed(self): + """Test that no completion chunk is sent when all arguments were already streamed.""" + + # Mock FunctionCallParser with detector that has complete tool call data + mock_parser = Mock() + mock_detector = Mock() + + # Simulate a tool call that was completely streamed + mock_detector.prev_tool_call_arr = [ + {"name": "get_weather", "arguments": {"location": "San Francisco"}} + ] + mock_detector.streamed_args_for_tool = [ + '{"location": "San Francisco"}' # All arguments already streamed + ] + mock_parser.detector = mock_detector + + content = { + "meta_info": { + "id": "chatcmpl-test123", + } + } + + request = ChatCompletionRequest( + model="test", + messages=[{"role": "user", "content": "What's the weather?"}], + tools=[{"type": "function", "function": {"name": "get_weather"}}], + ) + + # Test the completion method + result = self.chat._check_for_unstreamed_tool_args( + parser=mock_parser, + content=content, + request=request, + finish_reason_type="stop", + index=0, + ) + + # Should return None since no completion is needed + self.assertIsNone(result, "Should return None when no completion is needed") + + async def test_unstreamed_tool_args_no_parser_data(self): + """Test that no completion chunk is sent when parser has no tool call data.""" + + # Mock FunctionCallParser with empty detector + mock_parser = Mock() + mock_detector = Mock() + mock_detector.prev_tool_call_arr = [] + mock_detector.streamed_args_for_tool = [] + mock_parser.detector = mock_detector + + content = { + "meta_info": { + "id": "chatcmpl-test123", + } + } + + request = ChatCompletionRequest( + model="test", + messages=[{"role": "user", "content": "What's the weather?"}], + tools=[{"type": "function", "function": {"name": "get_weather"}}], + ) + + # Test the completion method + result = self.chat._check_for_unstreamed_tool_args( + parser=mock_parser, + content=content, + request=request, + finish_reason_type="stop", + index=0, + ) + + # Should return None since there's no parser data + self.assertIsNone( + result, "Should return None when parser has no tool call data" + ) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/test/srt/openai_server/function_call/test_openai_function_calling.py b/test/srt/openai_server/function_call/test_openai_function_calling.py index cd6d767b5..714514dd7 100644 --- a/test/srt/openai_server/function_call/test_openai_function_calling.py +++ b/test/srt/openai_server/function_call/test_openai_function_calling.py @@ -16,6 +16,20 @@ from sglang.test.test_utils import ( class TestOpenAIServerFunctionCalling(CustomTestCase): + # NOTE: this system_message is for Llama3.2 system prompt. Without this, + # sometimes Llama3.2 gives a different tool call format such as: + # '<|python_tag|>{"type": "function", "function": "add", "parameters": {"a": "3", "b": "5"}}' + SYSTEM_MESSAGE = ( + "You are a helpful assistant with tool calling capabilities. " + "Only reply with a tool call if the function exists in the library provided by the user. " + "If it doesn't exist, just reply directly in natural language. " + "When you receive a tool call response, use the output to format an answer to the original user question. " + "You have access to the following functions. " + "To call a function, please respond with JSON for a function call. " + 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. ' + "Do not use variables.\n\n" + ) + @classmethod def setUpClass(cls): # Replace with the model name needed for testing; if not required, reuse DEFAULT_SMALL_MODEL_NAME_FOR_TEST @@ -73,7 +87,10 @@ class TestOpenAIServerFunctionCalling(CustomTestCase): } ] - messages = [{"role": "user", "content": "Compute (3+5)"}] + messages = [ + {"role": "system", "content": self.SYSTEM_MESSAGE}, + {"role": "user", "content": "Compute (3+5)"}, + ] response = client.chat.completions.create( model=self.model, max_tokens=2048, @@ -205,7 +222,8 @@ class TestOpenAIServerFunctionCalling(CustomTestCase): ] messages = [ - {"role": "user", "content": "What is the temperature in Paris in celsius?"} + {"role": "system", "content": self.SYSTEM_MESSAGE}, + {"role": "user", "content": "What is the temperature in Paris?"}, ] response_stream = client.chat.completions.create( @@ -248,74 +266,6 @@ class TestOpenAIServerFunctionCalling(CustomTestCase): "Final response of function calling should have finish_reason 'tool_calls'", ) - # TODO: There is a bug in sglang preventing this UT from passing. We are working on it. Once done, we will add this UT back. - def _test_function_calling_streaming_no_tool_call(self): - """ - Test: Whether the finish_reason is stop in streaming mode when no tool call is given. - - Expect no function call to be found. - - Verify that finish_reason is stop - """ - client = openai.Client(api_key=self.api_key, base_url=self.base_url) - - tools = [ - { - "type": "function", - "function": { - "name": "get_current_weather", - "description": "Get the current weather in a given location", - "parameters": { - "type": "object", - "properties": { - "city": { - "type": "string", - "description": "The city to find the weather for", - }, - "unit": { - "type": "string", - "description": "Weather unit (celsius or fahrenheit)", - "enum": ["celsius", "fahrenheit"], - }, - }, - "required": ["city", "unit"], - }, - }, - } - ] - - messages = [{"role": "user", "content": "Who are you?"}] - - response_stream = client.chat.completions.create( - model=self.model, - max_tokens=2048, - messages=messages, - temperature=0.8, - top_p=0.8, - stream=True, - tools=tools, - tool_choice="none", - ) - - chunks = list(response_stream) - self.assertTrue(len(chunks) > 0, "Streaming should return at least one chunk") - - found_tool_call = False - for chunk in chunks: - choice = chunk.choices[0] - # Check whether the current chunk contains tool_calls - found_tool_call = choice.delta.tool_calls is not None - - self.assertFalse( - found_tool_call, - "Shouldn't have any tool_call in the streaming chunks", - ) - - finish_reason = chunks[-1].choices[0].finish_reason - self.assertEqual( - finish_reason, - "stop", - "Final response of no function calling should have finish_reason 'stop'", - ) - def test_function_calling_streaming_args_parsing(self): """ Test: Whether the function call arguments returned in streaming mode can be correctly concatenated into valid JSON. @@ -350,7 +300,8 @@ class TestOpenAIServerFunctionCalling(CustomTestCase): ] messages = [ - {"role": "user", "content": "Please sum 5 and 7, just call the function."} + {"role": "system", "content": self.SYSTEM_MESSAGE}, + {"role": "user", "content": "Please sum 5 and 7, just call the function."}, ] response_stream = client.chat.completions.create( @@ -617,6 +568,212 @@ class TestOpenAIServerFunctionCalling(CustomTestCase): ) self.assertIn("city", args_obj, "Function arguments should have 'city'") + def test_streaming_multiple_choices_finish_reason(self): + """ + Test: Verify that each choice gets its own finish_reason chunk in streaming mode with n > 1. + This tests the fix for the bug where only the last index got a finish_reason chunk. + """ + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + + tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + }, + }, + "required": ["location"], + }, + }, + } + ] + + messages = [ + {"role": "user", "content": "What is the weather like in Los Angeles?"} + ] + + # Request with n=2 to get multiple choices + response_stream = client.chat.completions.create( + model=self.model, + messages=messages, + max_tokens=2048, + temperature=0.8, + stream=True, + tools=tools, + tool_choice="required", # Force tool calls + n=2, # Multiple choices + ) + + chunks = list(response_stream) + + # Track finish_reason chunks for each index + finish_reason_chunks = {} + for chunk in chunks: + if chunk.choices: + for choice in chunk.choices: + if choice.finish_reason is not None: + index = choice.index + if index not in finish_reason_chunks: + finish_reason_chunks[index] = [] + finish_reason_chunks[index].append(choice.finish_reason) + + # Verify we got finish_reason chunks for both indices + self.assertEqual( + len(finish_reason_chunks), + 2, + f"Expected finish_reason chunks for 2 indices, got {len(finish_reason_chunks)}", + ) + + # Verify both index 0 and 1 have finish_reason + self.assertIn( + 0, finish_reason_chunks, "Missing finish_reason chunk for index 0" + ) + self.assertIn( + 1, finish_reason_chunks, "Missing finish_reason chunk for index 1" + ) + + # Verify the finish_reason is "tool_calls" since we forced tool calls + for index, reasons in finish_reason_chunks.items(): + self.assertEqual( + reasons[-1], # Last finish_reason for this index + "tool_calls", + f"Expected finish_reason 'tool_calls' for index {index}, got {reasons[-1]}", + ) + + def test_function_calling_streaming_no_tool_call(self): + """ + Test: Whether the finish_reason is stop in streaming mode when no tool call is given. + - Expect no function call to be found. + - Verify that finish_reason is stop + """ + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + + tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The city to find the weather for", + }, + "unit": { + "type": "string", + "description": "Weather unit (celsius or fahrenheit)", + "enum": ["celsius", "fahrenheit"], + }, + }, + "required": ["city", "unit"], + }, + }, + } + ] + + messages = [{"role": "user", "content": "Who are you?"}] + + response_stream = client.chat.completions.create( + model=self.model, + max_tokens=2048, + messages=messages, + temperature=0.8, + top_p=0.8, + stream=True, + tools=tools, + tool_choice="none", + ) + + chunks = list(response_stream) + self.assertTrue(len(chunks) > 0, "Streaming should return at least one chunk") + + found_tool_call = False + for chunk in chunks: + choice = chunk.choices[0] + # Check whether the current chunk contains tool_calls + found_tool_call = choice.delta.tool_calls is not None + + self.assertFalse( + found_tool_call, + "Shouldn't have any tool_call in the streaming chunks", + ) + + finish_reason = chunks[-1].choices[0].finish_reason + self.assertEqual( + finish_reason, + "stop", + "Final response of no function calling should have finish_reason 'stop'", + ) + + def test_streaming_multiple_choices_without_tools(self): + """ + Test: Verify that each choice gets its own finish_reason chunk without tool calls. + This tests the fix for regular content streaming with multiple choices. + """ + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + + messages = [{"role": "user", "content": "Say hello in one word."}] + + # Request with n=2 to get multiple choices, no tools + response_stream = client.chat.completions.create( + model=self.model, + messages=messages, + max_tokens=2048, + temperature=0.8, + stream=True, + max_tokens=10, # Keep it short + n=2, # Multiple choices + ) + + chunks = list(response_stream) + + # Track finish_reason chunks for each index + finish_reason_chunks = {} + for chunk in chunks: + if chunk.choices: + for choice in chunk.choices: + if choice.finish_reason is not None: + index = choice.index + if index not in finish_reason_chunks: + finish_reason_chunks[index] = [] + finish_reason_chunks[index].append(choice.finish_reason) + + # Verify we got finish_reason chunks for both indices + self.assertEqual( + len(finish_reason_chunks), + 2, + f"Expected finish_reason chunks for 2 indices, got {len(finish_reason_chunks)}", + ) + + # Verify both index 0 and 1 have finish_reason + self.assertIn( + 0, finish_reason_chunks, "Missing finish_reason chunk for index 0" + ) + self.assertIn( + 1, finish_reason_chunks, "Missing finish_reason chunk for index 1" + ) + + # Verify the finish_reason is "stop" (regular completion) + for index, reasons in finish_reason_chunks.items(): + self.assertIn( + reasons[-1], + ["stop", "length"], # Could be either depending on how model responds + f"Expected finish_reason 'stop' or 'length' for index {index}, got {reasons[-1]}", + ) + class TestOpenAIPythonicFunctionCalling(CustomTestCase): PYTHONIC_TOOLS = [ @@ -706,7 +863,6 @@ class TestOpenAIPythonicFunctionCalling(CustomTestCase): client = openai.Client(api_key=self.api_key, base_url=self.base_url) response = client.chat.completions.create( model=self.model, - max_tokens=2048, messages=self.PYTHONIC_MESSAGES, tools=self.PYTHONIC_TOOLS, temperature=0.1, @@ -728,7 +884,6 @@ class TestOpenAIPythonicFunctionCalling(CustomTestCase): client = openai.Client(api_key=self.api_key, base_url=self.base_url) response_stream = client.chat.completions.create( model=self.model, - max_tokens=2048, messages=self.PYTHONIC_MESSAGES, tools=self.PYTHONIC_TOOLS, temperature=0.1,