Add back data parallelism (#1635)
This commit is contained in:
11
.github/workflows/pr-test.yml
vendored
11
.github/workflows/pr-test.yml
vendored
@@ -255,12 +255,11 @@ jobs:
|
|||||||
python3 test_mla.py
|
python3 test_mla.py
|
||||||
python3 test_mla_fp8.py
|
python3 test_mla_fp8.py
|
||||||
|
|
||||||
# Temporarily disabled
|
- name: Evaluate Data Parallelism Accuracy (DP=2)
|
||||||
#- name: Evaluate Data Parallelism Accuracy (TP=2)
|
timeout-minutes: 10
|
||||||
# timeout-minutes: 10
|
run: |
|
||||||
# run: |
|
cd test/srt
|
||||||
# cd test/srt
|
python3 test_data_parallelism.py
|
||||||
# python3 test_data_parallelism.py
|
|
||||||
|
|
||||||
finish:
|
finish:
|
||||||
needs: [
|
needs: [
|
||||||
|
|||||||
@@ -139,7 +139,7 @@ def load_model(server_args, port_args, tp_rank):
|
|||||||
gpu_id=tp_rank,
|
gpu_id=tp_rank,
|
||||||
tp_rank=tp_rank,
|
tp_rank=tp_rank,
|
||||||
tp_size=server_args.tp_size,
|
tp_size=server_args.tp_size,
|
||||||
nccl_port=port_args.nccl_ports[0],
|
nccl_port=port_args.nccl_port,
|
||||||
server_args=server_args,
|
server_args=server_args,
|
||||||
)
|
)
|
||||||
rank_print(f"max_total_num_tokens={model_runner.max_total_num_tokens}")
|
rank_print(f"max_total_num_tokens={model_runner.max_total_num_tokens}")
|
||||||
|
|||||||
177
python/sglang/srt/managers/data_parallel_controller.py
Normal file
177
python/sglang/srt/managers/data_parallel_controller.py
Normal file
@@ -0,0 +1,177 @@
|
|||||||
|
"""
|
||||||
|
Copyright 2023-2024 SGLang Team
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""A controller that dispatches requests to multiple data parallel workers."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import multiprocessing as mp
|
||||||
|
from enum import Enum, auto
|
||||||
|
|
||||||
|
import zmq
|
||||||
|
|
||||||
|
from sglang.srt.managers.io_struct import (
|
||||||
|
TokenizedEmbeddingReqInput,
|
||||||
|
TokenizedGenerateReqInput,
|
||||||
|
TokenizedRewardReqInput,
|
||||||
|
)
|
||||||
|
from sglang.srt.managers.scheduler import run_scheduler_process
|
||||||
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
|
from sglang.srt.utils import (
|
||||||
|
configure_logger,
|
||||||
|
kill_parent_process,
|
||||||
|
suppress_other_loggers,
|
||||||
|
)
|
||||||
|
from sglang.utils import get_exception_traceback
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class LoadBalanceMethod(Enum):
|
||||||
|
"""Load balance method."""
|
||||||
|
|
||||||
|
ROUND_ROBIN = auto()
|
||||||
|
SHORTEST_QUEUE = auto()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_str(cls, method: str):
|
||||||
|
method = method.upper()
|
||||||
|
try:
|
||||||
|
return cls[method]
|
||||||
|
except KeyError as exc:
|
||||||
|
raise ValueError(f"Invalid load balance method: {method}") from exc
|
||||||
|
|
||||||
|
|
||||||
|
class DataParallelController:
|
||||||
|
"""A controller that dispatches requests to multiple data parallel workers."""
|
||||||
|
|
||||||
|
def __init__(self, server_args, port_args) -> None:
|
||||||
|
# Parse args
|
||||||
|
self.server_args = server_args
|
||||||
|
self.port_args = port_args
|
||||||
|
self.load_balance_method = LoadBalanceMethod.from_str(
|
||||||
|
server_args.load_balance_method
|
||||||
|
)
|
||||||
|
|
||||||
|
# Init inter-process communication
|
||||||
|
self.context = zmq.Context(1 + server_args.dp_size)
|
||||||
|
self.recv_from_tokenizer = self.context.socket(zmq.PULL)
|
||||||
|
self.recv_from_tokenizer.bind(f"ipc://{port_args.scheduler_input_ipc_name}")
|
||||||
|
|
||||||
|
# Dispatch method
|
||||||
|
self.round_robin_counter = 0
|
||||||
|
dispatch_lookup = {
|
||||||
|
LoadBalanceMethod.ROUND_ROBIN: self.round_robin_scheduler,
|
||||||
|
LoadBalanceMethod.SHORTEST_QUEUE: self.shortest_queue_scheduler,
|
||||||
|
}
|
||||||
|
self.dispatching = dispatch_lookup[self.load_balance_method]
|
||||||
|
|
||||||
|
# Start data parallel workers
|
||||||
|
base_gpu_id = 0
|
||||||
|
self.workers = []
|
||||||
|
for dp_rank in range(server_args.dp_size):
|
||||||
|
tmp_port_args = PortArgs.init_new(server_args)
|
||||||
|
tmp_port_args.detokenizer_ipc_name = port_args.detokenizer_ipc_name
|
||||||
|
|
||||||
|
send_to = self.launch_tensor_parallel_group(
|
||||||
|
server_args,
|
||||||
|
tmp_port_args,
|
||||||
|
base_gpu_id,
|
||||||
|
dp_rank,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.workers.append(send_to)
|
||||||
|
base_gpu_id += server_args.tp_size
|
||||||
|
|
||||||
|
def launch_tensor_parallel_group(
|
||||||
|
self,
|
||||||
|
server_args: ServerArgs,
|
||||||
|
port_args: PortArgs,
|
||||||
|
base_gpu_id: int,
|
||||||
|
dp_rank: int,
|
||||||
|
):
|
||||||
|
# Launch tensor parallel scheduler processes
|
||||||
|
scheduler_procs = []
|
||||||
|
scheduler_pipe_readers = []
|
||||||
|
tp_size_per_node = server_args.tp_size // server_args.nnodes
|
||||||
|
tp_rank_range = range(
|
||||||
|
tp_size_per_node * server_args.node_rank,
|
||||||
|
tp_size_per_node * (server_args.node_rank + 1),
|
||||||
|
)
|
||||||
|
for tp_rank in tp_rank_range:
|
||||||
|
reader, writer = mp.Pipe(duplex=False)
|
||||||
|
gpu_id = base_gpu_id + tp_rank % tp_size_per_node
|
||||||
|
proc = mp.Process(
|
||||||
|
target=run_scheduler_process,
|
||||||
|
args=(server_args, port_args, gpu_id, tp_rank, dp_rank, writer),
|
||||||
|
)
|
||||||
|
proc.start()
|
||||||
|
scheduler_procs.append(proc)
|
||||||
|
scheduler_pipe_readers.append(reader)
|
||||||
|
|
||||||
|
send_to = self.context.socket(zmq.PUSH)
|
||||||
|
send_to.connect(f"ipc://{port_args.scheduler_input_ipc_name}")
|
||||||
|
|
||||||
|
# Wait for model to finish loading
|
||||||
|
for i in range(len(scheduler_pipe_readers)):
|
||||||
|
scheduler_pipe_readers[i].recv()
|
||||||
|
|
||||||
|
return send_to
|
||||||
|
|
||||||
|
def round_robin_scheduler(self, req):
|
||||||
|
self.workers[self.round_robin_counter].send_pyobj(req)
|
||||||
|
self.round_robin_counter = (self.round_robin_counter + 1) % len(self.workers)
|
||||||
|
|
||||||
|
def shortest_queue_scheduler(self, input_requests):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def event_loop(self):
|
||||||
|
while True:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
|
||||||
|
except zmq.ZMQError:
|
||||||
|
break
|
||||||
|
|
||||||
|
if isinstance(
|
||||||
|
recv_req,
|
||||||
|
(
|
||||||
|
TokenizedGenerateReqInput,
|
||||||
|
TokenizedEmbeddingReqInput,
|
||||||
|
TokenizedRewardReqInput,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
self.dispatching(recv_req)
|
||||||
|
else:
|
||||||
|
# Send other control messages to all workers
|
||||||
|
for worker in self.workers:
|
||||||
|
worker.queue.put(recv_req)
|
||||||
|
|
||||||
|
|
||||||
|
def run_data_parallel_controller_process(
|
||||||
|
server_args: ServerArgs,
|
||||||
|
port_args: PortArgs,
|
||||||
|
pipe_writer,
|
||||||
|
):
|
||||||
|
configure_logger(server_args)
|
||||||
|
suppress_other_loggers()
|
||||||
|
|
||||||
|
try:
|
||||||
|
controller = DataParallelController(server_args, port_args)
|
||||||
|
pipe_writer.send("ready")
|
||||||
|
controller.event_loop()
|
||||||
|
except Exception:
|
||||||
|
msg = get_exception_traceback()
|
||||||
|
logger.error(msg)
|
||||||
|
kill_parent_process()
|
||||||
@@ -142,7 +142,7 @@ class Scheduler:
|
|||||||
gpu_id=gpu_id,
|
gpu_id=gpu_id,
|
||||||
tp_rank=tp_rank,
|
tp_rank=tp_rank,
|
||||||
server_args=server_args,
|
server_args=server_args,
|
||||||
nccl_port=port_args.nccl_ports[0],
|
nccl_port=port_args.nccl_port,
|
||||||
)
|
)
|
||||||
self.tp_cpu_group = self.tp_worker.model_runner.tp_group.cpu_group
|
self.tp_cpu_group = self.tp_worker.model_runner.tp_group.cpu_group
|
||||||
|
|
||||||
@@ -1042,9 +1042,14 @@ def run_scheduler_process(
|
|||||||
port_args: PortArgs,
|
port_args: PortArgs,
|
||||||
gpu_id: int,
|
gpu_id: int,
|
||||||
tp_rank: int,
|
tp_rank: int,
|
||||||
|
dp_rank: Optional[int],
|
||||||
pipe_writer,
|
pipe_writer,
|
||||||
):
|
):
|
||||||
configure_logger(server_args, prefix=f" TP{tp_rank}")
|
if dp_rank is None:
|
||||||
|
configure_logger(server_args, prefix=f" TP{tp_rank}")
|
||||||
|
else:
|
||||||
|
configure_logger(server_args, prefix=f" DP{dp_rank} TP{tp_rank}")
|
||||||
|
|
||||||
suppress_other_loggers()
|
suppress_other_loggers()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -141,7 +141,7 @@ class ModelRunner:
|
|||||||
self.init_attention_backend()
|
self.init_attention_backend()
|
||||||
|
|
||||||
def init_torch_distributed(self):
|
def init_torch_distributed(self):
|
||||||
logger.info("Init torch distributed begin.")
|
logger.info("Init torch distributed begin.")
|
||||||
# Init torch distributed
|
# Init torch distributed
|
||||||
if self.device == "cuda":
|
if self.device == "cuda":
|
||||||
torch.cuda.set_device(self.gpu_id)
|
torch.cuda.set_device(self.gpu_id)
|
||||||
|
|||||||
@@ -44,6 +44,9 @@ from fastapi.responses import JSONResponse, Response, StreamingResponse
|
|||||||
|
|
||||||
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
|
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
|
||||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||||
|
from sglang.srt.managers.data_parallel_controller import (
|
||||||
|
run_data_parallel_controller_process,
|
||||||
|
)
|
||||||
from sglang.srt.managers.detokenizer_manager import run_detokenizer_process
|
from sglang.srt.managers.detokenizer_manager import run_detokenizer_process
|
||||||
from sglang.srt.managers.io_struct import (
|
from sglang.srt.managers.io_struct import (
|
||||||
EmbeddingReqInput,
|
EmbeddingReqInput,
|
||||||
@@ -337,30 +340,40 @@ def launch_engine(
|
|||||||
server_args.model_path, server_args.tokenizer_path
|
server_args.model_path, server_args.tokenizer_path
|
||||||
)
|
)
|
||||||
|
|
||||||
# Launch tensor parallel scheduler processes
|
if server_args.dp_size == 1:
|
||||||
scheduler_procs = []
|
# Launch tensor parallel scheduler processes
|
||||||
scheduler_pipe_readers = []
|
scheduler_procs = []
|
||||||
tp_size_per_node = server_args.tp_size // server_args.nnodes
|
scheduler_pipe_readers = []
|
||||||
tp_rank_range = range(
|
tp_size_per_node = server_args.tp_size // server_args.nnodes
|
||||||
tp_size_per_node * server_args.node_rank,
|
tp_rank_range = range(
|
||||||
tp_size_per_node * (server_args.node_rank + 1),
|
tp_size_per_node * server_args.node_rank,
|
||||||
)
|
tp_size_per_node * (server_args.node_rank + 1),
|
||||||
for tp_rank in tp_rank_range:
|
)
|
||||||
|
for tp_rank in tp_rank_range:
|
||||||
|
reader, writer = mp.Pipe(duplex=False)
|
||||||
|
gpu_id = tp_rank % tp_size_per_node
|
||||||
|
proc = mp.Process(
|
||||||
|
target=run_scheduler_process,
|
||||||
|
args=(server_args, port_args, gpu_id, tp_rank, None, writer),
|
||||||
|
)
|
||||||
|
proc.start()
|
||||||
|
scheduler_procs.append(proc)
|
||||||
|
scheduler_pipe_readers.append(reader)
|
||||||
|
|
||||||
|
if server_args.node_rank >= 1:
|
||||||
|
# For other nodes, they do not need to run tokenizer or detokenizer,
|
||||||
|
# so they can just wait here.
|
||||||
|
while True:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
# Launch the data parallel controller
|
||||||
reader, writer = mp.Pipe(duplex=False)
|
reader, writer = mp.Pipe(duplex=False)
|
||||||
gpu_id = tp_rank % tp_size_per_node
|
scheduler_pipe_readers = [reader]
|
||||||
proc = mp.Process(
|
proc = mp.Process(
|
||||||
target=run_scheduler_process,
|
target=run_data_parallel_controller_process,
|
||||||
args=(server_args, port_args, gpu_id, tp_rank, writer),
|
args=(server_args, port_args, writer),
|
||||||
)
|
)
|
||||||
proc.start()
|
proc.start()
|
||||||
scheduler_procs.append(proc)
|
|
||||||
scheduler_pipe_readers.append(reader)
|
|
||||||
|
|
||||||
if server_args.node_rank >= 1:
|
|
||||||
# For other nodes, they do not need to run tokenizer or detokenizer,
|
|
||||||
# so they can just wait here.
|
|
||||||
while True:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Launch detokenizer process
|
# Launch detokenizer process
|
||||||
detoken_proc = mp.Process(
|
detoken_proc = mp.Process(
|
||||||
|
|||||||
@@ -574,7 +574,7 @@ class ServerArgs:
|
|||||||
self.tp_size % self.nnodes == 0
|
self.tp_size % self.nnodes == 0
|
||||||
), "tp_size must be divisible by number of nodes"
|
), "tp_size must be divisible by number of nodes"
|
||||||
assert not (
|
assert not (
|
||||||
self.dp_size > 1 and self.node_rank is not None
|
self.dp_size > 1 and self.nnodes != 1
|
||||||
), "multi-node data parallel is not supported"
|
), "multi-node data parallel is not supported"
|
||||||
assert (
|
assert (
|
||||||
self.max_loras_per_batch > 0
|
self.max_loras_per_batch > 0
|
||||||
@@ -583,11 +583,6 @@ class ServerArgs:
|
|||||||
and (self.lora_paths is None or self.disable_radix_cache)
|
and (self.lora_paths is None or self.disable_radix_cache)
|
||||||
), "compatibility of lora and cuda graph and radix attention is in progress"
|
), "compatibility of lora and cuda graph and radix attention is in progress"
|
||||||
|
|
||||||
assert self.dp_size == 1, (
|
|
||||||
"The support for data parallelism is temporarily disabled during refactor. "
|
|
||||||
"Please use sglang<=0.3.2 or wait for later updates."
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(self.lora_paths, list):
|
if isinstance(self.lora_paths, list):
|
||||||
lora_paths = self.lora_paths
|
lora_paths = self.lora_paths
|
||||||
self.lora_paths = {}
|
self.lora_paths = {}
|
||||||
@@ -626,8 +621,8 @@ class PortArgs:
|
|||||||
# The ipc filename for detokenizer to receive inputs from scheduler (zmq)
|
# The ipc filename for detokenizer to receive inputs from scheduler (zmq)
|
||||||
detokenizer_ipc_name: str
|
detokenizer_ipc_name: str
|
||||||
|
|
||||||
# The port for nccl initialization for multiple TP groups (torch.dist)
|
# The port for nccl initialization (torch.dist)
|
||||||
nccl_ports: List[int]
|
nccl_port: int
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def init_new(server_args) -> "PortArgs":
|
def init_new(server_args) -> "PortArgs":
|
||||||
@@ -641,7 +636,7 @@ class PortArgs:
|
|||||||
tokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
|
tokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
|
||||||
scheduler_input_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
|
scheduler_input_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
|
||||||
detokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
|
detokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
|
||||||
nccl_ports=[port],
|
nccl_port=port,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user