diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index 71f77d51b..489e4563a 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -83,6 +83,15 @@ if(CCACHE_FOUND AND ENABLE_CCACHE AND DEFINED ENV{CCACHE_DIR}) set_property(GLOBAL PROPERTY RULE_LAUNCH_LINK "ccache") endif() +# Enable gencode below SM90 +option(ENABLE_BELOW_SM90 "Enable below SM90" ON) + +if (CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64") + set(ENABLE_BELOW_SM90 OFF) + message(STATUS "For aarch64, disable gencode below SM90 by default") +endif() + + include_directories( ${PROJECT_SOURCE_DIR}/include ${PROJECT_SOURCE_DIR}/csrc @@ -98,9 +107,6 @@ set(SGL_KERNEL_CUDA_FLAGS "-O3" "-Xcompiler" "-fPIC" - "-gencode=arch=compute_75,code=sm_75" - "-gencode=arch=compute_80,code=sm_80" - "-gencode=arch=compute_89,code=sm_89" "-gencode=arch=compute_90,code=sm_90" "-std=c++17" "-DFLASHINFER_ENABLE_F16" @@ -130,6 +136,14 @@ option(SGL_KERNEL_ENABLE_FP8 "Enable FP8" ON) option(SGL_KERNEL_ENABLE_FP4 "Enable FP4" OFF) option(SGL_KERNEL_ENABLE_FA3 "Enable FA3" OFF) +if (ENABLE_BELOW_SM90) + list(APPEND SGL_KERNEL_CUDA_FLAGS + "-gencode=arch=compute_75,code=sm_75" + "-gencode=arch=compute_80,code=sm_80" + "-gencode=arch=compute_89,code=sm_89" + ) +endif() + if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "13.0" OR SGL_KERNEL_ENABLE_SM100A) list(APPEND SGL_KERNEL_CUDA_FLAGS "-gencode=arch=compute_100,code=sm_110" @@ -253,8 +267,6 @@ if (SGL_KERNEL_ENABLE_FA3) "-O3" "-Xcompiler" "-fPIC" - "-gencode=arch=compute_80,code=sm_80" - "-gencode=arch=compute_86,code=sm_86" "-gencode=arch=compute_90a,code=sm_90a" "-std=c++17" "-DCUTE_USE_PACKED_TUPLE=1" @@ -270,9 +282,15 @@ if (SGL_KERNEL_ENABLE_FA3) "-Xcompiler=-fno-strict-aliasing" ) - # SM8X Logic - file(GLOB FA3_SM8X_GEN_SRCS - "${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdim*_sm80.cu") + if (ENABLE_BELOW_SM90) + list(APPEND SGL_FLASH_KERNEL_CUDA_FLAGS + "-gencode=arch=compute_80,code=sm_80" + "-gencode=arch=compute_86,code=sm_86" + ) + # SM8X Logic + file(GLOB FA3_SM8X_GEN_SRCS + "${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdim*_sm80.cu") + endif() file(GLOB FA3_BF16_GEN_SRCS "${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_bf16*_sm90.cu") @@ -313,14 +331,17 @@ if (SGL_KERNEL_ENABLE_FA3) target_link_libraries(flash_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda) install(TARGETS flash_ops LIBRARY DESTINATION "sgl_kernel") - - target_compile_definitions(flash_ops PRIVATE - # FLASHATTENTION_DISABLE_SM8x + set(FLASH_OPS_COMPILE_DEFS FLASHATTENTION_DISABLE_BACKWARD FLASHATTENTION_DISABLE_DROPOUT FLASHATTENTION_DISABLE_UNEVEN_K FLASHATTENTION_VARLEN_ONLY ) + + if(NOT ENABLE_BELOW_SM90) + list(APPEND FLASH_OPS_COMPILE_DEFS FLASHATTENTION_DISABLE_SM8x) + endif() + target_compile_definitions(flash_ops PRIVATE ${FLASH_OPS_COMPILE_DEFS}) endif() # JIT Logic