Replace the Kimi-K2 generated tool call idx with history tool call count (#10612)
Co-authored-by: eraser00 <eraser00@github.com>
This commit is contained in:
@@ -33,6 +33,7 @@ from sglang.srt.entrypoints.openai.utils import (
|
|||||||
process_hidden_states_from_ret,
|
process_hidden_states_from_ret,
|
||||||
to_openai_style_logprobs,
|
to_openai_style_logprobs,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.function_call.core_types import ToolCallItem
|
||||||
from sglang.srt.function_call.function_call_parser import FunctionCallParser
|
from sglang.srt.function_call.function_call_parser import FunctionCallParser
|
||||||
from sglang.srt.managers.io_struct import GenerateReqInput
|
from sglang.srt.managers.io_struct import GenerateReqInput
|
||||||
from sglang.srt.parser.conversation import generate_chat_conv
|
from sglang.srt.parser.conversation import generate_chat_conv
|
||||||
@@ -749,8 +750,9 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
and request.tools
|
and request.tools
|
||||||
and self.tool_call_parser
|
and self.tool_call_parser
|
||||||
):
|
):
|
||||||
|
history_tool_calls_cnt = self._get_history_tool_calls_cnt(request)
|
||||||
tool_calls, text, finish_reason = self._process_tool_calls(
|
tool_calls, text, finish_reason = self._process_tool_calls(
|
||||||
text, request.tools, finish_reason
|
text, request.tools, finish_reason, history_tool_calls_cnt
|
||||||
)
|
)
|
||||||
|
|
||||||
choice_data = ChatCompletionResponseChoice(
|
choice_data = ChatCompletionResponseChoice(
|
||||||
@@ -840,11 +842,32 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
token_logprobs = self._process_logprobs_tokens(logprobs, use_token_index=True)
|
token_logprobs = self._process_logprobs_tokens(logprobs, use_token_index=True)
|
||||||
return ChoiceLogprobs(content=token_logprobs)
|
return ChoiceLogprobs(content=token_logprobs)
|
||||||
|
|
||||||
|
def _process_tool_call_id(
|
||||||
|
self,
|
||||||
|
call_item: ToolCallItem,
|
||||||
|
history_tool_calls_cnt: int,
|
||||||
|
) -> str:
|
||||||
|
"""Process for generating a new and unique `tool_call_id`"""
|
||||||
|
if self.tool_call_parser != "kimi_k2":
|
||||||
|
# A simple uuid is sufficient for all models except for Kimi-K2.
|
||||||
|
tool_call_id = f"call_{uuid.uuid4().hex[:24]}"
|
||||||
|
return tool_call_id
|
||||||
|
else:
|
||||||
|
# Align with Kimi-K2 format: functions.{name}:{index}
|
||||||
|
# Kimi-K2 allows multiple tool_calls in one message; SGLang sets call_item.tool_index to the *local* position inside that message.
|
||||||
|
# Therefore, the index must be corrected by using `history_tool_calls_cnt + call_item.tool_index` to ensure globally unique and properly ordered.
|
||||||
|
tool_call_id = f"functions.{call_item.name}:{history_tool_calls_cnt+call_item.tool_index}"
|
||||||
|
logger.debug(
|
||||||
|
f"Process tool call idx, parser: {self.tool_call_parser}, tool_call_id: {tool_call_id}, history_cnt: {history_tool_calls_cnt}"
|
||||||
|
)
|
||||||
|
return tool_call_id
|
||||||
|
|
||||||
def _process_tool_calls(
|
def _process_tool_calls(
|
||||||
self,
|
self,
|
||||||
text: str,
|
text: str,
|
||||||
tools: List[Any],
|
tools: List[Any],
|
||||||
finish_reason: Dict[str, Any],
|
finish_reason: Dict[str, Any],
|
||||||
|
history_tool_calls_cnt: int = 0,
|
||||||
) -> tuple[Optional[List[ToolCall]], str, Dict[str, Any]]:
|
) -> tuple[Optional[List[ToolCall]], str, Dict[str, Any]]:
|
||||||
"""Process tool calls in the response"""
|
"""Process tool calls in the response"""
|
||||||
parser = FunctionCallParser(tools, self.tool_call_parser)
|
parser = FunctionCallParser(tools, self.tool_call_parser)
|
||||||
@@ -856,15 +879,9 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
text, call_info_list = parser.parse_non_stream(text)
|
text, call_info_list = parser.parse_non_stream(text)
|
||||||
tool_calls = []
|
tool_calls = []
|
||||||
for call_info in call_info_list:
|
for call_info in call_info_list:
|
||||||
# For Kimi-K2, align tool_call_id with the model format: functions.{name}:{index}
|
tool_id = self._process_tool_call_id(
|
||||||
if (
|
call_info, history_tool_calls_cnt
|
||||||
self.tool_call_parser == "kimi_k2"
|
)
|
||||||
and call_info.name is not None
|
|
||||||
):
|
|
||||||
tool_id = f"functions.{call_info.name}:{call_info.tool_index}"
|
|
||||||
else:
|
|
||||||
tool_id = f"call_{uuid.uuid4().hex[:24]}"
|
|
||||||
|
|
||||||
tool_calls.append(
|
tool_calls.append(
|
||||||
ToolCall(
|
ToolCall(
|
||||||
id=tool_id,
|
id=tool_id,
|
||||||
@@ -920,6 +937,26 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
reasoning_parser = reasoning_parser_dict[index]
|
reasoning_parser = reasoning_parser_dict[index]
|
||||||
return reasoning_parser.parse_stream_chunk(delta)
|
return reasoning_parser.parse_stream_chunk(delta)
|
||||||
|
|
||||||
|
def _get_history_tool_calls_cnt(self, request: ChatCompletionRequest) -> int:
|
||||||
|
"""Counts the number of tool calls in the request's message history.
|
||||||
|
|
||||||
|
NOTE: This method is only useful for models that include self-increasing
|
||||||
|
history tool call idx in tool calls id, such as kimi-k2
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: The chat completion request object.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The total number of tool calls in the history, or 0 if not applicable.
|
||||||
|
"""
|
||||||
|
messages = getattr(request, "messages", [])
|
||||||
|
idx = 0
|
||||||
|
for msg in messages:
|
||||||
|
if msg.role == "assistant":
|
||||||
|
tool_calls = getattr(msg, "tool_calls", None)
|
||||||
|
idx += len(list(tool_calls)) if tool_calls is not None else 0 # noqa
|
||||||
|
return idx
|
||||||
|
|
||||||
def _get_enable_thinking_from_request(self, request: ChatCompletionRequest) -> bool:
|
def _get_enable_thinking_from_request(self, request: ChatCompletionRequest) -> bool:
|
||||||
"""Extracts the 'enable_thinking' flag from request chat_template_kwargs.
|
"""Extracts the 'enable_thinking' flag from request chat_template_kwargs.
|
||||||
|
|
||||||
@@ -977,6 +1014,7 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||||
|
|
||||||
# Yield tool calls
|
# Yield tool calls
|
||||||
|
history_tool_calls_cnt = self._get_history_tool_calls_cnt(request)
|
||||||
for call_item in calls:
|
for call_item in calls:
|
||||||
# Mark that this choice has tool calls
|
# Mark that this choice has tool calls
|
||||||
has_tool_calls[index] = True
|
has_tool_calls[index] = True
|
||||||
@@ -984,11 +1022,9 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
# Tool call ID should be generated only once per tool call
|
# Tool call ID should be generated only once per tool call
|
||||||
if call_item.name:
|
if call_item.name:
|
||||||
# First chunk: include ID and function name
|
# First chunk: include ID and function name
|
||||||
if self.tool_call_parser == "kimi_k2":
|
tool_call_id = self._process_tool_call_id(
|
||||||
# Align with Kimi-K2 format: functions.{name}:{index}
|
call_item, history_tool_calls_cnt
|
||||||
tool_call_id = f"functions.{call_item.name}:{call_item.tool_index}"
|
)
|
||||||
else:
|
|
||||||
tool_call_id = f"call_{uuid.uuid4().hex[:24]}"
|
|
||||||
function_name = call_item.name
|
function_name = call_item.name
|
||||||
else:
|
else:
|
||||||
# Subsequent chunks: null ID and name for argument deltas
|
# Subsequent chunks: null ID and name for argument deltas
|
||||||
|
|||||||
@@ -420,6 +420,181 @@ class ServingChatTestCase(unittest.TestCase):
|
|||||||
tool_calls = payload["choices"][0]["delta"]["tool_calls"]
|
tool_calls = payload["choices"][0]["delta"]["tool_calls"]
|
||||||
self.assertEqual(tool_calls[0]["id"], "functions.get_weather:0")
|
self.assertEqual(tool_calls[0]["id"], "functions.get_weather:0")
|
||||||
|
|
||||||
|
def test_kimi_k2_non_streaming_tool_call_id_with_history(self):
|
||||||
|
"""Ensure non-streaming tool_call.id increase with tool calls history for kimi_k2 parser."""
|
||||||
|
|
||||||
|
# Force kimi_k2 parser
|
||||||
|
self.chat.tool_call_parser = "kimi_k2"
|
||||||
|
|
||||||
|
# Prepare request with tool calls history
|
||||||
|
req = ChatCompletionRequest(
|
||||||
|
model="x",
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": "What's the weather today in paris?"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Let me do some search first.",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "functions.get_weather:0",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"arguments": '{"city": "Paris"}',
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"content": "It's rainy in paris now.",
|
||||||
|
"tool_call_id": "functions.get_weather:0",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "It's rainy now.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What about LA and Tokyo?",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
tools=[{"type": "function", "function": {"name": "get_weather"}}],
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock FunctionCallParser.parse_non_stream to return one tool call
|
||||||
|
with patch(
|
||||||
|
"sglang.srt.entrypoints.openai.serving_chat.FunctionCallParser"
|
||||||
|
) as ParserMock:
|
||||||
|
parser_instance = ParserMock.return_value
|
||||||
|
|
||||||
|
# Build a mock ToolCallItem-like object
|
||||||
|
call_info = Mock()
|
||||||
|
call_info.name = "get_weather"
|
||||||
|
call_info.parameters = '{"city":"Loa Angeles"}'
|
||||||
|
# Kimi-K2 series models might generate fixed number tool_indx,
|
||||||
|
# ignoring the tool calls history and mess up all the following tool calls
|
||||||
|
call_info.tool_index = 0
|
||||||
|
|
||||||
|
call_info2 = Mock()
|
||||||
|
call_info2.name = "get_weather"
|
||||||
|
call_info2.parameters = '{"city":"Tokyo"}'
|
||||||
|
call_info2.tool_index = 1
|
||||||
|
|
||||||
|
parser_instance.has_tool_call.return_value = True
|
||||||
|
parser_instance.parse_non_stream.return_value = (
|
||||||
|
"",
|
||||||
|
[call_info, call_info2],
|
||||||
|
)
|
||||||
|
|
||||||
|
finish_reason = {"type": "stop", "matched": None}
|
||||||
|
tools = [
|
||||||
|
{"type": "function", "function": {"name": "get_weather"}},
|
||||||
|
]
|
||||||
|
|
||||||
|
history_tool_calls_cnt = self.chat._get_history_tool_calls_cnt(req)
|
||||||
|
tool_calls, remaining_text, _ = self.chat._process_tool_calls(
|
||||||
|
text="<|tool_calls_section_begin|>...",
|
||||||
|
tools=tools,
|
||||||
|
finish_reason=finish_reason,
|
||||||
|
history_tool_calls_cnt=history_tool_calls_cnt,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(history_tool_calls_cnt, 1)
|
||||||
|
self.assertIsNotNone(tool_calls)
|
||||||
|
self.assertEqual(len(tool_calls), 2)
|
||||||
|
self.assertEqual(tool_calls[0].id, "functions.get_weather:1")
|
||||||
|
self.assertEqual(tool_calls[0].function.name, "get_weather")
|
||||||
|
self.assertEqual(tool_calls[1].id, "functions.get_weather:2")
|
||||||
|
self.assertEqual(tool_calls[1].function.name, "get_weather")
|
||||||
|
|
||||||
|
def test_kimi_k2_streaming_tool_call_id_with_history(self):
|
||||||
|
"""Ensure streaming first chunk tool_call.id increase with tool calls history for kimi_k2 parser."""
|
||||||
|
|
||||||
|
# Force kimi_k2 parser
|
||||||
|
self.chat.tool_call_parser = "kimi_k2"
|
||||||
|
|
||||||
|
# Prepare request with tool calls history
|
||||||
|
req = ChatCompletionRequest(
|
||||||
|
model="x",
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": "What's the weather today in paris?"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Let me do some search first.",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "functions.get_weather:0",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"arguments": '{"city": "Paris"}',
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"content": "It's rainy in paris now.",
|
||||||
|
"tool_call_id": "functions.get_weather:0",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "It's rainy now.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What about LA?",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
tools=[{"type": "function", "function": {"name": "get_weather"}}],
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Patch FunctionCallParser used inside _process_tool_call_stream
|
||||||
|
with patch(
|
||||||
|
"sglang.srt.entrypoints.openai.serving_chat.FunctionCallParser"
|
||||||
|
) as ParserMock:
|
||||||
|
parser_instance = ParserMock.return_value
|
||||||
|
|
||||||
|
# First call returns one ToolCallItem-like chunk (with name)
|
||||||
|
first_chunk_call = Mock()
|
||||||
|
# Kimi-K2 series models might generate fixed number tool_indx,
|
||||||
|
# ignoring the tool calls history and mess up all the following tool calls
|
||||||
|
first_chunk_call.tool_index = 0
|
||||||
|
first_chunk_call.name = "get_weather"
|
||||||
|
first_chunk_call.parameters = ""
|
||||||
|
parser_instance.parse_stream_chunk.side_effect = [
|
||||||
|
("", [first_chunk_call]),
|
||||||
|
("", []),
|
||||||
|
]
|
||||||
|
|
||||||
|
async def collect_first_tool_chunk():
|
||||||
|
gen = self.chat._process_tool_call_stream(
|
||||||
|
index=0,
|
||||||
|
delta="irrelevant",
|
||||||
|
parser_dict={},
|
||||||
|
content={"meta_info": {"id": "chatcmpl-test"}},
|
||||||
|
request=req,
|
||||||
|
has_tool_calls={},
|
||||||
|
)
|
||||||
|
# Get first yielded SSE line
|
||||||
|
line = None
|
||||||
|
async for emitted in gen:
|
||||||
|
line = emitted
|
||||||
|
break
|
||||||
|
return line
|
||||||
|
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
line = loop.run_until_complete(collect_first_tool_chunk())
|
||||||
|
self.assertIsNotNone(line)
|
||||||
|
self.assertTrue(line.startswith("data: "))
|
||||||
|
|
||||||
|
payload = json.loads(line[len("data: ") :])
|
||||||
|
tool_calls = payload["choices"][0]["delta"]["tool_calls"]
|
||||||
|
self.assertEqual(tool_calls[0]["id"], "functions.get_weather:1")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main(verbosity=2)
|
unittest.main(verbosity=2)
|
||||||
|
|||||||
Reference in New Issue
Block a user