Files
sglang/python/sglang/srt/function_call/kimik2_detector.py

246 lines
9.9 KiB
Python

import json
import logging
import re
from typing import List
from sglang.srt.entrypoints.openai.protocol import Tool
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import (
StreamingParseResult,
StructureInfo,
ToolCallItem,
_GetInfoFunc,
)
from sglang.srt.function_call.ebnf_composer import EBNFComposer
from sglang.srt.function_call.utils import _is_complete_json
logger = logging.getLogger(__name__)
class KimiK2Detector(BaseFormatDetector):
"""
Detector for Kimi K2 model function call format.
Format Structure:
```
<|tool_calls_section_begin|>
<|tool_call_begin|>functions.{func_name}:{index} <|tool_call_argument_begin|>{json_args}<|tool_call_end|>
<|tool_calls_section_end|>
```
Reference: https://huggingface.co/moonshotai/Kimi-K2-Instruct/blob/main/docs/tool_call_guidance.md
"""
def __init__(self):
super().__init__()
self.bot_token: str = "<|tool_calls_section_begin|>"
self.eot_token: str = "<|tool_calls_section_end|>"
self.tool_call_start_token: str = "<|tool_call_begin|>"
self.tool_call_end_token: str = "<|tool_call_end|>"
self.tool_call_regex = re.compile(
r"<\|tool_call_begin\|>\s*(?P<tool_call_id>[\w\.]+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P<function_arguments>\{.*?\})\s*<\|tool_call_end\|>"
)
self.stream_tool_call_portion_regex = re.compile(
r"<\|tool_call_begin\|>\s*(?P<tool_call_id>[\w\.]+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P<function_arguments>\{.*)"
)
self._last_arguments = ""
def has_tool_call(self, text: str) -> bool:
"""Check if the text contains a KimiK2 format tool call."""
return self.bot_token in text
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
"""
One-time parsing: Detects and parses tool calls in the provided text.
:param text: The complete text to parse.
:param tools: List of available tools.
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
"""
if self.bot_token not in text:
return StreamingParseResult(normal_text=text, calls=[])
try:
# there are two possible captures - between tags, or between a
# tag and end-of-string so the result of
# findall is an array of tuples where one is a function call and
# the other is None
function_call_tuples = self.tool_call_regex.findall(text)
logger.debug("function_call_tuples: %s", function_call_tuples)
tool_calls = []
for match in function_call_tuples:
function_id, function_args = match
function_name = function_id.split(".")[1].split(":")[0]
function_idx = int(function_id.split(".")[1].split(":")[1])
logger.info(f"function_name {function_name}")
tool_calls.append(
ToolCallItem(
tool_index=function_idx, # Use the call index in the response, not tool position
name=function_name,
parameters=function_args,
)
)
content = text[: text.find(self.bot_token)]
return StreamingParseResult(normal_text=content, calls=tool_calls)
except Exception as e:
logger.error(f"Error in detect_and_parse: {e}")
# return the normal text if parsing fails
return StreamingParseResult(normal_text=text)
def parse_streaming_increment(
self, new_text: str, tools: List[Tool]
) -> StreamingParseResult:
"""
Streaming incremental parsing tool calls for KimiK2 format.
"""
self._buffer += new_text
current_text = self._buffer
# Check if we have a tool call (either the start token or individual tool call)
has_tool_call = (
self.bot_token in current_text or self.tool_call_start_token in current_text
)
if not has_tool_call:
self._buffer = ""
for e_token in [self.eot_token, self.tool_call_end_token]:
if e_token in new_text:
new_text = new_text.replace(e_token, "")
return StreamingParseResult(normal_text=new_text)
if not hasattr(self, "_tool_indices"):
self._tool_indices = self._get_tool_indices(tools)
calls: list[ToolCallItem] = []
try:
match = self.stream_tool_call_portion_regex.search(current_text)
if match:
function_id = match.group("tool_call_id")
function_args = match.group("function_arguments")
function_name = function_id.split(".")[1].split(":")[0]
# Initialize state if this is the first tool call
if self.current_tool_id == -1:
self.current_tool_id = 0
self.prev_tool_call_arr = []
self.streamed_args_for_tool = [""]
# Ensure we have enough entries in our tracking arrays
while len(self.prev_tool_call_arr) <= self.current_tool_id:
self.prev_tool_call_arr.append({})
while len(self.streamed_args_for_tool) <= self.current_tool_id:
self.streamed_args_for_tool.append("")
if not self.current_tool_name_sent:
calls.append(
ToolCallItem(
tool_index=self.current_tool_id,
name=function_name,
parameters="",
)
)
self.current_tool_name_sent = True
# Store the tool call info for serving layer completions endpoint
self.prev_tool_call_arr[self.current_tool_id] = {
"name": function_name,
"arguments": {},
}
else:
argument_diff = (
function_args[len(self._last_arguments) :]
if function_args.startswith(self._last_arguments)
else function_args
)
parsed_args_diff = argument_diff.split("<|tool_call_end|>", 1)[0]
if parsed_args_diff:
calls.append(
ToolCallItem(
tool_index=self.current_tool_id,
name=None,
parameters=parsed_args_diff,
)
)
self._last_arguments += argument_diff
self.streamed_args_for_tool[
self.current_tool_id
] += parsed_args_diff
parsed_args = function_args.split("<|tool_call_end|>", 1)[0]
if _is_complete_json(parsed_args):
try:
parsed_args = json.loads(parsed_args)
self.prev_tool_call_arr[self.current_tool_id][
"arguments"
] = parsed_args
except json.JSONDecodeError:
pass
# Find the end of the current tool call and remove only that part from buffer
tool_call_end_pattern = (
r"<\|tool_call_begin\|>.*?<\|tool_call_end\|>"
)
match = re.search(
tool_call_end_pattern, current_text, re.DOTALL
)
if match:
# Remove the completed tool call from buffer, keep any remaining content
self._buffer = current_text[match.end() :]
else:
self._buffer = ""
result = StreamingParseResult(normal_text="", calls=calls)
self.current_tool_id += 1
self._last_arguments = ""
self.current_tool_name_sent = False
return result
return StreamingParseResult(normal_text="", calls=calls)
except Exception as e:
logger.error(f"Error in parse_streaming_increment: {e}")
return StreamingParseResult(normal_text=current_text)
def structure_info(self) -> _GetInfoFunc:
"""Return function that creates StructureInfo for guided generation."""
def get_info(name: str) -> StructureInfo:
return StructureInfo(
begin=f"<|tool_calls_section_begin|><|tool_call_begin|>functions.{name}:0 <|tool_call_argument_begin|>",
end="<|tool_call_end|><|tool_calls_section_end|>",
trigger="<|tool_calls_section_begin|>",
)
return get_info
def build_ebnf(self, tools: List[Tool]) -> str:
"""
Build EBNF grammar for KimiK2 tool call format.
NOTE: The call_rule_fmt uses [0-9]+ for the function index to allow the grammar
to accept any numeric index (0, 1, 2, etc.) for proper sequential indexing in
multiple function call scenarios, while still maintaining the correct KimiK2
format structure for constrained generation.
"""
return EBNFComposer.build_ebnf(
tools,
sequence_start_token=self.bot_token,
sequence_end_token=self.eot_token,
tool_call_separator="",
call_rule_fmt='"<|tool_call_begin|>functions.{name}:" [0-9]+ " <|tool_call_argument_begin|>" {arguments_rule} "<|tool_call_end|>"',
function_format="json",
)