From 07ee0ab7507aebfd9240ba143c190e66b056608a Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Wed, 27 Aug 2025 11:26:00 -0700 Subject: [PATCH] [router] add gpt-oss and glm4 tool parser (#9703) Co-authored-by: Chang Su --- sgl-router/src/tool_parser/mod.rs | 4 +- .../tool_parser/parsers/glm4_moe_parser.rs | 292 ++++++++++++++++++ .../src/tool_parser/parsers/gpt_oss_parser.rs | 292 ++++++++++++++++++ sgl-router/src/tool_parser/parsers/mod.rs | 4 + sgl-router/src/tool_parser/registry.rs | 31 +- sgl-router/tests/tool_parser_glm4_moe.rs | 194 ++++++++++++ sgl-router/tests/tool_parser_gpt_oss.rs | 201 ++++++++++++ 7 files changed, 1014 insertions(+), 4 deletions(-) create mode 100644 sgl-router/src/tool_parser/parsers/glm4_moe_parser.rs create mode 100644 sgl-router/src/tool_parser/parsers/gpt_oss_parser.rs create mode 100644 sgl-router/tests/tool_parser_glm4_moe.rs create mode 100644 sgl-router/tests/tool_parser_gpt_oss.rs diff --git a/sgl-router/src/tool_parser/mod.rs b/sgl-router/src/tool_parser/mod.rs index 42d42ea5b..41b8fae2f 100644 --- a/sgl-router/src/tool_parser/mod.rs +++ b/sgl-router/src/tool_parser/mod.rs @@ -25,6 +25,6 @@ pub use types::{FunctionCall, PartialToolCall, StreamResult, TokenConfig, ToolCa // Re-export parsers for convenience pub use parsers::{ - DeepSeekParser, JsonParser, KimiK2Parser, LlamaParser, MistralParser, PythonicParser, - QwenParser, Step3Parser, + DeepSeekParser, Glm4MoeParser, GptOssParser, JsonParser, KimiK2Parser, LlamaParser, + MistralParser, PythonicParser, QwenParser, Step3Parser, }; diff --git a/sgl-router/src/tool_parser/parsers/glm4_moe_parser.rs b/sgl-router/src/tool_parser/parsers/glm4_moe_parser.rs new file mode 100644 index 000000000..017de1256 --- /dev/null +++ b/sgl-router/src/tool_parser/parsers/glm4_moe_parser.rs @@ -0,0 +1,292 @@ +use async_trait::async_trait; +use regex::Regex; +use serde_json::Value; + +use crate::tool_parser::{ + errors::{ToolParserError, ToolParserResult}, + state::ParseState, + traits::ToolParser, + types::{FunctionCall, StreamResult, ToolCall}, +}; + +/// GLM-4 MoE format parser for tool calls +/// +/// Handles the GLM-4 MoE specific format: +/// `{name}\n{key}\n{value}\n` +/// +/// Features: +/// - XML-style tags for tool calls +/// - Key-value pairs for arguments +/// - Support for multiple sequential tool calls +pub struct Glm4MoeParser { + /// Regex for extracting complete tool calls + tool_call_extractor: Regex, + /// Regex for extracting function details + func_detail_extractor: Regex, + /// Regex for extracting argument key-value pairs + arg_extractor: Regex, +} + +impl Glm4MoeParser { + /// Create a new GLM-4 MoE parser + pub fn new() -> Self { + // Use (?s) flag for DOTALL mode to handle newlines + let tool_call_pattern = r"(?s).*?"; + let tool_call_extractor = Regex::new(tool_call_pattern).expect("Valid regex pattern"); + + let func_detail_pattern = r"(?s)([^\n]*)\n(.*)"; + let func_detail_extractor = Regex::new(func_detail_pattern).expect("Valid regex pattern"); + + let arg_pattern = r"(?s)(.*?)\s*(.*?)"; + let arg_extractor = Regex::new(arg_pattern).expect("Valid regex pattern"); + + Self { + tool_call_extractor, + func_detail_extractor, + arg_extractor, + } + } + + /// Check if text contains GLM-4 MoE tool markers + fn has_tool_markers(&self, text: &str) -> bool { + text.contains("") + } + + /// Parse arguments from key-value pairs + fn parse_arguments(&self, args_text: &str) -> ToolParserResult> { + let mut arguments = serde_json::Map::new(); + + for capture in self.arg_extractor.captures_iter(args_text) { + let key = capture.get(1).map_or("", |m| m.as_str()).trim(); + let value_str = capture.get(2).map_or("", |m| m.as_str()).trim(); + + // Try to parse the value as JSON first, fallback to string + let value = if let Ok(json_val) = serde_json::from_str::(value_str) { + json_val + } else { + // Try parsing as Python literal (similar to Python's ast.literal_eval) + if value_str == "true" || value_str == "True" { + Value::Bool(true) + } else if value_str == "false" || value_str == "False" { + Value::Bool(false) + } else if value_str == "null" || value_str == "None" { + Value::Null + } else if let Ok(num) = value_str.parse::() { + Value::Number(num.into()) + } else if let Ok(num) = value_str.parse::() { + if let Some(n) = serde_json::Number::from_f64(num) { + Value::Number(n) + } else { + Value::String(value_str.to_string()) + } + } else { + Value::String(value_str.to_string()) + } + }; + + arguments.insert(key.to_string(), value); + } + + Ok(arguments) + } + + /// Parse a single tool call block + fn parse_tool_call(&self, block: &str) -> ToolParserResult> { + if let Some(captures) = self.func_detail_extractor.captures(block) { + // Get function name + let func_name = captures.get(1).map_or("", |m| m.as_str()).trim(); + + // Get arguments text + let args_text = captures.get(2).map_or("", |m| m.as_str()); + + // Parse arguments + let arguments = self.parse_arguments(args_text)?; + + let arguments_str = serde_json::to_string(&arguments) + .map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?; + + // Generate ID + let id = format!("glm4_call_{}", uuid::Uuid::new_v4()); + + Ok(Some(ToolCall { + id, + r#type: "function".to_string(), + function: FunctionCall { + name: func_name.to_string(), + arguments: arguments_str, + }, + })) + } else { + Ok(None) + } + } +} + +impl Default for Glm4MoeParser { + fn default() -> Self { + Self::new() + } +} + +#[async_trait] +impl ToolParser for Glm4MoeParser { + async fn parse_complete(&self, text: &str) -> ToolParserResult> { + // Check if text contains GLM-4 MoE format + if !self.has_tool_markers(text) { + return Ok(vec![]); + } + + // Extract all tool call blocks + let mut tools = Vec::new(); + for mat in self.tool_call_extractor.find_iter(text) { + if let Some(tool) = self.parse_tool_call(mat.as_str())? { + tools.push(tool); + } + } + + Ok(tools) + } + + async fn parse_incremental( + &self, + chunk: &str, + state: &mut ParseState, + ) -> ToolParserResult { + state.buffer.push_str(chunk); + + // Check for tool markers + if !self.has_tool_markers(&state.buffer) { + // No markers found, return as incomplete + return Ok(StreamResult::Incomplete); + } + + // Look for start of tool call + if let Some(start_pos) = state.buffer.find("") { + // Look for the end of this tool call + let search_from = start_pos + "".len(); + if let Some(end_pos) = state.buffer[search_from..].find("") { + let end_abs = search_from + end_pos + "".len(); + + // Extract and parse the complete tool call + let tool_call_text = &state.buffer[start_pos..end_abs]; + + if let Some(tool) = self.parse_tool_call(tool_call_text)? { + // Remove the processed part from buffer + state.buffer.drain(..end_abs); + + return Ok(StreamResult::ToolComplete(tool)); + } + } else { + // Tool call not complete yet, try to extract partial info + let partial = &state.buffer[search_from..]; + + // Try to extract function name (first line after ) + if let Some(name_end) = partial.find('\n') { + let func_name = partial[..name_end].trim(); + + if !func_name.is_empty() && !state.in_string { + state.in_string = true; // Mark name as sent + return Ok(StreamResult::ToolName { + index: 0, + name: func_name.to_string(), + }); + } + + // Try to extract partial arguments + let args_text = &partial[name_end + 1..]; + let partial_args = self.parse_arguments(args_text)?; + + if !partial_args.is_empty() { + let args_str = serde_json::to_string(&partial_args) + .unwrap_or_else(|_| "{}".to_string()); + + return Ok(StreamResult::ToolArguments { + index: 0, + arguments: args_str, + }); + } + } + } + } + + Ok(StreamResult::Incomplete) + } + + fn detect_format(&self, text: &str) -> bool { + self.has_tool_markers(text) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_parse_glm4_single_tool() { + let parser = Glm4MoeParser::new(); + let input = r#"Some text +get_weather +city +Beijing +date +2024-06-27 +More text"#; + + let result = parser.parse_complete(input).await.unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result[0].function.name, "get_weather"); + assert!(result[0].function.arguments.contains("Beijing")); + assert!(result[0].function.arguments.contains("2024-06-27")); + } + + #[tokio::test] + async fn test_parse_glm4_multiple_tools() { + let parser = Glm4MoeParser::new(); + let input = r#"get_weather +city +Beijing + +get_weather +city +Shanghai +"#; + + let result = parser.parse_complete(input).await.unwrap(); + assert_eq!(result.len(), 2); + assert_eq!(result[0].function.name, "get_weather"); + assert_eq!(result[1].function.name, "get_weather"); + assert!(result[0].function.arguments.contains("Beijing")); + assert!(result[1].function.arguments.contains("Shanghai")); + } + + #[tokio::test] + async fn test_parse_glm4_mixed_types() { + let parser = Glm4MoeParser::new(); + let input = r#"process_data +count +42 +active +true +name +test +"#; + + let result = parser.parse_complete(input).await.unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result[0].function.name, "process_data"); + + // Parse arguments to check types + let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + assert_eq!(args["count"], 42); + assert_eq!(args["active"], true); + assert_eq!(args["name"], "test"); + } + + #[test] + fn test_detect_format() { + let parser = Glm4MoeParser::new(); + assert!(parser.detect_format("")); + assert!(!parser.detect_format("plain text")); + assert!(!parser.detect_format("[TOOL_CALLS]")); + } +} diff --git a/sgl-router/src/tool_parser/parsers/gpt_oss_parser.rs b/sgl-router/src/tool_parser/parsers/gpt_oss_parser.rs new file mode 100644 index 000000000..646161a72 --- /dev/null +++ b/sgl-router/src/tool_parser/parsers/gpt_oss_parser.rs @@ -0,0 +1,292 @@ +use async_trait::async_trait; +use regex::Regex; +use serde_json::Value; + +use crate::tool_parser::{ + errors::{ToolParserError, ToolParserResult}, + partial_json::PartialJson, + state::ParseState, + traits::ToolParser, + types::{FunctionCall, StreamResult, ToolCall}, +}; + +/// GPT-OSS format parser for tool calls +/// +/// Handles the GPT-OSS specific channel format: +/// `<|channel|>commentary to={namespace.function}<|constrain|>json<|message|>{json_args}<|call|>` +/// +/// Features: +/// - Channel-based format with commentary +/// - Namespaced function calls +/// - JSON arguments +pub struct GptOssParser { + /// Parser for handling incomplete JSON during streaming + partial_json: PartialJson, + /// Regex for extracting complete function calls + function_call_extractor: Regex, + /// Regex for extracting streaming function calls + streaming_extractor: Regex, +} + +impl GptOssParser { + /// Create a new GPT-OSS parser + pub fn new() -> Self { + // Pattern for complete function calls with to= parameter + // Handles optional <|start|>assistant prefix and whitespace after function name + let function_call_pattern = r"(?s)(?:<\|start\|>assistant)?<\|channel\|>commentary to=([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)\s*<\|constrain\|>json<\|message\|>(.*?)<\|call\|>(?:commentary)?"; + let function_call_extractor = + Regex::new(function_call_pattern).expect("Valid regex pattern"); + + // Pattern for streaming function calls (incomplete) + let streaming_pattern = r"(?s)(?:<\|start\|>assistant)?<\|channel\|>commentary to=([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)\s*<\|constrain\|>json<\|message\|>(.*)"; + let streaming_extractor = Regex::new(streaming_pattern).expect("Valid regex pattern"); + + Self { + partial_json: PartialJson::default(), + function_call_extractor, + streaming_extractor, + } + } + + /// Check if text contains GPT-OSS tool markers + fn has_tool_markers(&self, text: &str) -> bool { + text.contains("<|channel|>commentary to=") + } + + /// Extract function name from full namespace (e.g., "functions.get_weather" -> "get_weather") + fn extract_function_name(&self, full_name: &str) -> String { + if let Some(dot_pos) = full_name.rfind('.') { + full_name[dot_pos + 1..].to_string() + } else { + full_name.to_string() + } + } +} + +impl Default for GptOssParser { + fn default() -> Self { + Self::new() + } +} + +#[async_trait] +impl ToolParser for GptOssParser { + async fn parse_complete(&self, text: &str) -> ToolParserResult> { + // Check if text contains GPT-OSS format + if !self.has_tool_markers(text) { + return Ok(vec![]); + } + + let mut tools = Vec::new(); + let mut _tool_index = 0; + + // Extract all function calls + for captures in self.function_call_extractor.captures_iter(text) { + if let (Some(name_match), Some(args_match)) = (captures.get(1), captures.get(2)) { + let full_function_name = name_match.as_str(); + let args_content = args_match.as_str().trim(); + + // Extract actual function name + let function_name = self.extract_function_name(full_function_name); + + // Parse JSON arguments + let arguments = if args_content.is_empty() { + "{}".to_string() + } else { + match serde_json::from_str::(args_content) { + Ok(value) => serde_json::to_string(&value) + .map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?, + Err(_) => { + // Skip malformed JSON + continue; + } + } + }; + + // Generate unique ID + let id = format!("gpt_oss_call_{}", uuid::Uuid::new_v4()); + + tools.push(ToolCall { + id, + r#type: "function".to_string(), + function: FunctionCall { + name: function_name, + arguments, + }, + }); + + _tool_index += 1; + } + } + + Ok(tools) + } + + async fn parse_incremental( + &self, + chunk: &str, + state: &mut ParseState, + ) -> ToolParserResult { + state.buffer.push_str(chunk); + + // Check for tool markers + if !self.has_tool_markers(&state.buffer) { + // No markers found, clear buffer and return + state.buffer.clear(); + return Ok(StreamResult::Incomplete); + } + + // Try to match streaming pattern + if let Some(captures) = self.streaming_extractor.captures(&state.buffer) { + if let (Some(name_match), Some(args_match)) = (captures.get(1), captures.get(2)) { + let full_function_name = name_match.as_str(); + let partial_args = args_match.as_str(); + + // Extract actual function name + let function_name = self.extract_function_name(full_function_name); + + // Send function name if not sent yet + if !state.in_string { + state.in_string = true; // Mark name as sent + return Ok(StreamResult::ToolName { + index: 0, + name: function_name.clone(), + }); + } + + // Check if we have a complete function call + if let Some(complete_match) = self.function_call_extractor.captures(&state.buffer) { + if let Some(args_match) = complete_match.get(2) { + let args_content = args_match.as_str().trim(); + + // Parse JSON arguments + let arguments = if args_content.is_empty() { + "{}".to_string() + } else { + match serde_json::from_str::(args_content) { + Ok(value) => serde_json::to_string(&value) + .unwrap_or_else(|_| "{}".to_string()), + Err(_) => "{}".to_string(), + } + }; + + // Generate unique ID + let id = format!("gpt_oss_call_{}", uuid::Uuid::new_v4()); + + let tool = ToolCall { + id, + r#type: "function".to_string(), + function: FunctionCall { + name: function_name, + arguments, + }, + }; + + // Remove the processed part from buffer + let complete_end = complete_match.get(0).unwrap().end(); + state.buffer.drain(..complete_end); + + // Reset state for next tool + state.in_string = false; + + return Ok(StreamResult::ToolComplete(tool)); + } + } else { + // Try to parse partial JSON for streaming arguments + if !partial_args.is_empty() { + // Look for the end of JSON (before <|call|>) + let json_part = if let Some(call_pos) = partial_args.find("<|call|>") { + &partial_args[..call_pos] + } else { + partial_args + }; + + match self.partial_json.parse_value(json_part) { + Ok((value, _consumed)) => { + let args_str = serde_json::to_string(&value) + .unwrap_or_else(|_| "{}".to_string()); + + return Ok(StreamResult::ToolArguments { + index: 0, + arguments: args_str, + }); + } + Err(_) => { + // Can't parse yet, keep buffering + } + } + } + } + } + } + + Ok(StreamResult::Incomplete) + } + + fn detect_format(&self, text: &str) -> bool { + self.has_tool_markers(text) || text.contains("<|channel|>commentary") + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_parse_gpt_oss_single_tool() { + let parser = GptOssParser::new(); + let input = r#"Some text +<|channel|>commentary to=functions.get_weather<|constrain|>json<|message|>{"location": "San Francisco"}<|call|> +More text"#; + + let result = parser.parse_complete(input).await.unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result[0].function.name, "get_weather"); + assert!(result[0].function.arguments.contains("San Francisco")); + } + + #[tokio::test] + async fn test_parse_gpt_oss_multiple_tools() { + let parser = GptOssParser::new(); + let input = r#"<|channel|>commentary to=functions.get_weather<|constrain|>json<|message|>{"location": "Paris"}<|call|>commentary +<|channel|>commentary to=functions.search<|constrain|>json<|message|>{"query": "Paris tourism"}<|call|>"#; + + let result = parser.parse_complete(input).await.unwrap(); + assert_eq!(result.len(), 2); + assert_eq!(result[0].function.name, "get_weather"); + assert_eq!(result[1].function.name, "search"); + assert!(result[0].function.arguments.contains("Paris")); + assert!(result[1].function.arguments.contains("Paris tourism")); + } + + #[tokio::test] + async fn test_parse_gpt_oss_with_prefix() { + let parser = GptOssParser::new(); + let input = r#"<|start|>assistant<|channel|>commentary to=functions.test<|constrain|>json<|message|>{"key": "value"}<|call|>"#; + + let result = parser.parse_complete(input).await.unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result[0].function.name, "test"); + } + + #[tokio::test] + async fn test_parse_gpt_oss_empty_args() { + let parser = GptOssParser::new(); + let input = + r#"<|channel|>commentary to=functions.get_time<|constrain|>json<|message|>{}<|call|>"#; + + let result = parser.parse_complete(input).await.unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result[0].function.name, "get_time"); + assert_eq!(result[0].function.arguments, "{}"); + } + + #[test] + fn test_detect_format() { + let parser = GptOssParser::new(); + assert!(parser.detect_format("<|channel|>commentary to=")); + assert!(parser.detect_format("<|channel|>commentary")); + assert!(!parser.detect_format("plain text")); + assert!(!parser.detect_format("[TOOL_CALLS]")); + } +} diff --git a/sgl-router/src/tool_parser/parsers/mod.rs b/sgl-router/src/tool_parser/parsers/mod.rs index 681a5fb31..693aeedf4 100644 --- a/sgl-router/src/tool_parser/parsers/mod.rs +++ b/sgl-router/src/tool_parser/parsers/mod.rs @@ -4,6 +4,8 @@ /// tool/function call formats. // Individual parser modules pub mod deepseek_parser; +pub mod glm4_moe_parser; +pub mod gpt_oss_parser; pub mod json_parser; pub mod kimik2_parser; pub mod llama_parser; @@ -14,6 +16,8 @@ pub mod step3_parser; // Re-export parser types for convenience pub use deepseek_parser::DeepSeekParser; +pub use glm4_moe_parser::Glm4MoeParser; +pub use gpt_oss_parser::GptOssParser; pub use json_parser::JsonParser; pub use kimik2_parser::KimiK2Parser; pub use llama_parser::LlamaParser; diff --git a/sgl-router/src/tool_parser/registry.rs b/sgl-router/src/tool_parser/registry.rs index ba01bb776..1a740f1a2 100644 --- a/sgl-router/src/tool_parser/registry.rs +++ b/sgl-router/src/tool_parser/registry.rs @@ -1,6 +1,6 @@ use crate::tool_parser::parsers::{ - DeepSeekParser, JsonParser, KimiK2Parser, LlamaParser, MistralParser, PythonicParser, - QwenParser, Step3Parser, + DeepSeekParser, Glm4MoeParser, GptOssParser, JsonParser, KimiK2Parser, LlamaParser, + MistralParser, PythonicParser, QwenParser, Step3Parser, }; use crate::tool_parser::traits::ToolParser; use std::collections::HashMap; @@ -115,11 +115,17 @@ impl ParserRegistry { // DeepSeek V3 parser - Unicode tokens with JSON blocks self.register_parser("deepseek", Arc::new(DeepSeekParser::new())); + // GLM-4 MoE parser - XML-style key-value format + self.register_parser("glm4_moe", Arc::new(Glm4MoeParser::new())); + // Step3 parser - StepTML XML format self.register_parser("step3", Arc::new(Step3Parser::new())); // Kimi K2 parser - Token-based with indexed functions self.register_parser("kimik2", Arc::new(KimiK2Parser::new())); + + // GPT-OSS parser - Channel format + self.register_parser("gpt_oss", Arc::new(GptOssParser::new())); } /// Register default model mappings @@ -158,6 +164,27 @@ impl ParserRegistry { // DeepSeek V2 uses pythonic format self.map_model("deepseek-*", "pythonic"); + // GLM models + // GLM-4 MoE uses XML-style format + self.map_model("glm-4-moe*", "glm4_moe"); + self.map_model("THUDM/glm-4-moe*", "glm4_moe"); + self.map_model("glm-4.5*", "glm4_moe"); + // Other GLM models may use JSON + self.map_model("glm-*", "json"); + + // Step3 models + self.map_model("step3*", "step3"); + self.map_model("Step-3*", "step3"); + + // Kimi models + self.map_model("kimi-k2*", "kimik2"); + self.map_model("Kimi-K2*", "kimik2"); + self.map_model("moonshot*/Kimi-K2*", "kimik2"); + + // GPT-OSS models (T4-style) + self.map_model("gpt-oss*", "gpt_oss"); + self.map_model("t4-*", "gpt_oss"); + // Other models default to JSON self.map_model("gemini-*", "json"); self.map_model("palm-*", "json"); diff --git a/sgl-router/tests/tool_parser_glm4_moe.rs b/sgl-router/tests/tool_parser_glm4_moe.rs new file mode 100644 index 000000000..bae8fe727 --- /dev/null +++ b/sgl-router/tests/tool_parser_glm4_moe.rs @@ -0,0 +1,194 @@ +//! GLM-4 MoE Parser Integration Tests + +use sglang_router_rs::tool_parser::{Glm4MoeParser, ParseState, StreamResult, ToolParser}; + +#[tokio::test] +async fn test_glm4_complete_parsing() { + let parser = Glm4MoeParser::new(); + + // Test single tool call + let input = r#"Let me search for that. +get_weather +city +Beijing +date +2024-12-25 + +The weather will be..."#; + + let result = parser.parse_complete(input).await.unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result[0].function.name, "get_weather"); + + // Verify arguments + let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + assert_eq!(args["city"], "Beijing"); + assert_eq!(args["date"], "2024-12-25"); +} + +#[tokio::test] +async fn test_glm4_multiple_tools() { + let parser = Glm4MoeParser::new(); + + let input = r#"search +query +rust tutorials + +translate +text +Hello World +target_lang +zh +"#; + + let result = parser.parse_complete(input).await.unwrap(); + assert_eq!(result.len(), 2); + assert_eq!(result[0].function.name, "search"); + assert_eq!(result[1].function.name, "translate"); +} + +#[tokio::test] +async fn test_glm4_type_conversion() { + let parser = Glm4MoeParser::new(); + + // Test various value types + let input = r#"process +count +42 +rate +1.5 +enabled +true +data +null +text +string value +"#; + + let result = parser.parse_complete(input).await.unwrap(); + assert_eq!(result.len(), 1); + + let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + assert_eq!(args["count"], 42); + assert_eq!(args["rate"], 1.5); + assert_eq!(args["enabled"], true); + assert_eq!(args["data"], serde_json::Value::Null); + assert_eq!(args["text"], "string value"); +} + +#[tokio::test] +async fn test_glm4_streaming() { + let parser = Glm4MoeParser::new(); + let mut state = ParseState::new(); + + // Simulate streaming chunks + let chunks = vec![ + "", + "get_weather\n", + "city\n", + "Shanghai\n", + "units\n", + "celsius\n", + "", + ]; + + let mut found_name = false; + let mut found_complete = false; + + for chunk in chunks { + let result = parser.parse_incremental(chunk, &mut state).await.unwrap(); + + match result { + StreamResult::ToolName { name, .. } => { + assert_eq!(name, "get_weather"); + found_name = true; + } + StreamResult::ToolComplete(tool) => { + assert_eq!(tool.function.name, "get_weather"); + found_complete = true; + } + _ => {} + } + } + + assert!(found_name || found_complete); +} + +#[test] +fn test_glm4_format_detection() { + let parser = Glm4MoeParser::new(); + + // Should detect GLM-4 format + assert!(parser.detect_format("")); + assert!(parser.detect_format("text with marker")); + + // Should not detect other formats + assert!(!parser.detect_format("[TOOL_CALLS]")); + assert!(!parser.detect_format("<|tool▁calls▁begin|>")); + assert!(!parser.detect_format("plain text")); +} + +#[tokio::test] +async fn test_glm4_python_literal_values() { + let parser = Glm4MoeParser::new(); + + // Test Python-style boolean values + let input = r#"config +debug +True +verbose +False +optional +None +"#; + + let result = parser.parse_complete(input).await.unwrap(); + assert_eq!(result.len(), 1); + + let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + assert_eq!(args["debug"], true); + assert_eq!(args["verbose"], false); + assert_eq!(args["optional"], serde_json::Value::Null); +} + +#[tokio::test] +async fn test_python_literals() { + let parser = Glm4MoeParser::new(); + + let input = r#"test_func +bool_true +True +bool_false +False +none_val +None +"#; + + let result = parser.parse_complete(input).await.unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result[0].function.name, "test_func"); + + let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + assert_eq!(args["bool_true"], true); + assert_eq!(args["bool_false"], false); + assert_eq!(args["none_val"], serde_json::Value::Null); +} + +#[tokio::test] +async fn test_nested_values() { + let parser = Glm4MoeParser::new(); + + let input = r#"process +data +{"nested": {"key": "value"}} +list +[1, 2, 3] +"#; + + let result = parser.parse_complete(input).await.unwrap(); + assert_eq!(result.len(), 1); + + let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + assert!(args["data"].is_object()); + assert!(args["list"].is_array()); +} diff --git a/sgl-router/tests/tool_parser_gpt_oss.rs b/sgl-router/tests/tool_parser_gpt_oss.rs new file mode 100644 index 000000000..50dc0be15 --- /dev/null +++ b/sgl-router/tests/tool_parser_gpt_oss.rs @@ -0,0 +1,201 @@ +//! GPT-OSS Parser Integration Tests + +use sglang_router_rs::tool_parser::{GptOssParser, ParseState, StreamResult, ToolParser}; + +#[tokio::test] +async fn test_gpt_oss_complete_parsing() { + let parser = GptOssParser::new(); + + // Test single tool call + let input = r#"Let me search for that information. +<|channel|>commentary to=functions.search<|constrain|>json<|message|>{"query": "rust programming", "limit": 10}<|call|> +Here are the results..."#; + + let result = parser.parse_complete(input).await.unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result[0].function.name, "search"); + + // Verify arguments + let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + assert_eq!(args["query"], "rust programming"); + assert_eq!(args["limit"], 10); +} + +#[tokio::test] +async fn test_gpt_oss_multiple_tools() { + let parser = GptOssParser::new(); + + let input = r#"<|channel|>commentary to=functions.get_weather<|constrain|>json<|message|>{"location": "Paris"}<|call|>commentary +<|channel|>commentary to=functions.search<|constrain|>json<|message|>{"query": "Paris tourism"}<|call|>"#; + + let result = parser.parse_complete(input).await.unwrap(); + assert_eq!(result.len(), 2); + assert_eq!(result[0].function.name, "get_weather"); + assert_eq!(result[1].function.name, "search"); +} + +#[tokio::test] +async fn test_gpt_oss_with_namespace() { + let parser = GptOssParser::new(); + + // Test with different namespace patterns + let input = r#"<|channel|>commentary to=api.users.create<|constrain|>json<|message|>{"name": "John", "email": "john@example.com"}<|call|> +<|channel|>commentary to=tools.calculator.add<|constrain|>json<|message|>{"x": 10, "y": 20}<|call|>"#; + + let result = parser.parse_complete(input).await.unwrap(); + assert_eq!(result.len(), 2); + assert_eq!(result[0].function.name, "create"); // Should extract last part + assert_eq!(result[1].function.name, "add"); +} + +#[tokio::test] +async fn test_gpt_oss_with_assistant_prefix() { + let parser = GptOssParser::new(); + + // Test with <|start|>assistant prefix + let input = r#"<|start|>assistant<|channel|>commentary to=functions.test<|constrain|>json<|message|>{"key": "value"}<|call|>"#; + + let result = parser.parse_complete(input).await.unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result[0].function.name, "test"); +} + +#[tokio::test] +async fn test_gpt_oss_empty_args() { + let parser = GptOssParser::new(); + + // Test with empty arguments + let input = + r#"<|channel|>commentary to=functions.get_time<|constrain|>json<|message|>{}<|call|>"#; + + let result = parser.parse_complete(input).await.unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result[0].function.name, "get_time"); + assert_eq!(result[0].function.arguments, "{}"); +} + +#[tokio::test] +async fn test_gpt_oss_streaming() { + let parser = GptOssParser::new(); + let mut state = ParseState::new(); + + // Simulate streaming chunks + let chunks = vec![ + "<|channel|>commentary to=", + "functions.calculate", + "<|constrain|>json<|message|>", + r#"{"x": 10"#, + r#", "y": 20}"#, + "<|call|>", + ]; + + let mut found_name = false; + let mut found_complete = false; + + for chunk in chunks { + let result = parser.parse_incremental(chunk, &mut state).await.unwrap(); + + match result { + StreamResult::ToolName { name, .. } => { + assert_eq!(name, "calculate"); + found_name = true; + } + StreamResult::ToolComplete(tool) => { + assert_eq!(tool.function.name, "calculate"); + found_complete = true; + } + _ => {} + } + } + + assert!(found_name || found_complete); +} + +#[test] +fn test_gpt_oss_format_detection() { + let parser = GptOssParser::new(); + + // Should detect GPT-OSS format + assert!(parser.detect_format("<|channel|>commentary to=")); + assert!(parser.detect_format("<|channel|>commentary")); + assert!(parser.detect_format("text with <|channel|>commentary to= marker")); + + // Should not detect other formats + assert!(!parser.detect_format("[TOOL_CALLS]")); + assert!(!parser.detect_format("")); + assert!(!parser.detect_format("plain text")); +} + +#[tokio::test] +async fn test_gpt_oss_with_whitespace() { + let parser = GptOssParser::new(); + + // Test with whitespace after function name + let input = r#"<|channel|>commentary to=functions.test <|constrain|>json<|message|>{"key": "value"}<|call|>"#; + + let result = parser.parse_complete(input).await.unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result[0].function.name, "test"); +} + +#[tokio::test] +async fn test_gpt_oss_complex_json() { + let parser = GptOssParser::new(); + + // Test with complex nested JSON + let input = r#"<|channel|>commentary to=functions.process<|constrain|>json<|message|>{ + "nested": { + "data": [1, 2, 3], + "config": { + "enabled": true + } + } +}<|call|>"#; + + let result = parser.parse_complete(input).await.unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result[0].function.name, "process"); + + let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + assert!(args["nested"]["data"].is_array()); + assert_eq!(args["nested"]["config"]["enabled"], true); +} + +#[tokio::test] +async fn test_commentary_without_function() { + let parser = GptOssParser::new(); + + // Python should extract commentary as normal text + let input = r#"<|channel|>commentary<|message|>**Action plan**: 1. Do X 2. Do Y<|end|>"#; + + let result = parser.parse_complete(input).await.unwrap(); + assert_eq!(result.len(), 0); // No tool calls + // TODO: Verify normal text = "**Action plan**: 1. Do X 2. Do Y" +} + +#[tokio::test] +async fn test_final_channel() { + let parser = GptOssParser::new(); + + let input = r#"<|channel|>commentary to=functions.test<|constrain|>json<|message|>{"x": 1}<|call|> +<|channel|>final<|message|>The result is calculated.<|return|>"#; + + let result = parser.parse_complete(input).await.unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result[0].function.name, "test"); + // TODO: Verify normal text = "The result is calculated." +} + +#[tokio::test] +async fn test_mixed_commentary_and_calls() { + let parser = GptOssParser::new(); + + let input = r#"<|channel|>commentary<|message|>Let me think<|end|> +<|channel|>commentary to=functions.calc<|constrain|>json<|message|>{"x": 5}<|call|> +<|channel|>commentary<|message|>Processing...<|end|>"#; + + let result = parser.parse_complete(input).await.unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result[0].function.name, "calc"); + // TODO: Verify normal text = "Let me think Processing..." +}