From 28d4d4728088f551f13edfcafadf12484b32ee64 Mon Sep 17 00:00:00 2001 From: li haoyang Date: Fri, 25 Jul 2025 11:48:42 +0800 Subject: [PATCH] [Feature] Integrate quick allreduce and select the best allreduce implementation (#6619) Signed-off-by: Haoyang Li Co-authored-by: ilmarkov --- python/sglang/srt/_custom_ops.py | 30 +- .../device_communicators/custom_all_reduce.py | 94 +-- .../custom_all_reduce_utils.py | 97 ++- .../device_communicators/quick_all_reduce.py | 273 ++++++++ .../sglang/srt/distributed/parallel_state.py | 76 ++- sgl-kernel/csrc/allreduce/quick_all_reduce.cu | 111 +++ .../csrc/allreduce/quick_all_reduce.cuh | 633 ++++++++++++++++++ sgl-kernel/csrc/allreduce/quick_all_reduce.h | 233 +++++++ .../csrc/allreduce/quick_all_reduce_base.h | 318 +++++++++ sgl-kernel/csrc/torch_extension_rocm.cc | 19 + sgl-kernel/include/sgl_kernel_ops.h | 9 + sgl-kernel/python/sgl_kernel/allreduce.py | 34 +- sgl-kernel/setup_rocm.py | 1 + test/srt/test_quick_allreduce.py | 212 ++++++ 14 files changed, 2031 insertions(+), 109 deletions(-) create mode 100644 python/sglang/srt/distributed/device_communicators/quick_all_reduce.py create mode 100644 sgl-kernel/csrc/allreduce/quick_all_reduce.cu create mode 100644 sgl-kernel/csrc/allreduce/quick_all_reduce.cuh create mode 100644 sgl-kernel/csrc/allreduce/quick_all_reduce.h create mode 100644 sgl-kernel/csrc/allreduce/quick_all_reduce_base.h create mode 100644 test/srt/test_quick_allreduce.py diff --git a/python/sglang/srt/_custom_ops.py b/python/sglang/srt/_custom_ops.py index 1c232d19f..5ed175312 100644 --- a/python/sglang/srt/_custom_ops.py +++ b/python/sglang/srt/_custom_ops.py @@ -1,6 +1,6 @@ # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/_custom_ops.py import logging -from typing import List, Tuple +from typing import List, Optional, Tuple import torch @@ -114,6 +114,34 @@ else: def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor: return sgl_kernel.allreduce.get_meta_buffer_ipc_handle(inp) + # ROCM custom quick allreduce + + def init_custom_qr( + rank: int, world_size: int, qr_max_size: Optional[int] = None + ) -> int: + return sgl_kernel.allreduce.init_custom_qr(world_size, rank, qr_max_size) + + def qr_get_handle(fa: int) -> torch.Tensor: + return sgl_kernel.allreduce.qr_get_handle(fa) + + def qr_open_handles(fa: int, handles: list[torch.Tensor]) -> None: + sgl_kernel.allreduce.qr_open_handles(fa, handles) + + def qr_all_reduce( + fa: int, + inp: torch.Tensor, + out: torch.Tensor, + quant_level: int, + cast_bf2half: bool, + ) -> None: + sgl_kernel.allreduce.qr_all_reduce(fa, inp, out, quant_level, cast_bf2half) + + def qr_destroy(fa: int) -> None: + sgl_kernel.allreduce.qr_destroy(fa) + + def qr_max_size() -> int: + return sgl_kernel.allreduce.qr_max_size() + def mscclpp_generate_unique_id() -> bytes: return sgl_kernel.allreduce.mscclpp_generate_unique_id() diff --git a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py index 9faff648c..a1d28f2fc 100644 --- a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py +++ b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py @@ -4,18 +4,18 @@ import ctypes import logging import os from contextlib import contextmanager -from functools import wraps -from typing import Any, Callable, List, Optional, TypeVar, Union +from typing import Any, List, Optional, Union import torch import torch.distributed as dist from torch.distributed import ProcessGroup -from typing_extensions import ParamSpec from sglang.srt import _custom_ops as ops from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibrary from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import ( gpu_p2p_access_check, + is_full_nvlink, + is_weak_contiguous, ) from sglang.srt.distributed.parallel_state import in_the_same_node_as from sglang.srt.utils import is_cuda, is_hip @@ -25,23 +25,6 @@ logger = logging.getLogger(__name__) _is_cuda = is_cuda() _is_hip = is_hip() -if _is_cuda: - try: - import pynvml - except ImportError as e: - logger.warning("Failed to import pynvml with %r", e) - -if _is_hip: - try: - from amdsmi import ( - AmdSmiException, - amdsmi_get_processor_handles, - amdsmi_init, - amdsmi_shut_down, - amdsmi_topo_get_link_type, - ) - except ImportError as e: - logger.warning("Failed to import amdsmi with %r", e) try: if ops.use_vllm_custom_allreduce and not _is_hip: @@ -57,70 +40,6 @@ except Exception: logger = logging.getLogger(__name__) -_P = ParamSpec("_P") -_R = TypeVar("_R") - - -def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]: - @wraps(fn) - def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R: - if _is_hip: - try: - amdsmi_init() - return fn(*args, **kwargs) - finally: - amdsmi_shut_down() - else: - pynvml.nvmlInit() - try: - return fn(*args, **kwargs) - finally: - pynvml.nvmlShutdown() - - return wrapper - - -@with_nvml_context -def is_full_nvlink(physical_device_ids: List[int], world_size: int) -> bool: - if _is_hip: - """ - query if the set of gpus are fully connected by xgmi (1 hop) - """ - handles = [amdsmi_get_processor_handles()[i] for i in physical_device_ids] - for i, handle in enumerate(handles): - for j, peer_handle in enumerate(handles): - if i < j: - try: - link_type = amdsmi_topo_get_link_type(handle, peer_handle) - # type is 2 for XGMI - if link_type["hops"] != 1 or link_type["type"] != 2: - return False - except AmdSmiException as error: - logger.error("AMD 1 hop XGMI detection failed.", exc_info=error) - return False - return True - else: - """ - query if the set of gpus are fully connected by nvlink (1 hop) - """ - handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids] - for i, handle in enumerate(handles): - for j, peer_handle in enumerate(handles): - if i < j: - try: - p2p_status = pynvml.nvmlDeviceGetP2PStatus( - handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK - ) - if p2p_status != pynvml.NVML_P2P_STATUS_OK: - return False - except pynvml.NVMLError: - logger.exception( - "NVLink detection failed. This is normal if your" - " machine has no NVLink equipped." - ) - return False - return True - def _can_p2p(rank: int, world_size: int) -> bool: # SGLANG_SKIP_P2P_CHECK can be set to False in sglang @@ -136,13 +55,6 @@ def _can_p2p(rank: int, world_size: int) -> bool: return True -def is_weak_contiguous(inp: torch.Tensor): - return inp.is_contiguous() or ( - inp.storage().nbytes() - inp.storage_offset() * inp.element_size() - == inp.numel() * inp.element_size() - ) - - class CustomAllreduce: _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8] _MAX_CAR_SIZE = 8192 * 1024 diff --git a/python/sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py b/python/sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py index 86121ac97..c7baac845 100644 --- a/python/sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +++ b/python/sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py @@ -8,17 +8,44 @@ import pickle import subprocess import sys import tempfile +from functools import wraps from itertools import product -from typing import Dict, List, Optional, Sequence +from typing import Callable, Dict, List, Optional, Sequence, TypeVar import torch import torch.distributed as dist import torch.multiprocessing as mp +from typing_extensions import ParamSpec from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibrary +from sglang.srt.utils import is_cuda, is_hip logger = logging.getLogger(__name__) +_is_cuda = is_cuda() +_is_hip = is_hip() + +if _is_cuda: + try: + import pynvml + except ImportError as e: + logger.warning("Failed to import pynvml with %r", e) + +if _is_hip: + try: + from amdsmi import ( + AmdSmiException, + amdsmi_get_processor_handles, + amdsmi_init, + amdsmi_shut_down, + amdsmi_topo_get_link_type, + ) + except ImportError as e: + logger.warning("Failed to import amdsmi with %r", e) + +_P = ParamSpec("_P") +_R = TypeVar("_R") + def update_environment_variables(envs: Dict[str, str]): for k, v in envs.items(): @@ -282,6 +309,74 @@ def gpu_p2p_access_check(src: int, tgt: int) -> bool: return _gpu_p2p_access_cache[f"{src}->{tgt}"] +def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]: + @wraps(fn) + def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R: + if _is_hip: + try: + amdsmi_init() + return fn(*args, **kwargs) + finally: + amdsmi_shut_down() + else: + pynvml.nvmlInit() + try: + return fn(*args, **kwargs) + finally: + pynvml.nvmlShutdown() + + return wrapper + + +@with_nvml_context +def is_full_nvlink(physical_device_ids: List[int], world_size: int) -> bool: + if _is_hip: + """ + query if the set of gpus are fully connected by xgmi (1 hop) + """ + handles = [amdsmi_get_processor_handles()[i] for i in physical_device_ids] + for i, handle in enumerate(handles): + for j, peer_handle in enumerate(handles): + if i < j: + try: + link_type = amdsmi_topo_get_link_type(handle, peer_handle) + # type is 2 for XGMI + if link_type["hops"] != 1 or link_type["type"] != 2: + return False + except AmdSmiException as error: + logger.error("AMD 1 hop XGMI detection failed.", exc_info=error) + return False + return True + else: + """ + query if the set of gpus are fully connected by nvlink (1 hop) + """ + handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids] + for i, handle in enumerate(handles): + for j, peer_handle in enumerate(handles): + if i < j: + try: + p2p_status = pynvml.nvmlDeviceGetP2PStatus( + handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK + ) + if p2p_status != pynvml.NVML_P2P_STATUS_OK: + return False + except pynvml.NVMLError: + logger.exception( + "NVLink detection failed. This is normal if your" + " machine has no NVLink equipped." + ) + return False + return True + + +def is_weak_contiguous(inp: torch.Tensor): + return inp.is_contiguous() or ( + inp.storage().nbytes() - inp.storage_offset() * inp.element_size() + == inp.numel() * inp.element_size() + ) + + __all__ = ["gpu_p2p_access_check"] if __name__ == "__main__": diff --git a/python/sglang/srt/distributed/device_communicators/quick_all_reduce.py b/python/sglang/srt/distributed/device_communicators/quick_all_reduce.py new file mode 100644 index 000000000..0113c432d --- /dev/null +++ b/python/sglang/srt/distributed/device_communicators/quick_all_reduce.py @@ -0,0 +1,273 @@ +# SPDX-License-Identifier: Apache-2.0 + +import logging +import os +from enum import Enum +from typing import Union + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup + +from sglang.srt import _custom_ops as ops +from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import ( + is_full_nvlink, + is_weak_contiguous, +) +from sglang.srt.distributed.parallel_state import in_the_same_node_as +from sglang.srt.utils import is_cuda, is_hip + +logger = logging.getLogger(__name__) + +_is_cuda = is_cuda() +_is_hip = is_hip() + + +try: + ops.qr_max_size() + quick_ar = True +except Exception: + # For CPUs and CUDA + quick_ar = False + + +def qr_rocm_arch_available(): + if not _is_hip: + return False + try: + props = torch.cuda.get_device_properties(0) + gcn_arch = getattr(props, "gcnArchName", "") + supported_archs = ["gfx94", "gfx95"] + return any(gfx in gcn_arch for gfx in supported_archs) + except Exception as e: + logger.warning("Failed to determine ROCm for quick allreduce: %s", e) + return False + + +class QuickReduceRegime(Enum): + FP = 0 + INT8 = 1 + INT6 = 2 + INT4 = 3 + NONE = 4 + + +MB = 1024 * 1024 + + +class QuickAllReduce: + + _SUPPORTED_WORLD_SIZES = [2, 4, 8] + _SUPPORTED_DTYPES = [torch.float16, torch.bfloat16] + # The following data is based on kernel tests. + # In this order [FP, INT8, INT6, INT4]. + _QR_MIN_SIZE = { + (torch.float16, 2): [1 * MB, 2 * MB, 2 * MB, 1 * MB], + (torch.float16, 4): [1 * MB, 16 * MB, 4 * MB, 2 * MB], + (torch.float16, 8): [16 * MB, 4 * MB, 4 * MB, 2 * MB], + (torch.bfloat16, 2): [2 * MB, 8 * MB, 8 * MB, 8 * MB], + (torch.bfloat16, 4): [8 * MB, 64 * MB, 64 * MB, 16 * MB], + (torch.bfloat16, 8): [16 * MB, 2048 * MB, 2048 * MB, 2048 * MB], + } + + def __init__( + self, group: ProcessGroup, device: Union[int, str, torch.device] + ) -> None: + """ + Custom allreduce provides non-destructive acceleration and is + available for CUDA and ROCm MI300 series. + Custom quick allreduce leverages quantization for further + acceleration on ROCm. It currently supports Q8, Q6, and Q4 + quantization formats and FP(float16, bfloat16). + Quick allreduce is designed as a complement to custom allreduce. + Its initialization requires even stricter conditions. + Only the ROCm MI300 series is supported for quick allreduce at + this time. + Args: + group: the process group to work on. If None, it will use the + default process group. + device: the device to bind the CustomAllreduce to. If None, + it will be bind to f"cuda:{local_rank}". + It is the caller's responsibility to make sure each communicator + is bind to a unique device, and all communicators in this group + are in the same node. + """ + self.disabled = True + if not qr_rocm_arch_available(): + logger.debug( + "Custom quick allreduce is only supported on ROCm MI300 series." + ) + return + + if not quick_ar: + # disable because of missing quick reduce library + # e.g. in a cuda environment + logger.info( + "Custom quick allreduce is disabled because " + "of missing custom quick allreduce library" + ) + return + + self.group = group + assert ( + dist.get_backend(group) != dist.Backend.NCCL + ), "Custom quick allreduce should be attached to a non-NCCL group." + if not all(in_the_same_node_as(group, source_rank=0)): + # No need to initialize custom quick allreduce for + # multi-node case. + logger.warning( + "Custom quick allreduce is disabled because this " + "process group spans across nodes." + ) + return + rank = dist.get_rank(group=self.group) + world_size = dist.get_world_size(group=self.group) + self.rank = rank + self.world_size = world_size + if world_size == 1: + # No need to initialize QuickReduce for single GPU case. + return + + if world_size not in QuickAllReduce._SUPPORTED_WORLD_SIZES: + logger.warning( + "Custom quick allreduce is disabled due to an " + "unsupported world size: %d. Supported world sizes: %s.", + world_size, + str(QuickAllReduce._SUPPORTED_WORLD_SIZES), + ) + return + + if isinstance(device, int): + device = torch.device(f"cuda:{device}") + elif isinstance(device, str): + device = torch.device(device) + assert isinstance(device, torch.device) + self.device = device + + cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None) + if cuda_visible_devices: + device_ids = list(map(int, cuda_visible_devices.split(","))) + else: + device_ids = list(range(torch.cuda.device_count())) + physical_device_id = device_ids[device.index] + tensor = torch.tensor([physical_device_id], dtype=torch.int, device="cpu") + gather_list = [ + torch.tensor([0], dtype=torch.int, device="cpu") + for _ in range(self.world_size) + ] + dist.all_gather(gather_list, tensor, group=self.group) + physical_device_ids = [t.item() for t in gather_list] + + # test nvlink first, this will filter out most of the cases + # where custom quick allreduce is not supported + # this checks hardware and driver support for NVLink + if _is_cuda or _is_hip: + self.fully_connected = is_full_nvlink(physical_device_ids, self.world_size) + if self.world_size > 2 and not self.fully_connected: + logger.debug( + "Custom quick allreduce is disabled because it's not supported " + "on more than two PCIe-only GPUs. " + ) + return + + self.init_quick_all_reduce() + + def init_quick_all_reduce(self): + # On RocM, bfloat16 kernels are slower than fp16 + # due to slower match operations + # If environment variable is set to 1, we convert input to fp16 + self.use_fp16_kernels = int( + os.environ.get("ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16", 1) + ) + regime_str = os.environ.get("ROCM_QUICK_REDUCE_QUANTIZATION", "NONE") + if regime_str not in QuickReduceRegime.__members__: + logger.warning( + "Custom quick allreduce:", + f"Invalid quantization level: {regime_str}. " + "Supported levels: " + f"{list(QuickReduceRegime.__members__.keys())}", + ) + return + + if regime_str == "NONE": + logger.debug( + "Custom quick allreduce is disabled based " + "on env variable " + "ROCM_QUICK_REDUCE_QUANTIZATION='NONE'" + ) + return + self.qr_quant_level = QuickReduceRegime[regime_str] + + # TODO: If the dtype is not bfloat16 or then float16, + # quickallreduce should not be created. + + # ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB is specified in MB + qr_max_size = int(os.environ.get("ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB", 0)) + if qr_max_size > 0: + if qr_max_size < 1: + logger.info( + "You should not set a max_size smaller than 1MB, which can " + "lead to error or degradation to custom allreduce or rccl." + ) + qr_max_size = qr_max_size * MB + # If qr_max_size is None, then 2GB is used by default. + self._ptr = ops.init_custom_qr(self.rank, self.world_size, qr_max_size) + self.qr_max_size = qr_max_size if qr_max_size > 0 else ops.qr_max_size() + self.create_shared_buffer() + self.disabled = False + + def create_shared_buffer(self): + """ + Creates a shared buffer for quickreduce. + Has to be called after init_custom_qr + """ + handle = ops.qr_get_handle(self._ptr) + world_size = dist.get_world_size(group=self.group) + handles = [None] * world_size + dist.all_gather_object(handles, handle, group=self.group) + ops.qr_open_handles(self._ptr, handles) + + def should_quick_allreduce(self, inp: torch.Tensor): + """ + Check if quickreduce is available + """ + if self.disabled: + return False + if inp.dtype not in self._SUPPORTED_DTYPES: + return False + inp_size = inp.numel() * inp.element_size() + # custom quick allreduce requires input byte size to be + # multiples of 16 + if inp_size % 16 != 0: + return False + if not is_weak_contiguous(inp): + return False + dtype = inp.dtype + if self.use_fp16_kernels: + dtype = torch.float16 + return ( + inp_size <= self.qr_max_size + and inp_size + >= self._QR_MIN_SIZE[(dtype, self.world_size)][self.qr_quant_level.value] + ) + + def quick_all_reduce(self, inp: torch.Tensor, *, out: torch.Tensor = None): + """Performs an out-of-place custom quick all reduce.""" + # quick allreduce doesn't require a separate graph mode, + # as QR uses static IPC buffer. + if out is None: + out = torch.empty_like(inp) + ops.qr_all_reduce( + self._ptr, inp, out, self.qr_quant_level.value, self.use_fp16_kernels + ) + return out + + def close(self): + if not self.disabled and getattr(self, "_ptr", None): + if ops is not None: + ops.qr_destroy(self._ptr) + self._ptr = 0 + self.disabled = True + + def __del__(self): + self.close() diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index 509c71531..130bc53c7 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -44,6 +44,7 @@ from sglang.srt.utils import ( get_bool_env_var, get_int_env_var, is_cuda_alike, + is_hip, is_npu, is_shm_available, supports_custom_op, @@ -126,14 +127,18 @@ if supports_custom_op(): fake_impl=inplace_all_reduce_fake, ) - def outplace_all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor: + def outplace_all_reduce( + tensor: torch.Tensor, group_name: str, outplace_all_reduce_method: str + ) -> torch.Tensor: assert group_name in _groups, f"Group {group_name} is not found." group = _groups[group_name]() if group is None: raise ValueError(f"Group {group_name} is destroyed.") - return group._all_reduce_out_place(tensor) + return group._all_reduce_out_place(tensor, outplace_all_reduce_method) - def outplace_all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor: + def outplace_all_reduce_fake( + tensor: torch.Tensor, group_name: str, outplace_all_reduce_method: str + ) -> torch.Tensor: return torch.empty_like(tensor) direct_register_custom_op( @@ -264,6 +269,12 @@ class GroupCoordinator: PyNcclCommunicator, ) + if is_hip(): + from sglang.srt.distributed.device_communicators.quick_all_reduce import ( + QuickAllReduce, + qr_rocm_arch_available, + ) + self.pynccl_comm: Optional[PyNcclCommunicator] = None if use_pynccl and self.world_size > 1: self.pynccl_comm = PyNcclCommunicator( @@ -283,6 +294,7 @@ class GroupCoordinator: ) self.ca_comm: Optional[CustomAllreduce] = None + self.qr_comm: Optional[QuickAllReduce] = None if use_custom_allreduce and self.world_size > 1: # Initialize a custom fast all-reduce implementation. try: @@ -295,6 +307,18 @@ class GroupCoordinator: f"Setup Custom allreduce failed with {e}. To silence this " "warning, specify --disable-custom-all-reduce explicitly." ) + if is_hip(): + try: + # Initialize a custom quick all-reduce implementation for AMD + # when rocm >= gfx942. Quick reduce is designed as a + # complement to custom allreduce. + # Based on quickreduce (https://github.com/mk1-project/quickreduce). + if qr_rocm_arch_available(): + self.qr_comm = QuickAllReduce( + group=self.cpu_group, device=self.device + ) + except Exception as e: + logger.warning(f"Failed to initialize QuickAllReduce: {e}") from sglang.srt.distributed.device_communicators.hpu_communicator import ( HpuCommunicator, @@ -373,7 +397,8 @@ class GroupCoordinator: graph_capture_context = GraphCaptureContext(stream) else: stream = graph_capture_context.stream - + # We don't need the context of custom quick allreduce because the ipc access + # is already collected in init() and we can capture the quick allreduce directly. ca_comm = self.ca_comm maybe_ca_context = nullcontext() if ca_comm is None else ca_comm.capture() @@ -388,23 +413,24 @@ class GroupCoordinator: # operations. The current status is: # allreduce \ Mode | Eager | Graph | # -------------------------------------------- + # quick allreduce | enabled | enabled | # custom allreduce | enabled | enabled | # PyNccl | disabled| enabled | # PyMscclpp | disabled| enabled | # torch.distributed | enabled | disabled| # + # Note: When custom quick allreduce is enabled, a runtime check + # will be performed. If the tensor size is too small, it will + # automatically fall back to the next available option. # Note that custom allreduce will have a runtime check, if the # tensor size is too large, it will fallback to the next # available option. # Note that the PyMsccl needs to register the tensor in ahead, # which will introduce large overhead in the eager case, # therefore it is only supported in the graph case. - # In summary: When using CUDA graph, we use - # either custom all-reduce kernel or pynccl. When not using - # CUDA graph, we use either custom all-reduce kernel or - # PyTorch NCCL. We always prioritize using custom all-reduce - # kernel but fall back to PyTorch or pynccl if it is - # disabled or not supported. + # In summary: We select the appropriate allreduce method for + # each mode based on the algorithm order in the table and + # their usage conditions. pynccl_comm = self.pynccl_comm maybe_pynccl_context: Any if not pynccl_comm: @@ -464,27 +490,47 @@ class GroupCoordinator: if self.npu_communicator is not None and not self.npu_communicator.disabled: return self.npu_communicator.all_reduce(input_) + outplace_all_reduce_method = None if ( + self.qr_comm is not None + and not self.qr_comm.disabled + and self.qr_comm.should_quick_allreduce(input_) + ): + outplace_all_reduce_method = "qr" + elif ( self.ca_comm is not None and not self.ca_comm.disabled and self.ca_comm.should_custom_ar(input_) - ) or ( + ): + outplace_all_reduce_method = "ca" + elif ( self.pymscclpp_comm is not None and not self.pymscclpp_comm.disabled and self.pymscclpp_comm.should_mscclpp_allreduce(input_) ): + outplace_all_reduce_method = "pymscclpp" + if outplace_all_reduce_method is not None: return torch.ops.sglang.outplace_all_reduce( - input_, group_name=self.unique_name + input_, + group_name=self.unique_name, + outplace_all_reduce_method=outplace_all_reduce_method, ) else: torch.ops.sglang.inplace_all_reduce(input_, group_name=self.unique_name) return input_ - def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor: + def _all_reduce_out_place( + self, input_: torch.Tensor, outplace_all_reduce_method: str + ) -> torch.Tensor: + qr_comm = self.qr_comm ca_comm = self.ca_comm pymscclpp_comm = self.pymscclpp_comm - assert ca_comm is not None or pymscclpp_comm is not None - if ca_comm is not None and not ca_comm.disabled: + assert any([qr_comm, ca_comm, pymscclpp_comm]) + if outplace_all_reduce_method == "qr": + assert not qr_comm.disabled + out = qr_comm.quick_all_reduce(input_) + elif outplace_all_reduce_method == "ca": + assert not ca_comm.disabled out = ca_comm.custom_all_reduce(input_) else: assert not pymscclpp_comm.disabled diff --git a/sgl-kernel/csrc/allreduce/quick_all_reduce.cu b/sgl-kernel/csrc/allreduce/quick_all_reduce.cu new file mode 100644 index 000000000..757c05d2b --- /dev/null +++ b/sgl-kernel/csrc/allreduce/quick_all_reduce.cu @@ -0,0 +1,111 @@ +#include +#include +#include +#include + +#ifdef USE_ROCM + +#include "quick_all_reduce.h" + +quickreduce::fptr_t init_custom_qr(int64_t rank, int64_t world_size, std::optional qr_max_size) { + if (world_size > 8) throw std::invalid_argument("world size > 8 is not supported"); + if (world_size == 6) throw std::invalid_argument("world size == 6 is not supported"); + if (world_size % 2 != 0) throw std::invalid_argument("Odd num gpus is not supported for now"); + if (rank < 0 || rank >= world_size) throw std::invalid_argument("invalid rank passed in"); + quickreduce::DeviceComms* fptr = new quickreduce::DeviceComms(); + fptr->init(world_size, rank, qr_max_size); + return (quickreduce::fptr_t)fptr; +} + +void qr_destroy(quickreduce::fptr_t _fa) { + if (_fa) { + auto fa = reinterpret_cast(_fa); + fa->destroy(); + delete fa; + } +} + +torch::Tensor qr_get_handle(quickreduce::fptr_t _fa) { + auto fa = reinterpret_cast(_fa); + hipIpcMemHandle_t handle = fa->get_handle(); + auto options = torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU); + auto data_handle = torch::empty({static_cast(sizeof(hipIpcMemHandle_t))}, options); + std::memcpy(data_handle.data_ptr(), &handle, sizeof(hipIpcMemHandle_t)); + return data_handle; +} + +void qr_open_handles(quickreduce::fptr_t _fa, const std::vector& handles) { + auto fa = reinterpret_cast(_fa); + std::vector ipc_handles; + ipc_handles.reserve(handles.size()); + for (auto& handle : handles) { + // Ensure the tensor is on the same device as the current device. + hipIpcMemHandle_t ipc_handle; + std::memcpy(&ipc_handle, handle.data_ptr(), sizeof(hipIpcMemHandle_t)); + ipc_handles.push_back(ipc_handle); + } + fa->open_ipc_handles(ipc_handles); +} + +void qr_all_reduce( + quickreduce::fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, int64_t quant_level, bool cast_bf2half) { + auto fa = reinterpret_cast(_fa); + const at::cuda::OptionalCUDAGuard device_guard(device_of(inp)); + auto stream = at::cuda::getCurrentHIPStreamMasqueradingAsCUDA(); + + TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type()); + TORCH_CHECK_EQ(inp.numel(), out.numel()); + TORCH_CHECK_LE(out.numel(), fa->kMaxProblemSize); + if (out.scalar_type() == at::ScalarType::Half) { + fa->allreduce( + reinterpret_cast(inp.data_ptr()), + reinterpret_cast(out.data_ptr()), + out.numel(), + quant_level, + stream); + } else if (out.scalar_type() == at::ScalarType::BFloat16) { + if (cast_bf2half) { + fa->allreduce( + reinterpret_cast(inp.data_ptr()), + reinterpret_cast(out.data_ptr()), + out.numel(), + quant_level, + stream); + } else { + fa->allreduce( + reinterpret_cast(inp.data_ptr()), + reinterpret_cast(out.data_ptr()), + out.numel(), + quant_level, + stream); + } + } else { + throw std::runtime_error("quick allreduce only supports float16 and bfloat16"); + } +} + +int64_t qr_max_size() { + // The default is 2GB (2,147,483,648 bytes) + return static_cast(std::numeric_limits::max()) + 1; +} + +#define INSTANTIATE_FOR_WORLDSIZE(T, Codec, cast_bf2half) \ + template struct quickreduce::AllReduceTwoshot, cast_bf2half>; \ + template struct quickreduce::AllReduceTwoshot, cast_bf2half>; \ + template struct quickreduce::AllReduceTwoshot, cast_bf2half>; + +INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecFP, false) +INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ4, false) +INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ6, false) +INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ8, false) +INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecFP, true) +INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ4, true) +INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ6, true) +INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ8, true) + +INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecFP, false) +INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecQ4, false) +INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecQ6, false) +INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecQ8, false) + +#endif // USE_ROCM diff --git a/sgl-kernel/csrc/allreduce/quick_all_reduce.cuh b/sgl-kernel/csrc/allreduce/quick_all_reduce.cuh new file mode 100644 index 000000000..bd9e7b10f --- /dev/null +++ b/sgl-kernel/csrc/allreduce/quick_all_reduce.cuh @@ -0,0 +1,633 @@ +#pragma once + +#include + +#include "quick_all_reduce_base.h" + +namespace quickreduce { + +struct CodecBase { + const int thread; + const int rank; + const int group_leader; + __quickreduce_device_inline__ CodecBase(int thread, int rank) + : thread(thread), rank(rank), group_leader((threadIdx.x / kThreadGroupSize) * kThreadGroupSize) { + set_fp16_ovfl(true); + } +}; + +// Default full precision codec. +template +struct CodecFP : public CodecBase { + static constexpr int kWorldSize = world_size; + static constexpr int kRankAtoms = kAtoms / kWorldSize; + + // Codec tile size process by this workgroup. + // Each thread processes atoms of f16x8_t (16B). + static constexpr int kRankTransmittedTileSize = kBlockSize * kRankAtoms * sizeof(int32x4_t); + static_assert(kRankTransmittedTileSize % 16 == 0, "kRankTransmittedTileSize must be 16B aligned."); + + // Total tile size for the collective communication. + static constexpr int kTransmittedTileSize = kRankTransmittedTileSize * kWorldSize; + + __quickreduce_device_inline__ CodecFP(int thread, int rank) : CodecBase(thread, rank) {} + + __quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer, const int32x4_t* __restrict__ data) { + for (int i = 0; i < kRankAtoms; i++) { + __builtin_nontemporal_store(data[i], send_buffer + thread); + send_buffer += kAtomStride; + } + } + + __quickreduce_device_inline__ void recv(int32x4_t** __restrict__ recv_buffer, int32x4_t* __restrict__ data) { + for (int i = 0; i < kRankAtoms; i++) { + data[i] = __builtin_nontemporal_load(*recv_buffer + thread); + *recv_buffer += kAtomStride; + } + } +}; + +// Int4 symmetric quantization codec. +// We quantize the FP16 data to block-scaled Int4 in blocks of 4 * +// kThreadGroupSize. +template +struct CodecQ4 : public CodecBase { + static constexpr int kWorldSize = world_size; + + // Codec tile size process by this workgroup. + // Each threads processes a fragment of fp16x8_t (16B), + // into a int4x8_t (4B) and a fp16 scale shared among 32 values. + static constexpr int kRankAtoms = kAtoms / kWorldSize; + static constexpr int kRankTileStride = 1152; + static constexpr int kRankTileScaleOffset = 1024; + static constexpr int kRankTransmittedTileSize = kRankTileStride * kRankAtoms; + static_assert(kRankTransmittedTileSize % 16 == 0, "kRankTransmittedTileSize must be 16B aligned."); + + static constexpr int kRankBufferTileStride = kRankTileStride / sizeof(int32x4_t); + + // Total tile size for the collective communication. + static constexpr int kTransmittedTileSize = kRankTransmittedTileSize * kWorldSize; + + // Constants configuration + + // {-1/8.0h, -1/8.0h}, f16x2_t + static constexpr int kScaleFactor = std::is_same::value ? 0xB000B000 : 0xBE00BE00; + + // {1e-7, 1e-7}, f16x2_t + static constexpr int kScaleEpsilon = std::is_same::value ? 0x00010001 : 0x33D733D7; + + // {-8, -8}, f16x2_t + static constexpr int kRangeMin = std::is_same::value ? 0xC800C800 : 0xC100C100; + + // {+7, +7}, f16x2_t + static constexpr int kRangeMax = std::is_same::value ? 0x47004700 : 0x40E040E0; + + // {+8, +8}, int16x2_t + static constexpr int kRangeBias = 0x00080008; + + __quickreduce_device_inline__ CodecQ4(int thread, int rank) : CodecBase(thread, rank) {} + + __quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer, const int32x4_t* __restrict__ data) { + for (int k = 0; k < kRankAtoms; k++) { + int32x4_t const atom = data[k]; + + // Compute the absolute maximum of the atom in the thread group + // In 2 blocks of values, upper/lower halves of the f16x2_t + int wblockmax = group_abs_max(atom); + + // Derive scales + int decoding_scale; + int encoding_scale; + decoding_scale = packed_mul(wblockmax, kScaleFactor); + encoding_scale = packed_add(decoding_scale, kScaleEpsilon); + encoding_scale = packed_rcp(encoding_scale); + + // Apply scales to get quantized values + int32x4_t w; + for (int i = 0; i < 4; i++) { + w[i] = packed_mul(atom[i], encoding_scale); + w[i] = packed_max(w[i], kRangeMin); + w[i] = packed_min(w[i], kRangeMax); + } + + // Convert from f16x2_t to uint16x2_t + int32x4_t q; + { + int16_t* qi = reinterpret_cast(&q); + T* wh = reinterpret_cast(&w); + for (int i = 0; i < 8; i++) + qi[i] = (int16_t)rintf(T2float_cast(wh[i])); + + for (int i = 0; i < 4; i++) { + q[i] = packed_add(q[i], kRangeBias); + } + } + + // Pack 8 x q4 into int32_t + int qw = q[0] | (q[1] << 4) | (q[2] << 8) | (q[3] << 12); + + // Write quantized atom to send_buffer + // note: only the group leader stores the scale + uint8_t* atom_ptr = reinterpret_cast(send_buffer + k * kRankBufferTileStride); + int32_t* qw_ptr = reinterpret_cast(atom_ptr) + thread; + int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + (thread / 8); + + __builtin_nontemporal_store(qw, qw_ptr); + if (threadIdx.x == group_leader) { + __builtin_nontemporal_store(decoding_scale, qs_ptr); + } + } + } + + __quickreduce_device_inline__ void recv(int32x4_t** __restrict__ recv_buffer, int32x4_t* __restrict__ data) { + for (int k = 0; k < kRankAtoms; k++) { + // Directly read quantized atom from recv_buffer + uint8_t* atom_ptr = reinterpret_cast(*recv_buffer); + int32_t* qw_ptr = reinterpret_cast(atom_ptr) + thread; + int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + (thread / 8); + + int32_t qw = __builtin_nontemporal_load(qw_ptr); + int qs = __builtin_nontemporal_load(qs_ptr); + + *recv_buffer += kRankBufferTileStride; + + // Unpack q4 into f16x8_t + int32x4_t w; + { + static constexpr uint kMask000F = 0x000F000F; + static constexpr uint kHalf2_1024 = 0x64006400; // {1024.0, 1024.0}, fp16x2_t + static uint constexpr kHalf2_1032 = 0xE408E408; // {-1032.0, -1032.0}, fp16x2_t + + for (int i = 0; i < 4; i++) { + if constexpr (std::is_same::value) { + int32_t q4 = ((qw >> (i * 4)) & kMask000F) | kHalf2_1024; + w[i] = packed_add(q4, kHalf2_1032); + } else { + int32_t int16_2 = (qw >> (i * 4)) & kMask000F; + int16_t low = static_cast(int16_2 & 0xFFFF); + int16_t high = static_cast((int16_2 >> 16) & 0xFFFF); + nv_bfloat16 bf_low = __float2bfloat16(static_cast(low)); + nv_bfloat16 bf_high = __float2bfloat16(static_cast(high)); + nv_bfloat162 bf2 = __halves2bfloat162(bf_low, bf_high); + int32_t packed_bf16 = *reinterpret_cast(&bf2); + w[i] = packed_add(packed_bf16, kRangeMin); + } + } + } + + // Apply decoding scales + for (int i = 0; i < 4; i++) { + w[i] = packed_mul(w[i], qs); + } + + data[k] = w; + } + } +}; + +// Int6 symmetric quantization codec. +// We quantize the FP16 data to block-scaled Int6 in blocks of 4 * +// kThreadGroupSize. +template +struct CodecQ6 : public CodecBase { + static constexpr int kWorldSize = world_size; + + // Codec tile size process by this workgroup. + // Each threads processes a fragment of fp16x8_t (16B), + // into a int6x8_t (4B + 2B) and a fp16 scale shared among 32 values. + static constexpr int kRankAtoms = kAtoms / kWorldSize; + static constexpr int kRankTileStride = 1664; + static constexpr int kRankTileQ2Offset = 1024; + static constexpr int kRankTileScaleOffset = 1536; + static constexpr int kRankTransmittedTileSize = kRankTileStride * kRankAtoms; + static_assert(kRankTransmittedTileSize % 16 == 0, "kRankTransmittedTileSize must be 16B aligned."); + + static constexpr int kRankBufferTileStride = kRankTileStride / sizeof(int32x4_t); + + // Total tile size for the collective communication. + static constexpr int kTransmittedTileSize = kRankTransmittedTileSize * kWorldSize; + + // Constants configuration + + // {-1/32.0h, -1/32.0h}, fp16x2_t + static constexpr int kScaleFactor = std::is_same::value ? 0xA800A800 : 0xBD00BD00; + + // {1e-7, 1e-7}, fp16x2_t + static constexpr int kScaleEpsilon = std::is_same::value ? 0x00010001 : 0x33D733D7; + + // {-32, -32}, fp16x2_t + static constexpr int kRangeMin = std::is_same::value ? 0xD000D000 : 0xC200C200; + + // {+31, +31}, fp16x2_t + static constexpr int kRangeMax = std::is_same::value ? 0x4FC04FC0 : 0x41F841F8; + + // {+32, +32}, int16x2_t + static constexpr int kRangeBias = 0x00200020; + + __quickreduce_device_inline__ CodecQ6(int thread, int rank) : CodecBase(thread, rank) {} + + __quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer, const int32x4_t* __restrict__ data) { + for (int k = 0; k < kRankAtoms; k++) { + int32x4_t const atom = data[k]; + + // Compute the absolute maximum of the atom in the thread group + // In 2 blocks of values, upper/lower halves of the f16x2_t + int wblockmax = group_abs_max(atom); + + // Derive scales + int decoding_scale; + int encoding_scale; + decoding_scale = packed_mul(wblockmax, kScaleFactor); + encoding_scale = packed_add(decoding_scale, kScaleEpsilon); + encoding_scale = packed_rcp(encoding_scale); + + // Apply scales to get quantized values + int32x4_t w; + for (int i = 0; i < 4; i++) { + w[i] = packed_mul(atom[i], encoding_scale); + w[i] = packed_max(w[i], kRangeMin); + w[i] = packed_min(w[i], kRangeMax); + } + + // Convert from f16x2_t to uint16x2_t + int32x4_t q; + { + int16_t* qi = reinterpret_cast(&q); + T* wh = reinterpret_cast(&w); + for (int i = 0; i < 8; i++) + qi[i] = (int16_t)rintf(T2float_cast(wh[i])); + + for (int i = 0; i < 4; i++) { + q[i] = packed_add(q[i], kRangeBias); + } + } + + // Pack 8 x q6 into int32_t + int16_t + uint32_t q4w; + uint16_t q2w = 0; + q4w = (q[0] & 0x000F000F) | ((q[1] & 0x000F000F) << 4) | ((q[2] & 0x000F000F) << 8) | ((q[3] & 0x000F000F) << 12); + { + int16_t* tw = reinterpret_cast(&q); +#pragma unroll + for (int i = 0; i < 8; i++) { + q2w |= (tw[i] >> 4) << (i * 2); + } + } + // Write quantized atom to send_buffer + // note: only the group leader stores the scale + uint8_t* atom_ptr = reinterpret_cast(send_buffer + k * kRankBufferTileStride); + uint32_t* q4w_ptr = reinterpret_cast(atom_ptr) + thread; + uint16_t* q2w_ptr = reinterpret_cast(atom_ptr + kRankTileQ2Offset) + thread; + int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + (thread / 8); + + __builtin_nontemporal_store(q4w, q4w_ptr); + __builtin_nontemporal_store(q2w, q2w_ptr); + if (threadIdx.x == group_leader) { + __builtin_nontemporal_store(decoding_scale, qs_ptr); + } + } + } + + __quickreduce_device_inline__ void recv(int32x4_t** __restrict__ recv_buffer, int32x4_t* __restrict__ data) { + for (int k = 0; k < kRankAtoms; k++) { + // Directly read quantized atom from recv_buffer + uint8_t* atom_ptr = reinterpret_cast(*recv_buffer); + uint32_t* q4w_ptr = reinterpret_cast(atom_ptr) + thread; + uint16_t* q2w_ptr = reinterpret_cast(atom_ptr + kRankTileQ2Offset) + thread; + int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + (thread / 8); + + uint32_t q4w = __builtin_nontemporal_load(q4w_ptr); + uint16_t q2w = __builtin_nontemporal_load(q2w_ptr); + int qs = __builtin_nontemporal_load(qs_ptr); + + *recv_buffer += kRankBufferTileStride; + + // Unpack q6 into fp16x8_t + int32x4_t w; + { + static uint constexpr kMask000F = 0x000F000F; + static uint constexpr kHalf2_1024 = 0x64006400; // {1024.0, 1024.0}, fp16x2_t + static uint constexpr kHalf2_1056 = 0xE420E420; // {-1056.0, -1056.0}, fp16x2_t + +#pragma unroll + for (int i = 0; i < 4; i++) { + int32_t q4 = q4w & kMask000F; + int32_t q2 = (q2w & 0x3) | ((q2w & 0xC) << 14); + q4w >>= 4; + q2w >>= 4; + if constexpr (std::is_same::value) { + int32_t q6 = q4 | (q2 << 4) | kHalf2_1024; + asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(w[i]) : "v"(q6), "v"(kHalf2_1056)); + } else { + int32_t int16_2 = q4 | (q2 << 4); + int16_t low = static_cast(int16_2 & 0xFFFF); + int16_t high = static_cast((int16_2 >> 16) & 0xFFFF); + + nv_bfloat16 bf_low = __float2bfloat16(static_cast(low)); + nv_bfloat16 bf_high = __float2bfloat16(static_cast(high)); + nv_bfloat162 bf2 = __halves2bfloat162(bf_low, bf_high); + int32_t packed_bf16 = *reinterpret_cast(&bf2); + w[i] = packed_add(packed_bf16, kRangeMin); + } + } + } + + // Apply decoding scales + for (int i = 0; i < 4; i++) { + w[i] = packed_mul(w[i], qs); + } + + // That's pretty much it... + data[k] = w; + } + } +}; + +// Int8 symmetric quantization codec. +// We quantize the FP16 data to block-scaled Int8 in blocks of 4 * +// kThreadGroupSize. +template +struct CodecQ8 : public CodecBase { + static constexpr int kWorldSize = world_size; + + // Codec tile size process by this workgroup. + // Each threads processes a fragment of f16x8_t (16B), + // into a int8x8_t (8B) and a f16 scale shared among 32 values. + static constexpr int kRankAtoms = kAtoms / kWorldSize; + static constexpr int kRankTileStride = 2176; + static constexpr int kRankTileScaleOffset = 2048; + static constexpr int kRankTransmittedTileSize = kRankTileStride * kRankAtoms; + static_assert(kRankTransmittedTileSize % 16 == 0, "kRankTileSize must be 16B aligned."); + + static constexpr int kRankBufferTileStride = kRankTileStride / sizeof(int32x4_t); + + // Total tile size for the collective communication. + static constexpr int kTransmittedTileSize = kRankTransmittedTileSize * kWorldSize; + + // Constants configuration + + // {-1/128.0h, -1/128.0h}, f16x2_t + static constexpr int kScaleFactor = std::is_same::value ? 0xA000A000 : 0xBC00BC00; + + // {1e-7, 1e-7}, f16x2_t + static constexpr int kScaleEpsilon = std::is_same::value ? 0x00010001 : 0x33D733D7; + + // {-128, -128}, f16x2_t + static constexpr int kRangeMin = std::is_same::value ? 0xD800D800 : 0xC300C300; + // {+127, +127}, f16x2_t + static constexpr int kRangeMax = std::is_same::value ? 0x57F057F0 : 0x42FE42FE; + + // {+128, +128}, int16x2_t + static constexpr int kRangeBias = 0x00800080; + + __quickreduce_device_inline__ CodecQ8(int thread, int rank) : CodecBase(thread, rank) {} + + __quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer, int32x4_t const* __restrict__ data) { + for (int k = 0; k < kRankAtoms; k++) { + int32x4_t const atom = data[k]; + // Compute the absolute maximum of the atom in the thread group + // In 2 blocks of values, upper/lower halves of the f16x2_t + int wblockmax = group_abs_max(atom); + + // Derive scales + int decoding_scale; + int encoding_scale; + decoding_scale = packed_mul(wblockmax, kScaleFactor); + encoding_scale = packed_add(decoding_scale, kScaleEpsilon); + encoding_scale = packed_rcp(encoding_scale); + + // Apply scales to get quantized values + int32x4_t w; + for (int i = 0; i < 4; i++) { + w[i] = packed_mul(atom[i], encoding_scale); + w[i] = packed_max(w[i], kRangeMin); + w[i] = packed_min(w[i], kRangeMax); + } + + // Convert from f16x2_t to uint16x2_t + int32x4_t q; + { + int16_t* qi = reinterpret_cast(&q); + T* wh = reinterpret_cast(&w); + for (int i = 0; i < 8; i++) + qi[i] = (int16_t)rintf(T2float_cast(wh[i])); + + for (int i = 0; i < 4; i++) { + q[i] = packed_add(q[i], kRangeBias); + } + } + + // Pack 8 x q8 into int32x2_t + int32x2_t qw; + qw[0] = q[0] | (q[1] << 8); + qw[1] = q[2] | (q[3] << 8); + + // Write quantized atom to send_buffer + // note: only the group leader stores the scale + uint8_t* atom_ptr = reinterpret_cast(send_buffer + k * kRankBufferTileStride); + int32x2_t* qw_ptr = reinterpret_cast(atom_ptr) + thread; + int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + (thread / 8); + + __builtin_nontemporal_store(qw, qw_ptr); + if (threadIdx.x == group_leader) { + __builtin_nontemporal_store(decoding_scale, qs_ptr); + } + } + } + + __quickreduce_device_inline__ void recv(int32x4_t** __restrict__ recv_buffer, int32x4_t* __restrict__ data) { + for (int k = 0; k < kRankAtoms; k++) { + // Directly read quantized atom from recv_buffer + uint8_t* atom_ptr = reinterpret_cast(*recv_buffer); + int32x2_t* qw_ptr = reinterpret_cast(atom_ptr) + thread; + int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + (thread / 8); + + int32x2_t qw = __builtin_nontemporal_load(qw_ptr); + int qs = __builtin_nontemporal_load(qs_ptr); + + *recv_buffer += kRankBufferTileStride; + + // Unpack q8 into fp16x8_t + int32x4_t w; + { + static uint constexpr kMask00FF = 0x00FF00FF; + + // {1024.0, 1024.0}, fp16x2_t + static uint constexpr kHalf2_1024 = 0x64006400; + + // {-1152.0, -1152.0}, fp16x2_t + static uint constexpr kHalf2_1152 = 0xE480E480; + +#pragma unroll + for (int i = 0; i < 4; i++) { + if constexpr (std::is_same::value) { + int32_t q8 = ((qw[i / 2] >> ((i % 2) * 8)) & kMask00FF) | kHalf2_1024; + w[i] = packed_add(q8, kHalf2_1152); + } else { + int32_t int16_2 = (qw[i / 2] >> ((i % 2) * 8)) & kMask00FF; + int16_t low = static_cast(int16_2 & 0xFFFF); + int16_t high = static_cast((int16_2 >> 16) & 0xFFFF); + nv_bfloat16 bf_low = __float2bfloat16(static_cast(low)); + nv_bfloat16 bf_high = __float2bfloat16(static_cast(high)); + nv_bfloat162 bf2 = __halves2bfloat162(bf_low, bf_high); + int32_t packed_bf16 = *reinterpret_cast(&bf2); + w[i] = packed_add(packed_bf16, kRangeMin); + } + } + } + + // Apply decoding scales + for (int i = 0; i < 4; i++) { + w[i] = packed_mul(w[i], qs); + } + + data[k] = w; + } + } +}; + +// Twoshot All Reduce +template +struct AllReduceTwoshot { + static_assert(sizeof(T) == 2); + + static constexpr int kWorldSize = Codec::kWorldSize; + + __device__ static void + run(T const* __restrict__ input, + T* __restrict__ output, + uint32_t const N, // number of elements + int const block, // block index + int const rank, // rank index + uint8_t** __restrict__ buffer_list, // communication buffers + uint32_t const data_offset, // offset to start of the data buffer + uint32_t flag_color) { + // Topology + int thread = threadIdx.x + threadIdx.y * kWavefront; + uint8_t* rank_buffer = buffer_list[rank]; + Codec codec(thread, rank); + int block_id = blockIdx.x; + int grid_size = gridDim.x; + // -------------------------------------------------------- + // Read input into registers + int32x4_t tA[kAtoms]; + + BufferResource src_buffer(const_cast(input), N * sizeof(T)); + uint32_t src_offset = block * kTileSize + thread * sizeof(int32x4_t); + + for (int i = 0; i < kAtoms; i++) { + tA[i] = buffer_load_dwordx4(src_buffer.descriptor, src_offset, 0, 0); + src_offset += kAtomStride * sizeof(int32x4_t); + if constexpr (cast_bf2half) { + const nv_bfloat162* bf_buf = reinterpret_cast(&tA[i]); + half2 half_buf[4]; +#pragma unroll + for (int j = 0; j < 4; ++j) { + float2 f = __bfloat1622float2(bf_buf[j]); + half_buf[j] = __float22half2_rn(f); + } + tA[i] = *reinterpret_cast(half_buf); + } + } + + // -------------------------------------------------------- + // Phase-1A: Write segment data into the communication buffer of the target + // rank responsible for this segment. + uint32_t comm_data0_offset = data_offset + block_id * Codec::kTransmittedTileSize; + uint32_t comm_data1_offset = grid_size * Codec::kTransmittedTileSize + comm_data0_offset; + + uint32_t comm_flags0_offset = block_id * (kWorldSize * sizeof(uint32_t)); + uint32_t comm_flags1_offset = grid_size * (kWorldSize * sizeof(uint32_t)) + comm_flags0_offset; + + for (int r = 0; r < kWorldSize; r++) { + int32x4_t* send_buffer = + reinterpret_cast(buffer_list[r] + comm_data0_offset + rank * Codec::kRankTransmittedTileSize); + codec.send(send_buffer, &tA[r * Codec::kRankAtoms]); + } + + __syncthreads(); + if (thread < kWorldSize) { + int r = thread; + uint32_t* flag_ptr = reinterpret_cast(buffer_list[r] + comm_flags0_offset + rank * sizeof(uint32_t)); + set_sync_flag(flag_ptr, flag_color); + } + // -------------------------------------------------------- + // Phase-1B: Reduce the segment data from the communication buffers. + int32x4_t tR[Codec::kRankAtoms] = {}; + { + // Read the data from the communication buffer. + int32x4_t* recv_buffer = reinterpret_cast(rank_buffer + comm_data0_offset); + uint32_t* flag_ptr = reinterpret_cast(rank_buffer + comm_flags0_offset); + + for (int r = 0; r < kWorldSize; r++) { + // Wait for the flags to be set. + if (thread == 0) { + wait_sync_flag(&flag_ptr[r], flag_color); + } + __syncthreads(); + + // note: we reuse tA as temp buffer here + codec.recv(&recv_buffer, tA); + + for (int i = 0; i < Codec::kRankAtoms; i++) { + packed_assign_add(&tR[i], &tA[i]); + } + } + } + + // Phase-2: Write the reduced segment to every other rank + for (int r = 0; r < kWorldSize; r++) { + int32x4_t* send_buffer = + reinterpret_cast(buffer_list[r] + comm_data1_offset + rank * Codec::kRankTransmittedTileSize); + codec.send(send_buffer, tR); + } + + __syncthreads(); + if (thread < kWorldSize) { + int r = thread; + uint32_t* flag_ptr = reinterpret_cast(buffer_list[r] + comm_flags1_offset + rank * sizeof(uint32_t)); + set_sync_flag(flag_ptr, flag_color); + } + + // Phase-2: Read the gather segments from the rank's communication buffer. + { + // Read the data from the communication buffer. + int32x4_t* recv_buffer = reinterpret_cast(rank_buffer + comm_data1_offset); + uint32_t* flag_ptr = reinterpret_cast(rank_buffer + comm_flags1_offset); + + for (int r = 0; r < kWorldSize; r++) { + // Wait for the flags to be set. + if (thread == 0) { + wait_sync_flag(&flag_ptr[r], flag_color); + } + __syncthreads(); + + // Gather all reduced and final rank segments into tA. + codec.recv(&recv_buffer, &tA[r * Codec::kRankAtoms]); + } + } + + // -------------------------------------------------------- + // Write the result to output. + BufferResource dst_buffer(output, N * sizeof(T)); + uint32_t dst_offset = block * kTileSize + thread * sizeof(int32x4_t); + + for (int i = 0; i < kAtoms; i++) { + if constexpr (cast_bf2half) { + const half2* half_buf = reinterpret_cast(&tA[i]); + nv_bfloat162 bf16_buf[4]; +#pragma unroll + for (int j = 0; j < 4; ++j) { + float2 f = __half22float2(half_buf[j]); + bf16_buf[j] = __float22bfloat162_rn(f); + } + buffer_store_dwordx4(*reinterpret_cast(bf16_buf), dst_buffer.descriptor, dst_offset, 0, 0); + } else { + buffer_store_dwordx4(tA[i], dst_buffer.descriptor, dst_offset, 0, 0); + } + dst_offset += kAtomStride * sizeof(int32x4_t); + } + } +}; + +} // namespace quickreduce diff --git a/sgl-kernel/csrc/allreduce/quick_all_reduce.h b/sgl-kernel/csrc/allreduce/quick_all_reduce.h new file mode 100644 index 000000000..1d629e018 --- /dev/null +++ b/sgl-kernel/csrc/allreduce/quick_all_reduce.h @@ -0,0 +1,233 @@ +#pragma once + +#include + +#include + +#include "quick_all_reduce.cuh" + +#define HIP_CHECK(err) \ + do { \ + hipError_t err_ = (err); \ + if (err_ != hipSuccess) { \ + std::printf("HIP error %d at %s:%d. %s\n", err_, __FILE__, __LINE__, hipGetErrorString(err_)); \ + throw std::runtime_error("HIP error"); \ + } \ + } while (0) + +namespace quickreduce { +using fptr_t = int64_t; +static_assert(sizeof(void*) == sizeof(fptr_t)); + +template +__global__ __quickreduce_launch_bounds_two_shot__ static void allreduce_prototype_twoshot( + T const* A, + T* B, + uint32_t N, + uint32_t num_blocks, + int rank, + uint8_t** dbuffer_list, + uint32_t data_offset, + uint32_t flag_color) { + int block = blockIdx.x; + int grid = gridDim.x; + + while (block < num_blocks) { + AllReduceKernel::run(A, B, N, block, rank, dbuffer_list, data_offset, flag_color); + block += grid; + flag_color++; + } +} + +#define TWOSHOT_DISPATCH(__codec) \ + if (world_size == 2) { \ + using LineCodec = __codec; \ + using AllReduceKernel = AllReduceTwoshot; \ + hipLaunchKernelGGL( \ + (allreduce_prototype_twoshot), \ + dim3(grid), \ + dim3(kBlockTwoShot), \ + 0, \ + stream, \ + A, \ + B, \ + N, \ + num_blocks, \ + rank, \ + dbuffer_list, \ + data_offset, \ + flag_color); \ + } else if (world_size == 4) { \ + using LineCodec = __codec; \ + using AllReduceKernel = AllReduceTwoshot; \ + hipLaunchKernelGGL( \ + (allreduce_prototype_twoshot), \ + dim3(grid), \ + dim3(kBlockTwoShot), \ + 0, \ + stream, \ + A, \ + B, \ + N, \ + num_blocks, \ + rank, \ + dbuffer_list, \ + data_offset, \ + flag_color); \ + } else if (world_size == 8) { \ + using LineCodec = __codec; \ + using AllReduceKernel = AllReduceTwoshot; \ + hipLaunchKernelGGL( \ + (allreduce_prototype_twoshot), \ + dim3(grid), \ + dim3(kBlockTwoShot), \ + 0, \ + stream, \ + A, \ + B, \ + N, \ + num_blocks, \ + rank, \ + dbuffer_list, \ + data_offset, \ + flag_color); \ + } + +enum QuickReduceQuantLevel { + F16 = 0, + INT8 = 1, + INT6 = 2, + INT4 = 3, +}; + +struct DeviceComms { + // Max problem size is 2GB (in bytes) or half of uint32_t max value. + int64_t kMaxProblemSize = static_cast(std::numeric_limits::max()) + 1; + + // Max TP-8 + static int constexpr kMaxWorldSize = 8; + + bool initialized = false; + uint32_t flag_color = 1; + int world_size; + int rank; + + uint8_t* dbuffer; + uint8_t** dbuffer_list; + hipIpcMemHandle_t buffer_ipc_handle; + std::vector all_buffer_ipc_handles; + std::vector buffer_list; + uint32_t data_offset; + + DeviceComms() : initialized(false), world_size(1), rank(0) {} + ~DeviceComms() { + destroy(); + } + + void init(int world_size, int rank, std::optional max_problem_size = std::nullopt) { + destroy(); + this->world_size = world_size; + this->rank = rank; + if (max_problem_size.has_value() && max_problem_size.value() > 0) { + this->kMaxProblemSize = max_problem_size.value(); + } + // Allocate buffer size for worst case: F16 2-stage buffer. + uint32_t flags_buffer_size = 2 * world_size * kMaxNumBlocks * sizeof(uint32_t); + static int64_t data_buffer_size = 2 * this->kMaxProblemSize; + int64_t total_buffer_size = flags_buffer_size + data_buffer_size; + data_offset = flags_buffer_size; + HIP_CHECK(hipExtMallocWithFlags((void**)&dbuffer, total_buffer_size, hipDeviceMallocUncached)); + + // Clear the flags buffer. + HIP_CHECK(hipMemset(dbuffer, 0, flags_buffer_size)); + + // Device-side list of IPC buffers. + buffer_list.resize(world_size); + HIP_CHECK(hipMalloc(&dbuffer_list, world_size * sizeof(uint8_t*))); + + // Create IPC handles for rank's communication buffer. + all_buffer_ipc_handles.resize(world_size); + HIP_CHECK(hipIpcGetMemHandle(&buffer_ipc_handle, dbuffer)); + + initialized = true; + } + int get_world_size() { + return world_size; + } + int get_rank() { + return rank; + } + bool status() { + return initialized; + } + hipIpcMemHandle_t const get_handle() { + return buffer_ipc_handle; + } + + void destroy() { + if (initialized) { + for (int i = 0; i < world_size; i++) { + if (i != rank) { + HIP_CHECK(hipIpcCloseMemHandle(dbuffer_list[i])); + } + } + + HIP_CHECK(hipFree(dbuffer)); + HIP_CHECK(hipFree(dbuffer_list)); + + initialized = false; + } + } + + void open_ipc_handles(std::vector const& ipc_handles) { + assert(ipc_handles.size() == all_buffer_ipc_handles.size()); + for (int i = 0; i < world_size; i++) { + all_buffer_ipc_handles[i] = ipc_handles[i]; + } + + // Open device memory access to the IPC communication buffers. + // Note: For our own rank, we do not need to open a handle. + for (int i = 0; i < world_size; i++) { + if (i != rank) { + HIP_CHECK( + hipIpcOpenMemHandle((void**)&buffer_list[i], all_buffer_ipc_handles[i], hipIpcMemLazyEnablePeerAccess)); + } else { + buffer_list[i] = dbuffer; + } + } + + HIP_CHECK(hipMemcpy(dbuffer_list, buffer_list.data(), world_size * sizeof(uint8_t*), hipMemcpyHostToDevice)); + } + + template + void allreduce(T const* A, T* B, uint32_t N, int quant_level, hipStream_t stream) { + if (world_size != 2 && world_size != 4 && world_size != 8) { + throw std::runtime_error("All Reduce not supported for world_size = " + std::to_string(world_size)); + } + + // Configuration. + uint32_t msg_size = N * sizeof(T); + uint32_t num_blocks = divceil(msg_size, kTileSize); + uint32_t grid = min(kMaxNumBlocks, num_blocks); + auto quant_level_ = static_cast(quant_level); + switch (quant_level_) { + case QuickReduceQuantLevel::INT8: + TWOSHOT_DISPATCH(CodecQ8) + break; + case QuickReduceQuantLevel::INT6: + TWOSHOT_DISPATCH(CodecQ6) + break; + case QuickReduceQuantLevel::INT4: + TWOSHOT_DISPATCH(CodecQ4) + break; + default: + TWOSHOT_DISPATCH(CodecFP) + break; + } + HIP_CHECK(cudaGetLastError()); + // Rotate the flag color. + flag_color += divceil(N, grid); + } +}; + +} // namespace quickreduce diff --git a/sgl-kernel/csrc/allreduce/quick_all_reduce_base.h b/sgl-kernel/csrc/allreduce/quick_all_reduce_base.h new file mode 100644 index 000000000..759b28f38 --- /dev/null +++ b/sgl-kernel/csrc/allreduce/quick_all_reduce_base.h @@ -0,0 +1,318 @@ +#pragma once + +#include +#include +#include + +#include + +#define __quickreduce_device_inline__ __device__ __forceinline__ +#define __quickreduce_launch_bounds_two_shot__ __launch_bounds__(256, 4) +#define __quickreduce_launch_bounds_one_shot__ __launch_bounds__(512, 4) + +namespace quickreduce { + +typedef __hip_bfloat16 nv_bfloat16; +typedef __hip_bfloat162 nv_bfloat162; + +using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int; +using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int; + +// Setup acquire-release semantics for vector memory reads (mubuf instruction) +// as per architecture. +#if defined(__gfx942__) +// CDNA3: Scope bits sc0, sc1 +#define MUBUF_ACQUIRE 16 +#define MUBUF_RELEASE 16 +#elif (defined(__gfx908__) || defined(__gfx90a__)) +// CDNA1 and CDNA2 - glc bit +#define MUBUF_ACQUIRE 1 +#define MUBUF_RELEASE 0 +#endif + +static constexpr int kNegOne = 0xBC00BC00; // {-1, -1}, fp16x2_t + +// Number of atoms (4xf16x2_t) processed by a single thread +static constexpr int kAtoms = 8; + +// We use a workgroup of 256 threads +static constexpr int kBlockSize = 256; +static constexpr int kAtomStride = kBlockSize; + +// Size and atom stride of source/destination data that the block will +// process. +// Workgroup scope = Tile = (256 threads x 8 atoms x 16B) +static constexpr int kTileSize = kBlockSize * kAtoms * sizeof(int32x4_t); + +// Max number of blocks. 304 CUs on MI300 +static constexpr int kMaxNumBlocks = 304 * 4; + +// Standard CDNA wavefront size. +static constexpr int kWavefront = 64; + +// 256 thread, 4 wavefronts. +static dim3 constexpr kBlockTwoShot = {kWavefront, kBlockSize / kWavefront, 1}; + +// Number of threads in a group for quantization +// It corresponds to 32 F16 elements in quantization block +static constexpr int kThreadGroupSize = 8; + +// Methods +__quickreduce_device_inline__ __host__ unsigned long divceil(unsigned long x, unsigned long y) { + return ((x + y - 1) / y); +} + +union BufferResource { + __quickreduce_device_inline__ constexpr BufferResource() : config(0x00020000U) {} + + __quickreduce_device_inline__ constexpr BufferResource(void* buffer_address, uint32_t buffer_size) + : address(buffer_address), range(buffer_size), config(0x00020000U) {} + + int32x4_t descriptor; + struct { + void* address; // 8B, out of which first 48b is address, and 16b is stride + // (unused) + uint32_t range; // Byte range for the buffer resource + uint32_t config; // Constant, DFMT=32b + }; +}; + +__quickreduce_device_inline__ static int32x4_t buffer_load_dwordx4( + int32x4_t srsrc, int32_t voffset, int32_t soffset, int32_t aux) __asm("llvm.amdgcn.raw.buffer.load.v4i32"); + +__quickreduce_device_inline__ static void +buffer_store_dwordx4(int32x4_t data, int32x4_t srsrc, int32_t voffset, int32_t soffset, int32_t aux) __asm( + "llvm.amdgcn.raw.buffer.store.v4i32"); + +__quickreduce_device_inline__ static void set_fp16_ovfl(bool const value) { +#if defined(__gfx942__) + if (value) { + asm volatile("s_setreg_imm32_b32 0xdc1, 1;" ::); + } else { + asm volatile("s_setreg_imm32_b32 0xdc1, 0;" ::); + } +#endif +} +union bf162_int_union { + int i; + nv_bfloat162 bf2; +}; + +template +__quickreduce_device_inline__ void packed_assign_add(int32x4_t* A, int32x4_t* B); + +template <> +__quickreduce_device_inline__ void packed_assign_add(int32x4_t* A, int32x4_t* B) { + int32x4_t& tR_fragment = A[0]; + int32x4_t& tA_fragment = B[0]; + + asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(tR_fragment[0]) : "v"(tR_fragment[0]), "v"(tA_fragment[0])); + asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(tR_fragment[1]) : "v"(tR_fragment[1]), "v"(tA_fragment[1])); + asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(tR_fragment[2]) : "v"(tR_fragment[2]), "v"(tA_fragment[2])); + asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(tR_fragment[3]) : "v"(tR_fragment[3]), "v"(tA_fragment[3])); +} + +template <> +__quickreduce_device_inline__ void packed_assign_add(int32x4_t* A, int32x4_t* B) { + nv_bfloat162* tA = reinterpret_cast(A); + nv_bfloat162* tB = reinterpret_cast(B); +#pragma unroll + for (int i = 0; i < 4; i++) { + tA[i] = __hadd2(tA[i], tB[i]); + } +} + +template +__quickreduce_device_inline__ int packed_max(int a, int b); + +template <> +__quickreduce_device_inline__ int packed_max(int a, int b) { + int result; + asm volatile("v_pk_max_f16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b)); + return result; +} + +template <> +__quickreduce_device_inline__ int packed_max(int a, int b) { + bf162_int_union A, B, R; + A.i = a; + B.i = b; + R.bf2 = __hmax2(A.bf2, B.bf2); + return R.i; +} + +template +__quickreduce_device_inline__ int packed_min(int a, int b); + +template <> +__quickreduce_device_inline__ int packed_min(int a, int b) { + int result; + asm volatile("v_pk_min_f16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b)); + return result; +} + +template <> +__quickreduce_device_inline__ int packed_min(int a, int b) { + bf162_int_union A, B, R; + A.i = a; + B.i = b; + R.bf2 = __hmin2(A.bf2, B.bf2); + return R.i; +} + +template +__quickreduce_device_inline__ int packed_abs_max(int a, int b); + +template <> +__quickreduce_device_inline__ int packed_abs_max(int a, int b) { + half2 wmaxh2 = __builtin_bit_cast(half2, a); + half2 wminh2 = __builtin_bit_cast(half2, b); + half2 wblockmaxh2; + + wblockmaxh2.x = __hgt(__habs(wmaxh2.x), __habs(wminh2.x)) ? wmaxh2.x : wminh2.x; + wblockmaxh2.y = __hgt(__habs(wmaxh2.y), __habs(wminh2.y)) ? wmaxh2.y : wminh2.y; + return __builtin_bit_cast(int, wblockmaxh2); +} + +template <> +__quickreduce_device_inline__ int packed_abs_max(int a, int b) { + bf162_int_union A, B, R; + A.i = a; + B.i = b; + R.bf2.x = __hgt(__habs(A.bf2.x), __habs(B.bf2.x)) ? A.bf2.x : B.bf2.x; + R.bf2.y = __hgt(__habs(A.bf2.y), __habs(B.bf2.y)) ? A.bf2.y : B.bf2.y; + return R.i; +} + +template +__quickreduce_device_inline__ int packed_add(int a, int b); + +template <> +__quickreduce_device_inline__ int packed_add(int a, int b) { + int result; + asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b)); + return result; +} + +template <> +__quickreduce_device_inline__ int packed_add(int a, int b) { + bf162_int_union A, B, R; + A.i = a; + B.i = b; + R.bf2 = __hadd2(A.bf2, B.bf2); + return R.i; +} + +template <> +__quickreduce_device_inline__ int packed_add(int a, int b) { + int result; + asm volatile("v_pk_add_i16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b)); + return result; +} + +template +__quickreduce_device_inline__ int packed_sub(int a, int b); + +template <> +__quickreduce_device_inline__ int packed_sub(int a, int b) { + int result; + + // MI300 lacks packed fp16 sub instruction. So we do -1 * min + max + asm volatile("v_pk_fma_f16 %0, %1, %2 %3" : "=v"(result) : "v"(kNegOne), "v"(b), "v"(a)); + return result; +} + +template <> +__quickreduce_device_inline__ int packed_sub(int a, int b) { + bf162_int_union A, B, R; + A.i = a; + B.i = b; + R.bf2 = __hsub2(A.bf2, B.bf2); + return R.i; +} + +template +__quickreduce_device_inline__ int packed_mul(int a, int b); + +template <> +__quickreduce_device_inline__ int packed_mul(int a, int b) { + int result; + asm volatile("v_pk_mul_f16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b)); + return result; +} + +template <> +__quickreduce_device_inline__ int packed_mul(int a, int b) { + nv_bfloat162* tA = reinterpret_cast(&a); + nv_bfloat162* tB = reinterpret_cast(&b); + nv_bfloat162 tR = __hmul2(*tA, *tB); + return *(reinterpret_cast(&tR)); +} + +template +__quickreduce_device_inline__ int packed_rcp(int a); + +template <> +__quickreduce_device_inline__ int packed_rcp(int a) { + return __builtin_bit_cast(int, h2rcp(__builtin_bit_cast(half2, a))); +} + +template <> +__quickreduce_device_inline__ int packed_rcp(int a) { + bf162_int_union A, R; + A.i = a; + R.bf2 = h2rcp(A.bf2); + return R.i; +} + +// changes dtype +__quickreduce_device_inline__ float T2float_cast(half a) { + return __half2float(a); +} + +__quickreduce_device_inline__ float T2float_cast(nv_bfloat16 a) { + return __bfloat162float(a); +} + +template +__quickreduce_device_inline__ int group_abs_max(int32x4_t atom) { + const int group_leader = (threadIdx.x / kThreadGroupSize) * kThreadGroupSize; + + int wmax, wmin, wblockmax; + int a, b; + a = packed_max(atom[0], atom[1]); + b = packed_max(atom[2], atom[3]); + + wmax = packed_max(a, b); + + a = packed_min(atom[0], atom[1]); + b = packed_min(atom[2], atom[3]); + + wmin = packed_min(a, b); + + // Reduce the max among a group of threads + // Note: This is basically 2 blocks of values setup as the + // upper/lower halves of the f16x2_t + for (int i = 1; i < kThreadGroupSize; i <<= 1) { + int x = __shfl_down(wmax, i); + wmax = packed_max(wmax, x); + + int y = __shfl_down(wmin, i); + wmin = packed_min(wmin, y); + } + wblockmax = packed_abs_max(wmax, wmin); + // Share with the cohort + wblockmax = __shfl(wblockmax, group_leader); + return wblockmax; +} + +__quickreduce_device_inline__ void set_sync_flag(uint32_t* flag_ptr, uint32_t flag) { + __atomic_store_n(flag_ptr, flag, __ATOMIC_RELEASE); +} + +__quickreduce_device_inline__ void wait_sync_flag(uint32_t* flag_ptr, uint32_t flag) { + while (__atomic_load_n(flag_ptr, __ATOMIC_RELAXED) != flag) { + } +} + +} // namespace quickreduce diff --git a/sgl-kernel/csrc/torch_extension_rocm.cc b/sgl-kernel/csrc/torch_extension_rocm.cc index 84f9d1e7a..46a50ca6b 100644 --- a/sgl-kernel/csrc/torch_extension_rocm.cc +++ b/sgl-kernel/csrc/torch_extension_rocm.cc @@ -54,6 +54,25 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) { m.def("get_meta_buffer_ipc_handle", &get_meta_buffer_ipc_handle); m.impl("get_meta_buffer_ipc_handle", torch::kCPU, &get_meta_buffer_ipc_handle); + // quick allreduce +#ifdef USE_ROCM + m.def( + "qr_all_reduce(int fa, Tensor inp, Tensor out, int quant_level, bool " + "cast_bf2half) -> ()"); + m.impl("qr_all_reduce", torch::kCUDA, &qr_all_reduce); + + m.def("init_custom_qr", &init_custom_qr); + m.def("qr_destroy", &qr_destroy); + + m.def("qr_get_handle", &qr_get_handle); + + m.def("qr_open_handles(int _fa, Tensor[](b!) handles) -> ()"); + m.impl("qr_open_handles", torch::kCPU, &qr_open_handles); + + // Max input size in bytes + m.def("qr_max_size", &qr_max_size); +#endif + /* * From csrc/moe */ diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index 6b589101f..ffd240a04 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -66,6 +66,13 @@ void register_graph_buffers( fptr_t _fa, const std::vector& handles, const std::vector>& offsets); torch::Tensor allocate_meta_buffer(int64_t size); torch::Tensor get_meta_buffer_ipc_handle(torch::Tensor& inp); +// quick allreduce +fptr_t init_custom_qr(int64_t rank, int64_t world_size, std::optional qr_max_size = std::nullopt); +void qr_destroy(fptr_t _fa); +torch::Tensor qr_get_handle(fptr_t _fa); +void qr_open_handles(fptr_t _fa, const std::vector& handles); +void qr_all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, int64_t quant_level, bool cast_bf2half = false); +int64_t qr_max_size(); #else // custom allreduce fptr_t @@ -77,6 +84,8 @@ std::tuple, std::vector> get_graph_buffer_ipc_meta void register_buffer(fptr_t _fa, const std::vector& fake_ipc_ptrs); void register_graph_buffers( fptr_t _fa, const std::vector>& handles, const std::vector>& offsets); + +// mscclpp torch::Tensor mscclpp_generate_unique_id(); fptr_t mscclpp_init_context( const torch::Tensor& unique_id, diff --git a/sgl-kernel/python/sgl_kernel/allreduce.py b/sgl-kernel/python/sgl_kernel/allreduce.py index 317b2f1a7..544fc1d77 100644 --- a/sgl-kernel/python/sgl_kernel/allreduce.py +++ b/sgl-kernel/python/sgl_kernel/allreduce.py @@ -1,4 +1,4 @@ -from typing import List, Tuple +from typing import List, Optional, Tuple import torch @@ -49,6 +49,38 @@ if torch.version.hip is not None: def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor: return torch.ops.sgl_kernel.get_meta_buffer_ipc_handle.default(inp) + # ROCM quick allreduce + def init_custom_qr( + rank: int, world_size: int, qr_max_size: Optional[int] = None + ) -> int: + return torch.ops.sgl_kernel.init_custom_qr.default( + world_size, rank, qr_max_size + ) + + def qr_get_handle(fa: int) -> torch.Tensor: + return torch.ops.sgl_kernel.qr_get_handle.default(fa) + + def qr_open_handles(fa: int, handles: list[torch.Tensor]) -> None: + torch.ops.sgl_kernel.qr_open_handles.default(fa, handles) + + def qr_all_reduce( + fa: int, + profile: int, + inp: torch.Tensor, + out: torch.Tensor, + cast_bf162half: bool, + ) -> None: + torch.ops.sgl_kernel.qr_all_reduce.default( + fa, profile, inp, out, cast_bf162half + ) + + def qr_destroy(fa: int) -> None: + torch.ops.sgl_kernel.qr_destroy.default(fa) + + def qr_max_size() -> int: + return torch.ops.sgl_kernel.qr_max_size.default() + + # mscclpp def mscclpp_generate_unique_id() -> bytes: raise NotImplementedError() diff --git a/sgl-kernel/setup_rocm.py b/sgl-kernel/setup_rocm.py index 4ab8635a8..a814b8196 100644 --- a/sgl-kernel/setup_rocm.py +++ b/sgl-kernel/setup_rocm.py @@ -41,6 +41,7 @@ include_dirs = [ sources = [ "csrc/allreduce/custom_all_reduce.hip", + "csrc/allreduce/quick_all_reduce.cu", "csrc/moe/moe_align_kernel.cu", "csrc/moe/moe_topk_softmax_kernels.cu", "csrc/torch_extension_rocm.cc", diff --git a/test/srt/test_quick_allreduce.py b/test/srt/test_quick_allreduce.py new file mode 100644 index 000000000..ed081255f --- /dev/null +++ b/test/srt/test_quick_allreduce.py @@ -0,0 +1,212 @@ +import os +import random +import socket +import unittest +from typing import Any + +import ray +import torch +import torch.distributed as dist + +from sglang.srt.distributed import init_distributed_environment +from sglang.srt.distributed.communication_op import ( # noqa + tensor_model_parallel_all_reduce, +) +from sglang.srt.distributed.device_communicators.quick_all_reduce import ( + qr_rocm_arch_available, +) +from sglang.srt.distributed.parallel_state import ( + get_tensor_model_parallel_group, + graph_capture, + initialize_model_parallel, +) +from sglang.test.test_utils import CustomTestCase + +torch.manual_seed(42) +random.seed(44) # keep the deterministic seed + + +def get_open_port() -> int: + # try ipv4 + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + except OSError: + # try ipv6 + with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + + +def multi_process_parallel( + world_size: int, cls: Any, test_target: Any, quant_mode: str +) -> None: + + # Using ray helps debugging the error when it failed + # as compared to multiprocessing. + # NOTE: We need to set working_dir for distributed tests, + # otherwise we may get import errors on ray workers + + ray.init(log_to_driver=True) + + distributed_init_port = get_open_port() + refs = [] + for rank in range(world_size): + refs.append( + test_target.remote(cls, world_size, rank, distributed_init_port, quant_mode) + ) + ray.get(refs) + + ray.shutdown() + + +class TestQuickAllReduce(CustomTestCase): + TEST_SIZES = [ + 2 * 1024 * 1024, + 4 * 1024 * 1024, + 8 * 1024 * 1024, + 16 * 1024 * 1024, + 32 * 1024 * 1024, + ] + TEST_LOOP = 5 + # Too many configurations can lead to a test grid that is too large + # The tp takes too long to boot,let's just choose 4 out of 12 configurations + # WORLD_SIZES = [2, 4, 8] + # QUANT_MODE = ["FP", "INT8", "INT6", "INT4"] + QUANT_MODE_WORLD_SIZE_PART = [["FP", 8], ["INT4", 4], ["INT8", 2], ["INT6", 2]] + + @unittest.skipIf( + not qr_rocm_arch_available(), + "Only test Quick AllReduce on ROCm architectures >= gfx94*", + ) + def test_graph_allreduce(self): + for quant_mode_world_size_part in self.QUANT_MODE_WORLD_SIZE_PART: + quant_mode = quant_mode_world_size_part[0] + world_size = quant_mode_world_size_part[1] + if world_size > torch.cuda.device_count(): + continue + multi_process_parallel(world_size, self, self.graph_allreduce, quant_mode) + + @unittest.skipIf( + not qr_rocm_arch_available(), + "Only test Quick AllReduce on ROCm architectures >= gfx94*", + ) + def test_eager_allreduce(self): + for quant_mode_world_size_part in self.QUANT_MODE_WORLD_SIZE_PART: + quant_mode = quant_mode_world_size_part[0] + world_size = quant_mode_world_size_part[1] + if world_size > torch.cuda.device_count(): + continue + multi_process_parallel(world_size, self, self.eager_allreduce, quant_mode) + + @ray.remote(num_gpus=1, max_calls=1) + def graph_allreduce(self, world_size, rank, distributed_init_port, quant_mode): + os.environ.pop("CUDA_VISIBLE_DEVICES", None) + os.environ["ROCM_QUICK_REDUCE_QUANTIZATION"] = quant_mode + os.environ["ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16"] = "0" + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + distributed_init_method = f"tcp://localhost:{distributed_init_port}" + init_distributed_environment( + world_size=world_size, + rank=rank, + distributed_init_method=distributed_init_method, + local_rank=rank, + ) + initialize_model_parallel(tensor_model_parallel_size=world_size) + group = get_tensor_model_parallel_group().device_group + + # A small all_reduce for warmup. + # this is needed because device communicators might be created lazily + # (e.g. NCCL). This will ensure that the communicator is initialized + # before any communication happens, so that this group can be used for + # graph capture immediately. + data = torch.zeros(1) + data = data.to(device=device) + torch.distributed.all_reduce(data, group=group) + torch.cuda.synchronize() + del data + + for sz in self.TEST_SIZES: + for dtype in [torch.float16, torch.bfloat16]: + for _ in range(self.TEST_LOOP): + with graph_capture() as graph_capture_context: + # use integers so result matches NCCL exactly + inp1 = torch.randint( + 1, + 23, + (sz,), + dtype=dtype, + device=torch.cuda.current_device(), + ) + inp2 = torch.randint( + -23, + 1, + (sz,), + dtype=dtype, + device=torch.cuda.current_device(), + ) + torch.cuda.synchronize() + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph( + graph, stream=graph_capture_context.stream + ): + out1 = tensor_model_parallel_all_reduce(inp1) + # the input buffer is immediately modified to test + # synchronization + dist.all_reduce(inp1, group=group) + out2 = tensor_model_parallel_all_reduce(inp2) + dist.all_reduce(inp2, group=group) + graph.replay() + atol = 1.25 * world_size + rtol = 0.5 * world_size + for inp, out in [[inp1, out1], [inp2, out2]]: + torch.testing.assert_close(out, inp, atol=atol, rtol=rtol) + # try: + # torch.testing.assert_close(out, inp, atol=atol, rtol=rtol) + # except AssertionError as e: + # print("Max abs diff:", (out - inp).abs().max()) + # print("Max rel diff:", ((out - inp).abs() / inp.abs().clamp(min=1e-5)).max()) + + @ray.remote(num_gpus=1, max_calls=1) + def eager_allreduce(self, world_size, rank, distributed_init_port, quant_mode): + os.environ.pop("CUDA_VISIBLE_DEVICES", None) + os.environ["ROCM_QUICK_REDUCE_QUANTIZATION"] = quant_mode + os.environ["ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16"] = "0" + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + distributed_init_method = f"tcp://localhost:{distributed_init_port}" + init_distributed_environment( + world_size=world_size, + rank=rank, + distributed_init_method=distributed_init_method, + local_rank=rank, + ) + initialize_model_parallel(tensor_model_parallel_size=world_size) + group = get_tensor_model_parallel_group().device_group + + for sz in self.TEST_SIZES: + for dtype in [torch.float16, torch.bfloat16]: + for _ in range(self.TEST_LOOP): + inp1 = torch.randint( + 1, + 23, + (sz,), + dtype=dtype, + device=torch.cuda.current_device(), + ) + out1 = tensor_model_parallel_all_reduce(inp1) + dist.all_reduce(inp1, group=group) + atol = 1.25 * world_size + rtol = 0.5 * world_size + torch.testing.assert_close(out1, inp1, atol=atol, rtol=rtol) + # try: + # torch.testing.assert_close(out1, inp1, atol=atol, rtol=rtol) + # except AssertionError as e: + # print("Max abs diff:", (out1 - inp1).abs().max()) + # print("Max rel diff:", ((out1 - inp1).abs() / inp1.abs().clamp(min=1e-5)).max()) + + +if __name__ == "__main__": + unittest.main()