Multi-node Tensor Parallelism (#550)
Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com>
This commit is contained in:
@@ -20,7 +20,7 @@ python3 bench_throughput.py --backend srt --tokenizer meta-llama/Llama-2-7b-chat
|
|||||||
|
|
||||||
```
|
```
|
||||||
# run synthetic
|
# run synthetic
|
||||||
python3 synthetic_benchmark.py --backend srt --tokenizer meta-llama/Llama-2-7b-chat-hf --num-prompt 1000 --request-rate 100 --input-len 1024 --output-len 256 --port 30000
|
python3 bench_throughput.py --backend srt --tokenizer meta-llama/Llama-2-7b-chat-hf --num-prompt 1000 --request-rate 100 --input-len 1024 --output-len 256 --port 30000
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
@@ -36,7 +36,7 @@ python3 bench_throughput.py --backend vllm --tokenizer meta-llama/Llama-2-7b-cha
|
|||||||
|
|
||||||
```
|
```
|
||||||
# run synthetic
|
# run synthetic
|
||||||
python3 synthetic_benchmark.py --backend vllm --tokenizer meta-llama/Llama-2-7b-chat-hf --num-prompt 1000 --request-rate 100 --input-len 1024 --output-len 256 --port 30000
|
python3 bench_throughput.py --backend vllm --tokenizer meta-llama/Llama-2-7b-chat-hf --num-prompt 1000 --request-rate 100 --input-len 1024 --output-len 256 --port 30000
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ if __name__ == "__main__":
|
|||||||
raise ValueError(f"Invalid backend: {args.backend}")
|
raise ValueError(f"Invalid backend: {args.backend}")
|
||||||
|
|
||||||
url = f"{args.host}:{args.port}"
|
url = f"{args.host}:{args.port}"
|
||||||
a = random.randint(0, 1 << 20)
|
a = 20
|
||||||
max_new_tokens = 256
|
max_new_tokens = 256
|
||||||
prompt = f"{a, }"
|
prompt = f"{a, }"
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,8 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
from sglang.srt.server import ServerArgs, launch_server
|
from sglang.srt.server import launch_server
|
||||||
|
from sglang.srt.server_args import ServerArgs
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|||||||
@@ -76,8 +76,9 @@ def start_controller_process(
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
tp_size_local = server_args.tp_size // server_args.nnodes
|
||||||
model_client = ModelTpClient(
|
model_client = ModelTpClient(
|
||||||
list(range(server_args.tp_size)),
|
[i for _ in range(server_args.nnodes) for i in range(tp_size_local)],
|
||||||
server_args,
|
server_args,
|
||||||
port_args.model_port_args[0],
|
port_args.model_port_args[0],
|
||||||
model_overide_args,
|
model_overide_args,
|
||||||
|
|||||||
@@ -246,12 +246,16 @@ class ModelRunner:
|
|||||||
torch.cuda.set_device(self.gpu_id)
|
torch.cuda.set_device(self.gpu_id)
|
||||||
logger.info(f"[gpu_id={self.gpu_id}] Init nccl begin.")
|
logger.info(f"[gpu_id={self.gpu_id}] Init nccl begin.")
|
||||||
monkey_patch_vllm_p2p_access_check(self.gpu_id)
|
monkey_patch_vllm_p2p_access_check(self.gpu_id)
|
||||||
|
if server_args.nccl_init_addr:
|
||||||
|
nccl_init_method = f"tcp://{server_args.nccl_init_addr}"
|
||||||
|
else:
|
||||||
|
nccl_init_method = f"tcp://127.0.0.1:{self.nccl_port}"
|
||||||
init_distributed_environment(
|
init_distributed_environment(
|
||||||
backend="nccl",
|
backend="nccl",
|
||||||
world_size=self.tp_size,
|
world_size=self.tp_size,
|
||||||
rank=self.tp_rank,
|
rank=self.tp_rank,
|
||||||
local_rank=self.gpu_id,
|
local_rank=self.gpu_id,
|
||||||
distributed_init_method=f"tcp://127.0.0.1:{self.nccl_port}",
|
distributed_init_method=nccl_init_method
|
||||||
)
|
)
|
||||||
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
|
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
|
||||||
total_gpu_memory = get_available_gpu_memory(
|
total_gpu_memory = get_available_gpu_memory(
|
||||||
@@ -311,7 +315,7 @@ class ModelRunner:
|
|||||||
self.gpu_id, distributed=self.tp_size > 1
|
self.gpu_id, distributed=self.tp_size > 1
|
||||||
)
|
)
|
||||||
head_dim = self.model_config.head_dim
|
head_dim = self.model_config.head_dim
|
||||||
head_num = self.model_config.num_key_value_heads // self.tp_size
|
head_num = self.model_config.get_num_kv_heads(self.tp_size)
|
||||||
cell_size = head_num * head_dim * self.model_config.num_hidden_layers * 2 * 2
|
cell_size = head_num * head_dim * self.model_config.num_hidden_layers * 2 * 2
|
||||||
rest_memory = available_gpu_memory - total_gpu_memory * (
|
rest_memory = available_gpu_memory - total_gpu_memory * (
|
||||||
1 - self.mem_fraction_static
|
1 - self.mem_fraction_static
|
||||||
@@ -324,7 +328,7 @@ class ModelRunner:
|
|||||||
|
|
||||||
if self.max_total_num_tokens <= 0:
|
if self.max_total_num_tokens <= 0:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Not enought memory. Please try to increase --mem-fraction-static."
|
"Not enough memory. Please try to increase --mem-fraction-static."
|
||||||
)
|
)
|
||||||
|
|
||||||
self.req_to_token_pool = ReqToTokenPool(
|
self.req_to_token_pool = ReqToTokenPool(
|
||||||
|
|||||||
@@ -37,7 +37,8 @@ from sglang.srt.utils import (
|
|||||||
get_int_token_logit_bias,
|
get_int_token_logit_bias,
|
||||||
is_multimodal_model,
|
is_multimodal_model,
|
||||||
set_random_seed,
|
set_random_seed,
|
||||||
start_rpyc_process,
|
start_rpyc_service_process,
|
||||||
|
connect_rpyc_service,
|
||||||
suppress_other_loggers,
|
suppress_other_loggers,
|
||||||
)
|
)
|
||||||
from sglang.utils import get_exception_traceback
|
from sglang.utils import get_exception_traceback
|
||||||
@@ -770,12 +771,17 @@ class ModelTpClient:
|
|||||||
else:
|
else:
|
||||||
with ThreadPoolExecutor(self.tp_size) as executor:
|
with ThreadPoolExecutor(self.tp_size) as executor:
|
||||||
# Launch model processes
|
# Launch model processes
|
||||||
rets = executor.map(
|
if server_args.nnodes == 1:
|
||||||
lambda args: start_rpyc_process(*args),
|
self.procs = list(executor.map(
|
||||||
[(ModelTpService, p) for p in model_port_args.model_tp_ports],
|
lambda args: start_rpyc_service_process(*args),
|
||||||
)
|
[(ModelTpService, p) for p in model_port_args.model_tp_ports],
|
||||||
self.model_services = [x[0] for x in rets]
|
))
|
||||||
self.procs = [x[1] for x in rets]
|
addrs = [("localhost", p) for p in model_port_args.model_tp_ports]
|
||||||
|
else:
|
||||||
|
addrs = [(ip, port) for ip, port in zip(model_port_args.model_tp_ips, model_port_args.model_tp_ports)]
|
||||||
|
|
||||||
|
self.model_services = list(executor.map(
|
||||||
|
lambda args: connect_rpyc_service(*args), addrs))
|
||||||
|
|
||||||
# Init model
|
# Init model
|
||||||
def init_model(i):
|
def init_model(i):
|
||||||
@@ -787,7 +793,7 @@ class ModelTpClient:
|
|||||||
model_overide_args,
|
model_overide_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.model_servers = executor.map(init_model, range(self.tp_size))
|
self.model_servers = list(executor.map(init_model, range(self.tp_size)))
|
||||||
|
|
||||||
# Wrap functions
|
# Wrap functions
|
||||||
def async_wrap(func_name):
|
def async_wrap(func_name):
|
||||||
|
|||||||
@@ -71,7 +71,11 @@ class ModelConfig:
|
|||||||
return 1
|
return 1
|
||||||
|
|
||||||
# For DBRX and MPT
|
# For DBRX and MPT
|
||||||
if self.hf_config.model_type in ["dbrx", "mpt"]:
|
if self.hf_config.model_type in ["mpt"]:
|
||||||
|
if "kv_n_heads" in self.hf_config.attn_config:
|
||||||
|
return self.hf_config.attn_config["kv_n_heads"]
|
||||||
|
return self.hf_config.num_attention_heads
|
||||||
|
if self.hf_config.model_type in ["dbrx"]:
|
||||||
return getattr(
|
return getattr(
|
||||||
self.hf_config.attn_config,
|
self.hf_config.attn_config,
|
||||||
"kv_n_heads",
|
"kv_n_heads",
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ from sglang.srt.managers.controller.manager_multi import (
|
|||||||
from sglang.srt.managers.controller.manager_single import (
|
from sglang.srt.managers.controller.manager_single import (
|
||||||
start_controller_process as start_controller_process_single,
|
start_controller_process as start_controller_process_single,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.managers.controller.tp_worker import ModelTpService
|
||||||
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
|
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
|
||||||
from sglang.srt.managers.io_struct import GenerateReqInput
|
from sglang.srt.managers.io_struct import GenerateReqInput
|
||||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||||
@@ -50,9 +51,13 @@ from sglang.srt.utils import (
|
|||||||
allocate_init_ports,
|
allocate_init_ports,
|
||||||
assert_pkg_version,
|
assert_pkg_version,
|
||||||
enable_show_time_cost,
|
enable_show_time_cost,
|
||||||
|
send_addrs_to_rank_0,
|
||||||
|
receive_addrs,
|
||||||
|
start_rpyc_service_process,
|
||||||
)
|
)
|
||||||
from sglang.utils import get_exception_traceback
|
from sglang.utils import get_exception_traceback
|
||||||
|
|
||||||
|
|
||||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||||
|
|
||||||
|
|
||||||
@@ -151,21 +156,23 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
|||||||
load_chat_template_for_openai_api(server_args.chat_template)
|
load_chat_template_for_openai_api(server_args.chat_template)
|
||||||
|
|
||||||
# Allocate ports
|
# Allocate ports
|
||||||
|
assert server_args.tp_size % server_args.nnodes == 0
|
||||||
|
tp_size_local = server_args.tp_size // server_args.nnodes
|
||||||
server_args.port, server_args.additional_ports = allocate_init_ports(
|
server_args.port, server_args.additional_ports = allocate_init_ports(
|
||||||
server_args.port,
|
server_args.port,
|
||||||
server_args.additional_ports,
|
server_args.additional_ports,
|
||||||
server_args.tp_size,
|
tp_size_local,
|
||||||
server_args.dp_size,
|
server_args.dp_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
ports = server_args.additional_ports
|
ports = server_args.additional_ports
|
||||||
tp = server_args.tp_size
|
|
||||||
model_port_args = []
|
model_port_args = []
|
||||||
for i in range(server_args.dp_size):
|
for i in range(server_args.dp_size):
|
||||||
model_port_args.append(
|
model_port_args.append(
|
||||||
ModelPortArgs(
|
ModelPortArgs(
|
||||||
nccl_port=ports[3 + i * (tp + 1)],
|
nccl_port=ports[3 + i * (tp_size_local + 1)],
|
||||||
model_tp_ports=ports[3 + i * (tp + 1) + 1 : 3 + (i + 1) * (tp + 1)],
|
model_tp_ips=[None] * tp_size_local,
|
||||||
|
model_tp_ports=ports[3 + i * (tp_size_local + 1) + 1 : 3 + (i + 1) * (tp_size_local + 1)],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
port_args = PortArgs(
|
port_args = PortArgs(
|
||||||
@@ -175,6 +182,20 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
|||||||
model_port_args=model_port_args,
|
model_port_args=model_port_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO multi-node dp is not supported
|
||||||
|
assert not (server_args.dp_size > 1 and server_args.node_rank is not None)
|
||||||
|
if server_args.nnodes > 1:
|
||||||
|
if server_args.node_rank != 0:
|
||||||
|
send_addrs_to_rank_0(model_port_args[0], server_args)
|
||||||
|
else:
|
||||||
|
receive_addrs(model_port_args[0], server_args)
|
||||||
|
for i in range(tp_size_local):
|
||||||
|
start_rpyc_service_process(ModelTpService, model_port_args[0].model_tp_ports[i])
|
||||||
|
if server_args.node_rank != 0:
|
||||||
|
print("Listen for connections...")
|
||||||
|
while True:
|
||||||
|
pass
|
||||||
|
|
||||||
# Launch processes
|
# Launch processes
|
||||||
tokenizer_manager = TokenizerManager(server_args, port_args, model_overide_args)
|
tokenizer_manager = TokenizerManager(server_args, port_args, model_overide_args)
|
||||||
pipe_router_reader, pipe_router_writer = mp.Pipe(duplex=False)
|
pipe_router_reader, pipe_router_writer = mp.Pipe(duplex=False)
|
||||||
|
|||||||
@@ -56,6 +56,11 @@ class ServerArgs:
|
|||||||
disable_regex_jump_forward: bool = False
|
disable_regex_jump_forward: bool = False
|
||||||
disable_disk_cache: bool = False
|
disable_disk_cache: bool = False
|
||||||
|
|
||||||
|
# Distributed args
|
||||||
|
nccl_init_addr: Optional[str] = None
|
||||||
|
nnodes: int = 1
|
||||||
|
node_rank: Optional[int] = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.tokenizer_path is None:
|
if self.tokenizer_path is None:
|
||||||
self.tokenizer_path = self.model_path
|
self.tokenizer_path = self.model_path
|
||||||
@@ -252,6 +257,24 @@ class ServerArgs:
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Multi-node distributed serving args
|
||||||
|
parser.add_argument(
|
||||||
|
"--nccl-init-addr",
|
||||||
|
type=str,
|
||||||
|
help="The nccl init address of multi-node server."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--nnodes",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Number of nodes"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--node-rank",
|
||||||
|
type=int,
|
||||||
|
help="The node rank."
|
||||||
|
)
|
||||||
|
|
||||||
# Optimization/debug options
|
# Optimization/debug options
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--enable-flashinfer",
|
"--enable-flashinfer",
|
||||||
@@ -300,6 +323,7 @@ class ServerArgs:
|
|||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class ModelPortArgs:
|
class ModelPortArgs:
|
||||||
nccl_port: int
|
nccl_port: int
|
||||||
|
model_tp_ips: List[str]
|
||||||
model_tp_ports: List[int]
|
model_tp_ports: List[int]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +1,13 @@
|
|||||||
"""Common utilities."""
|
"""Common utilities."""
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
|
import fcntl
|
||||||
import logging
|
import logging
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import socket
|
import socket
|
||||||
|
import struct
|
||||||
import time
|
import time
|
||||||
from importlib.metadata import PackageNotFoundError, version
|
from importlib.metadata import PackageNotFoundError, version
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
@@ -369,23 +371,7 @@ def load_image(image_file):
|
|||||||
return image, image_size
|
return image, image_size
|
||||||
|
|
||||||
|
|
||||||
def init_rpyc_service(service: rpyc.Service, port: int):
|
def connect_rpyc_service(host, port):
|
||||||
t = ThreadedServer(
|
|
||||||
service=service,
|
|
||||||
port=port,
|
|
||||||
protocol_config={
|
|
||||||
"allow_public_attrs": True,
|
|
||||||
"allow_pickle": True,
|
|
||||||
"sync_request_timeout": 3600,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
t.logger.setLevel(logging.WARN)
|
|
||||||
t.start()
|
|
||||||
|
|
||||||
|
|
||||||
def connect_to_rpyc_service(port, host="localhost"):
|
|
||||||
time.sleep(1)
|
|
||||||
|
|
||||||
repeat_count = 0
|
repeat_count = 0
|
||||||
while repeat_count < 20:
|
while repeat_count < 20:
|
||||||
try:
|
try:
|
||||||
@@ -399,22 +385,33 @@ def connect_to_rpyc_service(port, host="localhost"):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
except ConnectionRefusedError:
|
except ConnectionRefusedError as e:
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
repeat_count += 1
|
repeat_count += 1
|
||||||
if repeat_count == 20:
|
if repeat_count == 20:
|
||||||
raise RuntimeError("init rpc env error!")
|
raise RuntimeError(f"Connect rpyc error: {e}")
|
||||||
|
|
||||||
return con.root
|
return con.root
|
||||||
|
|
||||||
|
|
||||||
def start_rpyc_process(service: rpyc.Service, port: int):
|
def start_rpyc_service(service: rpyc.Service, port: int):
|
||||||
# Return the proxy and the process
|
t = ThreadedServer(
|
||||||
proc = multiprocessing.Process(target=init_rpyc_service, args=(service, port))
|
service=service,
|
||||||
|
port=port,
|
||||||
|
protocol_config={
|
||||||
|
"allow_public_attrs": True,
|
||||||
|
"allow_pickle": True,
|
||||||
|
"sync_request_timeout": 3600,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
t.logger.setLevel(logging.WARN)
|
||||||
|
t.start()
|
||||||
|
|
||||||
|
|
||||||
|
def start_rpyc_service_process(service: rpyc.Service, port: int):
|
||||||
|
proc = multiprocessing.Process(target=start_rpyc_service, args=(service, port))
|
||||||
proc.start()
|
proc.start()
|
||||||
proxy = connect_to_rpyc_service(port)
|
return proc
|
||||||
assert proc.is_alive()
|
|
||||||
return proxy, proc
|
|
||||||
|
|
||||||
|
|
||||||
def suppress_other_loggers():
|
def suppress_other_loggers():
|
||||||
@@ -487,3 +484,66 @@ class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
|
|||||||
)
|
)
|
||||||
response = await call_next(request)
|
response = await call_next(request)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def get_ip_address(ifname):
|
||||||
|
"""
|
||||||
|
Get the IP address of a network interface.
|
||||||
|
|
||||||
|
:param ifname: Name of the network interface (e.g., 'eth0')
|
||||||
|
:return: IP address of the network interface
|
||||||
|
"""
|
||||||
|
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||||
|
ip_address = fcntl.ioctl(
|
||||||
|
s.fileno(),
|
||||||
|
0x8915, # SIOCGIFADDR
|
||||||
|
struct.pack('256s', bytes(ifname[:15], 'utf-8'))
|
||||||
|
)[20:24]
|
||||||
|
return socket.inet_ntoa(ip_address)
|
||||||
|
|
||||||
|
|
||||||
|
def send_addrs_to_rank_0(model_port_args, server_args):
|
||||||
|
assert server_args.node_rank != 0 and server_args.dp_size == 1
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
ifname = os.environ.get("SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0"))
|
||||||
|
ip_addr = get_ip_address(ifname)
|
||||||
|
|
||||||
|
num_tp_ports = server_args.tp_size // server_args.nnodes
|
||||||
|
model_port_args.model_tp_ips[:num_tp_ports] = [ip_addr] * num_tp_ports
|
||||||
|
ip_addr = [int(x) for x in ip_addr.split(".")]
|
||||||
|
addrs_tensor = torch.tensor(ip_addr + model_port_args.model_tp_ports, dtype=torch.int)
|
||||||
|
|
||||||
|
init_method = f"tcp://{server_args.nccl_init_addr}"
|
||||||
|
dist.init_process_group(backend="gloo", init_method=init_method, rank=server_args.node_rank, world_size=server_args.nnodes)
|
||||||
|
dist.send(addrs_tensor, dst=0)
|
||||||
|
print(f"Node {server_args.node_rank} sent: ip_address {ip_addr} and ports {model_port_args.model_tp_ports}")
|
||||||
|
|
||||||
|
dist.barrier()
|
||||||
|
dist.destroy_process_group()
|
||||||
|
|
||||||
|
|
||||||
|
def receive_addrs(model_port_args, server_args):
|
||||||
|
assert server_args.node_rank == 0 and server_args.dp_size == 1
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
ifname = os.environ.get("SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0"))
|
||||||
|
ip_addr = get_ip_address(ifname)
|
||||||
|
|
||||||
|
num_tp_ports = server_args.tp_size // server_args.nnodes
|
||||||
|
model_port_args.model_tp_ips[:num_tp_ports] = [ip_addr] * num_tp_ports
|
||||||
|
|
||||||
|
init_method = f"tcp://{server_args.nccl_init_addr}"
|
||||||
|
dist.init_process_group(backend="gloo", init_method=init_method, rank=server_args.node_rank, world_size=server_args.nnodes)
|
||||||
|
|
||||||
|
for src_rank in range(1, server_args.nnodes):
|
||||||
|
tensor = torch.zeros(4 + num_tp_ports, dtype=torch.int)
|
||||||
|
dist.recv(tensor, src=src_rank)
|
||||||
|
ip = ".".join([str(x) for x in tensor[:4].tolist()])
|
||||||
|
ports = tensor[4:].tolist()
|
||||||
|
model_port_args.model_tp_ips[num_tp_ports * src_rank: num_tp_ports * (src_rank + 1)] = [ip] * num_tp_ports
|
||||||
|
model_port_args.model_tp_ports[num_tp_ports * src_rank: num_tp_ports * (src_rank + 1)] = ports
|
||||||
|
print(f"Node 0 received from rank {src_rank}: {tensor.tolist()}")
|
||||||
|
|
||||||
|
dist.barrier()
|
||||||
|
dist.destroy_process_group()
|
||||||
Reference in New Issue
Block a user