[router] restructure tool parser module folder (#9693)
This commit is contained in:
456
sgl-router/src/tool_parser/parsers/json_parser.rs
Normal file
456
sgl-router/src/tool_parser/parsers/json_parser.rs
Normal file
@@ -0,0 +1,456 @@
|
||||
use async_trait::async_trait;
|
||||
use regex::Regex;
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::tool_parser::{
|
||||
errors::{ToolParserError, ToolParserResult},
|
||||
partial_json::PartialJson,
|
||||
state::ParseState,
|
||||
traits::ToolParser,
|
||||
types::{FunctionCall, StreamResult, TokenConfig, ToolCall},
|
||||
};
|
||||
|
||||
/// JSON format parser for tool calls
|
||||
///
|
||||
/// Handles various JSON formats for function calling:
|
||||
/// - Single tool call: {"name": "fn", "arguments": {...}}
|
||||
/// - Multiple tool calls: [{"name": "fn1", "arguments": {...}}, ...]
|
||||
/// - With parameters instead of arguments: {"name": "fn", "parameters": {...}}
|
||||
///
|
||||
/// Supports configurable token markers for different models
|
||||
pub struct JsonParser {
|
||||
/// Token configuration for parsing
|
||||
token_config: TokenConfig,
|
||||
/// Parser for handling incomplete JSON during streaming
|
||||
partial_json: PartialJson,
|
||||
/// Regex patterns for extracting content between tokens
|
||||
extractors: Vec<Regex>,
|
||||
}
|
||||
|
||||
impl JsonParser {
|
||||
/// Create a new JSON parser with default configuration
|
||||
pub fn new() -> Self {
|
||||
Self::with_config(TokenConfig {
|
||||
start_tokens: vec![],
|
||||
end_tokens: vec![],
|
||||
separator: ", ".to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Create a parser with custom token configuration
|
||||
pub fn with_config(config: TokenConfig) -> Self {
|
||||
// Build extraction patterns for each token pair
|
||||
let extractors: Vec<Regex> = config
|
||||
.iter_pairs()
|
||||
.filter_map(|(start, end)| {
|
||||
if !start.is_empty() && !end.is_empty() {
|
||||
// Use (?s) flag to enable DOTALL mode so . matches newlines
|
||||
let pattern =
|
||||
format!(r"(?s){}(.*?){}", regex::escape(start), regex::escape(end));
|
||||
Regex::new(&pattern).ok()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
Self {
|
||||
token_config: config,
|
||||
partial_json: PartialJson::default(),
|
||||
extractors,
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract JSON content from text, handling wrapper tokens if configured
|
||||
fn extract_json_content<'a>(&self, text: &'a str) -> &'a str {
|
||||
let mut content = text;
|
||||
|
||||
// Try each extractor pattern (for tokens with both start and end)
|
||||
for extractor in &self.extractors {
|
||||
if let Some(captures) = extractor.captures(content) {
|
||||
if let Some(matched) = captures.get(1) {
|
||||
return matched.as_str().trim();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle special case where there's a start token but no end token
|
||||
for (start, end) in self.token_config.iter_pairs() {
|
||||
if !start.is_empty() && end.is_empty() {
|
||||
// Find the start token and extract everything after it
|
||||
if let Some(pos) = content.find(start) {
|
||||
content = &content[pos + start.len()..];
|
||||
return content.trim();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
content.trim()
|
||||
}
|
||||
|
||||
/// Try to extract a JSON object or array from text that may contain other content
|
||||
fn extract_json_from_text(&self, text: &str) -> Option<String> {
|
||||
// Look for JSON object starting with {
|
||||
if let Some(start) = text.find('{') {
|
||||
let mut depth = 0;
|
||||
let mut in_string = false;
|
||||
let mut escape_next = false;
|
||||
|
||||
for (i, ch) in text[start..].char_indices() {
|
||||
if escape_next {
|
||||
escape_next = false;
|
||||
continue;
|
||||
}
|
||||
|
||||
match ch {
|
||||
'\\' if in_string => escape_next = true,
|
||||
'"' if !in_string => in_string = true,
|
||||
'"' if in_string => in_string = false,
|
||||
'{' if !in_string => depth += 1,
|
||||
'}' if !in_string => {
|
||||
depth -= 1;
|
||||
if depth == 0 {
|
||||
return Some(text[start..start + i + 1].to_string());
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Look for JSON array starting with [
|
||||
if let Some(start) = text.find('[') {
|
||||
let mut depth = 0;
|
||||
let mut in_string = false;
|
||||
let mut escape_next = false;
|
||||
|
||||
for (i, ch) in text[start..].char_indices() {
|
||||
if escape_next {
|
||||
escape_next = false;
|
||||
continue;
|
||||
}
|
||||
|
||||
match ch {
|
||||
'\\' if in_string => escape_next = true,
|
||||
'"' if !in_string => in_string = true,
|
||||
'"' if in_string => in_string = false,
|
||||
'[' if !in_string => depth += 1,
|
||||
']' if !in_string => {
|
||||
depth -= 1;
|
||||
if depth == 0 {
|
||||
return Some(text[start..start + i + 1].to_string());
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Parse a single JSON object into a ToolCall
|
||||
fn parse_single_object(&self, obj: &Value) -> ToolParserResult<Option<ToolCall>> {
|
||||
// Check if this looks like a tool call
|
||||
let name = obj
|
||||
.get("name")
|
||||
.or_else(|| obj.get("function"))
|
||||
.and_then(|v| v.as_str());
|
||||
|
||||
if let Some(name) = name {
|
||||
// Get arguments - support both "arguments" and "parameters" keys
|
||||
let empty_obj = Value::Object(serde_json::Map::new());
|
||||
let args = obj
|
||||
.get("arguments")
|
||||
.or_else(|| obj.get("parameters"))
|
||||
.unwrap_or(&empty_obj);
|
||||
|
||||
// Convert arguments to JSON string
|
||||
let arguments = serde_json::to_string(args)
|
||||
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
|
||||
|
||||
// Generate a unique ID if not provided
|
||||
let id = obj
|
||||
.get("id")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(String::from)
|
||||
.unwrap_or_else(|| format!("call_{}", uuid::Uuid::new_v4()));
|
||||
|
||||
Ok(Some(ToolCall {
|
||||
id,
|
||||
r#type: "function".to_string(),
|
||||
function: FunctionCall {
|
||||
name: name.to_string(),
|
||||
arguments,
|
||||
},
|
||||
}))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse JSON value(s) into tool calls
|
||||
fn parse_json_value(&self, value: &Value) -> ToolParserResult<Vec<ToolCall>> {
|
||||
let mut tools = Vec::new();
|
||||
|
||||
match value {
|
||||
Value::Array(arr) => {
|
||||
// Parse each element in the array
|
||||
for item in arr {
|
||||
if let Some(tool) = self.parse_single_object(item)? {
|
||||
tools.push(tool);
|
||||
}
|
||||
}
|
||||
}
|
||||
Value::Object(_) => {
|
||||
// Single tool call
|
||||
if let Some(tool) = self.parse_single_object(value)? {
|
||||
tools.push(tool);
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
// Not a valid tool call format
|
||||
return Ok(vec![]);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(tools)
|
||||
}
|
||||
|
||||
/// Check if text contains potential tool call markers
|
||||
fn has_tool_markers(&self, text: &str) -> bool {
|
||||
// If no start tokens configured, check for JSON structure
|
||||
if self.token_config.start_tokens.is_empty() {
|
||||
// For JSON, we just need to see the start of an object or array
|
||||
return text.contains('{') || text.contains('[');
|
||||
}
|
||||
|
||||
// Check for any start token
|
||||
self.token_config
|
||||
.start_tokens
|
||||
.iter()
|
||||
.any(|token| text.contains(token))
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for JsonParser {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ToolParser for JsonParser {
|
||||
async fn parse_complete(&self, text: &str) -> ToolParserResult<Vec<ToolCall>> {
|
||||
// Extract JSON content from wrapper tokens if present
|
||||
let json_content = self.extract_json_content(text);
|
||||
|
||||
// Try to parse as JSON
|
||||
match serde_json::from_str::<Value>(json_content) {
|
||||
Ok(value) => self.parse_json_value(&value),
|
||||
Err(_) => {
|
||||
// If no wrapper tokens configured and parse failed,
|
||||
// try to extract JSON from mixed text
|
||||
if self.token_config.start_tokens.is_empty() {
|
||||
if let Some(extracted) = self.extract_json_from_text(text) {
|
||||
if let Ok(value) = serde_json::from_str::<Value>(&extracted) {
|
||||
return self.parse_json_value(&value);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Not valid JSON, return empty
|
||||
Ok(vec![])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn parse_incremental(
|
||||
&self,
|
||||
chunk: &str,
|
||||
state: &mut ParseState,
|
||||
) -> ToolParserResult<StreamResult> {
|
||||
state.buffer.push_str(chunk);
|
||||
|
||||
// Check if we have potential tool calls
|
||||
if !self.has_tool_markers(&state.buffer) {
|
||||
// No tool markers, return as incomplete
|
||||
return Ok(StreamResult::Incomplete);
|
||||
}
|
||||
|
||||
// Extract JSON content
|
||||
let json_content = self.extract_json_content(&state.buffer);
|
||||
|
||||
// Try to parse with partial JSON parser
|
||||
match self.partial_json.parse_value(json_content) {
|
||||
Ok((value, consumed)) => {
|
||||
// Check if we have a complete JSON structure
|
||||
if consumed == json_content.len() {
|
||||
// Complete JSON, parse tool calls
|
||||
let tools = self.parse_json_value(&value)?;
|
||||
if !tools.is_empty() {
|
||||
// Clear buffer since we consumed everything
|
||||
state.buffer.clear();
|
||||
|
||||
// Return the first tool as complete
|
||||
// TODO simplified version, address more complex version
|
||||
if let Some(tool) = tools.into_iter().next() {
|
||||
return Ok(StreamResult::ToolComplete(tool));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Partial JSON, try to extract tool name
|
||||
if let Some(name) = value.get("name").and_then(|v| v.as_str()) {
|
||||
// TODO simplified version, address more complex version
|
||||
// Just return the tool name once we see it
|
||||
if !state.in_string {
|
||||
state.in_string = true; // Use as a flag for "name sent"
|
||||
return Ok(StreamResult::ToolName {
|
||||
index: 0,
|
||||
name: name.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
// Check for complete arguments
|
||||
if let Some(args) =
|
||||
value.get("arguments").or_else(|| value.get("parameters"))
|
||||
{
|
||||
if let Ok(args_str) = serde_json::to_string(args) {
|
||||
// Return arguments as a single update
|
||||
return Ok(StreamResult::ToolArguments {
|
||||
index: 0,
|
||||
arguments: args_str,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
// Failed to parse even as partial JSON
|
||||
// Keep buffering
|
||||
}
|
||||
}
|
||||
|
||||
Ok(StreamResult::Incomplete)
|
||||
}
|
||||
|
||||
fn detect_format(&self, text: &str) -> bool {
|
||||
// Check if text contains JSON-like structure
|
||||
if self.has_tool_markers(text) {
|
||||
// Try to extract and parse
|
||||
let json_content = self.extract_json_content(text);
|
||||
|
||||
// Check if it looks like valid JSON for tool calls
|
||||
if let Ok(value) = serde_json::from_str::<Value>(json_content) {
|
||||
match value {
|
||||
Value::Object(ref obj) => {
|
||||
// Check for tool call structure
|
||||
obj.contains_key("name") || obj.contains_key("function")
|
||||
}
|
||||
Value::Array(ref arr) => {
|
||||
// Check if array contains tool-like objects
|
||||
arr.iter().any(|v| {
|
||||
v.as_object().is_some_and(|o| {
|
||||
o.contains_key("name") || o.contains_key("function")
|
||||
})
|
||||
})
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
} else {
|
||||
false
|
||||
}
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_single_tool_call() {
|
||||
let parser = JsonParser::new();
|
||||
let input = r#"{"name": "get_weather", "arguments": {"location": "San Francisco"}}"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "get_weather");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_multiple_tool_calls() {
|
||||
let parser = JsonParser::new();
|
||||
let input = r#"[
|
||||
{"name": "get_weather", "arguments": {"location": "SF"}},
|
||||
{"name": "search", "arguments": {"query": "news"}}
|
||||
]"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 2);
|
||||
assert_eq!(result[0].function.name, "get_weather");
|
||||
assert_eq!(result[1].function.name, "search");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_with_parameters_key() {
|
||||
let parser = JsonParser::new();
|
||||
let input = r#"{"name": "calculate", "parameters": {"x": 10, "y": 20}}"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "calculate");
|
||||
assert!(result[0].function.arguments.contains("10"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_with_wrapper_tokens() {
|
||||
let parser = JsonParser::with_config(TokenConfig {
|
||||
start_tokens: vec!["<tool>".to_string()],
|
||||
end_tokens: vec!["</tool>".to_string()],
|
||||
separator: ", ".to_string(),
|
||||
});
|
||||
|
||||
let input = r#"<tool>{"name": "test", "arguments": {}}</tool>"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "test");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_format() {
|
||||
let parser = JsonParser::new();
|
||||
|
||||
assert!(parser.detect_format(r#"{"name": "test", "arguments": {}}"#));
|
||||
assert!(parser.detect_format(r#"[{"name": "test"}]"#));
|
||||
assert!(!parser.detect_format("plain text"));
|
||||
assert!(!parser.detect_format(r#"{"key": "value"}"#));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_streaming_parse() {
|
||||
// Just verify that streaming eventually produces a complete tool call
|
||||
let parser = JsonParser::new();
|
||||
let mut state = ParseState::new();
|
||||
|
||||
// Send complete JSON in one go
|
||||
// TODO simplified version, address more complex version
|
||||
let full_json = r#"{"name": "get_weather", "arguments": {"location": "SF"}}"#;
|
||||
|
||||
let result = parser
|
||||
.parse_incremental(full_json, &mut state)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Should get a complete tool immediately with complete JSON
|
||||
match result {
|
||||
StreamResult::ToolComplete(tool) => {
|
||||
assert_eq!(tool.function.name, "get_weather");
|
||||
assert!(tool.function.arguments.contains("SF"));
|
||||
}
|
||||
_ => panic!("Expected ToolComplete for complete JSON input"),
|
||||
}
|
||||
}
|
||||
}
|
||||
156
sgl-router/src/tool_parser/parsers/llama_parser.rs
Normal file
156
sgl-router/src/tool_parser/parsers/llama_parser.rs
Normal file
@@ -0,0 +1,156 @@
|
||||
use async_trait::async_trait;
|
||||
|
||||
use super::json_parser::JsonParser;
|
||||
use crate::tool_parser::{
|
||||
errors::ToolParserResult,
|
||||
state::ParseState,
|
||||
traits::ToolParser,
|
||||
types::{StreamResult, TokenConfig, ToolCall},
|
||||
};
|
||||
|
||||
/// Llama 3.2 format parser for tool calls
|
||||
///
|
||||
/// Handles the Llama 3.2 specific format:
|
||||
/// `<|python_tag|>{"name": "func", "arguments": {...}}`
|
||||
///
|
||||
/// Also supports plain JSON without the python_tag prefix
|
||||
pub struct LlamaParser {
|
||||
/// Underlying JSON parser with Llama-specific configuration
|
||||
json_parser: JsonParser,
|
||||
}
|
||||
|
||||
impl LlamaParser {
|
||||
/// Create a new Llama parser
|
||||
pub fn new() -> Self {
|
||||
// Configure JSON parser with Llama's python_tag token
|
||||
// Note: No end token for python_tag format
|
||||
let json_parser = JsonParser::with_config(TokenConfig {
|
||||
start_tokens: vec!["<|python_tag|>".to_string()],
|
||||
end_tokens: vec!["".to_string()], // Empty end token
|
||||
separator: ";".to_string(), // Llama uses semicolon for multiple calls (though not well supported)
|
||||
});
|
||||
|
||||
Self { json_parser }
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for LlamaParser {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ToolParser for LlamaParser {
|
||||
async fn parse_complete(&self, text: &str) -> ToolParserResult<Vec<ToolCall>> {
|
||||
// First try with the configured python_tag parser
|
||||
let result = self.json_parser.parse_complete(text).await?;
|
||||
|
||||
if !result.is_empty() {
|
||||
return Ok(result);
|
||||
}
|
||||
|
||||
// If no results and text starts with '{', try plain JSON
|
||||
if text.trim_start().starts_with('{') {
|
||||
// Create a temporary plain JSON parser
|
||||
let plain_parser = JsonParser::new();
|
||||
return plain_parser.parse_complete(text).await;
|
||||
}
|
||||
|
||||
Ok(vec![])
|
||||
}
|
||||
|
||||
async fn parse_incremental(
|
||||
&self,
|
||||
chunk: &str,
|
||||
state: &mut ParseState,
|
||||
) -> ToolParserResult<StreamResult> {
|
||||
// Try with the python_tag parser first
|
||||
let result = self.json_parser.parse_incremental(chunk, state).await?;
|
||||
|
||||
// If we get Incomplete and buffer starts with '{', might be plain JSON
|
||||
if matches!(result, StreamResult::Incomplete) && state.buffer.trim_start().starts_with('{')
|
||||
{
|
||||
// Check if we have python_tag in the buffer
|
||||
if !state.buffer.contains("<|python_tag|>") {
|
||||
// Likely plain JSON, create temporary parser
|
||||
let plain_parser = JsonParser::new();
|
||||
return plain_parser.parse_incremental("", state).await;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn detect_format(&self, text: &str) -> bool {
|
||||
// Llama format if contains python_tag or starts with JSON object
|
||||
text.contains("<|python_tag|>")
|
||||
|| (text.trim_start().starts_with('{')
|
||||
&& (text.contains(r#""name""#) || text.contains(r#""function""#)))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_with_python_tag() {
|
||||
let parser = LlamaParser::new();
|
||||
let input = r#"<|python_tag|>{"name": "search", "arguments": {"query": "weather"}}"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "search");
|
||||
assert!(result[0].function.arguments.contains("weather"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_plain_json() {
|
||||
let parser = LlamaParser::new();
|
||||
let input = r#"{"name": "calculate", "arguments": {"x": 5, "y": 10}}"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "calculate");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_with_text_before() {
|
||||
let parser = LlamaParser::new();
|
||||
let input = r#"Let me help you with that. <|python_tag|>{"name": "get_time", "arguments": {"timezone": "UTC"}}"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "get_time");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_format() {
|
||||
let parser = LlamaParser::new();
|
||||
|
||||
assert!(parser.detect_format(r#"<|python_tag|>{"name": "test"}"#));
|
||||
assert!(parser.detect_format(r#"{"name": "test", "arguments": {}}"#));
|
||||
assert!(!parser.detect_format("plain text"));
|
||||
assert!(!parser.detect_format(r#"{"key": "value"}"#)); // No name field
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_single_call_with_semicolon() {
|
||||
let parser = LlamaParser::new();
|
||||
// Note: Llama 3.2 doesn't handle multiple calls well
|
||||
// Test that we can at least parse a single call followed by semicolon
|
||||
let input = r#"<|python_tag|>{"name": "func1", "arguments": {"x": 1}};"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
|
||||
// We expect this to either parse the first JSON object or fail gracefully
|
||||
// Since the semicolon makes it invalid JSON, it will likely return empty
|
||||
// This is acceptable as Llama 3.2 doesn't reliably support parallel calls
|
||||
|
||||
// If it parses anything, it should be func1
|
||||
if !result.is_empty() {
|
||||
assert_eq!(result[0].function.name, "func1");
|
||||
}
|
||||
}
|
||||
}
|
||||
347
sgl-router/src/tool_parser/parsers/mistral_parser.rs
Normal file
347
sgl-router/src/tool_parser/parsers/mistral_parser.rs
Normal file
@@ -0,0 +1,347 @@
|
||||
use async_trait::async_trait;
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::tool_parser::{
|
||||
errors::{ToolParserError, ToolParserResult},
|
||||
partial_json::PartialJson,
|
||||
state::ParseState,
|
||||
traits::ToolParser,
|
||||
types::{FunctionCall, StreamResult, ToolCall},
|
||||
};
|
||||
|
||||
/// Mistral format parser for tool calls
|
||||
///
|
||||
/// Handles the Mistral-specific format:
|
||||
/// `[TOOL_CALLS] [{"name": "func", "arguments": {...}}, ...]`
|
||||
///
|
||||
/// Features:
|
||||
/// - Bracket counting for proper JSON array extraction
|
||||
/// - Support for multiple tool calls in a single array
|
||||
/// - String-aware parsing to handle nested brackets in JSON
|
||||
pub struct MistralParser {
|
||||
/// Parser for handling incomplete JSON during streaming
|
||||
partial_json: PartialJson,
|
||||
}
|
||||
|
||||
impl MistralParser {
|
||||
/// Create a new Mistral parser
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
partial_json: PartialJson::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract JSON array using bracket counting
|
||||
///
|
||||
/// Handles nested brackets in JSON content by tracking:
|
||||
/// - String boundaries (quotes)
|
||||
/// - Escape sequences
|
||||
/// - Bracket depth
|
||||
fn extract_json_array<'a>(&self, text: &'a str) -> Option<&'a str> {
|
||||
const BOT_TOKEN: &str = "[TOOL_CALLS] [";
|
||||
|
||||
// Find the start of the token
|
||||
let start_idx = text.find(BOT_TOKEN)?;
|
||||
|
||||
// Start from the opening bracket after [TOOL_CALLS]
|
||||
// The -1 is to include the opening bracket that's part of the token
|
||||
let json_start = start_idx + BOT_TOKEN.len() - 1;
|
||||
|
||||
let mut bracket_count = 0;
|
||||
let mut in_string = false;
|
||||
let mut escape_next = false;
|
||||
|
||||
let bytes = text.as_bytes();
|
||||
|
||||
for i in json_start..text.len() {
|
||||
let char = bytes[i];
|
||||
|
||||
if escape_next {
|
||||
escape_next = false;
|
||||
continue;
|
||||
}
|
||||
|
||||
if char == b'\\' {
|
||||
escape_next = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
if char == b'"' && !escape_next {
|
||||
in_string = !in_string;
|
||||
continue;
|
||||
}
|
||||
|
||||
if !in_string {
|
||||
if char == b'[' {
|
||||
bracket_count += 1;
|
||||
} else if char == b']' {
|
||||
bracket_count -= 1;
|
||||
if bracket_count == 0 {
|
||||
// Found the matching closing bracket
|
||||
return Some(&text[json_start..=i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Incomplete array (no matching closing bracket found)
|
||||
None
|
||||
}
|
||||
|
||||
/// Parse tool calls from a JSON array
|
||||
fn parse_json_array(&self, json_str: &str) -> ToolParserResult<Vec<ToolCall>> {
|
||||
let value: Value = serde_json::from_str(json_str)
|
||||
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
|
||||
|
||||
let mut tools = Vec::new();
|
||||
|
||||
if let Value::Array(arr) = value {
|
||||
for (index, item) in arr.iter().enumerate() {
|
||||
if let Some(tool) = self.parse_single_object(item, index)? {
|
||||
tools.push(tool);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Single object case (shouldn't happen with Mistral format, but handle it)
|
||||
if let Some(tool) = self.parse_single_object(&value, 0)? {
|
||||
tools.push(tool);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(tools)
|
||||
}
|
||||
|
||||
/// Parse a single JSON object into a ToolCall
|
||||
fn parse_single_object(&self, obj: &Value, index: usize) -> ToolParserResult<Option<ToolCall>> {
|
||||
let name = obj.get("name").and_then(|v| v.as_str());
|
||||
|
||||
if let Some(name) = name {
|
||||
// Get arguments - Mistral uses "arguments" key
|
||||
let empty_obj = Value::Object(serde_json::Map::new());
|
||||
let args = obj.get("arguments").unwrap_or(&empty_obj);
|
||||
|
||||
// Convert arguments to JSON string
|
||||
let arguments = serde_json::to_string(args)
|
||||
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
|
||||
|
||||
// Generate ID with index for multiple tools
|
||||
let id = format!("mistral_call_{}", index);
|
||||
|
||||
Ok(Some(ToolCall {
|
||||
id,
|
||||
r#type: "function".to_string(),
|
||||
function: FunctionCall {
|
||||
name: name.to_string(),
|
||||
arguments,
|
||||
},
|
||||
}))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if text contains Mistral tool markers
|
||||
fn has_tool_markers(&self, text: &str) -> bool {
|
||||
text.contains("[TOOL_CALLS]")
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for MistralParser {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ToolParser for MistralParser {
|
||||
async fn parse_complete(&self, text: &str) -> ToolParserResult<Vec<ToolCall>> {
|
||||
// Check if text contains Mistral format
|
||||
if !self.has_tool_markers(text) {
|
||||
return Ok(vec![]);
|
||||
}
|
||||
|
||||
// Extract JSON array from Mistral format
|
||||
if let Some(json_array) = self.extract_json_array(text) {
|
||||
self.parse_json_array(json_array)
|
||||
} else {
|
||||
// Markers present but no complete array found
|
||||
Ok(vec![])
|
||||
}
|
||||
}
|
||||
|
||||
async fn parse_incremental(
|
||||
&self,
|
||||
chunk: &str,
|
||||
state: &mut ParseState,
|
||||
) -> ToolParserResult<StreamResult> {
|
||||
state.buffer.push_str(chunk);
|
||||
|
||||
// Check if we have the start marker
|
||||
if !self.has_tool_markers(&state.buffer) {
|
||||
return Ok(StreamResult::Incomplete);
|
||||
}
|
||||
|
||||
// Try to extract complete JSON array
|
||||
if let Some(json_array) = self.extract_json_array(&state.buffer) {
|
||||
// Parse with partial JSON to handle incomplete content
|
||||
match self.partial_json.parse_value(json_array) {
|
||||
Ok((value, consumed)) => {
|
||||
// Check if we have a complete JSON structure
|
||||
if consumed == json_array.len() {
|
||||
// Complete JSON, parse tool calls
|
||||
let tools = if let Value::Array(arr) = value {
|
||||
let mut result = Vec::new();
|
||||
for (index, item) in arr.iter().enumerate() {
|
||||
if let Some(tool) = self.parse_single_object(item, index)? {
|
||||
result.push(tool);
|
||||
}
|
||||
}
|
||||
result
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
|
||||
if !tools.is_empty() {
|
||||
// Clear buffer since we consumed everything
|
||||
state.buffer.clear();
|
||||
|
||||
// Return the first tool (simplified for Phase 3)
|
||||
// Full multi-tool streaming will be implemented later
|
||||
if let Some(tool) = tools.into_iter().next() {
|
||||
return Ok(StreamResult::ToolComplete(tool));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Partial JSON - try to extract tool name for streaming
|
||||
if let Value::Array(arr) = value {
|
||||
if let Some(first_tool) = arr.first() {
|
||||
if let Some(name) = first_tool.get("name").and_then(|v| v.as_str())
|
||||
{
|
||||
// Check if we've already sent the name
|
||||
if !state.in_string {
|
||||
state.in_string = true; // Use as flag for "name sent"
|
||||
return Ok(StreamResult::ToolName {
|
||||
index: 0,
|
||||
name: name.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
// Check for arguments
|
||||
if let Some(args) = first_tool.get("arguments") {
|
||||
if let Ok(args_str) = serde_json::to_string(args) {
|
||||
return Ok(StreamResult::ToolArguments {
|
||||
index: 0,
|
||||
arguments: args_str,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
// Failed to parse even as partial JSON
|
||||
// Keep buffering
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(StreamResult::Incomplete)
|
||||
}
|
||||
|
||||
fn detect_format(&self, text: &str) -> bool {
|
||||
// Check if text contains Mistral-specific markers
|
||||
if self.has_tool_markers(text) {
|
||||
// Try to extract and validate the array
|
||||
if let Some(json_array) = self.extract_json_array(text) {
|
||||
// Check if it's valid JSON
|
||||
if let Ok(value) = serde_json::from_str::<Value>(json_array) {
|
||||
// Check if it contains tool-like structures
|
||||
match value {
|
||||
Value::Array(ref arr) => arr.iter().any(|v| {
|
||||
v.as_object().is_some_and(|o| {
|
||||
o.contains_key("name") && o.contains_key("arguments")
|
||||
})
|
||||
}),
|
||||
Value::Object(ref obj) => {
|
||||
obj.contains_key("name") && obj.contains_key("arguments")
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
} else {
|
||||
false
|
||||
}
|
||||
} else {
|
||||
// Has markers but no complete array - might be streaming
|
||||
true
|
||||
}
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_mistral_format() {
|
||||
let parser = MistralParser::new();
|
||||
let input = r#"[TOOL_CALLS] [{"name": "get_weather", "arguments": {"location": "Paris", "units": "celsius"}}]"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "get_weather");
|
||||
assert!(result[0].function.arguments.contains("Paris"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_multiple_tools() {
|
||||
let parser = MistralParser::new();
|
||||
let input = r#"[TOOL_CALLS] [
|
||||
{"name": "search", "arguments": {"query": "rust programming"}},
|
||||
{"name": "calculate", "arguments": {"expression": "2 + 2"}}
|
||||
]"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 2);
|
||||
assert_eq!(result[0].function.name, "search");
|
||||
assert_eq!(result[1].function.name, "calculate");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_nested_brackets_in_json() {
|
||||
let parser = MistralParser::new();
|
||||
let input = r#"[TOOL_CALLS] [{"name": "process", "arguments": {"data": [1, 2, [3, 4]], "config": {"nested": [5, 6]}}}]"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "process");
|
||||
// JSON serialization removes spaces, so check for [3,4] without spaces
|
||||
assert!(result[0].function.arguments.contains("[3,4]"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_escaped_quotes_in_strings() {
|
||||
let parser = MistralParser::new();
|
||||
let input = r#"[TOOL_CALLS] [{"name": "echo", "arguments": {"message": "He said \"Hello [World]\""}}]"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "echo");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_format() {
|
||||
let parser = MistralParser::new();
|
||||
|
||||
assert!(parser.detect_format(r#"[TOOL_CALLS] [{"name": "test", "arguments": {}}]"#));
|
||||
assert!(
|
||||
parser.detect_format(r#"Some text [TOOL_CALLS] [{"name": "test", "arguments": {}}]"#)
|
||||
);
|
||||
assert!(!parser.detect_format(r#"{"name": "test", "arguments": {}}"#));
|
||||
assert!(!parser.detect_format("plain text"));
|
||||
}
|
||||
}
|
||||
16
sgl-router/src/tool_parser/parsers/mod.rs
Normal file
16
sgl-router/src/tool_parser/parsers/mod.rs
Normal file
@@ -0,0 +1,16 @@
|
||||
/// Parser implementations for different model formats
|
||||
///
|
||||
/// This module contains concrete parser implementations for various model-specific
|
||||
/// tool/function call formats.
|
||||
// Individual parser modules
|
||||
pub mod json_parser;
|
||||
pub mod llama_parser;
|
||||
pub mod mistral_parser;
|
||||
pub mod pythonic_parser;
|
||||
pub mod qwen_parser;
|
||||
|
||||
pub use json_parser::JsonParser;
|
||||
pub use llama_parser::LlamaParser;
|
||||
pub use mistral_parser::MistralParser;
|
||||
pub use pythonic_parser::PythonicParser;
|
||||
pub use qwen_parser::QwenParser;
|
||||
428
sgl-router/src/tool_parser/parsers/pythonic_parser.rs
Normal file
428
sgl-router/src/tool_parser/parsers/pythonic_parser.rs
Normal file
@@ -0,0 +1,428 @@
|
||||
/// Pythonic format parser for tool calls
|
||||
///
|
||||
/// Handles Python function call syntax within square brackets:
|
||||
/// ```text
|
||||
/// [tool1(arg1=val1, arg2=val2), tool2(arg1=val3)]
|
||||
/// ```
|
||||
///
|
||||
/// This format is used by Llama-4 models and uses Python literals
|
||||
/// rather than JSON for arguments.
|
||||
use async_trait::async_trait;
|
||||
use regex::Regex;
|
||||
use serde_json::{json, Value};
|
||||
|
||||
use crate::tool_parser::{
|
||||
errors::ToolParserResult,
|
||||
python_literal_parser::parse_python_literal,
|
||||
state::ParseState,
|
||||
traits::ToolParser,
|
||||
types::{FunctionCall, StreamResult, ToolCall},
|
||||
};
|
||||
|
||||
/// Parser for Pythonic tool call format
|
||||
pub struct PythonicParser {
|
||||
/// Regex to detect tool calls in Pythonic format
|
||||
tool_call_regex: Regex,
|
||||
}
|
||||
|
||||
impl PythonicParser {
|
||||
/// Create a new Pythonic parser
|
||||
pub fn new() -> Self {
|
||||
// Simple regex to detect start of Pythonic tool calls
|
||||
// We'll use manual parsing for the actual extraction
|
||||
let pattern = r"\[[a-zA-Z_]\w*\(";
|
||||
let tool_call_regex = Regex::new(pattern).expect("Valid regex pattern");
|
||||
|
||||
Self { tool_call_regex }
|
||||
}
|
||||
|
||||
/// Extract tool calls using bracket counting (similar to MistralParser)
|
||||
fn extract_tool_calls(&self, text: &str) -> Option<String> {
|
||||
// Find the start of a tool call list - look for [ followed by a function name
|
||||
let chars: Vec<char> = text.chars().collect();
|
||||
|
||||
for start_idx in 0..chars.len() {
|
||||
if chars[start_idx] != '[' {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check if this looks like a tool call
|
||||
// Skip whitespace after [
|
||||
let mut check_idx = start_idx + 1;
|
||||
while check_idx < chars.len() && chars[check_idx].is_whitespace() {
|
||||
check_idx += 1;
|
||||
}
|
||||
|
||||
// Check if we have a function name (starts with letter or underscore)
|
||||
if check_idx >= chars.len()
|
||||
|| (!chars[check_idx].is_alphabetic() && chars[check_idx] != '_')
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
// Now count brackets to find the matching ]
|
||||
let mut bracket_count = 0;
|
||||
let mut _paren_count = 0;
|
||||
let mut _brace_count = 0;
|
||||
let mut in_string = false;
|
||||
let mut string_char = ' ';
|
||||
let mut escape_next = false;
|
||||
|
||||
for i in start_idx..chars.len() {
|
||||
let ch = chars[i];
|
||||
|
||||
if escape_next {
|
||||
escape_next = false;
|
||||
continue;
|
||||
}
|
||||
|
||||
if ch == '\\' && in_string {
|
||||
escape_next = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
if !in_string && (ch == '"' || ch == '\'') {
|
||||
in_string = true;
|
||||
string_char = ch;
|
||||
} else if in_string && ch == string_char && !escape_next {
|
||||
in_string = false;
|
||||
} else if !in_string {
|
||||
match ch {
|
||||
'[' => bracket_count += 1,
|
||||
']' => {
|
||||
bracket_count -= 1;
|
||||
if bracket_count == 0 {
|
||||
// Found the matching bracket
|
||||
let extracted: String = chars[start_idx..=i].iter().collect();
|
||||
// Verify this actually contains a function call
|
||||
if extracted.contains('(') && extracted.contains(')') {
|
||||
return Some(extracted);
|
||||
}
|
||||
}
|
||||
}
|
||||
'(' => _paren_count += 1,
|
||||
')' => _paren_count -= 1,
|
||||
'{' => _brace_count += 1,
|
||||
'}' => _brace_count -= 1,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Strip special tokens that Llama 4 might output
|
||||
fn strip_special_tokens(text: &str) -> String {
|
||||
text.replace("<|python_start|>", "")
|
||||
.replace("<|python_end|>", "")
|
||||
}
|
||||
|
||||
/// Parse a single function call from Python syntax
|
||||
fn parse_function_call(&self, call_str: &str) -> ToolParserResult<Option<ToolCall>> {
|
||||
// Match function_name(args) - use (?s) to make . match newlines
|
||||
let call_regex = Regex::new(r"(?s)^([a-zA-Z_]\w*)\((.*)\)$").unwrap();
|
||||
|
||||
if let Some(captures) = call_regex.captures(call_str.trim()) {
|
||||
let function_name = captures.get(1).unwrap().as_str();
|
||||
let args_str = captures.get(2).unwrap().as_str();
|
||||
|
||||
// Parse arguments
|
||||
let arguments = self.parse_arguments(args_str)?;
|
||||
|
||||
Ok(Some(ToolCall {
|
||||
id: format!("call_{}", uuid::Uuid::new_v4()),
|
||||
r#type: "function".to_string(),
|
||||
function: FunctionCall {
|
||||
name: function_name.to_string(),
|
||||
arguments: serde_json::to_string(&arguments)?,
|
||||
},
|
||||
}))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse Python-style arguments into JSON
|
||||
fn parse_arguments(&self, args_str: &str) -> ToolParserResult<Value> {
|
||||
if args_str.trim().is_empty() {
|
||||
return Ok(json!({}));
|
||||
}
|
||||
|
||||
let mut result = serde_json::Map::new();
|
||||
let mut current_key = String::new();
|
||||
let mut current_value = String::new();
|
||||
let mut in_key = true;
|
||||
let mut depth = 0;
|
||||
let mut in_string = false;
|
||||
let mut string_char = ' ';
|
||||
let mut escape_next = false;
|
||||
|
||||
let chars: Vec<char> = args_str.chars().collect();
|
||||
let mut i = 0;
|
||||
|
||||
while i < chars.len() {
|
||||
let ch = chars[i];
|
||||
|
||||
if escape_next {
|
||||
if in_key {
|
||||
current_key.push(ch);
|
||||
} else {
|
||||
current_value.push(ch);
|
||||
}
|
||||
escape_next = false;
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
if ch == '\\' && in_string {
|
||||
escape_next = true;
|
||||
current_value.push(ch);
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Handle string literals
|
||||
if !in_string && (ch == '"' || ch == '\'') {
|
||||
in_string = true;
|
||||
string_char = ch;
|
||||
if !in_key {
|
||||
current_value.push(ch);
|
||||
}
|
||||
} else if in_string && ch == string_char && !escape_next {
|
||||
in_string = false;
|
||||
if !in_key {
|
||||
current_value.push(ch);
|
||||
}
|
||||
} else if in_string {
|
||||
if in_key {
|
||||
current_key.push(ch);
|
||||
} else {
|
||||
current_value.push(ch);
|
||||
}
|
||||
} else {
|
||||
// Not in string
|
||||
match ch {
|
||||
'=' if in_key && depth == 0 => {
|
||||
in_key = false;
|
||||
}
|
||||
',' if depth == 0 => {
|
||||
// End of current argument
|
||||
if !current_key.is_empty() {
|
||||
let value = parse_python_literal(current_value.trim())?;
|
||||
result.insert(current_key.trim().to_string(), value);
|
||||
}
|
||||
current_key.clear();
|
||||
current_value.clear();
|
||||
in_key = true;
|
||||
}
|
||||
'[' | '{' | '(' => {
|
||||
depth += 1;
|
||||
if !in_key {
|
||||
current_value.push(ch);
|
||||
}
|
||||
}
|
||||
']' | '}' | ')' => {
|
||||
depth -= 1;
|
||||
if !in_key {
|
||||
current_value.push(ch);
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
if in_key {
|
||||
if !ch.is_whitespace() || !current_key.is_empty() {
|
||||
current_key.push(ch);
|
||||
}
|
||||
} else {
|
||||
current_value.push(ch);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
i += 1;
|
||||
}
|
||||
|
||||
// Handle the last argument
|
||||
if !current_key.is_empty() {
|
||||
let value = parse_python_literal(current_value.trim())?;
|
||||
result.insert(current_key.trim().to_string(), value);
|
||||
}
|
||||
|
||||
Ok(Value::Object(result))
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ToolParser for PythonicParser {
|
||||
async fn parse_complete(&self, text: &str) -> ToolParserResult<Vec<ToolCall>> {
|
||||
let cleaned = Self::strip_special_tokens(text);
|
||||
|
||||
// Extract tool calls using bracket counting
|
||||
if let Some(tool_calls_text) = self.extract_tool_calls(&cleaned) {
|
||||
// Remove the outer brackets
|
||||
let tool_calls_str = &tool_calls_text[1..tool_calls_text.len() - 1];
|
||||
|
||||
// Split into individual function calls
|
||||
let mut calls = Vec::new();
|
||||
let mut current_call = String::new();
|
||||
let mut paren_depth = 0;
|
||||
let mut in_string = false;
|
||||
let mut string_char = ' ';
|
||||
|
||||
for ch in tool_calls_str.chars() {
|
||||
if !in_string && (ch == '"' || ch == '\'') {
|
||||
in_string = true;
|
||||
string_char = ch;
|
||||
current_call.push(ch);
|
||||
} else if in_string && ch == string_char {
|
||||
in_string = false;
|
||||
current_call.push(ch);
|
||||
} else if in_string {
|
||||
current_call.push(ch);
|
||||
} else {
|
||||
match ch {
|
||||
'(' => {
|
||||
paren_depth += 1;
|
||||
current_call.push(ch);
|
||||
}
|
||||
')' => {
|
||||
paren_depth -= 1;
|
||||
current_call.push(ch);
|
||||
}
|
||||
',' if paren_depth == 0 => {
|
||||
// End of current function call
|
||||
if let Some(call) = self.parse_function_call(current_call.trim())? {
|
||||
calls.push(call);
|
||||
}
|
||||
current_call.clear();
|
||||
}
|
||||
_ => {
|
||||
if !ch.is_whitespace() || !current_call.is_empty() {
|
||||
current_call.push(ch);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle the last call (important for single calls or the last call in a list)
|
||||
if !current_call.trim().is_empty() {
|
||||
if let Some(call) = self.parse_function_call(current_call.trim())? {
|
||||
calls.push(call);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(calls)
|
||||
} else {
|
||||
Ok(vec![])
|
||||
}
|
||||
}
|
||||
|
||||
async fn parse_incremental(
|
||||
&self,
|
||||
chunk: &str,
|
||||
state: &mut ParseState,
|
||||
) -> ToolParserResult<StreamResult> {
|
||||
// For Pythonic format, we accumulate until we have a complete tool call
|
||||
// This is a simplified implementation
|
||||
state.buffer.push_str(chunk);
|
||||
|
||||
// Try to parse if we have a complete tool call
|
||||
let cleaned = Self::strip_special_tokens(&state.buffer);
|
||||
if self.extract_tool_calls(&cleaned).is_some() {
|
||||
let result = self.parse_complete(&state.buffer).await?;
|
||||
if !result.is_empty() {
|
||||
state.buffer.clear();
|
||||
return Ok(StreamResult::ToolComplete(
|
||||
result.into_iter().next().unwrap(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(StreamResult::Incomplete)
|
||||
}
|
||||
|
||||
fn detect_format(&self, text: &str) -> bool {
|
||||
let cleaned = Self::strip_special_tokens(text);
|
||||
self.tool_call_regex.is_match(&cleaned)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for PythonicParser {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_single_function_call() {
|
||||
let parser = PythonicParser::new();
|
||||
let input = r#"[search_web(query="Rust programming", max_results=5)]"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "search_web");
|
||||
|
||||
let args: Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert_eq!(args["query"], "Rust programming");
|
||||
assert_eq!(args["max_results"], 5);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_multiple_function_calls() {
|
||||
let parser = PythonicParser::new();
|
||||
let input = r#"[get_weather(city="Tokyo"), search(query="news")]"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 2);
|
||||
assert_eq!(result[0].function.name, "get_weather");
|
||||
assert_eq!(result[1].function.name, "search");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_python_literals() {
|
||||
let parser = PythonicParser::new();
|
||||
let input = r#"[test(flag=True, disabled=False, optional=None)]"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
|
||||
let args: Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert_eq!(args["flag"], true);
|
||||
assert_eq!(args["disabled"], false);
|
||||
assert_eq!(args["optional"], Value::Null);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_special_tokens() {
|
||||
let parser = PythonicParser::new();
|
||||
let input = r#"<|python_start|>[calculate(x=10, y=20)]<|python_end|>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "calculate");
|
||||
|
||||
let args: Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert_eq!(args["x"], 10);
|
||||
assert_eq!(args["y"], 20);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_llama4_format() {
|
||||
let parser = PythonicParser::new();
|
||||
let input = r#"[get_weather(city="London", units="celsius")]"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "get_weather");
|
||||
|
||||
let args: Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert_eq!(args["city"], "London");
|
||||
assert_eq!(args["units"], "celsius");
|
||||
}
|
||||
}
|
||||
396
sgl-router/src/tool_parser/parsers/qwen_parser.rs
Normal file
396
sgl-router/src/tool_parser/parsers/qwen_parser.rs
Normal file
@@ -0,0 +1,396 @@
|
||||
use async_trait::async_trait;
|
||||
use regex::Regex;
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::tool_parser::{
|
||||
errors::{ToolParserError, ToolParserResult},
|
||||
partial_json::PartialJson,
|
||||
state::ParseState,
|
||||
traits::ToolParser,
|
||||
types::{FunctionCall, StreamResult, ToolCall},
|
||||
};
|
||||
|
||||
/// Qwen format parser for tool calls
|
||||
///
|
||||
/// Handles the Qwen 2.5/3 specific format:
|
||||
/// `<tool_call>\n{"name": "func", "arguments": {...}}\n</tool_call>`
|
||||
///
|
||||
/// Features:
|
||||
/// - XML-style tags with JSON content
|
||||
/// - Support for multiple sequential tool calls
|
||||
/// - Newline-aware parsing
|
||||
pub struct QwenParser {
|
||||
/// Parser for handling incomplete JSON during streaming
|
||||
partial_json: PartialJson,
|
||||
/// Regex for extracting tool calls
|
||||
extractor: Regex,
|
||||
}
|
||||
|
||||
impl QwenParser {
|
||||
/// Create a new Qwen parser
|
||||
pub fn new() -> Self {
|
||||
// Use (?s) flag for DOTALL mode to handle newlines
|
||||
let pattern = r"(?s)<tool_call>\n(.*?)\n</tool_call>";
|
||||
let extractor = Regex::new(pattern).expect("Valid regex pattern");
|
||||
|
||||
Self {
|
||||
partial_json: PartialJson::default(),
|
||||
extractor,
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract all tool call blocks from text
|
||||
fn extract_tool_calls<'a>(&self, text: &'a str) -> Vec<&'a str> {
|
||||
self.extractor
|
||||
.captures_iter(text)
|
||||
.filter_map(|cap| cap.get(1).map(|m| m.as_str()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Parse a single JSON object into a ToolCall
|
||||
fn parse_single_object(&self, obj: &Value, index: usize) -> ToolParserResult<Option<ToolCall>> {
|
||||
let name = obj.get("name").and_then(|v| v.as_str());
|
||||
|
||||
if let Some(name) = name {
|
||||
// Get arguments - Qwen uses "arguments" key
|
||||
let empty_obj = Value::Object(serde_json::Map::new());
|
||||
let args = obj.get("arguments").unwrap_or(&empty_obj);
|
||||
|
||||
// Convert arguments to JSON string
|
||||
let arguments = serde_json::to_string(args)
|
||||
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
|
||||
|
||||
// Generate ID with index for multiple tools
|
||||
let id = format!("qwen_call_{}", index);
|
||||
|
||||
Ok(Some(ToolCall {
|
||||
id,
|
||||
r#type: "function".to_string(),
|
||||
function: FunctionCall {
|
||||
name: name.to_string(),
|
||||
arguments,
|
||||
},
|
||||
}))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if text contains Qwen tool markers
|
||||
fn has_tool_markers(&self, text: &str) -> bool {
|
||||
text.contains("<tool_call>")
|
||||
}
|
||||
|
||||
/// Find the start position of a tool call
|
||||
fn find_tool_start(&self, text: &str) -> Option<usize> {
|
||||
text.find("<tool_call>\n")
|
||||
}
|
||||
|
||||
/// Find the end position of a tool call
|
||||
fn find_tool_end(&self, text: &str, start_pos: usize) -> Option<usize> {
|
||||
let search_from = start_pos + "<tool_call>\n".len();
|
||||
text[search_from..]
|
||||
.find("\n</tool_call>")
|
||||
.map(|pos| search_from + pos + "\n</tool_call>".len())
|
||||
}
|
||||
|
||||
/// Check if buffer ends with a partial token
|
||||
fn ends_with_partial_token(&self, buffer: &str) -> Option<usize> {
|
||||
// Check for partial start token
|
||||
let start_token = "<tool_call>\n";
|
||||
// Use inclusive range to check if entire buffer could be a prefix
|
||||
for i in 1..=start_token.len().min(buffer.len()) {
|
||||
if start_token.starts_with(&buffer[buffer.len() - i..]) {
|
||||
return Some(i);
|
||||
}
|
||||
}
|
||||
|
||||
// Check for partial end token
|
||||
let end_token = "\n</tool_call>";
|
||||
// Only check if buffer ends with a partial match (not the complete token without newline)
|
||||
// If buffer ends with "</tool_call>", that's not a partial token - it's missing the newline
|
||||
if buffer.ends_with("</tool_call>") {
|
||||
// This is a complete end tag, just missing the leading newline
|
||||
// Not a partial token situation
|
||||
return None;
|
||||
}
|
||||
// Use inclusive range to check if entire buffer could be a prefix
|
||||
(1..=end_token.len().min(buffer.len()))
|
||||
.find(|&i| end_token.starts_with(&buffer[buffer.len() - i..]))
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for QwenParser {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ToolParser for QwenParser {
|
||||
async fn parse_complete(&self, text: &str) -> ToolParserResult<Vec<ToolCall>> {
|
||||
// Check if text contains Qwen format
|
||||
if !self.has_tool_markers(text) {
|
||||
return Ok(vec![]);
|
||||
}
|
||||
|
||||
// Extract all tool call blocks
|
||||
let tool_blocks = self.extract_tool_calls(text);
|
||||
let mut tools = Vec::new();
|
||||
|
||||
for (index, json_str) in tool_blocks.iter().enumerate() {
|
||||
// Parse each JSON block
|
||||
match serde_json::from_str::<Value>(json_str.trim()) {
|
||||
Ok(value) => {
|
||||
if let Some(tool) = self.parse_single_object(&value, index)? {
|
||||
tools.push(tool);
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
// Skip malformed JSON blocks
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(tools)
|
||||
}
|
||||
|
||||
async fn parse_incremental(
|
||||
&self,
|
||||
chunk: &str,
|
||||
state: &mut ParseState,
|
||||
) -> ToolParserResult<StreamResult> {
|
||||
state.buffer.push_str(chunk);
|
||||
|
||||
// Check for partial token at end of buffer
|
||||
if let Some(_partial_len) = self.ends_with_partial_token(&state.buffer) {
|
||||
// Hold back the partial token
|
||||
return Ok(StreamResult::Incomplete);
|
||||
}
|
||||
|
||||
// Check if we have the start marker
|
||||
if !self.has_tool_markers(&state.buffer) {
|
||||
return Ok(StreamResult::Incomplete);
|
||||
}
|
||||
|
||||
// Find start and end positions
|
||||
if let Some(start_pos) = self.find_tool_start(&state.buffer) {
|
||||
// Check if we have the complete tool call
|
||||
if let Some(end_pos) = self.find_tool_end(&state.buffer, start_pos) {
|
||||
// Extract the JSON content
|
||||
let json_start = start_pos + "<tool_call>\n".len();
|
||||
let json_end = end_pos - "\n</tool_call>".len();
|
||||
let json_str = &state.buffer[json_start..json_end];
|
||||
|
||||
// Parse the complete JSON
|
||||
match serde_json::from_str::<Value>(json_str.trim()) {
|
||||
Ok(value) => {
|
||||
if let Some(tool) = self.parse_single_object(&value, 0)? {
|
||||
// Clear the consumed part from buffer using drain for efficiency
|
||||
state.buffer.drain(..end_pos);
|
||||
return Ok(StreamResult::ToolComplete(tool));
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
// JSON parsing failed, might be incomplete
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// We have start but no end yet - try partial parsing
|
||||
let json_start = start_pos + "<tool_call>\n".len();
|
||||
let partial_json = &state.buffer[json_start..];
|
||||
|
||||
// Remove trailing newline if present (might be start of end token)
|
||||
let partial_json = partial_json.trim_end();
|
||||
|
||||
// Try to parse with partial JSON parser
|
||||
match self.partial_json.parse_value(partial_json) {
|
||||
Ok((value, _consumed)) => {
|
||||
// Extract tool name if available
|
||||
if let Some(name) = value.get("name").and_then(|v| v.as_str()) {
|
||||
// Check if we've already sent the name
|
||||
if !state.in_string {
|
||||
state.in_string = true; // Use as flag for "name sent"
|
||||
return Ok(StreamResult::ToolName {
|
||||
index: 0,
|
||||
name: name.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
// Check for arguments
|
||||
if let Some(args) = value.get("arguments") {
|
||||
if let Ok(args_str) = serde_json::to_string(args) {
|
||||
return Ok(StreamResult::ToolArguments {
|
||||
index: 0,
|
||||
arguments: args_str,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
// Failed to parse even as partial JSON
|
||||
// Keep buffering
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(StreamResult::Incomplete)
|
||||
}
|
||||
|
||||
fn detect_format(&self, text: &str) -> bool {
|
||||
// Check if text contains Qwen-specific markers. If not, it's not this format.
|
||||
if !self.has_tool_markers(text) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Try to extract tool calls to see if we have a complete, valid one.
|
||||
let tool_blocks = self.extract_tool_calls(text);
|
||||
for json_str in &tool_blocks {
|
||||
if let Ok(value) = serde_json::from_str::<Value>(json_str.trim()) {
|
||||
if let Some(obj) = value.as_object() {
|
||||
if obj.contains_key("name") && obj.contains_key("arguments") {
|
||||
// Found a valid, complete tool call.
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If we have the marker but no valid complete tool call,
|
||||
// it could be a partial stream. We should detect this as the format.
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_qwen_format() {
|
||||
let parser = QwenParser::new();
|
||||
let input = r#"<tool_call>
|
||||
{"name": "get_weather", "arguments": {"location": "Beijing", "units": "celsius"}}
|
||||
</tool_call>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "get_weather");
|
||||
assert!(result[0].function.arguments.contains("Beijing"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_multiple_tools() {
|
||||
let parser = QwenParser::new();
|
||||
let input = r#"<tool_call>
|
||||
{"name": "search", "arguments": {"query": "rust programming"}}
|
||||
</tool_call>
|
||||
<tool_call>
|
||||
{"name": "calculate", "arguments": {"expression": "2 + 2"}}
|
||||
</tool_call>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 2);
|
||||
assert_eq!(result[0].function.name, "search");
|
||||
assert_eq!(result[1].function.name, "calculate");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_with_normal_text() {
|
||||
let parser = QwenParser::new();
|
||||
let input = r#"Let me help you with that.
|
||||
<tool_call>
|
||||
{"name": "get_info", "arguments": {"topic": "Rust"}}
|
||||
</tool_call>
|
||||
Here are the results."#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "get_info");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_nested_json_structures() {
|
||||
let parser = QwenParser::new();
|
||||
let input = r#"<tool_call>
|
||||
{
|
||||
"name": "process_data",
|
||||
"arguments": {
|
||||
"data": {
|
||||
"nested": {
|
||||
"array": [1, 2, 3],
|
||||
"object": {"key": "value"}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
</tool_call>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "process_data");
|
||||
assert!(result[0].function.arguments.contains("nested"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_format() {
|
||||
let parser = QwenParser::new();
|
||||
|
||||
assert!(parser.detect_format(
|
||||
r#"<tool_call>
|
||||
{"name": "test", "arguments": {}}
|
||||
</tool_call>"#
|
||||
));
|
||||
|
||||
assert!(parser.detect_format(
|
||||
r#"Text before <tool_call>
|
||||
{"name": "test", "arguments": {}}
|
||||
</tool_call> text after"#
|
||||
));
|
||||
|
||||
assert!(!parser.detect_format(r#"{"name": "test", "arguments": {}}"#));
|
||||
assert!(!parser.detect_format("plain text"));
|
||||
|
||||
// Partial format should still be detected
|
||||
assert!(parser.detect_format("<tool_call>"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_streaming_partial() {
|
||||
let parser = QwenParser::new();
|
||||
let mut state = ParseState::new();
|
||||
|
||||
// Simulate streaming chunks
|
||||
let chunks = vec![
|
||||
"<tool_call>\n",
|
||||
r#"{"name": "search","#,
|
||||
r#" "arguments": {"query":"#,
|
||||
r#" "rust"}}"#,
|
||||
"\n</tool_call>",
|
||||
];
|
||||
|
||||
let mut found_name = false;
|
||||
let mut found_complete = false;
|
||||
|
||||
for chunk in chunks {
|
||||
let result = parser.parse_incremental(chunk, &mut state).await.unwrap();
|
||||
|
||||
match result {
|
||||
StreamResult::ToolName { name, .. } => {
|
||||
assert_eq!(name, "search");
|
||||
found_name = true;
|
||||
}
|
||||
StreamResult::ToolComplete(tool) => {
|
||||
assert_eq!(tool.function.name, "search");
|
||||
found_complete = true;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
assert!(found_name || found_complete); // At least one should be found
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user