[Fix] fix fa3 build at cu118 (#5036)
This commit is contained in:
@@ -1,10 +1,12 @@
|
||||
cmake_minimum_required(VERSION 3.26 FATAL_ERROR)
|
||||
project(sgl-kernel LANGUAGES CXX CUDA)
|
||||
|
||||
# we only want to download 3rd, but not build them.
|
||||
# FetchContent_MakeAvailable will build it.
|
||||
cmake_policy(SET CMP0169 OLD)
|
||||
|
||||
include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
|
||||
|
||||
set(BUILD_FA3, OFF)
|
||||
|
||||
find_package(Python COMPONENTS Interpreter Development.Module ${SKBUILD_SABI_COMPONENT} REQUIRED)
|
||||
|
||||
enable_language(CUDA)
|
||||
@@ -22,6 +24,8 @@ elseif ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "11.8")
|
||||
endif()
|
||||
|
||||
find_package(Torch REQUIRED)
|
||||
# clean Torch Flag
|
||||
clear_cuda_arches(CMAKE_FLAG)
|
||||
|
||||
include(FetchContent)
|
||||
|
||||
@@ -53,8 +57,8 @@ FetchContent_Populate(repo-flashinfer)
|
||||
FetchContent_Declare(
|
||||
repo-flash-attention
|
||||
GIT_REPOSITORY https://github.com/sgl-project/sgl-attn
|
||||
GIT_TAG sgl-kernel
|
||||
GIT_SHALLOW OFF
|
||||
GIT_TAG sgl-kernel
|
||||
GIT_SHALLOW OFF
|
||||
)
|
||||
FetchContent_Populate(repo-flash-attention)
|
||||
|
||||
@@ -92,14 +96,13 @@ set(SGL_KERNEL_CUDA_FLAGS
|
||||
"-gencode=arch=compute_90,code=sm_90"
|
||||
"-std=c++17"
|
||||
"-DFLASHINFER_ENABLE_F16"
|
||||
"-DCUTE_USE_PACKED_TUPLE=1"
|
||||
"-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1"
|
||||
"-DCUTLASS_VERSIONS_GENERATED"
|
||||
"-DCUTE_USE_PACKED_TUPLE=1"
|
||||
"-DCUTLASS_TEST_LEVEL=0"
|
||||
"-DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1"
|
||||
"-DCUTLASS_DEBUG_TRACE_LEVEL=0"
|
||||
"--expt-relaxed-constexpr"
|
||||
"--use_fast_math"
|
||||
"-Xcompiler=-Wconversion"
|
||||
"-Xcompiler=-fno-strict-aliasing"
|
||||
)
|
||||
@@ -122,6 +125,7 @@ else()
|
||||
endif()
|
||||
|
||||
if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.4" OR SGL_KERNEL_ENABLE_SM90A)
|
||||
set(BUILD_FA3 ON)
|
||||
list(APPEND SGL_KERNEL_CUDA_FLAGS
|
||||
"-gencode=arch=compute_90a,code=sm_90a"
|
||||
)
|
||||
@@ -152,30 +156,6 @@ string(REPLACE "-D__CUDA_NO_HALF_CONVERSIONS__" "" CMAKE_CUDA_FLAGS "${CMAKE
|
||||
string(REPLACE "-D__CUDA_NO_BFLOAT16_CONVERSIONS__" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}")
|
||||
string(REPLACE "-D__CUDA_NO_HALF2_OPERATORS__" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}")
|
||||
|
||||
# set flash-attention sources file
|
||||
# BF16 source files
|
||||
file(GLOB FA3_BF16_GEN_SRCS
|
||||
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_bf16*_sm90.cu")
|
||||
file(GLOB FA3_BF16_GEN_SRCS_
|
||||
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimdiff_bf16*_sm90.cu")
|
||||
list(APPEND FA3_BF16_GEN_SRCS ${FA3_BF16_GEN_SRCS_})
|
||||
|
||||
# FP16 source files
|
||||
file(GLOB FA3_FP16_GEN_SRCS
|
||||
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_fp16*_sm90.cu")
|
||||
file(GLOB FA3_FP16_GEN_SRCS_
|
||||
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimdiff_fp16*_sm90.cu")
|
||||
list(APPEND FA3_FP16_GEN_SRCS ${FA3_FP16_GEN_SRCS_})
|
||||
|
||||
# FP8 source files
|
||||
file(GLOB FA3_FP8_GEN_SRCS
|
||||
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_e4m3*_sm90.cu")
|
||||
file(GLOB FA3_FP8_GEN_SRCS_
|
||||
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimdiff_e4m3*_sm90.cu")
|
||||
list(APPEND FA3_FP8_GEN_SRCS ${FA3_FP8_GEN_SRCS_})
|
||||
|
||||
set(FA3_GEN_SRCS ${FA3_BF16_GEN_SRCS} ${FA3_FP16_GEN_SRCS} ${FA3_FP8_GEN_SRCS})
|
||||
|
||||
set(SOURCES
|
||||
"csrc/allreduce/trt_reduce_internal.cu"
|
||||
"csrc/allreduce/trt_reduce_kernel.cu"
|
||||
@@ -202,39 +182,94 @@ set(SOURCES
|
||||
"csrc/speculative/eagle_utils.cu"
|
||||
"csrc/speculative/speculative_sampling.cu"
|
||||
"csrc/speculative/packbit.cu"
|
||||
"csrc/torch_extension.cc"
|
||||
"csrc/common_extension.cc"
|
||||
"${repo-flashinfer_SOURCE_DIR}/csrc/norm.cu"
|
||||
"${repo-flashinfer_SOURCE_DIR}/csrc/renorm.cu"
|
||||
"${repo-flashinfer_SOURCE_DIR}/csrc/sampling.cu"
|
||||
"${repo-flash-attention_SOURCE_DIR}/hopper/flash_prepare_scheduler.cu"
|
||||
"${repo-flash-attention_SOURCE_DIR}/hopper/flash_api.cpp"
|
||||
"${repo-flash-attention_SOURCE_DIR}/hopper/flash_fwd_combine.cu"
|
||||
"${FA3_GEN_SRCS}"
|
||||
)
|
||||
|
||||
# Support abi3 for build
|
||||
# set flash-attention sources file
|
||||
# BF16 source files
|
||||
if (BUILD_FA3)
|
||||
set(SGL_FLASH_KERNEL_CUDA_FLAGS
|
||||
"-DNDEBUG"
|
||||
"-DOPERATOR_NAMESPACE=sgl-kernel"
|
||||
"-O3"
|
||||
"-Xcompiler"
|
||||
"-fPIC"
|
||||
"-gencode=arch=compute_90a,code=sm_90a"
|
||||
"-std=c++17"
|
||||
"-DCUTE_USE_PACKED_TUPLE=1"
|
||||
"-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1"
|
||||
"-DCUTLASS_VERSIONS_GENERATED"
|
||||
"-DCUTLASS_TEST_LEVEL=0"
|
||||
"-DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1"
|
||||
"-DCUTLASS_DEBUG_TRACE_LEVEL=0"
|
||||
"--expt-relaxed-constexpr"
|
||||
"--expt-extended-lambda"
|
||||
"--use_fast_math"
|
||||
"-Xcompiler=-Wconversion"
|
||||
"-Xcompiler=-fno-strict-aliasing"
|
||||
)
|
||||
|
||||
file(GLOB FA3_BF16_GEN_SRCS
|
||||
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_bf16*_sm90.cu")
|
||||
file(GLOB FA3_BF16_GEN_SRCS_
|
||||
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimdiff_bf16*_sm90.cu")
|
||||
list(APPEND FA3_BF16_GEN_SRCS ${FA3_BF16_GEN_SRCS_})
|
||||
|
||||
# FP16 source files
|
||||
file(GLOB FA3_FP16_GEN_SRCS
|
||||
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_fp16*_sm90.cu")
|
||||
file(GLOB FA3_FP16_GEN_SRCS_
|
||||
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimdiff_fp16*_sm90.cu")
|
||||
list(APPEND FA3_FP16_GEN_SRCS ${FA3_FP16_GEN_SRCS_})
|
||||
|
||||
# FP8 source files
|
||||
file(GLOB FA3_FP8_GEN_SRCS
|
||||
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_e4m3*_sm90.cu")
|
||||
file(GLOB FA3_FP8_GEN_SRCS_
|
||||
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimdiff_e4m3*_sm90.cu")
|
||||
list(APPEND FA3_FP8_GEN_SRCS ${FA3_FP8_GEN_SRCS_})
|
||||
|
||||
set(FA3_GEN_SRCS ${FA3_BF16_GEN_SRCS} ${FA3_FP16_GEN_SRCS} ${FA3_FP8_GEN_SRCS})
|
||||
|
||||
set(FLASH_SOURCES
|
||||
"csrc/flash_extension.cc"
|
||||
"${repo-flash-attention_SOURCE_DIR}/hopper/flash_prepare_scheduler.cu"
|
||||
"${repo-flash-attention_SOURCE_DIR}/hopper/flash_api.cpp"
|
||||
"${repo-flash-attention_SOURCE_DIR}/hopper/flash_fwd_combine.cu"
|
||||
"${FA3_GEN_SRCS}"
|
||||
)
|
||||
|
||||
Python_add_library(flash_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${FLASH_SOURCES})
|
||||
|
||||
target_compile_options(flash_ops PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:${SGL_FLASH_KERNEL_CUDA_FLAGS}>)
|
||||
target_include_directories(flash_ops PRIVATE ${TORCH_INCLUDE_DIRS})
|
||||
target_link_libraries(flash_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda)
|
||||
|
||||
install(TARGETS flash_ops LIBRARY DESTINATION "sgl_kernel")
|
||||
|
||||
target_compile_definitions(flash_ops PRIVATE
|
||||
FLASHATTENTION_DISABLE_SM8x
|
||||
FLASHATTENTION_DISABLE_BACKWARD
|
||||
FLASHATTENTION_DISABLE_DROPOUT
|
||||
# FLASHATTENTION_DISABLE_ALIBI
|
||||
# FLASHATTENTION_DISABLE_SOFTCAP
|
||||
FLASHATTENTION_DISABLE_UNEVEN_K
|
||||
# FLASHATTENTION_DISABLE_LOCAL
|
||||
FLASHATTENTION_VARLEN_ONLY
|
||||
)
|
||||
endif()
|
||||
|
||||
Python_add_library(common_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${SOURCES})
|
||||
|
||||
target_compile_options(common_ops PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:${SGL_KERNEL_CUDA_FLAGS}>)
|
||||
|
||||
target_include_directories(common_ops PRIVATE ${TORCH_INCLUDE_DIRS})
|
||||
|
||||
target_link_libraries(common_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda cublas cublasLt)
|
||||
|
||||
install(TARGETS common_ops LIBRARY DESTINATION "sgl_kernel")
|
||||
|
||||
# Add some flash-attention custom flag for inference
|
||||
target_compile_definitions(common_ops PRIVATE
|
||||
FLASHATTENTION_DISABLE_SM8x
|
||||
FLASHATTENTION_DISABLE_BACKWARD
|
||||
FLASHATTENTION_DISABLE_DROPOUT
|
||||
# FLASHATTENTION_DISABLE_ALIBI
|
||||
# FLASHATTENTION_DISABLE_SOFTCAP
|
||||
FLASHATTENTION_DISABLE_UNEVEN_K
|
||||
# FLASHATTENTION_DISABLE_LOCAL
|
||||
FLASHATTENTION_VARLEN_ONLY
|
||||
)
|
||||
|
||||
# JIT Logic
|
||||
# DeepGEMM
|
||||
|
||||
|
||||
Reference in New Issue
Block a user