Add back data parallelism (#1635)
This commit is contained in:
@@ -139,7 +139,7 @@ def load_model(server_args, port_args, tp_rank):
|
||||
gpu_id=tp_rank,
|
||||
tp_rank=tp_rank,
|
||||
tp_size=server_args.tp_size,
|
||||
nccl_port=port_args.nccl_ports[0],
|
||||
nccl_port=port_args.nccl_port,
|
||||
server_args=server_args,
|
||||
)
|
||||
rank_print(f"max_total_num_tokens={model_runner.max_total_num_tokens}")
|
||||
|
||||
177
python/sglang/srt/managers/data_parallel_controller.py
Normal file
177
python/sglang/srt/managers/data_parallel_controller.py
Normal file
@@ -0,0 +1,177 @@
|
||||
"""
|
||||
Copyright 2023-2024 SGLang Team
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
"""A controller that dispatches requests to multiple data parallel workers."""
|
||||
|
||||
import logging
|
||||
import multiprocessing as mp
|
||||
from enum import Enum, auto
|
||||
|
||||
import zmq
|
||||
|
||||
from sglang.srt.managers.io_struct import (
|
||||
TokenizedEmbeddingReqInput,
|
||||
TokenizedGenerateReqInput,
|
||||
TokenizedRewardReqInput,
|
||||
)
|
||||
from sglang.srt.managers.scheduler import run_scheduler_process
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
configure_logger,
|
||||
kill_parent_process,
|
||||
suppress_other_loggers,
|
||||
)
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoadBalanceMethod(Enum):
|
||||
"""Load balance method."""
|
||||
|
||||
ROUND_ROBIN = auto()
|
||||
SHORTEST_QUEUE = auto()
|
||||
|
||||
@classmethod
|
||||
def from_str(cls, method: str):
|
||||
method = method.upper()
|
||||
try:
|
||||
return cls[method]
|
||||
except KeyError as exc:
|
||||
raise ValueError(f"Invalid load balance method: {method}") from exc
|
||||
|
||||
|
||||
class DataParallelController:
|
||||
"""A controller that dispatches requests to multiple data parallel workers."""
|
||||
|
||||
def __init__(self, server_args, port_args) -> None:
|
||||
# Parse args
|
||||
self.server_args = server_args
|
||||
self.port_args = port_args
|
||||
self.load_balance_method = LoadBalanceMethod.from_str(
|
||||
server_args.load_balance_method
|
||||
)
|
||||
|
||||
# Init inter-process communication
|
||||
self.context = zmq.Context(1 + server_args.dp_size)
|
||||
self.recv_from_tokenizer = self.context.socket(zmq.PULL)
|
||||
self.recv_from_tokenizer.bind(f"ipc://{port_args.scheduler_input_ipc_name}")
|
||||
|
||||
# Dispatch method
|
||||
self.round_robin_counter = 0
|
||||
dispatch_lookup = {
|
||||
LoadBalanceMethod.ROUND_ROBIN: self.round_robin_scheduler,
|
||||
LoadBalanceMethod.SHORTEST_QUEUE: self.shortest_queue_scheduler,
|
||||
}
|
||||
self.dispatching = dispatch_lookup[self.load_balance_method]
|
||||
|
||||
# Start data parallel workers
|
||||
base_gpu_id = 0
|
||||
self.workers = []
|
||||
for dp_rank in range(server_args.dp_size):
|
||||
tmp_port_args = PortArgs.init_new(server_args)
|
||||
tmp_port_args.detokenizer_ipc_name = port_args.detokenizer_ipc_name
|
||||
|
||||
send_to = self.launch_tensor_parallel_group(
|
||||
server_args,
|
||||
tmp_port_args,
|
||||
base_gpu_id,
|
||||
dp_rank,
|
||||
)
|
||||
|
||||
self.workers.append(send_to)
|
||||
base_gpu_id += server_args.tp_size
|
||||
|
||||
def launch_tensor_parallel_group(
|
||||
self,
|
||||
server_args: ServerArgs,
|
||||
port_args: PortArgs,
|
||||
base_gpu_id: int,
|
||||
dp_rank: int,
|
||||
):
|
||||
# Launch tensor parallel scheduler processes
|
||||
scheduler_procs = []
|
||||
scheduler_pipe_readers = []
|
||||
tp_size_per_node = server_args.tp_size // server_args.nnodes
|
||||
tp_rank_range = range(
|
||||
tp_size_per_node * server_args.node_rank,
|
||||
tp_size_per_node * (server_args.node_rank + 1),
|
||||
)
|
||||
for tp_rank in tp_rank_range:
|
||||
reader, writer = mp.Pipe(duplex=False)
|
||||
gpu_id = base_gpu_id + tp_rank % tp_size_per_node
|
||||
proc = mp.Process(
|
||||
target=run_scheduler_process,
|
||||
args=(server_args, port_args, gpu_id, tp_rank, dp_rank, writer),
|
||||
)
|
||||
proc.start()
|
||||
scheduler_procs.append(proc)
|
||||
scheduler_pipe_readers.append(reader)
|
||||
|
||||
send_to = self.context.socket(zmq.PUSH)
|
||||
send_to.connect(f"ipc://{port_args.scheduler_input_ipc_name}")
|
||||
|
||||
# Wait for model to finish loading
|
||||
for i in range(len(scheduler_pipe_readers)):
|
||||
scheduler_pipe_readers[i].recv()
|
||||
|
||||
return send_to
|
||||
|
||||
def round_robin_scheduler(self, req):
|
||||
self.workers[self.round_robin_counter].send_pyobj(req)
|
||||
self.round_robin_counter = (self.round_robin_counter + 1) % len(self.workers)
|
||||
|
||||
def shortest_queue_scheduler(self, input_requests):
|
||||
raise NotImplementedError()
|
||||
|
||||
def event_loop(self):
|
||||
while True:
|
||||
while True:
|
||||
try:
|
||||
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
|
||||
except zmq.ZMQError:
|
||||
break
|
||||
|
||||
if isinstance(
|
||||
recv_req,
|
||||
(
|
||||
TokenizedGenerateReqInput,
|
||||
TokenizedEmbeddingReqInput,
|
||||
TokenizedRewardReqInput,
|
||||
),
|
||||
):
|
||||
self.dispatching(recv_req)
|
||||
else:
|
||||
# Send other control messages to all workers
|
||||
for worker in self.workers:
|
||||
worker.queue.put(recv_req)
|
||||
|
||||
|
||||
def run_data_parallel_controller_process(
|
||||
server_args: ServerArgs,
|
||||
port_args: PortArgs,
|
||||
pipe_writer,
|
||||
):
|
||||
configure_logger(server_args)
|
||||
suppress_other_loggers()
|
||||
|
||||
try:
|
||||
controller = DataParallelController(server_args, port_args)
|
||||
pipe_writer.send("ready")
|
||||
controller.event_loop()
|
||||
except Exception:
|
||||
msg = get_exception_traceback()
|
||||
logger.error(msg)
|
||||
kill_parent_process()
|
||||
@@ -142,7 +142,7 @@ class Scheduler:
|
||||
gpu_id=gpu_id,
|
||||
tp_rank=tp_rank,
|
||||
server_args=server_args,
|
||||
nccl_port=port_args.nccl_ports[0],
|
||||
nccl_port=port_args.nccl_port,
|
||||
)
|
||||
self.tp_cpu_group = self.tp_worker.model_runner.tp_group.cpu_group
|
||||
|
||||
@@ -1042,9 +1042,14 @@ def run_scheduler_process(
|
||||
port_args: PortArgs,
|
||||
gpu_id: int,
|
||||
tp_rank: int,
|
||||
dp_rank: Optional[int],
|
||||
pipe_writer,
|
||||
):
|
||||
configure_logger(server_args, prefix=f" TP{tp_rank}")
|
||||
if dp_rank is None:
|
||||
configure_logger(server_args, prefix=f" TP{tp_rank}")
|
||||
else:
|
||||
configure_logger(server_args, prefix=f" DP{dp_rank} TP{tp_rank}")
|
||||
|
||||
suppress_other_loggers()
|
||||
|
||||
try:
|
||||
|
||||
@@ -141,7 +141,7 @@ class ModelRunner:
|
||||
self.init_attention_backend()
|
||||
|
||||
def init_torch_distributed(self):
|
||||
logger.info("Init torch distributed begin.")
|
||||
logger.info("Init torch distributed begin.")
|
||||
# Init torch distributed
|
||||
if self.device == "cuda":
|
||||
torch.cuda.set_device(self.gpu_id)
|
||||
|
||||
@@ -44,6 +44,9 @@ from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||
|
||||
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.srt.managers.data_parallel_controller import (
|
||||
run_data_parallel_controller_process,
|
||||
)
|
||||
from sglang.srt.managers.detokenizer_manager import run_detokenizer_process
|
||||
from sglang.srt.managers.io_struct import (
|
||||
EmbeddingReqInput,
|
||||
@@ -337,30 +340,40 @@ def launch_engine(
|
||||
server_args.model_path, server_args.tokenizer_path
|
||||
)
|
||||
|
||||
# Launch tensor parallel scheduler processes
|
||||
scheduler_procs = []
|
||||
scheduler_pipe_readers = []
|
||||
tp_size_per_node = server_args.tp_size // server_args.nnodes
|
||||
tp_rank_range = range(
|
||||
tp_size_per_node * server_args.node_rank,
|
||||
tp_size_per_node * (server_args.node_rank + 1),
|
||||
)
|
||||
for tp_rank in tp_rank_range:
|
||||
if server_args.dp_size == 1:
|
||||
# Launch tensor parallel scheduler processes
|
||||
scheduler_procs = []
|
||||
scheduler_pipe_readers = []
|
||||
tp_size_per_node = server_args.tp_size // server_args.nnodes
|
||||
tp_rank_range = range(
|
||||
tp_size_per_node * server_args.node_rank,
|
||||
tp_size_per_node * (server_args.node_rank + 1),
|
||||
)
|
||||
for tp_rank in tp_rank_range:
|
||||
reader, writer = mp.Pipe(duplex=False)
|
||||
gpu_id = tp_rank % tp_size_per_node
|
||||
proc = mp.Process(
|
||||
target=run_scheduler_process,
|
||||
args=(server_args, port_args, gpu_id, tp_rank, None, writer),
|
||||
)
|
||||
proc.start()
|
||||
scheduler_procs.append(proc)
|
||||
scheduler_pipe_readers.append(reader)
|
||||
|
||||
if server_args.node_rank >= 1:
|
||||
# For other nodes, they do not need to run tokenizer or detokenizer,
|
||||
# so they can just wait here.
|
||||
while True:
|
||||
pass
|
||||
else:
|
||||
# Launch the data parallel controller
|
||||
reader, writer = mp.Pipe(duplex=False)
|
||||
gpu_id = tp_rank % tp_size_per_node
|
||||
scheduler_pipe_readers = [reader]
|
||||
proc = mp.Process(
|
||||
target=run_scheduler_process,
|
||||
args=(server_args, port_args, gpu_id, tp_rank, writer),
|
||||
target=run_data_parallel_controller_process,
|
||||
args=(server_args, port_args, writer),
|
||||
)
|
||||
proc.start()
|
||||
scheduler_procs.append(proc)
|
||||
scheduler_pipe_readers.append(reader)
|
||||
|
||||
if server_args.node_rank >= 1:
|
||||
# For other nodes, they do not need to run tokenizer or detokenizer,
|
||||
# so they can just wait here.
|
||||
while True:
|
||||
pass
|
||||
|
||||
# Launch detokenizer process
|
||||
detoken_proc = mp.Process(
|
||||
|
||||
@@ -574,7 +574,7 @@ class ServerArgs:
|
||||
self.tp_size % self.nnodes == 0
|
||||
), "tp_size must be divisible by number of nodes"
|
||||
assert not (
|
||||
self.dp_size > 1 and self.node_rank is not None
|
||||
self.dp_size > 1 and self.nnodes != 1
|
||||
), "multi-node data parallel is not supported"
|
||||
assert (
|
||||
self.max_loras_per_batch > 0
|
||||
@@ -583,11 +583,6 @@ class ServerArgs:
|
||||
and (self.lora_paths is None or self.disable_radix_cache)
|
||||
), "compatibility of lora and cuda graph and radix attention is in progress"
|
||||
|
||||
assert self.dp_size == 1, (
|
||||
"The support for data parallelism is temporarily disabled during refactor. "
|
||||
"Please use sglang<=0.3.2 or wait for later updates."
|
||||
)
|
||||
|
||||
if isinstance(self.lora_paths, list):
|
||||
lora_paths = self.lora_paths
|
||||
self.lora_paths = {}
|
||||
@@ -626,8 +621,8 @@ class PortArgs:
|
||||
# The ipc filename for detokenizer to receive inputs from scheduler (zmq)
|
||||
detokenizer_ipc_name: str
|
||||
|
||||
# The port for nccl initialization for multiple TP groups (torch.dist)
|
||||
nccl_ports: List[int]
|
||||
# The port for nccl initialization (torch.dist)
|
||||
nccl_port: int
|
||||
|
||||
@staticmethod
|
||||
def init_new(server_args) -> "PortArgs":
|
||||
@@ -641,7 +636,7 @@ class PortArgs:
|
||||
tokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
|
||||
scheduler_input_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
|
||||
detokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
|
||||
nccl_ports=[port],
|
||||
nccl_port=port,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user