[router] leverage RAII to actively cancel request during client disconnect (#11399)

This commit is contained in:
Simo Lin
2025-10-10 20:43:38 -04:00
committed by GitHub
parent 2eeb27515a
commit c495833186
7 changed files with 297 additions and 78 deletions

View File

@@ -319,13 +319,8 @@ class GrpcRequestManager:
is_stream = getattr(obj, "stream", False)
while True:
# Client cancelled - notify scheduler and exit
if grpc_context and grpc_context.cancelled():
await self.abort_request(request_id)
return
try:
response = await asyncio.wait_for(state.out_queue.get(), timeout=4)
response = await state.out_queue.get()
if is_stream:
yield response
@@ -338,10 +333,11 @@ class GrpcRequestManager:
yield final_response
break
except asyncio.TimeoutError:
# Timeout is for periodic client cancellation check
# Continue waiting for scheduler response
continue
except asyncio.CancelledError:
# Task was cancelled by gRPC framework when client disconnected
logger.info(f"Request {request_id} cancelled by client")
await self.abort_request(request_id)
raise # Re-raise to let gRPC server handle cleanup
finally:
# Always clean up request state when exiting
@@ -409,31 +405,31 @@ class GrpcRequestManager:
return future
async def abort_request(self, request_id: str) -> bool:
"""Abort a running request."""
"""Abort a running request.
Sends abort request to scheduler and marks local state as finished
to stop processing any further outputs from the scheduler.
"""
# Skip aborting health check requests (they clean themselves up)
if request_id.startswith("HEALTH_CHECK"):
return False
if request_id not in self.rid_to_state:
return False
# Send abort to scheduler
abort_req = AbortReq(rid=request_id)
try:
await self._send_to_scheduler(abort_req)
except Exception as e:
logger.error(f"Failed to send abort request: {e}")
return False
# Mark as finished
# Mark state as finished immediately to stop processing scheduler outputs
state = self.rid_to_state.get(request_id)
if state:
state.finished = True
state.stream_finished = True
state.event.set()
logger.debug(f"Marked request {request_id} as aborted locally")
# Send abort notification to output queue
await state.out_queue.put({"error": "Request aborted", "abort": True})
# Send abort to scheduler - the scheduler will send AbortReq back
# which will be handled by _handle_abort_req
abort_req = AbortReq(rid=request_id)
try:
await self._send_to_scheduler(abort_req)
logger.debug(f"Sent abort to scheduler for request {request_id}")
except Exception as e:
logger.error(f"Failed to send abort request to scheduler: {e}")
return False
return True
@@ -460,6 +456,8 @@ class GrpcRequestManager:
await self._handle_embedding_output(recv_obj)
elif isinstance(recv_obj, HealthCheckOutput):
await self._handle_health_check_output(recv_obj)
elif isinstance(recv_obj, AbortReq):
await self._handle_abort_req(recv_obj)
else:
logger.warning(f"Unknown output type: {type(recv_obj)}")
@@ -541,6 +539,11 @@ class GrpcRequestManager:
state = self.rid_to_state[rid]
# Skip if already aborted/finished locally (client cancelled)
if state.finished:
logger.debug(f"Skipping output for aborted request {rid}")
continue
# Update metrics
now = time.time()
if state.first_token_time == 0.0:
@@ -713,6 +716,67 @@ class GrpcRequestManager:
state.finished_time = time.time()
state.event.set()
async def _handle_abort_req(self, recv_obj: AbortReq):
"""Handle abort request from scheduler.
The scheduler sends AbortReq back to notify us that a request was aborted,
either due to explicit abort_request() call or scheduler-initiated abort
(priority preemption, queue full, KV cache pressure, etc).
"""
# Skip health check requests
if recv_obj.rid.startswith("HEALTH_CHECK"):
return
# Check if request still exists
if recv_obj.rid not in self.rid_to_state:
logger.debug(
f"Abort request for {recv_obj.rid} not in local state (may have already finished or not started yet)"
)
return
state = self.rid_to_state[recv_obj.rid]
# Mark as finished
state.finished = True
state.stream_finished = True
# Create abort response
if recv_obj.finished_reason:
# Scheduler provided a specific finish reason (e.g., priority preemption, queue full)
abort_response = {
"request_id": recv_obj.rid,
"error": recv_obj.finished_reason.get("message", "Request aborted"),
"finished": True,
"meta_info": {
"id": recv_obj.rid,
"finish_reason": recv_obj.finished_reason,
},
}
else:
# Generic abort (e.g., explicit abort_request call)
abort_response = {
"request_id": recv_obj.rid,
"error": "Request aborted",
"finished": True,
"meta_info": {
"id": recv_obj.rid,
"finish_reason": {
"type": "abort",
"message": "Abort before prefill",
},
"prompt_tokens": 0,
"completion_tokens": 0,
},
}
# Send abort notification to output queue
await state.out_queue.put(abort_response)
# Wake up any waiting coroutines
state.event.set()
logger.debug(f"Handled abort request for {recv_obj.rid}")
async def _send_to_scheduler(self, obj):
"""Send an object to the scheduler via ZMQ."""
try:

View File

@@ -211,13 +211,6 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
)
async for output in response_generator:
# Check if client cancelled before processing/yielding
if context.cancelled():
logger.info(f"Client cancelled request {request.request_id}")
# Explicitly abort the request to notify scheduler
await self.request_manager.abort_request(request.request_id)
break
# Handle batch responses (for n>1 non-streaming)
if isinstance(output, list):
for batch_output in output: