diff --git a/sgl-router/src/tool_parser/mod.rs b/sgl-router/src/tool_parser/mod.rs index bc4c5a020..54b5a0a11 100644 --- a/sgl-router/src/tool_parser/mod.rs +++ b/sgl-router/src/tool_parser/mod.rs @@ -5,6 +5,8 @@ pub mod errors; pub mod json_parser; pub mod mistral_parser; pub mod partial_json; +pub mod python_literal_parser; +pub mod pythonic_parser; pub mod qwen_parser; pub mod registry; pub mod state; @@ -18,6 +20,7 @@ mod tests; pub use errors::{ToolParserError, ToolParserResult}; pub use json_parser::JsonParser; pub use mistral_parser::MistralParser; +pub use pythonic_parser::PythonicParser; pub use qwen_parser::QwenParser; pub use registry::ParserRegistry; pub use state::{ParsePhase, ParseState}; diff --git a/sgl-router/src/tool_parser/python_literal_parser.rs b/sgl-router/src/tool_parser/python_literal_parser.rs new file mode 100644 index 000000000..4acc69d34 --- /dev/null +++ b/sgl-router/src/tool_parser/python_literal_parser.rs @@ -0,0 +1,442 @@ +/// Minimal Python literal parser for Pythonic tool call format +/// +/// This module provides a recursive descent parser for Python literals +/// (strings, numbers, booleans, None, lists, dicts) without requiring +/// a full Python AST parser. +use serde_json::{json, Value}; +use std::collections::HashMap; + +use crate::tool_parser::errors::{ToolParserError, ToolParserResult}; + +/// Token types for Python literals +#[derive(Debug, Clone, PartialEq)] +enum Token { + // Literals + String(String), + Number(String), + True, + False, + None, + + // Delimiters + LeftBracket, // [ + RightBracket, // ] + LeftBrace, // { + RightBrace, // } + LeftParen, // ( + RightParen, // ) + Comma, // , + Colon, // : + Equals, // = + + // Identifier for function names + Identifier(String), + + // End of input + Eof, +} + +/// Lexer for Python literals +struct Lexer { + input: Vec, + position: usize, +} + +impl Lexer { + fn new(input: &str) -> Self { + Self { + input: input.chars().collect(), + position: 0, + } + } + + fn current_char(&self) -> Option { + self.input.get(self.position).copied() + } + + fn advance(&mut self) { + if self.position < self.input.len() { + self.position += 1; + } + } + + fn skip_whitespace(&mut self) { + while let Some(ch) = self.current_char() { + if ch.is_whitespace() { + self.advance(); + } else { + break; + } + } + } + + fn read_string(&mut self, quote_char: char) -> Result { + let mut result = String::new(); + self.advance(); // Skip opening quote + + while let Some(ch) = self.current_char() { + if ch == '\\' { + self.advance(); + if let Some(escaped) = self.current_char() { + match escaped { + 'n' => result.push('\n'), + 't' => result.push('\t'), + 'r' => result.push('\r'), + '\\' => result.push('\\'), + '\'' => result.push('\''), + '"' => result.push('"'), + _ => { + result.push('\\'); + result.push(escaped); + } + } + self.advance(); + } + } else if ch == quote_char { + self.advance(); // Skip closing quote + return Ok(result); + } else { + result.push(ch); + self.advance(); + } + } + + Err(ToolParserError::ParsingFailed("Unterminated string".into())) + } + + fn read_number(&mut self) -> String { + let mut result = String::new(); + + // Handle negative numbers + if self.current_char() == Some('-') { + result.push('-'); + self.advance(); + } + + // Read digits and decimal point + while let Some(ch) = self.current_char() { + if ch.is_ascii_digit() || ch == '.' || ch == 'e' || ch == 'E' || ch == '+' || ch == '-' + { + result.push(ch); + self.advance(); + } else { + break; + } + } + + result + } + + fn read_identifier(&mut self) -> String { + let mut result = String::new(); + + while let Some(ch) = self.current_char() { + if ch.is_alphanumeric() || ch == '_' { + result.push(ch); + self.advance(); + } else { + break; + } + } + + result + } + + fn next_token(&mut self) -> Result { + self.skip_whitespace(); + + match self.current_char() { + None => Ok(Token::Eof), + Some('[') => { + self.advance(); + Ok(Token::LeftBracket) + } + Some(']') => { + self.advance(); + Ok(Token::RightBracket) + } + Some('{') => { + self.advance(); + Ok(Token::LeftBrace) + } + Some('}') => { + self.advance(); + Ok(Token::RightBrace) + } + Some('(') => { + self.advance(); + Ok(Token::LeftParen) + } + Some(')') => { + self.advance(); + Ok(Token::RightParen) + } + Some(',') => { + self.advance(); + Ok(Token::Comma) + } + Some(':') => { + self.advance(); + Ok(Token::Colon) + } + Some('=') => { + self.advance(); + Ok(Token::Equals) + } + Some('"') => Ok(Token::String(self.read_string('"')?)), + Some('\'') => Ok(Token::String(self.read_string('\'')?)), + Some(ch) if ch == '-' || ch.is_ascii_digit() => Ok(Token::Number(self.read_number())), + Some(ch) if ch.is_alphabetic() || ch == '_' => { + let ident = self.read_identifier(); + match ident.as_str() { + "True" => Ok(Token::True), + "False" => Ok(Token::False), + "None" => Ok(Token::None), + _ => Ok(Token::Identifier(ident)), + } + } + Some(ch) => Err(ToolParserError::ParsingFailed(format!( + "Unexpected character: {}", + ch + ))), + } + } +} + +/// Parser for Python literals +pub struct PythonLiteralParser { + lexer: Lexer, + current_token: Token, +} + +impl PythonLiteralParser { + pub fn new(input: &str) -> Result { + let mut lexer = Lexer::new(input); + let current_token = lexer.next_token()?; + Ok(Self { + lexer, + current_token, + }) + } + + fn advance(&mut self) -> Result<(), ToolParserError> { + self.current_token = self.lexer.next_token()?; + Ok(()) + } + + fn expect(&mut self, expected: Token) -> Result<(), ToolParserError> { + if self.current_token == expected { + self.advance()?; + Ok(()) + } else { + Err(ToolParserError::ParsingFailed(format!( + "Expected {:?}, got {:?}", + expected, self.current_token + ))) + } + } + + /// Parse a Python literal value + pub fn parse_value(&mut self) -> Result { + match &self.current_token.clone() { + Token::String(s) => { + let value = s.clone(); + self.advance()?; + Ok(json!(value)) + } + Token::Number(n) => { + let value = if let Ok(int_val) = n.parse::() { + json!(int_val) + } else if let Ok(float_val) = n.parse::() { + json!(float_val) + } else { + return Err(ToolParserError::ParsingFailed(format!( + "Invalid number: {}", + n + ))); + }; + self.advance()?; + Ok(value) + } + Token::True => { + self.advance()?; + Ok(json!(true)) + } + Token::False => { + self.advance()?; + Ok(json!(false)) + } + Token::None => { + self.advance()?; + Ok(Value::Null) + } + Token::LeftBracket => self.parse_list(), + Token::LeftBrace => self.parse_dict(), + _ => Err(ToolParserError::ParsingFailed(format!( + "Unexpected token: {:?}", + self.current_token + ))), + } + } + + /// Parse a Python list: [item1, item2, ...] + fn parse_list(&mut self) -> Result { + self.expect(Token::LeftBracket)?; + let mut items = Vec::new(); + + // Handle empty list + if self.current_token == Token::RightBracket { + self.advance()?; + return Ok(json!(items)); + } + + loop { + items.push(self.parse_value()?); + + if self.current_token == Token::Comma { + self.advance()?; + // Handle trailing comma + if self.current_token == Token::RightBracket { + break; + } + } else if self.current_token == Token::RightBracket { + break; + } else { + return Err(ToolParserError::ParsingFailed(format!( + "Expected ',' or ']', got {:?}", + self.current_token + ))); + } + } + + self.expect(Token::RightBracket)?; + Ok(json!(items)) + } + + /// Parse a Python dict: {key1: value1, key2: value2, ...} + fn parse_dict(&mut self) -> Result { + self.expect(Token::LeftBrace)?; + let mut map = HashMap::new(); + + // Handle empty dict + if self.current_token == Token::RightBrace { + self.advance()?; + return Ok(json!(map)); + } + + loop { + // Parse key (must be a string) + let key = match &self.current_token { + Token::String(s) => { + let k = s.clone(); + self.advance()?; + k + } + _ => { + return Err(ToolParserError::ParsingFailed(format!( + "Expected string key, got {:?}", + self.current_token + ))) + } + }; + + self.expect(Token::Colon)?; + + // Parse value + let value = self.parse_value()?; + map.insert(key, value); + + if self.current_token == Token::Comma { + self.advance()?; + // Handle trailing comma + if self.current_token == Token::RightBrace { + break; + } + } else if self.current_token == Token::RightBrace { + break; + } else { + return Err(ToolParserError::ParsingFailed(format!( + "Expected ',' or '}}', got {:?}", + self.current_token + ))); + } + } + + self.expect(Token::RightBrace)?; + Ok(json!(map)) + } +} + +/// Parse a Python literal string into a JSON value +pub fn parse_python_literal(input: &str) -> ToolParserResult { + let mut parser = PythonLiteralParser::new(input)?; + parser.parse_value() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_primitives() { + assert_eq!(parse_python_literal("True").unwrap(), json!(true)); + assert_eq!(parse_python_literal("False").unwrap(), json!(false)); + assert_eq!(parse_python_literal("None").unwrap(), Value::Null); + assert_eq!(parse_python_literal("42").unwrap(), json!(42)); + assert_eq!(parse_python_literal("12.345").unwrap(), json!(12.345)); + assert_eq!(parse_python_literal("-42").unwrap(), json!(-42)); + assert_eq!(parse_python_literal("\"hello\"").unwrap(), json!("hello")); + assert_eq!(parse_python_literal("'world'").unwrap(), json!("world")); + } + + #[test] + fn test_parse_list() { + assert_eq!(parse_python_literal("[]").unwrap(), json!([])); + assert_eq!(parse_python_literal("[1, 2, 3]").unwrap(), json!([1, 2, 3])); + assert_eq!( + parse_python_literal("[\"a\", \"b\", \"c\"]").unwrap(), + json!(["a", "b", "c"]) + ); + assert_eq!( + parse_python_literal("[True, False, None]").unwrap(), + json!([true, false, null]) + ); + // Nested list + assert_eq!( + parse_python_literal("[[1, 2], [3, 4]]").unwrap(), + json!([[1, 2], [3, 4]]) + ); + } + + #[test] + fn test_parse_dict() { + assert_eq!(parse_python_literal("{}").unwrap(), json!({})); + assert_eq!( + parse_python_literal("{\"a\": 1, \"b\": 2}").unwrap(), + json!({"a": 1, "b": 2}) + ); + assert_eq!( + parse_python_literal("{'x': True, 'y': False}").unwrap(), + json!({"x": true, "y": false}) + ); + // Nested dict + assert_eq!( + parse_python_literal("{\"nested\": {\"value\": [1, 2, 3]}}").unwrap(), + json!({"nested": {"value": [1, 2, 3]}}) + ); + } + + #[test] + fn test_complex_nested() { + let input = r#"{"config": {"nested": {"value": [1, 2, 3]}, "enabled": True}}"#; + let expected = json!({ + "config": { + "nested": { + "value": [1, 2, 3] + }, + "enabled": true + } + }); + assert_eq!(parse_python_literal(input).unwrap(), expected); + } +} diff --git a/sgl-router/src/tool_parser/pythonic_parser.rs b/sgl-router/src/tool_parser/pythonic_parser.rs new file mode 100644 index 000000000..e74272345 --- /dev/null +++ b/sgl-router/src/tool_parser/pythonic_parser.rs @@ -0,0 +1,428 @@ +/// Pythonic format parser for tool calls +/// +/// Handles Python function call syntax within square brackets: +/// ```text +/// [tool1(arg1=val1, arg2=val2), tool2(arg1=val3)] +/// ``` +/// +/// This format is used by Llama-4 models and uses Python literals +/// rather than JSON for arguments. +use async_trait::async_trait; +use regex::Regex; +use serde_json::{json, Value}; + +use crate::tool_parser::{ + errors::ToolParserResult, + python_literal_parser::parse_python_literal, + state::ParseState, + traits::ToolParser, + types::{FunctionCall, StreamResult, ToolCall}, +}; + +/// Parser for Pythonic tool call format +pub struct PythonicParser { + /// Regex to detect tool calls in Pythonic format + tool_call_regex: Regex, +} + +impl PythonicParser { + /// Create a new Pythonic parser + pub fn new() -> Self { + // Simple regex to detect start of Pythonic tool calls + // We'll use manual parsing for the actual extraction + let pattern = r"\[[a-zA-Z_]\w*\("; + let tool_call_regex = Regex::new(pattern).expect("Valid regex pattern"); + + Self { tool_call_regex } + } + + /// Extract tool calls using bracket counting (similar to MistralParser) + fn extract_tool_calls(&self, text: &str) -> Option { + // Find the start of a tool call list - look for [ followed by a function name + let chars: Vec = text.chars().collect(); + + for start_idx in 0..chars.len() { + if chars[start_idx] != '[' { + continue; + } + + // Check if this looks like a tool call + // Skip whitespace after [ + let mut check_idx = start_idx + 1; + while check_idx < chars.len() && chars[check_idx].is_whitespace() { + check_idx += 1; + } + + // Check if we have a function name (starts with letter or underscore) + if check_idx >= chars.len() + || (!chars[check_idx].is_alphabetic() && chars[check_idx] != '_') + { + continue; + } + + // Now count brackets to find the matching ] + let mut bracket_count = 0; + let mut _paren_count = 0; + let mut _brace_count = 0; + let mut in_string = false; + let mut string_char = ' '; + let mut escape_next = false; + + for i in start_idx..chars.len() { + let ch = chars[i]; + + if escape_next { + escape_next = false; + continue; + } + + if ch == '\\' && in_string { + escape_next = true; + continue; + } + + if !in_string && (ch == '"' || ch == '\'') { + in_string = true; + string_char = ch; + } else if in_string && ch == string_char && !escape_next { + in_string = false; + } else if !in_string { + match ch { + '[' => bracket_count += 1, + ']' => { + bracket_count -= 1; + if bracket_count == 0 { + // Found the matching bracket + let extracted: String = chars[start_idx..=i].iter().collect(); + // Verify this actually contains a function call + if extracted.contains('(') && extracted.contains(')') { + return Some(extracted); + } + } + } + '(' => _paren_count += 1, + ')' => _paren_count -= 1, + '{' => _brace_count += 1, + '}' => _brace_count -= 1, + _ => {} + } + } + } + } + None + } + + /// Strip special tokens that Llama 4 might output + fn strip_special_tokens(text: &str) -> String { + text.replace("<|python_start|>", "") + .replace("<|python_end|>", "") + } + + /// Parse a single function call from Python syntax + fn parse_function_call(&self, call_str: &str) -> ToolParserResult> { + // Match function_name(args) - use (?s) to make . match newlines + let call_regex = Regex::new(r"(?s)^([a-zA-Z_]\w*)\((.*)\)$").unwrap(); + + if let Some(captures) = call_regex.captures(call_str.trim()) { + let function_name = captures.get(1).unwrap().as_str(); + let args_str = captures.get(2).unwrap().as_str(); + + // Parse arguments + let arguments = self.parse_arguments(args_str)?; + + Ok(Some(ToolCall { + id: format!("call_{}", uuid::Uuid::new_v4()), + r#type: "function".to_string(), + function: FunctionCall { + name: function_name.to_string(), + arguments: serde_json::to_string(&arguments)?, + }, + })) + } else { + Ok(None) + } + } + + /// Parse Python-style arguments into JSON + fn parse_arguments(&self, args_str: &str) -> ToolParserResult { + if args_str.trim().is_empty() { + return Ok(json!({})); + } + + let mut result = serde_json::Map::new(); + let mut current_key = String::new(); + let mut current_value = String::new(); + let mut in_key = true; + let mut depth = 0; + let mut in_string = false; + let mut string_char = ' '; + let mut escape_next = false; + + let chars: Vec = args_str.chars().collect(); + let mut i = 0; + + while i < chars.len() { + let ch = chars[i]; + + if escape_next { + if in_key { + current_key.push(ch); + } else { + current_value.push(ch); + } + escape_next = false; + i += 1; + continue; + } + + if ch == '\\' && in_string { + escape_next = true; + current_value.push(ch); + i += 1; + continue; + } + + // Handle string literals + if !in_string && (ch == '"' || ch == '\'') { + in_string = true; + string_char = ch; + if !in_key { + current_value.push(ch); + } + } else if in_string && ch == string_char && !escape_next { + in_string = false; + if !in_key { + current_value.push(ch); + } + } else if in_string { + if in_key { + current_key.push(ch); + } else { + current_value.push(ch); + } + } else { + // Not in string + match ch { + '=' if in_key && depth == 0 => { + in_key = false; + } + ',' if depth == 0 => { + // End of current argument + if !current_key.is_empty() { + let value = parse_python_literal(current_value.trim())?; + result.insert(current_key.trim().to_string(), value); + } + current_key.clear(); + current_value.clear(); + in_key = true; + } + '[' | '{' | '(' => { + depth += 1; + if !in_key { + current_value.push(ch); + } + } + ']' | '}' | ')' => { + depth -= 1; + if !in_key { + current_value.push(ch); + } + } + _ => { + if in_key { + if !ch.is_whitespace() || !current_key.is_empty() { + current_key.push(ch); + } + } else { + current_value.push(ch); + } + } + } + } + + i += 1; + } + + // Handle the last argument + if !current_key.is_empty() { + let value = parse_python_literal(current_value.trim())?; + result.insert(current_key.trim().to_string(), value); + } + + Ok(Value::Object(result)) + } +} + +#[async_trait] +impl ToolParser for PythonicParser { + async fn parse_complete(&self, text: &str) -> ToolParserResult> { + let cleaned = Self::strip_special_tokens(text); + + // Extract tool calls using bracket counting + if let Some(tool_calls_text) = self.extract_tool_calls(&cleaned) { + // Remove the outer brackets + let tool_calls_str = &tool_calls_text[1..tool_calls_text.len() - 1]; + + // Split into individual function calls + let mut calls = Vec::new(); + let mut current_call = String::new(); + let mut paren_depth = 0; + let mut in_string = false; + let mut string_char = ' '; + + for ch in tool_calls_str.chars() { + if !in_string && (ch == '"' || ch == '\'') { + in_string = true; + string_char = ch; + current_call.push(ch); + } else if in_string && ch == string_char { + in_string = false; + current_call.push(ch); + } else if in_string { + current_call.push(ch); + } else { + match ch { + '(' => { + paren_depth += 1; + current_call.push(ch); + } + ')' => { + paren_depth -= 1; + current_call.push(ch); + } + ',' if paren_depth == 0 => { + // End of current function call + if let Some(call) = self.parse_function_call(current_call.trim())? { + calls.push(call); + } + current_call.clear(); + } + _ => { + if !ch.is_whitespace() || !current_call.is_empty() { + current_call.push(ch); + } + } + } + } + } + + // Handle the last call (important for single calls or the last call in a list) + if !current_call.trim().is_empty() { + if let Some(call) = self.parse_function_call(current_call.trim())? { + calls.push(call); + } + } + + Ok(calls) + } else { + Ok(vec![]) + } + } + + async fn parse_incremental( + &self, + chunk: &str, + state: &mut ParseState, + ) -> ToolParserResult { + // For Pythonic format, we accumulate until we have a complete tool call + // This is a simplified implementation + state.buffer.push_str(chunk); + + // Try to parse if we have a complete tool call + let cleaned = Self::strip_special_tokens(&state.buffer); + if self.extract_tool_calls(&cleaned).is_some() { + let result = self.parse_complete(&state.buffer).await?; + if !result.is_empty() { + state.buffer.clear(); + return Ok(StreamResult::ToolComplete( + result.into_iter().next().unwrap(), + )); + } + } + + Ok(StreamResult::Incomplete) + } + + fn detect_format(&self, text: &str) -> bool { + let cleaned = Self::strip_special_tokens(text); + self.tool_call_regex.is_match(&cleaned) + } +} + +impl Default for PythonicParser { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_single_function_call() { + let parser = PythonicParser::new(); + let input = r#"[search_web(query="Rust programming", max_results=5)]"#; + + let result = parser.parse_complete(input).await.unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result[0].function.name, "search_web"); + + let args: Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + assert_eq!(args["query"], "Rust programming"); + assert_eq!(args["max_results"], 5); + } + + #[tokio::test] + async fn test_multiple_function_calls() { + let parser = PythonicParser::new(); + let input = r#"[get_weather(city="Tokyo"), search(query="news")]"#; + + let result = parser.parse_complete(input).await.unwrap(); + assert_eq!(result.len(), 2); + assert_eq!(result[0].function.name, "get_weather"); + assert_eq!(result[1].function.name, "search"); + } + + #[tokio::test] + async fn test_python_literals() { + let parser = PythonicParser::new(); + let input = r#"[test(flag=True, disabled=False, optional=None)]"#; + + let result = parser.parse_complete(input).await.unwrap(); + assert_eq!(result.len(), 1); + + let args: Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + assert_eq!(args["flag"], true); + assert_eq!(args["disabled"], false); + assert_eq!(args["optional"], Value::Null); + } + + #[tokio::test] + async fn test_special_tokens() { + let parser = PythonicParser::new(); + let input = r#"<|python_start|>[calculate(x=10, y=20)]<|python_end|>"#; + + let result = parser.parse_complete(input).await.unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result[0].function.name, "calculate"); + + let args: Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + assert_eq!(args["x"], 10); + assert_eq!(args["y"], 20); + } + + #[tokio::test] + async fn test_llama4_format() { + let parser = PythonicParser::new(); + let input = r#"[get_weather(city="London", units="celsius")]"#; + + let result = parser.parse_complete(input).await.unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result[0].function.name, "get_weather"); + + let args: Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + assert_eq!(args["city"], "London"); + assert_eq!(args["units"], "celsius"); + } +} diff --git a/sgl-router/src/tool_parser/registry.rs b/sgl-router/src/tool_parser/registry.rs index dc61fccbb..598009aa4 100644 --- a/sgl-router/src/tool_parser/registry.rs +++ b/sgl-router/src/tool_parser/registry.rs @@ -1,5 +1,6 @@ use crate::tool_parser::json_parser::JsonParser; use crate::tool_parser::mistral_parser::MistralParser; +use crate::tool_parser::pythonic_parser::PythonicParser; use crate::tool_parser::qwen_parser::QwenParser; use crate::tool_parser::traits::ToolParser; use std::collections::HashMap; @@ -104,6 +105,9 @@ impl ParserRegistry { // Qwen parser - ... format self.register_parser("qwen", Arc::new(QwenParser::new())); + + // Pythonic parser - [func(arg=val)] format + self.register_parser("pythonic", Arc::new(PythonicParser::new())); } /// Register default model mappings