Improve process creation (#1534)
This commit is contained in:
@@ -16,13 +16,12 @@ limitations under the License.
|
||||
"""Common utilities."""
|
||||
|
||||
import base64
|
||||
import fcntl
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import random
|
||||
import resource
|
||||
import socket
|
||||
import struct
|
||||
import time
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
from io import BytesIO
|
||||
@@ -36,7 +35,6 @@ import torch.distributed as dist
|
||||
from fastapi.responses import JSONResponse
|
||||
from packaging import version as pkg_version
|
||||
from torch import nn
|
||||
from torch.nn.parameter import Parameter
|
||||
from triton.runtime.cache import (
|
||||
FileCacheManager,
|
||||
default_cache_dir,
|
||||
@@ -539,89 +537,6 @@ class CustomCacheManager(FileCacheManager):
|
||||
raise RuntimeError("Could not create or locate cache dir")
|
||||
|
||||
|
||||
def get_ip_address(ifname):
|
||||
"""
|
||||
Get the IP address of a network interface.
|
||||
|
||||
:param ifname: Name of the network interface (e.g., 'eth0')
|
||||
:return: IP address of the network interface
|
||||
"""
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
ip_address = fcntl.ioctl(
|
||||
s.fileno(),
|
||||
0x8915, # SIOCGIFADDR
|
||||
struct.pack("256s", bytes(ifname[:15], "utf-8")),
|
||||
)[20:24]
|
||||
return socket.inet_ntoa(ip_address)
|
||||
|
||||
|
||||
def send_addrs_to_rank_0(model_port_args, server_args):
|
||||
assert server_args.node_rank != 0 and server_args.dp_size == 1
|
||||
|
||||
ifname = os.environ.get(
|
||||
"SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0")
|
||||
)
|
||||
ip_addr = get_ip_address(ifname)
|
||||
|
||||
num_tp_ports = server_args.tp_size // server_args.nnodes
|
||||
model_port_args.model_tp_ips[:num_tp_ports] = [ip_addr] * num_tp_ports
|
||||
ip_addr = [int(x) for x in ip_addr.split(".")]
|
||||
addrs_tensor = torch.tensor(
|
||||
ip_addr + model_port_args.model_tp_ports, dtype=torch.int
|
||||
)
|
||||
|
||||
init_method = f"tcp://{server_args.nccl_init_addr}"
|
||||
dist.init_process_group(
|
||||
backend="gloo",
|
||||
init_method=init_method,
|
||||
rank=server_args.node_rank,
|
||||
world_size=server_args.nnodes,
|
||||
)
|
||||
dist.send(addrs_tensor, dst=0)
|
||||
print(
|
||||
f"Node {server_args.node_rank} sent: ip_address {ip_addr} and ports {model_port_args.model_tp_ports}"
|
||||
)
|
||||
|
||||
dist.barrier()
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
def receive_addrs(model_port_args, server_args):
|
||||
assert server_args.node_rank == 0 and server_args.dp_size == 1
|
||||
|
||||
ifname = os.environ.get(
|
||||
"SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0")
|
||||
)
|
||||
ip_addr = get_ip_address(ifname)
|
||||
|
||||
num_tp_ports = server_args.tp_size // server_args.nnodes
|
||||
model_port_args.model_tp_ips[:num_tp_ports] = [ip_addr] * num_tp_ports
|
||||
|
||||
init_method = f"tcp://{server_args.nccl_init_addr}"
|
||||
dist.init_process_group(
|
||||
backend="gloo",
|
||||
init_method=init_method,
|
||||
rank=server_args.node_rank,
|
||||
world_size=server_args.nnodes,
|
||||
)
|
||||
|
||||
for src_rank in range(1, server_args.nnodes):
|
||||
tensor = torch.zeros(4 + num_tp_ports, dtype=torch.int)
|
||||
dist.recv(tensor, src=src_rank)
|
||||
ip = ".".join([str(x) for x in tensor[:4].tolist()])
|
||||
ports = tensor[4:].tolist()
|
||||
model_port_args.model_tp_ips[
|
||||
num_tp_ports * src_rank : num_tp_ports * (src_rank + 1)
|
||||
] = [ip] * num_tp_ports
|
||||
model_port_args.model_tp_ports[
|
||||
num_tp_ports * src_rank : num_tp_ports * (src_rank + 1)
|
||||
] = ports
|
||||
print(f"Node 0 received from rank {src_rank}: {tensor.tolist()}")
|
||||
|
||||
dist.barrier()
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
def set_ulimit(target_soft_limit=65535):
|
||||
resource_type = resource.RLIMIT_NOFILE
|
||||
current_soft, current_hard = resource.getrlimit(resource_type)
|
||||
@@ -645,24 +560,16 @@ def add_api_key_middleware(app, api_key: str):
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
def prepare_model(model_path: str):
|
||||
def prepare_model_and_tokenizer(model_path: str, tokenizer_path: str):
|
||||
if "SGLANG_USE_MODELSCOPE" in os.environ:
|
||||
if not os.path.exists(model_path):
|
||||
from modelscope import snapshot_download
|
||||
|
||||
return snapshot_download(model_path)
|
||||
return model_path
|
||||
|
||||
|
||||
def prepare_tokenizer(tokenizer_path: str):
|
||||
if "SGLANG_USE_MODELSCOPE" in os.environ:
|
||||
if not os.path.exists(tokenizer_path):
|
||||
from modelscope import snapshot_download
|
||||
|
||||
return snapshot_download(
|
||||
model_path = snapshot_download(model_path)
|
||||
tokenizer_path = snapshot_download(
|
||||
tokenizer_path, ignore_patterns=["*.bin", "*.safetensors"]
|
||||
)
|
||||
return tokenizer_path
|
||||
return model_path, tokenizer_path
|
||||
|
||||
|
||||
def configure_logger(server_args, prefix: str = ""):
|
||||
@@ -704,3 +611,37 @@ def set_weight_attrs(
|
||||
for key, value in weight_attrs.items():
|
||||
assert not hasattr(weight, key), f"Overwriting existing tensor attribute: {key}"
|
||||
setattr(weight, key, value)
|
||||
|
||||
|
||||
def broadcast_pyobj(
|
||||
data: List[Any], rank: int, dist_group: torch.distributed.ProcessGroup
|
||||
):
|
||||
"""Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
|
||||
|
||||
if rank == 0:
|
||||
if len(data) == 0:
|
||||
tensor_size = torch.tensor([0], dtype=torch.long)
|
||||
dist.broadcast(tensor_size, src=0, group=dist_group)
|
||||
else:
|
||||
serialized_data = pickle.dumps(data)
|
||||
size = len(serialized_data)
|
||||
tensor_data = torch.ByteTensor(list(serialized_data))
|
||||
tensor_size = torch.tensor([size], dtype=torch.long)
|
||||
|
||||
dist.broadcast(tensor_size, src=0, group=dist_group)
|
||||
dist.broadcast(tensor_data, src=0, group=dist_group)
|
||||
return data
|
||||
else:
|
||||
tensor_size = torch.tensor([0], dtype=torch.long)
|
||||
dist.broadcast(tensor_size, src=0, group=dist_group)
|
||||
size = tensor_size.item()
|
||||
|
||||
if size == 0:
|
||||
return []
|
||||
|
||||
tensor_data = torch.empty(size, dtype=torch.uint8)
|
||||
dist.broadcast(tensor_data, src=0, group=dist_group)
|
||||
|
||||
serialized_data = bytes(tensor_data.tolist())
|
||||
data = pickle.loads(serialized_data)
|
||||
return data
|
||||
|
||||
Reference in New Issue
Block a user