[grpc] Support gRPC standard health check (#11955)
This commit is contained in:
@@ -72,6 +72,7 @@ dependencies = [
|
||||
"grpcio==1.75.1", # keep it align with compile_proto.py
|
||||
"grpcio-tools==1.75.1", # keep it align with compile_proto.py
|
||||
"grpcio-reflection==1.75.1", # required by srt/entrypoints/grpc_server.py
|
||||
"grpcio-health-checking==1.75.1", # required for Kubernetes gRPC health probes
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
||||
@@ -12,166 +12,35 @@ import signal
|
||||
import threading
|
||||
import time
|
||||
from concurrent import futures
|
||||
from typing import AsyncIterator, Dict, Optional, Tuple
|
||||
from typing import AsyncIterator, Dict, Optional
|
||||
|
||||
import grpc
|
||||
from google.protobuf.json_format import MessageToDict
|
||||
from google.protobuf.struct_pb2 import Struct
|
||||
from google.protobuf.timestamp_pb2 import Timestamp
|
||||
from grpc_health.v1 import health_pb2_grpc
|
||||
from grpc_reflection.v1alpha import reflection
|
||||
|
||||
import sglang
|
||||
from sglang.srt.disaggregation.utils import FAKE_BOOTSTRAP_HOST, DisaggregationMode
|
||||
from sglang.srt.grpc import sglang_scheduler_pb2, sglang_scheduler_pb2_grpc
|
||||
from sglang.srt.grpc.grpc_request_manager import GrpcRequestManager
|
||||
from sglang.srt.managers.data_parallel_controller import (
|
||||
run_data_parallel_controller_process,
|
||||
)
|
||||
from sglang.srt.grpc.health_servicer import SGLangHealthServicer
|
||||
from sglang.srt.grpc.scheduler_launcher import launch_scheduler_process_only
|
||||
from sglang.srt.managers.disagg_service import start_disagg_service
|
||||
from sglang.srt.managers.io_struct import (
|
||||
TokenizedEmbeddingReqInput,
|
||||
TokenizedGenerateReqInput,
|
||||
)
|
||||
from sglang.srt.managers.scheduler import run_scheduler_process
|
||||
from sglang.srt.sampling.sampling_params import SamplingParams as SGLSamplingParams
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
configure_logger,
|
||||
kill_process_tree,
|
||||
prepare_model_and_tokenizer,
|
||||
)
|
||||
from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20))
|
||||
|
||||
|
||||
def _run_scheduler_with_signal_handling(*args, **kwargs):
|
||||
"""
|
||||
Wrapper for run_scheduler_process that ignores SIGINT.
|
||||
|
||||
The scheduler process should not handle Ctrl+C - it should only terminate
|
||||
when the parent gRPC server exits (via kill_itself_when_parent_died).
|
||||
"""
|
||||
# Ignore SIGINT in this subprocess - let the parent handle it
|
||||
signal.signal(signal.SIGINT, signal.SIG_IGN)
|
||||
|
||||
# Now run the actual scheduler process
|
||||
run_scheduler_process(*args, **kwargs)
|
||||
|
||||
|
||||
def _launch_scheduler_process_only(
|
||||
server_args: ServerArgs,
|
||||
port_args: Optional[PortArgs] = None,
|
||||
) -> Tuple[Dict, PortArgs, list]:
|
||||
"""
|
||||
Launch only the scheduler process(es) without tokenizer/detokenizer.
|
||||
Returns scheduler info, port args, and list of scheduler processes.
|
||||
"""
|
||||
# Configure global environment
|
||||
configure_logger(server_args)
|
||||
server_args.check_server_args()
|
||||
# Fix CUDA multiprocessing issues - must be called before any CUDA operations
|
||||
mp.set_start_method("spawn", force=True)
|
||||
|
||||
# Allocate ports for inter-process communications
|
||||
if port_args is None:
|
||||
port_args = PortArgs.init_new(server_args)
|
||||
logger.info(f"{server_args=}")
|
||||
|
||||
# Prepare model and tokenizer paths
|
||||
server_args.model_path, server_args.tokenizer_path = prepare_model_and_tokenizer(
|
||||
server_args.model_path, server_args.tokenizer_path
|
||||
)
|
||||
|
||||
scheduler_procs = []
|
||||
if server_args.dp_size == 1:
|
||||
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
||||
enable=server_args.enable_memory_saver
|
||||
)
|
||||
scheduler_pipe_readers = []
|
||||
|
||||
nnodes_per_tp_group = max(server_args.nnodes // server_args.pp_size, 1)
|
||||
tp_size_per_node = server_args.tp_size // nnodes_per_tp_group
|
||||
tp_rank_range = range(
|
||||
tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group),
|
||||
tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group + 1),
|
||||
)
|
||||
|
||||
pp_size_per_node = max(server_args.pp_size // server_args.nnodes, 1)
|
||||
pp_rank_range = range(
|
||||
pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group),
|
||||
pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group + 1),
|
||||
)
|
||||
|
||||
for pp_rank in pp_rank_range:
|
||||
for tp_rank in tp_rank_range:
|
||||
reader, writer = mp.Pipe(duplex=False)
|
||||
gpu_id = (
|
||||
server_args.base_gpu_id
|
||||
+ ((pp_rank % pp_size_per_node) * tp_size_per_node)
|
||||
+ (tp_rank % tp_size_per_node) * server_args.gpu_id_step
|
||||
)
|
||||
moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size)
|
||||
proc = mp.Process(
|
||||
target=_run_scheduler_with_signal_handling,
|
||||
args=(
|
||||
server_args,
|
||||
port_args,
|
||||
gpu_id,
|
||||
tp_rank,
|
||||
moe_ep_rank,
|
||||
pp_rank,
|
||||
None,
|
||||
writer,
|
||||
),
|
||||
)
|
||||
|
||||
with memory_saver_adapter.configure_subprocess():
|
||||
proc.start()
|
||||
scheduler_procs.append(proc)
|
||||
scheduler_pipe_readers.append(reader)
|
||||
else:
|
||||
# Launch the data parallel controller
|
||||
reader, writer = mp.Pipe(duplex=False)
|
||||
scheduler_pipe_readers = [reader]
|
||||
proc = mp.Process(
|
||||
target=run_data_parallel_controller_process,
|
||||
args=(server_args, port_args, writer),
|
||||
)
|
||||
proc.start()
|
||||
scheduler_procs.append(proc)
|
||||
|
||||
# TODO(CatherineSue): handle cases for multi-node
|
||||
|
||||
# Wait for all scheduler processes to be ready
|
||||
scheduler_infos = []
|
||||
for i, reader in enumerate(scheduler_pipe_readers):
|
||||
try:
|
||||
data = reader.recv()
|
||||
except EOFError:
|
||||
logger.error(
|
||||
f"Rank {i} scheduler is dead. Please check if there are relevant logs."
|
||||
)
|
||||
scheduler_procs[i].join()
|
||||
logger.error(f"Exit code: {scheduler_procs[i].exitcode}")
|
||||
raise RuntimeError(f"Failed to initialize scheduler rank {i}")
|
||||
|
||||
if data.get("status") != "ready":
|
||||
raise RuntimeError(
|
||||
f"Scheduler rank {i} initialization failed: {data.get('error', 'Unknown error')}"
|
||||
)
|
||||
scheduler_infos.append(data)
|
||||
|
||||
logger.info(
|
||||
f"All {len(scheduler_procs)} scheduler process(es) initialized successfully"
|
||||
)
|
||||
|
||||
# Return the first scheduler's info (they should all be the same)
|
||||
return scheduler_infos[0], port_args, scheduler_procs
|
||||
|
||||
|
||||
class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer):
|
||||
"""
|
||||
Standalone gRPC service implementation using GrpcRequestManager.
|
||||
@@ -184,6 +53,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
|
||||
server_args: ServerArgs,
|
||||
model_info: Dict,
|
||||
scheduler_info: Dict,
|
||||
health_servicer: Optional[SGLangHealthServicer] = None,
|
||||
):
|
||||
"""Initialize the standalone gRPC service."""
|
||||
self.request_manager = request_manager
|
||||
@@ -191,6 +61,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
|
||||
self.model_info = model_info
|
||||
self.scheduler_info = scheduler_info
|
||||
self.start_time = time.time()
|
||||
self.health_servicer = health_servicer
|
||||
|
||||
# Start the request manager's event loop using auto_create_handle_loop
|
||||
self.request_manager.auto_create_handle_loop()
|
||||
@@ -817,6 +688,10 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
|
||||
"""Shutdown the service."""
|
||||
logger.info("Shutting down gRPC service")
|
||||
|
||||
# Mark health service as NOT_SERVING before shutdown
|
||||
if self.health_servicer:
|
||||
self.health_servicer.set_not_serving()
|
||||
|
||||
# Shutdown request manager (handles its own tasks)
|
||||
await self.request_manager.shutdown()
|
||||
|
||||
@@ -839,7 +714,7 @@ async def serve_grpc(
|
||||
|
||||
# Launch only the scheduler process(es) (no tokenizer/detokenizer needed for gRPC)
|
||||
logger.info("Launching scheduler process(es)...")
|
||||
scheduler_info, port_args, scheduler_procs = _launch_scheduler_process_only(
|
||||
scheduler_info, port_args, scheduler_procs = launch_scheduler_process_only(
|
||||
server_args=server_args,
|
||||
)
|
||||
|
||||
@@ -876,18 +751,27 @@ async def serve_grpc(
|
||||
],
|
||||
)
|
||||
|
||||
# Add service
|
||||
# Create standard health service (for Kubernetes probes)
|
||||
health_servicer = SGLangHealthServicer(
|
||||
request_manager=request_manager,
|
||||
scheduler_info=scheduler_info,
|
||||
)
|
||||
health_pb2_grpc.add_HealthServicer_to_server(health_servicer, server)
|
||||
|
||||
# Add SGLang service
|
||||
servicer = SGLangSchedulerServicer(
|
||||
request_manager=request_manager,
|
||||
server_args=server_args,
|
||||
model_info=model_info,
|
||||
scheduler_info=scheduler_info,
|
||||
health_servicer=health_servicer,
|
||||
)
|
||||
sglang_scheduler_pb2_grpc.add_SglangSchedulerServicer_to_server(servicer, server)
|
||||
|
||||
# Enable reflection
|
||||
SERVICE_NAMES = (
|
||||
sglang_scheduler_pb2.DESCRIPTOR.services_by_name["SglangScheduler"].full_name,
|
||||
"grpc.health.v1.Health",
|
||||
reflection.SERVICE_NAME,
|
||||
)
|
||||
reflection.enable_server_reflection(SERVICE_NAMES, server)
|
||||
@@ -902,7 +786,7 @@ async def serve_grpc(
|
||||
# Start warmup in a separate thread
|
||||
warmup_thread = threading.Thread(
|
||||
target=_wait_and_warmup_grpc,
|
||||
args=(server_args, None),
|
||||
args=(server_args, None, health_servicer),
|
||||
)
|
||||
warmup_thread.start()
|
||||
|
||||
@@ -1103,6 +987,7 @@ def _execute_grpc_server_warmup(
|
||||
def _wait_and_warmup_grpc(
|
||||
server_args: ServerArgs,
|
||||
pipe_finish_writer: Optional[mp.connection.Connection],
|
||||
health_servicer: Optional[SGLangHealthServicer] = None,
|
||||
):
|
||||
"""Wait for gRPC server to be ready and execute warmup."""
|
||||
if not server_args.skip_server_warmup:
|
||||
@@ -1111,6 +996,11 @@ def _wait_and_warmup_grpc(
|
||||
else:
|
||||
logger.info("Skipping gRPC server warmup (skip_server_warmup=True)")
|
||||
|
||||
# Mark health service as SERVING after warmup completes
|
||||
if health_servicer:
|
||||
health_servicer.set_serving()
|
||||
logger.info("Health service marked as SERVING")
|
||||
|
||||
logger.info("The server is fired up and ready to roll!")
|
||||
|
||||
if pipe_finish_writer is not None:
|
||||
|
||||
189
python/sglang/srt/grpc/health_servicer.py
Normal file
189
python/sglang/srt/grpc/health_servicer.py
Normal file
@@ -0,0 +1,189 @@
|
||||
"""
|
||||
Standard gRPC health check service implementation for Kubernetes probes.
|
||||
|
||||
This module implements the grpc.health.v1.Health service protocol, enabling
|
||||
native Kubernetes gRPC health probes for liveness and readiness checks.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import AsyncIterator
|
||||
|
||||
import grpc
|
||||
from grpc_health.v1 import health_pb2, health_pb2_grpc
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SGLangHealthServicer(health_pb2_grpc.HealthServicer):
|
||||
"""
|
||||
Standard gRPC health check service implementation for Kubernetes probes.
|
||||
Implements grpc.health.v1.Health protocol.
|
||||
|
||||
Supports two service levels:
|
||||
1. Overall server health (service="") - for liveness probes
|
||||
2. SGLang service health (service="sglang.grpc.scheduler.SglangScheduler") - for readiness probes
|
||||
|
||||
Health status lifecycle:
|
||||
- NOT_SERVING: Initial state, model loading, or shutting down
|
||||
- SERVING: Model loaded and ready to serve requests
|
||||
"""
|
||||
|
||||
# Service names we support
|
||||
OVERALL_SERVER = "" # Empty string for overall server health
|
||||
SGLANG_SERVICE = "sglang.grpc.scheduler.SglangScheduler"
|
||||
|
||||
def __init__(self, request_manager, scheduler_info: dict):
|
||||
"""
|
||||
Initialize health servicer.
|
||||
|
||||
Args:
|
||||
request_manager: GrpcRequestManager instance for checking server state
|
||||
scheduler_info: Dict containing scheduler metadata
|
||||
"""
|
||||
self.request_manager = request_manager
|
||||
self.scheduler_info = scheduler_info
|
||||
self._serving_status = {}
|
||||
|
||||
# Initially set to NOT_SERVING until model is loaded
|
||||
self._serving_status[self.OVERALL_SERVER] = (
|
||||
health_pb2.HealthCheckResponse.NOT_SERVING
|
||||
)
|
||||
self._serving_status[self.SGLANG_SERVICE] = (
|
||||
health_pb2.HealthCheckResponse.NOT_SERVING
|
||||
)
|
||||
|
||||
logger.info("Standard gRPC health service initialized")
|
||||
|
||||
def set_serving(self):
|
||||
"""Mark services as SERVING - call this after model is loaded."""
|
||||
self._serving_status[self.OVERALL_SERVER] = (
|
||||
health_pb2.HealthCheckResponse.SERVING
|
||||
)
|
||||
self._serving_status[self.SGLANG_SERVICE] = (
|
||||
health_pb2.HealthCheckResponse.SERVING
|
||||
)
|
||||
logger.info("Health service status set to SERVING")
|
||||
|
||||
def set_not_serving(self):
|
||||
"""Mark services as NOT_SERVING - call this during shutdown."""
|
||||
self._serving_status[self.OVERALL_SERVER] = (
|
||||
health_pb2.HealthCheckResponse.NOT_SERVING
|
||||
)
|
||||
self._serving_status[self.SGLANG_SERVICE] = (
|
||||
health_pb2.HealthCheckResponse.NOT_SERVING
|
||||
)
|
||||
logger.info("Health service status set to NOT_SERVING")
|
||||
|
||||
async def Check(
|
||||
self,
|
||||
request: health_pb2.HealthCheckRequest,
|
||||
context: grpc.aio.ServicerContext,
|
||||
) -> health_pb2.HealthCheckResponse:
|
||||
"""
|
||||
Standard health check for Kubernetes probes.
|
||||
|
||||
Args:
|
||||
request: Contains service name ("" for overall, or specific service)
|
||||
context: gRPC context
|
||||
|
||||
Returns:
|
||||
HealthCheckResponse with SERVING/NOT_SERVING/SERVICE_UNKNOWN status
|
||||
"""
|
||||
service_name = request.service
|
||||
logger.debug(f"Health check request for service: '{service_name}'")
|
||||
|
||||
# Check if shutting down
|
||||
if self.request_manager.gracefully_exit:
|
||||
logger.debug("Health check: Server is shutting down")
|
||||
return health_pb2.HealthCheckResponse(
|
||||
status=health_pb2.HealthCheckResponse.NOT_SERVING
|
||||
)
|
||||
|
||||
# Overall server health - just check if process is alive
|
||||
if service_name == self.OVERALL_SERVER:
|
||||
status = self._serving_status.get(
|
||||
self.OVERALL_SERVER, health_pb2.HealthCheckResponse.NOT_SERVING
|
||||
)
|
||||
logger.debug(
|
||||
f"Overall health check: {health_pb2.HealthCheckResponse.ServingStatus.Name(status)}"
|
||||
)
|
||||
return health_pb2.HealthCheckResponse(status=status)
|
||||
|
||||
# Specific service health - check if ready to serve
|
||||
elif service_name == self.SGLANG_SERVICE:
|
||||
# Additional checks for service readiness
|
||||
|
||||
# Check base status first
|
||||
base_status = self._serving_status.get(
|
||||
self.SGLANG_SERVICE, health_pb2.HealthCheckResponse.NOT_SERVING
|
||||
)
|
||||
|
||||
if base_status != health_pb2.HealthCheckResponse.SERVING:
|
||||
logger.debug("Service health check: NOT_SERVING (base status)")
|
||||
return health_pb2.HealthCheckResponse(status=base_status)
|
||||
|
||||
# Check if scheduler is responsive (received data recently)
|
||||
time_since_last_receive = (
|
||||
time.time() - self.request_manager.last_receive_tstamp
|
||||
)
|
||||
|
||||
# If no recent activity and we have active requests, might be stuck
|
||||
# NOTE: 30s timeout is hardcoded. This is more conservative than
|
||||
# HEALTH_CHECK_TIMEOUT (20s) used for custom HealthCheck RPC.
|
||||
# Consider making this configurable via environment variable in the future
|
||||
# if different workloads need different responsiveness thresholds.
|
||||
if (
|
||||
time_since_last_receive > 30
|
||||
and len(self.request_manager.rid_to_state) > 0
|
||||
):
|
||||
logger.warning(
|
||||
f"Service health check: Scheduler not responsive "
|
||||
f"({time_since_last_receive:.1f}s since last receive, "
|
||||
f"{len(self.request_manager.rid_to_state)} pending requests)"
|
||||
)
|
||||
return health_pb2.HealthCheckResponse(
|
||||
status=health_pb2.HealthCheckResponse.NOT_SERVING
|
||||
)
|
||||
|
||||
logger.debug("Service health check: SERVING")
|
||||
return health_pb2.HealthCheckResponse(
|
||||
status=health_pb2.HealthCheckResponse.SERVING
|
||||
)
|
||||
|
||||
# Unknown service
|
||||
else:
|
||||
logger.debug(f"Health check for unknown service: '{service_name}'")
|
||||
context.set_code(grpc.StatusCode.NOT_FOUND)
|
||||
context.set_details(f"Unknown service: {service_name}")
|
||||
return health_pb2.HealthCheckResponse(
|
||||
status=health_pb2.HealthCheckResponse.SERVICE_UNKNOWN
|
||||
)
|
||||
|
||||
async def Watch(
|
||||
self,
|
||||
request: health_pb2.HealthCheckRequest,
|
||||
context: grpc.aio.ServicerContext,
|
||||
) -> AsyncIterator[health_pb2.HealthCheckResponse]:
|
||||
"""
|
||||
Streaming health check - sends updates when status changes.
|
||||
|
||||
For now, just send current status once (Kubernetes doesn't use Watch).
|
||||
A full implementation would monitor status changes and stream updates.
|
||||
|
||||
Args:
|
||||
request: Contains service name
|
||||
context: gRPC context
|
||||
|
||||
Yields:
|
||||
HealthCheckResponse messages when status changes
|
||||
"""
|
||||
service_name = request.service
|
||||
logger.debug(f"Health watch request for service: '{service_name}'")
|
||||
|
||||
# Send current status
|
||||
response = await self.Check(request, context)
|
||||
yield response
|
||||
|
||||
# Note: Full Watch implementation would monitor status changes
|
||||
# and stream updates. For K8s probes, Check is sufficient.
|
||||
181
python/sglang/srt/grpc/scheduler_launcher.py
Normal file
181
python/sglang/srt/grpc/scheduler_launcher.py
Normal file
@@ -0,0 +1,181 @@
|
||||
"""
|
||||
Scheduler process management for gRPC server.
|
||||
|
||||
This module handles launching and managing scheduler processes for the gRPC server,
|
||||
including tensor parallelism, pipeline parallelism, and data parallelism configurations.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import multiprocessing as mp
|
||||
import signal
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from sglang.srt.managers.data_parallel_controller import (
|
||||
run_data_parallel_controller_process,
|
||||
)
|
||||
from sglang.srt.managers.scheduler import run_scheduler_process
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import configure_logger, prepare_model_and_tokenizer
|
||||
from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def run_scheduler_with_signal_handling(*args, **kwargs):
|
||||
"""
|
||||
Wrapper for run_scheduler_process that ignores SIGINT.
|
||||
|
||||
The scheduler process should not handle Ctrl+C - it should only terminate
|
||||
when the parent gRPC server exits (via kill_itself_when_parent_died).
|
||||
|
||||
Args:
|
||||
*args: Positional arguments for run_scheduler_process
|
||||
**kwargs: Keyword arguments for run_scheduler_process
|
||||
"""
|
||||
# Ignore SIGINT in this subprocess - let the parent handle it
|
||||
signal.signal(signal.SIGINT, signal.SIG_IGN)
|
||||
|
||||
# Now run the actual scheduler process
|
||||
run_scheduler_process(*args, **kwargs)
|
||||
|
||||
|
||||
def launch_scheduler_process_only(
|
||||
server_args: ServerArgs,
|
||||
port_args: Optional[PortArgs] = None,
|
||||
) -> Tuple[Dict, PortArgs, List[mp.Process]]:
|
||||
"""
|
||||
Launch only the scheduler process(es) without tokenizer/detokenizer.
|
||||
|
||||
This function handles all scheduler startup logic including:
|
||||
- Tensor parallelism (tp_size)
|
||||
- Pipeline parallelism (pp_size)
|
||||
- Data parallelism (dp_size)
|
||||
- Multi-node distributed setup
|
||||
|
||||
Args:
|
||||
server_args: Server configuration
|
||||
port_args: Port configuration (created if None)
|
||||
|
||||
Returns:
|
||||
Tuple of (scheduler_info, port_args, scheduler_processes):
|
||||
- scheduler_info: Dict with model metadata and configuration
|
||||
- port_args: Port configuration used for IPC
|
||||
- scheduler_processes: List of launched scheduler Process objects
|
||||
|
||||
Raises:
|
||||
RuntimeError: If any scheduler process fails to initialize
|
||||
"""
|
||||
# Configure global environment
|
||||
configure_logger(server_args)
|
||||
server_args.check_server_args()
|
||||
|
||||
# Fix CUDA multiprocessing issues - must be called before any CUDA operations
|
||||
mp.set_start_method("spawn", force=True)
|
||||
|
||||
# Allocate ports for inter-process communications
|
||||
if port_args is None:
|
||||
port_args = PortArgs.init_new(server_args)
|
||||
logger.info(f"{server_args=}")
|
||||
|
||||
# Prepare model and tokenizer paths
|
||||
server_args.model_path, server_args.tokenizer_path = prepare_model_and_tokenizer(
|
||||
server_args.model_path, server_args.tokenizer_path
|
||||
)
|
||||
|
||||
scheduler_procs = []
|
||||
|
||||
if server_args.dp_size == 1:
|
||||
# Single data parallel group - launch TP/PP schedulers
|
||||
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
||||
enable=server_args.enable_memory_saver
|
||||
)
|
||||
scheduler_pipe_readers = []
|
||||
|
||||
# Calculate TP/PP distribution across nodes
|
||||
nnodes_per_tp_group = max(server_args.nnodes // server_args.pp_size, 1)
|
||||
tp_size_per_node = server_args.tp_size // nnodes_per_tp_group
|
||||
tp_rank_range = range(
|
||||
tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group),
|
||||
tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group + 1),
|
||||
)
|
||||
|
||||
pp_size_per_node = max(server_args.pp_size // server_args.nnodes, 1)
|
||||
pp_rank_range = range(
|
||||
pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group),
|
||||
pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group + 1),
|
||||
)
|
||||
|
||||
# Launch scheduler for each TP/PP rank combination
|
||||
for pp_rank in pp_rank_range:
|
||||
for tp_rank in tp_rank_range:
|
||||
reader, writer = mp.Pipe(duplex=False)
|
||||
|
||||
# Calculate GPU ID for this rank
|
||||
gpu_id = (
|
||||
server_args.base_gpu_id
|
||||
+ ((pp_rank % pp_size_per_node) * tp_size_per_node)
|
||||
+ (tp_rank % tp_size_per_node) * server_args.gpu_id_step
|
||||
)
|
||||
|
||||
# Calculate MoE expert parallel rank
|
||||
moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size)
|
||||
|
||||
# Create scheduler process
|
||||
proc = mp.Process(
|
||||
target=run_scheduler_with_signal_handling,
|
||||
args=(
|
||||
server_args,
|
||||
port_args,
|
||||
gpu_id,
|
||||
tp_rank,
|
||||
moe_ep_rank,
|
||||
pp_rank,
|
||||
None, # dp_rank
|
||||
writer,
|
||||
),
|
||||
)
|
||||
|
||||
with memory_saver_adapter.configure_subprocess():
|
||||
proc.start()
|
||||
|
||||
scheduler_procs.append(proc)
|
||||
scheduler_pipe_readers.append(reader)
|
||||
else:
|
||||
# Data parallelism - launch data parallel controller
|
||||
reader, writer = mp.Pipe(duplex=False)
|
||||
scheduler_pipe_readers = [reader]
|
||||
|
||||
proc = mp.Process(
|
||||
target=run_data_parallel_controller_process,
|
||||
args=(server_args, port_args, writer),
|
||||
)
|
||||
proc.start()
|
||||
scheduler_procs.append(proc)
|
||||
|
||||
# TODO(CatherineSue): handle cases for multi-node
|
||||
|
||||
# Wait for all scheduler processes to be ready
|
||||
scheduler_infos = []
|
||||
for i, reader in enumerate(scheduler_pipe_readers):
|
||||
try:
|
||||
data = reader.recv()
|
||||
except EOFError:
|
||||
logger.error(
|
||||
f"Rank {i} scheduler is dead. Please check if there are relevant logs."
|
||||
)
|
||||
scheduler_procs[i].join()
|
||||
logger.error(f"Exit code: {scheduler_procs[i].exitcode}")
|
||||
raise RuntimeError(f"Failed to initialize scheduler rank {i}")
|
||||
|
||||
if data.get("status") != "ready":
|
||||
raise RuntimeError(
|
||||
f"Scheduler rank {i} initialization failed: {data.get('error', 'Unknown error')}"
|
||||
)
|
||||
scheduler_infos.append(data)
|
||||
|
||||
logger.info(
|
||||
f"All {len(scheduler_procs)} scheduler process(es) initialized successfully"
|
||||
)
|
||||
|
||||
# Return the first scheduler's info (they should all be the same)
|
||||
return scheduler_infos[0], port_args, scheduler_procs
|
||||
Reference in New Issue
Block a user