From 4540a4666a112a82dcf21505b781f3e31e50d178 Mon Sep 17 00:00:00 2001 From: ybyang <10629930+whybeyoung@users.noreply.github.com> Date: Sun, 20 Jul 2025 09:10:00 +0800 Subject: [PATCH] [Feature] Simple Improve Health Check Mechanism for Production-Grade Stability (#8115) Signed-off-by: ybyang --- python/sglang/srt/entrypoints/engine.py | 4 ++ python/sglang/srt/entrypoints/http_server.py | 57 ++++++++++++++++--- python/sglang/srt/managers/io_struct.py | 6 ++ python/sglang/srt/managers/scheduler.py | 3 + .../sglang/srt/managers/tokenizer_manager.py | 7 ++- python/sglang/srt/utils.py | 16 ++++++ 6 files changed, 82 insertions(+), 11 deletions(-) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 990fac9a1..957d85aa5 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -65,6 +65,7 @@ from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.utils import ( MultiprocessingSerializer, + ServerStatus, assert_pkg_version, configure_logger, get_zmq_socket, @@ -73,6 +74,7 @@ from sglang.srt.utils import ( launch_dummy_health_check_server, maybe_set_triton_cache_manager, prepare_model_and_tokenizer, + report_health, set_prometheus_multiproc_dir, set_ulimit, ) @@ -661,6 +663,7 @@ def _set_envs_and_config(server_args: ServerArgs): def sigchld_handler(signum, frame): pid, exitcode = os.waitpid(0, os.WNOHANG) if exitcode != 0: + report_health(ServerStatus.Crashed, server_args.host, server_args.port) logger.warning( f"Child process unexpectedly failed with {exitcode=}. {pid=}" ) @@ -674,6 +677,7 @@ def _set_envs_and_config(server_args: ServerArgs): logger.error( "Received sigquit from a child process. It usually means the child failed." ) + report_health(ServerStatus.Crashed, server_args.host, server_args.port) kill_process_tree(os.getpid()) signal.signal(signal.SIGQUIT, sigquit_handler) diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 43819e1a6..f880c4aa5 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -77,6 +77,7 @@ from sglang.srt.managers.io_struct import ( ParseFunctionCallReq, ProfileReqInput, ReleaseMemoryOccupationReqInput, + ReportHealthInput, ResumeMemoryOccupationReqInput, SeparateReasoningReqInput, SetInternalStateReq, @@ -93,6 +94,7 @@ from sglang.srt.metrics.func_timer import enable_func_timer from sglang.srt.reasoning_parser import ReasoningParser from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ( + ServerStatus, add_api_key_middleware, add_prometheus_middleware, delete_directory, @@ -220,8 +222,31 @@ HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20)) @app.get("/health") async def health() -> Response: - """Check the health of the http server.""" - return Response(status_code=200) + """Check the status of the http server.""" + code = HTTPStatus.SERVICE_UNAVAILABLE.value + if _global_state.tokenizer_manager.server_status == ServerStatus.Up: + code = HTTPStatus.OK.value + return Response( + status_code=code, + content=json.dumps( + {"status": _global_state.tokenizer_manager.server_status.value} + ), + ) + + +@app.post("/health") +async def health_update(obj: ReportHealthInput, request: Request) -> Response: + """Update the Status of the http server.""" + try: + server_status = ServerStatus(obj.status) + _global_state.tokenizer_manager.server_status = server_status + if server_status != ServerStatus.Up: + return Response( + status_code=HTTPStatus.SERVICE_UNAVAILABLE.value, content=obj.msg + ) + except Exception as e: + logger.error(e) + return Response(status_code=HTTPStatus.SERVICE_UNAVAILABLE.value) @app.get("/health_generate") @@ -256,7 +281,7 @@ async def health_generate(request: Request) -> Response: if _global_state.tokenizer_manager.last_receive_tstamp > tic: task.cancel() _global_state.tokenizer_manager.rid_to_state.pop(rid, None) - _global_state.tokenizer_manager.health_check_failed = False + _global_state.tokenizer_manager.server_status = ServerStatus.Up return Response(status_code=200) task.cancel() @@ -270,7 +295,7 @@ async def health_generate(request: Request) -> Response: f"last_heartbeat time: {last_receive_time}" ) _global_state.tokenizer_manager.rid_to_state.pop(rid, None) - _global_state.tokenizer_manager.health_check_failed = True + _global_state.tokenizer_manager.server_status = ServerStatus.UnHealthy return Response(status_code=503) @@ -1022,9 +1047,13 @@ def _execute_server_warmup( headers=headers, timeout=600, ) - assert res.status_code == 200, f"{res}" + if res.status_code == 200: + _global_state.tokenizer_manager.server_status = ServerStatus.Up + else: + _global_state.tokenizer_manager.server_status = ServerStatus.UnHealthy + logger.info(f"{res}") else: - logger.info(f"Start of prefill warmup ...") + logger.info(f"Start of prefill/decode warmup ...") json_data = { "sampling_params": { "temperature": 0.0, @@ -1046,15 +1075,25 @@ def _execute_server_warmup( headers=headers, timeout=1800, # because of deep gemm precache is very long if not precache. ) - logger.info( - f"End of prefill warmup with status {res.status_code}, resp: {res.json()}" - ) + if res.status_code == 200: + logger.info( + f"End of prefill disaggregation mode warmup with status {res.status_code}, resp: {res.json()}" + ) + _global_state.tokenizer_manager.server_status = ServerStatus.Up + else: + logger.info( + "Prefill disaggregation mode warm Up Failed, status code: {}".format( + res.status_code + ) + ) + _global_state.tokenizer_manager.server_status = ServerStatus.UnHealthy except Exception: last_traceback = get_exception_traceback() if pipe_finish_writer is not None: pipe_finish_writer.send(last_traceback) logger.error(f"Initialization failed. warmup error: {last_traceback}") + _global_state.tokenizer_manager.server_status = ServerStatus.Crashed kill_process_tree(os.getpid()) return False diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 8e1d1075a..b8332fdf6 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -1083,3 +1083,9 @@ class LoRAUpdateResult: LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult + + +@dataclass +class ReportHealthInput: + status: str + msg: Optional[str] = "" diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index e6dd80d71..aee1596db 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -143,6 +143,7 @@ from sglang.srt.two_batch_overlap import TboDPAttentionPreparer from sglang.srt.utils import ( DeepEPMode, DynamicGradMode, + ServerStatus, broadcast_pyobj, configure_gc_logger, configure_logger, @@ -154,6 +155,7 @@ from sglang.srt.utils import ( kill_itself_when_parent_died, point_to_point_pyobj, pyspy_dump_schedulers, + report_health, require_mlp_sync, require_mlp_tp_gather, set_gpu_proc_affinity, @@ -2964,4 +2966,5 @@ def run_scheduler_process( except Exception: traceback = get_exception_traceback() logger.error(f"Scheduler hit an exception: {traceback}") + report_health(ServerStatus.Crashed, server_args.host, ServerArgs.port) parent_process.send_signal(signal.SIGQUIT) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 631d23f17..a0f66419e 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -116,6 +116,7 @@ from sglang.srt.metrics.collector import TokenizerMetricsCollector from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import ( + ServerStatus, dataclass_to_string_truncated, get_bool_env_var, get_zmq_socket, @@ -173,6 +174,9 @@ class TokenizerManager: server_args: ServerArgs, port_args: PortArgs, ): + # Server Status + self.server_status = ServerStatus.Starting + # Parse args self.server_args = server_args self.enable_metrics = server_args.enable_metrics @@ -251,7 +255,6 @@ class TokenizerManager: # Store states self.no_create_loop = False self.rid_to_state: Dict[str, ReqState] = {} - self.health_check_failed = False self.gracefully_exit = False self.last_receive_tstamp = 0 self.dump_requests_folder = "" # By default do not dump @@ -1332,7 +1335,7 @@ class TokenizerManager: while True: remain_num_req = len(self.rid_to_state) - if self.health_check_failed: + if not self.server_status.is_healthy(): # if health check failed, we should exit immediately logger.error( "Signal SIGTERM received while health check failed. Exiting... remaining number of requests: %d", diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 23960a8c1..03565a018 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -93,6 +93,22 @@ time_infos = {} HIP_FP8_E4M3_FNUZ_MAX = 224.0 +class ServerStatus(Enum): + Up = "Up" + Starting = "Starting" + UnHealthy = "UnHealthy" + Crashed = "Crashed" + + def is_healthy(self) -> bool: + return self == ServerStatus.Up + + +def report_health(status: ServerStatus, host: str, http_port: int, msg: str = ""): + requests.post( + f"http://{host}:{http_port}/health", json={"status": status.value, "msg": msg} + ) + + # https://pytorch.org/docs/stable/notes/hip.html#checking-for-hip def is_hip() -> bool: return torch.version.hip is not None