Replace the Kimi-K2 generated tool call idx with history tool call count (#10612)
Co-authored-by: eraser00 <eraser00@github.com>
This commit is contained in:
@@ -33,6 +33,7 @@ from sglang.srt.entrypoints.openai.utils import (
|
||||
process_hidden_states_from_ret,
|
||||
to_openai_style_logprobs,
|
||||
)
|
||||
from sglang.srt.function_call.core_types import ToolCallItem
|
||||
from sglang.srt.function_call.function_call_parser import FunctionCallParser
|
||||
from sglang.srt.managers.io_struct import GenerateReqInput
|
||||
from sglang.srt.parser.conversation import generate_chat_conv
|
||||
@@ -749,8 +750,9 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
and request.tools
|
||||
and self.tool_call_parser
|
||||
):
|
||||
history_tool_calls_cnt = self._get_history_tool_calls_cnt(request)
|
||||
tool_calls, text, finish_reason = self._process_tool_calls(
|
||||
text, request.tools, finish_reason
|
||||
text, request.tools, finish_reason, history_tool_calls_cnt
|
||||
)
|
||||
|
||||
choice_data = ChatCompletionResponseChoice(
|
||||
@@ -840,11 +842,32 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
token_logprobs = self._process_logprobs_tokens(logprobs, use_token_index=True)
|
||||
return ChoiceLogprobs(content=token_logprobs)
|
||||
|
||||
def _process_tool_call_id(
|
||||
self,
|
||||
call_item: ToolCallItem,
|
||||
history_tool_calls_cnt: int,
|
||||
) -> str:
|
||||
"""Process for generating a new and unique `tool_call_id`"""
|
||||
if self.tool_call_parser != "kimi_k2":
|
||||
# A simple uuid is sufficient for all models except for Kimi-K2.
|
||||
tool_call_id = f"call_{uuid.uuid4().hex[:24]}"
|
||||
return tool_call_id
|
||||
else:
|
||||
# Align with Kimi-K2 format: functions.{name}:{index}
|
||||
# Kimi-K2 allows multiple tool_calls in one message; SGLang sets call_item.tool_index to the *local* position inside that message.
|
||||
# Therefore, the index must be corrected by using `history_tool_calls_cnt + call_item.tool_index` to ensure globally unique and properly ordered.
|
||||
tool_call_id = f"functions.{call_item.name}:{history_tool_calls_cnt+call_item.tool_index}"
|
||||
logger.debug(
|
||||
f"Process tool call idx, parser: {self.tool_call_parser}, tool_call_id: {tool_call_id}, history_cnt: {history_tool_calls_cnt}"
|
||||
)
|
||||
return tool_call_id
|
||||
|
||||
def _process_tool_calls(
|
||||
self,
|
||||
text: str,
|
||||
tools: List[Any],
|
||||
finish_reason: Dict[str, Any],
|
||||
history_tool_calls_cnt: int = 0,
|
||||
) -> tuple[Optional[List[ToolCall]], str, Dict[str, Any]]:
|
||||
"""Process tool calls in the response"""
|
||||
parser = FunctionCallParser(tools, self.tool_call_parser)
|
||||
@@ -856,15 +879,9 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
text, call_info_list = parser.parse_non_stream(text)
|
||||
tool_calls = []
|
||||
for call_info in call_info_list:
|
||||
# For Kimi-K2, align tool_call_id with the model format: functions.{name}:{index}
|
||||
if (
|
||||
self.tool_call_parser == "kimi_k2"
|
||||
and call_info.name is not None
|
||||
):
|
||||
tool_id = f"functions.{call_info.name}:{call_info.tool_index}"
|
||||
else:
|
||||
tool_id = f"call_{uuid.uuid4().hex[:24]}"
|
||||
|
||||
tool_id = self._process_tool_call_id(
|
||||
call_info, history_tool_calls_cnt
|
||||
)
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
id=tool_id,
|
||||
@@ -920,6 +937,26 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
reasoning_parser = reasoning_parser_dict[index]
|
||||
return reasoning_parser.parse_stream_chunk(delta)
|
||||
|
||||
def _get_history_tool_calls_cnt(self, request: ChatCompletionRequest) -> int:
|
||||
"""Counts the number of tool calls in the request's message history.
|
||||
|
||||
NOTE: This method is only useful for models that include self-increasing
|
||||
history tool call idx in tool calls id, such as kimi-k2
|
||||
|
||||
Args:
|
||||
request: The chat completion request object.
|
||||
|
||||
Returns:
|
||||
The total number of tool calls in the history, or 0 if not applicable.
|
||||
"""
|
||||
messages = getattr(request, "messages", [])
|
||||
idx = 0
|
||||
for msg in messages:
|
||||
if msg.role == "assistant":
|
||||
tool_calls = getattr(msg, "tool_calls", None)
|
||||
idx += len(list(tool_calls)) if tool_calls is not None else 0 # noqa
|
||||
return idx
|
||||
|
||||
def _get_enable_thinking_from_request(self, request: ChatCompletionRequest) -> bool:
|
||||
"""Extracts the 'enable_thinking' flag from request chat_template_kwargs.
|
||||
|
||||
@@ -977,6 +1014,7 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||
|
||||
# Yield tool calls
|
||||
history_tool_calls_cnt = self._get_history_tool_calls_cnt(request)
|
||||
for call_item in calls:
|
||||
# Mark that this choice has tool calls
|
||||
has_tool_calls[index] = True
|
||||
@@ -984,11 +1022,9 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
# Tool call ID should be generated only once per tool call
|
||||
if call_item.name:
|
||||
# First chunk: include ID and function name
|
||||
if self.tool_call_parser == "kimi_k2":
|
||||
# Align with Kimi-K2 format: functions.{name}:{index}
|
||||
tool_call_id = f"functions.{call_item.name}:{call_item.tool_index}"
|
||||
else:
|
||||
tool_call_id = f"call_{uuid.uuid4().hex[:24]}"
|
||||
tool_call_id = self._process_tool_call_id(
|
||||
call_item, history_tool_calls_cnt
|
||||
)
|
||||
function_name = call_item.name
|
||||
else:
|
||||
# Subsequent chunks: null ID and name for argument deltas
|
||||
|
||||
@@ -420,6 +420,181 @@ class ServingChatTestCase(unittest.TestCase):
|
||||
tool_calls = payload["choices"][0]["delta"]["tool_calls"]
|
||||
self.assertEqual(tool_calls[0]["id"], "functions.get_weather:0")
|
||||
|
||||
def test_kimi_k2_non_streaming_tool_call_id_with_history(self):
|
||||
"""Ensure non-streaming tool_call.id increase with tool calls history for kimi_k2 parser."""
|
||||
|
||||
# Force kimi_k2 parser
|
||||
self.chat.tool_call_parser = "kimi_k2"
|
||||
|
||||
# Prepare request with tool calls history
|
||||
req = ChatCompletionRequest(
|
||||
model="x",
|
||||
messages=[
|
||||
{"role": "user", "content": "What's the weather today in paris?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Let me do some search first.",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "functions.get_weather:0",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"city": "Paris"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"content": "It's rainy in paris now.",
|
||||
"tool_call_id": "functions.get_weather:0",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "It's rainy now.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What about LA and Tokyo?",
|
||||
},
|
||||
],
|
||||
tools=[{"type": "function", "function": {"name": "get_weather"}}],
|
||||
stream=False,
|
||||
)
|
||||
|
||||
# Mock FunctionCallParser.parse_non_stream to return one tool call
|
||||
with patch(
|
||||
"sglang.srt.entrypoints.openai.serving_chat.FunctionCallParser"
|
||||
) as ParserMock:
|
||||
parser_instance = ParserMock.return_value
|
||||
|
||||
# Build a mock ToolCallItem-like object
|
||||
call_info = Mock()
|
||||
call_info.name = "get_weather"
|
||||
call_info.parameters = '{"city":"Loa Angeles"}'
|
||||
# Kimi-K2 series models might generate fixed number tool_indx,
|
||||
# ignoring the tool calls history and mess up all the following tool calls
|
||||
call_info.tool_index = 0
|
||||
|
||||
call_info2 = Mock()
|
||||
call_info2.name = "get_weather"
|
||||
call_info2.parameters = '{"city":"Tokyo"}'
|
||||
call_info2.tool_index = 1
|
||||
|
||||
parser_instance.has_tool_call.return_value = True
|
||||
parser_instance.parse_non_stream.return_value = (
|
||||
"",
|
||||
[call_info, call_info2],
|
||||
)
|
||||
|
||||
finish_reason = {"type": "stop", "matched": None}
|
||||
tools = [
|
||||
{"type": "function", "function": {"name": "get_weather"}},
|
||||
]
|
||||
|
||||
history_tool_calls_cnt = self.chat._get_history_tool_calls_cnt(req)
|
||||
tool_calls, remaining_text, _ = self.chat._process_tool_calls(
|
||||
text="<|tool_calls_section_begin|>...",
|
||||
tools=tools,
|
||||
finish_reason=finish_reason,
|
||||
history_tool_calls_cnt=history_tool_calls_cnt,
|
||||
)
|
||||
|
||||
self.assertEqual(history_tool_calls_cnt, 1)
|
||||
self.assertIsNotNone(tool_calls)
|
||||
self.assertEqual(len(tool_calls), 2)
|
||||
self.assertEqual(tool_calls[0].id, "functions.get_weather:1")
|
||||
self.assertEqual(tool_calls[0].function.name, "get_weather")
|
||||
self.assertEqual(tool_calls[1].id, "functions.get_weather:2")
|
||||
self.assertEqual(tool_calls[1].function.name, "get_weather")
|
||||
|
||||
def test_kimi_k2_streaming_tool_call_id_with_history(self):
|
||||
"""Ensure streaming first chunk tool_call.id increase with tool calls history for kimi_k2 parser."""
|
||||
|
||||
# Force kimi_k2 parser
|
||||
self.chat.tool_call_parser = "kimi_k2"
|
||||
|
||||
# Prepare request with tool calls history
|
||||
req = ChatCompletionRequest(
|
||||
model="x",
|
||||
messages=[
|
||||
{"role": "user", "content": "What's the weather today in paris?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Let me do some search first.",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "functions.get_weather:0",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"city": "Paris"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"content": "It's rainy in paris now.",
|
||||
"tool_call_id": "functions.get_weather:0",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "It's rainy now.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What about LA?",
|
||||
},
|
||||
],
|
||||
tools=[{"type": "function", "function": {"name": "get_weather"}}],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# Patch FunctionCallParser used inside _process_tool_call_stream
|
||||
with patch(
|
||||
"sglang.srt.entrypoints.openai.serving_chat.FunctionCallParser"
|
||||
) as ParserMock:
|
||||
parser_instance = ParserMock.return_value
|
||||
|
||||
# First call returns one ToolCallItem-like chunk (with name)
|
||||
first_chunk_call = Mock()
|
||||
# Kimi-K2 series models might generate fixed number tool_indx,
|
||||
# ignoring the tool calls history and mess up all the following tool calls
|
||||
first_chunk_call.tool_index = 0
|
||||
first_chunk_call.name = "get_weather"
|
||||
first_chunk_call.parameters = ""
|
||||
parser_instance.parse_stream_chunk.side_effect = [
|
||||
("", [first_chunk_call]),
|
||||
("", []),
|
||||
]
|
||||
|
||||
async def collect_first_tool_chunk():
|
||||
gen = self.chat._process_tool_call_stream(
|
||||
index=0,
|
||||
delta="irrelevant",
|
||||
parser_dict={},
|
||||
content={"meta_info": {"id": "chatcmpl-test"}},
|
||||
request=req,
|
||||
has_tool_calls={},
|
||||
)
|
||||
# Get first yielded SSE line
|
||||
line = None
|
||||
async for emitted in gen:
|
||||
line = emitted
|
||||
break
|
||||
return line
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
line = loop.run_until_complete(collect_first_tool_chunk())
|
||||
self.assertIsNotNone(line)
|
||||
self.assertTrue(line.startswith("data: "))
|
||||
|
||||
payload = json.loads(line[len("data: ") :])
|
||||
tool_calls = payload["choices"][0]["delta"]["tool_calls"]
|
||||
self.assertEqual(tool_calls[0]["id"], "functions.get_weather:1")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
Reference in New Issue
Block a user