Multi-node Tensor Parallelism (#550)
Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com>
This commit is contained in:
@@ -37,7 +37,8 @@ from sglang.srt.utils import (
|
||||
get_int_token_logit_bias,
|
||||
is_multimodal_model,
|
||||
set_random_seed,
|
||||
start_rpyc_process,
|
||||
start_rpyc_service_process,
|
||||
connect_rpyc_service,
|
||||
suppress_other_loggers,
|
||||
)
|
||||
from sglang.utils import get_exception_traceback
|
||||
@@ -770,12 +771,17 @@ class ModelTpClient:
|
||||
else:
|
||||
with ThreadPoolExecutor(self.tp_size) as executor:
|
||||
# Launch model processes
|
||||
rets = executor.map(
|
||||
lambda args: start_rpyc_process(*args),
|
||||
[(ModelTpService, p) for p in model_port_args.model_tp_ports],
|
||||
)
|
||||
self.model_services = [x[0] for x in rets]
|
||||
self.procs = [x[1] for x in rets]
|
||||
if server_args.nnodes == 1:
|
||||
self.procs = list(executor.map(
|
||||
lambda args: start_rpyc_service_process(*args),
|
||||
[(ModelTpService, p) for p in model_port_args.model_tp_ports],
|
||||
))
|
||||
addrs = [("localhost", p) for p in model_port_args.model_tp_ports]
|
||||
else:
|
||||
addrs = [(ip, port) for ip, port in zip(model_port_args.model_tp_ips, model_port_args.model_tp_ports)]
|
||||
|
||||
self.model_services = list(executor.map(
|
||||
lambda args: connect_rpyc_service(*args), addrs))
|
||||
|
||||
# Init model
|
||||
def init_model(i):
|
||||
@@ -787,7 +793,7 @@ class ModelTpClient:
|
||||
model_overide_args,
|
||||
)
|
||||
|
||||
self.model_servers = executor.map(init_model, range(self.tp_size))
|
||||
self.model_servers = list(executor.map(init_model, range(self.tp_size)))
|
||||
|
||||
# Wrap functions
|
||||
def async_wrap(func_name):
|
||||
|
||||
Reference in New Issue
Block a user