diff --git a/benchmark/kernels/all_reduce/benchmark_mscclpp.py b/benchmark/kernels/all_reduce/benchmark_mscclpp.py new file mode 100644 index 000000000..eebbd00ce --- /dev/null +++ b/benchmark/kernels/all_reduce/benchmark_mscclpp.py @@ -0,0 +1,224 @@ +"""For Now, MSCCL is only supported on TP16 and TP8 case + +export WORLD_SIZE=1 +export RANK=0 +export MASTER_ADDR=127.0.0.1 +export MASTER_PORT=12345 + +torchrun --nproc_per_node gpu \ +--nnodes $WORLD_SIZE \ +--node_rank $RANK \ +--master_addr $MASTER_ADDR \ +--master_port $MASTER_PORT benchmark/kernels/all_reduce/benchmark_mscclpp.py +""" + +import os +from contextlib import nullcontext +from typing import List + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup + +from sglang.srt.distributed import init_distributed_environment +from sglang.srt.distributed.device_communicators.pymscclpp import PyMscclppCommunicator +from sglang.srt.distributed.device_communicators.pynccl import PyNcclCommunicator +from sglang.srt.distributed.parallel_state import ( + get_tensor_model_parallel_group, + graph_capture, + initialize_model_parallel, + set_mscclpp_all_reduce, +) + + +def torch_allreduce(torch_input: torch.Tensor, group: ProcessGroup) -> torch.Tensor: + dist.all_reduce(torch_input, group=group) + return torch_input + + +def msccl_allreduce( + msccl_input: torch.Tensor, msccl_comm: PyMscclppCommunicator +) -> torch.Tensor: + return msccl_comm.all_reduce(msccl_input) + + +def pynccl_allreduce( + msccl_input: torch.Tensor, pynccl_comm: PyNcclCommunicator +) -> torch.Tensor: + pynccl_comm.all_reduce(msccl_input) + return msccl_input + + +def _bench_graph_time(func, inp_randn, warmup_loop=2, graph_loop=10, test_loop=10): + graph_input = inp_randn.clone() + with graph_capture() as graph_capture_context: + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, stream=graph_capture_context.stream): + for _ in range(graph_loop): + graph_out = func(graph_input) + + graph.replay() + func_output = graph_out.clone() + + for _ in range(warmup_loop): + graph.replay() + torch.cuda.synchronize() + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + latencies: List[float] = [] + for _ in range(test_loop): + torch.cuda.synchronize() + dist.barrier() + start_event.record() + graph.replay() + end_event.record() + end_event.synchronize() + latencies.append(start_event.elapsed_time(end_event)) + func_cost_us = sum(latencies) / len(latencies) / graph_loop * 1000 + graph.reset() + return func_output, func_cost_us + + +def _bench_eager_time(func, inp_randn, warmup_loop=2, test_loop=10): + eager_input = inp_randn.clone() + eager_output = func(eager_input) + func_output = eager_output.clone() + + for _ in range(warmup_loop): + func(eager_input) + torch.cuda.synchronize() + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + torch.cuda.synchronize() + start_event.record() + for _ in range(test_loop): + func(eager_input) + end_event.record() + torch.cuda.synchronize() + func_cost_us = start_event.elapsed_time(end_event) / test_loop * 1000 + + return func_output, func_cost_us + + +def get_torch_prof_ctx(do_prof: bool): + ctx = ( + torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + record_shapes=True, + with_stack=True, + ) + if do_prof + else nullcontext() + ) + return ctx + + +def human_readable_size(size, decimal_places=1): + for unit in ["B", "KiB", "MiB", "GiB", "TiB", "PiB"]: + if size < 1024.0 or unit == "PiB": + break + size /= 1024.0 + return f"{size:.{decimal_places}f} {unit}" + + +try: + from tabulate import tabulate +except ImportError: + print("tabulate not installed, skipping table printing") + tabulate = None + + +def print_markdown_table(data): + if tabulate is not None: + print(tabulate(data, headers="keys", tablefmt="github")) + return + headers = data[0].keys() + header_row = "| " + " | ".join(headers) + " |" + separator = "| " + " | ".join(["---"] * len(headers)) + " |" + rows = [] + for item in data: + row = "| " + " | ".join(str(item[key]) for key in headers) + " |" + rows.append(row) + markdown_table = "\n".join([header_row, separator] + rows) + print(markdown_table) + + +if __name__ == "__main__": + import logging + + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(levelname)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + force=True, + ) + if not dist.is_initialized(): + dist.init_process_group(backend="nccl") + world, world_size = dist.group.WORLD, dist.get_world_size() + rank = dist.get_rank() + torch.cuda.set_device(rank % 8) + device = torch.cuda.current_device() + set_mscclpp_all_reduce(True) + init_distributed_environment( + world_size=world_size, + rank=rank, + local_rank=rank % 8, + ) + initialize_model_parallel(tensor_model_parallel_size=world_size) + group = get_tensor_model_parallel_group().device_group + cpu_group = get_tensor_model_parallel_group().cpu_group + pynccl_comm = get_tensor_model_parallel_group().pynccl_comm + pymscclpp_comm = get_tensor_model_parallel_group().pymscclpp_comm + dist.barrier() + profile = False + dtype = torch.bfloat16 + ctx = get_torch_prof_ctx(profile) + result = [] + + with ctx: + for i in range(10, 20): + sz = 2**i + if sz * dtype.itemsize > 2**20: + break + inp_randn = torch.randint(1, 16, (sz,), dtype=dtype, device=device) + + memory = torch.empty_like(inp_randn) + memory_out = torch.empty_like(memory) + torch_eager_output, torch_eager_time = _bench_eager_time( + lambda inp: torch_allreduce(inp, group), inp_randn + ) + msccl_eager_output, msccl_eager_time = _bench_eager_time( + lambda inp: msccl_allreduce(inp, pymscclpp_comm), inp_randn + ) + msccl_graph_output, msccl_graph_time = _bench_graph_time( + lambda inp: msccl_allreduce(inp, pymscclpp_comm), inp_randn + ) + # since pynccl is inplace op, this return result is not correct if graph loop > 1 + _, pynccl_graph_time = _bench_graph_time( + lambda inp: pynccl_allreduce(inp, pynccl_comm), inp_randn + ) + torch.testing.assert_close(torch_eager_output, msccl_graph_output) + torch.testing.assert_close(torch_eager_output, msccl_eager_output) + result.append( + { + "msg_size": human_readable_size(inp_randn.nbytes), + "torch eager time": torch_eager_time, + "msccl eager time": msccl_eager_time, + "msccl graph time": msccl_graph_time, + "pynccl graph time": pynccl_graph_time, + } + ) + if rank == 0: + print(f"sz={sz}, dtype={dtype}: correctness check PASS!") + if rank == 0: + print_markdown_table(result) + if profile: + prof_dir = f"prof/msccl" + os.makedirs(prof_dir, exist_ok=True) + ctx.export_chrome_trace(f"{prof_dir}/trace_rank{dist.get_rank()}.json.gz") diff --git a/benchmark/lora/launch_server.py b/benchmark/lora/launch_server.py index 6ed85d9d2..b0781ca30 100644 --- a/benchmark/lora/launch_server.py +++ b/benchmark/lora/launch_server.py @@ -26,6 +26,8 @@ def launch_server(args): cmd += f"--tp-size {args.tp_size} " if args.disable_custom_all_reduce: cmd += "--disable-custom-all-reduce" + if args.enable_mscclpp: + cmd += "--enable-mscclpp" print(cmd) os.system(cmd) @@ -63,6 +65,11 @@ if __name__ == "__main__": action="store_true", help="Disable custom all reduce when device does not support p2p communication", ) + parser.add_argument( + "--enable-mscclpp", + action="store_true", + help="Enable using mscclpp for small messages for all-reduce kernel and fall back to NCCL.", + ) args = parser.parse_args() launch_server(args) diff --git a/docs/backend/server_arguments.md b/docs/backend/server_arguments.md index b36be6487..7ebeac5ce 100644 --- a/docs/backend/server_arguments.md +++ b/docs/backend/server_arguments.md @@ -201,6 +201,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s | `disable_cuda_graph_padding` | Disable CUDA Graph when padding is needed; otherwise, still use CUDA Graph. | `False` | | `disable_outlines_disk_cache` | Disable disk cache for outlines grammar backend. | `False` | | `disable_custom_all_reduce` | Disable usage of custom all-reduce kernel. | `False` | +| `enable_mscclpp` | Enable usage of mscclpp kernel for small message all-reduce. | `False` | | `disable_overlap_schedule` | Disable the [Overhead-Scheduler](https://lmsys.org/blog/2024-12-04-sglang-v0-4/#zero-overhead-batch-scheduler). | `False` | | `enable_nan_detection` | Enable warning if the logits contain `NaN`. | `False` | | `enable_p2p_check` | Turns off the default of always allowing P2P checks when accessing GPU. | `False` | diff --git a/python/sglang/srt/_custom_ops.py b/python/sglang/srt/_custom_ops.py index 07c087bf6..5d5b999a2 100644 --- a/python/sglang/srt/_custom_ops.py +++ b/python/sglang/srt/_custom_ops.py @@ -113,3 +113,37 @@ else: def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor: return sgl_kernel.allreduce.get_meta_buffer_ipc_handle(inp) + + +def mscclpp_generate_unique_id() -> bytes: + return sgl_kernel.allreduce.mscclpp_generate_unique_id() + + +def mscclpp_init_context( + unique_id: bytes, + rank: int, + world_size: int, + scratch: torch.Tensor, + put_buffer: torch.Tensor, + nranks_per_node: int, + rank_to_node: List[int], + rank_to_ib: List[int], + context_selection: int, +) -> int: + return sgl_kernel.allreduce.mscclpp_init_context( + unique_id, + rank, + world_size, + scratch, + put_buffer, + nranks_per_node, + rank_to_node, + rank_to_ib, + context_selection, + ) + + +def mscclpp_allreduce( + context: int, inp: torch.Tensor, out: torch.Tensor, nthreads: int, nblocks: int +) -> None: + return sgl_kernel.allreduce.mscclpp_allreduce(context, inp, out, nthreads, nblocks) diff --git a/python/sglang/srt/distributed/device_communicators/pymscclpp.py b/python/sglang/srt/distributed/device_communicators/pymscclpp.py new file mode 100644 index 000000000..78269ed05 --- /dev/null +++ b/python/sglang/srt/distributed/device_communicators/pymscclpp.py @@ -0,0 +1,315 @@ +import bisect +import logging +import math +import os +from contextlib import contextmanager +from enum import IntEnum +from typing import Any, Callable, List, Optional, TypeVar, Union + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup, ReduceOp + +from sglang.srt import _custom_ops as ops +from sglang.srt.utils import is_cuda, is_hip + +logger = logging.getLogger(__name__) + +_is_cuda = is_cuda() +_is_hip = is_hip() + +mscclpp_is_available = False +if _is_hip: + # TODO(zyksir): mscclpp is untested on AMD and therefore disabled. + mscclpp_is_available = False +if _is_cuda: + try: + import sgl_kernel + + mscclpp_is_available = True + except: + mscclpp_is_available = False + + +class MscclContextSelection(IntEnum): + MSCCL1SHOT1NODELL = 1 + MSCCL1SHOT2NODELL = 2 + + +def mscclpp_is_weak_contiguous(inp: torch.Tensor): + return inp.is_contiguous() or ( + inp.storage().nbytes() - inp.storage_offset() * inp.element_size() + == inp.numel() * inp.element_size() + ) + + +def mscclpp_convert_to_bytes(size_str): + """ + Converts a human-readable size string (e.g., "1MB", "2.5kb", "3 GB") + into the equivalent number of bytes using binary units. + + Args: + size_str (str): A string representing size with unit (KB, MB, GB). + + Returns: + int: Number of bytes. + """ + size_str = size_str.strip().lower() + + if not size_str: + raise ValueError("Empty input string") + + # Extract numeric part and unit + for i in range(len(size_str)): + if not size_str[i].isdigit() and size_str[i] != ".": + break + num_str = size_str[:i] + unit = size_str[i:].strip() + + try: + num = float(num_str) + except ValueError: + raise ValueError(f"Invalid numeric value in '{size_str}'") + + # Conversion factors + if unit == "b": + return int(num) + elif unit == "kb": + return int(num * 1024) + elif unit == "mb": + return int(num * 1024 * 1024) + elif unit == "gb": + return int(num * 1024 * 1024 * 1024) + else: + raise ValueError(f"Unsupported unit: {unit}, support B, KB, MB, GB only") + + +def mscclpp_bench_time(func, test_niter: int = 10, warmup_niter: int = 2): + # warmup + for _ in range(warmup_niter): + func() + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + torch.cuda.synchronize() + dist.barrier() + start_event.record() + for _ in range(test_niter): + func() + end_event.record() + end_event.synchronize() + func_cost_us = start_event.elapsed_time(end_event) / test_niter * 1000 + return func_cost_us + + +class PyMscclppCommunicator: + _SUPPORTED_WORLD_SIZES = [8, 16] + _MAX_BYTES = mscclpp_convert_to_bytes(os.getenv("SGLANG_MSCCLPP_MAX_BYTES", "1MB")) + _SUPPORTED_DTYPE = [torch.float, torch.float16, torch.bfloat16] + + # max_bytes: max supported mscclpp allreduce size + # in A100 mscclpp is faster than nccl only under condition of msg size smaller than1MB + def __init__( + self, + group: ProcessGroup, + device: Union[int, str, torch.device], + max_bytes=_MAX_BYTES, + ) -> None: + """ + Args: + group: the process group to work on. If None, it will use the + default process group. + device: the device to bind the CustomAllreduce to. If None, + it will be bind to f"cuda:{local_rank}". + It is the caller's responsibility to make sure each communicator + is bind to a unique device, and all communicators in this group + are in the same node. + """ + self._IS_CAPTURING = False + self.disabled = True + + if not mscclpp_is_available: + # disable because of missing mscclpp library + # e.g. in a non-cuda environment + return + + self.group = group + + assert ( + dist.get_backend(group) != dist.Backend.NCCL + ), "CustomAllreduce should be attached to a non-NCCL group." + + rank = dist.get_rank(group=self.group) + world_size = dist.get_world_size(group=self.group) + if world_size == 1: + # No need to initialize mscclpp for single GPU case. + return + + if world_size not in PyMscclppCommunicator._SUPPORTED_WORLD_SIZES: + logger.warning( + "PyMscclpp is disabled due to an unsupported world" + " size: %d. Supported world sizes: %s. To silence this " + "warning, specify disable_mscclpp=True explicitly.", + world_size, + str(PyMscclppCommunicator._SUPPORTED_WORLD_SIZES), + ) + return + + self.ranks = torch.distributed.get_process_group_ranks(group) + self.nranks_per_node = torch.cuda.device_count() + # for now mscclpp with stride in the communicator is not tested + if not (abs(self.ranks[-1] - self.ranks[0]) == world_size - 1): + logger.warning( + "PyMscclpp is disabled due to an unsupported group %s." + "Please ensure all ranks in the group are consecutive." + "To silence this warning, specify disable_mscclpp=True explicitly.", + str(self.ranks), + ) + return + + if isinstance(device, int): + device = torch.device(f"cuda:{device}") + elif isinstance(device, str): + device = torch.device(device) + # now `device` is a `torch.device` object + assert isinstance(device, torch.device) + self.device = device + + self.max_bytes = max_bytes + self.rank = rank + self.world_size = world_size + + if dist.get_rank(group) == 0: + unique_id = [ops.mscclpp_generate_unique_id()] + else: + unique_id = [None] + dist.broadcast_object_list(unique_id, src=self.ranks[0], group=self.group) + self.unique_id = unique_id[0] + self.rank_to_node, self.rank_to_ib = list(range(world_size)), list( + range(world_size) + ) + for r in range(world_size): + self.rank_to_node[r] = r // 8 + self.rank_to_ib[r] = self.rank % 8 + + self._context = None + self.context_selection = None + self.msg_size_for_finetune = [ + 2**i for i in range(10, math.floor(math.log2(self.max_bytes)) + 1) + ] + self.msg_size2best_config = {} + if world_size == 8: + self.context_selection = MscclContextSelection.MSCCL1SHOT1NODELL + elif world_size == 16: + self.context_selection = MscclContextSelection.MSCCL1SHOT2NODELL + if not _is_hip: + self.scratch = torch.empty( + self.max_bytes * 8, + dtype=torch.uint8, + device=self.device, + ) + self.put_buffer = torch.empty( + self.max_bytes * 8 // self.nranks_per_node, + dtype=torch.uint8, + device=self.device, + ) + self._context = ops.mscclpp_init_context( + self.unique_id, + self.rank, + self.world_size, + self.scratch, + self.put_buffer, + self.nranks_per_node, + self.rank_to_node, + self.rank_to_ib, + int(self.context_selection), + ) + else: + raise NotImplementedError("HIP Mscclpp is not supported yet.") + + self.msg_size2best_config = {} + self.pre_tune_config() + if dist.get_rank(group) == 0: + msg_size2best_config = [self.msg_size2best_config] + else: + msg_size2best_config = [None] + dist.broadcast_object_list( + msg_size2best_config, src=self.ranks[0], group=self.group + ) + self.msg_size2best_config = msg_size2best_config[0] + + # PyMscclpp is enabled only in cuda graph + self.disabled = True + + def pre_tune_config(self, dtype=torch.bfloat16) -> bool: + logger.debug(f"start to pre-tune configs for rank {self.rank}") + nthreads_to_try = [256, 512, 1024] + nblocks_to_try = [21, 42, 84] + inp_randn = torch.ones( + self.msg_size_for_finetune[-1] // dtype.itemsize, dtype=dtype, device="cuda" + ) + oup_randn = torch.empty_like(inp_randn) + for msg_size in self.msg_size_for_finetune: + mock_inp, mock_outp = ( + inp_randn[: msg_size // dtype.itemsize], + oup_randn[: msg_size // dtype.itemsize], + ) + best_config, best_time = None, None + for nthreads in nthreads_to_try: + for nblocks in nblocks_to_try: + cur_cost = mscclpp_bench_time( + lambda: ops.mscclpp_allreduce( + self._context, mock_inp, mock_outp, nthreads, nblocks + ) + ) + if best_time is None or cur_cost < best_time: + best_config = (nthreads, nblocks) + best_time = cur_cost + self.msg_size2best_config[msg_size] = best_config + if self.rank == 0: + logger.debug( + f"for msg_size {msg_size}, best_config: {best_config}, best_time: {best_time}us" + ) + + def should_mscclpp_allreduce( + self, inp: torch.Tensor, op: ReduceOp = ReduceOp.SUM + ) -> bool: + if self.disabled or self._context is None: + return False + if inp.dtype not in PyMscclppCommunicator._SUPPORTED_DTYPE: + return False + if not mscclpp_is_weak_contiguous(inp): + return False + # only support sum op + if op != ReduceOp.SUM: + return False + if inp.numel() * inp.element_size() > self.max_bytes: + return False + return True + + def all_reduce(self, tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM): + if self._IS_CAPTURING: + if torch.cuda.is_current_stream_capturing(): + self.graph_input_set.add((tensor.dtype, tensor.numel())) + msg_size = tensor.numel() * tensor.itemsize + index = bisect.bisect_left(self.msg_size_for_finetune, msg_size) + msg_size_finetune = self.msg_size_for_finetune[index] + nthreads, nblocks = self.msg_size2best_config[msg_size_finetune] + result = torch.empty_like(tensor) + ops.mscclpp_allreduce(self._context, tensor, result, nthreads, nblocks) + return result + + @contextmanager + def change_state( + self, + enable: Optional[bool] = None, + ): + if enable is None: + # guess a default value when not specified + enable = self.available + + old_disable = self.disabled + self.disabled = not enable + + yield + + self.disabled = old_disable diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index 943bccc53..cc2ba95a6 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -190,6 +190,7 @@ class GroupCoordinator: cpu_group: ProcessGroup # group for CPU communication device_group: ProcessGroup # group for device communication use_pynccl: bool # a hint of whether to use PyNccl + use_pymscclpp: bool # a hint of whether to use PyMsccl use_custom_allreduce: bool # a hint of whether to use CustomAllreduce use_message_queue_broadcaster: ( bool # a hint of whether to use message queue broadcaster @@ -205,6 +206,7 @@ class GroupCoordinator: local_rank: int, torch_distributed_backend: Union[str, Backend], use_pynccl: bool, + use_pymscclpp: bool, use_custom_allreduce: bool, use_hpu_communicator: bool, use_xpu_communicator: bool, @@ -244,6 +246,7 @@ class GroupCoordinator: self.device = torch.device("cpu") self.use_pynccl = use_pynccl + self.use_pymscclpp = use_pymscclpp self.use_custom_allreduce = use_custom_allreduce self.use_hpu_communicator = use_hpu_communicator self.use_xpu_communicator = use_xpu_communicator @@ -265,6 +268,17 @@ class GroupCoordinator: device=self.device, ) + from sglang.srt.distributed.device_communicators.pymscclpp import ( + PyMscclppCommunicator, + ) + + self.pymscclpp_comm: Optional[PyMscclppCommunicator] = None + if use_pymscclpp and self.world_size > 1: + self.pymscclpp_comm = PyMscclppCommunicator( + group=self.cpu_group, + device=self.device, + ) + self.ca_comm: Optional[CustomAllreduce] = None if use_custom_allreduce and self.world_size > 1: # Initialize a custom fast all-reduce implementation. @@ -373,11 +387,15 @@ class GroupCoordinator: # -------------------------------------------- # custom allreduce | enabled | enabled | # PyNccl | disabled| enabled | + # PyMscclpp | disabled| enabled | # torch.distributed | enabled | disabled| # # Note that custom allreduce will have a runtime check, if the # tensor size is too large, it will fallback to the next # available option. + # Note that the PyMsccl needs to register the tensor in ahead, + # which will introduce large overhead in the eager case, + # therefore it is only supported in the graph case. # In summary: When using CUDA graph, we use # either custom all-reduce kernel or pynccl. When not using # CUDA graph, we use either custom all-reduce kernel or @@ -392,7 +410,14 @@ class GroupCoordinator: maybe_pynccl_context = pynccl_comm.change_state( enable=True, stream=torch.cuda.current_stream() ) - with maybe_pynccl_context: + + pymscclpp_comm = self.pymscclpp_comm + maybe_pymscclpp_context: Any + if not pymscclpp_comm: + maybe_pymscclpp_context = nullcontext() + else: + maybe_pymscclpp_context = pymscclpp_comm.change_state(enable=True) + with maybe_pynccl_context, maybe_pymscclpp_context: yield graph_capture_context def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: @@ -437,6 +462,10 @@ class GroupCoordinator: self.ca_comm is not None and not self.ca_comm.disabled and self.ca_comm.should_custom_ar(input_) + ) or ( + self.pymscclpp_comm is not None + and not self.pymscclpp_comm.disabled + and self.pymscclpp_comm.should_mscclpp_allreduce(input_) ): return torch.ops.sglang.outplace_all_reduce( input_, group_name=self.unique_name @@ -447,9 +476,13 @@ class GroupCoordinator: def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor: ca_comm = self.ca_comm - assert ca_comm is not None - assert not ca_comm.disabled - out = ca_comm.custom_all_reduce(input_) + pymscclpp_comm = self.pymscclpp_comm + assert ca_comm is not None or pymscclpp_comm is not None + if ca_comm is not None and not ca_comm.disabled: + out = ca_comm.custom_all_reduce(input_) + else: + assert not pymscclpp_comm.disabled + out = pymscclpp_comm.all_reduce(input_) assert out is not None return out @@ -958,6 +991,7 @@ def init_world_group( local_rank=local_rank, torch_distributed_backend=backend, use_pynccl=False, + use_pymscclpp=False, use_custom_allreduce=False, use_hpu_communicator=False, use_xpu_communicator=False, @@ -973,14 +1007,18 @@ def init_model_parallel_group( use_custom_allreduce: Optional[bool] = None, use_message_queue_broadcaster: bool = False, group_name: Optional[str] = None, + use_mscclpp_allreduce: Optional[bool] = None, ) -> GroupCoordinator: if use_custom_allreduce is None: use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE + if use_mscclpp_allreduce is None: + use_mscclpp_allreduce = _ENABLE_MSCCLPP_ALL_REDUCE return GroupCoordinator( group_ranks=group_ranks, local_rank=local_rank, torch_distributed_backend=backend, use_pynccl=not is_npu(), + use_pymscclpp=use_mscclpp_allreduce, use_custom_allreduce=use_custom_allreduce, use_hpu_communicator=True, use_xpu_communicator=True, @@ -1037,6 +1075,7 @@ def graph_capture(): logger = logging.getLogger(__name__) _ENABLE_CUSTOM_ALL_REDUCE = True +_ENABLE_MSCCLPP_ALL_REDUCE = False def set_custom_all_reduce(enable: bool): @@ -1044,6 +1083,11 @@ def set_custom_all_reduce(enable: bool): _ENABLE_CUSTOM_ALL_REDUCE = enable +def set_mscclpp_all_reduce(enable: bool): + global _ENABLE_MSCCLPP_ALL_REDUCE + _ENABLE_MSCCLPP_ALL_REDUCE = enable + + def init_distributed_environment( world_size: int = -1, rank: int = -1, diff --git a/python/sglang/srt/layers/dp_attention.py b/python/sglang/srt/layers/dp_attention.py index 5fa7ce092..b1862ff2c 100644 --- a/python/sglang/srt/layers/dp_attention.py +++ b/python/sglang/srt/layers/dp_attention.py @@ -98,11 +98,12 @@ def initialize_dp_attention( ], local_rank, torch.distributed.get_backend(tp_group.device_group), - SYNC_TOKEN_IDS_ACROSS_TP, - False, - False, - False, - False, + use_pynccl=SYNC_TOKEN_IDS_ACROSS_TP, + use_pymscclpp=False, + use_custom_allreduce=False, + use_hpu_communicator=False, + use_xpu_communicator=False, + use_npu_communicator=False, group_name="attention_tp", ) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index a7885cad4..adabc897f 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -35,6 +35,7 @@ from sglang.srt.distributed import ( init_distributed_environment, initialize_model_parallel, set_custom_all_reduce, + set_mscclpp_all_reduce, ) from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state from sglang.srt.layers.attention.tbo_backend import TboAttnBackend @@ -460,6 +461,7 @@ class ModelRunner: else: dist_init_method = f"tcp://127.0.0.1:{self.dist_port}" set_custom_all_reduce(not self.server_args.disable_custom_all_reduce) + set_mscclpp_all_reduce(self.server_args.enable_mscclpp) if not self.is_draft_worker: # Only initialize the distributed environment on the target model worker. diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 961350f7c..b3557a472 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -165,6 +165,7 @@ class ServerArgs: enable_tokenizer_batch_encode: bool = False disable_outlines_disk_cache: bool = False disable_custom_all_reduce: bool = False + enable_mscclpp: bool = False disable_overlap_schedule: bool = False enable_mixed_chunk: bool = False enable_dp_attention: bool = False @@ -1168,6 +1169,11 @@ class ServerArgs: action="store_true", help="Disable the custom all-reduce kernel and fall back to NCCL.", ) + parser.add_argument( + "--enable-mscclpp", + action="store_true", + help="Enable using mscclpp for small messages for all-reduce kernel and fall back to NCCL.", + ) parser.add_argument( "--disable-overlap-schedule", action="store_true", diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index 09f8f529f..ab11ded67 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -73,6 +73,14 @@ FetchContent_Declare( GIT_SHALLOW OFF ) FetchContent_Populate(repo-flash-attention) +# mscclpp +FetchContent_Declare( + repo-mscclpp + GIT_REPOSITORY https://github.com/microsoft/mscclpp.git + GIT_TAG 51eca89d20f0cfb3764ccd764338d7b22cd486a6 + GIT_SHALLOW OFF +) +FetchContent_Populate(repo-mscclpp) # ccache option option(ENABLE_CCACHE "Whether to use ccache" ON) @@ -99,6 +107,7 @@ include_directories( ${repo-cutlass_SOURCE_DIR}/tools/util/include ${repo-flashinfer_SOURCE_DIR}/include ${repo-flashinfer_SOURCE_DIR}/csrc + ${repo-mscclpp_SOURCE_DIR}/include ) set(SGL_KERNEL_CUDA_FLAGS @@ -196,6 +205,7 @@ string(REPLACE "-D__CUDA_NO_BFLOAT16_CONVERSIONS__" "" CMAKE_CUDA_FLAGS "${CMAKE string(REPLACE "-D__CUDA_NO_HALF2_OPERATORS__" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}") set(SOURCES + "csrc/allreduce/mscclpp_allreduce.cu" "csrc/allreduce/custom_all_reduce.cu" "csrc/attention/cascade.cu" "csrc/attention/merge_attn_states.cu" @@ -250,7 +260,27 @@ target_include_directories(common_ops PRIVATE ${repo-cutlass_SOURCE_DIR}/examples/common ${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src ) -target_link_libraries(common_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda cublas cublasLt) + +find_package(Python3 COMPONENTS Interpreter REQUIRED) +execute_process( + COMMAND ${Python3_EXECUTABLE} -c "import torch; print(int(torch._C._GLIBCXX_USE_CXX11_ABI))" + OUTPUT_VARIABLE TORCH_CXX11_ABI + OUTPUT_STRIP_TRAILING_WHITESPACE +) +if(TORCH_CXX11_ABI STREQUAL "0") + message(STATUS "Using old C++ ABI (-D_GLIBCXX_USE_CXX11_ABI=0)") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=0") + set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=0") +else() + message(STATUS "Using new C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI=1)") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=1") + set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=1") +endif() +set(MSCCLPP_USE_CUDA ON) +set(MSCCLPP_BYPASS_GPU_CHECK ON) +set(MSCCLPP_BUILD_TESTS OFF) +add_subdirectory(${repo-mscclpp_SOURCE_DIR}) +target_link_libraries(common_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda cublas cublasLt mscclpp_static) target_compile_definitions(common_ops PRIVATE FLASHATTENTION_DISABLE_BACKWARD diff --git a/sgl-kernel/Makefile b/sgl-kernel/Makefile index 830f2821e..fcd6153f3 100644 --- a/sgl-kernel/Makefile +++ b/sgl-kernel/Makefile @@ -19,14 +19,14 @@ submodule: ## Initialize and update git submodules @git submodule update --init --recursive ln: submodule ## Create compilation database - @rm -rf build && mkdir build && cd build && cmake .. -DCMAKE_EXPORT_COMPILE_COMMANDS=YES + @rm -rf build && mkdir build && cd build && cmake .. -DCMAKE_EXPORT_COMPILE_COMMANDS=YES -DCMAKE_POLICY_VERSION_MINIMUM=3.5 install: submodule ## Install package in development mode @pip install -e . --no-build-isolation build: install-deps submodule ## Build and install wheel package - @rm -rf dist/* || true && export MAX_JOBS=$(nproc) && CMAKE_BUILD_PARALLEL_LEVEL=$(nproc) uv build --wheel -Cbuild-dir=build . --verbose --color=always --no-build-isolation && pip3 install dist/*whl --force-reinstall --no-deps + @rm -rf dist/* || true && export MAX_JOBS=$(nproc) && CMAKE_POLICY_VERSION_MINIMUM=3.5 CMAKE_BUILD_PARALLEL_LEVEL=$(nproc) uv build --wheel -Cbuild-dir=build . --verbose --color=always --no-build-isolation && pip3 install dist/*whl --force-reinstall --no-deps clean: ## Remove build artifacts @rm -rf build dist *.egg-info diff --git a/sgl-kernel/build.sh b/sgl-kernel/build.sh index a3fb8a303..70a982ab2 100755 --- a/sgl-kernel/build.sh +++ b/sgl-kernel/build.sh @@ -50,6 +50,9 @@ docker run --rm \ which cmake cmake --version + yum install numactl-devel -y && \ + yum install libibverbs -y && \ + ln -sv /usr/lib64/libibverbs.so.1 /usr/lib64/libibverbs.so && \ ${PYTHON_ROOT_PATH}/bin/${TORCH_INSTALL} && \ ${PYTHON_ROOT_PATH}/bin/pip install --no-cache-dir ninja setuptools==75.0.0 wheel==0.41.0 numpy uv scikit-build-core && \ export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0+PTX' && \ diff --git a/sgl-kernel/csrc/allreduce/mscclpp_allreduce.cu b/sgl-kernel/csrc/allreduce/mscclpp_allreduce.cu new file mode 100644 index 000000000..d9cda22f0 --- /dev/null +++ b/sgl-kernel/csrc/allreduce/mscclpp_allreduce.cu @@ -0,0 +1,140 @@ +#include +#include +#include +#include + +#include "mscclpp_allreduce.cuh" + +enum MscclContextSelection { + MSCCL1NODELL = 1, + MSCCL2NODELL = 2, +}; + +class MscclContext { + public: + MscclContextSelection selection_; + std::shared_ptr msccl_1nodeLL_context; + std::shared_ptr msccl_2nodeLL_context; + MscclContext(MscclContextSelection selection) : selection_(selection) {} + template + void allreduce( + cudaStream_t stream, T* input, T* output, const size_t input_numel, int threads = 512, int block_limit = 21) { + if (selection_ == MSCCL1NODELL) { + msccl_1nodeLL_context->allreduce(stream, input, output, input_numel, threads, block_limit); + } else if (selection_ == MSCCL2NODELL) { + msccl_2nodeLL_context->allreduce(stream, input, output, input_numel, threads, block_limit); + } + } +}; + +using fptr_t = int64_t; +static_assert(sizeof(void*) == sizeof(fptr_t)); + +torch::Tensor _unique_id2tensor(const mscclpp::UniqueId& unique_id) { + auto options = torch::TensorOptions().dtype(torch::kByte).device(torch::kCPU); + auto tensor = torch::empty({static_cast(unique_id.size())}, options); + std::memcpy(tensor.data_ptr(), unique_id.data(), unique_id.size()); + return tensor; +} + +// Function to convert vector of int32_t back to array of uint8_t +mscclpp::UniqueId _tensor2unique_id(const torch::Tensor& tensor) { + mscclpp::UniqueId unique_id; + std::memcpy(unique_id.data(), tensor.data_ptr(), unique_id.size()); + return unique_id; +} + +torch::Tensor mscclpp_generate_unique_id() { + mscclpp::UniqueId unique_id = mscclpp::TcpBootstrap::createUniqueId(); + return _unique_id2tensor(unique_id); +} + +fptr_t mscclpp_init_context( + const torch::Tensor& unique_id, + const int64_t rank, + const int64_t world_size, + torch::Tensor& scratch, + torch::Tensor& put_buffer, + const int64_t nranks_per_node, + const std::vector& rank_to_node, + const std::vector& rank_to_ib, + const int64_t context_selection) { + MscclContext* context_ptr = new MscclContext(static_cast(context_selection)); + mscclpp::UniqueId uid = _tensor2unique_id(unique_id); + if (context_selection == MSCCL1NODELL) { + void* scratch_ptr = reinterpret_cast(scratch.data_ptr()); + const size_t scratch_bytes = scratch.numel() * scratch.element_size(); + context_ptr->msccl_1nodeLL_context = std::make_shared( + uid, rank, world_size, scratch_ptr, scratch_bytes, nranks_per_node, rank_to_node, rank_to_ib); + } else if (context_selection == MSCCL2NODELL) { + void* scratch_ptr = reinterpret_cast(scratch.data_ptr()); + const size_t scratch_bytes = scratch.numel() * scratch.element_size(); + void* put_buffer_ptr = reinterpret_cast(put_buffer.data_ptr()); + const size_t put_buffer_bytes = put_buffer.numel() * put_buffer.element_size(); + context_ptr->msccl_2nodeLL_context = std::make_shared( + uid, + rank, + world_size, + scratch_ptr, + scratch_bytes, + put_buffer_ptr, + put_buffer_bytes, + nranks_per_node, + rank_to_node, + rank_to_ib); + } else { + throw std::runtime_error("invalid context selection"); + } + return (fptr_t)context_ptr; +} + +bool _mscclpp_is_weak_contiguous(torch::Tensor& t) { + return t.is_contiguous() || + (t.storage().nbytes() - t.storage_offset() * t.element_size() == t.numel() * t.element_size()); +} +void mscclpp_allreduce(fptr_t _context, torch::Tensor& inp, torch::Tensor& out, int64_t nthreads, int64_t nblocks) { + MscclContext* context = reinterpret_cast(_context); + const at::cuda::OptionalCUDAGuard device_guard(device_of(inp)); + auto stream = c10::cuda::getCurrentCUDAStream().stream(); + + TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type()); + TORCH_CHECK_EQ(inp.numel(), out.numel()); + TORCH_CHECK(_mscclpp_is_weak_contiguous(out)); + TORCH_CHECK(_mscclpp_is_weak_contiguous(inp)); + switch (out.scalar_type()) { + case at::ScalarType::Float: { + context->allreduce( + stream, + reinterpret_cast(inp.data_ptr()), + reinterpret_cast(out.data_ptr()), + inp.numel(), + nthreads, + nblocks); + break; + } + case at::ScalarType::Half: { + context->allreduce( + stream, + reinterpret_cast(inp.data_ptr()), + reinterpret_cast(out.data_ptr()), + inp.numel(), + nthreads, + nblocks); + break; + } +#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) + case at::ScalarType::BFloat16: { + context->allreduce<__nv_bfloat16>( + stream, + reinterpret_cast<__nv_bfloat16*>(inp.data_ptr()), + reinterpret_cast<__nv_bfloat16*>(out.data_ptr()), + inp.numel(), + nthreads, + nblocks); + break; + } +#endif + default: + throw std::runtime_error("custom allreduce only supports float32, float16 and bfloat16"); + } +} diff --git a/sgl-kernel/csrc/allreduce/mscclpp_allreduce.cuh b/sgl-kernel/csrc/allreduce/mscclpp_allreduce.cuh new file mode 100644 index 000000000..2e064d704 --- /dev/null +++ b/sgl-kernel/csrc/allreduce/mscclpp_allreduce.cuh @@ -0,0 +1,779 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +#pragma once +#if defined(__HIP_PLATFORM_AMD__) +#include +#else +#include +#include +#endif + +#include +#include +#include +#include +#include +#include +#include + +// comment this for test_mscclpp_allreduce.cu +#include "utils.h" + +namespace sglang { + +__device__ mscclpp::DeviceSyncer deviceSyncer; +__device__ mscclpp::DeviceSyncer allGatherDeviceSyncer; +__device__ mscclpp::DeviceSyncer reduceScatterDeviceSyncer; +__device__ mscclpp::DeviceSyncer ibDeviceSyncer; + +template +__forceinline__ __device__ To bit_cast(const From& src) { + static_assert(sizeof(To) == sizeof(From), "Size mismatch for bit_cast"); + + union { + From f; + To t; + } u; + u.f = src; + return u.t; +} + +template +__forceinline__ __device__ T add_elements(T a, T b) { + return a + b; +} + +template <> +__forceinline__ __device__ __half2 add_elements(__half2 a, __half2 b) { + return __hadd2(a, b); +} + +#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) +template <> +__forceinline__ __device__ __nv_bfloat162 add_elements(__nv_bfloat162 a, __nv_bfloat162 b) { + return __hadd2(a, b); +} +#endif + +template +__forceinline__ __device__ int4 add_vectors_helper(int4 a, int4 b) { + int4 ret; + ret.w = bit_cast(add_elements(bit_cast(a.w), bit_cast(b.w))); + ret.x = bit_cast(add_elements(bit_cast(a.x), bit_cast(b.x))); + ret.y = bit_cast(add_elements(bit_cast(a.y), bit_cast(b.y))); + ret.z = bit_cast(add_elements(bit_cast(a.z), bit_cast(b.z))); + return ret; +} + +template +__forceinline__ __device__ int4 add_vectors(int4 a, int4 b) { + return add_vectors_helper(a, b); +} + +#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) +template <> +__forceinline__ __device__ int4 add_vectors<__nv_bfloat16>(int4 a, int4 b) { + return add_vectors_helper<__nv_bfloat162>(a, b); +} +#endif + +template <> +__forceinline__ __device__ int4 add_vectors<__half>(int4 a, int4 b) { + return add_vectors_helper<__half2>(a, b); +} + +template +__forceinline__ __device__ uint2 add_vectors_helper(uint2 a, uint2 b) { + uint2 ret; + ret.x = bit_cast(add_elements(bit_cast(a.x), bit_cast(b.x))); + ret.y = bit_cast(add_elements(bit_cast(a.y), bit_cast(b.y))); + return ret; +} + +template +__forceinline__ __device__ uint2 add_vectors(uint2 a, uint2 b) { + return add_vectors_helper(a, b); +} + +#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) +template <> +__forceinline__ __device__ uint2 add_vectors<__nv_bfloat16>(uint2 a, uint2 b) { + return add_vectors_helper<__nv_bfloat162>(a, b); +} +#endif + +template <> +__forceinline__ __device__ uint2 add_vectors<__half>(uint2 a, uint2 b) { + return add_vectors_helper<__half2>(a, b); +} + +template +__forceinline__ __device__ int add_vectors_helper(int a, int b) { + return bit_cast(add_elements(bit_cast(a), bit_cast(b))); +} + +template +__forceinline__ __device__ int add_vectors(int a, int b) { + return add_vectors_helper(a, b); +} + +#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) +template <> +__forceinline__ __device__ int add_vectors<__nv_bfloat16>(int a, int b) { + return add_vectors_helper<__nv_bfloat162>(a, b); +} +#endif + +template <> +__forceinline__ __device__ int add_vectors<__half>(int a, int b) { + return add_vectors_helper<__half2>(a, b); +} + +// ------------------------------------------------------- +// allreduce_LL_1node using LLPacket, origin allreduce2 +// ------------------------------------------------------- + +__device__ uint64_t globalFlag = 1; + +template +__global__ void __launch_bounds__(1024, 1) allreduce_LL_1node( + mscclpp::MemoryChannelDeviceHandle* memChans, + TYPE* buff, + TYPE* scratch, + void* resultBuff, + int rank, + int worldSize, + size_t nelems) { + nelems = nelems / (sizeof(int) / sizeof(TYPE)); + // This version of allreduce only works for single nodes + const int nPeers = worldSize - 1; + const size_t nPkts = nelems / 2; + const int nelemsPerRank = nelems / worldSize; + const int nPktsPerRank = nelemsPerRank / 2; + // flag for packets. Initially 1 + const uint32_t flag = (uint32_t)globalFlag; + // thread block & channel info + const int nBlocksPerPeer = gridDim.x / nPeers; + const int localBlockIdx = blockIdx.x % nBlocksPerPeer; + const int peerIdx = blockIdx.x / nBlocksPerPeer; + const int remoteRank = peerIdx < rank ? peerIdx : peerIdx + 1; + mscclpp::MemoryChannelDeviceHandle memChan = memChans[peerIdx]; + const int tid = threadIdx.x + localBlockIdx * blockDim.x; + // double buffering + size_t scratchBaseOffset = (flag & 1) ? 0 : nPkts * sizeof(mscclpp::LLPacket); + void* scratchBuff = (void*)((char*)scratch + scratchBaseOffset); + size_t scratchOffset = scratchBaseOffset + rank * nPktsPerRank * sizeof(mscclpp::LLPacket); + size_t scratchResultOffset = + (flag & 1) ? 2 * nPkts * sizeof(mscclpp::LLPacket) : 3 * nPkts * sizeof(mscclpp::LLPacket); + size_t srcOffset = remoteRank * nelemsPerRank * sizeof(int); + uint2* src = (uint2*)((char*)buff + rank * nelemsPerRank * sizeof(int)); + uint2* dst = (uint2*)((char*)resultBuff + rank * nelemsPerRank * sizeof(int)); + + // step 1: write to scratch buffer + memChan.putPackets(scratchOffset, srcOffset, nelemsPerRank * sizeof(int), tid, blockDim.x * nBlocksPerPeer, flag); + // step 2: get data from scratch buffer, reduce data and write result to remote scratch buffer + for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < nPktsPerRank; idx += blockDim.x * gridDim.x) { + uint2 data = make_uint2(0, 0); + for (int index = 0; index < nPeers; index++) { + const int remoteRank = index < rank ? index : index + 1; + mscclpp::LLPacket* dstPkt = (mscclpp::LLPacket*)scratchBuff + remoteRank * nPktsPerRank; + uint2 val = dstPkt[idx].read(flag); + data = add_vectors(val, data); + } + data = add_vectors(data, src[idx]); + dst[idx] = data; + + mscclpp::LLPacket packet; + packet.data1 = data.x; + packet.flag1 = flag; + packet.data2 = data.y; + packet.flag2 = flag; + size_t offset = scratchResultOffset / sizeof(mscclpp::LLPacket) + (idx + rank * nPktsPerRank); + for (int index = 0; index < nPeers; index++) { + memChans[index].write(offset, packet); + } + } + // step 3: get data result from scratch buffer + mscclpp::LLPacket* dstPkt = (mscclpp::LLPacket*)((char*)scratch + scratchResultOffset); + const int dstOffset = remoteRank * nPktsPerRank; + uint2* result = (uint2*)((char*)resultBuff + remoteRank * nelemsPerRank * sizeof(int)); + for (int idx = threadIdx.x + localBlockIdx * blockDim.x; idx < nPktsPerRank; idx += blockDim.x * nBlocksPerPeer) { + uint2 data = dstPkt[idx + dstOffset].read(flag); + result[idx].x = data.x; + result[idx].y = data.y; + } + if (threadIdx.x == 0 && blockIdx.x == 0) { + globalFlag += 1; + } +} + +// ------------------------------------------------------- +// allreduce_LL_2node using LLPacket, origin allreduce5 +// ------------------------------------------------------- + +template +__global__ void __launch_bounds__(1024, 1) allreduce_LL_2node( + mscclpp::MemoryChannelDeviceHandle* memChans, + mscclpp::PortChannelDeviceHandle* portChans, + TYPE* buff, + TYPE* scratch, + TYPE* putBuff, + TYPE* resultBuff, + int rank, + int nRanksPerNode, + int worldSize, + size_t nelems) { + nelems = nelems / (sizeof(int) / sizeof(TYPE)); + // This version of allreduce only works for single nodes + const int nPeersInNode = nRanksPerNode - 1; + const int nPkts = nelems / 2; + const int nelemsPerLocalRank = nelems / nRanksPerNode; + const int nPktsPerLocalRank = nelemsPerLocalRank / 2; + const int localRankId = rank % nRanksPerNode; + // flag for packets. Initially 1 + const uint32_t flag = (uint32_t)globalFlag; + // thread block & channel info + const int nBlocksPerPeer = gridDim.x / nPeersInNode; + const int localBlockIdx = blockIdx.x % nBlocksPerPeer; + const int peerIdx = blockIdx.x / nBlocksPerPeer; + const int remoteRankIdx = peerIdx < localRankId ? peerIdx : peerIdx + 1; + mscclpp::MemoryChannelDeviceHandle memChan = memChans[peerIdx]; + mscclpp::PortChannelDeviceHandle portChan = portChans[localRankId]; + const int tid = threadIdx.x + localBlockIdx * blockDim.x; + // double buffering + size_t scratchBaseOffset = (flag & 1) ? 0 : nPkts * sizeof(mscclpp::LLPacket); + size_t putBaseOffset = (flag & 1) ? 0 : nPktsPerLocalRank * sizeof(mscclpp::LLPacket); + void* scratchBuff = (void*)((char*)scratch + scratchBaseOffset); + size_t scratchOffset = scratchBaseOffset + localRankId * nPktsPerLocalRank * sizeof(mscclpp::LLPacket); + size_t scratchResultOffset = + (flag & 1) ? 2 * nPkts * sizeof(mscclpp::LLPacket) : 3 * nPkts * sizeof(mscclpp::LLPacket); + size_t srcOffset = remoteRankIdx * nelemsPerLocalRank * sizeof(int); + uint2* src = (uint2*)((char*)buff + localRankId * nelemsPerLocalRank * sizeof(int)); + uint2* dst = (uint2*)((char*)resultBuff + localRankId * nelemsPerLocalRank * sizeof(int)); + + // step 1: write to scratch buffer + if (nRanksPerNode > 1) { + memChan.putPackets( + scratchOffset, srcOffset, nelemsPerLocalRank * sizeof(int), tid, blockDim.x * nBlocksPerPeer, flag); + } + // step 2: get data from scratch buffer, do local reduce-scatter in each node. + mscclpp::LLPacket* putPkt = (mscclpp::LLPacket*)((char*)putBuff + putBaseOffset); + for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < nPktsPerLocalRank; idx += blockDim.x * gridDim.x) { + uint2 data = make_uint2(0, 0); + for (int index = 0; index < nPeersInNode; index++) { + const int remoteRank = index < localRankId ? index : index + 1; + mscclpp::LLPacket* dstPkt = (mscclpp::LLPacket*)scratchBuff + remoteRank * nPktsPerLocalRank; + uint2 val = dstPkt[idx].read(flag); + data = add_vectors(val, data); + } + data = add_vectors(data, src[idx]); + putPkt[idx].write(data.x, data.y, flag); + dst[idx] = data; + } + deviceSyncer.sync(gridDim.x); + // step 3. send local reduced data to remote node. + if (threadIdx.x == 0 && blockIdx.x == 0) { + portChan.put(scratchOffset, putBaseOffset, nPktsPerLocalRank * sizeof(mscclpp::LLPacket)); + if ((flag & 63) == 0) { + portChan.flush(); + } + } + // step 4. try to read the data from scratch buffer and write to local peers + mscclpp::LLPacket* dstPkt = (mscclpp::LLPacket*)scratchBuff + localRankId * nPktsPerLocalRank; + for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < nPktsPerLocalRank; idx += blockDim.x * gridDim.x) { + uint2 res = dst[idx]; + uint2 val = dstPkt[idx].read(flag); + res = add_vectors(res, val); + + mscclpp::LLPacket packet; + packet.data1 = res.x; + packet.flag1 = flag; + packet.data2 = res.y; + packet.flag2 = flag; + size_t offset = scratchResultOffset / sizeof(mscclpp::LLPacket) + (idx + localRankId * nPktsPerLocalRank); + for (int index = 0; index < nPeersInNode; index++) { + memChans[index].write(offset, packet); + } + dst[idx] = res; + } + + // step 5: get data result from scratch buffer + dstPkt = (mscclpp::LLPacket*)((char*)scratch + scratchResultOffset); + const int dstOffset = remoteRankIdx * nPktsPerLocalRank; + uint2* result = (uint2*)((char*)resultBuff + remoteRankIdx * nelemsPerLocalRank * sizeof(int)); + if (nRanksPerNode > 1) { + for (int idx = threadIdx.x + localBlockIdx * blockDim.x; idx < nPktsPerLocalRank; + idx += blockDim.x * nBlocksPerPeer) { + uint2 data = dstPkt[idx + dstOffset].read(flag); + result[idx] = data; + } + } + if (threadIdx.x == 0 && blockIdx.x == 0) { + globalFlag += 1; + } +} + +static const mscclpp::Transport IBs[] = { + mscclpp::Transport::IB0, + mscclpp::Transport::IB1, + mscclpp::Transport::IB2, + mscclpp::Transport::IB3, + mscclpp::Transport::IB4, + mscclpp::Transport::IB5, + mscclpp::Transport::IB6, + mscclpp::Transport::IB7}; + +class MscclCommGroup { + public: + std::shared_ptr comm_; + const size_t rank_; + const size_t world_size_; + const std::vector rank_to_node_; + const std::vector rank_to_ib_; + MscclCommGroup( + mscclpp::UniqueId unique_id, + const size_t rank, + const size_t world_size, + const std::vector& rank_to_node, + const std::vector& rank_to_ib) + : rank_(rank), world_size_(world_size), rank_to_node_(rank_to_node), rank_to_ib_(rank_to_ib) { + auto bootstrap = std::make_shared(rank, world_size); + bootstrap->initialize(unique_id); + comm_ = std::make_shared(bootstrap); + } + template + void allreduce(cudaStream_t stream, T* output, size_t input_numel, int threads = 512, int block_limit = 21) { + throw std::runtime_error("you should not call allreduce of a base context"); + } + bool is_same_node(int r1, int r2) { + return rank_to_node_[r1] == rank_to_node_[r2]; + } + + void make_connection( + std::unordered_map>& same_node_connections, + std::unordered_map>& cross_node_connections) { + same_node_connections.clear(); + cross_node_connections.clear(); + std::unordered_map>> conn_futures; + for (int r = 0; r < world_size_; ++r) { + if (r == rank_) continue; + mscclpp::Transport transport = is_same_node(r, rank_) ? mscclpp::Transport::CudaIpc : IBs[rank_to_ib_[r]]; + conn_futures.emplace(r, comm_->connectOnSetup(r, 0, transport)); + } + comm_->setup(); + for (int r = 0; r < world_size_; ++r) { + if (r == rank_) continue; + if (is_same_node(r, rank_)) { + same_node_connections.emplace(r, conn_futures[r].get()); + } else { + cross_node_connections.emplace(r, conn_futures[r].get()); + } + } + } + + void make_memory_channels_with_scratch( + void* tensor_ptr, + const size_t tensor_bytes, + void* scratch_ptr, + const size_t scratch_bytes, + const std::unordered_map>& connections, + std::unordered_map>& semaphores, + std::unordered_map& registered_memories, + std::unordered_map& channels) { + channels.clear(); + make_semaphores(connections, semaphores); + register_tensor_with_connections(scratch_ptr, scratch_bytes, connections, registered_memories); + for (const auto& [peer, _] : connections) { + channels.emplace( + peer, mscclpp::MemoryChannel(semaphores[peer], registered_memories[peer], tensor_ptr, scratch_ptr)); + } + } + void make_port_channels_with_scratch( + std::shared_ptr proxyService, + void* tensor_ptr, + const size_t tensor_bytes, + void* scratch_ptr, + const size_t scratch_bytes, + const std::unordered_map>& connections, + std::unordered_map>& semaphores, + std::unordered_map& registered_memories, + std::unordered_map& channels) { + channels.clear(); + make_semaphores(connections, semaphores); + + mscclpp::TransportFlags flags; + for (const auto& [_, conn] : connections) { + flags |= conn->transport(); + } + auto local_reg_memory = comm_->registerMemory(tensor_ptr, tensor_bytes, flags); + + register_tensor_with_connections(scratch_ptr, scratch_bytes, connections, registered_memories); + std::unordered_map semaphore_ids; + std::unordered_map memory_ids; + memory_ids[rank_] = proxyService->addMemory(local_reg_memory); + for (const auto& [peer, memory] : registered_memories) { + if (peer == rank_) continue; + memory_ids[peer] = proxyService->addMemory(memory); + } + for (const auto& [peer, semaphore] : semaphores) { + semaphore_ids[peer] = proxyService->addSemaphore(semaphore); + } + + for (const auto& [peer, _] : connections) { + channels.emplace(peer, proxyService->portChannel(semaphore_ids[peer], memory_ids[peer], memory_ids[rank_])); + } + } + + template + void make_semaphores( + const std::unordered_map>& connections, + std::unordered_map>& semaphores) { + semaphores.clear(); + for (const auto& [peer, conn] : connections) { + semaphores[peer] = std::make_shared(*comm_, conn); + } + comm_->setup(); + } + + void register_tensor_with_connections( + void* tensor_ptr, + size_t tensor_bytes, + const std::unordered_map>& connections, + std::unordered_map& registered_memories) { + registered_memories.clear(); + mscclpp::TransportFlags all_transports; + for (const auto& [_, connection] : connections) { + all_transports |= connection->transport(); + } + mscclpp::RegisteredMemory buf_reg_mem = comm_->registerMemory(tensor_ptr, tensor_bytes, all_transports); + registered_memories[rank_] = buf_reg_mem; + + std::unordered_map> remote_mem_futures; + for (const auto& [r, connection] : connections) { + comm_->sendMemoryOnSetup(buf_reg_mem, r, 0); + auto remoteMemory = comm_->recvMemoryOnSetup(r, 0); + remote_mem_futures.emplace(r, remoteMemory); + } + comm_->setup(); + for (auto& [r, mem_feature] : remote_mem_futures) { + registered_memories.emplace(r, mem_feature.get()); + } + } + + void make_device_memory_handle_base_on_new_ptr( + const std::unordered_map& old_memory_channels, + std::unordered_map& registered_sm_memories, + std::unordered_map>& memory_semaphores, + std::unordered_map& memory_channels, + mscclpp::GpuBuffer& device_memory_handle, + void* input, + void* scratch, + const cudaStream_t stream) { + memory_channels.clear(); + for (const auto& [peer, channel] : old_memory_channels) { + memory_channels.emplace( + peer, mscclpp::MemoryChannel(memory_semaphores[peer], registered_sm_memories[peer], input, scratch)); + } + std::vector memory_channels_list; + for (int r = 0; r < world_size_; r++) { + if (r == rank_) continue; + if (is_same_node(r, rank_)) { + memory_channels_list.push_back(memory_channels[r]); + } + } + std::vector memory_channel_handlers(memory_channels_list.size()); + std::transform( + memory_channels_list.begin(), + memory_channels_list.end(), + memory_channel_handlers.begin(), + [](const mscclpp::MemoryChannel& channel) { return channel.deviceHandle(); }); + mscclpp::gpuMemcpyAsync( + device_memory_handle.data(), + memory_channel_handlers.data(), + memory_channel_handlers.size(), + stream, + cudaMemcpyHostToDevice); + } +}; + +class Msccl1NodeLLcontext { + private: + std::shared_ptr comm_group_ = nullptr; + void* scratch_; + const size_t scratch_bytes_; + std::unordered_map> same_node_connections_; + std::unordered_map> cross_node_connections_; + + std::unordered_map registered_sm_memories_; + std::unordered_map> memory_semaphores_; + std::unordered_map memory_channels_; + mscclpp::GpuBuffer d_memHandles_; + std::unordered_map> input_ptr2memory_channels_; + std::unordered_map> input_ptr2d_memHandles_; + cudaStream_t h2d_stream; + const size_t nranks_per_node_; + + public: + Msccl1NodeLLcontext( + mscclpp::UniqueId unique_id, + const size_t rank, + const size_t world_size, + void* scratch, + const size_t scratch_bytes, + const size_t nranks_per_node, + const std::vector& rank_to_node, + const std::vector& rank_to_ib) + : scratch_(scratch), + scratch_bytes_(scratch_bytes), + nranks_per_node_(nranks_per_node), + d_memHandles_(nranks_per_node - 1) { + CHECK_CUDA_SUCCESS(cudaStreamCreateWithFlags(&h2d_stream, cudaStreamNonBlocking)); + comm_group_ = std::make_shared(unique_id, rank, world_size, rank_to_node, rank_to_ib); + comm_group_->make_connection(same_node_connections_, cross_node_connections_); + comm_group_->make_memory_channels_with_scratch( + scratch_, + scratch_bytes_, + scratch_, + scratch_bytes_, + same_node_connections_, + memory_semaphores_, + registered_sm_memories_, + memory_channels_); + std::vector memory_channels_list; + for (int r = 0; r < comm_group_->world_size_; r++) { + if (r == comm_group_->rank_) continue; + memory_channels_list.push_back(memory_channels_[r]); + } + std::vector memory_channel_handlers(memory_channels_list.size()); + std::transform( + memory_channels_list.begin(), + memory_channels_list.end(), + memory_channel_handlers.begin(), + [](const mscclpp::MemoryChannel& channel) { return channel.deviceHandle(); }); + mscclpp::gpuMemcpy( + d_memHandles_.data(), memory_channel_handlers.data(), memory_channel_handlers.size(), cudaMemcpyHostToDevice); + } + + ~Msccl1NodeLLcontext() { + CHECK_CUDA_SUCCESS(cudaStreamDestroy(h2d_stream)); + } + + template + void allreduce(cudaStream_t stream, T* input, T* output, size_t input_numel, int nthreads = 512, int nblocks = 21) { + dim3 nthrs(nthreads); + dim3 nblks(nblocks); + cudaStreamCaptureStatus capturing_status; + CHECK_CUDA_SUCCESS(cudaStreamIsCapturing(stream, &capturing_status)); + mscclpp::MemoryChannelDeviceHandle* memChans; + if (capturing_status != cudaStreamCaptureStatusActive) { + std::unordered_map memory_channels; + comm_group_->make_device_memory_handle_base_on_new_ptr( + memory_channels_, + registered_sm_memories_, + memory_semaphores_, + memory_channels, + d_memHandles_, + input, + scratch_, + h2d_stream); + CHECK_CUDA_SUCCESS(cudaStreamSynchronize(h2d_stream)); + memChans = d_memHandles_.data(); + } else { + void* input_void_ptr = reinterpret_cast(input); + if (input_ptr2d_memHandles_.find(input_void_ptr) == input_ptr2d_memHandles_.end()) { + std::unordered_map memory_channels; + mscclpp::GpuBuffer device_memory_handle(comm_group_->world_size_ - 1); + comm_group_->make_device_memory_handle_base_on_new_ptr( + memory_channels_, + registered_sm_memories_, + memory_semaphores_, + memory_channels, + device_memory_handle, + input, + scratch_, + h2d_stream); + input_ptr2memory_channels_.emplace(input_void_ptr, memory_channels); + input_ptr2d_memHandles_.emplace(input_void_ptr, device_memory_handle); + } + auto it = input_ptr2d_memHandles_.find(input_void_ptr); + memChans = it->second.data(); + } + allreduce_LL_1node<<>>( + memChans, (T*)input, (T*)scratch_, output, comm_group_->rank_, comm_group_->world_size_, input_numel); + + cudaError_t status = cudaGetLastError(); + if (status != cudaSuccess) { + printf("rank: %lu failed to launch allreduce_LL_1node: %s\n", comm_group_->rank_, cudaGetErrorString(status)); + } + } +}; + +class Msccl2NodeLLcontext { + private: + std::shared_ptr comm_group_ = nullptr; + void* scratch_; + const size_t scratch_bytes_; + void* put_buffer_; + const size_t put_buffer_bytes_; + std::unordered_map> same_node_connections_; + std::unordered_map> cross_node_connections_; + + std::unordered_map registered_sm_memories_; + std::unordered_map registered_port_memories_; + + std::unordered_map> memory_semaphores_; + std::unordered_map> port_semaphores_; + + std::unordered_map memory_channels_; + std::unordered_map port_channels_; + + mscclpp::GpuBuffer d_memHandles_; + mscclpp::GpuBuffer d_portHandles_; + + std::shared_ptr proxyService; + cudaStream_t h2d_stream; + const size_t nranks_per_node_; + + std::unordered_map> input_ptr2memory_channels_; + std::unordered_map> input_ptr2d_memHandles_; + + public: + Msccl2NodeLLcontext( + mscclpp::UniqueId unique_id, + const size_t rank, + const size_t world_size, + void* scratch, + const size_t scratch_bytes, + void* put_buffer, + const size_t put_buffer_bytes, + const size_t nranks_per_node, + const std::vector& rank_to_node, + const std::vector& rank_to_ib) + : scratch_(scratch), + scratch_bytes_(scratch_bytes), + put_buffer_(put_buffer), + put_buffer_bytes_(put_buffer_bytes), + nranks_per_node_(nranks_per_node), + d_memHandles_(nranks_per_node - 1), + d_portHandles_(world_size - nranks_per_node) { + CHECK_CUDA_SUCCESS(cudaStreamCreateWithFlags(&h2d_stream, cudaStreamNonBlocking)); + comm_group_ = std::make_shared(unique_id, rank, world_size, rank_to_node, rank_to_ib); + proxyService = std::make_shared(); + proxyService->startProxy(); + comm_group_->make_connection(same_node_connections_, cross_node_connections_); + comm_group_->make_memory_channels_with_scratch( + scratch_, + scratch_bytes_, + scratch_, + scratch_bytes_, + same_node_connections_, + memory_semaphores_, + registered_sm_memories_, + memory_channels_); + comm_group_->make_port_channels_with_scratch( + proxyService, + put_buffer_, + put_buffer_bytes_, + scratch_, + scratch_bytes_, + cross_node_connections_, + port_semaphores_, + registered_port_memories_, + port_channels_); + std::vector memory_channels_list; + std::vector port_channels_list; + for (int r = 0; r < comm_group_->world_size_; r++) { + if (r == comm_group_->rank_) continue; + if (comm_group_->is_same_node(r, comm_group_->rank_)) { + memory_channels_list.push_back(memory_channels_[r]); + } else { + port_channels_list.push_back(port_channels_[r]); + } + } + std::vector memory_channel_handlers(memory_channels_list.size()); + std::transform( + memory_channels_list.begin(), + memory_channels_list.end(), + memory_channel_handlers.begin(), + [](const mscclpp::MemoryChannel& channel) { return channel.deviceHandle(); }); + mscclpp::gpuMemcpy( + d_memHandles_.data(), memory_channel_handlers.data(), memory_channel_handlers.size(), cudaMemcpyHostToDevice); + + std::vector port_channel_handlers(port_channels_list.size()); + std::transform( + port_channels_list.begin(), + port_channels_list.end(), + port_channel_handlers.begin(), + [](const mscclpp::PortChannel& channel) { return channel.deviceHandle(); }); + mscclpp::gpuMemcpy( + d_portHandles_.data(), port_channel_handlers.data(), port_channel_handlers.size(), cudaMemcpyHostToDevice); + } + + ~Msccl2NodeLLcontext() { + CHECK_CUDA_SUCCESS(cudaStreamDestroy(h2d_stream)); + if (proxyService) { + proxyService->stopProxy(); + } + } + + template + void + allreduce(cudaStream_t stream, T* input, T* output, const size_t input_numel, int nthreads = 512, int nblocks = 21) { + dim3 nthrs(nthreads); + dim3 nblks(nblocks); + cudaStreamCaptureStatus capturing_status; + CHECK_CUDA_SUCCESS(cudaStreamIsCapturing(stream, &capturing_status)); + mscclpp::MemoryChannelDeviceHandle* memChans; + if (capturing_status != cudaStreamCaptureStatusActive) { + std::unordered_map memory_channels; + comm_group_->make_device_memory_handle_base_on_new_ptr( + memory_channels_, + registered_sm_memories_, + memory_semaphores_, + memory_channels, + d_memHandles_, + input, + scratch_, + h2d_stream); + CHECK_CUDA_SUCCESS(cudaStreamSynchronize(h2d_stream)); + memChans = d_memHandles_.data(); + } else { + void* input_void_ptr = reinterpret_cast(input); + if (input_ptr2d_memHandles_.find(input_void_ptr) == input_ptr2d_memHandles_.end()) { + std::unordered_map memory_channels; + mscclpp::GpuBuffer device_memory_handle(7); + comm_group_->make_device_memory_handle_base_on_new_ptr( + memory_channels_, + registered_sm_memories_, + memory_semaphores_, + memory_channels, + device_memory_handle, + input, + scratch_, + h2d_stream); + input_ptr2memory_channels_.emplace(input_void_ptr, memory_channels); + input_ptr2d_memHandles_.emplace(input_void_ptr, device_memory_handle); + } + auto it = input_ptr2d_memHandles_.find(input_void_ptr); + memChans = it->second.data(); + } + allreduce_LL_2node<<>>( + memChans, + d_portHandles_.data(), + (T*)input, + (T*)scratch_, + (T*)put_buffer_, + output, + comm_group_->rank_, + nranks_per_node_, + comm_group_->world_size_, + input_numel); + + cudaError_t status = cudaGetLastError(); + if (status != cudaSuccess) { + printf("rank: %lu failed to launch allreduce_LL_2node: %s\n", comm_group_->rank_, cudaGetErrorString(status)); + } + } +}; + +} // namespace sglang diff --git a/sgl-kernel/csrc/allreduce/test_mscclpp_allreduce.cu b/sgl-kernel/csrc/allreduce/test_mscclpp_allreduce.cu new file mode 100644 index 000000000..4ca0c5739 --- /dev/null +++ b/sgl-kernel/csrc/allreduce/test_mscclpp_allreduce.cu @@ -0,0 +1,153 @@ +/* + * this file is used to test mscclpp_allreduce.cu using mpirun + * this file is adapted from https://github.com/flashinfer-ai/flashinfer/blob/v0.2.5/src/test_sum_all_reduce.cu +usage: +cd PATH-TO-THIS-FILE +export MPI_HOME=/usr/local/mpi +# export MPI_HOME=/opt/hpcx/ompi/ +export MSCCLPP_HOME=/workspace/test/mscclpp +nvcc -O2 -arch=native -std=c++17 test_mscclpp_allreduce.cu \ + -o test_mscclpp_allreduce -D_GLIBCXX_USE_CXX11_ABI=0 \ + -I${MSCCLPP_HOME}/include -L${MSCCLPP_HOME}/build -lmscclpp \ + -lnccl -I${MPI_HOME}/include -L${MPI_HOME}/lib -lmpi + +/opt/hpcx/ompi/bin/ +mpirun --allow-run-as-root -H 127.0.0.1:8 -np 8 \ + --map-by ppr:8:node \ + --mca btl_openib_warn_no_device_params_found 0 \ + --mca btl_tcp_if_include bond0 \ + --allow-run-as-root -np 8 \ + -x NCCL_RUNTIME_CONNECT=0 -x NCCL_IB_GID_INDEX=3 -x NCCL_DEBUG=WARN \ + -x LD_PRELOAD=${MSCCLPP_HOME}/build/libmscclpp.so ./test_mscclpp_allreduce + */ +#include +#include +#include +#include + +#ifndef CHECK_CUDA_SUCCESS +#define CHECK_CUDA_SUCCESS(cmd) \ + do { \ + cudaError_t e = cmd; \ + if (e != cudaSuccess) { \ + printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, cudaGetErrorString(e)); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) +#endif + +#include + +#include "mscclpp_allreduce.cuh" + +template +bool isclose(T a, T b, float rtol = 1e-5, float atol = 1e-8) { + return fabs(a - b) <= (atol + rtol * fabs(b)); +} + +int main(int argc, char* argv[]) { + // init mpi + MPI_Init(&argc, &argv); + printf("MPI Initialized.\n"); + int nranks, rank; + + // get work size and rank id + MPI_Comm_size(MPI_COMM_WORLD, &nranks); + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + cudaSetDevice(rank); + printf("nranks: %d, rank: %d\n", nranks, rank); + + // init host and device buffers + using T = float; + using ReduceT = float; + const size_t num_elems = 2 * 1024 * 1024; + std::vector host_buf(num_elems); + for (uint32_t i = 0; i < num_elems; ++i) { + host_buf[i] = T(i + rank); + } + thrust::device_vector device_buf(host_buf); + const size_t buf_size_in_bytes = num_elems * sizeof(T); + std::vector host_result_buf(num_elems); + thrust::device_vector device_result_buf(host_result_buf); + + std::vector host_scratch_buf(num_elems * 8); + for (uint32_t i = 0; i < num_elems; ++i) { + host_scratch_buf[i] = 1; + } + thrust::device_vector device_scratch_buf(host_scratch_buf); + std::vector host_put_buf(num_elems); + thrust::device_vector device_put_buf(host_put_buf); + + mscclpp::UniqueId unique_id; + if (rank == 0) unique_id = mscclpp::TcpBootstrap::createUniqueId(); + MPI_Bcast(&unique_id, sizeof(unique_id), MPI_BYTE, 0, MPI_COMM_WORLD); + + std::vector rank_to_node(nranks); + std::vector rank_to_ib(nranks); + for (int i = 0; i < nranks; i++) { + rank_to_node[i] = i / 8; + rank_to_ib[i] = i % 8; + } + + cudaStream_t s; + CHECK_CUDA_SUCCESS(cudaStreamCreate(&s)); + CHECK_CUDA_SUCCESS(cudaStreamSynchronize(s)); + if (nranks == 8) { + auto context = std::make_shared( + unique_id, + rank, + nranks, + thrust::raw_pointer_cast(device_scratch_buf.data()), + buf_size_in_bytes * 8, + rank_to_node, + rank_to_ib); + printf("rank: %d, Msccl1NodeLLcontext setup.\n", rank); + MPI_Barrier(MPI_COMM_WORLD); + context->allreduce( + s, + thrust::raw_pointer_cast(device_buf.data()), + thrust::raw_pointer_cast(device_result_buf.data()), + device_buf.size()); + } else if (nranks == 16) { + // TODO: this branch is untested since there is something wrong with mpirun in my test machince + auto context = std::make_shared( + unique_id, + rank, + nranks, + thrust::raw_pointer_cast(device_scratch_buf.data()), + buf_size_in_bytes * 8, + thrust::raw_pointer_cast(device_put_buf.data()), + buf_size_in_bytes, + rank_to_node, + rank_to_ib); + printf("rank: %d, Msccl2NodeLLcontext setup.\n", rank); + MPI_Barrier(MPI_COMM_WORLD); + context->allreduce( + s, + thrust::raw_pointer_cast(device_buf.data()), + thrust::raw_pointer_cast(device_result_buf.data()), + device_buf.size()); + } + + // check result correctness + thrust::host_vector host_buf_result = device_result_buf; + size_t num_results_error_atol_1e_3_rtol_1e_3 = 0; + bool nan_detected = false; + + for (uint32_t i = 0; i < num_elems; ++i) { + T expected = T(i * nranks + (nranks - 1) * nranks / 2); + if (std::isnan(float(host_buf_result[i]))) { + nan_detected = true; + } + if (!isclose(float(host_buf_result[i]), float(expected), 1e-3, 1e-3)) { + num_results_error_atol_1e_3_rtol_1e_3++; + } + } + float result_accuracy = 1. - float(num_results_error_atol_1e_3_rtol_1e_3) / float(num_elems); + + printf("rank: %d, nan_detected: %d accuracy: %f\n", rank, nan_detected, result_accuracy); + + CHECK_CUDA_SUCCESS(cudaStreamDestroy(s)); + MPI_Finalize(); + return 0; +} diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index 1baa5aa07..73e937ec1 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -38,6 +38,15 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { "all_reduce(int fa, Tensor inp, Tensor! out, int reg_buffer, " "int reg_buffer_sz_bytes) -> ()"); m.impl("all_reduce", torch::kCUDA, &all_reduce); + + m.def("mscclpp_generate_unique_id", &mscclpp_generate_unique_id); + m.def( + "mscclpp_init_context(Tensor unique_id, int rank, int world_size, Tensor scratch, Tensor put_buffer, " + "int nranks_per_node, int[] rank_to_node, int[] rank_to_ib, int context_selection) -> int"); + m.impl("mscclpp_init_context", torch::kCUDA, &mscclpp_init_context); + + m.def("mscclpp_allreduce(int context, Tensor inp, Tensor! out, int nthreads, int nblocks) -> ()"); + m.impl("mscclpp_allreduce", torch::kCUDA, &mscclpp_allreduce); /* * From csrc/attention */ diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index 4c6a37bf8..93db3d952 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -74,6 +74,18 @@ std::tuple, std::vector> get_graph_buffer_ipc_meta void register_buffer(fptr_t _fa, const std::vector& fake_ipc_ptrs); void register_graph_buffers( fptr_t _fa, const std::vector>& handles, const std::vector>& offsets); +torch::Tensor mscclpp_generate_unique_id(); +fptr_t mscclpp_init_context( + const torch::Tensor& unique_id, + const int64_t rank, + const int64_t world_size, + torch::Tensor& scratch, + torch::Tensor& put_buffer, + const int64_t nranks_per_node, + const std::vector& rank_to_node, + const std::vector& rank_to_ib, + const int64_t context_selection); +void mscclpp_allreduce(fptr_t _context, torch::Tensor& inp, torch::Tensor& out, int64_t nthreads, int64_t nblocks); #endif /* diff --git a/sgl-kernel/python/sgl_kernel/allreduce.py b/sgl-kernel/python/sgl_kernel/allreduce.py index 135275961..317b2f1a7 100644 --- a/sgl-kernel/python/sgl_kernel/allreduce.py +++ b/sgl-kernel/python/sgl_kernel/allreduce.py @@ -49,6 +49,27 @@ if torch.version.hip is not None: def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor: return torch.ops.sgl_kernel.get_meta_buffer_ipc_handle.default(inp) + def mscclpp_generate_unique_id() -> bytes: + raise NotImplementedError() + + def mscclpp_init_context( + unique_id: bytes, + rank: int, + world_size: int, + scratch: torch.Tensor, + put_buffer: torch.Tensor, + nranks_per_node: int, + rank_to_node: List[int], + rank_to_ib: List[int], + context_selection: int, + ) -> int: + raise NotImplementedError() + + def mscclpp_allreduce( + context: int, inp: torch.Tensor, out: torch.Tensor, nthreads: int, nblocks: int + ) -> None: + raise NotImplementedError() + else: def init_custom_ar( @@ -85,3 +106,36 @@ else: def meta_size() -> int: return torch.ops.sgl_kernel.meta_size.default() + + def mscclpp_generate_unique_id() -> torch.Tensor: + return torch.ops.sgl_kernel.mscclpp_generate_unique_id.default() + + def mscclpp_init_context( + unique_id: torch.Tensor, + rank: int, + world_size: int, + scratch: torch.Tensor, + put_buffer: torch.Tensor, + nranks_per_node: int, + rank_to_node: List[int], + rank_to_ib: List[int], + context_selection: int, + ) -> int: + return torch.ops.sgl_kernel.mscclpp_init_context.default( + unique_id, + rank, + world_size, + scratch, + put_buffer, + nranks_per_node, + rank_to_node, + rank_to_ib, + context_selection, + ) + + def mscclpp_allreduce( + context: int, inp: torch.Tensor, out: torch.Tensor, nthreads: int, nblocks: int + ) -> None: + torch.ops.sgl_kernel.mscclpp_allreduce.default( + context, inp, out, nthreads, nblocks + ) diff --git a/sgl-kernel/tests/test_mscclpp.py b/sgl-kernel/tests/test_mscclpp.py new file mode 100644 index 000000000..0a4332bd3 --- /dev/null +++ b/sgl-kernel/tests/test_mscclpp.py @@ -0,0 +1,146 @@ +import multiprocessing as mp +import os +import socket +import unittest +from enum import IntEnum +from typing import Any + +import sgl_kernel.allreduce as custom_ops +import torch +import torch.distributed as dist + + +class MscclContextSelection(IntEnum): + MSCCL1SHOT1NODELL = 1 + MSCCL1SHOT2NODELL = 2 + + +def _run_correctness_worker(world_size, rank, distributed_init_port, test_sizes): + device = torch.device(f"cuda:{rank % torch.cuda.device_count()}") + torch.cuda.set_device(device) + distributed_init_method = f"tcp://localhost:{distributed_init_port}" + dist.init_process_group( + backend="nccl", + init_method=distributed_init_method, + rank=rank, + world_size=world_size, + ) + group = dist.group.WORLD + cpu_group = torch.distributed.new_group(list(range(world_size)), backend="gloo") + if rank == 0: + unique_id = [custom_ops.mscclpp_generate_unique_id()] + else: + unique_id = [None] + dist.broadcast_object_list( + unique_id, src=0, device=torch.device("cpu"), group=cpu_group + ) + unique_id = unique_id[0] + rank_to_node, rank_to_ib = list(range(world_size)), list(range(world_size)) + for r in range(world_size): + rank_to_node[r] = r // 8 + rank_to_ib[r] = rank % 8 + MAX_BYTES = 2**20 + scratch = torch.empty( + MAX_BYTES * 8, dtype=torch.bfloat16, device=torch.cuda.current_device() + ) + put_buffer = torch.empty( + MAX_BYTES, dtype=torch.bfloat16, device=torch.cuda.current_device() + ) + print(f"[{rank}] start mscclpp_context init") + nranks_per_node = torch.cuda.device_count() + selection = int(MscclContextSelection.MSCCL1SHOT1NODELL) + mscclpp_context = custom_ops.mscclpp_init_context( + unique_id, + rank, + world_size, + scratch, + put_buffer, + nranks_per_node, + rank_to_node, + rank_to_ib, + selection, + ) + try: + test_loop = 10 + for sz in test_sizes: + for dtype in [torch.float32, torch.float16, torch.bfloat16]: + if sz * dtype.itemsize > MAX_BYTES: + continue + if rank == 0: + print(f"mscclpp allreduce test sz {sz}, dtype {dtype}") + for _ in range(test_loop): + inp1 = torch.randint(1, 16, (sz,), dtype=dtype, device=device) + inp1_ref = inp1.clone() + out1 = torch.empty_like(inp1) + custom_ops.mscclpp_allreduce( + mscclpp_context, inp1, out1, nthreads=512, nblocks=21 + ) + dist.all_reduce(inp1_ref, group=group) + torch.testing.assert_close(out1, inp1_ref) + finally: + dist.barrier(group=group) + dist.destroy_process_group(group=group) + + +def get_open_port() -> int: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + except OSError: + with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s: + s.bind(("::1", 0)) + return s.getsockname()[1] + + +def multi_process_parallel( + world_size: int, test_target: Any, target_args: tuple = () +) -> None: + mp.set_start_method("spawn", force=True) + + procs = [] + distributed_init_port = get_open_port() + for i in range(world_size): + proc_args = (world_size, i, distributed_init_port) + target_args + proc = mp.Process(target=test_target, args=proc_args, name=f"Worker-{i}") + proc.start() + procs.append(proc) + + for i in range(world_size): + procs[i].join() + assert ( + procs[i].exitcode == 0 + ), f"Process {i} failed with exit code {procs[i].exitcode}" + + +class TestMSCCLAllReduce(unittest.TestCase): + test_sizes = [ + 512, + 2560, + 4096, + 5120, + 7680, + 32768, + 262144, + 524288, + ] + world_sizes = [8] + + def test_correctness(self): + for world_size in self.world_sizes: + available_gpus = torch.cuda.device_count() + if world_size > available_gpus: + print( + f"Skipping world_size={world_size}, found {available_gpus} and now ray is not supported here" + ) + continue + + print(f"Running test for world_size={world_size}") + multi_process_parallel( + world_size, _run_correctness_worker, target_args=(self.test_sizes,) + ) + print(f"custom allreduce tp = {world_size}: OK") + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_mscclpp.py b/test/srt/test_mscclpp.py new file mode 100644 index 000000000..894598b3d --- /dev/null +++ b/test/srt/test_mscclpp.py @@ -0,0 +1,205 @@ +"""For Now, MSCCL is only supported on TP16 and TP8 case + +if [[ $RANK -eq 0 ]]; then + ray start --block --head --port=6379 & + python3 test_mscclpp.py; +else + ray start --block --address=${MASTER_ADDR}:6379; +fi +""" + +import itertools +import os +import random +import socket +import unittest +from contextlib import contextmanager, nullcontext +from typing import Any, List, Optional, Union + +import ray +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup, ReduceOp + +from sglang.srt.distributed import init_distributed_environment +from sglang.srt.distributed.communication_op import ( # noqa + tensor_model_parallel_all_reduce, +) +from sglang.srt.distributed.device_communicators.custom_all_reduce import ( + CustomAllreduce, +) +from sglang.srt.distributed.device_communicators.pymscclpp import PyMscclppCommunicator +from sglang.srt.distributed.device_communicators.pynccl import PyNcclCommunicator +from sglang.srt.distributed.parallel_state import ( + get_tensor_model_parallel_group, + graph_capture, + initialize_model_parallel, + set_custom_all_reduce, + set_mscclpp_all_reduce, +) +from sglang.srt.distributed.utils import StatelessProcessGroup +from sglang.test.test_utils import CustomTestCase + + +def get_open_port() -> int: + # try ipv4 + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + except OSError: + # try ipv6 + with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + + +def multi_process_parallel( + world_size: int, + master_addr: str, + cls: Any, + test_target: Any, +) -> None: + + # Using ray helps debugging the error when it failed + # as compared to multiprocessing. + # NOTE: We need to set working_dir for distributed tests, + # otherwise we may get import errors on ray workers + + ray.init(log_to_driver=True) + + distributed_init_port = get_open_port() + refs = [] + for rank in range(world_size): + refs.append( + test_target.remote( + cls, world_size, master_addr, rank, distributed_init_port + ) + ) + ray.get(refs) + + ray.shutdown() + + +class TestMSCCLAllReduce(CustomTestCase): + @classmethod + def setUpClass(cls): + random.seed(42) + # 1KB to 1MB + cls.test_sizes = [512, 4096, 32768, 262144, 524288] + cls.world_sizes = [8] + TEST_TP16 = int(os.getenv("SGL_MSCCLPP_TEST_TP16", "0")) + if TEST_TP16: + cls.world_sizes = [16] + cls.test_loop = 10 + + def test_graph_allreduce(self): + TEST_MASTER_ADDR = os.getenv("SGL_MSCCLPP_TEST_MASTER_ADDR", "localhost") + for world_size in self.world_sizes: + if world_size not in [8, 16]: + continue + multi_process_parallel( + world_size, TEST_MASTER_ADDR, self, self.graph_allreduce + ) + + def test_eager_allreduce(self): + TEST_MASTER_ADDR = os.getenv("SGL_MSCCLPP_TEST_MASTER_ADDR", "localhost") + for world_size in self.world_sizes: + if world_size not in [8, 16]: + continue + multi_process_parallel( + world_size, TEST_MASTER_ADDR, self, self.eager_allreduce + ) + + @ray.remote(num_gpus=1, max_calls=1) + def graph_allreduce(self, world_size, master_addr, rank, distributed_init_port): + del os.environ["CUDA_VISIBLE_DEVICES"] + device = torch.device(f"cuda:{rank % torch.cuda.device_count()}") + torch.cuda.set_device(device) + distributed_init_method = f"tcp://{master_addr}:{distributed_init_port}" + set_mscclpp_all_reduce(True) + set_custom_all_reduce(False) + init_distributed_environment( + world_size=world_size, + rank=rank, + distributed_init_method=distributed_init_method, + local_rank=rank % torch.cuda.device_count(), + ) + initialize_model_parallel(tensor_model_parallel_size=world_size) + group = get_tensor_model_parallel_group().device_group + + # A small all_reduce for warmup. + # this is needed because device communicators might be created lazily + # (e.g. NCCL). This will ensure that the communicator is initialized + # before any communication happens, so that this group can be used for + # graph capture immediately. + data = torch.zeros(1) + data = data.to(device=device) + torch.distributed.all_reduce(data, group=group) + torch.cuda.synchronize() + del data + + for sz in self.test_sizes: + for dtype in [torch.float32, torch.float16, torch.bfloat16]: + for _ in range(self.test_loop): + with graph_capture() as graph_capture_context: + # use integers so result matches NCCL exactly + inp1 = torch.randint( + 1, + 16, + (sz,), + dtype=dtype, + device=torch.cuda.current_device(), + ) + inp2 = torch.randint( + 1, + 16, + (sz,), + dtype=dtype, + device=torch.cuda.current_device(), + ) + torch.cuda.synchronize() + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph( + graph, stream=graph_capture_context.stream + ): + out1 = tensor_model_parallel_all_reduce(inp1) + # the input buffer is immediately modified to test + # synchronization + dist.all_reduce(inp1, group=group) + out2 = tensor_model_parallel_all_reduce(inp2) + dist.all_reduce(inp2, group=group) + graph.replay() + torch.testing.assert_close(out1, inp1) + torch.testing.assert_close(out2, inp2) + + @ray.remote(num_gpus=1, max_calls=1) + def eager_allreduce(self, world_size, master_addr, rank, distributed_init_port): + del os.environ["CUDA_VISIBLE_DEVICES"] + device = torch.device(f"cuda:{rank % torch.cuda.device_count()}") + torch.cuda.set_device(device) + distributed_init_method = f"tcp://{master_addr}:{distributed_init_port}" + set_mscclpp_all_reduce(True) + set_custom_all_reduce(False) + init_distributed_environment( + world_size=world_size, + rank=rank, + distributed_init_method=distributed_init_method, + local_rank=rank, + ) + initialize_model_parallel(tensor_model_parallel_size=world_size) + group = get_tensor_model_parallel_group().device_group + + for sz in self.test_sizes: + for dtype in [torch.float32, torch.float16, torch.bfloat16]: + for _ in range(self.test_loop): + inp1 = torch.randint( + 1, 16, (sz,), dtype=dtype, device=torch.cuda.current_device() + ) + out1 = tensor_model_parallel_all_reduce(inp1) + dist.all_reduce(inp1, group=group) + torch.testing.assert_close(out1, inp1) + + +if __name__ == "__main__": + unittest.main()