[router] add llama tool parser (#9629)
Co-authored-by: Chang Su <chang.s.su@oracle.com>
This commit is contained in:
156
sgl-router/src/tool_parser/llama_parser.rs
Normal file
156
sgl-router/src/tool_parser/llama_parser.rs
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
use async_trait::async_trait;
|
||||||
|
|
||||||
|
use crate::tool_parser::{
|
||||||
|
errors::ToolParserResult,
|
||||||
|
json_parser::JsonParser,
|
||||||
|
state::ParseState,
|
||||||
|
traits::ToolParser,
|
||||||
|
types::{StreamResult, TokenConfig, ToolCall},
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Llama 3.2 format parser for tool calls
|
||||||
|
///
|
||||||
|
/// Handles the Llama 3.2 specific format:
|
||||||
|
/// `<|python_tag|>{"name": "func", "arguments": {...}}`
|
||||||
|
///
|
||||||
|
/// Also supports plain JSON without the python_tag prefix
|
||||||
|
pub struct LlamaParser {
|
||||||
|
/// Underlying JSON parser with Llama-specific configuration
|
||||||
|
json_parser: JsonParser,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LlamaParser {
|
||||||
|
/// Create a new Llama parser
|
||||||
|
pub fn new() -> Self {
|
||||||
|
// Configure JSON parser with Llama's python_tag token
|
||||||
|
// Note: No end token for python_tag format
|
||||||
|
let json_parser = JsonParser::with_config(TokenConfig {
|
||||||
|
start_tokens: vec!["<|python_tag|>".to_string()],
|
||||||
|
end_tokens: vec!["".to_string()], // Empty end token
|
||||||
|
separator: ";".to_string(), // Llama uses semicolon for multiple calls (though not well supported)
|
||||||
|
});
|
||||||
|
|
||||||
|
Self { json_parser }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for LlamaParser {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl ToolParser for LlamaParser {
|
||||||
|
async fn parse_complete(&self, text: &str) -> ToolParserResult<Vec<ToolCall>> {
|
||||||
|
// First try with the configured python_tag parser
|
||||||
|
let result = self.json_parser.parse_complete(text).await?;
|
||||||
|
|
||||||
|
if !result.is_empty() {
|
||||||
|
return Ok(result);
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no results and text starts with '{', try plain JSON
|
||||||
|
if text.trim_start().starts_with('{') {
|
||||||
|
// Create a temporary plain JSON parser
|
||||||
|
let plain_parser = JsonParser::new();
|
||||||
|
return plain_parser.parse_complete(text).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(vec![])
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn parse_incremental(
|
||||||
|
&self,
|
||||||
|
chunk: &str,
|
||||||
|
state: &mut ParseState,
|
||||||
|
) -> ToolParserResult<StreamResult> {
|
||||||
|
// Try with the python_tag parser first
|
||||||
|
let result = self.json_parser.parse_incremental(chunk, state).await?;
|
||||||
|
|
||||||
|
// If we get Incomplete and buffer starts with '{', might be plain JSON
|
||||||
|
if matches!(result, StreamResult::Incomplete) && state.buffer.trim_start().starts_with('{')
|
||||||
|
{
|
||||||
|
// Check if we have python_tag in the buffer
|
||||||
|
if !state.buffer.contains("<|python_tag|>") {
|
||||||
|
// Likely plain JSON, create temporary parser
|
||||||
|
let plain_parser = JsonParser::new();
|
||||||
|
return plain_parser.parse_incremental("", state).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn detect_format(&self, text: &str) -> bool {
|
||||||
|
// Llama format if contains python_tag or starts with JSON object
|
||||||
|
text.contains("<|python_tag|>")
|
||||||
|
|| (text.trim_start().starts_with('{')
|
||||||
|
&& (text.contains(r#""name""#) || text.contains(r#""function""#)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_parse_with_python_tag() {
|
||||||
|
let parser = LlamaParser::new();
|
||||||
|
let input = r#"<|python_tag|>{"name": "search", "arguments": {"query": "weather"}}"#;
|
||||||
|
|
||||||
|
let result = parser.parse_complete(input).await.unwrap();
|
||||||
|
assert_eq!(result.len(), 1);
|
||||||
|
assert_eq!(result[0].function.name, "search");
|
||||||
|
assert!(result[0].function.arguments.contains("weather"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_parse_plain_json() {
|
||||||
|
let parser = LlamaParser::new();
|
||||||
|
let input = r#"{"name": "calculate", "arguments": {"x": 5, "y": 10}}"#;
|
||||||
|
|
||||||
|
let result = parser.parse_complete(input).await.unwrap();
|
||||||
|
assert_eq!(result.len(), 1);
|
||||||
|
assert_eq!(result[0].function.name, "calculate");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_parse_with_text_before() {
|
||||||
|
let parser = LlamaParser::new();
|
||||||
|
let input = r#"Let me help you with that. <|python_tag|>{"name": "get_time", "arguments": {"timezone": "UTC"}}"#;
|
||||||
|
|
||||||
|
let result = parser.parse_complete(input).await.unwrap();
|
||||||
|
assert_eq!(result.len(), 1);
|
||||||
|
assert_eq!(result[0].function.name, "get_time");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_detect_format() {
|
||||||
|
let parser = LlamaParser::new();
|
||||||
|
|
||||||
|
assert!(parser.detect_format(r#"<|python_tag|>{"name": "test"}"#));
|
||||||
|
assert!(parser.detect_format(r#"{"name": "test", "arguments": {}}"#));
|
||||||
|
assert!(!parser.detect_format("plain text"));
|
||||||
|
assert!(!parser.detect_format(r#"{"key": "value"}"#)); // No name field
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_single_call_with_semicolon() {
|
||||||
|
let parser = LlamaParser::new();
|
||||||
|
// Note: Llama 3.2 doesn't handle multiple calls well
|
||||||
|
// Test that we can at least parse a single call followed by semicolon
|
||||||
|
let input = r#"<|python_tag|>{"name": "func1", "arguments": {"x": 1}};"#;
|
||||||
|
|
||||||
|
let result = parser.parse_complete(input).await.unwrap();
|
||||||
|
|
||||||
|
// We expect this to either parse the first JSON object or fail gracefully
|
||||||
|
// Since the semicolon makes it invalid JSON, it will likely return empty
|
||||||
|
// This is acceptable as Llama 3.2 doesn't reliably support parallel calls
|
||||||
|
|
||||||
|
// If it parses anything, it should be func1
|
||||||
|
if !result.is_empty() {
|
||||||
|
assert_eq!(result[0].function.name, "func1");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -3,6 +3,7 @@
|
|||||||
/// This module provides infrastructure for parsing tool calls from various model formats.
|
/// This module provides infrastructure for parsing tool calls from various model formats.
|
||||||
pub mod errors;
|
pub mod errors;
|
||||||
pub mod json_parser;
|
pub mod json_parser;
|
||||||
|
pub mod llama_parser;
|
||||||
pub mod mistral_parser;
|
pub mod mistral_parser;
|
||||||
pub mod partial_json;
|
pub mod partial_json;
|
||||||
pub mod python_literal_parser;
|
pub mod python_literal_parser;
|
||||||
@@ -19,6 +20,7 @@ mod tests;
|
|||||||
// Re-export commonly used types
|
// Re-export commonly used types
|
||||||
pub use errors::{ToolParserError, ToolParserResult};
|
pub use errors::{ToolParserError, ToolParserResult};
|
||||||
pub use json_parser::JsonParser;
|
pub use json_parser::JsonParser;
|
||||||
|
pub use llama_parser::LlamaParser;
|
||||||
pub use mistral_parser::MistralParser;
|
pub use mistral_parser::MistralParser;
|
||||||
pub use pythonic_parser::PythonicParser;
|
pub use pythonic_parser::PythonicParser;
|
||||||
pub use qwen_parser::QwenParser;
|
pub use qwen_parser::QwenParser;
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
use crate::tool_parser::json_parser::JsonParser;
|
use crate::tool_parser::json_parser::JsonParser;
|
||||||
|
use crate::tool_parser::llama_parser::LlamaParser;
|
||||||
use crate::tool_parser::mistral_parser::MistralParser;
|
use crate::tool_parser::mistral_parser::MistralParser;
|
||||||
use crate::tool_parser::pythonic_parser::PythonicParser;
|
use crate::tool_parser::pythonic_parser::PythonicParser;
|
||||||
use crate::tool_parser::qwen_parser::QwenParser;
|
use crate::tool_parser::qwen_parser::QwenParser;
|
||||||
@@ -108,6 +109,9 @@ impl ParserRegistry {
|
|||||||
|
|
||||||
// Pythonic parser - [func(arg=val)] format
|
// Pythonic parser - [func(arg=val)] format
|
||||||
self.register_parser("pythonic", Arc::new(PythonicParser::new()));
|
self.register_parser("pythonic", Arc::new(PythonicParser::new()));
|
||||||
|
|
||||||
|
// Llama parser - <|python_tag|>{...} or plain JSON format
|
||||||
|
self.register_parser("llama", Arc::new(LlamaParser::new()));
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Register default model mappings
|
/// Register default model mappings
|
||||||
|
|||||||
Reference in New Issue
Block a user