Add support for NCCL symmetric memory for TP allreduces (#8238)

This commit is contained in:
Nicolas Castet
2025-08-01 18:30:55 -05:00
committed by GitHub
parent b89d37cb11
commit 82e6c3a65a
13 changed files with 266 additions and 30 deletions

View File

@@ -75,6 +75,7 @@ class PyNcclCommunicator:
self.available = True
self.disabled = False
self.nccl_version = self.nccl.ncclGetRawVersion()
if self.rank == 0:
logger.info("sglang is using nccl==%s", self.nccl.ncclGetVersion())
@@ -259,6 +260,12 @@ class PyNcclCommunicator:
cudaStream_t(stream.cuda_stream),
)
def register_comm_window_raw(self, ptr: int, size: int):
return self.nccl.ncclCommWindowRegister(self.comm, buffer_type(ptr), size, 1)
def deregister_comm_window(self, window):
return self.nccl.ncclCommWindowDeregister(self.comm, window)
@contextmanager
def change_state(
self, enable: Optional[bool] = None, stream: Optional[torch.cuda.Stream] = None

View File

@@ -0,0 +1,133 @@
import tempfile
import torch
from packaging import version
from torch.cuda.memory import CUDAPluggableAllocator
from sglang.srt.distributed.parallel_state import GroupCoordinator
from sglang.srt.managers.schedule_batch import global_server_args_dict
nccl_allocator_source = """
#include <nccl.h>
extern "C" {
void* nccl_alloc_plug(size_t size, int device, void* stream) {
void* ptr;
ncclResult_t err = ncclMemAlloc(&ptr, size);
return ptr;
}
void nccl_free_plug(void* ptr, size_t size, int device, void* stream) {
ncclResult_t err = ncclMemFree(ptr);
}
}
"""
_allocator = None
_mem_pool = None
_registered_base_addrs = set()
_graph_pool_id = None
def is_symmetric_memory_enabled():
return global_server_args_dict["enable_symm_mem"]
def set_graph_pool_id(graph_pool_id):
global _graph_pool_id
_graph_pool_id = graph_pool_id
def get_nccl_mem_pool():
global _allocator, _mem_pool
if _mem_pool is None:
out_dir = tempfile.gettempdir()
nccl_allocator_libname = "nccl_allocator"
torch.utils.cpp_extension.load_inline(
name=nccl_allocator_libname,
cpp_sources=nccl_allocator_source,
with_cuda=True,
extra_ldflags=["-lnccl"],
verbose=True,
is_python_module=False,
build_directory=out_dir,
)
_allocator = CUDAPluggableAllocator(
f"{out_dir}/{nccl_allocator_libname}.so",
"nccl_alloc_plug",
"nccl_free_plug",
).allocator()
_mem_pool = torch.cuda.MemPool(_allocator)
return _mem_pool
class use_symmetric_memory:
def __init__(self, group_coordinator: GroupCoordinator):
if not is_symmetric_memory_enabled():
self.group_coordinator = None
self._mem_pool_ctx = None
self.is_graph_capture = None
self.device = None
self.pre_2_8_0 = None
else:
self.group_coordinator = group_coordinator
self._mem_pool_ctx = torch.cuda.use_mem_pool(get_nccl_mem_pool())
self.is_graph_capture = torch.cuda.is_current_stream_capturing()
self.device = torch.cuda.current_device()
self.pre_2_8_0 = version.parse(torch.__version__) < version.parse("2.8.0")
def __enter__(self):
if not is_symmetric_memory_enabled():
return self
assert (
self.group_coordinator.pynccl_comm is not None
), f"Symmetric memory requires pynccl to be enabled in group '{self.group_coordinator.group_name}'"
assert (
self.group_coordinator.pynccl_comm.nccl_version >= 22703
), "NCCL version 2.27.3 or higher is required for NCCL symmetric memory"
if self.is_graph_capture:
assert (
_graph_pool_id is not None
), "graph_pool_id is not set under graph capture"
# Pause graph memory pool to use symmetric memory with cuda graph
if self.pre_2_8_0:
torch._C._cuda_endAllocateCurrentStreamToPool(
self.device, _graph_pool_id
)
else:
torch._C._cuda_endAllocateToPool(self.device, _graph_pool_id)
self._mem_pool_ctx.__enter__()
return self
def tag(self, tensor: torch.Tensor):
if not is_symmetric_memory_enabled():
return
tensor.symmetric_memory = True
def __exit__(self, exc_type, exc_val, exc_tb):
if not is_symmetric_memory_enabled():
return
global _registered_base_addrs
self._mem_pool_ctx.__exit__(exc_type, exc_val, exc_tb)
for segment in get_nccl_mem_pool().snapshot():
if segment["address"] not in _registered_base_addrs:
if segment["stream"] == 0 and self.pre_2_8_0:
# PyTorch version < 2.8.0 has a multi-thread MemPool bug
# See https://github.com/pytorch/pytorch/issues/152861
# Fixed at https://github.com/pytorch/pytorch/commit/f01e628e3b31852983ab30b25bf251f557ba9c0b
# WAR is to skip allocations on the default stream since the forward_pass thread always runs on a custom stream
continue
self.group_coordinator.pynccl_comm.register_comm_window_raw(
segment["address"], segment["total_size"]
)
_registered_base_addrs.add(segment["address"])
if self.is_graph_capture:
if self.pre_2_8_0:
torch._C._cuda_beginAllocateToPool(self.device, _graph_pool_id)
else:
torch._C._cuda_beginAllocateCurrentThreadToPool(
self.device, _graph_pool_id
)

View File

@@ -67,6 +67,7 @@ def find_nccl_library() -> str:
ncclResult_t = ctypes.c_int
ncclComm_t = ctypes.c_void_p
ncclWindow_t = ctypes.c_void_p
class ncclUniqueId(ctypes.Structure):
@@ -279,6 +280,23 @@ class NCCLLibrary:
Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]),
]
exported_functions_symm_mem = [
# ncclResult_t ncclCommWindowRegister(ncclComm_t comm, void* buff, size_t size, ncclWindow_t* win, int winFlags);
Function(
"ncclCommWindowRegister",
ncclResult_t,
[
ncclComm_t,
buffer_type,
ctypes.c_size_t,
ctypes.POINTER(ncclWindow_t),
ctypes.c_int,
],
),
# ncclResult_t ncclCommWindowDeregister(ncclComm_t comm, ncclWindow_t win);
Function("ncclCommWindowDeregister", ncclResult_t, [ncclComm_t, ncclWindow_t]),
]
# class attribute to store the mapping from the path to the library
# to avoid loading the same library multiple times
path_to_library_cache: Dict[str, Any] = {}
@@ -312,7 +330,10 @@ class NCCLLibrary:
if so_file not in NCCLLibrary.path_to_dict_mapping:
_funcs: Dict[str, Any] = {}
for func in NCCLLibrary.exported_functions:
exported_functions = NCCLLibrary.exported_functions
if hasattr(self.lib, "ncclCommWindowRegister"):
exported_functions.extend(NCCLLibrary.exported_functions_symm_mem)
for func in exported_functions:
f = getattr(self.lib, func.name)
f.restype = func.restype
f.argtypes = func.argtypes
@@ -328,10 +349,14 @@ class NCCLLibrary:
error_str = self.ncclGetErrorString(result)
raise RuntimeError(f"NCCL error: {error_str}")
def ncclGetVersion(self) -> str:
def ncclGetRawVersion(self) -> int:
version = ctypes.c_int()
self.NCCL_CHECK(self._funcs["ncclGetVersion"](ctypes.byref(version)))
version_str = str(version.value)
# something like 21903
return version.value
def ncclGetVersion(self) -> str:
version_str = str(self.ncclGetRawVersion())
# something like 21903 --> "2.19.3"
major = version_str[0].lstrip("0")
minor = version_str[1:3].lstrip("0")
@@ -460,6 +485,20 @@ class NCCLLibrary:
def ncclCommDestroy(self, comm: ncclComm_t) -> None:
self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm))
def ncclCommWindowRegister(
self, comm: ncclComm_t, buff: buffer_type, size: int, win_flags: int
) -> ncclWindow_t:
window = ncclWindow_t()
self.NCCL_CHECK(
self._funcs["ncclCommWindowRegister"](
comm, buff, size, ctypes.byref(window), win_flags
)
)
return window
def ncclCommWindowDeregister(self, comm: ncclComm_t, window: ncclWindow_t) -> None:
self.NCCL_CHECK(self._funcs["ncclCommWindowDeregister"](comm, window))
__all__ = [
"NCCLLibrary",

View File

@@ -497,6 +497,17 @@ class GroupCoordinator:
if self.npu_communicator is not None and not self.npu_communicator.disabled:
return self.npu_communicator.all_reduce(input_)
if (
self.pynccl_comm is not None
and hasattr(input_, "symmetric_memory")
and input_.symmetric_memory
):
with self.pynccl_comm.change_state(
enable=True, stream=torch.cuda.current_stream()
):
self.pynccl_comm.all_reduce(input_)
return input_
outplace_all_reduce_method = None
if (
self.qr_comm is not None