From ab926dd6977cd1bad1876cdf83c3b164f7cff655 Mon Sep 17 00:00:00 2001 From: Chang Su Date: Thu, 9 Oct 2025 03:53:23 -0700 Subject: [PATCH] [router][grpc] Fix streaming bugs: empty tool names, state pollution, and panics (#11373) --- sgl-router/benches/tool_parser_benchmark.rs | 2 +- sgl-router/src/reasoning_parser/factory.rs | 68 ++- sgl-router/src/reasoning_parser/mod.rs | 2 +- sgl-router/src/routers/grpc/context.rs | 4 +- sgl-router/src/routers/grpc/pd_router.rs | 4 +- sgl-router/src/routers/grpc/processing.rs | 4 +- sgl-router/src/routers/grpc/router.rs | 4 +- sgl-router/src/routers/grpc/streaming.rs | 86 +++- sgl-router/src/routers/grpc/utils.rs | 60 ++- sgl-router/src/server.rs | 8 +- sgl-router/src/tool_parser/errors.rs | 4 +- sgl-router/src/tool_parser/factory.rs | 113 ++++- sgl-router/src/tool_parser/mod.rs | 4 +- .../tool_parser/parsers/deepseek_parser.rs | 28 +- .../tool_parser/parsers/glm4_moe_parser.rs | 21 +- .../parsers/gpt_oss_harmony_parser.rs | 10 +- .../src/tool_parser/parsers/gpt_oss_parser.rs | 10 +- sgl-router/src/tool_parser/parsers/helpers.rs | 113 +++-- .../src/tool_parser/parsers/json_parser.rs | 24 +- .../src/tool_parser/parsers/kimik2_parser.rs | 15 +- .../src/tool_parser/parsers/llama_parser.rs | 24 +- .../src/tool_parser/parsers/mistral_parser.rs | 24 +- .../tool_parser/parsers/pythonic_parser.rs | 64 ++- .../src/tool_parser/parsers/qwen_parser.rs | 22 +- .../src/tool_parser/parsers/step3_parser.rs | 30 +- sgl-router/src/tool_parser/partial_json.rs | 99 ++-- sgl-router/src/tool_parser/tests.rs | 20 +- sgl-router/src/tool_parser/traits.rs | 21 +- sgl-router/tests/common/mod.rs | 1 + sgl-router/tests/common/streaming_helpers.rs | 134 +++++ sgl-router/tests/tool_parser_glm4_moe.rs | 24 +- sgl-router/tests/tool_parser_partial_json.rs | 156 ++++++ sgl-router/tests/tool_parser_streaming.rs | 476 +++++++++--------- 33 files changed, 1145 insertions(+), 534 deletions(-) create mode 100644 sgl-router/tests/common/streaming_helpers.rs create mode 100644 sgl-router/tests/tool_parser_partial_json.rs diff --git a/sgl-router/benches/tool_parser_benchmark.rs b/sgl-router/benches/tool_parser_benchmark.rs index 6fe174383..636a32366 100644 --- a/sgl-router/benches/tool_parser_benchmark.rs +++ b/sgl-router/benches/tool_parser_benchmark.rs @@ -10,7 +10,7 @@ use criterion::{black_box, criterion_group, BenchmarkId, Criterion, Throughput}; use serde_json::json; use sglang_router_rs::protocols::spec::{Function, Tool}; -use sglang_router_rs::tool_parser::{JsonParser, ToolParser, ToolParserFactory}; +use sglang_router_rs::tool_parser::{JsonParser, ParserFactory as ToolParserFactory, ToolParser}; use std::collections::BTreeMap; use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; use std::sync::{Arc, Mutex}; diff --git a/sgl-router/src/reasoning_parser/factory.rs b/sgl-router/src/reasoning_parser/factory.rs index 7e2367f7b..28f3d0836 100644 --- a/sgl-router/src/reasoning_parser/factory.rs +++ b/sgl-router/src/reasoning_parser/factory.rs @@ -82,9 +82,15 @@ impl ParserRegistry { } } - /// Get a parser by exact name (creates new instance, not pooled). - /// Use this for compatibility or when you need a fresh instance. - pub fn get_parser(&self, name: &str) -> Option> { + /// Check if a parser with the given name is registered. + pub fn has_parser(&self, name: &str) -> bool { + let creators = self.creators.read().unwrap(); + creators.contains_key(name) + } + + /// Create a fresh parser instance by exact name (not pooled). + /// Returns a new parser instance for each call - useful for streaming where state isolation is needed. + pub fn create_parser(&self, name: &str) -> Option> { let creators = self.creators.read().unwrap(); creators.get(name).map(|creator| creator()) } @@ -102,14 +108,30 @@ impl ParserRegistry { None } - /// Find a parser for a given model ID by pattern matching (creates new instance). - pub fn find_parser_for_model(&self, model_id: &str) -> Option> { + /// Check if a parser can be created for a specific model without actually creating it. + /// Returns true if a parser is available (registered) for this model. + pub fn has_parser_for_model(&self, model_id: &str) -> bool { let patterns = self.patterns.read().unwrap(); let model_lower = model_id.to_lowercase(); for (pattern, parser_name) in patterns.iter() { if model_lower.contains(&pattern.to_lowercase()) { - return self.get_parser(parser_name); + let creators = self.creators.read().unwrap(); + return creators.contains_key(parser_name); + } + } + false + } + + /// Create a fresh parser instance for a given model ID by pattern matching (not pooled). + /// Returns a new parser instance for each call - useful for streaming where state isolation is needed. + pub fn create_for_model(&self, model_id: &str) -> Option> { + let patterns = self.patterns.read().unwrap(); + let model_lower = model_id.to_lowercase(); + + for (pattern, parser_name) in patterns.iter() { + if model_lower.contains(&pattern.to_lowercase()) { + return self.create_parser(parser_name); } } None @@ -131,11 +153,11 @@ impl Default for ParserRegistry { /// Factory for creating reasoning parsers based on model type. #[derive(Clone)] -pub struct ReasoningParserFactory { +pub struct ParserFactory { registry: ParserRegistry, } -impl ReasoningParserFactory { +impl ParserFactory { /// Create a new factory with default parsers registered. pub fn new() -> Self { let registry = ParserRegistry::new(); @@ -211,7 +233,7 @@ impl ReasoningParserFactory { /// Use this when you need an isolated parser instance. pub fn create(&self, model_id: &str) -> Result, ParseError> { // First try to find by pattern - if let Some(parser) = self.registry.find_parser_for_model(model_id) { + if let Some(parser) = self.registry.create_for_model(model_id) { return Ok(parser); } @@ -240,7 +262,7 @@ impl ReasoningParserFactory { } } -impl Default for ReasoningParserFactory { +impl Default for ParserFactory { fn default() -> Self { Self::new() } @@ -252,35 +274,35 @@ mod tests { #[test] fn test_factory_creates_deepseek_r1() { - let factory = ReasoningParserFactory::new(); + let factory = ParserFactory::new(); let parser = factory.create("deepseek-r1-distill").unwrap(); assert_eq!(parser.model_type(), "deepseek_r1"); } #[test] fn test_factory_creates_qwen3() { - let factory = ReasoningParserFactory::new(); + let factory = ParserFactory::new(); let parser = factory.create("qwen3-7b").unwrap(); assert_eq!(parser.model_type(), "qwen3"); } #[test] fn test_factory_creates_kimi() { - let factory = ReasoningParserFactory::new(); + let factory = ParserFactory::new(); let parser = factory.create("kimi-chat").unwrap(); assert_eq!(parser.model_type(), "kimi"); } #[test] fn test_factory_fallback_to_passthrough() { - let factory = ReasoningParserFactory::new(); + let factory = ParserFactory::new(); let parser = factory.create("unknown-model").unwrap(); assert_eq!(parser.model_type(), "passthrough"); } #[test] fn test_case_insensitive_matching() { - let factory = ReasoningParserFactory::new(); + let factory = ParserFactory::new(); let parser1 = factory.create("DeepSeek-R1").unwrap(); let parser2 = factory.create("QWEN3").unwrap(); let parser3 = factory.create("Kimi").unwrap(); @@ -292,21 +314,21 @@ mod tests { #[test] fn test_step3_model() { - let factory = ReasoningParserFactory::new(); + let factory = ParserFactory::new(); let step3 = factory.create("step3-model").unwrap(); assert_eq!(step3.model_type(), "step3"); } #[test] fn test_glm45_model() { - let factory = ReasoningParserFactory::new(); + let factory = ParserFactory::new(); let glm45 = factory.create("glm45-v2").unwrap(); assert_eq!(glm45.model_type(), "glm45"); } #[tokio::test] async fn test_pooled_parser_reuse() { - let factory = ReasoningParserFactory::new(); + let factory = ParserFactory::new(); // Get the same parser twice - should be the same instance let parser1 = factory.get_pooled("deepseek-r1"); @@ -322,7 +344,7 @@ mod tests { #[tokio::test] async fn test_pooled_parser_concurrent_access() { - let factory = ReasoningParserFactory::new(); + let factory = ParserFactory::new(); let parser = factory.get_pooled("deepseek-r1"); // Spawn multiple async tasks that use the same parser @@ -348,7 +370,7 @@ mod tests { #[tokio::test] async fn test_pool_clearing() { - let factory = ReasoningParserFactory::new(); + let factory = ParserFactory::new(); // Get a pooled parser let parser1 = factory.get_pooled("deepseek-r1"); @@ -365,7 +387,7 @@ mod tests { #[tokio::test] async fn test_passthrough_parser_pooling() { - let factory = ReasoningParserFactory::new(); + let factory = ParserFactory::new(); // Unknown models should get passthrough parser let parser1 = factory.get_pooled("unknown-model-1"); @@ -383,7 +405,7 @@ mod tests { use std::sync::atomic::{AtomicUsize, Ordering}; use std::time::Instant; - let factory = ReasoningParserFactory::new(); + let factory = ParserFactory::new(); let num_tasks = 100; let requests_per_task = 50; let models = vec!["deepseek-r1", "qwen3", "kimi", "qwen3-thinking"]; @@ -512,7 +534,7 @@ mod tests { #[tokio::test(flavor = "multi_thread", worker_threads = 4)] async fn test_concurrent_pool_modifications() { - let factory = ReasoningParserFactory::new(); + let factory = ParserFactory::new(); let mut handles = vec![]; // Task 1: Continuously get parsers diff --git a/sgl-router/src/reasoning_parser/mod.rs b/sgl-router/src/reasoning_parser/mod.rs index 8cc7e8357..95ffcbc4f 100644 --- a/sgl-router/src/reasoning_parser/mod.rs +++ b/sgl-router/src/reasoning_parser/mod.rs @@ -2,7 +2,7 @@ pub mod factory; pub mod parsers; pub mod traits; -pub use factory::{ParserRegistry, PooledParser, ReasoningParserFactory}; +pub use factory::{ParserFactory, ParserRegistry, PooledParser}; pub use parsers::{ BaseReasoningParser, DeepSeekR1Parser, Glm45Parser, KimiParser, Qwen3Parser, QwenThinkingParser, Step3Parser, diff --git a/sgl-router/src/routers/grpc/context.rs b/sgl-router/src/routers/grpc/context.rs index bc9f3c7a5..50e713fe2 100644 --- a/sgl-router/src/routers/grpc/context.rs +++ b/sgl-router/src/routers/grpc/context.rs @@ -15,10 +15,10 @@ use crate::grpc_client::{proto, SglangSchedulerClient}; use crate::protocols::spec::{ ChatCompletionRequest, ChatCompletionResponse, GenerateRequest, GenerateResponse, }; -use crate::reasoning_parser::ReasoningParserFactory; +use crate::reasoning_parser::ParserFactory as ReasoningParserFactory; use crate::tokenizer::stop::StopSequenceDecoder; use crate::tokenizer::traits::Tokenizer; -use crate::tool_parser::ToolParserFactory; +use crate::tool_parser::ParserFactory as ToolParserFactory; // ============================================================================ // Core Context Types diff --git a/sgl-router/src/routers/grpc/pd_router.rs b/sgl-router/src/routers/grpc/pd_router.rs index 0fc29a6c5..de6f79a2d 100644 --- a/sgl-router/src/routers/grpc/pd_router.rs +++ b/sgl-router/src/routers/grpc/pd_router.rs @@ -7,11 +7,11 @@ use crate::protocols::spec::{ ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest, ResponsesGetParams, ResponsesRequest, }; -use crate::reasoning_parser::ReasoningParserFactory; +use crate::reasoning_parser::ParserFactory as ReasoningParserFactory; use crate::routers::RouterTrait; use crate::server::AppContext; use crate::tokenizer::traits::Tokenizer; -use crate::tool_parser::ToolParserFactory; +use crate::tool_parser::ParserFactory as ToolParserFactory; use async_trait::async_trait; use axum::{ body::Body, diff --git a/sgl-router/src/routers/grpc/processing.rs b/sgl-router/src/routers/grpc/processing.rs index 91f831663..523480465 100644 --- a/sgl-router/src/routers/grpc/processing.rs +++ b/sgl-router/src/routers/grpc/processing.rs @@ -13,10 +13,10 @@ use crate::protocols::spec::{ ChatChoice, ChatCompletionMessage, ChatCompletionRequest, FunctionCallResponse, ToolCall, ToolChoice, ToolChoiceValue, }; -use crate::reasoning_parser::ReasoningParserFactory; +use crate::reasoning_parser::ParserFactory as ReasoningParserFactory; use crate::tokenizer::stop::{SequenceDecoderOutput, StopSequenceDecoder}; use crate::tokenizer::traits::Tokenizer; -use crate::tool_parser::ToolParserFactory; +use crate::tool_parser::ParserFactory as ToolParserFactory; use super::utils; diff --git a/sgl-router/src/routers/grpc/router.rs b/sgl-router/src/routers/grpc/router.rs index c35358209..5666823de 100644 --- a/sgl-router/src/routers/grpc/router.rs +++ b/sgl-router/src/routers/grpc/router.rs @@ -18,11 +18,11 @@ use crate::protocols::spec::{ ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest, ResponsesGetParams, ResponsesRequest, }; -use crate::reasoning_parser::ReasoningParserFactory; +use crate::reasoning_parser::ParserFactory as ReasoningParserFactory; use crate::routers::RouterTrait; use crate::server::AppContext; use crate::tokenizer::traits::Tokenizer; -use crate::tool_parser::ToolParserFactory; +use crate::tool_parser::ParserFactory as ToolParserFactory; /// gRPC router implementation for SGLang #[derive(Clone)] diff --git a/sgl-router/src/routers/grpc/streaming.rs b/sgl-router/src/routers/grpc/streaming.rs index 4920b30b1..1e7707767 100644 --- a/sgl-router/src/routers/grpc/streaming.rs +++ b/sgl-router/src/routers/grpc/streaming.rs @@ -34,8 +34,8 @@ use tokio::sync::mpsc; #[derive(Clone)] pub struct StreamingProcessor { tokenizer: Arc, - tool_parser_factory: crate::tool_parser::ToolParserFactory, - reasoning_parser_factory: crate::reasoning_parser::ReasoningParserFactory, + tool_parser_factory: crate::tool_parser::ParserFactory, + reasoning_parser_factory: crate::reasoning_parser::ParserFactory, configured_tool_parser: Option, configured_reasoning_parser: Option, } @@ -43,8 +43,8 @@ pub struct StreamingProcessor { impl StreamingProcessor { pub fn new( tokenizer: Arc, - tool_parser_factory: crate::tool_parser::ToolParserFactory, - reasoning_parser_factory: crate::reasoning_parser::ReasoningParserFactory, + tool_parser_factory: crate::tool_parser::ParserFactory, + reasoning_parser_factory: crate::reasoning_parser::ParserFactory, configured_tool_parser: Option, configured_reasoning_parser: Option, ) -> Self { @@ -195,6 +195,47 @@ impl StreamingProcessor { let created = dispatch.created; let system_fingerprint = dispatch.weight_version.as_deref(); + // Check parser availability once upfront (log warning only once per request) + let reasoning_parser_available = if separate_reasoning { + if let Some(parser_name) = self.configured_reasoning_parser.as_ref() { + self.reasoning_parser_factory + .registry() + .has_parser(parser_name) + } else { + self.reasoning_parser_factory + .registry() + .has_parser_for_model(model) + } + } else { + false + }; + + let tool_parser_available = if tools.is_some() { + if let Some(parser_name) = self.configured_tool_parser.as_ref() { + self.tool_parser_factory.registry().has_parser(parser_name) + } else { + self.tool_parser_factory + .registry() + .has_parser_for_model(model) + } + } else { + false + }; + + if separate_reasoning && !reasoning_parser_available { + warn!( + "No reasoning parser found for model '{}', skipping reasoning parsing", + model + ); + } + + if tools.is_some() && !tool_parser_available { + warn!( + "No tool parser found for model '{}', skipping tool call parsing", + model + ); + } + // Phase 2: Main streaming loop while let Some(response) = grpc_stream.next().await { let gen_response = response.map_err(|e| format!("Stream error: {}", e))?; @@ -276,7 +317,7 @@ impl StreamingProcessor { stream_buffer.push_str(&delta); // Reasoning content handling - let in_reasoning = if separate_reasoning { + let in_reasoning = if separate_reasoning && reasoning_parser_available { let (normal_text, reasoning_chunk, in_reasoning) = self .process_reasoning_stream( &delta, @@ -303,8 +344,12 @@ impl StreamingProcessor { let tool_choice_enabled = !matches!(tool_choice, Some(ToolChoice::Value(ToolChoiceValue::None))); - if !in_reasoning && tool_choice_enabled && tools.is_some() { - let (should_skip, tool_chunks) = self + if !in_reasoning + && tool_choice_enabled + && tools.is_some() + && tool_parser_available + { + let tool_chunks = self .process_tool_calls_stream( &delta, index, @@ -325,10 +370,9 @@ impl StreamingProcessor { .map_err(|_| "Failed to send tool call chunk".to_string())?; } - // Continue to process the next chunk as we have tool chunks - if should_skip { - continue; - } + // Always skip regular content when tool parsing is active + // Parser either emitted chunks or buffered content + continue; } // Regular content emission @@ -963,13 +1007,15 @@ impl StreamingProcessor { created: u64, system_fingerprint: Option<&str>, ) -> (String, Option, bool) { - // Get or create parser for this index + // Create fresh parser for this index (not pooled, to avoid state pollution) reasoning_parsers.entry(index).or_insert_with(|| { - utils::get_reasoning_parser( + let parser = utils::create_reasoning_parser( &self.reasoning_parser_factory, self.configured_reasoning_parser.as_ref(), model, ) + .expect("Parser should be available - checked upfront"); + Arc::new(tokio::sync::Mutex::new(parser)) }); if let Some(pooled_parser) = reasoning_parsers.get(&index) { @@ -1034,20 +1080,23 @@ impl StreamingProcessor { created: u64, system_fingerprint: Option<&str>, history_tool_calls_count: usize, - ) -> (bool, Vec) { + ) -> Vec { let mut chunks = Vec::new(); - // Get or create parser for this index + // Create fresh parser for this index (not pooled, to avoid state pollution) tool_parsers.entry(index).or_insert_with(|| { - utils::get_tool_parser( + let parser = utils::create_tool_parser( &self.tool_parser_factory, self.configured_tool_parser.as_ref(), model, ) + .expect("Parser should be available - checked upfront"); + Arc::new(tokio::sync::Mutex::new(parser)) }); if let Some(pooled_parser) = tool_parsers.get(&index) { let mut parser = pooled_parser.lock().await; + match parser.parse_incremental(delta, tools).await { Ok(crate::tool_parser::StreamingParseResult { normal_text, calls }) => { // Emit normal text if present @@ -1129,8 +1178,7 @@ impl StreamingProcessor { }); } - // If we emitted chunks, skip regular content - return (!chunks.is_empty(), chunks); + return chunks; } Err(e) => { error!("Tool call parsing error: {}", e); @@ -1138,7 +1186,7 @@ impl StreamingProcessor { } } - (false, chunks) + chunks } /// Format a response as SSE chunk into a reusable buffer diff --git a/sgl-router/src/routers/grpc/utils.rs b/sgl-router/src/routers/grpc/utils.rs index 01474fb10..4422671bf 100644 --- a/sgl-router/src/routers/grpc/utils.rs +++ b/sgl-router/src/routers/grpc/utils.rs @@ -677,13 +677,12 @@ pub fn generate_tool_call_id( /// /// If a parser name is explicitly configured, use that parser. /// Otherwise, auto-detect based on the model name. +/// Get a pooled reasoning parser (for non-streaming where state doesn't matter) pub fn get_reasoning_parser( - reasoning_parser_factory: &crate::reasoning_parser::ReasoningParserFactory, + reasoning_parser_factory: &crate::reasoning_parser::ParserFactory, configured_parser: Option<&String>, model: &str, ) -> crate::reasoning_parser::PooledParser { - use tracing::warn; - if let Some(parser_name) = configured_parser { // Use configured parser if specified reasoning_parser_factory @@ -702,17 +701,40 @@ pub fn get_reasoning_parser( } } +/// Create a fresh reasoning parser instance (for streaming where state isolation is needed) +pub fn create_reasoning_parser( + reasoning_parser_factory: &crate::reasoning_parser::ParserFactory, + configured_parser: Option<&String>, + model: &str, +) -> Option> { + if let Some(parser_name) = configured_parser { + // Use configured parser if specified + reasoning_parser_factory + .registry() + .create_parser(parser_name) + .or_else(|| { + warn!( + "Configured reasoning parser '{}' not found, falling back to model-based selection", + parser_name + ); + reasoning_parser_factory.registry().create_for_model(model) + }) + } else { + // Auto-detect based on model + reasoning_parser_factory.registry().create_for_model(model) + } +} + /// Get the appropriate tool parser for a model /// /// If a parser name is explicitly configured, use that parser. /// Otherwise, auto-detect based on the model name. +/// Get a pooled tool parser (for non-streaming where state doesn't matter) pub fn get_tool_parser( - tool_parser_factory: &crate::tool_parser::ToolParserFactory, + tool_parser_factory: &crate::tool_parser::ParserFactory, configured_parser: Option<&String>, model: &str, -) -> crate::tool_parser::PooledToolParser { - use tracing::warn; - +) -> crate::tool_parser::PooledParser { if let Some(parser_name) = configured_parser { // Use configured parser if specified tool_parser_factory @@ -731,6 +753,30 @@ pub fn get_tool_parser( } } +/// Create a fresh tool parser instance (for streaming where state isolation is needed) +pub fn create_tool_parser( + tool_parser_factory: &crate::tool_parser::ParserFactory, + configured_parser: Option<&String>, + model: &str, +) -> Option> { + if let Some(parser_name) = configured_parser { + // Use configured parser if specified + tool_parser_factory + .registry() + .create_parser(parser_name) + .or_else(|| { + warn!( + "Configured tool parser '{}' not found, falling back to model-based selection", + parser_name + ); + tool_parser_factory.registry().create_for_model(model) + }) + } else { + // Auto-detect based on model + tool_parser_factory.registry().create_for_model(model) + } +} + /// Convert proto::OutputLogProbs to OpenAI ChatLogProbs format /// /// This function decodes token IDs using the tokenizer and builds the logprobs structure diff --git a/sgl-router/src/server.rs b/sgl-router/src/server.rs index f20c21b26..ad222f4fd 100644 --- a/sgl-router/src/server.rs +++ b/sgl-router/src/server.rs @@ -18,11 +18,11 @@ use crate::{ }, worker_spec::{WorkerApiResponse, WorkerConfigRequest, WorkerErrorResponse}, }, - reasoning_parser::ReasoningParserFactory, + reasoning_parser::ParserFactory as ReasoningParserFactory, routers::{router_manager::RouterManager, RouterTrait}, service_discovery::{start_service_discovery, ServiceDiscoveryConfig}, tokenizer::{factory as tokenizer_factory, traits::Tokenizer}, - tool_parser::ToolParserFactory, + tool_parser::ParserFactory as ToolParserFactory, }; use axum::{ extract::{Path, Query, Request, State}, @@ -88,8 +88,8 @@ impl AppContext { tokenizer_factory::create_tokenizer(&tokenizer_path) .map_err(|e| format!("Failed to create tokenizer: {e}"))?, ); - let reasoning_parser_factory = Some(ReasoningParserFactory::new()); - let tool_parser_factory = Some(ToolParserFactory::new()); + let reasoning_parser_factory = Some(crate::reasoning_parser::ParserFactory::new()); + let tool_parser_factory = Some(crate::tool_parser::ParserFactory::new()); (tokenizer, reasoning_parser_factory, tool_parser_factory) } else { diff --git a/sgl-router/src/tool_parser/errors.rs b/sgl-router/src/tool_parser/errors.rs index 30129a596..8a34e5f93 100644 --- a/sgl-router/src/tool_parser/errors.rs +++ b/sgl-router/src/tool_parser/errors.rs @@ -1,11 +1,11 @@ use thiserror::Error; /// Result type for tool parser operations -pub type ToolParserResult = Result; +pub type ParserResult = Result; /// Errors that can occur during tool parsing #[derive(Debug, Error)] -pub enum ToolParserError { +pub enum ParserError { #[error("Parsing failed: {0}")] ParsingFailed(String), diff --git a/sgl-router/src/tool_parser/factory.rs b/sgl-router/src/tool_parser/factory.rs index ae7bee418..3d6bede95 100644 --- a/sgl-router/src/tool_parser/factory.rs +++ b/sgl-router/src/tool_parser/factory.rs @@ -11,25 +11,25 @@ use crate::tool_parser::parsers::{ use crate::tool_parser::traits::ToolParser; /// Type alias for pooled parser instances. -pub type PooledToolParser = Arc>>; +pub type PooledParser = Arc>>; /// Type alias for parser creator functions. type ParserCreator = Arc Box + Send + Sync>; /// Registry for model-specific tool parsers with pooling support. #[derive(Clone)] -pub struct ToolParserRegistry { +pub struct ParserRegistry { /// Creator functions for parsers (used when pool is empty) creators: Arc>>, /// Pooled parser instances for reuse - pool: Arc>>, + pool: Arc>>, /// Model pattern to parser name mappings model_mapping: Arc>>, /// Default parser name default_parser: Arc>, } -impl ToolParserRegistry { +impl ParserRegistry { /// Create a new empty registry. pub fn new() -> Self { Self { @@ -57,7 +57,7 @@ impl ToolParserRegistry { /// Get a pooled parser by exact name. /// Returns a shared parser instance from the pool, creating one if needed. - pub fn get_pooled_parser(&self, name: &str) -> Option { + pub fn get_pooled_parser(&self, name: &str) -> Option { // First check if we have a pooled instance { let pool = self.pool.read().unwrap(); @@ -81,8 +81,91 @@ impl ToolParserRegistry { } } + /// Check if a parser with the given name is registered. + pub fn has_parser(&self, name: &str) -> bool { + let creators = self.creators.read().unwrap(); + creators.contains_key(name) + } + + /// Create a fresh (non-pooled) parser instance by exact name. + /// Returns a new parser instance for each call - useful for streaming where state isolation is needed. + pub fn create_parser(&self, name: &str) -> Option> { + let creators = self.creators.read().unwrap(); + creators.get(name).map(|creator| creator()) + } + + /// Check if a parser can be created for a specific model without actually creating it. + /// Returns true if a parser is available (registered) for this model. + pub fn has_parser_for_model(&self, model: &str) -> bool { + // Try exact match first + { + let mapping = self.model_mapping.read().unwrap(); + if let Some(parser_name) = mapping.get(model) { + let creators = self.creators.read().unwrap(); + if creators.contains_key(parser_name) { + return true; + } + } + } + + // Try prefix matching + let model_mapping = self.model_mapping.read().unwrap(); + let best_match = model_mapping + .iter() + .filter(|(pattern, _)| { + pattern.ends_with('*') && model.starts_with(&pattern[..pattern.len() - 1]) + }) + .max_by_key(|(pattern, _)| pattern.len()); + + if let Some((_, parser_name)) = best_match { + let creators = self.creators.read().unwrap(); + if creators.contains_key(parser_name) { + return true; + } + } + + // Check if default parser exists + let default = self.default_parser.read().unwrap().clone(); + let creators = self.creators.read().unwrap(); + creators.contains_key(&default) + } + + /// Create a fresh (non-pooled) parser instance for a specific model. + /// Returns a new parser instance for each call - useful for streaming where state isolation is needed. + pub fn create_for_model(&self, model: &str) -> Option> { + // Try exact match first + { + let mapping = self.model_mapping.read().unwrap(); + if let Some(parser_name) = mapping.get(model) { + if let Some(parser) = self.create_parser(parser_name) { + return Some(parser); + } + } + } + + // Try prefix matching with more specific patterns first + let model_mapping = self.model_mapping.read().unwrap(); + let best_match = model_mapping + .iter() + .filter(|(pattern, _)| { + pattern.ends_with('*') && model.starts_with(&pattern[..pattern.len() - 1]) + }) + .max_by_key(|(pattern, _)| pattern.len()); + + // Return the best matching parser + if let Some((_, parser_name)) = best_match { + if let Some(parser) = self.create_parser(parser_name) { + return Some(parser); + } + } + + // Fall back to default parser + let default = self.default_parser.read().unwrap().clone(); + self.create_parser(&default) + } + /// Get parser for a specific model - pub fn get_pooled_for_model(&self, model: &str) -> Option { + pub fn get_pooled_for_model(&self, model: &str) -> Option { // Try exact match first { let mapping = self.model_mapping.read().unwrap(); @@ -127,7 +210,7 @@ impl ToolParserRegistry { } } -impl Default for ToolParserRegistry { +impl Default for ParserRegistry { fn default() -> Self { Self::new() } @@ -135,14 +218,14 @@ impl Default for ToolParserRegistry { /// Factory for creating tool parsers based on model type. #[derive(Clone)] -pub struct ToolParserFactory { - registry: ToolParserRegistry, +pub struct ParserFactory { + registry: ParserRegistry, } -impl ToolParserFactory { +impl ParserFactory { /// Create a new factory with default parsers registered. pub fn new() -> Self { - let registry = ToolParserRegistry::new(); + let registry = ParserRegistry::new(); // Register default parsers registry.register_parser("json", || Box::new(JsonParser::new())); @@ -172,7 +255,7 @@ impl ToolParserFactory { Self { registry } } - fn register_default_mappings(registry: &ToolParserRegistry) { + fn register_default_mappings(registry: &ParserRegistry) { // OpenAI models registry.map_model("gpt-4*", "json"); registry.map_model("gpt-3.5*", "json"); @@ -229,7 +312,7 @@ impl ToolParserFactory { /// Get a pooled parser for the given model ID. /// Returns a shared instance that can be used concurrently. /// Falls back to JSON parser if model is not recognized. - pub fn get_pooled(&self, model_id: &str) -> PooledToolParser { + pub fn get_pooled(&self, model_id: &str) -> PooledParser { self.registry .get_pooled_for_model(model_id) .unwrap_or_else(|| { @@ -241,7 +324,7 @@ impl ToolParserFactory { } /// Get the internal registry for custom registration. - pub fn registry(&self) -> &ToolParserRegistry { + pub fn registry(&self) -> &ParserRegistry { &self.registry } @@ -299,7 +382,7 @@ impl ToolParserFactory { } } -impl Default for ToolParserFactory { +impl Default for ParserFactory { fn default() -> Self { Self::new() } diff --git a/sgl-router/src/tool_parser/mod.rs b/sgl-router/src/tool_parser/mod.rs index 80b19506e..d4521b10c 100644 --- a/sgl-router/src/tool_parser/mod.rs +++ b/sgl-router/src/tool_parser/mod.rs @@ -16,8 +16,8 @@ pub mod parsers; mod tests; // Re-export commonly used types -pub use errors::{ToolParserError, ToolParserResult}; -pub use factory::{PooledToolParser, ToolParserFactory, ToolParserRegistry}; +pub use errors::{ParserError, ParserResult}; +pub use factory::{ParserFactory, ParserRegistry, PooledParser}; pub use traits::{PartialJsonParser, ToolParser}; pub use types::{FunctionCall, PartialToolCall, StreamingParseResult, ToolCall}; diff --git a/sgl-router/src/tool_parser/parsers/deepseek_parser.rs b/sgl-router/src/tool_parser/parsers/deepseek_parser.rs index 32774797c..371be9b68 100644 --- a/sgl-router/src/tool_parser/parsers/deepseek_parser.rs +++ b/sgl-router/src/tool_parser/parsers/deepseek_parser.rs @@ -5,7 +5,7 @@ use serde_json::Value; use crate::protocols::spec::Tool; use crate::tool_parser::{ - errors::{ToolParserError, ToolParserResult}, + errors::{ParserError, ParserResult}, parsers::helpers, traits::ToolParser, types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem}, @@ -78,15 +78,15 @@ impl DeepSeekParser { } /// Parse a single tool call block - throws error if parsing fails - fn parse_tool_call(&self, block: &str) -> ToolParserResult { + fn parse_tool_call(&self, block: &str) -> ParserResult { let captures = self.func_detail_extractor.captures(block).ok_or_else(|| { - ToolParserError::ParsingFailed("Failed to match tool call pattern".to_string()) + ParserError::ParsingFailed("Failed to match tool call pattern".to_string()) })?; // Get function type (should be "function") let func_type = captures.get(1).map_or("", |m| m.as_str()); if func_type != "function" { - return Err(ToolParserError::ParsingFailed(format!( + return Err(ParserError::ParsingFailed(format!( "Invalid function type: {}", func_type ))); @@ -95,7 +95,7 @@ impl DeepSeekParser { // Get function name let func_name = captures.get(2).map_or("", |m| m.as_str()).trim(); if func_name.is_empty() { - return Err(ToolParserError::ParsingFailed( + return Err(ParserError::ParsingFailed( "Empty function name".to_string(), )); } @@ -105,7 +105,7 @@ impl DeepSeekParser { // Parse JSON arguments let value = serde_json::from_str::(json_args) - .map_err(|e| ToolParserError::ParsingFailed(format!("Invalid JSON: {}", e)))?; + .map_err(|e| ParserError::ParsingFailed(format!("Invalid JSON: {}", e)))?; // Create arguments object let args = if value.is_object() { @@ -115,8 +115,8 @@ impl DeepSeekParser { serde_json::json!({ "value": value }) }; - let arguments = serde_json::to_string(&args) - .map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?; + let arguments = + serde_json::to_string(&args).map_err(|e| ParserError::ParsingFailed(e.to_string()))?; Ok(ToolCall { function: FunctionCall { @@ -135,7 +135,7 @@ impl Default for DeepSeekParser { #[async_trait] impl ToolParser for DeepSeekParser { - async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec)> { + async fn parse_complete(&self, text: &str) -> ParserResult<(String, Vec)> { if !self.has_tool_markers(text) { return Ok((text.to_string(), vec![])); } @@ -168,7 +168,7 @@ impl ToolParser for DeepSeekParser { &mut self, chunk: &str, tools: &[Tool], - ) -> ToolParserResult { + ) -> ParserResult { self.buffer.push_str(chunk); let current_text = &self.buffer.clone(); @@ -314,4 +314,12 @@ impl ToolParser for DeepSeekParser { fn get_unstreamed_tool_args(&self) -> Option> { helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool) } + + fn reset(&mut self) { + self.buffer.clear(); + self.prev_tool_call_arr.clear(); + self.current_tool_id = -1; + self.current_tool_name_sent = false; + self.streamed_args_for_tool.clear(); + } } diff --git a/sgl-router/src/tool_parser/parsers/glm4_moe_parser.rs b/sgl-router/src/tool_parser/parsers/glm4_moe_parser.rs index 3980709ea..d40273466 100644 --- a/sgl-router/src/tool_parser/parsers/glm4_moe_parser.rs +++ b/sgl-router/src/tool_parser/parsers/glm4_moe_parser.rs @@ -5,7 +5,7 @@ use serde_json::Value; use crate::protocols::spec::Tool; use crate::tool_parser::{ - errors::{ToolParserError, ToolParserResult}, + errors::{ParserError, ParserResult}, parsers::helpers, traits::ToolParser, types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem}, @@ -72,7 +72,7 @@ impl Glm4MoeParser { } /// Parse arguments from key-value pairs - fn parse_arguments(&self, args_text: &str) -> ToolParserResult> { + fn parse_arguments(&self, args_text: &str) -> ParserResult> { let mut arguments = serde_json::Map::new(); for capture in self.arg_extractor.captures_iter(args_text) { @@ -110,7 +110,7 @@ impl Glm4MoeParser { } /// Parse a single tool call block - fn parse_tool_call(&self, block: &str) -> ToolParserResult> { + fn parse_tool_call(&self, block: &str) -> ParserResult> { if let Some(captures) = self.func_detail_extractor.captures(block) { // Get function name let func_name = captures.get(1).map_or("", |m| m.as_str()).trim(); @@ -122,7 +122,7 @@ impl Glm4MoeParser { let arguments = self.parse_arguments(args_text)?; let arguments_str = serde_json::to_string(&arguments) - .map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?; + .map_err(|e| ParserError::ParsingFailed(e.to_string()))?; Ok(Some(ToolCall { function: FunctionCall { @@ -137,7 +137,7 @@ impl Glm4MoeParser { /// Parse and return StreamingParseResult (mirrors Python's detect_and_parse) /// Parse all tool calls from text (shared logic for complete and incremental parsing) - fn parse_tool_calls_from_text(&self, text: &str) -> ToolParserResult> { + fn parse_tool_calls_from_text(&self, text: &str) -> ParserResult> { let mut tools = Vec::new(); for mat in self.tool_call_extractor.find_iter(text) { @@ -163,7 +163,7 @@ impl Default for Glm4MoeParser { #[async_trait] impl ToolParser for Glm4MoeParser { - async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec)> { + async fn parse_complete(&self, text: &str) -> ParserResult<(String, Vec)> { // Check if text contains GLM-4 MoE format if !self.has_tool_markers(text) { return Ok((text.to_string(), vec![])); @@ -188,7 +188,7 @@ impl ToolParser for Glm4MoeParser { &mut self, chunk: &str, tools: &[Tool], - ) -> ToolParserResult { + ) -> ParserResult { // Python logic: Wait for complete tool call, then parse it all at once self.buffer.push_str(chunk); let current_text = &self.buffer.clone(); @@ -315,4 +315,11 @@ impl ToolParser for Glm4MoeParser { fn get_unstreamed_tool_args(&self) -> Option> { helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool) } + + fn reset(&mut self) { + self.buffer.clear(); + self.prev_tool_call_arr.clear(); + self.current_tool_id = -1; + self.streamed_args_for_tool.clear(); + } } diff --git a/sgl-router/src/tool_parser/parsers/gpt_oss_harmony_parser.rs b/sgl-router/src/tool_parser/parsers/gpt_oss_harmony_parser.rs index e7ca5179b..091971df9 100644 --- a/sgl-router/src/tool_parser/parsers/gpt_oss_harmony_parser.rs +++ b/sgl-router/src/tool_parser/parsers/gpt_oss_harmony_parser.rs @@ -3,7 +3,7 @@ use async_trait::async_trait; use crate::protocols::spec::Tool; use crate::tool_parser::{ - errors::ToolParserResult, + errors::ParserResult, traits::{TokenToolParser, ToolParser}, types::{StreamingParseResult, ToolCall}, }; @@ -23,7 +23,7 @@ impl GptOssHarmonyParser { #[async_trait] impl ToolParser for GptOssHarmonyParser { - async fn parse_complete(&self, output: &str) -> ToolParserResult<(String, Vec)> { + async fn parse_complete(&self, output: &str) -> ParserResult<(String, Vec)> { // Temporary stub: fall back to returning the raw text with no tool calls. // Later phases will decode Harmony tokens into structured tool calls. Ok((output.to_string(), Vec::new())) @@ -33,7 +33,7 @@ impl ToolParser for GptOssHarmonyParser { &mut self, _chunk: &str, _tools: &[Tool], - ) -> ToolParserResult { + ) -> ParserResult { // Temporary stub until the Harmony streaming pipeline is implemented. Ok(StreamingParseResult::default()) } @@ -54,7 +54,7 @@ impl TokenToolParser for GptOssHarmonyParser { async fn parse_complete_tokens( &self, _tokens: &[u32], - ) -> ToolParserResult<(String, Vec)> { + ) -> ParserResult<(String, Vec)> { // Placeholder until Harmony integration lands. Returning an empty tool list ensures // that enabling the parser without full implementation results in a no-op rather // than a runtime panic. @@ -65,7 +65,7 @@ impl TokenToolParser for GptOssHarmonyParser { &mut self, _tokens: &[u32], _tools: &[Tool], - ) -> ToolParserResult { + ) -> ParserResult { Ok(StreamingParseResult::default()) } } diff --git a/sgl-router/src/tool_parser/parsers/gpt_oss_parser.rs b/sgl-router/src/tool_parser/parsers/gpt_oss_parser.rs index ddca0d32b..6aacdb6f4 100644 --- a/sgl-router/src/tool_parser/parsers/gpt_oss_parser.rs +++ b/sgl-router/src/tool_parser/parsers/gpt_oss_parser.rs @@ -5,7 +5,7 @@ use serde_json::Value; use crate::protocols::spec::Tool; use crate::tool_parser::{ - errors::{ToolParserError, ToolParserResult}, + errors::{ParserError, ParserResult}, parsers::helpers, partial_json::PartialJson, traits::ToolParser, @@ -76,7 +76,7 @@ impl Default for GptOssParser { #[async_trait] impl ToolParser for GptOssParser { - async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec)> { + async fn parse_complete(&self, text: &str) -> ParserResult<(String, Vec)> { // Check if text contains GPT-OSS format if !self.has_tool_markers(text) { return Ok((text.to_string(), vec![])); @@ -100,7 +100,7 @@ impl ToolParser for GptOssParser { } else { match serde_json::from_str::(args_content) { Ok(value) => serde_json::to_string(&value) - .map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?, + .map_err(|e| ParserError::ParsingFailed(e.to_string()))?, Err(_) => { // Skip malformed JSON continue; @@ -126,7 +126,7 @@ impl ToolParser for GptOssParser { &mut self, chunk: &str, tools: &[Tool], - ) -> ToolParserResult { + ) -> ParserResult { self.buffer.push_str(chunk); // Check for tool markers @@ -211,7 +211,7 @@ impl ToolParser for GptOssParser { partial_args }; - match self.partial_json.parse_value(json_part) { + match self.partial_json.parse_value(json_part, true) { Ok((value, _consumed)) => { let args_str = serde_json::to_string(&value) .unwrap_or_else(|_| "{}".to_string()); diff --git a/sgl-router/src/tool_parser/parsers/helpers.rs b/sgl-router/src/tool_parser/parsers/helpers.rs index 46dcd71c6..c71cf66a0 100644 --- a/sgl-router/src/tool_parser/parsers/helpers.rs +++ b/sgl-router/src/tool_parser/parsers/helpers.rs @@ -2,7 +2,7 @@ use crate::protocols::spec::Tool; use serde_json::Value; use std::collections::HashMap; -use crate::tool_parser::errors::{ToolParserError, ToolParserResult}; +use crate::tool_parser::errors::{ParserError, ParserResult}; use crate::tool_parser::types::{StreamingParseResult, ToolCallItem}; /// Get a mapping of tool names to their indices @@ -14,6 +14,16 @@ pub fn get_tool_indices(tools: &[Tool]) -> HashMap { .collect() } +/// Find the common prefix of two strings +/// Used for incremental argument streaming when partial JSON returns different intermediate states +pub fn find_common_prefix(s1: &str, s2: &str) -> String { + s1.chars() + .zip(s2.chars()) + .take_while(|(c1, c2)| c1 == c2) + .map(|(c1, _)| c1) + .collect() +} + /// Get unstreamed tool call arguments /// Returns tool call items for arguments that have been parsed but not yet streamed /// This ensures tool calls are properly completed even if the model generates final arguments in the last chunk @@ -96,7 +106,7 @@ pub fn reset_parser_state( ) { buffer.clear(); prev_tool_call_arr.clear(); - *current_tool_id = 0; + *current_tool_id = -1; *current_tool_name_sent = false; streamed_args_for_tool.clear(); } @@ -169,7 +179,7 @@ pub fn normalize_arguments_field(mut obj: Value) -> Value { /// /// # Returns /// - `Ok(StreamingParseResult)` with any tool call items to stream -/// - `Err(ToolParserError)` if JSON parsing or serialization fails +/// - `Err(ParserError)` if JSON parsing or serialization fails #[allow(clippy::too_many_arguments)] pub fn handle_json_tool_streaming( current_text: &str, @@ -181,7 +191,7 @@ pub fn handle_json_tool_streaming( current_tool_name_sent: &mut bool, streamed_args_for_tool: &mut Vec, prev_tool_call_arr: &mut Vec, -) -> ToolParserResult { +) -> ParserResult { // Check if we have content to parse if start_idx >= current_text.len() { return Ok(StreamingParseResult::default()); @@ -190,8 +200,12 @@ pub fn handle_json_tool_streaming( // Extract JSON string from current position let json_str = ¤t_text[start_idx..]; + // When current_tool_name_sent is false, don't allow partial strings to avoid + // parsing incomplete tool names as empty strings + let allow_partial_strings = *current_tool_name_sent; + // Parse partial JSON - let (obj, end_idx) = match partial_json.parse_value(json_str) { + let (obj, end_idx) = match partial_json.parse_value(json_str, allow_partial_strings) { Ok(result) => result, Err(_) => { return Ok(StreamingParseResult::default()); @@ -252,49 +266,68 @@ pub fn handle_json_tool_streaming( .map(|s| s.len()) .unwrap_or(0); let cur_args_json = serde_json::to_string(cur_arguments) - .map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?; + .map_err(|e| ParserError::ParsingFailed(e.to_string()))?; - // Compute diff: everything after what we've already sent - let diff = cur_args_json[sent..].to_string(); + // Get prev_arguments (matches Python's structure) + let prev_arguments = if tool_id < prev_tool_call_arr.len() { + prev_tool_call_arr[tool_id].get("arguments") + } else { + None + }; - // Send diff if there's new content - if !diff.is_empty() { - // Only accumulate if not complete - if !is_complete && tool_id < streamed_args_for_tool.len() { - streamed_args_for_tool[tool_id].push_str(&diff); - } + // Calculate diff: everything after we've already sent + let mut argument_diff = None; - result.calls.push(ToolCallItem { - tool_index: tool_id, - name: None, - parameters: diff, - }); - } - - // If JSON is complete, advance to next tool if is_complete { - // Remove processed portion, keep unprocessed content - *buffer = current_text[start_idx + end_idx..].to_string(); + // Python: argument_diff = cur_args_json[sent:] + // Rust needs bounds check (Python returns "" automatically) + argument_diff = if sent < cur_args_json.len() { + Some(cur_args_json[sent..].to_string()) + } else { + Some(String::new()) + }; + } else if let Some(prev_args) = prev_arguments { + let prev_args_json = serde_json::to_string(prev_args) + .map_err(|e| ParserError::ParsingFailed(e.to_string()))?; - // Clear completed tool data - if tool_id < prev_tool_call_arr.len() { - prev_tool_call_arr[tool_id] = Value::Null; + if cur_args_json != prev_args_json { + let prefix = find_common_prefix(&prev_args_json, &cur_args_json); + argument_diff = if sent < prefix.len() { + Some(prefix[sent..].to_string()) + } else { + Some(String::new()) + }; } - *current_tool_name_sent = false; - if tool_id < streamed_args_for_tool.len() { - streamed_args_for_tool[tool_id].clear(); - } - *current_tool_id += 1; } - } - // Update prev_tool_call_arr with current state - if *current_tool_id >= 0 { - ensure_capacity(*current_tool_id, prev_tool_call_arr, streamed_args_for_tool); - let tool_id = *current_tool_id as usize; + // Send diff if present + if let Some(diff) = argument_diff { + if !diff.is_empty() { + if tool_id < streamed_args_for_tool.len() { + streamed_args_for_tool[tool_id].push_str(&diff); + } + result.calls.push(ToolCallItem { + tool_index: tool_id, + name: None, + parameters: diff, + }); + } + } - if tool_id < prev_tool_call_arr.len() { - prev_tool_call_arr[tool_id] = current_tool_call; + // Update prev_tool_call_arr with current state + if *current_tool_id >= 0 { + ensure_capacity(*current_tool_id, prev_tool_call_arr, streamed_args_for_tool); + + if tool_id < prev_tool_call_arr.len() { + prev_tool_call_arr[tool_id] = current_tool_call; + } + } + + // If complete, advance to next tool + if is_complete { + *buffer = current_text[start_idx + end_idx..].to_string(); + *current_tool_name_sent = false; + *current_tool_id += 1; } } @@ -371,7 +404,7 @@ mod tests { assert_eq!(buffer, ""); assert_eq!(prev_tools.len(), 0); - assert_eq!(current_tool_id, 0); + assert_eq!(current_tool_id, -1); assert!(!current_tool_name_sent); assert_eq!(streamed_args.len(), 0); } diff --git a/sgl-router/src/tool_parser/parsers/json_parser.rs b/sgl-router/src/tool_parser/parsers/json_parser.rs index 660e113e0..04b0ca1de 100644 --- a/sgl-router/src/tool_parser/parsers/json_parser.rs +++ b/sgl-router/src/tool_parser/parsers/json_parser.rs @@ -4,7 +4,7 @@ use serde_json::Value; use crate::protocols::spec::Tool; use crate::tool_parser::{ - errors::{ToolParserError, ToolParserResult}, + errors::{ParserError, ParserResult}, parsers::helpers, partial_json::PartialJson, traits::ToolParser, @@ -117,7 +117,7 @@ impl JsonParser { } /// Parse a single JSON object into a ToolCall - fn parse_single_object(&self, obj: &Value) -> ToolParserResult> { + fn parse_single_object(&self, obj: &Value) -> ParserResult> { // Check if this looks like a tool call let name = obj .get("name") @@ -134,7 +134,7 @@ impl JsonParser { // Convert arguments to JSON string let arguments = serde_json::to_string(args) - .map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?; + .map_err(|e| ParserError::ParsingFailed(e.to_string()))?; Ok(Some(ToolCall { function: FunctionCall { @@ -148,7 +148,7 @@ impl JsonParser { } /// Parse JSON value(s) into tool calls - fn parse_json_value(&self, value: &Value) -> ToolParserResult> { + fn parse_json_value(&self, value: &Value) -> ParserResult> { let mut tools = Vec::new(); match value { @@ -184,11 +184,11 @@ impl Default for JsonParser { #[async_trait] impl ToolParser for JsonParser { - async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec)> { + async fn parse_complete(&self, text: &str) -> ParserResult<(String, Vec)> { // Always use extract_json_from_text to handle both pure JSON and mixed content if let Some((extracted_json, normal_text)) = self.extract_json_from_text(text) { let parsed = serde_json::from_str::(&extracted_json) - .map_err(|e| ToolParserError::ParsingFailed(e.to_string())) + .map_err(|e| ParserError::ParsingFailed(e.to_string())) .and_then(|v| self.parse_json_value(&v)); match parsed { @@ -205,7 +205,7 @@ impl ToolParser for JsonParser { &mut self, chunk: &str, tools: &[Tool], - ) -> ToolParserResult { + ) -> ParserResult { // Append new text to buffer self.buffer.push_str(chunk); let current_text = &self.buffer.clone(); @@ -264,4 +264,14 @@ impl ToolParser for JsonParser { fn get_unstreamed_tool_args(&self) -> Option> { helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool) } + + fn reset(&mut self) { + helpers::reset_parser_state( + &mut self.buffer, + &mut self.prev_tool_call_arr, + &mut self.current_tool_id, + &mut self.current_tool_name_sent, + &mut self.streamed_args_for_tool, + ); + } } diff --git a/sgl-router/src/tool_parser/parsers/kimik2_parser.rs b/sgl-router/src/tool_parser/parsers/kimik2_parser.rs index b2d6e85d8..2e2237f0c 100644 --- a/sgl-router/src/tool_parser/parsers/kimik2_parser.rs +++ b/sgl-router/src/tool_parser/parsers/kimik2_parser.rs @@ -5,7 +5,7 @@ use serde_json::Value; use crate::protocols::spec::Tool; use crate::tool_parser::{ - errors::ToolParserResult, + errors::ParserResult, parsers::helpers, traits::ToolParser, types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem}, @@ -102,7 +102,7 @@ impl Default for KimiK2Parser { #[async_trait] impl ToolParser for KimiK2Parser { - async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec)> { + async fn parse_complete(&self, text: &str) -> ParserResult<(String, Vec)> { if !self.has_tool_markers(text) { return Ok((text.to_string(), vec![])); } @@ -161,7 +161,7 @@ impl ToolParser for KimiK2Parser { &mut self, chunk: &str, tools: &[Tool], - ) -> ToolParserResult { + ) -> ParserResult { self.buffer.push_str(chunk); let current_text = &self.buffer.clone(); @@ -333,4 +333,13 @@ impl ToolParser for KimiK2Parser { fn get_unstreamed_tool_args(&self) -> Option> { helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool) } + + fn reset(&mut self) { + self.buffer.clear(); + self.prev_tool_call_arr.clear(); + self.current_tool_id = -1; + self.current_tool_name_sent = false; + self.streamed_args_for_tool.clear(); + self.last_arguments.clear(); + } } diff --git a/sgl-router/src/tool_parser/parsers/llama_parser.rs b/sgl-router/src/tool_parser/parsers/llama_parser.rs index 14c09c1c5..3af8b9bda 100644 --- a/sgl-router/src/tool_parser/parsers/llama_parser.rs +++ b/sgl-router/src/tool_parser/parsers/llama_parser.rs @@ -4,7 +4,7 @@ use serde_json::Value; use crate::protocols::spec::Tool; use crate::tool_parser::{ - errors::{ToolParserError, ToolParserResult}, + errors::{ParserError, ParserResult}, parsers::helpers, partial_json::PartialJson, traits::ToolParser, @@ -70,7 +70,7 @@ impl LlamaParser { } /// Parse a single JSON object into a ToolCall (Llama format: name + parameters) - fn parse_single_object(&self, obj: &Value) -> ToolParserResult> { + fn parse_single_object(&self, obj: &Value) -> ParserResult> { // Llama format only: {"name": "function_name", "parameters": {...}} let name = obj.get("name").and_then(|v| v.as_str()); @@ -81,7 +81,7 @@ impl LlamaParser { // Convert parameters to JSON string let arguments = serde_json::to_string(parameters) - .map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?; + .map_err(|e| ParserError::ParsingFailed(e.to_string()))?; Ok(Some(ToolCall { function: FunctionCall { @@ -95,7 +95,7 @@ impl LlamaParser { } /// Parse semicolon-separated JSON objects - fn parse_semicolon_separated(&self, content: &str) -> ToolParserResult> { + fn parse_semicolon_separated(&self, content: &str) -> ParserResult> { let mut all_tools = Vec::new(); // Split by semicolon and parse each JSON object @@ -131,7 +131,7 @@ impl Default for LlamaParser { #[async_trait] impl ToolParser for LlamaParser { - async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec)> { + async fn parse_complete(&self, text: &str) -> ParserResult<(String, Vec)> { // Extract normal text and JSON content let (normal_text, json_content) = if let Some((normal, json)) = self.extract_content_after_python_tag(text) { @@ -149,7 +149,7 @@ impl ToolParser for LlamaParser { } else { // Try single JSON object let parsed = serde_json::from_str::(json_content.trim()) - .map_err(|e| ToolParserError::ParsingFailed(e.to_string())) + .map_err(|e| ParserError::ParsingFailed(e.to_string())) .and_then(|v| { self.parse_single_object(&v) .map(|opt| opt.map_or_else(Vec::new, |tool| vec![tool])) @@ -173,7 +173,7 @@ impl ToolParser for LlamaParser { &mut self, chunk: &str, tools: &[Tool], - ) -> ToolParserResult { + ) -> ParserResult { // Append new text to buffer self.buffer.push_str(chunk); let current_text = &self.buffer.clone(); @@ -231,4 +231,14 @@ impl ToolParser for LlamaParser { fn get_unstreamed_tool_args(&self) -> Option> { helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool) } + + fn reset(&mut self) { + helpers::reset_parser_state( + &mut self.buffer, + &mut self.prev_tool_call_arr, + &mut self.current_tool_id, + &mut self.current_tool_name_sent, + &mut self.streamed_args_for_tool, + ); + } } diff --git a/sgl-router/src/tool_parser/parsers/mistral_parser.rs b/sgl-router/src/tool_parser/parsers/mistral_parser.rs index a7cf79aba..c87d8ce7a 100644 --- a/sgl-router/src/tool_parser/parsers/mistral_parser.rs +++ b/sgl-router/src/tool_parser/parsers/mistral_parser.rs @@ -4,7 +4,7 @@ use serde_json::Value; use crate::protocols::spec::Tool; use crate::tool_parser::{ - errors::{ToolParserError, ToolParserResult}, + errors::{ParserError, ParserResult}, parsers::helpers, partial_json::PartialJson, traits::ToolParser, @@ -111,9 +111,9 @@ impl MistralParser { } /// Parse tool calls from a JSON array - fn parse_json_array(&self, json_str: &str) -> ToolParserResult> { + fn parse_json_array(&self, json_str: &str) -> ParserResult> { let value: Value = serde_json::from_str(json_str) - .map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?; + .map_err(|e| ParserError::ParsingFailed(e.to_string()))?; let mut tools = Vec::new(); @@ -134,7 +134,7 @@ impl MistralParser { } /// Parse a single JSON object into a ToolCall - fn parse_single_object(&self, obj: &Value) -> ToolParserResult> { + fn parse_single_object(&self, obj: &Value) -> ParserResult> { let name = obj.get("name").and_then(|v| v.as_str()); if let Some(name) = name { @@ -144,7 +144,7 @@ impl MistralParser { // Convert arguments to JSON string let arguments = serde_json::to_string(args) - .map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?; + .map_err(|e| ParserError::ParsingFailed(e.to_string()))?; Ok(Some(ToolCall { function: FunctionCall { @@ -166,7 +166,7 @@ impl Default for MistralParser { #[async_trait] impl ToolParser for MistralParser { - async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec)> { + async fn parse_complete(&self, text: &str) -> ParserResult<(String, Vec)> { // Check if text contains Mistral format if !self.has_tool_markers(text) { return Ok((text.to_string(), vec![])); @@ -199,7 +199,7 @@ impl ToolParser for MistralParser { &mut self, chunk: &str, tools: &[Tool], - ) -> ToolParserResult { + ) -> ParserResult { // Append new text to buffer self.buffer.push_str(chunk); let current_text = &self.buffer.clone(); @@ -256,4 +256,14 @@ impl ToolParser for MistralParser { fn get_unstreamed_tool_args(&self) -> Option> { helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool) } + + fn reset(&mut self) { + helpers::reset_parser_state( + &mut self.buffer, + &mut self.prev_tool_call_arr, + &mut self.current_tool_id, + &mut self.current_tool_name_sent, + &mut self.streamed_args_for_tool, + ); + } } diff --git a/sgl-router/src/tool_parser/parsers/pythonic_parser.rs b/sgl-router/src/tool_parser/parsers/pythonic_parser.rs index 2b5ad8bad..4c712c7bd 100644 --- a/sgl-router/src/tool_parser/parsers/pythonic_parser.rs +++ b/sgl-router/src/tool_parser/parsers/pythonic_parser.rs @@ -18,7 +18,7 @@ use std::sync::OnceLock; use crate::protocols::spec::Tool; use crate::tool_parser::{ - errors::{ToolParserError, ToolParserResult}, + errors::{ParserError, ParserResult}, parsers::helpers, traits::ToolParser, types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem}, @@ -74,7 +74,7 @@ impl PythonicParser { .replace("<|python_end|>", "") } - fn parse_tool_call_block(&self, block: &str) -> ToolParserResult> { + fn parse_tool_call_block(&self, block: &str) -> ParserResult> { let expr = parse_python_expression(block)?; match expr { Expr::List(list_expr) => list_expr @@ -83,7 +83,7 @@ impl PythonicParser { .enumerate() .map(|(idx, call_expr)| build_tool_call(call_expr, idx)) .collect(), - _ => Err(ToolParserError::ParsingFailed( + _ => Err(ParserError::ParsingFailed( "Expected a list of function calls in pythonic tool call".to_string(), )), } @@ -92,7 +92,7 @@ impl PythonicParser { #[async_trait] impl ToolParser for PythonicParser { - async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec)> { + async fn parse_complete(&self, text: &str) -> ParserResult<(String, Vec)> { let cleaned = Self::strip_special_tokens(text); if let Some((tool_calls_text, normal_text)) = self.extract_tool_calls(&cleaned) { @@ -120,7 +120,7 @@ impl ToolParser for PythonicParser { &mut self, chunk: &str, tools: &[Tool], - ) -> ToolParserResult { + ) -> ParserResult { self.buffer.push_str(chunk); let cleaned = Self::strip_special_tokens(&self.buffer); @@ -232,23 +232,23 @@ fn find_matching_bracket(buffer: &str, start: usize) -> Option { None // No matching bracket found } -fn parse_python_expression(source: &str) -> ToolParserResult { +fn parse_python_expression(source: &str) -> ParserResult { let module = parse(source, Mode::Expression, "") - .map_err(|err| ToolParserError::ParsingFailed(err.to_string()))?; + .map_err(|err| ParserError::ParsingFailed(err.to_string()))?; match module { Mod::Expression(expr_mod) => Ok(*expr_mod.body), - _ => Err(ToolParserError::ParsingFailed( + _ => Err(ParserError::ParsingFailed( "Expected a Python expression".to_string(), )), } } -fn build_tool_call(expr: Expr, _index: usize) -> ToolParserResult { +fn build_tool_call(expr: Expr, _index: usize) -> ParserResult { match expr { Expr::Call(call_expr) => { if !call_expr.args.is_empty() { - return Err(ToolParserError::ParsingFailed( + return Err(ParserError::ParsingFailed( "Positional arguments are not supported in pythonic tool calls".to_string(), )); } @@ -256,7 +256,7 @@ fn build_tool_call(expr: Expr, _index: usize) -> ToolParserResult { let function_name = match *call_expr.func { Expr::Name(name_expr) => name_expr.id.to_string(), _ => { - return Err(ToolParserError::ParsingFailed( + return Err(ParserError::ParsingFailed( "Unsupported function reference in pythonic tool call".to_string(), )) } @@ -265,7 +265,7 @@ fn build_tool_call(expr: Expr, _index: usize) -> ToolParserResult { let mut arguments_map = Map::with_capacity(call_expr.keywords.len()); for keyword in call_expr.keywords { let arg_name = keyword.arg.ok_or_else(|| { - ToolParserError::ParsingFailed( + ParserError::ParsingFailed( "pythonic tool calls do not support **kwargs".to_string(), ) })?; @@ -283,13 +283,13 @@ fn build_tool_call(expr: Expr, _index: usize) -> ToolParserResult { }, }) } - _ => Err(ToolParserError::ParsingFailed( + _ => Err(ParserError::ParsingFailed( "Expected function calls inside pythonic tool call list".to_string(), )), } } -fn expression_to_json(expr: &Expr) -> ToolParserResult { +fn expression_to_json(expr: &Expr) -> ParserResult { match expr { Expr::Constant(expr_constant) => constant_to_json(&expr_constant.value), Expr::List(list_expr) => collect_sequence(&list_expr.elts).map(Value::Array), @@ -300,81 +300,75 @@ fn expression_to_json(expr: &Expr) -> ToolParserResult { Expr::UnaryOp(unary_expr) => match unary_expr.op { UnaryOp::USub => match unary_expr.operand.as_ref() { Expr::Constant(const_expr) => negate_constant(&const_expr.value), - _ => Err(ToolParserError::ParsingFailed( + _ => Err(ParserError::ParsingFailed( "Unsupported unary operand in pythonic tool call".to_string(), )), }, UnaryOp::UAdd => expression_to_json(unary_expr.operand.as_ref()), - _ => Err(ToolParserError::ParsingFailed(format!( + _ => Err(ParserError::ParsingFailed(format!( "Unsupported unary operator in pythonic tool call: {:?}", unary_expr.op ))), }, Expr::Name(name_expr) => Ok(Value::String(name_expr.id.to_string())), - _ => Err(ToolParserError::ParsingFailed(format!( + _ => Err(ParserError::ParsingFailed(format!( "Unsupported expression in pythonic tool call: {:?}", expr ))), } } -fn constant_to_json(constant: &Constant) -> ToolParserResult { +fn constant_to_json(constant: &Constant) -> ParserResult { match constant { Constant::None => Ok(Value::Null), Constant::Bool(b) => Ok(Value::Bool(*b)), Constant::Int(value) => Ok(integer_constant_to_value(value, false)), Constant::Float(f) => Number::from_f64(*f).map(Value::Number).ok_or_else(|| { - ToolParserError::ParsingFailed( - "Invalid float literal in pythonic tool call".to_string(), - ) + ParserError::ParsingFailed("Invalid float literal in pythonic tool call".to_string()) }), Constant::Str(s) => Ok(Value::String(s.clone())), Constant::Bytes(bytes) => Ok(Value::String(String::from_utf8_lossy(bytes).into_owned())), Constant::Tuple(values) => constant_tuple_to_array(values).map(Value::Array), - Constant::Ellipsis | Constant::Complex { .. } => Err(ToolParserError::ParsingFailed( + Constant::Ellipsis | Constant::Complex { .. } => Err(ParserError::ParsingFailed( "Unsupported literal in pythonic tool call".to_string(), )), } } -fn negate_constant(constant: &Constant) -> ToolParserResult { +fn negate_constant(constant: &Constant) -> ParserResult { match constant { Constant::Int(value) => Ok(integer_constant_to_value(value, true)), Constant::Float(f) => Number::from_f64(-f).map(Value::Number).ok_or_else(|| { - ToolParserError::ParsingFailed( - "Invalid float literal in pythonic tool call".to_string(), - ) + ParserError::ParsingFailed("Invalid float literal in pythonic tool call".to_string()) }), - _ => Err(ToolParserError::ParsingFailed( + _ => Err(ParserError::ParsingFailed( "Unsupported unary operand in pythonic tool call".to_string(), )), } } -fn value_to_key_string(value: Value) -> ToolParserResult { +fn value_to_key_string(value: Value) -> ParserResult { match value { Value::String(s) => Ok(s), Value::Number(num) => Ok(num.to_string()), Value::Bool(b) => Ok(b.to_string()), Value::Null => Ok("null".to_string()), - other => Err(ToolParserError::ParsingFailed(format!( + other => Err(ParserError::ParsingFailed(format!( "Unsupported key type in pythonic tool call: {:?}", other ))), } } -fn collect_sequence(elements: &[Expr]) -> ToolParserResult> { +fn collect_sequence(elements: &[Expr]) -> ParserResult> { elements.iter().map(expression_to_json).collect() } -fn collect_dict(keys: &[Option], values: &[Expr]) -> ToolParserResult> { +fn collect_dict(keys: &[Option], values: &[Expr]) -> ParserResult> { let mut map = Map::with_capacity(keys.len()); for (key_expr, value_expr) in keys.iter().zip(values.iter()) { let key_expr = key_expr.as_ref().ok_or_else(|| { - ToolParserError::ParsingFailed( - "pythonic tool calls do not support **kwargs".to_string(), - ) + ParserError::ParsingFailed("pythonic tool calls do not support **kwargs".to_string()) })?; let key_value = expression_to_json(key_expr)?; let key = value_to_key_string(key_value)?; @@ -384,7 +378,7 @@ fn collect_dict(keys: &[Option], values: &[Expr]) -> ToolParserResult ToolParserResult> { +fn constant_tuple_to_array(values: &[Constant]) -> ParserResult> { values.iter().map(constant_to_json).collect() } diff --git a/sgl-router/src/tool_parser/parsers/qwen_parser.rs b/sgl-router/src/tool_parser/parsers/qwen_parser.rs index 62eac4f64..e0072debc 100644 --- a/sgl-router/src/tool_parser/parsers/qwen_parser.rs +++ b/sgl-router/src/tool_parser/parsers/qwen_parser.rs @@ -5,7 +5,7 @@ use serde_json::Value; use crate::protocols::spec::Tool; use crate::tool_parser::{ - errors::{ToolParserError, ToolParserResult}, + errors::{ParserError, ParserResult}, parsers::helpers, partial_json::PartialJson, traits::ToolParser, @@ -76,7 +76,7 @@ impl QwenParser { } /// Parse a single JSON object into a ToolCall - fn parse_single_object(&self, obj: &Value) -> ToolParserResult> { + fn parse_single_object(&self, obj: &Value) -> ParserResult> { let name = obj.get("name").and_then(|v| v.as_str()); if let Some(name) = name { @@ -86,7 +86,7 @@ impl QwenParser { // Convert arguments to JSON string let arguments = serde_json::to_string(args) - .map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?; + .map_err(|e| ParserError::ParsingFailed(e.to_string()))?; Ok(Some(ToolCall { function: FunctionCall { @@ -108,7 +108,7 @@ impl Default for QwenParser { #[async_trait] impl ToolParser for QwenParser { - async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec)> { + async fn parse_complete(&self, text: &str) -> ParserResult<(String, Vec)> { // Check if text contains Qwen format if !self.has_tool_markers(text) { return Ok((text.to_string(), vec![])); @@ -123,7 +123,7 @@ impl ToolParser for QwenParser { for captures in self.extractor.captures_iter(text) { if let Some(json_str) = captures.get(1) { let parsed = serde_json::from_str::(json_str.as_str().trim()) - .map_err(|e| ToolParserError::ParsingFailed(e.to_string())) + .map_err(|e| ParserError::ParsingFailed(e.to_string())) .and_then(|v| self.parse_single_object(&v)); match parsed { @@ -149,7 +149,7 @@ impl ToolParser for QwenParser { &mut self, chunk: &str, tools: &[Tool], - ) -> ToolParserResult { + ) -> ParserResult { // Append new text to buffer self.buffer.push_str(chunk); let current_text = &self.buffer.clone(); @@ -240,4 +240,14 @@ impl ToolParser for QwenParser { fn get_unstreamed_tool_args(&self) -> Option> { helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool) } + + fn reset(&mut self) { + helpers::reset_parser_state( + &mut self.buffer, + &mut self.prev_tool_call_arr, + &mut self.current_tool_id, + &mut self.current_tool_name_sent, + &mut self.streamed_args_for_tool, + ); + } } diff --git a/sgl-router/src/tool_parser/parsers/step3_parser.rs b/sgl-router/src/tool_parser/parsers/step3_parser.rs index 622843c0c..01f3674aa 100644 --- a/sgl-router/src/tool_parser/parsers/step3_parser.rs +++ b/sgl-router/src/tool_parser/parsers/step3_parser.rs @@ -6,7 +6,7 @@ use std::collections::HashMap; use crate::protocols::spec::Tool; use crate::tool_parser::{ - errors::{ToolParserError, ToolParserResult}, + errors::{ParserError, ParserResult}, parsers::helpers, traits::ToolParser, types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem}, @@ -108,7 +108,7 @@ impl Step3Parser { fn parse_partial_tool_call( &mut self, tool_indices: &HashMap, - ) -> ToolParserResult { + ) -> ParserResult { let mut calls = Vec::new(); // Check if we have tool_sep (means we're past the type declaration) @@ -321,7 +321,7 @@ impl Step3Parser { fn parse_steptml_parameters( &self, params_text: &str, - ) -> ToolParserResult> { + ) -> ParserResult> { let mut parameters = serde_json::Map::new(); for capture in self.param_extractor.captures_iter(params_text) { @@ -359,7 +359,7 @@ impl Step3Parser { } /// Parse a single tool call block - fn parse_tool_call(&self, block: &str) -> ToolParserResult> { + fn parse_tool_call(&self, block: &str) -> ParserResult> { // Check if it contains function marker and tool separator if !block.contains("function") || !block.contains("<|tool_sep|>") { return Ok(None); @@ -393,7 +393,7 @@ impl Step3Parser { let parameters = self.parse_steptml_parameters(params_text)?; let arguments_str = serde_json::to_string(¶meters) - .map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?; + .map_err(|e| ParserError::ParsingFailed(e.to_string()))?; Ok(Some(ToolCall { function: FunctionCall { @@ -415,7 +415,7 @@ impl Default for Step3Parser { #[async_trait] impl ToolParser for Step3Parser { - async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec)> { + async fn parse_complete(&self, text: &str) -> ParserResult<(String, Vec)> { if !self.has_tool_markers(text) { return Ok((text.to_string(), vec![])); } @@ -449,7 +449,7 @@ impl ToolParser for Step3Parser { &mut self, chunk: &str, tools: &[Tool], - ) -> ToolParserResult { + ) -> ParserResult { self.buffer.push_str(chunk); // Build tool indices for validation @@ -555,4 +555,20 @@ impl ToolParser for Step3Parser { fn get_unstreamed_tool_args(&self) -> Option> { helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool) } + + fn reset(&mut self) { + // Reset standard state + self.buffer.clear(); + self.prev_tool_call_arr.clear(); + self.current_tool_id = -1; + self.streamed_args_for_tool.clear(); + + // Reset Step3-specific fields + self.in_tool_block = false; + self.tool_block_finished = false; + self.current_function_name.clear(); + self.current_parameters.clear(); + self.in_tool_call = false; + self.function_name_sent = false; + } } diff --git a/sgl-router/src/tool_parser/partial_json.rs b/sgl-router/src/tool_parser/partial_json.rs index 4a4504fe0..c6d474d6a 100644 --- a/sgl-router/src/tool_parser/partial_json.rs +++ b/sgl-router/src/tool_parser/partial_json.rs @@ -1,5 +1,5 @@ use crate::tool_parser::{ - errors::{ToolParserError, ToolParserResult}, + errors::{ParserError, ParserResult}, traits::PartialJsonParser, }; use serde_json::{Map, Value}; @@ -22,8 +22,22 @@ impl PartialJson { } /// Parse potentially incomplete JSON, returning parsed value and consumed bytes - pub fn parse_value(&self, input: &str) -> ToolParserResult<(Value, usize)> { - let mut parser = Parser::new(input, self.max_depth, self.allow_incomplete); + /// + /// # Arguments + /// * `input` - The JSON string to parse + /// * `allow_partial_strings` - When false, incomplete strings cause parsing to stop + /// (matches Python's Allow.ALL & ~Allow.STR behavior) + pub fn parse_value( + &self, + input: &str, + allow_partial_strings: bool, + ) -> ParserResult<(Value, usize)> { + let mut parser = Parser::new( + input, + self.max_depth, + self.allow_incomplete, + allow_partial_strings, + ); let value = parser.parse_value(0)?; Ok((value, parser.position)) } @@ -36,8 +50,9 @@ impl Default for PartialJson { } impl PartialJsonParser for PartialJson { - fn parse(&self, input: &str) -> ToolParserResult<(Value, usize)> { - self.parse_value(input) + fn parse(&self, input: &str) -> ParserResult<(Value, usize)> { + // Default to allowing partial strings + self.parse_value(input, true) } fn is_complete(&self, input: &str) -> bool { @@ -56,15 +71,22 @@ struct Parser<'a> { position: usize, max_depth: usize, allow_incomplete: bool, + allow_partial_strings: bool, } impl<'a> Parser<'a> { - fn new(input: &'a str, max_depth: usize, allow_incomplete: bool) -> Self { + fn new( + input: &'a str, + max_depth: usize, + allow_incomplete: bool, + allow_partial_strings: bool, + ) -> Self { Self { chars: input.chars().peekable(), position: 0, max_depth, allow_incomplete, + allow_partial_strings, } } @@ -88,9 +110,9 @@ impl<'a> Parser<'a> { } } - fn parse_value(&mut self, depth: usize) -> ToolParserResult { + fn parse_value(&mut self, depth: usize) -> ParserResult { if depth > self.max_depth { - return Err(ToolParserError::DepthExceeded(self.max_depth)); + return Err(ParserError::DepthExceeded(self.max_depth)); } self.skip_whitespace(); @@ -106,17 +128,15 @@ impl<'a> Parser<'a> { if self.allow_incomplete { Ok(Value::Null) } else { - Err(ToolParserError::ParsingFailed( - "Unexpected character".into(), - )) + Err(ParserError::ParsingFailed("Unexpected character".into())) } } } } - fn parse_object(&mut self, depth: usize) -> ToolParserResult { + fn parse_object(&mut self, depth: usize) -> ParserResult { if depth > self.max_depth { - return Err(ToolParserError::DepthExceeded(self.max_depth)); + return Err(ParserError::DepthExceeded(self.max_depth)); } let mut object = Map::new(); @@ -140,7 +160,7 @@ impl<'a> Parser<'a> { return Ok(Value::Object(object)); } Err(e) => return Err(e), - _ => return Err(ToolParserError::ParsingFailed("Expected string key".into())), + _ => return Err(ParserError::ParsingFailed("Expected string key".into())), }; self.skip_whitespace(); @@ -152,7 +172,7 @@ impl<'a> Parser<'a> { object.insert(key, Value::Null); return Ok(Value::Object(object)); } - return Err(ToolParserError::ParsingFailed("Expected ':'".into())); + return Err(ParserError::ParsingFailed("Expected ':'".into())); } self.advance(); self.skip_whitespace(); @@ -161,8 +181,13 @@ impl<'a> Parser<'a> { let value = match self.parse_value(depth) { Ok(v) => v, Err(_) if self.allow_incomplete => { - // Add null for incomplete value - object.insert(key, Value::Null); + // When allow_partial_strings is false, don't add the key with Null + // Just return the object without this incomplete key-value pair + // This matches Python's behavior: Allow.ALL & ~Allow.STR + if self.allow_partial_strings { + // Add null for incomplete value + object.insert(key, Value::Null); + } return Ok(Value::Object(object)); } Err(e) => return Err(e), @@ -192,15 +217,15 @@ impl<'a> Parser<'a> { if self.allow_incomplete { return Ok(Value::Object(object)); } - return Err(ToolParserError::ParsingFailed("Expected ',' or '}'".into())); + return Err(ParserError::ParsingFailed("Expected ',' or '}'".into())); } } } } - fn parse_array(&mut self, depth: usize) -> ToolParserResult { + fn parse_array(&mut self, depth: usize) -> ParserResult { if depth > self.max_depth { - return Err(ToolParserError::DepthExceeded(self.max_depth)); + return Err(ParserError::DepthExceeded(self.max_depth)); } let mut array = Vec::new(); @@ -249,15 +274,15 @@ impl<'a> Parser<'a> { if self.allow_incomplete { return Ok(Value::Array(array)); } - return Err(ToolParserError::ParsingFailed("Expected ',' or ']'".into())); + return Err(ParserError::ParsingFailed("Expected ',' or ']'".into())); } } } } - fn parse_string(&mut self) -> ToolParserResult { + fn parse_string(&mut self) -> ParserResult { if self.peek() != Some('"') { - return Err(ToolParserError::ParsingFailed("Expected '\"'".into())); + return Err(ParserError::ParsingFailed("Expected '\"'".into())); } // Consume opening quote @@ -301,14 +326,14 @@ impl<'a> Parser<'a> { } // Incomplete string - if self.allow_incomplete { + if self.allow_incomplete && self.allow_partial_strings { Ok(Value::String(string)) } else { - Err(ToolParserError::ParsingFailed("Unterminated string".into())) + Err(ParserError::ParsingFailed("Unterminated string".into())) } } - fn parse_unicode_escape(&mut self) -> ToolParserResult { + fn parse_unicode_escape(&mut self) -> ParserResult { let mut hex = String::new(); for _ in 0..4 { if let Some(ch) = self.peek() { @@ -327,17 +352,17 @@ impl<'a> Parser<'a> { u32::from_str_radix(&hex, 16) .ok() .and_then(char::from_u32) - .ok_or_else(|| ToolParserError::ParsingFailed("Invalid unicode escape".into())) + .ok_or_else(|| ParserError::ParsingFailed("Invalid unicode escape".into())) } else if self.allow_incomplete { Ok('\u{FFFD}') // Replacement character } else { - Err(ToolParserError::ParsingFailed( + Err(ParserError::ParsingFailed( "Incomplete unicode escape".into(), )) } } - fn parse_number(&mut self) -> ToolParserResult { + fn parse_number(&mut self) -> ParserResult { let mut number = String::new(); // Handle negative sign @@ -410,11 +435,11 @@ impl<'a> Parser<'a> { } else if self.allow_incomplete { Ok(Value::Number(serde_json::Number::from(0))) } else { - Err(ToolParserError::ParsingFailed("Invalid number".into())) + Err(ParserError::ParsingFailed("Invalid number".into())) } } - fn parse_bool(&mut self) -> ToolParserResult { + fn parse_bool(&mut self) -> ParserResult { let mut word = String::new(); // Peek at upcoming characters to validate it looks like a boolean @@ -435,7 +460,7 @@ impl<'a> Parser<'a> { || (self.allow_incomplete && ("true".starts_with(&word) || "false".starts_with(&word))); if !is_valid { - return Err(ToolParserError::ParsingFailed("Invalid boolean".into())); + return Err(ParserError::ParsingFailed("Invalid boolean".into())); } // Now actually consume the characters @@ -458,14 +483,14 @@ impl<'a> Parser<'a> { } else if "false".starts_with(partial) { Ok(Value::Bool(false)) } else { - Err(ToolParserError::ParsingFailed("Invalid boolean".into())) + Err(ParserError::ParsingFailed("Invalid boolean".into())) } } - _ => Err(ToolParserError::ParsingFailed("Invalid boolean".into())), + _ => Err(ParserError::ParsingFailed("Invalid boolean".into())), } } - fn parse_null(&mut self) -> ToolParserResult { + fn parse_null(&mut self) -> ParserResult { let mut word = String::new(); // Peek at upcoming characters to validate it looks like "null" @@ -484,7 +509,7 @@ impl<'a> Parser<'a> { let is_valid = word == "null" || (self.allow_incomplete && "null".starts_with(&word)); if !is_valid { - return Err(ToolParserError::ParsingFailed("Invalid null".into())); + return Err(ParserError::ParsingFailed("Invalid null".into())); } // Now actually consume the characters @@ -501,7 +526,7 @@ impl<'a> Parser<'a> { if word == "null" || (self.allow_incomplete && "null".starts_with(&word)) { Ok(Value::Null) } else { - Err(ToolParserError::ParsingFailed("Invalid null".into())) + Err(ParserError::ParsingFailed("Invalid null".into())) } } } diff --git a/sgl-router/src/tool_parser/tests.rs b/sgl-router/src/tool_parser/tests.rs index cd10b23ee..b440382b6 100644 --- a/sgl-router/src/tool_parser/tests.rs +++ b/sgl-router/src/tool_parser/tests.rs @@ -7,7 +7,7 @@ use crate::tool_parser::traits::ToolParser; #[tokio::test] async fn test_tool_parser_factory() { - let factory = ToolParserFactory::new(); + let factory = ParserFactory::new(); // Test that we can get a pooled parser let pooled_parser = factory.get_pooled("gpt-4"); @@ -17,7 +17,7 @@ async fn test_tool_parser_factory() { #[tokio::test] async fn test_tool_parser_factory_model_mapping() { - let factory = ToolParserFactory::new(); + let factory = ParserFactory::new(); // Test model mapping factory.registry().map_model("test-model", "json"); @@ -54,22 +54,22 @@ fn test_partial_json_parser() { let parser = PartialJson::default(); let input = r#"{"name": "test", "value": 42}"#; - let (value, consumed) = parser.parse_value(input).unwrap(); + let (value, consumed) = parser.parse_value(input, true).unwrap(); assert_eq!(value["name"], "test"); assert_eq!(value["value"], 42); assert_eq!(consumed, input.len()); let input = r#"{"name": "test", "value": "#; - let (value, _consumed) = parser.parse_value(input).unwrap(); + let (value, _consumed) = parser.parse_value(input, true).unwrap(); assert_eq!(value["name"], "test"); assert!(value["value"].is_null()); let input = r#"{"name": "tes"#; - let (value, _consumed) = parser.parse_value(input).unwrap(); + let (value, _consumed) = parser.parse_value(input, true).unwrap(); assert_eq!(value["name"], "tes"); let input = r#"[1, 2, "#; - let (value, _consumed) = parser.parse_value(input).unwrap(); + let (value, _consumed) = parser.parse_value(input, true).unwrap(); assert!(value.is_array()); assert_eq!(value[0], 1); assert_eq!(value[1], 2); @@ -83,17 +83,17 @@ fn test_partial_json_depth_limit() { // This should work (simple object) let input = r#"{"a": 1}"#; - let result = parser.parse_value(input); + let result = parser.parse_value(input, true); assert!(result.is_ok()); // This should work (nested to depth 3) let input = r#"{"a": {"b": {"c": 1}}}"#; - let result = parser.parse_value(input); + let result = parser.parse_value(input, true); assert!(result.is_ok()); // This should fail (nested to depth 4, exceeds limit) let input = r#"{"a": {"b": {"c": {"d": 1}}}}"#; - let result = parser.parse_value(input); + let result = parser.parse_value(input, true); assert!(result.is_err()); } @@ -244,7 +244,7 @@ fn test_json_parser_format_detection() { #[tokio::test] async fn test_factory_with_json_parser() { - let factory = ToolParserFactory::new(); + let factory = ParserFactory::new(); // Should get JSON parser for OpenAI models let pooled_parser = factory.get_pooled("gpt-4-turbo"); diff --git a/sgl-router/src/tool_parser/traits.rs b/sgl-router/src/tool_parser/traits.rs index f9f23216f..f4e64a053 100644 --- a/sgl-router/src/tool_parser/traits.rs +++ b/sgl-router/src/tool_parser/traits.rs @@ -1,6 +1,6 @@ use crate::protocols::spec::Tool; use crate::tool_parser::{ - errors::ToolParserResult, + errors::ParserResult, types::{StreamingParseResult, ToolCall}, }; use async_trait::async_trait; @@ -10,7 +10,7 @@ use async_trait::async_trait; pub trait ToolParser: Send + Sync { /// Parse complete tool calls from final output /// Returns (remaining_normal_text, tool_calls) tuple - async fn parse_complete(&self, output: &str) -> ToolParserResult<(String, Vec)>; + async fn parse_complete(&self, output: &str) -> ParserResult<(String, Vec)>; /// Parse tool calls from model output (streaming) /// Parsers now maintain internal state, so self is mutable @@ -22,7 +22,7 @@ pub trait ToolParser: Send + Sync { &mut self, chunk: &str, tools: &[Tool], - ) -> ToolParserResult; + ) -> ParserResult; /// Check if text contains tool calls in this parser's format fn has_tool_markers(&self, text: &str) -> bool; @@ -38,12 +38,18 @@ pub trait ToolParser: Send + Sync { fn get_unstreamed_tool_args(&self) -> Option> { None } + + /// Reset the parser state for reuse across requests. + /// This should clear all buffers and reset state to initial values. + fn reset(&mut self) { + // Default no-op implementation + } } /// Trait for partial JSON parsing pub trait PartialJsonParser: Send + Sync { /// Parse potentially incomplete JSON - fn parse(&self, input: &str) -> ToolParserResult<(serde_json::Value, usize)>; + fn parse(&self, input: &str) -> ParserResult<(serde_json::Value, usize)>; /// Check if JSON is complete fn is_complete(&self, input: &str) -> bool; @@ -55,10 +61,7 @@ pub trait PartialJsonParser: Send + Sync { #[async_trait] pub trait TokenToolParser: ToolParser { /// Parse complete tool calls when provided with raw token IDs. - async fn parse_complete_tokens( - &self, - tokens: &[u32], - ) -> ToolParserResult<(String, Vec)>; + async fn parse_complete_tokens(&self, tokens: &[u32]) -> ParserResult<(String, Vec)>; /// Streaming parser entrypoint for token chunks. /// Parsers maintain internal state, so self is mutable @@ -66,5 +69,5 @@ pub trait TokenToolParser: ToolParser { &mut self, tokens: &[u32], tools: &[Tool], - ) -> ToolParserResult; + ) -> ParserResult; } diff --git a/sgl-router/tests/common/mod.rs b/sgl-router/tests/common/mod.rs index 0e7022bff..9288a9b06 100644 --- a/sgl-router/tests/common/mod.rs +++ b/sgl-router/tests/common/mod.rs @@ -4,6 +4,7 @@ pub mod mock_mcp_server; pub mod mock_openai_server; pub mod mock_worker; +pub mod streaming_helpers; pub mod test_app; use serde_json::json; diff --git a/sgl-router/tests/common/streaming_helpers.rs b/sgl-router/tests/common/streaming_helpers.rs new file mode 100644 index 000000000..0c993168e --- /dev/null +++ b/sgl-router/tests/common/streaming_helpers.rs @@ -0,0 +1,134 @@ +//! Streaming Test Helpers +//! +//! Utilities for creating realistic streaming chunks that simulate +//! how LLM tokens actually arrive (1-5 characters at a time). + +/// Split input into realistic char-level chunks (2-3 chars each for determinism) +pub fn create_realistic_chunks(input: &str) -> Vec { + let mut chunks = Vec::new(); + let chars: Vec = input.chars().collect(); + let mut i = 0; + + while i < chars.len() { + // Take 2-3 characters at a time (deterministic for testing) + let chunk_size = if i + 3 <= chars.len() && chars[i].is_ascii_alphanumeric() { + 3 // Longer chunks for alphanumeric sequences + } else { + 2 // Shorter chunks for special characters + }; + + let end = (i + chunk_size).min(chars.len()); + let chunk: String = chars[i..end].iter().collect(); + chunks.push(chunk); + i = end; + } + + chunks +} + +/// Split input at strategic positions to test edge cases +/// This creates chunks that break at critical positions like after quotes, colons, etc. +pub fn create_strategic_chunks(input: &str) -> Vec { + let mut chunks = Vec::new(); + let mut current = String::new(); + let chars: Vec = input.chars().collect(); + + for (i, &ch) in chars.iter().enumerate() { + current.push(ch); + + // Break after strategic characters + let should_break = matches!(ch, '"' | ':' | ',' | '{' | '}' | '[' | ']') + || (i > 0 && chars[i-1] == '"' && ch == ' ') // Space after quote + || current.len() >= 5; // Max 5 chars per chunk + + if should_break && !current.is_empty() { + chunks.push(current.clone()); + current.clear(); + } + } + + if !current.is_empty() { + chunks.push(current); + } + + chunks +} + +/// Create the bug scenario chunks: `{"name": "` arrives in parts +pub fn create_bug_scenario_chunks() -> Vec<&'static str> { + vec![ + r#"{"#, + r#"""#, + r#"name"#, + r#"""#, + r#":"#, + r#" "#, + r#"""#, // Bug occurs here: parser has {"name": " + r#"search"#, // Use valid tool name + r#"""#, + r#","#, + r#" "#, + r#"""#, + r#"arguments"#, + r#"""#, + r#":"#, + r#" "#, + r#"{"#, + r#"""#, + r#"query"#, + r#"""#, + r#":"#, + r#" "#, + r#"""#, + r#"test query"#, + r#"""#, + r#"}"#, + r#"}"#, + ] +} + +#[cfg(test)] +mod tests { + #[allow(unused_imports)] + use super::*; + + #[test] + fn test_realistic_chunks() { + let input = r#"{"name": "test"}"#; + let chunks = create_realistic_chunks(input); + + // Should have multiple chunks + assert!(chunks.len() > 3); + + // Reconstructed should equal original + let reconstructed: String = chunks.join(""); + assert_eq!(reconstructed, input); + } + + #[test] + fn test_strategic_chunks_breaks_after_quotes() { + let input = r#"{"name": "value"}"#; + let chunks = create_strategic_chunks(input); + + // Should break after quotes and colons + assert!(chunks.iter().any(|c| c.ends_with('"'))); + assert!(chunks.iter().any(|c| c.ends_with(':'))); + + // Reconstructed should equal original + let reconstructed: String = chunks.join(""); + assert_eq!(reconstructed, input); + } + + #[test] + fn test_bug_scenario_chunks() { + let chunks = create_bug_scenario_chunks(); + let reconstructed: String = chunks.join(""); + + // Should reconstruct to valid JSON + assert!(reconstructed.contains(r#"{"name": "search""#)); + + // The critical chunk sequence should be present (space after colon, then quote in next chunk) + let joined = chunks.join("|"); + assert!(joined.contains(r#" |"#)); // The bug happens at {"name": " and then " + } +} diff --git a/sgl-router/tests/tool_parser_glm4_moe.rs b/sgl-router/tests/tool_parser_glm4_moe.rs index e92848cf4..86d161c9e 100644 --- a/sgl-router/tests/tool_parser_glm4_moe.rs +++ b/sgl-router/tests/tool_parser_glm4_moe.rs @@ -126,28 +126,6 @@ fn test_glm4_format_detection() { assert!(!parser.has_tool_markers("plain text")); } -#[tokio::test] -async fn test_glm4_python_literal_values() { - let parser = Glm4MoeParser::new(); - - let input = r#"config -debug -True -verbose -False -optional -None -"#; - - let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); - assert_eq!(tools.len(), 1); - - let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); - assert_eq!(args["debug"], true); - assert_eq!(args["verbose"], false); - assert_eq!(args["optional"], serde_json::Value::Null); -} - #[tokio::test] async fn test_python_literals() { let parser = Glm4MoeParser::new(); @@ -172,7 +150,7 @@ async fn test_python_literals() { } #[tokio::test] -async fn test_nested_values() { +async fn test_glm4_nested_json_in_arg_values() { let parser = Glm4MoeParser::new(); let input = r#"process diff --git a/sgl-router/tests/tool_parser_partial_json.rs b/sgl-router/tests/tool_parser_partial_json.rs new file mode 100644 index 000000000..36d493651 --- /dev/null +++ b/sgl-router/tests/tool_parser_partial_json.rs @@ -0,0 +1,156 @@ +//! Partial JSON Parser Tests +//! +//! Tests for the partial JSON parser with allow_partial_strings flag behavior + +use sglang_router_rs::tool_parser::partial_json::PartialJson; + +#[test] +fn test_partial_string_flag_disallows_incomplete_strings() { + // Test case from the bug report: {"name": " + // With allow_partial_strings=false, should return {} (stop before incomplete string) + let parser = PartialJson::new(32, true); + let input = r#"{"name": ""#; + + let result = parser.parse_value(input, false); + assert!(result.is_ok()); + + let (obj, consumed) = result.unwrap(); + + // Should parse just the opening brace and stop at the incomplete string + assert!(obj.is_object()); + let obj_map = obj.as_object().unwrap(); + + // Should have empty object (stopped before parsing incomplete "name" key) + assert!( + obj_map.is_empty() || !obj_map.contains_key("name"), + "Should not parse incomplete string key, got: {:?}", + obj_map + ); + + // Should consume characters up to the incomplete string + assert!(consumed <= input.len()); +} + +#[test] +fn test_partial_string_flag_allows_incomplete_strings() { + // Test case: {"name": " + // With allow_partial_strings=true, should parse the incomplete string + let parser = PartialJson::new(32, true); + let input = r#"{"name": ""#; + + let result = parser.parse_value(input, true); + assert!(result.is_ok()); + + let (obj, consumed) = result.unwrap(); + + // Should parse the object with incomplete string value + assert!(obj.is_object()); + let obj_map = obj.as_object().unwrap(); + + // With allow_partial_strings=true, should parse "name" key with empty string value + assert!( + obj_map.contains_key("name"), + "Should parse incomplete string with allow_partial_strings=true" + ); + + assert_eq!(consumed, input.len()); +} + +#[test] +fn test_partial_string_flag_complete_json() { + // Test case: {"name": "test"} + // Both flags should parse complete JSON the same way + let input = r#"{"name": "test"}"#; + + let parser = PartialJson::new(32, true); + let result1 = parser.parse_value(input, false); + assert!(result1.is_ok()); + let (obj1, consumed1) = result1.unwrap(); + + let result2 = parser.parse_value(input, true); + assert!(result2.is_ok()); + let (obj2, consumed2) = result2.unwrap(); + + // Both should parse the same complete JSON + assert_eq!(obj1, obj2); + assert_eq!(consumed1, consumed2); + assert_eq!(consumed1, input.len()); + + // Check the parsed value + assert!(obj1.is_object()); + let obj_map = obj1.as_object().unwrap(); + assert_eq!(obj_map.get("name").and_then(|v| v.as_str()), Some("test")); +} + +#[test] +fn test_backward_compatibility_default() { + // Test that default PartialJson still allows partial strings (backward compatible) + let parser = PartialJson::default(); + let input = r#"{"name": ""#; + + let result = parser.parse_value(input, true); + assert!(result.is_ok()); + + let (obj, _) = result.unwrap(); + assert!(obj.is_object()); + + // Default behavior should allow partial strings + let obj_map = obj.as_object().unwrap(); + assert!( + obj_map.contains_key("name"), + "Default should allow partial strings for backward compatibility" + ); +} + +#[test] +fn test_partial_string_in_nested_object() { + // Test case: {"tool": {"name": " + let parser = PartialJson::new(32, true); + let input = r#"{"tool": {"name": ""#; + + let result = parser.parse_value(input, false); + assert!(result.is_ok()); + + let (obj, _) = result.unwrap(); + assert!(obj.is_object()); + + // With allow_partial_strings=false, should stop before incomplete nested string + let obj_map = obj.as_object().unwrap(); + if let Some(tool) = obj_map.get("tool") { + if let Some(tool_map) = tool.as_object() { + assert!( + !tool_map.contains_key("name") + || tool_map.get("name").and_then(|v| v.as_str()).is_none(), + "Should not parse incomplete nested string" + ); + } + } +} + +#[test] +fn test_bug_fix_exact_scenario() { + // This test verifies the exact bug scenario from the issue: + // buffer = "{\"name\": \"" + // flags = Allow.ALL & ~Allow.STR + // Python returns: Parsed object: {}, consumed length: 10 + + let parser = PartialJson::new(32, true); + let input = r#"{"name": ""#; + + let result = parser.parse_value(input, false); + assert!(result.is_ok()); + + let (obj, consumed) = result.unwrap(); + + // Should return empty object (not {"name": null} or {"name": ""}) + assert!(obj.is_object()); + let obj_map = obj.as_object().unwrap(); + assert!( + obj_map.is_empty(), + "Expected empty object, got: {:?}. This matches Python behavior with Allow.ALL & ~Allow.STR", + obj_map + ); + + // Should consume all characters (10 bytes) + assert_eq!(consumed, 10, "Should consume all 10 characters"); +} diff --git a/sgl-router/tests/tool_parser_streaming.rs b/sgl-router/tests/tool_parser_streaming.rs index b2d5ef1a8..73484c9f8 100644 --- a/sgl-router/tests/tool_parser_streaming.rs +++ b/sgl-router/tests/tool_parser_streaming.rs @@ -1,73 +1,199 @@ -//! Streaming Parser Tests +//! Realistic Streaming Parser Tests //! -//! Tests for incremental/streaming parsing capabilities across all parsers +//! Tests incremental parsing with realistic char-level chunks (2-5 chars) +//! that simulate how LLM tokens actually arrive. +//! +//! These tests are designed to catch bugs like `{"name": "` being parsed +//! as an empty tool name. -use sglang_router_rs::tool_parser::{ - JsonParser, LlamaParser, MistralParser, PythonicParser, QwenParser, ToolParser, -}; +use sglang_router_rs::tool_parser::{JsonParser, LlamaParser, QwenParser, ToolParser}; mod common; -use common::create_test_tools; +use common::{create_test_tools, streaming_helpers::*}; + +// ============================================================================= +// THE BUG SCENARIO - Most Critical Test +// ============================================================================= #[tokio::test] -async fn test_json_streaming_simple() { +async fn test_json_bug_incomplete_tool_name_string() { let tools = create_test_tools(); - - let mut parser = JsonParser::new(); - - let full_json = r#"{"name": "get_weather", "arguments": {"location": "San Francisco"}}"#; - - let result = parser.parse_incremental(full_json, &tools).await.unwrap(); - - assert!(!result.calls.is_empty(), "Should have parsed a tool call"); - assert_eq!(result.calls[0].name, Some("get_weather".to_string())); -} - -#[tokio::test] -async fn test_json_streaming_array() { - let tools = create_test_tools(); - let mut parser = JsonParser::new(); + // This exact sequence triggered the bug: + // Parser receives {"name": " and must NOT parse it as empty name let chunks = vec![ - r#"["#, - r#"{"name": "tool1", "#, - r#""arguments": {}}, "#, - r#"{"name": "tool2", "#, - r#""arguments": {"x": 1"#, - r#"}}]"#, + r#"{"#, + r#"""#, + r#"name"#, + r#"""#, + r#":"#, + r#" "#, + r#"""#, // ← Critical moment: parser has {"name": " + // At this point, partial_json should NOT allow incomplete strings + // when current_tool_name_sent=false + r#"search"#, // Use valid tool name from create_test_tools() + r#"""#, + r#", "#, + r#"""#, + r#"arguments"#, + r#"""#, + r#": {"#, + r#"""#, + r#"query"#, + r#"""#, + r#": "#, + r#"""#, + r#"rust programming"#, + r#"""#, + r#"}}"#, ]; - let mut tool_count = 0; + let mut got_tool_name = false; + let mut saw_empty_name = false; - for chunk in chunks { + for chunk in chunks.iter() { let result = parser.parse_incremental(chunk, &tools).await.unwrap(); + for call in result.calls { - if call.name.is_some() { - tool_count += 1; + if let Some(name) = &call.name { + if name.is_empty() { + saw_empty_name = true; + } + if name == "search" { + got_tool_name = true; + } } } } - // Current implementation may handle this differently - assert!(tool_count <= 2, "Should parse at most 2 tools"); + assert!( + !saw_empty_name, + "Parser should NEVER return empty tool name" + ); + assert!(got_tool_name, "Should have parsed tool name correctly"); +} + +// ============================================================================= +// JSON PARSER REALISTIC STREAMING +// ============================================================================= + +#[tokio::test] +async fn test_json_realistic_chunks_simple_tool() { + let tools = create_test_tools(); + let mut parser = JsonParser::new(); + + let input = r#"{"name": "get_weather", "arguments": {"city": "Paris"}}"#; + let chunks = create_realistic_chunks(input); + + assert!(chunks.len() > 10, "Should have many small chunks"); + + let mut got_tool_name = false; + + for chunk in chunks { + let result = parser.parse_incremental(&chunk, &tools).await.unwrap(); + for call in result.calls { + if let Some(name) = call.name { + assert_eq!(name, "get_weather"); + got_tool_name = true; + } + } + } + + assert!(got_tool_name, "Should have parsed tool name"); } #[tokio::test] -async fn test_mistral_streaming() { +async fn test_json_strategic_chunks_with_quotes() { let tools = create_test_tools(); + let mut parser = JsonParser::new(); - let mut parser = MistralParser::new(); + let input = r#"{"name": "search", "arguments": {"query": "rust programming"}}"#; + let chunks = create_strategic_chunks(input); + // Strategic chunks break after quotes and colons + assert!(chunks.iter().any(|c| c.ends_with('"'))); + + let mut got_tool_name = false; + + for chunk in chunks { + let result = parser.parse_incremental(&chunk, &tools).await.unwrap(); + for call in result.calls { + if call.name.is_some() { + got_tool_name = true; + } + } + } + + assert!(got_tool_name, "Should have parsed tool name"); +} + +#[tokio::test] +async fn test_json_incremental_arguments_streaming() { + let tools = create_test_tools(); + let mut parser = JsonParser::new(); + + let input = r#"{"name": "search", "arguments": {"query": "test", "limit": 10}}"#; + let chunks = create_realistic_chunks(input); + + let mut tool_name_sent = false; + let mut got_arguments = false; + + for chunk in chunks { + let result = parser.parse_incremental(&chunk, &tools).await.unwrap(); + for call in result.calls { + if call.name.is_some() { + tool_name_sent = true; + } + if tool_name_sent && !call.parameters.is_empty() { + got_arguments = true; + } + } + } + + assert!(tool_name_sent, "Should have sent tool name"); + assert!(got_arguments, "Should have sent arguments"); +} + +// ============================================================================= +// LLAMA PARSER REALISTIC STREAMING +// ============================================================================= + +#[tokio::test] +async fn test_llama_realistic_chunks_with_python_tag() { + let tools = create_test_tools(); + let mut parser = LlamaParser::new(); + + let input = r#"<|python_tag|>{"name": "calculate", "parameters": {"x": 10, "y": 20}}"#; + let chunks = create_realistic_chunks(input); + + assert!(chunks.len() > 15, "Should have many small chunks"); + + let mut got_tool_name = false; + + for chunk in chunks { + let result = parser.parse_incremental(&chunk, &tools).await.unwrap(); + for call in result.calls { + if let Some(name) = call.name { + assert_eq!(name, "calculate"); + got_tool_name = true; + } + } + } + + assert!(got_tool_name, "Should have parsed tool name"); +} + +#[tokio::test] +async fn test_llama_python_tag_arrives_in_parts() { + let tools = create_test_tools(); + let mut parser = LlamaParser::new(); + + // Python tag itself arrives in small chunks let chunks = vec![ - r#"Here is the result: "#, - r#"[TOOL_CALLS] ["#, - r#"{"name": "#, - r#""search", "#, - r#""arguments": "#, - r#"{"query": "#, - r#""rust lang""#, - r#"}}]"#, + "<|p", "yth", "on_", "tag", "|>{", r#"""#, "na", r#"me""#, ": ", r#"""#, "sea", "rch", + r#"""#, ", ", r#"""#, "par", "ame", "ter", "s", r#"""#, ": {", r#"""#, "q", r#"""#, ": ", + r#"""#, "tes", "t", r#"""#, "}}", ]; let mut got_tool_name = false; @@ -82,40 +208,47 @@ async fn test_mistral_streaming() { } } - assert!(got_tool_name, "Should have found tool name"); + assert!(got_tool_name, "Should have parsed tool name"); +} + +// ============================================================================= +// QWEN PARSER REALISTIC STREAMING +// ============================================================================= + +#[tokio::test] +async fn test_qwen_realistic_chunks_with_xml_tags() { + let tools = create_test_tools(); + let mut parser = QwenParser::new(); + + let input = "\n{\"name\": \"get_weather\", \"arguments\": {\"city\": \"Tokyo\"}}\n"; + let chunks = create_realistic_chunks(input); + + assert!(chunks.len() > 20, "Should have many small chunks"); + + let mut got_tool_name = false; + + for chunk in chunks { + let result = parser.parse_incremental(&chunk, &tools).await.unwrap(); + for call in result.calls { + if let Some(name) = call.name { + assert_eq!(name, "get_weather"); + got_tool_name = true; + } + } + } + + assert!(got_tool_name, "Should have parsed tool name"); } #[tokio::test] -async fn test_pythonic_streaming() { +async fn test_qwen_xml_tag_arrives_in_parts() { let tools = create_test_tools(); - - let mut parser = PythonicParser::new(); - - let full_input = r#"[get_weather(city="London", units="celsius")]"#; - - let result = parser.parse_incremental(full_input, &tools).await.unwrap(); - - assert!(!result.calls.is_empty(), "Should have parsed a tool call"); - assert_eq!(result.calls[0].name, Some("get_weather".to_string())); - let args: serde_json::Value = serde_json::from_str(&result.calls[0].parameters).unwrap(); - assert_eq!(args["city"], "London"); -} - -#[tokio::test] -async fn test_llama_streaming_with_python_tag() { - let tools = create_test_tools(); - - let mut parser = LlamaParser::new(); + let mut parser = QwenParser::new(); let chunks = vec![ - r#"Let me help. "#, - r#"<|python"#, - r#"_tag|>"#, - r#"{"name": "#, - r#""calculate", "#, - r#""arguments": "#, - r#"{"x": 10}"#, - r#"}"#, + "\n", "{", r#"""#, "na", "me", r#"""#, ": ", r#"""#, "tra", "nsl", + "ate", r#"""#, ", ", r#"""#, "arg", "ume", "nts", r#"""#, ": {", r#"""#, "tex", "t", + r#"""#, ": ", r#"""#, "hel", "lo", r#"""#, "}}\n", "", ]; let mut got_tool_name = false; @@ -124,191 +257,66 @@ async fn test_llama_streaming_with_python_tag() { let result = parser.parse_incremental(chunk, &tools).await.unwrap(); for call in result.calls { if let Some(name) = call.name { - assert_eq!(name, "calculate"); + assert_eq!(name, "translate"); got_tool_name = true; } } } - assert!(got_tool_name, "Should have found tool name"); + assert!(got_tool_name, "Should have parsed tool name"); } -#[tokio::test] -async fn test_qwen_streaming() { - let tools = create_test_tools(); - - let mut parser = QwenParser::new(); - - // Note: Parser expects newline after both tags - let full_input = "\n{\"name\": \"translate\", \"arguments\": {\"text\": \"hello\", \"to\": \"zh\"}}\n"; - - let result = parser.parse_incremental(full_input, &tools).await.unwrap(); - - assert!(!result.calls.is_empty(), "Should have parsed a tool call"); - assert_eq!(result.calls[0].name, Some("translate".to_string())); -} +// ============================================================================= +// EDGE CASES WITH REALISTIC CHUNKS +// ============================================================================= #[tokio::test] -async fn test_streaming_incomplete_stays_incomplete() { +async fn test_json_very_long_url_in_arguments() { let tools = create_test_tools(); - let mut parser = JsonParser::new(); - let chunks = vec![r#"{"na"#, r#"me": "#]; + // Simulate long URL arriving in many chunks + let long_url = "https://example.com/very/long/path/".to_string() + &"segment/".repeat(50); + let input = format!( + r#"{{"name": "search", "arguments": {{"query": "{}"}}}}"#, + long_url + ); + let chunks = create_realistic_chunks(&input); + + assert!(chunks.len() > 100, "Long URL should create many chunks"); + + let mut got_tool_name = false; for chunk in chunks { - let result = parser.parse_incremental(chunk, &tools).await.unwrap(); - assert!( - result.calls.is_empty(), - "Should return empty calls for partial JSON, got: {:?}", - result - ); + let result = parser.parse_incremental(&chunk, &tools).await.unwrap(); + for call in result.calls { + if call.name.is_some() { + got_tool_name = true; + } + } } + + assert!(got_tool_name, "Should have parsed tool name"); } #[tokio::test] -async fn test_streaming_buffer_accumulation() { +async fn test_json_unicode_arrives_byte_by_byte() { let tools = create_test_tools(); - let mut parser = JsonParser::new(); - let result1 = parser.parse_incremental(r#"{"na"#, &tools).await.unwrap(); + let input = r#"{"name": "search", "arguments": {"query": "Hello 世界 🌍"}}"#; + let chunks = create_realistic_chunks(input); - assert!(result1.calls.is_empty(), "Should not parse incomplete JSON"); + let mut got_tool_name = false; - let result2 = parser - .parse_incremental(r#"me": "test", "arguments": {}}"#, &tools) - .await - .unwrap(); - - assert!( - !result2.calls.is_empty(), - "Should parse complete JSON after buffering" - ); - assert_eq!(result2.calls[0].name, Some("test".to_string())); -} - -#[tokio::test] -async fn test_streaming_multiple_tools_sequential() { - let tools = create_test_tools(); - - let mut parser = QwenParser::new(); - - let full_input = r#" -{"name": "tool1", "arguments": {}} -"#; - - let result = parser.parse_incremental(full_input, &tools).await.unwrap(); - - assert!(!result.calls.is_empty(), "Should have parsed a tool call"); - assert_eq!(result.calls[0].name, Some("tool1".to_string())); -} - -#[tokio::test] -async fn test_streaming_reset_after_error() { - let tools = create_test_tools(); - - let mut parser1 = JsonParser::new(); - - let _ = parser1 - .parse_incremental(r#"{"name": invalid}"#, &tools) - .await; - - // Use a new parser instance for clean state - let mut parser2 = JsonParser::new(); - let result = parser2 - .parse_incremental(r#"{"name": "test", "arguments": {}}"#, &tools) - .await - .unwrap(); - - assert!(!result.calls.is_empty(), "Should parse valid JSON"); - assert_eq!(result.calls[0].name, Some("test".to_string())); -} - -#[tokio::test] -async fn test_streaming_with_unicode_chunks() { - let tools = create_test_tools(); - - let mut parser = JsonParser::new(); - - let full_input = r#"{"name": "translate", "arguments": {"text": "Hello 世界 🌍"}}"#; - - let result = parser.parse_incremental(full_input, &tools).await.unwrap(); - - assert!(!result.calls.is_empty(), "Should have parsed a tool call"); - - // Check if we got the tool name - if let Some(name) = &result.calls[0].name { - assert_eq!(name, "translate"); + for chunk in chunks { + let result = parser.parse_incremental(&chunk, &tools).await.unwrap(); + for call in result.calls { + if call.name.is_some() { + got_tool_name = true; + } + } } - // In streaming mode, need to make another call to get parameters - let result2 = parser.parse_incremental("", &tools).await.unwrap(); - - // Parameters should be in either result.calls[1] or result2.calls[0] - let params = if result.calls.len() > 1 { - &result.calls[1].parameters - } else if !result2.calls.is_empty() { - &result2.calls[0].parameters - } else { - &result.calls[0].parameters - }; - - if !params.is_empty() { - let args: serde_json::Value = serde_json::from_str(params).unwrap(); - assert!(args["text"].as_str().unwrap().contains("世界")); - } -} - -#[tokio::test] -async fn test_streaming_with_partial_chunks() { - let mut parser = JsonParser::new(); - let tools = create_test_tools(); - - let partial = r#"{"#; - let result = parser.parse_incremental(partial, &tools).await.unwrap(); - assert!( - result.calls.is_empty(), - "Should return empty calls for just opening brace" - ); - - let mut parser2 = JsonParser::new(); - let complete = r#"{"name": "get_weather", "arguments": {"location": "SF"}}"#; - let result = parser2.parse_incremental(complete, &tools).await.unwrap(); - - assert!( - !result.calls.is_empty(), - "Expected tool call for complete JSON" - ); - assert_eq!(result.calls[0].name.as_ref().unwrap(), "get_weather"); - - // In streaming mode, need to make another call to get parameters - let result2 = parser2.parse_incremental("", &tools).await.unwrap(); - - // Parameters should be in either result.calls[1] or result2.calls[0] - let params = if result.calls.len() > 1 { - &result.calls[1].parameters - } else if !result2.calls.is_empty() { - &result2.calls[0].parameters - } else { - &result.calls[0].parameters - }; - - if !params.is_empty() { - let args: serde_json::Value = serde_json::from_str(params).unwrap(); - assert_eq!(args["location"], "SF"); - } - - // The PartialJson parser can complete partial JSON by filling in missing values - let mut parser3 = JsonParser::new(); - let partial_with_name = r#"{"name": "test", "argum"#; - let result = parser3 - .parse_incremental(partial_with_name, &tools) - .await - .unwrap(); - - // Parser behavior may vary - either complete with partial data or wait for more - if !result.calls.is_empty() { - assert_eq!(result.calls[0].name.as_ref().unwrap(), "test"); - } + assert!(got_tool_name, "Should have parsed with unicode"); }