[router][tool parser] Modify tool parser to return both normal text and tool calls (non-stream) (#10995)

This commit is contained in:
Chang Su
2025-09-27 15:10:17 -07:00
committed by GitHub
parent f6bc3f529b
commit c1c8dd1dd0
30 changed files with 1467 additions and 934 deletions

View File

@@ -45,7 +45,8 @@ impl PythonicParser {
}
/// Extract tool calls using bracket counting (similar to MistralParser)
fn extract_tool_calls(&self, text: &str) -> Option<String> {
/// Returns extracted tool call group with [] and normal content
fn extract_tool_calls(&self, text: &str) -> Option<(String, String)> {
// Find the start of a tool call list - look for [ followed by a function name
let chars: Vec<char> = text.chars().collect();
@@ -103,7 +104,11 @@ impl PythonicParser {
// Found the matching bracket
let extracted: String = chars[start_idx..=i].iter().collect();
if extracted.contains('(') && extracted.contains(')') {
return Some(extracted);
// Calculate normal text by removing the tool call portion
let before = &text[..start_idx];
let after = &text[(i + 1)..];
let normal_text = format!("{}{}", before, after);
return Some((extracted, normal_text));
}
}
}
@@ -260,11 +265,11 @@ impl PythonicParser {
#[async_trait]
impl ToolParser for PythonicParser {
async fn parse_complete(&self, text: &str) -> ToolParserResult<Vec<ToolCall>> {
async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec<ToolCall>)> {
let cleaned = Self::strip_special_tokens(text);
// Extract tool calls using bracket counting
if let Some(tool_calls_text) = self.extract_tool_calls(&cleaned) {
if let Some((tool_calls_text, normal_text)) = self.extract_tool_calls(&cleaned) {
// Remove the outer brackets
let tool_calls_str = &tool_calls_text[1..tool_calls_text.len() - 1];
@@ -318,9 +323,9 @@ impl ToolParser for PythonicParser {
}
}
Ok(calls)
Ok((normal_text, calls))
} else {
Ok(vec![])
Ok((text.to_string(), vec![]))
}
}
@@ -336,11 +341,11 @@ impl ToolParser for PythonicParser {
// Try to parse if we have a complete tool call
let cleaned = Self::strip_special_tokens(&state.buffer);
if self.extract_tool_calls(&cleaned).is_some() {
let result = self.parse_complete(&state.buffer).await?;
if !result.is_empty() {
let (_normal_text, tools) = self.parse_complete(&state.buffer).await?;
if !tools.is_empty() {
state.buffer.clear();
return Ok(StreamResult::ToolComplete(
result.into_iter().next().unwrap(),
tools.into_iter().next().unwrap(),
));
}
}
@@ -369,11 +374,11 @@ mod tests {
let parser = PythonicParser::new();
let input = r#"[search_web(query="Rust programming", max_results=5)]"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "search_web");
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
assert_eq!(tools.len(), 1);
assert_eq!(tools[0].function.name, "search_web");
let args: Value = serde_json::from_str(&result[0].function.arguments).unwrap();
let args: Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
assert_eq!(args["query"], "Rust programming");
assert_eq!(args["max_results"], 5);
}
@@ -383,10 +388,10 @@ mod tests {
let parser = PythonicParser::new();
let input = r#"[get_weather(city="Tokyo"), search(query="news")]"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 2);
assert_eq!(result[0].function.name, "get_weather");
assert_eq!(result[1].function.name, "search");
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
assert_eq!(tools.len(), 2);
assert_eq!(tools[0].function.name, "get_weather");
assert_eq!(tools[1].function.name, "search");
}
#[tokio::test]
@@ -394,10 +399,10 @@ mod tests {
let parser = PythonicParser::new();
let input = r#"[test(flag=True, disabled=False, optional=None)]"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
assert_eq!(tools.len(), 1);
let args: Value = serde_json::from_str(&result[0].function.arguments).unwrap();
let args: Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
assert_eq!(args["flag"], true);
assert_eq!(args["disabled"], false);
assert_eq!(args["optional"], Value::Null);
@@ -408,11 +413,11 @@ mod tests {
let parser = PythonicParser::new();
let input = r#"<|python_start|>[calculate(x=10, y=20)]<|python_end|>"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "calculate");
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
assert_eq!(tools.len(), 1);
assert_eq!(tools[0].function.name, "calculate");
let args: Value = serde_json::from_str(&result[0].function.arguments).unwrap();
let args: Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
assert_eq!(args["x"], 10);
assert_eq!(args["y"], 20);
}
@@ -422,12 +427,41 @@ mod tests {
let parser = PythonicParser::new();
let input = r#"[get_weather(city="London", units="celsius")]"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "get_weather");
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
assert_eq!(tools.len(), 1);
assert_eq!(tools[0].function.name, "get_weather");
let args: Value = serde_json::from_str(&result[0].function.arguments).unwrap();
let args: Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
assert_eq!(args["city"], "London");
assert_eq!(args["units"], "celsius");
}
#[tokio::test]
async fn test_normal_text_extraction() {
let parser = PythonicParser::new();
// Test with text before and after
let input = r#"Please check the weather [get_weather(city="Tokyo")] and let me know."#;
let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
assert_eq!(tools.len(), 1);
assert_eq!(tools[0].function.name, "get_weather");
assert_eq!(normal_text, "Please check the weather and let me know.");
// Test with only normal text (no tool calls)
let input_no_tools = "This is just normal text without any tool calls.";
let (normal_text, tools) = parser.parse_complete(input_no_tools).await.unwrap();
assert_eq!(tools.len(), 0);
assert_eq!(normal_text, input_no_tools);
// Test with multiple tool calls in single bracket group and normal text
let input_multiple = r#"First, [search(query="rust"), calculate(x=5, y=10)] please."#;
let (normal_text, tools) = parser.parse_complete(input_multiple).await.unwrap();
assert_eq!(tools.len(), 2);
assert_eq!(tools[0].function.name, "search");
assert_eq!(tools[1].function.name, "calculate");
assert_eq!(normal_text, "First, please.");
}
}