Fix kimi k2 function calling format (#9606)
This commit is contained in:
@@ -835,15 +835,23 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
finish_reason["matched"] = None
|
||||
try:
|
||||
text, call_info_list = parser.parse_non_stream(text)
|
||||
tool_calls = [
|
||||
ToolCall(
|
||||
id=f"call_{uuid.uuid4().hex[:24]}",
|
||||
function=FunctionResponse(
|
||||
name=call_info.name, arguments=call_info.parameters
|
||||
),
|
||||
tool_calls = []
|
||||
for call_info in call_info_list:
|
||||
# For Kimi-K2, align tool_call_id with the model format: functions.{name}:{index}
|
||||
if tool_call_parser == "kimi_k2" and call_info.name is not None:
|
||||
tool_id = f"functions.{call_info.name}:{call_info.tool_index}"
|
||||
else:
|
||||
tool_id = f"call_{uuid.uuid4().hex[:24]}"
|
||||
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
id=tool_id,
|
||||
index=getattr(call_info, "tool_index", None),
|
||||
function=FunctionResponse(
|
||||
name=call_info.name, arguments=call_info.parameters
|
||||
),
|
||||
)
|
||||
)
|
||||
for call_info in call_info_list
|
||||
]
|
||||
return tool_calls, text, finish_reason
|
||||
except Exception as e:
|
||||
logger.error(f"Tool call parsing error: {e}")
|
||||
@@ -954,7 +962,11 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
# Tool call ID should be generated only once per tool call
|
||||
if call_item.name:
|
||||
# First chunk: include ID and function name
|
||||
tool_call_id = f"call_{uuid.uuid4().hex[:24]}"
|
||||
if self.tokenizer_manager.server_args.tool_call_parser == "kimi_k2":
|
||||
# Align with Kimi-K2 format: functions.{name}:{index}
|
||||
tool_call_id = f"functions.{call_item.name}:{call_item.tool_index}"
|
||||
else:
|
||||
tool_call_id = f"call_{uuid.uuid4().hex[:24]}"
|
||||
function_name = call_item.name
|
||||
else:
|
||||
# Subsequent chunks: null ID and name for argument deltas
|
||||
|
||||
@@ -6,6 +6,8 @@ or
|
||||
python -m unittest discover -s tests -p "test_*unit.py" -v
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import unittest
|
||||
import uuid
|
||||
from typing import Optional
|
||||
@@ -325,6 +327,100 @@ class ServingChatTestCase(unittest.TestCase):
|
||||
result, "Should return None when parser has no tool call data"
|
||||
)
|
||||
|
||||
# ------------- kimi_k2 tool_call_id formatting -------------
|
||||
def test_kimi_k2_non_streaming_tool_call_id_format(self):
|
||||
"""Ensure non-streaming tool_call.id matches functions.{name}:{index} for kimi_k2 parser."""
|
||||
|
||||
# Force kimi_k2 parser
|
||||
self.tm.server_args.tool_call_parser = "kimi_k2"
|
||||
|
||||
# Mock FunctionCallParser.parse_non_stream to return one tool call
|
||||
with patch(
|
||||
"sglang.srt.entrypoints.openai.serving_chat.FunctionCallParser"
|
||||
) as ParserMock:
|
||||
parser_instance = ParserMock.return_value
|
||||
|
||||
# Build a mock ToolCallItem-like object
|
||||
call_info = Mock()
|
||||
call_info.name = "get_weather"
|
||||
call_info.parameters = '{"city":"Paris"}'
|
||||
call_info.tool_index = 0
|
||||
|
||||
parser_instance.has_tool_call.return_value = True
|
||||
parser_instance.parse_non_stream.return_value = ("", [call_info])
|
||||
|
||||
finish_reason = {"type": "stop", "matched": None}
|
||||
tools = [
|
||||
{"type": "function", "function": {"name": "get_weather"}},
|
||||
]
|
||||
|
||||
tool_calls, remaining_text, _ = self.chat._process_tool_calls(
|
||||
text="<|tool_calls_section_begin|>...",
|
||||
tools=tools,
|
||||
tool_call_parser="kimi_k2",
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
|
||||
self.assertIsNotNone(tool_calls)
|
||||
self.assertEqual(len(tool_calls), 1)
|
||||
self.assertEqual(tool_calls[0].id, "functions.get_weather:0")
|
||||
self.assertEqual(tool_calls[0].function.name, "get_weather")
|
||||
|
||||
def test_kimi_k2_streaming_tool_call_id_format(self):
|
||||
"""Ensure streaming first chunk tool_call.id matches functions.{name}:{index} for kimi_k2 parser."""
|
||||
|
||||
# Force kimi_k2 parser
|
||||
self.tm.server_args.tool_call_parser = "kimi_k2"
|
||||
|
||||
# Prepare request with tools
|
||||
req = ChatCompletionRequest(
|
||||
model="x",
|
||||
messages=[{"role": "user", "content": "Hi?"}],
|
||||
tools=[{"type": "function", "function": {"name": "get_weather"}}],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# Patch FunctionCallParser used inside _process_tool_call_stream
|
||||
with patch(
|
||||
"sglang.srt.entrypoints.openai.serving_chat.FunctionCallParser"
|
||||
) as ParserMock:
|
||||
parser_instance = ParserMock.return_value
|
||||
|
||||
# First call returns one ToolCallItem-like chunk (with name)
|
||||
first_chunk_call = Mock()
|
||||
first_chunk_call.tool_index = 0
|
||||
first_chunk_call.name = "get_weather"
|
||||
first_chunk_call.parameters = ""
|
||||
parser_instance.parse_stream_chunk.side_effect = [
|
||||
("", [first_chunk_call]),
|
||||
("", []),
|
||||
]
|
||||
|
||||
async def collect_first_tool_chunk():
|
||||
gen = self.chat._process_tool_call_stream(
|
||||
index=0,
|
||||
delta="irrelevant",
|
||||
parser_dict={},
|
||||
content={"meta_info": {"id": "chatcmpl-test"}},
|
||||
request=req,
|
||||
has_tool_calls={},
|
||||
)
|
||||
# Get first yielded SSE line
|
||||
line = None
|
||||
async for emitted in gen:
|
||||
line = emitted
|
||||
break
|
||||
return line
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
line = loop.run_until_complete(collect_first_tool_chunk())
|
||||
self.assertIsNotNone(line)
|
||||
self.assertTrue(line.startswith("data: "))
|
||||
|
||||
payload = json.loads(line[len("data: ") :])
|
||||
tool_calls = payload["choices"][0]["delta"]["tool_calls"]
|
||||
self.assertEqual(tool_calls[0]["id"], "functions.get_weather:0")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
Reference in New Issue
Block a user