[Feat] Add sparse attn to sgl-kernel (#5327)
This commit is contained in:
@@ -5,8 +5,6 @@ cmake_policy(SET CMP0169 OLD)
|
||||
|
||||
include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
|
||||
|
||||
set(BUILD_FA3, OFF)
|
||||
|
||||
find_package(Python COMPONENTS Interpreter Development.Module ${SKBUILD_SABI_COMPONENT} REQUIRED)
|
||||
|
||||
enable_language(CUDA)
|
||||
@@ -80,7 +78,6 @@ include_directories(
|
||||
${repo-cutlass_SOURCE_DIR}/examples/common
|
||||
${repo-flashinfer_SOURCE_DIR}/include
|
||||
${repo-flashinfer_SOURCE_DIR}/csrc
|
||||
${repo-flash-attention_SOURCE_DIR}/hopper
|
||||
)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
@@ -115,6 +112,9 @@ option(SGL_KERNEL_ENABLE_BF16 "Enable BF16" ON)
|
||||
option(SGL_KERNEL_ENABLE_FP8 "Enable FP8" ON)
|
||||
option(SGL_KERNEL_ENABLE_FP4 "Enable FP4" OFF)
|
||||
|
||||
option(SGL_KERNEL_ENABLE_FA3 "Enable FA3" OFF)
|
||||
|
||||
|
||||
if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.8" OR SGL_KERNEL_ENABLE_SM100A)
|
||||
list(APPEND SGL_KERNEL_CUDA_FLAGS
|
||||
"-gencode=arch=compute_100,code=sm_100"
|
||||
@@ -127,7 +127,7 @@ else()
|
||||
endif()
|
||||
|
||||
if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.4" OR SGL_KERNEL_ENABLE_SM90A)
|
||||
set(BUILD_FA3 ON)
|
||||
set(SGL_KERNEL_ENABLE_FA3 ON)
|
||||
list(APPEND SGL_KERNEL_CUDA_FLAGS
|
||||
"-gencode=arch=compute_90a,code=sm_90a"
|
||||
)
|
||||
@@ -187,11 +187,33 @@ set(SOURCES
|
||||
"${repo-flashinfer_SOURCE_DIR}/csrc/norm.cu"
|
||||
"${repo-flashinfer_SOURCE_DIR}/csrc/renorm.cu"
|
||||
"${repo-flashinfer_SOURCE_DIR}/csrc/sampling.cu"
|
||||
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_bf16_causal_sm80.cu"
|
||||
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_bf16_sm80.cu"
|
||||
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_fp16_causal_sm80.cu"
|
||||
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_fp16_sm80.cu"
|
||||
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/flash_sparse_api.cpp"
|
||||
)
|
||||
|
||||
Python_add_library(common_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${SOURCES})
|
||||
|
||||
target_compile_options(common_ops PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:${SGL_KERNEL_CUDA_FLAGS}>)
|
||||
target_include_directories(common_ops PRIVATE
|
||||
${TORCH_INCLUDE_DIRS}
|
||||
${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src)
|
||||
target_link_libraries(common_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda cublas cublasLt)
|
||||
|
||||
target_compile_definitions(common_ops PRIVATE
|
||||
FLASHATTENTION_DISABLE_BACKWARD
|
||||
FLASHATTENTION_DISABLE_DROPOUT
|
||||
FLASHATTENTION_DISABLE_UNEVEN_K
|
||||
)
|
||||
|
||||
install(TARGETS common_ops LIBRARY DESTINATION "sgl_kernel")
|
||||
|
||||
# ============================ Optional Install ============================= #
|
||||
# set flash-attention sources file
|
||||
# BF16 source files
|
||||
if (BUILD_FA3)
|
||||
if (SGL_KERNEL_ENABLE_FA3)
|
||||
set(SGL_FLASH_KERNEL_CUDA_FLAGS
|
||||
"-DNDEBUG"
|
||||
"-DOPERATOR_NAMESPACE=sgl-kernel"
|
||||
@@ -246,7 +268,9 @@ if (BUILD_FA3)
|
||||
Python_add_library(flash_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${FLASH_SOURCES})
|
||||
|
||||
target_compile_options(flash_ops PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:${SGL_FLASH_KERNEL_CUDA_FLAGS}>)
|
||||
target_include_directories(flash_ops PRIVATE ${TORCH_INCLUDE_DIRS})
|
||||
target_include_directories(flash_ops PRIVATE
|
||||
${TORCH_INCLUDE_DIRS}
|
||||
${repo-flash-attention_SOURCE_DIR}/hopper)
|
||||
target_link_libraries(flash_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda)
|
||||
|
||||
install(TARGETS flash_ops LIBRARY DESTINATION "sgl_kernel")
|
||||
@@ -260,14 +284,6 @@ if (BUILD_FA3)
|
||||
)
|
||||
endif()
|
||||
|
||||
Python_add_library(common_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${SOURCES})
|
||||
|
||||
target_compile_options(common_ops PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:${SGL_KERNEL_CUDA_FLAGS}>)
|
||||
target_include_directories(common_ops PRIVATE ${TORCH_INCLUDE_DIRS})
|
||||
target_link_libraries(common_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda cublas cublasLt)
|
||||
|
||||
install(TARGETS common_ops LIBRARY DESTINATION "sgl_kernel")
|
||||
|
||||
# JIT Logic
|
||||
# DeepGEMM
|
||||
|
||||
|
||||
Reference in New Issue
Block a user