[router] improve reasoning parser lock and reduce req cloning (#11336)
This commit is contained in:
@@ -48,9 +48,10 @@ pub struct RequestInput {
|
||||
}
|
||||
|
||||
/// Request type variants
|
||||
/// Using Arc instead of Box to enable cheap cloning for background tasks
|
||||
pub enum RequestType {
|
||||
Chat(Box<ChatCompletionRequest>),
|
||||
Generate(Box<GenerateRequest>),
|
||||
Chat(Arc<ChatCompletionRequest>),
|
||||
Generate(Arc<GenerateRequest>),
|
||||
}
|
||||
|
||||
/// Shared components (injected once at creation)
|
||||
@@ -181,14 +182,14 @@ pub struct StreamingState {
|
||||
impl RequestContext {
|
||||
/// Create context for chat completion request
|
||||
pub fn for_chat(
|
||||
request: ChatCompletionRequest,
|
||||
request: Arc<ChatCompletionRequest>,
|
||||
headers: Option<HeaderMap>,
|
||||
model_id: Option<String>,
|
||||
components: Arc<SharedComponents>,
|
||||
) -> Self {
|
||||
Self {
|
||||
input: RequestInput {
|
||||
request_type: RequestType::Chat(Box::new(request)),
|
||||
request_type: RequestType::Chat(request),
|
||||
headers,
|
||||
model_id,
|
||||
},
|
||||
@@ -199,14 +200,14 @@ impl RequestContext {
|
||||
|
||||
/// Create context for generate request
|
||||
pub fn for_generate(
|
||||
request: GenerateRequest,
|
||||
request: Arc<GenerateRequest>,
|
||||
headers: Option<HeaderMap>,
|
||||
model_id: Option<String>,
|
||||
components: Arc<SharedComponents>,
|
||||
) -> Self {
|
||||
Self {
|
||||
input: RequestInput {
|
||||
request_type: RequestType::Generate(Box::new(request)),
|
||||
request_type: RequestType::Generate(request),
|
||||
headers,
|
||||
model_id,
|
||||
},
|
||||
@@ -228,6 +229,14 @@ impl RequestContext {
|
||||
}
|
||||
}
|
||||
|
||||
/// Get Arc clone of chat request (panics if not chat)
|
||||
pub fn chat_request_arc(&self) -> Arc<ChatCompletionRequest> {
|
||||
match &self.input.request_type {
|
||||
RequestType::Chat(req) => Arc::clone(req),
|
||||
_ => panic!("Expected chat request"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get generate request (panics if not generate)
|
||||
pub fn generate_request(&self) -> &GenerateRequest {
|
||||
match &self.input.request_type {
|
||||
@@ -236,6 +245,14 @@ impl RequestContext {
|
||||
}
|
||||
}
|
||||
|
||||
/// Get Arc clone of generate request (panics if not generate)
|
||||
pub fn generate_request_arc(&self) -> Arc<GenerateRequest> {
|
||||
match &self.input.request_type {
|
||||
RequestType::Generate(req) => Arc::clone(req),
|
||||
_ => panic!("Expected generate request"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if request is streaming
|
||||
pub fn is_streaming(&self) -> bool {
|
||||
match &self.input.request_type {
|
||||
|
||||
@@ -129,7 +129,7 @@ impl GrpcPDRouter {
|
||||
// Use pipeline for ALL requests (streaming and non-streaming)
|
||||
self.pipeline
|
||||
.execute_generate(
|
||||
body.clone(),
|
||||
Arc::new(body.clone()),
|
||||
headers.cloned(),
|
||||
model_id.map(|s| s.to_string()),
|
||||
self.shared_components.clone(),
|
||||
@@ -152,7 +152,7 @@ impl GrpcPDRouter {
|
||||
// Use pipeline for ALL requests (streaming and non-streaming)
|
||||
self.pipeline
|
||||
.execute_chat(
|
||||
body.clone(),
|
||||
Arc::new(body.clone()),
|
||||
headers.cloned(),
|
||||
model_id.map(|s| s.to_string()),
|
||||
self.shared_components.clone(),
|
||||
|
||||
@@ -58,16 +58,17 @@ impl PipelineStage for PreparationStage {
|
||||
async fn execute(&self, ctx: &mut RequestContext) -> Result<Option<Response>, Response> {
|
||||
debug!("Stage {}: Processing request", self.name());
|
||||
|
||||
// Clone the request to avoid borrowing issues
|
||||
match &ctx.input.request_type {
|
||||
RequestType::Chat(request) => {
|
||||
let request_clone = request.clone();
|
||||
self.prepare_chat(ctx, &request_clone).await?;
|
||||
}
|
||||
RequestType::Generate(request) => {
|
||||
let request_clone = request.clone();
|
||||
self.prepare_generate(ctx, &request_clone).await?;
|
||||
}
|
||||
// Clone Arc before match to avoid borrow checker issues
|
||||
// (matching borrows ctx, but prepare_* methods need mutable borrow)
|
||||
// Arc clone is cheap (8 bytes) - avoids full request clone (15KB-200KB)
|
||||
let is_chat = matches!(&ctx.input.request_type, RequestType::Chat(_));
|
||||
|
||||
if is_chat {
|
||||
let request_arc = ctx.chat_request_arc();
|
||||
self.prepare_chat(ctx, &request_arc).await?;
|
||||
} else {
|
||||
let request_arc = ctx.generate_request_arc();
|
||||
self.prepare_generate(ctx, &request_arc).await?;
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
@@ -820,7 +821,7 @@ impl ResponseProcessingStage {
|
||||
return Ok(Some(
|
||||
self.streaming_processor.clone().process_streaming_response(
|
||||
execution_result,
|
||||
ctx.chat_request().clone(),
|
||||
ctx.chat_request_arc(), // Cheap Arc clone (8 bytes)
|
||||
dispatch.clone(),
|
||||
),
|
||||
));
|
||||
@@ -865,9 +866,7 @@ impl ResponseProcessingStage {
|
||||
return Err(utils::internal_error_static("No responses from server"));
|
||||
}
|
||||
|
||||
// Clone chat_request to avoid borrow checker conflict
|
||||
// (ctx.chat_request() borrows ctx, preventing mutable borrow of ctx.state.response.stop_decoder)
|
||||
let chat_request = ctx.chat_request().clone();
|
||||
let chat_request = ctx.chat_request_arc();
|
||||
let history_tool_calls_count = utils::get_history_tool_calls_count(&chat_request);
|
||||
|
||||
let stop_decoder = ctx
|
||||
@@ -959,13 +958,11 @@ impl ResponseProcessingStage {
|
||||
.as_ref()
|
||||
.ok_or_else(|| utils::internal_error_static("Dispatch metadata not set"))?;
|
||||
|
||||
let generate_request = ctx.generate_request().clone();
|
||||
|
||||
// Streaming: Use StreamingProcessor and return SSE response (done)
|
||||
return Ok(Some(
|
||||
self.streaming_processor.clone().process_streaming_generate(
|
||||
execution_result,
|
||||
generate_request,
|
||||
ctx.generate_request_arc(), // Cheap Arc clone (8 bytes)
|
||||
dispatch.clone(),
|
||||
),
|
||||
));
|
||||
@@ -1193,8 +1190,8 @@ impl ChatCompletionPipeline {
|
||||
/// Execute the complete pipeline for a chat request
|
||||
pub async fn execute_chat(
|
||||
&self,
|
||||
request: ChatCompletionRequest,
|
||||
headers: Option<axum::http::HeaderMap>,
|
||||
request: Arc<ChatCompletionRequest>,
|
||||
headers: Option<http::HeaderMap>,
|
||||
model_id: Option<String>,
|
||||
components: Arc<SharedComponents>,
|
||||
) -> Response {
|
||||
@@ -1243,8 +1240,8 @@ impl ChatCompletionPipeline {
|
||||
/// Execute the complete pipeline for a generate request
|
||||
pub async fn execute_generate(
|
||||
&self,
|
||||
request: GenerateRequest,
|
||||
headers: Option<axum::http::HeaderMap>,
|
||||
request: Arc<GenerateRequest>,
|
||||
headers: Option<http::HeaderMap>,
|
||||
model_id: Option<String>,
|
||||
components: Arc<SharedComponents>,
|
||||
) -> Response {
|
||||
|
||||
@@ -97,9 +97,7 @@ impl ResponseProcessor {
|
||||
&original_request.model,
|
||||
);
|
||||
|
||||
let mut parser = pooled_parser
|
||||
.lock()
|
||||
.map_err(|e| format!("Failed to acquire reasoning parser lock: {}", e))?;
|
||||
let mut parser = pooled_parser.lock().await;
|
||||
match parser.detect_and_parse_reasoning(&processed_text) {
|
||||
Ok(result) => {
|
||||
if !result.reasoning_text.is_empty() {
|
||||
|
||||
@@ -129,7 +129,7 @@ impl GrpcRouter {
|
||||
// Use pipeline for ALL requests (streaming and non-streaming)
|
||||
self.pipeline
|
||||
.execute_chat(
|
||||
body.clone(),
|
||||
Arc::new(body.clone()),
|
||||
headers.cloned(),
|
||||
model_id.map(|s| s.to_string()),
|
||||
self.shared_components.clone(),
|
||||
@@ -149,7 +149,7 @@ impl GrpcRouter {
|
||||
// Use pipeline for ALL requests (streaming and non-streaming)
|
||||
self.pipeline
|
||||
.execute_generate(
|
||||
body.clone(),
|
||||
Arc::new(body.clone()),
|
||||
headers.cloned(),
|
||||
model_id.map(|s| s.to_string()),
|
||||
self.shared_components.clone(),
|
||||
|
||||
@@ -66,7 +66,7 @@ impl StreamingProcessor {
|
||||
pub fn process_streaming_response(
|
||||
self: Arc<Self>,
|
||||
execution_result: context::ExecutionResult,
|
||||
chat_request: ChatCompletionRequest,
|
||||
chat_request: Arc<ChatCompletionRequest>,
|
||||
dispatch: context::DispatchMetadata,
|
||||
) -> Response {
|
||||
use bytes::Bytes;
|
||||
@@ -156,7 +156,7 @@ impl StreamingProcessor {
|
||||
mut grpc_stream: Streaming<proto::GenerateResponse>,
|
||||
dispatch: context::DispatchMetadata,
|
||||
stop_params: (Option<StringOrArray>, Option<Vec<u32>>, bool, bool),
|
||||
original_request: ChatCompletionRequest,
|
||||
original_request: Arc<ChatCompletionRequest>,
|
||||
tx: &UnboundedSender<Result<Bytes, io::Error>>,
|
||||
) -> Result<(), String> {
|
||||
// Extract request parameters
|
||||
@@ -176,7 +176,7 @@ impl StreamingProcessor {
|
||||
let mut cached_tokens: HashMap<u32, u32> = HashMap::new();
|
||||
|
||||
// Parser state (lazy initialization per index)
|
||||
type PooledReasoningParser = Arc<std::sync::Mutex<Box<dyn ReasoningParser>>>;
|
||||
type PooledReasoningParser = Arc<tokio::sync::Mutex<Box<dyn ReasoningParser>>>;
|
||||
let mut reasoning_parsers: HashMap<u32, PooledReasoningParser> = HashMap::new();
|
||||
|
||||
type PooledToolParser = Arc<tokio::sync::Mutex<Box<dyn ToolParser>>>;
|
||||
@@ -186,6 +186,9 @@ impl StreamingProcessor {
|
||||
// Per-index stop decoders (each index needs its own state for n>1 support)
|
||||
let mut stop_decoders: HashMap<u32, StopSequenceDecoder> = HashMap::new();
|
||||
|
||||
// Reusable SSE formatting buffer to avoid allocations per chunk
|
||||
let mut sse_buffer = Vec::with_capacity(512);
|
||||
|
||||
// Use dispatch metadata for consistent response fields
|
||||
let request_id = &dispatch.request_id;
|
||||
let model = &dispatch.model;
|
||||
@@ -262,7 +265,8 @@ impl StreamingProcessor {
|
||||
}],
|
||||
usage: None,
|
||||
};
|
||||
tx.send(Ok(Bytes::from(Self::format_sse_chunk(&first_chunk))))
|
||||
Self::format_sse_chunk_into(&mut sse_buffer, &first_chunk);
|
||||
tx.send(Ok(Bytes::from(sse_buffer.clone())))
|
||||
.map_err(|_| "Failed to send first chunk".to_string())?;
|
||||
is_firsts.insert(index, false);
|
||||
}
|
||||
@@ -282,9 +286,11 @@ impl StreamingProcessor {
|
||||
model,
|
||||
created,
|
||||
system_fingerprint,
|
||||
);
|
||||
)
|
||||
.await;
|
||||
if let Some(chunk) = reasoning_chunk {
|
||||
tx.send(Ok(Bytes::from(Self::format_sse_chunk(&chunk))))
|
||||
Self::format_sse_chunk_into(&mut sse_buffer, &chunk);
|
||||
tx.send(Ok(Bytes::from(sse_buffer.clone())))
|
||||
.map_err(|_| "Failed to send reasoning chunk".to_string())?;
|
||||
}
|
||||
delta = normal_text;
|
||||
@@ -314,7 +320,8 @@ impl StreamingProcessor {
|
||||
.await;
|
||||
|
||||
for chunk in tool_chunks {
|
||||
tx.send(Ok(Bytes::from(Self::format_sse_chunk(&chunk))))
|
||||
Self::format_sse_chunk_into(&mut sse_buffer, &chunk);
|
||||
tx.send(Ok(Bytes::from(sse_buffer.clone())))
|
||||
.map_err(|_| "Failed to send tool call chunk".to_string())?;
|
||||
}
|
||||
|
||||
@@ -335,7 +342,8 @@ impl StreamingProcessor {
|
||||
system_fingerprint,
|
||||
choice_logprobs,
|
||||
);
|
||||
tx.send(Ok(Bytes::from(Self::format_sse_chunk(&content_chunk))))
|
||||
Self::format_sse_chunk_into(&mut sse_buffer, &content_chunk);
|
||||
tx.send(Ok(Bytes::from(sse_buffer.clone())))
|
||||
.map_err(|_| "Failed to send content chunk".to_string())?;
|
||||
}
|
||||
}
|
||||
@@ -529,7 +537,7 @@ impl StreamingProcessor {
|
||||
decode_stream: Streaming<proto::GenerateResponse>,
|
||||
dispatch: context::DispatchMetadata,
|
||||
stop_params: (Option<StringOrArray>, Option<Vec<u32>>, bool, bool),
|
||||
original_request: ChatCompletionRequest,
|
||||
original_request: Arc<ChatCompletionRequest>,
|
||||
tx: &UnboundedSender<Result<Bytes, io::Error>>,
|
||||
) -> Result<(), String> {
|
||||
// Phase 1.5: Collect input_logprobs from prefill stream if requested
|
||||
@@ -561,7 +569,7 @@ impl StreamingProcessor {
|
||||
pub fn process_streaming_generate(
|
||||
self: Arc<Self>,
|
||||
execution_result: context::ExecutionResult,
|
||||
generate_request: GenerateRequest,
|
||||
generate_request: Arc<GenerateRequest>,
|
||||
dispatch: context::DispatchMetadata,
|
||||
) -> Response {
|
||||
let return_logprob = generate_request.return_logprob;
|
||||
@@ -946,11 +954,11 @@ impl StreamingProcessor {
|
||||
|
||||
/// Helper: Process reasoning content in streaming mode
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn process_reasoning_stream(
|
||||
async fn process_reasoning_stream(
|
||||
&self,
|
||||
delta: &str,
|
||||
index: u32,
|
||||
reasoning_parsers: &mut HashMap<u32, Arc<std::sync::Mutex<Box<dyn ReasoningParser>>>>,
|
||||
reasoning_parsers: &mut HashMap<u32, Arc<tokio::sync::Mutex<Box<dyn ReasoningParser>>>>,
|
||||
request_id: &str,
|
||||
model: &str,
|
||||
created: u64,
|
||||
@@ -967,7 +975,7 @@ impl StreamingProcessor {
|
||||
|
||||
if let Some(pooled_parser) = reasoning_parsers.get(&index) {
|
||||
let (parse_result, in_reasoning) = {
|
||||
let mut parser = pooled_parser.lock().unwrap();
|
||||
let mut parser = pooled_parser.lock().await;
|
||||
let result = parser.parse_reasoning_streaming_incremental(delta);
|
||||
let in_reasoning = parser.is_in_reasoning();
|
||||
(result, in_reasoning)
|
||||
@@ -1134,15 +1142,20 @@ impl StreamingProcessor {
|
||||
(false, chunks)
|
||||
}
|
||||
|
||||
/// Format a response as SSE chunk
|
||||
fn format_sse_chunk(chunk: &ChatCompletionStreamResponse) -> String {
|
||||
match serde_json::to_string(chunk) {
|
||||
Ok(json) => format!("data: {}\n\n", json),
|
||||
Err(e) => {
|
||||
error!("Failed to serialize SSE chunk: {}", e);
|
||||
format!("data: {}\n\n", json!({"error": "serialization_failed"}))
|
||||
}
|
||||
/// Format a response as SSE chunk into a reusable buffer
|
||||
/// This avoids allocations by reusing the same buffer across multiple chunks
|
||||
#[inline]
|
||||
fn format_sse_chunk_into(buffer: &mut Vec<u8>, chunk: &ChatCompletionStreamResponse) {
|
||||
buffer.clear();
|
||||
buffer.extend_from_slice(b"data: ");
|
||||
if let Err(e) = serde_json::to_writer(&mut *buffer, chunk) {
|
||||
error!("Failed to serialize SSE chunk: {}", e);
|
||||
buffer.clear();
|
||||
buffer.extend_from_slice(b"data: ");
|
||||
let error_msg = json!({"error": "serialization_failed"}).to_string();
|
||||
buffer.extend_from_slice(error_msg.as_bytes());
|
||||
}
|
||||
buffer.extend_from_slice(b"\n\n");
|
||||
}
|
||||
|
||||
/// Create a content chunk response
|
||||
|
||||
Reference in New Issue
Block a user