From aba5ca154d4c16323dc3f4098d00cc4384e75910 Mon Sep 17 00:00:00 2001 From: Yi Zhang <1109276519@qq.com> Date: Sun, 6 Apr 2025 06:35:55 +0800 Subject: [PATCH] python transfer custom allreduce from trt kernel to vllm kernel (#5080) --- python/pyproject.toml | 2 +- python/sglang/srt/_custom_ops.py | 133 +++++++----------- .../device_communicators/custom_all_reduce.py | 102 ++++---------- scripts/ci_install_dependency.sh | 2 +- 4 files changed, 77 insertions(+), 162 deletions(-) diff --git a/python/pyproject.toml b/python/pyproject.toml index 6504d4320..2e3b367f9 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -47,7 +47,7 @@ runtime_common = [ srt = [ "sglang[runtime_common]", - "sgl-kernel==0.0.7", + "sgl-kernel==0.0.8", "flashinfer_python==0.2.3", "torch==2.5.1", "cuda-python", diff --git a/python/sglang/srt/_custom_ops.py b/python/sglang/srt/_custom_ops.py index d0bc51261..2e9db19f9 100644 --- a/python/sglang/srt/_custom_ops.py +++ b/python/sglang/srt/_custom_ops.py @@ -27,17 +27,20 @@ if not is_hpu(): logger.warning("Failed to import from custom_ar with %r", e) -if use_vllm_custom_allreduce and not is_hip(): - # vLLM custom allreduce +if not is_hip(): + if use_vllm_custom_allreduce: + custom_op = torch.ops._C_custom_ar + else: + custom_op = sgl_kernel.allreduce + + # custom allreduce def init_custom_ar( ipc_tensors: List[torch.Tensor], rank_data: torch.Tensor, rank: int, full_nvlink: bool, ) -> int: - return torch.ops._C_custom_ar.init_custom_ar( - ipc_tensors, rank_data, rank, full_nvlink - ) + return custom_op.init_custom_ar(ipc_tensors, rank_data, rank, full_nvlink) def all_reduce( fa: int, @@ -46,105 +49,69 @@ if use_vllm_custom_allreduce and not is_hip(): reg_buffer: int, reg_buffer_sz_bytes: int, ) -> None: - torch.ops._C_custom_ar.all_reduce(fa, inp, out, reg_buffer, reg_buffer_sz_bytes) + custom_op.all_reduce(fa, inp, out, reg_buffer, reg_buffer_sz_bytes) def dispose(fa: int) -> None: - torch.ops._C_custom_ar.dispose(fa) + custom_op.dispose(fa) def meta_size() -> int: - return torch.ops._C_custom_ar.meta_size() + return custom_op.meta_size() def register_buffer(fa: int, ipc_tensors: List[int]) -> None: - return torch.ops._C_custom_ar.register_buffer(fa, ipc_tensors) + return custom_op.register_buffer(fa, ipc_tensors) def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]: - return torch.ops._C_custom_ar.get_graph_buffer_ipc_meta(fa) + return custom_op.get_graph_buffer_ipc_meta(fa) def register_graph_buffers( fa: int, handles: List[List[int]], offsets: List[List[int]] ) -> None: - torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets) + custom_op.register_graph_buffers(fa, handles, offsets) else: - if is_hip(): - # ROCM custom allreduce + # ROCM custom allreduce - def init_custom_ar( - meta: torch.Tensor, - rank_data: torch.Tensor, - handles: List[str], - offsets: List[int], - rank: int, - full_nvlink: bool, - ) -> int: - return sgl_kernel.allreduce.init_custom_ar( - meta, rank_data, handles, offsets, rank, full_nvlink - ) + def init_custom_ar( + meta: torch.Tensor, + rank_data: torch.Tensor, + handles: List[str], + offsets: List[int], + rank: int, + full_nvlink: bool, + ) -> int: + return sgl_kernel.allreduce.init_custom_ar( + meta, rank_data, handles, offsets, rank, full_nvlink + ) - def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None: - sgl_kernel.allreduce.all_reduce_reg(fa, inp, out) + def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None: + sgl_kernel.allreduce.all_reduce_reg(fa, inp, out) - def all_reduce_unreg( - fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor, out: torch.Tensor - ) -> None: - sgl_kernel.allreduce.all_reduce_unreg(fa, inp, reg_buffer, out) + def all_reduce_unreg( + fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor, out: torch.Tensor + ) -> None: + sgl_kernel.allreduce.all_reduce_unreg(fa, inp, reg_buffer, out) - def dispose(fa: int) -> None: - sgl_kernel.allreduce.dispose(fa) + def dispose(fa: int) -> None: + sgl_kernel.allreduce.dispose(fa) - def meta_size() -> int: - return sgl_kernel.allreduce.meta_size() + def meta_size() -> int: + return sgl_kernel.allreduce.meta_size() - def register_buffer( - fa: int, t: torch.Tensor, handles: List[str], offsets: List[int] - ) -> None: - return sgl_kernel.allreduce.register_buffer(fa, t, handles, offsets) + def register_buffer( + fa: int, t: torch.Tensor, handles: List[str], offsets: List[int] + ) -> None: + return sgl_kernel.allreduce.register_buffer(fa, t, handles, offsets) - def get_graph_buffer_ipc_meta(fa: int) -> Tuple[torch.Tensor, List[int]]: - return sgl_kernel.allreduce.get_graph_buffer_ipc_meta(fa) + def get_graph_buffer_ipc_meta(fa: int) -> Tuple[torch.Tensor, List[int]]: + return sgl_kernel.allreduce.get_graph_buffer_ipc_meta(fa) - def register_graph_buffers( - fa: int, handles: List[str], offsets: List[List[int]] - ) -> None: - sgl_kernel.allreduce.register_graph_buffers(fa, handles, offsets) + def register_graph_buffers( + fa: int, handles: List[str], offsets: List[List[int]] + ) -> None: + sgl_kernel.allreduce.register_graph_buffers(fa, handles, offsets) - def allocate_meta_buffer(size: int) -> torch.Tensor: - return sgl_kernel.allreduce.allocate_meta_buffer(size) + def allocate_meta_buffer(size: int) -> torch.Tensor: + return sgl_kernel.allreduce.allocate_meta_buffer(size) - def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor: - return sgl_kernel.allreduce.get_meta_buffer_ipc_handle(inp) - - else: - # TRTLLM custom allreduce - def init_custom_ar( - rank_id: int, - world_size: int, - rank_data_base: torch.Tensor, - buffers: List[int], - tmp_result_buffers: List[int], - barrier_in: List[int], - barrier_out: List[int], - ) -> int: - return sgl_kernel.init_custom_reduce( - rank_id, - world_size, - rank_data_base, - buffers, - tmp_result_buffers, - barrier_in, - barrier_out, - ) - - def all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None: - sgl_kernel.custom_reduce(fa, inp, out) - - def dispose(fa: int) -> None: - sgl_kernel.custom_dispose(fa) - - def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]: - return sgl_kernel.get_graph_buffer_ipc_meta(fa) - - def register_graph_buffers( - fa: int, handles: List[List[int]], offsets: List[List[int]] - ) -> None: - sgl_kernel.register_graph_buffers(fa, handles, offsets) + def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor: + return sgl_kernel.allreduce.get_meta_buffer_ipc_handle(inp) diff --git a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py index 8d81e47a1..51bcd4722 100644 --- a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py +++ b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py @@ -257,7 +257,7 @@ class CustomAllreduce: self.world_size = world_size self.full_nvlink = full_nvlink - if ops.use_vllm_custom_allreduce and not _is_hip: + if not _is_hip: # Buffers memory are owned by this Python class and passed to C++. # Meta data composes of two parts: meta data for synchronization and a # temporary buffer for storing intermediate allreduce results. @@ -280,56 +280,24 @@ class CustomAllreduce: ) ops.register_buffer(self._ptr, self.buffer_ptrs) else: - if _is_hip: - # meta data buffers need to be "uncached" for signal on MI200 - self.meta = ops.allocate_meta_buffer(ops.meta_size() + max_size) - self.buffer = torch.empty( - max_size, dtype=torch.uint8, device=self.device - ) - handle = ops.get_meta_buffer_ipc_handle(self.meta) - shard_data = ( - bytes(handle), # ipc handle to base ptr - 0, # offset of base ptr - ) - handles, offsets = self._gather_ipc_meta(shard_data) - self.rank_data = torch.empty( - 8 * 1024 * 1024, dtype=torch.uint8, device=self.device - ) - self._ptr = ops.init_custom_ar( - self.meta, self.rank_data, handles, offsets, rank, self.full_nvlink - ) - self.register_buffer(self.buffer) - self.MSCCL = os.getenv("RCCL_MSCCL_ENABLE", "1") == "1" - else: - # From TensorRT-LLM getMaxRequiredWorkspaceSize - self.max_required_workspace_size = [16 * 1024 * 1024, 8 * 1024 * 1024] + # meta data buffers need to be "uncached" for signal on MI200 + self.meta = ops.allocate_meta_buffer(ops.meta_size() + max_size) + self.buffer = torch.empty(max_size, dtype=torch.uint8, device=self.device) + handle = ops.get_meta_buffer_ipc_handle(self.meta) + shard_data = ( + bytes(handle), # ipc handle to base ptr + 0, # offset of base ptr + ) + handles, offsets = self._gather_ipc_meta(shard_data) + self.rank_data = torch.empty( + 8 * 1024 * 1024, dtype=torch.uint8, device=self.device + ) + self._ptr = ops.init_custom_ar( + self.meta, self.rank_data, handles, offsets, rank, self.full_nvlink + ) + self.register_buffer(self.buffer) + self.MSCCL = os.getenv("RCCL_MSCCL_ENABLE", "1") == "1" - # sizeof(uint32_t) * (MAX_ALL_REDUCE_BLOCKS + 2) * MAX_RANKS_PER_NODE; - self.barrier_max_size = 8 * (36 + 2) * 8 - - self.buffer_ptrs = self.create_shared_buffer(max_size, group=group) - self.tmp_result_buffer_ptrs = self.create_shared_buffer( - max_size, group=group - ) - self.rank_data_base = torch.empty( - 8 * 1024 * 1024, dtype=torch.uint8, device=self.device - ) - self.barrier_in_ptrs = self.create_shared_buffer( - self.barrier_max_size, group=group - ) - self.barrier_out_ptrs = self.create_shared_buffer( - self.barrier_max_size, group=group - ) - - self._ptr = ops.init_custom_ar( - rank, - world_size, - self.rank_data_base, - self.buffer_ptrs, - self.tmp_result_buffer_ptrs, - self.barrier_in_ptrs, - self.barrier_out_ptrs, - ) self.disabled = False @staticmethod @@ -455,7 +423,7 @@ class CustomAllreduce: return False # for 4 or more non NVLink-capable GPUs, custom allreduce provides # little performance improvement over NCCL. - if ops.use_vllm_custom_allreduce and not _is_hip: + if not _is_hip: if self.world_size == 2 or self.full_nvlink: return inp_size < self.max_size return False @@ -471,18 +439,6 @@ class CustomAllreduce: return inp_size < self.max_size return False - if self.world_size == 2: - return ( - inp_size < self.max_size - and inp_size < self.max_required_workspace_size[0] - ) - - if self.full_nvlink: - return ( - inp_size < self.max_size - and inp_size < self.max_required_workspace_size[1] - ) - return False # all reduce, assuming inp tensor is IPC registered with register_buffer, @@ -515,15 +471,12 @@ class CustomAllreduce: """ if out is None: out = torch.empty_like(inp) - if ops.use_vllm_custom_allreduce: - if registered: - ops.all_reduce(self._ptr, inp, out, 0, 0) - else: - ops.all_reduce( - self._ptr, inp, out, self.buffer_ptrs[self.rank], self.max_size - ) + if registered: + ops.all_reduce(self._ptr, inp, out, 0, 0) else: - ops.all_reduce(self._ptr, inp, out) + ops.all_reduce( + self._ptr, inp, out, self.buffer_ptrs[self.rank], self.max_size + ) return out def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]: @@ -554,14 +507,9 @@ class CustomAllreduce: def close(self): if not self.disabled and self._ptr: ops.dispose(self._ptr) - if ops.use_vllm_custom_allreduce: + if _is_cuda: self.free_shared_buffer(self.meta_ptrs) self.free_shared_buffer(self.buffer_ptrs) - elif _is_cuda: - self.free_shared_buffer(self.buffer_ptrs) - self.free_shared_buffer(self.tmp_result_buffer_ptrs) - self.free_shared_buffer(self.barrier_in_ptrs) - self.free_shared_buffer(self.barrier_out_ptrs) self._ptr = 0 def __del__(self): diff --git a/scripts/ci_install_dependency.sh b/scripts/ci_install_dependency.sh index ec646b804..d7b4c89b4 100755 --- a/scripts/ci_install_dependency.sh +++ b/scripts/ci_install_dependency.sh @@ -20,7 +20,7 @@ pip install --upgrade pip # Install flashinfer and sgl-kernel pip install flashinfer_python==0.2.3 --find-links ${FLASHINFER_REPO} --no-cache-dir -pip install sgl-kernel==0.0.7 --no-cache-dir +pip install sgl-kernel==0.0.8 --no-cache-dir # Install the main package pip install -e "python[all]" --find-links ${FLASHINFER_REPO}