From 9cf407729441a18561ee5a8b2c1e3e535ba817f9 Mon Sep 17 00:00:00 2001 From: Hubert Lu <55214931+hubertlu-tw@users.noreply.github.com> Date: Sun, 2 Mar 2025 15:19:06 -0800 Subject: [PATCH] Enable custom AR for AMD GPUs and maintain it in sgl-kernel (#3406) --- python/sglang/srt/_custom_ops.py | 114 ++-- .../device_communicators/custom_all_reduce.py | 302 +++++++--- sgl-kernel/setup_rocm.py | 1 + sgl-kernel/src/sgl-kernel/__init__.py | 192 ++++-- .../src/sgl-kernel/csrc/custom_all_reduce.hip | 180 ++++++ .../sgl-kernel/csrc/custom_all_reduce_hip.cuh | 554 ++++++++++++++++++ .../src/sgl-kernel/include/sgl_kernels_ops.h | 18 +- sgl-kernel/src/sgl-kernel/ops/__init__.py | 85 ++- .../src/sgl-kernel/torch_extension_rocm.cc | 31 + 9 files changed, 1282 insertions(+), 195 deletions(-) create mode 100644 sgl-kernel/src/sgl-kernel/csrc/custom_all_reduce.hip create mode 100644 sgl-kernel/src/sgl-kernel/csrc/custom_all_reduce_hip.cuh diff --git a/python/sglang/srt/_custom_ops.py b/python/sglang/srt/_custom_ops.py index 3cb313b91..7002b25d5 100644 --- a/python/sglang/srt/_custom_ops.py +++ b/python/sglang/srt/_custom_ops.py @@ -9,13 +9,13 @@ from typing import TYPE_CHECKING, List, Optional, Tuple, Union import torch import torch.library -from sglang.srt.utils import is_hpu +from sglang.srt.utils import is_hip, is_hpu logger = logging.getLogger(__name__) use_vllm_custom_allreduce = os.environ.get("USE_VLLM_CUSTOM_ALLREDUCE", default=True) - if not is_hpu(): - if use_vllm_custom_allreduce: + # Remove vllm dependency for custom allreduce on ROCm + if use_vllm_custom_allreduce and not is_hip(): try: import vllm._C except ImportError as e: @@ -56,7 +56,7 @@ def hint_on_error(fn): return wrapper -if use_vllm_custom_allreduce: +if use_vllm_custom_allreduce and not is_hip(): # custom ar def init_custom_ar( ipc_tensors: List[torch.Tensor], @@ -95,39 +95,87 @@ if use_vllm_custom_allreduce: torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets) else: - # custom ar - def init_custom_ar( - rank_id: int, - world_size: int, - rank_data_base: torch.Tensor, - buffers: List[int], - tmp_result_buffers: List[int], - barrier_in: List[int], - barrier_out: List[int], - ) -> int: - return sgl_kernel.ops.init_custom_reduce( - rank_id, - world_size, - rank_data_base, - buffers, - tmp_result_buffers, - barrier_in, - barrier_out, - ) + if is_hip(): - def all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None: - sgl_kernel.ops.custom_reduce(fa, inp, out) + def init_custom_ar( + meta: torch.Tensor, + rank_data: torch.Tensor, + handles: List[str], + offsets: List[int], + rank: int, + full_nvlink: bool, + ) -> int: + return sgl_kernel.ops.init_custom_ar( + meta, rank_data, handles, offsets, rank, full_nvlink + ) - def dispose(fa: int) -> None: - sgl_kernel.ops.custom_dispose(fa) + def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None: + sgl_kernel.ops.all_reduce_reg(fa, inp, out) - def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]: - return sgl_kernel.ops.get_graph_buffer_ipc_meta(fa) + def all_reduce_unreg( + fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor, out: torch.Tensor + ) -> None: + sgl_kernel.ops.all_reduce_unreg(fa, inp, reg_buffer, out) - def register_graph_buffers( - fa: int, handles: List[List[int]], offsets: List[List[int]] - ) -> None: - sgl_kernel.ops.register_graph_buffers(fa, handles, offsets) + def dispose(fa: int) -> None: + sgl_kernel.ops.dispose(fa) + + def meta_size() -> int: + return sgl_kernel.ops.meta_size() + + def register_buffer( + fa: int, t: torch.Tensor, handles: List[str], offsets: List[int] + ) -> None: + return sgl_kernel.ops.register_buffer(fa, t, handles, offsets) + + def get_graph_buffer_ipc_meta(fa: int) -> Tuple[torch.Tensor, List[int]]: + return sgl_kernel.ops.get_graph_buffer_ipc_meta(fa) + + def register_graph_buffers( + fa: int, handles: List[str], offsets: List[List[int]] + ) -> None: + sgl_kernel.ops.register_graph_buffers(fa, handles, offsets) + + def allocate_meta_buffer(size: int) -> torch.Tensor: + return sgl_kernel.ops.allocate_meta_buffer(size) + + def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor: + return sgl_kernel.ops.get_meta_buffer_ipc_handle(inp) + + else: + # custom ar + def init_custom_ar( + rank_id: int, + world_size: int, + rank_data_base: torch.Tensor, + buffers: List[int], + tmp_result_buffers: List[int], + barrier_in: List[int], + barrier_out: List[int], + ) -> int: + return sgl_kernel.ops.init_custom_reduce( + rank_id, + world_size, + rank_data_base, + buffers, + tmp_result_buffers, + barrier_in, + barrier_out, + ) + + def all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None: + sgl_kernel.ops.custom_reduce(fa, inp, out) + + def dispose(fa: int) -> None: + sgl_kernel.ops.custom_dispose(fa) + + def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]: + return sgl_kernel.ops.get_graph_buffer_ipc_meta(fa) + + def register_graph_buffers( + fa: int, handles: List[List[int]], offsets: List[List[int]] + ) -> None: + sgl_kernel.ops.register_graph_buffers(fa, handles, offsets) # temporary fix for https://github.com/vllm-project/vllm/issues/5456 diff --git a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py index faeac0bba..4eb28ced7 100644 --- a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py +++ b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py @@ -18,7 +18,7 @@ from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import gpu_p2p_access_check, ) from sglang.srt.distributed.parallel_state import in_the_same_node_as -from sglang.srt.utils import cuda_device_count_stateless, is_cuda +from sglang.srt.utils import cuda_device_count_stateless, is_cuda, is_hip logger = logging.getLogger(__name__) @@ -28,14 +28,27 @@ if is_cuda(): except ImportError as e: logger.warning("Failed to import pynvml with %r", e) +if is_hip(): + try: + from amdsmi import ( + AmdSmiException, + amdsmi_get_gpu_board_info, + amdsmi_get_processor_handles, + amdsmi_init, + amdsmi_shut_down, + amdsmi_topo_get_link_type, + ) + except ImportError as e: + logger.warning("Failed to import amdsmi with %r", e) + try: - if ops.use_vllm_custom_allreduce: + if ops.use_vllm_custom_allreduce and not is_hip(): ops.meta_size() else: import sgl_kernel custom_ar = True except Exception: - # For AMD GPUs and CPUs + # For CPUs custom_ar = False logger = logging.getLogger(__name__) @@ -47,37 +60,62 @@ _R = TypeVar("_R") def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]: @wraps(fn) def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R: - pynvml.nvmlInit() - try: - return fn(*args, **kwargs) - finally: - pynvml.nvmlShutdown() + if torch.version.hip: + try: + amdsmi_init() + return fn(*args, **kwargs) + finally: + amdsmi_shut_down() + else: + pynvml.nvmlInit() + try: + return fn(*args, **kwargs) + finally: + pynvml.nvmlShutdown() return wrapper @with_nvml_context -def is_full_nvlink(physical_device_ids: List[int]) -> bool: - """ - query if the set of gpus are fully connected by nvlink (1 hop) - """ - handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids] - for i, handle in enumerate(handles): - for j, peer_handle in enumerate(handles): - if i < j: - try: - p2p_status = pynvml.nvmlDeviceGetP2PStatus( - handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK - ) - if p2p_status != pynvml.NVML_P2P_STATUS_OK: +def is_full_nvlink(physical_device_ids: List[int], world_size: int) -> bool: + if is_hip(): + """ + query if the set of gpus are fully connected by xgmi (1 hop) + """ + handles = [amdsmi_get_processor_handles()[i] for i in physical_device_ids] + for i, handle in enumerate(handles): + for j, peer_handle in enumerate(handles): + if i < j: + try: + link_type = amdsmi_topo_get_link_type(handle, peer_handle) + # type is 2 for XGMI + if link_type["hops"] != 1 or link_type["type"] != 2: + return False + except AmdSmiException as error: + logger.error("AMD 1 hop XGMI detection failed.", exc_info=error) return False - except pynvml.NVMLError: - logger.exception( - "NVLink detection failed. This is normal if your" - " machine has no NVLink equipped." - ) - return False - return True + return True + else: + """ + query if the set of gpus are fully connected by nvlink (1 hop) + """ + handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids] + for i, handle in enumerate(handles): + for j, peer_handle in enumerate(handles): + if i < j: + try: + p2p_status = pynvml.nvmlDeviceGetP2PStatus( + handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK + ) + if p2p_status != pynvml.NVML_P2P_STATUS_OK: + return False + except pynvml.NVMLError: + logger.exception( + "NVLink detection failed. This is normal if your" + " machine has no NVLink equipped." + ) + return False + return True def _can_p2p(rank: int, world_size: int) -> bool: @@ -102,15 +140,18 @@ def is_weak_contiguous(inp: torch.Tensor): class CustomAllreduce: - _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8] + _MAX_CAR_SIZE = 8192 * 1024 + if is_hip(): + # crossover is at 16MB buffer size for ROCm + _MAX_CAR_SIZE = 2 * 8192 * 1024 # max_size: max supported allreduce size def __init__( self, group: ProcessGroup, device: Union[int, str, torch.device], - max_size=8192 * 1024, + max_size=_MAX_CAR_SIZE, ) -> None: """ Args: @@ -185,12 +226,9 @@ class CustomAllreduce: # test nvlink first, this will filter out most of the cases # where custom allreduce is not supported # this checks hardware and driver support for NVLink - if is_cuda(): - assert is_cuda() + if is_cuda() or is_hip(): + full_nvlink = is_full_nvlink(physical_device_ids, world_size) - full_nvlink = is_full_nvlink(physical_device_ids) - else: - full_nvlink = False if world_size > 2 and not full_nvlink: logger.warning( "Custom allreduce is disabled because it's not supported on" @@ -201,7 +239,8 @@ class CustomAllreduce: # test P2P capability, this checks software/cudaruntime support # this is expensive to compute at the first time # then we cache the result - if not _can_p2p(rank, world_size): + # On AMD GPU, p2p is always enabled between XGMI connected GPUs + if not is_hip() and not _can_p2p(rank, world_size): logger.warning( "Custom allreduce is disabled because your platform lacks " "GPU P2P capability or P2P test failed. To silence this " @@ -214,7 +253,7 @@ class CustomAllreduce: self.world_size = world_size self.full_nvlink = full_nvlink - if ops.use_vllm_custom_allreduce: + if ops.use_vllm_custom_allreduce and not is_hip(): # Buffers memory are owned by this Python class and passed to C++. # Meta data composes of two parts: meta data for synchronization and a # temporary buffer for storing intermediate allreduce results. @@ -237,35 +276,56 @@ class CustomAllreduce: ) ops.register_buffer(self._ptr, self.buffer_ptrs) else: - # From TensorRT-LLM getMaxRequiredWorkspaceSize - self.max_required_workspace_size = [16 * 1024 * 1024, 8 * 1024 * 1024] + if is_hip(): + # meta data buffers need to be "uncached" for signal on MI200 + self.meta = ops.allocate_meta_buffer(ops.meta_size() + max_size) + self.buffer = torch.empty( + max_size, dtype=torch.uint8, device=self.device + ) + handle = ops.get_meta_buffer_ipc_handle(self.meta) + shard_data = ( + bytes(handle), # ipc handle to base ptr + 0, # offset of base ptr + ) + handles, offsets = self._gather_ipc_meta(shard_data) + self.rank_data = torch.empty( + 8 * 1024 * 1024, dtype=torch.uint8, device=self.device + ) + self._ptr = ops.init_custom_ar( + self.meta, self.rank_data, handles, offsets, rank, self.full_nvlink + ) + self.register_buffer(self.buffer) + self.MSCCL = os.getenv("RCCL_MSCCL_ENABLE", "1") == "1" + else: + # From TensorRT-LLM getMaxRequiredWorkspaceSize + self.max_required_workspace_size = [16 * 1024 * 1024, 8 * 1024 * 1024] - # sizeof(uint32_t) * (MAX_ALL_REDUCE_BLOCKS + 2) * MAX_RANKS_PER_NODE; - self.barrier_max_size = 8 * (36 + 2) * 8 + # sizeof(uint32_t) * (MAX_ALL_REDUCE_BLOCKS + 2) * MAX_RANKS_PER_NODE; + self.barrier_max_size = 8 * (36 + 2) * 8 - self.buffer_ptrs = self.create_shared_buffer(max_size, group=group) - self.tmp_result_buffer_ptrs = self.create_shared_buffer( - max_size, group=group - ) - self.rank_data_base = torch.empty( - 8 * 1024 * 1024, dtype=torch.uint8, device=self.device - ) - self.barrier_in_ptrs = self.create_shared_buffer( - self.barrier_max_size, group=group - ) - self.barrier_out_ptrs = self.create_shared_buffer( - self.barrier_max_size, group=group - ) + self.buffer_ptrs = self.create_shared_buffer(max_size, group=group) + self.tmp_result_buffer_ptrs = self.create_shared_buffer( + max_size, group=group + ) + self.rank_data_base = torch.empty( + 8 * 1024 * 1024, dtype=torch.uint8, device=self.device + ) + self.barrier_in_ptrs = self.create_shared_buffer( + self.barrier_max_size, group=group + ) + self.barrier_out_ptrs = self.create_shared_buffer( + self.barrier_max_size, group=group + ) - self._ptr = ops.init_custom_ar( - rank, - world_size, - self.rank_data_base, - self.buffer_ptrs, - self.tmp_result_buffer_ptrs, - self.barrier_in_ptrs, - self.barrier_out_ptrs, - ) + self._ptr = ops.init_custom_ar( + rank, + world_size, + self.rank_data_base, + self.buffer_ptrs, + self.tmp_result_buffer_ptrs, + self.barrier_in_ptrs, + self.barrier_out_ptrs, + ) self.disabled = False @staticmethod @@ -316,23 +376,69 @@ class CustomAllreduce: if not self.disabled: self.register_graph_buffers() - def register_graph_buffers(self): - handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr) - logger.info("Registering %d cuda graph addresses", len(offset)) - # We cannot directly use `dist.all_gather_object` here - # because it is incompatible with `gloo` backend under inference mode. - # see https://github.com/pytorch/pytorch/issues/126032 for details. - all_data = [[None, None] for _ in range(dist.get_world_size(group=self.group))] - all_data[self.rank] = [handle, offset] - ranks = sorted(dist.get_process_group_ranks(group=self.group)) + def _get_ipc_meta(self, inp: torch.Tensor): + # _share_cuda_() doesn't accept meta buffer not allocated from + # PyTorch cache allocator, use direct HIP call to get IPC handle + handle = ops.get_meta_buffer_ipc_handle(inp) + shard_data = ( + bytes(handle), # ipc handle to base ptr + 0, # offset of base ptr + ) + return self._gather_ipc_meta(shard_data) + + def _gather_ipc_meta(self, shard_data): + # Note: don't use `[[None]] * self.world_size` here + # because it will create a list of the same reference + all_data: List[Optional[Any]] = [[None] for i in range(self.world_size)] + all_data[self.rank][0] = shard_data + + ranks = dist.get_process_group_ranks(group=self.group) + ranks.sort() for i, rank in enumerate(ranks): dist.broadcast_object_list( all_data[i], src=rank, group=self.group, device="cpu" ) - # Unpack list of tuples to tuple of lists. - handles = [d[0] for d in all_data] # type: ignore - offsets = [d[1] for d in all_data] # type: ignore - ops.register_graph_buffers(self._ptr, handles, offsets) + + # we cannot directly use `dist.all_gather_object` here + # because it is incompatible with `gloo` backend under inference mode. + # see https://github.com/pytorch/pytorch/issues/126032 for details. + + handles = [] + offsets = [] + for i in range(len(all_data)): + handles.append(all_data[i][0][0]) # type: ignore + offsets.append(all_data[i][0][1]) # type: ignore + return handles, offsets + + def register_buffer(self, inp: torch.Tensor): + handles, offsets = self._get_ipc_meta(inp) + ops.register_buffer(self._ptr, inp, handles, offsets) + + def register_graph_buffers(self): + if is_hip(): + handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr) + handles, offsets = self._gather_ipc_meta((bytes(handle), offset)) + logger.info("Registering %d cuda graph addresses", len(offset)) + ops.register_graph_buffers(self._ptr, handles, offsets) + else: + handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr) + logger.info("Registering %d cuda graph addresses", len(offset)) + # We cannot directly use `dist.all_gather_object` here + # because it is incompatible with `gloo` backend under inference mode. + # see https://github.com/pytorch/pytorch/issues/126032 for details. + all_data = [ + [None, None] for _ in range(dist.get_world_size(group=self.group)) + ] + all_data[self.rank] = [handle, offset] + ranks = sorted(dist.get_process_group_ranks(group=self.group)) + for i, rank in enumerate(ranks): + dist.broadcast_object_list( + all_data[i], src=rank, group=self.group, device="cpu" + ) + # Unpack list of tuples to tuple of lists. + handles = [d[0] for d in all_data] # type: ignore + offsets = [d[1] for d in all_data] # type: ignore + ops.register_graph_buffers(self._ptr, handles, offsets) def should_custom_ar(self, inp: torch.Tensor): if self.disabled: @@ -345,11 +451,22 @@ class CustomAllreduce: return False # for 4 or more non NVLink-capable GPUs, custom allreduce provides # little performance improvement over NCCL. - if ops.use_vllm_custom_allreduce: + if ops.use_vllm_custom_allreduce and not is_hip(): if self.world_size == 2 or self.full_nvlink: return inp_size < self.max_size return False + if is_hip(): + if self.full_nvlink: + if self.world_size == 8: + if self.MSCCL: + return False + else: + return inp_size < self.max_size + else: + return inp_size < self.max_size + return False + if self.world_size == 2: return ( inp_size < self.max_size @@ -364,6 +481,21 @@ class CustomAllreduce: return False + # all reduce, assuming inp tensor is IPC registered with register_buffer, + # or, in the context of cuda graphs, register_graph_buffers + def all_reduce_reg(self, inp: torch.Tensor, out: torch.Tensor = None): + if out is None: + out = torch.empty_like(inp) + ops.all_reduce_reg(self._ptr, inp, out) + return out + + # all reduce, assuming inp tensor is NOT IPC registered + def all_reduce_unreg(self, inp: torch.Tensor, out: torch.Tensor = None): + if out is None: + out = torch.empty_like(inp) + ops.all_reduce_unreg(self._ptr, inp, self.buffer, out) + return out + def all_reduce( self, inp: torch.Tensor, @@ -397,13 +529,23 @@ class CustomAllreduce: return None if self._IS_CAPTURING: if torch.cuda.is_current_stream_capturing(): - return self.all_reduce(input, registered=True) + if is_hip(): + return self.all_reduce_reg(input) + else: + return self.all_reduce(input, registered=True) else: # If warm up, mimic the allocation pattern since custom # allreduce is out-of-place. return torch.empty_like(input) else: - return self.all_reduce(input, registered=False) + if is_hip(): + # note: outside of cuda graph context, + # custom allreduce incurs a cost of cudaMemcpy, which should + # be small(<=1% of overall latency) compared to the performance + # gains of using custom kernels + return self.all_reduce_unreg(input) + else: + return self.all_reduce(input, registered=False) def close(self): if not self.disabled and self._ptr: @@ -411,7 +553,7 @@ class CustomAllreduce: if ops.use_vllm_custom_allreduce: self.free_shared_buffer(self.meta_ptrs) self.free_shared_buffer(self.buffer_ptrs) - else: + elif is_cuda(): self.free_shared_buffer(self.buffer_ptrs) self.free_shared_buffer(self.tmp_result_buffer_ptrs) self.free_shared_buffer(self.barrier_in_ptrs) diff --git a/sgl-kernel/setup_rocm.py b/sgl-kernel/setup_rocm.py index 6530cd7c7..e3ad6c546 100644 --- a/sgl-kernel/setup_rocm.py +++ b/sgl-kernel/setup_rocm.py @@ -44,6 +44,7 @@ include_dirs = [ sources = [ "src/sgl-kernel/torch_extension_rocm.cc", "src/sgl-kernel/csrc/moe_align_kernel.cu", + "src/sgl-kernel/csrc/custom_all_reduce.hip", ] cxx_flags = ["-O3"] diff --git a/sgl-kernel/src/sgl-kernel/__init__.py b/sgl-kernel/src/sgl-kernel/__init__.py index 25ed6bb74..8af022434 100644 --- a/sgl-kernel/src/sgl-kernel/__init__.py +++ b/sgl-kernel/src/sgl-kernel/__init__.py @@ -1,74 +1,138 @@ import ctypes import os +import torch + if os.path.exists("/usr/local/cuda/targets/x86_64-linux/lib/libcudart.so.12"): ctypes.CDLL( "/usr/local/cuda/targets/x86_64-linux/lib/libcudart.so.12", mode=ctypes.RTLD_GLOBAL, ) - -from sgl_kernel.ops import ( - apply_rope_with_cos_sin_cache_inplace, - bmm_fp8, - build_tree_kernel, - build_tree_kernel_efficient, - cublas_grouped_gemm, - custom_dispose, - custom_reduce, - fp8_blockwise_scaled_mm, - fp8_scaled_mm, - fused_add_rmsnorm, - gelu_and_mul, - gelu_tanh_and_mul, - gemma_fused_add_rmsnorm, - gemma_rmsnorm, - get_graph_buffer_ipc_meta, - init_custom_reduce, - int8_scaled_mm, - lightning_attention_decode, - min_p_sampling_from_probs, - moe_align_block_size, - register_graph_buffers, - rmsnorm, - sampling_scaling_penalties, - sgl_per_token_group_quant_fp8, - silu_and_mul, - top_k_renorm_prob, - top_k_top_p_sampling_from_probs, - top_p_renorm_prob, - tree_speculative_sampling_target_only, -) - from .version import __version__ -__all__ = [ - "apply_rope_with_cos_sin_cache_inplace", - "bmm_fp8", - "cublas_grouped_gemm", - "custom_dispose", - "custom_reduce", - "fp8_blockwise_scaled_mm", - "fp8_scaled_mm", - "fused_add_rmsnorm", - "gelu_and_mul", - "gelu_tanh_and_mul", - "gemma_fused_add_rmsnorm", - "gemma_rmsnorm", - "get_graph_buffer_ipc_meta", - "init_custom_reduce", - "int8_scaled_mm", - "lightning_attention_decode", - "min_p_sampling_from_probs", - "moe_align_block_size", - "register_graph_buffers", - "rmsnorm", - "sampling_scaling_penalties", - "silu_and_mul", - "top_k_renorm_prob", - "top_k_top_p_sampling_from_probs", - "top_p_renorm_prob", - "tree_speculative_sampling_target_only", - "build_tree_kernel_efficient", - "build_tree_kernel", - "sgl_per_token_group_quant_fp8", -] +if torch.version.hip is not None: + from sgl_kernel.ops import ( + all_reduce_reg, + all_reduce_unreg, + allocate_meta_buffer, + apply_rope_with_cos_sin_cache_inplace, + bmm_fp8, + dispose, + fp8_scaled_mm, + fused_add_rmsnorm, + gelu_and_mul, + gelu_tanh_and_mul, + gemma_fused_add_rmsnorm, + gemma_rmsnorm, + get_graph_buffer_ipc_meta, + get_meta_buffer_ipc_handle, + init_custom_ar, + int8_scaled_mm, + lightning_attention_decode, + meta_size, + min_p_sampling_from_probs, + moe_align_block_size, + register_buffer, + register_graph_buffers, + rmsnorm, + sampling_scaling_penalties, + silu_and_mul, + top_k_renorm_prob, + top_k_top_p_sampling_from_probs, + top_p_renorm_prob, + ) + + __all__ = [ + "all_reduce_reg", + "all_reduce_unreg", + "allocate_meta_buffer", + "apply_rope_with_cos_sin_cache_inplace", + "bmm_fp8", + "dispose", + "fp8_scaled_mm", + "fused_add_rmsnorm", + "gelu_and_mul", + "gelu_tanh_and_mul", + "gemma_fused_add_rmsnorm", + "gemma_rmsnorm", + "get_graph_buffer_ipc_meta", + "get_meta_buffer_ipc_handle", + "init_custom_ar", + "int8_scaled_mm", + "lightning_attention_decode", + "meta_size", + "min_p_sampling_from_probs", + "moe_align_block_size", + "register_buffer", + "register_graph_buffers", + "rmsnorm", + "sampling_scaling_penalties", + "silu_and_mul", + "top_k_renorm_prob", + "top_k_top_p_sampling_from_probs", + "top_p_renorm_prob", + ] +else: + from sgl_kernel.ops import ( + apply_rope_with_cos_sin_cache_inplace, + bmm_fp8, + build_tree_kernel, + build_tree_kernel_efficient, + cublas_grouped_gemm, + custom_dispose, + custom_reduce, + fp8_blockwise_scaled_mm, + fp8_scaled_mm, + fused_add_rmsnorm, + gelu_and_mul, + gelu_tanh_and_mul, + gemma_fused_add_rmsnorm, + gemma_rmsnorm, + get_graph_buffer_ipc_meta, + init_custom_reduce, + int8_scaled_mm, + lightning_attention_decode, + min_p_sampling_from_probs, + moe_align_block_size, + register_graph_buffers, + rmsnorm, + sampling_scaling_penalties, + sgl_per_token_group_quant_fp8, + silu_and_mul, + top_k_renorm_prob, + top_k_top_p_sampling_from_probs, + top_p_renorm_prob, + tree_speculative_sampling_target_only, + ) + + __all__ = [ + "apply_rope_with_cos_sin_cache_inplace", + "bmm_fp8", + "cublas_grouped_gemm", + "custom_dispose", + "custom_reduce", + "fp8_blockwise_scaled_mm", + "fp8_scaled_mm", + "fused_add_rmsnorm", + "gelu_and_mul", + "gelu_tanh_and_mul", + "gemma_fused_add_rmsnorm", + "gemma_rmsnorm", + "get_graph_buffer_ipc_meta", + "init_custom_reduce", + "int8_scaled_mm", + "lightning_attention_decode", + "min_p_sampling_from_probs", + "moe_align_block_size", + "register_graph_buffers", + "rmsnorm", + "sampling_scaling_penalties", + "silu_and_mul", + "top_k_renorm_prob", + "top_k_top_p_sampling_from_probs", + "top_p_renorm_prob", + "tree_speculative_sampling_target_only", + "build_tree_kernel_efficient", + "build_tree_kernel", + "sgl_per_token_group_quant_fp8", + ] diff --git a/sgl-kernel/src/sgl-kernel/csrc/custom_all_reduce.hip b/sgl-kernel/src/sgl-kernel/csrc/custom_all_reduce.hip new file mode 100644 index 000000000..6c1ef0d06 --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/csrc/custom_all_reduce.hip @@ -0,0 +1,180 @@ +// !!! This is a file automatically generated by hipify!!! +#include +#include +#include +#include + +#include "custom_all_reduce_hip.cuh" + +// fake pointer type, must match fptr_t type in ops.h +using fptr_t = int64_t; +static_assert(sizeof(void*) == sizeof(fptr_t)); + +fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data, + const std::vector& handles, + const std::vector& offsets, int64_t rank, + bool full_nvlink) { + int world_size = offsets.size(); + if (world_size > 8) + throw std::invalid_argument("world size > 8 is not supported"); + if (world_size % 2 != 0) + throw std::invalid_argument("Odd num gpus is not supported for now"); + if (world_size != handles.size()) + throw std::invalid_argument( + "handles length should equal to offsets length"); + if (rank < 0 || rank >= world_size) + throw std::invalid_argument("invalid rank passed in"); + + hipIpcMemHandle_t ipc_handles[8]; + for (int i = 0; i < world_size; i++) { + std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(hipIpcMemHandle_t)); + } + return (fptr_t) new vllm::CustomAllreduce( + reinterpret_cast(meta.data_ptr()), rank_data.data_ptr(), + rank_data.numel(), ipc_handles, offsets, rank, full_nvlink); +} + +/** + * Make sure tensor t's data lies completely within ((char)t.data_ptr()) + + * t.numel() * t.element_size(). This is slightly weaker than t.is_contiguous() + * because it allows transpose of contiguous slice (i.e. slicing the first + * dimension). Currently, we require this because stride information is not + * passed into the kernels and we treat input tensors as flat. + * + * Examples + * A = torch.zeros(3, 3, 3) + * 1. A: OK + * 2. A[1:]: OK + * 3. A.permute(2, 0, 1): OK + * 4. A[1:].permute(2, 0, 1): OK + * 5. A[None].expand(2, -1, -1, -1): Not OK + * 6. A[:, 1:, 1:]: Not OK + */ +bool _is_weak_contiguous(torch::Tensor& t) { + return t.is_contiguous() || + (t.storage().nbytes() - t.storage_offset() * t.element_size() == + t.numel() * t.element_size()); +} + +void _all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, + hipStream_t stream) { + auto fa = reinterpret_cast(_fa); + TORCH_CHECK(_is_weak_contiguous(out)); + switch (out.scalar_type()) { + case at::ScalarType::Float: { + fa->allreduce(stream, reinterpret_cast(inp.data_ptr()), + reinterpret_cast(out.data_ptr()), + out.numel()); + break; + } + case at::ScalarType::Half: { + fa->allreduce(stream, reinterpret_cast(inp.data_ptr()), + reinterpret_cast(out.data_ptr()), out.numel()); + break; + } +#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) + case at::ScalarType::BFloat16: { + fa->allreduce( + stream, reinterpret_cast(inp.data_ptr()), + reinterpret_cast(out.data_ptr()), out.numel()); + break; + } +#endif + default: + throw std::runtime_error( + "custom allreduce only supports float32, float16 and bfloat16"); + } +} + +void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out) { + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(inp)); + auto stream = c10::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream(); + TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type()); + TORCH_CHECK_EQ(inp.numel(), out.numel()); + _all_reduce(_fa, inp, out, stream); +} + +void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer, + torch::Tensor& out) { + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(inp)); + auto stream = c10::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream(); + + auto input_size = inp.numel() * inp.element_size(); + TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type()); + TORCH_CHECK_EQ(inp.numel(), out.numel()); + TORCH_CHECK(input_size <= reg_buffer.numel() * reg_buffer.element_size(), + "registered buffer is too small to contain the input"); + AT_CUDA_CHECK(hipMemcpyAsync(reg_buffer.data_ptr(), inp.data_ptr(), + input_size, hipMemcpyDeviceToDevice, stream)); + _all_reduce(_fa, reg_buffer, out, stream); +} + +void dispose(fptr_t _fa) { + auto fa = reinterpret_cast(_fa); + delete fa; +} + +int64_t meta_size() { return sizeof(vllm::Signal); } + +void register_buffer(fptr_t _fa, torch::Tensor& t, + const std::vector& handles, + const std::vector& offsets) { + auto fa = reinterpret_cast(_fa); + fa->register_buffer(handles, offsets, t.data_ptr()); +} + +std::tuple> get_graph_buffer_ipc_meta( + fptr_t _fa) { + auto fa = reinterpret_cast(_fa); + auto [handle_bytes, offsets] = fa->get_graph_buffer_ipc_meta(); + auto options = + torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU); + auto handles = + torch::empty({static_cast(handle_bytes.size())}, options); + std::memcpy(handles.data_ptr(), handle_bytes.data(), handle_bytes.size()); + return {handles, std::move(offsets)}; +} + +void register_graph_buffers(fptr_t _fa, const std::vector& handles, + const std::vector>& offsets) { + auto fa = reinterpret_cast(_fa); + fa->register_graph_buffers(handles, offsets); +} + +void free_meta_buffer(void* buffer) { CUDACHECK(hipFree(buffer)); } + +torch::Tensor get_meta_buffer_ipc_handle(torch::Tensor& inp) { + auto options = + torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU); + auto data_handle = + torch::empty({static_cast(sizeof(hipIpcMemHandle_t))}, options); + CUDACHECK(hipIpcGetMemHandle((hipIpcMemHandle_t*)data_handle.data_ptr(), + inp.data_ptr())); + return data_handle; +} + +torch::Tensor allocate_meta_buffer(int64_t size) { + auto device_index = c10::hip::current_device(); + at::DeviceGuard device_guard(at::Device(at::DeviceType::CUDA, device_index)); + void* buffer; + hipStreamCaptureMode mode = hipStreamCaptureModeRelaxed; + auto stream = c10::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream(); + AT_CUDA_CHECK(hipThreadExchangeStreamCaptureMode(&mode)); + AT_CUDA_CHECK( + hipExtMallocWithFlags((void**)&buffer, size, hipDeviceMallocUncached)); + AT_CUDA_CHECK(hipMemsetAsync(buffer, 0, size, stream)); + AT_CUDA_CHECK(hipStreamSynchronize(stream)); + AT_CUDA_CHECK(hipThreadExchangeStreamCaptureMode(&mode)); + auto options = torch::TensorOptions() + .dtype(torch::kI8) + .device(torch::kCUDA, device_index); + return torch::from_blob(buffer, {size}, free_meta_buffer, options); +} + +std::vector get_device_bdf(int dev) { + char busIdStr[] = "0000:00:00.0"; + std::vector bdf(sizeof(busIdStr), 0); + CUDACHECK(hipDeviceGetPCIBusId((char*)bdf.data(), sizeof(busIdStr), dev)); + bdf.resize(bdf.size() - 1); // remove trailing NULL + return bdf; +} diff --git a/sgl-kernel/src/sgl-kernel/csrc/custom_all_reduce_hip.cuh b/sgl-kernel/src/sgl-kernel/csrc/custom_all_reduce_hip.cuh new file mode 100644 index 000000000..06173bc42 --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/csrc/custom_all_reduce_hip.cuh @@ -0,0 +1,554 @@ +// !!! This is a file automatically generated by hipify!!! +#pragma once + +#include +#ifdef USE_ROCM +#include +typedef __hip_bfloat16 nv_bfloat16; +#else +#include +#endif +#include +#include + +#include +#include +#include +#include +#include + +#define CUDACHECK(cmd) \ + do { \ + hipError_t e = cmd; \ + if (e != hipSuccess) { \ + printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, hipGetErrorString(e)); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) + +namespace vllm { + +constexpr int kMaxBlocks = 64; +// note: we don't want to use atomics for signals because peer atomics are no +// supported on PCIe links +struct Signal { + alignas(128) uint32_t start[kMaxBlocks][8]; + alignas(128) uint32_t end[kMaxBlocks][8]; + alignas(128) uint32_t _flag[kMaxBlocks]; // incremental flags for each rank +}; + +#ifdef USE_ROCM +struct __align__(16) RankData { + const void* ptrs[8]; +}; +#else +struct __align__(16) RankData { + const void* __restrict__ ptrs[8]; +}; +#endif + +struct __align__(16) RankSignals { +#ifndef USE_ROCM + volatile +#endif + Signal* signals[8]; +}; + +// like std::array, but aligned +template +struct __align__(alignof(T) * sz) array_t { + T data[sz]; + using type = T; + static constexpr int size = sz; +}; + +// use packed type to maximize memory efficiency +// goal: generate ld.128 and st.128 instructions +template +struct packed_t { + // the (P)acked type for load/store + using P = array_t; + // the (A)ccumulator type for reduction + using A = array_t; +}; + +#define DINLINE __device__ __forceinline__ + +// scalar cast functions +DINLINE float upcast_s(half val) { + return __half2float(val); +} + +template +DINLINE T downcast_s(float val); +template <> +DINLINE half downcast_s(float val) { + return __float2half(val); +} + +// scalar add functions +// for some reason when compiling with Pytorch, the + operator for half and +// bfloat is disabled so we call the intrinsics directly +DINLINE half& assign_add(half& a, half b) { + a = __hadd(a, b); + return a; +} +DINLINE float& assign_add(float& a, float b) { + return a += b; +} + +#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) +DINLINE float upcast_s(nv_bfloat16 val) { + return __bfloat162float(val); +} +template <> +DINLINE nv_bfloat16 downcast_s(float val) { + return __float2bfloat16(val); +} +DINLINE nv_bfloat16& assign_add(nv_bfloat16& a, nv_bfloat16 b) { + a = __hadd(a, b); + return a; +} +#endif + +template +DINLINE array_t& packed_assign_add(array_t& a, array_t b) { +#pragma unroll + for (int i = 0; i < N; i++) { + assign_add(a.data[i], b.data[i]); + } + return a; +} + +template +DINLINE array_t upcast(array_t val) { + if constexpr (std::is_same::value) { + return val; + } else { + array_t out; +#pragma unroll + for (int i = 0; i < N; i++) { + out.data[i] = upcast_s(val.data[i]); + } + return out; + } +} + +template +DINLINE O downcast(array_t val) { + if constexpr (std::is_same::value) { + return val; + } else { + O out; +#pragma unroll + for (int i = 0; i < O::size; i++) { + out.data[i] = downcast_s(val.data[i]); + } + return out; + } +} + +// This function is meant to be used as the first synchronization in the all +// reduce kernel. Thus, it doesn't need to make any visibility guarantees for +// prior memory accesses. Note: volatile writes will not be reordered against +// other volatile writes. +template +DINLINE void start_sync(const RankSignals& sg, +#ifndef USE_ROCM + volatile +#endif + Signal* self_sg, + int rank) { +#ifdef USE_ROCM + uint32_t flag = self_sg->_flag[blockIdx.x] + 1; + if (threadIdx.x < ngpus) { + // simultaneously write to the corresponding flag of all ranks. + // Latency = 1 p2p write + __scoped_atomic_store_n(&sg.signals[threadIdx.x]->start[blockIdx.x][rank], flag, __ATOMIC_RELAXED, + __MEMORY_SCOPE_SYSTEM); + // wait until we got true from all ranks + while (__scoped_atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x], __ATOMIC_RELAXED, __MEMORY_SCOPE_DEVICE) < + flag) + ; + } + __syncthreads(); + // use one thread to update flag + if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag; +#else + if (threadIdx.x < ngpus) { + // reset flag for next time + self_sg->end[blockIdx.x][threadIdx.x] = 0; + // simultaneously write to the corresponding flag of all ranks. + // Latency = 1 p2p write + sg.signals[threadIdx.x]->start[blockIdx.x][rank] = 1; + // wait until we got true from all ranks + while (!self_sg->start[blockIdx.x][threadIdx.x]) + ; + } + __syncthreads(); +#endif +} + +// This function is meant to be used as the second or the final synchronization +// barrier in the all reduce kernel. If it's the final synchronization barrier, +// we don't need to make any visibility guarantees for prior memory accesses. +template +DINLINE void end_sync(const RankSignals& sg, +#ifndef USE_ROCM + volatile +#endif + Signal* self_sg, + int rank) { +#ifdef USE_ROCM + __syncthreads(); + // eliminate the case that prior writes are not visible after signals become + // visible. Note that I did not managed to make this happen through a lot of + // testing. Might be the case that hardware provides stronger guarantee than + // the memory model. + uint32_t flag = self_sg->_flag[blockIdx.x] + 1; + if (threadIdx.x < ngpus) { + // simultaneously write to the corresponding flag of all ranks. + // Latency = 1 p2p write + __scoped_atomic_store_n(&sg.signals[threadIdx.x]->end[blockIdx.x][rank], flag, + final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE, __MEMORY_SCOPE_SYSTEM); + // wait until we got true from all ranks + while (__scoped_atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x], + final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE, __MEMORY_SCOPE_DEVICE) < flag) + ; + } + __syncthreads(); + // use one thread to update flag + if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag; +#else + __syncthreads(); + // eliminate the case that prior writes are not visible after signals become + // visible. Note that I did not managed to make this happen through a lot of + // testing. Might be the case that hardware provides stronger guarantee than + // the memory model. + if constexpr (!final_sync) __threadfence_system(); + if (threadIdx.x < ngpus) { + // reset flag for next time + self_sg->start[blockIdx.x][threadIdx.x] = 0; + // simultaneously write to the corresponding flag of all ranks. + // Latency = 1 p2p write + sg.signals[threadIdx.x]->end[blockIdx.x][rank] = 1; + // wait until we got true from all ranks + while (!self_sg->end[blockIdx.x][threadIdx.x]) + ; + } + if constexpr (!final_sync) __syncthreads(); +#endif +} + +template +DINLINE P packed_reduce(const P* ptrs[], int idx) { + A tmp = upcast(ptrs[0][idx]); +#pragma unroll + for (int i = 1; i < ngpus; i++) { + packed_assign_add(tmp, upcast(ptrs[i][idx])); + } + return downcast

(tmp); +} + +template +__global__ void __launch_bounds__(512, 1) cross_device_reduce_1stage(RankData* _dp, RankSignals sg, +#ifndef USE_ROCM + volatile +#endif + Signal* self_sg, + T* __restrict__ result, int rank, int size) { + using P = typename packed_t::P; + using A = typename packed_t::A; + // note: we don't reorder the address so the accumulation order is the same + // for all ranks, ensuring bitwise identical results + auto dp = *_dp; + start_sync(sg, self_sg, rank); + // do the actual reduction + for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += gridDim.x * blockDim.x) { + ((P*)result)[idx] = packed_reduce((const P**)&dp.ptrs[0], idx); + } + end_sync(sg, self_sg, rank); +} + +template +#ifdef USE_ROCM +DINLINE P* get_tmp_buf(Signal* sg) { +#else +DINLINE P* get_tmp_buf(volatile Signal* sg) { +#endif + return (P*)(((Signal*)sg) + 1); +} + +template +__global__ void __launch_bounds__(512, 1) cross_device_reduce_2stage(RankData* _dp, RankSignals sg, +#ifndef USE_ROCM + volatile +#endif + Signal* self_sg, + T* __restrict__ result, int rank, int size) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = gridDim.x * blockDim.x; + using P = typename packed_t::P; + using A = typename packed_t::A; + int part = size / ngpus; + int start = rank * part; + int end = rank == ngpus - 1 ? size : start + part; + int largest_part = part + size % ngpus; + const P* ptrs[ngpus]; + P* tmps[ngpus]; +#pragma unroll + for (int i = 0; i < ngpus; i++) { + int target = (rank + i) % ngpus; + ptrs[i] = (const P*)_dp->ptrs[target]; + tmps[i] = get_tmp_buf

(sg.signals[target]); + } + auto tmp_out = tmps[0]; + start_sync(sg, self_sg, rank); + // stage 1: reduce scatter + for (int idx = start + tid; idx < end; idx += stride) { + tmp_out[idx - start] = packed_reduce(ptrs, idx); + } + end_sync(sg, self_sg, rank); + + // stage 2: allgather. Note: it's important to match the tid between + // the two stages, because visibility across devices is only guaranteed + // between threads that have the same tid. If thread i computes the sum of + // start + i in the first stage, then thread i also gathers start + i from all + // ranks. + for (int idx = tid; idx < largest_part; idx += stride) { +#pragma unroll + for (int i = 0; i < ngpus; i++) { + int gather_from_rank = ((rank + i) % ngpus); + if (gather_from_rank == ngpus - 1 || idx < part) { + int dst_idx = gather_from_rank * part + idx; + ((P*)result)[dst_idx] = tmps[i][idx]; + } + } + } +} + +using IPC_KEY = std::array; +static_assert(sizeof(IPC_KEY) == sizeof(hipIpcMemHandle_t)); +static_assert(alignof(IPC_KEY) == alignof(hipIpcMemHandle_t)); + +class CustomAllreduce { + public: + int rank_; + int world_size_; + bool full_nvlink_; + + // below are device pointers + RankSignals sg_; + std::unordered_map buffers_; + Signal* self_sg_; + + // stores the registered device pointers from all ranks + RankData *d_rank_data_base_, *d_rank_data_end_; + std::vector graph_unreg_buffers_; + // a map from IPC handles to opened IPC pointers + std::map ipc_handles_; + + /** + * meta is a pointer to device metadata and temporary buffer for allreduce. + * + * There's a total of sizeof(Signal) of prefix before the actual data, + * so meta + 1 points to actual temporary buffer. + * + * note: this class does not own any device memory. Any required buffers + * are passed in from the constructor + */ + CustomAllreduce(Signal* meta, void* rank_data, size_t rank_data_sz, const hipIpcMemHandle_t* handles, + const std::vector& offsets, int rank, bool full_nvlink = true) + : rank_(rank), + world_size_(offsets.size()), + full_nvlink_(full_nvlink), + self_sg_(meta), + d_rank_data_base_(reinterpret_cast(rank_data)), + d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) { + for (int i = 0; i < world_size_; i++) { + Signal* rank_sg; + if (i != rank_) { + char* handle = open_ipc_handle(&handles[i]); + handle += offsets[i]; + rank_sg = (Signal*)handle; + } else { + rank_sg = self_sg_; + } + sg_.signals[i] = rank_sg; + } + } + + char* open_ipc_handle(const void* ipc_handle) { + auto [it, new_handle] = ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr}); + if (new_handle) { + char* ipc_ptr; + CUDACHECK(hipIpcOpenMemHandle((void**)&ipc_ptr, *((const hipIpcMemHandle_t*)ipc_handle), + hipIpcMemLazyEnablePeerAccess)); + it->second = ipc_ptr; + } + return it->second; + } + + std::pair, std::vector> get_graph_buffer_ipc_meta() { + auto num_buffers = graph_unreg_buffers_.size(); + auto handle_sz = sizeof(hipIpcMemHandle_t); + std::vector handles(handle_sz * num_buffers, 0); + std::vector offsets(num_buffers); + for (int i = 0; i < num_buffers; i++) { + auto ptr = graph_unreg_buffers_[i]; + void* base_ptr; + // note: must share the base address of each allocation, or we get wrong + // address + if (hipPointerGetAttribute(&base_ptr, +#ifdef USE_ROCM + HIP_POINTER_ATTRIBUTE_RANGE_START_ADDR, +#else + CU_POINTER_ATTRIBUTE_RANGE_START_ADDR, +#endif + (hipDeviceptr_t)ptr) != hipSuccess) + throw std::runtime_error("failed to get pointer attr"); + CUDACHECK(hipIpcGetMemHandle((hipIpcMemHandle_t*)&handles[i * handle_sz], base_ptr)); + offsets[i] = ((char*)ptr) - ((char*)base_ptr); + } + return std::make_pair(handles, offsets); + } + + void check_rank_data_capacity(size_t num = 1) { + if (d_rank_data_base_ + num > d_rank_data_end_) + throw std::runtime_error("Rank data buffer is overflowed by " + + std::to_string(d_rank_data_base_ + num - d_rank_data_end_)); + } + + void register_buffer(const std::vector& handles, const std::vector& offsets, void* self) { + check_rank_data_capacity(); + RankData data; + for (int i = 0; i < world_size_; i++) { + if (i != rank_) { + char* handle = open_ipc_handle(handles[i].data()); + handle += offsets[i]; + data.ptrs[i] = handle; + } else { + data.ptrs[i] = self; + } + } + auto d_data = d_rank_data_base_++; + CUDACHECK(hipMemcpy(d_data, &data, sizeof(RankData), hipMemcpyHostToDevice)); + buffers_[self] = d_data; + } + + // note: when registering graph buffers, we intentionally choose to not + // deduplicate the addresses. That means if the allocator reuses some + // addresses, they will be registered again. This is to account for the remote + // possibility of different allocation patterns between ranks. For example, + // rank 1 may get the same input address for the second allreduce, but rank 2 + // got a different address. IPC handles have internal reference counting + // mechanism so overhead should be small. + void register_graph_buffers(const std::vector& handles, + const std::vector>& offsets) { + auto num_buffers = graph_unreg_buffers_.size(); + check_rank_data_capacity(num_buffers); + std::vector rank_data(num_buffers); + for (int i = 0; i < num_buffers; i++) { + auto self_ptr = graph_unreg_buffers_[i]; + auto& rd = rank_data[i]; + for (int j = 0; j < world_size_; j++) { + if (j != rank_) { + char* handle = open_ipc_handle(&handles[j][i * sizeof(hipIpcMemHandle_t)]); + handle += offsets[j][i]; + rd.ptrs[j] = handle; + } else { + rd.ptrs[j] = self_ptr; + } + } + } + CUDACHECK(hipMemcpy(d_rank_data_base_, rank_data.data(), sizeof(RankData) * num_buffers, hipMemcpyHostToDevice)); + d_rank_data_base_ += num_buffers; + graph_unreg_buffers_.clear(); + } + + /** + * This is the result after careful grid search. Using 36 blocks give the best + * or close to the best runtime on the devices I tried: A100, A10, A30, T4, + * V100. You'll notice that NCCL kernels also only take a small amount of SMs. + * Not quite sure the underlying reason, but my guess is that too many SMs + * will cause contention on NVLink bus. + */ + template + void allreduce(hipStream_t stream, T* input, T* output, int size, +#ifndef USE_ROCM + int threads = 512, int block_limit = 36){ +#else + int threads = 512, int block_limit = 16) { +#endif + auto d = packed_t::P::size; + if (size % d != 0) + throw std::runtime_error( + "custom allreduce currently requires input length to be multiple " + "of " + + std::to_string(d)); + if (block_limit > kMaxBlocks) + throw std::runtime_error("max supported block limit is " + std::to_string(kMaxBlocks) + ". Got " + + std::to_string(block_limit)); + + RankData* ptrs; + hipStreamCaptureStatus status; + CUDACHECK(hipStreamIsCapturing(stream, &status)); + if (status == hipStreamCaptureStatusActive) { + ptrs = d_rank_data_base_ + graph_unreg_buffers_.size(); + graph_unreg_buffers_.push_back(input); + } else { + auto it = buffers_.find(input); + if (it == buffers_.end()) + throw std::runtime_error("buffer address " + std::to_string(reinterpret_cast(input)) + + " is not registered!"); + ptrs = it->second; + } + + size /= d; + auto bytes = size * sizeof(typename packed_t::P); + int blocks = ::min(block_limit, (size + threads - 1) / threads); +#define KL(ngpus, name) \ + hipLaunchKernelGGL((name), dim3(blocks), dim3(threads), 0, stream, ptrs, sg_, self_sg_, output, rank_, \ + size); +#define REDUCE_CASE(ngpus) \ + case ngpus: { \ + if (world_size_ == 2) { \ + KL(ngpus, cross_device_reduce_1stage); \ + } else if (full_nvlink_) { \ + if ((world_size_ <= 4 && bytes < 512 * 1024) || (world_size_ <= 8 && bytes < 256 * 1024)) { \ + KL(ngpus, cross_device_reduce_1stage); \ + } else { \ + KL(ngpus, cross_device_reduce_2stage); \ + } \ + } \ + break; \ + } + + switch (world_size_) { + REDUCE_CASE(2) + REDUCE_CASE(4) + REDUCE_CASE(6) + REDUCE_CASE(8) + default: + throw std::runtime_error( + "custom allreduce only supports num gpus in (2,4,6,8). Actual num " + "gpus = " + + std::to_string(world_size_)); + } +#undef REDUCE_CASE +#undef KL +} + +~CustomAllreduce() { + for (auto [_, ptr] : ipc_handles_) { + CUDACHECK(hipIpcCloseMemHandle(ptr)); + } +} +}; // namespace vllm +/** + * To inspect PTX/SASS, copy paste this header file to compiler explorer and add + a template instantiation: + * template void vllm::CustomAllreduce::allreduce(hipStream_t, half *, + half *, int, int, int); +*/ +} // namespace vllm diff --git a/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h b/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h index 92a68d622..6e3eab1af 100644 --- a/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h +++ b/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h @@ -34,8 +34,23 @@ limitations under the License. return PyModule_Create(&module); \ } -// trt_reduce using fptr_t = int64_t; +#ifdef USE_ROCM +fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data, const std::vector& handles, + const std::vector& offsets, int64_t rank, bool full_nvlink); +void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out); +void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer, torch::Tensor& out); +void dispose(fptr_t _fa); +int64_t meta_size(); +void register_buffer(fptr_t _fa, torch::Tensor& t, const std::vector& handles, + const std::vector& offsets); +std::tuple> get_graph_buffer_ipc_meta(fptr_t _fa); +void register_graph_buffers(fptr_t _fa, const std::vector& handles, + const std::vector>& offsets); +torch::Tensor allocate_meta_buffer(int64_t size); +torch::Tensor get_meta_buffer_ipc_handle(torch::Tensor& inp); +#else +// trt_reduce fptr_t init_custom_ar(int64_t rank_id, int64_t world_size, torch::Tensor& rank_data, const std::vector& buffers, const std::vector& tmp_result_buffers, const std::vector& barrier_in, const std::vector& barrier_out); @@ -44,6 +59,7 @@ void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out); std::tuple, std::vector> get_graph_buffer_ipc_meta(fptr_t _fa); void register_graph_buffers(fptr_t _fa, const std::vector>& handles, const std::vector>& offsets); +#endif // moe_align_block_size void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size, diff --git a/sgl-kernel/src/sgl-kernel/ops/__init__.py b/sgl-kernel/src/sgl-kernel/ops/__init__.py index 9848323ad..b4e87695b 100644 --- a/sgl-kernel/src/sgl-kernel/ops/__init__.py +++ b/sgl-kernel/src/sgl-kernel/ops/__init__.py @@ -64,28 +64,79 @@ def apply_rope_with_cos_sin_cache_inplace( ) -def init_custom_reduce( - rank_id, num_devices, rank_data, buffers, tmp_buffers, barrier_in, barrier_out -): - return torch.ops.sgl_kernels.init_custom_ar( +if torch.version.hip is not None: + + def init_custom_ar( + meta: torch.Tensor, + rank_data: torch.Tensor, + handles: List[str], + offsets: List[int], + rank: int, + full_nvlink: bool, + ) -> int: + return torch.ops.sgl_kernels.init_custom_ar( + meta, rank_data, handles, offsets, rank, full_nvlink + ) + + def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None: + torch.ops.sgl_kernels.all_reduce_reg(fa, inp, out) + + def all_reduce_unreg( + fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor, out: torch.Tensor + ) -> None: + torch.ops.sgl_kernels.all_reduce_unreg(fa, inp, reg_buffer, out) + + def dispose(fa: int) -> None: + torch.ops.sgl_kernels.dispose(fa) + + def meta_size() -> int: + return torch.ops.sgl_kernels.meta_size() + + def register_buffer( + fa: int, t: torch.Tensor, handles: List[str], offsets: List[int] + ) -> None: + return torch.ops.sgl_kernels.register_buffer(fa, t, handles, offsets) + + def get_graph_buffer_ipc_meta(fa: int) -> Tuple[torch.Tensor, List[int]]: + return torch.ops.sgl_kernels.get_graph_buffer_ipc_meta(fa) + + def register_graph_buffers( + fa: int, handles: List[str], offsets: List[List[int]] + ) -> None: + torch.ops.sgl_kernels.register_graph_buffers(fa, handles, offsets) + + def allocate_meta_buffer(size: int) -> torch.Tensor: + return torch.ops.sgl_kernels.allocate_meta_buffer(size) + + def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor: + return torch.ops.sgl_kernels.get_meta_buffer_ipc_handle(inp) + +else: + # trt_reduce + def init_custom_reduce( rank_id, num_devices, rank_data, buffers, tmp_buffers, barrier_in, barrier_out - ) + ): + return torch.ops.sgl_kernels.init_custom_ar( + rank_id, + num_devices, + rank_data, + buffers, + tmp_buffers, + barrier_in, + barrier_out, + ) + def custom_dispose(fa): + torch.ops.sgl_kernels.dispose(fa) -def custom_dispose(fa): - torch.ops.sgl_kernels.dispose(fa) + def custom_reduce(fa, inp, out): + torch.ops.sgl_kernels.all_reduce(fa, inp, out) + def get_graph_buffer_ipc_meta(fa): + return torch.ops.sgl_kernels.get_graph_buffer_ipc_meta(fa) -def custom_reduce(fa, inp, out): - torch.ops.sgl_kernels.all_reduce(fa, inp, out) - - -def get_graph_buffer_ipc_meta(fa): - return torch.ops.sgl_kernels.get_graph_buffer_ipc_meta(fa) - - -def register_graph_buffers(fa, handles, offsets): - torch.ops.sgl_kernels.register_graph_buffers(fa, handles, offsets) + def register_graph_buffers(fa, handles, offsets): + torch.ops.sgl_kernels.register_graph_buffers(fa, handles, offsets) def moe_align_block_size( diff --git a/sgl-kernel/src/sgl-kernel/torch_extension_rocm.cc b/sgl-kernel/src/sgl-kernel/torch_extension_rocm.cc index 22f40da10..2c41bb57e 100644 --- a/sgl-kernel/src/sgl-kernel/torch_extension_rocm.cc +++ b/sgl-kernel/src/sgl-kernel/torch_extension_rocm.cc @@ -19,6 +19,37 @@ limitations under the License. #include "sgl_kernels_ops.h" TORCH_LIBRARY_EXPAND(sgl_kernels, m) { + // Custom all-reduce kernels + m.def( + "init_custom_ar(Tensor meta, Tensor rank_data, " + "str[] handles, int[] offsets, int rank, " + "bool full_nvlink) -> int"); + m.impl("init_custom_ar", torch::kCUDA, &init_custom_ar); + + m.def("all_reduce_reg(int fa, Tensor inp, Tensor! out) -> ()"); + m.impl("all_reduce_reg", torch::kCUDA, &all_reduce_reg); + + m.def( + "all_reduce_unreg(int fa, Tensor inp, Tensor reg_buffer, Tensor! out) -> " + "()"); + m.impl("all_reduce_unreg", torch::kCUDA, &all_reduce_unreg); + + m.def("dispose", &dispose); + + m.def("meta_size", &meta_size); + + m.def( + "register_buffer(int fa, Tensor t, str[] handles, " + "int[] offsets) -> ()"); + m.impl("register_buffer", torch::kCUDA, ®ister_buffer); + + m.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta); + m.def("register_graph_buffers", ®ister_graph_buffers); + m.def("allocate_meta_buffer", &allocate_meta_buffer); + m.impl("allocate_meta_buffer", torch::kCUDA, &allocate_meta_buffer); + m.def("get_meta_buffer_ipc_handle", &get_meta_buffer_ipc_handle); + m.impl("get_meta_buffer_ipc_handle", torch::kCPU, &get_meta_buffer_ipc_handle); + // moe_align_block_size m.def( "moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! "