[router] improve reasoning parser lock and reduce req cloning (#11336)
This commit is contained in:
@@ -2,7 +2,9 @@
|
||||
// Now with parser pooling support for efficient reuse across requests.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex, RwLock};
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use crate::reasoning_parser::parsers::{
|
||||
BaseReasoningParser, DeepSeekR1Parser, Glm45Parser, KimiParser, Qwen3Parser,
|
||||
@@ -11,6 +13,7 @@ use crate::reasoning_parser::parsers::{
|
||||
use crate::reasoning_parser::traits::{ParseError, ParserConfig, ReasoningParser};
|
||||
|
||||
/// Type alias for pooled parser instances.
|
||||
/// Uses tokio::Mutex to avoid blocking the async executor.
|
||||
pub type PooledParser = Arc<Mutex<Box<dyn ReasoningParser>>>;
|
||||
|
||||
/// Type alias for parser creator functions.
|
||||
@@ -301,8 +304,8 @@ mod tests {
|
||||
assert_eq!(glm45.model_type(), "glm45");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pooled_parser_reuse() {
|
||||
#[tokio::test]
|
||||
async fn test_pooled_parser_reuse() {
|
||||
let factory = ReasoningParserFactory::new();
|
||||
|
||||
// Get the same parser twice - should be the same instance
|
||||
@@ -317,20 +320,18 @@ mod tests {
|
||||
assert!(!Arc::ptr_eq(&parser1, &parser3));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pooled_parser_concurrent_access() {
|
||||
use std::thread;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_pooled_parser_concurrent_access() {
|
||||
let factory = ReasoningParserFactory::new();
|
||||
let parser = factory.get_pooled("deepseek-r1");
|
||||
|
||||
// Spawn multiple threads that use the same parser
|
||||
// Spawn multiple async tasks that use the same parser
|
||||
let mut handles = vec![];
|
||||
|
||||
for i in 0..3 {
|
||||
let parser_clone = Arc::clone(&parser);
|
||||
let handle = thread::spawn(move || {
|
||||
let mut parser = parser_clone.lock().unwrap();
|
||||
let handle = tokio::spawn(async move {
|
||||
let mut parser = parser_clone.lock().await;
|
||||
let input = format!("thread {} reasoning</think>answer", i);
|
||||
let result = parser.detect_and_parse_reasoning(&input).unwrap();
|
||||
assert_eq!(result.normal_text, "answer");
|
||||
@@ -339,14 +340,14 @@ mod tests {
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
// Wait for all threads to complete
|
||||
// Wait for all tasks to complete
|
||||
for handle in handles {
|
||||
handle.join().unwrap();
|
||||
handle.await.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pool_clearing() {
|
||||
#[tokio::test]
|
||||
async fn test_pool_clearing() {
|
||||
let factory = ReasoningParserFactory::new();
|
||||
|
||||
// Get a pooled parser
|
||||
@@ -362,8 +363,8 @@ mod tests {
|
||||
assert!(!Arc::ptr_eq(&parser1, &parser2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_passthrough_parser_pooling() {
|
||||
#[tokio::test]
|
||||
async fn test_passthrough_parser_pooling() {
|
||||
let factory = ReasoningParserFactory::new();
|
||||
|
||||
// Unknown models should get passthrough parser
|
||||
@@ -373,19 +374,18 @@ mod tests {
|
||||
// Both should use the same passthrough parser instance
|
||||
assert!(Arc::ptr_eq(&parser1, &parser2));
|
||||
|
||||
let parser = parser1.lock().unwrap();
|
||||
let parser = parser1.lock().await;
|
||||
assert_eq!(parser.model_type(), "passthrough");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_high_concurrency_parser_access() {
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 8)]
|
||||
async fn test_high_concurrency_parser_access() {
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::thread;
|
||||
use std::time::Instant;
|
||||
|
||||
let factory = ReasoningParserFactory::new();
|
||||
let num_threads = 100;
|
||||
let requests_per_thread = 50;
|
||||
let num_tasks = 100;
|
||||
let requests_per_task = 50;
|
||||
let models = vec!["deepseek-r1", "qwen3", "kimi", "qwen3-thinking"];
|
||||
|
||||
// Track successful operations
|
||||
@@ -395,36 +395,25 @@ mod tests {
|
||||
let start = Instant::now();
|
||||
let mut handles = vec![];
|
||||
|
||||
for thread_id in 0..num_threads {
|
||||
for task_id in 0..num_tasks {
|
||||
let factory = factory.clone();
|
||||
let models = models.clone();
|
||||
let success_count = Arc::clone(&success_count);
|
||||
let error_count = Arc::clone(&error_count);
|
||||
|
||||
let handle = thread::spawn(move || {
|
||||
for request_id in 0..requests_per_thread {
|
||||
let handle = tokio::spawn(async move {
|
||||
for request_id in 0..requests_per_task {
|
||||
// Rotate through different models
|
||||
let model = &models[(thread_id + request_id) % models.len()];
|
||||
let model = &models[(task_id + request_id) % models.len()];
|
||||
let parser = factory.get_pooled(model);
|
||||
|
||||
// Use blocking lock - this is the realistic scenario
|
||||
// In production, requests would wait for the parser to be available
|
||||
// Handle poisoned locks gracefully
|
||||
let mut p = match parser.lock() {
|
||||
Ok(guard) => guard,
|
||||
Err(_poisoned) => {
|
||||
// Lock was poisoned by a panicking thread
|
||||
// In production, we might want to recreate the parser
|
||||
// For testing, we'll just skip this iteration
|
||||
error_count.fetch_add(1, Ordering::Relaxed);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
// Use async lock - tokio::Mutex doesn't poison
|
||||
let mut p = parser.lock().await;
|
||||
|
||||
// Simulate realistic parsing work with substantial text
|
||||
// Typical reasoning can be 500-5000 tokens
|
||||
let reasoning_text = format!(
|
||||
"Thread {} is processing request {}. Let me think through this step by step. \
|
||||
"Task {} is processing request {}. Let me think through this step by step. \
|
||||
First, I need to understand the problem. The problem involves analyzing data \
|
||||
and making calculations. Let me break this down: \n\
|
||||
1. Initial analysis shows that we have multiple variables to consider. \
|
||||
@@ -436,19 +425,19 @@ mod tests {
|
||||
7. Validating against known constraints... \
|
||||
8. The conclusion follows logically from premises A, B, and C. \
|
||||
This reasoning chain demonstrates the validity of our approach.",
|
||||
thread_id, request_id, thread_id, request_id, thread_id * request_id
|
||||
task_id, request_id, task_id, request_id, task_id * request_id
|
||||
);
|
||||
|
||||
let answer_text = format!(
|
||||
"Based on my analysis, the answer for thread {} request {} is: \
|
||||
"Based on my analysis, the answer for task {} request {} is: \
|
||||
The solution involves multiple steps as outlined in the reasoning. \
|
||||
The final result is {} with confidence level high. \
|
||||
This conclusion is supported by rigorous mathematical analysis \
|
||||
and has been validated against multiple test cases. \
|
||||
The implementation should handle edge cases appropriately.",
|
||||
thread_id,
|
||||
task_id,
|
||||
request_id,
|
||||
thread_id * request_id
|
||||
task_id * request_id
|
||||
);
|
||||
|
||||
let input = format!("<think>{}</think>{}", reasoning_text, answer_text);
|
||||
@@ -456,16 +445,14 @@ mod tests {
|
||||
match p.detect_and_parse_reasoning(&input) {
|
||||
Ok(result) => {
|
||||
// Note: Some parsers with stream_reasoning=true won't accumulate reasoning text
|
||||
assert!(result
|
||||
.normal_text
|
||||
.contains(&format!("thread {}", thread_id)));
|
||||
assert!(result.normal_text.contains(&format!("task {}", task_id)));
|
||||
|
||||
// For parsers that accumulate reasoning (stream_reasoning=false)
|
||||
// the reasoning_text should be populated
|
||||
if !result.reasoning_text.is_empty() {
|
||||
assert!(result
|
||||
.reasoning_text
|
||||
.contains(&format!("Thread {}", thread_id)));
|
||||
.contains(&format!("Task {}", task_id)));
|
||||
assert!(result.reasoning_text.len() > 500); // Ensure substantial reasoning
|
||||
}
|
||||
|
||||
@@ -486,20 +473,20 @@ mod tests {
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
// Wait for all threads
|
||||
// Wait for all tasks
|
||||
for handle in handles {
|
||||
handle.join().unwrap();
|
||||
handle.await.unwrap();
|
||||
}
|
||||
|
||||
let duration = start.elapsed();
|
||||
let total_requests = num_threads * requests_per_thread;
|
||||
let total_requests = num_tasks * requests_per_task;
|
||||
let successes = success_count.load(Ordering::Relaxed);
|
||||
let errors = error_count.load(Ordering::Relaxed);
|
||||
|
||||
// Print stats for debugging
|
||||
println!(
|
||||
"High concurrency test: {} threads, {} requests each",
|
||||
num_threads, requests_per_thread
|
||||
"High concurrency test: {} tasks, {} requests each",
|
||||
num_tasks, requests_per_task
|
||||
);
|
||||
println!(
|
||||
"Completed in {:?}, {} successes, {} errors",
|
||||
@@ -523,42 +510,40 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_concurrent_pool_modifications() {
|
||||
use std::thread;
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn test_concurrent_pool_modifications() {
|
||||
let factory = ReasoningParserFactory::new();
|
||||
let mut handles = vec![];
|
||||
|
||||
// Thread 1: Continuously get parsers
|
||||
// Task 1: Continuously get parsers
|
||||
let factory1 = factory.clone();
|
||||
handles.push(thread::spawn(move || {
|
||||
handles.push(tokio::spawn(async move {
|
||||
for _ in 0..100 {
|
||||
let _parser = factory1.get_pooled("deepseek-r1");
|
||||
}
|
||||
}));
|
||||
|
||||
// Thread 2: Continuously clear pool
|
||||
// Task 2: Continuously clear pool
|
||||
let factory2 = factory.clone();
|
||||
handles.push(thread::spawn(move || {
|
||||
handles.push(tokio::spawn(async move {
|
||||
for _ in 0..10 {
|
||||
factory2.clear_pool();
|
||||
thread::sleep(std::time::Duration::from_micros(100));
|
||||
tokio::time::sleep(tokio::time::Duration::from_micros(100)).await;
|
||||
}
|
||||
}));
|
||||
|
||||
// Thread 3: Get different parsers
|
||||
// Task 3: Get different parsers
|
||||
let factory3 = factory.clone();
|
||||
handles.push(thread::spawn(move || {
|
||||
handles.push(tokio::spawn(async move {
|
||||
for i in 0..100 {
|
||||
let models = ["qwen3", "kimi", "unknown"];
|
||||
let _parser = factory3.get_pooled(models[i % 3]);
|
||||
}
|
||||
}));
|
||||
|
||||
// Wait for all threads - should not deadlock or panic
|
||||
// Wait for all tasks - should not deadlock or panic
|
||||
for handle in handles {
|
||||
handle.join().unwrap();
|
||||
handle.await.unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user