From 817c62a077b2ce95ea67daea93320edd03ef9b36 Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Mon, 25 Aug 2025 20:09:51 -0700 Subject: [PATCH] [router] add mistral tool parser (#9622) Co-authored-by: Chang Su --- sgl-router/src/tool_parser/json_parser.rs | 139 ++++++-- sgl-router/src/tool_parser/mistral_parser.rs | 347 +++++++++++++++++++ sgl-router/src/tool_parser/mod.rs | 3 + sgl-router/src/tool_parser/registry.rs | 53 ++- sgl-router/src/tool_parser/tests.rs | 41 +-- 5 files changed, 512 insertions(+), 71 deletions(-) create mode 100644 sgl-router/src/tool_parser/mistral_parser.rs diff --git a/sgl-router/src/tool_parser/json_parser.rs b/sgl-router/src/tool_parser/json_parser.rs index 4dd7efc64..01321b6b5 100644 --- a/sgl-router/src/tool_parser/json_parser.rs +++ b/sgl-router/src/tool_parser/json_parser.rs @@ -7,7 +7,7 @@ use crate::tool_parser::{ partial_json::PartialJson, state::ParseState, traits::ToolParser, - types::{FunctionCall, StreamResult, ToolCall}, + types::{FunctionCall, StreamResult, TokenConfig, ToolCall}, }; /// JSON format parser for tool calls @@ -19,12 +19,8 @@ use crate::tool_parser::{ /// /// Supports configurable token markers for different models pub struct JsonParser { - /// Token(s) that mark the start of tool calls - start_tokens: Vec, - /// Token(s) that mark the end of tool calls - end_tokens: Vec, - /// Separator between multiple tool calls (reserved for future use) - _separator: String, + /// Token configuration for parsing + token_config: TokenConfig, /// Parser for handling incomplete JSON during streaming partial_json: PartialJson, /// Regex patterns for extracting content between tokens @@ -34,23 +30,18 @@ pub struct JsonParser { impl JsonParser { /// Create a new JSON parser with default configuration pub fn new() -> Self { - Self::with_config( - vec![], // No wrapper tokens by default - vec![], - ", ".to_string(), - ) + Self::with_config(TokenConfig { + start_tokens: vec![], + end_tokens: vec![], + separator: ", ".to_string(), + }) } /// Create a parser with custom token configuration - pub fn with_config( - start_tokens: Vec, - end_tokens: Vec, - separator: String, - ) -> Self { + pub fn with_config(config: TokenConfig) -> Self { // Build extraction patterns for each token pair - let extractors = start_tokens - .iter() - .zip(end_tokens.iter()) + let extractors: Vec = config + .iter_pairs() .filter_map(|(start, end)| { if !start.is_empty() && !end.is_empty() { // Use (?s) flag to enable DOTALL mode so . matches newlines @@ -64,9 +55,7 @@ impl JsonParser { .collect(); Self { - start_tokens, - end_tokens, - _separator: separator, + token_config: config, partial_json: PartialJson::default(), extractors, } @@ -74,26 +63,90 @@ impl JsonParser { /// Extract JSON content from text, handling wrapper tokens if configured fn extract_json_content<'a>(&self, text: &'a str) -> &'a str { - let mut content = text.trim(); + let mut content = text; - // Try each extractor pattern + // Try each extractor pattern (for tokens with both start and end) for extractor in &self.extractors { if let Some(captures) = extractor.captures(content) { if let Some(matched) = captures.get(1) { - content = matched.as_str().trim(); - break; + return matched.as_str().trim(); } } } // Handle special case where there's a start token but no end token - for (start, end) in self.start_tokens.iter().zip(self.end_tokens.iter()) { + for (start, end) in self.token_config.iter_pairs() { if !start.is_empty() && end.is_empty() { - content = content.strip_prefix(start).unwrap_or(content); + // Find the start token and extract everything after it + if let Some(pos) = content.find(start) { + content = &content[pos + start.len()..]; + return content.trim(); + } } } - content + content.trim() + } + + /// Try to extract a JSON object or array from text that may contain other content + fn extract_json_from_text(&self, text: &str) -> Option { + // Look for JSON object starting with { + if let Some(start) = text.find('{') { + let mut depth = 0; + let mut in_string = false; + let mut escape_next = false; + + for (i, ch) in text[start..].char_indices() { + if escape_next { + escape_next = false; + continue; + } + + match ch { + '\\' if in_string => escape_next = true, + '"' if !in_string => in_string = true, + '"' if in_string => in_string = false, + '{' if !in_string => depth += 1, + '}' if !in_string => { + depth -= 1; + if depth == 0 { + return Some(text[start..start + i + 1].to_string()); + } + } + _ => {} + } + } + } + + // Look for JSON array starting with [ + if let Some(start) = text.find('[') { + let mut depth = 0; + let mut in_string = false; + let mut escape_next = false; + + for (i, ch) in text[start..].char_indices() { + if escape_next { + escape_next = false; + continue; + } + + match ch { + '\\' if in_string => escape_next = true, + '"' if !in_string => in_string = true, + '"' if in_string => in_string = false, + '[' if !in_string => depth += 1, + ']' if !in_string => { + depth -= 1; + if depth == 0 { + return Some(text[start..start + i + 1].to_string()); + } + } + _ => {} + } + } + } + + None } /// Parse a single JSON object into a ToolCall @@ -167,13 +220,16 @@ impl JsonParser { /// Check if text contains potential tool call markers fn has_tool_markers(&self, text: &str) -> bool { // If no start tokens configured, check for JSON structure - if self.start_tokens.is_empty() { + if self.token_config.start_tokens.is_empty() { // For JSON, we just need to see the start of an object or array return text.contains('{') || text.contains('['); } // Check for any start token - self.start_tokens.iter().any(|token| text.contains(token)) + self.token_config + .start_tokens + .iter() + .any(|token| text.contains(token)) } } @@ -193,6 +249,15 @@ impl ToolParser for JsonParser { match serde_json::from_str::(json_content) { Ok(value) => self.parse_json_value(&value), Err(_) => { + // If no wrapper tokens configured and parse failed, + // try to extract JSON from mixed text + if self.token_config.start_tokens.is_empty() { + if let Some(extracted) = self.extract_json_from_text(text) { + if let Ok(value) = serde_json::from_str::(&extracted) { + return self.parse_json_value(&value); + } + } + } // Not valid JSON, return empty Ok(vec![]) } @@ -341,11 +406,11 @@ mod tests { #[tokio::test] async fn test_parse_with_wrapper_tokens() { - let parser = JsonParser::with_config( - vec!["".to_string()], - vec!["".to_string()], - ", ".to_string(), - ); + let parser = JsonParser::with_config(TokenConfig { + start_tokens: vec!["".to_string()], + end_tokens: vec!["".to_string()], + separator: ", ".to_string(), + }); let input = r#"{"name": "test", "arguments": {}}"#; let result = parser.parse_complete(input).await.unwrap(); diff --git a/sgl-router/src/tool_parser/mistral_parser.rs b/sgl-router/src/tool_parser/mistral_parser.rs new file mode 100644 index 000000000..68a3568aa --- /dev/null +++ b/sgl-router/src/tool_parser/mistral_parser.rs @@ -0,0 +1,347 @@ +use async_trait::async_trait; +use serde_json::Value; + +use crate::tool_parser::{ + errors::{ToolParserError, ToolParserResult}, + partial_json::PartialJson, + state::ParseState, + traits::ToolParser, + types::{FunctionCall, StreamResult, ToolCall}, +}; + +/// Mistral format parser for tool calls +/// +/// Handles the Mistral-specific format: +/// `[TOOL_CALLS] [{"name": "func", "arguments": {...}}, ...]` +/// +/// Features: +/// - Bracket counting for proper JSON array extraction +/// - Support for multiple tool calls in a single array +/// - String-aware parsing to handle nested brackets in JSON +pub struct MistralParser { + /// Parser for handling incomplete JSON during streaming + partial_json: PartialJson, +} + +impl MistralParser { + /// Create a new Mistral parser + pub fn new() -> Self { + Self { + partial_json: PartialJson::default(), + } + } + + /// Extract JSON array using bracket counting + /// + /// Handles nested brackets in JSON content by tracking: + /// - String boundaries (quotes) + /// - Escape sequences + /// - Bracket depth + fn extract_json_array<'a>(&self, text: &'a str) -> Option<&'a str> { + const BOT_TOKEN: &str = "[TOOL_CALLS] ["; + + // Find the start of the token + let start_idx = text.find(BOT_TOKEN)?; + + // Start from the opening bracket after [TOOL_CALLS] + // The -1 is to include the opening bracket that's part of the token + let json_start = start_idx + BOT_TOKEN.len() - 1; + + let mut bracket_count = 0; + let mut in_string = false; + let mut escape_next = false; + + let bytes = text.as_bytes(); + + for i in json_start..text.len() { + let char = bytes[i]; + + if escape_next { + escape_next = false; + continue; + } + + if char == b'\\' { + escape_next = true; + continue; + } + + if char == b'"' && !escape_next { + in_string = !in_string; + continue; + } + + if !in_string { + if char == b'[' { + bracket_count += 1; + } else if char == b']' { + bracket_count -= 1; + if bracket_count == 0 { + // Found the matching closing bracket + return Some(&text[json_start..=i]); + } + } + } + } + + // Incomplete array (no matching closing bracket found) + None + } + + /// Parse tool calls from a JSON array + fn parse_json_array(&self, json_str: &str) -> ToolParserResult> { + let value: Value = serde_json::from_str(json_str) + .map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?; + + let mut tools = Vec::new(); + + if let Value::Array(arr) = value { + for (index, item) in arr.iter().enumerate() { + if let Some(tool) = self.parse_single_object(item, index)? { + tools.push(tool); + } + } + } else { + // Single object case (shouldn't happen with Mistral format, but handle it) + if let Some(tool) = self.parse_single_object(&value, 0)? { + tools.push(tool); + } + } + + Ok(tools) + } + + /// Parse a single JSON object into a ToolCall + fn parse_single_object(&self, obj: &Value, index: usize) -> ToolParserResult> { + let name = obj.get("name").and_then(|v| v.as_str()); + + if let Some(name) = name { + // Get arguments - Mistral uses "arguments" key + let empty_obj = Value::Object(serde_json::Map::new()); + let args = obj.get("arguments").unwrap_or(&empty_obj); + + // Convert arguments to JSON string + let arguments = serde_json::to_string(args) + .map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?; + + // Generate ID with index for multiple tools + let id = format!("mistral_call_{}", index); + + Ok(Some(ToolCall { + id, + r#type: "function".to_string(), + function: FunctionCall { + name: name.to_string(), + arguments, + }, + })) + } else { + Ok(None) + } + } + + /// Check if text contains Mistral tool markers + fn has_tool_markers(&self, text: &str) -> bool { + text.contains("[TOOL_CALLS]") + } +} + +impl Default for MistralParser { + fn default() -> Self { + Self::new() + } +} + +#[async_trait] +impl ToolParser for MistralParser { + async fn parse_complete(&self, text: &str) -> ToolParserResult> { + // Check if text contains Mistral format + if !self.has_tool_markers(text) { + return Ok(vec![]); + } + + // Extract JSON array from Mistral format + if let Some(json_array) = self.extract_json_array(text) { + self.parse_json_array(json_array) + } else { + // Markers present but no complete array found + Ok(vec![]) + } + } + + async fn parse_incremental( + &self, + chunk: &str, + state: &mut ParseState, + ) -> ToolParserResult { + state.buffer.push_str(chunk); + + // Check if we have the start marker + if !self.has_tool_markers(&state.buffer) { + return Ok(StreamResult::Incomplete); + } + + // Try to extract complete JSON array + if let Some(json_array) = self.extract_json_array(&state.buffer) { + // Parse with partial JSON to handle incomplete content + match self.partial_json.parse_value(json_array) { + Ok((value, consumed)) => { + // Check if we have a complete JSON structure + if consumed == json_array.len() { + // Complete JSON, parse tool calls + let tools = if let Value::Array(arr) = value { + let mut result = Vec::new(); + for (index, item) in arr.iter().enumerate() { + if let Some(tool) = self.parse_single_object(item, index)? { + result.push(tool); + } + } + result + } else { + vec![] + }; + + if !tools.is_empty() { + // Clear buffer since we consumed everything + state.buffer.clear(); + + // Return the first tool (simplified for Phase 3) + // Full multi-tool streaming will be implemented later + if let Some(tool) = tools.into_iter().next() { + return Ok(StreamResult::ToolComplete(tool)); + } + } + } else { + // Partial JSON - try to extract tool name for streaming + if let Value::Array(arr) = value { + if let Some(first_tool) = arr.first() { + if let Some(name) = first_tool.get("name").and_then(|v| v.as_str()) + { + // Check if we've already sent the name + if !state.in_string { + state.in_string = true; // Use as flag for "name sent" + return Ok(StreamResult::ToolName { + index: 0, + name: name.to_string(), + }); + } + + // Check for arguments + if let Some(args) = first_tool.get("arguments") { + if let Ok(args_str) = serde_json::to_string(args) { + return Ok(StreamResult::ToolArguments { + index: 0, + arguments: args_str, + }); + } + } + } + } + } + } + } + Err(_) => { + // Failed to parse even as partial JSON + // Keep buffering + } + } + } + + Ok(StreamResult::Incomplete) + } + + fn detect_format(&self, text: &str) -> bool { + // Check if text contains Mistral-specific markers + if self.has_tool_markers(text) { + // Try to extract and validate the array + if let Some(json_array) = self.extract_json_array(text) { + // Check if it's valid JSON + if let Ok(value) = serde_json::from_str::(json_array) { + // Check if it contains tool-like structures + match value { + Value::Array(ref arr) => arr.iter().any(|v| { + v.as_object().is_some_and(|o| { + o.contains_key("name") && o.contains_key("arguments") + }) + }), + Value::Object(ref obj) => { + obj.contains_key("name") && obj.contains_key("arguments") + } + _ => false, + } + } else { + false + } + } else { + // Has markers but no complete array - might be streaming + true + } + } else { + false + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_parse_mistral_format() { + let parser = MistralParser::new(); + let input = r#"[TOOL_CALLS] [{"name": "get_weather", "arguments": {"location": "Paris", "units": "celsius"}}]"#; + + let result = parser.parse_complete(input).await.unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result[0].function.name, "get_weather"); + assert!(result[0].function.arguments.contains("Paris")); + } + + #[tokio::test] + async fn test_parse_multiple_tools() { + let parser = MistralParser::new(); + let input = r#"[TOOL_CALLS] [ + {"name": "search", "arguments": {"query": "rust programming"}}, + {"name": "calculate", "arguments": {"expression": "2 + 2"}} + ]"#; + + let result = parser.parse_complete(input).await.unwrap(); + assert_eq!(result.len(), 2); + assert_eq!(result[0].function.name, "search"); + assert_eq!(result[1].function.name, "calculate"); + } + + #[tokio::test] + async fn test_nested_brackets_in_json() { + let parser = MistralParser::new(); + let input = r#"[TOOL_CALLS] [{"name": "process", "arguments": {"data": [1, 2, [3, 4]], "config": {"nested": [5, 6]}}}]"#; + + let result = parser.parse_complete(input).await.unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result[0].function.name, "process"); + // JSON serialization removes spaces, so check for [3,4] without spaces + assert!(result[0].function.arguments.contains("[3,4]")); + } + + #[tokio::test] + async fn test_escaped_quotes_in_strings() { + let parser = MistralParser::new(); + let input = r#"[TOOL_CALLS] [{"name": "echo", "arguments": {"message": "He said \"Hello [World]\""}}]"#; + + let result = parser.parse_complete(input).await.unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result[0].function.name, "echo"); + } + + #[test] + fn test_detect_format() { + let parser = MistralParser::new(); + + assert!(parser.detect_format(r#"[TOOL_CALLS] [{"name": "test", "arguments": {}}]"#)); + assert!( + parser.detect_format(r#"Some text [TOOL_CALLS] [{"name": "test", "arguments": {}}]"#) + ); + assert!(!parser.detect_format(r#"{"name": "test", "arguments": {}}"#)); + assert!(!parser.detect_format("plain text")); + } +} diff --git a/sgl-router/src/tool_parser/mod.rs b/sgl-router/src/tool_parser/mod.rs index 01d42385f..b2f775c8b 100644 --- a/sgl-router/src/tool_parser/mod.rs +++ b/sgl-router/src/tool_parser/mod.rs @@ -3,7 +3,9 @@ /// This module provides infrastructure for parsing tool calls from various model formats. pub mod errors; pub mod json_parser; +pub mod mistral_parser; pub mod partial_json; + pub mod registry; pub mod state; pub mod traits; @@ -15,6 +17,7 @@ mod tests; // Re-export commonly used types pub use errors::{ToolParserError, ToolParserResult}; pub use json_parser::JsonParser; +pub use mistral_parser::MistralParser; pub use registry::ParserRegistry; pub use state::{ParsePhase, ParseState}; pub use traits::{PartialJsonParser, ToolParser}; diff --git a/sgl-router/src/tool_parser/registry.rs b/sgl-router/src/tool_parser/registry.rs index 11153dfd5..c1178200a 100644 --- a/sgl-router/src/tool_parser/registry.rs +++ b/sgl-router/src/tool_parser/registry.rs @@ -50,15 +50,28 @@ impl ParserRegistry { } } - // Try prefix matching (e.g., "gpt-4" matches "gpt-*") - for (pattern, parser_name) in &self.model_mapping { - if pattern.ends_with('*') { - let prefix = &pattern[..pattern.len() - 1]; - if model.starts_with(prefix) { - if let Some(parser) = self.parsers.get(parser_name) { - return Some(parser.clone()); - } + // Try prefix matching with more specific patterns first + // Collect all matching patterns and sort by specificity (longer = more specific) + let mut matches: Vec<(&String, &String)> = self + .model_mapping + .iter() + .filter(|(pattern, _)| { + if pattern.ends_with('*') { + let prefix = &pattern[..pattern.len() - 1]; + model.starts_with(prefix) + } else { + false } + }) + .collect(); + + // Sort by pattern length in descending order (longer patterns are more specific) + matches.sort_by_key(|(pattern, _)| std::cmp::Reverse(pattern.len())); + + // Return the first matching parser + for (_, parser_name) in matches { + if let Some(parser) = self.parsers.get(parser_name) { + return Some(parser.clone()); } } @@ -97,20 +110,32 @@ impl ParserRegistry { // Anthropic models self.map_model("claude-*", "json"); - // Mistral models (will use json until mistral parser is implemented) - self.map_model("mistral-*", "json"); - self.map_model("mixtral-*", "json"); + // Mistral models - use Mistral parser + self.map_model("mistral-*", "mistral"); + self.map_model("mixtral-*", "mistral"); - // Qwen models (will use json until qwen parser is implemented) - self.map_model("qwen*", "json"); + // Qwen models - use Qwen parser + self.map_model("qwen*", "qwen"); + self.map_model("Qwen*", "qwen"); - // Llama models (will use json until llama parser is implemented) + // Llama models + // Llama 4 uses pythonic format + self.map_model("llama-4*", "pythonic"); + self.map_model("meta-llama-4*", "pythonic"); + // Llama 3.2 uses python_tag format + self.map_model("llama-3.2*", "llama"); + self.map_model("meta-llama-3.2*", "llama"); + // Other Llama models use JSON self.map_model("llama-*", "json"); self.map_model("meta-llama-*", "json"); + // DeepSeek models - DeepSeek v3 would need custom parser, v2 uses pythonic + self.map_model("deepseek-*", "pythonic"); + // Other models default to JSON self.map_model("gemini-*", "json"); self.map_model("palm-*", "json"); + self.map_model("gemma-*", "json"); } /// Set the default parser diff --git a/sgl-router/src/tool_parser/tests.rs b/sgl-router/src/tool_parser/tests.rs index 2635e0350..a9284586a 100644 --- a/sgl-router/src/tool_parser/tests.rs +++ b/sgl-router/src/tool_parser/tests.rs @@ -4,6 +4,7 @@ use crate::tool_parser::partial_json::{ compute_diff, find_common_prefix, is_complete_json, PartialJson, }; use crate::tool_parser::traits::ToolParser; +use crate::tool_parser::types::TokenConfig; #[test] fn test_parse_state_new() { @@ -299,11 +300,11 @@ async fn test_json_parser_with_parameters() { #[tokio::test] async fn test_json_parser_with_tokens() { // Test with custom wrapper tokens - let parser = JsonParser::with_config( - vec!["[TOOL_CALLS] [".to_string()], - vec!["]".to_string()], - ", ".to_string(), - ); + let parser = JsonParser::with_config(TokenConfig { + start_tokens: vec!["[TOOL_CALLS] [".to_string()], + end_tokens: vec!["]".to_string()], + separator: ", ".to_string(), + }); let input = r#"[TOOL_CALLS] [{"name": "search", "arguments": {"query": "rust programming"}}]"#; let result = parser.parse_complete(input).await.unwrap(); @@ -315,11 +316,11 @@ async fn test_json_parser_with_tokens() { #[tokio::test] async fn test_multiline_json_with_tokens() { // Test that regex with (?s) flag properly handles multi-line JSON - let parser = JsonParser::with_config( - vec!["".to_string()], - vec!["".to_string()], - ", ".to_string(), - ); + let parser = JsonParser::with_config(TokenConfig { + start_tokens: vec!["".to_string()], + end_tokens: vec!["".to_string()], + separator: ", ".to_string(), + }); // Pretty-printed multi-line JSON let input = r#"{ @@ -493,11 +494,11 @@ mod failure_cases { #[tokio::test] async fn test_broken_wrapper_tokens() { - let parser = JsonParser::with_config( - vec!["".to_string()], - vec!["".to_string()], - ", ".to_string(), - ); + let parser = JsonParser::with_config(TokenConfig { + start_tokens: vec!["".to_string()], + end_tokens: vec!["".to_string()], + separator: ", ".to_string(), + }); // Missing end token let input = r#"{"name": "test", "arguments": {}}"#; @@ -678,11 +679,11 @@ mod edge_cases { #[tokio::test] async fn test_multiple_token_pairs_with_conflicts() { // Test with overlapping token patterns - let parser = JsonParser::with_config( - vec!["<<".to_string(), "".to_string()], - vec![">>".to_string(), "".to_string()], - ", ".to_string(), - ); + let parser = JsonParser::with_config(TokenConfig { + start_tokens: vec!["<<".to_string(), "".to_string()], + end_tokens: vec![">>".to_string(), "".to_string()], + separator: ", ".to_string(), + }); // First pattern let input = r#"<<{"name": "test1", "arguments": {}}>>"#;