From 82e6c3a65ab3701c3ef498bc51fbe447e8c6cbe5 Mon Sep 17 00:00:00 2001 From: Nicolas Castet <26874160+nvcastet@users.noreply.github.com> Date: Fri, 1 Aug 2025 18:30:55 -0500 Subject: [PATCH] Add support for NCCL symmetric memory for TP allreduces (#8238) --- docs/backend/server_arguments.md | 1 + .../device_communicators/pynccl.py | 7 + .../device_communicators/pynccl_allocator.py | 133 ++++++++++++++++++ .../device_communicators/pynccl_wrapper.py | 45 +++++- .../sglang/srt/distributed/parallel_state.py | 11 ++ python/sglang/srt/entrypoints/engine.py | 5 +- python/sglang/srt/layers/linear.py | 8 +- .../srt/layers/moe/fused_moe_triton/layer.py | 43 +++--- .../srt/layers/vocab_parallel_embedding.py | 8 +- python/sglang/srt/managers/schedule_batch.py | 1 + .../srt/model_executor/cuda_graph_runner.py | 13 +- python/sglang/srt/models/deepseek_v2.py | 15 +- python/sglang/srt/server_args.py | 6 + 13 files changed, 266 insertions(+), 30 deletions(-) create mode 100644 python/sglang/srt/distributed/device_communicators/pynccl_allocator.py diff --git a/docs/backend/server_arguments.md b/docs/backend/server_arguments.md index ac56aebf6..047458123 100644 --- a/docs/backend/server_arguments.md +++ b/docs/backend/server_arguments.md @@ -251,6 +251,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s | `--disable-cuda-graph-padding` | Disable cuda graph when padding is needed. Still uses cuda graph when padding is not needed. | False | | `--enable-profile-cuda-graph` | Enable profiling of cuda graph capture. | False | | `--enable-nccl-nvls` | Enable NCCL NVLS for prefill heavy requests when available. | False | +| `--enable-symm-mem` | Enable NCCL symmetric memory for fast collectives. | False | | `--enable-tokenizer-batch-encode` | Enable batch tokenization for improved performance when processing multiple text inputs. Do not use with image inputs, pre-tokenized input_ids, or input_embeds. | False | | `--disable-outlines-disk-cache` | Disable disk cache of outlines to avoid possible crashes related to file system or high concurrency. | False | | `--disable-custom-all-reduce` | Disable the custom all-reduce kernel and fall back to NCCL. | False | diff --git a/python/sglang/srt/distributed/device_communicators/pynccl.py b/python/sglang/srt/distributed/device_communicators/pynccl.py index 6459f70fd..81dd81780 100644 --- a/python/sglang/srt/distributed/device_communicators/pynccl.py +++ b/python/sglang/srt/distributed/device_communicators/pynccl.py @@ -75,6 +75,7 @@ class PyNcclCommunicator: self.available = True self.disabled = False + self.nccl_version = self.nccl.ncclGetRawVersion() if self.rank == 0: logger.info("sglang is using nccl==%s", self.nccl.ncclGetVersion()) @@ -259,6 +260,12 @@ class PyNcclCommunicator: cudaStream_t(stream.cuda_stream), ) + def register_comm_window_raw(self, ptr: int, size: int): + return self.nccl.ncclCommWindowRegister(self.comm, buffer_type(ptr), size, 1) + + def deregister_comm_window(self, window): + return self.nccl.ncclCommWindowDeregister(self.comm, window) + @contextmanager def change_state( self, enable: Optional[bool] = None, stream: Optional[torch.cuda.Stream] = None diff --git a/python/sglang/srt/distributed/device_communicators/pynccl_allocator.py b/python/sglang/srt/distributed/device_communicators/pynccl_allocator.py new file mode 100644 index 000000000..d7274cf2c --- /dev/null +++ b/python/sglang/srt/distributed/device_communicators/pynccl_allocator.py @@ -0,0 +1,133 @@ +import tempfile + +import torch +from packaging import version +from torch.cuda.memory import CUDAPluggableAllocator + +from sglang.srt.distributed.parallel_state import GroupCoordinator +from sglang.srt.managers.schedule_batch import global_server_args_dict + +nccl_allocator_source = """ +#include +extern "C" { + +void* nccl_alloc_plug(size_t size, int device, void* stream) { + void* ptr; + ncclResult_t err = ncclMemAlloc(&ptr, size); + return ptr; + +} + +void nccl_free_plug(void* ptr, size_t size, int device, void* stream) { + ncclResult_t err = ncclMemFree(ptr); +} + +} +""" + +_allocator = None +_mem_pool = None +_registered_base_addrs = set() +_graph_pool_id = None + + +def is_symmetric_memory_enabled(): + return global_server_args_dict["enable_symm_mem"] + + +def set_graph_pool_id(graph_pool_id): + global _graph_pool_id + _graph_pool_id = graph_pool_id + + +def get_nccl_mem_pool(): + global _allocator, _mem_pool + if _mem_pool is None: + out_dir = tempfile.gettempdir() + nccl_allocator_libname = "nccl_allocator" + torch.utils.cpp_extension.load_inline( + name=nccl_allocator_libname, + cpp_sources=nccl_allocator_source, + with_cuda=True, + extra_ldflags=["-lnccl"], + verbose=True, + is_python_module=False, + build_directory=out_dir, + ) + _allocator = CUDAPluggableAllocator( + f"{out_dir}/{nccl_allocator_libname}.so", + "nccl_alloc_plug", + "nccl_free_plug", + ).allocator() + _mem_pool = torch.cuda.MemPool(_allocator) + return _mem_pool + + +class use_symmetric_memory: + def __init__(self, group_coordinator: GroupCoordinator): + if not is_symmetric_memory_enabled(): + self.group_coordinator = None + self._mem_pool_ctx = None + self.is_graph_capture = None + self.device = None + self.pre_2_8_0 = None + else: + self.group_coordinator = group_coordinator + self._mem_pool_ctx = torch.cuda.use_mem_pool(get_nccl_mem_pool()) + self.is_graph_capture = torch.cuda.is_current_stream_capturing() + self.device = torch.cuda.current_device() + self.pre_2_8_0 = version.parse(torch.__version__) < version.parse("2.8.0") + + def __enter__(self): + if not is_symmetric_memory_enabled(): + return self + assert ( + self.group_coordinator.pynccl_comm is not None + ), f"Symmetric memory requires pynccl to be enabled in group '{self.group_coordinator.group_name}'" + assert ( + self.group_coordinator.pynccl_comm.nccl_version >= 22703 + ), "NCCL version 2.27.3 or higher is required for NCCL symmetric memory" + if self.is_graph_capture: + assert ( + _graph_pool_id is not None + ), "graph_pool_id is not set under graph capture" + # Pause graph memory pool to use symmetric memory with cuda graph + if self.pre_2_8_0: + torch._C._cuda_endAllocateCurrentStreamToPool( + self.device, _graph_pool_id + ) + else: + torch._C._cuda_endAllocateToPool(self.device, _graph_pool_id) + self._mem_pool_ctx.__enter__() + return self + + def tag(self, tensor: torch.Tensor): + if not is_symmetric_memory_enabled(): + return + tensor.symmetric_memory = True + + def __exit__(self, exc_type, exc_val, exc_tb): + if not is_symmetric_memory_enabled(): + return + global _registered_base_addrs + self._mem_pool_ctx.__exit__(exc_type, exc_val, exc_tb) + for segment in get_nccl_mem_pool().snapshot(): + if segment["address"] not in _registered_base_addrs: + if segment["stream"] == 0 and self.pre_2_8_0: + # PyTorch version < 2.8.0 has a multi-thread MemPool bug + # See https://github.com/pytorch/pytorch/issues/152861 + # Fixed at https://github.com/pytorch/pytorch/commit/f01e628e3b31852983ab30b25bf251f557ba9c0b + # WAR is to skip allocations on the default stream since the forward_pass thread always runs on a custom stream + continue + self.group_coordinator.pynccl_comm.register_comm_window_raw( + segment["address"], segment["total_size"] + ) + _registered_base_addrs.add(segment["address"]) + + if self.is_graph_capture: + if self.pre_2_8_0: + torch._C._cuda_beginAllocateToPool(self.device, _graph_pool_id) + else: + torch._C._cuda_beginAllocateCurrentThreadToPool( + self.device, _graph_pool_id + ) diff --git a/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py b/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py index afb477334..cad39624e 100644 --- a/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py +++ b/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py @@ -67,6 +67,7 @@ def find_nccl_library() -> str: ncclResult_t = ctypes.c_int ncclComm_t = ctypes.c_void_p +ncclWindow_t = ctypes.c_void_p class ncclUniqueId(ctypes.Structure): @@ -279,6 +280,23 @@ class NCCLLibrary: Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]), ] + exported_functions_symm_mem = [ + # ncclResult_t ncclCommWindowRegister(ncclComm_t comm, void* buff, size_t size, ncclWindow_t* win, int winFlags); + Function( + "ncclCommWindowRegister", + ncclResult_t, + [ + ncclComm_t, + buffer_type, + ctypes.c_size_t, + ctypes.POINTER(ncclWindow_t), + ctypes.c_int, + ], + ), + # ncclResult_t ncclCommWindowDeregister(ncclComm_t comm, ncclWindow_t win); + Function("ncclCommWindowDeregister", ncclResult_t, [ncclComm_t, ncclWindow_t]), + ] + # class attribute to store the mapping from the path to the library # to avoid loading the same library multiple times path_to_library_cache: Dict[str, Any] = {} @@ -312,7 +330,10 @@ class NCCLLibrary: if so_file not in NCCLLibrary.path_to_dict_mapping: _funcs: Dict[str, Any] = {} - for func in NCCLLibrary.exported_functions: + exported_functions = NCCLLibrary.exported_functions + if hasattr(self.lib, "ncclCommWindowRegister"): + exported_functions.extend(NCCLLibrary.exported_functions_symm_mem) + for func in exported_functions: f = getattr(self.lib, func.name) f.restype = func.restype f.argtypes = func.argtypes @@ -328,10 +349,14 @@ class NCCLLibrary: error_str = self.ncclGetErrorString(result) raise RuntimeError(f"NCCL error: {error_str}") - def ncclGetVersion(self) -> str: + def ncclGetRawVersion(self) -> int: version = ctypes.c_int() self.NCCL_CHECK(self._funcs["ncclGetVersion"](ctypes.byref(version))) - version_str = str(version.value) + # something like 21903 + return version.value + + def ncclGetVersion(self) -> str: + version_str = str(self.ncclGetRawVersion()) # something like 21903 --> "2.19.3" major = version_str[0].lstrip("0") minor = version_str[1:3].lstrip("0") @@ -460,6 +485,20 @@ class NCCLLibrary: def ncclCommDestroy(self, comm: ncclComm_t) -> None: self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm)) + def ncclCommWindowRegister( + self, comm: ncclComm_t, buff: buffer_type, size: int, win_flags: int + ) -> ncclWindow_t: + window = ncclWindow_t() + self.NCCL_CHECK( + self._funcs["ncclCommWindowRegister"]( + comm, buff, size, ctypes.byref(window), win_flags + ) + ) + return window + + def ncclCommWindowDeregister(self, comm: ncclComm_t, window: ncclWindow_t) -> None: + self.NCCL_CHECK(self._funcs["ncclCommWindowDeregister"](comm, window)) + __all__ = [ "NCCLLibrary", diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index 279393f95..4e81f80dc 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -497,6 +497,17 @@ class GroupCoordinator: if self.npu_communicator is not None and not self.npu_communicator.disabled: return self.npu_communicator.all_reduce(input_) + if ( + self.pynccl_comm is not None + and hasattr(input_, "symmetric_memory") + and input_.symmetric_memory + ): + with self.pynccl_comm.change_state( + enable=True, stream=torch.cuda.current_stream() + ): + self.pynccl_comm.all_reduce(input_) + return input_ + outplace_all_reduce_method = None if ( self.qr_comm is not None diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index cfe3e0a5b..0e764081a 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -623,8 +623,9 @@ class Engine(EngineBase): def _set_envs_and_config(server_args: ServerArgs): # Set global environments os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" - os.environ["NCCL_CUMEM_ENABLE"] = "0" - os.environ["NCCL_NVLS_ENABLE"] = str(int(server_args.enable_nccl_nvls)) + os.environ["NCCL_CUMEM_ENABLE"] = str(int(server_args.enable_symm_mem)) + if not server_args.enable_symm_mem: + os.environ["NCCL_NVLS_ENABLE"] = str(int(server_args.enable_nccl_nvls)) os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4" os.environ["CUDA_MODULE_LOADING"] = "AUTO" diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py index 9d8ab8632..9e765ebf9 100644 --- a/python/sglang/srt/layers/linear.py +++ b/python/sglang/srt/layers/linear.py @@ -13,10 +13,14 @@ from sglang.srt.distributed import ( divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, + parallel_state, split_tensor_along_last_dim, tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce, ) +from sglang.srt.distributed.device_communicators.pynccl_allocator import ( + use_symmetric_memory, +) from sglang.srt.layers.parameter import ( BasevLLMParameter, BlockQuantScaleParameter, @@ -1292,7 +1296,9 @@ class RowParallelLinear(LinearBase): # Only fuse bias add into GEMM for rank 0 (this ensures that # bias will not get added more than once in TP>1 case) bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias - output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_) + with use_symmetric_memory(parallel_state.get_tp_group()) as sm: + output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_) + sm.tag(output_parallel) if self.reduce_results and self.tp_size > 1 and not can_fuse_mlp_allreduce: output = tensor_model_parallel_all_reduce(output_parallel) else: diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index ba590dbef..3960e22a6 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -14,8 +14,12 @@ from sglang.srt.distributed import ( get_moe_expert_parallel_world_size, get_moe_tensor_parallel_rank, get_moe_tensor_parallel_world_size, + get_tp_group, tensor_model_parallel_all_reduce, ) +from sglang.srt.distributed.device_communicators.pynccl_allocator import ( + use_symmetric_memory, +) from sglang.srt.eplb.expert_location import get_global_expert_location_metadata from sglang.srt.layers.moe.topk import StandardTopKOutput from sglang.srt.layers.quantization.base_config import ( @@ -626,24 +630,27 @@ class FusedMoE(torch.nn.Module): ) # Matrix multiply. - final_hidden_states = self.quant_method.apply( - layer=self, - x=hidden_states, - topk_output=topk_output, - activation=self.activation, - apply_router_weight_on_input=self.apply_router_weight_on_input, - routed_scaling_factor=self.routed_scaling_factor, - **( - dict( - tp_rank=self.moe_tp_rank, - tp_size=self.moe_tp_size, - ep_rank=self.moe_ep_rank, - ep_size=self.moe_ep_size, - ) - if self.quant_method.__class__.__name__ == "ModelOptNvFp4FusedMoEMethod" - else {} - ), - ) + with use_symmetric_memory(get_tp_group()) as sm: + final_hidden_states = self.quant_method.apply( + layer=self, + x=hidden_states, + topk_output=topk_output, + activation=self.activation, + apply_router_weight_on_input=self.apply_router_weight_on_input, + routed_scaling_factor=self.routed_scaling_factor, + **( + dict( + tp_rank=self.moe_tp_rank, + tp_size=self.moe_tp_size, + ep_rank=self.moe_ep_rank, + ep_size=self.moe_ep_size, + ) + if self.quant_method.__class__.__name__ + == "ModelOptNvFp4FusedMoEMethod" + else {} + ), + ) + sm.tag(final_hidden_states) if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1): final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) diff --git a/python/sglang/srt/layers/vocab_parallel_embedding.py b/python/sglang/srt/layers/vocab_parallel_embedding.py index d925506f5..ab1ced99a 100644 --- a/python/sglang/srt/layers/vocab_parallel_embedding.py +++ b/python/sglang/srt/layers/vocab_parallel_embedding.py @@ -11,8 +11,12 @@ from sglang.srt.distributed import ( divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, + parallel_state, tensor_model_parallel_all_reduce, ) +from sglang.srt.distributed.device_communicators.pynccl_allocator import ( + use_symmetric_memory, +) from sglang.srt.layers.amx_utils import PackWeightMethod from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size from sglang.srt.layers.parameter import BasevLLMParameter @@ -464,7 +468,9 @@ class VocabParallelEmbedding(torch.nn.Module): else: masked_input = input_ # Get the embeddings. - output_parallel = self.quant_method.embedding(self, masked_input.long()) + with use_symmetric_memory(parallel_state.get_tp_group()) as sm: + output_parallel = self.quant_method.embedding(self, masked_input.long()) + sm.tag(output_parallel) # Mask the output embedding. if self.tp_size > 1: output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 4b8d07b96..3bfb31b6b 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -108,6 +108,7 @@ GLOBAL_SERVER_ARGS_KEYS = [ "weight_loader_disable_mmap", "enable_triton_kernel_moe", "enable_multimodal", + "enable_symm_mem", ] # Put some global args for easy access diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index fb703255b..e5a8cc872 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -29,6 +29,9 @@ from torch.profiler import ProfilerActivity, profile from sglang.srt.custom_op import CustomOp from sglang.srt.distributed import get_tensor_model_parallel_rank +from sglang.srt.distributed.device_communicators.pynccl_allocator import ( + set_graph_pool_id, +) from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture from sglang.srt.layers.dp_attention import DPPaddingMode, get_attention_tp_size from sglang.srt.layers.logits_processor import LogitsProcessorOutput @@ -643,11 +646,15 @@ class CudaGraphRunner: run_once() - global global_graph_memory_pool - with torch.cuda.graph(graph, pool=global_graph_memory_pool, stream=stream): + if get_global_graph_memory_pool() is None: + set_global_graph_memory_pool(torch.cuda.graph_pool_handle()) + # Set graph pool id globally to be able to use symmetric memory + set_graph_pool_id(get_global_graph_memory_pool()) + with torch.cuda.graph( + graph, pool=get_global_graph_memory_pool(), stream=stream + ): out = run_once() - global_graph_memory_pool = graph.pool() return graph, out def recapture_if_needed(self, forward_batch: ForwardBatch): diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index aaafdb085..b5b13d9ac 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -34,6 +34,9 @@ from sglang.srt.distributed import ( parallel_state, tensor_model_parallel_all_reduce, ) +from sglang.srt.distributed.device_communicators.pynccl_allocator import ( + use_symmetric_memory, +) from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo @@ -481,7 +484,11 @@ class DeepseekV2MoE(nn.Module): if not _is_cuda: final_hidden_states *= self.routed_scaling_factor current_stream.wait_stream(self.alt_stream) - final_hidden_states += shared_output + with use_symmetric_memory(parallel_state.get_tp_group()) as sm: + final_hidden_states_out = torch.empty_like(final_hidden_states) + torch.add(final_hidden_states, shared_output, out=final_hidden_states_out) + final_hidden_states = final_hidden_states_out + sm.tag(final_hidden_states) if self.tp_size > 1 and not can_fuse_mlp_allreduce: final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) return final_hidden_states @@ -507,7 +514,11 @@ class DeepseekV2MoE(nn.Module): # fused in biased_grouped_topk so we can skip here final_hidden_states *= self.routed_scaling_factor if shared_output is not None: - final_hidden_states = final_hidden_states + shared_output + with use_symmetric_memory(parallel_state.get_tp_group()) as sm: + final_hidden_states_out = torch.empty_like(final_hidden_states) + torch.add(final_hidden_states, shared_output, out=final_hidden_states_out) + final_hidden_states = final_hidden_states_out + sm.tag(final_hidden_states) if self.tp_size > 1 and not can_fuse_mlp_allreduce: final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) return final_hidden_states diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index b8b025a79..037505dd7 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -218,6 +218,7 @@ class ServerArgs: enable_profile_cuda_graph: bool = False enable_cudagraph_gc: bool = False enable_nccl_nvls: bool = False + enable_symm_mem: bool = False enable_tokenizer_batch_encode: bool = False disable_outlines_disk_cache: bool = False disable_custom_all_reduce: bool = False @@ -1599,6 +1600,11 @@ class ServerArgs: action="store_true", help="Enable NCCL NVLS for prefill heavy requests when available.", ) + parser.add_argument( + "--enable-symm-mem", + action="store_true", + help="Enable NCCL symmetric memory for fast collectives.", + ) parser.add_argument( "--enable-tokenizer-batch-encode", action="store_true",