[feat] add fa3 in sgl-kernel (#4902)

Co-authored-by: Sleepcoo <Sleepcoo@gmail.com>
This commit is contained in:
yinfan98
2025-03-31 03:57:10 +08:00
committed by GitHub
parent 9adf178cc2
commit 37c66ec856
7 changed files with 1300 additions and 0 deletions

View File

@@ -25,6 +25,7 @@ find_package(Torch REQUIRED)
include(FetchContent)
# cutlass
FetchContent_Declare(
repo-cutlass
GIT_REPOSITORY https://github.com/NVIDIA/cutlass
@@ -32,6 +33,7 @@ FetchContent_Declare(
GIT_SHALLOW ON
)
FetchContent_Populate(repo-cutlass)
# DeepGEMM
FetchContent_Declare(
repo-deepgemm
GIT_REPOSITORY https://github.com/deepseek-ai/DeepGEMM
@@ -39,6 +41,7 @@ FetchContent_Declare(
GIT_SHALLOW ON
)
FetchContent_Populate(repo-deepgemm)
# flashinfer
FetchContent_Declare(
repo-flashinfer
GIT_REPOSITORY https://github.com/sgl-project/flashinfer
@@ -46,6 +49,15 @@ FetchContent_Declare(
GIT_SHALLOW OFF
)
FetchContent_Populate(repo-flashinfer)
# flash-attention
FetchContent_Declare(
repo-flash-attention
GIT_REPOSITORY https://github.com/sgl-project/sgl-attn
GIT_TAG sgl-kernel
GIT_SHALLOW OFF
)
FetchContent_Populate(repo-flash-attention)
include_directories(
${PROJECT_SOURCE_DIR}/include
@@ -54,6 +66,7 @@ include_directories(
${repo-cutlass_SOURCE_DIR}/tools/util/include
${repo-flashinfer_SOURCE_DIR}/include
${repo-flashinfer_SOURCE_DIR}/csrc
${repo-flash-attention_SOURCE_DIR}/hopper
)
set(CMAKE_CXX_STANDARD 17)
@@ -78,6 +91,7 @@ set(SGL_KERNEL_CUDA_FLAGS
"-DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1"
"-DCUTLASS_DEBUG_TRACE_LEVEL=0"
"--expt-relaxed-constexpr"
"--use_fast_math"
"-Xcompiler=-Wconversion"
"-Xcompiler=-fno-strict-aliasing"
)
@@ -130,6 +144,30 @@ string(REPLACE "-D__CUDA_NO_HALF_CONVERSIONS__" "" CMAKE_CUDA_FLAGS "${CMAKE
string(REPLACE "-D__CUDA_NO_BFLOAT16_CONVERSIONS__" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}")
string(REPLACE "-D__CUDA_NO_HALF2_OPERATORS__" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}")
# set flash-attention sources file
# BF16 source files
file(GLOB FA3_BF16_GEN_SRCS
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_bf16*_sm90.cu")
file(GLOB FA3_BF16_GEN_SRCS_
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimdiff_bf16*_sm90.cu")
list(APPEND FA3_BF16_GEN_SRCS ${FA3_BF16_GEN_SRCS_})
# FP16 source files
file(GLOB FA3_FP16_GEN_SRCS
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_fp16*_sm90.cu")
file(GLOB FA3_FP16_GEN_SRCS_
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimdiff_fp16*_sm90.cu")
list(APPEND FA3_FP16_GEN_SRCS ${FA3_FP16_GEN_SRCS_})
# FP8 source files
file(GLOB FA3_FP8_GEN_SRCS
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_e4m3*_sm90.cu")
file(GLOB FA3_FP8_GEN_SRCS_
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimdiff_e4m3*_sm90.cu")
list(APPEND FA3_FP8_GEN_SRCS ${FA3_FP8_GEN_SRCS_})
set(FA3_GEN_SRCS ${FA3_BF16_GEN_SRCS} ${FA3_FP16_GEN_SRCS} ${FA3_FP8_GEN_SRCS})
set(SOURCES
"csrc/allreduce/trt_reduce_internal.cu"
"csrc/allreduce/trt_reduce_kernel.cu"
@@ -160,6 +198,10 @@ set(SOURCES
"${repo-flashinfer_SOURCE_DIR}/csrc/norm.cu"
"${repo-flashinfer_SOURCE_DIR}/csrc/renorm.cu"
"${repo-flashinfer_SOURCE_DIR}/csrc/sampling.cu"
"${repo-flash-attention_SOURCE_DIR}/hopper/flash_prepare_scheduler.cu"
"${repo-flash-attention_SOURCE_DIR}/hopper/flash_api.cpp"
"${repo-flash-attention_SOURCE_DIR}/hopper/flash_fwd_combine.cu"
"${FA3_GEN_SRCS}"
)
# Support abi3 for build
@@ -173,6 +215,18 @@ target_link_libraries(common_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda cublas cubl
install(TARGETS common_ops LIBRARY DESTINATION "sgl_kernel")
# Add some flash-attention custom flag for inference
target_compile_definitions(common_ops PRIVATE
FLASHATTENTION_DISABLE_SM8x
FLASHATTENTION_DISABLE_BACKWARD
FLASHATTENTION_DISABLE_DROPOUT
# FLASHATTENTION_DISABLE_ALIBI
# FLASHATTENTION_DISABLE_SOFTCAP
FLASHATTENTION_DISABLE_UNEVEN_K
# FLASHATTENTION_DISABLE_LOCAL
FLASHATTENTION_VARLEN_ONLY
)
# JIT Logic
# DeepGEMM