From c495833186e7571463c7b5db2864928ac645046f Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Fri, 10 Oct 2025 20:43:38 -0400 Subject: [PATCH] [router] leverage RAII to actively cancel request during client disconnect (#11399) --- .../srt/entrypoints/grpc_request_manager.py | 116 +++++++++++---- python/sglang/srt/entrypoints/grpc_server.py | 7 - .../src/grpc_client/sglang_scheduler.rs | 136 ++++++++++++++++-- sgl-router/src/routers/grpc/context.rs | 9 +- sgl-router/src/routers/grpc/pipeline.rs | 50 +++++-- sgl-router/src/routers/grpc/streaming.rs | 51 +++++-- sgl-router/src/routers/grpc/utils.rs | 6 +- 7 files changed, 297 insertions(+), 78 deletions(-) diff --git a/python/sglang/srt/entrypoints/grpc_request_manager.py b/python/sglang/srt/entrypoints/grpc_request_manager.py index 7351f4de3..a8acb4bc4 100644 --- a/python/sglang/srt/entrypoints/grpc_request_manager.py +++ b/python/sglang/srt/entrypoints/grpc_request_manager.py @@ -319,13 +319,8 @@ class GrpcRequestManager: is_stream = getattr(obj, "stream", False) while True: - # Client cancelled - notify scheduler and exit - if grpc_context and grpc_context.cancelled(): - await self.abort_request(request_id) - return - try: - response = await asyncio.wait_for(state.out_queue.get(), timeout=4) + response = await state.out_queue.get() if is_stream: yield response @@ -338,10 +333,11 @@ class GrpcRequestManager: yield final_response break - except asyncio.TimeoutError: - # Timeout is for periodic client cancellation check - # Continue waiting for scheduler response - continue + except asyncio.CancelledError: + # Task was cancelled by gRPC framework when client disconnected + logger.info(f"Request {request_id} cancelled by client") + await self.abort_request(request_id) + raise # Re-raise to let gRPC server handle cleanup finally: # Always clean up request state when exiting @@ -409,31 +405,31 @@ class GrpcRequestManager: return future async def abort_request(self, request_id: str) -> bool: - """Abort a running request.""" + """Abort a running request. + + Sends abort request to scheduler and marks local state as finished + to stop processing any further outputs from the scheduler. + """ # Skip aborting health check requests (they clean themselves up) if request_id.startswith("HEALTH_CHECK"): return False - if request_id not in self.rid_to_state: - return False - - # Send abort to scheduler - abort_req = AbortReq(rid=request_id) - try: - await self._send_to_scheduler(abort_req) - except Exception as e: - logger.error(f"Failed to send abort request: {e}") - return False - - # Mark as finished + # Mark state as finished immediately to stop processing scheduler outputs state = self.rid_to_state.get(request_id) if state: state.finished = True state.stream_finished = True - state.event.set() + logger.debug(f"Marked request {request_id} as aborted locally") - # Send abort notification to output queue - await state.out_queue.put({"error": "Request aborted", "abort": True}) + # Send abort to scheduler - the scheduler will send AbortReq back + # which will be handled by _handle_abort_req + abort_req = AbortReq(rid=request_id) + try: + await self._send_to_scheduler(abort_req) + logger.debug(f"Sent abort to scheduler for request {request_id}") + except Exception as e: + logger.error(f"Failed to send abort request to scheduler: {e}") + return False return True @@ -460,6 +456,8 @@ class GrpcRequestManager: await self._handle_embedding_output(recv_obj) elif isinstance(recv_obj, HealthCheckOutput): await self._handle_health_check_output(recv_obj) + elif isinstance(recv_obj, AbortReq): + await self._handle_abort_req(recv_obj) else: logger.warning(f"Unknown output type: {type(recv_obj)}") @@ -541,6 +539,11 @@ class GrpcRequestManager: state = self.rid_to_state[rid] + # Skip if already aborted/finished locally (client cancelled) + if state.finished: + logger.debug(f"Skipping output for aborted request {rid}") + continue + # Update metrics now = time.time() if state.first_token_time == 0.0: @@ -713,6 +716,67 @@ class GrpcRequestManager: state.finished_time = time.time() state.event.set() + async def _handle_abort_req(self, recv_obj: AbortReq): + """Handle abort request from scheduler. + + The scheduler sends AbortReq back to notify us that a request was aborted, + either due to explicit abort_request() call or scheduler-initiated abort + (priority preemption, queue full, KV cache pressure, etc). + """ + # Skip health check requests + if recv_obj.rid.startswith("HEALTH_CHECK"): + return + + # Check if request still exists + if recv_obj.rid not in self.rid_to_state: + logger.debug( + f"Abort request for {recv_obj.rid} not in local state (may have already finished or not started yet)" + ) + return + + state = self.rid_to_state[recv_obj.rid] + + # Mark as finished + state.finished = True + state.stream_finished = True + + # Create abort response + if recv_obj.finished_reason: + # Scheduler provided a specific finish reason (e.g., priority preemption, queue full) + abort_response = { + "request_id": recv_obj.rid, + "error": recv_obj.finished_reason.get("message", "Request aborted"), + "finished": True, + "meta_info": { + "id": recv_obj.rid, + "finish_reason": recv_obj.finished_reason, + }, + } + else: + # Generic abort (e.g., explicit abort_request call) + abort_response = { + "request_id": recv_obj.rid, + "error": "Request aborted", + "finished": True, + "meta_info": { + "id": recv_obj.rid, + "finish_reason": { + "type": "abort", + "message": "Abort before prefill", + }, + "prompt_tokens": 0, + "completion_tokens": 0, + }, + } + + # Send abort notification to output queue + await state.out_queue.put(abort_response) + + # Wake up any waiting coroutines + state.event.set() + + logger.debug(f"Handled abort request for {recv_obj.rid}") + async def _send_to_scheduler(self, obj): """Send an object to the scheduler via ZMQ.""" try: diff --git a/python/sglang/srt/entrypoints/grpc_server.py b/python/sglang/srt/entrypoints/grpc_server.py index f9c7c72fd..4841092b5 100644 --- a/python/sglang/srt/entrypoints/grpc_server.py +++ b/python/sglang/srt/entrypoints/grpc_server.py @@ -211,13 +211,6 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer) ) async for output in response_generator: - # Check if client cancelled before processing/yielding - if context.cancelled(): - logger.info(f"Client cancelled request {request.request_id}") - # Explicitly abort the request to notify scheduler - await self.request_manager.abort_request(request.request_id) - break - # Handle batch responses (for n>1 non-streaming) if isinstance(output, list): for batch_output in output: diff --git a/sgl-router/src/grpc_client/sglang_scheduler.rs b/sgl-router/src/grpc_client/sglang_scheduler.rs index 799db14a4..3086abea6 100644 --- a/sgl-router/src/grpc_client/sglang_scheduler.rs +++ b/sgl-router/src/grpc_client/sglang_scheduler.rs @@ -1,7 +1,11 @@ use std::convert::TryFrom; +use std::pin::Pin; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; +use std::task::{Context, Poll}; use std::time::Duration; use tonic::{transport::Channel, Request, Streaming}; -use tracing::debug; +use tracing::{debug, warn}; use crate::protocols::spec::{ ChatCompletionRequest, GenerateRequest, ResponseFormat, @@ -16,6 +20,92 @@ pub mod proto { // The generated module structure depends on the package name in the .proto file // package sglang.grpc.scheduler; generates a nested module structure +/// A smart wrapper around Streaming that automatically +/// sends abort when dropped (e.g., due to client disconnection or early termination). +/// +/// This leverages Rust's RAII pattern to ensure cleanup happens automatically, +/// regardless of how the stream is dropped (panic, early return, client disconnect, etc.). +pub struct AbortOnDropStream { + inner: Streaming, + request_id: String, + client: SglangSchedulerClient, + aborted: Arc, +} + +impl AbortOnDropStream { + /// Create a new auto-aborting stream wrapper + pub fn new( + stream: Streaming, + request_id: String, + client: SglangSchedulerClient, + ) -> Self { + debug!("Created AbortOnDropStream for request {}", request_id); + Self { + inner: stream, + request_id, + client, + aborted: Arc::new(AtomicBool::new(false)), + } + } + + /// Manually mark the request as completed to prevent abort on drop. + /// Call this when the request completes successfully to avoid unnecessary abort RPC. + pub fn mark_completed(&self) { + // Use Release ordering to ensure that this write is visible to other threads + // that use Acquire on the same atomic variable + self.aborted.store(true, Ordering::Release); + debug!("Request {} marked as completed", self.request_id); + } +} + +impl Drop for AbortOnDropStream { + fn drop(&mut self) { + // Atomically check and set the aborted flag using compare_exchange. + // If compare_exchange fails, it means the flag was already true (from mark_completed), + // so we don't need to send abort. AcqRel is used for success to synchronize with + // mark_completed's Release, and Acquire for failure to see writes from mark_completed. + if self + .aborted + .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire) + .is_err() + { + return; + } + + let client = self.client.clone(); + let request_id = self.request_id.clone(); + + // Spawn a background task to send abort (since Drop is sync but abort_request is async) + tokio::spawn(async move { + debug!( + "Stream dropped without completion for request {}, sending abort", + request_id + ); + // Clone request_id for the error message since abort_request takes ownership + let request_id_for_log = request_id.clone(); + if let Err(e) = client + .abort_request(request_id, "Stream dropped".to_string()) + .await + { + warn!( + "Failed to send abort on drop for request {}: {}", + request_id_for_log, e + ); + } + }); + } +} + +// Implement Stream trait to make AbortOnDropStream work like the original Streaming +impl futures::Stream for AbortOnDropStream { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + // Delegate to the inner stream + Pin::new(&mut self.inner).poll_next(cx) + } +} + /// gRPC client for SGLang scheduler #[derive(Clone)] pub struct SglangSchedulerClient { @@ -35,7 +125,7 @@ impl SglangSchedulerClient { }; let channel = Channel::from_shared(http_endpoint)? - .timeout(Duration::from_secs(3600)) + .timeout(Duration::from_secs(600)) // 10 minute timeout for connection .http2_keep_alive_interval(Duration::from_secs(30)) .keep_alive_timeout(Duration::from_secs(10)) .keep_alive_while_idle(true) @@ -52,15 +142,26 @@ impl SglangSchedulerClient { Ok(Self { client }) } - /// Submit a generation request (returns streaming response) + /// Submit a generation request (returns auto-aborting streaming response) + /// + /// The returned stream automatically sends an abort request when dropped, + /// ensuring proper cleanup even if the HTTP client disconnects or an error occurs. + /// Call `mark_completed()` on the stream after successful completion to prevent + /// unnecessary abort RPCs. pub async fn generate( &self, req: proto::GenerateRequest, - ) -> Result, Box> { + ) -> Result> { + let request_id = req.request_id.clone(); let mut client = self.client.clone(); let request = Request::new(req); let response = client.generate(request).await?; - Ok(response.into_inner()) + + Ok(AbortOnDropStream::new( + response.into_inner(), + request_id, + self.clone(), + )) } /// Perform health check @@ -68,12 +169,8 @@ impl SglangSchedulerClient { &self, ) -> Result> { debug!("Sending health check request"); - let request = Request::new(proto::HealthCheckRequest { - tokenized: Some(proto::TokenizedInput { - original_text: "Hello".to_string(), - input_ids: vec![9906], // Mock token ID for "Hello" - }), - }); + // Server ignores the request body and creates its own health check internally + let request = Request::new(proto::HealthCheckRequest { tokenized: None }); let mut client = self.client.clone(); let response = client.health_check(request).await?; @@ -87,10 +184,23 @@ impl SglangSchedulerClient { request_id: String, reason: String, ) -> Result<(), Box> { - let request = Request::new(proto::AbortRequest { request_id, reason }); + debug!( + "Sending abort request for {} (reason: {})", + request_id, reason + ); + let request = Request::new(proto::AbortRequest { + request_id: request_id.clone(), + reason, + }); let mut client = self.client.clone(); - client.abort(request).await?; + let response = client.abort(request).await?; + debug!( + "Abort response for {}: success={}, message={}", + request_id, + response.get_ref().success, + response.get_ref().message + ); Ok(()) } diff --git a/sgl-router/src/routers/grpc/context.rs b/sgl-router/src/routers/grpc/context.rs index 50e713fe2..edd5a94d7 100644 --- a/sgl-router/src/routers/grpc/context.rs +++ b/sgl-router/src/routers/grpc/context.rs @@ -371,16 +371,17 @@ impl ClientSelection { // Execution and Response Types // ============================================================================ -use tonic::codec::Streaming; +use crate::grpc_client::sglang_scheduler::AbortOnDropStream; /// Result of request execution (streams from workers) +/// Uses AbortOnDropStream to automatically abort on cancellation pub enum ExecutionResult { Single { - stream: Streaming, + stream: AbortOnDropStream, }, Dual { - prefill: Streaming, - decode: Box>, + prefill: AbortOnDropStream, + decode: Box, }, } diff --git a/sgl-router/src/routers/grpc/pipeline.rs b/sgl-router/src/routers/grpc/pipeline.rs index 7f8ed2387..a4d8eb617 100644 --- a/sgl-router/src/routers/grpc/pipeline.rs +++ b/sgl-router/src/routers/grpc/pipeline.rs @@ -816,16 +816,27 @@ impl ResponseProcessingStage { // Collect all responses from the execution result let all_responses = match execution_result { - ExecutionResult::Single { stream } => { - utils::collect_stream_responses(stream, "Single").await? + ExecutionResult::Single { mut stream } => { + let responses = utils::collect_stream_responses(&mut stream, "Single").await?; + stream.mark_completed(); + responses } - ExecutionResult::Dual { prefill, decode } => { - // Collect prefill for input_logprobs - let prefill_responses = utils::collect_stream_responses(prefill, "Prefill").await?; + ExecutionResult::Dual { + mut prefill, + decode, + } => { + // Collect prefill for input_logprobs (don't mark completed yet) + let prefill_responses = + utils::collect_stream_responses(&mut prefill, "Prefill").await?; - // Collect decode for actual output + // Collect decode for actual output (don't mark completed yet) + let mut decode_stream = *decode; let mut decode_responses = - utils::collect_stream_responses(*decode, "Decode").await?; + utils::collect_stream_responses(&mut decode_stream, "Decode").await?; + + // Mark both streams as completed now that both succeeded + prefill.mark_completed(); + decode_stream.mark_completed(); // Merge prefill input_logprobs if requested if request_logprobs { @@ -952,16 +963,27 @@ impl ResponseProcessingStage { // Non-streaming: Collect all responses let request_logprobs = ctx.generate_request().return_logprob; let all_responses = match execution_result { - ExecutionResult::Single { stream } => { - utils::collect_stream_responses(stream, "Single").await? + ExecutionResult::Single { mut stream } => { + let responses = utils::collect_stream_responses(&mut stream, "Single").await?; + stream.mark_completed(); + responses } - ExecutionResult::Dual { prefill, decode } => { - // Collect prefill for input_logprobs - let prefill_responses = utils::collect_stream_responses(prefill, "Prefill").await?; + ExecutionResult::Dual { + mut prefill, + decode, + } => { + // Collect prefill for input_logprobs (don't mark completed yet) + let prefill_responses = + utils::collect_stream_responses(&mut prefill, "Prefill").await?; - // Collect decode for actual output + // Collect decode for actual output (don't mark completed yet) + let mut decode_stream = *decode; let mut decode_responses = - utils::collect_stream_responses(*decode, "Decode").await?; + utils::collect_stream_responses(&mut decode_stream, "Decode").await?; + + // Mark both streams as completed now that both succeeded + prefill.mark_completed(); + decode_stream.mark_completed(); // Merge prefill input_logprobs if requested if request_logprobs { diff --git a/sgl-router/src/routers/grpc/streaming.rs b/sgl-router/src/routers/grpc/streaming.rs index 1e7707767..d932c5818 100644 --- a/sgl-router/src/routers/grpc/streaming.rs +++ b/sgl-router/src/routers/grpc/streaming.rs @@ -14,7 +14,6 @@ use std::sync::Arc; use tokio::sync::mpsc::UnboundedSender; use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::StreamExt; -use tonic::codec::Streaming; use tracing::{debug, error, warn}; use super::context; @@ -153,7 +152,7 @@ impl StreamingProcessor { /// Process streaming chunks from a single stream (Regular mode) pub async fn process_streaming_chunks( &self, - mut grpc_stream: Streaming, + mut grpc_stream: crate::grpc_client::sglang_scheduler::AbortOnDropStream, dispatch: context::DispatchMetadata, stop_params: (Option, Option>, bool, bool), original_request: Arc, @@ -571,14 +570,17 @@ impl StreamingProcessor { } } + // Mark stream as completed successfully to prevent abort on drop + grpc_stream.mark_completed(); + Ok(()) } /// Process dual streaming chunks (prefill + decode) - PD mode pub async fn process_dual_streaming_chunks( &self, - mut prefill_stream: Streaming, - decode_stream: Streaming, + mut prefill_stream: crate::grpc_client::sglang_scheduler::AbortOnDropStream, + decode_stream: crate::grpc_client::sglang_scheduler::AbortOnDropStream, dispatch: context::DispatchMetadata, stop_params: (Option, Option>, bool, bool), original_request: Arc, @@ -603,8 +605,18 @@ impl StreamingProcessor { } // Phase 2-5: Process decode stream (same as single mode) - self.process_streaming_chunks(decode_stream, dispatch, stop_params, original_request, tx) - .await + // Note: decode_stream will be marked completed inside process_streaming_chunks + let result = self + .process_streaming_chunks(decode_stream, dispatch, stop_params, original_request, tx) + .await; + + // Mark prefill stream as completed AFTER decode completes successfully + // This ensures that if client disconnects during decode, BOTH streams send abort + if result.is_ok() { + prefill_stream.mark_completed(); + } + + result } /// Process streaming generate response and return SSE response @@ -687,7 +699,7 @@ impl StreamingProcessor { /// Process streaming chunks for generate endpoint (no tool/reasoning parsing) async fn process_generate_streaming( tokenizer: Arc, - mut stream: Streaming, + mut stream: crate::grpc_client::sglang_scheduler::AbortOnDropStream, request_id: String, weight_version: String, _include_logprobs: bool, @@ -782,14 +794,17 @@ impl StreamingProcessor { } } + // Mark stream as completed successfully to prevent abort on drop + stream.mark_completed(); + Ok(()) } /// Process dual streaming for generate endpoint (PD mode with logprobs support) async fn process_generate_streaming_dual( tokenizer: Arc, - mut prefill_stream: Streaming, - decode_stream: Streaming, + mut prefill_stream: crate::grpc_client::sglang_scheduler::AbortOnDropStream, + decode_stream: crate::grpc_client::sglang_scheduler::AbortOnDropStream, request_id: String, weight_version: String, return_logprob: bool, @@ -821,7 +836,8 @@ impl StreamingProcessor { }; // Process decode stream with input_logprobs prepended - Self::process_generate_streaming_with_input_logprobs( + // Note: decode_stream will be marked completed inside the function + let result = Self::process_generate_streaming_with_input_logprobs( tokenizer, decode_stream, request_id, @@ -830,13 +846,21 @@ impl StreamingProcessor { input_token_logprobs, tx, ) - .await + .await; + + // Mark prefill stream as completed AFTER decode completes successfully + // This ensures that if client disconnects during decode, BOTH streams send abort + if result.is_ok() { + prefill_stream.mark_completed(); + } + + result } /// Process generate streaming with optional input_logprobs async fn process_generate_streaming_with_input_logprobs( tokenizer: Arc, - mut stream: Streaming, + mut stream: crate::grpc_client::sglang_scheduler::AbortOnDropStream, request_id: String, weight_version: String, _include_logprobs: bool, @@ -957,6 +981,9 @@ impl StreamingProcessor { } } + // Mark stream as completed successfully to prevent abort on drop + stream.mark_completed(); + Ok(()) } diff --git a/sgl-router/src/routers/grpc/utils.rs b/sgl-router/src/routers/grpc/utils.rs index 4422671bf..b217ba815 100644 --- a/sgl-router/src/routers/grpc/utils.rs +++ b/sgl-router/src/routers/grpc/utils.rs @@ -2,6 +2,7 @@ use super::ProcessedMessages; use crate::core::Worker; +use crate::grpc_client::sglang_scheduler::AbortOnDropStream; use crate::grpc_client::{proto, SglangSchedulerClient}; use crate::protocols::spec::{ ChatCompletionRequest, ChatLogProbs, ChatLogProbsContent, ChatMessage, FunctionCallResponse, @@ -20,7 +21,6 @@ use futures::StreamExt; use serde_json::{json, Map, Value}; use std::collections::HashMap; use std::sync::Arc; -use tonic::codec::Streaming; use tracing::{error, warn}; use uuid::Uuid; @@ -590,7 +590,7 @@ pub fn parse_json_schema_response( /// * `Ok(Vec)` - All complete responses collected from the stream /// * `Err(Response)` - Error response if the stream fails or returns an error pub async fn collect_stream_responses( - mut stream: Streaming, + stream: &mut AbortOnDropStream, worker_name: &str, ) -> Result, Response> { use proto::generate_response::Response::*; @@ -606,6 +606,7 @@ pub async fn collect_stream_responses( } Some(Error(err)) => { error!("{} error: {}", worker_name, err.message); + // Don't mark as completed - let Drop send abort for error cases return Err(internal_error_message(format!( "{} generation failed: {}", worker_name, err.message @@ -621,6 +622,7 @@ pub async fn collect_stream_responses( } Err(e) => { error!("{} stream error: {:?}", worker_name, e); + // Don't mark as completed - let Drop send abort for error cases return Err(internal_error_message(format!( "{} stream failed: {}", worker_name, e