Add support for NCCL symmetric memory for TP allreduces (#8238)
This commit is contained in:
@@ -75,6 +75,7 @@ class PyNcclCommunicator:
|
||||
self.available = True
|
||||
self.disabled = False
|
||||
|
||||
self.nccl_version = self.nccl.ncclGetRawVersion()
|
||||
if self.rank == 0:
|
||||
logger.info("sglang is using nccl==%s", self.nccl.ncclGetVersion())
|
||||
|
||||
@@ -259,6 +260,12 @@ class PyNcclCommunicator:
|
||||
cudaStream_t(stream.cuda_stream),
|
||||
)
|
||||
|
||||
def register_comm_window_raw(self, ptr: int, size: int):
|
||||
return self.nccl.ncclCommWindowRegister(self.comm, buffer_type(ptr), size, 1)
|
||||
|
||||
def deregister_comm_window(self, window):
|
||||
return self.nccl.ncclCommWindowDeregister(self.comm, window)
|
||||
|
||||
@contextmanager
|
||||
def change_state(
|
||||
self, enable: Optional[bool] = None, stream: Optional[torch.cuda.Stream] = None
|
||||
|
||||
@@ -0,0 +1,133 @@
|
||||
import tempfile
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
from torch.cuda.memory import CUDAPluggableAllocator
|
||||
|
||||
from sglang.srt.distributed.parallel_state import GroupCoordinator
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
|
||||
nccl_allocator_source = """
|
||||
#include <nccl.h>
|
||||
extern "C" {
|
||||
|
||||
void* nccl_alloc_plug(size_t size, int device, void* stream) {
|
||||
void* ptr;
|
||||
ncclResult_t err = ncclMemAlloc(&ptr, size);
|
||||
return ptr;
|
||||
|
||||
}
|
||||
|
||||
void nccl_free_plug(void* ptr, size_t size, int device, void* stream) {
|
||||
ncclResult_t err = ncclMemFree(ptr);
|
||||
}
|
||||
|
||||
}
|
||||
"""
|
||||
|
||||
_allocator = None
|
||||
_mem_pool = None
|
||||
_registered_base_addrs = set()
|
||||
_graph_pool_id = None
|
||||
|
||||
|
||||
def is_symmetric_memory_enabled():
|
||||
return global_server_args_dict["enable_symm_mem"]
|
||||
|
||||
|
||||
def set_graph_pool_id(graph_pool_id):
|
||||
global _graph_pool_id
|
||||
_graph_pool_id = graph_pool_id
|
||||
|
||||
|
||||
def get_nccl_mem_pool():
|
||||
global _allocator, _mem_pool
|
||||
if _mem_pool is None:
|
||||
out_dir = tempfile.gettempdir()
|
||||
nccl_allocator_libname = "nccl_allocator"
|
||||
torch.utils.cpp_extension.load_inline(
|
||||
name=nccl_allocator_libname,
|
||||
cpp_sources=nccl_allocator_source,
|
||||
with_cuda=True,
|
||||
extra_ldflags=["-lnccl"],
|
||||
verbose=True,
|
||||
is_python_module=False,
|
||||
build_directory=out_dir,
|
||||
)
|
||||
_allocator = CUDAPluggableAllocator(
|
||||
f"{out_dir}/{nccl_allocator_libname}.so",
|
||||
"nccl_alloc_plug",
|
||||
"nccl_free_plug",
|
||||
).allocator()
|
||||
_mem_pool = torch.cuda.MemPool(_allocator)
|
||||
return _mem_pool
|
||||
|
||||
|
||||
class use_symmetric_memory:
|
||||
def __init__(self, group_coordinator: GroupCoordinator):
|
||||
if not is_symmetric_memory_enabled():
|
||||
self.group_coordinator = None
|
||||
self._mem_pool_ctx = None
|
||||
self.is_graph_capture = None
|
||||
self.device = None
|
||||
self.pre_2_8_0 = None
|
||||
else:
|
||||
self.group_coordinator = group_coordinator
|
||||
self._mem_pool_ctx = torch.cuda.use_mem_pool(get_nccl_mem_pool())
|
||||
self.is_graph_capture = torch.cuda.is_current_stream_capturing()
|
||||
self.device = torch.cuda.current_device()
|
||||
self.pre_2_8_0 = version.parse(torch.__version__) < version.parse("2.8.0")
|
||||
|
||||
def __enter__(self):
|
||||
if not is_symmetric_memory_enabled():
|
||||
return self
|
||||
assert (
|
||||
self.group_coordinator.pynccl_comm is not None
|
||||
), f"Symmetric memory requires pynccl to be enabled in group '{self.group_coordinator.group_name}'"
|
||||
assert (
|
||||
self.group_coordinator.pynccl_comm.nccl_version >= 22703
|
||||
), "NCCL version 2.27.3 or higher is required for NCCL symmetric memory"
|
||||
if self.is_graph_capture:
|
||||
assert (
|
||||
_graph_pool_id is not None
|
||||
), "graph_pool_id is not set under graph capture"
|
||||
# Pause graph memory pool to use symmetric memory with cuda graph
|
||||
if self.pre_2_8_0:
|
||||
torch._C._cuda_endAllocateCurrentStreamToPool(
|
||||
self.device, _graph_pool_id
|
||||
)
|
||||
else:
|
||||
torch._C._cuda_endAllocateToPool(self.device, _graph_pool_id)
|
||||
self._mem_pool_ctx.__enter__()
|
||||
return self
|
||||
|
||||
def tag(self, tensor: torch.Tensor):
|
||||
if not is_symmetric_memory_enabled():
|
||||
return
|
||||
tensor.symmetric_memory = True
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if not is_symmetric_memory_enabled():
|
||||
return
|
||||
global _registered_base_addrs
|
||||
self._mem_pool_ctx.__exit__(exc_type, exc_val, exc_tb)
|
||||
for segment in get_nccl_mem_pool().snapshot():
|
||||
if segment["address"] not in _registered_base_addrs:
|
||||
if segment["stream"] == 0 and self.pre_2_8_0:
|
||||
# PyTorch version < 2.8.0 has a multi-thread MemPool bug
|
||||
# See https://github.com/pytorch/pytorch/issues/152861
|
||||
# Fixed at https://github.com/pytorch/pytorch/commit/f01e628e3b31852983ab30b25bf251f557ba9c0b
|
||||
# WAR is to skip allocations on the default stream since the forward_pass thread always runs on a custom stream
|
||||
continue
|
||||
self.group_coordinator.pynccl_comm.register_comm_window_raw(
|
||||
segment["address"], segment["total_size"]
|
||||
)
|
||||
_registered_base_addrs.add(segment["address"])
|
||||
|
||||
if self.is_graph_capture:
|
||||
if self.pre_2_8_0:
|
||||
torch._C._cuda_beginAllocateToPool(self.device, _graph_pool_id)
|
||||
else:
|
||||
torch._C._cuda_beginAllocateCurrentThreadToPool(
|
||||
self.device, _graph_pool_id
|
||||
)
|
||||
@@ -67,6 +67,7 @@ def find_nccl_library() -> str:
|
||||
|
||||
ncclResult_t = ctypes.c_int
|
||||
ncclComm_t = ctypes.c_void_p
|
||||
ncclWindow_t = ctypes.c_void_p
|
||||
|
||||
|
||||
class ncclUniqueId(ctypes.Structure):
|
||||
@@ -279,6 +280,23 @@ class NCCLLibrary:
|
||||
Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]),
|
||||
]
|
||||
|
||||
exported_functions_symm_mem = [
|
||||
# ncclResult_t ncclCommWindowRegister(ncclComm_t comm, void* buff, size_t size, ncclWindow_t* win, int winFlags);
|
||||
Function(
|
||||
"ncclCommWindowRegister",
|
||||
ncclResult_t,
|
||||
[
|
||||
ncclComm_t,
|
||||
buffer_type,
|
||||
ctypes.c_size_t,
|
||||
ctypes.POINTER(ncclWindow_t),
|
||||
ctypes.c_int,
|
||||
],
|
||||
),
|
||||
# ncclResult_t ncclCommWindowDeregister(ncclComm_t comm, ncclWindow_t win);
|
||||
Function("ncclCommWindowDeregister", ncclResult_t, [ncclComm_t, ncclWindow_t]),
|
||||
]
|
||||
|
||||
# class attribute to store the mapping from the path to the library
|
||||
# to avoid loading the same library multiple times
|
||||
path_to_library_cache: Dict[str, Any] = {}
|
||||
@@ -312,7 +330,10 @@ class NCCLLibrary:
|
||||
|
||||
if so_file not in NCCLLibrary.path_to_dict_mapping:
|
||||
_funcs: Dict[str, Any] = {}
|
||||
for func in NCCLLibrary.exported_functions:
|
||||
exported_functions = NCCLLibrary.exported_functions
|
||||
if hasattr(self.lib, "ncclCommWindowRegister"):
|
||||
exported_functions.extend(NCCLLibrary.exported_functions_symm_mem)
|
||||
for func in exported_functions:
|
||||
f = getattr(self.lib, func.name)
|
||||
f.restype = func.restype
|
||||
f.argtypes = func.argtypes
|
||||
@@ -328,10 +349,14 @@ class NCCLLibrary:
|
||||
error_str = self.ncclGetErrorString(result)
|
||||
raise RuntimeError(f"NCCL error: {error_str}")
|
||||
|
||||
def ncclGetVersion(self) -> str:
|
||||
def ncclGetRawVersion(self) -> int:
|
||||
version = ctypes.c_int()
|
||||
self.NCCL_CHECK(self._funcs["ncclGetVersion"](ctypes.byref(version)))
|
||||
version_str = str(version.value)
|
||||
# something like 21903
|
||||
return version.value
|
||||
|
||||
def ncclGetVersion(self) -> str:
|
||||
version_str = str(self.ncclGetRawVersion())
|
||||
# something like 21903 --> "2.19.3"
|
||||
major = version_str[0].lstrip("0")
|
||||
minor = version_str[1:3].lstrip("0")
|
||||
@@ -460,6 +485,20 @@ class NCCLLibrary:
|
||||
def ncclCommDestroy(self, comm: ncclComm_t) -> None:
|
||||
self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm))
|
||||
|
||||
def ncclCommWindowRegister(
|
||||
self, comm: ncclComm_t, buff: buffer_type, size: int, win_flags: int
|
||||
) -> ncclWindow_t:
|
||||
window = ncclWindow_t()
|
||||
self.NCCL_CHECK(
|
||||
self._funcs["ncclCommWindowRegister"](
|
||||
comm, buff, size, ctypes.byref(window), win_flags
|
||||
)
|
||||
)
|
||||
return window
|
||||
|
||||
def ncclCommWindowDeregister(self, comm: ncclComm_t, window: ncclWindow_t) -> None:
|
||||
self.NCCL_CHECK(self._funcs["ncclCommWindowDeregister"](comm, window))
|
||||
|
||||
|
||||
__all__ = [
|
||||
"NCCLLibrary",
|
||||
|
||||
@@ -497,6 +497,17 @@ class GroupCoordinator:
|
||||
if self.npu_communicator is not None and not self.npu_communicator.disabled:
|
||||
return self.npu_communicator.all_reduce(input_)
|
||||
|
||||
if (
|
||||
self.pynccl_comm is not None
|
||||
and hasattr(input_, "symmetric_memory")
|
||||
and input_.symmetric_memory
|
||||
):
|
||||
with self.pynccl_comm.change_state(
|
||||
enable=True, stream=torch.cuda.current_stream()
|
||||
):
|
||||
self.pynccl_comm.all_reduce(input_)
|
||||
return input_
|
||||
|
||||
outplace_all_reduce_method = None
|
||||
if (
|
||||
self.qr_comm is not None
|
||||
|
||||
Reference in New Issue
Block a user