diff --git a/benchmark/kernels/all_reduce/benchmark_symm_mem.py b/benchmark/kernels/all_reduce/benchmark_symm_mem.py new file mode 100644 index 000000000..c16397eaa --- /dev/null +++ b/benchmark/kernels/all_reduce/benchmark_symm_mem.py @@ -0,0 +1,234 @@ +"""For Now, SYMM_MEM is only supported on TP8 case + +export WORLD_SIZE=1 +export RANK=0 +export MASTER_ADDR=127.0.0.1 +export MASTER_PORT=12345 + +torchrun --nproc_per_node gpu \ +--nnodes $WORLD_SIZE \ +--node_rank $RANK \ +--master_addr $MASTER_ADDR \ +--master_port $MASTER_PORT ./benchmark/kernels/all_reduce/benchmark_symm_mem.py +""" + +import os +from contextlib import nullcontext +from typing import List + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup + +from sglang.srt.distributed import init_distributed_environment +from sglang.srt.distributed.device_communicators.pynccl import PyNcclCommunicator +from sglang.srt.distributed.device_communicators.symm_mem import SymmMemCommunicator +from sglang.srt.distributed.parallel_state import ( + get_tensor_model_parallel_group, + graph_capture, + initialize_model_parallel, + set_symm_mem_all_reduce, +) + +# CI environment detection +IS_CI = ( + os.getenv("CI", "false").lower() == "true" + or os.getenv("GITHUB_ACTIONS", "false").lower() == "true" +) + + +def torch_allreduce(torch_input: torch.Tensor, group: ProcessGroup) -> torch.Tensor: + dist.all_reduce(torch_input, group=group) + return torch_input + + +def symm_mem_allreduce( + symm_mem_input: torch.Tensor, symm_mem_comm: SymmMemCommunicator +) -> torch.Tensor: + return symm_mem_comm.all_reduce(symm_mem_input) + + +def pynccl_allreduce( + pynccl_input: torch.Tensor, pynccl_comm: PyNcclCommunicator +) -> torch.Tensor: + pynccl_comm.all_reduce(pynccl_input) + return pynccl_input + + +def _bench_graph_time(func, inp_randn, warmup_loop=2, graph_loop=10, test_loop=10): + graph_input = inp_randn.clone() + with graph_capture() as graph_capture_context: + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, stream=graph_capture_context.stream): + for _ in range(graph_loop): + graph_out = func(graph_input) + + graph.replay() + func_output = graph_out.clone() + + for _ in range(warmup_loop): + graph.replay() + torch.cuda.synchronize() + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + latencies: List[float] = [] + for _ in range(test_loop): + torch.cuda.synchronize() + dist.barrier() + start_event.record() + graph.replay() + end_event.record() + end_event.synchronize() + latencies.append(start_event.elapsed_time(end_event)) + func_cost_us = sum(latencies) / len(latencies) / graph_loop * 1000 + graph.reset() + return func_output, func_cost_us + + +def _bench_eager_time(func, inp_randn, warmup_loop=2, test_loop=10): + eager_input = inp_randn.clone() + eager_output = func(eager_input) + func_output = eager_output.clone() + + for _ in range(warmup_loop): + func(eager_input) + torch.cuda.synchronize() + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + torch.cuda.synchronize() + start_event.record() + for _ in range(test_loop): + func(eager_input) + end_event.record() + torch.cuda.synchronize() + func_cost_us = start_event.elapsed_time(end_event) / test_loop * 1000 + + return func_output, func_cost_us + + +def get_torch_prof_ctx(do_prof: bool): + ctx = ( + torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + record_shapes=True, + with_stack=True, + ) + if do_prof + else nullcontext() + ) + return ctx + + +def human_readable_size(size, decimal_places=1): + for unit in ["B", "KiB", "MiB", "GiB", "TiB", "PiB"]: + if size < 1024.0 or unit == "PiB": + break + size /= 1024.0 + return f"{size:.{decimal_places}f} {unit}" + + +try: + from tabulate import tabulate +except ImportError: + print("tabulate not installed, skipping table printing") + tabulate = None + + +def print_markdown_table(data): + if tabulate is not None: + print(tabulate(data, headers="keys", tablefmt="github")) + return + headers = data[0].keys() + header_row = "| " + " | ".join(headers) + " |" + separator = "| " + " | ".join(["---"] * len(headers)) + " |" + rows = [] + for item in data: + row = "| " + " | ".join(str(item[key]) for key in headers) + " |" + rows.append(row) + markdown_table = "\n".join([header_row, separator] + rows) + print(markdown_table) + + +if __name__ == "__main__": + import logging + + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(levelname)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + force=True, + ) + if not dist.is_initialized(): + dist.init_process_group(backend="nccl") + world, world_size = dist.group.WORLD, dist.get_world_size() + rank = dist.get_rank() + torch.cuda.set_device(rank % 8) + device = torch.cuda.current_device() + set_symm_mem_all_reduce(True) + init_distributed_environment( + world_size=world_size, + rank=rank, + local_rank=rank % 8, + ) + initialize_model_parallel(tensor_model_parallel_size=world_size) + group = get_tensor_model_parallel_group().device_group + cpu_group = get_tensor_model_parallel_group().cpu_group + pynccl_comm = get_tensor_model_parallel_group().pynccl_comm + symm_mem_comm = get_tensor_model_parallel_group().symm_mem_comm + dist.barrier() + profile = False + dtype = torch.bfloat16 + ctx = get_torch_prof_ctx(profile) + result = [] + + with ctx: + if IS_CI: + i_range = range(10, 11) + else: + i_range = range(10, 20) + for i in i_range: + sz = 2**i + if sz * dtype.itemsize > 2**24: + break + inp_randn = torch.randint(1, 16, (sz,), dtype=dtype, device=device) + + memory = torch.empty_like(inp_randn) + memory_out = torch.empty_like(memory) + torch_eager_output, torch_eager_time = _bench_eager_time( + lambda inp: torch_allreduce(inp, group), inp_randn + ) + symm_mem_eager_output, symm_mem_eager_time = _bench_eager_time( + lambda inp: symm_mem_allreduce(inp, symm_mem_comm), inp_randn + ) + symm_mem_graph_output, symm_mem_graph_time = _bench_graph_time( + lambda inp: symm_mem_allreduce(inp, symm_mem_comm), inp_randn + ) + # since pynccl is inplace op, this return result is not correct if graph loop > 1 + _, pynccl_graph_time = _bench_graph_time( + lambda inp: pynccl_allreduce(inp, pynccl_comm), inp_randn + ) + torch.testing.assert_close(torch_eager_output, symm_mem_graph_output) + torch.testing.assert_close(torch_eager_output, symm_mem_eager_output) + result.append( + { + "msg_size": human_readable_size(inp_randn.nbytes), + "torch eager time": torch_eager_time, + "symm mem eager time": symm_mem_eager_time, + "symm mem graph time": symm_mem_graph_time, + "pynccl graph time": pynccl_graph_time, + } + ) + if rank == 0: + print(f"sz={sz}, dtype={dtype}: correctness check PASS!") + if rank == 0: + print_markdown_table(result) + if profile: + prof_dir = f"prof/symm_mem" + os.makedirs(prof_dir, exist_ok=True) + ctx.export_chrome_trace(f"{prof_dir}/trace_rank{dist.get_rank()}.json.gz") diff --git a/benchmark/lora/launch_server.py b/benchmark/lora/launch_server.py index b0781ca30..de93a6e13 100644 --- a/benchmark/lora/launch_server.py +++ b/benchmark/lora/launch_server.py @@ -28,6 +28,8 @@ def launch_server(args): cmd += "--disable-custom-all-reduce" if args.enable_mscclpp: cmd += "--enable-mscclpp" + if args.enable_torch_symm_mem: + cmd += "--enable-torch-symm-mem" print(cmd) os.system(cmd) @@ -70,6 +72,11 @@ if __name__ == "__main__": action="store_true", help="Enable using mscclpp for small messages for all-reduce kernel and fall back to NCCL.", ) + parser.add_argument( + "--enable-torch-symm-mem", + action="store_true", + help="Enable using torch symm mem for all-reduce kernel and fall back to NCCL.", + ) args = parser.parse_args() launch_server(args) diff --git a/python/sglang/srt/distributed/device_communicators/all_reduce_utils.py b/python/sglang/srt/distributed/device_communicators/all_reduce_utils.py new file mode 100644 index 000000000..99d6ebf2e --- /dev/null +++ b/python/sglang/srt/distributed/device_communicators/all_reduce_utils.py @@ -0,0 +1,16 @@ +MiB = 1024 * 1024 + +SYMM_MEM_ALL_REDUCE_MAX_SIZES = { + 9: { + 2: 64 * MiB, # 64 MB + 4: 32 * MiB, # 32 MB + 6: 64 * MiB, # 64 MB + 8: 64 * MiB, # 64 MB + }, + 10: { + 2: 64 * MiB, # 64 MB + 4: 32 * MiB, # 32 MB + 6: 128 * MiB, # 128 MB + 8: 128 * MiB, # 128 MB + }, +} diff --git a/python/sglang/srt/distributed/device_communicators/symm_mem.py b/python/sglang/srt/distributed/device_communicators/symm_mem.py new file mode 100644 index 000000000..0d69a33a2 --- /dev/null +++ b/python/sglang/srt/distributed/device_communicators/symm_mem.py @@ -0,0 +1,164 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/bf214ca22625e311a2c4c0dfbf7af19128f4919c/vllm/distributed/device_communicators/symm_mem.py +import logging +from typing import Optional, Union + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup + +from sglang.srt.distributed.device_communicators.all_reduce_utils import ( + SYMM_MEM_ALL_REDUCE_MAX_SIZES, +) +from sglang.srt.utils import get_device_capability, is_cuda, is_hip + +try: + import torch.distributed._symmetric_memory as torch_symm_mem + + symm_mem_available = True +except ImportError: + symm_mem_available = False + + +logger = logging.getLogger(__name__) + +_is_cuda = is_cuda() +_is_hip = is_hip() + +symm_mem_is_available = False +if _is_hip: + symm_mem_is_available = False +if _is_cuda: + symm_mem_is_available = True + + +class SymmMemCommunicator: + """ + Thin wrapper around symmetric-memory collectives. + + This communicator: + - Validates device capability and world size. + - Allocates a shared symmetric buffer. + - Chooses between 'multimem' and 'two-shot' all-reduce kernels. + - Exposes a fast-path all_reduce() compatible with bfloat16 inputs. + + If any prerequisite is not met, the instance remains disabled and will + decline to perform symmetric-memory all-reduce. + """ + + # Mapping: compute capability major -> supported world sizes for multimem + # If the current (cc_major, world_size) is not listed, we fall back + # to the two-shot path. + _WORLD_SIZES_MULTIMEM = { + 9: [4, 6, 8], + 10: [6, 8], + } + + def __init__(self, group: ProcessGroup, device: Union[int, str, torch.device]): + """ + Args: + group: Torch process group used for rendezvous and naming. + device: Target CUDA device (index, 'cuda:X', or torch.device). + """ + + self.disabled = True + + if not symm_mem_available: + return + + if isinstance(device, int): + device = torch.device(f"cuda:{device}") + elif isinstance(device, str): + device = torch.device(device) + torch.cuda.set_device(device) + self.dtype = torch.bfloat16 + self.device = device + self.group = group + self.world_size = dist.get_world_size(self.group) + self.device_capability = torch.cuda.get_device_capability(device)[0] + if self.device_capability < 9: + logger.warning( + "SymmMemCommunicator: Device capability %s not supported, " + "communicator is not available.", + self.device_capability, + ) + return + if self.world_size not in SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability]: + logger.warning( + "SymmMemCommunicator: World size %d not supported, " + "communicator is not available.", + self.world_size, + ) + return + self.max_size = SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability][ + self.world_size + ] + self.buffer = torch_symm_mem.empty( + self.max_size // self.dtype.itemsize, + device=self.device, + dtype=self.dtype, + ) + handle = torch_symm_mem.rendezvous(self.buffer, self.group.group_name) + if handle.multicast_ptr == 0: + logger.warning( + "SymmMemCommunicator: symmetric memory " + "multicast operations are not supported." + ) + self.buffer = None + self.disabled = True + return + self.disabled = False + + def should_symm_mem_allreduce(self, inp: torch.Tensor): + """ + Fast-path eligibility check for a given tensor. + + Conditions: + - Communicator must be enabled. + - dtype must be bfloat16 (matches kernel + buffer dtype). + - Total byte size must be 4-byte aligned (hardware requirement). + - Payload must be smaller than the symmetric-memory max size. + + Returns: + True if the symmetric-memory path can handle this tensor. + """ + if self.disabled: + return False + if inp.dtype != self.dtype: + return False + inp_size = inp.numel() * inp.element_size() + # enforce 4-byte alignment + if inp_size % 4 != 0: + return False + return inp_size < self.max_size + + def all_reduce( + self, inp: torch.Tensor, *, out: Optional[torch.Tensor] = None + ) -> Optional[torch.Tensor]: + """ + Perform an in-place sum all-reduce via symmetric memory. + + Args: + inp: Input tensor on the target CUDA device (bfloat16). + out: Optional output tensor; if omitted, a new tensor is allocated. + + Returns: + The reduced tensor (same shape as inp), or None if disabled. + + Implementation details: + - Stages 'inp' into the symmetric buffer. + - Selects 'multimem' or 'two_shot' kernel based on topology. + - Writes the result into 'out' and returns it. + """ + if out is None: + out = torch.empty_like(inp) + self.buffer[: inp.numel()].copy_(inp.view(-1)) + if self.world_size in self._WORLD_SIZES_MULTIMEM[self.device_capability]: + torch.ops.symm_mem.multimem_all_reduce_( + self.buffer[: inp.numel()], "sum", self.group.group_name + ) + else: + torch.ops.symm_mem.two_shot_all_reduce_( + self.buffer[: inp.numel()], "sum", self.group.group_name + ) + out.copy_(self.buffer[: inp.numel()].view(out.shape)) + return out diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index ec980c030..009aba52e 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -215,12 +215,14 @@ class GroupCoordinator: use_pynccl: bool # a hint of whether to use PyNccl use_pymscclpp: bool # a hint of whether to use PyMsccl use_custom_allreduce: bool # a hint of whether to use CustomAllreduce + use_torch_symm_mem: bool # a hint of whether to use SymmMemAllReduce use_message_queue_broadcaster: ( bool # a hint of whether to use message queue broadcaster ) # communicators are only created for world size > 1 pynccl_comm: Optional[Any] # PyNccl communicator ca_comm: Optional[Any] # Custom allreduce communicator + symm_mem_comm: Optional[Any] # Symm mem communicator mq_broadcaster: Optional[Any] # shared memory broadcaster def __init__( @@ -231,6 +233,7 @@ class GroupCoordinator: use_pynccl: bool, use_pymscclpp: bool, use_custom_allreduce: bool, + use_torch_symm_mem: bool, use_hpu_communicator: bool, use_xpu_communicator: bool, use_npu_communicator: bool, @@ -279,6 +282,7 @@ class GroupCoordinator: self.use_pynccl = use_pynccl self.use_pymscclpp = use_pymscclpp self.use_custom_allreduce = use_custom_allreduce + self.use_torch_symm_mem = use_torch_symm_mem self.use_hpu_communicator = use_hpu_communicator self.use_xpu_communicator = use_xpu_communicator self.use_npu_communicator = use_npu_communicator @@ -294,6 +298,9 @@ class GroupCoordinator: from sglang.srt.distributed.device_communicators.pynccl import ( PyNcclCommunicator, ) + from sglang.srt.distributed.device_communicators.symm_mem import ( + SymmMemCommunicator, + ) if is_hip(): from sglang.srt.distributed.device_communicators.quick_all_reduce import ( @@ -342,6 +349,13 @@ class GroupCoordinator: except Exception as e: logger.warning(f"Failed to initialize QuickAllReduce: {e}") + self.symm_mem_comm: Optional[SymmMemCommunicator] = None + if self.use_torch_symm_mem and self.world_size > 1: + self.symm_mem_comm = SymmMemCommunicator( + group=self.cpu_group, + device=self.device, + ) + # Create communicator for other hardware backends from sglang.srt.distributed.device_communicators.hpu_communicator import ( HpuCommunicator, @@ -446,6 +460,7 @@ class GroupCoordinator: # custom allreduce | enabled | enabled | # PyNccl | disabled| enabled | # PyMscclpp | disabled| enabled | + # TorchSymmMem | disabled| enabled | # torch.distributed | enabled | disabled| # # Note: When custom quick allreduce is enabled, a runtime check @@ -547,7 +562,12 @@ class GroupCoordinator: and self.pymscclpp_comm.should_mscclpp_allreduce(input_) ): outplace_all_reduce_method = "pymscclpp" - + elif ( + self.symm_mem_comm is not None + and not self.symm_mem_comm.disabled + and self.symm_mem_comm.should_symm_mem_allreduce(input_) + ): + outplace_all_reduce_method = "symm_mem" if outplace_all_reduce_method is not None: return torch.ops.sglang.outplace_all_reduce( input_, @@ -564,6 +584,7 @@ class GroupCoordinator: ca_comm = self.ca_comm qr_comm = self.qr_comm pymscclpp_comm = self.pymscclpp_comm + symm_mem_comm = self.symm_mem_comm assert any([qr_comm, ca_comm, pymscclpp_comm]) if outplace_all_reduce_method == "ca": assert not ca_comm.disabled @@ -571,6 +592,9 @@ class GroupCoordinator: elif outplace_all_reduce_method == "qr": assert not qr_comm.disabled out = qr_comm.quick_all_reduce(input_) + elif outplace_all_reduce_method == "symm_mem": + assert not symm_mem_comm.disabled + out = symm_mem_comm.all_reduce(input_) else: assert not pymscclpp_comm.disabled out = pymscclpp_comm.all_reduce(input_) @@ -1219,6 +1243,7 @@ def init_world_group( use_pynccl=False, use_pymscclpp=False, use_custom_allreduce=False, + use_torch_symm_mem=False, use_hpu_communicator=False, use_xpu_communicator=False, use_npu_communicator=False, @@ -1234,11 +1259,14 @@ def init_model_parallel_group( use_message_queue_broadcaster: bool = False, group_name: Optional[str] = None, use_mscclpp_allreduce: Optional[bool] = None, + use_symm_mem_allreduce: Optional[bool] = None, ) -> GroupCoordinator: if use_custom_allreduce is None: use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE if use_mscclpp_allreduce is None: use_mscclpp_allreduce = _ENABLE_MSCCLPP_ALL_REDUCE + if use_symm_mem_allreduce is None: + use_symm_mem_allreduce = _ENABLE_SYMM_MEM_ALL_REDUCE return GroupCoordinator( group_ranks=group_ranks, local_rank=local_rank, @@ -1246,6 +1274,7 @@ def init_model_parallel_group( use_pynccl=not _is_npu, use_pymscclpp=use_mscclpp_allreduce, use_custom_allreduce=use_custom_allreduce, + use_torch_symm_mem=use_symm_mem_allreduce, use_hpu_communicator=True, use_xpu_communicator=True, use_npu_communicator=True, @@ -1331,6 +1360,7 @@ logger = logging.getLogger(__name__) _ENABLE_CUSTOM_ALL_REDUCE = True _ENABLE_MSCCLPP_ALL_REDUCE = False +_ENABLE_SYMM_MEM_ALL_REDUCE = False def set_custom_all_reduce(enable: bool): @@ -1343,6 +1373,11 @@ def set_mscclpp_all_reduce(enable: bool): _ENABLE_MSCCLPP_ALL_REDUCE = enable +def set_symm_mem_all_reduce(enable: bool): + global _ENABLE_SYMM_MEM_ALL_REDUCE + _ENABLE_SYMM_MEM_ALL_REDUCE = enable + + def init_distributed_environment( world_size: int = -1, rank: int = -1, diff --git a/python/sglang/srt/layers/dp_attention.py b/python/sglang/srt/layers/dp_attention.py index 9fa359eaf..d4db39a33 100644 --- a/python/sglang/srt/layers/dp_attention.py +++ b/python/sglang/srt/layers/dp_attention.py @@ -263,6 +263,7 @@ def initialize_dp_attention( use_pynccl=SYNC_TOKEN_IDS_ACROSS_TP, use_pymscclpp=False, use_custom_allreduce=False, + use_torch_symm_mem=False, use_hpu_communicator=False, use_xpu_communicator=False, use_npu_communicator=False, diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index eb02c41e3..2d87ec6f6 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -42,6 +42,7 @@ from sglang.srt.distributed import ( initialize_model_parallel, set_custom_all_reduce, set_mscclpp_all_reduce, + set_symm_mem_all_reduce, ) from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state from sglang.srt.eplb.eplb_manager import EPLBManager @@ -646,6 +647,7 @@ class ModelRunner: dist_init_method = f"tcp://127.0.0.1:{self.dist_port}" set_custom_all_reduce(not self.server_args.disable_custom_all_reduce) set_mscclpp_all_reduce(self.server_args.enable_mscclpp) + set_symm_mem_all_reduce(self.server_args.enable_torch_symm_mem) if not self.is_draft_worker: if self.device == "cpu": diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 5c7e2f57b..dfa4d8e8f 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -382,6 +382,7 @@ class ServerArgs: disable_outlines_disk_cache: bool = False disable_custom_all_reduce: bool = False enable_mscclpp: bool = False + enable_torch_symm_mem: bool = False disable_overlap_schedule: bool = False enable_mixed_chunk: bool = False enable_dp_attention: bool = False @@ -2443,6 +2444,11 @@ class ServerArgs: action="store_true", help="Enable using mscclpp for small messages for all-reduce kernel and fall back to NCCL.", ) + parser.add_argument( + "--enable-torch-symm-mem", + action="store_true", + help="Enable using torch symm mem for all-reduce kernel and fall back to NCCL. Only supports CUDA device SM90 and above. SM90 supports world size 4, 6, 8. SM10 supports world size 6, 8.", + ) parser.add_argument( "--disable-overlap-schedule", action="store_true",