cmake_minimum_required(VERSION 3.10)
set(CMAKE_C_COMPILER "gcc")
set(CMAKE_CXX_COMPILER "g++")

project(kernel_test)
message(STATUS "project name: ${PROJECT_NAME}")

set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/lib")
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/archive")
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "$ENV{NEUWARE_HOME}/cmake/modules")

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_EXTENSIONS OFF)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
add_compile_options(-std=c++17 -O3 -g -fPIC -Wall -Werror -Wextra -Wno-unused-parameter -Wno-unused-variable -Wno-unused-function -Wno-unknown-pragmas)

link_directories($ENV{NEUWARE_HOME}/lib64)
link_libraries(m stdc++ dl pthread cnnl cnnl_extra cnrt cndrv "-Wl,-rpath,$ENV{NEUWARE_HOME}/lib64 -Wl,--disable-new-dtags")
include_directories($ENV{NEUWARE_HOME}/include)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../../csrc/ ${CMAKE_CURRENT_SOURCE_DIR}/../../csrc/kernels)

function(find_torch)
  # get the path and cxx11_abi flag
  execute_process(
    COMMAND python3 -c "import torch; print(torch.__path__[0], torch.compiled_with_cxx11_abi(), sep=';')"
    RESULT_VARIABLE TORCH_NOT_FOUND
    OUTPUT_VARIABLE TORCH_INFO
    OUTPUT_STRIP_TRAILING_WHITESPACE
  )

  if(TORCH_NOT_FOUND)
    return()
  endif()

  list(GET TORCH_INFO 0 TORCH_PATH)
  message(STATUS "torch path: ${TORCH_PATH}")

  list(GET TORCH_INFO 1 TORCH_CXX11_ABI)
  message(STATUS "torch cxx11 abi: ${TORCH_CXX11_ABI}")

  set(Torch_DIR ${TORCH_PATH}/share/cmake/Torch PARENT_SCOPE)
endfunction()

# import pytorch
find_torch()
message(STATUS "Torch_DIR: ${Torch_DIR}")
find_package(Torch QUIET)
find_library(TORCH_PYTHON_LIBRARY torch_python PATH "${TORCH_INSTALL_PREFIX}/lib")

# import torch_mlu
execute_process(
    COMMAND python3 -c "import torch_mlu.utils as mlu_utils;print(mlu_utils.cmake_prefix_path)"
    OUTPUT_VARIABLE Torch_MLU_MODULE_DIR
    OUTPUT_STRIP_TRAILING_WHITESPACE
)

# find ops library
execute_process(
    COMMAND python3 -c "import torch_mlu_ops as ops;print(ops._utils.get_custom_op_library_path())"
    RESULT_VARIABLE LIBOPS_NOT_FOUND
    OUTPUT_VARIABLE LIBOPS_PATH
    OUTPUT_STRIP_TRAILING_WHITESPACE
)
if(LIBOPS_NOT_FOUND)
  message(FATAL_ERROR "torch_mlu_ops not installed, can not find ops library.")
endif()

set(CMAKE_PREFIX_PATH ${CMAKE_PREFIX_PATH} ${Torch_MLU_MODULE_DIR})
find_package(TorchMLU QUIET)
# torch_mlu throw [-Werror=sign-compare] error, so ignore it
add_compile_options(-Wno-sign-compare)
# TorchMLUConfig.cmake run will get TORCH_ATEN_LIBRARY-NOTFOUND, it will cause compile fail, remove it
string(REPLACE "TORCH_ATEN_LIBRARY-NOTFOUND" "" TORCH_MLU_LIBRARIES_MODIFIED "${TORCH_MLU_LIBRARIES}")

execute_process(
    COMMAND python3 -c "from distutils import sysconfig; print(sysconfig.get_python_inc())"
    OUTPUT_VARIABLE PYTHON_INCLUDE_DIR
    OUTPUT_STRIP_TRAILING_WHITESPACE
)

if(Torch_FOUND AND ((TORCH_CXX11_ABI AND USE_CXX11_ABI) OR (NOT TORCH_CXX11_ABI AND NOT USE_CXX11_ABI)))
  include_directories(${PYTHON_INCLUDE_DIR} ${TORCH_INCLUDE_DIRS} ${TORCH_MLU_INCLUDE_DIRS})
  link_libraries(${TORCH_LIBRARIES} ${TORCH_PYTHON_LIBRARY} ${TORCH_MLU_LIBRARIES_MODIFIED})
  file(GLOB_RECURSE TEST_SRCS "src/*.cpp" RECURSE)
  execute_process(
    COMMAND python3 -c "import sysconfig;print(sysconfig.get_config_var('EXT_SUFFIX'))"
    OUTPUT_VARIABLE TEST_LIB_NAME_SUFFIX
    OUTPUT_STRIP_TRAILING_WHITESPACE
  )
  message("${PYTHON_EXTENSION}")
  set(CMAKE_SHARED_LIBRARY_PREFIX "")
  set(CMAKE_SHARED_LIBRARY_SUFFIX "")
  set(TEST_LIB_NAME_PREFIX "btunittests")
  string(APPEND TEST_LIB_NAME "${TEST_LIB_NAME_PREFIX}${TEST_LIB_NAME_SUFFIX}")
  add_library(${TEST_LIB_NAME} SHARED ${TEST_SRCS})
  target_link_libraries(${TEST_LIB_NAME} ${LIBOPS_PATH} -lstdc++fs)
else()
  message(STATUS "Torch not found, or torch abi is different with which you specified, will not build")
  message(STATUS "if torch not found, please set env Torch_DIR to the directory containing TorchConfig.cmake")
endif()
