adapt vllm distributed module to sglang (#2244)
Co-authored-by: Yineng Zhang <me@zhyncs.com>
This commit is contained in:
118
python/sglang/srt/_custom_ops.py
Normal file
118
python/sglang/srt/_custom_ops.py
Normal file
@@ -0,0 +1,118 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/_custom_ops.py
|
||||
import contextlib
|
||||
import functools
|
||||
import importlib
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.library
|
||||
|
||||
from sglang.srt.utils import is_hpu
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if not is_hpu():
|
||||
try:
|
||||
import custom_ar
|
||||
except ImportError as e:
|
||||
logger.warning("Failed to import from custom_ar with %r", e)
|
||||
|
||||
|
||||
def hint_on_error(fn):
|
||||
|
||||
@functools.wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
except NotImplementedError as e:
|
||||
msg = (
|
||||
"Error in calling custom op %s: %s\n"
|
||||
"Not implemented or built, mostly likely because the current current device "
|
||||
"does not support this kernel (less likely TORCH_CUDA_ARCH_LIST was set "
|
||||
"incorrectly while building)"
|
||||
)
|
||||
logger.error(msg, fn.__name__, e)
|
||||
raise NotImplementedError(msg % (fn.__name__, e)) from e
|
||||
except AttributeError as e:
|
||||
msg = (
|
||||
"Error in calling custom op %s: %s\n"
|
||||
"Possibly you have built or installed an obsolete version of vllm.\n"
|
||||
"Please try a clean build and install of vllm,"
|
||||
"or remove old built files such as vllm/*cpython*.so and build/ ."
|
||||
)
|
||||
logger.error(msg, fn.__name__, e)
|
||||
raise e
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
# custom ar
|
||||
def init_custom_ar(
|
||||
ipc_tensors: List[torch.Tensor],
|
||||
rank_data: torch.Tensor,
|
||||
rank: int,
|
||||
full_nvlink: bool,
|
||||
) -> int:
|
||||
return torch.ops._C_vllm_ar.init_custom_ar(
|
||||
ipc_tensors, rank_data, rank, full_nvlink
|
||||
)
|
||||
|
||||
|
||||
def all_reduce(
|
||||
fa: int,
|
||||
inp: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
reg_buffer: int,
|
||||
reg_buffer_sz_bytes: int,
|
||||
) -> None:
|
||||
torch.ops._C_vllm_ar.all_reduce(fa, inp, out, reg_buffer, reg_buffer_sz_bytes)
|
||||
|
||||
|
||||
def dispose(fa: int) -> None:
|
||||
torch.ops._C_vllm_ar.dispose(fa)
|
||||
|
||||
|
||||
def meta_size() -> int:
|
||||
return torch.ops._C_vllm_ar.meta_size()
|
||||
|
||||
|
||||
def register_buffer(fa: int, ipc_tensors: List[int]) -> None:
|
||||
return torch.ops._C_vllm_ar.register_buffer(fa, ipc_tensors)
|
||||
|
||||
|
||||
def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]:
|
||||
return torch.ops._C_vllm_ar.get_graph_buffer_ipc_meta(fa)
|
||||
|
||||
|
||||
def register_graph_buffers(
|
||||
fa: int, handles: List[List[int]], offsets: List[List[int]]
|
||||
) -> None:
|
||||
torch.ops._C_vllm_ar.register_graph_buffers(fa, handles, offsets)
|
||||
|
||||
|
||||
# temporary fix for https://github.com/vllm-project/vllm/issues/5456
|
||||
# TODO: remove this in v0.6.0
|
||||
names_and_values = globals()
|
||||
names_and_values_to_update = {}
|
||||
# prepare variables to avoid dict size change during iteration
|
||||
k, v, arg = None, None, None
|
||||
fn_type = type(lambda x: x)
|
||||
for k, v in names_and_values.items():
|
||||
# find functions that are defined in this file and have torch.Tensor
|
||||
# in their annotations. `arg == "torch.Tensor"` is used to handle
|
||||
# the case when users use `import __annotations__` to turn type
|
||||
# hints into strings.
|
||||
if (
|
||||
isinstance(v, fn_type)
|
||||
and v.__code__.co_filename == __file__
|
||||
and any(
|
||||
arg is torch.Tensor or arg == "torch.Tensor"
|
||||
for arg in v.__annotations__.values()
|
||||
)
|
||||
):
|
||||
names_and_values_to_update[k] = hint_on_error(v)
|
||||
|
||||
names_and_values.update(names_and_values_to_update)
|
||||
del names_and_values_to_update, names_and_values, v, k, fn_type
|
||||
3
python/sglang/srt/distributed/__init__.py
Normal file
3
python/sglang/srt/distributed/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .communication_op import *
|
||||
from .parallel_state import *
|
||||
from .utils import *
|
||||
34
python/sglang/srt/distributed/communication_op.py
Normal file
34
python/sglang/srt/distributed/communication_op.py
Normal file
@@ -0,0 +1,34 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/communication_op.py
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from .parallel_state import get_tp_group
|
||||
|
||||
|
||||
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
|
||||
"""All-reduce the input tensor across model parallel group."""
|
||||
return get_tp_group().all_reduce(input_)
|
||||
|
||||
|
||||
def tensor_model_parallel_all_gather(
|
||||
input_: torch.Tensor, dim: int = -1
|
||||
) -> torch.Tensor:
|
||||
"""All-gather the input tensor across model parallel group."""
|
||||
return get_tp_group().all_gather(input_, dim)
|
||||
|
||||
|
||||
def tensor_model_parallel_gather(
|
||||
input_: torch.Tensor, dst: int = 0, dim: int = -1
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""Gather the input tensor across model parallel group."""
|
||||
return get_tp_group().gather(input_, dst, dim)
|
||||
|
||||
|
||||
def broadcast_tensor_dict(
|
||||
tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None, src: int = 0
|
||||
):
|
||||
if not torch.distributed.is_initialized():
|
||||
return tensor_dict
|
||||
return get_tp_group().broadcast_tensor_dict(tensor_dict, src)
|
||||
@@ -0,0 +1,182 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/cuda_wrapper.py
|
||||
"""This file is a pure Python wrapper for the cudart library.
|
||||
It avoids the need to compile a separate shared library, and is
|
||||
convenient for use when we just need to call a few functions.
|
||||
"""
|
||||
|
||||
import ctypes
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
# this line makes it possible to directly load `libcudart.so` using `ctypes`
|
||||
import torch # noqa
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# === export types and functions from cudart to Python ===
|
||||
# for the original cudart definition, please check
|
||||
# https://docs.nvidia.com/cuda/cuda-runtime-api/index.html
|
||||
|
||||
cudaError_t = ctypes.c_int
|
||||
cudaMemcpyKind = ctypes.c_int
|
||||
|
||||
|
||||
class cudaIpcMemHandle_t(ctypes.Structure):
|
||||
_fields_ = [("internal", ctypes.c_byte * 128)]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Function:
|
||||
name: str
|
||||
restype: Any
|
||||
argtypes: List[Any]
|
||||
|
||||
|
||||
def find_loaded_library(lib_name) -> Optional[str]:
|
||||
"""
|
||||
According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html,
|
||||
the file `/proc/self/maps` contains the memory maps of the process, which includes the
|
||||
shared libraries loaded by the process. We can use this file to find the path of the
|
||||
a loaded library.
|
||||
""" # noqa
|
||||
found = False
|
||||
with open("/proc/self/maps") as f:
|
||||
for line in f:
|
||||
if lib_name in line:
|
||||
found = True
|
||||
break
|
||||
if not found:
|
||||
# the library is not loaded in the current process
|
||||
return None
|
||||
# if lib_name is libcudart, we need to match a line with:
|
||||
# address /path/to/libcudart-hash.so.11.0
|
||||
start = line.index("/")
|
||||
path = line[start:].strip()
|
||||
filename = path.split("/")[-1]
|
||||
assert filename.rpartition(".so")[0].startswith(
|
||||
lib_name
|
||||
), f"Unexpected filename: {filename} for library {lib_name}"
|
||||
return path
|
||||
|
||||
|
||||
class CudaRTLibrary:
|
||||
exported_functions = [
|
||||
# cudaError_t cudaSetDevice ( int device )
|
||||
Function("cudaSetDevice", cudaError_t, [ctypes.c_int]),
|
||||
# cudaError_t cudaDeviceSynchronize ( void )
|
||||
Function("cudaDeviceSynchronize", cudaError_t, []),
|
||||
# cudaError_t cudaDeviceReset ( void )
|
||||
Function("cudaDeviceReset", cudaError_t, []),
|
||||
# const char* cudaGetErrorString ( cudaError_t error )
|
||||
Function("cudaGetErrorString", ctypes.c_char_p, [cudaError_t]),
|
||||
# cudaError_t cudaMalloc ( void** devPtr, size_t size )
|
||||
Function(
|
||||
"cudaMalloc",
|
||||
cudaError_t,
|
||||
[ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t],
|
||||
),
|
||||
# cudaError_t cudaFree ( void* devPtr )
|
||||
Function("cudaFree", cudaError_t, [ctypes.c_void_p]),
|
||||
# cudaError_t cudaMemset ( void* devPtr, int value, size_t count )
|
||||
Function(
|
||||
"cudaMemset", cudaError_t, [ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t]
|
||||
),
|
||||
# cudaError_t cudaMemcpy ( void* dst, const void* src, size_t count, cudaMemcpyKind kind ) # noqa
|
||||
Function(
|
||||
"cudaMemcpy",
|
||||
cudaError_t,
|
||||
[ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, cudaMemcpyKind],
|
||||
),
|
||||
# cudaError_t cudaIpcGetMemHandle ( cudaIpcMemHandle_t* handle, void* devPtr ) # noqa
|
||||
Function(
|
||||
"cudaIpcGetMemHandle",
|
||||
cudaError_t,
|
||||
[ctypes.POINTER(cudaIpcMemHandle_t), ctypes.c_void_p],
|
||||
),
|
||||
# cudaError_t cudaIpcOpenMemHandle ( void** devPtr, cudaIpcMemHandle_t handle, unsigned int flags ) # noqa
|
||||
Function(
|
||||
"cudaIpcOpenMemHandle",
|
||||
cudaError_t,
|
||||
[ctypes.POINTER(ctypes.c_void_p), cudaIpcMemHandle_t, ctypes.c_uint],
|
||||
),
|
||||
]
|
||||
|
||||
# class attribute to store the mapping from the path to the library
|
||||
# to avoid loading the same library multiple times
|
||||
path_to_library_cache: Dict[str, Any] = {}
|
||||
|
||||
# class attribute to store the mapping from library path
|
||||
# to the corresponding dictionary
|
||||
path_to_dict_mapping: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
def __init__(self, so_file: Optional[str] = None):
|
||||
if so_file is None:
|
||||
so_file = find_loaded_library("libcudart")
|
||||
assert so_file is not None, "libcudart is not loaded in the current process"
|
||||
if so_file not in CudaRTLibrary.path_to_library_cache:
|
||||
lib = ctypes.CDLL(so_file)
|
||||
CudaRTLibrary.path_to_library_cache[so_file] = lib
|
||||
self.lib = CudaRTLibrary.path_to_library_cache[so_file]
|
||||
|
||||
if so_file not in CudaRTLibrary.path_to_dict_mapping:
|
||||
_funcs = {}
|
||||
for func in CudaRTLibrary.exported_functions:
|
||||
f = getattr(self.lib, func.name)
|
||||
f.restype = func.restype
|
||||
f.argtypes = func.argtypes
|
||||
_funcs[func.name] = f
|
||||
CudaRTLibrary.path_to_dict_mapping[so_file] = _funcs
|
||||
self.funcs = CudaRTLibrary.path_to_dict_mapping[so_file]
|
||||
|
||||
def CUDART_CHECK(self, result: cudaError_t) -> None:
|
||||
if result != 0:
|
||||
error_str = self.cudaGetErrorString(result)
|
||||
raise RuntimeError(f"CUDART error: {error_str}")
|
||||
|
||||
def cudaGetErrorString(self, error: cudaError_t) -> str:
|
||||
return self.funcs["cudaGetErrorString"](error).decode("utf-8")
|
||||
|
||||
def cudaSetDevice(self, device: int) -> None:
|
||||
self.CUDART_CHECK(self.funcs["cudaSetDevice"](device))
|
||||
|
||||
def cudaDeviceSynchronize(self) -> None:
|
||||
self.CUDART_CHECK(self.funcs["cudaDeviceSynchronize"]())
|
||||
|
||||
def cudaDeviceReset(self) -> None:
|
||||
self.CUDART_CHECK(self.funcs["cudaDeviceReset"]())
|
||||
|
||||
def cudaMalloc(self, size: int) -> ctypes.c_void_p:
|
||||
devPtr = ctypes.c_void_p()
|
||||
self.CUDART_CHECK(self.funcs["cudaMalloc"](ctypes.byref(devPtr), size))
|
||||
return devPtr
|
||||
|
||||
def cudaFree(self, devPtr: ctypes.c_void_p) -> None:
|
||||
self.CUDART_CHECK(self.funcs["cudaFree"](devPtr))
|
||||
|
||||
def cudaMemset(self, devPtr: ctypes.c_void_p, value: int, count: int) -> None:
|
||||
self.CUDART_CHECK(self.funcs["cudaMemset"](devPtr, value, count))
|
||||
|
||||
def cudaMemcpy(
|
||||
self, dst: ctypes.c_void_p, src: ctypes.c_void_p, count: int
|
||||
) -> None:
|
||||
cudaMemcpyDefault = 4
|
||||
kind = cudaMemcpyDefault
|
||||
self.CUDART_CHECK(self.funcs["cudaMemcpy"](dst, src, count, kind))
|
||||
|
||||
def cudaIpcGetMemHandle(self, devPtr: ctypes.c_void_p) -> cudaIpcMemHandle_t:
|
||||
handle = cudaIpcMemHandle_t()
|
||||
self.CUDART_CHECK(
|
||||
self.funcs["cudaIpcGetMemHandle"](ctypes.byref(handle), devPtr)
|
||||
)
|
||||
return handle
|
||||
|
||||
def cudaIpcOpenMemHandle(self, handle: cudaIpcMemHandle_t) -> ctypes.c_void_p:
|
||||
cudaIpcMemLazyEnablePeerAccess = 1
|
||||
devPtr = ctypes.c_void_p()
|
||||
self.CUDART_CHECK(
|
||||
self.funcs["cudaIpcOpenMemHandle"](
|
||||
ctypes.byref(devPtr), handle, cudaIpcMemLazyEnablePeerAccess
|
||||
)
|
||||
)
|
||||
return devPtr
|
||||
@@ -0,0 +1,352 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/custom_all_reduce.py
|
||||
import ctypes
|
||||
import logging
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from functools import wraps
|
||||
from typing import Callable, List, Optional, TypeVar, Union
|
||||
|
||||
import pynvml
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from sglang.srt import _custom_ops as ops
|
||||
from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
|
||||
from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import (
|
||||
gpu_p2p_access_check,
|
||||
)
|
||||
from sglang.srt.distributed.parallel_state import in_the_same_node_as
|
||||
from sglang.srt.utils import cuda_device_count_stateless, is_cuda
|
||||
|
||||
try:
|
||||
ops.meta_size()
|
||||
custom_ar = True
|
||||
except Exception:
|
||||
# For AMD GPUs and CPUs
|
||||
custom_ar = False
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_P = ParamSpec("_P")
|
||||
_R = TypeVar("_R")
|
||||
|
||||
|
||||
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
|
||||
@wraps(fn)
|
||||
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
|
||||
pynvml.nvmlInit()
|
||||
try:
|
||||
return fn(*args, **kwargs)
|
||||
finally:
|
||||
pynvml.nvmlShutdown()
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@with_nvml_context
|
||||
def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool:
|
||||
"""
|
||||
query if the set of gpus are fully connected by nvlink (1 hop)
|
||||
"""
|
||||
handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids]
|
||||
for i, handle in enumerate(handles):
|
||||
for j, peer_handle in enumerate(handles):
|
||||
if i < j:
|
||||
try:
|
||||
p2p_status = pynvml.nvmlDeviceGetP2PStatus(
|
||||
handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK
|
||||
)
|
||||
if p2p_status != pynvml.NVML_P2P_STATUS_OK:
|
||||
return False
|
||||
except pynvml.NVMLError:
|
||||
logger.exception(
|
||||
"NVLink detection failed. This is normal if your"
|
||||
" machine has no NVLink equipped."
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _can_p2p(rank: int, world_size: int) -> bool:
|
||||
# SGLANG_SKIP_P2P_CHECK can be set to False in sglang
|
||||
SGLANG_SKIP_P2P_CHECK = os.getenv("SGLANG_SKIP_P2P_CHECK", "0") == "1"
|
||||
for i in range(world_size):
|
||||
if i == rank:
|
||||
continue
|
||||
if SGLANG_SKIP_P2P_CHECK:
|
||||
logger.info("Skipping P2P check and trusting the driver's P2P report.")
|
||||
return torch.cuda.can_device_access_peer(rank, i)
|
||||
if not gpu_p2p_access_check(rank, i):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def is_weak_contiguous(inp: torch.Tensor):
|
||||
return inp.is_contiguous() or (
|
||||
inp.storage().nbytes() - inp.storage_offset() * inp.element_size()
|
||||
== inp.numel() * inp.element_size()
|
||||
)
|
||||
|
||||
|
||||
class CustomAllreduce:
|
||||
|
||||
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
|
||||
|
||||
# max_size: max supported allreduce size
|
||||
def __init__(
|
||||
self,
|
||||
group: ProcessGroup,
|
||||
device: Union[int, str, torch.device],
|
||||
max_size=8192 * 1024,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
group: the process group to work on. If None, it will use the
|
||||
default process group.
|
||||
device: the device to bind the CustomAllreduce to. If None,
|
||||
it will be bind to f"cuda:{local_rank}".
|
||||
It is the caller's responsibility to make sure each communicator
|
||||
is bind to a unique device, and all communicators in this group
|
||||
are in the same node.
|
||||
"""
|
||||
self._IS_CAPTURING = False
|
||||
self.disabled = True
|
||||
|
||||
if not custom_ar:
|
||||
# disable because of missing custom allreduce library
|
||||
# e.g. in a non-cuda environment
|
||||
return
|
||||
|
||||
self.group = group
|
||||
|
||||
assert (
|
||||
dist.get_backend(group) != dist.Backend.NCCL
|
||||
), "CustomAllreduce should be attached to a non-NCCL group."
|
||||
|
||||
if not all(in_the_same_node_as(group, source_rank=0)):
|
||||
# No need to initialize custom allreduce for multi-node case.
|
||||
logger.warning(
|
||||
"Custom allreduce is disabled because this process group"
|
||||
" spans across nodes."
|
||||
)
|
||||
return
|
||||
|
||||
rank = dist.get_rank(group=self.group)
|
||||
world_size = dist.get_world_size(group=self.group)
|
||||
if world_size == 1:
|
||||
# No need to initialize custom allreduce for single GPU case.
|
||||
return
|
||||
|
||||
if world_size not in CustomAllreduce._SUPPORTED_WORLD_SIZES:
|
||||
logger.warning(
|
||||
"Custom allreduce is disabled due to an unsupported world"
|
||||
" size: %d. Supported world sizes: %s. To silence this "
|
||||
"warning, specify disable_custom_all_reduce=True explicitly.",
|
||||
world_size,
|
||||
str(CustomAllreduce._SUPPORTED_WORLD_SIZES),
|
||||
)
|
||||
return
|
||||
|
||||
if isinstance(device, int):
|
||||
device = torch.device(f"cuda:{device}")
|
||||
elif isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
# now `device` is a `torch.device` object
|
||||
assert isinstance(device, torch.device)
|
||||
self.device = device
|
||||
|
||||
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
|
||||
if cuda_visible_devices:
|
||||
device_ids = list(map(int, cuda_visible_devices.split(",")))
|
||||
else:
|
||||
device_ids = list(range(cuda_device_count_stateless()))
|
||||
|
||||
physical_device_id = device_ids[device.index]
|
||||
tensor = torch.tensor([physical_device_id], dtype=torch.int, device="cpu")
|
||||
gather_list = [
|
||||
torch.tensor([0], dtype=torch.int, device="cpu") for _ in range(world_size)
|
||||
]
|
||||
dist.all_gather(gather_list, tensor, group=self.group)
|
||||
physical_device_ids = [t.item() for t in gather_list]
|
||||
|
||||
# test nvlink first, this will filter out most of the cases
|
||||
# where custom allreduce is not supported
|
||||
# this checks hardware and driver support for NVLink
|
||||
assert is_cuda()
|
||||
|
||||
full_nvlink = is_full_nvlink(physical_device_ids)
|
||||
if world_size > 2 and not full_nvlink:
|
||||
logger.warning(
|
||||
"Custom allreduce is disabled because it's not supported on"
|
||||
" more than two PCIe-only GPUs. To silence this warning, "
|
||||
"specify disable_custom_all_reduce=True explicitly."
|
||||
)
|
||||
return
|
||||
# test P2P capability, this checks software/cudaruntime support
|
||||
# this is expensive to compute at the first time
|
||||
# then we cache the result
|
||||
if not _can_p2p(rank, world_size):
|
||||
logger.warning(
|
||||
"Custom allreduce is disabled because your platform lacks "
|
||||
"GPU P2P capability or P2P test failed. To silence this "
|
||||
"warning, specify disable_custom_all_reduce=True explicitly."
|
||||
)
|
||||
return
|
||||
|
||||
self.disabled = False
|
||||
# Buffers memory are owned by this Python class and passed to C++.
|
||||
# Meta data composes of two parts: meta data for synchronization and a
|
||||
# temporary buffer for storing intermediate allreduce results.
|
||||
self.meta_ptrs = self.create_shared_buffer(
|
||||
ops.meta_size() + max_size, group=group
|
||||
)
|
||||
# This is a pre-registered IPC buffer. In eager mode, input tensors
|
||||
# are first copied into this buffer before allreduce is performed
|
||||
self.buffer_ptrs = self.create_shared_buffer(max_size, group=group)
|
||||
# This is a buffer for storing the tuples of pointers pointing to
|
||||
# IPC buffers from all ranks. Each registered tuple has size of
|
||||
# 8*world_size bytes where world_size is at most 8. Allocating 8MB
|
||||
# is enough for 131072 such tuples. The largest model I've seen only
|
||||
# needs less than 10000 of registered tuples.
|
||||
self.rank_data = torch.empty(
|
||||
8 * 1024 * 1024, dtype=torch.uint8, device=self.device
|
||||
)
|
||||
self.max_size = max_size
|
||||
self.rank = rank
|
||||
self.world_size = world_size
|
||||
self.full_nvlink = full_nvlink
|
||||
self._ptr = ops.init_custom_ar(
|
||||
self.meta_ptrs, self.rank_data, rank, self.full_nvlink
|
||||
)
|
||||
ops.register_buffer(self._ptr, self.buffer_ptrs)
|
||||
|
||||
@staticmethod
|
||||
def create_shared_buffer(
|
||||
size_in_bytes: int, group: Optional[ProcessGroup] = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
Creates a shared buffer and returns a list of pointers
|
||||
representing the buffer on all processes in the group.
|
||||
"""
|
||||
lib = CudaRTLibrary()
|
||||
pointer = lib.cudaMalloc(size_in_bytes)
|
||||
handle = lib.cudaIpcGetMemHandle(pointer)
|
||||
world_size = dist.get_world_size(group=group)
|
||||
rank = dist.get_rank(group=group)
|
||||
handles = [None] * world_size
|
||||
dist.all_gather_object(handles, handle, group=group)
|
||||
|
||||
pointers: List[int] = []
|
||||
for i, h in enumerate(handles):
|
||||
if i == rank:
|
||||
pointers.append(pointer.value) # type: ignore
|
||||
else:
|
||||
pointers.append(lib.cudaIpcOpenMemHandle(h).value) # type: ignore
|
||||
|
||||
return pointers
|
||||
|
||||
@staticmethod
|
||||
def free_shared_buffer(
|
||||
pointers: List[int], group: Optional[ProcessGroup] = None
|
||||
) -> None:
|
||||
rank = dist.get_rank(group=group)
|
||||
lib = CudaRTLibrary()
|
||||
lib.cudaFree(ctypes.c_void_p(pointers[rank]))
|
||||
|
||||
@contextmanager
|
||||
def capture(self):
|
||||
"""
|
||||
The main responsibility of this context manager is the
|
||||
`register_graph_buffers` call at the end of the context.
|
||||
It records all the buffer addresses used in the CUDA graph.
|
||||
"""
|
||||
try:
|
||||
self._IS_CAPTURING = True
|
||||
yield
|
||||
finally:
|
||||
self._IS_CAPTURING = False
|
||||
if not self.disabled:
|
||||
self.register_graph_buffers()
|
||||
|
||||
def register_graph_buffers(self):
|
||||
handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr)
|
||||
logger.info("Registering %d cuda graph addresses", len(offset))
|
||||
# We cannot directly use `dist.all_gather_object` here
|
||||
# because it is incompatible with `gloo` backend under inference mode.
|
||||
# see https://github.com/pytorch/pytorch/issues/126032 for details.
|
||||
all_data = [[None, None] for _ in range(dist.get_world_size(group=self.group))]
|
||||
all_data[self.rank] = [handle, offset]
|
||||
ranks = sorted(dist.get_process_group_ranks(group=self.group))
|
||||
for i, rank in enumerate(ranks):
|
||||
dist.broadcast_object_list(
|
||||
all_data[i], src=rank, group=self.group, device="cpu"
|
||||
)
|
||||
# Unpack list of tuples to tuple of lists.
|
||||
handles = [d[0] for d in all_data] # type: ignore
|
||||
offsets = [d[1] for d in all_data] # type: ignore
|
||||
ops.register_graph_buffers(self._ptr, handles, offsets)
|
||||
|
||||
def should_custom_ar(self, inp: torch.Tensor):
|
||||
if self.disabled:
|
||||
return False
|
||||
inp_size = inp.numel() * inp.element_size()
|
||||
# custom allreduce requires input byte size to be multiples of 16
|
||||
if inp_size % 16 != 0:
|
||||
return False
|
||||
if not is_weak_contiguous(inp):
|
||||
return False
|
||||
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
|
||||
# little performance improvement over NCCL.
|
||||
if self.world_size == 2 or self.full_nvlink:
|
||||
return inp_size < self.max_size
|
||||
return False
|
||||
|
||||
def all_reduce(
|
||||
self, inp: torch.Tensor, *, out: torch.Tensor = None, registered: bool = False
|
||||
):
|
||||
"""Performs an out-of-place all reduce.
|
||||
|
||||
If registered is True, this assumes inp's pointer is already
|
||||
IPC-registered. Otherwise, inp is first copied into a pre-registered
|
||||
buffer.
|
||||
"""
|
||||
if out is None:
|
||||
out = torch.empty_like(inp)
|
||||
if registered:
|
||||
ops.all_reduce(self._ptr, inp, out, 0, 0)
|
||||
else:
|
||||
ops.all_reduce(
|
||||
self._ptr, inp, out, self.buffer_ptrs[self.rank], self.max_size
|
||||
)
|
||||
return out
|
||||
|
||||
def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
"""The main allreduce API that provides support for cuda graph."""
|
||||
# When custom allreduce is disabled, this will be None.
|
||||
if self.disabled or not self.should_custom_ar(input):
|
||||
return None
|
||||
if self._IS_CAPTURING:
|
||||
if torch.cuda.is_current_stream_capturing():
|
||||
return self.all_reduce(input, registered=True)
|
||||
else:
|
||||
# If warm up, mimic the allocation pattern since custom
|
||||
# allreduce is out-of-place.
|
||||
return torch.empty_like(input)
|
||||
else:
|
||||
# Note: outside of cuda graph context, custom allreduce incurs a
|
||||
# cost of cudaMemcpy, which should be small (<=1% of overall
|
||||
# latency) compared to the performance gain of using custom kernels
|
||||
return self.all_reduce(input, registered=False)
|
||||
|
||||
def close(self):
|
||||
if not self.disabled and self._ptr:
|
||||
ops.dispose(self._ptr)
|
||||
self._ptr = 0
|
||||
self.free_shared_buffer(self.meta_ptrs)
|
||||
self.free_shared_buffer(self.buffer_ptrs)
|
||||
|
||||
def __del__(self):
|
||||
self.close()
|
||||
@@ -0,0 +1,291 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/custom_all_reduce_utils.py
|
||||
import ctypes
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
from functools import lru_cache
|
||||
from itertools import product
|
||||
from typing import Dict, List, Optional, Sequence
|
||||
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
|
||||
from sglang.srt.utils import cuda_device_count_stateless
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def update_environment_variables(envs: Dict[str, str]):
|
||||
for k, v in envs.items():
|
||||
if k in os.environ and os.environ[k] != v:
|
||||
logger.warning(
|
||||
"Overwriting environment variable %s " "from '%s' to '%s'",
|
||||
k,
|
||||
os.environ[k],
|
||||
v,
|
||||
)
|
||||
os.environ[k] = v
|
||||
|
||||
|
||||
def producer(
|
||||
batch_src: Sequence[int],
|
||||
producer_queue,
|
||||
consumer_queue,
|
||||
result_queue,
|
||||
cuda_visible_devices: Optional[str] = None,
|
||||
):
|
||||
if cuda_visible_devices is not None:
|
||||
update_environment_variables({"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
|
||||
|
||||
lib = CudaRTLibrary()
|
||||
for i in batch_src:
|
||||
lib.cudaSetDevice(i)
|
||||
pointer = lib.cudaMalloc(1024)
|
||||
lib.cudaMemset(pointer, 1, 1024)
|
||||
lib.cudaDeviceSynchronize()
|
||||
handle = lib.cudaIpcGetMemHandle(pointer)
|
||||
producer_queue.put(handle)
|
||||
open_success = consumer_queue.get()
|
||||
if open_success:
|
||||
# use two queues to simulate barrier
|
||||
producer_queue.put(0)
|
||||
consumer_queue.get()
|
||||
# check if the memory is modified
|
||||
host_data = (ctypes.c_char * 1024)()
|
||||
lib.cudaMemcpy(host_data, pointer, 1024) # type: ignore
|
||||
for i in range(1024):
|
||||
if ord(host_data[i]) != 2:
|
||||
open_success = False
|
||||
break
|
||||
result_queue.put(open_success)
|
||||
lib.cudaDeviceReset()
|
||||
|
||||
|
||||
def consumer(
|
||||
batch_tgt: Sequence[int],
|
||||
producer_queue,
|
||||
consumer_queue,
|
||||
result_queue,
|
||||
cuda_visible_devices: Optional[str] = None,
|
||||
):
|
||||
if cuda_visible_devices is not None:
|
||||
update_environment_variables({"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
|
||||
|
||||
lib = CudaRTLibrary()
|
||||
for j in batch_tgt:
|
||||
lib.cudaSetDevice(j)
|
||||
handle = producer_queue.get()
|
||||
open_success = False
|
||||
try:
|
||||
pointer = lib.cudaIpcOpenMemHandle(handle) # type: ignore
|
||||
open_success = True
|
||||
except RuntimeError:
|
||||
# cannot error out here, because the producer process
|
||||
# is still waiting for the response.
|
||||
pass
|
||||
consumer_queue.put(open_success)
|
||||
if open_success:
|
||||
# modify the memory
|
||||
lib.cudaMemset(pointer, 2, 1024)
|
||||
lib.cudaDeviceSynchronize()
|
||||
# use two queues to simulate barrier
|
||||
producer_queue.get()
|
||||
consumer_queue.put(0)
|
||||
# check if the memory is modified
|
||||
host_data = (ctypes.c_char * 1024)()
|
||||
lib.cudaMemcpy(host_data, pointer, 1024) # type: ignore
|
||||
for i in range(1024):
|
||||
if ord(host_data[i]) != 2:
|
||||
open_success = False
|
||||
break
|
||||
result_queue.put(open_success)
|
||||
lib.cudaDeviceReset()
|
||||
|
||||
|
||||
def can_actually_p2p(
|
||||
batch_src: Sequence[int],
|
||||
batch_tgt: Sequence[int],
|
||||
) -> Sequence[bool]:
|
||||
"""
|
||||
Usually, checking if P2P access is enabled can be done by
|
||||
`torch.cuda.can_device_access_peer(src, tgt)`. However, sometimes
|
||||
the driver might be broken, and `torch.cuda.can_device_access_peer(src, tgt)`
|
||||
returns `True` even if P2P access is not actually possible.
|
||||
See https://github.com/vllm-project/vllm/issues/2728 and
|
||||
https://forums.developer.nvidia.com/t/direct-gpu-gpu-communication-does-not-seem-to-work-properly/283264/10
|
||||
Therefore, we have to perform a real P2P access to check if it is actually
|
||||
possible.
|
||||
|
||||
Note on p2p and cuda IPC:
|
||||
Usually, one process uses one GPU:
|
||||
GPU src --> cuda context src --> tensor src --> process src
|
||||
|
||||
We need to combine p2p and cuda IPC, so that:
|
||||
GPU src --> cuda context src --> tensor src --> process src
|
||||
|shared|
|
||||
GPU tgt --> cuda context tgt --> tensor tgt --> process tgt
|
||||
That is to say, process src creates a tensor in GPU src, passes IPC handle to
|
||||
process tgt, and process tgt accesses the tensor in GPU tgt. Any operation on the
|
||||
tensor in process tgt will be reflected in the tensor in process src, because
|
||||
they are the same memory segment.
|
||||
It is important to note that process tgt accesses the tensor in GPU tgt, not
|
||||
GPU src. That's why we need p2p access.
|
||||
|
||||
The most time-consuming part is the process creation. To avoid creating
|
||||
processes for every pair of GPUs, we use batched testing. We create two
|
||||
processes for testing all pairs of GPUs in batch. The trick is to reset
|
||||
the device after each test (which is not available in PyTorch).
|
||||
""" # noqa
|
||||
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
|
||||
# pass the CUDA_VISIBLE_DEVICES to the child process
|
||||
# to make sure they see the same set of GPUs
|
||||
|
||||
# make sure the processes are spawned
|
||||
smp = mp.get_context("spawn")
|
||||
producer_queue = smp.Queue()
|
||||
consumer_queue = smp.Queue()
|
||||
result_queue = smp.Queue()
|
||||
p_src = smp.Process(
|
||||
target=producer,
|
||||
args=(
|
||||
batch_src,
|
||||
producer_queue,
|
||||
consumer_queue,
|
||||
result_queue,
|
||||
cuda_visible_devices,
|
||||
),
|
||||
)
|
||||
p_tgt = smp.Process(
|
||||
target=consumer,
|
||||
args=(
|
||||
batch_tgt,
|
||||
producer_queue,
|
||||
consumer_queue,
|
||||
result_queue,
|
||||
cuda_visible_devices,
|
||||
),
|
||||
)
|
||||
p_src.start()
|
||||
p_tgt.start()
|
||||
p_src.join()
|
||||
p_tgt.join()
|
||||
assert p_src.exitcode == 0 and p_tgt.exitcode == 0
|
||||
result: List[bool] = []
|
||||
for src, tgt in zip(batch_src, batch_tgt):
|
||||
a = result_queue.get()
|
||||
b = result_queue.get()
|
||||
if a != b:
|
||||
logger.warning(
|
||||
"Two processes do not agree on the P2P access"
|
||||
" status on %d -> %d, treat as disabled.",
|
||||
src,
|
||||
tgt,
|
||||
)
|
||||
result.append(False)
|
||||
else:
|
||||
result.append(a)
|
||||
return result
|
||||
|
||||
|
||||
# why do we need this cache?
|
||||
# we are testing peer-to-peer (p2p) access between GPUs,across processes.
|
||||
# if we test it every time, it will be very slow, because we need to create
|
||||
# N * N * 2 processes, where N is the world size. This is very slow.
|
||||
# to reduce the time, we use a cache file to store the p2p access status.
|
||||
# the cache file is generated by the master process if it does not exist.
|
||||
# then all the processes can read the cache file to check the p2p access status.
|
||||
# Note that the cache file is suffixed by the CUDA_VISIBLE_DEVICES, so that we
|
||||
# can have different cache files for different CUDA_VISIBLE_DEVICES settings,
|
||||
# e.g. used by different vllm engines. The device id in the cache file is a
|
||||
# **local** device id, i.e. from 0 to num_dev-1, where num_dev is the number
|
||||
# of visible devices in the vllm engine.
|
||||
_gpu_p2p_access_cache: Optional[Dict[str, bool]] = None
|
||||
|
||||
|
||||
def gpu_p2p_access_check(src: int, tgt: int) -> bool:
|
||||
"""Check if GPU src can access GPU tgt."""
|
||||
|
||||
# if the cache variable is already calculated,
|
||||
# read from the cache instead of checking it again
|
||||
global _gpu_p2p_access_cache
|
||||
if _gpu_p2p_access_cache is not None:
|
||||
return _gpu_p2p_access_cache[f"{src}->{tgt}"]
|
||||
|
||||
is_distributed = dist.is_initialized()
|
||||
|
||||
num_dev = cuda_device_count_stateless()
|
||||
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
|
||||
if cuda_visible_devices is None:
|
||||
cuda_visible_devices = ",".join(str(i) for i in range(num_dev))
|
||||
|
||||
# VLLM_CACHE_ROOT -> SGLANG_CACHE_ROOT
|
||||
# "~/.cache/vllm" -> "~/.cache/sglang"
|
||||
SGLANG_CACHE_ROOT = os.path.expanduser("~/.cache/sglang")
|
||||
path = os.path.join(
|
||||
SGLANG_CACHE_ROOT, f"gpu_p2p_access_cache_for_{cuda_visible_devices}.json"
|
||||
)
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
from sglang.srt.distributed.parallel_state import get_world_group
|
||||
|
||||
if (not is_distributed or get_world_group().local_rank == 0) and (
|
||||
not os.path.exists(path)
|
||||
):
|
||||
# only the local master process (with local_rank == 0) can
|
||||
# enter this block to calculate the cache
|
||||
logger.info("generating GPU P2P access cache in %s", path)
|
||||
cache: Dict[str, bool] = {}
|
||||
ids = list(range(num_dev))
|
||||
# batch of all pairs of GPUs
|
||||
batch_src, batch_tgt = zip(*list(product(ids, ids)))
|
||||
# NOTE: we use `subprocess` rather than `multiprocessing` here
|
||||
# because the caller might not have `if __name__ == "__main__":`,
|
||||
# in that case we cannot use spawn method in multiprocessing.
|
||||
# However, `can_actually_p2p` requires spawn method.
|
||||
# The fix is, we use `subprocess` to call the function,
|
||||
# where we have `if __name__ == "__main__":` in this file.
|
||||
|
||||
# use a temporary file to store the result
|
||||
# we don't use the output of the subprocess directly,
|
||||
# because the subprocess might produce logging output
|
||||
with tempfile.NamedTemporaryFile() as output_file:
|
||||
input_bytes = pickle.dumps((batch_src, batch_tgt, output_file.name))
|
||||
returned = subprocess.run(
|
||||
[sys.executable, __file__], input=input_bytes, capture_output=True
|
||||
)
|
||||
# check if the subprocess is successful
|
||||
try:
|
||||
returned.check_returncode()
|
||||
except Exception as e:
|
||||
# wrap raised exception to provide more information
|
||||
raise RuntimeError(
|
||||
f"Error happened when batch testing "
|
||||
f"peer-to-peer access from {batch_src} to {batch_tgt}:\n"
|
||||
f"{returned.stderr.decode()}"
|
||||
) from e
|
||||
with open(output_file.name, "rb") as f:
|
||||
result = pickle.load(f)
|
||||
for _i, _j, r in zip(batch_src, batch_tgt, result):
|
||||
cache[f"{_i}->{_j}"] = r
|
||||
with open(path, "w") as f:
|
||||
json.dump(cache, f, indent=4)
|
||||
if is_distributed:
|
||||
get_world_group().barrier()
|
||||
logger.info("reading GPU P2P access cache from %s", path)
|
||||
with open(path) as f:
|
||||
cache = json.load(f)
|
||||
_gpu_p2p_access_cache = cache
|
||||
return _gpu_p2p_access_cache[f"{src}->{tgt}"]
|
||||
|
||||
|
||||
__all__ = ["gpu_p2p_access_check"]
|
||||
|
||||
if __name__ == "__main__":
|
||||
batch_src, batch_tgt, output_file = pickle.loads(sys.stdin.buffer.read())
|
||||
result = can_actually_p2p(batch_src, batch_tgt)
|
||||
with open(output_file, "wb") as f:
|
||||
f.write(pickle.dumps(result))
|
||||
@@ -0,0 +1,48 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/hpu_communicator.py
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from sglang.srt.utils import is_hpu
|
||||
|
||||
if is_hpu():
|
||||
import habana_frameworks.torch as htorch # noqa: F401
|
||||
|
||||
|
||||
class HpuCommunicator:
|
||||
|
||||
def __init__(self, group: ProcessGroup):
|
||||
if not is_hpu():
|
||||
self.disabled = True
|
||||
return
|
||||
self.disabled = False
|
||||
self.group = group
|
||||
self.world_size = dist.get_world_size(self.group)
|
||||
|
||||
def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# FIXME(kzawora): this is a workaround for a bug in Habana PT bridge
|
||||
# occurring when PT_HPU_ENABLE_LAZY_COLLECTIVES=true env var is used
|
||||
# (which is required for tensor parallel HPUGraph inference)
|
||||
htorch.core.mark_step()
|
||||
dist.all_reduce(x, group=self.group)
|
||||
return x
|
||||
|
||||
def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||
world_size = self.world_size
|
||||
if dim < 0:
|
||||
# Convert negative dim to positive.
|
||||
dim += x.dim()
|
||||
input_size = x.size()
|
||||
# Allocate output tensor.
|
||||
output_tensor = torch.empty(
|
||||
(world_size,) + input_size, dtype=x.dtype, device=x.device
|
||||
)
|
||||
# All-gather.
|
||||
htorch.core.mark_step()
|
||||
dist.all_gather_into_tensor(output_tensor, x, group=self.group)
|
||||
# Reshape
|
||||
output_tensor = output_tensor.movedim(0, dim)
|
||||
output_tensor = output_tensor.reshape(
|
||||
input_size[:dim] + (world_size * input_size[dim],) + input_size[dim + 1 :]
|
||||
)
|
||||
return output_tensor
|
||||
204
python/sglang/srt/distributed/device_communicators/pynccl.py
Normal file
204
python/sglang/srt/distributed/device_communicators/pynccl.py
Normal file
@@ -0,0 +1,204 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/pynccl.py
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup, ReduceOp
|
||||
|
||||
from sglang.srt.distributed.device_communicators.pynccl_wrapper import (
|
||||
NCCLLibrary,
|
||||
buffer_type,
|
||||
cudaStream_t,
|
||||
ncclComm_t,
|
||||
ncclDataTypeEnum,
|
||||
ncclRedOpTypeEnum,
|
||||
ncclUniqueId,
|
||||
)
|
||||
from sglang.srt.distributed.utils import StatelessProcessGroup
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PyNcclCommunicator:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
group: Union[ProcessGroup, StatelessProcessGroup],
|
||||
device: Union[int, str, torch.device],
|
||||
library_path: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
group: the process group to work on. If None, it will use the
|
||||
default process group.
|
||||
device: the device to bind the PyNcclCommunicator to. If None,
|
||||
it will be bind to f"cuda:{local_rank}".
|
||||
library_path: the path to the NCCL library. If None, it will
|
||||
use the default library path.
|
||||
It is the caller's responsibility to make sure each communicator
|
||||
is bind to a unique device.
|
||||
"""
|
||||
if not isinstance(group, StatelessProcessGroup):
|
||||
assert dist.is_initialized()
|
||||
assert (
|
||||
dist.get_backend(group) != dist.Backend.NCCL
|
||||
), "PyNcclCommunicator should be attached to a non-NCCL group."
|
||||
# note: this rank is the rank in the group
|
||||
self.rank = dist.get_rank(group)
|
||||
self.world_size = dist.get_world_size(group)
|
||||
else:
|
||||
self.rank = group.rank
|
||||
self.world_size = group.world_size
|
||||
|
||||
self.group = group
|
||||
|
||||
# if world_size == 1, no need to create communicator
|
||||
if self.world_size == 1:
|
||||
self.available = False
|
||||
self.disabled = True
|
||||
self.stream = None
|
||||
return
|
||||
try:
|
||||
self.nccl = NCCLLibrary(library_path)
|
||||
except Exception:
|
||||
# disable because of missing NCCL library
|
||||
# e.g. in a non-GPU environment
|
||||
self.available = False
|
||||
self.disabled = True
|
||||
self.stream = None
|
||||
return
|
||||
|
||||
self.available = True
|
||||
self.disabled = False
|
||||
|
||||
logger.info("sglang is using nccl==%s", self.nccl.ncclGetVersion())
|
||||
|
||||
if self.rank == 0:
|
||||
# get the unique id from NCCL
|
||||
self.unique_id = self.nccl.ncclGetUniqueId()
|
||||
else:
|
||||
# construct an empty unique id
|
||||
self.unique_id = ncclUniqueId()
|
||||
|
||||
if not isinstance(group, StatelessProcessGroup):
|
||||
tensor = torch.ByteTensor(list(self.unique_id.internal))
|
||||
ranks = dist.get_process_group_ranks(group)
|
||||
# arg `src` in `broadcast` is the global rank
|
||||
dist.broadcast(tensor, src=ranks[0], group=group)
|
||||
byte_list = tensor.tolist()
|
||||
for i, byte in enumerate(byte_list):
|
||||
self.unique_id.internal[i] = byte
|
||||
else:
|
||||
self.unique_id = group.broadcast_obj(self.unique_id, src=0)
|
||||
if isinstance(device, int):
|
||||
device = torch.device(f"cuda:{device}")
|
||||
elif isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
# now `device` is a `torch.device` object
|
||||
assert isinstance(device, torch.device)
|
||||
self.device = device
|
||||
# nccl communicator and stream will use this device
|
||||
# `torch.cuda.device` is a context manager that changes the
|
||||
# current cuda device to the specified one
|
||||
with torch.cuda.device(device):
|
||||
self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
|
||||
self.world_size, self.unique_id, self.rank
|
||||
)
|
||||
self.stream = torch.cuda.Stream()
|
||||
|
||||
# A small all_reduce for warmup.
|
||||
data = torch.zeros(1, device=device)
|
||||
self.all_reduce(data)
|
||||
self.stream.synchronize()
|
||||
del data
|
||||
|
||||
# by default it is disabled, e.g. in profiling models and prefill phase.
|
||||
# to use it, use under `with obj.change_state(enable=True)`, usually
|
||||
# when we are using CUDA graph.
|
||||
self.disabled = True
|
||||
|
||||
def all_reduce(
|
||||
self, tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, stream=None
|
||||
):
|
||||
if self.disabled:
|
||||
return
|
||||
# nccl communicator created on a specific device
|
||||
# will only work on tensors on the same device
|
||||
# otherwise it will cause "illegal memory access"
|
||||
assert tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {tensor.device}"
|
||||
)
|
||||
if stream is None:
|
||||
stream = self.stream
|
||||
self.nccl.ncclAllReduce(
|
||||
buffer_type(tensor.data_ptr()),
|
||||
buffer_type(tensor.data_ptr()),
|
||||
tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(tensor.dtype),
|
||||
ncclRedOpTypeEnum.from_torch(op),
|
||||
self.comm,
|
||||
cudaStream_t(stream.cuda_stream),
|
||||
)
|
||||
|
||||
def send(self, tensor: torch.Tensor, dst: int, stream=None):
|
||||
if self.disabled:
|
||||
return
|
||||
assert tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {tensor.device}"
|
||||
)
|
||||
if stream is None:
|
||||
stream = self.stream
|
||||
self.nccl.ncclSend(
|
||||
buffer_type(tensor.data_ptr()),
|
||||
tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(tensor.dtype),
|
||||
dst,
|
||||
self.comm,
|
||||
cudaStream_t(stream.cuda_stream),
|
||||
)
|
||||
|
||||
def recv(self, tensor: torch.Tensor, src: int, stream=None):
|
||||
if self.disabled:
|
||||
return
|
||||
assert tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {tensor.device}"
|
||||
)
|
||||
if stream is None:
|
||||
stream = self.stream
|
||||
self.nccl.ncclRecv(
|
||||
buffer_type(tensor.data_ptr()),
|
||||
tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(tensor.dtype),
|
||||
src,
|
||||
self.comm,
|
||||
cudaStream_t(stream.cuda_stream),
|
||||
)
|
||||
|
||||
@contextmanager
|
||||
def change_state(
|
||||
self, enable: Optional[bool] = None, stream: Optional[torch.cuda.Stream] = None
|
||||
):
|
||||
"""
|
||||
A context manager to change the state of the communicator.
|
||||
"""
|
||||
if enable is None:
|
||||
# guess a default value when not specified
|
||||
enable = self.available
|
||||
|
||||
if stream is None:
|
||||
stream = self.stream
|
||||
|
||||
old_disable = self.disabled
|
||||
old_stream = self.stream
|
||||
|
||||
self.stream = stream
|
||||
self.disabled = not enable
|
||||
yield
|
||||
|
||||
self.disabled = old_disable
|
||||
self.stream = old_stream
|
||||
@@ -0,0 +1,362 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/pynccl.py
|
||||
|
||||
# This file is a pure Python wrapper for the NCCL library.
|
||||
# The main purpose is to use NCCL combined with CUDA graph.
|
||||
# Before writing this script, we tried the following approach:
|
||||
# 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself
|
||||
# often gets stuck when initializing the NCCL communicator.
|
||||
# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce`
|
||||
# contains many other potential cuda APIs, that are not allowed during
|
||||
# capturing the CUDA graph. For further details, please check
|
||||
# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ .
|
||||
#
|
||||
# Another rejected idea is to write a C/C++ binding for NCCL. It is usually
|
||||
# doable, but we often encounter issues related with nccl versions, and need
|
||||
# to switch between different versions of NCCL. See
|
||||
# https://github.com/NVIDIA/nccl/issues/1234 for more details.
|
||||
# A C/C++ binding is not flexible enough to handle this. It requires
|
||||
# recompilation of the code every time we want to switch between different
|
||||
# versions. This current implementation, with a **pure** Python wrapper, is
|
||||
# more flexible. We can easily switch between different versions of NCCL by
|
||||
# changing the environment variable `SGLANG_NCCL_SO_PATH`, or the `so_file`
|
||||
# variable in the code.
|
||||
|
||||
import ctypes
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.distributed import ReduceOp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def find_nccl_library() -> str:
|
||||
"""
|
||||
We either use the library file specified by the `SGLANG_NCCL_SO_PATH`
|
||||
environment variable, or we find the library file brought by PyTorch.
|
||||
After importing `torch`, `libnccl.so.2` or `librccl.so.1` can be
|
||||
found by `ctypes` automatically.
|
||||
"""
|
||||
|
||||
# so_file can be set to None in sglang
|
||||
so_file = os.environ.get("SGLANG_NCCL_SO_PATH", None)
|
||||
|
||||
# manually load the nccl library
|
||||
if so_file:
|
||||
logger.info(
|
||||
"Found nccl from environment variable SGLANG_NCCL_SO_PATH=%s", so_file
|
||||
)
|
||||
else:
|
||||
if torch.version.cuda is not None:
|
||||
so_file = "libnccl.so.2"
|
||||
elif torch.version.hip is not None:
|
||||
so_file = "librccl.so.1"
|
||||
else:
|
||||
raise ValueError("NCCL only supports CUDA and ROCm backends.")
|
||||
logger.info("Found nccl from library %s", so_file)
|
||||
return so_file
|
||||
|
||||
|
||||
# === export types and functions from nccl to Python ===
|
||||
# for the original nccl definition, please check
|
||||
# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in
|
||||
|
||||
ncclResult_t = ctypes.c_int
|
||||
ncclComm_t = ctypes.c_void_p
|
||||
|
||||
|
||||
class ncclUniqueId(ctypes.Structure):
|
||||
_fields_ = [("internal", ctypes.c_byte * 128)]
|
||||
|
||||
|
||||
cudaStream_t = ctypes.c_void_p
|
||||
buffer_type = ctypes.c_void_p
|
||||
|
||||
ncclDataType_t = ctypes.c_int
|
||||
|
||||
|
||||
class ncclDataTypeEnum:
|
||||
ncclInt8 = 0
|
||||
ncclChar = 0
|
||||
ncclUint8 = 1
|
||||
ncclInt32 = 2
|
||||
ncclInt = 2
|
||||
ncclUint32 = 3
|
||||
ncclInt64 = 4
|
||||
ncclUint64 = 5
|
||||
ncclFloat16 = 6
|
||||
ncclHalf = 6
|
||||
ncclFloat32 = 7
|
||||
ncclFloat = 7
|
||||
ncclFloat64 = 8
|
||||
ncclDouble = 8
|
||||
ncclBfloat16 = 9
|
||||
ncclNumTypes = 10
|
||||
|
||||
@classmethod
|
||||
def from_torch(cls, dtype: torch.dtype) -> int:
|
||||
if dtype == torch.int8:
|
||||
return cls.ncclInt8
|
||||
if dtype == torch.uint8:
|
||||
return cls.ncclUint8
|
||||
if dtype == torch.int32:
|
||||
return cls.ncclInt32
|
||||
if dtype == torch.int64:
|
||||
return cls.ncclInt64
|
||||
if dtype == torch.float16:
|
||||
return cls.ncclFloat16
|
||||
if dtype == torch.float32:
|
||||
return cls.ncclFloat32
|
||||
if dtype == torch.float64:
|
||||
return cls.ncclFloat64
|
||||
if dtype == torch.bfloat16:
|
||||
return cls.ncclBfloat16
|
||||
raise ValueError(f"Unsupported dtype: {dtype}")
|
||||
|
||||
|
||||
ncclRedOp_t = ctypes.c_int
|
||||
|
||||
|
||||
class ncclRedOpTypeEnum:
|
||||
ncclSum = 0
|
||||
ncclProd = 1
|
||||
ncclMax = 2
|
||||
ncclMin = 3
|
||||
ncclAvg = 4
|
||||
ncclNumOps = 5
|
||||
|
||||
@classmethod
|
||||
def from_torch(cls, op: ReduceOp) -> int:
|
||||
if op == ReduceOp.SUM:
|
||||
return cls.ncclSum
|
||||
if op == ReduceOp.PRODUCT:
|
||||
return cls.ncclProd
|
||||
if op == ReduceOp.MAX:
|
||||
return cls.ncclMax
|
||||
if op == ReduceOp.MIN:
|
||||
return cls.ncclMin
|
||||
if op == ReduceOp.AVG:
|
||||
return cls.ncclAvg
|
||||
raise ValueError(f"Unsupported op: {op}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class Function:
|
||||
name: str
|
||||
restype: Any
|
||||
argtypes: List[Any]
|
||||
|
||||
|
||||
class NCCLLibrary:
|
||||
exported_functions = [
|
||||
# const char* ncclGetErrorString(ncclResult_t result)
|
||||
Function("ncclGetErrorString", ctypes.c_char_p, [ncclResult_t]),
|
||||
# ncclResult_t ncclGetVersion(int *version);
|
||||
Function("ncclGetVersion", ncclResult_t, [ctypes.POINTER(ctypes.c_int)]),
|
||||
# ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId);
|
||||
Function("ncclGetUniqueId", ncclResult_t, [ctypes.POINTER(ncclUniqueId)]),
|
||||
# ncclResult_t ncclCommInitRank(
|
||||
# ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank);
|
||||
# note that ncclComm_t is a pointer type, so the first argument
|
||||
# is a pointer to a pointer
|
||||
Function(
|
||||
"ncclCommInitRank",
|
||||
ncclResult_t,
|
||||
[ctypes.POINTER(ncclComm_t), ctypes.c_int, ncclUniqueId, ctypes.c_int],
|
||||
),
|
||||
# ncclResult_t ncclAllReduce(
|
||||
# const void* sendbuff, void* recvbuff, size_t count,
|
||||
# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
|
||||
# cudaStream_t stream);
|
||||
# note that cudaStream_t is a pointer type, so the last argument
|
||||
# is a pointer
|
||||
Function(
|
||||
"ncclAllReduce",
|
||||
ncclResult_t,
|
||||
[
|
||||
buffer_type,
|
||||
buffer_type,
|
||||
ctypes.c_size_t,
|
||||
ncclDataType_t,
|
||||
ncclRedOp_t,
|
||||
ncclComm_t,
|
||||
cudaStream_t,
|
||||
],
|
||||
),
|
||||
# ncclResult_t ncclSend(
|
||||
# const void* sendbuff, size_t count, ncclDataType_t datatype,
|
||||
# int dest, ncclComm_t comm, cudaStream_t stream);
|
||||
Function(
|
||||
"ncclSend",
|
||||
ncclResult_t,
|
||||
[
|
||||
buffer_type,
|
||||
ctypes.c_size_t,
|
||||
ncclDataType_t,
|
||||
ctypes.c_int,
|
||||
ncclComm_t,
|
||||
cudaStream_t,
|
||||
],
|
||||
),
|
||||
# ncclResult_t ncclRecv(
|
||||
# void* recvbuff, size_t count, ncclDataType_t datatype,
|
||||
# int src, ncclComm_t comm, cudaStream_t stream);
|
||||
Function(
|
||||
"ncclRecv",
|
||||
ncclResult_t,
|
||||
[
|
||||
buffer_type,
|
||||
ctypes.c_size_t,
|
||||
ncclDataType_t,
|
||||
ctypes.c_int,
|
||||
ncclComm_t,
|
||||
cudaStream_t,
|
||||
],
|
||||
),
|
||||
# be cautious! this is a collective call, it will block until all
|
||||
# processes in the communicator have called this function.
|
||||
# because Python object destruction can happen in random order,
|
||||
# it is better not to call it at all.
|
||||
# ncclResult_t ncclCommDestroy(ncclComm_t comm);
|
||||
Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]),
|
||||
]
|
||||
|
||||
# class attribute to store the mapping from the path to the library
|
||||
# to avoid loading the same library multiple times
|
||||
path_to_library_cache: Dict[str, Any] = {}
|
||||
|
||||
# class attribute to store the mapping from library path
|
||||
# to the corresponding dictionary
|
||||
path_to_dict_mapping: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
def __init__(self, so_file: Optional[str] = None):
|
||||
|
||||
so_file = so_file or find_nccl_library()
|
||||
|
||||
try:
|
||||
if so_file not in NCCLLibrary.path_to_dict_mapping:
|
||||
lib = ctypes.CDLL(so_file)
|
||||
NCCLLibrary.path_to_library_cache[so_file] = lib
|
||||
self.lib = NCCLLibrary.path_to_library_cache[so_file]
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to load NCCL library from %s ."
|
||||
"It is expected if you are not running on NVIDIA/AMD GPUs."
|
||||
"Otherwise, the nccl library might not exist, be corrupted "
|
||||
"or it does not support the current platform %s."
|
||||
"If you already have the library, please set the "
|
||||
"environment variable SGLANG_NCCL_SO_PATH"
|
||||
" to point to the correct nccl library path.",
|
||||
so_file,
|
||||
platform.platform(),
|
||||
)
|
||||
raise e
|
||||
|
||||
if so_file not in NCCLLibrary.path_to_dict_mapping:
|
||||
_funcs: Dict[str, Any] = {}
|
||||
for func in NCCLLibrary.exported_functions:
|
||||
f = getattr(self.lib, func.name)
|
||||
f.restype = func.restype
|
||||
f.argtypes = func.argtypes
|
||||
_funcs[func.name] = f
|
||||
NCCLLibrary.path_to_dict_mapping[so_file] = _funcs
|
||||
self._funcs = NCCLLibrary.path_to_dict_mapping[so_file]
|
||||
|
||||
def ncclGetErrorString(self, result: ncclResult_t) -> str:
|
||||
return self._funcs["ncclGetErrorString"](result).decode("utf-8")
|
||||
|
||||
def NCCL_CHECK(self, result: ncclResult_t) -> None:
|
||||
if result != 0:
|
||||
error_str = self.ncclGetErrorString(result)
|
||||
raise RuntimeError(f"NCCL error: {error_str}")
|
||||
|
||||
def ncclGetVersion(self) -> str:
|
||||
version = ctypes.c_int()
|
||||
self.NCCL_CHECK(self._funcs["ncclGetVersion"](ctypes.byref(version)))
|
||||
version_str = str(version.value)
|
||||
# something like 21903 --> "2.19.3"
|
||||
major = version_str[0].lstrip("0")
|
||||
minor = version_str[1:3].lstrip("0")
|
||||
patch = version_str[3:].lstrip("0")
|
||||
return f"{major}.{minor}.{patch}"
|
||||
|
||||
def ncclGetUniqueId(self) -> ncclUniqueId:
|
||||
unique_id = ncclUniqueId()
|
||||
self.NCCL_CHECK(self._funcs["ncclGetUniqueId"](ctypes.byref(unique_id)))
|
||||
return unique_id
|
||||
|
||||
def ncclCommInitRank(
|
||||
self, world_size: int, unique_id: ncclUniqueId, rank: int
|
||||
) -> ncclComm_t:
|
||||
comm = ncclComm_t()
|
||||
self.NCCL_CHECK(
|
||||
self._funcs["ncclCommInitRank"](
|
||||
ctypes.byref(comm), world_size, unique_id, rank
|
||||
)
|
||||
)
|
||||
return comm
|
||||
|
||||
def ncclAllReduce(
|
||||
self,
|
||||
sendbuff: buffer_type,
|
||||
recvbuff: buffer_type,
|
||||
count: int,
|
||||
datatype: int,
|
||||
op: int,
|
||||
comm: ncclComm_t,
|
||||
stream: cudaStream_t,
|
||||
) -> None:
|
||||
# `datatype` actually should be `ncclDataType_t`
|
||||
# and `op` should be `ncclRedOp_t`
|
||||
# both are aliases of `ctypes.c_int`
|
||||
# when we pass int to a function, it will be converted to `ctypes.c_int`
|
||||
# by ctypes automatically
|
||||
self.NCCL_CHECK(
|
||||
self._funcs["ncclAllReduce"](
|
||||
sendbuff, recvbuff, count, datatype, op, comm, stream
|
||||
)
|
||||
)
|
||||
|
||||
def ncclSend(
|
||||
self,
|
||||
sendbuff: buffer_type,
|
||||
count: int,
|
||||
datatype: int,
|
||||
dest: int,
|
||||
comm: ncclComm_t,
|
||||
stream: cudaStream_t,
|
||||
) -> None:
|
||||
self.NCCL_CHECK(
|
||||
self._funcs["ncclSend"](sendbuff, count, datatype, dest, comm, stream)
|
||||
)
|
||||
|
||||
def ncclRecv(
|
||||
self,
|
||||
recvbuff: buffer_type,
|
||||
count: int,
|
||||
datatype: int,
|
||||
src: int,
|
||||
comm: ncclComm_t,
|
||||
stream: cudaStream_t,
|
||||
) -> None:
|
||||
self.NCCL_CHECK(
|
||||
self._funcs["ncclRecv"](recvbuff, count, datatype, src, comm, stream)
|
||||
)
|
||||
|
||||
def ncclCommDestroy(self, comm: ncclComm_t) -> None:
|
||||
self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm))
|
||||
|
||||
|
||||
__all__ = [
|
||||
"NCCLLibrary",
|
||||
"ncclDataTypeEnum",
|
||||
"ncclRedOpTypeEnum",
|
||||
"ncclUniqueId",
|
||||
"ncclComm_t",
|
||||
"cudaStream_t",
|
||||
"buffer_type",
|
||||
]
|
||||
@@ -0,0 +1,568 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/shm_broadcast.py
|
||||
import ipaddress
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import socket
|
||||
import time
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from multiprocessing import shared_memory
|
||||
from typing import List, Optional
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
from zmq import IPV6 # type: ignore
|
||||
from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore
|
||||
|
||||
# SGLANG_RINGBUFFER_WARNING_INTERVAL can be set to 60
|
||||
SGLANG_RINGBUFFER_WARNING_INTERVAL = int(
|
||||
os.environ.get("SGLANG_RINGBUFFER_WARNING_INTERVAL", "60")
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_ip() -> str:
|
||||
# SGLANG_HOST_IP env can be ignore
|
||||
host_ip = os.getenv("SGLANG_HOST_IP", "") or os.getenv("HOST_IP", "")
|
||||
if host_ip:
|
||||
return host_ip
|
||||
|
||||
# IP is not set, try to get it from the network interface
|
||||
|
||||
# try ipv4
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
try:
|
||||
s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable
|
||||
return s.getsockname()[0]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# try ipv6
|
||||
try:
|
||||
s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
|
||||
# Google's public DNS server, see
|
||||
# https://developers.google.com/speed/public-dns/docs/using#addresses
|
||||
s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable
|
||||
return s.getsockname()[0]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
warnings.warn(
|
||||
"Failed to get the IP address, using 0.0.0.0 by default."
|
||||
"The value can be set by the environment variable"
|
||||
" SGLANG_HOST_IP or HOST_IP.",
|
||||
stacklevel=2,
|
||||
)
|
||||
return "0.0.0.0"
|
||||
|
||||
|
||||
def get_open_port() -> int:
|
||||
|
||||
port = os.getenv("SGLANG_PORT")
|
||||
if port is not None:
|
||||
while True:
|
||||
try:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("", port))
|
||||
return port
|
||||
except OSError:
|
||||
port += 1 # Increment port number if already in use
|
||||
logger.info("Port %d is already in use, trying port %d", port - 1, port)
|
||||
# try ipv4
|
||||
try:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("", 0))
|
||||
return s.getsockname()[1]
|
||||
except OSError:
|
||||
# try ipv6
|
||||
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
|
||||
s.bind(("", 0))
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
def is_valid_ipv6_address(address: str) -> bool:
|
||||
try:
|
||||
ipaddress.IPv6Address(address)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
class ShmRingBuffer:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_reader: int,
|
||||
max_chunk_bytes: int,
|
||||
max_chunks: int,
|
||||
name: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
A shared memory ring buffer implementation for broadcast communication.
|
||||
Essentially, it is a queue where only one will `enqueue` and multiple
|
||||
will `dequeue`. The max size of each item, together with the max number
|
||||
of items that can be stored in the buffer are known in advance.
|
||||
In this case, we don't need to synchronize the access to
|
||||
the buffer.
|
||||
|
||||
Buffer memory layout:
|
||||
data metadata
|
||||
| |
|
||||
| (current_idx) | (current_idx)
|
||||
v v
|
||||
+-------------------------------+----------------------------------------+
|
||||
| chunk0 | chunk1 | ... | chunk | metadata0 | metadata1 | ... | metadata |
|
||||
+-------------------------------+----------------------------------------+
|
||||
| max_chunks x max_chunk_bytes | max_chunks x (1 + n_reader) bytes |
|
||||
|
||||
metadata memory layout: each byte is a flag, the first byte is the written
|
||||
flag, and the rest are reader flags. The flags are set to 0 by default.
|
||||
+--------------+--------------+--------------+-----+--------------+
|
||||
| written_flag | reader0_flag | reader1_flag | ... | readerN_flag |
|
||||
+--------------+--------------+--------------+-----+--------------+
|
||||
|
||||
The state of metadata is as follows:
|
||||
|
||||
(case 1) 0???...???: the block is not written yet, cannot read, can write
|
||||
(case 2) 1000...000: the block is just written, can read, cannot write
|
||||
(case 3) 1???...???: the block is written and read by some readers, can read if not read, cannot write
|
||||
(case 4) 1111...111: the block is written and read by all readers, cannot read, can write
|
||||
|
||||
State transition for readers:
|
||||
|
||||
When a reader finds a block that it can read (case 2 or 3), it can yield the block for caller to read.
|
||||
Only after the caller finishes reading the block, the reader can mark the block as read.
|
||||
Readers only mark the block as read (from 0 to 1), the writer marks the block as ready to read (from 1 to 0).
|
||||
|
||||
State transition for writer:
|
||||
|
||||
When the writer writes to a block (case 1 or 4), it first resets the written flag to 0, converting either case
|
||||
to case 1. Then it can yield the block for caller to write. After the caller finishes writing the block, the writer
|
||||
can reset the reader flags to 0, and mark the block as written (from 0 to 1).
|
||||
NOTE: the order is important here, first reset the reader flags (so that we are still in case 1), then mark the block as written. The state transition is atomic. If we do it in the reverse order, it will go through case 3 and then back to case 2, and readers might read the intermediate case 3, which is not correct.
|
||||
|
||||
During creation, `name` is None and the buffer is created. We can pass the
|
||||
created object to other processes by pickling it. The other processes will
|
||||
get the name of the shared memory and open it, so that they can access the
|
||||
same shared memory buffer.
|
||||
""" # noqa
|
||||
self.n_reader = n_reader
|
||||
self.metadata_size = 1 + n_reader
|
||||
self.max_chunk_bytes = max_chunk_bytes
|
||||
self.max_chunks = max_chunks
|
||||
self.total_bytes_of_buffer = (
|
||||
self.max_chunk_bytes + self.metadata_size
|
||||
) * self.max_chunks
|
||||
self.data_offset = 0
|
||||
self.metadata_offset = self.max_chunk_bytes * self.max_chunks
|
||||
|
||||
if name is None:
|
||||
# we are creating a buffer
|
||||
self.is_creator = True
|
||||
self.shared_memory = shared_memory.SharedMemory(
|
||||
create=True, size=self.total_bytes_of_buffer
|
||||
)
|
||||
# initialize the metadata section to 0
|
||||
with memoryview(
|
||||
self.shared_memory.buf[self.metadata_offset :]
|
||||
) as metadata_buffer:
|
||||
torch.frombuffer(metadata_buffer, dtype=torch.uint8).fill_(0)
|
||||
else:
|
||||
# we are opening an existing buffer
|
||||
self.is_creator = False
|
||||
# fix to https://stackoverflow.com/q/62748654/9191338
|
||||
# Python incorrectly tracks shared memory even if it is not
|
||||
# created by the process. The following patch is a workaround.
|
||||
with patch(
|
||||
"multiprocessing.resource_tracker.register",
|
||||
lambda *args, **kwargs: None,
|
||||
):
|
||||
try:
|
||||
self.shared_memory = shared_memory.SharedMemory(name=name)
|
||||
assert self.shared_memory.size == self.total_bytes_of_buffer
|
||||
except FileNotFoundError:
|
||||
# we might deserialize the object in a different node
|
||||
# in this case, this object is not used,
|
||||
# and we should suppress the error
|
||||
pass
|
||||
|
||||
def __reduce__(self):
|
||||
return (
|
||||
self.__class__,
|
||||
(
|
||||
self.n_reader,
|
||||
self.max_chunk_bytes,
|
||||
self.max_chunks,
|
||||
self.shared_memory.name,
|
||||
),
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
if hasattr(self, "shared_memory"):
|
||||
self.shared_memory.close()
|
||||
if self.is_creator:
|
||||
self.shared_memory.unlink()
|
||||
|
||||
@contextmanager
|
||||
def get_data(self, current_idx: int):
|
||||
start = self.data_offset + current_idx * self.max_chunk_bytes
|
||||
end = start + self.max_chunk_bytes
|
||||
with memoryview(self.shared_memory.buf[start:end]) as buf:
|
||||
yield buf
|
||||
|
||||
@contextmanager
|
||||
def get_metadata(self, current_idx: int):
|
||||
start = self.metadata_offset + current_idx * self.metadata_size
|
||||
end = start + self.metadata_size
|
||||
with memoryview(self.shared_memory.buf[start:end]) as buf:
|
||||
yield buf
|
||||
|
||||
|
||||
@dataclass
|
||||
class Handle:
|
||||
connect_ip: str
|
||||
local_reader_ranks: List[int] = field(default_factory=list)
|
||||
|
||||
buffer: Optional[ShmRingBuffer] = None
|
||||
local_subscribe_port: Optional[int] = None
|
||||
remote_subscribe_port: Optional[int] = None
|
||||
|
||||
|
||||
class MessageQueue:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_reader, # number of all readers
|
||||
n_local_reader, # number of local readers through shared memory
|
||||
local_reader_ranks: Optional[List[int]] = None,
|
||||
max_chunk_bytes: int = 1024 * 1024 * 10,
|
||||
max_chunks: int = 10,
|
||||
connect_ip: Optional[str] = None,
|
||||
):
|
||||
if local_reader_ranks is None:
|
||||
local_reader_ranks = list(range(n_local_reader))
|
||||
else:
|
||||
assert len(local_reader_ranks) == n_local_reader
|
||||
self.n_local_reader = n_local_reader
|
||||
n_remote_reader = n_reader - n_local_reader
|
||||
self.n_remote_reader = n_remote_reader
|
||||
|
||||
if connect_ip is None:
|
||||
connect_ip = get_ip() if n_remote_reader > 0 else "127.0.0.1"
|
||||
|
||||
context = Context()
|
||||
|
||||
if n_local_reader > 0:
|
||||
# for local readers, we will:
|
||||
# 1. create a shared memory ring buffer to communicate small data
|
||||
# 2. create a publish-subscribe socket to communicate large data
|
||||
self.buffer = ShmRingBuffer(n_local_reader, max_chunk_bytes, max_chunks)
|
||||
|
||||
# XPUB is very similar to PUB,
|
||||
# except that it can receive subscription messages
|
||||
# to confirm the number of subscribers
|
||||
self.local_socket = context.socket(XPUB)
|
||||
# set the verbose option so that we can receive every subscription
|
||||
# message. otherwise, we will only receive the first subscription
|
||||
# see http://api.zeromq.org/3-3:zmq-setsockopt for more details
|
||||
self.local_socket.setsockopt(XPUB_VERBOSE, True)
|
||||
local_subscribe_port = get_open_port()
|
||||
socket_addr = f"tcp://127.0.0.1:{local_subscribe_port}"
|
||||
logger.debug("Binding to %s", socket_addr)
|
||||
self.local_socket.bind(socket_addr)
|
||||
|
||||
self.current_idx = 0
|
||||
|
||||
else:
|
||||
self.buffer = None # type: ignore
|
||||
local_subscribe_port = None
|
||||
self.local_socket = None
|
||||
self.current_idx = -1
|
||||
|
||||
if n_remote_reader > 0:
|
||||
# for remote readers, we will:
|
||||
# create a publish-subscribe socket to communicate large data
|
||||
self.remote_socket = context.socket(XPUB)
|
||||
self.remote_socket.setsockopt(XPUB_VERBOSE, True)
|
||||
remote_subscribe_port = get_open_port()
|
||||
if is_valid_ipv6_address(connect_ip):
|
||||
self.remote_socket.setsockopt(IPV6, 1)
|
||||
socket_addr = f"tcp://*:{remote_subscribe_port}"
|
||||
self.remote_socket.bind(socket_addr)
|
||||
|
||||
else:
|
||||
remote_subscribe_port = None
|
||||
self.remote_socket = None
|
||||
|
||||
self._is_writer = True
|
||||
self._is_local_reader = False
|
||||
self.local_reader_rank = -1
|
||||
# rank does not matter for remote readers
|
||||
self._is_remote_reader = False
|
||||
|
||||
self.handle = Handle(
|
||||
connect_ip=connect_ip,
|
||||
local_reader_ranks=local_reader_ranks,
|
||||
buffer=self.buffer,
|
||||
local_subscribe_port=local_subscribe_port,
|
||||
remote_subscribe_port=remote_subscribe_port,
|
||||
)
|
||||
|
||||
logger.info("vLLM message queue communication handle: %s", self.handle)
|
||||
|
||||
def export_handle(self) -> Handle:
|
||||
return self.handle
|
||||
|
||||
@staticmethod
|
||||
def create_from_handle(handle: Handle, rank) -> "MessageQueue":
|
||||
self = MessageQueue.__new__(MessageQueue)
|
||||
self.handle = handle
|
||||
self._is_writer = False
|
||||
|
||||
context = Context()
|
||||
|
||||
if rank in handle.local_reader_ranks:
|
||||
assert handle.buffer is not None
|
||||
self.buffer = handle.buffer
|
||||
self.current_idx = 0
|
||||
self.local_reader_rank = handle.local_reader_ranks.index(rank)
|
||||
self._is_local_reader = True
|
||||
self._is_remote_reader = False
|
||||
|
||||
self.local_socket = context.socket(SUB)
|
||||
self.local_socket.setsockopt_string(SUBSCRIBE, "")
|
||||
socket_addr = f"tcp://127.0.0.1:{handle.local_subscribe_port}"
|
||||
logger.debug("Connecting to %s", socket_addr)
|
||||
self.local_socket.connect(socket_addr)
|
||||
|
||||
self.remote_socket = None
|
||||
else:
|
||||
self.buffer = None # type: ignore
|
||||
self.current_idx = -1
|
||||
self.local_reader_rank = -1
|
||||
self._is_local_reader = False
|
||||
self._is_remote_reader = True
|
||||
|
||||
self.local_socket = None
|
||||
|
||||
self.remote_socket = context.socket(SUB)
|
||||
self.remote_socket.setsockopt_string(SUBSCRIBE, "")
|
||||
if is_valid_ipv6_address(handle.connect_ip):
|
||||
self.remote_socket.setsockopt(IPV6, 1)
|
||||
socket_addr = f"tcp://{handle.connect_ip}:{handle.remote_subscribe_port}"
|
||||
logger.debug("Connecting to %s", socket_addr)
|
||||
self.remote_socket.connect(socket_addr)
|
||||
|
||||
return self
|
||||
|
||||
def wait_until_ready(self):
|
||||
"""This is a collective operation. All processes (including the
|
||||
readers and the writer) should call this function.
|
||||
"""
|
||||
if self._is_writer:
|
||||
# wait for all readers to connect
|
||||
|
||||
# local readers
|
||||
for i in range(self.n_local_reader):
|
||||
# wait for subscription messages from all local readers
|
||||
self.local_socket.recv()
|
||||
if self.n_local_reader > 0:
|
||||
# send a message to all local readers
|
||||
# to make sure the publish channel is working
|
||||
self.local_socket.send(b"READY")
|
||||
|
||||
# remote readers
|
||||
for i in range(self.n_remote_reader):
|
||||
# wait for subscription messages from all remote readers
|
||||
self.remote_socket.recv()
|
||||
if self.n_remote_reader > 0:
|
||||
# send a message to all remote readers
|
||||
# to make sure the publish channel is working
|
||||
self.remote_socket.send(b"READY")
|
||||
elif self._is_local_reader:
|
||||
# wait for the writer to send a message
|
||||
recv = self.local_socket.recv()
|
||||
assert recv == b"READY"
|
||||
elif self._is_remote_reader:
|
||||
# wait for the writer to send a message
|
||||
recv = self.remote_socket.recv()
|
||||
assert recv == b"READY"
|
||||
|
||||
@contextmanager
|
||||
def acquire_write(self):
|
||||
assert self._is_writer, "Only writers can acquire write"
|
||||
start_time = time.monotonic()
|
||||
n_warning = 1
|
||||
while True:
|
||||
with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
|
||||
read_count = sum(metadata_buffer[1:])
|
||||
written_flag = metadata_buffer[0]
|
||||
if written_flag and read_count != self.buffer.n_reader:
|
||||
# this block is written and not read by all readers
|
||||
# for writers, `self.current_idx` is the next block to write
|
||||
# if this block is not ready to write,
|
||||
# we need to wait until it is read by all readers
|
||||
|
||||
# Release the processor to other threads
|
||||
os.sched_yield()
|
||||
|
||||
# if we wait for a long time, we should warn the user
|
||||
if (
|
||||
time.monotonic() - start_time
|
||||
> SGLANG_RINGBUFFER_WARNING_INTERVAL * n_warning
|
||||
):
|
||||
logger.warning(
|
||||
"No available block found in %s second. ",
|
||||
SGLANG_RINGBUFFER_WARNING_INTERVAL,
|
||||
)
|
||||
n_warning += 1
|
||||
|
||||
continue
|
||||
# found a block that is either
|
||||
# (1) not written
|
||||
# (2) read by all readers
|
||||
|
||||
# mark the block as not written
|
||||
metadata_buffer[0] = 0
|
||||
# let caller write to the buffer
|
||||
with self.buffer.get_data(self.current_idx) as buf:
|
||||
yield buf
|
||||
|
||||
# caller has written to the buffer
|
||||
# NOTE: order is important here
|
||||
# first set the read flags to 0
|
||||
# then set the written flag to 1
|
||||
# otherwise, the readers may think they already read the block
|
||||
for i in range(1, self.buffer.n_reader + 1):
|
||||
# set read flag to 0, meaning it is not read yet
|
||||
metadata_buffer[i] = 0
|
||||
# mark the block as written
|
||||
metadata_buffer[0] = 1
|
||||
self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks
|
||||
break
|
||||
|
||||
@contextmanager
|
||||
def acquire_read(self):
|
||||
assert self._is_local_reader, "Only readers can acquire read"
|
||||
start_time = time.monotonic()
|
||||
n_warning = 1
|
||||
while True:
|
||||
with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
|
||||
read_flag = metadata_buffer[self.local_reader_rank + 1]
|
||||
written_flag = metadata_buffer[0]
|
||||
if not written_flag or read_flag:
|
||||
# this block is either
|
||||
# (1) not written
|
||||
# (2) already read by this reader
|
||||
|
||||
# for readers, `self.current_idx` is the next block to read
|
||||
# if this block is not ready,
|
||||
# we need to wait until it is written
|
||||
|
||||
# Release the processor to other threads
|
||||
os.sched_yield()
|
||||
|
||||
# if we wait for a long time, we should warn the user
|
||||
if (
|
||||
time.monotonic() - start_time
|
||||
> SGLANG_RINGBUFFER_WARNING_INTERVAL * n_warning
|
||||
):
|
||||
logger.warning(
|
||||
"No available block found in %s second. ",
|
||||
SGLANG_RINGBUFFER_WARNING_INTERVAL,
|
||||
)
|
||||
n_warning += 1
|
||||
|
||||
continue
|
||||
# found a block that is not read by this reader
|
||||
# let caller read from the buffer
|
||||
with self.buffer.get_data(self.current_idx) as buf:
|
||||
yield buf
|
||||
|
||||
# caller has read from the buffer
|
||||
# set the read flag
|
||||
metadata_buffer[self.local_reader_rank + 1] = 1
|
||||
self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks
|
||||
break
|
||||
|
||||
def enqueue(self, obj):
|
||||
assert self._is_writer, "Only writers can enqueue"
|
||||
serialized_obj = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
if self.n_local_reader > 0:
|
||||
if len(serialized_obj) >= self.buffer.max_chunk_bytes:
|
||||
with self.acquire_write() as buf:
|
||||
buf[0] = 1 # overflow
|
||||
self.local_socket.send(serialized_obj)
|
||||
else:
|
||||
with self.acquire_write() as buf:
|
||||
buf[0] = 0 # not overflow
|
||||
buf[1 : len(serialized_obj) + 1] = serialized_obj
|
||||
if self.n_remote_reader > 0:
|
||||
self.remote_socket.send(serialized_obj)
|
||||
|
||||
def dequeue(self):
|
||||
if self._is_local_reader:
|
||||
with self.acquire_read() as buf:
|
||||
overflow = buf[0] == 1
|
||||
if not overflow:
|
||||
# no need to know the size of serialized object
|
||||
# pickle format contains the size information internally
|
||||
# see https://docs.python.org/3/library/pickle.html
|
||||
obj = pickle.loads(buf[1:])
|
||||
if overflow:
|
||||
recv = self.local_socket.recv()
|
||||
obj = pickle.loads(recv)
|
||||
elif self._is_remote_reader:
|
||||
recv = self.remote_socket.recv()
|
||||
obj = pickle.loads(recv)
|
||||
else:
|
||||
raise RuntimeError("Only readers can dequeue")
|
||||
return obj
|
||||
|
||||
def broadcast_object(self, obj=None):
|
||||
if self._is_writer:
|
||||
self.enqueue(obj)
|
||||
return obj
|
||||
else:
|
||||
return self.dequeue()
|
||||
|
||||
@staticmethod
|
||||
def create_from_process_group(
|
||||
pg: ProcessGroup, max_chunk_bytes, max_chunks, writer_rank=0
|
||||
) -> "MessageQueue":
|
||||
group_rank = dist.get_rank(pg)
|
||||
group_world_size = dist.get_world_size(pg)
|
||||
global_ranks = dist.get_process_group_ranks(pg)
|
||||
|
||||
from sglang.srt.distributed.parallel_state import in_the_same_node_as
|
||||
|
||||
status = in_the_same_node_as(pg, source_rank=writer_rank)
|
||||
same_node_ranks = [i for i, s in enumerate(status) if s]
|
||||
n_reader = group_world_size - 1
|
||||
n_local_reader = len(same_node_ranks) - 1
|
||||
local_reader_ranks = [i for i in same_node_ranks if i != writer_rank]
|
||||
buffer_io: MessageQueue
|
||||
if group_rank == writer_rank:
|
||||
buffer_io = MessageQueue(
|
||||
n_reader=n_reader,
|
||||
n_local_reader=n_local_reader,
|
||||
local_reader_ranks=local_reader_ranks,
|
||||
max_chunk_bytes=max_chunk_bytes,
|
||||
max_chunks=max_chunks,
|
||||
)
|
||||
handle = buffer_io.export_handle()
|
||||
dist.broadcast_object_list(
|
||||
[handle], src=global_ranks[writer_rank], group=pg
|
||||
)
|
||||
else:
|
||||
recv = [None]
|
||||
dist.broadcast_object_list(recv, src=global_ranks[writer_rank], group=pg)
|
||||
handle = recv[0] # type: ignore
|
||||
buffer_io = MessageQueue.create_from_handle(handle, group_rank)
|
||||
buffer_io.wait_until_ready()
|
||||
return buffer_io
|
||||
@@ -0,0 +1,47 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/xpu_communicator.py
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from sglang.srt.utils import is_xpu
|
||||
|
||||
|
||||
class XpuCommunicator:
|
||||
|
||||
def __init__(self, group: ProcessGroup):
|
||||
if not is_xpu():
|
||||
self.disabled = True
|
||||
return
|
||||
self.disabled = False
|
||||
self.group = group
|
||||
self.world_size = dist.get_world_size(self.group)
|
||||
|
||||
def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
|
||||
dist.all_reduce(x, group=self.group)
|
||||
return x
|
||||
|
||||
def gather(
|
||||
self, input_: torch.Tensor, rank_in_group: int, dst: int = 0, dim: int = -1
|
||||
):
|
||||
# For xpu path, gather doesn't work properly together with ray
|
||||
# cluster so we use all_gather instead for now.
|
||||
input_size = input_.size()
|
||||
# Allocate output tensor.
|
||||
output_tensor = torch.empty(
|
||||
(self.world_size,) + input_size, dtype=input_.dtype, device=input_.device
|
||||
)
|
||||
# All-gather.
|
||||
torch.distributed.all_gather_into_tensor(
|
||||
output_tensor, input_, group=self.group
|
||||
)
|
||||
if rank_in_group == dst:
|
||||
# Reshape
|
||||
output_tensor = output_tensor.movedim(0, dim)
|
||||
output_tensor = output_tensor.reshape(
|
||||
input_size[:dim]
|
||||
+ (self.world_size * input_size[dim],)
|
||||
+ input_size[dim + 1 :]
|
||||
)
|
||||
else:
|
||||
output_tensor = None
|
||||
return output_tensor
|
||||
1275
python/sglang/srt/distributed/parallel_state.py
Normal file
1275
python/sglang/srt/distributed/parallel_state.py
Normal file
File diff suppressed because it is too large
Load Diff
223
python/sglang/srt/distributed/utils.py
Normal file
223
python/sglang/srt/distributed/utils.py
Normal file
@@ -0,0 +1,223 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/utils.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Adapted from
|
||||
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
|
||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||
import dataclasses
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import time
|
||||
from collections import deque
|
||||
from typing import Any, Deque, Dict, Optional, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
from torch.distributed import TCPStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def ensure_divisibility(numerator, denominator):
|
||||
"""Ensure that numerator is divisible by the denominator."""
|
||||
assert numerator % denominator == 0, "{} is not divisible by {}".format(
|
||||
numerator, denominator
|
||||
)
|
||||
|
||||
|
||||
def divide(numerator, denominator):
|
||||
"""Ensure that numerator is divisible by the denominator and return
|
||||
the division value."""
|
||||
ensure_divisibility(numerator, denominator)
|
||||
return numerator // denominator
|
||||
|
||||
|
||||
def split_tensor_along_last_dim(
|
||||
tensor: torch.Tensor,
|
||||
num_partitions: int,
|
||||
contiguous_split_chunks: bool = False,
|
||||
) -> Sequence[torch.Tensor]:
|
||||
"""Split a tensor along its last dimension.
|
||||
|
||||
Arguments:
|
||||
tensor: input tensor.
|
||||
num_partitions: number of partitions to split the tensor
|
||||
contiguous_split_chunks: If True, make each chunk contiguous
|
||||
in memory.
|
||||
|
||||
Returns:
|
||||
A list of Tensors
|
||||
"""
|
||||
# Get the size and dimension.
|
||||
last_dim = tensor.dim() - 1
|
||||
last_dim_size = divide(tensor.size()[last_dim], num_partitions)
|
||||
# Split.
|
||||
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
|
||||
# NOTE: torch.split does not create contiguous tensors by default.
|
||||
if contiguous_split_chunks:
|
||||
return tuple(chunk.contiguous() for chunk in tensor_list)
|
||||
|
||||
return tensor_list
|
||||
|
||||
|
||||
def get_pp_indices(
|
||||
num_hidden_layers: int, pp_rank: int, pp_size: int
|
||||
) -> Tuple[int, int]:
|
||||
"""Try to evenly distribute layers across partitions.
|
||||
If the number of layers is not divisible by the number of partitions,
|
||||
the last partition will have the remaining layers.
|
||||
"""
|
||||
# partition_list_str can be set to None in sglang
|
||||
partition_list_str = os.getenv("SGLANG_PP_LAYER_PARTITION", None)
|
||||
if partition_list_str is not None:
|
||||
try:
|
||||
partitions = [int(layer) for layer in partition_list_str.split(",")]
|
||||
except ValueError as err:
|
||||
raise ValueError(
|
||||
"Invalid partition string: {}".format(partition_list_str)
|
||||
) from err
|
||||
if len(partitions) != pp_size:
|
||||
raise ValueError(f"{len(partitions)=} does not match {pp_size=}.")
|
||||
if sum(partitions) != num_hidden_layers:
|
||||
raise ValueError(f"{sum(partitions)=} does not match {num_hidden_layers=}.")
|
||||
start_layer = sum(partitions[:pp_rank])
|
||||
end_layer = start_layer + partitions[pp_rank]
|
||||
else:
|
||||
layers_per_partition = num_hidden_layers // pp_size
|
||||
start_layer = pp_rank * layers_per_partition
|
||||
end_layer = start_layer + layers_per_partition
|
||||
|
||||
if pp_rank == pp_size - 1:
|
||||
end_layer = num_hidden_layers
|
||||
|
||||
return (start_layer, end_layer)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class StatelessProcessGroup:
|
||||
"""A dataclass to hold a metadata store, and the rank, world_size of the
|
||||
group. Only use it to communicate metadata between processes.
|
||||
For data-plane communication, create NCCL-related objects.
|
||||
"""
|
||||
|
||||
rank: int
|
||||
world_size: int
|
||||
store: torch._C._distributed_c10d.Store
|
||||
data_expiration_seconds: int = 3600 # 1 hour
|
||||
|
||||
# dst rank -> counter
|
||||
send_dst_counter: Dict[int, int] = dataclasses.field(default_factory=dict)
|
||||
# src rank -> counter
|
||||
recv_src_counter: Dict[int, int] = dataclasses.field(default_factory=dict)
|
||||
broadcast_send_counter: int = 0
|
||||
broadcast_recv_src_counter: Dict[int, int] = dataclasses.field(default_factory=dict)
|
||||
|
||||
# A deque to store the data entries, with key and timestamp.
|
||||
entries: Deque[Tuple[str, float]] = dataclasses.field(default_factory=deque)
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.rank < self.world_size
|
||||
self.send_dst_counter = {i: 0 for i in range(self.world_size)}
|
||||
self.recv_src_counter = {i: 0 for i in range(self.world_size)}
|
||||
self.broadcast_recv_src_counter = {i: 0 for i in range(self.world_size)}
|
||||
|
||||
def send_obj(self, obj: Any, dst: int):
|
||||
"""Send an object to a destination rank."""
|
||||
self.expire_data()
|
||||
key = f"send_to/{dst}/{self.send_dst_counter[dst]}"
|
||||
self.store.set(key, pickle.dumps(obj))
|
||||
self.send_dst_counter[dst] += 1
|
||||
self.entries.append((key, time.time()))
|
||||
|
||||
def expire_data(self):
|
||||
"""Expire data that is older than `data_expiration_seconds` seconds."""
|
||||
while self.entries:
|
||||
# check the oldest entry
|
||||
key, timestamp = self.entries[0]
|
||||
if time.time() - timestamp > self.data_expiration_seconds:
|
||||
self.store.delete_key(key)
|
||||
self.entries.popleft()
|
||||
else:
|
||||
break
|
||||
|
||||
def recv_obj(self, src: int) -> Any:
|
||||
"""Receive an object from a source rank."""
|
||||
obj = pickle.loads(
|
||||
self.store.get(f"send_to/{self.rank}/{self.recv_src_counter[src]}")
|
||||
)
|
||||
self.recv_src_counter[src] += 1
|
||||
return obj
|
||||
|
||||
def broadcast_obj(self, obj: Optional[Any], src: int) -> Any:
|
||||
"""Broadcast an object from a source rank to all other ranks.
|
||||
It does not clean up after all ranks have received the object.
|
||||
Use it for limited times, e.g., for initialization.
|
||||
"""
|
||||
if self.rank == src:
|
||||
self.expire_data()
|
||||
key = f"broadcast_from/{src}/" f"{self.broadcast_send_counter}"
|
||||
self.store.set(key, pickle.dumps(obj))
|
||||
self.broadcast_send_counter += 1
|
||||
self.entries.append((key, time.time()))
|
||||
return obj
|
||||
else:
|
||||
key = f"broadcast_from/{src}/" f"{self.broadcast_recv_src_counter[src]}"
|
||||
recv_obj = pickle.loads(self.store.get(key))
|
||||
self.broadcast_recv_src_counter[src] += 1
|
||||
return recv_obj
|
||||
|
||||
def all_gather_obj(self, obj: Any) -> list[Any]:
|
||||
"""All gather an object from all ranks."""
|
||||
gathered_objs = []
|
||||
for i in range(self.world_size):
|
||||
if i == self.rank:
|
||||
gathered_objs.append(obj)
|
||||
self.broadcast_obj(obj, src=self.rank)
|
||||
else:
|
||||
recv_obj = self.broadcast_obj(None, src=i)
|
||||
gathered_objs.append(recv_obj)
|
||||
return gathered_objs
|
||||
|
||||
def barrier(self):
|
||||
"""A barrier to synchronize all ranks."""
|
||||
for i in range(self.world_size):
|
||||
if i == self.rank:
|
||||
self.broadcast_obj(None, src=self.rank)
|
||||
else:
|
||||
self.broadcast_obj(None, src=i)
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
host: str,
|
||||
port: int,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
data_expiration_seconds: int = 3600,
|
||||
) -> "StatelessProcessGroup":
|
||||
"""A replacement for `torch.distributed.init_process_group` that does not
|
||||
pollute the global state.
|
||||
|
||||
If we have process A and process B called `torch.distributed.init_process_group`
|
||||
to form a group, and then we want to form another group with process A, B, C,
|
||||
D, it is not possible in PyTorch, because process A and process B have already
|
||||
formed a group, and process C and process D cannot join that group. This
|
||||
function is a workaround for this issue.
|
||||
|
||||
`torch.distributed.init_process_group` is a global call, while this function
|
||||
is a stateless call. It will return a `StatelessProcessGroup` object that can be
|
||||
used for exchanging metadata. With this function, process A and process B
|
||||
can call `StatelessProcessGroup.create` to form a group, and then process A, B,
|
||||
C, and D can call `StatelessProcessGroup.create` to form another group.
|
||||
""" # noqa
|
||||
store = TCPStore(
|
||||
host_name=host,
|
||||
port=port,
|
||||
world_size=world_size,
|
||||
is_master=(rank == 0),
|
||||
)
|
||||
|
||||
return StatelessProcessGroup(
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
store=store,
|
||||
data_expiration_seconds=data_expiration_seconds,
|
||||
)
|
||||
@@ -30,6 +30,7 @@ import subprocess
|
||||
import tempfile
|
||||
import time
|
||||
import warnings
|
||||
from functools import lru_cache
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
from io import BytesIO
|
||||
from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, Union
|
||||
@@ -67,6 +68,22 @@ def is_hip() -> bool:
|
||||
return torch.version.hip is not None
|
||||
|
||||
|
||||
def is_cuda():
|
||||
return hasattr(torch, "cuda") and torch.cuda.is_available()
|
||||
|
||||
|
||||
def is_cuda_alike():
|
||||
return is_cuda() or is_hip()
|
||||
|
||||
|
||||
def is_hpu() -> bool:
|
||||
return hasattr(torch, "hpu") and torch.hpu.is_available()
|
||||
|
||||
|
||||
def is_xpu() -> bool:
|
||||
return hasattr(torch, "xpu") and torch.xpu.is_available()
|
||||
|
||||
|
||||
def is_flashinfer_available():
|
||||
"""
|
||||
Check whether flashinfer is available.
|
||||
@@ -967,6 +984,12 @@ def get_device_name(device_id: int = 0) -> str:
|
||||
sglang_lib = Library("sglang", "FRAGMENT") # noqa
|
||||
|
||||
|
||||
# Some backends use pytorch version < 2.4.0 which doesn't
|
||||
# support `torch.library.custom_op`.
|
||||
def supports_custom_op() -> bool:
|
||||
return hasattr(torch.library, "custom_op")
|
||||
|
||||
|
||||
def direct_register_custom_op(
|
||||
op_name: str,
|
||||
op_func: Callable,
|
||||
@@ -1043,3 +1066,45 @@ def set_gpu_proc_affinity(
|
||||
def get_bool_env_var(name: str, default: str = "false") -> bool:
|
||||
value = os.getenv(name, default)
|
||||
return value.lower() in ("true", "1")
|
||||
|
||||
|
||||
@lru_cache(maxsize=8)
|
||||
def _cuda_device_count_stateless(cuda_visible_devices: Optional[str] = None) -> int:
|
||||
# Note: cuda_visible_devices is not used, but we keep it as an argument for
|
||||
# LRU Cache purposes.
|
||||
|
||||
# Code below is based on
|
||||
# https://github.com/pytorch/pytorch/blob/
|
||||
# c1cd946818442aca8c7f812b16d187ce1586c3bc/
|
||||
# torch/cuda/__init__.py#L831C1-L831C17
|
||||
import torch.cuda
|
||||
import torch.version
|
||||
|
||||
if not torch.cuda._is_compiled():
|
||||
return 0
|
||||
if is_hip():
|
||||
# ROCm uses amdsmi instead of nvml for stateless device count
|
||||
# This requires a sufficiently modern version of Torch 2.4.0
|
||||
raw_count = (
|
||||
torch.cuda._device_count_amdsmi()
|
||||
if (hasattr(torch.cuda, "_device_count_amdsmi"))
|
||||
else -1
|
||||
)
|
||||
else:
|
||||
raw_count = torch.cuda._device_count_nvml()
|
||||
r = torch._C._cuda_getDeviceCount() if raw_count < 0 else raw_count
|
||||
return r
|
||||
|
||||
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/utils.py
|
||||
def cuda_device_count_stateless() -> int:
|
||||
"""Get number of CUDA devices, caching based on the value of
|
||||
CUDA_VISIBLE_DEVICES at the time of call.
|
||||
|
||||
This should be used instead of torch.cuda.device_count()
|
||||
unless CUDA_VISIBLE_DEVICES has already been set to the desired
|
||||
value."""
|
||||
|
||||
# This can be removed and simply replaced with torch.cuda.get_device_count
|
||||
# after https://github.com/pytorch/pytorch/pull/122815 is released.
|
||||
return _cuda_device_count_stateless(os.environ.get("CUDA_VISIBLE_DEVICES", None))
|
||||
|
||||
Reference in New Issue
Block a user