Multi-node Tensor Parallelism (#550)
Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com>
This commit is contained in:
@@ -1,11 +1,13 @@
|
||||
"""Common utilities."""
|
||||
|
||||
import base64
|
||||
import fcntl
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import random
|
||||
import socket
|
||||
import struct
|
||||
import time
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
from io import BytesIO
|
||||
@@ -369,23 +371,7 @@ def load_image(image_file):
|
||||
return image, image_size
|
||||
|
||||
|
||||
def init_rpyc_service(service: rpyc.Service, port: int):
|
||||
t = ThreadedServer(
|
||||
service=service,
|
||||
port=port,
|
||||
protocol_config={
|
||||
"allow_public_attrs": True,
|
||||
"allow_pickle": True,
|
||||
"sync_request_timeout": 3600,
|
||||
},
|
||||
)
|
||||
t.logger.setLevel(logging.WARN)
|
||||
t.start()
|
||||
|
||||
|
||||
def connect_to_rpyc_service(port, host="localhost"):
|
||||
time.sleep(1)
|
||||
|
||||
def connect_rpyc_service(host, port):
|
||||
repeat_count = 0
|
||||
while repeat_count < 20:
|
||||
try:
|
||||
@@ -399,22 +385,33 @@ def connect_to_rpyc_service(port, host="localhost"):
|
||||
},
|
||||
)
|
||||
break
|
||||
except ConnectionRefusedError:
|
||||
except ConnectionRefusedError as e:
|
||||
time.sleep(1)
|
||||
repeat_count += 1
|
||||
if repeat_count == 20:
|
||||
raise RuntimeError("init rpc env error!")
|
||||
raise RuntimeError(f"Connect rpyc error: {e}")
|
||||
|
||||
return con.root
|
||||
|
||||
|
||||
def start_rpyc_process(service: rpyc.Service, port: int):
|
||||
# Return the proxy and the process
|
||||
proc = multiprocessing.Process(target=init_rpyc_service, args=(service, port))
|
||||
def start_rpyc_service(service: rpyc.Service, port: int):
|
||||
t = ThreadedServer(
|
||||
service=service,
|
||||
port=port,
|
||||
protocol_config={
|
||||
"allow_public_attrs": True,
|
||||
"allow_pickle": True,
|
||||
"sync_request_timeout": 3600,
|
||||
},
|
||||
)
|
||||
t.logger.setLevel(logging.WARN)
|
||||
t.start()
|
||||
|
||||
|
||||
def start_rpyc_service_process(service: rpyc.Service, port: int):
|
||||
proc = multiprocessing.Process(target=start_rpyc_service, args=(service, port))
|
||||
proc.start()
|
||||
proxy = connect_to_rpyc_service(port)
|
||||
assert proc.is_alive()
|
||||
return proxy, proc
|
||||
return proc
|
||||
|
||||
|
||||
def suppress_other_loggers():
|
||||
@@ -487,3 +484,66 @@ class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
|
||||
)
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
|
||||
def get_ip_address(ifname):
|
||||
"""
|
||||
Get the IP address of a network interface.
|
||||
|
||||
:param ifname: Name of the network interface (e.g., 'eth0')
|
||||
:return: IP address of the network interface
|
||||
"""
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
ip_address = fcntl.ioctl(
|
||||
s.fileno(),
|
||||
0x8915, # SIOCGIFADDR
|
||||
struct.pack('256s', bytes(ifname[:15], 'utf-8'))
|
||||
)[20:24]
|
||||
return socket.inet_ntoa(ip_address)
|
||||
|
||||
|
||||
def send_addrs_to_rank_0(model_port_args, server_args):
|
||||
assert server_args.node_rank != 0 and server_args.dp_size == 1
|
||||
import torch.distributed as dist
|
||||
|
||||
ifname = os.environ.get("SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0"))
|
||||
ip_addr = get_ip_address(ifname)
|
||||
|
||||
num_tp_ports = server_args.tp_size // server_args.nnodes
|
||||
model_port_args.model_tp_ips[:num_tp_ports] = [ip_addr] * num_tp_ports
|
||||
ip_addr = [int(x) for x in ip_addr.split(".")]
|
||||
addrs_tensor = torch.tensor(ip_addr + model_port_args.model_tp_ports, dtype=torch.int)
|
||||
|
||||
init_method = f"tcp://{server_args.nccl_init_addr}"
|
||||
dist.init_process_group(backend="gloo", init_method=init_method, rank=server_args.node_rank, world_size=server_args.nnodes)
|
||||
dist.send(addrs_tensor, dst=0)
|
||||
print(f"Node {server_args.node_rank} sent: ip_address {ip_addr} and ports {model_port_args.model_tp_ports}")
|
||||
|
||||
dist.barrier()
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
def receive_addrs(model_port_args, server_args):
|
||||
assert server_args.node_rank == 0 and server_args.dp_size == 1
|
||||
import torch.distributed as dist
|
||||
|
||||
ifname = os.environ.get("SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0"))
|
||||
ip_addr = get_ip_address(ifname)
|
||||
|
||||
num_tp_ports = server_args.tp_size // server_args.nnodes
|
||||
model_port_args.model_tp_ips[:num_tp_ports] = [ip_addr] * num_tp_ports
|
||||
|
||||
init_method = f"tcp://{server_args.nccl_init_addr}"
|
||||
dist.init_process_group(backend="gloo", init_method=init_method, rank=server_args.node_rank, world_size=server_args.nnodes)
|
||||
|
||||
for src_rank in range(1, server_args.nnodes):
|
||||
tensor = torch.zeros(4 + num_tp_ports, dtype=torch.int)
|
||||
dist.recv(tensor, src=src_rank)
|
||||
ip = ".".join([str(x) for x in tensor[:4].tolist()])
|
||||
ports = tensor[4:].tolist()
|
||||
model_port_args.model_tp_ips[num_tp_ports * src_rank: num_tp_ports * (src_rank + 1)] = [ip] * num_tp_ports
|
||||
model_port_args.model_tp_ports[num_tp_ports * src_rank: num_tp_ports * (src_rank + 1)] = ports
|
||||
print(f"Node 0 received from rank {src_rank}: {tensor.tolist()}")
|
||||
|
||||
dist.barrier()
|
||||
dist.destroy_process_group()
|
||||
Reference in New Issue
Block a user