[Misc] feat: Deepgemm update for sgl-kernel (#8790)
This commit is contained in:
@@ -50,22 +50,17 @@ FetchContent_Declare(
|
||||
)
|
||||
FetchContent_Populate(repo-cutlass)
|
||||
|
||||
# DeepGEMM
|
||||
if("${CUDA_VERSION}" VERSION_EQUAL "12.8")
|
||||
set(DeepGEMM_REPO "https://github.com/sgl-project/DeepGEMM")
|
||||
set(DeepGEMM_TAG "blackwell")
|
||||
elseif("${CUDA_VERSION}" VERSION_EQUAL "12.9")
|
||||
set(DeepGEMM_REPO "https://github.com/sgl-project/DeepGEMM")
|
||||
set(DeepGEMM_TAG "blackwell")
|
||||
else()
|
||||
set(DeepGEMM_REPO "https://github.com/deepseek-ai/DeepGEMM")
|
||||
set(DeepGEMM_TAG "391755ada0ffefa9a6a52b6f14dcaf22d1a463e0")
|
||||
endif()
|
||||
|
||||
FetchContent_Declare(
|
||||
repo-fmt
|
||||
GIT_REPOSITORY https://github.com/fmtlib/fmt
|
||||
GIT_TAG 553ec11ec06fbe0beebfbb45f9dc3c9eabd83d28
|
||||
GIT_SHALLOW OFF
|
||||
)
|
||||
FetchContent_Populate(repo-fmt)
|
||||
FetchContent_Declare(
|
||||
repo-deepgemm
|
||||
GIT_REPOSITORY ${DeepGEMM_REPO}
|
||||
GIT_TAG ${DeepGEMM_TAG}
|
||||
GIT_REPOSITORY https://github.com/sgl-project/DeepGEMM
|
||||
GIT_TAG cabi
|
||||
GIT_SHALLOW OFF
|
||||
)
|
||||
FetchContent_Populate(repo-deepgemm)
|
||||
@@ -422,13 +417,44 @@ if (SGL_KERNEL_ENABLE_FA3)
|
||||
target_compile_definitions(flash_ops PRIVATE ${FLASH_OPS_COMPILE_DEFS})
|
||||
endif()
|
||||
|
||||
# JIT Logic
|
||||
# DeepGEMM
|
||||
# ============================ DeepGEMM (JIT) ============================= #
|
||||
# Create a separate library for DeepGEMM's Python API.
|
||||
# This keeps its compilation isolated from the main common_ops.
|
||||
set(DEEPGEMM_SOURCES
|
||||
"${repo-deepgemm_SOURCE_DIR}/csrc/python_api.cpp"
|
||||
)
|
||||
|
||||
install(DIRECTORY "${repo-deepgemm_SOURCE_DIR}/deep_gemm/"
|
||||
DESTINATION "deep_gemm"
|
||||
PATTERN ".git*" EXCLUDE
|
||||
PATTERN "__pycache__" EXCLUDE)
|
||||
Python_add_library(deep_gemm_cpp MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${DEEPGEMM_SOURCES})
|
||||
|
||||
# Link against necessary libraries, including nvrtc for JIT compilation.
|
||||
target_link_libraries(deep_gemm_cpp PRIVATE ${TORCH_LIBRARIES} c10 cuda nvrtc mscclpp_static)
|
||||
|
||||
# Add include directories needed by DeepGEMM.
|
||||
target_include_directories(deep_gemm_cpp PRIVATE
|
||||
${repo-deepgemm_SOURCE_DIR}/deep_gemm/include
|
||||
${repo-cutlass_SOURCE_DIR}/include
|
||||
${repo-fmt_SOURCE_DIR}/include
|
||||
)
|
||||
|
||||
# Apply the same compile options as common_ops.
|
||||
target_compile_options(deep_gemm_cpp PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:${SGL_KERNEL_CUDA_FLAGS}>)
|
||||
|
||||
# Create an empty __init__.py to make `deepgemm` a Python package.
|
||||
file(WRITE ${CMAKE_CURRENT_BINARY_DIR}/deepgemm_pkg_init.py "")
|
||||
install(
|
||||
FILES ${CMAKE_CURRENT_BINARY_DIR}/deepgemm_pkg_init.py
|
||||
DESTINATION deep_gemm
|
||||
RENAME __init__.py
|
||||
)
|
||||
|
||||
# Install the compiled DeepGEMM API library.
|
||||
install(TARGETS deep_gemm_cpp LIBRARY DESTINATION deep_gemm)
|
||||
|
||||
# Install the source files required by DeepGEMM for runtime JIT compilation.
|
||||
install(
|
||||
DIRECTORY ${repo-deepgemm_SOURCE_DIR}/deep_gemm/
|
||||
DESTINATION deep_gemm
|
||||
)
|
||||
|
||||
install(DIRECTORY "${repo-cutlass_SOURCE_DIR}/include/cute/"
|
||||
DESTINATION "deep_gemm/include/cute")
|
||||
|
||||
Reference in New Issue
Block a user