[router][grpc] Replace fake health check with correct ones (#11387)
This commit is contained in:
@@ -27,6 +27,7 @@ from sglang.srt.managers.io_struct import (
|
||||
TokenizedEmbeddingReqInput,
|
||||
TokenizedGenerateReqInput,
|
||||
)
|
||||
from sglang.srt.managers.scheduler import is_health_check_generate_req
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import get_zmq_socket, kill_process_tree
|
||||
from sglang.utils import get_exception_traceback
|
||||
@@ -338,12 +339,9 @@ class GrpcRequestManager:
|
||||
break
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# Timeout waiting for response - abort and cleanup
|
||||
logger.warning(
|
||||
f"Timeout waiting for response for request {request_id}"
|
||||
)
|
||||
await self.abort_request(request_id)
|
||||
return
|
||||
# Timeout is for periodic client cancellation check
|
||||
# Continue waiting for scheduler response
|
||||
continue
|
||||
|
||||
finally:
|
||||
# Always clean up request state when exiting
|
||||
@@ -412,6 +410,10 @@ class GrpcRequestManager:
|
||||
|
||||
async def abort_request(self, request_id: str) -> bool:
|
||||
"""Abort a running request."""
|
||||
# Skip aborting health check requests (they clean themselves up)
|
||||
if request_id.startswith("HEALTH_CHECK"):
|
||||
return False
|
||||
|
||||
if request_id not in self.rid_to_state:
|
||||
return False
|
||||
|
||||
|
||||
@@ -197,7 +197,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
|
||||
context: grpc.aio.ServicerContext,
|
||||
) -> AsyncIterator[sglang_scheduler_pb2.GenerateResponse]:
|
||||
"""Handle generation requests with streaming responses."""
|
||||
logger.debug(f"Receive generation request: {request.request_id}")
|
||||
logger.info(f"Receive generation request: {request.request_id}")
|
||||
|
||||
try:
|
||||
# Convert gRPC request to internal format
|
||||
@@ -211,6 +211,13 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
|
||||
)
|
||||
|
||||
async for output in response_generator:
|
||||
# Check if client cancelled before processing/yielding
|
||||
if context.cancelled():
|
||||
logger.info(f"Client cancelled request {request.request_id}")
|
||||
# Explicitly abort the request to notify scheduler
|
||||
await self.request_manager.abort_request(request.request_id)
|
||||
break
|
||||
|
||||
# Handle batch responses (for n>1 non-streaming)
|
||||
if isinstance(output, list):
|
||||
for batch_output in output:
|
||||
@@ -268,7 +275,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
|
||||
_context: grpc.aio.ServicerContext,
|
||||
) -> sglang_scheduler_pb2.EmbedResponse:
|
||||
"""Handle embedding requests."""
|
||||
logger.debug(f"Receive embedding request: {request.request_id}")
|
||||
logger.info(f"Receive embedding request: {request.request_id}")
|
||||
|
||||
try:
|
||||
# Convert request
|
||||
@@ -313,9 +320,86 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
|
||||
request: sglang_scheduler_pb2.HealthCheckRequest,
|
||||
context: grpc.aio.ServicerContext,
|
||||
) -> sglang_scheduler_pb2.HealthCheckResponse:
|
||||
"""Health check - always returns healthy after server started."""
|
||||
"""
|
||||
Check the health of the inference server by sending a special request to generate one token.
|
||||
Similar to HTTP server's /health endpoint.
|
||||
"""
|
||||
logger.info("Receive health check request")
|
||||
|
||||
if self.request_manager.gracefully_exit:
|
||||
logger.info(
|
||||
"Health check request received during shutdown. Returning unhealthy."
|
||||
)
|
||||
return sglang_scheduler_pb2.HealthCheckResponse(
|
||||
healthy=False, message="Server is shutting down"
|
||||
)
|
||||
|
||||
# Create a special health check request
|
||||
rid = f"HEALTH_CHECK_{time.time()}"
|
||||
sampling_params = SGLSamplingParams(max_new_tokens=1, temperature=0.0)
|
||||
sampling_params.normalize(tokenizer=None)
|
||||
|
||||
# Create health check request
|
||||
is_generation = self.scheduler_info.get("is_generation", True)
|
||||
if is_generation:
|
||||
health_req = TokenizedGenerateReqInput(
|
||||
rid=rid,
|
||||
input_text="",
|
||||
input_ids=[0],
|
||||
sampling_params=sampling_params,
|
||||
return_logprob=False,
|
||||
logprob_start_len=-1,
|
||||
top_logprobs_num=0,
|
||||
stream=False,
|
||||
mm_inputs=None,
|
||||
token_ids_logprob=None,
|
||||
)
|
||||
# Set disaggregation params if needed
|
||||
if self.server_args.disaggregation_mode != DisaggregationMode.NULL:
|
||||
health_req.bootstrap_host = FAKE_BOOTSTRAP_HOST
|
||||
health_req.bootstrap_room = 0
|
||||
else:
|
||||
health_req = TokenizedEmbeddingReqInput(
|
||||
rid=rid,
|
||||
input_text="",
|
||||
input_ids=[0],
|
||||
)
|
||||
|
||||
# Submit health check request
|
||||
async def run_health_check():
|
||||
try:
|
||||
async for _ in self.request_manager.generate_request(
|
||||
obj=health_req,
|
||||
request_id=rid,
|
||||
):
|
||||
# Got at least one response, server is healthy
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"Health check failed: {e}")
|
||||
return False
|
||||
return False
|
||||
|
||||
task = asyncio.create_task(run_health_check())
|
||||
|
||||
# Wait for response with timeout
|
||||
tic = time.time()
|
||||
while time.time() < tic + HEALTH_CHECK_TIMEOUT:
|
||||
await asyncio.sleep(1)
|
||||
# Check if we got a response from scheduler
|
||||
if self.request_manager.last_receive_tstamp > tic:
|
||||
task.cancel()
|
||||
# Clean up health check state
|
||||
self.request_manager._cleanup_request_state(rid)
|
||||
return sglang_scheduler_pb2.HealthCheckResponse(
|
||||
healthy=True, message="Health check passed"
|
||||
)
|
||||
|
||||
# Timeout - server not responding
|
||||
task.cancel()
|
||||
self.request_manager._cleanup_request_state(rid)
|
||||
logger.warning(f"Health check timeout after {HEALTH_CHECK_TIMEOUT}s")
|
||||
return sglang_scheduler_pb2.HealthCheckResponse(
|
||||
healthy=True, message="Health check passed"
|
||||
healthy=False, message=f"Health check timeout after {HEALTH_CHECK_TIMEOUT}s"
|
||||
)
|
||||
|
||||
async def Abort(
|
||||
@@ -324,7 +408,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
|
||||
_context: grpc.aio.ServicerContext,
|
||||
) -> sglang_scheduler_pb2.AbortResponse:
|
||||
"""Abort an ongoing request."""
|
||||
logger.debug(f"Receive abort request: {request.request_id}")
|
||||
logger.info(f"Receive abort request: {request.request_id}")
|
||||
|
||||
try:
|
||||
success = await self.request_manager.abort_request(request.request_id)
|
||||
|
||||
Reference in New Issue
Block a user