[router][tool parser] Modify tool parser to return both normal text and tool calls (non-stream) (#10995)
This commit is contained in:
@@ -9,11 +9,11 @@ async fn test_llama_python_tag_format() {
|
||||
let parser = LlamaParser::new();
|
||||
let input = r#"<|python_tag|>{"name": "search", "arguments": {"query": "weather"}}"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "search");
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0].function.name, "search");
|
||||
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
|
||||
assert_eq!(args["query"], "weather");
|
||||
}
|
||||
|
||||
@@ -22,11 +22,11 @@ async fn test_llama_plain_json_fallback() {
|
||||
let parser = LlamaParser::new();
|
||||
let input = r#"{"name": "calculate", "arguments": {"x": 5, "y": 10}}"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "calculate");
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0].function.name, "calculate");
|
||||
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
|
||||
assert_eq!(args["x"], 5);
|
||||
assert_eq!(args["y"], 10);
|
||||
}
|
||||
@@ -36,11 +36,11 @@ async fn test_llama_with_text_before() {
|
||||
let parser = LlamaParser::new();
|
||||
let input = r#"Let me help you with that. <|python_tag|>{"name": "get_time", "arguments": {"timezone": "UTC"}}"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "get_time");
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0].function.name, "get_time");
|
||||
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
|
||||
assert_eq!(args["timezone"], "UTC");
|
||||
}
|
||||
|
||||
@@ -58,11 +58,11 @@ async fn test_llama_with_nested_json() {
|
||||
}
|
||||
}"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "update_settings");
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0].function.name, "update_settings");
|
||||
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
|
||||
assert_eq!(args["preferences"]["theme"], "dark");
|
||||
assert_eq!(args["notifications"], true);
|
||||
}
|
||||
@@ -73,15 +73,15 @@ async fn test_llama_empty_arguments() {
|
||||
|
||||
// With python_tag
|
||||
let input = r#"<|python_tag|>{"name": "ping", "arguments": {}}"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "ping");
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0].function.name, "ping");
|
||||
|
||||
// Plain JSON
|
||||
let input = r#"{"name": "ping", "arguments": {}}"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "ping");
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0].function.name, "ping");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -99,8 +99,8 @@ async fn test_llama_invalid_json_after_tag() {
|
||||
let parser = LlamaParser::new();
|
||||
|
||||
let input = r#"<|python_tag|>{"name": invalid}"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 0);
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -112,9 +112,9 @@ async fn test_llama_real_world_output() {
|
||||
|
||||
<|python_tag|>{"name": "web_search", "arguments": {"query": "Llama 3.2 model capabilities", "num_results": 5, "search_type": "recent"}}"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "web_search");
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0].function.name, "web_search");
|
||||
|
||||
let formatted_input = r#"<|python_tag|>{
|
||||
"name": "get_current_time",
|
||||
@@ -124,9 +124,9 @@ async fn test_llama_real_world_output() {
|
||||
}
|
||||
}"#;
|
||||
|
||||
let result2 = parser.parse_complete(formatted_input).await.unwrap();
|
||||
assert_eq!(result2.len(), 1);
|
||||
assert_eq!(result2[0].function.name, "get_current_time");
|
||||
let (_normal_text, tools2) = parser.parse_complete(formatted_input).await.unwrap();
|
||||
assert_eq!(tools2.len(), 1);
|
||||
assert_eq!(tools2[0].function.name, "get_current_time");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -136,9 +136,9 @@ async fn test_llama_json_array_format() {
|
||||
// Plain JSON array (should work as fallback)
|
||||
let input = r#"[{"name": "func1", "arguments": {}}, {"name": "func2", "arguments": {}}]"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
// Current implementation might handle this through JSON fallback
|
||||
assert!(!result.is_empty());
|
||||
assert!(!tools.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -146,11 +146,11 @@ async fn test_single_json() {
|
||||
let parser = LlamaParser::new();
|
||||
let text = r#"{"name": "get_weather", "arguments": {"city": "Paris"}}"#;
|
||||
|
||||
let result = parser.parse_complete(text).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "get_weather");
|
||||
let (_normal_text, tools) = parser.parse_complete(text).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0].function.name, "get_weather");
|
||||
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
|
||||
assert_eq!(args["city"], "Paris");
|
||||
}
|
||||
|
||||
@@ -159,10 +159,10 @@ async fn test_multiple_json_with_separator() {
|
||||
let parser = LlamaParser::new();
|
||||
let text = r#"<|python_tag|>{"name": "get_weather", "arguments": {"city": "Paris"}};{"name": "get_tourist_attractions", "arguments": {"city": "Paris"}}"#;
|
||||
|
||||
let result = parser.parse_complete(text).await.unwrap();
|
||||
let (_normal_text, tools) = parser.parse_complete(text).await.unwrap();
|
||||
// Note: Current implementation may only parse the first one due to semicolon handling
|
||||
assert!(!result.is_empty());
|
||||
assert_eq!(result[0].function.name, "get_weather");
|
||||
assert!(!tools.is_empty());
|
||||
assert_eq!(tools[0].function.name, "get_weather");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -170,10 +170,10 @@ async fn test_multiple_json_with_separator_customized() {
|
||||
let parser = LlamaParser::new();
|
||||
let text = r#"<|python_tag|>{"name": "get_weather", "arguments": {}}<|python_tag|>{"name": "get_tourist_attractions", "arguments": {}}"#;
|
||||
|
||||
let result = parser.parse_complete(text).await.unwrap();
|
||||
let (_normal_text, tools) = parser.parse_complete(text).await.unwrap();
|
||||
// Current implementation may handle this differently
|
||||
assert!(!result.is_empty());
|
||||
assert_eq!(result[0].function.name, "get_weather");
|
||||
assert!(!tools.is_empty());
|
||||
assert_eq!(tools[0].function.name, "get_weather");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -181,9 +181,9 @@ async fn test_json_with_trailing_text() {
|
||||
let parser = LlamaParser::new();
|
||||
let text = r#"{"name": "get_weather", "arguments": {}} Some follow-up text"#;
|
||||
|
||||
let result = parser.parse_complete(text).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "get_weather");
|
||||
let (_normal_text, tools) = parser.parse_complete(text).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0].function.name, "get_weather");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -191,10 +191,10 @@ async fn test_invalid_then_valid_json() {
|
||||
let parser = LlamaParser::new();
|
||||
let text = r#"{"name": "get_weather", "arguments": {{"name": "get_weather", "arguments": {}}"#;
|
||||
|
||||
let result = parser.parse_complete(text).await.unwrap();
|
||||
let (_normal_text, tools) = parser.parse_complete(text).await.unwrap();
|
||||
// Should parse at least one valid JSON
|
||||
if !result.is_empty() {
|
||||
assert_eq!(result[0].function.name, "get_weather");
|
||||
if !tools.is_empty() {
|
||||
assert_eq!(tools[0].function.name, "get_weather");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -203,8 +203,8 @@ async fn test_plain_text_only() {
|
||||
let parser = LlamaParser::new();
|
||||
let text = "This is just plain explanation text.";
|
||||
|
||||
let result = parser.parse_complete(text).await.unwrap();
|
||||
assert_eq!(result.len(), 0);
|
||||
let (_normal_text, tools) = parser.parse_complete(text).await.unwrap();
|
||||
assert_eq!(tools.len(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -212,9 +212,9 @@ async fn test_with_python_tag_prefix() {
|
||||
let parser = LlamaParser::new();
|
||||
let text = r#"Some intro. <|python_tag|>{"name": "get_weather", "arguments": {}}"#;
|
||||
|
||||
let result = parser.parse_complete(text).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "get_weather");
|
||||
let (_normal_text, tools) = parser.parse_complete(text).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0].function.name, "get_weather");
|
||||
}
|
||||
|
||||
// STREAMING TESTS
|
||||
|
||||
Reference in New Issue
Block a user