[feat] add fa3 in sgl-kernel (#4902)
Co-authored-by: Sleepcoo <Sleepcoo@gmail.com>
This commit is contained in:
@@ -25,6 +25,7 @@ find_package(Torch REQUIRED)
|
||||
|
||||
include(FetchContent)
|
||||
|
||||
# cutlass
|
||||
FetchContent_Declare(
|
||||
repo-cutlass
|
||||
GIT_REPOSITORY https://github.com/NVIDIA/cutlass
|
||||
@@ -32,6 +33,7 @@ FetchContent_Declare(
|
||||
GIT_SHALLOW ON
|
||||
)
|
||||
FetchContent_Populate(repo-cutlass)
|
||||
# DeepGEMM
|
||||
FetchContent_Declare(
|
||||
repo-deepgemm
|
||||
GIT_REPOSITORY https://github.com/deepseek-ai/DeepGEMM
|
||||
@@ -39,6 +41,7 @@ FetchContent_Declare(
|
||||
GIT_SHALLOW ON
|
||||
)
|
||||
FetchContent_Populate(repo-deepgemm)
|
||||
# flashinfer
|
||||
FetchContent_Declare(
|
||||
repo-flashinfer
|
||||
GIT_REPOSITORY https://github.com/sgl-project/flashinfer
|
||||
@@ -46,6 +49,15 @@ FetchContent_Declare(
|
||||
GIT_SHALLOW OFF
|
||||
)
|
||||
FetchContent_Populate(repo-flashinfer)
|
||||
# flash-attention
|
||||
FetchContent_Declare(
|
||||
repo-flash-attention
|
||||
GIT_REPOSITORY https://github.com/sgl-project/sgl-attn
|
||||
GIT_TAG sgl-kernel
|
||||
GIT_SHALLOW OFF
|
||||
)
|
||||
FetchContent_Populate(repo-flash-attention)
|
||||
|
||||
|
||||
include_directories(
|
||||
${PROJECT_SOURCE_DIR}/include
|
||||
@@ -54,6 +66,7 @@ include_directories(
|
||||
${repo-cutlass_SOURCE_DIR}/tools/util/include
|
||||
${repo-flashinfer_SOURCE_DIR}/include
|
||||
${repo-flashinfer_SOURCE_DIR}/csrc
|
||||
${repo-flash-attention_SOURCE_DIR}/hopper
|
||||
)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
@@ -78,6 +91,7 @@ set(SGL_KERNEL_CUDA_FLAGS
|
||||
"-DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1"
|
||||
"-DCUTLASS_DEBUG_TRACE_LEVEL=0"
|
||||
"--expt-relaxed-constexpr"
|
||||
"--use_fast_math"
|
||||
"-Xcompiler=-Wconversion"
|
||||
"-Xcompiler=-fno-strict-aliasing"
|
||||
)
|
||||
@@ -130,6 +144,30 @@ string(REPLACE "-D__CUDA_NO_HALF_CONVERSIONS__" "" CMAKE_CUDA_FLAGS "${CMAKE
|
||||
string(REPLACE "-D__CUDA_NO_BFLOAT16_CONVERSIONS__" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}")
|
||||
string(REPLACE "-D__CUDA_NO_HALF2_OPERATORS__" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}")
|
||||
|
||||
# set flash-attention sources file
|
||||
# BF16 source files
|
||||
file(GLOB FA3_BF16_GEN_SRCS
|
||||
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_bf16*_sm90.cu")
|
||||
file(GLOB FA3_BF16_GEN_SRCS_
|
||||
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimdiff_bf16*_sm90.cu")
|
||||
list(APPEND FA3_BF16_GEN_SRCS ${FA3_BF16_GEN_SRCS_})
|
||||
|
||||
# FP16 source files
|
||||
file(GLOB FA3_FP16_GEN_SRCS
|
||||
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_fp16*_sm90.cu")
|
||||
file(GLOB FA3_FP16_GEN_SRCS_
|
||||
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimdiff_fp16*_sm90.cu")
|
||||
list(APPEND FA3_FP16_GEN_SRCS ${FA3_FP16_GEN_SRCS_})
|
||||
|
||||
# FP8 source files
|
||||
file(GLOB FA3_FP8_GEN_SRCS
|
||||
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_e4m3*_sm90.cu")
|
||||
file(GLOB FA3_FP8_GEN_SRCS_
|
||||
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimdiff_e4m3*_sm90.cu")
|
||||
list(APPEND FA3_FP8_GEN_SRCS ${FA3_FP8_GEN_SRCS_})
|
||||
|
||||
set(FA3_GEN_SRCS ${FA3_BF16_GEN_SRCS} ${FA3_FP16_GEN_SRCS} ${FA3_FP8_GEN_SRCS})
|
||||
|
||||
set(SOURCES
|
||||
"csrc/allreduce/trt_reduce_internal.cu"
|
||||
"csrc/allreduce/trt_reduce_kernel.cu"
|
||||
@@ -160,6 +198,10 @@ set(SOURCES
|
||||
"${repo-flashinfer_SOURCE_DIR}/csrc/norm.cu"
|
||||
"${repo-flashinfer_SOURCE_DIR}/csrc/renorm.cu"
|
||||
"${repo-flashinfer_SOURCE_DIR}/csrc/sampling.cu"
|
||||
"${repo-flash-attention_SOURCE_DIR}/hopper/flash_prepare_scheduler.cu"
|
||||
"${repo-flash-attention_SOURCE_DIR}/hopper/flash_api.cpp"
|
||||
"${repo-flash-attention_SOURCE_DIR}/hopper/flash_fwd_combine.cu"
|
||||
"${FA3_GEN_SRCS}"
|
||||
)
|
||||
|
||||
# Support abi3 for build
|
||||
@@ -173,6 +215,18 @@ target_link_libraries(common_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda cublas cubl
|
||||
|
||||
install(TARGETS common_ops LIBRARY DESTINATION "sgl_kernel")
|
||||
|
||||
# Add some flash-attention custom flag for inference
|
||||
target_compile_definitions(common_ops PRIVATE
|
||||
FLASHATTENTION_DISABLE_SM8x
|
||||
FLASHATTENTION_DISABLE_BACKWARD
|
||||
FLASHATTENTION_DISABLE_DROPOUT
|
||||
# FLASHATTENTION_DISABLE_ALIBI
|
||||
# FLASHATTENTION_DISABLE_SOFTCAP
|
||||
FLASHATTENTION_DISABLE_UNEVEN_K
|
||||
# FLASHATTENTION_DISABLE_LOCAL
|
||||
FLASHATTENTION_VARLEN_ONLY
|
||||
)
|
||||
|
||||
# JIT Logic
|
||||
# DeepGEMM
|
||||
|
||||
|
||||
Reference in New Issue
Block a user