From 6f6beca49dc403ca7792fb42e76b630ae3ab798b Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Wed, 27 Aug 2025 10:44:52 -0700 Subject: [PATCH] [router] add step3 tool parser (#9695) Co-authored-by: Chang Su --- sgl-router/src/tool_parser/mod.rs | 2 +- sgl-router/src/tool_parser/parsers/mod.rs | 3 + .../src/tool_parser/parsers/step3_parser.rs | 348 ++++++++++++++++++ sgl-router/src/tool_parser/registry.rs | 4 +- sgl-router/tests/tool_parser_step3.rs | 245 ++++++++++++ 5 files changed, 600 insertions(+), 2 deletions(-) create mode 100644 sgl-router/src/tool_parser/parsers/step3_parser.rs create mode 100644 sgl-router/tests/tool_parser_step3.rs diff --git a/sgl-router/src/tool_parser/mod.rs b/sgl-router/src/tool_parser/mod.rs index dad9c23b5..7a6bdfc24 100644 --- a/sgl-router/src/tool_parser/mod.rs +++ b/sgl-router/src/tool_parser/mod.rs @@ -25,5 +25,5 @@ pub use types::{FunctionCall, PartialToolCall, StreamResult, TokenConfig, ToolCa // Re-export parsers for convenience pub use parsers::{ - DeepSeekParser, JsonParser, LlamaParser, MistralParser, PythonicParser, QwenParser, + DeepSeekParser, JsonParser, LlamaParser, MistralParser, PythonicParser, QwenParser, Step3Parser, }; diff --git a/sgl-router/src/tool_parser/parsers/mod.rs b/sgl-router/src/tool_parser/parsers/mod.rs index 1166b70d1..399e2dc98 100644 --- a/sgl-router/src/tool_parser/parsers/mod.rs +++ b/sgl-router/src/tool_parser/parsers/mod.rs @@ -9,12 +9,15 @@ pub mod llama_parser; pub mod mistral_parser; pub mod pythonic_parser; pub mod qwen_parser; +pub mod step3_parser; // Re-export parser types for convenience pub use deepseek_parser::DeepSeekParser; pub use json_parser::JsonParser; + pub use llama_parser::LlamaParser; pub use mistral_parser::MistralParser; pub use pythonic_parser::PythonicParser; pub use qwen_parser::QwenParser; +pub use step3_parser::Step3Parser; diff --git a/sgl-router/src/tool_parser/parsers/step3_parser.rs b/sgl-router/src/tool_parser/parsers/step3_parser.rs new file mode 100644 index 000000000..721d5c037 --- /dev/null +++ b/sgl-router/src/tool_parser/parsers/step3_parser.rs @@ -0,0 +1,348 @@ +use async_trait::async_trait; +use regex::Regex; +use serde_json::Value; + +use crate::tool_parser::{ + errors::{ToolParserError, ToolParserResult}, + state::ParseState, + traits::ToolParser, + types::{FunctionCall, StreamResult, ToolCall}, +}; + +/// Step3 format parser for tool calls +/// +/// Handles the Step3 specific format with steptml XML: +/// `<|tool_calls_begin|><|tool_call_begin|>function<|tool_sep|>{v}<|tool_call_end|><|tool_calls_end|>` +/// +/// Features: +/// - Unicode token delimiters +/// - StepTML XML format for invocations +/// - Support for multiple sequential tool calls +pub struct Step3Parser { + /// Regex for extracting tool call blocks + tool_call_extractor: Regex, + /// Regex for extracting steptml invocations + invoke_extractor: Regex, + /// Regex for extracting parameters + param_extractor: Regex, +} + +impl Step3Parser { + /// Create a new Step3 parser + pub fn new() -> Self { + // Pattern for individual tool calls + let tool_call_pattern = r"(?s)<|tool_call_begin|>.*?<|tool_call_end|>"; + let tool_call_extractor = Regex::new(tool_call_pattern).expect("Valid regex pattern"); + + // Pattern for steptml invocations + let invoke_pattern = r#"(?s)(.+?)"#; + let invoke_extractor = Regex::new(invoke_pattern).expect("Valid regex pattern"); + + // Pattern for steptml parameters - using non-greedy match for values to handle < characters + let param_pattern = r#"(?s)(.+?)"#; + let param_extractor = Regex::new(param_pattern).expect("Valid regex pattern"); + + Self { + tool_call_extractor, + invoke_extractor, + param_extractor, + } + } + + /// Check if text contains Step3 tool markers + fn has_tool_markers(&self, text: &str) -> bool { + text.contains("<|tool_calls_begin|>") + } + + /// Parse parameters from steptml format + fn parse_steptml_parameters( + &self, + params_text: &str, + ) -> ToolParserResult> { + let mut parameters = serde_json::Map::new(); + + for capture in self.param_extractor.captures_iter(params_text) { + let param_name = capture.get(1).map_or("", |m| m.as_str()).trim(); + let param_value_str = capture.get(2).map_or("", |m| m.as_str()).trim(); + + // Try to parse the value as JSON first, fallback to string + let param_value = if let Ok(json_val) = serde_json::from_str::(param_value_str) { + json_val + } else { + // Try parsing as Python literal + if param_value_str == "true" || param_value_str == "True" { + Value::Bool(true) + } else if param_value_str == "false" || param_value_str == "False" { + Value::Bool(false) + } else if param_value_str == "null" || param_value_str == "None" { + Value::Null + } else if let Ok(num) = param_value_str.parse::() { + Value::Number(num.into()) + } else if let Ok(num) = param_value_str.parse::() { + if let Some(n) = serde_json::Number::from_f64(num) { + Value::Number(n) + } else { + Value::String(param_value_str.to_string()) + } + } else { + Value::String(param_value_str.to_string()) + } + }; + + parameters.insert(param_name.to_string(), param_value); + } + + Ok(parameters) + } + + /// Parse a single tool call block + fn parse_tool_call(&self, block: &str) -> ToolParserResult> { + // Check if it contains function marker and tool separator + if !block.contains("function") || !block.contains("<|tool_sep|>") { + return Ok(None); + } + + // Split by tool separator + let parts: Vec<&str> = block.split("<|tool_sep|>").collect(); + if parts.len() != 2 { + return Ok(None); + } + + // Check if it's a function type + if !parts[0].contains("function") { + return Ok(None); + } + + let invoke_part = parts[1]; + + // Extract steptml invoke + if let Some(captures) = self.invoke_extractor.captures(invoke_part) { + let func_name = captures.get(1).map_or("", |m| m.as_str()).trim(); + + // Validate function name is not empty + if func_name.is_empty() { + return Ok(None); + } + + let params_text = captures.get(2).map_or("", |m| m.as_str()); + + // Parse parameters + let parameters = self.parse_steptml_parameters(params_text)?; + + let arguments_str = serde_json::to_string(¶meters) + .map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?; + + // Generate ID + let id = format!("step3_call_{}", uuid::Uuid::new_v4()); + + Ok(Some(ToolCall { + id, + r#type: "function".to_string(), + function: FunctionCall { + name: func_name.to_string(), + arguments: arguments_str, + }, + })) + } else { + Ok(None) + } + } +} + +impl Default for Step3Parser { + fn default() -> Self { + Self::new() + } +} + +#[async_trait] +impl ToolParser for Step3Parser { + async fn parse_complete(&self, text: &str) -> ToolParserResult> { + // Check if text contains Step3 format + if !self.has_tool_markers(text) { + return Ok(vec![]); + } + + // Find the tool calls section + if let Some(start_pos) = text.find("<|tool_calls_begin|>") { + let search_from = start_pos + "<|tool_calls_begin|>".len(); + + // Find the end of tool calls section + if let Some(end_pos) = text[search_from..].find("<|tool_calls_end|>") { + let tool_section = &text[search_from..search_from + end_pos]; + + // Extract all tool call blocks + let mut tools = Vec::new(); + for mat in self.tool_call_extractor.find_iter(tool_section) { + if let Some(tool) = self.parse_tool_call(mat.as_str())? { + tools.push(tool); + } + } + + return Ok(tools); + } + } + + Ok(vec![]) + } + + async fn parse_incremental( + &self, + chunk: &str, + state: &mut ParseState, + ) -> ToolParserResult { + state.buffer.push_str(chunk); + + // Check for tool markers + if !self.has_tool_markers(&state.buffer) { + // No markers found, return as incomplete + return Ok(StreamResult::Incomplete); + } + + // Look for start of tool calls + if let Some(start_pos) = state.buffer.find("<|tool_calls_begin|>") { + let search_from = start_pos + "<|tool_calls_begin|>".len(); + + // Look for individual tool call start + if let Some(call_start) = state.buffer[search_from..].find("<|tool_call_begin|>") { + let call_start_abs = search_from + call_start; + + // Look for the end of this tool call + let search_end_from = call_start_abs + "<|tool_call_begin|>".len(); + if let Some(call_end) = state.buffer[search_end_from..].find("<|tool_call_end|>") + { + let call_end_abs = search_end_from + call_end + "<|tool_call_end|>".len(); + + // Extract and parse the complete tool call + let tool_call_text = &state.buffer[call_start_abs..call_end_abs]; + + if let Some(tool) = self.parse_tool_call(tool_call_text)? { + // Remove the processed part from buffer + state.buffer.drain(..call_end_abs); + + return Ok(StreamResult::ToolComplete(tool)); + } + } else { + // Tool call not complete yet, try to extract partial info + let partial = &state.buffer[search_end_from..]; + + // Check for tool separator + if let Some(sep_pos) = partial.find("<|tool_sep|>") { + // Check if it's a function + if partial[..sep_pos].contains("function") { + let after_sep = &partial[sep_pos + "<|tool_sep|>".len()..]; + + // Try to extract function name from steptml:invoke + if let Some(name_match) = self.invoke_extractor.captures(after_sep) { + let func_name = name_match.get(1).map_or("", |m| m.as_str()).trim(); + + if !state.in_string && !func_name.is_empty() { + state.in_string = true; // Mark name as sent + return Ok(StreamResult::ToolName { + index: 0, + name: func_name.to_string(), + }); + } + + // Try to extract partial parameters + if let Some(params_text) = name_match.get(2) { + let parameters = + self.parse_steptml_parameters(params_text.as_str())?; + + if !parameters.is_empty() { + let args_str = serde_json::to_string(¶meters) + .unwrap_or_else(|_| "{}".to_string()); + + return Ok(StreamResult::ToolArguments { + index: 0, + arguments: args_str, + }); + } + } + } + } + } + } + } + } + + Ok(StreamResult::Incomplete) + } + + fn detect_format(&self, text: &str) -> bool { + self.has_tool_markers(text) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_parse_step3_single_tool() { + let parser = Step3Parser::new(); + let input = r#"Some text +<|tool_calls_begin|> +<|tool_call_begin|>function<|tool_sep|> +Tokyo +celsius +<|tool_call_end|> +<|tool_calls_end|>More text"#; + + let result = parser.parse_complete(input).await.unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result[0].function.name, "get_weather"); + assert!(result[0].function.arguments.contains("Tokyo")); + assert!(result[0].function.arguments.contains("celsius")); + } + + #[tokio::test] + async fn test_parse_step3_multiple_tools() { + let parser = Step3Parser::new(); + let input = r#"<|tool_calls_begin|> +<|tool_call_begin|>function<|tool_sep|> +rust programming +<|tool_call_end|> +<|tool_call_begin|>function<|tool_sep|> +2 + 2 +<|tool_call_end|> +<|tool_calls_end|>"#; + + let result = parser.parse_complete(input).await.unwrap(); + assert_eq!(result.len(), 2); + assert_eq!(result[0].function.name, "search"); + assert_eq!(result[1].function.name, "calculate"); + } + + #[tokio::test] + async fn test_parse_step3_mixed_types() { + let parser = Step3Parser::new(); + let input = r#"<|tool_calls_begin|> +<|tool_call_begin|>function<|tool_sep|> +42 +true +1.5 +test +<|tool_call_end|> +<|tool_calls_end|>"#; + + let result = parser.parse_complete(input).await.unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result[0].function.name, "process_data"); + + // Parse arguments to check types + let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + assert_eq!(args["count"], 42); + assert_eq!(args["active"], true); + assert_eq!(args["rate"], 1.5); + assert_eq!(args["name"], "test"); + } + + #[test] + fn test_detect_format() { + let parser = Step3Parser::new(); + assert!(parser.detect_format("<|tool_calls_begin|>")); + assert!(!parser.detect_format("plain text")); + assert!(!parser.detect_format("[TOOL_CALLS]")); + } +} diff --git a/sgl-router/src/tool_parser/registry.rs b/sgl-router/src/tool_parser/registry.rs index e29c6c136..93ef7b785 100644 --- a/sgl-router/src/tool_parser/registry.rs +++ b/sgl-router/src/tool_parser/registry.rs @@ -1,5 +1,5 @@ use crate::tool_parser::parsers::{ - DeepSeekParser, JsonParser, LlamaParser, MistralParser, PythonicParser, QwenParser, + DeepSeekParser, JsonParser, LlamaParser, MistralParser, PythonicParser, QwenParser, Step3Parser, }; use crate::tool_parser::traits::ToolParser; use std::collections::HashMap; @@ -113,6 +113,8 @@ impl ParserRegistry { // DeepSeek V3 parser - Unicode tokens with JSON blocks self.register_parser("deepseek", Arc::new(DeepSeekParser::new())); + // Step3 parser - StepTML XML format + self.register_parser("step3", Arc::new(Step3Parser::new())); } /// Register default model mappings diff --git a/sgl-router/tests/tool_parser_step3.rs b/sgl-router/tests/tool_parser_step3.rs new file mode 100644 index 000000000..6c1808b31 --- /dev/null +++ b/sgl-router/tests/tool_parser_step3.rs @@ -0,0 +1,245 @@ +//! Step3 Parser Integration Tests + +use sglang_router_rs::tool_parser::{ParseState, Step3Parser, StreamResult, ToolParser}; + +#[tokio::test] +async fn test_step3_complete_parsing() { + let parser = Step3Parser::new(); + + // Test single tool call + let input = r#"Let me help you. +<|tool_calls_begin|> +<|tool_call_begin|>function<|tool_sep|> +rust programming +10 +<|tool_call_end|> +<|tool_calls_end|> +Here are the results..."#; + + let result = parser.parse_complete(input).await.unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result[0].function.name, "search"); + + // Verify arguments + let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + assert_eq!(args["query"], "rust programming"); + assert_eq!(args["limit"], 10); +} + +#[tokio::test] +async fn test_step3_multiple_tools() { + let parser = Step3Parser::new(); + + let input = r#"<|tool_calls_begin|> +<|tool_call_begin|>function<|tool_sep|> +Tokyo +<|tool_call_end|> +<|tool_call_begin|>function<|tool_sep|> +tech +5 +<|tool_call_end|> +<|tool_calls_end|>"#; + + let result = parser.parse_complete(input).await.unwrap(); + assert_eq!(result.len(), 2); + assert_eq!(result[0].function.name, "get_weather"); + assert_eq!(result[1].function.name, "get_news"); +} + +#[tokio::test] +async fn test_step3_type_conversion() { + let parser = Step3Parser::new(); + + let input = r#"<|tool_calls_begin|> +<|tool_call_begin|>function<|tool_sep|> +100 +2.5 +true +null +hello world +<|tool_call_end|> +<|tool_calls_end|>"#; + + let result = parser.parse_complete(input).await.unwrap(); + assert_eq!(result.len(), 1); + + let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + assert_eq!(args["count"], 100); + assert_eq!(args["rate"], 2.5); + assert_eq!(args["active"], true); + assert_eq!(args["optional"], serde_json::Value::Null); + assert_eq!(args["text"], "hello world"); +} + +#[tokio::test] +async fn test_step3_streaming() { + let parser = Step3Parser::new(); + let mut state = ParseState::new(); + + // Simulate streaming chunks + let chunks = vec![ + "<|tool_calls_begin|>\n", + "<|tool_call_begin|>function", + "<|tool_sep|>", + "\n10", + "\n20", + "\n<|tool_call_end|>", + "\n<|tool_calls_end|>", + ]; + + let mut found_name = false; + let mut found_complete = false; + + for chunk in chunks { + let result = parser.parse_incremental(chunk, &mut state).await.unwrap(); + + match result { + StreamResult::ToolName { name, .. } => { + assert_eq!(name, "calc"); + found_name = true; + } + StreamResult::ToolComplete(tool) => { + assert_eq!(tool.function.name, "calc"); + found_complete = true; + } + _ => {} + } + } + + assert!(found_name || found_complete); +} + +#[test] +fn test_step3_format_detection() { + let parser = Step3Parser::new(); + + // Should detect Step3 format + assert!(parser.detect_format("<|tool_calls_begin|>")); + assert!(parser.detect_format("text with <|tool_calls_begin|> marker")); + + // Should not detect other formats + assert!(!parser.detect_format("[TOOL_CALLS]")); + assert!(!parser.detect_format("")); + assert!(!parser.detect_format("plain text")); +} + +#[tokio::test] +async fn test_step3_nested_steptml() { + let parser = Step3Parser::new(); + + // Test with complex parameter values + let input = r#"<|tool_calls_begin|> +<|tool_call_begin|>function<|tool_sep|> +{"nested": {"key": "value"}} +[1, 2, 3] +<|tool_call_end|> +<|tool_calls_end|>"#; + + let result = parser.parse_complete(input).await.unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result[0].function.name, "config"); + + let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + assert!(args["settings"].is_object()); + assert!(args["array"].is_array()); +} + +#[tokio::test] +async fn test_step3_python_literals() { + let parser = Step3Parser::new(); + + // Test Python-style literals + let input = r#"<|tool_calls_begin|> +<|tool_call_begin|>function<|tool_sep|> +True +False +None +<|tool_call_end|> +<|tool_calls_end|>"#; + + let result = parser.parse_complete(input).await.unwrap(); + assert_eq!(result.len(), 1); + + let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + assert_eq!(args["bool_true"], true); + assert_eq!(args["bool_false"], false); + assert_eq!(args["none_value"], serde_json::Value::Null); +} + +#[tokio::test] +async fn test_steptml_format() { + let parser = Step3Parser::new(); + + let input = r#"Text before. +<|tool_calls_begin|> +<|tool_call_begin|>function<|tool_sep|> +rust lang +10 +<|tool_call_end|> +<|tool_calls_end|>Text after."#; + + let result = parser.parse_complete(input).await.unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result[0].function.name, "search"); + + let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + assert_eq!(args["query"], "rust lang"); + assert_eq!(args["limit"], 10); + // TODO: Verify normal text extraction +} + +#[tokio::test] +async fn test_json_parameter_values() { + let parser = Step3Parser::new(); + + let input = r#"<|tool_calls_begin|> +<|tool_call_begin|>function<|tool_sep|> +{"nested": {"value": true}} +[1, 2, 3] +<|tool_call_end|> +<|tool_calls_end|>"#; + + let result = parser.parse_complete(input).await.unwrap(); + assert_eq!(result.len(), 1); + + let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + assert!(args["settings"].is_object()); + assert!(args["items"].is_array()); +} + +#[tokio::test] +async fn test_step3_parameter_with_angle_brackets() { + let parser = Step3Parser::new(); + + // Test parameter value containing < character + let input = r#"<|tool_calls_begin|> +<|tool_call_begin|>function<|tool_sep|> +a < b && b > c +comparison test +<|tool_call_end|> +<|tool_calls_end|>"#; + + let result = parser.parse_complete(input).await.unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result[0].function.name, "compare"); + + // Verify the parameter value was parsed correctly + let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + assert_eq!(args["expression"], "a < b && b > c"); + assert_eq!(args["context"], "comparison test"); +} + +#[tokio::test] +async fn test_step3_empty_function_name() { + let parser = Step3Parser::new(); + + // Test empty function name + let input = r#"<|tool_calls_begin|> +<|tool_call_begin|>function<|tool_sep|> +value +<|tool_call_end|> +<|tool_calls_end|>"#; + + let result = parser.parse_complete(input).await.unwrap(); + assert_eq!(result.len(), 0); // Should reject empty function name +}