From 0936c766ed6e52ac0a05fdee9f600a1d64365713 Mon Sep 17 00:00:00 2001 From: Xiaotong Jiang Date: Tue, 26 Aug 2025 00:50:59 -0700 Subject: [PATCH] Fix kimi k2 function calling format (#9606) --- .../srt/entrypoints/openai/serving_chat.py | 30 ++++-- .../openai_server/basic/test_serving_chat.py | 96 +++++++++++++++++++ 2 files changed, 117 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/entrypoints/openai/serving_chat.py b/python/sglang/srt/entrypoints/openai/serving_chat.py index 83f8ec2eb..4043203ef 100644 --- a/python/sglang/srt/entrypoints/openai/serving_chat.py +++ b/python/sglang/srt/entrypoints/openai/serving_chat.py @@ -835,15 +835,23 @@ class OpenAIServingChat(OpenAIServingBase): finish_reason["matched"] = None try: text, call_info_list = parser.parse_non_stream(text) - tool_calls = [ - ToolCall( - id=f"call_{uuid.uuid4().hex[:24]}", - function=FunctionResponse( - name=call_info.name, arguments=call_info.parameters - ), + tool_calls = [] + for call_info in call_info_list: + # For Kimi-K2, align tool_call_id with the model format: functions.{name}:{index} + if tool_call_parser == "kimi_k2" and call_info.name is not None: + tool_id = f"functions.{call_info.name}:{call_info.tool_index}" + else: + tool_id = f"call_{uuid.uuid4().hex[:24]}" + + tool_calls.append( + ToolCall( + id=tool_id, + index=getattr(call_info, "tool_index", None), + function=FunctionResponse( + name=call_info.name, arguments=call_info.parameters + ), + ) ) - for call_info in call_info_list - ] return tool_calls, text, finish_reason except Exception as e: logger.error(f"Tool call parsing error: {e}") @@ -954,7 +962,11 @@ class OpenAIServingChat(OpenAIServingBase): # Tool call ID should be generated only once per tool call if call_item.name: # First chunk: include ID and function name - tool_call_id = f"call_{uuid.uuid4().hex[:24]}" + if self.tokenizer_manager.server_args.tool_call_parser == "kimi_k2": + # Align with Kimi-K2 format: functions.{name}:{index} + tool_call_id = f"functions.{call_item.name}:{call_item.tool_index}" + else: + tool_call_id = f"call_{uuid.uuid4().hex[:24]}" function_name = call_item.name else: # Subsequent chunks: null ID and name for argument deltas diff --git a/test/srt/openai_server/basic/test_serving_chat.py b/test/srt/openai_server/basic/test_serving_chat.py index 262f8b8bd..41eaea2ee 100644 --- a/test/srt/openai_server/basic/test_serving_chat.py +++ b/test/srt/openai_server/basic/test_serving_chat.py @@ -6,6 +6,8 @@ or python -m unittest discover -s tests -p "test_*unit.py" -v """ +import asyncio +import json import unittest import uuid from typing import Optional @@ -325,6 +327,100 @@ class ServingChatTestCase(unittest.TestCase): result, "Should return None when parser has no tool call data" ) + # ------------- kimi_k2 tool_call_id formatting ------------- + def test_kimi_k2_non_streaming_tool_call_id_format(self): + """Ensure non-streaming tool_call.id matches functions.{name}:{index} for kimi_k2 parser.""" + + # Force kimi_k2 parser + self.tm.server_args.tool_call_parser = "kimi_k2" + + # Mock FunctionCallParser.parse_non_stream to return one tool call + with patch( + "sglang.srt.entrypoints.openai.serving_chat.FunctionCallParser" + ) as ParserMock: + parser_instance = ParserMock.return_value + + # Build a mock ToolCallItem-like object + call_info = Mock() + call_info.name = "get_weather" + call_info.parameters = '{"city":"Paris"}' + call_info.tool_index = 0 + + parser_instance.has_tool_call.return_value = True + parser_instance.parse_non_stream.return_value = ("", [call_info]) + + finish_reason = {"type": "stop", "matched": None} + tools = [ + {"type": "function", "function": {"name": "get_weather"}}, + ] + + tool_calls, remaining_text, _ = self.chat._process_tool_calls( + text="<|tool_calls_section_begin|>...", + tools=tools, + tool_call_parser="kimi_k2", + finish_reason=finish_reason, + ) + + self.assertIsNotNone(tool_calls) + self.assertEqual(len(tool_calls), 1) + self.assertEqual(tool_calls[0].id, "functions.get_weather:0") + self.assertEqual(tool_calls[0].function.name, "get_weather") + + def test_kimi_k2_streaming_tool_call_id_format(self): + """Ensure streaming first chunk tool_call.id matches functions.{name}:{index} for kimi_k2 parser.""" + + # Force kimi_k2 parser + self.tm.server_args.tool_call_parser = "kimi_k2" + + # Prepare request with tools + req = ChatCompletionRequest( + model="x", + messages=[{"role": "user", "content": "Hi?"}], + tools=[{"type": "function", "function": {"name": "get_weather"}}], + stream=True, + ) + + # Patch FunctionCallParser used inside _process_tool_call_stream + with patch( + "sglang.srt.entrypoints.openai.serving_chat.FunctionCallParser" + ) as ParserMock: + parser_instance = ParserMock.return_value + + # First call returns one ToolCallItem-like chunk (with name) + first_chunk_call = Mock() + first_chunk_call.tool_index = 0 + first_chunk_call.name = "get_weather" + first_chunk_call.parameters = "" + parser_instance.parse_stream_chunk.side_effect = [ + ("", [first_chunk_call]), + ("", []), + ] + + async def collect_first_tool_chunk(): + gen = self.chat._process_tool_call_stream( + index=0, + delta="irrelevant", + parser_dict={}, + content={"meta_info": {"id": "chatcmpl-test"}}, + request=req, + has_tool_calls={}, + ) + # Get first yielded SSE line + line = None + async for emitted in gen: + line = emitted + break + return line + + loop = asyncio.get_event_loop() + line = loop.run_until_complete(collect_first_tool_chunk()) + self.assertIsNotNone(line) + self.assertTrue(line.startswith("data: ")) + + payload = json.loads(line[len("data: ") :]) + tool_calls = payload["choices"][0]["delta"]["tool_calls"] + self.assertEqual(tool_calls[0]["id"], "functions.get_weather:0") + if __name__ == "__main__": unittest.main(verbosity=2)