From b4403985d00939958db69194f94a795d3ea95bce Mon Sep 17 00:00:00 2001 From: Ke Bao Date: Tue, 31 Dec 2024 14:28:29 +0800 Subject: [PATCH] Add cutlass submodule for sgl-kernel (#2676) --- .gitmodules | 3 +++ sgl-kernel/3rdparty/cutlass | 1 + sgl-kernel/CMakeLists.txt | 4 ++++ sgl-kernel/setup.py | 6 ++++++ 4 files changed, 14 insertions(+) create mode 160000 sgl-kernel/3rdparty/cutlass diff --git a/.gitmodules b/.gitmodules index e69de29bb..3a14f6297 100644 --- a/.gitmodules +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "sgl-kernel/3rdparty/cutlass"] + path = sgl-kernel/3rdparty/cutlass + url = https://github.com/NVIDIA/cutlass.git diff --git a/sgl-kernel/3rdparty/cutlass b/sgl-kernel/3rdparty/cutlass new file mode 160000 index 000000000..bf9da7b76 --- /dev/null +++ b/sgl-kernel/3rdparty/cutlass @@ -0,0 +1 @@ +Subproject commit bf9da7b76c766d7ee7d536afc77880a4ef1f1156 diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index 137e7a9a8..26081a8e7 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -8,6 +8,8 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CUDA_STANDARD 17) set(CMAKE_CUDA_STANDARD_REQUIRED ON) +set(CUTLASS_DIR "3rdparty/cutlass") + # Set CUDA architectures set(CMAKE_CUDA_ARCHITECTURES "75;80;86;89;90") message(STATUS "Building for CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}") @@ -38,6 +40,8 @@ target_include_directories(_kernels ${CMAKE_CURRENT_SOURCE_DIR}/src/sgl-kernel/csrc ${CUDA_INCLUDE_DIRS} ${TORCH_INCLUDE_DIRS} + ${CUTLASS_DIR}/include + ${CUTLASS_DIR}/tools/util/include ) target_link_libraries(_kernels diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py index bfed5f6e5..4a40aeb2d 100644 --- a/sgl-kernel/setup.py +++ b/sgl-kernel/setup.py @@ -58,6 +58,11 @@ def update_wheel_platform_tag(): old_wheel.rename(new_wheel) +cutlass = root / "3rdparty" / "cutlass" +include_dirs = [ + cutlass.resolve() / "include", + cutlass.resolve() / "tools" / "util" / "include", +] nvcc_flags = [ "-O3", "-Xcompiler", @@ -82,6 +87,7 @@ ext_modules = [ "src/sgl-kernel/csrc/moe_align_kernel.cu", "src/sgl-kernel/csrc/sgl_kernel_ops.cu", ], + include_dirs=include_dirs, extra_compile_args={ "nvcc": nvcc_flags, "cxx": cxx_flags,