Support Kimi K2 (#7940)
This commit is contained in:
@@ -25,6 +25,7 @@ from transformers import PretrainedConfig
|
||||
from sglang.srt.hf_transformers_utils import (
|
||||
get_config,
|
||||
get_context_length,
|
||||
get_generation_config,
|
||||
get_hf_text_config,
|
||||
)
|
||||
from sglang.srt.layers.quantization import QUANTIZATION_METHODS
|
||||
@@ -83,6 +84,13 @@ class ModelConfig:
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.hf_generation_config = get_generation_config(
|
||||
self.model_path,
|
||||
trust_remote_code=trust_remote_code,
|
||||
revision=revision,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.hf_text_config = get_hf_text_config(self.hf_config)
|
||||
self.attention_chunk_size = getattr(
|
||||
self.hf_text_config, "attention_chunk_size", None
|
||||
@@ -467,6 +475,19 @@ class ModelConfig:
|
||||
if eos_ids:
|
||||
# it can be either int or list of int
|
||||
eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
|
||||
if eos_ids is None:
|
||||
eos_ids = set()
|
||||
if self.hf_generation_config:
|
||||
generation_eos_ids = getattr(
|
||||
self.hf_generation_config, "eos_token_id", None
|
||||
)
|
||||
if generation_eos_ids:
|
||||
generation_eos_ids = (
|
||||
{generation_eos_ids}
|
||||
if isinstance(generation_eos_ids, int)
|
||||
else set(generation_eos_ids)
|
||||
)
|
||||
eos_ids = eos_ids | generation_eos_ids
|
||||
return eos_ids
|
||||
|
||||
def maybe_pull_model_tokenizer_from_remote(self) -> None:
|
||||
|
||||
@@ -10,6 +10,7 @@ from sglang.srt.entrypoints.openai.protocol import (
|
||||
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
||||
from sglang.srt.function_call.core_types import ToolCallItem
|
||||
from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector
|
||||
from sglang.srt.function_call.kimik2_detector import KimiK2Detector
|
||||
from sglang.srt.function_call.llama32_detector import Llama32Detector
|
||||
from sglang.srt.function_call.mistral_detector import MistralDetector
|
||||
from sglang.srt.function_call.pythonic_detector import PythonicDetector
|
||||
@@ -33,6 +34,7 @@ class FunctionCallParser:
|
||||
"mistral": MistralDetector,
|
||||
"deepseekv3": DeepSeekV3Detector,
|
||||
"pythonic": PythonicDetector,
|
||||
"kimi_k2": KimiK2Detector,
|
||||
}
|
||||
|
||||
def __init__(self, tools: List[Tool], tool_call_parser: str):
|
||||
|
||||
220
python/sglang/srt/function_call/kimik2_detector.py
Normal file
220
python/sglang/srt/function_call/kimik2_detector.py
Normal file
@@ -0,0 +1,220 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import List
|
||||
|
||||
from sglang.srt.entrypoints.openai.protocol import Tool
|
||||
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
||||
from sglang.srt.function_call.core_types import (
|
||||
StreamingParseResult,
|
||||
StructureInfo,
|
||||
ToolCallItem,
|
||||
_GetInfoFunc,
|
||||
)
|
||||
from sglang.srt.function_call.ebnf_composer import EBNFComposer
|
||||
from sglang.srt.function_call.utils import _is_complete_json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class KimiK2Detector(BaseFormatDetector):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._buffer = ""
|
||||
self.current_tool_name_sent: bool = False
|
||||
self.prev_tool_call_arr: list[dict] = []
|
||||
self.current_tool_id: int = -1
|
||||
self.streamed_args_for_tool: list[str] = (
|
||||
[]
|
||||
) # map what has been streamed for each tool so far to a list
|
||||
|
||||
self.bot_token: str = "<|tool_calls_section_begin|>"
|
||||
self.eot_token: str = "<|tool_calls_section_end|>"
|
||||
|
||||
self.tool_call_start_token: str = "<|tool_call_begin|>"
|
||||
self.tool_call_end_token: str = "<|tool_call_end|>"
|
||||
|
||||
self.tool_call_regex = re.compile(
|
||||
r"<\|tool_call_begin\|>\s*(?P<tool_call_id>[\w\.]+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P<function_arguments>\{.*?\})\s*<\|tool_call_end\|>"
|
||||
)
|
||||
|
||||
self.stream_tool_call_portion_regex = re.compile(
|
||||
r"<\|tool_call_begin\|>\s*(?P<tool_call_id>[\w\.]+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P<function_arguments>\{.*)"
|
||||
)
|
||||
|
||||
self._last_arguments = ""
|
||||
|
||||
def has_tool_call(self, text: str) -> bool:
|
||||
"""Check if the text contains a KimiK2 format tool call."""
|
||||
return self.bot_token in text
|
||||
|
||||
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
||||
"""
|
||||
One-time parsing: Detects and parses tool calls in the provided text.
|
||||
|
||||
:param text: The complete text to parse.
|
||||
:param tools: List of available tools.
|
||||
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
|
||||
"""
|
||||
if self.bot_token not in text:
|
||||
return StreamingParseResult(normal_text=text, calls=[])
|
||||
try:
|
||||
# there are two possible captures - between tags, or between a
|
||||
# tag and end-of-string so the result of
|
||||
# findall is an array of tuples where one is a function call and
|
||||
# the other is None
|
||||
function_call_tuples = self.tool_call_regex.findall(text)
|
||||
|
||||
logger.debug("function_call_tuples: %s", function_call_tuples)
|
||||
|
||||
tool_calls = []
|
||||
for match in function_call_tuples:
|
||||
function_id, function_args = match
|
||||
function_name = function_id.split(".")[1].split(":")[0]
|
||||
function_idx = int(function_id.split(".")[1].split(":")[1])
|
||||
|
||||
logger.info(f"function_name {function_name}")
|
||||
|
||||
tool_calls.append(
|
||||
ToolCallItem(
|
||||
tool_index=function_idx, # Use the call index in the response, not tool position
|
||||
name=function_name,
|
||||
parameters=function_args,
|
||||
)
|
||||
)
|
||||
|
||||
content = text[: text.find(self.bot_token)]
|
||||
return StreamingParseResult(normal_text=content, calls=tool_calls)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in detect_and_parse: {e}")
|
||||
# return the normal text if parsing fails
|
||||
return StreamingParseResult(normal_text=text)
|
||||
|
||||
def parse_streaming_increment(
|
||||
self, new_text: str, tools: List[Tool]
|
||||
) -> StreamingParseResult:
|
||||
"""
|
||||
Streaming incremental parsing tool calls for KimiK2 format.
|
||||
"""
|
||||
self._buffer += new_text
|
||||
current_text = self._buffer
|
||||
|
||||
# Check if we have a tool call (either the start token or individual tool call)
|
||||
has_tool_call = (
|
||||
self.bot_token in current_text or self.tool_call_start_token in current_text
|
||||
)
|
||||
|
||||
if not has_tool_call:
|
||||
self._buffer = ""
|
||||
for e_token in [self.eot_token, self.tool_call_end_token]:
|
||||
if e_token in new_text:
|
||||
new_text = new_text.replace(e_token, "")
|
||||
return StreamingParseResult(normal_text=new_text)
|
||||
|
||||
if not hasattr(self, "_tool_indices"):
|
||||
self._tool_indices = {
|
||||
tool.function.name: i
|
||||
for i, tool in enumerate(tools)
|
||||
if tool.function and tool.function.name
|
||||
}
|
||||
|
||||
calls: list[ToolCallItem] = []
|
||||
try:
|
||||
match = self.stream_tool_call_portion_regex.search(current_text)
|
||||
if match:
|
||||
function_id = match.group("tool_call_id")
|
||||
function_args = match.group("function_arguments")
|
||||
|
||||
function_name = function_id.split(".")[1].split(":")[0]
|
||||
|
||||
# Initialize state if this is the first tool call
|
||||
if self.current_tool_id == -1:
|
||||
self.current_tool_id = 0
|
||||
self.prev_tool_call_arr = []
|
||||
self.streamed_args_for_tool = [""]
|
||||
|
||||
# Ensure we have enough entries in our tracking arrays
|
||||
while len(self.prev_tool_call_arr) <= self.current_tool_id:
|
||||
self.prev_tool_call_arr.append({})
|
||||
while len(self.streamed_args_for_tool) <= self.current_tool_id:
|
||||
self.streamed_args_for_tool.append("")
|
||||
|
||||
if not self.current_tool_name_sent:
|
||||
calls.append(
|
||||
ToolCallItem(
|
||||
tool_index=self.current_tool_id,
|
||||
name=function_name,
|
||||
parameters="",
|
||||
)
|
||||
)
|
||||
self.current_tool_name_sent = True
|
||||
# Store the tool call info for adapter.py
|
||||
self.prev_tool_call_arr[self.current_tool_id] = {
|
||||
"name": function_name,
|
||||
"arguments": {},
|
||||
}
|
||||
else:
|
||||
argument_diff = (
|
||||
function_args[len(self._last_arguments) :]
|
||||
if function_args.startswith(self._last_arguments)
|
||||
else function_args
|
||||
)
|
||||
|
||||
parsed_args_diff = argument_diff.split("<|tool_call_end|>", 1)[0]
|
||||
|
||||
if parsed_args_diff:
|
||||
|
||||
calls.append(
|
||||
ToolCallItem(
|
||||
tool_index=self.current_tool_id,
|
||||
name=None,
|
||||
parameters=parsed_args_diff,
|
||||
)
|
||||
)
|
||||
self._last_arguments += argument_diff
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id
|
||||
] += parsed_args_diff
|
||||
|
||||
parsed_args = function_args.split("<|tool_call_end|>", 1)[0]
|
||||
if _is_complete_json(parsed_args):
|
||||
try:
|
||||
parsed_args = json.loads(parsed_args)
|
||||
self.prev_tool_call_arr[self.current_tool_id][
|
||||
"arguments"
|
||||
] = parsed_args
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Find the end of the current tool call and remove only that part from buffer
|
||||
tool_call_end_pattern = (
|
||||
r"<\|tool_call_begin\|>.*?<\|tool_call_end\|>"
|
||||
)
|
||||
match = re.search(
|
||||
tool_call_end_pattern, current_text, re.DOTALL
|
||||
)
|
||||
if match:
|
||||
# Remove the completed tool call from buffer, keep any remaining content
|
||||
self._buffer = current_text[match.end() :]
|
||||
else:
|
||||
self._buffer = ""
|
||||
|
||||
result = StreamingParseResult(normal_text="", calls=calls)
|
||||
self.current_tool_id += 1
|
||||
self._last_arguments = ""
|
||||
self.current_tool_name_sent = False
|
||||
return result
|
||||
|
||||
return StreamingParseResult(normal_text="", calls=calls)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in parse_streaming_increment: {e}")
|
||||
return StreamingParseResult(normal_text=current_text)
|
||||
|
||||
def structure_info(self) -> _GetInfoFunc:
|
||||
raise NotImplementedError()
|
||||
|
||||
def build_ebnf(self, tools: List[Tool]):
|
||||
raise NotImplementedError()
|
||||
@@ -14,6 +14,7 @@
|
||||
"""Utilities for Huggingface Transformers."""
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
@@ -25,6 +26,7 @@ from transformers import (
|
||||
AutoConfig,
|
||||
AutoProcessor,
|
||||
AutoTokenizer,
|
||||
GenerationConfig,
|
||||
PretrainedConfig,
|
||||
PreTrainedTokenizer,
|
||||
PreTrainedTokenizerBase,
|
||||
@@ -153,6 +155,22 @@ def get_config(
|
||||
return config
|
||||
|
||||
|
||||
@lru_cache_frozenset(maxsize=32)
|
||||
def get_generation_config(
|
||||
model: str,
|
||||
trust_remote_code: bool,
|
||||
revision: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
try:
|
||||
return GenerationConfig.from_pretrained(
|
||||
model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
|
||||
)
|
||||
except OSError as e:
|
||||
logging.info("model doesn't have generation_config.json")
|
||||
return None
|
||||
|
||||
|
||||
# Models don't use the same configuration key for determining the maximum
|
||||
# context length. Store them here so we can sanely check them.
|
||||
# NOTE: The ordering here is important. Some models have two of these and we
|
||||
|
||||
@@ -1048,9 +1048,16 @@ class ServerArgs:
|
||||
parser.add_argument(
|
||||
"--tool-call-parser",
|
||||
type=str,
|
||||
choices=["qwen25", "mistral", "llama3", "deepseekv3", "pythonic"],
|
||||
choices=[
|
||||
"qwen25",
|
||||
"mistral",
|
||||
"llama3",
|
||||
"deepseekv3",
|
||||
"pythonic",
|
||||
"kimi_k2",
|
||||
],
|
||||
default=ServerArgs.tool_call_parser,
|
||||
help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', and 'pythonic'.",
|
||||
help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', 'pythonic', and 'kimi_k2'.",
|
||||
)
|
||||
|
||||
# Data parallelism
|
||||
|
||||
Reference in New Issue
Block a user