Disable compiling arch below sm_90 in aarch64 by default (#6380)
This commit is contained in:
@@ -83,6 +83,15 @@ if(CCACHE_FOUND AND ENABLE_CCACHE AND DEFINED ENV{CCACHE_DIR})
|
||||
set_property(GLOBAL PROPERTY RULE_LAUNCH_LINK "ccache")
|
||||
endif()
|
||||
|
||||
# Enable gencode below SM90
|
||||
option(ENABLE_BELOW_SM90 "Enable below SM90" ON)
|
||||
|
||||
if (CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64")
|
||||
set(ENABLE_BELOW_SM90 OFF)
|
||||
message(STATUS "For aarch64, disable gencode below SM90 by default")
|
||||
endif()
|
||||
|
||||
|
||||
include_directories(
|
||||
${PROJECT_SOURCE_DIR}/include
|
||||
${PROJECT_SOURCE_DIR}/csrc
|
||||
@@ -98,9 +107,6 @@ set(SGL_KERNEL_CUDA_FLAGS
|
||||
"-O3"
|
||||
"-Xcompiler"
|
||||
"-fPIC"
|
||||
"-gencode=arch=compute_75,code=sm_75"
|
||||
"-gencode=arch=compute_80,code=sm_80"
|
||||
"-gencode=arch=compute_89,code=sm_89"
|
||||
"-gencode=arch=compute_90,code=sm_90"
|
||||
"-std=c++17"
|
||||
"-DFLASHINFER_ENABLE_F16"
|
||||
@@ -130,6 +136,14 @@ option(SGL_KERNEL_ENABLE_FP8 "Enable FP8" ON)
|
||||
option(SGL_KERNEL_ENABLE_FP4 "Enable FP4" OFF)
|
||||
option(SGL_KERNEL_ENABLE_FA3 "Enable FA3" OFF)
|
||||
|
||||
if (ENABLE_BELOW_SM90)
|
||||
list(APPEND SGL_KERNEL_CUDA_FLAGS
|
||||
"-gencode=arch=compute_75,code=sm_75"
|
||||
"-gencode=arch=compute_80,code=sm_80"
|
||||
"-gencode=arch=compute_89,code=sm_89"
|
||||
)
|
||||
endif()
|
||||
|
||||
if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "13.0" OR SGL_KERNEL_ENABLE_SM100A)
|
||||
list(APPEND SGL_KERNEL_CUDA_FLAGS
|
||||
"-gencode=arch=compute_100,code=sm_110"
|
||||
@@ -253,8 +267,6 @@ if (SGL_KERNEL_ENABLE_FA3)
|
||||
"-O3"
|
||||
"-Xcompiler"
|
||||
"-fPIC"
|
||||
"-gencode=arch=compute_80,code=sm_80"
|
||||
"-gencode=arch=compute_86,code=sm_86"
|
||||
"-gencode=arch=compute_90a,code=sm_90a"
|
||||
"-std=c++17"
|
||||
"-DCUTE_USE_PACKED_TUPLE=1"
|
||||
@@ -270,9 +282,15 @@ if (SGL_KERNEL_ENABLE_FA3)
|
||||
"-Xcompiler=-fno-strict-aliasing"
|
||||
)
|
||||
|
||||
# SM8X Logic
|
||||
file(GLOB FA3_SM8X_GEN_SRCS
|
||||
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdim*_sm80.cu")
|
||||
if (ENABLE_BELOW_SM90)
|
||||
list(APPEND SGL_FLASH_KERNEL_CUDA_FLAGS
|
||||
"-gencode=arch=compute_80,code=sm_80"
|
||||
"-gencode=arch=compute_86,code=sm_86"
|
||||
)
|
||||
# SM8X Logic
|
||||
file(GLOB FA3_SM8X_GEN_SRCS
|
||||
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdim*_sm80.cu")
|
||||
endif()
|
||||
|
||||
file(GLOB FA3_BF16_GEN_SRCS
|
||||
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_bf16*_sm90.cu")
|
||||
@@ -313,14 +331,17 @@ if (SGL_KERNEL_ENABLE_FA3)
|
||||
target_link_libraries(flash_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda)
|
||||
|
||||
install(TARGETS flash_ops LIBRARY DESTINATION "sgl_kernel")
|
||||
|
||||
target_compile_definitions(flash_ops PRIVATE
|
||||
# FLASHATTENTION_DISABLE_SM8x
|
||||
set(FLASH_OPS_COMPILE_DEFS
|
||||
FLASHATTENTION_DISABLE_BACKWARD
|
||||
FLASHATTENTION_DISABLE_DROPOUT
|
||||
FLASHATTENTION_DISABLE_UNEVEN_K
|
||||
FLASHATTENTION_VARLEN_ONLY
|
||||
)
|
||||
|
||||
if(NOT ENABLE_BELOW_SM90)
|
||||
list(APPEND FLASH_OPS_COMPILE_DEFS FLASHATTENTION_DISABLE_SM8x)
|
||||
endif()
|
||||
target_compile_definitions(flash_ops PRIVATE ${FLASH_OPS_COMPILE_DEFS})
|
||||
endif()
|
||||
|
||||
# JIT Logic
|
||||
|
||||
Reference in New Issue
Block a user