From b520958ec8f35b0c8dfc365ede757e9b730da6e9 Mon Sep 17 00:00:00 2001 From: Chang Su Date: Thu, 9 Oct 2025 09:13:57 -0700 Subject: [PATCH] [router][grpc] Replace fake health check with correct ones (#11387) --- .../srt/entrypoints/grpc_request_manager.py | 14 +-- python/sglang/srt/entrypoints/grpc_server.py | 94 ++++++++++++++++++- 2 files changed, 97 insertions(+), 11 deletions(-) diff --git a/python/sglang/srt/entrypoints/grpc_request_manager.py b/python/sglang/srt/entrypoints/grpc_request_manager.py index c719f7c45..7351f4de3 100644 --- a/python/sglang/srt/entrypoints/grpc_request_manager.py +++ b/python/sglang/srt/entrypoints/grpc_request_manager.py @@ -27,6 +27,7 @@ from sglang.srt.managers.io_struct import ( TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, ) +from sglang.srt.managers.scheduler import is_health_check_generate_req from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import get_zmq_socket, kill_process_tree from sglang.utils import get_exception_traceback @@ -338,12 +339,9 @@ class GrpcRequestManager: break except asyncio.TimeoutError: - # Timeout waiting for response - abort and cleanup - logger.warning( - f"Timeout waiting for response for request {request_id}" - ) - await self.abort_request(request_id) - return + # Timeout is for periodic client cancellation check + # Continue waiting for scheduler response + continue finally: # Always clean up request state when exiting @@ -412,6 +410,10 @@ class GrpcRequestManager: async def abort_request(self, request_id: str) -> bool: """Abort a running request.""" + # Skip aborting health check requests (they clean themselves up) + if request_id.startswith("HEALTH_CHECK"): + return False + if request_id not in self.rid_to_state: return False diff --git a/python/sglang/srt/entrypoints/grpc_server.py b/python/sglang/srt/entrypoints/grpc_server.py index e94e0e813..c3c813a3a 100644 --- a/python/sglang/srt/entrypoints/grpc_server.py +++ b/python/sglang/srt/entrypoints/grpc_server.py @@ -197,7 +197,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer) context: grpc.aio.ServicerContext, ) -> AsyncIterator[sglang_scheduler_pb2.GenerateResponse]: """Handle generation requests with streaming responses.""" - logger.debug(f"Receive generation request: {request.request_id}") + logger.info(f"Receive generation request: {request.request_id}") try: # Convert gRPC request to internal format @@ -211,6 +211,13 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer) ) async for output in response_generator: + # Check if client cancelled before processing/yielding + if context.cancelled(): + logger.info(f"Client cancelled request {request.request_id}") + # Explicitly abort the request to notify scheduler + await self.request_manager.abort_request(request.request_id) + break + # Handle batch responses (for n>1 non-streaming) if isinstance(output, list): for batch_output in output: @@ -268,7 +275,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer) _context: grpc.aio.ServicerContext, ) -> sglang_scheduler_pb2.EmbedResponse: """Handle embedding requests.""" - logger.debug(f"Receive embedding request: {request.request_id}") + logger.info(f"Receive embedding request: {request.request_id}") try: # Convert request @@ -313,9 +320,86 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer) request: sglang_scheduler_pb2.HealthCheckRequest, context: grpc.aio.ServicerContext, ) -> sglang_scheduler_pb2.HealthCheckResponse: - """Health check - always returns healthy after server started.""" + """ + Check the health of the inference server by sending a special request to generate one token. + Similar to HTTP server's /health endpoint. + """ + logger.info("Receive health check request") + + if self.request_manager.gracefully_exit: + logger.info( + "Health check request received during shutdown. Returning unhealthy." + ) + return sglang_scheduler_pb2.HealthCheckResponse( + healthy=False, message="Server is shutting down" + ) + + # Create a special health check request + rid = f"HEALTH_CHECK_{time.time()}" + sampling_params = SGLSamplingParams(max_new_tokens=1, temperature=0.0) + sampling_params.normalize(tokenizer=None) + + # Create health check request + is_generation = self.scheduler_info.get("is_generation", True) + if is_generation: + health_req = TokenizedGenerateReqInput( + rid=rid, + input_text="", + input_ids=[0], + sampling_params=sampling_params, + return_logprob=False, + logprob_start_len=-1, + top_logprobs_num=0, + stream=False, + mm_inputs=None, + token_ids_logprob=None, + ) + # Set disaggregation params if needed + if self.server_args.disaggregation_mode != DisaggregationMode.NULL: + health_req.bootstrap_host = FAKE_BOOTSTRAP_HOST + health_req.bootstrap_room = 0 + else: + health_req = TokenizedEmbeddingReqInput( + rid=rid, + input_text="", + input_ids=[0], + ) + + # Submit health check request + async def run_health_check(): + try: + async for _ in self.request_manager.generate_request( + obj=health_req, + request_id=rid, + ): + # Got at least one response, server is healthy + return True + except Exception as e: + logger.warning(f"Health check failed: {e}") + return False + return False + + task = asyncio.create_task(run_health_check()) + + # Wait for response with timeout + tic = time.time() + while time.time() < tic + HEALTH_CHECK_TIMEOUT: + await asyncio.sleep(1) + # Check if we got a response from scheduler + if self.request_manager.last_receive_tstamp > tic: + task.cancel() + # Clean up health check state + self.request_manager._cleanup_request_state(rid) + return sglang_scheduler_pb2.HealthCheckResponse( + healthy=True, message="Health check passed" + ) + + # Timeout - server not responding + task.cancel() + self.request_manager._cleanup_request_state(rid) + logger.warning(f"Health check timeout after {HEALTH_CHECK_TIMEOUT}s") return sglang_scheduler_pb2.HealthCheckResponse( - healthy=True, message="Health check passed" + healthy=False, message=f"Health check timeout after {HEALTH_CHECK_TIMEOUT}s" ) async def Abort( @@ -324,7 +408,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer) _context: grpc.aio.ServicerContext, ) -> sglang_scheduler_pb2.AbortResponse: """Abort an ongoing request.""" - logger.debug(f"Receive abort request: {request.request_id}") + logger.info(f"Receive abort request: {request.request_id}") try: success = await self.request_manager.abort_request(request.request_id)