Improve tensor parallel performance (#625)

Co-authored-by: Mingyi <wisclmy0611@gmail.com>
This commit is contained in:
Ying Sheng
2024-07-15 07:10:51 -07:00
committed by GitHub
parent 5ac8b80677
commit 6a2941f4d0
10 changed files with 171 additions and 81 deletions

View File

@@ -33,9 +33,9 @@ from sglang.srt.managers.controller.manager_multi import (
start_controller_process as start_controller_process_multi,
)
from sglang.srt.managers.controller.manager_single import (
launch_tp_servers,
start_controller_process as start_controller_process_single,
)
from sglang.srt.managers.controller.tp_worker import ModelTpService
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.managers.tokenizer_manager import TokenizerManager
@@ -53,7 +53,6 @@ from sglang.srt.utils import (
enable_show_time_cost,
receive_addrs,
send_addrs_to_rank_0,
start_rpyc_service_process,
)
from sglang.utils import get_exception_traceback
@@ -192,21 +191,17 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
model_port_args=model_port_args,
)
# TODO multi-node dp is not supported
assert not (server_args.dp_size > 1 and server_args.node_rank is not None)
# Handle multi-node tp
if server_args.nnodes > 1:
assert server_args.dp_size == 1, "Multi-node dp is not supported."
if server_args.node_rank != 0:
send_addrs_to_rank_0(model_port_args[0], server_args)
else:
receive_addrs(model_port_args[0], server_args)
for i in range(tp_size_local):
start_rpyc_service_process(
ModelTpService, model_port_args[0].model_tp_ports[i]
)
if server_args.node_rank != 0:
logger.info(
f"[node_rank={server_args.node_rank}]: Listen for connections..."
)
tp_size_local = server_args.tp_size // server_args.nnodes
gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)]
tp_rank_range = list(range(server_args.node_rank * tp_size_local,
(server_args.node_rank + 1) * tp_size_local))
procs = launch_tp_servers(gpu_ids, tp_rank_range, server_args,
port_args.model_port_args[0], model_overide_args)
while True:
pass