diff --git a/sgl-router/src/reasoning_parser/factory.rs b/sgl-router/src/reasoning_parser/factory.rs index 042653a1b..970f9e41a 100644 --- a/sgl-router/src/reasoning_parser/factory.rs +++ b/sgl-router/src/reasoning_parser/factory.rs @@ -1,20 +1,28 @@ // Factory and registry for creating model-specific reasoning parsers. +// Now with parser pooling support for efficient reuse across requests. use std::collections::HashMap; -use std::sync::{Arc, RwLock}; +use std::sync::{Arc, Mutex, RwLock}; use crate::reasoning_parser::parsers::{ BaseReasoningParser, DeepSeekR1Parser, KimiParser, Qwen3Parser, QwenThinkingParser, }; use crate::reasoning_parser::traits::{ParseError, ParserConfig, ReasoningParser}; +/// Type alias for pooled parser instances. +pub type PooledParser = Arc>>; + /// Type alias for parser creator functions. type ParserCreator = Arc Box + Send + Sync>; -/// Registry for model-specific parsers. +/// Registry for model-specific parsers with pooling support. #[derive(Clone)] pub struct ParserRegistry { - parsers: Arc>>, + /// Creator functions for parsers (used when pool is empty) + creators: Arc>>, + /// Pooled parser instances for reuse + pool: Arc>>, + /// Model pattern to parser name mappings patterns: Arc>>, // (pattern, parser_name) } @@ -22,7 +30,8 @@ impl ParserRegistry { /// Create a new empty registry. pub fn new() -> Self { Self { - parsers: Arc::new(RwLock::new(HashMap::new())), + creators: Arc::new(RwLock::new(HashMap::new())), + pool: Arc::new(RwLock::new(HashMap::new())), patterns: Arc::new(RwLock::new(Vec::new())), } } @@ -32,8 +41,8 @@ impl ParserRegistry { where F: Fn() -> Box + Send + Sync + 'static, { - let mut parsers = self.parsers.write().unwrap(); - parsers.insert(name.to_string(), Arc::new(creator)); + let mut creators = self.creators.write().unwrap(); + creators.insert(name.to_string(), Arc::new(creator)); } /// Register a model pattern to parser mapping. @@ -43,13 +52,53 @@ impl ParserRegistry { patterns.push((pattern.to_string(), parser_name.to_string())); } - /// Get a parser by exact name. - pub fn get_parser(&self, name: &str) -> Option> { - let parsers = self.parsers.read().unwrap(); - parsers.get(name).map(|creator| creator()) + /// Get a pooled parser by exact name. + /// Returns a shared parser instance from the pool, creating one if needed. + pub fn get_pooled_parser(&self, name: &str) -> Option { + // First check if we have a pooled instance + { + let pool = self.pool.read().unwrap(); + if let Some(parser) = pool.get(name) { + return Some(Arc::clone(parser)); + } + } + + // If not in pool, create one and add to pool + let creators = self.creators.read().unwrap(); + if let Some(creator) = creators.get(name) { + let parser = Arc::new(Mutex::new(creator())); + + // Add to pool for future use + let mut pool = self.pool.write().unwrap(); + pool.insert(name.to_string(), Arc::clone(&parser)); + + Some(parser) + } else { + None + } } - /// Find a parser for a given model ID by pattern matching. + /// Get a parser by exact name (creates new instance, not pooled). + /// Use this for compatibility or when you need a fresh instance. + pub fn get_parser(&self, name: &str) -> Option> { + let creators = self.creators.read().unwrap(); + creators.get(name).map(|creator| creator()) + } + + /// Find a pooled parser for a given model ID by pattern matching. + pub fn find_pooled_parser_for_model(&self, model_id: &str) -> Option { + let patterns = self.patterns.read().unwrap(); + let model_lower = model_id.to_lowercase(); + + for (pattern, parser_name) in patterns.iter() { + if model_lower.contains(&pattern.to_lowercase()) { + return self.get_pooled_parser(parser_name); + } + } + None + } + + /// Find a parser for a given model ID by pattern matching (creates new instance). pub fn find_parser_for_model(&self, model_id: &str) -> Option> { let patterns = self.patterns.read().unwrap(); let model_lower = model_id.to_lowercase(); @@ -61,6 +110,13 @@ impl ParserRegistry { } None } + + /// Clear the parser pool, forcing new instances to be created. + /// Useful for testing or when parsers need to be reset globally. + pub fn clear_pool(&self) { + let mut pool = self.pool.write().unwrap(); + pool.clear(); + } } impl Default for ParserRegistry { @@ -70,6 +126,7 @@ impl Default for ParserRegistry { } /// Factory for creating reasoning parsers based on model type. +#[derive(Clone)] pub struct ParserFactory { registry: ParserRegistry, } @@ -109,8 +166,39 @@ impl ParserFactory { Self { registry } } - /// Create a parser for the given model ID. - /// Returns a no-op parser if model is not recognized. + /// Get a pooled parser for the given model ID. + /// Returns a shared instance that can be used concurrently. + /// Falls back to a passthrough parser if model is not recognized. + pub fn get_pooled(&self, model_id: &str) -> PooledParser { + // First try to find by pattern + if let Some(parser) = self.registry.find_pooled_parser_for_model(model_id) { + return parser; + } + + // Fall back to no-op parser (get or create passthrough in pool) + self.registry + .get_pooled_parser("passthrough") + .unwrap_or_else(|| { + // Register passthrough if not already registered + self.registry.register_parser("passthrough", || { + let config = ParserConfig { + think_start_token: "".to_string(), + think_end_token: "".to_string(), + stream_reasoning: true, + max_buffer_size: 65536, + initial_in_reasoning: false, + }; + Box::new( + BaseReasoningParser::new(config).with_model_type("passthrough".to_string()), + ) + }); + self.registry.get_pooled_parser("passthrough").unwrap() + }) + } + + /// Create a new parser instance for the given model ID. + /// Returns a fresh instance (not pooled). + /// Use this when you need an isolated parser instance. pub fn create(&self, model_id: &str) -> Result, ParseError> { // First try to find by pattern if let Some(parser) = self.registry.find_parser_for_model(model_id) { @@ -134,6 +222,12 @@ impl ParserFactory { pub fn registry(&self) -> &ParserRegistry { &self.registry } + + /// Clear the parser pool. + /// Useful for testing or when parsers need to be reset globally. + pub fn clear_pool(&self) { + self.registry.clear_pool(); + } } impl Default for ParserFactory { @@ -195,4 +289,267 @@ mod tests { assert_eq!(step3.model_type(), "deepseek_r1"); assert_eq!(glm45.model_type(), "qwen3"); } + + #[test] + fn test_pooled_parser_reuse() { + let factory = ParserFactory::new(); + + // Get the same parser twice - should be the same instance + let parser1 = factory.get_pooled("deepseek-r1"); + let parser2 = factory.get_pooled("deepseek-r1"); + + // Both should point to the same Arc + assert!(Arc::ptr_eq(&parser1, &parser2)); + + // Different models should get different parsers + let parser3 = factory.get_pooled("qwen3"); + assert!(!Arc::ptr_eq(&parser1, &parser3)); + } + + #[test] + fn test_pooled_parser_concurrent_access() { + use std::thread; + + let factory = ParserFactory::new(); + let parser = factory.get_pooled("deepseek-r1"); + + // Spawn multiple threads that use the same parser + let mut handles = vec![]; + + for i in 0..3 { + let parser_clone = Arc::clone(&parser); + let handle = thread::spawn(move || { + let mut parser = parser_clone.lock().unwrap(); + let input = format!("thread {} reasoninganswer", i); + let result = parser.detect_and_parse_reasoning(&input).unwrap(); + assert_eq!(result.normal_text, "answer"); + assert!(result.reasoning_text.contains("reasoning")); + }); + handles.push(handle); + } + + // Wait for all threads to complete + for handle in handles { + handle.join().unwrap(); + } + } + + #[test] + fn test_pool_clearing() { + let factory = ParserFactory::new(); + + // Get a pooled parser + let parser1 = factory.get_pooled("deepseek-r1"); + + // Clear the pool + factory.clear_pool(); + + // Get another parser - should be a new instance + let parser2 = factory.get_pooled("deepseek-r1"); + + // They should be different instances (different Arc pointers) + assert!(!Arc::ptr_eq(&parser1, &parser2)); + } + + #[test] + fn test_passthrough_parser_pooling() { + let factory = ParserFactory::new(); + + // Unknown models should get passthrough parser + let parser1 = factory.get_pooled("unknown-model-1"); + let parser2 = factory.get_pooled("unknown-model-2"); + + // Both should use the same passthrough parser instance + assert!(Arc::ptr_eq(&parser1, &parser2)); + + // Verify it's actually a passthrough parser + let parser = parser1.lock().unwrap(); + assert_eq!(parser.model_type(), "passthrough"); + } + + #[test] + fn test_high_concurrency_parser_access() { + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::thread; + use std::time::Instant; + + let factory = ParserFactory::new(); + let num_threads = 100; + let requests_per_thread = 50; + let models = vec!["deepseek-r1", "qwen3", "kimi", "qwen3-thinking"]; + + // Track successful operations + let success_count = Arc::new(AtomicUsize::new(0)); + let error_count = Arc::new(AtomicUsize::new(0)); + + let start = Instant::now(); + let mut handles = vec![]; + + for thread_id in 0..num_threads { + let factory = factory.clone(); + let models = models.clone(); + let success_count = Arc::clone(&success_count); + let error_count = Arc::clone(&error_count); + + let handle = thread::spawn(move || { + for request_id in 0..requests_per_thread { + // Rotate through different models + let model = &models[(thread_id + request_id) % models.len()]; + let parser = factory.get_pooled(model); + + // Use blocking lock - this is the realistic scenario + // In production, requests would wait for the parser to be available + // Handle poisoned locks gracefully + let mut p = match parser.lock() { + Ok(guard) => guard, + Err(_poisoned) => { + // Lock was poisoned by a panicking thread + // In production, we might want to recreate the parser + // For testing, we'll just skip this iteration + error_count.fetch_add(1, Ordering::Relaxed); + continue; + } + }; + + // Simulate realistic parsing work with substantial text + // Typical reasoning can be 500-5000 tokens + let reasoning_text = format!( + "Thread {} is processing request {}. Let me think through this step by step. \ + First, I need to understand the problem. The problem involves analyzing data \ + and making calculations. Let me break this down: \n\ + 1. Initial analysis shows that we have multiple variables to consider. \ + 2. The data suggests a pattern that needs further investigation. \ + 3. Computing the values: {} * {} = {}. \ + 4. Cross-referencing with previous results indicates consistency. \ + 5. The mathematical proof follows from the axioms... \ + 6. Considering edge cases and boundary conditions... \ + 7. Validating against known constraints... \ + 8. The conclusion follows logically from premises A, B, and C. \ + This reasoning chain demonstrates the validity of our approach.", + thread_id, request_id, thread_id, request_id, thread_id * request_id + ); + + let answer_text = format!( + "Based on my analysis, the answer for thread {} request {} is: \ + The solution involves multiple steps as outlined in the reasoning. \ + The final result is {} with confidence level high. \ + This conclusion is supported by rigorous mathematical analysis \ + and has been validated against multiple test cases. \ + The implementation should handle edge cases appropriately.", + thread_id, + request_id, + thread_id * request_id + ); + + let input = format!("{}{}", reasoning_text, answer_text); + + match p.detect_and_parse_reasoning(&input) { + Ok(result) => { + // Verify parsing worked correctly with substantial content + // Note: Some parsers with stream_reasoning=true won't accumulate reasoning text + assert!(result + .normal_text + .contains(&format!("thread {}", thread_id))); + + // For parsers that accumulate reasoning (stream_reasoning=false) + // the reasoning_text should be populated + if !result.reasoning_text.is_empty() { + assert!(result + .reasoning_text + .contains(&format!("Thread {}", thread_id))); + assert!(result.reasoning_text.len() > 500); // Ensure substantial reasoning + } + + // Normal text should always be present + assert!(result.normal_text.len() > 100); // Ensure substantial answer + success_count.fetch_add(1, Ordering::Relaxed); + } + Err(e) => { + eprintln!("Parse error: {:?}", e); + error_count.fetch_add(1, Ordering::Relaxed); + } + } + + // Explicitly drop the lock to release it quickly + drop(p); + } + }); + handles.push(handle); + } + + // Wait for all threads + for handle in handles { + handle.join().unwrap(); + } + + let duration = start.elapsed(); + let total_requests = num_threads * requests_per_thread; + let successes = success_count.load(Ordering::Relaxed); + let errors = error_count.load(Ordering::Relaxed); + + // Print stats for debugging + println!( + "High concurrency test: {} threads, {} requests each", + num_threads, requests_per_thread + ); + println!( + "Completed in {:?}, {} successes, {} errors", + duration, successes, errors + ); + println!( + "Throughput: {:.0} requests/sec", + (total_requests as f64) / duration.as_secs_f64() + ); + + // All requests should succeed + assert_eq!(successes, total_requests); + assert_eq!(errors, 0); + + // Performance check: should handle at least 1000 req/sec + let throughput = (total_requests as f64) / duration.as_secs_f64(); + assert!( + throughput > 1000.0, + "Throughput too low: {:.0} req/sec", + throughput + ); + } + + #[test] + fn test_concurrent_pool_modifications() { + use std::thread; + + let factory = ParserFactory::new(); + let mut handles = vec![]; + + // Thread 1: Continuously get parsers + let factory1 = factory.clone(); + handles.push(thread::spawn(move || { + for _ in 0..100 { + let _parser = factory1.get_pooled("deepseek-r1"); + } + })); + + // Thread 2: Continuously clear pool + let factory2 = factory.clone(); + handles.push(thread::spawn(move || { + for _ in 0..10 { + factory2.clear_pool(); + thread::sleep(std::time::Duration::from_micros(100)); + } + })); + + // Thread 3: Get different parsers + let factory3 = factory.clone(); + handles.push(thread::spawn(move || { + for i in 0..100 { + let models = ["qwen3", "kimi", "unknown"]; + let _parser = factory3.get_pooled(models[i % 3]); + } + })); + + // Wait for all threads - should not deadlock or panic + for handle in handles { + handle.join().unwrap(); + } + } } diff --git a/sgl-router/src/reasoning_parser/mod.rs b/sgl-router/src/reasoning_parser/mod.rs index f566a5187..3be6321c7 100644 --- a/sgl-router/src/reasoning_parser/mod.rs +++ b/sgl-router/src/reasoning_parser/mod.rs @@ -2,7 +2,7 @@ pub mod factory; pub mod parsers; pub mod traits; -pub use factory::{ParserFactory, ParserRegistry}; +pub use factory::{ParserFactory, ParserRegistry, PooledParser}; pub use parsers::{ BaseReasoningParser, DeepSeekR1Parser, KimiParser, Qwen3Parser, QwenThinkingParser, };