From 0ac61146947ad5bb202ce08a81431eb0daf43aef Mon Sep 17 00:00:00 2001 From: eraser00 <145762019+eraser00@users.noreply.github.com> Date: Fri, 26 Sep 2025 09:47:40 +0800 Subject: [PATCH] Replace the Kimi-K2 generated tool call idx with history tool call count (#10612) Co-authored-by: eraser00 --- .../srt/entrypoints/openai/serving_chat.py | 66 +++++-- .../openai_server/basic/test_serving_chat.py | 175 ++++++++++++++++++ 2 files changed, 226 insertions(+), 15 deletions(-) diff --git a/python/sglang/srt/entrypoints/openai/serving_chat.py b/python/sglang/srt/entrypoints/openai/serving_chat.py index 90572be6c..ff62e0988 100644 --- a/python/sglang/srt/entrypoints/openai/serving_chat.py +++ b/python/sglang/srt/entrypoints/openai/serving_chat.py @@ -33,6 +33,7 @@ from sglang.srt.entrypoints.openai.utils import ( process_hidden_states_from_ret, to_openai_style_logprobs, ) +from sglang.srt.function_call.core_types import ToolCallItem from sglang.srt.function_call.function_call_parser import FunctionCallParser from sglang.srt.managers.io_struct import GenerateReqInput from sglang.srt.parser.conversation import generate_chat_conv @@ -749,8 +750,9 @@ class OpenAIServingChat(OpenAIServingBase): and request.tools and self.tool_call_parser ): + history_tool_calls_cnt = self._get_history_tool_calls_cnt(request) tool_calls, text, finish_reason = self._process_tool_calls( - text, request.tools, finish_reason + text, request.tools, finish_reason, history_tool_calls_cnt ) choice_data = ChatCompletionResponseChoice( @@ -840,11 +842,32 @@ class OpenAIServingChat(OpenAIServingBase): token_logprobs = self._process_logprobs_tokens(logprobs, use_token_index=True) return ChoiceLogprobs(content=token_logprobs) + def _process_tool_call_id( + self, + call_item: ToolCallItem, + history_tool_calls_cnt: int, + ) -> str: + """Process for generating a new and unique `tool_call_id`""" + if self.tool_call_parser != "kimi_k2": + # A simple uuid is sufficient for all models except for Kimi-K2. + tool_call_id = f"call_{uuid.uuid4().hex[:24]}" + return tool_call_id + else: + # Align with Kimi-K2 format: functions.{name}:{index} + # Kimi-K2 allows multiple tool_calls in one message; SGLang sets call_item.tool_index to the *local* position inside that message. + # Therefore, the index must be corrected by using `history_tool_calls_cnt + call_item.tool_index` to ensure globally unique and properly ordered. + tool_call_id = f"functions.{call_item.name}:{history_tool_calls_cnt+call_item.tool_index}" + logger.debug( + f"Process tool call idx, parser: {self.tool_call_parser}, tool_call_id: {tool_call_id}, history_cnt: {history_tool_calls_cnt}" + ) + return tool_call_id + def _process_tool_calls( self, text: str, tools: List[Any], finish_reason: Dict[str, Any], + history_tool_calls_cnt: int = 0, ) -> tuple[Optional[List[ToolCall]], str, Dict[str, Any]]: """Process tool calls in the response""" parser = FunctionCallParser(tools, self.tool_call_parser) @@ -856,15 +879,9 @@ class OpenAIServingChat(OpenAIServingBase): text, call_info_list = parser.parse_non_stream(text) tool_calls = [] for call_info in call_info_list: - # For Kimi-K2, align tool_call_id with the model format: functions.{name}:{index} - if ( - self.tool_call_parser == "kimi_k2" - and call_info.name is not None - ): - tool_id = f"functions.{call_info.name}:{call_info.tool_index}" - else: - tool_id = f"call_{uuid.uuid4().hex[:24]}" - + tool_id = self._process_tool_call_id( + call_info, history_tool_calls_cnt + ) tool_calls.append( ToolCall( id=tool_id, @@ -920,6 +937,26 @@ class OpenAIServingChat(OpenAIServingBase): reasoning_parser = reasoning_parser_dict[index] return reasoning_parser.parse_stream_chunk(delta) + def _get_history_tool_calls_cnt(self, request: ChatCompletionRequest) -> int: + """Counts the number of tool calls in the request's message history. + + NOTE: This method is only useful for models that include self-increasing + history tool call idx in tool calls id, such as kimi-k2 + + Args: + request: The chat completion request object. + + Returns: + The total number of tool calls in the history, or 0 if not applicable. + """ + messages = getattr(request, "messages", []) + idx = 0 + for msg in messages: + if msg.role == "assistant": + tool_calls = getattr(msg, "tool_calls", None) + idx += len(list(tool_calls)) if tool_calls is not None else 0 # noqa + return idx + def _get_enable_thinking_from_request(self, request: ChatCompletionRequest) -> bool: """Extracts the 'enable_thinking' flag from request chat_template_kwargs. @@ -977,6 +1014,7 @@ class OpenAIServingChat(OpenAIServingBase): yield f"data: {chunk.model_dump_json()}\n\n" # Yield tool calls + history_tool_calls_cnt = self._get_history_tool_calls_cnt(request) for call_item in calls: # Mark that this choice has tool calls has_tool_calls[index] = True @@ -984,11 +1022,9 @@ class OpenAIServingChat(OpenAIServingBase): # Tool call ID should be generated only once per tool call if call_item.name: # First chunk: include ID and function name - if self.tool_call_parser == "kimi_k2": - # Align with Kimi-K2 format: functions.{name}:{index} - tool_call_id = f"functions.{call_item.name}:{call_item.tool_index}" - else: - tool_call_id = f"call_{uuid.uuid4().hex[:24]}" + tool_call_id = self._process_tool_call_id( + call_item, history_tool_calls_cnt + ) function_name = call_item.name else: # Subsequent chunks: null ID and name for argument deltas diff --git a/test/srt/openai_server/basic/test_serving_chat.py b/test/srt/openai_server/basic/test_serving_chat.py index 9f0d48004..6f1901d75 100644 --- a/test/srt/openai_server/basic/test_serving_chat.py +++ b/test/srt/openai_server/basic/test_serving_chat.py @@ -420,6 +420,181 @@ class ServingChatTestCase(unittest.TestCase): tool_calls = payload["choices"][0]["delta"]["tool_calls"] self.assertEqual(tool_calls[0]["id"], "functions.get_weather:0") + def test_kimi_k2_non_streaming_tool_call_id_with_history(self): + """Ensure non-streaming tool_call.id increase with tool calls history for kimi_k2 parser.""" + + # Force kimi_k2 parser + self.chat.tool_call_parser = "kimi_k2" + + # Prepare request with tool calls history + req = ChatCompletionRequest( + model="x", + messages=[ + {"role": "user", "content": "What's the weather today in paris?"}, + { + "role": "assistant", + "content": "Let me do some search first.", + "tool_calls": [ + { + "id": "functions.get_weather:0", + "type": "function", + "function": { + "name": "get_weather", + "arguments": '{"city": "Paris"}', + }, + } + ], + }, + { + "role": "tool", + "content": "It's rainy in paris now.", + "tool_call_id": "functions.get_weather:0", + }, + { + "role": "assistant", + "content": "It's rainy now.", + }, + { + "role": "user", + "content": "What about LA and Tokyo?", + }, + ], + tools=[{"type": "function", "function": {"name": "get_weather"}}], + stream=False, + ) + + # Mock FunctionCallParser.parse_non_stream to return one tool call + with patch( + "sglang.srt.entrypoints.openai.serving_chat.FunctionCallParser" + ) as ParserMock: + parser_instance = ParserMock.return_value + + # Build a mock ToolCallItem-like object + call_info = Mock() + call_info.name = "get_weather" + call_info.parameters = '{"city":"Loa Angeles"}' + # Kimi-K2 series models might generate fixed number tool_indx, + # ignoring the tool calls history and mess up all the following tool calls + call_info.tool_index = 0 + + call_info2 = Mock() + call_info2.name = "get_weather" + call_info2.parameters = '{"city":"Tokyo"}' + call_info2.tool_index = 1 + + parser_instance.has_tool_call.return_value = True + parser_instance.parse_non_stream.return_value = ( + "", + [call_info, call_info2], + ) + + finish_reason = {"type": "stop", "matched": None} + tools = [ + {"type": "function", "function": {"name": "get_weather"}}, + ] + + history_tool_calls_cnt = self.chat._get_history_tool_calls_cnt(req) + tool_calls, remaining_text, _ = self.chat._process_tool_calls( + text="<|tool_calls_section_begin|>...", + tools=tools, + finish_reason=finish_reason, + history_tool_calls_cnt=history_tool_calls_cnt, + ) + + self.assertEqual(history_tool_calls_cnt, 1) + self.assertIsNotNone(tool_calls) + self.assertEqual(len(tool_calls), 2) + self.assertEqual(tool_calls[0].id, "functions.get_weather:1") + self.assertEqual(tool_calls[0].function.name, "get_weather") + self.assertEqual(tool_calls[1].id, "functions.get_weather:2") + self.assertEqual(tool_calls[1].function.name, "get_weather") + + def test_kimi_k2_streaming_tool_call_id_with_history(self): + """Ensure streaming first chunk tool_call.id increase with tool calls history for kimi_k2 parser.""" + + # Force kimi_k2 parser + self.chat.tool_call_parser = "kimi_k2" + + # Prepare request with tool calls history + req = ChatCompletionRequest( + model="x", + messages=[ + {"role": "user", "content": "What's the weather today in paris?"}, + { + "role": "assistant", + "content": "Let me do some search first.", + "tool_calls": [ + { + "id": "functions.get_weather:0", + "type": "function", + "function": { + "name": "get_weather", + "arguments": '{"city": "Paris"}', + }, + } + ], + }, + { + "role": "tool", + "content": "It's rainy in paris now.", + "tool_call_id": "functions.get_weather:0", + }, + { + "role": "assistant", + "content": "It's rainy now.", + }, + { + "role": "user", + "content": "What about LA?", + }, + ], + tools=[{"type": "function", "function": {"name": "get_weather"}}], + stream=True, + ) + + # Patch FunctionCallParser used inside _process_tool_call_stream + with patch( + "sglang.srt.entrypoints.openai.serving_chat.FunctionCallParser" + ) as ParserMock: + parser_instance = ParserMock.return_value + + # First call returns one ToolCallItem-like chunk (with name) + first_chunk_call = Mock() + # Kimi-K2 series models might generate fixed number tool_indx, + # ignoring the tool calls history and mess up all the following tool calls + first_chunk_call.tool_index = 0 + first_chunk_call.name = "get_weather" + first_chunk_call.parameters = "" + parser_instance.parse_stream_chunk.side_effect = [ + ("", [first_chunk_call]), + ("", []), + ] + + async def collect_first_tool_chunk(): + gen = self.chat._process_tool_call_stream( + index=0, + delta="irrelevant", + parser_dict={}, + content={"meta_info": {"id": "chatcmpl-test"}}, + request=req, + has_tool_calls={}, + ) + # Get first yielded SSE line + line = None + async for emitted in gen: + line = emitted + break + return line + + loop = asyncio.get_event_loop() + line = loop.run_until_complete(collect_first_tool_chunk()) + self.assertIsNotNone(line) + self.assertTrue(line.startswith("data: ")) + + payload = json.loads(line[len("data: ") :]) + tool_calls = payload["choices"][0]["delta"]["tool_calls"] + self.assertEqual(tool_calls[0]["id"], "functions.get_weather:1") + if __name__ == "__main__": unittest.main(verbosity=2)