[router] improve reasoning parser lock and reduce req cloning (#11336)
This commit is contained in:
@@ -2,7 +2,9 @@
|
|||||||
// Now with parser pooling support for efficient reuse across requests.
|
// Now with parser pooling support for efficient reuse across requests.
|
||||||
|
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::{Arc, Mutex, RwLock};
|
use std::sync::{Arc, RwLock};
|
||||||
|
|
||||||
|
use tokio::sync::Mutex;
|
||||||
|
|
||||||
use crate::reasoning_parser::parsers::{
|
use crate::reasoning_parser::parsers::{
|
||||||
BaseReasoningParser, DeepSeekR1Parser, Glm45Parser, KimiParser, Qwen3Parser,
|
BaseReasoningParser, DeepSeekR1Parser, Glm45Parser, KimiParser, Qwen3Parser,
|
||||||
@@ -11,6 +13,7 @@ use crate::reasoning_parser::parsers::{
|
|||||||
use crate::reasoning_parser::traits::{ParseError, ParserConfig, ReasoningParser};
|
use crate::reasoning_parser::traits::{ParseError, ParserConfig, ReasoningParser};
|
||||||
|
|
||||||
/// Type alias for pooled parser instances.
|
/// Type alias for pooled parser instances.
|
||||||
|
/// Uses tokio::Mutex to avoid blocking the async executor.
|
||||||
pub type PooledParser = Arc<Mutex<Box<dyn ReasoningParser>>>;
|
pub type PooledParser = Arc<Mutex<Box<dyn ReasoningParser>>>;
|
||||||
|
|
||||||
/// Type alias for parser creator functions.
|
/// Type alias for parser creator functions.
|
||||||
@@ -301,8 +304,8 @@ mod tests {
|
|||||||
assert_eq!(glm45.model_type(), "glm45");
|
assert_eq!(glm45.model_type(), "glm45");
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[tokio::test]
|
||||||
fn test_pooled_parser_reuse() {
|
async fn test_pooled_parser_reuse() {
|
||||||
let factory = ReasoningParserFactory::new();
|
let factory = ReasoningParserFactory::new();
|
||||||
|
|
||||||
// Get the same parser twice - should be the same instance
|
// Get the same parser twice - should be the same instance
|
||||||
@@ -317,20 +320,18 @@ mod tests {
|
|||||||
assert!(!Arc::ptr_eq(&parser1, &parser3));
|
assert!(!Arc::ptr_eq(&parser1, &parser3));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[tokio::test]
|
||||||
fn test_pooled_parser_concurrent_access() {
|
async fn test_pooled_parser_concurrent_access() {
|
||||||
use std::thread;
|
|
||||||
|
|
||||||
let factory = ReasoningParserFactory::new();
|
let factory = ReasoningParserFactory::new();
|
||||||
let parser = factory.get_pooled("deepseek-r1");
|
let parser = factory.get_pooled("deepseek-r1");
|
||||||
|
|
||||||
// Spawn multiple threads that use the same parser
|
// Spawn multiple async tasks that use the same parser
|
||||||
let mut handles = vec![];
|
let mut handles = vec![];
|
||||||
|
|
||||||
for i in 0..3 {
|
for i in 0..3 {
|
||||||
let parser_clone = Arc::clone(&parser);
|
let parser_clone = Arc::clone(&parser);
|
||||||
let handle = thread::spawn(move || {
|
let handle = tokio::spawn(async move {
|
||||||
let mut parser = parser_clone.lock().unwrap();
|
let mut parser = parser_clone.lock().await;
|
||||||
let input = format!("thread {} reasoning</think>answer", i);
|
let input = format!("thread {} reasoning</think>answer", i);
|
||||||
let result = parser.detect_and_parse_reasoning(&input).unwrap();
|
let result = parser.detect_and_parse_reasoning(&input).unwrap();
|
||||||
assert_eq!(result.normal_text, "answer");
|
assert_eq!(result.normal_text, "answer");
|
||||||
@@ -339,14 +340,14 @@ mod tests {
|
|||||||
handles.push(handle);
|
handles.push(handle);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wait for all threads to complete
|
// Wait for all tasks to complete
|
||||||
for handle in handles {
|
for handle in handles {
|
||||||
handle.join().unwrap();
|
handle.await.unwrap();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[tokio::test]
|
||||||
fn test_pool_clearing() {
|
async fn test_pool_clearing() {
|
||||||
let factory = ReasoningParserFactory::new();
|
let factory = ReasoningParserFactory::new();
|
||||||
|
|
||||||
// Get a pooled parser
|
// Get a pooled parser
|
||||||
@@ -362,8 +363,8 @@ mod tests {
|
|||||||
assert!(!Arc::ptr_eq(&parser1, &parser2));
|
assert!(!Arc::ptr_eq(&parser1, &parser2));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[tokio::test]
|
||||||
fn test_passthrough_parser_pooling() {
|
async fn test_passthrough_parser_pooling() {
|
||||||
let factory = ReasoningParserFactory::new();
|
let factory = ReasoningParserFactory::new();
|
||||||
|
|
||||||
// Unknown models should get passthrough parser
|
// Unknown models should get passthrough parser
|
||||||
@@ -373,19 +374,18 @@ mod tests {
|
|||||||
// Both should use the same passthrough parser instance
|
// Both should use the same passthrough parser instance
|
||||||
assert!(Arc::ptr_eq(&parser1, &parser2));
|
assert!(Arc::ptr_eq(&parser1, &parser2));
|
||||||
|
|
||||||
let parser = parser1.lock().unwrap();
|
let parser = parser1.lock().await;
|
||||||
assert_eq!(parser.model_type(), "passthrough");
|
assert_eq!(parser.model_type(), "passthrough");
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[tokio::test(flavor = "multi_thread", worker_threads = 8)]
|
||||||
fn test_high_concurrency_parser_access() {
|
async fn test_high_concurrency_parser_access() {
|
||||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||||
use std::thread;
|
|
||||||
use std::time::Instant;
|
use std::time::Instant;
|
||||||
|
|
||||||
let factory = ReasoningParserFactory::new();
|
let factory = ReasoningParserFactory::new();
|
||||||
let num_threads = 100;
|
let num_tasks = 100;
|
||||||
let requests_per_thread = 50;
|
let requests_per_task = 50;
|
||||||
let models = vec!["deepseek-r1", "qwen3", "kimi", "qwen3-thinking"];
|
let models = vec!["deepseek-r1", "qwen3", "kimi", "qwen3-thinking"];
|
||||||
|
|
||||||
// Track successful operations
|
// Track successful operations
|
||||||
@@ -395,36 +395,25 @@ mod tests {
|
|||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
let mut handles = vec![];
|
let mut handles = vec![];
|
||||||
|
|
||||||
for thread_id in 0..num_threads {
|
for task_id in 0..num_tasks {
|
||||||
let factory = factory.clone();
|
let factory = factory.clone();
|
||||||
let models = models.clone();
|
let models = models.clone();
|
||||||
let success_count = Arc::clone(&success_count);
|
let success_count = Arc::clone(&success_count);
|
||||||
let error_count = Arc::clone(&error_count);
|
let error_count = Arc::clone(&error_count);
|
||||||
|
|
||||||
let handle = thread::spawn(move || {
|
let handle = tokio::spawn(async move {
|
||||||
for request_id in 0..requests_per_thread {
|
for request_id in 0..requests_per_task {
|
||||||
// Rotate through different models
|
// Rotate through different models
|
||||||
let model = &models[(thread_id + request_id) % models.len()];
|
let model = &models[(task_id + request_id) % models.len()];
|
||||||
let parser = factory.get_pooled(model);
|
let parser = factory.get_pooled(model);
|
||||||
|
|
||||||
// Use blocking lock - this is the realistic scenario
|
// Use async lock - tokio::Mutex doesn't poison
|
||||||
// In production, requests would wait for the parser to be available
|
let mut p = parser.lock().await;
|
||||||
// Handle poisoned locks gracefully
|
|
||||||
let mut p = match parser.lock() {
|
|
||||||
Ok(guard) => guard,
|
|
||||||
Err(_poisoned) => {
|
|
||||||
// Lock was poisoned by a panicking thread
|
|
||||||
// In production, we might want to recreate the parser
|
|
||||||
// For testing, we'll just skip this iteration
|
|
||||||
error_count.fetch_add(1, Ordering::Relaxed);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Simulate realistic parsing work with substantial text
|
// Simulate realistic parsing work with substantial text
|
||||||
// Typical reasoning can be 500-5000 tokens
|
// Typical reasoning can be 500-5000 tokens
|
||||||
let reasoning_text = format!(
|
let reasoning_text = format!(
|
||||||
"Thread {} is processing request {}. Let me think through this step by step. \
|
"Task {} is processing request {}. Let me think through this step by step. \
|
||||||
First, I need to understand the problem. The problem involves analyzing data \
|
First, I need to understand the problem. The problem involves analyzing data \
|
||||||
and making calculations. Let me break this down: \n\
|
and making calculations. Let me break this down: \n\
|
||||||
1. Initial analysis shows that we have multiple variables to consider. \
|
1. Initial analysis shows that we have multiple variables to consider. \
|
||||||
@@ -436,19 +425,19 @@ mod tests {
|
|||||||
7. Validating against known constraints... \
|
7. Validating against known constraints... \
|
||||||
8. The conclusion follows logically from premises A, B, and C. \
|
8. The conclusion follows logically from premises A, B, and C. \
|
||||||
This reasoning chain demonstrates the validity of our approach.",
|
This reasoning chain demonstrates the validity of our approach.",
|
||||||
thread_id, request_id, thread_id, request_id, thread_id * request_id
|
task_id, request_id, task_id, request_id, task_id * request_id
|
||||||
);
|
);
|
||||||
|
|
||||||
let answer_text = format!(
|
let answer_text = format!(
|
||||||
"Based on my analysis, the answer for thread {} request {} is: \
|
"Based on my analysis, the answer for task {} request {} is: \
|
||||||
The solution involves multiple steps as outlined in the reasoning. \
|
The solution involves multiple steps as outlined in the reasoning. \
|
||||||
The final result is {} with confidence level high. \
|
The final result is {} with confidence level high. \
|
||||||
This conclusion is supported by rigorous mathematical analysis \
|
This conclusion is supported by rigorous mathematical analysis \
|
||||||
and has been validated against multiple test cases. \
|
and has been validated against multiple test cases. \
|
||||||
The implementation should handle edge cases appropriately.",
|
The implementation should handle edge cases appropriately.",
|
||||||
thread_id,
|
task_id,
|
||||||
request_id,
|
request_id,
|
||||||
thread_id * request_id
|
task_id * request_id
|
||||||
);
|
);
|
||||||
|
|
||||||
let input = format!("<think>{}</think>{}", reasoning_text, answer_text);
|
let input = format!("<think>{}</think>{}", reasoning_text, answer_text);
|
||||||
@@ -456,16 +445,14 @@ mod tests {
|
|||||||
match p.detect_and_parse_reasoning(&input) {
|
match p.detect_and_parse_reasoning(&input) {
|
||||||
Ok(result) => {
|
Ok(result) => {
|
||||||
// Note: Some parsers with stream_reasoning=true won't accumulate reasoning text
|
// Note: Some parsers with stream_reasoning=true won't accumulate reasoning text
|
||||||
assert!(result
|
assert!(result.normal_text.contains(&format!("task {}", task_id)));
|
||||||
.normal_text
|
|
||||||
.contains(&format!("thread {}", thread_id)));
|
|
||||||
|
|
||||||
// For parsers that accumulate reasoning (stream_reasoning=false)
|
// For parsers that accumulate reasoning (stream_reasoning=false)
|
||||||
// the reasoning_text should be populated
|
// the reasoning_text should be populated
|
||||||
if !result.reasoning_text.is_empty() {
|
if !result.reasoning_text.is_empty() {
|
||||||
assert!(result
|
assert!(result
|
||||||
.reasoning_text
|
.reasoning_text
|
||||||
.contains(&format!("Thread {}", thread_id)));
|
.contains(&format!("Task {}", task_id)));
|
||||||
assert!(result.reasoning_text.len() > 500); // Ensure substantial reasoning
|
assert!(result.reasoning_text.len() > 500); // Ensure substantial reasoning
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -486,20 +473,20 @@ mod tests {
|
|||||||
handles.push(handle);
|
handles.push(handle);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wait for all threads
|
// Wait for all tasks
|
||||||
for handle in handles {
|
for handle in handles {
|
||||||
handle.join().unwrap();
|
handle.await.unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
let duration = start.elapsed();
|
let duration = start.elapsed();
|
||||||
let total_requests = num_threads * requests_per_thread;
|
let total_requests = num_tasks * requests_per_task;
|
||||||
let successes = success_count.load(Ordering::Relaxed);
|
let successes = success_count.load(Ordering::Relaxed);
|
||||||
let errors = error_count.load(Ordering::Relaxed);
|
let errors = error_count.load(Ordering::Relaxed);
|
||||||
|
|
||||||
// Print stats for debugging
|
// Print stats for debugging
|
||||||
println!(
|
println!(
|
||||||
"High concurrency test: {} threads, {} requests each",
|
"High concurrency test: {} tasks, {} requests each",
|
||||||
num_threads, requests_per_thread
|
num_tasks, requests_per_task
|
||||||
);
|
);
|
||||||
println!(
|
println!(
|
||||||
"Completed in {:?}, {} successes, {} errors",
|
"Completed in {:?}, {} successes, {} errors",
|
||||||
@@ -523,42 +510,40 @@ mod tests {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||||
fn test_concurrent_pool_modifications() {
|
async fn test_concurrent_pool_modifications() {
|
||||||
use std::thread;
|
|
||||||
|
|
||||||
let factory = ReasoningParserFactory::new();
|
let factory = ReasoningParserFactory::new();
|
||||||
let mut handles = vec![];
|
let mut handles = vec![];
|
||||||
|
|
||||||
// Thread 1: Continuously get parsers
|
// Task 1: Continuously get parsers
|
||||||
let factory1 = factory.clone();
|
let factory1 = factory.clone();
|
||||||
handles.push(thread::spawn(move || {
|
handles.push(tokio::spawn(async move {
|
||||||
for _ in 0..100 {
|
for _ in 0..100 {
|
||||||
let _parser = factory1.get_pooled("deepseek-r1");
|
let _parser = factory1.get_pooled("deepseek-r1");
|
||||||
}
|
}
|
||||||
}));
|
}));
|
||||||
|
|
||||||
// Thread 2: Continuously clear pool
|
// Task 2: Continuously clear pool
|
||||||
let factory2 = factory.clone();
|
let factory2 = factory.clone();
|
||||||
handles.push(thread::spawn(move || {
|
handles.push(tokio::spawn(async move {
|
||||||
for _ in 0..10 {
|
for _ in 0..10 {
|
||||||
factory2.clear_pool();
|
factory2.clear_pool();
|
||||||
thread::sleep(std::time::Duration::from_micros(100));
|
tokio::time::sleep(tokio::time::Duration::from_micros(100)).await;
|
||||||
}
|
}
|
||||||
}));
|
}));
|
||||||
|
|
||||||
// Thread 3: Get different parsers
|
// Task 3: Get different parsers
|
||||||
let factory3 = factory.clone();
|
let factory3 = factory.clone();
|
||||||
handles.push(thread::spawn(move || {
|
handles.push(tokio::spawn(async move {
|
||||||
for i in 0..100 {
|
for i in 0..100 {
|
||||||
let models = ["qwen3", "kimi", "unknown"];
|
let models = ["qwen3", "kimi", "unknown"];
|
||||||
let _parser = factory3.get_pooled(models[i % 3]);
|
let _parser = factory3.get_pooled(models[i % 3]);
|
||||||
}
|
}
|
||||||
}));
|
}));
|
||||||
|
|
||||||
// Wait for all threads - should not deadlock or panic
|
// Wait for all tasks - should not deadlock or panic
|
||||||
for handle in handles {
|
for handle in handles {
|
||||||
handle.join().unwrap();
|
handle.await.unwrap();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -48,9 +48,10 @@ pub struct RequestInput {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Request type variants
|
/// Request type variants
|
||||||
|
/// Using Arc instead of Box to enable cheap cloning for background tasks
|
||||||
pub enum RequestType {
|
pub enum RequestType {
|
||||||
Chat(Box<ChatCompletionRequest>),
|
Chat(Arc<ChatCompletionRequest>),
|
||||||
Generate(Box<GenerateRequest>),
|
Generate(Arc<GenerateRequest>),
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Shared components (injected once at creation)
|
/// Shared components (injected once at creation)
|
||||||
@@ -181,14 +182,14 @@ pub struct StreamingState {
|
|||||||
impl RequestContext {
|
impl RequestContext {
|
||||||
/// Create context for chat completion request
|
/// Create context for chat completion request
|
||||||
pub fn for_chat(
|
pub fn for_chat(
|
||||||
request: ChatCompletionRequest,
|
request: Arc<ChatCompletionRequest>,
|
||||||
headers: Option<HeaderMap>,
|
headers: Option<HeaderMap>,
|
||||||
model_id: Option<String>,
|
model_id: Option<String>,
|
||||||
components: Arc<SharedComponents>,
|
components: Arc<SharedComponents>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
input: RequestInput {
|
input: RequestInput {
|
||||||
request_type: RequestType::Chat(Box::new(request)),
|
request_type: RequestType::Chat(request),
|
||||||
headers,
|
headers,
|
||||||
model_id,
|
model_id,
|
||||||
},
|
},
|
||||||
@@ -199,14 +200,14 @@ impl RequestContext {
|
|||||||
|
|
||||||
/// Create context for generate request
|
/// Create context for generate request
|
||||||
pub fn for_generate(
|
pub fn for_generate(
|
||||||
request: GenerateRequest,
|
request: Arc<GenerateRequest>,
|
||||||
headers: Option<HeaderMap>,
|
headers: Option<HeaderMap>,
|
||||||
model_id: Option<String>,
|
model_id: Option<String>,
|
||||||
components: Arc<SharedComponents>,
|
components: Arc<SharedComponents>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
input: RequestInput {
|
input: RequestInput {
|
||||||
request_type: RequestType::Generate(Box::new(request)),
|
request_type: RequestType::Generate(request),
|
||||||
headers,
|
headers,
|
||||||
model_id,
|
model_id,
|
||||||
},
|
},
|
||||||
@@ -228,6 +229,14 @@ impl RequestContext {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Get Arc clone of chat request (panics if not chat)
|
||||||
|
pub fn chat_request_arc(&self) -> Arc<ChatCompletionRequest> {
|
||||||
|
match &self.input.request_type {
|
||||||
|
RequestType::Chat(req) => Arc::clone(req),
|
||||||
|
_ => panic!("Expected chat request"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Get generate request (panics if not generate)
|
/// Get generate request (panics if not generate)
|
||||||
pub fn generate_request(&self) -> &GenerateRequest {
|
pub fn generate_request(&self) -> &GenerateRequest {
|
||||||
match &self.input.request_type {
|
match &self.input.request_type {
|
||||||
@@ -236,6 +245,14 @@ impl RequestContext {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Get Arc clone of generate request (panics if not generate)
|
||||||
|
pub fn generate_request_arc(&self) -> Arc<GenerateRequest> {
|
||||||
|
match &self.input.request_type {
|
||||||
|
RequestType::Generate(req) => Arc::clone(req),
|
||||||
|
_ => panic!("Expected generate request"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Check if request is streaming
|
/// Check if request is streaming
|
||||||
pub fn is_streaming(&self) -> bool {
|
pub fn is_streaming(&self) -> bool {
|
||||||
match &self.input.request_type {
|
match &self.input.request_type {
|
||||||
|
|||||||
@@ -129,7 +129,7 @@ impl GrpcPDRouter {
|
|||||||
// Use pipeline for ALL requests (streaming and non-streaming)
|
// Use pipeline for ALL requests (streaming and non-streaming)
|
||||||
self.pipeline
|
self.pipeline
|
||||||
.execute_generate(
|
.execute_generate(
|
||||||
body.clone(),
|
Arc::new(body.clone()),
|
||||||
headers.cloned(),
|
headers.cloned(),
|
||||||
model_id.map(|s| s.to_string()),
|
model_id.map(|s| s.to_string()),
|
||||||
self.shared_components.clone(),
|
self.shared_components.clone(),
|
||||||
@@ -152,7 +152,7 @@ impl GrpcPDRouter {
|
|||||||
// Use pipeline for ALL requests (streaming and non-streaming)
|
// Use pipeline for ALL requests (streaming and non-streaming)
|
||||||
self.pipeline
|
self.pipeline
|
||||||
.execute_chat(
|
.execute_chat(
|
||||||
body.clone(),
|
Arc::new(body.clone()),
|
||||||
headers.cloned(),
|
headers.cloned(),
|
||||||
model_id.map(|s| s.to_string()),
|
model_id.map(|s| s.to_string()),
|
||||||
self.shared_components.clone(),
|
self.shared_components.clone(),
|
||||||
|
|||||||
@@ -58,16 +58,17 @@ impl PipelineStage for PreparationStage {
|
|||||||
async fn execute(&self, ctx: &mut RequestContext) -> Result<Option<Response>, Response> {
|
async fn execute(&self, ctx: &mut RequestContext) -> Result<Option<Response>, Response> {
|
||||||
debug!("Stage {}: Processing request", self.name());
|
debug!("Stage {}: Processing request", self.name());
|
||||||
|
|
||||||
// Clone the request to avoid borrowing issues
|
// Clone Arc before match to avoid borrow checker issues
|
||||||
match &ctx.input.request_type {
|
// (matching borrows ctx, but prepare_* methods need mutable borrow)
|
||||||
RequestType::Chat(request) => {
|
// Arc clone is cheap (8 bytes) - avoids full request clone (15KB-200KB)
|
||||||
let request_clone = request.clone();
|
let is_chat = matches!(&ctx.input.request_type, RequestType::Chat(_));
|
||||||
self.prepare_chat(ctx, &request_clone).await?;
|
|
||||||
}
|
if is_chat {
|
||||||
RequestType::Generate(request) => {
|
let request_arc = ctx.chat_request_arc();
|
||||||
let request_clone = request.clone();
|
self.prepare_chat(ctx, &request_arc).await?;
|
||||||
self.prepare_generate(ctx, &request_clone).await?;
|
} else {
|
||||||
}
|
let request_arc = ctx.generate_request_arc();
|
||||||
|
self.prepare_generate(ctx, &request_arc).await?;
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(None)
|
Ok(None)
|
||||||
@@ -820,7 +821,7 @@ impl ResponseProcessingStage {
|
|||||||
return Ok(Some(
|
return Ok(Some(
|
||||||
self.streaming_processor.clone().process_streaming_response(
|
self.streaming_processor.clone().process_streaming_response(
|
||||||
execution_result,
|
execution_result,
|
||||||
ctx.chat_request().clone(),
|
ctx.chat_request_arc(), // Cheap Arc clone (8 bytes)
|
||||||
dispatch.clone(),
|
dispatch.clone(),
|
||||||
),
|
),
|
||||||
));
|
));
|
||||||
@@ -865,9 +866,7 @@ impl ResponseProcessingStage {
|
|||||||
return Err(utils::internal_error_static("No responses from server"));
|
return Err(utils::internal_error_static("No responses from server"));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Clone chat_request to avoid borrow checker conflict
|
let chat_request = ctx.chat_request_arc();
|
||||||
// (ctx.chat_request() borrows ctx, preventing mutable borrow of ctx.state.response.stop_decoder)
|
|
||||||
let chat_request = ctx.chat_request().clone();
|
|
||||||
let history_tool_calls_count = utils::get_history_tool_calls_count(&chat_request);
|
let history_tool_calls_count = utils::get_history_tool_calls_count(&chat_request);
|
||||||
|
|
||||||
let stop_decoder = ctx
|
let stop_decoder = ctx
|
||||||
@@ -959,13 +958,11 @@ impl ResponseProcessingStage {
|
|||||||
.as_ref()
|
.as_ref()
|
||||||
.ok_or_else(|| utils::internal_error_static("Dispatch metadata not set"))?;
|
.ok_or_else(|| utils::internal_error_static("Dispatch metadata not set"))?;
|
||||||
|
|
||||||
let generate_request = ctx.generate_request().clone();
|
|
||||||
|
|
||||||
// Streaming: Use StreamingProcessor and return SSE response (done)
|
// Streaming: Use StreamingProcessor and return SSE response (done)
|
||||||
return Ok(Some(
|
return Ok(Some(
|
||||||
self.streaming_processor.clone().process_streaming_generate(
|
self.streaming_processor.clone().process_streaming_generate(
|
||||||
execution_result,
|
execution_result,
|
||||||
generate_request,
|
ctx.generate_request_arc(), // Cheap Arc clone (8 bytes)
|
||||||
dispatch.clone(),
|
dispatch.clone(),
|
||||||
),
|
),
|
||||||
));
|
));
|
||||||
@@ -1193,8 +1190,8 @@ impl ChatCompletionPipeline {
|
|||||||
/// Execute the complete pipeline for a chat request
|
/// Execute the complete pipeline for a chat request
|
||||||
pub async fn execute_chat(
|
pub async fn execute_chat(
|
||||||
&self,
|
&self,
|
||||||
request: ChatCompletionRequest,
|
request: Arc<ChatCompletionRequest>,
|
||||||
headers: Option<axum::http::HeaderMap>,
|
headers: Option<http::HeaderMap>,
|
||||||
model_id: Option<String>,
|
model_id: Option<String>,
|
||||||
components: Arc<SharedComponents>,
|
components: Arc<SharedComponents>,
|
||||||
) -> Response {
|
) -> Response {
|
||||||
@@ -1243,8 +1240,8 @@ impl ChatCompletionPipeline {
|
|||||||
/// Execute the complete pipeline for a generate request
|
/// Execute the complete pipeline for a generate request
|
||||||
pub async fn execute_generate(
|
pub async fn execute_generate(
|
||||||
&self,
|
&self,
|
||||||
request: GenerateRequest,
|
request: Arc<GenerateRequest>,
|
||||||
headers: Option<axum::http::HeaderMap>,
|
headers: Option<http::HeaderMap>,
|
||||||
model_id: Option<String>,
|
model_id: Option<String>,
|
||||||
components: Arc<SharedComponents>,
|
components: Arc<SharedComponents>,
|
||||||
) -> Response {
|
) -> Response {
|
||||||
|
|||||||
@@ -97,9 +97,7 @@ impl ResponseProcessor {
|
|||||||
&original_request.model,
|
&original_request.model,
|
||||||
);
|
);
|
||||||
|
|
||||||
let mut parser = pooled_parser
|
let mut parser = pooled_parser.lock().await;
|
||||||
.lock()
|
|
||||||
.map_err(|e| format!("Failed to acquire reasoning parser lock: {}", e))?;
|
|
||||||
match parser.detect_and_parse_reasoning(&processed_text) {
|
match parser.detect_and_parse_reasoning(&processed_text) {
|
||||||
Ok(result) => {
|
Ok(result) => {
|
||||||
if !result.reasoning_text.is_empty() {
|
if !result.reasoning_text.is_empty() {
|
||||||
|
|||||||
@@ -129,7 +129,7 @@ impl GrpcRouter {
|
|||||||
// Use pipeline for ALL requests (streaming and non-streaming)
|
// Use pipeline for ALL requests (streaming and non-streaming)
|
||||||
self.pipeline
|
self.pipeline
|
||||||
.execute_chat(
|
.execute_chat(
|
||||||
body.clone(),
|
Arc::new(body.clone()),
|
||||||
headers.cloned(),
|
headers.cloned(),
|
||||||
model_id.map(|s| s.to_string()),
|
model_id.map(|s| s.to_string()),
|
||||||
self.shared_components.clone(),
|
self.shared_components.clone(),
|
||||||
@@ -149,7 +149,7 @@ impl GrpcRouter {
|
|||||||
// Use pipeline for ALL requests (streaming and non-streaming)
|
// Use pipeline for ALL requests (streaming and non-streaming)
|
||||||
self.pipeline
|
self.pipeline
|
||||||
.execute_generate(
|
.execute_generate(
|
||||||
body.clone(),
|
Arc::new(body.clone()),
|
||||||
headers.cloned(),
|
headers.cloned(),
|
||||||
model_id.map(|s| s.to_string()),
|
model_id.map(|s| s.to_string()),
|
||||||
self.shared_components.clone(),
|
self.shared_components.clone(),
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ impl StreamingProcessor {
|
|||||||
pub fn process_streaming_response(
|
pub fn process_streaming_response(
|
||||||
self: Arc<Self>,
|
self: Arc<Self>,
|
||||||
execution_result: context::ExecutionResult,
|
execution_result: context::ExecutionResult,
|
||||||
chat_request: ChatCompletionRequest,
|
chat_request: Arc<ChatCompletionRequest>,
|
||||||
dispatch: context::DispatchMetadata,
|
dispatch: context::DispatchMetadata,
|
||||||
) -> Response {
|
) -> Response {
|
||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
@@ -156,7 +156,7 @@ impl StreamingProcessor {
|
|||||||
mut grpc_stream: Streaming<proto::GenerateResponse>,
|
mut grpc_stream: Streaming<proto::GenerateResponse>,
|
||||||
dispatch: context::DispatchMetadata,
|
dispatch: context::DispatchMetadata,
|
||||||
stop_params: (Option<StringOrArray>, Option<Vec<u32>>, bool, bool),
|
stop_params: (Option<StringOrArray>, Option<Vec<u32>>, bool, bool),
|
||||||
original_request: ChatCompletionRequest,
|
original_request: Arc<ChatCompletionRequest>,
|
||||||
tx: &UnboundedSender<Result<Bytes, io::Error>>,
|
tx: &UnboundedSender<Result<Bytes, io::Error>>,
|
||||||
) -> Result<(), String> {
|
) -> Result<(), String> {
|
||||||
// Extract request parameters
|
// Extract request parameters
|
||||||
@@ -176,7 +176,7 @@ impl StreamingProcessor {
|
|||||||
let mut cached_tokens: HashMap<u32, u32> = HashMap::new();
|
let mut cached_tokens: HashMap<u32, u32> = HashMap::new();
|
||||||
|
|
||||||
// Parser state (lazy initialization per index)
|
// Parser state (lazy initialization per index)
|
||||||
type PooledReasoningParser = Arc<std::sync::Mutex<Box<dyn ReasoningParser>>>;
|
type PooledReasoningParser = Arc<tokio::sync::Mutex<Box<dyn ReasoningParser>>>;
|
||||||
let mut reasoning_parsers: HashMap<u32, PooledReasoningParser> = HashMap::new();
|
let mut reasoning_parsers: HashMap<u32, PooledReasoningParser> = HashMap::new();
|
||||||
|
|
||||||
type PooledToolParser = Arc<tokio::sync::Mutex<Box<dyn ToolParser>>>;
|
type PooledToolParser = Arc<tokio::sync::Mutex<Box<dyn ToolParser>>>;
|
||||||
@@ -186,6 +186,9 @@ impl StreamingProcessor {
|
|||||||
// Per-index stop decoders (each index needs its own state for n>1 support)
|
// Per-index stop decoders (each index needs its own state for n>1 support)
|
||||||
let mut stop_decoders: HashMap<u32, StopSequenceDecoder> = HashMap::new();
|
let mut stop_decoders: HashMap<u32, StopSequenceDecoder> = HashMap::new();
|
||||||
|
|
||||||
|
// Reusable SSE formatting buffer to avoid allocations per chunk
|
||||||
|
let mut sse_buffer = Vec::with_capacity(512);
|
||||||
|
|
||||||
// Use dispatch metadata for consistent response fields
|
// Use dispatch metadata for consistent response fields
|
||||||
let request_id = &dispatch.request_id;
|
let request_id = &dispatch.request_id;
|
||||||
let model = &dispatch.model;
|
let model = &dispatch.model;
|
||||||
@@ -262,7 +265,8 @@ impl StreamingProcessor {
|
|||||||
}],
|
}],
|
||||||
usage: None,
|
usage: None,
|
||||||
};
|
};
|
||||||
tx.send(Ok(Bytes::from(Self::format_sse_chunk(&first_chunk))))
|
Self::format_sse_chunk_into(&mut sse_buffer, &first_chunk);
|
||||||
|
tx.send(Ok(Bytes::from(sse_buffer.clone())))
|
||||||
.map_err(|_| "Failed to send first chunk".to_string())?;
|
.map_err(|_| "Failed to send first chunk".to_string())?;
|
||||||
is_firsts.insert(index, false);
|
is_firsts.insert(index, false);
|
||||||
}
|
}
|
||||||
@@ -282,9 +286,11 @@ impl StreamingProcessor {
|
|||||||
model,
|
model,
|
||||||
created,
|
created,
|
||||||
system_fingerprint,
|
system_fingerprint,
|
||||||
);
|
)
|
||||||
|
.await;
|
||||||
if let Some(chunk) = reasoning_chunk {
|
if let Some(chunk) = reasoning_chunk {
|
||||||
tx.send(Ok(Bytes::from(Self::format_sse_chunk(&chunk))))
|
Self::format_sse_chunk_into(&mut sse_buffer, &chunk);
|
||||||
|
tx.send(Ok(Bytes::from(sse_buffer.clone())))
|
||||||
.map_err(|_| "Failed to send reasoning chunk".to_string())?;
|
.map_err(|_| "Failed to send reasoning chunk".to_string())?;
|
||||||
}
|
}
|
||||||
delta = normal_text;
|
delta = normal_text;
|
||||||
@@ -314,7 +320,8 @@ impl StreamingProcessor {
|
|||||||
.await;
|
.await;
|
||||||
|
|
||||||
for chunk in tool_chunks {
|
for chunk in tool_chunks {
|
||||||
tx.send(Ok(Bytes::from(Self::format_sse_chunk(&chunk))))
|
Self::format_sse_chunk_into(&mut sse_buffer, &chunk);
|
||||||
|
tx.send(Ok(Bytes::from(sse_buffer.clone())))
|
||||||
.map_err(|_| "Failed to send tool call chunk".to_string())?;
|
.map_err(|_| "Failed to send tool call chunk".to_string())?;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -335,7 +342,8 @@ impl StreamingProcessor {
|
|||||||
system_fingerprint,
|
system_fingerprint,
|
||||||
choice_logprobs,
|
choice_logprobs,
|
||||||
);
|
);
|
||||||
tx.send(Ok(Bytes::from(Self::format_sse_chunk(&content_chunk))))
|
Self::format_sse_chunk_into(&mut sse_buffer, &content_chunk);
|
||||||
|
tx.send(Ok(Bytes::from(sse_buffer.clone())))
|
||||||
.map_err(|_| "Failed to send content chunk".to_string())?;
|
.map_err(|_| "Failed to send content chunk".to_string())?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -529,7 +537,7 @@ impl StreamingProcessor {
|
|||||||
decode_stream: Streaming<proto::GenerateResponse>,
|
decode_stream: Streaming<proto::GenerateResponse>,
|
||||||
dispatch: context::DispatchMetadata,
|
dispatch: context::DispatchMetadata,
|
||||||
stop_params: (Option<StringOrArray>, Option<Vec<u32>>, bool, bool),
|
stop_params: (Option<StringOrArray>, Option<Vec<u32>>, bool, bool),
|
||||||
original_request: ChatCompletionRequest,
|
original_request: Arc<ChatCompletionRequest>,
|
||||||
tx: &UnboundedSender<Result<Bytes, io::Error>>,
|
tx: &UnboundedSender<Result<Bytes, io::Error>>,
|
||||||
) -> Result<(), String> {
|
) -> Result<(), String> {
|
||||||
// Phase 1.5: Collect input_logprobs from prefill stream if requested
|
// Phase 1.5: Collect input_logprobs from prefill stream if requested
|
||||||
@@ -561,7 +569,7 @@ impl StreamingProcessor {
|
|||||||
pub fn process_streaming_generate(
|
pub fn process_streaming_generate(
|
||||||
self: Arc<Self>,
|
self: Arc<Self>,
|
||||||
execution_result: context::ExecutionResult,
|
execution_result: context::ExecutionResult,
|
||||||
generate_request: GenerateRequest,
|
generate_request: Arc<GenerateRequest>,
|
||||||
dispatch: context::DispatchMetadata,
|
dispatch: context::DispatchMetadata,
|
||||||
) -> Response {
|
) -> Response {
|
||||||
let return_logprob = generate_request.return_logprob;
|
let return_logprob = generate_request.return_logprob;
|
||||||
@@ -946,11 +954,11 @@ impl StreamingProcessor {
|
|||||||
|
|
||||||
/// Helper: Process reasoning content in streaming mode
|
/// Helper: Process reasoning content in streaming mode
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
fn process_reasoning_stream(
|
async fn process_reasoning_stream(
|
||||||
&self,
|
&self,
|
||||||
delta: &str,
|
delta: &str,
|
||||||
index: u32,
|
index: u32,
|
||||||
reasoning_parsers: &mut HashMap<u32, Arc<std::sync::Mutex<Box<dyn ReasoningParser>>>>,
|
reasoning_parsers: &mut HashMap<u32, Arc<tokio::sync::Mutex<Box<dyn ReasoningParser>>>>,
|
||||||
request_id: &str,
|
request_id: &str,
|
||||||
model: &str,
|
model: &str,
|
||||||
created: u64,
|
created: u64,
|
||||||
@@ -967,7 +975,7 @@ impl StreamingProcessor {
|
|||||||
|
|
||||||
if let Some(pooled_parser) = reasoning_parsers.get(&index) {
|
if let Some(pooled_parser) = reasoning_parsers.get(&index) {
|
||||||
let (parse_result, in_reasoning) = {
|
let (parse_result, in_reasoning) = {
|
||||||
let mut parser = pooled_parser.lock().unwrap();
|
let mut parser = pooled_parser.lock().await;
|
||||||
let result = parser.parse_reasoning_streaming_incremental(delta);
|
let result = parser.parse_reasoning_streaming_incremental(delta);
|
||||||
let in_reasoning = parser.is_in_reasoning();
|
let in_reasoning = parser.is_in_reasoning();
|
||||||
(result, in_reasoning)
|
(result, in_reasoning)
|
||||||
@@ -1134,15 +1142,20 @@ impl StreamingProcessor {
|
|||||||
(false, chunks)
|
(false, chunks)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Format a response as SSE chunk
|
/// Format a response as SSE chunk into a reusable buffer
|
||||||
fn format_sse_chunk(chunk: &ChatCompletionStreamResponse) -> String {
|
/// This avoids allocations by reusing the same buffer across multiple chunks
|
||||||
match serde_json::to_string(chunk) {
|
#[inline]
|
||||||
Ok(json) => format!("data: {}\n\n", json),
|
fn format_sse_chunk_into(buffer: &mut Vec<u8>, chunk: &ChatCompletionStreamResponse) {
|
||||||
Err(e) => {
|
buffer.clear();
|
||||||
error!("Failed to serialize SSE chunk: {}", e);
|
buffer.extend_from_slice(b"data: ");
|
||||||
format!("data: {}\n\n", json!({"error": "serialization_failed"}))
|
if let Err(e) = serde_json::to_writer(&mut *buffer, chunk) {
|
||||||
}
|
error!("Failed to serialize SSE chunk: {}", e);
|
||||||
|
buffer.clear();
|
||||||
|
buffer.extend_from_slice(b"data: ");
|
||||||
|
let error_msg = json!({"error": "serialization_failed"}).to_string();
|
||||||
|
buffer.extend_from_slice(error_msg.as_bytes());
|
||||||
}
|
}
|
||||||
|
buffer.extend_from_slice(b"\n\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create a content chunk response
|
/// Create a content chunk response
|
||||||
|
|||||||
Reference in New Issue
Block a user