Add back data parallelism (#1635)
This commit is contained in:
177
python/sglang/srt/managers/data_parallel_controller.py
Normal file
177
python/sglang/srt/managers/data_parallel_controller.py
Normal file
@@ -0,0 +1,177 @@
|
||||
"""
|
||||
Copyright 2023-2024 SGLang Team
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
"""A controller that dispatches requests to multiple data parallel workers."""
|
||||
|
||||
import logging
|
||||
import multiprocessing as mp
|
||||
from enum import Enum, auto
|
||||
|
||||
import zmq
|
||||
|
||||
from sglang.srt.managers.io_struct import (
|
||||
TokenizedEmbeddingReqInput,
|
||||
TokenizedGenerateReqInput,
|
||||
TokenizedRewardReqInput,
|
||||
)
|
||||
from sglang.srt.managers.scheduler import run_scheduler_process
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
configure_logger,
|
||||
kill_parent_process,
|
||||
suppress_other_loggers,
|
||||
)
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoadBalanceMethod(Enum):
|
||||
"""Load balance method."""
|
||||
|
||||
ROUND_ROBIN = auto()
|
||||
SHORTEST_QUEUE = auto()
|
||||
|
||||
@classmethod
|
||||
def from_str(cls, method: str):
|
||||
method = method.upper()
|
||||
try:
|
||||
return cls[method]
|
||||
except KeyError as exc:
|
||||
raise ValueError(f"Invalid load balance method: {method}") from exc
|
||||
|
||||
|
||||
class DataParallelController:
|
||||
"""A controller that dispatches requests to multiple data parallel workers."""
|
||||
|
||||
def __init__(self, server_args, port_args) -> None:
|
||||
# Parse args
|
||||
self.server_args = server_args
|
||||
self.port_args = port_args
|
||||
self.load_balance_method = LoadBalanceMethod.from_str(
|
||||
server_args.load_balance_method
|
||||
)
|
||||
|
||||
# Init inter-process communication
|
||||
self.context = zmq.Context(1 + server_args.dp_size)
|
||||
self.recv_from_tokenizer = self.context.socket(zmq.PULL)
|
||||
self.recv_from_tokenizer.bind(f"ipc://{port_args.scheduler_input_ipc_name}")
|
||||
|
||||
# Dispatch method
|
||||
self.round_robin_counter = 0
|
||||
dispatch_lookup = {
|
||||
LoadBalanceMethod.ROUND_ROBIN: self.round_robin_scheduler,
|
||||
LoadBalanceMethod.SHORTEST_QUEUE: self.shortest_queue_scheduler,
|
||||
}
|
||||
self.dispatching = dispatch_lookup[self.load_balance_method]
|
||||
|
||||
# Start data parallel workers
|
||||
base_gpu_id = 0
|
||||
self.workers = []
|
||||
for dp_rank in range(server_args.dp_size):
|
||||
tmp_port_args = PortArgs.init_new(server_args)
|
||||
tmp_port_args.detokenizer_ipc_name = port_args.detokenizer_ipc_name
|
||||
|
||||
send_to = self.launch_tensor_parallel_group(
|
||||
server_args,
|
||||
tmp_port_args,
|
||||
base_gpu_id,
|
||||
dp_rank,
|
||||
)
|
||||
|
||||
self.workers.append(send_to)
|
||||
base_gpu_id += server_args.tp_size
|
||||
|
||||
def launch_tensor_parallel_group(
|
||||
self,
|
||||
server_args: ServerArgs,
|
||||
port_args: PortArgs,
|
||||
base_gpu_id: int,
|
||||
dp_rank: int,
|
||||
):
|
||||
# Launch tensor parallel scheduler processes
|
||||
scheduler_procs = []
|
||||
scheduler_pipe_readers = []
|
||||
tp_size_per_node = server_args.tp_size // server_args.nnodes
|
||||
tp_rank_range = range(
|
||||
tp_size_per_node * server_args.node_rank,
|
||||
tp_size_per_node * (server_args.node_rank + 1),
|
||||
)
|
||||
for tp_rank in tp_rank_range:
|
||||
reader, writer = mp.Pipe(duplex=False)
|
||||
gpu_id = base_gpu_id + tp_rank % tp_size_per_node
|
||||
proc = mp.Process(
|
||||
target=run_scheduler_process,
|
||||
args=(server_args, port_args, gpu_id, tp_rank, dp_rank, writer),
|
||||
)
|
||||
proc.start()
|
||||
scheduler_procs.append(proc)
|
||||
scheduler_pipe_readers.append(reader)
|
||||
|
||||
send_to = self.context.socket(zmq.PUSH)
|
||||
send_to.connect(f"ipc://{port_args.scheduler_input_ipc_name}")
|
||||
|
||||
# Wait for model to finish loading
|
||||
for i in range(len(scheduler_pipe_readers)):
|
||||
scheduler_pipe_readers[i].recv()
|
||||
|
||||
return send_to
|
||||
|
||||
def round_robin_scheduler(self, req):
|
||||
self.workers[self.round_robin_counter].send_pyobj(req)
|
||||
self.round_robin_counter = (self.round_robin_counter + 1) % len(self.workers)
|
||||
|
||||
def shortest_queue_scheduler(self, input_requests):
|
||||
raise NotImplementedError()
|
||||
|
||||
def event_loop(self):
|
||||
while True:
|
||||
while True:
|
||||
try:
|
||||
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
|
||||
except zmq.ZMQError:
|
||||
break
|
||||
|
||||
if isinstance(
|
||||
recv_req,
|
||||
(
|
||||
TokenizedGenerateReqInput,
|
||||
TokenizedEmbeddingReqInput,
|
||||
TokenizedRewardReqInput,
|
||||
),
|
||||
):
|
||||
self.dispatching(recv_req)
|
||||
else:
|
||||
# Send other control messages to all workers
|
||||
for worker in self.workers:
|
||||
worker.queue.put(recv_req)
|
||||
|
||||
|
||||
def run_data_parallel_controller_process(
|
||||
server_args: ServerArgs,
|
||||
port_args: PortArgs,
|
||||
pipe_writer,
|
||||
):
|
||||
configure_logger(server_args)
|
||||
suppress_other_loggers()
|
||||
|
||||
try:
|
||||
controller = DataParallelController(server_args, port_args)
|
||||
pipe_writer.send("ready")
|
||||
controller.event_loop()
|
||||
except Exception:
|
||||
msg = get_exception_traceback()
|
||||
logger.error(msg)
|
||||
kill_parent_process()
|
||||
@@ -142,7 +142,7 @@ class Scheduler:
|
||||
gpu_id=gpu_id,
|
||||
tp_rank=tp_rank,
|
||||
server_args=server_args,
|
||||
nccl_port=port_args.nccl_ports[0],
|
||||
nccl_port=port_args.nccl_port,
|
||||
)
|
||||
self.tp_cpu_group = self.tp_worker.model_runner.tp_group.cpu_group
|
||||
|
||||
@@ -1042,9 +1042,14 @@ def run_scheduler_process(
|
||||
port_args: PortArgs,
|
||||
gpu_id: int,
|
||||
tp_rank: int,
|
||||
dp_rank: Optional[int],
|
||||
pipe_writer,
|
||||
):
|
||||
configure_logger(server_args, prefix=f" TP{tp_rank}")
|
||||
if dp_rank is None:
|
||||
configure_logger(server_args, prefix=f" TP{tp_rank}")
|
||||
else:
|
||||
configure_logger(server_args, prefix=f" DP{dp_rank} TP{tp_rank}")
|
||||
|
||||
suppress_other_loggers()
|
||||
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user