[router] leverage RAII to actively cancel request during client disconnect (#11399)
This commit is contained in:
@@ -371,16 +371,17 @@ impl ClientSelection {
|
||||
// Execution and Response Types
|
||||
// ============================================================================
|
||||
|
||||
use tonic::codec::Streaming;
|
||||
use crate::grpc_client::sglang_scheduler::AbortOnDropStream;
|
||||
|
||||
/// Result of request execution (streams from workers)
|
||||
/// Uses AbortOnDropStream to automatically abort on cancellation
|
||||
pub enum ExecutionResult {
|
||||
Single {
|
||||
stream: Streaming<proto::GenerateResponse>,
|
||||
stream: AbortOnDropStream,
|
||||
},
|
||||
Dual {
|
||||
prefill: Streaming<proto::GenerateResponse>,
|
||||
decode: Box<Streaming<proto::GenerateResponse>>,
|
||||
prefill: AbortOnDropStream,
|
||||
decode: Box<AbortOnDropStream>,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -816,16 +816,27 @@ impl ResponseProcessingStage {
|
||||
|
||||
// Collect all responses from the execution result
|
||||
let all_responses = match execution_result {
|
||||
ExecutionResult::Single { stream } => {
|
||||
utils::collect_stream_responses(stream, "Single").await?
|
||||
ExecutionResult::Single { mut stream } => {
|
||||
let responses = utils::collect_stream_responses(&mut stream, "Single").await?;
|
||||
stream.mark_completed();
|
||||
responses
|
||||
}
|
||||
ExecutionResult::Dual { prefill, decode } => {
|
||||
// Collect prefill for input_logprobs
|
||||
let prefill_responses = utils::collect_stream_responses(prefill, "Prefill").await?;
|
||||
ExecutionResult::Dual {
|
||||
mut prefill,
|
||||
decode,
|
||||
} => {
|
||||
// Collect prefill for input_logprobs (don't mark completed yet)
|
||||
let prefill_responses =
|
||||
utils::collect_stream_responses(&mut prefill, "Prefill").await?;
|
||||
|
||||
// Collect decode for actual output
|
||||
// Collect decode for actual output (don't mark completed yet)
|
||||
let mut decode_stream = *decode;
|
||||
let mut decode_responses =
|
||||
utils::collect_stream_responses(*decode, "Decode").await?;
|
||||
utils::collect_stream_responses(&mut decode_stream, "Decode").await?;
|
||||
|
||||
// Mark both streams as completed now that both succeeded
|
||||
prefill.mark_completed();
|
||||
decode_stream.mark_completed();
|
||||
|
||||
// Merge prefill input_logprobs if requested
|
||||
if request_logprobs {
|
||||
@@ -952,16 +963,27 @@ impl ResponseProcessingStage {
|
||||
// Non-streaming: Collect all responses
|
||||
let request_logprobs = ctx.generate_request().return_logprob;
|
||||
let all_responses = match execution_result {
|
||||
ExecutionResult::Single { stream } => {
|
||||
utils::collect_stream_responses(stream, "Single").await?
|
||||
ExecutionResult::Single { mut stream } => {
|
||||
let responses = utils::collect_stream_responses(&mut stream, "Single").await?;
|
||||
stream.mark_completed();
|
||||
responses
|
||||
}
|
||||
ExecutionResult::Dual { prefill, decode } => {
|
||||
// Collect prefill for input_logprobs
|
||||
let prefill_responses = utils::collect_stream_responses(prefill, "Prefill").await?;
|
||||
ExecutionResult::Dual {
|
||||
mut prefill,
|
||||
decode,
|
||||
} => {
|
||||
// Collect prefill for input_logprobs (don't mark completed yet)
|
||||
let prefill_responses =
|
||||
utils::collect_stream_responses(&mut prefill, "Prefill").await?;
|
||||
|
||||
// Collect decode for actual output
|
||||
// Collect decode for actual output (don't mark completed yet)
|
||||
let mut decode_stream = *decode;
|
||||
let mut decode_responses =
|
||||
utils::collect_stream_responses(*decode, "Decode").await?;
|
||||
utils::collect_stream_responses(&mut decode_stream, "Decode").await?;
|
||||
|
||||
// Mark both streams as completed now that both succeeded
|
||||
prefill.mark_completed();
|
||||
decode_stream.mark_completed();
|
||||
|
||||
// Merge prefill input_logprobs if requested
|
||||
if request_logprobs {
|
||||
|
||||
@@ -14,7 +14,6 @@ use std::sync::Arc;
|
||||
use tokio::sync::mpsc::UnboundedSender;
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use tokio_stream::StreamExt;
|
||||
use tonic::codec::Streaming;
|
||||
use tracing::{debug, error, warn};
|
||||
|
||||
use super::context;
|
||||
@@ -153,7 +152,7 @@ impl StreamingProcessor {
|
||||
/// Process streaming chunks from a single stream (Regular mode)
|
||||
pub async fn process_streaming_chunks(
|
||||
&self,
|
||||
mut grpc_stream: Streaming<proto::GenerateResponse>,
|
||||
mut grpc_stream: crate::grpc_client::sglang_scheduler::AbortOnDropStream,
|
||||
dispatch: context::DispatchMetadata,
|
||||
stop_params: (Option<StringOrArray>, Option<Vec<u32>>, bool, bool),
|
||||
original_request: Arc<ChatCompletionRequest>,
|
||||
@@ -571,14 +570,17 @@ impl StreamingProcessor {
|
||||
}
|
||||
}
|
||||
|
||||
// Mark stream as completed successfully to prevent abort on drop
|
||||
grpc_stream.mark_completed();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Process dual streaming chunks (prefill + decode) - PD mode
|
||||
pub async fn process_dual_streaming_chunks(
|
||||
&self,
|
||||
mut prefill_stream: Streaming<proto::GenerateResponse>,
|
||||
decode_stream: Streaming<proto::GenerateResponse>,
|
||||
mut prefill_stream: crate::grpc_client::sglang_scheduler::AbortOnDropStream,
|
||||
decode_stream: crate::grpc_client::sglang_scheduler::AbortOnDropStream,
|
||||
dispatch: context::DispatchMetadata,
|
||||
stop_params: (Option<StringOrArray>, Option<Vec<u32>>, bool, bool),
|
||||
original_request: Arc<ChatCompletionRequest>,
|
||||
@@ -603,8 +605,18 @@ impl StreamingProcessor {
|
||||
}
|
||||
|
||||
// Phase 2-5: Process decode stream (same as single mode)
|
||||
self.process_streaming_chunks(decode_stream, dispatch, stop_params, original_request, tx)
|
||||
.await
|
||||
// Note: decode_stream will be marked completed inside process_streaming_chunks
|
||||
let result = self
|
||||
.process_streaming_chunks(decode_stream, dispatch, stop_params, original_request, tx)
|
||||
.await;
|
||||
|
||||
// Mark prefill stream as completed AFTER decode completes successfully
|
||||
// This ensures that if client disconnects during decode, BOTH streams send abort
|
||||
if result.is_ok() {
|
||||
prefill_stream.mark_completed();
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Process streaming generate response and return SSE response
|
||||
@@ -687,7 +699,7 @@ impl StreamingProcessor {
|
||||
/// Process streaming chunks for generate endpoint (no tool/reasoning parsing)
|
||||
async fn process_generate_streaming(
|
||||
tokenizer: Arc<dyn Tokenizer>,
|
||||
mut stream: Streaming<proto::GenerateResponse>,
|
||||
mut stream: crate::grpc_client::sglang_scheduler::AbortOnDropStream,
|
||||
request_id: String,
|
||||
weight_version: String,
|
||||
_include_logprobs: bool,
|
||||
@@ -782,14 +794,17 @@ impl StreamingProcessor {
|
||||
}
|
||||
}
|
||||
|
||||
// Mark stream as completed successfully to prevent abort on drop
|
||||
stream.mark_completed();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Process dual streaming for generate endpoint (PD mode with logprobs support)
|
||||
async fn process_generate_streaming_dual(
|
||||
tokenizer: Arc<dyn Tokenizer>,
|
||||
mut prefill_stream: Streaming<proto::GenerateResponse>,
|
||||
decode_stream: Streaming<proto::GenerateResponse>,
|
||||
mut prefill_stream: crate::grpc_client::sglang_scheduler::AbortOnDropStream,
|
||||
decode_stream: crate::grpc_client::sglang_scheduler::AbortOnDropStream,
|
||||
request_id: String,
|
||||
weight_version: String,
|
||||
return_logprob: bool,
|
||||
@@ -821,7 +836,8 @@ impl StreamingProcessor {
|
||||
};
|
||||
|
||||
// Process decode stream with input_logprobs prepended
|
||||
Self::process_generate_streaming_with_input_logprobs(
|
||||
// Note: decode_stream will be marked completed inside the function
|
||||
let result = Self::process_generate_streaming_with_input_logprobs(
|
||||
tokenizer,
|
||||
decode_stream,
|
||||
request_id,
|
||||
@@ -830,13 +846,21 @@ impl StreamingProcessor {
|
||||
input_token_logprobs,
|
||||
tx,
|
||||
)
|
||||
.await
|
||||
.await;
|
||||
|
||||
// Mark prefill stream as completed AFTER decode completes successfully
|
||||
// This ensures that if client disconnects during decode, BOTH streams send abort
|
||||
if result.is_ok() {
|
||||
prefill_stream.mark_completed();
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Process generate streaming with optional input_logprobs
|
||||
async fn process_generate_streaming_with_input_logprobs(
|
||||
tokenizer: Arc<dyn Tokenizer>,
|
||||
mut stream: Streaming<proto::GenerateResponse>,
|
||||
mut stream: crate::grpc_client::sglang_scheduler::AbortOnDropStream,
|
||||
request_id: String,
|
||||
weight_version: String,
|
||||
_include_logprobs: bool,
|
||||
@@ -957,6 +981,9 @@ impl StreamingProcessor {
|
||||
}
|
||||
}
|
||||
|
||||
// Mark stream as completed successfully to prevent abort on drop
|
||||
stream.mark_completed();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
use super::ProcessedMessages;
|
||||
use crate::core::Worker;
|
||||
use crate::grpc_client::sglang_scheduler::AbortOnDropStream;
|
||||
use crate::grpc_client::{proto, SglangSchedulerClient};
|
||||
use crate::protocols::spec::{
|
||||
ChatCompletionRequest, ChatLogProbs, ChatLogProbsContent, ChatMessage, FunctionCallResponse,
|
||||
@@ -20,7 +21,6 @@ use futures::StreamExt;
|
||||
use serde_json::{json, Map, Value};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tonic::codec::Streaming;
|
||||
use tracing::{error, warn};
|
||||
use uuid::Uuid;
|
||||
|
||||
@@ -590,7 +590,7 @@ pub fn parse_json_schema_response(
|
||||
/// * `Ok(Vec<GenerateComplete>)` - All complete responses collected from the stream
|
||||
/// * `Err(Response)` - Error response if the stream fails or returns an error
|
||||
pub async fn collect_stream_responses(
|
||||
mut stream: Streaming<proto::GenerateResponse>,
|
||||
stream: &mut AbortOnDropStream,
|
||||
worker_name: &str,
|
||||
) -> Result<Vec<proto::GenerateComplete>, Response> {
|
||||
use proto::generate_response::Response::*;
|
||||
@@ -606,6 +606,7 @@ pub async fn collect_stream_responses(
|
||||
}
|
||||
Some(Error(err)) => {
|
||||
error!("{} error: {}", worker_name, err.message);
|
||||
// Don't mark as completed - let Drop send abort for error cases
|
||||
return Err(internal_error_message(format!(
|
||||
"{} generation failed: {}",
|
||||
worker_name, err.message
|
||||
@@ -621,6 +622,7 @@ pub async fn collect_stream_responses(
|
||||
}
|
||||
Err(e) => {
|
||||
error!("{} stream error: {:?}", worker_name, e);
|
||||
// Don't mark as completed - let Drop send abort for error cases
|
||||
return Err(internal_error_message(format!(
|
||||
"{} stream failed: {}",
|
||||
worker_name, e
|
||||
|
||||
Reference in New Issue
Block a user