Sync from v0.13
This commit is contained in:
@@ -1,29 +1,56 @@
|
||||
include(FetchContent)
|
||||
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_EXTENSIONS ON)
|
||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||
|
||||
if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
||||
set(MACOSX_FOUND TRUE)
|
||||
endif()
|
||||
|
||||
|
||||
#
|
||||
# Define environment variables for special configurations
|
||||
#
|
||||
if(DEFINED ENV{VLLM_CPU_AVX512BF16})
|
||||
set(ENABLE_AVX512BF16 ON)
|
||||
endif()
|
||||
set(ENABLE_AVX512BF16 $ENV{VLLM_CPU_AVX512BF16})
|
||||
set(ENABLE_AVX512VNNI $ENV{VLLM_CPU_AVX512VNNI})
|
||||
set(ENABLE_AMXBF16 $ENV{VLLM_CPU_AMXBF16})
|
||||
|
||||
include_directories("${CMAKE_SOURCE_DIR}/csrc")
|
||||
|
||||
|
||||
set (ENABLE_NUMA TRUE)
|
||||
|
||||
#
|
||||
# Check the compile flags
|
||||
#
|
||||
list(APPEND CXX_COMPILE_FLAGS
|
||||
"-fopenmp"
|
||||
"-DVLLM_CPU_EXTENSION")
|
||||
|
||||
execute_process(COMMAND cat /proc/cpuinfo
|
||||
RESULT_VARIABLE CPUINFO_RET
|
||||
OUTPUT_VARIABLE CPUINFO)
|
||||
|
||||
if (NOT CPUINFO_RET EQUAL 0)
|
||||
message(FATAL_ERROR "Failed to check CPU features via /proc/cpuinfo")
|
||||
if (CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64")
|
||||
list(APPEND CXX_COMPILE_FLAGS
|
||||
"-mf16c"
|
||||
)
|
||||
endif()
|
||||
|
||||
if(MACOSX_FOUND)
|
||||
list(APPEND CXX_COMPILE_FLAGS
|
||||
"-DVLLM_CPU_EXTENSION")
|
||||
else()
|
||||
list(APPEND CXX_COMPILE_FLAGS
|
||||
"-fopenmp"
|
||||
"-DVLLM_CPU_EXTENSION")
|
||||
endif()
|
||||
|
||||
if (NOT MACOSX_FOUND)
|
||||
execute_process(COMMAND cat /proc/cpuinfo
|
||||
RESULT_VARIABLE CPUINFO_RET
|
||||
OUTPUT_VARIABLE CPUINFO)
|
||||
if (NOT CPUINFO_RET EQUAL 0)
|
||||
message(FATAL_ERROR "Failed to check CPU features via /proc/cpuinfo")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
|
||||
function (find_isa CPUINFO TARGET OUT)
|
||||
string(FIND ${CPUINFO} ${TARGET} ISA_FOUND)
|
||||
if(NOT ISA_FOUND EQUAL -1)
|
||||
@@ -33,9 +60,52 @@ function (find_isa CPUINFO TARGET OUT)
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
find_isa(${CPUINFO} "avx512f" AVX512_FOUND)
|
||||
|
||||
if (AVX512_FOUND)
|
||||
function(check_sysctl TARGET OUT)
|
||||
execute_process(COMMAND sysctl -n "${TARGET}"
|
||||
RESULT_VARIABLE SYSCTL_RET
|
||||
OUTPUT_VARIABLE SYSCTL_INFO
|
||||
ERROR_QUIET
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE)
|
||||
if(SYSCTL_RET EQUAL 0 AND
|
||||
(SYSCTL_INFO STREQUAL "1" OR SYSCTL_INFO GREATER 0))
|
||||
set(${OUT} ON PARENT_SCOPE)
|
||||
else()
|
||||
set(${OUT} OFF PARENT_SCOPE)
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
|
||||
function (is_avx512_disabled OUT)
|
||||
set(DISABLE_AVX512 $ENV{VLLM_CPU_DISABLE_AVX512})
|
||||
if(DISABLE_AVX512 AND DISABLE_AVX512 STREQUAL "true")
|
||||
set(${OUT} ON PARENT_SCOPE)
|
||||
else()
|
||||
set(${OUT} OFF PARENT_SCOPE)
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
is_avx512_disabled(AVX512_DISABLED)
|
||||
|
||||
if (MACOSX_FOUND AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
|
||||
message(STATUS "Apple Silicon Detected")
|
||||
set(APPLE_SILICON_FOUND TRUE)
|
||||
set(ENABLE_NUMA OFF)
|
||||
check_sysctl(hw.optional.neon ASIMD_FOUND)
|
||||
check_sysctl(hw.optional.arm.FEAT_BF16 ARM_BF16_FOUND)
|
||||
else()
|
||||
find_isa(${CPUINFO} "avx2" AVX2_FOUND)
|
||||
find_isa(${CPUINFO} "avx512f" AVX512_FOUND)
|
||||
find_isa(${CPUINFO} "Power11" POWER11_FOUND)
|
||||
find_isa(${CPUINFO} "POWER10" POWER10_FOUND)
|
||||
find_isa(${CPUINFO} "POWER9" POWER9_FOUND)
|
||||
find_isa(${CPUINFO} "asimd" ASIMD_FOUND) # Check for ARM NEON support
|
||||
find_isa(${CPUINFO} "bf16" ARM_BF16_FOUND) # Check for ARM BF16 support
|
||||
find_isa(${CPUINFO} "S390" S390_FOUND)
|
||||
find_isa(${CPUINFO} "v" RVV_FOUND) # Check for RISC-V RVV support
|
||||
endif()
|
||||
|
||||
if (AVX512_FOUND AND NOT AVX512_DISABLED)
|
||||
list(APPEND CXX_COMPILE_FLAGS
|
||||
"-mavx512f"
|
||||
"-mavx512vl"
|
||||
@@ -44,47 +114,292 @@ if (AVX512_FOUND)
|
||||
|
||||
find_isa(${CPUINFO} "avx512_bf16" AVX512BF16_FOUND)
|
||||
if (AVX512BF16_FOUND OR ENABLE_AVX512BF16)
|
||||
if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND
|
||||
CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3)
|
||||
if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND
|
||||
CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3)
|
||||
list(APPEND CXX_COMPILE_FLAGS "-mavx512bf16")
|
||||
set(ENABLE_AVX512BF16 ON)
|
||||
else()
|
||||
set(ENABLE_AVX512BF16 OFF)
|
||||
message(WARNING "Disable AVX512-BF16 ISA support, requires gcc/g++ >= 12.3")
|
||||
endif()
|
||||
else()
|
||||
set(ENABLE_AVX512BF16 OFF)
|
||||
message(WARNING "Disable AVX512-BF16 ISA support, no avx512_bf16 found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AVX512BF16=1.")
|
||||
endif()
|
||||
|
||||
find_isa(${CPUINFO} "avx512_vnni" AVX512VNNI_FOUND)
|
||||
if (AVX512VNNI_FOUND OR ENABLE_AVX512VNNI)
|
||||
if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND
|
||||
CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3)
|
||||
list(APPEND CXX_COMPILE_FLAGS "-mavx512vnni")
|
||||
set(ENABLE_AVX512VNNI ON)
|
||||
else()
|
||||
set(ENABLE_AVX512VNNI OFF)
|
||||
message(WARNING "Disable AVX512-VNNI ISA support, requires gcc/g++ >= 12.3")
|
||||
endif()
|
||||
else()
|
||||
set(ENABLE_AVX512VNNI OFF)
|
||||
message(WARNING "Disable AVX512-VNNI ISA support, no avx512_vnni found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AVX512VNNI=1.")
|
||||
endif()
|
||||
|
||||
find_isa(${CPUINFO} "amx_bf16" AMXBF16_FOUND)
|
||||
if (AMXBF16_FOUND OR ENABLE_AMXBF16)
|
||||
if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND
|
||||
CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3)
|
||||
list(APPEND CXX_COMPILE_FLAGS "-mamx-bf16" "-mamx-tile")
|
||||
set(ENABLE_AMXBF16 ON)
|
||||
add_compile_definitions(-DCPU_CAPABILITY_AMXBF16)
|
||||
else()
|
||||
set(ENABLE_AMXBF16 OFF)
|
||||
message(WARNING "Disable AMX_BF16 ISA support, requires gcc/g++ >= 12.3")
|
||||
endif()
|
||||
else()
|
||||
set(ENABLE_AMXBF16 OFF)
|
||||
message(WARNING "Disable AMX_BF16 ISA support, no amx_bf16 found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AMXBF16=1.")
|
||||
endif()
|
||||
|
||||
elseif (AVX2_FOUND)
|
||||
list(APPEND CXX_COMPILE_FLAGS "-mavx2")
|
||||
message(WARNING "vLLM CPU backend using AVX2 ISA")
|
||||
|
||||
elseif (POWER9_FOUND OR POWER10_FOUND OR POWER11_FOUND)
|
||||
message(STATUS "PowerPC detected")
|
||||
if (POWER9_FOUND)
|
||||
list(APPEND CXX_COMPILE_FLAGS
|
||||
"-mvsx"
|
||||
"-mcpu=power9"
|
||||
"-mtune=power9")
|
||||
elseif (POWER10_FOUND OR POWER11_FOUND)
|
||||
list(APPEND CXX_COMPILE_FLAGS
|
||||
"-mvsx"
|
||||
"-mcpu=power10"
|
||||
"-mtune=power10")
|
||||
endif()
|
||||
|
||||
elseif (ASIMD_FOUND)
|
||||
message(STATUS "ARMv8 or later architecture detected")
|
||||
if(ARM_BF16_FOUND)
|
||||
message(STATUS "BF16 extension detected")
|
||||
set(MARCH_FLAGS "-march=armv8.2-a+bf16+dotprod+fp16")
|
||||
add_compile_definitions(ARM_BF16_SUPPORT)
|
||||
else()
|
||||
message(WARNING "BF16 functionality is not available")
|
||||
set(MARCH_FLAGS "-march=armv8.2-a+dotprod+fp16")
|
||||
endif()
|
||||
list(APPEND CXX_COMPILE_FLAGS ${MARCH_FLAGS})
|
||||
elseif (S390_FOUND)
|
||||
message(STATUS "S390 detected")
|
||||
# Check for S390 VXE support
|
||||
list(APPEND CXX_COMPILE_FLAGS
|
||||
"-mvx"
|
||||
"-mzvector"
|
||||
"-march=native"
|
||||
"-mtune=native")
|
||||
elseif (CMAKE_SYSTEM_PROCESSOR MATCHES "riscv64")
|
||||
if(RVV_FOUND)
|
||||
message(FAIL_ERROR "Can't support rvv now.")
|
||||
else()
|
||||
list(APPEND CXX_COMPILE_FLAGS "-march=rv64gc")
|
||||
endif()
|
||||
else()
|
||||
message(FATAL_ERROR "vLLM CPU backend requires AVX512 ISA support.")
|
||||
message(FATAL_ERROR "vLLM CPU backend requires AVX512, AVX2, Power9+ ISA, S390X ISA, ARMv8 or RISC-V support.")
|
||||
endif()
|
||||
|
||||
|
||||
# Build oneDNN for GEMM kernels (only for x86-AVX512 /ARM platforms)
|
||||
if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON_FOUND) OR POWER9_FOUND OR POWER10_FOUND OR POWER11_FOUND)
|
||||
# Fetch and build Arm Compute Library (ACL) as oneDNN's backend for AArch64
|
||||
# TODO [fadara01]: remove this once ACL can be fetched and built automatically as a dependency of oneDNN
|
||||
set(ONEDNN_AARCH64_USE_ACL OFF CACHE BOOL "")
|
||||
if(ASIMD_FOUND)
|
||||
# Set number of parallel build processes
|
||||
include(ProcessorCount)
|
||||
ProcessorCount(NPROC)
|
||||
if(NOT NPROC)
|
||||
set(NPROC 4)
|
||||
endif()
|
||||
# locate PyTorch's libgomp (e.g. site-packages/torch.libs/libgomp-947d5fa1.so.1.0.0)
|
||||
# and create a local shim dir with it
|
||||
vllm_prepare_torch_gomp_shim(VLLM_TORCH_GOMP_SHIM_DIR)
|
||||
|
||||
find_library(OPEN_MP
|
||||
NAMES gomp
|
||||
PATHS ${VLLM_TORCH_GOMP_SHIM_DIR}
|
||||
NO_DEFAULT_PATH
|
||||
REQUIRED
|
||||
)
|
||||
# Set LD_LIBRARY_PATH to include the shim dir at build time to use the same libgomp as PyTorch
|
||||
if (OPEN_MP)
|
||||
set(ENV{LD_LIBRARY_PATH} "${VLLM_TORCH_GOMP_SHIM_DIR}:$ENV{LD_LIBRARY_PATH}")
|
||||
endif()
|
||||
|
||||
# Fetch and populate ACL
|
||||
if(DEFINED ENV{ACL_ROOT_DIR} AND IS_DIRECTORY "$ENV{ACL_ROOT_DIR}")
|
||||
message(STATUS "Using ACL from specified source directory: $ENV{ACL_ROOT_DIR}")
|
||||
else()
|
||||
message(STATUS "Downloading Arm Compute Library (ACL) from GitHub")
|
||||
FetchContent_Populate(arm_compute
|
||||
SUBBUILD_DIR "${FETCHCONTENT_BASE_DIR}/arm_compute-subbuild"
|
||||
SOURCE_DIR "${FETCHCONTENT_BASE_DIR}/arm_compute-src"
|
||||
GIT_REPOSITORY https://github.com/ARM-software/ComputeLibrary.git
|
||||
GIT_TAG v52.6.0
|
||||
GIT_SHALLOW TRUE
|
||||
GIT_PROGRESS TRUE
|
||||
)
|
||||
set(ENV{ACL_ROOT_DIR} "${arm_compute_SOURCE_DIR}")
|
||||
set(ACL_LIB_DIR "$ENV{ACL_ROOT_DIR}/build")
|
||||
endif()
|
||||
|
||||
# Build ACL with CMake
|
||||
set(_cmake_config_cmd
|
||||
${CMAKE_COMMAND} -G Ninja -B build
|
||||
-DARM_COMPUTE_BUILD_SHARED_LIB=OFF
|
||||
-DCMAKE_BUILD_TYPE=Release
|
||||
-DARM_COMPUTE_ARCH=armv8.2-a
|
||||
-DARM_COMPUTE_ENABLE_ASSERTS=OFF
|
||||
-DARM_COMPUTE_ENABLE_CPPTHREADS=OFF
|
||||
-DARM_COMPUTE_ENABLE_OPENMP=ON
|
||||
-DARM_COMPUTE_ENABLE_WERROR=OFF
|
||||
-DARM_COMPUTE_BUILD_EXAMPLES=OFF
|
||||
-DARM_COMPUTE_BUILD_TESTING=OFF)
|
||||
set(_cmake_build_cmd
|
||||
${CMAKE_COMMAND} --build build -- -j${NPROC}
|
||||
)
|
||||
|
||||
execute_process(
|
||||
COMMAND ${_cmake_config_cmd}
|
||||
WORKING_DIRECTORY "$ENV{ACL_ROOT_DIR}"
|
||||
)
|
||||
execute_process(
|
||||
COMMAND ${_cmake_build_cmd}
|
||||
WORKING_DIRECTORY "$ENV{ACL_ROOT_DIR}"
|
||||
RESULT_VARIABLE _acl_rc
|
||||
)
|
||||
|
||||
if(NOT _acl_rc EQUAL 0)
|
||||
message(FATAL_ERROR "ACL SCons build failed (exit ${_acl_rc}).")
|
||||
endif()
|
||||
message(STATUS "Arm Compute Library (ACL) built successfully.")
|
||||
|
||||
# VLLM/oneDNN settings for ACL
|
||||
set(ONEDNN_AARCH64_USE_ACL ON CACHE BOOL "" FORCE)
|
||||
add_compile_definitions(VLLM_USE_ACL)
|
||||
endif()
|
||||
|
||||
set(FETCHCONTENT_SOURCE_DIR_ONEDNN "$ENV{FETCHCONTENT_SOURCE_DIR_ONEDNN}" CACHE PATH "Path to a local oneDNN source directory.")
|
||||
|
||||
if(FETCHCONTENT_SOURCE_DIR_ONEDNN)
|
||||
message(STATUS "Using oneDNN from specified source directory: ${FETCHCONTENT_SOURCE_DIR_ONEDNN}")
|
||||
FetchContent_Declare(
|
||||
oneDNN
|
||||
SOURCE_DIR ${FETCHCONTENT_SOURCE_DIR_ONEDNN}
|
||||
)
|
||||
else()
|
||||
message(STATUS "Downloading oneDNN from GitHub")
|
||||
FetchContent_Declare(
|
||||
oneDNN
|
||||
GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git
|
||||
GIT_TAG v3.10
|
||||
GIT_PROGRESS TRUE
|
||||
GIT_SHALLOW TRUE
|
||||
)
|
||||
endif()
|
||||
|
||||
set(ONEDNN_LIBRARY_TYPE "STATIC")
|
||||
set(ONEDNN_BUILD_DOC "OFF")
|
||||
set(ONEDNN_BUILD_EXAMPLES "OFF")
|
||||
set(ONEDNN_BUILD_TESTS "OFF")
|
||||
set(ONEDNN_ENABLE_WORKLOAD "INFERENCE")
|
||||
set(ONEDNN_ENABLE_PRIMITIVE "MATMUL;REORDER")
|
||||
set(ONEDNN_BUILD_GRAPH "OFF")
|
||||
set(ONEDNN_ENABLE_JIT_PROFILING "OFF")
|
||||
set(ONEDNN_ENABLE_ITT_TASKS "OFF")
|
||||
set(ONEDNN_ENABLE_MAX_CPU_ISA "OFF")
|
||||
set(ONEDNN_ENABLE_CPU_ISA_HINTS "OFF")
|
||||
set(ONEDNN_VERBOSE "OFF")
|
||||
set(CMAKE_POLICY_DEFAULT_CMP0077 NEW)
|
||||
|
||||
set(VLLM_BUILD_TYPE ${CMAKE_BUILD_TYPE})
|
||||
set(CMAKE_BUILD_TYPE "Release") # remove oneDNN debug symbols to reduce size
|
||||
FetchContent_MakeAvailable(oneDNN)
|
||||
set(CMAKE_BUILD_TYPE ${VLLM_BUILD_TYPE})
|
||||
add_library(dnnl_ext OBJECT "csrc/cpu/dnnl_helper.cpp")
|
||||
target_include_directories(
|
||||
dnnl_ext
|
||||
PUBLIC ${oneDNN_SOURCE_DIR}/include
|
||||
PUBLIC ${oneDNN_BINARY_DIR}/include
|
||||
PRIVATE ${oneDNN_SOURCE_DIR}/src
|
||||
)
|
||||
target_link_libraries(dnnl_ext dnnl)
|
||||
target_compile_options(dnnl_ext PRIVATE ${CXX_COMPILE_FLAGS} -fPIC)
|
||||
list(APPEND LIBS dnnl_ext)
|
||||
set(USE_ONEDNN ON)
|
||||
else()
|
||||
set(USE_ONEDNN OFF)
|
||||
endif()
|
||||
|
||||
message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}")
|
||||
|
||||
|
||||
#
|
||||
# Define extension targets
|
||||
#
|
||||
if(ENABLE_NUMA)
|
||||
list(APPEND LIBS numa)
|
||||
else()
|
||||
message(STATUS "NUMA is disabled")
|
||||
add_compile_definitions(-DVLLM_NUMA_DISABLED)
|
||||
endif()
|
||||
|
||||
#
|
||||
# _C extension
|
||||
#
|
||||
set(VLLM_EXT_SRC
|
||||
"csrc/cpu/activation.cpp"
|
||||
"csrc/cpu/attention.cpp"
|
||||
"csrc/cpu/cache.cpp"
|
||||
"csrc/cpu/utils.cpp"
|
||||
"csrc/cpu/layernorm.cpp"
|
||||
"csrc/cpu/mla_decode.cpp"
|
||||
"csrc/cpu/pos_encoding.cpp"
|
||||
"csrc/cpu/pybind.cpp")
|
||||
"csrc/moe/dynamic_4bit_int_moe_cpu.cpp"
|
||||
"csrc/cpu/cpu_attn.cpp"
|
||||
"csrc/cpu/scratchpad_manager.cpp"
|
||||
"csrc/cpu/torch_bindings.cpp")
|
||||
|
||||
define_gpu_extension_target(
|
||||
if (AVX512_FOUND AND NOT AVX512_DISABLED)
|
||||
set(VLLM_EXT_SRC
|
||||
"csrc/cpu/shm.cpp"
|
||||
"csrc/cpu/cpu_wna16.cpp"
|
||||
${VLLM_EXT_SRC})
|
||||
if (ENABLE_AVX512BF16 AND ENABLE_AVX512VNNI)
|
||||
set(VLLM_EXT_SRC
|
||||
"csrc/cpu/sgl-kernels/gemm.cpp"
|
||||
"csrc/cpu/sgl-kernels/gemm_int8.cpp"
|
||||
"csrc/cpu/sgl-kernels/gemm_fp8.cpp"
|
||||
"csrc/cpu/sgl-kernels/moe.cpp"
|
||||
"csrc/cpu/sgl-kernels/moe_int8.cpp"
|
||||
"csrc/cpu/sgl-kernels/moe_fp8.cpp"
|
||||
${VLLM_EXT_SRC})
|
||||
add_compile_definitions(-DCPU_CAPABILITY_AVX512)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(USE_ONEDNN)
|
||||
set(VLLM_EXT_SRC
|
||||
"csrc/cpu/dnnl_kernels.cpp"
|
||||
${VLLM_EXT_SRC})
|
||||
endif()
|
||||
|
||||
message(STATUS "CPU extension source files: ${VLLM_EXT_SRC}")
|
||||
|
||||
#
|
||||
# Define extension targets
|
||||
#
|
||||
|
||||
define_extension_target(
|
||||
_C
|
||||
DESTINATION vllm
|
||||
LANGUAGE CXX
|
||||
SOURCES ${VLLM_EXT_SRC}
|
||||
LIBRARIES ${LIBS}
|
||||
COMPILE_FLAGS ${CXX_COMPILE_FLAGS}
|
||||
WITH_SOABI
|
||||
USE_SABI 3
|
||||
WITH_SOABI
|
||||
)
|
||||
|
||||
add_custom_target(default)
|
||||
message(STATUS "Enabling C extension.")
|
||||
add_dependencies(default _C)
|
||||
|
||||
|
||||
133
cmake/external_projects/flashmla.cmake
Normal file
133
cmake/external_projects/flashmla.cmake
Normal file
@@ -0,0 +1,133 @@
|
||||
include(FetchContent)
|
||||
|
||||
# If FLASH_MLA_SRC_DIR is set, flash-mla is installed from that directory
|
||||
# instead of downloading.
|
||||
# It can be set as an environment variable or passed as a cmake argument.
|
||||
# The environment variable takes precedence.
|
||||
if (DEFINED ENV{FLASH_MLA_SRC_DIR})
|
||||
set(FLASH_MLA_SRC_DIR $ENV{FLASH_MLA_SRC_DIR})
|
||||
endif()
|
||||
|
||||
if(FLASH_MLA_SRC_DIR)
|
||||
FetchContent_Declare(
|
||||
flashmla
|
||||
SOURCE_DIR ${FLASH_MLA_SRC_DIR}
|
||||
CONFIGURE_COMMAND ""
|
||||
BUILD_COMMAND ""
|
||||
)
|
||||
else()
|
||||
FetchContent_Declare(
|
||||
flashmla
|
||||
GIT_REPOSITORY https://github.com/vllm-project/FlashMLA
|
||||
GIT_TAG 46d64a8ebef03fa50b4ae74937276a5c940e3f95
|
||||
GIT_PROGRESS TRUE
|
||||
CONFIGURE_COMMAND ""
|
||||
BUILD_COMMAND ""
|
||||
)
|
||||
endif()
|
||||
|
||||
|
||||
FetchContent_MakeAvailable(flashmla)
|
||||
message(STATUS "FlashMLA is available at ${flashmla_SOURCE_DIR}")
|
||||
|
||||
# The FlashMLA kernels only work on hopper and require CUDA 12.3 or later.
|
||||
# Only build FlashMLA kernels if we are building for something compatible with
|
||||
# sm90a
|
||||
|
||||
set(SUPPORT_ARCHS)
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3)
|
||||
list(APPEND SUPPORT_ARCHS 9.0a)
|
||||
endif()
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8)
|
||||
list(APPEND SUPPORT_ARCHS 10.0a)
|
||||
endif()
|
||||
|
||||
|
||||
cuda_archs_loose_intersection(FLASH_MLA_ARCHS "${SUPPORT_ARCHS}" "${CUDA_ARCHS}")
|
||||
if(FLASH_MLA_ARCHS)
|
||||
set(VLLM_FLASHMLA_GPU_FLAGS ${VLLM_GPU_FLAGS})
|
||||
list(APPEND VLLM_FLASHMLA_GPU_FLAGS "--expt-relaxed-constexpr" "--expt-extended-lambda" "--use_fast_math")
|
||||
|
||||
set(FlashMLA_SOURCES
|
||||
${flashmla_SOURCE_DIR}/csrc/torch_api.cpp
|
||||
${flashmla_SOURCE_DIR}/csrc/pybind.cpp
|
||||
${flashmla_SOURCE_DIR}/csrc/smxx/get_mla_metadata.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/smxx/mla_combine.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/sm90/decode/dense/splitkv_mla.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/sm90/decode/sparse_fp8/splitkv_mla.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/sm90/prefill/sparse/fwd.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/sm100/decode/sparse_fp8/splitkv_mla.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/sm100/prefill/sparse/fwd.cu
|
||||
)
|
||||
|
||||
set(FlashMLA_Extension_SOURCES
|
||||
${flashmla_SOURCE_DIR}/csrc/extension/torch_api.cpp
|
||||
${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/pybind.cpp
|
||||
${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/flash_fwd_mla_fp8_sm90.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/flash_fwd_mla_metadata.cu
|
||||
)
|
||||
|
||||
set(FlashMLA_INCLUDES
|
||||
${flashmla_SOURCE_DIR}/csrc
|
||||
${flashmla_SOURCE_DIR}/csrc/sm90
|
||||
${flashmla_SOURCE_DIR}/csrc/cutlass/include
|
||||
${flashmla_SOURCE_DIR}/csrc/cutlass/tools/util/include
|
||||
)
|
||||
|
||||
set(FlashMLA_Extension_INCLUDES
|
||||
${flashmla_SOURCE_DIR}/csrc
|
||||
${flashmla_SOURCE_DIR}/csrc/sm90
|
||||
${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/
|
||||
${flashmla_SOURCE_DIR}/csrc/cutlass/include
|
||||
${flashmla_SOURCE_DIR}/csrc/cutlass/tools/util/include
|
||||
)
|
||||
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${FlashMLA_SOURCES}"
|
||||
CUDA_ARCHS "${FLASH_MLA_ARCHS}")
|
||||
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${FlashMLA_Extension_SOURCES}"
|
||||
CUDA_ARCHS "${FLASH_MLA_ARCHS}")
|
||||
|
||||
define_extension_target(
|
||||
_flashmla_C
|
||||
DESTINATION vllm
|
||||
LANGUAGE ${VLLM_GPU_LANG}
|
||||
SOURCES ${FlashMLA_SOURCES}
|
||||
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
|
||||
ARCHITECTURES ${VLLM_GPU_ARCHES}
|
||||
INCLUDE_DIRECTORIES ${FlashMLA_INCLUDES}
|
||||
USE_SABI 3
|
||||
WITH_SOABI)
|
||||
|
||||
# Keep Stable ABI for the module, but *not* for CUDA/C++ files.
|
||||
# This prevents Py_LIMITED_API from affecting nvcc and C++ compiles.
|
||||
target_compile_options(_flashmla_C PRIVATE
|
||||
$<$<COMPILE_LANGUAGE:CUDA>:-UPy_LIMITED_API>
|
||||
$<$<COMPILE_LANGUAGE:CXX>:-UPy_LIMITED_API>)
|
||||
|
||||
define_extension_target(
|
||||
_flashmla_extension_C
|
||||
DESTINATION vllm
|
||||
LANGUAGE ${VLLM_GPU_LANG}
|
||||
SOURCES ${FlashMLA_Extension_SOURCES}
|
||||
COMPILE_FLAGS ${VLLM_FLASHMLA_GPU_FLAGS}
|
||||
ARCHITECTURES ${VLLM_GPU_ARCHES}
|
||||
INCLUDE_DIRECTORIES ${FlashMLA_Extension_INCLUDES}
|
||||
USE_SABI 3
|
||||
WITH_SOABI)
|
||||
|
||||
# Keep Stable ABI for the module, but *not* for CUDA/C++ files.
|
||||
# This prevents Py_LIMITED_API from affecting nvcc and C++ compiles.
|
||||
target_compile_options(_flashmla_extension_C PRIVATE
|
||||
$<$<COMPILE_LANGUAGE:CUDA>:-UPy_LIMITED_API>
|
||||
$<$<COMPILE_LANGUAGE:CXX>:-UPy_LIMITED_API>)
|
||||
else()
|
||||
# Create empty targets for setup.py when not targeting sm90a systems
|
||||
add_custom_target(_flashmla_C)
|
||||
add_custom_target(_flashmla_extension_C)
|
||||
endif()
|
||||
|
||||
97
cmake/external_projects/qutlass.cmake
Normal file
97
cmake/external_projects/qutlass.cmake
Normal file
@@ -0,0 +1,97 @@
|
||||
include(FetchContent)
|
||||
|
||||
set(CUTLASS_INCLUDE_DIR "${CUTLASS_INCLUDE_DIR}" CACHE PATH "Path to CUTLASS include/ directory")
|
||||
|
||||
if(DEFINED ENV{QUTLASS_SRC_DIR})
|
||||
set(QUTLASS_SRC_DIR $ENV{QUTLASS_SRC_DIR})
|
||||
endif()
|
||||
|
||||
if(QUTLASS_SRC_DIR)
|
||||
FetchContent_Declare(
|
||||
qutlass
|
||||
SOURCE_DIR ${QUTLASS_SRC_DIR}
|
||||
CONFIGURE_COMMAND ""
|
||||
BUILD_COMMAND ""
|
||||
)
|
||||
else()
|
||||
FetchContent_Declare(
|
||||
qutlass
|
||||
GIT_REPOSITORY https://github.com/IST-DASLab/qutlass.git
|
||||
GIT_TAG 830d2c4537c7396e14a02a46fbddd18b5d107c65
|
||||
GIT_PROGRESS TRUE
|
||||
CONFIGURE_COMMAND ""
|
||||
BUILD_COMMAND ""
|
||||
)
|
||||
endif()
|
||||
|
||||
FetchContent_Populate(qutlass)
|
||||
|
||||
if(NOT qutlass_SOURCE_DIR)
|
||||
message(FATAL_ERROR "[QUTLASS] source directory could not be resolved.")
|
||||
endif()
|
||||
message(STATUS "[QUTLASS] QuTLASS is available at ${qutlass_SOURCE_DIR}")
|
||||
|
||||
cuda_archs_loose_intersection(QUTLASS_ARCHS "12.0a;10.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND QUTLASS_ARCHS)
|
||||
|
||||
if(QUTLASS_ARCHS MATCHES "10\\.0a")
|
||||
set(QUTLASS_TARGET_CC 100)
|
||||
elseif(QUTLASS_ARCHS MATCHES "12\\.0a")
|
||||
set(QUTLASS_TARGET_CC 120)
|
||||
else()
|
||||
message(FATAL_ERROR "[QUTLASS] internal error parsing CUDA_ARCHS='${QUTLASS_ARCHS}'.")
|
||||
endif()
|
||||
|
||||
set(QUTLASS_SOURCES
|
||||
${qutlass_SOURCE_DIR}/qutlass/csrc/bindings.cpp
|
||||
${qutlass_SOURCE_DIR}/qutlass/csrc/gemm.cu
|
||||
${qutlass_SOURCE_DIR}/qutlass/csrc/gemm_ada.cu
|
||||
${qutlass_SOURCE_DIR}/qutlass/csrc/fused_quantize_mx.cu
|
||||
${qutlass_SOURCE_DIR}/qutlass/csrc/fused_quantize_nv.cu
|
||||
${qutlass_SOURCE_DIR}/qutlass/csrc/fused_quantize_mx_sm100.cu
|
||||
${qutlass_SOURCE_DIR}/qutlass/csrc/fused_quantize_nv_sm100.cu
|
||||
)
|
||||
|
||||
set(QUTLASS_INCLUDES
|
||||
${qutlass_SOURCE_DIR}
|
||||
${qutlass_SOURCE_DIR}/qutlass
|
||||
${qutlass_SOURCE_DIR}/qutlass/csrc/include
|
||||
${qutlass_SOURCE_DIR}/qutlass/csrc/include/cutlass_extensions
|
||||
)
|
||||
|
||||
if(CUTLASS_INCLUDE_DIR AND EXISTS "${CUTLASS_INCLUDE_DIR}/cutlass/cutlass.h")
|
||||
list(APPEND QUTLASS_INCLUDES "${CUTLASS_INCLUDE_DIR}")
|
||||
elseif(EXISTS "${qutlass_SOURCE_DIR}/qutlass/third_party/cutlass/include/cutlass/cutlass.h")
|
||||
list(APPEND QUTLASS_INCLUDES "${qutlass_SOURCE_DIR}/qutlass/third_party/cutlass/include")
|
||||
message(STATUS "[QUTLASS] Using QuTLASS vendored CUTLASS headers (no vLLM CUTLASS detected).")
|
||||
else()
|
||||
message(FATAL_ERROR "[QUTLASS] CUTLASS headers not found. "
|
||||
"Set -DCUTLASS_INCLUDE_DIR=/path/to/cutlass/include")
|
||||
endif()
|
||||
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${QUTLASS_SOURCES}"
|
||||
CUDA_ARCHS "${QUTLASS_ARCHS}"
|
||||
)
|
||||
|
||||
target_sources(_C PRIVATE ${QUTLASS_SOURCES})
|
||||
target_include_directories(_C PRIVATE ${QUTLASS_INCLUDES})
|
||||
target_compile_definitions(_C PRIVATE
|
||||
QUTLASS_DISABLE_PYBIND=1
|
||||
TARGET_CUDA_ARCH=${QUTLASS_TARGET_CC}
|
||||
)
|
||||
|
||||
set_property(SOURCE ${QUTLASS_SOURCES} APPEND PROPERTY COMPILE_OPTIONS
|
||||
$<$<COMPILE_LANGUAGE:CUDA>:--expt-relaxed-constexpr --use_fast_math -O3>
|
||||
)
|
||||
|
||||
else()
|
||||
if("${CMAKE_CUDA_COMPILER_VERSION}" VERSION_LESS "12.8")
|
||||
message(STATUS
|
||||
"[QUTLASS] Skipping build: CUDA 12.8 or newer is required (found ${CMAKE_CUDA_COMPILER_VERSION}).")
|
||||
else()
|
||||
message(STATUS
|
||||
"[QUTLASS] Skipping build: no supported arch (12.0a / 10.0a) found in "
|
||||
"CUDA_ARCHS='${CUDA_ARCHS}'.")
|
||||
endif()
|
||||
endif()
|
||||
53
cmake/external_projects/triton_kernels.cmake
Normal file
53
cmake/external_projects/triton_kernels.cmake
Normal file
@@ -0,0 +1,53 @@
|
||||
# Install OpenAI triton_kernels from https://github.com/triton-lang/triton/tree/main/python/triton_kernels
|
||||
|
||||
set(DEFAULT_TRITON_KERNELS_TAG "v3.5.0")
|
||||
|
||||
# Set TRITON_KERNELS_SRC_DIR for use with local development with vLLM. We expect TRITON_KERNELS_SRC_DIR to
|
||||
# be directly set to the triton_kernels python directory.
|
||||
if (DEFINED ENV{TRITON_KERNELS_SRC_DIR})
|
||||
message(STATUS "[triton_kernels] Fetch from $ENV{TRITON_KERNELS_SRC_DIR}")
|
||||
FetchContent_Declare(
|
||||
triton_kernels
|
||||
SOURCE_DIR $ENV{TRITON_KERNELS_SRC_DIR}
|
||||
)
|
||||
|
||||
else()
|
||||
set(TRITON_GIT "https://github.com/triton-lang/triton.git")
|
||||
message (STATUS "[triton_kernels] Fetch from ${TRITON_GIT}:${DEFAULT_TRITON_KERNELS_TAG}")
|
||||
FetchContent_Declare(
|
||||
triton_kernels
|
||||
# TODO (varun) : Fetch just the triton_kernels directory from Triton
|
||||
GIT_REPOSITORY https://github.com/triton-lang/triton.git
|
||||
GIT_TAG ${DEFAULT_TRITON_KERNELS_TAG}
|
||||
GIT_PROGRESS TRUE
|
||||
SOURCE_SUBDIR python/triton_kernels/triton_kernels
|
||||
)
|
||||
endif()
|
||||
|
||||
# Fetch content
|
||||
FetchContent_MakeAvailable(triton_kernels)
|
||||
|
||||
if (NOT triton_kernels_SOURCE_DIR)
|
||||
message (FATAL_ERROR "[triton_kernels] Cannot resolve triton_kernels_SOURCE_DIR")
|
||||
endif()
|
||||
|
||||
if (DEFINED ENV{TRITON_KERNELS_SRC_DIR})
|
||||
set(TRITON_KERNELS_PYTHON_DIR "${triton_kernels_SOURCE_DIR}/")
|
||||
else()
|
||||
set(TRITON_KERNELS_PYTHON_DIR "${triton_kernels_SOURCE_DIR}/python/triton_kernels/triton_kernels/")
|
||||
endif()
|
||||
|
||||
message (STATUS "[triton_kernels] triton_kernels is available at ${TRITON_KERNELS_PYTHON_DIR}")
|
||||
|
||||
add_custom_target(triton_kernels)
|
||||
|
||||
# Ensure the vllm/third_party directory exists before installation
|
||||
install(CODE "file(MAKE_DIRECTORY \"\${CMAKE_INSTALL_PREFIX}/vllm/third_party/triton_kernels\")")
|
||||
|
||||
## Copy .py files to install directory.
|
||||
install(DIRECTORY
|
||||
${TRITON_KERNELS_PYTHON_DIR}
|
||||
DESTINATION
|
||||
vllm/third_party/triton_kernels/
|
||||
COMPONENT triton_kernels
|
||||
FILES_MATCHING PATTERN "*.py")
|
||||
83
cmake/external_projects/vllm_flash_attn.cmake
Normal file
83
cmake/external_projects/vllm_flash_attn.cmake
Normal file
@@ -0,0 +1,83 @@
|
||||
# vLLM flash attention requires VLLM_GPU_ARCHES to contain the set of target
|
||||
# arches in the CMake syntax (75-real, 89-virtual, etc), since we clear the
|
||||
# arches in the CUDA case (and instead set the gencodes on a per file basis)
|
||||
# we need to manually set VLLM_GPU_ARCHES here.
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
foreach(_ARCH ${CUDA_ARCHS})
|
||||
string(REPLACE "." "" _ARCH "${_ARCH}")
|
||||
list(APPEND VLLM_GPU_ARCHES "${_ARCH}-real")
|
||||
endforeach()
|
||||
endif()
|
||||
|
||||
#
|
||||
# Build vLLM flash attention from source
|
||||
#
|
||||
# IMPORTANT: This has to be the last thing we do, because vllm-flash-attn uses the same macros/functions as vLLM.
|
||||
# Because functions all belong to the global scope, vllm-flash-attn's functions overwrite vLLMs.
|
||||
# They should be identical but if they aren't, this is a massive footgun.
|
||||
#
|
||||
# The vllm-flash-attn install rules are nested under vllm to make sure the library gets installed in the correct place.
|
||||
# To only install vllm-flash-attn, use --component _vllm_fa2_C (for FA2) or --component _vllm_fa3_C (for FA3).
|
||||
# If no component is specified, vllm-flash-attn is still installed.
|
||||
|
||||
# If VLLM_FLASH_ATTN_SRC_DIR is set, vllm-flash-attn is installed from that directory instead of downloading.
|
||||
# This is to enable local development of vllm-flash-attn within vLLM.
|
||||
# It can be set as an environment variable or passed as a cmake argument.
|
||||
# The environment variable takes precedence.
|
||||
if (DEFINED ENV{VLLM_FLASH_ATTN_SRC_DIR})
|
||||
set(VLLM_FLASH_ATTN_SRC_DIR $ENV{VLLM_FLASH_ATTN_SRC_DIR})
|
||||
endif()
|
||||
|
||||
if(VLLM_FLASH_ATTN_SRC_DIR)
|
||||
FetchContent_Declare(
|
||||
vllm-flash-attn SOURCE_DIR
|
||||
${VLLM_FLASH_ATTN_SRC_DIR}
|
||||
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
|
||||
)
|
||||
else()
|
||||
FetchContent_Declare(
|
||||
vllm-flash-attn
|
||||
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
|
||||
GIT_TAG 86f8f157cf82aa2342743752b97788922dd7de43
|
||||
GIT_PROGRESS TRUE
|
||||
# Don't share the vllm-flash-attn build between build types
|
||||
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
|
||||
)
|
||||
endif()
|
||||
|
||||
|
||||
# Ensure the vllm/vllm_flash_attn directory exists before installation
|
||||
install(CODE "file(MAKE_DIRECTORY \"\${CMAKE_INSTALL_PREFIX}/vllm/vllm_flash_attn\")" ALL_COMPONENTS)
|
||||
|
||||
# Make sure vllm-flash-attn install rules are nested under vllm/
|
||||
# This is here to support installing all components under the same prefix with cmake --install.
|
||||
# setup.py installs every component separately but uses the same prefix for all.
|
||||
# ALL_COMPONENTS is used to avoid duplication for FA2 and FA3,
|
||||
# and these statements don't hurt when installing neither component.
|
||||
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY FALSE)" ALL_COMPONENTS)
|
||||
install(CODE "set(OLD_CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}\")" ALL_COMPONENTS)
|
||||
install(CODE "set(CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}/vllm/\")" ALL_COMPONENTS)
|
||||
|
||||
# Fetch the vllm-flash-attn library
|
||||
FetchContent_MakeAvailable(vllm-flash-attn)
|
||||
message(STATUS "vllm-flash-attn is available at ${vllm-flash-attn_SOURCE_DIR}")
|
||||
|
||||
# Restore the install prefix
|
||||
install(CODE "set(CMAKE_INSTALL_PREFIX \"\${OLD_CMAKE_INSTALL_PREFIX}\")" ALL_COMPONENTS)
|
||||
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS)
|
||||
|
||||
# Copy over the vllm-flash-attn python files (duplicated for fa2 and fa3, in
|
||||
# case only one is built, in the case both are built redundant work is done)
|
||||
install(
|
||||
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/
|
||||
DESTINATION vllm/vllm_flash_attn
|
||||
COMPONENT _vllm_fa2_C
|
||||
FILES_MATCHING PATTERN "*.py"
|
||||
)
|
||||
|
||||
install(
|
||||
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/
|
||||
DESTINATION vllm/vllm_flash_attn
|
||||
COMPONENT _vllm_fa3_C
|
||||
FILES_MATCHING PATTERN "*.py"
|
||||
)
|
||||
@@ -1,4 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
#
|
||||
# A command line tool for running pytorch's hipify preprocessor on CUDA
|
||||
@@ -14,7 +16,7 @@ import shutil
|
||||
|
||||
from torch.utils.hipify.hipify_python import hipify
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# Project directory where all the source + include files live.
|
||||
@@ -32,15 +34,14 @@ if __name__ == '__main__':
|
||||
)
|
||||
|
||||
# Source files to convert.
|
||||
parser.add_argument("sources",
|
||||
help="Source files to hipify.",
|
||||
nargs="*",
|
||||
default=[])
|
||||
parser.add_argument(
|
||||
"sources", help="Source files to hipify.", nargs="*", default=[]
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Limit include scope to project_dir only
|
||||
includes = [os.path.join(args.project_dir, '*')]
|
||||
includes = [os.path.join(args.project_dir, "*")]
|
||||
|
||||
# Get absolute path for all source files.
|
||||
extra_files = [os.path.abspath(s) for s in args.sources]
|
||||
@@ -49,25 +50,31 @@ if __name__ == '__main__':
|
||||
# The directory might already exist to hold object files so we ignore that.
|
||||
shutil.copytree(args.project_dir, args.output_dir, dirs_exist_ok=True)
|
||||
|
||||
hipify_result = hipify(project_directory=args.project_dir,
|
||||
output_directory=args.output_dir,
|
||||
header_include_dirs=[],
|
||||
includes=includes,
|
||||
extra_files=extra_files,
|
||||
show_detailed=True,
|
||||
is_pytorch_extension=True,
|
||||
hipify_extra_files_only=True)
|
||||
hipify_result = hipify(
|
||||
project_directory=args.project_dir,
|
||||
output_directory=args.output_dir,
|
||||
header_include_dirs=[],
|
||||
includes=includes,
|
||||
extra_files=extra_files,
|
||||
show_detailed=True,
|
||||
is_pytorch_extension=True,
|
||||
hipify_extra_files_only=True,
|
||||
)
|
||||
|
||||
hipified_sources = []
|
||||
for source in args.sources:
|
||||
s_abs = os.path.abspath(source)
|
||||
hipified_s_abs = (hipify_result[s_abs].hipified_path if
|
||||
(s_abs in hipify_result
|
||||
and hipify_result[s_abs].hipified_path is not None)
|
||||
else s_abs)
|
||||
hipified_s_abs = (
|
||||
hipify_result[s_abs].hipified_path
|
||||
if (
|
||||
s_abs in hipify_result
|
||||
and hipify_result[s_abs].hipified_path is not None
|
||||
)
|
||||
else s_abs
|
||||
)
|
||||
hipified_sources.append(hipified_s_abs)
|
||||
|
||||
assert (len(hipified_sources) == len(args.sources))
|
||||
assert len(hipified_sources) == len(args.sources)
|
||||
|
||||
# Print hipified source files.
|
||||
print("\n".join(hipified_sources))
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
macro (find_python_from_executable EXECUTABLE SUPPORTED_VERSIONS)
|
||||
file(REAL_PATH ${EXECUTABLE} EXECUTABLE)
|
||||
set(Python_EXECUTABLE ${EXECUTABLE})
|
||||
find_package(Python COMPONENTS Interpreter Development.Module)
|
||||
find_package(Python COMPONENTS Interpreter Development.Module Development.SABIModule)
|
||||
if (NOT Python_FOUND)
|
||||
message(FATAL_ERROR "Unable to find python matching: ${EXECUTABLE}.")
|
||||
endif()
|
||||
@@ -58,8 +58,8 @@ function (hipify_sources_target OUT_SRCS NAME ORIG_SRCS)
|
||||
#
|
||||
set(SRCS ${ORIG_SRCS})
|
||||
set(CXX_SRCS ${ORIG_SRCS})
|
||||
list(FILTER SRCS EXCLUDE REGEX "\.(cc)|(cpp)$")
|
||||
list(FILTER CXX_SRCS INCLUDE REGEX "\.(cc)|(cpp)$")
|
||||
list(FILTER SRCS EXCLUDE REGEX "\.(cc)|(cpp)|(hip)$")
|
||||
list(FILTER CXX_SRCS INCLUDE REGEX "\.(cc)|(cpp)|(hip)$")
|
||||
|
||||
#
|
||||
# Generate ROCm/HIP source file names from CUDA file names.
|
||||
@@ -76,7 +76,7 @@ function (hipify_sources_target OUT_SRCS NAME ORIG_SRCS)
|
||||
set(CSRC_BUILD_DIR ${CMAKE_CURRENT_BINARY_DIR}/csrc)
|
||||
add_custom_target(
|
||||
hipify${NAME}
|
||||
COMMAND ${CMAKE_SOURCE_DIR}/cmake/hipify.py -p ${CMAKE_SOURCE_DIR}/csrc -o ${CSRC_BUILD_DIR} ${SRCS}
|
||||
COMMAND ${Python_EXECUTABLE} ${CMAKE_SOURCE_DIR}/cmake/hipify.py -p ${CMAKE_SOURCE_DIR}/csrc -o ${CSRC_BUILD_DIR} ${SRCS}
|
||||
DEPENDS ${CMAKE_SOURCE_DIR}/cmake/hipify.py ${SRCS}
|
||||
BYPRODUCTS ${HIP_SRCS}
|
||||
COMMENT "Running hipify on ${NAME} extension source files.")
|
||||
@@ -99,7 +99,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
|
||||
"Failed to determine torch nvcc compiler flags")
|
||||
|
||||
if (CUDA_VERSION VERSION_GREATER_EQUAL 11.8)
|
||||
list(APPEND GPU_FLAGS "-DENABLE_FP8_E5M2")
|
||||
list(APPEND GPU_FLAGS "-DENABLE_FP8")
|
||||
endif()
|
||||
if (CUDA_VERSION VERSION_GREATER_EQUAL 12.0)
|
||||
list(REMOVE_ITEM GPU_FLAGS
|
||||
@@ -119,24 +119,304 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
|
||||
|
||||
list(APPEND GPU_FLAGS
|
||||
"-DUSE_ROCM"
|
||||
"-DENABLE_FP8_E4M3"
|
||||
"-DENABLE_FP8"
|
||||
"-U__HIP_NO_HALF_CONVERSIONS__"
|
||||
"-U__HIP_NO_HALF_OPERATORS__"
|
||||
"-Werror=unused-variable"
|
||||
"-fno-gpu-rdc")
|
||||
|
||||
endif()
|
||||
set(${OUT_GPU_FLAGS} ${GPU_FLAGS} PARENT_SCOPE)
|
||||
endfunction()
|
||||
|
||||
# Find libgomp that gets shipped with PyTorch wheel and create a shim dir with:
|
||||
# libgomp.so -> libgomp-<hash>.so...
|
||||
# libgomp.so.1 -> libgomp-<hash>.so...
|
||||
# OUTPUT: TORCH_GOMP_SHIM_DIR ("" if not found)
|
||||
function(vllm_prepare_torch_gomp_shim TORCH_GOMP_SHIM_DIR)
|
||||
set(${TORCH_GOMP_SHIM_DIR} "" PARENT_SCOPE)
|
||||
|
||||
# Use run_python to locate vendored libgomp; never throw on failure.
|
||||
run_python(_VLLM_TORCH_GOMP_PATH
|
||||
"
|
||||
import os, glob
|
||||
import torch
|
||||
torch_pkg = os.path.dirname(torch.__file__)
|
||||
site_root = os.path.dirname(torch_pkg)
|
||||
|
||||
# Search both torch.libs and torch/lib
|
||||
roots = [os.path.join(site_root, 'torch.libs'), os.path.join(torch_pkg, 'lib')]
|
||||
candidates = []
|
||||
for root in roots:
|
||||
if not os.path.isdir(root):
|
||||
continue
|
||||
candidates.extend(glob.glob(os.path.join(root, 'libgomp*.so*')))
|
||||
|
||||
print(candidates[0] if candidates else '')
|
||||
"
|
||||
"failed to probe for libgomp")
|
||||
|
||||
if(_VLLM_TORCH_GOMP_PATH STREQUAL "" OR NOT EXISTS "${_VLLM_TORCH_GOMP_PATH}")
|
||||
return()
|
||||
endif()
|
||||
|
||||
# Create shim under the build tree
|
||||
set(_shim "${CMAKE_BINARY_DIR}/gomp_shim")
|
||||
file(MAKE_DIRECTORY "${_shim}")
|
||||
|
||||
execute_process(COMMAND ${CMAKE_COMMAND} -E rm -f "${_shim}/libgomp.so")
|
||||
execute_process(COMMAND ${CMAKE_COMMAND} -E rm -f "${_shim}/libgomp.so.1")
|
||||
execute_process(COMMAND ${CMAKE_COMMAND} -E create_symlink "${_VLLM_TORCH_GOMP_PATH}" "${_shim}/libgomp.so")
|
||||
execute_process(COMMAND ${CMAKE_COMMAND} -E create_symlink "${_VLLM_TORCH_GOMP_PATH}" "${_shim}/libgomp.so.1")
|
||||
|
||||
set(${TORCH_GOMP_SHIM_DIR} "${_shim}" PARENT_SCOPE)
|
||||
endfunction()
|
||||
|
||||
# Macro for converting a `gencode` version number to a cmake version number.
|
||||
macro(string_to_ver OUT_VER IN_STR)
|
||||
string(REGEX REPLACE "\([0-9]+\)\([0-9]\)" "\\1.\\2" ${OUT_VER} ${IN_STR})
|
||||
endmacro()
|
||||
|
||||
#
|
||||
# Clear all `-gencode` flags from `CMAKE_CUDA_FLAGS` and store them in
|
||||
# `CUDA_ARCH_FLAGS`.
|
||||
#
|
||||
# Example:
|
||||
# CMAKE_CUDA_FLAGS="-Wall -gencode arch=compute_70,code=sm_70 -gencode arch=compute_75,code=sm_75"
|
||||
# clear_cuda_arches(CUDA_ARCH_FLAGS)
|
||||
# CUDA_ARCH_FLAGS="-gencode arch=compute_70,code=sm_70;-gencode arch=compute_75,code=sm_75"
|
||||
# CMAKE_CUDA_FLAGS="-Wall"
|
||||
#
|
||||
macro(clear_cuda_arches CUDA_ARCH_FLAGS)
|
||||
# Extract all `-gencode` flags from `CMAKE_CUDA_FLAGS`
|
||||
string(REGEX MATCHALL "-gencode arch=[^ ]+" CUDA_ARCH_FLAGS
|
||||
${CMAKE_CUDA_FLAGS})
|
||||
|
||||
# Remove all `-gencode` flags from `CMAKE_CUDA_FLAGS` since they will be modified
|
||||
# and passed back via the `CUDA_ARCHITECTURES` property.
|
||||
string(REGEX REPLACE "-gencode arch=[^ ]+ *" "" CMAKE_CUDA_FLAGS
|
||||
${CMAKE_CUDA_FLAGS})
|
||||
endmacro()
|
||||
|
||||
#
|
||||
# Extract unique CUDA architectures from a list of compute capabilities codes in
|
||||
# the form `<major><minor>[<letter>]`, convert them to the form sort
|
||||
# `<major>.<minor>`, dedupes them and then sorts them in ascending order and
|
||||
# stores them in `OUT_ARCHES`.
|
||||
#
|
||||
# Example:
|
||||
# CUDA_ARCH_FLAGS="-gencode arch=compute_75,code=sm_75;...;-gencode arch=compute_90a,code=sm_90a"
|
||||
# extract_unique_cuda_archs_ascending(OUT_ARCHES CUDA_ARCH_FLAGS)
|
||||
# OUT_ARCHES="7.5;...;9.0"
|
||||
function(extract_unique_cuda_archs_ascending OUT_ARCHES CUDA_ARCH_FLAGS)
|
||||
set(_CUDA_ARCHES)
|
||||
foreach(_ARCH ${CUDA_ARCH_FLAGS})
|
||||
string(REGEX MATCH "arch=compute_\([0-9]+a?\)" _COMPUTE ${_ARCH})
|
||||
if (_COMPUTE)
|
||||
set(_COMPUTE ${CMAKE_MATCH_1})
|
||||
endif()
|
||||
|
||||
string_to_ver(_COMPUTE_VER ${_COMPUTE})
|
||||
list(APPEND _CUDA_ARCHES ${_COMPUTE_VER})
|
||||
endforeach()
|
||||
|
||||
list(REMOVE_DUPLICATES _CUDA_ARCHES)
|
||||
list(SORT _CUDA_ARCHES COMPARE NATURAL ORDER ASCENDING)
|
||||
set(${OUT_ARCHES} ${_CUDA_ARCHES} PARENT_SCOPE)
|
||||
endfunction()
|
||||
|
||||
#
|
||||
# For a specific file set the `-gencode` flag in compile options conditionally
|
||||
# for the CUDA language.
|
||||
#
|
||||
# Example:
|
||||
# set_gencode_flag_for_srcs(
|
||||
# SRCS "foo.cu"
|
||||
# ARCH "compute_75"
|
||||
# CODE "sm_75")
|
||||
# adds: "-gencode arch=compute_75,code=sm_75" to the compile options for
|
||||
# `foo.cu` (only for the CUDA language).
|
||||
#
|
||||
macro(set_gencode_flag_for_srcs)
|
||||
set(options)
|
||||
set(oneValueArgs ARCH CODE)
|
||||
set(multiValueArgs SRCS)
|
||||
cmake_parse_arguments(arg "${options}" "${oneValueArgs}"
|
||||
"${multiValueArgs}" ${ARGN} )
|
||||
set(_FLAG -gencode arch=${arg_ARCH},code=${arg_CODE})
|
||||
set_property(
|
||||
SOURCE ${arg_SRCS}
|
||||
APPEND PROPERTY
|
||||
COMPILE_OPTIONS "$<$<COMPILE_LANGUAGE:CUDA>:${_FLAG}>"
|
||||
)
|
||||
|
||||
message(DEBUG "Setting gencode flag for ${arg_SRCS}: ${_FLAG}")
|
||||
endmacro(set_gencode_flag_for_srcs)
|
||||
|
||||
#
|
||||
# For a list of source files set the `-gencode` flags in the files specific
|
||||
# compile options (specifically for the CUDA language).
|
||||
#
|
||||
# arguments are:
|
||||
# SRCS: list of source files
|
||||
# CUDA_ARCHS: list of CUDA architectures in the form `<major>.<minor>[letter]`
|
||||
# BUILD_PTX_FOR_ARCH: if set to true, then the PTX code will be built
|
||||
# for architecture `BUILD_PTX_FOR_ARCH` if there is a CUDA_ARCH in CUDA_ARCHS
|
||||
# that is larger than BUILD_PTX_FOR_ARCH.
|
||||
#
|
||||
macro(set_gencode_flags_for_srcs)
|
||||
set(options)
|
||||
set(oneValueArgs BUILD_PTX_FOR_ARCH)
|
||||
set(multiValueArgs SRCS CUDA_ARCHS)
|
||||
cmake_parse_arguments(arg "${options}" "${oneValueArgs}"
|
||||
"${multiValueArgs}" ${ARGN} )
|
||||
|
||||
foreach(_ARCH ${arg_CUDA_ARCHS})
|
||||
# handle +PTX suffix: generate both sm and ptx codes if requested
|
||||
string(FIND "${_ARCH}" "+PTX" _HAS_PTX)
|
||||
if(NOT _HAS_PTX EQUAL -1)
|
||||
string(REPLACE "+PTX" "" _BASE_ARCH "${_ARCH}")
|
||||
string(REPLACE "." "" _STRIPPED_ARCH "${_BASE_ARCH}")
|
||||
set_gencode_flag_for_srcs(
|
||||
SRCS ${arg_SRCS}
|
||||
ARCH "compute_${_STRIPPED_ARCH}"
|
||||
CODE "sm_${_STRIPPED_ARCH}")
|
||||
set_gencode_flag_for_srcs(
|
||||
SRCS ${arg_SRCS}
|
||||
ARCH "compute_${_STRIPPED_ARCH}"
|
||||
CODE "compute_${_STRIPPED_ARCH}")
|
||||
else()
|
||||
string(REPLACE "." "" _STRIPPED_ARCH "${_ARCH}")
|
||||
set_gencode_flag_for_srcs(
|
||||
SRCS ${arg_SRCS}
|
||||
ARCH "compute_${_STRIPPED_ARCH}"
|
||||
CODE "sm_${_STRIPPED_ARCH}")
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
if (${arg_BUILD_PTX_FOR_ARCH})
|
||||
list(SORT arg_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
|
||||
list(GET arg_CUDA_ARCHS -1 _HIGHEST_ARCH)
|
||||
if (_HIGHEST_ARCH VERSION_GREATER_EQUAL ${arg_BUILD_PTX_FOR_ARCH})
|
||||
string(REPLACE "." "" _PTX_ARCH "${arg_BUILD_PTX_FOR_ARCH}")
|
||||
set_gencode_flag_for_srcs(
|
||||
SRCS ${arg_SRCS}
|
||||
ARCH "compute_${_PTX_ARCH}"
|
||||
CODE "compute_${_PTX_ARCH}")
|
||||
endif()
|
||||
endif()
|
||||
endmacro()
|
||||
|
||||
#
|
||||
# For the given `SRC_CUDA_ARCHS` list of gencode versions in the form
|
||||
# `<major>.<minor>[letter]` compute the "loose intersection" with the
|
||||
# `TGT_CUDA_ARCHS` list of gencodes. We also support the `+PTX` suffix in
|
||||
# `SRC_CUDA_ARCHS` which indicates that the PTX code should be built when there
|
||||
# is a CUDA_ARCH in `TGT_CUDA_ARCHS` that is equal to or larger than the
|
||||
# architecture in `SRC_CUDA_ARCHS`.
|
||||
# The loose intersection is defined as:
|
||||
# { max{ x \in tgt | x <= y } | y \in src, { x \in tgt | x <= y } != {} }
|
||||
# where `<=` is the version comparison operator.
|
||||
# In other words, for each version in `TGT_CUDA_ARCHS` find the highest version
|
||||
# in `SRC_CUDA_ARCHS` that is less or equal to the version in `TGT_CUDA_ARCHS`.
|
||||
# We have special handling for x.0a, if x.0a is in `SRC_CUDA_ARCHS` and x.0 is
|
||||
# in `TGT_CUDA_ARCHS` then we should remove x.0a from `SRC_CUDA_ARCHS` and add
|
||||
# x.0a to the result (and remove x.0 from TGT_CUDA_ARCHS).
|
||||
# The result is stored in `OUT_CUDA_ARCHS`.
|
||||
#
|
||||
# Example:
|
||||
# SRC_CUDA_ARCHS="7.5;8.0;8.6;9.0;9.0a"
|
||||
# TGT_CUDA_ARCHS="8.0;8.9;9.0"
|
||||
# cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
|
||||
# OUT_CUDA_ARCHS="8.0;8.6;9.0;9.0a"
|
||||
#
|
||||
# Example With PTX:
|
||||
# SRC_CUDA_ARCHS="8.0+PTX"
|
||||
# TGT_CUDA_ARCHS="9.0"
|
||||
# cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
|
||||
# OUT_CUDA_ARCHS="8.0+PTX"
|
||||
#
|
||||
function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
|
||||
set(_SRC_CUDA_ARCHS "${SRC_CUDA_ARCHS}")
|
||||
set(_TGT_CUDA_ARCHS ${TGT_CUDA_ARCHS})
|
||||
|
||||
# handle +PTX suffix: separate base arch for matching, record PTX requests
|
||||
set(_PTX_ARCHS)
|
||||
foreach(_arch ${_SRC_CUDA_ARCHS})
|
||||
if(_arch MATCHES "\\+PTX$")
|
||||
string(REPLACE "+PTX" "" _base "${_arch}")
|
||||
list(APPEND _PTX_ARCHS "${_base}")
|
||||
list(REMOVE_ITEM _SRC_CUDA_ARCHS "${_arch}")
|
||||
list(APPEND _SRC_CUDA_ARCHS "${_base}")
|
||||
endif()
|
||||
endforeach()
|
||||
list(REMOVE_DUPLICATES _PTX_ARCHS)
|
||||
list(REMOVE_DUPLICATES _SRC_CUDA_ARCHS)
|
||||
|
||||
# If x.0a or x.0f is in SRC_CUDA_ARCHS and x.0 is in CUDA_ARCHS then we should
|
||||
# remove x.0a or x.0f from SRC_CUDA_ARCHS and add x.0a or x.0f to _CUDA_ARCHS
|
||||
set(_CUDA_ARCHS)
|
||||
foreach(_arch ${_SRC_CUDA_ARCHS})
|
||||
if(_arch MATCHES "[af]$")
|
||||
list(REMOVE_ITEM _SRC_CUDA_ARCHS "${_arch}")
|
||||
string(REGEX REPLACE "[af]$" "" _base "${_arch}")
|
||||
if ("${_base}" IN_LIST TGT_CUDA_ARCHS)
|
||||
list(REMOVE_ITEM _TGT_CUDA_ARCHS "${_base}")
|
||||
list(APPEND _CUDA_ARCHS "${_arch}")
|
||||
endif()
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
list(SORT _SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
|
||||
|
||||
# for each ARCH in TGT_CUDA_ARCHS find the highest arch in SRC_CUDA_ARCHS that
|
||||
# is less or equal to ARCH (but has the same major version since SASS binary
|
||||
# compatibility is only forward compatible within the same major version).
|
||||
foreach(_ARCH ${_TGT_CUDA_ARCHS})
|
||||
set(_TMP_ARCH)
|
||||
# Extract the major version of the target arch
|
||||
string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" TGT_ARCH_MAJOR "${_ARCH}")
|
||||
foreach(_SRC_ARCH ${_SRC_CUDA_ARCHS})
|
||||
# Extract the major version of the source arch
|
||||
string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" SRC_ARCH_MAJOR "${_SRC_ARCH}")
|
||||
# Check version-less-or-equal, and allow PTX arches to match across majors
|
||||
if (_SRC_ARCH VERSION_LESS_EQUAL _ARCH)
|
||||
if (_SRC_ARCH IN_LIST _PTX_ARCHS OR SRC_ARCH_MAJOR STREQUAL TGT_ARCH_MAJOR)
|
||||
set(_TMP_ARCH "${_SRC_ARCH}")
|
||||
endif()
|
||||
else()
|
||||
# If we hit a version greater than the target, we can break
|
||||
break()
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
# If we found a matching _TMP_ARCH, append it to _CUDA_ARCHS
|
||||
if (_TMP_ARCH)
|
||||
list(APPEND _CUDA_ARCHS "${_TMP_ARCH}")
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
list(REMOVE_DUPLICATES _CUDA_ARCHS)
|
||||
|
||||
# reapply +PTX suffix to architectures that requested PTX
|
||||
set(_FINAL_ARCHS)
|
||||
foreach(_arch ${_CUDA_ARCHS})
|
||||
if(_arch IN_LIST _PTX_ARCHS)
|
||||
list(APPEND _FINAL_ARCHS "${_arch}+PTX")
|
||||
else()
|
||||
list(APPEND _FINAL_ARCHS "${_arch}")
|
||||
endif()
|
||||
endforeach()
|
||||
set(_CUDA_ARCHS ${_FINAL_ARCHS})
|
||||
|
||||
set(${OUT_CUDA_ARCHS} ${_CUDA_ARCHS} PARENT_SCOPE)
|
||||
endfunction()
|
||||
|
||||
#
|
||||
# Override the GPU architectures detected by cmake/torch and filter them by
|
||||
# `GPU_SUPPORTED_ARCHES`. Sets the final set of architectures in
|
||||
# `GPU_ARCHES`.
|
||||
# `GPU_ARCHES`. This only applies to the HIP language since for CUDA we set
|
||||
# the architectures on a per file basis.
|
||||
#
|
||||
# Note: this is defined as a macro since it updates `CMAKE_CUDA_FLAGS`.
|
||||
#
|
||||
@@ -147,16 +427,23 @@ macro(override_gpu_arches GPU_ARCHES GPU_LANG GPU_SUPPORTED_ARCHES)
|
||||
if (${GPU_LANG} STREQUAL "HIP")
|
||||
#
|
||||
# `GPU_ARCHES` controls the `--offload-arch` flags.
|
||||
# `CMAKE_HIP_ARCHITECTURES` is set up by torch and can be controlled
|
||||
# via the `PYTORCH_ROCM_ARCH` env variable.
|
||||
#
|
||||
|
||||
# If PYTORCH_ROCM_ARCH env variable exists, then we take it as a list,
|
||||
# if not, then we use CMAKE_HIP_ARCHITECTURES which was generated by calling
|
||||
# "rocm_agent_enumerator" in "enable_language(HIP)"
|
||||
# (in file Modules/CMakeDetermineHIPCompiler.cmake)
|
||||
#
|
||||
if(DEFINED ENV{PYTORCH_ROCM_ARCH})
|
||||
set(HIP_ARCHITECTURES $ENV{PYTORCH_ROCM_ARCH})
|
||||
else()
|
||||
set(HIP_ARCHITECTURES ${CMAKE_HIP_ARCHITECTURES})
|
||||
endif()
|
||||
#
|
||||
# Find the intersection of the supported + detected architectures to
|
||||
# set the module architecture flags.
|
||||
#
|
||||
set(${GPU_ARCHES})
|
||||
foreach (_ARCH ${CMAKE_HIP_ARCHITECTURES})
|
||||
foreach (_ARCH ${HIP_ARCHITECTURES})
|
||||
if (_ARCH IN_LIST _GPU_SUPPORTED_ARCHES_LIST)
|
||||
list(APPEND ${GPU_ARCHES} ${_ARCH})
|
||||
endif()
|
||||
@@ -164,191 +451,98 @@ macro(override_gpu_arches GPU_ARCHES GPU_LANG GPU_SUPPORTED_ARCHES)
|
||||
|
||||
if(NOT ${GPU_ARCHES})
|
||||
message(FATAL_ERROR
|
||||
"None of the detected ROCm architectures: ${CMAKE_HIP_ARCHITECTURES} is"
|
||||
"None of the detected ROCm architectures: ${HIP_ARCHITECTURES} is"
|
||||
" supported. Supported ROCm architectures are: ${_GPU_SUPPORTED_ARCHES_LIST}.")
|
||||
endif()
|
||||
|
||||
elseif(${GPU_LANG} STREQUAL "CUDA")
|
||||
#
|
||||
# Setup/process CUDA arch flags.
|
||||
#
|
||||
# The torch cmake setup hardcodes the detected architecture flags in
|
||||
# `CMAKE_CUDA_FLAGS`. Since `CMAKE_CUDA_FLAGS` is a "global" variable, it
|
||||
# can't modified on a per-target basis, e.g. for the `punica` extension.
|
||||
# So, all the `-gencode` flags need to be extracted and removed from
|
||||
# `CMAKE_CUDA_FLAGS` for processing so they can be passed by another method.
|
||||
# Since it's not possible to use `target_compiler_options` for adding target
|
||||
# specific `-gencode` arguments, the target's `CUDA_ARCHITECTURES` property
|
||||
# must be used instead. This requires repackaging the architecture flags
|
||||
# into a format that cmake expects for `CUDA_ARCHITECTURES`.
|
||||
#
|
||||
# This is a bit fragile in that it depends on torch using `-gencode` as opposed
|
||||
# to one of the other nvcc options to specify architectures.
|
||||
#
|
||||
# Note: torch uses the `TORCH_CUDA_ARCH_LIST` environment variable to override
|
||||
# detected architectures.
|
||||
#
|
||||
message(DEBUG "initial CMAKE_CUDA_FLAGS: ${CMAKE_CUDA_FLAGS}")
|
||||
|
||||
# Extract all `-gencode` flags from `CMAKE_CUDA_FLAGS`
|
||||
string(REGEX MATCHALL "-gencode arch=[^ ]+" _CUDA_ARCH_FLAGS
|
||||
${CMAKE_CUDA_FLAGS})
|
||||
|
||||
# Remove all `-gencode` flags from `CMAKE_CUDA_FLAGS` since they will be modified
|
||||
# and passed back via the `CUDA_ARCHITECTURES` property.
|
||||
string(REGEX REPLACE "-gencode arch=[^ ]+ *" "" CMAKE_CUDA_FLAGS
|
||||
${CMAKE_CUDA_FLAGS})
|
||||
|
||||
# If this error is triggered, it might mean that torch has changed how it sets
|
||||
# up nvcc architecture code generation flags.
|
||||
if (NOT _CUDA_ARCH_FLAGS)
|
||||
message(FATAL_ERROR
|
||||
"Could not find any architecture related code generation flags in "
|
||||
"CMAKE_CUDA_FLAGS. (${CMAKE_CUDA_FLAGS})")
|
||||
endif()
|
||||
|
||||
message(DEBUG "final CMAKE_CUDA_FLAGS: ${CMAKE_CUDA_FLAGS}")
|
||||
message(DEBUG "arch flags: ${_CUDA_ARCH_FLAGS}")
|
||||
|
||||
# Initialize the architecture lists to empty.
|
||||
set(${GPU_ARCHES})
|
||||
|
||||
# Process each `gencode` flag.
|
||||
foreach(_ARCH ${_CUDA_ARCH_FLAGS})
|
||||
# For each flag, extract the version number and whether it refers to PTX
|
||||
# or native code.
|
||||
# Note: if a regex matches then `CMAKE_MATCH_1` holds the binding
|
||||
# for that match.
|
||||
|
||||
string(REGEX MATCH "arch=compute_\([0-9]+a?\)" _COMPUTE ${_ARCH})
|
||||
if (_COMPUTE)
|
||||
set(_COMPUTE ${CMAKE_MATCH_1})
|
||||
endif()
|
||||
|
||||
string(REGEX MATCH "code=sm_\([0-9]+a?\)" _SM ${_ARCH})
|
||||
if (_SM)
|
||||
set(_SM ${CMAKE_MATCH_1})
|
||||
endif()
|
||||
|
||||
string(REGEX MATCH "code=compute_\([0-9]+a?\)" _CODE ${_ARCH})
|
||||
if (_CODE)
|
||||
set(_CODE ${CMAKE_MATCH_1})
|
||||
endif()
|
||||
|
||||
# Make sure the virtual architecture can be matched.
|
||||
if (NOT _COMPUTE)
|
||||
message(FATAL_ERROR
|
||||
"Could not determine virtual architecture from: ${_ARCH}.")
|
||||
endif()
|
||||
|
||||
# One of sm_ or compute_ must exist.
|
||||
if ((NOT _SM) AND (NOT _CODE))
|
||||
message(FATAL_ERROR
|
||||
"Could not determine a codegen architecture from: ${_ARCH}.")
|
||||
endif()
|
||||
|
||||
if (_SM)
|
||||
# -real suffix let CMake to only generate elf code for the kernels.
|
||||
# we want this, otherwise the added ptx (default) will increase binary size.
|
||||
set(_VIRT "-real")
|
||||
set(_CODE_ARCH ${_SM})
|
||||
else()
|
||||
# -virtual suffix let CMake to generate ptx code for the kernels.
|
||||
set(_VIRT "-virtual")
|
||||
set(_CODE_ARCH ${_CODE})
|
||||
endif()
|
||||
|
||||
# Check if the current version is in the supported arch list.
|
||||
string_to_ver(_CODE_VER ${_CODE_ARCH})
|
||||
if (NOT _CODE_VER IN_LIST _GPU_SUPPORTED_ARCHES_LIST)
|
||||
message(STATUS "discarding unsupported CUDA arch ${_VER}.")
|
||||
continue()
|
||||
endif()
|
||||
|
||||
# Add it to the arch list.
|
||||
list(APPEND ${GPU_ARCHES} "${_CODE_ARCH}${_VIRT}")
|
||||
endforeach()
|
||||
endif()
|
||||
message(STATUS "${GPU_LANG} target arches: ${${GPU_ARCHES}}")
|
||||
endmacro()
|
||||
|
||||
#
|
||||
# Define a target named `GPU_MOD_NAME` for a single extension. The
|
||||
# Define a target named `MOD_NAME` for a single extension. The
|
||||
# arguments are:
|
||||
#
|
||||
# DESTINATION <dest> - Module destination directory.
|
||||
# LANGUAGE <lang> - The GPU language for this module, e.g CUDA, HIP,
|
||||
# etc.
|
||||
# LANGUAGE <lang> - The language for this module, e.g. CUDA, HIP,
|
||||
# CXX, etc.
|
||||
# SOURCES <sources> - List of source files relative to CMakeLists.txt
|
||||
# directory.
|
||||
#
|
||||
# Optional arguments:
|
||||
#
|
||||
# ARCHITECTURES <arches> - A list of target GPU architectures in cmake
|
||||
# format.
|
||||
# Refer `CMAKE_CUDA_ARCHITECTURES` documentation
|
||||
# and `CMAKE_HIP_ARCHITECTURES` for more info.
|
||||
# ARCHITECTURES <arches> - A list of target architectures in cmake format.
|
||||
# For GPU, refer to CMAKE_CUDA_ARCHITECTURES and
|
||||
# CMAKE_HIP_ARCHITECTURES for more info.
|
||||
# ARCHITECTURES will use cmake's defaults if
|
||||
# not provided.
|
||||
# COMPILE_FLAGS <flags> - Extra compiler flags passed to NVCC/hip.
|
||||
# INCLUDE_DIRECTORIES <dirs> - Extra include directories.
|
||||
# LIBRARIES <libraries> - Extra link libraries.
|
||||
# WITH_SOABI - Generate library with python SOABI suffix name.
|
||||
# USE_SABI <version> - Use python stable api <version>
|
||||
#
|
||||
# Note: optimization level/debug info is set via cmake build type.
|
||||
#
|
||||
function (define_gpu_extension_target GPU_MOD_NAME)
|
||||
function (define_extension_target MOD_NAME)
|
||||
cmake_parse_arguments(PARSE_ARGV 1
|
||||
GPU
|
||||
ARG
|
||||
"WITH_SOABI"
|
||||
"DESTINATION;LANGUAGE"
|
||||
"DESTINATION;LANGUAGE;USE_SABI"
|
||||
"SOURCES;ARCHITECTURES;COMPILE_FLAGS;INCLUDE_DIRECTORIES;LIBRARIES")
|
||||
|
||||
# Add hipify preprocessing step when building with HIP/ROCm.
|
||||
if (GPU_LANGUAGE STREQUAL "HIP")
|
||||
hipify_sources_target(GPU_SOURCES ${GPU_MOD_NAME} "${GPU_SOURCES}")
|
||||
if (ARG_LANGUAGE STREQUAL "HIP")
|
||||
hipify_sources_target(ARG_SOURCES ${MOD_NAME} "${ARG_SOURCES}")
|
||||
endif()
|
||||
|
||||
if (GPU_WITH_SOABI)
|
||||
set(GPU_WITH_SOABI WITH_SOABI)
|
||||
if (ARG_WITH_SOABI)
|
||||
set(SOABI_KEYWORD WITH_SOABI)
|
||||
else()
|
||||
set(GPU_WITH_SOABI)
|
||||
set(SOABI_KEYWORD "")
|
||||
endif()
|
||||
|
||||
Python_add_library(${GPU_MOD_NAME} MODULE "${GPU_SOURCES}" ${GPU_WITH_SOABI})
|
||||
run_python(IS_FREETHREADED_PYTHON
|
||||
"import sysconfig; print(1 if sysconfig.get_config_var(\"Py_GIL_DISABLED\") else 0)"
|
||||
"Failed to determine whether interpreter is free-threaded")
|
||||
|
||||
if (GPU_LANGUAGE STREQUAL "HIP")
|
||||
# Free-threaded Python doesn't yet support the stable ABI (see PEP 803/809),
|
||||
# so avoid using the stable ABI under free-threading only.
|
||||
if (ARG_USE_SABI AND NOT IS_FREETHREADED_PYTHON)
|
||||
Python_add_library(${MOD_NAME} MODULE USE_SABI ${ARG_USE_SABI} ${SOABI_KEYWORD} "${ARG_SOURCES}")
|
||||
else()
|
||||
Python_add_library(${MOD_NAME} MODULE ${SOABI_KEYWORD} "${ARG_SOURCES}")
|
||||
endif()
|
||||
|
||||
if (ARG_LANGUAGE STREQUAL "HIP")
|
||||
# Make this target dependent on the hipify preprocessor step.
|
||||
add_dependencies(${GPU_MOD_NAME} hipify${GPU_MOD_NAME})
|
||||
add_dependencies(${MOD_NAME} hipify${MOD_NAME})
|
||||
# Make sure we include the hipified versions of the headers, and avoid conflicts with the ones in the original source folder
|
||||
target_include_directories(${MOD_NAME} PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/csrc
|
||||
${ARG_INCLUDE_DIRECTORIES})
|
||||
else()
|
||||
target_include_directories(${MOD_NAME} PRIVATE csrc
|
||||
${ARG_INCLUDE_DIRECTORIES})
|
||||
endif()
|
||||
|
||||
if (GPU_ARCHITECTURES)
|
||||
set_target_properties(${GPU_MOD_NAME} PROPERTIES
|
||||
${GPU_LANGUAGE}_ARCHITECTURES "${GPU_ARCHITECTURES}")
|
||||
if (ARG_ARCHITECTURES)
|
||||
set_target_properties(${MOD_NAME} PROPERTIES
|
||||
${ARG_LANGUAGE}_ARCHITECTURES "${ARG_ARCHITECTURES}")
|
||||
endif()
|
||||
|
||||
set_property(TARGET ${GPU_MOD_NAME} PROPERTY CXX_STANDARD 17)
|
||||
target_compile_options(${MOD_NAME} PRIVATE
|
||||
$<$<COMPILE_LANGUAGE:${ARG_LANGUAGE}>:${ARG_COMPILE_FLAGS}>)
|
||||
|
||||
target_compile_options(${GPU_MOD_NAME} PRIVATE
|
||||
$<$<COMPILE_LANGUAGE:${GPU_LANGUAGE}>:${GPU_COMPILE_FLAGS}>)
|
||||
target_compile_definitions(${MOD_NAME} PRIVATE
|
||||
"-DTORCH_EXTENSION_NAME=${MOD_NAME}")
|
||||
|
||||
target_compile_definitions(${GPU_MOD_NAME} PRIVATE
|
||||
"-DTORCH_EXTENSION_NAME=${GPU_MOD_NAME}")
|
||||
|
||||
target_include_directories(${GPU_MOD_NAME} PRIVATE csrc
|
||||
${GPU_INCLUDE_DIRECTORIES})
|
||||
|
||||
target_link_libraries(${GPU_MOD_NAME} PRIVATE torch ${torch_python_LIBRARY}
|
||||
${GPU_LIBRARIES})
|
||||
target_link_libraries(${MOD_NAME} PRIVATE torch ${ARG_LIBRARIES})
|
||||
|
||||
# Don't use `TORCH_LIBRARIES` for CUDA since it pulls in a bunch of
|
||||
# dependencies that are not necessary and may not be installed.
|
||||
if (GPU_LANGUAGE STREQUAL "CUDA")
|
||||
target_link_libraries(${GPU_MOD_NAME} PRIVATE ${CUDA_CUDA_LIB}
|
||||
${CUDA_LIBRARIES})
|
||||
if (ARG_LANGUAGE STREQUAL "CUDA")
|
||||
target_link_libraries(${MOD_NAME} PRIVATE torch CUDA::cudart CUDA::cuda_driver ${ARG_LIBRARIES})
|
||||
else()
|
||||
target_link_libraries(${GPU_MOD_NAME} PRIVATE ${TORCH_LIBRARIES})
|
||||
target_link_libraries(${MOD_NAME} PRIVATE torch ${TORCH_LIBRARIES} ${ARG_LIBRARIES})
|
||||
endif()
|
||||
|
||||
install(TARGETS ${GPU_MOD_NAME} LIBRARY DESTINATION ${GPU_DESTINATION})
|
||||
install(TARGETS ${MOD_NAME} LIBRARY DESTINATION ${ARG_DESTINATION} COMPONENT ${MOD_NAME})
|
||||
endfunction()
|
||||
|
||||
Reference in New Issue
Block a user