Add back data parallelism (#1635)

This commit is contained in:
Lianmin Zheng
2024-10-11 07:22:48 -07:00
committed by GitHub
parent 5d09ca5735
commit 23cc66f7b6
7 changed files with 228 additions and 39 deletions

View File

@@ -44,6 +44,9 @@ from fastapi.responses import JSONResponse, Response, StreamingResponse
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.data_parallel_controller import (
run_data_parallel_controller_process,
)
from sglang.srt.managers.detokenizer_manager import run_detokenizer_process
from sglang.srt.managers.io_struct import (
EmbeddingReqInput,
@@ -337,30 +340,40 @@ def launch_engine(
server_args.model_path, server_args.tokenizer_path
)
# Launch tensor parallel scheduler processes
scheduler_procs = []
scheduler_pipe_readers = []
tp_size_per_node = server_args.tp_size // server_args.nnodes
tp_rank_range = range(
tp_size_per_node * server_args.node_rank,
tp_size_per_node * (server_args.node_rank + 1),
)
for tp_rank in tp_rank_range:
if server_args.dp_size == 1:
# Launch tensor parallel scheduler processes
scheduler_procs = []
scheduler_pipe_readers = []
tp_size_per_node = server_args.tp_size // server_args.nnodes
tp_rank_range = range(
tp_size_per_node * server_args.node_rank,
tp_size_per_node * (server_args.node_rank + 1),
)
for tp_rank in tp_rank_range:
reader, writer = mp.Pipe(duplex=False)
gpu_id = tp_rank % tp_size_per_node
proc = mp.Process(
target=run_scheduler_process,
args=(server_args, port_args, gpu_id, tp_rank, None, writer),
)
proc.start()
scheduler_procs.append(proc)
scheduler_pipe_readers.append(reader)
if server_args.node_rank >= 1:
# For other nodes, they do not need to run tokenizer or detokenizer,
# so they can just wait here.
while True:
pass
else:
# Launch the data parallel controller
reader, writer = mp.Pipe(duplex=False)
gpu_id = tp_rank % tp_size_per_node
scheduler_pipe_readers = [reader]
proc = mp.Process(
target=run_scheduler_process,
args=(server_args, port_args, gpu_id, tp_rank, writer),
target=run_data_parallel_controller_process,
args=(server_args, port_args, writer),
)
proc.start()
scheduler_procs.append(proc)
scheduler_pipe_readers.append(reader)
if server_args.node_rank >= 1:
# For other nodes, they do not need to run tokenizer or detokenizer,
# so they can just wait here.
while True:
pass
# Launch detokenizer process
detoken_proc = mp.Process(