diff --git a/sgl-router/src/reasoning_parser/factory.rs b/sgl-router/src/reasoning_parser/factory.rs index f7ea9f3fa..7e2367f7b 100644 --- a/sgl-router/src/reasoning_parser/factory.rs +++ b/sgl-router/src/reasoning_parser/factory.rs @@ -2,7 +2,9 @@ // Now with parser pooling support for efficient reuse across requests. use std::collections::HashMap; -use std::sync::{Arc, Mutex, RwLock}; +use std::sync::{Arc, RwLock}; + +use tokio::sync::Mutex; use crate::reasoning_parser::parsers::{ BaseReasoningParser, DeepSeekR1Parser, Glm45Parser, KimiParser, Qwen3Parser, @@ -11,6 +13,7 @@ use crate::reasoning_parser::parsers::{ use crate::reasoning_parser::traits::{ParseError, ParserConfig, ReasoningParser}; /// Type alias for pooled parser instances. +/// Uses tokio::Mutex to avoid blocking the async executor. pub type PooledParser = Arc>>; /// Type alias for parser creator functions. @@ -301,8 +304,8 @@ mod tests { assert_eq!(glm45.model_type(), "glm45"); } - #[test] - fn test_pooled_parser_reuse() { + #[tokio::test] + async fn test_pooled_parser_reuse() { let factory = ReasoningParserFactory::new(); // Get the same parser twice - should be the same instance @@ -317,20 +320,18 @@ mod tests { assert!(!Arc::ptr_eq(&parser1, &parser3)); } - #[test] - fn test_pooled_parser_concurrent_access() { - use std::thread; - + #[tokio::test] + async fn test_pooled_parser_concurrent_access() { let factory = ReasoningParserFactory::new(); let parser = factory.get_pooled("deepseek-r1"); - // Spawn multiple threads that use the same parser + // Spawn multiple async tasks that use the same parser let mut handles = vec![]; for i in 0..3 { let parser_clone = Arc::clone(&parser); - let handle = thread::spawn(move || { - let mut parser = parser_clone.lock().unwrap(); + let handle = tokio::spawn(async move { + let mut parser = parser_clone.lock().await; let input = format!("thread {} reasoninganswer", i); let result = parser.detect_and_parse_reasoning(&input).unwrap(); assert_eq!(result.normal_text, "answer"); @@ -339,14 +340,14 @@ mod tests { handles.push(handle); } - // Wait for all threads to complete + // Wait for all tasks to complete for handle in handles { - handle.join().unwrap(); + handle.await.unwrap(); } } - #[test] - fn test_pool_clearing() { + #[tokio::test] + async fn test_pool_clearing() { let factory = ReasoningParserFactory::new(); // Get a pooled parser @@ -362,8 +363,8 @@ mod tests { assert!(!Arc::ptr_eq(&parser1, &parser2)); } - #[test] - fn test_passthrough_parser_pooling() { + #[tokio::test] + async fn test_passthrough_parser_pooling() { let factory = ReasoningParserFactory::new(); // Unknown models should get passthrough parser @@ -373,19 +374,18 @@ mod tests { // Both should use the same passthrough parser instance assert!(Arc::ptr_eq(&parser1, &parser2)); - let parser = parser1.lock().unwrap(); + let parser = parser1.lock().await; assert_eq!(parser.model_type(), "passthrough"); } - #[test] - fn test_high_concurrency_parser_access() { + #[tokio::test(flavor = "multi_thread", worker_threads = 8)] + async fn test_high_concurrency_parser_access() { use std::sync::atomic::{AtomicUsize, Ordering}; - use std::thread; use std::time::Instant; let factory = ReasoningParserFactory::new(); - let num_threads = 100; - let requests_per_thread = 50; + let num_tasks = 100; + let requests_per_task = 50; let models = vec!["deepseek-r1", "qwen3", "kimi", "qwen3-thinking"]; // Track successful operations @@ -395,36 +395,25 @@ mod tests { let start = Instant::now(); let mut handles = vec![]; - for thread_id in 0..num_threads { + for task_id in 0..num_tasks { let factory = factory.clone(); let models = models.clone(); let success_count = Arc::clone(&success_count); let error_count = Arc::clone(&error_count); - let handle = thread::spawn(move || { - for request_id in 0..requests_per_thread { + let handle = tokio::spawn(async move { + for request_id in 0..requests_per_task { // Rotate through different models - let model = &models[(thread_id + request_id) % models.len()]; + let model = &models[(task_id + request_id) % models.len()]; let parser = factory.get_pooled(model); - // Use blocking lock - this is the realistic scenario - // In production, requests would wait for the parser to be available - // Handle poisoned locks gracefully - let mut p = match parser.lock() { - Ok(guard) => guard, - Err(_poisoned) => { - // Lock was poisoned by a panicking thread - // In production, we might want to recreate the parser - // For testing, we'll just skip this iteration - error_count.fetch_add(1, Ordering::Relaxed); - continue; - } - }; + // Use async lock - tokio::Mutex doesn't poison + let mut p = parser.lock().await; // Simulate realistic parsing work with substantial text // Typical reasoning can be 500-5000 tokens let reasoning_text = format!( - "Thread {} is processing request {}. Let me think through this step by step. \ + "Task {} is processing request {}. Let me think through this step by step. \ First, I need to understand the problem. The problem involves analyzing data \ and making calculations. Let me break this down: \n\ 1. Initial analysis shows that we have multiple variables to consider. \ @@ -436,19 +425,19 @@ mod tests { 7. Validating against known constraints... \ 8. The conclusion follows logically from premises A, B, and C. \ This reasoning chain demonstrates the validity of our approach.", - thread_id, request_id, thread_id, request_id, thread_id * request_id + task_id, request_id, task_id, request_id, task_id * request_id ); let answer_text = format!( - "Based on my analysis, the answer for thread {} request {} is: \ + "Based on my analysis, the answer for task {} request {} is: \ The solution involves multiple steps as outlined in the reasoning. \ The final result is {} with confidence level high. \ This conclusion is supported by rigorous mathematical analysis \ and has been validated against multiple test cases. \ The implementation should handle edge cases appropriately.", - thread_id, + task_id, request_id, - thread_id * request_id + task_id * request_id ); let input = format!("{}{}", reasoning_text, answer_text); @@ -456,16 +445,14 @@ mod tests { match p.detect_and_parse_reasoning(&input) { Ok(result) => { // Note: Some parsers with stream_reasoning=true won't accumulate reasoning text - assert!(result - .normal_text - .contains(&format!("thread {}", thread_id))); + assert!(result.normal_text.contains(&format!("task {}", task_id))); // For parsers that accumulate reasoning (stream_reasoning=false) // the reasoning_text should be populated if !result.reasoning_text.is_empty() { assert!(result .reasoning_text - .contains(&format!("Thread {}", thread_id))); + .contains(&format!("Task {}", task_id))); assert!(result.reasoning_text.len() > 500); // Ensure substantial reasoning } @@ -486,20 +473,20 @@ mod tests { handles.push(handle); } - // Wait for all threads + // Wait for all tasks for handle in handles { - handle.join().unwrap(); + handle.await.unwrap(); } let duration = start.elapsed(); - let total_requests = num_threads * requests_per_thread; + let total_requests = num_tasks * requests_per_task; let successes = success_count.load(Ordering::Relaxed); let errors = error_count.load(Ordering::Relaxed); // Print stats for debugging println!( - "High concurrency test: {} threads, {} requests each", - num_threads, requests_per_thread + "High concurrency test: {} tasks, {} requests each", + num_tasks, requests_per_task ); println!( "Completed in {:?}, {} successes, {} errors", @@ -523,42 +510,40 @@ mod tests { ); } - #[test] - fn test_concurrent_pool_modifications() { - use std::thread; - + #[tokio::test(flavor = "multi_thread", worker_threads = 4)] + async fn test_concurrent_pool_modifications() { let factory = ReasoningParserFactory::new(); let mut handles = vec![]; - // Thread 1: Continuously get parsers + // Task 1: Continuously get parsers let factory1 = factory.clone(); - handles.push(thread::spawn(move || { + handles.push(tokio::spawn(async move { for _ in 0..100 { let _parser = factory1.get_pooled("deepseek-r1"); } })); - // Thread 2: Continuously clear pool + // Task 2: Continuously clear pool let factory2 = factory.clone(); - handles.push(thread::spawn(move || { + handles.push(tokio::spawn(async move { for _ in 0..10 { factory2.clear_pool(); - thread::sleep(std::time::Duration::from_micros(100)); + tokio::time::sleep(tokio::time::Duration::from_micros(100)).await; } })); - // Thread 3: Get different parsers + // Task 3: Get different parsers let factory3 = factory.clone(); - handles.push(thread::spawn(move || { + handles.push(tokio::spawn(async move { for i in 0..100 { let models = ["qwen3", "kimi", "unknown"]; let _parser = factory3.get_pooled(models[i % 3]); } })); - // Wait for all threads - should not deadlock or panic + // Wait for all tasks - should not deadlock or panic for handle in handles { - handle.join().unwrap(); + handle.await.unwrap(); } } } diff --git a/sgl-router/src/routers/grpc/context.rs b/sgl-router/src/routers/grpc/context.rs index 5c6bb7a99..bc9f3c7a5 100644 --- a/sgl-router/src/routers/grpc/context.rs +++ b/sgl-router/src/routers/grpc/context.rs @@ -48,9 +48,10 @@ pub struct RequestInput { } /// Request type variants +/// Using Arc instead of Box to enable cheap cloning for background tasks pub enum RequestType { - Chat(Box), - Generate(Box), + Chat(Arc), + Generate(Arc), } /// Shared components (injected once at creation) @@ -181,14 +182,14 @@ pub struct StreamingState { impl RequestContext { /// Create context for chat completion request pub fn for_chat( - request: ChatCompletionRequest, + request: Arc, headers: Option, model_id: Option, components: Arc, ) -> Self { Self { input: RequestInput { - request_type: RequestType::Chat(Box::new(request)), + request_type: RequestType::Chat(request), headers, model_id, }, @@ -199,14 +200,14 @@ impl RequestContext { /// Create context for generate request pub fn for_generate( - request: GenerateRequest, + request: Arc, headers: Option, model_id: Option, components: Arc, ) -> Self { Self { input: RequestInput { - request_type: RequestType::Generate(Box::new(request)), + request_type: RequestType::Generate(request), headers, model_id, }, @@ -228,6 +229,14 @@ impl RequestContext { } } + /// Get Arc clone of chat request (panics if not chat) + pub fn chat_request_arc(&self) -> Arc { + match &self.input.request_type { + RequestType::Chat(req) => Arc::clone(req), + _ => panic!("Expected chat request"), + } + } + /// Get generate request (panics if not generate) pub fn generate_request(&self) -> &GenerateRequest { match &self.input.request_type { @@ -236,6 +245,14 @@ impl RequestContext { } } + /// Get Arc clone of generate request (panics if not generate) + pub fn generate_request_arc(&self) -> Arc { + match &self.input.request_type { + RequestType::Generate(req) => Arc::clone(req), + _ => panic!("Expected generate request"), + } + } + /// Check if request is streaming pub fn is_streaming(&self) -> bool { match &self.input.request_type { diff --git a/sgl-router/src/routers/grpc/pd_router.rs b/sgl-router/src/routers/grpc/pd_router.rs index a387c16b6..0fc29a6c5 100644 --- a/sgl-router/src/routers/grpc/pd_router.rs +++ b/sgl-router/src/routers/grpc/pd_router.rs @@ -129,7 +129,7 @@ impl GrpcPDRouter { // Use pipeline for ALL requests (streaming and non-streaming) self.pipeline .execute_generate( - body.clone(), + Arc::new(body.clone()), headers.cloned(), model_id.map(|s| s.to_string()), self.shared_components.clone(), @@ -152,7 +152,7 @@ impl GrpcPDRouter { // Use pipeline for ALL requests (streaming and non-streaming) self.pipeline .execute_chat( - body.clone(), + Arc::new(body.clone()), headers.cloned(), model_id.map(|s| s.to_string()), self.shared_components.clone(), diff --git a/sgl-router/src/routers/grpc/pipeline.rs b/sgl-router/src/routers/grpc/pipeline.rs index 01c97df63..17e88bd8a 100644 --- a/sgl-router/src/routers/grpc/pipeline.rs +++ b/sgl-router/src/routers/grpc/pipeline.rs @@ -58,16 +58,17 @@ impl PipelineStage for PreparationStage { async fn execute(&self, ctx: &mut RequestContext) -> Result, Response> { debug!("Stage {}: Processing request", self.name()); - // Clone the request to avoid borrowing issues - match &ctx.input.request_type { - RequestType::Chat(request) => { - let request_clone = request.clone(); - self.prepare_chat(ctx, &request_clone).await?; - } - RequestType::Generate(request) => { - let request_clone = request.clone(); - self.prepare_generate(ctx, &request_clone).await?; - } + // Clone Arc before match to avoid borrow checker issues + // (matching borrows ctx, but prepare_* methods need mutable borrow) + // Arc clone is cheap (8 bytes) - avoids full request clone (15KB-200KB) + let is_chat = matches!(&ctx.input.request_type, RequestType::Chat(_)); + + if is_chat { + let request_arc = ctx.chat_request_arc(); + self.prepare_chat(ctx, &request_arc).await?; + } else { + let request_arc = ctx.generate_request_arc(); + self.prepare_generate(ctx, &request_arc).await?; } Ok(None) @@ -820,7 +821,7 @@ impl ResponseProcessingStage { return Ok(Some( self.streaming_processor.clone().process_streaming_response( execution_result, - ctx.chat_request().clone(), + ctx.chat_request_arc(), // Cheap Arc clone (8 bytes) dispatch.clone(), ), )); @@ -865,9 +866,7 @@ impl ResponseProcessingStage { return Err(utils::internal_error_static("No responses from server")); } - // Clone chat_request to avoid borrow checker conflict - // (ctx.chat_request() borrows ctx, preventing mutable borrow of ctx.state.response.stop_decoder) - let chat_request = ctx.chat_request().clone(); + let chat_request = ctx.chat_request_arc(); let history_tool_calls_count = utils::get_history_tool_calls_count(&chat_request); let stop_decoder = ctx @@ -959,13 +958,11 @@ impl ResponseProcessingStage { .as_ref() .ok_or_else(|| utils::internal_error_static("Dispatch metadata not set"))?; - let generate_request = ctx.generate_request().clone(); - // Streaming: Use StreamingProcessor and return SSE response (done) return Ok(Some( self.streaming_processor.clone().process_streaming_generate( execution_result, - generate_request, + ctx.generate_request_arc(), // Cheap Arc clone (8 bytes) dispatch.clone(), ), )); @@ -1193,8 +1190,8 @@ impl ChatCompletionPipeline { /// Execute the complete pipeline for a chat request pub async fn execute_chat( &self, - request: ChatCompletionRequest, - headers: Option, + request: Arc, + headers: Option, model_id: Option, components: Arc, ) -> Response { @@ -1243,8 +1240,8 @@ impl ChatCompletionPipeline { /// Execute the complete pipeline for a generate request pub async fn execute_generate( &self, - request: GenerateRequest, - headers: Option, + request: Arc, + headers: Option, model_id: Option, components: Arc, ) -> Response { diff --git a/sgl-router/src/routers/grpc/processing.rs b/sgl-router/src/routers/grpc/processing.rs index 7451236fb..91f831663 100644 --- a/sgl-router/src/routers/grpc/processing.rs +++ b/sgl-router/src/routers/grpc/processing.rs @@ -97,9 +97,7 @@ impl ResponseProcessor { &original_request.model, ); - let mut parser = pooled_parser - .lock() - .map_err(|e| format!("Failed to acquire reasoning parser lock: {}", e))?; + let mut parser = pooled_parser.lock().await; match parser.detect_and_parse_reasoning(&processed_text) { Ok(result) => { if !result.reasoning_text.is_empty() { diff --git a/sgl-router/src/routers/grpc/router.rs b/sgl-router/src/routers/grpc/router.rs index 9e95e49f4..c35358209 100644 --- a/sgl-router/src/routers/grpc/router.rs +++ b/sgl-router/src/routers/grpc/router.rs @@ -129,7 +129,7 @@ impl GrpcRouter { // Use pipeline for ALL requests (streaming and non-streaming) self.pipeline .execute_chat( - body.clone(), + Arc::new(body.clone()), headers.cloned(), model_id.map(|s| s.to_string()), self.shared_components.clone(), @@ -149,7 +149,7 @@ impl GrpcRouter { // Use pipeline for ALL requests (streaming and non-streaming) self.pipeline .execute_generate( - body.clone(), + Arc::new(body.clone()), headers.cloned(), model_id.map(|s| s.to_string()), self.shared_components.clone(), diff --git a/sgl-router/src/routers/grpc/streaming.rs b/sgl-router/src/routers/grpc/streaming.rs index fc9a8a68b..86d81f6c3 100644 --- a/sgl-router/src/routers/grpc/streaming.rs +++ b/sgl-router/src/routers/grpc/streaming.rs @@ -66,7 +66,7 @@ impl StreamingProcessor { pub fn process_streaming_response( self: Arc, execution_result: context::ExecutionResult, - chat_request: ChatCompletionRequest, + chat_request: Arc, dispatch: context::DispatchMetadata, ) -> Response { use bytes::Bytes; @@ -156,7 +156,7 @@ impl StreamingProcessor { mut grpc_stream: Streaming, dispatch: context::DispatchMetadata, stop_params: (Option, Option>, bool, bool), - original_request: ChatCompletionRequest, + original_request: Arc, tx: &UnboundedSender>, ) -> Result<(), String> { // Extract request parameters @@ -176,7 +176,7 @@ impl StreamingProcessor { let mut cached_tokens: HashMap = HashMap::new(); // Parser state (lazy initialization per index) - type PooledReasoningParser = Arc>>; + type PooledReasoningParser = Arc>>; let mut reasoning_parsers: HashMap = HashMap::new(); type PooledToolParser = Arc>>; @@ -186,6 +186,9 @@ impl StreamingProcessor { // Per-index stop decoders (each index needs its own state for n>1 support) let mut stop_decoders: HashMap = HashMap::new(); + // Reusable SSE formatting buffer to avoid allocations per chunk + let mut sse_buffer = Vec::with_capacity(512); + // Use dispatch metadata for consistent response fields let request_id = &dispatch.request_id; let model = &dispatch.model; @@ -262,7 +265,8 @@ impl StreamingProcessor { }], usage: None, }; - tx.send(Ok(Bytes::from(Self::format_sse_chunk(&first_chunk)))) + Self::format_sse_chunk_into(&mut sse_buffer, &first_chunk); + tx.send(Ok(Bytes::from(sse_buffer.clone()))) .map_err(|_| "Failed to send first chunk".to_string())?; is_firsts.insert(index, false); } @@ -282,9 +286,11 @@ impl StreamingProcessor { model, created, system_fingerprint, - ); + ) + .await; if let Some(chunk) = reasoning_chunk { - tx.send(Ok(Bytes::from(Self::format_sse_chunk(&chunk)))) + Self::format_sse_chunk_into(&mut sse_buffer, &chunk); + tx.send(Ok(Bytes::from(sse_buffer.clone()))) .map_err(|_| "Failed to send reasoning chunk".to_string())?; } delta = normal_text; @@ -314,7 +320,8 @@ impl StreamingProcessor { .await; for chunk in tool_chunks { - tx.send(Ok(Bytes::from(Self::format_sse_chunk(&chunk)))) + Self::format_sse_chunk_into(&mut sse_buffer, &chunk); + tx.send(Ok(Bytes::from(sse_buffer.clone()))) .map_err(|_| "Failed to send tool call chunk".to_string())?; } @@ -335,7 +342,8 @@ impl StreamingProcessor { system_fingerprint, choice_logprobs, ); - tx.send(Ok(Bytes::from(Self::format_sse_chunk(&content_chunk)))) + Self::format_sse_chunk_into(&mut sse_buffer, &content_chunk); + tx.send(Ok(Bytes::from(sse_buffer.clone()))) .map_err(|_| "Failed to send content chunk".to_string())?; } } @@ -529,7 +537,7 @@ impl StreamingProcessor { decode_stream: Streaming, dispatch: context::DispatchMetadata, stop_params: (Option, Option>, bool, bool), - original_request: ChatCompletionRequest, + original_request: Arc, tx: &UnboundedSender>, ) -> Result<(), String> { // Phase 1.5: Collect input_logprobs from prefill stream if requested @@ -561,7 +569,7 @@ impl StreamingProcessor { pub fn process_streaming_generate( self: Arc, execution_result: context::ExecutionResult, - generate_request: GenerateRequest, + generate_request: Arc, dispatch: context::DispatchMetadata, ) -> Response { let return_logprob = generate_request.return_logprob; @@ -946,11 +954,11 @@ impl StreamingProcessor { /// Helper: Process reasoning content in streaming mode #[allow(clippy::too_many_arguments)] - fn process_reasoning_stream( + async fn process_reasoning_stream( &self, delta: &str, index: u32, - reasoning_parsers: &mut HashMap>>>, + reasoning_parsers: &mut HashMap>>>, request_id: &str, model: &str, created: u64, @@ -967,7 +975,7 @@ impl StreamingProcessor { if let Some(pooled_parser) = reasoning_parsers.get(&index) { let (parse_result, in_reasoning) = { - let mut parser = pooled_parser.lock().unwrap(); + let mut parser = pooled_parser.lock().await; let result = parser.parse_reasoning_streaming_incremental(delta); let in_reasoning = parser.is_in_reasoning(); (result, in_reasoning) @@ -1134,15 +1142,20 @@ impl StreamingProcessor { (false, chunks) } - /// Format a response as SSE chunk - fn format_sse_chunk(chunk: &ChatCompletionStreamResponse) -> String { - match serde_json::to_string(chunk) { - Ok(json) => format!("data: {}\n\n", json), - Err(e) => { - error!("Failed to serialize SSE chunk: {}", e); - format!("data: {}\n\n", json!({"error": "serialization_failed"})) - } + /// Format a response as SSE chunk into a reusable buffer + /// This avoids allocations by reusing the same buffer across multiple chunks + #[inline] + fn format_sse_chunk_into(buffer: &mut Vec, chunk: &ChatCompletionStreamResponse) { + buffer.clear(); + buffer.extend_from_slice(b"data: "); + if let Err(e) = serde_json::to_writer(&mut *buffer, chunk) { + error!("Failed to serialize SSE chunk: {}", e); + buffer.clear(); + buffer.extend_from_slice(b"data: "); + let error_msg = json!({"error": "serialization_failed"}).to_string(); + buffer.extend_from_slice(error_msg.as_bytes()); } + buffer.extend_from_slice(b"\n\n"); } /// Create a content chunk response