[Kernel] add custom op GmmSwigluQuantWeightNzTensorList (#3804)

### What this PR does / why we need it?

This PR introduces support for adding custom CANN `aclnn` ops to
`vllm-ascend`, allowing users to define and use their own custom
operators.

Key changes include:
- Building and installing custom ops into the `vllm-ascend`-specified
directory
- Binding the `aclnn` op interface to the `torch.ops._C_ascend` module
- Enabling invocation of these ops within `vllm-ascend`

This PR includes a sample custom op:
`aclnnGroupedMatmulSwigluQuantWeightNzTensorList`, which is adapted from
the CANN operator
[`aclnnGroupedMatmulSwigluQuantWeightNZ`](https://www.hiascend.com/document/detail/zh/canncommercial/83RC1/API/aolapi/context/aclnnGroupedMatmulSwigluQuantWeightNZ.md).
Its input parameters `weight` and `weight_scale` now accept
`list[torch.Tensor]` (i.e., `at::TensorList`).

### Does this PR introduce _any_ user-facing change?

No.


- vLLM version: v0.11.2

---------

Signed-off-by: QianChenxi <chenxi.qian.cq@outlook.com>
This commit is contained in:
Chenxi Qian
2025-11-28 18:06:39 +08:00
committed by GitHub
parent 3199fe8350
commit 554f16ae1f
50 changed files with 6934 additions and 7 deletions

View File

@@ -96,6 +96,11 @@ jobs:
--exclude libge_common_base.so \
--exclude libc10.so \
--exclude libc_sec.so \
--exclude libnnopbase.so \
--exclude libprofapi.so \
--exclude libgraph_base.so \
--exclude libgraph.so \
--exclude libexe_graph.so \
--exclude "libascend*.so" \
--exclude "libtorch*.so" \
--exclude "libopapi.so" \

View File

@@ -12,7 +12,7 @@ repos:
- id: codespell
args: [
--toml, pyproject.toml,
'--skip', 'tests/e2e/multicard/test_torchair_graph_mode.py,csrc/mla_preprocess/**,tests/prompts/**,./benchmarks/sonnet.txt,*tests/lora/data/**,build/**,./vllm_ascend.egg-info/**,.github/**,typos.toml',
'--skip', 'tests/e2e/multicard/test_torchair_graph_mode.py,csrc/**,tests/prompts/**,./benchmarks/sonnet.txt,*tests/lora/data/**,build/**,./vllm_ascend.egg-info/**,.github/**,typos.toml',
'-L', 'CANN,cann,NNAL,nnal,ASCEND,ascend,EnQue,CopyIn,ArchType,AND'
]
additional_dependencies:
@@ -37,7 +37,7 @@ repos:
- id: typos
args: [
"--force-exclude",
"--exclude", "csrc/mla_preprocess/**"
"--exclude", "csrc/**"
]
- repo: https://github.com/PyCQA/isort
rev: 6.0.1

View File

@@ -82,6 +82,7 @@ set(
${TORCH_NPU_INCLUDE_DIRS}
${ASCEND_HOME_PATH}/include
${ASCEND_HOME_PATH}/aarch64-linux/include/experiment/platform
${ASCEND_HOME_PATH}/x86_64-linux/include/experiment/platform
)
pybind11_add_module(vllm_ascend_C ${VLLM_ASCEND_SRC})

642
csrc/CMakeLists.txt Normal file
View File

@@ -0,0 +1,642 @@
# Copyright (c) 2024 Huawei Technologies Co., Ltd.
# This file is a part of the CANN Open Software.
# Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
# Please refer to the License for details. You may not use this file except in compliance with the License.
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.
# ======================================================================================================================
cmake_minimum_required(VERSION 3.16)
project(cann_ops_custom)
option(BUILD_OPEN_PROJECT "Build open ascend ops project." ON)
option(ENABLE_CCACHE "Enable ccache capability" ON)
set(ASCEND_COMPUTE_UNIT "ascend910b" CACHE STRING "soc that need to be compiled")
set(ASCEND_OP_NAME "ALL" CACHE STRING "operators that need to be compiled")
set(VENDOR_NAME "customize" CACHE STRING "vendor name")
include(cmake/config.cmake)
include(cmake/func.cmake)
include(cmake/intf.cmake)
if (BUILD_OPEN_PROJECT)
set(_op_host_aclnn_link
$<BUILD_INTERFACE:intf_pub>
exe_graph
register
c_sec
)
set(CMAKE_MODULE_PATH
${CMAKE_MODULE_PATH}
${CMAKE_CURRENT_LIST_DIR}/cmake/modules
)
set(CMAKE_PREFIX_PATH
${CMAKE_PREFIX_PATH}
${ASCEND_CANN_PACKAGE_PATH}
)
find_package(alog MODULE REQUIRED)
add_library(op_host_aclnn SHARED EXCLUDE_FROM_ALL)
target_link_libraries(op_host_aclnn PRIVATE
${_op_host_aclnn_link}
)
target_compile_options(op_host_aclnn PRIVATE
$<$<COMPILE_LANGUAGE:CXX>:-std=gnu++1z>
)
add_library(op_host_aclnnInner SHARED EXCLUDE_FROM_ALL)
target_link_libraries(op_host_aclnnInner PRIVATE
${_op_host_aclnn_link}
)
target_compile_options(op_host_aclnnInner PRIVATE
$<$<COMPILE_LANGUAGE:CXX>:-std=gnu++1z>
)
add_library(op_host_aclnnExc SHARED EXCLUDE_FROM_ALL)
target_link_libraries(op_host_aclnnExc PRIVATE
${_op_host_aclnn_link}
)
target_compile_options(op_host_aclnnExc PRIVATE
$<$<COMPILE_LANGUAGE:CXX>:-std=gnu++1z>
)
# op api
add_library(opapi SHARED)
# When compiling a specified operator without aclnn src
if (NOT "${ASCEND_OP_NAME}" STREQUAL "ALL")
add_custom_command(OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/opapi_stub.cpp
COMMAND touch ${CMAKE_CURRENT_BINARY_DIR}/opapi_stub.cpp
)
target_sources(opapi PRIVATE
${CMAKE_CURRENT_BINARY_DIR}/opapi_stub.cpp
)
endif()
target_compile_options(opapi PRIVATE
$<$<COMPILE_LANGUAGE:CXX>:-std=gnu++1z>
)
target_include_directories(opapi PRIVATE
$<BUILD_INTERFACE:${ASCEND_CANN_PACKAGE_PATH}/include>
$<BUILD_INTERFACE:${ASCEND_CANN_PACKAGE_PATH}/include/aclnn>
$<BUILD_INTERFACE:${ASCEND_CANN_PACKAGE_PATH}/include/aclnn_kernels>
)
target_compile_options(opapi PRIVATE
-Werror=format
)
target_compile_definitions(opapi PRIVATE
-DACLNN_LOG_FMT_CHECK
)
target_link_libraries(opapi PRIVATE
$<BUILD_INTERFACE:intf_pub>
-Wl,--whole-archive
ops_aclnn
-Wl,--no-whole-archive
-lopapi
nnopbase
profapi
ge_common_base
ascend_dump
ascendalog
dl
)
set_target_properties(opapi PROPERTIES OUTPUT_NAME
cust_opapi
)
install(TARGETS opapi
LIBRARY DESTINATION packages/vendors/${VENDOR_NAME}/op_api/lib
)
# op proto
add_library(opsproto SHARED)
target_compile_options(opsproto PRIVATE
$<$<COMPILE_LANGUAGE:CXX>:-std=c++11>
-fvisibility=hidden
)
target_compile_definitions(opsproto PRIVATE
LOG_CPP
PROCESS_LOG
)
target_link_libraries(opsproto PRIVATE
$<BUILD_INTERFACE:intf_pub>
$<BUILD_INTERFACE:ops_utils_proto_headers>
-Wl,--whole-archive
rt2_registry
-Wl,--no-whole-archive
-Wl,--no-as-needed
exe_graph
graph
graph_base
register
ascendalog
error_manager
platform
-Wl,--as-needed
c_sec
)
set_target_properties(opsproto PROPERTIES OUTPUT_NAME
cust_opsproto_rt2.0
)
install(TARGETS opsproto
LIBRARY DESTINATION packages/vendors/${VENDOR_NAME}/op_proto/lib/linux/${CMAKE_SYSTEM_PROCESSOR}
)
# op tiling
add_library(optiling SHARED)
target_sources(optiling PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/utils/src/fallback_comm.cpp
)
target_compile_options(optiling PRIVATE
$<$<COMPILE_LANGUAGE:CXX>:-std=c++11>
-fvisibility=hidden
)
target_compile_definitions(optiling PRIVATE
LOG_CPP
PROCESS_LOG
)
target_link_libraries(optiling PRIVATE
$<BUILD_INTERFACE:intf_pub>
$<BUILD_INTERFACE:ops_utils_tiling_headers>
-Wl,--whole-archive
rt2_registry
-Wl,--no-whole-archive
-Wl,--no-as-needed
graph
graph_base
exe_graph
platform
register
ascendalog
error_manager
-Wl,--as-needed
-Wl,--whole-archive
tiling_api
-Wl,--no-whole-archive
mmpa
c_sec
)
set_target_properties(optiling PROPERTIES OUTPUT_NAME
cust_opmaster_rt2.0
)
add_custom_command(TARGET optiling
POST_BUILD
COMMAND ${CMAKE_COMMAND} -E make_directory ${TILING_CUSTOM_DIR}
COMMAND ln -sf $<TARGET_FILE:optiling> ${TILING_CUSTOM_FILE}
)
install(TARGETS optiling
LIBRARY DESTINATION packages/vendors/${VENDOR_NAME}/op_impl/ai_core/tbe/op_tiling/lib/linux/${CMAKE_SYSTEM_PROCESSOR}
)
# optiling compat
set(compat_optiling_dir ${CMAKE_CURRENT_BINARY_DIR}/compat)
set(compat_optiling_file ${compat_optiling_dir}/liboptiling.so)
add_custom_target(optiling_compat ALL
DEPENDS ${compat_optiling_file}
)
add_custom_command(
OUTPUT ${compat_optiling_file}
COMMAND ${CMAKE_COMMAND} -E make_directory ${compat_optiling_dir}
COMMAND ln -sf lib/linux/${CMAKE_SYSTEM_PROCESSOR}/$<TARGET_FILE_NAME:optiling> ${compat_optiling_file}
)
install(FILES ${compat_optiling_file}
DESTINATION packages/vendors/${VENDOR_NAME}/op_impl/ai_core/tbe/op_tiling
)
add_ops_tiling_keys(
OP_NAME "ALL"
TILING_KEYS ${TILING_KEY}
)
add_opc_config(
OP_NAME "ALL"
CONFIG ${OP_DEBUG_CONFIG}
)
if(ADD_OPS_COMPILE_OPTION_V2)
add_ops_compile_options(
OP_NAME "ALL"
OPTIONS ${OPS_COMPILE_OPTIONS}
)
endif()
endif ()
add_subdirectory(utils)
set(OP_LIST)
set(OP_DIR_LIST)
op_add_subdirectory(OP_LIST OP_DIR_LIST)
foreach (OP_DIR ${OP_DIR_LIST})
add_subdirectory(${OP_DIR}/op_host)
endforeach ()
set(OP_DEPEND_DIR_LIST)
op_add_depend_directory(
OP_LIST ${OP_LIST}
OP_DIR_LIST OP_DEPEND_DIR_LIST
)
foreach (OP_DEPEND_DIR ${OP_DEPEND_DIR_LIST})
add_subdirectory(${OP_DEPEND_DIR}/op_host)
endforeach ()
# ------------------------------------------------ aclnn ------------------------------------------------
get_target_property(base_aclnn_srcs op_host_aclnn SOURCES)
get_target_property(base_aclnn_inner_srcs op_host_aclnnInner SOURCES)
get_target_property(base_aclnn_exclude_srcs op_host_aclnnExc SOURCES)
if (BUILD_OPEN_PROJECT)
set(base_aclnn_binary_dir ${ASCEND_AUTOGEN_DIR})
else()
get_target_property(base_aclnn_binary_dir op_host_aclnn BINARY_DIR)
endif ()
set(generate_aclnn_srcs)
set(generate_aclnn_inner_srcs)
set(generate_aclnn_headers)
set(generate_proto_dir ${base_aclnn_binary_dir})
set(generate_exclude_proto_srcs)
set(generate_proto_srcs)
set(generate_proto_headers)
if (base_aclnn_srcs)
foreach (_src ${base_aclnn_srcs})
string(REGEX MATCH "^${CMAKE_CURRENT_SOURCE_DIR}" is_match "${_src}")
if (is_match)
get_filename_component(name_without_ext ${_src} NAME_WE)
string(REGEX REPLACE "_def$" "" _op_name ${name_without_ext})
list(APPEND generate_aclnn_srcs ${base_aclnn_binary_dir}/aclnn_${_op_name}.cpp)
list(APPEND generate_aclnn_headers ${base_aclnn_binary_dir}/aclnn_${_op_name}.h)
list(APPEND generate_proto_srcs ${generate_proto_dir}/${_op_name}_proto.cpp)
list(APPEND generate_proto_headers ${generate_proto_dir}/${_op_name}_proto.h)
endif ()
endforeach ()
else ()
add_custom_command(OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/op_host_aclnn_stub.cpp
COMMAND touch ${CMAKE_CURRENT_BINARY_DIR}/op_host_aclnn_stub.cpp
)
target_sources(op_host_aclnn PRIVATE
${CMAKE_CURRENT_BINARY_DIR}/op_host_aclnn_stub.cpp
)
endif ()
if (base_aclnn_inner_srcs)
foreach (_src ${base_aclnn_inner_srcs})
string(REGEX MATCH "^${CMAKE_CURRENT_SOURCE_DIR}" is_match "${_src}")
if (is_match)
get_filename_component(name_without_ext ${_src} NAME_WE)
string(REGEX REPLACE "_def$" "" _op_name ${name_without_ext})
list(APPEND generate_aclnn_inner_srcs ${base_aclnn_binary_dir}/inner/aclnnInner_${_op_name}.cpp)
list(APPEND generate_proto_srcs ${generate_proto_dir}/inner/${_op_name}_proto.cpp)
list(APPEND generate_proto_headers ${generate_proto_dir}/inner/${_op_name}_proto.h)
endif ()
endforeach ()
else ()
add_custom_command(OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/op_host_aclnn_inner_stub.cpp
COMMAND touch ${CMAKE_CURRENT_BINARY_DIR}/op_host_aclnn_inner_stub.cpp
)
target_sources(op_host_aclnnInner PRIVATE
${CMAKE_CURRENT_BINARY_DIR}/op_host_aclnn_inner_stub.cpp
)
endif ()
if (base_aclnn_exclude_srcs)
foreach (_src ${base_aclnn_exclude_srcs})
string(REGEX MATCH "^${CMAKE_CURRENT_SOURCE_DIR}" is_match "${_src}")
if (is_match)
get_filename_component(name_without_ext ${_src} NAME_WE)
string(REGEX REPLACE "_def$" "" _op_name ${name_without_ext})
list(APPEND generate_exclude_proto_srcs ${generate_proto_dir}/exc/${_op_name}_proto.cpp)
list(APPEND generate_proto_srcs ${generate_proto_dir}/exc/${_op_name}_proto.cpp)
list(APPEND generate_proto_headers ${generate_proto_dir}/exc/${_op_name}_proto.h)
endif ()
endforeach ()
else()
add_custom_command(OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/op_host_aclnn_exc_stub.cpp
COMMAND touch ${CMAKE_CURRENT_BINARY_DIR}/op_host_aclnn_exc_stub.cpp
)
target_sources(op_host_aclnnExc PRIVATE
${CMAKE_CURRENT_BINARY_DIR}/op_host_aclnn_exc_stub.cpp
)
endif ()
if (BUILD_OPEN_PROJECT)
if (generate_aclnn_srcs OR generate_aclnn_inner_srcs)
set(ops_aclnn_src ${generate_aclnn_srcs} ${generate_aclnn_inner_srcs})
else ()
set(ops_aclnn_src ${CMAKE_CURRENT_BINARY_DIR}/ops_aclnn_src_stub.cpp)
add_custom_command(OUTPUT ${ops_aclnn_src}
COMMAND touch ${ops_aclnn_src}
)
endif ()
set_source_files_properties(${ops_aclnn_src}
PROPERTIES GENERATED TRUE
)
add_library(ops_aclnn STATIC
${ops_aclnn_src}
)
target_compile_options(ops_aclnn PRIVATE
$<$<COMPILE_LANGUAGE:CXX>:-std=gnu++1z>
)
target_link_libraries(ops_aclnn PRIVATE
$<BUILD_INTERFACE:intf_pub>
)
add_dependencies(ops_aclnn opbuild_gen_default opbuild_gen_inner)
set_source_files_properties(${generate_proto_srcs}
PROPERTIES GENERATED TRUE
)
target_sources(opsproto PRIVATE
${generate_proto_srcs}
)
add_dependencies(opsproto ops_proto_headers)
install(FILES ${generate_proto_headers}
DESTINATION packages/vendors/${VENDOR_NAME}/op_proto/inc OPTIONAL
)
redefine_file_macro(
TARGET_NAME
op_host_aclnn
op_host_aclnnInner
op_host_aclnnExc
opapi
opsproto
optiling
ops_aclnn
)
else()
if (generate_aclnn_srcs OR generate_aclnn_inner_srcs)
set_source_files_properties(${generate_aclnn_srcs} ${generate_aclnn_inner_srcs}
TARGET_DIRECTORY acl_op_builtin
PROPERTIES GENERATED TRUE
)
target_sources(acl_op_builtin PRIVATE
${generate_aclnn_srcs}
${generate_aclnn_inner_srcs}
)
endif ()
if (generate_proto_srcs)
set_source_files_properties(${generate_proto_srcs}
TARGET_DIRECTORY opsproto opsproto_rt2.0
PROPERTIES GENERATED TRUE
)
target_sources(opsproto PRIVATE
${generate_proto_srcs}
)
add_dependencies(opsproto ops_proto_headers)
target_sources(opsproto_rt2.0 PRIVATE
${generate_proto_srcs}
)
add_dependencies(opsproto_rt2.0 ops_proto_headers)
endif ()
add_target_source(
TARGET_NAME opmaster_rt2.0 opmaster_static_rt2.0
BASE_TARGET optiling
SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}
)
add_target_source(
TARGET_NAME opsproto_rt2.0 opsproto_static_rt2.0
BASE_TARGET opsproto
SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}
)
add_static_ops(
ACLNN_SRC ${generate_aclnn_srcs}
ACLNN_INNER_SRC ${generate_aclnn_inner_srcs}
SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}
)
endif ()
if (generate_aclnn_headers)
install(FILES ${generate_aclnn_headers}
DESTINATION ${ACLNN_INC_INSTALL_DIR} OPTIONAL
)
endif ()
add_library(ops_proto_headers INTERFACE)
target_include_directories(ops_proto_headers INTERFACE
$<BUILD_INTERFACE:${generate_proto_dir}>
$<BUILD_INTERFACE:${generate_proto_dir}/inner>
$<BUILD_INTERFACE:${generate_proto_dir}/exc>
$<INSTALL_INTERFACE:include/ops_adv/proto>
)
if ((NOT BUILD_OPEN_PROJECT) AND ("${PRODUCT_SIDE}" STREQUAL "device"))
ExternalProject_Add(extern_opbuild_gen_default
SOURCE_DIR ${TOP_DIR}/cmake/superbuild
CONFIGURE_COMMAND ${CMAKE_COMMAND}
-G ${CMAKE_GENERATOR}
-DHOST_PACKAGE=opp
-DBUILD_MOD=ops
-DCMAKE_INSTALL_PREFIX=${CMAKE_CURRENT_BINARY_DIR}/opbuild_output
-DFEATURE_LIST=custom_opbuild_out_dir=${generate_proto_dir}
<SOURCE_DIR>
BUILD_COMMAND TARGETS=opbuild_gen_all $(MAKE)
INSTALL_COMMAND ""
LIST_SEPARATOR ::
EXCLUDE_FROM_ALL TRUE
)
add_dependencies(ops_proto_headers extern_opbuild_gen_default)
else()
add_dependencies(ops_proto_headers opbuild_gen_default opbuild_gen_inner opbuild_gen_exc)
endif ()
if (NOT BUILD_OPEN_PROJECT)
if (generate_proto_srcs)
install_package(
PACKAGE ops_adv
TARGETS ops_proto_headers
FILES ${generate_proto_headers}
DESTINATION include/ops_adv/proto
)
endif ()
endif ()
# ------------------------------------------------ opbuild ------------------------------------------------
if (BUILD_OPEN_PROJECT)
if (generate_aclnn_srcs)
add_custom_command(OUTPUT ${generate_aclnn_srcs} ${generate_aclnn_headers}
COMMAND mkdir -p ${base_aclnn_binary_dir}
COMMAND OPS_PROTO_SEPARATE=1
OPS_ACLNN_GEN=1
OPS_PROJECT_NAME=aclnn
${OP_BUILD_TOOL}
$<TARGET_FILE:op_host_aclnn>
${base_aclnn_binary_dir}
)
endif ()
add_custom_target(opbuild_gen_default
DEPENDS ${generate_aclnn_srcs} ${generate_aclnn_headers} op_host_aclnn
)
if (generate_aclnn_inner_srcs)
add_custom_command(OUTPUT ${generate_aclnn_inner_srcs}
COMMAND mkdir -p ${base_aclnn_binary_dir}/inner
COMMAND OPS_PROTO_SEPARATE=1
OPS_ACLNN_GEN=1
OPS_PROJECT_NAME=aclnnInner
${OP_BUILD_TOOL}
$<TARGET_FILE:op_host_aclnnInner>
${base_aclnn_binary_dir}/inner
)
endif ()
add_custom_target(opbuild_gen_inner
DEPENDS ${generate_aclnn_inner_srcs} op_host_aclnnInner
)
if (generate_exclude_proto_srcs)
add_custom_command(OUTPUT ${generate_exclude_proto_srcs}
COMMAND mkdir -p ${base_aclnn_binary_dir}/exc
COMMAND OPS_PROTO_SEPARATE=1
OPS_ACLNN_GEN=0
OPS_PROJECT_NAME=aclnnExc
${OP_BUILD_TOOL}
$<TARGET_FILE:op_host_aclnnExc>
${base_aclnn_binary_dir}/exc
)
endif ()
add_custom_target(opbuild_gen_exc
DEPENDS ${generate_exclude_proto_srcs} op_host_aclnnExc
)
endif ()
# ------------------------------------------------ generate adapt py ------------------------------------------------
add_custom_target(generate_adapt_py
COMMAND ${HI_PYTHON} ${ASCENDC_CMAKE_UTIL_DIR}/ascendc_impl_build.py
\"\"
\"\"
\"\"
\"\"
${ASCEND_IMPL_OUT_DIR}
${ASCEND_AUTOGEN_DIR}
--opsinfo-dir ${base_aclnn_binary_dir} ${base_aclnn_binary_dir}/inner ${base_aclnn_binary_dir}/exc
)
add_dependencies(generate_adapt_py opbuild_gen_default opbuild_gen_inner opbuild_gen_exc)
foreach (_op_name ${OP_LIST})
install(FILES ${ASCEND_IMPL_OUT_DIR}/dynamic/${_op_name}.py
DESTINATION ${IMPL_DYNAMIC_INSTALL_DIR}
OPTIONAL
)
endforeach ()
install(DIRECTORY ${OPS_ADV_UTILS_KERNEL_INC}/
DESTINATION ${IMPL_INSTALL_DIR}/ascendc/common
)
foreach (op_dir ${OP_DIR_LIST})
get_filename_component(_op_name "${op_dir}" NAME)
file(GLOB KERNEL_FILES
${op_dir}/op_kernel/*.cpp
${op_dir}/op_kernel/*.h
)
install(FILES ${KERNEL_FILES}
DESTINATION ${IMPL_INSTALL_DIR}/ascendc/${_op_name}
OPTIONAL
)
endforeach ()
# ------------------------------------------------ generate compile cmd ------------------------------------------------
if (BUILD_OPEN_PROJECT)
add_custom_target(prepare_build ALL)
add_custom_target(generate_compile_cmd ALL)
add_custom_target(generate_ops_info ALL)
add_dependencies(prepare_build generate_adapt_py generate_compile_cmd)
foreach (compute_unit ${ASCEND_COMPUTE_UNIT})
add_compile_cmd_target(
COMPUTE_UNIT ${compute_unit}
)
add_ops_info_target(
COMPUTE_UNIT ${compute_unit}
)
endforeach ()
else()
add_dependencies(tbe_ops_json_info generate_adapt_py)
endif ()
# ------------------------------------------------ opp kernel ------------------------------------------------
if (ENABLE_OPS_KERNEL)
add_custom_target(ops_kernel ALL)
add_custom_target(ops_config ALL)
add_dependencies(ops_kernel ops_config)
foreach (compute_unit ${ASCEND_COMPUTE_UNIT})
add_bin_compile_target(
COMPUTE_UNIT
${compute_unit}
OP_INFO
${OP_DIR_LIST}
)
endforeach ()
endif ()
if (BUILD_OPEN_PROJECT)
add_custom_target(modify_vendor ALL
DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/scripts/install.sh ${CMAKE_CURRENT_BINARY_DIR}/scripts/upgrade.sh
)
# modify VENDOR_NAME in install.sh and upgrade.sh
add_custom_command(OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/scripts/install.sh ${CMAKE_CURRENT_BINARY_DIR}/scripts/upgrade.sh
COMMAND mkdir -p ${CMAKE_CURRENT_BINARY_DIR}/scripts
COMMAND cp -r ${ASCEND_PROJECT_DIR}/scripts/* ${CMAKE_CURRENT_BINARY_DIR}/scripts/
COMMAND chmod +w ${CMAKE_CURRENT_BINARY_DIR}/scripts/*
COMMAND sed -i "s/vendor_name=customize/vendor_name=${VENDOR_NAME}/g" ${CMAKE_CURRENT_BINARY_DIR}/scripts/*
)
install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/scripts/
DESTINATION . FILE_PERMISSIONS OWNER_EXECUTE OWNER_READ GROUP_READ
)
# gen version.info
set(version_info_dir ${CMAKE_CURRENT_BINARY_DIR})
set(version_info_file ${version_info_dir}/version.info)
add_custom_target(gen_version_info ALL
DEPENDS ${version_info_file}
)
add_custom_command(OUTPUT ${version_info_file}
COMMAND bash ${ASCENDC_CMAKE_UTIL_DIR}/gen_version_info.sh ${ASCEND_CANN_PACKAGE_PATH} ${version_info_dir}
)
install(FILES ${version_info_file}
DESTINATION packages/vendors/${VENDOR_NAME}/
)
# CPack config
set(CPACK_PACKAGE_NAME ${CMAKE_PROJECT_NAME})
set(CPACK_PACKAGE_VERSION ${CMAKE_PROJECT_VERSION})
set(CPACK_PACKAGE_DESCRIPTION "CPack ops project")
set(CPACK_PACKAGE_DESCRIPTION_SUMMARY "CPack ops project")
set(CPACK_PACKAGE_DIRECTORY ${CMAKE_BINARY_DIR})
set(CPACK_PACKAGE_FILE_NAME "CANN-custom_ops-${CANN_VERSION}-linux.${CMAKE_SYSTEM_PROCESSOR}.run")
set(CPACK_GENERATOR External)
set(CPACK_CMAKE_GENERATOR "Unix Makefiles")
set(CPACK_EXTERNAL_ENABLE_STAGING TRUE)
set(CPACK_EXTERNAL_PACKAGE_SCRIPT ${ASCEND_CMAKE_DIR}/makeself.cmake)
set(CPACK_EXTERNAL_BUILT_PACKAGES ${CPACK_PACKAGE_DIRECTORY}/_CPack_Packages/Linux/External/${CPACK_PACKAGE_FILE_NAME}/${CPACK_PACKAGE_FILE_NAME})
include(CPack)
endif ()

189
csrc/build.sh Normal file
View File

@@ -0,0 +1,189 @@
#!/bin/bash
# Copyright (c) 2024 Huawei Technologies Co., Ltd.
# This file is a part of the CANN Open Software.
# Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
# Please refer to the License for details. You may not use this file except in compliance with the License.
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.
# ======================================================================================================================
set -e
CURRENT_DIR=$(dirname $(readlink -f ${BASH_SOURCE[0]}))
BUILD_DIR=${CURRENT_DIR}/build
OUTPUT_DIR=${CURRENT_DIR}/output
USER_ID=$(id -u)
PARENT_JOB="false"
CHECK_COMPATIBLE="true"
ASAN="false"
COV="false"
VERBOSE="false"
if [ "${USER_ID}" != "0" ]; then
DEFAULT_TOOLKIT_INSTALL_DIR="${HOME}/Ascend/ascend-toolkit/latest"
DEFAULT_INSTALL_DIR="${HOME}/Ascend/latest"
else
DEFAULT_TOOLKIT_INSTALL_DIR="/usr/local/Ascend/ascend-toolkit/latest"
DEFAULT_INSTALL_DIR="/usr/local/Ascend/latest"
fi
CUSTOM_OPTION="-DBUILD_OPEN_PROJECT=ON"
function help_info() {
echo "Usage: $0 [options]"
echo "Options:"
echo
echo "-h|--help Displays help message."
echo
echo "-n|--op-name Specifies the compiled operator. If there are multiple values, separate them with semicolons and use quotation marks. The default is all."
echo " For example: -n \"flash_attention_score\" or -n \"flash_attention_score;flash_attention_score_grad\""
echo
echo "-c|--compute-unit Specifies the chip type. If there are multiple values, separate them with semicolons and use quotation marks. The default is ascend910b."
echo " For example: -c \"ascend910b\" or -c \"ascend910b;ascend310p\""
echo
echo "--cov Compiles with cov."
echo
echo "--verbose Displays more compilation information."
echo
}
function log() {
local current_time=`date +"%Y-%m-%d %H:%M:%S"`
echo "[$current_time] "$1
}
function set_env()
{
source $ASCEND_CANN_PACKAGE_PATH/bin/setenv.bash || echo "0"
export BISHENG_REAL_PATH=$(which bisheng || true)
if [ -z "${BISHENG_REAL_PATH}" ];then
log "Error: bisheng compilation tool not found, Please check whether the cann package or environment variables are set."
exit 1
fi
}
function clean()
{
if [ -n "${BUILD_DIR}" ];then
rm -rf ${BUILD_DIR}
fi
mkdir -p ${BUILD_DIR} ${OUTPUT_DIR}
}
function cmake_config()
{
local extra_option="$1"
log "Info: cmake config ${CUSTOM_OPTION} ${extra_option} ."
cmake .. ${CUSTOM_OPTION} ${extra_option}
}
function build()
{
local target="$1"
if [ "${VERBOSE}" == "true" ];then
local option="--verbose"
fi
cmake --build . --target ${target} ${JOB_NUM} ${option}
}
function gen_bisheng(){
local ccache_program=$1
local gen_bisheng_dir=${BUILD_DIR}/gen_bisheng_dir
if [ ! -d "${gen_bisheng_dir}" ];then
mkdir -p ${gen_bisheng_dir}
fi
pushd ${gen_bisheng_dir}
$(> bisheng)
echo "#!/bin/bash" >> bisheng
echo "ccache_args=""\"""${ccache_program} ${BISHENG_REAL_PATH}""\"" >> bisheng
echo "args=""$""@" >> bisheng
if [ "${VERBOSE}" == "true" ];then
echo "echo ""\"""$""{ccache_args} ""$""args""\"" >> bisheng
fi
echo "eval ""\"""$""{ccache_args} ""$""args""\"" >> bisheng
chmod +x bisheng
export PATH=${gen_bisheng_dir}:$PATH
popd
}
function build_package(){
build package
}
function build_host(){
build_package
}
function build_kernel(){
build ops_kernel
}
while [[ $# -gt 0 ]]; do
case $1 in
-h|--help)
help_info
exit
;;
-n|--op-name)
ascend_op_name="$2"
shift 2
;;
-c|--compute-unit)
ascend_compute_unit="$2"
shift 2
;;
*)
help_info
exit 1
;;
esac
done
if [ -n "${ascend_compute_unit}" ];then
CUSTOM_OPTION="${CUSTOM_OPTION} -DASCEND_COMPUTE_UNIT=${ascend_compute_unit}"
fi
if [ -n "${ascend_op_name}" ];then
CUSTOM_OPTION="${CUSTOM_OPTION} -DASCEND_OP_NAME=${ascend_op_name}"
fi
if [ -n "${ASCEND_HOME_PATH}" ];then
ASCEND_CANN_PACKAGE_PATH=${ASCEND_HOME_PATH}
elif [ -n "${ASCEND_OPP_PATH}" ];then
ASCEND_CANN_PACKAGE_PATH=$(dirname ${ASCEND_OPP_PATH})
elif [ -d "${DEFAULT_TOOLKIT_INSTALL_DIR}" ];then
ASCEND_CANN_PACKAGE_PATH=${DEFAULT_TOOLKIT_INSTALL_DIR}
elif [ -d "${DEFAULT_INSTALL_DIR}" ];then
ASCEND_CANN_PACKAGE_PATH=${DEFAULT_INSTALL_DIR}
else
log "Error: Please set the toolkit package installation directory through parameter -p|--package-path."
exit 1
fi
if [ "${PARENT_JOB}" == "false" ];then
CPU_NUM=$(($(cat /proc/cpuinfo | grep "^processor" | wc -l)*2))
JOB_NUM="-j${CPU_NUM}"
fi
CUSTOM_OPTION="${CUSTOM_OPTION} -DCUSTOM_ASCEND_CANN_PACKAGE_PATH=${ASCEND_CANN_PACKAGE_PATH} -DCHECK_COMPATIBLE=${CHECK_COMPATIBLE}"
set_env
clean
ccache_system=$(which ccache || true)
if [ -n "${ccache_system}" ];then
CUSTOM_OPTION="${CUSTOM_OPTION} -DENABLE_CCACHE=ON -DCUSTOM_CCACHE=${ccache_system}"
gen_bisheng ${ccache_system}
fi
cd ${BUILD_DIR}
cmake_config
build_package

34
csrc/build_aclnn.sh Normal file
View File

@@ -0,0 +1,34 @@
#!/bin/bash
ROOT_DIR=$1
SOC_VERSION=$2
if [[ "$SOC_VERSION" =~ ^ascend310 ]]; then
# ASCEND310P series
# currently, no custom aclnn ops for ASCEND310 series
# CUSTOM_OPS=""
# SOC_ARG="ascend310p"
exit 0
elif [[ "$SOC_VERSION" =~ ^ascend910b ]]; then
# ASCEND910B (A2) series
CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list"
SOC_ARG="ascend910b"
elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then
# ASCEND910C (A3) series
CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list"
SOC_ARG="ascend910_93"
else
# others
# currently, no custom aclnn ops for other series
exit 0
fi
# build custom ops
cd csrc
rm -rf build output
echo "building custom ops $CUSTOM_OPS for $SOC_VERSION"
bash build.sh -n $CUSTOM_OPS -c $SOC_ARG
# install custom ops to vllm_ascend/_cann_ops_custom
./output/CANN-custom_ops*.run --install-path=$ROOT_DIR/vllm_ascend/_cann_ops_custom
source $ROOT_DIR/vllm_ascend/_cann_ops_custom/vendors/customize/bin/set_env.bash

235
csrc/cmake/config.cmake Normal file
View File

@@ -0,0 +1,235 @@
# Copyright (c) 2024 Huawei Technologies Co., Ltd.
# This file is a part of the CANN Open Software.
# Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
# Please refer to the License for details. You may not use this file except in compliance with the License.
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.
# ======================================================================================================================
########################################################################################################################
# Environment Check
########################################################################################################################
# Python3
find_package(Python3)
if ((NOT Python3_FOUND) OR (${Python3_EXECUTABLE} STREQUAL ""))
message(FATAL_ERROR "Can't find python3.")
endif ()
set(HI_PYTHON "${Python3_EXECUTABLE}" CACHE STRING "python executor")
# Get the base CANN path
if (CUSTOM_ASCEND_CANN_PACKAGE_PATH)
set(ASCEND_CANN_PACKAGE_PATH ${CUSTOM_ASCEND_CANN_PACKAGE_PATH})
elseif (DEFINED ENV{ASCEND_HOME_PATH})
set(ASCEND_CANN_PACKAGE_PATH $ENV{ASCEND_HOME_PATH})
elseif (DEFINED ENV{ASCEND_OPP_PATH})
get_filename_component(ASCEND_CANN_PACKAGE_PATH "$ENV{ASCEND_OPP_PATH}/.." ABSOLUTE)
else()
set(ASCEND_CANN_PACKAGE_PATH "/usr/local/Ascend/latest")
endif ()
message(STATUS "ASCEND_CANN_PACKAGE_PATH=${ASCEND_CANN_PACKAGE_PATH}")
########################################################################################################################
# Common Configuration
########################################################################################################################
# Switches
option(PREPARE_BUILD "Prepare build." OFF)
option(ENABLE_OPS_HOST "Build ops host." ON)
option(ENABLE_OPS_KERNEL "Build ops kernel." ON)
if (TESTS_EXAMPLE_OPS_TEST OR TESTS_UT_OPS_TEST)
set(ENABLE_OPS_KERNEL OFF)
endif ()
set(OP_DEBUG_CONFIG "false" CACHE STRING "op debug config")
# Path configuration
# Source tree related paths
get_filename_component(OPS_ADV_DIR "${CMAKE_CURRENT_SOURCE_DIR}" REALPATH)
get_filename_component(OPS_ADV_CMAKE_DIR "${OPS_ADV_DIR}/cmake" REALPATH)
get_filename_component(OPS_ADV_UTILS_KERNEL_INC "${OPS_ADV_DIR}/utils/inc/kernel" REALPATH)
# Build tree related paths
set(ASCEND_IMPL_OUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/impl CACHE STRING "ascend impl output directories")
set(ASCEND_BINARY_OUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/binary CACHE STRING "ascend binary output directories")
set(ASCEND_AUTOGEN_DIR ${CMAKE_CURRENT_BINARY_DIR}/autogen CACHE STRING "Auto generate file directories")
set(ASCEND_CUSTOM_OPTIONS ${ASCEND_AUTOGEN_DIR}/custom_compile_options.ini)
set(ASCEND_CUSTOM_TILING_KEYS ${ASCEND_AUTOGEN_DIR}/custom_tiling_keys.ini)
set(ASCEND_CUSTOM_OPC_OPTIONS ${ASCEND_AUTOGEN_DIR}/custom_opc_options.ini)
set(OP_BUILD_TOOL ${ASCEND_CANN_PACKAGE_PATH}/tools/opbuild/op_build CACHE STRING "op_build tool")
file(MAKE_DIRECTORY ${ASCEND_AUTOGEN_DIR})
file(REMOVE ${ASCEND_CUSTOM_OPTIONS})
file(TOUCH ${ASCEND_CUSTOM_OPTIONS})
file(REMOVE ${ASCEND_CUSTOM_TILING_KEYS})
file(TOUCH ${ASCEND_CUSTOM_TILING_KEYS})
file(REMOVE ${ASCEND_CUSTOM_OPC_OPTIONS})
file(TOUCH ${ASCEND_CUSTOM_OPC_OPTIONS})
if (BUILD_OPEN_PROJECT)
if(EXISTS ${ASCEND_CANN_PACKAGE_PATH}/tools/ascend_project/cmake)
set(ASCEND_PROJECT_DIR ${ASCEND_CANN_PACKAGE_PATH}/tools/ascend_project)
else()
set(ASCEND_PROJECT_DIR ${ASCEND_CANN_PACKAGE_PATH}/tools/op_project_templates/ascendc/customize)
endif()
set(ASCEND_CMAKE_DIR ${ASCEND_PROJECT_DIR}/cmake CACHE STRING "ascend project cmake")
set(IMPL_INSTALL_DIR packages/vendors/${VENDOR_NAME}/op_impl/ai_core/tbe/${VENDOR_NAME}_impl)
set(IMPL_DYNAMIC_INSTALL_DIR packages/vendors/${VENDOR_NAME}/op_impl/ai_core/tbe/${VENDOR_NAME}_impl/dynamic)
set(ACLNN_INC_INSTALL_DIR packages/vendors/${VENDOR_NAME}/op_api/include)
else()
set(ASCEND_CMAKE_DIR ${TOP_DIR}/asl/ops/cann/ops/built-in/ascendc/samples/customize/cmake CACHE STRING "ascend project cmake")
set(IMPL_INSTALL_DIR lib/ascendc/impl)
set(IMPL_DYNAMIC_INSTALL_DIR lib/ascendc/impl/dynamic)
set(ACLNN_INC_INSTALL_DIR lib/include)
set(OPS_STATIC_TYPES infer train)
set(OPS_STATIC_SCRIPT ${TOP_DIR}/asl/ops/cann/ops/built-in/kernel/binary_script/build_opp_kernel_static.py)
endif ()
set(ASCENDC_CMAKE_UTIL_DIR ${ASCEND_CMAKE_DIR}/util)
set(CUSTOM_DIR ${CMAKE_BINARY_DIR}/custom)
set(TILING_CUSTOM_DIR ${CUSTOM_DIR}/op_impl/ai_core/tbe/op_tiling)
set(TILING_CUSTOM_FILE ${TILING_CUSTOM_DIR}/liboptiling.so)
# Temporary adaptation for ascendc changes, to be removed after switching to the new version of ascendc
if(EXISTS ${ASCENDC_CMAKE_UTIL_DIR}/ascendc_gen_options.py)
set(ADD_OPS_COMPILE_OPTION_V2 ON)
else()
set(ADD_OPS_COMPILE_OPTION_V2 OFF)
endif()
########################################################################################################################
# CMake Options, Default Parameters Setting
# Configure CMake options and default parameters according to the CMake build process
# CMake build process: 1) Configuration phase; 2) Build phase; 3) Installation phase;
########################################################################################################################
if (BUILD_OPEN_PROJECT)
# Build phase
# Build type
# The Generator in CMake is a tool used to generate native build systems. Generally divided into two types:
# 1. Single-configuration generator:
# In the configuration phase, only one build type is allowed to be specified through the variable CMAKE_BUILD_TYPE;
# In the build phase, the build type cannot be changed, and only the build type specified through the variable CMAKE_BUILD_TYPE in the configuration phase can be used;
# Common generators of this type include: Ninja, Unix Makefiles
# 2. Multi-configuration generator:
# In the configuration phase, only the list of build types available in the build phase is specified through the variable CMAKE_CONFIGURATION_TYPES;
# In the build phase, the specific build type of the build phase is specified through the "--config" parameter;
# Common generators of this type include: Xcode, Visual Studio
# Therefore:
# 1. In the single-configuration generator scenario, if the build type (CMAKE_BUILD_TYPE) is not specified, the default is Debug;
# 2. In the multi-configuration generator scenario, if the build types available in the build phase (CMAKE_CONFIGURATION_TYPES) are not specified,
# it is defaulted to the full set of build types allowed by CMake [Debug;Release;MinSizeRel;RelWithDebInfo]
get_property(GENERATOR_IS_MULTI_CONFIG GLOBAL PROPERTY GENERATOR_IS_MULTI_CONFIG)
if (GENERATOR_IS_MULTI_CONFIG)
if (NOT CMAKE_CONFIGURATION_TYPES)
set(CMAKE_CONFIGURATION_TYPES "Debug;Release;MinSizeRel;RelWithDebInfo" CACHE STRING "Configuration Build type" FORCE)
endif ()
else ()
if (NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE "Debug" CACHE STRING "Build type(default Debug)" FORCE)
endif ()
endif ()
# Build phase
# Executable runtime library file search path RPATH
# Do not skip RPATH in UTest and Example scenarios
if (TESTS_UT_OPS_TEST OR TESTS_EXAMPLE_OPS_TEST)
set(CMAKE_SKIP_RPATH FALSE)
else ()
set(CMAKE_SKIP_RPATH TRUE)
endif ()
# Build phase
# CCACHE configuration
if (ENABLE_CCACHE)
if (CUSTOM_CCACHE)
set(CCACHE_PROGRAM ${CUSTOM_CCACHE})
else()
find_program(CCACHE_PROGRAM ccache)
endif ()
if (CCACHE_PROGRAM)
set(CMAKE_C_COMPILER_LAUNCHER ${CCACHE_PROGRAM} CACHE PATH "C cache Compiler")
set(CMAKE_CXX_COMPILER_LAUNCHER ${CCACHE_PROGRAM} CACHE PATH "CXX cache Compiler")
endif ()
endif ()
# Installation phase
# Installation path
# When CMAKE_INSTALL_PREFIX is not explicitly set (i.e., CMAKE_INSTALL_PREFIX takes the default value),
# correct its value to be level with the build tree root directory CMAKE_CURRENT_BINARY_DIR
if (CMAKE_INSTALL_PREFIX_INITIALIZED_TO_DEFAULT)
get_filename_component(_Install_Path_Prefix "${CMAKE_CURRENT_BINARY_DIR}/../output" REALPATH)
set(CMAKE_INSTALL_PREFIX "${_Install_Path_Prefix}" CACHE STRING "Install path" FORCE)
endif ()
endif ()
########################################################################################################################
# Public Compilation Parameters
########################################################################################################################
list(TRANSFORM ASCEND_COMPUTE_UNIT TOLOWER)
if (BUILD_OPEN_PROJECT)
message(STATUS "ENABLE_CCACHE=${ENABLE_CCACHE}, CUSTOM_CCACHE=${CUSTOM_CCACHE}")
message(STATUS "CCACHE_PROGRAM=${CCACHE_PROGRAM}")
message(STATUS "ASCEND_COMPUTE_UNIT=${ASCEND_COMPUTE_UNIT}")
message(STATUS "ASCEND_OP_NAME=${ASCEND_OP_NAME}")
message(STATUS "TILING_KEY=${TILING_KEY}")
message(STATUS "TESTS_UT_OPS_TEST=${TESTS_UT_OPS_TEST}")
message(STATUS "TESTS_EXAMPLE_OPS_TEST=${TESTS_EXAMPLE_OPS_TEST}")
endif ()
########################################################################################################################
# Preprocessing
########################################################################################################################
if (BUILD_OPEN_PROJECT)
if (NOT PREPARE_BUILD AND ENABLE_OPS_KERNEL)
if (TILING_KEY)
string(REPLACE ";" "::" EP_TILING_KEY "${TILING_KEY}")
else()
set(EP_TILING_KEY FALSE)
endif ()
if (OPS_COMPILE_OPTIONS)
string(REPLACE ";" "::" EP_OPS_COMPILE_OPTIONS "${OPS_COMPILE_OPTIONS}")
else()
set(EP_OPS_COMPILE_OPTIONS FALSE)
endif ()
string(REPLACE ";" "::" EP_ASCEND_COMPUTE_UNIT "${ASCEND_COMPUTE_UNIT}")
execute_process(COMMAND bash ${CMAKE_CURRENT_SOURCE_DIR}/cmake/scripts/prepare.sh
-s ${CMAKE_CURRENT_SOURCE_DIR}
-b ${CMAKE_CURRENT_BINARY_DIR}/prepare_build
-p ${ASCEND_CANN_PACKAGE_PATH}
--autogen-dir ${ASCEND_AUTOGEN_DIR}
--build-open-project ${BUILD_OPEN_PROJECT}
--binary-out-dir ${ASCEND_BINARY_OUT_DIR}
--impl-out-dir ${ASCEND_IMPL_OUT_DIR}
--op-build-tool ${OP_BUILD_TOOL}
--ascend-cmake-dir ${ASCEND_CMAKE_DIR}
--tiling-key ${EP_TILING_KEY}
--ops-compile-options ${EP_OPS_COMPILE_OPTIONS}
--check-compatible ${CHECK_COMPATIBLE}
--ascend-compute_unit ${EP_ASCEND_COMPUTE_UNIT}
--op_debug_config ${OP_DEBUG_CONFIG}
--ascend-op-name "${ASCEND_OP_NAME}"
RESULT_VARIABLE result
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE PREPARE_BUILD_OUTPUT_VARIABLE)
if (result)
message(FATAL_ERROR "Error: ops prepare build failed.")
endif ()
file(REMOVE ${ASCEND_CUSTOM_OPTIONS})
file(TOUCH ${ASCEND_CUSTOM_OPTIONS})
file(REMOVE ${ASCEND_CUSTOM_TILING_KEYS})
file(TOUCH ${ASCEND_CUSTOM_TILING_KEYS})
file(REMOVE ${ASCEND_CUSTOM_OPC_OPTIONS})
file(TOUCH ${ASCEND_CUSTOM_OPC_OPTIONS})
endif ()
endif ()
########################################################################################################################
# Other Configuration
########################################################################################################################
if (BUILD_OPEN_PROJECT)
if (TESTS_UT_OPS_TEST)
include(${OPS_ADV_CMAKE_DIR}/config_utest.cmake)
endif ()
endif ()

609
csrc/cmake/func.cmake Normal file
View File

@@ -0,0 +1,609 @@
# Copyright (c) 2024 Huawei Technologies Co., Ltd.
# This file is a part of the CANN Open Software.
# Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
# Please refer to the License for details. You may not use this file except in compliance with the License.
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.
# ======================================================================================================================
function(add_target_source)
cmake_parse_arguments(ADD "" "BASE_TARGET;SRC_DIR" "TARGET_NAME" ${ARGN})
get_target_property(all_srcs ${ADD_BASE_TARGET} SOURCES)
set(add_srcs)
foreach(_src ${all_srcs})
string(REGEX MATCH "^${ADD_SRC_DIR}" is_match "${_src}")
if (is_match)
list(APPEND add_srcs ${_src})
endif ()
endforeach()
get_target_property(all_includes ${ADD_BASE_TARGET} INCLUDE_DIRECTORIES)
set(add_includes)
foreach(_include ${all_includes})
string(REGEX MATCH "^${ADD_SRC_DIR}" is_match "${_include}")
if (is_match)
list(APPEND add_includes ${_include})
endif ()
endforeach()
foreach(_target_name ${ADD_TARGET_NAME})
target_sources(${_target_name} PRIVATE
${add_srcs}
)
target_include_directories(${_target_name} PRIVATE
${add_includes}
)
endforeach()
endfunction()
function(op_add_subdirectory OP_LIST OP_DIR_LIST)
set(_OP_LIST)
set(_OP_DIR_LIST)
file(GLOB OP_HOST_CMAKE_FILES "${CMAKE_CURRENT_SOURCE_DIR}/**/op_host/CMakeLists.txt")
foreach(OP_CMAKE_FILE ${OP_HOST_CMAKE_FILES})
get_filename_component(OP_HOST_DIR "${OP_CMAKE_FILE}" DIRECTORY)
get_filename_component(OP_DIR "${OP_HOST_DIR}" DIRECTORY)
get_filename_component(OP_NAME "${OP_DIR}" NAME)
if (NOT BUILD_OPEN_PROJECT)
if (EXISTS ${TOP_DIR}/asl/ops/cann/ops/built-in/tbe/impl/ascendc/${OP_NAME})
continue()
endif ()
endif ()
if (DEFINED ASCEND_OP_NAME AND NOT "${ASCEND_OP_NAME}" STREQUAL "")
if (NOT "${ASCEND_OP_NAME}" STREQUAL "all" AND NOT "${ASCEND_OP_NAME}" STREQUAL "ALL")
if (NOT ${OP_NAME} IN_LIST ASCEND_OP_NAME)
continue()
endif ()
endif ()
endif ()
list(APPEND _OP_LIST ${OP_NAME})
list(APPEND _OP_DIR_LIST ${OP_DIR})
endforeach()
list(REMOVE_DUPLICATES _OP_LIST)
list(REMOVE_DUPLICATES _OP_DIR_LIST)
list(SORT _OP_LIST)
list(SORT _OP_DIR_LIST)
set(${OP_LIST} ${_OP_LIST} PARENT_SCOPE)
set(${OP_DIR_LIST} ${_OP_DIR_LIST} PARENT_SCOPE)
endfunction()
function(op_add_depend_directory)
cmake_parse_arguments(DEP "" "OP_DIR_LIST" "OP_LIST" ${ARGN})
set(_OP_DEPEND_DIR_LIST)
foreach(op_name ${DEP_OP_LIST})
if (DEFINED ${op_name}_depends)
foreach(depend_info ${${op_name}_depends})
if (NOT EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${depend_info}/op_host/CMakeLists.txt)
continue()
endif ()
get_filename_component(_depend_op_name "${depend_info}" NAME)
if (NOT BUILD_OPEN_PROJECT)
if (EXISTS ${TOP_DIR}/asl/ops/cann/ops/built-in/tbe/impl/ascendc/${_depend_op_name})
continue()
endif ()
endif ()
if (NOT ${_depend_op_name} IN_LIST DEP_OP_LIST)
list(APPEND _OP_DEPEND_DIR_LIST ${CMAKE_CURRENT_SOURCE_DIR}/${depend_info})
endif ()
endforeach()
endif()
endforeach()
list(SORT _OP_DEPEND_DIR_LIST)
set(${DEP_OP_DIR_LIST} ${_OP_DEPEND_DIR_LIST} PARENT_SCOPE)
endfunction()
function(add_compile_cmd_target)
cmake_parse_arguments(CMD "" "COMPUTE_UNIT" "" ${ARGN})
if(ADD_OPS_COMPILE_OPTION_V2)
set(OP_DEBUG_CONFIG_OPTION --opc-config-file ${ASCEND_CUSTOM_OPC_OPTIONS})
else()
if(OP_DEBUG_CONFIG)
set(OP_DEBUG_CONFIG_OPTION --op-debug-config ${OP_DEBUG_CONFIG})
endif()
set(OP_TILING_KEY_OPTION --tiling-keys ${ASCEND_CUSTOM_TILING_KEYS})
endif()
set(_OUT_DIR ${ASCEND_BINARY_OUT_DIR}/${CMD_COMPUTE_UNIT})
set(GEN_OUT_DIR ${_OUT_DIR}/gen)
set(COMPILE_CMD_TARGET generate_compile_cmd_${CMD_COMPUTE_UNIT})
add_custom_target(${COMPILE_CMD_TARGET} ALL
COMMAND mkdir -p ${GEN_OUT_DIR}
COMMAND ${HI_PYTHON} ${ASCENDC_CMAKE_UTIL_DIR}/ascendc_bin_param_build.py
${base_aclnn_binary_dir}/aic-${CMD_COMPUTE_UNIT}-ops-info.ini
${GEN_OUT_DIR}
${CMD_COMPUTE_UNIT}
${OP_TILING_KEY_OPTION}
${OP_DEBUG_CONFIG_OPTION}
COMMAND ${HI_PYTHON} ${ASCENDC_CMAKE_UTIL_DIR}/ascendc_bin_param_build.py
${base_aclnn_binary_dir}/inner/aic-${CMD_COMPUTE_UNIT}-ops-info.ini
${GEN_OUT_DIR}
${CMD_COMPUTE_UNIT}
${OP_TILING_KEY_OPTION}
${OP_DEBUG_CONFIG_OPTION}
COMMAND ${HI_PYTHON} ${ASCENDC_CMAKE_UTIL_DIR}/ascendc_bin_param_build.py
${base_aclnn_binary_dir}/exc/aic-${CMD_COMPUTE_UNIT}-ops-info.ini
${GEN_OUT_DIR}
${CMD_COMPUTE_UNIT}
${OP_TILING_KEY_OPTION}
${OP_DEBUG_CONFIG_OPTION}
)
add_dependencies(${COMPILE_CMD_TARGET} opbuild_gen_default opbuild_gen_inner opbuild_gen_exc)
add_dependencies(generate_compile_cmd ${COMPILE_CMD_TARGET})
endfunction()
function(add_ops_info_target)
cmake_parse_arguments(OPINFO "" "COMPUTE_UNIT" "" ${ARGN})
set(OPS_INFO_TARGET generate_ops_info_${OPINFO_COMPUTE_UNIT})
set(OPS_INFO_JSON ${ASCEND_AUTOGEN_DIR}/aic-${OPINFO_COMPUTE_UNIT}-ops-info.json)
set(CUSTOM_OPS_INFO_DIR ${CUSTOM_DIR}/op_impl/ai_core/tbe/config/${OPINFO_COMPUTE_UNIT})
set(OPS_INFO_INI ${base_aclnn_binary_dir}/aic-${OPINFO_COMPUTE_UNIT}-ops-info.ini)
set(OPS_INFO_INNER_INI ${base_aclnn_binary_dir}/inner/aic-${OPINFO_COMPUTE_UNIT}-ops-info.ini)
set(OPS_INFO_EXCLUDE_INI ${base_aclnn_binary_dir}/exc/aic-${OPINFO_COMPUTE_UNIT}-ops-info.ini)
add_custom_command(OUTPUT ${OPS_INFO_JSON}
COMMAND ${HI_PYTHON} ${ASCENDC_CMAKE_UTIL_DIR}/parse_ini_to_json.py
${OPS_INFO_INI}
${OPS_INFO_INNER_INI}
${OPS_INFO_EXCLUDE_INI}
${OPS_INFO_JSON}
COMMAND mkdir -p ${CUSTOM_OPS_INFO_DIR}
COMMAND cp -f ${OPS_INFO_JSON} ${CUSTOM_OPS_INFO_DIR}
)
add_custom_target(${OPS_INFO_TARGET} ALL
DEPENDS ${OPS_INFO_JSON}
)
add_dependencies(${OPS_INFO_TARGET} opbuild_gen_default opbuild_gen_inner opbuild_gen_exc)
add_dependencies(generate_ops_info ${OPS_INFO_TARGET})
install(FILES ${OPS_INFO_JSON}
DESTINATION packages/vendors/${VENDOR_NAME}/op_impl/ai_core/tbe/config/${OPINFO_COMPUTE_UNIT} OPTIONAL
)
endfunction()
function(add_ops_compile_options)
cmake_parse_arguments(OP_COMPILE "" "OP_NAME" "COMPUTE_UNIT;OPTIONS" ${ARGN})
if(NOT OP_COMPILE_OPTIONS)
return()
endif()
if(ADD_OPS_COMPILE_OPTION_V2)
execute_process(COMMAND ${HI_PYTHON} ${ASCENDC_CMAKE_UTIL_DIR}/ascendc_gen_options.py
${ASCEND_CUSTOM_OPTIONS} ${OP_COMPILE_OP_NAME}
${OP_COMPILE_COMPUTE_UNIT} ${OP_COMPILE_OPTIONS}
RESULT_VARIABLE EXEC_RESULT
OUTPUT_VARIABLE EXEC_INFO
ERROR_VARIABLE EXEC_ERROR)
if (EXEC_RESULT)
message("add ops compile options info: ${EXEC_INFO}")
message("add ops compile options error: ${EXEC_ERROR}")
message(FATAL_ERROR "Error: add ops compile options failed!")
endif ()
else()
file(APPEND ${ASCEND_CUSTOM_OPTIONS}
"${OP_COMPILE_OP_NAME},${OP_COMPILE_COMPUTE_UNIT},${OP_COMPILE_OPTIONS}\n"
)
endif()
endfunction()
function(add_ops_tiling_keys)
cmake_parse_arguments(OP_COMPILE "" "OP_NAME" "COMPUTE_UNIT;TILING_KEYS" ${ARGN})
if(NOT OP_COMPILE_TILING_KEYS)
return()
endif()
if(ADD_OPS_COMPILE_OPTION_V2)
list(JOIN OP_COMPILE_TILING_KEYS "," STRING_TILING_KEYS)
add_ops_compile_options(
OP_NAME ${OP_COMPILE_OP_NAME}
OPTIONS --tiling_key=${STRING_TILING_KEYS}
)
else()
file(APPEND ${ASCEND_CUSTOM_TILING_KEYS}
"${OP_COMPILE_OP_NAME},${OP_COMPILE_COMPUTE_UNIT},${OP_COMPILE_TILING_KEYS}\n"
)
endif()
endfunction()
function(add_opc_config)
cmake_parse_arguments(OP_COMPILE "" "OP_NAME" "COMPUTE_UNIT;CONFIG" ${ARGN})
if(NOT ADD_OPS_COMPILE_OPTION_V2)
return()
endif()
if(NOT OP_COMPILE_CONFIG)
return()
endif()
string(REPLACE "," ";" OP_COMPILE_CONFIG_LIST "${OP_COMPILE_CONFIG}")
set(_OPC_CONFIG)
foreach(_option ${OP_COMPILE_CONFIG_LIST})
if("${_option}" STREQUAL "ccec_g")
list(APPEND _OPC_CONFIG "-g")
elseif("${_option}" STREQUAL "ccec_O0")
list(APPEND _OPC_CONFIG "-O0")
elseif("${_option}" STREQUAL "oom")
list(APPEND _OPC_CONFIG "--oom")
elseif("${_option}" STREQUAL "dump_cce")
list(APPEND _OPC_CONFIG "--save-temp-files")
endif()
endforeach()
if(_OPC_CONFIG)
add_ops_compile_options(
OP_NAME ${OP_COMPILE_OP_NAME}
OPTIONS ${_OPC_CONFIG}
)
endif()
endfunction()
function(add_ops_src_copy)
cmake_parse_arguments(SRC_COPY "" "TARGET_NAME;SRC;DST;BE_RELIED;COMPUTE_UNIT" "" ${ARGN})
set(OPS_UTILS_INC_KERNEL_TARGET ops_utils_inc_kernel_${SRC_COPY_COMPUTE_UNIT})
if (EXISTS ${OPS_ADV_UTILS_KERNEL_INC})
if (NOT TARGET ${OPS_UTILS_INC_KERNEL_TARGET})
get_filename_component(_ROOT_OPS_SRC_DIR "${SRC_COPY_DST}" DIRECTORY)
set(OPS_UTILS_INC_KERNEL_DIR ${_ROOT_OPS_SRC_DIR}/ascendc/common)
add_custom_command(OUTPUT ${OPS_UTILS_INC_KERNEL_DIR}
COMMAND mkdir -p ${OPS_UTILS_INC_KERNEL_DIR}
COMMAND cp -rf ${OPS_ADV_UTILS_KERNEL_INC}/*.* ${OPS_UTILS_INC_KERNEL_DIR}
)
add_custom_target(${OPS_UTILS_INC_KERNEL_TARGET}
DEPENDS ${OPS_UTILS_INC_KERNEL_DIR}
)
endif ()
endif ()
if (NOT TARGET ${SRC_COPY_TARGET_NAME})
set(_BUILD_FLAG ${SRC_COPY_DST}/${SRC_COPY_TARGET_NAME}.done)
add_custom_command(OUTPUT ${_BUILD_FLAG}
COMMAND mkdir -p ${SRC_COPY_DST}
COMMAND cp -rf ${SRC_COPY_SRC}/op_kernel/*.* ${SRC_COPY_DST}
COMMAND touch ${_BUILD_FLAG}
)
add_custom_target(${SRC_COPY_TARGET_NAME}
DEPENDS ${_BUILD_FLAG}
)
endif ()
if (TARGET ${OPS_UTILS_INC_KERNEL_TARGET})
add_dependencies(${SRC_COPY_TARGET_NAME} ${OPS_UTILS_INC_KERNEL_TARGET})
endif ()
if (DEFINED SRC_COPY_BE_RELIED)
add_dependencies(${SRC_COPY_BE_RELIED} ${SRC_COPY_TARGET_NAME})
endif ()
endfunction()
function(add_bin_compile_target)
cmake_parse_arguments(BINARY "" "COMPUTE_UNIT" "OP_INFO" ${ARGN})
set(_INSTALL_DIR packages/vendors/${VENDOR_NAME}/op_impl/ai_core/tbe/kernel)
set(_OUT_DIR ${ASCEND_BINARY_OUT_DIR}/${BINARY_COMPUTE_UNIT})
set(BIN_OUT_DIR ${_OUT_DIR}/bin)
set(GEN_OUT_DIR ${_OUT_DIR}/gen)
set(SRC_OUT_DIR ${_OUT_DIR}/src)
file(MAKE_DIRECTORY ${BIN_OUT_DIR})
foreach(_op_info ${BINARY_OP_INFO})
get_filename_component(_op_name "${_op_info}" NAME)
set(${_op_name}_dir ${_op_info})
endforeach()
set(_ops_target_list)
set(compile_scripts)
file(GLOB scripts_list ${GEN_OUT_DIR}/*.sh)
list(APPEND compile_scripts ${scripts_list})
foreach(bin_script ${compile_scripts})
get_filename_component(bin_file ${bin_script} NAME_WE)
string(REPLACE "-" ";" bin_sep ${bin_file})
list(GET bin_sep 0 op_type)
list(GET bin_sep 1 op_file)
list(GET bin_sep 2 op_index)
if (NOT DEFINED ${op_file}_dir)
continue()
endif ()
if (NOT TARGET ${op_file})
add_custom_target(${op_file})
add_dependencies(ops_kernel ${op_file})
endif ()
set(OP_TARGET_NAME ${op_file}_${BINARY_COMPUTE_UNIT})
if (NOT TARGET ${OP_TARGET_NAME})
add_custom_target(${OP_TARGET_NAME})
add_dependencies(${op_file} ${OP_TARGET_NAME})
list(APPEND _ops_target_list ${OP_TARGET_NAME})
set(OP_SRC_OUT_DIR ${SRC_OUT_DIR}/${op_file})
set(OP_BIN_OUT_DIR ${BIN_OUT_DIR}/${op_file})
file(MAKE_DIRECTORY ${OP_SRC_OUT_DIR})
add_ops_src_copy(
TARGET_NAME
${OP_TARGET_NAME}_src_copy
SRC
${${op_file}_dir}
DST
${OP_SRC_OUT_DIR}
COMPUTE_UNIT
${BINARY_COMPUTE_UNIT}
)
if (DEFINED ${op_file}_depends)
foreach(depend_info ${${op_file}_depends})
get_filename_component(_depend_op_name "${depend_info}" NAME)
set(_depend_op_target ${_depend_op_name}_${BINARY_COMPUTE_UNIT}_src_copy)
add_ops_src_copy(
TARGET_NAME
${_depend_op_target}
SRC
${CMAKE_SOURCE_DIR}/${depend_info}
DST
${SRC_OUT_DIR}/${_depend_op_name}
COMPUTE_UNIT
${BINARY_COMPUTE_UNIT}
BE_RELIED
${OP_TARGET_NAME}_src_copy
)
endforeach()
endif ()
set(DYNAMIC_PY_FILE ${OP_SRC_OUT_DIR}/${op_type}.py)
add_custom_command(OUTPUT ${DYNAMIC_PY_FILE}
COMMAND cp -rf ${ASCEND_IMPL_OUT_DIR}/dynamic/${op_file}.py ${DYNAMIC_PY_FILE}
)
add_custom_target(${OP_TARGET_NAME}_py_copy
DEPENDS ${DYNAMIC_PY_FILE}
)
add_custom_command(OUTPUT ${OP_BIN_OUT_DIR}
COMMAND mkdir -p ${OP_BIN_OUT_DIR}
)
add_custom_target(${OP_TARGET_NAME}_mkdir
DEPENDS ${OP_BIN_OUT_DIR}
)
install(DIRECTORY ${OP_BIN_OUT_DIR}
DESTINATION ${_INSTALL_DIR}/${BINARY_COMPUTE_UNIT} OPTIONAL
)
install(FILES ${BIN_OUT_DIR}/${op_file}.json
DESTINATION ${_INSTALL_DIR}/config/${BINARY_COMPUTE_UNIT} OPTIONAL
)
endif ()
set(_group "1-0")
if (DEFINED ASCEND_OP_NAME AND NOT "${ASCEND_OP_NAME}" STREQUAL "")
if (NOT "${ASCEND_OP_NAME}" STREQUAL "all" AND NOT "${ASCEND_OP_NAME}" STREQUAL "ALL")
if (${op_file} IN_LIST ASCEND_OP_NAME)
list(LENGTH ASCEND_OP_NAME _len)
list(FIND ASCEND_OP_NAME ${op_file} _index)
math(EXPR _next_index "${_index} + 1")
if (${_next_index} LESS ${_len})
list(GET ASCEND_OP_NAME ${_next_index} _group_str)
set(_regex "^[0-9]+-[0-9]+$")
string(REGEX MATCH "${_regex}" match "${_group_str}")
if (match)
set(_group ${_group_str})
endif ()
endif ()
endif ()
endif ()
endif ()
string(REPLACE "-" ";" _group_sep ${_group})
list(GET _group_sep 1 start_index)
set(end_index ${op_index})
list(GET _group_sep 0 step)
set(_compile_flag false)
if (${start_index} LESS ${end_index})
foreach(i RANGE ${start_index} ${end_index} ${step})
if (${i} EQUAL ${end_index})
set(_compile_flag true)
break()
endif ()
endforeach()
elseif (${start_index} EQUAL ${end_index})
set(_compile_flag true)
else()
set(_compile_flag false)
endif ()
if (_compile_flag)
set(_BUILD_COMMAND)
set(_BUILD_FLAG ${GEN_OUT_DIR}/${OP_TARGET_NAME}_${op_index}.done)
if (ENABLE_OPS_HOST)
list(APPEND _BUILD_COMMAND export ASCEND_CUSTOM_OPP_PATH=${CUSTOM_DIR} &&)
endif ()
list(APPEND _BUILD_COMMAND export HI_PYTHON="python3" &&)
list(APPEND _BUILD_COMMAND export TILINGKEY_PAR_COMPILE=1 &&)
list(APPEND _BUILD_COMMAND export BIN_FILENAME_HASHED=1 &&)
list(APPEND _BUILD_COMMAND bash ${bin_script} ${OP_SRC_OUT_DIR}/${op_type}.py ${OP_BIN_OUT_DIR})
if(CMAKE_GENERATOR MATCHES "Unix Makefiles")
list(APPEND _BUILD_COMMAND && echo $(MAKE))
endif()
add_custom_command(OUTPUT ${_BUILD_FLAG}
COMMAND ${_BUILD_COMMAND}
COMMAND touch ${_BUILD_FLAG}
WORKING_DIRECTORY ${GEN_OUT_DIR}
)
add_custom_target(${OP_TARGET_NAME}_${op_index}
DEPENDS ${_BUILD_FLAG}
)
if (ENABLE_OPS_HOST)
add_dependencies(${OP_TARGET_NAME}_${op_index} optiling generate_ops_info)
endif ()
add_dependencies(${OP_TARGET_NAME}_${op_index} ${OP_TARGET_NAME}_src_copy ${OP_TARGET_NAME}_py_copy ${OP_TARGET_NAME}_mkdir)
add_dependencies(${OP_TARGET_NAME} ${OP_TARGET_NAME}_${op_index})
endif ()
endforeach()
if (_ops_target_list)
set(OPS_CONFIG_TARGET ops_config_${BINARY_COMPUTE_UNIT})
set(BINARY_INFO_CONFIG_FILE ${BIN_OUT_DIR}/binary_info_config.json)
add_custom_command(OUTPUT ${BINARY_INFO_CONFIG_FILE}
COMMAND ${HI_PYTHON} ${ASCENDC_CMAKE_UTIL_DIR}/ascendc_ops_config.py -p ${BIN_OUT_DIR} -s ${BINARY_COMPUTE_UNIT}
)
add_custom_target(${OPS_CONFIG_TARGET}
DEPENDS ${BINARY_INFO_CONFIG_FILE}
)
add_dependencies(ops_config ${OPS_CONFIG_TARGET})
foreach(_op_target ${_ops_target_list})
add_dependencies(${OPS_CONFIG_TARGET} ${_op_target})
endforeach()
install(FILES ${BINARY_INFO_CONFIG_FILE}
DESTINATION ${_INSTALL_DIR}/config/${BINARY_COMPUTE_UNIT} OPTIONAL
)
endif ()
endfunction()
function(redefine_file_macro)
cmake_parse_arguments(_FILE "" "" "TARGET_NAME" ${ARGN})
foreach(_target_name ${_FILE_TARGET_NAME})
target_compile_options(${_target_name} PRIVATE
-Wno-builtin-macro-redefined
)
get_target_property(_srcs ${_target_name} SOURCES)
foreach(_src ${_srcs})
get_filename_component(_src_name "${_src}" NAME)
set_source_files_properties(${_src}
PROPERTIES COMPILE_DEFINITIONS __FILE__="${_src_name}"
)
endforeach()
endforeach()
endfunction()
function(add_static_ops)
cmake_parse_arguments(STATIC "" "SRC_DIR" "ACLNN_SRC;ACLNN_INNER_SRC" ${ARGN})
set(prepare_ops_adv_static_target prepare_ops_adv_static)
set(static_src_temp_dir ${CMAKE_CURRENT_BINARY_DIR}/static_src_temp_dir)
set(modified_files)
foreach(ops_type ${OPS_STATIC_TYPES})
get_target_property(all_srcs aclnn_ops_${ops_type} SOURCES)
set(add_srcs)
set(generate_aclnn_srcs)
foreach(_src ${all_srcs})
string(REGEX MATCH "^${STATIC_SRC_DIR}" is_match "${_src}")
if (is_match)
list(APPEND add_srcs ${_src})
endif ()
endforeach()
foreach(_src ${add_srcs})
get_filename_component(name_without_ext ${_src} NAME_WE)
string(REGEX REPLACE "^aclnn_" "" _op_name ${name_without_ext})
foreach(_aclnn_src ${STATIC_ACLNN_SRC})
get_filename_component(aclnn_name ${_aclnn_src} NAME_WE)
if("aclnn_${_op_name}" STREQUAL "${aclnn_name}")
list(APPEND generate_aclnn_srcs ${_aclnn_src})
break()
endif()
endforeach()
foreach(_aclnn_inner_src ${STATIC_ACLNN_INNER_SRC})
get_filename_component(aclnn_inner_name ${_aclnn_inner_src} NAME_WE)
if("aclnnInner_${_op_name}" STREQUAL "${aclnn_inner_name}")
list(APPEND generate_aclnn_srcs ${_aclnn_inner_src})
break()
endif()
endforeach()
endforeach()
if(add_srcs)
list(TRANSFORM add_srcs REPLACE "${STATIC_SRC_DIR}" "${static_src_temp_dir}" OUTPUT_VARIABLE add_static_srcs)
list(APPEND modified_files ${add_static_srcs})
set(aclnn_ops_static_target aclnn_ops_${ops_type}_static)
set_source_files_properties(${add_static_srcs}
TARGET_DIRECTORY ${aclnn_ops_static_target}
PROPERTIES GENERATED TRUE
)
target_sources(${aclnn_ops_static_target} PRIVATE
${add_static_srcs}
)
add_dependencies(${aclnn_ops_static_target} ${prepare_ops_adv_static_target})
endif()
if(generate_aclnn_srcs)
list(REMOVE_DUPLICATES generate_aclnn_srcs)
set(aclnn_op_target acl_op_${ops_type}_builtin)
set_source_files_properties(${generate_aclnn_srcs}
TARGET_DIRECTORY ${aclnn_op_target}
PROPERTIES GENERATED TRUE
)
target_sources(${aclnn_op_target} PRIVATE
${generate_aclnn_srcs}
)
endif()
endforeach()
if(NOT TARGET ${prepare_ops_adv_static_target})
list(REMOVE_DUPLICATES modified_files)
add_custom_command(OUTPUT ${static_src_temp_dir}
COMMAND mkdir -p ${static_src_temp_dir}
COMMAND cp -rf ${STATIC_SRC_DIR}/src ${static_src_temp_dir}
COMMAND ${HI_PYTHON} -B ${OPS_STATIC_SCRIPT} InsertIni -p ${static_src_temp_dir} -f ${modified_files}
)
add_custom_target(${prepare_ops_adv_static_target}
DEPENDS ${static_src_temp_dir}
)
endif()
endfunction()
if (BUILD_OPEN_PROJECT)
if (TESTS_UT_OPS_TEST)
include(${OPS_ADV_CMAKE_DIR}/func_utest.cmake)
endif ()
if (TESTS_EXAMPLE_OPS_TEST)
include(${OPS_ADV_CMAKE_DIR}/func_examples.cmake)
endif ()
endif ()

12
csrc/cmake/intf.cmake Normal file
View File

@@ -0,0 +1,12 @@
# Copyright (c) 2024 Huawei Technologies Co., Ltd.
# This file is a part of the CANN Open Software.
# Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
# Please refer to the License for details. You may not use this file except in compliance with the License.
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.
# ======================================================================================================================
if (BUILD_OPEN_PROJECT)
include(${OPS_ADV_CMAKE_DIR}/intf_pub.cmake)
endif ()

75
csrc/cmake/intf_pub.cmake Normal file
View File

@@ -0,0 +1,75 @@
# Copyright (c) 2024 Huawei Technologies Co., Ltd.
# This file is a part of the CANN Open Software.
# Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
# Please refer to the License for details. You may not use this file except in compliance with the License.
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.
# ======================================================================================================================
# Custom package scenario, public compilation configuration for Host side targets
# Note: To ensure compatibility with the built-in package compilation process, the intf_pub name cannot be changed
add_library(intf_pub INTERFACE)
target_include_directories(intf_pub
INTERFACE
${ASCEND_CANN_PACKAGE_PATH}/include
${ASCEND_CANN_PACKAGE_PATH}/include/external
${ASCEND_CANN_PACKAGE_PATH}/include/experiment/platform
${ASCEND_CANN_PACKAGE_PATH}/include/experiment/runtime
${ASCEND_CANN_PACKAGE_PATH}/include/experiment/msprof
)
target_link_directories(intf_pub
INTERFACE
${ASCEND_CANN_PACKAGE_PATH}/lib64
)
target_compile_options(intf_pub
INTERFACE
-fPIC
-O2
-Wall -Wundef -Wcast-qual -Wpointer-arith -Wdate-time
-Wfloat-equal -Wformat=2 -Wshadow
-Wsign-compare -Wunused-macros -Wvla -Wdisabled-optimization -Wempty-body -Wignored-qualifiers
-Wimplicit-fallthrough=3 -Wtype-limits -Wshift-negative-value -Wswitch-default
-Wframe-larger-than=32768 -Woverloaded-virtual
-Wnon-virtual-dtor -Wshift-overflow=2 -Wshift-count-overflow
-Wwrite-strings -Wmissing-format-attribute -Wformat-nonliteral
-Wdelete-non-virtual-dtor -Wduplicated-cond
-Wtrampolines -Wsized-deallocation -Wlogical-op -Wsuggest-attribute=format
-Wduplicated-branches
-Wmissing-include-dirs -Wformat-signedness
-Wreturn-local-addr -Wextra
-Wredundant-decls -Wfloat-conversion
-Wno-write-strings -Wall -Wno-dangling-else -Wno-comment -Wno-conversion-null -Wno-return-type
-Wno-unknown-pragmas -Wno-sign-compare
-Wno-error=undef
-Wno-error=comment
-Wno-error=conversion-null
-Wno-error=dangling-else
-Wno-error=return-type
-Wno-error=shadow
-Wno-error=sign-compare
-Wno-error=unknown-pragmas
-Wno-error=unused-parameter
-Wno-error=cast-qual
-Wno-error=format=
-Wno-error=maybe-uninitialized
-Wno-error=missing-field-initializers
-Wno-error=redundant-decls
-Wno-error=unused-variable
$<$<COMPILE_LANGUAGE:C>:-Wnested-externs>
$<$<CONFIG:Debug>:-g>
$<IF:$<VERSION_GREATER:${CMAKE_C_COMPILER_VERSION},4.8.5>,-fstack-protector-strong,-fstack-protector-all>
)
target_compile_definitions(intf_pub
INTERFACE
$<$<COMPILE_LANGUAGE:CXX>:_GLIBCXX_USE_CXX11_ABI=0>
$<$<CONFIG:Release>:_FORTIFY_SOURCE=2>
)
target_link_options(intf_pub
INTERFACE
$<$<STREQUAL:$<TARGET_PROPERTY:TYPE>,EXECUTABLE>:-pie>
$<$<CONFIG:Release>:-s>
-Wl,-z,relro
-Wl,-z,now
-Wl,-z,noexecstack
)

View File

@@ -0,0 +1,113 @@
# Copyright (c) 2025 Huawei Technologies Co., Ltd.
# This file is a part of the CANN Open Software.
# Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
# Please refer to the License for details. You may not use this file except in compliance with the License.
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.
# ======================================================================================================================
if (alog_FOUND)
message(STATUS "Package alog has been found.")
return()
endif()
set(_cmake_targets_defined "")
set(_cmake_targets_not_defined "")
set(_cmake_expected_targets "")
foreach(_cmake_expected_target IN ITEMS slog alog alog_headers)
list(APPEND _cmake_expected_targets "${_cmake_expected_target}")
if(TARGET "${_cmake_expected_target}")
list(APPEND _cmake_targets_defined "${_cmake_expected_target}")
else()
list(APPEND _cmake_targets_not_defined "${_cmake_expected_target}")
endif()
endforeach()
unset(_cmake_expected_target)
if(_cmake_targets_defined STREQUAL _cmake_expected_targets)
unset(_cmake_targets_defined)
unset(_cmake_targets_not_defined)
unset(_cmake_expected_targets)
unset(CMAKE_IMPORT_FILE_VERSION)
cmake_policy(POP)
return()
endif()
if(NOT _cmake_targets_defined STREQUAL "")
string(REPLACE ";" ", " _cmake_targets_defined_text "${_cmake_targets_defined}")
string(REPLACE ";" ", " _cmake_targets_not_defined_text "${_cmake_targets_not_defined}")
message(FATAL_ERROR "Some (but not all) targets in this export set were already defined.\nTargets Defined: ${_cmake_targets_defined_text}\nTargets not yet defined: ${_cmake_targets_not_defined_text}\n")
endif()
unset(_cmake_targets_defined)
unset(_cmake_targets_not_defined)
unset(_cmake_expected_targets)
find_path(_INCLUDE_DIR
NAMES base/alog_pub.h
NO_CMAKE_SYSTEM_PATH
NO_CMAKE_FIND_ROOT_PATH)
find_library(slog_SHARED_LIBRARY
NAMES libascendalog.so
PATH_SUFFIXES lib64
NO_CMAKE_SYSTEM_PATH
NO_CMAKE_FIND_ROOT_PATH)
find_library(alog_SHARED_LIBRARY
NAMES libascendalog.so
PATH_SUFFIXES lib64
NO_CMAKE_SYSTEM_PATH
NO_CMAKE_FIND_ROOT_PATH)
include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(alog
FOUND_VAR
alog_FOUND
REQUIRED_VARS
_INCLUDE_DIR
slog_SHARED_LIBRARY
alog_SHARED_LIBRARY
)
if(alog_FOUND)
set(alog_INCLUDE_DIR "${_INCLUDE_DIR}")
include(CMakePrintHelpers)
message(STATUS "Variables in alog module:")
cmake_print_variables(alog_INCLUDE_DIR)
cmake_print_variables(slog_SHARED_LIBRARY)
cmake_print_variables(alog_SHARED_LIBRARY)
add_library(slog SHARED IMPORTED)
set_target_properties(slog PROPERTIES
INTERFACE_COMPILE_DEFINITIONS "LOG_CPP;PROCESS_LOG"
INTERFACE_LINK_LIBRARIES "alog_headers"
IMPORTED_LOCATION "${slog_SHARED_LIBRARY}"
)
add_library(alog SHARED IMPORTED)
set_target_properties(alog PROPERTIES
INTERFACE_COMPILE_DEFINITIONS "LOG_CPP;PROCESS_LOG"
INTERFACE_LINK_LIBRARIES "alog_headers"
IMPORTED_LOCATION "${alog_SHARED_LIBRARY}"
)
add_library(alog_headers INTERFACE IMPORTED)
set_target_properties(alog_headers PROPERTIES
INTERFACE_INCLUDE_DIRECTORIES "${alog_INCLUDE_DIR}"
)
include(CMakePrintHelpers)
cmake_print_properties(TARGETS slog
PROPERTIES INTERFACE_COMPILE_DEFINITIONS INTERFACE_LINK_LIBRARIES IMPORTED_LOCATION
)
cmake_print_properties(TARGETS alog
PROPERTIES INTERFACE_COMPILE_DEFINITIONS INTERFACE_LINK_LIBRARIES IMPORTED_LOCATION
)
cmake_print_properties(TARGETS alog_headers
PROPERTIES INTERFACE_INCLUDE_DIRECTORIES
)
endif()
# Cleanup temporary variables.
set(_INCLUDE_DIR)

View File

@@ -0,0 +1,130 @@
#!/bin/bash
# Copyright (c) 2024 Huawei Technologies Co., Ltd.
# This file is a part of the CANN Open Software.
# Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
# Please refer to the License for details. You may not use this file except in compliance with the License.
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.
# ======================================================================================================================
CPU_NUM=$(($(cat /proc/cpuinfo | grep "^processor" | wc -l)*2))
JOB_NUM="-j${CPU_NUM}"
while [[ $# -gt 0 ]]; do
case $1 in
-s)
PATH_TO_SOURCE="$2"
shift 2
;;
-b)
PATH_TO_BUILD="$2"
shift 2
;;
-p)
ASCEND_CANN_PACKAGE_PATH="$2"
shift 2
;;
--autogen-dir)
ASCEND_AUTOGEN_DIR="$2"
shift 2
;;
--build-open-project)
BUILD_OPEN_PROJECT="$2"
shift 2
;;
--binary-out-dir)
ASCEND_BINARY_OUT_DIR="$2"
shift 2
;;
--impl-out-dir)
ASCEND_IMPL_OUT_DIR="$2"
shift 2
;;
--op-build-tool)
OP_BUILD_TOOL="$2"
shift 2
;;
--ascend-cmake-dir)
ASCEND_CMAKE_DIR="$2"
shift 2
;;
--tiling-key)
TILING_KEY="$2"
shift 2
;;
--ops-compile-options)
OPS_COMPILE_OPTIONS="$2"
shift 2
;;
--check-compatible)
CHECK_COMPATIBLE="$2"
shift 2
;;
--ascend-compute_unit)
ASCEND_COMPUTE_UNIT="$2"
shift 2
;;
--ascend-op-name)
ASCEND_OP_NAME="$2"
shift 2
;;
--op_debug_config)
OP_DEBUG_CONFIG="$2"
shift 2
;;
*)
break
;;
esac
done
function clean() {
if [ -n "${PATH_TO_BUILD}" ];then
rm -rf ${PATH_TO_BUILD}
mkdir -p ${PATH_TO_BUILD}
fi
}
function convert_string() {
local _input=$1
_output=$(echo $_input | sed 's/::/;/g')
echo "${_output}"
}
function set_env() {
CONVERT_TILING_KEY="$(convert_string ${TILING_KEY})"
CONVERT_OPS_COMPILE_OPTIONS="$(convert_string ${OPS_COMPILE_OPTIONS})"
CONVERT_ASCEND_COMPUTE_UNIT="$(convert_string ${ASCEND_COMPUTE_UNIT})"
}
function build() {
cd ${PATH_TO_BUILD}
cmake ${PATH_TO_SOURCE} \
-DBUILD_OPEN_PROJECT=${BUILD_OPEN_PROJECT} \
-DPREPARE_BUILD=ON \
-DCUSTOM_ASCEND_CANN_PACKAGE_PATH=${ASCEND_CANN_PACKAGE_PATH} \
-DASCEND_AUTOGEN_DIR=${ASCEND_AUTOGEN_DIR} \
-DASCEND_BINARY_OUT_DIR=${ASCEND_BINARY_OUT_DIR} \
-DASCEND_IMPL_OUT_DIR=${ASCEND_IMPL_OUT_DIR} \
-DOP_BUILD_TOOL=${OP_BUILD_TOOL} \
-DASCEND_CMAKE_DIR=${ASCEND_CMAKE_DIR} \
-DCHECK_COMPATIBLE=${CHECK_COMPATIBLE} \
-DTILING_KEY="${CONVERT_TILING_KEY}" \
-DOPS_COMPILE_OPTIONS="${CONVERT_OPS_COMPILE_OPTIONS}" \
-DASCEND_COMPUTE_UNIT=${CONVERT_ASCEND_COMPUTE_UNIT} \
-DOP_DEBUG_CONFIG=${OP_DEBUG_CONFIG} \
-DASCEND_OP_NAME=${ASCEND_OP_NAME}
make ${JOB_NUM} prepare_build
}
function main() {
clean
set_env
build
}
main

View File

@@ -0,0 +1,54 @@
# Copyright (c) 2025 Huawei Technologies Co., Ltd.
# This file is a part of the CANN Open Software.
# Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
# Please refer to the License for details. You may not use this file except in compliance with the License.
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.
# ======================================================================================================================
add_ops_compile_options(
OP_NAME GroupedMatmulSwigluQuantWeightNzTensorList
OPTIONS --cce-auto-sync=off
-Wno-deprecated-declarations
-Werror
)
target_sources(op_host_aclnnExc PRIVATE
grouped_matmul_swiglu_quant_weight_nz_tensor_list_def.cpp
)
target_sources(opapi PRIVATE
grouped_matmul_swiglu_quant_weight_nz_tensor_list.cpp
aclnn_grouped_matmul_swiglu_quant_weight_nz_tensor_list.cpp
)
if (NOT BUILD_OPEN_PROJECT)
target_sources(aclnn_ops_train PRIVATE
grouped_matmul_swiglu_quant_weight_nz_tensor_list.cpp
aclnn_grouped_matmul_swiglu_quant_weight_nz_tensor_list.cpp
)
target_sources(aclnn_ops_infer PRIVATE
grouped_matmul_swiglu_quant_weight_nz_tensor_list.cpp
aclnn_grouped_matmul_swiglu_quant_weight_nz_tensor_list.cpp
)
endif ()
target_sources(optiling PRIVATE
grouped_matmul_swiglu_quant_weight_nz_tensor_list_tiling.cpp
)
target_include_directories(optiling PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}
)
target_sources(opsproto PRIVATE
grouped_matmul_swiglu_quant_weight_nz_tensor_list_proto.cpp
)
file(GLOB _GMM_Aclnn_header "${CMAKE_CURRENT_SOURCE_DIR}/aclnn_grouped_matmul_swiglu_quant_weight_nz_tensor_list.h")
install(FILES ${_GMM_Aclnn_header}
DESTINATION ${ACLNN_INC_INSTALL_DIR} OPTIONAL
)

View File

@@ -0,0 +1,329 @@
/*
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include <dlfcn.h>
#include <new>
#include "aclnn_kernels/contiguous.h"
#include "acl/acl.h"
#include "aclnn/aclnn_base.h"
#include "aclnn_kernels/common/op_error_check.h"
#include "opdev/common_types.h"
#include "opdev/data_type_utils.h"
#include "opdev/format_utils.h"
#include "opdev/op_dfx.h"
#include "opdev/op_executor.h"
#include "opdev/op_log.h"
#include "opdev/platform.h"
#include "opdev/shape_utils.h"
#include "opdev/tensor_view_utils.h"
#include "opdev/make_op_executor.h"
#include "grouped_matmul_swiglu_quant_weight_nz_tensor_list.h"
#include "aclnn_grouped_matmul_swiglu_quant_weight_nz_tensor_list.h"
using namespace op;
#ifdef __cplusplus
extern "C" {
#endif
static constexpr int64_t SPLIT = 2;
static constexpr int64_t K_LIMIT = 65536;
static constexpr int64_t N_LIMIT = 4096;
static constexpr int64_t NZ_DIM_3 = 32;
static constexpr int64_t NZ_DIM_2 = 16;
static constexpr int64_t OUTPUT_IDX_0 = 0;
static constexpr int64_t OUTPUT_IDX_1 = 1;
static constexpr size_t X_DIM_LIMIT = 2;
static constexpr size_t WEIGHT_ND_DIM_LIMIT = 2;
static constexpr size_t WEIGHT_NZ_DIM_LIMIT = 4;
static constexpr size_t WEIGHT_SCALE_DIM_LIMIT = 1;
static constexpr size_t TOKEN_SCALE_DIM_LIMIT = 1;
static constexpr size_t GROUP_LIST_DIM_LIMIT = 1;
static constexpr size_t QUANTOUT_DIM_LIMIT = 2;
static constexpr size_t QUANTSCALEOUT_DIM_LIMIT = 1;
static const std::initializer_list<DataType> X_DTYPE_SUPPORT_LIST = {DataType::DT_INT8};
static const std::initializer_list<DataType> WEIGHT_DTYPE_SUPPORT_LIST = {DataType::DT_INT8};
static const std::initializer_list<DataType> WEIGHT_SCALE_DTYPE_SUPPORT_LIST = {DataType::DT_FLOAT, DataType::DT_FLOAT16, DataType::DT_BF16};
static const std::initializer_list<DataType> X_SCALE_DTYPE_SUPPORT_LIST = {DataType::DT_FLOAT, DataType::DT_FLOAT16, DataType::DT_BF16};
static const std::initializer_list<DataType> GROUP_LIST_DTYPE_SUPPORT_LIST = {DataType::DT_INT64};
static const std::initializer_list<DataType> QUANTOUT_DTYPE_SUPPORT_LIST = {DataType::DT_INT8};
static const std::initializer_list<DataType> QUANTSCALEOUT_DTYPE_SUPPORT_LIST = {DataType::DT_FLOAT};
static bool CheckNotNull(const aclTensor* x, const aclTensorList* weight, const aclTensor* bias, const aclTensor* offset,
const aclTensorList* weightScale, const aclTensor* xScale, const aclTensor* groupList,
const aclTensor* output, const aclTensor* outputScale, const aclTensor* outputOffset)
{
OP_CHECK_NULL(x, return false);
OP_CHECK_NULL(weight, return false);
OP_CHECK_NULL(weightScale, return false);
OP_CHECK_NULL(xScale, return false);
OP_CHECK_NULL(groupList, return false);
OP_CHECK_NULL(output, return false);
OP_CHECK_NULL(outputScale, return false);
if (bias != nullptr) {
OP_LOGW("aclnnGroupedMatmulSwigluQuantWeightNzTensorList, The current version does not support the scenario where bias is not 0. "
"Features and accuracy are not guaranteed if inputting bias with values other than 0s.");
}
if (offset != nullptr) {
OP_LOGW("aclnnGroupedMatmulSwigluQuantWeightNzTensorList, The current version does not support the scenario where offset is not 0. "
"Features and accuracy are not guaranteed if inputting bias with values other than 0s.");
}
if (outputOffset != nullptr) {
OP_LOGW("aclnnGroupedMatmulSwigluQuantWeightNzTensorList, The current version does not support the scenario where outputOffset is not 0. "
"Features and accuracy are not guaranteed if inputting bias with values other than 0s.");
}
return true;
}
static bool CheckInputOutDims(const aclTensor* x, const aclTensorList* weight, const aclTensorList* weightScale,
const aclTensor* xScale, const aclTensor* groupList,
const aclTensor* output, const aclTensor* outputScale)
{
OP_CHECK_WRONG_DIMENSION(x, X_DIM_LIMIT, return false);
op::Format weightViewFormat = (*weight)[0]->GetViewFormat();
if (IsPrivateFormat(weightViewFormat)){
OP_CHECK_WRONG_DIMENSION((*weight)[0], WEIGHT_NZ_DIM_LIMIT, return false);
} else {
OP_CHECK_WRONG_DIMENSION((*weight)[0], WEIGHT_ND_DIM_LIMIT, return false);
}
OP_CHECK_WRONG_DIMENSION((*weightScale)[0], WEIGHT_SCALE_DIM_LIMIT, return false);
OP_CHECK_WRONG_DIMENSION(xScale, TOKEN_SCALE_DIM_LIMIT, return false);
OP_CHECK_WRONG_DIMENSION(groupList, GROUP_LIST_DIM_LIMIT, return false);
OP_CHECK_WRONG_DIMENSION(output, QUANTOUT_DIM_LIMIT, return false);
OP_CHECK_WRONG_DIMENSION(outputScale, QUANTSCALEOUT_DIM_LIMIT, return false);
return true;
}
static bool CheckInputOutShape(const aclTensor* x, const aclTensorList* weight, const aclTensorList* weightScale,
const aclTensor* xScale, const aclTensor* groupList,
const aclTensor* output, const aclTensor* outputScale)
{
int64_t m = x->GetViewShape().GetDim(0);
int64_t k = x->GetViewShape().GetDim(1);
int64_t n = (*weightScale)[0]->GetViewShape().GetDim(0);
int64_t e = weight->Size();
if (n % SPLIT != 0){
OP_LOGE(ACLNN_ERR_PARAM_INVALID,
"aclnnGroupedMatmulSwigluQuantWeightNzTensorList, N is %ld , not an even number.", n);
return false;
}
int64_t nAfterHalve = static_cast<int64_t>(n / SPLIT);
// x shape is expected to be [M, K]
op::Shape xExpectShape = {m, k};
// The ND shape of each weight in TensorList is expected to be [K, N]
op::Shape weightNDExpectShape = {k, n};
// The NZ shape of each weight in TensorList is expected to be [N // 32, K // 16, 16, 32]
op::Shape weightNZExpectShape = {static_cast<int64_t>(n / NZ_DIM_3),
static_cast<int64_t>(k / NZ_DIM_2),
NZ_DIM_2, NZ_DIM_3};
// weightScale shape is expected to be [N]
op::Shape weightScaleExpectShape = {n};
// xScale shape is expected to be [E, N]
op::Shape xScaleExpectShape = {m};
// output shape is expected to be [M, N]
op::Shape outputExpectShape = {m, nAfterHalve};
// outputScale shape is expected to be [M]
op::Shape outputScaleExpectShape = {m};
for (size_t i = 0; i < weight->Size(); ++i) {
op::Format weightViewFormat = (*weight)[i]->GetViewFormat();
OP_CHECK_SHAPE_NOT_EQUAL_WITH_EXPECTED_SIZE(x, xExpectShape, return false);
if (IsPrivateFormat(weightViewFormat)){
OP_CHECK_SHAPE_NOT_EQUAL_WITH_EXPECTED_SIZE((*weight)[i], weightNZExpectShape, return false);
} else {
OP_CHECK_SHAPE_NOT_EQUAL_WITH_EXPECTED_SIZE((*weight)[i], weightNDExpectShape, return false);
}
OP_CHECK_SHAPE_NOT_EQUAL_WITH_EXPECTED_SIZE((*weightScale)[i], weightScaleExpectShape, return false);
}
OP_CHECK_SHAPE_NOT_EQUAL_WITH_EXPECTED_SIZE(xScale, xScaleExpectShape, return false);
OP_CHECK_SHAPE_NOT_EQUAL_WITH_EXPECTED_SIZE(output, outputExpectShape, return false);
OP_CHECK_SHAPE_NOT_EQUAL_WITH_EXPECTED_SIZE(outputScale, outputScaleExpectShape, return false);
// The length of groupList should be less than or equal to the number of experts in weight
int64_t groupListLen = groupList->GetViewShape().GetDim(0);
if(groupListLen > e) {
OP_LOGE(ACLNN_ERR_PARAM_INVALID,
"aclnnGroupedMatmulSwigluQuantWeightNzTensorList, Length of 'groupList' out of range (expected to be in range of [1, %ld], but got %ld)",
e, groupListLen);
return false;
}
if(nAfterHalve > N_LIMIT) {
OP_LOGE(ACLNN_ERR_PARAM_INVALID,
"aclnnGroupedMatmulSwigluQuantWeightNzTensorList, The current version does not support the scenario.\
where N after halve is %ld greater than %ld.",
nAfterHalve, N_LIMIT);
return false;
}
if(k >= K_LIMIT) {
OP_LOGE(ACLNN_ERR_PARAM_INVALID,
"aclnnGroupedMatmulSwigluQuantWeightNzTensorList, The current version does not support the scenario.\
The tail axis dimension of input0(x) is %ld, which need lower than %ld.",
k, K_LIMIT);
return false;
}
return true;
}
static bool CheckDtypeValid(const aclTensor* x, const aclTensorList* weight, const aclTensorList* weightScale,
const aclTensor* xScale, const aclTensor* groupList,
const aclTensor* output, const aclTensor* outputScale)
{
OP_CHECK_DTYPE_NOT_SUPPORT(x, X_DTYPE_SUPPORT_LIST, return false);
for (size_t i = 0; i < weight->Size(); ++i) {
OP_CHECK_DTYPE_NOT_SUPPORT((*weight)[i], WEIGHT_DTYPE_SUPPORT_LIST, return false);
OP_CHECK_DTYPE_NOT_SUPPORT((*weightScale)[i], WEIGHT_SCALE_DTYPE_SUPPORT_LIST, return false);
}
OP_CHECK_DTYPE_NOT_SUPPORT(xScale, X_SCALE_DTYPE_SUPPORT_LIST, return false);
OP_CHECK_DTYPE_NOT_SUPPORT(groupList, GROUP_LIST_DTYPE_SUPPORT_LIST, return false);
OP_CHECK_DTYPE_NOT_SUPPORT(output, QUANTOUT_DTYPE_SUPPORT_LIST, return false);
OP_CHECK_DTYPE_NOT_SUPPORT(outputScale, QUANTSCALEOUT_DTYPE_SUPPORT_LIST, return false);
return true;
}
static bool CheckFormat(const aclTensor* x, const aclTensorList* weight, const aclTensor* output)
{
bool isNZ = (*weight)[0]->GetStorageFormat() == op::Format::FORMAT_FRACTAL_NZ;
if (!isNZ) {
// fp16 in fp32 out that is split k template, not precision-advanced now
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "aclnnGroupedMatmulSwigluQuantWeightNzTensorList, The current version does not support the scenario.\
weight Format expect is FRACTAL_NZ, but got [%s].", op::ToString((*weight)[0]->GetStorageFormat()).GetString());
return false;
}
if (IsPrivateFormat(x->GetStorageFormat())) {
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "aclnnGroupedMatmulSwigluQuantWeightNzTensorList, The current version does not support the scenario.\
x Format Not support Private Format.");
return false;
}
if (IsPrivateFormat(output->GetStorageFormat())) {
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "aclnnGroupedMatmulSwigluQuantWeightNzTensorList, The current version does not support the scenario.\
output Format Not support Private Format.");
return false;
}
return true;
}
static aclnnStatus CheckParams(const aclTensor* x, const aclTensorList* weight, const aclTensor* bias, const aclTensor* offset,
const aclTensorList* weightScale, const aclTensor* xScale, const aclTensor* groupList,
const aclTensor* output, const aclTensor* outputScale, const aclTensor* outputOffset) {
// 1. Check if parameters are null pointers
CHECK_RET(CheckNotNull(x, weight, bias, offset, weightScale, xScale,
groupList, output, outputScale, outputOffset), ACLNN_ERR_PARAM_NULLPTR);
// 2. Verify input and output parameter dimensions
CHECK_RET(CheckInputOutDims(x, weight, weightScale, xScale,
groupList, output, outputScale), ACLNN_ERR_PARAM_INVALID);
// 3. Verify input and output shape parameters
CHECK_RET(CheckInputOutShape(x, weight, weightScale, xScale,
groupList, output, outputScale), ACLNN_ERR_PARAM_INVALID);
// 4. Check if the input data types are within the supported data type range
CHECK_RET(CheckDtypeValid(x, weight, weightScale, xScale,
groupList, output, outputScale), ACLNN_ERR_PARAM_INVALID);
// 5. Check if data format is supported
CHECK_RET(CheckFormat(x, weight, output), ACLNN_ERR_PARAM_INVALID);
return ACLNN_SUCCESS;
}
static aclnnStatus aclnnGroupedMatmulSwigluQuantWeightNzTensorListGetWorkspaceSizeCommon(const aclTensor *x, const aclTensorList *weight,
const aclTensor *bias, const aclTensor *offset,
const aclTensorList *weightScale, const aclTensor *xScale,
const aclTensor *groupList,
aclTensor *output, aclTensor *outputScale,
aclTensor *outputOffset, uint64_t *workspaceSize,
aclOpExecutor **executor){
// Fixed pattern, create OpExecutor
auto uniqueExecutor = CREATE_EXECUTOR();
CHECK_RET(uniqueExecutor.get() != nullptr, ACLNN_ERR_INNER_CREATE_EXECUTOR);
// Fixed pattern, parameter check
auto ret = CheckParams(x, weight, bias, offset, weightScale, xScale,
groupList, output, outputScale, outputOffset);
CHECK_RET(ret == ACLNN_SUCCESS, ret);
// Empty tensor scenario
if (output->IsEmpty() || groupList->IsEmpty() || outputScale->IsEmpty()) {
*workspaceSize = 0;
uniqueExecutor.ReleaseTo(executor);
return ACLNN_SUCCESS;
}
// Convert to contiguous
x = l0op::Contiguous(x, uniqueExecutor.get());
CHECK_RET(x != nullptr, ACLNN_ERR_INNER_NULLPTR);
for (size_t i = 0; i < weight->Size(); ++i) {
(*weight)[i]->SetOriginalShape((*weight)[i]->GetViewShape());
}
xScale = l0op::Contiguous(xScale, uniqueExecutor.get());
CHECK_RET(xScale != nullptr, ACLNN_ERR_INNER_NULLPTR);
groupList = l0op::Contiguous(groupList, uniqueExecutor.get());
CHECK_RET(groupList != nullptr, ACLNN_ERR_INNER_NULLPTR);
// Call L0 operator capability
auto ret_0 = l0op::GroupedMatmulSwigluQuantWeightNzTensorList(x, weight, weightScale, xScale, groupList, uniqueExecutor.get());
CHECK_RET(ret_0 != std::tuple(nullptr, nullptr), ACLNN_ERR_INNER_NULLPTR);
auto out0 = std::get<OUTPUT_IDX_0>(ret_0);
auto ret_1 = l0op::ViewCopy(out0, output, uniqueExecutor.get());
CHECK_RET(ret_1 != nullptr, ACLNN_ERR_INNER_NULLPTR);
auto out1 = std::get<OUTPUT_IDX_1>(ret_0);
auto ret_2 = l0op::ViewCopy(out1, outputScale, uniqueExecutor.get());
CHECK_RET(ret_2 != nullptr, ACLNN_ERR_INNER_NULLPTR);
*workspaceSize = uniqueExecutor->GetWorkspaceSize();
uniqueExecutor.ReleaseTo(executor);
return ACLNN_SUCCESS;
}
aclnnStatus aclnnGroupedMatmulSwigluQuantWeightNzTensorListGetWorkspaceSize(const aclTensor *x, const aclTensorList *weight,
const aclTensor *bias, const aclTensor *offset,
const aclTensorList *weightScale, const aclTensor *xScale,
const aclTensor *groupList,
aclTensor *output, aclTensor *outputScale,
aclTensor *outputOffset, uint64_t *workspaceSize,
aclOpExecutor **executor) {
OP_CHECK_COMM_INPUT(workspaceSize, executor);
L2_DFX_PHASE_1(aclnnGroupedMatmulSwigluQuantWeightNzTensorList,
DFX_IN(x, weight, bias, offset, weightScale, xScale, groupList),
DFX_OUT(output, outputScale, outputOffset));
// weight is forcibly bound to StorageFormat and ViewFormat as NZ in this scenario
CHECK_RET(weight != nullptr, ACLNN_ERR_PARAM_NULLPTR);
for (size_t i = 0; i < weight->Size(); ++i) {
auto storgeShape = (*weight)[i]->GetStorageShape();
auto viewShape = (*weight)[i]->GetViewShape();
aclTensor* weightNZ = const_cast<aclTensor*>((*weight)[i]);
CHECK_COND((storgeShape.GetDimNum() == WEIGHT_NZ_DIM_LIMIT),
ACLNN_ERR_PARAM_INVALID,
"aclnnGroupedMatmulSwigluQuantWeightNZTensorList, The dimnum of storageShape for second input (weight) \
must be 4. \n But StorageShape got %s , and dimNum is %lu.",
op::ToString(storgeShape).GetString(), storgeShape.GetDimNum());
// The StorageFormat of weight is unconditionally regarded as NZ
weightNZ->SetStorageFormat(op::Format::FORMAT_FRACTAL_NZ);
if (viewShape.GetDimNum() == WEIGHT_NZ_DIM_LIMIT){
// If the viewShape of weight is 4-dimensional, it is regarded as NZ
weightNZ->SetViewFormat(op::Format::FORMAT_FRACTAL_NZ);
} else if (viewShape.GetDimNum() == WEIGHT_ND_DIM_LIMIT){
// If the viewShape of weight is 2-dimensional, it is regarded as ND
weightNZ->SetViewFormat(op::Format::FORMAT_ND);
}
}
// Call the common interface
return aclnnGroupedMatmulSwigluQuantWeightNzTensorListGetWorkspaceSizeCommon(x, weight, bias, offset, weightScale, xScale, groupList,
output, outputScale, outputOffset, workspaceSize, executor);
}
aclnnStatus aclnnGroupedMatmulSwigluQuantWeightNzTensorList(void *workspace,
uint64_t workspaceSize,
aclOpExecutor *executor,
aclrtStream stream) {
L2_DFX_PHASE_2(aclnnGroupedMatmulSwigluQuantWeightNzTensorList);
CHECK_COND(CommonOpExecutorRun(workspace, workspaceSize, executor, stream) == ACLNN_SUCCESS, ACLNN_ERR_INNER,
"This is an error in GroupedMatmulSwigluQuantWeightNzTensorList launch aicore");
return ACLNN_SUCCESS;
}
#ifdef __cplusplus
}
#endif

View File

@@ -0,0 +1,56 @@
/*
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef OP_API_INC_GROUPED_MATMUL_SWIGLU_QUANT_WEIGHT_NZ_TENSOR_LIST_H
#define OP_API_INC_GROUPED_MATMUL_SWIGLU_QUANT_WEIGHT_NZ_TENSOR_LIST_H
#include "aclnn/aclnn_base.h"
#ifdef __cplusplus
extern "C" {
#endif
/**
* @brief The first interface of aclnnGroupedMatmulSwigluQuantWeightNzTensorList, which calculates the workspace size according to the specific calculation process.
* @domain aclnn_ops_infer
*
* @param [in] x: Represents x in the formula. The data type supports INT8, and the data format supports ND.
* @param [in] weight:
* Represents weight in the formula. The data type supports INT8, and the data format supports NZ.
* @param [in] weightScale: Represents quantization parameters. The data type supports FLOAT16, BFLOAT16, and FLOAT32. The data format supports ND, with a maximum length of 128.
* Represents per Channel parameters. The data type supports FLOAT16 and BFLOAT16. The data format supports ND.
* @param [in] xScale:
* Represents per Token quantization parameters. The data type supports FLOAT32, and the data format supports ND.
* @param [in] groupList: Required parameter, representing the index situation on the input and output grouping axes. The data type supports INT64.
* @param [out] quantOutput: Represents out in the formula. The data type supports INT8, and the data format supports ND.
* @param [out] quantScaleOutput: Represents outQuantScale in the formula. The data type supports Float32.
* @param [out] workspaceSize: Returns the workspace size that users need to apply for on the npu device side.
* @param [out] executor: Returns the op executor, containing the operator calculation process.
* @return aclnnStatus: Returns the status code.
*/
__attribute__((visibility("default"))) aclnnStatus aclnnGroupedMatmulSwigluQuantWeightNzTensorListGetWorkspaceSize(
const aclTensor *x, const aclTensorList *weight, const aclTensor *bias, const aclTensor *offset,
const aclTensorList *weightScale, const aclTensor *xScale, const aclTensor *groupList,
aclTensor *output, aclTensor *outputScale, aclTensor *outputOffset, uint64_t *workspaceSize, aclOpExecutor **executor);
/**
* @brief The second interface of aclnnGroupedMatmulSwigluQuantWeightNzTensorList, used to execute calculations.
* @param [in] workspace: The starting address of the workspace memory applied for on the npu device side.
* @param [in] workspaceSize: The workspace size applied for on the npu device side, obtained from the first interface aclnnGroupedMatmulSwigluQuantWeightNzTensorListGetWorkspaceSize.
* @param [in] stream: acl stream.
* @param [in] executor: op executor, containing the operator calculation process.
* @return aclnnStatus: Returns the status code.
*/
__attribute__((visibility("default"))) aclnnStatus aclnnGroupedMatmulSwigluQuantWeightNzTensorList(void* workspace,
uint64_t workspaceSize, aclOpExecutor* executor, aclrtStream stream);
#ifdef __cplusplus
}
#endif
#endif

View File

@@ -0,0 +1,56 @@
/*
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "opdev/op_log.h"
#include "opdev/op_dfx.h"
#include "opdev/make_op_executor.h"
#include "grouped_matmul_swiglu_quant_weight_nz_tensor_list.h"
using namespace op;
namespace l0op {
OP_TYPE_REGISTER(GroupedMatmulSwigluQuantWeightNzTensorList);
const std::tuple<aclTensor*, aclTensor*> GroupedMatmulSwigluQuantWeightNzTensorList(const aclTensor *x,
const aclTensorList *weight,
const aclTensorList *perChannelScale,
const aclTensor *perTokenScale,
const aclTensor *groupList,
aclOpExecutor *executor) {
L0_DFX(GroupedMatmulSwigluQuantWeightNzTensorList, x, weight, perChannelScale, perTokenScale, groupList);
if (x == nullptr) {
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "x is nullptr.");
return std::tuple(nullptr, nullptr);
}
int64_t m = perTokenScale->GetViewShape().GetDim(0);
int64_t n = (*perChannelScale)[0]->GetViewShape().GetDim(0);
int64_t nAfterHalve = static_cast<int64_t>(n / 2);
gert::Shape outShape({m, nAfterHalve});
gert::Shape scaleOutShape({m});
auto out = executor->AllocTensor(outShape, DataType::DT_INT8, ge::FORMAT_ND);
auto scaleOut = executor->AllocTensor(scaleOutShape, DataType::DT_FLOAT, ge::FORMAT_ND);
auto ret = INFER_SHAPE(GroupedMatmulSwigluQuantWeightNzTensorList,
OP_INPUT(x, weight, perChannelScale, perTokenScale, groupList),
OP_OUTPUT(out, scaleOut));
if (ret != ACLNN_SUCCESS) {
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "InferShape failed.");
return std::tuple(nullptr, nullptr);
}
ret = ADD_TO_LAUNCHER_LIST_AICORE(GroupedMatmulSwigluQuantWeightNzTensorList,
OP_INPUT(x, weight, perChannelScale, perTokenScale, groupList),
OP_OUTPUT(out, scaleOut));
if (ret != ACLNN_SUCCESS) {
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "ADD_TO_LAUNCHER_LIST_AICORE failed.");
return std::tuple(nullptr, nullptr);
}
return std::tie(out, scaleOut);
}
} // namespace l0op

View File

@@ -0,0 +1,24 @@
/*
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef OP_API_INC_LEVEL0_OP_GROUPED_MATMUL_SWIGLU_QUANT_WEIGHT_NZ_TENSOR_LIST_OP_H
#define OP_API_INC_LEVEL0_OP_GROUPED_MATMUL_SWIGLU_QUANT_WEIGHT_NZ_TENSOR_LIST_OP_H
#include "opdev/op_executor.h"
namespace l0op {
const std::tuple<aclTensor*, aclTensor*> GroupedMatmulSwigluQuantWeightNzTensorList(const aclTensor *x,
const aclTensorList *weight,
const aclTensorList *perChannelScale,
const aclTensor *perTokenScale,
const aclTensor *groupList,
aclOpExecutor *executor);
}
#endif

View File

@@ -0,0 +1,65 @@
/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file grouped_matmul_swiglu_quant_weight_nz_tensor_list_def.cpp
* \brief
*/
#include <cstdint>
#include "register/op_def_registry.h"
namespace ops {
class GroupedMatmulSwigluQuantWeightNzTensorList : public OpDef {
public:
explicit GroupedMatmulSwigluQuantWeightNzTensorList(const char* name) : OpDef(name)
{
this->Input("x")
.ParamType(REQUIRED)
.DataType({ge::DT_INT8,ge::DT_INT8,ge::DT_INT8})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("weight")
.ParamType(DYNAMIC)
.DataType({ge::DT_INT8,ge::DT_INT8,ge::DT_INT8})
.Format({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ});
this->Input("weight_scale")
.ParamType(DYNAMIC)
.DataType({ge::DT_FLOAT, ge::DT_BF16, ge::DT_FLOAT16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("x_scale")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT,ge::DT_FLOAT,ge::DT_FLOAT})
.Format({ge::FORMAT_ND,ge::FORMAT_ND,ge::FORMAT_ND});
this->Input("group_list")
.ParamType(REQUIRED)
.DataType({ge::DT_INT64,ge::DT_INT64,ge::DT_INT64})
.Format({ge::FORMAT_ND,ge::FORMAT_ND,ge::FORMAT_ND});
this->Output("y")
.ParamType(REQUIRED)
.DataType({ge::DT_INT8,ge::DT_INT8,ge::DT_INT8})
.Format({ge::FORMAT_ND,ge::FORMAT_ND,ge::FORMAT_ND});
this->Output("y_scale")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT,ge::DT_FLOAT,ge::DT_FLOAT})
.Format({ge::FORMAT_ND,ge::FORMAT_ND,ge::FORMAT_ND});
OpAICoreConfig aicore_config;
aicore_config.DynamicCompileStaticFlag(true)
.DynamicFormatFlag(true)
.DynamicRankSupportFlag(true)
.DynamicShapeSupportFlag(true)
.NeedCheckSupportFlag(false)
.PrecisionReduceFlag(true);
this->AICore().AddConfig("ascend910b", aicore_config);
this->AICore().AddConfig("ascend910_93", aicore_config);
}
};
OP_ADD(GroupedMatmulSwigluQuantWeightNzTensorList);
}

View File

@@ -0,0 +1,49 @@
/*
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file grouped_matmul_swiglu_quant_weight_nz_tensor_list_proto.cpp
* \brief
*/
#include "register/op_impl_registry.h"
#include "log/ops_log.h"
#include "platform/platform_info.h"
using namespace ge;
namespace ops {
const int64_t X_INDEX = 0;
const int64_t WEIGHTSCALE_INDEX = 2;
const int64_t M_DIM_INDEX = 0;
const int64_t N_DIM_INDEX = 0;
static ge::graphStatus InferShape4GroupedMatmulSwigluQuantWeightNzTensorList(gert::InferShapeContext* context) {
const gert::Shape* xShape = context->GetInputShape(X_INDEX);
const gert::Shape* weightScaleShape = context->GetDynamicInputShape(WEIGHTSCALE_INDEX, 0);
int64_t m = xShape->GetDim(M_DIM_INDEX);
int64_t n = static_cast<int64_t>(weightScaleShape->GetDim(N_DIM_INDEX) / 2);
auto outShape = context->GetOutputShape(0);
outShape->SetDimNum(2);
outShape->SetDim(0, m);
outShape->SetDim(1, n);
auto outScaleShape = context->GetOutputShape(1);
outScaleShape->SetDimNum(1);
outScaleShape->SetDim(0, m);
return GRAPH_SUCCESS;
}
static graphStatus InferDataType4GroupedMatmulSwigluQuantWeightNzTensorList(gert::InferDataTypeContext* context) {
context->SetOutputDataType(0, DataType::DT_INT8);
context->SetOutputDataType(1, DataType::DT_FLOAT);
return GRAPH_SUCCESS;
}
IMPL_OP_INFERSHAPE(GroupedMatmulSwigluQuantWeightNzTensorList)
.InferShape(InferShape4GroupedMatmulSwigluQuantWeightNzTensorList)
.InferDataType(InferDataType4GroupedMatmulSwigluQuantWeightNzTensorList);
} // namespace ops

View File

@@ -0,0 +1,188 @@
/*
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file grouped_matmul_swiglu_quant_weight_nz_tensor_list_tiling.cpp
* \brief
*/
#include <climits>
#include <graph/utils/type_utils.h>
#include "register/op_impl_registry.h"
#include "log/ops_log.h"
#include "error/ops_error.h"
#include "tiling/tiling_base.h"
#include "grouped_matmul_swiglu_quant_weight_nz_tensor_list_tiling.h"
using namespace ge;
using namespace AscendC;
using namespace GroupedMatmulSwigluQuantWeightNzTensorListTiling;
template <typename T1, typename T2>
static T1 CeilDiv(T1 a, T2 b)
{
if (b == 0) {
return 0;
}
return (a + b - 1) / b;
}
namespace optiling {
struct GMMSwigluCompileInfo {
uint64_t ubSize_ = 0;
uint32_t aicNum_ = 0;
uint32_t baseM_ = 128;
uint32_t baseN_ = 256;
};
static uint64_t CalcMaxTmpSize(const uint32_t row, const uint64_t n) {
std::vector<int64_t> shape_vec = {static_cast<int64_t>(row * n)};
Shape shape(shape_vec);
uint32_t max;
uint32_t min;
GetSwiGLUMaxMinTmpSize(shape, 4, max, min, false);
uint32_t averageTmp = (max + min) >> 1;
GetAscendQuantMaxMinTmpSize(shape, 4, max, min);
uint32_t average = (max + min) >> 1;
average = average > averageTmp ? average : averageTmp;
GetAscendDequantMaxMinTmpSize(shape, 4, max, min);
averageTmp = (max + min) >> 1;
return average > averageTmp ? average : averageTmp;
}
static uint64_t CalRows(const uint64_t ubSize, const uint64_t n) {
uint64_t tokenSize = n << 2;
uint64_t expectSize = ubSize - tokenSize;
uint64_t rows = expectSize / (8 + tokenSize);
uint64_t realSize = (8 + tokenSize) * rows + CalcMaxTmpSize(rows, n);
while (expectSize < realSize) {
rows -= CeilDiv(realSize - expectSize, (8 + tokenSize) << 2);
realSize = (8 + tokenSize) * rows + CalcMaxTmpSize(rows, n);
}
return rows;
}
static void SetTilingKey(gert::TilingContext* context, bool isSplitWorkSpace) {
if(isSplitWorkSpace){
context->SetTilingKey(1);
context->SetScheduleMode(BATCH_MODE_SCHEDULE);
} else {
context->SetTilingKey(0);
context->SetScheduleMode(BATCH_MODE_SCHEDULE);
}
}
static bool IsPreFill(GMMSwigluQuantTilingData &tilingData) {
int64_t k = tilingData.gmmSwigluBaseParams.get_K();
int64_t n = tilingData.gmmSwigluBaseParams.get_N();
int64_t m = tilingData.gmmSwigluBaseParams.get_M();
int64_t groupNum = tilingData.gmmSwigluBaseParams.get_groupNum();
if (groupNum == 128 && m >= PREFILL_M_MIN_SIZE) { // 128:prefiling groupNum
std::array<int64_t, 2> kNList = {k, n}; // 2: kNList size
if (PREFILL_WHITE_LIST.count(kNList)) {
return true;
}
}
return false;
}
ASCENDC_EXTERN_C graphStatus TilingGMMSwigluQuant(gert::TilingContext* context) {
// set info
OPS_LOG_I(context->GetNodeName(), "Begin Run GMM Swiglu Tiling .");
auto compileInfoPtr = context->GetCompileInfo<GMMSwigluCompileInfo>();
auto xTensor = context->GetInputTensor(X_INDEX);
OPS_LOG_E_IF_NULL(context, xTensor, return GRAPH_FAILED);
const int64_t m = xTensor->GetStorageShape().GetDim(0);
const int64_t k = xTensor->GetStorageShape().GetDim(1);
auto wTensor = context->GetDynamicInputTensor(WEIGHT_INDEX, 0);
OPS_LOG_E_IF_NULL(context, wTensor, return GRAPH_FAILED);
const int64_t n = wTensor->GetStorageShape().GetDim(0) * wTensor->GetStorageShape().GetDim(3);
auto groupListTensor = context->GetDynamicInputTensor(GROUPLIST_INDEX, 0);
OPS_LOG_E_IF_NULL(context, groupListTensor, return GRAPH_FAILED);
const int64_t groupNum = groupListTensor->GetStorageShape().GetDim(0);
GMMSwigluQuantTilingData tilingData;
const int64_t row = CalRows(compileInfoPtr->ubSize_, n);
tilingData.gmmSwigluBaseParams.set_groupNum(groupNum);
tilingData.gmmSwigluBaseParams.set_coreNum(compileInfoPtr->aicNum_);
tilingData.gmmSwigluBaseParams.set_K(k);
tilingData.gmmSwigluBaseParams.set_N(n);
tilingData.gmmSwigluBaseParams.set_M(m);
tilingData.gmmSwiglu.set_maxProcessRowNum(row);
tilingData.gmmSwiglu.set_groupListLen(groupNum);
tilingData.gmmSwiglu.set_tokenLen(n);
OPS_LOG_D(context->GetNodeName(),"grouped_matmul_swiglu_quant_weight_nz_tensor_list_tiling.");
OPS_LOG_D(context->GetNodeName(),"gmmSwigluBaseParams.groupNum: %ld", groupNum);
OPS_LOG_D(context->GetNodeName(),"gmmSwigluBaseParams.coreNum: %u ", compileInfoPtr->aicNum_);
OPS_LOG_D(context->GetNodeName(),"gmmSwigluBaseParams.M: %ld", m);
OPS_LOG_D(context->GetNodeName(),"gmmSwigluBaseParams.K: %ld", k);
OPS_LOG_D(context->GetNodeName(),"gmmSwigluBaseParams.N: %ld", n);
OPS_LOG_D(context->GetNodeName(),"gmmSwiglu.maxProcessRowNum: %ld", row);
OPS_LOG_D(context->GetNodeName(),"gmmSwiglu.groupListLen: %ld", groupNum);
OPS_LOG_D(context->GetNodeName(),"gmmSwiglu.tokenLen: %ld", n);
auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo());
using namespace matmul_tiling;
MatmulApiTiling tiling(ascendcPlatform);
tiling.SetAType(TPosition::GM, CubeFormat::ND, matmul_tiling::DataType::DT_INT8);
tiling.SetBType(TPosition::GM, CubeFormat::NZ, matmul_tiling::DataType::DT_INT8);
tiling.SetCType(TPosition::GM, CubeFormat::ND, matmul_tiling::DataType::DT_INT32);
tiling.SetBias(false);
tiling.SetShape(compileInfoPtr->baseM_, compileInfoPtr->baseN_, k);
tiling.SetOrgShape(m, n, k);
tiling.SetBufferSpace(-1, -1, -1);
OPS_ERR_IF(tiling.GetTiling(tilingData.mmTilingData) == -1,
OPS_REPORT_VECTOR_INNER_ERR(context->GetNodeName(), "grouped_matmul_swiglu_quant_weight_nz_tensor_list_tiling, get tiling failed"),
return GRAPH_FAILED);
auto workspaceSizes = context->GetWorkspaceSizes(1);
bool isPreFill = IsPreFill(tilingData);
tilingData.gmmSwigluBaseParams.set_isPreFill(isPreFill);
int64_t usrWorkspaceLimut = isPreFill ? PREFILL_USER_WORKSPACE_LIMIT : USER_WORKSPACE_LIMIT;
int64_t mLimit = ((usrWorkspaceLimut / DOUBLE_WORKSPACE_SPLIT) / INT32_DTYPE_SIZE) / n;
OPS_ERR_IF(mLimit <= 0,
OPS_REPORT_VECTOR_INNER_ERR(context->GetNodeName(),"mLimit is %ld must over then 0.", mLimit),
return GRAPH_FAILED);
tilingData.gmmSwigluBaseParams.set_mLimit(mLimit);
workspaceSizes[0] = SYS_WORKSPACE_SIZE + ((mLimit * DOUBLE_WORKSPACE_SPLIT > m \
? m \
: mLimit * DOUBLE_WORKSPACE_SPLIT) * n * sizeof(int32_t));
bool isSplitWorkSpace = m > mLimit * DOUBLE_WORKSPACE_SPLIT;
OPS_LOG_D(context->GetNodeName(), "USER_WORKSPACE_LIMIT: %ld", usrWorkspaceLimut);
OPS_LOG_D(context->GetNodeName(), "mLimit: %ld", mLimit);
OPS_LOG_D(context->GetNodeName(), "workspaceSizes: %lu", workspaceSizes[0]);
OPS_LOG_D(context->GetNodeName(), "isSplitWorkSpace: %s", isSplitWorkSpace ? "true" : "false");
OPS_LOG_D(context->GetNodeName(), "isPreFill: %s", isPreFill ? "true" : "false");
SetTilingKey(context, isSplitWorkSpace);
tilingData.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity());
context->SetBlockDim(compileInfoPtr->aicNum_); // block dim is the number of aicube
context->GetRawTilingData()->SetDataSize(tilingData.GetDataSize());
OPS_LOG_D(context->GetNodeName(), "End Run GMM Swiglu Tiling.");
return GRAPH_SUCCESS;
}
ASCENDC_EXTERN_C graphStatus TilingPrepareForGMMSwigluQuant(gert::TilingParseContext* context) {
// get info
fe::PlatFormInfos* platformInfoPtr = context->GetPlatformInfo();
OPS_LOG_E_IF_NULL(context, platformInfoPtr, return GRAPH_FAILED);
auto compileInfoPtr = context->GetCompiledInfo<GMMSwigluCompileInfo>();
OPS_LOG_E_IF_NULL(context, compileInfoPtr, return GRAPH_FAILED);
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfoPtr);
compileInfoPtr->aicNum_ = ascendcPlatform.GetCoreNumAic();
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, compileInfoPtr->ubSize_);
OPS_LOG_D(context->GetNodeName(), "ubSize is %lu, aicNum is %u.", compileInfoPtr->ubSize_, compileInfoPtr->aicNum_);
return GRAPH_SUCCESS;
}
IMPL_OP_OPTILING(GroupedMatmulSwigluQuantWeightNzTensorList)
.Tiling(TilingGMMSwigluQuant)
.TilingParse<GMMSwigluCompileInfo>(TilingPrepareForGMMSwigluQuant);
} // namespace optiling

View File

@@ -0,0 +1,68 @@
/*
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file grouped_matmul_swiglu_quant_weight_nz_tensor_list_tiling.h
* \brief
*/
#ifndef AIR_CXX_RUNTIME_V2_OP_IMPL_GROUPED_MATMUL_SWIGLU_QUANT_WEIGHT_NZ_TENSOR_LIST_H
#define AIR_CXX_RUNTIME_V2_OP_IMPL_GROUPED_MATMUL_SWIGLU_QUANT_WEIGHT_NZ_TENSOR_LIST_H
#include <set>
#include "register/tilingdata_base.h"
#include "tiling/tiling_api.h"
namespace optiling {
BEGIN_TILING_DATA_DEF(GMMSwigluBaseParams)
TILING_DATA_FIELD_DEF(uint32_t, groupNum);
TILING_DATA_FIELD_DEF(uint32_t, coreNum);
TILING_DATA_FIELD_DEF(uint32_t, K);
TILING_DATA_FIELD_DEF(uint32_t, N);
TILING_DATA_FIELD_DEF(uint32_t, M);
TILING_DATA_FIELD_DEF(uint32_t, mLimit);
TILING_DATA_FIELD_DEF(uint64_t, isPreFill);
END_TILING_DATA_DEF;
REGISTER_TILING_DATA_CLASS(GMMSwigluBaseParamsOp, GMMSwigluBaseParams)
BEGIN_TILING_DATA_DEF(GMMSwiglu)
TILING_DATA_FIELD_DEF(uint32_t, maxProcessRowNum);
TILING_DATA_FIELD_DEF(uint32_t, groupListLen);
TILING_DATA_FIELD_DEF(uint32_t, tokenLen);
END_TILING_DATA_DEF;
REGISTER_TILING_DATA_CLASS(GMMSwigluOp, GMMSwiglu)
BEGIN_TILING_DATA_DEF(GMMSwigluQuantTilingData)
TILING_DATA_FIELD_DEF_STRUCT(GMMSwigluBaseParams, gmmSwigluBaseParams);
TILING_DATA_FIELD_DEF_STRUCT(GMMSwiglu, gmmSwiglu);
TILING_DATA_FIELD_DEF_STRUCT(TCubeTiling, mmTilingData);
END_TILING_DATA_DEF;
REGISTER_TILING_DATA_CLASS(GroupedMatmulSwigluQuantWeightNzTensorList, GMMSwigluQuantTilingData)
}
namespace GroupedMatmulSwigluQuantWeightNzTensorListTiling {
constexpr uint32_t X_INDEX = 0;
constexpr uint32_t WEIGHT_INDEX = 1;
constexpr uint32_t GROUPLIST_INDEX = 4;
constexpr uint32_t BATCH_MODE_SCHEDULE = 1;
constexpr uint32_t SYS_WORKSPACE_SIZE = 16 * 1024 * 1024;
constexpr int64_t USER_WORKSPACE_LIMIT = 256 * 1024 * 1024;
constexpr int64_t PREFILL_USER_WORKSPACE_LIMIT = 64 * 1024 * 1024;
constexpr int64_t DOUBLE_WORKSPACE_SPLIT = 2;
constexpr int64_t INT32_DTYPE_SIZE = 4;
constexpr uint32_t PREFILL_M_MIN_SIZE = 16 * 1024;
const std::set<std::array<int64_t, 2>> PREFILL_WHITE_LIST = { // used for preFill case
{{2048, 1536}},
{{4096, 3072}}
};
} // namespace GroupedMatmulSwigluQuantWeightNzTensorListTiling
#endif // AIR_CXX_RUNTIME_V2_OP_IMPL_GROUPED_MATMUL_SWIGLU_QUANT_WEIGHT_NZ_TENSOR_LIST_H

View File

@@ -0,0 +1,80 @@
/*
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file grouped_matmul_swiglu_quant_weight_nz_tensor_list.cpp
* \brief
*/
#include "grouped_matmul_swiglu_quant_weight_nz_tensor_list.h"
#include <typeinfo>
#include "grouped_matmul_swiglu_quant_weight_nz_tensor_list_split_ws.h"
using namespace AscendC;
using namespace matmul;
using namespace GROUPED_MATMUL_SWIGLU_QUANT_WEIGHT_NZ_TENSOR_LIST;
using MM_DTYPE_Y = int32_t;
template <bool trans = false>
using xType = MatmulType<AscendC::TPosition::GM, CubeFormat::ND, DTYPE_X>;
template <bool trans = false>
using weightType = MatmulType<AscendC::TPosition::GM, CubeFormat::NZ, DTYPE_WEIGHT>;
using yType = MatmulType<AscendC::TPosition::GM, CubeFormat::ND, MM_DTYPE_Y>;
#define GMM_CV_SPLIT_IMP(computeClass, dtypeC, transA, transB, sync, cfg, aType, bType, cType) \
do { \
using matmulType = MMImplType<aType<transA>, bType<transB>, cType, cType, cfg>; \
matmulType::MT mm; \
GET_TILING_DATA_MEMBER(GMMSwigluQuantTilingData, gmmSwigluBaseParams, gmmSwigluBaseParams_, tiling); \
GET_TILING_DATA_MEMBER(GMMSwigluQuantTilingData, mmTilingData, mmTilingData_, tiling); \
GET_TILING_DATA_MEMBER(GMMSwigluQuantTilingData, gmmSwiglu, gmmSwiglu_, tiling); \
if ASCEND_IS_AIC { \
mm.SetSubBlockIdx(0); \
mm.Init(&mmTilingData_, &tPipe); \
} \
computeClass<matmulType, sync, dtypeC> computeOp(mm); \
computeOp.Init(x, weight, perChannelScale, perTokenScale, groupList, quantOutput, quantScaleOutput, \
user1, &gmmSwigluBaseParams_, &mmTilingData_, &gmmSwiglu_, &tPipe); \
computeOp.Process(); \
} while (0)
extern "C" __global__ __aicore__ void grouped_matmul_swiglu_quant_weight_nz_tensor_list(GM_ADDR x, GM_ADDR weight, GM_ADDR perChannelScale, GM_ADDR perTokenScale,
GM_ADDR groupList, GM_ADDR quantOutput, GM_ADDR quantScaleOutput,
GM_ADDR workspace, GM_ADDR tiling) {
TPipe tPipe;
AscendCUtils::SetOverflow(1);
KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_MIX_AIC_1_2);
GM_ADDR user1 = GetUserWorkspace(workspace);
if (TILING_KEY_IS(0)) { // antiquant msd
KERNEL_TASK_TYPE(0, KERNEL_TYPE_MIX_AIC_1_2);
GMM_CV_SPLIT_IMP(
GMMSwigluCompute, // computeClass
DTYPE_WEIGHT_SCALE,
false, // transA
false, // transB
false, // sync
NZ_CFG_MDL, // cfg
xType, // aType
weightType, // bType
yType); // cType
} else if(TILING_KEY_IS(1)){
KERNEL_TASK_TYPE(1, KERNEL_TYPE_MIX_AIC_1_2);
GMM_CV_SPLIT_IMP(
GMMSwigluSplitWorkSpaceCompute, // computeClass
DTYPE_WEIGHT_SCALE,
false, // transA
false, // transB
false, // sync
NZ_CFG_MDL, // cfg
xType, // aType
weightType, // bType
yType); // cType
}
}

View File

@@ -0,0 +1,498 @@
/*
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file grouped_matmul_swiglu_quant_weight_nz_tensor_list.h
* \brief
*/
#ifndef ASCENDC_GROUPED_MATMUL_SWIGLU_QUANT_WEIGHT_NZ_TENSOR_LIST_H
#define ASCENDC_GROUPED_MATMUL_SWIGLU_QUANT_WEIGHT_NZ_TENSOR_LIST_H
#include "grouped_matmul_swiglu_quant_weight_nz_tensor_list_utils.h"
namespace GROUPED_MATMUL_SWIGLU_QUANT_WEIGHT_NZ_TENSOR_LIST {
/** @brief intenal computation class
*/
template <class mmType, bool sync = false, typename CHANNELDTYPE = float>
class GMMSwigluCompute{
public:
using AT = typename mmType::AT::T;
using BT = typename mmType::BT::T;
using B = typename mmType::BT;
using CT = typename mmType::CT::T;
using BiasT = typename mmType::BiasT::T;
using WT = int8_t;
constexpr static bool transposeX = mmType::AT::isTrans;
constexpr static bool transposeW = mmType::BT::isTrans;
/** @brief constructor */
__aicore__ inline GMMSwigluCompute(typename mmType::MT& mm_): mm(mm_) {}
__aicore__ inline void Init(GM_ADDR x, GM_ADDR weight, GM_ADDR perChannelScale, GM_ADDR perTokenScale,
GM_ADDR groupList, GM_ADDR quantOutput, GM_ADDR quantScaleOutput,
GM_ADDR workspace,
const GMMSwigluBaseParams* __restrict gmmBaseParamsIN,
const TCubeTiling* __restrict mmTilingDataIN,
const GMMSwiglu* __restrict gmmSwigluIN, TPipe* tPipeIN);
__aicore__ inline void Process();
private:
__aicore__ inline void MMCompute(uint32_t groupIdx, MNConfig& mnConfig, uint32_t coreIdx);
__aicore__ inline void UpdateMnConfig(MNConfig &mnConfig);
__aicore__ inline void SetMNConfig(const int32_t splitValue, const uint32_t groupIdx, MNConfig &mnConfig);
__aicore__ inline void SetMKN(const int32_t splitValue, const uint32_t groupIdx, MNConfig &mnConfig);
__aicore__ inline uint64_t GetWOffset(uint32_t tailN, uint32_t k);
__aicore__ inline void MNBlockIdxCompute(MNConfig &mnConfig, const uint32_t curBlock,
const uint32_t count, const uint32_t thresholdM_dimN);
template <typename DTYPE_CS>
__aicore__ inline void UpdateChannelScale(uint32_t loopidx);
__aicore__ inline void VectorCompute(uint32_t loopidx);
template <typename DTYPE_CS>
__aicore__ inline void PreLoadTokenAndChannel(LocalTensor<float>& channelScaleLocal);
__aicore__ inline void UpdateVecConfig(uint32_t blockIdx, VecConfig& vecConfig);
__aicore__ inline void customDataCopyIn(uint32_t outLoopIdx);
__aicore__ inline void customDataCopyOut();
__aicore__ inline void Dequant(uint32_t loopidx);
__aicore__ inline void Quant(uint32_t loopidx);
__aicore__ inline void Swiglu(uint32_t loopidx);
private:
typename mmType::MT& mm;
const GMMSwigluBaseParams* __restrict gmmBaseParams;
const GMMSwiglu* __restrict gmmSwiglu;
const TCubeTiling* __restrict mmTilingData;
uint32_t blockIdx;
VecConfig vecConfig;
TPipe* pipe;
GlobalTensor<int8_t> xGM, weightGM;
GlobalTensor<CHANNELDTYPE> perChannelScaleGM;
GlobalTensor<float> perTokenScaleGM;
GlobalTensor<int64_t> groupListGM;
GlobalTensor<int8_t> quantOutputGM;
GlobalTensor<float> quantScaleOutputGM;
GlobalTensor<int32_t> mmOutGM;
// define the que
TQue<QuePosition::VECIN, 1> mmOutQueue;
TQue<QuePosition::VECIN, 1> perChannelScaleInQueue;
TQue<QuePosition::VECOUT, 1> quantOutQueue;
TQue<QuePosition::VECOUT, 1> quantScaleOutQueue;
TBuf<TPosition::VECCALC> reduceWorkspace;
TBuf<TPosition::VECCALC> castWorkspace;
bool sequentialWrite = true;
uint32_t cubeNum; // Matmul completions on the kernel
uint32_t groupNum; // Matmul completions on the kernel
int32_t preOffset;
int64_t aicCoreNum;
int64_t aivCoreNum;
GM_ADDR xTensorPtr;
GM_ADDR weightTensorPtr;
GM_ADDR perChannelScalePtr;
};
template <typename mmType, bool sync, typename CHANNELDTYPE>
__aicore__ inline void GMMSwigluCompute<mmType, sync, CHANNELDTYPE>::Init(GM_ADDR x, GM_ADDR weight, GM_ADDR perChannelScale, GM_ADDR perTokenScale,
GM_ADDR groupList, GM_ADDR quantOutput, GM_ADDR quantScaleOutput,
GM_ADDR workspace,
const GMMSwigluBaseParams* __restrict gmmSwigluBaseParamsIn,
const TCubeTiling* __restrict mmTilingDataIN,
const GMMSwiglu* __restrict gmmSwigluIN, TPipe* tPipeIN)
{
aicCoreNum = GetBlockNum();
aivCoreNum = aicCoreNum * 2;
blockIdx = GetBlockIdx();
mmTilingData = mmTilingDataIN;
gmmBaseParams = gmmSwigluBaseParamsIn;
gmmSwiglu = gmmSwigluIN;
pipe = tPipeIN;
xTensorPtr = x;
weightTensorPtr = weight;
perChannelScalePtr = perChannelScale;
groupNum = gmmSwiglu->groupListLen;
if ASCEND_IS_AIC {
groupListGM.SetGlobalBuffer((__gm__ int64_t *)groupList, gmmSwiglu->groupListLen);
mmOutGM.SetGlobalBuffer((__gm__ int32_t *)workspace, gmmBaseParams->M * gmmSwiglu->tokenLen);
}
if ASCEND_IS_AIV {
mmOutGM.SetGlobalBuffer((__gm__ int32_t *)workspace, gmmBaseParams->M * gmmSwiglu->tokenLen);
perChannelScaleGM.SetGlobalBuffer((__gm__ CHANNELDTYPE *)perChannelScale, gmmSwiglu->groupListLen * gmmSwiglu->tokenLen);
perTokenScaleGM.SetGlobalBuffer((__gm__ float *)perTokenScale, gmmSwiglu->maxProcessRowNum);
groupListGM.SetGlobalBuffer((__gm__ int64_t *)groupList, gmmSwiglu->groupListLen);
quantOutputGM.SetGlobalBuffer((__gm__ int8_t *)quantOutput, gmmBaseParams->M * gmmSwiglu->tokenLen / 2);
quantScaleOutputGM.SetGlobalBuffer((__gm__ float *)quantScaleOutput, gmmSwiglu->maxProcessRowNum);
}
}
template <typename mmType, bool sync, typename CHANNELDTYPE>
__aicore__ inline void GMMSwigluCompute<mmType, sync, CHANNELDTYPE>::Process() {
MNConfig mnConfig;
if ASCEND_IS_AIC {
preOffset = 0;
int32_t prevSplitValue = 0;
for (uint32_t groupIdx = 0, count = 0; groupIdx < gmmSwiglu->groupListLen; ++groupIdx) {
UpdateMnConfig(mnConfig);
int32_t currSplitValue = static_cast<int32_t>(groupListGM.GetValue(groupIdx));
int32_t splitValue = currSplitValue - prevSplitValue;
prevSplitValue = currSplitValue;
SetMNConfig(splitValue, groupIdx, mnConfig);
if (mnConfig.m <= 0 || mnConfig.k <= 0 || mnConfig.n <= 0) {
continue;
}
mnConfig.blockDimM = Ceil(mnConfig.m, mnConfig.singleM);
mnConfig.blockDimN = Ceil(mnConfig.n, mnConfig.singleN);
uint32_t curCount = count + mnConfig.blockDimM * mnConfig.blockDimN;
uint32_t curBlock = blockIdx >= count ? blockIdx : blockIdx + gmmBaseParams->coreNum;
uint32_t thresholdM_dimN = THRESHOLD_BLOCK_NUM * mnConfig.blockDimN;
while (curBlock < curCount) {
MNBlockIdxCompute(mnConfig, curBlock, count, thresholdM_dimN);
MMCompute(groupIdx, mnConfig, blockIdx);
curBlock += aicCoreNum;
}
count = curCount % gmmBaseParams->coreNum;
}
SyncAll<false>();
}
if ASCEND_IS_AIV {
UpdateVecConfig(blockIdx, vecConfig);
if (blockIdx < vecConfig.usedCoreNum) {
LocalTensor<float> channelScaleLocal = perChannelScaleInQueue.AllocTensor<float>();
LocalTensor<int32_t> mmLocal = mmOutQueue.AllocTensor<int32_t>();
LocalTensor<int8_t> quantLocal = quantOutQueue.AllocTensor<int8_t>();
LocalTensor<float> quantScaleLocal = quantScaleOutQueue.AllocTensor<float>();
mmOutQueue.EnQue(mmLocal);
quantScaleOutQueue.EnQue(quantScaleLocal);
quantOutQueue.EnQue(quantLocal);
PreLoadTokenAndChannel<CHANNELDTYPE>(channelScaleLocal);
}
SyncAll<false>();
if (blockIdx < vecConfig.usedCoreNum) {
for (uint32_t outLoopIdx = 0; outLoopIdx < vecConfig.outLoopNum; outLoopIdx++) {
vecConfig.innerLoopNum = outLoopIdx == (vecConfig.outLoopNum - 1)
? vecConfig.tailLoopNum
: gmmSwiglu->maxProcessRowNum;
customDataCopyIn(outLoopIdx);
for (uint32_t innerLoopIdx = 0; innerLoopIdx < vecConfig.innerLoopNum; innerLoopIdx++) {
UpdateChannelScale<CHANNELDTYPE>(innerLoopIdx);
VectorCompute(innerLoopIdx);
}
customDataCopyOut();
}
LocalTensor<float> channelScaleLocal = perChannelScaleInQueue.DeQue<float>();
LocalTensor<int32_t> mmLocal = mmOutQueue.DeQue<int32_t>();
LocalTensor<int8_t> quantLocal = quantOutQueue.DeQue<int8_t>();
LocalTensor<float> quantScaleLocal = quantScaleOutQueue.DeQue<float>();
perChannelScaleInQueue.FreeTensor(channelScaleLocal);
mmOutQueue.FreeTensor(mmLocal);
quantScaleOutQueue.FreeTensor(quantScaleLocal);
quantOutQueue.FreeTensor(quantLocal);
} else {
return;
}
}
}
template <typename mmType, bool sync, typename CHANNELDTYPE>
template <typename DTYPE_CS>
__aicore__ inline void GMMSwigluCompute<mmType, sync, CHANNELDTYPE>::PreLoadTokenAndChannel(LocalTensor<float>& channelScaleLocal)
{
GlobalTensor<CHANNELDTYPE> perChannelScaleTensor;
perChannelScaleTensor.SetGlobalBuffer(GetTensorAddr<CHANNELDTYPE>(vecConfig.curGroupIdx, perChannelScalePtr));
DataCopyExtParams copyChannelParams{1, static_cast<uint32_t>(gmmSwiglu->tokenLen * sizeof(DTYPE_CS)), 0, 0, 0};
DataCopyPadExtParams<DTYPE_CS> padParams{false, 0 ,0, 0};
if constexpr(!IsSameType<DTYPE_CS, float>::value) {
LocalTensor<DTYPE_CS> dstLocalT = channelScaleLocal.template ReinterpretCast<DTYPE_CS>();
DataCopyPad(dstLocalT[gmmSwiglu->tokenLen], perChannelScaleTensor, copyChannelParams, padParams);
PipeBarrier<PIPE_ALL>();
Cast(channelScaleLocal, dstLocalT[gmmSwiglu->tokenLen], RoundMode::CAST_NONE, gmmSwiglu->tokenLen);
} else {
DataCopyPad(channelScaleLocal, perChannelScaleTensor, copyChannelParams, padParams);
}
perChannelScaleInQueue.EnQue(channelScaleLocal);
}
template <typename mmType, bool sync, typename CHANNELDTYPE>
__aicore__ inline void GMMSwigluCompute<mmType, sync, CHANNELDTYPE>::MMCompute(uint32_t groupIdx, MNConfig& mnConfig, uint32_t coreIdx)
{
uint32_t tailN = mnConfig.nIdx * mnConfig.singleN;
uint32_t curSingleN = mnConfig.nIdx < mnConfig.blockDimN - 1 ? mnConfig.singleN : mnConfig.n - tailN;
uint32_t curSingleM = mnConfig.mIdx < mnConfig.blockDimM - 1 ? mnConfig.singleM
: mnConfig.m - mnConfig.mIdx * mnConfig.singleM;
uint64_t xOffset = mnConfig.mIdx * mnConfig.singleM * mnConfig.k;
if constexpr (transposeX) {
xOffset = mnConfig.mIdx * mnConfig.singleM;
}
uint64_t outOffset = mnConfig.mIdx * mnConfig.singleM * mnConfig.n + tailN;
xGM.SetGlobalBuffer((__gm__ int8_t *)xTensorPtr + mnConfig.xBaseOffset);
weightGM.SetGlobalBuffer(GetTensorAddr<int8_t>(groupIdx, weightTensorPtr) + GetWOffset(tailN, mnConfig.k));
if (mnConfig.blockDimM == 1){
weightGM.SetL2CacheHint(CacheMode::CACHE_MODE_DISABLE);
}
mnConfig.workSpaceOffset = outOffset + mnConfig.yBaseOffset;
mm.SetOrgShape(mnConfig.m, mnConfig.n, mnConfig.k);
mm.SetSingleShape(curSingleM, curSingleN, mnConfig.k);
mm.SetTensorA(xGM[xOffset], transposeX);
mm.SetTensorB(weightGM, transposeW);
mm.template IterateAll<sync>(mmOutGM[mnConfig.workSpaceOffset], 0);
}
template <typename mmType, bool sync, typename CHANNELDTYPE>
__aicore__ inline void GMMSwigluCompute<mmType, sync, CHANNELDTYPE>::UpdateMnConfig(MNConfig &mnConfig) {
if constexpr (B::format == CubeFormat::NZ) {
mnConfig.wBaseOffset += AlignUp<16>(mnConfig.k) * AlignUp<32>(mnConfig.n); // 16: nz format last two dim size
} else {
mnConfig.wBaseOffset += mnConfig.k * mnConfig.n;
}
mnConfig.nAxisBaseOffset += mnConfig.n;
mnConfig.mAxisBaseOffset += mnConfig.m;
mnConfig.xBaseOffset += mnConfig.m * mnConfig.k;
mnConfig.yBaseOffset += mnConfig.m * mnConfig.n;
}
template <typename mmType, bool sync, typename CHANNELDTYPE>
__aicore__ inline void GMMSwigluCompute<mmType, sync, CHANNELDTYPE>::SetMNConfig(const int32_t splitValue, const uint32_t groupIdx, MNConfig &mnConfig) {
SetMKN(splitValue, groupIdx, mnConfig);
mnConfig.baseM = BASIC_M;
mnConfig.baseN = BASIC_N;
mnConfig.singleM = SINGLE_CORE_M;
mnConfig.singleN = SINGLE_CORE_N;
}
template <typename mmType, bool sync, typename CHANNELDTYPE>
__aicore__ inline void GMMSwigluCompute<mmType, sync, CHANNELDTYPE>::SetMKN(const int32_t splitValue, const uint32_t groupIdx, MNConfig &mnConfig)
{
mnConfig.m = static_cast<uint32_t>(splitValue);
mnConfig.k = gmmBaseParams->K; // tilingData
mnConfig.n = gmmBaseParams->N; // tilingData
}
template <typename mmType, bool sync, typename CHANNELDTYPE>
__aicore__ inline uint64_t GMMSwigluCompute<mmType, sync, CHANNELDTYPE>::GetWOffset(uint32_t tailN, uint32_t k) {
uint64_t wOffset = 0;
if constexpr (mmType::BT::format == CubeFormat::NZ) {
wOffset = tailN * AlignUp<16>(k); // 16: nz format last two dim size
} else {
wOffset = tailN;
}
return wOffset;
}
template <typename mmType, bool sync, typename CHANNELDTYPE>
__aicore__ inline void GMMSwigluCompute<mmType, sync, CHANNELDTYPE>::MNBlockIdxCompute(MNConfig &mnConfig, const uint32_t curBlock,
const uint32_t count, const uint32_t thresholdM_dimN) {
mnConfig.mIdx = (curBlock - count) / mnConfig.blockDimN;
mnConfig.nIdx = (curBlock - count) % mnConfig.blockDimN;
}
template <typename mmType, bool sync, typename CHANNELDTYPE>
__aicore__ inline void GMMSwigluCompute<mmType, sync, CHANNELDTYPE>::UpdateVecConfig(uint32_t blockIdx, VecConfig& vecConfig)
{
// Step 1: Read grouplist reduceSum to calculate total data count
int64_t prevM = 0;
for (uint32_t groupIdx = 0; groupIdx < gmmSwiglu->groupListLen; groupIdx++){
int64_t currM = groupListGM.GetValue(groupIdx);
int64_t tempM = currM - prevM;
prevM = currM;
vecConfig.M += tempM;
}
// Step 2: Calculate core allocation
uint32_t eachCoreTaskNum = (vecConfig.M + aivCoreNum - 1) / aivCoreNum;
vecConfig.usedCoreNum = vecConfig.M >= aivCoreNum ? aivCoreNum : vecConfig.M;
uint32_t tailCoreIdx = vecConfig.M - (eachCoreTaskNum - 1) * vecConfig.usedCoreNum;
vecConfig.taskNum = blockIdx < tailCoreIdx ? eachCoreTaskNum : eachCoreTaskNum - 1;
vecConfig.startIdx = blockIdx < tailCoreIdx
? eachCoreTaskNum * blockIdx
:((eachCoreTaskNum - 1) * blockIdx + tailCoreIdx);
vecConfig.curIdx = vecConfig.startIdx;
vecConfig.startOffset = vecConfig.startIdx * gmmSwiglu->tokenLen;
vecConfig.curOffset = vecConfig.startOffset;
int64_t curStartIdx = vecConfig.startIdx;
prevM = 0;
for (uint32_t groupIdx = 0; groupIdx < gmmSwiglu->groupListLen; groupIdx++){
int64_t currM = groupListGM.GetValue(groupIdx);
int64_t tempM = currM - prevM;
prevM = currM;
if (curStartIdx >= 0 && curStartIdx - tempM < 0) {
vecConfig.curGroupIdx = groupIdx;
vecConfig.nextUpadteInterVal = tempM - curStartIdx;
}
curStartIdx -= tempM;
}
// Step 3: Calculate total data volume
vecConfig.outLoopNum = (vecConfig.taskNum + gmmSwiglu->maxProcessRowNum - 1) / gmmSwiglu->maxProcessRowNum;
vecConfig.tailLoopNum = vecConfig.taskNum % gmmSwiglu->maxProcessRowNum
? vecConfig.taskNum % gmmSwiglu->maxProcessRowNum
: gmmSwiglu->maxProcessRowNum;
pipe->Reset();
// Step 4: Allocate space
pipe->InitBuffer(mmOutQueue, 1, gmmSwiglu->maxProcessRowNum * gmmSwiglu->tokenLen * sizeof(int32_t));
pipe->InitBuffer(perChannelScaleInQueue, 1, gmmSwiglu->tokenLen * sizeof(float));
pipe->InitBuffer(quantOutQueue, 1, gmmSwiglu->maxProcessRowNum * gmmSwiglu->tokenLen / 2 * sizeof(int8_t));
pipe->InitBuffer(quantScaleOutQueue, 1, AlignUp<int32_t>(gmmSwiglu->maxProcessRowNum, 8) * sizeof(float));
pipe->InitBuffer(reduceWorkspace, 1024 * sizeof(float));
pipe->InitBuffer(castWorkspace, 32 * sizeof(int8_t));
}
template <typename mmType, bool sync, typename CHANNELDTYPE>
__aicore__ inline void GMMSwigluCompute<mmType, sync, CHANNELDTYPE>::customDataCopyIn(uint32_t outLoopIdx)
{
LocalTensor<int32_t> _inMMLocal_0 = mmOutQueue.DeQue<int32_t>();
DataCopyExtParams copyParams_0{1, static_cast<uint32_t>(vecConfig.innerLoopNum * gmmSwiglu->tokenLen * sizeof(int32_t)), 0, 0, 0};
DataCopyPadExtParams<int32_t> padParams_0{false, 0 ,0, 0};
DataCopyPad(_inMMLocal_0, mmOutGM[vecConfig.curOffset], copyParams_0, padParams_0);
mmOutQueue.EnQue(_inMMLocal_0);
LocalTensor<int32_t> _inMMLocal_1 = mmOutQueue.DeQue<int32_t>();
Cast(_inMMLocal_1.ReinterpretCast<float>(), _inMMLocal_1, RoundMode::CAST_NONE, vecConfig.innerLoopNum * gmmSwiglu->tokenLen);
mmOutQueue.EnQue(_inMMLocal_1);
LocalTensor<float> _inMMLocal_2 = mmOutQueue.DeQue<float>();
set_flag(PIPE_S, PIPE_V, EVENT_ID0);
for (uint32_t i = 0; i < vecConfig.innerLoopNum; i++){
wait_flag(PIPE_S, PIPE_V, EVENT_ID0);
float scale = perTokenScaleGM.GetValue(vecConfig.curIdx);
set_flag(PIPE_S, PIPE_V, EVENT_ID0);
wait_flag(PIPE_S, PIPE_V, EVENT_ID0);
Muls(_inMMLocal_2[i * gmmSwiglu->tokenLen], _inMMLocal_2[i * gmmSwiglu->tokenLen], scale, gmmSwiglu->tokenLen);
set_flag(PIPE_S, PIPE_V, EVENT_ID0);
vecConfig.curIdx++;
}
wait_flag(PIPE_S, PIPE_V, EVENT_ID0);
vecConfig.curOffset = vecConfig.curIdx * gmmSwiglu->tokenLen;
mmOutQueue.EnQue(_inMMLocal_2);
}
template <typename mmType, bool sync, typename CHANNELDTYPE>
template <typename DTYPE_CS>
__aicore__ inline void GMMSwigluCompute<mmType, sync, CHANNELDTYPE>::UpdateChannelScale(uint32_t loopIdx){
// Update perChannel
if (unlikely(vecConfig.nextUpadteInterVal == 0)) {
int64_t loop = gmmSwiglu->groupListLen - vecConfig.curGroupIdx;
while (loop--) {
int64_t curTemp = groupListGM.GetValue(vecConfig.curGroupIdx);
vecConfig.curGroupIdx++;
int64_t nextTemp = groupListGM.GetValue(vecConfig.curGroupIdx);
if(nextTemp != curTemp){
vecConfig.nextUpadteInterVal = nextTemp - curTemp;
break;
}
}
LocalTensor<float> _inChannel = perChannelScaleInQueue.DeQue<float>();
DataCopyExtParams copyParams{1, static_cast<uint32_t>(gmmSwiglu->tokenLen * sizeof(DTYPE_CS)), 0, 0, 0};
DataCopyPadExtParams<DTYPE_CS> padParams{false, 0 ,0, 0};
GlobalTensor<CHANNELDTYPE> perChannelScaleTensor;
perChannelScaleTensor.SetGlobalBuffer(GetTensorAddr<CHANNELDTYPE>(vecConfig.curGroupIdx, perChannelScalePtr));
if constexpr(!IsSameType<DTYPE_CS, float>::value) {
LocalTensor<DTYPE_CS> dstLocalT = _inChannel.template ReinterpretCast<DTYPE_CS>();
DataCopyPad(dstLocalT[gmmSwiglu->tokenLen], perChannelScaleTensor, copyParams, padParams);
PipeBarrier<PIPE_ALL>();
Cast(_inChannel, dstLocalT[gmmSwiglu->tokenLen], RoundMode::CAST_NONE, gmmSwiglu->tokenLen);
} else {
DataCopyPad(_inChannel, perChannelScaleTensor, copyParams, padParams);
}
PipeBarrier<PIPE_ALL>();
perChannelScaleInQueue.EnQue(_inChannel);
}
}
template <typename mmType, bool sync, typename CHANNELDTYPE>
__aicore__ inline void GMMSwigluCompute<mmType, sync, CHANNELDTYPE>::VectorCompute(uint32_t loopIdx) {
Dequant(loopIdx);
Swiglu(loopIdx);
Quant(loopIdx);
}
template <typename mmType, bool sync, typename CHANNELDTYPE>
__aicore__ inline void GMMSwigluCompute<mmType, sync, CHANNELDTYPE>::Dequant(uint32_t loopIdx) {
// perChanelScale * perTokenScale
LocalTensor<float> mmLocal = mmOutQueue.DeQue<float>();
LocalTensor<float> perChannelLocal = perChannelScaleInQueue.DeQue<float>();
Mul(mmLocal[loopIdx * gmmSwiglu->tokenLen], mmLocal[loopIdx * gmmSwiglu->tokenLen], perChannelLocal, gmmSwiglu->tokenLen);
vecConfig.nextUpadteInterVal--;
mmOutQueue.EnQue(mmLocal);
perChannelScaleInQueue.EnQue(perChannelLocal);
}
template <typename mmType, bool sync, typename CHANNELDTYPE>
__aicore__ inline void GMMSwigluCompute<mmType, sync, CHANNELDTYPE>::Swiglu(uint32_t loopIdx) {
// High-level API swiglu
LocalTensor<float> _inMMLocal = mmOutQueue.DeQue<float>();
float beta = 1.0f;
LocalTensor<float> workspaceLocal= reduceWorkspace.Get<float>();
LocalTensor<float> src0Local = _inMMLocal[loopIdx * gmmSwiglu->tokenLen + gmmSwiglu->tokenLen / 2];
LocalTensor<float> src1Local = _inMMLocal[loopIdx * gmmSwiglu->tokenLen];
SwiGLU<float, false>(workspaceLocal, src0Local, src1Local, beta, gmmSwiglu->tokenLen / 2);
PipeBarrier<PIPE_ALL>();
DataCopyParams repeatParams{1, static_cast<uint16_t>((gmmSwiglu->tokenLen / 2) / 8), 0, 0};
DataCopy(_inMMLocal[loopIdx * gmmSwiglu->tokenLen], workspaceLocal, repeatParams);
mmOutQueue.EnQue(_inMMLocal);
}
template <typename mmType, bool sync, typename CHANNELDTYPE>
__aicore__ inline void GMMSwigluCompute<mmType, sync, CHANNELDTYPE>::Quant(uint32_t loopIdx) {
LocalTensor<float> _inMMLocal = mmOutQueue.DeQue<float>();
Abs(_inMMLocal[loopIdx * gmmSwiglu->tokenLen + gmmSwiglu->tokenLen / BISECT],
_inMMLocal[loopIdx * gmmSwiglu->tokenLen],
gmmSwiglu->tokenLen / BISECT);
LocalTensor<float> workspaceLocal= reduceWorkspace.Get<float>();
PipeBarrier<PIPE_V>();
ReduceMaxTemplate(workspaceLocal,
_inMMLocal, loopIdx * gmmSwiglu->tokenLen + gmmSwiglu->tokenLen / BISECT, gmmSwiglu->tokenLen / BISECT);
PipeBarrier<PIPE_ALL>();
float quantScale = workspaceLocal.GetValue(0) / QUANT_SCALE_INT8;
PipeBarrier<PIPE_ALL>();
LocalTensor<float> quantScaleLocal = quantScaleOutQueue.DeQue<float>();
PipeBarrier<PIPE_ALL>();
quantScaleLocal.SetValue(loopIdx, quantScale);
PipeBarrier<PIPE_ALL>();
quantScale = 1 / quantScale;
PipeBarrier<PIPE_ALL>();
Muls(_inMMLocal[loopIdx * gmmSwiglu->tokenLen], _inMMLocal[loopIdx * gmmSwiglu->tokenLen],
quantScale, gmmSwiglu->tokenLen / BISECT);
PipeBarrier<PIPE_V>();
LocalTensor<int8_t> quantLocal = quantOutQueue.DeQue<int8_t>();
int32_t dstTempOffset = static_cast<int32_t>(loopIdx * gmmSwiglu->tokenLen / BISECT);
int32_t srcTempOffset = static_cast<int32_t>(loopIdx * gmmSwiglu->tokenLen);
int32_t tempCount = static_cast<int32_t>(gmmSwiglu->tokenLen / BISECT);
LocalTensor<int8_t> castSpace = castWorkspace.Get<int8_t>();
CastFp32ToInt8Template(quantLocal, _inMMLocal, castSpace, dstTempOffset, srcTempOffset, tempCount);
mmOutQueue.EnQue(_inMMLocal);
quantOutQueue.EnQue(quantLocal);
}
template <typename mmType, bool sync, typename CHANNELDTYPE>
__aicore__ inline void GMMSwigluCompute<mmType, sync, CHANNELDTYPE>::customDataCopyOut() {
// perChanelScale * perTokenScale
LocalTensor<float> quantScaleLocal = quantScaleOutQueue.DeQue<float>();
DataCopyParams copyParams_0{1, (uint16_t)(vecConfig.innerLoopNum * sizeof(float)), 0, 0};
PipeBarrier<PIPE_ALL>();
DataCopyPad(quantScaleOutputGM[vecConfig.startIdx], quantScaleLocal, copyParams_0);
LocalTensor<int8_t> quantLocal = quantOutQueue.DeQue<int8_t>();
DataCopyParams copyParams_1{1, (uint16_t)(vecConfig.innerLoopNum * gmmSwiglu->tokenLen / 2 * sizeof(int8_t)), 0, 0};
PipeBarrier<PIPE_ALL>();
DataCopyPad(quantOutputGM[vecConfig.startIdx * gmmSwiglu->tokenLen / 2], quantLocal, copyParams_1);
PipeBarrier<PIPE_ALL>();
vecConfig.startIdx += vecConfig.innerLoopNum;
vecConfig.startOffset = vecConfig.startIdx * gmmSwiglu->tokenLen;
quantOutQueue.EnQue(quantLocal);
quantScaleOutQueue.EnQue(quantScaleLocal);
}
} // namespace GROUPED_MATMUL
#endif // ASCENDC_GROUPED_MATMUL_QUANT_MIXCORE_H

View File

@@ -0,0 +1,588 @@
/*
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file grouped_matmul_swiglu_quant_weight_nz_tensor_list_split_ws.h
* \brief
*/
#ifndef ASCENDC_GROUPED_MATMUL_SWIGLU_QUANT_WEIGHT_NZ_TENSOR_LIST_SPLIT_WS_H
#define ASCENDC_GROUPED_MATMUL_SWIGLU_QUANT_WEIGHT_NZ_TENSOR_LIST_SPLIT_WS_H
#include "grouped_matmul_swiglu_quant_weight_nz_tensor_list_utils.h"
namespace GROUPED_MATMUL_SWIGLU_QUANT_WEIGHT_NZ_TENSOR_LIST {
/** @brief internal computation class
*/
template <class mmType, bool sync = false, typename CHANNELDTYPE = float>
class GMMSwigluSplitWorkSpaceCompute{
public:
using AT = typename mmType::AT::T;
using BT = typename mmType::BT::T;
using B = typename mmType::BT;
using CT = typename mmType::CT::T;
using BiasT = typename mmType::BiasT::T;
using WT = int8_t;
constexpr static bool transposeX = mmType::AT::isTrans;
constexpr static bool transposeW = mmType::BT::isTrans;
/** @brief constructor */
__aicore__ inline GMMSwigluSplitWorkSpaceCompute(typename mmType::MT& mm_): mm(mm_) {}
__aicore__ inline void Init(GM_ADDR x, GM_ADDR weight, GM_ADDR perChannelScale, GM_ADDR perTokenScale,
GM_ADDR groupList, GM_ADDR quantOutput, GM_ADDR quantScaleOutput,
GM_ADDR workspace,
const GMMSwigluBaseParams* __restrict gmmBaseParamsIN,
const TCubeTiling* __restrict mmTilingDataIN,
const GMMSwiglu* __restrict gmmSwigluIN, TPipe* tPipeIN);
__aicore__ inline void Process();
private:
__aicore__ inline void MMCompute(uint32_t groupIdx, MNConfig& mnConfig, uint32_t coreIdx, GlobalTensor<int32_t> &mmOutGM);
__aicore__ inline void UpdateMnConfig(MNConfig &mnConfig);
__aicore__ inline void SetMNConfig(const int32_t splitValue, const uint32_t groupIdx, MNConfig &mnConfig);
__aicore__ inline void SetMKN(const int32_t splitValue, const uint32_t groupIdx, MNConfig &mnConfig);
__aicore__ inline uint64_t GetWOffset(uint32_t tailN, uint32_t k);
__aicore__ inline void MNBlockIdxCompute(MNConfig &mnConfig, const uint32_t curBlock,
const uint32_t count, const uint32_t thresholdM_dimN);
template <typename DTYPE_CS>
__aicore__ inline void UpdateChannelScale(uint32_t loopidx, VecConfig& vecConfig);
__aicore__ inline void VectorCompute(uint32_t loopidx, VecConfig& vecConfig);
template <typename DTYPE_CS>
__aicore__ inline void PreLoadTokenAndChannel(LocalTensor<float>& channelScaleLocal, VecConfig& vecConfig);
__aicore__ inline void UpdateVecConfig(uint32_t blockIdx, VecConfig& vecConfig);
__aicore__ inline void UpdateWorkSpaceSplitConfig(WorkSpaceSplitConfig &workspaceSplitConfig, int32_t workspaceSplitLoopIdx);
__aicore__ inline void InitWorkSpaceSplitConfig(WorkSpaceSplitConfig &workspaceSplitConfig);
__aicore__ inline void customDataCopyIn(uint32_t outLoopIdx, GlobalTensor<int32_t> &mmOutGM, VecConfig& vecConfig);
__aicore__ inline void customDataCopyOut(VecConfig& vecConfig);
__aicore__ inline void Dequant(uint32_t loopidx, VecConfig& vecConfig);
__aicore__ inline void Quant(uint32_t loopidx, VecConfig& vecConfig);
__aicore__ inline void Swiglu(uint32_t loopidx, VecConfig& vecConfig);
private:
typename mmType::MT& mm;
const GMMSwigluBaseParams* __restrict gmmBaseParams;
const GMMSwiglu* __restrict gmmSwiglu;
const TCubeTiling* __restrict mmTilingData;
uint32_t blockIdx;
WorkSpaceSplitConfig workspaceSplitConfig;
TPipe* pipe;
GlobalTensor<int8_t> xGM;
GlobalTensor<int8_t> weightGM;
GlobalTensor<CHANNELDTYPE> perChannelScaleGM;
GlobalTensor<float> perTokenScaleGM;
GlobalTensor<int64_t> groupListGM;
GlobalTensor<int8_t> quantOutputGM;
GlobalTensor<float> quantScaleOutputGM;
GlobalTensor<int32_t> mmOutGM1;
GlobalTensor<int32_t> mmOutGM2;
// define the que
TQue<QuePosition::VECIN, 1> mmOutQueue;
TQue<QuePosition::VECIN, 1> perChannelScaleInQueue;
TQue<QuePosition::VECOUT, 1> quantOutQueue;
TQue<QuePosition::VECOUT, 1> quantScaleOutQueue;
TBuf<TPosition::VECCALC> reduceWorkspace;
TBuf<TPosition::VECCALC> castWorkspace;
bool sequentialWrite = true;
uint32_t cubeNum; // Matmul completions on the kernel
uint32_t groupNum; // Matmul completions on the kernel
int64_t aicCoreNum;
int64_t aivCoreNum;
GM_ADDR xTensorPtr;
GM_ADDR weightTensorPtr;
GM_ADDR perChannelScalePtr;
};
template <typename mmType, bool sync, typename CHANNELDTYPE>
__aicore__ inline void GMMSwigluSplitWorkSpaceCompute<mmType, sync, CHANNELDTYPE>::Init(GM_ADDR x, GM_ADDR weight,
GM_ADDR perChannelScale, GM_ADDR perTokenScale,
GM_ADDR groupList, GM_ADDR quantOutput,
GM_ADDR quantScaleOutput, GM_ADDR workspace,
const GMMSwigluBaseParams* __restrict gmmSwigluBaseParamsIn,
const TCubeTiling* __restrict mmTilingDataIN,
const GMMSwiglu* __restrict gmmSwigluIN, TPipe* tPipeIN)
{
aicCoreNum = GetBlockNum();
aivCoreNum = aicCoreNum * 2;
blockIdx = GetBlockIdx();
pipe = tPipeIN;
xTensorPtr = x;
weightTensorPtr = weight;
perChannelScalePtr = perChannelScale;
mmTilingData = mmTilingDataIN;
gmmBaseParams = gmmSwigluBaseParamsIn;
gmmSwiglu = gmmSwigluIN;
groupNum = gmmSwiglu->groupListLen;
if ASCEND_IS_AIC {
groupListGM.SetGlobalBuffer((__gm__ int64_t *)groupList, gmmSwiglu->groupListLen);
mmOutGM1.SetGlobalBuffer((__gm__ int32_t *)workspace, gmmBaseParams->mLimit * gmmSwiglu->tokenLen);
mmOutGM2.SetGlobalBuffer((__gm__ int32_t *)workspace + gmmBaseParams->mLimit * gmmSwiglu->tokenLen,
gmmBaseParams->mLimit * gmmSwiglu->tokenLen);
}
if ASCEND_IS_AIV {
mmOutGM1.SetGlobalBuffer((__gm__ int32_t *)workspace, gmmBaseParams->mLimit * gmmSwiglu->tokenLen);
mmOutGM2.SetGlobalBuffer((__gm__ int32_t *)workspace + gmmBaseParams->mLimit * gmmSwiglu->tokenLen,
gmmBaseParams->mLimit * gmmSwiglu->tokenLen);
perChannelScaleGM.SetGlobalBuffer((__gm__ CHANNELDTYPE *)perChannelScale,
gmmSwiglu->groupListLen * gmmSwiglu->tokenLen);
perTokenScaleGM.SetGlobalBuffer((__gm__ float *)perTokenScale, gmmBaseParams->M);
groupListGM.SetGlobalBuffer((__gm__ int64_t *)groupList, gmmSwiglu->groupListLen);
quantOutputGM.SetGlobalBuffer((__gm__ int8_t *)quantOutput, gmmBaseParams->M * gmmSwiglu->tokenLen / 2);
quantScaleOutputGM.SetGlobalBuffer((__gm__ float *)quantScaleOutput, gmmBaseParams->M);
}
}
template <typename mmType, bool sync, typename CHANNELDTYPE>
__aicore__ inline void GMMSwigluSplitWorkSpaceCompute<mmType, sync, CHANNELDTYPE>::InitWorkSpaceSplitConfig(WorkSpaceSplitConfig &workspaceSplitConfig)
{
workspaceSplitConfig.M = groupListGM.GetValue(gmmSwiglu->groupListLen - 1);
workspaceSplitConfig.loopCount = Ceil(workspaceSplitConfig.M, gmmBaseParams->mLimit);
workspaceSplitConfig.notLastTaskSize = gmmBaseParams->mLimit;
workspaceSplitConfig.lastLoopTaskSize = workspaceSplitConfig.M - (workspaceSplitConfig.loopCount - 1) * gmmBaseParams->mLimit;
workspaceSplitConfig.leftMatrixStartIndex = 0;
workspaceSplitConfig.rightMatrixExpertStartIndex = 0;
workspaceSplitConfig.rightMatrixExpertNextStartIndex = 0;
workspaceSplitConfig.isLastLoop = false;
}
template <typename mmType, bool sync, typename CHANNELDTYPE>
__aicore__ inline void GMMSwigluSplitWorkSpaceCompute<mmType, sync, CHANNELDTYPE>::UpdateWorkSpaceSplitConfig(WorkSpaceSplitConfig &workspaceSplitConfig, int32_t workspaceSplitLoopIdx)
{
workspaceSplitConfig.leftMatrixStartIndex = workspaceSplitLoopIdx * gmmBaseParams->mLimit;
workspaceSplitConfig.rightMatrixExpertStartIndex = workspaceSplitConfig.rightMatrixExpertNextStartIndex;
workspaceSplitConfig.rightMatrixExpertEndIndex = workspaceSplitConfig.rightMatrixExpertStartIndex;
// Calculate the right expert matrix end index (rightMatrixExpertEndIndex) and the next start index (rightMatrixExpertNextStartIndex)
int32_t curTaskNum = 0;
int32_t nextTaskNum = 0;
while(workspaceSplitConfig.rightMatrixExpertEndIndex < gmmSwiglu->groupListLen)
{
curTaskNum = groupListGM.GetValue(workspaceSplitConfig.rightMatrixExpertEndIndex) - workspaceSplitConfig.leftMatrixStartIndex;
int32_t nextTaskIdx = workspaceSplitConfig.rightMatrixExpertEndIndex >= gmmSwiglu->groupListLen - 1 \
? gmmSwiglu->groupListLen - 1 \
: workspaceSplitConfig.rightMatrixExpertEndIndex + 1;
nextTaskNum = groupListGM.GetValue(nextTaskIdx) - workspaceSplitConfig.leftMatrixStartIndex;
if (curTaskNum > gmmBaseParams->mLimit){
workspaceSplitConfig.rightMatrixExpertNextStartIndex = workspaceSplitConfig.rightMatrixExpertEndIndex;
break;
} else if (curTaskNum == gmmBaseParams->mLimit && nextTaskNum > gmmBaseParams->mLimit){
workspaceSplitConfig.rightMatrixExpertNextStartIndex = workspaceSplitConfig.rightMatrixExpertEndIndex + 1;
break;
} else if (nextTaskNum > gmmBaseParams->mLimit){
workspaceSplitConfig.rightMatrixExpertEndIndex++;
workspaceSplitConfig.rightMatrixExpertNextStartIndex = workspaceSplitConfig.rightMatrixExpertEndIndex;
break;
}
workspaceSplitConfig.rightMatrixExpertEndIndex++;
}
workspaceSplitConfig.isLastLoop = workspaceSplitLoopIdx == workspaceSplitConfig.loopCount - 1 ? true : false;
if (workspaceSplitConfig.isLastLoop) {
workspaceSplitConfig.rightMatrixExpertEndIndex = workspaceSplitConfig.rightMatrixExpertEndIndex >= gmmSwiglu->groupListLen \
? gmmSwiglu->groupListLen - 1 \
: workspaceSplitConfig.rightMatrixExpertEndIndex;
}
}
template <typename mmType, bool sync, typename CHANNELDTYPE>
__aicore__ inline void GMMSwigluSplitWorkSpaceCompute<mmType, sync, CHANNELDTYPE>::Process() {
InitWorkSpaceSplitConfig(workspaceSplitConfig);
int32_t parallelNum = gmmBaseParams->isPreFill ? 2 : 1; // 2: double workspace buffer
for (int32_t workspaceSplitLoopIdx = 0; workspaceSplitLoopIdx < workspaceSplitConfig.loopCount; workspaceSplitLoopIdx++) {
UpdateWorkSpaceSplitConfig(workspaceSplitConfig, workspaceSplitLoopIdx);
GlobalTensor<int32_t> mmOutGM = (workspaceSplitLoopIdx % 2 == 0 ) ? mmOutGM1 : mmOutGM2;
if ASCEND_IS_AIC {
if (workspaceSplitLoopIdx >= parallelNum){ // first parallelNum core no need to wait
SyncAll<false>();
}
MNConfig mnConfig;
int32_t prevSplitValue = workspaceSplitConfig.leftMatrixStartIndex;
for (uint32_t groupIdx = workspaceSplitConfig.rightMatrixExpertStartIndex, count = 0; groupIdx <= workspaceSplitConfig.rightMatrixExpertEndIndex; ++groupIdx) {
UpdateMnConfig(mnConfig);
int32_t currSplitValue = static_cast<int32_t>(groupListGM.GetValue(groupIdx));
currSplitValue = currSplitValue > (workspaceSplitLoopIdx + 1) * gmmBaseParams->mLimit \
? (workspaceSplitLoopIdx + 1) * gmmBaseParams->mLimit \
: currSplitValue;
int32_t splitValue = currSplitValue - prevSplitValue;
prevSplitValue = currSplitValue;
SetMNConfig(splitValue, groupIdx, mnConfig);
if (mnConfig.m <= 0 || mnConfig.k <= 0 || mnConfig.n <= 0) {
continue;
}
mnConfig.blockDimM = Ceil(mnConfig.m, mnConfig.singleM);
mnConfig.blockDimN = Ceil(mnConfig.n, mnConfig.singleN);
uint32_t curCount = count + mnConfig.blockDimM * mnConfig.blockDimN;
uint32_t curBlock = blockIdx >= count ? blockIdx : blockIdx + gmmBaseParams->coreNum;
uint32_t thresholdM_dimN = THRESHOLD_BLOCK_NUM * mnConfig.blockDimN;
while (curBlock < curCount) {
MNBlockIdxCompute(mnConfig, curBlock, count, thresholdM_dimN);
MMCompute(groupIdx, mnConfig, blockIdx, mmOutGM);
curBlock += aicCoreNum;
}
count = curCount % gmmBaseParams->coreNum;
}
SyncAll<false>();
}
if ASCEND_IS_AIV {
VecConfig vecConfig;
UpdateVecConfig(blockIdx, vecConfig);
if (blockIdx < vecConfig.usedCoreNum) {
LocalTensor<float> channelScaleLocal = perChannelScaleInQueue.AllocTensor<float>();
LocalTensor<int32_t> mmLocal = mmOutQueue.AllocTensor<int32_t>();
LocalTensor<int8_t> quantLocal = quantOutQueue.AllocTensor<int8_t>();
LocalTensor<float> quantScaleLocal = quantScaleOutQueue.AllocTensor<float>();
mmOutQueue.EnQue(mmLocal);
quantScaleOutQueue.EnQue(quantScaleLocal);
quantOutQueue.EnQue(quantLocal);
PreLoadTokenAndChannel<CHANNELDTYPE>(channelScaleLocal, vecConfig);
}
SyncAll<false>();
if (blockIdx < vecConfig.usedCoreNum) {
for (uint32_t outLoopIdx = 0; outLoopIdx < vecConfig.outLoopNum; outLoopIdx++) {
vecConfig.innerLoopNum = outLoopIdx == (vecConfig.outLoopNum - 1)
? vecConfig.tailLoopNum
: gmmSwiglu->maxProcessRowNum;
PipeBarrier<PIPE_ALL>();
customDataCopyIn(outLoopIdx, mmOutGM, vecConfig);
PipeBarrier<PIPE_ALL>();
for (uint32_t innerLoopIdx = 0; innerLoopIdx < vecConfig.innerLoopNum; innerLoopIdx++) {
UpdateChannelScale<CHANNELDTYPE>(innerLoopIdx, vecConfig);
VectorCompute(innerLoopIdx, vecConfig);
}
PipeBarrier<PIPE_ALL>();
customDataCopyOut(vecConfig);
PipeBarrier<PIPE_ALL>();
}
LocalTensor<float> channelScaleLocal = perChannelScaleInQueue.DeQue<float>();
LocalTensor<int32_t> mmLocal = mmOutQueue.DeQue<int32_t>();
LocalTensor<int8_t> quantLocal = quantOutQueue.DeQue<int8_t>();
LocalTensor<float> quantScaleLocal = quantScaleOutQueue.DeQue<float>();
perChannelScaleInQueue.FreeTensor(channelScaleLocal);
mmOutQueue.FreeTensor(mmLocal);
quantScaleOutQueue.FreeTensor(quantScaleLocal);
quantOutQueue.FreeTensor(quantLocal);
}
if (workspaceSplitLoopIdx < workspaceSplitConfig.loopCount - parallelNum){
SyncAll<false>();
}
}
}
}
template <typename mmType, bool sync, typename CHANNELDTYPE>
template <typename DTYPE_CS>
__aicore__ inline void GMMSwigluSplitWorkSpaceCompute<mmType, sync, CHANNELDTYPE>::PreLoadTokenAndChannel(LocalTensor<float>& channelScaleLocal, VecConfig& vecConfig)
{
GlobalTensor<CHANNELDTYPE> perChannelScaleTensor;
perChannelScaleTensor.SetGlobalBuffer(GetTensorAddr<CHANNELDTYPE>(vecConfig.curGroupIdx, perChannelScalePtr));
DataCopyExtParams copyChannelParams{1, static_cast<uint32_t>(gmmSwiglu->tokenLen * sizeof(DTYPE_CS)), 0, 0, 0};
DataCopyPadExtParams<DTYPE_CS> padParams{false, 0 ,0, 0};
if constexpr(!IsSameType<DTYPE_CS, float>::value) {
LocalTensor<DTYPE_CS> dstLocalT = channelScaleLocal.template ReinterpretCast<DTYPE_CS>();
DataCopyPad(dstLocalT[gmmSwiglu->tokenLen], perChannelScaleTensor, copyChannelParams, padParams);
PipeBarrier<PIPE_ALL>();
Cast(channelScaleLocal, dstLocalT[gmmSwiglu->tokenLen], RoundMode::CAST_NONE, gmmSwiglu->tokenLen);
} else {
DataCopyPad(channelScaleLocal, perChannelScaleTensor, copyChannelParams, padParams);
}
perChannelScaleInQueue.EnQue(channelScaleLocal);
}
template <typename mmType, bool sync, typename CHANNELDTYPE>
__aicore__ inline void GMMSwigluSplitWorkSpaceCompute<mmType, sync, CHANNELDTYPE>::MMCompute(uint32_t groupIdx, MNConfig& mnConfig, uint32_t coreIdx, GlobalTensor<int32_t> &mmOutGM)
{
uint32_t tailN = mnConfig.nIdx * mnConfig.singleN;
uint32_t curSingleN = mnConfig.nIdx < mnConfig.blockDimN - 1 ? mnConfig.singleN : mnConfig.n - tailN;
uint32_t curSingleM = mnConfig.mIdx < mnConfig.blockDimM - 1 ? mnConfig.singleM
: mnConfig.m - mnConfig.mIdx * mnConfig.singleM;
uint64_t xOffset = mnConfig.mIdx * mnConfig.singleM * mnConfig.k;
if constexpr (transposeX) {
xOffset = mnConfig.mIdx * mnConfig.singleM;
}
uint64_t outOffset = mnConfig.mIdx * mnConfig.singleM * mnConfig.n + tailN;
xGM.SetGlobalBuffer((__gm__ int8_t *)xTensorPtr + mnConfig.xBaseOffset + workspaceSplitConfig.leftMatrixStartIndex * mnConfig.k);
weightGM.SetGlobalBuffer(GetTensorAddr<int8_t>(groupIdx, weightTensorPtr) + GetWOffset(tailN, mnConfig.k));
if (mnConfig.blockDimM == 1){
weightGM.SetL2CacheHint(CacheMode::CACHE_MODE_DISABLE);
} else {
weightGM.SetL2CacheHint(CacheMode::CACHE_MODE_NORMAL);
}
mnConfig.workSpaceOffset = outOffset + mnConfig.yBaseOffset;
mm.SetOrgShape(mnConfig.m, mnConfig.n, mnConfig.k);
mm.SetSingleShape(curSingleM, curSingleN, mnConfig.k);
mm.SetTensorA(xGM[xOffset], transposeX);
mm.SetTensorB(weightGM, transposeW);
mm.template IterateAll<sync>(mmOutGM[mnConfig.workSpaceOffset], 0);
}
template <typename mmType, bool sync, typename CHANNELDTYPE>
__aicore__ inline void GMMSwigluSplitWorkSpaceCompute<mmType, sync, CHANNELDTYPE>::UpdateMnConfig(MNConfig &mnConfig) {
if constexpr (B::format == CubeFormat::NZ) {
mnConfig.wBaseOffset += AlignUp<16>(mnConfig.k) * AlignUp<32>(mnConfig.n); // 16: nz format last two dim size
} else {
mnConfig.wBaseOffset += mnConfig.k * mnConfig.n;
}
mnConfig.nAxisBaseOffset += mnConfig.n;
mnConfig.mAxisBaseOffset += mnConfig.m;
mnConfig.xBaseOffset += mnConfig.m * mnConfig.k;
mnConfig.yBaseOffset += mnConfig.m * mnConfig.n;
}
template <typename mmType, bool sync, typename CHANNELDTYPE>
__aicore__ inline void GMMSwigluSplitWorkSpaceCompute<mmType, sync, CHANNELDTYPE>::SetMNConfig(const int32_t splitValue, const uint32_t groupIdx, MNConfig &mnConfig) {
SetMKN(splitValue, groupIdx, mnConfig);
mnConfig.baseM = BASIC_M;
mnConfig.baseN = BASIC_N;
mnConfig.singleM = SINGLE_CORE_M;
mnConfig.singleN = SINGLE_CORE_N;
}
template <typename mmType, bool sync, typename CHANNELDTYPE>
__aicore__ inline void GMMSwigluSplitWorkSpaceCompute<mmType, sync, CHANNELDTYPE>::SetMKN(const int32_t splitValue, const uint32_t groupIdx, MNConfig &mnConfig)
{
mnConfig.m = static_cast<int64_t>(splitValue);
mnConfig.k = gmmBaseParams->K; // tilingData
mnConfig.n = gmmBaseParams->N; // tilingData
}
template <typename mmType, bool sync, typename CHANNELDTYPE>
__aicore__ inline uint64_t GMMSwigluSplitWorkSpaceCompute<mmType, sync, CHANNELDTYPE>::GetWOffset(uint32_t tailN, uint32_t k) {
uint64_t wOffset = 0;
if constexpr (mmType::BT::format == CubeFormat::NZ) {
wOffset = tailN * AlignUp<16>(k); // 16: nz format last two dim size
} else {
wOffset = tailN;
}
return wOffset;
}
template <typename mmType, bool sync, typename CHANNELDTYPE>
__aicore__ inline void GMMSwigluSplitWorkSpaceCompute<mmType, sync, CHANNELDTYPE>::MNBlockIdxCompute(MNConfig &mnConfig, const uint32_t curBlock,
const uint32_t count, const uint32_t thresholdM_dimN) {
mnConfig.mIdx = (curBlock - count) / mnConfig.blockDimN;
mnConfig.nIdx = (curBlock - count) % mnConfig.blockDimN;
}
template <typename mmType, bool sync, typename CHANNELDTYPE>
__aicore__ inline void GMMSwigluSplitWorkSpaceCompute<mmType, sync, CHANNELDTYPE>::UpdateVecConfig(uint32_t blockIdx, VecConfig& vecConfig)
{
// Step 1: Read grouplist reduceSum to calculate total data count
vecConfig.M = workspaceSplitConfig.isLastLoop \
? workspaceSplitConfig.lastLoopTaskSize\
: workspaceSplitConfig.notLastTaskSize;
// Step 2: Calculate core allocation
uint32_t eachCoreTaskNum = (vecConfig.M + aivCoreNum - 1) / aivCoreNum;
vecConfig.usedCoreNum = vecConfig.M >= aivCoreNum ? aivCoreNum : vecConfig.M;
uint32_t tailCoreIdx = vecConfig.M - (eachCoreTaskNum - 1) * vecConfig.usedCoreNum;
vecConfig.taskNum = blockIdx < tailCoreIdx ? eachCoreTaskNum : eachCoreTaskNum - 1;
vecConfig.startIdx = blockIdx < tailCoreIdx
? eachCoreTaskNum * blockIdx
:((eachCoreTaskNum - 1) * blockIdx + tailCoreIdx);
vecConfig.curIdx = vecConfig.startIdx;
vecConfig.startOffset = vecConfig.startIdx * gmmSwiglu->tokenLen;
vecConfig.curOffset = vecConfig.startOffset;
int64_t curStartIdx = vecConfig.startIdx;
int64_t prevM = workspaceSplitConfig.leftMatrixStartIndex;
for (uint32_t groupIdx = workspaceSplitConfig.rightMatrixExpertStartIndex; groupIdx <= workspaceSplitConfig.rightMatrixExpertEndIndex; groupIdx++){
int64_t currM = groupListGM.GetValue(groupIdx);
int64_t tempM = currM - prevM;
prevM = currM;
if (curStartIdx >= 0 && curStartIdx - tempM < 0) {
vecConfig.curGroupIdx = groupIdx;
vecConfig.nextUpadteInterVal = tempM - curStartIdx;
}
curStartIdx -= tempM;
}
// Step 3: Calculate total data volume
vecConfig.outLoopNum = (vecConfig.taskNum + gmmSwiglu->maxProcessRowNum - 1) / gmmSwiglu->maxProcessRowNum;
vecConfig.tailLoopNum = vecConfig.taskNum % gmmSwiglu->maxProcessRowNum
? vecConfig.taskNum % gmmSwiglu->maxProcessRowNum
: gmmSwiglu->maxProcessRowNum;
pipe->Reset();
// Step 4: Allocate space
pipe->InitBuffer(mmOutQueue, 1, gmmSwiglu->maxProcessRowNum * gmmSwiglu->tokenLen * sizeof(int32_t));
pipe->InitBuffer(perChannelScaleInQueue, 1, gmmSwiglu->tokenLen * sizeof(float));
pipe->InitBuffer(quantOutQueue, 1, gmmSwiglu->maxProcessRowNum * gmmSwiglu->tokenLen / 2 * sizeof(int8_t));
pipe->InitBuffer(quantScaleOutQueue, 1, AlignUp<int32_t>(gmmSwiglu->maxProcessRowNum, 8) * sizeof(float));
pipe->InitBuffer(reduceWorkspace, 1024 * sizeof(float));
pipe->InitBuffer(castWorkspace, 32 * sizeof(int8_t));
}
template <typename mmType, bool sync, typename CHANNELDTYPE>
__aicore__ inline void GMMSwigluSplitWorkSpaceCompute<mmType, sync, CHANNELDTYPE>::customDataCopyIn(uint32_t outLoopIdx, GlobalTensor<int32_t> &mmOutGM, VecConfig& vecConfig)
{
LocalTensor<int32_t> _inMMLocal_0 = mmOutQueue.DeQue<int32_t>();
DataCopyExtParams copyParams_0{1, static_cast<uint32_t>(vecConfig.innerLoopNum * gmmSwiglu->tokenLen * sizeof(int32_t)), 0, 0, 0};
DataCopyPadExtParams<int32_t> padParams_0{false, 0 ,0, 0};
PipeBarrier<PIPE_ALL>();
DataCopyPad(_inMMLocal_0, mmOutGM[vecConfig.curOffset], copyParams_0, padParams_0);
mmOutQueue.EnQue(_inMMLocal_0);
LocalTensor<int32_t> _inMMLocal_1 = mmOutQueue.DeQue<int32_t>();
Cast(_inMMLocal_1.ReinterpretCast<float>(), _inMMLocal_1, RoundMode::CAST_NONE, vecConfig.innerLoopNum * gmmSwiglu->tokenLen);
mmOutQueue.EnQue(_inMMLocal_1);
LocalTensor<float> _inMMLocal_2 = mmOutQueue.DeQue<float>();
set_flag(PIPE_S, PIPE_V, EVENT_ID0);
for (uint32_t i = 0; i < vecConfig.innerLoopNum; i++){
wait_flag(PIPE_S, PIPE_V, EVENT_ID0);
float scale = perTokenScaleGM.GetValue(vecConfig.curIdx + workspaceSplitConfig.leftMatrixStartIndex);
set_flag(PIPE_S, PIPE_V, EVENT_ID0);
wait_flag(PIPE_S, PIPE_V, EVENT_ID0);
Muls(_inMMLocal_2[i * gmmSwiglu->tokenLen], _inMMLocal_2[i * gmmSwiglu->tokenLen], scale, gmmSwiglu->tokenLen);
set_flag(PIPE_S, PIPE_V, EVENT_ID0);
vecConfig.curIdx++;
}
wait_flag(PIPE_S, PIPE_V, EVENT_ID0);
vecConfig.curOffset = vecConfig.curIdx * gmmSwiglu->tokenLen;
mmOutQueue.EnQue(_inMMLocal_2);
}
template <typename mmType, bool sync, typename CHANNELDTYPE>
template <typename DTYPE_CS>
__aicore__ inline void GMMSwigluSplitWorkSpaceCompute<mmType, sync, CHANNELDTYPE>::UpdateChannelScale(uint32_t loopIdx, VecConfig& vecConfig){
// Update perChannel
if (unlikely(vecConfig.nextUpadteInterVal == 0)) {
int64_t loop = gmmSwiglu->groupListLen - vecConfig.curGroupIdx;
while (loop--) {
int64_t curTemp = groupListGM.GetValue(vecConfig.curGroupIdx);
vecConfig.curGroupIdx++;
int64_t nextTemp = groupListGM.GetValue(vecConfig.curGroupIdx);
if(nextTemp != curTemp){
vecConfig.nextUpadteInterVal = nextTemp - curTemp;
break;
}
}
LocalTensor<float> _inChannel = perChannelScaleInQueue.DeQue<float>();
DataCopyExtParams copyParams{1, static_cast<uint32_t>(gmmSwiglu->tokenLen * sizeof(DTYPE_CS)), 0, 0, 0};
DataCopyPadExtParams<DTYPE_CS> padParams{false, 0 ,0, 0};
GlobalTensor<CHANNELDTYPE> perChannelScaleTensor;
perChannelScaleTensor.SetGlobalBuffer(GetTensorAddr<CHANNELDTYPE>(vecConfig.curGroupIdx, perChannelScalePtr));
if constexpr(!IsSameType<DTYPE_CS, float>::value) {
LocalTensor<DTYPE_CS> dstLocalT = _inChannel.template ReinterpretCast<DTYPE_CS>();
DataCopyPad(dstLocalT[gmmSwiglu->tokenLen], perChannelScaleTensor, copyParams, padParams);
PipeBarrier<PIPE_ALL>();
Cast(_inChannel, dstLocalT[gmmSwiglu->tokenLen], RoundMode::CAST_NONE, gmmSwiglu->tokenLen);
} else {
DataCopyPad(_inChannel, perChannelScaleTensor, copyParams, padParams);
}
PipeBarrier<PIPE_ALL>();
perChannelScaleInQueue.EnQue(_inChannel);
}
}
template <typename mmType, bool sync, typename CHANNELDTYPE>
__aicore__ inline void GMMSwigluSplitWorkSpaceCompute<mmType, sync, CHANNELDTYPE>::VectorCompute(uint32_t loopIdx, VecConfig& vecConfig) {
Dequant(loopIdx, vecConfig);
Swiglu(loopIdx, vecConfig);
Quant(loopIdx, vecConfig);
}
template <typename mmType, bool sync, typename CHANNELDTYPE>
__aicore__ inline void GMMSwigluSplitWorkSpaceCompute<mmType, sync, CHANNELDTYPE>::Dequant(uint32_t loopIdx, VecConfig& vecConfig) {
// perChanelScale * perTokenScale
LocalTensor<float> mmLocal = mmOutQueue.DeQue<float>();
LocalTensor<float> perChannelLocal = perChannelScaleInQueue.DeQue<float>();
Mul(mmLocal[loopIdx * gmmSwiglu->tokenLen], mmLocal[loopIdx * gmmSwiglu->tokenLen], perChannelLocal, gmmSwiglu->tokenLen);
vecConfig.nextUpadteInterVal--;
mmOutQueue.EnQue(mmLocal);
perChannelScaleInQueue.EnQue(perChannelLocal);
}
template <typename mmType, bool sync, typename CHANNELDTYPE>
__aicore__ inline void GMMSwigluSplitWorkSpaceCompute<mmType, sync, CHANNELDTYPE>::Swiglu(uint32_t loopIdx, VecConfig& vecConfig) {
// High-level API swiglu
LocalTensor<float> _inMMLocal = mmOutQueue.DeQue<float>();
float beta = 1.0f;
LocalTensor<float> workspaceLocal= reduceWorkspace.Get<float>();
LocalTensor<float> src0Local = _inMMLocal[loopIdx * gmmSwiglu->tokenLen + gmmSwiglu->tokenLen / 2];
LocalTensor<float> src1Local = _inMMLocal[loopIdx * gmmSwiglu->tokenLen];
SwiGLU<float, false>(workspaceLocal, src0Local, src1Local, beta, gmmSwiglu->tokenLen / 2);
PipeBarrier<PIPE_ALL>();
DataCopyParams repeatParams{1, static_cast<uint16_t>((gmmSwiglu->tokenLen / 2) / 8), 0, 0};
DataCopy(_inMMLocal[loopIdx * gmmSwiglu->tokenLen], workspaceLocal, repeatParams);
mmOutQueue.EnQue(_inMMLocal);
}
template <typename mmType, bool sync, typename CHANNELDTYPE>
__aicore__ inline void GMMSwigluSplitWorkSpaceCompute<mmType, sync, CHANNELDTYPE>::Quant(uint32_t loopIdx, VecConfig& vecConfig) {
LocalTensor<float> _inMMLocal = mmOutQueue.DeQue<float>();
Abs(_inMMLocal[loopIdx * gmmSwiglu->tokenLen + gmmSwiglu->tokenLen / BISECT],
_inMMLocal[loopIdx * gmmSwiglu->tokenLen],
gmmSwiglu->tokenLen / BISECT);
LocalTensor<float> workspaceLocal= reduceWorkspace.Get<float>();
PipeBarrier<PIPE_ALL>();
ReduceMaxTemplate(workspaceLocal,
_inMMLocal, loopIdx * gmmSwiglu->tokenLen + gmmSwiglu->tokenLen / BISECT, gmmSwiglu->tokenLen / BISECT);
PipeBarrier<PIPE_ALL>();
float quantScale = workspaceLocal.GetValue(0) / QUANT_SCALE_INT8;
PipeBarrier<PIPE_ALL>();
LocalTensor<float> quantScaleLocal = quantScaleOutQueue.DeQue<float>();
PipeBarrier<PIPE_ALL>();
quantScaleLocal.SetValue(loopIdx, quantScale);
PipeBarrier<PIPE_ALL>();
quantScale = 1 / quantScale;
PipeBarrier<PIPE_ALL>();
Muls(_inMMLocal[loopIdx * gmmSwiglu->tokenLen], _inMMLocal[loopIdx * gmmSwiglu->tokenLen],
quantScale, gmmSwiglu->tokenLen / BISECT);
PipeBarrier<PIPE_V>();
LocalTensor<int8_t> quantLocal = quantOutQueue.DeQue<int8_t>();
int32_t dstTempOffset = static_cast<int32_t>(loopIdx * gmmSwiglu->tokenLen / BISECT);
int32_t srcTempOffset = static_cast<int32_t>(loopIdx * gmmSwiglu->tokenLen);
int32_t tempCount = static_cast<int32_t>(gmmSwiglu->tokenLen / BISECT);
LocalTensor<int8_t> castSpace = castWorkspace.Get<int8_t>();
CastFp32ToInt8Template(quantLocal, _inMMLocal, castSpace, dstTempOffset, srcTempOffset, tempCount);
mmOutQueue.EnQue(_inMMLocal);
quantOutQueue.EnQue(quantLocal);
}
template <typename mmType, bool sync, typename CHANNELDTYPE>
__aicore__ inline void GMMSwigluSplitWorkSpaceCompute<mmType, sync, CHANNELDTYPE>::customDataCopyOut(VecConfig& vecConfig) {
LocalTensor<float> quantScaleLocal = quantScaleOutQueue.DeQue<float>();
DataCopyParams copyParams_0{1, (uint16_t)(vecConfig.innerLoopNum * sizeof(float)), 0, 0};
PipeBarrier<PIPE_ALL>();
DataCopyPad(quantScaleOutputGM[workspaceSplitConfig.leftMatrixStartIndex + vecConfig.startIdx], quantScaleLocal, copyParams_0);
LocalTensor<int8_t> quantLocal = quantOutQueue.DeQue<int8_t>();
DataCopyParams copyParams_1{1, (uint16_t)(vecConfig.innerLoopNum * gmmSwiglu->tokenLen / 2 * sizeof(int8_t)), 0, 0};
PipeBarrier<PIPE_ALL>();
DataCopyPad(quantOutputGM[(workspaceSplitConfig.leftMatrixStartIndex + vecConfig.startIdx) * gmmSwiglu->tokenLen / 2], quantLocal, copyParams_1);
PipeBarrier<PIPE_ALL>();
vecConfig.startIdx += vecConfig.innerLoopNum;
vecConfig.startOffset = vecConfig.startIdx * gmmSwiglu->tokenLen;
quantOutQueue.EnQue(quantLocal);
quantScaleOutQueue.EnQue(quantScaleLocal);
}
} // namespace GROUPED_MATMUL
#endif // ASCENDC_GROUPED_MATMUL_SWIGLU_QUANT_WEIGHT_NZ_TENSOR_LIST_SPLIT_WS_H

View File

@@ -0,0 +1,240 @@
/*
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file grouped_matmul_swiglu_quant_weight_nz_tensor_list_utils.h
* \brief
*/
#ifndef ASCENDC_GROUPED_MATMUL_SWIGLU_QUANT_WEIGHT_NZ_TENSOR_LIST_UTILS_H
#define ASCENDC_GROUPED_MATMUL_SWIGLU_QUANT_WEIGHT_NZ_TENSOR_LIST_UTILS_H
#include "kernel_tiling/kernel_tiling.h"
#include "kernel_operator.h"
#include "lib/matmul_intf.h"
namespace GROUPED_MATMUL_SWIGLU_QUANT_WEIGHT_NZ_TENSOR_LIST {
using namespace AscendC;
constexpr uint32_t INT8_BITS = 8; // a int8 number has 8 bits
constexpr uint32_t UB_BLOCK_UNIT_SIZE = 32; // 32: a block has 32 bytes data
constexpr uint32_t THRESHOLD_BLOCK_NUM = 8;
constexpr uint32_t UB_BLOCK_DOUBLE_UNIT_SIZE = 64; // 64: a block has 64 bytes data
constexpr uint32_t HALF_UB_BLOCK_UNIT_SIZE = UB_BLOCK_UNIT_SIZE / 2; // 2: a float16 data has two bytes
constexpr uint32_t SINGLE_CORE_M = 128;
constexpr uint32_t SINGLE_CORE_N = 512;
constexpr uint32_t SINGLE_CORE_K = 7168;
constexpr uint32_t BASIC_M = 128;
constexpr uint32_t BASIC_N = 256;
constexpr uint32_t BASIC_K = 128;
constexpr uint32_t STEP_M = 1;
constexpr uint32_t STEP_N = 1;
constexpr uint32_t STEP_Ka = 4;
constexpr uint32_t STEP_Kb = 4;
constexpr uint32_t DEPTH_A1 = 8;
constexpr uint32_t DEPTH_B1 = 8;
constexpr uint32_t VEC_LEN_ONCE_REPEAT_ELE = 64;
constexpr uint32_t VEC_LEN_ONCE_REPEAT_BLOCK = 8;
constexpr uint32_t BISECT = 2;
constexpr uint32_t MOD_32_MASK = 0x1F;
constexpr uint32_t MOD_16_MASK = 0x0F;
constexpr uint32_t ALIGN_8_ELE = 8;
constexpr uint32_t ALIGN_16_ELE = 16;
constexpr float QUANT_SCALE_INT8 = 127.0f;
constexpr MatmulConfig NZ_CFG_MDL = GetMDLConfig(false, false, 0, true, false, false, true);
constexpr MatmulConfig GetMMCFG() {
MatmulConfig MM_CFG = NZ_CFG_MDL;
MM_CFG.singleCoreM = SINGLE_CORE_M;
MM_CFG.singleCoreN= SINGLE_CORE_N;
MM_CFG.singleCoreK= SINGLE_CORE_K;
MM_CFG.basicM= BASIC_M;
MM_CFG.basicN= BASIC_N;
MM_CFG.basicK= BASIC_K;
return MM_CFG;
}
constexpr static MatmulApiStaticTiling GetMMTiling(const MatmulApiStaticTiling& mmTiling)
{
MatmulApiStaticTiling tiling = mmTiling;
tiling.stepM = STEP_M;
tiling.stepN = STEP_N;
tiling.stepKa = STEP_Ka;
tiling.stepKb = STEP_Kb;
tiling.depthA1 = DEPTH_A1;
tiling.depthB1 = DEPTH_B1;
return tiling;
}
template<class AT_, class BT_, class CT_, class BiasT_, const MatmulConfig& MM_CFG>
struct MMImplType {
using AT = AT_;
using BT = BT_;
using CT = CT_;
using BiasT = BiasT_;
static constexpr MatmulConfig cfg = GetMMCFG();
static constexpr MatmulApiStaticTiling mdl = GetMMTiling(GetMatmulApiTiling<AT, BT, CT, BiasT>(cfg));
using MT = matmul::MatmulImpl<AT, BT, CT, BiasT, mdl>;
};
struct MNConfig {
int64_t m = 0;
int64_t k = 0;
int64_t n = 0;
int64_t baseM = 0;
int64_t baseN = 0;
int64_t mIdx = 0;
int64_t nIdx = 0;
int64_t blockDimM = 0;
int64_t blockDimN = 0;
int64_t singleM = 0;
int64_t singleN = 0;
int64_t wBaseOffset = 0;
int64_t nAxisBaseOffset = 0;
int64_t mAxisBaseOffset = 0;
int64_t xBaseOffset = 0;
int64_t yBaseOffset = 0;
int64_t wOutOffset = 0;
int64_t workSpaceOffset = 0;
};
struct VecConfig {
int64_t M = 0;
int64_t usedCoreNum = 0;
int64_t startOffset = 0;
int64_t curOffset = 0;
int64_t startIdx = 0;
int64_t curIdx = 0;
int64_t taskNum = 0;
int64_t curGroupIdx = 0;
int64_t outLoopNum = 0;
int64_t innerLoopNum = 0;
int64_t tailLoopNum = 0;
int64_t nextUpadteInterVal = 0;
};
struct WorkSpaceSplitConfig {
int64_t M = 0;
int64_t loopCount = 0;
int64_t leftMatrixStartIndex = 0;
int64_t rightMatrixExpertStartIndex = 0;
int64_t rightMatrixExpertNextStartIndex = 0;
int64_t rightMatrixExpertEndIndex = 0;
int64_t notLastTaskSize = 0;
int64_t lastLoopTaskSize = 0;
bool isLastLoop = false;
};
template <uint32_t base, typename T = uint32_t>
__aicore__ inline T AlignUp(T a) {
return (a + base - 1) / base * base;
}
template <typename T>
__aicore__ inline T AlignUp(T a, T base) {
return (a + base - 1) / base * base;
}
template <typename T>
__aicore__ inline T AlignDown(T a, T base) {
if (unlikely(base == 0)) {
return a;
}
return a / base * base;
}
template <>
__aicore__ inline uint32_t AlignUp<4, uint32_t>(uint32_t a) {
// to be Multiple of 4, result should be in a format of b(xxxx,x100).
// This means last two bits should be zero, requiring that
// result = num & b(1111,1100) = num & (~3).
// &(~3) operator may reduces num into the range [num, num - 3].
// As the result should be no less than a (result >= a), it means num - 3 >= a in the worst case.
// In this case, num >= a+3. On the other hand, num should also be less then a+4, otherwise,
// the result will not be least multiple of 4 for 3. In other cases like [num, num - 2],
// num = a + 3 also satisfies the goal condition.
return (a + 3) & ~3; // & ~3: set last two bits of (a+3) to be zero
}
template <>
__aicore__ inline uint32_t AlignUp<8, uint32_t>(uint32_t a) {
// In general, if we want to get the least multiple of b (b is the power of 2) for a,
// it comes to a conclusion from the above comment: result = (a + (b - 1)) & (~b)
return (a + 7) & ~7; // & ~7: set last four bits of (a+7) to be zero
}
template <>
__aicore__ inline uint32_t AlignUp<16, uint32_t>(uint32_t a) {
// In general, if we want to get the least multiple of b (b is the power of 2) for a,
// it comes to a conclusion from the above comment: result = (a + (b - 1)) & (~b)
return (a + 15) & ~15; // & ~15: set last four bits of (a+15) to be zero
}
template <>
__aicore__ inline uint32_t AlignUp<32, uint32_t>(uint32_t a) {
// refer to the above comments.
return (a + 31) & ~31; // & ~31: set last five bits of (a+31) to be zero}
}
__aicore__ inline void ReduceMaxTemplate(LocalTensor<float>& dstLocal, LocalTensor<float>& srcLocal,
uint32_t srcOffset, uint32_t count)
{
if (likely(count > VEC_LEN_ONCE_REPEAT_ELE && count % VEC_LEN_ONCE_REPEAT_ELE == 0)){
WholeReduceMax(dstLocal,
srcLocal[srcOffset], VEC_LEN_ONCE_REPEAT_ELE,
count / VEC_LEN_ONCE_REPEAT_ELE, 1, 1,
VEC_LEN_ONCE_REPEAT_BLOCK, ReduceOrder::ORDER_ONLY_VALUE);
PipeBarrier<PIPE_V>();
WholeReduceMax(dstLocal, dstLocal,
count / VEC_LEN_ONCE_REPEAT_ELE, 1, 1, 1,
VEC_LEN_ONCE_REPEAT_BLOCK, ReduceOrder::ORDER_ONLY_VALUE);
} else if (count <= VEC_LEN_ONCE_REPEAT_ELE) {
WholeReduceMax(dstLocal,
srcLocal[srcOffset],
count, 1, 1, 1, VEC_LEN_ONCE_REPEAT_BLOCK, ReduceOrder::ORDER_ONLY_VALUE);
} else {
ReduceMax(dstLocal, srcLocal[srcOffset], dstLocal, count, false);
}
}
__aicore__ inline void CastFp32ToInt8Template(LocalTensor<int8_t>& dstLocal, LocalTensor<float>& srcLocal,
LocalTensor<int8_t>& oneBlockWorkspace,
int32_t dstOffset, int32_t srcOffset, int32_t count)
{
Cast(srcLocal[srcOffset].ReinterpretCast<half>(), srcLocal[srcOffset], RoundMode::CAST_RINT, count);
PipeBarrier<PIPE_V>();
if ((dstOffset & MOD_32_MASK) == 0) {
Cast(dstLocal[dstOffset],
srcLocal[srcOffset].ReinterpretCast<half>(),
RoundMode::CAST_RINT, count);
} else if ((dstOffset & MOD_16_MASK) == 0) {
Cast(dstLocal[dstOffset + ALIGN_16_ELE],
srcLocal[srcOffset + ALIGN_8_ELE].ReinterpretCast<half>(),
RoundMode::CAST_RINT, count - ALIGN_16_ELE);
PipeBarrier<PIPE_V>();
Cast(oneBlockWorkspace, srcLocal[srcOffset].ReinterpretCast<half>(),
RoundMode::CAST_RINT, ALIGN_16_ELE);
PipeBarrier<PIPE_ALL>();
for (int32_t i = 0; i < ALIGN_16_ELE; i++) {
int8_t temp = oneBlockWorkspace.GetValue(i);
dstLocal.SetValue(dstOffset + i, temp);
}
PipeBarrier<PIPE_ALL>();
}
}
template <typename T>
__aicore__ inline __gm__ T* GetTensorAddr(uint16_t index, GM_ADDR tensorPtr) {
__gm__ uint64_t* dataAddr = reinterpret_cast<__gm__ uint64_t*>(tensorPtr);
uint64_t tensorPtrOffset = *dataAddr; // The offset of the data address from the first address.
// Moving 3 bits to the right means dividing by sizeof(uint64 t).
__gm__ uint64_t* retPtr = dataAddr + (tensorPtrOffset >> 3);
return reinterpret_cast<__gm__ T*>(*(retPtr + index));
}
} // namespace GROUPED_MATMUL
#endif // ASCENDC_GROUPED_MATMUL_UTILS_H

View File

@@ -552,6 +552,41 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> grouped_matmul_swiglu_quant(
output_offset);
return std::tuple<at::Tensor, at::Tensor, at::Tensor>(output, output_scale, output_offset);
}
std::tuple<at::Tensor, at::Tensor, at::Tensor> grouped_matmul_swiglu_quant_weight_nz_tensor_list(
const at::Tensor & x,
const at::TensorList & weight,
const at::TensorList & weight_scale,
const at::Tensor & x_scale,
const at::Tensor & group_list,
const c10::optional<at::Tensor> & bias,
const c10::optional<at::Tensor> & offset)
{
auto x_size = x.sizes();
int n = weight[0].sizes()[1];
int m = x_size[0];
int k = x_size[1];
at::Tensor output = at::zeros({m, n/2}, x.options().dtype(at::kChar));
at::Tensor output_scale = at::zeros({m}, x.options().dtype(at::kFloat));
at::Tensor output_offset = at::zeros({m}, x.options().dtype(at::kFloat));
EXEC_NPU_CMD(
aclnnGroupedMatmulSwigluQuantWeightNzTensorList,
x,
weight,
bias,
offset,
weight_scale,
x_scale,
group_list,
output,
output_scale,
output_offset);
return std::tuple<at::Tensor, at::Tensor, at::Tensor>(output, output_scale, output_offset);
}
} // namespace vllm_ascend
TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
@@ -614,4 +649,12 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
" Tensor group_list, *, Tensor? bias=None,"
" Tensor? offset=None) -> (Tensor output, Tensor output_scale, Tensor output_offset)");
ops.impl("grouped_matmul_swiglu_quant", torch::kPrivateUse1, &vllm_ascend::grouped_matmul_swiglu_quant);
ops.def(
"grouped_matmul_swiglu_quant_weight_nz_tensor_list(Tensor x, Tensor[] weight, Tensor[] weight_scale, Tensor x_scale,"
" Tensor group_list, *,"
" Tensor? bias=None, Tensor? offset=None) ->"
" (Tensor output, Tensor output_scale, Tensor output_offset)"
);
ops.impl("grouped_matmul_swiglu_quant_weight_nz_tensor_list", torch::kPrivateUse1, &vllm_ascend::grouped_matmul_swiglu_quant_weight_nz_tensor_list);
}

View File

@@ -130,14 +130,34 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> grouped_matmul_swiglu_quant(
return {output, output_scale, output_offset};
}
std::tuple<at::Tensor, at::Tensor, at::Tensor> grouped_matmul_swiglu_quant_weight_nz_tensor_list_meta(
const at::Tensor & x,
const at::TensorList & weight,
const at::TensorList & weight_scale,
const at::Tensor & x_scale,
const at::Tensor & group_list,
const c10::optional<at::Tensor> & bias,
const c10::optional<at::Tensor> & offset)
{
auto x_size = x.sizes();
int n = weight[0].sizes()[1];
int m = x_size[0];
int k = x_size[1];
at::Tensor output = at::zeros({m, n/2}, c10::dtype(c10::ScalarType::Char));
at::Tensor output_scale = at::zeros({m}, c10::dtype(c10::ScalarType::Float));
at::Tensor output_offset = at::zeros({m}, c10::dtype(c10::ScalarType::Float));
return std::tuple<at::Tensor, at::Tensor, at::Tensor>(output, output_scale, output_offset);
}
} // namespace meta
} // namespace vllm_ascend
namespace {
// Register the meta implementations of the custom kernels for symbolic tracing, this will also
// the custom kernel been captured into aclgraph
TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) {
// Register the meta implementations of the custom kernels for symbolic tracing, this will also
// the custom kernel been captured into aclgraph
TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) {
// Rotary embedding meta implementation
ops.impl("rotary_embedding", &vllm_ascend::meta::rotary_embedding_meta);
// Masked input and mask meta implementation
@@ -150,5 +170,7 @@ namespace {
ops.impl("mla_preprocess", &vllm_ascend::meta::mla_preprocess);
// grouped_matmul_swiglu_quant meta implementation
ops.impl("grouped_matmul_swiglu_quant", &vllm_ascend::meta::grouped_matmul_swiglu_quant);
// Grouped matmul swiglu quant weight nz tensor list
ops.impl("grouped_matmul_swiglu_quant_weight_nz_tensor_list", &vllm_ascend::meta::grouped_matmul_swiglu_quant_weight_nz_tensor_list_meta);
}
}

48
csrc/utils/CMakeLists.txt Normal file
View File

@@ -0,0 +1,48 @@
# Copyright (c) 2024 Huawei Technologies Co., Ltd.
# This file is a part of the CANN Open Software.
# Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
# Please refer to the License for details. You may not use this file except in compliance with the License.
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.
# ======================================================================================================================
add_library(ops_utils_tiling_headers INTERFACE)
target_include_directories(ops_utils_tiling_headers INTERFACE
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/inc>
$<$<BOOL:${BUILD_OPEN_PROJECT}>:$<BUILD_INTERFACE:${ASCEND_CANN_PACKAGE_PATH}/include/experiment/slog>>
$<$<BOOL:${BUILD_OPEN_PROJECT}>:$<BUILD_INTERFACE:${ASCEND_CANN_PACKAGE_PATH}/include/experiment/metadef>>
$<$<BOOL:${BUILD_OPEN_PROJECT}>:$<BUILD_INTERFACE:${ASCEND_CANN_PACKAGE_PATH}/include/experiment/runtime>>
$<$<BOOL:${BUILD_OPEN_PROJECT}>:$<BUILD_INTERFACE:${ASCEND_CANN_PACKAGE_PATH}/include/experiment/msprof>>
$<INSTALL_INTERFACE:include/ops_adv/utils>
)
target_compile_definitions(ops_utils_tiling_headers INTERFACE
OPS_UTILS_LOG_SUB_MOD_NAME="OP_TILING"
OPS_UTILS_LOG_PACKAGE_TYPE=$<IF:$<BOOL:${BUILD_OPEN_PROJECT}>,"[Custom]","">
)
add_library(ops_utils_proto_headers INTERFACE)
target_include_directories(ops_utils_proto_headers INTERFACE
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/inc>
$<$<BOOL:${BUILD_OPEN_PROJECT}>:$<BUILD_INTERFACE:${ASCEND_CANN_PACKAGE_PATH}/include/experiment/slog>>
$<$<BOOL:${BUILD_OPEN_PROJECT}>:$<BUILD_INTERFACE:${ASCEND_CANN_PACKAGE_PATH}/include/experiment/metadef>>
$<$<BOOL:${BUILD_OPEN_PROJECT}>:$<BUILD_INTERFACE:${ASCEND_CANN_PACKAGE_PATH}/include/aclnn/opdev>>
$<INSTALL_INTERFACE:include/ops_adv/utils>
)
target_compile_definitions(ops_utils_proto_headers INTERFACE
OPS_UTILS_LOG_SUB_MOD_NAME="OP_PROTO"
OPS_UTILS_LOG_PACKAGE_TYPE=$<IF:$<BOOL:${BUILD_OPEN_PROJECT}>,"[Custom]","">
)
if(NOT BUILD_OPEN_PROJECT)
install_package(
PACKAGE ops_adv
TARGETS ops_utils_tiling_headers ops_utils_proto_headers
DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/inc/
DESTINATION include/ops_adv/utils
)
endif()

View File

@@ -0,0 +1,14 @@
/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef OP_API_INC_ACLNN_UTIL_H
#define OP_API_INC_ACLNN_UTIL_H
#define ACLNN_API __attribute__((visibility("default")))
#endif // OP_API_INC_ACLNN_UTIL_H

View File

@@ -0,0 +1,25 @@
/**
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file ops_error.h
* \brief
*/
#pragma once
#include "log/ops_log.h"
/* 基础报错 */
#define OPS_REPORT_VECTOR_INNER_ERR(OPS_DESC, ...) OPS_INNER_ERR_STUB("E89999", OPS_DESC, __VA_ARGS__)
#define OPS_REPORT_CUBE_INNER_ERR(OPS_DESC, ...) OPS_INNER_ERR_STUB("E69999", OPS_DESC, __VA_ARGS__)
/* 条件报错 */
#define OPS_ERR_IF(COND, LOG_FUNC, EXPR) OPS_LOG_STUB_IF(COND, LOG_FUNC, EXPR)

497
csrc/utils/inc/fallback.h Normal file
View File

@@ -0,0 +1,497 @@
/**
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file fallback.h
* \brief
*/
#ifndef ACLNNFALLBACK_OPAPI_H_
#define ACLNNFALLBACK_OPAPI_H_
#include <dlfcn.h>
#include <functional>
#include <tuple>
#include <type_traits>
#include <vector>
#include "aclnn/aclnn_base.h"
#include "fallback_comm.h"
#include "error/ops_error.h"
#include "runtime/base.h"
namespace fallback {
using namespace std;
using namespace gert;
using namespace ge;
using namespace std;
namespace std_utils {
template <std::size_t... Is>
struct index_sequence {};
template <std::size_t N, std::size_t... Is>
struct make_index_sequence_helper : make_index_sequence_helper<N - 1, N - 1, Is...> {};
template <std::size_t... Is>
struct make_index_sequence_helper<0, Is...> {
using type = index_sequence<Is...>;
};
template <std::size_t N>
using make_index_sequence = typename make_index_sequence_helper<N>::type;
}
using aclOpExecutor = struct aclOpExecutor;
using aclTensor = struct aclTensor;
using aclScalar = struct aclScalar;
using aclIntArray = struct aclIntArray;
using aclFloatArray = struct aclFloatArray;
using aclBoolArray = struct aclBoolArray;
using aclTensorList = struct aclTensorList;
using _aclCreateTensor = aclTensor* (*)(const int64_t* view_dims, uint64_t view_dims_num, aclDataType data_type,
const int64_t* stride, int64_t offset, aclFormat format,
const int64_t* storage_dims, uint64_t storage_dims_num, void* tensor_data);
using _aclCreateScalar = aclScalar* (*)(void* value, aclDataType data_type);
using _aclCreateIntArray = aclIntArray* (*)(const int64_t* value, uint64_t size);
using _aclCreateFloatArray = aclFloatArray* (*)(const float* value, uint64_t size);
using _aclCreateBoolArray = aclBoolArray* (*)(const bool* value, uint64_t size);
using _aclCreateTensorList = aclTensorList* (*)(const aclTensor* const* value, uint64_t size);
using _aclDestroyTensor = int (*)(const aclTensor* tensor);
using _aclDestroyScalar = int (*)(const aclScalar* scalar);
using _aclDestroyIntArray = int (*)(const aclIntArray* array);
using _aclDestroyFloatArray = int (*)(const aclFloatArray* array);
using _aclDestroyBoolArray = int (*)(const aclBoolArray* array);
using _aclDestroyTensorList = int (*)(const aclTensorList* array);
#define GET_OP_API_FUNC(apiName) reinterpret_cast<_##apiName>(GetOpApiFuncAddr(#apiName))
inline const char* GetOpApiLibName(void) {
return "libopapi.so";
}
inline const char* GetCustOpApiLibName(void) {
return "libcust_opapi.so";
}
inline void* GetOpApiFuncAddrInLib(void* handler, const char* libName, const char* apiName) {
auto funcAddr = dlsym(handler, apiName);
if (funcAddr == nullptr) {
OPS_LOG_W("aclnnfallback", "dlsym %s from %s failed, error:%s.", apiName, libName, dlerror());
}
return funcAddr;
}
inline void* GetOpApiLibHandler(const char* libName) {
auto handler = dlopen(libName, RTLD_LAZY);
if (handler == nullptr) {
OPS_LOG_W("aclnnfallback", "dlopen %s failed, error:%s.", libName, dlerror());
}
return handler;
}
inline void* GetAclnnArrdByApiName(const char *apiName) {
vector<std:: string> libs = {"libaclnn_ops_infer.so", "libaclnn_ops_train.so", "libaclnn_math.so",
"libaclnn_rand.so", "libaclnn_sparse.so", "libaclnn_fft.so"};
for (const auto &libName : libs) {
static auto libHandler = GetOpApiLibHandler(libName.c_str());
if (libHandler != nullptr) {
auto funcAddr = GetOpApiFuncAddrInLib(libHandler, libName.c_str(), apiName);
if (funcAddr != nullptr) {
return funcAddr;
}
}
}
OPS_LOG_E("aclnnfallback", "api %s can't find in any aclnn lib.", apiName);
return nullptr;
}
inline void* GetOpApiFuncAddr(const char* apiName) {
static auto custOpApiHandler = GetOpApiLibHandler(GetCustOpApiLibName());
if (custOpApiHandler != nullptr) {
auto funcAddr = GetOpApiFuncAddrInLib(custOpApiHandler, GetCustOpApiLibName(), apiName);
if (funcAddr != nullptr) {
return funcAddr;
}
}
static auto opApiHandler = GetOpApiLibHandler(GetOpApiLibName());
if (opApiHandler != nullptr) {
auto funcAddr = GetOpApiFuncAddrInLib(opApiHandler, GetOpApiLibName(), apiName);
if (funcAddr != nullptr) {
return funcAddr;
}
}
OPS_LOG_D("aclnnfallback", "opapi lib is not exist,will use aclnn lib.");
return GetAclnnArrdByApiName(apiName);
}
inline aclTensor* ConvertType(aclTensor* ge_tensor) {
return ge_tensor;
}
inline aclIntArray* ConvertType(const std::vector<int64_t> &arr) {
if (arr.empty()) {
return nullptr;
}
static const auto aclCreateIntArray = GET_OP_API_FUNC(aclCreateIntArray);
auto array = aclCreateIntArray(arr.data(), arr.size());
return array;
}
inline aclDataType GetConvertType(const gert::Tensor* ge_tensor) {
// convert data type
auto dataType_ge = ge_tensor->GetDataType();
auto dataType = aclDataType::ACL_FLOAT16;
if (dataType_ge == DT_FLOAT) {
dataType = aclDataType::ACL_FLOAT;
} else if (dataType_ge == DT_BF16) {
dataType = aclDataType::ACL_BF16;
} else if (dataType_ge == DT_BOOL) {
dataType = aclDataType::ACL_BOOL;
} else if (dataType_ge == DT_INT64) {
dataType = aclDataType::ACL_INT64;
} else if (dataType_ge == DT_INT32) {
dataType = aclDataType::ACL_INT32;
} else if (dataType_ge == DT_UINT64) {
dataType = aclDataType::ACL_UINT64;
} else if (dataType_ge == DT_UINT32) {
dataType = aclDataType::ACL_UINT32;
} else if (dataType_ge == DT_INT8) {
dataType = aclDataType::ACL_INT8;
} else if (dataType_ge == DT_UINT8) {
dataType = aclDataType::ACL_UINT8;
} else if (dataType_ge == DT_INT4) {
dataType = aclDataType::ACL_INT4;
} else {
dataType = aclDataType::ACL_FLOAT16;
}
return dataType;
}
inline aclTensor* ConvertType(const gert::Tensor* ge_tensor) {
if (ge_tensor == nullptr) {
return nullptr;
}
static const auto aclCreateTensor = GET_OP_API_FUNC(aclCreateTensor);
OPS_ERR_IF(aclCreateTensor == nullptr, OPS_LOG_E("aclnnfallback", "aclCreateTensor nullptr"), return nullptr);
void* device_addr = nullptr;
auto tensor_place = ge_tensor->GetPlacement();
device_addr = const_cast<void*>(ge_tensor->GetAddr());
auto dataType = GetConvertType(ge_tensor);
OPS_LOG_D("aclnnfallback", "aclCreateTensor: tensor type is %d", dataType);
// convert shape
auto gert_shape = ge_tensor->GetStorageShape();
std::vector<int64_t> shape;
for (size_t i = 0; i < gert_shape.GetDimNum(); ++i) {
shape.push_back(gert_shape.GetDim(i));
}
// 计算连续tensor的strides
std::vector<int64_t> strides(shape.size(), 1);
for (int64_t i = shape.size() - 2; i >= 0; i--) {
strides[i] = shape[i + 1] * strides[i + 1];
}
aclTensor* out = aclCreateTensor(shape.data(), shape.size(), dataType, strides.data(),
0, aclFormat::ACL_FORMAT_ND,
shape.data(), shape.size(), device_addr);
OPS_ERR_IF(out == nullptr,
OPS_LOG_E("aclnnfallback", "out nullptr"), return nullptr);
return out;
}
inline aclTensorList* ConvertType(std::vector<const gert::Tensor*>& ge_tenserList) {
OPS_ERR_IF(ge_tenserList.size() == 0,
OPS_LOG_E("aclnnfallback", "ge_tenserList size 0"), return nullptr);
static const auto aclCreateTensorList = GET_OP_API_FUNC(aclCreateTensorList);
OPS_ERR_IF(aclCreateTensorList == nullptr,
OPS_LOG_E("aclnnfallback", "ge_tenserList size 0"), return nullptr);
std::vector<aclTensor*> tmp;
for (size_t i = 0; i < ge_tenserList.size(); i++) {
auto t_acl = ConvertType(ge_tenserList[i]);
tmp.push_back(t_acl);
}
aclTensorList* tensorList = aclCreateTensorList(tmp.data(), tmp.size());
return tensorList;
}
template <typename T>
inline aclScalar* ConvertScalarType(T value) {
static const auto aclCreateScalar = GET_OP_API_FUNC(aclCreateScalar);
OPS_ERR_IF(aclCreateScalar == nullptr,
OPS_LOG_E("aclnnfallback", "aclCreateScalar nullptr"), return nullptr);
if (typeid(value) == typeid(float)) {
return aclCreateScalar(&value, aclDataType::ACL_FLOAT);
}
return nullptr;
}
template <typename T>
T ConvertType(T value) {
return value;
}
inline aclTensor* ConvertMmType(const gert::Tensor* ge_tensor, bool transpose, bool enable_NZ=false) {
if (ge_tensor == nullptr) {
return nullptr;
}
auto gert_shape = ge_tensor->GetStorageShape();
if (gert_shape.GetDimNum() <= 1) {
return ConvertType(ge_tensor);
}
static const auto aclCreateTensor = GET_OP_API_FUNC(aclCreateTensor);
OPS_ERR_IF(aclCreateTensor == nullptr, OPS_LOG_E("aclnnfallback", "aclCreateTensor nullptr"), return nullptr);
void* device_addr = const_cast<void*>(ge_tensor->GetAddr());
// convert data type
auto dataType_ge = ge_tensor->GetDataType();
auto dataType = ToAclDataType(dataType_ge);
// convert shape
std::vector<int64_t> shape;
for (size_t i = 0; i < gert_shape.GetDimNum(); ++i) {
shape.push_back(gert_shape.GetDim(i));
}
// 计算连续tensor的strides
std::vector<int64_t> strides(shape.size(), 1);
for (int64_t i = shape.size() - 2; i >= 0; i--) {
strides[i] = shape[i + 1] * strides[i + 1];
}
auto viewShape = shape;
// 对于transpose后的tensor对后两维度进行strides, viewShape转换
if (transpose) {
// dimM 为倒数第二维, dimN 为倒数第一维度
auto dimM = shape.size() - 2;
auto dimN = shape.size() - 1;
auto swap = strides[dimN];
strides[dimN] = strides[dimM];
strides[dimM] = swap;
// 修改viewShape
viewShape[dimN] = shape[dimM];
viewShape[dimM] = shape[dimN];
}
auto acl_format = aclFormat::ACL_FORMAT_ND;
if (enable_NZ && GetPrimaryFormat(ge_tensor->GetStorageFormat()) == ge::Format::FORMAT_FRACTAL_NZ) {
acl_format = aclFormat::ACL_FORMAT_FRACTAL_NZ;
}
aclTensor* out = aclCreateTensor(viewShape.data(), shape.size(), dataType, strides.data(),
0, acl_format, shape.data(), shape.size(), device_addr);
OPS_ERR_IF(out == nullptr, OPS_LOG_E("aclnnfallback", "out nullptr"), return nullptr);
return out;
}
inline void Release(aclTensor* p) {
static const auto aclDestroyTensor = GET_OP_API_FUNC(aclDestroyTensor);
OPS_ERR_IF(aclDestroyTensor == nullptr,
OPS_LOG_E("aclnnfallback", "aclDestroyTensor is null"), return);
aclDestroyTensor(p);
}
inline void Release(aclScalar* p) {
static const auto aclDestroyScalar = GET_OP_API_FUNC(aclDestroyScalar);
OPS_ERR_IF(aclDestroyScalar == nullptr,
OPS_LOG_E("aclnnfallback", "aclDestroyScalar is null"), return);
aclDestroyScalar(p);
}
inline void Release(aclIntArray* p) {
static const auto aclDestroyIntArray = GET_OP_API_FUNC(aclDestroyIntArray);
OPS_ERR_IF(aclDestroyIntArray == nullptr,
OPS_LOG_E("aclnnfallback", "aclDestroyIntArray is null"), return);
aclDestroyIntArray(p);
}
inline void Release(aclBoolArray* p) {
static const auto aclDestroyBoolArray = GET_OP_API_FUNC(aclDestroyBoolArray);
OPS_ERR_IF(aclDestroyBoolArray == nullptr,
OPS_LOG_E("aclnnfallback", "aclDestroyBoolArray is null"), return);
aclDestroyBoolArray(p);
}
inline void Release(aclTensorList* p) {
static const auto aclDestroyTensorList = GET_OP_API_FUNC(aclDestroyTensorList);
OPS_ERR_IF(aclDestroyTensorList == nullptr,
OPS_LOG_E("aclnnfallback", "aclDestroyTensorList is null"), return);
aclDestroyTensorList(p);
}
template <typename T>
void Release(T value) {
(void)value;
}
template <typename Tuple, size_t... I>
void CallRelease(Tuple t, std_utils::index_sequence<I...>) {
(void)std::initializer_list<int>{(Release(std::get<I>(t)), 0)...};
}
template <typename Tuple>
void ReleaseConvertTypes(Tuple& t) {
static constexpr auto size = std::tuple_size<Tuple>::value;
CallRelease(t, std_utils::make_index_sequence<size>{});
}
template <typename... Ts>
auto ConvertTypes(Ts&... args) -> decltype(std::make_tuple(ConvertType(args)...)) {
auto tp = std::make_tuple(ConvertType(args)...);
return tp;
}
template <typename Function, typename Tuple, size_t... I>
auto call(Function f, Tuple t, std_utils::index_sequence<I...>) -> int {
return f(std::get<I>(t)...);
}
template <typename Function, typename Tuple>
auto call(Function f, Tuple t) -> int {
static constexpr auto size = std::tuple_size<Tuple>::value;
return call(f, t, std_utils::make_index_sequence<size>{});
}
template <typename Tuple, size_t... I>
auto ConvertToOpApiFunc(const Tuple& params, void* opApiAddr, std_utils::index_sequence<I...>)
-> int (*)(typename std::decay<decltype(std::get<I>(params))>::type...) {
using OpApiFunc = int (*)(typename std::decay<decltype(std::get<I>(params))>::type...);
auto func = reinterpret_cast<OpApiFunc>(opApiAddr);
return func;
}
template <typename Tuple>
auto ConvertToOpApiFunc(const Tuple& params, void* opApiAddr)
-> typename std::enable_if<std::tuple_size<Tuple>::value != 0,
decltype(ConvertToOpApiFunc(params, opApiAddr, std_utils::make_index_sequence<std::tuple_size<Tuple>::value>{}))>::type {
static constexpr auto size = std::tuple_size<Tuple>::value;
return ConvertToOpApiFunc(params, opApiAddr, std_utils::make_index_sequence<size>{});
}
template <typename Tuple>
class ConvertedParams {
public:
ConvertedParams(Tuple&& convertedParams) : convertedParams_(std::move(convertedParams)){};
ConvertedParams(ConvertedParams&& other) : convertedParams_(std::move(other.convertedParams_)) {
other.validParams_ = false;
};
ConvertedParams& operator=(ConvertedParams&& other) {
if (this == &other) {
return *this;
}
convertedParams_ = std::move(other.convertedParams_);
validParams_ = true;
other.validParams_ = false;
return *this;
}
ConvertedParams() = delete;
ConvertedParams(const ConvertedParams& other) = delete;
ConvertedParams& operator=(const ConvertedParams& other) = delete;
~ConvertedParams() {
if (validParams_) {
ReleaseConvertTypes(convertedParams_);
}
}
const Tuple& GetConvertedParams() const {
return convertedParams_;
}
private:
Tuple convertedParams_;
bool validParams_{true};
};
using InitHugeMemThreadLocal = int (*)(void*, bool);
using UnInitHugeMemThreadLocal = void (*)(void*, bool);
using ReleaseHugeMem = void (*)(void*, bool);
using PTAGetExecCache = aclOpExecutor* (*)(uint64_t, uint64_t*);
using InitPTACacheThreadLocal = void (*)();
using SetPTAHashKey = void (*)(uint64_t);
using CanUsePTACache = bool (*)(const char*);
using ResetCacheThreadLocal = void (*)();
#define EXEC_OPAPI_CMD(aclnn_api, ...) \
({ \
static auto ret = GRAPH_SUCCESS; \
do { \
static const auto ResetCacheThreadLocalAddr = GetOpApiFuncAddr("ResetCacheThreadLocal"); \
static const auto getWorkspaceSizeFuncAddr = GetOpApiFuncAddr(#aclnn_api "GetWorkspaceSize"); \
static const auto opApiFuncAddr = GetOpApiFuncAddr(#aclnn_api); \
if (getWorkspaceSizeFuncAddr == nullptr || opApiFuncAddr == nullptr || ResetCacheThreadLocalAddr == nullptr) { \
OPS_LOG_E("aclnnfallback", "%s or %s not in %s or %s or ResetCacheThreadLocal not found.", \
#aclnn_api "GetWorkspaceSize", #aclnn_api, GetOpApiLibName(), GetOpApiLibName()); \
ret = GRAPH_FAILED; \
break; \
} \
auto ResetCacheThreadLocalFunc = reinterpret_cast<ResetCacheThreadLocal>(ResetCacheThreadLocalAddr); \
ResetCacheThreadLocalFunc(); \
uint64_t workspace_size = 0; \
uint64_t* workspace_size_addr = &workspace_size; \
aclOpExecutor* executor = nullptr; \
aclOpExecutor** executor_addr = &executor; \
auto converted_params = ConvertTypes(__VA_ARGS__, workspace_size_addr, executor_addr); \
static auto getWorkspaceSizeFunc = ConvertToOpApiFunc(converted_params, getWorkspaceSizeFuncAddr); \
auto workspace_status = call(getWorkspaceSizeFunc, converted_params); \
if (workspace_status != 0) { \
OPS_LOG_E("aclnnfallback", "call %s failed:", #aclnn_api); \
ret = GRAPH_FAILED; \
break; \
} \
void* workspace_addr = nullptr; \
if (workspace_size > 0) { \
workspace_addr = host_api_ctx->MallocWorkspace(workspace_size); \
if (workspace_addr == nullptr) { \
OPS_LOG_E("aclnnfallback", "call %s allocate workspace failed", #aclnn_api); \
ret = GRAPH_FAILED; \
break; \
} \
} \
auto acl_stream = host_api_ctx->GetStream(); \
auto acl_call = [converted_params, workspace_addr, workspace_size, host_api_ctx, acl_stream, \
executor]() -> int { \
using OpApiFunc = int (*)(void*, uint64_t, aclOpExecutor*, const aclrtStream); \
OpApiFunc opApiFunc = reinterpret_cast<OpApiFunc>(opApiFuncAddr); \
auto api_ret = opApiFunc(workspace_addr, workspace_size, executor, acl_stream); \
ReleaseConvertTypes(converted_params); \
host_api_ctx->FreeWorkspace(); \
if (api_ret != 0) { \
OPS_LOG_E("aclnnfallback", "call %s allocate workspace failed api_ret: %d", #aclnn_api, api_ret); \
return GRAPH_FAILED; \
} \
return api_ret; \
}; \
\
ret = acl_call(); \
} while (false); \
(ret); \
})
} // namespace fallback
#endif // ACLNNFALLBACK_OPAPI_H_

View File

@@ -0,0 +1,38 @@
/**
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file fallback_comm.h
* \brief
*/
#ifndef INC_EXTERNAL_GRAPH_FALLBACK_COMMON_H_
#define INC_EXTERNAL_GRAPH_FALLBACK_COMMON_H_
#include "aclnn/aclnn_base.h"
#include "exe_graph/runtime/op_execute_context.h"
#include "exe_graph/runtime/tensor.h"
#include "register/op_impl_registry.h"
#include "runtime/base.h"
#ifdef __cplusplus
extern "C" {
#endif
namespace fallback {
aclDataType ToAclDataType(ge::DataType dtype);
} // namespace fallback
#ifdef __cplusplus
}
#endif
#endif // INC_EXTERNAL_GRAPH_FALLBACK_COMMON_H_

View File

@@ -0,0 +1,121 @@
/**
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file dropmask.h
* \brief
*/
#ifndef DROPMASK_H
#define DROPMASK_H
#include "util.h"
using AscendC::DROPOUT_MODE_BIT_MISALIGN;
using AscendC::DropOutShapeInfo;
using AscendC::DropOut;
struct DropMaskInfo {
// for compute dropout mask offset
// 参数按B N G S1 S2全部切分设置进行偏移计算没有切分的轴对应的参数设置为合适的0或者原始值
int64_t n2G; // n2 * g
int64_t gSize; // g
int64_t s1Size; // s1
int64_t s2Size; // s2
int64_t gOutIdx; // g out index
int64_t bSSOffset; // boidx * s1 * s2 ===bSSOffset
int64_t n2OutIdx; // n out index
int64_t s1OutIdx; // s1 out index ===s1oIdx
int64_t s1InnerIdx; // s1 inner index, 配比 ===loopIdx
int64_t s1BaseSize; // S1基本块大小
int64_t splitS1BaseSize; // s1 split size ===vec1S1BaseSize
int64_t s2StartIdx; // s2 start index
int64_t s2Idx; // s2 index =====s2LoopCount
int64_t s2BaseNratioSize; // s2的配比长度: s2BaseSize(S2基本块大小) * nRatio
// for copy in dropout mask
uint32_t s1CopySize;
uint32_t s2CopySize;
int64_t s2TotalSize;
// for compute dropout mask
uint32_t firstAxis;
uint32_t lstAxis;
uint32_t maskLstAxis;
int64_t vecCoreOffset = 0;
float keepProb;
bool boolMode;
};
template <bool hasDrop>
__aicore__ inline int64_t ComputeDropOffset(DropMaskInfo &dropMaskInfo)
{
if constexpr (hasDrop == true) {
// boidx * n2 * g* s1 * s2
int64_t bOffset = dropMaskInfo.bSSOffset * dropMaskInfo.n2G;
// n2oIdx * g * s1 *s2
int64_t n2Offset = dropMaskInfo.n2OutIdx * dropMaskInfo.gSize * dropMaskInfo.s1Size * dropMaskInfo.s2Size;
// goIdx * s1 * s2
int64_t gOffset = dropMaskInfo.gOutIdx * dropMaskInfo.s1Size * dropMaskInfo.s2Size;
// s1oIdx * s1BaseSize * s2Size + s1innerindex * vec1S1BaseSize * s2Size
int64_t s1Offset = (dropMaskInfo.s1OutIdx * dropMaskInfo.s1BaseSize + dropMaskInfo.vecCoreOffset +
dropMaskInfo.s1InnerIdx * dropMaskInfo.splitS1BaseSize) * dropMaskInfo.s2Size;
// s2StartIdx + s2index * s2BaseNratioSize
int64_t s2Offset = dropMaskInfo.s2StartIdx + dropMaskInfo.s2Idx * dropMaskInfo.s2BaseNratioSize;
return bOffset + n2Offset + gOffset + s1Offset + s2Offset;
} else {
return 0;
}
}
template <bool hasDrop>
__aicore__ inline void CopyInDropMask(LocalTensor<uint8_t>&dstTensor, GlobalTensor<uint8_t>& srcBoolTensor,
GlobalTensor<uint8_t>& srcByteTensor, DropMaskInfo &dropMaskInfo, int64_t alignedSize = blockBytes)
{
if constexpr (hasDrop == true) {
int64_t dropMaskOffset = ComputeDropOffset<hasDrop>(dropMaskInfo);
if (unlikely(dropMaskInfo.boolMode)) {
BoolCopyIn(dstTensor, srcBoolTensor, dropMaskOffset,
dropMaskInfo.s1CopySize, dropMaskInfo.s2CopySize, dropMaskInfo.s2TotalSize, alignedSize);
} else {
Bit2Int8CopyIn(dstTensor, srcByteTensor, dropMaskOffset, 1,
dropMaskInfo.s1CopySize, dropMaskInfo.s2CopySize, dropMaskInfo.s2TotalSize, alignedSize);
}
return;
}
}
template <typename T, bool hasDrop>
__aicore__ inline void ComputeDropMask(LocalTensor<T>& dstTensor, LocalTensor<T>& srcTensor,
LocalTensor<uint8_t>& dropoutBuffer, LocalTensor<uint8_t>& tmpDropBuffer, DropMaskInfo &dropMaskInfo)
{
if constexpr (hasDrop == true) {
DropOutShapeInfo dropOutShapeInfo;
dropOutShapeInfo.firstAxis = dropMaskInfo.firstAxis;
dropOutShapeInfo.srcLastAxis = dropMaskInfo.lstAxis;
if (unlikely(dropMaskInfo.boolMode)) {
dropOutShapeInfo.maskLastAxis = CeilDiv(dropMaskInfo.maskLstAxis, blockBytes) * blockBytes;
DropOut(dstTensor, srcTensor, dropoutBuffer, tmpDropBuffer, dropMaskInfo.keepProb, dropOutShapeInfo);
} else {
dropOutShapeInfo.maskLastAxis = CeilDiv(dropMaskInfo.maskLstAxis / byteBitRatio, blockBytes) * blockBytes;
if (likely(dropMaskInfo.lstAxis / byteBitRatio % blockBytes == 0)) {
DropOut(dstTensor, srcTensor, dropoutBuffer, tmpDropBuffer, dropMaskInfo.keepProb, dropOutShapeInfo);
} else {
DropOut<T, false, DROPOUT_MODE_BIT_MISALIGN>(dstTensor, srcTensor, dropoutBuffer, tmpDropBuffer,
dropMaskInfo.keepProb, dropOutShapeInfo);
}
}
return;
}
}
#endif // DROPMASK_H

483
csrc/utils/inc/kernel/pse.h Normal file
View File

@@ -0,0 +1,483 @@
/**
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file pse.h
* \brief
*/
#ifndef FLASH_ATTENTION_SCORE_PSE_H
#define FLASH_ATTENTION_SCORE_PSE_H
#include "kernel_operator.h"
#include "util.h"
constexpr static int64_t pseS1S2 = 0;
constexpr static int64_t pse1S2 = 1;
constexpr static int64_t pseSlopeBn = 2;
constexpr static int64_t pseSlopeN = 3;
constexpr static uint8_t pseEncodeALibiS2Full = 0x11;
enum class PseTypeEnum {
PSE_OUTER_MUL_ADD_TYPE = 0, // default
PSE_OUTER_ADD_MUL_TYPE,
PSE_INNER_MUL_ADD_TYPE,
PSE_INNER_MUL_ADD_SQRT_TYPE,
PSE_INVALID_TYPE
};
struct PseInfo {
int64_t blockCount;
int64_t bSSOffset; // boidx * s1 * s2
int64_t boIdx;
int64_t gSize;
int64_t goIdx;
int64_t loopIdx;
int64_t n2G;
int64_t n2oIdx;
int64_t pseBSize;
int64_t pseS1Size; // for alibi
int64_t pseS2ComputeSize; // for alibi, do not need assignment
int64_t pseS2Size; // for alibi
uint32_t pseShapeType;
int64_t readS2Size; // for alibi, do not need assignment
int64_t s1BaseSize;
int64_t s1Size;
int64_t s1oIdx;
int64_t s2AlignedSize;
int64_t s2BaseNratioSize;
int64_t s2LoopCount;
int64_t s2RealSize;
int64_t s2Size;
int64_t s2SizeAcc; // accumulated sum of s2 size
int64_t s2StartIdx;
int64_t vec1S1BaseSize;
int64_t vec1S1RealSize;
uint32_t pseEncodeType; // for distinguish alibi
uint32_t pseType; // 0: outer, mul-add 1:outer, add-mul 2:inner, mul-add 3:inner, mul-add-sqrt
int64_t pseAlibiBaseS1;
int64_t pseAlibiBaseS2;
int64_t qStartIdx;
int64_t kvStartIdx;
int64_t vecCoreOffset = 0;
bool needCast;
bool align8 = false;
bool pseEndogenous = false;
};
template <typename INPUT_T, bool hasPse>
__aicore__ inline void DataCopyInCommon(LocalTensor<INPUT_T> &dstTensor, GlobalTensor<INPUT_T> &srcTensor, int64_t offset,
int64_t s1Size, int64_t s2Size, int64_t actualS2Len, int32_t dtypeSize,
int32_t alignedS2Size)
{
if constexpr (hasPse == true) {
uint32_t shapeArray[] = {static_cast<uint32_t>(s1Size), static_cast<uint32_t>(alignedS2Size)};
dstTensor.SetShapeInfo(ShapeInfo(2, shapeArray, DataFormat::ND));
dstTensor.SetSize(s1Size * alignedS2Size);
DataCopyParams dataCopyParams;
dataCopyParams.blockCount = s1Size;
dataCopyParams.blockLen = CeilDiv(s2Size * dtypeSize, blockBytes); // 单位32B
dataCopyParams.dstStride = alignedS2Size * dtypeSize / blockBytes - dataCopyParams.blockLen; // gap
if (actualS2Len * dtypeSize % blockBytes == 0) {
dataCopyParams.srcStride =
(actualS2Len * dtypeSize - dataCopyParams.blockLen * blockBytes) / blockBytes; // srcGap
DataCopy(dstTensor, srcTensor[offset], dataCopyParams);
} else {
dataCopyParams.blockLen = s2Size * dtypeSize; // 单位Byte
dataCopyParams.srcStride = (actualS2Len * dtypeSize - dataCopyParams.blockLen);
dataCopyParams.dstStride = (alignedS2Size - s2Size) * dtypeSize / blockBytes;
DataCopyPadParams dataCopyPadParams;
dataCopyPadParams.isPad = false;
DataCopyPad(dstTensor, srcTensor[offset], dataCopyParams, dataCopyPadParams);
}
}
}
template <typename INPUT_T, bool hasPse>
__aicore__ inline void DataCopyIn(LocalTensor<INPUT_T> &dstTensor, GlobalTensor<INPUT_T> &srcTensor, int64_t offset,
int64_t s1Size, int64_t s2Size, int64_t actualS2Len, int64_t alignedSize = 16)
{
if constexpr (hasPse == true) {
int32_t dtypeSize = sizeof(INPUT_T);
int32_t alignedS2Size = CeilDiv(s2Size, alignedSize) * alignedSize;
DataCopyInCommon<INPUT_T, hasPse>(dstTensor, srcTensor, offset, s1Size, s2Size,
actualS2Len, dtypeSize, alignedS2Size);
}
}
template <typename INPUT_T, bool hasPse>
__aicore__ inline void DataCopyInAlign8(LocalTensor<INPUT_T> &dstTensor, GlobalTensor<INPUT_T> &srcTensor, int64_t offset,
int64_t s1Size, int64_t s2Size, int64_t actualS2Len)
{
if constexpr (hasPse == true) {
int32_t dtypeSize = sizeof(INPUT_T);
if (dtypeSize == 0){
return;
}
int32_t alignedS2Size = CeilDiv(s2Size, 32 / dtypeSize) * (32 / dtypeSize);
DataCopyInCommon<INPUT_T, hasPse>(dstTensor, srcTensor, offset, s1Size, s2Size,
actualS2Len, dtypeSize, alignedS2Size);
}
}
/*
dst = BroadcastAdd(src0, src1)
src0 shape: (s1, s2)
src1 shape: (1, s2)
dst shape: (s1, s2)
*/
template <typename T, bool hasPse>
__aicore__ inline void BroadcastAdd(const LocalTensor<T> &src0Tensor, const LocalTensor<T> &src1Tensor,
int64_t src0Offset, int32_t src1Size, int32_t repeatTimes)
{
if constexpr (hasPse == true) {
/* Total data number of single step should be smaller than 256bytes.
* If larger, we need to do add multiple times. */
int32_t innerLoop = src1Size / repeatMaxSize; // s2轴整块计算次数
int32_t innerRemain = src1Size % repeatMaxSize; // s2轴尾块计算量
BinaryRepeatParams binaryRepeatParams;
binaryRepeatParams.src0BlkStride = 1;
binaryRepeatParams.src0RepStride = src1Size / blockSize;
binaryRepeatParams.src1BlkStride = 1;
binaryRepeatParams.src1RepStride = 0;
binaryRepeatParams.dstRepStride = binaryRepeatParams.src0RepStride;
binaryRepeatParams.blockNumber = binaryRepeatParams.src0RepStride;
for (int32_t j = 0; j < innerLoop; j++) {
auto innerOffset = j * repeatMaxSize;
auto ubOffset = src0Offset + innerOffset;
Add(src0Tensor[ubOffset], src0Tensor[ubOffset], src1Tensor[innerOffset], repeatMaxSize, repeatTimes,
binaryRepeatParams);
}
if (innerRemain > 0) {
auto innerOffset = innerLoop * repeatMaxSize;
auto ubOffset = src0Offset + innerOffset;
Add(src0Tensor[ubOffset], src0Tensor[ubOffset], src1Tensor[innerOffset], innerRemain, repeatTimes,
binaryRepeatParams);
}
}
}
template <typename T, bool hasPse>
__aicore__ inline void PseBroadcastAdd(int32_t s1Size, int32_t s2Size, int32_t computeSize, const LocalTensor<T> &pseUb,
const LocalTensor<T> &dstTensor, uint32_t pseShapeType)
{
if constexpr (hasPse == true) {
if (pseShapeType == pseS1S2 || pseShapeType == pseSlopeBn || pseShapeType == pseSlopeN) {
Add(dstTensor, dstTensor, pseUb, computeSize);
} else {
/* Total repeated times should be <= repeatMaxTimes. If larger,
* we need to do multiple inner loops. */
int32_t s1OuterLoop = s1Size / repeatMaxTimes;
int32_t s1OuterRemain = s1Size % repeatMaxTimes;
for (int32_t s1OuterIdx = 0; s1OuterIdx < s1OuterLoop; s1OuterIdx++) {
int32_t s1OuterOffset = s1OuterIdx * repeatMaxTimes * s2Size;
BroadcastAdd<T, hasPse>(dstTensor, pseUb, s1OuterOffset, s2Size, repeatMaxTimes);
}
if (s1OuterRemain > 0) {
int32_t s1OuterOffset = s1OuterLoop * repeatMaxTimes * s2Size;
BroadcastAdd<T, hasPse>(dstTensor, pseUb, s1OuterOffset, s2Size, s1OuterRemain);
}
}
}
}
template <bool hasPse> __aicore__ inline int64_t PseComputeOffset(PseInfo &pseInfo)
{
if constexpr (hasPse == true) {
int64_t bOffset = 0;
int64_t n2Offset = 0;
int64_t s1Offset = 0;
int64_t s2Offset = pseInfo.s2StartIdx + pseInfo.s2LoopCount * pseInfo.s2BaseNratioSize;
int64_t gOffset = 0;
if (pseInfo.pseShapeType == pseS1S2) {
// b, n2, g, s1, s2
bOffset = pseInfo.bSSOffset * pseInfo.n2G;
n2Offset = pseInfo.n2oIdx * pseInfo.gSize * pseInfo.s1Size * pseInfo.s2Size;
gOffset = pseInfo.goIdx * pseInfo.s1Size * pseInfo.s2Size;
s1Offset = (pseInfo.s1oIdx * pseInfo.s1BaseSize + pseInfo.vecCoreOffset +
pseInfo.loopIdx * pseInfo.vec1S1BaseSize) * pseInfo.s2Size;
} else if (pseInfo.pseShapeType == pse1S2) {
// b, n2, g, 1, s2
bOffset = pseInfo.s2SizeAcc * pseInfo.n2G;
n2Offset = pseInfo.n2oIdx * pseInfo.gSize * pseInfo.s2Size;
gOffset = pseInfo.goIdx * pseInfo.s2Size;
}
if (pseInfo.pseBSize == 1) {
bOffset = 0;
}
return bOffset + n2Offset + gOffset + s1Offset + s2Offset;
} else {
return 0;
}
}
template <LayOutTypeEnum layOutType, bool hasPse> __aicore__ inline int64_t PseAlibiComputeOffset(PseInfo &pseInfo)
{
if constexpr (hasPse == true) {
int64_t bOffset = (pseInfo.boIdx % pseInfo.pseBSize) * pseInfo.n2G * pseInfo.pseS2Size * pseInfo.pseS1Size;
int64_t n2Offset = pseInfo.n2oIdx * pseInfo.gSize * pseInfo.pseS2Size * pseInfo.pseS1Size;
int64_t gOffset = pseInfo.goIdx * pseInfo.pseS2Size * pseInfo.pseS1Size;
int64_t row = pseInfo.s1oIdx * pseInfo.s1BaseSize + pseInfo.vecCoreOffset +
pseInfo.loopIdx * pseInfo.vec1S1BaseSize;
int64_t column = pseInfo.s2StartIdx + pseInfo.s2LoopCount * pseInfo.s2BaseNratioSize;
int64_t m = 0;
int64_t k = 0;
if constexpr (layOutType != LayOutTypeEnum::LAYOUT_TND) {
int64_t threshold = pseInfo.s1Size - pseInfo.pseS1Size;
if (row >= threshold) {
m = row - threshold;
k = column;
} else {
m = row % pseInfo.pseS1Size;
k = pseInfo.pseS2Size - (row - column) - (pseInfo.pseS1Size - m);
}
} else {
int64_t threshold = pseInfo.pseS2Size - pseInfo.pseS1Size;
int64_t posVal = row - column - threshold;
if (threshold >= 0) {
if (posVal >= 0) {
m = posVal;
k = 0;
} else {
m = 0;
k = -posVal;
}
} else {
m = posVal;
k = 0;
}
}
int64_t s1Offset = m * pseInfo.pseS2Size;
int64_t s2Offset = k;
pseInfo.readS2Size = Min(pseInfo.s2AlignedSize, pseInfo.pseS2Size - k);
pseInfo.pseS2ComputeSize = Align(pseInfo.readS2Size);
return bOffset + n2Offset + gOffset + s1Offset + s2Offset;
} else {
return 0;
}
}
template <bool hasPse> __aicore__ inline bool NeedPseAlibiCompute(PseInfo &pseInfo)
{
if constexpr (hasPse == true) {
// Alibi编码只计算下三角
if (pseInfo.s1oIdx * pseInfo.s1BaseSize + pseInfo.vecCoreOffset +
(pseInfo.loopIdx + 1) * pseInfo.vec1S1BaseSize <=
pseInfo.s2StartIdx + pseInfo.s2LoopCount * pseInfo.s2BaseNratioSize) {
return false;
}
return true;
} else {
return false;
}
}
template <typename INPUT_T, typename T, LayOutTypeEnum layOutType, bool hasPse>
__aicore__ inline void PseAlibiCopyIn(LocalTensor<T> &dstTensor, LocalTensor<INPUT_T> &tmpTensor,
GlobalTensor<INPUT_T> &srcTensor, PseInfo &pseInfo, int64_t alignedSize = 16)
{
if constexpr (hasPse == true) {
if (!NeedPseAlibiCompute<hasPse>(pseInfo)) {
return;
}
int64_t offset = PseAlibiComputeOffset<layOutType, hasPse>(pseInfo);
if constexpr (IsSameType<INPUT_T, T>::value) {
if (!pseInfo.align8){
DataCopyIn<INPUT_T, hasPse>(dstTensor, srcTensor, offset, pseInfo.vec1S1RealSize, pseInfo.readS2Size,
pseInfo.pseS2Size, alignedSize);
} else {
DataCopyInAlign8<INPUT_T, hasPse>(dstTensor, srcTensor, offset, pseInfo.vec1S1RealSize,
pseInfo.readS2Size, pseInfo.pseS2Size);
}
return;
}
DataCopyIn<INPUT_T, hasPse>(tmpTensor, srcTensor, offset, pseInfo.vec1S1RealSize, pseInfo.readS2Size,
pseInfo.pseS2Size, alignedSize);
if (pseInfo.needCast) {
event_t eventIdMte2ToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
SetFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
WaitFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
Cast(dstTensor, tmpTensor, RoundMode::CAST_NONE, pseInfo.vec1S1RealSize * pseInfo.pseS2ComputeSize);
}
return;
}
}
template <typename T, bool hasPse>
__aicore__ inline void PseSlopeCopyIn(LocalTensor<T> &dstTensor, LocalTensor<half> &helpTensor,
__gm__ uint8_t *pseSlope, GlobalTensor<half> &alibiGm, PseInfo &pseInfo,
int64_t alignedSize = 16) {
if constexpr (hasPse == true) {
int64_t bOffset = 0;
int64_t n2Offset = pseInfo.n2oIdx * pseInfo.gSize;
int64_t gOffset = pseInfo.goIdx;
if (pseInfo.pseShapeType == pseSlopeBn) {
bOffset = pseInfo.boIdx * pseInfo.n2G;
}
int64_t offset = bOffset + n2Offset + gOffset;
DataCopyIn<half, hasPse>(helpTensor, alibiGm, 0, pseInfo.vec1S1RealSize,
pseInfo.s2RealSize, pseInfo.pseAlibiBaseS2, alignedSize);
event_t eventIdMte2ToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
SetFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
WaitFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
if (pseInfo.needCast) {
int64_t computeSize = pseInfo.vec1S1RealSize * pseInfo.s2AlignedSize;
Cast(dstTensor, helpTensor, RoundMode::CAST_NONE, computeSize);
pipe_barrier(PIPE_V);
int64_t s1Offset = pseInfo.s1oIdx * pseInfo.s1BaseSize + pseInfo.vecCoreOffset +
pseInfo.loopIdx * pseInfo.vec1S1BaseSize;
int64_t s2Offset = pseInfo.s2StartIdx + pseInfo.s2LoopCount * pseInfo.s2BaseNratioSize;
float posShift = float(s2Offset + pseInfo.kvStartIdx - s1Offset - pseInfo.qStartIdx);
Adds(dstTensor, dstTensor, posShift, computeSize);
pipe_barrier(PIPE_V);
Abs(dstTensor, dstTensor, computeSize);
pipe_barrier(PIPE_V);
float slopes = ((__gm__ T *)pseSlope)[offset] * -1;
if (pseInfo.pseType == (uint32_t)PseTypeEnum::PSE_INNER_MUL_ADD_SQRT_TYPE) {
Sqrt(dstTensor, dstTensor, computeSize);
pipe_barrier(PIPE_V);
}
Muls(dstTensor, dstTensor, slopes, computeSize);
pipe_barrier(PIPE_V);
}
}
}
template <typename T, bool hasPse>
__aicore__ inline void PseSlopeCast(LocalTensor<T> &dstTensor, LocalTensor<half> &helpTensor,
__gm__ uint8_t *pseSlope, PseInfo &pseInfo) {
if constexpr (hasPse == true) {
int64_t bOffset = 0;
int64_t n2Offset = pseInfo.n2oIdx * pseInfo.gSize;
int64_t gOffset = pseInfo.goIdx;
if (pseInfo.pseShapeType == pseSlopeBn) {
bOffset = pseInfo.boIdx * pseInfo.n2G;
}
int64_t offset = bOffset + n2Offset + gOffset;
int64_t computeSize = pseInfo.vec1S1RealSize * pseInfo.s2AlignedSize;
Cast(dstTensor, helpTensor, RoundMode::CAST_NONE, computeSize);
pipe_barrier(PIPE_V);
int64_t s1Offset = pseInfo.s1oIdx * pseInfo.s1BaseSize + pseInfo.vecCoreOffset +
pseInfo.loopIdx * pseInfo.vec1S1BaseSize;
int64_t s2Offset = pseInfo.s2StartIdx + pseInfo.s2LoopCount * pseInfo.s2BaseNratioSize;
float posShift = float(s2Offset + pseInfo.kvStartIdx - s1Offset - pseInfo.qStartIdx);
Adds(dstTensor, dstTensor, posShift, computeSize);
pipe_barrier(PIPE_V);
Abs(dstTensor, dstTensor, computeSize);
pipe_barrier(PIPE_V);
float slopes = ((__gm__ T *)pseSlope)[offset] * -1;
if (pseInfo.pseType == (uint32_t)PseTypeEnum::PSE_INNER_MUL_ADD_SQRT_TYPE) {
Sqrt(dstTensor, dstTensor, computeSize);
pipe_barrier(PIPE_V);
}
Muls(dstTensor, dstTensor, slopes, computeSize);
pipe_barrier(PIPE_V);
}
}
template <typename INPUT_T, typename T, LayOutTypeEnum layOutType, bool hasPse>
__aicore__ inline void PseCopyIn(LocalTensor<T> &dstTensor, LocalTensor<INPUT_T> &tmpTensor,
GlobalTensor<INPUT_T> &srcTensor, PseInfo &pseInfo, int64_t alignedSize = 16)
{
if constexpr (hasPse == true) {
if (pseInfo.pseEncodeType == pseEncodeALibiS2Full) {
return PseAlibiCopyIn<INPUT_T, T, layOutType, hasPse>(dstTensor, tmpTensor, srcTensor, pseInfo, alignedSize);
}
int64_t offset = PseComputeOffset<hasPse>(pseInfo);
int64_t s1Size = pseInfo.pseShapeType == pse1S2 ? (pseInfo.blockCount == 0 ? 1 : pseInfo.blockCount) :
pseInfo.vec1S1RealSize;
if constexpr (IsSameType<INPUT_T, T>::value) {
if (!pseInfo.align8){
DataCopyIn<INPUT_T, hasPse>(dstTensor, srcTensor, offset, s1Size, pseInfo.s2RealSize,
pseInfo.s2Size, alignedSize);
} else {
DataCopyInAlign8<INPUT_T, hasPse>(dstTensor, srcTensor, offset, s1Size, pseInfo.s2RealSize, pseInfo.s2Size);
}
return;
}
DataCopyIn<INPUT_T, hasPse>(tmpTensor, srcTensor, offset, s1Size, pseInfo.s2RealSize, pseInfo.s2Size,
alignedSize);
if (pseInfo.needCast) {
event_t eventIdMte2ToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
SetFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
WaitFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
Cast(dstTensor, tmpTensor, RoundMode::CAST_NONE, s1Size * pseInfo.s2AlignedSize);
}
return;
}
}
template <typename T, bool hasPse>
__aicore__ inline void PseAlibiCompute(LocalTensor<T> &dstTensor, LocalTensor<T> &pseTensor, PseInfo &pseInfo)
{
if constexpr (hasPse == true) {
if (!NeedPseAlibiCompute<hasPse>(pseInfo)) {
return;
}
Add(dstTensor, dstTensor, pseTensor, pseInfo.vec1S1RealSize * pseInfo.pseS2ComputeSize);
return;
}
}
template <typename T, bool hasPse>
__aicore__ inline void PseCompute(LocalTensor<T> &dstTensor, LocalTensor<T> &pseTensor, PseInfo &pseInfo)
{
if constexpr (hasPse == true) {
if (pseInfo.pseEncodeType == pseEncodeALibiS2Full) {
return PseAlibiCompute<T, hasPse>(dstTensor, pseTensor, pseInfo);
}
int64_t computeSize = (pseInfo.pseShapeType == pseS1S2 || pseInfo.pseShapeType == pseSlopeBn ||
pseInfo.pseShapeType == pseSlopeN)
? pseInfo.vec1S1RealSize * pseInfo.s2AlignedSize
: pseInfo.s2AlignedSize;
PseBroadcastAdd<T, hasPse>(pseInfo.vec1S1RealSize, pseInfo.s2AlignedSize, computeSize, pseTensor,
dstTensor, pseInfo.pseShapeType);
return;
}
}
template <bool hasPse>
__aicore__ inline void PseInnerAlibiCreate(GlobalTensor<half> &dstTensor, LocalTensor<half> &helpTensor, PseInfo &pseInfo) {
if constexpr (hasPse == true) {
if (pseInfo.pseType != (uint32_t)PseTypeEnum::PSE_INNER_MUL_ADD_TYPE && pseInfo.pseType != (uint32_t)PseTypeEnum::PSE_INNER_MUL_ADD_SQRT_TYPE) {
return;
}
event_t eventIdMte3ToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE3_V));
event_t eventIdMte3ToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE3_S));
event_t eventIdVToMte3 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_MTE3));
float tmpValue = -1.0;
for (int64_t i = 0; i < pseInfo.pseAlibiBaseS1; i++) {
CreateVecIndex(helpTensor, (half)(i * tmpValue), pseInfo.pseAlibiBaseS2);
SetFlag<HardEvent::V_MTE3>(eventIdVToMte3);
WaitFlag<HardEvent::V_MTE3>(eventIdVToMte3);
DataCopy(dstTensor[i * pseInfo.pseAlibiBaseS2], helpTensor, pseInfo.pseAlibiBaseS2);
SetFlag<HardEvent::MTE3_V>(eventIdMte3ToV);
WaitFlag<HardEvent::MTE3_V>(eventIdMte3ToV);
SetFlag<HardEvent::MTE3_S>(eventIdMte3ToS);
WaitFlag<HardEvent::MTE3_S>(eventIdMte3ToS);
}
}
}
#endif

View File

@@ -0,0 +1,144 @@
/**
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file util.h
* \brief
*/
#ifndef FLASH_ATTENTION_UTIL_H
#define FLASH_ATTENTION_UTIL_H
constexpr int32_t blockBytes = 32;
constexpr int32_t byteBitRatio = 8;
constexpr int64_t prefixAttenMaskDownHeight = 1024;
constexpr static int32_t blockSize = blockBytes / 4; // 4 means sizeof(T)
constexpr static int32_t repeatMaxBytes = 256;
constexpr static int32_t repeatMaxTimes = 255;
constexpr static int32_t repeatMaxSize = repeatMaxBytes / 4; // 4 means sizeof(T)
using AscendC::LocalTensor;
using AscendC::GlobalTensor;
using AscendC::DataFormat;
using AscendC::ShapeInfo;
using AscendC::DataCopyParams;
using AscendC::DataCopyPadParams;
using AscendC::BinaryRepeatParams;
using AscendC::IsSameType;
using AscendC::HardEvent;
using AscendC::SetFlag;
using AscendC::WaitFlag;
enum class LayOutTypeEnum { None = 0, LAYOUT_BSH = 1, LAYOUT_SBH = 2, LAYOUT_BNSD = 3, LAYOUT_TND = 4, LAYOUT_NTD_TND = 5};
namespace math {
template <typename T> __aicore__ inline T Ceil(T a, T b)
{
if (b == 0) {
return 0;
}
return (a + b - 1) / b;
}
template <typename T> __aicore__ inline T Align(T a, T b)
{
if (b == 0) {
return 0;
}
return (a + b - 1) / b * b;
}
}
template <typename T1, typename T2>
__aicore__ inline T1 CeilDiv(T1 a, T2 b)
{
if (b == 0) {
return 0;
}
return (a + b - 1) / b;
}
template <typename T1, typename T2>
__aicore__ inline T1 Max(T1 a, T2 b)
{
return (a > b) ? (a) : (b);
}
template <typename T1, typename T2>
__aicore__ inline T1 Min(T1 a, T2 b)
{
return (a > b) ? (b) : (a);
}
__aicore__ inline void BoolCopyIn(LocalTensor<uint8_t> &dstTensor, GlobalTensor<uint8_t> &srcTensor,
int64_t srcOffset, uint32_t s1Size, uint32_t s2Size, int64_t totalS2Size, int64_t alignedSize = blockBytes)
{
uint32_t alignedS2Size = CeilDiv(s2Size, alignedSize) * alignedSize;
uint32_t shapeArray[] = {s1Size, alignedS2Size};
dstTensor.SetShapeInfo(ShapeInfo(2, shapeArray, DataFormat::ND));
dstTensor.SetSize(s1Size * alignedS2Size);
DataCopyParams dataCopyParams;
dataCopyParams.blockCount = s1Size;
dataCopyParams.dstStride = 0;
if (totalS2Size == blockBytes && alignedSize == 64) { // totalS2Size < 64 && totalS2Size % blockBytes == 0
dataCopyParams.dstStride = 1;
alignedSize = blockBytes;
alignedS2Size = CeilDiv(s2Size, blockBytes) * blockBytes;
}
if (totalS2Size % alignedSize == 0) {
dataCopyParams.blockLen = alignedS2Size / blockBytes;
dataCopyParams.srcStride = (totalS2Size - alignedS2Size) / blockBytes;
DataCopy(dstTensor, srcTensor[srcOffset], dataCopyParams);
} else {
dataCopyParams.blockLen = s2Size;
dataCopyParams.srcStride = totalS2Size - s2Size;
DataCopyPadParams dataCopyPadParams;
dataCopyPadParams.isPad = true;
dataCopyPadParams.rightPadding = Min(alignedS2Size - s2Size, blockBytes);
dataCopyPadParams.paddingValue = 1;
DataCopyPad(dstTensor, srcTensor[srcOffset], dataCopyParams, dataCopyPadParams);
}
}
__aicore__ inline void Bit2Int8CopyIn(LocalTensor<uint8_t> &dstTensor, GlobalTensor<uint8_t> &srcTensor,
int64_t srcOffset, uint32_t batchSize, uint32_t s1BaseSize, uint32_t s2BaseSize, int64_t s2TotalSize,
int64_t alignedSize = blockBytes)
{
uint32_t alignedS2Size = CeilDiv(s2BaseSize / byteBitRatio, alignedSize) * alignedSize;
uint32_t shapeArray[] = {batchSize * s1BaseSize, alignedS2Size};
dstTensor.SetShapeInfo(ShapeInfo(2, shapeArray, DataFormat::ND));
dstTensor.SetSize(batchSize * s1BaseSize * alignedS2Size);
DataCopyParams dataCopyParams;
dataCopyParams.blockCount = batchSize * s1BaseSize;
dataCopyParams.blockLen = CeilDiv(s2BaseSize / byteBitRatio, blockBytes);
dataCopyParams.dstStride = 0;
if (s2TotalSize / byteBitRatio % alignedSize == 0 && s2BaseSize / byteBitRatio % alignedSize == 0) {
dataCopyParams.srcStride =
(s2TotalSize / byteBitRatio - dataCopyParams.blockLen * blockBytes) / blockBytes;
DataCopy(dstTensor, srcTensor[srcOffset / byteBitRatio], dataCopyParams);
} else {
dataCopyParams.blockLen = CeilDiv(s2BaseSize , byteBitRatio);
dataCopyParams.srcStride = (s2TotalSize - s2BaseSize) / byteBitRatio;
DataCopyPadParams dataCopyPadParams;
dataCopyPadParams.isPad = true;
dataCopyPadParams.rightPadding = 0;
dataCopyPadParams.paddingValue = 0;
DataCopyPad(dstTensor, srcTensor[srcOffset / byteBitRatio], dataCopyParams, dataCopyPadParams);
}
}
__aicore__ inline int32_t Align(int32_t shape)
{
int32_t alignFactor = 16;
int32_t alignedSize = CeilDiv<int32_t, int32_t>(shape, alignFactor) * alignFactor;
return alignedSize;
}
#endif // FLASH_ATTENTION_UTIL_H

View File

@@ -0,0 +1,190 @@
/**
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file dfx_base.h
* \brief 外部模块不应直接引用本头文件
*/
#pragma once
#include <string>
#include <cstdint>
#include <sstream>
#include <unistd.h>
#include <sys/syscall.h>
#include <securec.h>
#include <base/alog_pub.h>
#include <base/err_msg.h>
#include <exe_graph/runtime/tiling_context.h>
#include <exe_graph/runtime/tiling_parse_context.h>
#include <exe_graph/runtime/infer_shape_context.h>
#include <exe_graph/runtime/infer_datatype_context.h>
namespace ops {
namespace utils {
class LogBase {
public:
static constexpr const int MAX_LOG_LEN = 16000;
static constexpr const int MSG_HDR_LEN = 200;
static inline uint64_t GetTid()
{
return static_cast<uint64_t>(syscall(__NR_gettid));
}
static inline const char *GetStr(const std::string &str)
{
return str.c_str();
}
static inline const char *GetStr(const char *str)
{
return str;
}
static inline const std::string &GetOpInfo(const std::string &str)
{
return str;
}
static inline const char *GetOpInfo(const char *str)
{
return str;
}
static inline std::string GetOpInfo(const gert::TilingContext *context)
{
return GetOpInfoFromContext(context);
}
static inline std::string GetOpInfo(const gert::TilingParseContext *context)
{
return GetOpInfoFromContext(context);
}
static inline std::string GetOpInfo(const gert::InferShapeContext *context)
{
return GetOpInfoFromContext(context);
}
static inline std::string GetOpInfo(const gert::InferDataTypeContext *context)
{
return GetOpInfoFromContext(context);
}
private:
template <class T> static inline std::string GetOpInfoFromContext(T context)
{
if (context == nullptr) {
return "nil:nil";
}
std::string opInfo = context->GetNodeType() != nullptr ? context->GetNodeType() : "nil";
opInfo += ":";
opInfo += context->GetNodeName() != nullptr ? context->GetNodeName() : "nil";
return opInfo;
}
};
} // namespace utils
template <typename T>
std::string Shape2String(const T& shape) {
std::ostringstream oss;
oss << "[";
if (shape.GetDimNum() > 0) {
for (size_t i = 0; i < shape.GetDimNum() - 1; ++i) {
oss << shape.GetDim(i) << ", ";
}
oss << shape.GetDim(shape.GetDimNum() - 1);
}
oss << "]";
return oss.str();
}
} // namespace ops
// 使用本宏前需预定义标识子模块名称的 OPS_UTILS_LOG_SUB_MOD_NAME
// 如: #define OPS_UTILS_LOG_SUB_MOD_NAME "OP_TILING" 或通过 CMake 传递预定义宏
#define OPS_LOG_STUB(MOD_ID, LOG_LEVEL, OPS_DESC, FMT, ...) \
do { \
if (AlogCheckDebugLevel(static_cast<int>(MOD_ID), (LOG_LEVEL)) == 1) { \
AlogRecord(static_cast<int>(MOD_ID), DLOG_TYPE_DEBUG, (LOG_LEVEL), \
"[%s:%d][%s]%s[%s][%lu] OpName:[%s] " #FMT, \
__FILE__, __LINE__, (OPS_UTILS_LOG_SUB_MOD_NAME), \
(OPS_UTILS_LOG_PACKAGE_TYPE), __FUNCTION__, ops::utils::LogBase::GetTid(), \
ops::utils::LogBase::GetStr(ops::utils::LogBase::GetOpInfo(OPS_DESC)), ##__VA_ARGS__); \
} \
}while (0)
#define OPS_LOG_STUB_IF(COND, LOG_FUNC, EXPR) \
static_assert(std::is_same<bool, std::decay<decltype(COND)>::type>::value, "condition should be bool"); \
do { \
if (__builtin_expect((COND), 0)) { \
LOG_FUNC; \
EXPR; \
} \
} while (0)
#define OPS_INNER_ERR_STUB(ERR_CODE_STR, OPS_DESC, FMT, ...) \
do { \
OPS_LOG_STUB(OP, DLOG_ERROR, OPS_DESC, FMT, ##__VA_ARGS__); \
REPORT_INNER_ERR_MSG(ERR_CODE_STR, FMT, ##__VA_ARGS__); \
} while (0)
#define OPS_CALL_ERR_STUB(ERR_CODE_STR, OPS_DESC, FMT, ...) \
do { \
OPS_LOG_STUB(OP, DLOG_ERROR, OPS_DESC, FMT, ##__VA_ARGS__); \
REPORT_INNER_ERR_MSG(ERR_CODE_STR, FMT, ##__VA_ARGS__); \
} while (0)
#define OPS_LOG_STUB_D(OPS_DESC, FMT, ...) OPS_LOG_STUB(OP, DLOG_DEBUG, OPS_DESC, FMT, ##__VA_ARGS__)
#define OPS_LOG_STUB_I(OPS_DESC, FMT, ...) OPS_LOG_STUB(OP, DLOG_INFO, OPS_DESC, FMT, ##__VA_ARGS__)
#define OPS_LOG_STUB_W(OPS_DESC, FMT, ...) OPS_LOG_STUB(OP, DLOG_WARN, OPS_DESC, FMT, ##__VA_ARGS__)
#define OPS_LOG_STUB_E(OPS_DESC, FMT, ...) OPS_LOG_STUB(OP, DLOG_ERROR, OPS_DESC, FMT, ##__VA_ARGS__)
#define OPS_LOG_STUB_EVENT(OPS_DESC, FMT, ...) OPS_LOG_STUB(OP, DLOG_EVENT, OPS_DESC, FMT, ##__VA_ARGS__)
#define OPS_LOG_STUB_FULL(LEVEL, OPS_DESC, FMT, ...) \
do { \
if (0 == AlogCheckDebugLevel(OP, (LEVEL))) { \
break; \
} \
char msgbufxyz[ops::utils::LogBase::MAX_LOG_LEN]; \
size_t msgmaxlen = (MSG_LENGTH - ops::utils::LogBase::MSG_HDR_LEN); \
int rettmp = snprintf_s(msgbufxyz, sizeof(msgbufxyz), sizeof(msgbufxyz) - 1, FMT, ##__VA_ARGS__); \
if (rettmp == -1) { \
msgbufxyz[sizeof(msgbufxyz) - 1] = '\0'; \
} \
size_t msglength = std::strlen(msgbufxyz); \
if (msglength < msgmaxlen) { \
OPS_LOG_STUB(OP, (LEVEL), (OPS_DESC), "%s", msgbufxyz); \
break; \
} \
char *msgchunkbegin = msgbufxyz; \
char *msgchunkend = nullptr; \
while (msgchunkbegin < msgbufxyz + msglength) { \
if (msgchunkbegin[0] == '\n') { \
OPS_LOG_STUB(OP, (LEVEL), (OPS_DESC), ""); \
msgchunkbegin += 1; \
continue; \
} \
msgchunkend = std::strchr(msgchunkbegin, '\n'); \
if (msgchunkend == nullptr) { \
msgchunkend = msgchunkbegin + std::strlen(msgchunkbegin); \
} \
while (msgchunkend > msgchunkbegin) { \
std::string msgchunk(msgchunkbegin, \
std::min(msgmaxlen, static_cast<size_t>(msgchunkend - msgchunkbegin))); \
OPS_LOG_STUB(OP, (LEVEL), (OPS_DESC), "%s", msgchunk.c_str()); \
msgchunkbegin += msgchunk.size(); \
} \
msgchunkbegin += 1; \
} \
} while (0)

View File

@@ -0,0 +1,59 @@
/**
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file ops_log.h
* \brief
*/
#pragma once
#include "log/inner/dfx_base.h"
/* 基础日志 */
#define OPS_LOG_D(OPS_DESC, ...) OPS_LOG_STUB_D(OPS_DESC, __VA_ARGS__)
#define OPS_LOG_I(OPS_DESC, ...) OPS_LOG_STUB_I(OPS_DESC, __VA_ARGS__)
#define OPS_LOG_W(OPS_DESC, ...) OPS_LOG_STUB_W(OPS_DESC, __VA_ARGS__)
#define OPS_LOG_E(OPS_DESC, ...) OPS_INNER_ERR_STUB("EZ9999", OPS_DESC, __VA_ARGS__)
#define OPS_LOG_E_WITHOUT_REPORT(OPS_DESC, ...) OPS_LOG_STUB_E(OPS_DESC, __VA_ARGS__)
#define OPS_LOG_EVENT(OPS_DESC, ...) OPS_LOG_STUB_EVENT(OPS_DESC, __VA_ARGS__)
/* 全量日志
* 输出超长日志, 若日志超长, 则会被分为多行输出 */
#define OPS_LOG_FULL(LEVEL, OPS_DESC, ...) OPS_LOG_STUB_FULL(LEVEL, OPS_DESC, __VA_ARGS__)
#define OPS_LOG_D_FULL(OPS_DESC, ...) OPS_LOG_STUB_FULL(DLOG_DEBUG, OPS_DESC, __VA_ARGS__)
#define OPS_LOG_I_FULL(OPS_DESC, ...) OPS_LOG_STUB_FULL(DLOG_INFO, OPS_DESC, __VA_ARGS__)
#define OPS_LOG_W_FULL(OPS_DESC, ...) OPS_LOG_STUB_FULL(DLOG_WARN, OPS_DESC, __VA_ARGS__)
/* 条件日志 */
#define OPS_LOG_D_IF(COND, OP_DESC, EXPR, ...) OPS_LOG_STUB_IF(COND, OPS_LOG_D(OP_DESC, __VA_ARGS__), EXPR)
#define OPS_LOG_I_IF(COND, OP_DESC, EXPR, ...) OPS_LOG_STUB_IF(COND, OPS_LOG_I(OP_DESC, __VA_ARGS__), EXPR)
#define OPS_LOG_W_IF(COND, OP_DESC, EXPR, ...) OPS_LOG_STUB_IF(COND, OPS_LOG_W(OP_DESC, __VA_ARGS__), EXPR)
#define OPS_LOG_E_IF(COND, OP_DESC, EXPR, ...) OPS_LOG_STUB_IF(COND, OPS_LOG_E(OP_DESC, __VA_ARGS__), EXPR)
#define OPS_LOG_EVENT_IF(COND, OP_DESC, EXPR, ...) OPS_LOG_STUB_IF(COND, OPS_LOG_EVENT(OP_DESC, __VA_ARGS__), EXPR)
#define OPS_LOG_E_IF_NULL(OPS_DESC, PTR, EXPR) \
if (__builtin_expect((PTR) == nullptr, 0)) { \
OPS_LOG_STUB_E(OPS_DESC, "%s is nullptr!", #PTR); \
OPS_CALL_ERR_STUB("EZ9999", OPS_DESC, "%s is nullptr!", #PTR); \
EXPR; \
}
#define OPS_CHECK(COND, LOG_FUNC, EXPR) \
if (COND) { \
LOG_FUNC; \
EXPR; \
}
#define OP_CHECK(COND, LOG_FUNC, EXPR) \
if (COND) { \
LOG_FUNC; \
EXPR; \
}

View File

@@ -0,0 +1,47 @@
/**
* Copyright (c) 2023-2024 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file data_copy_transpose_tiling.h
* \brief
*/
#pragma once
#include <vector>
#include <graph/tensor.h>
#include "data_copy_transpose_tiling_def.h"
namespace optiling {
inline void GetDataCopyTransposeTiling(const ge::Shape &dstShape, const ge::Shape &srcShape, const uint32_t typeSize,
optiling::CopyTransposeTiling &tiling)
{
std::vector<int64_t> dstShapeInfo = dstShape.GetDims();
std::vector<int64_t> srcShapeInfo = srcShape.GetDims();
tiling.set_dstShapeB(dstShapeInfo[0]);
tiling.set_dstShapeN(dstShapeInfo[1]);
tiling.set_dstShapeS(dstShapeInfo[2]);
tiling.set_dstShapeH(dstShapeInfo[3]);
tiling.set_dstShapeHN(tiling.get_dstShapeH() / tiling.get_dstShapeN());
tiling.set_srcShapeB(srcShapeInfo[0]);
tiling.set_srcShapeN(srcShapeInfo[1]);
tiling.set_srcShapeS(srcShapeInfo[2]);
tiling.set_srcShapeHN(srcShapeInfo[3]);
tiling.set_originalShapeNLen(tiling.get_srcShapeHN() * typeSize);
tiling.set_shapeSHValue(tiling.get_dstShapeS() * tiling.get_dstShapeH());
tiling.set_shapeNsValue(tiling.get_dstShapeN() * tiling.get_dstShapeS());
tiling.set_shapeNsnValue(tiling.get_dstShapeN() * tiling.get_srcShapeS() * tiling.get_srcShapeN());
tiling.set_shapeBHValue(tiling.get_dstShapeB() * tiling.get_dstShapeH());
}
} // namespace optiling

View File

@@ -0,0 +1,43 @@
/**
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file data_copy_transpose_tiling_def.h
* \brief
*/
#pragma once
#include <cstdint>
#include <register/tilingdata_base.h>
namespace optiling {
BEGIN_TILING_DATA_DEF(CopyTransposeTiling)
TILING_DATA_FIELD_DEF(uint32_t, dstShapeB);
TILING_DATA_FIELD_DEF(uint32_t, dstShapeN);
TILING_DATA_FIELD_DEF(uint32_t, dstShapeS);
TILING_DATA_FIELD_DEF(uint32_t, dstShapeHN);
TILING_DATA_FIELD_DEF(uint32_t, dstShapeH);
TILING_DATA_FIELD_DEF(uint32_t, srcShapeB);
TILING_DATA_FIELD_DEF(uint32_t, srcShapeN);
TILING_DATA_FIELD_DEF(uint32_t, srcShapeS);
TILING_DATA_FIELD_DEF(uint32_t, srcShapeHN);
TILING_DATA_FIELD_DEF(uint32_t, originalShapeNLen);
TILING_DATA_FIELD_DEF(uint32_t, shapeSHValue);
TILING_DATA_FIELD_DEF(uint32_t, shapeNsValue);
TILING_DATA_FIELD_DEF(uint32_t, shapeNsnValue);
TILING_DATA_FIELD_DEF(uint32_t, invalidParamCopyTransposeTiling);
TILING_DATA_FIELD_DEF(uint32_t, shapeBHValue);
TILING_DATA_FIELD_DEF(uint32_t, paramsAlign);
END_TILING_DATA_DEF;
REGISTER_TILING_DATA_CLASS(CopyTransposeTilingOp, CopyTransposeTiling)
} // namespace optiling

View File

@@ -0,0 +1,225 @@
/**
* Copyright (c) 2023-2024 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file tiling_base.h
* \brief
*/
#pragma once
#include <sstream>
#include <exe_graph/runtime/tiling_context.h>
#include <graph/utils/type_utils.h>
#include <tiling/platform/platform_ascendc.h>
#include "log/ops_log.h"
#ifdef ASCENDC_OP_TEST
#define ASCENDC_EXTERN_C extern "C"
#else
#define ASCENDC_EXTERN_C
#endif
namespace optiling {
struct AiCoreParams {
uint64_t ubSize;
uint64_t blockDim;
uint64_t aicNum;
uint64_t l1Size;
uint64_t l0aSize;
uint64_t l0bSize;
uint64_t l0cSize;
};
struct FlashAttentionScoreGradCompileInfo {
uint32_t aivNum;
uint32_t aicNum;
uint64_t ubSize;
uint64_t l1Size;
uint64_t l0aSize;
uint64_t l0bSize;
uint64_t l0cSize;
uint64_t l2CacheSize;
int64_t coreNum;
};
class TilingBaseClass {
public:
TilingBaseClass() = default;
explicit TilingBaseClass(gert::TilingContext *context) : context_(context)
{
}
virtual ~TilingBaseClass() = default;
// Tiling执行框架
// 1、GRAPH_SUCCESS: 成功并且不需要继续执行后续Tiling类的实现
// 2、GRAPH_FAILED: 失败中止整个Tiling流程
// 3、GRAPH_PARAM_INVALID: 本类不支持需要继续往下执行其他Tiling类的实现
ge::graphStatus DoTiling()
{
auto ret = GetShapeAttrsInfo();
if (ret != ge::GRAPH_SUCCESS) {
return ret;
}
ret = GetPlatformInfo();
if (ret != ge::GRAPH_SUCCESS) {
return ret;
}
if (!IsCapable()) {
return ge::GRAPH_PARAM_INVALID;
}
ret = DoOpTiling();
if (ret != ge::GRAPH_SUCCESS) {
return ret;
}
ret = DoLibApiTiling();
if (ret != ge::GRAPH_SUCCESS) {
return ret;
}
ret = GetWorkspaceSize();
if (ret != ge::GRAPH_SUCCESS) {
return ret;
}
ret = PostTiling();
if (ret != ge::GRAPH_SUCCESS) {
return ret;
}
context_->SetTilingKey(GetTilingKey());
DumpTilingInfo();
return ge::GRAPH_SUCCESS;
}
// 更新 context
virtual void Reset(gert::TilingContext *context)
{
context_ = context;
}
protected:
virtual bool IsCapable() = 0;
// 1、获取平台信息比如CoreNum、UB/L1/L0C资源大小
virtual ge::graphStatus GetPlatformInfo() = 0;
// 2、获取INPUT/OUTPUT/ATTR信息
virtual ge::graphStatus GetShapeAttrsInfo() = 0;
// 3、计算数据切分TilingData
virtual ge::graphStatus DoOpTiling() = 0;
// 4、计算高阶API的TilingData
virtual ge::graphStatus DoLibApiTiling() = 0;
// 5、计算TilingKey
[[nodiscard]] virtual uint64_t GetTilingKey() const = 0;
// 6、计算Workspace 大小
virtual ge::graphStatus GetWorkspaceSize() = 0;
// 7、保存Tiling数据
virtual ge::graphStatus PostTiling() = 0;
// 8、Dump Tiling数据
virtual void DumpTilingInfo()
{
int32_t enable = AlogCheckDebugLevel(static_cast<int32_t>(OP), DLOG_DEBUG);
if (enable != 1) {
return;
}
auto buf = (uint32_t *)context_->GetRawTilingData()->GetData();
auto bufLen = context_->GetRawTilingData()->GetDataSize();
std::ostringstream oss;
oss << "Start to dump tiling info. tilingkey:" << GetTilingKey() << ", tiling data size:" << bufLen
<< ", content:";
for (size_t i = 0; i < bufLen / sizeof(uint32_t); i++) {
oss << *(buf + i) << ",";
if (oss.str().length() > 640) { // Split according to 640 to avoid truncation
OPS_LOG_D(context_, "%s", oss.str().c_str());
oss.str("");
}
}
OPS_LOG_D(context_, "%s", oss.str().c_str());
}
static uint32_t CalcTschBlockDim(uint32_t sliceNum, uint32_t aicCoreNum, uint32_t aivCoreNum)
{
uint32_t ration;
if (aicCoreNum == 0 || aivCoreNum == 0 || aicCoreNum > aivCoreNum) {
return sliceNum;
}
ration = aivCoreNum / aicCoreNum;
return (sliceNum + (ration - 1)) / ration;
}
template <typename T> [[nodiscard]] std::string GetShapeDebugStr(const T &shape) const
{
std::ostringstream oss;
oss << "[";
if (shape.GetDimNum() > 0) {
for (size_t i = 0; i < shape.GetDimNum() - 1; ++i) {
oss << shape.GetDim(i) << ", ";
}
oss << shape.GetDim(shape.GetDimNum() - 1);
}
oss << "]";
return oss.str();
}
[[nodiscard]] std::string GetTensorDebugStr(const gert::StorageShape *shape,
const gert::CompileTimeTensorDesc *tensor)
{
if (shape == nullptr || tensor == nullptr) {
return "nil ";
}
std::ostringstream oss;
oss << "(dtype: " << ge::TypeUtils::DataTypeToSerialString(tensor->GetDataType()) << "),";
oss << "(shape:" << GetShapeDebugStr(shape->GetStorageShape()) << "),";
oss << "(ori_shape:" << GetShapeDebugStr(shape->GetOriginShape()) << "),";
oss << "(format: "
<< ge::TypeUtils::FormatToSerialString(
static_cast<ge::Format>(ge::GetPrimaryFormat(tensor->GetStorageFormat())))
<< "),";
oss << "(ori_format: " << ge::TypeUtils::FormatToSerialString(tensor->GetOriginFormat()) << ") ";
return oss.str();
}
[[nodiscard]] std::string GetTilingContextDebugStr()
{
std::ostringstream oss;
for (size_t i = 0; i < context_->GetComputeNodeInfo()->GetInputsNum(); ++i) {
oss << "input" << i << ": ";
oss << GetTensorDebugStr(context_->GetInputShape(i), context_->GetInputDesc(i));
}
for (size_t i = 0; i < context_->GetComputeNodeInfo()->GetOutputsNum(); ++i) {
oss << "output" << i << ": ";
oss << GetTensorDebugStr(context_->GetOutputShape(i), context_->GetOutputDesc(i));
}
return oss.str();
}
[[nodiscard]] std::string GetTilingDataDebugStr() const
{
auto rawTilingData = context_->GetRawTilingData();
auto rawTilingDataSize = rawTilingData->GetDataSize();
auto data = reinterpret_cast<const int32_t *>(rawTilingData->GetData());
size_t len = rawTilingDataSize / sizeof(int32_t);
std::ostringstream oss;
for (size_t i = 0; i < len; i++) {
oss << data[i] << ", ";
}
return oss.str();
}
protected:
gert::TilingContext *context_ = nullptr;
std::unique_ptr<platform_ascendc::PlatformAscendC> ascendcPlatform_{nullptr};
uint32_t blockDim_{0};
uint64_t workspaceSize_{0};
uint64_t tilingKey_{0};
AiCoreParams aicoreParams_{0, 0, 0, 0, 0, 0, 0};
};
} // namespace optiling

View File

@@ -0,0 +1,162 @@
/**
* Copyright (c) 2023-2024 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file tiling_templates_registry.h
* \brief
*/
#pragma once
#include <map>
#include <string>
#include <memory>
#include <exe_graph/runtime/tiling_context.h>
#include "tiling/tiling_base.h"
#include "log/ops_log.h"
#include "error/ops_error.h"
namespace optiling {
template <typename T> std::unique_ptr<TilingBaseClass> TILING_CLASS(gert::TilingContext *context)
{
return std::unique_ptr<T>(new (std::nothrow) T(context));
}
using TilingClassCase = std::unique_ptr<TilingBaseClass> (*)(gert::TilingContext *);
class TilingCases {
public:
explicit TilingCases(std::string op_type) : op_type_(std::move(op_type))
{
}
template <typename T> void AddTiling(int32_t priority)
{
OPS_ERR_IF(cases_.find(priority) != cases_.end(),
OPS_REPORT_VECTOR_INNER_ERR(op_type_, "There are duplicate registrations."), return);
cases_[priority] = TILING_CLASS<T>;
OPS_ERR_IF(
cases_[priority] == nullptr,
OPS_REPORT_VECTOR_INNER_ERR(op_type_, "Register op tiling func failed, please check the class name."),
return);
}
const std::map<int32_t, TilingClassCase> &GetTilingCases()
{
return cases_;
}
private:
std::map<int32_t, TilingClassCase> cases_;
const std::string op_type_;
};
class TilingRegistry {
public:
TilingRegistry() = default;
#ifdef ASCENDC_OP_TEST
static TilingRegistry &GetInstance();
#else
static TilingRegistry &GetInstance()
{
static TilingRegistry registry_impl_;
return registry_impl_;
}
#endif
std::shared_ptr<TilingCases> RegisterOp(const std::string &op_type)
{
if (registry_map_.find(op_type) == registry_map_.end()) {
registry_map_[op_type] = std::shared_ptr<TilingCases>(new (std::nothrow) TilingCases(op_type));
}
OPS_ERR_IF(registry_map_[op_type] == nullptr,
OPS_REPORT_VECTOR_INNER_ERR(op_type, "Register tiling func failed, please check the class name."),
return nullptr);
return registry_map_[op_type];
}
ge::graphStatus DoTilingImpl(gert::TilingContext *context)
{
const char *op_type = context->GetNodeType();
auto tilingTemplateRegistryMap = GetTilingTemplates(op_type);
for (auto it = tilingTemplateRegistryMap.begin(); it != tilingTemplateRegistryMap.end(); ++it) {
auto tilingTemplate = it->second(context);
if (tilingTemplate != nullptr) {
ge::graphStatus status = tilingTemplate->DoTiling();
if (status != ge::GRAPH_PARAM_INVALID) {
OPS_LOG_D(context, "Do general op tiling success priority=%d", it->first);
return status;
}
OPS_LOG_D(context, "Ignore general op tiling priority=%d", it->first);
}
}
OPS_REPORT_VECTOR_INNER_ERR(op_type, "Do op tiling failed, no valid template is found.");
return ge::GRAPH_FAILED;
}
ge::graphStatus DoTilingImpl(gert::TilingContext *context, const std::vector<int32_t> &priorities)
{
const char *op_type = context->GetNodeType();
auto tilingTemplateRegistryMap = GetTilingTemplates(op_type);
for (auto priorityId : priorities) {
auto templateFunc = tilingTemplateRegistryMap[priorityId](context);
if (templateFunc != nullptr) {
ge::graphStatus status = templateFunc->DoTiling();
if (status == ge::GRAPH_SUCCESS) {
OPS_LOG_D(context, "Do general op tiling success priority=%d", priorityId);
return status;
}
OPS_LOG_D(context, "Ignore general op tiling priority=%d", priorityId);
}
}
return ge::GRAPH_FAILED;
}
const std::map<int32_t, TilingClassCase> &GetTilingTemplates(const std::string &op_type)
{
OPS_ERR_IF(registry_map_.find(op_type) == registry_map_.end(),
OPS_REPORT_VECTOR_INNER_ERR(op_type, "Get op tiling func failed, please check the op name."),
return empty_tiling_case_);
return registry_map_[op_type]->GetTilingCases();
}
private:
std::map<std::string, std::shared_ptr<TilingCases>> registry_map_;
const std::map<int32_t, TilingClassCase> empty_tiling_case_ {};
};
class Register {
public:
explicit Register(std::string op_type) : op_type_(std::move(op_type))
{
}
template <typename T> Register &tiling(int32_t priority)
{
auto tilingCases = TilingRegistry::GetInstance().RegisterOp(op_type_);
OPS_ERR_IF(tilingCases == nullptr,
OPS_REPORT_VECTOR_INNER_ERR(op_type_, "Register op tiling failed, please the op name."),
return *this);
tilingCases->AddTiling<T>(priority);
return *this;
}
private:
const std::string op_type_;
};
// op_type: 算子名称, class_name: 注册的 tiling 类,
// priority: tiling 类的优先级, 越小表示优先级越高, 即被选中的概率越大
#define REGISTER_TILING_TEMPLATE(op_type, class_name, priority) \
static Register VAR_UNUSED##op_type_##class_name##priority_register = Register(op_type).tiling<class_name>(priority)
} // namespace optiling

View File

@@ -0,0 +1,136 @@
/**
* Copyright (c) 2023-2024 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file tiling_type.h
* \brief
*/
#pragma once
#include <cstdint>
namespace optiling {
enum class AxisEnum {
B = 0,
N2 = 1,
G = 2,
S1 = 3,
S2 = 4,
D = 5,
NONE = 9,
};
enum class DtypeEnum {
FLOAT16 = 0,
FLOAT32 = 1,
BFLOAT16 = 2,
FLOAT16_PRECISION = 3,
};
enum class PerformanceOrientedEnum {
BIG_BUFFER = 1,
BIG_DOUBLE_BUFFER = 2,
};
enum class MatmulConfig {
NULL_CONFIG = 0,
NORMAL_CONFIG = 1,
MDL_CONFIG = 2
};
enum class PseConfig {
NO_PSE = 0,
EXIST_PSE = 1
};
enum class AttenMaskConfig {
NO_ATTEN_MASK = 0,
EXIST_ATTEN_MASK = 1
};
enum class DropOutConfig {
NO_DROP_OUT = 0,
EXIST_DROP_OUT = 1
};
enum class CubeFormatEnum {
ND = 0,
NZ = 1
};
enum class LayoutEnum {
BSND = 0,
SBND = 1,
BNSD = 2,
TND = 3
};
enum class CubeInputSourceEnum {
GM = 0,
L1 = 1
};
enum class OptionEnum {
DISABLE = 0,
ENABLE = 1
};
enum class SparseEnum {
ALL = 0,
NONE = 1,
ANY = 2,
CAUSAL = 3,
BAND = 4,
PREFIX = 5,
BAND_COMPRESS = 6,
RIGHT_DOWN_CAUSAL = 7,
RIGHT_DOWN_CAUSAL_BAND = 8,
BAND_LEFT_UP_CAUSAL = 9
};
constexpr uint64_t RecursiveSum()
{
return 0;
}
template <typename T, typename... Args> constexpr uint64_t RecursiveSum(T templateId, Args... templateIds)
{
return static_cast<uint64_t>(templateId) + 10 * RecursiveSum(templateIds...);
}
// TilingKey 的生成规则:
// FlashAttentionScore/FlashAttentionScoreGrad 十进制位组装tiling key包含以下关键参数从低位到高位依次是Ub0, Ub1,
// Block, DataType, Format, Sparse, 特化模板 Ub0、Ub1:
// 表示Ub核内切分的轴使用枚举AxisEnum表示因为我们允许最多切分两根轴所以存在UB0和UB1如果没有UB核内切分
// 那么填AXIS_NONE。UB0和UB1各占一个十进制位;
// Block: 表示UB用来分核的轴使用枚举AxisEnum表示占一个十进制位;
// DataType: 表示当前tiling key支持的输入输出的数据类型使用枚举SupportedDtype来表示占一个十进制位
// Format: 表示当前tiling key支持的Format, 使用枚举InputLayout表示占一个十进制位
// Sparse: 表示当前tiling key是否支持Sparse使用枚举SparseCapability表示占一个十进制位
// 其余特化场景,定义自己的位域和值
// usage: get tilingKey from inputed types
// uint64_t tilingKey = GET_FLASHATTENTION_TILINGKEY(AxisEnum::AXIS_S1, AxisEnum::AXIS_S2, AxisEnum::AXIS_N2,
// SupportedDtype::FLOAT32, InputLayout::BSH, SparseCapability::SUPPORT_ALL)
constexpr uint64_t TILINGKEYOFFSET = uint64_t(10000000000000000000UL); // 10^19
template <typename... Args> constexpr uint64_t GET_TILINGKEY(Args... templateIds)
{
return TILINGKEYOFFSET + RecursiveSum(templateIds...);
}
// usage: get tilingKey from inputed types
// uint64_t tilingKey = TILINGKEY(S2, S1, N2, FLOAT32, BSND, ALL)
#define TILINGKEY(ub2, ub1, block, dtype, layout, sparse) \
(GET_TILINGKEY(AxisEnum::ub2, AxisEnum::ub1, AxisEnum::block, DtypeEnum::dtype, LayoutEnum::layout, \
SparseEnum::sparse))
} // namespace optiling

View File

@@ -0,0 +1,53 @@
/**
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file fallback_comm.cpp
* \brief
*/
#include "fallback_comm.h"
#include <iostream>
#include <unordered_map>
#include <vector>
#include <algorithm>
#include "aclnn/aclnn_base.h"
#include "runtime/base.h"
#ifdef __cplusplus
extern "C" {
#endif
namespace fallback {
using namespace std;
using namespace gert;
using namespace ge;
aclDataType ToAclDataType(ge::DataType dtype) {
static const std::vector<DataType> CANN_CONVERT_TO_ACL_DataType_LIST = {
ge::DataType::DT_FLOAT, ge::DataType::DT_FLOAT16, ge::DataType::DT_INT8, ge::DataType::DT_INT32,
ge::DataType::DT_UINT8, ge::DataType::DT_INT16, ge::DataType::DT_UINT16, ge::DataType::DT_UINT32,
ge::DataType::DT_INT64, ge::DataType::DT_DOUBLE, ge::DataType::DT_BOOL, ge::DataType::DT_STRING,
ge::DataType::DT_COMPLEX64, ge::DataType::DT_COMPLEX128, ge::DataType::DT_BF16, ge::DataType::DT_UINT64,
ge::DataType::DT_INT4};
auto iter = std::find(CANN_CONVERT_TO_ACL_DataType_LIST.begin(), CANN_CONVERT_TO_ACL_DataType_LIST.end(), dtype);
if (iter == CANN_CONVERT_TO_ACL_DataType_LIST.end()) {
return aclDataType::ACL_DT_UNDEFINED;
}
return static_cast<aclDataType>(dtype);
}
} // namespace fallback
#ifdef __cplusplus
}
#endif

View File

@@ -0,0 +1,25 @@
# Adding a custom aclnn operation
This document describes how to add a custom aclnn operation to vllm-ascend.
## How custom aclnn operation works in vllm-ascend?
Custom aclnn operations are built and installed into `vllm_ascend/cann_ops_custom` directory during the build process of vllm-ascend. Then the aclnn operators are bound to `torch.ops._C_ascend` module, enabling users to invoke them in vllm-ascend python code.
To enable custom operations, use the following code:
```python
from vllm_ascend.utils import enable_custom_op
enable_custom_op()
```
## How to add a custom aclnn operation?
- Create a new operation folder under `csrc` directory
- Create `op_host` and `op_kernel` directories for host and kernel source code
- Add build options in `csrc/build_aclnn.sh` for supported SOC. Note that multiple ops should be separated with `;`, i.e. `CUSTOM_OPS=op1;op2;op3`
- Bind aclnn operators to torch.ops._C_ascend module in `csrc/torch_binding.cpp`
- Write a meta implementation in `csrc/torch_binding_meta.cpp` for op being captured into aclgraph
After a successful build of vllm-ascend, the custom aclnn operation can be invoked in python code.

View File

@@ -12,4 +12,5 @@ eplb_swift_balancer.md
Multi_Token_Prediction
ACL_Graph
KV_Cache_Pool_Guide
add_custom_aclnn_op
:::

View File

@@ -1,9 +1,11 @@
[build-system]
# Should be mirrored in requirements.txt
requires = [
"attrs",
"cmake>=3.26",
"decorator",
"einops",
"googleapis-common-protos",
"numpy<2.0.0",
"packaging",
"pip",
@@ -12,6 +14,7 @@ requires = [
"scipy",
"pandas",
"pandas-stubs",
"psutil",
"setuptools>=64",
"setuptools-scm>=8",
"transformers<=4.57.1",

View File

@@ -25,7 +25,7 @@ import sys
from sysconfig import get_paths
from typing import Dict, List
from setuptools import Extension, find_packages, setup
from setuptools import Command, Extension, find_packages, setup
from setuptools.command.build_ext import build_ext
from setuptools.command.build_py import build_py
from setuptools.command.develop import develop
@@ -199,6 +199,27 @@ class custom_build_info(build_py):
super().run()
class build_and_install_aclnn(Command):
description = "Build and install AclNN by running build_aclnn.sh"
user_options = []
def initialize_options(self):
pass
def finalize_options(self):
pass
def run(self):
try:
print("Running bash build_aclnn.sh ...")
subprocess.check_call(
["bash", "csrc/build_aclnn.sh", ROOT_DIR, envs.SOC_VERSION])
print("buid_aclnn.sh executed successfully!")
except subprocess.CalledProcessError as e:
print(f"Error running build_aclnn.sh: {e}")
raise SystemExit(e.returncode)
class cmake_build_ext(build_ext):
# A dict of extension directories that have been configured.
did_config: Dict[str, bool] = {}
@@ -385,8 +406,22 @@ class cmake_build_ext(build_ext):
shutil.copy(src_path, dst_path)
print(f"Copy: {src_path} -> {dst_path}")
# copy back _cann_ops_custom directory
src_cann_ops_custom = os.path.join(ROOT_DIR, "vllm_ascend",
"_cann_ops_custom")
dst_cann_ops_custom = os.path.join(self.build_lib, "vllm_ascend",
"_cann_ops_custom")
if os.path.exists(src_cann_ops_custom):
import shutil
if os.path.exists(dst_cann_ops_custom):
shutil.rmtree(dst_cann_ops_custom)
shutil.copytree(src_cann_ops_custom, dst_cann_ops_custom)
print(f"Copy: {src_cann_ops_custom} -> {dst_cann_ops_custom}")
def run(self):
# First, run the standard build_ext command to compile the extensions
# First, ensure ACLNN custom-ops is built and installed.
self.run_command("build_aclnn")
# Then, run the standard build_ext command to compile the extensions
super().run()
@@ -450,6 +485,7 @@ def get_requirements() -> List[str]:
cmdclass = {
"develop": custom_develop,
"build_py": custom_build_info,
"build_aclnn": build_and_install_aclnn,
"build_ext": cmake_build_ext,
"install": custom_install
}

View File

@@ -0,0 +1,148 @@
import gc
import torch
import torch_npu
from vllm_ascend.utils import enable_custom_op
# enable internal format
torch_npu.npu.config.allow_internal_format = True
# enable vllm-ascend custom ops
enable_custom_op()
def gmm_swiglu_quant(x: torch.Tensor, weight: torch.Tensor,
perChannelScale: torch.Tensor,
perTokenScale: torch.Tensor, m: int):
"""
Perform quantized GMM (Grouped Matrix Multiplication) operation with SwiGLU activation function.
Parameters:
x (torch.Tensor): Input tensor with shape (m, k).
weight (torch.Tensor): Weight tensor with shape (k, n).
perChannelScale (torch.Tensor): Per-channel scaling factor with shape (n,).
perTokenScale (torch.Tensor): Per-token scaling factor with shape (m,).
m (int): Number of tokens (rows of x).
Returns:
quantOutput (torch.Tensor): Quantized output tensor with shape (m, k // 2).
quantScaleOutput (torch.Tensor): Quantization scaling factor with shape (m,).
"""
# Perform matrix multiplication with int32 precision
c_temp1 = torch.matmul(x.to(torch.int32), weight.to(torch.int32))
c_temp1 = c_temp1.to(torch.float32) # Convert back to float32 for scaling
# Apply per-channel and per-token scaling
c_temp2 = torch.mul(c_temp1, perChannelScale)
c_temp3 = torch.mul(c_temp2, perTokenScale.reshape(m, 1))
# Split the result into two parts to apply SwiGLU activation function
c_temp4, gate = c_temp3.chunk(2, dim=-1)
c_temp5 = c_temp4 * torch.sigmoid(c_temp4) # SwiGLU activation
c_temp6 = c_temp5 * gate # Element-wise multiplication with gating values
# Quantize the output
max = torch.max(
torch.abs(c_temp6),
-1).values # Find maximum absolute value to calculate scaling factor
quantScaleOutput = 127 / max # Calculate quantization scaling factor
quantOutput = torch.round(c_temp6 * quantScaleOutput.reshape(m, 1)).to(
torch.int8) # Quantize to int8
quantScaleOutput = 1 / quantScaleOutput # Inverse quantization scaling factor for subsequent dequantization
return quantOutput, quantScaleOutput
def process_groups(x: torch.Tensor, weight: torch.Tensor,
perChannelScale: torch.Tensor, perTokenScale: torch.Tensor,
groupList: torch.Tensor):
"""
Process input data by groups and call GMM_Swiglu_quant function for quantized computation.
Parameters:
x (torch.Tensor): Input tensor with shape (M, K).
weight (torch.Tensor): List of weight tensors, each with shape (E, K, N).
perChannelScale (torch.Tensor): List of per-channel scaling factors, each with shape (E, N).
perTokenScale (torch.Tensor): Per-token scaling factor with shape (M,).
groupList (list): List defining the number of tokens in each group.
Returns:
quantOutput (torch.Tensor): Quantized output tensor with shape (M, N // 2).
quantScaleOutput (torch.Tensor): Quantization scaling factor with shape (M,).
"""
M, N = x.shape[0], weight.shape[2] # Get the shape of the input tensor
quantOutput = torch.zeros(M, N // 2).to(
torch.int8) # Initialize quantized output tensor
quantScaleOutput = torch.zeros(M).to(
torch.float32) # Initialize quantization scaling factor tensor
start_idx = 0 # Starting index
preV = 0 # Number of tokens in the previous group
groupList = groupList.tolist()
# Iterate through groupList to process data by groups
for i, v in enumerate(groupList):
currV = v
tempV = currV - preV # Calculate number of tokens in the current group
preV = currV # Update number of tokens in the previous group
if tempV > 0:
# Call GMM_Swiglu_quant to process the current group
quantOutput[start_idx:start_idx + tempV], quantScaleOutput[start_idx:start_idx + tempV] = \
gmm_swiglu_quant(x[start_idx:start_idx + tempV],
weight[i],
perChannelScale[i],
perTokenScale[start_idx:start_idx + tempV],
tempV)
start_idx += tempV # Update starting index to process the next group
return quantOutput, quantScaleOutput
@torch.inference_mode()
def test_gmm_swiglu_quant_weight_nz_tensor_list():
M, K, E, N = 8192, 7168, 4, 4096
# x (M, K) - int8
x = torch.randint(-128, 127, (M, K), dtype=torch.int8)
# weight (E, N, K) - int8
weight = torch.randint(-128, 127, size=(E, K, N), dtype=torch.int8)
# weight_scale (E, N) - float32
weight_scale = torch.rand(E, N) * 0.9 + 0.1 # uniform(0.1, 1.0)
weight_scale = weight_scale.to(torch.float32)
weight_nz_npu = []
weight_scale_npu = []
for i in range(E):
weight_nz_npu.append(torch_npu.npu_format_cast(weight[i].npu(), 29))
weight_scale_npu.append(weight_scale[i].npu())
# x_scale (M,) - float32
x_scale = torch.rand(M) * 0.9 + 0.1 # uniform(0.1, 1.0)
x_scale = x_scale.to(torch.float32)
group_list = torch.tensor([2048, 4096, 6144, 8192], dtype=torch.int64)
output_cpu, output_scale_cpu = process_groups(x, weight, weight_scale,
x_scale, group_list)
output_npu, output_scale_npu, _ = \
torch.ops._C_ascend.grouped_matmul_swiglu_quant_weight_nz_tensor_list(x.npu(),
weight_nz_npu,
weight_scale_npu,
x_scale.npu(),
group_list.npu())
output_npu_valid = output_npu[:group_list[-1], :]
output_scale_npu_valid = output_scale_npu[:group_list[-1]]
torch.testing.assert_close(output_npu_valid.cpu(),
output_cpu,
atol=1,
rtol=2**-13)
torch.testing.assert_close(output_scale_npu_valid.cpu(),
output_scale_cpu,
atol=1e-9,
rtol=1e-6)
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()

View File

@@ -0,0 +1,3 @@
# This folder is reserved for the installation of custom aclnn operators tailored for vLLM-Ascend.
# Source code of the operators can be found in the `src` folder.
# The operators are compiled into a custom CANN software package and installed to this folder automatically.

View File

@@ -38,6 +38,27 @@ from vllm_ascend.utils import (
prefill_context_parallel_enable, update_aclgraph_sizes,
update_cudagraph_capture_sizes, update_default_aclgraph_sizes)
# set custom ops path
CUR_DIR = os.path.dirname(os.path.realpath(__file__))
CUSTOM_OPP_PATH = os.path.join(CUR_DIR, "vllm_ascend", "_cann_ops_custom",
"vendors", "customize")
CUSTOM_LIB_PATH = os.path.join(CUSTOM_OPP_PATH, "op_api", "lib")
if os.path.exists(CUSTOM_OPP_PATH):
current_cust_opp_path = os.environ.get("ASCEND_CUSTOM_OPP_PATH", "")
if current_cust_opp_path:
os.environ[
"ASCEND_CUSTOM_OPP_PATH"] = f"{CUSTOM_OPP_PATH}:{current_cust_opp_path}"
else:
os.environ["ASCEND_CUSTOM_OPP_PATH"] = CUSTOM_OPP_PATH
if os.path.exists(CUSTOM_LIB_PATH):
current_lib_path = os.environ.get("LD_LIBRARY_PATH", "")
if current_lib_path:
os.environ["LD_LIBRARY_PATH"] = f"{CUSTOM_LIB_PATH}:{current_lib_path}"
else:
os.environ["LD_LIBRARY_PATH"] = CUSTOM_LIB_PATH
if TYPE_CHECKING:
from vllm.config import ModelConfig, VllmConfig
from vllm.utils import FlexibleArgumentParser