diff --git a/sgl-router/src/tool_parser/mod.rs b/sgl-router/src/tool_parser/mod.rs index 7a6bdfc24..42d42ea5b 100644 --- a/sgl-router/src/tool_parser/mod.rs +++ b/sgl-router/src/tool_parser/mod.rs @@ -25,5 +25,6 @@ pub use types::{FunctionCall, PartialToolCall, StreamResult, TokenConfig, ToolCa // Re-export parsers for convenience pub use parsers::{ - DeepSeekParser, JsonParser, LlamaParser, MistralParser, PythonicParser, QwenParser, Step3Parser, + DeepSeekParser, JsonParser, KimiK2Parser, LlamaParser, MistralParser, PythonicParser, + QwenParser, Step3Parser, }; diff --git a/sgl-router/src/tool_parser/parsers/kimik2_parser.rs b/sgl-router/src/tool_parser/parsers/kimik2_parser.rs new file mode 100644 index 000000000..52f92bd90 --- /dev/null +++ b/sgl-router/src/tool_parser/parsers/kimik2_parser.rs @@ -0,0 +1,270 @@ +use async_trait::async_trait; +use regex::Regex; + +use crate::tool_parser::{ + errors::ToolParserResult, + partial_json::PartialJson, + state::ParseState, + traits::ToolParser, + types::{FunctionCall, StreamResult, ToolCall}, +}; + +/// Kimi K2 format parser for tool calls +/// +/// Handles the Kimi K2 specific format: +/// `<|tool_calls_section_begin|><|tool_call_begin|>functions.{name}:{index}<|tool_call_argument_begin|>{json_args}<|tool_call_end|><|tool_calls_section_end|>` +/// +/// Features: +/// - Token-based delimiters +/// - Function calls with explicit indexing +/// - JSON arguments +pub struct KimiK2Parser { + /// Parser for handling incomplete JSON during streaming + partial_json: PartialJson, + /// Regex for extracting complete tool calls + tool_call_extractor: Regex, + /// Regex for extracting partial tool calls (streaming) + stream_tool_call_extractor: Regex, +} + +impl KimiK2Parser { + /// Create a new Kimi K2 parser + pub fn new() -> Self { + // Pattern for complete tool calls + let tool_call_pattern = r"<\|tool_call_begin\|>\s*(?P[\w\.]+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P\{.*?\})\s*<\|tool_call_end\|>"; + let tool_call_extractor = Regex::new(tool_call_pattern).expect("Valid regex pattern"); + + // Pattern for streaming (partial) tool calls + let stream_pattern = r"<\|tool_call_begin\|>\s*(?P[\w\.]+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P\{.*)"; + let stream_tool_call_extractor = Regex::new(stream_pattern).expect("Valid regex pattern"); + + Self { + partial_json: PartialJson::default(), + tool_call_extractor, + stream_tool_call_extractor, + } + } + + /// Check if text contains Kimi K2 tool markers + fn has_tool_markers(&self, text: &str) -> bool { + text.contains("<|tool_calls_section_begin|>") + } + + /// Parse function ID to extract name and index + fn parse_function_id(&self, id: &str) -> Option<(String, usize)> { + // Format: functions.{name}:{index} or namespace.functions.{name}:{index} + // Extract everything after the last dot before the colon as the function name + if let Some(colon_pos) = id.rfind(':') { + let before_colon = &id[..colon_pos]; + let index_str = &id[colon_pos + 1..]; + + // Find the last dot to extract the function name + if let Some(dot_pos) = before_colon.rfind('.') { + let func_name = &before_colon[dot_pos + 1..]; + + if let Ok(index) = index_str.parse::() { + return Some((func_name.to_string(), index)); + } + } + } + None + } +} + +impl Default for KimiK2Parser { + fn default() -> Self { + Self::new() + } +} + +#[async_trait] +impl ToolParser for KimiK2Parser { + async fn parse_complete(&self, text: &str) -> ToolParserResult> { + // Check if text contains Kimi K2 format + if !self.has_tool_markers(text) { + return Ok(vec![]); + } + + let mut tools = Vec::new(); + + // Extract all tool calls + for captures in self.tool_call_extractor.captures_iter(text) { + if let (Some(id_match), Some(args_match)) = ( + captures.name("tool_call_id"), + captures.name("function_arguments"), + ) { + let function_id = id_match.as_str(); + let function_args = args_match.as_str(); + + // Parse function ID + if let Some((func_name, _index)) = self.parse_function_id(function_id) { + // Validate JSON arguments + if serde_json::from_str::(function_args).is_ok() { + // Generate unique ID + let id = format!("kimi_call_{}", uuid::Uuid::new_v4()); + + tools.push(ToolCall { + id, + r#type: "function".to_string(), + function: FunctionCall { + name: func_name, + arguments: function_args.to_string(), + }, + }); + } + } + } + } + + Ok(tools) + } + + async fn parse_incremental( + &self, + chunk: &str, + state: &mut ParseState, + ) -> ToolParserResult { + state.buffer.push_str(chunk); + + // Check for tool markers + let has_tool_call = + self.has_tool_markers(&state.buffer) || state.buffer.contains("<|tool_call_begin|>"); + + if !has_tool_call { + // No markers found, clear buffer and return + state.buffer.clear(); + return Ok(StreamResult::Incomplete); + } + + // Try to match streaming pattern + if let Some(captures) = self.stream_tool_call_extractor.captures(&state.buffer) { + if let (Some(id_match), Some(args_match)) = ( + captures.name("tool_call_id"), + captures.name("function_arguments"), + ) { + let function_id = id_match.as_str(); + let partial_args = args_match.as_str(); + + // Parse function ID + if let Some((func_name, _index)) = self.parse_function_id(function_id) { + // Send function name if not sent yet + if !state.in_string { + state.in_string = true; // Mark name as sent + return Ok(StreamResult::ToolName { + index: 0, + name: func_name.clone(), + }); + } + + // Check if we have a complete tool call + if let Some(end_pos) = partial_args.find("<|tool_call_end|>") { + // Extract just the JSON part + let json_args = &partial_args[..end_pos]; + + // Validate and parse JSON + if serde_json::from_str::(json_args).is_ok() { + // Generate unique ID + let id = format!("kimi_call_{}", uuid::Uuid::new_v4()); + + let tool = ToolCall { + id, + r#type: "function".to_string(), + function: FunctionCall { + name: func_name, + arguments: json_args.to_string(), + }, + }; + + // Find where this tool call ends in the buffer + if let Some(tool_end) = state.buffer.find("<|tool_call_end|>") { + let end_pos = tool_end + "<|tool_call_end|>".len(); + state.buffer.drain(..end_pos); + } + + // Reset state for next tool + state.in_string = false; + + return Ok(StreamResult::ToolComplete(tool)); + } + } else { + // Try to parse partial JSON for streaming arguments + match self.partial_json.parse_value(partial_args) { + Ok((value, _consumed)) => { + let args_str = serde_json::to_string(&value) + .unwrap_or_else(|_| "{}".to_string()); + + return Ok(StreamResult::ToolArguments { + index: 0, + arguments: args_str, + }); + } + Err(_) => { + // Can't parse yet, keep buffering + } + } + } + } + } + } + + Ok(StreamResult::Incomplete) + } + + fn detect_format(&self, text: &str) -> bool { + self.has_tool_markers(text) || text.contains("<|tool_call_begin|>") + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_parse_kimi_single_tool() { + let parser = KimiK2Parser::new(); + let input = r#"Some text +<|tool_calls_section_begin|> +<|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{"location": "Tokyo", "units": "celsius"}<|tool_call_end|> +<|tool_calls_section_end|>More text"#; + + let result = parser.parse_complete(input).await.unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result[0].function.name, "get_weather"); + assert!(result[0].function.arguments.contains("Tokyo")); + } + + #[tokio::test] + async fn test_parse_kimi_multiple_tools() { + let parser = KimiK2Parser::new(); + let input = r#"<|tool_calls_section_begin|> +<|tool_call_begin|>functions.search:0<|tool_call_argument_begin|>{"query": "rust"}<|tool_call_end|> +<|tool_call_begin|>functions.calculate:1<|tool_call_argument_begin|>{"expression": "2+2"}<|tool_call_end|> +<|tool_calls_section_end|>"#; + + let result = parser.parse_complete(input).await.unwrap(); + assert_eq!(result.len(), 2); + assert_eq!(result[0].function.name, "search"); + assert_eq!(result[1].function.name, "calculate"); + } + + #[tokio::test] + async fn test_parse_kimi_with_whitespace() { + let parser = KimiK2Parser::new(); + let input = r#"<|tool_calls_section_begin|> +<|tool_call_begin|> functions.test:0 <|tool_call_argument_begin|> {"key": "value"} <|tool_call_end|> +<|tool_calls_section_end|>"#; + + let result = parser.parse_complete(input).await.unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result[0].function.name, "test"); + } + + #[test] + fn test_detect_format() { + let parser = KimiK2Parser::new(); + assert!(parser.detect_format("<|tool_calls_section_begin|>")); + assert!(parser.detect_format("<|tool_call_begin|>")); + assert!(!parser.detect_format("plain text")); + assert!(!parser.detect_format("[TOOL_CALLS]")); + } +} diff --git a/sgl-router/src/tool_parser/parsers/mod.rs b/sgl-router/src/tool_parser/parsers/mod.rs index 399e2dc98..681a5fb31 100644 --- a/sgl-router/src/tool_parser/parsers/mod.rs +++ b/sgl-router/src/tool_parser/parsers/mod.rs @@ -5,6 +5,7 @@ // Individual parser modules pub mod deepseek_parser; pub mod json_parser; +pub mod kimik2_parser; pub mod llama_parser; pub mod mistral_parser; pub mod pythonic_parser; @@ -13,9 +14,8 @@ pub mod step3_parser; // Re-export parser types for convenience pub use deepseek_parser::DeepSeekParser; - pub use json_parser::JsonParser; - +pub use kimik2_parser::KimiK2Parser; pub use llama_parser::LlamaParser; pub use mistral_parser::MistralParser; pub use pythonic_parser::PythonicParser; diff --git a/sgl-router/src/tool_parser/registry.rs b/sgl-router/src/tool_parser/registry.rs index 93ef7b785..ba01bb776 100644 --- a/sgl-router/src/tool_parser/registry.rs +++ b/sgl-router/src/tool_parser/registry.rs @@ -1,5 +1,6 @@ use crate::tool_parser::parsers::{ - DeepSeekParser, JsonParser, LlamaParser, MistralParser, PythonicParser, QwenParser, Step3Parser, + DeepSeekParser, JsonParser, KimiK2Parser, LlamaParser, MistralParser, PythonicParser, + QwenParser, Step3Parser, }; use crate::tool_parser::traits::ToolParser; use std::collections::HashMap; @@ -113,8 +114,12 @@ impl ParserRegistry { // DeepSeek V3 parser - Unicode tokens with JSON blocks self.register_parser("deepseek", Arc::new(DeepSeekParser::new())); + // Step3 parser - StepTML XML format self.register_parser("step3", Arc::new(Step3Parser::new())); + + // Kimi K2 parser - Token-based with indexed functions + self.register_parser("kimik2", Arc::new(KimiK2Parser::new())); } /// Register default model mappings diff --git a/sgl-router/tests/tool_parser_kimik2.rs b/sgl-router/tests/tool_parser_kimik2.rs new file mode 100644 index 000000000..66be2e88f --- /dev/null +++ b/sgl-router/tests/tool_parser_kimik2.rs @@ -0,0 +1,160 @@ +//! Kimi K2 Parser Integration Tests + +use sglang_router_rs::tool_parser::{KimiK2Parser, ParseState, StreamResult, ToolParser}; + +#[tokio::test] +async fn test_kimik2_complete_parsing() { + let parser = KimiK2Parser::new(); + + // Test single tool call + let input = r#"Let me help you with that. +<|tool_calls_section_begin|> +<|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{"location": "Tokyo", "units": "celsius"}<|tool_call_end|> +<|tool_calls_section_end|> +The weather in Tokyo is..."#; + + let result = parser.parse_complete(input).await.unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result[0].function.name, "get_weather"); + + // Verify arguments + let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + assert_eq!(args["location"], "Tokyo"); + assert_eq!(args["units"], "celsius"); +} + +#[tokio::test] +async fn test_kimik2_multiple_tools() { + let parser = KimiK2Parser::new(); + + let input = r#"<|tool_calls_section_begin|> +<|tool_call_begin|>functions.search:0<|tool_call_argument_begin|>{"query": "rust tutorials"}<|tool_call_end|> +<|tool_call_begin|>functions.translate:1<|tool_call_argument_begin|>{"text": "Hello", "to": "ja"}<|tool_call_end|> +<|tool_calls_section_end|>"#; + + let result = parser.parse_complete(input).await.unwrap(); + assert_eq!(result.len(), 2); + assert_eq!(result[0].function.name, "search"); + assert_eq!(result[1].function.name, "translate"); +} + +#[tokio::test] +async fn test_kimik2_with_whitespace() { + let parser = KimiK2Parser::new(); + + // Test with extra whitespace + let input = r#"<|tool_calls_section_begin|> +<|tool_call_begin|> functions.test:0 <|tool_call_argument_begin|> {"key": "value", "num": 42} <|tool_call_end|> +<|tool_calls_section_end|>"#; + + let result = parser.parse_complete(input).await.unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result[0].function.name, "test"); + + let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + assert_eq!(args["key"], "value"); + assert_eq!(args["num"], 42); +} + +#[tokio::test] +async fn test_kimik2_streaming() { + let parser = KimiK2Parser::new(); + let mut state = ParseState::new(); + + // Simulate streaming chunks + let chunks = vec![ + "<|tool_calls_section_begin|>\n", + "<|tool_call_begin|>functions.", + "calculate:0", + "<|tool_call_argument_begin|>", + r#"{"x": 10, "#, + r#""y": 20}"#, + "<|tool_call_end|>\n", + "<|tool_calls_section_end|>", + ]; + + let mut found_name = false; + let mut found_complete = false; + + for chunk in chunks { + let result = parser.parse_incremental(chunk, &mut state).await.unwrap(); + + match result { + StreamResult::ToolName { name, .. } => { + assert_eq!(name, "calculate"); + found_name = true; + } + StreamResult::ToolComplete(tool) => { + assert_eq!(tool.function.name, "calculate"); + found_complete = true; + } + _ => {} + } + } + + assert!(found_name || found_complete); +} + +#[test] +fn test_kimik2_format_detection() { + let parser = KimiK2Parser::new(); + + // Should detect Kimi K2 format + assert!(parser.detect_format("<|tool_calls_section_begin|>")); + assert!(parser.detect_format("<|tool_call_begin|>")); + assert!(parser.detect_format("text with <|tool_calls_section_begin|> marker")); + + // Should not detect other formats + assert!(!parser.detect_format("[TOOL_CALLS]")); + assert!(!parser.detect_format("")); + assert!(!parser.detect_format("plain text")); +} + +#[tokio::test] +async fn test_kimik2_sequential_indices() { + let parser = KimiK2Parser::new(); + + // Test with proper sequential indexing + let input = r#"<|tool_calls_section_begin|> +<|tool_call_begin|>functions.first:0<|tool_call_argument_begin|>{"param": "a"}<|tool_call_end|> +<|tool_call_begin|>functions.second:1<|tool_call_argument_begin|>{"param": "b"}<|tool_call_end|> +<|tool_call_begin|>functions.third:2<|tool_call_argument_begin|>{"param": "c"}<|tool_call_end|> +<|tool_calls_section_end|>"#; + + let result = parser.parse_complete(input).await.unwrap(); + assert_eq!(result.len(), 3); + assert_eq!(result[0].function.name, "first"); + assert_eq!(result[1].function.name, "second"); + assert_eq!(result[2].function.name, "third"); +} + +#[tokio::test] +async fn test_function_index_extraction() { + let parser = KimiK2Parser::new(); + + let input = r#"Text before tool calls. +<|tool_calls_section_begin|> +<|tool_call_begin|>functions.search:0<|tool_call_argument_begin|>{"query": "rust"}<|tool_call_end|> +<|tool_call_begin|>functions.calc:1<|tool_call_argument_begin|>{"x": 10}<|tool_call_end|> +<|tool_calls_section_end|>"#; + + let result = parser.parse_complete(input).await.unwrap(); + assert_eq!(result.len(), 2); + assert_eq!(result[0].function.name, "search"); + assert_eq!(result[1].function.name, "calc"); + // TODO: Verify indices are preserved: 0 and 1 + // TODO: Verify normal text = "Text before tool calls." +} + +#[tokio::test] +async fn test_namespace_extraction() { + let parser = KimiK2Parser::new(); + + let input = r#"<|tool_calls_section_begin|> +<|tool_call_begin|>api.tools.search:0<|tool_call_argument_begin|>{"q": "test"}<|tool_call_end|> +<|tool_calls_section_end|>"#; + + let result = parser.parse_complete(input).await.unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result[0].function.name, "search"); // Should extract after last dot +}