320 lines
12 KiB
Rust
320 lines
12 KiB
Rust
/// Pythonic format parser for tool calls
|
|
///
|
|
/// Handles Python function call syntax within square brackets:
|
|
/// ```text
|
|
/// [tool1(arg1=val1, arg2=val2), tool2(arg1=val3)]
|
|
/// ```
|
|
///
|
|
/// This format is used by Llama models and uses Python literals
|
|
/// rather than JSON for arguments.
|
|
use async_trait::async_trait;
|
|
use num_traits::ToPrimitive;
|
|
use regex::Regex;
|
|
use rustpython_parser::ast::{Constant, Expr, Mod, UnaryOp};
|
|
use rustpython_parser::{parse, Mode};
|
|
use serde_json::{Map, Number, Value};
|
|
use std::sync::OnceLock;
|
|
|
|
use crate::tool_parser::{
|
|
errors::{ToolParserError, ToolParserResult},
|
|
state::ParseState,
|
|
traits::ToolParser,
|
|
types::{FunctionCall, StreamResult, ToolCall},
|
|
};
|
|
|
|
static PYTHONIC_BLOCK_REGEX: OnceLock<Regex> = OnceLock::new();
|
|
|
|
/// Lazily compiled regex that locates pythonic tool call blocks.
|
|
fn pythonic_block_regex() -> &'static Regex {
|
|
PYTHONIC_BLOCK_REGEX.get_or_init(|| {
|
|
// Matches one or more function calls inside a list. The `(?s)` flag allows
|
|
// newlines inside argument lists while keeping the pattern anchored to
|
|
// identifiers followed by parentheses, preventing plain lists like
|
|
// `[1, 2, 3]` from matching.
|
|
Regex::new(r"(?s)\[\s*[A-Za-z_]\w*\s*\(.*?\)\s*(?:,\s*[A-Za-z_]\w*\s*\(.*?\)\s*)*\]")
|
|
.expect("pythonic tool call regex must compile")
|
|
})
|
|
}
|
|
|
|
/// Parser for Pythonic tool call format
|
|
#[derive(Default)]
|
|
pub struct PythonicParser;
|
|
|
|
impl PythonicParser {
|
|
/// Create a new Pythonic parser
|
|
pub fn new() -> Self {
|
|
Self
|
|
}
|
|
|
|
/// Extract the first pythonic tool call block and return it along with the
|
|
/// surrounding "normal" content.
|
|
fn extract_tool_calls(&self, text: &str) -> Option<(String, String)> {
|
|
pythonic_block_regex().find(text).map(|mat| {
|
|
let block = mat.as_str().to_string();
|
|
let normal = format!("{}{}", &text[..mat.start()], &text[mat.end()..]);
|
|
(block, normal)
|
|
})
|
|
}
|
|
|
|
/// Strip special tokens that Llama models might output
|
|
fn strip_special_tokens(text: &str) -> String {
|
|
text.replace("<|python_start|>", "")
|
|
.replace("<|python_end|>", "")
|
|
}
|
|
|
|
fn parse_tool_call_block(&self, block: &str) -> ToolParserResult<Vec<ToolCall>> {
|
|
let expr = parse_python_expression(block)?;
|
|
match expr {
|
|
Expr::List(list_expr) => list_expr
|
|
.elts
|
|
.into_iter()
|
|
.enumerate()
|
|
.map(|(idx, call_expr)| build_tool_call(call_expr, idx))
|
|
.collect(),
|
|
_ => Err(ToolParserError::ParsingFailed(
|
|
"Expected a list of function calls in pythonic tool call".to_string(),
|
|
)),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl ToolParser for PythonicParser {
|
|
async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec<ToolCall>)> {
|
|
let cleaned = Self::strip_special_tokens(text);
|
|
|
|
if let Some((tool_calls_text, normal_text)) = self.extract_tool_calls(&cleaned) {
|
|
match self.parse_tool_call_block(&tool_calls_text) {
|
|
Ok(calls) => {
|
|
if calls.is_empty() {
|
|
// No tools successfully parsed despite having markers
|
|
Ok((text.to_string(), vec![]))
|
|
} else {
|
|
Ok((normal_text, calls))
|
|
}
|
|
}
|
|
Err(e) => {
|
|
// Log warning and return entire text as fallback
|
|
tracing::warn!("Failed to parse pythonic tool calls: {}", e);
|
|
Ok((text.to_string(), vec![]))
|
|
}
|
|
}
|
|
} else {
|
|
Ok((text.to_string(), vec![]))
|
|
}
|
|
}
|
|
|
|
async fn parse_incremental(
|
|
&self,
|
|
chunk: &str,
|
|
state: &mut ParseState,
|
|
) -> ToolParserResult<StreamResult> {
|
|
state.buffer.push_str(chunk);
|
|
|
|
let cleaned = Self::strip_special_tokens(&state.buffer);
|
|
if let Some((tool_calls_text, _)) = self.extract_tool_calls(&cleaned) {
|
|
if let Ok(tools) = self.parse_tool_call_block(&tool_calls_text) {
|
|
if let Some(tool) = tools.into_iter().next() {
|
|
state.buffer.clear();
|
|
return Ok(StreamResult::ToolComplete(tool));
|
|
}
|
|
}
|
|
}
|
|
|
|
Ok(StreamResult::Incomplete)
|
|
}
|
|
|
|
fn detect_format(&self, text: &str) -> bool {
|
|
let cleaned = Self::strip_special_tokens(text);
|
|
if pythonic_block_regex().is_match(&cleaned) {
|
|
return true;
|
|
}
|
|
|
|
false
|
|
}
|
|
}
|
|
|
|
fn parse_python_expression(source: &str) -> ToolParserResult<Expr> {
|
|
let module = parse(source, Mode::Expression, "<pythonic_tool_call>")
|
|
.map_err(|err| ToolParserError::ParsingFailed(err.to_string()))?;
|
|
|
|
match module {
|
|
Mod::Expression(expr_mod) => Ok(*expr_mod.body),
|
|
_ => Err(ToolParserError::ParsingFailed(
|
|
"Expected a Python expression".to_string(),
|
|
)),
|
|
}
|
|
}
|
|
|
|
fn build_tool_call(expr: Expr, index: usize) -> ToolParserResult<ToolCall> {
|
|
match expr {
|
|
Expr::Call(call_expr) => {
|
|
if !call_expr.args.is_empty() {
|
|
return Err(ToolParserError::ParsingFailed(
|
|
"Positional arguments are not supported in pythonic tool calls".to_string(),
|
|
));
|
|
}
|
|
|
|
let function_name = match *call_expr.func {
|
|
Expr::Name(name_expr) => name_expr.id.to_string(),
|
|
_ => {
|
|
return Err(ToolParserError::ParsingFailed(
|
|
"Unsupported function reference in pythonic tool call".to_string(),
|
|
))
|
|
}
|
|
};
|
|
|
|
let mut arguments_map = Map::with_capacity(call_expr.keywords.len());
|
|
for keyword in call_expr.keywords {
|
|
let arg_name = keyword.arg.ok_or_else(|| {
|
|
ToolParserError::ParsingFailed(
|
|
"pythonic tool calls do not support **kwargs".to_string(),
|
|
)
|
|
})?;
|
|
let value_json = expression_to_json(&keyword.value)?;
|
|
arguments_map.insert(arg_name.to_string(), value_json);
|
|
}
|
|
|
|
let arguments_json = Value::Object(arguments_map);
|
|
let arguments_string = serde_json::to_string(&arguments_json)?;
|
|
|
|
Ok(ToolCall {
|
|
id: format!("call-{}", index + 1),
|
|
r#type: "function".to_string(),
|
|
function: FunctionCall {
|
|
name: function_name,
|
|
arguments: arguments_string,
|
|
},
|
|
})
|
|
}
|
|
_ => Err(ToolParserError::ParsingFailed(
|
|
"Expected function calls inside pythonic tool call list".to_string(),
|
|
)),
|
|
}
|
|
}
|
|
|
|
fn expression_to_json(expr: &Expr) -> ToolParserResult<Value> {
|
|
match expr {
|
|
Expr::Constant(expr_constant) => constant_to_json(&expr_constant.value),
|
|
Expr::List(list_expr) => collect_sequence(&list_expr.elts).map(Value::Array),
|
|
Expr::Tuple(tuple_expr) => collect_sequence(&tuple_expr.elts).map(Value::Array),
|
|
Expr::Dict(dict_expr) => {
|
|
collect_dict(&dict_expr.keys, &dict_expr.values).map(Value::Object)
|
|
}
|
|
Expr::UnaryOp(unary_expr) => match unary_expr.op {
|
|
UnaryOp::USub => match unary_expr.operand.as_ref() {
|
|
Expr::Constant(const_expr) => negate_constant(&const_expr.value),
|
|
_ => Err(ToolParserError::ParsingFailed(
|
|
"Unsupported unary operand in pythonic tool call".to_string(),
|
|
)),
|
|
},
|
|
UnaryOp::UAdd => expression_to_json(unary_expr.operand.as_ref()),
|
|
_ => Err(ToolParserError::ParsingFailed(format!(
|
|
"Unsupported unary operator in pythonic tool call: {:?}",
|
|
unary_expr.op
|
|
))),
|
|
},
|
|
Expr::Name(name_expr) => Ok(Value::String(name_expr.id.to_string())),
|
|
_ => Err(ToolParserError::ParsingFailed(format!(
|
|
"Unsupported expression in pythonic tool call: {:?}",
|
|
expr
|
|
))),
|
|
}
|
|
}
|
|
|
|
fn constant_to_json(constant: &Constant) -> ToolParserResult<Value> {
|
|
match constant {
|
|
Constant::None => Ok(Value::Null),
|
|
Constant::Bool(b) => Ok(Value::Bool(*b)),
|
|
Constant::Int(value) => Ok(integer_constant_to_value(value, false)),
|
|
Constant::Float(f) => Number::from_f64(*f).map(Value::Number).ok_or_else(|| {
|
|
ToolParserError::ParsingFailed(
|
|
"Invalid float literal in pythonic tool call".to_string(),
|
|
)
|
|
}),
|
|
Constant::Str(s) => Ok(Value::String(s.clone())),
|
|
Constant::Bytes(bytes) => Ok(Value::String(String::from_utf8_lossy(bytes).into_owned())),
|
|
Constant::Tuple(values) => constant_tuple_to_array(values).map(Value::Array),
|
|
Constant::Ellipsis | Constant::Complex { .. } => Err(ToolParserError::ParsingFailed(
|
|
"Unsupported literal in pythonic tool call".to_string(),
|
|
)),
|
|
}
|
|
}
|
|
|
|
fn negate_constant(constant: &Constant) -> ToolParserResult<Value> {
|
|
match constant {
|
|
Constant::Int(value) => Ok(integer_constant_to_value(value, true)),
|
|
Constant::Float(f) => Number::from_f64(-f).map(Value::Number).ok_or_else(|| {
|
|
ToolParserError::ParsingFailed(
|
|
"Invalid float literal in pythonic tool call".to_string(),
|
|
)
|
|
}),
|
|
_ => Err(ToolParserError::ParsingFailed(
|
|
"Unsupported unary operand in pythonic tool call".to_string(),
|
|
)),
|
|
}
|
|
}
|
|
|
|
fn value_to_key_string(value: Value) -> ToolParserResult<String> {
|
|
match value {
|
|
Value::String(s) => Ok(s),
|
|
Value::Number(num) => Ok(num.to_string()),
|
|
Value::Bool(b) => Ok(b.to_string()),
|
|
Value::Null => Ok("null".to_string()),
|
|
other => Err(ToolParserError::ParsingFailed(format!(
|
|
"Unsupported key type in pythonic tool call: {:?}",
|
|
other
|
|
))),
|
|
}
|
|
}
|
|
|
|
fn collect_sequence(elements: &[Expr]) -> ToolParserResult<Vec<Value>> {
|
|
elements.iter().map(expression_to_json).collect()
|
|
}
|
|
|
|
fn collect_dict(keys: &[Option<Expr>], values: &[Expr]) -> ToolParserResult<Map<String, Value>> {
|
|
let mut map = Map::with_capacity(keys.len());
|
|
for (key_expr, value_expr) in keys.iter().zip(values.iter()) {
|
|
let key_expr = key_expr.as_ref().ok_or_else(|| {
|
|
ToolParserError::ParsingFailed(
|
|
"pythonic tool calls do not support **kwargs".to_string(),
|
|
)
|
|
})?;
|
|
let key_value = expression_to_json(key_expr)?;
|
|
let key = value_to_key_string(key_value)?;
|
|
let value_json = expression_to_json(value_expr)?;
|
|
map.insert(key, value_json);
|
|
}
|
|
Ok(map)
|
|
}
|
|
|
|
fn constant_tuple_to_array(values: &[Constant]) -> ToolParserResult<Vec<Value>> {
|
|
values.iter().map(constant_to_json).collect()
|
|
}
|
|
|
|
fn integer_constant_to_value<T>(value: &T, negate: bool) -> Value
|
|
where
|
|
T: ToPrimitive + std::fmt::Display,
|
|
{
|
|
if let Some(mut i) = value.to_i64() {
|
|
if negate {
|
|
i = -i;
|
|
}
|
|
return Value::Number(Number::from(i));
|
|
}
|
|
|
|
if negate {
|
|
if let Some(u) = value.to_u64() {
|
|
if u <= i64::MAX as u64 {
|
|
return Value::Number(Number::from(-(u as i64)));
|
|
}
|
|
return Value::String(format!("-{}", value));
|
|
}
|
|
Value::String(format!("-{}", value))
|
|
} else if let Some(u) = value.to_u64() {
|
|
Value::Number(Number::from(u))
|
|
} else {
|
|
Value::String(value.to_string())
|
|
}
|
|
}
|