Improve process creation (#1534)

This commit is contained in:
Lianmin Zheng
2024-09-29 02:36:12 -07:00
committed by GitHub
parent fd9ad817ec
commit 048685430d
15 changed files with 270 additions and 677 deletions

View File

@@ -16,13 +16,12 @@ limitations under the License.
"""Common utilities."""
import base64
import fcntl
import logging
import os
import pickle
import random
import resource
import socket
import struct
import time
from importlib.metadata import PackageNotFoundError, version
from io import BytesIO
@@ -36,7 +35,6 @@ import torch.distributed as dist
from fastapi.responses import JSONResponse
from packaging import version as pkg_version
from torch import nn
from torch.nn.parameter import Parameter
from triton.runtime.cache import (
FileCacheManager,
default_cache_dir,
@@ -539,89 +537,6 @@ class CustomCacheManager(FileCacheManager):
raise RuntimeError("Could not create or locate cache dir")
def get_ip_address(ifname):
"""
Get the IP address of a network interface.
:param ifname: Name of the network interface (e.g., 'eth0')
:return: IP address of the network interface
"""
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
ip_address = fcntl.ioctl(
s.fileno(),
0x8915, # SIOCGIFADDR
struct.pack("256s", bytes(ifname[:15], "utf-8")),
)[20:24]
return socket.inet_ntoa(ip_address)
def send_addrs_to_rank_0(model_port_args, server_args):
assert server_args.node_rank != 0 and server_args.dp_size == 1
ifname = os.environ.get(
"SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0")
)
ip_addr = get_ip_address(ifname)
num_tp_ports = server_args.tp_size // server_args.nnodes
model_port_args.model_tp_ips[:num_tp_ports] = [ip_addr] * num_tp_ports
ip_addr = [int(x) for x in ip_addr.split(".")]
addrs_tensor = torch.tensor(
ip_addr + model_port_args.model_tp_ports, dtype=torch.int
)
init_method = f"tcp://{server_args.nccl_init_addr}"
dist.init_process_group(
backend="gloo",
init_method=init_method,
rank=server_args.node_rank,
world_size=server_args.nnodes,
)
dist.send(addrs_tensor, dst=0)
print(
f"Node {server_args.node_rank} sent: ip_address {ip_addr} and ports {model_port_args.model_tp_ports}"
)
dist.barrier()
dist.destroy_process_group()
def receive_addrs(model_port_args, server_args):
assert server_args.node_rank == 0 and server_args.dp_size == 1
ifname = os.environ.get(
"SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0")
)
ip_addr = get_ip_address(ifname)
num_tp_ports = server_args.tp_size // server_args.nnodes
model_port_args.model_tp_ips[:num_tp_ports] = [ip_addr] * num_tp_ports
init_method = f"tcp://{server_args.nccl_init_addr}"
dist.init_process_group(
backend="gloo",
init_method=init_method,
rank=server_args.node_rank,
world_size=server_args.nnodes,
)
for src_rank in range(1, server_args.nnodes):
tensor = torch.zeros(4 + num_tp_ports, dtype=torch.int)
dist.recv(tensor, src=src_rank)
ip = ".".join([str(x) for x in tensor[:4].tolist()])
ports = tensor[4:].tolist()
model_port_args.model_tp_ips[
num_tp_ports * src_rank : num_tp_ports * (src_rank + 1)
] = [ip] * num_tp_ports
model_port_args.model_tp_ports[
num_tp_ports * src_rank : num_tp_ports * (src_rank + 1)
] = ports
print(f"Node 0 received from rank {src_rank}: {tensor.tolist()}")
dist.barrier()
dist.destroy_process_group()
def set_ulimit(target_soft_limit=65535):
resource_type = resource.RLIMIT_NOFILE
current_soft, current_hard = resource.getrlimit(resource_type)
@@ -645,24 +560,16 @@ def add_api_key_middleware(app, api_key: str):
return await call_next(request)
def prepare_model(model_path: str):
def prepare_model_and_tokenizer(model_path: str, tokenizer_path: str):
if "SGLANG_USE_MODELSCOPE" in os.environ:
if not os.path.exists(model_path):
from modelscope import snapshot_download
return snapshot_download(model_path)
return model_path
def prepare_tokenizer(tokenizer_path: str):
if "SGLANG_USE_MODELSCOPE" in os.environ:
if not os.path.exists(tokenizer_path):
from modelscope import snapshot_download
return snapshot_download(
model_path = snapshot_download(model_path)
tokenizer_path = snapshot_download(
tokenizer_path, ignore_patterns=["*.bin", "*.safetensors"]
)
return tokenizer_path
return model_path, tokenizer_path
def configure_logger(server_args, prefix: str = ""):
@@ -704,3 +611,37 @@ def set_weight_attrs(
for key, value in weight_attrs.items():
assert not hasattr(weight, key), f"Overwriting existing tensor attribute: {key}"
setattr(weight, key, value)
def broadcast_pyobj(
data: List[Any], rank: int, dist_group: torch.distributed.ProcessGroup
):
"""Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
if rank == 0:
if len(data) == 0:
tensor_size = torch.tensor([0], dtype=torch.long)
dist.broadcast(tensor_size, src=0, group=dist_group)
else:
serialized_data = pickle.dumps(data)
size = len(serialized_data)
tensor_data = torch.ByteTensor(list(serialized_data))
tensor_size = torch.tensor([size], dtype=torch.long)
dist.broadcast(tensor_size, src=0, group=dist_group)
dist.broadcast(tensor_data, src=0, group=dist_group)
return data
else:
tensor_size = torch.tensor([0], dtype=torch.long)
dist.broadcast(tensor_size, src=0, group=dist_group)
size = tensor_size.item()
if size == 0:
return []
tensor_data = torch.empty(size, dtype=torch.uint8)
dist.broadcast(tensor_data, src=0, group=dist_group)
serialized_data = bytes(tensor_data.tolist())
data = pickle.loads(serialized_data)
return data