From 24cafe317746a1051ae965925eeaab539049a09f Mon Sep 17 00:00:00 2001 From: yizhang2077 <1109276519@qq.com> Date: Sun, 19 Jan 2025 22:30:38 +0800 Subject: [PATCH] add config to swtich from vllm custom allreduce to sgl_kernel custom allreduce (#2981) --- python/sglang/srt/_custom_ops.py | 109 +++++++++++------ .../device_communicators/custom_all_reduce.py | 111 ++++++++++++------ 2 files changed, 154 insertions(+), 66 deletions(-) diff --git a/python/sglang/srt/_custom_ops.py b/python/sglang/srt/_custom_ops.py index f59f67605..3c00a8552 100644 --- a/python/sglang/srt/_custom_ops.py +++ b/python/sglang/srt/_custom_ops.py @@ -3,6 +3,7 @@ import contextlib import functools import importlib import logging +import os from typing import TYPE_CHECKING, List, Optional, Tuple, Union import torch @@ -11,12 +12,19 @@ import torch.library from sglang.srt.utils import is_hpu logger = logging.getLogger(__name__) +use_vllm_custom_allreduce = os.environ.get("USE_VLLM_CUSTOM_ALLREDUCE", default=False) if not is_hpu(): - try: - import sgl_kernel - except ImportError as e: - logger.warning("Failed to import from custom_ar with %r", e) + if use_vllm_custom_allreduce: + try: + import vllm._C + except ImportError as e: + logger.warning("Failed to import from vllm._C with %r", e) + else: + try: + import sgl_kernel + except ImportError as e: + logger.warning("Failed to import from custom_ar with %r", e) def hint_on_error(fn): @@ -48,43 +56,78 @@ def hint_on_error(fn): return wrapper -# custom ar -def init_custom_ar( - rank_id: int, - world_size: int, - rank_data_base: torch.Tensor, - buffers: List[int], - tmp_result_buffers: List[int], - barrier_in: List[int], - barrier_out: List[int], -) -> int: - return sgl_kernel.ops.init_custom_reduce( - rank_id, - world_size, - rank_data_base, - buffers, - tmp_result_buffers, - barrier_in, - barrier_out, - ) +if use_vllm_custom_allreduce: + # custom ar + def init_custom_ar( + ipc_tensors: List[torch.Tensor], + rank_data: torch.Tensor, + rank: int, + full_nvlink: bool, + ) -> int: + return torch.ops._C_custom_ar.init_custom_ar( + ipc_tensors, rank_data, rank, full_nvlink + ) + def all_reduce( + fa: int, + inp: torch.Tensor, + out: torch.Tensor, + reg_buffer: int, + reg_buffer_sz_bytes: int, + ) -> None: + torch.ops._C_custom_ar.all_reduce(fa, inp, out, reg_buffer, reg_buffer_sz_bytes) -def all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None: - sgl_kernel.ops.custom_reduce(fa, inp, out) + def dispose(fa: int) -> None: + torch.ops._C_custom_ar.dispose(fa) + def meta_size() -> int: + return torch.ops._C_custom_ar.meta_size() -def dispose(fa: int) -> None: - sgl_kernel.ops.custom_dispose(fa) + def register_buffer(fa: int, ipc_tensors: List[int]) -> None: + return torch.ops._C_custom_ar.register_buffer(fa, ipc_tensors) + def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]: + return torch.ops._C_custom_ar.get_graph_buffer_ipc_meta(fa) -def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]: - return sgl_kernel.ops.get_graph_buffer_ipc_meta(fa) + def register_graph_buffers( + fa: int, handles: List[List[int]], offsets: List[List[int]] + ) -> None: + torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets) +else: + # custom ar + def init_custom_ar( + rank_id: int, + world_size: int, + rank_data_base: torch.Tensor, + buffers: List[int], + tmp_result_buffers: List[int], + barrier_in: List[int], + barrier_out: List[int], + ) -> int: + return sgl_kernel.ops.init_custom_reduce( + rank_id, + world_size, + rank_data_base, + buffers, + tmp_result_buffers, + barrier_in, + barrier_out, + ) -def register_graph_buffers( - fa: int, handles: List[List[int]], offsets: List[List[int]] -) -> None: - sgl_kernel.ops.register_graph_buffers(fa, handles, offsets) + def all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None: + sgl_kernel.ops.custom_reduce(fa, inp, out) + + def dispose(fa: int) -> None: + sgl_kernel.ops.custom_dispose(fa) + + def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]: + return sgl_kernel.ops.get_graph_buffer_ipc_meta(fa) + + def register_graph_buffers( + fa: int, handles: List[List[int]], offsets: List[List[int]] + ) -> None: + sgl_kernel.ops.register_graph_buffers(fa, handles, offsets) # temporary fix for https://github.com/vllm-project/vllm/issues/5456 diff --git a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py index ba9feb59d..28aa9d481 100644 --- a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py +++ b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py @@ -21,8 +21,10 @@ from sglang.srt.distributed.parallel_state import in_the_same_node_as from sglang.srt.utils import cuda_device_count_stateless, is_cuda try: - import sgl_kernel - + if ops.use_vllm_custom_allreduce: + ops.meta_size() + else: + import sgl_kernel custom_ar = True except Exception: # For AMD GPUs and CPUs @@ -201,33 +203,58 @@ class CustomAllreduce: self.world_size = world_size self.full_nvlink = full_nvlink - # From TensorRT-LLM getMaxRequiredWorkspaceSize - self.max_required_workspace_size = [16 * 1024 * 1024, 8 * 1024 * 1024] + if ops.use_vllm_custom_allreduce: + # Buffers memory are owned by this Python class and passed to C++. + # Meta data composes of two parts: meta data for synchronization and a + # temporary buffer for storing intermediate allreduce results. + self.meta_ptrs = self.create_shared_buffer( + ops.meta_size() + max_size, group=group + ) + # This is a pre-registered IPC buffer. In eager mode, input tensors + # are first copied into this buffer before allreduce is performed + self.buffer_ptrs = self.create_shared_buffer(max_size, group=group) + # This is a buffer for storing the tuples of pointers pointing to + # IPC buffers from all ranks. Each registered tuple has size of + # 8*world_size bytes where world_size is at most 8. Allocating 8MB + # is enough for 131072 such tuples. The largest model I've seen only + # needs less than 10000 of registered tuples. + self.rank_data = torch.empty( + 8 * 1024 * 1024, dtype=torch.uint8, device=self.device + ) + self._ptr = ops.init_custom_ar( + self.meta_ptrs, self.rank_data, rank, self.full_nvlink + ) + ops.register_buffer(self._ptr, self.buffer_ptrs) + else: + # From TensorRT-LLM getMaxRequiredWorkspaceSize + self.max_required_workspace_size = [16 * 1024 * 1024, 8 * 1024 * 1024] - # sizeof(uint32_t) * (MAX_ALL_REDUCE_BLOCKS + 2) * MAX_RANKS_PER_NODE; - self.barrier_max_size = 8 * (36 + 2) * 8 + # sizeof(uint32_t) * (MAX_ALL_REDUCE_BLOCKS + 2) * MAX_RANKS_PER_NODE; + self.barrier_max_size = 8 * (36 + 2) * 8 - self.buffer_ptrs = self.create_shared_buffer(max_size, group=group) - self.tmp_result_buffer_ptrs = self.create_shared_buffer(max_size, group=group) - self.rank_data_base = torch.empty( - 8 * 1024 * 1024, dtype=torch.uint8, device=self.device - ) - self.barrier_in_ptrs = self.create_shared_buffer( - self.barrier_max_size, group=group - ) - self.barrier_out_ptrs = self.create_shared_buffer( - self.barrier_max_size, group=group - ) + self.buffer_ptrs = self.create_shared_buffer(max_size, group=group) + self.tmp_result_buffer_ptrs = self.create_shared_buffer( + max_size, group=group + ) + self.rank_data_base = torch.empty( + 8 * 1024 * 1024, dtype=torch.uint8, device=self.device + ) + self.barrier_in_ptrs = self.create_shared_buffer( + self.barrier_max_size, group=group + ) + self.barrier_out_ptrs = self.create_shared_buffer( + self.barrier_max_size, group=group + ) - self._ptr = ops.init_custom_ar( - rank, - world_size, - self.rank_data_base, - self.buffer_ptrs, - self.tmp_result_buffer_ptrs, - self.barrier_in_ptrs, - self.barrier_out_ptrs, - ) + self._ptr = ops.init_custom_ar( + rank, + world_size, + self.rank_data_base, + self.buffer_ptrs, + self.tmp_result_buffer_ptrs, + self.barrier_in_ptrs, + self.barrier_out_ptrs, + ) self.disabled = False @staticmethod @@ -307,6 +334,11 @@ class CustomAllreduce: return False # for 4 or more non NVLink-capable GPUs, custom allreduce provides # little performance improvement over NCCL. + if ops.use_vllm_custom_allreduce: + if self.world_size == 2 or self.full_nvlink: + return inp_size < self.max_size + return False + if self.world_size == 2: return ( inp_size < self.max_size @@ -326,6 +358,7 @@ class CustomAllreduce: inp: torch.Tensor, *, out: torch.Tensor = None, + registered: bool = False, ): """Performs an out-of-place all reduce. @@ -335,7 +368,15 @@ class CustomAllreduce: """ if out is None: out = torch.empty_like(inp) - ops.all_reduce(self._ptr, inp, out) + if ops.use_vllm_custom_allreduce: + if registered: + ops.all_reduce(self._ptr, inp, out, 0, 0) + else: + ops.all_reduce( + self._ptr, inp, out, self.buffer_ptrs[self.rank], self.max_size + ) + else: + ops.all_reduce(self._ptr, inp, out) return out def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]: @@ -345,21 +386,25 @@ class CustomAllreduce: return None if self._IS_CAPTURING: if torch.cuda.is_current_stream_capturing(): - return self.all_reduce(input) + return self.all_reduce(input, registered=True) else: # If warm up, mimic the allocation pattern since custom # allreduce is out-of-place. return torch.empty_like(input) else: - return self.all_reduce(input) + return self.all_reduce(input, registered=False) def close(self): if not self.disabled and self._ptr: ops.dispose(self._ptr) - self.free_shared_buffer(self.buffer_ptrs) - self.free_shared_buffer(self.tmp_result_buffer_ptrs) - self.free_shared_buffer(self.barrier_in_ptrs) - self.free_shared_buffer(self.barrier_out_ptrs) + if ops.use_vllm_custom_allreduce: + self.free_shared_buffer(self.meta_ptrs) + self.free_shared_buffer(self.buffer_ptrs) + else: + self.free_shared_buffer(self.buffer_ptrs) + self.free_shared_buffer(self.tmp_result_buffer_ptrs) + self.free_shared_buffer(self.barrier_in_ptrs) + self.free_shared_buffer(self.barrier_out_ptrs) self._ptr = 0 def __del__(self):