[router] migrate to rust python module for pythonic parser (#11033)
This commit is contained in:
@@ -71,6 +71,8 @@ rmcp = { version = "0.6.3", features = ["client", "server",
|
||||
serde_yaml = "0.9"
|
||||
oracle = { version = "0.6.3", features = ["chrono"] }
|
||||
subtle = "2.6"
|
||||
rustpython-parser = "0.4.0"
|
||||
num-traits = "0.2"
|
||||
|
||||
# gRPC and Protobuf dependencies
|
||||
tonic = { version = "0.12", features = ["tls", "gzip", "transport"] }
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
// Core modules
|
||||
pub mod errors;
|
||||
pub mod partial_json;
|
||||
pub mod python_literal_parser;
|
||||
pub mod registry;
|
||||
pub mod state;
|
||||
pub mod traits;
|
||||
|
||||
@@ -5,262 +5,77 @@
|
||||
/// [tool1(arg1=val1, arg2=val2), tool2(arg1=val3)]
|
||||
/// ```
|
||||
///
|
||||
/// This format is used by Llama-4 models and uses Python literals
|
||||
/// This format is used by Llama models and uses Python literals
|
||||
/// rather than JSON for arguments.
|
||||
use async_trait::async_trait;
|
||||
use num_traits::ToPrimitive;
|
||||
use regex::Regex;
|
||||
use serde_json::{json, Value};
|
||||
use rustpython_parser::ast::{Constant, Expr, Mod, UnaryOp};
|
||||
use rustpython_parser::{parse, Mode};
|
||||
use serde_json::{Map, Number, Value};
|
||||
use std::sync::OnceLock;
|
||||
|
||||
use crate::tool_parser::{
|
||||
errors::ToolParserResult,
|
||||
python_literal_parser::parse_python_literal,
|
||||
errors::{ToolParserError, ToolParserResult},
|
||||
state::ParseState,
|
||||
traits::ToolParser,
|
||||
types::{FunctionCall, StreamResult, ToolCall},
|
||||
};
|
||||
|
||||
/// Parser for Pythonic tool call format
|
||||
pub struct PythonicParser {
|
||||
/// Regex to detect tool calls in Pythonic format
|
||||
tool_call_regex: Regex,
|
||||
/// Regex to parse function calls - cached for reuse
|
||||
call_regex: Regex,
|
||||
static PYTHONIC_BLOCK_REGEX: OnceLock<Regex> = OnceLock::new();
|
||||
|
||||
/// Lazily compiled regex that locates pythonic tool call blocks.
|
||||
fn pythonic_block_regex() -> &'static Regex {
|
||||
PYTHONIC_BLOCK_REGEX.get_or_init(|| {
|
||||
// Matches one or more function calls inside a list. The `(?s)` flag allows
|
||||
// newlines inside argument lists while keeping the pattern anchored to
|
||||
// identifiers followed by parentheses, preventing plain lists like
|
||||
// `[1, 2, 3]` from matching.
|
||||
Regex::new(r"(?s)\[\s*[A-Za-z_]\w*\s*\(.*?\)\s*(?:,\s*[A-Za-z_]\w*\s*\(.*?\)\s*)*\]")
|
||||
.expect("pythonic tool call regex must compile")
|
||||
})
|
||||
}
|
||||
|
||||
/// Parser for Pythonic tool call format
|
||||
#[derive(Default)]
|
||||
pub struct PythonicParser;
|
||||
|
||||
impl PythonicParser {
|
||||
/// Create a new Pythonic parser
|
||||
pub fn new() -> Self {
|
||||
// Simple regex to detect start of Pythonic tool calls
|
||||
// We'll use manual parsing for the actual extraction
|
||||
let pattern = r"\[[a-zA-Z_]\w*\(";
|
||||
let tool_call_regex = Regex::new(pattern).expect("Valid regex pattern");
|
||||
|
||||
// Compile the function call regex once
|
||||
let call_regex = Regex::new(r"(?s)^([a-zA-Z_]\w*)\((.*)\)$").expect("Valid regex pattern");
|
||||
|
||||
Self {
|
||||
tool_call_regex,
|
||||
call_regex,
|
||||
}
|
||||
Self
|
||||
}
|
||||
|
||||
/// Extract tool calls using bracket counting (similar to MistralParser)
|
||||
/// Returns extracted tool call group with [] and normal content
|
||||
/// Extract the first pythonic tool call block and return it along with the
|
||||
/// surrounding "normal" content.
|
||||
fn extract_tool_calls(&self, text: &str) -> Option<(String, String)> {
|
||||
// Find the start of a tool call list - look for [ followed by a function name
|
||||
let chars: Vec<char> = text.chars().collect();
|
||||
|
||||
for start_idx in 0..chars.len() {
|
||||
if chars[start_idx] != '[' {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check if this looks like a tool call
|
||||
// Skip whitespace after [
|
||||
let mut check_idx = start_idx + 1;
|
||||
while check_idx < chars.len() && chars[check_idx].is_whitespace() {
|
||||
check_idx += 1;
|
||||
}
|
||||
|
||||
// Check if we have a function name (starts with letter or underscore)
|
||||
if check_idx >= chars.len()
|
||||
|| (!chars[check_idx].is_alphabetic() && chars[check_idx] != '_')
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
// Now count brackets to find the matching ]
|
||||
let mut bracket_count = 0;
|
||||
let mut _paren_count = 0;
|
||||
let mut _brace_count = 0;
|
||||
let mut in_string = false;
|
||||
let mut string_char = ' ';
|
||||
let mut escape_next = false;
|
||||
|
||||
for i in start_idx..chars.len() {
|
||||
let ch = chars[i];
|
||||
|
||||
if escape_next {
|
||||
escape_next = false;
|
||||
continue;
|
||||
}
|
||||
|
||||
if ch == '\\' && in_string {
|
||||
escape_next = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
if !in_string && (ch == '"' || ch == '\'') {
|
||||
in_string = true;
|
||||
string_char = ch;
|
||||
} else if in_string && ch == string_char && !escape_next {
|
||||
in_string = false;
|
||||
} else if !in_string {
|
||||
match ch {
|
||||
'[' => bracket_count += 1,
|
||||
']' => {
|
||||
bracket_count -= 1;
|
||||
if bracket_count == 0 {
|
||||
// Found the matching bracket
|
||||
let extracted: String = chars[start_idx..=i].iter().collect();
|
||||
if extracted.contains('(') && extracted.contains(')') {
|
||||
// Calculate normal text by removing the tool call portion
|
||||
let before = &text[..start_idx];
|
||||
let after = &text[(i + 1)..];
|
||||
let normal_text = format!("{}{}", before, after);
|
||||
return Some((extracted, normal_text));
|
||||
}
|
||||
}
|
||||
}
|
||||
'(' => _paren_count += 1,
|
||||
')' => _paren_count -= 1,
|
||||
'{' => _brace_count += 1,
|
||||
'}' => _brace_count -= 1,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
pythonic_block_regex().find(text).map(|mat| {
|
||||
let block = mat.as_str().to_string();
|
||||
let normal = format!("{}{}", &text[..mat.start()], &text[mat.end()..]);
|
||||
(block, normal)
|
||||
})
|
||||
}
|
||||
|
||||
/// Strip special tokens that Llama 4 might output
|
||||
/// Strip special tokens that Llama models might output
|
||||
fn strip_special_tokens(text: &str) -> String {
|
||||
text.replace("<|python_start|>", "")
|
||||
.replace("<|python_end|>", "")
|
||||
}
|
||||
|
||||
/// Parse a single function call from Python syntax
|
||||
fn parse_function_call(&self, call_str: &str) -> ToolParserResult<Option<ToolCall>> {
|
||||
// Use cached regex instead of creating new one
|
||||
if let Some(captures) = self.call_regex.captures(call_str.trim()) {
|
||||
let function_name = captures.get(1).unwrap().as_str();
|
||||
let args_str = captures.get(2).unwrap().as_str();
|
||||
|
||||
// Parse arguments
|
||||
let arguments = self.parse_arguments(args_str)?;
|
||||
|
||||
Ok(Some(ToolCall {
|
||||
id: format!("call_{}", uuid::Uuid::new_v4()),
|
||||
r#type: "function".to_string(),
|
||||
function: FunctionCall {
|
||||
name: function_name.to_string(),
|
||||
arguments: serde_json::to_string(&arguments)?,
|
||||
},
|
||||
}))
|
||||
} else {
|
||||
Ok(None)
|
||||
fn parse_tool_call_block(&self, block: &str) -> ToolParserResult<Vec<ToolCall>> {
|
||||
let expr = parse_python_expression(block)?;
|
||||
match expr {
|
||||
Expr::List(list_expr) => list_expr
|
||||
.elts
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(idx, call_expr)| build_tool_call(call_expr, idx))
|
||||
.collect(),
|
||||
_ => Err(ToolParserError::ParsingFailed(
|
||||
"Expected a list of function calls in pythonic tool call".to_string(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse Python-style arguments into JSON
|
||||
fn parse_arguments(&self, args_str: &str) -> ToolParserResult<Value> {
|
||||
if args_str.trim().is_empty() {
|
||||
return Ok(json!({}));
|
||||
}
|
||||
|
||||
let mut result = serde_json::Map::new();
|
||||
let mut current_key = String::new();
|
||||
let mut current_value = String::new();
|
||||
let mut in_key = true;
|
||||
let mut depth = 0;
|
||||
let mut in_string = false;
|
||||
let mut string_char = ' ';
|
||||
let mut escape_next = false;
|
||||
|
||||
let chars: Vec<char> = args_str.chars().collect();
|
||||
let mut i = 0;
|
||||
|
||||
while i < chars.len() {
|
||||
let ch = chars[i];
|
||||
|
||||
if escape_next {
|
||||
if in_key {
|
||||
current_key.push(ch);
|
||||
} else {
|
||||
current_value.push(ch);
|
||||
}
|
||||
escape_next = false;
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
if ch == '\\' && in_string {
|
||||
escape_next = true;
|
||||
current_value.push(ch);
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Handle string literals
|
||||
if !in_string && (ch == '"' || ch == '\'') {
|
||||
in_string = true;
|
||||
string_char = ch;
|
||||
if !in_key {
|
||||
current_value.push(ch);
|
||||
}
|
||||
} else if in_string && ch == string_char && !escape_next {
|
||||
in_string = false;
|
||||
if !in_key {
|
||||
current_value.push(ch);
|
||||
}
|
||||
} else if in_string {
|
||||
if in_key {
|
||||
current_key.push(ch);
|
||||
} else {
|
||||
current_value.push(ch);
|
||||
}
|
||||
} else {
|
||||
// Not in string
|
||||
match ch {
|
||||
'=' if in_key && depth == 0 => {
|
||||
in_key = false;
|
||||
}
|
||||
',' if depth == 0 => {
|
||||
// End of current argument
|
||||
if !current_key.is_empty() {
|
||||
let value = parse_python_literal(current_value.trim())?;
|
||||
result.insert(current_key.trim().to_string(), value);
|
||||
}
|
||||
current_key.clear();
|
||||
current_value.clear();
|
||||
in_key = true;
|
||||
}
|
||||
'[' | '{' | '(' => {
|
||||
depth += 1;
|
||||
if !in_key {
|
||||
current_value.push(ch);
|
||||
}
|
||||
}
|
||||
']' | '}' | ')' => {
|
||||
depth -= 1;
|
||||
if !in_key {
|
||||
current_value.push(ch);
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
if in_key {
|
||||
if !ch.is_whitespace() || !current_key.is_empty() {
|
||||
current_key.push(ch);
|
||||
}
|
||||
} else {
|
||||
current_value.push(ch);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
i += 1;
|
||||
}
|
||||
|
||||
// Handle the last argument
|
||||
if !current_key.is_empty() {
|
||||
let value = parse_python_literal(current_value.trim())?;
|
||||
result.insert(current_key.trim().to_string(), value);
|
||||
}
|
||||
|
||||
Ok(Value::Object(result))
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
@@ -268,61 +83,8 @@ impl ToolParser for PythonicParser {
|
||||
async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec<ToolCall>)> {
|
||||
let cleaned = Self::strip_special_tokens(text);
|
||||
|
||||
// Extract tool calls using bracket counting
|
||||
if let Some((tool_calls_text, normal_text)) = self.extract_tool_calls(&cleaned) {
|
||||
// Remove the outer brackets
|
||||
let tool_calls_str = &tool_calls_text[1..tool_calls_text.len() - 1];
|
||||
|
||||
// Split into individual function calls
|
||||
let mut calls = Vec::new();
|
||||
let mut current_call = String::new();
|
||||
let mut paren_depth = 0;
|
||||
let mut in_string = false;
|
||||
let mut string_char = ' ';
|
||||
|
||||
for ch in tool_calls_str.chars() {
|
||||
if !in_string && (ch == '"' || ch == '\'') {
|
||||
in_string = true;
|
||||
string_char = ch;
|
||||
current_call.push(ch);
|
||||
} else if in_string && ch == string_char {
|
||||
in_string = false;
|
||||
current_call.push(ch);
|
||||
} else if in_string {
|
||||
current_call.push(ch);
|
||||
} else {
|
||||
match ch {
|
||||
'(' => {
|
||||
paren_depth += 1;
|
||||
current_call.push(ch);
|
||||
}
|
||||
')' => {
|
||||
paren_depth -= 1;
|
||||
current_call.push(ch);
|
||||
}
|
||||
',' if paren_depth == 0 => {
|
||||
// End of current function call
|
||||
if let Some(call) = self.parse_function_call(current_call.trim())? {
|
||||
calls.push(call);
|
||||
}
|
||||
current_call.clear();
|
||||
}
|
||||
_ => {
|
||||
if !ch.is_whitespace() || !current_call.is_empty() {
|
||||
current_call.push(ch);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle the last call (important for single calls or the last call in a list)
|
||||
if !current_call.trim().is_empty() {
|
||||
if let Some(call) = self.parse_function_call(current_call.trim())? {
|
||||
calls.push(call);
|
||||
}
|
||||
}
|
||||
|
||||
let calls = self.parse_tool_call_block(&tool_calls_text)?;
|
||||
Ok((normal_text, calls))
|
||||
} else {
|
||||
Ok((text.to_string(), vec![]))
|
||||
@@ -334,19 +96,15 @@ impl ToolParser for PythonicParser {
|
||||
chunk: &str,
|
||||
state: &mut ParseState,
|
||||
) -> ToolParserResult<StreamResult> {
|
||||
// For Pythonic format, we accumulate until we have a complete tool call
|
||||
// This is a simplified implementation
|
||||
state.buffer.push_str(chunk);
|
||||
|
||||
// Try to parse if we have a complete tool call
|
||||
let cleaned = Self::strip_special_tokens(&state.buffer);
|
||||
if self.extract_tool_calls(&cleaned).is_some() {
|
||||
let (_normal_text, tools) = self.parse_complete(&state.buffer).await?;
|
||||
if !tools.is_empty() {
|
||||
state.buffer.clear();
|
||||
return Ok(StreamResult::ToolComplete(
|
||||
tools.into_iter().next().unwrap(),
|
||||
));
|
||||
if let Some((tool_calls_text, _)) = self.extract_tool_calls(&cleaned) {
|
||||
if let Ok(tools) = self.parse_tool_call_block(&tool_calls_text) {
|
||||
if let Some(tool) = tools.into_iter().next() {
|
||||
state.buffer.clear();
|
||||
return Ok(StreamResult::ToolComplete(tool));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -355,13 +113,220 @@ impl ToolParser for PythonicParser {
|
||||
|
||||
fn detect_format(&self, text: &str) -> bool {
|
||||
let cleaned = Self::strip_special_tokens(text);
|
||||
self.tool_call_regex.is_match(&cleaned)
|
||||
if pythonic_block_regex().is_match(&cleaned) {
|
||||
return true;
|
||||
}
|
||||
|
||||
let trimmed = cleaned.trim();
|
||||
let Some(open_idx) = trimmed.find('[') else {
|
||||
return false;
|
||||
};
|
||||
|
||||
let after_bracket = trimmed[open_idx + 1..].trim_start();
|
||||
let mut chars = after_bracket.char_indices();
|
||||
let Some((_, first_char)) = chars.next() else {
|
||||
return false;
|
||||
};
|
||||
|
||||
if !(first_char.is_ascii_alphabetic() || first_char == '_') {
|
||||
return false;
|
||||
}
|
||||
|
||||
let mut ident_len = first_char.len_utf8();
|
||||
for (idx, ch) in chars {
|
||||
if ch.is_alphanumeric() || ch == '_' {
|
||||
ident_len = idx + ch.len_utf8();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let remaining = after_bracket[ident_len..].trim_start();
|
||||
remaining.starts_with('(')
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for PythonicParser {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
fn parse_python_expression(source: &str) -> ToolParserResult<Expr> {
|
||||
let module = parse(source, Mode::Expression, "<pythonic_tool_call>")
|
||||
.map_err(|err| ToolParserError::ParsingFailed(err.to_string()))?;
|
||||
|
||||
match module {
|
||||
Mod::Expression(expr_mod) => Ok(*expr_mod.body),
|
||||
_ => Err(ToolParserError::ParsingFailed(
|
||||
"Expected a Python expression".to_string(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
fn build_tool_call(expr: Expr, index: usize) -> ToolParserResult<ToolCall> {
|
||||
match expr {
|
||||
Expr::Call(call_expr) => {
|
||||
if !call_expr.args.is_empty() {
|
||||
return Err(ToolParserError::ParsingFailed(
|
||||
"Positional arguments are not supported in pythonic tool calls".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let function_name = match *call_expr.func {
|
||||
Expr::Name(name_expr) => name_expr.id.to_string(),
|
||||
_ => {
|
||||
return Err(ToolParserError::ParsingFailed(
|
||||
"Unsupported function reference in pythonic tool call".to_string(),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let mut arguments_map = Map::with_capacity(call_expr.keywords.len());
|
||||
for keyword in call_expr.keywords {
|
||||
let arg_name = keyword.arg.ok_or_else(|| {
|
||||
ToolParserError::ParsingFailed(
|
||||
"pythonic tool calls do not support **kwargs".to_string(),
|
||||
)
|
||||
})?;
|
||||
let value_json = expression_to_json(&keyword.value)?;
|
||||
arguments_map.insert(arg_name.to_string(), value_json);
|
||||
}
|
||||
|
||||
let arguments_json = Value::Object(arguments_map);
|
||||
let arguments_string = serde_json::to_string(&arguments_json)?;
|
||||
|
||||
Ok(ToolCall {
|
||||
id: format!("call-{}", index + 1),
|
||||
r#type: "function".to_string(),
|
||||
function: FunctionCall {
|
||||
name: function_name,
|
||||
arguments: arguments_string,
|
||||
},
|
||||
})
|
||||
}
|
||||
_ => Err(ToolParserError::ParsingFailed(
|
||||
"Expected function calls inside pythonic tool call list".to_string(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
fn expression_to_json(expr: &Expr) -> ToolParserResult<Value> {
|
||||
match expr {
|
||||
Expr::Constant(expr_constant) => constant_to_json(&expr_constant.value),
|
||||
Expr::List(list_expr) => collect_sequence(&list_expr.elts).map(Value::Array),
|
||||
Expr::Tuple(tuple_expr) => collect_sequence(&tuple_expr.elts).map(Value::Array),
|
||||
Expr::Dict(dict_expr) => {
|
||||
collect_dict(&dict_expr.keys, &dict_expr.values).map(Value::Object)
|
||||
}
|
||||
Expr::UnaryOp(unary_expr) => match unary_expr.op {
|
||||
UnaryOp::USub => match unary_expr.operand.as_ref() {
|
||||
Expr::Constant(const_expr) => negate_constant(&const_expr.value),
|
||||
_ => Err(ToolParserError::ParsingFailed(
|
||||
"Unsupported unary operand in pythonic tool call".to_string(),
|
||||
)),
|
||||
},
|
||||
UnaryOp::UAdd => expression_to_json(unary_expr.operand.as_ref()),
|
||||
_ => Err(ToolParserError::ParsingFailed(format!(
|
||||
"Unsupported unary operator in pythonic tool call: {:?}",
|
||||
unary_expr.op
|
||||
))),
|
||||
},
|
||||
Expr::Name(name_expr) => Ok(Value::String(name_expr.id.to_string())),
|
||||
_ => Err(ToolParserError::ParsingFailed(format!(
|
||||
"Unsupported expression in pythonic tool call: {:?}",
|
||||
expr
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
fn constant_to_json(constant: &Constant) -> ToolParserResult<Value> {
|
||||
match constant {
|
||||
Constant::None => Ok(Value::Null),
|
||||
Constant::Bool(b) => Ok(Value::Bool(*b)),
|
||||
Constant::Int(value) => Ok(integer_constant_to_value(value, false)),
|
||||
Constant::Float(f) => Number::from_f64(*f).map(Value::Number).ok_or_else(|| {
|
||||
ToolParserError::ParsingFailed(
|
||||
"Invalid float literal in pythonic tool call".to_string(),
|
||||
)
|
||||
}),
|
||||
Constant::Str(s) => Ok(Value::String(s.clone())),
|
||||
Constant::Bytes(bytes) => Ok(Value::String(String::from_utf8_lossy(bytes).into_owned())),
|
||||
Constant::Tuple(values) => constant_tuple_to_array(values).map(Value::Array),
|
||||
Constant::Ellipsis | Constant::Complex { .. } => Err(ToolParserError::ParsingFailed(
|
||||
"Unsupported literal in pythonic tool call".to_string(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
fn negate_constant(constant: &Constant) -> ToolParserResult<Value> {
|
||||
match constant {
|
||||
Constant::Int(value) => Ok(integer_constant_to_value(value, true)),
|
||||
Constant::Float(f) => Number::from_f64(-f).map(Value::Number).ok_or_else(|| {
|
||||
ToolParserError::ParsingFailed(
|
||||
"Invalid float literal in pythonic tool call".to_string(),
|
||||
)
|
||||
}),
|
||||
_ => Err(ToolParserError::ParsingFailed(
|
||||
"Unsupported unary operand in pythonic tool call".to_string(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
fn value_to_key_string(value: Value) -> ToolParserResult<String> {
|
||||
match value {
|
||||
Value::String(s) => Ok(s),
|
||||
Value::Number(num) => Ok(num.to_string()),
|
||||
Value::Bool(b) => Ok(b.to_string()),
|
||||
Value::Null => Ok("null".to_string()),
|
||||
other => Err(ToolParserError::ParsingFailed(format!(
|
||||
"Unsupported key type in pythonic tool call: {:?}",
|
||||
other
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
fn collect_sequence(elements: &[Expr]) -> ToolParserResult<Vec<Value>> {
|
||||
elements.iter().map(expression_to_json).collect()
|
||||
}
|
||||
|
||||
fn collect_dict(keys: &[Option<Expr>], values: &[Expr]) -> ToolParserResult<Map<String, Value>> {
|
||||
let mut map = Map::with_capacity(keys.len());
|
||||
for (key_expr, value_expr) in keys.iter().zip(values.iter()) {
|
||||
let key_expr = key_expr.as_ref().ok_or_else(|| {
|
||||
ToolParserError::ParsingFailed(
|
||||
"pythonic tool calls do not support **kwargs".to_string(),
|
||||
)
|
||||
})?;
|
||||
let key_value = expression_to_json(key_expr)?;
|
||||
let key = value_to_key_string(key_value)?;
|
||||
let value_json = expression_to_json(value_expr)?;
|
||||
map.insert(key, value_json);
|
||||
}
|
||||
Ok(map)
|
||||
}
|
||||
|
||||
fn constant_tuple_to_array(values: &[Constant]) -> ToolParserResult<Vec<Value>> {
|
||||
values.iter().map(constant_to_json).collect()
|
||||
}
|
||||
|
||||
fn integer_constant_to_value<T>(value: &T, negate: bool) -> Value
|
||||
where
|
||||
T: ToPrimitive + std::fmt::Display,
|
||||
{
|
||||
if let Some(mut i) = value.to_i64() {
|
||||
if negate {
|
||||
i = -i;
|
||||
}
|
||||
return Value::Number(Number::from(i));
|
||||
}
|
||||
|
||||
if negate {
|
||||
if let Some(u) = value.to_u64() {
|
||||
if u <= i64::MAX as u64 {
|
||||
return Value::Number(Number::from(-(u as i64)));
|
||||
}
|
||||
return Value::String(format!("-{}", value));
|
||||
}
|
||||
Value::String(format!("-{}", value))
|
||||
} else if let Some(u) = value.to_u64() {
|
||||
Value::Number(Number::from(u))
|
||||
} else {
|
||||
Value::String(value.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -405,63 +370,43 @@ mod tests {
|
||||
let args: Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
|
||||
assert_eq!(args["flag"], true);
|
||||
assert_eq!(args["disabled"], false);
|
||||
assert_eq!(args["optional"], Value::Null);
|
||||
assert!(args["optional"].is_null());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_special_tokens() {
|
||||
async fn test_strip_special_tokens() {
|
||||
let parser = PythonicParser::new();
|
||||
let input = r#"<|python_start|>[calculate(x=10, y=20)]<|python_end|>"#;
|
||||
let input = "<|python_start|>[call(arg=1)]<|python_end|>";
|
||||
|
||||
assert!(parser.detect_format(input));
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0].function.name, "calculate");
|
||||
|
||||
let args: Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
|
||||
assert_eq!(args["x"], 10);
|
||||
assert_eq!(args["y"], 20);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_llama4_format() {
|
||||
async fn test_detect_format() {
|
||||
let parser = PythonicParser::new();
|
||||
let input = r#"[get_weather(city="London", units="celsius")]"#;
|
||||
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0].function.name, "get_weather");
|
||||
|
||||
let args: Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
|
||||
assert_eq!(args["city"], "London");
|
||||
assert_eq!(args["units"], "celsius");
|
||||
assert!(parser.detect_format("[foo(bar=1)]"));
|
||||
assert!(!parser.detect_format("No python here"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_normal_text_extraction() {
|
||||
async fn test_parse_incremental() {
|
||||
let parser = PythonicParser::new();
|
||||
let mut state = ParseState::new();
|
||||
|
||||
// Test with text before and after
|
||||
let input = r#"Please check the weather [get_weather(city="Tokyo")] and let me know."#;
|
||||
let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
let chunk1 = "[call(arg=";
|
||||
let result1 = parser.parse_incremental(chunk1, &mut state).await.unwrap();
|
||||
assert!(matches!(result1, StreamResult::Incomplete));
|
||||
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0].function.name, "get_weather");
|
||||
assert_eq!(normal_text, "Please check the weather and let me know.");
|
||||
let chunk2 = "1)]";
|
||||
let result2 = parser.parse_incremental(chunk2, &mut state).await.unwrap();
|
||||
|
||||
// Test with only normal text (no tool calls)
|
||||
let input_no_tools = "This is just normal text without any tool calls.";
|
||||
let (normal_text, tools) = parser.parse_complete(input_no_tools).await.unwrap();
|
||||
|
||||
assert_eq!(tools.len(), 0);
|
||||
assert_eq!(normal_text, input_no_tools);
|
||||
|
||||
// Test with multiple tool calls in single bracket group and normal text
|
||||
let input_multiple = r#"First, [search(query="rust"), calculate(x=5, y=10)] please."#;
|
||||
let (normal_text, tools) = parser.parse_complete(input_multiple).await.unwrap();
|
||||
|
||||
assert_eq!(tools.len(), 2);
|
||||
assert_eq!(tools[0].function.name, "search");
|
||||
assert_eq!(tools[1].function.name, "calculate");
|
||||
assert_eq!(normal_text, "First, please.");
|
||||
match result2 {
|
||||
StreamResult::ToolComplete(tool) => {
|
||||
assert_eq!(tool.function.name, "call");
|
||||
}
|
||||
other => panic!("Expected ToolComplete, got {:?}", other),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,442 +0,0 @@
|
||||
/// Minimal Python literal parser for Pythonic tool call format
|
||||
///
|
||||
/// This module provides a recursive descent parser for Python literals
|
||||
/// (strings, numbers, booleans, None, lists, dicts) without requiring
|
||||
/// a full Python AST parser.
|
||||
use serde_json::{json, Value};
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::tool_parser::errors::{ToolParserError, ToolParserResult};
|
||||
|
||||
/// Token types for Python literals
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
enum Token {
|
||||
// Literals
|
||||
String(String),
|
||||
Number(String),
|
||||
True,
|
||||
False,
|
||||
None,
|
||||
|
||||
// Delimiters
|
||||
LeftBracket, // [
|
||||
RightBracket, // ]
|
||||
LeftBrace, // {
|
||||
RightBrace, // }
|
||||
LeftParen, // (
|
||||
RightParen, // )
|
||||
Comma, // ,
|
||||
Colon, // :
|
||||
Equals, // =
|
||||
|
||||
// Identifier for function names
|
||||
Identifier(String),
|
||||
|
||||
// End of input
|
||||
Eof,
|
||||
}
|
||||
|
||||
/// Lexer for Python literals
|
||||
struct Lexer {
|
||||
input: Vec<char>,
|
||||
position: usize,
|
||||
}
|
||||
|
||||
impl Lexer {
|
||||
fn new(input: &str) -> Self {
|
||||
Self {
|
||||
input: input.chars().collect(),
|
||||
position: 0,
|
||||
}
|
||||
}
|
||||
|
||||
fn current_char(&self) -> Option<char> {
|
||||
self.input.get(self.position).copied()
|
||||
}
|
||||
|
||||
fn advance(&mut self) {
|
||||
if self.position < self.input.len() {
|
||||
self.position += 1;
|
||||
}
|
||||
}
|
||||
|
||||
fn skip_whitespace(&mut self) {
|
||||
while let Some(ch) = self.current_char() {
|
||||
if ch.is_whitespace() {
|
||||
self.advance();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn read_string(&mut self, quote_char: char) -> Result<String, ToolParserError> {
|
||||
let mut result = String::new();
|
||||
self.advance(); // Skip opening quote
|
||||
|
||||
while let Some(ch) = self.current_char() {
|
||||
if ch == '\\' {
|
||||
self.advance();
|
||||
if let Some(escaped) = self.current_char() {
|
||||
match escaped {
|
||||
'n' => result.push('\n'),
|
||||
't' => result.push('\t'),
|
||||
'r' => result.push('\r'),
|
||||
'\\' => result.push('\\'),
|
||||
'\'' => result.push('\''),
|
||||
'"' => result.push('"'),
|
||||
_ => {
|
||||
result.push('\\');
|
||||
result.push(escaped);
|
||||
}
|
||||
}
|
||||
self.advance();
|
||||
}
|
||||
} else if ch == quote_char {
|
||||
self.advance(); // Skip closing quote
|
||||
return Ok(result);
|
||||
} else {
|
||||
result.push(ch);
|
||||
self.advance();
|
||||
}
|
||||
}
|
||||
|
||||
Err(ToolParserError::ParsingFailed("Unterminated string".into()))
|
||||
}
|
||||
|
||||
fn read_number(&mut self) -> String {
|
||||
let mut result = String::new();
|
||||
|
||||
// Handle negative numbers
|
||||
if self.current_char() == Some('-') {
|
||||
result.push('-');
|
||||
self.advance();
|
||||
}
|
||||
|
||||
// Read digits and decimal point
|
||||
while let Some(ch) = self.current_char() {
|
||||
if ch.is_ascii_digit() || ch == '.' || ch == 'e' || ch == 'E' || ch == '+' || ch == '-'
|
||||
{
|
||||
result.push(ch);
|
||||
self.advance();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
fn read_identifier(&mut self) -> String {
|
||||
let mut result = String::new();
|
||||
|
||||
while let Some(ch) = self.current_char() {
|
||||
if ch.is_alphanumeric() || ch == '_' {
|
||||
result.push(ch);
|
||||
self.advance();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
fn next_token(&mut self) -> Result<Token, ToolParserError> {
|
||||
self.skip_whitespace();
|
||||
|
||||
match self.current_char() {
|
||||
None => Ok(Token::Eof),
|
||||
Some('[') => {
|
||||
self.advance();
|
||||
Ok(Token::LeftBracket)
|
||||
}
|
||||
Some(']') => {
|
||||
self.advance();
|
||||
Ok(Token::RightBracket)
|
||||
}
|
||||
Some('{') => {
|
||||
self.advance();
|
||||
Ok(Token::LeftBrace)
|
||||
}
|
||||
Some('}') => {
|
||||
self.advance();
|
||||
Ok(Token::RightBrace)
|
||||
}
|
||||
Some('(') => {
|
||||
self.advance();
|
||||
Ok(Token::LeftParen)
|
||||
}
|
||||
Some(')') => {
|
||||
self.advance();
|
||||
Ok(Token::RightParen)
|
||||
}
|
||||
Some(',') => {
|
||||
self.advance();
|
||||
Ok(Token::Comma)
|
||||
}
|
||||
Some(':') => {
|
||||
self.advance();
|
||||
Ok(Token::Colon)
|
||||
}
|
||||
Some('=') => {
|
||||
self.advance();
|
||||
Ok(Token::Equals)
|
||||
}
|
||||
Some('"') => Ok(Token::String(self.read_string('"')?)),
|
||||
Some('\'') => Ok(Token::String(self.read_string('\'')?)),
|
||||
Some(ch) if ch == '-' || ch.is_ascii_digit() => Ok(Token::Number(self.read_number())),
|
||||
Some(ch) if ch.is_alphabetic() || ch == '_' => {
|
||||
let ident = self.read_identifier();
|
||||
match ident.as_str() {
|
||||
"True" => Ok(Token::True),
|
||||
"False" => Ok(Token::False),
|
||||
"None" => Ok(Token::None),
|
||||
_ => Ok(Token::Identifier(ident)),
|
||||
}
|
||||
}
|
||||
Some(ch) => Err(ToolParserError::ParsingFailed(format!(
|
||||
"Unexpected character: {}",
|
||||
ch
|
||||
))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Parser for Python literals
|
||||
pub struct PythonLiteralParser {
|
||||
lexer: Lexer,
|
||||
current_token: Token,
|
||||
}
|
||||
|
||||
impl PythonLiteralParser {
|
||||
pub fn new(input: &str) -> Result<Self, ToolParserError> {
|
||||
let mut lexer = Lexer::new(input);
|
||||
let current_token = lexer.next_token()?;
|
||||
Ok(Self {
|
||||
lexer,
|
||||
current_token,
|
||||
})
|
||||
}
|
||||
|
||||
fn advance(&mut self) -> Result<(), ToolParserError> {
|
||||
self.current_token = self.lexer.next_token()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn expect(&mut self, expected: Token) -> Result<(), ToolParserError> {
|
||||
if self.current_token == expected {
|
||||
self.advance()?;
|
||||
Ok(())
|
||||
} else {
|
||||
Err(ToolParserError::ParsingFailed(format!(
|
||||
"Expected {:?}, got {:?}",
|
||||
expected, self.current_token
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse a Python literal value
|
||||
pub fn parse_value(&mut self) -> Result<Value, ToolParserError> {
|
||||
match &self.current_token.clone() {
|
||||
Token::String(s) => {
|
||||
let value = s.clone();
|
||||
self.advance()?;
|
||||
Ok(json!(value))
|
||||
}
|
||||
Token::Number(n) => {
|
||||
let value = if let Ok(int_val) = n.parse::<i64>() {
|
||||
json!(int_val)
|
||||
} else if let Ok(float_val) = n.parse::<f64>() {
|
||||
json!(float_val)
|
||||
} else {
|
||||
return Err(ToolParserError::ParsingFailed(format!(
|
||||
"Invalid number: {}",
|
||||
n
|
||||
)));
|
||||
};
|
||||
self.advance()?;
|
||||
Ok(value)
|
||||
}
|
||||
Token::True => {
|
||||
self.advance()?;
|
||||
Ok(json!(true))
|
||||
}
|
||||
Token::False => {
|
||||
self.advance()?;
|
||||
Ok(json!(false))
|
||||
}
|
||||
Token::None => {
|
||||
self.advance()?;
|
||||
Ok(Value::Null)
|
||||
}
|
||||
Token::LeftBracket => self.parse_list(),
|
||||
Token::LeftBrace => self.parse_dict(),
|
||||
_ => Err(ToolParserError::ParsingFailed(format!(
|
||||
"Unexpected token: {:?}",
|
||||
self.current_token
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse a Python list: [item1, item2, ...]
|
||||
fn parse_list(&mut self) -> Result<Value, ToolParserError> {
|
||||
self.expect(Token::LeftBracket)?;
|
||||
let mut items = Vec::new();
|
||||
|
||||
// Handle empty list
|
||||
if self.current_token == Token::RightBracket {
|
||||
self.advance()?;
|
||||
return Ok(json!(items));
|
||||
}
|
||||
|
||||
loop {
|
||||
items.push(self.parse_value()?);
|
||||
|
||||
if self.current_token == Token::Comma {
|
||||
self.advance()?;
|
||||
// Handle trailing comma
|
||||
if self.current_token == Token::RightBracket {
|
||||
break;
|
||||
}
|
||||
} else if self.current_token == Token::RightBracket {
|
||||
break;
|
||||
} else {
|
||||
return Err(ToolParserError::ParsingFailed(format!(
|
||||
"Expected ',' or ']', got {:?}",
|
||||
self.current_token
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
self.expect(Token::RightBracket)?;
|
||||
Ok(json!(items))
|
||||
}
|
||||
|
||||
/// Parse a Python dict: {key1: value1, key2: value2, ...}
|
||||
fn parse_dict(&mut self) -> Result<Value, ToolParserError> {
|
||||
self.expect(Token::LeftBrace)?;
|
||||
let mut map = HashMap::new();
|
||||
|
||||
// Handle empty dict
|
||||
if self.current_token == Token::RightBrace {
|
||||
self.advance()?;
|
||||
return Ok(json!(map));
|
||||
}
|
||||
|
||||
loop {
|
||||
// Parse key (must be a string)
|
||||
let key = match &self.current_token {
|
||||
Token::String(s) => {
|
||||
let k = s.clone();
|
||||
self.advance()?;
|
||||
k
|
||||
}
|
||||
_ => {
|
||||
return Err(ToolParserError::ParsingFailed(format!(
|
||||
"Expected string key, got {:?}",
|
||||
self.current_token
|
||||
)))
|
||||
}
|
||||
};
|
||||
|
||||
self.expect(Token::Colon)?;
|
||||
|
||||
// Parse value
|
||||
let value = self.parse_value()?;
|
||||
map.insert(key, value);
|
||||
|
||||
if self.current_token == Token::Comma {
|
||||
self.advance()?;
|
||||
// Handle trailing comma
|
||||
if self.current_token == Token::RightBrace {
|
||||
break;
|
||||
}
|
||||
} else if self.current_token == Token::RightBrace {
|
||||
break;
|
||||
} else {
|
||||
return Err(ToolParserError::ParsingFailed(format!(
|
||||
"Expected ',' or '}}', got {:?}",
|
||||
self.current_token
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
self.expect(Token::RightBrace)?;
|
||||
Ok(json!(map))
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse a Python literal string into a JSON value
|
||||
pub fn parse_python_literal(input: &str) -> ToolParserResult<Value> {
|
||||
let mut parser = PythonLiteralParser::new(input)?;
|
||||
parser.parse_value()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_parse_primitives() {
|
||||
assert_eq!(parse_python_literal("True").unwrap(), json!(true));
|
||||
assert_eq!(parse_python_literal("False").unwrap(), json!(false));
|
||||
assert_eq!(parse_python_literal("None").unwrap(), Value::Null);
|
||||
assert_eq!(parse_python_literal("42").unwrap(), json!(42));
|
||||
assert_eq!(parse_python_literal("12.345").unwrap(), json!(12.345));
|
||||
assert_eq!(parse_python_literal("-42").unwrap(), json!(-42));
|
||||
assert_eq!(parse_python_literal("\"hello\"").unwrap(), json!("hello"));
|
||||
assert_eq!(parse_python_literal("'world'").unwrap(), json!("world"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_list() {
|
||||
assert_eq!(parse_python_literal("[]").unwrap(), json!([]));
|
||||
assert_eq!(parse_python_literal("[1, 2, 3]").unwrap(), json!([1, 2, 3]));
|
||||
assert_eq!(
|
||||
parse_python_literal("[\"a\", \"b\", \"c\"]").unwrap(),
|
||||
json!(["a", "b", "c"])
|
||||
);
|
||||
assert_eq!(
|
||||
parse_python_literal("[True, False, None]").unwrap(),
|
||||
json!([true, false, null])
|
||||
);
|
||||
// Nested list
|
||||
assert_eq!(
|
||||
parse_python_literal("[[1, 2], [3, 4]]").unwrap(),
|
||||
json!([[1, 2], [3, 4]])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_dict() {
|
||||
assert_eq!(parse_python_literal("{}").unwrap(), json!({}));
|
||||
assert_eq!(
|
||||
parse_python_literal("{\"a\": 1, \"b\": 2}").unwrap(),
|
||||
json!({"a": 1, "b": 2})
|
||||
);
|
||||
assert_eq!(
|
||||
parse_python_literal("{'x': True, 'y': False}").unwrap(),
|
||||
json!({"x": true, "y": false})
|
||||
);
|
||||
// Nested dict
|
||||
assert_eq!(
|
||||
parse_python_literal("{\"nested\": {\"value\": [1, 2, 3]}}").unwrap(),
|
||||
json!({"nested": {"value": [1, 2, 3]}})
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_complex_nested() {
|
||||
let input = r#"{"config": {"nested": {"value": [1, 2, 3]}, "enabled": True}}"#;
|
||||
let expected = json!({
|
||||
"config": {
|
||||
"nested": {
|
||||
"value": [1, 2, 3]
|
||||
},
|
||||
"enabled": true
|
||||
}
|
||||
});
|
||||
assert_eq!(parse_python_literal(input).unwrap(), expected);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user