Sync from v0.13
This commit is contained in:
@@ -1,27 +1,60 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Adapted from
|
||||
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
|
||||
# Copyright (c) 2024 - 2024 Moore Threads Technology Co., Ltd("Moore Threads"). All rights reserved.
|
||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||
import json
|
||||
import dataclasses
|
||||
import os
|
||||
from typing import Dict, Optional, Sequence
|
||||
import pickle
|
||||
import socket
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
from collections import deque
|
||||
from collections.abc import Sequence
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup, TCPStore
|
||||
from torch.distributed.distributed_c10d import (
|
||||
Backend,
|
||||
PrefixStore,
|
||||
_get_default_timeout,
|
||||
_unregister_process_group,
|
||||
)
|
||||
from torch.distributed.rendezvous import rendezvous
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .parallel_state import get_cpu_world_group, get_local_rank
|
||||
from vllm.utils.network_utils import get_tcp_uri
|
||||
from vllm.utils.system_utils import suppress_stdout
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# We prefer to use os.sched_yield as it results in tighter polling loops,
|
||||
# measured to be around 3e-7 seconds. However on earlier versions of Python
|
||||
# os.sched_yield() does not release the GIL, so we fall back to time.sleep(0)
|
||||
USE_SCHED_YIELD = (sys.version_info[:3] >= (3, 11, 1)) or (
|
||||
sys.version_info[:2] == (3, 10) and sys.version_info[2] >= 8
|
||||
)
|
||||
|
||||
|
||||
def sched_yield():
|
||||
if USE_SCHED_YIELD:
|
||||
os.sched_yield()
|
||||
else:
|
||||
time.sleep(0)
|
||||
|
||||
|
||||
def ensure_divisibility(numerator, denominator):
|
||||
"""Ensure that numerator is divisible by the denominator."""
|
||||
assert numerator % denominator == 0, "{} is not divisible by {}".format(
|
||||
numerator, denominator)
|
||||
numerator, denominator
|
||||
)
|
||||
|
||||
|
||||
def divide(numerator, denominator):
|
||||
@@ -36,16 +69,16 @@ def split_tensor_along_last_dim(
|
||||
num_partitions: int,
|
||||
contiguous_split_chunks: bool = False,
|
||||
) -> Sequence[torch.Tensor]:
|
||||
""" Split a tensor along its last dimension.
|
||||
"""Split a tensor along its last dimension.
|
||||
|
||||
Arguments:
|
||||
tensor: input tensor.
|
||||
num_partitions: number of partitions to split the tensor
|
||||
contiguous_split_chunks: If True, make each chunk contiguous
|
||||
in memory.
|
||||
Arguments:
|
||||
tensor: input tensor.
|
||||
num_partitions: number of partitions to split the tensor
|
||||
contiguous_split_chunks: If True, make each chunk contiguous
|
||||
in memory.
|
||||
|
||||
Returns:
|
||||
A list of Tensors
|
||||
Returns:
|
||||
A list of Tensors
|
||||
"""
|
||||
# Get the size and dimension.
|
||||
last_dim = tensor.dim() - 1
|
||||
@@ -59,79 +92,454 @@ def split_tensor_along_last_dim(
|
||||
return tensor_list
|
||||
|
||||
|
||||
# code partly borrowed from
|
||||
# https://github.com/turboderp/exllamav2/blob/1c67f97f3d2a968605a9c31ab791a05c85bb7879/exllamav2/compat.py#L10
|
||||
# License: MIT
|
||||
def _can_actually_p2p(idx_a, idx_b):
|
||||
dev_i = f"musa:{idx_a}"
|
||||
dev_j = f"musa:{idx_b}"
|
||||
a = torch.randn(5, device=dev_i) + 123.0
|
||||
b = a.to(dev_j)
|
||||
c = b.to(dev_i)
|
||||
return torch.all(a == c).cpu().item()
|
||||
def get_pp_indices(
|
||||
num_hidden_layers: int, pp_rank: int, pp_size: int
|
||||
) -> tuple[int, int]:
|
||||
"""Try to evenly distribute layers across partitions.
|
||||
|
||||
If the number of layers is not divisible by the number of partitions,
|
||||
the remaining layers are evenly distributed across all but the last
|
||||
partition. The last partition is excluded because it often contains an
|
||||
additional norm layer and we are attempting to balance compute.
|
||||
|
||||
If `pp_size > 2` and the number of remaining layers is
|
||||
`0 < x <= pp_size - 2` then the remaining layers are evenly distributed
|
||||
across the middle partitions. The first and last partitions are excluded
|
||||
because they contain the input and output embeddings respectively and we
|
||||
are attempting to reduce maximum memory consumption across partitions.
|
||||
"""
|
||||
partition_list_str = envs.VLLM_PP_LAYER_PARTITION
|
||||
if partition_list_str is not None:
|
||||
try:
|
||||
partitions = [int(layer) for layer in partition_list_str.split(",")]
|
||||
except ValueError as err:
|
||||
raise ValueError(
|
||||
"Invalid partition string: {}".format(partition_list_str)
|
||||
) from err
|
||||
if len(partitions) != pp_size:
|
||||
raise ValueError(f"{len(partitions)=} does not match {pp_size=}.")
|
||||
if sum(partitions) != num_hidden_layers:
|
||||
raise ValueError(f"{sum(partitions)=} does not match {num_hidden_layers=}.")
|
||||
else:
|
||||
layers_per_partition = num_hidden_layers // pp_size
|
||||
partitions = [layers_per_partition for _ in range(pp_size)]
|
||||
|
||||
if remaining_layers := num_hidden_layers % pp_size:
|
||||
for i in range(2, remaining_layers + 2):
|
||||
partitions[-i] += 1
|
||||
logger.info(
|
||||
"Hidden layers were unevenly partitioned: [%s]. "
|
||||
"This can be manually overridden using the "
|
||||
"VLLM_PP_LAYER_PARTITION environment variable",
|
||||
",".join(str(p) for p in partitions),
|
||||
)
|
||||
|
||||
start_layer = sum(partitions[:pp_rank])
|
||||
end_layer = start_layer + partitions[pp_rank]
|
||||
|
||||
return (start_layer, end_layer)
|
||||
|
||||
|
||||
# why do we need this cache?
|
||||
# 1. we can have runtime checks for P2P access, where every process checks
|
||||
# P2P access to all other GPUs. Unfortunately, the test might cost many
|
||||
# (world_size * world_size) cuda context, and reduce the memory available
|
||||
# for the model. see https://github.com/vllm-project/vllm/issues/3821
|
||||
# 2. alternatively, we can have a p2p map that is generated by the master
|
||||
# process and broadcasted to all other processes. This still requires
|
||||
# #world_size of cuda context, belonging to the master process, on each GPU.
|
||||
# 3. we can have a cache file, that records the p2p access status. The first
|
||||
# time the master process checks the p2p access, it will generate the cache
|
||||
# file, at the cost of #world_size of cuda context. Later on, all processes
|
||||
# can read the cache file to check the p2p access status without any cost of
|
||||
# additional cuda context.
|
||||
# Note that the cache file is suffixed by the CUDA_VISIBLE_DEVICES, so that we
|
||||
# can have different cache files for different CUDA_VISIBLE_DEVICES settings,
|
||||
# e.g. used by different vllm engines. The device id in the cache file is a
|
||||
# **local** device id, i.e. from 0 to num_dev-1, where num_dev is the number
|
||||
# of visible devices in the vllm engine.
|
||||
_gpu_p2p_access_cache: Optional[Dict[str, bool]] = None
|
||||
@dataclasses.dataclass
|
||||
class StatelessProcessGroup:
|
||||
"""A dataclass to hold a metadata store, and the rank, world_size of the
|
||||
group. Only use it to communicate metadata between processes.
|
||||
For data-plane communication, create NCCL-related objects.
|
||||
"""
|
||||
|
||||
rank: int
|
||||
world_size: int
|
||||
store: torch._C._distributed_c10d.Store
|
||||
|
||||
# stores a reference to the socket so that the file descriptor stays alive
|
||||
socket: socket.socket | None
|
||||
|
||||
data_expiration_seconds: int = 3600 # 1 hour
|
||||
|
||||
# dst rank -> counter
|
||||
send_dst_counter: dict[int, int] = dataclasses.field(default_factory=dict)
|
||||
# src rank -> counter
|
||||
recv_src_counter: dict[int, int] = dataclasses.field(default_factory=dict)
|
||||
broadcast_send_counter: int = 0
|
||||
broadcast_recv_src_counter: dict[int, int] = dataclasses.field(default_factory=dict)
|
||||
|
||||
# A deque to store the data entries, with key and timestamp.
|
||||
entries: deque[tuple[str, float]] = dataclasses.field(default_factory=deque)
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.rank < self.world_size
|
||||
self.send_dst_counter = {i: 0 for i in range(self.world_size)}
|
||||
self.recv_src_counter = {i: 0 for i in range(self.world_size)}
|
||||
self.broadcast_recv_src_counter = {i: 0 for i in range(self.world_size)}
|
||||
|
||||
def send_obj(self, obj: Any, dst: int):
|
||||
"""Send an object to a destination rank."""
|
||||
self.expire_data()
|
||||
key = f"send_to/{dst}/{self.send_dst_counter[dst]}"
|
||||
self.store.set(key, pickle.dumps(obj))
|
||||
self.send_dst_counter[dst] += 1
|
||||
self.entries.append((key, time.time()))
|
||||
|
||||
def expire_data(self):
|
||||
"""Expire data that is older than `data_expiration_seconds` seconds."""
|
||||
while self.entries:
|
||||
# check the oldest entry
|
||||
key, timestamp = self.entries[0]
|
||||
if time.time() - timestamp > self.data_expiration_seconds:
|
||||
self.store.delete_key(key)
|
||||
self.entries.popleft()
|
||||
else:
|
||||
break
|
||||
|
||||
def recv_obj(self, src: int) -> Any:
|
||||
"""Receive an object from a source rank."""
|
||||
obj = pickle.loads(
|
||||
self.store.get(f"send_to/{self.rank}/{self.recv_src_counter[src]}")
|
||||
)
|
||||
self.recv_src_counter[src] += 1
|
||||
return obj
|
||||
|
||||
def broadcast_obj(self, obj: Any | None, src: int) -> Any:
|
||||
"""Broadcast an object from a source rank to all other ranks.
|
||||
It does not clean up after all ranks have received the object.
|
||||
Use it for limited times, e.g., for initialization.
|
||||
"""
|
||||
if self.rank == src:
|
||||
self.expire_data()
|
||||
key = f"broadcast_from/{src}/{self.broadcast_send_counter}"
|
||||
self.store.set(key, pickle.dumps(obj))
|
||||
self.broadcast_send_counter += 1
|
||||
self.entries.append((key, time.time()))
|
||||
return obj
|
||||
else:
|
||||
key = f"broadcast_from/{src}/{self.broadcast_recv_src_counter[src]}"
|
||||
recv_obj = pickle.loads(self.store.get(key))
|
||||
self.broadcast_recv_src_counter[src] += 1
|
||||
return recv_obj
|
||||
|
||||
def all_gather_obj(self, obj: Any) -> list[Any]:
|
||||
"""All gather an object from all ranks."""
|
||||
gathered_objs = []
|
||||
for i in range(self.world_size):
|
||||
if i == self.rank:
|
||||
gathered_objs.append(obj)
|
||||
self.broadcast_obj(obj, src=self.rank)
|
||||
else:
|
||||
recv_obj = self.broadcast_obj(None, src=i)
|
||||
gathered_objs.append(recv_obj)
|
||||
return gathered_objs
|
||||
|
||||
def barrier(self, timeout: float = 30.0):
|
||||
"""A robust barrier to synchronize all ranks.
|
||||
|
||||
|
||||
def gpu_p2p_access_check(i: int, j: int) -> bool:
|
||||
"""Check if GPU i can access GPU j."""
|
||||
Uses a multi-phase approach to ensure all processes reach the barrier
|
||||
before proceeding:
|
||||
|
||||
# if the cache variable is already calculated,
|
||||
# read from the cache instead of checking it again
|
||||
global _gpu_p2p_access_cache
|
||||
if _gpu_p2p_access_cache is not None:
|
||||
return _gpu_p2p_access_cache[f"{i}->{j}"]
|
||||
1. Each process signals it has reached the barrier
|
||||
|
||||
is_distributed = dist.is_initialized()
|
||||
2. Each process signals that it has confirmed the arrival of all other
|
||||
ranks.
|
||||
|
||||
num_dev = torch.musa.device_count()
|
||||
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
|
||||
if cuda_visible_devices is None:
|
||||
cuda_visible_devices = ",".join(str(i) for i in range(num_dev))
|
||||
VLLM_CONFIG_ROOT = envs.VLLM_CONFIG_ROOT
|
||||
path = os.path.expanduser(
|
||||
f"{VLLM_CONFIG_ROOT}/vllm/gpu_p2p_access_cache_for_{cuda_visible_devices}.json"
|
||||
3. Rank 0 waits for all other ranks to signal their departure to ensure
|
||||
that all ranks have departed the barrier first.
|
||||
|
||||
Args:
|
||||
timeout: Maximum time in seconds to wait for each phase (in seconds)
|
||||
|
||||
|
||||
Raises:
|
||||
RuntimeError: If coordination fails or times out
|
||||
"""
|
||||
# Generate a barrier ID that is globally unique
|
||||
try:
|
||||
if self.rank == 0:
|
||||
barrier_id = f"barrier_{uuid.uuid4()}"
|
||||
self.broadcast_obj(barrier_id, src=0)
|
||||
else:
|
||||
barrier_id = self.broadcast_obj(None, src=0)
|
||||
except Exception as e:
|
||||
raise RuntimeError("Failed to broadcast barrier_id") from e
|
||||
|
||||
# Phase 1: Signal arrival at barrier
|
||||
# Wait for all processes to arrive
|
||||
# We need all ranks to confirm the arrival of all other ranks.
|
||||
# This is the key synchronization point.
|
||||
arrival_key = f"arrival_{barrier_id}_{self.rank}"
|
||||
try:
|
||||
self.store.set(arrival_key, b"1")
|
||||
except Exception as e:
|
||||
raise RuntimeError("Failed to signal barrier arrival") from e
|
||||
|
||||
start_time = time.time()
|
||||
processes_arrived: set[int] = set()
|
||||
|
||||
while len(processes_arrived) < self.world_size:
|
||||
# Check for timeout
|
||||
cur_time = time.time()
|
||||
if cur_time - start_time > timeout:
|
||||
raise RuntimeError(f"Barrier timed out after {timeout:.2f} seconds")
|
||||
|
||||
# Check for each process
|
||||
for i in range(self.world_size):
|
||||
if i in processes_arrived:
|
||||
continue
|
||||
|
||||
key = f"arrival_{barrier_id}_{i}"
|
||||
try:
|
||||
# Try to get the key - if it exists, we'll get a value
|
||||
# If it doesn't exist, it will throw an exception
|
||||
self.store.get(key)
|
||||
processes_arrived.add(i)
|
||||
except KeyError:
|
||||
# Key doesn't exist yet
|
||||
pass
|
||||
except Exception as check_e:
|
||||
logger.debug("Error checking key existence: %s", check_e)
|
||||
sched_yield()
|
||||
|
||||
# Short sleep to avoid tight polling
|
||||
if len(processes_arrived) < self.world_size:
|
||||
sched_yield()
|
||||
|
||||
# Phase 2: Signal departure from barrier
|
||||
# We only care to block at this stage in rank 0, which runs the
|
||||
# server side of the TCPStore. We want to make sure that all
|
||||
# clients have departed the barrier before rank 0 in case the
|
||||
# next thing after the barrier is a shutdown, including tearing
|
||||
# down the TCPStore. Other ranks can exit the barrier immediately
|
||||
# after signaling their departure.
|
||||
departure_key = f"departure_{barrier_id}_{self.rank}"
|
||||
try:
|
||||
self.store.set(departure_key, b"1")
|
||||
except Exception as e:
|
||||
raise RuntimeError("Failed to signal barrier departure") from e
|
||||
|
||||
if self.rank != 0:
|
||||
return
|
||||
|
||||
# Make rank 0 wait for all processes to signal departure
|
||||
start_time = time.time()
|
||||
processes_departed: set[int] = set()
|
||||
|
||||
while len(processes_departed) < self.world_size:
|
||||
# Check for timeout
|
||||
if time.time() - start_time > timeout:
|
||||
raise RuntimeError(
|
||||
f"Barrier departure timed out after {timeout:.2f} seconds"
|
||||
)
|
||||
|
||||
# Check for each process
|
||||
for i in range(self.world_size):
|
||||
if i in processes_departed:
|
||||
continue
|
||||
|
||||
key = f"departure_{barrier_id}_{i}"
|
||||
try:
|
||||
# Try to get the key - if it exists, we'll get a value
|
||||
# If it doesn't exist, it will throw an exception
|
||||
self.store.get(key)
|
||||
processes_departed.add(i)
|
||||
except KeyError:
|
||||
# Key doesn't exist yet
|
||||
pass
|
||||
except Exception as check_e:
|
||||
logger.debug("Error checking key existence: %s", check_e)
|
||||
sched_yield()
|
||||
|
||||
# Short sleep to avoid tight polling
|
||||
if len(processes_departed) < self.world_size:
|
||||
sched_yield()
|
||||
|
||||
# Clean up keys to avoid leaking memory in the store
|
||||
for i in range(self.world_size):
|
||||
try:
|
||||
self.store.delete_key(f"arrival_{barrier_id}_{i}")
|
||||
except Exception:
|
||||
logger.debug("Error deleting key: %s", f"arrival_{barrier_id}_{i}")
|
||||
|
||||
try:
|
||||
self.store.delete_key(f"departure_{barrier_id}_{i}")
|
||||
except Exception:
|
||||
logger.debug("Error deleting key: %s", f"departure_{barrier_id}_{i}")
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
host: str,
|
||||
port: int,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
data_expiration_seconds: int = 3600,
|
||||
store_timeout: int = 300,
|
||||
) -> "StatelessProcessGroup":
|
||||
"""A replacement for `torch.distributed.init_process_group` that does not
|
||||
pollute the global state.
|
||||
|
||||
If we have process A and process B called `torch.distributed.init_process_group`
|
||||
to form a group, and then we want to form another group with process A, B, C,
|
||||
D, it is not possible in PyTorch, because process A and process B have already
|
||||
formed a group, and process C and process D cannot join that group. This
|
||||
function is a workaround for this issue.
|
||||
|
||||
`torch.distributed.init_process_group` is a global call, while this function
|
||||
is a stateless call. It will return a `StatelessProcessGroup` object that can be
|
||||
used for exchanging metadata. With this function, process A and process B
|
||||
can call `StatelessProcessGroup.create` to form a group, and then process A, B,
|
||||
C, and D can call `StatelessProcessGroup.create` to form another group.
|
||||
""" # noqa
|
||||
launch_server = rank == 0
|
||||
if launch_server:
|
||||
# listen on the specified interface (instead of 0.0.0.0)
|
||||
listen_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
listen_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
listen_socket.bind((host, port))
|
||||
listen_socket.listen()
|
||||
listen_fd = listen_socket.fileno()
|
||||
else:
|
||||
listen_socket = None
|
||||
listen_fd = None
|
||||
|
||||
store = TCPStore(
|
||||
host_name=host,
|
||||
port=port,
|
||||
world_size=world_size,
|
||||
is_master=launch_server,
|
||||
timeout=timedelta(seconds=store_timeout),
|
||||
use_libuv=False, # for now: github.com/pytorch/pytorch/pull/150215
|
||||
master_listen_fd=listen_fd,
|
||||
)
|
||||
|
||||
return StatelessProcessGroup(
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
store=store,
|
||||
socket=listen_socket,
|
||||
data_expiration_seconds=data_expiration_seconds,
|
||||
)
|
||||
|
||||
|
||||
def init_gloo_process_group(
|
||||
prefix_store: PrefixStore,
|
||||
group_rank: int,
|
||||
group_size: int,
|
||||
timeout: timedelta,
|
||||
) -> ProcessGroup:
|
||||
"""
|
||||
Stateless init ProcessGroup with gloo backend compatible with
|
||||
different torch versions.
|
||||
"""
|
||||
with suppress_stdout():
|
||||
if is_torch_equal_or_newer("2.6"):
|
||||
pg = ProcessGroup(
|
||||
prefix_store,
|
||||
group_rank,
|
||||
group_size,
|
||||
)
|
||||
else:
|
||||
options = ProcessGroup.Options(backend="gloo")
|
||||
pg = ProcessGroup(
|
||||
prefix_store,
|
||||
group_rank,
|
||||
group_size,
|
||||
options,
|
||||
)
|
||||
from torch.distributed.distributed_c10d import ProcessGroupGloo
|
||||
|
||||
backend_class = ProcessGroupGloo(
|
||||
prefix_store, group_rank, group_size, timeout=timeout
|
||||
)
|
||||
backend_type = ProcessGroup.BackendType.GLOO
|
||||
device = torch.device("cpu")
|
||||
if is_torch_equal_or_newer("2.6"):
|
||||
# _set_default_backend is supported in torch >= 2.6
|
||||
pg._set_default_backend(backend_type)
|
||||
backend_class._set_sequence_number_for_group()
|
||||
|
||||
pg._register_backend(device, backend_type, backend_class)
|
||||
return pg
|
||||
|
||||
|
||||
def stateless_init_torch_distributed_process_group(
|
||||
host: str, port: int, rank: int, world_size: int, backend: str
|
||||
) -> ProcessGroup:
|
||||
"""
|
||||
A replacement for `torch.distributed.init_process_group` that does not
|
||||
pollute the global state. The created ProcessGroup object can be used for
|
||||
some operations such as `allreduce`, because it does not depend on the
|
||||
global rank. However, some operations such as `broadcast` cannot be used
|
||||
because it depends on the global rank.
|
||||
|
||||
# TODO: ask for help from PyTorch team if we need the `broadcast` operation.
|
||||
|
||||
This function is useful when we are not sure about the total number of
|
||||
processes in the process group. For example, we may have process
|
||||
1, 2, ..., 8 who want to communicate, and process 9 might be the same
|
||||
process as process 1, or it might be a different process; process 10
|
||||
might be the same process as process 5, or it might be a different process.
|
||||
In this case, how can we reliably form a communication channel within
|
||||
process 9 and 10, without affecting the communication channel within
|
||||
process 1, 2, ..., 8?
|
||||
|
||||
One possible solution is to figure out if process 9 and 10 are the same
|
||||
as process 1 and 5 beforehand, and then form a communication channel
|
||||
based on the information, adjusting the ranks and world_size etc. However,
|
||||
figuring out the information is not always easy, and it will interfere
|
||||
with the main communication channel.
|
||||
|
||||
Our solution is to always form a communication channel with process 1, 2,
|
||||
..., 8, and then use this function to form another communication channel
|
||||
with process 9 and 10. This way, regardless of whether process 9 and 10
|
||||
are the same as process 1 and 5, the main communication channel is
|
||||
always formed with process 1, 2, ..., 8, and the additional communication
|
||||
channel is formed with process 9 and 10.
|
||||
"""
|
||||
init_method = get_tcp_uri(host, port)
|
||||
backend = Backend(backend) # it is basically string
|
||||
timeout = _get_default_timeout(backend)
|
||||
|
||||
store, rank, world_size = next(
|
||||
rendezvous(init_method, rank, world_size, timeout=timeout)
|
||||
)
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
if (not is_distributed or get_local_rank() == 0) \
|
||||
and (not os.path.exists(path)):
|
||||
# only the local master process (with local_rank == 0) can
|
||||
# enter this block to calculate the cache
|
||||
logger.info("generating GPU P2P access cache for in %s", path)
|
||||
cache = {}
|
||||
for _i in range(num_dev):
|
||||
for _j in range(num_dev):
|
||||
# on some platforms, P2P support might be buggy and we need
|
||||
# additional checks. See also:
|
||||
# https://github.com/vllm-project/vllm/issues/2728
|
||||
cache[f"{_i}->{_j}"] = torch.musa.can_device_access_peer(
|
||||
_i, _j) and _can_actually_p2p(_i, _j)
|
||||
with open(path, "w") as f:
|
||||
json.dump(cache, f, indent=4)
|
||||
if is_distributed:
|
||||
cpu_world_group = get_cpu_world_group()
|
||||
dist.barrier(cpu_world_group)
|
||||
logger.info("reading GPU P2P access cache from %s", path)
|
||||
with open(path, "r") as f:
|
||||
cache = json.load(f)
|
||||
_gpu_p2p_access_cache = cache
|
||||
return _gpu_p2p_access_cache[f"{i}->{j}"]
|
||||
store.set_timeout(timeout)
|
||||
|
||||
group_rank = rank
|
||||
group_size = world_size
|
||||
|
||||
# Use a PrefixStore to avoid accidental overrides of keys used by
|
||||
# different systems (e.g. RPC) in case the store is multi-tenant.
|
||||
prefix_store = PrefixStore(init_method, store)
|
||||
try:
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
return current_platform.stateless_init_device_torch_dist_pg(
|
||||
backend=backend,
|
||||
prefix_store=prefix_store,
|
||||
group_rank=group_rank,
|
||||
group_size=group_size,
|
||||
timeout=timeout,
|
||||
)
|
||||
except NotImplementedError:
|
||||
# If platform doesn't implement stateless_init_device_torch_dist_pg, it
|
||||
# will raise a NotImplementedError. In this case, we fall back to gloo.
|
||||
return init_gloo_process_group(
|
||||
prefix_store=prefix_store,
|
||||
group_rank=group_rank,
|
||||
group_size=group_size,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
|
||||
def stateless_destroy_torch_distributed_process_group(pg: ProcessGroup) -> None:
|
||||
"""
|
||||
Destroy ProcessGroup returned by
|
||||
stateless_init_torch_distributed_process_group().
|
||||
"""
|
||||
if is_torch_equal_or_newer("2.7"):
|
||||
pg.shutdown()
|
||||
else:
|
||||
# Lazy import for non-CUDA backends.
|
||||
from torch.distributed.distributed_c10d import _shutdown_backend
|
||||
|
||||
_shutdown_backend(pg)
|
||||
|
||||
_unregister_process_group(pg.group_name)
|
||||
|
||||
Reference in New Issue
Block a user