Files
sglang/python/sglang/srt/managers/data_parallel_controller.py

320 lines
12 KiB
Python
Raw Normal View History

# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
2024-10-11 07:22:48 -07:00
"""A controller that dispatches requests to multiple data parallel workers."""
import logging
import multiprocessing as mp
import signal
import threading
import time
2024-10-11 07:22:48 -07:00
from enum import Enum, auto
import psutil
import setproctitle
2024-10-11 07:22:48 -07:00
import zmq
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
2024-10-11 07:22:48 -07:00
from sglang.srt.managers.io_struct import (
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
)
from sglang.srt.managers.schedule_batch import Req
2024-10-11 07:22:48 -07:00
from sglang.srt.managers.scheduler import run_scheduler_process
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.utils import bind_port, configure_logger, get_zmq_socket
2024-10-11 07:22:48 -07:00
from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__)
class LoadBalanceMethod(Enum):
"""Load balance method."""
ROUND_ROBIN = auto()
SHORTEST_QUEUE = auto()
@classmethod
def from_str(cls, method: str):
method = method.upper()
try:
return cls[method]
except KeyError as exc:
raise ValueError(f"Invalid load balance method: {method}") from exc
class DataParallelController:
"""A controller that dispatches requests to multiple data parallel workers."""
def __init__(self, server_args: ServerArgs, port_args: PortArgs) -> None:
2024-10-11 07:22:48 -07:00
# Parse args
self.max_total_num_tokens = None
2024-10-11 07:22:48 -07:00
self.server_args = server_args
self.port_args = port_args
self.load_balance_method = LoadBalanceMethod.from_str(
server_args.load_balance_method
)
# Init inter-process communication
self.context = zmq.Context(1 + server_args.dp_size)
if server_args.node_rank == 0:
self.recv_from_tokenizer = get_zmq_socket(
self.context, zmq.PULL, port_args.scheduler_input_ipc_name, False
)
2024-10-11 07:22:48 -07:00
# Dispatch method
self.round_robin_counter = 0
dispatch_lookup = {
LoadBalanceMethod.ROUND_ROBIN: self.round_robin_scheduler,
LoadBalanceMethod.SHORTEST_QUEUE: self.shortest_queue_scheduler,
}
self.dispatching = dispatch_lookup[self.load_balance_method]
# Launch data parallel workers
self.scheduler_procs = []
self.workers = [None] * server_args.dp_size
if server_args.enable_dp_attention:
dp_port_args = self.launch_dp_attention_schedulers(server_args, port_args)
self.control_message_step = server_args.tp_size
else:
dp_port_args = self.launch_dp_schedulers(server_args, port_args)
self.control_message_step = 1
# Only node rank 0 runs the real data parallel controller that dispatches the requests.
if server_args.node_rank == 0:
for dp_rank in range(server_args.dp_size):
self.workers[dp_rank] = get_zmq_socket(
self.context,
zmq.PUSH,
dp_port_args[dp_rank].scheduler_input_ipc_name,
True,
)
self.max_req_input_len = None
def launch_dp_schedulers(self, server_args, port_args):
base_gpu_id = 0
threads = []
sockets = []
dp_port_args = []
ready_events = []
2024-10-11 07:22:48 -07:00
for dp_rank in range(server_args.dp_size):
tmp_port_args = PortArgs.init_new(server_args)
tmp_port_args.tokenizer_ipc_name = port_args.tokenizer_ipc_name
2024-10-11 07:22:48 -07:00
tmp_port_args.detokenizer_ipc_name = port_args.detokenizer_ipc_name
dp_port_args.append(tmp_port_args)
2024-10-11 07:22:48 -07:00
# This port is checked free in PortArgs.init_new.
# We hold it first so that the next dp worker gets a different port
sockets.append(bind_port(tmp_port_args.nccl_port))
ready_event = threading.Event()
ready_events.append(ready_event)
# Create a thread for each worker
thread = threading.Thread(
target=self.launch_tensor_parallel_group_thread,
args=(server_args, tmp_port_args, base_gpu_id, dp_rank, ready_event),
)
threads.append(thread)
base_gpu_id += server_args.tp_size * server_args.gpu_id_step
# Free all sockets before starting the threads to launch TP workers
for sock in sockets:
sock.close()
# Start all threads
for thread in threads:
thread.start()
for event in ready_events:
event.wait()
return dp_port_args
def launch_tensor_parallel_group_thread(
self,
server_args: ServerArgs,
port_args: PortArgs,
base_gpu_id: int,
dp_rank: int,
ready_event: threading.Event,
):
self.launch_tensor_parallel_group(server_args, port_args, base_gpu_id, dp_rank)
ready_event.set()
# This thread cannot be closed because otherwise the `kill_itself_when_parent_died`
# function in scheduler.py will kill the scheduler.
while True:
time.sleep(30 * 24 * 3600)
def launch_dp_attention_schedulers(self, server_args, port_args):
self.launch_tensor_parallel_group(server_args, port_args, 0, None)
dp_port_args = []
for dp_rank in range(server_args.dp_size):
dp_port_args.append(PortArgs.init_new(server_args, dp_rank))
return dp_port_args
2024-10-11 07:22:48 -07:00
def launch_tensor_parallel_group(
self,
server_args: ServerArgs,
port_args: PortArgs,
base_gpu_id: int,
dp_rank: int,
):
if not server_args.enable_dp_attention:
logger.info(f"Launch DP{dp_rank} starting at GPU #{base_gpu_id}.")
memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=server_args.enable_memory_saver
)
2024-10-11 07:22:48 -07:00
scheduler_pipe_readers = []
2025-04-30 18:18:07 -07:00
nnodes_per_tp_group = max(server_args.nnodes // server_args.pp_size, 1)
tp_size_per_node = server_args.tp_size // nnodes_per_tp_group
2024-10-11 07:22:48 -07:00
tp_rank_range = range(
2025-04-30 18:18:07 -07:00
tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group),
tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group + 1),
)
pp_size_per_node = max(server_args.pp_size // server_args.nnodes, 1)
pp_rank_range = range(
pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group),
pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group + 1),
2024-10-11 07:22:48 -07:00
)
2025-04-30 18:18:07 -07:00
for pp_rank in pp_rank_range:
for tp_rank in tp_rank_range:
rank_port_args = port_args
if server_args.enable_dp_attention:
# dp attention has different sharding logic
_, _, dp_rank = compute_dp_attention_world_info(
server_args.enable_dp_attention,
tp_rank,
server_args.tp_size,
server_args.dp_size,
)
# compute zmq ports for this dp rank
rank_port_args = PortArgs.init_new(server_args, dp_rank)
# Data parallelism reuses the tensor parallelism group,
2025-04-30 18:18:07 -07:00
# so all dp ranks should use the same nccl port.
rank_port_args.nccl_port = port_args.nccl_port
reader, writer = mp.Pipe(duplex=False)
gpu_id = (
server_args.base_gpu_id
+ base_gpu_id
+ ((pp_rank % pp_size_per_node) * tp_size_per_node)
+ (tp_rank % tp_size_per_node) * server_args.gpu_id_step
)
2025-04-30 18:18:07 -07:00
proc = mp.Process(
target=run_scheduler_process,
args=(
server_args,
rank_port_args,
gpu_id,
tp_rank,
pp_rank,
dp_rank,
writer,
),
)
with memory_saver_adapter.configure_subprocess():
proc.start()
self.scheduler_procs.append(proc)
scheduler_pipe_readers.append(reader)
2024-10-11 07:22:48 -07:00
# Wait for model to finish loading
scheduler_info = []
2024-10-11 07:22:48 -07:00
for i in range(len(scheduler_pipe_readers)):
scheduler_info.append(scheduler_pipe_readers[i].recv())
self.max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"]
self.max_req_input_len = scheduler_info[0]["max_req_input_len"]
2024-10-11 07:22:48 -07:00
def round_robin_scheduler(self, req: Req):
if self.server_args.disaggregation_mode == "null":
if req.data_parallel_rank is not None:
logger.debug(f"Direct routing to DP rank {req.data_parallel_rank}")
self.workers[req.data_parallel_rank].send_pyobj(req)
else:
self.workers[self.round_robin_counter].send_pyobj(req)
self.round_robin_counter = (self.round_robin_counter + 1) % len(
self.workers
)
else:
if req.data_parallel_rank is not None:
logger.debug(f"Direct routing to DP rank {req.data_parallel_rank}")
self.workers[req.data_parallel_rank].send_pyobj(req)
else:
self.workers[req.bootstrap_room % len(self.workers)].send_pyobj(req)
2024-10-11 07:22:48 -07:00
def shortest_queue_scheduler(self, input_requests):
raise NotImplementedError()
def event_loop(self):
while True:
while True:
try:
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
except zmq.ZMQError:
break
if isinstance(
recv_req,
(
TokenizedGenerateReqInput,
TokenizedEmbeddingReqInput,
),
):
self.dispatching(recv_req)
else:
# Send other control messages to first worker of tp group
for worker in self.workers[:: self.control_message_step]:
2024-10-23 10:46:29 -07:00
worker.send_pyobj(recv_req)
2024-10-11 07:22:48 -07:00
def run_data_parallel_controller_process(
server_args: ServerArgs,
port_args: PortArgs,
pipe_writer,
):
setproctitle.setproctitle("sglang::data_parallel_controller")
2024-10-11 07:22:48 -07:00
configure_logger(server_args)
parent_process = psutil.Process().parent()
2024-10-11 07:22:48 -07:00
try:
controller = DataParallelController(server_args, port_args)
pipe_writer.send(
{
"status": "ready",
"max_total_num_tokens": controller.max_total_num_tokens,
"max_req_input_len": controller.max_req_input_len,
}
)
if server_args.node_rank == 0:
controller.event_loop()
for proc in controller.scheduler_procs:
proc.join()
logger.error(
f"Scheduler or DataParallelController {proc.pid} terminated with {proc.exitcode}"
)
2024-10-11 07:22:48 -07:00
except Exception:
traceback = get_exception_traceback()
logger.error(f"DataParallelController hit an exception: {traceback}")
parent_process.send_signal(signal.SIGQUIT)