From b02da24a5b8cc0b8e4971f59a7e0f8afcfeab9b3 Mon Sep 17 00:00:00 2001 From: Ke Bao Date: Mon, 30 Dec 2024 18:07:01 +0800 Subject: [PATCH] Refactor sgl-kernel build (#2642) --- sgl-kernel/CMakeLists.txt | 29 ++--- sgl-kernel/setup.py | 101 ++++++------------ sgl-kernel/src/sgl-kernel/__init__.py | 12 ++- .../src/sgl-kernel/csrc/moe_align_kernel.cu | 8 +- .../src/sgl-kernel/csrc/sgl_kernel_ops.cu | 32 ++++++ sgl-kernel/src/sgl-kernel/csrc/trt_reduce.cc | 13 --- sgl-kernel/src/sgl-kernel/csrc/warp_reduce.cc | 14 --- .../src/sgl-kernel/csrc/warp_reduce_kernel.cu | 3 +- sgl-kernel/src/sgl-kernel/ops/__init__.py | 22 +++- 9 files changed, 108 insertions(+), 126 deletions(-) create mode 100644 sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu delete mode 100644 sgl-kernel/src/sgl-kernel/csrc/trt_reduce.cc delete mode 100644 sgl-kernel/src/sgl-kernel/csrc/warp_reduce.cc diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index adb81fa2b..137e7a9a8 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -25,46 +25,29 @@ list(APPEND CMAKE_PREFIX_PATH "${TORCH_CMAKE_PATH}") find_package(Torch REQUIRED) # Warp Reduce library -add_library(warp_reduce SHARED - src/sgl-kernel/csrc/warp_reduce.cc +add_library(_kernels SHARED src/sgl-kernel/csrc/warp_reduce_kernel.cu -) - -target_include_directories(warp_reduce - PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/src/sgl-kernel/csrc - ${CUDA_INCLUDE_DIRS} - ${TORCH_INCLUDE_DIRS} -) - -target_link_libraries(warp_reduce - PRIVATE - ${TORCH_LIBRARIES} - Python3::Python -) - -# TRT Reduce library -add_library(trt_reduce SHARED - src/sgl-kernel/csrc/trt_reduce.cc src/sgl-kernel/csrc/trt_reduce_internal.cu src/sgl-kernel/csrc/trt_reduce_kernel.cu + src/sgl-kernel/csrc/moe_align_kernel.cu + src/sgl-kernel/csrc/sgl_kernel_ops.cu ) -target_include_directories(trt_reduce +target_include_directories(_kernels PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/src/sgl-kernel/csrc ${CUDA_INCLUDE_DIRS} ${TORCH_INCLUDE_DIRS} ) -target_link_libraries(trt_reduce +target_link_libraries(_kernels PRIVATE ${TORCH_LIBRARIES} Python3::Python ) # Set common properties for both libraries -foreach(target warp_reduce trt_reduce) +foreach(target _kernels) set_target_properties(${target} PROPERTIES CUDA_SEPARABLE_COMPILATION ON POSITION_INDEPENDENT_CODE ON diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py index 5b8da4b15..bfed5f6e5 100644 --- a/sgl-kernel/setup.py +++ b/sgl-kernel/setup.py @@ -58,78 +58,45 @@ def update_wheel_platform_tag(): old_wheel.rename(new_wheel) +nvcc_flags = [ + "-O3", + "-Xcompiler", + "-fPIC", + "-gencode=arch=compute_75,code=sm_75", + "-gencode=arch=compute_80,code=sm_80", + "-gencode=arch=compute_89,code=sm_89", + "-gencode=arch=compute_90,code=sm_90", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF2_OPERATORS__", +] +cxx_flags = ["-O3"] +libraries = ["c10", "torch", "torch_python"] +extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib"] +ext_modules = [ + CUDAExtension( + name="sgl_kernel.ops._kernels", + sources=[ + "src/sgl-kernel/csrc/warp_reduce_kernel.cu", + "src/sgl-kernel/csrc/trt_reduce_internal.cu", + "src/sgl-kernel/csrc/trt_reduce_kernel.cu", + "src/sgl-kernel/csrc/moe_align_kernel.cu", + "src/sgl-kernel/csrc/sgl_kernel_ops.cu", + ], + extra_compile_args={ + "nvcc": nvcc_flags, + "cxx": cxx_flags, + }, + libraries=libraries, + extra_link_args=extra_link_args, + ), +] + setup( name="sgl-kernel", version=get_version(), packages=["sgl_kernel"], package_dir={"": "src"}, - ext_modules=[ - CUDAExtension( - "sgl_kernel.ops.warp_reduce_cuda", - [ - "src/sgl-kernel/csrc/warp_reduce.cc", - "src/sgl-kernel/csrc/warp_reduce_kernel.cu", - ], - extra_compile_args={ - "nvcc": [ - "-O3", - "-Xcompiler", - "-fPIC", - "-gencode=arch=compute_75,code=sm_75", - "-gencode=arch=compute_80,code=sm_80", - "-gencode=arch=compute_89,code=sm_89", - "-gencode=arch=compute_90,code=sm_90", - ], - "cxx": ["-O3"], - }, - libraries=["c10", "torch", "torch_python"], - extra_link_args=["-Wl,-rpath,$ORIGIN/../../torch/lib"], - ), - CUDAExtension( - "sgl_kernel.ops.custom_reduce_cuda", - [ - "src/sgl-kernel/csrc/trt_reduce_internal.cu", - "src/sgl-kernel/csrc/trt_reduce_kernel.cu", - "src/sgl-kernel/csrc/trt_reduce.cc", - ], - extra_compile_args={ - "nvcc": [ - "-O3", - "-Xcompiler", - "-fPIC", - "-gencode=arch=compute_75,code=sm_75", - "-gencode=arch=compute_80,code=sm_80", - "-gencode=arch=compute_89,code=sm_89", - "-gencode=arch=compute_90,code=sm_90", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF2_OPERATORS__", - ], - "cxx": ["-O3"], - }, - libraries=["c10", "torch", "torch_python"], - extra_link_args=["-Wl,-rpath,$ORIGIN/../../torch/lib"], - ), - CUDAExtension( - "sgl_kernel.ops.moe_align_block_size", - [ - "src/sgl-kernel/csrc/moe_align_kernel.cu", - ], - extra_compile_args={ - "nvcc": [ - "-O3", - "-Xcompiler", - "-fPIC", - "-gencode=arch=compute_75,code=sm_75", - "-gencode=arch=compute_80,code=sm_80", - "-gencode=arch=compute_89,code=sm_89", - "-gencode=arch=compute_90,code=sm_90", - ], - "cxx": ["-O3"], - }, - libraries=["c10", "torch", "torch_python"], - extra_link_args=["-Wl,-rpath,$ORIGIN/../../torch/lib"], - ), - ], + ext_modules=ext_modules, cmdclass={"build_ext": BuildExtension}, install_requires=["torch"], ) diff --git a/sgl-kernel/src/sgl-kernel/__init__.py b/sgl-kernel/src/sgl-kernel/__init__.py index 1019896fe..c0a5caa10 100644 --- a/sgl-kernel/src/sgl-kernel/__init__.py +++ b/sgl-kernel/src/sgl-kernel/__init__.py @@ -1,5 +1,15 @@ -from .ops import moe_align_block_size +from sgl_kernel.ops import ( + custom_dispose, + custom_reduce, + init_custom_reduce, + moe_align_block_size, + warp_reduce, +) __all__ = [ "moe_align_block_size", + "warp_reduce", + "init_custom_reduce", + "custom_dispose", + "custom_reduce", ] diff --git a/sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu index 795f9157d..dfd28032f 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu @@ -3,11 +3,11 @@ #include #include #include -#include -#include #include +#include "utils.hpp" + #ifdef USE_ROCM #include #endif @@ -133,7 +133,3 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t b token_cnts_buffer.data_ptr(), cumsum_buffer.data_ptr()); }); } - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("moe_align_block_size", &moe_align_block_size, "MOE Align Block Size (CUDA)"); -} diff --git a/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu b/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu new file mode 100644 index 000000000..9518eb4ff --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu @@ -0,0 +1,32 @@ +#include "utils.hpp" + +// warp_reduce +torch::Tensor warp_reduce_cuda(torch::Tensor input); + +torch::Tensor warp_reduce(torch::Tensor input) { + CHECK_CUDA_INPUT(input); + return warp_reduce_cuda(input); +} + +// trt_reduce +using fptr_t = int64_t; +fptr_t init_custom_ar(int64_t rank_id, int64_t world_size, const std::vector& buffers, + const std::vector& barrier_in, const std::vector& barrier_out); +void dispose(fptr_t _fa); +void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out); + +// moe_align_block_size +void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size, + torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad, + torch::Tensor token_cnts_buffer, torch::Tensor cumsum_buffer); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + // warp_reduce + m.def("reduce", &warp_reduce, "Warp Reduce (CUDA)"); + // trt_reduce + m.def("init_custom_ar", &init_custom_ar, "init custom allreduce meta (CUDA)"); + m.def("dispose", &dispose, "dispose custom allreduce meta"); + m.def("all_reduce", &all_reduce, "custom all reduce (CUDA)"); + // moe_align_block_size + m.def("moe_align_block_size", &moe_align_block_size, "MOE Align Block Size (CUDA)"); +} diff --git a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce.cc b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce.cc deleted file mode 100644 index 4d8f732af..000000000 --- a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce.cc +++ /dev/null @@ -1,13 +0,0 @@ -#include - -using fptr_t = int64_t; -fptr_t init_custom_ar(int64_t rank_id, int64_t world_size, const std::vector& buffers, - const std::vector& barrier_in, const std::vector& barrier_out); -void dispose(fptr_t _fa); -void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out); - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("init_custom_ar", &init_custom_ar, "init custom allreduce meta (CUDA)"); - m.def("dispose", &dispose, "dispose custom allreduce meta"); - m.def("all_reduce", &all_reduce, "custom all reduce (CUDA)"); -} diff --git a/sgl-kernel/src/sgl-kernel/csrc/warp_reduce.cc b/sgl-kernel/src/sgl-kernel/csrc/warp_reduce.cc deleted file mode 100644 index 379b4cc15..000000000 --- a/sgl-kernel/src/sgl-kernel/csrc/warp_reduce.cc +++ /dev/null @@ -1,14 +0,0 @@ -#include - -#include "utils.hpp" - -torch::Tensor warp_reduce_cuda(torch::Tensor input); - -torch::Tensor warp_reduce(torch::Tensor input) { - CHECK_CUDA_INPUT(input); - return warp_reduce_cuda(input); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("reduce", &warp_reduce, "Warp Reduce (CUDA)"); -} diff --git a/sgl-kernel/src/sgl-kernel/csrc/warp_reduce_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/warp_reduce_kernel.cu index d75cc9bee..7a3f2f5fc 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/warp_reduce_kernel.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/warp_reduce_kernel.cu @@ -1,6 +1,7 @@ #include #include -#include + +#include "utils.hpp" #define FINAL_MASK 0xffffffff #define BLOCK_SIZE 256 diff --git a/sgl-kernel/src/sgl-kernel/ops/__init__.py b/sgl-kernel/src/sgl-kernel/ops/__init__.py index 55318879a..3b4cbfb17 100644 --- a/sgl-kernel/src/sgl-kernel/ops/__init__.py +++ b/sgl-kernel/src/sgl-kernel/ops/__init__.py @@ -1,4 +1,24 @@ -from .moe_align_block_size import moe_align_block_size as _moe_align_block_size +from sgl_kernel.ops._kernels import all_reduce as _all_reduce +from sgl_kernel.ops._kernels import dispose as _dispose +from sgl_kernel.ops._kernels import init_custom_ar as _init_custom_ar +from sgl_kernel.ops._kernels import moe_align_block_size as _moe_align_block_size +from sgl_kernel.ops._kernels import reduce as _reduce + + +def warp_reduce(input_tensor): + return _reduce(input_tensor) + + +def init_custom_reduce(rank_id, num_devices, buffers, barrier_in, barrier_out): + return _init_custom_ar(rank_id, num_devices, buffers, barrier_in, barrier_out) + + +def custom_dispose(fa): + _dispose(fa) + + +def custom_reduce(fa, inp, out): + _all_reduce(fa, inp, out) def moe_align_block_size(