From e0a2c963088552f39d4858118b0c8dae294f16c4 Mon Sep 17 00:00:00 2001 From: kk <43161300+kkHuang-amd@users.noreply.github.com> Date: Tue, 4 Mar 2025 18:59:03 +0800 Subject: [PATCH] Fix breakage problem when using custom_ar (#4052) --- python/sglang/srt/_custom_ops.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/python/sglang/srt/_custom_ops.py b/python/sglang/srt/_custom_ops.py index 24c0dd79f..c5056ffc2 100644 --- a/python/sglang/srt/_custom_ops.py +++ b/python/sglang/srt/_custom_ops.py @@ -75,42 +75,42 @@ else: rank: int, full_nvlink: bool, ) -> int: - return sgl_kernel.ops.init_custom_ar( + return sgl_kernel.ops.allreduce.init_custom_ar( meta, rank_data, handles, offsets, rank, full_nvlink ) def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None: - sgl_kernel.ops.all_reduce_reg(fa, inp, out) + sgl_kernel.ops.allreduce.all_reduce_reg(fa, inp, out) def all_reduce_unreg( fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor, out: torch.Tensor ) -> None: - sgl_kernel.ops.all_reduce_unreg(fa, inp, reg_buffer, out) + sgl_kernel.ops.allreduce.all_reduce_unreg(fa, inp, reg_buffer, out) def dispose(fa: int) -> None: - sgl_kernel.ops.dispose(fa) + sgl_kernel.ops.allreduce.dispose(fa) def meta_size() -> int: - return sgl_kernel.ops.meta_size() + return sgl_kernel.ops.allreduce.meta_size() def register_buffer( fa: int, t: torch.Tensor, handles: List[str], offsets: List[int] ) -> None: - return sgl_kernel.ops.register_buffer(fa, t, handles, offsets) + return sgl_kernel.ops.allreduce.register_buffer(fa, t, handles, offsets) def get_graph_buffer_ipc_meta(fa: int) -> Tuple[torch.Tensor, List[int]]: - return sgl_kernel.ops.get_graph_buffer_ipc_meta(fa) + return sgl_kernel.ops.allreduce.get_graph_buffer_ipc_meta(fa) def register_graph_buffers( fa: int, handles: List[str], offsets: List[List[int]] ) -> None: - sgl_kernel.ops.register_graph_buffers(fa, handles, offsets) + sgl_kernel.ops.allreduce.register_graph_buffers(fa, handles, offsets) def allocate_meta_buffer(size: int) -> torch.Tensor: - return sgl_kernel.ops.allocate_meta_buffer(size) + return sgl_kernel.ops.allreduce.allocate_meta_buffer(size) def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor: - return sgl_kernel.ops.get_meta_buffer_ipc_handle(inp) + return sgl_kernel.ops.allreduce.get_meta_buffer_ipc_handle(inp) else: # TRTLLM custom allreduce