[router] leverage RAII to actively cancel request during client disconnect (#11399)
This commit is contained in:
@@ -319,13 +319,8 @@ class GrpcRequestManager:
|
|||||||
is_stream = getattr(obj, "stream", False)
|
is_stream = getattr(obj, "stream", False)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
# Client cancelled - notify scheduler and exit
|
|
||||||
if grpc_context and grpc_context.cancelled():
|
|
||||||
await self.abort_request(request_id)
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await asyncio.wait_for(state.out_queue.get(), timeout=4)
|
response = await state.out_queue.get()
|
||||||
|
|
||||||
if is_stream:
|
if is_stream:
|
||||||
yield response
|
yield response
|
||||||
@@ -338,10 +333,11 @@ class GrpcRequestManager:
|
|||||||
yield final_response
|
yield final_response
|
||||||
break
|
break
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
except asyncio.CancelledError:
|
||||||
# Timeout is for periodic client cancellation check
|
# Task was cancelled by gRPC framework when client disconnected
|
||||||
# Continue waiting for scheduler response
|
logger.info(f"Request {request_id} cancelled by client")
|
||||||
continue
|
await self.abort_request(request_id)
|
||||||
|
raise # Re-raise to let gRPC server handle cleanup
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# Always clean up request state when exiting
|
# Always clean up request state when exiting
|
||||||
@@ -409,31 +405,31 @@ class GrpcRequestManager:
|
|||||||
return future
|
return future
|
||||||
|
|
||||||
async def abort_request(self, request_id: str) -> bool:
|
async def abort_request(self, request_id: str) -> bool:
|
||||||
"""Abort a running request."""
|
"""Abort a running request.
|
||||||
|
|
||||||
|
Sends abort request to scheduler and marks local state as finished
|
||||||
|
to stop processing any further outputs from the scheduler.
|
||||||
|
"""
|
||||||
# Skip aborting health check requests (they clean themselves up)
|
# Skip aborting health check requests (they clean themselves up)
|
||||||
if request_id.startswith("HEALTH_CHECK"):
|
if request_id.startswith("HEALTH_CHECK"):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if request_id not in self.rid_to_state:
|
# Mark state as finished immediately to stop processing scheduler outputs
|
||||||
return False
|
|
||||||
|
|
||||||
# Send abort to scheduler
|
|
||||||
abort_req = AbortReq(rid=request_id)
|
|
||||||
try:
|
|
||||||
await self._send_to_scheduler(abort_req)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to send abort request: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Mark as finished
|
|
||||||
state = self.rid_to_state.get(request_id)
|
state = self.rid_to_state.get(request_id)
|
||||||
if state:
|
if state:
|
||||||
state.finished = True
|
state.finished = True
|
||||||
state.stream_finished = True
|
state.stream_finished = True
|
||||||
state.event.set()
|
logger.debug(f"Marked request {request_id} as aborted locally")
|
||||||
|
|
||||||
# Send abort notification to output queue
|
# Send abort to scheduler - the scheduler will send AbortReq back
|
||||||
await state.out_queue.put({"error": "Request aborted", "abort": True})
|
# which will be handled by _handle_abort_req
|
||||||
|
abort_req = AbortReq(rid=request_id)
|
||||||
|
try:
|
||||||
|
await self._send_to_scheduler(abort_req)
|
||||||
|
logger.debug(f"Sent abort to scheduler for request {request_id}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to send abort request to scheduler: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -460,6 +456,8 @@ class GrpcRequestManager:
|
|||||||
await self._handle_embedding_output(recv_obj)
|
await self._handle_embedding_output(recv_obj)
|
||||||
elif isinstance(recv_obj, HealthCheckOutput):
|
elif isinstance(recv_obj, HealthCheckOutput):
|
||||||
await self._handle_health_check_output(recv_obj)
|
await self._handle_health_check_output(recv_obj)
|
||||||
|
elif isinstance(recv_obj, AbortReq):
|
||||||
|
await self._handle_abort_req(recv_obj)
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Unknown output type: {type(recv_obj)}")
|
logger.warning(f"Unknown output type: {type(recv_obj)}")
|
||||||
|
|
||||||
@@ -541,6 +539,11 @@ class GrpcRequestManager:
|
|||||||
|
|
||||||
state = self.rid_to_state[rid]
|
state = self.rid_to_state[rid]
|
||||||
|
|
||||||
|
# Skip if already aborted/finished locally (client cancelled)
|
||||||
|
if state.finished:
|
||||||
|
logger.debug(f"Skipping output for aborted request {rid}")
|
||||||
|
continue
|
||||||
|
|
||||||
# Update metrics
|
# Update metrics
|
||||||
now = time.time()
|
now = time.time()
|
||||||
if state.first_token_time == 0.0:
|
if state.first_token_time == 0.0:
|
||||||
@@ -713,6 +716,67 @@ class GrpcRequestManager:
|
|||||||
state.finished_time = time.time()
|
state.finished_time = time.time()
|
||||||
state.event.set()
|
state.event.set()
|
||||||
|
|
||||||
|
async def _handle_abort_req(self, recv_obj: AbortReq):
|
||||||
|
"""Handle abort request from scheduler.
|
||||||
|
|
||||||
|
The scheduler sends AbortReq back to notify us that a request was aborted,
|
||||||
|
either due to explicit abort_request() call or scheduler-initiated abort
|
||||||
|
(priority preemption, queue full, KV cache pressure, etc).
|
||||||
|
"""
|
||||||
|
# Skip health check requests
|
||||||
|
if recv_obj.rid.startswith("HEALTH_CHECK"):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check if request still exists
|
||||||
|
if recv_obj.rid not in self.rid_to_state:
|
||||||
|
logger.debug(
|
||||||
|
f"Abort request for {recv_obj.rid} not in local state (may have already finished or not started yet)"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
state = self.rid_to_state[recv_obj.rid]
|
||||||
|
|
||||||
|
# Mark as finished
|
||||||
|
state.finished = True
|
||||||
|
state.stream_finished = True
|
||||||
|
|
||||||
|
# Create abort response
|
||||||
|
if recv_obj.finished_reason:
|
||||||
|
# Scheduler provided a specific finish reason (e.g., priority preemption, queue full)
|
||||||
|
abort_response = {
|
||||||
|
"request_id": recv_obj.rid,
|
||||||
|
"error": recv_obj.finished_reason.get("message", "Request aborted"),
|
||||||
|
"finished": True,
|
||||||
|
"meta_info": {
|
||||||
|
"id": recv_obj.rid,
|
||||||
|
"finish_reason": recv_obj.finished_reason,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# Generic abort (e.g., explicit abort_request call)
|
||||||
|
abort_response = {
|
||||||
|
"request_id": recv_obj.rid,
|
||||||
|
"error": "Request aborted",
|
||||||
|
"finished": True,
|
||||||
|
"meta_info": {
|
||||||
|
"id": recv_obj.rid,
|
||||||
|
"finish_reason": {
|
||||||
|
"type": "abort",
|
||||||
|
"message": "Abort before prefill",
|
||||||
|
},
|
||||||
|
"prompt_tokens": 0,
|
||||||
|
"completion_tokens": 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Send abort notification to output queue
|
||||||
|
await state.out_queue.put(abort_response)
|
||||||
|
|
||||||
|
# Wake up any waiting coroutines
|
||||||
|
state.event.set()
|
||||||
|
|
||||||
|
logger.debug(f"Handled abort request for {recv_obj.rid}")
|
||||||
|
|
||||||
async def _send_to_scheduler(self, obj):
|
async def _send_to_scheduler(self, obj):
|
||||||
"""Send an object to the scheduler via ZMQ."""
|
"""Send an object to the scheduler via ZMQ."""
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -211,13 +211,6 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
|
|||||||
)
|
)
|
||||||
|
|
||||||
async for output in response_generator:
|
async for output in response_generator:
|
||||||
# Check if client cancelled before processing/yielding
|
|
||||||
if context.cancelled():
|
|
||||||
logger.info(f"Client cancelled request {request.request_id}")
|
|
||||||
# Explicitly abort the request to notify scheduler
|
|
||||||
await self.request_manager.abort_request(request.request_id)
|
|
||||||
break
|
|
||||||
|
|
||||||
# Handle batch responses (for n>1 non-streaming)
|
# Handle batch responses (for n>1 non-streaming)
|
||||||
if isinstance(output, list):
|
if isinstance(output, list):
|
||||||
for batch_output in output:
|
for batch_output in output:
|
||||||
|
|||||||
@@ -1,7 +1,11 @@
|
|||||||
use std::convert::TryFrom;
|
use std::convert::TryFrom;
|
||||||
|
use std::pin::Pin;
|
||||||
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::task::{Context, Poll};
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use tonic::{transport::Channel, Request, Streaming};
|
use tonic::{transport::Channel, Request, Streaming};
|
||||||
use tracing::debug;
|
use tracing::{debug, warn};
|
||||||
|
|
||||||
use crate::protocols::spec::{
|
use crate::protocols::spec::{
|
||||||
ChatCompletionRequest, GenerateRequest, ResponseFormat,
|
ChatCompletionRequest, GenerateRequest, ResponseFormat,
|
||||||
@@ -16,6 +20,92 @@ pub mod proto {
|
|||||||
// The generated module structure depends on the package name in the .proto file
|
// The generated module structure depends on the package name in the .proto file
|
||||||
// package sglang.grpc.scheduler; generates a nested module structure
|
// package sglang.grpc.scheduler; generates a nested module structure
|
||||||
|
|
||||||
|
/// A smart wrapper around Streaming<GenerateResponse> that automatically
|
||||||
|
/// sends abort when dropped (e.g., due to client disconnection or early termination).
|
||||||
|
///
|
||||||
|
/// This leverages Rust's RAII pattern to ensure cleanup happens automatically,
|
||||||
|
/// regardless of how the stream is dropped (panic, early return, client disconnect, etc.).
|
||||||
|
pub struct AbortOnDropStream {
|
||||||
|
inner: Streaming<proto::GenerateResponse>,
|
||||||
|
request_id: String,
|
||||||
|
client: SglangSchedulerClient,
|
||||||
|
aborted: Arc<AtomicBool>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AbortOnDropStream {
|
||||||
|
/// Create a new auto-aborting stream wrapper
|
||||||
|
pub fn new(
|
||||||
|
stream: Streaming<proto::GenerateResponse>,
|
||||||
|
request_id: String,
|
||||||
|
client: SglangSchedulerClient,
|
||||||
|
) -> Self {
|
||||||
|
debug!("Created AbortOnDropStream for request {}", request_id);
|
||||||
|
Self {
|
||||||
|
inner: stream,
|
||||||
|
request_id,
|
||||||
|
client,
|
||||||
|
aborted: Arc::new(AtomicBool::new(false)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Manually mark the request as completed to prevent abort on drop.
|
||||||
|
/// Call this when the request completes successfully to avoid unnecessary abort RPC.
|
||||||
|
pub fn mark_completed(&self) {
|
||||||
|
// Use Release ordering to ensure that this write is visible to other threads
|
||||||
|
// that use Acquire on the same atomic variable
|
||||||
|
self.aborted.store(true, Ordering::Release);
|
||||||
|
debug!("Request {} marked as completed", self.request_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for AbortOnDropStream {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
// Atomically check and set the aborted flag using compare_exchange.
|
||||||
|
// If compare_exchange fails, it means the flag was already true (from mark_completed),
|
||||||
|
// so we don't need to send abort. AcqRel is used for success to synchronize with
|
||||||
|
// mark_completed's Release, and Acquire for failure to see writes from mark_completed.
|
||||||
|
if self
|
||||||
|
.aborted
|
||||||
|
.compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
|
||||||
|
.is_err()
|
||||||
|
{
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let client = self.client.clone();
|
||||||
|
let request_id = self.request_id.clone();
|
||||||
|
|
||||||
|
// Spawn a background task to send abort (since Drop is sync but abort_request is async)
|
||||||
|
tokio::spawn(async move {
|
||||||
|
debug!(
|
||||||
|
"Stream dropped without completion for request {}, sending abort",
|
||||||
|
request_id
|
||||||
|
);
|
||||||
|
// Clone request_id for the error message since abort_request takes ownership
|
||||||
|
let request_id_for_log = request_id.clone();
|
||||||
|
if let Err(e) = client
|
||||||
|
.abort_request(request_id, "Stream dropped".to_string())
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
warn!(
|
||||||
|
"Failed to send abort on drop for request {}: {}",
|
||||||
|
request_id_for_log, e
|
||||||
|
);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implement Stream trait to make AbortOnDropStream work like the original Streaming
|
||||||
|
impl futures::Stream for AbortOnDropStream {
|
||||||
|
type Item = Result<proto::GenerateResponse, tonic::Status>;
|
||||||
|
|
||||||
|
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||||
|
// Delegate to the inner stream
|
||||||
|
Pin::new(&mut self.inner).poll_next(cx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// gRPC client for SGLang scheduler
|
/// gRPC client for SGLang scheduler
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct SglangSchedulerClient {
|
pub struct SglangSchedulerClient {
|
||||||
@@ -35,7 +125,7 @@ impl SglangSchedulerClient {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let channel = Channel::from_shared(http_endpoint)?
|
let channel = Channel::from_shared(http_endpoint)?
|
||||||
.timeout(Duration::from_secs(3600))
|
.timeout(Duration::from_secs(600)) // 10 minute timeout for connection
|
||||||
.http2_keep_alive_interval(Duration::from_secs(30))
|
.http2_keep_alive_interval(Duration::from_secs(30))
|
||||||
.keep_alive_timeout(Duration::from_secs(10))
|
.keep_alive_timeout(Duration::from_secs(10))
|
||||||
.keep_alive_while_idle(true)
|
.keep_alive_while_idle(true)
|
||||||
@@ -52,15 +142,26 @@ impl SglangSchedulerClient {
|
|||||||
Ok(Self { client })
|
Ok(Self { client })
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Submit a generation request (returns streaming response)
|
/// Submit a generation request (returns auto-aborting streaming response)
|
||||||
|
///
|
||||||
|
/// The returned stream automatically sends an abort request when dropped,
|
||||||
|
/// ensuring proper cleanup even if the HTTP client disconnects or an error occurs.
|
||||||
|
/// Call `mark_completed()` on the stream after successful completion to prevent
|
||||||
|
/// unnecessary abort RPCs.
|
||||||
pub async fn generate(
|
pub async fn generate(
|
||||||
&self,
|
&self,
|
||||||
req: proto::GenerateRequest,
|
req: proto::GenerateRequest,
|
||||||
) -> Result<Streaming<proto::GenerateResponse>, Box<dyn std::error::Error + Send + Sync>> {
|
) -> Result<AbortOnDropStream, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
|
let request_id = req.request_id.clone();
|
||||||
let mut client = self.client.clone();
|
let mut client = self.client.clone();
|
||||||
let request = Request::new(req);
|
let request = Request::new(req);
|
||||||
let response = client.generate(request).await?;
|
let response = client.generate(request).await?;
|
||||||
Ok(response.into_inner())
|
|
||||||
|
Ok(AbortOnDropStream::new(
|
||||||
|
response.into_inner(),
|
||||||
|
request_id,
|
||||||
|
self.clone(),
|
||||||
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Perform health check
|
/// Perform health check
|
||||||
@@ -68,12 +169,8 @@ impl SglangSchedulerClient {
|
|||||||
&self,
|
&self,
|
||||||
) -> Result<proto::HealthCheckResponse, Box<dyn std::error::Error + Send + Sync>> {
|
) -> Result<proto::HealthCheckResponse, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
debug!("Sending health check request");
|
debug!("Sending health check request");
|
||||||
let request = Request::new(proto::HealthCheckRequest {
|
// Server ignores the request body and creates its own health check internally
|
||||||
tokenized: Some(proto::TokenizedInput {
|
let request = Request::new(proto::HealthCheckRequest { tokenized: None });
|
||||||
original_text: "Hello".to_string(),
|
|
||||||
input_ids: vec![9906], // Mock token ID for "Hello"
|
|
||||||
}),
|
|
||||||
});
|
|
||||||
|
|
||||||
let mut client = self.client.clone();
|
let mut client = self.client.clone();
|
||||||
let response = client.health_check(request).await?;
|
let response = client.health_check(request).await?;
|
||||||
@@ -87,10 +184,23 @@ impl SglangSchedulerClient {
|
|||||||
request_id: String,
|
request_id: String,
|
||||||
reason: String,
|
reason: String,
|
||||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||||
let request = Request::new(proto::AbortRequest { request_id, reason });
|
debug!(
|
||||||
|
"Sending abort request for {} (reason: {})",
|
||||||
|
request_id, reason
|
||||||
|
);
|
||||||
|
let request = Request::new(proto::AbortRequest {
|
||||||
|
request_id: request_id.clone(),
|
||||||
|
reason,
|
||||||
|
});
|
||||||
|
|
||||||
let mut client = self.client.clone();
|
let mut client = self.client.clone();
|
||||||
client.abort(request).await?;
|
let response = client.abort(request).await?;
|
||||||
|
debug!(
|
||||||
|
"Abort response for {}: success={}, message={}",
|
||||||
|
request_id,
|
||||||
|
response.get_ref().success,
|
||||||
|
response.get_ref().message
|
||||||
|
);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -371,16 +371,17 @@ impl ClientSelection {
|
|||||||
// Execution and Response Types
|
// Execution and Response Types
|
||||||
// ============================================================================
|
// ============================================================================
|
||||||
|
|
||||||
use tonic::codec::Streaming;
|
use crate::grpc_client::sglang_scheduler::AbortOnDropStream;
|
||||||
|
|
||||||
/// Result of request execution (streams from workers)
|
/// Result of request execution (streams from workers)
|
||||||
|
/// Uses AbortOnDropStream to automatically abort on cancellation
|
||||||
pub enum ExecutionResult {
|
pub enum ExecutionResult {
|
||||||
Single {
|
Single {
|
||||||
stream: Streaming<proto::GenerateResponse>,
|
stream: AbortOnDropStream,
|
||||||
},
|
},
|
||||||
Dual {
|
Dual {
|
||||||
prefill: Streaming<proto::GenerateResponse>,
|
prefill: AbortOnDropStream,
|
||||||
decode: Box<Streaming<proto::GenerateResponse>>,
|
decode: Box<AbortOnDropStream>,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -816,16 +816,27 @@ impl ResponseProcessingStage {
|
|||||||
|
|
||||||
// Collect all responses from the execution result
|
// Collect all responses from the execution result
|
||||||
let all_responses = match execution_result {
|
let all_responses = match execution_result {
|
||||||
ExecutionResult::Single { stream } => {
|
ExecutionResult::Single { mut stream } => {
|
||||||
utils::collect_stream_responses(stream, "Single").await?
|
let responses = utils::collect_stream_responses(&mut stream, "Single").await?;
|
||||||
|
stream.mark_completed();
|
||||||
|
responses
|
||||||
}
|
}
|
||||||
ExecutionResult::Dual { prefill, decode } => {
|
ExecutionResult::Dual {
|
||||||
// Collect prefill for input_logprobs
|
mut prefill,
|
||||||
let prefill_responses = utils::collect_stream_responses(prefill, "Prefill").await?;
|
decode,
|
||||||
|
} => {
|
||||||
|
// Collect prefill for input_logprobs (don't mark completed yet)
|
||||||
|
let prefill_responses =
|
||||||
|
utils::collect_stream_responses(&mut prefill, "Prefill").await?;
|
||||||
|
|
||||||
// Collect decode for actual output
|
// Collect decode for actual output (don't mark completed yet)
|
||||||
|
let mut decode_stream = *decode;
|
||||||
let mut decode_responses =
|
let mut decode_responses =
|
||||||
utils::collect_stream_responses(*decode, "Decode").await?;
|
utils::collect_stream_responses(&mut decode_stream, "Decode").await?;
|
||||||
|
|
||||||
|
// Mark both streams as completed now that both succeeded
|
||||||
|
prefill.mark_completed();
|
||||||
|
decode_stream.mark_completed();
|
||||||
|
|
||||||
// Merge prefill input_logprobs if requested
|
// Merge prefill input_logprobs if requested
|
||||||
if request_logprobs {
|
if request_logprobs {
|
||||||
@@ -952,16 +963,27 @@ impl ResponseProcessingStage {
|
|||||||
// Non-streaming: Collect all responses
|
// Non-streaming: Collect all responses
|
||||||
let request_logprobs = ctx.generate_request().return_logprob;
|
let request_logprobs = ctx.generate_request().return_logprob;
|
||||||
let all_responses = match execution_result {
|
let all_responses = match execution_result {
|
||||||
ExecutionResult::Single { stream } => {
|
ExecutionResult::Single { mut stream } => {
|
||||||
utils::collect_stream_responses(stream, "Single").await?
|
let responses = utils::collect_stream_responses(&mut stream, "Single").await?;
|
||||||
|
stream.mark_completed();
|
||||||
|
responses
|
||||||
}
|
}
|
||||||
ExecutionResult::Dual { prefill, decode } => {
|
ExecutionResult::Dual {
|
||||||
// Collect prefill for input_logprobs
|
mut prefill,
|
||||||
let prefill_responses = utils::collect_stream_responses(prefill, "Prefill").await?;
|
decode,
|
||||||
|
} => {
|
||||||
|
// Collect prefill for input_logprobs (don't mark completed yet)
|
||||||
|
let prefill_responses =
|
||||||
|
utils::collect_stream_responses(&mut prefill, "Prefill").await?;
|
||||||
|
|
||||||
// Collect decode for actual output
|
// Collect decode for actual output (don't mark completed yet)
|
||||||
|
let mut decode_stream = *decode;
|
||||||
let mut decode_responses =
|
let mut decode_responses =
|
||||||
utils::collect_stream_responses(*decode, "Decode").await?;
|
utils::collect_stream_responses(&mut decode_stream, "Decode").await?;
|
||||||
|
|
||||||
|
// Mark both streams as completed now that both succeeded
|
||||||
|
prefill.mark_completed();
|
||||||
|
decode_stream.mark_completed();
|
||||||
|
|
||||||
// Merge prefill input_logprobs if requested
|
// Merge prefill input_logprobs if requested
|
||||||
if request_logprobs {
|
if request_logprobs {
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ use std::sync::Arc;
|
|||||||
use tokio::sync::mpsc::UnboundedSender;
|
use tokio::sync::mpsc::UnboundedSender;
|
||||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
use tokio_stream::StreamExt;
|
use tokio_stream::StreamExt;
|
||||||
use tonic::codec::Streaming;
|
|
||||||
use tracing::{debug, error, warn};
|
use tracing::{debug, error, warn};
|
||||||
|
|
||||||
use super::context;
|
use super::context;
|
||||||
@@ -153,7 +152,7 @@ impl StreamingProcessor {
|
|||||||
/// Process streaming chunks from a single stream (Regular mode)
|
/// Process streaming chunks from a single stream (Regular mode)
|
||||||
pub async fn process_streaming_chunks(
|
pub async fn process_streaming_chunks(
|
||||||
&self,
|
&self,
|
||||||
mut grpc_stream: Streaming<proto::GenerateResponse>,
|
mut grpc_stream: crate::grpc_client::sglang_scheduler::AbortOnDropStream,
|
||||||
dispatch: context::DispatchMetadata,
|
dispatch: context::DispatchMetadata,
|
||||||
stop_params: (Option<StringOrArray>, Option<Vec<u32>>, bool, bool),
|
stop_params: (Option<StringOrArray>, Option<Vec<u32>>, bool, bool),
|
||||||
original_request: Arc<ChatCompletionRequest>,
|
original_request: Arc<ChatCompletionRequest>,
|
||||||
@@ -571,14 +570,17 @@ impl StreamingProcessor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Mark stream as completed successfully to prevent abort on drop
|
||||||
|
grpc_stream.mark_completed();
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Process dual streaming chunks (prefill + decode) - PD mode
|
/// Process dual streaming chunks (prefill + decode) - PD mode
|
||||||
pub async fn process_dual_streaming_chunks(
|
pub async fn process_dual_streaming_chunks(
|
||||||
&self,
|
&self,
|
||||||
mut prefill_stream: Streaming<proto::GenerateResponse>,
|
mut prefill_stream: crate::grpc_client::sglang_scheduler::AbortOnDropStream,
|
||||||
decode_stream: Streaming<proto::GenerateResponse>,
|
decode_stream: crate::grpc_client::sglang_scheduler::AbortOnDropStream,
|
||||||
dispatch: context::DispatchMetadata,
|
dispatch: context::DispatchMetadata,
|
||||||
stop_params: (Option<StringOrArray>, Option<Vec<u32>>, bool, bool),
|
stop_params: (Option<StringOrArray>, Option<Vec<u32>>, bool, bool),
|
||||||
original_request: Arc<ChatCompletionRequest>,
|
original_request: Arc<ChatCompletionRequest>,
|
||||||
@@ -603,8 +605,18 @@ impl StreamingProcessor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Phase 2-5: Process decode stream (same as single mode)
|
// Phase 2-5: Process decode stream (same as single mode)
|
||||||
self.process_streaming_chunks(decode_stream, dispatch, stop_params, original_request, tx)
|
// Note: decode_stream will be marked completed inside process_streaming_chunks
|
||||||
.await
|
let result = self
|
||||||
|
.process_streaming_chunks(decode_stream, dispatch, stop_params, original_request, tx)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
// Mark prefill stream as completed AFTER decode completes successfully
|
||||||
|
// This ensures that if client disconnects during decode, BOTH streams send abort
|
||||||
|
if result.is_ok() {
|
||||||
|
prefill_stream.mark_completed();
|
||||||
|
}
|
||||||
|
|
||||||
|
result
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Process streaming generate response and return SSE response
|
/// Process streaming generate response and return SSE response
|
||||||
@@ -687,7 +699,7 @@ impl StreamingProcessor {
|
|||||||
/// Process streaming chunks for generate endpoint (no tool/reasoning parsing)
|
/// Process streaming chunks for generate endpoint (no tool/reasoning parsing)
|
||||||
async fn process_generate_streaming(
|
async fn process_generate_streaming(
|
||||||
tokenizer: Arc<dyn Tokenizer>,
|
tokenizer: Arc<dyn Tokenizer>,
|
||||||
mut stream: Streaming<proto::GenerateResponse>,
|
mut stream: crate::grpc_client::sglang_scheduler::AbortOnDropStream,
|
||||||
request_id: String,
|
request_id: String,
|
||||||
weight_version: String,
|
weight_version: String,
|
||||||
_include_logprobs: bool,
|
_include_logprobs: bool,
|
||||||
@@ -782,14 +794,17 @@ impl StreamingProcessor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Mark stream as completed successfully to prevent abort on drop
|
||||||
|
stream.mark_completed();
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Process dual streaming for generate endpoint (PD mode with logprobs support)
|
/// Process dual streaming for generate endpoint (PD mode with logprobs support)
|
||||||
async fn process_generate_streaming_dual(
|
async fn process_generate_streaming_dual(
|
||||||
tokenizer: Arc<dyn Tokenizer>,
|
tokenizer: Arc<dyn Tokenizer>,
|
||||||
mut prefill_stream: Streaming<proto::GenerateResponse>,
|
mut prefill_stream: crate::grpc_client::sglang_scheduler::AbortOnDropStream,
|
||||||
decode_stream: Streaming<proto::GenerateResponse>,
|
decode_stream: crate::grpc_client::sglang_scheduler::AbortOnDropStream,
|
||||||
request_id: String,
|
request_id: String,
|
||||||
weight_version: String,
|
weight_version: String,
|
||||||
return_logprob: bool,
|
return_logprob: bool,
|
||||||
@@ -821,7 +836,8 @@ impl StreamingProcessor {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Process decode stream with input_logprobs prepended
|
// Process decode stream with input_logprobs prepended
|
||||||
Self::process_generate_streaming_with_input_logprobs(
|
// Note: decode_stream will be marked completed inside the function
|
||||||
|
let result = Self::process_generate_streaming_with_input_logprobs(
|
||||||
tokenizer,
|
tokenizer,
|
||||||
decode_stream,
|
decode_stream,
|
||||||
request_id,
|
request_id,
|
||||||
@@ -830,13 +846,21 @@ impl StreamingProcessor {
|
|||||||
input_token_logprobs,
|
input_token_logprobs,
|
||||||
tx,
|
tx,
|
||||||
)
|
)
|
||||||
.await
|
.await;
|
||||||
|
|
||||||
|
// Mark prefill stream as completed AFTER decode completes successfully
|
||||||
|
// This ensures that if client disconnects during decode, BOTH streams send abort
|
||||||
|
if result.is_ok() {
|
||||||
|
prefill_stream.mark_completed();
|
||||||
|
}
|
||||||
|
|
||||||
|
result
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Process generate streaming with optional input_logprobs
|
/// Process generate streaming with optional input_logprobs
|
||||||
async fn process_generate_streaming_with_input_logprobs(
|
async fn process_generate_streaming_with_input_logprobs(
|
||||||
tokenizer: Arc<dyn Tokenizer>,
|
tokenizer: Arc<dyn Tokenizer>,
|
||||||
mut stream: Streaming<proto::GenerateResponse>,
|
mut stream: crate::grpc_client::sglang_scheduler::AbortOnDropStream,
|
||||||
request_id: String,
|
request_id: String,
|
||||||
weight_version: String,
|
weight_version: String,
|
||||||
_include_logprobs: bool,
|
_include_logprobs: bool,
|
||||||
@@ -957,6 +981,9 @@ impl StreamingProcessor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Mark stream as completed successfully to prevent abort on drop
|
||||||
|
stream.mark_completed();
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
use super::ProcessedMessages;
|
use super::ProcessedMessages;
|
||||||
use crate::core::Worker;
|
use crate::core::Worker;
|
||||||
|
use crate::grpc_client::sglang_scheduler::AbortOnDropStream;
|
||||||
use crate::grpc_client::{proto, SglangSchedulerClient};
|
use crate::grpc_client::{proto, SglangSchedulerClient};
|
||||||
use crate::protocols::spec::{
|
use crate::protocols::spec::{
|
||||||
ChatCompletionRequest, ChatLogProbs, ChatLogProbsContent, ChatMessage, FunctionCallResponse,
|
ChatCompletionRequest, ChatLogProbs, ChatLogProbsContent, ChatMessage, FunctionCallResponse,
|
||||||
@@ -20,7 +21,6 @@ use futures::StreamExt;
|
|||||||
use serde_json::{json, Map, Value};
|
use serde_json::{json, Map, Value};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tonic::codec::Streaming;
|
|
||||||
use tracing::{error, warn};
|
use tracing::{error, warn};
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
@@ -590,7 +590,7 @@ pub fn parse_json_schema_response(
|
|||||||
/// * `Ok(Vec<GenerateComplete>)` - All complete responses collected from the stream
|
/// * `Ok(Vec<GenerateComplete>)` - All complete responses collected from the stream
|
||||||
/// * `Err(Response)` - Error response if the stream fails or returns an error
|
/// * `Err(Response)` - Error response if the stream fails or returns an error
|
||||||
pub async fn collect_stream_responses(
|
pub async fn collect_stream_responses(
|
||||||
mut stream: Streaming<proto::GenerateResponse>,
|
stream: &mut AbortOnDropStream,
|
||||||
worker_name: &str,
|
worker_name: &str,
|
||||||
) -> Result<Vec<proto::GenerateComplete>, Response> {
|
) -> Result<Vec<proto::GenerateComplete>, Response> {
|
||||||
use proto::generate_response::Response::*;
|
use proto::generate_response::Response::*;
|
||||||
@@ -606,6 +606,7 @@ pub async fn collect_stream_responses(
|
|||||||
}
|
}
|
||||||
Some(Error(err)) => {
|
Some(Error(err)) => {
|
||||||
error!("{} error: {}", worker_name, err.message);
|
error!("{} error: {}", worker_name, err.message);
|
||||||
|
// Don't mark as completed - let Drop send abort for error cases
|
||||||
return Err(internal_error_message(format!(
|
return Err(internal_error_message(format!(
|
||||||
"{} generation failed: {}",
|
"{} generation failed: {}",
|
||||||
worker_name, err.message
|
worker_name, err.message
|
||||||
@@ -621,6 +622,7 @@ pub async fn collect_stream_responses(
|
|||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
error!("{} stream error: {:?}", worker_name, e);
|
error!("{} stream error: {:?}", worker_name, e);
|
||||||
|
// Don't mark as completed - let Drop send abort for error cases
|
||||||
return Err(internal_error_message(format!(
|
return Err(internal_error_message(format!(
|
||||||
"{} stream failed: {}",
|
"{} stream failed: {}",
|
||||||
worker_name, e
|
worker_name, e
|
||||||
|
|||||||
Reference in New Issue
Block a user