Fix kimi k2 function calling format (#9606)
This commit is contained in:
@@ -835,15 +835,23 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
finish_reason["matched"] = None
|
finish_reason["matched"] = None
|
||||||
try:
|
try:
|
||||||
text, call_info_list = parser.parse_non_stream(text)
|
text, call_info_list = parser.parse_non_stream(text)
|
||||||
tool_calls = [
|
tool_calls = []
|
||||||
ToolCall(
|
for call_info in call_info_list:
|
||||||
id=f"call_{uuid.uuid4().hex[:24]}",
|
# For Kimi-K2, align tool_call_id with the model format: functions.{name}:{index}
|
||||||
function=FunctionResponse(
|
if tool_call_parser == "kimi_k2" and call_info.name is not None:
|
||||||
name=call_info.name, arguments=call_info.parameters
|
tool_id = f"functions.{call_info.name}:{call_info.tool_index}"
|
||||||
),
|
else:
|
||||||
|
tool_id = f"call_{uuid.uuid4().hex[:24]}"
|
||||||
|
|
||||||
|
tool_calls.append(
|
||||||
|
ToolCall(
|
||||||
|
id=tool_id,
|
||||||
|
index=getattr(call_info, "tool_index", None),
|
||||||
|
function=FunctionResponse(
|
||||||
|
name=call_info.name, arguments=call_info.parameters
|
||||||
|
),
|
||||||
|
)
|
||||||
)
|
)
|
||||||
for call_info in call_info_list
|
|
||||||
]
|
|
||||||
return tool_calls, text, finish_reason
|
return tool_calls, text, finish_reason
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Tool call parsing error: {e}")
|
logger.error(f"Tool call parsing error: {e}")
|
||||||
@@ -954,7 +962,11 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
# Tool call ID should be generated only once per tool call
|
# Tool call ID should be generated only once per tool call
|
||||||
if call_item.name:
|
if call_item.name:
|
||||||
# First chunk: include ID and function name
|
# First chunk: include ID and function name
|
||||||
tool_call_id = f"call_{uuid.uuid4().hex[:24]}"
|
if self.tokenizer_manager.server_args.tool_call_parser == "kimi_k2":
|
||||||
|
# Align with Kimi-K2 format: functions.{name}:{index}
|
||||||
|
tool_call_id = f"functions.{call_item.name}:{call_item.tool_index}"
|
||||||
|
else:
|
||||||
|
tool_call_id = f"call_{uuid.uuid4().hex[:24]}"
|
||||||
function_name = call_item.name
|
function_name = call_item.name
|
||||||
else:
|
else:
|
||||||
# Subsequent chunks: null ID and name for argument deltas
|
# Subsequent chunks: null ID and name for argument deltas
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ or
|
|||||||
python -m unittest discover -s tests -p "test_*unit.py" -v
|
python -m unittest discover -s tests -p "test_*unit.py" -v
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
import unittest
|
import unittest
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
@@ -325,6 +327,100 @@ class ServingChatTestCase(unittest.TestCase):
|
|||||||
result, "Should return None when parser has no tool call data"
|
result, "Should return None when parser has no tool call data"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# ------------- kimi_k2 tool_call_id formatting -------------
|
||||||
|
def test_kimi_k2_non_streaming_tool_call_id_format(self):
|
||||||
|
"""Ensure non-streaming tool_call.id matches functions.{name}:{index} for kimi_k2 parser."""
|
||||||
|
|
||||||
|
# Force kimi_k2 parser
|
||||||
|
self.tm.server_args.tool_call_parser = "kimi_k2"
|
||||||
|
|
||||||
|
# Mock FunctionCallParser.parse_non_stream to return one tool call
|
||||||
|
with patch(
|
||||||
|
"sglang.srt.entrypoints.openai.serving_chat.FunctionCallParser"
|
||||||
|
) as ParserMock:
|
||||||
|
parser_instance = ParserMock.return_value
|
||||||
|
|
||||||
|
# Build a mock ToolCallItem-like object
|
||||||
|
call_info = Mock()
|
||||||
|
call_info.name = "get_weather"
|
||||||
|
call_info.parameters = '{"city":"Paris"}'
|
||||||
|
call_info.tool_index = 0
|
||||||
|
|
||||||
|
parser_instance.has_tool_call.return_value = True
|
||||||
|
parser_instance.parse_non_stream.return_value = ("", [call_info])
|
||||||
|
|
||||||
|
finish_reason = {"type": "stop", "matched": None}
|
||||||
|
tools = [
|
||||||
|
{"type": "function", "function": {"name": "get_weather"}},
|
||||||
|
]
|
||||||
|
|
||||||
|
tool_calls, remaining_text, _ = self.chat._process_tool_calls(
|
||||||
|
text="<|tool_calls_section_begin|>...",
|
||||||
|
tools=tools,
|
||||||
|
tool_call_parser="kimi_k2",
|
||||||
|
finish_reason=finish_reason,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertIsNotNone(tool_calls)
|
||||||
|
self.assertEqual(len(tool_calls), 1)
|
||||||
|
self.assertEqual(tool_calls[0].id, "functions.get_weather:0")
|
||||||
|
self.assertEqual(tool_calls[0].function.name, "get_weather")
|
||||||
|
|
||||||
|
def test_kimi_k2_streaming_tool_call_id_format(self):
|
||||||
|
"""Ensure streaming first chunk tool_call.id matches functions.{name}:{index} for kimi_k2 parser."""
|
||||||
|
|
||||||
|
# Force kimi_k2 parser
|
||||||
|
self.tm.server_args.tool_call_parser = "kimi_k2"
|
||||||
|
|
||||||
|
# Prepare request with tools
|
||||||
|
req = ChatCompletionRequest(
|
||||||
|
model="x",
|
||||||
|
messages=[{"role": "user", "content": "Hi?"}],
|
||||||
|
tools=[{"type": "function", "function": {"name": "get_weather"}}],
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Patch FunctionCallParser used inside _process_tool_call_stream
|
||||||
|
with patch(
|
||||||
|
"sglang.srt.entrypoints.openai.serving_chat.FunctionCallParser"
|
||||||
|
) as ParserMock:
|
||||||
|
parser_instance = ParserMock.return_value
|
||||||
|
|
||||||
|
# First call returns one ToolCallItem-like chunk (with name)
|
||||||
|
first_chunk_call = Mock()
|
||||||
|
first_chunk_call.tool_index = 0
|
||||||
|
first_chunk_call.name = "get_weather"
|
||||||
|
first_chunk_call.parameters = ""
|
||||||
|
parser_instance.parse_stream_chunk.side_effect = [
|
||||||
|
("", [first_chunk_call]),
|
||||||
|
("", []),
|
||||||
|
]
|
||||||
|
|
||||||
|
async def collect_first_tool_chunk():
|
||||||
|
gen = self.chat._process_tool_call_stream(
|
||||||
|
index=0,
|
||||||
|
delta="irrelevant",
|
||||||
|
parser_dict={},
|
||||||
|
content={"meta_info": {"id": "chatcmpl-test"}},
|
||||||
|
request=req,
|
||||||
|
has_tool_calls={},
|
||||||
|
)
|
||||||
|
# Get first yielded SSE line
|
||||||
|
line = None
|
||||||
|
async for emitted in gen:
|
||||||
|
line = emitted
|
||||||
|
break
|
||||||
|
return line
|
||||||
|
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
line = loop.run_until_complete(collect_first_tool_chunk())
|
||||||
|
self.assertIsNotNone(line)
|
||||||
|
self.assertTrue(line.startswith("data: "))
|
||||||
|
|
||||||
|
payload = json.loads(line[len("data: ") :])
|
||||||
|
tool_calls = payload["choices"][0]["delta"]["tool_calls"]
|
||||||
|
self.assertEqual(tool_calls[0]["id"], "functions.get_weather:0")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main(verbosity=2)
|
unittest.main(verbosity=2)
|
||||||
|
|||||||
Reference in New Issue
Block a user