Add back data parallelism (#1635)

This commit is contained in:
Lianmin Zheng
2024-10-11 07:22:48 -07:00
committed by GitHub
parent 5d09ca5735
commit 23cc66f7b6
7 changed files with 228 additions and 39 deletions

View File

@@ -0,0 +1,177 @@
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
"""A controller that dispatches requests to multiple data parallel workers."""
import logging
import multiprocessing as mp
from enum import Enum, auto
import zmq
from sglang.srt.managers.io_struct import (
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
TokenizedRewardReqInput,
)
from sglang.srt.managers.scheduler import run_scheduler_process
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import (
configure_logger,
kill_parent_process,
suppress_other_loggers,
)
from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__)
class LoadBalanceMethod(Enum):
"""Load balance method."""
ROUND_ROBIN = auto()
SHORTEST_QUEUE = auto()
@classmethod
def from_str(cls, method: str):
method = method.upper()
try:
return cls[method]
except KeyError as exc:
raise ValueError(f"Invalid load balance method: {method}") from exc
class DataParallelController:
"""A controller that dispatches requests to multiple data parallel workers."""
def __init__(self, server_args, port_args) -> None:
# Parse args
self.server_args = server_args
self.port_args = port_args
self.load_balance_method = LoadBalanceMethod.from_str(
server_args.load_balance_method
)
# Init inter-process communication
self.context = zmq.Context(1 + server_args.dp_size)
self.recv_from_tokenizer = self.context.socket(zmq.PULL)
self.recv_from_tokenizer.bind(f"ipc://{port_args.scheduler_input_ipc_name}")
# Dispatch method
self.round_robin_counter = 0
dispatch_lookup = {
LoadBalanceMethod.ROUND_ROBIN: self.round_robin_scheduler,
LoadBalanceMethod.SHORTEST_QUEUE: self.shortest_queue_scheduler,
}
self.dispatching = dispatch_lookup[self.load_balance_method]
# Start data parallel workers
base_gpu_id = 0
self.workers = []
for dp_rank in range(server_args.dp_size):
tmp_port_args = PortArgs.init_new(server_args)
tmp_port_args.detokenizer_ipc_name = port_args.detokenizer_ipc_name
send_to = self.launch_tensor_parallel_group(
server_args,
tmp_port_args,
base_gpu_id,
dp_rank,
)
self.workers.append(send_to)
base_gpu_id += server_args.tp_size
def launch_tensor_parallel_group(
self,
server_args: ServerArgs,
port_args: PortArgs,
base_gpu_id: int,
dp_rank: int,
):
# Launch tensor parallel scheduler processes
scheduler_procs = []
scheduler_pipe_readers = []
tp_size_per_node = server_args.tp_size // server_args.nnodes
tp_rank_range = range(
tp_size_per_node * server_args.node_rank,
tp_size_per_node * (server_args.node_rank + 1),
)
for tp_rank in tp_rank_range:
reader, writer = mp.Pipe(duplex=False)
gpu_id = base_gpu_id + tp_rank % tp_size_per_node
proc = mp.Process(
target=run_scheduler_process,
args=(server_args, port_args, gpu_id, tp_rank, dp_rank, writer),
)
proc.start()
scheduler_procs.append(proc)
scheduler_pipe_readers.append(reader)
send_to = self.context.socket(zmq.PUSH)
send_to.connect(f"ipc://{port_args.scheduler_input_ipc_name}")
# Wait for model to finish loading
for i in range(len(scheduler_pipe_readers)):
scheduler_pipe_readers[i].recv()
return send_to
def round_robin_scheduler(self, req):
self.workers[self.round_robin_counter].send_pyobj(req)
self.round_robin_counter = (self.round_robin_counter + 1) % len(self.workers)
def shortest_queue_scheduler(self, input_requests):
raise NotImplementedError()
def event_loop(self):
while True:
while True:
try:
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
except zmq.ZMQError:
break
if isinstance(
recv_req,
(
TokenizedGenerateReqInput,
TokenizedEmbeddingReqInput,
TokenizedRewardReqInput,
),
):
self.dispatching(recv_req)
else:
# Send other control messages to all workers
for worker in self.workers:
worker.queue.put(recv_req)
def run_data_parallel_controller_process(
server_args: ServerArgs,
port_args: PortArgs,
pipe_writer,
):
configure_logger(server_args)
suppress_other_loggers()
try:
controller = DataParallelController(server_args, port_args)
pipe_writer.send("ready")
controller.event_loop()
except Exception:
msg = get_exception_traceback()
logger.error(msg)
kill_parent_process()

View File

@@ -142,7 +142,7 @@ class Scheduler:
gpu_id=gpu_id,
tp_rank=tp_rank,
server_args=server_args,
nccl_port=port_args.nccl_ports[0],
nccl_port=port_args.nccl_port,
)
self.tp_cpu_group = self.tp_worker.model_runner.tp_group.cpu_group
@@ -1042,9 +1042,14 @@ def run_scheduler_process(
port_args: PortArgs,
gpu_id: int,
tp_rank: int,
dp_rank: Optional[int],
pipe_writer,
):
configure_logger(server_args, prefix=f" TP{tp_rank}")
if dp_rank is None:
configure_logger(server_args, prefix=f" TP{tp_rank}")
else:
configure_logger(server_args, prefix=f" DP{dp_rank} TP{tp_rank}")
suppress_other_loggers()
try: