[Feat] Scale up fa3 kernel to sm8x arch (#5912)
Co-authored-by: zhyncs <me@zhyncs.com>
This commit is contained in:
@@ -121,12 +121,12 @@ set(SGL_KERNEL_CUDA_FLAGS
|
||||
# "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage"
|
||||
)
|
||||
|
||||
option(SGL_KERNEL_ENABLE_SM100A "Enable SM100A" OFF)
|
||||
option(SGL_KERNEL_ENABLE_SM90A "Enable SM90A" OFF)
|
||||
option(SGL_KERNEL_ENABLE_BF16 "Enable BF16" ON)
|
||||
option(SGL_KERNEL_ENABLE_FP8 "Enable FP8" ON)
|
||||
option(SGL_KERNEL_ENABLE_FP4 "Enable FP4" OFF)
|
||||
option(SGL_KERNEL_ENABLE_FA3 "Enable FA3" OFF)
|
||||
option(SGL_KERNEL_ENABLE_SM100A "Enable SM100A" OFF)
|
||||
option(SGL_KERNEL_ENABLE_SM90A "Enable SM90A" OFF)
|
||||
option(SGL_KERNEL_ENABLE_BF16 "Enable BF16" ON)
|
||||
option(SGL_KERNEL_ENABLE_FP8 "Enable FP8" ON)
|
||||
option(SGL_KERNEL_ENABLE_FP4 "Enable FP4" OFF)
|
||||
option(SGL_KERNEL_ENABLE_FA3 "Enable FA3" OFF)
|
||||
|
||||
if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.8" OR SGL_KERNEL_ENABLE_SM100A)
|
||||
list(APPEND SGL_KERNEL_CUDA_FLAGS
|
||||
@@ -233,7 +233,7 @@ install(TARGETS common_ops LIBRARY DESTINATION sgl_kernel)
|
||||
|
||||
# ============================ Optional Install ============================= #
|
||||
# set flash-attention sources file
|
||||
# BF16 source files
|
||||
# Now FA3 support sm80/sm86/sm90
|
||||
if (SGL_KERNEL_ENABLE_FA3)
|
||||
set(SGL_FLASH_KERNEL_CUDA_FLAGS
|
||||
"-DNDEBUG"
|
||||
@@ -241,6 +241,8 @@ if (SGL_KERNEL_ENABLE_FA3)
|
||||
"-O3"
|
||||
"-Xcompiler"
|
||||
"-fPIC"
|
||||
"-gencode=arch=compute_80,code=sm_80"
|
||||
"-gencode=arch=compute_86,code=sm_86"
|
||||
"-gencode=arch=compute_90a,code=sm_90a"
|
||||
"-std=c++17"
|
||||
"-DCUTE_USE_PACKED_TUPLE=1"
|
||||
@@ -256,6 +258,10 @@ if (SGL_KERNEL_ENABLE_FA3)
|
||||
"-Xcompiler=-fno-strict-aliasing"
|
||||
)
|
||||
|
||||
# SM8X Logic
|
||||
file(GLOB FA3_SM8X_GEN_SRCS
|
||||
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdim*_sm80.cu")
|
||||
|
||||
file(GLOB FA3_BF16_GEN_SRCS
|
||||
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_bf16*_sm90.cu")
|
||||
file(GLOB FA3_BF16_GEN_SRCS_
|
||||
@@ -276,7 +282,7 @@ if (SGL_KERNEL_ENABLE_FA3)
|
||||
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimdiff_e4m3*_sm90.cu")
|
||||
list(APPEND FA3_FP8_GEN_SRCS ${FA3_FP8_GEN_SRCS_})
|
||||
|
||||
set(FA3_GEN_SRCS ${FA3_BF16_GEN_SRCS} ${FA3_FP16_GEN_SRCS} ${FA3_FP8_GEN_SRCS})
|
||||
set(FA3_GEN_SRCS ${FA3_BF16_GEN_SRCS} ${FA3_FP16_GEN_SRCS} ${FA3_FP8_GEN_SRCS} ${FA3_SM8X_GEN_SRCS})
|
||||
|
||||
set(FLASH_SOURCES
|
||||
"csrc/flash_extension.cc"
|
||||
@@ -297,7 +303,7 @@ if (SGL_KERNEL_ENABLE_FA3)
|
||||
install(TARGETS flash_ops LIBRARY DESTINATION "sgl_kernel")
|
||||
|
||||
target_compile_definitions(flash_ops PRIVATE
|
||||
FLASHATTENTION_DISABLE_SM8x
|
||||
# FLASHATTENTION_DISABLE_SM8x
|
||||
FLASHATTENTION_DISABLE_BACKWARD
|
||||
FLASHATTENTION_DISABLE_DROPOUT
|
||||
FLASHATTENTION_DISABLE_UNEVEN_K
|
||||
|
||||
Reference in New Issue
Block a user