[grpc] Support gRPC standard health check (#11955)
This commit is contained in:
@@ -72,6 +72,7 @@ dependencies = [
|
|||||||
"grpcio==1.75.1", # keep it align with compile_proto.py
|
"grpcio==1.75.1", # keep it align with compile_proto.py
|
||||||
"grpcio-tools==1.75.1", # keep it align with compile_proto.py
|
"grpcio-tools==1.75.1", # keep it align with compile_proto.py
|
||||||
"grpcio-reflection==1.75.1", # required by srt/entrypoints/grpc_server.py
|
"grpcio-reflection==1.75.1", # required by srt/entrypoints/grpc_server.py
|
||||||
|
"grpcio-health-checking==1.75.1", # required for Kubernetes gRPC health probes
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
|
|||||||
@@ -12,166 +12,35 @@ import signal
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from concurrent import futures
|
from concurrent import futures
|
||||||
from typing import AsyncIterator, Dict, Optional, Tuple
|
from typing import AsyncIterator, Dict, Optional
|
||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
from google.protobuf.json_format import MessageToDict
|
from google.protobuf.json_format import MessageToDict
|
||||||
from google.protobuf.struct_pb2 import Struct
|
from google.protobuf.struct_pb2 import Struct
|
||||||
from google.protobuf.timestamp_pb2 import Timestamp
|
from google.protobuf.timestamp_pb2 import Timestamp
|
||||||
|
from grpc_health.v1 import health_pb2_grpc
|
||||||
from grpc_reflection.v1alpha import reflection
|
from grpc_reflection.v1alpha import reflection
|
||||||
|
|
||||||
import sglang
|
import sglang
|
||||||
from sglang.srt.disaggregation.utils import FAKE_BOOTSTRAP_HOST, DisaggregationMode
|
from sglang.srt.disaggregation.utils import FAKE_BOOTSTRAP_HOST, DisaggregationMode
|
||||||
from sglang.srt.grpc import sglang_scheduler_pb2, sglang_scheduler_pb2_grpc
|
from sglang.srt.grpc import sglang_scheduler_pb2, sglang_scheduler_pb2_grpc
|
||||||
from sglang.srt.grpc.grpc_request_manager import GrpcRequestManager
|
from sglang.srt.grpc.grpc_request_manager import GrpcRequestManager
|
||||||
from sglang.srt.managers.data_parallel_controller import (
|
from sglang.srt.grpc.health_servicer import SGLangHealthServicer
|
||||||
run_data_parallel_controller_process,
|
from sglang.srt.grpc.scheduler_launcher import launch_scheduler_process_only
|
||||||
)
|
|
||||||
from sglang.srt.managers.disagg_service import start_disagg_service
|
from sglang.srt.managers.disagg_service import start_disagg_service
|
||||||
from sglang.srt.managers.io_struct import (
|
from sglang.srt.managers.io_struct import (
|
||||||
TokenizedEmbeddingReqInput,
|
TokenizedEmbeddingReqInput,
|
||||||
TokenizedGenerateReqInput,
|
TokenizedGenerateReqInput,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.scheduler import run_scheduler_process
|
|
||||||
from sglang.srt.sampling.sampling_params import SamplingParams as SGLSamplingParams
|
from sglang.srt.sampling.sampling_params import SamplingParams as SGLSamplingParams
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import kill_process_tree
|
||||||
configure_logger,
|
|
||||||
kill_process_tree,
|
|
||||||
prepare_model_and_tokenizer,
|
|
||||||
)
|
|
||||||
from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
|
||||||
from sglang.utils import get_exception_traceback
|
from sglang.utils import get_exception_traceback
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20))
|
HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20))
|
||||||
|
|
||||||
|
|
||||||
def _run_scheduler_with_signal_handling(*args, **kwargs):
|
|
||||||
"""
|
|
||||||
Wrapper for run_scheduler_process that ignores SIGINT.
|
|
||||||
|
|
||||||
The scheduler process should not handle Ctrl+C - it should only terminate
|
|
||||||
when the parent gRPC server exits (via kill_itself_when_parent_died).
|
|
||||||
"""
|
|
||||||
# Ignore SIGINT in this subprocess - let the parent handle it
|
|
||||||
signal.signal(signal.SIGINT, signal.SIG_IGN)
|
|
||||||
|
|
||||||
# Now run the actual scheduler process
|
|
||||||
run_scheduler_process(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def _launch_scheduler_process_only(
|
|
||||||
server_args: ServerArgs,
|
|
||||||
port_args: Optional[PortArgs] = None,
|
|
||||||
) -> Tuple[Dict, PortArgs, list]:
|
|
||||||
"""
|
|
||||||
Launch only the scheduler process(es) without tokenizer/detokenizer.
|
|
||||||
Returns scheduler info, port args, and list of scheduler processes.
|
|
||||||
"""
|
|
||||||
# Configure global environment
|
|
||||||
configure_logger(server_args)
|
|
||||||
server_args.check_server_args()
|
|
||||||
# Fix CUDA multiprocessing issues - must be called before any CUDA operations
|
|
||||||
mp.set_start_method("spawn", force=True)
|
|
||||||
|
|
||||||
# Allocate ports for inter-process communications
|
|
||||||
if port_args is None:
|
|
||||||
port_args = PortArgs.init_new(server_args)
|
|
||||||
logger.info(f"{server_args=}")
|
|
||||||
|
|
||||||
# Prepare model and tokenizer paths
|
|
||||||
server_args.model_path, server_args.tokenizer_path = prepare_model_and_tokenizer(
|
|
||||||
server_args.model_path, server_args.tokenizer_path
|
|
||||||
)
|
|
||||||
|
|
||||||
scheduler_procs = []
|
|
||||||
if server_args.dp_size == 1:
|
|
||||||
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
|
||||||
enable=server_args.enable_memory_saver
|
|
||||||
)
|
|
||||||
scheduler_pipe_readers = []
|
|
||||||
|
|
||||||
nnodes_per_tp_group = max(server_args.nnodes // server_args.pp_size, 1)
|
|
||||||
tp_size_per_node = server_args.tp_size // nnodes_per_tp_group
|
|
||||||
tp_rank_range = range(
|
|
||||||
tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group),
|
|
||||||
tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group + 1),
|
|
||||||
)
|
|
||||||
|
|
||||||
pp_size_per_node = max(server_args.pp_size // server_args.nnodes, 1)
|
|
||||||
pp_rank_range = range(
|
|
||||||
pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group),
|
|
||||||
pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group + 1),
|
|
||||||
)
|
|
||||||
|
|
||||||
for pp_rank in pp_rank_range:
|
|
||||||
for tp_rank in tp_rank_range:
|
|
||||||
reader, writer = mp.Pipe(duplex=False)
|
|
||||||
gpu_id = (
|
|
||||||
server_args.base_gpu_id
|
|
||||||
+ ((pp_rank % pp_size_per_node) * tp_size_per_node)
|
|
||||||
+ (tp_rank % tp_size_per_node) * server_args.gpu_id_step
|
|
||||||
)
|
|
||||||
moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size)
|
|
||||||
proc = mp.Process(
|
|
||||||
target=_run_scheduler_with_signal_handling,
|
|
||||||
args=(
|
|
||||||
server_args,
|
|
||||||
port_args,
|
|
||||||
gpu_id,
|
|
||||||
tp_rank,
|
|
||||||
moe_ep_rank,
|
|
||||||
pp_rank,
|
|
||||||
None,
|
|
||||||
writer,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
with memory_saver_adapter.configure_subprocess():
|
|
||||||
proc.start()
|
|
||||||
scheduler_procs.append(proc)
|
|
||||||
scheduler_pipe_readers.append(reader)
|
|
||||||
else:
|
|
||||||
# Launch the data parallel controller
|
|
||||||
reader, writer = mp.Pipe(duplex=False)
|
|
||||||
scheduler_pipe_readers = [reader]
|
|
||||||
proc = mp.Process(
|
|
||||||
target=run_data_parallel_controller_process,
|
|
||||||
args=(server_args, port_args, writer),
|
|
||||||
)
|
|
||||||
proc.start()
|
|
||||||
scheduler_procs.append(proc)
|
|
||||||
|
|
||||||
# TODO(CatherineSue): handle cases for multi-node
|
|
||||||
|
|
||||||
# Wait for all scheduler processes to be ready
|
|
||||||
scheduler_infos = []
|
|
||||||
for i, reader in enumerate(scheduler_pipe_readers):
|
|
||||||
try:
|
|
||||||
data = reader.recv()
|
|
||||||
except EOFError:
|
|
||||||
logger.error(
|
|
||||||
f"Rank {i} scheduler is dead. Please check if there are relevant logs."
|
|
||||||
)
|
|
||||||
scheduler_procs[i].join()
|
|
||||||
logger.error(f"Exit code: {scheduler_procs[i].exitcode}")
|
|
||||||
raise RuntimeError(f"Failed to initialize scheduler rank {i}")
|
|
||||||
|
|
||||||
if data.get("status") != "ready":
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Scheduler rank {i} initialization failed: {data.get('error', 'Unknown error')}"
|
|
||||||
)
|
|
||||||
scheduler_infos.append(data)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"All {len(scheduler_procs)} scheduler process(es) initialized successfully"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Return the first scheduler's info (they should all be the same)
|
|
||||||
return scheduler_infos[0], port_args, scheduler_procs
|
|
||||||
|
|
||||||
|
|
||||||
class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer):
|
class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer):
|
||||||
"""
|
"""
|
||||||
Standalone gRPC service implementation using GrpcRequestManager.
|
Standalone gRPC service implementation using GrpcRequestManager.
|
||||||
@@ -184,6 +53,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
|
|||||||
server_args: ServerArgs,
|
server_args: ServerArgs,
|
||||||
model_info: Dict,
|
model_info: Dict,
|
||||||
scheduler_info: Dict,
|
scheduler_info: Dict,
|
||||||
|
health_servicer: Optional[SGLangHealthServicer] = None,
|
||||||
):
|
):
|
||||||
"""Initialize the standalone gRPC service."""
|
"""Initialize the standalone gRPC service."""
|
||||||
self.request_manager = request_manager
|
self.request_manager = request_manager
|
||||||
@@ -191,6 +61,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
|
|||||||
self.model_info = model_info
|
self.model_info = model_info
|
||||||
self.scheduler_info = scheduler_info
|
self.scheduler_info = scheduler_info
|
||||||
self.start_time = time.time()
|
self.start_time = time.time()
|
||||||
|
self.health_servicer = health_servicer
|
||||||
|
|
||||||
# Start the request manager's event loop using auto_create_handle_loop
|
# Start the request manager's event loop using auto_create_handle_loop
|
||||||
self.request_manager.auto_create_handle_loop()
|
self.request_manager.auto_create_handle_loop()
|
||||||
@@ -817,6 +688,10 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
|
|||||||
"""Shutdown the service."""
|
"""Shutdown the service."""
|
||||||
logger.info("Shutting down gRPC service")
|
logger.info("Shutting down gRPC service")
|
||||||
|
|
||||||
|
# Mark health service as NOT_SERVING before shutdown
|
||||||
|
if self.health_servicer:
|
||||||
|
self.health_servicer.set_not_serving()
|
||||||
|
|
||||||
# Shutdown request manager (handles its own tasks)
|
# Shutdown request manager (handles its own tasks)
|
||||||
await self.request_manager.shutdown()
|
await self.request_manager.shutdown()
|
||||||
|
|
||||||
@@ -839,7 +714,7 @@ async def serve_grpc(
|
|||||||
|
|
||||||
# Launch only the scheduler process(es) (no tokenizer/detokenizer needed for gRPC)
|
# Launch only the scheduler process(es) (no tokenizer/detokenizer needed for gRPC)
|
||||||
logger.info("Launching scheduler process(es)...")
|
logger.info("Launching scheduler process(es)...")
|
||||||
scheduler_info, port_args, scheduler_procs = _launch_scheduler_process_only(
|
scheduler_info, port_args, scheduler_procs = launch_scheduler_process_only(
|
||||||
server_args=server_args,
|
server_args=server_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -876,18 +751,27 @@ async def serve_grpc(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add service
|
# Create standard health service (for Kubernetes probes)
|
||||||
|
health_servicer = SGLangHealthServicer(
|
||||||
|
request_manager=request_manager,
|
||||||
|
scheduler_info=scheduler_info,
|
||||||
|
)
|
||||||
|
health_pb2_grpc.add_HealthServicer_to_server(health_servicer, server)
|
||||||
|
|
||||||
|
# Add SGLang service
|
||||||
servicer = SGLangSchedulerServicer(
|
servicer = SGLangSchedulerServicer(
|
||||||
request_manager=request_manager,
|
request_manager=request_manager,
|
||||||
server_args=server_args,
|
server_args=server_args,
|
||||||
model_info=model_info,
|
model_info=model_info,
|
||||||
scheduler_info=scheduler_info,
|
scheduler_info=scheduler_info,
|
||||||
|
health_servicer=health_servicer,
|
||||||
)
|
)
|
||||||
sglang_scheduler_pb2_grpc.add_SglangSchedulerServicer_to_server(servicer, server)
|
sglang_scheduler_pb2_grpc.add_SglangSchedulerServicer_to_server(servicer, server)
|
||||||
|
|
||||||
# Enable reflection
|
# Enable reflection
|
||||||
SERVICE_NAMES = (
|
SERVICE_NAMES = (
|
||||||
sglang_scheduler_pb2.DESCRIPTOR.services_by_name["SglangScheduler"].full_name,
|
sglang_scheduler_pb2.DESCRIPTOR.services_by_name["SglangScheduler"].full_name,
|
||||||
|
"grpc.health.v1.Health",
|
||||||
reflection.SERVICE_NAME,
|
reflection.SERVICE_NAME,
|
||||||
)
|
)
|
||||||
reflection.enable_server_reflection(SERVICE_NAMES, server)
|
reflection.enable_server_reflection(SERVICE_NAMES, server)
|
||||||
@@ -902,7 +786,7 @@ async def serve_grpc(
|
|||||||
# Start warmup in a separate thread
|
# Start warmup in a separate thread
|
||||||
warmup_thread = threading.Thread(
|
warmup_thread = threading.Thread(
|
||||||
target=_wait_and_warmup_grpc,
|
target=_wait_and_warmup_grpc,
|
||||||
args=(server_args, None),
|
args=(server_args, None, health_servicer),
|
||||||
)
|
)
|
||||||
warmup_thread.start()
|
warmup_thread.start()
|
||||||
|
|
||||||
@@ -1103,6 +987,7 @@ def _execute_grpc_server_warmup(
|
|||||||
def _wait_and_warmup_grpc(
|
def _wait_and_warmup_grpc(
|
||||||
server_args: ServerArgs,
|
server_args: ServerArgs,
|
||||||
pipe_finish_writer: Optional[mp.connection.Connection],
|
pipe_finish_writer: Optional[mp.connection.Connection],
|
||||||
|
health_servicer: Optional[SGLangHealthServicer] = None,
|
||||||
):
|
):
|
||||||
"""Wait for gRPC server to be ready and execute warmup."""
|
"""Wait for gRPC server to be ready and execute warmup."""
|
||||||
if not server_args.skip_server_warmup:
|
if not server_args.skip_server_warmup:
|
||||||
@@ -1111,6 +996,11 @@ def _wait_and_warmup_grpc(
|
|||||||
else:
|
else:
|
||||||
logger.info("Skipping gRPC server warmup (skip_server_warmup=True)")
|
logger.info("Skipping gRPC server warmup (skip_server_warmup=True)")
|
||||||
|
|
||||||
|
# Mark health service as SERVING after warmup completes
|
||||||
|
if health_servicer:
|
||||||
|
health_servicer.set_serving()
|
||||||
|
logger.info("Health service marked as SERVING")
|
||||||
|
|
||||||
logger.info("The server is fired up and ready to roll!")
|
logger.info("The server is fired up and ready to roll!")
|
||||||
|
|
||||||
if pipe_finish_writer is not None:
|
if pipe_finish_writer is not None:
|
||||||
|
|||||||
189
python/sglang/srt/grpc/health_servicer.py
Normal file
189
python/sglang/srt/grpc/health_servicer.py
Normal file
@@ -0,0 +1,189 @@
|
|||||||
|
"""
|
||||||
|
Standard gRPC health check service implementation for Kubernetes probes.
|
||||||
|
|
||||||
|
This module implements the grpc.health.v1.Health service protocol, enabling
|
||||||
|
native Kubernetes gRPC health probes for liveness and readiness checks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import AsyncIterator
|
||||||
|
|
||||||
|
import grpc
|
||||||
|
from grpc_health.v1 import health_pb2, health_pb2_grpc
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SGLangHealthServicer(health_pb2_grpc.HealthServicer):
|
||||||
|
"""
|
||||||
|
Standard gRPC health check service implementation for Kubernetes probes.
|
||||||
|
Implements grpc.health.v1.Health protocol.
|
||||||
|
|
||||||
|
Supports two service levels:
|
||||||
|
1. Overall server health (service="") - for liveness probes
|
||||||
|
2. SGLang service health (service="sglang.grpc.scheduler.SglangScheduler") - for readiness probes
|
||||||
|
|
||||||
|
Health status lifecycle:
|
||||||
|
- NOT_SERVING: Initial state, model loading, or shutting down
|
||||||
|
- SERVING: Model loaded and ready to serve requests
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Service names we support
|
||||||
|
OVERALL_SERVER = "" # Empty string for overall server health
|
||||||
|
SGLANG_SERVICE = "sglang.grpc.scheduler.SglangScheduler"
|
||||||
|
|
||||||
|
def __init__(self, request_manager, scheduler_info: dict):
|
||||||
|
"""
|
||||||
|
Initialize health servicer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request_manager: GrpcRequestManager instance for checking server state
|
||||||
|
scheduler_info: Dict containing scheduler metadata
|
||||||
|
"""
|
||||||
|
self.request_manager = request_manager
|
||||||
|
self.scheduler_info = scheduler_info
|
||||||
|
self._serving_status = {}
|
||||||
|
|
||||||
|
# Initially set to NOT_SERVING until model is loaded
|
||||||
|
self._serving_status[self.OVERALL_SERVER] = (
|
||||||
|
health_pb2.HealthCheckResponse.NOT_SERVING
|
||||||
|
)
|
||||||
|
self._serving_status[self.SGLANG_SERVICE] = (
|
||||||
|
health_pb2.HealthCheckResponse.NOT_SERVING
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Standard gRPC health service initialized")
|
||||||
|
|
||||||
|
def set_serving(self):
|
||||||
|
"""Mark services as SERVING - call this after model is loaded."""
|
||||||
|
self._serving_status[self.OVERALL_SERVER] = (
|
||||||
|
health_pb2.HealthCheckResponse.SERVING
|
||||||
|
)
|
||||||
|
self._serving_status[self.SGLANG_SERVICE] = (
|
||||||
|
health_pb2.HealthCheckResponse.SERVING
|
||||||
|
)
|
||||||
|
logger.info("Health service status set to SERVING")
|
||||||
|
|
||||||
|
def set_not_serving(self):
|
||||||
|
"""Mark services as NOT_SERVING - call this during shutdown."""
|
||||||
|
self._serving_status[self.OVERALL_SERVER] = (
|
||||||
|
health_pb2.HealthCheckResponse.NOT_SERVING
|
||||||
|
)
|
||||||
|
self._serving_status[self.SGLANG_SERVICE] = (
|
||||||
|
health_pb2.HealthCheckResponse.NOT_SERVING
|
||||||
|
)
|
||||||
|
logger.info("Health service status set to NOT_SERVING")
|
||||||
|
|
||||||
|
async def Check(
|
||||||
|
self,
|
||||||
|
request: health_pb2.HealthCheckRequest,
|
||||||
|
context: grpc.aio.ServicerContext,
|
||||||
|
) -> health_pb2.HealthCheckResponse:
|
||||||
|
"""
|
||||||
|
Standard health check for Kubernetes probes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: Contains service name ("" for overall, or specific service)
|
||||||
|
context: gRPC context
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
HealthCheckResponse with SERVING/NOT_SERVING/SERVICE_UNKNOWN status
|
||||||
|
"""
|
||||||
|
service_name = request.service
|
||||||
|
logger.debug(f"Health check request for service: '{service_name}'")
|
||||||
|
|
||||||
|
# Check if shutting down
|
||||||
|
if self.request_manager.gracefully_exit:
|
||||||
|
logger.debug("Health check: Server is shutting down")
|
||||||
|
return health_pb2.HealthCheckResponse(
|
||||||
|
status=health_pb2.HealthCheckResponse.NOT_SERVING
|
||||||
|
)
|
||||||
|
|
||||||
|
# Overall server health - just check if process is alive
|
||||||
|
if service_name == self.OVERALL_SERVER:
|
||||||
|
status = self._serving_status.get(
|
||||||
|
self.OVERALL_SERVER, health_pb2.HealthCheckResponse.NOT_SERVING
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f"Overall health check: {health_pb2.HealthCheckResponse.ServingStatus.Name(status)}"
|
||||||
|
)
|
||||||
|
return health_pb2.HealthCheckResponse(status=status)
|
||||||
|
|
||||||
|
# Specific service health - check if ready to serve
|
||||||
|
elif service_name == self.SGLANG_SERVICE:
|
||||||
|
# Additional checks for service readiness
|
||||||
|
|
||||||
|
# Check base status first
|
||||||
|
base_status = self._serving_status.get(
|
||||||
|
self.SGLANG_SERVICE, health_pb2.HealthCheckResponse.NOT_SERVING
|
||||||
|
)
|
||||||
|
|
||||||
|
if base_status != health_pb2.HealthCheckResponse.SERVING:
|
||||||
|
logger.debug("Service health check: NOT_SERVING (base status)")
|
||||||
|
return health_pb2.HealthCheckResponse(status=base_status)
|
||||||
|
|
||||||
|
# Check if scheduler is responsive (received data recently)
|
||||||
|
time_since_last_receive = (
|
||||||
|
time.time() - self.request_manager.last_receive_tstamp
|
||||||
|
)
|
||||||
|
|
||||||
|
# If no recent activity and we have active requests, might be stuck
|
||||||
|
# NOTE: 30s timeout is hardcoded. This is more conservative than
|
||||||
|
# HEALTH_CHECK_TIMEOUT (20s) used for custom HealthCheck RPC.
|
||||||
|
# Consider making this configurable via environment variable in the future
|
||||||
|
# if different workloads need different responsiveness thresholds.
|
||||||
|
if (
|
||||||
|
time_since_last_receive > 30
|
||||||
|
and len(self.request_manager.rid_to_state) > 0
|
||||||
|
):
|
||||||
|
logger.warning(
|
||||||
|
f"Service health check: Scheduler not responsive "
|
||||||
|
f"({time_since_last_receive:.1f}s since last receive, "
|
||||||
|
f"{len(self.request_manager.rid_to_state)} pending requests)"
|
||||||
|
)
|
||||||
|
return health_pb2.HealthCheckResponse(
|
||||||
|
status=health_pb2.HealthCheckResponse.NOT_SERVING
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug("Service health check: SERVING")
|
||||||
|
return health_pb2.HealthCheckResponse(
|
||||||
|
status=health_pb2.HealthCheckResponse.SERVING
|
||||||
|
)
|
||||||
|
|
||||||
|
# Unknown service
|
||||||
|
else:
|
||||||
|
logger.debug(f"Health check for unknown service: '{service_name}'")
|
||||||
|
context.set_code(grpc.StatusCode.NOT_FOUND)
|
||||||
|
context.set_details(f"Unknown service: {service_name}")
|
||||||
|
return health_pb2.HealthCheckResponse(
|
||||||
|
status=health_pb2.HealthCheckResponse.SERVICE_UNKNOWN
|
||||||
|
)
|
||||||
|
|
||||||
|
async def Watch(
|
||||||
|
self,
|
||||||
|
request: health_pb2.HealthCheckRequest,
|
||||||
|
context: grpc.aio.ServicerContext,
|
||||||
|
) -> AsyncIterator[health_pb2.HealthCheckResponse]:
|
||||||
|
"""
|
||||||
|
Streaming health check - sends updates when status changes.
|
||||||
|
|
||||||
|
For now, just send current status once (Kubernetes doesn't use Watch).
|
||||||
|
A full implementation would monitor status changes and stream updates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: Contains service name
|
||||||
|
context: gRPC context
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
HealthCheckResponse messages when status changes
|
||||||
|
"""
|
||||||
|
service_name = request.service
|
||||||
|
logger.debug(f"Health watch request for service: '{service_name}'")
|
||||||
|
|
||||||
|
# Send current status
|
||||||
|
response = await self.Check(request, context)
|
||||||
|
yield response
|
||||||
|
|
||||||
|
# Note: Full Watch implementation would monitor status changes
|
||||||
|
# and stream updates. For K8s probes, Check is sufficient.
|
||||||
181
python/sglang/srt/grpc/scheduler_launcher.py
Normal file
181
python/sglang/srt/grpc/scheduler_launcher.py
Normal file
@@ -0,0 +1,181 @@
|
|||||||
|
"""
|
||||||
|
Scheduler process management for gRPC server.
|
||||||
|
|
||||||
|
This module handles launching and managing scheduler processes for the gRPC server,
|
||||||
|
including tensor parallelism, pipeline parallelism, and data parallelism configurations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import multiprocessing as mp
|
||||||
|
import signal
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
from sglang.srt.managers.data_parallel_controller import (
|
||||||
|
run_data_parallel_controller_process,
|
||||||
|
)
|
||||||
|
from sglang.srt.managers.scheduler import run_scheduler_process
|
||||||
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
|
from sglang.srt.utils import configure_logger, prepare_model_and_tokenizer
|
||||||
|
from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def run_scheduler_with_signal_handling(*args, **kwargs):
|
||||||
|
"""
|
||||||
|
Wrapper for run_scheduler_process that ignores SIGINT.
|
||||||
|
|
||||||
|
The scheduler process should not handle Ctrl+C - it should only terminate
|
||||||
|
when the parent gRPC server exits (via kill_itself_when_parent_died).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*args: Positional arguments for run_scheduler_process
|
||||||
|
**kwargs: Keyword arguments for run_scheduler_process
|
||||||
|
"""
|
||||||
|
# Ignore SIGINT in this subprocess - let the parent handle it
|
||||||
|
signal.signal(signal.SIGINT, signal.SIG_IGN)
|
||||||
|
|
||||||
|
# Now run the actual scheduler process
|
||||||
|
run_scheduler_process(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def launch_scheduler_process_only(
|
||||||
|
server_args: ServerArgs,
|
||||||
|
port_args: Optional[PortArgs] = None,
|
||||||
|
) -> Tuple[Dict, PortArgs, List[mp.Process]]:
|
||||||
|
"""
|
||||||
|
Launch only the scheduler process(es) without tokenizer/detokenizer.
|
||||||
|
|
||||||
|
This function handles all scheduler startup logic including:
|
||||||
|
- Tensor parallelism (tp_size)
|
||||||
|
- Pipeline parallelism (pp_size)
|
||||||
|
- Data parallelism (dp_size)
|
||||||
|
- Multi-node distributed setup
|
||||||
|
|
||||||
|
Args:
|
||||||
|
server_args: Server configuration
|
||||||
|
port_args: Port configuration (created if None)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (scheduler_info, port_args, scheduler_processes):
|
||||||
|
- scheduler_info: Dict with model metadata and configuration
|
||||||
|
- port_args: Port configuration used for IPC
|
||||||
|
- scheduler_processes: List of launched scheduler Process objects
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If any scheduler process fails to initialize
|
||||||
|
"""
|
||||||
|
# Configure global environment
|
||||||
|
configure_logger(server_args)
|
||||||
|
server_args.check_server_args()
|
||||||
|
|
||||||
|
# Fix CUDA multiprocessing issues - must be called before any CUDA operations
|
||||||
|
mp.set_start_method("spawn", force=True)
|
||||||
|
|
||||||
|
# Allocate ports for inter-process communications
|
||||||
|
if port_args is None:
|
||||||
|
port_args = PortArgs.init_new(server_args)
|
||||||
|
logger.info(f"{server_args=}")
|
||||||
|
|
||||||
|
# Prepare model and tokenizer paths
|
||||||
|
server_args.model_path, server_args.tokenizer_path = prepare_model_and_tokenizer(
|
||||||
|
server_args.model_path, server_args.tokenizer_path
|
||||||
|
)
|
||||||
|
|
||||||
|
scheduler_procs = []
|
||||||
|
|
||||||
|
if server_args.dp_size == 1:
|
||||||
|
# Single data parallel group - launch TP/PP schedulers
|
||||||
|
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
||||||
|
enable=server_args.enable_memory_saver
|
||||||
|
)
|
||||||
|
scheduler_pipe_readers = []
|
||||||
|
|
||||||
|
# Calculate TP/PP distribution across nodes
|
||||||
|
nnodes_per_tp_group = max(server_args.nnodes // server_args.pp_size, 1)
|
||||||
|
tp_size_per_node = server_args.tp_size // nnodes_per_tp_group
|
||||||
|
tp_rank_range = range(
|
||||||
|
tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group),
|
||||||
|
tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group + 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
pp_size_per_node = max(server_args.pp_size // server_args.nnodes, 1)
|
||||||
|
pp_rank_range = range(
|
||||||
|
pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group),
|
||||||
|
pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group + 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Launch scheduler for each TP/PP rank combination
|
||||||
|
for pp_rank in pp_rank_range:
|
||||||
|
for tp_rank in tp_rank_range:
|
||||||
|
reader, writer = mp.Pipe(duplex=False)
|
||||||
|
|
||||||
|
# Calculate GPU ID for this rank
|
||||||
|
gpu_id = (
|
||||||
|
server_args.base_gpu_id
|
||||||
|
+ ((pp_rank % pp_size_per_node) * tp_size_per_node)
|
||||||
|
+ (tp_rank % tp_size_per_node) * server_args.gpu_id_step
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate MoE expert parallel rank
|
||||||
|
moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size)
|
||||||
|
|
||||||
|
# Create scheduler process
|
||||||
|
proc = mp.Process(
|
||||||
|
target=run_scheduler_with_signal_handling,
|
||||||
|
args=(
|
||||||
|
server_args,
|
||||||
|
port_args,
|
||||||
|
gpu_id,
|
||||||
|
tp_rank,
|
||||||
|
moe_ep_rank,
|
||||||
|
pp_rank,
|
||||||
|
None, # dp_rank
|
||||||
|
writer,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
with memory_saver_adapter.configure_subprocess():
|
||||||
|
proc.start()
|
||||||
|
|
||||||
|
scheduler_procs.append(proc)
|
||||||
|
scheduler_pipe_readers.append(reader)
|
||||||
|
else:
|
||||||
|
# Data parallelism - launch data parallel controller
|
||||||
|
reader, writer = mp.Pipe(duplex=False)
|
||||||
|
scheduler_pipe_readers = [reader]
|
||||||
|
|
||||||
|
proc = mp.Process(
|
||||||
|
target=run_data_parallel_controller_process,
|
||||||
|
args=(server_args, port_args, writer),
|
||||||
|
)
|
||||||
|
proc.start()
|
||||||
|
scheduler_procs.append(proc)
|
||||||
|
|
||||||
|
# TODO(CatherineSue): handle cases for multi-node
|
||||||
|
|
||||||
|
# Wait for all scheduler processes to be ready
|
||||||
|
scheduler_infos = []
|
||||||
|
for i, reader in enumerate(scheduler_pipe_readers):
|
||||||
|
try:
|
||||||
|
data = reader.recv()
|
||||||
|
except EOFError:
|
||||||
|
logger.error(
|
||||||
|
f"Rank {i} scheduler is dead. Please check if there are relevant logs."
|
||||||
|
)
|
||||||
|
scheduler_procs[i].join()
|
||||||
|
logger.error(f"Exit code: {scheduler_procs[i].exitcode}")
|
||||||
|
raise RuntimeError(f"Failed to initialize scheduler rank {i}")
|
||||||
|
|
||||||
|
if data.get("status") != "ready":
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Scheduler rank {i} initialization failed: {data.get('error', 'Unknown error')}"
|
||||||
|
)
|
||||||
|
scheduler_infos.append(data)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"All {len(scheduler_procs)} scheduler process(es) initialized successfully"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return the first scheduler's info (they should all be the same)
|
||||||
|
return scheduler_infos[0], port_args, scheduler_procs
|
||||||
Reference in New Issue
Block a user