Replace the Kimi-K2 generated tool call idx with history tool call count (#10612)

Co-authored-by: eraser00 <eraser00@github.com>
This commit is contained in:
eraser00
2025-09-26 09:47:40 +08:00
committed by GitHub
parent 7dcd689b47
commit 0ac6114694
2 changed files with 226 additions and 15 deletions

View File

@@ -33,6 +33,7 @@ from sglang.srt.entrypoints.openai.utils import (
process_hidden_states_from_ret,
to_openai_style_logprobs,
)
from sglang.srt.function_call.core_types import ToolCallItem
from sglang.srt.function_call.function_call_parser import FunctionCallParser
from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.parser.conversation import generate_chat_conv
@@ -749,8 +750,9 @@ class OpenAIServingChat(OpenAIServingBase):
and request.tools
and self.tool_call_parser
):
history_tool_calls_cnt = self._get_history_tool_calls_cnt(request)
tool_calls, text, finish_reason = self._process_tool_calls(
text, request.tools, finish_reason
text, request.tools, finish_reason, history_tool_calls_cnt
)
choice_data = ChatCompletionResponseChoice(
@@ -840,11 +842,32 @@ class OpenAIServingChat(OpenAIServingBase):
token_logprobs = self._process_logprobs_tokens(logprobs, use_token_index=True)
return ChoiceLogprobs(content=token_logprobs)
def _process_tool_call_id(
self,
call_item: ToolCallItem,
history_tool_calls_cnt: int,
) -> str:
"""Process for generating a new and unique `tool_call_id`"""
if self.tool_call_parser != "kimi_k2":
# A simple uuid is sufficient for all models except for Kimi-K2.
tool_call_id = f"call_{uuid.uuid4().hex[:24]}"
return tool_call_id
else:
# Align with Kimi-K2 format: functions.{name}:{index}
# Kimi-K2 allows multiple tool_calls in one message; SGLang sets call_item.tool_index to the *local* position inside that message.
# Therefore, the index must be corrected by using `history_tool_calls_cnt + call_item.tool_index` to ensure globally unique and properly ordered.
tool_call_id = f"functions.{call_item.name}:{history_tool_calls_cnt+call_item.tool_index}"
logger.debug(
f"Process tool call idx, parser: {self.tool_call_parser}, tool_call_id: {tool_call_id}, history_cnt: {history_tool_calls_cnt}"
)
return tool_call_id
def _process_tool_calls(
self,
text: str,
tools: List[Any],
finish_reason: Dict[str, Any],
history_tool_calls_cnt: int = 0,
) -> tuple[Optional[List[ToolCall]], str, Dict[str, Any]]:
"""Process tool calls in the response"""
parser = FunctionCallParser(tools, self.tool_call_parser)
@@ -856,15 +879,9 @@ class OpenAIServingChat(OpenAIServingBase):
text, call_info_list = parser.parse_non_stream(text)
tool_calls = []
for call_info in call_info_list:
# For Kimi-K2, align tool_call_id with the model format: functions.{name}:{index}
if (
self.tool_call_parser == "kimi_k2"
and call_info.name is not None
):
tool_id = f"functions.{call_info.name}:{call_info.tool_index}"
else:
tool_id = f"call_{uuid.uuid4().hex[:24]}"
tool_id = self._process_tool_call_id(
call_info, history_tool_calls_cnt
)
tool_calls.append(
ToolCall(
id=tool_id,
@@ -920,6 +937,26 @@ class OpenAIServingChat(OpenAIServingBase):
reasoning_parser = reasoning_parser_dict[index]
return reasoning_parser.parse_stream_chunk(delta)
def _get_history_tool_calls_cnt(self, request: ChatCompletionRequest) -> int:
"""Counts the number of tool calls in the request's message history.
NOTE: This method is only useful for models that include self-increasing
history tool call idx in tool calls id, such as kimi-k2
Args:
request: The chat completion request object.
Returns:
The total number of tool calls in the history, or 0 if not applicable.
"""
messages = getattr(request, "messages", [])
idx = 0
for msg in messages:
if msg.role == "assistant":
tool_calls = getattr(msg, "tool_calls", None)
idx += len(list(tool_calls)) if tool_calls is not None else 0 # noqa
return idx
def _get_enable_thinking_from_request(self, request: ChatCompletionRequest) -> bool:
"""Extracts the 'enable_thinking' flag from request chat_template_kwargs.
@@ -977,6 +1014,7 @@ class OpenAIServingChat(OpenAIServingBase):
yield f"data: {chunk.model_dump_json()}\n\n"
# Yield tool calls
history_tool_calls_cnt = self._get_history_tool_calls_cnt(request)
for call_item in calls:
# Mark that this choice has tool calls
has_tool_calls[index] = True
@@ -984,11 +1022,9 @@ class OpenAIServingChat(OpenAIServingBase):
# Tool call ID should be generated only once per tool call
if call_item.name:
# First chunk: include ID and function name
if self.tool_call_parser == "kimi_k2":
# Align with Kimi-K2 format: functions.{name}:{index}
tool_call_id = f"functions.{call_item.name}:{call_item.tool_index}"
else:
tool_call_id = f"call_{uuid.uuid4().hex[:24]}"
tool_call_id = self._process_tool_call_id(
call_item, history_tool_calls_cnt
)
function_name = call_item.name
else:
# Subsequent chunks: null ID and name for argument deltas