[router] add llama tool parser (#9629)
Co-authored-by: Chang Su <chang.s.su@oracle.com>
This commit is contained in:
156
sgl-router/src/tool_parser/llama_parser.rs
Normal file
156
sgl-router/src/tool_parser/llama_parser.rs
Normal file
@@ -0,0 +1,156 @@
|
||||
use async_trait::async_trait;
|
||||
|
||||
use crate::tool_parser::{
|
||||
errors::ToolParserResult,
|
||||
json_parser::JsonParser,
|
||||
state::ParseState,
|
||||
traits::ToolParser,
|
||||
types::{StreamResult, TokenConfig, ToolCall},
|
||||
};
|
||||
|
||||
/// Llama 3.2 format parser for tool calls
|
||||
///
|
||||
/// Handles the Llama 3.2 specific format:
|
||||
/// `<|python_tag|>{"name": "func", "arguments": {...}}`
|
||||
///
|
||||
/// Also supports plain JSON without the python_tag prefix
|
||||
pub struct LlamaParser {
|
||||
/// Underlying JSON parser with Llama-specific configuration
|
||||
json_parser: JsonParser,
|
||||
}
|
||||
|
||||
impl LlamaParser {
|
||||
/// Create a new Llama parser
|
||||
pub fn new() -> Self {
|
||||
// Configure JSON parser with Llama's python_tag token
|
||||
// Note: No end token for python_tag format
|
||||
let json_parser = JsonParser::with_config(TokenConfig {
|
||||
start_tokens: vec!["<|python_tag|>".to_string()],
|
||||
end_tokens: vec!["".to_string()], // Empty end token
|
||||
separator: ";".to_string(), // Llama uses semicolon for multiple calls (though not well supported)
|
||||
});
|
||||
|
||||
Self { json_parser }
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for LlamaParser {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ToolParser for LlamaParser {
|
||||
async fn parse_complete(&self, text: &str) -> ToolParserResult<Vec<ToolCall>> {
|
||||
// First try with the configured python_tag parser
|
||||
let result = self.json_parser.parse_complete(text).await?;
|
||||
|
||||
if !result.is_empty() {
|
||||
return Ok(result);
|
||||
}
|
||||
|
||||
// If no results and text starts with '{', try plain JSON
|
||||
if text.trim_start().starts_with('{') {
|
||||
// Create a temporary plain JSON parser
|
||||
let plain_parser = JsonParser::new();
|
||||
return plain_parser.parse_complete(text).await;
|
||||
}
|
||||
|
||||
Ok(vec![])
|
||||
}
|
||||
|
||||
async fn parse_incremental(
|
||||
&self,
|
||||
chunk: &str,
|
||||
state: &mut ParseState,
|
||||
) -> ToolParserResult<StreamResult> {
|
||||
// Try with the python_tag parser first
|
||||
let result = self.json_parser.parse_incremental(chunk, state).await?;
|
||||
|
||||
// If we get Incomplete and buffer starts with '{', might be plain JSON
|
||||
if matches!(result, StreamResult::Incomplete) && state.buffer.trim_start().starts_with('{')
|
||||
{
|
||||
// Check if we have python_tag in the buffer
|
||||
if !state.buffer.contains("<|python_tag|>") {
|
||||
// Likely plain JSON, create temporary parser
|
||||
let plain_parser = JsonParser::new();
|
||||
return plain_parser.parse_incremental("", state).await;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn detect_format(&self, text: &str) -> bool {
|
||||
// Llama format if contains python_tag or starts with JSON object
|
||||
text.contains("<|python_tag|>")
|
||||
|| (text.trim_start().starts_with('{')
|
||||
&& (text.contains(r#""name""#) || text.contains(r#""function""#)))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_with_python_tag() {
|
||||
let parser = LlamaParser::new();
|
||||
let input = r#"<|python_tag|>{"name": "search", "arguments": {"query": "weather"}}"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "search");
|
||||
assert!(result[0].function.arguments.contains("weather"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_plain_json() {
|
||||
let parser = LlamaParser::new();
|
||||
let input = r#"{"name": "calculate", "arguments": {"x": 5, "y": 10}}"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "calculate");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_with_text_before() {
|
||||
let parser = LlamaParser::new();
|
||||
let input = r#"Let me help you with that. <|python_tag|>{"name": "get_time", "arguments": {"timezone": "UTC"}}"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "get_time");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_format() {
|
||||
let parser = LlamaParser::new();
|
||||
|
||||
assert!(parser.detect_format(r#"<|python_tag|>{"name": "test"}"#));
|
||||
assert!(parser.detect_format(r#"{"name": "test", "arguments": {}}"#));
|
||||
assert!(!parser.detect_format("plain text"));
|
||||
assert!(!parser.detect_format(r#"{"key": "value"}"#)); // No name field
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_single_call_with_semicolon() {
|
||||
let parser = LlamaParser::new();
|
||||
// Note: Llama 3.2 doesn't handle multiple calls well
|
||||
// Test that we can at least parse a single call followed by semicolon
|
||||
let input = r#"<|python_tag|>{"name": "func1", "arguments": {"x": 1}};"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
|
||||
// We expect this to either parse the first JSON object or fail gracefully
|
||||
// Since the semicolon makes it invalid JSON, it will likely return empty
|
||||
// This is acceptable as Llama 3.2 doesn't reliably support parallel calls
|
||||
|
||||
// If it parses anything, it should be func1
|
||||
if !result.is_empty() {
|
||||
assert_eq!(result[0].function.name, "func1");
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -3,6 +3,7 @@
|
||||
/// This module provides infrastructure for parsing tool calls from various model formats.
|
||||
pub mod errors;
|
||||
pub mod json_parser;
|
||||
pub mod llama_parser;
|
||||
pub mod mistral_parser;
|
||||
pub mod partial_json;
|
||||
pub mod python_literal_parser;
|
||||
@@ -19,6 +20,7 @@ mod tests;
|
||||
// Re-export commonly used types
|
||||
pub use errors::{ToolParserError, ToolParserResult};
|
||||
pub use json_parser::JsonParser;
|
||||
pub use llama_parser::LlamaParser;
|
||||
pub use mistral_parser::MistralParser;
|
||||
pub use pythonic_parser::PythonicParser;
|
||||
pub use qwen_parser::QwenParser;
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use crate::tool_parser::json_parser::JsonParser;
|
||||
use crate::tool_parser::llama_parser::LlamaParser;
|
||||
use crate::tool_parser::mistral_parser::MistralParser;
|
||||
use crate::tool_parser::pythonic_parser::PythonicParser;
|
||||
use crate::tool_parser::qwen_parser::QwenParser;
|
||||
@@ -108,6 +109,9 @@ impl ParserRegistry {
|
||||
|
||||
// Pythonic parser - [func(arg=val)] format
|
||||
self.register_parser("pythonic", Arc::new(PythonicParser::new()));
|
||||
|
||||
// Llama parser - <|python_tag|>{...} or plain JSON format
|
||||
self.register_parser("llama", Arc::new(LlamaParser::new()));
|
||||
}
|
||||
|
||||
/// Register default model mappings
|
||||
|
||||
Reference in New Issue
Block a user