[Kernel] add custom op GmmSwigluQuantWeightNzTensorList (#3804)
### What this PR does / why we need it? This PR introduces support for adding custom CANN `aclnn` ops to `vllm-ascend`, allowing users to define and use their own custom operators. Key changes include: - Building and installing custom ops into the `vllm-ascend`-specified directory - Binding the `aclnn` op interface to the `torch.ops._C_ascend` module - Enabling invocation of these ops within `vllm-ascend` This PR includes a sample custom op: `aclnnGroupedMatmulSwigluQuantWeightNzTensorList`, which is adapted from the CANN operator [`aclnnGroupedMatmulSwigluQuantWeightNZ`](https://www.hiascend.com/document/detail/zh/canncommercial/83RC1/API/aolapi/context/aclnnGroupedMatmulSwigluQuantWeightNZ.md). Its input parameters `weight` and `weight_scale` now accept `list[torch.Tensor]` (i.e., `at::TensorList`). ### Does this PR introduce _any_ user-facing change? No. - vLLM version: v0.11.2 --------- Signed-off-by: QianChenxi <chenxi.qian.cq@outlook.com>
This commit is contained in:
5
.github/workflows/release_whl.yml
vendored
5
.github/workflows/release_whl.yml
vendored
@@ -96,6 +96,11 @@ jobs:
|
||||
--exclude libge_common_base.so \
|
||||
--exclude libc10.so \
|
||||
--exclude libc_sec.so \
|
||||
--exclude libnnopbase.so \
|
||||
--exclude libprofapi.so \
|
||||
--exclude libgraph_base.so \
|
||||
--exclude libgraph.so \
|
||||
--exclude libexe_graph.so \
|
||||
--exclude "libascend*.so" \
|
||||
--exclude "libtorch*.so" \
|
||||
--exclude "libopapi.so" \
|
||||
|
||||
@@ -12,7 +12,7 @@ repos:
|
||||
- id: codespell
|
||||
args: [
|
||||
--toml, pyproject.toml,
|
||||
'--skip', 'tests/e2e/multicard/test_torchair_graph_mode.py,csrc/mla_preprocess/**,tests/prompts/**,./benchmarks/sonnet.txt,*tests/lora/data/**,build/**,./vllm_ascend.egg-info/**,.github/**,typos.toml',
|
||||
'--skip', 'tests/e2e/multicard/test_torchair_graph_mode.py,csrc/**,tests/prompts/**,./benchmarks/sonnet.txt,*tests/lora/data/**,build/**,./vllm_ascend.egg-info/**,.github/**,typos.toml',
|
||||
'-L', 'CANN,cann,NNAL,nnal,ASCEND,ascend,EnQue,CopyIn,ArchType,AND'
|
||||
]
|
||||
additional_dependencies:
|
||||
@@ -37,7 +37,7 @@ repos:
|
||||
- id: typos
|
||||
args: [
|
||||
"--force-exclude",
|
||||
"--exclude", "csrc/mla_preprocess/**"
|
||||
"--exclude", "csrc/**"
|
||||
]
|
||||
- repo: https://github.com/PyCQA/isort
|
||||
rev: 6.0.1
|
||||
|
||||
@@ -82,6 +82,7 @@ set(
|
||||
${TORCH_NPU_INCLUDE_DIRS}
|
||||
${ASCEND_HOME_PATH}/include
|
||||
${ASCEND_HOME_PATH}/aarch64-linux/include/experiment/platform
|
||||
${ASCEND_HOME_PATH}/x86_64-linux/include/experiment/platform
|
||||
)
|
||||
|
||||
pybind11_add_module(vllm_ascend_C ${VLLM_ASCEND_SRC})
|
||||
|
||||
642
csrc/CMakeLists.txt
Normal file
642
csrc/CMakeLists.txt
Normal file
@@ -0,0 +1,642 @@
|
||||
# Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
# This file is a part of the CANN Open Software.
|
||||
# Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
# Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See LICENSE in the root of the software repository for the full text of the License.
|
||||
# ======================================================================================================================
|
||||
|
||||
cmake_minimum_required(VERSION 3.16)
|
||||
|
||||
project(cann_ops_custom)
|
||||
|
||||
option(BUILD_OPEN_PROJECT "Build open ascend ops project." ON)
|
||||
option(ENABLE_CCACHE "Enable ccache capability" ON)
|
||||
set(ASCEND_COMPUTE_UNIT "ascend910b" CACHE STRING "soc that need to be compiled")
|
||||
set(ASCEND_OP_NAME "ALL" CACHE STRING "operators that need to be compiled")
|
||||
set(VENDOR_NAME "customize" CACHE STRING "vendor name")
|
||||
|
||||
include(cmake/config.cmake)
|
||||
include(cmake/func.cmake)
|
||||
include(cmake/intf.cmake)
|
||||
|
||||
if (BUILD_OPEN_PROJECT)
|
||||
set(_op_host_aclnn_link
|
||||
$<BUILD_INTERFACE:intf_pub>
|
||||
exe_graph
|
||||
register
|
||||
c_sec
|
||||
)
|
||||
set(CMAKE_MODULE_PATH
|
||||
${CMAKE_MODULE_PATH}
|
||||
${CMAKE_CURRENT_LIST_DIR}/cmake/modules
|
||||
)
|
||||
|
||||
set(CMAKE_PREFIX_PATH
|
||||
${CMAKE_PREFIX_PATH}
|
||||
${ASCEND_CANN_PACKAGE_PATH}
|
||||
)
|
||||
|
||||
find_package(alog MODULE REQUIRED)
|
||||
add_library(op_host_aclnn SHARED EXCLUDE_FROM_ALL)
|
||||
target_link_libraries(op_host_aclnn PRIVATE
|
||||
${_op_host_aclnn_link}
|
||||
)
|
||||
target_compile_options(op_host_aclnn PRIVATE
|
||||
$<$<COMPILE_LANGUAGE:CXX>:-std=gnu++1z>
|
||||
)
|
||||
|
||||
add_library(op_host_aclnnInner SHARED EXCLUDE_FROM_ALL)
|
||||
target_link_libraries(op_host_aclnnInner PRIVATE
|
||||
${_op_host_aclnn_link}
|
||||
)
|
||||
target_compile_options(op_host_aclnnInner PRIVATE
|
||||
$<$<COMPILE_LANGUAGE:CXX>:-std=gnu++1z>
|
||||
)
|
||||
|
||||
add_library(op_host_aclnnExc SHARED EXCLUDE_FROM_ALL)
|
||||
target_link_libraries(op_host_aclnnExc PRIVATE
|
||||
${_op_host_aclnn_link}
|
||||
)
|
||||
target_compile_options(op_host_aclnnExc PRIVATE
|
||||
$<$<COMPILE_LANGUAGE:CXX>:-std=gnu++1z>
|
||||
)
|
||||
|
||||
# op api
|
||||
add_library(opapi SHARED)
|
||||
# When compiling a specified operator without aclnn src
|
||||
if (NOT "${ASCEND_OP_NAME}" STREQUAL "ALL")
|
||||
add_custom_command(OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/opapi_stub.cpp
|
||||
COMMAND touch ${CMAKE_CURRENT_BINARY_DIR}/opapi_stub.cpp
|
||||
)
|
||||
target_sources(opapi PRIVATE
|
||||
${CMAKE_CURRENT_BINARY_DIR}/opapi_stub.cpp
|
||||
)
|
||||
endif()
|
||||
|
||||
target_compile_options(opapi PRIVATE
|
||||
$<$<COMPILE_LANGUAGE:CXX>:-std=gnu++1z>
|
||||
)
|
||||
target_include_directories(opapi PRIVATE
|
||||
$<BUILD_INTERFACE:${ASCEND_CANN_PACKAGE_PATH}/include>
|
||||
$<BUILD_INTERFACE:${ASCEND_CANN_PACKAGE_PATH}/include/aclnn>
|
||||
$<BUILD_INTERFACE:${ASCEND_CANN_PACKAGE_PATH}/include/aclnn_kernels>
|
||||
)
|
||||
target_compile_options(opapi PRIVATE
|
||||
-Werror=format
|
||||
)
|
||||
target_compile_definitions(opapi PRIVATE
|
||||
-DACLNN_LOG_FMT_CHECK
|
||||
)
|
||||
target_link_libraries(opapi PRIVATE
|
||||
$<BUILD_INTERFACE:intf_pub>
|
||||
-Wl,--whole-archive
|
||||
ops_aclnn
|
||||
-Wl,--no-whole-archive
|
||||
-lopapi
|
||||
nnopbase
|
||||
profapi
|
||||
ge_common_base
|
||||
ascend_dump
|
||||
ascendalog
|
||||
dl
|
||||
)
|
||||
set_target_properties(opapi PROPERTIES OUTPUT_NAME
|
||||
cust_opapi
|
||||
)
|
||||
install(TARGETS opapi
|
||||
LIBRARY DESTINATION packages/vendors/${VENDOR_NAME}/op_api/lib
|
||||
)
|
||||
|
||||
# op proto
|
||||
add_library(opsproto SHARED)
|
||||
target_compile_options(opsproto PRIVATE
|
||||
$<$<COMPILE_LANGUAGE:CXX>:-std=c++11>
|
||||
-fvisibility=hidden
|
||||
)
|
||||
target_compile_definitions(opsproto PRIVATE
|
||||
LOG_CPP
|
||||
PROCESS_LOG
|
||||
)
|
||||
target_link_libraries(opsproto PRIVATE
|
||||
$<BUILD_INTERFACE:intf_pub>
|
||||
$<BUILD_INTERFACE:ops_utils_proto_headers>
|
||||
-Wl,--whole-archive
|
||||
rt2_registry
|
||||
-Wl,--no-whole-archive
|
||||
-Wl,--no-as-needed
|
||||
exe_graph
|
||||
graph
|
||||
graph_base
|
||||
register
|
||||
ascendalog
|
||||
error_manager
|
||||
platform
|
||||
-Wl,--as-needed
|
||||
c_sec
|
||||
)
|
||||
set_target_properties(opsproto PROPERTIES OUTPUT_NAME
|
||||
cust_opsproto_rt2.0
|
||||
)
|
||||
install(TARGETS opsproto
|
||||
LIBRARY DESTINATION packages/vendors/${VENDOR_NAME}/op_proto/lib/linux/${CMAKE_SYSTEM_PROCESSOR}
|
||||
)
|
||||
|
||||
# op tiling
|
||||
add_library(optiling SHARED)
|
||||
target_sources(optiling PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils/src/fallback_comm.cpp
|
||||
)
|
||||
target_compile_options(optiling PRIVATE
|
||||
$<$<COMPILE_LANGUAGE:CXX>:-std=c++11>
|
||||
-fvisibility=hidden
|
||||
)
|
||||
target_compile_definitions(optiling PRIVATE
|
||||
LOG_CPP
|
||||
PROCESS_LOG
|
||||
)
|
||||
target_link_libraries(optiling PRIVATE
|
||||
$<BUILD_INTERFACE:intf_pub>
|
||||
$<BUILD_INTERFACE:ops_utils_tiling_headers>
|
||||
-Wl,--whole-archive
|
||||
rt2_registry
|
||||
-Wl,--no-whole-archive
|
||||
-Wl,--no-as-needed
|
||||
graph
|
||||
graph_base
|
||||
exe_graph
|
||||
platform
|
||||
register
|
||||
ascendalog
|
||||
error_manager
|
||||
-Wl,--as-needed
|
||||
-Wl,--whole-archive
|
||||
tiling_api
|
||||
-Wl,--no-whole-archive
|
||||
mmpa
|
||||
c_sec
|
||||
)
|
||||
set_target_properties(optiling PROPERTIES OUTPUT_NAME
|
||||
cust_opmaster_rt2.0
|
||||
)
|
||||
add_custom_command(TARGET optiling
|
||||
POST_BUILD
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory ${TILING_CUSTOM_DIR}
|
||||
COMMAND ln -sf $<TARGET_FILE:optiling> ${TILING_CUSTOM_FILE}
|
||||
)
|
||||
install(TARGETS optiling
|
||||
LIBRARY DESTINATION packages/vendors/${VENDOR_NAME}/op_impl/ai_core/tbe/op_tiling/lib/linux/${CMAKE_SYSTEM_PROCESSOR}
|
||||
)
|
||||
|
||||
# optiling compat
|
||||
set(compat_optiling_dir ${CMAKE_CURRENT_BINARY_DIR}/compat)
|
||||
set(compat_optiling_file ${compat_optiling_dir}/liboptiling.so)
|
||||
add_custom_target(optiling_compat ALL
|
||||
DEPENDS ${compat_optiling_file}
|
||||
)
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT ${compat_optiling_file}
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory ${compat_optiling_dir}
|
||||
COMMAND ln -sf lib/linux/${CMAKE_SYSTEM_PROCESSOR}/$<TARGET_FILE_NAME:optiling> ${compat_optiling_file}
|
||||
)
|
||||
|
||||
install(FILES ${compat_optiling_file}
|
||||
DESTINATION packages/vendors/${VENDOR_NAME}/op_impl/ai_core/tbe/op_tiling
|
||||
)
|
||||
|
||||
add_ops_tiling_keys(
|
||||
OP_NAME "ALL"
|
||||
TILING_KEYS ${TILING_KEY}
|
||||
)
|
||||
|
||||
add_opc_config(
|
||||
OP_NAME "ALL"
|
||||
CONFIG ${OP_DEBUG_CONFIG}
|
||||
)
|
||||
|
||||
if(ADD_OPS_COMPILE_OPTION_V2)
|
||||
add_ops_compile_options(
|
||||
OP_NAME "ALL"
|
||||
OPTIONS ${OPS_COMPILE_OPTIONS}
|
||||
)
|
||||
endif()
|
||||
endif ()
|
||||
|
||||
add_subdirectory(utils)
|
||||
|
||||
set(OP_LIST)
|
||||
set(OP_DIR_LIST)
|
||||
op_add_subdirectory(OP_LIST OP_DIR_LIST)
|
||||
|
||||
foreach (OP_DIR ${OP_DIR_LIST})
|
||||
add_subdirectory(${OP_DIR}/op_host)
|
||||
endforeach ()
|
||||
|
||||
set(OP_DEPEND_DIR_LIST)
|
||||
op_add_depend_directory(
|
||||
OP_LIST ${OP_LIST}
|
||||
OP_DIR_LIST OP_DEPEND_DIR_LIST
|
||||
)
|
||||
foreach (OP_DEPEND_DIR ${OP_DEPEND_DIR_LIST})
|
||||
add_subdirectory(${OP_DEPEND_DIR}/op_host)
|
||||
endforeach ()
|
||||
|
||||
# ------------------------------------------------ aclnn ------------------------------------------------
|
||||
get_target_property(base_aclnn_srcs op_host_aclnn SOURCES)
|
||||
get_target_property(base_aclnn_inner_srcs op_host_aclnnInner SOURCES)
|
||||
get_target_property(base_aclnn_exclude_srcs op_host_aclnnExc SOURCES)
|
||||
|
||||
if (BUILD_OPEN_PROJECT)
|
||||
set(base_aclnn_binary_dir ${ASCEND_AUTOGEN_DIR})
|
||||
else()
|
||||
get_target_property(base_aclnn_binary_dir op_host_aclnn BINARY_DIR)
|
||||
endif ()
|
||||
|
||||
set(generate_aclnn_srcs)
|
||||
set(generate_aclnn_inner_srcs)
|
||||
set(generate_aclnn_headers)
|
||||
set(generate_proto_dir ${base_aclnn_binary_dir})
|
||||
set(generate_exclude_proto_srcs)
|
||||
set(generate_proto_srcs)
|
||||
set(generate_proto_headers)
|
||||
|
||||
if (base_aclnn_srcs)
|
||||
foreach (_src ${base_aclnn_srcs})
|
||||
string(REGEX MATCH "^${CMAKE_CURRENT_SOURCE_DIR}" is_match "${_src}")
|
||||
if (is_match)
|
||||
get_filename_component(name_without_ext ${_src} NAME_WE)
|
||||
|
||||
string(REGEX REPLACE "_def$" "" _op_name ${name_without_ext})
|
||||
list(APPEND generate_aclnn_srcs ${base_aclnn_binary_dir}/aclnn_${_op_name}.cpp)
|
||||
list(APPEND generate_aclnn_headers ${base_aclnn_binary_dir}/aclnn_${_op_name}.h)
|
||||
list(APPEND generate_proto_srcs ${generate_proto_dir}/${_op_name}_proto.cpp)
|
||||
list(APPEND generate_proto_headers ${generate_proto_dir}/${_op_name}_proto.h)
|
||||
endif ()
|
||||
endforeach ()
|
||||
else ()
|
||||
add_custom_command(OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/op_host_aclnn_stub.cpp
|
||||
COMMAND touch ${CMAKE_CURRENT_BINARY_DIR}/op_host_aclnn_stub.cpp
|
||||
)
|
||||
|
||||
target_sources(op_host_aclnn PRIVATE
|
||||
${CMAKE_CURRENT_BINARY_DIR}/op_host_aclnn_stub.cpp
|
||||
)
|
||||
endif ()
|
||||
|
||||
if (base_aclnn_inner_srcs)
|
||||
foreach (_src ${base_aclnn_inner_srcs})
|
||||
string(REGEX MATCH "^${CMAKE_CURRENT_SOURCE_DIR}" is_match "${_src}")
|
||||
if (is_match)
|
||||
get_filename_component(name_without_ext ${_src} NAME_WE)
|
||||
string(REGEX REPLACE "_def$" "" _op_name ${name_without_ext})
|
||||
list(APPEND generate_aclnn_inner_srcs ${base_aclnn_binary_dir}/inner/aclnnInner_${_op_name}.cpp)
|
||||
list(APPEND generate_proto_srcs ${generate_proto_dir}/inner/${_op_name}_proto.cpp)
|
||||
list(APPEND generate_proto_headers ${generate_proto_dir}/inner/${_op_name}_proto.h)
|
||||
endif ()
|
||||
endforeach ()
|
||||
else ()
|
||||
add_custom_command(OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/op_host_aclnn_inner_stub.cpp
|
||||
COMMAND touch ${CMAKE_CURRENT_BINARY_DIR}/op_host_aclnn_inner_stub.cpp
|
||||
)
|
||||
|
||||
target_sources(op_host_aclnnInner PRIVATE
|
||||
${CMAKE_CURRENT_BINARY_DIR}/op_host_aclnn_inner_stub.cpp
|
||||
)
|
||||
endif ()
|
||||
|
||||
if (base_aclnn_exclude_srcs)
|
||||
foreach (_src ${base_aclnn_exclude_srcs})
|
||||
string(REGEX MATCH "^${CMAKE_CURRENT_SOURCE_DIR}" is_match "${_src}")
|
||||
if (is_match)
|
||||
get_filename_component(name_without_ext ${_src} NAME_WE)
|
||||
string(REGEX REPLACE "_def$" "" _op_name ${name_without_ext})
|
||||
list(APPEND generate_exclude_proto_srcs ${generate_proto_dir}/exc/${_op_name}_proto.cpp)
|
||||
list(APPEND generate_proto_srcs ${generate_proto_dir}/exc/${_op_name}_proto.cpp)
|
||||
list(APPEND generate_proto_headers ${generate_proto_dir}/exc/${_op_name}_proto.h)
|
||||
endif ()
|
||||
endforeach ()
|
||||
else()
|
||||
add_custom_command(OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/op_host_aclnn_exc_stub.cpp
|
||||
COMMAND touch ${CMAKE_CURRENT_BINARY_DIR}/op_host_aclnn_exc_stub.cpp
|
||||
)
|
||||
|
||||
target_sources(op_host_aclnnExc PRIVATE
|
||||
${CMAKE_CURRENT_BINARY_DIR}/op_host_aclnn_exc_stub.cpp
|
||||
)
|
||||
endif ()
|
||||
|
||||
if (BUILD_OPEN_PROJECT)
|
||||
if (generate_aclnn_srcs OR generate_aclnn_inner_srcs)
|
||||
set(ops_aclnn_src ${generate_aclnn_srcs} ${generate_aclnn_inner_srcs})
|
||||
else ()
|
||||
set(ops_aclnn_src ${CMAKE_CURRENT_BINARY_DIR}/ops_aclnn_src_stub.cpp)
|
||||
|
||||
add_custom_command(OUTPUT ${ops_aclnn_src}
|
||||
COMMAND touch ${ops_aclnn_src}
|
||||
)
|
||||
endif ()
|
||||
|
||||
set_source_files_properties(${ops_aclnn_src}
|
||||
PROPERTIES GENERATED TRUE
|
||||
)
|
||||
add_library(ops_aclnn STATIC
|
||||
${ops_aclnn_src}
|
||||
)
|
||||
target_compile_options(ops_aclnn PRIVATE
|
||||
$<$<COMPILE_LANGUAGE:CXX>:-std=gnu++1z>
|
||||
)
|
||||
target_link_libraries(ops_aclnn PRIVATE
|
||||
$<BUILD_INTERFACE:intf_pub>
|
||||
)
|
||||
add_dependencies(ops_aclnn opbuild_gen_default opbuild_gen_inner)
|
||||
|
||||
set_source_files_properties(${generate_proto_srcs}
|
||||
PROPERTIES GENERATED TRUE
|
||||
)
|
||||
target_sources(opsproto PRIVATE
|
||||
${generate_proto_srcs}
|
||||
)
|
||||
add_dependencies(opsproto ops_proto_headers)
|
||||
|
||||
install(FILES ${generate_proto_headers}
|
||||
DESTINATION packages/vendors/${VENDOR_NAME}/op_proto/inc OPTIONAL
|
||||
)
|
||||
|
||||
redefine_file_macro(
|
||||
TARGET_NAME
|
||||
op_host_aclnn
|
||||
op_host_aclnnInner
|
||||
op_host_aclnnExc
|
||||
opapi
|
||||
opsproto
|
||||
optiling
|
||||
ops_aclnn
|
||||
)
|
||||
else()
|
||||
if (generate_aclnn_srcs OR generate_aclnn_inner_srcs)
|
||||
set_source_files_properties(${generate_aclnn_srcs} ${generate_aclnn_inner_srcs}
|
||||
TARGET_DIRECTORY acl_op_builtin
|
||||
PROPERTIES GENERATED TRUE
|
||||
)
|
||||
|
||||
target_sources(acl_op_builtin PRIVATE
|
||||
${generate_aclnn_srcs}
|
||||
${generate_aclnn_inner_srcs}
|
||||
)
|
||||
endif ()
|
||||
|
||||
if (generate_proto_srcs)
|
||||
set_source_files_properties(${generate_proto_srcs}
|
||||
TARGET_DIRECTORY opsproto opsproto_rt2.0
|
||||
PROPERTIES GENERATED TRUE
|
||||
)
|
||||
target_sources(opsproto PRIVATE
|
||||
${generate_proto_srcs}
|
||||
)
|
||||
add_dependencies(opsproto ops_proto_headers)
|
||||
|
||||
target_sources(opsproto_rt2.0 PRIVATE
|
||||
${generate_proto_srcs}
|
||||
)
|
||||
add_dependencies(opsproto_rt2.0 ops_proto_headers)
|
||||
endif ()
|
||||
|
||||
add_target_source(
|
||||
TARGET_NAME opmaster_rt2.0 opmaster_static_rt2.0
|
||||
BASE_TARGET optiling
|
||||
SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
)
|
||||
|
||||
add_target_source(
|
||||
TARGET_NAME opsproto_rt2.0 opsproto_static_rt2.0
|
||||
BASE_TARGET opsproto
|
||||
SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
)
|
||||
|
||||
add_static_ops(
|
||||
ACLNN_SRC ${generate_aclnn_srcs}
|
||||
ACLNN_INNER_SRC ${generate_aclnn_inner_srcs}
|
||||
SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
)
|
||||
endif ()
|
||||
|
||||
if (generate_aclnn_headers)
|
||||
install(FILES ${generate_aclnn_headers}
|
||||
DESTINATION ${ACLNN_INC_INSTALL_DIR} OPTIONAL
|
||||
)
|
||||
endif ()
|
||||
|
||||
add_library(ops_proto_headers INTERFACE)
|
||||
|
||||
target_include_directories(ops_proto_headers INTERFACE
|
||||
$<BUILD_INTERFACE:${generate_proto_dir}>
|
||||
$<BUILD_INTERFACE:${generate_proto_dir}/inner>
|
||||
$<BUILD_INTERFACE:${generate_proto_dir}/exc>
|
||||
$<INSTALL_INTERFACE:include/ops_adv/proto>
|
||||
)
|
||||
|
||||
if ((NOT BUILD_OPEN_PROJECT) AND ("${PRODUCT_SIDE}" STREQUAL "device"))
|
||||
ExternalProject_Add(extern_opbuild_gen_default
|
||||
SOURCE_DIR ${TOP_DIR}/cmake/superbuild
|
||||
CONFIGURE_COMMAND ${CMAKE_COMMAND}
|
||||
-G ${CMAKE_GENERATOR}
|
||||
-DHOST_PACKAGE=opp
|
||||
-DBUILD_MOD=ops
|
||||
-DCMAKE_INSTALL_PREFIX=${CMAKE_CURRENT_BINARY_DIR}/opbuild_output
|
||||
-DFEATURE_LIST=custom_opbuild_out_dir=${generate_proto_dir}
|
||||
<SOURCE_DIR>
|
||||
BUILD_COMMAND TARGETS=opbuild_gen_all $(MAKE)
|
||||
INSTALL_COMMAND ""
|
||||
LIST_SEPARATOR ::
|
||||
EXCLUDE_FROM_ALL TRUE
|
||||
)
|
||||
add_dependencies(ops_proto_headers extern_opbuild_gen_default)
|
||||
else()
|
||||
add_dependencies(ops_proto_headers opbuild_gen_default opbuild_gen_inner opbuild_gen_exc)
|
||||
endif ()
|
||||
|
||||
if (NOT BUILD_OPEN_PROJECT)
|
||||
if (generate_proto_srcs)
|
||||
install_package(
|
||||
PACKAGE ops_adv
|
||||
TARGETS ops_proto_headers
|
||||
FILES ${generate_proto_headers}
|
||||
DESTINATION include/ops_adv/proto
|
||||
)
|
||||
endif ()
|
||||
endif ()
|
||||
|
||||
# ------------------------------------------------ opbuild ------------------------------------------------
|
||||
if (BUILD_OPEN_PROJECT)
|
||||
if (generate_aclnn_srcs)
|
||||
add_custom_command(OUTPUT ${generate_aclnn_srcs} ${generate_aclnn_headers}
|
||||
COMMAND mkdir -p ${base_aclnn_binary_dir}
|
||||
COMMAND OPS_PROTO_SEPARATE=1
|
||||
OPS_ACLNN_GEN=1
|
||||
OPS_PROJECT_NAME=aclnn
|
||||
${OP_BUILD_TOOL}
|
||||
$<TARGET_FILE:op_host_aclnn>
|
||||
${base_aclnn_binary_dir}
|
||||
)
|
||||
endif ()
|
||||
|
||||
add_custom_target(opbuild_gen_default
|
||||
DEPENDS ${generate_aclnn_srcs} ${generate_aclnn_headers} op_host_aclnn
|
||||
)
|
||||
|
||||
if (generate_aclnn_inner_srcs)
|
||||
add_custom_command(OUTPUT ${generate_aclnn_inner_srcs}
|
||||
COMMAND mkdir -p ${base_aclnn_binary_dir}/inner
|
||||
COMMAND OPS_PROTO_SEPARATE=1
|
||||
OPS_ACLNN_GEN=1
|
||||
OPS_PROJECT_NAME=aclnnInner
|
||||
${OP_BUILD_TOOL}
|
||||
$<TARGET_FILE:op_host_aclnnInner>
|
||||
${base_aclnn_binary_dir}/inner
|
||||
)
|
||||
endif ()
|
||||
|
||||
add_custom_target(opbuild_gen_inner
|
||||
DEPENDS ${generate_aclnn_inner_srcs} op_host_aclnnInner
|
||||
)
|
||||
|
||||
if (generate_exclude_proto_srcs)
|
||||
add_custom_command(OUTPUT ${generate_exclude_proto_srcs}
|
||||
COMMAND mkdir -p ${base_aclnn_binary_dir}/exc
|
||||
COMMAND OPS_PROTO_SEPARATE=1
|
||||
OPS_ACLNN_GEN=0
|
||||
OPS_PROJECT_NAME=aclnnExc
|
||||
${OP_BUILD_TOOL}
|
||||
$<TARGET_FILE:op_host_aclnnExc>
|
||||
${base_aclnn_binary_dir}/exc
|
||||
)
|
||||
endif ()
|
||||
|
||||
add_custom_target(opbuild_gen_exc
|
||||
DEPENDS ${generate_exclude_proto_srcs} op_host_aclnnExc
|
||||
)
|
||||
endif ()
|
||||
|
||||
# ------------------------------------------------ generate adapt py ------------------------------------------------
|
||||
add_custom_target(generate_adapt_py
|
||||
COMMAND ${HI_PYTHON} ${ASCENDC_CMAKE_UTIL_DIR}/ascendc_impl_build.py
|
||||
\"\"
|
||||
\"\"
|
||||
\"\"
|
||||
\"\"
|
||||
${ASCEND_IMPL_OUT_DIR}
|
||||
${ASCEND_AUTOGEN_DIR}
|
||||
--opsinfo-dir ${base_aclnn_binary_dir} ${base_aclnn_binary_dir}/inner ${base_aclnn_binary_dir}/exc
|
||||
)
|
||||
|
||||
add_dependencies(generate_adapt_py opbuild_gen_default opbuild_gen_inner opbuild_gen_exc)
|
||||
|
||||
foreach (_op_name ${OP_LIST})
|
||||
install(FILES ${ASCEND_IMPL_OUT_DIR}/dynamic/${_op_name}.py
|
||||
DESTINATION ${IMPL_DYNAMIC_INSTALL_DIR}
|
||||
OPTIONAL
|
||||
)
|
||||
endforeach ()
|
||||
|
||||
install(DIRECTORY ${OPS_ADV_UTILS_KERNEL_INC}/
|
||||
DESTINATION ${IMPL_INSTALL_DIR}/ascendc/common
|
||||
)
|
||||
|
||||
foreach (op_dir ${OP_DIR_LIST})
|
||||
get_filename_component(_op_name "${op_dir}" NAME)
|
||||
|
||||
file(GLOB KERNEL_FILES
|
||||
${op_dir}/op_kernel/*.cpp
|
||||
${op_dir}/op_kernel/*.h
|
||||
)
|
||||
|
||||
install(FILES ${KERNEL_FILES}
|
||||
DESTINATION ${IMPL_INSTALL_DIR}/ascendc/${_op_name}
|
||||
OPTIONAL
|
||||
)
|
||||
endforeach ()
|
||||
|
||||
# ------------------------------------------------ generate compile cmd ------------------------------------------------
|
||||
if (BUILD_OPEN_PROJECT)
|
||||
add_custom_target(prepare_build ALL)
|
||||
add_custom_target(generate_compile_cmd ALL)
|
||||
add_custom_target(generate_ops_info ALL)
|
||||
add_dependencies(prepare_build generate_adapt_py generate_compile_cmd)
|
||||
|
||||
foreach (compute_unit ${ASCEND_COMPUTE_UNIT})
|
||||
add_compile_cmd_target(
|
||||
COMPUTE_UNIT ${compute_unit}
|
||||
)
|
||||
|
||||
add_ops_info_target(
|
||||
COMPUTE_UNIT ${compute_unit}
|
||||
)
|
||||
endforeach ()
|
||||
else()
|
||||
add_dependencies(tbe_ops_json_info generate_adapt_py)
|
||||
endif ()
|
||||
|
||||
# ------------------------------------------------ opp kernel ------------------------------------------------
|
||||
if (ENABLE_OPS_KERNEL)
|
||||
add_custom_target(ops_kernel ALL)
|
||||
add_custom_target(ops_config ALL)
|
||||
add_dependencies(ops_kernel ops_config)
|
||||
|
||||
foreach (compute_unit ${ASCEND_COMPUTE_UNIT})
|
||||
add_bin_compile_target(
|
||||
COMPUTE_UNIT
|
||||
${compute_unit}
|
||||
OP_INFO
|
||||
${OP_DIR_LIST}
|
||||
)
|
||||
endforeach ()
|
||||
endif ()
|
||||
|
||||
if (BUILD_OPEN_PROJECT)
|
||||
add_custom_target(modify_vendor ALL
|
||||
DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/scripts/install.sh ${CMAKE_CURRENT_BINARY_DIR}/scripts/upgrade.sh
|
||||
)
|
||||
|
||||
# modify VENDOR_NAME in install.sh and upgrade.sh
|
||||
add_custom_command(OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/scripts/install.sh ${CMAKE_CURRENT_BINARY_DIR}/scripts/upgrade.sh
|
||||
COMMAND mkdir -p ${CMAKE_CURRENT_BINARY_DIR}/scripts
|
||||
COMMAND cp -r ${ASCEND_PROJECT_DIR}/scripts/* ${CMAKE_CURRENT_BINARY_DIR}/scripts/
|
||||
COMMAND chmod +w ${CMAKE_CURRENT_BINARY_DIR}/scripts/*
|
||||
COMMAND sed -i "s/vendor_name=customize/vendor_name=${VENDOR_NAME}/g" ${CMAKE_CURRENT_BINARY_DIR}/scripts/*
|
||||
)
|
||||
|
||||
install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/scripts/
|
||||
DESTINATION . FILE_PERMISSIONS OWNER_EXECUTE OWNER_READ GROUP_READ
|
||||
)
|
||||
|
||||
# gen version.info
|
||||
set(version_info_dir ${CMAKE_CURRENT_BINARY_DIR})
|
||||
set(version_info_file ${version_info_dir}/version.info)
|
||||
add_custom_target(gen_version_info ALL
|
||||
DEPENDS ${version_info_file}
|
||||
)
|
||||
|
||||
add_custom_command(OUTPUT ${version_info_file}
|
||||
COMMAND bash ${ASCENDC_CMAKE_UTIL_DIR}/gen_version_info.sh ${ASCEND_CANN_PACKAGE_PATH} ${version_info_dir}
|
||||
)
|
||||
|
||||
install(FILES ${version_info_file}
|
||||
DESTINATION packages/vendors/${VENDOR_NAME}/
|
||||
)
|
||||
|
||||
# CPack config
|
||||
set(CPACK_PACKAGE_NAME ${CMAKE_PROJECT_NAME})
|
||||
set(CPACK_PACKAGE_VERSION ${CMAKE_PROJECT_VERSION})
|
||||
set(CPACK_PACKAGE_DESCRIPTION "CPack ops project")
|
||||
set(CPACK_PACKAGE_DESCRIPTION_SUMMARY "CPack ops project")
|
||||
set(CPACK_PACKAGE_DIRECTORY ${CMAKE_BINARY_DIR})
|
||||
set(CPACK_PACKAGE_FILE_NAME "CANN-custom_ops-${CANN_VERSION}-linux.${CMAKE_SYSTEM_PROCESSOR}.run")
|
||||
set(CPACK_GENERATOR External)
|
||||
set(CPACK_CMAKE_GENERATOR "Unix Makefiles")
|
||||
set(CPACK_EXTERNAL_ENABLE_STAGING TRUE)
|
||||
set(CPACK_EXTERNAL_PACKAGE_SCRIPT ${ASCEND_CMAKE_DIR}/makeself.cmake)
|
||||
set(CPACK_EXTERNAL_BUILT_PACKAGES ${CPACK_PACKAGE_DIRECTORY}/_CPack_Packages/Linux/External/${CPACK_PACKAGE_FILE_NAME}/${CPACK_PACKAGE_FILE_NAME})
|
||||
include(CPack)
|
||||
endif ()
|
||||
189
csrc/build.sh
Normal file
189
csrc/build.sh
Normal file
@@ -0,0 +1,189 @@
|
||||
#!/bin/bash
|
||||
# Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
# This file is a part of the CANN Open Software.
|
||||
# Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
# Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See LICENSE in the root of the software repository for the full text of the License.
|
||||
# ======================================================================================================================
|
||||
|
||||
set -e
|
||||
|
||||
CURRENT_DIR=$(dirname $(readlink -f ${BASH_SOURCE[0]}))
|
||||
BUILD_DIR=${CURRENT_DIR}/build
|
||||
OUTPUT_DIR=${CURRENT_DIR}/output
|
||||
USER_ID=$(id -u)
|
||||
PARENT_JOB="false"
|
||||
CHECK_COMPATIBLE="true"
|
||||
ASAN="false"
|
||||
COV="false"
|
||||
VERBOSE="false"
|
||||
|
||||
if [ "${USER_ID}" != "0" ]; then
|
||||
DEFAULT_TOOLKIT_INSTALL_DIR="${HOME}/Ascend/ascend-toolkit/latest"
|
||||
DEFAULT_INSTALL_DIR="${HOME}/Ascend/latest"
|
||||
else
|
||||
DEFAULT_TOOLKIT_INSTALL_DIR="/usr/local/Ascend/ascend-toolkit/latest"
|
||||
DEFAULT_INSTALL_DIR="/usr/local/Ascend/latest"
|
||||
fi
|
||||
|
||||
CUSTOM_OPTION="-DBUILD_OPEN_PROJECT=ON"
|
||||
|
||||
function help_info() {
|
||||
echo "Usage: $0 [options]"
|
||||
echo "Options:"
|
||||
echo
|
||||
echo "-h|--help Displays help message."
|
||||
echo
|
||||
echo "-n|--op-name Specifies the compiled operator. If there are multiple values, separate them with semicolons and use quotation marks. The default is all."
|
||||
echo " For example: -n \"flash_attention_score\" or -n \"flash_attention_score;flash_attention_score_grad\""
|
||||
echo
|
||||
echo "-c|--compute-unit Specifies the chip type. If there are multiple values, separate them with semicolons and use quotation marks. The default is ascend910b."
|
||||
echo " For example: -c \"ascend910b\" or -c \"ascend910b;ascend310p\""
|
||||
echo
|
||||
echo "--cov Compiles with cov."
|
||||
echo
|
||||
echo "--verbose Displays more compilation information."
|
||||
echo
|
||||
}
|
||||
|
||||
function log() {
|
||||
local current_time=`date +"%Y-%m-%d %H:%M:%S"`
|
||||
echo "[$current_time] "$1
|
||||
}
|
||||
|
||||
function set_env()
|
||||
{
|
||||
source $ASCEND_CANN_PACKAGE_PATH/bin/setenv.bash || echo "0"
|
||||
|
||||
export BISHENG_REAL_PATH=$(which bisheng || true)
|
||||
|
||||
if [ -z "${BISHENG_REAL_PATH}" ];then
|
||||
log "Error: bisheng compilation tool not found, Please check whether the cann package or environment variables are set."
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
function clean()
|
||||
{
|
||||
if [ -n "${BUILD_DIR}" ];then
|
||||
rm -rf ${BUILD_DIR}
|
||||
fi
|
||||
mkdir -p ${BUILD_DIR} ${OUTPUT_DIR}
|
||||
}
|
||||
|
||||
function cmake_config()
|
||||
{
|
||||
local extra_option="$1"
|
||||
log "Info: cmake config ${CUSTOM_OPTION} ${extra_option} ."
|
||||
cmake .. ${CUSTOM_OPTION} ${extra_option}
|
||||
}
|
||||
|
||||
function build()
|
||||
{
|
||||
local target="$1"
|
||||
if [ "${VERBOSE}" == "true" ];then
|
||||
local option="--verbose"
|
||||
fi
|
||||
cmake --build . --target ${target} ${JOB_NUM} ${option}
|
||||
}
|
||||
|
||||
function gen_bisheng(){
|
||||
local ccache_program=$1
|
||||
local gen_bisheng_dir=${BUILD_DIR}/gen_bisheng_dir
|
||||
|
||||
if [ ! -d "${gen_bisheng_dir}" ];then
|
||||
mkdir -p ${gen_bisheng_dir}
|
||||
fi
|
||||
|
||||
pushd ${gen_bisheng_dir}
|
||||
$(> bisheng)
|
||||
echo "#!/bin/bash" >> bisheng
|
||||
echo "ccache_args=""\"""${ccache_program} ${BISHENG_REAL_PATH}""\"" >> bisheng
|
||||
echo "args=""$""@" >> bisheng
|
||||
|
||||
if [ "${VERBOSE}" == "true" ];then
|
||||
echo "echo ""\"""$""{ccache_args} ""$""args""\"" >> bisheng
|
||||
fi
|
||||
|
||||
echo "eval ""\"""$""{ccache_args} ""$""args""\"" >> bisheng
|
||||
chmod +x bisheng
|
||||
|
||||
export PATH=${gen_bisheng_dir}:$PATH
|
||||
popd
|
||||
}
|
||||
|
||||
function build_package(){
|
||||
build package
|
||||
}
|
||||
|
||||
function build_host(){
|
||||
build_package
|
||||
}
|
||||
|
||||
function build_kernel(){
|
||||
build ops_kernel
|
||||
}
|
||||
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case $1 in
|
||||
-h|--help)
|
||||
help_info
|
||||
exit
|
||||
;;
|
||||
-n|--op-name)
|
||||
ascend_op_name="$2"
|
||||
shift 2
|
||||
;;
|
||||
-c|--compute-unit)
|
||||
ascend_compute_unit="$2"
|
||||
shift 2
|
||||
;;
|
||||
*)
|
||||
help_info
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
if [ -n "${ascend_compute_unit}" ];then
|
||||
CUSTOM_OPTION="${CUSTOM_OPTION} -DASCEND_COMPUTE_UNIT=${ascend_compute_unit}"
|
||||
fi
|
||||
|
||||
if [ -n "${ascend_op_name}" ];then
|
||||
CUSTOM_OPTION="${CUSTOM_OPTION} -DASCEND_OP_NAME=${ascend_op_name}"
|
||||
fi
|
||||
|
||||
if [ -n "${ASCEND_HOME_PATH}" ];then
|
||||
ASCEND_CANN_PACKAGE_PATH=${ASCEND_HOME_PATH}
|
||||
elif [ -n "${ASCEND_OPP_PATH}" ];then
|
||||
ASCEND_CANN_PACKAGE_PATH=$(dirname ${ASCEND_OPP_PATH})
|
||||
elif [ -d "${DEFAULT_TOOLKIT_INSTALL_DIR}" ];then
|
||||
ASCEND_CANN_PACKAGE_PATH=${DEFAULT_TOOLKIT_INSTALL_DIR}
|
||||
elif [ -d "${DEFAULT_INSTALL_DIR}" ];then
|
||||
ASCEND_CANN_PACKAGE_PATH=${DEFAULT_INSTALL_DIR}
|
||||
else
|
||||
log "Error: Please set the toolkit package installation directory through parameter -p|--package-path."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ "${PARENT_JOB}" == "false" ];then
|
||||
CPU_NUM=$(($(cat /proc/cpuinfo | grep "^processor" | wc -l)*2))
|
||||
JOB_NUM="-j${CPU_NUM}"
|
||||
fi
|
||||
|
||||
CUSTOM_OPTION="${CUSTOM_OPTION} -DCUSTOM_ASCEND_CANN_PACKAGE_PATH=${ASCEND_CANN_PACKAGE_PATH} -DCHECK_COMPATIBLE=${CHECK_COMPATIBLE}"
|
||||
|
||||
set_env
|
||||
clean
|
||||
|
||||
ccache_system=$(which ccache || true)
|
||||
if [ -n "${ccache_system}" ];then
|
||||
CUSTOM_OPTION="${CUSTOM_OPTION} -DENABLE_CCACHE=ON -DCUSTOM_CCACHE=${ccache_system}"
|
||||
gen_bisheng ${ccache_system}
|
||||
fi
|
||||
|
||||
cd ${BUILD_DIR}
|
||||
cmake_config
|
||||
build_package
|
||||
34
csrc/build_aclnn.sh
Normal file
34
csrc/build_aclnn.sh
Normal file
@@ -0,0 +1,34 @@
|
||||
#!/bin/bash
|
||||
|
||||
ROOT_DIR=$1
|
||||
SOC_VERSION=$2
|
||||
|
||||
if [[ "$SOC_VERSION" =~ ^ascend310 ]]; then
|
||||
# ASCEND310P series
|
||||
# currently, no custom aclnn ops for ASCEND310 series
|
||||
# CUSTOM_OPS=""
|
||||
# SOC_ARG="ascend310p"
|
||||
exit 0
|
||||
elif [[ "$SOC_VERSION" =~ ^ascend910b ]]; then
|
||||
# ASCEND910B (A2) series
|
||||
CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list"
|
||||
SOC_ARG="ascend910b"
|
||||
elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then
|
||||
# ASCEND910C (A3) series
|
||||
CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list"
|
||||
SOC_ARG="ascend910_93"
|
||||
else
|
||||
# others
|
||||
# currently, no custom aclnn ops for other series
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# build custom ops
|
||||
cd csrc
|
||||
rm -rf build output
|
||||
echo "building custom ops $CUSTOM_OPS for $SOC_VERSION"
|
||||
bash build.sh -n $CUSTOM_OPS -c $SOC_ARG
|
||||
|
||||
# install custom ops to vllm_ascend/_cann_ops_custom
|
||||
./output/CANN-custom_ops*.run --install-path=$ROOT_DIR/vllm_ascend/_cann_ops_custom
|
||||
source $ROOT_DIR/vllm_ascend/_cann_ops_custom/vendors/customize/bin/set_env.bash
|
||||
235
csrc/cmake/config.cmake
Normal file
235
csrc/cmake/config.cmake
Normal file
@@ -0,0 +1,235 @@
|
||||
# Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
# This file is a part of the CANN Open Software.
|
||||
# Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
# Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See LICENSE in the root of the software repository for the full text of the License.
|
||||
# ======================================================================================================================
|
||||
|
||||
########################################################################################################################
|
||||
# Environment Check
|
||||
########################################################################################################################
|
||||
|
||||
# Python3
|
||||
find_package(Python3)
|
||||
if ((NOT Python3_FOUND) OR (${Python3_EXECUTABLE} STREQUAL ""))
|
||||
message(FATAL_ERROR "Can't find python3.")
|
||||
endif ()
|
||||
set(HI_PYTHON "${Python3_EXECUTABLE}" CACHE STRING "python executor")
|
||||
|
||||
# Get the base CANN path
|
||||
if (CUSTOM_ASCEND_CANN_PACKAGE_PATH)
|
||||
set(ASCEND_CANN_PACKAGE_PATH ${CUSTOM_ASCEND_CANN_PACKAGE_PATH})
|
||||
elseif (DEFINED ENV{ASCEND_HOME_PATH})
|
||||
set(ASCEND_CANN_PACKAGE_PATH $ENV{ASCEND_HOME_PATH})
|
||||
elseif (DEFINED ENV{ASCEND_OPP_PATH})
|
||||
get_filename_component(ASCEND_CANN_PACKAGE_PATH "$ENV{ASCEND_OPP_PATH}/.." ABSOLUTE)
|
||||
else()
|
||||
set(ASCEND_CANN_PACKAGE_PATH "/usr/local/Ascend/latest")
|
||||
endif ()
|
||||
message(STATUS "ASCEND_CANN_PACKAGE_PATH=${ASCEND_CANN_PACKAGE_PATH}")
|
||||
|
||||
########################################################################################################################
|
||||
# Common Configuration
|
||||
########################################################################################################################
|
||||
|
||||
# Switches
|
||||
option(PREPARE_BUILD "Prepare build." OFF)
|
||||
option(ENABLE_OPS_HOST "Build ops host." ON)
|
||||
option(ENABLE_OPS_KERNEL "Build ops kernel." ON)
|
||||
if (TESTS_EXAMPLE_OPS_TEST OR TESTS_UT_OPS_TEST)
|
||||
set(ENABLE_OPS_KERNEL OFF)
|
||||
endif ()
|
||||
set(OP_DEBUG_CONFIG "false" CACHE STRING "op debug config")
|
||||
|
||||
# Path configuration
|
||||
# Source tree related paths
|
||||
get_filename_component(OPS_ADV_DIR "${CMAKE_CURRENT_SOURCE_DIR}" REALPATH)
|
||||
get_filename_component(OPS_ADV_CMAKE_DIR "${OPS_ADV_DIR}/cmake" REALPATH)
|
||||
get_filename_component(OPS_ADV_UTILS_KERNEL_INC "${OPS_ADV_DIR}/utils/inc/kernel" REALPATH)
|
||||
|
||||
|
||||
# Build tree related paths
|
||||
set(ASCEND_IMPL_OUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/impl CACHE STRING "ascend impl output directories")
|
||||
set(ASCEND_BINARY_OUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/binary CACHE STRING "ascend binary output directories")
|
||||
set(ASCEND_AUTOGEN_DIR ${CMAKE_CURRENT_BINARY_DIR}/autogen CACHE STRING "Auto generate file directories")
|
||||
set(ASCEND_CUSTOM_OPTIONS ${ASCEND_AUTOGEN_DIR}/custom_compile_options.ini)
|
||||
set(ASCEND_CUSTOM_TILING_KEYS ${ASCEND_AUTOGEN_DIR}/custom_tiling_keys.ini)
|
||||
set(ASCEND_CUSTOM_OPC_OPTIONS ${ASCEND_AUTOGEN_DIR}/custom_opc_options.ini)
|
||||
set(OP_BUILD_TOOL ${ASCEND_CANN_PACKAGE_PATH}/tools/opbuild/op_build CACHE STRING "op_build tool")
|
||||
file(MAKE_DIRECTORY ${ASCEND_AUTOGEN_DIR})
|
||||
file(REMOVE ${ASCEND_CUSTOM_OPTIONS})
|
||||
file(TOUCH ${ASCEND_CUSTOM_OPTIONS})
|
||||
file(REMOVE ${ASCEND_CUSTOM_TILING_KEYS})
|
||||
file(TOUCH ${ASCEND_CUSTOM_TILING_KEYS})
|
||||
file(REMOVE ${ASCEND_CUSTOM_OPC_OPTIONS})
|
||||
file(TOUCH ${ASCEND_CUSTOM_OPC_OPTIONS})
|
||||
if (BUILD_OPEN_PROJECT)
|
||||
if(EXISTS ${ASCEND_CANN_PACKAGE_PATH}/tools/ascend_project/cmake)
|
||||
set(ASCEND_PROJECT_DIR ${ASCEND_CANN_PACKAGE_PATH}/tools/ascend_project)
|
||||
else()
|
||||
set(ASCEND_PROJECT_DIR ${ASCEND_CANN_PACKAGE_PATH}/tools/op_project_templates/ascendc/customize)
|
||||
endif()
|
||||
set(ASCEND_CMAKE_DIR ${ASCEND_PROJECT_DIR}/cmake CACHE STRING "ascend project cmake")
|
||||
set(IMPL_INSTALL_DIR packages/vendors/${VENDOR_NAME}/op_impl/ai_core/tbe/${VENDOR_NAME}_impl)
|
||||
set(IMPL_DYNAMIC_INSTALL_DIR packages/vendors/${VENDOR_NAME}/op_impl/ai_core/tbe/${VENDOR_NAME}_impl/dynamic)
|
||||
set(ACLNN_INC_INSTALL_DIR packages/vendors/${VENDOR_NAME}/op_api/include)
|
||||
else()
|
||||
set(ASCEND_CMAKE_DIR ${TOP_DIR}/asl/ops/cann/ops/built-in/ascendc/samples/customize/cmake CACHE STRING "ascend project cmake")
|
||||
set(IMPL_INSTALL_DIR lib/ascendc/impl)
|
||||
set(IMPL_DYNAMIC_INSTALL_DIR lib/ascendc/impl/dynamic)
|
||||
set(ACLNN_INC_INSTALL_DIR lib/include)
|
||||
set(OPS_STATIC_TYPES infer train)
|
||||
set(OPS_STATIC_SCRIPT ${TOP_DIR}/asl/ops/cann/ops/built-in/kernel/binary_script/build_opp_kernel_static.py)
|
||||
endif ()
|
||||
set(ASCENDC_CMAKE_UTIL_DIR ${ASCEND_CMAKE_DIR}/util)
|
||||
set(CUSTOM_DIR ${CMAKE_BINARY_DIR}/custom)
|
||||
set(TILING_CUSTOM_DIR ${CUSTOM_DIR}/op_impl/ai_core/tbe/op_tiling)
|
||||
set(TILING_CUSTOM_FILE ${TILING_CUSTOM_DIR}/liboptiling.so)
|
||||
|
||||
# Temporary adaptation for ascendc changes, to be removed after switching to the new version of ascendc
|
||||
if(EXISTS ${ASCENDC_CMAKE_UTIL_DIR}/ascendc_gen_options.py)
|
||||
set(ADD_OPS_COMPILE_OPTION_V2 ON)
|
||||
else()
|
||||
set(ADD_OPS_COMPILE_OPTION_V2 OFF)
|
||||
endif()
|
||||
|
||||
########################################################################################################################
|
||||
# CMake Options, Default Parameters Setting
|
||||
# Configure CMake options and default parameters according to the CMake build process
|
||||
# CMake build process: 1) Configuration phase; 2) Build phase; 3) Installation phase;
|
||||
########################################################################################################################
|
||||
if (BUILD_OPEN_PROJECT)
|
||||
# Build phase
|
||||
# Build type
|
||||
# The Generator in CMake is a tool used to generate native build systems. Generally divided into two types:
|
||||
# 1. Single-configuration generator:
|
||||
# In the configuration phase, only one build type is allowed to be specified through the variable CMAKE_BUILD_TYPE;
|
||||
# In the build phase, the build type cannot be changed, and only the build type specified through the variable CMAKE_BUILD_TYPE in the configuration phase can be used;
|
||||
# Common generators of this type include: Ninja, Unix Makefiles
|
||||
# 2. Multi-configuration generator:
|
||||
# In the configuration phase, only the list of build types available in the build phase is specified through the variable CMAKE_CONFIGURATION_TYPES;
|
||||
# In the build phase, the specific build type of the build phase is specified through the "--config" parameter;
|
||||
# Common generators of this type include: Xcode, Visual Studio
|
||||
# Therefore:
|
||||
# 1. In the single-configuration generator scenario, if the build type (CMAKE_BUILD_TYPE) is not specified, the default is Debug;
|
||||
# 2. In the multi-configuration generator scenario, if the build types available in the build phase (CMAKE_CONFIGURATION_TYPES) are not specified,
|
||||
# it is defaulted to the full set of build types allowed by CMake [Debug;Release;MinSizeRel;RelWithDebInfo]
|
||||
get_property(GENERATOR_IS_MULTI_CONFIG GLOBAL PROPERTY GENERATOR_IS_MULTI_CONFIG)
|
||||
if (GENERATOR_IS_MULTI_CONFIG)
|
||||
if (NOT CMAKE_CONFIGURATION_TYPES)
|
||||
set(CMAKE_CONFIGURATION_TYPES "Debug;Release;MinSizeRel;RelWithDebInfo" CACHE STRING "Configuration Build type" FORCE)
|
||||
endif ()
|
||||
else ()
|
||||
if (NOT CMAKE_BUILD_TYPE)
|
||||
set(CMAKE_BUILD_TYPE "Debug" CACHE STRING "Build type(default Debug)" FORCE)
|
||||
endif ()
|
||||
endif ()
|
||||
|
||||
# Build phase
|
||||
# Executable runtime library file search path RPATH
|
||||
# Do not skip RPATH in UTest and Example scenarios
|
||||
if (TESTS_UT_OPS_TEST OR TESTS_EXAMPLE_OPS_TEST)
|
||||
set(CMAKE_SKIP_RPATH FALSE)
|
||||
else ()
|
||||
set(CMAKE_SKIP_RPATH TRUE)
|
||||
endif ()
|
||||
|
||||
# Build phase
|
||||
# CCACHE configuration
|
||||
if (ENABLE_CCACHE)
|
||||
if (CUSTOM_CCACHE)
|
||||
set(CCACHE_PROGRAM ${CUSTOM_CCACHE})
|
||||
else()
|
||||
find_program(CCACHE_PROGRAM ccache)
|
||||
endif ()
|
||||
if (CCACHE_PROGRAM)
|
||||
set(CMAKE_C_COMPILER_LAUNCHER ${CCACHE_PROGRAM} CACHE PATH "C cache Compiler")
|
||||
set(CMAKE_CXX_COMPILER_LAUNCHER ${CCACHE_PROGRAM} CACHE PATH "CXX cache Compiler")
|
||||
endif ()
|
||||
endif ()
|
||||
|
||||
# Installation phase
|
||||
# Installation path
|
||||
# When CMAKE_INSTALL_PREFIX is not explicitly set (i.e., CMAKE_INSTALL_PREFIX takes the default value),
|
||||
# correct its value to be level with the build tree root directory CMAKE_CURRENT_BINARY_DIR
|
||||
if (CMAKE_INSTALL_PREFIX_INITIALIZED_TO_DEFAULT)
|
||||
get_filename_component(_Install_Path_Prefix "${CMAKE_CURRENT_BINARY_DIR}/../output" REALPATH)
|
||||
set(CMAKE_INSTALL_PREFIX "${_Install_Path_Prefix}" CACHE STRING "Install path" FORCE)
|
||||
endif ()
|
||||
endif ()
|
||||
|
||||
########################################################################################################################
|
||||
# Public Compilation Parameters
|
||||
########################################################################################################################
|
||||
list(TRANSFORM ASCEND_COMPUTE_UNIT TOLOWER)
|
||||
if (BUILD_OPEN_PROJECT)
|
||||
message(STATUS "ENABLE_CCACHE=${ENABLE_CCACHE}, CUSTOM_CCACHE=${CUSTOM_CCACHE}")
|
||||
message(STATUS "CCACHE_PROGRAM=${CCACHE_PROGRAM}")
|
||||
message(STATUS "ASCEND_COMPUTE_UNIT=${ASCEND_COMPUTE_UNIT}")
|
||||
message(STATUS "ASCEND_OP_NAME=${ASCEND_OP_NAME}")
|
||||
message(STATUS "TILING_KEY=${TILING_KEY}")
|
||||
message(STATUS "TESTS_UT_OPS_TEST=${TESTS_UT_OPS_TEST}")
|
||||
message(STATUS "TESTS_EXAMPLE_OPS_TEST=${TESTS_EXAMPLE_OPS_TEST}")
|
||||
endif ()
|
||||
|
||||
########################################################################################################################
|
||||
# Preprocessing
|
||||
########################################################################################################################
|
||||
if (BUILD_OPEN_PROJECT)
|
||||
if (NOT PREPARE_BUILD AND ENABLE_OPS_KERNEL)
|
||||
if (TILING_KEY)
|
||||
string(REPLACE ";" "::" EP_TILING_KEY "${TILING_KEY}")
|
||||
else()
|
||||
set(EP_TILING_KEY FALSE)
|
||||
endif ()
|
||||
|
||||
if (OPS_COMPILE_OPTIONS)
|
||||
string(REPLACE ";" "::" EP_OPS_COMPILE_OPTIONS "${OPS_COMPILE_OPTIONS}")
|
||||
else()
|
||||
set(EP_OPS_COMPILE_OPTIONS FALSE)
|
||||
endif ()
|
||||
|
||||
string(REPLACE ";" "::" EP_ASCEND_COMPUTE_UNIT "${ASCEND_COMPUTE_UNIT}")
|
||||
|
||||
execute_process(COMMAND bash ${CMAKE_CURRENT_SOURCE_DIR}/cmake/scripts/prepare.sh
|
||||
-s ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
-b ${CMAKE_CURRENT_BINARY_DIR}/prepare_build
|
||||
-p ${ASCEND_CANN_PACKAGE_PATH}
|
||||
--autogen-dir ${ASCEND_AUTOGEN_DIR}
|
||||
--build-open-project ${BUILD_OPEN_PROJECT}
|
||||
--binary-out-dir ${ASCEND_BINARY_OUT_DIR}
|
||||
--impl-out-dir ${ASCEND_IMPL_OUT_DIR}
|
||||
--op-build-tool ${OP_BUILD_TOOL}
|
||||
--ascend-cmake-dir ${ASCEND_CMAKE_DIR}
|
||||
--tiling-key ${EP_TILING_KEY}
|
||||
--ops-compile-options ${EP_OPS_COMPILE_OPTIONS}
|
||||
--check-compatible ${CHECK_COMPATIBLE}
|
||||
--ascend-compute_unit ${EP_ASCEND_COMPUTE_UNIT}
|
||||
--op_debug_config ${OP_DEBUG_CONFIG}
|
||||
--ascend-op-name "${ASCEND_OP_NAME}"
|
||||
RESULT_VARIABLE result
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
OUTPUT_VARIABLE PREPARE_BUILD_OUTPUT_VARIABLE)
|
||||
if (result)
|
||||
message(FATAL_ERROR "Error: ops prepare build failed.")
|
||||
endif ()
|
||||
|
||||
file(REMOVE ${ASCEND_CUSTOM_OPTIONS})
|
||||
file(TOUCH ${ASCEND_CUSTOM_OPTIONS})
|
||||
file(REMOVE ${ASCEND_CUSTOM_TILING_KEYS})
|
||||
file(TOUCH ${ASCEND_CUSTOM_TILING_KEYS})
|
||||
file(REMOVE ${ASCEND_CUSTOM_OPC_OPTIONS})
|
||||
file(TOUCH ${ASCEND_CUSTOM_OPC_OPTIONS})
|
||||
endif ()
|
||||
endif ()
|
||||
|
||||
########################################################################################################################
|
||||
# Other Configuration
|
||||
########################################################################################################################
|
||||
if (BUILD_OPEN_PROJECT)
|
||||
if (TESTS_UT_OPS_TEST)
|
||||
include(${OPS_ADV_CMAKE_DIR}/config_utest.cmake)
|
||||
endif ()
|
||||
endif ()
|
||||
609
csrc/cmake/func.cmake
Normal file
609
csrc/cmake/func.cmake
Normal file
@@ -0,0 +1,609 @@
|
||||
# Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
# This file is a part of the CANN Open Software.
|
||||
# Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
# Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See LICENSE in the root of the software repository for the full text of the License.
|
||||
# ======================================================================================================================
|
||||
|
||||
function(add_target_source)
|
||||
cmake_parse_arguments(ADD "" "BASE_TARGET;SRC_DIR" "TARGET_NAME" ${ARGN})
|
||||
|
||||
get_target_property(all_srcs ${ADD_BASE_TARGET} SOURCES)
|
||||
set(add_srcs)
|
||||
foreach(_src ${all_srcs})
|
||||
string(REGEX MATCH "^${ADD_SRC_DIR}" is_match "${_src}")
|
||||
if (is_match)
|
||||
list(APPEND add_srcs ${_src})
|
||||
endif ()
|
||||
endforeach()
|
||||
|
||||
get_target_property(all_includes ${ADD_BASE_TARGET} INCLUDE_DIRECTORIES)
|
||||
set(add_includes)
|
||||
foreach(_include ${all_includes})
|
||||
string(REGEX MATCH "^${ADD_SRC_DIR}" is_match "${_include}")
|
||||
if (is_match)
|
||||
list(APPEND add_includes ${_include})
|
||||
endif ()
|
||||
endforeach()
|
||||
|
||||
foreach(_target_name ${ADD_TARGET_NAME})
|
||||
target_sources(${_target_name} PRIVATE
|
||||
${add_srcs}
|
||||
)
|
||||
|
||||
target_include_directories(${_target_name} PRIVATE
|
||||
${add_includes}
|
||||
)
|
||||
endforeach()
|
||||
endfunction()
|
||||
|
||||
function(op_add_subdirectory OP_LIST OP_DIR_LIST)
|
||||
set(_OP_LIST)
|
||||
set(_OP_DIR_LIST)
|
||||
|
||||
file(GLOB OP_HOST_CMAKE_FILES "${CMAKE_CURRENT_SOURCE_DIR}/**/op_host/CMakeLists.txt")
|
||||
|
||||
foreach(OP_CMAKE_FILE ${OP_HOST_CMAKE_FILES})
|
||||
get_filename_component(OP_HOST_DIR "${OP_CMAKE_FILE}" DIRECTORY)
|
||||
get_filename_component(OP_DIR "${OP_HOST_DIR}" DIRECTORY)
|
||||
get_filename_component(OP_NAME "${OP_DIR}" NAME)
|
||||
|
||||
if (NOT BUILD_OPEN_PROJECT)
|
||||
if (EXISTS ${TOP_DIR}/asl/ops/cann/ops/built-in/tbe/impl/ascendc/${OP_NAME})
|
||||
continue()
|
||||
endif ()
|
||||
endif ()
|
||||
|
||||
if (DEFINED ASCEND_OP_NAME AND NOT "${ASCEND_OP_NAME}" STREQUAL "")
|
||||
if (NOT "${ASCEND_OP_NAME}" STREQUAL "all" AND NOT "${ASCEND_OP_NAME}" STREQUAL "ALL")
|
||||
if (NOT ${OP_NAME} IN_LIST ASCEND_OP_NAME)
|
||||
continue()
|
||||
endif ()
|
||||
endif ()
|
||||
endif ()
|
||||
|
||||
list(APPEND _OP_LIST ${OP_NAME})
|
||||
list(APPEND _OP_DIR_LIST ${OP_DIR})
|
||||
endforeach()
|
||||
|
||||
list(REMOVE_DUPLICATES _OP_LIST)
|
||||
list(REMOVE_DUPLICATES _OP_DIR_LIST)
|
||||
list(SORT _OP_LIST)
|
||||
list(SORT _OP_DIR_LIST)
|
||||
set(${OP_LIST} ${_OP_LIST} PARENT_SCOPE)
|
||||
set(${OP_DIR_LIST} ${_OP_DIR_LIST} PARENT_SCOPE)
|
||||
endfunction()
|
||||
|
||||
function(op_add_depend_directory)
|
||||
cmake_parse_arguments(DEP "" "OP_DIR_LIST" "OP_LIST" ${ARGN})
|
||||
set(_OP_DEPEND_DIR_LIST)
|
||||
foreach(op_name ${DEP_OP_LIST})
|
||||
if (DEFINED ${op_name}_depends)
|
||||
foreach(depend_info ${${op_name}_depends})
|
||||
if (NOT EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${depend_info}/op_host/CMakeLists.txt)
|
||||
continue()
|
||||
endif ()
|
||||
|
||||
get_filename_component(_depend_op_name "${depend_info}" NAME)
|
||||
if (NOT BUILD_OPEN_PROJECT)
|
||||
if (EXISTS ${TOP_DIR}/asl/ops/cann/ops/built-in/tbe/impl/ascendc/${_depend_op_name})
|
||||
continue()
|
||||
endif ()
|
||||
endif ()
|
||||
|
||||
if (NOT ${_depend_op_name} IN_LIST DEP_OP_LIST)
|
||||
list(APPEND _OP_DEPEND_DIR_LIST ${CMAKE_CURRENT_SOURCE_DIR}/${depend_info})
|
||||
endif ()
|
||||
endforeach()
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
list(SORT _OP_DEPEND_DIR_LIST)
|
||||
set(${DEP_OP_DIR_LIST} ${_OP_DEPEND_DIR_LIST} PARENT_SCOPE)
|
||||
endfunction()
|
||||
|
||||
function(add_compile_cmd_target)
|
||||
cmake_parse_arguments(CMD "" "COMPUTE_UNIT" "" ${ARGN})
|
||||
|
||||
if(ADD_OPS_COMPILE_OPTION_V2)
|
||||
set(OP_DEBUG_CONFIG_OPTION --opc-config-file ${ASCEND_CUSTOM_OPC_OPTIONS})
|
||||
else()
|
||||
if(OP_DEBUG_CONFIG)
|
||||
set(OP_DEBUG_CONFIG_OPTION --op-debug-config ${OP_DEBUG_CONFIG})
|
||||
endif()
|
||||
set(OP_TILING_KEY_OPTION --tiling-keys ${ASCEND_CUSTOM_TILING_KEYS})
|
||||
endif()
|
||||
|
||||
set(_OUT_DIR ${ASCEND_BINARY_OUT_DIR}/${CMD_COMPUTE_UNIT})
|
||||
set(GEN_OUT_DIR ${_OUT_DIR}/gen)
|
||||
set(COMPILE_CMD_TARGET generate_compile_cmd_${CMD_COMPUTE_UNIT})
|
||||
add_custom_target(${COMPILE_CMD_TARGET} ALL
|
||||
COMMAND mkdir -p ${GEN_OUT_DIR}
|
||||
COMMAND ${HI_PYTHON} ${ASCENDC_CMAKE_UTIL_DIR}/ascendc_bin_param_build.py
|
||||
${base_aclnn_binary_dir}/aic-${CMD_COMPUTE_UNIT}-ops-info.ini
|
||||
${GEN_OUT_DIR}
|
||||
${CMD_COMPUTE_UNIT}
|
||||
${OP_TILING_KEY_OPTION}
|
||||
${OP_DEBUG_CONFIG_OPTION}
|
||||
COMMAND ${HI_PYTHON} ${ASCENDC_CMAKE_UTIL_DIR}/ascendc_bin_param_build.py
|
||||
${base_aclnn_binary_dir}/inner/aic-${CMD_COMPUTE_UNIT}-ops-info.ini
|
||||
${GEN_OUT_DIR}
|
||||
${CMD_COMPUTE_UNIT}
|
||||
${OP_TILING_KEY_OPTION}
|
||||
${OP_DEBUG_CONFIG_OPTION}
|
||||
COMMAND ${HI_PYTHON} ${ASCENDC_CMAKE_UTIL_DIR}/ascendc_bin_param_build.py
|
||||
${base_aclnn_binary_dir}/exc/aic-${CMD_COMPUTE_UNIT}-ops-info.ini
|
||||
${GEN_OUT_DIR}
|
||||
${CMD_COMPUTE_UNIT}
|
||||
${OP_TILING_KEY_OPTION}
|
||||
${OP_DEBUG_CONFIG_OPTION}
|
||||
)
|
||||
|
||||
add_dependencies(${COMPILE_CMD_TARGET} opbuild_gen_default opbuild_gen_inner opbuild_gen_exc)
|
||||
add_dependencies(generate_compile_cmd ${COMPILE_CMD_TARGET})
|
||||
endfunction()
|
||||
|
||||
function(add_ops_info_target)
|
||||
cmake_parse_arguments(OPINFO "" "COMPUTE_UNIT" "" ${ARGN})
|
||||
|
||||
set(OPS_INFO_TARGET generate_ops_info_${OPINFO_COMPUTE_UNIT})
|
||||
set(OPS_INFO_JSON ${ASCEND_AUTOGEN_DIR}/aic-${OPINFO_COMPUTE_UNIT}-ops-info.json)
|
||||
set(CUSTOM_OPS_INFO_DIR ${CUSTOM_DIR}/op_impl/ai_core/tbe/config/${OPINFO_COMPUTE_UNIT})
|
||||
|
||||
set(OPS_INFO_INI ${base_aclnn_binary_dir}/aic-${OPINFO_COMPUTE_UNIT}-ops-info.ini)
|
||||
set(OPS_INFO_INNER_INI ${base_aclnn_binary_dir}/inner/aic-${OPINFO_COMPUTE_UNIT}-ops-info.ini)
|
||||
set(OPS_INFO_EXCLUDE_INI ${base_aclnn_binary_dir}/exc/aic-${OPINFO_COMPUTE_UNIT}-ops-info.ini)
|
||||
|
||||
add_custom_command(OUTPUT ${OPS_INFO_JSON}
|
||||
COMMAND ${HI_PYTHON} ${ASCENDC_CMAKE_UTIL_DIR}/parse_ini_to_json.py
|
||||
${OPS_INFO_INI}
|
||||
${OPS_INFO_INNER_INI}
|
||||
${OPS_INFO_EXCLUDE_INI}
|
||||
${OPS_INFO_JSON}
|
||||
COMMAND mkdir -p ${CUSTOM_OPS_INFO_DIR}
|
||||
COMMAND cp -f ${OPS_INFO_JSON} ${CUSTOM_OPS_INFO_DIR}
|
||||
)
|
||||
|
||||
add_custom_target(${OPS_INFO_TARGET} ALL
|
||||
DEPENDS ${OPS_INFO_JSON}
|
||||
)
|
||||
|
||||
add_dependencies(${OPS_INFO_TARGET} opbuild_gen_default opbuild_gen_inner opbuild_gen_exc)
|
||||
add_dependencies(generate_ops_info ${OPS_INFO_TARGET})
|
||||
|
||||
install(FILES ${OPS_INFO_JSON}
|
||||
DESTINATION packages/vendors/${VENDOR_NAME}/op_impl/ai_core/tbe/config/${OPINFO_COMPUTE_UNIT} OPTIONAL
|
||||
)
|
||||
endfunction()
|
||||
|
||||
function(add_ops_compile_options)
|
||||
cmake_parse_arguments(OP_COMPILE "" "OP_NAME" "COMPUTE_UNIT;OPTIONS" ${ARGN})
|
||||
|
||||
if(NOT OP_COMPILE_OPTIONS)
|
||||
return()
|
||||
endif()
|
||||
|
||||
if(ADD_OPS_COMPILE_OPTION_V2)
|
||||
execute_process(COMMAND ${HI_PYTHON} ${ASCENDC_CMAKE_UTIL_DIR}/ascendc_gen_options.py
|
||||
${ASCEND_CUSTOM_OPTIONS} ${OP_COMPILE_OP_NAME}
|
||||
${OP_COMPILE_COMPUTE_UNIT} ${OP_COMPILE_OPTIONS}
|
||||
RESULT_VARIABLE EXEC_RESULT
|
||||
OUTPUT_VARIABLE EXEC_INFO
|
||||
ERROR_VARIABLE EXEC_ERROR)
|
||||
if (EXEC_RESULT)
|
||||
message("add ops compile options info: ${EXEC_INFO}")
|
||||
message("add ops compile options error: ${EXEC_ERROR}")
|
||||
message(FATAL_ERROR "Error: add ops compile options failed!")
|
||||
endif ()
|
||||
else()
|
||||
file(APPEND ${ASCEND_CUSTOM_OPTIONS}
|
||||
"${OP_COMPILE_OP_NAME},${OP_COMPILE_COMPUTE_UNIT},${OP_COMPILE_OPTIONS}\n"
|
||||
)
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
function(add_ops_tiling_keys)
|
||||
cmake_parse_arguments(OP_COMPILE "" "OP_NAME" "COMPUTE_UNIT;TILING_KEYS" ${ARGN})
|
||||
|
||||
if(NOT OP_COMPILE_TILING_KEYS)
|
||||
return()
|
||||
endif()
|
||||
|
||||
if(ADD_OPS_COMPILE_OPTION_V2)
|
||||
list(JOIN OP_COMPILE_TILING_KEYS "," STRING_TILING_KEYS)
|
||||
add_ops_compile_options(
|
||||
OP_NAME ${OP_COMPILE_OP_NAME}
|
||||
OPTIONS --tiling_key=${STRING_TILING_KEYS}
|
||||
)
|
||||
else()
|
||||
file(APPEND ${ASCEND_CUSTOM_TILING_KEYS}
|
||||
"${OP_COMPILE_OP_NAME},${OP_COMPILE_COMPUTE_UNIT},${OP_COMPILE_TILING_KEYS}\n"
|
||||
)
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
function(add_opc_config)
|
||||
cmake_parse_arguments(OP_COMPILE "" "OP_NAME" "COMPUTE_UNIT;CONFIG" ${ARGN})
|
||||
|
||||
if(NOT ADD_OPS_COMPILE_OPTION_V2)
|
||||
return()
|
||||
endif()
|
||||
|
||||
if(NOT OP_COMPILE_CONFIG)
|
||||
return()
|
||||
endif()
|
||||
|
||||
string(REPLACE "," ";" OP_COMPILE_CONFIG_LIST "${OP_COMPILE_CONFIG}")
|
||||
|
||||
set(_OPC_CONFIG)
|
||||
|
||||
foreach(_option ${OP_COMPILE_CONFIG_LIST})
|
||||
if("${_option}" STREQUAL "ccec_g")
|
||||
list(APPEND _OPC_CONFIG "-g")
|
||||
elseif("${_option}" STREQUAL "ccec_O0")
|
||||
list(APPEND _OPC_CONFIG "-O0")
|
||||
elseif("${_option}" STREQUAL "oom")
|
||||
list(APPEND _OPC_CONFIG "--oom")
|
||||
elseif("${_option}" STREQUAL "dump_cce")
|
||||
list(APPEND _OPC_CONFIG "--save-temp-files")
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
if(_OPC_CONFIG)
|
||||
add_ops_compile_options(
|
||||
OP_NAME ${OP_COMPILE_OP_NAME}
|
||||
OPTIONS ${_OPC_CONFIG}
|
||||
)
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
function(add_ops_src_copy)
|
||||
cmake_parse_arguments(SRC_COPY "" "TARGET_NAME;SRC;DST;BE_RELIED;COMPUTE_UNIT" "" ${ARGN})
|
||||
|
||||
set(OPS_UTILS_INC_KERNEL_TARGET ops_utils_inc_kernel_${SRC_COPY_COMPUTE_UNIT})
|
||||
if (EXISTS ${OPS_ADV_UTILS_KERNEL_INC})
|
||||
if (NOT TARGET ${OPS_UTILS_INC_KERNEL_TARGET})
|
||||
get_filename_component(_ROOT_OPS_SRC_DIR "${SRC_COPY_DST}" DIRECTORY)
|
||||
set(OPS_UTILS_INC_KERNEL_DIR ${_ROOT_OPS_SRC_DIR}/ascendc/common)
|
||||
add_custom_command(OUTPUT ${OPS_UTILS_INC_KERNEL_DIR}
|
||||
COMMAND mkdir -p ${OPS_UTILS_INC_KERNEL_DIR}
|
||||
COMMAND cp -rf ${OPS_ADV_UTILS_KERNEL_INC}/*.* ${OPS_UTILS_INC_KERNEL_DIR}
|
||||
)
|
||||
|
||||
add_custom_target(${OPS_UTILS_INC_KERNEL_TARGET}
|
||||
DEPENDS ${OPS_UTILS_INC_KERNEL_DIR}
|
||||
)
|
||||
endif ()
|
||||
endif ()
|
||||
|
||||
if (NOT TARGET ${SRC_COPY_TARGET_NAME})
|
||||
set(_BUILD_FLAG ${SRC_COPY_DST}/${SRC_COPY_TARGET_NAME}.done)
|
||||
add_custom_command(OUTPUT ${_BUILD_FLAG}
|
||||
COMMAND mkdir -p ${SRC_COPY_DST}
|
||||
COMMAND cp -rf ${SRC_COPY_SRC}/op_kernel/*.* ${SRC_COPY_DST}
|
||||
COMMAND touch ${_BUILD_FLAG}
|
||||
)
|
||||
|
||||
add_custom_target(${SRC_COPY_TARGET_NAME}
|
||||
DEPENDS ${_BUILD_FLAG}
|
||||
)
|
||||
endif ()
|
||||
|
||||
if (TARGET ${OPS_UTILS_INC_KERNEL_TARGET})
|
||||
add_dependencies(${SRC_COPY_TARGET_NAME} ${OPS_UTILS_INC_KERNEL_TARGET})
|
||||
endif ()
|
||||
|
||||
if (DEFINED SRC_COPY_BE_RELIED)
|
||||
add_dependencies(${SRC_COPY_BE_RELIED} ${SRC_COPY_TARGET_NAME})
|
||||
endif ()
|
||||
|
||||
endfunction()
|
||||
|
||||
function(add_bin_compile_target)
|
||||
cmake_parse_arguments(BINARY "" "COMPUTE_UNIT" "OP_INFO" ${ARGN})
|
||||
|
||||
set(_INSTALL_DIR packages/vendors/${VENDOR_NAME}/op_impl/ai_core/tbe/kernel)
|
||||
set(_OUT_DIR ${ASCEND_BINARY_OUT_DIR}/${BINARY_COMPUTE_UNIT})
|
||||
|
||||
set(BIN_OUT_DIR ${_OUT_DIR}/bin)
|
||||
set(GEN_OUT_DIR ${_OUT_DIR}/gen)
|
||||
set(SRC_OUT_DIR ${_OUT_DIR}/src)
|
||||
file(MAKE_DIRECTORY ${BIN_OUT_DIR})
|
||||
|
||||
foreach(_op_info ${BINARY_OP_INFO})
|
||||
get_filename_component(_op_name "${_op_info}" NAME)
|
||||
set(${_op_name}_dir ${_op_info})
|
||||
endforeach()
|
||||
|
||||
set(_ops_target_list)
|
||||
set(compile_scripts)
|
||||
file(GLOB scripts_list ${GEN_OUT_DIR}/*.sh)
|
||||
list(APPEND compile_scripts ${scripts_list})
|
||||
|
||||
foreach(bin_script ${compile_scripts})
|
||||
get_filename_component(bin_file ${bin_script} NAME_WE)
|
||||
string(REPLACE "-" ";" bin_sep ${bin_file})
|
||||
list(GET bin_sep 0 op_type)
|
||||
list(GET bin_sep 1 op_file)
|
||||
list(GET bin_sep 2 op_index)
|
||||
|
||||
if (NOT DEFINED ${op_file}_dir)
|
||||
continue()
|
||||
endif ()
|
||||
|
||||
if (NOT TARGET ${op_file})
|
||||
add_custom_target(${op_file})
|
||||
add_dependencies(ops_kernel ${op_file})
|
||||
endif ()
|
||||
|
||||
set(OP_TARGET_NAME ${op_file}_${BINARY_COMPUTE_UNIT})
|
||||
|
||||
if (NOT TARGET ${OP_TARGET_NAME})
|
||||
add_custom_target(${OP_TARGET_NAME})
|
||||
add_dependencies(${op_file} ${OP_TARGET_NAME})
|
||||
list(APPEND _ops_target_list ${OP_TARGET_NAME})
|
||||
|
||||
set(OP_SRC_OUT_DIR ${SRC_OUT_DIR}/${op_file})
|
||||
set(OP_BIN_OUT_DIR ${BIN_OUT_DIR}/${op_file})
|
||||
file(MAKE_DIRECTORY ${OP_SRC_OUT_DIR})
|
||||
|
||||
add_ops_src_copy(
|
||||
TARGET_NAME
|
||||
${OP_TARGET_NAME}_src_copy
|
||||
SRC
|
||||
${${op_file}_dir}
|
||||
DST
|
||||
${OP_SRC_OUT_DIR}
|
||||
COMPUTE_UNIT
|
||||
${BINARY_COMPUTE_UNIT}
|
||||
)
|
||||
|
||||
if (DEFINED ${op_file}_depends)
|
||||
foreach(depend_info ${${op_file}_depends})
|
||||
get_filename_component(_depend_op_name "${depend_info}" NAME)
|
||||
set(_depend_op_target ${_depend_op_name}_${BINARY_COMPUTE_UNIT}_src_copy)
|
||||
add_ops_src_copy(
|
||||
TARGET_NAME
|
||||
${_depend_op_target}
|
||||
SRC
|
||||
${CMAKE_SOURCE_DIR}/${depend_info}
|
||||
DST
|
||||
${SRC_OUT_DIR}/${_depend_op_name}
|
||||
COMPUTE_UNIT
|
||||
${BINARY_COMPUTE_UNIT}
|
||||
BE_RELIED
|
||||
${OP_TARGET_NAME}_src_copy
|
||||
)
|
||||
endforeach()
|
||||
endif ()
|
||||
|
||||
set(DYNAMIC_PY_FILE ${OP_SRC_OUT_DIR}/${op_type}.py)
|
||||
add_custom_command(OUTPUT ${DYNAMIC_PY_FILE}
|
||||
COMMAND cp -rf ${ASCEND_IMPL_OUT_DIR}/dynamic/${op_file}.py ${DYNAMIC_PY_FILE}
|
||||
)
|
||||
|
||||
add_custom_target(${OP_TARGET_NAME}_py_copy
|
||||
DEPENDS ${DYNAMIC_PY_FILE}
|
||||
)
|
||||
|
||||
add_custom_command(OUTPUT ${OP_BIN_OUT_DIR}
|
||||
COMMAND mkdir -p ${OP_BIN_OUT_DIR}
|
||||
)
|
||||
|
||||
add_custom_target(${OP_TARGET_NAME}_mkdir
|
||||
DEPENDS ${OP_BIN_OUT_DIR}
|
||||
)
|
||||
|
||||
install(DIRECTORY ${OP_BIN_OUT_DIR}
|
||||
DESTINATION ${_INSTALL_DIR}/${BINARY_COMPUTE_UNIT} OPTIONAL
|
||||
)
|
||||
|
||||
install(FILES ${BIN_OUT_DIR}/${op_file}.json
|
||||
DESTINATION ${_INSTALL_DIR}/config/${BINARY_COMPUTE_UNIT} OPTIONAL
|
||||
)
|
||||
endif ()
|
||||
|
||||
set(_group "1-0")
|
||||
if (DEFINED ASCEND_OP_NAME AND NOT "${ASCEND_OP_NAME}" STREQUAL "")
|
||||
if (NOT "${ASCEND_OP_NAME}" STREQUAL "all" AND NOT "${ASCEND_OP_NAME}" STREQUAL "ALL")
|
||||
if (${op_file} IN_LIST ASCEND_OP_NAME)
|
||||
list(LENGTH ASCEND_OP_NAME _len)
|
||||
list(FIND ASCEND_OP_NAME ${op_file} _index)
|
||||
math(EXPR _next_index "${_index} + 1")
|
||||
if (${_next_index} LESS ${_len})
|
||||
list(GET ASCEND_OP_NAME ${_next_index} _group_str)
|
||||
set(_regex "^[0-9]+-[0-9]+$")
|
||||
string(REGEX MATCH "${_regex}" match "${_group_str}")
|
||||
if (match)
|
||||
set(_group ${_group_str})
|
||||
endif ()
|
||||
endif ()
|
||||
endif ()
|
||||
endif ()
|
||||
endif ()
|
||||
|
||||
string(REPLACE "-" ";" _group_sep ${_group})
|
||||
|
||||
list(GET _group_sep 1 start_index)
|
||||
set(end_index ${op_index})
|
||||
list(GET _group_sep 0 step)
|
||||
|
||||
set(_compile_flag false)
|
||||
if (${start_index} LESS ${end_index})
|
||||
foreach(i RANGE ${start_index} ${end_index} ${step})
|
||||
if (${i} EQUAL ${end_index})
|
||||
set(_compile_flag true)
|
||||
break()
|
||||
endif ()
|
||||
endforeach()
|
||||
elseif (${start_index} EQUAL ${end_index})
|
||||
set(_compile_flag true)
|
||||
else()
|
||||
set(_compile_flag false)
|
||||
endif ()
|
||||
|
||||
if (_compile_flag)
|
||||
set(_BUILD_COMMAND)
|
||||
set(_BUILD_FLAG ${GEN_OUT_DIR}/${OP_TARGET_NAME}_${op_index}.done)
|
||||
if (ENABLE_OPS_HOST)
|
||||
list(APPEND _BUILD_COMMAND export ASCEND_CUSTOM_OPP_PATH=${CUSTOM_DIR} &&)
|
||||
endif ()
|
||||
list(APPEND _BUILD_COMMAND export HI_PYTHON="python3" &&)
|
||||
list(APPEND _BUILD_COMMAND export TILINGKEY_PAR_COMPILE=1 &&)
|
||||
list(APPEND _BUILD_COMMAND export BIN_FILENAME_HASHED=1 &&)
|
||||
list(APPEND _BUILD_COMMAND bash ${bin_script} ${OP_SRC_OUT_DIR}/${op_type}.py ${OP_BIN_OUT_DIR})
|
||||
if(CMAKE_GENERATOR MATCHES "Unix Makefiles")
|
||||
list(APPEND _BUILD_COMMAND && echo $(MAKE))
|
||||
endif()
|
||||
|
||||
add_custom_command(OUTPUT ${_BUILD_FLAG}
|
||||
COMMAND ${_BUILD_COMMAND}
|
||||
COMMAND touch ${_BUILD_FLAG}
|
||||
WORKING_DIRECTORY ${GEN_OUT_DIR}
|
||||
)
|
||||
|
||||
add_custom_target(${OP_TARGET_NAME}_${op_index}
|
||||
DEPENDS ${_BUILD_FLAG}
|
||||
)
|
||||
|
||||
if (ENABLE_OPS_HOST)
|
||||
add_dependencies(${OP_TARGET_NAME}_${op_index} optiling generate_ops_info)
|
||||
endif ()
|
||||
add_dependencies(${OP_TARGET_NAME}_${op_index} ${OP_TARGET_NAME}_src_copy ${OP_TARGET_NAME}_py_copy ${OP_TARGET_NAME}_mkdir)
|
||||
add_dependencies(${OP_TARGET_NAME} ${OP_TARGET_NAME}_${op_index})
|
||||
endif ()
|
||||
endforeach()
|
||||
|
||||
if (_ops_target_list)
|
||||
set(OPS_CONFIG_TARGET ops_config_${BINARY_COMPUTE_UNIT})
|
||||
set(BINARY_INFO_CONFIG_FILE ${BIN_OUT_DIR}/binary_info_config.json)
|
||||
|
||||
add_custom_command(OUTPUT ${BINARY_INFO_CONFIG_FILE}
|
||||
COMMAND ${HI_PYTHON} ${ASCENDC_CMAKE_UTIL_DIR}/ascendc_ops_config.py -p ${BIN_OUT_DIR} -s ${BINARY_COMPUTE_UNIT}
|
||||
)
|
||||
|
||||
add_custom_target(${OPS_CONFIG_TARGET}
|
||||
DEPENDS ${BINARY_INFO_CONFIG_FILE}
|
||||
)
|
||||
|
||||
add_dependencies(ops_config ${OPS_CONFIG_TARGET})
|
||||
|
||||
foreach(_op_target ${_ops_target_list})
|
||||
add_dependencies(${OPS_CONFIG_TARGET} ${_op_target})
|
||||
endforeach()
|
||||
|
||||
install(FILES ${BINARY_INFO_CONFIG_FILE}
|
||||
DESTINATION ${_INSTALL_DIR}/config/${BINARY_COMPUTE_UNIT} OPTIONAL
|
||||
)
|
||||
endif ()
|
||||
endfunction()
|
||||
|
||||
function(redefine_file_macro)
|
||||
cmake_parse_arguments(_FILE "" "" "TARGET_NAME" ${ARGN})
|
||||
|
||||
foreach(_target_name ${_FILE_TARGET_NAME})
|
||||
target_compile_options(${_target_name} PRIVATE
|
||||
-Wno-builtin-macro-redefined
|
||||
)
|
||||
|
||||
get_target_property(_srcs ${_target_name} SOURCES)
|
||||
|
||||
foreach(_src ${_srcs})
|
||||
get_filename_component(_src_name "${_src}" NAME)
|
||||
set_source_files_properties(${_src}
|
||||
PROPERTIES COMPILE_DEFINITIONS __FILE__="${_src_name}"
|
||||
)
|
||||
endforeach()
|
||||
endforeach()
|
||||
endfunction()
|
||||
|
||||
function(add_static_ops)
|
||||
cmake_parse_arguments(STATIC "" "SRC_DIR" "ACLNN_SRC;ACLNN_INNER_SRC" ${ARGN})
|
||||
set(prepare_ops_adv_static_target prepare_ops_adv_static)
|
||||
set(static_src_temp_dir ${CMAKE_CURRENT_BINARY_DIR}/static_src_temp_dir)
|
||||
set(modified_files)
|
||||
foreach(ops_type ${OPS_STATIC_TYPES})
|
||||
get_target_property(all_srcs aclnn_ops_${ops_type} SOURCES)
|
||||
set(add_srcs)
|
||||
set(generate_aclnn_srcs)
|
||||
foreach(_src ${all_srcs})
|
||||
string(REGEX MATCH "^${STATIC_SRC_DIR}" is_match "${_src}")
|
||||
if (is_match)
|
||||
list(APPEND add_srcs ${_src})
|
||||
endif ()
|
||||
endforeach()
|
||||
|
||||
foreach(_src ${add_srcs})
|
||||
get_filename_component(name_without_ext ${_src} NAME_WE)
|
||||
string(REGEX REPLACE "^aclnn_" "" _op_name ${name_without_ext})
|
||||
|
||||
foreach(_aclnn_src ${STATIC_ACLNN_SRC})
|
||||
get_filename_component(aclnn_name ${_aclnn_src} NAME_WE)
|
||||
if("aclnn_${_op_name}" STREQUAL "${aclnn_name}")
|
||||
list(APPEND generate_aclnn_srcs ${_aclnn_src})
|
||||
break()
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
foreach(_aclnn_inner_src ${STATIC_ACLNN_INNER_SRC})
|
||||
get_filename_component(aclnn_inner_name ${_aclnn_inner_src} NAME_WE)
|
||||
if("aclnnInner_${_op_name}" STREQUAL "${aclnn_inner_name}")
|
||||
list(APPEND generate_aclnn_srcs ${_aclnn_inner_src})
|
||||
break()
|
||||
endif()
|
||||
endforeach()
|
||||
endforeach()
|
||||
|
||||
if(add_srcs)
|
||||
list(TRANSFORM add_srcs REPLACE "${STATIC_SRC_DIR}" "${static_src_temp_dir}" OUTPUT_VARIABLE add_static_srcs)
|
||||
list(APPEND modified_files ${add_static_srcs})
|
||||
set(aclnn_ops_static_target aclnn_ops_${ops_type}_static)
|
||||
set_source_files_properties(${add_static_srcs}
|
||||
TARGET_DIRECTORY ${aclnn_ops_static_target}
|
||||
PROPERTIES GENERATED TRUE
|
||||
)
|
||||
|
||||
target_sources(${aclnn_ops_static_target} PRIVATE
|
||||
${add_static_srcs}
|
||||
)
|
||||
add_dependencies(${aclnn_ops_static_target} ${prepare_ops_adv_static_target})
|
||||
endif()
|
||||
|
||||
if(generate_aclnn_srcs)
|
||||
list(REMOVE_DUPLICATES generate_aclnn_srcs)
|
||||
set(aclnn_op_target acl_op_${ops_type}_builtin)
|
||||
set_source_files_properties(${generate_aclnn_srcs}
|
||||
TARGET_DIRECTORY ${aclnn_op_target}
|
||||
PROPERTIES GENERATED TRUE
|
||||
)
|
||||
|
||||
target_sources(${aclnn_op_target} PRIVATE
|
||||
${generate_aclnn_srcs}
|
||||
)
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
if(NOT TARGET ${prepare_ops_adv_static_target})
|
||||
list(REMOVE_DUPLICATES modified_files)
|
||||
add_custom_command(OUTPUT ${static_src_temp_dir}
|
||||
COMMAND mkdir -p ${static_src_temp_dir}
|
||||
COMMAND cp -rf ${STATIC_SRC_DIR}/src ${static_src_temp_dir}
|
||||
COMMAND ${HI_PYTHON} -B ${OPS_STATIC_SCRIPT} InsertIni -p ${static_src_temp_dir} -f ${modified_files}
|
||||
)
|
||||
|
||||
add_custom_target(${prepare_ops_adv_static_target}
|
||||
DEPENDS ${static_src_temp_dir}
|
||||
)
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
if (BUILD_OPEN_PROJECT)
|
||||
if (TESTS_UT_OPS_TEST)
|
||||
include(${OPS_ADV_CMAKE_DIR}/func_utest.cmake)
|
||||
endif ()
|
||||
if (TESTS_EXAMPLE_OPS_TEST)
|
||||
include(${OPS_ADV_CMAKE_DIR}/func_examples.cmake)
|
||||
endif ()
|
||||
endif ()
|
||||
12
csrc/cmake/intf.cmake
Normal file
12
csrc/cmake/intf.cmake
Normal file
@@ -0,0 +1,12 @@
|
||||
# Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
# This file is a part of the CANN Open Software.
|
||||
# Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
# Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See LICENSE in the root of the software repository for the full text of the License.
|
||||
# ======================================================================================================================
|
||||
|
||||
if (BUILD_OPEN_PROJECT)
|
||||
include(${OPS_ADV_CMAKE_DIR}/intf_pub.cmake)
|
||||
endif ()
|
||||
75
csrc/cmake/intf_pub.cmake
Normal file
75
csrc/cmake/intf_pub.cmake
Normal file
@@ -0,0 +1,75 @@
|
||||
# Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
# This file is a part of the CANN Open Software.
|
||||
# Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
# Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See LICENSE in the root of the software repository for the full text of the License.
|
||||
# ======================================================================================================================
|
||||
|
||||
# Custom package scenario, public compilation configuration for Host side targets
|
||||
# Note: To ensure compatibility with the built-in package compilation process, the intf_pub name cannot be changed
|
||||
add_library(intf_pub INTERFACE)
|
||||
target_include_directories(intf_pub
|
||||
INTERFACE
|
||||
${ASCEND_CANN_PACKAGE_PATH}/include
|
||||
${ASCEND_CANN_PACKAGE_PATH}/include/external
|
||||
${ASCEND_CANN_PACKAGE_PATH}/include/experiment/platform
|
||||
${ASCEND_CANN_PACKAGE_PATH}/include/experiment/runtime
|
||||
${ASCEND_CANN_PACKAGE_PATH}/include/experiment/msprof
|
||||
)
|
||||
target_link_directories(intf_pub
|
||||
INTERFACE
|
||||
${ASCEND_CANN_PACKAGE_PATH}/lib64
|
||||
)
|
||||
target_compile_options(intf_pub
|
||||
INTERFACE
|
||||
-fPIC
|
||||
-O2
|
||||
-Wall -Wundef -Wcast-qual -Wpointer-arith -Wdate-time
|
||||
-Wfloat-equal -Wformat=2 -Wshadow
|
||||
-Wsign-compare -Wunused-macros -Wvla -Wdisabled-optimization -Wempty-body -Wignored-qualifiers
|
||||
-Wimplicit-fallthrough=3 -Wtype-limits -Wshift-negative-value -Wswitch-default
|
||||
-Wframe-larger-than=32768 -Woverloaded-virtual
|
||||
-Wnon-virtual-dtor -Wshift-overflow=2 -Wshift-count-overflow
|
||||
-Wwrite-strings -Wmissing-format-attribute -Wformat-nonliteral
|
||||
-Wdelete-non-virtual-dtor -Wduplicated-cond
|
||||
-Wtrampolines -Wsized-deallocation -Wlogical-op -Wsuggest-attribute=format
|
||||
-Wduplicated-branches
|
||||
-Wmissing-include-dirs -Wformat-signedness
|
||||
-Wreturn-local-addr -Wextra
|
||||
-Wredundant-decls -Wfloat-conversion
|
||||
-Wno-write-strings -Wall -Wno-dangling-else -Wno-comment -Wno-conversion-null -Wno-return-type
|
||||
-Wno-unknown-pragmas -Wno-sign-compare
|
||||
-Wno-error=undef
|
||||
-Wno-error=comment
|
||||
-Wno-error=conversion-null
|
||||
-Wno-error=dangling-else
|
||||
-Wno-error=return-type
|
||||
-Wno-error=shadow
|
||||
-Wno-error=sign-compare
|
||||
-Wno-error=unknown-pragmas
|
||||
-Wno-error=unused-parameter
|
||||
-Wno-error=cast-qual
|
||||
-Wno-error=format=
|
||||
-Wno-error=maybe-uninitialized
|
||||
-Wno-error=missing-field-initializers
|
||||
-Wno-error=redundant-decls
|
||||
-Wno-error=unused-variable
|
||||
$<$<COMPILE_LANGUAGE:C>:-Wnested-externs>
|
||||
$<$<CONFIG:Debug>:-g>
|
||||
$<IF:$<VERSION_GREATER:${CMAKE_C_COMPILER_VERSION},4.8.5>,-fstack-protector-strong,-fstack-protector-all>
|
||||
)
|
||||
target_compile_definitions(intf_pub
|
||||
INTERFACE
|
||||
$<$<COMPILE_LANGUAGE:CXX>:_GLIBCXX_USE_CXX11_ABI=0>
|
||||
$<$<CONFIG:Release>:_FORTIFY_SOURCE=2>
|
||||
)
|
||||
target_link_options(intf_pub
|
||||
INTERFACE
|
||||
$<$<STREQUAL:$<TARGET_PROPERTY:TYPE>,EXECUTABLE>:-pie>
|
||||
$<$<CONFIG:Release>:-s>
|
||||
-Wl,-z,relro
|
||||
-Wl,-z,now
|
||||
-Wl,-z,noexecstack
|
||||
)
|
||||
113
csrc/cmake/modules/Findalog.cmake
Normal file
113
csrc/cmake/modules/Findalog.cmake
Normal file
@@ -0,0 +1,113 @@
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
# This file is a part of the CANN Open Software.
|
||||
# Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
# Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See LICENSE in the root of the software repository for the full text of the License.
|
||||
# ======================================================================================================================
|
||||
|
||||
if (alog_FOUND)
|
||||
message(STATUS "Package alog has been found.")
|
||||
return()
|
||||
endif()
|
||||
|
||||
set(_cmake_targets_defined "")
|
||||
set(_cmake_targets_not_defined "")
|
||||
set(_cmake_expected_targets "")
|
||||
foreach(_cmake_expected_target IN ITEMS slog alog alog_headers)
|
||||
list(APPEND _cmake_expected_targets "${_cmake_expected_target}")
|
||||
if(TARGET "${_cmake_expected_target}")
|
||||
list(APPEND _cmake_targets_defined "${_cmake_expected_target}")
|
||||
else()
|
||||
list(APPEND _cmake_targets_not_defined "${_cmake_expected_target}")
|
||||
endif()
|
||||
endforeach()
|
||||
unset(_cmake_expected_target)
|
||||
|
||||
if(_cmake_targets_defined STREQUAL _cmake_expected_targets)
|
||||
unset(_cmake_targets_defined)
|
||||
unset(_cmake_targets_not_defined)
|
||||
unset(_cmake_expected_targets)
|
||||
unset(CMAKE_IMPORT_FILE_VERSION)
|
||||
cmake_policy(POP)
|
||||
return()
|
||||
endif()
|
||||
|
||||
if(NOT _cmake_targets_defined STREQUAL "")
|
||||
string(REPLACE ";" ", " _cmake_targets_defined_text "${_cmake_targets_defined}")
|
||||
string(REPLACE ";" ", " _cmake_targets_not_defined_text "${_cmake_targets_not_defined}")
|
||||
message(FATAL_ERROR "Some (but not all) targets in this export set were already defined.\nTargets Defined: ${_cmake_targets_defined_text}\nTargets not yet defined: ${_cmake_targets_not_defined_text}\n")
|
||||
endif()
|
||||
unset(_cmake_targets_defined)
|
||||
unset(_cmake_targets_not_defined)
|
||||
unset(_cmake_expected_targets)
|
||||
|
||||
find_path(_INCLUDE_DIR
|
||||
NAMES base/alog_pub.h
|
||||
NO_CMAKE_SYSTEM_PATH
|
||||
NO_CMAKE_FIND_ROOT_PATH)
|
||||
|
||||
find_library(slog_SHARED_LIBRARY
|
||||
NAMES libascendalog.so
|
||||
PATH_SUFFIXES lib64
|
||||
NO_CMAKE_SYSTEM_PATH
|
||||
NO_CMAKE_FIND_ROOT_PATH)
|
||||
|
||||
find_library(alog_SHARED_LIBRARY
|
||||
NAMES libascendalog.so
|
||||
PATH_SUFFIXES lib64
|
||||
NO_CMAKE_SYSTEM_PATH
|
||||
NO_CMAKE_FIND_ROOT_PATH)
|
||||
|
||||
include(FindPackageHandleStandardArgs)
|
||||
find_package_handle_standard_args(alog
|
||||
FOUND_VAR
|
||||
alog_FOUND
|
||||
REQUIRED_VARS
|
||||
_INCLUDE_DIR
|
||||
slog_SHARED_LIBRARY
|
||||
alog_SHARED_LIBRARY
|
||||
)
|
||||
|
||||
if(alog_FOUND)
|
||||
set(alog_INCLUDE_DIR "${_INCLUDE_DIR}")
|
||||
include(CMakePrintHelpers)
|
||||
message(STATUS "Variables in alog module:")
|
||||
cmake_print_variables(alog_INCLUDE_DIR)
|
||||
cmake_print_variables(slog_SHARED_LIBRARY)
|
||||
cmake_print_variables(alog_SHARED_LIBRARY)
|
||||
|
||||
add_library(slog SHARED IMPORTED)
|
||||
set_target_properties(slog PROPERTIES
|
||||
INTERFACE_COMPILE_DEFINITIONS "LOG_CPP;PROCESS_LOG"
|
||||
INTERFACE_LINK_LIBRARIES "alog_headers"
|
||||
IMPORTED_LOCATION "${slog_SHARED_LIBRARY}"
|
||||
)
|
||||
|
||||
add_library(alog SHARED IMPORTED)
|
||||
set_target_properties(alog PROPERTIES
|
||||
INTERFACE_COMPILE_DEFINITIONS "LOG_CPP;PROCESS_LOG"
|
||||
INTERFACE_LINK_LIBRARIES "alog_headers"
|
||||
IMPORTED_LOCATION "${alog_SHARED_LIBRARY}"
|
||||
)
|
||||
|
||||
add_library(alog_headers INTERFACE IMPORTED)
|
||||
set_target_properties(alog_headers PROPERTIES
|
||||
INTERFACE_INCLUDE_DIRECTORIES "${alog_INCLUDE_DIR}"
|
||||
)
|
||||
|
||||
include(CMakePrintHelpers)
|
||||
cmake_print_properties(TARGETS slog
|
||||
PROPERTIES INTERFACE_COMPILE_DEFINITIONS INTERFACE_LINK_LIBRARIES IMPORTED_LOCATION
|
||||
)
|
||||
cmake_print_properties(TARGETS alog
|
||||
PROPERTIES INTERFACE_COMPILE_DEFINITIONS INTERFACE_LINK_LIBRARIES IMPORTED_LOCATION
|
||||
)
|
||||
cmake_print_properties(TARGETS alog_headers
|
||||
PROPERTIES INTERFACE_INCLUDE_DIRECTORIES
|
||||
)
|
||||
endif()
|
||||
|
||||
# Cleanup temporary variables.
|
||||
set(_INCLUDE_DIR)
|
||||
130
csrc/cmake/scripts/prepare.sh
Normal file
130
csrc/cmake/scripts/prepare.sh
Normal file
@@ -0,0 +1,130 @@
|
||||
#!/bin/bash
|
||||
# Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
# This file is a part of the CANN Open Software.
|
||||
# Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
# Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See LICENSE in the root of the software repository for the full text of the License.
|
||||
# ======================================================================================================================
|
||||
|
||||
CPU_NUM=$(($(cat /proc/cpuinfo | grep "^processor" | wc -l)*2))
|
||||
JOB_NUM="-j${CPU_NUM}"
|
||||
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case $1 in
|
||||
-s)
|
||||
PATH_TO_SOURCE="$2"
|
||||
shift 2
|
||||
;;
|
||||
-b)
|
||||
PATH_TO_BUILD="$2"
|
||||
shift 2
|
||||
;;
|
||||
-p)
|
||||
ASCEND_CANN_PACKAGE_PATH="$2"
|
||||
shift 2
|
||||
;;
|
||||
--autogen-dir)
|
||||
ASCEND_AUTOGEN_DIR="$2"
|
||||
shift 2
|
||||
;;
|
||||
--build-open-project)
|
||||
BUILD_OPEN_PROJECT="$2"
|
||||
shift 2
|
||||
;;
|
||||
--binary-out-dir)
|
||||
ASCEND_BINARY_OUT_DIR="$2"
|
||||
shift 2
|
||||
;;
|
||||
--impl-out-dir)
|
||||
ASCEND_IMPL_OUT_DIR="$2"
|
||||
shift 2
|
||||
;;
|
||||
--op-build-tool)
|
||||
OP_BUILD_TOOL="$2"
|
||||
shift 2
|
||||
;;
|
||||
--ascend-cmake-dir)
|
||||
ASCEND_CMAKE_DIR="$2"
|
||||
shift 2
|
||||
;;
|
||||
--tiling-key)
|
||||
TILING_KEY="$2"
|
||||
shift 2
|
||||
;;
|
||||
--ops-compile-options)
|
||||
OPS_COMPILE_OPTIONS="$2"
|
||||
shift 2
|
||||
;;
|
||||
--check-compatible)
|
||||
CHECK_COMPATIBLE="$2"
|
||||
shift 2
|
||||
;;
|
||||
--ascend-compute_unit)
|
||||
ASCEND_COMPUTE_UNIT="$2"
|
||||
shift 2
|
||||
;;
|
||||
--ascend-op-name)
|
||||
ASCEND_OP_NAME="$2"
|
||||
shift 2
|
||||
;;
|
||||
--op_debug_config)
|
||||
OP_DEBUG_CONFIG="$2"
|
||||
shift 2
|
||||
;;
|
||||
*)
|
||||
break
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
function clean() {
|
||||
if [ -n "${PATH_TO_BUILD}" ];then
|
||||
rm -rf ${PATH_TO_BUILD}
|
||||
mkdir -p ${PATH_TO_BUILD}
|
||||
fi
|
||||
}
|
||||
|
||||
function convert_string() {
|
||||
local _input=$1
|
||||
_output=$(echo $_input | sed 's/::/;/g')
|
||||
echo "${_output}"
|
||||
}
|
||||
|
||||
function set_env() {
|
||||
CONVERT_TILING_KEY="$(convert_string ${TILING_KEY})"
|
||||
|
||||
CONVERT_OPS_COMPILE_OPTIONS="$(convert_string ${OPS_COMPILE_OPTIONS})"
|
||||
|
||||
CONVERT_ASCEND_COMPUTE_UNIT="$(convert_string ${ASCEND_COMPUTE_UNIT})"
|
||||
}
|
||||
|
||||
function build() {
|
||||
cd ${PATH_TO_BUILD}
|
||||
cmake ${PATH_TO_SOURCE} \
|
||||
-DBUILD_OPEN_PROJECT=${BUILD_OPEN_PROJECT} \
|
||||
-DPREPARE_BUILD=ON \
|
||||
-DCUSTOM_ASCEND_CANN_PACKAGE_PATH=${ASCEND_CANN_PACKAGE_PATH} \
|
||||
-DASCEND_AUTOGEN_DIR=${ASCEND_AUTOGEN_DIR} \
|
||||
-DASCEND_BINARY_OUT_DIR=${ASCEND_BINARY_OUT_DIR} \
|
||||
-DASCEND_IMPL_OUT_DIR=${ASCEND_IMPL_OUT_DIR} \
|
||||
-DOP_BUILD_TOOL=${OP_BUILD_TOOL} \
|
||||
-DASCEND_CMAKE_DIR=${ASCEND_CMAKE_DIR} \
|
||||
-DCHECK_COMPATIBLE=${CHECK_COMPATIBLE} \
|
||||
-DTILING_KEY="${CONVERT_TILING_KEY}" \
|
||||
-DOPS_COMPILE_OPTIONS="${CONVERT_OPS_COMPILE_OPTIONS}" \
|
||||
-DASCEND_COMPUTE_UNIT=${CONVERT_ASCEND_COMPUTE_UNIT} \
|
||||
-DOP_DEBUG_CONFIG=${OP_DEBUG_CONFIG} \
|
||||
-DASCEND_OP_NAME=${ASCEND_OP_NAME}
|
||||
|
||||
make ${JOB_NUM} prepare_build
|
||||
}
|
||||
|
||||
function main() {
|
||||
clean
|
||||
set_env
|
||||
build
|
||||
}
|
||||
|
||||
main
|
||||
@@ -0,0 +1,54 @@
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
# This file is a part of the CANN Open Software.
|
||||
# Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
# Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See LICENSE in the root of the software repository for the full text of the License.
|
||||
# ======================================================================================================================
|
||||
|
||||
add_ops_compile_options(
|
||||
OP_NAME GroupedMatmulSwigluQuantWeightNzTensorList
|
||||
OPTIONS --cce-auto-sync=off
|
||||
-Wno-deprecated-declarations
|
||||
-Werror
|
||||
)
|
||||
|
||||
target_sources(op_host_aclnnExc PRIVATE
|
||||
grouped_matmul_swiglu_quant_weight_nz_tensor_list_def.cpp
|
||||
)
|
||||
|
||||
target_sources(opapi PRIVATE
|
||||
grouped_matmul_swiglu_quant_weight_nz_tensor_list.cpp
|
||||
aclnn_grouped_matmul_swiglu_quant_weight_nz_tensor_list.cpp
|
||||
)
|
||||
|
||||
if (NOT BUILD_OPEN_PROJECT)
|
||||
target_sources(aclnn_ops_train PRIVATE
|
||||
grouped_matmul_swiglu_quant_weight_nz_tensor_list.cpp
|
||||
aclnn_grouped_matmul_swiglu_quant_weight_nz_tensor_list.cpp
|
||||
)
|
||||
|
||||
target_sources(aclnn_ops_infer PRIVATE
|
||||
grouped_matmul_swiglu_quant_weight_nz_tensor_list.cpp
|
||||
aclnn_grouped_matmul_swiglu_quant_weight_nz_tensor_list.cpp
|
||||
)
|
||||
endif ()
|
||||
|
||||
target_sources(optiling PRIVATE
|
||||
grouped_matmul_swiglu_quant_weight_nz_tensor_list_tiling.cpp
|
||||
)
|
||||
|
||||
target_include_directories(optiling PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}
|
||||
)
|
||||
|
||||
target_sources(opsproto PRIVATE
|
||||
grouped_matmul_swiglu_quant_weight_nz_tensor_list_proto.cpp
|
||||
)
|
||||
|
||||
file(GLOB _GMM_Aclnn_header "${CMAKE_CURRENT_SOURCE_DIR}/aclnn_grouped_matmul_swiglu_quant_weight_nz_tensor_list.h")
|
||||
|
||||
install(FILES ${_GMM_Aclnn_header}
|
||||
DESTINATION ${ACLNN_INC_INSTALL_DIR} OPTIONAL
|
||||
)
|
||||
@@ -0,0 +1,329 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
#include <dlfcn.h>
|
||||
#include <new>
|
||||
#include "aclnn_kernels/contiguous.h"
|
||||
#include "acl/acl.h"
|
||||
#include "aclnn/aclnn_base.h"
|
||||
#include "aclnn_kernels/common/op_error_check.h"
|
||||
#include "opdev/common_types.h"
|
||||
#include "opdev/data_type_utils.h"
|
||||
#include "opdev/format_utils.h"
|
||||
#include "opdev/op_dfx.h"
|
||||
#include "opdev/op_executor.h"
|
||||
#include "opdev/op_log.h"
|
||||
#include "opdev/platform.h"
|
||||
#include "opdev/shape_utils.h"
|
||||
#include "opdev/tensor_view_utils.h"
|
||||
#include "opdev/make_op_executor.h"
|
||||
#include "grouped_matmul_swiglu_quant_weight_nz_tensor_list.h"
|
||||
#include "aclnn_grouped_matmul_swiglu_quant_weight_nz_tensor_list.h"
|
||||
|
||||
using namespace op;
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
static constexpr int64_t SPLIT = 2;
|
||||
static constexpr int64_t K_LIMIT = 65536;
|
||||
static constexpr int64_t N_LIMIT = 4096;
|
||||
static constexpr int64_t NZ_DIM_3 = 32;
|
||||
static constexpr int64_t NZ_DIM_2 = 16;
|
||||
static constexpr int64_t OUTPUT_IDX_0 = 0;
|
||||
static constexpr int64_t OUTPUT_IDX_1 = 1;
|
||||
static constexpr size_t X_DIM_LIMIT = 2;
|
||||
static constexpr size_t WEIGHT_ND_DIM_LIMIT = 2;
|
||||
static constexpr size_t WEIGHT_NZ_DIM_LIMIT = 4;
|
||||
static constexpr size_t WEIGHT_SCALE_DIM_LIMIT = 1;
|
||||
static constexpr size_t TOKEN_SCALE_DIM_LIMIT = 1;
|
||||
static constexpr size_t GROUP_LIST_DIM_LIMIT = 1;
|
||||
static constexpr size_t QUANTOUT_DIM_LIMIT = 2;
|
||||
static constexpr size_t QUANTSCALEOUT_DIM_LIMIT = 1;
|
||||
|
||||
static const std::initializer_list<DataType> X_DTYPE_SUPPORT_LIST = {DataType::DT_INT8};
|
||||
static const std::initializer_list<DataType> WEIGHT_DTYPE_SUPPORT_LIST = {DataType::DT_INT8};
|
||||
static const std::initializer_list<DataType> WEIGHT_SCALE_DTYPE_SUPPORT_LIST = {DataType::DT_FLOAT, DataType::DT_FLOAT16, DataType::DT_BF16};
|
||||
static const std::initializer_list<DataType> X_SCALE_DTYPE_SUPPORT_LIST = {DataType::DT_FLOAT, DataType::DT_FLOAT16, DataType::DT_BF16};
|
||||
static const std::initializer_list<DataType> GROUP_LIST_DTYPE_SUPPORT_LIST = {DataType::DT_INT64};
|
||||
static const std::initializer_list<DataType> QUANTOUT_DTYPE_SUPPORT_LIST = {DataType::DT_INT8};
|
||||
static const std::initializer_list<DataType> QUANTSCALEOUT_DTYPE_SUPPORT_LIST = {DataType::DT_FLOAT};
|
||||
|
||||
static bool CheckNotNull(const aclTensor* x, const aclTensorList* weight, const aclTensor* bias, const aclTensor* offset,
|
||||
const aclTensorList* weightScale, const aclTensor* xScale, const aclTensor* groupList,
|
||||
const aclTensor* output, const aclTensor* outputScale, const aclTensor* outputOffset)
|
||||
{
|
||||
OP_CHECK_NULL(x, return false);
|
||||
OP_CHECK_NULL(weight, return false);
|
||||
OP_CHECK_NULL(weightScale, return false);
|
||||
OP_CHECK_NULL(xScale, return false);
|
||||
OP_CHECK_NULL(groupList, return false);
|
||||
OP_CHECK_NULL(output, return false);
|
||||
OP_CHECK_NULL(outputScale, return false);
|
||||
if (bias != nullptr) {
|
||||
OP_LOGW("aclnnGroupedMatmulSwigluQuantWeightNzTensorList, The current version does not support the scenario where bias is not 0. "
|
||||
"Features and accuracy are not guaranteed if inputting bias with values other than 0s.");
|
||||
}
|
||||
if (offset != nullptr) {
|
||||
OP_LOGW("aclnnGroupedMatmulSwigluQuantWeightNzTensorList, The current version does not support the scenario where offset is not 0. "
|
||||
"Features and accuracy are not guaranteed if inputting bias with values other than 0s.");
|
||||
}
|
||||
if (outputOffset != nullptr) {
|
||||
OP_LOGW("aclnnGroupedMatmulSwigluQuantWeightNzTensorList, The current version does not support the scenario where outputOffset is not 0. "
|
||||
"Features and accuracy are not guaranteed if inputting bias with values other than 0s.");
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool CheckInputOutDims(const aclTensor* x, const aclTensorList* weight, const aclTensorList* weightScale,
|
||||
const aclTensor* xScale, const aclTensor* groupList,
|
||||
const aclTensor* output, const aclTensor* outputScale)
|
||||
{
|
||||
OP_CHECK_WRONG_DIMENSION(x, X_DIM_LIMIT, return false);
|
||||
op::Format weightViewFormat = (*weight)[0]->GetViewFormat();
|
||||
if (IsPrivateFormat(weightViewFormat)){
|
||||
OP_CHECK_WRONG_DIMENSION((*weight)[0], WEIGHT_NZ_DIM_LIMIT, return false);
|
||||
} else {
|
||||
OP_CHECK_WRONG_DIMENSION((*weight)[0], WEIGHT_ND_DIM_LIMIT, return false);
|
||||
}
|
||||
OP_CHECK_WRONG_DIMENSION((*weightScale)[0], WEIGHT_SCALE_DIM_LIMIT, return false);
|
||||
OP_CHECK_WRONG_DIMENSION(xScale, TOKEN_SCALE_DIM_LIMIT, return false);
|
||||
OP_CHECK_WRONG_DIMENSION(groupList, GROUP_LIST_DIM_LIMIT, return false);
|
||||
OP_CHECK_WRONG_DIMENSION(output, QUANTOUT_DIM_LIMIT, return false);
|
||||
OP_CHECK_WRONG_DIMENSION(outputScale, QUANTSCALEOUT_DIM_LIMIT, return false);
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool CheckInputOutShape(const aclTensor* x, const aclTensorList* weight, const aclTensorList* weightScale,
|
||||
const aclTensor* xScale, const aclTensor* groupList,
|
||||
const aclTensor* output, const aclTensor* outputScale)
|
||||
{
|
||||
int64_t m = x->GetViewShape().GetDim(0);
|
||||
int64_t k = x->GetViewShape().GetDim(1);
|
||||
int64_t n = (*weightScale)[0]->GetViewShape().GetDim(0);
|
||||
int64_t e = weight->Size();
|
||||
if (n % SPLIT != 0){
|
||||
OP_LOGE(ACLNN_ERR_PARAM_INVALID,
|
||||
"aclnnGroupedMatmulSwigluQuantWeightNzTensorList, N is %ld , not an even number.", n);
|
||||
return false;
|
||||
}
|
||||
int64_t nAfterHalve = static_cast<int64_t>(n / SPLIT);
|
||||
// x shape is expected to be [M, K]
|
||||
op::Shape xExpectShape = {m, k};
|
||||
// The ND shape of each weight in TensorList is expected to be [K, N]
|
||||
op::Shape weightNDExpectShape = {k, n};
|
||||
// The NZ shape of each weight in TensorList is expected to be [N // 32, K // 16, 16, 32]
|
||||
op::Shape weightNZExpectShape = {static_cast<int64_t>(n / NZ_DIM_3),
|
||||
static_cast<int64_t>(k / NZ_DIM_2),
|
||||
NZ_DIM_2, NZ_DIM_3};
|
||||
// weightScale shape is expected to be [N]
|
||||
op::Shape weightScaleExpectShape = {n};
|
||||
// xScale shape is expected to be [E, N]
|
||||
op::Shape xScaleExpectShape = {m};
|
||||
// output shape is expected to be [M, N]
|
||||
op::Shape outputExpectShape = {m, nAfterHalve};
|
||||
// outputScale shape is expected to be [M]
|
||||
op::Shape outputScaleExpectShape = {m};
|
||||
for (size_t i = 0; i < weight->Size(); ++i) {
|
||||
op::Format weightViewFormat = (*weight)[i]->GetViewFormat();
|
||||
OP_CHECK_SHAPE_NOT_EQUAL_WITH_EXPECTED_SIZE(x, xExpectShape, return false);
|
||||
if (IsPrivateFormat(weightViewFormat)){
|
||||
OP_CHECK_SHAPE_NOT_EQUAL_WITH_EXPECTED_SIZE((*weight)[i], weightNZExpectShape, return false);
|
||||
} else {
|
||||
OP_CHECK_SHAPE_NOT_EQUAL_WITH_EXPECTED_SIZE((*weight)[i], weightNDExpectShape, return false);
|
||||
}
|
||||
OP_CHECK_SHAPE_NOT_EQUAL_WITH_EXPECTED_SIZE((*weightScale)[i], weightScaleExpectShape, return false);
|
||||
}
|
||||
OP_CHECK_SHAPE_NOT_EQUAL_WITH_EXPECTED_SIZE(xScale, xScaleExpectShape, return false);
|
||||
|
||||
OP_CHECK_SHAPE_NOT_EQUAL_WITH_EXPECTED_SIZE(output, outputExpectShape, return false);
|
||||
OP_CHECK_SHAPE_NOT_EQUAL_WITH_EXPECTED_SIZE(outputScale, outputScaleExpectShape, return false);
|
||||
// The length of groupList should be less than or equal to the number of experts in weight
|
||||
int64_t groupListLen = groupList->GetViewShape().GetDim(0);
|
||||
if(groupListLen > e) {
|
||||
OP_LOGE(ACLNN_ERR_PARAM_INVALID,
|
||||
"aclnnGroupedMatmulSwigluQuantWeightNzTensorList, Length of 'groupList' out of range (expected to be in range of [1, %ld], but got %ld)",
|
||||
e, groupListLen);
|
||||
return false;
|
||||
}
|
||||
if(nAfterHalve > N_LIMIT) {
|
||||
OP_LOGE(ACLNN_ERR_PARAM_INVALID,
|
||||
"aclnnGroupedMatmulSwigluQuantWeightNzTensorList, The current version does not support the scenario.\
|
||||
where N after halve is %ld greater than %ld.",
|
||||
nAfterHalve, N_LIMIT);
|
||||
return false;
|
||||
}
|
||||
if(k >= K_LIMIT) {
|
||||
OP_LOGE(ACLNN_ERR_PARAM_INVALID,
|
||||
"aclnnGroupedMatmulSwigluQuantWeightNzTensorList, The current version does not support the scenario.\
|
||||
The tail axis dimension of input0(x) is %ld, which need lower than %ld.",
|
||||
k, K_LIMIT);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool CheckDtypeValid(const aclTensor* x, const aclTensorList* weight, const aclTensorList* weightScale,
|
||||
const aclTensor* xScale, const aclTensor* groupList,
|
||||
const aclTensor* output, const aclTensor* outputScale)
|
||||
{
|
||||
OP_CHECK_DTYPE_NOT_SUPPORT(x, X_DTYPE_SUPPORT_LIST, return false);
|
||||
for (size_t i = 0; i < weight->Size(); ++i) {
|
||||
OP_CHECK_DTYPE_NOT_SUPPORT((*weight)[i], WEIGHT_DTYPE_SUPPORT_LIST, return false);
|
||||
OP_CHECK_DTYPE_NOT_SUPPORT((*weightScale)[i], WEIGHT_SCALE_DTYPE_SUPPORT_LIST, return false);
|
||||
}
|
||||
OP_CHECK_DTYPE_NOT_SUPPORT(xScale, X_SCALE_DTYPE_SUPPORT_LIST, return false);
|
||||
OP_CHECK_DTYPE_NOT_SUPPORT(groupList, GROUP_LIST_DTYPE_SUPPORT_LIST, return false);
|
||||
OP_CHECK_DTYPE_NOT_SUPPORT(output, QUANTOUT_DTYPE_SUPPORT_LIST, return false);
|
||||
OP_CHECK_DTYPE_NOT_SUPPORT(outputScale, QUANTSCALEOUT_DTYPE_SUPPORT_LIST, return false);
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool CheckFormat(const aclTensor* x, const aclTensorList* weight, const aclTensor* output)
|
||||
{
|
||||
bool isNZ = (*weight)[0]->GetStorageFormat() == op::Format::FORMAT_FRACTAL_NZ;
|
||||
if (!isNZ) {
|
||||
// fp16 in fp32 out that is split k template, not precision-advanced now
|
||||
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "aclnnGroupedMatmulSwigluQuantWeightNzTensorList, The current version does not support the scenario.\
|
||||
weight Format expect is FRACTAL_NZ, but got [%s].", op::ToString((*weight)[0]->GetStorageFormat()).GetString());
|
||||
return false;
|
||||
}
|
||||
if (IsPrivateFormat(x->GetStorageFormat())) {
|
||||
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "aclnnGroupedMatmulSwigluQuantWeightNzTensorList, The current version does not support the scenario.\
|
||||
x Format Not support Private Format.");
|
||||
return false;
|
||||
}
|
||||
if (IsPrivateFormat(output->GetStorageFormat())) {
|
||||
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "aclnnGroupedMatmulSwigluQuantWeightNzTensorList, The current version does not support the scenario.\
|
||||
output Format Not support Private Format.");
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static aclnnStatus CheckParams(const aclTensor* x, const aclTensorList* weight, const aclTensor* bias, const aclTensor* offset,
|
||||
const aclTensorList* weightScale, const aclTensor* xScale, const aclTensor* groupList,
|
||||
const aclTensor* output, const aclTensor* outputScale, const aclTensor* outputOffset) {
|
||||
// 1. Check if parameters are null pointers
|
||||
CHECK_RET(CheckNotNull(x, weight, bias, offset, weightScale, xScale,
|
||||
groupList, output, outputScale, outputOffset), ACLNN_ERR_PARAM_NULLPTR);
|
||||
|
||||
// 2. Verify input and output parameter dimensions
|
||||
CHECK_RET(CheckInputOutDims(x, weight, weightScale, xScale,
|
||||
groupList, output, outputScale), ACLNN_ERR_PARAM_INVALID);
|
||||
|
||||
// 3. Verify input and output shape parameters
|
||||
CHECK_RET(CheckInputOutShape(x, weight, weightScale, xScale,
|
||||
groupList, output, outputScale), ACLNN_ERR_PARAM_INVALID);
|
||||
|
||||
// 4. Check if the input data types are within the supported data type range
|
||||
CHECK_RET(CheckDtypeValid(x, weight, weightScale, xScale,
|
||||
groupList, output, outputScale), ACLNN_ERR_PARAM_INVALID);
|
||||
|
||||
// 5. Check if data format is supported
|
||||
CHECK_RET(CheckFormat(x, weight, output), ACLNN_ERR_PARAM_INVALID);
|
||||
|
||||
return ACLNN_SUCCESS;
|
||||
}
|
||||
|
||||
static aclnnStatus aclnnGroupedMatmulSwigluQuantWeightNzTensorListGetWorkspaceSizeCommon(const aclTensor *x, const aclTensorList *weight,
|
||||
const aclTensor *bias, const aclTensor *offset,
|
||||
const aclTensorList *weightScale, const aclTensor *xScale,
|
||||
const aclTensor *groupList,
|
||||
aclTensor *output, aclTensor *outputScale,
|
||||
aclTensor *outputOffset, uint64_t *workspaceSize,
|
||||
aclOpExecutor **executor){
|
||||
// Fixed pattern, create OpExecutor
|
||||
auto uniqueExecutor = CREATE_EXECUTOR();
|
||||
CHECK_RET(uniqueExecutor.get() != nullptr, ACLNN_ERR_INNER_CREATE_EXECUTOR);
|
||||
// Fixed pattern, parameter check
|
||||
auto ret = CheckParams(x, weight, bias, offset, weightScale, xScale,
|
||||
groupList, output, outputScale, outputOffset);
|
||||
CHECK_RET(ret == ACLNN_SUCCESS, ret);
|
||||
// Empty tensor scenario
|
||||
if (output->IsEmpty() || groupList->IsEmpty() || outputScale->IsEmpty()) {
|
||||
*workspaceSize = 0;
|
||||
uniqueExecutor.ReleaseTo(executor);
|
||||
return ACLNN_SUCCESS;
|
||||
}
|
||||
// Convert to contiguous
|
||||
x = l0op::Contiguous(x, uniqueExecutor.get());
|
||||
CHECK_RET(x != nullptr, ACLNN_ERR_INNER_NULLPTR);
|
||||
for (size_t i = 0; i < weight->Size(); ++i) {
|
||||
(*weight)[i]->SetOriginalShape((*weight)[i]->GetViewShape());
|
||||
}
|
||||
xScale = l0op::Contiguous(xScale, uniqueExecutor.get());
|
||||
CHECK_RET(xScale != nullptr, ACLNN_ERR_INNER_NULLPTR);
|
||||
groupList = l0op::Contiguous(groupList, uniqueExecutor.get());
|
||||
CHECK_RET(groupList != nullptr, ACLNN_ERR_INNER_NULLPTR);
|
||||
// Call L0 operator capability
|
||||
auto ret_0 = l0op::GroupedMatmulSwigluQuantWeightNzTensorList(x, weight, weightScale, xScale, groupList, uniqueExecutor.get());
|
||||
CHECK_RET(ret_0 != std::tuple(nullptr, nullptr), ACLNN_ERR_INNER_NULLPTR);
|
||||
auto out0 = std::get<OUTPUT_IDX_0>(ret_0);
|
||||
auto ret_1 = l0op::ViewCopy(out0, output, uniqueExecutor.get());
|
||||
CHECK_RET(ret_1 != nullptr, ACLNN_ERR_INNER_NULLPTR);
|
||||
auto out1 = std::get<OUTPUT_IDX_1>(ret_0);
|
||||
auto ret_2 = l0op::ViewCopy(out1, outputScale, uniqueExecutor.get());
|
||||
CHECK_RET(ret_2 != nullptr, ACLNN_ERR_INNER_NULLPTR);
|
||||
*workspaceSize = uniqueExecutor->GetWorkspaceSize();
|
||||
uniqueExecutor.ReleaseTo(executor);
|
||||
return ACLNN_SUCCESS;
|
||||
}
|
||||
|
||||
aclnnStatus aclnnGroupedMatmulSwigluQuantWeightNzTensorListGetWorkspaceSize(const aclTensor *x, const aclTensorList *weight,
|
||||
const aclTensor *bias, const aclTensor *offset,
|
||||
const aclTensorList *weightScale, const aclTensor *xScale,
|
||||
const aclTensor *groupList,
|
||||
aclTensor *output, aclTensor *outputScale,
|
||||
aclTensor *outputOffset, uint64_t *workspaceSize,
|
||||
aclOpExecutor **executor) {
|
||||
OP_CHECK_COMM_INPUT(workspaceSize, executor);
|
||||
L2_DFX_PHASE_1(aclnnGroupedMatmulSwigluQuantWeightNzTensorList,
|
||||
DFX_IN(x, weight, bias, offset, weightScale, xScale, groupList),
|
||||
DFX_OUT(output, outputScale, outputOffset));
|
||||
// weight is forcibly bound to StorageFormat and ViewFormat as NZ in this scenario
|
||||
CHECK_RET(weight != nullptr, ACLNN_ERR_PARAM_NULLPTR);
|
||||
for (size_t i = 0; i < weight->Size(); ++i) {
|
||||
auto storgeShape = (*weight)[i]->GetStorageShape();
|
||||
auto viewShape = (*weight)[i]->GetViewShape();
|
||||
aclTensor* weightNZ = const_cast<aclTensor*>((*weight)[i]);
|
||||
CHECK_COND((storgeShape.GetDimNum() == WEIGHT_NZ_DIM_LIMIT),
|
||||
ACLNN_ERR_PARAM_INVALID,
|
||||
"aclnnGroupedMatmulSwigluQuantWeightNZTensorList, The dimnum of storageShape for second input (weight) \
|
||||
must be 4. \n But StorageShape got %s , and dimNum is %lu.",
|
||||
op::ToString(storgeShape).GetString(), storgeShape.GetDimNum());
|
||||
// The StorageFormat of weight is unconditionally regarded as NZ
|
||||
weightNZ->SetStorageFormat(op::Format::FORMAT_FRACTAL_NZ);
|
||||
if (viewShape.GetDimNum() == WEIGHT_NZ_DIM_LIMIT){
|
||||
// If the viewShape of weight is 4-dimensional, it is regarded as NZ
|
||||
weightNZ->SetViewFormat(op::Format::FORMAT_FRACTAL_NZ);
|
||||
} else if (viewShape.GetDimNum() == WEIGHT_ND_DIM_LIMIT){
|
||||
// If the viewShape of weight is 2-dimensional, it is regarded as ND
|
||||
weightNZ->SetViewFormat(op::Format::FORMAT_ND);
|
||||
}
|
||||
}
|
||||
// Call the common interface
|
||||
return aclnnGroupedMatmulSwigluQuantWeightNzTensorListGetWorkspaceSizeCommon(x, weight, bias, offset, weightScale, xScale, groupList,
|
||||
output, outputScale, outputOffset, workspaceSize, executor);
|
||||
}
|
||||
|
||||
aclnnStatus aclnnGroupedMatmulSwigluQuantWeightNzTensorList(void *workspace,
|
||||
uint64_t workspaceSize,
|
||||
aclOpExecutor *executor,
|
||||
aclrtStream stream) {
|
||||
L2_DFX_PHASE_2(aclnnGroupedMatmulSwigluQuantWeightNzTensorList);
|
||||
CHECK_COND(CommonOpExecutorRun(workspace, workspaceSize, executor, stream) == ACLNN_SUCCESS, ACLNN_ERR_INNER,
|
||||
"This is an error in GroupedMatmulSwigluQuantWeightNzTensorList launch aicore");
|
||||
return ACLNN_SUCCESS;
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
@@ -0,0 +1,56 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
#ifndef OP_API_INC_GROUPED_MATMUL_SWIGLU_QUANT_WEIGHT_NZ_TENSOR_LIST_H
|
||||
#define OP_API_INC_GROUPED_MATMUL_SWIGLU_QUANT_WEIGHT_NZ_TENSOR_LIST_H
|
||||
#include "aclnn/aclnn_base.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* @brief The first interface of aclnnGroupedMatmulSwigluQuantWeightNzTensorList, which calculates the workspace size according to the specific calculation process.
|
||||
* @domain aclnn_ops_infer
|
||||
*
|
||||
* @param [in] x: Represents x in the formula. The data type supports INT8, and the data format supports ND.
|
||||
* @param [in] weight:
|
||||
* Represents weight in the formula. The data type supports INT8, and the data format supports NZ.
|
||||
* @param [in] weightScale: Represents quantization parameters. The data type supports FLOAT16, BFLOAT16, and FLOAT32. The data format supports ND, with a maximum length of 128.
|
||||
* Represents per Channel parameters. The data type supports FLOAT16 and BFLOAT16. The data format supports ND.
|
||||
* @param [in] xScale:
|
||||
* Represents per Token quantization parameters. The data type supports FLOAT32, and the data format supports ND.
|
||||
* @param [in] groupList: Required parameter, representing the index situation on the input and output grouping axes. The data type supports INT64.
|
||||
* @param [out] quantOutput: Represents out in the formula. The data type supports INT8, and the data format supports ND.
|
||||
* @param [out] quantScaleOutput: Represents outQuantScale in the formula. The data type supports Float32.
|
||||
* @param [out] workspaceSize: Returns the workspace size that users need to apply for on the npu device side.
|
||||
* @param [out] executor: Returns the op executor, containing the operator calculation process.
|
||||
* @return aclnnStatus: Returns the status code.
|
||||
*/
|
||||
__attribute__((visibility("default"))) aclnnStatus aclnnGroupedMatmulSwigluQuantWeightNzTensorListGetWorkspaceSize(
|
||||
const aclTensor *x, const aclTensorList *weight, const aclTensor *bias, const aclTensor *offset,
|
||||
const aclTensorList *weightScale, const aclTensor *xScale, const aclTensor *groupList,
|
||||
aclTensor *output, aclTensor *outputScale, aclTensor *outputOffset, uint64_t *workspaceSize, aclOpExecutor **executor);
|
||||
|
||||
/**
|
||||
* @brief The second interface of aclnnGroupedMatmulSwigluQuantWeightNzTensorList, used to execute calculations.
|
||||
* @param [in] workspace: The starting address of the workspace memory applied for on the npu device side.
|
||||
* @param [in] workspaceSize: The workspace size applied for on the npu device side, obtained from the first interface aclnnGroupedMatmulSwigluQuantWeightNzTensorListGetWorkspaceSize.
|
||||
* @param [in] stream: acl stream.
|
||||
* @param [in] executor: op executor, containing the operator calculation process.
|
||||
* @return aclnnStatus: Returns the status code.
|
||||
*/
|
||||
__attribute__((visibility("default"))) aclnnStatus aclnnGroupedMatmulSwigluQuantWeightNzTensorList(void* workspace,
|
||||
uint64_t workspaceSize, aclOpExecutor* executor, aclrtStream stream);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,56 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
#include "opdev/op_log.h"
|
||||
#include "opdev/op_dfx.h"
|
||||
#include "opdev/make_op_executor.h"
|
||||
#include "grouped_matmul_swiglu_quant_weight_nz_tensor_list.h"
|
||||
|
||||
using namespace op;
|
||||
|
||||
namespace l0op {
|
||||
OP_TYPE_REGISTER(GroupedMatmulSwigluQuantWeightNzTensorList);
|
||||
|
||||
const std::tuple<aclTensor*, aclTensor*> GroupedMatmulSwigluQuantWeightNzTensorList(const aclTensor *x,
|
||||
const aclTensorList *weight,
|
||||
const aclTensorList *perChannelScale,
|
||||
const aclTensor *perTokenScale,
|
||||
const aclTensor *groupList,
|
||||
aclOpExecutor *executor) {
|
||||
L0_DFX(GroupedMatmulSwigluQuantWeightNzTensorList, x, weight, perChannelScale, perTokenScale, groupList);
|
||||
if (x == nullptr) {
|
||||
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "x is nullptr.");
|
||||
return std::tuple(nullptr, nullptr);
|
||||
}
|
||||
int64_t m = perTokenScale->GetViewShape().GetDim(0);
|
||||
int64_t n = (*perChannelScale)[0]->GetViewShape().GetDim(0);
|
||||
int64_t nAfterHalve = static_cast<int64_t>(n / 2);
|
||||
gert::Shape outShape({m, nAfterHalve});
|
||||
gert::Shape scaleOutShape({m});
|
||||
auto out = executor->AllocTensor(outShape, DataType::DT_INT8, ge::FORMAT_ND);
|
||||
auto scaleOut = executor->AllocTensor(scaleOutShape, DataType::DT_FLOAT, ge::FORMAT_ND);
|
||||
auto ret = INFER_SHAPE(GroupedMatmulSwigluQuantWeightNzTensorList,
|
||||
OP_INPUT(x, weight, perChannelScale, perTokenScale, groupList),
|
||||
OP_OUTPUT(out, scaleOut));
|
||||
if (ret != ACLNN_SUCCESS) {
|
||||
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "InferShape failed.");
|
||||
return std::tuple(nullptr, nullptr);
|
||||
}
|
||||
ret = ADD_TO_LAUNCHER_LIST_AICORE(GroupedMatmulSwigluQuantWeightNzTensorList,
|
||||
OP_INPUT(x, weight, perChannelScale, perTokenScale, groupList),
|
||||
OP_OUTPUT(out, scaleOut));
|
||||
if (ret != ACLNN_SUCCESS) {
|
||||
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "ADD_TO_LAUNCHER_LIST_AICORE failed.");
|
||||
return std::tuple(nullptr, nullptr);
|
||||
}
|
||||
return std::tie(out, scaleOut);
|
||||
}
|
||||
|
||||
} // namespace l0op
|
||||
@@ -0,0 +1,24 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
#ifndef OP_API_INC_LEVEL0_OP_GROUPED_MATMUL_SWIGLU_QUANT_WEIGHT_NZ_TENSOR_LIST_OP_H
|
||||
#define OP_API_INC_LEVEL0_OP_GROUPED_MATMUL_SWIGLU_QUANT_WEIGHT_NZ_TENSOR_LIST_OP_H
|
||||
|
||||
#include "opdev/op_executor.h"
|
||||
|
||||
namespace l0op {
|
||||
const std::tuple<aclTensor*, aclTensor*> GroupedMatmulSwigluQuantWeightNzTensorList(const aclTensor *x,
|
||||
const aclTensorList *weight,
|
||||
const aclTensorList *perChannelScale,
|
||||
const aclTensor *perTokenScale,
|
||||
const aclTensor *groupList,
|
||||
aclOpExecutor *executor);
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,65 @@
|
||||
/**
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file grouped_matmul_swiglu_quant_weight_nz_tensor_list_def.cpp
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#include <cstdint>
|
||||
#include "register/op_def_registry.h"
|
||||
namespace ops {
|
||||
class GroupedMatmulSwigluQuantWeightNzTensorList : public OpDef {
|
||||
public:
|
||||
explicit GroupedMatmulSwigluQuantWeightNzTensorList(const char* name) : OpDef(name)
|
||||
{
|
||||
this->Input("x")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_INT8,ge::DT_INT8,ge::DT_INT8})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Input("weight")
|
||||
.ParamType(DYNAMIC)
|
||||
.DataType({ge::DT_INT8,ge::DT_INT8,ge::DT_INT8})
|
||||
.Format({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ});
|
||||
this->Input("weight_scale")
|
||||
.ParamType(DYNAMIC)
|
||||
.DataType({ge::DT_FLOAT, ge::DT_BF16, ge::DT_FLOAT16})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Input("x_scale")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_FLOAT,ge::DT_FLOAT,ge::DT_FLOAT})
|
||||
.Format({ge::FORMAT_ND,ge::FORMAT_ND,ge::FORMAT_ND});
|
||||
this->Input("group_list")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_INT64,ge::DT_INT64,ge::DT_INT64})
|
||||
.Format({ge::FORMAT_ND,ge::FORMAT_ND,ge::FORMAT_ND});
|
||||
this->Output("y")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_INT8,ge::DT_INT8,ge::DT_INT8})
|
||||
.Format({ge::FORMAT_ND,ge::FORMAT_ND,ge::FORMAT_ND});
|
||||
this->Output("y_scale")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_FLOAT,ge::DT_FLOAT,ge::DT_FLOAT})
|
||||
.Format({ge::FORMAT_ND,ge::FORMAT_ND,ge::FORMAT_ND});
|
||||
OpAICoreConfig aicore_config;
|
||||
aicore_config.DynamicCompileStaticFlag(true)
|
||||
.DynamicFormatFlag(true)
|
||||
.DynamicRankSupportFlag(true)
|
||||
.DynamicShapeSupportFlag(true)
|
||||
.NeedCheckSupportFlag(false)
|
||||
.PrecisionReduceFlag(true);
|
||||
|
||||
this->AICore().AddConfig("ascend910b", aicore_config);
|
||||
this->AICore().AddConfig("ascend910_93", aicore_config);
|
||||
}
|
||||
};
|
||||
|
||||
OP_ADD(GroupedMatmulSwigluQuantWeightNzTensorList);
|
||||
}
|
||||
@@ -0,0 +1,49 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file grouped_matmul_swiglu_quant_weight_nz_tensor_list_proto.cpp
|
||||
* \brief
|
||||
*/
|
||||
#include "register/op_impl_registry.h"
|
||||
#include "log/ops_log.h"
|
||||
#include "platform/platform_info.h"
|
||||
|
||||
using namespace ge;
|
||||
namespace ops {
|
||||
const int64_t X_INDEX = 0;
|
||||
const int64_t WEIGHTSCALE_INDEX = 2;
|
||||
const int64_t M_DIM_INDEX = 0;
|
||||
const int64_t N_DIM_INDEX = 0;
|
||||
static ge::graphStatus InferShape4GroupedMatmulSwigluQuantWeightNzTensorList(gert::InferShapeContext* context) {
|
||||
const gert::Shape* xShape = context->GetInputShape(X_INDEX);
|
||||
const gert::Shape* weightScaleShape = context->GetDynamicInputShape(WEIGHTSCALE_INDEX, 0);
|
||||
int64_t m = xShape->GetDim(M_DIM_INDEX);
|
||||
int64_t n = static_cast<int64_t>(weightScaleShape->GetDim(N_DIM_INDEX) / 2);
|
||||
auto outShape = context->GetOutputShape(0);
|
||||
outShape->SetDimNum(2);
|
||||
outShape->SetDim(0, m);
|
||||
outShape->SetDim(1, n);
|
||||
auto outScaleShape = context->GetOutputShape(1);
|
||||
outScaleShape->SetDimNum(1);
|
||||
outScaleShape->SetDim(0, m);
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
static graphStatus InferDataType4GroupedMatmulSwigluQuantWeightNzTensorList(gert::InferDataTypeContext* context) {
|
||||
context->SetOutputDataType(0, DataType::DT_INT8);
|
||||
context->SetOutputDataType(1, DataType::DT_FLOAT);
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
IMPL_OP_INFERSHAPE(GroupedMatmulSwigluQuantWeightNzTensorList)
|
||||
.InferShape(InferShape4GroupedMatmulSwigluQuantWeightNzTensorList)
|
||||
.InferDataType(InferDataType4GroupedMatmulSwigluQuantWeightNzTensorList);
|
||||
} // namespace ops
|
||||
@@ -0,0 +1,188 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file grouped_matmul_swiglu_quant_weight_nz_tensor_list_tiling.cpp
|
||||
* \brief
|
||||
*/
|
||||
#include <climits>
|
||||
#include <graph/utils/type_utils.h>
|
||||
#include "register/op_impl_registry.h"
|
||||
#include "log/ops_log.h"
|
||||
#include "error/ops_error.h"
|
||||
#include "tiling/tiling_base.h"
|
||||
#include "grouped_matmul_swiglu_quant_weight_nz_tensor_list_tiling.h"
|
||||
using namespace ge;
|
||||
using namespace AscendC;
|
||||
using namespace GroupedMatmulSwigluQuantWeightNzTensorListTiling;
|
||||
|
||||
template <typename T1, typename T2>
|
||||
static T1 CeilDiv(T1 a, T2 b)
|
||||
{
|
||||
if (b == 0) {
|
||||
return 0;
|
||||
}
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
|
||||
namespace optiling {
|
||||
|
||||
struct GMMSwigluCompileInfo {
|
||||
uint64_t ubSize_ = 0;
|
||||
uint32_t aicNum_ = 0;
|
||||
uint32_t baseM_ = 128;
|
||||
uint32_t baseN_ = 256;
|
||||
};
|
||||
|
||||
static uint64_t CalcMaxTmpSize(const uint32_t row, const uint64_t n) {
|
||||
std::vector<int64_t> shape_vec = {static_cast<int64_t>(row * n)};
|
||||
Shape shape(shape_vec);
|
||||
uint32_t max;
|
||||
uint32_t min;
|
||||
GetSwiGLUMaxMinTmpSize(shape, 4, max, min, false);
|
||||
uint32_t averageTmp = (max + min) >> 1;
|
||||
GetAscendQuantMaxMinTmpSize(shape, 4, max, min);
|
||||
uint32_t average = (max + min) >> 1;
|
||||
average = average > averageTmp ? average : averageTmp;
|
||||
GetAscendDequantMaxMinTmpSize(shape, 4, max, min);
|
||||
averageTmp = (max + min) >> 1;
|
||||
return average > averageTmp ? average : averageTmp;
|
||||
}
|
||||
|
||||
static uint64_t CalRows(const uint64_t ubSize, const uint64_t n) {
|
||||
uint64_t tokenSize = n << 2;
|
||||
uint64_t expectSize = ubSize - tokenSize;
|
||||
uint64_t rows = expectSize / (8 + tokenSize);
|
||||
uint64_t realSize = (8 + tokenSize) * rows + CalcMaxTmpSize(rows, n);
|
||||
while (expectSize < realSize) {
|
||||
rows -= CeilDiv(realSize - expectSize, (8 + tokenSize) << 2);
|
||||
realSize = (8 + tokenSize) * rows + CalcMaxTmpSize(rows, n);
|
||||
}
|
||||
return rows;
|
||||
}
|
||||
|
||||
static void SetTilingKey(gert::TilingContext* context, bool isSplitWorkSpace) {
|
||||
if(isSplitWorkSpace){
|
||||
context->SetTilingKey(1);
|
||||
context->SetScheduleMode(BATCH_MODE_SCHEDULE);
|
||||
} else {
|
||||
context->SetTilingKey(0);
|
||||
context->SetScheduleMode(BATCH_MODE_SCHEDULE);
|
||||
}
|
||||
}
|
||||
|
||||
static bool IsPreFill(GMMSwigluQuantTilingData &tilingData) {
|
||||
int64_t k = tilingData.gmmSwigluBaseParams.get_K();
|
||||
int64_t n = tilingData.gmmSwigluBaseParams.get_N();
|
||||
int64_t m = tilingData.gmmSwigluBaseParams.get_M();
|
||||
int64_t groupNum = tilingData.gmmSwigluBaseParams.get_groupNum();
|
||||
if (groupNum == 128 && m >= PREFILL_M_MIN_SIZE) { // 128:prefiling groupNum
|
||||
std::array<int64_t, 2> kNList = {k, n}; // 2: kNList size
|
||||
if (PREFILL_WHITE_LIST.count(kNList)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
ASCENDC_EXTERN_C graphStatus TilingGMMSwigluQuant(gert::TilingContext* context) {
|
||||
// set info
|
||||
OPS_LOG_I(context->GetNodeName(), "Begin Run GMM Swiglu Tiling .");
|
||||
|
||||
auto compileInfoPtr = context->GetCompileInfo<GMMSwigluCompileInfo>();
|
||||
auto xTensor = context->GetInputTensor(X_INDEX);
|
||||
OPS_LOG_E_IF_NULL(context, xTensor, return GRAPH_FAILED);
|
||||
const int64_t m = xTensor->GetStorageShape().GetDim(0);
|
||||
const int64_t k = xTensor->GetStorageShape().GetDim(1);
|
||||
auto wTensor = context->GetDynamicInputTensor(WEIGHT_INDEX, 0);
|
||||
OPS_LOG_E_IF_NULL(context, wTensor, return GRAPH_FAILED);
|
||||
const int64_t n = wTensor->GetStorageShape().GetDim(0) * wTensor->GetStorageShape().GetDim(3);
|
||||
auto groupListTensor = context->GetDynamicInputTensor(GROUPLIST_INDEX, 0);
|
||||
OPS_LOG_E_IF_NULL(context, groupListTensor, return GRAPH_FAILED);
|
||||
const int64_t groupNum = groupListTensor->GetStorageShape().GetDim(0);
|
||||
GMMSwigluQuantTilingData tilingData;
|
||||
const int64_t row = CalRows(compileInfoPtr->ubSize_, n);
|
||||
tilingData.gmmSwigluBaseParams.set_groupNum(groupNum);
|
||||
tilingData.gmmSwigluBaseParams.set_coreNum(compileInfoPtr->aicNum_);
|
||||
tilingData.gmmSwigluBaseParams.set_K(k);
|
||||
tilingData.gmmSwigluBaseParams.set_N(n);
|
||||
tilingData.gmmSwigluBaseParams.set_M(m);
|
||||
tilingData.gmmSwiglu.set_maxProcessRowNum(row);
|
||||
tilingData.gmmSwiglu.set_groupListLen(groupNum);
|
||||
tilingData.gmmSwiglu.set_tokenLen(n);
|
||||
|
||||
OPS_LOG_D(context->GetNodeName(),"grouped_matmul_swiglu_quant_weight_nz_tensor_list_tiling.");
|
||||
OPS_LOG_D(context->GetNodeName(),"gmmSwigluBaseParams.groupNum: %ld", groupNum);
|
||||
OPS_LOG_D(context->GetNodeName(),"gmmSwigluBaseParams.coreNum: %u ", compileInfoPtr->aicNum_);
|
||||
OPS_LOG_D(context->GetNodeName(),"gmmSwigluBaseParams.M: %ld", m);
|
||||
OPS_LOG_D(context->GetNodeName(),"gmmSwigluBaseParams.K: %ld", k);
|
||||
OPS_LOG_D(context->GetNodeName(),"gmmSwigluBaseParams.N: %ld", n);
|
||||
OPS_LOG_D(context->GetNodeName(),"gmmSwiglu.maxProcessRowNum: %ld", row);
|
||||
OPS_LOG_D(context->GetNodeName(),"gmmSwiglu.groupListLen: %ld", groupNum);
|
||||
OPS_LOG_D(context->GetNodeName(),"gmmSwiglu.tokenLen: %ld", n);
|
||||
|
||||
auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo());
|
||||
using namespace matmul_tiling;
|
||||
MatmulApiTiling tiling(ascendcPlatform);
|
||||
tiling.SetAType(TPosition::GM, CubeFormat::ND, matmul_tiling::DataType::DT_INT8);
|
||||
tiling.SetBType(TPosition::GM, CubeFormat::NZ, matmul_tiling::DataType::DT_INT8);
|
||||
tiling.SetCType(TPosition::GM, CubeFormat::ND, matmul_tiling::DataType::DT_INT32);
|
||||
tiling.SetBias(false);
|
||||
tiling.SetShape(compileInfoPtr->baseM_, compileInfoPtr->baseN_, k);
|
||||
tiling.SetOrgShape(m, n, k);
|
||||
tiling.SetBufferSpace(-1, -1, -1);
|
||||
OPS_ERR_IF(tiling.GetTiling(tilingData.mmTilingData) == -1,
|
||||
OPS_REPORT_VECTOR_INNER_ERR(context->GetNodeName(), "grouped_matmul_swiglu_quant_weight_nz_tensor_list_tiling, get tiling failed"),
|
||||
return GRAPH_FAILED);
|
||||
auto workspaceSizes = context->GetWorkspaceSizes(1);
|
||||
bool isPreFill = IsPreFill(tilingData);
|
||||
tilingData.gmmSwigluBaseParams.set_isPreFill(isPreFill);
|
||||
int64_t usrWorkspaceLimut = isPreFill ? PREFILL_USER_WORKSPACE_LIMIT : USER_WORKSPACE_LIMIT;
|
||||
int64_t mLimit = ((usrWorkspaceLimut / DOUBLE_WORKSPACE_SPLIT) / INT32_DTYPE_SIZE) / n;
|
||||
OPS_ERR_IF(mLimit <= 0,
|
||||
OPS_REPORT_VECTOR_INNER_ERR(context->GetNodeName(),"mLimit is %ld must over then 0.", mLimit),
|
||||
return GRAPH_FAILED);
|
||||
tilingData.gmmSwigluBaseParams.set_mLimit(mLimit);
|
||||
workspaceSizes[0] = SYS_WORKSPACE_SIZE + ((mLimit * DOUBLE_WORKSPACE_SPLIT > m \
|
||||
? m \
|
||||
: mLimit * DOUBLE_WORKSPACE_SPLIT) * n * sizeof(int32_t));
|
||||
bool isSplitWorkSpace = m > mLimit * DOUBLE_WORKSPACE_SPLIT;
|
||||
OPS_LOG_D(context->GetNodeName(), "USER_WORKSPACE_LIMIT: %ld", usrWorkspaceLimut);
|
||||
OPS_LOG_D(context->GetNodeName(), "mLimit: %ld", mLimit);
|
||||
OPS_LOG_D(context->GetNodeName(), "workspaceSizes: %lu", workspaceSizes[0]);
|
||||
OPS_LOG_D(context->GetNodeName(), "isSplitWorkSpace: %s", isSplitWorkSpace ? "true" : "false");
|
||||
OPS_LOG_D(context->GetNodeName(), "isPreFill: %s", isPreFill ? "true" : "false");
|
||||
SetTilingKey(context, isSplitWorkSpace);
|
||||
tilingData.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity());
|
||||
context->SetBlockDim(compileInfoPtr->aicNum_); // block dim is the number of aicube
|
||||
context->GetRawTilingData()->SetDataSize(tilingData.GetDataSize());
|
||||
|
||||
OPS_LOG_D(context->GetNodeName(), "End Run GMM Swiglu Tiling.");
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ASCENDC_EXTERN_C graphStatus TilingPrepareForGMMSwigluQuant(gert::TilingParseContext* context) {
|
||||
// get info
|
||||
fe::PlatFormInfos* platformInfoPtr = context->GetPlatformInfo();
|
||||
OPS_LOG_E_IF_NULL(context, platformInfoPtr, return GRAPH_FAILED);
|
||||
auto compileInfoPtr = context->GetCompiledInfo<GMMSwigluCompileInfo>();
|
||||
OPS_LOG_E_IF_NULL(context, compileInfoPtr, return GRAPH_FAILED);
|
||||
|
||||
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfoPtr);
|
||||
compileInfoPtr->aicNum_ = ascendcPlatform.GetCoreNumAic();
|
||||
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, compileInfoPtr->ubSize_);
|
||||
OPS_LOG_D(context->GetNodeName(), "ubSize is %lu, aicNum is %u.", compileInfoPtr->ubSize_, compileInfoPtr->aicNum_);
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
IMPL_OP_OPTILING(GroupedMatmulSwigluQuantWeightNzTensorList)
|
||||
.Tiling(TilingGMMSwigluQuant)
|
||||
.TilingParse<GMMSwigluCompileInfo>(TilingPrepareForGMMSwigluQuant);
|
||||
} // namespace optiling
|
||||
@@ -0,0 +1,68 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file grouped_matmul_swiglu_quant_weight_nz_tensor_list_tiling.h
|
||||
* \brief
|
||||
*/
|
||||
#ifndef AIR_CXX_RUNTIME_V2_OP_IMPL_GROUPED_MATMUL_SWIGLU_QUANT_WEIGHT_NZ_TENSOR_LIST_H
|
||||
#define AIR_CXX_RUNTIME_V2_OP_IMPL_GROUPED_MATMUL_SWIGLU_QUANT_WEIGHT_NZ_TENSOR_LIST_H
|
||||
|
||||
#include <set>
|
||||
#include "register/tilingdata_base.h"
|
||||
#include "tiling/tiling_api.h"
|
||||
|
||||
namespace optiling {
|
||||
BEGIN_TILING_DATA_DEF(GMMSwigluBaseParams)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, groupNum);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, coreNum);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, K);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, N);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, M);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, mLimit);
|
||||
TILING_DATA_FIELD_DEF(uint64_t, isPreFill);
|
||||
END_TILING_DATA_DEF;
|
||||
REGISTER_TILING_DATA_CLASS(GMMSwigluBaseParamsOp, GMMSwigluBaseParams)
|
||||
|
||||
BEGIN_TILING_DATA_DEF(GMMSwiglu)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, maxProcessRowNum);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, groupListLen);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, tokenLen);
|
||||
END_TILING_DATA_DEF;
|
||||
REGISTER_TILING_DATA_CLASS(GMMSwigluOp, GMMSwiglu)
|
||||
|
||||
BEGIN_TILING_DATA_DEF(GMMSwigluQuantTilingData)
|
||||
TILING_DATA_FIELD_DEF_STRUCT(GMMSwigluBaseParams, gmmSwigluBaseParams);
|
||||
TILING_DATA_FIELD_DEF_STRUCT(GMMSwiglu, gmmSwiglu);
|
||||
TILING_DATA_FIELD_DEF_STRUCT(TCubeTiling, mmTilingData);
|
||||
END_TILING_DATA_DEF;
|
||||
|
||||
REGISTER_TILING_DATA_CLASS(GroupedMatmulSwigluQuantWeightNzTensorList, GMMSwigluQuantTilingData)
|
||||
}
|
||||
|
||||
namespace GroupedMatmulSwigluQuantWeightNzTensorListTiling {
|
||||
constexpr uint32_t X_INDEX = 0;
|
||||
constexpr uint32_t WEIGHT_INDEX = 1;
|
||||
constexpr uint32_t GROUPLIST_INDEX = 4;
|
||||
constexpr uint32_t BATCH_MODE_SCHEDULE = 1;
|
||||
constexpr uint32_t SYS_WORKSPACE_SIZE = 16 * 1024 * 1024;
|
||||
constexpr int64_t USER_WORKSPACE_LIMIT = 256 * 1024 * 1024;
|
||||
constexpr int64_t PREFILL_USER_WORKSPACE_LIMIT = 64 * 1024 * 1024;
|
||||
constexpr int64_t DOUBLE_WORKSPACE_SPLIT = 2;
|
||||
constexpr int64_t INT32_DTYPE_SIZE = 4;
|
||||
constexpr uint32_t PREFILL_M_MIN_SIZE = 16 * 1024;
|
||||
|
||||
const std::set<std::array<int64_t, 2>> PREFILL_WHITE_LIST = { // used for preFill case
|
||||
{{2048, 1536}},
|
||||
{{4096, 3072}}
|
||||
};
|
||||
} // namespace GroupedMatmulSwigluQuantWeightNzTensorListTiling
|
||||
|
||||
#endif // AIR_CXX_RUNTIME_V2_OP_IMPL_GROUPED_MATMUL_SWIGLU_QUANT_WEIGHT_NZ_TENSOR_LIST_H
|
||||
@@ -0,0 +1,80 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file grouped_matmul_swiglu_quant_weight_nz_tensor_list.cpp
|
||||
* \brief
|
||||
*/
|
||||
#include "grouped_matmul_swiglu_quant_weight_nz_tensor_list.h"
|
||||
#include <typeinfo>
|
||||
#include "grouped_matmul_swiglu_quant_weight_nz_tensor_list_split_ws.h"
|
||||
using namespace AscendC;
|
||||
using namespace matmul;
|
||||
using namespace GROUPED_MATMUL_SWIGLU_QUANT_WEIGHT_NZ_TENSOR_LIST;
|
||||
using MM_DTYPE_Y = int32_t;
|
||||
|
||||
template <bool trans = false>
|
||||
using xType = MatmulType<AscendC::TPosition::GM, CubeFormat::ND, DTYPE_X>;
|
||||
|
||||
template <bool trans = false>
|
||||
using weightType = MatmulType<AscendC::TPosition::GM, CubeFormat::NZ, DTYPE_WEIGHT>;
|
||||
|
||||
using yType = MatmulType<AscendC::TPosition::GM, CubeFormat::ND, MM_DTYPE_Y>;
|
||||
|
||||
#define GMM_CV_SPLIT_IMP(computeClass, dtypeC, transA, transB, sync, cfg, aType, bType, cType) \
|
||||
do { \
|
||||
using matmulType = MMImplType<aType<transA>, bType<transB>, cType, cType, cfg>; \
|
||||
matmulType::MT mm; \
|
||||
GET_TILING_DATA_MEMBER(GMMSwigluQuantTilingData, gmmSwigluBaseParams, gmmSwigluBaseParams_, tiling); \
|
||||
GET_TILING_DATA_MEMBER(GMMSwigluQuantTilingData, mmTilingData, mmTilingData_, tiling); \
|
||||
GET_TILING_DATA_MEMBER(GMMSwigluQuantTilingData, gmmSwiglu, gmmSwiglu_, tiling); \
|
||||
if ASCEND_IS_AIC { \
|
||||
mm.SetSubBlockIdx(0); \
|
||||
mm.Init(&mmTilingData_, &tPipe); \
|
||||
} \
|
||||
computeClass<matmulType, sync, dtypeC> computeOp(mm); \
|
||||
computeOp.Init(x, weight, perChannelScale, perTokenScale, groupList, quantOutput, quantScaleOutput, \
|
||||
user1, &gmmSwigluBaseParams_, &mmTilingData_, &gmmSwiglu_, &tPipe); \
|
||||
computeOp.Process(); \
|
||||
} while (0)
|
||||
|
||||
extern "C" __global__ __aicore__ void grouped_matmul_swiglu_quant_weight_nz_tensor_list(GM_ADDR x, GM_ADDR weight, GM_ADDR perChannelScale, GM_ADDR perTokenScale,
|
||||
GM_ADDR groupList, GM_ADDR quantOutput, GM_ADDR quantScaleOutput,
|
||||
GM_ADDR workspace, GM_ADDR tiling) {
|
||||
TPipe tPipe;
|
||||
AscendCUtils::SetOverflow(1);
|
||||
KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_MIX_AIC_1_2);
|
||||
GM_ADDR user1 = GetUserWorkspace(workspace);
|
||||
if (TILING_KEY_IS(0)) { // antiquant msd
|
||||
KERNEL_TASK_TYPE(0, KERNEL_TYPE_MIX_AIC_1_2);
|
||||
GMM_CV_SPLIT_IMP(
|
||||
GMMSwigluCompute, // computeClass
|
||||
DTYPE_WEIGHT_SCALE,
|
||||
false, // transA
|
||||
false, // transB
|
||||
false, // sync
|
||||
NZ_CFG_MDL, // cfg
|
||||
xType, // aType
|
||||
weightType, // bType
|
||||
yType); // cType
|
||||
} else if(TILING_KEY_IS(1)){
|
||||
KERNEL_TASK_TYPE(1, KERNEL_TYPE_MIX_AIC_1_2);
|
||||
GMM_CV_SPLIT_IMP(
|
||||
GMMSwigluSplitWorkSpaceCompute, // computeClass
|
||||
DTYPE_WEIGHT_SCALE,
|
||||
false, // transA
|
||||
false, // transB
|
||||
false, // sync
|
||||
NZ_CFG_MDL, // cfg
|
||||
xType, // aType
|
||||
weightType, // bType
|
||||
yType); // cType
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,498 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file grouped_matmul_swiglu_quant_weight_nz_tensor_list.h
|
||||
* \brief
|
||||
*/
|
||||
#ifndef ASCENDC_GROUPED_MATMUL_SWIGLU_QUANT_WEIGHT_NZ_TENSOR_LIST_H
|
||||
#define ASCENDC_GROUPED_MATMUL_SWIGLU_QUANT_WEIGHT_NZ_TENSOR_LIST_H
|
||||
|
||||
#include "grouped_matmul_swiglu_quant_weight_nz_tensor_list_utils.h"
|
||||
namespace GROUPED_MATMUL_SWIGLU_QUANT_WEIGHT_NZ_TENSOR_LIST {
|
||||
/** @brief intenal computation class
|
||||
*/
|
||||
template <class mmType, bool sync = false, typename CHANNELDTYPE = float>
|
||||
class GMMSwigluCompute{
|
||||
public:
|
||||
using AT = typename mmType::AT::T;
|
||||
using BT = typename mmType::BT::T;
|
||||
using B = typename mmType::BT;
|
||||
using CT = typename mmType::CT::T;
|
||||
using BiasT = typename mmType::BiasT::T;
|
||||
using WT = int8_t;
|
||||
constexpr static bool transposeX = mmType::AT::isTrans;
|
||||
constexpr static bool transposeW = mmType::BT::isTrans;
|
||||
|
||||
/** @brief constructor */
|
||||
__aicore__ inline GMMSwigluCompute(typename mmType::MT& mm_): mm(mm_) {}
|
||||
|
||||
__aicore__ inline void Init(GM_ADDR x, GM_ADDR weight, GM_ADDR perChannelScale, GM_ADDR perTokenScale,
|
||||
GM_ADDR groupList, GM_ADDR quantOutput, GM_ADDR quantScaleOutput,
|
||||
GM_ADDR workspace,
|
||||
const GMMSwigluBaseParams* __restrict gmmBaseParamsIN,
|
||||
const TCubeTiling* __restrict mmTilingDataIN,
|
||||
const GMMSwiglu* __restrict gmmSwigluIN, TPipe* tPipeIN);
|
||||
__aicore__ inline void Process();
|
||||
private:
|
||||
__aicore__ inline void MMCompute(uint32_t groupIdx, MNConfig& mnConfig, uint32_t coreIdx);
|
||||
|
||||
__aicore__ inline void UpdateMnConfig(MNConfig &mnConfig);
|
||||
|
||||
__aicore__ inline void SetMNConfig(const int32_t splitValue, const uint32_t groupIdx, MNConfig &mnConfig);
|
||||
|
||||
__aicore__ inline void SetMKN(const int32_t splitValue, const uint32_t groupIdx, MNConfig &mnConfig);
|
||||
|
||||
__aicore__ inline uint64_t GetWOffset(uint32_t tailN, uint32_t k);
|
||||
|
||||
__aicore__ inline void MNBlockIdxCompute(MNConfig &mnConfig, const uint32_t curBlock,
|
||||
const uint32_t count, const uint32_t thresholdM_dimN);
|
||||
template <typename DTYPE_CS>
|
||||
__aicore__ inline void UpdateChannelScale(uint32_t loopidx);
|
||||
__aicore__ inline void VectorCompute(uint32_t loopidx);
|
||||
template <typename DTYPE_CS>
|
||||
__aicore__ inline void PreLoadTokenAndChannel(LocalTensor<float>& channelScaleLocal);
|
||||
__aicore__ inline void UpdateVecConfig(uint32_t blockIdx, VecConfig& vecConfig);
|
||||
__aicore__ inline void customDataCopyIn(uint32_t outLoopIdx);
|
||||
__aicore__ inline void customDataCopyOut();
|
||||
__aicore__ inline void Dequant(uint32_t loopidx);
|
||||
__aicore__ inline void Quant(uint32_t loopidx);
|
||||
__aicore__ inline void Swiglu(uint32_t loopidx);
|
||||
private:
|
||||
typename mmType::MT& mm;
|
||||
const GMMSwigluBaseParams* __restrict gmmBaseParams;
|
||||
const GMMSwiglu* __restrict gmmSwiglu;
|
||||
const TCubeTiling* __restrict mmTilingData;
|
||||
uint32_t blockIdx;
|
||||
VecConfig vecConfig;
|
||||
TPipe* pipe;
|
||||
GlobalTensor<int8_t> xGM, weightGM;
|
||||
GlobalTensor<CHANNELDTYPE> perChannelScaleGM;
|
||||
GlobalTensor<float> perTokenScaleGM;
|
||||
GlobalTensor<int64_t> groupListGM;
|
||||
GlobalTensor<int8_t> quantOutputGM;
|
||||
GlobalTensor<float> quantScaleOutputGM;
|
||||
GlobalTensor<int32_t> mmOutGM;
|
||||
// define the que
|
||||
TQue<QuePosition::VECIN, 1> mmOutQueue;
|
||||
TQue<QuePosition::VECIN, 1> perChannelScaleInQueue;
|
||||
TQue<QuePosition::VECOUT, 1> quantOutQueue;
|
||||
TQue<QuePosition::VECOUT, 1> quantScaleOutQueue;
|
||||
TBuf<TPosition::VECCALC> reduceWorkspace;
|
||||
TBuf<TPosition::VECCALC> castWorkspace;
|
||||
bool sequentialWrite = true;
|
||||
uint32_t cubeNum; // Matmul completions on the kernel
|
||||
uint32_t groupNum; // Matmul completions on the kernel
|
||||
int32_t preOffset;
|
||||
int64_t aicCoreNum;
|
||||
int64_t aivCoreNum;
|
||||
GM_ADDR xTensorPtr;
|
||||
GM_ADDR weightTensorPtr;
|
||||
GM_ADDR perChannelScalePtr;
|
||||
};
|
||||
|
||||
template <typename mmType, bool sync, typename CHANNELDTYPE>
|
||||
__aicore__ inline void GMMSwigluCompute<mmType, sync, CHANNELDTYPE>::Init(GM_ADDR x, GM_ADDR weight, GM_ADDR perChannelScale, GM_ADDR perTokenScale,
|
||||
GM_ADDR groupList, GM_ADDR quantOutput, GM_ADDR quantScaleOutput,
|
||||
GM_ADDR workspace,
|
||||
const GMMSwigluBaseParams* __restrict gmmSwigluBaseParamsIn,
|
||||
const TCubeTiling* __restrict mmTilingDataIN,
|
||||
const GMMSwiglu* __restrict gmmSwigluIN, TPipe* tPipeIN)
|
||||
{
|
||||
aicCoreNum = GetBlockNum();
|
||||
aivCoreNum = aicCoreNum * 2;
|
||||
blockIdx = GetBlockIdx();
|
||||
mmTilingData = mmTilingDataIN;
|
||||
gmmBaseParams = gmmSwigluBaseParamsIn;
|
||||
gmmSwiglu = gmmSwigluIN;
|
||||
pipe = tPipeIN;
|
||||
xTensorPtr = x;
|
||||
weightTensorPtr = weight;
|
||||
perChannelScalePtr = perChannelScale;
|
||||
groupNum = gmmSwiglu->groupListLen;
|
||||
if ASCEND_IS_AIC {
|
||||
groupListGM.SetGlobalBuffer((__gm__ int64_t *)groupList, gmmSwiglu->groupListLen);
|
||||
mmOutGM.SetGlobalBuffer((__gm__ int32_t *)workspace, gmmBaseParams->M * gmmSwiglu->tokenLen);
|
||||
}
|
||||
if ASCEND_IS_AIV {
|
||||
mmOutGM.SetGlobalBuffer((__gm__ int32_t *)workspace, gmmBaseParams->M * gmmSwiglu->tokenLen);
|
||||
perChannelScaleGM.SetGlobalBuffer((__gm__ CHANNELDTYPE *)perChannelScale, gmmSwiglu->groupListLen * gmmSwiglu->tokenLen);
|
||||
perTokenScaleGM.SetGlobalBuffer((__gm__ float *)perTokenScale, gmmSwiglu->maxProcessRowNum);
|
||||
groupListGM.SetGlobalBuffer((__gm__ int64_t *)groupList, gmmSwiglu->groupListLen);
|
||||
quantOutputGM.SetGlobalBuffer((__gm__ int8_t *)quantOutput, gmmBaseParams->M * gmmSwiglu->tokenLen / 2);
|
||||
quantScaleOutputGM.SetGlobalBuffer((__gm__ float *)quantScaleOutput, gmmSwiglu->maxProcessRowNum);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename mmType, bool sync, typename CHANNELDTYPE>
|
||||
__aicore__ inline void GMMSwigluCompute<mmType, sync, CHANNELDTYPE>::Process() {
|
||||
MNConfig mnConfig;
|
||||
if ASCEND_IS_AIC {
|
||||
preOffset = 0;
|
||||
int32_t prevSplitValue = 0;
|
||||
for (uint32_t groupIdx = 0, count = 0; groupIdx < gmmSwiglu->groupListLen; ++groupIdx) {
|
||||
UpdateMnConfig(mnConfig);
|
||||
int32_t currSplitValue = static_cast<int32_t>(groupListGM.GetValue(groupIdx));
|
||||
int32_t splitValue = currSplitValue - prevSplitValue;
|
||||
prevSplitValue = currSplitValue;
|
||||
SetMNConfig(splitValue, groupIdx, mnConfig);
|
||||
if (mnConfig.m <= 0 || mnConfig.k <= 0 || mnConfig.n <= 0) {
|
||||
continue;
|
||||
}
|
||||
mnConfig.blockDimM = Ceil(mnConfig.m, mnConfig.singleM);
|
||||
mnConfig.blockDimN = Ceil(mnConfig.n, mnConfig.singleN);
|
||||
|
||||
uint32_t curCount = count + mnConfig.blockDimM * mnConfig.blockDimN;
|
||||
uint32_t curBlock = blockIdx >= count ? blockIdx : blockIdx + gmmBaseParams->coreNum;
|
||||
uint32_t thresholdM_dimN = THRESHOLD_BLOCK_NUM * mnConfig.blockDimN;
|
||||
|
||||
while (curBlock < curCount) {
|
||||
MNBlockIdxCompute(mnConfig, curBlock, count, thresholdM_dimN);
|
||||
MMCompute(groupIdx, mnConfig, blockIdx);
|
||||
curBlock += aicCoreNum;
|
||||
}
|
||||
count = curCount % gmmBaseParams->coreNum;
|
||||
}
|
||||
SyncAll<false>();
|
||||
}
|
||||
|
||||
if ASCEND_IS_AIV {
|
||||
UpdateVecConfig(blockIdx, vecConfig);
|
||||
if (blockIdx < vecConfig.usedCoreNum) {
|
||||
LocalTensor<float> channelScaleLocal = perChannelScaleInQueue.AllocTensor<float>();
|
||||
LocalTensor<int32_t> mmLocal = mmOutQueue.AllocTensor<int32_t>();
|
||||
LocalTensor<int8_t> quantLocal = quantOutQueue.AllocTensor<int8_t>();
|
||||
LocalTensor<float> quantScaleLocal = quantScaleOutQueue.AllocTensor<float>();
|
||||
mmOutQueue.EnQue(mmLocal);
|
||||
quantScaleOutQueue.EnQue(quantScaleLocal);
|
||||
quantOutQueue.EnQue(quantLocal);
|
||||
PreLoadTokenAndChannel<CHANNELDTYPE>(channelScaleLocal);
|
||||
}
|
||||
SyncAll<false>();
|
||||
if (blockIdx < vecConfig.usedCoreNum) {
|
||||
for (uint32_t outLoopIdx = 0; outLoopIdx < vecConfig.outLoopNum; outLoopIdx++) {
|
||||
vecConfig.innerLoopNum = outLoopIdx == (vecConfig.outLoopNum - 1)
|
||||
? vecConfig.tailLoopNum
|
||||
: gmmSwiglu->maxProcessRowNum;
|
||||
customDataCopyIn(outLoopIdx);
|
||||
for (uint32_t innerLoopIdx = 0; innerLoopIdx < vecConfig.innerLoopNum; innerLoopIdx++) {
|
||||
UpdateChannelScale<CHANNELDTYPE>(innerLoopIdx);
|
||||
VectorCompute(innerLoopIdx);
|
||||
}
|
||||
customDataCopyOut();
|
||||
}
|
||||
|
||||
LocalTensor<float> channelScaleLocal = perChannelScaleInQueue.DeQue<float>();
|
||||
LocalTensor<int32_t> mmLocal = mmOutQueue.DeQue<int32_t>();
|
||||
LocalTensor<int8_t> quantLocal = quantOutQueue.DeQue<int8_t>();
|
||||
LocalTensor<float> quantScaleLocal = quantScaleOutQueue.DeQue<float>();
|
||||
perChannelScaleInQueue.FreeTensor(channelScaleLocal);
|
||||
mmOutQueue.FreeTensor(mmLocal);
|
||||
quantScaleOutQueue.FreeTensor(quantScaleLocal);
|
||||
quantOutQueue.FreeTensor(quantLocal);
|
||||
} else {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename mmType, bool sync, typename CHANNELDTYPE>
|
||||
template <typename DTYPE_CS>
|
||||
__aicore__ inline void GMMSwigluCompute<mmType, sync, CHANNELDTYPE>::PreLoadTokenAndChannel(LocalTensor<float>& channelScaleLocal)
|
||||
{
|
||||
GlobalTensor<CHANNELDTYPE> perChannelScaleTensor;
|
||||
perChannelScaleTensor.SetGlobalBuffer(GetTensorAddr<CHANNELDTYPE>(vecConfig.curGroupIdx, perChannelScalePtr));
|
||||
|
||||
DataCopyExtParams copyChannelParams{1, static_cast<uint32_t>(gmmSwiglu->tokenLen * sizeof(DTYPE_CS)), 0, 0, 0};
|
||||
DataCopyPadExtParams<DTYPE_CS> padParams{false, 0 ,0, 0};
|
||||
if constexpr(!IsSameType<DTYPE_CS, float>::value) {
|
||||
LocalTensor<DTYPE_CS> dstLocalT = channelScaleLocal.template ReinterpretCast<DTYPE_CS>();
|
||||
DataCopyPad(dstLocalT[gmmSwiglu->tokenLen], perChannelScaleTensor, copyChannelParams, padParams);
|
||||
PipeBarrier<PIPE_ALL>();
|
||||
Cast(channelScaleLocal, dstLocalT[gmmSwiglu->tokenLen], RoundMode::CAST_NONE, gmmSwiglu->tokenLen);
|
||||
} else {
|
||||
DataCopyPad(channelScaleLocal, perChannelScaleTensor, copyChannelParams, padParams);
|
||||
}
|
||||
perChannelScaleInQueue.EnQue(channelScaleLocal);
|
||||
}
|
||||
|
||||
template <typename mmType, bool sync, typename CHANNELDTYPE>
|
||||
__aicore__ inline void GMMSwigluCompute<mmType, sync, CHANNELDTYPE>::MMCompute(uint32_t groupIdx, MNConfig& mnConfig, uint32_t coreIdx)
|
||||
{
|
||||
uint32_t tailN = mnConfig.nIdx * mnConfig.singleN;
|
||||
uint32_t curSingleN = mnConfig.nIdx < mnConfig.blockDimN - 1 ? mnConfig.singleN : mnConfig.n - tailN;
|
||||
uint32_t curSingleM = mnConfig.mIdx < mnConfig.blockDimM - 1 ? mnConfig.singleM
|
||||
: mnConfig.m - mnConfig.mIdx * mnConfig.singleM;
|
||||
uint64_t xOffset = mnConfig.mIdx * mnConfig.singleM * mnConfig.k;
|
||||
if constexpr (transposeX) {
|
||||
xOffset = mnConfig.mIdx * mnConfig.singleM;
|
||||
}
|
||||
uint64_t outOffset = mnConfig.mIdx * mnConfig.singleM * mnConfig.n + tailN;
|
||||
xGM.SetGlobalBuffer((__gm__ int8_t *)xTensorPtr + mnConfig.xBaseOffset);
|
||||
weightGM.SetGlobalBuffer(GetTensorAddr<int8_t>(groupIdx, weightTensorPtr) + GetWOffset(tailN, mnConfig.k));
|
||||
if (mnConfig.blockDimM == 1){
|
||||
weightGM.SetL2CacheHint(CacheMode::CACHE_MODE_DISABLE);
|
||||
}
|
||||
mnConfig.workSpaceOffset = outOffset + mnConfig.yBaseOffset;
|
||||
mm.SetOrgShape(mnConfig.m, mnConfig.n, mnConfig.k);
|
||||
mm.SetSingleShape(curSingleM, curSingleN, mnConfig.k);
|
||||
mm.SetTensorA(xGM[xOffset], transposeX);
|
||||
mm.SetTensorB(weightGM, transposeW);
|
||||
mm.template IterateAll<sync>(mmOutGM[mnConfig.workSpaceOffset], 0);
|
||||
}
|
||||
|
||||
template <typename mmType, bool sync, typename CHANNELDTYPE>
|
||||
__aicore__ inline void GMMSwigluCompute<mmType, sync, CHANNELDTYPE>::UpdateMnConfig(MNConfig &mnConfig) {
|
||||
if constexpr (B::format == CubeFormat::NZ) {
|
||||
mnConfig.wBaseOffset += AlignUp<16>(mnConfig.k) * AlignUp<32>(mnConfig.n); // 16: nz format last two dim size
|
||||
} else {
|
||||
mnConfig.wBaseOffset += mnConfig.k * mnConfig.n;
|
||||
}
|
||||
mnConfig.nAxisBaseOffset += mnConfig.n;
|
||||
mnConfig.mAxisBaseOffset += mnConfig.m;
|
||||
mnConfig.xBaseOffset += mnConfig.m * mnConfig.k;
|
||||
mnConfig.yBaseOffset += mnConfig.m * mnConfig.n;
|
||||
}
|
||||
|
||||
template <typename mmType, bool sync, typename CHANNELDTYPE>
|
||||
__aicore__ inline void GMMSwigluCompute<mmType, sync, CHANNELDTYPE>::SetMNConfig(const int32_t splitValue, const uint32_t groupIdx, MNConfig &mnConfig) {
|
||||
SetMKN(splitValue, groupIdx, mnConfig);
|
||||
mnConfig.baseM = BASIC_M;
|
||||
mnConfig.baseN = BASIC_N;
|
||||
mnConfig.singleM = SINGLE_CORE_M;
|
||||
mnConfig.singleN = SINGLE_CORE_N;
|
||||
}
|
||||
|
||||
template <typename mmType, bool sync, typename CHANNELDTYPE>
|
||||
__aicore__ inline void GMMSwigluCompute<mmType, sync, CHANNELDTYPE>::SetMKN(const int32_t splitValue, const uint32_t groupIdx, MNConfig &mnConfig)
|
||||
{
|
||||
mnConfig.m = static_cast<uint32_t>(splitValue);
|
||||
mnConfig.k = gmmBaseParams->K; // tilingData
|
||||
mnConfig.n = gmmBaseParams->N; // tilingData
|
||||
}
|
||||
|
||||
template <typename mmType, bool sync, typename CHANNELDTYPE>
|
||||
__aicore__ inline uint64_t GMMSwigluCompute<mmType, sync, CHANNELDTYPE>::GetWOffset(uint32_t tailN, uint32_t k) {
|
||||
uint64_t wOffset = 0;
|
||||
if constexpr (mmType::BT::format == CubeFormat::NZ) {
|
||||
wOffset = tailN * AlignUp<16>(k); // 16: nz format last two dim size
|
||||
} else {
|
||||
wOffset = tailN;
|
||||
}
|
||||
return wOffset;
|
||||
}
|
||||
|
||||
template <typename mmType, bool sync, typename CHANNELDTYPE>
|
||||
__aicore__ inline void GMMSwigluCompute<mmType, sync, CHANNELDTYPE>::MNBlockIdxCompute(MNConfig &mnConfig, const uint32_t curBlock,
|
||||
const uint32_t count, const uint32_t thresholdM_dimN) {
|
||||
mnConfig.mIdx = (curBlock - count) / mnConfig.blockDimN;
|
||||
mnConfig.nIdx = (curBlock - count) % mnConfig.blockDimN;
|
||||
}
|
||||
|
||||
template <typename mmType, bool sync, typename CHANNELDTYPE>
|
||||
__aicore__ inline void GMMSwigluCompute<mmType, sync, CHANNELDTYPE>::UpdateVecConfig(uint32_t blockIdx, VecConfig& vecConfig)
|
||||
{
|
||||
// Step 1: Read grouplist reduceSum to calculate total data count
|
||||
int64_t prevM = 0;
|
||||
for (uint32_t groupIdx = 0; groupIdx < gmmSwiglu->groupListLen; groupIdx++){
|
||||
int64_t currM = groupListGM.GetValue(groupIdx);
|
||||
int64_t tempM = currM - prevM;
|
||||
prevM = currM;
|
||||
vecConfig.M += tempM;
|
||||
}
|
||||
// Step 2: Calculate core allocation
|
||||
uint32_t eachCoreTaskNum = (vecConfig.M + aivCoreNum - 1) / aivCoreNum;
|
||||
vecConfig.usedCoreNum = vecConfig.M >= aivCoreNum ? aivCoreNum : vecConfig.M;
|
||||
uint32_t tailCoreIdx = vecConfig.M - (eachCoreTaskNum - 1) * vecConfig.usedCoreNum;
|
||||
vecConfig.taskNum = blockIdx < tailCoreIdx ? eachCoreTaskNum : eachCoreTaskNum - 1;
|
||||
vecConfig.startIdx = blockIdx < tailCoreIdx
|
||||
? eachCoreTaskNum * blockIdx
|
||||
:((eachCoreTaskNum - 1) * blockIdx + tailCoreIdx);
|
||||
vecConfig.curIdx = vecConfig.startIdx;
|
||||
vecConfig.startOffset = vecConfig.startIdx * gmmSwiglu->tokenLen;
|
||||
vecConfig.curOffset = vecConfig.startOffset;
|
||||
int64_t curStartIdx = vecConfig.startIdx;
|
||||
prevM = 0;
|
||||
for (uint32_t groupIdx = 0; groupIdx < gmmSwiglu->groupListLen; groupIdx++){
|
||||
int64_t currM = groupListGM.GetValue(groupIdx);
|
||||
int64_t tempM = currM - prevM;
|
||||
prevM = currM;
|
||||
if (curStartIdx >= 0 && curStartIdx - tempM < 0) {
|
||||
vecConfig.curGroupIdx = groupIdx;
|
||||
vecConfig.nextUpadteInterVal = tempM - curStartIdx;
|
||||
}
|
||||
curStartIdx -= tempM;
|
||||
}
|
||||
// Step 3: Calculate total data volume
|
||||
vecConfig.outLoopNum = (vecConfig.taskNum + gmmSwiglu->maxProcessRowNum - 1) / gmmSwiglu->maxProcessRowNum;
|
||||
vecConfig.tailLoopNum = vecConfig.taskNum % gmmSwiglu->maxProcessRowNum
|
||||
? vecConfig.taskNum % gmmSwiglu->maxProcessRowNum
|
||||
: gmmSwiglu->maxProcessRowNum;
|
||||
pipe->Reset();
|
||||
// Step 4: Allocate space
|
||||
pipe->InitBuffer(mmOutQueue, 1, gmmSwiglu->maxProcessRowNum * gmmSwiglu->tokenLen * sizeof(int32_t));
|
||||
pipe->InitBuffer(perChannelScaleInQueue, 1, gmmSwiglu->tokenLen * sizeof(float));
|
||||
pipe->InitBuffer(quantOutQueue, 1, gmmSwiglu->maxProcessRowNum * gmmSwiglu->tokenLen / 2 * sizeof(int8_t));
|
||||
pipe->InitBuffer(quantScaleOutQueue, 1, AlignUp<int32_t>(gmmSwiglu->maxProcessRowNum, 8) * sizeof(float));
|
||||
pipe->InitBuffer(reduceWorkspace, 1024 * sizeof(float));
|
||||
pipe->InitBuffer(castWorkspace, 32 * sizeof(int8_t));
|
||||
}
|
||||
|
||||
template <typename mmType, bool sync, typename CHANNELDTYPE>
|
||||
__aicore__ inline void GMMSwigluCompute<mmType, sync, CHANNELDTYPE>::customDataCopyIn(uint32_t outLoopIdx)
|
||||
{
|
||||
LocalTensor<int32_t> _inMMLocal_0 = mmOutQueue.DeQue<int32_t>();
|
||||
DataCopyExtParams copyParams_0{1, static_cast<uint32_t>(vecConfig.innerLoopNum * gmmSwiglu->tokenLen * sizeof(int32_t)), 0, 0, 0};
|
||||
DataCopyPadExtParams<int32_t> padParams_0{false, 0 ,0, 0};
|
||||
DataCopyPad(_inMMLocal_0, mmOutGM[vecConfig.curOffset], copyParams_0, padParams_0);
|
||||
|
||||
mmOutQueue.EnQue(_inMMLocal_0);
|
||||
|
||||
LocalTensor<int32_t> _inMMLocal_1 = mmOutQueue.DeQue<int32_t>();
|
||||
|
||||
Cast(_inMMLocal_1.ReinterpretCast<float>(), _inMMLocal_1, RoundMode::CAST_NONE, vecConfig.innerLoopNum * gmmSwiglu->tokenLen);
|
||||
|
||||
mmOutQueue.EnQue(_inMMLocal_1);
|
||||
LocalTensor<float> _inMMLocal_2 = mmOutQueue.DeQue<float>();
|
||||
set_flag(PIPE_S, PIPE_V, EVENT_ID0);
|
||||
for (uint32_t i = 0; i < vecConfig.innerLoopNum; i++){
|
||||
wait_flag(PIPE_S, PIPE_V, EVENT_ID0);
|
||||
float scale = perTokenScaleGM.GetValue(vecConfig.curIdx);
|
||||
set_flag(PIPE_S, PIPE_V, EVENT_ID0);
|
||||
wait_flag(PIPE_S, PIPE_V, EVENT_ID0);
|
||||
Muls(_inMMLocal_2[i * gmmSwiglu->tokenLen], _inMMLocal_2[i * gmmSwiglu->tokenLen], scale, gmmSwiglu->tokenLen);
|
||||
set_flag(PIPE_S, PIPE_V, EVENT_ID0);
|
||||
vecConfig.curIdx++;
|
||||
}
|
||||
wait_flag(PIPE_S, PIPE_V, EVENT_ID0);
|
||||
vecConfig.curOffset = vecConfig.curIdx * gmmSwiglu->tokenLen;
|
||||
mmOutQueue.EnQue(_inMMLocal_2);
|
||||
}
|
||||
|
||||
template <typename mmType, bool sync, typename CHANNELDTYPE>
|
||||
template <typename DTYPE_CS>
|
||||
__aicore__ inline void GMMSwigluCompute<mmType, sync, CHANNELDTYPE>::UpdateChannelScale(uint32_t loopIdx){
|
||||
// Update perChannel
|
||||
if (unlikely(vecConfig.nextUpadteInterVal == 0)) {
|
||||
int64_t loop = gmmSwiglu->groupListLen - vecConfig.curGroupIdx;
|
||||
while (loop--) {
|
||||
int64_t curTemp = groupListGM.GetValue(vecConfig.curGroupIdx);
|
||||
vecConfig.curGroupIdx++;
|
||||
int64_t nextTemp = groupListGM.GetValue(vecConfig.curGroupIdx);
|
||||
if(nextTemp != curTemp){
|
||||
vecConfig.nextUpadteInterVal = nextTemp - curTemp;
|
||||
break;
|
||||
}
|
||||
}
|
||||
LocalTensor<float> _inChannel = perChannelScaleInQueue.DeQue<float>();
|
||||
DataCopyExtParams copyParams{1, static_cast<uint32_t>(gmmSwiglu->tokenLen * sizeof(DTYPE_CS)), 0, 0, 0};
|
||||
DataCopyPadExtParams<DTYPE_CS> padParams{false, 0 ,0, 0};
|
||||
|
||||
GlobalTensor<CHANNELDTYPE> perChannelScaleTensor;
|
||||
perChannelScaleTensor.SetGlobalBuffer(GetTensorAddr<CHANNELDTYPE>(vecConfig.curGroupIdx, perChannelScalePtr));
|
||||
|
||||
if constexpr(!IsSameType<DTYPE_CS, float>::value) {
|
||||
LocalTensor<DTYPE_CS> dstLocalT = _inChannel.template ReinterpretCast<DTYPE_CS>();
|
||||
DataCopyPad(dstLocalT[gmmSwiglu->tokenLen], perChannelScaleTensor, copyParams, padParams);
|
||||
PipeBarrier<PIPE_ALL>();
|
||||
Cast(_inChannel, dstLocalT[gmmSwiglu->tokenLen], RoundMode::CAST_NONE, gmmSwiglu->tokenLen);
|
||||
} else {
|
||||
DataCopyPad(_inChannel, perChannelScaleTensor, copyParams, padParams);
|
||||
}
|
||||
PipeBarrier<PIPE_ALL>();
|
||||
perChannelScaleInQueue.EnQue(_inChannel);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename mmType, bool sync, typename CHANNELDTYPE>
|
||||
__aicore__ inline void GMMSwigluCompute<mmType, sync, CHANNELDTYPE>::VectorCompute(uint32_t loopIdx) {
|
||||
Dequant(loopIdx);
|
||||
Swiglu(loopIdx);
|
||||
Quant(loopIdx);
|
||||
}
|
||||
|
||||
template <typename mmType, bool sync, typename CHANNELDTYPE>
|
||||
__aicore__ inline void GMMSwigluCompute<mmType, sync, CHANNELDTYPE>::Dequant(uint32_t loopIdx) {
|
||||
// perChanelScale * perTokenScale
|
||||
LocalTensor<float> mmLocal = mmOutQueue.DeQue<float>();
|
||||
LocalTensor<float> perChannelLocal = perChannelScaleInQueue.DeQue<float>();
|
||||
Mul(mmLocal[loopIdx * gmmSwiglu->tokenLen], mmLocal[loopIdx * gmmSwiglu->tokenLen], perChannelLocal, gmmSwiglu->tokenLen);
|
||||
vecConfig.nextUpadteInterVal--;
|
||||
mmOutQueue.EnQue(mmLocal);
|
||||
perChannelScaleInQueue.EnQue(perChannelLocal);
|
||||
}
|
||||
|
||||
template <typename mmType, bool sync, typename CHANNELDTYPE>
|
||||
__aicore__ inline void GMMSwigluCompute<mmType, sync, CHANNELDTYPE>::Swiglu(uint32_t loopIdx) {
|
||||
// High-level API swiglu
|
||||
LocalTensor<float> _inMMLocal = mmOutQueue.DeQue<float>();
|
||||
float beta = 1.0f;
|
||||
LocalTensor<float> workspaceLocal= reduceWorkspace.Get<float>();
|
||||
LocalTensor<float> src0Local = _inMMLocal[loopIdx * gmmSwiglu->tokenLen + gmmSwiglu->tokenLen / 2];
|
||||
LocalTensor<float> src1Local = _inMMLocal[loopIdx * gmmSwiglu->tokenLen];
|
||||
SwiGLU<float, false>(workspaceLocal, src0Local, src1Local, beta, gmmSwiglu->tokenLen / 2);
|
||||
PipeBarrier<PIPE_ALL>();
|
||||
DataCopyParams repeatParams{1, static_cast<uint16_t>((gmmSwiglu->tokenLen / 2) / 8), 0, 0};
|
||||
DataCopy(_inMMLocal[loopIdx * gmmSwiglu->tokenLen], workspaceLocal, repeatParams);
|
||||
mmOutQueue.EnQue(_inMMLocal);
|
||||
}
|
||||
|
||||
template <typename mmType, bool sync, typename CHANNELDTYPE>
|
||||
__aicore__ inline void GMMSwigluCompute<mmType, sync, CHANNELDTYPE>::Quant(uint32_t loopIdx) {
|
||||
LocalTensor<float> _inMMLocal = mmOutQueue.DeQue<float>();
|
||||
Abs(_inMMLocal[loopIdx * gmmSwiglu->tokenLen + gmmSwiglu->tokenLen / BISECT],
|
||||
_inMMLocal[loopIdx * gmmSwiglu->tokenLen],
|
||||
gmmSwiglu->tokenLen / BISECT);
|
||||
LocalTensor<float> workspaceLocal= reduceWorkspace.Get<float>();
|
||||
PipeBarrier<PIPE_V>();
|
||||
ReduceMaxTemplate(workspaceLocal,
|
||||
_inMMLocal, loopIdx * gmmSwiglu->tokenLen + gmmSwiglu->tokenLen / BISECT, gmmSwiglu->tokenLen / BISECT);
|
||||
PipeBarrier<PIPE_ALL>();
|
||||
float quantScale = workspaceLocal.GetValue(0) / QUANT_SCALE_INT8;
|
||||
PipeBarrier<PIPE_ALL>();
|
||||
LocalTensor<float> quantScaleLocal = quantScaleOutQueue.DeQue<float>();
|
||||
PipeBarrier<PIPE_ALL>();
|
||||
quantScaleLocal.SetValue(loopIdx, quantScale);
|
||||
PipeBarrier<PIPE_ALL>();
|
||||
quantScale = 1 / quantScale;
|
||||
PipeBarrier<PIPE_ALL>();
|
||||
Muls(_inMMLocal[loopIdx * gmmSwiglu->tokenLen], _inMMLocal[loopIdx * gmmSwiglu->tokenLen],
|
||||
quantScale, gmmSwiglu->tokenLen / BISECT);
|
||||
PipeBarrier<PIPE_V>();
|
||||
LocalTensor<int8_t> quantLocal = quantOutQueue.DeQue<int8_t>();
|
||||
int32_t dstTempOffset = static_cast<int32_t>(loopIdx * gmmSwiglu->tokenLen / BISECT);
|
||||
int32_t srcTempOffset = static_cast<int32_t>(loopIdx * gmmSwiglu->tokenLen);
|
||||
int32_t tempCount = static_cast<int32_t>(gmmSwiglu->tokenLen / BISECT);
|
||||
LocalTensor<int8_t> castSpace = castWorkspace.Get<int8_t>();
|
||||
CastFp32ToInt8Template(quantLocal, _inMMLocal, castSpace, dstTempOffset, srcTempOffset, tempCount);
|
||||
mmOutQueue.EnQue(_inMMLocal);
|
||||
quantOutQueue.EnQue(quantLocal);
|
||||
}
|
||||
|
||||
template <typename mmType, bool sync, typename CHANNELDTYPE>
|
||||
__aicore__ inline void GMMSwigluCompute<mmType, sync, CHANNELDTYPE>::customDataCopyOut() {
|
||||
// perChanelScale * perTokenScale
|
||||
LocalTensor<float> quantScaleLocal = quantScaleOutQueue.DeQue<float>();
|
||||
DataCopyParams copyParams_0{1, (uint16_t)(vecConfig.innerLoopNum * sizeof(float)), 0, 0};
|
||||
PipeBarrier<PIPE_ALL>();
|
||||
DataCopyPad(quantScaleOutputGM[vecConfig.startIdx], quantScaleLocal, copyParams_0);
|
||||
LocalTensor<int8_t> quantLocal = quantOutQueue.DeQue<int8_t>();
|
||||
DataCopyParams copyParams_1{1, (uint16_t)(vecConfig.innerLoopNum * gmmSwiglu->tokenLen / 2 * sizeof(int8_t)), 0, 0};
|
||||
PipeBarrier<PIPE_ALL>();
|
||||
DataCopyPad(quantOutputGM[vecConfig.startIdx * gmmSwiglu->tokenLen / 2], quantLocal, copyParams_1);
|
||||
PipeBarrier<PIPE_ALL>();
|
||||
vecConfig.startIdx += vecConfig.innerLoopNum;
|
||||
vecConfig.startOffset = vecConfig.startIdx * gmmSwiglu->tokenLen;
|
||||
quantOutQueue.EnQue(quantLocal);
|
||||
quantScaleOutQueue.EnQue(quantScaleLocal);
|
||||
}
|
||||
|
||||
} // namespace GROUPED_MATMUL
|
||||
#endif // ASCENDC_GROUPED_MATMUL_QUANT_MIXCORE_H
|
||||
@@ -0,0 +1,588 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file grouped_matmul_swiglu_quant_weight_nz_tensor_list_split_ws.h
|
||||
* \brief
|
||||
*/
|
||||
#ifndef ASCENDC_GROUPED_MATMUL_SWIGLU_QUANT_WEIGHT_NZ_TENSOR_LIST_SPLIT_WS_H
|
||||
#define ASCENDC_GROUPED_MATMUL_SWIGLU_QUANT_WEIGHT_NZ_TENSOR_LIST_SPLIT_WS_H
|
||||
|
||||
#include "grouped_matmul_swiglu_quant_weight_nz_tensor_list_utils.h"
|
||||
namespace GROUPED_MATMUL_SWIGLU_QUANT_WEIGHT_NZ_TENSOR_LIST {
|
||||
/** @brief internal computation class
|
||||
*/
|
||||
|
||||
template <class mmType, bool sync = false, typename CHANNELDTYPE = float>
|
||||
class GMMSwigluSplitWorkSpaceCompute{
|
||||
public:
|
||||
using AT = typename mmType::AT::T;
|
||||
using BT = typename mmType::BT::T;
|
||||
using B = typename mmType::BT;
|
||||
using CT = typename mmType::CT::T;
|
||||
using BiasT = typename mmType::BiasT::T;
|
||||
using WT = int8_t;
|
||||
constexpr static bool transposeX = mmType::AT::isTrans;
|
||||
constexpr static bool transposeW = mmType::BT::isTrans;
|
||||
|
||||
/** @brief constructor */
|
||||
__aicore__ inline GMMSwigluSplitWorkSpaceCompute(typename mmType::MT& mm_): mm(mm_) {}
|
||||
|
||||
__aicore__ inline void Init(GM_ADDR x, GM_ADDR weight, GM_ADDR perChannelScale, GM_ADDR perTokenScale,
|
||||
GM_ADDR groupList, GM_ADDR quantOutput, GM_ADDR quantScaleOutput,
|
||||
GM_ADDR workspace,
|
||||
const GMMSwigluBaseParams* __restrict gmmBaseParamsIN,
|
||||
const TCubeTiling* __restrict mmTilingDataIN,
|
||||
const GMMSwiglu* __restrict gmmSwigluIN, TPipe* tPipeIN);
|
||||
__aicore__ inline void Process();
|
||||
|
||||
private:
|
||||
__aicore__ inline void MMCompute(uint32_t groupIdx, MNConfig& mnConfig, uint32_t coreIdx, GlobalTensor<int32_t> &mmOutGM);
|
||||
|
||||
__aicore__ inline void UpdateMnConfig(MNConfig &mnConfig);
|
||||
|
||||
__aicore__ inline void SetMNConfig(const int32_t splitValue, const uint32_t groupIdx, MNConfig &mnConfig);
|
||||
|
||||
__aicore__ inline void SetMKN(const int32_t splitValue, const uint32_t groupIdx, MNConfig &mnConfig);
|
||||
|
||||
__aicore__ inline uint64_t GetWOffset(uint32_t tailN, uint32_t k);
|
||||
|
||||
__aicore__ inline void MNBlockIdxCompute(MNConfig &mnConfig, const uint32_t curBlock,
|
||||
const uint32_t count, const uint32_t thresholdM_dimN);
|
||||
|
||||
template <typename DTYPE_CS>
|
||||
__aicore__ inline void UpdateChannelScale(uint32_t loopidx, VecConfig& vecConfig);
|
||||
|
||||
__aicore__ inline void VectorCompute(uint32_t loopidx, VecConfig& vecConfig);
|
||||
|
||||
template <typename DTYPE_CS>
|
||||
__aicore__ inline void PreLoadTokenAndChannel(LocalTensor<float>& channelScaleLocal, VecConfig& vecConfig);
|
||||
|
||||
__aicore__ inline void UpdateVecConfig(uint32_t blockIdx, VecConfig& vecConfig);
|
||||
|
||||
__aicore__ inline void UpdateWorkSpaceSplitConfig(WorkSpaceSplitConfig &workspaceSplitConfig, int32_t workspaceSplitLoopIdx);
|
||||
|
||||
__aicore__ inline void InitWorkSpaceSplitConfig(WorkSpaceSplitConfig &workspaceSplitConfig);
|
||||
|
||||
__aicore__ inline void customDataCopyIn(uint32_t outLoopIdx, GlobalTensor<int32_t> &mmOutGM, VecConfig& vecConfig);
|
||||
|
||||
__aicore__ inline void customDataCopyOut(VecConfig& vecConfig);
|
||||
|
||||
__aicore__ inline void Dequant(uint32_t loopidx, VecConfig& vecConfig);
|
||||
|
||||
__aicore__ inline void Quant(uint32_t loopidx, VecConfig& vecConfig);
|
||||
|
||||
__aicore__ inline void Swiglu(uint32_t loopidx, VecConfig& vecConfig);
|
||||
|
||||
private:
|
||||
typename mmType::MT& mm;
|
||||
const GMMSwigluBaseParams* __restrict gmmBaseParams;
|
||||
const GMMSwiglu* __restrict gmmSwiglu;
|
||||
const TCubeTiling* __restrict mmTilingData;
|
||||
uint32_t blockIdx;
|
||||
WorkSpaceSplitConfig workspaceSplitConfig;
|
||||
TPipe* pipe;
|
||||
GlobalTensor<int8_t> xGM;
|
||||
GlobalTensor<int8_t> weightGM;
|
||||
GlobalTensor<CHANNELDTYPE> perChannelScaleGM;
|
||||
GlobalTensor<float> perTokenScaleGM;
|
||||
GlobalTensor<int64_t> groupListGM;
|
||||
GlobalTensor<int8_t> quantOutputGM;
|
||||
GlobalTensor<float> quantScaleOutputGM;
|
||||
GlobalTensor<int32_t> mmOutGM1;
|
||||
GlobalTensor<int32_t> mmOutGM2;
|
||||
// define the que
|
||||
TQue<QuePosition::VECIN, 1> mmOutQueue;
|
||||
TQue<QuePosition::VECIN, 1> perChannelScaleInQueue;
|
||||
TQue<QuePosition::VECOUT, 1> quantOutQueue;
|
||||
TQue<QuePosition::VECOUT, 1> quantScaleOutQueue;
|
||||
TBuf<TPosition::VECCALC> reduceWorkspace;
|
||||
TBuf<TPosition::VECCALC> castWorkspace;
|
||||
bool sequentialWrite = true;
|
||||
uint32_t cubeNum; // Matmul completions on the kernel
|
||||
uint32_t groupNum; // Matmul completions on the kernel
|
||||
int64_t aicCoreNum;
|
||||
int64_t aivCoreNum;
|
||||
GM_ADDR xTensorPtr;
|
||||
GM_ADDR weightTensorPtr;
|
||||
GM_ADDR perChannelScalePtr;
|
||||
};
|
||||
|
||||
template <typename mmType, bool sync, typename CHANNELDTYPE>
|
||||
__aicore__ inline void GMMSwigluSplitWorkSpaceCompute<mmType, sync, CHANNELDTYPE>::Init(GM_ADDR x, GM_ADDR weight,
|
||||
GM_ADDR perChannelScale, GM_ADDR perTokenScale,
|
||||
GM_ADDR groupList, GM_ADDR quantOutput,
|
||||
GM_ADDR quantScaleOutput, GM_ADDR workspace,
|
||||
const GMMSwigluBaseParams* __restrict gmmSwigluBaseParamsIn,
|
||||
const TCubeTiling* __restrict mmTilingDataIN,
|
||||
const GMMSwiglu* __restrict gmmSwigluIN, TPipe* tPipeIN)
|
||||
{
|
||||
aicCoreNum = GetBlockNum();
|
||||
aivCoreNum = aicCoreNum * 2;
|
||||
blockIdx = GetBlockIdx();
|
||||
pipe = tPipeIN;
|
||||
xTensorPtr = x;
|
||||
weightTensorPtr = weight;
|
||||
perChannelScalePtr = perChannelScale;
|
||||
mmTilingData = mmTilingDataIN;
|
||||
gmmBaseParams = gmmSwigluBaseParamsIn;
|
||||
gmmSwiglu = gmmSwigluIN;
|
||||
groupNum = gmmSwiglu->groupListLen;
|
||||
if ASCEND_IS_AIC {
|
||||
groupListGM.SetGlobalBuffer((__gm__ int64_t *)groupList, gmmSwiglu->groupListLen);
|
||||
mmOutGM1.SetGlobalBuffer((__gm__ int32_t *)workspace, gmmBaseParams->mLimit * gmmSwiglu->tokenLen);
|
||||
mmOutGM2.SetGlobalBuffer((__gm__ int32_t *)workspace + gmmBaseParams->mLimit * gmmSwiglu->tokenLen,
|
||||
gmmBaseParams->mLimit * gmmSwiglu->tokenLen);
|
||||
}
|
||||
if ASCEND_IS_AIV {
|
||||
mmOutGM1.SetGlobalBuffer((__gm__ int32_t *)workspace, gmmBaseParams->mLimit * gmmSwiglu->tokenLen);
|
||||
mmOutGM2.SetGlobalBuffer((__gm__ int32_t *)workspace + gmmBaseParams->mLimit * gmmSwiglu->tokenLen,
|
||||
gmmBaseParams->mLimit * gmmSwiglu->tokenLen);
|
||||
perChannelScaleGM.SetGlobalBuffer((__gm__ CHANNELDTYPE *)perChannelScale,
|
||||
gmmSwiglu->groupListLen * gmmSwiglu->tokenLen);
|
||||
perTokenScaleGM.SetGlobalBuffer((__gm__ float *)perTokenScale, gmmBaseParams->M);
|
||||
groupListGM.SetGlobalBuffer((__gm__ int64_t *)groupList, gmmSwiglu->groupListLen);
|
||||
quantOutputGM.SetGlobalBuffer((__gm__ int8_t *)quantOutput, gmmBaseParams->M * gmmSwiglu->tokenLen / 2);
|
||||
quantScaleOutputGM.SetGlobalBuffer((__gm__ float *)quantScaleOutput, gmmBaseParams->M);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename mmType, bool sync, typename CHANNELDTYPE>
|
||||
__aicore__ inline void GMMSwigluSplitWorkSpaceCompute<mmType, sync, CHANNELDTYPE>::InitWorkSpaceSplitConfig(WorkSpaceSplitConfig &workspaceSplitConfig)
|
||||
{
|
||||
workspaceSplitConfig.M = groupListGM.GetValue(gmmSwiglu->groupListLen - 1);
|
||||
workspaceSplitConfig.loopCount = Ceil(workspaceSplitConfig.M, gmmBaseParams->mLimit);
|
||||
workspaceSplitConfig.notLastTaskSize = gmmBaseParams->mLimit;
|
||||
workspaceSplitConfig.lastLoopTaskSize = workspaceSplitConfig.M - (workspaceSplitConfig.loopCount - 1) * gmmBaseParams->mLimit;
|
||||
workspaceSplitConfig.leftMatrixStartIndex = 0;
|
||||
workspaceSplitConfig.rightMatrixExpertStartIndex = 0;
|
||||
workspaceSplitConfig.rightMatrixExpertNextStartIndex = 0;
|
||||
workspaceSplitConfig.isLastLoop = false;
|
||||
}
|
||||
|
||||
template <typename mmType, bool sync, typename CHANNELDTYPE>
|
||||
__aicore__ inline void GMMSwigluSplitWorkSpaceCompute<mmType, sync, CHANNELDTYPE>::UpdateWorkSpaceSplitConfig(WorkSpaceSplitConfig &workspaceSplitConfig, int32_t workspaceSplitLoopIdx)
|
||||
{
|
||||
workspaceSplitConfig.leftMatrixStartIndex = workspaceSplitLoopIdx * gmmBaseParams->mLimit;
|
||||
workspaceSplitConfig.rightMatrixExpertStartIndex = workspaceSplitConfig.rightMatrixExpertNextStartIndex;
|
||||
workspaceSplitConfig.rightMatrixExpertEndIndex = workspaceSplitConfig.rightMatrixExpertStartIndex;
|
||||
// Calculate the right expert matrix end index (rightMatrixExpertEndIndex) and the next start index (rightMatrixExpertNextStartIndex)
|
||||
int32_t curTaskNum = 0;
|
||||
int32_t nextTaskNum = 0;
|
||||
while(workspaceSplitConfig.rightMatrixExpertEndIndex < gmmSwiglu->groupListLen)
|
||||
{
|
||||
curTaskNum = groupListGM.GetValue(workspaceSplitConfig.rightMatrixExpertEndIndex) - workspaceSplitConfig.leftMatrixStartIndex;
|
||||
int32_t nextTaskIdx = workspaceSplitConfig.rightMatrixExpertEndIndex >= gmmSwiglu->groupListLen - 1 \
|
||||
? gmmSwiglu->groupListLen - 1 \
|
||||
: workspaceSplitConfig.rightMatrixExpertEndIndex + 1;
|
||||
nextTaskNum = groupListGM.GetValue(nextTaskIdx) - workspaceSplitConfig.leftMatrixStartIndex;
|
||||
if (curTaskNum > gmmBaseParams->mLimit){
|
||||
workspaceSplitConfig.rightMatrixExpertNextStartIndex = workspaceSplitConfig.rightMatrixExpertEndIndex;
|
||||
break;
|
||||
} else if (curTaskNum == gmmBaseParams->mLimit && nextTaskNum > gmmBaseParams->mLimit){
|
||||
workspaceSplitConfig.rightMatrixExpertNextStartIndex = workspaceSplitConfig.rightMatrixExpertEndIndex + 1;
|
||||
break;
|
||||
} else if (nextTaskNum > gmmBaseParams->mLimit){
|
||||
workspaceSplitConfig.rightMatrixExpertEndIndex++;
|
||||
workspaceSplitConfig.rightMatrixExpertNextStartIndex = workspaceSplitConfig.rightMatrixExpertEndIndex;
|
||||
break;
|
||||
}
|
||||
workspaceSplitConfig.rightMatrixExpertEndIndex++;
|
||||
}
|
||||
workspaceSplitConfig.isLastLoop = workspaceSplitLoopIdx == workspaceSplitConfig.loopCount - 1 ? true : false;
|
||||
|
||||
if (workspaceSplitConfig.isLastLoop) {
|
||||
workspaceSplitConfig.rightMatrixExpertEndIndex = workspaceSplitConfig.rightMatrixExpertEndIndex >= gmmSwiglu->groupListLen \
|
||||
? gmmSwiglu->groupListLen - 1 \
|
||||
: workspaceSplitConfig.rightMatrixExpertEndIndex;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename mmType, bool sync, typename CHANNELDTYPE>
|
||||
__aicore__ inline void GMMSwigluSplitWorkSpaceCompute<mmType, sync, CHANNELDTYPE>::Process() {
|
||||
InitWorkSpaceSplitConfig(workspaceSplitConfig);
|
||||
int32_t parallelNum = gmmBaseParams->isPreFill ? 2 : 1; // 2: double workspace buffer
|
||||
for (int32_t workspaceSplitLoopIdx = 0; workspaceSplitLoopIdx < workspaceSplitConfig.loopCount; workspaceSplitLoopIdx++) {
|
||||
UpdateWorkSpaceSplitConfig(workspaceSplitConfig, workspaceSplitLoopIdx);
|
||||
GlobalTensor<int32_t> mmOutGM = (workspaceSplitLoopIdx % 2 == 0 ) ? mmOutGM1 : mmOutGM2;
|
||||
|
||||
if ASCEND_IS_AIC {
|
||||
if (workspaceSplitLoopIdx >= parallelNum){ // first parallelNum core no need to wait
|
||||
SyncAll<false>();
|
||||
}
|
||||
MNConfig mnConfig;
|
||||
int32_t prevSplitValue = workspaceSplitConfig.leftMatrixStartIndex;
|
||||
for (uint32_t groupIdx = workspaceSplitConfig.rightMatrixExpertStartIndex, count = 0; groupIdx <= workspaceSplitConfig.rightMatrixExpertEndIndex; ++groupIdx) {
|
||||
UpdateMnConfig(mnConfig);
|
||||
int32_t currSplitValue = static_cast<int32_t>(groupListGM.GetValue(groupIdx));
|
||||
currSplitValue = currSplitValue > (workspaceSplitLoopIdx + 1) * gmmBaseParams->mLimit \
|
||||
? (workspaceSplitLoopIdx + 1) * gmmBaseParams->mLimit \
|
||||
: currSplitValue;
|
||||
int32_t splitValue = currSplitValue - prevSplitValue;
|
||||
prevSplitValue = currSplitValue;
|
||||
SetMNConfig(splitValue, groupIdx, mnConfig);
|
||||
if (mnConfig.m <= 0 || mnConfig.k <= 0 || mnConfig.n <= 0) {
|
||||
continue;
|
||||
}
|
||||
mnConfig.blockDimM = Ceil(mnConfig.m, mnConfig.singleM);
|
||||
mnConfig.blockDimN = Ceil(mnConfig.n, mnConfig.singleN);
|
||||
|
||||
uint32_t curCount = count + mnConfig.blockDimM * mnConfig.blockDimN;
|
||||
uint32_t curBlock = blockIdx >= count ? blockIdx : blockIdx + gmmBaseParams->coreNum;
|
||||
uint32_t thresholdM_dimN = THRESHOLD_BLOCK_NUM * mnConfig.blockDimN;
|
||||
|
||||
while (curBlock < curCount) {
|
||||
MNBlockIdxCompute(mnConfig, curBlock, count, thresholdM_dimN);
|
||||
MMCompute(groupIdx, mnConfig, blockIdx, mmOutGM);
|
||||
curBlock += aicCoreNum;
|
||||
}
|
||||
count = curCount % gmmBaseParams->coreNum;
|
||||
}
|
||||
SyncAll<false>();
|
||||
}
|
||||
|
||||
if ASCEND_IS_AIV {
|
||||
VecConfig vecConfig;
|
||||
UpdateVecConfig(blockIdx, vecConfig);
|
||||
if (blockIdx < vecConfig.usedCoreNum) {
|
||||
LocalTensor<float> channelScaleLocal = perChannelScaleInQueue.AllocTensor<float>();
|
||||
LocalTensor<int32_t> mmLocal = mmOutQueue.AllocTensor<int32_t>();
|
||||
LocalTensor<int8_t> quantLocal = quantOutQueue.AllocTensor<int8_t>();
|
||||
LocalTensor<float> quantScaleLocal = quantScaleOutQueue.AllocTensor<float>();
|
||||
mmOutQueue.EnQue(mmLocal);
|
||||
quantScaleOutQueue.EnQue(quantScaleLocal);
|
||||
quantOutQueue.EnQue(quantLocal);
|
||||
PreLoadTokenAndChannel<CHANNELDTYPE>(channelScaleLocal, vecConfig);
|
||||
}
|
||||
SyncAll<false>();
|
||||
if (blockIdx < vecConfig.usedCoreNum) {
|
||||
for (uint32_t outLoopIdx = 0; outLoopIdx < vecConfig.outLoopNum; outLoopIdx++) {
|
||||
vecConfig.innerLoopNum = outLoopIdx == (vecConfig.outLoopNum - 1)
|
||||
? vecConfig.tailLoopNum
|
||||
: gmmSwiglu->maxProcessRowNum;
|
||||
PipeBarrier<PIPE_ALL>();
|
||||
customDataCopyIn(outLoopIdx, mmOutGM, vecConfig);
|
||||
PipeBarrier<PIPE_ALL>();
|
||||
for (uint32_t innerLoopIdx = 0; innerLoopIdx < vecConfig.innerLoopNum; innerLoopIdx++) {
|
||||
UpdateChannelScale<CHANNELDTYPE>(innerLoopIdx, vecConfig);
|
||||
VectorCompute(innerLoopIdx, vecConfig);
|
||||
}
|
||||
PipeBarrier<PIPE_ALL>();
|
||||
customDataCopyOut(vecConfig);
|
||||
PipeBarrier<PIPE_ALL>();
|
||||
}
|
||||
|
||||
LocalTensor<float> channelScaleLocal = perChannelScaleInQueue.DeQue<float>();
|
||||
LocalTensor<int32_t> mmLocal = mmOutQueue.DeQue<int32_t>();
|
||||
LocalTensor<int8_t> quantLocal = quantOutQueue.DeQue<int8_t>();
|
||||
LocalTensor<float> quantScaleLocal = quantScaleOutQueue.DeQue<float>();
|
||||
perChannelScaleInQueue.FreeTensor(channelScaleLocal);
|
||||
mmOutQueue.FreeTensor(mmLocal);
|
||||
quantScaleOutQueue.FreeTensor(quantScaleLocal);
|
||||
quantOutQueue.FreeTensor(quantLocal);
|
||||
}
|
||||
if (workspaceSplitLoopIdx < workspaceSplitConfig.loopCount - parallelNum){
|
||||
SyncAll<false>();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename mmType, bool sync, typename CHANNELDTYPE>
|
||||
template <typename DTYPE_CS>
|
||||
__aicore__ inline void GMMSwigluSplitWorkSpaceCompute<mmType, sync, CHANNELDTYPE>::PreLoadTokenAndChannel(LocalTensor<float>& channelScaleLocal, VecConfig& vecConfig)
|
||||
{
|
||||
GlobalTensor<CHANNELDTYPE> perChannelScaleTensor;
|
||||
perChannelScaleTensor.SetGlobalBuffer(GetTensorAddr<CHANNELDTYPE>(vecConfig.curGroupIdx, perChannelScalePtr));
|
||||
|
||||
DataCopyExtParams copyChannelParams{1, static_cast<uint32_t>(gmmSwiglu->tokenLen * sizeof(DTYPE_CS)), 0, 0, 0};
|
||||
DataCopyPadExtParams<DTYPE_CS> padParams{false, 0 ,0, 0};
|
||||
if constexpr(!IsSameType<DTYPE_CS, float>::value) {
|
||||
LocalTensor<DTYPE_CS> dstLocalT = channelScaleLocal.template ReinterpretCast<DTYPE_CS>();
|
||||
DataCopyPad(dstLocalT[gmmSwiglu->tokenLen], perChannelScaleTensor, copyChannelParams, padParams);
|
||||
PipeBarrier<PIPE_ALL>();
|
||||
Cast(channelScaleLocal, dstLocalT[gmmSwiglu->tokenLen], RoundMode::CAST_NONE, gmmSwiglu->tokenLen);
|
||||
} else {
|
||||
DataCopyPad(channelScaleLocal, perChannelScaleTensor, copyChannelParams, padParams);
|
||||
}
|
||||
perChannelScaleInQueue.EnQue(channelScaleLocal);
|
||||
}
|
||||
|
||||
template <typename mmType, bool sync, typename CHANNELDTYPE>
|
||||
__aicore__ inline void GMMSwigluSplitWorkSpaceCompute<mmType, sync, CHANNELDTYPE>::MMCompute(uint32_t groupIdx, MNConfig& mnConfig, uint32_t coreIdx, GlobalTensor<int32_t> &mmOutGM)
|
||||
{
|
||||
uint32_t tailN = mnConfig.nIdx * mnConfig.singleN;
|
||||
uint32_t curSingleN = mnConfig.nIdx < mnConfig.blockDimN - 1 ? mnConfig.singleN : mnConfig.n - tailN;
|
||||
uint32_t curSingleM = mnConfig.mIdx < mnConfig.blockDimM - 1 ? mnConfig.singleM
|
||||
: mnConfig.m - mnConfig.mIdx * mnConfig.singleM;
|
||||
uint64_t xOffset = mnConfig.mIdx * mnConfig.singleM * mnConfig.k;
|
||||
if constexpr (transposeX) {
|
||||
xOffset = mnConfig.mIdx * mnConfig.singleM;
|
||||
}
|
||||
uint64_t outOffset = mnConfig.mIdx * mnConfig.singleM * mnConfig.n + tailN;
|
||||
xGM.SetGlobalBuffer((__gm__ int8_t *)xTensorPtr + mnConfig.xBaseOffset + workspaceSplitConfig.leftMatrixStartIndex * mnConfig.k);
|
||||
weightGM.SetGlobalBuffer(GetTensorAddr<int8_t>(groupIdx, weightTensorPtr) + GetWOffset(tailN, mnConfig.k));
|
||||
if (mnConfig.blockDimM == 1){
|
||||
weightGM.SetL2CacheHint(CacheMode::CACHE_MODE_DISABLE);
|
||||
} else {
|
||||
weightGM.SetL2CacheHint(CacheMode::CACHE_MODE_NORMAL);
|
||||
}
|
||||
mnConfig.workSpaceOffset = outOffset + mnConfig.yBaseOffset;
|
||||
mm.SetOrgShape(mnConfig.m, mnConfig.n, mnConfig.k);
|
||||
mm.SetSingleShape(curSingleM, curSingleN, mnConfig.k);
|
||||
mm.SetTensorA(xGM[xOffset], transposeX);
|
||||
mm.SetTensorB(weightGM, transposeW);
|
||||
mm.template IterateAll<sync>(mmOutGM[mnConfig.workSpaceOffset], 0);
|
||||
}
|
||||
|
||||
template <typename mmType, bool sync, typename CHANNELDTYPE>
|
||||
__aicore__ inline void GMMSwigluSplitWorkSpaceCompute<mmType, sync, CHANNELDTYPE>::UpdateMnConfig(MNConfig &mnConfig) {
|
||||
if constexpr (B::format == CubeFormat::NZ) {
|
||||
mnConfig.wBaseOffset += AlignUp<16>(mnConfig.k) * AlignUp<32>(mnConfig.n); // 16: nz format last two dim size
|
||||
} else {
|
||||
mnConfig.wBaseOffset += mnConfig.k * mnConfig.n;
|
||||
}
|
||||
mnConfig.nAxisBaseOffset += mnConfig.n;
|
||||
mnConfig.mAxisBaseOffset += mnConfig.m;
|
||||
mnConfig.xBaseOffset += mnConfig.m * mnConfig.k;
|
||||
mnConfig.yBaseOffset += mnConfig.m * mnConfig.n;
|
||||
}
|
||||
|
||||
template <typename mmType, bool sync, typename CHANNELDTYPE>
|
||||
__aicore__ inline void GMMSwigluSplitWorkSpaceCompute<mmType, sync, CHANNELDTYPE>::SetMNConfig(const int32_t splitValue, const uint32_t groupIdx, MNConfig &mnConfig) {
|
||||
SetMKN(splitValue, groupIdx, mnConfig);
|
||||
mnConfig.baseM = BASIC_M;
|
||||
mnConfig.baseN = BASIC_N;
|
||||
mnConfig.singleM = SINGLE_CORE_M;
|
||||
mnConfig.singleN = SINGLE_CORE_N;
|
||||
}
|
||||
|
||||
template <typename mmType, bool sync, typename CHANNELDTYPE>
|
||||
__aicore__ inline void GMMSwigluSplitWorkSpaceCompute<mmType, sync, CHANNELDTYPE>::SetMKN(const int32_t splitValue, const uint32_t groupIdx, MNConfig &mnConfig)
|
||||
{
|
||||
mnConfig.m = static_cast<int64_t>(splitValue);
|
||||
mnConfig.k = gmmBaseParams->K; // tilingData
|
||||
mnConfig.n = gmmBaseParams->N; // tilingData
|
||||
}
|
||||
|
||||
template <typename mmType, bool sync, typename CHANNELDTYPE>
|
||||
__aicore__ inline uint64_t GMMSwigluSplitWorkSpaceCompute<mmType, sync, CHANNELDTYPE>::GetWOffset(uint32_t tailN, uint32_t k) {
|
||||
uint64_t wOffset = 0;
|
||||
if constexpr (mmType::BT::format == CubeFormat::NZ) {
|
||||
wOffset = tailN * AlignUp<16>(k); // 16: nz format last two dim size
|
||||
} else {
|
||||
wOffset = tailN;
|
||||
}
|
||||
return wOffset;
|
||||
}
|
||||
|
||||
template <typename mmType, bool sync, typename CHANNELDTYPE>
|
||||
__aicore__ inline void GMMSwigluSplitWorkSpaceCompute<mmType, sync, CHANNELDTYPE>::MNBlockIdxCompute(MNConfig &mnConfig, const uint32_t curBlock,
|
||||
const uint32_t count, const uint32_t thresholdM_dimN) {
|
||||
mnConfig.mIdx = (curBlock - count) / mnConfig.blockDimN;
|
||||
mnConfig.nIdx = (curBlock - count) % mnConfig.blockDimN;
|
||||
}
|
||||
|
||||
template <typename mmType, bool sync, typename CHANNELDTYPE>
|
||||
__aicore__ inline void GMMSwigluSplitWorkSpaceCompute<mmType, sync, CHANNELDTYPE>::UpdateVecConfig(uint32_t blockIdx, VecConfig& vecConfig)
|
||||
{
|
||||
// Step 1: Read grouplist reduceSum to calculate total data count
|
||||
vecConfig.M = workspaceSplitConfig.isLastLoop \
|
||||
? workspaceSplitConfig.lastLoopTaskSize\
|
||||
: workspaceSplitConfig.notLastTaskSize;
|
||||
// Step 2: Calculate core allocation
|
||||
uint32_t eachCoreTaskNum = (vecConfig.M + aivCoreNum - 1) / aivCoreNum;
|
||||
vecConfig.usedCoreNum = vecConfig.M >= aivCoreNum ? aivCoreNum : vecConfig.M;
|
||||
uint32_t tailCoreIdx = vecConfig.M - (eachCoreTaskNum - 1) * vecConfig.usedCoreNum;
|
||||
vecConfig.taskNum = blockIdx < tailCoreIdx ? eachCoreTaskNum : eachCoreTaskNum - 1;
|
||||
vecConfig.startIdx = blockIdx < tailCoreIdx
|
||||
? eachCoreTaskNum * blockIdx
|
||||
:((eachCoreTaskNum - 1) * blockIdx + tailCoreIdx);
|
||||
vecConfig.curIdx = vecConfig.startIdx;
|
||||
vecConfig.startOffset = vecConfig.startIdx * gmmSwiglu->tokenLen;
|
||||
vecConfig.curOffset = vecConfig.startOffset;
|
||||
int64_t curStartIdx = vecConfig.startIdx;
|
||||
int64_t prevM = workspaceSplitConfig.leftMatrixStartIndex;
|
||||
for (uint32_t groupIdx = workspaceSplitConfig.rightMatrixExpertStartIndex; groupIdx <= workspaceSplitConfig.rightMatrixExpertEndIndex; groupIdx++){
|
||||
int64_t currM = groupListGM.GetValue(groupIdx);
|
||||
int64_t tempM = currM - prevM;
|
||||
prevM = currM;
|
||||
if (curStartIdx >= 0 && curStartIdx - tempM < 0) {
|
||||
vecConfig.curGroupIdx = groupIdx;
|
||||
vecConfig.nextUpadteInterVal = tempM - curStartIdx;
|
||||
}
|
||||
curStartIdx -= tempM;
|
||||
}
|
||||
// Step 3: Calculate total data volume
|
||||
vecConfig.outLoopNum = (vecConfig.taskNum + gmmSwiglu->maxProcessRowNum - 1) / gmmSwiglu->maxProcessRowNum;
|
||||
vecConfig.tailLoopNum = vecConfig.taskNum % gmmSwiglu->maxProcessRowNum
|
||||
? vecConfig.taskNum % gmmSwiglu->maxProcessRowNum
|
||||
: gmmSwiglu->maxProcessRowNum;
|
||||
pipe->Reset();
|
||||
// Step 4: Allocate space
|
||||
pipe->InitBuffer(mmOutQueue, 1, gmmSwiglu->maxProcessRowNum * gmmSwiglu->tokenLen * sizeof(int32_t));
|
||||
pipe->InitBuffer(perChannelScaleInQueue, 1, gmmSwiglu->tokenLen * sizeof(float));
|
||||
pipe->InitBuffer(quantOutQueue, 1, gmmSwiglu->maxProcessRowNum * gmmSwiglu->tokenLen / 2 * sizeof(int8_t));
|
||||
pipe->InitBuffer(quantScaleOutQueue, 1, AlignUp<int32_t>(gmmSwiglu->maxProcessRowNum, 8) * sizeof(float));
|
||||
pipe->InitBuffer(reduceWorkspace, 1024 * sizeof(float));
|
||||
pipe->InitBuffer(castWorkspace, 32 * sizeof(int8_t));
|
||||
}
|
||||
|
||||
template <typename mmType, bool sync, typename CHANNELDTYPE>
|
||||
__aicore__ inline void GMMSwigluSplitWorkSpaceCompute<mmType, sync, CHANNELDTYPE>::customDataCopyIn(uint32_t outLoopIdx, GlobalTensor<int32_t> &mmOutGM, VecConfig& vecConfig)
|
||||
{
|
||||
LocalTensor<int32_t> _inMMLocal_0 = mmOutQueue.DeQue<int32_t>();
|
||||
DataCopyExtParams copyParams_0{1, static_cast<uint32_t>(vecConfig.innerLoopNum * gmmSwiglu->tokenLen * sizeof(int32_t)), 0, 0, 0};
|
||||
DataCopyPadExtParams<int32_t> padParams_0{false, 0 ,0, 0};
|
||||
PipeBarrier<PIPE_ALL>();
|
||||
DataCopyPad(_inMMLocal_0, mmOutGM[vecConfig.curOffset], copyParams_0, padParams_0);
|
||||
mmOutQueue.EnQue(_inMMLocal_0);
|
||||
|
||||
LocalTensor<int32_t> _inMMLocal_1 = mmOutQueue.DeQue<int32_t>();
|
||||
|
||||
Cast(_inMMLocal_1.ReinterpretCast<float>(), _inMMLocal_1, RoundMode::CAST_NONE, vecConfig.innerLoopNum * gmmSwiglu->tokenLen);
|
||||
|
||||
mmOutQueue.EnQue(_inMMLocal_1);
|
||||
LocalTensor<float> _inMMLocal_2 = mmOutQueue.DeQue<float>();
|
||||
set_flag(PIPE_S, PIPE_V, EVENT_ID0);
|
||||
for (uint32_t i = 0; i < vecConfig.innerLoopNum; i++){
|
||||
wait_flag(PIPE_S, PIPE_V, EVENT_ID0);
|
||||
float scale = perTokenScaleGM.GetValue(vecConfig.curIdx + workspaceSplitConfig.leftMatrixStartIndex);
|
||||
set_flag(PIPE_S, PIPE_V, EVENT_ID0);
|
||||
wait_flag(PIPE_S, PIPE_V, EVENT_ID0);
|
||||
Muls(_inMMLocal_2[i * gmmSwiglu->tokenLen], _inMMLocal_2[i * gmmSwiglu->tokenLen], scale, gmmSwiglu->tokenLen);
|
||||
set_flag(PIPE_S, PIPE_V, EVENT_ID0);
|
||||
vecConfig.curIdx++;
|
||||
}
|
||||
wait_flag(PIPE_S, PIPE_V, EVENT_ID0);
|
||||
vecConfig.curOffset = vecConfig.curIdx * gmmSwiglu->tokenLen;
|
||||
mmOutQueue.EnQue(_inMMLocal_2);
|
||||
}
|
||||
|
||||
template <typename mmType, bool sync, typename CHANNELDTYPE>
|
||||
template <typename DTYPE_CS>
|
||||
__aicore__ inline void GMMSwigluSplitWorkSpaceCompute<mmType, sync, CHANNELDTYPE>::UpdateChannelScale(uint32_t loopIdx, VecConfig& vecConfig){
|
||||
// Update perChannel
|
||||
if (unlikely(vecConfig.nextUpadteInterVal == 0)) {
|
||||
int64_t loop = gmmSwiglu->groupListLen - vecConfig.curGroupIdx;
|
||||
while (loop--) {
|
||||
int64_t curTemp = groupListGM.GetValue(vecConfig.curGroupIdx);
|
||||
vecConfig.curGroupIdx++;
|
||||
int64_t nextTemp = groupListGM.GetValue(vecConfig.curGroupIdx);
|
||||
if(nextTemp != curTemp){
|
||||
vecConfig.nextUpadteInterVal = nextTemp - curTemp;
|
||||
break;
|
||||
}
|
||||
}
|
||||
LocalTensor<float> _inChannel = perChannelScaleInQueue.DeQue<float>();
|
||||
DataCopyExtParams copyParams{1, static_cast<uint32_t>(gmmSwiglu->tokenLen * sizeof(DTYPE_CS)), 0, 0, 0};
|
||||
DataCopyPadExtParams<DTYPE_CS> padParams{false, 0 ,0, 0};
|
||||
|
||||
GlobalTensor<CHANNELDTYPE> perChannelScaleTensor;
|
||||
perChannelScaleTensor.SetGlobalBuffer(GetTensorAddr<CHANNELDTYPE>(vecConfig.curGroupIdx, perChannelScalePtr));
|
||||
|
||||
if constexpr(!IsSameType<DTYPE_CS, float>::value) {
|
||||
LocalTensor<DTYPE_CS> dstLocalT = _inChannel.template ReinterpretCast<DTYPE_CS>();
|
||||
DataCopyPad(dstLocalT[gmmSwiglu->tokenLen], perChannelScaleTensor, copyParams, padParams);
|
||||
PipeBarrier<PIPE_ALL>();
|
||||
Cast(_inChannel, dstLocalT[gmmSwiglu->tokenLen], RoundMode::CAST_NONE, gmmSwiglu->tokenLen);
|
||||
} else {
|
||||
DataCopyPad(_inChannel, perChannelScaleTensor, copyParams, padParams);
|
||||
}
|
||||
PipeBarrier<PIPE_ALL>();
|
||||
|
||||
perChannelScaleInQueue.EnQue(_inChannel);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename mmType, bool sync, typename CHANNELDTYPE>
|
||||
__aicore__ inline void GMMSwigluSplitWorkSpaceCompute<mmType, sync, CHANNELDTYPE>::VectorCompute(uint32_t loopIdx, VecConfig& vecConfig) {
|
||||
Dequant(loopIdx, vecConfig);
|
||||
Swiglu(loopIdx, vecConfig);
|
||||
Quant(loopIdx, vecConfig);
|
||||
}
|
||||
|
||||
template <typename mmType, bool sync, typename CHANNELDTYPE>
|
||||
__aicore__ inline void GMMSwigluSplitWorkSpaceCompute<mmType, sync, CHANNELDTYPE>::Dequant(uint32_t loopIdx, VecConfig& vecConfig) {
|
||||
// perChanelScale * perTokenScale
|
||||
LocalTensor<float> mmLocal = mmOutQueue.DeQue<float>();
|
||||
LocalTensor<float> perChannelLocal = perChannelScaleInQueue.DeQue<float>();
|
||||
Mul(mmLocal[loopIdx * gmmSwiglu->tokenLen], mmLocal[loopIdx * gmmSwiglu->tokenLen], perChannelLocal, gmmSwiglu->tokenLen);
|
||||
vecConfig.nextUpadteInterVal--;
|
||||
mmOutQueue.EnQue(mmLocal);
|
||||
perChannelScaleInQueue.EnQue(perChannelLocal);
|
||||
}
|
||||
|
||||
template <typename mmType, bool sync, typename CHANNELDTYPE>
|
||||
__aicore__ inline void GMMSwigluSplitWorkSpaceCompute<mmType, sync, CHANNELDTYPE>::Swiglu(uint32_t loopIdx, VecConfig& vecConfig) {
|
||||
// High-level API swiglu
|
||||
LocalTensor<float> _inMMLocal = mmOutQueue.DeQue<float>();
|
||||
float beta = 1.0f;
|
||||
LocalTensor<float> workspaceLocal= reduceWorkspace.Get<float>();
|
||||
LocalTensor<float> src0Local = _inMMLocal[loopIdx * gmmSwiglu->tokenLen + gmmSwiglu->tokenLen / 2];
|
||||
LocalTensor<float> src1Local = _inMMLocal[loopIdx * gmmSwiglu->tokenLen];
|
||||
SwiGLU<float, false>(workspaceLocal, src0Local, src1Local, beta, gmmSwiglu->tokenLen / 2);
|
||||
PipeBarrier<PIPE_ALL>();
|
||||
DataCopyParams repeatParams{1, static_cast<uint16_t>((gmmSwiglu->tokenLen / 2) / 8), 0, 0};
|
||||
DataCopy(_inMMLocal[loopIdx * gmmSwiglu->tokenLen], workspaceLocal, repeatParams);
|
||||
mmOutQueue.EnQue(_inMMLocal);
|
||||
}
|
||||
|
||||
template <typename mmType, bool sync, typename CHANNELDTYPE>
|
||||
__aicore__ inline void GMMSwigluSplitWorkSpaceCompute<mmType, sync, CHANNELDTYPE>::Quant(uint32_t loopIdx, VecConfig& vecConfig) {
|
||||
LocalTensor<float> _inMMLocal = mmOutQueue.DeQue<float>();
|
||||
Abs(_inMMLocal[loopIdx * gmmSwiglu->tokenLen + gmmSwiglu->tokenLen / BISECT],
|
||||
_inMMLocal[loopIdx * gmmSwiglu->tokenLen],
|
||||
gmmSwiglu->tokenLen / BISECT);
|
||||
LocalTensor<float> workspaceLocal= reduceWorkspace.Get<float>();
|
||||
PipeBarrier<PIPE_ALL>();
|
||||
ReduceMaxTemplate(workspaceLocal,
|
||||
_inMMLocal, loopIdx * gmmSwiglu->tokenLen + gmmSwiglu->tokenLen / BISECT, gmmSwiglu->tokenLen / BISECT);
|
||||
PipeBarrier<PIPE_ALL>();
|
||||
float quantScale = workspaceLocal.GetValue(0) / QUANT_SCALE_INT8;
|
||||
PipeBarrier<PIPE_ALL>();
|
||||
LocalTensor<float> quantScaleLocal = quantScaleOutQueue.DeQue<float>();
|
||||
PipeBarrier<PIPE_ALL>();
|
||||
quantScaleLocal.SetValue(loopIdx, quantScale);
|
||||
PipeBarrier<PIPE_ALL>();
|
||||
quantScale = 1 / quantScale;
|
||||
PipeBarrier<PIPE_ALL>();
|
||||
Muls(_inMMLocal[loopIdx * gmmSwiglu->tokenLen], _inMMLocal[loopIdx * gmmSwiglu->tokenLen],
|
||||
quantScale, gmmSwiglu->tokenLen / BISECT);
|
||||
PipeBarrier<PIPE_V>();
|
||||
LocalTensor<int8_t> quantLocal = quantOutQueue.DeQue<int8_t>();
|
||||
int32_t dstTempOffset = static_cast<int32_t>(loopIdx * gmmSwiglu->tokenLen / BISECT);
|
||||
int32_t srcTempOffset = static_cast<int32_t>(loopIdx * gmmSwiglu->tokenLen);
|
||||
int32_t tempCount = static_cast<int32_t>(gmmSwiglu->tokenLen / BISECT);
|
||||
LocalTensor<int8_t> castSpace = castWorkspace.Get<int8_t>();
|
||||
CastFp32ToInt8Template(quantLocal, _inMMLocal, castSpace, dstTempOffset, srcTempOffset, tempCount);
|
||||
mmOutQueue.EnQue(_inMMLocal);
|
||||
quantOutQueue.EnQue(quantLocal);
|
||||
}
|
||||
|
||||
template <typename mmType, bool sync, typename CHANNELDTYPE>
|
||||
__aicore__ inline void GMMSwigluSplitWorkSpaceCompute<mmType, sync, CHANNELDTYPE>::customDataCopyOut(VecConfig& vecConfig) {
|
||||
LocalTensor<float> quantScaleLocal = quantScaleOutQueue.DeQue<float>();
|
||||
DataCopyParams copyParams_0{1, (uint16_t)(vecConfig.innerLoopNum * sizeof(float)), 0, 0};
|
||||
PipeBarrier<PIPE_ALL>();
|
||||
DataCopyPad(quantScaleOutputGM[workspaceSplitConfig.leftMatrixStartIndex + vecConfig.startIdx], quantScaleLocal, copyParams_0);
|
||||
LocalTensor<int8_t> quantLocal = quantOutQueue.DeQue<int8_t>();
|
||||
DataCopyParams copyParams_1{1, (uint16_t)(vecConfig.innerLoopNum * gmmSwiglu->tokenLen / 2 * sizeof(int8_t)), 0, 0};
|
||||
PipeBarrier<PIPE_ALL>();
|
||||
DataCopyPad(quantOutputGM[(workspaceSplitConfig.leftMatrixStartIndex + vecConfig.startIdx) * gmmSwiglu->tokenLen / 2], quantLocal, copyParams_1);
|
||||
PipeBarrier<PIPE_ALL>();
|
||||
vecConfig.startIdx += vecConfig.innerLoopNum;
|
||||
vecConfig.startOffset = vecConfig.startIdx * gmmSwiglu->tokenLen;
|
||||
quantOutQueue.EnQue(quantLocal);
|
||||
quantScaleOutQueue.EnQue(quantScaleLocal);
|
||||
}
|
||||
|
||||
} // namespace GROUPED_MATMUL
|
||||
#endif // ASCENDC_GROUPED_MATMUL_SWIGLU_QUANT_WEIGHT_NZ_TENSOR_LIST_SPLIT_WS_H
|
||||
@@ -0,0 +1,240 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file grouped_matmul_swiglu_quant_weight_nz_tensor_list_utils.h
|
||||
* \brief
|
||||
*/
|
||||
#ifndef ASCENDC_GROUPED_MATMUL_SWIGLU_QUANT_WEIGHT_NZ_TENSOR_LIST_UTILS_H
|
||||
#define ASCENDC_GROUPED_MATMUL_SWIGLU_QUANT_WEIGHT_NZ_TENSOR_LIST_UTILS_H
|
||||
|
||||
#include "kernel_tiling/kernel_tiling.h"
|
||||
#include "kernel_operator.h"
|
||||
#include "lib/matmul_intf.h"
|
||||
|
||||
namespace GROUPED_MATMUL_SWIGLU_QUANT_WEIGHT_NZ_TENSOR_LIST {
|
||||
using namespace AscendC;
|
||||
constexpr uint32_t INT8_BITS = 8; // a int8 number has 8 bits
|
||||
constexpr uint32_t UB_BLOCK_UNIT_SIZE = 32; // 32: a block has 32 bytes data
|
||||
constexpr uint32_t THRESHOLD_BLOCK_NUM = 8;
|
||||
constexpr uint32_t UB_BLOCK_DOUBLE_UNIT_SIZE = 64; // 64: a block has 64 bytes data
|
||||
constexpr uint32_t HALF_UB_BLOCK_UNIT_SIZE = UB_BLOCK_UNIT_SIZE / 2; // 2: a float16 data has two bytes
|
||||
constexpr uint32_t SINGLE_CORE_M = 128;
|
||||
constexpr uint32_t SINGLE_CORE_N = 512;
|
||||
constexpr uint32_t SINGLE_CORE_K = 7168;
|
||||
constexpr uint32_t BASIC_M = 128;
|
||||
constexpr uint32_t BASIC_N = 256;
|
||||
constexpr uint32_t BASIC_K = 128;
|
||||
constexpr uint32_t STEP_M = 1;
|
||||
constexpr uint32_t STEP_N = 1;
|
||||
constexpr uint32_t STEP_Ka = 4;
|
||||
constexpr uint32_t STEP_Kb = 4;
|
||||
constexpr uint32_t DEPTH_A1 = 8;
|
||||
constexpr uint32_t DEPTH_B1 = 8;
|
||||
constexpr uint32_t VEC_LEN_ONCE_REPEAT_ELE = 64;
|
||||
constexpr uint32_t VEC_LEN_ONCE_REPEAT_BLOCK = 8;
|
||||
constexpr uint32_t BISECT = 2;
|
||||
constexpr uint32_t MOD_32_MASK = 0x1F;
|
||||
constexpr uint32_t MOD_16_MASK = 0x0F;
|
||||
constexpr uint32_t ALIGN_8_ELE = 8;
|
||||
constexpr uint32_t ALIGN_16_ELE = 16;
|
||||
constexpr float QUANT_SCALE_INT8 = 127.0f;
|
||||
constexpr MatmulConfig NZ_CFG_MDL = GetMDLConfig(false, false, 0, true, false, false, true);
|
||||
constexpr MatmulConfig GetMMCFG() {
|
||||
MatmulConfig MM_CFG = NZ_CFG_MDL;
|
||||
MM_CFG.singleCoreM = SINGLE_CORE_M;
|
||||
MM_CFG.singleCoreN= SINGLE_CORE_N;
|
||||
MM_CFG.singleCoreK= SINGLE_CORE_K;
|
||||
MM_CFG.basicM= BASIC_M;
|
||||
MM_CFG.basicN= BASIC_N;
|
||||
MM_CFG.basicK= BASIC_K;
|
||||
return MM_CFG;
|
||||
}
|
||||
|
||||
constexpr static MatmulApiStaticTiling GetMMTiling(const MatmulApiStaticTiling& mmTiling)
|
||||
{
|
||||
MatmulApiStaticTiling tiling = mmTiling;
|
||||
tiling.stepM = STEP_M;
|
||||
tiling.stepN = STEP_N;
|
||||
tiling.stepKa = STEP_Ka;
|
||||
tiling.stepKb = STEP_Kb;
|
||||
tiling.depthA1 = DEPTH_A1;
|
||||
tiling.depthB1 = DEPTH_B1;
|
||||
return tiling;
|
||||
}
|
||||
template<class AT_, class BT_, class CT_, class BiasT_, const MatmulConfig& MM_CFG>
|
||||
struct MMImplType {
|
||||
using AT = AT_;
|
||||
using BT = BT_;
|
||||
using CT = CT_;
|
||||
using BiasT = BiasT_;
|
||||
static constexpr MatmulConfig cfg = GetMMCFG();
|
||||
static constexpr MatmulApiStaticTiling mdl = GetMMTiling(GetMatmulApiTiling<AT, BT, CT, BiasT>(cfg));
|
||||
using MT = matmul::MatmulImpl<AT, BT, CT, BiasT, mdl>;
|
||||
};
|
||||
|
||||
struct MNConfig {
|
||||
int64_t m = 0;
|
||||
int64_t k = 0;
|
||||
int64_t n = 0;
|
||||
int64_t baseM = 0;
|
||||
int64_t baseN = 0;
|
||||
int64_t mIdx = 0;
|
||||
int64_t nIdx = 0;
|
||||
int64_t blockDimM = 0;
|
||||
int64_t blockDimN = 0;
|
||||
int64_t singleM = 0;
|
||||
int64_t singleN = 0;
|
||||
int64_t wBaseOffset = 0;
|
||||
int64_t nAxisBaseOffset = 0;
|
||||
int64_t mAxisBaseOffset = 0;
|
||||
int64_t xBaseOffset = 0;
|
||||
int64_t yBaseOffset = 0;
|
||||
int64_t wOutOffset = 0;
|
||||
int64_t workSpaceOffset = 0;
|
||||
};
|
||||
|
||||
struct VecConfig {
|
||||
int64_t M = 0;
|
||||
int64_t usedCoreNum = 0;
|
||||
int64_t startOffset = 0;
|
||||
int64_t curOffset = 0;
|
||||
int64_t startIdx = 0;
|
||||
int64_t curIdx = 0;
|
||||
int64_t taskNum = 0;
|
||||
int64_t curGroupIdx = 0;
|
||||
int64_t outLoopNum = 0;
|
||||
int64_t innerLoopNum = 0;
|
||||
int64_t tailLoopNum = 0;
|
||||
int64_t nextUpadteInterVal = 0;
|
||||
};
|
||||
|
||||
struct WorkSpaceSplitConfig {
|
||||
int64_t M = 0;
|
||||
int64_t loopCount = 0;
|
||||
int64_t leftMatrixStartIndex = 0;
|
||||
int64_t rightMatrixExpertStartIndex = 0;
|
||||
int64_t rightMatrixExpertNextStartIndex = 0;
|
||||
int64_t rightMatrixExpertEndIndex = 0;
|
||||
int64_t notLastTaskSize = 0;
|
||||
int64_t lastLoopTaskSize = 0;
|
||||
bool isLastLoop = false;
|
||||
};
|
||||
|
||||
template <uint32_t base, typename T = uint32_t>
|
||||
__aicore__ inline T AlignUp(T a) {
|
||||
return (a + base - 1) / base * base;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline T AlignUp(T a, T base) {
|
||||
return (a + base - 1) / base * base;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline T AlignDown(T a, T base) {
|
||||
if (unlikely(base == 0)) {
|
||||
return a;
|
||||
}
|
||||
return a / base * base;
|
||||
}
|
||||
|
||||
template <>
|
||||
__aicore__ inline uint32_t AlignUp<4, uint32_t>(uint32_t a) {
|
||||
// to be Multiple of 4, result should be in a format of b(xxxx,x100).
|
||||
// This means last two bits should be zero, requiring that
|
||||
// result = num & b(1111,1100) = num & (~3).
|
||||
// &(~3) operator may reduces num into the range [num, num - 3].
|
||||
// As the result should be no less than a (result >= a), it means num - 3 >= a in the worst case.
|
||||
// In this case, num >= a+3. On the other hand, num should also be less then a+4, otherwise,
|
||||
// the result will not be least multiple of 4 for 3. In other cases like [num, num - 2],
|
||||
// num = a + 3 also satisfies the goal condition.
|
||||
return (a + 3) & ~3; // & ~3: set last two bits of (a+3) to be zero
|
||||
}
|
||||
|
||||
template <>
|
||||
__aicore__ inline uint32_t AlignUp<8, uint32_t>(uint32_t a) {
|
||||
// In general, if we want to get the least multiple of b (b is the power of 2) for a,
|
||||
// it comes to a conclusion from the above comment: result = (a + (b - 1)) & (~b)
|
||||
return (a + 7) & ~7; // & ~7: set last four bits of (a+7) to be zero
|
||||
}
|
||||
|
||||
template <>
|
||||
__aicore__ inline uint32_t AlignUp<16, uint32_t>(uint32_t a) {
|
||||
// In general, if we want to get the least multiple of b (b is the power of 2) for a,
|
||||
// it comes to a conclusion from the above comment: result = (a + (b - 1)) & (~b)
|
||||
return (a + 15) & ~15; // & ~15: set last four bits of (a+15) to be zero
|
||||
}
|
||||
|
||||
template <>
|
||||
__aicore__ inline uint32_t AlignUp<32, uint32_t>(uint32_t a) {
|
||||
// refer to the above comments.
|
||||
return (a + 31) & ~31; // & ~31: set last five bits of (a+31) to be zero}
|
||||
}
|
||||
|
||||
__aicore__ inline void ReduceMaxTemplate(LocalTensor<float>& dstLocal, LocalTensor<float>& srcLocal,
|
||||
uint32_t srcOffset, uint32_t count)
|
||||
{
|
||||
if (likely(count > VEC_LEN_ONCE_REPEAT_ELE && count % VEC_LEN_ONCE_REPEAT_ELE == 0)){
|
||||
WholeReduceMax(dstLocal,
|
||||
srcLocal[srcOffset], VEC_LEN_ONCE_REPEAT_ELE,
|
||||
count / VEC_LEN_ONCE_REPEAT_ELE, 1, 1,
|
||||
VEC_LEN_ONCE_REPEAT_BLOCK, ReduceOrder::ORDER_ONLY_VALUE);
|
||||
PipeBarrier<PIPE_V>();
|
||||
WholeReduceMax(dstLocal, dstLocal,
|
||||
count / VEC_LEN_ONCE_REPEAT_ELE, 1, 1, 1,
|
||||
VEC_LEN_ONCE_REPEAT_BLOCK, ReduceOrder::ORDER_ONLY_VALUE);
|
||||
} else if (count <= VEC_LEN_ONCE_REPEAT_ELE) {
|
||||
WholeReduceMax(dstLocal,
|
||||
srcLocal[srcOffset],
|
||||
count, 1, 1, 1, VEC_LEN_ONCE_REPEAT_BLOCK, ReduceOrder::ORDER_ONLY_VALUE);
|
||||
} else {
|
||||
ReduceMax(dstLocal, srcLocal[srcOffset], dstLocal, count, false);
|
||||
}
|
||||
}
|
||||
|
||||
__aicore__ inline void CastFp32ToInt8Template(LocalTensor<int8_t>& dstLocal, LocalTensor<float>& srcLocal,
|
||||
LocalTensor<int8_t>& oneBlockWorkspace,
|
||||
int32_t dstOffset, int32_t srcOffset, int32_t count)
|
||||
{
|
||||
Cast(srcLocal[srcOffset].ReinterpretCast<half>(), srcLocal[srcOffset], RoundMode::CAST_RINT, count);
|
||||
PipeBarrier<PIPE_V>();
|
||||
if ((dstOffset & MOD_32_MASK) == 0) {
|
||||
Cast(dstLocal[dstOffset],
|
||||
srcLocal[srcOffset].ReinterpretCast<half>(),
|
||||
RoundMode::CAST_RINT, count);
|
||||
} else if ((dstOffset & MOD_16_MASK) == 0) {
|
||||
Cast(dstLocal[dstOffset + ALIGN_16_ELE],
|
||||
srcLocal[srcOffset + ALIGN_8_ELE].ReinterpretCast<half>(),
|
||||
RoundMode::CAST_RINT, count - ALIGN_16_ELE);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Cast(oneBlockWorkspace, srcLocal[srcOffset].ReinterpretCast<half>(),
|
||||
RoundMode::CAST_RINT, ALIGN_16_ELE);
|
||||
PipeBarrier<PIPE_ALL>();
|
||||
for (int32_t i = 0; i < ALIGN_16_ELE; i++) {
|
||||
int8_t temp = oneBlockWorkspace.GetValue(i);
|
||||
dstLocal.SetValue(dstOffset + i, temp);
|
||||
}
|
||||
PipeBarrier<PIPE_ALL>();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline __gm__ T* GetTensorAddr(uint16_t index, GM_ADDR tensorPtr) {
|
||||
__gm__ uint64_t* dataAddr = reinterpret_cast<__gm__ uint64_t*>(tensorPtr);
|
||||
uint64_t tensorPtrOffset = *dataAddr; // The offset of the data address from the first address.
|
||||
// Moving 3 bits to the right means dividing by sizeof(uint64 t).
|
||||
__gm__ uint64_t* retPtr = dataAddr + (tensorPtrOffset >> 3);
|
||||
return reinterpret_cast<__gm__ T*>(*(retPtr + index));
|
||||
}
|
||||
|
||||
} // namespace GROUPED_MATMUL
|
||||
|
||||
#endif // ASCENDC_GROUPED_MATMUL_UTILS_H
|
||||
@@ -552,6 +552,41 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> grouped_matmul_swiglu_quant(
|
||||
output_offset);
|
||||
return std::tuple<at::Tensor, at::Tensor, at::Tensor>(output, output_scale, output_offset);
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor> grouped_matmul_swiglu_quant_weight_nz_tensor_list(
|
||||
const at::Tensor & x,
|
||||
const at::TensorList & weight,
|
||||
const at::TensorList & weight_scale,
|
||||
const at::Tensor & x_scale,
|
||||
const at::Tensor & group_list,
|
||||
const c10::optional<at::Tensor> & bias,
|
||||
const c10::optional<at::Tensor> & offset)
|
||||
{
|
||||
auto x_size = x.sizes();
|
||||
int n = weight[0].sizes()[1];
|
||||
int m = x_size[0];
|
||||
int k = x_size[1];
|
||||
|
||||
at::Tensor output = at::zeros({m, n/2}, x.options().dtype(at::kChar));
|
||||
at::Tensor output_scale = at::zeros({m}, x.options().dtype(at::kFloat));
|
||||
at::Tensor output_offset = at::zeros({m}, x.options().dtype(at::kFloat));
|
||||
|
||||
EXEC_NPU_CMD(
|
||||
aclnnGroupedMatmulSwigluQuantWeightNzTensorList,
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
offset,
|
||||
weight_scale,
|
||||
x_scale,
|
||||
group_list,
|
||||
output,
|
||||
output_scale,
|
||||
output_offset);
|
||||
|
||||
return std::tuple<at::Tensor, at::Tensor, at::Tensor>(output, output_scale, output_offset);
|
||||
}
|
||||
|
||||
} // namespace vllm_ascend
|
||||
|
||||
TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
|
||||
@@ -614,4 +649,12 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
|
||||
" Tensor group_list, *, Tensor? bias=None,"
|
||||
" Tensor? offset=None) -> (Tensor output, Tensor output_scale, Tensor output_offset)");
|
||||
ops.impl("grouped_matmul_swiglu_quant", torch::kPrivateUse1, &vllm_ascend::grouped_matmul_swiglu_quant);
|
||||
|
||||
ops.def(
|
||||
"grouped_matmul_swiglu_quant_weight_nz_tensor_list(Tensor x, Tensor[] weight, Tensor[] weight_scale, Tensor x_scale,"
|
||||
" Tensor group_list, *,"
|
||||
" Tensor? bias=None, Tensor? offset=None) ->"
|
||||
" (Tensor output, Tensor output_scale, Tensor output_offset)"
|
||||
);
|
||||
ops.impl("grouped_matmul_swiglu_quant_weight_nz_tensor_list", torch::kPrivateUse1, &vllm_ascend::grouped_matmul_swiglu_quant_weight_nz_tensor_list);
|
||||
}
|
||||
|
||||
@@ -130,14 +130,34 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> grouped_matmul_swiglu_quant(
|
||||
return {output, output_scale, output_offset};
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor> grouped_matmul_swiglu_quant_weight_nz_tensor_list_meta(
|
||||
const at::Tensor & x,
|
||||
const at::TensorList & weight,
|
||||
const at::TensorList & weight_scale,
|
||||
const at::Tensor & x_scale,
|
||||
const at::Tensor & group_list,
|
||||
const c10::optional<at::Tensor> & bias,
|
||||
const c10::optional<at::Tensor> & offset)
|
||||
{
|
||||
auto x_size = x.sizes();
|
||||
int n = weight[0].sizes()[1];
|
||||
int m = x_size[0];
|
||||
int k = x_size[1];
|
||||
|
||||
at::Tensor output = at::zeros({m, n/2}, c10::dtype(c10::ScalarType::Char));
|
||||
at::Tensor output_scale = at::zeros({m}, c10::dtype(c10::ScalarType::Float));
|
||||
at::Tensor output_offset = at::zeros({m}, c10::dtype(c10::ScalarType::Float));
|
||||
|
||||
return std::tuple<at::Tensor, at::Tensor, at::Tensor>(output, output_scale, output_offset);
|
||||
}
|
||||
|
||||
} // namespace meta
|
||||
} // namespace vllm_ascend
|
||||
|
||||
namespace {
|
||||
// Register the meta implementations of the custom kernels for symbolic tracing, this will also
|
||||
// the custom kernel been captured into aclgraph
|
||||
TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) {
|
||||
// Register the meta implementations of the custom kernels for symbolic tracing, this will also
|
||||
// the custom kernel been captured into aclgraph
|
||||
TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) {
|
||||
// Rotary embedding meta implementation
|
||||
ops.impl("rotary_embedding", &vllm_ascend::meta::rotary_embedding_meta);
|
||||
// Masked input and mask meta implementation
|
||||
@@ -150,5 +170,7 @@ namespace {
|
||||
ops.impl("mla_preprocess", &vllm_ascend::meta::mla_preprocess);
|
||||
// grouped_matmul_swiglu_quant meta implementation
|
||||
ops.impl("grouped_matmul_swiglu_quant", &vllm_ascend::meta::grouped_matmul_swiglu_quant);
|
||||
// Grouped matmul swiglu quant weight nz tensor list
|
||||
ops.impl("grouped_matmul_swiglu_quant_weight_nz_tensor_list", &vllm_ascend::meta::grouped_matmul_swiglu_quant_weight_nz_tensor_list_meta);
|
||||
}
|
||||
}
|
||||
|
||||
48
csrc/utils/CMakeLists.txt
Normal file
48
csrc/utils/CMakeLists.txt
Normal file
@@ -0,0 +1,48 @@
|
||||
# Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
# This file is a part of the CANN Open Software.
|
||||
# Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
# Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See LICENSE in the root of the software repository for the full text of the License.
|
||||
# ======================================================================================================================
|
||||
|
||||
add_library(ops_utils_tiling_headers INTERFACE)
|
||||
|
||||
target_include_directories(ops_utils_tiling_headers INTERFACE
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/inc>
|
||||
$<$<BOOL:${BUILD_OPEN_PROJECT}>:$<BUILD_INTERFACE:${ASCEND_CANN_PACKAGE_PATH}/include/experiment/slog>>
|
||||
$<$<BOOL:${BUILD_OPEN_PROJECT}>:$<BUILD_INTERFACE:${ASCEND_CANN_PACKAGE_PATH}/include/experiment/metadef>>
|
||||
$<$<BOOL:${BUILD_OPEN_PROJECT}>:$<BUILD_INTERFACE:${ASCEND_CANN_PACKAGE_PATH}/include/experiment/runtime>>
|
||||
$<$<BOOL:${BUILD_OPEN_PROJECT}>:$<BUILD_INTERFACE:${ASCEND_CANN_PACKAGE_PATH}/include/experiment/msprof>>
|
||||
$<INSTALL_INTERFACE:include/ops_adv/utils>
|
||||
)
|
||||
|
||||
target_compile_definitions(ops_utils_tiling_headers INTERFACE
|
||||
OPS_UTILS_LOG_SUB_MOD_NAME="OP_TILING"
|
||||
OPS_UTILS_LOG_PACKAGE_TYPE=$<IF:$<BOOL:${BUILD_OPEN_PROJECT}>,"[Custom]","">
|
||||
)
|
||||
|
||||
add_library(ops_utils_proto_headers INTERFACE)
|
||||
|
||||
target_include_directories(ops_utils_proto_headers INTERFACE
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/inc>
|
||||
$<$<BOOL:${BUILD_OPEN_PROJECT}>:$<BUILD_INTERFACE:${ASCEND_CANN_PACKAGE_PATH}/include/experiment/slog>>
|
||||
$<$<BOOL:${BUILD_OPEN_PROJECT}>:$<BUILD_INTERFACE:${ASCEND_CANN_PACKAGE_PATH}/include/experiment/metadef>>
|
||||
$<$<BOOL:${BUILD_OPEN_PROJECT}>:$<BUILD_INTERFACE:${ASCEND_CANN_PACKAGE_PATH}/include/aclnn/opdev>>
|
||||
$<INSTALL_INTERFACE:include/ops_adv/utils>
|
||||
)
|
||||
|
||||
target_compile_definitions(ops_utils_proto_headers INTERFACE
|
||||
OPS_UTILS_LOG_SUB_MOD_NAME="OP_PROTO"
|
||||
OPS_UTILS_LOG_PACKAGE_TYPE=$<IF:$<BOOL:${BUILD_OPEN_PROJECT}>,"[Custom]","">
|
||||
)
|
||||
|
||||
if(NOT BUILD_OPEN_PROJECT)
|
||||
install_package(
|
||||
PACKAGE ops_adv
|
||||
TARGETS ops_utils_tiling_headers ops_utils_proto_headers
|
||||
DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/inc/
|
||||
DESTINATION include/ops_adv/utils
|
||||
)
|
||||
endif()
|
||||
14
csrc/utils/inc/aclnn_util.h
Normal file
14
csrc/utils/inc/aclnn_util.h
Normal file
@@ -0,0 +1,14 @@
|
||||
/**
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
#ifndef OP_API_INC_ACLNN_UTIL_H
|
||||
#define OP_API_INC_ACLNN_UTIL_H
|
||||
|
||||
#define ACLNN_API __attribute__((visibility("default")))
|
||||
#endif // OP_API_INC_ACLNN_UTIL_H
|
||||
25
csrc/utils/inc/error/ops_error.h
Normal file
25
csrc/utils/inc/error/ops_error.h
Normal file
@@ -0,0 +1,25 @@
|
||||
/**
|
||||
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file ops_error.h
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "log/ops_log.h"
|
||||
|
||||
/* 基础报错 */
|
||||
#define OPS_REPORT_VECTOR_INNER_ERR(OPS_DESC, ...) OPS_INNER_ERR_STUB("E89999", OPS_DESC, __VA_ARGS__)
|
||||
#define OPS_REPORT_CUBE_INNER_ERR(OPS_DESC, ...) OPS_INNER_ERR_STUB("E69999", OPS_DESC, __VA_ARGS__)
|
||||
|
||||
/* 条件报错 */
|
||||
#define OPS_ERR_IF(COND, LOG_FUNC, EXPR) OPS_LOG_STUB_IF(COND, LOG_FUNC, EXPR)
|
||||
497
csrc/utils/inc/fallback.h
Normal file
497
csrc/utils/inc/fallback.h
Normal file
@@ -0,0 +1,497 @@
|
||||
/**
|
||||
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file fallback.h
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#ifndef ACLNNFALLBACK_OPAPI_H_
|
||||
#define ACLNNFALLBACK_OPAPI_H_
|
||||
|
||||
#include <dlfcn.h>
|
||||
|
||||
#include <functional>
|
||||
#include <tuple>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
#include "aclnn/aclnn_base.h"
|
||||
#include "fallback_comm.h"
|
||||
#include "error/ops_error.h"
|
||||
#include "runtime/base.h"
|
||||
|
||||
namespace fallback {
|
||||
using namespace std;
|
||||
using namespace gert;
|
||||
using namespace ge;
|
||||
using namespace std;
|
||||
|
||||
namespace std_utils {
|
||||
template <std::size_t... Is>
|
||||
struct index_sequence {};
|
||||
|
||||
template <std::size_t N, std::size_t... Is>
|
||||
struct make_index_sequence_helper : make_index_sequence_helper<N - 1, N - 1, Is...> {};
|
||||
|
||||
template <std::size_t... Is>
|
||||
struct make_index_sequence_helper<0, Is...> {
|
||||
using type = index_sequence<Is...>;
|
||||
};
|
||||
|
||||
template <std::size_t N>
|
||||
using make_index_sequence = typename make_index_sequence_helper<N>::type;
|
||||
}
|
||||
|
||||
using aclOpExecutor = struct aclOpExecutor;
|
||||
using aclTensor = struct aclTensor;
|
||||
using aclScalar = struct aclScalar;
|
||||
using aclIntArray = struct aclIntArray;
|
||||
using aclFloatArray = struct aclFloatArray;
|
||||
using aclBoolArray = struct aclBoolArray;
|
||||
using aclTensorList = struct aclTensorList;
|
||||
|
||||
using _aclCreateTensor = aclTensor* (*)(const int64_t* view_dims, uint64_t view_dims_num, aclDataType data_type,
|
||||
const int64_t* stride, int64_t offset, aclFormat format,
|
||||
const int64_t* storage_dims, uint64_t storage_dims_num, void* tensor_data);
|
||||
|
||||
using _aclCreateScalar = aclScalar* (*)(void* value, aclDataType data_type);
|
||||
using _aclCreateIntArray = aclIntArray* (*)(const int64_t* value, uint64_t size);
|
||||
using _aclCreateFloatArray = aclFloatArray* (*)(const float* value, uint64_t size);
|
||||
using _aclCreateBoolArray = aclBoolArray* (*)(const bool* value, uint64_t size);
|
||||
using _aclCreateTensorList = aclTensorList* (*)(const aclTensor* const* value, uint64_t size);
|
||||
|
||||
using _aclDestroyTensor = int (*)(const aclTensor* tensor);
|
||||
using _aclDestroyScalar = int (*)(const aclScalar* scalar);
|
||||
using _aclDestroyIntArray = int (*)(const aclIntArray* array);
|
||||
using _aclDestroyFloatArray = int (*)(const aclFloatArray* array);
|
||||
using _aclDestroyBoolArray = int (*)(const aclBoolArray* array);
|
||||
using _aclDestroyTensorList = int (*)(const aclTensorList* array);
|
||||
|
||||
#define GET_OP_API_FUNC(apiName) reinterpret_cast<_##apiName>(GetOpApiFuncAddr(#apiName))
|
||||
|
||||
inline const char* GetOpApiLibName(void) {
|
||||
return "libopapi.so";
|
||||
}
|
||||
|
||||
inline const char* GetCustOpApiLibName(void) {
|
||||
return "libcust_opapi.so";
|
||||
}
|
||||
|
||||
inline void* GetOpApiFuncAddrInLib(void* handler, const char* libName, const char* apiName) {
|
||||
auto funcAddr = dlsym(handler, apiName);
|
||||
if (funcAddr == nullptr) {
|
||||
OPS_LOG_W("aclnnfallback", "dlsym %s from %s failed, error:%s.", apiName, libName, dlerror());
|
||||
}
|
||||
return funcAddr;
|
||||
}
|
||||
|
||||
inline void* GetOpApiLibHandler(const char* libName) {
|
||||
auto handler = dlopen(libName, RTLD_LAZY);
|
||||
if (handler == nullptr) {
|
||||
OPS_LOG_W("aclnnfallback", "dlopen %s failed, error:%s.", libName, dlerror());
|
||||
}
|
||||
return handler;
|
||||
}
|
||||
|
||||
inline void* GetAclnnArrdByApiName(const char *apiName) {
|
||||
vector<std:: string> libs = {"libaclnn_ops_infer.so", "libaclnn_ops_train.so", "libaclnn_math.so",
|
||||
"libaclnn_rand.so", "libaclnn_sparse.so", "libaclnn_fft.so"};
|
||||
for (const auto &libName : libs) {
|
||||
static auto libHandler = GetOpApiLibHandler(libName.c_str());
|
||||
if (libHandler != nullptr) {
|
||||
auto funcAddr = GetOpApiFuncAddrInLib(libHandler, libName.c_str(), apiName);
|
||||
if (funcAddr != nullptr) {
|
||||
return funcAddr;
|
||||
}
|
||||
}
|
||||
}
|
||||
OPS_LOG_E("aclnnfallback", "api %s can't find in any aclnn lib.", apiName);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
inline void* GetOpApiFuncAddr(const char* apiName) {
|
||||
static auto custOpApiHandler = GetOpApiLibHandler(GetCustOpApiLibName());
|
||||
if (custOpApiHandler != nullptr) {
|
||||
auto funcAddr = GetOpApiFuncAddrInLib(custOpApiHandler, GetCustOpApiLibName(), apiName);
|
||||
if (funcAddr != nullptr) {
|
||||
return funcAddr;
|
||||
}
|
||||
}
|
||||
|
||||
static auto opApiHandler = GetOpApiLibHandler(GetOpApiLibName());
|
||||
if (opApiHandler != nullptr) {
|
||||
auto funcAddr = GetOpApiFuncAddrInLib(opApiHandler, GetOpApiLibName(), apiName);
|
||||
if (funcAddr != nullptr) {
|
||||
return funcAddr;
|
||||
}
|
||||
}
|
||||
OPS_LOG_D("aclnnfallback", "opapi lib is not exist,will use aclnn lib.");
|
||||
return GetAclnnArrdByApiName(apiName);
|
||||
}
|
||||
|
||||
inline aclTensor* ConvertType(aclTensor* ge_tensor) {
|
||||
return ge_tensor;
|
||||
}
|
||||
|
||||
inline aclIntArray* ConvertType(const std::vector<int64_t> &arr) {
|
||||
if (arr.empty()) {
|
||||
return nullptr;
|
||||
}
|
||||
static const auto aclCreateIntArray = GET_OP_API_FUNC(aclCreateIntArray);
|
||||
auto array = aclCreateIntArray(arr.data(), arr.size());
|
||||
return array;
|
||||
}
|
||||
|
||||
inline aclDataType GetConvertType(const gert::Tensor* ge_tensor) {
|
||||
// convert data type
|
||||
auto dataType_ge = ge_tensor->GetDataType();
|
||||
auto dataType = aclDataType::ACL_FLOAT16;
|
||||
if (dataType_ge == DT_FLOAT) {
|
||||
dataType = aclDataType::ACL_FLOAT;
|
||||
} else if (dataType_ge == DT_BF16) {
|
||||
dataType = aclDataType::ACL_BF16;
|
||||
} else if (dataType_ge == DT_BOOL) {
|
||||
dataType = aclDataType::ACL_BOOL;
|
||||
} else if (dataType_ge == DT_INT64) {
|
||||
dataType = aclDataType::ACL_INT64;
|
||||
} else if (dataType_ge == DT_INT32) {
|
||||
dataType = aclDataType::ACL_INT32;
|
||||
} else if (dataType_ge == DT_UINT64) {
|
||||
dataType = aclDataType::ACL_UINT64;
|
||||
} else if (dataType_ge == DT_UINT32) {
|
||||
dataType = aclDataType::ACL_UINT32;
|
||||
} else if (dataType_ge == DT_INT8) {
|
||||
dataType = aclDataType::ACL_INT8;
|
||||
} else if (dataType_ge == DT_UINT8) {
|
||||
dataType = aclDataType::ACL_UINT8;
|
||||
} else if (dataType_ge == DT_INT4) {
|
||||
dataType = aclDataType::ACL_INT4;
|
||||
} else {
|
||||
dataType = aclDataType::ACL_FLOAT16;
|
||||
}
|
||||
|
||||
return dataType;
|
||||
}
|
||||
|
||||
inline aclTensor* ConvertType(const gert::Tensor* ge_tensor) {
|
||||
if (ge_tensor == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static const auto aclCreateTensor = GET_OP_API_FUNC(aclCreateTensor);
|
||||
OPS_ERR_IF(aclCreateTensor == nullptr, OPS_LOG_E("aclnnfallback", "aclCreateTensor nullptr"), return nullptr);
|
||||
|
||||
void* device_addr = nullptr;
|
||||
auto tensor_place = ge_tensor->GetPlacement();
|
||||
device_addr = const_cast<void*>(ge_tensor->GetAddr());
|
||||
|
||||
auto dataType = GetConvertType(ge_tensor);
|
||||
|
||||
OPS_LOG_D("aclnnfallback", "aclCreateTensor: tensor type is %d", dataType);
|
||||
|
||||
// convert shape
|
||||
auto gert_shape = ge_tensor->GetStorageShape();
|
||||
std::vector<int64_t> shape;
|
||||
for (size_t i = 0; i < gert_shape.GetDimNum(); ++i) {
|
||||
shape.push_back(gert_shape.GetDim(i));
|
||||
}
|
||||
|
||||
// 计算连续tensor的strides
|
||||
std::vector<int64_t> strides(shape.size(), 1);
|
||||
for (int64_t i = shape.size() - 2; i >= 0; i--) {
|
||||
strides[i] = shape[i + 1] * strides[i + 1];
|
||||
}
|
||||
|
||||
aclTensor* out = aclCreateTensor(shape.data(), shape.size(), dataType, strides.data(),
|
||||
0, aclFormat::ACL_FORMAT_ND,
|
||||
shape.data(), shape.size(), device_addr);
|
||||
|
||||
OPS_ERR_IF(out == nullptr,
|
||||
OPS_LOG_E("aclnnfallback", "out nullptr"), return nullptr);
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
inline aclTensorList* ConvertType(std::vector<const gert::Tensor*>& ge_tenserList) {
|
||||
OPS_ERR_IF(ge_tenserList.size() == 0,
|
||||
OPS_LOG_E("aclnnfallback", "ge_tenserList size 0"), return nullptr);
|
||||
|
||||
static const auto aclCreateTensorList = GET_OP_API_FUNC(aclCreateTensorList);
|
||||
OPS_ERR_IF(aclCreateTensorList == nullptr,
|
||||
OPS_LOG_E("aclnnfallback", "ge_tenserList size 0"), return nullptr);
|
||||
|
||||
std::vector<aclTensor*> tmp;
|
||||
for (size_t i = 0; i < ge_tenserList.size(); i++) {
|
||||
auto t_acl = ConvertType(ge_tenserList[i]);
|
||||
tmp.push_back(t_acl);
|
||||
}
|
||||
|
||||
aclTensorList* tensorList = aclCreateTensorList(tmp.data(), tmp.size());
|
||||
return tensorList;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline aclScalar* ConvertScalarType(T value) {
|
||||
static const auto aclCreateScalar = GET_OP_API_FUNC(aclCreateScalar);
|
||||
OPS_ERR_IF(aclCreateScalar == nullptr,
|
||||
OPS_LOG_E("aclnnfallback", "aclCreateScalar nullptr"), return nullptr);
|
||||
if (typeid(value) == typeid(float)) {
|
||||
return aclCreateScalar(&value, aclDataType::ACL_FLOAT);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T ConvertType(T value) {
|
||||
return value;
|
||||
}
|
||||
|
||||
inline aclTensor* ConvertMmType(const gert::Tensor* ge_tensor, bool transpose, bool enable_NZ=false) {
|
||||
if (ge_tensor == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
auto gert_shape = ge_tensor->GetStorageShape();
|
||||
if (gert_shape.GetDimNum() <= 1) {
|
||||
return ConvertType(ge_tensor);
|
||||
}
|
||||
|
||||
static const auto aclCreateTensor = GET_OP_API_FUNC(aclCreateTensor);
|
||||
OPS_ERR_IF(aclCreateTensor == nullptr, OPS_LOG_E("aclnnfallback", "aclCreateTensor nullptr"), return nullptr);
|
||||
|
||||
void* device_addr = const_cast<void*>(ge_tensor->GetAddr());
|
||||
// convert data type
|
||||
auto dataType_ge = ge_tensor->GetDataType();
|
||||
auto dataType = ToAclDataType(dataType_ge);
|
||||
// convert shape
|
||||
std::vector<int64_t> shape;
|
||||
for (size_t i = 0; i < gert_shape.GetDimNum(); ++i) {
|
||||
shape.push_back(gert_shape.GetDim(i));
|
||||
}
|
||||
// 计算连续tensor的strides
|
||||
std::vector<int64_t> strides(shape.size(), 1);
|
||||
for (int64_t i = shape.size() - 2; i >= 0; i--) {
|
||||
strides[i] = shape[i + 1] * strides[i + 1];
|
||||
}
|
||||
|
||||
auto viewShape = shape;
|
||||
// 对于transpose后的tensor对后两维度进行strides, viewShape转换
|
||||
if (transpose) {
|
||||
// dimM 为倒数第二维, dimN 为倒数第一维度
|
||||
auto dimM = shape.size() - 2;
|
||||
auto dimN = shape.size() - 1;
|
||||
auto swap = strides[dimN];
|
||||
strides[dimN] = strides[dimM];
|
||||
strides[dimM] = swap;
|
||||
// 修改viewShape
|
||||
viewShape[dimN] = shape[dimM];
|
||||
viewShape[dimM] = shape[dimN];
|
||||
}
|
||||
auto acl_format = aclFormat::ACL_FORMAT_ND;
|
||||
if (enable_NZ && GetPrimaryFormat(ge_tensor->GetStorageFormat()) == ge::Format::FORMAT_FRACTAL_NZ) {
|
||||
acl_format = aclFormat::ACL_FORMAT_FRACTAL_NZ;
|
||||
}
|
||||
aclTensor* out = aclCreateTensor(viewShape.data(), shape.size(), dataType, strides.data(),
|
||||
0, acl_format, shape.data(), shape.size(), device_addr);
|
||||
OPS_ERR_IF(out == nullptr, OPS_LOG_E("aclnnfallback", "out nullptr"), return nullptr);
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
inline void Release(aclTensor* p) {
|
||||
static const auto aclDestroyTensor = GET_OP_API_FUNC(aclDestroyTensor);
|
||||
OPS_ERR_IF(aclDestroyTensor == nullptr,
|
||||
OPS_LOG_E("aclnnfallback", "aclDestroyTensor is null"), return);
|
||||
aclDestroyTensor(p);
|
||||
}
|
||||
|
||||
inline void Release(aclScalar* p) {
|
||||
static const auto aclDestroyScalar = GET_OP_API_FUNC(aclDestroyScalar);
|
||||
OPS_ERR_IF(aclDestroyScalar == nullptr,
|
||||
OPS_LOG_E("aclnnfallback", "aclDestroyScalar is null"), return);
|
||||
aclDestroyScalar(p);
|
||||
}
|
||||
|
||||
inline void Release(aclIntArray* p) {
|
||||
static const auto aclDestroyIntArray = GET_OP_API_FUNC(aclDestroyIntArray);
|
||||
OPS_ERR_IF(aclDestroyIntArray == nullptr,
|
||||
OPS_LOG_E("aclnnfallback", "aclDestroyIntArray is null"), return);
|
||||
aclDestroyIntArray(p);
|
||||
}
|
||||
|
||||
inline void Release(aclBoolArray* p) {
|
||||
static const auto aclDestroyBoolArray = GET_OP_API_FUNC(aclDestroyBoolArray);
|
||||
OPS_ERR_IF(aclDestroyBoolArray == nullptr,
|
||||
OPS_LOG_E("aclnnfallback", "aclDestroyBoolArray is null"), return);
|
||||
aclDestroyBoolArray(p);
|
||||
}
|
||||
|
||||
inline void Release(aclTensorList* p) {
|
||||
static const auto aclDestroyTensorList = GET_OP_API_FUNC(aclDestroyTensorList);
|
||||
OPS_ERR_IF(aclDestroyTensorList == nullptr,
|
||||
OPS_LOG_E("aclnnfallback", "aclDestroyTensorList is null"), return);
|
||||
aclDestroyTensorList(p);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void Release(T value) {
|
||||
(void)value;
|
||||
}
|
||||
|
||||
template <typename Tuple, size_t... I>
|
||||
void CallRelease(Tuple t, std_utils::index_sequence<I...>) {
|
||||
(void)std::initializer_list<int>{(Release(std::get<I>(t)), 0)...};
|
||||
}
|
||||
|
||||
template <typename Tuple>
|
||||
void ReleaseConvertTypes(Tuple& t) {
|
||||
static constexpr auto size = std::tuple_size<Tuple>::value;
|
||||
CallRelease(t, std_utils::make_index_sequence<size>{});
|
||||
}
|
||||
|
||||
template <typename... Ts>
|
||||
auto ConvertTypes(Ts&... args) -> decltype(std::make_tuple(ConvertType(args)...)) {
|
||||
auto tp = std::make_tuple(ConvertType(args)...);
|
||||
return tp;
|
||||
}
|
||||
|
||||
template <typename Function, typename Tuple, size_t... I>
|
||||
auto call(Function f, Tuple t, std_utils::index_sequence<I...>) -> int {
|
||||
return f(std::get<I>(t)...);
|
||||
}
|
||||
|
||||
template <typename Function, typename Tuple>
|
||||
auto call(Function f, Tuple t) -> int {
|
||||
static constexpr auto size = std::tuple_size<Tuple>::value;
|
||||
return call(f, t, std_utils::make_index_sequence<size>{});
|
||||
}
|
||||
|
||||
template <typename Tuple, size_t... I>
|
||||
auto ConvertToOpApiFunc(const Tuple& params, void* opApiAddr, std_utils::index_sequence<I...>)
|
||||
-> int (*)(typename std::decay<decltype(std::get<I>(params))>::type...) {
|
||||
using OpApiFunc = int (*)(typename std::decay<decltype(std::get<I>(params))>::type...);
|
||||
auto func = reinterpret_cast<OpApiFunc>(opApiAddr);
|
||||
return func;
|
||||
}
|
||||
|
||||
template <typename Tuple>
|
||||
auto ConvertToOpApiFunc(const Tuple& params, void* opApiAddr)
|
||||
-> typename std::enable_if<std::tuple_size<Tuple>::value != 0,
|
||||
decltype(ConvertToOpApiFunc(params, opApiAddr, std_utils::make_index_sequence<std::tuple_size<Tuple>::value>{}))>::type {
|
||||
static constexpr auto size = std::tuple_size<Tuple>::value;
|
||||
return ConvertToOpApiFunc(params, opApiAddr, std_utils::make_index_sequence<size>{});
|
||||
}
|
||||
|
||||
template <typename Tuple>
|
||||
class ConvertedParams {
|
||||
public:
|
||||
ConvertedParams(Tuple&& convertedParams) : convertedParams_(std::move(convertedParams)){};
|
||||
ConvertedParams(ConvertedParams&& other) : convertedParams_(std::move(other.convertedParams_)) {
|
||||
other.validParams_ = false;
|
||||
};
|
||||
ConvertedParams& operator=(ConvertedParams&& other) {
|
||||
if (this == &other) {
|
||||
return *this;
|
||||
}
|
||||
|
||||
convertedParams_ = std::move(other.convertedParams_);
|
||||
validParams_ = true;
|
||||
other.validParams_ = false;
|
||||
return *this;
|
||||
}
|
||||
|
||||
ConvertedParams() = delete;
|
||||
ConvertedParams(const ConvertedParams& other) = delete;
|
||||
ConvertedParams& operator=(const ConvertedParams& other) = delete;
|
||||
|
||||
~ConvertedParams() {
|
||||
if (validParams_) {
|
||||
ReleaseConvertTypes(convertedParams_);
|
||||
}
|
||||
}
|
||||
|
||||
const Tuple& GetConvertedParams() const {
|
||||
return convertedParams_;
|
||||
}
|
||||
|
||||
private:
|
||||
Tuple convertedParams_;
|
||||
bool validParams_{true};
|
||||
};
|
||||
|
||||
using InitHugeMemThreadLocal = int (*)(void*, bool);
|
||||
using UnInitHugeMemThreadLocal = void (*)(void*, bool);
|
||||
using ReleaseHugeMem = void (*)(void*, bool);
|
||||
using PTAGetExecCache = aclOpExecutor* (*)(uint64_t, uint64_t*);
|
||||
using InitPTACacheThreadLocal = void (*)();
|
||||
using SetPTAHashKey = void (*)(uint64_t);
|
||||
using CanUsePTACache = bool (*)(const char*);
|
||||
|
||||
using ResetCacheThreadLocal = void (*)();
|
||||
|
||||
#define EXEC_OPAPI_CMD(aclnn_api, ...) \
|
||||
({ \
|
||||
static auto ret = GRAPH_SUCCESS; \
|
||||
do { \
|
||||
static const auto ResetCacheThreadLocalAddr = GetOpApiFuncAddr("ResetCacheThreadLocal"); \
|
||||
static const auto getWorkspaceSizeFuncAddr = GetOpApiFuncAddr(#aclnn_api "GetWorkspaceSize"); \
|
||||
static const auto opApiFuncAddr = GetOpApiFuncAddr(#aclnn_api); \
|
||||
if (getWorkspaceSizeFuncAddr == nullptr || opApiFuncAddr == nullptr || ResetCacheThreadLocalAddr == nullptr) { \
|
||||
OPS_LOG_E("aclnnfallback", "%s or %s not in %s or %s or ResetCacheThreadLocal not found.", \
|
||||
#aclnn_api "GetWorkspaceSize", #aclnn_api, GetOpApiLibName(), GetOpApiLibName()); \
|
||||
ret = GRAPH_FAILED; \
|
||||
break; \
|
||||
} \
|
||||
auto ResetCacheThreadLocalFunc = reinterpret_cast<ResetCacheThreadLocal>(ResetCacheThreadLocalAddr); \
|
||||
ResetCacheThreadLocalFunc(); \
|
||||
uint64_t workspace_size = 0; \
|
||||
uint64_t* workspace_size_addr = &workspace_size; \
|
||||
aclOpExecutor* executor = nullptr; \
|
||||
aclOpExecutor** executor_addr = &executor; \
|
||||
auto converted_params = ConvertTypes(__VA_ARGS__, workspace_size_addr, executor_addr); \
|
||||
static auto getWorkspaceSizeFunc = ConvertToOpApiFunc(converted_params, getWorkspaceSizeFuncAddr); \
|
||||
auto workspace_status = call(getWorkspaceSizeFunc, converted_params); \
|
||||
if (workspace_status != 0) { \
|
||||
OPS_LOG_E("aclnnfallback", "call %s failed:", #aclnn_api); \
|
||||
ret = GRAPH_FAILED; \
|
||||
break; \
|
||||
} \
|
||||
void* workspace_addr = nullptr; \
|
||||
if (workspace_size > 0) { \
|
||||
workspace_addr = host_api_ctx->MallocWorkspace(workspace_size); \
|
||||
if (workspace_addr == nullptr) { \
|
||||
OPS_LOG_E("aclnnfallback", "call %s allocate workspace failed", #aclnn_api); \
|
||||
ret = GRAPH_FAILED; \
|
||||
break; \
|
||||
} \
|
||||
} \
|
||||
auto acl_stream = host_api_ctx->GetStream(); \
|
||||
auto acl_call = [converted_params, workspace_addr, workspace_size, host_api_ctx, acl_stream, \
|
||||
executor]() -> int { \
|
||||
using OpApiFunc = int (*)(void*, uint64_t, aclOpExecutor*, const aclrtStream); \
|
||||
OpApiFunc opApiFunc = reinterpret_cast<OpApiFunc>(opApiFuncAddr); \
|
||||
auto api_ret = opApiFunc(workspace_addr, workspace_size, executor, acl_stream); \
|
||||
ReleaseConvertTypes(converted_params); \
|
||||
host_api_ctx->FreeWorkspace(); \
|
||||
if (api_ret != 0) { \
|
||||
OPS_LOG_E("aclnnfallback", "call %s allocate workspace failed api_ret: %d", #aclnn_api, api_ret); \
|
||||
return GRAPH_FAILED; \
|
||||
} \
|
||||
return api_ret; \
|
||||
}; \
|
||||
\
|
||||
ret = acl_call(); \
|
||||
} while (false); \
|
||||
(ret); \
|
||||
})
|
||||
|
||||
} // namespace fallback
|
||||
|
||||
#endif // ACLNNFALLBACK_OPAPI_H_
|
||||
38
csrc/utils/inc/fallback_comm.h
Normal file
38
csrc/utils/inc/fallback_comm.h
Normal file
@@ -0,0 +1,38 @@
|
||||
/**
|
||||
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file fallback_comm.h
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#ifndef INC_EXTERNAL_GRAPH_FALLBACK_COMMON_H_
|
||||
#define INC_EXTERNAL_GRAPH_FALLBACK_COMMON_H_
|
||||
|
||||
#include "aclnn/aclnn_base.h"
|
||||
#include "exe_graph/runtime/op_execute_context.h"
|
||||
#include "exe_graph/runtime/tensor.h"
|
||||
#include "register/op_impl_registry.h"
|
||||
#include "runtime/base.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
namespace fallback {
|
||||
|
||||
aclDataType ToAclDataType(ge::DataType dtype);
|
||||
} // namespace fallback
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // INC_EXTERNAL_GRAPH_FALLBACK_COMMON_H_
|
||||
121
csrc/utils/inc/kernel/dropmask.h
Normal file
121
csrc/utils/inc/kernel/dropmask.h
Normal file
@@ -0,0 +1,121 @@
|
||||
/**
|
||||
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file dropmask.h
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#ifndef DROPMASK_H
|
||||
#define DROPMASK_H
|
||||
|
||||
#include "util.h"
|
||||
|
||||
using AscendC::DROPOUT_MODE_BIT_MISALIGN;
|
||||
using AscendC::DropOutShapeInfo;
|
||||
using AscendC::DropOut;
|
||||
|
||||
struct DropMaskInfo {
|
||||
// for compute dropout mask offset
|
||||
// 参数按B N G S1 S2全部切分设置进行偏移计算,没有切分的轴对应的参数设置为合适的0或者原始值
|
||||
int64_t n2G; // n2 * g
|
||||
int64_t gSize; // g
|
||||
int64_t s1Size; // s1
|
||||
int64_t s2Size; // s2
|
||||
int64_t gOutIdx; // g out index
|
||||
int64_t bSSOffset; // boidx * s1 * s2 ===bSSOffset
|
||||
int64_t n2OutIdx; // n out index
|
||||
int64_t s1OutIdx; // s1 out index ===s1oIdx
|
||||
int64_t s1InnerIdx; // s1 inner index, 配比 ===loopIdx
|
||||
int64_t s1BaseSize; // S1基本块大小
|
||||
int64_t splitS1BaseSize; // s1 split size ===vec1S1BaseSize
|
||||
int64_t s2StartIdx; // s2 start index
|
||||
int64_t s2Idx; // s2 index =====s2LoopCount
|
||||
int64_t s2BaseNratioSize; // s2的配比长度: s2BaseSize(S2基本块大小) * nRatio
|
||||
|
||||
// for copy in dropout mask
|
||||
uint32_t s1CopySize;
|
||||
uint32_t s2CopySize;
|
||||
int64_t s2TotalSize;
|
||||
|
||||
// for compute dropout mask
|
||||
uint32_t firstAxis;
|
||||
uint32_t lstAxis;
|
||||
uint32_t maskLstAxis;
|
||||
int64_t vecCoreOffset = 0;
|
||||
float keepProb;
|
||||
|
||||
bool boolMode;
|
||||
};
|
||||
|
||||
template <bool hasDrop>
|
||||
__aicore__ inline int64_t ComputeDropOffset(DropMaskInfo &dropMaskInfo)
|
||||
{
|
||||
if constexpr (hasDrop == true) {
|
||||
// boidx * n2 * g* s1 * s2
|
||||
int64_t bOffset = dropMaskInfo.bSSOffset * dropMaskInfo.n2G;
|
||||
// n2oIdx * g * s1 *s2
|
||||
int64_t n2Offset = dropMaskInfo.n2OutIdx * dropMaskInfo.gSize * dropMaskInfo.s1Size * dropMaskInfo.s2Size;
|
||||
// goIdx * s1 * s2
|
||||
int64_t gOffset = dropMaskInfo.gOutIdx * dropMaskInfo.s1Size * dropMaskInfo.s2Size;
|
||||
// s1oIdx * s1BaseSize * s2Size + s1innerindex * vec1S1BaseSize * s2Size
|
||||
int64_t s1Offset = (dropMaskInfo.s1OutIdx * dropMaskInfo.s1BaseSize + dropMaskInfo.vecCoreOffset +
|
||||
dropMaskInfo.s1InnerIdx * dropMaskInfo.splitS1BaseSize) * dropMaskInfo.s2Size;
|
||||
// s2StartIdx + s2index * s2BaseNratioSize
|
||||
int64_t s2Offset = dropMaskInfo.s2StartIdx + dropMaskInfo.s2Idx * dropMaskInfo.s2BaseNratioSize;
|
||||
return bOffset + n2Offset + gOffset + s1Offset + s2Offset;
|
||||
} else {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
template <bool hasDrop>
|
||||
__aicore__ inline void CopyInDropMask(LocalTensor<uint8_t>&dstTensor, GlobalTensor<uint8_t>& srcBoolTensor,
|
||||
GlobalTensor<uint8_t>& srcByteTensor, DropMaskInfo &dropMaskInfo, int64_t alignedSize = blockBytes)
|
||||
{
|
||||
if constexpr (hasDrop == true) {
|
||||
int64_t dropMaskOffset = ComputeDropOffset<hasDrop>(dropMaskInfo);
|
||||
if (unlikely(dropMaskInfo.boolMode)) {
|
||||
BoolCopyIn(dstTensor, srcBoolTensor, dropMaskOffset,
|
||||
dropMaskInfo.s1CopySize, dropMaskInfo.s2CopySize, dropMaskInfo.s2TotalSize, alignedSize);
|
||||
} else {
|
||||
Bit2Int8CopyIn(dstTensor, srcByteTensor, dropMaskOffset, 1,
|
||||
dropMaskInfo.s1CopySize, dropMaskInfo.s2CopySize, dropMaskInfo.s2TotalSize, alignedSize);
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, bool hasDrop>
|
||||
__aicore__ inline void ComputeDropMask(LocalTensor<T>& dstTensor, LocalTensor<T>& srcTensor,
|
||||
LocalTensor<uint8_t>& dropoutBuffer, LocalTensor<uint8_t>& tmpDropBuffer, DropMaskInfo &dropMaskInfo)
|
||||
{
|
||||
if constexpr (hasDrop == true) {
|
||||
DropOutShapeInfo dropOutShapeInfo;
|
||||
dropOutShapeInfo.firstAxis = dropMaskInfo.firstAxis;
|
||||
dropOutShapeInfo.srcLastAxis = dropMaskInfo.lstAxis;
|
||||
|
||||
if (unlikely(dropMaskInfo.boolMode)) {
|
||||
dropOutShapeInfo.maskLastAxis = CeilDiv(dropMaskInfo.maskLstAxis, blockBytes) * blockBytes;
|
||||
DropOut(dstTensor, srcTensor, dropoutBuffer, tmpDropBuffer, dropMaskInfo.keepProb, dropOutShapeInfo);
|
||||
} else {
|
||||
dropOutShapeInfo.maskLastAxis = CeilDiv(dropMaskInfo.maskLstAxis / byteBitRatio, blockBytes) * blockBytes;
|
||||
if (likely(dropMaskInfo.lstAxis / byteBitRatio % blockBytes == 0)) {
|
||||
DropOut(dstTensor, srcTensor, dropoutBuffer, tmpDropBuffer, dropMaskInfo.keepProb, dropOutShapeInfo);
|
||||
} else {
|
||||
DropOut<T, false, DROPOUT_MODE_BIT_MISALIGN>(dstTensor, srcTensor, dropoutBuffer, tmpDropBuffer,
|
||||
dropMaskInfo.keepProb, dropOutShapeInfo);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
#endif // DROPMASK_H
|
||||
483
csrc/utils/inc/kernel/pse.h
Normal file
483
csrc/utils/inc/kernel/pse.h
Normal file
@@ -0,0 +1,483 @@
|
||||
/**
|
||||
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file pse.h
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#ifndef FLASH_ATTENTION_SCORE_PSE_H
|
||||
#define FLASH_ATTENTION_SCORE_PSE_H
|
||||
|
||||
#include "kernel_operator.h"
|
||||
#include "util.h"
|
||||
|
||||
constexpr static int64_t pseS1S2 = 0;
|
||||
constexpr static int64_t pse1S2 = 1;
|
||||
constexpr static int64_t pseSlopeBn = 2;
|
||||
constexpr static int64_t pseSlopeN = 3;
|
||||
|
||||
constexpr static uint8_t pseEncodeALibiS2Full = 0x11;
|
||||
|
||||
enum class PseTypeEnum {
|
||||
PSE_OUTER_MUL_ADD_TYPE = 0, // default
|
||||
PSE_OUTER_ADD_MUL_TYPE,
|
||||
PSE_INNER_MUL_ADD_TYPE,
|
||||
PSE_INNER_MUL_ADD_SQRT_TYPE,
|
||||
PSE_INVALID_TYPE
|
||||
};
|
||||
|
||||
struct PseInfo {
|
||||
int64_t blockCount;
|
||||
int64_t bSSOffset; // boidx * s1 * s2
|
||||
int64_t boIdx;
|
||||
int64_t gSize;
|
||||
int64_t goIdx;
|
||||
int64_t loopIdx;
|
||||
int64_t n2G;
|
||||
int64_t n2oIdx;
|
||||
int64_t pseBSize;
|
||||
int64_t pseS1Size; // for alibi
|
||||
int64_t pseS2ComputeSize; // for alibi, do not need assignment
|
||||
int64_t pseS2Size; // for alibi
|
||||
uint32_t pseShapeType;
|
||||
int64_t readS2Size; // for alibi, do not need assignment
|
||||
int64_t s1BaseSize;
|
||||
int64_t s1Size;
|
||||
int64_t s1oIdx;
|
||||
int64_t s2AlignedSize;
|
||||
int64_t s2BaseNratioSize;
|
||||
int64_t s2LoopCount;
|
||||
int64_t s2RealSize;
|
||||
int64_t s2Size;
|
||||
int64_t s2SizeAcc; // accumulated sum of s2 size
|
||||
int64_t s2StartIdx;
|
||||
int64_t vec1S1BaseSize;
|
||||
int64_t vec1S1RealSize;
|
||||
uint32_t pseEncodeType; // for distinguish alibi
|
||||
uint32_t pseType; // 0: outer, mul-add 1:outer, add-mul 2:inner, mul-add 3:inner, mul-add-sqrt
|
||||
int64_t pseAlibiBaseS1;
|
||||
int64_t pseAlibiBaseS2;
|
||||
int64_t qStartIdx;
|
||||
int64_t kvStartIdx;
|
||||
int64_t vecCoreOffset = 0;
|
||||
bool needCast;
|
||||
bool align8 = false;
|
||||
bool pseEndogenous = false;
|
||||
};
|
||||
|
||||
template <typename INPUT_T, bool hasPse>
|
||||
__aicore__ inline void DataCopyInCommon(LocalTensor<INPUT_T> &dstTensor, GlobalTensor<INPUT_T> &srcTensor, int64_t offset,
|
||||
int64_t s1Size, int64_t s2Size, int64_t actualS2Len, int32_t dtypeSize,
|
||||
int32_t alignedS2Size)
|
||||
{
|
||||
if constexpr (hasPse == true) {
|
||||
uint32_t shapeArray[] = {static_cast<uint32_t>(s1Size), static_cast<uint32_t>(alignedS2Size)};
|
||||
dstTensor.SetShapeInfo(ShapeInfo(2, shapeArray, DataFormat::ND));
|
||||
dstTensor.SetSize(s1Size * alignedS2Size);
|
||||
DataCopyParams dataCopyParams;
|
||||
dataCopyParams.blockCount = s1Size;
|
||||
dataCopyParams.blockLen = CeilDiv(s2Size * dtypeSize, blockBytes); // 单位32B
|
||||
dataCopyParams.dstStride = alignedS2Size * dtypeSize / blockBytes - dataCopyParams.blockLen; // gap
|
||||
if (actualS2Len * dtypeSize % blockBytes == 0) {
|
||||
dataCopyParams.srcStride =
|
||||
(actualS2Len * dtypeSize - dataCopyParams.blockLen * blockBytes) / blockBytes; // srcGap
|
||||
DataCopy(dstTensor, srcTensor[offset], dataCopyParams);
|
||||
} else {
|
||||
dataCopyParams.blockLen = s2Size * dtypeSize; // 单位Byte
|
||||
dataCopyParams.srcStride = (actualS2Len * dtypeSize - dataCopyParams.blockLen);
|
||||
dataCopyParams.dstStride = (alignedS2Size - s2Size) * dtypeSize / blockBytes;
|
||||
DataCopyPadParams dataCopyPadParams;
|
||||
dataCopyPadParams.isPad = false;
|
||||
DataCopyPad(dstTensor, srcTensor[offset], dataCopyParams, dataCopyPadParams);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename INPUT_T, bool hasPse>
|
||||
__aicore__ inline void DataCopyIn(LocalTensor<INPUT_T> &dstTensor, GlobalTensor<INPUT_T> &srcTensor, int64_t offset,
|
||||
int64_t s1Size, int64_t s2Size, int64_t actualS2Len, int64_t alignedSize = 16)
|
||||
{
|
||||
if constexpr (hasPse == true) {
|
||||
int32_t dtypeSize = sizeof(INPUT_T);
|
||||
int32_t alignedS2Size = CeilDiv(s2Size, alignedSize) * alignedSize;
|
||||
DataCopyInCommon<INPUT_T, hasPse>(dstTensor, srcTensor, offset, s1Size, s2Size,
|
||||
actualS2Len, dtypeSize, alignedS2Size);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename INPUT_T, bool hasPse>
|
||||
__aicore__ inline void DataCopyInAlign8(LocalTensor<INPUT_T> &dstTensor, GlobalTensor<INPUT_T> &srcTensor, int64_t offset,
|
||||
int64_t s1Size, int64_t s2Size, int64_t actualS2Len)
|
||||
{
|
||||
if constexpr (hasPse == true) {
|
||||
int32_t dtypeSize = sizeof(INPUT_T);
|
||||
if (dtypeSize == 0){
|
||||
return;
|
||||
}
|
||||
int32_t alignedS2Size = CeilDiv(s2Size, 32 / dtypeSize) * (32 / dtypeSize);
|
||||
DataCopyInCommon<INPUT_T, hasPse>(dstTensor, srcTensor, offset, s1Size, s2Size,
|
||||
actualS2Len, dtypeSize, alignedS2Size);
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
dst = BroadcastAdd(src0, src1)
|
||||
src0 shape: (s1, s2)
|
||||
src1 shape: (1, s2)
|
||||
dst shape: (s1, s2)
|
||||
*/
|
||||
template <typename T, bool hasPse>
|
||||
__aicore__ inline void BroadcastAdd(const LocalTensor<T> &src0Tensor, const LocalTensor<T> &src1Tensor,
|
||||
int64_t src0Offset, int32_t src1Size, int32_t repeatTimes)
|
||||
{
|
||||
if constexpr (hasPse == true) {
|
||||
/* Total data number of single step should be smaller than 256bytes.
|
||||
* If larger, we need to do add multiple times. */
|
||||
int32_t innerLoop = src1Size / repeatMaxSize; // s2轴整块计算次数
|
||||
int32_t innerRemain = src1Size % repeatMaxSize; // s2轴尾块计算量
|
||||
BinaryRepeatParams binaryRepeatParams;
|
||||
binaryRepeatParams.src0BlkStride = 1;
|
||||
binaryRepeatParams.src0RepStride = src1Size / blockSize;
|
||||
binaryRepeatParams.src1BlkStride = 1;
|
||||
binaryRepeatParams.src1RepStride = 0;
|
||||
binaryRepeatParams.dstRepStride = binaryRepeatParams.src0RepStride;
|
||||
binaryRepeatParams.blockNumber = binaryRepeatParams.src0RepStride;
|
||||
|
||||
for (int32_t j = 0; j < innerLoop; j++) {
|
||||
auto innerOffset = j * repeatMaxSize;
|
||||
auto ubOffset = src0Offset + innerOffset;
|
||||
Add(src0Tensor[ubOffset], src0Tensor[ubOffset], src1Tensor[innerOffset], repeatMaxSize, repeatTimes,
|
||||
binaryRepeatParams);
|
||||
}
|
||||
if (innerRemain > 0) {
|
||||
auto innerOffset = innerLoop * repeatMaxSize;
|
||||
auto ubOffset = src0Offset + innerOffset;
|
||||
Add(src0Tensor[ubOffset], src0Tensor[ubOffset], src1Tensor[innerOffset], innerRemain, repeatTimes,
|
||||
binaryRepeatParams);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, bool hasPse>
|
||||
__aicore__ inline void PseBroadcastAdd(int32_t s1Size, int32_t s2Size, int32_t computeSize, const LocalTensor<T> &pseUb,
|
||||
const LocalTensor<T> &dstTensor, uint32_t pseShapeType)
|
||||
{
|
||||
if constexpr (hasPse == true) {
|
||||
if (pseShapeType == pseS1S2 || pseShapeType == pseSlopeBn || pseShapeType == pseSlopeN) {
|
||||
Add(dstTensor, dstTensor, pseUb, computeSize);
|
||||
} else {
|
||||
/* Total repeated times should be <= repeatMaxTimes. If larger,
|
||||
* we need to do multiple inner loops. */
|
||||
int32_t s1OuterLoop = s1Size / repeatMaxTimes;
|
||||
int32_t s1OuterRemain = s1Size % repeatMaxTimes;
|
||||
for (int32_t s1OuterIdx = 0; s1OuterIdx < s1OuterLoop; s1OuterIdx++) {
|
||||
int32_t s1OuterOffset = s1OuterIdx * repeatMaxTimes * s2Size;
|
||||
BroadcastAdd<T, hasPse>(dstTensor, pseUb, s1OuterOffset, s2Size, repeatMaxTimes);
|
||||
}
|
||||
if (s1OuterRemain > 0) {
|
||||
int32_t s1OuterOffset = s1OuterLoop * repeatMaxTimes * s2Size;
|
||||
BroadcastAdd<T, hasPse>(dstTensor, pseUb, s1OuterOffset, s2Size, s1OuterRemain);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
template <bool hasPse> __aicore__ inline int64_t PseComputeOffset(PseInfo &pseInfo)
|
||||
{
|
||||
if constexpr (hasPse == true) {
|
||||
int64_t bOffset = 0;
|
||||
int64_t n2Offset = 0;
|
||||
int64_t s1Offset = 0;
|
||||
int64_t s2Offset = pseInfo.s2StartIdx + pseInfo.s2LoopCount * pseInfo.s2BaseNratioSize;
|
||||
int64_t gOffset = 0;
|
||||
if (pseInfo.pseShapeType == pseS1S2) {
|
||||
// b, n2, g, s1, s2
|
||||
bOffset = pseInfo.bSSOffset * pseInfo.n2G;
|
||||
n2Offset = pseInfo.n2oIdx * pseInfo.gSize * pseInfo.s1Size * pseInfo.s2Size;
|
||||
gOffset = pseInfo.goIdx * pseInfo.s1Size * pseInfo.s2Size;
|
||||
s1Offset = (pseInfo.s1oIdx * pseInfo.s1BaseSize + pseInfo.vecCoreOffset +
|
||||
pseInfo.loopIdx * pseInfo.vec1S1BaseSize) * pseInfo.s2Size;
|
||||
} else if (pseInfo.pseShapeType == pse1S2) {
|
||||
// b, n2, g, 1, s2
|
||||
bOffset = pseInfo.s2SizeAcc * pseInfo.n2G;
|
||||
n2Offset = pseInfo.n2oIdx * pseInfo.gSize * pseInfo.s2Size;
|
||||
gOffset = pseInfo.goIdx * pseInfo.s2Size;
|
||||
}
|
||||
if (pseInfo.pseBSize == 1) {
|
||||
bOffset = 0;
|
||||
}
|
||||
return bOffset + n2Offset + gOffset + s1Offset + s2Offset;
|
||||
} else {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
template <LayOutTypeEnum layOutType, bool hasPse> __aicore__ inline int64_t PseAlibiComputeOffset(PseInfo &pseInfo)
|
||||
{
|
||||
if constexpr (hasPse == true) {
|
||||
int64_t bOffset = (pseInfo.boIdx % pseInfo.pseBSize) * pseInfo.n2G * pseInfo.pseS2Size * pseInfo.pseS1Size;
|
||||
int64_t n2Offset = pseInfo.n2oIdx * pseInfo.gSize * pseInfo.pseS2Size * pseInfo.pseS1Size;
|
||||
int64_t gOffset = pseInfo.goIdx * pseInfo.pseS2Size * pseInfo.pseS1Size;
|
||||
int64_t row = pseInfo.s1oIdx * pseInfo.s1BaseSize + pseInfo.vecCoreOffset +
|
||||
pseInfo.loopIdx * pseInfo.vec1S1BaseSize;
|
||||
int64_t column = pseInfo.s2StartIdx + pseInfo.s2LoopCount * pseInfo.s2BaseNratioSize;
|
||||
int64_t m = 0;
|
||||
int64_t k = 0;
|
||||
if constexpr (layOutType != LayOutTypeEnum::LAYOUT_TND) {
|
||||
int64_t threshold = pseInfo.s1Size - pseInfo.pseS1Size;
|
||||
if (row >= threshold) {
|
||||
m = row - threshold;
|
||||
k = column;
|
||||
} else {
|
||||
m = row % pseInfo.pseS1Size;
|
||||
k = pseInfo.pseS2Size - (row - column) - (pseInfo.pseS1Size - m);
|
||||
}
|
||||
} else {
|
||||
int64_t threshold = pseInfo.pseS2Size - pseInfo.pseS1Size;
|
||||
int64_t posVal = row - column - threshold;
|
||||
if (threshold >= 0) {
|
||||
if (posVal >= 0) {
|
||||
m = posVal;
|
||||
k = 0;
|
||||
} else {
|
||||
m = 0;
|
||||
k = -posVal;
|
||||
}
|
||||
} else {
|
||||
m = posVal;
|
||||
k = 0;
|
||||
}
|
||||
}
|
||||
int64_t s1Offset = m * pseInfo.pseS2Size;
|
||||
int64_t s2Offset = k;
|
||||
pseInfo.readS2Size = Min(pseInfo.s2AlignedSize, pseInfo.pseS2Size - k);
|
||||
pseInfo.pseS2ComputeSize = Align(pseInfo.readS2Size);
|
||||
|
||||
return bOffset + n2Offset + gOffset + s1Offset + s2Offset;
|
||||
} else {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
template <bool hasPse> __aicore__ inline bool NeedPseAlibiCompute(PseInfo &pseInfo)
|
||||
{
|
||||
if constexpr (hasPse == true) {
|
||||
// Alibi编码只计算下三角
|
||||
if (pseInfo.s1oIdx * pseInfo.s1BaseSize + pseInfo.vecCoreOffset +
|
||||
(pseInfo.loopIdx + 1) * pseInfo.vec1S1BaseSize <=
|
||||
pseInfo.s2StartIdx + pseInfo.s2LoopCount * pseInfo.s2BaseNratioSize) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename INPUT_T, typename T, LayOutTypeEnum layOutType, bool hasPse>
|
||||
__aicore__ inline void PseAlibiCopyIn(LocalTensor<T> &dstTensor, LocalTensor<INPUT_T> &tmpTensor,
|
||||
GlobalTensor<INPUT_T> &srcTensor, PseInfo &pseInfo, int64_t alignedSize = 16)
|
||||
{
|
||||
if constexpr (hasPse == true) {
|
||||
if (!NeedPseAlibiCompute<hasPse>(pseInfo)) {
|
||||
return;
|
||||
}
|
||||
int64_t offset = PseAlibiComputeOffset<layOutType, hasPse>(pseInfo);
|
||||
if constexpr (IsSameType<INPUT_T, T>::value) {
|
||||
if (!pseInfo.align8){
|
||||
DataCopyIn<INPUT_T, hasPse>(dstTensor, srcTensor, offset, pseInfo.vec1S1RealSize, pseInfo.readS2Size,
|
||||
pseInfo.pseS2Size, alignedSize);
|
||||
} else {
|
||||
DataCopyInAlign8<INPUT_T, hasPse>(dstTensor, srcTensor, offset, pseInfo.vec1S1RealSize,
|
||||
pseInfo.readS2Size, pseInfo.pseS2Size);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
DataCopyIn<INPUT_T, hasPse>(tmpTensor, srcTensor, offset, pseInfo.vec1S1RealSize, pseInfo.readS2Size,
|
||||
pseInfo.pseS2Size, alignedSize);
|
||||
if (pseInfo.needCast) {
|
||||
event_t eventIdMte2ToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
|
||||
SetFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
|
||||
WaitFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
|
||||
Cast(dstTensor, tmpTensor, RoundMode::CAST_NONE, pseInfo.vec1S1RealSize * pseInfo.pseS2ComputeSize);
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, bool hasPse>
|
||||
__aicore__ inline void PseSlopeCopyIn(LocalTensor<T> &dstTensor, LocalTensor<half> &helpTensor,
|
||||
__gm__ uint8_t *pseSlope, GlobalTensor<half> &alibiGm, PseInfo &pseInfo,
|
||||
int64_t alignedSize = 16) {
|
||||
if constexpr (hasPse == true) {
|
||||
int64_t bOffset = 0;
|
||||
int64_t n2Offset = pseInfo.n2oIdx * pseInfo.gSize;
|
||||
int64_t gOffset = pseInfo.goIdx;
|
||||
|
||||
if (pseInfo.pseShapeType == pseSlopeBn) {
|
||||
bOffset = pseInfo.boIdx * pseInfo.n2G;
|
||||
}
|
||||
int64_t offset = bOffset + n2Offset + gOffset;
|
||||
|
||||
DataCopyIn<half, hasPse>(helpTensor, alibiGm, 0, pseInfo.vec1S1RealSize,
|
||||
pseInfo.s2RealSize, pseInfo.pseAlibiBaseS2, alignedSize);
|
||||
event_t eventIdMte2ToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
|
||||
SetFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
|
||||
WaitFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
|
||||
|
||||
if (pseInfo.needCast) {
|
||||
int64_t computeSize = pseInfo.vec1S1RealSize * pseInfo.s2AlignedSize;
|
||||
Cast(dstTensor, helpTensor, RoundMode::CAST_NONE, computeSize);
|
||||
pipe_barrier(PIPE_V);
|
||||
|
||||
int64_t s1Offset = pseInfo.s1oIdx * pseInfo.s1BaseSize + pseInfo.vecCoreOffset +
|
||||
pseInfo.loopIdx * pseInfo.vec1S1BaseSize;
|
||||
int64_t s2Offset = pseInfo.s2StartIdx + pseInfo.s2LoopCount * pseInfo.s2BaseNratioSize;
|
||||
|
||||
float posShift = float(s2Offset + pseInfo.kvStartIdx - s1Offset - pseInfo.qStartIdx);
|
||||
|
||||
Adds(dstTensor, dstTensor, posShift, computeSize);
|
||||
pipe_barrier(PIPE_V);
|
||||
Abs(dstTensor, dstTensor, computeSize);
|
||||
pipe_barrier(PIPE_V);
|
||||
float slopes = ((__gm__ T *)pseSlope)[offset] * -1;
|
||||
if (pseInfo.pseType == (uint32_t)PseTypeEnum::PSE_INNER_MUL_ADD_SQRT_TYPE) {
|
||||
Sqrt(dstTensor, dstTensor, computeSize);
|
||||
pipe_barrier(PIPE_V);
|
||||
}
|
||||
Muls(dstTensor, dstTensor, slopes, computeSize);
|
||||
pipe_barrier(PIPE_V);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, bool hasPse>
|
||||
__aicore__ inline void PseSlopeCast(LocalTensor<T> &dstTensor, LocalTensor<half> &helpTensor,
|
||||
__gm__ uint8_t *pseSlope, PseInfo &pseInfo) {
|
||||
if constexpr (hasPse == true) {
|
||||
int64_t bOffset = 0;
|
||||
int64_t n2Offset = pseInfo.n2oIdx * pseInfo.gSize;
|
||||
int64_t gOffset = pseInfo.goIdx;
|
||||
|
||||
if (pseInfo.pseShapeType == pseSlopeBn) {
|
||||
bOffset = pseInfo.boIdx * pseInfo.n2G;
|
||||
}
|
||||
int64_t offset = bOffset + n2Offset + gOffset;
|
||||
int64_t computeSize = pseInfo.vec1S1RealSize * pseInfo.s2AlignedSize;
|
||||
Cast(dstTensor, helpTensor, RoundMode::CAST_NONE, computeSize);
|
||||
pipe_barrier(PIPE_V);
|
||||
|
||||
int64_t s1Offset = pseInfo.s1oIdx * pseInfo.s1BaseSize + pseInfo.vecCoreOffset +
|
||||
pseInfo.loopIdx * pseInfo.vec1S1BaseSize;
|
||||
int64_t s2Offset = pseInfo.s2StartIdx + pseInfo.s2LoopCount * pseInfo.s2BaseNratioSize;
|
||||
|
||||
float posShift = float(s2Offset + pseInfo.kvStartIdx - s1Offset - pseInfo.qStartIdx);
|
||||
|
||||
Adds(dstTensor, dstTensor, posShift, computeSize);
|
||||
pipe_barrier(PIPE_V);
|
||||
Abs(dstTensor, dstTensor, computeSize);
|
||||
pipe_barrier(PIPE_V);
|
||||
float slopes = ((__gm__ T *)pseSlope)[offset] * -1;
|
||||
if (pseInfo.pseType == (uint32_t)PseTypeEnum::PSE_INNER_MUL_ADD_SQRT_TYPE) {
|
||||
Sqrt(dstTensor, dstTensor, computeSize);
|
||||
pipe_barrier(PIPE_V);
|
||||
}
|
||||
Muls(dstTensor, dstTensor, slopes, computeSize);
|
||||
pipe_barrier(PIPE_V);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename INPUT_T, typename T, LayOutTypeEnum layOutType, bool hasPse>
|
||||
__aicore__ inline void PseCopyIn(LocalTensor<T> &dstTensor, LocalTensor<INPUT_T> &tmpTensor,
|
||||
GlobalTensor<INPUT_T> &srcTensor, PseInfo &pseInfo, int64_t alignedSize = 16)
|
||||
{
|
||||
if constexpr (hasPse == true) {
|
||||
if (pseInfo.pseEncodeType == pseEncodeALibiS2Full) {
|
||||
return PseAlibiCopyIn<INPUT_T, T, layOutType, hasPse>(dstTensor, tmpTensor, srcTensor, pseInfo, alignedSize);
|
||||
}
|
||||
int64_t offset = PseComputeOffset<hasPse>(pseInfo);
|
||||
int64_t s1Size = pseInfo.pseShapeType == pse1S2 ? (pseInfo.blockCount == 0 ? 1 : pseInfo.blockCount) :
|
||||
pseInfo.vec1S1RealSize;
|
||||
|
||||
if constexpr (IsSameType<INPUT_T, T>::value) {
|
||||
if (!pseInfo.align8){
|
||||
DataCopyIn<INPUT_T, hasPse>(dstTensor, srcTensor, offset, s1Size, pseInfo.s2RealSize,
|
||||
pseInfo.s2Size, alignedSize);
|
||||
} else {
|
||||
DataCopyInAlign8<INPUT_T, hasPse>(dstTensor, srcTensor, offset, s1Size, pseInfo.s2RealSize, pseInfo.s2Size);
|
||||
}
|
||||
return;
|
||||
}
|
||||
DataCopyIn<INPUT_T, hasPse>(tmpTensor, srcTensor, offset, s1Size, pseInfo.s2RealSize, pseInfo.s2Size,
|
||||
alignedSize);
|
||||
if (pseInfo.needCast) {
|
||||
event_t eventIdMte2ToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
|
||||
SetFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
|
||||
WaitFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
|
||||
Cast(dstTensor, tmpTensor, RoundMode::CAST_NONE, s1Size * pseInfo.s2AlignedSize);
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, bool hasPse>
|
||||
__aicore__ inline void PseAlibiCompute(LocalTensor<T> &dstTensor, LocalTensor<T> &pseTensor, PseInfo &pseInfo)
|
||||
{
|
||||
if constexpr (hasPse == true) {
|
||||
if (!NeedPseAlibiCompute<hasPse>(pseInfo)) {
|
||||
return;
|
||||
}
|
||||
Add(dstTensor, dstTensor, pseTensor, pseInfo.vec1S1RealSize * pseInfo.pseS2ComputeSize);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, bool hasPse>
|
||||
__aicore__ inline void PseCompute(LocalTensor<T> &dstTensor, LocalTensor<T> &pseTensor, PseInfo &pseInfo)
|
||||
{
|
||||
if constexpr (hasPse == true) {
|
||||
if (pseInfo.pseEncodeType == pseEncodeALibiS2Full) {
|
||||
return PseAlibiCompute<T, hasPse>(dstTensor, pseTensor, pseInfo);
|
||||
}
|
||||
int64_t computeSize = (pseInfo.pseShapeType == pseS1S2 || pseInfo.pseShapeType == pseSlopeBn ||
|
||||
pseInfo.pseShapeType == pseSlopeN)
|
||||
? pseInfo.vec1S1RealSize * pseInfo.s2AlignedSize
|
||||
: pseInfo.s2AlignedSize;
|
||||
PseBroadcastAdd<T, hasPse>(pseInfo.vec1S1RealSize, pseInfo.s2AlignedSize, computeSize, pseTensor,
|
||||
dstTensor, pseInfo.pseShapeType);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
template <bool hasPse>
|
||||
__aicore__ inline void PseInnerAlibiCreate(GlobalTensor<half> &dstTensor, LocalTensor<half> &helpTensor, PseInfo &pseInfo) {
|
||||
if constexpr (hasPse == true) {
|
||||
if (pseInfo.pseType != (uint32_t)PseTypeEnum::PSE_INNER_MUL_ADD_TYPE && pseInfo.pseType != (uint32_t)PseTypeEnum::PSE_INNER_MUL_ADD_SQRT_TYPE) {
|
||||
return;
|
||||
}
|
||||
event_t eventIdMte3ToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE3_V));
|
||||
event_t eventIdMte3ToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE3_S));
|
||||
event_t eventIdVToMte3 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_MTE3));
|
||||
float tmpValue = -1.0;
|
||||
|
||||
for (int64_t i = 0; i < pseInfo.pseAlibiBaseS1; i++) {
|
||||
CreateVecIndex(helpTensor, (half)(i * tmpValue), pseInfo.pseAlibiBaseS2);
|
||||
SetFlag<HardEvent::V_MTE3>(eventIdVToMte3);
|
||||
WaitFlag<HardEvent::V_MTE3>(eventIdVToMte3);
|
||||
DataCopy(dstTensor[i * pseInfo.pseAlibiBaseS2], helpTensor, pseInfo.pseAlibiBaseS2);
|
||||
SetFlag<HardEvent::MTE3_V>(eventIdMte3ToV);
|
||||
WaitFlag<HardEvent::MTE3_V>(eventIdMte3ToV);
|
||||
SetFlag<HardEvent::MTE3_S>(eventIdMte3ToS);
|
||||
WaitFlag<HardEvent::MTE3_S>(eventIdMte3ToS);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
144
csrc/utils/inc/kernel/util.h
Normal file
144
csrc/utils/inc/kernel/util.h
Normal file
@@ -0,0 +1,144 @@
|
||||
/**
|
||||
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file util.h
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#ifndef FLASH_ATTENTION_UTIL_H
|
||||
#define FLASH_ATTENTION_UTIL_H
|
||||
|
||||
constexpr int32_t blockBytes = 32;
|
||||
constexpr int32_t byteBitRatio = 8;
|
||||
constexpr int64_t prefixAttenMaskDownHeight = 1024;
|
||||
constexpr static int32_t blockSize = blockBytes / 4; // 4 means sizeof(T)
|
||||
constexpr static int32_t repeatMaxBytes = 256;
|
||||
constexpr static int32_t repeatMaxTimes = 255;
|
||||
constexpr static int32_t repeatMaxSize = repeatMaxBytes / 4; // 4 means sizeof(T)
|
||||
|
||||
using AscendC::LocalTensor;
|
||||
using AscendC::GlobalTensor;
|
||||
using AscendC::DataFormat;
|
||||
using AscendC::ShapeInfo;
|
||||
using AscendC::DataCopyParams;
|
||||
using AscendC::DataCopyPadParams;
|
||||
using AscendC::BinaryRepeatParams;
|
||||
using AscendC::IsSameType;
|
||||
using AscendC::HardEvent;
|
||||
using AscendC::SetFlag;
|
||||
using AscendC::WaitFlag;
|
||||
|
||||
enum class LayOutTypeEnum { None = 0, LAYOUT_BSH = 1, LAYOUT_SBH = 2, LAYOUT_BNSD = 3, LAYOUT_TND = 4, LAYOUT_NTD_TND = 5};
|
||||
|
||||
namespace math {
|
||||
template <typename T> __aicore__ inline T Ceil(T a, T b)
|
||||
{
|
||||
if (b == 0) {
|
||||
return 0;
|
||||
}
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
|
||||
template <typename T> __aicore__ inline T Align(T a, T b)
|
||||
{
|
||||
if (b == 0) {
|
||||
return 0;
|
||||
}
|
||||
return (a + b - 1) / b * b;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T1, typename T2>
|
||||
__aicore__ inline T1 CeilDiv(T1 a, T2 b)
|
||||
{
|
||||
if (b == 0) {
|
||||
return 0;
|
||||
}
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
|
||||
template <typename T1, typename T2>
|
||||
__aicore__ inline T1 Max(T1 a, T2 b)
|
||||
{
|
||||
return (a > b) ? (a) : (b);
|
||||
}
|
||||
|
||||
template <typename T1, typename T2>
|
||||
__aicore__ inline T1 Min(T1 a, T2 b)
|
||||
{
|
||||
return (a > b) ? (b) : (a);
|
||||
}
|
||||
|
||||
__aicore__ inline void BoolCopyIn(LocalTensor<uint8_t> &dstTensor, GlobalTensor<uint8_t> &srcTensor,
|
||||
int64_t srcOffset, uint32_t s1Size, uint32_t s2Size, int64_t totalS2Size, int64_t alignedSize = blockBytes)
|
||||
{
|
||||
uint32_t alignedS2Size = CeilDiv(s2Size, alignedSize) * alignedSize;
|
||||
uint32_t shapeArray[] = {s1Size, alignedS2Size};
|
||||
dstTensor.SetShapeInfo(ShapeInfo(2, shapeArray, DataFormat::ND));
|
||||
dstTensor.SetSize(s1Size * alignedS2Size);
|
||||
DataCopyParams dataCopyParams;
|
||||
dataCopyParams.blockCount = s1Size;
|
||||
dataCopyParams.dstStride = 0;
|
||||
if (totalS2Size == blockBytes && alignedSize == 64) { // totalS2Size < 64 && totalS2Size % blockBytes == 0
|
||||
dataCopyParams.dstStride = 1;
|
||||
alignedSize = blockBytes;
|
||||
alignedS2Size = CeilDiv(s2Size, blockBytes) * blockBytes;
|
||||
}
|
||||
if (totalS2Size % alignedSize == 0) {
|
||||
dataCopyParams.blockLen = alignedS2Size / blockBytes;
|
||||
dataCopyParams.srcStride = (totalS2Size - alignedS2Size) / blockBytes;
|
||||
DataCopy(dstTensor, srcTensor[srcOffset], dataCopyParams);
|
||||
} else {
|
||||
dataCopyParams.blockLen = s2Size;
|
||||
dataCopyParams.srcStride = totalS2Size - s2Size;
|
||||
DataCopyPadParams dataCopyPadParams;
|
||||
dataCopyPadParams.isPad = true;
|
||||
dataCopyPadParams.rightPadding = Min(alignedS2Size - s2Size, blockBytes);
|
||||
dataCopyPadParams.paddingValue = 1;
|
||||
DataCopyPad(dstTensor, srcTensor[srcOffset], dataCopyParams, dataCopyPadParams);
|
||||
}
|
||||
}
|
||||
|
||||
__aicore__ inline void Bit2Int8CopyIn(LocalTensor<uint8_t> &dstTensor, GlobalTensor<uint8_t> &srcTensor,
|
||||
int64_t srcOffset, uint32_t batchSize, uint32_t s1BaseSize, uint32_t s2BaseSize, int64_t s2TotalSize,
|
||||
int64_t alignedSize = blockBytes)
|
||||
{
|
||||
uint32_t alignedS2Size = CeilDiv(s2BaseSize / byteBitRatio, alignedSize) * alignedSize;
|
||||
uint32_t shapeArray[] = {batchSize * s1BaseSize, alignedS2Size};
|
||||
dstTensor.SetShapeInfo(ShapeInfo(2, shapeArray, DataFormat::ND));
|
||||
dstTensor.SetSize(batchSize * s1BaseSize * alignedS2Size);
|
||||
DataCopyParams dataCopyParams;
|
||||
dataCopyParams.blockCount = batchSize * s1BaseSize;
|
||||
dataCopyParams.blockLen = CeilDiv(s2BaseSize / byteBitRatio, blockBytes);
|
||||
dataCopyParams.dstStride = 0;
|
||||
if (s2TotalSize / byteBitRatio % alignedSize == 0 && s2BaseSize / byteBitRatio % alignedSize == 0) {
|
||||
dataCopyParams.srcStride =
|
||||
(s2TotalSize / byteBitRatio - dataCopyParams.blockLen * blockBytes) / blockBytes;
|
||||
DataCopy(dstTensor, srcTensor[srcOffset / byteBitRatio], dataCopyParams);
|
||||
} else {
|
||||
dataCopyParams.blockLen = CeilDiv(s2BaseSize , byteBitRatio);
|
||||
dataCopyParams.srcStride = (s2TotalSize - s2BaseSize) / byteBitRatio;
|
||||
DataCopyPadParams dataCopyPadParams;
|
||||
dataCopyPadParams.isPad = true;
|
||||
dataCopyPadParams.rightPadding = 0;
|
||||
dataCopyPadParams.paddingValue = 0;
|
||||
DataCopyPad(dstTensor, srcTensor[srcOffset / byteBitRatio], dataCopyParams, dataCopyPadParams);
|
||||
}
|
||||
}
|
||||
|
||||
__aicore__ inline int32_t Align(int32_t shape)
|
||||
{
|
||||
int32_t alignFactor = 16;
|
||||
int32_t alignedSize = CeilDiv<int32_t, int32_t>(shape, alignFactor) * alignFactor;
|
||||
return alignedSize;
|
||||
}
|
||||
|
||||
#endif // FLASH_ATTENTION_UTIL_H
|
||||
190
csrc/utils/inc/log/inner/dfx_base.h
Normal file
190
csrc/utils/inc/log/inner/dfx_base.h
Normal file
@@ -0,0 +1,190 @@
|
||||
/**
|
||||
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file dfx_base.h
|
||||
* \brief 外部模块不应直接引用本头文件
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <cstdint>
|
||||
#include <sstream>
|
||||
#include <unistd.h>
|
||||
#include <sys/syscall.h>
|
||||
#include <securec.h>
|
||||
#include <base/alog_pub.h>
|
||||
#include <base/err_msg.h>
|
||||
#include <exe_graph/runtime/tiling_context.h>
|
||||
#include <exe_graph/runtime/tiling_parse_context.h>
|
||||
#include <exe_graph/runtime/infer_shape_context.h>
|
||||
#include <exe_graph/runtime/infer_datatype_context.h>
|
||||
|
||||
namespace ops {
|
||||
namespace utils {
|
||||
|
||||
class LogBase {
|
||||
public:
|
||||
static constexpr const int MAX_LOG_LEN = 16000;
|
||||
static constexpr const int MSG_HDR_LEN = 200;
|
||||
|
||||
static inline uint64_t GetTid()
|
||||
{
|
||||
return static_cast<uint64_t>(syscall(__NR_gettid));
|
||||
}
|
||||
|
||||
static inline const char *GetStr(const std::string &str)
|
||||
{
|
||||
return str.c_str();
|
||||
}
|
||||
|
||||
static inline const char *GetStr(const char *str)
|
||||
{
|
||||
return str;
|
||||
}
|
||||
|
||||
static inline const std::string &GetOpInfo(const std::string &str)
|
||||
{
|
||||
return str;
|
||||
}
|
||||
|
||||
static inline const char *GetOpInfo(const char *str)
|
||||
{
|
||||
return str;
|
||||
}
|
||||
|
||||
static inline std::string GetOpInfo(const gert::TilingContext *context)
|
||||
{
|
||||
return GetOpInfoFromContext(context);
|
||||
}
|
||||
|
||||
static inline std::string GetOpInfo(const gert::TilingParseContext *context)
|
||||
{
|
||||
return GetOpInfoFromContext(context);
|
||||
}
|
||||
|
||||
static inline std::string GetOpInfo(const gert::InferShapeContext *context)
|
||||
{
|
||||
return GetOpInfoFromContext(context);
|
||||
}
|
||||
|
||||
static inline std::string GetOpInfo(const gert::InferDataTypeContext *context)
|
||||
{
|
||||
return GetOpInfoFromContext(context);
|
||||
}
|
||||
|
||||
private:
|
||||
template <class T> static inline std::string GetOpInfoFromContext(T context)
|
||||
{
|
||||
if (context == nullptr) {
|
||||
return "nil:nil";
|
||||
}
|
||||
std::string opInfo = context->GetNodeType() != nullptr ? context->GetNodeType() : "nil";
|
||||
opInfo += ":";
|
||||
opInfo += context->GetNodeName() != nullptr ? context->GetNodeName() : "nil";
|
||||
return opInfo;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace utils
|
||||
|
||||
template <typename T>
|
||||
std::string Shape2String(const T& shape) {
|
||||
std::ostringstream oss;
|
||||
oss << "[";
|
||||
if (shape.GetDimNum() > 0) {
|
||||
for (size_t i = 0; i < shape.GetDimNum() - 1; ++i) {
|
||||
oss << shape.GetDim(i) << ", ";
|
||||
}
|
||||
oss << shape.GetDim(shape.GetDimNum() - 1);
|
||||
}
|
||||
oss << "]";
|
||||
return oss.str();
|
||||
}
|
||||
} // namespace ops
|
||||
|
||||
// 使用本宏前需预定义标识子模块名称的 OPS_UTILS_LOG_SUB_MOD_NAME
|
||||
// 如: #define OPS_UTILS_LOG_SUB_MOD_NAME "OP_TILING" 或通过 CMake 传递预定义宏
|
||||
#define OPS_LOG_STUB(MOD_ID, LOG_LEVEL, OPS_DESC, FMT, ...) \
|
||||
do { \
|
||||
if (AlogCheckDebugLevel(static_cast<int>(MOD_ID), (LOG_LEVEL)) == 1) { \
|
||||
AlogRecord(static_cast<int>(MOD_ID), DLOG_TYPE_DEBUG, (LOG_LEVEL), \
|
||||
"[%s:%d][%s]%s[%s][%lu] OpName:[%s] " #FMT, \
|
||||
__FILE__, __LINE__, (OPS_UTILS_LOG_SUB_MOD_NAME), \
|
||||
(OPS_UTILS_LOG_PACKAGE_TYPE), __FUNCTION__, ops::utils::LogBase::GetTid(), \
|
||||
ops::utils::LogBase::GetStr(ops::utils::LogBase::GetOpInfo(OPS_DESC)), ##__VA_ARGS__); \
|
||||
} \
|
||||
}while (0)
|
||||
|
||||
#define OPS_LOG_STUB_IF(COND, LOG_FUNC, EXPR) \
|
||||
static_assert(std::is_same<bool, std::decay<decltype(COND)>::type>::value, "condition should be bool"); \
|
||||
do { \
|
||||
if (__builtin_expect((COND), 0)) { \
|
||||
LOG_FUNC; \
|
||||
EXPR; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define OPS_INNER_ERR_STUB(ERR_CODE_STR, OPS_DESC, FMT, ...) \
|
||||
do { \
|
||||
OPS_LOG_STUB(OP, DLOG_ERROR, OPS_DESC, FMT, ##__VA_ARGS__); \
|
||||
REPORT_INNER_ERR_MSG(ERR_CODE_STR, FMT, ##__VA_ARGS__); \
|
||||
} while (0)
|
||||
|
||||
#define OPS_CALL_ERR_STUB(ERR_CODE_STR, OPS_DESC, FMT, ...) \
|
||||
do { \
|
||||
OPS_LOG_STUB(OP, DLOG_ERROR, OPS_DESC, FMT, ##__VA_ARGS__); \
|
||||
REPORT_INNER_ERR_MSG(ERR_CODE_STR, FMT, ##__VA_ARGS__); \
|
||||
} while (0)
|
||||
|
||||
#define OPS_LOG_STUB_D(OPS_DESC, FMT, ...) OPS_LOG_STUB(OP, DLOG_DEBUG, OPS_DESC, FMT, ##__VA_ARGS__)
|
||||
#define OPS_LOG_STUB_I(OPS_DESC, FMT, ...) OPS_LOG_STUB(OP, DLOG_INFO, OPS_DESC, FMT, ##__VA_ARGS__)
|
||||
#define OPS_LOG_STUB_W(OPS_DESC, FMT, ...) OPS_LOG_STUB(OP, DLOG_WARN, OPS_DESC, FMT, ##__VA_ARGS__)
|
||||
#define OPS_LOG_STUB_E(OPS_DESC, FMT, ...) OPS_LOG_STUB(OP, DLOG_ERROR, OPS_DESC, FMT, ##__VA_ARGS__)
|
||||
#define OPS_LOG_STUB_EVENT(OPS_DESC, FMT, ...) OPS_LOG_STUB(OP, DLOG_EVENT, OPS_DESC, FMT, ##__VA_ARGS__)
|
||||
|
||||
#define OPS_LOG_STUB_FULL(LEVEL, OPS_DESC, FMT, ...) \
|
||||
do { \
|
||||
if (0 == AlogCheckDebugLevel(OP, (LEVEL))) { \
|
||||
break; \
|
||||
} \
|
||||
char msgbufxyz[ops::utils::LogBase::MAX_LOG_LEN]; \
|
||||
size_t msgmaxlen = (MSG_LENGTH - ops::utils::LogBase::MSG_HDR_LEN); \
|
||||
int rettmp = snprintf_s(msgbufxyz, sizeof(msgbufxyz), sizeof(msgbufxyz) - 1, FMT, ##__VA_ARGS__); \
|
||||
if (rettmp == -1) { \
|
||||
msgbufxyz[sizeof(msgbufxyz) - 1] = '\0'; \
|
||||
} \
|
||||
size_t msglength = std::strlen(msgbufxyz); \
|
||||
if (msglength < msgmaxlen) { \
|
||||
OPS_LOG_STUB(OP, (LEVEL), (OPS_DESC), "%s", msgbufxyz); \
|
||||
break; \
|
||||
} \
|
||||
char *msgchunkbegin = msgbufxyz; \
|
||||
char *msgchunkend = nullptr; \
|
||||
while (msgchunkbegin < msgbufxyz + msglength) { \
|
||||
if (msgchunkbegin[0] == '\n') { \
|
||||
OPS_LOG_STUB(OP, (LEVEL), (OPS_DESC), ""); \
|
||||
msgchunkbegin += 1; \
|
||||
continue; \
|
||||
} \
|
||||
msgchunkend = std::strchr(msgchunkbegin, '\n'); \
|
||||
if (msgchunkend == nullptr) { \
|
||||
msgchunkend = msgchunkbegin + std::strlen(msgchunkbegin); \
|
||||
} \
|
||||
while (msgchunkend > msgchunkbegin) { \
|
||||
std::string msgchunk(msgchunkbegin, \
|
||||
std::min(msgmaxlen, static_cast<size_t>(msgchunkend - msgchunkbegin))); \
|
||||
OPS_LOG_STUB(OP, (LEVEL), (OPS_DESC), "%s", msgchunk.c_str()); \
|
||||
msgchunkbegin += msgchunk.size(); \
|
||||
} \
|
||||
msgchunkbegin += 1; \
|
||||
} \
|
||||
} while (0)
|
||||
59
csrc/utils/inc/log/ops_log.h
Normal file
59
csrc/utils/inc/log/ops_log.h
Normal file
@@ -0,0 +1,59 @@
|
||||
/**
|
||||
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file ops_log.h
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "log/inner/dfx_base.h"
|
||||
|
||||
/* 基础日志 */
|
||||
#define OPS_LOG_D(OPS_DESC, ...) OPS_LOG_STUB_D(OPS_DESC, __VA_ARGS__)
|
||||
#define OPS_LOG_I(OPS_DESC, ...) OPS_LOG_STUB_I(OPS_DESC, __VA_ARGS__)
|
||||
#define OPS_LOG_W(OPS_DESC, ...) OPS_LOG_STUB_W(OPS_DESC, __VA_ARGS__)
|
||||
#define OPS_LOG_E(OPS_DESC, ...) OPS_INNER_ERR_STUB("EZ9999", OPS_DESC, __VA_ARGS__)
|
||||
#define OPS_LOG_E_WITHOUT_REPORT(OPS_DESC, ...) OPS_LOG_STUB_E(OPS_DESC, __VA_ARGS__)
|
||||
#define OPS_LOG_EVENT(OPS_DESC, ...) OPS_LOG_STUB_EVENT(OPS_DESC, __VA_ARGS__)
|
||||
|
||||
/* 全量日志
|
||||
* 输出超长日志, 若日志超长, 则会被分为多行输出 */
|
||||
#define OPS_LOG_FULL(LEVEL, OPS_DESC, ...) OPS_LOG_STUB_FULL(LEVEL, OPS_DESC, __VA_ARGS__)
|
||||
#define OPS_LOG_D_FULL(OPS_DESC, ...) OPS_LOG_STUB_FULL(DLOG_DEBUG, OPS_DESC, __VA_ARGS__)
|
||||
#define OPS_LOG_I_FULL(OPS_DESC, ...) OPS_LOG_STUB_FULL(DLOG_INFO, OPS_DESC, __VA_ARGS__)
|
||||
#define OPS_LOG_W_FULL(OPS_DESC, ...) OPS_LOG_STUB_FULL(DLOG_WARN, OPS_DESC, __VA_ARGS__)
|
||||
|
||||
/* 条件日志 */
|
||||
#define OPS_LOG_D_IF(COND, OP_DESC, EXPR, ...) OPS_LOG_STUB_IF(COND, OPS_LOG_D(OP_DESC, __VA_ARGS__), EXPR)
|
||||
#define OPS_LOG_I_IF(COND, OP_DESC, EXPR, ...) OPS_LOG_STUB_IF(COND, OPS_LOG_I(OP_DESC, __VA_ARGS__), EXPR)
|
||||
#define OPS_LOG_W_IF(COND, OP_DESC, EXPR, ...) OPS_LOG_STUB_IF(COND, OPS_LOG_W(OP_DESC, __VA_ARGS__), EXPR)
|
||||
#define OPS_LOG_E_IF(COND, OP_DESC, EXPR, ...) OPS_LOG_STUB_IF(COND, OPS_LOG_E(OP_DESC, __VA_ARGS__), EXPR)
|
||||
#define OPS_LOG_EVENT_IF(COND, OP_DESC, EXPR, ...) OPS_LOG_STUB_IF(COND, OPS_LOG_EVENT(OP_DESC, __VA_ARGS__), EXPR)
|
||||
|
||||
#define OPS_LOG_E_IF_NULL(OPS_DESC, PTR, EXPR) \
|
||||
if (__builtin_expect((PTR) == nullptr, 0)) { \
|
||||
OPS_LOG_STUB_E(OPS_DESC, "%s is nullptr!", #PTR); \
|
||||
OPS_CALL_ERR_STUB("EZ9999", OPS_DESC, "%s is nullptr!", #PTR); \
|
||||
EXPR; \
|
||||
}
|
||||
|
||||
#define OPS_CHECK(COND, LOG_FUNC, EXPR) \
|
||||
if (COND) { \
|
||||
LOG_FUNC; \
|
||||
EXPR; \
|
||||
}
|
||||
|
||||
#define OP_CHECK(COND, LOG_FUNC, EXPR) \
|
||||
if (COND) { \
|
||||
LOG_FUNC; \
|
||||
EXPR; \
|
||||
}
|
||||
47
csrc/utils/inc/tiling/data_copy_transpose_tiling.h
Normal file
47
csrc/utils/inc/tiling/data_copy_transpose_tiling.h
Normal file
@@ -0,0 +1,47 @@
|
||||
/**
|
||||
* Copyright (c) 2023-2024 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file data_copy_transpose_tiling.h
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <graph/tensor.h>
|
||||
#include "data_copy_transpose_tiling_def.h"
|
||||
|
||||
namespace optiling {
|
||||
|
||||
inline void GetDataCopyTransposeTiling(const ge::Shape &dstShape, const ge::Shape &srcShape, const uint32_t typeSize,
|
||||
optiling::CopyTransposeTiling &tiling)
|
||||
{
|
||||
std::vector<int64_t> dstShapeInfo = dstShape.GetDims();
|
||||
std::vector<int64_t> srcShapeInfo = srcShape.GetDims();
|
||||
|
||||
tiling.set_dstShapeB(dstShapeInfo[0]);
|
||||
tiling.set_dstShapeN(dstShapeInfo[1]);
|
||||
tiling.set_dstShapeS(dstShapeInfo[2]);
|
||||
tiling.set_dstShapeH(dstShapeInfo[3]);
|
||||
tiling.set_dstShapeHN(tiling.get_dstShapeH() / tiling.get_dstShapeN());
|
||||
|
||||
tiling.set_srcShapeB(srcShapeInfo[0]);
|
||||
tiling.set_srcShapeN(srcShapeInfo[1]);
|
||||
tiling.set_srcShapeS(srcShapeInfo[2]);
|
||||
tiling.set_srcShapeHN(srcShapeInfo[3]);
|
||||
tiling.set_originalShapeNLen(tiling.get_srcShapeHN() * typeSize);
|
||||
tiling.set_shapeSHValue(tiling.get_dstShapeS() * tiling.get_dstShapeH());
|
||||
tiling.set_shapeNsValue(tiling.get_dstShapeN() * tiling.get_dstShapeS());
|
||||
tiling.set_shapeNsnValue(tiling.get_dstShapeN() * tiling.get_srcShapeS() * tiling.get_srcShapeN());
|
||||
tiling.set_shapeBHValue(tiling.get_dstShapeB() * tiling.get_dstShapeH());
|
||||
}
|
||||
|
||||
} // namespace optiling
|
||||
43
csrc/utils/inc/tiling/data_copy_transpose_tiling_def.h
Normal file
43
csrc/utils/inc/tiling/data_copy_transpose_tiling_def.h
Normal file
@@ -0,0 +1,43 @@
|
||||
/**
|
||||
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file data_copy_transpose_tiling_def.h
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
#include <register/tilingdata_base.h>
|
||||
|
||||
namespace optiling {
|
||||
|
||||
BEGIN_TILING_DATA_DEF(CopyTransposeTiling)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, dstShapeB);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, dstShapeN);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, dstShapeS);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, dstShapeHN);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, dstShapeH);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, srcShapeB);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, srcShapeN);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, srcShapeS);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, srcShapeHN);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, originalShapeNLen);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, shapeSHValue);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, shapeNsValue);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, shapeNsnValue);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, invalidParamCopyTransposeTiling);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, shapeBHValue);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, paramsAlign);
|
||||
END_TILING_DATA_DEF;
|
||||
REGISTER_TILING_DATA_CLASS(CopyTransposeTilingOp, CopyTransposeTiling)
|
||||
|
||||
} // namespace optiling
|
||||
225
csrc/utils/inc/tiling/tiling_base.h
Normal file
225
csrc/utils/inc/tiling/tiling_base.h
Normal file
@@ -0,0 +1,225 @@
|
||||
/**
|
||||
* Copyright (c) 2023-2024 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file tiling_base.h
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <sstream>
|
||||
#include <exe_graph/runtime/tiling_context.h>
|
||||
#include <graph/utils/type_utils.h>
|
||||
#include <tiling/platform/platform_ascendc.h>
|
||||
#include "log/ops_log.h"
|
||||
|
||||
#ifdef ASCENDC_OP_TEST
|
||||
#define ASCENDC_EXTERN_C extern "C"
|
||||
#else
|
||||
#define ASCENDC_EXTERN_C
|
||||
#endif
|
||||
|
||||
namespace optiling {
|
||||
|
||||
struct AiCoreParams {
|
||||
uint64_t ubSize;
|
||||
uint64_t blockDim;
|
||||
uint64_t aicNum;
|
||||
uint64_t l1Size;
|
||||
uint64_t l0aSize;
|
||||
uint64_t l0bSize;
|
||||
uint64_t l0cSize;
|
||||
};
|
||||
|
||||
struct FlashAttentionScoreGradCompileInfo {
|
||||
uint32_t aivNum;
|
||||
uint32_t aicNum;
|
||||
uint64_t ubSize;
|
||||
uint64_t l1Size;
|
||||
uint64_t l0aSize;
|
||||
uint64_t l0bSize;
|
||||
uint64_t l0cSize;
|
||||
uint64_t l2CacheSize;
|
||||
int64_t coreNum;
|
||||
};
|
||||
|
||||
class TilingBaseClass {
|
||||
public:
|
||||
TilingBaseClass() = default;
|
||||
|
||||
explicit TilingBaseClass(gert::TilingContext *context) : context_(context)
|
||||
{
|
||||
}
|
||||
|
||||
virtual ~TilingBaseClass() = default;
|
||||
|
||||
// Tiling执行框架
|
||||
// 1、GRAPH_SUCCESS: 成功,并且不需要继续执行后续Tiling类的实现
|
||||
// 2、GRAPH_FAILED: 失败,中止整个Tiling流程
|
||||
// 3、GRAPH_PARAM_INVALID: 本类不支持,需要继续往下执行其他Tiling类的实现
|
||||
ge::graphStatus DoTiling()
|
||||
{
|
||||
auto ret = GetShapeAttrsInfo();
|
||||
if (ret != ge::GRAPH_SUCCESS) {
|
||||
return ret;
|
||||
}
|
||||
ret = GetPlatformInfo();
|
||||
if (ret != ge::GRAPH_SUCCESS) {
|
||||
return ret;
|
||||
}
|
||||
if (!IsCapable()) {
|
||||
return ge::GRAPH_PARAM_INVALID;
|
||||
}
|
||||
ret = DoOpTiling();
|
||||
if (ret != ge::GRAPH_SUCCESS) {
|
||||
return ret;
|
||||
}
|
||||
ret = DoLibApiTiling();
|
||||
if (ret != ge::GRAPH_SUCCESS) {
|
||||
return ret;
|
||||
}
|
||||
ret = GetWorkspaceSize();
|
||||
if (ret != ge::GRAPH_SUCCESS) {
|
||||
return ret;
|
||||
}
|
||||
ret = PostTiling();
|
||||
if (ret != ge::GRAPH_SUCCESS) {
|
||||
return ret;
|
||||
}
|
||||
context_->SetTilingKey(GetTilingKey());
|
||||
DumpTilingInfo();
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
// 更新 context
|
||||
virtual void Reset(gert::TilingContext *context)
|
||||
{
|
||||
context_ = context;
|
||||
}
|
||||
|
||||
protected:
|
||||
virtual bool IsCapable() = 0;
|
||||
// 1、获取平台信息比如CoreNum、UB/L1/L0C资源大小
|
||||
virtual ge::graphStatus GetPlatformInfo() = 0;
|
||||
// 2、获取INPUT/OUTPUT/ATTR信息
|
||||
virtual ge::graphStatus GetShapeAttrsInfo() = 0;
|
||||
// 3、计算数据切分TilingData
|
||||
virtual ge::graphStatus DoOpTiling() = 0;
|
||||
// 4、计算高阶API的TilingData
|
||||
virtual ge::graphStatus DoLibApiTiling() = 0;
|
||||
// 5、计算TilingKey
|
||||
[[nodiscard]] virtual uint64_t GetTilingKey() const = 0;
|
||||
// 6、计算Workspace 大小
|
||||
virtual ge::graphStatus GetWorkspaceSize() = 0;
|
||||
// 7、保存Tiling数据
|
||||
virtual ge::graphStatus PostTiling() = 0;
|
||||
// 8、Dump Tiling数据
|
||||
virtual void DumpTilingInfo()
|
||||
{
|
||||
int32_t enable = AlogCheckDebugLevel(static_cast<int32_t>(OP), DLOG_DEBUG);
|
||||
if (enable != 1) {
|
||||
return;
|
||||
}
|
||||
auto buf = (uint32_t *)context_->GetRawTilingData()->GetData();
|
||||
auto bufLen = context_->GetRawTilingData()->GetDataSize();
|
||||
std::ostringstream oss;
|
||||
oss << "Start to dump tiling info. tilingkey:" << GetTilingKey() << ", tiling data size:" << bufLen
|
||||
<< ", content:";
|
||||
for (size_t i = 0; i < bufLen / sizeof(uint32_t); i++) {
|
||||
oss << *(buf + i) << ",";
|
||||
if (oss.str().length() > 640) { // Split according to 640 to avoid truncation
|
||||
OPS_LOG_D(context_, "%s", oss.str().c_str());
|
||||
oss.str("");
|
||||
}
|
||||
}
|
||||
OPS_LOG_D(context_, "%s", oss.str().c_str());
|
||||
}
|
||||
|
||||
static uint32_t CalcTschBlockDim(uint32_t sliceNum, uint32_t aicCoreNum, uint32_t aivCoreNum)
|
||||
{
|
||||
uint32_t ration;
|
||||
if (aicCoreNum == 0 || aivCoreNum == 0 || aicCoreNum > aivCoreNum) {
|
||||
return sliceNum;
|
||||
}
|
||||
ration = aivCoreNum / aicCoreNum;
|
||||
return (sliceNum + (ration - 1)) / ration;
|
||||
}
|
||||
|
||||
template <typename T> [[nodiscard]] std::string GetShapeDebugStr(const T &shape) const
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << "[";
|
||||
if (shape.GetDimNum() > 0) {
|
||||
for (size_t i = 0; i < shape.GetDimNum() - 1; ++i) {
|
||||
oss << shape.GetDim(i) << ", ";
|
||||
}
|
||||
oss << shape.GetDim(shape.GetDimNum() - 1);
|
||||
}
|
||||
oss << "]";
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
[[nodiscard]] std::string GetTensorDebugStr(const gert::StorageShape *shape,
|
||||
const gert::CompileTimeTensorDesc *tensor)
|
||||
{
|
||||
if (shape == nullptr || tensor == nullptr) {
|
||||
return "nil ";
|
||||
}
|
||||
std::ostringstream oss;
|
||||
oss << "(dtype: " << ge::TypeUtils::DataTypeToSerialString(tensor->GetDataType()) << "),";
|
||||
oss << "(shape:" << GetShapeDebugStr(shape->GetStorageShape()) << "),";
|
||||
oss << "(ori_shape:" << GetShapeDebugStr(shape->GetOriginShape()) << "),";
|
||||
oss << "(format: "
|
||||
<< ge::TypeUtils::FormatToSerialString(
|
||||
static_cast<ge::Format>(ge::GetPrimaryFormat(tensor->GetStorageFormat())))
|
||||
<< "),";
|
||||
oss << "(ori_format: " << ge::TypeUtils::FormatToSerialString(tensor->GetOriginFormat()) << ") ";
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
[[nodiscard]] std::string GetTilingContextDebugStr()
|
||||
{
|
||||
std::ostringstream oss;
|
||||
for (size_t i = 0; i < context_->GetComputeNodeInfo()->GetInputsNum(); ++i) {
|
||||
oss << "input" << i << ": ";
|
||||
oss << GetTensorDebugStr(context_->GetInputShape(i), context_->GetInputDesc(i));
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < context_->GetComputeNodeInfo()->GetOutputsNum(); ++i) {
|
||||
oss << "output" << i << ": ";
|
||||
oss << GetTensorDebugStr(context_->GetOutputShape(i), context_->GetOutputDesc(i));
|
||||
}
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
[[nodiscard]] std::string GetTilingDataDebugStr() const
|
||||
{
|
||||
auto rawTilingData = context_->GetRawTilingData();
|
||||
auto rawTilingDataSize = rawTilingData->GetDataSize();
|
||||
auto data = reinterpret_cast<const int32_t *>(rawTilingData->GetData());
|
||||
size_t len = rawTilingDataSize / sizeof(int32_t);
|
||||
std::ostringstream oss;
|
||||
for (size_t i = 0; i < len; i++) {
|
||||
oss << data[i] << ", ";
|
||||
}
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
protected:
|
||||
gert::TilingContext *context_ = nullptr;
|
||||
std::unique_ptr<platform_ascendc::PlatformAscendC> ascendcPlatform_{nullptr};
|
||||
uint32_t blockDim_{0};
|
||||
uint64_t workspaceSize_{0};
|
||||
uint64_t tilingKey_{0};
|
||||
AiCoreParams aicoreParams_{0, 0, 0, 0, 0, 0, 0};
|
||||
};
|
||||
|
||||
} // namespace optiling
|
||||
162
csrc/utils/inc/tiling/tiling_templates_registry.h
Normal file
162
csrc/utils/inc/tiling/tiling_templates_registry.h
Normal file
@@ -0,0 +1,162 @@
|
||||
/**
|
||||
* Copyright (c) 2023-2024 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file tiling_templates_registry.h
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <exe_graph/runtime/tiling_context.h>
|
||||
#include "tiling/tiling_base.h"
|
||||
#include "log/ops_log.h"
|
||||
#include "error/ops_error.h"
|
||||
|
||||
namespace optiling {
|
||||
|
||||
template <typename T> std::unique_ptr<TilingBaseClass> TILING_CLASS(gert::TilingContext *context)
|
||||
{
|
||||
return std::unique_ptr<T>(new (std::nothrow) T(context));
|
||||
}
|
||||
|
||||
using TilingClassCase = std::unique_ptr<TilingBaseClass> (*)(gert::TilingContext *);
|
||||
|
||||
class TilingCases {
|
||||
public:
|
||||
explicit TilingCases(std::string op_type) : op_type_(std::move(op_type))
|
||||
{
|
||||
}
|
||||
|
||||
template <typename T> void AddTiling(int32_t priority)
|
||||
{
|
||||
OPS_ERR_IF(cases_.find(priority) != cases_.end(),
|
||||
OPS_REPORT_VECTOR_INNER_ERR(op_type_, "There are duplicate registrations."), return);
|
||||
cases_[priority] = TILING_CLASS<T>;
|
||||
OPS_ERR_IF(
|
||||
cases_[priority] == nullptr,
|
||||
OPS_REPORT_VECTOR_INNER_ERR(op_type_, "Register op tiling func failed, please check the class name."),
|
||||
return);
|
||||
}
|
||||
|
||||
const std::map<int32_t, TilingClassCase> &GetTilingCases()
|
||||
{
|
||||
return cases_;
|
||||
}
|
||||
|
||||
private:
|
||||
std::map<int32_t, TilingClassCase> cases_;
|
||||
const std::string op_type_;
|
||||
};
|
||||
|
||||
class TilingRegistry {
|
||||
public:
|
||||
TilingRegistry() = default;
|
||||
|
||||
#ifdef ASCENDC_OP_TEST
|
||||
static TilingRegistry &GetInstance();
|
||||
#else
|
||||
static TilingRegistry &GetInstance()
|
||||
{
|
||||
static TilingRegistry registry_impl_;
|
||||
return registry_impl_;
|
||||
}
|
||||
#endif
|
||||
|
||||
std::shared_ptr<TilingCases> RegisterOp(const std::string &op_type)
|
||||
{
|
||||
if (registry_map_.find(op_type) == registry_map_.end()) {
|
||||
registry_map_[op_type] = std::shared_ptr<TilingCases>(new (std::nothrow) TilingCases(op_type));
|
||||
}
|
||||
OPS_ERR_IF(registry_map_[op_type] == nullptr,
|
||||
OPS_REPORT_VECTOR_INNER_ERR(op_type, "Register tiling func failed, please check the class name."),
|
||||
return nullptr);
|
||||
return registry_map_[op_type];
|
||||
}
|
||||
|
||||
ge::graphStatus DoTilingImpl(gert::TilingContext *context)
|
||||
{
|
||||
const char *op_type = context->GetNodeType();
|
||||
auto tilingTemplateRegistryMap = GetTilingTemplates(op_type);
|
||||
for (auto it = tilingTemplateRegistryMap.begin(); it != tilingTemplateRegistryMap.end(); ++it) {
|
||||
auto tilingTemplate = it->second(context);
|
||||
if (tilingTemplate != nullptr) {
|
||||
ge::graphStatus status = tilingTemplate->DoTiling();
|
||||
if (status != ge::GRAPH_PARAM_INVALID) {
|
||||
OPS_LOG_D(context, "Do general op tiling success priority=%d", it->first);
|
||||
return status;
|
||||
}
|
||||
OPS_LOG_D(context, "Ignore general op tiling priority=%d", it->first);
|
||||
}
|
||||
}
|
||||
OPS_REPORT_VECTOR_INNER_ERR(op_type, "Do op tiling failed, no valid template is found.");
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
|
||||
ge::graphStatus DoTilingImpl(gert::TilingContext *context, const std::vector<int32_t> &priorities)
|
||||
{
|
||||
const char *op_type = context->GetNodeType();
|
||||
auto tilingTemplateRegistryMap = GetTilingTemplates(op_type);
|
||||
for (auto priorityId : priorities) {
|
||||
auto templateFunc = tilingTemplateRegistryMap[priorityId](context);
|
||||
if (templateFunc != nullptr) {
|
||||
ge::graphStatus status = templateFunc->DoTiling();
|
||||
if (status == ge::GRAPH_SUCCESS) {
|
||||
OPS_LOG_D(context, "Do general op tiling success priority=%d", priorityId);
|
||||
return status;
|
||||
}
|
||||
OPS_LOG_D(context, "Ignore general op tiling priority=%d", priorityId);
|
||||
}
|
||||
}
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
|
||||
const std::map<int32_t, TilingClassCase> &GetTilingTemplates(const std::string &op_type)
|
||||
{
|
||||
OPS_ERR_IF(registry_map_.find(op_type) == registry_map_.end(),
|
||||
OPS_REPORT_VECTOR_INNER_ERR(op_type, "Get op tiling func failed, please check the op name."),
|
||||
return empty_tiling_case_);
|
||||
return registry_map_[op_type]->GetTilingCases();
|
||||
}
|
||||
|
||||
private:
|
||||
std::map<std::string, std::shared_ptr<TilingCases>> registry_map_;
|
||||
const std::map<int32_t, TilingClassCase> empty_tiling_case_ {};
|
||||
};
|
||||
|
||||
class Register {
|
||||
public:
|
||||
explicit Register(std::string op_type) : op_type_(std::move(op_type))
|
||||
{
|
||||
}
|
||||
|
||||
template <typename T> Register &tiling(int32_t priority)
|
||||
{
|
||||
auto tilingCases = TilingRegistry::GetInstance().RegisterOp(op_type_);
|
||||
OPS_ERR_IF(tilingCases == nullptr,
|
||||
OPS_REPORT_VECTOR_INNER_ERR(op_type_, "Register op tiling failed, please the op name."),
|
||||
return *this);
|
||||
tilingCases->AddTiling<T>(priority);
|
||||
return *this;
|
||||
}
|
||||
|
||||
private:
|
||||
const std::string op_type_;
|
||||
};
|
||||
|
||||
// op_type: 算子名称, class_name: 注册的 tiling 类,
|
||||
// priority: tiling 类的优先级, 越小表示优先级越高, 即被选中的概率越大
|
||||
#define REGISTER_TILING_TEMPLATE(op_type, class_name, priority) \
|
||||
static Register VAR_UNUSED##op_type_##class_name##priority_register = Register(op_type).tiling<class_name>(priority)
|
||||
|
||||
} // namespace optiling
|
||||
136
csrc/utils/inc/tiling/tiling_type.h
Normal file
136
csrc/utils/inc/tiling/tiling_type.h
Normal file
@@ -0,0 +1,136 @@
|
||||
/**
|
||||
* Copyright (c) 2023-2024 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file tiling_type.h
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
namespace optiling {
|
||||
|
||||
enum class AxisEnum {
|
||||
B = 0,
|
||||
N2 = 1,
|
||||
G = 2,
|
||||
S1 = 3,
|
||||
S2 = 4,
|
||||
D = 5,
|
||||
NONE = 9,
|
||||
};
|
||||
|
||||
enum class DtypeEnum {
|
||||
FLOAT16 = 0,
|
||||
FLOAT32 = 1,
|
||||
BFLOAT16 = 2,
|
||||
FLOAT16_PRECISION = 3,
|
||||
};
|
||||
|
||||
enum class PerformanceOrientedEnum {
|
||||
BIG_BUFFER = 1,
|
||||
BIG_DOUBLE_BUFFER = 2,
|
||||
};
|
||||
|
||||
enum class MatmulConfig {
|
||||
NULL_CONFIG = 0,
|
||||
NORMAL_CONFIG = 1,
|
||||
MDL_CONFIG = 2
|
||||
};
|
||||
|
||||
enum class PseConfig {
|
||||
NO_PSE = 0,
|
||||
EXIST_PSE = 1
|
||||
};
|
||||
|
||||
enum class AttenMaskConfig {
|
||||
NO_ATTEN_MASK = 0,
|
||||
EXIST_ATTEN_MASK = 1
|
||||
};
|
||||
|
||||
enum class DropOutConfig {
|
||||
NO_DROP_OUT = 0,
|
||||
EXIST_DROP_OUT = 1
|
||||
};
|
||||
|
||||
enum class CubeFormatEnum {
|
||||
ND = 0,
|
||||
NZ = 1
|
||||
};
|
||||
enum class LayoutEnum {
|
||||
BSND = 0,
|
||||
SBND = 1,
|
||||
BNSD = 2,
|
||||
TND = 3
|
||||
};
|
||||
|
||||
enum class CubeInputSourceEnum {
|
||||
GM = 0,
|
||||
L1 = 1
|
||||
};
|
||||
|
||||
enum class OptionEnum {
|
||||
DISABLE = 0,
|
||||
ENABLE = 1
|
||||
};
|
||||
|
||||
enum class SparseEnum {
|
||||
ALL = 0,
|
||||
NONE = 1,
|
||||
ANY = 2,
|
||||
CAUSAL = 3,
|
||||
BAND = 4,
|
||||
PREFIX = 5,
|
||||
BAND_COMPRESS = 6,
|
||||
RIGHT_DOWN_CAUSAL = 7,
|
||||
RIGHT_DOWN_CAUSAL_BAND = 8,
|
||||
BAND_LEFT_UP_CAUSAL = 9
|
||||
};
|
||||
|
||||
constexpr uint64_t RecursiveSum()
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <typename T, typename... Args> constexpr uint64_t RecursiveSum(T templateId, Args... templateIds)
|
||||
{
|
||||
return static_cast<uint64_t>(templateId) + 10 * RecursiveSum(templateIds...);
|
||||
}
|
||||
|
||||
// TilingKey 的生成规则:
|
||||
// FlashAttentionScore/FlashAttentionScoreGrad 十进制位组装tiling key,包含以下关键参数,从低位到高位依次是:Ub0, Ub1,
|
||||
// Block, DataType, Format, Sparse, 特化模板 Ub0、Ub1:
|
||||
// 表示Ub核内切分的轴,使用枚举AxisEnum表示,因为我们允许最多切分两根轴,所以存在UB0和UB1,如果没有UB核内切分,
|
||||
// 那么填AXIS_NONE。UB0和UB1各占一个十进制位;
|
||||
// Block: 表示UB用来分核的轴,使用枚举AxisEnum表示,占一个十进制位;
|
||||
// DataType: 表示当前tiling key支持的输入输出的数据类型,使用枚举SupportedDtype来表示,占一个十进制位
|
||||
// Format: 表示当前tiling key支持的Format, 使用枚举InputLayout表示,占一个十进制位
|
||||
// Sparse: 表示当前tiling key是否支持Sparse,使用枚举SparseCapability表示,占一个十进制位
|
||||
// 其余特化场景,定义自己的位域和值
|
||||
// usage: get tilingKey from inputed types
|
||||
// uint64_t tilingKey = GET_FLASHATTENTION_TILINGKEY(AxisEnum::AXIS_S1, AxisEnum::AXIS_S2, AxisEnum::AXIS_N2,
|
||||
// SupportedDtype::FLOAT32, InputLayout::BSH, SparseCapability::SUPPORT_ALL)
|
||||
|
||||
constexpr uint64_t TILINGKEYOFFSET = uint64_t(10000000000000000000UL); // 10^19
|
||||
template <typename... Args> constexpr uint64_t GET_TILINGKEY(Args... templateIds)
|
||||
{
|
||||
return TILINGKEYOFFSET + RecursiveSum(templateIds...);
|
||||
}
|
||||
|
||||
// usage: get tilingKey from inputed types
|
||||
// uint64_t tilingKey = TILINGKEY(S2, S1, N2, FLOAT32, BSND, ALL)
|
||||
|
||||
#define TILINGKEY(ub2, ub1, block, dtype, layout, sparse) \
|
||||
(GET_TILINGKEY(AxisEnum::ub2, AxisEnum::ub1, AxisEnum::block, DtypeEnum::dtype, LayoutEnum::layout, \
|
||||
SparseEnum::sparse))
|
||||
|
||||
} // namespace optiling
|
||||
53
csrc/utils/src/fallback_comm.cpp
Normal file
53
csrc/utils/src/fallback_comm.cpp
Normal file
@@ -0,0 +1,53 @@
|
||||
/**
|
||||
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file fallback_comm.cpp
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#include "fallback_comm.h"
|
||||
|
||||
#include <iostream>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
|
||||
#include "aclnn/aclnn_base.h"
|
||||
#include "runtime/base.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
namespace fallback {
|
||||
using namespace std;
|
||||
using namespace gert;
|
||||
using namespace ge;
|
||||
|
||||
aclDataType ToAclDataType(ge::DataType dtype) {
|
||||
static const std::vector<DataType> CANN_CONVERT_TO_ACL_DataType_LIST = {
|
||||
ge::DataType::DT_FLOAT, ge::DataType::DT_FLOAT16, ge::DataType::DT_INT8, ge::DataType::DT_INT32,
|
||||
ge::DataType::DT_UINT8, ge::DataType::DT_INT16, ge::DataType::DT_UINT16, ge::DataType::DT_UINT32,
|
||||
ge::DataType::DT_INT64, ge::DataType::DT_DOUBLE, ge::DataType::DT_BOOL, ge::DataType::DT_STRING,
|
||||
ge::DataType::DT_COMPLEX64, ge::DataType::DT_COMPLEX128, ge::DataType::DT_BF16, ge::DataType::DT_UINT64,
|
||||
ge::DataType::DT_INT4};
|
||||
auto iter = std::find(CANN_CONVERT_TO_ACL_DataType_LIST.begin(), CANN_CONVERT_TO_ACL_DataType_LIST.end(), dtype);
|
||||
if (iter == CANN_CONVERT_TO_ACL_DataType_LIST.end()) {
|
||||
return aclDataType::ACL_DT_UNDEFINED;
|
||||
}
|
||||
return static_cast<aclDataType>(dtype);
|
||||
}
|
||||
|
||||
} // namespace fallback
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
@@ -0,0 +1,25 @@
|
||||
# Adding a custom aclnn operation
|
||||
|
||||
This document describes how to add a custom aclnn operation to vllm-ascend.
|
||||
|
||||
## How custom aclnn operation works in vllm-ascend?
|
||||
|
||||
Custom aclnn operations are built and installed into `vllm_ascend/cann_ops_custom` directory during the build process of vllm-ascend. Then the aclnn operators are bound to `torch.ops._C_ascend` module, enabling users to invoke them in vllm-ascend python code.
|
||||
|
||||
To enable custom operations, use the following code:
|
||||
|
||||
```python
|
||||
from vllm_ascend.utils import enable_custom_op
|
||||
|
||||
enable_custom_op()
|
||||
```
|
||||
|
||||
## How to add a custom aclnn operation?
|
||||
|
||||
- Create a new operation folder under `csrc` directory
|
||||
- Create `op_host` and `op_kernel` directories for host and kernel source code
|
||||
- Add build options in `csrc/build_aclnn.sh` for supported SOC. Note that multiple ops should be separated with `;`, i.e. `CUSTOM_OPS=op1;op2;op3`
|
||||
- Bind aclnn operators to torch.ops._C_ascend module in `csrc/torch_binding.cpp`
|
||||
- Write a meta implementation in `csrc/torch_binding_meta.cpp` for op being captured into aclgraph
|
||||
|
||||
After a successful build of vllm-ascend, the custom aclnn operation can be invoked in python code.
|
||||
@@ -12,4 +12,5 @@ eplb_swift_balancer.md
|
||||
Multi_Token_Prediction
|
||||
ACL_Graph
|
||||
KV_Cache_Pool_Guide
|
||||
add_custom_aclnn_op
|
||||
:::
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
[build-system]
|
||||
# Should be mirrored in requirements.txt
|
||||
requires = [
|
||||
"attrs",
|
||||
"cmake>=3.26",
|
||||
"decorator",
|
||||
"einops",
|
||||
"googleapis-common-protos",
|
||||
"numpy<2.0.0",
|
||||
"packaging",
|
||||
"pip",
|
||||
@@ -12,6 +14,7 @@ requires = [
|
||||
"scipy",
|
||||
"pandas",
|
||||
"pandas-stubs",
|
||||
"psutil",
|
||||
"setuptools>=64",
|
||||
"setuptools-scm>=8",
|
||||
"transformers<=4.57.1",
|
||||
|
||||
40
setup.py
40
setup.py
@@ -25,7 +25,7 @@ import sys
|
||||
from sysconfig import get_paths
|
||||
from typing import Dict, List
|
||||
|
||||
from setuptools import Extension, find_packages, setup
|
||||
from setuptools import Command, Extension, find_packages, setup
|
||||
from setuptools.command.build_ext import build_ext
|
||||
from setuptools.command.build_py import build_py
|
||||
from setuptools.command.develop import develop
|
||||
@@ -199,6 +199,27 @@ class custom_build_info(build_py):
|
||||
super().run()
|
||||
|
||||
|
||||
class build_and_install_aclnn(Command):
|
||||
description = "Build and install AclNN by running build_aclnn.sh"
|
||||
user_options = []
|
||||
|
||||
def initialize_options(self):
|
||||
pass
|
||||
|
||||
def finalize_options(self):
|
||||
pass
|
||||
|
||||
def run(self):
|
||||
try:
|
||||
print("Running bash build_aclnn.sh ...")
|
||||
subprocess.check_call(
|
||||
["bash", "csrc/build_aclnn.sh", ROOT_DIR, envs.SOC_VERSION])
|
||||
print("buid_aclnn.sh executed successfully!")
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Error running build_aclnn.sh: {e}")
|
||||
raise SystemExit(e.returncode)
|
||||
|
||||
|
||||
class cmake_build_ext(build_ext):
|
||||
# A dict of extension directories that have been configured.
|
||||
did_config: Dict[str, bool] = {}
|
||||
@@ -385,8 +406,22 @@ class cmake_build_ext(build_ext):
|
||||
shutil.copy(src_path, dst_path)
|
||||
print(f"Copy: {src_path} -> {dst_path}")
|
||||
|
||||
# copy back _cann_ops_custom directory
|
||||
src_cann_ops_custom = os.path.join(ROOT_DIR, "vllm_ascend",
|
||||
"_cann_ops_custom")
|
||||
dst_cann_ops_custom = os.path.join(self.build_lib, "vllm_ascend",
|
||||
"_cann_ops_custom")
|
||||
if os.path.exists(src_cann_ops_custom):
|
||||
import shutil
|
||||
if os.path.exists(dst_cann_ops_custom):
|
||||
shutil.rmtree(dst_cann_ops_custom)
|
||||
shutil.copytree(src_cann_ops_custom, dst_cann_ops_custom)
|
||||
print(f"Copy: {src_cann_ops_custom} -> {dst_cann_ops_custom}")
|
||||
|
||||
def run(self):
|
||||
# First, run the standard build_ext command to compile the extensions
|
||||
# First, ensure ACLNN custom-ops is built and installed.
|
||||
self.run_command("build_aclnn")
|
||||
# Then, run the standard build_ext command to compile the extensions
|
||||
super().run()
|
||||
|
||||
|
||||
@@ -450,6 +485,7 @@ def get_requirements() -> List[str]:
|
||||
cmdclass = {
|
||||
"develop": custom_develop,
|
||||
"build_py": custom_build_info,
|
||||
"build_aclnn": build_and_install_aclnn,
|
||||
"build_ext": cmake_build_ext,
|
||||
"install": custom_install
|
||||
}
|
||||
|
||||
@@ -0,0 +1,148 @@
|
||||
import gc
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
|
||||
from vllm_ascend.utils import enable_custom_op
|
||||
|
||||
# enable internal format
|
||||
torch_npu.npu.config.allow_internal_format = True
|
||||
# enable vllm-ascend custom ops
|
||||
enable_custom_op()
|
||||
|
||||
|
||||
def gmm_swiglu_quant(x: torch.Tensor, weight: torch.Tensor,
|
||||
perChannelScale: torch.Tensor,
|
||||
perTokenScale: torch.Tensor, m: int):
|
||||
"""
|
||||
Perform quantized GMM (Grouped Matrix Multiplication) operation with SwiGLU activation function.
|
||||
|
||||
Parameters:
|
||||
x (torch.Tensor): Input tensor with shape (m, k).
|
||||
weight (torch.Tensor): Weight tensor with shape (k, n).
|
||||
perChannelScale (torch.Tensor): Per-channel scaling factor with shape (n,).
|
||||
perTokenScale (torch.Tensor): Per-token scaling factor with shape (m,).
|
||||
m (int): Number of tokens (rows of x).
|
||||
|
||||
Returns:
|
||||
quantOutput (torch.Tensor): Quantized output tensor with shape (m, k // 2).
|
||||
quantScaleOutput (torch.Tensor): Quantization scaling factor with shape (m,).
|
||||
"""
|
||||
# Perform matrix multiplication with int32 precision
|
||||
c_temp1 = torch.matmul(x.to(torch.int32), weight.to(torch.int32))
|
||||
c_temp1 = c_temp1.to(torch.float32) # Convert back to float32 for scaling
|
||||
|
||||
# Apply per-channel and per-token scaling
|
||||
c_temp2 = torch.mul(c_temp1, perChannelScale)
|
||||
c_temp3 = torch.mul(c_temp2, perTokenScale.reshape(m, 1))
|
||||
|
||||
# Split the result into two parts to apply SwiGLU activation function
|
||||
c_temp4, gate = c_temp3.chunk(2, dim=-1)
|
||||
c_temp5 = c_temp4 * torch.sigmoid(c_temp4) # SwiGLU activation
|
||||
c_temp6 = c_temp5 * gate # Element-wise multiplication with gating values
|
||||
|
||||
# Quantize the output
|
||||
max = torch.max(
|
||||
torch.abs(c_temp6),
|
||||
-1).values # Find maximum absolute value to calculate scaling factor
|
||||
quantScaleOutput = 127 / max # Calculate quantization scaling factor
|
||||
quantOutput = torch.round(c_temp6 * quantScaleOutput.reshape(m, 1)).to(
|
||||
torch.int8) # Quantize to int8
|
||||
quantScaleOutput = 1 / quantScaleOutput # Inverse quantization scaling factor for subsequent dequantization
|
||||
|
||||
return quantOutput, quantScaleOutput
|
||||
|
||||
|
||||
def process_groups(x: torch.Tensor, weight: torch.Tensor,
|
||||
perChannelScale: torch.Tensor, perTokenScale: torch.Tensor,
|
||||
groupList: torch.Tensor):
|
||||
"""
|
||||
Process input data by groups and call GMM_Swiglu_quant function for quantized computation.
|
||||
|
||||
Parameters:
|
||||
x (torch.Tensor): Input tensor with shape (M, K).
|
||||
weight (torch.Tensor): List of weight tensors, each with shape (E, K, N).
|
||||
perChannelScale (torch.Tensor): List of per-channel scaling factors, each with shape (E, N).
|
||||
perTokenScale (torch.Tensor): Per-token scaling factor with shape (M,).
|
||||
groupList (list): List defining the number of tokens in each group.
|
||||
|
||||
Returns:
|
||||
quantOutput (torch.Tensor): Quantized output tensor with shape (M, N // 2).
|
||||
quantScaleOutput (torch.Tensor): Quantization scaling factor with shape (M,).
|
||||
"""
|
||||
M, N = x.shape[0], weight.shape[2] # Get the shape of the input tensor
|
||||
quantOutput = torch.zeros(M, N // 2).to(
|
||||
torch.int8) # Initialize quantized output tensor
|
||||
quantScaleOutput = torch.zeros(M).to(
|
||||
torch.float32) # Initialize quantization scaling factor tensor
|
||||
|
||||
start_idx = 0 # Starting index
|
||||
preV = 0 # Number of tokens in the previous group
|
||||
groupList = groupList.tolist()
|
||||
# Iterate through groupList to process data by groups
|
||||
for i, v in enumerate(groupList):
|
||||
currV = v
|
||||
tempV = currV - preV # Calculate number of tokens in the current group
|
||||
preV = currV # Update number of tokens in the previous group
|
||||
if tempV > 0:
|
||||
# Call GMM_Swiglu_quant to process the current group
|
||||
quantOutput[start_idx:start_idx + tempV], quantScaleOutput[start_idx:start_idx + tempV] = \
|
||||
gmm_swiglu_quant(x[start_idx:start_idx + tempV],
|
||||
weight[i],
|
||||
perChannelScale[i],
|
||||
perTokenScale[start_idx:start_idx + tempV],
|
||||
tempV)
|
||||
|
||||
start_idx += tempV # Update starting index to process the next group
|
||||
return quantOutput, quantScaleOutput
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_gmm_swiglu_quant_weight_nz_tensor_list():
|
||||
M, K, E, N = 8192, 7168, 4, 4096
|
||||
|
||||
# x (M, K) - int8
|
||||
x = torch.randint(-128, 127, (M, K), dtype=torch.int8)
|
||||
|
||||
# weight (E, N, K) - int8
|
||||
weight = torch.randint(-128, 127, size=(E, K, N), dtype=torch.int8)
|
||||
|
||||
# weight_scale (E, N) - float32
|
||||
weight_scale = torch.rand(E, N) * 0.9 + 0.1 # uniform(0.1, 1.0)
|
||||
weight_scale = weight_scale.to(torch.float32)
|
||||
|
||||
weight_nz_npu = []
|
||||
weight_scale_npu = []
|
||||
for i in range(E):
|
||||
weight_nz_npu.append(torch_npu.npu_format_cast(weight[i].npu(), 29))
|
||||
weight_scale_npu.append(weight_scale[i].npu())
|
||||
|
||||
# x_scale (M,) - float32
|
||||
x_scale = torch.rand(M) * 0.9 + 0.1 # uniform(0.1, 1.0)
|
||||
x_scale = x_scale.to(torch.float32)
|
||||
|
||||
group_list = torch.tensor([2048, 4096, 6144, 8192], dtype=torch.int64)
|
||||
|
||||
output_cpu, output_scale_cpu = process_groups(x, weight, weight_scale,
|
||||
x_scale, group_list)
|
||||
output_npu, output_scale_npu, _ = \
|
||||
torch.ops._C_ascend.grouped_matmul_swiglu_quant_weight_nz_tensor_list(x.npu(),
|
||||
weight_nz_npu,
|
||||
weight_scale_npu,
|
||||
x_scale.npu(),
|
||||
group_list.npu())
|
||||
output_npu_valid = output_npu[:group_list[-1], :]
|
||||
output_scale_npu_valid = output_scale_npu[:group_list[-1]]
|
||||
|
||||
torch.testing.assert_close(output_npu_valid.cpu(),
|
||||
output_cpu,
|
||||
atol=1,
|
||||
rtol=2**-13)
|
||||
torch.testing.assert_close(output_scale_npu_valid.cpu(),
|
||||
output_scale_cpu,
|
||||
atol=1e-9,
|
||||
rtol=1e-6)
|
||||
|
||||
gc.collect()
|
||||
torch.npu.empty_cache()
|
||||
torch.npu.reset_peak_memory_stats()
|
||||
3
vllm_ascend/_cann_ops_custom/.gitkeep
Normal file
3
vllm_ascend/_cann_ops_custom/.gitkeep
Normal file
@@ -0,0 +1,3 @@
|
||||
# This folder is reserved for the installation of custom aclnn operators tailored for vLLM-Ascend.
|
||||
# Source code of the operators can be found in the `src` folder.
|
||||
# The operators are compiled into a custom CANN software package and installed to this folder automatically.
|
||||
@@ -38,6 +38,27 @@ from vllm_ascend.utils import (
|
||||
prefill_context_parallel_enable, update_aclgraph_sizes,
|
||||
update_cudagraph_capture_sizes, update_default_aclgraph_sizes)
|
||||
|
||||
# set custom ops path
|
||||
CUR_DIR = os.path.dirname(os.path.realpath(__file__))
|
||||
CUSTOM_OPP_PATH = os.path.join(CUR_DIR, "vllm_ascend", "_cann_ops_custom",
|
||||
"vendors", "customize")
|
||||
CUSTOM_LIB_PATH = os.path.join(CUSTOM_OPP_PATH, "op_api", "lib")
|
||||
|
||||
if os.path.exists(CUSTOM_OPP_PATH):
|
||||
current_cust_opp_path = os.environ.get("ASCEND_CUSTOM_OPP_PATH", "")
|
||||
if current_cust_opp_path:
|
||||
os.environ[
|
||||
"ASCEND_CUSTOM_OPP_PATH"] = f"{CUSTOM_OPP_PATH}:{current_cust_opp_path}"
|
||||
else:
|
||||
os.environ["ASCEND_CUSTOM_OPP_PATH"] = CUSTOM_OPP_PATH
|
||||
|
||||
if os.path.exists(CUSTOM_LIB_PATH):
|
||||
current_lib_path = os.environ.get("LD_LIBRARY_PATH", "")
|
||||
if current_lib_path:
|
||||
os.environ["LD_LIBRARY_PATH"] = f"{CUSTOM_LIB_PATH}:{current_lib_path}"
|
||||
else:
|
||||
os.environ["LD_LIBRARY_PATH"] = CUSTOM_LIB_PATH
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
Reference in New Issue
Block a user