diff --git a/.github/workflows/release_whl.yml b/.github/workflows/release_whl.yml index 741ef3d3..b095e696 100644 --- a/.github/workflows/release_whl.yml +++ b/.github/workflows/release_whl.yml @@ -96,6 +96,11 @@ jobs: --exclude libge_common_base.so \ --exclude libc10.so \ --exclude libc_sec.so \ + --exclude libnnopbase.so \ + --exclude libprofapi.so \ + --exclude libgraph_base.so \ + --exclude libgraph.so \ + --exclude libexe_graph.so \ --exclude "libascend*.so" \ --exclude "libtorch*.so" \ --exclude "libopapi.so" \ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 34819347..f2f42d5b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,7 +12,7 @@ repos: - id: codespell args: [ --toml, pyproject.toml, - '--skip', 'tests/e2e/multicard/test_torchair_graph_mode.py,csrc/mla_preprocess/**,tests/prompts/**,./benchmarks/sonnet.txt,*tests/lora/data/**,build/**,./vllm_ascend.egg-info/**,.github/**,typos.toml', + '--skip', 'tests/e2e/multicard/test_torchair_graph_mode.py,csrc/**,tests/prompts/**,./benchmarks/sonnet.txt,*tests/lora/data/**,build/**,./vllm_ascend.egg-info/**,.github/**,typos.toml', '-L', 'CANN,cann,NNAL,nnal,ASCEND,ascend,EnQue,CopyIn,ArchType,AND' ] additional_dependencies: @@ -37,7 +37,7 @@ repos: - id: typos args: [ "--force-exclude", - "--exclude", "csrc/mla_preprocess/**" + "--exclude", "csrc/**" ] - repo: https://github.com/PyCQA/isort rev: 6.0.1 diff --git a/CMakeLists.txt b/CMakeLists.txt index 3e810fa8..f0136bc4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -82,6 +82,7 @@ set( ${TORCH_NPU_INCLUDE_DIRS} ${ASCEND_HOME_PATH}/include ${ASCEND_HOME_PATH}/aarch64-linux/include/experiment/platform + ${ASCEND_HOME_PATH}/x86_64-linux/include/experiment/platform ) pybind11_add_module(vllm_ascend_C ${VLLM_ASCEND_SRC}) diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt new file mode 100644 index 00000000..dab92509 --- /dev/null +++ b/csrc/CMakeLists.txt @@ -0,0 +1,642 @@ +# Copyright (c) 2024 Huawei Technologies Co., Ltd. +# This file is a part of the CANN Open Software. +# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ====================================================================================================================== + +cmake_minimum_required(VERSION 3.16) + +project(cann_ops_custom) + +option(BUILD_OPEN_PROJECT "Build open ascend ops project." ON) +option(ENABLE_CCACHE "Enable ccache capability" ON) +set(ASCEND_COMPUTE_UNIT "ascend910b" CACHE STRING "soc that need to be compiled") +set(ASCEND_OP_NAME "ALL" CACHE STRING "operators that need to be compiled") +set(VENDOR_NAME "customize" CACHE STRING "vendor name") + +include(cmake/config.cmake) +include(cmake/func.cmake) +include(cmake/intf.cmake) + +if (BUILD_OPEN_PROJECT) + set(_op_host_aclnn_link + $ + exe_graph + register + c_sec + ) + set(CMAKE_MODULE_PATH + ${CMAKE_MODULE_PATH} + ${CMAKE_CURRENT_LIST_DIR}/cmake/modules + ) + + set(CMAKE_PREFIX_PATH + ${CMAKE_PREFIX_PATH} + ${ASCEND_CANN_PACKAGE_PATH} + ) + + find_package(alog MODULE REQUIRED) + add_library(op_host_aclnn SHARED EXCLUDE_FROM_ALL) + target_link_libraries(op_host_aclnn PRIVATE + ${_op_host_aclnn_link} + ) + target_compile_options(op_host_aclnn PRIVATE + $<$:-std=gnu++1z> + ) + + add_library(op_host_aclnnInner SHARED EXCLUDE_FROM_ALL) + target_link_libraries(op_host_aclnnInner PRIVATE + ${_op_host_aclnn_link} + ) + target_compile_options(op_host_aclnnInner PRIVATE + $<$:-std=gnu++1z> + ) + + add_library(op_host_aclnnExc SHARED EXCLUDE_FROM_ALL) + target_link_libraries(op_host_aclnnExc PRIVATE + ${_op_host_aclnn_link} + ) + target_compile_options(op_host_aclnnExc PRIVATE + $<$:-std=gnu++1z> + ) + + # op api + add_library(opapi SHARED) + # When compiling a specified operator without aclnn src + if (NOT "${ASCEND_OP_NAME}" STREQUAL "ALL") + add_custom_command(OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/opapi_stub.cpp + COMMAND touch ${CMAKE_CURRENT_BINARY_DIR}/opapi_stub.cpp + ) + target_sources(opapi PRIVATE + ${CMAKE_CURRENT_BINARY_DIR}/opapi_stub.cpp + ) + endif() + + target_compile_options(opapi PRIVATE + $<$:-std=gnu++1z> + ) + target_include_directories(opapi PRIVATE + $ + $ + $ + ) + target_compile_options(opapi PRIVATE + -Werror=format + ) + target_compile_definitions(opapi PRIVATE + -DACLNN_LOG_FMT_CHECK + ) + target_link_libraries(opapi PRIVATE + $ + -Wl,--whole-archive + ops_aclnn + -Wl,--no-whole-archive + -lopapi + nnopbase + profapi + ge_common_base + ascend_dump + ascendalog + dl + ) + set_target_properties(opapi PROPERTIES OUTPUT_NAME + cust_opapi + ) + install(TARGETS opapi + LIBRARY DESTINATION packages/vendors/${VENDOR_NAME}/op_api/lib + ) + + # op proto + add_library(opsproto SHARED) + target_compile_options(opsproto PRIVATE + $<$:-std=c++11> + -fvisibility=hidden + ) + target_compile_definitions(opsproto PRIVATE + LOG_CPP + PROCESS_LOG + ) + target_link_libraries(opsproto PRIVATE + $ + $ + -Wl,--whole-archive + rt2_registry + -Wl,--no-whole-archive + -Wl,--no-as-needed + exe_graph + graph + graph_base + register + ascendalog + error_manager + platform + -Wl,--as-needed + c_sec + ) + set_target_properties(opsproto PROPERTIES OUTPUT_NAME + cust_opsproto_rt2.0 + ) + install(TARGETS opsproto + LIBRARY DESTINATION packages/vendors/${VENDOR_NAME}/op_proto/lib/linux/${CMAKE_SYSTEM_PROCESSOR} + ) + + # op tiling + add_library(optiling SHARED) + target_sources(optiling PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/utils/src/fallback_comm.cpp + ) + target_compile_options(optiling PRIVATE + $<$:-std=c++11> + -fvisibility=hidden + ) + target_compile_definitions(optiling PRIVATE + LOG_CPP + PROCESS_LOG + ) + target_link_libraries(optiling PRIVATE + $ + $ + -Wl,--whole-archive + rt2_registry + -Wl,--no-whole-archive + -Wl,--no-as-needed + graph + graph_base + exe_graph + platform + register + ascendalog + error_manager + -Wl,--as-needed + -Wl,--whole-archive + tiling_api + -Wl,--no-whole-archive + mmpa + c_sec + ) + set_target_properties(optiling PROPERTIES OUTPUT_NAME + cust_opmaster_rt2.0 + ) + add_custom_command(TARGET optiling + POST_BUILD + COMMAND ${CMAKE_COMMAND} -E make_directory ${TILING_CUSTOM_DIR} + COMMAND ln -sf $ ${TILING_CUSTOM_FILE} + ) + install(TARGETS optiling + LIBRARY DESTINATION packages/vendors/${VENDOR_NAME}/op_impl/ai_core/tbe/op_tiling/lib/linux/${CMAKE_SYSTEM_PROCESSOR} + ) + + # optiling compat + set(compat_optiling_dir ${CMAKE_CURRENT_BINARY_DIR}/compat) + set(compat_optiling_file ${compat_optiling_dir}/liboptiling.so) + add_custom_target(optiling_compat ALL + DEPENDS ${compat_optiling_file} + ) + + add_custom_command( + OUTPUT ${compat_optiling_file} + COMMAND ${CMAKE_COMMAND} -E make_directory ${compat_optiling_dir} + COMMAND ln -sf lib/linux/${CMAKE_SYSTEM_PROCESSOR}/$ ${compat_optiling_file} + ) + + install(FILES ${compat_optiling_file} + DESTINATION packages/vendors/${VENDOR_NAME}/op_impl/ai_core/tbe/op_tiling + ) + + add_ops_tiling_keys( + OP_NAME "ALL" + TILING_KEYS ${TILING_KEY} + ) + + add_opc_config( + OP_NAME "ALL" + CONFIG ${OP_DEBUG_CONFIG} + ) + + if(ADD_OPS_COMPILE_OPTION_V2) + add_ops_compile_options( + OP_NAME "ALL" + OPTIONS ${OPS_COMPILE_OPTIONS} + ) + endif() +endif () + +add_subdirectory(utils) + +set(OP_LIST) +set(OP_DIR_LIST) +op_add_subdirectory(OP_LIST OP_DIR_LIST) + +foreach (OP_DIR ${OP_DIR_LIST}) + add_subdirectory(${OP_DIR}/op_host) +endforeach () + +set(OP_DEPEND_DIR_LIST) +op_add_depend_directory( + OP_LIST ${OP_LIST} + OP_DIR_LIST OP_DEPEND_DIR_LIST +) +foreach (OP_DEPEND_DIR ${OP_DEPEND_DIR_LIST}) + add_subdirectory(${OP_DEPEND_DIR}/op_host) +endforeach () + +# ------------------------------------------------ aclnn ------------------------------------------------ +get_target_property(base_aclnn_srcs op_host_aclnn SOURCES) +get_target_property(base_aclnn_inner_srcs op_host_aclnnInner SOURCES) +get_target_property(base_aclnn_exclude_srcs op_host_aclnnExc SOURCES) + +if (BUILD_OPEN_PROJECT) + set(base_aclnn_binary_dir ${ASCEND_AUTOGEN_DIR}) +else() + get_target_property(base_aclnn_binary_dir op_host_aclnn BINARY_DIR) +endif () + +set(generate_aclnn_srcs) +set(generate_aclnn_inner_srcs) +set(generate_aclnn_headers) +set(generate_proto_dir ${base_aclnn_binary_dir}) +set(generate_exclude_proto_srcs) +set(generate_proto_srcs) +set(generate_proto_headers) + +if (base_aclnn_srcs) + foreach (_src ${base_aclnn_srcs}) + string(REGEX MATCH "^${CMAKE_CURRENT_SOURCE_DIR}" is_match "${_src}") + if (is_match) + get_filename_component(name_without_ext ${_src} NAME_WE) + + string(REGEX REPLACE "_def$" "" _op_name ${name_without_ext}) + list(APPEND generate_aclnn_srcs ${base_aclnn_binary_dir}/aclnn_${_op_name}.cpp) + list(APPEND generate_aclnn_headers ${base_aclnn_binary_dir}/aclnn_${_op_name}.h) + list(APPEND generate_proto_srcs ${generate_proto_dir}/${_op_name}_proto.cpp) + list(APPEND generate_proto_headers ${generate_proto_dir}/${_op_name}_proto.h) + endif () + endforeach () +else () + add_custom_command(OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/op_host_aclnn_stub.cpp + COMMAND touch ${CMAKE_CURRENT_BINARY_DIR}/op_host_aclnn_stub.cpp + ) + + target_sources(op_host_aclnn PRIVATE + ${CMAKE_CURRENT_BINARY_DIR}/op_host_aclnn_stub.cpp + ) +endif () + +if (base_aclnn_inner_srcs) + foreach (_src ${base_aclnn_inner_srcs}) + string(REGEX MATCH "^${CMAKE_CURRENT_SOURCE_DIR}" is_match "${_src}") + if (is_match) + get_filename_component(name_without_ext ${_src} NAME_WE) + string(REGEX REPLACE "_def$" "" _op_name ${name_without_ext}) + list(APPEND generate_aclnn_inner_srcs ${base_aclnn_binary_dir}/inner/aclnnInner_${_op_name}.cpp) + list(APPEND generate_proto_srcs ${generate_proto_dir}/inner/${_op_name}_proto.cpp) + list(APPEND generate_proto_headers ${generate_proto_dir}/inner/${_op_name}_proto.h) + endif () + endforeach () +else () + add_custom_command(OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/op_host_aclnn_inner_stub.cpp + COMMAND touch ${CMAKE_CURRENT_BINARY_DIR}/op_host_aclnn_inner_stub.cpp + ) + + target_sources(op_host_aclnnInner PRIVATE + ${CMAKE_CURRENT_BINARY_DIR}/op_host_aclnn_inner_stub.cpp + ) +endif () + +if (base_aclnn_exclude_srcs) + foreach (_src ${base_aclnn_exclude_srcs}) + string(REGEX MATCH "^${CMAKE_CURRENT_SOURCE_DIR}" is_match "${_src}") + if (is_match) + get_filename_component(name_without_ext ${_src} NAME_WE) + string(REGEX REPLACE "_def$" "" _op_name ${name_without_ext}) + list(APPEND generate_exclude_proto_srcs ${generate_proto_dir}/exc/${_op_name}_proto.cpp) + list(APPEND generate_proto_srcs ${generate_proto_dir}/exc/${_op_name}_proto.cpp) + list(APPEND generate_proto_headers ${generate_proto_dir}/exc/${_op_name}_proto.h) + endif () + endforeach () +else() + add_custom_command(OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/op_host_aclnn_exc_stub.cpp + COMMAND touch ${CMAKE_CURRENT_BINARY_DIR}/op_host_aclnn_exc_stub.cpp + ) + + target_sources(op_host_aclnnExc PRIVATE + ${CMAKE_CURRENT_BINARY_DIR}/op_host_aclnn_exc_stub.cpp + ) +endif () + +if (BUILD_OPEN_PROJECT) + if (generate_aclnn_srcs OR generate_aclnn_inner_srcs) + set(ops_aclnn_src ${generate_aclnn_srcs} ${generate_aclnn_inner_srcs}) + else () + set(ops_aclnn_src ${CMAKE_CURRENT_BINARY_DIR}/ops_aclnn_src_stub.cpp) + + add_custom_command(OUTPUT ${ops_aclnn_src} + COMMAND touch ${ops_aclnn_src} + ) + endif () + + set_source_files_properties(${ops_aclnn_src} + PROPERTIES GENERATED TRUE + ) + add_library(ops_aclnn STATIC + ${ops_aclnn_src} + ) + target_compile_options(ops_aclnn PRIVATE + $<$:-std=gnu++1z> + ) + target_link_libraries(ops_aclnn PRIVATE + $ + ) + add_dependencies(ops_aclnn opbuild_gen_default opbuild_gen_inner) + + set_source_files_properties(${generate_proto_srcs} + PROPERTIES GENERATED TRUE + ) + target_sources(opsproto PRIVATE + ${generate_proto_srcs} + ) + add_dependencies(opsproto ops_proto_headers) + + install(FILES ${generate_proto_headers} + DESTINATION packages/vendors/${VENDOR_NAME}/op_proto/inc OPTIONAL + ) + + redefine_file_macro( + TARGET_NAME + op_host_aclnn + op_host_aclnnInner + op_host_aclnnExc + opapi + opsproto + optiling + ops_aclnn + ) +else() + if (generate_aclnn_srcs OR generate_aclnn_inner_srcs) + set_source_files_properties(${generate_aclnn_srcs} ${generate_aclnn_inner_srcs} + TARGET_DIRECTORY acl_op_builtin + PROPERTIES GENERATED TRUE + ) + + target_sources(acl_op_builtin PRIVATE + ${generate_aclnn_srcs} + ${generate_aclnn_inner_srcs} + ) + endif () + + if (generate_proto_srcs) + set_source_files_properties(${generate_proto_srcs} + TARGET_DIRECTORY opsproto opsproto_rt2.0 + PROPERTIES GENERATED TRUE + ) + target_sources(opsproto PRIVATE + ${generate_proto_srcs} + ) + add_dependencies(opsproto ops_proto_headers) + + target_sources(opsproto_rt2.0 PRIVATE + ${generate_proto_srcs} + ) + add_dependencies(opsproto_rt2.0 ops_proto_headers) + endif () + + add_target_source( + TARGET_NAME opmaster_rt2.0 opmaster_static_rt2.0 + BASE_TARGET optiling + SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR} + ) + + add_target_source( + TARGET_NAME opsproto_rt2.0 opsproto_static_rt2.0 + BASE_TARGET opsproto + SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR} + ) + + add_static_ops( + ACLNN_SRC ${generate_aclnn_srcs} + ACLNN_INNER_SRC ${generate_aclnn_inner_srcs} + SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR} + ) +endif () + +if (generate_aclnn_headers) + install(FILES ${generate_aclnn_headers} + DESTINATION ${ACLNN_INC_INSTALL_DIR} OPTIONAL + ) +endif () + +add_library(ops_proto_headers INTERFACE) + +target_include_directories(ops_proto_headers INTERFACE + $ + $ + $ + $ +) + +if ((NOT BUILD_OPEN_PROJECT) AND ("${PRODUCT_SIDE}" STREQUAL "device")) + ExternalProject_Add(extern_opbuild_gen_default + SOURCE_DIR ${TOP_DIR}/cmake/superbuild + CONFIGURE_COMMAND ${CMAKE_COMMAND} + -G ${CMAKE_GENERATOR} + -DHOST_PACKAGE=opp + -DBUILD_MOD=ops + -DCMAKE_INSTALL_PREFIX=${CMAKE_CURRENT_BINARY_DIR}/opbuild_output + -DFEATURE_LIST=custom_opbuild_out_dir=${generate_proto_dir} + + BUILD_COMMAND TARGETS=opbuild_gen_all $(MAKE) + INSTALL_COMMAND "" + LIST_SEPARATOR :: + EXCLUDE_FROM_ALL TRUE + ) + add_dependencies(ops_proto_headers extern_opbuild_gen_default) +else() + add_dependencies(ops_proto_headers opbuild_gen_default opbuild_gen_inner opbuild_gen_exc) +endif () + +if (NOT BUILD_OPEN_PROJECT) + if (generate_proto_srcs) + install_package( + PACKAGE ops_adv + TARGETS ops_proto_headers + FILES ${generate_proto_headers} + DESTINATION include/ops_adv/proto + ) + endif () +endif () + +# ------------------------------------------------ opbuild ------------------------------------------------ +if (BUILD_OPEN_PROJECT) + if (generate_aclnn_srcs) + add_custom_command(OUTPUT ${generate_aclnn_srcs} ${generate_aclnn_headers} + COMMAND mkdir -p ${base_aclnn_binary_dir} + COMMAND OPS_PROTO_SEPARATE=1 + OPS_ACLNN_GEN=1 + OPS_PROJECT_NAME=aclnn + ${OP_BUILD_TOOL} + $ + ${base_aclnn_binary_dir} + ) + endif () + + add_custom_target(opbuild_gen_default + DEPENDS ${generate_aclnn_srcs} ${generate_aclnn_headers} op_host_aclnn + ) + + if (generate_aclnn_inner_srcs) + add_custom_command(OUTPUT ${generate_aclnn_inner_srcs} + COMMAND mkdir -p ${base_aclnn_binary_dir}/inner + COMMAND OPS_PROTO_SEPARATE=1 + OPS_ACLNN_GEN=1 + OPS_PROJECT_NAME=aclnnInner + ${OP_BUILD_TOOL} + $ + ${base_aclnn_binary_dir}/inner + ) + endif () + + add_custom_target(opbuild_gen_inner + DEPENDS ${generate_aclnn_inner_srcs} op_host_aclnnInner + ) + + if (generate_exclude_proto_srcs) + add_custom_command(OUTPUT ${generate_exclude_proto_srcs} + COMMAND mkdir -p ${base_aclnn_binary_dir}/exc + COMMAND OPS_PROTO_SEPARATE=1 + OPS_ACLNN_GEN=0 + OPS_PROJECT_NAME=aclnnExc + ${OP_BUILD_TOOL} + $ + ${base_aclnn_binary_dir}/exc + ) + endif () + + add_custom_target(opbuild_gen_exc + DEPENDS ${generate_exclude_proto_srcs} op_host_aclnnExc + ) +endif () + +# ------------------------------------------------ generate adapt py ------------------------------------------------ +add_custom_target(generate_adapt_py + COMMAND ${HI_PYTHON} ${ASCENDC_CMAKE_UTIL_DIR}/ascendc_impl_build.py + \"\" + \"\" + \"\" + \"\" + ${ASCEND_IMPL_OUT_DIR} + ${ASCEND_AUTOGEN_DIR} + --opsinfo-dir ${base_aclnn_binary_dir} ${base_aclnn_binary_dir}/inner ${base_aclnn_binary_dir}/exc +) + +add_dependencies(generate_adapt_py opbuild_gen_default opbuild_gen_inner opbuild_gen_exc) + +foreach (_op_name ${OP_LIST}) + install(FILES ${ASCEND_IMPL_OUT_DIR}/dynamic/${_op_name}.py + DESTINATION ${IMPL_DYNAMIC_INSTALL_DIR} + OPTIONAL + ) +endforeach () + +install(DIRECTORY ${OPS_ADV_UTILS_KERNEL_INC}/ + DESTINATION ${IMPL_INSTALL_DIR}/ascendc/common +) + +foreach (op_dir ${OP_DIR_LIST}) + get_filename_component(_op_name "${op_dir}" NAME) + + file(GLOB KERNEL_FILES + ${op_dir}/op_kernel/*.cpp + ${op_dir}/op_kernel/*.h + ) + + install(FILES ${KERNEL_FILES} + DESTINATION ${IMPL_INSTALL_DIR}/ascendc/${_op_name} + OPTIONAL + ) +endforeach () + +# ------------------------------------------------ generate compile cmd ------------------------------------------------ +if (BUILD_OPEN_PROJECT) + add_custom_target(prepare_build ALL) + add_custom_target(generate_compile_cmd ALL) + add_custom_target(generate_ops_info ALL) + add_dependencies(prepare_build generate_adapt_py generate_compile_cmd) + + foreach (compute_unit ${ASCEND_COMPUTE_UNIT}) + add_compile_cmd_target( + COMPUTE_UNIT ${compute_unit} + ) + + add_ops_info_target( + COMPUTE_UNIT ${compute_unit} + ) + endforeach () +else() + add_dependencies(tbe_ops_json_info generate_adapt_py) +endif () + +# ------------------------------------------------ opp kernel ------------------------------------------------ +if (ENABLE_OPS_KERNEL) + add_custom_target(ops_kernel ALL) + add_custom_target(ops_config ALL) + add_dependencies(ops_kernel ops_config) + + foreach (compute_unit ${ASCEND_COMPUTE_UNIT}) + add_bin_compile_target( + COMPUTE_UNIT + ${compute_unit} + OP_INFO + ${OP_DIR_LIST} + ) + endforeach () +endif () + +if (BUILD_OPEN_PROJECT) + add_custom_target(modify_vendor ALL + DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/scripts/install.sh ${CMAKE_CURRENT_BINARY_DIR}/scripts/upgrade.sh + ) + + # modify VENDOR_NAME in install.sh and upgrade.sh + add_custom_command(OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/scripts/install.sh ${CMAKE_CURRENT_BINARY_DIR}/scripts/upgrade.sh + COMMAND mkdir -p ${CMAKE_CURRENT_BINARY_DIR}/scripts + COMMAND cp -r ${ASCEND_PROJECT_DIR}/scripts/* ${CMAKE_CURRENT_BINARY_DIR}/scripts/ + COMMAND chmod +w ${CMAKE_CURRENT_BINARY_DIR}/scripts/* + COMMAND sed -i "s/vendor_name=customize/vendor_name=${VENDOR_NAME}/g" ${CMAKE_CURRENT_BINARY_DIR}/scripts/* + ) + + install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/scripts/ + DESTINATION . FILE_PERMISSIONS OWNER_EXECUTE OWNER_READ GROUP_READ + ) + + # gen version.info + set(version_info_dir ${CMAKE_CURRENT_BINARY_DIR}) + set(version_info_file ${version_info_dir}/version.info) + add_custom_target(gen_version_info ALL + DEPENDS ${version_info_file} + ) + + add_custom_command(OUTPUT ${version_info_file} + COMMAND bash ${ASCENDC_CMAKE_UTIL_DIR}/gen_version_info.sh ${ASCEND_CANN_PACKAGE_PATH} ${version_info_dir} + ) + + install(FILES ${version_info_file} + DESTINATION packages/vendors/${VENDOR_NAME}/ + ) + + # CPack config + set(CPACK_PACKAGE_NAME ${CMAKE_PROJECT_NAME}) + set(CPACK_PACKAGE_VERSION ${CMAKE_PROJECT_VERSION}) + set(CPACK_PACKAGE_DESCRIPTION "CPack ops project") + set(CPACK_PACKAGE_DESCRIPTION_SUMMARY "CPack ops project") + set(CPACK_PACKAGE_DIRECTORY ${CMAKE_BINARY_DIR}) + set(CPACK_PACKAGE_FILE_NAME "CANN-custom_ops-${CANN_VERSION}-linux.${CMAKE_SYSTEM_PROCESSOR}.run") + set(CPACK_GENERATOR External) + set(CPACK_CMAKE_GENERATOR "Unix Makefiles") + set(CPACK_EXTERNAL_ENABLE_STAGING TRUE) + set(CPACK_EXTERNAL_PACKAGE_SCRIPT ${ASCEND_CMAKE_DIR}/makeself.cmake) + set(CPACK_EXTERNAL_BUILT_PACKAGES ${CPACK_PACKAGE_DIRECTORY}/_CPack_Packages/Linux/External/${CPACK_PACKAGE_FILE_NAME}/${CPACK_PACKAGE_FILE_NAME}) + include(CPack) +endif () diff --git a/csrc/build.sh b/csrc/build.sh new file mode 100644 index 00000000..76efeaaa --- /dev/null +++ b/csrc/build.sh @@ -0,0 +1,189 @@ +#!/bin/bash +# Copyright (c) 2024 Huawei Technologies Co., Ltd. +# This file is a part of the CANN Open Software. +# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ====================================================================================================================== + +set -e + +CURRENT_DIR=$(dirname $(readlink -f ${BASH_SOURCE[0]})) +BUILD_DIR=${CURRENT_DIR}/build +OUTPUT_DIR=${CURRENT_DIR}/output +USER_ID=$(id -u) +PARENT_JOB="false" +CHECK_COMPATIBLE="true" +ASAN="false" +COV="false" +VERBOSE="false" + +if [ "${USER_ID}" != "0" ]; then + DEFAULT_TOOLKIT_INSTALL_DIR="${HOME}/Ascend/ascend-toolkit/latest" + DEFAULT_INSTALL_DIR="${HOME}/Ascend/latest" +else + DEFAULT_TOOLKIT_INSTALL_DIR="/usr/local/Ascend/ascend-toolkit/latest" + DEFAULT_INSTALL_DIR="/usr/local/Ascend/latest" +fi + +CUSTOM_OPTION="-DBUILD_OPEN_PROJECT=ON" + +function help_info() { + echo "Usage: $0 [options]" + echo "Options:" + echo + echo "-h|--help Displays help message." + echo + echo "-n|--op-name Specifies the compiled operator. If there are multiple values, separate them with semicolons and use quotation marks. The default is all." + echo " For example: -n \"flash_attention_score\" or -n \"flash_attention_score;flash_attention_score_grad\"" + echo + echo "-c|--compute-unit Specifies the chip type. If there are multiple values, separate them with semicolons and use quotation marks. The default is ascend910b." + echo " For example: -c \"ascend910b\" or -c \"ascend910b;ascend310p\"" + echo + echo "--cov Compiles with cov." + echo + echo "--verbose Displays more compilation information." + echo +} + +function log() { + local current_time=`date +"%Y-%m-%d %H:%M:%S"` + echo "[$current_time] "$1 +} + +function set_env() +{ + source $ASCEND_CANN_PACKAGE_PATH/bin/setenv.bash || echo "0" + + export BISHENG_REAL_PATH=$(which bisheng || true) + + if [ -z "${BISHENG_REAL_PATH}" ];then + log "Error: bisheng compilation tool not found, Please check whether the cann package or environment variables are set." + exit 1 + fi +} + +function clean() +{ + if [ -n "${BUILD_DIR}" ];then + rm -rf ${BUILD_DIR} + fi + mkdir -p ${BUILD_DIR} ${OUTPUT_DIR} +} + +function cmake_config() +{ + local extra_option="$1" + log "Info: cmake config ${CUSTOM_OPTION} ${extra_option} ." + cmake .. ${CUSTOM_OPTION} ${extra_option} +} + +function build() +{ + local target="$1" + if [ "${VERBOSE}" == "true" ];then + local option="--verbose" + fi + cmake --build . --target ${target} ${JOB_NUM} ${option} +} + +function gen_bisheng(){ + local ccache_program=$1 + local gen_bisheng_dir=${BUILD_DIR}/gen_bisheng_dir + + if [ ! -d "${gen_bisheng_dir}" ];then + mkdir -p ${gen_bisheng_dir} + fi + + pushd ${gen_bisheng_dir} + $(> bisheng) + echo "#!/bin/bash" >> bisheng + echo "ccache_args=""\"""${ccache_program} ${BISHENG_REAL_PATH}""\"" >> bisheng + echo "args=""$""@" >> bisheng + + if [ "${VERBOSE}" == "true" ];then + echo "echo ""\"""$""{ccache_args} ""$""args""\"" >> bisheng + fi + + echo "eval ""\"""$""{ccache_args} ""$""args""\"" >> bisheng + chmod +x bisheng + + export PATH=${gen_bisheng_dir}:$PATH + popd +} + +function build_package(){ + build package +} + +function build_host(){ + build_package +} + +function build_kernel(){ + build ops_kernel +} + +while [[ $# -gt 0 ]]; do + case $1 in + -h|--help) + help_info + exit + ;; + -n|--op-name) + ascend_op_name="$2" + shift 2 + ;; + -c|--compute-unit) + ascend_compute_unit="$2" + shift 2 + ;; + *) + help_info + exit 1 + ;; + esac +done + +if [ -n "${ascend_compute_unit}" ];then + CUSTOM_OPTION="${CUSTOM_OPTION} -DASCEND_COMPUTE_UNIT=${ascend_compute_unit}" +fi + +if [ -n "${ascend_op_name}" ];then + CUSTOM_OPTION="${CUSTOM_OPTION} -DASCEND_OP_NAME=${ascend_op_name}" +fi + +if [ -n "${ASCEND_HOME_PATH}" ];then + ASCEND_CANN_PACKAGE_PATH=${ASCEND_HOME_PATH} +elif [ -n "${ASCEND_OPP_PATH}" ];then + ASCEND_CANN_PACKAGE_PATH=$(dirname ${ASCEND_OPP_PATH}) +elif [ -d "${DEFAULT_TOOLKIT_INSTALL_DIR}" ];then + ASCEND_CANN_PACKAGE_PATH=${DEFAULT_TOOLKIT_INSTALL_DIR} +elif [ -d "${DEFAULT_INSTALL_DIR}" ];then + ASCEND_CANN_PACKAGE_PATH=${DEFAULT_INSTALL_DIR} +else + log "Error: Please set the toolkit package installation directory through parameter -p|--package-path." + exit 1 +fi + +if [ "${PARENT_JOB}" == "false" ];then + CPU_NUM=$(($(cat /proc/cpuinfo | grep "^processor" | wc -l)*2)) + JOB_NUM="-j${CPU_NUM}" +fi + +CUSTOM_OPTION="${CUSTOM_OPTION} -DCUSTOM_ASCEND_CANN_PACKAGE_PATH=${ASCEND_CANN_PACKAGE_PATH} -DCHECK_COMPATIBLE=${CHECK_COMPATIBLE}" + +set_env +clean + +ccache_system=$(which ccache || true) +if [ -n "${ccache_system}" ];then + CUSTOM_OPTION="${CUSTOM_OPTION} -DENABLE_CCACHE=ON -DCUSTOM_CCACHE=${ccache_system}" + gen_bisheng ${ccache_system} +fi + +cd ${BUILD_DIR} +cmake_config +build_package diff --git a/csrc/build_aclnn.sh b/csrc/build_aclnn.sh new file mode 100644 index 00000000..9dba287e --- /dev/null +++ b/csrc/build_aclnn.sh @@ -0,0 +1,34 @@ +#!/bin/bash + +ROOT_DIR=$1 +SOC_VERSION=$2 + +if [[ "$SOC_VERSION" =~ ^ascend310 ]]; then + # ASCEND310P series + # currently, no custom aclnn ops for ASCEND310 series + # CUSTOM_OPS="" + # SOC_ARG="ascend310p" + exit 0 +elif [[ "$SOC_VERSION" =~ ^ascend910b ]]; then + # ASCEND910B (A2) series + CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list" + SOC_ARG="ascend910b" +elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then + # ASCEND910C (A3) series + CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list" + SOC_ARG="ascend910_93" +else + # others + # currently, no custom aclnn ops for other series + exit 0 +fi + +# build custom ops +cd csrc +rm -rf build output +echo "building custom ops $CUSTOM_OPS for $SOC_VERSION" +bash build.sh -n $CUSTOM_OPS -c $SOC_ARG + +# install custom ops to vllm_ascend/_cann_ops_custom +./output/CANN-custom_ops*.run --install-path=$ROOT_DIR/vllm_ascend/_cann_ops_custom +source $ROOT_DIR/vllm_ascend/_cann_ops_custom/vendors/customize/bin/set_env.bash diff --git a/csrc/cmake/config.cmake b/csrc/cmake/config.cmake new file mode 100644 index 00000000..38553f82 --- /dev/null +++ b/csrc/cmake/config.cmake @@ -0,0 +1,235 @@ +# Copyright (c) 2024 Huawei Technologies Co., Ltd. +# This file is a part of the CANN Open Software. +# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ====================================================================================================================== + +######################################################################################################################## +# Environment Check +######################################################################################################################## + +# Python3 +find_package(Python3) +if ((NOT Python3_FOUND) OR (${Python3_EXECUTABLE} STREQUAL "")) + message(FATAL_ERROR "Can't find python3.") +endif () +set(HI_PYTHON "${Python3_EXECUTABLE}" CACHE STRING "python executor") + +# Get the base CANN path +if (CUSTOM_ASCEND_CANN_PACKAGE_PATH) + set(ASCEND_CANN_PACKAGE_PATH ${CUSTOM_ASCEND_CANN_PACKAGE_PATH}) +elseif (DEFINED ENV{ASCEND_HOME_PATH}) + set(ASCEND_CANN_PACKAGE_PATH $ENV{ASCEND_HOME_PATH}) +elseif (DEFINED ENV{ASCEND_OPP_PATH}) + get_filename_component(ASCEND_CANN_PACKAGE_PATH "$ENV{ASCEND_OPP_PATH}/.." ABSOLUTE) +else() + set(ASCEND_CANN_PACKAGE_PATH "/usr/local/Ascend/latest") +endif () +message(STATUS "ASCEND_CANN_PACKAGE_PATH=${ASCEND_CANN_PACKAGE_PATH}") + +######################################################################################################################## +# Common Configuration +######################################################################################################################## + +# Switches +option(PREPARE_BUILD "Prepare build." OFF) +option(ENABLE_OPS_HOST "Build ops host." ON) +option(ENABLE_OPS_KERNEL "Build ops kernel." ON) +if (TESTS_EXAMPLE_OPS_TEST OR TESTS_UT_OPS_TEST) + set(ENABLE_OPS_KERNEL OFF) +endif () +set(OP_DEBUG_CONFIG "false" CACHE STRING "op debug config") + +# Path configuration +# Source tree related paths +get_filename_component(OPS_ADV_DIR "${CMAKE_CURRENT_SOURCE_DIR}" REALPATH) +get_filename_component(OPS_ADV_CMAKE_DIR "${OPS_ADV_DIR}/cmake" REALPATH) +get_filename_component(OPS_ADV_UTILS_KERNEL_INC "${OPS_ADV_DIR}/utils/inc/kernel" REALPATH) + + +# Build tree related paths +set(ASCEND_IMPL_OUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/impl CACHE STRING "ascend impl output directories") +set(ASCEND_BINARY_OUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/binary CACHE STRING "ascend binary output directories") +set(ASCEND_AUTOGEN_DIR ${CMAKE_CURRENT_BINARY_DIR}/autogen CACHE STRING "Auto generate file directories") +set(ASCEND_CUSTOM_OPTIONS ${ASCEND_AUTOGEN_DIR}/custom_compile_options.ini) +set(ASCEND_CUSTOM_TILING_KEYS ${ASCEND_AUTOGEN_DIR}/custom_tiling_keys.ini) +set(ASCEND_CUSTOM_OPC_OPTIONS ${ASCEND_AUTOGEN_DIR}/custom_opc_options.ini) +set(OP_BUILD_TOOL ${ASCEND_CANN_PACKAGE_PATH}/tools/opbuild/op_build CACHE STRING "op_build tool") +file(MAKE_DIRECTORY ${ASCEND_AUTOGEN_DIR}) +file(REMOVE ${ASCEND_CUSTOM_OPTIONS}) +file(TOUCH ${ASCEND_CUSTOM_OPTIONS}) +file(REMOVE ${ASCEND_CUSTOM_TILING_KEYS}) +file(TOUCH ${ASCEND_CUSTOM_TILING_KEYS}) +file(REMOVE ${ASCEND_CUSTOM_OPC_OPTIONS}) +file(TOUCH ${ASCEND_CUSTOM_OPC_OPTIONS}) +if (BUILD_OPEN_PROJECT) + if(EXISTS ${ASCEND_CANN_PACKAGE_PATH}/tools/ascend_project/cmake) + set(ASCEND_PROJECT_DIR ${ASCEND_CANN_PACKAGE_PATH}/tools/ascend_project) + else() + set(ASCEND_PROJECT_DIR ${ASCEND_CANN_PACKAGE_PATH}/tools/op_project_templates/ascendc/customize) + endif() + set(ASCEND_CMAKE_DIR ${ASCEND_PROJECT_DIR}/cmake CACHE STRING "ascend project cmake") + set(IMPL_INSTALL_DIR packages/vendors/${VENDOR_NAME}/op_impl/ai_core/tbe/${VENDOR_NAME}_impl) + set(IMPL_DYNAMIC_INSTALL_DIR packages/vendors/${VENDOR_NAME}/op_impl/ai_core/tbe/${VENDOR_NAME}_impl/dynamic) + set(ACLNN_INC_INSTALL_DIR packages/vendors/${VENDOR_NAME}/op_api/include) +else() + set(ASCEND_CMAKE_DIR ${TOP_DIR}/asl/ops/cann/ops/built-in/ascendc/samples/customize/cmake CACHE STRING "ascend project cmake") + set(IMPL_INSTALL_DIR lib/ascendc/impl) + set(IMPL_DYNAMIC_INSTALL_DIR lib/ascendc/impl/dynamic) + set(ACLNN_INC_INSTALL_DIR lib/include) + set(OPS_STATIC_TYPES infer train) + set(OPS_STATIC_SCRIPT ${TOP_DIR}/asl/ops/cann/ops/built-in/kernel/binary_script/build_opp_kernel_static.py) +endif () +set(ASCENDC_CMAKE_UTIL_DIR ${ASCEND_CMAKE_DIR}/util) +set(CUSTOM_DIR ${CMAKE_BINARY_DIR}/custom) +set(TILING_CUSTOM_DIR ${CUSTOM_DIR}/op_impl/ai_core/tbe/op_tiling) +set(TILING_CUSTOM_FILE ${TILING_CUSTOM_DIR}/liboptiling.so) + +# Temporary adaptation for ascendc changes, to be removed after switching to the new version of ascendc +if(EXISTS ${ASCENDC_CMAKE_UTIL_DIR}/ascendc_gen_options.py) + set(ADD_OPS_COMPILE_OPTION_V2 ON) +else() + set(ADD_OPS_COMPILE_OPTION_V2 OFF) +endif() + +######################################################################################################################## +# CMake Options, Default Parameters Setting +# Configure CMake options and default parameters according to the CMake build process +# CMake build process: 1) Configuration phase; 2) Build phase; 3) Installation phase; +######################################################################################################################## +if (BUILD_OPEN_PROJECT) + # Build phase + # Build type + # The Generator in CMake is a tool used to generate native build systems. Generally divided into two types: + # 1. Single-configuration generator: + # In the configuration phase, only one build type is allowed to be specified through the variable CMAKE_BUILD_TYPE; + # In the build phase, the build type cannot be changed, and only the build type specified through the variable CMAKE_BUILD_TYPE in the configuration phase can be used; + # Common generators of this type include: Ninja, Unix Makefiles + # 2. Multi-configuration generator: + # In the configuration phase, only the list of build types available in the build phase is specified through the variable CMAKE_CONFIGURATION_TYPES; + # In the build phase, the specific build type of the build phase is specified through the "--config" parameter; + # Common generators of this type include: Xcode, Visual Studio + # Therefore: + # 1. In the single-configuration generator scenario, if the build type (CMAKE_BUILD_TYPE) is not specified, the default is Debug; + # 2. In the multi-configuration generator scenario, if the build types available in the build phase (CMAKE_CONFIGURATION_TYPES) are not specified, + # it is defaulted to the full set of build types allowed by CMake [Debug;Release;MinSizeRel;RelWithDebInfo] + get_property(GENERATOR_IS_MULTI_CONFIG GLOBAL PROPERTY GENERATOR_IS_MULTI_CONFIG) + if (GENERATOR_IS_MULTI_CONFIG) + if (NOT CMAKE_CONFIGURATION_TYPES) + set(CMAKE_CONFIGURATION_TYPES "Debug;Release;MinSizeRel;RelWithDebInfo" CACHE STRING "Configuration Build type" FORCE) + endif () + else () + if (NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE "Debug" CACHE STRING "Build type(default Debug)" FORCE) + endif () + endif () + + # Build phase + # Executable runtime library file search path RPATH + # Do not skip RPATH in UTest and Example scenarios + if (TESTS_UT_OPS_TEST OR TESTS_EXAMPLE_OPS_TEST) + set(CMAKE_SKIP_RPATH FALSE) + else () + set(CMAKE_SKIP_RPATH TRUE) + endif () + + # Build phase + # CCACHE configuration + if (ENABLE_CCACHE) + if (CUSTOM_CCACHE) + set(CCACHE_PROGRAM ${CUSTOM_CCACHE}) + else() + find_program(CCACHE_PROGRAM ccache) + endif () + if (CCACHE_PROGRAM) + set(CMAKE_C_COMPILER_LAUNCHER ${CCACHE_PROGRAM} CACHE PATH "C cache Compiler") + set(CMAKE_CXX_COMPILER_LAUNCHER ${CCACHE_PROGRAM} CACHE PATH "CXX cache Compiler") + endif () + endif () + + # Installation phase + # Installation path + # When CMAKE_INSTALL_PREFIX is not explicitly set (i.e., CMAKE_INSTALL_PREFIX takes the default value), + # correct its value to be level with the build tree root directory CMAKE_CURRENT_BINARY_DIR + if (CMAKE_INSTALL_PREFIX_INITIALIZED_TO_DEFAULT) + get_filename_component(_Install_Path_Prefix "${CMAKE_CURRENT_BINARY_DIR}/../output" REALPATH) + set(CMAKE_INSTALL_PREFIX "${_Install_Path_Prefix}" CACHE STRING "Install path" FORCE) + endif () +endif () + +######################################################################################################################## +# Public Compilation Parameters +######################################################################################################################## +list(TRANSFORM ASCEND_COMPUTE_UNIT TOLOWER) +if (BUILD_OPEN_PROJECT) + message(STATUS "ENABLE_CCACHE=${ENABLE_CCACHE}, CUSTOM_CCACHE=${CUSTOM_CCACHE}") + message(STATUS "CCACHE_PROGRAM=${CCACHE_PROGRAM}") + message(STATUS "ASCEND_COMPUTE_UNIT=${ASCEND_COMPUTE_UNIT}") + message(STATUS "ASCEND_OP_NAME=${ASCEND_OP_NAME}") + message(STATUS "TILING_KEY=${TILING_KEY}") + message(STATUS "TESTS_UT_OPS_TEST=${TESTS_UT_OPS_TEST}") + message(STATUS "TESTS_EXAMPLE_OPS_TEST=${TESTS_EXAMPLE_OPS_TEST}") +endif () + +######################################################################################################################## +# Preprocessing +######################################################################################################################## +if (BUILD_OPEN_PROJECT) + if (NOT PREPARE_BUILD AND ENABLE_OPS_KERNEL) + if (TILING_KEY) + string(REPLACE ";" "::" EP_TILING_KEY "${TILING_KEY}") + else() + set(EP_TILING_KEY FALSE) + endif () + + if (OPS_COMPILE_OPTIONS) + string(REPLACE ";" "::" EP_OPS_COMPILE_OPTIONS "${OPS_COMPILE_OPTIONS}") + else() + set(EP_OPS_COMPILE_OPTIONS FALSE) + endif () + + string(REPLACE ";" "::" EP_ASCEND_COMPUTE_UNIT "${ASCEND_COMPUTE_UNIT}") + + execute_process(COMMAND bash ${CMAKE_CURRENT_SOURCE_DIR}/cmake/scripts/prepare.sh + -s ${CMAKE_CURRENT_SOURCE_DIR} + -b ${CMAKE_CURRENT_BINARY_DIR}/prepare_build + -p ${ASCEND_CANN_PACKAGE_PATH} + --autogen-dir ${ASCEND_AUTOGEN_DIR} + --build-open-project ${BUILD_OPEN_PROJECT} + --binary-out-dir ${ASCEND_BINARY_OUT_DIR} + --impl-out-dir ${ASCEND_IMPL_OUT_DIR} + --op-build-tool ${OP_BUILD_TOOL} + --ascend-cmake-dir ${ASCEND_CMAKE_DIR} + --tiling-key ${EP_TILING_KEY} + --ops-compile-options ${EP_OPS_COMPILE_OPTIONS} + --check-compatible ${CHECK_COMPATIBLE} + --ascend-compute_unit ${EP_ASCEND_COMPUTE_UNIT} + --op_debug_config ${OP_DEBUG_CONFIG} + --ascend-op-name "${ASCEND_OP_NAME}" + RESULT_VARIABLE result + OUTPUT_STRIP_TRAILING_WHITESPACE + OUTPUT_VARIABLE PREPARE_BUILD_OUTPUT_VARIABLE) + if (result) + message(FATAL_ERROR "Error: ops prepare build failed.") + endif () + + file(REMOVE ${ASCEND_CUSTOM_OPTIONS}) + file(TOUCH ${ASCEND_CUSTOM_OPTIONS}) + file(REMOVE ${ASCEND_CUSTOM_TILING_KEYS}) + file(TOUCH ${ASCEND_CUSTOM_TILING_KEYS}) + file(REMOVE ${ASCEND_CUSTOM_OPC_OPTIONS}) + file(TOUCH ${ASCEND_CUSTOM_OPC_OPTIONS}) + endif () +endif () + +######################################################################################################################## +# Other Configuration +######################################################################################################################## +if (BUILD_OPEN_PROJECT) + if (TESTS_UT_OPS_TEST) + include(${OPS_ADV_CMAKE_DIR}/config_utest.cmake) + endif () +endif () diff --git a/csrc/cmake/func.cmake b/csrc/cmake/func.cmake new file mode 100644 index 00000000..f2bebf75 --- /dev/null +++ b/csrc/cmake/func.cmake @@ -0,0 +1,609 @@ +# Copyright (c) 2024 Huawei Technologies Co., Ltd. +# This file is a part of the CANN Open Software. +# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ====================================================================================================================== + +function(add_target_source) + cmake_parse_arguments(ADD "" "BASE_TARGET;SRC_DIR" "TARGET_NAME" ${ARGN}) + + get_target_property(all_srcs ${ADD_BASE_TARGET} SOURCES) + set(add_srcs) + foreach(_src ${all_srcs}) + string(REGEX MATCH "^${ADD_SRC_DIR}" is_match "${_src}") + if (is_match) + list(APPEND add_srcs ${_src}) + endif () + endforeach() + + get_target_property(all_includes ${ADD_BASE_TARGET} INCLUDE_DIRECTORIES) + set(add_includes) + foreach(_include ${all_includes}) + string(REGEX MATCH "^${ADD_SRC_DIR}" is_match "${_include}") + if (is_match) + list(APPEND add_includes ${_include}) + endif () + endforeach() + + foreach(_target_name ${ADD_TARGET_NAME}) + target_sources(${_target_name} PRIVATE + ${add_srcs} + ) + + target_include_directories(${_target_name} PRIVATE + ${add_includes} + ) + endforeach() +endfunction() + +function(op_add_subdirectory OP_LIST OP_DIR_LIST) + set(_OP_LIST) + set(_OP_DIR_LIST) + + file(GLOB OP_HOST_CMAKE_FILES "${CMAKE_CURRENT_SOURCE_DIR}/**/op_host/CMakeLists.txt") + + foreach(OP_CMAKE_FILE ${OP_HOST_CMAKE_FILES}) + get_filename_component(OP_HOST_DIR "${OP_CMAKE_FILE}" DIRECTORY) + get_filename_component(OP_DIR "${OP_HOST_DIR}" DIRECTORY) + get_filename_component(OP_NAME "${OP_DIR}" NAME) + + if (NOT BUILD_OPEN_PROJECT) + if (EXISTS ${TOP_DIR}/asl/ops/cann/ops/built-in/tbe/impl/ascendc/${OP_NAME}) + continue() + endif () + endif () + + if (DEFINED ASCEND_OP_NAME AND NOT "${ASCEND_OP_NAME}" STREQUAL "") + if (NOT "${ASCEND_OP_NAME}" STREQUAL "all" AND NOT "${ASCEND_OP_NAME}" STREQUAL "ALL") + if (NOT ${OP_NAME} IN_LIST ASCEND_OP_NAME) + continue() + endif () + endif () + endif () + + list(APPEND _OP_LIST ${OP_NAME}) + list(APPEND _OP_DIR_LIST ${OP_DIR}) + endforeach() + + list(REMOVE_DUPLICATES _OP_LIST) + list(REMOVE_DUPLICATES _OP_DIR_LIST) + list(SORT _OP_LIST) + list(SORT _OP_DIR_LIST) + set(${OP_LIST} ${_OP_LIST} PARENT_SCOPE) + set(${OP_DIR_LIST} ${_OP_DIR_LIST} PARENT_SCOPE) +endfunction() + +function(op_add_depend_directory) + cmake_parse_arguments(DEP "" "OP_DIR_LIST" "OP_LIST" ${ARGN}) + set(_OP_DEPEND_DIR_LIST) + foreach(op_name ${DEP_OP_LIST}) + if (DEFINED ${op_name}_depends) + foreach(depend_info ${${op_name}_depends}) + if (NOT EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${depend_info}/op_host/CMakeLists.txt) + continue() + endif () + + get_filename_component(_depend_op_name "${depend_info}" NAME) + if (NOT BUILD_OPEN_PROJECT) + if (EXISTS ${TOP_DIR}/asl/ops/cann/ops/built-in/tbe/impl/ascendc/${_depend_op_name}) + continue() + endif () + endif () + + if (NOT ${_depend_op_name} IN_LIST DEP_OP_LIST) + list(APPEND _OP_DEPEND_DIR_LIST ${CMAKE_CURRENT_SOURCE_DIR}/${depend_info}) + endif () + endforeach() + endif() + endforeach() + + list(SORT _OP_DEPEND_DIR_LIST) + set(${DEP_OP_DIR_LIST} ${_OP_DEPEND_DIR_LIST} PARENT_SCOPE) +endfunction() + +function(add_compile_cmd_target) + cmake_parse_arguments(CMD "" "COMPUTE_UNIT" "" ${ARGN}) + + if(ADD_OPS_COMPILE_OPTION_V2) + set(OP_DEBUG_CONFIG_OPTION --opc-config-file ${ASCEND_CUSTOM_OPC_OPTIONS}) + else() + if(OP_DEBUG_CONFIG) + set(OP_DEBUG_CONFIG_OPTION --op-debug-config ${OP_DEBUG_CONFIG}) + endif() + set(OP_TILING_KEY_OPTION --tiling-keys ${ASCEND_CUSTOM_TILING_KEYS}) + endif() + + set(_OUT_DIR ${ASCEND_BINARY_OUT_DIR}/${CMD_COMPUTE_UNIT}) + set(GEN_OUT_DIR ${_OUT_DIR}/gen) + set(COMPILE_CMD_TARGET generate_compile_cmd_${CMD_COMPUTE_UNIT}) + add_custom_target(${COMPILE_CMD_TARGET} ALL + COMMAND mkdir -p ${GEN_OUT_DIR} + COMMAND ${HI_PYTHON} ${ASCENDC_CMAKE_UTIL_DIR}/ascendc_bin_param_build.py + ${base_aclnn_binary_dir}/aic-${CMD_COMPUTE_UNIT}-ops-info.ini + ${GEN_OUT_DIR} + ${CMD_COMPUTE_UNIT} + ${OP_TILING_KEY_OPTION} + ${OP_DEBUG_CONFIG_OPTION} + COMMAND ${HI_PYTHON} ${ASCENDC_CMAKE_UTIL_DIR}/ascendc_bin_param_build.py + ${base_aclnn_binary_dir}/inner/aic-${CMD_COMPUTE_UNIT}-ops-info.ini + ${GEN_OUT_DIR} + ${CMD_COMPUTE_UNIT} + ${OP_TILING_KEY_OPTION} + ${OP_DEBUG_CONFIG_OPTION} + COMMAND ${HI_PYTHON} ${ASCENDC_CMAKE_UTIL_DIR}/ascendc_bin_param_build.py + ${base_aclnn_binary_dir}/exc/aic-${CMD_COMPUTE_UNIT}-ops-info.ini + ${GEN_OUT_DIR} + ${CMD_COMPUTE_UNIT} + ${OP_TILING_KEY_OPTION} + ${OP_DEBUG_CONFIG_OPTION} + ) + + add_dependencies(${COMPILE_CMD_TARGET} opbuild_gen_default opbuild_gen_inner opbuild_gen_exc) + add_dependencies(generate_compile_cmd ${COMPILE_CMD_TARGET}) +endfunction() + +function(add_ops_info_target) + cmake_parse_arguments(OPINFO "" "COMPUTE_UNIT" "" ${ARGN}) + + set(OPS_INFO_TARGET generate_ops_info_${OPINFO_COMPUTE_UNIT}) + set(OPS_INFO_JSON ${ASCEND_AUTOGEN_DIR}/aic-${OPINFO_COMPUTE_UNIT}-ops-info.json) + set(CUSTOM_OPS_INFO_DIR ${CUSTOM_DIR}/op_impl/ai_core/tbe/config/${OPINFO_COMPUTE_UNIT}) + + set(OPS_INFO_INI ${base_aclnn_binary_dir}/aic-${OPINFO_COMPUTE_UNIT}-ops-info.ini) + set(OPS_INFO_INNER_INI ${base_aclnn_binary_dir}/inner/aic-${OPINFO_COMPUTE_UNIT}-ops-info.ini) + set(OPS_INFO_EXCLUDE_INI ${base_aclnn_binary_dir}/exc/aic-${OPINFO_COMPUTE_UNIT}-ops-info.ini) + + add_custom_command(OUTPUT ${OPS_INFO_JSON} + COMMAND ${HI_PYTHON} ${ASCENDC_CMAKE_UTIL_DIR}/parse_ini_to_json.py + ${OPS_INFO_INI} + ${OPS_INFO_INNER_INI} + ${OPS_INFO_EXCLUDE_INI} + ${OPS_INFO_JSON} + COMMAND mkdir -p ${CUSTOM_OPS_INFO_DIR} + COMMAND cp -f ${OPS_INFO_JSON} ${CUSTOM_OPS_INFO_DIR} + ) + + add_custom_target(${OPS_INFO_TARGET} ALL + DEPENDS ${OPS_INFO_JSON} + ) + + add_dependencies(${OPS_INFO_TARGET} opbuild_gen_default opbuild_gen_inner opbuild_gen_exc) + add_dependencies(generate_ops_info ${OPS_INFO_TARGET}) + + install(FILES ${OPS_INFO_JSON} + DESTINATION packages/vendors/${VENDOR_NAME}/op_impl/ai_core/tbe/config/${OPINFO_COMPUTE_UNIT} OPTIONAL + ) +endfunction() + +function(add_ops_compile_options) + cmake_parse_arguments(OP_COMPILE "" "OP_NAME" "COMPUTE_UNIT;OPTIONS" ${ARGN}) + + if(NOT OP_COMPILE_OPTIONS) + return() + endif() + + if(ADD_OPS_COMPILE_OPTION_V2) + execute_process(COMMAND ${HI_PYTHON} ${ASCENDC_CMAKE_UTIL_DIR}/ascendc_gen_options.py + ${ASCEND_CUSTOM_OPTIONS} ${OP_COMPILE_OP_NAME} + ${OP_COMPILE_COMPUTE_UNIT} ${OP_COMPILE_OPTIONS} + RESULT_VARIABLE EXEC_RESULT + OUTPUT_VARIABLE EXEC_INFO + ERROR_VARIABLE EXEC_ERROR) + if (EXEC_RESULT) + message("add ops compile options info: ${EXEC_INFO}") + message("add ops compile options error: ${EXEC_ERROR}") + message(FATAL_ERROR "Error: add ops compile options failed!") + endif () + else() + file(APPEND ${ASCEND_CUSTOM_OPTIONS} + "${OP_COMPILE_OP_NAME},${OP_COMPILE_COMPUTE_UNIT},${OP_COMPILE_OPTIONS}\n" + ) + endif() +endfunction() + +function(add_ops_tiling_keys) + cmake_parse_arguments(OP_COMPILE "" "OP_NAME" "COMPUTE_UNIT;TILING_KEYS" ${ARGN}) + + if(NOT OP_COMPILE_TILING_KEYS) + return() + endif() + + if(ADD_OPS_COMPILE_OPTION_V2) + list(JOIN OP_COMPILE_TILING_KEYS "," STRING_TILING_KEYS) + add_ops_compile_options( + OP_NAME ${OP_COMPILE_OP_NAME} + OPTIONS --tiling_key=${STRING_TILING_KEYS} + ) + else() + file(APPEND ${ASCEND_CUSTOM_TILING_KEYS} + "${OP_COMPILE_OP_NAME},${OP_COMPILE_COMPUTE_UNIT},${OP_COMPILE_TILING_KEYS}\n" + ) + endif() +endfunction() + +function(add_opc_config) + cmake_parse_arguments(OP_COMPILE "" "OP_NAME" "COMPUTE_UNIT;CONFIG" ${ARGN}) + + if(NOT ADD_OPS_COMPILE_OPTION_V2) + return() + endif() + + if(NOT OP_COMPILE_CONFIG) + return() + endif() + + string(REPLACE "," ";" OP_COMPILE_CONFIG_LIST "${OP_COMPILE_CONFIG}") + + set(_OPC_CONFIG) + + foreach(_option ${OP_COMPILE_CONFIG_LIST}) + if("${_option}" STREQUAL "ccec_g") + list(APPEND _OPC_CONFIG "-g") + elseif("${_option}" STREQUAL "ccec_O0") + list(APPEND _OPC_CONFIG "-O0") + elseif("${_option}" STREQUAL "oom") + list(APPEND _OPC_CONFIG "--oom") + elseif("${_option}" STREQUAL "dump_cce") + list(APPEND _OPC_CONFIG "--save-temp-files") + endif() + endforeach() + + if(_OPC_CONFIG) + add_ops_compile_options( + OP_NAME ${OP_COMPILE_OP_NAME} + OPTIONS ${_OPC_CONFIG} + ) + endif() +endfunction() + +function(add_ops_src_copy) + cmake_parse_arguments(SRC_COPY "" "TARGET_NAME;SRC;DST;BE_RELIED;COMPUTE_UNIT" "" ${ARGN}) + + set(OPS_UTILS_INC_KERNEL_TARGET ops_utils_inc_kernel_${SRC_COPY_COMPUTE_UNIT}) + if (EXISTS ${OPS_ADV_UTILS_KERNEL_INC}) + if (NOT TARGET ${OPS_UTILS_INC_KERNEL_TARGET}) + get_filename_component(_ROOT_OPS_SRC_DIR "${SRC_COPY_DST}" DIRECTORY) + set(OPS_UTILS_INC_KERNEL_DIR ${_ROOT_OPS_SRC_DIR}/ascendc/common) + add_custom_command(OUTPUT ${OPS_UTILS_INC_KERNEL_DIR} + COMMAND mkdir -p ${OPS_UTILS_INC_KERNEL_DIR} + COMMAND cp -rf ${OPS_ADV_UTILS_KERNEL_INC}/*.* ${OPS_UTILS_INC_KERNEL_DIR} + ) + + add_custom_target(${OPS_UTILS_INC_KERNEL_TARGET} + DEPENDS ${OPS_UTILS_INC_KERNEL_DIR} + ) + endif () + endif () + + if (NOT TARGET ${SRC_COPY_TARGET_NAME}) + set(_BUILD_FLAG ${SRC_COPY_DST}/${SRC_COPY_TARGET_NAME}.done) + add_custom_command(OUTPUT ${_BUILD_FLAG} + COMMAND mkdir -p ${SRC_COPY_DST} + COMMAND cp -rf ${SRC_COPY_SRC}/op_kernel/*.* ${SRC_COPY_DST} + COMMAND touch ${_BUILD_FLAG} + ) + + add_custom_target(${SRC_COPY_TARGET_NAME} + DEPENDS ${_BUILD_FLAG} + ) + endif () + + if (TARGET ${OPS_UTILS_INC_KERNEL_TARGET}) + add_dependencies(${SRC_COPY_TARGET_NAME} ${OPS_UTILS_INC_KERNEL_TARGET}) + endif () + + if (DEFINED SRC_COPY_BE_RELIED) + add_dependencies(${SRC_COPY_BE_RELIED} ${SRC_COPY_TARGET_NAME}) + endif () + +endfunction() + +function(add_bin_compile_target) + cmake_parse_arguments(BINARY "" "COMPUTE_UNIT" "OP_INFO" ${ARGN}) + + set(_INSTALL_DIR packages/vendors/${VENDOR_NAME}/op_impl/ai_core/tbe/kernel) + set(_OUT_DIR ${ASCEND_BINARY_OUT_DIR}/${BINARY_COMPUTE_UNIT}) + + set(BIN_OUT_DIR ${_OUT_DIR}/bin) + set(GEN_OUT_DIR ${_OUT_DIR}/gen) + set(SRC_OUT_DIR ${_OUT_DIR}/src) + file(MAKE_DIRECTORY ${BIN_OUT_DIR}) + + foreach(_op_info ${BINARY_OP_INFO}) + get_filename_component(_op_name "${_op_info}" NAME) + set(${_op_name}_dir ${_op_info}) + endforeach() + + set(_ops_target_list) + set(compile_scripts) + file(GLOB scripts_list ${GEN_OUT_DIR}/*.sh) + list(APPEND compile_scripts ${scripts_list}) + + foreach(bin_script ${compile_scripts}) + get_filename_component(bin_file ${bin_script} NAME_WE) + string(REPLACE "-" ";" bin_sep ${bin_file}) + list(GET bin_sep 0 op_type) + list(GET bin_sep 1 op_file) + list(GET bin_sep 2 op_index) + + if (NOT DEFINED ${op_file}_dir) + continue() + endif () + + if (NOT TARGET ${op_file}) + add_custom_target(${op_file}) + add_dependencies(ops_kernel ${op_file}) + endif () + + set(OP_TARGET_NAME ${op_file}_${BINARY_COMPUTE_UNIT}) + + if (NOT TARGET ${OP_TARGET_NAME}) + add_custom_target(${OP_TARGET_NAME}) + add_dependencies(${op_file} ${OP_TARGET_NAME}) + list(APPEND _ops_target_list ${OP_TARGET_NAME}) + + set(OP_SRC_OUT_DIR ${SRC_OUT_DIR}/${op_file}) + set(OP_BIN_OUT_DIR ${BIN_OUT_DIR}/${op_file}) + file(MAKE_DIRECTORY ${OP_SRC_OUT_DIR}) + + add_ops_src_copy( + TARGET_NAME + ${OP_TARGET_NAME}_src_copy + SRC + ${${op_file}_dir} + DST + ${OP_SRC_OUT_DIR} + COMPUTE_UNIT + ${BINARY_COMPUTE_UNIT} + ) + + if (DEFINED ${op_file}_depends) + foreach(depend_info ${${op_file}_depends}) + get_filename_component(_depend_op_name "${depend_info}" NAME) + set(_depend_op_target ${_depend_op_name}_${BINARY_COMPUTE_UNIT}_src_copy) + add_ops_src_copy( + TARGET_NAME + ${_depend_op_target} + SRC + ${CMAKE_SOURCE_DIR}/${depend_info} + DST + ${SRC_OUT_DIR}/${_depend_op_name} + COMPUTE_UNIT + ${BINARY_COMPUTE_UNIT} + BE_RELIED + ${OP_TARGET_NAME}_src_copy + ) + endforeach() + endif () + + set(DYNAMIC_PY_FILE ${OP_SRC_OUT_DIR}/${op_type}.py) + add_custom_command(OUTPUT ${DYNAMIC_PY_FILE} + COMMAND cp -rf ${ASCEND_IMPL_OUT_DIR}/dynamic/${op_file}.py ${DYNAMIC_PY_FILE} + ) + + add_custom_target(${OP_TARGET_NAME}_py_copy + DEPENDS ${DYNAMIC_PY_FILE} + ) + + add_custom_command(OUTPUT ${OP_BIN_OUT_DIR} + COMMAND mkdir -p ${OP_BIN_OUT_DIR} + ) + + add_custom_target(${OP_TARGET_NAME}_mkdir + DEPENDS ${OP_BIN_OUT_DIR} + ) + + install(DIRECTORY ${OP_BIN_OUT_DIR} + DESTINATION ${_INSTALL_DIR}/${BINARY_COMPUTE_UNIT} OPTIONAL + ) + + install(FILES ${BIN_OUT_DIR}/${op_file}.json + DESTINATION ${_INSTALL_DIR}/config/${BINARY_COMPUTE_UNIT} OPTIONAL + ) + endif () + + set(_group "1-0") + if (DEFINED ASCEND_OP_NAME AND NOT "${ASCEND_OP_NAME}" STREQUAL "") + if (NOT "${ASCEND_OP_NAME}" STREQUAL "all" AND NOT "${ASCEND_OP_NAME}" STREQUAL "ALL") + if (${op_file} IN_LIST ASCEND_OP_NAME) + list(LENGTH ASCEND_OP_NAME _len) + list(FIND ASCEND_OP_NAME ${op_file} _index) + math(EXPR _next_index "${_index} + 1") + if (${_next_index} LESS ${_len}) + list(GET ASCEND_OP_NAME ${_next_index} _group_str) + set(_regex "^[0-9]+-[0-9]+$") + string(REGEX MATCH "${_regex}" match "${_group_str}") + if (match) + set(_group ${_group_str}) + endif () + endif () + endif () + endif () + endif () + + string(REPLACE "-" ";" _group_sep ${_group}) + + list(GET _group_sep 1 start_index) + set(end_index ${op_index}) + list(GET _group_sep 0 step) + + set(_compile_flag false) + if (${start_index} LESS ${end_index}) + foreach(i RANGE ${start_index} ${end_index} ${step}) + if (${i} EQUAL ${end_index}) + set(_compile_flag true) + break() + endif () + endforeach() + elseif (${start_index} EQUAL ${end_index}) + set(_compile_flag true) + else() + set(_compile_flag false) + endif () + + if (_compile_flag) + set(_BUILD_COMMAND) + set(_BUILD_FLAG ${GEN_OUT_DIR}/${OP_TARGET_NAME}_${op_index}.done) + if (ENABLE_OPS_HOST) + list(APPEND _BUILD_COMMAND export ASCEND_CUSTOM_OPP_PATH=${CUSTOM_DIR} &&) + endif () + list(APPEND _BUILD_COMMAND export HI_PYTHON="python3" &&) + list(APPEND _BUILD_COMMAND export TILINGKEY_PAR_COMPILE=1 &&) + list(APPEND _BUILD_COMMAND export BIN_FILENAME_HASHED=1 &&) + list(APPEND _BUILD_COMMAND bash ${bin_script} ${OP_SRC_OUT_DIR}/${op_type}.py ${OP_BIN_OUT_DIR}) + if(CMAKE_GENERATOR MATCHES "Unix Makefiles") + list(APPEND _BUILD_COMMAND && echo $(MAKE)) + endif() + + add_custom_command(OUTPUT ${_BUILD_FLAG} + COMMAND ${_BUILD_COMMAND} + COMMAND touch ${_BUILD_FLAG} + WORKING_DIRECTORY ${GEN_OUT_DIR} + ) + + add_custom_target(${OP_TARGET_NAME}_${op_index} + DEPENDS ${_BUILD_FLAG} + ) + + if (ENABLE_OPS_HOST) + add_dependencies(${OP_TARGET_NAME}_${op_index} optiling generate_ops_info) + endif () + add_dependencies(${OP_TARGET_NAME}_${op_index} ${OP_TARGET_NAME}_src_copy ${OP_TARGET_NAME}_py_copy ${OP_TARGET_NAME}_mkdir) + add_dependencies(${OP_TARGET_NAME} ${OP_TARGET_NAME}_${op_index}) + endif () + endforeach() + + if (_ops_target_list) + set(OPS_CONFIG_TARGET ops_config_${BINARY_COMPUTE_UNIT}) + set(BINARY_INFO_CONFIG_FILE ${BIN_OUT_DIR}/binary_info_config.json) + + add_custom_command(OUTPUT ${BINARY_INFO_CONFIG_FILE} + COMMAND ${HI_PYTHON} ${ASCENDC_CMAKE_UTIL_DIR}/ascendc_ops_config.py -p ${BIN_OUT_DIR} -s ${BINARY_COMPUTE_UNIT} + ) + + add_custom_target(${OPS_CONFIG_TARGET} + DEPENDS ${BINARY_INFO_CONFIG_FILE} + ) + + add_dependencies(ops_config ${OPS_CONFIG_TARGET}) + + foreach(_op_target ${_ops_target_list}) + add_dependencies(${OPS_CONFIG_TARGET} ${_op_target}) + endforeach() + + install(FILES ${BINARY_INFO_CONFIG_FILE} + DESTINATION ${_INSTALL_DIR}/config/${BINARY_COMPUTE_UNIT} OPTIONAL + ) + endif () +endfunction() + +function(redefine_file_macro) + cmake_parse_arguments(_FILE "" "" "TARGET_NAME" ${ARGN}) + + foreach(_target_name ${_FILE_TARGET_NAME}) + target_compile_options(${_target_name} PRIVATE + -Wno-builtin-macro-redefined + ) + + get_target_property(_srcs ${_target_name} SOURCES) + + foreach(_src ${_srcs}) + get_filename_component(_src_name "${_src}" NAME) + set_source_files_properties(${_src} + PROPERTIES COMPILE_DEFINITIONS __FILE__="${_src_name}" + ) + endforeach() + endforeach() +endfunction() + +function(add_static_ops) + cmake_parse_arguments(STATIC "" "SRC_DIR" "ACLNN_SRC;ACLNN_INNER_SRC" ${ARGN}) + set(prepare_ops_adv_static_target prepare_ops_adv_static) + set(static_src_temp_dir ${CMAKE_CURRENT_BINARY_DIR}/static_src_temp_dir) + set(modified_files) + foreach(ops_type ${OPS_STATIC_TYPES}) + get_target_property(all_srcs aclnn_ops_${ops_type} SOURCES) + set(add_srcs) + set(generate_aclnn_srcs) + foreach(_src ${all_srcs}) + string(REGEX MATCH "^${STATIC_SRC_DIR}" is_match "${_src}") + if (is_match) + list(APPEND add_srcs ${_src}) + endif () + endforeach() + + foreach(_src ${add_srcs}) + get_filename_component(name_without_ext ${_src} NAME_WE) + string(REGEX REPLACE "^aclnn_" "" _op_name ${name_without_ext}) + + foreach(_aclnn_src ${STATIC_ACLNN_SRC}) + get_filename_component(aclnn_name ${_aclnn_src} NAME_WE) + if("aclnn_${_op_name}" STREQUAL "${aclnn_name}") + list(APPEND generate_aclnn_srcs ${_aclnn_src}) + break() + endif() + endforeach() + + foreach(_aclnn_inner_src ${STATIC_ACLNN_INNER_SRC}) + get_filename_component(aclnn_inner_name ${_aclnn_inner_src} NAME_WE) + if("aclnnInner_${_op_name}" STREQUAL "${aclnn_inner_name}") + list(APPEND generate_aclnn_srcs ${_aclnn_inner_src}) + break() + endif() + endforeach() + endforeach() + + if(add_srcs) + list(TRANSFORM add_srcs REPLACE "${STATIC_SRC_DIR}" "${static_src_temp_dir}" OUTPUT_VARIABLE add_static_srcs) + list(APPEND modified_files ${add_static_srcs}) + set(aclnn_ops_static_target aclnn_ops_${ops_type}_static) + set_source_files_properties(${add_static_srcs} + TARGET_DIRECTORY ${aclnn_ops_static_target} + PROPERTIES GENERATED TRUE + ) + + target_sources(${aclnn_ops_static_target} PRIVATE + ${add_static_srcs} + ) + add_dependencies(${aclnn_ops_static_target} ${prepare_ops_adv_static_target}) + endif() + + if(generate_aclnn_srcs) + list(REMOVE_DUPLICATES generate_aclnn_srcs) + set(aclnn_op_target acl_op_${ops_type}_builtin) + set_source_files_properties(${generate_aclnn_srcs} + TARGET_DIRECTORY ${aclnn_op_target} + PROPERTIES GENERATED TRUE + ) + + target_sources(${aclnn_op_target} PRIVATE + ${generate_aclnn_srcs} + ) + endif() + endforeach() + + if(NOT TARGET ${prepare_ops_adv_static_target}) + list(REMOVE_DUPLICATES modified_files) + add_custom_command(OUTPUT ${static_src_temp_dir} + COMMAND mkdir -p ${static_src_temp_dir} + COMMAND cp -rf ${STATIC_SRC_DIR}/src ${static_src_temp_dir} + COMMAND ${HI_PYTHON} -B ${OPS_STATIC_SCRIPT} InsertIni -p ${static_src_temp_dir} -f ${modified_files} + ) + + add_custom_target(${prepare_ops_adv_static_target} + DEPENDS ${static_src_temp_dir} + ) + endif() +endfunction() + +if (BUILD_OPEN_PROJECT) + if (TESTS_UT_OPS_TEST) + include(${OPS_ADV_CMAKE_DIR}/func_utest.cmake) + endif () + if (TESTS_EXAMPLE_OPS_TEST) + include(${OPS_ADV_CMAKE_DIR}/func_examples.cmake) + endif () +endif () diff --git a/csrc/cmake/intf.cmake b/csrc/cmake/intf.cmake new file mode 100644 index 00000000..20c63563 --- /dev/null +++ b/csrc/cmake/intf.cmake @@ -0,0 +1,12 @@ +# Copyright (c) 2024 Huawei Technologies Co., Ltd. +# This file is a part of the CANN Open Software. +# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ====================================================================================================================== + +if (BUILD_OPEN_PROJECT) + include(${OPS_ADV_CMAKE_DIR}/intf_pub.cmake) +endif () diff --git a/csrc/cmake/intf_pub.cmake b/csrc/cmake/intf_pub.cmake new file mode 100644 index 00000000..4856aeef --- /dev/null +++ b/csrc/cmake/intf_pub.cmake @@ -0,0 +1,75 @@ +# Copyright (c) 2024 Huawei Technologies Co., Ltd. +# This file is a part of the CANN Open Software. +# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ====================================================================================================================== + +# Custom package scenario, public compilation configuration for Host side targets +# Note: To ensure compatibility with the built-in package compilation process, the intf_pub name cannot be changed +add_library(intf_pub INTERFACE) +target_include_directories(intf_pub + INTERFACE + ${ASCEND_CANN_PACKAGE_PATH}/include + ${ASCEND_CANN_PACKAGE_PATH}/include/external + ${ASCEND_CANN_PACKAGE_PATH}/include/experiment/platform + ${ASCEND_CANN_PACKAGE_PATH}/include/experiment/runtime + ${ASCEND_CANN_PACKAGE_PATH}/include/experiment/msprof +) +target_link_directories(intf_pub + INTERFACE + ${ASCEND_CANN_PACKAGE_PATH}/lib64 +) +target_compile_options(intf_pub + INTERFACE + -fPIC + -O2 + -Wall -Wundef -Wcast-qual -Wpointer-arith -Wdate-time + -Wfloat-equal -Wformat=2 -Wshadow + -Wsign-compare -Wunused-macros -Wvla -Wdisabled-optimization -Wempty-body -Wignored-qualifiers + -Wimplicit-fallthrough=3 -Wtype-limits -Wshift-negative-value -Wswitch-default + -Wframe-larger-than=32768 -Woverloaded-virtual + -Wnon-virtual-dtor -Wshift-overflow=2 -Wshift-count-overflow + -Wwrite-strings -Wmissing-format-attribute -Wformat-nonliteral + -Wdelete-non-virtual-dtor -Wduplicated-cond + -Wtrampolines -Wsized-deallocation -Wlogical-op -Wsuggest-attribute=format + -Wduplicated-branches + -Wmissing-include-dirs -Wformat-signedness + -Wreturn-local-addr -Wextra + -Wredundant-decls -Wfloat-conversion + -Wno-write-strings -Wall -Wno-dangling-else -Wno-comment -Wno-conversion-null -Wno-return-type + -Wno-unknown-pragmas -Wno-sign-compare + -Wno-error=undef + -Wno-error=comment + -Wno-error=conversion-null + -Wno-error=dangling-else + -Wno-error=return-type + -Wno-error=shadow + -Wno-error=sign-compare + -Wno-error=unknown-pragmas + -Wno-error=unused-parameter + -Wno-error=cast-qual + -Wno-error=format= + -Wno-error=maybe-uninitialized + -Wno-error=missing-field-initializers + -Wno-error=redundant-decls + -Wno-error=unused-variable + $<$:-Wnested-externs> + $<$:-g> + $,-fstack-protector-strong,-fstack-protector-all> +) +target_compile_definitions(intf_pub + INTERFACE + $<$:_GLIBCXX_USE_CXX11_ABI=0> + $<$:_FORTIFY_SOURCE=2> +) +target_link_options(intf_pub + INTERFACE + $<$,EXECUTABLE>:-pie> + $<$:-s> + -Wl,-z,relro + -Wl,-z,now + -Wl,-z,noexecstack +) diff --git a/csrc/cmake/modules/Findalog.cmake b/csrc/cmake/modules/Findalog.cmake new file mode 100644 index 00000000..d016a50e --- /dev/null +++ b/csrc/cmake/modules/Findalog.cmake @@ -0,0 +1,113 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd. +# This file is a part of the CANN Open Software. +# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ====================================================================================================================== + +if (alog_FOUND) + message(STATUS "Package alog has been found.") + return() +endif() + +set(_cmake_targets_defined "") +set(_cmake_targets_not_defined "") +set(_cmake_expected_targets "") +foreach(_cmake_expected_target IN ITEMS slog alog alog_headers) + list(APPEND _cmake_expected_targets "${_cmake_expected_target}") + if(TARGET "${_cmake_expected_target}") + list(APPEND _cmake_targets_defined "${_cmake_expected_target}") + else() + list(APPEND _cmake_targets_not_defined "${_cmake_expected_target}") + endif() +endforeach() +unset(_cmake_expected_target) + +if(_cmake_targets_defined STREQUAL _cmake_expected_targets) + unset(_cmake_targets_defined) + unset(_cmake_targets_not_defined) + unset(_cmake_expected_targets) + unset(CMAKE_IMPORT_FILE_VERSION) + cmake_policy(POP) + return() +endif() + +if(NOT _cmake_targets_defined STREQUAL "") + string(REPLACE ";" ", " _cmake_targets_defined_text "${_cmake_targets_defined}") + string(REPLACE ";" ", " _cmake_targets_not_defined_text "${_cmake_targets_not_defined}") + message(FATAL_ERROR "Some (but not all) targets in this export set were already defined.\nTargets Defined: ${_cmake_targets_defined_text}\nTargets not yet defined: ${_cmake_targets_not_defined_text}\n") +endif() +unset(_cmake_targets_defined) +unset(_cmake_targets_not_defined) +unset(_cmake_expected_targets) + +find_path(_INCLUDE_DIR + NAMES base/alog_pub.h + NO_CMAKE_SYSTEM_PATH + NO_CMAKE_FIND_ROOT_PATH) + +find_library(slog_SHARED_LIBRARY + NAMES libascendalog.so + PATH_SUFFIXES lib64 + NO_CMAKE_SYSTEM_PATH + NO_CMAKE_FIND_ROOT_PATH) + +find_library(alog_SHARED_LIBRARY + NAMES libascendalog.so + PATH_SUFFIXES lib64 + NO_CMAKE_SYSTEM_PATH + NO_CMAKE_FIND_ROOT_PATH) + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(alog + FOUND_VAR + alog_FOUND + REQUIRED_VARS + _INCLUDE_DIR + slog_SHARED_LIBRARY + alog_SHARED_LIBRARY +) + +if(alog_FOUND) + set(alog_INCLUDE_DIR "${_INCLUDE_DIR}") + include(CMakePrintHelpers) + message(STATUS "Variables in alog module:") + cmake_print_variables(alog_INCLUDE_DIR) + cmake_print_variables(slog_SHARED_LIBRARY) + cmake_print_variables(alog_SHARED_LIBRARY) + + add_library(slog SHARED IMPORTED) + set_target_properties(slog PROPERTIES + INTERFACE_COMPILE_DEFINITIONS "LOG_CPP;PROCESS_LOG" + INTERFACE_LINK_LIBRARIES "alog_headers" + IMPORTED_LOCATION "${slog_SHARED_LIBRARY}" + ) + + add_library(alog SHARED IMPORTED) + set_target_properties(alog PROPERTIES + INTERFACE_COMPILE_DEFINITIONS "LOG_CPP;PROCESS_LOG" + INTERFACE_LINK_LIBRARIES "alog_headers" + IMPORTED_LOCATION "${alog_SHARED_LIBRARY}" + ) + + add_library(alog_headers INTERFACE IMPORTED) + set_target_properties(alog_headers PROPERTIES + INTERFACE_INCLUDE_DIRECTORIES "${alog_INCLUDE_DIR}" + ) + + include(CMakePrintHelpers) + cmake_print_properties(TARGETS slog + PROPERTIES INTERFACE_COMPILE_DEFINITIONS INTERFACE_LINK_LIBRARIES IMPORTED_LOCATION + ) + cmake_print_properties(TARGETS alog + PROPERTIES INTERFACE_COMPILE_DEFINITIONS INTERFACE_LINK_LIBRARIES IMPORTED_LOCATION + ) + cmake_print_properties(TARGETS alog_headers + PROPERTIES INTERFACE_INCLUDE_DIRECTORIES + ) +endif() + +# Cleanup temporary variables. +set(_INCLUDE_DIR) diff --git a/csrc/cmake/scripts/prepare.sh b/csrc/cmake/scripts/prepare.sh new file mode 100644 index 00000000..dac9da76 --- /dev/null +++ b/csrc/cmake/scripts/prepare.sh @@ -0,0 +1,130 @@ +#!/bin/bash +# Copyright (c) 2024 Huawei Technologies Co., Ltd. +# This file is a part of the CANN Open Software. +# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ====================================================================================================================== + +CPU_NUM=$(($(cat /proc/cpuinfo | grep "^processor" | wc -l)*2)) +JOB_NUM="-j${CPU_NUM}" + +while [[ $# -gt 0 ]]; do + case $1 in + -s) + PATH_TO_SOURCE="$2" + shift 2 + ;; + -b) + PATH_TO_BUILD="$2" + shift 2 + ;; + -p) + ASCEND_CANN_PACKAGE_PATH="$2" + shift 2 + ;; + --autogen-dir) + ASCEND_AUTOGEN_DIR="$2" + shift 2 + ;; + --build-open-project) + BUILD_OPEN_PROJECT="$2" + shift 2 + ;; + --binary-out-dir) + ASCEND_BINARY_OUT_DIR="$2" + shift 2 + ;; + --impl-out-dir) + ASCEND_IMPL_OUT_DIR="$2" + shift 2 + ;; + --op-build-tool) + OP_BUILD_TOOL="$2" + shift 2 + ;; + --ascend-cmake-dir) + ASCEND_CMAKE_DIR="$2" + shift 2 + ;; + --tiling-key) + TILING_KEY="$2" + shift 2 + ;; + --ops-compile-options) + OPS_COMPILE_OPTIONS="$2" + shift 2 + ;; + --check-compatible) + CHECK_COMPATIBLE="$2" + shift 2 + ;; + --ascend-compute_unit) + ASCEND_COMPUTE_UNIT="$2" + shift 2 + ;; + --ascend-op-name) + ASCEND_OP_NAME="$2" + shift 2 + ;; + --op_debug_config) + OP_DEBUG_CONFIG="$2" + shift 2 + ;; + *) + break + ;; + esac +done + +function clean() { + if [ -n "${PATH_TO_BUILD}" ];then + rm -rf ${PATH_TO_BUILD} + mkdir -p ${PATH_TO_BUILD} + fi +} + +function convert_string() { + local _input=$1 + _output=$(echo $_input | sed 's/::/;/g') + echo "${_output}" +} + +function set_env() { + CONVERT_TILING_KEY="$(convert_string ${TILING_KEY})" + + CONVERT_OPS_COMPILE_OPTIONS="$(convert_string ${OPS_COMPILE_OPTIONS})" + + CONVERT_ASCEND_COMPUTE_UNIT="$(convert_string ${ASCEND_COMPUTE_UNIT})" +} + +function build() { + cd ${PATH_TO_BUILD} + cmake ${PATH_TO_SOURCE} \ + -DBUILD_OPEN_PROJECT=${BUILD_OPEN_PROJECT} \ + -DPREPARE_BUILD=ON \ + -DCUSTOM_ASCEND_CANN_PACKAGE_PATH=${ASCEND_CANN_PACKAGE_PATH} \ + -DASCEND_AUTOGEN_DIR=${ASCEND_AUTOGEN_DIR} \ + -DASCEND_BINARY_OUT_DIR=${ASCEND_BINARY_OUT_DIR} \ + -DASCEND_IMPL_OUT_DIR=${ASCEND_IMPL_OUT_DIR} \ + -DOP_BUILD_TOOL=${OP_BUILD_TOOL} \ + -DASCEND_CMAKE_DIR=${ASCEND_CMAKE_DIR} \ + -DCHECK_COMPATIBLE=${CHECK_COMPATIBLE} \ + -DTILING_KEY="${CONVERT_TILING_KEY}" \ + -DOPS_COMPILE_OPTIONS="${CONVERT_OPS_COMPILE_OPTIONS}" \ + -DASCEND_COMPUTE_UNIT=${CONVERT_ASCEND_COMPUTE_UNIT} \ + -DOP_DEBUG_CONFIG=${OP_DEBUG_CONFIG} \ + -DASCEND_OP_NAME=${ASCEND_OP_NAME} + + make ${JOB_NUM} prepare_build +} + +function main() { + clean + set_env + build +} + +main diff --git a/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_host/CMakeLists.txt b/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_host/CMakeLists.txt new file mode 100644 index 00000000..44fecd39 --- /dev/null +++ b/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_host/CMakeLists.txt @@ -0,0 +1,54 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd. +# This file is a part of the CANN Open Software. +# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ====================================================================================================================== + +add_ops_compile_options( + OP_NAME GroupedMatmulSwigluQuantWeightNzTensorList + OPTIONS --cce-auto-sync=off + -Wno-deprecated-declarations + -Werror +) + +target_sources(op_host_aclnnExc PRIVATE + grouped_matmul_swiglu_quant_weight_nz_tensor_list_def.cpp +) + +target_sources(opapi PRIVATE + grouped_matmul_swiglu_quant_weight_nz_tensor_list.cpp + aclnn_grouped_matmul_swiglu_quant_weight_nz_tensor_list.cpp +) + +if (NOT BUILD_OPEN_PROJECT) + target_sources(aclnn_ops_train PRIVATE + grouped_matmul_swiglu_quant_weight_nz_tensor_list.cpp + aclnn_grouped_matmul_swiglu_quant_weight_nz_tensor_list.cpp + ) + + target_sources(aclnn_ops_infer PRIVATE + grouped_matmul_swiglu_quant_weight_nz_tensor_list.cpp + aclnn_grouped_matmul_swiglu_quant_weight_nz_tensor_list.cpp + ) +endif () + +target_sources(optiling PRIVATE + grouped_matmul_swiglu_quant_weight_nz_tensor_list_tiling.cpp +) + +target_include_directories(optiling PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR} +) + +target_sources(opsproto PRIVATE + grouped_matmul_swiglu_quant_weight_nz_tensor_list_proto.cpp +) + +file(GLOB _GMM_Aclnn_header "${CMAKE_CURRENT_SOURCE_DIR}/aclnn_grouped_matmul_swiglu_quant_weight_nz_tensor_list.h") + +install(FILES ${_GMM_Aclnn_header} + DESTINATION ${ACLNN_INC_INSTALL_DIR} OPTIONAL +) diff --git a/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_host/aclnn_grouped_matmul_swiglu_quant_weight_nz_tensor_list.cpp b/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_host/aclnn_grouped_matmul_swiglu_quant_weight_nz_tensor_list.cpp new file mode 100644 index 00000000..d5992610 --- /dev/null +++ b/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_host/aclnn_grouped_matmul_swiglu_quant_weight_nz_tensor_list.cpp @@ -0,0 +1,329 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include +#include +#include "aclnn_kernels/contiguous.h" +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnn_kernels/common/op_error_check.h" +#include "opdev/common_types.h" +#include "opdev/data_type_utils.h" +#include "opdev/format_utils.h" +#include "opdev/op_dfx.h" +#include "opdev/op_executor.h" +#include "opdev/op_log.h" +#include "opdev/platform.h" +#include "opdev/shape_utils.h" +#include "opdev/tensor_view_utils.h" +#include "opdev/make_op_executor.h" +#include "grouped_matmul_swiglu_quant_weight_nz_tensor_list.h" +#include "aclnn_grouped_matmul_swiglu_quant_weight_nz_tensor_list.h" + +using namespace op; + +#ifdef __cplusplus +extern "C" { +#endif + +static constexpr int64_t SPLIT = 2; +static constexpr int64_t K_LIMIT = 65536; +static constexpr int64_t N_LIMIT = 4096; +static constexpr int64_t NZ_DIM_3 = 32; +static constexpr int64_t NZ_DIM_2 = 16; +static constexpr int64_t OUTPUT_IDX_0 = 0; +static constexpr int64_t OUTPUT_IDX_1 = 1; +static constexpr size_t X_DIM_LIMIT = 2; +static constexpr size_t WEIGHT_ND_DIM_LIMIT = 2; +static constexpr size_t WEIGHT_NZ_DIM_LIMIT = 4; +static constexpr size_t WEIGHT_SCALE_DIM_LIMIT = 1; +static constexpr size_t TOKEN_SCALE_DIM_LIMIT = 1; +static constexpr size_t GROUP_LIST_DIM_LIMIT = 1; +static constexpr size_t QUANTOUT_DIM_LIMIT = 2; +static constexpr size_t QUANTSCALEOUT_DIM_LIMIT = 1; + +static const std::initializer_list X_DTYPE_SUPPORT_LIST = {DataType::DT_INT8}; +static const std::initializer_list WEIGHT_DTYPE_SUPPORT_LIST = {DataType::DT_INT8}; +static const std::initializer_list WEIGHT_SCALE_DTYPE_SUPPORT_LIST = {DataType::DT_FLOAT, DataType::DT_FLOAT16, DataType::DT_BF16}; +static const std::initializer_list X_SCALE_DTYPE_SUPPORT_LIST = {DataType::DT_FLOAT, DataType::DT_FLOAT16, DataType::DT_BF16}; +static const std::initializer_list GROUP_LIST_DTYPE_SUPPORT_LIST = {DataType::DT_INT64}; +static const std::initializer_list QUANTOUT_DTYPE_SUPPORT_LIST = {DataType::DT_INT8}; +static const std::initializer_list QUANTSCALEOUT_DTYPE_SUPPORT_LIST = {DataType::DT_FLOAT}; + +static bool CheckNotNull(const aclTensor* x, const aclTensorList* weight, const aclTensor* bias, const aclTensor* offset, + const aclTensorList* weightScale, const aclTensor* xScale, const aclTensor* groupList, + const aclTensor* output, const aclTensor* outputScale, const aclTensor* outputOffset) +{ + OP_CHECK_NULL(x, return false); + OP_CHECK_NULL(weight, return false); + OP_CHECK_NULL(weightScale, return false); + OP_CHECK_NULL(xScale, return false); + OP_CHECK_NULL(groupList, return false); + OP_CHECK_NULL(output, return false); + OP_CHECK_NULL(outputScale, return false); + if (bias != nullptr) { + OP_LOGW("aclnnGroupedMatmulSwigluQuantWeightNzTensorList, The current version does not support the scenario where bias is not 0. " + "Features and accuracy are not guaranteed if inputting bias with values other than 0s."); + } + if (offset != nullptr) { + OP_LOGW("aclnnGroupedMatmulSwigluQuantWeightNzTensorList, The current version does not support the scenario where offset is not 0. " + "Features and accuracy are not guaranteed if inputting bias with values other than 0s."); + } + if (outputOffset != nullptr) { + OP_LOGW("aclnnGroupedMatmulSwigluQuantWeightNzTensorList, The current version does not support the scenario where outputOffset is not 0. " + "Features and accuracy are not guaranteed if inputting bias with values other than 0s."); + } + return true; +} + +static bool CheckInputOutDims(const aclTensor* x, const aclTensorList* weight, const aclTensorList* weightScale, + const aclTensor* xScale, const aclTensor* groupList, + const aclTensor* output, const aclTensor* outputScale) +{ + OP_CHECK_WRONG_DIMENSION(x, X_DIM_LIMIT, return false); + op::Format weightViewFormat = (*weight)[0]->GetViewFormat(); + if (IsPrivateFormat(weightViewFormat)){ + OP_CHECK_WRONG_DIMENSION((*weight)[0], WEIGHT_NZ_DIM_LIMIT, return false); + } else { + OP_CHECK_WRONG_DIMENSION((*weight)[0], WEIGHT_ND_DIM_LIMIT, return false); + } + OP_CHECK_WRONG_DIMENSION((*weightScale)[0], WEIGHT_SCALE_DIM_LIMIT, return false); + OP_CHECK_WRONG_DIMENSION(xScale, TOKEN_SCALE_DIM_LIMIT, return false); + OP_CHECK_WRONG_DIMENSION(groupList, GROUP_LIST_DIM_LIMIT, return false); + OP_CHECK_WRONG_DIMENSION(output, QUANTOUT_DIM_LIMIT, return false); + OP_CHECK_WRONG_DIMENSION(outputScale, QUANTSCALEOUT_DIM_LIMIT, return false); + return true; +} + +static bool CheckInputOutShape(const aclTensor* x, const aclTensorList* weight, const aclTensorList* weightScale, + const aclTensor* xScale, const aclTensor* groupList, + const aclTensor* output, const aclTensor* outputScale) +{ + int64_t m = x->GetViewShape().GetDim(0); + int64_t k = x->GetViewShape().GetDim(1); + int64_t n = (*weightScale)[0]->GetViewShape().GetDim(0); + int64_t e = weight->Size(); + if (n % SPLIT != 0){ + OP_LOGE(ACLNN_ERR_PARAM_INVALID, + "aclnnGroupedMatmulSwigluQuantWeightNzTensorList, N is %ld , not an even number.", n); + return false; + } + int64_t nAfterHalve = static_cast(n / SPLIT); + // x shape is expected to be [M, K] + op::Shape xExpectShape = {m, k}; + // The ND shape of each weight in TensorList is expected to be [K, N] + op::Shape weightNDExpectShape = {k, n}; + // The NZ shape of each weight in TensorList is expected to be [N // 32, K // 16, 16, 32] + op::Shape weightNZExpectShape = {static_cast(n / NZ_DIM_3), + static_cast(k / NZ_DIM_2), + NZ_DIM_2, NZ_DIM_3}; + // weightScale shape is expected to be [N] + op::Shape weightScaleExpectShape = {n}; + // xScale shape is expected to be [E, N] + op::Shape xScaleExpectShape = {m}; + // output shape is expected to be [M, N] + op::Shape outputExpectShape = {m, nAfterHalve}; + // outputScale shape is expected to be [M] + op::Shape outputScaleExpectShape = {m}; + for (size_t i = 0; i < weight->Size(); ++i) { + op::Format weightViewFormat = (*weight)[i]->GetViewFormat(); + OP_CHECK_SHAPE_NOT_EQUAL_WITH_EXPECTED_SIZE(x, xExpectShape, return false); + if (IsPrivateFormat(weightViewFormat)){ + OP_CHECK_SHAPE_NOT_EQUAL_WITH_EXPECTED_SIZE((*weight)[i], weightNZExpectShape, return false); + } else { + OP_CHECK_SHAPE_NOT_EQUAL_WITH_EXPECTED_SIZE((*weight)[i], weightNDExpectShape, return false); + } + OP_CHECK_SHAPE_NOT_EQUAL_WITH_EXPECTED_SIZE((*weightScale)[i], weightScaleExpectShape, return false); + } + OP_CHECK_SHAPE_NOT_EQUAL_WITH_EXPECTED_SIZE(xScale, xScaleExpectShape, return false); + + OP_CHECK_SHAPE_NOT_EQUAL_WITH_EXPECTED_SIZE(output, outputExpectShape, return false); + OP_CHECK_SHAPE_NOT_EQUAL_WITH_EXPECTED_SIZE(outputScale, outputScaleExpectShape, return false); + // The length of groupList should be less than or equal to the number of experts in weight + int64_t groupListLen = groupList->GetViewShape().GetDim(0); + if(groupListLen > e) { + OP_LOGE(ACLNN_ERR_PARAM_INVALID, + "aclnnGroupedMatmulSwigluQuantWeightNzTensorList, Length of 'groupList' out of range (expected to be in range of [1, %ld], but got %ld)", + e, groupListLen); + return false; + } + if(nAfterHalve > N_LIMIT) { + OP_LOGE(ACLNN_ERR_PARAM_INVALID, + "aclnnGroupedMatmulSwigluQuantWeightNzTensorList, The current version does not support the scenario.\ + where N after halve is %ld greater than %ld.", + nAfterHalve, N_LIMIT); + return false; + } + if(k >= K_LIMIT) { + OP_LOGE(ACLNN_ERR_PARAM_INVALID, + "aclnnGroupedMatmulSwigluQuantWeightNzTensorList, The current version does not support the scenario.\ + The tail axis dimension of input0(x) is %ld, which need lower than %ld.", + k, K_LIMIT); + return false; + } + return true; +} + +static bool CheckDtypeValid(const aclTensor* x, const aclTensorList* weight, const aclTensorList* weightScale, + const aclTensor* xScale, const aclTensor* groupList, + const aclTensor* output, const aclTensor* outputScale) +{ + OP_CHECK_DTYPE_NOT_SUPPORT(x, X_DTYPE_SUPPORT_LIST, return false); + for (size_t i = 0; i < weight->Size(); ++i) { + OP_CHECK_DTYPE_NOT_SUPPORT((*weight)[i], WEIGHT_DTYPE_SUPPORT_LIST, return false); + OP_CHECK_DTYPE_NOT_SUPPORT((*weightScale)[i], WEIGHT_SCALE_DTYPE_SUPPORT_LIST, return false); + } + OP_CHECK_DTYPE_NOT_SUPPORT(xScale, X_SCALE_DTYPE_SUPPORT_LIST, return false); + OP_CHECK_DTYPE_NOT_SUPPORT(groupList, GROUP_LIST_DTYPE_SUPPORT_LIST, return false); + OP_CHECK_DTYPE_NOT_SUPPORT(output, QUANTOUT_DTYPE_SUPPORT_LIST, return false); + OP_CHECK_DTYPE_NOT_SUPPORT(outputScale, QUANTSCALEOUT_DTYPE_SUPPORT_LIST, return false); + return true; +} + +static bool CheckFormat(const aclTensor* x, const aclTensorList* weight, const aclTensor* output) +{ + bool isNZ = (*weight)[0]->GetStorageFormat() == op::Format::FORMAT_FRACTAL_NZ; + if (!isNZ) { + // fp16 in fp32 out that is split k template, not precision-advanced now + OP_LOGE(ACLNN_ERR_PARAM_INVALID, "aclnnGroupedMatmulSwigluQuantWeightNzTensorList, The current version does not support the scenario.\ + weight Format expect is FRACTAL_NZ, but got [%s].", op::ToString((*weight)[0]->GetStorageFormat()).GetString()); + return false; + } + if (IsPrivateFormat(x->GetStorageFormat())) { + OP_LOGE(ACLNN_ERR_PARAM_INVALID, "aclnnGroupedMatmulSwigluQuantWeightNzTensorList, The current version does not support the scenario.\ + x Format Not support Private Format."); + return false; + } + if (IsPrivateFormat(output->GetStorageFormat())) { + OP_LOGE(ACLNN_ERR_PARAM_INVALID, "aclnnGroupedMatmulSwigluQuantWeightNzTensorList, The current version does not support the scenario.\ + output Format Not support Private Format."); + return false; + } + return true; +} + +static aclnnStatus CheckParams(const aclTensor* x, const aclTensorList* weight, const aclTensor* bias, const aclTensor* offset, + const aclTensorList* weightScale, const aclTensor* xScale, const aclTensor* groupList, + const aclTensor* output, const aclTensor* outputScale, const aclTensor* outputOffset) { + // 1. Check if parameters are null pointers + CHECK_RET(CheckNotNull(x, weight, bias, offset, weightScale, xScale, + groupList, output, outputScale, outputOffset), ACLNN_ERR_PARAM_NULLPTR); + + // 2. Verify input and output parameter dimensions + CHECK_RET(CheckInputOutDims(x, weight, weightScale, xScale, + groupList, output, outputScale), ACLNN_ERR_PARAM_INVALID); + + // 3. Verify input and output shape parameters + CHECK_RET(CheckInputOutShape(x, weight, weightScale, xScale, + groupList, output, outputScale), ACLNN_ERR_PARAM_INVALID); + + // 4. Check if the input data types are within the supported data type range + CHECK_RET(CheckDtypeValid(x, weight, weightScale, xScale, + groupList, output, outputScale), ACLNN_ERR_PARAM_INVALID); + + // 5. Check if data format is supported + CHECK_RET(CheckFormat(x, weight, output), ACLNN_ERR_PARAM_INVALID); + + return ACLNN_SUCCESS; +} + +static aclnnStatus aclnnGroupedMatmulSwigluQuantWeightNzTensorListGetWorkspaceSizeCommon(const aclTensor *x, const aclTensorList *weight, + const aclTensor *bias, const aclTensor *offset, + const aclTensorList *weightScale, const aclTensor *xScale, + const aclTensor *groupList, + aclTensor *output, aclTensor *outputScale, + aclTensor *outputOffset, uint64_t *workspaceSize, + aclOpExecutor **executor){ + // Fixed pattern, create OpExecutor + auto uniqueExecutor = CREATE_EXECUTOR(); + CHECK_RET(uniqueExecutor.get() != nullptr, ACLNN_ERR_INNER_CREATE_EXECUTOR); + // Fixed pattern, parameter check + auto ret = CheckParams(x, weight, bias, offset, weightScale, xScale, + groupList, output, outputScale, outputOffset); + CHECK_RET(ret == ACLNN_SUCCESS, ret); + // Empty tensor scenario + if (output->IsEmpty() || groupList->IsEmpty() || outputScale->IsEmpty()) { + *workspaceSize = 0; + uniqueExecutor.ReleaseTo(executor); + return ACLNN_SUCCESS; + } + // Convert to contiguous + x = l0op::Contiguous(x, uniqueExecutor.get()); + CHECK_RET(x != nullptr, ACLNN_ERR_INNER_NULLPTR); + for (size_t i = 0; i < weight->Size(); ++i) { + (*weight)[i]->SetOriginalShape((*weight)[i]->GetViewShape()); + } + xScale = l0op::Contiguous(xScale, uniqueExecutor.get()); + CHECK_RET(xScale != nullptr, ACLNN_ERR_INNER_NULLPTR); + groupList = l0op::Contiguous(groupList, uniqueExecutor.get()); + CHECK_RET(groupList != nullptr, ACLNN_ERR_INNER_NULLPTR); + // Call L0 operator capability + auto ret_0 = l0op::GroupedMatmulSwigluQuantWeightNzTensorList(x, weight, weightScale, xScale, groupList, uniqueExecutor.get()); + CHECK_RET(ret_0 != std::tuple(nullptr, nullptr), ACLNN_ERR_INNER_NULLPTR); + auto out0 = std::get(ret_0); + auto ret_1 = l0op::ViewCopy(out0, output, uniqueExecutor.get()); + CHECK_RET(ret_1 != nullptr, ACLNN_ERR_INNER_NULLPTR); + auto out1 = std::get(ret_0); + auto ret_2 = l0op::ViewCopy(out1, outputScale, uniqueExecutor.get()); + CHECK_RET(ret_2 != nullptr, ACLNN_ERR_INNER_NULLPTR); + *workspaceSize = uniqueExecutor->GetWorkspaceSize(); + uniqueExecutor.ReleaseTo(executor); + return ACLNN_SUCCESS; +} + +aclnnStatus aclnnGroupedMatmulSwigluQuantWeightNzTensorListGetWorkspaceSize(const aclTensor *x, const aclTensorList *weight, + const aclTensor *bias, const aclTensor *offset, + const aclTensorList *weightScale, const aclTensor *xScale, + const aclTensor *groupList, + aclTensor *output, aclTensor *outputScale, + aclTensor *outputOffset, uint64_t *workspaceSize, + aclOpExecutor **executor) { + OP_CHECK_COMM_INPUT(workspaceSize, executor); + L2_DFX_PHASE_1(aclnnGroupedMatmulSwigluQuantWeightNzTensorList, + DFX_IN(x, weight, bias, offset, weightScale, xScale, groupList), + DFX_OUT(output, outputScale, outputOffset)); + // weight is forcibly bound to StorageFormat and ViewFormat as NZ in this scenario + CHECK_RET(weight != nullptr, ACLNN_ERR_PARAM_NULLPTR); + for (size_t i = 0; i < weight->Size(); ++i) { + auto storgeShape = (*weight)[i]->GetStorageShape(); + auto viewShape = (*weight)[i]->GetViewShape(); + aclTensor* weightNZ = const_cast((*weight)[i]); + CHECK_COND((storgeShape.GetDimNum() == WEIGHT_NZ_DIM_LIMIT), + ACLNN_ERR_PARAM_INVALID, + "aclnnGroupedMatmulSwigluQuantWeightNZTensorList, The dimnum of storageShape for second input (weight) \ + must be 4. \n But StorageShape got %s , and dimNum is %lu.", + op::ToString(storgeShape).GetString(), storgeShape.GetDimNum()); + // The StorageFormat of weight is unconditionally regarded as NZ + weightNZ->SetStorageFormat(op::Format::FORMAT_FRACTAL_NZ); + if (viewShape.GetDimNum() == WEIGHT_NZ_DIM_LIMIT){ + // If the viewShape of weight is 4-dimensional, it is regarded as NZ + weightNZ->SetViewFormat(op::Format::FORMAT_FRACTAL_NZ); + } else if (viewShape.GetDimNum() == WEIGHT_ND_DIM_LIMIT){ + // If the viewShape of weight is 2-dimensional, it is regarded as ND + weightNZ->SetViewFormat(op::Format::FORMAT_ND); + } + } + // Call the common interface + return aclnnGroupedMatmulSwigluQuantWeightNzTensorListGetWorkspaceSizeCommon(x, weight, bias, offset, weightScale, xScale, groupList, + output, outputScale, outputOffset, workspaceSize, executor); +} + +aclnnStatus aclnnGroupedMatmulSwigluQuantWeightNzTensorList(void *workspace, + uint64_t workspaceSize, + aclOpExecutor *executor, + aclrtStream stream) { + L2_DFX_PHASE_2(aclnnGroupedMatmulSwigluQuantWeightNzTensorList); + CHECK_COND(CommonOpExecutorRun(workspace, workspaceSize, executor, stream) == ACLNN_SUCCESS, ACLNN_ERR_INNER, + "This is an error in GroupedMatmulSwigluQuantWeightNzTensorList launch aicore"); + return ACLNN_SUCCESS; +} + +#ifdef __cplusplus +} +#endif diff --git a/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_host/aclnn_grouped_matmul_swiglu_quant_weight_nz_tensor_list.h b/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_host/aclnn_grouped_matmul_swiglu_quant_weight_nz_tensor_list.h new file mode 100644 index 00000000..407f27f4 --- /dev/null +++ b/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_host/aclnn_grouped_matmul_swiglu_quant_weight_nz_tensor_list.h @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef OP_API_INC_GROUPED_MATMUL_SWIGLU_QUANT_WEIGHT_NZ_TENSOR_LIST_H +#define OP_API_INC_GROUPED_MATMUL_SWIGLU_QUANT_WEIGHT_NZ_TENSOR_LIST_H +#include "aclnn/aclnn_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * @brief The first interface of aclnnGroupedMatmulSwigluQuantWeightNzTensorList, which calculates the workspace size according to the specific calculation process. + * @domain aclnn_ops_infer + * + * @param [in] x: Represents x in the formula. The data type supports INT8, and the data format supports ND. + * @param [in] weight: + * Represents weight in the formula. The data type supports INT8, and the data format supports NZ. + * @param [in] weightScale: Represents quantization parameters. The data type supports FLOAT16, BFLOAT16, and FLOAT32. The data format supports ND, with a maximum length of 128. + * Represents per Channel parameters. The data type supports FLOAT16 and BFLOAT16. The data format supports ND. + * @param [in] xScale: + * Represents per Token quantization parameters. The data type supports FLOAT32, and the data format supports ND. + * @param [in] groupList: Required parameter, representing the index situation on the input and output grouping axes. The data type supports INT64. + * @param [out] quantOutput: Represents out in the formula. The data type supports INT8, and the data format supports ND. + * @param [out] quantScaleOutput: Represents outQuantScale in the formula. The data type supports Float32. + * @param [out] workspaceSize: Returns the workspace size that users need to apply for on the npu device side. + * @param [out] executor: Returns the op executor, containing the operator calculation process. + * @return aclnnStatus: Returns the status code. + */ +__attribute__((visibility("default"))) aclnnStatus aclnnGroupedMatmulSwigluQuantWeightNzTensorListGetWorkspaceSize( + const aclTensor *x, const aclTensorList *weight, const aclTensor *bias, const aclTensor *offset, + const aclTensorList *weightScale, const aclTensor *xScale, const aclTensor *groupList, + aclTensor *output, aclTensor *outputScale, aclTensor *outputOffset, uint64_t *workspaceSize, aclOpExecutor **executor); + +/** + * @brief The second interface of aclnnGroupedMatmulSwigluQuantWeightNzTensorList, used to execute calculations. + * @param [in] workspace: The starting address of the workspace memory applied for on the npu device side. + * @param [in] workspaceSize: The workspace size applied for on the npu device side, obtained from the first interface aclnnGroupedMatmulSwigluQuantWeightNzTensorListGetWorkspaceSize. + * @param [in] stream: acl stream. + * @param [in] executor: op executor, containing the operator calculation process. + * @return aclnnStatus: Returns the status code. + */ +__attribute__((visibility("default"))) aclnnStatus aclnnGroupedMatmulSwigluQuantWeightNzTensorList(void* workspace, + uint64_t workspaceSize, aclOpExecutor* executor, aclrtStream stream); + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_host/grouped_matmul_swiglu_quant_weight_nz_tensor_list.cpp b/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_host/grouped_matmul_swiglu_quant_weight_nz_tensor_list.cpp new file mode 100644 index 00000000..9181552b --- /dev/null +++ b/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_host/grouped_matmul_swiglu_quant_weight_nz_tensor_list.cpp @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "opdev/op_log.h" +#include "opdev/op_dfx.h" +#include "opdev/make_op_executor.h" +#include "grouped_matmul_swiglu_quant_weight_nz_tensor_list.h" + +using namespace op; + +namespace l0op { +OP_TYPE_REGISTER(GroupedMatmulSwigluQuantWeightNzTensorList); + +const std::tuple GroupedMatmulSwigluQuantWeightNzTensorList(const aclTensor *x, + const aclTensorList *weight, + const aclTensorList *perChannelScale, + const aclTensor *perTokenScale, + const aclTensor *groupList, + aclOpExecutor *executor) { + L0_DFX(GroupedMatmulSwigluQuantWeightNzTensorList, x, weight, perChannelScale, perTokenScale, groupList); + if (x == nullptr) { + OP_LOGE(ACLNN_ERR_PARAM_INVALID, "x is nullptr."); + return std::tuple(nullptr, nullptr); + } + int64_t m = perTokenScale->GetViewShape().GetDim(0); + int64_t n = (*perChannelScale)[0]->GetViewShape().GetDim(0); + int64_t nAfterHalve = static_cast(n / 2); + gert::Shape outShape({m, nAfterHalve}); + gert::Shape scaleOutShape({m}); + auto out = executor->AllocTensor(outShape, DataType::DT_INT8, ge::FORMAT_ND); + auto scaleOut = executor->AllocTensor(scaleOutShape, DataType::DT_FLOAT, ge::FORMAT_ND); + auto ret = INFER_SHAPE(GroupedMatmulSwigluQuantWeightNzTensorList, + OP_INPUT(x, weight, perChannelScale, perTokenScale, groupList), + OP_OUTPUT(out, scaleOut)); + if (ret != ACLNN_SUCCESS) { + OP_LOGE(ACLNN_ERR_PARAM_INVALID, "InferShape failed."); + return std::tuple(nullptr, nullptr); + } + ret = ADD_TO_LAUNCHER_LIST_AICORE(GroupedMatmulSwigluQuantWeightNzTensorList, + OP_INPUT(x, weight, perChannelScale, perTokenScale, groupList), + OP_OUTPUT(out, scaleOut)); + if (ret != ACLNN_SUCCESS) { + OP_LOGE(ACLNN_ERR_PARAM_INVALID, "ADD_TO_LAUNCHER_LIST_AICORE failed."); + return std::tuple(nullptr, nullptr); + } + return std::tie(out, scaleOut); +} + +} // namespace l0op diff --git a/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_host/grouped_matmul_swiglu_quant_weight_nz_tensor_list.h b/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_host/grouped_matmul_swiglu_quant_weight_nz_tensor_list.h new file mode 100644 index 00000000..f47ad8a8 --- /dev/null +++ b/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_host/grouped_matmul_swiglu_quant_weight_nz_tensor_list.h @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef OP_API_INC_LEVEL0_OP_GROUPED_MATMUL_SWIGLU_QUANT_WEIGHT_NZ_TENSOR_LIST_OP_H +#define OP_API_INC_LEVEL0_OP_GROUPED_MATMUL_SWIGLU_QUANT_WEIGHT_NZ_TENSOR_LIST_OP_H + +#include "opdev/op_executor.h" + +namespace l0op { +const std::tuple GroupedMatmulSwigluQuantWeightNzTensorList(const aclTensor *x, + const aclTensorList *weight, + const aclTensorList *perChannelScale, + const aclTensor *perTokenScale, + const aclTensor *groupList, + aclOpExecutor *executor); +} + +#endif diff --git a/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_host/grouped_matmul_swiglu_quant_weight_nz_tensor_list_def.cpp b/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_host/grouped_matmul_swiglu_quant_weight_nz_tensor_list_def.cpp new file mode 100644 index 00000000..bd7a80b1 --- /dev/null +++ b/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_host/grouped_matmul_swiglu_quant_weight_nz_tensor_list_def.cpp @@ -0,0 +1,65 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file grouped_matmul_swiglu_quant_weight_nz_tensor_list_def.cpp + * \brief + */ + +#include +#include "register/op_def_registry.h" +namespace ops { +class GroupedMatmulSwigluQuantWeightNzTensorList : public OpDef { +public: + explicit GroupedMatmulSwigluQuantWeightNzTensorList(const char* name) : OpDef(name) + { + this->Input("x") + .ParamType(REQUIRED) + .DataType({ge::DT_INT8,ge::DT_INT8,ge::DT_INT8}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("weight") + .ParamType(DYNAMIC) + .DataType({ge::DT_INT8,ge::DT_INT8,ge::DT_INT8}) + .Format({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ}); + this->Input("weight_scale") + .ParamType(DYNAMIC) + .DataType({ge::DT_FLOAT, ge::DT_BF16, ge::DT_FLOAT16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("x_scale") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT,ge::DT_FLOAT,ge::DT_FLOAT}) + .Format({ge::FORMAT_ND,ge::FORMAT_ND,ge::FORMAT_ND}); + this->Input("group_list") + .ParamType(REQUIRED) + .DataType({ge::DT_INT64,ge::DT_INT64,ge::DT_INT64}) + .Format({ge::FORMAT_ND,ge::FORMAT_ND,ge::FORMAT_ND}); + this->Output("y") + .ParamType(REQUIRED) + .DataType({ge::DT_INT8,ge::DT_INT8,ge::DT_INT8}) + .Format({ge::FORMAT_ND,ge::FORMAT_ND,ge::FORMAT_ND}); + this->Output("y_scale") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT,ge::DT_FLOAT,ge::DT_FLOAT}) + .Format({ge::FORMAT_ND,ge::FORMAT_ND,ge::FORMAT_ND}); + OpAICoreConfig aicore_config; + aicore_config.DynamicCompileStaticFlag(true) + .DynamicFormatFlag(true) + .DynamicRankSupportFlag(true) + .DynamicShapeSupportFlag(true) + .NeedCheckSupportFlag(false) + .PrecisionReduceFlag(true); + + this->AICore().AddConfig("ascend910b", aicore_config); + this->AICore().AddConfig("ascend910_93", aicore_config); + } +}; + +OP_ADD(GroupedMatmulSwigluQuantWeightNzTensorList); +} diff --git a/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_host/grouped_matmul_swiglu_quant_weight_nz_tensor_list_proto.cpp b/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_host/grouped_matmul_swiglu_quant_weight_nz_tensor_list_proto.cpp new file mode 100644 index 00000000..5e3d4432 --- /dev/null +++ b/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_host/grouped_matmul_swiglu_quant_weight_nz_tensor_list_proto.cpp @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file grouped_matmul_swiglu_quant_weight_nz_tensor_list_proto.cpp + * \brief + */ +#include "register/op_impl_registry.h" +#include "log/ops_log.h" +#include "platform/platform_info.h" + +using namespace ge; +namespace ops { +const int64_t X_INDEX = 0; +const int64_t WEIGHTSCALE_INDEX = 2; +const int64_t M_DIM_INDEX = 0; +const int64_t N_DIM_INDEX = 0; +static ge::graphStatus InferShape4GroupedMatmulSwigluQuantWeightNzTensorList(gert::InferShapeContext* context) { + const gert::Shape* xShape = context->GetInputShape(X_INDEX); + const gert::Shape* weightScaleShape = context->GetDynamicInputShape(WEIGHTSCALE_INDEX, 0); + int64_t m = xShape->GetDim(M_DIM_INDEX); + int64_t n = static_cast(weightScaleShape->GetDim(N_DIM_INDEX) / 2); + auto outShape = context->GetOutputShape(0); + outShape->SetDimNum(2); + outShape->SetDim(0, m); + outShape->SetDim(1, n); + auto outScaleShape = context->GetOutputShape(1); + outScaleShape->SetDimNum(1); + outScaleShape->SetDim(0, m); + return GRAPH_SUCCESS; +} + +static graphStatus InferDataType4GroupedMatmulSwigluQuantWeightNzTensorList(gert::InferDataTypeContext* context) { + context->SetOutputDataType(0, DataType::DT_INT8); + context->SetOutputDataType(1, DataType::DT_FLOAT); + return GRAPH_SUCCESS; +} + +IMPL_OP_INFERSHAPE(GroupedMatmulSwigluQuantWeightNzTensorList) + .InferShape(InferShape4GroupedMatmulSwigluQuantWeightNzTensorList) + .InferDataType(InferDataType4GroupedMatmulSwigluQuantWeightNzTensorList); +} // namespace ops diff --git a/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_host/grouped_matmul_swiglu_quant_weight_nz_tensor_list_tiling.cpp b/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_host/grouped_matmul_swiglu_quant_weight_nz_tensor_list_tiling.cpp new file mode 100644 index 00000000..f3495939 --- /dev/null +++ b/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_host/grouped_matmul_swiglu_quant_weight_nz_tensor_list_tiling.cpp @@ -0,0 +1,188 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file grouped_matmul_swiglu_quant_weight_nz_tensor_list_tiling.cpp + * \brief + */ +#include +#include +#include "register/op_impl_registry.h" +#include "log/ops_log.h" +#include "error/ops_error.h" +#include "tiling/tiling_base.h" +#include "grouped_matmul_swiglu_quant_weight_nz_tensor_list_tiling.h" +using namespace ge; +using namespace AscendC; +using namespace GroupedMatmulSwigluQuantWeightNzTensorListTiling; + +template +static T1 CeilDiv(T1 a, T2 b) +{ + if (b == 0) { + return 0; + } + return (a + b - 1) / b; +} + +namespace optiling { + +struct GMMSwigluCompileInfo { + uint64_t ubSize_ = 0; + uint32_t aicNum_ = 0; + uint32_t baseM_ = 128; + uint32_t baseN_ = 256; +}; + +static uint64_t CalcMaxTmpSize(const uint32_t row, const uint64_t n) { + std::vector shape_vec = {static_cast(row * n)}; + Shape shape(shape_vec); + uint32_t max; + uint32_t min; + GetSwiGLUMaxMinTmpSize(shape, 4, max, min, false); + uint32_t averageTmp = (max + min) >> 1; + GetAscendQuantMaxMinTmpSize(shape, 4, max, min); + uint32_t average = (max + min) >> 1; + average = average > averageTmp ? average : averageTmp; + GetAscendDequantMaxMinTmpSize(shape, 4, max, min); + averageTmp = (max + min) >> 1; + return average > averageTmp ? average : averageTmp; +} + +static uint64_t CalRows(const uint64_t ubSize, const uint64_t n) { + uint64_t tokenSize = n << 2; + uint64_t expectSize = ubSize - tokenSize; + uint64_t rows = expectSize / (8 + tokenSize); + uint64_t realSize = (8 + tokenSize) * rows + CalcMaxTmpSize(rows, n); + while (expectSize < realSize) { + rows -= CeilDiv(realSize - expectSize, (8 + tokenSize) << 2); + realSize = (8 + tokenSize) * rows + CalcMaxTmpSize(rows, n); + } + return rows; +} + +static void SetTilingKey(gert::TilingContext* context, bool isSplitWorkSpace) { + if(isSplitWorkSpace){ + context->SetTilingKey(1); + context->SetScheduleMode(BATCH_MODE_SCHEDULE); + } else { + context->SetTilingKey(0); + context->SetScheduleMode(BATCH_MODE_SCHEDULE); + } +} + +static bool IsPreFill(GMMSwigluQuantTilingData &tilingData) { + int64_t k = tilingData.gmmSwigluBaseParams.get_K(); + int64_t n = tilingData.gmmSwigluBaseParams.get_N(); + int64_t m = tilingData.gmmSwigluBaseParams.get_M(); + int64_t groupNum = tilingData.gmmSwigluBaseParams.get_groupNum(); + if (groupNum == 128 && m >= PREFILL_M_MIN_SIZE) { // 128:prefiling groupNum + std::array kNList = {k, n}; // 2: kNList size + if (PREFILL_WHITE_LIST.count(kNList)) { + return true; + } + } + return false; +} + +ASCENDC_EXTERN_C graphStatus TilingGMMSwigluQuant(gert::TilingContext* context) { + // set info + OPS_LOG_I(context->GetNodeName(), "Begin Run GMM Swiglu Tiling ."); + + auto compileInfoPtr = context->GetCompileInfo(); + auto xTensor = context->GetInputTensor(X_INDEX); + OPS_LOG_E_IF_NULL(context, xTensor, return GRAPH_FAILED); + const int64_t m = xTensor->GetStorageShape().GetDim(0); + const int64_t k = xTensor->GetStorageShape().GetDim(1); + auto wTensor = context->GetDynamicInputTensor(WEIGHT_INDEX, 0); + OPS_LOG_E_IF_NULL(context, wTensor, return GRAPH_FAILED); + const int64_t n = wTensor->GetStorageShape().GetDim(0) * wTensor->GetStorageShape().GetDim(3); + auto groupListTensor = context->GetDynamicInputTensor(GROUPLIST_INDEX, 0); + OPS_LOG_E_IF_NULL(context, groupListTensor, return GRAPH_FAILED); + const int64_t groupNum = groupListTensor->GetStorageShape().GetDim(0); + GMMSwigluQuantTilingData tilingData; + const int64_t row = CalRows(compileInfoPtr->ubSize_, n); + tilingData.gmmSwigluBaseParams.set_groupNum(groupNum); + tilingData.gmmSwigluBaseParams.set_coreNum(compileInfoPtr->aicNum_); + tilingData.gmmSwigluBaseParams.set_K(k); + tilingData.gmmSwigluBaseParams.set_N(n); + tilingData.gmmSwigluBaseParams.set_M(m); + tilingData.gmmSwiglu.set_maxProcessRowNum(row); + tilingData.gmmSwiglu.set_groupListLen(groupNum); + tilingData.gmmSwiglu.set_tokenLen(n); + + OPS_LOG_D(context->GetNodeName(),"grouped_matmul_swiglu_quant_weight_nz_tensor_list_tiling."); + OPS_LOG_D(context->GetNodeName(),"gmmSwigluBaseParams.groupNum: %ld", groupNum); + OPS_LOG_D(context->GetNodeName(),"gmmSwigluBaseParams.coreNum: %u ", compileInfoPtr->aicNum_); + OPS_LOG_D(context->GetNodeName(),"gmmSwigluBaseParams.M: %ld", m); + OPS_LOG_D(context->GetNodeName(),"gmmSwigluBaseParams.K: %ld", k); + OPS_LOG_D(context->GetNodeName(),"gmmSwigluBaseParams.N: %ld", n); + OPS_LOG_D(context->GetNodeName(),"gmmSwiglu.maxProcessRowNum: %ld", row); + OPS_LOG_D(context->GetNodeName(),"gmmSwiglu.groupListLen: %ld", groupNum); + OPS_LOG_D(context->GetNodeName(),"gmmSwiglu.tokenLen: %ld", n); + + auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo()); + using namespace matmul_tiling; + MatmulApiTiling tiling(ascendcPlatform); + tiling.SetAType(TPosition::GM, CubeFormat::ND, matmul_tiling::DataType::DT_INT8); + tiling.SetBType(TPosition::GM, CubeFormat::NZ, matmul_tiling::DataType::DT_INT8); + tiling.SetCType(TPosition::GM, CubeFormat::ND, matmul_tiling::DataType::DT_INT32); + tiling.SetBias(false); + tiling.SetShape(compileInfoPtr->baseM_, compileInfoPtr->baseN_, k); + tiling.SetOrgShape(m, n, k); + tiling.SetBufferSpace(-1, -1, -1); + OPS_ERR_IF(tiling.GetTiling(tilingData.mmTilingData) == -1, + OPS_REPORT_VECTOR_INNER_ERR(context->GetNodeName(), "grouped_matmul_swiglu_quant_weight_nz_tensor_list_tiling, get tiling failed"), + return GRAPH_FAILED); + auto workspaceSizes = context->GetWorkspaceSizes(1); + bool isPreFill = IsPreFill(tilingData); + tilingData.gmmSwigluBaseParams.set_isPreFill(isPreFill); + int64_t usrWorkspaceLimut = isPreFill ? PREFILL_USER_WORKSPACE_LIMIT : USER_WORKSPACE_LIMIT; + int64_t mLimit = ((usrWorkspaceLimut / DOUBLE_WORKSPACE_SPLIT) / INT32_DTYPE_SIZE) / n; + OPS_ERR_IF(mLimit <= 0, + OPS_REPORT_VECTOR_INNER_ERR(context->GetNodeName(),"mLimit is %ld must over then 0.", mLimit), + return GRAPH_FAILED); + tilingData.gmmSwigluBaseParams.set_mLimit(mLimit); + workspaceSizes[0] = SYS_WORKSPACE_SIZE + ((mLimit * DOUBLE_WORKSPACE_SPLIT > m \ + ? m \ + : mLimit * DOUBLE_WORKSPACE_SPLIT) * n * sizeof(int32_t)); + bool isSplitWorkSpace = m > mLimit * DOUBLE_WORKSPACE_SPLIT; + OPS_LOG_D(context->GetNodeName(), "USER_WORKSPACE_LIMIT: %ld", usrWorkspaceLimut); + OPS_LOG_D(context->GetNodeName(), "mLimit: %ld", mLimit); + OPS_LOG_D(context->GetNodeName(), "workspaceSizes: %lu", workspaceSizes[0]); + OPS_LOG_D(context->GetNodeName(), "isSplitWorkSpace: %s", isSplitWorkSpace ? "true" : "false"); + OPS_LOG_D(context->GetNodeName(), "isPreFill: %s", isPreFill ? "true" : "false"); + SetTilingKey(context, isSplitWorkSpace); + tilingData.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity()); + context->SetBlockDim(compileInfoPtr->aicNum_); // block dim is the number of aicube + context->GetRawTilingData()->SetDataSize(tilingData.GetDataSize()); + + OPS_LOG_D(context->GetNodeName(), "End Run GMM Swiglu Tiling."); + return GRAPH_SUCCESS; +} + +ASCENDC_EXTERN_C graphStatus TilingPrepareForGMMSwigluQuant(gert::TilingParseContext* context) { + // get info + fe::PlatFormInfos* platformInfoPtr = context->GetPlatformInfo(); + OPS_LOG_E_IF_NULL(context, platformInfoPtr, return GRAPH_FAILED); + auto compileInfoPtr = context->GetCompiledInfo(); + OPS_LOG_E_IF_NULL(context, compileInfoPtr, return GRAPH_FAILED); + + auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfoPtr); + compileInfoPtr->aicNum_ = ascendcPlatform.GetCoreNumAic(); + ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, compileInfoPtr->ubSize_); + OPS_LOG_D(context->GetNodeName(), "ubSize is %lu, aicNum is %u.", compileInfoPtr->ubSize_, compileInfoPtr->aicNum_); + return GRAPH_SUCCESS; +} + +IMPL_OP_OPTILING(GroupedMatmulSwigluQuantWeightNzTensorList) +.Tiling(TilingGMMSwigluQuant) +.TilingParse(TilingPrepareForGMMSwigluQuant); +} // namespace optiling diff --git a/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_host/grouped_matmul_swiglu_quant_weight_nz_tensor_list_tiling.h b/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_host/grouped_matmul_swiglu_quant_weight_nz_tensor_list_tiling.h new file mode 100644 index 00000000..ccc0d459 --- /dev/null +++ b/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_host/grouped_matmul_swiglu_quant_weight_nz_tensor_list_tiling.h @@ -0,0 +1,68 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file grouped_matmul_swiglu_quant_weight_nz_tensor_list_tiling.h + * \brief + */ +#ifndef AIR_CXX_RUNTIME_V2_OP_IMPL_GROUPED_MATMUL_SWIGLU_QUANT_WEIGHT_NZ_TENSOR_LIST_H +#define AIR_CXX_RUNTIME_V2_OP_IMPL_GROUPED_MATMUL_SWIGLU_QUANT_WEIGHT_NZ_TENSOR_LIST_H + +#include +#include "register/tilingdata_base.h" +#include "tiling/tiling_api.h" + +namespace optiling { +BEGIN_TILING_DATA_DEF(GMMSwigluBaseParams) + TILING_DATA_FIELD_DEF(uint32_t, groupNum); + TILING_DATA_FIELD_DEF(uint32_t, coreNum); + TILING_DATA_FIELD_DEF(uint32_t, K); + TILING_DATA_FIELD_DEF(uint32_t, N); + TILING_DATA_FIELD_DEF(uint32_t, M); + TILING_DATA_FIELD_DEF(uint32_t, mLimit); + TILING_DATA_FIELD_DEF(uint64_t, isPreFill); +END_TILING_DATA_DEF; +REGISTER_TILING_DATA_CLASS(GMMSwigluBaseParamsOp, GMMSwigluBaseParams) + +BEGIN_TILING_DATA_DEF(GMMSwiglu) + TILING_DATA_FIELD_DEF(uint32_t, maxProcessRowNum); + TILING_DATA_FIELD_DEF(uint32_t, groupListLen); + TILING_DATA_FIELD_DEF(uint32_t, tokenLen); +END_TILING_DATA_DEF; +REGISTER_TILING_DATA_CLASS(GMMSwigluOp, GMMSwiglu) + +BEGIN_TILING_DATA_DEF(GMMSwigluQuantTilingData) + TILING_DATA_FIELD_DEF_STRUCT(GMMSwigluBaseParams, gmmSwigluBaseParams); + TILING_DATA_FIELD_DEF_STRUCT(GMMSwiglu, gmmSwiglu); + TILING_DATA_FIELD_DEF_STRUCT(TCubeTiling, mmTilingData); +END_TILING_DATA_DEF; + +REGISTER_TILING_DATA_CLASS(GroupedMatmulSwigluQuantWeightNzTensorList, GMMSwigluQuantTilingData) +} + +namespace GroupedMatmulSwigluQuantWeightNzTensorListTiling { +constexpr uint32_t X_INDEX = 0; +constexpr uint32_t WEIGHT_INDEX = 1; +constexpr uint32_t GROUPLIST_INDEX = 4; +constexpr uint32_t BATCH_MODE_SCHEDULE = 1; +constexpr uint32_t SYS_WORKSPACE_SIZE = 16 * 1024 * 1024; +constexpr int64_t USER_WORKSPACE_LIMIT = 256 * 1024 * 1024; +constexpr int64_t PREFILL_USER_WORKSPACE_LIMIT = 64 * 1024 * 1024; +constexpr int64_t DOUBLE_WORKSPACE_SPLIT = 2; +constexpr int64_t INT32_DTYPE_SIZE = 4; +constexpr uint32_t PREFILL_M_MIN_SIZE = 16 * 1024; + +const std::set> PREFILL_WHITE_LIST = { // used for preFill case + {{2048, 1536}}, + {{4096, 3072}} +}; +} // namespace GroupedMatmulSwigluQuantWeightNzTensorListTiling + +#endif // AIR_CXX_RUNTIME_V2_OP_IMPL_GROUPED_MATMUL_SWIGLU_QUANT_WEIGHT_NZ_TENSOR_LIST_H diff --git a/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_kernel/grouped_matmul_swiglu_quant_weight_nz_tensor_list.cpp b/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_kernel/grouped_matmul_swiglu_quant_weight_nz_tensor_list.cpp new file mode 100644 index 00000000..ba60d80b --- /dev/null +++ b/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_kernel/grouped_matmul_swiglu_quant_weight_nz_tensor_list.cpp @@ -0,0 +1,80 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file grouped_matmul_swiglu_quant_weight_nz_tensor_list.cpp + * \brief + */ +#include "grouped_matmul_swiglu_quant_weight_nz_tensor_list.h" +#include +#include "grouped_matmul_swiglu_quant_weight_nz_tensor_list_split_ws.h" +using namespace AscendC; +using namespace matmul; +using namespace GROUPED_MATMUL_SWIGLU_QUANT_WEIGHT_NZ_TENSOR_LIST; +using MM_DTYPE_Y = int32_t; + +template +using xType = MatmulType; + +template +using weightType = MatmulType; + +using yType = MatmulType; + +#define GMM_CV_SPLIT_IMP(computeClass, dtypeC, transA, transB, sync, cfg, aType, bType, cType) \ + do { \ + using matmulType = MMImplType, bType, cType, cType, cfg>; \ + matmulType::MT mm; \ + GET_TILING_DATA_MEMBER(GMMSwigluQuantTilingData, gmmSwigluBaseParams, gmmSwigluBaseParams_, tiling); \ + GET_TILING_DATA_MEMBER(GMMSwigluQuantTilingData, mmTilingData, mmTilingData_, tiling); \ + GET_TILING_DATA_MEMBER(GMMSwigluQuantTilingData, gmmSwiglu, gmmSwiglu_, tiling); \ + if ASCEND_IS_AIC { \ + mm.SetSubBlockIdx(0); \ + mm.Init(&mmTilingData_, &tPipe); \ + } \ + computeClass computeOp(mm); \ + computeOp.Init(x, weight, perChannelScale, perTokenScale, groupList, quantOutput, quantScaleOutput, \ + user1, &gmmSwigluBaseParams_, &mmTilingData_, &gmmSwiglu_, &tPipe); \ + computeOp.Process(); \ + } while (0) + +extern "C" __global__ __aicore__ void grouped_matmul_swiglu_quant_weight_nz_tensor_list(GM_ADDR x, GM_ADDR weight, GM_ADDR perChannelScale, GM_ADDR perTokenScale, + GM_ADDR groupList, GM_ADDR quantOutput, GM_ADDR quantScaleOutput, + GM_ADDR workspace, GM_ADDR tiling) { + TPipe tPipe; + AscendCUtils::SetOverflow(1); + KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_MIX_AIC_1_2); + GM_ADDR user1 = GetUserWorkspace(workspace); + if (TILING_KEY_IS(0)) { // antiquant msd + KERNEL_TASK_TYPE(0, KERNEL_TYPE_MIX_AIC_1_2); + GMM_CV_SPLIT_IMP( + GMMSwigluCompute, // computeClass + DTYPE_WEIGHT_SCALE, + false, // transA + false, // transB + false, // sync + NZ_CFG_MDL, // cfg + xType, // aType + weightType, // bType + yType); // cType + } else if(TILING_KEY_IS(1)){ + KERNEL_TASK_TYPE(1, KERNEL_TYPE_MIX_AIC_1_2); + GMM_CV_SPLIT_IMP( + GMMSwigluSplitWorkSpaceCompute, // computeClass + DTYPE_WEIGHT_SCALE, + false, // transA + false, // transB + false, // sync + NZ_CFG_MDL, // cfg + xType, // aType + weightType, // bType + yType); // cType + } +} diff --git a/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_kernel/grouped_matmul_swiglu_quant_weight_nz_tensor_list.h b/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_kernel/grouped_matmul_swiglu_quant_weight_nz_tensor_list.h new file mode 100644 index 00000000..45d95488 --- /dev/null +++ b/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_kernel/grouped_matmul_swiglu_quant_weight_nz_tensor_list.h @@ -0,0 +1,498 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file grouped_matmul_swiglu_quant_weight_nz_tensor_list.h + * \brief + */ +#ifndef ASCENDC_GROUPED_MATMUL_SWIGLU_QUANT_WEIGHT_NZ_TENSOR_LIST_H +#define ASCENDC_GROUPED_MATMUL_SWIGLU_QUANT_WEIGHT_NZ_TENSOR_LIST_H + +#include "grouped_matmul_swiglu_quant_weight_nz_tensor_list_utils.h" +namespace GROUPED_MATMUL_SWIGLU_QUANT_WEIGHT_NZ_TENSOR_LIST { +/** @brief intenal computation class +*/ +template +class GMMSwigluCompute{ + public: + using AT = typename mmType::AT::T; + using BT = typename mmType::BT::T; + using B = typename mmType::BT; + using CT = typename mmType::CT::T; + using BiasT = typename mmType::BiasT::T; + using WT = int8_t; + constexpr static bool transposeX = mmType::AT::isTrans; + constexpr static bool transposeW = mmType::BT::isTrans; + + /** @brief constructor */ + __aicore__ inline GMMSwigluCompute(typename mmType::MT& mm_): mm(mm_) {} + + __aicore__ inline void Init(GM_ADDR x, GM_ADDR weight, GM_ADDR perChannelScale, GM_ADDR perTokenScale, + GM_ADDR groupList, GM_ADDR quantOutput, GM_ADDR quantScaleOutput, + GM_ADDR workspace, + const GMMSwigluBaseParams* __restrict gmmBaseParamsIN, + const TCubeTiling* __restrict mmTilingDataIN, + const GMMSwiglu* __restrict gmmSwigluIN, TPipe* tPipeIN); + __aicore__ inline void Process(); + private: + __aicore__ inline void MMCompute(uint32_t groupIdx, MNConfig& mnConfig, uint32_t coreIdx); + + __aicore__ inline void UpdateMnConfig(MNConfig &mnConfig); + + __aicore__ inline void SetMNConfig(const int32_t splitValue, const uint32_t groupIdx, MNConfig &mnConfig); + + __aicore__ inline void SetMKN(const int32_t splitValue, const uint32_t groupIdx, MNConfig &mnConfig); + + __aicore__ inline uint64_t GetWOffset(uint32_t tailN, uint32_t k); + + __aicore__ inline void MNBlockIdxCompute(MNConfig &mnConfig, const uint32_t curBlock, + const uint32_t count, const uint32_t thresholdM_dimN); + template + __aicore__ inline void UpdateChannelScale(uint32_t loopidx); + __aicore__ inline void VectorCompute(uint32_t loopidx); + template + __aicore__ inline void PreLoadTokenAndChannel(LocalTensor& channelScaleLocal); + __aicore__ inline void UpdateVecConfig(uint32_t blockIdx, VecConfig& vecConfig); + __aicore__ inline void customDataCopyIn(uint32_t outLoopIdx); + __aicore__ inline void customDataCopyOut(); + __aicore__ inline void Dequant(uint32_t loopidx); + __aicore__ inline void Quant(uint32_t loopidx); + __aicore__ inline void Swiglu(uint32_t loopidx); + private: + typename mmType::MT& mm; + const GMMSwigluBaseParams* __restrict gmmBaseParams; + const GMMSwiglu* __restrict gmmSwiglu; + const TCubeTiling* __restrict mmTilingData; + uint32_t blockIdx; + VecConfig vecConfig; + TPipe* pipe; + GlobalTensor xGM, weightGM; + GlobalTensor perChannelScaleGM; + GlobalTensor perTokenScaleGM; + GlobalTensor groupListGM; + GlobalTensor quantOutputGM; + GlobalTensor quantScaleOutputGM; + GlobalTensor mmOutGM; + // define the que + TQue mmOutQueue; + TQue perChannelScaleInQueue; + TQue quantOutQueue; + TQue quantScaleOutQueue; + TBuf reduceWorkspace; + TBuf castWorkspace; + bool sequentialWrite = true; + uint32_t cubeNum; // Matmul completions on the kernel + uint32_t groupNum; // Matmul completions on the kernel + int32_t preOffset; + int64_t aicCoreNum; + int64_t aivCoreNum; + GM_ADDR xTensorPtr; + GM_ADDR weightTensorPtr; + GM_ADDR perChannelScalePtr; +}; + +template +__aicore__ inline void GMMSwigluCompute::Init(GM_ADDR x, GM_ADDR weight, GM_ADDR perChannelScale, GM_ADDR perTokenScale, + GM_ADDR groupList, GM_ADDR quantOutput, GM_ADDR quantScaleOutput, + GM_ADDR workspace, + const GMMSwigluBaseParams* __restrict gmmSwigluBaseParamsIn, + const TCubeTiling* __restrict mmTilingDataIN, + const GMMSwiglu* __restrict gmmSwigluIN, TPipe* tPipeIN) +{ + aicCoreNum = GetBlockNum(); + aivCoreNum = aicCoreNum * 2; + blockIdx = GetBlockIdx(); + mmTilingData = mmTilingDataIN; + gmmBaseParams = gmmSwigluBaseParamsIn; + gmmSwiglu = gmmSwigluIN; + pipe = tPipeIN; + xTensorPtr = x; + weightTensorPtr = weight; + perChannelScalePtr = perChannelScale; + groupNum = gmmSwiglu->groupListLen; + if ASCEND_IS_AIC { + groupListGM.SetGlobalBuffer((__gm__ int64_t *)groupList, gmmSwiglu->groupListLen); + mmOutGM.SetGlobalBuffer((__gm__ int32_t *)workspace, gmmBaseParams->M * gmmSwiglu->tokenLen); + } + if ASCEND_IS_AIV { + mmOutGM.SetGlobalBuffer((__gm__ int32_t *)workspace, gmmBaseParams->M * gmmSwiglu->tokenLen); + perChannelScaleGM.SetGlobalBuffer((__gm__ CHANNELDTYPE *)perChannelScale, gmmSwiglu->groupListLen * gmmSwiglu->tokenLen); + perTokenScaleGM.SetGlobalBuffer((__gm__ float *)perTokenScale, gmmSwiglu->maxProcessRowNum); + groupListGM.SetGlobalBuffer((__gm__ int64_t *)groupList, gmmSwiglu->groupListLen); + quantOutputGM.SetGlobalBuffer((__gm__ int8_t *)quantOutput, gmmBaseParams->M * gmmSwiglu->tokenLen / 2); + quantScaleOutputGM.SetGlobalBuffer((__gm__ float *)quantScaleOutput, gmmSwiglu->maxProcessRowNum); + } +} + +template +__aicore__ inline void GMMSwigluCompute::Process() { + MNConfig mnConfig; + if ASCEND_IS_AIC { + preOffset = 0; + int32_t prevSplitValue = 0; + for (uint32_t groupIdx = 0, count = 0; groupIdx < gmmSwiglu->groupListLen; ++groupIdx) { + UpdateMnConfig(mnConfig); + int32_t currSplitValue = static_cast(groupListGM.GetValue(groupIdx)); + int32_t splitValue = currSplitValue - prevSplitValue; + prevSplitValue = currSplitValue; + SetMNConfig(splitValue, groupIdx, mnConfig); + if (mnConfig.m <= 0 || mnConfig.k <= 0 || mnConfig.n <= 0) { + continue; + } + mnConfig.blockDimM = Ceil(mnConfig.m, mnConfig.singleM); + mnConfig.blockDimN = Ceil(mnConfig.n, mnConfig.singleN); + + uint32_t curCount = count + mnConfig.blockDimM * mnConfig.blockDimN; + uint32_t curBlock = blockIdx >= count ? blockIdx : blockIdx + gmmBaseParams->coreNum; + uint32_t thresholdM_dimN = THRESHOLD_BLOCK_NUM * mnConfig.blockDimN; + + while (curBlock < curCount) { + MNBlockIdxCompute(mnConfig, curBlock, count, thresholdM_dimN); + MMCompute(groupIdx, mnConfig, blockIdx); + curBlock += aicCoreNum; + } + count = curCount % gmmBaseParams->coreNum; + } + SyncAll(); + } + + if ASCEND_IS_AIV { + UpdateVecConfig(blockIdx, vecConfig); + if (blockIdx < vecConfig.usedCoreNum) { + LocalTensor channelScaleLocal = perChannelScaleInQueue.AllocTensor(); + LocalTensor mmLocal = mmOutQueue.AllocTensor(); + LocalTensor quantLocal = quantOutQueue.AllocTensor(); + LocalTensor quantScaleLocal = quantScaleOutQueue.AllocTensor(); + mmOutQueue.EnQue(mmLocal); + quantScaleOutQueue.EnQue(quantScaleLocal); + quantOutQueue.EnQue(quantLocal); + PreLoadTokenAndChannel(channelScaleLocal); + } + SyncAll(); + if (blockIdx < vecConfig.usedCoreNum) { + for (uint32_t outLoopIdx = 0; outLoopIdx < vecConfig.outLoopNum; outLoopIdx++) { + vecConfig.innerLoopNum = outLoopIdx == (vecConfig.outLoopNum - 1) + ? vecConfig.tailLoopNum + : gmmSwiglu->maxProcessRowNum; + customDataCopyIn(outLoopIdx); + for (uint32_t innerLoopIdx = 0; innerLoopIdx < vecConfig.innerLoopNum; innerLoopIdx++) { + UpdateChannelScale(innerLoopIdx); + VectorCompute(innerLoopIdx); + } + customDataCopyOut(); + } + + LocalTensor channelScaleLocal = perChannelScaleInQueue.DeQue(); + LocalTensor mmLocal = mmOutQueue.DeQue(); + LocalTensor quantLocal = quantOutQueue.DeQue(); + LocalTensor quantScaleLocal = quantScaleOutQueue.DeQue(); + perChannelScaleInQueue.FreeTensor(channelScaleLocal); + mmOutQueue.FreeTensor(mmLocal); + quantScaleOutQueue.FreeTensor(quantScaleLocal); + quantOutQueue.FreeTensor(quantLocal); + } else { + return; + } + } +} + +template +template +__aicore__ inline void GMMSwigluCompute::PreLoadTokenAndChannel(LocalTensor& channelScaleLocal) +{ + GlobalTensor perChannelScaleTensor; + perChannelScaleTensor.SetGlobalBuffer(GetTensorAddr(vecConfig.curGroupIdx, perChannelScalePtr)); + + DataCopyExtParams copyChannelParams{1, static_cast(gmmSwiglu->tokenLen * sizeof(DTYPE_CS)), 0, 0, 0}; + DataCopyPadExtParams padParams{false, 0 ,0, 0}; + if constexpr(!IsSameType::value) { + LocalTensor dstLocalT = channelScaleLocal.template ReinterpretCast(); + DataCopyPad(dstLocalT[gmmSwiglu->tokenLen], perChannelScaleTensor, copyChannelParams, padParams); + PipeBarrier(); + Cast(channelScaleLocal, dstLocalT[gmmSwiglu->tokenLen], RoundMode::CAST_NONE, gmmSwiglu->tokenLen); + } else { + DataCopyPad(channelScaleLocal, perChannelScaleTensor, copyChannelParams, padParams); + } + perChannelScaleInQueue.EnQue(channelScaleLocal); +} + +template +__aicore__ inline void GMMSwigluCompute::MMCompute(uint32_t groupIdx, MNConfig& mnConfig, uint32_t coreIdx) +{ + uint32_t tailN = mnConfig.nIdx * mnConfig.singleN; + uint32_t curSingleN = mnConfig.nIdx < mnConfig.blockDimN - 1 ? mnConfig.singleN : mnConfig.n - tailN; + uint32_t curSingleM = mnConfig.mIdx < mnConfig.blockDimM - 1 ? mnConfig.singleM + : mnConfig.m - mnConfig.mIdx * mnConfig.singleM; + uint64_t xOffset = mnConfig.mIdx * mnConfig.singleM * mnConfig.k; + if constexpr (transposeX) { + xOffset = mnConfig.mIdx * mnConfig.singleM; + } + uint64_t outOffset = mnConfig.mIdx * mnConfig.singleM * mnConfig.n + tailN; + xGM.SetGlobalBuffer((__gm__ int8_t *)xTensorPtr + mnConfig.xBaseOffset); + weightGM.SetGlobalBuffer(GetTensorAddr(groupIdx, weightTensorPtr) + GetWOffset(tailN, mnConfig.k)); + if (mnConfig.blockDimM == 1){ + weightGM.SetL2CacheHint(CacheMode::CACHE_MODE_DISABLE); + } + mnConfig.workSpaceOffset = outOffset + mnConfig.yBaseOffset; + mm.SetOrgShape(mnConfig.m, mnConfig.n, mnConfig.k); + mm.SetSingleShape(curSingleM, curSingleN, mnConfig.k); + mm.SetTensorA(xGM[xOffset], transposeX); + mm.SetTensorB(weightGM, transposeW); + mm.template IterateAll(mmOutGM[mnConfig.workSpaceOffset], 0); +} + +template +__aicore__ inline void GMMSwigluCompute::UpdateMnConfig(MNConfig &mnConfig) { + if constexpr (B::format == CubeFormat::NZ) { + mnConfig.wBaseOffset += AlignUp<16>(mnConfig.k) * AlignUp<32>(mnConfig.n); // 16: nz format last two dim size + } else { + mnConfig.wBaseOffset += mnConfig.k * mnConfig.n; + } + mnConfig.nAxisBaseOffset += mnConfig.n; + mnConfig.mAxisBaseOffset += mnConfig.m; + mnConfig.xBaseOffset += mnConfig.m * mnConfig.k; + mnConfig.yBaseOffset += mnConfig.m * mnConfig.n; +} + +template +__aicore__ inline void GMMSwigluCompute::SetMNConfig(const int32_t splitValue, const uint32_t groupIdx, MNConfig &mnConfig) { + SetMKN(splitValue, groupIdx, mnConfig); + mnConfig.baseM = BASIC_M; + mnConfig.baseN = BASIC_N; + mnConfig.singleM = SINGLE_CORE_M; + mnConfig.singleN = SINGLE_CORE_N; +} + +template +__aicore__ inline void GMMSwigluCompute::SetMKN(const int32_t splitValue, const uint32_t groupIdx, MNConfig &mnConfig) +{ + mnConfig.m = static_cast(splitValue); + mnConfig.k = gmmBaseParams->K; // tilingData + mnConfig.n = gmmBaseParams->N; // tilingData +} + +template +__aicore__ inline uint64_t GMMSwigluCompute::GetWOffset(uint32_t tailN, uint32_t k) { + uint64_t wOffset = 0; + if constexpr (mmType::BT::format == CubeFormat::NZ) { + wOffset = tailN * AlignUp<16>(k); // 16: nz format last two dim size + } else { + wOffset = tailN; + } + return wOffset; +} + +template +__aicore__ inline void GMMSwigluCompute::MNBlockIdxCompute(MNConfig &mnConfig, const uint32_t curBlock, + const uint32_t count, const uint32_t thresholdM_dimN) { + mnConfig.mIdx = (curBlock - count) / mnConfig.blockDimN; + mnConfig.nIdx = (curBlock - count) % mnConfig.blockDimN; +} + +template +__aicore__ inline void GMMSwigluCompute::UpdateVecConfig(uint32_t blockIdx, VecConfig& vecConfig) +{ + // Step 1: Read grouplist reduceSum to calculate total data count + int64_t prevM = 0; + for (uint32_t groupIdx = 0; groupIdx < gmmSwiglu->groupListLen; groupIdx++){ + int64_t currM = groupListGM.GetValue(groupIdx); + int64_t tempM = currM - prevM; + prevM = currM; + vecConfig.M += tempM; + } + // Step 2: Calculate core allocation + uint32_t eachCoreTaskNum = (vecConfig.M + aivCoreNum - 1) / aivCoreNum; + vecConfig.usedCoreNum = vecConfig.M >= aivCoreNum ? aivCoreNum : vecConfig.M; + uint32_t tailCoreIdx = vecConfig.M - (eachCoreTaskNum - 1) * vecConfig.usedCoreNum; + vecConfig.taskNum = blockIdx < tailCoreIdx ? eachCoreTaskNum : eachCoreTaskNum - 1; + vecConfig.startIdx = blockIdx < tailCoreIdx + ? eachCoreTaskNum * blockIdx + :((eachCoreTaskNum - 1) * blockIdx + tailCoreIdx); + vecConfig.curIdx = vecConfig.startIdx; + vecConfig.startOffset = vecConfig.startIdx * gmmSwiglu->tokenLen; + vecConfig.curOffset = vecConfig.startOffset; + int64_t curStartIdx = vecConfig.startIdx; + prevM = 0; + for (uint32_t groupIdx = 0; groupIdx < gmmSwiglu->groupListLen; groupIdx++){ + int64_t currM = groupListGM.GetValue(groupIdx); + int64_t tempM = currM - prevM; + prevM = currM; + if (curStartIdx >= 0 && curStartIdx - tempM < 0) { + vecConfig.curGroupIdx = groupIdx; + vecConfig.nextUpadteInterVal = tempM - curStartIdx; + } + curStartIdx -= tempM; + } + // Step 3: Calculate total data volume + vecConfig.outLoopNum = (vecConfig.taskNum + gmmSwiglu->maxProcessRowNum - 1) / gmmSwiglu->maxProcessRowNum; + vecConfig.tailLoopNum = vecConfig.taskNum % gmmSwiglu->maxProcessRowNum + ? vecConfig.taskNum % gmmSwiglu->maxProcessRowNum + : gmmSwiglu->maxProcessRowNum; + pipe->Reset(); + // Step 4: Allocate space + pipe->InitBuffer(mmOutQueue, 1, gmmSwiglu->maxProcessRowNum * gmmSwiglu->tokenLen * sizeof(int32_t)); + pipe->InitBuffer(perChannelScaleInQueue, 1, gmmSwiglu->tokenLen * sizeof(float)); + pipe->InitBuffer(quantOutQueue, 1, gmmSwiglu->maxProcessRowNum * gmmSwiglu->tokenLen / 2 * sizeof(int8_t)); + pipe->InitBuffer(quantScaleOutQueue, 1, AlignUp(gmmSwiglu->maxProcessRowNum, 8) * sizeof(float)); + pipe->InitBuffer(reduceWorkspace, 1024 * sizeof(float)); + pipe->InitBuffer(castWorkspace, 32 * sizeof(int8_t)); +} + +template +__aicore__ inline void GMMSwigluCompute::customDataCopyIn(uint32_t outLoopIdx) +{ + LocalTensor _inMMLocal_0 = mmOutQueue.DeQue(); + DataCopyExtParams copyParams_0{1, static_cast(vecConfig.innerLoopNum * gmmSwiglu->tokenLen * sizeof(int32_t)), 0, 0, 0}; + DataCopyPadExtParams padParams_0{false, 0 ,0, 0}; + DataCopyPad(_inMMLocal_0, mmOutGM[vecConfig.curOffset], copyParams_0, padParams_0); + + mmOutQueue.EnQue(_inMMLocal_0); + + LocalTensor _inMMLocal_1 = mmOutQueue.DeQue(); + + Cast(_inMMLocal_1.ReinterpretCast(), _inMMLocal_1, RoundMode::CAST_NONE, vecConfig.innerLoopNum * gmmSwiglu->tokenLen); + + mmOutQueue.EnQue(_inMMLocal_1); + LocalTensor _inMMLocal_2 = mmOutQueue.DeQue(); + set_flag(PIPE_S, PIPE_V, EVENT_ID0); + for (uint32_t i = 0; i < vecConfig.innerLoopNum; i++){ + wait_flag(PIPE_S, PIPE_V, EVENT_ID0); + float scale = perTokenScaleGM.GetValue(vecConfig.curIdx); + set_flag(PIPE_S, PIPE_V, EVENT_ID0); + wait_flag(PIPE_S, PIPE_V, EVENT_ID0); + Muls(_inMMLocal_2[i * gmmSwiglu->tokenLen], _inMMLocal_2[i * gmmSwiglu->tokenLen], scale, gmmSwiglu->tokenLen); + set_flag(PIPE_S, PIPE_V, EVENT_ID0); + vecConfig.curIdx++; + } + wait_flag(PIPE_S, PIPE_V, EVENT_ID0); + vecConfig.curOffset = vecConfig.curIdx * gmmSwiglu->tokenLen; + mmOutQueue.EnQue(_inMMLocal_2); +} + +template +template +__aicore__ inline void GMMSwigluCompute::UpdateChannelScale(uint32_t loopIdx){ + // Update perChannel + if (unlikely(vecConfig.nextUpadteInterVal == 0)) { + int64_t loop = gmmSwiglu->groupListLen - vecConfig.curGroupIdx; + while (loop--) { + int64_t curTemp = groupListGM.GetValue(vecConfig.curGroupIdx); + vecConfig.curGroupIdx++; + int64_t nextTemp = groupListGM.GetValue(vecConfig.curGroupIdx); + if(nextTemp != curTemp){ + vecConfig.nextUpadteInterVal = nextTemp - curTemp; + break; + } + } + LocalTensor _inChannel = perChannelScaleInQueue.DeQue(); + DataCopyExtParams copyParams{1, static_cast(gmmSwiglu->tokenLen * sizeof(DTYPE_CS)), 0, 0, 0}; + DataCopyPadExtParams padParams{false, 0 ,0, 0}; + + GlobalTensor perChannelScaleTensor; + perChannelScaleTensor.SetGlobalBuffer(GetTensorAddr(vecConfig.curGroupIdx, perChannelScalePtr)); + + if constexpr(!IsSameType::value) { + LocalTensor dstLocalT = _inChannel.template ReinterpretCast(); + DataCopyPad(dstLocalT[gmmSwiglu->tokenLen], perChannelScaleTensor, copyParams, padParams); + PipeBarrier(); + Cast(_inChannel, dstLocalT[gmmSwiglu->tokenLen], RoundMode::CAST_NONE, gmmSwiglu->tokenLen); + } else { + DataCopyPad(_inChannel, perChannelScaleTensor, copyParams, padParams); + } + PipeBarrier(); + perChannelScaleInQueue.EnQue(_inChannel); + } +} + +template +__aicore__ inline void GMMSwigluCompute::VectorCompute(uint32_t loopIdx) { + Dequant(loopIdx); + Swiglu(loopIdx); + Quant(loopIdx); +} + +template +__aicore__ inline void GMMSwigluCompute::Dequant(uint32_t loopIdx) { + // perChanelScale * perTokenScale + LocalTensor mmLocal = mmOutQueue.DeQue(); + LocalTensor perChannelLocal = perChannelScaleInQueue.DeQue(); + Mul(mmLocal[loopIdx * gmmSwiglu->tokenLen], mmLocal[loopIdx * gmmSwiglu->tokenLen], perChannelLocal, gmmSwiglu->tokenLen); + vecConfig.nextUpadteInterVal--; + mmOutQueue.EnQue(mmLocal); + perChannelScaleInQueue.EnQue(perChannelLocal); +} + +template +__aicore__ inline void GMMSwigluCompute::Swiglu(uint32_t loopIdx) { + // High-level API swiglu + LocalTensor _inMMLocal = mmOutQueue.DeQue(); + float beta = 1.0f; + LocalTensor workspaceLocal= reduceWorkspace.Get(); + LocalTensor src0Local = _inMMLocal[loopIdx * gmmSwiglu->tokenLen + gmmSwiglu->tokenLen / 2]; + LocalTensor src1Local = _inMMLocal[loopIdx * gmmSwiglu->tokenLen]; + SwiGLU(workspaceLocal, src0Local, src1Local, beta, gmmSwiglu->tokenLen / 2); + PipeBarrier(); + DataCopyParams repeatParams{1, static_cast((gmmSwiglu->tokenLen / 2) / 8), 0, 0}; + DataCopy(_inMMLocal[loopIdx * gmmSwiglu->tokenLen], workspaceLocal, repeatParams); + mmOutQueue.EnQue(_inMMLocal); +} + +template +__aicore__ inline void GMMSwigluCompute::Quant(uint32_t loopIdx) { + LocalTensor _inMMLocal = mmOutQueue.DeQue(); + Abs(_inMMLocal[loopIdx * gmmSwiglu->tokenLen + gmmSwiglu->tokenLen / BISECT], + _inMMLocal[loopIdx * gmmSwiglu->tokenLen], + gmmSwiglu->tokenLen / BISECT); + LocalTensor workspaceLocal= reduceWorkspace.Get(); + PipeBarrier(); + ReduceMaxTemplate(workspaceLocal, + _inMMLocal, loopIdx * gmmSwiglu->tokenLen + gmmSwiglu->tokenLen / BISECT, gmmSwiglu->tokenLen / BISECT); + PipeBarrier(); + float quantScale = workspaceLocal.GetValue(0) / QUANT_SCALE_INT8; + PipeBarrier(); + LocalTensor quantScaleLocal = quantScaleOutQueue.DeQue(); + PipeBarrier(); + quantScaleLocal.SetValue(loopIdx, quantScale); + PipeBarrier(); + quantScale = 1 / quantScale; + PipeBarrier(); + Muls(_inMMLocal[loopIdx * gmmSwiglu->tokenLen], _inMMLocal[loopIdx * gmmSwiglu->tokenLen], + quantScale, gmmSwiglu->tokenLen / BISECT); + PipeBarrier(); + LocalTensor quantLocal = quantOutQueue.DeQue(); + int32_t dstTempOffset = static_cast(loopIdx * gmmSwiglu->tokenLen / BISECT); + int32_t srcTempOffset = static_cast(loopIdx * gmmSwiglu->tokenLen); + int32_t tempCount = static_cast(gmmSwiglu->tokenLen / BISECT); + LocalTensor castSpace = castWorkspace.Get(); + CastFp32ToInt8Template(quantLocal, _inMMLocal, castSpace, dstTempOffset, srcTempOffset, tempCount); + mmOutQueue.EnQue(_inMMLocal); + quantOutQueue.EnQue(quantLocal); +} + +template +__aicore__ inline void GMMSwigluCompute::customDataCopyOut() { + // perChanelScale * perTokenScale + LocalTensor quantScaleLocal = quantScaleOutQueue.DeQue(); + DataCopyParams copyParams_0{1, (uint16_t)(vecConfig.innerLoopNum * sizeof(float)), 0, 0}; + PipeBarrier(); + DataCopyPad(quantScaleOutputGM[vecConfig.startIdx], quantScaleLocal, copyParams_0); + LocalTensor quantLocal = quantOutQueue.DeQue(); + DataCopyParams copyParams_1{1, (uint16_t)(vecConfig.innerLoopNum * gmmSwiglu->tokenLen / 2 * sizeof(int8_t)), 0, 0}; + PipeBarrier(); + DataCopyPad(quantOutputGM[vecConfig.startIdx * gmmSwiglu->tokenLen / 2], quantLocal, copyParams_1); + PipeBarrier(); + vecConfig.startIdx += vecConfig.innerLoopNum; + vecConfig.startOffset = vecConfig.startIdx * gmmSwiglu->tokenLen; + quantOutQueue.EnQue(quantLocal); + quantScaleOutQueue.EnQue(quantScaleLocal); +} + +} // namespace GROUPED_MATMUL +#endif // ASCENDC_GROUPED_MATMUL_QUANT_MIXCORE_H diff --git a/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_kernel/grouped_matmul_swiglu_quant_weight_nz_tensor_list_split_ws.h b/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_kernel/grouped_matmul_swiglu_quant_weight_nz_tensor_list_split_ws.h new file mode 100644 index 00000000..a5c8571d --- /dev/null +++ b/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_kernel/grouped_matmul_swiglu_quant_weight_nz_tensor_list_split_ws.h @@ -0,0 +1,588 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file grouped_matmul_swiglu_quant_weight_nz_tensor_list_split_ws.h + * \brief + */ +#ifndef ASCENDC_GROUPED_MATMUL_SWIGLU_QUANT_WEIGHT_NZ_TENSOR_LIST_SPLIT_WS_H +#define ASCENDC_GROUPED_MATMUL_SWIGLU_QUANT_WEIGHT_NZ_TENSOR_LIST_SPLIT_WS_H + +#include "grouped_matmul_swiglu_quant_weight_nz_tensor_list_utils.h" +namespace GROUPED_MATMUL_SWIGLU_QUANT_WEIGHT_NZ_TENSOR_LIST { +/** @brief internal computation class +*/ + +template +class GMMSwigluSplitWorkSpaceCompute{ + public: + using AT = typename mmType::AT::T; + using BT = typename mmType::BT::T; + using B = typename mmType::BT; + using CT = typename mmType::CT::T; + using BiasT = typename mmType::BiasT::T; + using WT = int8_t; + constexpr static bool transposeX = mmType::AT::isTrans; + constexpr static bool transposeW = mmType::BT::isTrans; + + /** @brief constructor */ + __aicore__ inline GMMSwigluSplitWorkSpaceCompute(typename mmType::MT& mm_): mm(mm_) {} + + __aicore__ inline void Init(GM_ADDR x, GM_ADDR weight, GM_ADDR perChannelScale, GM_ADDR perTokenScale, + GM_ADDR groupList, GM_ADDR quantOutput, GM_ADDR quantScaleOutput, + GM_ADDR workspace, + const GMMSwigluBaseParams* __restrict gmmBaseParamsIN, + const TCubeTiling* __restrict mmTilingDataIN, + const GMMSwiglu* __restrict gmmSwigluIN, TPipe* tPipeIN); + __aicore__ inline void Process(); + + private: + __aicore__ inline void MMCompute(uint32_t groupIdx, MNConfig& mnConfig, uint32_t coreIdx, GlobalTensor &mmOutGM); + + __aicore__ inline void UpdateMnConfig(MNConfig &mnConfig); + + __aicore__ inline void SetMNConfig(const int32_t splitValue, const uint32_t groupIdx, MNConfig &mnConfig); + + __aicore__ inline void SetMKN(const int32_t splitValue, const uint32_t groupIdx, MNConfig &mnConfig); + + __aicore__ inline uint64_t GetWOffset(uint32_t tailN, uint32_t k); + + __aicore__ inline void MNBlockIdxCompute(MNConfig &mnConfig, const uint32_t curBlock, + const uint32_t count, const uint32_t thresholdM_dimN); + + template + __aicore__ inline void UpdateChannelScale(uint32_t loopidx, VecConfig& vecConfig); + + __aicore__ inline void VectorCompute(uint32_t loopidx, VecConfig& vecConfig); + + template + __aicore__ inline void PreLoadTokenAndChannel(LocalTensor& channelScaleLocal, VecConfig& vecConfig); + + __aicore__ inline void UpdateVecConfig(uint32_t blockIdx, VecConfig& vecConfig); + + __aicore__ inline void UpdateWorkSpaceSplitConfig(WorkSpaceSplitConfig &workspaceSplitConfig, int32_t workspaceSplitLoopIdx); + + __aicore__ inline void InitWorkSpaceSplitConfig(WorkSpaceSplitConfig &workspaceSplitConfig); + + __aicore__ inline void customDataCopyIn(uint32_t outLoopIdx, GlobalTensor &mmOutGM, VecConfig& vecConfig); + + __aicore__ inline void customDataCopyOut(VecConfig& vecConfig); + + __aicore__ inline void Dequant(uint32_t loopidx, VecConfig& vecConfig); + + __aicore__ inline void Quant(uint32_t loopidx, VecConfig& vecConfig); + + __aicore__ inline void Swiglu(uint32_t loopidx, VecConfig& vecConfig); + + private: + typename mmType::MT& mm; + const GMMSwigluBaseParams* __restrict gmmBaseParams; + const GMMSwiglu* __restrict gmmSwiglu; + const TCubeTiling* __restrict mmTilingData; + uint32_t blockIdx; + WorkSpaceSplitConfig workspaceSplitConfig; + TPipe* pipe; + GlobalTensor xGM; + GlobalTensor weightGM; + GlobalTensor perChannelScaleGM; + GlobalTensor perTokenScaleGM; + GlobalTensor groupListGM; + GlobalTensor quantOutputGM; + GlobalTensor quantScaleOutputGM; + GlobalTensor mmOutGM1; + GlobalTensor mmOutGM2; + // define the que + TQue mmOutQueue; + TQue perChannelScaleInQueue; + TQue quantOutQueue; + TQue quantScaleOutQueue; + TBuf reduceWorkspace; + TBuf castWorkspace; + bool sequentialWrite = true; + uint32_t cubeNum; // Matmul completions on the kernel + uint32_t groupNum; // Matmul completions on the kernel + int64_t aicCoreNum; + int64_t aivCoreNum; + GM_ADDR xTensorPtr; + GM_ADDR weightTensorPtr; + GM_ADDR perChannelScalePtr; +}; + +template +__aicore__ inline void GMMSwigluSplitWorkSpaceCompute::Init(GM_ADDR x, GM_ADDR weight, + GM_ADDR perChannelScale, GM_ADDR perTokenScale, + GM_ADDR groupList, GM_ADDR quantOutput, + GM_ADDR quantScaleOutput, GM_ADDR workspace, + const GMMSwigluBaseParams* __restrict gmmSwigluBaseParamsIn, + const TCubeTiling* __restrict mmTilingDataIN, + const GMMSwiglu* __restrict gmmSwigluIN, TPipe* tPipeIN) +{ + aicCoreNum = GetBlockNum(); + aivCoreNum = aicCoreNum * 2; + blockIdx = GetBlockIdx(); + pipe = tPipeIN; + xTensorPtr = x; + weightTensorPtr = weight; + perChannelScalePtr = perChannelScale; + mmTilingData = mmTilingDataIN; + gmmBaseParams = gmmSwigluBaseParamsIn; + gmmSwiglu = gmmSwigluIN; + groupNum = gmmSwiglu->groupListLen; + if ASCEND_IS_AIC { + groupListGM.SetGlobalBuffer((__gm__ int64_t *)groupList, gmmSwiglu->groupListLen); + mmOutGM1.SetGlobalBuffer((__gm__ int32_t *)workspace, gmmBaseParams->mLimit * gmmSwiglu->tokenLen); + mmOutGM2.SetGlobalBuffer((__gm__ int32_t *)workspace + gmmBaseParams->mLimit * gmmSwiglu->tokenLen, + gmmBaseParams->mLimit * gmmSwiglu->tokenLen); + } + if ASCEND_IS_AIV { + mmOutGM1.SetGlobalBuffer((__gm__ int32_t *)workspace, gmmBaseParams->mLimit * gmmSwiglu->tokenLen); + mmOutGM2.SetGlobalBuffer((__gm__ int32_t *)workspace + gmmBaseParams->mLimit * gmmSwiglu->tokenLen, + gmmBaseParams->mLimit * gmmSwiglu->tokenLen); + perChannelScaleGM.SetGlobalBuffer((__gm__ CHANNELDTYPE *)perChannelScale, + gmmSwiglu->groupListLen * gmmSwiglu->tokenLen); + perTokenScaleGM.SetGlobalBuffer((__gm__ float *)perTokenScale, gmmBaseParams->M); + groupListGM.SetGlobalBuffer((__gm__ int64_t *)groupList, gmmSwiglu->groupListLen); + quantOutputGM.SetGlobalBuffer((__gm__ int8_t *)quantOutput, gmmBaseParams->M * gmmSwiglu->tokenLen / 2); + quantScaleOutputGM.SetGlobalBuffer((__gm__ float *)quantScaleOutput, gmmBaseParams->M); + } +} + +template +__aicore__ inline void GMMSwigluSplitWorkSpaceCompute::InitWorkSpaceSplitConfig(WorkSpaceSplitConfig &workspaceSplitConfig) +{ + workspaceSplitConfig.M = groupListGM.GetValue(gmmSwiglu->groupListLen - 1); + workspaceSplitConfig.loopCount = Ceil(workspaceSplitConfig.M, gmmBaseParams->mLimit); + workspaceSplitConfig.notLastTaskSize = gmmBaseParams->mLimit; + workspaceSplitConfig.lastLoopTaskSize = workspaceSplitConfig.M - (workspaceSplitConfig.loopCount - 1) * gmmBaseParams->mLimit; + workspaceSplitConfig.leftMatrixStartIndex = 0; + workspaceSplitConfig.rightMatrixExpertStartIndex = 0; + workspaceSplitConfig.rightMatrixExpertNextStartIndex = 0; + workspaceSplitConfig.isLastLoop = false; +} + +template +__aicore__ inline void GMMSwigluSplitWorkSpaceCompute::UpdateWorkSpaceSplitConfig(WorkSpaceSplitConfig &workspaceSplitConfig, int32_t workspaceSplitLoopIdx) +{ + workspaceSplitConfig.leftMatrixStartIndex = workspaceSplitLoopIdx * gmmBaseParams->mLimit; + workspaceSplitConfig.rightMatrixExpertStartIndex = workspaceSplitConfig.rightMatrixExpertNextStartIndex; + workspaceSplitConfig.rightMatrixExpertEndIndex = workspaceSplitConfig.rightMatrixExpertStartIndex; + // Calculate the right expert matrix end index (rightMatrixExpertEndIndex) and the next start index (rightMatrixExpertNextStartIndex) + int32_t curTaskNum = 0; + int32_t nextTaskNum = 0; + while(workspaceSplitConfig.rightMatrixExpertEndIndex < gmmSwiglu->groupListLen) + { + curTaskNum = groupListGM.GetValue(workspaceSplitConfig.rightMatrixExpertEndIndex) - workspaceSplitConfig.leftMatrixStartIndex; + int32_t nextTaskIdx = workspaceSplitConfig.rightMatrixExpertEndIndex >= gmmSwiglu->groupListLen - 1 \ + ? gmmSwiglu->groupListLen - 1 \ + : workspaceSplitConfig.rightMatrixExpertEndIndex + 1; + nextTaskNum = groupListGM.GetValue(nextTaskIdx) - workspaceSplitConfig.leftMatrixStartIndex; + if (curTaskNum > gmmBaseParams->mLimit){ + workspaceSplitConfig.rightMatrixExpertNextStartIndex = workspaceSplitConfig.rightMatrixExpertEndIndex; + break; + } else if (curTaskNum == gmmBaseParams->mLimit && nextTaskNum > gmmBaseParams->mLimit){ + workspaceSplitConfig.rightMatrixExpertNextStartIndex = workspaceSplitConfig.rightMatrixExpertEndIndex + 1; + break; + } else if (nextTaskNum > gmmBaseParams->mLimit){ + workspaceSplitConfig.rightMatrixExpertEndIndex++; + workspaceSplitConfig.rightMatrixExpertNextStartIndex = workspaceSplitConfig.rightMatrixExpertEndIndex; + break; + } + workspaceSplitConfig.rightMatrixExpertEndIndex++; + } + workspaceSplitConfig.isLastLoop = workspaceSplitLoopIdx == workspaceSplitConfig.loopCount - 1 ? true : false; + + if (workspaceSplitConfig.isLastLoop) { + workspaceSplitConfig.rightMatrixExpertEndIndex = workspaceSplitConfig.rightMatrixExpertEndIndex >= gmmSwiglu->groupListLen \ + ? gmmSwiglu->groupListLen - 1 \ + : workspaceSplitConfig.rightMatrixExpertEndIndex; + } +} + +template +__aicore__ inline void GMMSwigluSplitWorkSpaceCompute::Process() { + InitWorkSpaceSplitConfig(workspaceSplitConfig); + int32_t parallelNum = gmmBaseParams->isPreFill ? 2 : 1; // 2: double workspace buffer + for (int32_t workspaceSplitLoopIdx = 0; workspaceSplitLoopIdx < workspaceSplitConfig.loopCount; workspaceSplitLoopIdx++) { + UpdateWorkSpaceSplitConfig(workspaceSplitConfig, workspaceSplitLoopIdx); + GlobalTensor mmOutGM = (workspaceSplitLoopIdx % 2 == 0 ) ? mmOutGM1 : mmOutGM2; + + if ASCEND_IS_AIC { + if (workspaceSplitLoopIdx >= parallelNum){ // first parallelNum core no need to wait + SyncAll(); + } + MNConfig mnConfig; + int32_t prevSplitValue = workspaceSplitConfig.leftMatrixStartIndex; + for (uint32_t groupIdx = workspaceSplitConfig.rightMatrixExpertStartIndex, count = 0; groupIdx <= workspaceSplitConfig.rightMatrixExpertEndIndex; ++groupIdx) { + UpdateMnConfig(mnConfig); + int32_t currSplitValue = static_cast(groupListGM.GetValue(groupIdx)); + currSplitValue = currSplitValue > (workspaceSplitLoopIdx + 1) * gmmBaseParams->mLimit \ + ? (workspaceSplitLoopIdx + 1) * gmmBaseParams->mLimit \ + : currSplitValue; + int32_t splitValue = currSplitValue - prevSplitValue; + prevSplitValue = currSplitValue; + SetMNConfig(splitValue, groupIdx, mnConfig); + if (mnConfig.m <= 0 || mnConfig.k <= 0 || mnConfig.n <= 0) { + continue; + } + mnConfig.blockDimM = Ceil(mnConfig.m, mnConfig.singleM); + mnConfig.blockDimN = Ceil(mnConfig.n, mnConfig.singleN); + + uint32_t curCount = count + mnConfig.blockDimM * mnConfig.blockDimN; + uint32_t curBlock = blockIdx >= count ? blockIdx : blockIdx + gmmBaseParams->coreNum; + uint32_t thresholdM_dimN = THRESHOLD_BLOCK_NUM * mnConfig.blockDimN; + + while (curBlock < curCount) { + MNBlockIdxCompute(mnConfig, curBlock, count, thresholdM_dimN); + MMCompute(groupIdx, mnConfig, blockIdx, mmOutGM); + curBlock += aicCoreNum; + } + count = curCount % gmmBaseParams->coreNum; + } + SyncAll(); + } + + if ASCEND_IS_AIV { + VecConfig vecConfig; + UpdateVecConfig(blockIdx, vecConfig); + if (blockIdx < vecConfig.usedCoreNum) { + LocalTensor channelScaleLocal = perChannelScaleInQueue.AllocTensor(); + LocalTensor mmLocal = mmOutQueue.AllocTensor(); + LocalTensor quantLocal = quantOutQueue.AllocTensor(); + LocalTensor quantScaleLocal = quantScaleOutQueue.AllocTensor(); + mmOutQueue.EnQue(mmLocal); + quantScaleOutQueue.EnQue(quantScaleLocal); + quantOutQueue.EnQue(quantLocal); + PreLoadTokenAndChannel(channelScaleLocal, vecConfig); + } + SyncAll(); + if (blockIdx < vecConfig.usedCoreNum) { + for (uint32_t outLoopIdx = 0; outLoopIdx < vecConfig.outLoopNum; outLoopIdx++) { + vecConfig.innerLoopNum = outLoopIdx == (vecConfig.outLoopNum - 1) + ? vecConfig.tailLoopNum + : gmmSwiglu->maxProcessRowNum; + PipeBarrier(); + customDataCopyIn(outLoopIdx, mmOutGM, vecConfig); + PipeBarrier(); + for (uint32_t innerLoopIdx = 0; innerLoopIdx < vecConfig.innerLoopNum; innerLoopIdx++) { + UpdateChannelScale(innerLoopIdx, vecConfig); + VectorCompute(innerLoopIdx, vecConfig); + } + PipeBarrier(); + customDataCopyOut(vecConfig); + PipeBarrier(); + } + + LocalTensor channelScaleLocal = perChannelScaleInQueue.DeQue(); + LocalTensor mmLocal = mmOutQueue.DeQue(); + LocalTensor quantLocal = quantOutQueue.DeQue(); + LocalTensor quantScaleLocal = quantScaleOutQueue.DeQue(); + perChannelScaleInQueue.FreeTensor(channelScaleLocal); + mmOutQueue.FreeTensor(mmLocal); + quantScaleOutQueue.FreeTensor(quantScaleLocal); + quantOutQueue.FreeTensor(quantLocal); + } + if (workspaceSplitLoopIdx < workspaceSplitConfig.loopCount - parallelNum){ + SyncAll(); + } + } + } +} + +template +template +__aicore__ inline void GMMSwigluSplitWorkSpaceCompute::PreLoadTokenAndChannel(LocalTensor& channelScaleLocal, VecConfig& vecConfig) +{ + GlobalTensor perChannelScaleTensor; + perChannelScaleTensor.SetGlobalBuffer(GetTensorAddr(vecConfig.curGroupIdx, perChannelScalePtr)); + + DataCopyExtParams copyChannelParams{1, static_cast(gmmSwiglu->tokenLen * sizeof(DTYPE_CS)), 0, 0, 0}; + DataCopyPadExtParams padParams{false, 0 ,0, 0}; + if constexpr(!IsSameType::value) { + LocalTensor dstLocalT = channelScaleLocal.template ReinterpretCast(); + DataCopyPad(dstLocalT[gmmSwiglu->tokenLen], perChannelScaleTensor, copyChannelParams, padParams); + PipeBarrier(); + Cast(channelScaleLocal, dstLocalT[gmmSwiglu->tokenLen], RoundMode::CAST_NONE, gmmSwiglu->tokenLen); + } else { + DataCopyPad(channelScaleLocal, perChannelScaleTensor, copyChannelParams, padParams); + } + perChannelScaleInQueue.EnQue(channelScaleLocal); +} + +template +__aicore__ inline void GMMSwigluSplitWorkSpaceCompute::MMCompute(uint32_t groupIdx, MNConfig& mnConfig, uint32_t coreIdx, GlobalTensor &mmOutGM) +{ + uint32_t tailN = mnConfig.nIdx * mnConfig.singleN; + uint32_t curSingleN = mnConfig.nIdx < mnConfig.blockDimN - 1 ? mnConfig.singleN : mnConfig.n - tailN; + uint32_t curSingleM = mnConfig.mIdx < mnConfig.blockDimM - 1 ? mnConfig.singleM + : mnConfig.m - mnConfig.mIdx * mnConfig.singleM; + uint64_t xOffset = mnConfig.mIdx * mnConfig.singleM * mnConfig.k; + if constexpr (transposeX) { + xOffset = mnConfig.mIdx * mnConfig.singleM; + } + uint64_t outOffset = mnConfig.mIdx * mnConfig.singleM * mnConfig.n + tailN; + xGM.SetGlobalBuffer((__gm__ int8_t *)xTensorPtr + mnConfig.xBaseOffset + workspaceSplitConfig.leftMatrixStartIndex * mnConfig.k); + weightGM.SetGlobalBuffer(GetTensorAddr(groupIdx, weightTensorPtr) + GetWOffset(tailN, mnConfig.k)); + if (mnConfig.blockDimM == 1){ + weightGM.SetL2CacheHint(CacheMode::CACHE_MODE_DISABLE); + } else { + weightGM.SetL2CacheHint(CacheMode::CACHE_MODE_NORMAL); + } + mnConfig.workSpaceOffset = outOffset + mnConfig.yBaseOffset; + mm.SetOrgShape(mnConfig.m, mnConfig.n, mnConfig.k); + mm.SetSingleShape(curSingleM, curSingleN, mnConfig.k); + mm.SetTensorA(xGM[xOffset], transposeX); + mm.SetTensorB(weightGM, transposeW); + mm.template IterateAll(mmOutGM[mnConfig.workSpaceOffset], 0); +} + +template +__aicore__ inline void GMMSwigluSplitWorkSpaceCompute::UpdateMnConfig(MNConfig &mnConfig) { + if constexpr (B::format == CubeFormat::NZ) { + mnConfig.wBaseOffset += AlignUp<16>(mnConfig.k) * AlignUp<32>(mnConfig.n); // 16: nz format last two dim size + } else { + mnConfig.wBaseOffset += mnConfig.k * mnConfig.n; + } + mnConfig.nAxisBaseOffset += mnConfig.n; + mnConfig.mAxisBaseOffset += mnConfig.m; + mnConfig.xBaseOffset += mnConfig.m * mnConfig.k; + mnConfig.yBaseOffset += mnConfig.m * mnConfig.n; +} + +template +__aicore__ inline void GMMSwigluSplitWorkSpaceCompute::SetMNConfig(const int32_t splitValue, const uint32_t groupIdx, MNConfig &mnConfig) { + SetMKN(splitValue, groupIdx, mnConfig); + mnConfig.baseM = BASIC_M; + mnConfig.baseN = BASIC_N; + mnConfig.singleM = SINGLE_CORE_M; + mnConfig.singleN = SINGLE_CORE_N; +} + +template +__aicore__ inline void GMMSwigluSplitWorkSpaceCompute::SetMKN(const int32_t splitValue, const uint32_t groupIdx, MNConfig &mnConfig) +{ + mnConfig.m = static_cast(splitValue); + mnConfig.k = gmmBaseParams->K; // tilingData + mnConfig.n = gmmBaseParams->N; // tilingData +} + +template +__aicore__ inline uint64_t GMMSwigluSplitWorkSpaceCompute::GetWOffset(uint32_t tailN, uint32_t k) { + uint64_t wOffset = 0; + if constexpr (mmType::BT::format == CubeFormat::NZ) { + wOffset = tailN * AlignUp<16>(k); // 16: nz format last two dim size + } else { + wOffset = tailN; + } + return wOffset; +} + +template +__aicore__ inline void GMMSwigluSplitWorkSpaceCompute::MNBlockIdxCompute(MNConfig &mnConfig, const uint32_t curBlock, + const uint32_t count, const uint32_t thresholdM_dimN) { + mnConfig.mIdx = (curBlock - count) / mnConfig.blockDimN; + mnConfig.nIdx = (curBlock - count) % mnConfig.blockDimN; +} + +template +__aicore__ inline void GMMSwigluSplitWorkSpaceCompute::UpdateVecConfig(uint32_t blockIdx, VecConfig& vecConfig) +{ + // Step 1: Read grouplist reduceSum to calculate total data count + vecConfig.M = workspaceSplitConfig.isLastLoop \ + ? workspaceSplitConfig.lastLoopTaskSize\ + : workspaceSplitConfig.notLastTaskSize; + // Step 2: Calculate core allocation + uint32_t eachCoreTaskNum = (vecConfig.M + aivCoreNum - 1) / aivCoreNum; + vecConfig.usedCoreNum = vecConfig.M >= aivCoreNum ? aivCoreNum : vecConfig.M; + uint32_t tailCoreIdx = vecConfig.M - (eachCoreTaskNum - 1) * vecConfig.usedCoreNum; + vecConfig.taskNum = blockIdx < tailCoreIdx ? eachCoreTaskNum : eachCoreTaskNum - 1; + vecConfig.startIdx = blockIdx < tailCoreIdx + ? eachCoreTaskNum * blockIdx + :((eachCoreTaskNum - 1) * blockIdx + tailCoreIdx); + vecConfig.curIdx = vecConfig.startIdx; + vecConfig.startOffset = vecConfig.startIdx * gmmSwiglu->tokenLen; + vecConfig.curOffset = vecConfig.startOffset; + int64_t curStartIdx = vecConfig.startIdx; + int64_t prevM = workspaceSplitConfig.leftMatrixStartIndex; + for (uint32_t groupIdx = workspaceSplitConfig.rightMatrixExpertStartIndex; groupIdx <= workspaceSplitConfig.rightMatrixExpertEndIndex; groupIdx++){ + int64_t currM = groupListGM.GetValue(groupIdx); + int64_t tempM = currM - prevM; + prevM = currM; + if (curStartIdx >= 0 && curStartIdx - tempM < 0) { + vecConfig.curGroupIdx = groupIdx; + vecConfig.nextUpadteInterVal = tempM - curStartIdx; + } + curStartIdx -= tempM; + } + // Step 3: Calculate total data volume + vecConfig.outLoopNum = (vecConfig.taskNum + gmmSwiglu->maxProcessRowNum - 1) / gmmSwiglu->maxProcessRowNum; + vecConfig.tailLoopNum = vecConfig.taskNum % gmmSwiglu->maxProcessRowNum + ? vecConfig.taskNum % gmmSwiglu->maxProcessRowNum + : gmmSwiglu->maxProcessRowNum; + pipe->Reset(); + // Step 4: Allocate space + pipe->InitBuffer(mmOutQueue, 1, gmmSwiglu->maxProcessRowNum * gmmSwiglu->tokenLen * sizeof(int32_t)); + pipe->InitBuffer(perChannelScaleInQueue, 1, gmmSwiglu->tokenLen * sizeof(float)); + pipe->InitBuffer(quantOutQueue, 1, gmmSwiglu->maxProcessRowNum * gmmSwiglu->tokenLen / 2 * sizeof(int8_t)); + pipe->InitBuffer(quantScaleOutQueue, 1, AlignUp(gmmSwiglu->maxProcessRowNum, 8) * sizeof(float)); + pipe->InitBuffer(reduceWorkspace, 1024 * sizeof(float)); + pipe->InitBuffer(castWorkspace, 32 * sizeof(int8_t)); +} + +template +__aicore__ inline void GMMSwigluSplitWorkSpaceCompute::customDataCopyIn(uint32_t outLoopIdx, GlobalTensor &mmOutGM, VecConfig& vecConfig) +{ + LocalTensor _inMMLocal_0 = mmOutQueue.DeQue(); + DataCopyExtParams copyParams_0{1, static_cast(vecConfig.innerLoopNum * gmmSwiglu->tokenLen * sizeof(int32_t)), 0, 0, 0}; + DataCopyPadExtParams padParams_0{false, 0 ,0, 0}; + PipeBarrier(); + DataCopyPad(_inMMLocal_0, mmOutGM[vecConfig.curOffset], copyParams_0, padParams_0); + mmOutQueue.EnQue(_inMMLocal_0); + + LocalTensor _inMMLocal_1 = mmOutQueue.DeQue(); + + Cast(_inMMLocal_1.ReinterpretCast(), _inMMLocal_1, RoundMode::CAST_NONE, vecConfig.innerLoopNum * gmmSwiglu->tokenLen); + + mmOutQueue.EnQue(_inMMLocal_1); + LocalTensor _inMMLocal_2 = mmOutQueue.DeQue(); + set_flag(PIPE_S, PIPE_V, EVENT_ID0); + for (uint32_t i = 0; i < vecConfig.innerLoopNum; i++){ + wait_flag(PIPE_S, PIPE_V, EVENT_ID0); + float scale = perTokenScaleGM.GetValue(vecConfig.curIdx + workspaceSplitConfig.leftMatrixStartIndex); + set_flag(PIPE_S, PIPE_V, EVENT_ID0); + wait_flag(PIPE_S, PIPE_V, EVENT_ID0); + Muls(_inMMLocal_2[i * gmmSwiglu->tokenLen], _inMMLocal_2[i * gmmSwiglu->tokenLen], scale, gmmSwiglu->tokenLen); + set_flag(PIPE_S, PIPE_V, EVENT_ID0); + vecConfig.curIdx++; + } + wait_flag(PIPE_S, PIPE_V, EVENT_ID0); + vecConfig.curOffset = vecConfig.curIdx * gmmSwiglu->tokenLen; + mmOutQueue.EnQue(_inMMLocal_2); +} + +template +template +__aicore__ inline void GMMSwigluSplitWorkSpaceCompute::UpdateChannelScale(uint32_t loopIdx, VecConfig& vecConfig){ + // Update perChannel + if (unlikely(vecConfig.nextUpadteInterVal == 0)) { + int64_t loop = gmmSwiglu->groupListLen - vecConfig.curGroupIdx; + while (loop--) { + int64_t curTemp = groupListGM.GetValue(vecConfig.curGroupIdx); + vecConfig.curGroupIdx++; + int64_t nextTemp = groupListGM.GetValue(vecConfig.curGroupIdx); + if(nextTemp != curTemp){ + vecConfig.nextUpadteInterVal = nextTemp - curTemp; + break; + } + } + LocalTensor _inChannel = perChannelScaleInQueue.DeQue(); + DataCopyExtParams copyParams{1, static_cast(gmmSwiglu->tokenLen * sizeof(DTYPE_CS)), 0, 0, 0}; + DataCopyPadExtParams padParams{false, 0 ,0, 0}; + + GlobalTensor perChannelScaleTensor; + perChannelScaleTensor.SetGlobalBuffer(GetTensorAddr(vecConfig.curGroupIdx, perChannelScalePtr)); + + if constexpr(!IsSameType::value) { + LocalTensor dstLocalT = _inChannel.template ReinterpretCast(); + DataCopyPad(dstLocalT[gmmSwiglu->tokenLen], perChannelScaleTensor, copyParams, padParams); + PipeBarrier(); + Cast(_inChannel, dstLocalT[gmmSwiglu->tokenLen], RoundMode::CAST_NONE, gmmSwiglu->tokenLen); + } else { + DataCopyPad(_inChannel, perChannelScaleTensor, copyParams, padParams); + } + PipeBarrier(); + + perChannelScaleInQueue.EnQue(_inChannel); + } +} + +template +__aicore__ inline void GMMSwigluSplitWorkSpaceCompute::VectorCompute(uint32_t loopIdx, VecConfig& vecConfig) { + Dequant(loopIdx, vecConfig); + Swiglu(loopIdx, vecConfig); + Quant(loopIdx, vecConfig); +} + +template +__aicore__ inline void GMMSwigluSplitWorkSpaceCompute::Dequant(uint32_t loopIdx, VecConfig& vecConfig) { + // perChanelScale * perTokenScale + LocalTensor mmLocal = mmOutQueue.DeQue(); + LocalTensor perChannelLocal = perChannelScaleInQueue.DeQue(); + Mul(mmLocal[loopIdx * gmmSwiglu->tokenLen], mmLocal[loopIdx * gmmSwiglu->tokenLen], perChannelLocal, gmmSwiglu->tokenLen); + vecConfig.nextUpadteInterVal--; + mmOutQueue.EnQue(mmLocal); + perChannelScaleInQueue.EnQue(perChannelLocal); +} + +template +__aicore__ inline void GMMSwigluSplitWorkSpaceCompute::Swiglu(uint32_t loopIdx, VecConfig& vecConfig) { + // High-level API swiglu + LocalTensor _inMMLocal = mmOutQueue.DeQue(); + float beta = 1.0f; + LocalTensor workspaceLocal= reduceWorkspace.Get(); + LocalTensor src0Local = _inMMLocal[loopIdx * gmmSwiglu->tokenLen + gmmSwiglu->tokenLen / 2]; + LocalTensor src1Local = _inMMLocal[loopIdx * gmmSwiglu->tokenLen]; + SwiGLU(workspaceLocal, src0Local, src1Local, beta, gmmSwiglu->tokenLen / 2); + PipeBarrier(); + DataCopyParams repeatParams{1, static_cast((gmmSwiglu->tokenLen / 2) / 8), 0, 0}; + DataCopy(_inMMLocal[loopIdx * gmmSwiglu->tokenLen], workspaceLocal, repeatParams); + mmOutQueue.EnQue(_inMMLocal); +} + +template +__aicore__ inline void GMMSwigluSplitWorkSpaceCompute::Quant(uint32_t loopIdx, VecConfig& vecConfig) { + LocalTensor _inMMLocal = mmOutQueue.DeQue(); + Abs(_inMMLocal[loopIdx * gmmSwiglu->tokenLen + gmmSwiglu->tokenLen / BISECT], + _inMMLocal[loopIdx * gmmSwiglu->tokenLen], + gmmSwiglu->tokenLen / BISECT); + LocalTensor workspaceLocal= reduceWorkspace.Get(); + PipeBarrier(); + ReduceMaxTemplate(workspaceLocal, + _inMMLocal, loopIdx * gmmSwiglu->tokenLen + gmmSwiglu->tokenLen / BISECT, gmmSwiglu->tokenLen / BISECT); + PipeBarrier(); + float quantScale = workspaceLocal.GetValue(0) / QUANT_SCALE_INT8; + PipeBarrier(); + LocalTensor quantScaleLocal = quantScaleOutQueue.DeQue(); + PipeBarrier(); + quantScaleLocal.SetValue(loopIdx, quantScale); + PipeBarrier(); + quantScale = 1 / quantScale; + PipeBarrier(); + Muls(_inMMLocal[loopIdx * gmmSwiglu->tokenLen], _inMMLocal[loopIdx * gmmSwiglu->tokenLen], + quantScale, gmmSwiglu->tokenLen / BISECT); + PipeBarrier(); + LocalTensor quantLocal = quantOutQueue.DeQue(); + int32_t dstTempOffset = static_cast(loopIdx * gmmSwiglu->tokenLen / BISECT); + int32_t srcTempOffset = static_cast(loopIdx * gmmSwiglu->tokenLen); + int32_t tempCount = static_cast(gmmSwiglu->tokenLen / BISECT); + LocalTensor castSpace = castWorkspace.Get(); + CastFp32ToInt8Template(quantLocal, _inMMLocal, castSpace, dstTempOffset, srcTempOffset, tempCount); + mmOutQueue.EnQue(_inMMLocal); + quantOutQueue.EnQue(quantLocal); +} + +template +__aicore__ inline void GMMSwigluSplitWorkSpaceCompute::customDataCopyOut(VecConfig& vecConfig) { + LocalTensor quantScaleLocal = quantScaleOutQueue.DeQue(); + DataCopyParams copyParams_0{1, (uint16_t)(vecConfig.innerLoopNum * sizeof(float)), 0, 0}; + PipeBarrier(); + DataCopyPad(quantScaleOutputGM[workspaceSplitConfig.leftMatrixStartIndex + vecConfig.startIdx], quantScaleLocal, copyParams_0); + LocalTensor quantLocal = quantOutQueue.DeQue(); + DataCopyParams copyParams_1{1, (uint16_t)(vecConfig.innerLoopNum * gmmSwiglu->tokenLen / 2 * sizeof(int8_t)), 0, 0}; + PipeBarrier(); + DataCopyPad(quantOutputGM[(workspaceSplitConfig.leftMatrixStartIndex + vecConfig.startIdx) * gmmSwiglu->tokenLen / 2], quantLocal, copyParams_1); + PipeBarrier(); + vecConfig.startIdx += vecConfig.innerLoopNum; + vecConfig.startOffset = vecConfig.startIdx * gmmSwiglu->tokenLen; + quantOutQueue.EnQue(quantLocal); + quantScaleOutQueue.EnQue(quantScaleLocal); +} + +} // namespace GROUPED_MATMUL +#endif // ASCENDC_GROUPED_MATMUL_SWIGLU_QUANT_WEIGHT_NZ_TENSOR_LIST_SPLIT_WS_H diff --git a/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_kernel/grouped_matmul_swiglu_quant_weight_nz_tensor_list_utils.h b/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_kernel/grouped_matmul_swiglu_quant_weight_nz_tensor_list_utils.h new file mode 100644 index 00000000..37ddd845 --- /dev/null +++ b/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_kernel/grouped_matmul_swiglu_quant_weight_nz_tensor_list_utils.h @@ -0,0 +1,240 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file grouped_matmul_swiglu_quant_weight_nz_tensor_list_utils.h + * \brief + */ +#ifndef ASCENDC_GROUPED_MATMUL_SWIGLU_QUANT_WEIGHT_NZ_TENSOR_LIST_UTILS_H +#define ASCENDC_GROUPED_MATMUL_SWIGLU_QUANT_WEIGHT_NZ_TENSOR_LIST_UTILS_H + +#include "kernel_tiling/kernel_tiling.h" +#include "kernel_operator.h" +#include "lib/matmul_intf.h" + +namespace GROUPED_MATMUL_SWIGLU_QUANT_WEIGHT_NZ_TENSOR_LIST { +using namespace AscendC; +constexpr uint32_t INT8_BITS = 8; // a int8 number has 8 bits +constexpr uint32_t UB_BLOCK_UNIT_SIZE = 32; // 32: a block has 32 bytes data +constexpr uint32_t THRESHOLD_BLOCK_NUM = 8; +constexpr uint32_t UB_BLOCK_DOUBLE_UNIT_SIZE = 64; // 64: a block has 64 bytes data +constexpr uint32_t HALF_UB_BLOCK_UNIT_SIZE = UB_BLOCK_UNIT_SIZE / 2; // 2: a float16 data has two bytes +constexpr uint32_t SINGLE_CORE_M = 128; +constexpr uint32_t SINGLE_CORE_N = 512; +constexpr uint32_t SINGLE_CORE_K = 7168; +constexpr uint32_t BASIC_M = 128; +constexpr uint32_t BASIC_N = 256; +constexpr uint32_t BASIC_K = 128; +constexpr uint32_t STEP_M = 1; +constexpr uint32_t STEP_N = 1; +constexpr uint32_t STEP_Ka = 4; +constexpr uint32_t STEP_Kb = 4; +constexpr uint32_t DEPTH_A1 = 8; +constexpr uint32_t DEPTH_B1 = 8; +constexpr uint32_t VEC_LEN_ONCE_REPEAT_ELE = 64; +constexpr uint32_t VEC_LEN_ONCE_REPEAT_BLOCK = 8; +constexpr uint32_t BISECT = 2; +constexpr uint32_t MOD_32_MASK = 0x1F; +constexpr uint32_t MOD_16_MASK = 0x0F; +constexpr uint32_t ALIGN_8_ELE = 8; +constexpr uint32_t ALIGN_16_ELE = 16; +constexpr float QUANT_SCALE_INT8 = 127.0f; +constexpr MatmulConfig NZ_CFG_MDL = GetMDLConfig(false, false, 0, true, false, false, true); +constexpr MatmulConfig GetMMCFG() { + MatmulConfig MM_CFG = NZ_CFG_MDL; + MM_CFG.singleCoreM = SINGLE_CORE_M; + MM_CFG.singleCoreN= SINGLE_CORE_N; + MM_CFG.singleCoreK= SINGLE_CORE_K; + MM_CFG.basicM= BASIC_M; + MM_CFG.basicN= BASIC_N; + MM_CFG.basicK= BASIC_K; + return MM_CFG; +} + +constexpr static MatmulApiStaticTiling GetMMTiling(const MatmulApiStaticTiling& mmTiling) +{ + MatmulApiStaticTiling tiling = mmTiling; + tiling.stepM = STEP_M; + tiling.stepN = STEP_N; + tiling.stepKa = STEP_Ka; + tiling.stepKb = STEP_Kb; + tiling.depthA1 = DEPTH_A1; + tiling.depthB1 = DEPTH_B1; + return tiling; +} +template +struct MMImplType { + using AT = AT_; + using BT = BT_; + using CT = CT_; + using BiasT = BiasT_; + static constexpr MatmulConfig cfg = GetMMCFG(); + static constexpr MatmulApiStaticTiling mdl = GetMMTiling(GetMatmulApiTiling(cfg)); + using MT = matmul::MatmulImpl; +}; + +struct MNConfig { + int64_t m = 0; + int64_t k = 0; + int64_t n = 0; + int64_t baseM = 0; + int64_t baseN = 0; + int64_t mIdx = 0; + int64_t nIdx = 0; + int64_t blockDimM = 0; + int64_t blockDimN = 0; + int64_t singleM = 0; + int64_t singleN = 0; + int64_t wBaseOffset = 0; + int64_t nAxisBaseOffset = 0; + int64_t mAxisBaseOffset = 0; + int64_t xBaseOffset = 0; + int64_t yBaseOffset = 0; + int64_t wOutOffset = 0; + int64_t workSpaceOffset = 0; +}; + +struct VecConfig { + int64_t M = 0; + int64_t usedCoreNum = 0; + int64_t startOffset = 0; + int64_t curOffset = 0; + int64_t startIdx = 0; + int64_t curIdx = 0; + int64_t taskNum = 0; + int64_t curGroupIdx = 0; + int64_t outLoopNum = 0; + int64_t innerLoopNum = 0; + int64_t tailLoopNum = 0; + int64_t nextUpadteInterVal = 0; +}; + +struct WorkSpaceSplitConfig { + int64_t M = 0; + int64_t loopCount = 0; + int64_t leftMatrixStartIndex = 0; + int64_t rightMatrixExpertStartIndex = 0; + int64_t rightMatrixExpertNextStartIndex = 0; + int64_t rightMatrixExpertEndIndex = 0; + int64_t notLastTaskSize = 0; + int64_t lastLoopTaskSize = 0; + bool isLastLoop = false; +}; + +template +__aicore__ inline T AlignUp(T a) { + return (a + base - 1) / base * base; +} + +template +__aicore__ inline T AlignUp(T a, T base) { + return (a + base - 1) / base * base; +} + +template +__aicore__ inline T AlignDown(T a, T base) { + if (unlikely(base == 0)) { + return a; + } + return a / base * base; +} + +template <> +__aicore__ inline uint32_t AlignUp<4, uint32_t>(uint32_t a) { + // to be Multiple of 4, result should be in a format of b(xxxx,x100). + // This means last two bits should be zero, requiring that + // result = num & b(1111,1100) = num & (~3). + // &(~3) operator may reduces num into the range [num, num - 3]. + // As the result should be no less than a (result >= a), it means num - 3 >= a in the worst case. + // In this case, num >= a+3. On the other hand, num should also be less then a+4, otherwise, + // the result will not be least multiple of 4 for 3. In other cases like [num, num - 2], + // num = a + 3 also satisfies the goal condition. + return (a + 3) & ~3; // & ~3: set last two bits of (a+3) to be zero +} + +template <> +__aicore__ inline uint32_t AlignUp<8, uint32_t>(uint32_t a) { + // In general, if we want to get the least multiple of b (b is the power of 2) for a, + // it comes to a conclusion from the above comment: result = (a + (b - 1)) & (~b) + return (a + 7) & ~7; // & ~7: set last four bits of (a+7) to be zero +} + +template <> +__aicore__ inline uint32_t AlignUp<16, uint32_t>(uint32_t a) { + // In general, if we want to get the least multiple of b (b is the power of 2) for a, + // it comes to a conclusion from the above comment: result = (a + (b - 1)) & (~b) + return (a + 15) & ~15; // & ~15: set last four bits of (a+15) to be zero +} + +template <> +__aicore__ inline uint32_t AlignUp<32, uint32_t>(uint32_t a) { + // refer to the above comments. + return (a + 31) & ~31; // & ~31: set last five bits of (a+31) to be zero} +} + +__aicore__ inline void ReduceMaxTemplate(LocalTensor& dstLocal, LocalTensor& srcLocal, + uint32_t srcOffset, uint32_t count) +{ + if (likely(count > VEC_LEN_ONCE_REPEAT_ELE && count % VEC_LEN_ONCE_REPEAT_ELE == 0)){ + WholeReduceMax(dstLocal, + srcLocal[srcOffset], VEC_LEN_ONCE_REPEAT_ELE, + count / VEC_LEN_ONCE_REPEAT_ELE, 1, 1, + VEC_LEN_ONCE_REPEAT_BLOCK, ReduceOrder::ORDER_ONLY_VALUE); + PipeBarrier(); + WholeReduceMax(dstLocal, dstLocal, + count / VEC_LEN_ONCE_REPEAT_ELE, 1, 1, 1, + VEC_LEN_ONCE_REPEAT_BLOCK, ReduceOrder::ORDER_ONLY_VALUE); + } else if (count <= VEC_LEN_ONCE_REPEAT_ELE) { + WholeReduceMax(dstLocal, + srcLocal[srcOffset], + count, 1, 1, 1, VEC_LEN_ONCE_REPEAT_BLOCK, ReduceOrder::ORDER_ONLY_VALUE); + } else { + ReduceMax(dstLocal, srcLocal[srcOffset], dstLocal, count, false); + } +} + +__aicore__ inline void CastFp32ToInt8Template(LocalTensor& dstLocal, LocalTensor& srcLocal, + LocalTensor& oneBlockWorkspace, + int32_t dstOffset, int32_t srcOffset, int32_t count) +{ + Cast(srcLocal[srcOffset].ReinterpretCast(), srcLocal[srcOffset], RoundMode::CAST_RINT, count); + PipeBarrier(); + if ((dstOffset & MOD_32_MASK) == 0) { + Cast(dstLocal[dstOffset], + srcLocal[srcOffset].ReinterpretCast(), + RoundMode::CAST_RINT, count); + } else if ((dstOffset & MOD_16_MASK) == 0) { + Cast(dstLocal[dstOffset + ALIGN_16_ELE], + srcLocal[srcOffset + ALIGN_8_ELE].ReinterpretCast(), + RoundMode::CAST_RINT, count - ALIGN_16_ELE); + PipeBarrier(); + Cast(oneBlockWorkspace, srcLocal[srcOffset].ReinterpretCast(), + RoundMode::CAST_RINT, ALIGN_16_ELE); + PipeBarrier(); + for (int32_t i = 0; i < ALIGN_16_ELE; i++) { + int8_t temp = oneBlockWorkspace.GetValue(i); + dstLocal.SetValue(dstOffset + i, temp); + } + PipeBarrier(); + } +} + +template +__aicore__ inline __gm__ T* GetTensorAddr(uint16_t index, GM_ADDR tensorPtr) { + __gm__ uint64_t* dataAddr = reinterpret_cast<__gm__ uint64_t*>(tensorPtr); + uint64_t tensorPtrOffset = *dataAddr; // The offset of the data address from the first address. + // Moving 3 bits to the right means dividing by sizeof(uint64 t). + __gm__ uint64_t* retPtr = dataAddr + (tensorPtrOffset >> 3); + return reinterpret_cast<__gm__ T*>(*(retPtr + index)); +} + +} // namespace GROUPED_MATMUL + +#endif // ASCENDC_GROUPED_MATMUL_UTILS_H diff --git a/csrc/torch_binding.cpp b/csrc/torch_binding.cpp index d2a1f90e..06338e4f 100644 --- a/csrc/torch_binding.cpp +++ b/csrc/torch_binding.cpp @@ -552,6 +552,41 @@ std::tuple grouped_matmul_swiglu_quant( output_offset); return std::tuple(output, output_scale, output_offset); } + +std::tuple grouped_matmul_swiglu_quant_weight_nz_tensor_list( + const at::Tensor & x, + const at::TensorList & weight, + const at::TensorList & weight_scale, + const at::Tensor & x_scale, + const at::Tensor & group_list, + const c10::optional & bias, + const c10::optional & offset) +{ + auto x_size = x.sizes(); + int n = weight[0].sizes()[1]; + int m = x_size[0]; + int k = x_size[1]; + + at::Tensor output = at::zeros({m, n/2}, x.options().dtype(at::kChar)); + at::Tensor output_scale = at::zeros({m}, x.options().dtype(at::kFloat)); + at::Tensor output_offset = at::zeros({m}, x.options().dtype(at::kFloat)); + + EXEC_NPU_CMD( + aclnnGroupedMatmulSwigluQuantWeightNzTensorList, + x, + weight, + bias, + offset, + weight_scale, + x_scale, + group_list, + output, + output_scale, + output_offset); + + return std::tuple(output, output_scale, output_offset); +} + } // namespace vllm_ascend TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops) @@ -614,4 +649,12 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops) " Tensor group_list, *, Tensor? bias=None," " Tensor? offset=None) -> (Tensor output, Tensor output_scale, Tensor output_offset)"); ops.impl("grouped_matmul_swiglu_quant", torch::kPrivateUse1, &vllm_ascend::grouped_matmul_swiglu_quant); + + ops.def( + "grouped_matmul_swiglu_quant_weight_nz_tensor_list(Tensor x, Tensor[] weight, Tensor[] weight_scale, Tensor x_scale," + " Tensor group_list, *," + " Tensor? bias=None, Tensor? offset=None) ->" + " (Tensor output, Tensor output_scale, Tensor output_offset)" + ); + ops.impl("grouped_matmul_swiglu_quant_weight_nz_tensor_list", torch::kPrivateUse1, &vllm_ascend::grouped_matmul_swiglu_quant_weight_nz_tensor_list); } diff --git a/csrc/torch_binding_meta.cpp b/csrc/torch_binding_meta.cpp index e3b35b10..26b3d66d 100644 --- a/csrc/torch_binding_meta.cpp +++ b/csrc/torch_binding_meta.cpp @@ -130,14 +130,34 @@ std::tuple grouped_matmul_swiglu_quant( return {output, output_scale, output_offset}; } +std::tuple grouped_matmul_swiglu_quant_weight_nz_tensor_list_meta( + const at::Tensor & x, + const at::TensorList & weight, + const at::TensorList & weight_scale, + const at::Tensor & x_scale, + const at::Tensor & group_list, + const c10::optional & bias, + const c10::optional & offset) +{ + auto x_size = x.sizes(); + int n = weight[0].sizes()[1]; + int m = x_size[0]; + int k = x_size[1]; + + at::Tensor output = at::zeros({m, n/2}, c10::dtype(c10::ScalarType::Char)); + at::Tensor output_scale = at::zeros({m}, c10::dtype(c10::ScalarType::Float)); + at::Tensor output_offset = at::zeros({m}, c10::dtype(c10::ScalarType::Float)); + + return std::tuple(output, output_scale, output_offset); +} } // namespace meta } // namespace vllm_ascend namespace { - // Register the meta implementations of the custom kernels for symbolic tracing, this will also - // the custom kernel been captured into aclgraph - TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) { +// Register the meta implementations of the custom kernels for symbolic tracing, this will also +// the custom kernel been captured into aclgraph +TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) { // Rotary embedding meta implementation ops.impl("rotary_embedding", &vllm_ascend::meta::rotary_embedding_meta); // Masked input and mask meta implementation @@ -150,5 +170,7 @@ namespace { ops.impl("mla_preprocess", &vllm_ascend::meta::mla_preprocess); // grouped_matmul_swiglu_quant meta implementation ops.impl("grouped_matmul_swiglu_quant", &vllm_ascend::meta::grouped_matmul_swiglu_quant); + // Grouped matmul swiglu quant weight nz tensor list + ops.impl("grouped_matmul_swiglu_quant_weight_nz_tensor_list", &vllm_ascend::meta::grouped_matmul_swiglu_quant_weight_nz_tensor_list_meta); } } diff --git a/csrc/utils/CMakeLists.txt b/csrc/utils/CMakeLists.txt new file mode 100644 index 00000000..db468cb2 --- /dev/null +++ b/csrc/utils/CMakeLists.txt @@ -0,0 +1,48 @@ +# Copyright (c) 2024 Huawei Technologies Co., Ltd. +# This file is a part of the CANN Open Software. +# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ====================================================================================================================== + +add_library(ops_utils_tiling_headers INTERFACE) + +target_include_directories(ops_utils_tiling_headers INTERFACE + $ + $<$:$> + $<$:$> + $<$:$> + $<$:$> + $ +) + +target_compile_definitions(ops_utils_tiling_headers INTERFACE + OPS_UTILS_LOG_SUB_MOD_NAME="OP_TILING" + OPS_UTILS_LOG_PACKAGE_TYPE=$,"[Custom]",""> +) + +add_library(ops_utils_proto_headers INTERFACE) + +target_include_directories(ops_utils_proto_headers INTERFACE + $ + $<$:$> + $<$:$> + $<$:$> + $ +) + +target_compile_definitions(ops_utils_proto_headers INTERFACE + OPS_UTILS_LOG_SUB_MOD_NAME="OP_PROTO" + OPS_UTILS_LOG_PACKAGE_TYPE=$,"[Custom]",""> +) + +if(NOT BUILD_OPEN_PROJECT) + install_package( + PACKAGE ops_adv + TARGETS ops_utils_tiling_headers ops_utils_proto_headers + DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/inc/ + DESTINATION include/ops_adv/utils + ) +endif() diff --git a/csrc/utils/inc/aclnn_util.h b/csrc/utils/inc/aclnn_util.h new file mode 100644 index 00000000..472ea4db --- /dev/null +++ b/csrc/utils/inc/aclnn_util.h @@ -0,0 +1,14 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef OP_API_INC_ACLNN_UTIL_H +#define OP_API_INC_ACLNN_UTIL_H + +#define ACLNN_API __attribute__((visibility("default"))) +#endif // OP_API_INC_ACLNN_UTIL_H \ No newline at end of file diff --git a/csrc/utils/inc/error/ops_error.h b/csrc/utils/inc/error/ops_error.h new file mode 100644 index 00000000..fbb5c295 --- /dev/null +++ b/csrc/utils/inc/error/ops_error.h @@ -0,0 +1,25 @@ +/** + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file ops_error.h + * \brief + */ + +#pragma once + +#include "log/ops_log.h" + +/* 基础报错 */ +#define OPS_REPORT_VECTOR_INNER_ERR(OPS_DESC, ...) OPS_INNER_ERR_STUB("E89999", OPS_DESC, __VA_ARGS__) +#define OPS_REPORT_CUBE_INNER_ERR(OPS_DESC, ...) OPS_INNER_ERR_STUB("E69999", OPS_DESC, __VA_ARGS__) + +/* 条件报错 */ +#define OPS_ERR_IF(COND, LOG_FUNC, EXPR) OPS_LOG_STUB_IF(COND, LOG_FUNC, EXPR) diff --git a/csrc/utils/inc/fallback.h b/csrc/utils/inc/fallback.h new file mode 100644 index 00000000..eb19050d --- /dev/null +++ b/csrc/utils/inc/fallback.h @@ -0,0 +1,497 @@ +/** + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file fallback.h + * \brief + */ + +#ifndef ACLNNFALLBACK_OPAPI_H_ +#define ACLNNFALLBACK_OPAPI_H_ + +#include + +#include +#include +#include +#include + +#include "aclnn/aclnn_base.h" +#include "fallback_comm.h" +#include "error/ops_error.h" +#include "runtime/base.h" + +namespace fallback { +using namespace std; +using namespace gert; +using namespace ge; +using namespace std; + +namespace std_utils { + template + struct index_sequence {}; + + template + struct make_index_sequence_helper : make_index_sequence_helper {}; + + template + struct make_index_sequence_helper<0, Is...> { + using type = index_sequence; + }; + + template + using make_index_sequence = typename make_index_sequence_helper::type; +} + +using aclOpExecutor = struct aclOpExecutor; +using aclTensor = struct aclTensor; +using aclScalar = struct aclScalar; +using aclIntArray = struct aclIntArray; +using aclFloatArray = struct aclFloatArray; +using aclBoolArray = struct aclBoolArray; +using aclTensorList = struct aclTensorList; + +using _aclCreateTensor = aclTensor* (*)(const int64_t* view_dims, uint64_t view_dims_num, aclDataType data_type, + const int64_t* stride, int64_t offset, aclFormat format, + const int64_t* storage_dims, uint64_t storage_dims_num, void* tensor_data); + +using _aclCreateScalar = aclScalar* (*)(void* value, aclDataType data_type); +using _aclCreateIntArray = aclIntArray* (*)(const int64_t* value, uint64_t size); +using _aclCreateFloatArray = aclFloatArray* (*)(const float* value, uint64_t size); +using _aclCreateBoolArray = aclBoolArray* (*)(const bool* value, uint64_t size); +using _aclCreateTensorList = aclTensorList* (*)(const aclTensor* const* value, uint64_t size); + +using _aclDestroyTensor = int (*)(const aclTensor* tensor); +using _aclDestroyScalar = int (*)(const aclScalar* scalar); +using _aclDestroyIntArray = int (*)(const aclIntArray* array); +using _aclDestroyFloatArray = int (*)(const aclFloatArray* array); +using _aclDestroyBoolArray = int (*)(const aclBoolArray* array); +using _aclDestroyTensorList = int (*)(const aclTensorList* array); + +#define GET_OP_API_FUNC(apiName) reinterpret_cast<_##apiName>(GetOpApiFuncAddr(#apiName)) + +inline const char* GetOpApiLibName(void) { + return "libopapi.so"; +} + +inline const char* GetCustOpApiLibName(void) { + return "libcust_opapi.so"; +} + +inline void* GetOpApiFuncAddrInLib(void* handler, const char* libName, const char* apiName) { + auto funcAddr = dlsym(handler, apiName); + if (funcAddr == nullptr) { + OPS_LOG_W("aclnnfallback", "dlsym %s from %s failed, error:%s.", apiName, libName, dlerror()); + } + return funcAddr; +} + +inline void* GetOpApiLibHandler(const char* libName) { + auto handler = dlopen(libName, RTLD_LAZY); + if (handler == nullptr) { + OPS_LOG_W("aclnnfallback", "dlopen %s failed, error:%s.", libName, dlerror()); + } + return handler; +} + +inline void* GetAclnnArrdByApiName(const char *apiName) { + vector libs = {"libaclnn_ops_infer.so", "libaclnn_ops_train.so", "libaclnn_math.so", + "libaclnn_rand.so", "libaclnn_sparse.so", "libaclnn_fft.so"}; + for (const auto &libName : libs) { + static auto libHandler = GetOpApiLibHandler(libName.c_str()); + if (libHandler != nullptr) { + auto funcAddr = GetOpApiFuncAddrInLib(libHandler, libName.c_str(), apiName); + if (funcAddr != nullptr) { + return funcAddr; + } + } + } + OPS_LOG_E("aclnnfallback", "api %s can't find in any aclnn lib.", apiName); + return nullptr; +} + +inline void* GetOpApiFuncAddr(const char* apiName) { + static auto custOpApiHandler = GetOpApiLibHandler(GetCustOpApiLibName()); + if (custOpApiHandler != nullptr) { + auto funcAddr = GetOpApiFuncAddrInLib(custOpApiHandler, GetCustOpApiLibName(), apiName); + if (funcAddr != nullptr) { + return funcAddr; + } + } + + static auto opApiHandler = GetOpApiLibHandler(GetOpApiLibName()); + if (opApiHandler != nullptr) { + auto funcAddr = GetOpApiFuncAddrInLib(opApiHandler, GetOpApiLibName(), apiName); + if (funcAddr != nullptr) { + return funcAddr; + } + } + OPS_LOG_D("aclnnfallback", "opapi lib is not exist,will use aclnn lib."); + return GetAclnnArrdByApiName(apiName); +} + +inline aclTensor* ConvertType(aclTensor* ge_tensor) { + return ge_tensor; +} + +inline aclIntArray* ConvertType(const std::vector &arr) { + if (arr.empty()) { + return nullptr; + } + static const auto aclCreateIntArray = GET_OP_API_FUNC(aclCreateIntArray); + auto array = aclCreateIntArray(arr.data(), arr.size()); + return array; +} + +inline aclDataType GetConvertType(const gert::Tensor* ge_tensor) { + // convert data type + auto dataType_ge = ge_tensor->GetDataType(); + auto dataType = aclDataType::ACL_FLOAT16; + if (dataType_ge == DT_FLOAT) { + dataType = aclDataType::ACL_FLOAT; + } else if (dataType_ge == DT_BF16) { + dataType = aclDataType::ACL_BF16; + } else if (dataType_ge == DT_BOOL) { + dataType = aclDataType::ACL_BOOL; + } else if (dataType_ge == DT_INT64) { + dataType = aclDataType::ACL_INT64; + } else if (dataType_ge == DT_INT32) { + dataType = aclDataType::ACL_INT32; + } else if (dataType_ge == DT_UINT64) { + dataType = aclDataType::ACL_UINT64; + } else if (dataType_ge == DT_UINT32) { + dataType = aclDataType::ACL_UINT32; + } else if (dataType_ge == DT_INT8) { + dataType = aclDataType::ACL_INT8; + } else if (dataType_ge == DT_UINT8) { + dataType = aclDataType::ACL_UINT8; + } else if (dataType_ge == DT_INT4) { + dataType = aclDataType::ACL_INT4; + } else { + dataType = aclDataType::ACL_FLOAT16; + } + + return dataType; +} + +inline aclTensor* ConvertType(const gert::Tensor* ge_tensor) { + if (ge_tensor == nullptr) { + return nullptr; + } + + static const auto aclCreateTensor = GET_OP_API_FUNC(aclCreateTensor); + OPS_ERR_IF(aclCreateTensor == nullptr, OPS_LOG_E("aclnnfallback", "aclCreateTensor nullptr"), return nullptr); + + void* device_addr = nullptr; + auto tensor_place = ge_tensor->GetPlacement(); + device_addr = const_cast(ge_tensor->GetAddr()); + + auto dataType = GetConvertType(ge_tensor); + + OPS_LOG_D("aclnnfallback", "aclCreateTensor: tensor type is %d", dataType); + + // convert shape + auto gert_shape = ge_tensor->GetStorageShape(); + std::vector shape; + for (size_t i = 0; i < gert_shape.GetDimNum(); ++i) { + shape.push_back(gert_shape.GetDim(i)); + } + + // 计算连续tensor的strides + std::vector strides(shape.size(), 1); + for (int64_t i = shape.size() - 2; i >= 0; i--) { + strides[i] = shape[i + 1] * strides[i + 1]; + } + + aclTensor* out = aclCreateTensor(shape.data(), shape.size(), dataType, strides.data(), + 0, aclFormat::ACL_FORMAT_ND, + shape.data(), shape.size(), device_addr); + + OPS_ERR_IF(out == nullptr, + OPS_LOG_E("aclnnfallback", "out nullptr"), return nullptr); + + return out; +} + +inline aclTensorList* ConvertType(std::vector& ge_tenserList) { + OPS_ERR_IF(ge_tenserList.size() == 0, + OPS_LOG_E("aclnnfallback", "ge_tenserList size 0"), return nullptr); + + static const auto aclCreateTensorList = GET_OP_API_FUNC(aclCreateTensorList); + OPS_ERR_IF(aclCreateTensorList == nullptr, + OPS_LOG_E("aclnnfallback", "ge_tenserList size 0"), return nullptr); + + std::vector tmp; + for (size_t i = 0; i < ge_tenserList.size(); i++) { + auto t_acl = ConvertType(ge_tenserList[i]); + tmp.push_back(t_acl); + } + + aclTensorList* tensorList = aclCreateTensorList(tmp.data(), tmp.size()); + return tensorList; +} + +template +inline aclScalar* ConvertScalarType(T value) { + static const auto aclCreateScalar = GET_OP_API_FUNC(aclCreateScalar); + OPS_ERR_IF(aclCreateScalar == nullptr, + OPS_LOG_E("aclnnfallback", "aclCreateScalar nullptr"), return nullptr); + if (typeid(value) == typeid(float)) { + return aclCreateScalar(&value, aclDataType::ACL_FLOAT); + } + return nullptr; +} + +template +T ConvertType(T value) { + return value; +} + +inline aclTensor* ConvertMmType(const gert::Tensor* ge_tensor, bool transpose, bool enable_NZ=false) { + if (ge_tensor == nullptr) { + return nullptr; + } + auto gert_shape = ge_tensor->GetStorageShape(); + if (gert_shape.GetDimNum() <= 1) { + return ConvertType(ge_tensor); + } + + static const auto aclCreateTensor = GET_OP_API_FUNC(aclCreateTensor); + OPS_ERR_IF(aclCreateTensor == nullptr, OPS_LOG_E("aclnnfallback", "aclCreateTensor nullptr"), return nullptr); + + void* device_addr = const_cast(ge_tensor->GetAddr()); + // convert data type + auto dataType_ge = ge_tensor->GetDataType(); + auto dataType = ToAclDataType(dataType_ge); + // convert shape + std::vector shape; + for (size_t i = 0; i < gert_shape.GetDimNum(); ++i) { + shape.push_back(gert_shape.GetDim(i)); + } + // 计算连续tensor的strides + std::vector strides(shape.size(), 1); + for (int64_t i = shape.size() - 2; i >= 0; i--) { + strides[i] = shape[i + 1] * strides[i + 1]; + } + + auto viewShape = shape; + // 对于transpose后的tensor对后两维度进行strides, viewShape转换 + if (transpose) { + // dimM 为倒数第二维, dimN 为倒数第一维度 + auto dimM = shape.size() - 2; + auto dimN = shape.size() - 1; + auto swap = strides[dimN]; + strides[dimN] = strides[dimM]; + strides[dimM] = swap; + // 修改viewShape + viewShape[dimN] = shape[dimM]; + viewShape[dimM] = shape[dimN]; + } + auto acl_format = aclFormat::ACL_FORMAT_ND; + if (enable_NZ && GetPrimaryFormat(ge_tensor->GetStorageFormat()) == ge::Format::FORMAT_FRACTAL_NZ) { + acl_format = aclFormat::ACL_FORMAT_FRACTAL_NZ; + } + aclTensor* out = aclCreateTensor(viewShape.data(), shape.size(), dataType, strides.data(), + 0, acl_format, shape.data(), shape.size(), device_addr); + OPS_ERR_IF(out == nullptr, OPS_LOG_E("aclnnfallback", "out nullptr"), return nullptr); + + return out; +} + +inline void Release(aclTensor* p) { + static const auto aclDestroyTensor = GET_OP_API_FUNC(aclDestroyTensor); + OPS_ERR_IF(aclDestroyTensor == nullptr, + OPS_LOG_E("aclnnfallback", "aclDestroyTensor is null"), return); + aclDestroyTensor(p); +} + +inline void Release(aclScalar* p) { + static const auto aclDestroyScalar = GET_OP_API_FUNC(aclDestroyScalar); + OPS_ERR_IF(aclDestroyScalar == nullptr, + OPS_LOG_E("aclnnfallback", "aclDestroyScalar is null"), return); + aclDestroyScalar(p); +} + +inline void Release(aclIntArray* p) { + static const auto aclDestroyIntArray = GET_OP_API_FUNC(aclDestroyIntArray); + OPS_ERR_IF(aclDestroyIntArray == nullptr, + OPS_LOG_E("aclnnfallback", "aclDestroyIntArray is null"), return); + aclDestroyIntArray(p); +} + +inline void Release(aclBoolArray* p) { + static const auto aclDestroyBoolArray = GET_OP_API_FUNC(aclDestroyBoolArray); + OPS_ERR_IF(aclDestroyBoolArray == nullptr, + OPS_LOG_E("aclnnfallback", "aclDestroyBoolArray is null"), return); + aclDestroyBoolArray(p); +} + +inline void Release(aclTensorList* p) { + static const auto aclDestroyTensorList = GET_OP_API_FUNC(aclDestroyTensorList); + OPS_ERR_IF(aclDestroyTensorList == nullptr, + OPS_LOG_E("aclnnfallback", "aclDestroyTensorList is null"), return); + aclDestroyTensorList(p); +} + +template +void Release(T value) { + (void)value; +} + +template +void CallRelease(Tuple t, std_utils::index_sequence) { + (void)std::initializer_list{(Release(std::get(t)), 0)...}; +} + +template +void ReleaseConvertTypes(Tuple& t) { + static constexpr auto size = std::tuple_size::value; + CallRelease(t, std_utils::make_index_sequence{}); +} + +template +auto ConvertTypes(Ts&... args) -> decltype(std::make_tuple(ConvertType(args)...)) { + auto tp = std::make_tuple(ConvertType(args)...); + return tp; +} + +template +auto call(Function f, Tuple t, std_utils::index_sequence) -> int { + return f(std::get(t)...); +} + +template +auto call(Function f, Tuple t) -> int { + static constexpr auto size = std::tuple_size::value; + return call(f, t, std_utils::make_index_sequence{}); +} + +template +auto ConvertToOpApiFunc(const Tuple& params, void* opApiAddr, std_utils::index_sequence) + -> int (*)(typename std::decay(params))>::type...) { + using OpApiFunc = int (*)(typename std::decay(params))>::type...); + auto func = reinterpret_cast(opApiAddr); + return func; +} + +template +auto ConvertToOpApiFunc(const Tuple& params, void* opApiAddr) + -> typename std::enable_if::value != 0, + decltype(ConvertToOpApiFunc(params, opApiAddr, std_utils::make_index_sequence::value>{}))>::type { + static constexpr auto size = std::tuple_size::value; + return ConvertToOpApiFunc(params, opApiAddr, std_utils::make_index_sequence{}); +} + +template +class ConvertedParams { + public: + ConvertedParams(Tuple&& convertedParams) : convertedParams_(std::move(convertedParams)){}; + ConvertedParams(ConvertedParams&& other) : convertedParams_(std::move(other.convertedParams_)) { + other.validParams_ = false; + }; + ConvertedParams& operator=(ConvertedParams&& other) { + if (this == &other) { + return *this; + } + + convertedParams_ = std::move(other.convertedParams_); + validParams_ = true; + other.validParams_ = false; + return *this; + } + + ConvertedParams() = delete; + ConvertedParams(const ConvertedParams& other) = delete; + ConvertedParams& operator=(const ConvertedParams& other) = delete; + + ~ConvertedParams() { + if (validParams_) { + ReleaseConvertTypes(convertedParams_); + } + } + + const Tuple& GetConvertedParams() const { + return convertedParams_; + } + + private: + Tuple convertedParams_; + bool validParams_{true}; +}; + +using InitHugeMemThreadLocal = int (*)(void*, bool); +using UnInitHugeMemThreadLocal = void (*)(void*, bool); +using ReleaseHugeMem = void (*)(void*, bool); +using PTAGetExecCache = aclOpExecutor* (*)(uint64_t, uint64_t*); +using InitPTACacheThreadLocal = void (*)(); +using SetPTAHashKey = void (*)(uint64_t); +using CanUsePTACache = bool (*)(const char*); + +using ResetCacheThreadLocal = void (*)(); + +#define EXEC_OPAPI_CMD(aclnn_api, ...) \ + ({ \ + static auto ret = GRAPH_SUCCESS; \ + do { \ + static const auto ResetCacheThreadLocalAddr = GetOpApiFuncAddr("ResetCacheThreadLocal"); \ + static const auto getWorkspaceSizeFuncAddr = GetOpApiFuncAddr(#aclnn_api "GetWorkspaceSize"); \ + static const auto opApiFuncAddr = GetOpApiFuncAddr(#aclnn_api); \ + if (getWorkspaceSizeFuncAddr == nullptr || opApiFuncAddr == nullptr || ResetCacheThreadLocalAddr == nullptr) { \ + OPS_LOG_E("aclnnfallback", "%s or %s not in %s or %s or ResetCacheThreadLocal not found.", \ + #aclnn_api "GetWorkspaceSize", #aclnn_api, GetOpApiLibName(), GetOpApiLibName()); \ + ret = GRAPH_FAILED; \ + break; \ + } \ + auto ResetCacheThreadLocalFunc = reinterpret_cast(ResetCacheThreadLocalAddr); \ + ResetCacheThreadLocalFunc(); \ + uint64_t workspace_size = 0; \ + uint64_t* workspace_size_addr = &workspace_size; \ + aclOpExecutor* executor = nullptr; \ + aclOpExecutor** executor_addr = &executor; \ + auto converted_params = ConvertTypes(__VA_ARGS__, workspace_size_addr, executor_addr); \ + static auto getWorkspaceSizeFunc = ConvertToOpApiFunc(converted_params, getWorkspaceSizeFuncAddr); \ + auto workspace_status = call(getWorkspaceSizeFunc, converted_params); \ + if (workspace_status != 0) { \ + OPS_LOG_E("aclnnfallback", "call %s failed:", #aclnn_api); \ + ret = GRAPH_FAILED; \ + break; \ + } \ + void* workspace_addr = nullptr; \ + if (workspace_size > 0) { \ + workspace_addr = host_api_ctx->MallocWorkspace(workspace_size); \ + if (workspace_addr == nullptr) { \ + OPS_LOG_E("aclnnfallback", "call %s allocate workspace failed", #aclnn_api); \ + ret = GRAPH_FAILED; \ + break; \ + } \ + } \ + auto acl_stream = host_api_ctx->GetStream(); \ + auto acl_call = [converted_params, workspace_addr, workspace_size, host_api_ctx, acl_stream, \ + executor]() -> int { \ + using OpApiFunc = int (*)(void*, uint64_t, aclOpExecutor*, const aclrtStream); \ + OpApiFunc opApiFunc = reinterpret_cast(opApiFuncAddr); \ + auto api_ret = opApiFunc(workspace_addr, workspace_size, executor, acl_stream); \ + ReleaseConvertTypes(converted_params); \ + host_api_ctx->FreeWorkspace(); \ + if (api_ret != 0) { \ + OPS_LOG_E("aclnnfallback", "call %s allocate workspace failed api_ret: %d", #aclnn_api, api_ret); \ + return GRAPH_FAILED; \ + } \ + return api_ret; \ + }; \ + \ + ret = acl_call(); \ + } while (false); \ + (ret); \ + }) + +} // namespace fallback + +#endif // ACLNNFALLBACK_OPAPI_H_ diff --git a/csrc/utils/inc/fallback_comm.h b/csrc/utils/inc/fallback_comm.h new file mode 100644 index 00000000..a2dd5cfd --- /dev/null +++ b/csrc/utils/inc/fallback_comm.h @@ -0,0 +1,38 @@ +/** + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file fallback_comm.h + * \brief + */ + +#ifndef INC_EXTERNAL_GRAPH_FALLBACK_COMMON_H_ +#define INC_EXTERNAL_GRAPH_FALLBACK_COMMON_H_ + +#include "aclnn/aclnn_base.h" +#include "exe_graph/runtime/op_execute_context.h" +#include "exe_graph/runtime/tensor.h" +#include "register/op_impl_registry.h" +#include "runtime/base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +namespace fallback { + +aclDataType ToAclDataType(ge::DataType dtype); +} // namespace fallback + +#ifdef __cplusplus +} +#endif + +#endif // INC_EXTERNAL_GRAPH_FALLBACK_COMMON_H_ diff --git a/csrc/utils/inc/kernel/dropmask.h b/csrc/utils/inc/kernel/dropmask.h new file mode 100644 index 00000000..13ed9c35 --- /dev/null +++ b/csrc/utils/inc/kernel/dropmask.h @@ -0,0 +1,121 @@ +/** + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file dropmask.h + * \brief + */ + +#ifndef DROPMASK_H +#define DROPMASK_H + +#include "util.h" + +using AscendC::DROPOUT_MODE_BIT_MISALIGN; +using AscendC::DropOutShapeInfo; +using AscendC::DropOut; + +struct DropMaskInfo { + // for compute dropout mask offset + // 参数按B N G S1 S2全部切分设置进行偏移计算,没有切分的轴对应的参数设置为合适的0或者原始值 + int64_t n2G; // n2 * g + int64_t gSize; // g + int64_t s1Size; // s1 + int64_t s2Size; // s2 + int64_t gOutIdx; // g out index + int64_t bSSOffset; // boidx * s1 * s2 ===bSSOffset + int64_t n2OutIdx; // n out index + int64_t s1OutIdx; // s1 out index ===s1oIdx + int64_t s1InnerIdx; // s1 inner index, 配比 ===loopIdx + int64_t s1BaseSize; // S1基本块大小 + int64_t splitS1BaseSize; // s1 split size ===vec1S1BaseSize + int64_t s2StartIdx; // s2 start index + int64_t s2Idx; // s2 index =====s2LoopCount + int64_t s2BaseNratioSize; // s2的配比长度: s2BaseSize(S2基本块大小) * nRatio + + // for copy in dropout mask + uint32_t s1CopySize; + uint32_t s2CopySize; + int64_t s2TotalSize; + + // for compute dropout mask + uint32_t firstAxis; + uint32_t lstAxis; + uint32_t maskLstAxis; + int64_t vecCoreOffset = 0; + float keepProb; + + bool boolMode; +}; + +template +__aicore__ inline int64_t ComputeDropOffset(DropMaskInfo &dropMaskInfo) +{ + if constexpr (hasDrop == true) { + // boidx * n2 * g* s1 * s2 + int64_t bOffset = dropMaskInfo.bSSOffset * dropMaskInfo.n2G; + // n2oIdx * g * s1 *s2 + int64_t n2Offset = dropMaskInfo.n2OutIdx * dropMaskInfo.gSize * dropMaskInfo.s1Size * dropMaskInfo.s2Size; + // goIdx * s1 * s2 + int64_t gOffset = dropMaskInfo.gOutIdx * dropMaskInfo.s1Size * dropMaskInfo.s2Size; + // s1oIdx * s1BaseSize * s2Size + s1innerindex * vec1S1BaseSize * s2Size + int64_t s1Offset = (dropMaskInfo.s1OutIdx * dropMaskInfo.s1BaseSize + dropMaskInfo.vecCoreOffset + + dropMaskInfo.s1InnerIdx * dropMaskInfo.splitS1BaseSize) * dropMaskInfo.s2Size; + // s2StartIdx + s2index * s2BaseNratioSize + int64_t s2Offset = dropMaskInfo.s2StartIdx + dropMaskInfo.s2Idx * dropMaskInfo.s2BaseNratioSize; + return bOffset + n2Offset + gOffset + s1Offset + s2Offset; + } else { + return 0; + } +} + +template +__aicore__ inline void CopyInDropMask(LocalTensor&dstTensor, GlobalTensor& srcBoolTensor, + GlobalTensor& srcByteTensor, DropMaskInfo &dropMaskInfo, int64_t alignedSize = blockBytes) +{ + if constexpr (hasDrop == true) { + int64_t dropMaskOffset = ComputeDropOffset(dropMaskInfo); + if (unlikely(dropMaskInfo.boolMode)) { + BoolCopyIn(dstTensor, srcBoolTensor, dropMaskOffset, + dropMaskInfo.s1CopySize, dropMaskInfo.s2CopySize, dropMaskInfo.s2TotalSize, alignedSize); + } else { + Bit2Int8CopyIn(dstTensor, srcByteTensor, dropMaskOffset, 1, + dropMaskInfo.s1CopySize, dropMaskInfo.s2CopySize, dropMaskInfo.s2TotalSize, alignedSize); + } + return; + } +} + +template +__aicore__ inline void ComputeDropMask(LocalTensor& dstTensor, LocalTensor& srcTensor, + LocalTensor& dropoutBuffer, LocalTensor& tmpDropBuffer, DropMaskInfo &dropMaskInfo) +{ + if constexpr (hasDrop == true) { + DropOutShapeInfo dropOutShapeInfo; + dropOutShapeInfo.firstAxis = dropMaskInfo.firstAxis; + dropOutShapeInfo.srcLastAxis = dropMaskInfo.lstAxis; + + if (unlikely(dropMaskInfo.boolMode)) { + dropOutShapeInfo.maskLastAxis = CeilDiv(dropMaskInfo.maskLstAxis, blockBytes) * blockBytes; + DropOut(dstTensor, srcTensor, dropoutBuffer, tmpDropBuffer, dropMaskInfo.keepProb, dropOutShapeInfo); + } else { + dropOutShapeInfo.maskLastAxis = CeilDiv(dropMaskInfo.maskLstAxis / byteBitRatio, blockBytes) * blockBytes; + if (likely(dropMaskInfo.lstAxis / byteBitRatio % blockBytes == 0)) { + DropOut(dstTensor, srcTensor, dropoutBuffer, tmpDropBuffer, dropMaskInfo.keepProb, dropOutShapeInfo); + } else { + DropOut(dstTensor, srcTensor, dropoutBuffer, tmpDropBuffer, + dropMaskInfo.keepProb, dropOutShapeInfo); + } + } + return; + } +} + +#endif // DROPMASK_H diff --git a/csrc/utils/inc/kernel/pse.h b/csrc/utils/inc/kernel/pse.h new file mode 100644 index 00000000..e6cd8e7b --- /dev/null +++ b/csrc/utils/inc/kernel/pse.h @@ -0,0 +1,483 @@ +/** + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file pse.h + * \brief + */ + +#ifndef FLASH_ATTENTION_SCORE_PSE_H +#define FLASH_ATTENTION_SCORE_PSE_H + +#include "kernel_operator.h" +#include "util.h" + +constexpr static int64_t pseS1S2 = 0; +constexpr static int64_t pse1S2 = 1; +constexpr static int64_t pseSlopeBn = 2; +constexpr static int64_t pseSlopeN = 3; + +constexpr static uint8_t pseEncodeALibiS2Full = 0x11; + +enum class PseTypeEnum { + PSE_OUTER_MUL_ADD_TYPE = 0, // default + PSE_OUTER_ADD_MUL_TYPE, + PSE_INNER_MUL_ADD_TYPE, + PSE_INNER_MUL_ADD_SQRT_TYPE, + PSE_INVALID_TYPE +}; + +struct PseInfo { + int64_t blockCount; + int64_t bSSOffset; // boidx * s1 * s2 + int64_t boIdx; + int64_t gSize; + int64_t goIdx; + int64_t loopIdx; + int64_t n2G; + int64_t n2oIdx; + int64_t pseBSize; + int64_t pseS1Size; // for alibi + int64_t pseS2ComputeSize; // for alibi, do not need assignment + int64_t pseS2Size; // for alibi + uint32_t pseShapeType; + int64_t readS2Size; // for alibi, do not need assignment + int64_t s1BaseSize; + int64_t s1Size; + int64_t s1oIdx; + int64_t s2AlignedSize; + int64_t s2BaseNratioSize; + int64_t s2LoopCount; + int64_t s2RealSize; + int64_t s2Size; + int64_t s2SizeAcc; // accumulated sum of s2 size + int64_t s2StartIdx; + int64_t vec1S1BaseSize; + int64_t vec1S1RealSize; + uint32_t pseEncodeType; // for distinguish alibi + uint32_t pseType; // 0: outer, mul-add 1:outer, add-mul 2:inner, mul-add 3:inner, mul-add-sqrt + int64_t pseAlibiBaseS1; + int64_t pseAlibiBaseS2; + int64_t qStartIdx; + int64_t kvStartIdx; + int64_t vecCoreOffset = 0; + bool needCast; + bool align8 = false; + bool pseEndogenous = false; +}; + +template +__aicore__ inline void DataCopyInCommon(LocalTensor &dstTensor, GlobalTensor &srcTensor, int64_t offset, + int64_t s1Size, int64_t s2Size, int64_t actualS2Len, int32_t dtypeSize, + int32_t alignedS2Size) +{ + if constexpr (hasPse == true) { + uint32_t shapeArray[] = {static_cast(s1Size), static_cast(alignedS2Size)}; + dstTensor.SetShapeInfo(ShapeInfo(2, shapeArray, DataFormat::ND)); + dstTensor.SetSize(s1Size * alignedS2Size); + DataCopyParams dataCopyParams; + dataCopyParams.blockCount = s1Size; + dataCopyParams.blockLen = CeilDiv(s2Size * dtypeSize, blockBytes); // 单位32B + dataCopyParams.dstStride = alignedS2Size * dtypeSize / blockBytes - dataCopyParams.blockLen; // gap + if (actualS2Len * dtypeSize % blockBytes == 0) { + dataCopyParams.srcStride = + (actualS2Len * dtypeSize - dataCopyParams.blockLen * blockBytes) / blockBytes; // srcGap + DataCopy(dstTensor, srcTensor[offset], dataCopyParams); + } else { + dataCopyParams.blockLen = s2Size * dtypeSize; // 单位Byte + dataCopyParams.srcStride = (actualS2Len * dtypeSize - dataCopyParams.blockLen); + dataCopyParams.dstStride = (alignedS2Size - s2Size) * dtypeSize / blockBytes; + DataCopyPadParams dataCopyPadParams; + dataCopyPadParams.isPad = false; + DataCopyPad(dstTensor, srcTensor[offset], dataCopyParams, dataCopyPadParams); + } + } +} + +template +__aicore__ inline void DataCopyIn(LocalTensor &dstTensor, GlobalTensor &srcTensor, int64_t offset, + int64_t s1Size, int64_t s2Size, int64_t actualS2Len, int64_t alignedSize = 16) +{ + if constexpr (hasPse == true) { + int32_t dtypeSize = sizeof(INPUT_T); + int32_t alignedS2Size = CeilDiv(s2Size, alignedSize) * alignedSize; + DataCopyInCommon(dstTensor, srcTensor, offset, s1Size, s2Size, + actualS2Len, dtypeSize, alignedS2Size); + } +} + +template +__aicore__ inline void DataCopyInAlign8(LocalTensor &dstTensor, GlobalTensor &srcTensor, int64_t offset, + int64_t s1Size, int64_t s2Size, int64_t actualS2Len) +{ + if constexpr (hasPse == true) { + int32_t dtypeSize = sizeof(INPUT_T); + if (dtypeSize == 0){ + return; + } + int32_t alignedS2Size = CeilDiv(s2Size, 32 / dtypeSize) * (32 / dtypeSize); + DataCopyInCommon(dstTensor, srcTensor, offset, s1Size, s2Size, + actualS2Len, dtypeSize, alignedS2Size); + } +} + +/* +dst = BroadcastAdd(src0, src1) +src0 shape: (s1, s2) +src1 shape: (1, s2) +dst shape: (s1, s2) +*/ +template +__aicore__ inline void BroadcastAdd(const LocalTensor &src0Tensor, const LocalTensor &src1Tensor, + int64_t src0Offset, int32_t src1Size, int32_t repeatTimes) +{ + if constexpr (hasPse == true) { + /* Total data number of single step should be smaller than 256bytes. + * If larger, we need to do add multiple times. */ + int32_t innerLoop = src1Size / repeatMaxSize; // s2轴整块计算次数 + int32_t innerRemain = src1Size % repeatMaxSize; // s2轴尾块计算量 + BinaryRepeatParams binaryRepeatParams; + binaryRepeatParams.src0BlkStride = 1; + binaryRepeatParams.src0RepStride = src1Size / blockSize; + binaryRepeatParams.src1BlkStride = 1; + binaryRepeatParams.src1RepStride = 0; + binaryRepeatParams.dstRepStride = binaryRepeatParams.src0RepStride; + binaryRepeatParams.blockNumber = binaryRepeatParams.src0RepStride; + + for (int32_t j = 0; j < innerLoop; j++) { + auto innerOffset = j * repeatMaxSize; + auto ubOffset = src0Offset + innerOffset; + Add(src0Tensor[ubOffset], src0Tensor[ubOffset], src1Tensor[innerOffset], repeatMaxSize, repeatTimes, + binaryRepeatParams); + } + if (innerRemain > 0) { + auto innerOffset = innerLoop * repeatMaxSize; + auto ubOffset = src0Offset + innerOffset; + Add(src0Tensor[ubOffset], src0Tensor[ubOffset], src1Tensor[innerOffset], innerRemain, repeatTimes, + binaryRepeatParams); + } + } +} + +template +__aicore__ inline void PseBroadcastAdd(int32_t s1Size, int32_t s2Size, int32_t computeSize, const LocalTensor &pseUb, + const LocalTensor &dstTensor, uint32_t pseShapeType) +{ + if constexpr (hasPse == true) { + if (pseShapeType == pseS1S2 || pseShapeType == pseSlopeBn || pseShapeType == pseSlopeN) { + Add(dstTensor, dstTensor, pseUb, computeSize); + } else { + /* Total repeated times should be <= repeatMaxTimes. If larger, + * we need to do multiple inner loops. */ + int32_t s1OuterLoop = s1Size / repeatMaxTimes; + int32_t s1OuterRemain = s1Size % repeatMaxTimes; + for (int32_t s1OuterIdx = 0; s1OuterIdx < s1OuterLoop; s1OuterIdx++) { + int32_t s1OuterOffset = s1OuterIdx * repeatMaxTimes * s2Size; + BroadcastAdd(dstTensor, pseUb, s1OuterOffset, s2Size, repeatMaxTimes); + } + if (s1OuterRemain > 0) { + int32_t s1OuterOffset = s1OuterLoop * repeatMaxTimes * s2Size; + BroadcastAdd(dstTensor, pseUb, s1OuterOffset, s2Size, s1OuterRemain); + } + } + } +} +template __aicore__ inline int64_t PseComputeOffset(PseInfo &pseInfo) +{ + if constexpr (hasPse == true) { + int64_t bOffset = 0; + int64_t n2Offset = 0; + int64_t s1Offset = 0; + int64_t s2Offset = pseInfo.s2StartIdx + pseInfo.s2LoopCount * pseInfo.s2BaseNratioSize; + int64_t gOffset = 0; + if (pseInfo.pseShapeType == pseS1S2) { + // b, n2, g, s1, s2 + bOffset = pseInfo.bSSOffset * pseInfo.n2G; + n2Offset = pseInfo.n2oIdx * pseInfo.gSize * pseInfo.s1Size * pseInfo.s2Size; + gOffset = pseInfo.goIdx * pseInfo.s1Size * pseInfo.s2Size; + s1Offset = (pseInfo.s1oIdx * pseInfo.s1BaseSize + pseInfo.vecCoreOffset + + pseInfo.loopIdx * pseInfo.vec1S1BaseSize) * pseInfo.s2Size; + } else if (pseInfo.pseShapeType == pse1S2) { + // b, n2, g, 1, s2 + bOffset = pseInfo.s2SizeAcc * pseInfo.n2G; + n2Offset = pseInfo.n2oIdx * pseInfo.gSize * pseInfo.s2Size; + gOffset = pseInfo.goIdx * pseInfo.s2Size; + } + if (pseInfo.pseBSize == 1) { + bOffset = 0; + } + return bOffset + n2Offset + gOffset + s1Offset + s2Offset; + } else { + return 0; + } +} + +template __aicore__ inline int64_t PseAlibiComputeOffset(PseInfo &pseInfo) +{ + if constexpr (hasPse == true) { + int64_t bOffset = (pseInfo.boIdx % pseInfo.pseBSize) * pseInfo.n2G * pseInfo.pseS2Size * pseInfo.pseS1Size; + int64_t n2Offset = pseInfo.n2oIdx * pseInfo.gSize * pseInfo.pseS2Size * pseInfo.pseS1Size; + int64_t gOffset = pseInfo.goIdx * pseInfo.pseS2Size * pseInfo.pseS1Size; + int64_t row = pseInfo.s1oIdx * pseInfo.s1BaseSize + pseInfo.vecCoreOffset + + pseInfo.loopIdx * pseInfo.vec1S1BaseSize; + int64_t column = pseInfo.s2StartIdx + pseInfo.s2LoopCount * pseInfo.s2BaseNratioSize; + int64_t m = 0; + int64_t k = 0; + if constexpr (layOutType != LayOutTypeEnum::LAYOUT_TND) { + int64_t threshold = pseInfo.s1Size - pseInfo.pseS1Size; + if (row >= threshold) { + m = row - threshold; + k = column; + } else { + m = row % pseInfo.pseS1Size; + k = pseInfo.pseS2Size - (row - column) - (pseInfo.pseS1Size - m); + } + } else { + int64_t threshold = pseInfo.pseS2Size - pseInfo.pseS1Size; + int64_t posVal = row - column - threshold; + if (threshold >= 0) { + if (posVal >= 0) { + m = posVal; + k = 0; + } else { + m = 0; + k = -posVal; + } + } else { + m = posVal; + k = 0; + } + } + int64_t s1Offset = m * pseInfo.pseS2Size; + int64_t s2Offset = k; + pseInfo.readS2Size = Min(pseInfo.s2AlignedSize, pseInfo.pseS2Size - k); + pseInfo.pseS2ComputeSize = Align(pseInfo.readS2Size); + + return bOffset + n2Offset + gOffset + s1Offset + s2Offset; + } else { + return 0; + } +} + +template __aicore__ inline bool NeedPseAlibiCompute(PseInfo &pseInfo) +{ + if constexpr (hasPse == true) { + // Alibi编码只计算下三角 + if (pseInfo.s1oIdx * pseInfo.s1BaseSize + pseInfo.vecCoreOffset + + (pseInfo.loopIdx + 1) * pseInfo.vec1S1BaseSize <= + pseInfo.s2StartIdx + pseInfo.s2LoopCount * pseInfo.s2BaseNratioSize) { + return false; + } + return true; + } else { + return false; + } +} + +template +__aicore__ inline void PseAlibiCopyIn(LocalTensor &dstTensor, LocalTensor &tmpTensor, + GlobalTensor &srcTensor, PseInfo &pseInfo, int64_t alignedSize = 16) +{ + if constexpr (hasPse == true) { + if (!NeedPseAlibiCompute(pseInfo)) { + return; + } + int64_t offset = PseAlibiComputeOffset(pseInfo); + if constexpr (IsSameType::value) { + if (!pseInfo.align8){ + DataCopyIn(dstTensor, srcTensor, offset, pseInfo.vec1S1RealSize, pseInfo.readS2Size, + pseInfo.pseS2Size, alignedSize); + } else { + DataCopyInAlign8(dstTensor, srcTensor, offset, pseInfo.vec1S1RealSize, + pseInfo.readS2Size, pseInfo.pseS2Size); + } + return; + } + + DataCopyIn(tmpTensor, srcTensor, offset, pseInfo.vec1S1RealSize, pseInfo.readS2Size, + pseInfo.pseS2Size, alignedSize); + if (pseInfo.needCast) { + event_t eventIdMte2ToV = static_cast(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V)); + SetFlag(eventIdMte2ToV); + WaitFlag(eventIdMte2ToV); + Cast(dstTensor, tmpTensor, RoundMode::CAST_NONE, pseInfo.vec1S1RealSize * pseInfo.pseS2ComputeSize); + } + return; + } +} + +template +__aicore__ inline void PseSlopeCopyIn(LocalTensor &dstTensor, LocalTensor &helpTensor, + __gm__ uint8_t *pseSlope, GlobalTensor &alibiGm, PseInfo &pseInfo, + int64_t alignedSize = 16) { + if constexpr (hasPse == true) { + int64_t bOffset = 0; + int64_t n2Offset = pseInfo.n2oIdx * pseInfo.gSize; + int64_t gOffset = pseInfo.goIdx; + + if (pseInfo.pseShapeType == pseSlopeBn) { + bOffset = pseInfo.boIdx * pseInfo.n2G; + } + int64_t offset = bOffset + n2Offset + gOffset; + + DataCopyIn(helpTensor, alibiGm, 0, pseInfo.vec1S1RealSize, + pseInfo.s2RealSize, pseInfo.pseAlibiBaseS2, alignedSize); + event_t eventIdMte2ToV = static_cast(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V)); + SetFlag(eventIdMte2ToV); + WaitFlag(eventIdMte2ToV); + + if (pseInfo.needCast) { + int64_t computeSize = pseInfo.vec1S1RealSize * pseInfo.s2AlignedSize; + Cast(dstTensor, helpTensor, RoundMode::CAST_NONE, computeSize); + pipe_barrier(PIPE_V); + + int64_t s1Offset = pseInfo.s1oIdx * pseInfo.s1BaseSize + pseInfo.vecCoreOffset + + pseInfo.loopIdx * pseInfo.vec1S1BaseSize; + int64_t s2Offset = pseInfo.s2StartIdx + pseInfo.s2LoopCount * pseInfo.s2BaseNratioSize; + + float posShift = float(s2Offset + pseInfo.kvStartIdx - s1Offset - pseInfo.qStartIdx); + + Adds(dstTensor, dstTensor, posShift, computeSize); + pipe_barrier(PIPE_V); + Abs(dstTensor, dstTensor, computeSize); + pipe_barrier(PIPE_V); + float slopes = ((__gm__ T *)pseSlope)[offset] * -1; + if (pseInfo.pseType == (uint32_t)PseTypeEnum::PSE_INNER_MUL_ADD_SQRT_TYPE) { + Sqrt(dstTensor, dstTensor, computeSize); + pipe_barrier(PIPE_V); + } + Muls(dstTensor, dstTensor, slopes, computeSize); + pipe_barrier(PIPE_V); + } + } +} + +template +__aicore__ inline void PseSlopeCast(LocalTensor &dstTensor, LocalTensor &helpTensor, + __gm__ uint8_t *pseSlope, PseInfo &pseInfo) { + if constexpr (hasPse == true) { + int64_t bOffset = 0; + int64_t n2Offset = pseInfo.n2oIdx * pseInfo.gSize; + int64_t gOffset = pseInfo.goIdx; + + if (pseInfo.pseShapeType == pseSlopeBn) { + bOffset = pseInfo.boIdx * pseInfo.n2G; + } + int64_t offset = bOffset + n2Offset + gOffset; + int64_t computeSize = pseInfo.vec1S1RealSize * pseInfo.s2AlignedSize; + Cast(dstTensor, helpTensor, RoundMode::CAST_NONE, computeSize); + pipe_barrier(PIPE_V); + + int64_t s1Offset = pseInfo.s1oIdx * pseInfo.s1BaseSize + pseInfo.vecCoreOffset + + pseInfo.loopIdx * pseInfo.vec1S1BaseSize; + int64_t s2Offset = pseInfo.s2StartIdx + pseInfo.s2LoopCount * pseInfo.s2BaseNratioSize; + + float posShift = float(s2Offset + pseInfo.kvStartIdx - s1Offset - pseInfo.qStartIdx); + + Adds(dstTensor, dstTensor, posShift, computeSize); + pipe_barrier(PIPE_V); + Abs(dstTensor, dstTensor, computeSize); + pipe_barrier(PIPE_V); + float slopes = ((__gm__ T *)pseSlope)[offset] * -1; + if (pseInfo.pseType == (uint32_t)PseTypeEnum::PSE_INNER_MUL_ADD_SQRT_TYPE) { + Sqrt(dstTensor, dstTensor, computeSize); + pipe_barrier(PIPE_V); + } + Muls(dstTensor, dstTensor, slopes, computeSize); + pipe_barrier(PIPE_V); + } +} + +template +__aicore__ inline void PseCopyIn(LocalTensor &dstTensor, LocalTensor &tmpTensor, + GlobalTensor &srcTensor, PseInfo &pseInfo, int64_t alignedSize = 16) +{ + if constexpr (hasPse == true) { + if (pseInfo.pseEncodeType == pseEncodeALibiS2Full) { + return PseAlibiCopyIn(dstTensor, tmpTensor, srcTensor, pseInfo, alignedSize); + } + int64_t offset = PseComputeOffset(pseInfo); + int64_t s1Size = pseInfo.pseShapeType == pse1S2 ? (pseInfo.blockCount == 0 ? 1 : pseInfo.blockCount) : + pseInfo.vec1S1RealSize; + + if constexpr (IsSameType::value) { + if (!pseInfo.align8){ + DataCopyIn(dstTensor, srcTensor, offset, s1Size, pseInfo.s2RealSize, + pseInfo.s2Size, alignedSize); + } else { + DataCopyInAlign8(dstTensor, srcTensor, offset, s1Size, pseInfo.s2RealSize, pseInfo.s2Size); + } + return; + } + DataCopyIn(tmpTensor, srcTensor, offset, s1Size, pseInfo.s2RealSize, pseInfo.s2Size, + alignedSize); + if (pseInfo.needCast) { + event_t eventIdMte2ToV = static_cast(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V)); + SetFlag(eventIdMte2ToV); + WaitFlag(eventIdMte2ToV); + Cast(dstTensor, tmpTensor, RoundMode::CAST_NONE, s1Size * pseInfo.s2AlignedSize); + } + return; + } +} + +template +__aicore__ inline void PseAlibiCompute(LocalTensor &dstTensor, LocalTensor &pseTensor, PseInfo &pseInfo) +{ + if constexpr (hasPse == true) { + if (!NeedPseAlibiCompute(pseInfo)) { + return; + } + Add(dstTensor, dstTensor, pseTensor, pseInfo.vec1S1RealSize * pseInfo.pseS2ComputeSize); + return; + } +} + +template +__aicore__ inline void PseCompute(LocalTensor &dstTensor, LocalTensor &pseTensor, PseInfo &pseInfo) +{ + if constexpr (hasPse == true) { + if (pseInfo.pseEncodeType == pseEncodeALibiS2Full) { + return PseAlibiCompute(dstTensor, pseTensor, pseInfo); + } + int64_t computeSize = (pseInfo.pseShapeType == pseS1S2 || pseInfo.pseShapeType == pseSlopeBn || + pseInfo.pseShapeType == pseSlopeN) + ? pseInfo.vec1S1RealSize * pseInfo.s2AlignedSize + : pseInfo.s2AlignedSize; + PseBroadcastAdd(pseInfo.vec1S1RealSize, pseInfo.s2AlignedSize, computeSize, pseTensor, + dstTensor, pseInfo.pseShapeType); + return; + } +} + +template +__aicore__ inline void PseInnerAlibiCreate(GlobalTensor &dstTensor, LocalTensor &helpTensor, PseInfo &pseInfo) { + if constexpr (hasPse == true) { + if (pseInfo.pseType != (uint32_t)PseTypeEnum::PSE_INNER_MUL_ADD_TYPE && pseInfo.pseType != (uint32_t)PseTypeEnum::PSE_INNER_MUL_ADD_SQRT_TYPE) { + return; + } + event_t eventIdMte3ToV = static_cast(GetTPipePtr()->FetchEventID(HardEvent::MTE3_V)); + event_t eventIdMte3ToS = static_cast(GetTPipePtr()->FetchEventID(HardEvent::MTE3_S)); + event_t eventIdVToMte3 = static_cast(GetTPipePtr()->FetchEventID(HardEvent::V_MTE3)); + float tmpValue = -1.0; + + for (int64_t i = 0; i < pseInfo.pseAlibiBaseS1; i++) { + CreateVecIndex(helpTensor, (half)(i * tmpValue), pseInfo.pseAlibiBaseS2); + SetFlag(eventIdVToMte3); + WaitFlag(eventIdVToMte3); + DataCopy(dstTensor[i * pseInfo.pseAlibiBaseS2], helpTensor, pseInfo.pseAlibiBaseS2); + SetFlag(eventIdMte3ToV); + WaitFlag(eventIdMte3ToV); + SetFlag(eventIdMte3ToS); + WaitFlag(eventIdMte3ToS); + } + } +} +#endif diff --git a/csrc/utils/inc/kernel/util.h b/csrc/utils/inc/kernel/util.h new file mode 100644 index 00000000..2c7d2089 --- /dev/null +++ b/csrc/utils/inc/kernel/util.h @@ -0,0 +1,144 @@ +/** + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file util.h + * \brief + */ + +#ifndef FLASH_ATTENTION_UTIL_H +#define FLASH_ATTENTION_UTIL_H + +constexpr int32_t blockBytes = 32; +constexpr int32_t byteBitRatio = 8; +constexpr int64_t prefixAttenMaskDownHeight = 1024; +constexpr static int32_t blockSize = blockBytes / 4; // 4 means sizeof(T) +constexpr static int32_t repeatMaxBytes = 256; +constexpr static int32_t repeatMaxTimes = 255; +constexpr static int32_t repeatMaxSize = repeatMaxBytes / 4; // 4 means sizeof(T) + +using AscendC::LocalTensor; +using AscendC::GlobalTensor; +using AscendC::DataFormat; +using AscendC::ShapeInfo; +using AscendC::DataCopyParams; +using AscendC::DataCopyPadParams; +using AscendC::BinaryRepeatParams; +using AscendC::IsSameType; +using AscendC::HardEvent; +using AscendC::SetFlag; +using AscendC::WaitFlag; + +enum class LayOutTypeEnum { None = 0, LAYOUT_BSH = 1, LAYOUT_SBH = 2, LAYOUT_BNSD = 3, LAYOUT_TND = 4, LAYOUT_NTD_TND = 5}; + +namespace math { +template __aicore__ inline T Ceil(T a, T b) +{ + if (b == 0) { + return 0; + } + return (a + b - 1) / b; +} + +template __aicore__ inline T Align(T a, T b) +{ + if (b == 0) { + return 0; + } + return (a + b - 1) / b * b; +} +} + +template +__aicore__ inline T1 CeilDiv(T1 a, T2 b) +{ + if (b == 0) { + return 0; + } + return (a + b - 1) / b; +} + +template +__aicore__ inline T1 Max(T1 a, T2 b) +{ + return (a > b) ? (a) : (b); +} + +template +__aicore__ inline T1 Min(T1 a, T2 b) +{ + return (a > b) ? (b) : (a); +} + +__aicore__ inline void BoolCopyIn(LocalTensor &dstTensor, GlobalTensor &srcTensor, + int64_t srcOffset, uint32_t s1Size, uint32_t s2Size, int64_t totalS2Size, int64_t alignedSize = blockBytes) +{ + uint32_t alignedS2Size = CeilDiv(s2Size, alignedSize) * alignedSize; + uint32_t shapeArray[] = {s1Size, alignedS2Size}; + dstTensor.SetShapeInfo(ShapeInfo(2, shapeArray, DataFormat::ND)); + dstTensor.SetSize(s1Size * alignedS2Size); + DataCopyParams dataCopyParams; + dataCopyParams.blockCount = s1Size; + dataCopyParams.dstStride = 0; + if (totalS2Size == blockBytes && alignedSize == 64) { // totalS2Size < 64 && totalS2Size % blockBytes == 0 + dataCopyParams.dstStride = 1; + alignedSize = blockBytes; + alignedS2Size = CeilDiv(s2Size, blockBytes) * blockBytes; + } + if (totalS2Size % alignedSize == 0) { + dataCopyParams.blockLen = alignedS2Size / blockBytes; + dataCopyParams.srcStride = (totalS2Size - alignedS2Size) / blockBytes; + DataCopy(dstTensor, srcTensor[srcOffset], dataCopyParams); + } else { + dataCopyParams.blockLen = s2Size; + dataCopyParams.srcStride = totalS2Size - s2Size; + DataCopyPadParams dataCopyPadParams; + dataCopyPadParams.isPad = true; + dataCopyPadParams.rightPadding = Min(alignedS2Size - s2Size, blockBytes); + dataCopyPadParams.paddingValue = 1; + DataCopyPad(dstTensor, srcTensor[srcOffset], dataCopyParams, dataCopyPadParams); + } +} + +__aicore__ inline void Bit2Int8CopyIn(LocalTensor &dstTensor, GlobalTensor &srcTensor, + int64_t srcOffset, uint32_t batchSize, uint32_t s1BaseSize, uint32_t s2BaseSize, int64_t s2TotalSize, + int64_t alignedSize = blockBytes) +{ + uint32_t alignedS2Size = CeilDiv(s2BaseSize / byteBitRatio, alignedSize) * alignedSize; + uint32_t shapeArray[] = {batchSize * s1BaseSize, alignedS2Size}; + dstTensor.SetShapeInfo(ShapeInfo(2, shapeArray, DataFormat::ND)); + dstTensor.SetSize(batchSize * s1BaseSize * alignedS2Size); + DataCopyParams dataCopyParams; + dataCopyParams.blockCount = batchSize * s1BaseSize; + dataCopyParams.blockLen = CeilDiv(s2BaseSize / byteBitRatio, blockBytes); + dataCopyParams.dstStride = 0; + if (s2TotalSize / byteBitRatio % alignedSize == 0 && s2BaseSize / byteBitRatio % alignedSize == 0) { + dataCopyParams.srcStride = + (s2TotalSize / byteBitRatio - dataCopyParams.blockLen * blockBytes) / blockBytes; + DataCopy(dstTensor, srcTensor[srcOffset / byteBitRatio], dataCopyParams); + } else { + dataCopyParams.blockLen = CeilDiv(s2BaseSize , byteBitRatio); + dataCopyParams.srcStride = (s2TotalSize - s2BaseSize) / byteBitRatio; + DataCopyPadParams dataCopyPadParams; + dataCopyPadParams.isPad = true; + dataCopyPadParams.rightPadding = 0; + dataCopyPadParams.paddingValue = 0; + DataCopyPad(dstTensor, srcTensor[srcOffset / byteBitRatio], dataCopyParams, dataCopyPadParams); + } +} + +__aicore__ inline int32_t Align(int32_t shape) +{ + int32_t alignFactor = 16; + int32_t alignedSize = CeilDiv(shape, alignFactor) * alignFactor; + return alignedSize; +} + +#endif // FLASH_ATTENTION_UTIL_H diff --git a/csrc/utils/inc/log/inner/dfx_base.h b/csrc/utils/inc/log/inner/dfx_base.h new file mode 100644 index 00000000..0fd1edb4 --- /dev/null +++ b/csrc/utils/inc/log/inner/dfx_base.h @@ -0,0 +1,190 @@ +/** + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file dfx_base.h + * \brief 外部模块不应直接引用本头文件 + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace ops { +namespace utils { + +class LogBase { +public: + static constexpr const int MAX_LOG_LEN = 16000; + static constexpr const int MSG_HDR_LEN = 200; + + static inline uint64_t GetTid() + { + return static_cast(syscall(__NR_gettid)); + } + + static inline const char *GetStr(const std::string &str) + { + return str.c_str(); + } + + static inline const char *GetStr(const char *str) + { + return str; + } + + static inline const std::string &GetOpInfo(const std::string &str) + { + return str; + } + + static inline const char *GetOpInfo(const char *str) + { + return str; + } + + static inline std::string GetOpInfo(const gert::TilingContext *context) + { + return GetOpInfoFromContext(context); + } + + static inline std::string GetOpInfo(const gert::TilingParseContext *context) + { + return GetOpInfoFromContext(context); + } + + static inline std::string GetOpInfo(const gert::InferShapeContext *context) + { + return GetOpInfoFromContext(context); + } + + static inline std::string GetOpInfo(const gert::InferDataTypeContext *context) + { + return GetOpInfoFromContext(context); + } + +private: + template static inline std::string GetOpInfoFromContext(T context) + { + if (context == nullptr) { + return "nil:nil"; + } + std::string opInfo = context->GetNodeType() != nullptr ? context->GetNodeType() : "nil"; + opInfo += ":"; + opInfo += context->GetNodeName() != nullptr ? context->GetNodeName() : "nil"; + return opInfo; + } +}; + +} // namespace utils + +template +std::string Shape2String(const T& shape) { + std::ostringstream oss; + oss << "["; + if (shape.GetDimNum() > 0) { + for (size_t i = 0; i < shape.GetDimNum() - 1; ++i) { + oss << shape.GetDim(i) << ", "; + } + oss << shape.GetDim(shape.GetDimNum() - 1); + } + oss << "]"; + return oss.str(); +} +} // namespace ops + +// 使用本宏前需预定义标识子模块名称的 OPS_UTILS_LOG_SUB_MOD_NAME +// 如: #define OPS_UTILS_LOG_SUB_MOD_NAME "OP_TILING" 或通过 CMake 传递预定义宏 +#define OPS_LOG_STUB(MOD_ID, LOG_LEVEL, OPS_DESC, FMT, ...) \ + do { \ + if (AlogCheckDebugLevel(static_cast(MOD_ID), (LOG_LEVEL)) == 1) { \ + AlogRecord(static_cast(MOD_ID), DLOG_TYPE_DEBUG, (LOG_LEVEL), \ + "[%s:%d][%s]%s[%s][%lu] OpName:[%s] " #FMT, \ + __FILE__, __LINE__, (OPS_UTILS_LOG_SUB_MOD_NAME), \ + (OPS_UTILS_LOG_PACKAGE_TYPE), __FUNCTION__, ops::utils::LogBase::GetTid(), \ + ops::utils::LogBase::GetStr(ops::utils::LogBase::GetOpInfo(OPS_DESC)), ##__VA_ARGS__); \ + } \ + }while (0) + +#define OPS_LOG_STUB_IF(COND, LOG_FUNC, EXPR) \ + static_assert(std::is_same::type>::value, "condition should be bool"); \ + do { \ + if (__builtin_expect((COND), 0)) { \ + LOG_FUNC; \ + EXPR; \ + } \ + } while (0) + +#define OPS_INNER_ERR_STUB(ERR_CODE_STR, OPS_DESC, FMT, ...) \ + do { \ + OPS_LOG_STUB(OP, DLOG_ERROR, OPS_DESC, FMT, ##__VA_ARGS__); \ + REPORT_INNER_ERR_MSG(ERR_CODE_STR, FMT, ##__VA_ARGS__); \ + } while (0) + +#define OPS_CALL_ERR_STUB(ERR_CODE_STR, OPS_DESC, FMT, ...) \ + do { \ + OPS_LOG_STUB(OP, DLOG_ERROR, OPS_DESC, FMT, ##__VA_ARGS__); \ + REPORT_INNER_ERR_MSG(ERR_CODE_STR, FMT, ##__VA_ARGS__); \ + } while (0) + +#define OPS_LOG_STUB_D(OPS_DESC, FMT, ...) OPS_LOG_STUB(OP, DLOG_DEBUG, OPS_DESC, FMT, ##__VA_ARGS__) +#define OPS_LOG_STUB_I(OPS_DESC, FMT, ...) OPS_LOG_STUB(OP, DLOG_INFO, OPS_DESC, FMT, ##__VA_ARGS__) +#define OPS_LOG_STUB_W(OPS_DESC, FMT, ...) OPS_LOG_STUB(OP, DLOG_WARN, OPS_DESC, FMT, ##__VA_ARGS__) +#define OPS_LOG_STUB_E(OPS_DESC, FMT, ...) OPS_LOG_STUB(OP, DLOG_ERROR, OPS_DESC, FMT, ##__VA_ARGS__) +#define OPS_LOG_STUB_EVENT(OPS_DESC, FMT, ...) OPS_LOG_STUB(OP, DLOG_EVENT, OPS_DESC, FMT, ##__VA_ARGS__) + +#define OPS_LOG_STUB_FULL(LEVEL, OPS_DESC, FMT, ...) \ + do { \ + if (0 == AlogCheckDebugLevel(OP, (LEVEL))) { \ + break; \ + } \ + char msgbufxyz[ops::utils::LogBase::MAX_LOG_LEN]; \ + size_t msgmaxlen = (MSG_LENGTH - ops::utils::LogBase::MSG_HDR_LEN); \ + int rettmp = snprintf_s(msgbufxyz, sizeof(msgbufxyz), sizeof(msgbufxyz) - 1, FMT, ##__VA_ARGS__); \ + if (rettmp == -1) { \ + msgbufxyz[sizeof(msgbufxyz) - 1] = '\0'; \ + } \ + size_t msglength = std::strlen(msgbufxyz); \ + if (msglength < msgmaxlen) { \ + OPS_LOG_STUB(OP, (LEVEL), (OPS_DESC), "%s", msgbufxyz); \ + break; \ + } \ + char *msgchunkbegin = msgbufxyz; \ + char *msgchunkend = nullptr; \ + while (msgchunkbegin < msgbufxyz + msglength) { \ + if (msgchunkbegin[0] == '\n') { \ + OPS_LOG_STUB(OP, (LEVEL), (OPS_DESC), ""); \ + msgchunkbegin += 1; \ + continue; \ + } \ + msgchunkend = std::strchr(msgchunkbegin, '\n'); \ + if (msgchunkend == nullptr) { \ + msgchunkend = msgchunkbegin + std::strlen(msgchunkbegin); \ + } \ + while (msgchunkend > msgchunkbegin) { \ + std::string msgchunk(msgchunkbegin, \ + std::min(msgmaxlen, static_cast(msgchunkend - msgchunkbegin))); \ + OPS_LOG_STUB(OP, (LEVEL), (OPS_DESC), "%s", msgchunk.c_str()); \ + msgchunkbegin += msgchunk.size(); \ + } \ + msgchunkbegin += 1; \ + } \ + } while (0) diff --git a/csrc/utils/inc/log/ops_log.h b/csrc/utils/inc/log/ops_log.h new file mode 100644 index 00000000..e7653a89 --- /dev/null +++ b/csrc/utils/inc/log/ops_log.h @@ -0,0 +1,59 @@ +/** + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file ops_log.h + * \brief + */ + +#pragma once + +#include "log/inner/dfx_base.h" + +/* 基础日志 */ +#define OPS_LOG_D(OPS_DESC, ...) OPS_LOG_STUB_D(OPS_DESC, __VA_ARGS__) +#define OPS_LOG_I(OPS_DESC, ...) OPS_LOG_STUB_I(OPS_DESC, __VA_ARGS__) +#define OPS_LOG_W(OPS_DESC, ...) OPS_LOG_STUB_W(OPS_DESC, __VA_ARGS__) +#define OPS_LOG_E(OPS_DESC, ...) OPS_INNER_ERR_STUB("EZ9999", OPS_DESC, __VA_ARGS__) +#define OPS_LOG_E_WITHOUT_REPORT(OPS_DESC, ...) OPS_LOG_STUB_E(OPS_DESC, __VA_ARGS__) +#define OPS_LOG_EVENT(OPS_DESC, ...) OPS_LOG_STUB_EVENT(OPS_DESC, __VA_ARGS__) + +/* 全量日志 + * 输出超长日志, 若日志超长, 则会被分为多行输出 */ +#define OPS_LOG_FULL(LEVEL, OPS_DESC, ...) OPS_LOG_STUB_FULL(LEVEL, OPS_DESC, __VA_ARGS__) +#define OPS_LOG_D_FULL(OPS_DESC, ...) OPS_LOG_STUB_FULL(DLOG_DEBUG, OPS_DESC, __VA_ARGS__) +#define OPS_LOG_I_FULL(OPS_DESC, ...) OPS_LOG_STUB_FULL(DLOG_INFO, OPS_DESC, __VA_ARGS__) +#define OPS_LOG_W_FULL(OPS_DESC, ...) OPS_LOG_STUB_FULL(DLOG_WARN, OPS_DESC, __VA_ARGS__) + +/* 条件日志 */ +#define OPS_LOG_D_IF(COND, OP_DESC, EXPR, ...) OPS_LOG_STUB_IF(COND, OPS_LOG_D(OP_DESC, __VA_ARGS__), EXPR) +#define OPS_LOG_I_IF(COND, OP_DESC, EXPR, ...) OPS_LOG_STUB_IF(COND, OPS_LOG_I(OP_DESC, __VA_ARGS__), EXPR) +#define OPS_LOG_W_IF(COND, OP_DESC, EXPR, ...) OPS_LOG_STUB_IF(COND, OPS_LOG_W(OP_DESC, __VA_ARGS__), EXPR) +#define OPS_LOG_E_IF(COND, OP_DESC, EXPR, ...) OPS_LOG_STUB_IF(COND, OPS_LOG_E(OP_DESC, __VA_ARGS__), EXPR) +#define OPS_LOG_EVENT_IF(COND, OP_DESC, EXPR, ...) OPS_LOG_STUB_IF(COND, OPS_LOG_EVENT(OP_DESC, __VA_ARGS__), EXPR) + +#define OPS_LOG_E_IF_NULL(OPS_DESC, PTR, EXPR) \ + if (__builtin_expect((PTR) == nullptr, 0)) { \ + OPS_LOG_STUB_E(OPS_DESC, "%s is nullptr!", #PTR); \ + OPS_CALL_ERR_STUB("EZ9999", OPS_DESC, "%s is nullptr!", #PTR); \ + EXPR; \ + } + +#define OPS_CHECK(COND, LOG_FUNC, EXPR) \ + if (COND) { \ + LOG_FUNC; \ + EXPR; \ + } + +#define OP_CHECK(COND, LOG_FUNC, EXPR) \ + if (COND) { \ + LOG_FUNC; \ + EXPR; \ + } diff --git a/csrc/utils/inc/tiling/data_copy_transpose_tiling.h b/csrc/utils/inc/tiling/data_copy_transpose_tiling.h new file mode 100644 index 00000000..7e8d15d7 --- /dev/null +++ b/csrc/utils/inc/tiling/data_copy_transpose_tiling.h @@ -0,0 +1,47 @@ +/** + * Copyright (c) 2023-2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file data_copy_transpose_tiling.h + * \brief + */ + +#pragma once + +#include +#include +#include "data_copy_transpose_tiling_def.h" + +namespace optiling { + +inline void GetDataCopyTransposeTiling(const ge::Shape &dstShape, const ge::Shape &srcShape, const uint32_t typeSize, + optiling::CopyTransposeTiling &tiling) +{ + std::vector dstShapeInfo = dstShape.GetDims(); + std::vector srcShapeInfo = srcShape.GetDims(); + + tiling.set_dstShapeB(dstShapeInfo[0]); + tiling.set_dstShapeN(dstShapeInfo[1]); + tiling.set_dstShapeS(dstShapeInfo[2]); + tiling.set_dstShapeH(dstShapeInfo[3]); + tiling.set_dstShapeHN(tiling.get_dstShapeH() / tiling.get_dstShapeN()); + + tiling.set_srcShapeB(srcShapeInfo[0]); + tiling.set_srcShapeN(srcShapeInfo[1]); + tiling.set_srcShapeS(srcShapeInfo[2]); + tiling.set_srcShapeHN(srcShapeInfo[3]); + tiling.set_originalShapeNLen(tiling.get_srcShapeHN() * typeSize); + tiling.set_shapeSHValue(tiling.get_dstShapeS() * tiling.get_dstShapeH()); + tiling.set_shapeNsValue(tiling.get_dstShapeN() * tiling.get_dstShapeS()); + tiling.set_shapeNsnValue(tiling.get_dstShapeN() * tiling.get_srcShapeS() * tiling.get_srcShapeN()); + tiling.set_shapeBHValue(tiling.get_dstShapeB() * tiling.get_dstShapeH()); +} + +} // namespace optiling diff --git a/csrc/utils/inc/tiling/data_copy_transpose_tiling_def.h b/csrc/utils/inc/tiling/data_copy_transpose_tiling_def.h new file mode 100644 index 00000000..510b5cda --- /dev/null +++ b/csrc/utils/inc/tiling/data_copy_transpose_tiling_def.h @@ -0,0 +1,43 @@ +/** + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file data_copy_transpose_tiling_def.h + * \brief + */ + +#pragma once + +#include +#include + +namespace optiling { + +BEGIN_TILING_DATA_DEF(CopyTransposeTiling) +TILING_DATA_FIELD_DEF(uint32_t, dstShapeB); +TILING_DATA_FIELD_DEF(uint32_t, dstShapeN); +TILING_DATA_FIELD_DEF(uint32_t, dstShapeS); +TILING_DATA_FIELD_DEF(uint32_t, dstShapeHN); +TILING_DATA_FIELD_DEF(uint32_t, dstShapeH); +TILING_DATA_FIELD_DEF(uint32_t, srcShapeB); +TILING_DATA_FIELD_DEF(uint32_t, srcShapeN); +TILING_DATA_FIELD_DEF(uint32_t, srcShapeS); +TILING_DATA_FIELD_DEF(uint32_t, srcShapeHN); +TILING_DATA_FIELD_DEF(uint32_t, originalShapeNLen); +TILING_DATA_FIELD_DEF(uint32_t, shapeSHValue); +TILING_DATA_FIELD_DEF(uint32_t, shapeNsValue); +TILING_DATA_FIELD_DEF(uint32_t, shapeNsnValue); +TILING_DATA_FIELD_DEF(uint32_t, invalidParamCopyTransposeTiling); +TILING_DATA_FIELD_DEF(uint32_t, shapeBHValue); +TILING_DATA_FIELD_DEF(uint32_t, paramsAlign); +END_TILING_DATA_DEF; +REGISTER_TILING_DATA_CLASS(CopyTransposeTilingOp, CopyTransposeTiling) + +} // namespace optiling diff --git a/csrc/utils/inc/tiling/tiling_base.h b/csrc/utils/inc/tiling/tiling_base.h new file mode 100644 index 00000000..9776d90c --- /dev/null +++ b/csrc/utils/inc/tiling/tiling_base.h @@ -0,0 +1,225 @@ +/** + * Copyright (c) 2023-2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file tiling_base.h + * \brief + */ + +#pragma once + +#include +#include +#include +#include +#include "log/ops_log.h" + +#ifdef ASCENDC_OP_TEST +#define ASCENDC_EXTERN_C extern "C" +#else +#define ASCENDC_EXTERN_C +#endif + +namespace optiling { + +struct AiCoreParams { + uint64_t ubSize; + uint64_t blockDim; + uint64_t aicNum; + uint64_t l1Size; + uint64_t l0aSize; + uint64_t l0bSize; + uint64_t l0cSize; +}; + +struct FlashAttentionScoreGradCompileInfo { + uint32_t aivNum; + uint32_t aicNum; + uint64_t ubSize; + uint64_t l1Size; + uint64_t l0aSize; + uint64_t l0bSize; + uint64_t l0cSize; + uint64_t l2CacheSize; + int64_t coreNum; +}; + +class TilingBaseClass { +public: + TilingBaseClass() = default; + + explicit TilingBaseClass(gert::TilingContext *context) : context_(context) + { + } + + virtual ~TilingBaseClass() = default; + + // Tiling执行框架 + // 1、GRAPH_SUCCESS: 成功,并且不需要继续执行后续Tiling类的实现 + // 2、GRAPH_FAILED: 失败,中止整个Tiling流程 + // 3、GRAPH_PARAM_INVALID: 本类不支持,需要继续往下执行其他Tiling类的实现 + ge::graphStatus DoTiling() + { + auto ret = GetShapeAttrsInfo(); + if (ret != ge::GRAPH_SUCCESS) { + return ret; + } + ret = GetPlatformInfo(); + if (ret != ge::GRAPH_SUCCESS) { + return ret; + } + if (!IsCapable()) { + return ge::GRAPH_PARAM_INVALID; + } + ret = DoOpTiling(); + if (ret != ge::GRAPH_SUCCESS) { + return ret; + } + ret = DoLibApiTiling(); + if (ret != ge::GRAPH_SUCCESS) { + return ret; + } + ret = GetWorkspaceSize(); + if (ret != ge::GRAPH_SUCCESS) { + return ret; + } + ret = PostTiling(); + if (ret != ge::GRAPH_SUCCESS) { + return ret; + } + context_->SetTilingKey(GetTilingKey()); + DumpTilingInfo(); + return ge::GRAPH_SUCCESS; + } + + // 更新 context + virtual void Reset(gert::TilingContext *context) + { + context_ = context; + } + +protected: + virtual bool IsCapable() = 0; + // 1、获取平台信息比如CoreNum、UB/L1/L0C资源大小 + virtual ge::graphStatus GetPlatformInfo() = 0; + // 2、获取INPUT/OUTPUT/ATTR信息 + virtual ge::graphStatus GetShapeAttrsInfo() = 0; + // 3、计算数据切分TilingData + virtual ge::graphStatus DoOpTiling() = 0; + // 4、计算高阶API的TilingData + virtual ge::graphStatus DoLibApiTiling() = 0; + // 5、计算TilingKey + [[nodiscard]] virtual uint64_t GetTilingKey() const = 0; + // 6、计算Workspace 大小 + virtual ge::graphStatus GetWorkspaceSize() = 0; + // 7、保存Tiling数据 + virtual ge::graphStatus PostTiling() = 0; + // 8、Dump Tiling数据 + virtual void DumpTilingInfo() + { + int32_t enable = AlogCheckDebugLevel(static_cast(OP), DLOG_DEBUG); + if (enable != 1) { + return; + } + auto buf = (uint32_t *)context_->GetRawTilingData()->GetData(); + auto bufLen = context_->GetRawTilingData()->GetDataSize(); + std::ostringstream oss; + oss << "Start to dump tiling info. tilingkey:" << GetTilingKey() << ", tiling data size:" << bufLen + << ", content:"; + for (size_t i = 0; i < bufLen / sizeof(uint32_t); i++) { + oss << *(buf + i) << ","; + if (oss.str().length() > 640) { // Split according to 640 to avoid truncation + OPS_LOG_D(context_, "%s", oss.str().c_str()); + oss.str(""); + } + } + OPS_LOG_D(context_, "%s", oss.str().c_str()); + } + + static uint32_t CalcTschBlockDim(uint32_t sliceNum, uint32_t aicCoreNum, uint32_t aivCoreNum) + { + uint32_t ration; + if (aicCoreNum == 0 || aivCoreNum == 0 || aicCoreNum > aivCoreNum) { + return sliceNum; + } + ration = aivCoreNum / aicCoreNum; + return (sliceNum + (ration - 1)) / ration; + } + + template [[nodiscard]] std::string GetShapeDebugStr(const T &shape) const + { + std::ostringstream oss; + oss << "["; + if (shape.GetDimNum() > 0) { + for (size_t i = 0; i < shape.GetDimNum() - 1; ++i) { + oss << shape.GetDim(i) << ", "; + } + oss << shape.GetDim(shape.GetDimNum() - 1); + } + oss << "]"; + return oss.str(); + } + + [[nodiscard]] std::string GetTensorDebugStr(const gert::StorageShape *shape, + const gert::CompileTimeTensorDesc *tensor) + { + if (shape == nullptr || tensor == nullptr) { + return "nil "; + } + std::ostringstream oss; + oss << "(dtype: " << ge::TypeUtils::DataTypeToSerialString(tensor->GetDataType()) << "),"; + oss << "(shape:" << GetShapeDebugStr(shape->GetStorageShape()) << "),"; + oss << "(ori_shape:" << GetShapeDebugStr(shape->GetOriginShape()) << "),"; + oss << "(format: " + << ge::TypeUtils::FormatToSerialString( + static_cast(ge::GetPrimaryFormat(tensor->GetStorageFormat()))) + << "),"; + oss << "(ori_format: " << ge::TypeUtils::FormatToSerialString(tensor->GetOriginFormat()) << ") "; + return oss.str(); + } + + [[nodiscard]] std::string GetTilingContextDebugStr() + { + std::ostringstream oss; + for (size_t i = 0; i < context_->GetComputeNodeInfo()->GetInputsNum(); ++i) { + oss << "input" << i << ": "; + oss << GetTensorDebugStr(context_->GetInputShape(i), context_->GetInputDesc(i)); + } + + for (size_t i = 0; i < context_->GetComputeNodeInfo()->GetOutputsNum(); ++i) { + oss << "output" << i << ": "; + oss << GetTensorDebugStr(context_->GetOutputShape(i), context_->GetOutputDesc(i)); + } + return oss.str(); + } + + [[nodiscard]] std::string GetTilingDataDebugStr() const + { + auto rawTilingData = context_->GetRawTilingData(); + auto rawTilingDataSize = rawTilingData->GetDataSize(); + auto data = reinterpret_cast(rawTilingData->GetData()); + size_t len = rawTilingDataSize / sizeof(int32_t); + std::ostringstream oss; + for (size_t i = 0; i < len; i++) { + oss << data[i] << ", "; + } + return oss.str(); + } + +protected: + gert::TilingContext *context_ = nullptr; + std::unique_ptr ascendcPlatform_{nullptr}; + uint32_t blockDim_{0}; + uint64_t workspaceSize_{0}; + uint64_t tilingKey_{0}; + AiCoreParams aicoreParams_{0, 0, 0, 0, 0, 0, 0}; +}; + +} // namespace optiling diff --git a/csrc/utils/inc/tiling/tiling_templates_registry.h b/csrc/utils/inc/tiling/tiling_templates_registry.h new file mode 100644 index 00000000..53fc590a --- /dev/null +++ b/csrc/utils/inc/tiling/tiling_templates_registry.h @@ -0,0 +1,162 @@ +/** + * Copyright (c) 2023-2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file tiling_templates_registry.h + * \brief + */ + +#pragma once + +#include +#include +#include +#include +#include "tiling/tiling_base.h" +#include "log/ops_log.h" +#include "error/ops_error.h" + +namespace optiling { + +template std::unique_ptr TILING_CLASS(gert::TilingContext *context) +{ + return std::unique_ptr(new (std::nothrow) T(context)); +} + +using TilingClassCase = std::unique_ptr (*)(gert::TilingContext *); + +class TilingCases { +public: + explicit TilingCases(std::string op_type) : op_type_(std::move(op_type)) + { + } + + template void AddTiling(int32_t priority) + { + OPS_ERR_IF(cases_.find(priority) != cases_.end(), + OPS_REPORT_VECTOR_INNER_ERR(op_type_, "There are duplicate registrations."), return); + cases_[priority] = TILING_CLASS; + OPS_ERR_IF( + cases_[priority] == nullptr, + OPS_REPORT_VECTOR_INNER_ERR(op_type_, "Register op tiling func failed, please check the class name."), + return); + } + + const std::map &GetTilingCases() + { + return cases_; + } + +private: + std::map cases_; + const std::string op_type_; +}; + +class TilingRegistry { +public: + TilingRegistry() = default; + +#ifdef ASCENDC_OP_TEST + static TilingRegistry &GetInstance(); +#else + static TilingRegistry &GetInstance() + { + static TilingRegistry registry_impl_; + return registry_impl_; + } +#endif + + std::shared_ptr RegisterOp(const std::string &op_type) + { + if (registry_map_.find(op_type) == registry_map_.end()) { + registry_map_[op_type] = std::shared_ptr(new (std::nothrow) TilingCases(op_type)); + } + OPS_ERR_IF(registry_map_[op_type] == nullptr, + OPS_REPORT_VECTOR_INNER_ERR(op_type, "Register tiling func failed, please check the class name."), + return nullptr); + return registry_map_[op_type]; + } + + ge::graphStatus DoTilingImpl(gert::TilingContext *context) + { + const char *op_type = context->GetNodeType(); + auto tilingTemplateRegistryMap = GetTilingTemplates(op_type); + for (auto it = tilingTemplateRegistryMap.begin(); it != tilingTemplateRegistryMap.end(); ++it) { + auto tilingTemplate = it->second(context); + if (tilingTemplate != nullptr) { + ge::graphStatus status = tilingTemplate->DoTiling(); + if (status != ge::GRAPH_PARAM_INVALID) { + OPS_LOG_D(context, "Do general op tiling success priority=%d", it->first); + return status; + } + OPS_LOG_D(context, "Ignore general op tiling priority=%d", it->first); + } + } + OPS_REPORT_VECTOR_INNER_ERR(op_type, "Do op tiling failed, no valid template is found."); + return ge::GRAPH_FAILED; + } + + ge::graphStatus DoTilingImpl(gert::TilingContext *context, const std::vector &priorities) + { + const char *op_type = context->GetNodeType(); + auto tilingTemplateRegistryMap = GetTilingTemplates(op_type); + for (auto priorityId : priorities) { + auto templateFunc = tilingTemplateRegistryMap[priorityId](context); + if (templateFunc != nullptr) { + ge::graphStatus status = templateFunc->DoTiling(); + if (status == ge::GRAPH_SUCCESS) { + OPS_LOG_D(context, "Do general op tiling success priority=%d", priorityId); + return status; + } + OPS_LOG_D(context, "Ignore general op tiling priority=%d", priorityId); + } + } + return ge::GRAPH_FAILED; + } + + const std::map &GetTilingTemplates(const std::string &op_type) + { + OPS_ERR_IF(registry_map_.find(op_type) == registry_map_.end(), + OPS_REPORT_VECTOR_INNER_ERR(op_type, "Get op tiling func failed, please check the op name."), + return empty_tiling_case_); + return registry_map_[op_type]->GetTilingCases(); + } + +private: + std::map> registry_map_; + const std::map empty_tiling_case_ {}; +}; + +class Register { +public: + explicit Register(std::string op_type) : op_type_(std::move(op_type)) + { + } + + template Register &tiling(int32_t priority) + { + auto tilingCases = TilingRegistry::GetInstance().RegisterOp(op_type_); + OPS_ERR_IF(tilingCases == nullptr, + OPS_REPORT_VECTOR_INNER_ERR(op_type_, "Register op tiling failed, please the op name."), + return *this); + tilingCases->AddTiling(priority); + return *this; + } + +private: + const std::string op_type_; +}; + +// op_type: 算子名称, class_name: 注册的 tiling 类, +// priority: tiling 类的优先级, 越小表示优先级越高, 即被选中的概率越大 +#define REGISTER_TILING_TEMPLATE(op_type, class_name, priority) \ + static Register VAR_UNUSED##op_type_##class_name##priority_register = Register(op_type).tiling(priority) + +} // namespace optiling diff --git a/csrc/utils/inc/tiling/tiling_type.h b/csrc/utils/inc/tiling/tiling_type.h new file mode 100644 index 00000000..d417b0b6 --- /dev/null +++ b/csrc/utils/inc/tiling/tiling_type.h @@ -0,0 +1,136 @@ +/** + * Copyright (c) 2023-2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file tiling_type.h + * \brief + */ + +#pragma once + +#include + +namespace optiling { + +enum class AxisEnum { + B = 0, + N2 = 1, + G = 2, + S1 = 3, + S2 = 4, + D = 5, + NONE = 9, +}; + +enum class DtypeEnum { + FLOAT16 = 0, + FLOAT32 = 1, + BFLOAT16 = 2, + FLOAT16_PRECISION = 3, +}; + +enum class PerformanceOrientedEnum { + BIG_BUFFER = 1, + BIG_DOUBLE_BUFFER = 2, +}; + +enum class MatmulConfig { + NULL_CONFIG = 0, + NORMAL_CONFIG = 1, + MDL_CONFIG = 2 +}; + +enum class PseConfig { + NO_PSE = 0, + EXIST_PSE = 1 +}; + +enum class AttenMaskConfig { + NO_ATTEN_MASK = 0, + EXIST_ATTEN_MASK = 1 +}; + +enum class DropOutConfig { + NO_DROP_OUT = 0, + EXIST_DROP_OUT = 1 +}; + +enum class CubeFormatEnum { + ND = 0, + NZ = 1 +}; +enum class LayoutEnum { + BSND = 0, + SBND = 1, + BNSD = 2, + TND = 3 +}; + +enum class CubeInputSourceEnum { + GM = 0, + L1 = 1 +}; + +enum class OptionEnum { + DISABLE = 0, + ENABLE = 1 +}; + +enum class SparseEnum { + ALL = 0, + NONE = 1, + ANY = 2, + CAUSAL = 3, + BAND = 4, + PREFIX = 5, + BAND_COMPRESS = 6, + RIGHT_DOWN_CAUSAL = 7, + RIGHT_DOWN_CAUSAL_BAND = 8, + BAND_LEFT_UP_CAUSAL = 9 +}; + +constexpr uint64_t RecursiveSum() +{ + return 0; +} + +template constexpr uint64_t RecursiveSum(T templateId, Args... templateIds) +{ + return static_cast(templateId) + 10 * RecursiveSum(templateIds...); +} + +// TilingKey 的生成规则: +// FlashAttentionScore/FlashAttentionScoreGrad 十进制位组装tiling key,包含以下关键参数,从低位到高位依次是:Ub0, Ub1, +// Block, DataType, Format, Sparse, 特化模板 Ub0、Ub1: +// 表示Ub核内切分的轴,使用枚举AxisEnum表示,因为我们允许最多切分两根轴,所以存在UB0和UB1,如果没有UB核内切分, +// 那么填AXIS_NONE。UB0和UB1各占一个十进制位; +// Block: 表示UB用来分核的轴,使用枚举AxisEnum表示,占一个十进制位; +// DataType: 表示当前tiling key支持的输入输出的数据类型,使用枚举SupportedDtype来表示,占一个十进制位 +// Format: 表示当前tiling key支持的Format, 使用枚举InputLayout表示,占一个十进制位 +// Sparse: 表示当前tiling key是否支持Sparse,使用枚举SparseCapability表示,占一个十进制位 +// 其余特化场景,定义自己的位域和值 +// usage: get tilingKey from inputed types +// uint64_t tilingKey = GET_FLASHATTENTION_TILINGKEY(AxisEnum::AXIS_S1, AxisEnum::AXIS_S2, AxisEnum::AXIS_N2, +// SupportedDtype::FLOAT32, InputLayout::BSH, SparseCapability::SUPPORT_ALL) + +constexpr uint64_t TILINGKEYOFFSET = uint64_t(10000000000000000000UL); // 10^19 +template constexpr uint64_t GET_TILINGKEY(Args... templateIds) +{ + return TILINGKEYOFFSET + RecursiveSum(templateIds...); +} + +// usage: get tilingKey from inputed types +// uint64_t tilingKey = TILINGKEY(S2, S1, N2, FLOAT32, BSND, ALL) + +#define TILINGKEY(ub2, ub1, block, dtype, layout, sparse) \ + (GET_TILINGKEY(AxisEnum::ub2, AxisEnum::ub1, AxisEnum::block, DtypeEnum::dtype, LayoutEnum::layout, \ + SparseEnum::sparse)) + +} // namespace optiling diff --git a/csrc/utils/src/fallback_comm.cpp b/csrc/utils/src/fallback_comm.cpp new file mode 100644 index 00000000..949cb728 --- /dev/null +++ b/csrc/utils/src/fallback_comm.cpp @@ -0,0 +1,53 @@ +/** + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file fallback_comm.cpp + * \brief + */ + +#include "fallback_comm.h" + +#include +#include +#include +#include + +#include "aclnn/aclnn_base.h" +#include "runtime/base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +namespace fallback { +using namespace std; +using namespace gert; +using namespace ge; + +aclDataType ToAclDataType(ge::DataType dtype) { + static const std::vector CANN_CONVERT_TO_ACL_DataType_LIST = { + ge::DataType::DT_FLOAT, ge::DataType::DT_FLOAT16, ge::DataType::DT_INT8, ge::DataType::DT_INT32, + ge::DataType::DT_UINT8, ge::DataType::DT_INT16, ge::DataType::DT_UINT16, ge::DataType::DT_UINT32, + ge::DataType::DT_INT64, ge::DataType::DT_DOUBLE, ge::DataType::DT_BOOL, ge::DataType::DT_STRING, + ge::DataType::DT_COMPLEX64, ge::DataType::DT_COMPLEX128, ge::DataType::DT_BF16, ge::DataType::DT_UINT64, + ge::DataType::DT_INT4}; + auto iter = std::find(CANN_CONVERT_TO_ACL_DataType_LIST.begin(), CANN_CONVERT_TO_ACL_DataType_LIST.end(), dtype); + if (iter == CANN_CONVERT_TO_ACL_DataType_LIST.end()) { + return aclDataType::ACL_DT_UNDEFINED; + } + return static_cast(dtype); +} + +} // namespace fallback + +#ifdef __cplusplus +} +#endif diff --git a/docs/source/developer_guide/feature_guide/add_custom_aclnn_op.md b/docs/source/developer_guide/feature_guide/add_custom_aclnn_op.md new file mode 100644 index 00000000..79a923a0 --- /dev/null +++ b/docs/source/developer_guide/feature_guide/add_custom_aclnn_op.md @@ -0,0 +1,25 @@ +# Adding a custom aclnn operation + +This document describes how to add a custom aclnn operation to vllm-ascend. + +## How custom aclnn operation works in vllm-ascend? + +Custom aclnn operations are built and installed into `vllm_ascend/cann_ops_custom` directory during the build process of vllm-ascend. Then the aclnn operators are bound to `torch.ops._C_ascend` module, enabling users to invoke them in vllm-ascend python code. + +To enable custom operations, use the following code: + +```python +from vllm_ascend.utils import enable_custom_op + +enable_custom_op() +``` + +## How to add a custom aclnn operation? + +- Create a new operation folder under `csrc` directory +- Create `op_host` and `op_kernel` directories for host and kernel source code +- Add build options in `csrc/build_aclnn.sh` for supported SOC. Note that multiple ops should be separated with `;`, i.e. `CUSTOM_OPS=op1;op2;op3` +- Bind aclnn operators to torch.ops._C_ascend module in `csrc/torch_binding.cpp` +- Write a meta implementation in `csrc/torch_binding_meta.cpp` for op being captured into aclgraph + +After a successful build of vllm-ascend, the custom aclnn operation can be invoked in python code. diff --git a/docs/source/developer_guide/feature_guide/index.md b/docs/source/developer_guide/feature_guide/index.md index 91f6badb..592850e6 100644 --- a/docs/source/developer_guide/feature_guide/index.md +++ b/docs/source/developer_guide/feature_guide/index.md @@ -12,4 +12,5 @@ eplb_swift_balancer.md Multi_Token_Prediction ACL_Graph KV_Cache_Pool_Guide +add_custom_aclnn_op ::: diff --git a/pyproject.toml b/pyproject.toml index 7a97edc3..a10ff9a8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,9 +1,11 @@ [build-system] # Should be mirrored in requirements.txt requires = [ + "attrs", "cmake>=3.26", "decorator", "einops", + "googleapis-common-protos", "numpy<2.0.0", "packaging", "pip", @@ -12,6 +14,7 @@ requires = [ "scipy", "pandas", "pandas-stubs", + "psutil", "setuptools>=64", "setuptools-scm>=8", "transformers<=4.57.1", diff --git a/setup.py b/setup.py index 0cee690e..1bf80081 100644 --- a/setup.py +++ b/setup.py @@ -25,7 +25,7 @@ import sys from sysconfig import get_paths from typing import Dict, List -from setuptools import Extension, find_packages, setup +from setuptools import Command, Extension, find_packages, setup from setuptools.command.build_ext import build_ext from setuptools.command.build_py import build_py from setuptools.command.develop import develop @@ -199,6 +199,27 @@ class custom_build_info(build_py): super().run() +class build_and_install_aclnn(Command): + description = "Build and install AclNN by running build_aclnn.sh" + user_options = [] + + def initialize_options(self): + pass + + def finalize_options(self): + pass + + def run(self): + try: + print("Running bash build_aclnn.sh ...") + subprocess.check_call( + ["bash", "csrc/build_aclnn.sh", ROOT_DIR, envs.SOC_VERSION]) + print("buid_aclnn.sh executed successfully!") + except subprocess.CalledProcessError as e: + print(f"Error running build_aclnn.sh: {e}") + raise SystemExit(e.returncode) + + class cmake_build_ext(build_ext): # A dict of extension directories that have been configured. did_config: Dict[str, bool] = {} @@ -385,8 +406,22 @@ class cmake_build_ext(build_ext): shutil.copy(src_path, dst_path) print(f"Copy: {src_path} -> {dst_path}") + # copy back _cann_ops_custom directory + src_cann_ops_custom = os.path.join(ROOT_DIR, "vllm_ascend", + "_cann_ops_custom") + dst_cann_ops_custom = os.path.join(self.build_lib, "vllm_ascend", + "_cann_ops_custom") + if os.path.exists(src_cann_ops_custom): + import shutil + if os.path.exists(dst_cann_ops_custom): + shutil.rmtree(dst_cann_ops_custom) + shutil.copytree(src_cann_ops_custom, dst_cann_ops_custom) + print(f"Copy: {src_cann_ops_custom} -> {dst_cann_ops_custom}") + def run(self): - # First, run the standard build_ext command to compile the extensions + # First, ensure ACLNN custom-ops is built and installed. + self.run_command("build_aclnn") + # Then, run the standard build_ext command to compile the extensions super().run() @@ -450,6 +485,7 @@ def get_requirements() -> List[str]: cmdclass = { "develop": custom_develop, "build_py": custom_build_info, + "build_aclnn": build_and_install_aclnn, "build_ext": cmake_build_ext, "install": custom_install } diff --git a/tests/e2e/nightly/ops/test_gmm_swiglu_quant_weight_nz_tensor_list.py b/tests/e2e/nightly/ops/test_gmm_swiglu_quant_weight_nz_tensor_list.py new file mode 100644 index 00000000..7e87e6d4 --- /dev/null +++ b/tests/e2e/nightly/ops/test_gmm_swiglu_quant_weight_nz_tensor_list.py @@ -0,0 +1,148 @@ +import gc + +import torch +import torch_npu + +from vllm_ascend.utils import enable_custom_op + +# enable internal format +torch_npu.npu.config.allow_internal_format = True +# enable vllm-ascend custom ops +enable_custom_op() + + +def gmm_swiglu_quant(x: torch.Tensor, weight: torch.Tensor, + perChannelScale: torch.Tensor, + perTokenScale: torch.Tensor, m: int): + """ + Perform quantized GMM (Grouped Matrix Multiplication) operation with SwiGLU activation function. + + Parameters: + x (torch.Tensor): Input tensor with shape (m, k). + weight (torch.Tensor): Weight tensor with shape (k, n). + perChannelScale (torch.Tensor): Per-channel scaling factor with shape (n,). + perTokenScale (torch.Tensor): Per-token scaling factor with shape (m,). + m (int): Number of tokens (rows of x). + + Returns: + quantOutput (torch.Tensor): Quantized output tensor with shape (m, k // 2). + quantScaleOutput (torch.Tensor): Quantization scaling factor with shape (m,). + """ + # Perform matrix multiplication with int32 precision + c_temp1 = torch.matmul(x.to(torch.int32), weight.to(torch.int32)) + c_temp1 = c_temp1.to(torch.float32) # Convert back to float32 for scaling + + # Apply per-channel and per-token scaling + c_temp2 = torch.mul(c_temp1, perChannelScale) + c_temp3 = torch.mul(c_temp2, perTokenScale.reshape(m, 1)) + + # Split the result into two parts to apply SwiGLU activation function + c_temp4, gate = c_temp3.chunk(2, dim=-1) + c_temp5 = c_temp4 * torch.sigmoid(c_temp4) # SwiGLU activation + c_temp6 = c_temp5 * gate # Element-wise multiplication with gating values + + # Quantize the output + max = torch.max( + torch.abs(c_temp6), + -1).values # Find maximum absolute value to calculate scaling factor + quantScaleOutput = 127 / max # Calculate quantization scaling factor + quantOutput = torch.round(c_temp6 * quantScaleOutput.reshape(m, 1)).to( + torch.int8) # Quantize to int8 + quantScaleOutput = 1 / quantScaleOutput # Inverse quantization scaling factor for subsequent dequantization + + return quantOutput, quantScaleOutput + + +def process_groups(x: torch.Tensor, weight: torch.Tensor, + perChannelScale: torch.Tensor, perTokenScale: torch.Tensor, + groupList: torch.Tensor): + """ + Process input data by groups and call GMM_Swiglu_quant function for quantized computation. + + Parameters: + x (torch.Tensor): Input tensor with shape (M, K). + weight (torch.Tensor): List of weight tensors, each with shape (E, K, N). + perChannelScale (torch.Tensor): List of per-channel scaling factors, each with shape (E, N). + perTokenScale (torch.Tensor): Per-token scaling factor with shape (M,). + groupList (list): List defining the number of tokens in each group. + + Returns: + quantOutput (torch.Tensor): Quantized output tensor with shape (M, N // 2). + quantScaleOutput (torch.Tensor): Quantization scaling factor with shape (M,). + """ + M, N = x.shape[0], weight.shape[2] # Get the shape of the input tensor + quantOutput = torch.zeros(M, N // 2).to( + torch.int8) # Initialize quantized output tensor + quantScaleOutput = torch.zeros(M).to( + torch.float32) # Initialize quantization scaling factor tensor + + start_idx = 0 # Starting index + preV = 0 # Number of tokens in the previous group + groupList = groupList.tolist() + # Iterate through groupList to process data by groups + for i, v in enumerate(groupList): + currV = v + tempV = currV - preV # Calculate number of tokens in the current group + preV = currV # Update number of tokens in the previous group + if tempV > 0: + # Call GMM_Swiglu_quant to process the current group + quantOutput[start_idx:start_idx + tempV], quantScaleOutput[start_idx:start_idx + tempV] = \ + gmm_swiglu_quant(x[start_idx:start_idx + tempV], + weight[i], + perChannelScale[i], + perTokenScale[start_idx:start_idx + tempV], + tempV) + + start_idx += tempV # Update starting index to process the next group + return quantOutput, quantScaleOutput + + +@torch.inference_mode() +def test_gmm_swiglu_quant_weight_nz_tensor_list(): + M, K, E, N = 8192, 7168, 4, 4096 + + # x (M, K) - int8 + x = torch.randint(-128, 127, (M, K), dtype=torch.int8) + + # weight (E, N, K) - int8 + weight = torch.randint(-128, 127, size=(E, K, N), dtype=torch.int8) + + # weight_scale (E, N) - float32 + weight_scale = torch.rand(E, N) * 0.9 + 0.1 # uniform(0.1, 1.0) + weight_scale = weight_scale.to(torch.float32) + + weight_nz_npu = [] + weight_scale_npu = [] + for i in range(E): + weight_nz_npu.append(torch_npu.npu_format_cast(weight[i].npu(), 29)) + weight_scale_npu.append(weight_scale[i].npu()) + + # x_scale (M,) - float32 + x_scale = torch.rand(M) * 0.9 + 0.1 # uniform(0.1, 1.0) + x_scale = x_scale.to(torch.float32) + + group_list = torch.tensor([2048, 4096, 6144, 8192], dtype=torch.int64) + + output_cpu, output_scale_cpu = process_groups(x, weight, weight_scale, + x_scale, group_list) + output_npu, output_scale_npu, _ = \ + torch.ops._C_ascend.grouped_matmul_swiglu_quant_weight_nz_tensor_list(x.npu(), + weight_nz_npu, + weight_scale_npu, + x_scale.npu(), + group_list.npu()) + output_npu_valid = output_npu[:group_list[-1], :] + output_scale_npu_valid = output_scale_npu[:group_list[-1]] + + torch.testing.assert_close(output_npu_valid.cpu(), + output_cpu, + atol=1, + rtol=2**-13) + torch.testing.assert_close(output_scale_npu_valid.cpu(), + output_scale_cpu, + atol=1e-9, + rtol=1e-6) + + gc.collect() + torch.npu.empty_cache() + torch.npu.reset_peak_memory_stats() diff --git a/vllm_ascend/_cann_ops_custom/.gitkeep b/vllm_ascend/_cann_ops_custom/.gitkeep new file mode 100644 index 00000000..df36e2ec --- /dev/null +++ b/vllm_ascend/_cann_ops_custom/.gitkeep @@ -0,0 +1,3 @@ +# This folder is reserved for the installation of custom aclnn operators tailored for vLLM-Ascend. +# Source code of the operators can be found in the `src` folder. +# The operators are compiled into a custom CANN software package and installed to this folder automatically. diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 9e8b2593..7cc84fc6 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -38,6 +38,27 @@ from vllm_ascend.utils import ( prefill_context_parallel_enable, update_aclgraph_sizes, update_cudagraph_capture_sizes, update_default_aclgraph_sizes) +# set custom ops path +CUR_DIR = os.path.dirname(os.path.realpath(__file__)) +CUSTOM_OPP_PATH = os.path.join(CUR_DIR, "vllm_ascend", "_cann_ops_custom", + "vendors", "customize") +CUSTOM_LIB_PATH = os.path.join(CUSTOM_OPP_PATH, "op_api", "lib") + +if os.path.exists(CUSTOM_OPP_PATH): + current_cust_opp_path = os.environ.get("ASCEND_CUSTOM_OPP_PATH", "") + if current_cust_opp_path: + os.environ[ + "ASCEND_CUSTOM_OPP_PATH"] = f"{CUSTOM_OPP_PATH}:{current_cust_opp_path}" + else: + os.environ["ASCEND_CUSTOM_OPP_PATH"] = CUSTOM_OPP_PATH + +if os.path.exists(CUSTOM_LIB_PATH): + current_lib_path = os.environ.get("LD_LIBRARY_PATH", "") + if current_lib_path: + os.environ["LD_LIBRARY_PATH"] = f"{CUSTOM_LIB_PATH}:{current_lib_path}" + else: + os.environ["LD_LIBRARY_PATH"] = CUSTOM_LIB_PATH + if TYPE_CHECKING: from vllm.config import ModelConfig, VllmConfig from vllm.utils import FlexibleArgumentParser