diff --git a/docs/backend/server_arguments.md b/docs/backend/server_arguments.md index ef88f8ac3..4f19321c3 100644 --- a/docs/backend/server_arguments.md +++ b/docs/backend/server_arguments.md @@ -135,7 +135,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s | `--file-storage-path` | The path of the file storage in backend. | sglang_storage | | `--enable-cache-report` | Return number of cached tokens in usage.prompt_tokens_details for each openai request. | False | | `--reasoning-parser` | Specify the parser for reasoning models, supported parsers are: {list(ReasoningParser.DetectorMap.keys())}. | None | -| `--tool-call-parser` | Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', and 'pythonic'. | None | +| `--tool-call-parser` | Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', 'pythonic', and 'kimi_k2'. | None | ## Data parallelism diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index a6f563cf0..ecff10244 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -25,6 +25,7 @@ from transformers import PretrainedConfig from sglang.srt.hf_transformers_utils import ( get_config, get_context_length, + get_generation_config, get_hf_text_config, ) from sglang.srt.layers.quantization import QUANTIZATION_METHODS @@ -83,6 +84,13 @@ class ModelConfig: **kwargs, ) + self.hf_generation_config = get_generation_config( + self.model_path, + trust_remote_code=trust_remote_code, + revision=revision, + **kwargs, + ) + self.hf_text_config = get_hf_text_config(self.hf_config) self.attention_chunk_size = getattr( self.hf_text_config, "attention_chunk_size", None @@ -467,6 +475,19 @@ class ModelConfig: if eos_ids: # it can be either int or list of int eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids) + if eos_ids is None: + eos_ids = set() + if self.hf_generation_config: + generation_eos_ids = getattr( + self.hf_generation_config, "eos_token_id", None + ) + if generation_eos_ids: + generation_eos_ids = ( + {generation_eos_ids} + if isinstance(generation_eos_ids, int) + else set(generation_eos_ids) + ) + eos_ids = eos_ids | generation_eos_ids return eos_ids def maybe_pull_model_tokenizer_from_remote(self) -> None: diff --git a/python/sglang/srt/function_call/function_call_parser.py b/python/sglang/srt/function_call/function_call_parser.py index 10b92a0af..a6708024f 100644 --- a/python/sglang/srt/function_call/function_call_parser.py +++ b/python/sglang/srt/function_call/function_call_parser.py @@ -10,6 +10,7 @@ from sglang.srt.entrypoints.openai.protocol import ( from sglang.srt.function_call.base_format_detector import BaseFormatDetector from sglang.srt.function_call.core_types import ToolCallItem from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector +from sglang.srt.function_call.kimik2_detector import KimiK2Detector from sglang.srt.function_call.llama32_detector import Llama32Detector from sglang.srt.function_call.mistral_detector import MistralDetector from sglang.srt.function_call.pythonic_detector import PythonicDetector @@ -33,6 +34,7 @@ class FunctionCallParser: "mistral": MistralDetector, "deepseekv3": DeepSeekV3Detector, "pythonic": PythonicDetector, + "kimi_k2": KimiK2Detector, } def __init__(self, tools: List[Tool], tool_call_parser: str): diff --git a/python/sglang/srt/function_call/kimik2_detector.py b/python/sglang/srt/function_call/kimik2_detector.py new file mode 100644 index 000000000..94457ccda --- /dev/null +++ b/python/sglang/srt/function_call/kimik2_detector.py @@ -0,0 +1,220 @@ +import json +import logging +import re +from typing import List + +from sglang.srt.entrypoints.openai.protocol import Tool +from sglang.srt.function_call.base_format_detector import BaseFormatDetector +from sglang.srt.function_call.core_types import ( + StreamingParseResult, + StructureInfo, + ToolCallItem, + _GetInfoFunc, +) +from sglang.srt.function_call.ebnf_composer import EBNFComposer +from sglang.srt.function_call.utils import _is_complete_json + +logger = logging.getLogger(__name__) + + +class KimiK2Detector(BaseFormatDetector): + + def __init__(self): + super().__init__() + self._buffer = "" + self.current_tool_name_sent: bool = False + self.prev_tool_call_arr: list[dict] = [] + self.current_tool_id: int = -1 + self.streamed_args_for_tool: list[str] = ( + [] + ) # map what has been streamed for each tool so far to a list + + self.bot_token: str = "<|tool_calls_section_begin|>" + self.eot_token: str = "<|tool_calls_section_end|>" + + self.tool_call_start_token: str = "<|tool_call_begin|>" + self.tool_call_end_token: str = "<|tool_call_end|>" + + self.tool_call_regex = re.compile( + r"<\|tool_call_begin\|>\s*(?P[\w\.]+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P\{.*?\})\s*<\|tool_call_end\|>" + ) + + self.stream_tool_call_portion_regex = re.compile( + r"<\|tool_call_begin\|>\s*(?P[\w\.]+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P\{.*)" + ) + + self._last_arguments = "" + + def has_tool_call(self, text: str) -> bool: + """Check if the text contains a KimiK2 format tool call.""" + return self.bot_token in text + + def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult: + """ + One-time parsing: Detects and parses tool calls in the provided text. + + :param text: The complete text to parse. + :param tools: List of available tools. + :return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls. + """ + if self.bot_token not in text: + return StreamingParseResult(normal_text=text, calls=[]) + try: + # there are two possible captures - between tags, or between a + # tag and end-of-string so the result of + # findall is an array of tuples where one is a function call and + # the other is None + function_call_tuples = self.tool_call_regex.findall(text) + + logger.debug("function_call_tuples: %s", function_call_tuples) + + tool_calls = [] + for match in function_call_tuples: + function_id, function_args = match + function_name = function_id.split(".")[1].split(":")[0] + function_idx = int(function_id.split(".")[1].split(":")[1]) + + logger.info(f"function_name {function_name}") + + tool_calls.append( + ToolCallItem( + tool_index=function_idx, # Use the call index in the response, not tool position + name=function_name, + parameters=function_args, + ) + ) + + content = text[: text.find(self.bot_token)] + return StreamingParseResult(normal_text=content, calls=tool_calls) + + except Exception as e: + logger.error(f"Error in detect_and_parse: {e}") + # return the normal text if parsing fails + return StreamingParseResult(normal_text=text) + + def parse_streaming_increment( + self, new_text: str, tools: List[Tool] + ) -> StreamingParseResult: + """ + Streaming incremental parsing tool calls for KimiK2 format. + """ + self._buffer += new_text + current_text = self._buffer + + # Check if we have a tool call (either the start token or individual tool call) + has_tool_call = ( + self.bot_token in current_text or self.tool_call_start_token in current_text + ) + + if not has_tool_call: + self._buffer = "" + for e_token in [self.eot_token, self.tool_call_end_token]: + if e_token in new_text: + new_text = new_text.replace(e_token, "") + return StreamingParseResult(normal_text=new_text) + + if not hasattr(self, "_tool_indices"): + self._tool_indices = { + tool.function.name: i + for i, tool in enumerate(tools) + if tool.function and tool.function.name + } + + calls: list[ToolCallItem] = [] + try: + match = self.stream_tool_call_portion_regex.search(current_text) + if match: + function_id = match.group("tool_call_id") + function_args = match.group("function_arguments") + + function_name = function_id.split(".")[1].split(":")[0] + + # Initialize state if this is the first tool call + if self.current_tool_id == -1: + self.current_tool_id = 0 + self.prev_tool_call_arr = [] + self.streamed_args_for_tool = [""] + + # Ensure we have enough entries in our tracking arrays + while len(self.prev_tool_call_arr) <= self.current_tool_id: + self.prev_tool_call_arr.append({}) + while len(self.streamed_args_for_tool) <= self.current_tool_id: + self.streamed_args_for_tool.append("") + + if not self.current_tool_name_sent: + calls.append( + ToolCallItem( + tool_index=self.current_tool_id, + name=function_name, + parameters="", + ) + ) + self.current_tool_name_sent = True + # Store the tool call info for adapter.py + self.prev_tool_call_arr[self.current_tool_id] = { + "name": function_name, + "arguments": {}, + } + else: + argument_diff = ( + function_args[len(self._last_arguments) :] + if function_args.startswith(self._last_arguments) + else function_args + ) + + parsed_args_diff = argument_diff.split("<|tool_call_end|>", 1)[0] + + if parsed_args_diff: + + calls.append( + ToolCallItem( + tool_index=self.current_tool_id, + name=None, + parameters=parsed_args_diff, + ) + ) + self._last_arguments += argument_diff + self.streamed_args_for_tool[ + self.current_tool_id + ] += parsed_args_diff + + parsed_args = function_args.split("<|tool_call_end|>", 1)[0] + if _is_complete_json(parsed_args): + try: + parsed_args = json.loads(parsed_args) + self.prev_tool_call_arr[self.current_tool_id][ + "arguments" + ] = parsed_args + except json.JSONDecodeError: + pass + + # Find the end of the current tool call and remove only that part from buffer + tool_call_end_pattern = ( + r"<\|tool_call_begin\|>.*?<\|tool_call_end\|>" + ) + match = re.search( + tool_call_end_pattern, current_text, re.DOTALL + ) + if match: + # Remove the completed tool call from buffer, keep any remaining content + self._buffer = current_text[match.end() :] + else: + self._buffer = "" + + result = StreamingParseResult(normal_text="", calls=calls) + self.current_tool_id += 1 + self._last_arguments = "" + self.current_tool_name_sent = False + return result + + return StreamingParseResult(normal_text="", calls=calls) + + except Exception as e: + logger.error(f"Error in parse_streaming_increment: {e}") + return StreamingParseResult(normal_text=current_text) + + def structure_info(self) -> _GetInfoFunc: + raise NotImplementedError() + + def build_ebnf(self, tools: List[Tool]): + raise NotImplementedError() diff --git a/python/sglang/srt/hf_transformers_utils.py b/python/sglang/srt/hf_transformers_utils.py index 5fcbb4cdc..e5b4af0c3 100644 --- a/python/sglang/srt/hf_transformers_utils.py +++ b/python/sglang/srt/hf_transformers_utils.py @@ -14,6 +14,7 @@ """Utilities for Huggingface Transformers.""" import contextlib +import logging import os import warnings from pathlib import Path @@ -25,6 +26,7 @@ from transformers import ( AutoConfig, AutoProcessor, AutoTokenizer, + GenerationConfig, PretrainedConfig, PreTrainedTokenizer, PreTrainedTokenizerBase, @@ -153,6 +155,22 @@ def get_config( return config +@lru_cache_frozenset(maxsize=32) +def get_generation_config( + model: str, + trust_remote_code: bool, + revision: Optional[str] = None, + **kwargs, +): + try: + return GenerationConfig.from_pretrained( + model, trust_remote_code=trust_remote_code, revision=revision, **kwargs + ) + except OSError as e: + logging.info("model doesn't have generation_config.json") + return None + + # Models don't use the same configuration key for determining the maximum # context length. Store them here so we can sanely check them. # NOTE: The ordering here is important. Some models have two of these and we diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 73f711add..935c1b89d 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1048,9 +1048,16 @@ class ServerArgs: parser.add_argument( "--tool-call-parser", type=str, - choices=["qwen25", "mistral", "llama3", "deepseekv3", "pythonic"], + choices=[ + "qwen25", + "mistral", + "llama3", + "deepseekv3", + "pythonic", + "kimi_k2", + ], default=ServerArgs.tool_call_parser, - help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', and 'pythonic'.", + help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', 'pythonic', and 'kimi_k2'.", ) # Data parallelism diff --git a/test/srt/test_function_call_parser.py b/test/srt/test_function_call_parser.py index 35b75d715..f9c36a9a2 100644 --- a/test/srt/test_function_call_parser.py +++ b/test/srt/test_function_call_parser.py @@ -6,6 +6,7 @@ from xgrammar import GrammarCompiler, TokenizerInfo from sglang.srt.entrypoints.openai.protocol import Function, Tool from sglang.srt.function_call.base_format_detector import BaseFormatDetector from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector +from sglang.srt.function_call.kimik2_detector import KimiK2Detector from sglang.srt.function_call.llama32_detector import Llama32Detector from sglang.srt.function_call.mistral_detector import MistralDetector from sglang.srt.function_call.pythonic_detector import PythonicDetector @@ -1138,5 +1139,213 @@ class TestLlama32Detector(unittest.TestCase): self.assertTrue(result.normal_text.strip().startswith("Some intro.")) +class TestKimiK2Detector(unittest.TestCase): + + def setUp(self): + """Set up test tools and detector.""" + self.tools = [ + Tool( + type="function", + function=Function( + name="get_weather", + description="Get weather information", + parameters={ + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "City name", + } + }, + "required": ["city"], + }, + ), + ), + Tool( + type="function", + function=Function( + name="get_tourist_attractions", + description="Get tourist attractions", + parameters={ + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "City name", + } + }, + "required": ["city"], + }, + ), + ), + ] + self.detector = KimiK2Detector() + + def test_single_tool_call(self): + """Test parsing a single tool call in a complete text.""" + text = '<|tool_calls_section_begin|><|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{"city": "Paris"}<|tool_call_end|><|tool_calls_section_end|>' + result = self.detector.detect_and_parse(text, self.tools) + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].name, "get_weather") + self.assertEqual(result.calls[0].parameters, '{"city": "Paris"}') + self.assertEqual(result.normal_text, "") + + def test_multiple_tool_calls(self): + """Test parsing multiple tool calls in a complete text.""" + text = '<|tool_calls_section_begin|><|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{"city": "Paris"}<|tool_call_end|><|tool_call_begin|>functions.get_tourist_attractions:1<|tool_call_argument_begin|>{"city": "London"}<|tool_call_end|><|tool_calls_section_end|>' + result = self.detector.detect_and_parse(text, self.tools) + self.assertEqual(len(result.calls), 2) + self.assertEqual(result.calls[0].name, "get_weather") + self.assertEqual(result.calls[0].parameters, '{"city": "Paris"}') + self.assertEqual(result.calls[1].name, "get_tourist_attractions") + self.assertEqual(result.calls[1].parameters, '{"city": "London"}') + self.assertEqual(result.normal_text, "") + + def test_streaming_tool_call(self): + """Test streaming incremental parsing of a tool call.""" + chunks = [ + "<|tool_calls_section_begin|><|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{", + '"city": "Paris"', + "}", + "<|tool_call_end|><|tool_calls_section_end|>", + ] + + tool_calls = [] + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, self.tools) + for tool_call_chunk in result.calls: + if tool_call_chunk.tool_index is not None: + + while len(tool_calls) <= tool_call_chunk.tool_index: + tool_calls.append({"name": "", "parameters": ""}) + + tc = tool_calls[tool_call_chunk.tool_index] + + if tool_call_chunk.name: + tc["name"] += tool_call_chunk.name + if tool_call_chunk.parameters: + tc["parameters"] += tool_call_chunk.parameters + + self.assertEqual(len(tool_calls), 1) + self.assertEqual(tool_calls[0]["name"], "get_weather") + self.assertEqual(tool_calls[0]["parameters"], '{"city": "Paris"}') + + def test_streaming_multiple_tool_calls(self): + """Test streaming incremental parsing of multiple tool calls.""" + chunks = [ + "<|tool_calls_section_begin|><|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{", + '"city": "Paris"', + "}<|tool_call_end|>", + "<|tool_call_begin|>functions.get_tourist_attractions:1<|tool_call_argument_begin|>{", + '"city": "London"', + "}<|tool_call_end|>", + "<|tool_calls_section_end|>", + ] + + tool_calls = [] + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, self.tools) + for tool_call_chunk in result.calls: + if tool_call_chunk.tool_index is not None: + + while len(tool_calls) <= tool_call_chunk.tool_index: + tool_calls.append({"name": "", "parameters": ""}) + + tc = tool_calls[tool_call_chunk.tool_index] + + if tool_call_chunk.name: + tc["name"] += tool_call_chunk.name + if tool_call_chunk.parameters: + tc["parameters"] += tool_call_chunk.parameters + + self.assertEqual(len(tool_calls), 2) + self.assertEqual(tool_calls[0]["name"], "get_weather") + self.assertEqual(tool_calls[0]["parameters"], '{"city": "Paris"}') + self.assertEqual(tool_calls[1]["name"], "get_tourist_attractions") + self.assertEqual(tool_calls[1]["parameters"], '{"city": "London"}') + + def test_tool_call_completion(self): + """Test that the buffer and state are reset after a tool call is completed.""" + chunks = [ + "<|tool_calls_section_begin|><|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{", + '"city": "Paris"', + "}", + "<|tool_call_end|>", + "<|tool_calls_section_end|>", + ] + + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, self.tools) + + # After processing all chunks, the buffer should be empty and current_tool_id should be reset + self.assertEqual(self.detector._buffer, "") + self.assertEqual(self.detector.current_tool_id, 1) + + def test_tool_name_streaming(self): + """Test that tool names are streamed correctly with the right index.""" + chunks = [ + "<|tool_calls_section_begin|><|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{", + '"city": "Paris"', + "}", + "<|tool_call_end|>", + "<|tool_call_begin|>functions.get_tourist_attractions:1<|tool_call_argument_begin|>{", + ] + + tool_calls = [] + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, self.tools) + for tool_call_chunk in result.calls: + if tool_call_chunk.tool_index is not None: + + while len(tool_calls) <= tool_call_chunk.tool_index: + tool_calls.append({"name": "", "parameters": ""}) + + tc = tool_calls[tool_call_chunk.tool_index] + + if tool_call_chunk.name: + tc["name"] += tool_call_chunk.name + if tool_call_chunk.parameters: + tc["parameters"] += tool_call_chunk.parameters + + self.assertEqual(len(tool_calls), 2) + self.assertEqual(tool_calls[0]["name"], "get_weather") + self.assertEqual(tool_calls[0]["parameters"], '{"city": "Paris"}') + self.assertEqual(tool_calls[1]["name"], "get_tourist_attractions") + + def test_invalid_tool_call(self): + """Test that invalid tool calls are handled correctly.""" + text = 'invalid_tool:0<|tool_call_argument_begin|>{"city": "Paris"}<|tool_call_end|><|tool_calls_section_end|>' + result = self.detector.detect_and_parse(text, self.tools) + self.assertEqual(len(result.calls), 0) + self.assertEqual(result.normal_text, text) + + def test_partial_tool_call(self): + """Test that partial tool calls are handled correctly in streaming mode.""" + chunks = [ + "<|tool_calls_section_begin|><|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{", + '"city": "Paris"', + ] + + tool_calls = [] + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, self.tools) + for tool_call_chunk in result.calls: + if tool_call_chunk.tool_index is not None: + + while len(tool_calls) <= tool_call_chunk.tool_index: + tool_calls.append({"name": "", "parameters": ""}) + + tc = tool_calls[tool_call_chunk.tool_index] + + if tool_call_chunk.name: + tc["name"] += tool_call_chunk.name + if tool_call_chunk.parameters: + tc["parameters"] += tool_call_chunk.parameters + + self.assertEqual(len(tool_calls), 1) + self.assertEqual(tool_calls[0]["name"], "get_weather") + self.assertEqual(tool_calls[0]["parameters"], '{"city": "Paris"') + + if __name__ == "__main__": unittest.main()