[router] add mistral tool parser (#9622)
Co-authored-by: Chang Su <chang.s.su@oracle.com>
This commit is contained in:
@@ -7,7 +7,7 @@ use crate::tool_parser::{
|
|||||||
partial_json::PartialJson,
|
partial_json::PartialJson,
|
||||||
state::ParseState,
|
state::ParseState,
|
||||||
traits::ToolParser,
|
traits::ToolParser,
|
||||||
types::{FunctionCall, StreamResult, ToolCall},
|
types::{FunctionCall, StreamResult, TokenConfig, ToolCall},
|
||||||
};
|
};
|
||||||
|
|
||||||
/// JSON format parser for tool calls
|
/// JSON format parser for tool calls
|
||||||
@@ -19,12 +19,8 @@ use crate::tool_parser::{
|
|||||||
///
|
///
|
||||||
/// Supports configurable token markers for different models
|
/// Supports configurable token markers for different models
|
||||||
pub struct JsonParser {
|
pub struct JsonParser {
|
||||||
/// Token(s) that mark the start of tool calls
|
/// Token configuration for parsing
|
||||||
start_tokens: Vec<String>,
|
token_config: TokenConfig,
|
||||||
/// Token(s) that mark the end of tool calls
|
|
||||||
end_tokens: Vec<String>,
|
|
||||||
/// Separator between multiple tool calls (reserved for future use)
|
|
||||||
_separator: String,
|
|
||||||
/// Parser for handling incomplete JSON during streaming
|
/// Parser for handling incomplete JSON during streaming
|
||||||
partial_json: PartialJson,
|
partial_json: PartialJson,
|
||||||
/// Regex patterns for extracting content between tokens
|
/// Regex patterns for extracting content between tokens
|
||||||
@@ -34,23 +30,18 @@ pub struct JsonParser {
|
|||||||
impl JsonParser {
|
impl JsonParser {
|
||||||
/// Create a new JSON parser with default configuration
|
/// Create a new JSON parser with default configuration
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
Self::with_config(
|
Self::with_config(TokenConfig {
|
||||||
vec![], // No wrapper tokens by default
|
start_tokens: vec![],
|
||||||
vec![],
|
end_tokens: vec![],
|
||||||
", ".to_string(),
|
separator: ", ".to_string(),
|
||||||
)
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create a parser with custom token configuration
|
/// Create a parser with custom token configuration
|
||||||
pub fn with_config(
|
pub fn with_config(config: TokenConfig) -> Self {
|
||||||
start_tokens: Vec<String>,
|
|
||||||
end_tokens: Vec<String>,
|
|
||||||
separator: String,
|
|
||||||
) -> Self {
|
|
||||||
// Build extraction patterns for each token pair
|
// Build extraction patterns for each token pair
|
||||||
let extractors = start_tokens
|
let extractors: Vec<Regex> = config
|
||||||
.iter()
|
.iter_pairs()
|
||||||
.zip(end_tokens.iter())
|
|
||||||
.filter_map(|(start, end)| {
|
.filter_map(|(start, end)| {
|
||||||
if !start.is_empty() && !end.is_empty() {
|
if !start.is_empty() && !end.is_empty() {
|
||||||
// Use (?s) flag to enable DOTALL mode so . matches newlines
|
// Use (?s) flag to enable DOTALL mode so . matches newlines
|
||||||
@@ -64,9 +55,7 @@ impl JsonParser {
|
|||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
start_tokens,
|
token_config: config,
|
||||||
end_tokens,
|
|
||||||
_separator: separator,
|
|
||||||
partial_json: PartialJson::default(),
|
partial_json: PartialJson::default(),
|
||||||
extractors,
|
extractors,
|
||||||
}
|
}
|
||||||
@@ -74,26 +63,90 @@ impl JsonParser {
|
|||||||
|
|
||||||
/// Extract JSON content from text, handling wrapper tokens if configured
|
/// Extract JSON content from text, handling wrapper tokens if configured
|
||||||
fn extract_json_content<'a>(&self, text: &'a str) -> &'a str {
|
fn extract_json_content<'a>(&self, text: &'a str) -> &'a str {
|
||||||
let mut content = text.trim();
|
let mut content = text;
|
||||||
|
|
||||||
// Try each extractor pattern
|
// Try each extractor pattern (for tokens with both start and end)
|
||||||
for extractor in &self.extractors {
|
for extractor in &self.extractors {
|
||||||
if let Some(captures) = extractor.captures(content) {
|
if let Some(captures) = extractor.captures(content) {
|
||||||
if let Some(matched) = captures.get(1) {
|
if let Some(matched) = captures.get(1) {
|
||||||
content = matched.as_str().trim();
|
return matched.as_str().trim();
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle special case where there's a start token but no end token
|
// Handle special case where there's a start token but no end token
|
||||||
for (start, end) in self.start_tokens.iter().zip(self.end_tokens.iter()) {
|
for (start, end) in self.token_config.iter_pairs() {
|
||||||
if !start.is_empty() && end.is_empty() {
|
if !start.is_empty() && end.is_empty() {
|
||||||
content = content.strip_prefix(start).unwrap_or(content);
|
// Find the start token and extract everything after it
|
||||||
|
if let Some(pos) = content.find(start) {
|
||||||
|
content = &content[pos + start.len()..];
|
||||||
|
return content.trim();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
content
|
content.trim()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Try to extract a JSON object or array from text that may contain other content
|
||||||
|
fn extract_json_from_text(&self, text: &str) -> Option<String> {
|
||||||
|
// Look for JSON object starting with {
|
||||||
|
if let Some(start) = text.find('{') {
|
||||||
|
let mut depth = 0;
|
||||||
|
let mut in_string = false;
|
||||||
|
let mut escape_next = false;
|
||||||
|
|
||||||
|
for (i, ch) in text[start..].char_indices() {
|
||||||
|
if escape_next {
|
||||||
|
escape_next = false;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
match ch {
|
||||||
|
'\\' if in_string => escape_next = true,
|
||||||
|
'"' if !in_string => in_string = true,
|
||||||
|
'"' if in_string => in_string = false,
|
||||||
|
'{' if !in_string => depth += 1,
|
||||||
|
'}' if !in_string => {
|
||||||
|
depth -= 1;
|
||||||
|
if depth == 0 {
|
||||||
|
return Some(text[start..start + i + 1].to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Look for JSON array starting with [
|
||||||
|
if let Some(start) = text.find('[') {
|
||||||
|
let mut depth = 0;
|
||||||
|
let mut in_string = false;
|
||||||
|
let mut escape_next = false;
|
||||||
|
|
||||||
|
for (i, ch) in text[start..].char_indices() {
|
||||||
|
if escape_next {
|
||||||
|
escape_next = false;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
match ch {
|
||||||
|
'\\' if in_string => escape_next = true,
|
||||||
|
'"' if !in_string => in_string = true,
|
||||||
|
'"' if in_string => in_string = false,
|
||||||
|
'[' if !in_string => depth += 1,
|
||||||
|
']' if !in_string => {
|
||||||
|
depth -= 1;
|
||||||
|
if depth == 0 {
|
||||||
|
return Some(text[start..start + i + 1].to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
None
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Parse a single JSON object into a ToolCall
|
/// Parse a single JSON object into a ToolCall
|
||||||
@@ -167,13 +220,16 @@ impl JsonParser {
|
|||||||
/// Check if text contains potential tool call markers
|
/// Check if text contains potential tool call markers
|
||||||
fn has_tool_markers(&self, text: &str) -> bool {
|
fn has_tool_markers(&self, text: &str) -> bool {
|
||||||
// If no start tokens configured, check for JSON structure
|
// If no start tokens configured, check for JSON structure
|
||||||
if self.start_tokens.is_empty() {
|
if self.token_config.start_tokens.is_empty() {
|
||||||
// For JSON, we just need to see the start of an object or array
|
// For JSON, we just need to see the start of an object or array
|
||||||
return text.contains('{') || text.contains('[');
|
return text.contains('{') || text.contains('[');
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for any start token
|
// Check for any start token
|
||||||
self.start_tokens.iter().any(|token| text.contains(token))
|
self.token_config
|
||||||
|
.start_tokens
|
||||||
|
.iter()
|
||||||
|
.any(|token| text.contains(token))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -193,6 +249,15 @@ impl ToolParser for JsonParser {
|
|||||||
match serde_json::from_str::<Value>(json_content) {
|
match serde_json::from_str::<Value>(json_content) {
|
||||||
Ok(value) => self.parse_json_value(&value),
|
Ok(value) => self.parse_json_value(&value),
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
|
// If no wrapper tokens configured and parse failed,
|
||||||
|
// try to extract JSON from mixed text
|
||||||
|
if self.token_config.start_tokens.is_empty() {
|
||||||
|
if let Some(extracted) = self.extract_json_from_text(text) {
|
||||||
|
if let Ok(value) = serde_json::from_str::<Value>(&extracted) {
|
||||||
|
return self.parse_json_value(&value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
// Not valid JSON, return empty
|
// Not valid JSON, return empty
|
||||||
Ok(vec![])
|
Ok(vec![])
|
||||||
}
|
}
|
||||||
@@ -341,11 +406,11 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_parse_with_wrapper_tokens() {
|
async fn test_parse_with_wrapper_tokens() {
|
||||||
let parser = JsonParser::with_config(
|
let parser = JsonParser::with_config(TokenConfig {
|
||||||
vec!["<tool>".to_string()],
|
start_tokens: vec!["<tool>".to_string()],
|
||||||
vec!["</tool>".to_string()],
|
end_tokens: vec!["</tool>".to_string()],
|
||||||
", ".to_string(),
|
separator: ", ".to_string(),
|
||||||
);
|
});
|
||||||
|
|
||||||
let input = r#"<tool>{"name": "test", "arguments": {}}</tool>"#;
|
let input = r#"<tool>{"name": "test", "arguments": {}}</tool>"#;
|
||||||
let result = parser.parse_complete(input).await.unwrap();
|
let result = parser.parse_complete(input).await.unwrap();
|
||||||
|
|||||||
347
sgl-router/src/tool_parser/mistral_parser.rs
Normal file
347
sgl-router/src/tool_parser/mistral_parser.rs
Normal file
@@ -0,0 +1,347 @@
|
|||||||
|
use async_trait::async_trait;
|
||||||
|
use serde_json::Value;
|
||||||
|
|
||||||
|
use crate::tool_parser::{
|
||||||
|
errors::{ToolParserError, ToolParserResult},
|
||||||
|
partial_json::PartialJson,
|
||||||
|
state::ParseState,
|
||||||
|
traits::ToolParser,
|
||||||
|
types::{FunctionCall, StreamResult, ToolCall},
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Mistral format parser for tool calls
|
||||||
|
///
|
||||||
|
/// Handles the Mistral-specific format:
|
||||||
|
/// `[TOOL_CALLS] [{"name": "func", "arguments": {...}}, ...]`
|
||||||
|
///
|
||||||
|
/// Features:
|
||||||
|
/// - Bracket counting for proper JSON array extraction
|
||||||
|
/// - Support for multiple tool calls in a single array
|
||||||
|
/// - String-aware parsing to handle nested brackets in JSON
|
||||||
|
pub struct MistralParser {
|
||||||
|
/// Parser for handling incomplete JSON during streaming
|
||||||
|
partial_json: PartialJson,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MistralParser {
|
||||||
|
/// Create a new Mistral parser
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
partial_json: PartialJson::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Extract JSON array using bracket counting
|
||||||
|
///
|
||||||
|
/// Handles nested brackets in JSON content by tracking:
|
||||||
|
/// - String boundaries (quotes)
|
||||||
|
/// - Escape sequences
|
||||||
|
/// - Bracket depth
|
||||||
|
fn extract_json_array<'a>(&self, text: &'a str) -> Option<&'a str> {
|
||||||
|
const BOT_TOKEN: &str = "[TOOL_CALLS] [";
|
||||||
|
|
||||||
|
// Find the start of the token
|
||||||
|
let start_idx = text.find(BOT_TOKEN)?;
|
||||||
|
|
||||||
|
// Start from the opening bracket after [TOOL_CALLS]
|
||||||
|
// The -1 is to include the opening bracket that's part of the token
|
||||||
|
let json_start = start_idx + BOT_TOKEN.len() - 1;
|
||||||
|
|
||||||
|
let mut bracket_count = 0;
|
||||||
|
let mut in_string = false;
|
||||||
|
let mut escape_next = false;
|
||||||
|
|
||||||
|
let bytes = text.as_bytes();
|
||||||
|
|
||||||
|
for i in json_start..text.len() {
|
||||||
|
let char = bytes[i];
|
||||||
|
|
||||||
|
if escape_next {
|
||||||
|
escape_next = false;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if char == b'\\' {
|
||||||
|
escape_next = true;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if char == b'"' && !escape_next {
|
||||||
|
in_string = !in_string;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if !in_string {
|
||||||
|
if char == b'[' {
|
||||||
|
bracket_count += 1;
|
||||||
|
} else if char == b']' {
|
||||||
|
bracket_count -= 1;
|
||||||
|
if bracket_count == 0 {
|
||||||
|
// Found the matching closing bracket
|
||||||
|
return Some(&text[json_start..=i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Incomplete array (no matching closing bracket found)
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parse tool calls from a JSON array
|
||||||
|
fn parse_json_array(&self, json_str: &str) -> ToolParserResult<Vec<ToolCall>> {
|
||||||
|
let value: Value = serde_json::from_str(json_str)
|
||||||
|
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
|
||||||
|
|
||||||
|
let mut tools = Vec::new();
|
||||||
|
|
||||||
|
if let Value::Array(arr) = value {
|
||||||
|
for (index, item) in arr.iter().enumerate() {
|
||||||
|
if let Some(tool) = self.parse_single_object(item, index)? {
|
||||||
|
tools.push(tool);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Single object case (shouldn't happen with Mistral format, but handle it)
|
||||||
|
if let Some(tool) = self.parse_single_object(&value, 0)? {
|
||||||
|
tools.push(tool);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(tools)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parse a single JSON object into a ToolCall
|
||||||
|
fn parse_single_object(&self, obj: &Value, index: usize) -> ToolParserResult<Option<ToolCall>> {
|
||||||
|
let name = obj.get("name").and_then(|v| v.as_str());
|
||||||
|
|
||||||
|
if let Some(name) = name {
|
||||||
|
// Get arguments - Mistral uses "arguments" key
|
||||||
|
let empty_obj = Value::Object(serde_json::Map::new());
|
||||||
|
let args = obj.get("arguments").unwrap_or(&empty_obj);
|
||||||
|
|
||||||
|
// Convert arguments to JSON string
|
||||||
|
let arguments = serde_json::to_string(args)
|
||||||
|
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
|
||||||
|
|
||||||
|
// Generate ID with index for multiple tools
|
||||||
|
let id = format!("mistral_call_{}", index);
|
||||||
|
|
||||||
|
Ok(Some(ToolCall {
|
||||||
|
id,
|
||||||
|
r#type: "function".to_string(),
|
||||||
|
function: FunctionCall {
|
||||||
|
name: name.to_string(),
|
||||||
|
arguments,
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
} else {
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if text contains Mistral tool markers
|
||||||
|
fn has_tool_markers(&self, text: &str) -> bool {
|
||||||
|
text.contains("[TOOL_CALLS]")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for MistralParser {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl ToolParser for MistralParser {
|
||||||
|
async fn parse_complete(&self, text: &str) -> ToolParserResult<Vec<ToolCall>> {
|
||||||
|
// Check if text contains Mistral format
|
||||||
|
if !self.has_tool_markers(text) {
|
||||||
|
return Ok(vec![]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract JSON array from Mistral format
|
||||||
|
if let Some(json_array) = self.extract_json_array(text) {
|
||||||
|
self.parse_json_array(json_array)
|
||||||
|
} else {
|
||||||
|
// Markers present but no complete array found
|
||||||
|
Ok(vec![])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn parse_incremental(
|
||||||
|
&self,
|
||||||
|
chunk: &str,
|
||||||
|
state: &mut ParseState,
|
||||||
|
) -> ToolParserResult<StreamResult> {
|
||||||
|
state.buffer.push_str(chunk);
|
||||||
|
|
||||||
|
// Check if we have the start marker
|
||||||
|
if !self.has_tool_markers(&state.buffer) {
|
||||||
|
return Ok(StreamResult::Incomplete);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to extract complete JSON array
|
||||||
|
if let Some(json_array) = self.extract_json_array(&state.buffer) {
|
||||||
|
// Parse with partial JSON to handle incomplete content
|
||||||
|
match self.partial_json.parse_value(json_array) {
|
||||||
|
Ok((value, consumed)) => {
|
||||||
|
// Check if we have a complete JSON structure
|
||||||
|
if consumed == json_array.len() {
|
||||||
|
// Complete JSON, parse tool calls
|
||||||
|
let tools = if let Value::Array(arr) = value {
|
||||||
|
let mut result = Vec::new();
|
||||||
|
for (index, item) in arr.iter().enumerate() {
|
||||||
|
if let Some(tool) = self.parse_single_object(item, index)? {
|
||||||
|
result.push(tool);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result
|
||||||
|
} else {
|
||||||
|
vec![]
|
||||||
|
};
|
||||||
|
|
||||||
|
if !tools.is_empty() {
|
||||||
|
// Clear buffer since we consumed everything
|
||||||
|
state.buffer.clear();
|
||||||
|
|
||||||
|
// Return the first tool (simplified for Phase 3)
|
||||||
|
// Full multi-tool streaming will be implemented later
|
||||||
|
if let Some(tool) = tools.into_iter().next() {
|
||||||
|
return Ok(StreamResult::ToolComplete(tool));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Partial JSON - try to extract tool name for streaming
|
||||||
|
if let Value::Array(arr) = value {
|
||||||
|
if let Some(first_tool) = arr.first() {
|
||||||
|
if let Some(name) = first_tool.get("name").and_then(|v| v.as_str())
|
||||||
|
{
|
||||||
|
// Check if we've already sent the name
|
||||||
|
if !state.in_string {
|
||||||
|
state.in_string = true; // Use as flag for "name sent"
|
||||||
|
return Ok(StreamResult::ToolName {
|
||||||
|
index: 0,
|
||||||
|
name: name.to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for arguments
|
||||||
|
if let Some(args) = first_tool.get("arguments") {
|
||||||
|
if let Ok(args_str) = serde_json::to_string(args) {
|
||||||
|
return Ok(StreamResult::ToolArguments {
|
||||||
|
index: 0,
|
||||||
|
arguments: args_str,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(_) => {
|
||||||
|
// Failed to parse even as partial JSON
|
||||||
|
// Keep buffering
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(StreamResult::Incomplete)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn detect_format(&self, text: &str) -> bool {
|
||||||
|
// Check if text contains Mistral-specific markers
|
||||||
|
if self.has_tool_markers(text) {
|
||||||
|
// Try to extract and validate the array
|
||||||
|
if let Some(json_array) = self.extract_json_array(text) {
|
||||||
|
// Check if it's valid JSON
|
||||||
|
if let Ok(value) = serde_json::from_str::<Value>(json_array) {
|
||||||
|
// Check if it contains tool-like structures
|
||||||
|
match value {
|
||||||
|
Value::Array(ref arr) => arr.iter().any(|v| {
|
||||||
|
v.as_object().is_some_and(|o| {
|
||||||
|
o.contains_key("name") && o.contains_key("arguments")
|
||||||
|
})
|
||||||
|
}),
|
||||||
|
Value::Object(ref obj) => {
|
||||||
|
obj.contains_key("name") && obj.contains_key("arguments")
|
||||||
|
}
|
||||||
|
_ => false,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Has markers but no complete array - might be streaming
|
||||||
|
true
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_parse_mistral_format() {
|
||||||
|
let parser = MistralParser::new();
|
||||||
|
let input = r#"[TOOL_CALLS] [{"name": "get_weather", "arguments": {"location": "Paris", "units": "celsius"}}]"#;
|
||||||
|
|
||||||
|
let result = parser.parse_complete(input).await.unwrap();
|
||||||
|
assert_eq!(result.len(), 1);
|
||||||
|
assert_eq!(result[0].function.name, "get_weather");
|
||||||
|
assert!(result[0].function.arguments.contains("Paris"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_parse_multiple_tools() {
|
||||||
|
let parser = MistralParser::new();
|
||||||
|
let input = r#"[TOOL_CALLS] [
|
||||||
|
{"name": "search", "arguments": {"query": "rust programming"}},
|
||||||
|
{"name": "calculate", "arguments": {"expression": "2 + 2"}}
|
||||||
|
]"#;
|
||||||
|
|
||||||
|
let result = parser.parse_complete(input).await.unwrap();
|
||||||
|
assert_eq!(result.len(), 2);
|
||||||
|
assert_eq!(result[0].function.name, "search");
|
||||||
|
assert_eq!(result[1].function.name, "calculate");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_nested_brackets_in_json() {
|
||||||
|
let parser = MistralParser::new();
|
||||||
|
let input = r#"[TOOL_CALLS] [{"name": "process", "arguments": {"data": [1, 2, [3, 4]], "config": {"nested": [5, 6]}}}]"#;
|
||||||
|
|
||||||
|
let result = parser.parse_complete(input).await.unwrap();
|
||||||
|
assert_eq!(result.len(), 1);
|
||||||
|
assert_eq!(result[0].function.name, "process");
|
||||||
|
// JSON serialization removes spaces, so check for [3,4] without spaces
|
||||||
|
assert!(result[0].function.arguments.contains("[3,4]"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_escaped_quotes_in_strings() {
|
||||||
|
let parser = MistralParser::new();
|
||||||
|
let input = r#"[TOOL_CALLS] [{"name": "echo", "arguments": {"message": "He said \"Hello [World]\""}}]"#;
|
||||||
|
|
||||||
|
let result = parser.parse_complete(input).await.unwrap();
|
||||||
|
assert_eq!(result.len(), 1);
|
||||||
|
assert_eq!(result[0].function.name, "echo");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_detect_format() {
|
||||||
|
let parser = MistralParser::new();
|
||||||
|
|
||||||
|
assert!(parser.detect_format(r#"[TOOL_CALLS] [{"name": "test", "arguments": {}}]"#));
|
||||||
|
assert!(
|
||||||
|
parser.detect_format(r#"Some text [TOOL_CALLS] [{"name": "test", "arguments": {}}]"#)
|
||||||
|
);
|
||||||
|
assert!(!parser.detect_format(r#"{"name": "test", "arguments": {}}"#));
|
||||||
|
assert!(!parser.detect_format("plain text"));
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -3,7 +3,9 @@
|
|||||||
/// This module provides infrastructure for parsing tool calls from various model formats.
|
/// This module provides infrastructure for parsing tool calls from various model formats.
|
||||||
pub mod errors;
|
pub mod errors;
|
||||||
pub mod json_parser;
|
pub mod json_parser;
|
||||||
|
pub mod mistral_parser;
|
||||||
pub mod partial_json;
|
pub mod partial_json;
|
||||||
|
|
||||||
pub mod registry;
|
pub mod registry;
|
||||||
pub mod state;
|
pub mod state;
|
||||||
pub mod traits;
|
pub mod traits;
|
||||||
@@ -15,6 +17,7 @@ mod tests;
|
|||||||
// Re-export commonly used types
|
// Re-export commonly used types
|
||||||
pub use errors::{ToolParserError, ToolParserResult};
|
pub use errors::{ToolParserError, ToolParserResult};
|
||||||
pub use json_parser::JsonParser;
|
pub use json_parser::JsonParser;
|
||||||
|
pub use mistral_parser::MistralParser;
|
||||||
pub use registry::ParserRegistry;
|
pub use registry::ParserRegistry;
|
||||||
pub use state::{ParsePhase, ParseState};
|
pub use state::{ParsePhase, ParseState};
|
||||||
pub use traits::{PartialJsonParser, ToolParser};
|
pub use traits::{PartialJsonParser, ToolParser};
|
||||||
|
|||||||
@@ -50,15 +50,28 @@ impl ParserRegistry {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try prefix matching (e.g., "gpt-4" matches "gpt-*")
|
// Try prefix matching with more specific patterns first
|
||||||
for (pattern, parser_name) in &self.model_mapping {
|
// Collect all matching patterns and sort by specificity (longer = more specific)
|
||||||
if pattern.ends_with('*') {
|
let mut matches: Vec<(&String, &String)> = self
|
||||||
let prefix = &pattern[..pattern.len() - 1];
|
.model_mapping
|
||||||
if model.starts_with(prefix) {
|
.iter()
|
||||||
if let Some(parser) = self.parsers.get(parser_name) {
|
.filter(|(pattern, _)| {
|
||||||
return Some(parser.clone());
|
if pattern.ends_with('*') {
|
||||||
}
|
let prefix = &pattern[..pattern.len() - 1];
|
||||||
|
model.starts_with(prefix)
|
||||||
|
} else {
|
||||||
|
false
|
||||||
}
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
// Sort by pattern length in descending order (longer patterns are more specific)
|
||||||
|
matches.sort_by_key(|(pattern, _)| std::cmp::Reverse(pattern.len()));
|
||||||
|
|
||||||
|
// Return the first matching parser
|
||||||
|
for (_, parser_name) in matches {
|
||||||
|
if let Some(parser) = self.parsers.get(parser_name) {
|
||||||
|
return Some(parser.clone());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -97,20 +110,32 @@ impl ParserRegistry {
|
|||||||
// Anthropic models
|
// Anthropic models
|
||||||
self.map_model("claude-*", "json");
|
self.map_model("claude-*", "json");
|
||||||
|
|
||||||
// Mistral models (will use json until mistral parser is implemented)
|
// Mistral models - use Mistral parser
|
||||||
self.map_model("mistral-*", "json");
|
self.map_model("mistral-*", "mistral");
|
||||||
self.map_model("mixtral-*", "json");
|
self.map_model("mixtral-*", "mistral");
|
||||||
|
|
||||||
// Qwen models (will use json until qwen parser is implemented)
|
// Qwen models - use Qwen parser
|
||||||
self.map_model("qwen*", "json");
|
self.map_model("qwen*", "qwen");
|
||||||
|
self.map_model("Qwen*", "qwen");
|
||||||
|
|
||||||
// Llama models (will use json until llama parser is implemented)
|
// Llama models
|
||||||
|
// Llama 4 uses pythonic format
|
||||||
|
self.map_model("llama-4*", "pythonic");
|
||||||
|
self.map_model("meta-llama-4*", "pythonic");
|
||||||
|
// Llama 3.2 uses python_tag format
|
||||||
|
self.map_model("llama-3.2*", "llama");
|
||||||
|
self.map_model("meta-llama-3.2*", "llama");
|
||||||
|
// Other Llama models use JSON
|
||||||
self.map_model("llama-*", "json");
|
self.map_model("llama-*", "json");
|
||||||
self.map_model("meta-llama-*", "json");
|
self.map_model("meta-llama-*", "json");
|
||||||
|
|
||||||
|
// DeepSeek models - DeepSeek v3 would need custom parser, v2 uses pythonic
|
||||||
|
self.map_model("deepseek-*", "pythonic");
|
||||||
|
|
||||||
// Other models default to JSON
|
// Other models default to JSON
|
||||||
self.map_model("gemini-*", "json");
|
self.map_model("gemini-*", "json");
|
||||||
self.map_model("palm-*", "json");
|
self.map_model("palm-*", "json");
|
||||||
|
self.map_model("gemma-*", "json");
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Set the default parser
|
/// Set the default parser
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ use crate::tool_parser::partial_json::{
|
|||||||
compute_diff, find_common_prefix, is_complete_json, PartialJson,
|
compute_diff, find_common_prefix, is_complete_json, PartialJson,
|
||||||
};
|
};
|
||||||
use crate::tool_parser::traits::ToolParser;
|
use crate::tool_parser::traits::ToolParser;
|
||||||
|
use crate::tool_parser::types::TokenConfig;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_parse_state_new() {
|
fn test_parse_state_new() {
|
||||||
@@ -299,11 +300,11 @@ async fn test_json_parser_with_parameters() {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_json_parser_with_tokens() {
|
async fn test_json_parser_with_tokens() {
|
||||||
// Test with custom wrapper tokens
|
// Test with custom wrapper tokens
|
||||||
let parser = JsonParser::with_config(
|
let parser = JsonParser::with_config(TokenConfig {
|
||||||
vec!["[TOOL_CALLS] [".to_string()],
|
start_tokens: vec!["[TOOL_CALLS] [".to_string()],
|
||||||
vec!["]".to_string()],
|
end_tokens: vec!["]".to_string()],
|
||||||
", ".to_string(),
|
separator: ", ".to_string(),
|
||||||
);
|
});
|
||||||
|
|
||||||
let input = r#"[TOOL_CALLS] [{"name": "search", "arguments": {"query": "rust programming"}}]"#;
|
let input = r#"[TOOL_CALLS] [{"name": "search", "arguments": {"query": "rust programming"}}]"#;
|
||||||
let result = parser.parse_complete(input).await.unwrap();
|
let result = parser.parse_complete(input).await.unwrap();
|
||||||
@@ -315,11 +316,11 @@ async fn test_json_parser_with_tokens() {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_multiline_json_with_tokens() {
|
async fn test_multiline_json_with_tokens() {
|
||||||
// Test that regex with (?s) flag properly handles multi-line JSON
|
// Test that regex with (?s) flag properly handles multi-line JSON
|
||||||
let parser = JsonParser::with_config(
|
let parser = JsonParser::with_config(TokenConfig {
|
||||||
vec!["<tool>".to_string()],
|
start_tokens: vec!["<tool>".to_string()],
|
||||||
vec!["</tool>".to_string()],
|
end_tokens: vec!["</tool>".to_string()],
|
||||||
", ".to_string(),
|
separator: ", ".to_string(),
|
||||||
);
|
});
|
||||||
|
|
||||||
// Pretty-printed multi-line JSON
|
// Pretty-printed multi-line JSON
|
||||||
let input = r#"<tool>{
|
let input = r#"<tool>{
|
||||||
@@ -493,11 +494,11 @@ mod failure_cases {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_broken_wrapper_tokens() {
|
async fn test_broken_wrapper_tokens() {
|
||||||
let parser = JsonParser::with_config(
|
let parser = JsonParser::with_config(TokenConfig {
|
||||||
vec!["<tool>".to_string()],
|
start_tokens: vec!["<tool>".to_string()],
|
||||||
vec!["</tool>".to_string()],
|
end_tokens: vec!["</tool>".to_string()],
|
||||||
", ".to_string(),
|
separator: ", ".to_string(),
|
||||||
);
|
});
|
||||||
|
|
||||||
// Missing end token
|
// Missing end token
|
||||||
let input = r#"<tool>{"name": "test", "arguments": {}}"#;
|
let input = r#"<tool>{"name": "test", "arguments": {}}"#;
|
||||||
@@ -678,11 +679,11 @@ mod edge_cases {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_multiple_token_pairs_with_conflicts() {
|
async fn test_multiple_token_pairs_with_conflicts() {
|
||||||
// Test with overlapping token patterns
|
// Test with overlapping token patterns
|
||||||
let parser = JsonParser::with_config(
|
let parser = JsonParser::with_config(TokenConfig {
|
||||||
vec!["<<".to_string(), "<tool>".to_string()],
|
start_tokens: vec!["<<".to_string(), "<tool>".to_string()],
|
||||||
vec![">>".to_string(), "</tool>".to_string()],
|
end_tokens: vec![">>".to_string(), "</tool>".to_string()],
|
||||||
", ".to_string(),
|
separator: ", ".to_string(),
|
||||||
);
|
});
|
||||||
|
|
||||||
// First pattern
|
// First pattern
|
||||||
let input = r#"<<{"name": "test1", "arguments": {}}>>"#;
|
let input = r#"<<{"name": "test1", "arguments": {}}>>"#;
|
||||||
|
|||||||
Reference in New Issue
Block a user