Files
sglang/sgl-router/src/tool_parser/parsers/kimik2_parser.rs

346 lines
14 KiB
Rust

use async_trait::async_trait;
use regex::Regex;
use serde_json::Value;
use crate::protocols::common::Tool;
use crate::tool_parser::{
errors::ParserResult,
parsers::helpers,
traits::ToolParser,
types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
};
/// Kimi K2 format parser for tool calls
///
/// Handles the Kimi K2 specific format:
/// `<|tool_calls_section_begin|><|tool_call_begin|>functions.{name}:{index}<|tool_call_argument_begin|>{json_args}<|tool_call_end|><|tool_calls_section_end|>`
///
/// Features:
/// - Token-based delimiters
/// - Function calls with explicit indexing
/// - JSON arguments
pub struct KimiK2Parser {
/// Regex for extracting complete tool calls
tool_call_extractor: Regex,
/// Regex for extracting partial tool calls (streaming)
stream_tool_call_extractor: Regex,
/// Regex pattern for removing completed tool calls from buffer
tool_call_end_pattern: Regex,
/// Robust parser for ids like "functions.search:0" or fallback "search:0"
tool_call_id_regex: Regex,
/// Buffer for accumulating incomplete patterns across chunks
buffer: String,
/// Stores complete tool call info (name and arguments) for each tool being parsed
prev_tool_call_arr: Vec<Value>,
/// Index of currently streaming tool call (-1 means no active tool)
current_tool_id: i32,
/// Flag for whether current tool's name has been sent to client
current_tool_name_sent: bool,
/// Tracks raw JSON string content streamed to client for each tool's arguments
streamed_args_for_tool: Vec<String>,
/// Tracks the last arguments sent for incremental diffing
last_arguments: String,
}
impl KimiK2Parser {
/// Create a new Kimi K2 parser
pub fn new() -> Self {
// Pattern for complete tool calls
let tool_call_pattern = r"<\|tool_call_begin\|>\s*(?P<tool_call_id>[\w\.]+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P<function_arguments>\{.*?\})\s*<\|tool_call_end\|>";
let tool_call_extractor = Regex::new(tool_call_pattern).expect("Valid regex pattern");
// Pattern for streaming (partial) tool calls
let stream_pattern = r"<\|tool_call_begin\|>\s*(?P<tool_call_id>[\w\.]+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P<function_arguments>\{.*)";
let stream_tool_call_extractor = Regex::new(stream_pattern).expect("Valid regex pattern");
// Pattern for removing completed tool calls
let end_pattern = r"<\|tool_call_begin\|>.*?<\|tool_call_end\|>";
let tool_call_end_pattern = Regex::new(end_pattern).expect("Valid regex pattern");
// Robust parser for ids like "functions.search:0" or fallback "search:0"
let id_pattern = r"^(?:functions\.)?(?P<name>[\w\.]+):(?P<index>\d+)$";
let tool_call_id_regex = Regex::new(id_pattern).expect("Valid regex pattern");
Self {
tool_call_extractor,
stream_tool_call_extractor,
tool_call_end_pattern,
tool_call_id_regex,
buffer: String::new(),
prev_tool_call_arr: Vec::new(),
current_tool_id: -1,
current_tool_name_sent: false,
streamed_args_for_tool: Vec::new(),
last_arguments: String::new(),
}
}
/// Parse function ID to extract name and index
fn parse_function_id(&self, id: &str) -> Option<(String, usize)> {
if let Some(captures) = self.tool_call_id_regex.captures(id) {
let name = captures.name("name")?.as_str().to_string();
let index = captures.name("index")?.as_str().parse::<usize>().ok()?;
Some((name, index))
} else {
None
}
}
}
impl Default for KimiK2Parser {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl ToolParser for KimiK2Parser {
async fn parse_complete(&self, text: &str) -> ParserResult<(String, Vec<ToolCall>)> {
if !self.has_tool_markers(text) {
return Ok((text.to_string(), vec![]));
}
// Find where tool calls begin
let idx = text.find("<|tool_calls_section_begin|>").unwrap();
let normal_text = text[..idx].to_string();
// Try to extract tool calls
let mut tools = Vec::new();
for captures in self.tool_call_extractor.captures_iter(text) {
if let (Some(id_match), Some(args_match)) = (
captures.name("tool_call_id"),
captures.name("function_arguments"),
) {
let function_id = id_match.as_str();
let function_args = args_match.as_str();
// Parse function ID
if let Some((func_name, _index)) = self.parse_function_id(function_id) {
// Try to parse JSON arguments
match serde_json::from_str::<Value>(function_args) {
Ok(_) => {
tools.push(ToolCall {
function: FunctionCall {
name: func_name,
arguments: function_args.to_string(),
},
});
}
Err(e) => {
tracing::warn!(
"Failed to parse JSON arguments for {}: {}",
func_name,
e
);
continue;
}
}
} else {
tracing::warn!("Failed to parse function ID: {}", function_id);
continue;
}
}
}
// If no tools were successfully parsed despite having markers, return entire text as fallback
if tools.is_empty() {
return Ok((text.to_string(), vec![]));
}
Ok((normal_text, tools))
}
async fn parse_incremental(
&mut self,
chunk: &str,
tools: &[Tool],
) -> ParserResult<StreamingParseResult> {
self.buffer.push_str(chunk);
let current_text = &self.buffer.clone();
// Check if we have a tool call (either the start token or individual tool call)
let has_tool_call =
self.has_tool_markers(current_text) || current_text.contains("<|tool_call_begin|>");
if !has_tool_call {
// No tool markers detected - return all buffered content as normal text
let mut normal_text = std::mem::take(&mut self.buffer);
// Remove end tokens if present
for e_token in ["<|tool_calls_section_end|>", "<|tool_call_end|>"] {
normal_text = normal_text.replace(e_token, "");
}
return Ok(StreamingParseResult {
normal_text,
calls: vec![],
});
}
// Build tool indices for validation
let tool_indices = helpers::get_tool_indices(tools);
let mut calls: Vec<ToolCallItem> = Vec::new();
// Try to match streaming pattern
if let Some(captures) = self.stream_tool_call_extractor.captures(current_text) {
if let (Some(id_match), Some(args_match)) = (
captures.name("tool_call_id"),
captures.name("function_arguments"),
) {
let function_id = id_match.as_str();
let function_args = args_match.as_str();
// Parse function ID
if let Some((func_name, _index)) = self.parse_function_id(function_id) {
// Validate tool name
if !tool_indices.contains_key(&func_name) {
// Invalid tool name - skip this tool, preserve indexing for next tool
tracing::warn!("Invalid tool name '{}' - skipping", func_name);
helpers::reset_current_tool_state(
&mut self.buffer,
&mut self.current_tool_name_sent,
&mut self.streamed_args_for_tool,
&self.prev_tool_call_arr,
);
return Ok(StreamingParseResult::default());
}
// Initialize state if this is the first tool call
if self.current_tool_id == -1 {
self.current_tool_id = 0;
self.prev_tool_call_arr = Vec::new();
self.streamed_args_for_tool = vec![String::new()];
}
// Ensure we have enough entries in our tracking arrays
helpers::ensure_capacity(
self.current_tool_id,
&mut self.prev_tool_call_arr,
&mut self.streamed_args_for_tool,
);
// Send tool name if not sent yet
if !self.current_tool_name_sent {
calls.push(ToolCallItem {
tool_index: self.current_tool_id as usize,
name: Some(func_name.clone()),
parameters: String::new(),
});
self.current_tool_name_sent = true;
// Store the tool call info for serving layer completions endpoint
let tool_id = self.current_tool_id as usize;
if self.prev_tool_call_arr.len() <= tool_id {
self.prev_tool_call_arr
.resize_with(tool_id + 1, || Value::Null);
}
self.prev_tool_call_arr[tool_id] = serde_json::json!({
"name": func_name,
"arguments": {},
});
} else {
// Compute incremental diff
let argument_diff = if function_args.starts_with(&self.last_arguments) {
&function_args[self.last_arguments.len()..]
} else {
function_args
};
// Split by end token before sending (like Python does)
let parsed_args_diff =
if let Some(pos) = argument_diff.find("<|tool_call_end|>") {
&argument_diff[..pos]
} else {
argument_diff
};
if !parsed_args_diff.is_empty() {
calls.push(ToolCallItem {
tool_index: self.current_tool_id as usize,
name: None,
parameters: parsed_args_diff.to_string(),
});
// Note: Python adds full diff to _last_arguments, not just parsed part
self.last_arguments.push_str(argument_diff);
let tool_id = self.current_tool_id as usize;
if tool_id < self.streamed_args_for_tool.len() {
self.streamed_args_for_tool[tool_id].push_str(parsed_args_diff);
}
}
// Check completeness - split by end token first
let parsed_args = if let Some(pos) = function_args.find("<|tool_call_end|>")
{
&function_args[..pos]
} else {
function_args
};
if helpers::is_complete_json(parsed_args) {
// Update the stored arguments
if let Ok(parsed_args_value) =
serde_json::from_str::<Value>(parsed_args)
{
let tool_id = self.current_tool_id as usize;
if tool_id < self.prev_tool_call_arr.len() {
if let Some(obj) =
self.prev_tool_call_arr[tool_id].as_object_mut()
{
obj.insert("arguments".to_string(), parsed_args_value);
}
}
}
// Find the end of the current tool call and remove only that part from buffer
if let Some(mat) = self.tool_call_end_pattern.find(current_text) {
// Remove the completed tool call from buffer, keep any remaining content
self.buffer = current_text[mat.end()..].to_string();
} else {
self.buffer.clear();
}
let result = StreamingParseResult {
normal_text: String::new(),
calls,
};
self.current_tool_id += 1;
self.last_arguments.clear();
self.current_tool_name_sent = false;
return Ok(result);
}
}
}
}
}
Ok(StreamingParseResult {
normal_text: String::new(),
calls,
})
}
fn has_tool_markers(&self, text: &str) -> bool {
text.contains("<|tool_calls_section_begin|>")
}
fn get_unstreamed_tool_args(&self) -> Option<Vec<ToolCallItem>> {
helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
}
fn reset(&mut self) {
self.buffer.clear();
self.prev_tool_call_arr.clear();
self.current_tool_id = -1;
self.current_tool_name_sent = false;
self.streamed_args_for_tool.clear();
self.last_arguments.clear();
}
}