diff --git a/python/pyproject.toml b/python/pyproject.toml index 1a449599c..fa909c4b9 100755 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -72,6 +72,7 @@ dependencies = [ "grpcio==1.75.1", # keep it align with compile_proto.py "grpcio-tools==1.75.1", # keep it align with compile_proto.py "grpcio-reflection==1.75.1", # required by srt/entrypoints/grpc_server.py + "grpcio-health-checking==1.75.1", # required for Kubernetes gRPC health probes ] [project.optional-dependencies] diff --git a/python/sglang/srt/entrypoints/grpc_server.py b/python/sglang/srt/entrypoints/grpc_server.py index 70fc9c7a3..80fb178f0 100644 --- a/python/sglang/srt/entrypoints/grpc_server.py +++ b/python/sglang/srt/entrypoints/grpc_server.py @@ -12,166 +12,35 @@ import signal import threading import time from concurrent import futures -from typing import AsyncIterator, Dict, Optional, Tuple +from typing import AsyncIterator, Dict, Optional import grpc from google.protobuf.json_format import MessageToDict from google.protobuf.struct_pb2 import Struct from google.protobuf.timestamp_pb2 import Timestamp +from grpc_health.v1 import health_pb2_grpc from grpc_reflection.v1alpha import reflection import sglang from sglang.srt.disaggregation.utils import FAKE_BOOTSTRAP_HOST, DisaggregationMode from sglang.srt.grpc import sglang_scheduler_pb2, sglang_scheduler_pb2_grpc from sglang.srt.grpc.grpc_request_manager import GrpcRequestManager -from sglang.srt.managers.data_parallel_controller import ( - run_data_parallel_controller_process, -) +from sglang.srt.grpc.health_servicer import SGLangHealthServicer +from sglang.srt.grpc.scheduler_launcher import launch_scheduler_process_only from sglang.srt.managers.disagg_service import start_disagg_service from sglang.srt.managers.io_struct import ( TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, ) -from sglang.srt.managers.scheduler import run_scheduler_process from sglang.srt.sampling.sampling_params import SamplingParams as SGLSamplingParams -from sglang.srt.server_args import PortArgs, ServerArgs -from sglang.srt.utils import ( - configure_logger, - kill_process_tree, - prepare_model_and_tokenizer, -) -from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter +from sglang.srt.server_args import ServerArgs +from sglang.srt.utils import kill_process_tree from sglang.utils import get_exception_traceback logger = logging.getLogger(__name__) HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20)) -def _run_scheduler_with_signal_handling(*args, **kwargs): - """ - Wrapper for run_scheduler_process that ignores SIGINT. - - The scheduler process should not handle Ctrl+C - it should only terminate - when the parent gRPC server exits (via kill_itself_when_parent_died). - """ - # Ignore SIGINT in this subprocess - let the parent handle it - signal.signal(signal.SIGINT, signal.SIG_IGN) - - # Now run the actual scheduler process - run_scheduler_process(*args, **kwargs) - - -def _launch_scheduler_process_only( - server_args: ServerArgs, - port_args: Optional[PortArgs] = None, -) -> Tuple[Dict, PortArgs, list]: - """ - Launch only the scheduler process(es) without tokenizer/detokenizer. - Returns scheduler info, port args, and list of scheduler processes. - """ - # Configure global environment - configure_logger(server_args) - server_args.check_server_args() - # Fix CUDA multiprocessing issues - must be called before any CUDA operations - mp.set_start_method("spawn", force=True) - - # Allocate ports for inter-process communications - if port_args is None: - port_args = PortArgs.init_new(server_args) - logger.info(f"{server_args=}") - - # Prepare model and tokenizer paths - server_args.model_path, server_args.tokenizer_path = prepare_model_and_tokenizer( - server_args.model_path, server_args.tokenizer_path - ) - - scheduler_procs = [] - if server_args.dp_size == 1: - memory_saver_adapter = TorchMemorySaverAdapter.create( - enable=server_args.enable_memory_saver - ) - scheduler_pipe_readers = [] - - nnodes_per_tp_group = max(server_args.nnodes // server_args.pp_size, 1) - tp_size_per_node = server_args.tp_size // nnodes_per_tp_group - tp_rank_range = range( - tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group), - tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group + 1), - ) - - pp_size_per_node = max(server_args.pp_size // server_args.nnodes, 1) - pp_rank_range = range( - pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group), - pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group + 1), - ) - - for pp_rank in pp_rank_range: - for tp_rank in tp_rank_range: - reader, writer = mp.Pipe(duplex=False) - gpu_id = ( - server_args.base_gpu_id - + ((pp_rank % pp_size_per_node) * tp_size_per_node) - + (tp_rank % tp_size_per_node) * server_args.gpu_id_step - ) - moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size) - proc = mp.Process( - target=_run_scheduler_with_signal_handling, - args=( - server_args, - port_args, - gpu_id, - tp_rank, - moe_ep_rank, - pp_rank, - None, - writer, - ), - ) - - with memory_saver_adapter.configure_subprocess(): - proc.start() - scheduler_procs.append(proc) - scheduler_pipe_readers.append(reader) - else: - # Launch the data parallel controller - reader, writer = mp.Pipe(duplex=False) - scheduler_pipe_readers = [reader] - proc = mp.Process( - target=run_data_parallel_controller_process, - args=(server_args, port_args, writer), - ) - proc.start() - scheduler_procs.append(proc) - - # TODO(CatherineSue): handle cases for multi-node - - # Wait for all scheduler processes to be ready - scheduler_infos = [] - for i, reader in enumerate(scheduler_pipe_readers): - try: - data = reader.recv() - except EOFError: - logger.error( - f"Rank {i} scheduler is dead. Please check if there are relevant logs." - ) - scheduler_procs[i].join() - logger.error(f"Exit code: {scheduler_procs[i].exitcode}") - raise RuntimeError(f"Failed to initialize scheduler rank {i}") - - if data.get("status") != "ready": - raise RuntimeError( - f"Scheduler rank {i} initialization failed: {data.get('error', 'Unknown error')}" - ) - scheduler_infos.append(data) - - logger.info( - f"All {len(scheduler_procs)} scheduler process(es) initialized successfully" - ) - - # Return the first scheduler's info (they should all be the same) - return scheduler_infos[0], port_args, scheduler_procs - - class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer): """ Standalone gRPC service implementation using GrpcRequestManager. @@ -184,6 +53,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer) server_args: ServerArgs, model_info: Dict, scheduler_info: Dict, + health_servicer: Optional[SGLangHealthServicer] = None, ): """Initialize the standalone gRPC service.""" self.request_manager = request_manager @@ -191,6 +61,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer) self.model_info = model_info self.scheduler_info = scheduler_info self.start_time = time.time() + self.health_servicer = health_servicer # Start the request manager's event loop using auto_create_handle_loop self.request_manager.auto_create_handle_loop() @@ -817,6 +688,10 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer) """Shutdown the service.""" logger.info("Shutting down gRPC service") + # Mark health service as NOT_SERVING before shutdown + if self.health_servicer: + self.health_servicer.set_not_serving() + # Shutdown request manager (handles its own tasks) await self.request_manager.shutdown() @@ -839,7 +714,7 @@ async def serve_grpc( # Launch only the scheduler process(es) (no tokenizer/detokenizer needed for gRPC) logger.info("Launching scheduler process(es)...") - scheduler_info, port_args, scheduler_procs = _launch_scheduler_process_only( + scheduler_info, port_args, scheduler_procs = launch_scheduler_process_only( server_args=server_args, ) @@ -876,18 +751,27 @@ async def serve_grpc( ], ) - # Add service + # Create standard health service (for Kubernetes probes) + health_servicer = SGLangHealthServicer( + request_manager=request_manager, + scheduler_info=scheduler_info, + ) + health_pb2_grpc.add_HealthServicer_to_server(health_servicer, server) + + # Add SGLang service servicer = SGLangSchedulerServicer( request_manager=request_manager, server_args=server_args, model_info=model_info, scheduler_info=scheduler_info, + health_servicer=health_servicer, ) sglang_scheduler_pb2_grpc.add_SglangSchedulerServicer_to_server(servicer, server) # Enable reflection SERVICE_NAMES = ( sglang_scheduler_pb2.DESCRIPTOR.services_by_name["SglangScheduler"].full_name, + "grpc.health.v1.Health", reflection.SERVICE_NAME, ) reflection.enable_server_reflection(SERVICE_NAMES, server) @@ -902,7 +786,7 @@ async def serve_grpc( # Start warmup in a separate thread warmup_thread = threading.Thread( target=_wait_and_warmup_grpc, - args=(server_args, None), + args=(server_args, None, health_servicer), ) warmup_thread.start() @@ -1103,6 +987,7 @@ def _execute_grpc_server_warmup( def _wait_and_warmup_grpc( server_args: ServerArgs, pipe_finish_writer: Optional[mp.connection.Connection], + health_servicer: Optional[SGLangHealthServicer] = None, ): """Wait for gRPC server to be ready and execute warmup.""" if not server_args.skip_server_warmup: @@ -1111,6 +996,11 @@ def _wait_and_warmup_grpc( else: logger.info("Skipping gRPC server warmup (skip_server_warmup=True)") + # Mark health service as SERVING after warmup completes + if health_servicer: + health_servicer.set_serving() + logger.info("Health service marked as SERVING") + logger.info("The server is fired up and ready to roll!") if pipe_finish_writer is not None: diff --git a/python/sglang/srt/grpc/health_servicer.py b/python/sglang/srt/grpc/health_servicer.py new file mode 100644 index 000000000..db3db2cc0 --- /dev/null +++ b/python/sglang/srt/grpc/health_servicer.py @@ -0,0 +1,189 @@ +""" +Standard gRPC health check service implementation for Kubernetes probes. + +This module implements the grpc.health.v1.Health service protocol, enabling +native Kubernetes gRPC health probes for liveness and readiness checks. +""" + +import logging +import time +from typing import AsyncIterator + +import grpc +from grpc_health.v1 import health_pb2, health_pb2_grpc + +logger = logging.getLogger(__name__) + + +class SGLangHealthServicer(health_pb2_grpc.HealthServicer): + """ + Standard gRPC health check service implementation for Kubernetes probes. + Implements grpc.health.v1.Health protocol. + + Supports two service levels: + 1. Overall server health (service="") - for liveness probes + 2. SGLang service health (service="sglang.grpc.scheduler.SglangScheduler") - for readiness probes + + Health status lifecycle: + - NOT_SERVING: Initial state, model loading, or shutting down + - SERVING: Model loaded and ready to serve requests + """ + + # Service names we support + OVERALL_SERVER = "" # Empty string for overall server health + SGLANG_SERVICE = "sglang.grpc.scheduler.SglangScheduler" + + def __init__(self, request_manager, scheduler_info: dict): + """ + Initialize health servicer. + + Args: + request_manager: GrpcRequestManager instance for checking server state + scheduler_info: Dict containing scheduler metadata + """ + self.request_manager = request_manager + self.scheduler_info = scheduler_info + self._serving_status = {} + + # Initially set to NOT_SERVING until model is loaded + self._serving_status[self.OVERALL_SERVER] = ( + health_pb2.HealthCheckResponse.NOT_SERVING + ) + self._serving_status[self.SGLANG_SERVICE] = ( + health_pb2.HealthCheckResponse.NOT_SERVING + ) + + logger.info("Standard gRPC health service initialized") + + def set_serving(self): + """Mark services as SERVING - call this after model is loaded.""" + self._serving_status[self.OVERALL_SERVER] = ( + health_pb2.HealthCheckResponse.SERVING + ) + self._serving_status[self.SGLANG_SERVICE] = ( + health_pb2.HealthCheckResponse.SERVING + ) + logger.info("Health service status set to SERVING") + + def set_not_serving(self): + """Mark services as NOT_SERVING - call this during shutdown.""" + self._serving_status[self.OVERALL_SERVER] = ( + health_pb2.HealthCheckResponse.NOT_SERVING + ) + self._serving_status[self.SGLANG_SERVICE] = ( + health_pb2.HealthCheckResponse.NOT_SERVING + ) + logger.info("Health service status set to NOT_SERVING") + + async def Check( + self, + request: health_pb2.HealthCheckRequest, + context: grpc.aio.ServicerContext, + ) -> health_pb2.HealthCheckResponse: + """ + Standard health check for Kubernetes probes. + + Args: + request: Contains service name ("" for overall, or specific service) + context: gRPC context + + Returns: + HealthCheckResponse with SERVING/NOT_SERVING/SERVICE_UNKNOWN status + """ + service_name = request.service + logger.debug(f"Health check request for service: '{service_name}'") + + # Check if shutting down + if self.request_manager.gracefully_exit: + logger.debug("Health check: Server is shutting down") + return health_pb2.HealthCheckResponse( + status=health_pb2.HealthCheckResponse.NOT_SERVING + ) + + # Overall server health - just check if process is alive + if service_name == self.OVERALL_SERVER: + status = self._serving_status.get( + self.OVERALL_SERVER, health_pb2.HealthCheckResponse.NOT_SERVING + ) + logger.debug( + f"Overall health check: {health_pb2.HealthCheckResponse.ServingStatus.Name(status)}" + ) + return health_pb2.HealthCheckResponse(status=status) + + # Specific service health - check if ready to serve + elif service_name == self.SGLANG_SERVICE: + # Additional checks for service readiness + + # Check base status first + base_status = self._serving_status.get( + self.SGLANG_SERVICE, health_pb2.HealthCheckResponse.NOT_SERVING + ) + + if base_status != health_pb2.HealthCheckResponse.SERVING: + logger.debug("Service health check: NOT_SERVING (base status)") + return health_pb2.HealthCheckResponse(status=base_status) + + # Check if scheduler is responsive (received data recently) + time_since_last_receive = ( + time.time() - self.request_manager.last_receive_tstamp + ) + + # If no recent activity and we have active requests, might be stuck + # NOTE: 30s timeout is hardcoded. This is more conservative than + # HEALTH_CHECK_TIMEOUT (20s) used for custom HealthCheck RPC. + # Consider making this configurable via environment variable in the future + # if different workloads need different responsiveness thresholds. + if ( + time_since_last_receive > 30 + and len(self.request_manager.rid_to_state) > 0 + ): + logger.warning( + f"Service health check: Scheduler not responsive " + f"({time_since_last_receive:.1f}s since last receive, " + f"{len(self.request_manager.rid_to_state)} pending requests)" + ) + return health_pb2.HealthCheckResponse( + status=health_pb2.HealthCheckResponse.NOT_SERVING + ) + + logger.debug("Service health check: SERVING") + return health_pb2.HealthCheckResponse( + status=health_pb2.HealthCheckResponse.SERVING + ) + + # Unknown service + else: + logger.debug(f"Health check for unknown service: '{service_name}'") + context.set_code(grpc.StatusCode.NOT_FOUND) + context.set_details(f"Unknown service: {service_name}") + return health_pb2.HealthCheckResponse( + status=health_pb2.HealthCheckResponse.SERVICE_UNKNOWN + ) + + async def Watch( + self, + request: health_pb2.HealthCheckRequest, + context: grpc.aio.ServicerContext, + ) -> AsyncIterator[health_pb2.HealthCheckResponse]: + """ + Streaming health check - sends updates when status changes. + + For now, just send current status once (Kubernetes doesn't use Watch). + A full implementation would monitor status changes and stream updates. + + Args: + request: Contains service name + context: gRPC context + + Yields: + HealthCheckResponse messages when status changes + """ + service_name = request.service + logger.debug(f"Health watch request for service: '{service_name}'") + + # Send current status + response = await self.Check(request, context) + yield response + + # Note: Full Watch implementation would monitor status changes + # and stream updates. For K8s probes, Check is sufficient. diff --git a/python/sglang/srt/grpc/scheduler_launcher.py b/python/sglang/srt/grpc/scheduler_launcher.py new file mode 100644 index 000000000..77a62d8a6 --- /dev/null +++ b/python/sglang/srt/grpc/scheduler_launcher.py @@ -0,0 +1,181 @@ +""" +Scheduler process management for gRPC server. + +This module handles launching and managing scheduler processes for the gRPC server, +including tensor parallelism, pipeline parallelism, and data parallelism configurations. +""" + +import logging +import multiprocessing as mp +import signal +from typing import Dict, List, Optional, Tuple + +from sglang.srt.managers.data_parallel_controller import ( + run_data_parallel_controller_process, +) +from sglang.srt.managers.scheduler import run_scheduler_process +from sglang.srt.server_args import PortArgs, ServerArgs +from sglang.srt.utils import configure_logger, prepare_model_and_tokenizer +from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter + +logger = logging.getLogger(__name__) + + +def run_scheduler_with_signal_handling(*args, **kwargs): + """ + Wrapper for run_scheduler_process that ignores SIGINT. + + The scheduler process should not handle Ctrl+C - it should only terminate + when the parent gRPC server exits (via kill_itself_when_parent_died). + + Args: + *args: Positional arguments for run_scheduler_process + **kwargs: Keyword arguments for run_scheduler_process + """ + # Ignore SIGINT in this subprocess - let the parent handle it + signal.signal(signal.SIGINT, signal.SIG_IGN) + + # Now run the actual scheduler process + run_scheduler_process(*args, **kwargs) + + +def launch_scheduler_process_only( + server_args: ServerArgs, + port_args: Optional[PortArgs] = None, +) -> Tuple[Dict, PortArgs, List[mp.Process]]: + """ + Launch only the scheduler process(es) without tokenizer/detokenizer. + + This function handles all scheduler startup logic including: + - Tensor parallelism (tp_size) + - Pipeline parallelism (pp_size) + - Data parallelism (dp_size) + - Multi-node distributed setup + + Args: + server_args: Server configuration + port_args: Port configuration (created if None) + + Returns: + Tuple of (scheduler_info, port_args, scheduler_processes): + - scheduler_info: Dict with model metadata and configuration + - port_args: Port configuration used for IPC + - scheduler_processes: List of launched scheduler Process objects + + Raises: + RuntimeError: If any scheduler process fails to initialize + """ + # Configure global environment + configure_logger(server_args) + server_args.check_server_args() + + # Fix CUDA multiprocessing issues - must be called before any CUDA operations + mp.set_start_method("spawn", force=True) + + # Allocate ports for inter-process communications + if port_args is None: + port_args = PortArgs.init_new(server_args) + logger.info(f"{server_args=}") + + # Prepare model and tokenizer paths + server_args.model_path, server_args.tokenizer_path = prepare_model_and_tokenizer( + server_args.model_path, server_args.tokenizer_path + ) + + scheduler_procs = [] + + if server_args.dp_size == 1: + # Single data parallel group - launch TP/PP schedulers + memory_saver_adapter = TorchMemorySaverAdapter.create( + enable=server_args.enable_memory_saver + ) + scheduler_pipe_readers = [] + + # Calculate TP/PP distribution across nodes + nnodes_per_tp_group = max(server_args.nnodes // server_args.pp_size, 1) + tp_size_per_node = server_args.tp_size // nnodes_per_tp_group + tp_rank_range = range( + tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group), + tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group + 1), + ) + + pp_size_per_node = max(server_args.pp_size // server_args.nnodes, 1) + pp_rank_range = range( + pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group), + pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group + 1), + ) + + # Launch scheduler for each TP/PP rank combination + for pp_rank in pp_rank_range: + for tp_rank in tp_rank_range: + reader, writer = mp.Pipe(duplex=False) + + # Calculate GPU ID for this rank + gpu_id = ( + server_args.base_gpu_id + + ((pp_rank % pp_size_per_node) * tp_size_per_node) + + (tp_rank % tp_size_per_node) * server_args.gpu_id_step + ) + + # Calculate MoE expert parallel rank + moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size) + + # Create scheduler process + proc = mp.Process( + target=run_scheduler_with_signal_handling, + args=( + server_args, + port_args, + gpu_id, + tp_rank, + moe_ep_rank, + pp_rank, + None, # dp_rank + writer, + ), + ) + + with memory_saver_adapter.configure_subprocess(): + proc.start() + + scheduler_procs.append(proc) + scheduler_pipe_readers.append(reader) + else: + # Data parallelism - launch data parallel controller + reader, writer = mp.Pipe(duplex=False) + scheduler_pipe_readers = [reader] + + proc = mp.Process( + target=run_data_parallel_controller_process, + args=(server_args, port_args, writer), + ) + proc.start() + scheduler_procs.append(proc) + + # TODO(CatherineSue): handle cases for multi-node + + # Wait for all scheduler processes to be ready + scheduler_infos = [] + for i, reader in enumerate(scheduler_pipe_readers): + try: + data = reader.recv() + except EOFError: + logger.error( + f"Rank {i} scheduler is dead. Please check if there are relevant logs." + ) + scheduler_procs[i].join() + logger.error(f"Exit code: {scheduler_procs[i].exitcode}") + raise RuntimeError(f"Failed to initialize scheduler rank {i}") + + if data.get("status") != "ready": + raise RuntimeError( + f"Scheduler rank {i} initialization failed: {data.get('error', 'Unknown error')}" + ) + scheduler_infos.append(data) + + logger.info( + f"All {len(scheduler_procs)} scheduler process(es) initialized successfully" + ) + + # Return the first scheduler's info (they should all be the same) + return scheduler_infos[0], port_args, scheduler_procs