Multi-node Tensor Parallelism (#550)

Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com>
This commit is contained in:
Ying Sheng
2024-06-17 20:41:24 -07:00
committed by GitHub
parent 53a7ebd89a
commit 09593e9bc9
10 changed files with 167 additions and 46 deletions

View File

@@ -37,7 +37,8 @@ from sglang.srt.utils import (
get_int_token_logit_bias,
is_multimodal_model,
set_random_seed,
start_rpyc_process,
start_rpyc_service_process,
connect_rpyc_service,
suppress_other_loggers,
)
from sglang.utils import get_exception_traceback
@@ -770,12 +771,17 @@ class ModelTpClient:
else:
with ThreadPoolExecutor(self.tp_size) as executor:
# Launch model processes
rets = executor.map(
lambda args: start_rpyc_process(*args),
[(ModelTpService, p) for p in model_port_args.model_tp_ports],
)
self.model_services = [x[0] for x in rets]
self.procs = [x[1] for x in rets]
if server_args.nnodes == 1:
self.procs = list(executor.map(
lambda args: start_rpyc_service_process(*args),
[(ModelTpService, p) for p in model_port_args.model_tp_ports],
))
addrs = [("localhost", p) for p in model_port_args.model_tp_ports]
else:
addrs = [(ip, port) for ip, port in zip(model_port_args.model_tp_ips, model_port_args.model_tp_ports)]
self.model_services = list(executor.map(
lambda args: connect_rpyc_service(*args), addrs))
# Init model
def init_model(i):
@@ -787,7 +793,7 @@ class ModelTpClient:
model_overide_args,
)
self.model_servers = executor.map(init_model, range(self.tp_size))
self.model_servers = list(executor.map(init_model, range(self.tp_size)))
# Wrap functions
def async_wrap(func_name):