Improve tensor parallel performance (#625)
Co-authored-by: Mingyi <wisclmy0611@gmail.com>
This commit is contained in:
@@ -33,9 +33,9 @@ from sglang.srt.managers.controller.manager_multi import (
|
||||
start_controller_process as start_controller_process_multi,
|
||||
)
|
||||
from sglang.srt.managers.controller.manager_single import (
|
||||
launch_tp_servers,
|
||||
start_controller_process as start_controller_process_single,
|
||||
)
|
||||
from sglang.srt.managers.controller.tp_worker import ModelTpService
|
||||
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
|
||||
from sglang.srt.managers.io_struct import GenerateReqInput
|
||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||
@@ -53,7 +53,6 @@ from sglang.srt.utils import (
|
||||
enable_show_time_cost,
|
||||
receive_addrs,
|
||||
send_addrs_to_rank_0,
|
||||
start_rpyc_service_process,
|
||||
)
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
@@ -192,21 +191,17 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
||||
model_port_args=model_port_args,
|
||||
)
|
||||
|
||||
# TODO multi-node dp is not supported
|
||||
assert not (server_args.dp_size > 1 and server_args.node_rank is not None)
|
||||
# Handle multi-node tp
|
||||
if server_args.nnodes > 1:
|
||||
assert server_args.dp_size == 1, "Multi-node dp is not supported."
|
||||
|
||||
if server_args.node_rank != 0:
|
||||
send_addrs_to_rank_0(model_port_args[0], server_args)
|
||||
else:
|
||||
receive_addrs(model_port_args[0], server_args)
|
||||
for i in range(tp_size_local):
|
||||
start_rpyc_service_process(
|
||||
ModelTpService, model_port_args[0].model_tp_ports[i]
|
||||
)
|
||||
if server_args.node_rank != 0:
|
||||
logger.info(
|
||||
f"[node_rank={server_args.node_rank}]: Listen for connections..."
|
||||
)
|
||||
tp_size_local = server_args.tp_size // server_args.nnodes
|
||||
gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)]
|
||||
tp_rank_range = list(range(server_args.node_rank * tp_size_local,
|
||||
(server_args.node_rank + 1) * tp_size_local))
|
||||
procs = launch_tp_servers(gpu_ids, tp_rank_range, server_args,
|
||||
port_args.model_port_args[0], model_overide_args)
|
||||
while True:
|
||||
pass
|
||||
|
||||
|
||||
Reference in New Issue
Block a user