feat(function call): complete utility method for KimiK2Detector and enhance documentation (#8043)
This commit is contained in:
@@ -25,23 +25,49 @@ class BaseFormatDetector(ABC):
|
||||
"""Base class providing two sets of interfaces: one-time and streaming incremental."""
|
||||
|
||||
def __init__(self):
|
||||
# initialize properties used for state when parsing tool calls in
|
||||
# Streaming state management
|
||||
# Buffer for accumulating incomplete patterns that arrive across multiple streaming chunks
|
||||
self._buffer = ""
|
||||
# streaming mode
|
||||
# Stores complete tool call info (name and arguments) for each tool being parsed.
|
||||
# Used by serving layer for completion handling when streaming ends.
|
||||
# Format: [{"name": str, "arguments": dict}, ...]
|
||||
self.prev_tool_call_arr: List[Dict] = []
|
||||
# Index of currently streaming tool call. Starts at -1 (no active tool),
|
||||
# increments as each tool completes. Tracks which tool's arguments are streaming.
|
||||
self.current_tool_id: int = -1
|
||||
# Flag for whether current tool's name has been sent to client.
|
||||
# Tool names sent first with empty parameters, then arguments stream incrementally.
|
||||
self.current_tool_name_sent: bool = False
|
||||
self.streamed_args_for_tool: List[str] = (
|
||||
[]
|
||||
) # map what has been streamed for each tool so far to a list
|
||||
# Tracks raw JSON string content streamed to client for each tool's arguments.
|
||||
# Critical for serving layer to calculate remaining content when streaming ends.
|
||||
# Each index corresponds to a tool_id. Example: ['{"location": "San Francisco"', '{"temp": 72']
|
||||
self.streamed_args_for_tool: List[str] = []
|
||||
|
||||
# Token configuration (override in subclasses)
|
||||
self.bot_token = ""
|
||||
self.eot_token = ""
|
||||
self.tool_call_separator = ", "
|
||||
|
||||
def parse_base_json(self, action: Any, tools: List[Tool]) -> List[ToolCallItem]:
|
||||
tool_indices = {
|
||||
def _get_tool_indices(self, tools: List[Tool]) -> Dict[str, int]:
|
||||
"""
|
||||
Get a mapping of tool names to their indices in the tools list.
|
||||
|
||||
This utility method creates a dictionary mapping function names to their
|
||||
indices in the tools list, which is commonly needed for tool validation
|
||||
and ToolCallItem creation.
|
||||
|
||||
Args:
|
||||
tools: List of available tools
|
||||
|
||||
Returns:
|
||||
Dictionary mapping tool names to their indices
|
||||
"""
|
||||
return {
|
||||
tool.function.name: i for i, tool in enumerate(tools) if tool.function.name
|
||||
}
|
||||
|
||||
def parse_base_json(self, action: Any, tools: List[Tool]) -> List[ToolCallItem]:
|
||||
tool_indices = self._get_tool_indices(tools)
|
||||
if not isinstance(action, list):
|
||||
action = [action]
|
||||
|
||||
@@ -130,11 +156,7 @@ class BaseFormatDetector(ABC):
|
||||
|
||||
# Build tool indices if not already built
|
||||
if not hasattr(self, "_tool_indices"):
|
||||
self._tool_indices = {
|
||||
tool.function.name: i
|
||||
for i, tool in enumerate(tools)
|
||||
if tool.function and tool.function.name
|
||||
}
|
||||
self._tool_indices = self._get_tool_indices(tools)
|
||||
|
||||
flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR
|
||||
|
||||
@@ -294,12 +316,48 @@ class BaseFormatDetector(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def has_tool_call(self, text: str) -> bool:
|
||||
"""
|
||||
Check if the given text contains function call markers specific to this format.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def structure_info(self) -> _GetInfoFunc:
|
||||
"""
|
||||
Return a function that creates StructureInfo for constrained generation.
|
||||
|
||||
The returned function takes a tool name and returns a StructureInfo object
|
||||
containing the begin/end patterns and trigger tokens needed for constrained
|
||||
generation of function calls in this format.
|
||||
|
||||
Returns:
|
||||
A function that takes a tool name (str) and returns StructureInfo
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def build_ebnf(self, tools: List[Tool]) -> str:
|
||||
"""
|
||||
Build an EBNF grammar for constrained generation of function calls.
|
||||
|
||||
This method generates an Extended Backus-Naur Form (EBNF) grammar that
|
||||
constrains the model's output to valid function calls in this format.
|
||||
The grammar should include all available tools and their parameter schemas.
|
||||
|
||||
Args:
|
||||
tools: List of available tools/functions that can be called
|
||||
|
||||
Returns:
|
||||
A string containing the EBNF grammar for this function call format
|
||||
|
||||
The EBNF grammar should:
|
||||
- Define the overall structure of function calls in this format
|
||||
- Include all tool names from the provided tools list
|
||||
- Define valid JSON structures for function arguments
|
||||
- Handle multiple function calls if the format supports them
|
||||
|
||||
Note:
|
||||
Most implementations use EBNFComposer.build_ebnf() utility with
|
||||
format-specific parameters rather than writing EBNF from scratch.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -19,9 +19,28 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class DeepSeekV3Detector(BaseFormatDetector):
|
||||
"""
|
||||
Detector for DeepSeek models.
|
||||
Assumes function call format:
|
||||
'<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_current_weather\n```json\n{"location": "Tokyo"}\n```<|tool▁call▁end|>\n<|tool▁call▁begin|>function<|tool▁sep|>get_current_weather\n```json\n{"location": "Paris"}\n```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>
|
||||
Detector for DeepSeek V3 model function call format.
|
||||
|
||||
The DeepSeek V3 format uses special Unicode tokens to delimit function calls
|
||||
with JSON code blocks for arguments.
|
||||
|
||||
Format Structure:
|
||||
```
|
||||
<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>{function_name}\n```json\n{json_arguments}\n```<|tool▁calls▁end|><|end▁of▁sentence|>
|
||||
```
|
||||
Examples:
|
||||
```
|
||||
<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_current_weather\n```json\n{"location": "Tokyo"}\n```<|tool▁call▁end|>\n<|tool▁call▁begin|>function<|tool▁sep|>get_current_weather\n```json\n{"location": "Paris"}\n```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>
|
||||
```
|
||||
|
||||
Key Components:
|
||||
- Tool Calls Section: Wrapped between `<|tool▁calls▁begin|>` and `<|tool▁calls▁end|>`
|
||||
- Individual Tool Call: Wrapped between `<|tool▁call▁begin|>` and `<|tool▁call▁end|>`
|
||||
- Function Declaration: `function<|tool▁sep|>{function_name}`
|
||||
- Arguments: JSON code block between ````json` and ````
|
||||
- Supports multiple tool calls
|
||||
|
||||
Reference: https://huggingface.co/deepseek-ai/DeepSeek-V3-0324?chat_template=default
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
@@ -89,11 +108,7 @@ class DeepSeekV3Detector(BaseFormatDetector):
|
||||
return StreamingParseResult(normal_text=new_text)
|
||||
|
||||
if not hasattr(self, "_tool_indices"):
|
||||
self._tool_indices = {
|
||||
tool.function.name: i
|
||||
for i, tool in enumerate(tools)
|
||||
if tool.function and tool.function.name
|
||||
}
|
||||
self._tool_indices = self._get_tool_indices(tools)
|
||||
|
||||
calls: list[ToolCallItem] = []
|
||||
try:
|
||||
@@ -127,7 +142,7 @@ class DeepSeekV3Detector(BaseFormatDetector):
|
||||
)
|
||||
)
|
||||
self.current_tool_name_sent = True
|
||||
# Store the tool call info for adapter.py
|
||||
# Store the tool call info for serving layer completions endpoint
|
||||
self.prev_tool_call_arr[self.current_tool_id] = {
|
||||
"name": func_name,
|
||||
"arguments": {},
|
||||
@@ -153,7 +168,7 @@ class DeepSeekV3Detector(BaseFormatDetector):
|
||||
] += argument_diff
|
||||
|
||||
if _is_complete_json(func_args_raw):
|
||||
# Update the stored arguments for adapter.py
|
||||
# Update the stored arguments
|
||||
try:
|
||||
parsed_args = json.loads(func_args_raw)
|
||||
self.prev_tool_call_arr[self.current_tool_id][
|
||||
|
||||
@@ -18,16 +18,21 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class KimiK2Detector(BaseFormatDetector):
|
||||
"""
|
||||
Detector for Kimi K2 model function call format.
|
||||
|
||||
Format Structure:
|
||||
```
|
||||
<|tool_calls_section_begin|>
|
||||
<|tool_call_begin|>functions.{func_name}:{index} <|tool_call_argument_begin|>{json_args}<|tool_call_end|>
|
||||
<|tool_calls_section_end|>
|
||||
```
|
||||
|
||||
Reference: https://huggingface.co/moonshotai/Kimi-K2-Instruct/blob/main/docs/tool_call_guidance.md
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._buffer = ""
|
||||
self.current_tool_name_sent: bool = False
|
||||
self.prev_tool_call_arr: list[dict] = []
|
||||
self.current_tool_id: int = -1
|
||||
self.streamed_args_for_tool: list[str] = (
|
||||
[]
|
||||
) # map what has been streamed for each tool so far to a list
|
||||
|
||||
self.bot_token: str = "<|tool_calls_section_begin|>"
|
||||
self.eot_token: str = "<|tool_calls_section_end|>"
|
||||
@@ -114,11 +119,7 @@ class KimiK2Detector(BaseFormatDetector):
|
||||
return StreamingParseResult(normal_text=new_text)
|
||||
|
||||
if not hasattr(self, "_tool_indices"):
|
||||
self._tool_indices = {
|
||||
tool.function.name: i
|
||||
for i, tool in enumerate(tools)
|
||||
if tool.function and tool.function.name
|
||||
}
|
||||
self._tool_indices = self._get_tool_indices(tools)
|
||||
|
||||
calls: list[ToolCallItem] = []
|
||||
try:
|
||||
@@ -150,7 +151,7 @@ class KimiK2Detector(BaseFormatDetector):
|
||||
)
|
||||
)
|
||||
self.current_tool_name_sent = True
|
||||
# Store the tool call info for adapter.py
|
||||
# Store the tool call info for serving layer completions endpoint
|
||||
self.prev_tool_call_arr[self.current_tool_id] = {
|
||||
"name": function_name,
|
||||
"arguments": {},
|
||||
@@ -214,7 +215,31 @@ class KimiK2Detector(BaseFormatDetector):
|
||||
return StreamingParseResult(normal_text=current_text)
|
||||
|
||||
def structure_info(self) -> _GetInfoFunc:
|
||||
raise NotImplementedError()
|
||||
"""Return function that creates StructureInfo for guided generation."""
|
||||
|
||||
def build_ebnf(self, tools: List[Tool]):
|
||||
raise NotImplementedError()
|
||||
def get_info(name: str) -> StructureInfo:
|
||||
return StructureInfo(
|
||||
begin=f"<|tool_calls_section_begin|><|tool_call_begin|>functions.{name}:0 <|tool_call_argument_begin|>",
|
||||
end="<|tool_call_end|><|tool_calls_section_end|>",
|
||||
trigger="<|tool_calls_section_begin|>",
|
||||
)
|
||||
|
||||
return get_info
|
||||
|
||||
def build_ebnf(self, tools: List[Tool]) -> str:
|
||||
"""
|
||||
Build EBNF grammar for KimiK2 tool call format.
|
||||
|
||||
NOTE: The call_rule_fmt uses [0-9]+ for the function index to allow the grammar
|
||||
to accept any numeric index (0, 1, 2, etc.) for proper sequential indexing in
|
||||
multiple function call scenarios, while still maintaining the correct KimiK2
|
||||
format structure for constrained generation.
|
||||
"""
|
||||
return EBNFComposer.build_ebnf(
|
||||
tools,
|
||||
sequence_start_token=self.bot_token,
|
||||
sequence_end_token=self.eot_token,
|
||||
tool_call_separator="",
|
||||
call_rule_fmt='"<|tool_call_begin|>functions.{name}:" [0-9]+ " <|tool_call_argument_begin|>" {arguments_rule} "<|tool_call_end|>"',
|
||||
function_format="json",
|
||||
)
|
||||
|
||||
@@ -16,9 +16,12 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class Llama32Detector(BaseFormatDetector):
|
||||
"""
|
||||
Detector for Llama 3.2 models.
|
||||
Assumes function call format:
|
||||
<|python_tag|>{"name":"xxx", "arguments":{...}}
|
||||
Detector for Llama 3.2 models with json tool call format.
|
||||
|
||||
Format Structure:
|
||||
```
|
||||
<python_tag>{"name":"xxx", "arguments":{...}}
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
|
||||
@@ -17,9 +17,17 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class MistralDetector(BaseFormatDetector):
|
||||
"""
|
||||
Detector for Mistral models.
|
||||
Assumes function call format:
|
||||
[TOOL_CALLS] [{"name":"func1", "arguments":{...}}, {"name":"func2", "arguments":{...}}]
|
||||
Detector for Mistral model function call format.
|
||||
|
||||
The Mistral format uses a simple bracket-delimited structure with JSON arrays
|
||||
containing function call objects.
|
||||
|
||||
Format Structure:
|
||||
```
|
||||
[TOOL_CALLS] [{"name": "function_name", "arguments": {json_args}}, ...]
|
||||
```
|
||||
|
||||
Reference: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3?chat_template=default
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
|
||||
@@ -19,10 +19,17 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class PythonicDetector(BaseFormatDetector):
|
||||
"""
|
||||
Detector for Llama-3.2 and Llama-4 models with pythonic tool call format.
|
||||
Assumes function call format:
|
||||
[tool1(arg1=val1, arg2=val2), tool2(arg1=val3)]
|
||||
Arguments are Python literals (not JSON).
|
||||
Detector for Llama-4 models with Pythonic tool call format.
|
||||
|
||||
The Pythonic format uses Python function call syntax within square brackets,
|
||||
with arguments as Python literals rather than JSON.
|
||||
|
||||
Format Structure:
|
||||
```
|
||||
[tool1(arg1=val1, arg2=val2), tool2(arg1=val3)]
|
||||
```
|
||||
|
||||
Reference: https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct?chat_template=default
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
@@ -75,11 +82,7 @@ class PythonicDetector(BaseFormatDetector):
|
||||
return StreamingParseResult(normal_text=normal_text, calls=[])
|
||||
|
||||
calls = []
|
||||
tool_indices = {
|
||||
tool.function.name: i
|
||||
for i, tool in enumerate(tools)
|
||||
if tool.function.name
|
||||
}
|
||||
tool_indices = self._get_tool_indices(tools)
|
||||
for call_index, call in enumerate(parsed.elts):
|
||||
if not isinstance(call.func, ast.Name):
|
||||
continue
|
||||
|
||||
@@ -17,9 +17,18 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class Qwen25Detector(BaseFormatDetector):
|
||||
"""
|
||||
Detector for Qwen 2.5 models.
|
||||
Assumes function call format:
|
||||
<tool_call>\n{"name":"func1", "arguments":{...}}\n</tool_call>\n<tool_call>\n{"name":"func2", "arguments":{...}}\n</tool_call>
|
||||
Detector for Qwen 2.5 and Qwen 3 model function call format.
|
||||
|
||||
Format Structure:
|
||||
```
|
||||
<tool_call>\n{"name":"func1", "arguments":{...}}\n</tool_call>\n<tool_call>\n{"name":"func2", "arguments":{...}}\n</tool_call>
|
||||
```
|
||||
|
||||
Key Components:
|
||||
- Tool Call Tags: `<tool_call>` and `</tool_call>` wrap each individual call
|
||||
- Function Call Object: JSON object with "name" and "arguments" fields
|
||||
|
||||
Reference: https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct?chat_template=default
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
|
||||
Reference in New Issue
Block a user