From 01079e174ff8a7a052b4f8f74b4f8a59edd13f61 Mon Sep 17 00:00:00 2001 From: Chang Su Date: Wed, 23 Jul 2025 17:37:31 -0700 Subject: [PATCH] feat(function call): complete utility method for KimiK2Detector and enhance documentation (#8043) --- .../srt/function_call/base_format_detector.py | 82 ++++++++++++++++--- .../srt/function_call/deepseekv3_detector.py | 35 +++++--- .../srt/function_call/kimik2_detector.py | 57 +++++++++---- .../srt/function_call/llama32_detector.py | 9 +- .../srt/function_call/mistral_detector.py | 14 +++- .../srt/function_call/pythonic_detector.py | 21 +++-- .../srt/function_call/qwen25_detector.py | 15 +++- test/srt/test_function_call_parser.py | 28 +++++++ 8 files changed, 205 insertions(+), 56 deletions(-) diff --git a/python/sglang/srt/function_call/base_format_detector.py b/python/sglang/srt/function_call/base_format_detector.py index 3989ec98d..d9ac71253 100644 --- a/python/sglang/srt/function_call/base_format_detector.py +++ b/python/sglang/srt/function_call/base_format_detector.py @@ -25,23 +25,49 @@ class BaseFormatDetector(ABC): """Base class providing two sets of interfaces: one-time and streaming incremental.""" def __init__(self): - # initialize properties used for state when parsing tool calls in + # Streaming state management + # Buffer for accumulating incomplete patterns that arrive across multiple streaming chunks self._buffer = "" - # streaming mode + # Stores complete tool call info (name and arguments) for each tool being parsed. + # Used by serving layer for completion handling when streaming ends. + # Format: [{"name": str, "arguments": dict}, ...] self.prev_tool_call_arr: List[Dict] = [] + # Index of currently streaming tool call. Starts at -1 (no active tool), + # increments as each tool completes. Tracks which tool's arguments are streaming. self.current_tool_id: int = -1 + # Flag for whether current tool's name has been sent to client. + # Tool names sent first with empty parameters, then arguments stream incrementally. self.current_tool_name_sent: bool = False - self.streamed_args_for_tool: List[str] = ( - [] - ) # map what has been streamed for each tool so far to a list + # Tracks raw JSON string content streamed to client for each tool's arguments. + # Critical for serving layer to calculate remaining content when streaming ends. + # Each index corresponds to a tool_id. Example: ['{"location": "San Francisco"', '{"temp": 72'] + self.streamed_args_for_tool: List[str] = [] + + # Token configuration (override in subclasses) self.bot_token = "" self.eot_token = "" self.tool_call_separator = ", " - def parse_base_json(self, action: Any, tools: List[Tool]) -> List[ToolCallItem]: - tool_indices = { + def _get_tool_indices(self, tools: List[Tool]) -> Dict[str, int]: + """ + Get a mapping of tool names to their indices in the tools list. + + This utility method creates a dictionary mapping function names to their + indices in the tools list, which is commonly needed for tool validation + and ToolCallItem creation. + + Args: + tools: List of available tools + + Returns: + Dictionary mapping tool names to their indices + """ + return { tool.function.name: i for i, tool in enumerate(tools) if tool.function.name } + + def parse_base_json(self, action: Any, tools: List[Tool]) -> List[ToolCallItem]: + tool_indices = self._get_tool_indices(tools) if not isinstance(action, list): action = [action] @@ -130,11 +156,7 @@ class BaseFormatDetector(ABC): # Build tool indices if not already built if not hasattr(self, "_tool_indices"): - self._tool_indices = { - tool.function.name: i - for i, tool in enumerate(tools) - if tool.function and tool.function.name - } + self._tool_indices = self._get_tool_indices(tools) flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR @@ -294,12 +316,48 @@ class BaseFormatDetector(ABC): @abstractmethod def has_tool_call(self, text: str) -> bool: + """ + Check if the given text contains function call markers specific to this format. + """ raise NotImplementedError() @abstractmethod def structure_info(self) -> _GetInfoFunc: + """ + Return a function that creates StructureInfo for constrained generation. + + The returned function takes a tool name and returns a StructureInfo object + containing the begin/end patterns and trigger tokens needed for constrained + generation of function calls in this format. + + Returns: + A function that takes a tool name (str) and returns StructureInfo + """ raise NotImplementedError() @abstractmethod def build_ebnf(self, tools: List[Tool]) -> str: + """ + Build an EBNF grammar for constrained generation of function calls. + + This method generates an Extended Backus-Naur Form (EBNF) grammar that + constrains the model's output to valid function calls in this format. + The grammar should include all available tools and their parameter schemas. + + Args: + tools: List of available tools/functions that can be called + + Returns: + A string containing the EBNF grammar for this function call format + + The EBNF grammar should: + - Define the overall structure of function calls in this format + - Include all tool names from the provided tools list + - Define valid JSON structures for function arguments + - Handle multiple function calls if the format supports them + + Note: + Most implementations use EBNFComposer.build_ebnf() utility with + format-specific parameters rather than writing EBNF from scratch. + """ raise NotImplementedError() diff --git a/python/sglang/srt/function_call/deepseekv3_detector.py b/python/sglang/srt/function_call/deepseekv3_detector.py index e3befca5b..35e96c715 100644 --- a/python/sglang/srt/function_call/deepseekv3_detector.py +++ b/python/sglang/srt/function_call/deepseekv3_detector.py @@ -19,9 +19,28 @@ logger = logging.getLogger(__name__) class DeepSeekV3Detector(BaseFormatDetector): """ - Detector for DeepSeek models. - Assumes function call format: - '<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_current_weather\n```json\n{"location": "Tokyo"}\n```<|tool▁call▁end|>\n<|tool▁call▁begin|>function<|tool▁sep|>get_current_weather\n```json\n{"location": "Paris"}\n```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|> + Detector for DeepSeek V3 model function call format. + + The DeepSeek V3 format uses special Unicode tokens to delimit function calls + with JSON code blocks for arguments. + + Format Structure: + ``` + <|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>{function_name}\n```json\n{json_arguments}\n```<|tool▁calls▁end|><|end▁of▁sentence|> + ``` + Examples: + ``` + <|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_current_weather\n```json\n{"location": "Tokyo"}\n```<|tool▁call▁end|>\n<|tool▁call▁begin|>function<|tool▁sep|>get_current_weather\n```json\n{"location": "Paris"}\n```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|> + ``` + + Key Components: + - Tool Calls Section: Wrapped between `<|tool▁calls▁begin|>` and `<|tool▁calls▁end|>` + - Individual Tool Call: Wrapped between `<|tool▁call▁begin|>` and `<|tool▁call▁end|>` + - Function Declaration: `function<|tool▁sep|>{function_name}` + - Arguments: JSON code block between ````json` and ```` + - Supports multiple tool calls + + Reference: https://huggingface.co/deepseek-ai/DeepSeek-V3-0324?chat_template=default """ def __init__(self): @@ -89,11 +108,7 @@ class DeepSeekV3Detector(BaseFormatDetector): return StreamingParseResult(normal_text=new_text) if not hasattr(self, "_tool_indices"): - self._tool_indices = { - tool.function.name: i - for i, tool in enumerate(tools) - if tool.function and tool.function.name - } + self._tool_indices = self._get_tool_indices(tools) calls: list[ToolCallItem] = [] try: @@ -127,7 +142,7 @@ class DeepSeekV3Detector(BaseFormatDetector): ) ) self.current_tool_name_sent = True - # Store the tool call info for adapter.py + # Store the tool call info for serving layer completions endpoint self.prev_tool_call_arr[self.current_tool_id] = { "name": func_name, "arguments": {}, @@ -153,7 +168,7 @@ class DeepSeekV3Detector(BaseFormatDetector): ] += argument_diff if _is_complete_json(func_args_raw): - # Update the stored arguments for adapter.py + # Update the stored arguments try: parsed_args = json.loads(func_args_raw) self.prev_tool_call_arr[self.current_tool_id][ diff --git a/python/sglang/srt/function_call/kimik2_detector.py b/python/sglang/srt/function_call/kimik2_detector.py index 94457ccda..54ee77787 100644 --- a/python/sglang/srt/function_call/kimik2_detector.py +++ b/python/sglang/srt/function_call/kimik2_detector.py @@ -18,16 +18,21 @@ logger = logging.getLogger(__name__) class KimiK2Detector(BaseFormatDetector): + """ + Detector for Kimi K2 model function call format. + + Format Structure: + ``` + <|tool_calls_section_begin|> + <|tool_call_begin|>functions.{func_name}:{index} <|tool_call_argument_begin|>{json_args}<|tool_call_end|> + <|tool_calls_section_end|> + ``` + + Reference: https://huggingface.co/moonshotai/Kimi-K2-Instruct/blob/main/docs/tool_call_guidance.md + """ def __init__(self): super().__init__() - self._buffer = "" - self.current_tool_name_sent: bool = False - self.prev_tool_call_arr: list[dict] = [] - self.current_tool_id: int = -1 - self.streamed_args_for_tool: list[str] = ( - [] - ) # map what has been streamed for each tool so far to a list self.bot_token: str = "<|tool_calls_section_begin|>" self.eot_token: str = "<|tool_calls_section_end|>" @@ -114,11 +119,7 @@ class KimiK2Detector(BaseFormatDetector): return StreamingParseResult(normal_text=new_text) if not hasattr(self, "_tool_indices"): - self._tool_indices = { - tool.function.name: i - for i, tool in enumerate(tools) - if tool.function and tool.function.name - } + self._tool_indices = self._get_tool_indices(tools) calls: list[ToolCallItem] = [] try: @@ -150,7 +151,7 @@ class KimiK2Detector(BaseFormatDetector): ) ) self.current_tool_name_sent = True - # Store the tool call info for adapter.py + # Store the tool call info for serving layer completions endpoint self.prev_tool_call_arr[self.current_tool_id] = { "name": function_name, "arguments": {}, @@ -214,7 +215,31 @@ class KimiK2Detector(BaseFormatDetector): return StreamingParseResult(normal_text=current_text) def structure_info(self) -> _GetInfoFunc: - raise NotImplementedError() + """Return function that creates StructureInfo for guided generation.""" - def build_ebnf(self, tools: List[Tool]): - raise NotImplementedError() + def get_info(name: str) -> StructureInfo: + return StructureInfo( + begin=f"<|tool_calls_section_begin|><|tool_call_begin|>functions.{name}:0 <|tool_call_argument_begin|>", + end="<|tool_call_end|><|tool_calls_section_end|>", + trigger="<|tool_calls_section_begin|>", + ) + + return get_info + + def build_ebnf(self, tools: List[Tool]) -> str: + """ + Build EBNF grammar for KimiK2 tool call format. + + NOTE: The call_rule_fmt uses [0-9]+ for the function index to allow the grammar + to accept any numeric index (0, 1, 2, etc.) for proper sequential indexing in + multiple function call scenarios, while still maintaining the correct KimiK2 + format structure for constrained generation. + """ + return EBNFComposer.build_ebnf( + tools, + sequence_start_token=self.bot_token, + sequence_end_token=self.eot_token, + tool_call_separator="", + call_rule_fmt='"<|tool_call_begin|>functions.{name}:" [0-9]+ " <|tool_call_argument_begin|>" {arguments_rule} "<|tool_call_end|>"', + function_format="json", + ) diff --git a/python/sglang/srt/function_call/llama32_detector.py b/python/sglang/srt/function_call/llama32_detector.py index e7afeddb0..453bcbc9a 100644 --- a/python/sglang/srt/function_call/llama32_detector.py +++ b/python/sglang/srt/function_call/llama32_detector.py @@ -16,9 +16,12 @@ logger = logging.getLogger(__name__) class Llama32Detector(BaseFormatDetector): """ - Detector for Llama 3.2 models. - Assumes function call format: - <|python_tag|>{"name":"xxx", "arguments":{...}} + Detector for Llama 3.2 models with json tool call format. + + Format Structure: + ``` + {"name":"xxx", "arguments":{...}} + ``` """ def __init__(self): diff --git a/python/sglang/srt/function_call/mistral_detector.py b/python/sglang/srt/function_call/mistral_detector.py index 031368006..49767fd53 100644 --- a/python/sglang/srt/function_call/mistral_detector.py +++ b/python/sglang/srt/function_call/mistral_detector.py @@ -17,9 +17,17 @@ logger = logging.getLogger(__name__) class MistralDetector(BaseFormatDetector): """ - Detector for Mistral models. - Assumes function call format: - [TOOL_CALLS] [{"name":"func1", "arguments":{...}}, {"name":"func2", "arguments":{...}}] + Detector for Mistral model function call format. + + The Mistral format uses a simple bracket-delimited structure with JSON arrays + containing function call objects. + + Format Structure: + ``` + [TOOL_CALLS] [{"name": "function_name", "arguments": {json_args}}, ...] + ``` + + Reference: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3?chat_template=default """ def __init__(self): diff --git a/python/sglang/srt/function_call/pythonic_detector.py b/python/sglang/srt/function_call/pythonic_detector.py index d3096d919..85c3cd135 100644 --- a/python/sglang/srt/function_call/pythonic_detector.py +++ b/python/sglang/srt/function_call/pythonic_detector.py @@ -19,10 +19,17 @@ logger = logging.getLogger(__name__) class PythonicDetector(BaseFormatDetector): """ - Detector for Llama-3.2 and Llama-4 models with pythonic tool call format. - Assumes function call format: - [tool1(arg1=val1, arg2=val2), tool2(arg1=val3)] - Arguments are Python literals (not JSON). + Detector for Llama-4 models with Pythonic tool call format. + + The Pythonic format uses Python function call syntax within square brackets, + with arguments as Python literals rather than JSON. + + Format Structure: + ``` + [tool1(arg1=val1, arg2=val2), tool2(arg1=val3)] + ``` + + Reference: https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct?chat_template=default """ def __init__(self): @@ -75,11 +82,7 @@ class PythonicDetector(BaseFormatDetector): return StreamingParseResult(normal_text=normal_text, calls=[]) calls = [] - tool_indices = { - tool.function.name: i - for i, tool in enumerate(tools) - if tool.function.name - } + tool_indices = self._get_tool_indices(tools) for call_index, call in enumerate(parsed.elts): if not isinstance(call.func, ast.Name): continue diff --git a/python/sglang/srt/function_call/qwen25_detector.py b/python/sglang/srt/function_call/qwen25_detector.py index cee3f18ea..40a65e5df 100644 --- a/python/sglang/srt/function_call/qwen25_detector.py +++ b/python/sglang/srt/function_call/qwen25_detector.py @@ -17,9 +17,18 @@ logger = logging.getLogger(__name__) class Qwen25Detector(BaseFormatDetector): """ - Detector for Qwen 2.5 models. - Assumes function call format: - \n{"name":"func1", "arguments":{...}}\n\n\n{"name":"func2", "arguments":{...}}\n + Detector for Qwen 2.5 and Qwen 3 model function call format. + + Format Structure: + ``` + \n{"name":"func1", "arguments":{...}}\n\n\n{"name":"func2", "arguments":{...}}\n + ``` + + Key Components: + - Tool Call Tags: `` and `` wrap each individual call + - Function Call Object: JSON object with "name" and "arguments" fields + + Reference: https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct?chat_template=default """ def __init__(self): diff --git a/test/srt/test_function_call_parser.py b/test/srt/test_function_call_parser.py index f9c36a9a2..c2f63e7e4 100644 --- a/test/srt/test_function_call_parser.py +++ b/test/srt/test_function_call_parser.py @@ -507,6 +507,7 @@ class TestEBNFGeneration(unittest.TestCase): self.llama32_detector = Llama32Detector() self.mistral_detector = MistralDetector() self.qwen25_detector = Qwen25Detector() + self.kimik2_detector = KimiK2Detector() def test_pythonic_detector_ebnf(self): """Test that the PythonicDetector generates valid EBNF.""" @@ -542,6 +543,33 @@ class TestEBNFGeneration(unittest.TestCase): except RuntimeError as e: self.fail(f"Failed to compile EBNF: {e}") + def test_kimik2_detector_ebnf(self): + """Test that the KimiK2Detector generates valid EBNF.""" + ebnf = self.kimik2_detector.build_ebnf(self.tools) + self.assertIsNotNone(ebnf) + + # Check that the EBNF contains expected patterns for KimiK2 format + self.assertIn("<|tool_calls_section_begin|>", ebnf) + self.assertIn("<|tool_calls_section_end|>", ebnf) + + # Check for KimiK2-specific function call structure + self.assertIn("<|tool_call_begin|>functions.get_weather:", ebnf) + self.assertIn("<|tool_call_begin|>functions.search:", ebnf) + self.assertIn("<|tool_call_argument_begin|>", ebnf) + self.assertIn("<|tool_call_end|>", ebnf) + + # Check that it uses the correct namespace.function format with numeric index pattern + self.assertIn("functions.get_weather:", ebnf) + self.assertIn("functions.search:", ebnf) + self.assertIn("[0-9]+", ebnf) # Numeric index pattern + + # Validate that the EBNF can be compiled by GrammarCompiler + try: + ctx = self.grammar_compiler.compile_grammar(ebnf) + self.assertIsNotNone(ctx, "EBNF should be valid and compile successfully") + except RuntimeError as e: + self.fail(f"Failed to compile EBNF: {e}") + def test_llama32_detector_ebnf(self): """Test that the Llama32Detector generates valid EBNF.""" ebnf = self.llama32_detector.build_ebnf(self.tools)