[router] leverage RAII to actively cancel request during client disconnect (#11399)
This commit is contained in:
@@ -1,7 +1,11 @@
|
||||
use std::convert::TryFrom;
|
||||
use std::pin::Pin;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::task::{Context, Poll};
|
||||
use std::time::Duration;
|
||||
use tonic::{transport::Channel, Request, Streaming};
|
||||
use tracing::debug;
|
||||
use tracing::{debug, warn};
|
||||
|
||||
use crate::protocols::spec::{
|
||||
ChatCompletionRequest, GenerateRequest, ResponseFormat,
|
||||
@@ -16,6 +20,92 @@ pub mod proto {
|
||||
// The generated module structure depends on the package name in the .proto file
|
||||
// package sglang.grpc.scheduler; generates a nested module structure
|
||||
|
||||
/// A smart wrapper around Streaming<GenerateResponse> that automatically
|
||||
/// sends abort when dropped (e.g., due to client disconnection or early termination).
|
||||
///
|
||||
/// This leverages Rust's RAII pattern to ensure cleanup happens automatically,
|
||||
/// regardless of how the stream is dropped (panic, early return, client disconnect, etc.).
|
||||
pub struct AbortOnDropStream {
|
||||
inner: Streaming<proto::GenerateResponse>,
|
||||
request_id: String,
|
||||
client: SglangSchedulerClient,
|
||||
aborted: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
impl AbortOnDropStream {
|
||||
/// Create a new auto-aborting stream wrapper
|
||||
pub fn new(
|
||||
stream: Streaming<proto::GenerateResponse>,
|
||||
request_id: String,
|
||||
client: SglangSchedulerClient,
|
||||
) -> Self {
|
||||
debug!("Created AbortOnDropStream for request {}", request_id);
|
||||
Self {
|
||||
inner: stream,
|
||||
request_id,
|
||||
client,
|
||||
aborted: Arc::new(AtomicBool::new(false)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Manually mark the request as completed to prevent abort on drop.
|
||||
/// Call this when the request completes successfully to avoid unnecessary abort RPC.
|
||||
pub fn mark_completed(&self) {
|
||||
// Use Release ordering to ensure that this write is visible to other threads
|
||||
// that use Acquire on the same atomic variable
|
||||
self.aborted.store(true, Ordering::Release);
|
||||
debug!("Request {} marked as completed", self.request_id);
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for AbortOnDropStream {
|
||||
fn drop(&mut self) {
|
||||
// Atomically check and set the aborted flag using compare_exchange.
|
||||
// If compare_exchange fails, it means the flag was already true (from mark_completed),
|
||||
// so we don't need to send abort. AcqRel is used for success to synchronize with
|
||||
// mark_completed's Release, and Acquire for failure to see writes from mark_completed.
|
||||
if self
|
||||
.aborted
|
||||
.compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
|
||||
.is_err()
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
let client = self.client.clone();
|
||||
let request_id = self.request_id.clone();
|
||||
|
||||
// Spawn a background task to send abort (since Drop is sync but abort_request is async)
|
||||
tokio::spawn(async move {
|
||||
debug!(
|
||||
"Stream dropped without completion for request {}, sending abort",
|
||||
request_id
|
||||
);
|
||||
// Clone request_id for the error message since abort_request takes ownership
|
||||
let request_id_for_log = request_id.clone();
|
||||
if let Err(e) = client
|
||||
.abort_request(request_id, "Stream dropped".to_string())
|
||||
.await
|
||||
{
|
||||
warn!(
|
||||
"Failed to send abort on drop for request {}: {}",
|
||||
request_id_for_log, e
|
||||
);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Implement Stream trait to make AbortOnDropStream work like the original Streaming
|
||||
impl futures::Stream for AbortOnDropStream {
|
||||
type Item = Result<proto::GenerateResponse, tonic::Status>;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
// Delegate to the inner stream
|
||||
Pin::new(&mut self.inner).poll_next(cx)
|
||||
}
|
||||
}
|
||||
|
||||
/// gRPC client for SGLang scheduler
|
||||
#[derive(Clone)]
|
||||
pub struct SglangSchedulerClient {
|
||||
@@ -35,7 +125,7 @@ impl SglangSchedulerClient {
|
||||
};
|
||||
|
||||
let channel = Channel::from_shared(http_endpoint)?
|
||||
.timeout(Duration::from_secs(3600))
|
||||
.timeout(Duration::from_secs(600)) // 10 minute timeout for connection
|
||||
.http2_keep_alive_interval(Duration::from_secs(30))
|
||||
.keep_alive_timeout(Duration::from_secs(10))
|
||||
.keep_alive_while_idle(true)
|
||||
@@ -52,15 +142,26 @@ impl SglangSchedulerClient {
|
||||
Ok(Self { client })
|
||||
}
|
||||
|
||||
/// Submit a generation request (returns streaming response)
|
||||
/// Submit a generation request (returns auto-aborting streaming response)
|
||||
///
|
||||
/// The returned stream automatically sends an abort request when dropped,
|
||||
/// ensuring proper cleanup even if the HTTP client disconnects or an error occurs.
|
||||
/// Call `mark_completed()` on the stream after successful completion to prevent
|
||||
/// unnecessary abort RPCs.
|
||||
pub async fn generate(
|
||||
&self,
|
||||
req: proto::GenerateRequest,
|
||||
) -> Result<Streaming<proto::GenerateResponse>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
) -> Result<AbortOnDropStream, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let request_id = req.request_id.clone();
|
||||
let mut client = self.client.clone();
|
||||
let request = Request::new(req);
|
||||
let response = client.generate(request).await?;
|
||||
Ok(response.into_inner())
|
||||
|
||||
Ok(AbortOnDropStream::new(
|
||||
response.into_inner(),
|
||||
request_id,
|
||||
self.clone(),
|
||||
))
|
||||
}
|
||||
|
||||
/// Perform health check
|
||||
@@ -68,12 +169,8 @@ impl SglangSchedulerClient {
|
||||
&self,
|
||||
) -> Result<proto::HealthCheckResponse, Box<dyn std::error::Error + Send + Sync>> {
|
||||
debug!("Sending health check request");
|
||||
let request = Request::new(proto::HealthCheckRequest {
|
||||
tokenized: Some(proto::TokenizedInput {
|
||||
original_text: "Hello".to_string(),
|
||||
input_ids: vec![9906], // Mock token ID for "Hello"
|
||||
}),
|
||||
});
|
||||
// Server ignores the request body and creates its own health check internally
|
||||
let request = Request::new(proto::HealthCheckRequest { tokenized: None });
|
||||
|
||||
let mut client = self.client.clone();
|
||||
let response = client.health_check(request).await?;
|
||||
@@ -87,10 +184,23 @@ impl SglangSchedulerClient {
|
||||
request_id: String,
|
||||
reason: String,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let request = Request::new(proto::AbortRequest { request_id, reason });
|
||||
debug!(
|
||||
"Sending abort request for {} (reason: {})",
|
||||
request_id, reason
|
||||
);
|
||||
let request = Request::new(proto::AbortRequest {
|
||||
request_id: request_id.clone(),
|
||||
reason,
|
||||
});
|
||||
|
||||
let mut client = self.client.clone();
|
||||
client.abort(request).await?;
|
||||
let response = client.abort(request).await?;
|
||||
debug!(
|
||||
"Abort response for {}: success={}, message={}",
|
||||
request_id,
|
||||
response.get_ref().success,
|
||||
response.get_ref().message
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user