diff --git a/sgl-router/src/tool_parser/llama_parser.rs b/sgl-router/src/tool_parser/llama_parser.rs new file mode 100644 index 000000000..677945d85 --- /dev/null +++ b/sgl-router/src/tool_parser/llama_parser.rs @@ -0,0 +1,156 @@ +use async_trait::async_trait; + +use crate::tool_parser::{ + errors::ToolParserResult, + json_parser::JsonParser, + state::ParseState, + traits::ToolParser, + types::{StreamResult, TokenConfig, ToolCall}, +}; + +/// Llama 3.2 format parser for tool calls +/// +/// Handles the Llama 3.2 specific format: +/// `<|python_tag|>{"name": "func", "arguments": {...}}` +/// +/// Also supports plain JSON without the python_tag prefix +pub struct LlamaParser { + /// Underlying JSON parser with Llama-specific configuration + json_parser: JsonParser, +} + +impl LlamaParser { + /// Create a new Llama parser + pub fn new() -> Self { + // Configure JSON parser with Llama's python_tag token + // Note: No end token for python_tag format + let json_parser = JsonParser::with_config(TokenConfig { + start_tokens: vec!["<|python_tag|>".to_string()], + end_tokens: vec!["".to_string()], // Empty end token + separator: ";".to_string(), // Llama uses semicolon for multiple calls (though not well supported) + }); + + Self { json_parser } + } +} + +impl Default for LlamaParser { + fn default() -> Self { + Self::new() + } +} + +#[async_trait] +impl ToolParser for LlamaParser { + async fn parse_complete(&self, text: &str) -> ToolParserResult> { + // First try with the configured python_tag parser + let result = self.json_parser.parse_complete(text).await?; + + if !result.is_empty() { + return Ok(result); + } + + // If no results and text starts with '{', try plain JSON + if text.trim_start().starts_with('{') { + // Create a temporary plain JSON parser + let plain_parser = JsonParser::new(); + return plain_parser.parse_complete(text).await; + } + + Ok(vec![]) + } + + async fn parse_incremental( + &self, + chunk: &str, + state: &mut ParseState, + ) -> ToolParserResult { + // Try with the python_tag parser first + let result = self.json_parser.parse_incremental(chunk, state).await?; + + // If we get Incomplete and buffer starts with '{', might be plain JSON + if matches!(result, StreamResult::Incomplete) && state.buffer.trim_start().starts_with('{') + { + // Check if we have python_tag in the buffer + if !state.buffer.contains("<|python_tag|>") { + // Likely plain JSON, create temporary parser + let plain_parser = JsonParser::new(); + return plain_parser.parse_incremental("", state).await; + } + } + + Ok(result) + } + + fn detect_format(&self, text: &str) -> bool { + // Llama format if contains python_tag or starts with JSON object + text.contains("<|python_tag|>") + || (text.trim_start().starts_with('{') + && (text.contains(r#""name""#) || text.contains(r#""function""#))) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_parse_with_python_tag() { + let parser = LlamaParser::new(); + let input = r#"<|python_tag|>{"name": "search", "arguments": {"query": "weather"}}"#; + + let result = parser.parse_complete(input).await.unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result[0].function.name, "search"); + assert!(result[0].function.arguments.contains("weather")); + } + + #[tokio::test] + async fn test_parse_plain_json() { + let parser = LlamaParser::new(); + let input = r#"{"name": "calculate", "arguments": {"x": 5, "y": 10}}"#; + + let result = parser.parse_complete(input).await.unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result[0].function.name, "calculate"); + } + + #[tokio::test] + async fn test_parse_with_text_before() { + let parser = LlamaParser::new(); + let input = r#"Let me help you with that. <|python_tag|>{"name": "get_time", "arguments": {"timezone": "UTC"}}"#; + + let result = parser.parse_complete(input).await.unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result[0].function.name, "get_time"); + } + + #[test] + fn test_detect_format() { + let parser = LlamaParser::new(); + + assert!(parser.detect_format(r#"<|python_tag|>{"name": "test"}"#)); + assert!(parser.detect_format(r#"{"name": "test", "arguments": {}}"#)); + assert!(!parser.detect_format("plain text")); + assert!(!parser.detect_format(r#"{"key": "value"}"#)); // No name field + } + + #[tokio::test] + async fn test_single_call_with_semicolon() { + let parser = LlamaParser::new(); + // Note: Llama 3.2 doesn't handle multiple calls well + // Test that we can at least parse a single call followed by semicolon + let input = r#"<|python_tag|>{"name": "func1", "arguments": {"x": 1}};"#; + + let result = parser.parse_complete(input).await.unwrap(); + + // We expect this to either parse the first JSON object or fail gracefully + // Since the semicolon makes it invalid JSON, it will likely return empty + // This is acceptable as Llama 3.2 doesn't reliably support parallel calls + + // If it parses anything, it should be func1 + if !result.is_empty() { + assert_eq!(result[0].function.name, "func1"); + } + } +} diff --git a/sgl-router/src/tool_parser/mod.rs b/sgl-router/src/tool_parser/mod.rs index 54b5a0a11..ce83bf112 100644 --- a/sgl-router/src/tool_parser/mod.rs +++ b/sgl-router/src/tool_parser/mod.rs @@ -3,6 +3,7 @@ /// This module provides infrastructure for parsing tool calls from various model formats. pub mod errors; pub mod json_parser; +pub mod llama_parser; pub mod mistral_parser; pub mod partial_json; pub mod python_literal_parser; @@ -19,6 +20,7 @@ mod tests; // Re-export commonly used types pub use errors::{ToolParserError, ToolParserResult}; pub use json_parser::JsonParser; +pub use llama_parser::LlamaParser; pub use mistral_parser::MistralParser; pub use pythonic_parser::PythonicParser; pub use qwen_parser::QwenParser; diff --git a/sgl-router/src/tool_parser/registry.rs b/sgl-router/src/tool_parser/registry.rs index 598009aa4..a71fd6232 100644 --- a/sgl-router/src/tool_parser/registry.rs +++ b/sgl-router/src/tool_parser/registry.rs @@ -1,4 +1,5 @@ use crate::tool_parser::json_parser::JsonParser; +use crate::tool_parser::llama_parser::LlamaParser; use crate::tool_parser::mistral_parser::MistralParser; use crate::tool_parser::pythonic_parser::PythonicParser; use crate::tool_parser::qwen_parser::QwenParser; @@ -108,6 +109,9 @@ impl ParserRegistry { // Pythonic parser - [func(arg=val)] format self.register_parser("pythonic", Arc::new(PythonicParser::new())); + + // Llama parser - <|python_tag|>{...} or plain JSON format + self.register_parser("llama", Arc::new(LlamaParser::new())); } /// Register default model mappings