Replace the Kimi-K2 generated tool call idx with history tool call count (#10612)
Co-authored-by: eraser00 <eraser00@github.com>
This commit is contained in:
@@ -33,6 +33,7 @@ from sglang.srt.entrypoints.openai.utils import (
|
||||
process_hidden_states_from_ret,
|
||||
to_openai_style_logprobs,
|
||||
)
|
||||
from sglang.srt.function_call.core_types import ToolCallItem
|
||||
from sglang.srt.function_call.function_call_parser import FunctionCallParser
|
||||
from sglang.srt.managers.io_struct import GenerateReqInput
|
||||
from sglang.srt.parser.conversation import generate_chat_conv
|
||||
@@ -749,8 +750,9 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
and request.tools
|
||||
and self.tool_call_parser
|
||||
):
|
||||
history_tool_calls_cnt = self._get_history_tool_calls_cnt(request)
|
||||
tool_calls, text, finish_reason = self._process_tool_calls(
|
||||
text, request.tools, finish_reason
|
||||
text, request.tools, finish_reason, history_tool_calls_cnt
|
||||
)
|
||||
|
||||
choice_data = ChatCompletionResponseChoice(
|
||||
@@ -840,11 +842,32 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
token_logprobs = self._process_logprobs_tokens(logprobs, use_token_index=True)
|
||||
return ChoiceLogprobs(content=token_logprobs)
|
||||
|
||||
def _process_tool_call_id(
|
||||
self,
|
||||
call_item: ToolCallItem,
|
||||
history_tool_calls_cnt: int,
|
||||
) -> str:
|
||||
"""Process for generating a new and unique `tool_call_id`"""
|
||||
if self.tool_call_parser != "kimi_k2":
|
||||
# A simple uuid is sufficient for all models except for Kimi-K2.
|
||||
tool_call_id = f"call_{uuid.uuid4().hex[:24]}"
|
||||
return tool_call_id
|
||||
else:
|
||||
# Align with Kimi-K2 format: functions.{name}:{index}
|
||||
# Kimi-K2 allows multiple tool_calls in one message; SGLang sets call_item.tool_index to the *local* position inside that message.
|
||||
# Therefore, the index must be corrected by using `history_tool_calls_cnt + call_item.tool_index` to ensure globally unique and properly ordered.
|
||||
tool_call_id = f"functions.{call_item.name}:{history_tool_calls_cnt+call_item.tool_index}"
|
||||
logger.debug(
|
||||
f"Process tool call idx, parser: {self.tool_call_parser}, tool_call_id: {tool_call_id}, history_cnt: {history_tool_calls_cnt}"
|
||||
)
|
||||
return tool_call_id
|
||||
|
||||
def _process_tool_calls(
|
||||
self,
|
||||
text: str,
|
||||
tools: List[Any],
|
||||
finish_reason: Dict[str, Any],
|
||||
history_tool_calls_cnt: int = 0,
|
||||
) -> tuple[Optional[List[ToolCall]], str, Dict[str, Any]]:
|
||||
"""Process tool calls in the response"""
|
||||
parser = FunctionCallParser(tools, self.tool_call_parser)
|
||||
@@ -856,15 +879,9 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
text, call_info_list = parser.parse_non_stream(text)
|
||||
tool_calls = []
|
||||
for call_info in call_info_list:
|
||||
# For Kimi-K2, align tool_call_id with the model format: functions.{name}:{index}
|
||||
if (
|
||||
self.tool_call_parser == "kimi_k2"
|
||||
and call_info.name is not None
|
||||
):
|
||||
tool_id = f"functions.{call_info.name}:{call_info.tool_index}"
|
||||
else:
|
||||
tool_id = f"call_{uuid.uuid4().hex[:24]}"
|
||||
|
||||
tool_id = self._process_tool_call_id(
|
||||
call_info, history_tool_calls_cnt
|
||||
)
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
id=tool_id,
|
||||
@@ -920,6 +937,26 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
reasoning_parser = reasoning_parser_dict[index]
|
||||
return reasoning_parser.parse_stream_chunk(delta)
|
||||
|
||||
def _get_history_tool_calls_cnt(self, request: ChatCompletionRequest) -> int:
|
||||
"""Counts the number of tool calls in the request's message history.
|
||||
|
||||
NOTE: This method is only useful for models that include self-increasing
|
||||
history tool call idx in tool calls id, such as kimi-k2
|
||||
|
||||
Args:
|
||||
request: The chat completion request object.
|
||||
|
||||
Returns:
|
||||
The total number of tool calls in the history, or 0 if not applicable.
|
||||
"""
|
||||
messages = getattr(request, "messages", [])
|
||||
idx = 0
|
||||
for msg in messages:
|
||||
if msg.role == "assistant":
|
||||
tool_calls = getattr(msg, "tool_calls", None)
|
||||
idx += len(list(tool_calls)) if tool_calls is not None else 0 # noqa
|
||||
return idx
|
||||
|
||||
def _get_enable_thinking_from_request(self, request: ChatCompletionRequest) -> bool:
|
||||
"""Extracts the 'enable_thinking' flag from request chat_template_kwargs.
|
||||
|
||||
@@ -977,6 +1014,7 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||
|
||||
# Yield tool calls
|
||||
history_tool_calls_cnt = self._get_history_tool_calls_cnt(request)
|
||||
for call_item in calls:
|
||||
# Mark that this choice has tool calls
|
||||
has_tool_calls[index] = True
|
||||
@@ -984,11 +1022,9 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
# Tool call ID should be generated only once per tool call
|
||||
if call_item.name:
|
||||
# First chunk: include ID and function name
|
||||
if self.tool_call_parser == "kimi_k2":
|
||||
# Align with Kimi-K2 format: functions.{name}:{index}
|
||||
tool_call_id = f"functions.{call_item.name}:{call_item.tool_index}"
|
||||
else:
|
||||
tool_call_id = f"call_{uuid.uuid4().hex[:24]}"
|
||||
tool_call_id = self._process_tool_call_id(
|
||||
call_item, history_tool_calls_cnt
|
||||
)
|
||||
function_name = call_item.name
|
||||
else:
|
||||
# Subsequent chunks: null ID and name for argument deltas
|
||||
|
||||
Reference in New Issue
Block a user