[Feat] Add sparse attn to sgl-kernel (#5327)

This commit is contained in:
PGFLMG
2025-04-13 02:36:36 +08:00
committed by GitHub
parent bc92107b03
commit 4879e50c6d
5 changed files with 625 additions and 14 deletions

View File

@@ -5,8 +5,6 @@ cmake_policy(SET CMP0169 OLD)
include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
set(BUILD_FA3, OFF)
find_package(Python COMPONENTS Interpreter Development.Module ${SKBUILD_SABI_COMPONENT} REQUIRED)
enable_language(CUDA)
@@ -80,7 +78,6 @@ include_directories(
${repo-cutlass_SOURCE_DIR}/examples/common
${repo-flashinfer_SOURCE_DIR}/include
${repo-flashinfer_SOURCE_DIR}/csrc
${repo-flash-attention_SOURCE_DIR}/hopper
)
set(CMAKE_CXX_STANDARD 17)
@@ -115,6 +112,9 @@ option(SGL_KERNEL_ENABLE_BF16 "Enable BF16" ON)
option(SGL_KERNEL_ENABLE_FP8 "Enable FP8" ON)
option(SGL_KERNEL_ENABLE_FP4 "Enable FP4" OFF)
option(SGL_KERNEL_ENABLE_FA3 "Enable FA3" OFF)
if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.8" OR SGL_KERNEL_ENABLE_SM100A)
list(APPEND SGL_KERNEL_CUDA_FLAGS
"-gencode=arch=compute_100,code=sm_100"
@@ -127,7 +127,7 @@ else()
endif()
if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.4" OR SGL_KERNEL_ENABLE_SM90A)
set(BUILD_FA3 ON)
set(SGL_KERNEL_ENABLE_FA3 ON)
list(APPEND SGL_KERNEL_CUDA_FLAGS
"-gencode=arch=compute_90a,code=sm_90a"
)
@@ -187,11 +187,33 @@ set(SOURCES
"${repo-flashinfer_SOURCE_DIR}/csrc/norm.cu"
"${repo-flashinfer_SOURCE_DIR}/csrc/renorm.cu"
"${repo-flashinfer_SOURCE_DIR}/csrc/sampling.cu"
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_bf16_causal_sm80.cu"
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_bf16_sm80.cu"
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_fp16_causal_sm80.cu"
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_fp16_sm80.cu"
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/flash_sparse_api.cpp"
)
Python_add_library(common_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${SOURCES})
target_compile_options(common_ops PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:${SGL_KERNEL_CUDA_FLAGS}>)
target_include_directories(common_ops PRIVATE
${TORCH_INCLUDE_DIRS}
${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src)
target_link_libraries(common_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda cublas cublasLt)
target_compile_definitions(common_ops PRIVATE
FLASHATTENTION_DISABLE_BACKWARD
FLASHATTENTION_DISABLE_DROPOUT
FLASHATTENTION_DISABLE_UNEVEN_K
)
install(TARGETS common_ops LIBRARY DESTINATION "sgl_kernel")
# ============================ Optional Install ============================= #
# set flash-attention sources file
# BF16 source files
if (BUILD_FA3)
if (SGL_KERNEL_ENABLE_FA3)
set(SGL_FLASH_KERNEL_CUDA_FLAGS
"-DNDEBUG"
"-DOPERATOR_NAMESPACE=sgl-kernel"
@@ -246,7 +268,9 @@ if (BUILD_FA3)
Python_add_library(flash_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${FLASH_SOURCES})
target_compile_options(flash_ops PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:${SGL_FLASH_KERNEL_CUDA_FLAGS}>)
target_include_directories(flash_ops PRIVATE ${TORCH_INCLUDE_DIRS})
target_include_directories(flash_ops PRIVATE
${TORCH_INCLUDE_DIRS}
${repo-flash-attention_SOURCE_DIR}/hopper)
target_link_libraries(flash_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda)
install(TARGETS flash_ops LIBRARY DESTINATION "sgl_kernel")
@@ -260,14 +284,6 @@ if (BUILD_FA3)
)
endif()
Python_add_library(common_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${SOURCES})
target_compile_options(common_ops PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:${SGL_KERNEL_CUDA_FLAGS}>)
target_include_directories(common_ops PRIVATE ${TORCH_INCLUDE_DIRS})
target_link_libraries(common_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda cublas cublasLt)
install(TARGETS common_ops LIBRARY DESTINATION "sgl_kernel")
# JIT Logic
# DeepGEMM