[router] migrate to rust python module for pythonic parser (#11033)

This commit is contained in:
Simo Lin
2025-09-28 14:48:59 -04:00
committed by GitHub
parent abb6781573
commit 336e9a6058
4 changed files with 284 additions and 780 deletions

View File

@@ -5,262 +5,77 @@
/// [tool1(arg1=val1, arg2=val2), tool2(arg1=val3)]
/// ```
///
/// This format is used by Llama-4 models and uses Python literals
/// This format is used by Llama models and uses Python literals
/// rather than JSON for arguments.
use async_trait::async_trait;
use num_traits::ToPrimitive;
use regex::Regex;
use serde_json::{json, Value};
use rustpython_parser::ast::{Constant, Expr, Mod, UnaryOp};
use rustpython_parser::{parse, Mode};
use serde_json::{Map, Number, Value};
use std::sync::OnceLock;
use crate::tool_parser::{
errors::ToolParserResult,
python_literal_parser::parse_python_literal,
errors::{ToolParserError, ToolParserResult},
state::ParseState,
traits::ToolParser,
types::{FunctionCall, StreamResult, ToolCall},
};
/// Parser for Pythonic tool call format
pub struct PythonicParser {
/// Regex to detect tool calls in Pythonic format
tool_call_regex: Regex,
/// Regex to parse function calls - cached for reuse
call_regex: Regex,
static PYTHONIC_BLOCK_REGEX: OnceLock<Regex> = OnceLock::new();
/// Lazily compiled regex that locates pythonic tool call blocks.
fn pythonic_block_regex() -> &'static Regex {
PYTHONIC_BLOCK_REGEX.get_or_init(|| {
// Matches one or more function calls inside a list. The `(?s)` flag allows
// newlines inside argument lists while keeping the pattern anchored to
// identifiers followed by parentheses, preventing plain lists like
// `[1, 2, 3]` from matching.
Regex::new(r"(?s)\[\s*[A-Za-z_]\w*\s*\(.*?\)\s*(?:,\s*[A-Za-z_]\w*\s*\(.*?\)\s*)*\]")
.expect("pythonic tool call regex must compile")
})
}
/// Parser for Pythonic tool call format
#[derive(Default)]
pub struct PythonicParser;
impl PythonicParser {
/// Create a new Pythonic parser
pub fn new() -> Self {
// Simple regex to detect start of Pythonic tool calls
// We'll use manual parsing for the actual extraction
let pattern = r"\[[a-zA-Z_]\w*\(";
let tool_call_regex = Regex::new(pattern).expect("Valid regex pattern");
// Compile the function call regex once
let call_regex = Regex::new(r"(?s)^([a-zA-Z_]\w*)\((.*)\)$").expect("Valid regex pattern");
Self {
tool_call_regex,
call_regex,
}
Self
}
/// Extract tool calls using bracket counting (similar to MistralParser)
/// Returns extracted tool call group with [] and normal content
/// Extract the first pythonic tool call block and return it along with the
/// surrounding "normal" content.
fn extract_tool_calls(&self, text: &str) -> Option<(String, String)> {
// Find the start of a tool call list - look for [ followed by a function name
let chars: Vec<char> = text.chars().collect();
for start_idx in 0..chars.len() {
if chars[start_idx] != '[' {
continue;
}
// Check if this looks like a tool call
// Skip whitespace after [
let mut check_idx = start_idx + 1;
while check_idx < chars.len() && chars[check_idx].is_whitespace() {
check_idx += 1;
}
// Check if we have a function name (starts with letter or underscore)
if check_idx >= chars.len()
|| (!chars[check_idx].is_alphabetic() && chars[check_idx] != '_')
{
continue;
}
// Now count brackets to find the matching ]
let mut bracket_count = 0;
let mut _paren_count = 0;
let mut _brace_count = 0;
let mut in_string = false;
let mut string_char = ' ';
let mut escape_next = false;
for i in start_idx..chars.len() {
let ch = chars[i];
if escape_next {
escape_next = false;
continue;
}
if ch == '\\' && in_string {
escape_next = true;
continue;
}
if !in_string && (ch == '"' || ch == '\'') {
in_string = true;
string_char = ch;
} else if in_string && ch == string_char && !escape_next {
in_string = false;
} else if !in_string {
match ch {
'[' => bracket_count += 1,
']' => {
bracket_count -= 1;
if bracket_count == 0 {
// Found the matching bracket
let extracted: String = chars[start_idx..=i].iter().collect();
if extracted.contains('(') && extracted.contains(')') {
// Calculate normal text by removing the tool call portion
let before = &text[..start_idx];
let after = &text[(i + 1)..];
let normal_text = format!("{}{}", before, after);
return Some((extracted, normal_text));
}
}
}
'(' => _paren_count += 1,
')' => _paren_count -= 1,
'{' => _brace_count += 1,
'}' => _brace_count -= 1,
_ => {}
}
}
}
}
None
pythonic_block_regex().find(text).map(|mat| {
let block = mat.as_str().to_string();
let normal = format!("{}{}", &text[..mat.start()], &text[mat.end()..]);
(block, normal)
})
}
/// Strip special tokens that Llama 4 might output
/// Strip special tokens that Llama models might output
fn strip_special_tokens(text: &str) -> String {
text.replace("<|python_start|>", "")
.replace("<|python_end|>", "")
}
/// Parse a single function call from Python syntax
fn parse_function_call(&self, call_str: &str) -> ToolParserResult<Option<ToolCall>> {
// Use cached regex instead of creating new one
if let Some(captures) = self.call_regex.captures(call_str.trim()) {
let function_name = captures.get(1).unwrap().as_str();
let args_str = captures.get(2).unwrap().as_str();
// Parse arguments
let arguments = self.parse_arguments(args_str)?;
Ok(Some(ToolCall {
id: format!("call_{}", uuid::Uuid::new_v4()),
r#type: "function".to_string(),
function: FunctionCall {
name: function_name.to_string(),
arguments: serde_json::to_string(&arguments)?,
},
}))
} else {
Ok(None)
fn parse_tool_call_block(&self, block: &str) -> ToolParserResult<Vec<ToolCall>> {
let expr = parse_python_expression(block)?;
match expr {
Expr::List(list_expr) => list_expr
.elts
.into_iter()
.enumerate()
.map(|(idx, call_expr)| build_tool_call(call_expr, idx))
.collect(),
_ => Err(ToolParserError::ParsingFailed(
"Expected a list of function calls in pythonic tool call".to_string(),
)),
}
}
/// Parse Python-style arguments into JSON
fn parse_arguments(&self, args_str: &str) -> ToolParserResult<Value> {
if args_str.trim().is_empty() {
return Ok(json!({}));
}
let mut result = serde_json::Map::new();
let mut current_key = String::new();
let mut current_value = String::new();
let mut in_key = true;
let mut depth = 0;
let mut in_string = false;
let mut string_char = ' ';
let mut escape_next = false;
let chars: Vec<char> = args_str.chars().collect();
let mut i = 0;
while i < chars.len() {
let ch = chars[i];
if escape_next {
if in_key {
current_key.push(ch);
} else {
current_value.push(ch);
}
escape_next = false;
i += 1;
continue;
}
if ch == '\\' && in_string {
escape_next = true;
current_value.push(ch);
i += 1;
continue;
}
// Handle string literals
if !in_string && (ch == '"' || ch == '\'') {
in_string = true;
string_char = ch;
if !in_key {
current_value.push(ch);
}
} else if in_string && ch == string_char && !escape_next {
in_string = false;
if !in_key {
current_value.push(ch);
}
} else if in_string {
if in_key {
current_key.push(ch);
} else {
current_value.push(ch);
}
} else {
// Not in string
match ch {
'=' if in_key && depth == 0 => {
in_key = false;
}
',' if depth == 0 => {
// End of current argument
if !current_key.is_empty() {
let value = parse_python_literal(current_value.trim())?;
result.insert(current_key.trim().to_string(), value);
}
current_key.clear();
current_value.clear();
in_key = true;
}
'[' | '{' | '(' => {
depth += 1;
if !in_key {
current_value.push(ch);
}
}
']' | '}' | ')' => {
depth -= 1;
if !in_key {
current_value.push(ch);
}
}
_ => {
if in_key {
if !ch.is_whitespace() || !current_key.is_empty() {
current_key.push(ch);
}
} else {
current_value.push(ch);
}
}
}
}
i += 1;
}
// Handle the last argument
if !current_key.is_empty() {
let value = parse_python_literal(current_value.trim())?;
result.insert(current_key.trim().to_string(), value);
}
Ok(Value::Object(result))
}
}
#[async_trait]
@@ -268,61 +83,8 @@ impl ToolParser for PythonicParser {
async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec<ToolCall>)> {
let cleaned = Self::strip_special_tokens(text);
// Extract tool calls using bracket counting
if let Some((tool_calls_text, normal_text)) = self.extract_tool_calls(&cleaned) {
// Remove the outer brackets
let tool_calls_str = &tool_calls_text[1..tool_calls_text.len() - 1];
// Split into individual function calls
let mut calls = Vec::new();
let mut current_call = String::new();
let mut paren_depth = 0;
let mut in_string = false;
let mut string_char = ' ';
for ch in tool_calls_str.chars() {
if !in_string && (ch == '"' || ch == '\'') {
in_string = true;
string_char = ch;
current_call.push(ch);
} else if in_string && ch == string_char {
in_string = false;
current_call.push(ch);
} else if in_string {
current_call.push(ch);
} else {
match ch {
'(' => {
paren_depth += 1;
current_call.push(ch);
}
')' => {
paren_depth -= 1;
current_call.push(ch);
}
',' if paren_depth == 0 => {
// End of current function call
if let Some(call) = self.parse_function_call(current_call.trim())? {
calls.push(call);
}
current_call.clear();
}
_ => {
if !ch.is_whitespace() || !current_call.is_empty() {
current_call.push(ch);
}
}
}
}
}
// Handle the last call (important for single calls or the last call in a list)
if !current_call.trim().is_empty() {
if let Some(call) = self.parse_function_call(current_call.trim())? {
calls.push(call);
}
}
let calls = self.parse_tool_call_block(&tool_calls_text)?;
Ok((normal_text, calls))
} else {
Ok((text.to_string(), vec![]))
@@ -334,19 +96,15 @@ impl ToolParser for PythonicParser {
chunk: &str,
state: &mut ParseState,
) -> ToolParserResult<StreamResult> {
// For Pythonic format, we accumulate until we have a complete tool call
// This is a simplified implementation
state.buffer.push_str(chunk);
// Try to parse if we have a complete tool call
let cleaned = Self::strip_special_tokens(&state.buffer);
if self.extract_tool_calls(&cleaned).is_some() {
let (_normal_text, tools) = self.parse_complete(&state.buffer).await?;
if !tools.is_empty() {
state.buffer.clear();
return Ok(StreamResult::ToolComplete(
tools.into_iter().next().unwrap(),
));
if let Some((tool_calls_text, _)) = self.extract_tool_calls(&cleaned) {
if let Ok(tools) = self.parse_tool_call_block(&tool_calls_text) {
if let Some(tool) = tools.into_iter().next() {
state.buffer.clear();
return Ok(StreamResult::ToolComplete(tool));
}
}
}
@@ -355,13 +113,220 @@ impl ToolParser for PythonicParser {
fn detect_format(&self, text: &str) -> bool {
let cleaned = Self::strip_special_tokens(text);
self.tool_call_regex.is_match(&cleaned)
if pythonic_block_regex().is_match(&cleaned) {
return true;
}
let trimmed = cleaned.trim();
let Some(open_idx) = trimmed.find('[') else {
return false;
};
let after_bracket = trimmed[open_idx + 1..].trim_start();
let mut chars = after_bracket.char_indices();
let Some((_, first_char)) = chars.next() else {
return false;
};
if !(first_char.is_ascii_alphabetic() || first_char == '_') {
return false;
}
let mut ident_len = first_char.len_utf8();
for (idx, ch) in chars {
if ch.is_alphanumeric() || ch == '_' {
ident_len = idx + ch.len_utf8();
} else {
break;
}
}
let remaining = after_bracket[ident_len..].trim_start();
remaining.starts_with('(')
}
}
impl Default for PythonicParser {
fn default() -> Self {
Self::new()
fn parse_python_expression(source: &str) -> ToolParserResult<Expr> {
let module = parse(source, Mode::Expression, "<pythonic_tool_call>")
.map_err(|err| ToolParserError::ParsingFailed(err.to_string()))?;
match module {
Mod::Expression(expr_mod) => Ok(*expr_mod.body),
_ => Err(ToolParserError::ParsingFailed(
"Expected a Python expression".to_string(),
)),
}
}
fn build_tool_call(expr: Expr, index: usize) -> ToolParserResult<ToolCall> {
match expr {
Expr::Call(call_expr) => {
if !call_expr.args.is_empty() {
return Err(ToolParserError::ParsingFailed(
"Positional arguments are not supported in pythonic tool calls".to_string(),
));
}
let function_name = match *call_expr.func {
Expr::Name(name_expr) => name_expr.id.to_string(),
_ => {
return Err(ToolParserError::ParsingFailed(
"Unsupported function reference in pythonic tool call".to_string(),
))
}
};
let mut arguments_map = Map::with_capacity(call_expr.keywords.len());
for keyword in call_expr.keywords {
let arg_name = keyword.arg.ok_or_else(|| {
ToolParserError::ParsingFailed(
"pythonic tool calls do not support **kwargs".to_string(),
)
})?;
let value_json = expression_to_json(&keyword.value)?;
arguments_map.insert(arg_name.to_string(), value_json);
}
let arguments_json = Value::Object(arguments_map);
let arguments_string = serde_json::to_string(&arguments_json)?;
Ok(ToolCall {
id: format!("call-{}", index + 1),
r#type: "function".to_string(),
function: FunctionCall {
name: function_name,
arguments: arguments_string,
},
})
}
_ => Err(ToolParserError::ParsingFailed(
"Expected function calls inside pythonic tool call list".to_string(),
)),
}
}
fn expression_to_json(expr: &Expr) -> ToolParserResult<Value> {
match expr {
Expr::Constant(expr_constant) => constant_to_json(&expr_constant.value),
Expr::List(list_expr) => collect_sequence(&list_expr.elts).map(Value::Array),
Expr::Tuple(tuple_expr) => collect_sequence(&tuple_expr.elts).map(Value::Array),
Expr::Dict(dict_expr) => {
collect_dict(&dict_expr.keys, &dict_expr.values).map(Value::Object)
}
Expr::UnaryOp(unary_expr) => match unary_expr.op {
UnaryOp::USub => match unary_expr.operand.as_ref() {
Expr::Constant(const_expr) => negate_constant(&const_expr.value),
_ => Err(ToolParserError::ParsingFailed(
"Unsupported unary operand in pythonic tool call".to_string(),
)),
},
UnaryOp::UAdd => expression_to_json(unary_expr.operand.as_ref()),
_ => Err(ToolParserError::ParsingFailed(format!(
"Unsupported unary operator in pythonic tool call: {:?}",
unary_expr.op
))),
},
Expr::Name(name_expr) => Ok(Value::String(name_expr.id.to_string())),
_ => Err(ToolParserError::ParsingFailed(format!(
"Unsupported expression in pythonic tool call: {:?}",
expr
))),
}
}
fn constant_to_json(constant: &Constant) -> ToolParserResult<Value> {
match constant {
Constant::None => Ok(Value::Null),
Constant::Bool(b) => Ok(Value::Bool(*b)),
Constant::Int(value) => Ok(integer_constant_to_value(value, false)),
Constant::Float(f) => Number::from_f64(*f).map(Value::Number).ok_or_else(|| {
ToolParserError::ParsingFailed(
"Invalid float literal in pythonic tool call".to_string(),
)
}),
Constant::Str(s) => Ok(Value::String(s.clone())),
Constant::Bytes(bytes) => Ok(Value::String(String::from_utf8_lossy(bytes).into_owned())),
Constant::Tuple(values) => constant_tuple_to_array(values).map(Value::Array),
Constant::Ellipsis | Constant::Complex { .. } => Err(ToolParserError::ParsingFailed(
"Unsupported literal in pythonic tool call".to_string(),
)),
}
}
fn negate_constant(constant: &Constant) -> ToolParserResult<Value> {
match constant {
Constant::Int(value) => Ok(integer_constant_to_value(value, true)),
Constant::Float(f) => Number::from_f64(-f).map(Value::Number).ok_or_else(|| {
ToolParserError::ParsingFailed(
"Invalid float literal in pythonic tool call".to_string(),
)
}),
_ => Err(ToolParserError::ParsingFailed(
"Unsupported unary operand in pythonic tool call".to_string(),
)),
}
}
fn value_to_key_string(value: Value) -> ToolParserResult<String> {
match value {
Value::String(s) => Ok(s),
Value::Number(num) => Ok(num.to_string()),
Value::Bool(b) => Ok(b.to_string()),
Value::Null => Ok("null".to_string()),
other => Err(ToolParserError::ParsingFailed(format!(
"Unsupported key type in pythonic tool call: {:?}",
other
))),
}
}
fn collect_sequence(elements: &[Expr]) -> ToolParserResult<Vec<Value>> {
elements.iter().map(expression_to_json).collect()
}
fn collect_dict(keys: &[Option<Expr>], values: &[Expr]) -> ToolParserResult<Map<String, Value>> {
let mut map = Map::with_capacity(keys.len());
for (key_expr, value_expr) in keys.iter().zip(values.iter()) {
let key_expr = key_expr.as_ref().ok_or_else(|| {
ToolParserError::ParsingFailed(
"pythonic tool calls do not support **kwargs".to_string(),
)
})?;
let key_value = expression_to_json(key_expr)?;
let key = value_to_key_string(key_value)?;
let value_json = expression_to_json(value_expr)?;
map.insert(key, value_json);
}
Ok(map)
}
fn constant_tuple_to_array(values: &[Constant]) -> ToolParserResult<Vec<Value>> {
values.iter().map(constant_to_json).collect()
}
fn integer_constant_to_value<T>(value: &T, negate: bool) -> Value
where
T: ToPrimitive + std::fmt::Display,
{
if let Some(mut i) = value.to_i64() {
if negate {
i = -i;
}
return Value::Number(Number::from(i));
}
if negate {
if let Some(u) = value.to_u64() {
if u <= i64::MAX as u64 {
return Value::Number(Number::from(-(u as i64)));
}
return Value::String(format!("-{}", value));
}
Value::String(format!("-{}", value))
} else if let Some(u) = value.to_u64() {
Value::Number(Number::from(u))
} else {
Value::String(value.to_string())
}
}
@@ -405,63 +370,43 @@ mod tests {
let args: Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
assert_eq!(args["flag"], true);
assert_eq!(args["disabled"], false);
assert_eq!(args["optional"], Value::Null);
assert!(args["optional"].is_null());
}
#[tokio::test]
async fn test_special_tokens() {
async fn test_strip_special_tokens() {
let parser = PythonicParser::new();
let input = r#"<|python_start|>[calculate(x=10, y=20)]<|python_end|>"#;
let input = "<|python_start|>[call(arg=1)]<|python_end|>";
assert!(parser.detect_format(input));
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
assert_eq!(tools.len(), 1);
assert_eq!(tools[0].function.name, "calculate");
let args: Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
assert_eq!(args["x"], 10);
assert_eq!(args["y"], 20);
}
#[tokio::test]
async fn test_llama4_format() {
async fn test_detect_format() {
let parser = PythonicParser::new();
let input = r#"[get_weather(city="London", units="celsius")]"#;
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
assert_eq!(tools.len(), 1);
assert_eq!(tools[0].function.name, "get_weather");
let args: Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
assert_eq!(args["city"], "London");
assert_eq!(args["units"], "celsius");
assert!(parser.detect_format("[foo(bar=1)]"));
assert!(!parser.detect_format("No python here"));
}
#[tokio::test]
async fn test_normal_text_extraction() {
async fn test_parse_incremental() {
let parser = PythonicParser::new();
let mut state = ParseState::new();
// Test with text before and after
let input = r#"Please check the weather [get_weather(city="Tokyo")] and let me know."#;
let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
let chunk1 = "[call(arg=";
let result1 = parser.parse_incremental(chunk1, &mut state).await.unwrap();
assert!(matches!(result1, StreamResult::Incomplete));
assert_eq!(tools.len(), 1);
assert_eq!(tools[0].function.name, "get_weather");
assert_eq!(normal_text, "Please check the weather and let me know.");
let chunk2 = "1)]";
let result2 = parser.parse_incremental(chunk2, &mut state).await.unwrap();
// Test with only normal text (no tool calls)
let input_no_tools = "This is just normal text without any tool calls.";
let (normal_text, tools) = parser.parse_complete(input_no_tools).await.unwrap();
assert_eq!(tools.len(), 0);
assert_eq!(normal_text, input_no_tools);
// Test with multiple tool calls in single bracket group and normal text
let input_multiple = r#"First, [search(query="rust"), calculate(x=5, y=10)] please."#;
let (normal_text, tools) = parser.parse_complete(input_multiple).await.unwrap();
assert_eq!(tools.len(), 2);
assert_eq!(tools[0].function.name, "search");
assert_eq!(tools[1].function.name, "calculate");
assert_eq!(normal_text, "First, please.");
match result2 {
StreamResult::ToolComplete(tool) => {
assert_eq!(tool.function.name, "call");
}
other => panic!("Expected ToolComplete, got {:?}", other),
}
}
}