diff --git a/sgl-router/src/tool_parser/mod.rs b/sgl-router/src/tool_parser/mod.rs index b2f775c8b..bc4c5a020 100644 --- a/sgl-router/src/tool_parser/mod.rs +++ b/sgl-router/src/tool_parser/mod.rs @@ -5,7 +5,7 @@ pub mod errors; pub mod json_parser; pub mod mistral_parser; pub mod partial_json; - +pub mod qwen_parser; pub mod registry; pub mod state; pub mod traits; @@ -18,6 +18,7 @@ mod tests; pub use errors::{ToolParserError, ToolParserResult}; pub use json_parser::JsonParser; pub use mistral_parser::MistralParser; +pub use qwen_parser::QwenParser; pub use registry::ParserRegistry; pub use state::{ParsePhase, ParseState}; pub use traits::{PartialJsonParser, ToolParser}; diff --git a/sgl-router/src/tool_parser/qwen_parser.rs b/sgl-router/src/tool_parser/qwen_parser.rs new file mode 100644 index 000000000..00d4c3e29 --- /dev/null +++ b/sgl-router/src/tool_parser/qwen_parser.rs @@ -0,0 +1,389 @@ +use async_trait::async_trait; +use regex::Regex; +use serde_json::Value; + +use crate::tool_parser::{ + errors::{ToolParserError, ToolParserResult}, + partial_json::PartialJson, + state::ParseState, + traits::ToolParser, + types::{FunctionCall, StreamResult, ToolCall}, +}; + +/// Qwen format parser for tool calls +/// +/// Handles the Qwen 2.5/3 specific format: +/// `\n{"name": "func", "arguments": {...}}\n` +/// +/// Features: +/// - XML-style tags with JSON content +/// - Support for multiple sequential tool calls +/// - Newline-aware parsing +pub struct QwenParser { + /// Parser for handling incomplete JSON during streaming + partial_json: PartialJson, + /// Regex for extracting tool calls + extractor: Regex, +} + +impl QwenParser { + /// Create a new Qwen parser + pub fn new() -> Self { + // Use (?s) flag for DOTALL mode to handle newlines + let pattern = r"(?s)\n(.*?)\n"; + let extractor = Regex::new(pattern).expect("Valid regex pattern"); + + Self { + partial_json: PartialJson::default(), + extractor, + } + } + + /// Extract all tool call blocks from text + fn extract_tool_calls<'a>(&self, text: &'a str) -> Vec<&'a str> { + self.extractor + .captures_iter(text) + .filter_map(|cap| cap.get(1).map(|m| m.as_str())) + .collect() + } + + /// Parse a single JSON object into a ToolCall + fn parse_single_object(&self, obj: &Value, index: usize) -> ToolParserResult> { + let name = obj.get("name").and_then(|v| v.as_str()); + + if let Some(name) = name { + // Get arguments - Qwen uses "arguments" key + let empty_obj = Value::Object(serde_json::Map::new()); + let args = obj.get("arguments").unwrap_or(&empty_obj); + + // Convert arguments to JSON string + let arguments = serde_json::to_string(args) + .map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?; + + // Generate ID with index for multiple tools + let id = format!("qwen_call_{}", index); + + Ok(Some(ToolCall { + id, + r#type: "function".to_string(), + function: FunctionCall { + name: name.to_string(), + arguments, + }, + })) + } else { + Ok(None) + } + } + + /// Check if text contains Qwen tool markers + fn has_tool_markers(&self, text: &str) -> bool { + text.contains("") + } + + /// Find the start position of a tool call + fn find_tool_start(&self, text: &str) -> Option { + text.find("\n") + } + + /// Find the end position of a tool call + fn find_tool_end(&self, text: &str, start_pos: usize) -> Option { + let search_from = start_pos + "\n".len(); + text[search_from..] + .find("\n") + .map(|pos| search_from + pos + "\n".len()) + } + + /// Check if buffer ends with a partial token + fn ends_with_partial_token(&self, buffer: &str) -> Option { + // Check for partial start token + let start_token = "\n"; + // Use inclusive range to check if entire buffer could be a prefix + for i in 1..=start_token.len().min(buffer.len()) { + if start_token.starts_with(&buffer[buffer.len() - i..]) { + return Some(i); + } + } + + // Check for partial end token + let end_token = "\n"; + // Use inclusive range to check if entire buffer could be a prefix + (1..=end_token.len().min(buffer.len())) + .find(|&i| end_token.starts_with(&buffer[buffer.len() - i..])) + } +} + +impl Default for QwenParser { + fn default() -> Self { + Self::new() + } +} + +#[async_trait] +impl ToolParser for QwenParser { + async fn parse_complete(&self, text: &str) -> ToolParserResult> { + // Check if text contains Qwen format + if !self.has_tool_markers(text) { + return Ok(vec![]); + } + + // Extract all tool call blocks + let tool_blocks = self.extract_tool_calls(text); + let mut tools = Vec::new(); + + for (index, json_str) in tool_blocks.iter().enumerate() { + // Parse each JSON block + match serde_json::from_str::(json_str.trim()) { + Ok(value) => { + if let Some(tool) = self.parse_single_object(&value, index)? { + tools.push(tool); + } + } + Err(_) => { + // Skip malformed JSON blocks + continue; + } + } + } + + Ok(tools) + } + + async fn parse_incremental( + &self, + chunk: &str, + state: &mut ParseState, + ) -> ToolParserResult { + state.buffer.push_str(chunk); + + // Check for partial token at end of buffer + if let Some(_partial_len) = self.ends_with_partial_token(&state.buffer) { + // Hold back the partial token + return Ok(StreamResult::Incomplete); + } + + // Check if we have the start marker + if !self.has_tool_markers(&state.buffer) { + return Ok(StreamResult::Incomplete); + } + + // Find start and end positions + if let Some(start_pos) = self.find_tool_start(&state.buffer) { + // Check if we have the complete tool call + if let Some(end_pos) = self.find_tool_end(&state.buffer, start_pos) { + // Extract the JSON content + let json_start = start_pos + "\n".len(); + let json_end = end_pos - "\n".len(); + let json_str = &state.buffer[json_start..json_end]; + + // Parse the complete JSON + match serde_json::from_str::(json_str.trim()) { + Ok(value) => { + if let Some(tool) = self.parse_single_object(&value, 0)? { + // Clear the consumed part from buffer using drain for efficiency + state.buffer.drain(..end_pos); + return Ok(StreamResult::ToolComplete(tool)); + } + } + Err(_) => { + // JSON parsing failed, might be incomplete + } + } + } else { + // We have start but no end yet - try partial parsing + let json_start = start_pos + "\n".len(); + let partial_json = &state.buffer[json_start..]; + + // Remove trailing newline if present (might be start of end token) + let partial_json = partial_json.trim_end(); + + // Try to parse with partial JSON parser + match self.partial_json.parse_value(partial_json) { + Ok((value, _consumed)) => { + // Extract tool name if available + if let Some(name) = value.get("name").and_then(|v| v.as_str()) { + // Check if we've already sent the name + if !state.in_string { + state.in_string = true; // Use as flag for "name sent" + return Ok(StreamResult::ToolName { + index: 0, + name: name.to_string(), + }); + } + + // Check for arguments + if let Some(args) = value.get("arguments") { + if let Ok(args_str) = serde_json::to_string(args) { + return Ok(StreamResult::ToolArguments { + index: 0, + arguments: args_str, + }); + } + } + } + } + Err(_) => { + // Failed to parse even as partial JSON + // Keep buffering + } + } + } + } + + Ok(StreamResult::Incomplete) + } + + fn detect_format(&self, text: &str) -> bool { + // Check if text contains Qwen-specific markers. If not, it's not this format. + if !self.has_tool_markers(text) { + return false; + } + + // Try to extract tool calls to see if we have a complete, valid one. + let tool_blocks = self.extract_tool_calls(text); + for json_str in &tool_blocks { + if let Ok(value) = serde_json::from_str::(json_str.trim()) { + if let Some(obj) = value.as_object() { + if obj.contains_key("name") && obj.contains_key("arguments") { + // Found a valid, complete tool call. + return true; + } + } + } + } + + // If we have the marker but no valid complete tool call, + // it could be a partial stream. We should detect this as the format. + true + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_parse_qwen_format() { + let parser = QwenParser::new(); + let input = r#" +{"name": "get_weather", "arguments": {"location": "Beijing", "units": "celsius"}} +"#; + + let result = parser.parse_complete(input).await.unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result[0].function.name, "get_weather"); + assert!(result[0].function.arguments.contains("Beijing")); + } + + #[tokio::test] + async fn test_parse_multiple_tools() { + let parser = QwenParser::new(); + let input = r#" +{"name": "search", "arguments": {"query": "rust programming"}} + + +{"name": "calculate", "arguments": {"expression": "2 + 2"}} +"#; + + let result = parser.parse_complete(input).await.unwrap(); + assert_eq!(result.len(), 2); + assert_eq!(result[0].function.name, "search"); + assert_eq!(result[1].function.name, "calculate"); + } + + #[tokio::test] + async fn test_with_normal_text() { + let parser = QwenParser::new(); + let input = r#"Let me help you with that. + +{"name": "get_info", "arguments": {"topic": "Rust"}} + +Here are the results."#; + + let result = parser.parse_complete(input).await.unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result[0].function.name, "get_info"); + } + + #[tokio::test] + async fn test_nested_json_structures() { + let parser = QwenParser::new(); + let input = r#" +{ + "name": "process_data", + "arguments": { + "data": { + "nested": { + "array": [1, 2, 3], + "object": {"key": "value"} + } + } + } +} +"#; + + let result = parser.parse_complete(input).await.unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result[0].function.name, "process_data"); + assert!(result[0].function.arguments.contains("nested")); + } + + #[test] + fn test_detect_format() { + let parser = QwenParser::new(); + + assert!(parser.detect_format( + r#" +{"name": "test", "arguments": {}} +"# + )); + + assert!(parser.detect_format( + r#"Text before +{"name": "test", "arguments": {}} + text after"# + )); + + assert!(!parser.detect_format(r#"{"name": "test", "arguments": {}}"#)); + assert!(!parser.detect_format("plain text")); + + // Partial format should still be detected + assert!(parser.detect_format("")); + } + + #[tokio::test] + async fn test_streaming_partial() { + let parser = QwenParser::new(); + let mut state = ParseState::new(); + + // Simulate streaming chunks + let chunks = vec![ + "\n", + r#"{"name": "search","#, + r#" "arguments": {"query":"#, + r#" "rust"}}"#, + "\n", + ]; + + let mut found_name = false; + let mut found_complete = false; + + for chunk in chunks { + let result = parser.parse_incremental(chunk, &mut state).await.unwrap(); + + match result { + StreamResult::ToolName { name, .. } => { + assert_eq!(name, "search"); + found_name = true; + } + StreamResult::ToolComplete(tool) => { + assert_eq!(tool.function.name, "search"); + found_complete = true; + } + _ => {} + } + } + + assert!(found_name || found_complete); // At least one should be found + } +} diff --git a/sgl-router/src/tool_parser/registry.rs b/sgl-router/src/tool_parser/registry.rs index c1178200a..dc61fccbb 100644 --- a/sgl-router/src/tool_parser/registry.rs +++ b/sgl-router/src/tool_parser/registry.rs @@ -1,4 +1,6 @@ use crate::tool_parser::json_parser::JsonParser; +use crate::tool_parser::mistral_parser::MistralParser; +use crate::tool_parser::qwen_parser::QwenParser; use crate::tool_parser::traits::ToolParser; use std::collections::HashMap; use std::sync::Arc; @@ -97,7 +99,11 @@ impl ParserRegistry { // JSON parser - most common format self.register_parser("json", Arc::new(JsonParser::new())); - // Note: Additional parsers (mistral, qwen, llama) will be added in later phases + // Mistral parser - [TOOL_CALLS] [...] format + self.register_parser("mistral", Arc::new(MistralParser::new())); + + // Qwen parser - ... format + self.register_parser("qwen", Arc::new(QwenParser::new())); } /// Register default model mappings