This commit is contained in:
Chranos
2026-02-04 17:39:32 +08:00
parent 8511fe8530
commit 79dfc69789
299 changed files with 55927 additions and 0 deletions

View File

@@ -0,0 +1,103 @@
cmake_minimum_required(VERSION 3.8)
project(tmo_kernels)
message(STATUS "project name: ${PROJECT_NAME}")
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
################################################################################
# Build Evironment
################################################################################
set(BANG_TARGET_CPU_ARCH ${TARGET_CPU_ARCH})
message("-- TARGET_CPU_ARCH=${TARGET_CPU_ARCH}")
set(TARGET_MLU_ARCH ${TARGET_MLU_ARCH})
message("-- TARGET_MLU_ARCH=${TARGET_MLU_ARCH}")
set(NEUWARE_HOME ${NEUWARE_HOME})
message("-- NEUWARE_HOME=${NEUWARE_HOME}")
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH}
"${CMAKE_SOURCE_DIR}/cmake"
"${NEUWARE_HOME}/cmake"
"${NEUWARE_HOME}/cmake/modules"
)
find_package(BANG)
if(NOT BANG_FOUND)
message(FATAL_ERROR "BANG cannot be found.")
else ()
if (NOT BANG_CNCC_EXECUTABLE)
message(FATAL_ERROR "cncc not found, please ensure cncc is in your PATH env or set variable BANG_CNCC_EXECUTABLE from cmake. Otherwise you should check path used by find_program(BANG_CNCC_EXECUTABLE) in FindBANG.cmake")
endif()
endif()
set(EXECUTABLE_OUTPUT_PATH "${CMAKE_BINARY_DIR}/test")
set(LIBRARY_OUTPUT_PATH "${CMAKE_BINARY_DIR}/lib")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fPIC -pthread -pipe")
set(CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG} ${CMAKE_C_FLAGS} -g3 -O0")
set(CMAKE_C_FLAGS_RELEASE "${CMAKE_C_FLAGS_RELEASE} ${CMAKE_C_FLAGS} -O3")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -fPIC -std=c++17 -pthread -pipe")
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} ${CMAKE_CXX_FLAGS} -g3 -O0")
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} ${CMAKE_CXX_FLAGS} -O3")
set(CMAKE_EXE_LINKER_FLAGS_RELEASE "${CMAKE_EXE_LINKER_FLAGS_RELEASE} -Wl,--gc-sections -fPIC")
set(BANG_CNCC_FLAGS "-Wall -Werror -Wdeprecated-declarations -fPIC -std=c++17 -pthread --target=${TARGET_CPU_ARCH}")
if ( "${_cncc_version}" VERSION_LESS "5.0.0") # [CNNLCORE-19128]
message(STATUS "Default rounding mode will be rn when computing float numbers, otherwise will be tz when computing int numbers")
# This compile option was enabled by JIRA: CNNLCORE-12027
set(BANG_CNCC_FLAGS "${BANG_CNCC_FLAGS} -Xbang-cnas --deprecated-cvt-default-round-mode-rn")
endif()
if(${TARGET_CPU_ARCH} MATCHES ".*x86_64.*")
set(BANG_CNCC_FLAGS "${BANG_CNCC_FLAGS} -mcmodel=large")
endif()
string(TOLOWER ${CMAKE_BUILD_TYPE} _CMAKE_BUILD_TYPE_LOWER)
if(${_CMAKE_BUILD_TYPE_LOWER} MATCHES "debug")
message(STATUS "Build debug mode")
set(BANG_CNCC_FLAGS "${BANG_CNCC_FLAGS} -g3 -O0")
endif()
if(${_CMAKE_BUILD_TYPE_LOWER} MATCHES "release")
set(BANG_CNCC_FLAGS "${BANG_CNCC_FLAGS} -O3 -DNDEBUG")
endif()
if(${TARGET_MLU_ARCH} MATCHES "CNFATBIN")
set(BANG_CNCC_FLAGS "${BANG_CNCC_FLAGS}" "--bang-mlu-arch=mtp_592 --bang-mlu-arch=mtp_613 --no-neuware-version-check")
set(BANG_CNCC_FLAGS "${BANG_CNCC_FLAGS}" "--bang-wram-align64")
else()
set(BANG_CNCC_FLAGS "${BANG_CNCC_FLAGS}" "--bang-mlu-arch=${TARGET_MLU_ARCH}")
set(BANG_CNCC_FLAGS "${BANG_CNCC_FLAGS}" "--bang-wram-align64")
endif()
# setup predefined macro for host sources, only for single mlu arch, useful for edge
if (${TARGET_MLU_ARCH} MATCHES "^(m?tp_)?([0-9]+)$")
# convert mtp_xxx or tp_xxx to xxx
string(REGEX REPLACE "^(m?tp_)?([0-9]+)$" "\\2" _TARGET_MLU_ARCH ${TARGET_MLU_ARCH})
add_definitions(-DTARGET_MLU_ARCH=${_TARGET_MLU_ARCH})
set(BANG_CNCC_FLAGS "${BANG_CNCC_FLAGS} -DTARGET_MLU_ARCH=${_TARGET_MLU_ARCH}")
endif()
################################################################################
# Neuware Evironment
################################################################################
if(EXISTS ${NEUWARE_HOME})
include_directories("${NEUWARE_HOME}/include")
link_directories("${NEUWARE_HOME}/lib64")
link_directories("${NEUWARE_HOME}/lib")
else()
message(FATAL_ERROR "NEUWARE cannot be found, refer README.md to prepare NEUWARE_HOME environment.")
endif()
include_directories("${CMAKE_CURRENT_SOURCE_DIR}")
################################################################################
# Build TMO kernels
################################################################################
# aux_source_directory(src DIR_SRCS)
file(GLOB_RECURSE bang_src_files FOLLOW_SYMLINKS "${CMAKE_CURRENT_SOURCE_DIR}/*.mlu")
bang_add_library(tmo_kernels STATIC "${bang_src_files}")
target_link_libraries(tmo_kernels cnnl cnrt cndrv dl)

View File

@@ -0,0 +1,28 @@
#include "add_scalar.mluh"
// clang-format off
#include <mlu.h>
// clang-format on
namespace tmo {
namespace kernels {
#define ONCHIP_DATA_NUM ((int)(__MLU_NRAM_SIZE__ * 3 / 4 * 1024 / sizeof(int)))
__nram__ int nram_buffer[ONCHIP_DATA_NUM];
__mlu_global__ void MLUBlockAddScalar(int *dst, int *src, int count, int scalar) {
int offset = ONCHIP_DATA_NUM * taskId;
int deal_num = std::min(ONCHIP_DATA_NUM, count - offset);
if (deal_num <= 0) return;
__memcpy(nram_buffer, src + offset, deal_num * sizeof(int), GDRAM2NRAM);
__bang_add_scalar(nram_buffer, nram_buffer, scalar, deal_num);
__memcpy(dst + offset, nram_buffer, deal_num * sizeof(int), NRAM2GDRAM);
}
} // namespace kernels
KernelStatus invokeMLUAddScalar(cnrtQueue_t queue, int *dst, int *src, int count, int scalar) {
uint32_t task_dim = (count + ONCHIP_DATA_NUM - 1) / ONCHIP_DATA_NUM;
cnrtDim3_t dim{task_dim, 1, 1};
kernels::MLUBlockAddScalar<<<dim, cnrtFuncTypeBlock, queue>>>(dst, src, count, scalar);
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
} // namespace tmo

View File

@@ -0,0 +1,29 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_ADD_SCALAR_MLUH_
#define CSRC_KERNELS_ADD_SCALAR_MLUH_
#include "kernel_utils.h"
namespace tmo {
/**
* @brief Add src with a scalar and save the result to dst.
* @param queue: The queue for mlu.
* @param dst: Pointer to the MLU memory of dst.
* @param src: Pointer to the MLU memory of src.
* @param count: The elements number in src.
* @param scalar: The scalar to add.
* @note: only support int. dst can overlap with src.
*/
KernelStatus invokeMLUAddScalar(cnrtQueue_t queue, int *dst, int *src, int count, int scalar);
} // namespace tmo
#endif // CSRC_KERNELS_ADD_SCALAR_MLUH_

View File

@@ -0,0 +1,205 @@
#!/bin/bash
set -e
TOP_DIR="$( cd "$( dirname "$0" )" && pwd )"
cd ${TOP_DIR}
################################################################################
# Evironment Variables
# BUILD_MODE: release/debug
# BUILD_DIR: build(default)
# TARGET_MLU_ARCH: CNFATBIN/MLU590
# TARGET_CPU_ARCH: x86_64-linux-gnu
# TARGET_C_COMPILER: C comppiler full-path
# TARGET_CXX_COMPILER: CXX comppiler full-path
# STRIP strip tool path
################################################################################
BUILD_MODE=${BUILD_MODE:-release}
BUILD_DIR="${BUILD_DIR:-build}"
BUILD_JOBS=${BUILD_JOBS:-32}
TARGET_MLU_ARCH=${TARGET_MLU_ARCH:-CNFATBIN}
TARGET_CPU_ARCH=${TARGET_CPU_ARCH:-$(uname -m)-linux-gnu}
TARGET_C_COMPILER=${TARGET_C_COMPILER:-gcc}
TARGET_CXX_COMPILER=${TARGET_CXX_COMPILER:-g++}
STRIP="${STRIP}" # empty by default, check later
# to forward variable to other scripts
export BUILD_DIR
################################################################################
# Shell Common Functions
################################################################################
check_deb_package() {
if [ -z "$(dpkg -l | grep ${1})" ]; then
echo "-- Please sudo apt install ${1}"
exit -1
fi
}
check_rpm_package() {
if [ -z "$(rpm -qa | grep ${1})" ]; then
echo "-- Please sudo yum install ${1}"
exit -1
fi
}
usage () {
echo "USAGE: build.sh <options>"
echo
echo " If need specify neuware path, please:"
echo " export NEUWARE_HOME=/path/of/your/neuware"
echo
echo "OPTIONS:"
echo " -h, --help Print usage"
echo " <null> If no --mluxxx specified, default arch is cnfatbin which contain all mlu arch"
echo " --mlu590 Build for target product MLU590: __BANG_ARCH__ = 592"
echo " cncc --bang-mlu-arch=mtp_592, cnas --mlu-arch mtp_592"
echo " -d, --debug Build test case with debug mode"
echo " -v, --verbose Build with verbose output"
echo " -j, --jobs=* Build parallel jobs"
echo " --cache Build without deleting BUILD_DIR contents first"
}
################################################################################
# Build Main Entry
################################################################################
# 1. Check cmake tool for build, cmake-3.23.1 is recommended
if [ -f "/etc/os-release" ]; then
source /etc/os-release
if [[ "${NAME}" == Ubuntu* ]] || [[ "${NAME}" == Debian* ]]; then
check_deb_package cmake
CMAKE=cmake
elif [[ "${NAME}" == CentOS* ]] || [[ "${NAME}" == Kylin* ]]; then
if [[ "${VERSION_ID}" == 7 ]]; then
check_rpm_package cmake3
CMAKE=cmake3
else
check_rpm_package cmake
CMAKE=cmake
fi
elif [[ "${NAME}" == Anolis* ]];then
check_rpm_package cmake
CMAKE=cmake
else
echo "-- Not support build on this os!"
exit -1
fi
else
echo "-- Not support build on this os!"
exit -1
fi
# 2. Create build dir
if [ ! -d "$BUILD_DIR" ]; then
mkdir "$BUILD_DIR"
fi
# 3. Handle build options
cmdline_args=$(getopt -o h,d,v,j: --long help,debug,verbose,jobs:,mlu590,cache -n 'build.sh' -- "$@")
eval set -- "$cmdline_args"
if [ $? != 0 ]; then echo "Unknown options, use -h or --help" >&2 ; exit -1; fi
if [ $# != 0 ]; then
while true; do
case "$1" in
--mlu590)
TARGET_MLU_ARCH="mtp_592"
shift
;;
-h | --help)
usage
exit 0
;;
-d | --debug)
BUILD_MODE="debug"
echo "-- Using debug mode."
shift
;;
-v | --verbose)
BUILD_VERBOSE="VERBOSE=1"
shift
;;
-j | --jobs)
shift
BUILD_JOBS=$1
shift
;;
--cache)
FLAG_KEEP_CACHE=1
shift
;;
--)
shift
break
;;
*)
echo "-- Unknown options ${1}, use -h or --help"
usage
exit -1
;;
esac
done
fi
# 5. Check NEUWARE_HOME and cncc
if [ ! -z "${NEUWARE_HOME}" ]; then
echo "-- using NEUWARE_HOME = ${NEUWARE_HOME}"
else
echo "-- NEUWARE_HOME is null, refer README.md to prepare NEUWARE_HOME environment."
exit -1
fi
# 6. Check device compiler
export PATH="${NEUWARE_HOME}/bin":$PATH
export LD_LIBRARY_PATH="${NEUWARE_HOME}/lib64":$LD_LIBRARY_PATH
if [ -z $(which cncc) ]; then
echo "-- ERROR: cannot find cncc"
exit -1
fi
cncc --version || ( echo "-- ERROR: cncc is not for current CPU target" && exit -1 )
echo "-- cncc: $(which cncc)"
# Check host compiler
## check compiler version and consider activate devtoolset for CentOS 7
if [ "$OS_RELEASE_ID" = "centos" -a "$OS_RELEASE_VERSION_ID" = "7" ]; then
if [ ! -f "/opt/rh/devtoolset-7/enable" ]; then
echo "You are using CentOS 7 but without 'devtoolset-7' installed."
echo "Please install devtoolset-7 or gnu-g++ that verion >= 5."
sleep 2
else
source /opt/rh/devtoolset-7/enable && echo "devtoolset-7 activated" \
|| echo "devtoolset-7 has installed on your server, but source failed."
fi
fi
if [[ "$(g++ --version | head -n1 | awk '{ print $3 }' | cut -d '.' -f1)" -lt "5" ]]; then
echo "we do not support g++<5, try to use higher version"
exit 1
fi
TARGET_C_COMPILER=$(which gcc)
TARGET_CXX_COMPILER=$(which g++)
echo "-- TARGET_C_COMPILER: " ${TARGET_C_COMPILER}
echo "-- TARGET_CXX_COMPILER: " ${TARGET_CXX_COMPILER}
export CC=$(basename ${TARGET_C_COMPILER})
export CXX=$(basename ${TARGET_CXX_COMPILER})
################################################################################
# Project Build
################################################################################
CMAKE_EXTRA_OPTIONS=()
SOURCE_DIR=${TOP_DIR}
pushd ${BUILD_DIR}
if [[ -z "${FLAG_KEEP_CACHE}" ]]; then
echo "Remove cmake cache ${PWD}"
rm -rf ./*
fi
${CMAKE} -DCMAKE_BUILD_TYPE="${BUILD_MODE}" \
-DNEUWARE_HOME="${NEUWARE_HOME}" \
-DTARGET_MLU_ARCH="${TARGET_MLU_ARCH}" \
-DTARGET_CPU_ARCH="${TARGET_CPU_ARCH}" \
-DCMAKE_C_COMPILER="$(basename ${TARGET_C_COMPILER})" \
-DCMAKE_CXX_COMPILER="$(basename ${TARGET_CXX_COMPILER})" \
-DCMAKE_STRIP="${STRIP}" \
${CMAKE_EXTRA_OPTIONS[@]} ${SOURCE_DIR}
popd
${CMAKE} --build ${BUILD_DIR} -- ${BUILD_VERBOSE} -j${BUILD_JOBS}

View File

@@ -0,0 +1,192 @@
#include <stdint.h>
#include <cmath>
#include <iostream>
#include <vector>
#include "cnnl.h"
#include "cnrt.h"
#include "copy_blocks.mluh"
#include "kernel_utils.h"
// clang-format off
#include <mlu.h>
// clang-format on
namespace tmo {
#define NRAM_REMAIN_SIZE (32 * 1024)
#define NRAM_BUFFER_SIZE (__MLU_NRAM_SIZE__ * 1024 - NRAM_REMAIN_SIZE)
#define USE_GATHER_THRESHHOLD_BLOCKSIZE 458753
#define LAYER_SIZE 128
#define BLOCK_PAIR_SIZE 512
#define ALIGN_BYTES 64
struct CopyBlocksInfo {
void *key_addrs[LAYER_SIZE];
void *value_addrs[LAYER_SIZE];
unsigned int mapping_addrs[BLOCK_PAIR_SIZE * 2];
bool has_value_cache = true;
};
namespace kernels {
__nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE];
__mlu_func__ void copyBlocksNodld(CopyBlocksInfo info,
uint32_t num_per_core,
uint32_t block_mapping_offset,
int32_t num_layers,
uint32_t block_size_in_bytes) {
for (uint32_t i = 0; i < num_per_core; i++) {
uint32_t map_offset = block_mapping_offset + i * 2;
uint32_t src_idx = info.mapping_addrs[map_offset];
uint32_t dst_idx = info.mapping_addrs[map_offset + 1];
int64_t src_offset = block_size_in_bytes * src_idx;
int64_t dst_offset = block_size_in_bytes * dst_idx;
for (uint32_t j = 0; j < num_layers; j++) {
__memcpy((int8_t *)info.key_addrs[j] + dst_offset, (int8_t *)info.key_addrs[j] + src_offset,
block_size_in_bytes, GDRAM2GDRAM);
if (info.has_value_cache) {
__memcpy((int8_t *)info.value_addrs[j] + dst_offset,
(int8_t *)info.value_addrs[j] + src_offset, block_size_in_bytes, GDRAM2GDRAM);
}
}
}
}
__mlu_global__ void launchCopyBlocksKernel(CopyBlocksInfo info,
int32_t num_pairs,
int32_t num_layers,
uint32_t block_size_in_bytes) {
uint32_t num_per_core = num_pairs / taskDim;
uint32_t remain_for_core = num_pairs % taskDim;
num_per_core += ((taskId < remain_for_core) ? 1 : 0);
uint32_t block_mapping_offset =
num_per_core * taskId + ((taskId < remain_for_core) ? 0 : remain_for_core);
block_mapping_offset *= 2;
#if (__BANG_ARCH__ >= 592)
if (block_size_in_bytes < USE_GATHER_THRESHHOLD_BLOCKSIZE) {
auto num_pair_data_width = sizeof(int32_t);
uint32_t align_num = ALIGN_BYTES / num_pair_data_width;
unsigned int num_per_core_2 = num_per_core * 2;
unsigned int num_per_core_2_align = (num_per_core_2 + align_num - 1) / align_num * align_num;
unsigned int *gather_src_offset = (unsigned int *)nram_buffer;
unsigned int *block_mapping_src_dst = gather_src_offset + num_per_core_2_align;
int8_t *n_buffer = (int8_t *)(block_mapping_src_dst + num_per_core_2_align);
uint32_t nram_remain = NRAM_BUFFER_SIZE - sizeof(unsigned int *) * num_per_core_2_align * 2;
unsigned int *scatter_dst_offset = gather_src_offset + num_per_core;
uint32_t num_per_loop = nram_remain / block_size_in_bytes;
uint32_t repeat = num_per_core / num_per_loop;
uint32_t remain = num_per_core % num_per_loop;
for (int i = 0; i < num_per_core; i++) {
unsigned int mapping_addrs_idx = block_mapping_offset + i * 2;
block_mapping_src_dst[i] = info.mapping_addrs[mapping_addrs_idx];
block_mapping_src_dst[num_per_core + i] = info.mapping_addrs[mapping_addrs_idx + 1];
}
__bang_mul_scalar(gather_src_offset, block_mapping_src_dst, (unsigned int)block_size_in_bytes,
num_per_core_2);
__sync();
for (uint32_t k = 0; k < num_layers; k++) {
for (uint32_t i = 0; i < repeat; i++) {
__gather_async(n_buffer, info.key_addrs[k], gather_src_offset + i * num_per_loop,
(unsigned int)block_size_in_bytes, GDRAM2NRAM,
(unsigned int)block_size_in_bytes, num_per_loop);
__scatter_async(info.key_addrs[k], n_buffer, scatter_dst_offset + i * num_per_loop,
(unsigned int)block_size_in_bytes, NRAM2GDRAM,
(unsigned int)block_size_in_bytes, num_per_loop);
if (info.has_value_cache) {
__gather_async(n_buffer, info.value_addrs[k], gather_src_offset + i * num_per_loop,
(unsigned int)block_size_in_bytes, GDRAM2NRAM,
(unsigned int)block_size_in_bytes, num_per_loop);
__scatter_async(info.value_addrs[k], n_buffer, scatter_dst_offset + i * num_per_loop,
(unsigned int)block_size_in_bytes, NRAM2GDRAM,
(unsigned int)block_size_in_bytes, num_per_loop);
}
}
if (remain != 0) {
uint32_t repeat_nums = repeat * num_per_loop;
__gather_async(n_buffer, info.key_addrs[k], gather_src_offset + repeat_nums,
(unsigned int)block_size_in_bytes, GDRAM2NRAM,
(unsigned int)block_size_in_bytes, remain);
__scatter_async(info.key_addrs[k], n_buffer, scatter_dst_offset + repeat_nums,
(unsigned int)block_size_in_bytes, NRAM2GDRAM,
(unsigned int)block_size_in_bytes, remain);
if (info.has_value_cache) {
__gather_async(n_buffer, info.value_addrs[k], gather_src_offset + repeat_nums,
(unsigned int)block_size_in_bytes, GDRAM2NRAM,
(unsigned int)block_size_in_bytes, remain);
__scatter_async(info.value_addrs[k], n_buffer, scatter_dst_offset + repeat_nums,
(unsigned int)block_size_in_bytes, NRAM2GDRAM,
(unsigned int)block_size_in_bytes, remain);
}
}
}
} else {
copyBlocksNodld(info, num_per_core, block_mapping_offset, num_layers, block_size_in_bytes);
}
#else
copyBlocksNodld(info, num_per_core, block_mapping_offset, num_layers, block_size_in_bytes);
#endif
}
} // namespace kernels
KernelStatus invokeCopyBlocksKernel(const cnrtQueue_t queue,
const std::vector<void *> &key_caches,
const std::vector<void *> &value_caches,
const std::vector<int32_t> &block_mapping_vec,
const size_t block_size_in_bytes) {
CNdev dev;
cnCtxGetDevice(&dev);
int cluster_num;
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
int core_num;
CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev));
cnrtFunctionType_t k_type = cnrtFuncTypeBlock;
if (key_caches.empty()) {
std::cerr << "[invokeCopyBlocksKernel]: key_caches can not be empty." << std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
if (!value_caches.empty() && key_caches.size() != value_caches.size()) {
std::cerr << "[invokeCopyBlocksKernel]: key_caches size must equal to value_caches "
<< "size if value_caches is not empty." << std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
int32_t mapping_size = block_mapping_vec.size();
int32_t num_pairs = mapping_size / 2;
uint32_t task_dim = std::min(num_pairs, cluster_num * core_num);
cnrtDim3_t k_dim{task_dim, 1, 1};
int32_t num_layers = key_caches.size();
int32_t layer_loop_num = std::ceil(float(num_layers) / LAYER_SIZE);
int32_t layer_num_per_loop = std::ceil(float(num_layers) / layer_loop_num);
int32_t pair_loop_num = std::ceil(float(num_pairs) / BLOCK_PAIR_SIZE);
int32_t pair_num_per_loop = std::ceil(float(num_pairs) / pair_loop_num);
CopyBlocksInfo info;
if (value_caches.empty()) {
info.has_value_cache = false;
}
for (int32_t i = 0; i < layer_loop_num; i++) {
int32_t sub_num_layers =
std::min(int32_t(layer_num_per_loop), num_layers - i * layer_num_per_loop);
for (int32_t l = 0; l < sub_num_layers; l++) {
info.key_addrs[l] = key_caches[l + i * layer_num_per_loop];
if (info.has_value_cache) {
info.value_addrs[l] = value_caches[l + i * layer_num_per_loop];
}
}
for (int32_t j = 0; j < pair_loop_num; j++) {
int32_t sub_num_pairs =
std::min(int32_t(pair_num_per_loop), num_pairs - j * pair_num_per_loop);
int32_t lens_block_mapping = sub_num_pairs * 2;
int32_t block_vec_offset = j * pair_num_per_loop * 2;
for (int32_t m = 0; m < lens_block_mapping; m++) {
info.mapping_addrs[m] = block_mapping_vec[m + block_vec_offset];
}
kernels::launchCopyBlocksKernel<<<k_dim, k_type, queue>>>(info, sub_num_pairs, sub_num_layers,
block_size_in_bytes);
}
}
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
} // namespace tmo

View File

@@ -0,0 +1,37 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_COPY_BLOCKS_MLUH_
#define CSRC_KERNELS_COPY_BLOCKS_MLUH_
#include <vector>
#include "cnnl.h"
#include "kernel_utils.h"
namespace tmo {
/**
* @brief Perform copy_blocks operation.
* @param queue: The queue for mlu.
* @param key_caches: Output/Input. Pointer to the MLU memory that stores the key_caches
* vector<tensor> which has shape [num_layers<num_blocks, num_heads, block_size, head_size>].
* @param value_caches: Output/Input. Pointer to the MLU memory that stores the value_caches
* vector<tensor> which has shape [num_layers<num_blocks, num_heads, block_size, head_size>].
* @param block_mapping_vec: block_mapping vector.
* @param block_size_in_bytes: one block data size.
*/
KernelStatus invokeCopyBlocksKernel(const cnrtQueue_t queue,
const std::vector<void *> &key_caches,
const std::vector<void *> &value_caches,
const std::vector<int32_t> &block_mapping_vec,
const size_t block_size_in_bytes);
} // namespace tmo
#endif // CSRC_KERNELS_COPY_BLOCKS_MLUH_

View File

@@ -0,0 +1,271 @@
#include <cmath>
#include <cstddef>
#include "cnnl.h"
#include "cnrt.h"
#include "create_cos_sin_table.mluh"
namespace {
// constexpr int LINEAR_SCALING = 0;
// constexpr int FIX_NTK_SCALING = 1;
constexpr int DYNAMIC_NTK_SCALING = 2;
} // namespace
namespace tmo {
namespace kernels {
__nram__ int8_t nram_buffer[__MLU_NRAM_SIZE__ * 1024 - 32 * 1024];
__nram__ const float range[64] = {
0.0F, 2.0F, 4.0F, 6.0F, 8.0F, 10.0F, 12.0F, 14.0F, 16.0F, 18.0F, 20.0F,
22.0F, 24.0F, 26.0F, 28.0F, 30.0F, 32.0F, 34.0F, 36.0F, 38.0F, 40.0F, 42.0F,
44.0F, 46.0F, 48.0F, 50.0F, 52.0F, 54.0F, 56.0F, 58.0F, 60.0F, 62.0F, 64.0F,
66.0F, 68.0F, 70.0F, 72.0F, 74.0F, 76.0F, 78.0F, 80.0F, 82.0F, 84.0F, 86.0F,
88.0F, 90.0F, 92.0F, 94.0F, 96.0F, 98.0F, 100.0F, 102.0F, 104.0F, 106.0F, 108.0F,
110.0F, 112.0F, 114.0F, 116.0F, 118.0F, 120.0F, 122.0F, 124.0F, 126.0F};
__mlu_func__ void genRangeDims(float *range_nram, int elem_count) {
int count = 64;
__bang_move(range_nram, range, std::min(count, elem_count) * sizeof(float));
while (count < elem_count) {
__bang_add_scalar(range_nram + count, range_nram, (float)count * 2.0F,
std::min(count, elem_count - count));
count *= 2;
}
}
__mlu_func__ int getBatchMaxSeqLen(int *seq_lens_nram, int *seq_lens, int batch) {
__memcpy(seq_lens_nram, seq_lens, batch * sizeof(int), GDRAM2NRAM);
__bang_argmax((float *)seq_lens_nram, (float *)seq_lens_nram, batch);
return __load_nram(seq_lens_nram);
}
__mlu_func__ float getNTKAlpha(int curr_seq_len, int max_position_embeddings, int kv_seq_len) {
int seq_len = kv_seq_len > max_position_embeddings ? curr_seq_len : kv_seq_len;
float context_value = std::log2((float)seq_len / (float)max_position_embeddings) + 1.0F;
float ntk_alpha = std::pow(2.0F, std::ceil(context_value)) - 1.0F;
return std::max(ntk_alpha, 1.0F);
}
__mlu_func__ void getRotaryInvFreq(float *inv_freq_nram,
float *base_nram,
float *range_nram,
float base,
int rotary_dim,
int elem_count) {
// inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2) / dim))
__bang_write_value(base_nram, elem_count, base);
__bang_mul_scalar(inv_freq_nram, range_nram, 1.0F / (float)rotary_dim, elem_count);
__bang_log(base_nram, base_nram, elem_count);
__bang_mul(inv_freq_nram, inv_freq_nram, base_nram, elem_count);
__bang_pow2(inv_freq_nram, inv_freq_nram, elem_count);
__bang_recip(inv_freq_nram, inv_freq_nram, elem_count);
}
template <typename T>
__mlu_func__ void convertCosSinTable(float *cos_table, float *sin_table, int elem_count) {}
template <>
__mlu_func__ void convertCosSinTable<half>(float *cos_table, float *sin_table, int elem_count) {
__bang_float2half((half *)cos_table, cos_table, elem_count);
__bang_float2half((half *)sin_table, sin_table, elem_count);
}
template <>
__mlu_func__ void convertCosSinTable<bfloat16_t>(float *cos_table,
float *sin_table,
int elem_count) {
__bang_float2bfloat16((bfloat16_t *)cos_table, cos_table, elem_count);
__bang_float2bfloat16((bfloat16_t *)sin_table, sin_table, elem_count);
}
__mlu_global__ void MLUUpdateCachedAlpha(float *rotary_emb_alpha_cached,
int *seq_lens,
int max_position_embeddings,
int batch) {
int *seq_lens_nram = (int *)nram_buffer; // [batch]
int kv_seq_len = getBatchMaxSeqLen(seq_lens_nram, seq_lens, batch);
rotary_emb_alpha_cached[taskIdY] =
getNTKAlpha(seq_lens[taskIdY], max_position_embeddings, kv_seq_len);
}
template <typename T>
__mlu_global__ void MLUCreateCosSinTableKernel(void *cos_sin_table,
float *rotary_emb_alpha_cached,
int *seq_lens,
int max_position_embeddings,
int batch,
int batch_stride,
int rotary_seq_len,
int rotary_dim,
int rotary_stride,
float rotary_base,
float rotary_scaling,
int rotary_scaling_type,
int seq_seg,
bool interleaved,
cnnlDataType_t dtype) {
int half_rotary_dim = rotary_dim / 2;
float *base_nram = (float *)nram_buffer; // [rotary_dim / 2]
float *range_nram = base_nram + half_rotary_dim; // [rotary_dim / 2]
float *inv_freq_nram = range_nram + half_rotary_dim; // [rotary_dim / 2]
float *freqs_nram = inv_freq_nram + half_rotary_dim; // [rotary_dim / 2]
float *cos_nram = freqs_nram + half_rotary_dim; // [rotary_dim]
float *sin_nram = cos_nram + rotary_dim; // [rotary_dim]
float *swap_nram = sin_nram + rotary_dim; // [rotary_dim]
int *seq_lens_nram = (int *)(swap_nram + rotary_dim); // [batch]
genRangeDims(range_nram, half_rotary_dim);
float adjust_base = rotary_base;
if (rotary_scaling_type == DYNAMIC_NTK_SCALING) {
int kv_seq_len = getBatchMaxSeqLen(seq_lens_nram, seq_lens, batch);
float ntk_alpha = getNTKAlpha(seq_lens[taskIdY], max_position_embeddings, kv_seq_len);
if (rotary_emb_alpha_cached[taskIdY] == ntk_alpha) {
return;
}
adjust_base = rotary_base * std::pow(ntk_alpha, (float)rotary_dim / (float)(rotary_dim - 2));
}
getRotaryInvFreq(inv_freq_nram, base_nram, range_nram, adjust_base, rotary_dim, half_rotary_dim);
int seq_start = taskIdX * seq_seg;
int seq_end = (taskIdX + 1) * seq_seg > rotary_seq_len ? rotary_seq_len : (taskIdX + 1) * seq_seg;
T *cos_table = (T *)cos_sin_table + (size_t)taskIdY * batch_stride;
T *sin_table = cos_table + rotary_dim;
for (int idx = seq_start; idx < seq_end; ++idx) {
__bang_mul_scalar(freqs_nram, inv_freq_nram, idx, half_rotary_dim);
__bang_cos(cos_nram, freqs_nram, half_rotary_dim);
__bang_sin(sin_nram, freqs_nram, half_rotary_dim);
convertCosSinTable<T>(cos_nram, sin_nram, half_rotary_dim);
if (!interleaved) {
__memcpy(cos_table + idx * rotary_stride, cos_nram, half_rotary_dim * sizeof(T), NRAM2GDRAM,
half_rotary_dim * sizeof(T), 0, 1);
__memcpy(sin_table + idx * rotary_stride, sin_nram, half_rotary_dim * sizeof(T), NRAM2GDRAM,
half_rotary_dim * sizeof(T), 0, 1);
} else {
__bang_move((T *)cos_nram + half_rotary_dim, (T *)cos_nram, half_rotary_dim * sizeof(T));
__bang_transpose((T *)swap_nram, (T *)cos_nram, 2, half_rotary_dim);
__memcpy(cos_table + idx * rotary_stride, (T *)swap_nram, half_rotary_dim * 2 * sizeof(T),
NRAM2GDRAM);
__bang_move((T *)sin_nram + half_rotary_dim, (T *)sin_nram, half_rotary_dim * sizeof(T));
__bang_transpose((T *)cos_nram, (T *)sin_nram, 2, half_rotary_dim);
__memcpy((T *)sin_table + idx * rotary_stride, (T *)cos_nram, half_rotary_dim * 2 * sizeof(T),
NRAM2GDRAM);
}
}
}
#if __BANG_ARCH__ < 592
template <>
__mlu_global__ void MLUCreateCosSinTableKernel<bfloat16_t>(void *cos_sin_table,
float *rotary_emb_alpha_cached,
int *seq_lens,
int max_position_embeddings,
int batch,
int batch_stride,
int rotary_seq_len,
int rotary_dim,
int rotary_stride,
float rotary_base,
float rotary_scaling,
int rotary_scaling_type,
int seq_seg,
bool interleaved,
cnnlDataType_t dtype) {}
#endif
} // namespace kernels
KernelStatus invokeCreateCosSinTable(cnrtQueue_t queue,
void *cos_sin_table,
float *rotary_emb_alpha_cached,
int *seq_lens,
int max_position_embeddings,
int batch,
int batch_stride,
int rotary_seq_len,
int rotary_dim,
int rotary_stride,
float rotary_base,
float rotary_scaling,
int rotary_scaling_type,
bool interleaved,
cnnlDataType_t data_type) {
bool is_supported_dtype = data_type == CNNL_DTYPE_HALF || data_type == CNNL_DTYPE_FLOAT ||
data_type == CNNL_DTYPE_BFLOAT16;
if (!is_supported_dtype) {
std::cerr << "[invokeCreateCosSinTable]: unsupport data type for create cos sin table kernel."
<< std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
// clang-format off
void (*create_sin_cos_kernels[])(
void*, /* cos_sin_table */
float*, /* rotary_emb_alpha_cached */
int*, /* seq_lens */
int, /* max_position_embeddings */
int, /* batch */
int, /* batch_stride */
int, /* rotary_seq_len */
int, /* rotary_dim */
int, /* rotary_stride */
float, /* rotary_base */
float, /* rotary_scaling */
int, /* rotary_scaling_type */
int, /* seq_seg */
bool, /* interleaved */
cnnlDataType_t /* data_type */
) = {
kernels::MLUCreateCosSinTableKernel<half>,
kernels::MLUCreateCosSinTableKernel<float>,
kernels::MLUCreateCosSinTableKernel<bfloat16_t>
};
// clang-format on
int kernel_index = 0;
if (data_type == CNNL_DTYPE_HALF) {
kernel_index = 0;
} else if (data_type == CNNL_DTYPE_FLOAT) {
kernel_index = 1;
} else {
kernel_index = 2;
}
CNdev dev;
cnCtxGetDevice(&dev);
int cluster_num = 1;
int core_num = 1;
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev));
int used_core_num = std::min(rotary_seq_len, cluster_num * core_num);
int seq_seg = (rotary_seq_len + used_core_num - 1) / used_core_num;
cnrtDim3_t dim1;
dim1.x = used_core_num;
dim1.y = rotary_scaling_type == DYNAMIC_NTK_SCALING ? batch : 1;
dim1.z = 1;
if (data_type == CNNL_DTYPE_BFLOAT16 && !isBf16Supported()) {
std::cerr << "[invokeCreateCosSinTable]: MLU300 devices do not support bfloat16." << std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
create_sin_cos_kernels[kernel_index]<<<dim1, cnrtFuncTypeBlock, queue>>>(
cos_sin_table, rotary_emb_alpha_cached, seq_lens, max_position_embeddings, batch,
batch_stride, rotary_seq_len, rotary_dim, rotary_stride, rotary_base, rotary_scaling,
rotary_scaling_type, seq_seg, interleaved, data_type);
if (rotary_scaling_type == DYNAMIC_NTK_SCALING) {
cnrtDim3_t dim2;
dim2.x = 1;
dim2.y = batch;
dim2.z = 1;
kernels::MLUUpdateCachedAlpha<<<dim2, cnrtFuncTypeBlock, queue>>>(
rotary_emb_alpha_cached, seq_lens, max_position_embeddings, batch);
}
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
} // namespace tmo

View File

@@ -0,0 +1,62 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_CREATE_COS_SIN_TABLE_MLUH_
#define CSRC_KERNELS_CREATE_COS_SIN_TABLE_MLUH_
#include "cnnl.h"
#include "kernel_utils.h"
namespace tmo {
/**
* @brief Create cos and sin table for rotary embedding.
* @param queue: The queue for mlu.
* @param cos_sin_table: Output. Pointer to the MLU memory that stores the cos and sin table.
* If rotary_scaling_type is linear, the shape is [rotary_seq_len, rotary_stride].
* If rotary_scaling_type is dynamic ntk, the shape is [batch, rotary_seq_len,
* rotary_stride].
* @param rotary_emb_alpha_cached: Output/Input. Pointer to the MLU memory that
* stores the ntk alpha cache. Only used in dynamic ntk, the shape is [batch].
* @param seq_lens: Input. Pointer to the MLU memory that stores the true sequence len.
* The shape is [batch].
* @param max_position_embeddings: The maximum rotary embedding positions.
* @param batch: Batch size.
* @param batch_stride: The stride for batch dim of cos_sin_table.
* Only used in dynamic ntk, the value is rotary_seq_len * rotary_stride.
* @param rotary_seq_len: The rotary sequence length of cos and sin table.
* @param rotary_dim: The rotary dim value of cos and sin table.
* @param rotary_stride: The stride of rotary_seq_len dim for cos and sin table.
* @param rotary_base: The rotary base, value is usually 10000.
* @param rotary_scaling: The rotary scaling, value is usually 1.
* @param rotary_scaling_type: The rotary scaling type, value is linear or dynamic ntk.
* @param interleaved: A boolean value indicates compute mode of rotary embedding.
* @param dtype: Data type of cos and sin table generated.
*/
KernelStatus invokeCreateCosSinTable(cnrtQueue_t queue,
void *cos_sin_table,
float *rotary_emb_alpha_cached,
int *seq_lens,
int max_position_embeddings,
int batch,
int batch_stride,
int rotary_seq_len,
int rotary_dim,
int rotary_stride,
float rotary_base,
float rotary_scaling,
int rotary_scaling_type,
bool interleaved,
cnnlDataType_t dtype);
} // namespace tmo
#endif // CSRC_KERNELS_CREATE_COS_SIN_TABLE_MLUH_

View File

@@ -0,0 +1,812 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include <algorithm>
#include <cassert>
#include <climits>
#include <cstddef>
#include <iostream>
#include <type_traits>
#include "dequant_from_linear_cache.mluh"
#include "quant_utils.h"
// clang-format off
#include <mlu.h>
// clang-format on
namespace tmo {
namespace kernels {
#pragma bang walign(16)
#define REM_FOR_STACK (32 * 1024)
#define DEQUANT_WRAM_SIZE (__MLU_WRAM_SIZE__ * 1024)
#define DEQUANT_NRAM_SIZE (__MLU_NRAM_SIZE__ * 1024 - REM_FOR_STACK)
#define DEQUANT_LINEAR_PERHEAD kernels::MLUDequantFromLinearCacheKernelPerHead
#define DEQUANT_LINEAR_PERCHANNEL kernels::MLUDequantFromLinearCacheKernelPerChannel
#define DEQUANT_FUNC_LEN (24)
#define DEQUANT_BATCH_NUM (1024)
__wram__ int8_t wbuf[DEQUANT_WRAM_SIZE];
__nram__ int8_t nbuf[DEQUANT_NRAM_SIZE];
__nram__ uint8_t pre_table_nram[TRANS_TABLE_SIZE];
// Uses 8K = 1K * (4 + 4) to process offsets
__nram__ int32_t n_lens[DEQUANT_BATCH_NUM];
__nram__ int32_t n_offsets[DEQUANT_BATCH_NUM];
__mlu_func__ void calcu_offsets_per_channel(int32_t &cache_id,
size_t &context_offset,
size_t &cache_offset,
size_t &scale_offset,
const int32_t *cache_bs_id,
const int32_t *cache_seq_offsets,
const int32_t cache_mem_len,
const int32_t seq_len,
const int32_t seq_begin,
const int32_t seq_offset,
const int32_t batch_idx,
const size_t context_seq_stride,
const size_t cache_bs_stride,
const size_t cache_seq_stride,
const size_t scale_bs_stride) {
cache_id = cache_bs_id == nullptr ? batch_idx : __load_gdram((int32_t *)cache_bs_id + batch_idx);
int32_t cache_seq_offset =
cache_seq_offsets == nullptr ? 0 : __load_gdram((int32_t *)cache_seq_offsets + batch_idx);
if (cache_id >= 0 && cache_seq_offset >= 0 && (cache_seq_offset + seq_len) <= cache_mem_len) {
context_offset = context_seq_stride * (seq_offset + seq_begin);
cache_offset = cache_bs_stride * cache_id + cache_seq_stride * (cache_seq_offset + seq_begin);
scale_offset = scale_bs_stride * cache_id;
} else {
cache_id = -1;
}
}
__mlu_func__ void calcu_offsets_per_head(int32_t &cache_id,
size_t &context_offset,
size_t &key_cache_offset,
size_t &value_cache_offset,
size_t &scale_offset,
const int32_t *cache_bs_id,
const int32_t *cache_seq_offsets,
const int32_t cache_mem_len,
const int32_t seq_len,
const int32_t seq_begin,
const int32_t seq_offset,
const int32_t batch_idx,
const size_t context_seq_stride,
const size_t cache_bs_stride,
const size_t key_cache_seq_stride,
const size_t value_cache_seq_stride,
const size_t scale_bs_stride) {
cache_id = cache_bs_id == nullptr ? batch_idx : __load_gdram((int32_t *)cache_bs_id + batch_idx);
int32_t cache_seq_offset =
cache_seq_offsets == nullptr ? 0 : __load_gdram((int32_t *)cache_seq_offsets + batch_idx);
if (cache_id >= 0 && cache_seq_offset >= 0 && (cache_seq_offset + seq_len) <= cache_mem_len) {
context_offset = context_seq_stride * (seq_offset + seq_begin);
key_cache_offset =
cache_bs_stride * cache_id + key_cache_seq_stride * (cache_seq_offset + seq_begin);
value_cache_offset =
cache_bs_stride * cache_id + value_cache_seq_stride * (cache_seq_offset + seq_begin);
scale_offset = cache_seq_offset + seq_begin + scale_bs_stride * cache_id;
} else {
cache_id = -1;
}
}
template <typename T, typename Tc, typename Ts>
__mlu_func__ void dequantize_per_channel(T *output_nram,
Tc *input_nram,
Ts *scale_nram,
T *data,
Tc *cache,
Ts *scale,
const int32_t scale_num,
const int32_t seq_num,
const int32_t head_num,
const int32_t head_size,
const size_t context_offset,
const size_t cache_offset,
const size_t scale_offset,
const size_t context_seq_stride,
const size_t context_head_stride,
const size_t cache_head_stride,
const size_t cache_seq_stride,
const size_t scale_bs_stride,
const size_t scale_head_stride) {
if (scale_bs_stride != 0) {
__memcpy((Ts *)scale_nram, (Ts *)scale + scale_offset, head_size * sizeof_(Ts), GDRAM2NRAM,
head_size * sizeof_(Ts), scale_head_stride * sizeof_(Ts), head_num - 1);
}
if (std::is_same<Tc, int4x2_t>::value) {
__memcpy((Tc *)input_nram, (Tc *)cache + cache_offset, head_size >> 1, GDRAM2NRAM,
head_size >> 1, head_num - 1, scale_num >> 1, seq_num - 1, cache_head_stride,
head_num - 1, cache_seq_stride, seq_num - 1);
} else {
__memcpy((Tc *)input_nram, (Tc *)cache + cache_offset, head_size * sizeof_(Tc), GDRAM2NRAM,
head_size * sizeof_(Tc), head_num - 1, scale_num * sizeof_(Tc), seq_num - 1,
cache_head_stride * sizeof_(Tc), head_num - 1, cache_seq_stride * sizeof_(Tc),
seq_num - 1);
}
dequantize<T, Tc, Ts>((T *)output_nram, (Tc *)input_nram, (Ts *)scale_nram, (Ts *)nbuf,
seq_num * scale_num, scale_num);
__memcpy((T *)data + context_offset, (T *)output_nram, head_size * sizeof_(T), NRAM2GDRAM,
context_head_stride * sizeof_(T), head_num - 1, context_seq_stride * sizeof_(T),
seq_num - 1, head_size * sizeof_(T), head_num - 1, scale_num * sizeof_(T), seq_num - 1);
}
template <typename T, typename Tc, typename Ts>
__mlu_func__ void dequantize_value_per_channel(T *output_nram,
Tc *input_nram,
Ts *scale_nram,
int8_t *temp_nram,
T *data,
Tc *cache,
Ts *scale,
const int32_t scale_num,
const int32_t seq_num,
const int32_t head_num,
const int32_t head_size,
const size_t context_offset,
const size_t cache_offset,
const size_t scale_offset,
const size_t context_seq_stride,
const size_t context_head_stride,
const size_t cache_head_stride,
const size_t cache_seq_stride,
const size_t scale_bs_stride,
const size_t scale_head_stride,
const bool pad_front) {
/* Step 1. load scale [head_num, head_size]*/
if (scale_bs_stride != 0) {
__memcpy((Ts *)scale_nram, (Ts *)scale + scale_offset, head_size * sizeof_(Ts), GDRAM2NRAM,
head_size * sizeof_(Ts), scale_head_stride * sizeof_(Ts), head_num - 1);
}
/* Step 2. load cache [load_seq_num, head_num, head_size] */
int32_t load_seq_num = (seq_num >> 1) + int32_t(seq_num % 2);
int32_t deal_seq_num = load_seq_num << 1;
__memcpy((Tc *)input_nram, (Tc *)cache + cache_offset, head_size, GDRAM2NRAM, head_size,
head_num - 1, scale_num, load_seq_num - 1, cache_head_stride, head_num - 1,
cache_seq_stride, load_seq_num - 1);
/* Step 3. convert into int8 [load_seq_num, head_num, head_size, 2] */
convert((int8_t *)output_nram, (int4x2_t *)input_nram, deal_seq_num * scale_num);
/* Step 4. transpose to [deal_seq_num (load_seq_num, 2), head_num, head_size] */
trans_nhwc2nchw_smallc((int8_t *)temp_nram, (int8_t *)output_nram, (uint8_t *)pre_table_nram,
load_seq_num, head_num, head_size, 2);
/* Step 5. dequantize [save_seq_num, head_num, head_size] */
int save_seq_num = pad_front ? seq_num - 1 : seq_num;
dequantize<T, int8_t, Ts>((T *)output_nram, (int8_t *)temp_nram + (pad_front ? scale_num : 0),
(Ts *)scale_nram, (Ts *)nbuf, save_seq_num * scale_num, scale_num);
/* Step 6. store [save_seq_num, head_num, head_size]*/
__memcpy((T *)data + context_offset, (T *)output_nram, head_size * sizeof_(T), NRAM2GDRAM,
context_head_stride * sizeof_(T), head_num - 1, context_seq_stride * sizeof_(T),
save_seq_num - 1, head_size * sizeof_(T), head_num - 1, scale_num * sizeof_(T),
save_seq_num - 1);
}
template <typename T, typename Tc, typename Ts>
__mlu_func__ void dequantize_per_head(T *output_nram,
Tc *input_nram,
Ts *scale_nram,
T *temp_nram,
T *data,
Tc *cache,
Ts *scale,
const int32_t scale_num,
const int32_t seq_num,
const int32_t head_num,
const int32_t head_size,
const size_t context_offset,
const size_t cache_offset,
const size_t scale_offset,
const size_t context_seq_stride,
const size_t context_head_stride,
const size_t cache_head_stride,
const size_t cache_seq_stride,
const size_t scale_head_stride) {
__memcpy((Ts *)scale_nram, (Ts *)scale + scale_offset, seq_num * sizeof_(Ts), GDRAM2NRAM,
seq_num * sizeof_(Ts), scale_head_stride * sizeof_(Ts), head_num - 1);
if (std::is_same<Tc, int4x2_t>::value) {
__memcpy((Tc *)input_nram, (Tc *)cache + cache_offset, head_size >> 1, GDRAM2NRAM,
head_size >> 1, seq_num - 1, seq_num * (head_size >> 1), head_num - 1,
cache_seq_stride, seq_num - 1, cache_head_stride, head_num - 1);
} else {
__memcpy((Tc *)input_nram, (Tc *)cache + cache_offset, head_size * sizeof_(Tc), GDRAM2NRAM,
head_size * sizeof_(Tc), seq_num - 1, seq_num * head_size * sizeof_(Tc), head_num - 1,
cache_seq_stride * sizeof_(Tc), seq_num - 1, cache_head_stride * sizeof_(Tc),
head_num - 1);
}
convert((float *)output_nram, (Tc *)input_nram, head_num * seq_num * head_size);
if (std::is_same<T, float>::value) {
conv_fuse_mul_cvt((T *)output_nram, (float *)scale_nram, (float *)wbuf, (float *)output_nram,
head_num * seq_num, head_size, 1);
} else {
conv_fuse_mul_cvt((T *)temp_nram, (float *)scale_nram, (float *)wbuf, (float *)output_nram,
head_num * seq_num, head_size, 1);
output_nram = (T *)temp_nram;
}
__memcpy((T *)data + context_offset, (T *)output_nram, head_size * sizeof_(T), NRAM2GDRAM,
context_seq_stride * sizeof_(T), seq_num - 1, context_head_stride * sizeof_(T),
head_num - 1, head_size * sizeof_(T), seq_num - 1, head_size * seq_num * sizeof_(T),
head_num - 1);
}
template <typename T, typename Tc, typename Ts>
__mlu_func__ void dequantize_value_per_head(T *output_nram,
Tc *input_nram,
Ts *scale_nram,
T *temp_nram,
T *data,
Tc *cache,
Ts *scale,
const int32_t scale_num,
const int32_t seq_num,
const int32_t head_num,
const int32_t head_size,
const size_t context_offset,
const size_t cache_offset,
const size_t scale_offset,
const size_t context_seq_stride,
const size_t context_head_stride,
const size_t cache_head_stride,
const size_t cache_seq_stride,
const size_t scale_bs_stride,
const size_t scale_head_stride,
const bool pad_front) {
int32_t load_seq_num = (seq_num >> 1) + int32_t(seq_num % 2);
int32_t deal_seq_num = load_seq_num << 1;
/* Step1. load scale first, [head_num, deal_seq_num] */
__memcpy((Ts *)scale_nram, (Ts *)scale + scale_offset, deal_seq_num * sizeof_(Ts), GDRAM2NRAM,
deal_seq_num * sizeof_(Ts), scale_head_stride * sizeof_(Ts), head_num - 1);
/* Step2. load cache input, [head_num, load_seq_num, head_size, 2] for int4 */
__memcpy((Tc *)input_nram, (Tc *)cache + cache_offset, head_size, GDRAM2NRAM, head_size,
load_seq_num - 1, load_seq_num * head_size, head_num - 1, cache_seq_stride,
load_seq_num - 1, cache_head_stride, head_num - 1);
convert((int8_t *)output_nram, (Tc *)input_nram, head_num * head_size * deal_seq_num);
/* Step3. trans to [head_num, load_seq_num, 2, head_size]*/
trans_nhwc2nchw_smallc((int8_t *)temp_nram, (int8_t *)output_nram, (uint8_t *)pre_table_nram,
head_num * load_seq_num, head_size, 1, 2);
/* Step4. dequant to T [head_num, deal_seq_num, head_size] */
convert((float *)output_nram, (int8_t *)temp_nram, head_num * deal_seq_num * head_size);
if (std::is_same<T, float>::value) {
conv_fuse_mul_cvt((T *)output_nram, (float *)scale_nram, (float *)wbuf, (float *)output_nram,
head_num * deal_seq_num, head_size, 1);
} else {
conv_fuse_mul_cvt((T *)temp_nram, (float *)scale_nram, (float *)wbuf, (float *)output_nram,
head_num * deal_seq_num, head_size, 1);
output_nram = (T *)temp_nram;
}
/* Step5. save [head_num, save_seq_num, head_size]*/
int32_t save_seq_num = pad_front ? seq_num - 1 : seq_num;
__memcpy((T *)data + context_offset, (T *)output_nram + (pad_front ? head_size : 0),
head_size * sizeof_(T), NRAM2GDRAM, context_seq_stride * sizeof_(T), save_seq_num - 1,
context_head_stride * sizeof_(T), head_num - 1, head_size * sizeof_(T), save_seq_num - 1,
head_size * deal_seq_num * sizeof_(T), head_num - 1);
}
template <typename T, typename Tc, typename Ts, bool ProcessOffsets>
__mlu_global__ void MLUDequantFromLinearCacheKernelPerChannel(void *key,
void *value,
const void *key_cache,
const void *value_cache,
const void *key_scale,
const void *value_scale,
const int32_t *context_lens,
const int32_t *context_seq_offsets,
const int32_t *cache_bs_id,
const int32_t *cache_seq_offsets,
const int32_t max_context_len,
const int32_t batch_size,
const int32_t head_num,
const int32_t key_group_num,
const int32_t value_group_num,
const int32_t cache_mem_len,
const int32_t head_size,
const int32_t seq_block,
const size_t context_head_stride,
const size_t context_seq_stride,
const size_t cache_bs_stride,
const size_t cache_head_stride,
const size_t key_cache_seq_stride,
const size_t value_cache_seq_stride,
const size_t scale_bs_stride,
const size_t scale_head_stride) {
bool has_key = (key && key_cache && key_scale);
bool has_value = (value && value_cache && value_scale);
if (!(has_key || has_value)) {
return;
}
/* *********************************nram space **************************************
* NRAM |scale[head_num, head_size]|output/input[seq_block, head_num, head_size]|
*/
int32_t scale_num = head_num * head_size;
Ts *scale_nram = (Ts *)nbuf + (scale_num * (sizeof_(float) - sizeof_(Ts)) / sizeof_(Ts));
float *output_nram = (float *)nbuf + scale_num;
Tc *input_nram = (Tc *)output_nram +
(std::is_same<Tc, int4x2_t>::value
? (7 * seq_block * (scale_num >> 1))
: (seq_block * scale_num * (sizeof_(float) - sizeof_(Tc)) / sizeof_(Tc)));
// temp_nram for nram to store int8 input
int8_t *temp_nram = (int8_t *)output_nram + seq_block * scale_num * 3;
int32_t seq_offset;
int32_t seq_len;
int32_t seq_begin;
int32_t deal_seq_num;
int32_t cache_id;
int32_t cache_seq_offset;
size_t context_offset;
size_t cache_offset;
size_t scale_offset;
process_offsets<ProcessOffsets>((int32_t *)n_lens, (int32_t *)n_offsets, (int32_t *)context_lens,
(int32_t *)context_seq_offsets, batch_size);
if (has_key) {
load_scale_once((Ts *)scale_nram, (Ts *)key_scale, head_num, head_size, scale_bs_stride,
scale_head_stride);
for (int32_t batch_idx = taskIdY; batch_idx < batch_size; batch_idx += taskDimY) {
load_len_offset<ProcessOffsets>(seq_len, seq_offset, (int32_t *)n_lens, (int32_t *)n_offsets,
(int32_t *)context_lens, (int32_t *)context_seq_offsets,
batch_idx);
seq_begin = taskIdZ * seq_block;
deal_seq_num = std::min(seq_len - seq_begin, seq_block);
if (deal_seq_num <= 0 || seq_offset < 0) continue;
calcu_offsets_per_channel(cache_id, context_offset, cache_offset, scale_offset,
(int32_t *)cache_bs_id, (int32_t *)cache_seq_offsets, cache_mem_len,
seq_len, seq_begin, seq_offset, batch_idx, context_seq_stride,
cache_bs_stride, key_cache_seq_stride, scale_bs_stride);
if (cache_id < 0) continue;
dequantize_per_channel((T *)output_nram, (Tc *)input_nram, (Ts *)scale_nram, (T *)key,
(Tc *)key_cache, (Ts *)key_scale, scale_num, deal_seq_num, head_num,
head_size, context_offset, cache_offset, scale_offset,
context_seq_stride, context_head_stride, cache_head_stride,
key_cache_seq_stride, scale_bs_stride, scale_head_stride);
}
}
if (has_value) {
if (std::is_same<Tc, int4x2_t>::value) {
__reshape_nhwc2nchw_smallc_init<int8_t>(pre_table_nram, 2);
}
load_scale_once((Ts *)scale_nram, (Ts *)value_scale, head_num, head_size, scale_bs_stride,
scale_head_stride);
for (int32_t batch_idx = taskIdY; batch_idx < batch_size; batch_idx += taskDimY) {
load_len_offset<ProcessOffsets>(seq_len, seq_offset, (int32_t *)n_lens, (int32_t *)n_offsets,
(int32_t *)context_lens, (int32_t *)context_seq_offsets,
batch_idx);
if (std::is_same<Tc, int4x2_t>::value) {
seq_begin = taskIdZ * seq_block;
cache_id =
cache_bs_id == nullptr ? batch_idx : __load_gdram((int32_t *)cache_bs_id + batch_idx);
cache_seq_offset = cache_seq_offsets == nullptr
? 0
: __load_gdram((int32_t *)cache_seq_offsets + batch_idx);
// move seq_begin left by 1 when cache_seq_offset is odd
seq_begin = cache_seq_offset % 2 ? seq_begin - 1 : seq_begin;
deal_seq_num = std::min(seq_len - seq_begin, seq_block);
if (deal_seq_num <= 0 || seq_offset < 0) continue;
if (cache_id >= 0 && cache_seq_offset >= 0 &&
(cache_seq_offset + seq_len) <= cache_mem_len) {
context_offset =
context_seq_stride * (seq_offset + seq_begin + ((seq_begin == -1) ? 1 : 0));
// value cache is [max_batch_size, head_num, cache_mem_len/2, head_size] for int4x2_t
cache_offset = cache_bs_stride * cache_id +
value_cache_seq_stride * ((cache_seq_offset + seq_begin) / 2);
scale_offset = scale_bs_stride * cache_id;
} else {
cache_id = -1;
}
if (cache_id < 0) continue;
dequantize_value_per_channel(
(T *)output_nram, (Tc *)input_nram, (Ts *)scale_nram, (int8_t *)temp_nram, (T *)value,
(Tc *)value_cache, (Ts *)value_scale, scale_num, deal_seq_num, head_num, head_size,
context_offset, cache_offset, scale_offset, context_seq_stride, context_head_stride,
cache_head_stride, value_cache_seq_stride, scale_bs_stride, scale_head_stride,
seq_begin == -1);
} else {
seq_begin = taskIdZ * seq_block;
deal_seq_num = std::min(seq_len - seq_begin, seq_block);
if (deal_seq_num <= 0 || seq_offset < 0) continue;
calcu_offsets_per_channel(
cache_id, context_offset, cache_offset, scale_offset, (int32_t *)cache_bs_id,
(int32_t *)cache_seq_offsets, cache_mem_len, seq_len, seq_begin, seq_offset, batch_idx,
context_seq_stride, cache_bs_stride, value_cache_seq_stride, scale_bs_stride);
if (cache_id < 0) continue;
dequantize_per_channel((T *)output_nram, (Tc *)input_nram, (Ts *)scale_nram, (T *)value,
(Tc *)value_cache, (Ts *)value_scale, scale_num, deal_seq_num,
head_num, head_size, context_offset, cache_offset, scale_offset,
context_seq_stride, context_head_stride, cache_head_stride,
value_cache_seq_stride, scale_bs_stride, scale_head_stride);
}
}
}
}
template <typename T, typename Tc, typename Ts, bool ProcessOffsets>
__mlu_global__ void MLUDequantFromLinearCacheKernelPerHead(void *key,
void *value,
const void *key_cache,
const void *value_cache,
const void *key_scale,
const void *value_scale,
const int32_t *context_lens,
const int32_t *context_seq_offsets,
const int32_t *cache_bs_id,
const int32_t *cache_seq_offsets,
const int32_t max_context_len,
const int32_t batch_size,
const int32_t head_num,
const int32_t key_group_num,
const int32_t value_group_num,
const int32_t cache_mem_len,
const int32_t head_size,
const int32_t seq_block,
const size_t context_head_stride,
const size_t context_seq_stride,
const size_t cache_bs_stride,
const size_t cache_head_stride,
const size_t key_cache_seq_stride,
const size_t value_cache_seq_stride,
const size_t scale_bs_stride,
const size_t scale_head_stride) {
bool has_key = (key && key_cache && key_scale);
bool has_value = (value && value_cache && value_scale);
if (!(has_key || has_value)) {
return;
}
/* *********************************nram space **************************************
* NRAM |scale[seq_block, head_num]|output/input[head_size, seq_block, head_num]|temp|
*/
int32_t scale_num = seq_block * head_num;
Ts *scale_nram = (Ts *)nbuf + (scale_num * (sizeof_(float) - sizeof_(Ts)) / sizeof_(Ts));
float *output_nram = (float *)nbuf + scale_num;
Tc *input_nram = (Tc *)output_nram +
(std::is_same<Tc, int4x2_t>::value
? (7 * head_size * (scale_num >> 1))
: head_size * scale_num * (sizeof_(float) - sizeof_(Tc)) / sizeof_(Tc));
// temp_nram for nram to store converted output
float *temp_nram = (float *)output_nram + head_size * scale_num;
int32_t seq_offset;
int32_t seq_len;
int32_t seq_begin;
int32_t deal_seq_num;
int32_t cache_id;
size_t context_offset;
size_t key_cache_offset;
size_t value_cache_offset;
size_t scale_offset;
__bang_write_value((float *)nbuf, head_size * 16, 1.0f);
mvNram2WramLT16<float>((int8_t *)wbuf, (int8_t *)nbuf, head_size, 16, 16);
process_offsets<ProcessOffsets>((int32_t *)n_lens, (int32_t *)n_offsets, (int32_t *)context_lens,
(int32_t *)context_seq_offsets, batch_size);
for (int32_t batch_idx = taskIdY; batch_idx < batch_size; batch_idx += taskDimY) {
load_len_offset<ProcessOffsets>(seq_len, seq_offset, (int32_t *)n_lens, (int32_t *)n_offsets,
(int32_t *)context_lens, (int32_t *)context_seq_offsets,
batch_idx);
seq_begin = taskIdZ * seq_block;
deal_seq_num = std::min(seq_len - seq_begin, seq_block);
if (deal_seq_num <= 0 || seq_offset < 0) continue;
calcu_offsets_per_head(cache_id, context_offset, key_cache_offset, value_cache_offset,
scale_offset, (int32_t *)cache_bs_id, (int32_t *)cache_seq_offsets,
cache_mem_len, seq_len, seq_begin, seq_offset, batch_idx,
context_seq_stride, cache_bs_stride, key_cache_seq_stride,
value_cache_seq_stride, scale_bs_stride);
if (cache_id < 0) continue;
if (has_key) {
dequantize_per_head((T *)output_nram, (Tc *)input_nram, (Ts *)scale_nram, (T *)temp_nram,
(T *)key, (Tc *)key_cache, (Ts *)key_scale, scale_num, deal_seq_num,
head_num, head_size, context_offset, key_cache_offset, scale_offset,
context_seq_stride, context_head_stride, cache_head_stride,
key_cache_seq_stride, scale_head_stride);
}
if (has_value && std::is_same<Tc, int8_t>::value) {
dequantize_per_head((T *)output_nram, (Tc *)input_nram, (Ts *)scale_nram, (T *)temp_nram,
(T *)value, (Tc *)value_cache, (Ts *)value_scale, scale_num, deal_seq_num,
head_num, head_size, context_offset, value_cache_offset, scale_offset,
context_seq_stride, context_head_stride, cache_head_stride,
value_cache_seq_stride, scale_head_stride);
}
}
// process value int4 differently
if (has_value && std::is_same<Tc, int4x2_t>::value) {
int32_t cache_seq_offset;
__reshape_nhwc2nchw_smallc_init<int8_t>(pre_table_nram, 2);
for (int32_t batch_idx = taskIdY; batch_idx < batch_size; batch_idx += taskDimY) {
load_len_offset<ProcessOffsets>(seq_len, seq_offset, (int32_t *)n_lens, (int32_t *)n_offsets,
(int32_t *)context_lens, (int32_t *)context_seq_offsets,
batch_idx);
seq_begin = taskIdZ * seq_block;
cache_id =
cache_bs_id == nullptr ? batch_idx : __load_gdram((int32_t *)cache_bs_id + batch_idx);
cache_seq_offset =
cache_seq_offsets == nullptr ? 0 : __load_gdram((int32_t *)cache_seq_offsets + batch_idx);
// move seq_begin left by 1 when cache_seq_offset is odd
seq_begin = cache_seq_offset % 2 ? seq_begin - 1 : seq_begin;
deal_seq_num = std::min(seq_len - seq_begin, seq_block);
if (deal_seq_num <= 0 || seq_offset < 0) continue;
if (cache_id >= 0 && cache_seq_offset >= 0 && (cache_seq_offset + seq_len) <= cache_mem_len) {
context_offset =
context_seq_stride * (seq_offset + seq_begin + ((seq_begin == -1) ? 1 : 0));
// value cache is [max_batch_size, head_num, cache_mem_len/2, head_size] for int4x2_t
value_cache_offset = cache_bs_stride * cache_id +
value_cache_seq_stride * ((cache_seq_offset + seq_begin) / 2);
scale_offset = cache_seq_offset + seq_begin + scale_bs_stride * cache_id;
} else {
cache_id = -1;
}
if (cache_id < 0) continue;
dequantize_value_per_head((T *)output_nram, (Tc *)input_nram, (Ts *)scale_nram,
(T *)temp_nram, (T *)value, (Tc *)value_cache, (Ts *)value_scale,
scale_num, deal_seq_num, head_num, head_size, context_offset,
value_cache_offset, scale_offset, context_seq_stride,
context_head_stride, cache_head_stride, value_cache_seq_stride,
scale_bs_stride, scale_head_stride, seq_begin == -1);
}
}
}
} // namespace kernels
#define DEQUANT_LINEAR_INIT(T, Tc, Ts, C, Name) \
template __mlu_global__ void kernels::MLUDequantFromLinearCacheKernel##Name<T, Tc, Ts, C>( \
void *key, void *value, const void *key_cache, const void *value_cache, \
const void *key_scale, const void *value_scale, const int32_t *context_lens, \
const int32_t *context_seq_offsets, const int32_t *cache_bs_id, \
const int32_t *cache_seq_offsets, const int32_t max_context_len, const int32_t batch_size, \
const int32_t head_num, const int32_t key_group_num, const int32_t value_group_num, \
const int32_t cache_mem_len, const int32_t head_size, const int32_t seq_block, \
const size_t context_head_stride, const size_t context_seq_stride, \
const size_t cache_bs_stride, const size_t cache_head_stride, \
const size_t key_cache_seq_stride, const size_t value_cache_seq_stride, \
const size_t scale_bs_stride, const size_t scale_head_stride);
DEQUANT_LINEAR_INIT(half, int8_t, float, false, PerChannel)
DEQUANT_LINEAR_INIT(bfloat16_t, int8_t, float, false, PerChannel)
DEQUANT_LINEAR_INIT(float, int8_t, float, false, PerChannel)
DEQUANT_LINEAR_INIT(half, int4x2_t, float, false, PerChannel)
DEQUANT_LINEAR_INIT(bfloat16_t, int4x2_t, float, false, PerChannel)
DEQUANT_LINEAR_INIT(float, int4x2_t, float, false, PerChannel)
DEQUANT_LINEAR_INIT(half, int8_t, float, false, PerHead)
DEQUANT_LINEAR_INIT(bfloat16_t, int8_t, float, false, PerHead)
DEQUANT_LINEAR_INIT(float, int8_t, float, false, PerHead)
DEQUANT_LINEAR_INIT(half, int4x2_t, float, false, PerHead)
DEQUANT_LINEAR_INIT(bfloat16_t, int4x2_t, float, false, PerHead)
DEQUANT_LINEAR_INIT(float, int4x2_t, float, false, PerHead)
DEQUANT_LINEAR_INIT(half, int8_t, float, true, PerChannel)
DEQUANT_LINEAR_INIT(bfloat16_t, int8_t, float, true, PerChannel)
DEQUANT_LINEAR_INIT(float, int8_t, float, true, PerChannel)
DEQUANT_LINEAR_INIT(half, int4x2_t, float, true, PerChannel)
DEQUANT_LINEAR_INIT(bfloat16_t, int4x2_t, float, true, PerChannel)
DEQUANT_LINEAR_INIT(float, int4x2_t, float, true, PerChannel)
DEQUANT_LINEAR_INIT(half, int8_t, float, true, PerHead)
DEQUANT_LINEAR_INIT(bfloat16_t, int8_t, float, true, PerHead)
DEQUANT_LINEAR_INIT(float, int8_t, float, true, PerHead)
DEQUANT_LINEAR_INIT(half, int4x2_t, float, true, PerHead)
DEQUANT_LINEAR_INIT(bfloat16_t, int4x2_t, float, true, PerHead)
DEQUANT_LINEAR_INIT(float, int4x2_t, float, true, PerHead)
typedef void (*DequantFromLinearCachePointer)(void *, // key
void *, // value
const void *, // key_cache
const void *, // value_cache
const void *, // key_scale
const void *, // value_scale
const int32_t *, // context_lens
const int32_t *, // context_seq_offsets
const int32_t *, // cache_bs_id
const int32_t *, // cache_seq_offsets
const int32_t, // max_context_len
const int32_t, // batch_size
const int32_t, // head_num
const int32_t, // key_group_num
const int32_t, // value_group_num
const int32_t, // cache_mem_len
const int32_t, // head_size
const int32_t, // seq_block
const size_t, // context_head_stride
const size_t, // context_seq_stride
const size_t, // cache_bs_stride
const size_t, // cache_head_stride
const size_t, // key_cache_seq_stride
const size_t, // value_cache_seq_stride
const size_t, // scale_bs_stride
const size_t); // scale_head_stride
static DequantFromLinearCachePointer DequantFromLinearCacheFuncArr[DEQUANT_FUNC_LEN] = {
DEQUANT_LINEAR_PERCHANNEL<half, int8_t, float, false>,
DEQUANT_LINEAR_PERCHANNEL<bfloat16_t, int8_t, float, false>,
DEQUANT_LINEAR_PERCHANNEL<float, int8_t, float, false>,
DEQUANT_LINEAR_PERCHANNEL<half, int4x2_t, float, false>,
DEQUANT_LINEAR_PERCHANNEL<bfloat16_t, int4x2_t, float, false>,
DEQUANT_LINEAR_PERCHANNEL<float, int4x2_t, float, false>,
DEQUANT_LINEAR_PERHEAD<half, int8_t, float, false>,
DEQUANT_LINEAR_PERHEAD<bfloat16_t, int8_t, float, false>,
DEQUANT_LINEAR_PERHEAD<float, int8_t, float, false>,
DEQUANT_LINEAR_PERHEAD<half, int4x2_t, float, false>,
DEQUANT_LINEAR_PERHEAD<bfloat16_t, int4x2_t, float, false>,
DEQUANT_LINEAR_PERHEAD<float, int4x2_t, float, false>,
DEQUANT_LINEAR_PERCHANNEL<half, int8_t, float, true>,
DEQUANT_LINEAR_PERCHANNEL<bfloat16_t, int8_t, float, true>,
DEQUANT_LINEAR_PERCHANNEL<float, int8_t, float, true>,
DEQUANT_LINEAR_PERCHANNEL<half, int4x2_t, float, true>,
DEQUANT_LINEAR_PERCHANNEL<bfloat16_t, int4x2_t, float, true>,
DEQUANT_LINEAR_PERCHANNEL<float, int4x2_t, float, true>,
DEQUANT_LINEAR_PERHEAD<half, int8_t, float, true>,
DEQUANT_LINEAR_PERHEAD<bfloat16_t, int8_t, float, true>,
DEQUANT_LINEAR_PERHEAD<float, int8_t, float, true>,
DEQUANT_LINEAR_PERHEAD<half, int4x2_t, float, true>,
DEQUANT_LINEAR_PERHEAD<bfloat16_t, int4x2_t, float, true>,
DEQUANT_LINEAR_PERHEAD<float, int4x2_t, float, true>};
uint32_t getDequantLinearIdx(cnnlDataType_t dtype,
int32_t quant_mode,
int32_t quant_bit,
const void *context_seq_offset) {
uint32_t idx = 0;
idx += (quant_mode != 0) ? 6 : 0;
idx += (quant_bit != 8) ? 3 : 0;
idx += (dtype == CNNL_DTYPE_BFLOAT16) ? 1 : 0;
idx += (dtype == CNNL_DTYPE_FLOAT) ? 2 : 0;
idx += (context_seq_offset == nullptr) ? 12 : 0;
return idx;
}
void getBlockAndDimForLinear(int32_t &seq_block,
cnrtDim3_t &task_dim,
cnrtFunctionType_t &task_type,
const int32_t max_context_len,
const int32_t head_num,
const int32_t batch_size,
const int32_t head_size,
const int32_t quant_mode,
const int32_t quant_bit,
const cnnlDataType_t dtype) {
int32_t core_dim;
int32_t cluster_dim;
int32_t nram_size = 480 * 1024;
int32_t wram_size = 512 * 1024;
int32_t sram_size = 2016 * 1024;
getDeviceCoreAndRam(cluster_dim, core_dim, nram_size, wram_size, sram_size, REM_FOR_STACK);
if (quant_mode == 0) {
seq_block = int32_t((int64_t)nram_size / ((int64_t)head_num * head_size * sizeof_(float)) - 1);
if (quant_bit == 4) {
if (seq_block <= 1) {
std::cerr << __func__ << "," << __LINE__
<< " :head_num * head_size * sizeof_(float) should be less than "
<< (nram_size >> 1) << " when quant_mode is 0." << std::endl;
}
} else {
if (seq_block <= 0) {
std::cerr << __func__ << "," << __LINE__
<< " :head_num * head_size * sizeof_(float) should be less than " << nram_size
<< " when quant_mode is 0." << std::endl;
}
}
} else {
int32_t dtype_size = dtype == CNNL_DTYPE_FLOAT ? sizeof_(float) : sizeof_(half);
seq_block = int32_t((int64_t)nram_size / ((int64_t)head_num * (head_size + 1) * sizeof_(float) +
(int64_t)head_num * head_size * dtype_size));
if (quant_bit == 4) {
if (seq_block <= 1) {
std::cerr << __func__ << "," << __LINE__
<< " :head_num * sizeof_(float) + head_num * head_size * (sizeof_(float) + "
"context_dtype_size) "
<< "should be less than " << (nram_size >> 1) << " when quant_mode is 1."
<< std::endl;
}
} else {
if (seq_block <= 0) {
std::cerr << __func__ << "," << __LINE__
<< " :head_num * sizeof_(float) + head_num * head_size * (sizeof_(float) + "
"context_dtype_size) "
<< "should be less than " << nram_size << " when quant_mode is 1." << std::endl;
}
}
/* head_size * 64B put in the wram. */
if (head_size * ONE_LINE >= wram_size) {
std::cerr << __func__ << "," << __LINE__ << " head_size * 64 " << "should be less than "
<< wram_size << " when quant_mode is 1." << std::endl;
}
}
seq_block = std::min(seq_block, max_context_len);
if (seq_block > 16 && seq_block < max_context_len) {
seq_block = PAD_DOWN(seq_block, 16);
}
if (quant_bit == 4) {
seq_block = PAD_DOWN(seq_block, 2);
}
int seq_seg = DIV_UP(max_context_len, seq_block);
// need an extra seg block to dealwith int4 value_cache [...,seq_len/2, head_size]
if (quant_bit == 4) {
seq_seg += 1;
}
uint32_t core_num = cluster_dim * core_dim;
if (batch_size * seq_seg <= (core_num / 2)) {
int times = core_num / batch_size / seq_seg;
seq_block = std::max(seq_block / times, 2);
if (quant_bit == 4) {
seq_block = PAD_DOWN(seq_block, 2);
}
seq_seg = DIV_UP(max_context_len, seq_block);
// same as above to dealwise int4 value_cache with an extra seg block
if (quant_bit == 4) {
seq_seg += 1;
}
}
task_dim.x = 1;
task_dim.y = uint32_t(std::min(batch_size, cluster_dim * core_dim));
task_dim.z = uint32_t(seq_seg);
task_type = cnrtFuncTypeBlock;
}
KernelStatus invokeDequantFromLinearCache(cnrtQueue_t queue,
void *key,
void *value,
const void *key_cache,
const void *value_cache,
const void *key_scale,
const void *value_scale,
const void *context_lens,
const void *context_seq_offsets,
const void *cache_bs_id,
const void *cache_seq_offsets,
const int32_t max_context_len,
const int32_t batch_size,
const int32_t head_num,
const int32_t key_group_num,
const int32_t value_group_num,
const int32_t cache_mem_len,
const int32_t head_size,
const int32_t quant_mode,
const int32_t quant_bit,
const size_t context_head_stride,
const size_t context_seq_stride,
const size_t cache_bs_stride,
const size_t cache_head_stride,
const size_t key_cache_seq_stride,
const size_t value_cache_seq_stride,
const size_t scale_bs_stride,
const size_t scale_head_stride,
const cnnlDataType_t dtype) {
if (dtype == CNNL_DTYPE_BFLOAT16 && !isBf16Supported()) {
std::cerr << "[invokeDequantFromPagedCache]: "
"MLU300 devices do not support bfloat16."
<< std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
int32_t index;
int32_t seq_block;
cnrtDim3_t k_dim;
cnrtFunctionType_t k_type;
getBlockAndDimForLinear(seq_block, k_dim, k_type, max_context_len, head_num, batch_size,
head_size, quant_mode, quant_bit, dtype);
index = getDequantLinearIdx(dtype, quant_mode, quant_bit, context_seq_offsets);
auto dequant_linear_func = DequantFromLinearCacheFuncArr[index];
dequant_linear_func<<<k_dim, k_type, queue>>>(
(void *)key, (void *)value, (const void *)key_cache, (const void *)value_cache,
(const void *)key_scale, (const void *)value_scale, (const int32_t *)context_lens,
(const int32_t *)context_seq_offsets, (const int32_t *)cache_bs_id,
(const int32_t *)cache_seq_offsets, max_context_len, batch_size, head_num, key_group_num,
value_group_num, cache_mem_len, head_size, seq_block, context_head_stride, context_seq_stride,
cache_bs_stride, cache_head_stride, key_cache_seq_stride, value_cache_seq_stride,
scale_bs_stride, scale_head_stride);
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
} // namespace tmo

View File

@@ -0,0 +1,108 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_DEQUANT_FROM_LINEAR_CACHE_MLUH_
#define CSRC_KERNELS_DEQUANT_FROM_LINEAR_CACHE_MLUH_
#include "cnnl.h"
#include "kernel_utils.h"
namespace tmo {
/**
* @brief De-quantizes the key and value tensors from the provided linear cache and scale.
* @param queue: The queue for mlu.
* @param key: Pointer to the MLU memory that stores the key tensor,
* with shape [total_seqlen, head_num, head_size]. Data type can be float32, half,
* or bfloat16. This parameter can be nullptr.
* @param value: Pointer to the MLU memory that stores the value tensor,
* with shape [total_seqlen, head_num, head_size]. Data type can be float32,
* half, or bfloat16.This parameter can be nullptr.
* @param key_cache: Pointer to the MLU memory that stores the key cache tensor,
* with shape [max_batch_size, head_num, cache_mem_len, head_size] for 8-bit quantization
* or [max_bs, head_num, cache_mem_len, head_size//2] for 4-bit quantization.
* Data type must be int8. This parameter can be nullptr.
* @param value_cache: Pointer to the MLU memory that stores the value cache tensor,
* with shape [max_batch_size, head_num, cache_mem_len, head_size] for 8-bit quantization
* or [max_bs, head_num, cache_mem_len//2, head_size] for 4-bit quantization.
* Data type must be int8. This parameter can be nullptr.
* @param key_scale: Pointer to the MLU memory that stores the key cache quantization scale.
* Shape depends on quantization mode:
* - For per-channel quantization (quant_mode = 0): [head_num, head_size].
* - For per-token quantization (quant_mode = 1): [max_batch_size, head_num, cache_mem_len].
* Data type must be float32. This parameter can be nullptr.
* @param value_scale: Pointer to the MLU memory that stores the value cache quantization scale,
* with the same shape as key_scale. Data type must be float32. This parameter can be
nullptr.
* @param context_lens: Pointer to the MLU memory that stores the sequence lengths.
* The shape must be [batch].
* @param context_seq_offset: Pointer to the MLU memory that stores the sequence offset in the
context.
* The shape must be [batch]. If nullptr, the default value is the cumulative sum of
context_lengths.
* @param cache_bs_id: Pointer to the MLU memory that stores the batch index in the cache.
* The shape must be [batch]. If nullptr, the default value is {0, 1, 2, ..., batch - 1}.
* @param cache_seq_offset: Pointer to the MLU memory that stores the sequence offset in the cache.
* The shape must be [batch]. If nullptr, the default value is 0 for every batch.
* @param max_contxt_len: The maximum sequence length of context.
* @param batch: Batch size.
* @param head_num: Head number.
* @param key_group_num: group number of key group-wise quantization.
* @param value_group_num: group number of value group-wise quantization.
* @param cache_mem_len: The maximum sequence length of cache.
* @param head_size: Head size.
* @param quant_mode: An integer value indicating the quantization mode:
* 0 for per-channel quantization and 1 for per-token quantization.
* @param quant_bit: An integer value indicating the quantization bit width:
* 8 for 8-bit quantization and 4 for 4-bit quantization.
* @param contxt_head_stride: The stride of head_num in context.
* @param contxt_seq_stride: The stride of max_contxt_len in context.
* @param cache_bs_stride: The stride of batch in cache.
* @param cache_head_stride: The stride of head_num in cache.
* @param key_cache_seq_stride: The stride of cache_mem_len in key cache.
* @param value_cache_seq_stride: The stride of cache_mem_len in value cache.
* @param cache_scale_bs_stride: The stride of batch in cache scale, only valid if quant_per_quant.
* @param cache_scale_head_stride: The stride of head in cache scale.
* @param dtype: The data type of the key and value tensors.
* @note If any of key/key_cache/key_scale is nullptr, no operation is performed on the key.
* If any of value/value_cache/value_scale is nullptr, no operation is performed on the value.
*/
KernelStatus invokeDequantFromLinearCache(cnrtQueue_t queue,
void *key,
void *value,
const void *key_cache,
const void *value_cache,
const void *key_scale,
const void *value_scale,
const void *context_lens,
const void *context_seq_offsets,
const void *cache_bs_ids,
const void *cache_seq_offsets,
const int32_t max_context_len,
const int32_t batch,
const int32_t head_num,
const int32_t key_group_num,
const int32_t value_group_num,
const int32_t cache_mem_len,
const int32_t head_size,
const int32_t quant_mode,
const int32_t quant_bit,
const size_t context_head_stride,
const size_t context_seq_stride,
const size_t cache_bs_stride,
const size_t cache_head_stride,
const size_t key_cache_seq_stride,
const size_t value_cache_seq_stride,
const size_t cache_scale_bs_stride,
const size_t cache_scale_head_stride,
const cnnlDataType_t dtype);
} // namespace tmo
#endif // CSRC_KERNELS_DEQUANT_FROM_LINEAR_CACHE_MLUH_

View File

@@ -0,0 +1,616 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include <algorithm>
#include <cassert>
#include <climits>
#include <cstddef>
#include <iostream>
#include <type_traits>
#include "dequant_from_paged_cache.mluh"
#include "quant_utils.h"
// clang-format off
#include <mlu.h>
// clang-format on
namespace tmo {
namespace kernels {
#pragma bang walign(16)
#define REM_FOR_STACK (32 * 1024)
#define DEQUANT_WRAM_SIZE (__MLU_WRAM_SIZE__ * 1024)
#define DEQUANT_NRAM_SIZE (__MLU_NRAM_SIZE__ * 1024 - REM_FOR_STACK)
#define DEQUANT_PAGED_PERHEAD kernels::MLUDequantFromPagedCacheKernelPerHead
#define DEQUANT_PAGED_PERCHANNEL kernels::MLUDequantFromPagedCacheKernelPerChannel
#define DEQUANT_FUNC_LEN (24)
#define DEQUANT_BATCH_NUM (1024)
__wram__ int8_t wbuf[DEQUANT_WRAM_SIZE];
__nram__ int8_t nbuf[DEQUANT_NRAM_SIZE];
// Uses 8K = 1K * (4 + 4) to process offsets
__nram__ int32_t n_lens[DEQUANT_BATCH_NUM];
__nram__ int32_t n_offsets[DEQUANT_BATCH_NUM];
template <typename T, typename Tc, typename Ts>
__mlu_func__ void dequantize_per_channel(T *output_nram,
Tc *input_nram,
Ts *scale_nram,
T *data,
const int32_t scale_num,
const int32_t seq_num,
const int32_t head_num,
const int32_t head_size,
const size_t context_offset,
const size_t context_seq_stride,
const size_t context_head_stride) {
dequantize<T, Tc, Ts>((T *)output_nram, (Tc *)input_nram, (Ts *)scale_nram, (Ts *)nbuf,
seq_num * scale_num, scale_num);
__memcpy((T *)data + context_offset, (T *)output_nram, head_size * sizeof_(T), NRAM2GDRAM,
context_head_stride * sizeof_(T), head_num - 1, context_seq_stride * sizeof_(T),
seq_num - 1, head_size * sizeof_(T), head_num - 1, scale_num * sizeof_(T), seq_num - 1);
}
template <typename T, typename Tc, typename Ts>
__mlu_func__ void dequantize_per_head(T *output_nram,
Tc *input_nram,
Ts *scale_nram,
T *temp_nram,
T *data,
const int32_t seq_num,
const int32_t head_num,
const int32_t block_size,
const int32_t head_size,
const size_t context_offset,
const size_t context_seq_stride,
const size_t context_head_stride) {
int block_count = DIV_UP(seq_num, block_size);
int rem_token = seq_num % block_size;
convert((float *)output_nram, (Tc *)input_nram, block_count * head_num * block_size * head_size);
T *res_nram = std::is_same<T, float>::value ? (T *)output_nram : (T *)temp_nram;
// dequantize [block_count, head_num, block_size, head_size]
conv_fuse_mul_cvt((T *)res_nram, (float *)scale_nram, (float *)wbuf, (float *)output_nram,
block_count * head_num * block_size, head_size, 1);
// copy to [seq_num, head_num, head_size]
int whole_block_count = block_count - int(rem_token > 0);
if (whole_block_count) {
for (int i = 0; i < head_num; ++i) {
// copy from [whole_block_count, i, block_size, head_size]
// to [whole_block_count * block_size, 1, head_size]
__memcpy((T *)data + context_offset + i * context_head_stride,
(T *)res_nram + i * block_size * head_size, head_size * sizeof_(T), NRAM2GDRAM,
context_seq_stride * sizeof_(T), block_size - 1,
block_size * context_seq_stride * sizeof_(T), whole_block_count - 1,
head_size * sizeof_(T), block_size - 1,
head_num * block_size * head_size * sizeof_(T), whole_block_count - 1);
}
}
if (rem_token) {
// copy from [last, head_num, block_size(rem_token), head_size]
// to [rem_token, head_num, head_size]
__memcpy((T *)data + context_offset + whole_block_count * block_size * context_seq_stride,
(T *)res_nram + whole_block_count * head_num * block_size * head_size,
head_size * sizeof_(T), NRAM2GDRAM, context_head_stride * sizeof_(T), head_num - 1,
context_seq_stride * sizeof_(T), rem_token - 1, block_size * head_size * sizeof_(T),
head_num - 1, head_size * sizeof_(T), rem_token - 1);
}
}
template <typename Tc>
__mlu_func__ void load_input_per_channel(Tc *input_nram,
Tc *cache,
Tc *temp_nram,
int32_t *block_offsets,
uint32_t *cache_offsets,
const int32_t *block_tables,
const int32_t batch_idx,
const int32_t scale_num,
const int32_t max_block_num,
const int32_t head_num,
const int32_t block_size,
const int32_t head_size,
const int32_t seq_begin,
const int32_t deal_seq_num,
const size_t cache_bs_stride,
const size_t cache_head_stride) {
int32_t block_start = batch_idx * max_block_num + seq_begin / block_size;
int32_t block_end = batch_idx * max_block_num + (seq_begin + deal_seq_num - 1) / block_size;
int32_t block_count = block_end - block_start + 1;
// make sure elements in block_tables >= 0
__memcpy((int32_t *)block_offsets, (int32_t *)block_tables + block_start,
block_count * sizeof_(int32_t), GDRAM2NRAM);
__bang_mul_scalar((uint32_t *)cache_offsets, (uint32_t *)block_offsets,
(uint32_t)cache_bs_stride * sizeof(Tc), block_count);
#if __BANG_ARCH__ >= 500
// gather [block_count, head_num, block_size, head_size]
__gather((Tc *)input_nram, (Tc *)cache, (uint32_t *)cache_offsets,
(uint32_t)cache_bs_stride * sizeof(Tc), GDRAM2NRAM,
(uint32_t)cache_bs_stride * sizeof(Tc), block_count);
if (head_num != 1 && block_size != 1) {
// mv to [head_num, whole_block_count, block_size, head_size]
__memcpy((Tc *)temp_nram, (Tc *)input_nram, block_size * head_size * sizeof(Tc), NRAM2NRAM,
block_size * head_size * sizeof(Tc), block_count - 1,
block_count * block_size * head_size * sizeof(Tc), head_num - 1,
head_num * block_size * head_size * sizeof(Tc), block_count - 1,
block_size * head_size * sizeof(Tc), head_num - 1);
// mv to [whole_block_count, block_size, head_num, head_size]
__memcpy((Tc *)input_nram, (Tc *)temp_nram, head_size * sizeof(Tc), NRAM2NRAM,
head_size * sizeof(Tc), head_num - 1, head_num * head_size * sizeof(Tc),
block_count * block_size - 1, block_count * block_size * head_size * sizeof(Tc),
head_num - 1, head_size * sizeof(Tc), block_count * block_size - 1);
}
#endif
}
template <typename Tc, typename Ts>
__mlu_func__ void load_input_per_head(Tc *input_nram,
Ts *scale_nram,
Tc *cache,
Ts *scale,
Tc *temp_nram,
int32_t *block_offsets,
uint32_t *cache_offsets,
uint32_t *scale_offsets,
const int32_t *block_tables,
const int32_t batch_idx,
const int32_t max_block_num,
const int32_t head_num,
const int32_t block_size,
const int32_t head_size,
const int32_t seq_begin,
const int32_t deal_seq_num,
const size_t cache_bs_stride,
const size_t cache_head_stride,
const size_t scale_bs_stride,
const size_t scale_head_stride) {
int32_t block_start = batch_idx * max_block_num + seq_begin / block_size;
int32_t block_end = batch_idx * max_block_num + (seq_begin + deal_seq_num - 1) / block_size;
int32_t block_count = block_end - block_start + 1;
// make sure elements in block_tables >= 0
__memcpy((int32_t *)block_offsets, (int32_t *)block_tables + block_start,
block_count * sizeof_(int32_t), GDRAM2NRAM);
__bang_mul_scalar((uint32_t *)cache_offsets, (uint32_t *)block_offsets,
(uint32_t)cache_bs_stride * sizeof(Tc), block_count);
__bang_mul_scalar((uint32_t *)scale_offsets, (uint32_t *)block_offsets,
(uint32_t)scale_bs_stride * sizeof(Ts), block_count);
#if __BANG_ARCH__ >= 500
// gather [block_count, head_num, block_size, head_size]
__gather((Tc *)input_nram, (Tc *)cache, (uint32_t *)cache_offsets,
(uint32_t)cache_bs_stride * sizeof(Tc), GDRAM2NRAM,
(uint32_t)cache_bs_stride * sizeof(Tc), block_count);
// gather [block_count, head_num, block_size]
__gather((Ts *)scale_nram, (Ts *)scale, (uint32_t *)scale_offsets,
(uint32_t)scale_bs_stride * sizeof(Ts), GDRAM2NRAM,
(uint32_t)scale_bs_stride * sizeof(Ts), block_count);
#endif
}
template <typename T, typename Tc, typename Ts, bool ProcessOffsets>
__mlu_global__ void MLUDequantFromPagedCacheKernelPerChannel(void *key,
void *value,
const void *key_cache,
const void *value_cache,
const void *key_scale,
const void *value_scale,
const int32_t *context_lens,
const int32_t *context_seq_offsets,
const int32_t *block_tables,
const int32_t max_context_len,
const int32_t max_block_num,
const int32_t batch_size,
const int32_t head_num,
const int32_t key_group_num,
const int32_t value_group_num,
const int32_t block_size,
const int32_t head_size,
const int32_t seq_block,
const size_t context_head_stride,
const size_t context_seq_stride,
const size_t cache_bs_stride,
const size_t cache_head_stride,
const size_t key_cache_seq_stride,
const size_t value_cache_seq_stride,
const size_t scale_bs_stride,
const size_t scale_head_stride) {
bool has_key = (key && key_cache && key_scale);
bool has_value = (value && value_cache && value_scale);
if (!(has_key || has_value)) {
return;
}
/* *********************************nram space **************************************
* NRAM |scale[head_num, head_size] fp32|output/input[seq_block, head_num, head_size] fp32|
*/
int32_t scale_num = head_num * head_size;
Ts *scale_nram = (Ts *)nbuf + (scale_num * (sizeof_(float) - sizeof_(Ts)) / sizeof_(Ts));
float *output_nram = (float *)nbuf + scale_num;
Tc *input_nram = (Tc *)output_nram +
(std::is_same<Tc, int4x2_t>::value
? (7 * seq_block * (scale_num >> 1))
: (seq_block * scale_num * (sizeof_(float) - sizeof_(Tc)) / sizeof_(Tc)));
int32_t *block_offsets = (int32_t *)((int8_t *)output_nram + seq_block * scale_num);
uint32_t *cache_offsets = (uint32_t *)block_offsets + DIV_UP(seq_block, block_size);
int32_t seq_offset;
int32_t seq_len;
int32_t seq_begin;
int32_t deal_seq_num;
size_t context_offset;
process_offsets<ProcessOffsets>((int32_t *)n_lens, (int32_t *)n_offsets, (int32_t *)context_lens,
(int32_t *)context_seq_offsets, batch_size);
if (has_key) {
load_scale_once((Ts *)scale_nram, (Ts *)key_scale, head_num, head_size, scale_bs_stride,
scale_head_stride);
for (int32_t batch_idx = taskIdY; batch_idx < batch_size; batch_idx += taskDimY) {
load_len_offset<ProcessOffsets>(seq_len, seq_offset, (int32_t *)n_lens, (int32_t *)n_offsets,
(int32_t *)context_lens, (int32_t *)context_seq_offsets,
batch_idx);
// seq_begin % block_size != 0 only when seq_block < block_size
seq_begin = taskIdZ * seq_block;
deal_seq_num = std::min(seq_len - seq_begin, seq_block);
if (deal_seq_num <= 0 || seq_offset < 0) continue;
context_offset = context_seq_stride * (seq_offset + seq_begin);
load_input_per_channel((Tc *)input_nram, (Tc *)key_cache, (Tc *)output_nram,
(int32_t *)block_offsets, (uint32_t *)cache_offsets,
(int32_t *)block_tables, batch_idx, scale_num, max_block_num, head_num,
block_size, head_size, seq_begin, deal_seq_num, cache_bs_stride,
cache_head_stride);
dequantize_per_channel((T *)output_nram, (Tc *)input_nram, (Ts *)scale_nram, (T *)key,
scale_num, deal_seq_num, head_num, head_size, context_offset,
context_seq_stride, context_head_stride);
}
}
if (has_value) {
load_scale_once((Ts *)scale_nram, (Ts *)value_scale, head_num, head_size, scale_bs_stride,
scale_head_stride);
for (int32_t batch_idx = taskIdY; batch_idx < batch_size; batch_idx += taskDimY) {
load_len_offset<ProcessOffsets>(seq_len, seq_offset, (int32_t *)n_lens, (int32_t *)n_offsets,
(int32_t *)context_lens, (int32_t *)context_seq_offsets,
batch_idx);
// seq_begin % block_size != 0 only when seq_block < block_size
seq_begin = taskIdZ * seq_block;
deal_seq_num = std::min(seq_len - seq_begin, seq_block);
if (deal_seq_num <= 0 || seq_offset < 0) continue;
context_offset = context_seq_stride * (seq_offset + seq_begin);
load_input_per_channel((Tc *)input_nram, (Tc *)value_cache, (Tc *)output_nram,
(int32_t *)block_offsets, (uint32_t *)cache_offsets,
(int32_t *)block_tables, batch_idx, scale_num, max_block_num, head_num,
block_size, head_size, seq_begin, deal_seq_num, cache_bs_stride,
cache_head_stride);
dequantize_per_channel((T *)output_nram, (Tc *)input_nram, (Ts *)scale_nram, (T *)value,
scale_num, deal_seq_num, head_num, head_size, context_offset,
context_seq_stride, context_head_stride);
}
}
}
template <typename T, typename Tc, typename Ts, bool ProcessOffsets>
__mlu_global__ void MLUDequantFromPagedCacheKernelPerHead(void *key,
void *value,
const void *key_cache,
const void *value_cache,
const void *key_scale,
const void *value_scale,
const int32_t *context_lens,
const int32_t *context_seq_offsets,
const int32_t *block_tables,
const int32_t max_context_len,
const int32_t max_block_num,
const int32_t batch_size,
const int32_t head_num,
const int32_t key_group_num,
const int32_t value_group_num,
const int32_t block_size,
const int32_t head_size,
const int32_t seq_block,
const size_t context_head_stride,
const size_t context_seq_stride,
const size_t cache_bs_stride,
const size_t cache_head_stride,
const size_t key_cache_seq_stride,
const size_t value_cache_seq_stride,
const size_t scale_bs_stride,
const size_t scale_head_stride) {
bool has_key = (key && key_cache && key_scale);
bool has_value = (value && value_cache && value_scale);
if (!(has_key || has_value)) {
return;
}
/* *********************************nram space **************************************
* NRAM |scale[seq_block, head_num] * fp32|output/input[seq_block, head_num, head_size] * fp32|
* |temp[seq_block, head_num, head_size] * output_dtype|
* WRAM |head_size * 64B|
*/
int32_t scale_num = seq_block * head_num;
Ts *scale_nram = (Ts *)nbuf + (scale_num * (sizeof_(float) - sizeof_(Ts)) / sizeof_(Ts));
float *output_nram = (float *)nbuf + scale_num;
Tc *input_nram =
(Tc *)output_nram + (head_size * scale_num * (sizeof_(float) - sizeof_(Tc)) / sizeof_(Tc));
// temp_nram for nram to store temp output
float *temp_nram = (float *)output_nram + head_size * scale_num;
int32_t *block_offsets = (int32_t *)((int8_t *)output_nram + head_size * scale_num);
uint32_t *cache_offsets = (uint32_t *)block_offsets + DIV_UP(seq_block, block_size);
uint32_t *scale_offsets = (uint32_t *)cache_offsets + DIV_UP(seq_block, block_size);
int32_t seq_len;
int32_t seq_begin;
int32_t seq_offset;
int32_t deal_seq_num;
size_t context_offset;
__bang_write_value((float *)nbuf, head_size * 16, 1.0f);
mvNram2WramLT16<float>((int8_t *)wbuf, (int8_t *)nbuf, head_size, 16, 16);
process_offsets<ProcessOffsets>((int32_t *)n_lens, (int32_t *)n_offsets, (int32_t *)context_lens,
(int32_t *)context_seq_offsets, batch_size);
if (has_key) {
for (int32_t batch_idx = taskIdY; batch_idx < batch_size; batch_idx += taskDimY) {
load_len_offset<ProcessOffsets>(seq_len, seq_offset, (int32_t *)n_lens, (int32_t *)n_offsets,
(int32_t *)context_lens, (int32_t *)context_seq_offsets,
batch_idx);
seq_begin = taskIdZ * seq_block;
deal_seq_num = std::min(seq_len - seq_begin, seq_block);
if (deal_seq_num <= 0 || seq_offset < 0) continue;
context_offset = context_seq_stride * (seq_offset + seq_begin);
load_input_per_head((Tc *)input_nram, (Ts *)scale_nram, (Tc *)key_cache, (Ts *)key_scale,
(Tc *)temp_nram, (int32_t *)block_offsets, (uint32_t *)cache_offsets,
(uint32_t *)scale_offsets, (int32_t *)block_tables, batch_idx,
max_block_num, head_num, block_size, head_size, seq_begin, deal_seq_num,
cache_bs_stride, cache_head_stride, scale_bs_stride, scale_head_stride);
dequantize_per_head((T *)output_nram, (Tc *)input_nram, (Ts *)scale_nram, (T *)temp_nram,
(T *)key, deal_seq_num, head_num, block_size, head_size, context_offset,
context_seq_stride, context_head_stride);
}
}
if (has_value) {
for (int32_t batch_idx = taskIdY; batch_idx < batch_size; batch_idx += taskDimY) {
load_len_offset<ProcessOffsets>(seq_len, seq_offset, (int32_t *)n_lens, (int32_t *)n_offsets,
(int32_t *)context_lens, (int32_t *)context_seq_offsets,
batch_idx);
seq_begin = taskIdZ * seq_block;
deal_seq_num = std::min(seq_len - seq_begin, seq_block);
if (deal_seq_num <= 0 || seq_offset < 0) continue;
context_offset = context_seq_stride * (seq_offset + seq_begin);
load_input_per_head((Tc *)input_nram, (Ts *)scale_nram, (Tc *)value_cache, (Ts *)value_scale,
(Tc *)temp_nram, (int32_t *)block_offsets, (uint32_t *)cache_offsets,
(uint32_t *)scale_offsets, (int32_t *)block_tables, batch_idx,
max_block_num, head_num, block_size, head_size, seq_begin, deal_seq_num,
cache_bs_stride, cache_head_stride, scale_bs_stride, scale_head_stride);
dequantize_per_head((T *)output_nram, (Tc *)input_nram, (Ts *)scale_nram, (T *)temp_nram,
(T *)value, deal_seq_num, head_num, block_size, head_size, context_offset,
context_seq_stride, context_head_stride);
}
}
}
} // namespace kernels
#define DEQUANT_PAGED_INIT(T, Tc, Ts, C, Name) \
template __mlu_global__ void kernels::MLUDequantFromPagedCacheKernel##Name<T, Tc, Ts, C>( \
void *key, void *value, const void *key_cache, const void *value_cache, \
const void *key_scale, const void *value_scale, const int32_t *context_lens, \
const int32_t *context_seq_offsets, const int32_t *block_tables, \
const int32_t max_context_len, const int32_t max_block_num, const int32_t batch_size, \
const int32_t head_num, const int32_t key_group_num, const int32_t value_group_num, \
const int32_t block_size, const int32_t head_size, const int32_t seq_block, \
const size_t context_head_stride, const size_t context_seq_stride, \
const size_t cache_bs_stride, const size_t cache_head_stride, \
const size_t key_cache_seq_stride, const size_t value_cache_seq_stride, \
const size_t scale_bs_stride, const size_t scale_head_stride);
DEQUANT_PAGED_INIT(half, int8_t, float, false, PerChannel)
DEQUANT_PAGED_INIT(bfloat16_t, int8_t, float, false, PerChannel)
DEQUANT_PAGED_INIT(float, int8_t, float, false, PerChannel)
// DEQUANT_PAGED_INIT(half, int4x2_t, float, false, PerChannel)
// DEQUANT_PAGED_INIT(bfloat16_t, int4x2_t, float, false, PerChannel)
// DEQUANT_PAGED_INIT(float, int4x2_t, float, false, PerChannel)
DEQUANT_PAGED_INIT(half, int8_t, float, false, PerHead)
DEQUANT_PAGED_INIT(bfloat16_t, int8_t, float, false, PerHead)
DEQUANT_PAGED_INIT(float, int8_t, float, false, PerHead)
// DEQUANT_PAGED_INIT(half, int4x2_t, float, false, PerHead)
// DEQUANT_PAGED_INIT(bfloat16_t, int4x2_t, float, false, PerHead)
// DEQUANT_PAGED_INIT(float, int4x2_t, float, false, PerHead)
DEQUANT_PAGED_INIT(half, int8_t, float, true, PerChannel)
DEQUANT_PAGED_INIT(bfloat16_t, int8_t, float, true, PerChannel)
DEQUANT_PAGED_INIT(float, int8_t, float, true, PerChannel)
// DEQUANT_PAGED_INIT(half, int4x2_t, float, true, PerChannel)
// DEQUANT_PAGED_INIT(bfloat16_t, int4x2_t, float, true, PerChannel)
// DEQUANT_PAGED_INIT(float, int4x2_t, float, true, PerChannel)
DEQUANT_PAGED_INIT(half, int8_t, float, true, PerHead)
DEQUANT_PAGED_INIT(bfloat16_t, int8_t, float, true, PerHead)
DEQUANT_PAGED_INIT(float, int8_t, float, true, PerHead)
// DEQUANT_PAGED_INIT(half, int4x2_t, float, true, PerHead)
// DEQUANT_PAGED_INIT(bfloat16_t, int4x2_t, float, true, PerHead)
// DEQUANT_PAGED_INIT(float, int4x2_t, float, true, PerHead)
typedef void (*DequantFromPagedCachePointer)(void *, // key
void *, // value
const void *, // key_cache
const void *, // value_cache
const void *, // key_scale
const void *, // value_scale
const int32_t *, // context_lens
const int32_t *, // context_seq_offsets
const int32_t *, // block_tables
const int32_t, // max_context_len
const int32_t, // max_block_num
const int32_t, // batch_size
const int32_t, // head_num
const int32_t, // key_group_num
const int32_t, // value_group_num
const int32_t, // block_size
const int32_t, // head_size
const int32_t, // seq_block
const size_t, // context_head_stride
const size_t, // context_seq_stride
const size_t, // cache_bs_stride
const size_t, // cache_head_stride
const size_t, // key_cache_seq_stride
const size_t, // value_cache_seq_stride
const size_t, // scale_bs_stride
const size_t); // scale_head_stride
static DequantFromPagedCachePointer DequantFromPagedCacheFuncArr[DEQUANT_FUNC_LEN] = {
DEQUANT_PAGED_PERCHANNEL<half, int8_t, float, false>,
DEQUANT_PAGED_PERCHANNEL<bfloat16_t, int8_t, float, false>,
DEQUANT_PAGED_PERCHANNEL<float, int8_t, float, false>, nullptr, nullptr, nullptr,
// DEQUANT_PAGED_PERCHANNEL<half, int4x2_t, float, false>,
// DEQUANT_PAGED_PERCHANNEL<bfloat16_t, int4x2_t, float, false>,
// DEQUANT_PAGED_PERCHANNEL<float, int4x2_t, float, false>,
DEQUANT_PAGED_PERHEAD<half, int8_t, float, false>,
DEQUANT_PAGED_PERHEAD<bfloat16_t, int8_t, float, false>,
DEQUANT_PAGED_PERHEAD<float, int8_t, float, false>, nullptr, nullptr, nullptr,
// DEQUANT_PAGED_PERHEAD<half, int4x2_t, float, false>,
// DEQUANT_PAGED_PERHEAD<bfloat16_t, int4x2_t, float, false>,
// DEQUANT_PAGED_PERHEAD<float, int4x2_t, float, false>,
DEQUANT_PAGED_PERCHANNEL<half, int8_t, float, true>,
DEQUANT_PAGED_PERCHANNEL<bfloat16_t, int8_t, float, true>,
DEQUANT_PAGED_PERCHANNEL<float, int8_t, float, true>, nullptr, nullptr, nullptr,
// DEQUANT_PAGED_PERCHANNEL<half, int4x2_t, float, true>,
// DEQUANT_PAGED_PERCHANNEL<bfloat16_t, int4x2_t, float, true>,
// DEQUANT_PAGED_PERCHANNEL<float, int4x2_t, float, true>,
DEQUANT_PAGED_PERHEAD<half, int8_t, float, true>,
DEQUANT_PAGED_PERHEAD<bfloat16_t, int8_t, float, true>,
DEQUANT_PAGED_PERHEAD<float, int8_t, float, true>, nullptr, nullptr, nullptr};
// DEQUANT_PAGED_PERHEAD<half, int4x2_t, float, true>,
// DEQUANT_PAGED_PERHEAD<bfloat16_t, int4x2_t, float, true>,
// DEQUANT_PAGED_PERHEAD<float, int4x2_t, float, true>};
uint32_t getDequantPagedIdx(cnnlDataType_t dtype,
int32_t quant_mode,
int32_t quant_bit,
const void *context_seq_offset) {
uint32_t idx = 0;
idx += (quant_mode != 0) ? 6 : 0;
idx += (quant_bit != 8) ? 3 : 0;
idx += (dtype == CNNL_DTYPE_BFLOAT16) ? 1 : 0;
idx += (dtype == CNNL_DTYPE_FLOAT) ? 2 : 0;
idx += (context_seq_offset == nullptr) ? 12 : 0;
return idx;
}
void getBlockAndDimForPaged(int32_t &seq_block,
cnrtDim3_t &task_dim,
cnrtFunctionType_t &task_type,
const int32_t max_context_len,
const int32_t head_num,
const int32_t batch_size,
const int32_t head_size,
const int32_t block_size,
const int32_t quant_mode,
const int32_t quant_bit,
const cnnlDataType_t dtype) {
int32_t core_dim;
int32_t cluster_dim;
int32_t nram_size = 480 * 1024;
int32_t wram_size = 512 * 1024;
int32_t sram_size = 2016 * 1024;
getDeviceCoreAndRam(cluster_dim, core_dim, nram_size, wram_size, sram_size, REM_FOR_STACK);
if (quant_mode == 0) {
seq_block = int32_t((int64_t)nram_size / ((int64_t)head_num * head_size * sizeof_(float)) - 1);
if (seq_block < block_size) {
std::cerr << __func__ << "," << __LINE__
<< " :head_num * head_size * sizeof_(float) should be less than "
<< nram_size / block_size << " when quant_mode is 0." << std::endl;
}
} else {
int32_t dtype_size = dtype == CNNL_DTYPE_FLOAT ? sizeof_(float) : sizeof_(half);
seq_block = int32_t((int64_t)nram_size / ((int64_t)head_num * (head_size + 1) * sizeof_(float) +
(int64_t)head_num * head_size * dtype_size));
if (seq_block < block_size) {
std::cerr << __func__ << "," << __LINE__
<< " :head_num * sizeof_(float) + head_num * head_size * (sizeof_(float) + "
"context_dtype_size) "
<< "should be less than " << nram_size / block_size << " when quant_mode is 1."
<< std::endl;
}
/* head_size * 64B put in the wram. */
if (head_size * ONE_LINE >= wram_size) {
std::cerr << __func__ << "," << __LINE__ << " head_size * 64 " << "should be less than "
<< wram_size << " when quant_mode is 1." << std::endl;
}
}
// seq_block should be a multiply of block_size
seq_block = PAD_DOWN(seq_block, block_size);
int seq_seg = DIV_UP(max_context_len, seq_block);
int32_t core_num = cluster_dim * core_dim;
if (batch_size * seq_seg <= (core_num / 2)) {
int times = core_num / batch_size / seq_seg;
seq_block = std::max(seq_block / times, 2);
if (seq_block > block_size) {
seq_block = PAD_DOWN(seq_block, block_size);
} else {
seq_block = block_size;
}
seq_seg = DIV_UP(max_context_len, seq_block);
}
task_dim.x = 1;
task_dim.y = uint32_t(std::min(batch_size, core_num));
task_dim.z = uint32_t(seq_seg);
task_type = cnrtFuncTypeBlock;
}
KernelStatus invokeDequantFromPagedCache(cnrtQueue_t queue,
void *key,
void *value,
const void *key_cache,
const void *value_cache,
const void *key_scale,
const void *value_scale,
const void *context_lens,
const void *context_seq_offsets,
const void *block_tables,
const int32_t max_context_len,
const int32_t max_block_num,
const int32_t batch_size,
const int32_t head_num,
const int32_t key_group_num,
const int32_t value_group_num,
const int32_t block_size,
const int32_t head_size,
const int32_t quant_mode,
const int32_t quant_bit,
const size_t context_head_stride,
const size_t context_seq_stride,
const size_t cache_bs_stride,
const size_t cache_head_stride,
const size_t key_cache_seq_stride,
const size_t value_cache_seq_stride,
const size_t scale_bs_stride,
const size_t scale_head_stride,
const cnnlDataType_t dtype) {
if (is_arch300()) {
std::cerr << "[invokeDequantFromPagedCache]: kernel does not support MLU300 devices."
<< std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
int32_t index;
int32_t seq_block;
cnrtDim3_t k_dim;
cnrtFunctionType_t k_type;
getBlockAndDimForPaged(seq_block, k_dim, k_type, max_context_len, head_num, batch_size, head_size,
block_size, quant_mode, quant_bit, dtype);
index = getDequantPagedIdx(dtype, quant_mode, quant_bit, context_seq_offsets);
auto dequant_paged_func = DequantFromPagedCacheFuncArr[index];
dequant_paged_func<<<k_dim, k_type, queue>>>(
(void *)key, (void *)value, (const void *)key_cache, (const void *)value_cache,
(const void *)key_scale, (const void *)value_scale, (const int32_t *)context_lens,
(const int32_t *)context_seq_offsets, (const int32_t *)block_tables, max_context_len,
max_block_num, batch_size, head_num, key_group_num, value_group_num, block_size, head_size,
seq_block, context_head_stride, context_seq_stride, cache_bs_stride, cache_head_stride,
key_cache_seq_stride, value_cache_seq_stride, scale_bs_stride, scale_head_stride);
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
} // namespace tmo

View File

@@ -0,0 +1,106 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_DEQUANT_FROM_PAGED_CACHE_MLUH_
#define CSRC_KERNELS_DEQUANT_FROM_PAGED_CACHE_MLUH_
#include "cnnl.h"
#include "kernel_utils.h"
namespace tmo {
/**
* @brief De-quantizes the key and value tensors from the provided paged cache and scale.
* @param queue: The queue for mlu.
* @param key: Pointer to the MLU memory that stores the key tensor,
* with shape [total_seqlen, head_num, head_size]. Data type can be half
* or bfloat16. This parameter can be nullptr.
* @param value: Pointer to the MLU memory that stores the value tensor,
* with shape [total_seqlen, head_num, head_size]. Data type can be
* half or bfloat16.This parameter can be nullptr.
* @param key_cache: Pointer to the MLU memory that stores the key cache tensor,
* with shape [total_blocks, head_num, block_size, head_size] for 8-bit quantization.
* Data type must be int8. This parameter can be nullptr.
* @param value_cache: Pointer to the MLU memory that stores the value cache tensor,
* with shape [total_blocks, head_num, block_size, head_size] for 8-bit quantization.
* Data type must be int8. This parameter can be nullptr.
* @param key_scale: Pointer to the MLU memory that stores the key cache quantization scale.
* Shape depends on quantization mode:
* - For per-channel quantization (quant_mode = 0): [head_num, head_size].
* - For per-token quantization (quant_mode = 1): [total_blocks, head_num, block_size].
* Data type must be float32. This parameter can be nullptr.
* @param value_scale: Pointer to the MLU memory that stores the value cache quantization scale,
* with the same shape as key_scale. Data type must be float32. This parameter can be
nullptr.
* @param context_lens: Pointer to the MLU memory that stores the sequence lengths.
* The shape must be [batch].
* @param context_seq_offset: Pointer to the MLU memory that stores the sequence offset in the
context.
* The shape must be [batch]. If nullptr, the default value is the cumulative sum of
context_lengths.
* @param block_tables: Pointer to the MLU memory that stores the block tables for indexing.
* The shape must be [batch, max_block_num].
* @param max_contxt_len: The maximum sequence length of context.
* @param max_block_num: The maximum block number of each batch.
* @param batch: Batch size.
* @param head_num: Head number.
* @param key_group_num: group number of key group-wise quantization.
* @param value_group_num: group number of value group-wise quantization.
* @param block_size: The block size of the cache.
* @param head_size: Head size.
* @param quant_mode: An integer value indicating the quantization mode:
* 0 for per-channel quantization and 1 for per-token quantization.
* @param quant_bit: An integer value indicating the quantization bit width:
* 8 for 8-bit quantization.
* @param contxt_head_stride: The stride of head_num in context.
* @param contxt_seq_stride: The stride of max_contxt_len in context.
* @param cache_bs_stride: The stride of batch in cache.
* @param cache_head_stride: The stride of head_num in cache.
* @param key_cache_seq_stride: The stride of cache_mem_len in key cache.
* @param value_cache_seq_stride: The stride of cache_mem_len in value cache.
* @param cache_scale_bs_stride: The stride of batch in cache scale, only valid if quant_per_quant.
* @param cache_scale_head_stride: The stride of head in cache scale.
* @param dtype: The data type of the key and value tensors.
* @note If any of key/key_cache/key_scale is nullptr, no operation is performed on the key.
* If any of value/value_cache/value_scale is nullptr, no operation is performed on the value.
*/
KernelStatus invokeDequantFromPagedCache(cnrtQueue_t queue,
void *key,
void *value,
const void *key_cache,
const void *value_cache,
const void *key_scale,
const void *value_scale,
const void *context_lens,
const void *context_seq_offsets,
const void *block_tables,
const int32_t max_context_len,
const int32_t max_block_num,
const int32_t batch,
const int32_t head_num,
const int32_t key_group_num,
const int32_t value_group_num,
const int32_t block_size,
const int32_t head_size,
const int32_t quant_mode,
const int32_t quant_bit,
const size_t context_head_stride,
const size_t context_seq_stride,
const size_t cache_bs_stride,
const size_t cache_head_stride,
const size_t key_cache_seq_stride,
const size_t value_cache_seq_stride,
const size_t cache_scale_bs_stride,
const size_t cache_scale_head_stride,
const cnnlDataType_t dtype);
} // namespace tmo
#endif // CSRC_KERNELS_DEQUANT_FROM_PAGED_CACHE_MLUH_

View File

@@ -0,0 +1,254 @@
#include <cassert>
#include <iostream>
#include <map>
#include <ostream>
#include "cnnl.h"
#include "cnrt.h"
#include "dequantify.mluh"
// clang-format off
#include <mlu.h>
// clang-format on
namespace tmo {
namespace kernels {
template <typename T>
struct PackValueNum {
const static int value = 1;
};
template <>
struct PackValueNum<int4x2_t> {
const static int value = 2;
};
__nram__ int8_t nram_buf[(__MLU_NRAM_SIZE__ * 3 / 8 * 1024)];
__nram__ int8_t nram_buf_scale[8192];
__mlu_func__ void convert(float *dst, const int8_t *src, int count, float scale) {
__bang_int82float(dst, src, count, 0);
__bang_mul_scalar(dst, dst, scale, count);
}
__mlu_func__ void convert(float *dst, const int4x2_t *src, int count, float scale) {
__bang_int42float_rn(dst, src, count, 0);
__bang_mul_scalar(dst, dst, scale, count);
}
__mlu_func__ void convert(half *dst, const int8_t *src, int count, float scale) {
__bang_int82half(dst, src, count, 0);
__bang_mul_scalar(dst, dst, (half)scale, count);
}
__mlu_func__ void convert(half *dst, const int4x2_t *src, int count, float scale) {
__bang_int42half_rn(dst, src, count, 0);
__bang_mul_scalar(dst, dst, (half)scale, count);
}
template <typename T>
__mlu_func__ void swap(T *&ping, T *&pong) {
T *tmp = ping;
ping = pong;
pong = tmp;
}
template <typename TDst, typename TSrc>
__mlu_global__ void dequantifyPerTensor(void *all_dst,
const void *all_src,
size_t all_src_count,
float scale) {
scale = 1.0f / scale;
size_t src_per_core = all_src_count / taskDim;
size_t src_remain = all_src_count % taskDim;
size_t start = taskId * src_per_core + (taskId < src_remain ? taskId : src_remain);
const size_t src_count = src_per_core + (taskId < src_remain ? 1 : 0);
TDst *dst = reinterpret_cast<TDst *>(all_dst) + start * PackValueNum<TSrc>::value;
const TSrc *src = reinterpret_cast<const TSrc *>(all_src) + start;
constexpr int size_unit = sizeof(nram_buf) / 2 / // divide by 2 for ping pong
(sizeof(TSrc) + sizeof(TDst) * PackValueNum<TSrc>::value) / 128 *
128; // align to 128
constexpr int src_num_unit = size_unit / sizeof(TSrc);
constexpr int dst_num_unit = src_num_unit * PackValueNum<TSrc>::value;
int8_t *nram_buf_ping = nram_buf;
int8_t *nram_buf_pong = nram_buf + sizeof(nram_buf) / 2;
TSrc *nram_src_ping = reinterpret_cast<TSrc *>(nram_buf_ping);
TDst *nram_dst_ping =
reinterpret_cast<TDst *>(nram_buf_ping + static_cast<int>(sizeof(TSrc)) * size_unit);
TSrc *nram_src_pong = reinterpret_cast<TSrc *>(nram_buf_pong);
TDst *nram_dst_pong =
reinterpret_cast<TDst *>(nram_buf_pong + static_cast<int>(sizeof(TSrc)) * size_unit);
int loop_count = src_count / src_num_unit;
int remain_count = src_count % src_num_unit;
// L
__memcpy_async(nram_src_ping, src, sizeof(TSrc) * src_num_unit, GDRAM2NRAM);
swap(nram_src_ping, nram_src_pong);
swap(nram_dst_ping, nram_dst_pong);
__sync_io_move_compute();
// L C
__memcpy_async(nram_src_ping, src + 1 * src_num_unit, sizeof(TSrc) * src_num_unit, GDRAM2NRAM);
convert(nram_dst_pong, nram_src_pong, dst_num_unit, scale);
swap(nram_src_ping, nram_src_pong);
swap(nram_dst_ping, nram_dst_pong);
__sync_io_move_compute();
// L C S
for (int i = 0; i < loop_count - 2; ++i) {
__memcpy_async(nram_src_ping, src + (i + 2) * src_num_unit, sizeof(TSrc) * src_num_unit,
GDRAM2NRAM);
__memcpy_async(dst + i * dst_num_unit, nram_dst_ping, sizeof(TDst) * dst_num_unit, NRAM2GDRAM);
convert(nram_dst_pong, nram_src_pong, dst_num_unit, scale);
swap(nram_src_ping, nram_src_pong);
swap(nram_dst_ping, nram_dst_pong);
__sync_io_move_compute();
}
// C S
__memcpy_async(dst + (loop_count - 2) * dst_num_unit, nram_dst_ping, sizeof(TDst) * dst_num_unit,
NRAM2GDRAM);
convert(nram_dst_pong, nram_src_pong, dst_num_unit, scale);
swap(nram_src_ping, nram_src_pong);
swap(nram_dst_ping, nram_dst_pong);
__sync_io_move_compute();
// S
__memcpy_async(dst + (loop_count - 1) * dst_num_unit, nram_dst_ping, sizeof(TDst) * dst_num_unit,
NRAM2GDRAM);
__sync_io_move_compute();
if (remain_count > 0) {
__memcpy(nram_src_ping, src + loop_count * src_num_unit, sizeof(TSrc) * remain_count,
GDRAM2NRAM);
convert(nram_dst_ping, nram_src_ping, remain_count * PackValueNum<TSrc>::value, scale);
__memcpy(dst + loop_count * dst_num_unit, nram_dst_ping,
sizeof(TDst) * remain_count * PackValueNum<TSrc>::value, NRAM2GDRAM);
}
}
// does not use a pipeline because per channel is more complicated but it's a one-time operation, so
// performance doesn't matter.
template <typename TDst, typename TSrc>
__mlu_global__ void dequantifyPerChannel(void *all_dst,
const void *all_src,
int src_ci,
int all_co,
const void *scale) {
const int co_per_core = all_co / taskDim;
const int co_remain = all_co % taskDim;
const int start_co = taskId * co_per_core + (taskId < co_remain ? taskId : co_remain);
const int co_count = co_per_core + (taskId < co_remain ? 1 : 0);
assert(co_count <= sizeof(nram_buf_scale) / sizeof(TDst));
constexpr int size_unit = sizeof(nram_buf) /
(sizeof(TSrc) + sizeof(TDst) * PackValueNum<TSrc>::value) / 128 *
128; // align to 128
// yes, we only deal with 1 channel at a time
// no, there's no need to optimize a one-time operation
const int src_num_unit = std::min((int)(size_unit / sizeof(TSrc)), src_ci);
const int dst_num_unit = src_num_unit * PackValueNum<TSrc>::value;
TSrc *const nram_src = reinterpret_cast<TSrc *>(nram_buf);
TDst *const nram_dst =
reinterpret_cast<TDst *>(nram_buf + static_cast<int>(sizeof(TSrc)) * size_unit);
const TDst *nram_scale = reinterpret_cast<const TDst *>(nram_buf_scale);
const int loop_one_channel = src_ci / src_num_unit;
const int remain_one_channel = src_ci % src_num_unit;
for (int o = start_co; o < start_co + co_count; ++o) {
const TSrc *src = reinterpret_cast<const TSrc *>(all_src) + o * src_ci;
TDst *dst = reinterpret_cast<TDst *>(all_dst) + o * src_ci;
const TDst scale_value = 1. / nram_scale[o];
for (int i = 0; i < loop_one_channel; ++i) {
__memcpy(nram_src, src + i * src_num_unit, sizeof(TSrc) * src_num_unit, GDRAM2NRAM);
convert(nram_dst, nram_src, dst_num_unit, scale_value);
__memcpy(dst + i * dst_num_unit, nram_dst, sizeof(TDst) * dst_num_unit, NRAM2GDRAM);
}
if (remain_one_channel > 0) {
__memcpy(nram_src, src + loop_one_channel * src_num_unit, sizeof(TSrc) * remain_one_channel,
GDRAM2NRAM);
convert(nram_dst, nram_src, remain_one_channel * PackValueNum<TSrc>::value, scale_value);
__memcpy(dst + loop_one_channel * dst_num_unit, nram_dst,
sizeof(TDst) * remain_one_channel * PackValueNum<TSrc>::value, NRAM2GDRAM);
}
}
}
} // namespace kernels
static const std::map<std::pair<int, cnnlDataType_t>,
decltype(&kernels::dequantifyPerTensor<half, int4x2_t>)>
per_tensor_func_map = {
{{4, CNNL_DTYPE_HALF}, &kernels::dequantifyPerTensor<half, int4x2_t>},
{{4, CNNL_DTYPE_FLOAT}, &kernels::dequantifyPerTensor<float, int4x2_t>},
{{8, CNNL_DTYPE_HALF}, &kernels::dequantifyPerTensor<half, int8_t>},
{{8, CNNL_DTYPE_FLOAT}, &kernels::dequantifyPerTensor<float, int8_t>},
};
KernelStatus invokeDequantifyPerTensor(cnnlHandle_t handle,
const void *src,
int src_bitwidth,
void *dst,
cnnlDataType_t dst_dtype,
size_t src_count,
float scale) {
cnrtQueue_t queue;
cnnlGetQueue(handle, &queue);
CNdev dev;
cnnlGetDevice(handle, &dev);
int cluster_num;
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
const cnrtDim3_t dim = {.x = 4, .y = (uint32_t)cluster_num, .z = 1};
auto iter = per_tensor_func_map.find(std::make_pair(src_bitwidth, dst_dtype));
if (iter == per_tensor_func_map.end()) {
std::cerr << "[invokeDequantifyPerTensor]: unsupported src_bitwidth: " << src_bitwidth
<< " dst_dtype: " << dst_dtype;
return KernelStatus::KERNEL_STATUS_FAILED;
}
iter->second<<<dim, cnrtFuncTypeUnion1, queue>>>(dst, src, src_count, scale);
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
static const std::map<std::pair<int, cnnlDataType_t>,
decltype(&kernels::dequantifyPerChannel<half, int4x2_t>)>
per_channel_func_map = {
{{4, CNNL_DTYPE_HALF}, &kernels::dequantifyPerChannel<half, int4x2_t>},
{{4, CNNL_DTYPE_FLOAT}, &kernels::dequantifyPerChannel<float, int4x2_t>},
{{8, CNNL_DTYPE_HALF}, &kernels::dequantifyPerChannel<half, int8_t>},
{{8, CNNL_DTYPE_FLOAT}, &kernels::dequantifyPerChannel<float, int8_t>},
};
KernelStatus invokeDequantifyPerChannel(cnnlHandle_t handle,
const void *src,
int src_bitwidth,
void *dst,
cnnlDataType_t dst_dtype,
int src_ci,
int co,
const void *scale) {
cnrtQueue_t queue;
cnnlGetQueue(handle, &queue);
CNdev dev;
cnnlGetDevice(handle, &dev);
int cluster_num;
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
const cnrtDim3_t dim = {.x = 4, .y = (uint32_t)cluster_num, .z = 1};
auto iter = per_channel_func_map.find(std::make_pair(src_bitwidth, dst_dtype));
if (iter == per_channel_func_map.end()) {
std::cerr << "[invokeDequantifyPerChannel]: unsupported src_bitwidth: " << src_bitwidth
<< " dst_dtype: " << dst_dtype;
return KernelStatus::KERNEL_STATUS_FAILED;
}
iter->second<<<dim, cnrtFuncTypeUnion1, queue>>>(dst, src, src_ci, co, scale);
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
} // namespace tmo

View File

@@ -0,0 +1,57 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_DEQUANTIFY_MLUH_
#define CSRC_KERNELS_DEQUANTIFY_MLUH_
#include "cnnl.h"
#include "kernel_utils.h"
namespace tmo {
/**
* @brief Dequantify per tensor.
* @param handle: The handle of cnnl.
* @param src: Input. Pointer to the MLU memory that stores the input.
* @param src_bitwidth: The bitwidth of input quantized data.
* @param dst: Output. Pointer to the MLU memory that stores the output.
* @param dst_dtype: The data type of output.
* @param src_count: The number of elements in input.
* @param scale: The scale for dequantify.
*/
KernelStatus invokeDequantifyPerTensor(cnnlHandle_t handle,
const void *src,
int src_bitwidth,
void *dst,
cnnlDataType_t dst_dtype,
size_t src_count,
float scale);
/**
* @brief Dequantify per channel.
* @param handle: The handle of cnnl.
* @param src: Input. Pointer to the MLU memory that stores the input.
* @param src_bitwidth: The bitwidth of input quantized data.
* @param dst: Output. Pointer to the MLU memory that stores the output.
* @param dst_dtype: The data type of output.
* @param src_ci: The ci of input.
* @param co: The co of input.
* @param scale: Pointer to the MLU memory that stores the scale for dequantify.
*/
KernelStatus invokeDequantifyPerChannel(cnnlHandle_t handle,
const void *src,
int src_bitwidth,
void *dst,
cnnlDataType_t dst_dtype,
int src_ci,
int co,
const void *scale);
} // namespace tmo
#endif // CSRC_KERNELS_DEQUANTIFY_MLUH_

View File

@@ -0,0 +1,310 @@
#include <cassert>
#include <iostream>
#include <map>
#include <ostream>
#include "cnnl.h"
#include "cnrt.h"
#include "embedding.mluh"
// clang-format off
#include <mlu.h>
// clang-format on
namespace tmo {
namespace kernels {
#define MAX_UINT32 (4294967295)
#define MAX_SINT32 (2147483647)
#define NRAM_SIZE (__MLU_NRAM_SIZE__ * 1024 - 32 * 1024)
__nram__ int8_t nram_buffer[NRAM_SIZE];
__mlu_func__ void split(const int total, const int num, const int id, int &every, int &offset) {
int base = total / num;
int tail = total - base * num;
every = base + (id < tail ? 1 : 0);
offset = base * id + (id < tail ? id : tail);
}
#define PAD_DOWN(x, y) (((x) / (y)) * (y))
#define PAD_UP(x, y) (((x) / (y) + (int)((x) % (y) > 0)) * (y))
template <typename T>
__mlu_func__ void embeddingImpl_500(T *filter,
int *input_ids,
T *output,
int vocab_offset,
int vocab_size,
int input_size,
int total_seq) {
if (__is_mpu()) {
return;
};
int bs_core = 0;
int bs_offset = 0;
split(total_seq, taskDim, taskId, bs_core, bs_offset);
// 8 * sizeof(int) left for mask_nram, because __bang_eq_bitindex <elem_count> must be divisible
// by 8
int limit = (NRAM_SIZE - input_size * sizeof(T) - 8 * sizeof(int)) /
(input_size * sizeof(T) + 4 * sizeof(int) + sizeof(int8_t));
int vocab_start = vocab_offset;
int vocab_end = vocab_offset + vocab_size - 1;
T *zeros_nram = (T *)nram_buffer; // input_size * sizeof(T)
T *emb_nram = zeros_nram + input_size; // limit * input_size * sizeof(T)
int *ones_nram = (int *)(emb_nram + (size_t)limit * input_size); // limit * sizeof(int)
int *idxs_nram = ones_nram + limit; // limit * sizeof(int)
int *mask_nram = idxs_nram + limit; // limit_pad * sizeof(int)
int *temp_nram = mask_nram + PAD_UP(limit, 8); // limit * sizeof(int)
uint8_t *zeros_offset_nram = (uint8_t *)(temp_nram + limit); // limit * sizeof(int8_t)
__bang_write_zero(zeros_nram, input_size);
__bang_write_zero(zeros_offset_nram, limit);
__bang_write_value(ones_nram, limit, 1);
int repeat = bs_core / limit;
int remain = bs_core % limit;
for (int i = 0; i < repeat + 1; i++) {
if ((i == repeat) && (remain == 0)) {
return;
}
int num = (i == repeat) ? remain : limit;
int num_pad = PAD_UP(num, 8); // for __bang_eq_bitindex
__memcpy_async(idxs_nram, input_ids + bs_offset + i * limit, num * sizeof(int), GDRAM2NRAM);
__sync();
__bang_ge_scalar(mask_nram, idxs_nram, vocab_start, num);
__bang_lt_scalar(temp_nram, idxs_nram, vocab_end + 1, num);
__bang_mul(mask_nram, mask_nram, temp_nram, num);
__bang_eq_bitindex((float *)mask_nram, (float *)mask_nram, (float *)ones_nram,
num_pad); // gather valid mask
__bang_bnot((int8_t *)temp_nram, (int8_t *)mask_nram, num); // gather invalid mask
__bang_sub_scalar(idxs_nram, idxs_nram, vocab_offset, num); // true index
__bang_mul_scalar((unsigned int *)idxs_nram, (unsigned int *)idxs_nram,
(unsigned int)input_size * sizeof(T), num); // gather offset
__sync();
__gather_async(emb_nram, filter, (unsigned int *)idxs_nram, mask_nram, input_size * sizeof(T),
GDRAM2NRAM, input_size * sizeof(T), num);
__gather_async(emb_nram, zeros_nram, zeros_offset_nram, temp_nram, input_size * sizeof(T),
NRAM2NRAM, input_size * sizeof(T), num);
__sync();
__memcpy_async(output + (size_t)(bs_offset + i * limit) * input_size, emb_nram,
num * input_size * sizeof(T), NRAM2GDRAM);
__sync();
}
}
template <typename T>
__mlu_func__ void write_zero(T *dst, unsigned int elem_count) {
__bang_write_zero(dst, elem_count);
}
template <>
__mlu_func__ void write_zero(bfloat16_t *dst, unsigned int elem_count) {
#if __BANG_ARCH__ >= 500
__bang_write_zero(dst, elem_count);
#endif
}
template <typename T>
__mlu_func__ void embeddingImpl_300(T *filter,
int *input_ids,
T *output,
int vocab_offset,
int vocab_size,
int input_size,
int total_seq) {
if (__is_mpu()) {
return;
};
int bs_core = 0;
int bs_offset = 0;
split(total_seq, taskDim, taskId, bs_core, bs_offset);
int limit = (NRAM_SIZE - 64) / (input_size * sizeof(T) + sizeof(int));
limit = PAD_DOWN(limit, 2);
int repeat = bs_core / limit;
int remain = bs_core % limit;
int vocab_start = vocab_offset;
int vocab_end = vocab_offset + vocab_size - 1;
T *emb_nram = (T *)nram_buffer; // limit * input_size * sizeof(T)
int *idxs_nram = (int *)(emb_nram + (size_t)limit * input_size); // limit * sizeof(int)
for (int i = 0; i < repeat + 1; i++) {
if ((i == repeat) && (remain == 0)) {
return;
}
int num = (i == repeat) ? remain : limit;
__memcpy_async(idxs_nram, input_ids + bs_offset + i * limit, num * sizeof(int), GDRAM2NRAM);
__sync();
int idx1 = idxs_nram[0];
int idx2 = idxs_nram[1];
bool first = (idx1 >= vocab_start && idx1 <= vocab_end);
bool second = (idx2 >= vocab_start && idx2 <= vocab_end);
for (int n = 0; n < num / 2 * 2; n += 2) {
if (first && second) {
__memcpy_async(emb_nram + n * input_size,
filter + (idx1 - vocab_offset) * (size_t)input_size, input_size * sizeof(T),
GDRAM2NRAM, input_size * sizeof(T), (idx2 - idx1) * input_size * sizeof(T),
1);
} else if (!first && !second) {
write_zero(emb_nram + n * input_size, 2 * input_size);
} else if (first && !second) {
write_zero(emb_nram + (n + 1) * input_size, input_size);
__memcpy_async(emb_nram + n * input_size,
filter + (idx1 - vocab_offset) * (size_t)input_size, input_size * sizeof(T),
GDRAM2NRAM);
} else {
write_zero(emb_nram + n * input_size, input_size);
__memcpy_async(emb_nram + (n + 1) * input_size,
filter + (idx2 - vocab_offset) * (size_t)input_size, input_size * sizeof(T),
GDRAM2NRAM);
}
idx1 = idxs_nram[n + 2];
idx2 = idxs_nram[n + 3];
first = (idx1 >= vocab_start && idx1 <= vocab_end);
second = (idx2 >= vocab_start && idx2 <= vocab_end);
} // copy loop
// last idx copy
if (num % 2 == 1) {
if (first) {
__memcpy_async(emb_nram + (num - 1) * input_size,
filter + (idx1 - vocab_offset) * (size_t)input_size, input_size * sizeof(T),
GDRAM2NRAM);
} else {
write_zero(emb_nram + (num - 1) * input_size, input_size);
}
}
__sync();
__memcpy_async(output + (size_t)(bs_offset + i * limit) * input_size, emb_nram,
num * input_size * sizeof(T), NRAM2GDRAM);
__sync();
}
}
template <typename T>
__mlu_func__ void embeddingImpl_generic(T *filter,
int *input_ids,
T *output,
int vocab_offset,
int vocab_size,
int input_size,
int total_seq) {
if (__is_mpu()) {
return;
};
int bs_core = 0;
int bs_offset = 0;
split(total_seq, taskDim, taskId, bs_core, bs_offset);
int limit = (NRAM_SIZE - 64) / (input_size * sizeof(T) + sizeof(int));
limit = PAD_DOWN(limit, 2);
int repeat = bs_core / limit;
int remain = bs_core % limit;
int vocab_start = vocab_offset;
int vocab_end = vocab_offset + vocab_size - 1;
T *emb_nram = (T *)nram_buffer; // limit * input_size * sizeof(T)
int *idxs_nram = (int *)(emb_nram + (size_t)limit * input_size); // limit * sizeof(int)
for (int i = 0; i < repeat + 1; i++) {
if ((i == repeat) && (remain == 0)) {
return;
}
int num = (i == repeat) ? remain : limit;
__memcpy_async(idxs_nram, input_ids + bs_offset + i * limit, num * sizeof(int), GDRAM2NRAM);
__sync();
int idx = idxs_nram[0];
bool hit = (idx >= vocab_start && idx <= vocab_end);
for (int n = 0; n < num; n++) {
if (hit) {
__memcpy_async(emb_nram + n * input_size,
filter + (idx - vocab_offset) * (size_t)input_size, input_size * sizeof(T),
GDRAM2NRAM);
} else {
write_zero(emb_nram + n * input_size, input_size);
}
idx = idxs_nram[n + 1];
hit = (idx >= vocab_start && idx <= vocab_end);
}
__sync();
__memcpy_async(output + (size_t)(bs_offset + i * limit) * input_size, emb_nram,
num * input_size * sizeof(T), NRAM2GDRAM);
__sync();
}
}
template <typename T>
__mlu_global__ void MLUEmbeddingKernel(T *filter,
int *input_ids,
T *output,
int vocab_offset,
int vocab_size,
int total_vocab_size,
int input_size,
int total_seq) {
#if __BANG_ARCH__ > 372
// __gather index maximum dtype is unsigned int
if ((size_t)(total_vocab_size - 1) * input_size * sizeof(T) <= (size_t)(MAX_UINT32)) {
embeddingImpl_500(filter, input_ids, output, vocab_offset, vocab_size, input_size, total_seq);
} else {
embeddingImpl_generic(filter, input_ids, output, vocab_offset, vocab_size, input_size,
total_seq);
}
#else
// __memcpy 2D src_stride dtype is int
if ((size_t)(total_vocab_size - 1) * input_size * sizeof(T) <= (size_t)(MAX_SINT32)) {
embeddingImpl_300(filter, input_ids, output, vocab_offset, vocab_size, input_size, total_seq);
} else {
embeddingImpl_generic(filter, input_ids, output, vocab_offset, vocab_size, input_size,
total_seq);
}
#endif
}
} // namespace kernels
KernelStatus invokeEmbedding(cnrtQueue_t queue,
void *filter,
void *input_ids,
void *output,
const cnnlDataType_t dtype,
int vocab_offset,
int vocab_size,
int total_vocab_size,
int input_size,
int total_seq) {
CNdev dev;
cnCtxGetDevice(&dev);
int cluster_num;
int core_num;
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev));
cnrtDim3_t dim{.x = (uint32_t)core_num, .y = (uint32_t)cluster_num, .z = 1};
if (dtype == CNNL_DTYPE_FLOAT) {
kernels::MLUEmbeddingKernel<<<dim, cnrtFuncTypeUnion1, queue>>>(
static_cast<float *>(filter), (int *)input_ids, static_cast<float *>(output), vocab_offset,
vocab_size, total_vocab_size, input_size, total_seq);
} else if (dtype == CNNL_DTYPE_HALF) {
kernels::MLUEmbeddingKernel<<<dim, cnrtFuncTypeUnion1, queue>>>(
static_cast<half *>(filter), (int *)input_ids, static_cast<half *>(output), vocab_offset,
vocab_size, total_vocab_size, input_size, total_seq);
} else if (dtype == CNNL_DTYPE_BFLOAT16) {
if (!isBf16Supported()) {
std::cerr << "[invokeEmbedding]: MLU300 devices do not support bfloat16." << std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
kernels::MLUEmbeddingKernel<<<dim, cnrtFuncTypeUnion1, queue>>>(
static_cast<bfloat16_t *>(filter), (int *)input_ids, static_cast<bfloat16_t *>(output),
vocab_offset, vocab_size, total_vocab_size, input_size, total_seq);
}
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
} // namespace tmo

View File

@@ -0,0 +1,63 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_EMBEDDING_MLUH_
#define CSRC_KERNELS_EMBEDDING_MLUH_
#include "cnnl.h"
#include "kernel_utils.h"
namespace tmo {
/**
* @brief Look up table for ids which greater than vocab_offset and less than
* vocab_offset + vocab_size, and write the results back to the position
* corresponding to the ids. For ids that are not in the range, write 0
* to the corresponding position.
* @example
* filter:
* [[1, 2, 3, 4],
* [5, 6, 7, 8],
* [4, 3, 2, 1]]
* input_ids:
* [[1, 5, 6, 7, 8, 9]]
* vocab_offset = 5
* vocab_size = 3
* input_size = 4
* total_seq = 6
* output:
* [[0, 0, 0, 0], [1, 2, 3, 4], [5, 6, 7, 8],
* [4, 3, 2, 1], [0, 0, 0, 0], [0, 0, 0, 0]]
* @param queue: The queue for mlu.
* @param filter: Input. Pointer to the MLU memory that stores the embedding table,
* the shape must be [vocab_size, input_size].
* @param input_ids: Input. Pointer to the MLU memory that stores the token id,
* the shape must be [batch, seq].
* @param output: Output. Pointer to the MLU memory that stores the output,
* the shape must be [batch, seq, input_size].
* @param dtype: Data type.
* @param vocab_offset: embedding table offset.
* @param vocab_size: embedding table size.
* @param total_vocab_size: total embedding table size.
* @param input_size: embedding dim.
* @param total_seq: Total sequence length.
*/
KernelStatus invokeEmbedding(cnrtQueue_t queue,
void *filter,
void *input_ids,
void *output,
const cnnlDataType_t dtype,
int vocab_offset,
int vocab_size,
int total_vocab_size,
int input_size,
int total_seq);
} // namespace tmo
#endif // CSRC_KERNELS_EMBEDDING_MLUH_

View File

@@ -0,0 +1,658 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "cnnl.h"
#include "cnrt.h"
#include "fused_rope.mluh"
// clang-format off
#include <mlu.h>
// clang-format on
using bf16 = bfloat16_t;
namespace tmo {
namespace kernels {
#ifndef PAD_UP
#define PAD_UP(x, y) ((x) / (y) + (int)((x) % (y) > 0) * (y))
#endif
#if __BANG_ARCH__ > 500
#include <bang_fusor.h>
template <typename T>
using bang_cycle_fusor = bang::experimental::cycle_fusor<T>;
#endif
#define NRAM_BUFFER_SIZE (480 * 1024)
__nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE];
__nram__ float nram_mask[256];
__nram__ int nram_rope_offsets[256];
__nram__ float nram_zeros[1024] = {0.f};
template <typename T>
__mlu_func__ void toFloat(float *dst, T *src, int num) {
if (std::is_same<T, half>::value) {
__bang_half2float(dst, (half *)src, num);
} else if (std::is_same<T, bf16>::value) {
#if __BANG_ARCH__ > 500
__bang_bfloat162float(dst, (bf16 *)src, num);
#endif
}
}
template <typename T>
__mlu_func__ void floatTo(T *dst, float *src, int num) {
if (std::is_same<T, half>::value) {
__bang_float2half_rn((half *)dst, src, num);
} else if (std::is_same<T, bf16>::value) {
__bang_float2bfloat16_rn((bf16 *)dst, src, num);
}
}
__mlu_func__ void genScatterOffsetMask(int *cache_bs_id_begin,
int *cache_seq_offsets_begin,
int *slot_mapping_begin,
int *nram_k_cache_offsets,
int *nram_v_cache_offsets,
int *nram_v_onchip_offsets,
int *nram_kv_scale_offsets,
float *nram_cache_mask,
float *nram_zeros,
float *nram_temp,
int task_deal_batch,
int task_begin_batch,
int head_num_k,
int head_size,
int max_decode_len,
int block_size,
int kv_out_size,
int group_num,
bool discrete_batch,
bool paged_cache,
bool mixed_cache) {
// 目前先用标量化计算offset便于理解(性能无影响)
int bh = task_deal_batch * head_num_k;
if (paged_cache) {
int cache_seq_stride = head_size;
int cache_head_stride = block_size * head_size;
int cache_scale_head_stride = block_size * group_num;
int cache_block_stride = head_num_k * cache_head_stride;
int cache_scale_block_stride = head_num_k * cache_scale_head_stride;
int *nram_slot_mapping = (int *)nram_mask;
__memcpy(nram_slot_mapping, slot_mapping_begin, task_deal_batch * sizeof(int), GDRAM2NRAM);
for (int i = 0; i < task_deal_batch; i++) {
int mapping_idx = __load_nram(nram_slot_mapping + i);
if (mapping_idx < 0) {
__bang_write_value(nram_k_cache_offsets + i * head_num_k, head_num_k, (int)-1);
continue;
}
int block_idx = mapping_idx / block_size;
int seq_idx = mapping_idx % block_size;
int k_seq_offset = block_idx * cache_block_stride + seq_idx * cache_seq_stride;
int v_seq_offset = block_idx * cache_block_stride / 2 + seq_idx / 2 * cache_seq_stride;
int scale_seq_offset = block_idx * cache_scale_block_stride + seq_idx * group_num;
int onchip_offset = i * head_num_k * head_size + seq_idx % 2 * bh * head_size;
for (int j = 0; j < head_num_k; j++) {
__store_nram(
nram_k_cache_offsets + i * head_num_k + j,
(int)((k_seq_offset + j * cache_head_stride) * kv_out_size / (mixed_cache + 1)));
if (mixed_cache) {
__store_nram(nram_v_cache_offsets + i * head_num_k + j,
(int)(v_seq_offset + j * cache_head_stride / 2));
__store_nram(nram_kv_scale_offsets + i * head_num_k + j,
(int)((scale_seq_offset + j * cache_scale_head_stride) * sizeof(float)));
__store_nram(nram_v_onchip_offsets + i * head_num_k + j, onchip_offset + j * head_size);
}
}
}
} else {
int *nram_seq_offsets = (int *)nram_mask;
int *nram_bs_id = nram_seq_offsets + 32;
int cache_seq_stride = head_size;
int cache_head_stride = max_decode_len * head_size;
int cache_scale_head_stride = max_decode_len * group_num;
int cache_bs_stride = head_num_k * cache_head_stride;
int cache_scale_bs_stride = head_num_k * cache_scale_head_stride;
__memcpy(nram_seq_offsets, cache_seq_offsets_begin, task_deal_batch * sizeof(int), GDRAM2NRAM);
if (discrete_batch) {
__memcpy(nram_bs_id, cache_bs_id_begin, task_deal_batch * sizeof(int), GDRAM2NRAM);
}
for (int i = 0; i < task_deal_batch; i++) {
int bs_idx = __load_nram(nram_bs_id + i);
int seq_idx = __load_nram(nram_seq_offsets + i);
int temp_bs_idx = discrete_batch ? bs_idx : task_begin_batch + i;
int temp_seq_idx = seq_idx;
bool masked = temp_bs_idx < 0 || temp_seq_idx < 0;
if (masked) {
__bang_write_value(nram_k_cache_offsets + i * head_num_k, head_num_k, (int)-1);
continue;
}
int k_seq_offset = temp_bs_idx * cache_bs_stride + temp_seq_idx * cache_seq_stride;
int scale_seq_offset = temp_bs_idx * cache_scale_bs_stride + temp_seq_idx * group_num;
int v_seq_offset = temp_bs_idx * cache_bs_stride / 2 + temp_seq_idx / 2 * cache_seq_stride;
int onchip_offset = i * head_num_k * head_size + temp_seq_idx % 2 * bh * head_size;
for (int j = 0; j < head_num_k; j++) {
__store_nram(
nram_k_cache_offsets + i * head_num_k + j,
(int)((k_seq_offset + j * cache_head_stride) * kv_out_size / (mixed_cache + 1)));
if (mixed_cache) {
__store_nram(nram_v_cache_offsets + i * head_num_k + j,
(int)(v_seq_offset + j * cache_head_stride / 2));
__store_nram(nram_kv_scale_offsets + i * head_num_k + j,
(int)((scale_seq_offset + j * cache_scale_head_stride) * sizeof(float)));
__store_nram(nram_v_onchip_offsets + i * head_num_k + j, onchip_offset + j * head_size);
}
}
}
}
// 此处是为了做上scatter指令的 mask如果bs offset或seq offset小于0则需要mask掉
__bang_int322float(nram_temp, nram_k_cache_offsets, bh, 0);
__bang_ge_bitindex(nram_cache_mask, nram_temp, nram_zeros, PAD_UP(bh, 8));
}
__mlu_func__ void layernormImpl(float *nram_k,
float *norm_params,
int task_deal_batch,
int k_hidden,
float eps) {
#if __BANG_ARCH__ > 500
float *buffer = nram_k + task_deal_batch * k_hidden;
for (int i = 0; i < task_deal_batch; i++) {
float *k_ = nram_k + i * k_hidden;
__bang_mul(buffer, k_, k_, k_hidden);
float mean = __bang_sum(k_, k_hidden);
mean = mean / k_hidden;
float rstd = __bang_sum(buffer, k_hidden);
rstd = rstd / k_hidden - mean * mean;
rstd = rstd < 0 ? eps : rstd + eps;
rstd = 1.f / std::sqrt(rstd);
__bang_fusion(FUSION_FSM, k_, k_, mean, rstd, k_hidden);
}
__bang_fusion(FUSION_FMA, nram_k, nram_k, norm_params, norm_params + k_hidden,
task_deal_batch * k_hidden, k_hidden);
#endif
}
__mlu_func__ void foldRotaryImpl(float *nram_qk,
float *nram_qk_rot,
float *nram_table,
int task_deal_batch,
int head_num_qk,
int head_size) {
int rotary_low_dim = task_deal_batch * head_size;
__bang_cycle_mul(nram_qk, nram_qk, nram_table, head_num_qk * rotary_low_dim, rotary_low_dim);
__bang_cycle_mul(nram_qk_rot, nram_qk_rot, nram_mask, head_num_qk * rotary_low_dim, head_size);
__bang_cycle_mul(nram_qk_rot, nram_qk_rot, nram_table + task_deal_batch * head_size,
head_num_qk * rotary_low_dim, rotary_low_dim);
__bang_add(nram_qk, nram_qk, nram_qk_rot, head_num_qk * rotary_low_dim);
}
template <typename T>
__mlu_func__ void quantify(T *input,
float *float_input,
void *output_hp,
void *output_lp,
float *nram_trans,
float *scale_hp,
float *scale_lp,
float *scale_lp_temp,
int batch,
int head_num,
int head_size,
int group_num,
int group_size,
bool quant_kv_hp,
bool mixed_cache) {
if (quant_kv_hp) {
int hidden = head_num * head_size;
int bh = batch * head_num;
toFloat<T>(float_input, input, batch * hidden);
__bang_recip(scale_hp, scale_hp, hidden);
#if __BANG_ARCH__ > 500
__asm__ __volatile__(
"fuse.nram.crn.s8.f32 "
"[%[dst]], %[num_long], %[num_short], [%[src0]], .mul.cycle([%[src1]]), .dstpos(%[pos])"
";\n\t" ::[dst] "r"(output_hp),
[num_long] "r"(batch * hidden), [num_short] "r"(hidden), [src0] "r"(float_input),
[src1] "r"(scale_hp), [pos] "i"(0));
#endif
if (mixed_cache) {
__bang_transpose(nram_trans, float_input, bh * group_num, group_size);
__bang_abs(float_input, nram_trans, bh * head_size);
__bang_maxpool(scale_lp_temp, float_input, bh * group_num, group_size, 1, group_size, 1, 1,
1);
__bang_mul_scalar(scale_lp, scale_lp_temp, 1 / 7.f, bh * group_num);
__bang_recip(scale_lp_temp, scale_lp, bh * group_num);
__bang_cycle_mul(nram_trans, nram_trans, scale_lp_temp, bh * group_num * group_size,
bh * group_num);
__bang_float2int8_rn((int8_t *)nram_trans, nram_trans, bh * head_size, 0);
__bang_transpose((int8_t *)output_lp, (int8_t *)nram_trans, group_size, bh * group_num);
}
}
}
template <typename T>
__mlu_func__ void fuseRopeImpl(T *input,
void *key_cache_hp,
void *value_cache_hp,
void *key_cache_lp,
void *value_cache_lp,
T *sin_table,
T *cos_table,
int *rope_offsets,
T *gamma,
T *beta,
float *key_scale_hp,
float *value_scale_hp,
float *key_scale_lp,
float *value_scale_lp,
int *cache_bs_id_hp,
int *cache_seq_offsets_hp,
int *cache_bs_id_lp,
int *cache_seq_offsets_lp,
int *slot_mapping_hp,
int *slot_mapping_lp,
int rotary_stride,
int task_deal_batch,
int task_begin_batch,
int head_num_q,
int head_num_k,
int head_size,
int max_decode_len_hp,
int max_decode_len_lp,
int block_size_hp,
int block_size_lp,
int group_size,
int batch_cap,
float eps) {
#if __BANG_ARCH__ > 500
/*
由于需要支持mixed cachekernel支持cache的功能组合比较多现规定只存在以下几种
1.只存在hp_cache的情况(通过lp tensor不为0判断)cache支持bf16fp16量化下支持离线perchannel int8
支持linear和pagedkey和value cache形状一致key/value_scale_hp形状为[head_num, head_size]
2.mixed cache的情况hp支持离线perchannel int8量化支持linear和pagedkey和value cache形状一致
key/value_scale_hp 形状为[head_num, head_size]. lp支持int4在线pertoken group量化
key_cache形状为 [batch, head_num_k, max_decode_len_lp, head_size / 2]
paged情况也是head_size / 2, value_cache的形状为[batch, head_num_l, max_decode_len_lp / 2,
head_size]paged cache形状为 [num_blocks, head_num_k, block_size / 2,
head_size]key/value_scale_lp形状为 [batch, head_num_k, max_decode_len_lp,
group_num]paged_cache 为 [num_blocks, head_num_k, block_size, group_num]
*/
bool mixed_cache = key_cache_lp != nullptr && value_cache_lp != nullptr;
bool quant_kv_hp = key_scale_hp != nullptr && value_scale_hp != nullptr;
bool discrete_batch_hp = cache_bs_id_hp != nullptr;
bool discrete_batch_lp = cache_bs_id_lp != nullptr;
bool paged_cache_hp = slot_mapping_hp != nullptr;
bool paged_cache_lp = slot_mapping_lp != nullptr;
int head_num_qk = head_num_q + head_num_k;
int head_num_qkv = head_num_q + head_num_k * 2;
int qkv_hidden = head_num_qkv * head_size;
int qk_hidden = head_num_qk * head_size;
int q_hidden = head_num_q * head_size;
int k_hidden = head_num_k * head_size;
int float_size = sizeof(float);
int dtype_size = sizeof(T);
int kv_size_hp = quant_kv_hp ? sizeof(int8_t) : dtype_size;
int group_num = mixed_cache ? head_size / group_size : 1;
// task ddr offset
T *input_begin = input + task_begin_batch * qkv_hidden;
int *cache_bs_id_begin_hp = cache_bs_id_hp + task_begin_batch;
int *cache_seq_offsets_begin_hp = cache_seq_offsets_hp + task_begin_batch;
int *slot_mapping_begin_hp = slot_mapping_hp + task_begin_batch;
// nram_buffer
float *nram_qk = (float *)nram_buffer;
float *nram_qk_rot = nram_qk + batch_cap * qk_hidden;
float *nram_v = nram_qk_rot + batch_cap * qk_hidden;
float *nram_kv_trans = nram_v + batch_cap * k_hidden;
float *nram_table = nram_kv_trans + (int)mixed_cache * batch_cap * k_hidden;
float *norm_params = nram_table + 2 * batch_cap * head_size;
float *nram_k_scale_hp = norm_params + 2 * head_size;
float *nram_v_scale_hp = nram_k_scale_hp + (int)quant_kv_hp * k_hidden;
float *nram_k_scale_lp = nram_v_scale_hp + (int)quant_kv_hp * k_hidden;
float *nram_v_scale_lp = nram_k_scale_lp + (int)mixed_cache * batch_cap * head_num_k * group_num;
int8_t *nram_kv_hp =
(int8_t *)(nram_v_scale_lp + (int)mixed_cache * batch_cap * head_num_k * group_num);
int8_t *nram_kv_lp = nram_kv_hp + (int)quant_kv_hp * batch_cap * k_hidden;
int8_t *nram_cache_v = nram_kv_lp + (int)mixed_cache * batch_cap * k_hidden;
int *nram_kv_cache_offsets_hp =
(int *)(nram_cache_v + (int)mixed_cache * batch_cap * k_hidden * 2);
int *nram_k_cache_offsets_lp = nram_kv_cache_offsets_hp + batch_cap * head_num_k;
int *nram_v_cache_offsets_lp =
nram_k_cache_offsets_lp + (int)mixed_cache * batch_cap * head_num_k;
int *nram_kv_scale_offsets = nram_v_cache_offsets_lp + (int)mixed_cache * batch_cap * head_num_k;
int *nram_v_onchip_offsets = nram_kv_scale_offsets + (int)mixed_cache * batch_cap * head_num_k;
float *cache_mask_hp =
(float *)(nram_v_onchip_offsets + (int)mixed_cache * batch_cap * head_num_k);
float *cache_mask_lp = (float *)((int8_t *)cache_mask_hp + PAD_UP(batch_cap * head_num_k, 8) / 8);
// 这里将qk和qk_rot放在一起是为升位宽可以一起做减少指令同样还有sincostable和norm的gamma和beta
T *qk_in = (T *)nram_qk_rot;
T *qk_rot_in = (T *)((int8_t *)nram_qk_rot + (float_size - dtype_size) * batch_cap * qk_hidden);
T *v_in =
(T *)((int8_t *)nram_v + (int)quant_kv_hp * (float_size - dtype_size) * batch_cap * k_hidden);
T *norm_params_in = (T *)((int8_t *)norm_params + (float_size - dtype_size) * 2 * head_size);
T *table_in = (T *)((int8_t *)nram_table + (float_size - dtype_size) * 2 * batch_cap * head_size);
int8_t *nram_cache_v_in = nram_cache_v + batch_cap * k_hidden * sizeof(int8_t);
// 生成 kv cache的offset和mask供scatter kv到kvcache使用
genScatterOffsetMask(cache_bs_id_begin_hp, cache_seq_offsets_begin_hp, slot_mapping_begin_hp,
nram_kv_cache_offsets_hp, nullptr, nullptr, nullptr, cache_mask_hp,
nram_zeros, nram_qk, task_deal_batch, task_begin_batch, head_num_k,
head_size, max_decode_len_hp, block_size_hp, kv_size_hp, 1,
discrete_batch_hp, paged_cache_hp, false);
if (mixed_cache) {
int *cache_bs_id_begin_lp = cache_bs_id_lp + task_begin_batch;
int *cache_seq_offsets_begin_lp = cache_seq_offsets_lp + task_begin_batch;
int *slot_mapping_begin_lp = slot_mapping_lp + task_begin_batch;
genScatterOffsetMask(cache_bs_id_begin_lp, cache_seq_offsets_begin_lp, slot_mapping_begin_lp,
nram_k_cache_offsets_lp, nram_v_cache_offsets_lp, nram_v_onchip_offsets,
nram_kv_scale_offsets, cache_mask_lp, nram_zeros, nram_qk, task_deal_batch,
task_begin_batch, head_num_k, head_size, max_decode_len_lp, block_size_lp,
1, group_num, discrete_batch_lp, paged_cache_lp, mixed_cache);
}
/*
-----------------------
load v |
-----------------------
load qk | quant v
-----------------------
store v | rope qk
-----------------------
store_q | layernorm k
| quant k
-----------------------
store k |
*/
// prepare v v_scale cache_v rope_offset
__memcpy_async(v_in, input_begin + qk_hidden, k_hidden * dtype_size, GDRAM2NRAM,
k_hidden * dtype_size, qkv_hidden * dtype_size, task_deal_batch - 1);
if (quant_kv_hp) {
__memcpy_async(nram_k_scale_hp, key_scale_hp, k_hidden * float_size, GDRAM2NRAM);
__memcpy_async(nram_v_scale_hp, value_scale_hp, k_hidden * float_size, GDRAM2NRAM);
}
__memcpy_async(nram_rope_offsets, rope_offsets + task_begin_batch, task_deal_batch * sizeof(int),
GDRAM2NRAM);
__sync_io();
if (mixed_cache) {
__gather(nram_cache_v_in, value_cache_lp, (uint32_t *)nram_v_cache_offsets_lp, cache_mask_lp,
head_size * sizeof(int8_t), GDRAM2NRAM, head_size * sizeof(int8_t),
task_deal_batch * head_num_k);
}
__bang_mul_scalar(nram_rope_offsets, nram_rope_offsets, rotary_stride * dtype_size,
task_deal_batch);
__sync_compute();
/*==============================================================================================*/
// load_qk,rope_table | quant v
__memcpy_async(qk_in, input_begin, head_size * dtype_size, GDRAM2NRAM, head_size * dtype_size,
task_deal_batch - 1, task_deal_batch * head_size * dtype_size, head_num_qk - 1,
qkv_hidden * dtype_size, task_deal_batch - 1, head_size * dtype_size,
head_num_qk - 1);
__gather_async(table_in, cos_table, (uint32_t *)nram_rope_offsets, head_size * dtype_size,
GDRAM2NRAM, head_size * dtype_size, task_deal_batch);
__gather_async(table_in + task_deal_batch * head_size, sin_table, (uint32_t *)nram_rope_offsets,
head_size * dtype_size, GDRAM2NRAM, head_size * dtype_size, task_deal_batch);
__memcpy_async(norm_params_in, gamma, head_size * dtype_size, GDRAM2NRAM);
__memcpy_async(norm_params_in + head_size, beta, head_size * dtype_size, GDRAM2NRAM);
int8_t *nram_temp = (int8_t *)nram_qk;
if (mixed_cache) {
__bang_int42int8(nram_cache_v, (int4x2_t *)nram_cache_v_in, task_deal_batch * k_hidden * 2, 0,
0);
__bang_transpose(nram_temp, nram_cache_v, task_deal_batch * k_hidden, 2);
}
quantify<T>(v_in, nram_v, nram_kv_hp, nram_kv_lp, nram_kv_trans, nram_v_scale_hp, nram_v_scale_lp,
nram_k_scale_lp /*lp_scale temp*/, task_deal_batch, head_num_k, head_size, group_num,
group_size, quant_kv_hp, mixed_cache);
if (mixed_cache) {
__scatter(nram_temp, nram_kv_lp, (uint32_t *)nram_v_onchip_offsets, cache_mask_lp,
head_size * sizeof(int8_t), NRAM2NRAM, head_size * sizeof(int8_t),
task_deal_batch * head_num_k);
__bang_transpose(nram_cache_v, nram_temp, 2, task_deal_batch * k_hidden);
__bang_int82int4_rn((int4x2_t *)nram_cache_v, nram_cache_v, task_deal_batch * k_hidden * 2, 0,
0);
}
__sync_io_move_compute();
/*==============================================================================================*/
// rope | store v
// 将qk的左右部分交换用于生成qk_rot
__memcpy(qk_rot_in, qk_in + head_size / 2, head_size / 2 * dtype_size, NRAM2NRAM,
head_size * dtype_size, head_size * dtype_size, task_deal_batch * head_num_qk - 1);
__memcpy(qk_rot_in + head_size / 2, qk_in, head_size / 2 * dtype_size, NRAM2NRAM,
head_size * dtype_size, head_size * dtype_size, task_deal_batch * head_num_qk - 1);
toFloat<T>(nram_qk, qk_in, 2 * batch_cap * qk_hidden);
toFloat<T>(nram_table, table_in, 2 * task_deal_batch * head_size);
toFloat<T>(norm_params, norm_params_in, 2 * head_size);
__bang_write_value(nram_mask, head_size / 2, (float)-1);
__bang_write_value(nram_mask + head_size / 2, head_size / 2, (float)1);
foldRotaryImpl(nram_qk, nram_qk_rot, nram_table, task_deal_batch, head_num_qk, head_size);
floatTo<T>((T *)nram_qk, nram_qk, task_deal_batch * q_hidden);
int8_t *scatter_v_src = quant_kv_hp ? nram_kv_hp : (int8_t *)nram_v;
__scatter_async(value_cache_hp, scatter_v_src, (uint32_t *)nram_kv_cache_offsets_hp,
cache_mask_hp, head_size * kv_size_hp, NRAM2GDRAM, head_size * kv_size_hp,
head_num_k * task_deal_batch);
if (mixed_cache) {
__scatter_async(value_cache_lp, nram_cache_v, (uint32_t *)nram_v_cache_offsets_lp,
cache_mask_lp, head_size * sizeof(int8_t), NRAM2GDRAM,
head_size * sizeof(int8_t), head_num_k * task_deal_batch);
__scatter_async(value_scale_lp, nram_v_scale_lp, (uint32_t *)nram_kv_scale_offsets,
cache_mask_lp, group_num * sizeof(float), NRAM2GDRAM, group_num * sizeof(float),
head_num_k * task_deal_batch);
}
__sync_io_move_compute();
/*==============================================================================================*/
// layernrom k quant k | store q
// 从qk的nram buffer中提取出k做layernorm和量化
float *nram_k = nram_qk_rot;
__memcpy(nram_k, nram_qk + task_deal_batch * q_hidden, head_size * float_size, NRAM2NRAM,
head_size * float_size, head_num_k - 1, k_hidden * float_size, task_deal_batch - 1,
task_deal_batch * head_size * float_size, head_num_k - 1, head_size * float_size,
task_deal_batch - 1);
layernormImpl(nram_k, norm_params, task_deal_batch * head_num_k, head_size, eps);
quantify<float>(nram_k, nram_k, nram_kv_hp, nram_kv_lp, nram_kv_trans, nram_k_scale_hp,
nram_k_scale_lp, nram_v_scale_lp /*lp_scale temp*/, task_deal_batch, head_num_k,
head_size, group_num, group_size, quant_kv_hp, mixed_cache);
if (mixed_cache) {
__bang_int82int4_rn((int4x2_t *)nram_kv_lp, nram_kv_lp, task_deal_batch * k_hidden, 0, 0);
}
if (!quant_kv_hp) {
floatTo<T>((T *)nram_k, nram_k, task_deal_batch * k_hidden);
}
// store q
__memcpy_async(input_begin, nram_qk, head_size * dtype_size, NRAM2GDRAM, qkv_hidden * dtype_size,
task_deal_batch - 1, head_size * dtype_size, head_num_q - 1,
head_size * dtype_size, task_deal_batch - 1,
task_deal_batch * head_size * dtype_size, head_num_q - 1);
// ===============================================================================================
int8_t *scatter_k_src = quant_kv_hp ? nram_kv_hp : (int8_t *)nram_k;
__scatter(key_cache_hp, scatter_k_src, (uint32_t *)nram_kv_cache_offsets_hp, cache_mask_hp,
head_size * kv_size_hp, NRAM2GDRAM, head_size * kv_size_hp,
head_num_k * task_deal_batch);
if (mixed_cache) {
__scatter(key_cache_lp, nram_kv_lp, (uint32_t *)nram_k_cache_offsets_lp, cache_mask_lp,
head_size / 2 * sizeof(int8_t), NRAM2GDRAM, head_size / 2 * sizeof(int8_t),
head_num_k * task_deal_batch);
__scatter(key_scale_lp, nram_k_scale_lp, (uint32_t *)nram_kv_scale_offsets, cache_mask_lp,
group_num * sizeof(float), NRAM2GDRAM, group_num * sizeof(float),
head_num_k * task_deal_batch);
}
__sync_io_move_compute();
#endif
}
template <typename T>
__mlu_global__ void MLUFuseRope(T *input,
void *key_cache_hp,
void *value_cache_hp,
void *key_cache_lp,
void *value_cache_lp,
T *sin_table,
T *cos_table,
int *rope_offsets,
T *gamma,
T *beta,
float *key_scale_hp,
float *value_scale_hp,
float *key_scale_lp,
float *value_scale_lp,
int *cache_bs_id_hp,
int *cache_seq_offsets_hp,
int *cache_bs_id_lp,
int *cache_seq_offsets_lp,
int *slot_mapping_hp,
int *slot_mapping_lp,
int rotary_stride,
int batch,
int head_num_q,
int head_num_k,
int head_size,
int max_decode_len_hp,
int max_decode_len_lp,
int block_size_hp,
int block_size_lp,
int group_size,
int batch_cap,
int task_avg_batch,
float eps) {
int task_begin_batch = taskId * task_avg_batch;
int task_deal_batch = std::min(batch - task_begin_batch, task_avg_batch);
if (task_deal_batch <= 0 || __is_mpu()) {
return;
}
int task_loop = (task_deal_batch + batch_cap - 1) / batch_cap;
int once_batch = (task_deal_batch + task_loop - 1) / task_loop;
for (int i = 0; i < task_loop; i++) {
int cur_batch = std::min(task_deal_batch - i * once_batch, once_batch);
int batch_offset = task_begin_batch + once_batch * i;
fuseRopeImpl<T>(input, key_cache_hp, value_cache_hp, key_cache_lp, value_cache_lp, sin_table,
cos_table, rope_offsets, gamma, beta, key_scale_hp, value_scale_hp,
key_scale_lp, value_scale_lp, cache_bs_id_hp, cache_seq_offsets_hp,
cache_bs_id_lp, cache_seq_offsets_lp, slot_mapping_hp, slot_mapping_lp,
rotary_stride, cur_batch, batch_offset, head_num_q, head_num_k, head_size,
max_decode_len_hp, max_decode_len_lp, block_size_hp, block_size_lp, group_size,
once_batch, eps);
}
}
} // namespace kernels
KernelStatus invokeFusedRope(cnrtQueue_t queue,
void *input,
void *key_cache_hp,
void *value_cache_hp,
void *key_cache_lp,
void *value_cache_lp,
const void *sin_table,
const void *cos_table,
const void *rope_offsets,
const void *gamma,
const void *beta,
const void *key_scale_hp,
const void *value_scale_hp,
void *key_scale_lp,
void *value_scale_lp,
const void *cache_bs_id_hp,
const void *cache_seq_offsets_hp,
const void *cache_bs_id_lp,
const void *cache_seq_offsets_lp,
const void *slot_mapping_hp,
const void *slot_mapping_lp,
int rotary_stride,
int batch_size,
int head_num_q,
int head_num_kv,
int head_size,
int max_decode_len_hp,
int max_decode_len_lp,
int block_size_hp,
int block_size_lp,
int group_size,
cnnlDataType_t dtype,
float eps) {
if (is_arch300()) {
std::cerr << "[invokeFusedRope]: kernel does not support MLU300 devices." << std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
CNdev dev;
cnCtxGetDevice(&dev);
int cluster_num;
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
int core_num;
CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev));
uint32_t taskdimx = cluster_num * core_num;
int task_avg_batch = (batch_size + taskdimx - 1) / taskdimx;
int float_size = sizeof(float);
int group_num = head_size / group_size;
bool quant_kv_hp = key_scale_hp != nullptr && value_scale_hp != nullptr;
bool mixed_cache = key_cache_lp != nullptr && value_cache_lp != nullptr;
int nram_avalible_bytes = 480 * 1024;
int task_max_batch = 32;
int mask_bytes = PAD_UP(task_max_batch * head_num_kv, 8) / 8 * (mixed_cache + 1);
int nram_params_bytes = 2 * head_size * float_size;
int nram_kv_hp_scale_bytes = 2 * (int)quant_kv_hp * head_num_kv * head_size * float_size;
int nram_remain_bytes =
nram_avalible_bytes - nram_params_bytes - nram_kv_hp_scale_bytes - mask_bytes;
int nram_qk_bytes = (head_num_q + head_num_kv) * head_size * float_size * 2;
int nram_v_bytes = head_num_kv * head_size * float_size * (mixed_cache + 1);
int nram_table_bytes = 2 * head_size * float_size;
int nram_kv_lp_scale_bytes = 2 * (int)mixed_cache * head_num_kv * group_num * float_size;
int nram_kv_hp_bytes = (int)quant_kv_hp * head_num_kv * head_size;
int nram_kv_lp_bytes = (int)mixed_cache * head_num_kv * head_size;
int nram_cache_v_bytes = (int)mixed_cache * head_num_kv * head_size * 2;
int nram_cache_offsets_hp = head_num_kv * sizeof(int);
int nram_cache_offsets_lp = (int)mixed_cache * head_num_kv * 3 * sizeof(int);
int batch_cap =
nram_remain_bytes /
(nram_qk_bytes + nram_v_bytes + nram_table_bytes + nram_kv_lp_scale_bytes + nram_kv_hp_bytes +
nram_kv_lp_bytes + nram_cache_v_bytes + nram_cache_offsets_hp + nram_cache_offsets_lp);
batch_cap = batch_cap < task_avg_batch ? std::min(task_max_batch, batch_cap) : task_avg_batch;
cnrtDim3_t dim{taskdimx, 1, 1};
if (dtype == CNNL_DTYPE_HALF) {
kernels::MLUFuseRope<<<dim, cnrtFuncTypeBlock, queue>>>(
(half *)input, key_cache_hp, value_cache_hp, key_cache_lp, value_cache_lp,
(half *)sin_table, (half *)cos_table, (int *)rope_offsets, (half *)gamma, (half *)beta,
(float *)key_scale_hp, (float *)value_scale_hp, (float *)key_scale_lp,
(float *)value_scale_lp, (int *)cache_bs_id_hp, (int *)cache_seq_offsets_hp,
(int *)cache_bs_id_lp, (int *)cache_seq_offsets_lp, (int *)slot_mapping_hp,
(int *)slot_mapping_lp, rotary_stride, batch_size, head_num_q, head_num_kv, head_size,
max_decode_len_hp, max_decode_len_lp, block_size_hp, block_size_lp, group_size, batch_cap,
task_avg_batch, eps);
} else if (dtype == CNNL_DTYPE_BFLOAT16) {
kernels::MLUFuseRope<<<dim, cnrtFuncTypeBlock, queue>>>(
(bf16 *)input, key_cache_hp, value_cache_hp, key_cache_lp, value_cache_lp,
(bf16 *)sin_table, (bf16 *)cos_table, (int *)rope_offsets, (bf16 *)gamma, (bf16 *)beta,
(float *)key_scale_hp, (float *)value_scale_hp, (float *)key_scale_lp,
(float *)value_scale_lp, (int *)cache_bs_id_hp, (int *)cache_seq_offsets_hp,
(int *)cache_bs_id_lp, (int *)cache_seq_offsets_lp, (int *)slot_mapping_hp,
(int *)slot_mapping_lp, rotary_stride, batch_size, head_num_q, head_num_kv, head_size,
max_decode_len_hp, max_decode_len_lp, block_size_hp, block_size_lp, group_size, batch_cap,
task_avg_batch, eps);
}
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
} // namespace tmo

View File

@@ -0,0 +1,119 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_FUSE_ROPE_FUSE_ROPE_MLUH_
#define CSRC_KERNELS_FUSE_ROPE_FUSE_ROPE_MLUH_
#include "cnnl.h"
#include "kernel_utils.h"
namespace tmo {
/**
* @brief Apply query and kery rotary embedding, key layernorm and
* quantize key and value to kv cache.
* @param queue: The queue for mlu.
* @param input: Input/Output. Pointer to the MLU memory that stores the input,
* the shape must be [batch_size, 1, head_num_q + head_num_kv * 2, head_size].
* @param key_cache_hp: Input/Output. Pointer to the MLU memory that stores the high precision key
* cache , the shape must be [max_bs, head_num_kv, max_decode_len, head_size] or [num_blocks,
* head_num_kv, block_size, head_size].
* @param value_cache_hp: Input/Output. Pointer to the MLU memory that stores the high precision
* value cache, the shape must be [max_bs, head_num_kv, max_decode_len, head_size] or [num_blocks,
* head_num_kv, block_size, head_size].
* @param key_cache_lp: Input/Output. Pointer to the MLU memory that stores the low precision key
* cache , the shape must be [max_bs, head_num_kv, max_decode_len, head_size] or [num_blocks,
* head_num_kv, block_size, head_size].
* @param value_cache_lp: Input/Output. Pointer to the MLU memory that stores the low precision
* value cache, the shape must be [max_bs, head_num_kv, max_decode_len, head_size] or [num_blocks,
* head_num_kv, block_size, head_size].
* @param sin_table: Input. Pointer to the MLU memory that stores the sin value, may not be
* continous. The shape must be [rotary_seq_len, rotary_dim].
* @param cos_table: Input. Pointer to the MLU memory that stores the cos value, may not be
* continous. The shape must be [rotary_seq_len, rotary_dim].
* @param seq_offsets: Input. Pointer to the MLU memory that stores the sequene offsets of each
* batch. The shape must be [batch].
* @param norm_gamma: Input. Pointer to the MLU memory that stores the gamma param of layernorm.
* @param norm_beta: Input. Pointer to the MLU memory that stores the beta param of layernorm.
* @param key_scale_hp: Input. Pointer to the MLU memory that stores the scales of high precision
* key. The shape must be [head_num_kv, head_size]. If key_scale is nullptr,
* that means key do not need to be quantized.
* @param value_scale_hp: Input. Pointer to the MLU memory that stores the scales of high precision
* value. The shape must be [head_num_kv, head_size]. If value_scale is nullptr,
* that means value do not need to be quantized.
* @param key_scale_lp: Input/Output. Pointer to the MLU memory that stores the scales of low
* precision key. The shape must be [batch_size, head_num_kv, max_deocde_len, group_num] or
* [num_blocks, head_num_kv, block_size, group_num].
* @param value_scale_lp: Input/Output. Pointer to the MLU memory that stores the scales of low
* precision value. The shape must be [batch_size, head_num_kv, max_deocde_len, group_num] or
* [num_blocks, head_num_kv, block_size, group_num].
* @param cache_bs_id_hp: Input. Pointer to the MLU memory that stores the batch
* offset of high precision cache, the shape must be [batch], if it's nullptr, the
* default value is {0, 1, 2 ... batch - 1}.
* @param cache_seq_offsets_hp: Input. Pointer to the MLU memory that stores the sequence
* offset of high precision cache, the shape must be [batch].
* @param cache_bs_id_lp: Input. Pointer to the MLU memory that stores the batch
* offset of low precision cache, the shape must be [batch], if it's nullptr, the
* default value is {0, 1, 2 ... batch - 1}.
* @param cache_seq_offsets_lp: Input. Pointer to the MLU memory that stores the sequence
* offset of low precision cache, the shape must be [batch].
* @param slot_mapping_hp: Input. Pointer to the MLU memory that stores the slot_mapping tensor
* which has shape [batch]. Data type of slot mapping must be int32_t.
* @param slot_mapping_lp: Input. Pointer to the MLU memory that stores the slot_mapping tensor
* which has shape [batch]. Data type of slot mapping must be int32_t.
* @param rotary_stride: The stride of rotary_seq_len in sin_table and cos_table.
* @param batch_size: Batch size.
* @param head_num_q: Head number of query.
* @param head_num_kv: Head number of key and value.
* @param head_size: Head size. For simplify, the rotary dim must be the same as head_size.
* @param max_decode_len_hp: The maximum sequence length of high precision cache.
* @param max_decode_len_lp: The maximum sequence length of low precision cache.
* @param block_size_hp: Number of tokens per block of high precision cache.
* @param block_size_lp: Number of tokens per block of low precision cache.
* @param data_type: Data type of all inputs and outputs.
* @param eps: float number use for layernorm.
* @note: Head_num_q and head_num_kv must be in range [1, 32].
* Head_size must be in range [1, 128], and must be divided by 2.
*/
KernelStatus invokeFusedRope(cnrtQueue_t queue,
void *input,
void *key_cache_hp,
void *value_cache_hp,
void *key_cache_lp,
void *value_cache_lp,
const void *sin_table,
const void *cos_table,
const void *rope_offsets,
const void *gamma,
const void *beta,
const void *key_scale_hp,
const void *value_scale_hp,
void *key_scale_lp,
void *value_scale_lp,
const void *cache_bs_id_hp,
const void *cache_seq_offsets_hp,
const void *cache_bs_id_lp,
const void *cache_seq_offsets_lp,
const void *slot_mapping_hp,
const void *slot_mapping_lp,
int rotary_stride,
int batch_size,
int head_num_q,
int head_num_kv,
int head_size,
int max_decode_len_hp,
int max_decode_len_lp,
int block_size_hp,
int block_size_lp,
int group_size,
cnnlDataType_t dtype,
float eps);
} // namespace tmo
#endif // CSRC_KERNELS_FUSE_ROPE_FUSE_ROPE_MLUH_

View File

@@ -0,0 +1,130 @@
#include <algorithm>
#include <cassert>
#include <cmath>
#include <iostream>
#include <map>
#include <ostream>
#include "cnnl.h"
#include "cnrt.h"
#include "generate_alibi_slope.mluh"
// clang-format off
#include <mlu.h>
// clang-format on
namespace tmo {
namespace kernels {
#define NRAM_SIZE (__MLU_NRAM_SIZE__ * 1024 - 32 * 1024)
__nram__ int8_t nram_buffer[NRAM_SIZE];
__nram__ float range_1[64] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,
33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48,
49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64};
__nram__ float range_2[64] = {1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25,
27, 29, 31, 33, 35, 37, 39, 41, 43, 45, 47, 49, 51,
53, 55, 57, 59, 61, 63, 65, 67, 69, 71, 73, 75, 77,
79, 81, 83, 85, 87, 89, 91, 93, 95, 97, 99, 101, 103,
105, 107, 109, 111, 113, 115, 117, 119, 121, 123, 125, 127};
__mlu_func__ void genRange(float *range_nram,
float *range_base_nram,
int fill_num,
int base_num,
int offset = 0) {
int loop = (fill_num + base_num - 1) / base_num;
for (int i = 0; i < loop; i++) {
int num = std::min((fill_num - i * base_num), base_num);
float *fill_nram = range_nram + i * base_num;
__bang_move(fill_nram, range_base_nram, num * sizeof(float));
__bang_add_scalar(fill_nram, fill_nram, i * base_num + offset, num);
}
}
__mlu_global__ void MLUAlibiSlopeKernel(float *alibi_slopes,
int *true_seq_lens,
int batch_num,
int head_start,
int head_num,
int head_num_total,
int max_sequence_length,
bool use_dynamic,
int closest_power_of_2,
int farthest_power_of_2,
float base,
float extra_base) {
float *range_nram = (float *)nram_buffer;
float *base_nram = range_nram + head_num;
float *slope_nram = base_nram + head_num;
float scale = 1.0;
float dynamic_base = base;
if (use_dynamic) {
float a0 = 1.0;
float a = a0 * true_seq_lens[taskIdX] / max_sequence_length;
a = std::max(a, 1.0f);
scale = powf(a, (1.0 / (head_num_total - 1)));
dynamic_base = base / scale;
}
int close_head_num = 0;
if (head_start >= closest_power_of_2) {
close_head_num = 0;
} else if (head_start + head_num <= closest_power_of_2) {
close_head_num = head_num;
} else {
close_head_num = closest_power_of_2 - head_start;
}
int far_head_num = head_num - close_head_num;
// fill range: 1, 2..., n1, 1, 3, (n - n1) * 2 - 1
if (close_head_num) {
genRange(range_nram, range_1, close_head_num, 64, head_start);
__bang_write_value(base_nram, close_head_num, dynamic_base);
}
if (far_head_num) {
genRange(range_nram + close_head_num, range_2, far_head_num, 64,
(head_start + close_head_num - closest_power_of_2) * 2);
__bang_write_value(base_nram + close_head_num, far_head_num, extra_base);
}
// base_nram ** range_nram
__bang_log(base_nram, base_nram, head_num);
__bang_mul(slope_nram, base_nram, range_nram, head_num);
__bang_pow2(slope_nram, slope_nram, head_num);
if (use_dynamic) {
__bang_mul_scalar(slope_nram, slope_nram, scale, close_head_num);
}
__memcpy(alibi_slopes + taskIdX * head_num, slope_nram, head_num * sizeof(float), NRAM2GDRAM);
}
} // namespace kernels
KernelStatus invokeGenerateAlibiSlope(cnrtQueue_t queue,
void *alibi_slopes,
void *true_seq_lens,
int batch_num,
int head_start,
int head_num,
int head_num_total,
int max_sequence_length,
bool use_dynamic) {
cnrtDim3_t dim{.x = (uint32_t)batch_num, .y = 1, .z = 1};
int closest_power_of_2 = pow(2, floor(log2(head_num_total)));
int farthest_power_of_2 = closest_power_of_2 * 2;
float base = pow(2, (-pow(2, -(log2(closest_power_of_2) - 3))));
float extra_base = pow(2, (-pow(2, -(log2(2 * closest_power_of_2) - 3))));
kernels::MLUAlibiSlopeKernel<<<dim, cnrtFuncTypeBlock, queue>>>(
(float *)alibi_slopes, (int *)true_seq_lens, batch_num, head_start, head_num, head_num_total,
max_sequence_length, use_dynamic, closest_power_of_2, farthest_power_of_2, base, extra_base);
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
} // namespace tmo

View File

@@ -0,0 +1,43 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_GENERATE_ALIBI_SLOPE_MLUH_
#define CSRC_KERNELS_GENERATE_ALIBI_SLOPE_MLUH_
#include "cnnl.h"
#include "kernel_utils.h"
namespace tmo {
/**
* @brief Generate causal mask for context satge.
* @param queue: The queue for mlu.
* @param alibi_slopes: Output. Pointer to the MLU memory that stores the output, the shape must be
* [batch_num, head_num].
* @param true_seq_lens: Input. Pointer to the MLU memory that stores the actual sequence length of
* each batch, the shape must be [batch_num].
* @param batch_num: Batch number.
* @param head_start: The index of first head.
* @param head_num: Head number in this card.
* @param head_num_total: Total head number in all cards.
* @param max_sequence_length: The maximum sequence length used during training.
* @param use_dynamic: A boolean value indicates whether to use dynamic NTK.
*/
KernelStatus invokeGenerateAlibiSlope(cnrtQueue_t queue,
void *alibi_slopes,
void *true_seq_lens,
int batch_num,
int head_start,
int head_num,
int head_num_total,
int max_sequence_length,
bool use_dynamic);
} // namespace tmo
#endif // CSRC_KERNELS_GENERATE_ALIBI_SLOPE_MLUH_

View File

@@ -0,0 +1,214 @@
#include <cstddef>
#include <iostream>
#include "cn_api.h"
#include "cnnl.h"
#include "cnrt.h"
#include "generate_mask.mluh"
// clang-format off
#include <mlu.h>
// clang-format on
namespace tmo {
namespace kernels {
template <typename T>
__mlu_func__ void write_value(void *dst, unsigned int elem_count, T value) {
__bang_write_value(dst, elem_count, value);
}
template <>
__mlu_func__ void write_value(void *dst, unsigned int elem_count, bfloat16_t value) {
#if __BANG_ARCH__ >= 500
__bang_write_value(dst, elem_count, value);
#endif
}
// [once_len, once_len]
__nram__ int8_t nram_small[(__MLU_NRAM_SIZE__ * 1 / 4 * 1024)];
// [1 + once_len, 2 * once_len]
__nram__ int8_t nram_large[(__MLU_NRAM_SIZE__ * 2 / 4 * 1024 + 1024)];
// [once_len * 2 + 1]
__nram__ int8_t nram_tiny[2048];
template <typename T>
class GenerateMask {
constexpr static int once_len = sizeof(T) == 4 ? 160 : 256;
// [once_len, once_len]
T *nram_upper = (T *)(nram_small);
// [1 + once_len, 2 * once_len]
T *nram_buf = (T *)(nram_large);
// [once_len, once_len], reuse upper part of nram_buf
T *nram_filled = nram_buf;
// [once_len, once_len], reuse lower part of nram_buf
T *nram_zeros = nram_buf + once_len * once_len;
// [once_len]
T *nram_ones_zeros = (T *)nram_tiny;
__mlu_func__ void initBuffers(T fill_value = -10000) {
/* nram_buf:
|---once_len---||---once_len---|
0, 1, 1, 1, ..., 1, 0, 0, 0, ...
0, 0, 1, 1, ..., 1, 1, 0, 0, ...
0, 0, 0, 1, ..., 1, 1, 1, 0, ...
... */
nram_buf[0] = 0;
constexpr static int copy_size = (once_len * 2 + 1) * sizeof(T);
__memcpy(nram_buf + 1, nram_ones_zeros, copy_size, NRAM2NRAM, copy_size, 0, once_len - 1);
__memcpy(nram_upper, nram_buf, once_len * sizeof(T), NRAM2NRAM, once_len * sizeof(T),
once_len * 2 * sizeof(T), once_len - 1);
// nram_buf is nolonger needed
write_value(nram_filled, once_len * once_len, (T)fill_value);
write_value(nram_zeros, once_len * once_len, (T)0);
}
__mlu_func__ void dealOneBatch(T *output, // [max_seq_len, max_seq_len]
int max_seq_len,
int seq_len) {
/*
| once_len |
+----------+-----------------------------------+
| | | |
| upper | fill_value | |
| | | |
+----------+----------+ | |
| | | | |
| | upper | | fill |
| | | | value |
| +----------+----------+ | |
| | | | |
| | upper | | |
| 0 | | | |
| +----------+---+ |
| | u | |
|--------------------------------+---+ |
| |
| fill_value |
| |
+----------------------------------------------+
*/
int tile_count = seq_len / once_len;
int tile_remain = seq_len % once_len;
int boarder_len = max_seq_len - seq_len;
int row = 0;
for (; row < tile_count * once_len; row += once_len) {
// fill left with zeros
// assume that max_seq_len <= once_len^2
if (row > 0) {
__memcpy_async(output + (size_t)row * max_seq_len, nram_zeros, row * sizeof(T), NRAM2GDRAM,
max_seq_len * sizeof(T), 0, once_len - 1);
}
// fill middle with upper
__memcpy_async(output + (size_t)row * max_seq_len + row, nram_upper, once_len * sizeof(T),
NRAM2GDRAM, max_seq_len * sizeof(T), once_len * sizeof(T), once_len - 1);
// fill right with fill_value
if (row + once_len < max_seq_len) {
__memcpy_async(output + (size_t)row * max_seq_len + row + once_len, nram_filled,
(max_seq_len - row - once_len) * sizeof(T), NRAM2GDRAM,
max_seq_len * sizeof(T), 0, once_len - 1);
}
}
if (tile_remain) {
// fill left with zeros
if (row > 0) {
__memcpy_async(output + (size_t)row * max_seq_len, nram_zeros, row * sizeof(T), NRAM2GDRAM,
max_seq_len * sizeof(T), 0, tile_remain - 1);
}
// fill middle with upper
__memcpy_async(output + (size_t)row * max_seq_len + row, nram_upper, tile_remain * sizeof(T),
NRAM2GDRAM, max_seq_len * sizeof(T), once_len * sizeof(T), tile_remain - 1);
// fill right with fill_value
if (row + tile_remain < max_seq_len) {
__memcpy_async(output + (size_t)row * max_seq_len + row + tile_remain, nram_filled,
(max_seq_len - row - tile_remain) * sizeof(T), NRAM2GDRAM,
max_seq_len * sizeof(T), 0, tile_remain - 1);
}
}
if (boarder_len) {
// fill right boarder with fill_value
__memcpy_async(output + seq_len, nram_filled, boarder_len * sizeof(T), NRAM2GDRAM,
max_seq_len * sizeof(T), 0, (max_seq_len - boarder_len) - 1);
// fill bottom boarder with fill_value
__memcpy_async(output + (size_t)seq_len * max_seq_len, nram_filled, max_seq_len * sizeof(T),
NRAM2GDRAM, max_seq_len * sizeof(T), 0, boarder_len - 1);
}
__sync_io();
}
public:
__mlu_func__ void execute(T *output_ddr, // [total_batch, max_seq_len, max_seq_len]
int *batch_seq_len,
int total_batch,
int max_seq_len,
T fill_value = -10000) {
int batch_each = total_batch / taskDimY;
int batch_remain = total_batch % taskDimY;
int batch_start = taskIdY * batch_each + (taskIdY < batch_remain ? taskIdY : batch_remain);
int batch_count = batch_each + (taskIdY < batch_remain ? 1 : 0);
write_value(nram_ones_zeros, once_len, (T)fill_value);
write_value(nram_ones_zeros + once_len, once_len + 1, (T)0);
initBuffers();
for (int n = batch_start; n < batch_start + batch_count; n++) {
T *output = output_ddr + (size_t)n * max_seq_len * max_seq_len;
int seq_len = batch_seq_len[n];
dealOneBatch(output, max_seq_len, seq_len);
}
}
};
template <typename T>
__mlu_global__ void MLUUnion1GenerateMask(T *output_ddr, // [total_batch, max_seq_len, max_seq_len]
int *batch_seq_len,
int total_batch,
int max_seq_len,
T fill_value = -10000) {
if (coreId != 0) {
return; // we only use 1 core in a cluster
}
GenerateMask<T>().execute(output_ddr, batch_seq_len, total_batch, max_seq_len, fill_value);
}
} // namespace kernels
KernelStatus invokeGenerateMask(cnnlHandle_t handle,
void *output_ddr,
int *batch_seq_len,
int total_batch,
int max_seq_len,
cnnlDataType_t data_type,
float fill_value) {
cnrtQueue_t queue;
cnnlGetQueue(handle, &queue);
CNdev dev;
cnnlGetDevice(handle, &dev);
int cluster_num;
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
cnrtDim3_t dim;
dim.x = 4;
dim.y = cluster_num;
dim.z = 1;
if (data_type == CNNL_DTYPE_FLOAT) {
kernels::MLUUnion1GenerateMask<float><<<dim, cnrtFuncTypeUnion1, queue>>>(
static_cast<float *>(output_ddr), batch_seq_len, total_batch, max_seq_len,
static_cast<float>(fill_value));
} else if (data_type == CNNL_DTYPE_HALF) {
kernels::MLUUnion1GenerateMask<half><<<dim, cnrtFuncTypeUnion1, queue>>>(
static_cast<half *>(output_ddr), batch_seq_len, total_batch, max_seq_len,
static_cast<half>(fill_value));
} else if (data_type == CNNL_DTYPE_BFLOAT16) {
if (!isBf16Supported()) {
std::cerr << "[invokeGenerateMask]: MLU300 devices do not support bfloat16." << std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
kernels::MLUUnion1GenerateMask<bfloat16_t><<<dim, cnrtFuncTypeUnion1, queue>>>(
static_cast<bfloat16_t *>(output_ddr), batch_seq_len, total_batch, max_seq_len,
static_cast<bfloat16_t>(fill_value));
} else {
std::cerr << "[invokeGenerateMask]: invokeGenerateMask: data_type is not supported"
<< std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
} // namespace tmo

View File

@@ -0,0 +1,37 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_GENERATE_MASK_MLUH_
#define CSRC_KERNELS_GENERATE_MASK_MLUH_
#include "cnnl.h"
#include "kernel_utils.h"
namespace tmo {
/**
* @brief Generate causal mask for context stage.
* @param handle: The handle of cnnl.
* @param output_ddr: Output. Pointer to the MLU memory that stores the output.
* @param batch_seq_len: Input. Pointer to the MLU memory that stores the sequence length.
* @param total_batch: Batch size.
* @param max_seq_len: The maximum sequence length of context.
* @param data_type: Data type.
* @param fill_value: The fill value of the pad part.
*/
KernelStatus invokeGenerateMask(cnnlHandle_t handle,
void *output_ddr,
int *batch_seq_len,
int total_batch,
int max_seq_len,
cnnlDataType_t data_type,
float fill_value);
} // namespace tmo
#endif // CSRC_KERNELS_GENERATE_MASK_MLUH_

View File

@@ -0,0 +1,60 @@
#include <stdexcept>
#include "cnrt.h"
#include "get_glm_position_id.mluh"
// clang-format off
#include <mlu.h>
// clang-format on
namespace tmo {
namespace kernels {
__nram__ int nram_buffer[__MLU_NRAM_SIZE__ * 3 / 4 * 1024 / sizeof(int)];
__mlu_global__ void MLUBlockSliceIndividualPosId(int *context_pos_id,
int *generate_pos_id,
int batch,
int context_seq_len,
int pos_id_dimension /* 1 for 1D, 2 for 2D */) {
if (taskId != 0) return;
__memcpy(nram_buffer, context_pos_id + context_seq_len - 1, sizeof(int), GDRAM2NRAM, sizeof(int),
context_seq_len * sizeof(int), pos_id_dimension * batch - 1);
if (pos_id_dimension == 2) {
for (int i = 1; i < 2 * batch; i += 2) {
nram_buffer[i] += 1;
}
}
__memcpy(generate_pos_id, nram_buffer, pos_id_dimension * batch * sizeof(int), NRAM2GDRAM);
}
__mlu_global__ void MLUBlockIncrement2DPosId(int *generate_pos_id, int batch) {
if (taskId != 0) return;
__memcpy(nram_buffer, generate_pos_id, 2 * batch * sizeof(int), GDRAM2NRAM);
for (int i = 1; i < 2 * batch; i += 2) {
nram_buffer[i] += 1;
}
__memcpy(generate_pos_id, nram_buffer, 2 * batch * sizeof(int), NRAM2GDRAM);
}
} // namespace kernels
KernelStatus invokeSliceIndividualPosId(cnrtQueue_t queue,
int *context_pos_id,
int *generate_pos_id,
int batch,
int context_seq_len,
int pos_id_dimension /* 1 for 1D, 2 for 2D */) {
if (pos_id_dimension != 1 && pos_id_dimension != 2) {
std::cerr << "[invokeSliceIndividualPosId]: pos_id_dimension must be 1 or 2, but got "
<< pos_id_dimension << std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
cnrtDim3_t dim{4, 1, 1};
kernels::MLUBlockSliceIndividualPosId<<<dim, cnrtFuncTypeUnion1, queue>>>(
context_pos_id, generate_pos_id, batch, context_seq_len, pos_id_dimension);
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
KernelStatus invokeIncrement2DPosId(cnrtQueue_t queue, int *generate_pos_id, int batch) {
cnrtDim3_t dim{4, 1, 1};
kernels::MLUBlockIncrement2DPosId<<<dim, cnrtFuncTypeUnion1, queue>>>(generate_pos_id, batch);
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
} // namespace tmo

View File

@@ -0,0 +1,59 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_GET_GLM_POSITION_ID_MLUH_
#define CSRC_KERNELS_GET_GLM_POSITION_ID_MLUH_
#include "cnnl.h"
#include "kernel_utils.h"
namespace tmo {
/**
* @brief Get generate position id from context position id, when position id is 2D,
* increase block position id by one.
* @example
* in GLM network, context_pos_id shape is [batch, 2, context_seq_len], data is
* [[[0, 1, 2, 2, 2, 2, 2], [0, 0, 0, 1, 1, 1, 1]],
* [[0, 1, 2, 3, 4, 5, 5], [0, 0, 0, 0, 0, 0, 1]]]
* after invoke this kernel, the data is
* [[[2], [2]],
* [[5], [2]]]
* @param queue: The queue for mlu.
* @param context_pos_id: Input. Pointer to the MLU memory that stores the position id of
* context.
* @param generate_pos_id: Output. Pointer to the MLU memory that stores the position id of
* generate.
* @param batch: Batch size.
* @param context_seq_len: The sequence length of context.
* @param pos_id_dimension: The dimension of position id, 1 for 1D, 2 for 2D.
*/
KernelStatus invokeSliceIndividualPosId(cnrtQueue_t queue,
int *context_pos_id,
int *generate_pos_id,
int batch,
int context_seq_len,
int pos_id_dimension);
/**
* @brief Increase block position id by one in generate stage.
* @example
* in GLM network, generate_pos_id shape is [batch, 2, 1], data is
* [[[2], [1]], [[5], [1]]]
* after invoke this kernel, the data is
* [[[2], [2]], [[5], [2]]]
* @param queue: The queue for mlu.
* @param generate_pos_id: Output/Input. Pointer to the MLU memory that stores the position id of
* generate.
* @param batch: Batch size.
*/
KernelStatus invokeIncrement2DPosId(cnrtQueue_t queue, int *generate_pos_id, int batch);
} // namespace tmo
#endif // CSRC_KERNELS_GET_GLM_POSITION_ID_MLUH_

View File

@@ -0,0 +1,54 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_KERNEL_UTILS_H_
#define CSRC_KERNELS_KERNEL_UTILS_H_
#include <cassert>
#include <iostream>
#include <string>
#include "cnnl.h"
#include "cnrt.h"
namespace tmo {
const std::string arch_370 = "MLU370";
enum class KernelStatus { KERNEL_STATUS_SUCCESS = 0, KERNEL_STATUS_FAILED };
#ifndef PAD_DOWN
#define PAD_DOWN(x, y) (((x) / (y)) * (y))
#endif
#ifndef PAD_UP
#define PAD_UP(x, y) (((x) / (y) + (int)((x) % (y) > 0)) * (y))
#endif
inline bool isMlu300(const std::string &dev_name) {
if (dev_name.find("MLU3") != std::string::npos) {
return true;
} else {
return false;
}
}
inline bool is_arch300() {
int card_id = -1;
cnrtDeviceProp_t dev_info;
CNRT_CHECK(cnrtGetDevice(&card_id));
CNRT_CHECK(cnrtGetDeviceProperties(&dev_info, card_id));
std::string dev_name = dev_info.name;
return isMlu300(dev_name);
}
inline bool isBf16Supported() {
return !is_arch300();
}
} // namespace tmo
#endif // CSRC_KERNELS_KERNEL_UTILS_H_

View File

@@ -0,0 +1,521 @@
#include <algorithm>
#include <cassert>
#include <iostream>
#include <map>
#include <ostream>
#include "add_bias_activation.mluh"
#include "cnnl.h"
#include "cnrt.h"
// clang-format off
#include <mlu.h>
// clang-format on
namespace tmo {
namespace kernels {
#define PAD_UP(x, y) (((x) / (y) + (int)((x) % (y) > 0)) * (y))
#define USE_NRAM_SIZE (__MLU_NRAM_SIZE__ * 1024 - 20 * 1024)
#define USE_SRAM_SIZE (__MLU_SRAM_SIZE__ * 1024 - 20 * 1024)
__nram__ int8_t nram_buffer[USE_NRAM_SIZE];
__mlu_shared__ int8_t sram_buffer[USE_SRAM_SIZE];
__mlu_func__ void get_expert_info(int *nram_count,
int *count_sram,
uint32_t *gather_offset,
int real_inner,
int tokens_start,
int tokens_end,
int tokens_load,
int expert_deal_start,
int expert_deal_end,
uint32_t &expert_start,
uint32_t &expert_end,
uint32_t &tokens_deal_first,
int dtype_size) {
bool record_start = false;
// loop expert to find first and last deal expert in current core
for (int expert_id = expert_deal_start; expert_id <= expert_deal_end; expert_id++) {
if (__load_nram(nram_count + expert_id + 1) > tokens_start && !record_start) {
expert_start = expert_id;
tokens_deal_first =
std::min(__load_nram(nram_count + expert_id + 1) - 1, tokens_end) - tokens_start + 1;
record_start = true;
}
if (__load_nram(nram_count + expert_id + 1) > tokens_end) {
expert_end = expert_id;
break;
}
}
// record expert offset to gather bias
__bang_write_zero(gather_offset, tokens_load);
int tokens_load_total = 0;
for (int expert_id = expert_start; expert_id <= expert_end; expert_id++) {
int tokens_expand = __load_sram((int *)count_sram + expert_id);
if (expert_id == expert_start) {
tokens_expand = tokens_deal_first;
} else if (expert_id == expert_end) {
tokens_expand = tokens_load - tokens_load_total;
}
if (tokens_expand == 0) {
continue;
}
__bang_write_value(gather_offset + tokens_load_total, tokens_expand,
(int)((expert_id - expert_deal_start) * real_inner * dtype_size));
tokens_load_total += tokens_expand;
}
}
/*************** functions for compute basic operation ***************/
template <typename T>
__mlu_func__ void add_bias(T *dst_src, T *bias, int number) {
// cycle add bias
__bang_add((T *)dst_src, (T *)dst_src, (T *)bias, number);
}
template <>
__mlu_func__ void add_bias(bfloat16_t *dst_src, bfloat16_t *bias, int number) {
#if __BANG_ARCH__ > 500
__bang_add((bfloat16_t *)dst_src, (bfloat16_t *)dst_src, (bfloat16_t *)bias, number);
#endif
}
template <typename T>
__mlu_func__ void mul_left_right(T *left, T *right, int number) {
__bang_mul((T *)left, (T *)left, (T *)right, number);
}
template <>
__mlu_func__ void mul_left_right(bfloat16_t *left, bfloat16_t *right, int number) {
#if __BANG_ARCH__ > 500
__bang_mul((bfloat16_t *)left, (bfloat16_t *)left, (bfloat16_t *)right, number);
#endif
}
__mlu_func__ void do_activation(float *input_left,
float *act_space,
int number,
float active_coef,
cnnlActivationMode_t act_type) {
if (act_type == CNNL_ACTIVATION_GELU) {
__bang_active_gelu((float *)input_left, (float *)input_left, number);
} else if (act_type == CNNL_ACTIVATION_SWISH) {
float *tmp = input_left;
if (active_coef != 1.0f) {
__bang_mul_scalar(act_space, input_left, active_coef, number);
tmp = act_space;
}
__bang_active_sigmoid((float *)act_space, (float *)tmp, number);
__bang_mul((float *)input_left, (float *)input_left, (float *)act_space, number);
}
}
/*************** functions for steps of each loop ***************/
template <typename T>
__mlu_func__ void gather_bias(T *bias_sram,
T *bias_nram,
uint32_t *gather_offset,
int expert_start,
int expert_end,
int expert_deal_start,
int tokens_deal_first,
int tokens_deal,
int inner_size,
bool is_gated) {
#if __BANG_ARCH__ > 500
if (is_gated) {
__gather_async((T *)bias_nram, (T *)bias_sram, gather_offset, inner_size * sizeof(T), SRAM2NRAM,
inner_size * sizeof(T), tokens_deal);
__gather_async((T *)bias_nram + tokens_deal * inner_size, (T *)bias_sram + inner_size,
gather_offset, inner_size * sizeof(T), SRAM2NRAM, inner_size * sizeof(T),
tokens_deal);
} else {
__gather_async((T *)bias_nram, (T *)bias_sram, gather_offset, inner_size * sizeof(T), SRAM2NRAM,
inner_size * sizeof(T), tokens_deal);
}
#else
for (int i = 0; i < tokens_deal; i++) {
if (is_gated) {
__memcpy_async((T *)bias_nram + i * inner_size, (int8_t *)bias_sram + gather_offset[i],
inner_size * sizeof(T), SRAM2NRAM);
__memcpy_async((T *)bias_nram + (tokens_deal * inner_size + i * inner_size),
(int8_t *)bias_sram + inner_size * sizeof(T) + gather_offset[i],
inner_size * sizeof(T), SRAM2NRAM);
} else {
__memcpy_async((T *)bias_nram + i * inner_size, (int8_t *)bias_sram + gather_offset[i],
inner_size * sizeof(T), SRAM2NRAM);
}
}
#endif
}
template <typename T>
__mlu_func__ void loadBiasInput(T *input,
T *left,
T *right,
T *bias_nram,
T *bias_sram,
uint32_t *gather_offset,
size_t input_offset,
int tokens_deal,
int inner_size,
uint32_t expert_start,
uint32_t expert_end,
int expert_deal_start,
uint32_t tokens_deal_first,
bool is_gated,
bool has_bias) {
if (is_gated) {
// if gated, stride io load input, left/right inner to input_left/right
__memcpy_async((T *)left, (T *)input + input_offset, inner_size * sizeof(T), GDRAM2NRAM,
inner_size * sizeof(T), inner_size * 2 * sizeof(T), tokens_deal - 1);
__memcpy_async((T *)right, (T *)input + input_offset + inner_size, inner_size * sizeof(T),
GDRAM2NRAM, inner_size * sizeof(T), inner_size * 2 * sizeof(T), tokens_deal - 1);
} else {
// if not gated, load input to input_left total
__memcpy_async((T *)left, (T *)input + input_offset, tokens_deal * inner_size * sizeof(T),
GDRAM2NRAM);
}
if (has_bias) {
__sync_compute();
gather_bias((T *)bias_sram, (T *)bias_nram, gather_offset, expert_start, expert_end,
expert_deal_start, tokens_deal_first, tokens_deal, inner_size, is_gated);
}
}
template <typename T>
__mlu_func__ void computeAddActivation(T *bias_nram,
T *left_dst,
T *input_right,
float *input_left,
float *act_space,
int tokens_deal,
int inner_size,
bool is_gated,
bool has_bias,
float active_coef,
cnnlActivationMode_t act_type) {
int number = tokens_deal * inner_size;
if (has_bias) {
add_bias((T *)left_dst, (T *)bias_nram, number);
if (is_gated) {
add_bias((T *)input_right, (T *)bias_nram + tokens_deal * inner_size, number);
}
}
// cast half/bfloat16 to float to acvication, if float, left_dst is same as input_left
if (std::is_same<T, half>::value) {
__bang_half2float((float *)input_left, (half *)left_dst, number);
}
#if __BANG_ARCH__ > 500
if (std::is_same<T, bfloat16_t>::value) {
__bang_bfloat162float((float *)input_left, (bfloat16_t *)left_dst, number);
}
#endif
// activation
do_activation(input_left, act_space, number, active_coef, act_type);
// if half/bfloat16, cast float to T to mul
if (std::is_same<T, half>::value) {
__bang_float2half((half *)input_left, (float *)input_left, number);
}
#if __BANG_ARCH__ > 500
if (std::is_same<T, bfloat16_t>::value) {
__bang_float2bfloat16((bfloat16_t *)input_left, (float *)input_left, number);
}
#endif
if (is_gated) {
mul_left_right((T *)input_left, (T *)input_right, tokens_deal * inner_size);
}
}
template <typename T>
__mlu_func__ void storeOutput(T *output,
T *output_nram,
size_t output_offset,
int output_stride,
int tokens_deal,
int inner_size) {
__memcpy_async((T *)output + output_offset, (T *)output_nram, inner_size * sizeof(T), NRAM2GDRAM,
output_stride * sizeof(T), inner_size * sizeof(T), tokens_deal - 1);
}
template <typename T>
__mlu_global__ void MLUAddBiasActivationKernel(T *output,
const T *input,
const T *bias,
const int *cusum_token_count,
int num_expert,
int total_tokens,
int inner_size,
int output_stride,
bool is_gated,
cnnlActivationMode_t act_type,
int start_expert_id,
int expert_size,
float active_coef) {
// if bias and token_count is nullptr, not add bias, only activation and gated mul.
bool has_bias = (bias != nullptr);
// 1. distrubute nram space
/* if not gated
---------------------------- ----------------------------------
| nram_token_count | bias_ping/pong |
| num_expert * sizeof(int) | 2 * x * inner_size * sizeof(T) |
---------------------------- ----------------------------------
--------------------------------------
| input_ping/pong |
| 2 * x * inner_size * sizeof(float) |
--------------------------------------
------------------------------------------------- -------------------
| act_space | gather_offset |
| gelu: 0; silu: x * inner_size * sizeof(float) | x * sizeof(int) |
------------------------------------------------- -------------------
*/
/* if gated
---------------------------- ----------------------------------
| nram_token_count | bias_ping/pong |
| num_expert * sizeof(int) | 2 * x * real_inner * sizeof(T) |
---------------------------- ----------------------------------
-------------------------------------- ----------------------------------
| input_left_ping/pong | input_right_ping/pong |
| 2 * x * inner_size * sizeof(float) | 2 * x * inner_size * sizeof(T) |
-------------------------------------- ----------------------------------
------------------------------------------------- -------------------
| act_space | gather_offset |
| gelu: 0; silu: x * inner_size * sizeof(float) | x * sizeof(int) |
------------------------------------------------- -------------------
*/
// distribute sram
int8_t *count_sram = (int8_t *)sram_buffer;
int8_t *bias_sram = (int8_t *)count_sram + num_expert * sizeof(int);
// distrubute nram
int real_inner = (is_gated) ? inner_size * 2 : inner_size;
int bias_nram_size = (has_bias) ? real_inner * sizeof(T) : 0;
int act_space_size = (act_type == CNNL_ACTIVATION_GELU) ? 0 : inner_size * sizeof(float);
int gated_ext_size = is_gated ? sizeof(T) : 0;
int max_token_deal = (USE_NRAM_SIZE - (num_expert + 1) * sizeof(int)) /
(2 * inner_size * (sizeof(float) + gated_ext_size) + act_space_size +
2 * bias_nram_size + sizeof(int));
int8_t *nram_count = (int8_t *)nram_buffer;
int8_t *bias_nram = (int8_t *)nram_count + (num_expert + 1) * sizeof(int);
int8_t *input_left = (int8_t *)bias_nram + 2 * max_token_deal * bias_nram_size;
int8_t *input_right =
(int8_t *)input_left + 2 * ((is_gated) ? max_token_deal * inner_size * sizeof(float) : 0);
int8_t *act_space = (int8_t *)input_right +
2 * max_token_deal * inner_size * (is_gated ? sizeof(T) : sizeof(float));
int8_t *gather_offset = (int8_t *)act_space + max_token_deal * act_space_size;
// 2. cusum_token_count load to nram, because need to reuse in load bias.
if (has_bias) {
__memcpy((int *)nram_count, (int *)cusum_token_count, (num_expert + 1) * sizeof(int),
GDRAM2NRAM);
if (taskIdX == 0) {
__bang_sub((int *)bias_nram, (int *)nram_count + 1, (int *)nram_count, num_expert);
__sync();
__memcpy((int *)count_sram, (int *)bias_nram, num_expert * sizeof(int), NRAM2SRAM);
}
__sync_cluster();
}
// 3. sram loop to compute
// compute once load bias to sram due to sram_limit
int max_expert_deal = (USE_SRAM_SIZE - num_expert * sizeof(int)) / (real_inner * sizeof(T));
int real_expert = cusum_token_count == nullptr ? num_expert : expert_size;
int sram_loop_rem = real_expert % max_expert_deal;
int sram_loop = real_expert / max_expert_deal + (int)(sram_loop_rem != 0);
if (!has_bias) {
max_expert_deal = real_expert;
sram_loop = 1;
sram_loop_rem = 0;
}
for (int deal_loop = 0; deal_loop < sram_loop; deal_loop++) {
// load current bias, compute each core deal number
int expert_deal =
(deal_loop == (sram_loop - 1) && sram_loop_rem != 0) ? sram_loop_rem : max_expert_deal;
int expert_deal_start = deal_loop * max_expert_deal + start_expert_id;
int expert_deal_end = expert_deal_start + expert_deal - 1;
__sync_all();
if (has_bias && __is_mpu()) {
__memcpy((T *)bias_sram, (T *)bias + deal_loop * max_expert_deal * real_inner,
expert_deal * real_inner * sizeof(T), GDRAM2SRAM);
}
__sync_all();
// get tokens info of each core
int tokens_total_cur = total_tokens;
if (has_bias) {
tokens_total_cur =
__load_nram((int *)nram_count + expert_deal_end + 1) -
(start_expert_id == 0 ? 0 : __load_nram((int *)nram_count + expert_deal_start));
} else if (cusum_token_count != nullptr) {
tokens_total_cur = __load_gdram(cusum_token_count + expert_deal_end + 1) -
__load_gdram(cusum_token_count + expert_deal_start);
}
if (sram_loop != 1) {
tokens_total_cur = ((int *)nram_count)[expert_deal_end + 1] -
((deal_loop == 0) ? 0 : ((int *)nram_count)[expert_deal_start]);
if (deal_loop == 0 && start_expert_id != 0) {
tokens_total_cur -= __load_nram((int *)nram_count + expert_deal_start);
}
}
int tokens_core_rem = tokens_total_cur % taskDim;
int tokens_cur_core = tokens_total_cur / taskDim + (taskId < tokens_core_rem);
if (tokens_cur_core == 0) {
continue;
}
// if ep, input start in current token, have a real start in total network
int real_start =
cusum_token_count != nullptr ? __load_gdram((int *)cusum_token_count + start_expert_id) : 0;
int tokens_core_start = tokens_cur_core * taskId +
(taskId < tokens_core_rem ? 0 : tokens_core_rem) +
((deal_loop == 0) ? 0 : ((int *)nram_count)[expert_deal_start]);
if (deal_loop != 0) {
tokens_core_start -= real_start;
}
uint32_t expert_start = 0;
uint32_t expert_end = 0;
uint32_t tokens_deal_first = 0;
// 4. nram loop compute
int nram_loop_rem = tokens_cur_core % max_token_deal;
int nram_loop = tokens_cur_core / max_token_deal + (int)(nram_loop_rem != 0);
int tokens_load = max_token_deal;
int tokens_compute = max_token_deal;
int tokens_store = max_token_deal;
for (int loop = 0; loop < nram_loop + 2; loop++) {
int inner_io_offset = (loop % 2) * max_token_deal * inner_size;
int inner_com_offset = ((loop + 1) % 2) * max_token_deal * inner_size;
int real_io_offset = (loop % 2) * max_token_deal * real_inner;
int real_com_offset = ((loop + 1) % 2) * max_token_deal * real_inner;
if (nram_loop_rem != 0) {
if (loop > 1 && (loop - 2) == (nram_loop - 1)) {
tokens_store = nram_loop_rem;
}
if (loop > 0 && (loop - 1) == (nram_loop - 1)) {
tokens_compute = nram_loop_rem;
}
if (loop == (nram_loop - 1)) {
tokens_load = nram_loop_rem;
}
}
int tokens_cur_start = tokens_core_start + loop * max_token_deal;
int tokens_cur_end = tokens_cur_start + tokens_load - 1;
// get current load info
if (loop < nram_loop && has_bias) {
get_expert_info((int *)nram_count, (int *)count_sram, (uint32_t *)gather_offset, real_inner,
tokens_cur_start + real_start, tokens_cur_end + real_start, tokens_load,
expert_deal_start, expert_deal_end, expert_start, expert_end,
tokens_deal_first, sizeof(T));
}
// store
if (loop > 1) {
size_t output_offset = (tokens_core_start + (loop - 2) * max_token_deal) * output_stride;
storeOutput((T *)output, (T *)((float *)input_left + inner_io_offset), output_offset,
output_stride, tokens_store, inner_size);
}
// compute
if (loop > 0 && loop <= nram_loop) {
T *left_dst = (T *)((float *)input_left + inner_com_offset) +
((std::is_same<T, float>::value) ? 0 : tokens_compute * inner_size);
computeAddActivation((T *)bias_nram + real_com_offset, (T *)left_dst,
(T *)input_right + inner_com_offset,
(float *)input_left + inner_com_offset, (float *)act_space,
tokens_compute, inner_size, is_gated, has_bias, active_coef, act_type);
}
// load
if (loop < nram_loop) {
T *left_dst = (T *)((float *)input_left + inner_io_offset) +
((std::is_same<T, float>::value) ? 0 : tokens_load * inner_size);
size_t input_offset = tokens_cur_start * real_inner;
loadBiasInput((T *)input, (T *)left_dst, (T *)input_right + inner_io_offset,
(T *)bias_nram + real_io_offset, (T *)bias_sram, (uint32_t *)gather_offset,
input_offset, tokens_load, inner_size, expert_start, expert_end,
expert_deal_start, tokens_deal_first, is_gated, has_bias);
}
__sync();
}
}
}
} // namespace kernels
KernelStatus invokeGroupAddBiasActivationKernel(cnrtQueue_t queue,
void *output,
const void *input,
const void *bias,
const int *cusum_token_count,
int num_expert,
int total_tokens,
int inner_size,
int output_stride,
cnnlDataType_t dtype,
bool is_gated,
cnnlActivationMode_t act_type,
int start_expert_id,
int expert_size,
float active_coef) {
if (bias != NULL && cusum_token_count == NULL) {
std::cerr << "[invokeGroupAddBiasActivationKernel]: "
<< "when have bias, cusum_token_count can not be nullptr.";
return KernelStatus::KERNEL_STATUS_FAILED;
}
if (act_type != CNNL_ACTIVATION_GELU && act_type != CNNL_ACTIVATION_SWISH) {
std::cerr << "[invokeGroupAddBiasActivationKernel]: "
<< "activation mode only supports gelu and swish now.";
return KernelStatus::KERNEL_STATUS_FAILED;
}
CNdev dev;
cnCtxGetDevice(&dev);
int cluster_num;
int core_num;
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev));
cnrtDim3_t dim{.x = (uint32_t)core_num, .y = (uint32_t)cluster_num, .z = 1};
if (dtype == CNNL_DTYPE_FLOAT) {
kernels::MLUAddBiasActivationKernel<<<dim, cnrtFuncTypeUnion1, queue>>>(
(float *)output, (const float *)input, (const float *)bias, cusum_token_count, num_expert,
total_tokens, inner_size, output_stride, is_gated, act_type, start_expert_id, expert_size,
active_coef);
} else if (dtype == CNNL_DTYPE_HALF) {
kernels::MLUAddBiasActivationKernel<<<dim, cnrtFuncTypeUnion1, queue>>>(
(half *)output, (const half *)(input), (const half *)bias, cusum_token_count, num_expert,
total_tokens, inner_size, output_stride, is_gated, act_type, start_expert_id, expert_size,
active_coef);
} else if (dtype == CNNL_DTYPE_BFLOAT16) {
if (!isBf16Supported()) {
std::cerr << "[invokeGroupAddBiasActivationKernel]: MLU300 devices do not support bfloat16."
<< std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
kernels::MLUAddBiasActivationKernel<<<dim, cnrtFuncTypeUnion1, queue>>>(
(bfloat16_t *)output, (const bfloat16_t *)input, (const bfloat16_t *)bias,
cusum_token_count, num_expert, total_tokens, inner_size, output_stride, is_gated, act_type,
start_expert_id, expert_size, active_coef);
} else {
std::cerr << "[invokeGroupAddBiasActivationKernel]: add_bias_activation data_type not support, "
<< "only support float/half/bfloat16." << std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
} // namespace tmo

View File

@@ -0,0 +1,84 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_MOE_ADD_BIAS_ACTIVATION_MLUH_
#define CSRC_KERNELS_MOE_ADD_BIAS_ACTIVATION_MLUH_
#include "../kernel_utils.h"
#include "cnnl.h"
namespace tmo {
/**
* @brief Add bias and activate to all tokens. Different expert with different bias.
* If in gated mode, add bias to all input. But only activate in the
* input of [:, :inner_size]. Then multiply it to the input of [:, inner_size:].
* Else, add bias and activate to all input, has no multiply process.
* @example
* is_gated = true, num_expert = 4, total_tokens = 6, inner_size = 2
* input: (6, 4) = [[2, 4, 5, 6], [1, 4, 5, 3],
* [3, 5, 7, 8], [6, 8, 5, 3],
* [2, 3, 4 ,5], [2, 9, 2, 3]]
* bias: (4, 4) = [[1, 0, 1, 0], [0, 1, 2, 2], [2, 3, 2, 3], [1, 2, 3, 4]]
* token_count = [2, 2, 1, 1]
* first step: add bias
* [[2+1, 4+0, 5+1, 6+0], [1+1, 4+0, 5+1, 3+0],
* [3+0, 5+1, 7+2, 8+2], [6+0, 8+1, 5+2, 3+2],
* [2+2, 3+3, 4+2, 5+3], [2+1, 9+2, 2+3, 3+4]]
* second step: act and mul
* output: (6, 2) = [[act(3)*6, act(4)*6], [act(2)*6, act(4)*3],
* [act(3)*9, act(6)*10], [act(6)*7, act(9)*5],
* [act(4)*6, act(6)*8], [act(3)*5, act(11)*7]]
* @param queue: The queue for mlu.
* @param output: Output. Pointer to the MLU memory that stores the result.
* When is_gated is true, The shape is [total_tokens, input_size / 2].
* In this case, the input_size must be even. Otherwise the shape is [total_tokens,
* input_size]. The memory can be discontinuous in total_tokens dim. The stride is output_stride.
* @param input: Input. Pointer to the MLU memory that stores the input tokens.
* The shape is [total_tokens, input_size].
* When is_gated is true, the shape is [total_tokens, 2 * inner_size].
* Otherwise the shape is [total_tokens, inner_size].
* @param bias: Input. Pointer to the MLU memory that stores the bias. The memory must be
* continuous. When is_gated is true, the shape is [num_expert, 2 * inner_size]. Otherwise the shape
* is [num_expert, inner_size]. Bias can be nullptr. If bias is nullptr, has no add bias process.
* @param cusum_token_count: Input. Pointer to the MLU memory that stores the prefix sum of token
* counts. The shape is [num_expert + 1]. If cusum_token_count
* is not nullptr, cusum_token_count, start_expert_id and
* expert_size together determine which tokens to process.
* If cusum_token_count is nullptr, process all tokens,
* the number of which is total_tokens. When bias is not nullptr,
* cusum_token_count must also not be nullptr.
* @param num_expert: The number of expert.
* @param total_tokens: The total number of tokens.
* @param inner_size: The inner size of output.
* @param output_stride: The stride of output, must be greater than or equal to inner_size.
* @param dtype: Data type.
* @param is_gated: Gated or not.
* @param act_type: The type of activation. Support gelu and swish.
* @param start_expert_id: The index of the start expert.
* @param expert_size: The number of experts to process.
* @param active_coef: The coefficient used in the swish activation.
*/
KernelStatus invokeGroupAddBiasActivationKernel(cnrtQueue_t queue,
void *output,
const void *input,
const void *bias,
const int *cusum_token_count,
int num_expert,
int total_tokens,
int inner_size,
int output_stride,
cnnlDataType_t dtype,
bool is_gated,
cnnlActivationMode_t act_type,
int start_expert_id,
int expert_size,
float active_coef);
} // namespace tmo
#endif // CSRC_KERNELS_MOE_ADD_BIAS_ACTIVATION_MLUH_

View File

@@ -0,0 +1,646 @@
#include <stdint.h>
#include <algorithm>
#include <cmath>
#include <iostream>
#include <vector>
#include "cast_gating.mluh"
#include "cnnl.h"
#include "cnrt.h"
// clang-format off
#include <mlu.h>
// clang-format on
namespace tmo {
#define DIV_UP(x, y) ((x) % (y) > 0 ? ((x) / (y) + 1) : ((x) / (y)))
#define NRAM_BUFFER_SIZE (496 * 1024)
#define WRAM_BUFFER_SIZE (512 * 1024)
#define SRAM_BUFFER_SIZE (2032 * 1024)
#ifndef ONE_LINE
#define ONE_LINE 64
#endif
#ifndef LT_NUM
#define LT_NUM 64
#endif
struct castGatingTileInfo {
int32_t block = 64;
int32_t split_k_num = 8;
int32_t block_k = 256;
};
namespace kernels {
#pragma bang walign(16)
#ifndef ROW_PER_LT
#define ROW_PER_LT 4
#endif
#ifndef LT_SIZE
#define LT_SIZE 16
#endif
#ifndef WRAM_LT_MAP16_STRIDE
#define WRAM_LT_MAP16_STRIDE (WRAM_BUFFER_SIZE / 16)
#endif
__nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE];
__wram__ int8_t wram_buffer[WRAM_BUFFER_SIZE];
__mlu_shared__ int8_t sram_buffer[SRAM_BUFFER_SIZE];
#define SRAM2NRAM_CONVERT_IMPL(dst, src, size, dst_dsize, src_dsize, convert_type) \
do { \
uint32_t align_num = 64 / src_dsize; \
uint32_t n = PAD_DOWN(size / src_dsize, align_num); \
uint32_t rem = size % 64; \
if (n) { \
__asm__ __volatile__( \
"move.tiling.async.nram.sram.b16" \
" [%[dst_addr]], [%[src_addr]], " \
"%[src_n0], %[src_n1], %[src_s1], %[src_n2], %[src_s2], " \
"%[src_n3], %[src_s3], %[src_n4], %[src_s4], %[src_n5], %[src_s5], " \
"%[dst_n0], %[dst_n1], %[dst_s1], %[dst_n2], %[dst_s2], " \
"%[dst_n3], %[dst_s3], %[dst_n4], %[dst_s4]," \
"%[dst_n5], %[dst_s5]," convert_type ";\n\t" ::[dst_addr] "r"(dst), \
[src_addr] "r"(src), [src_n0] "i"(64), [src_n1] "i"(1), [src_s1] "i"(0), \
[src_n2] "i"(1), [src_s2] "i"(0), [src_n3] "r"(n / align_num), \
[src_s3] "r"(align_num * src_dsize), [src_n4] "i"(1), [src_s4] "i"(0), [src_n5] "i"(1), \
[src_s5] "i"(0), [dst_n0] "i"(64), [dst_n1] "i"(1), [dst_s1] "i"(0), [dst_n2] "i"(1), \
[dst_s2] "i"(0), [dst_n3] "r"(n / align_num), [dst_s3] "r"(align_num * dst_dsize), \
[dst_n4] "i"(1), [dst_s4] "i"(0), [dst_n5] "i"(1), [dst_s5] "i"(0)); \
} \
\
if (rem) { \
__asm__ __volatile__( \
"move.tiling.async.nram.sram.b16" \
" [%[dst_addr]], [%[src_addr]], " \
"%[src_n0], %[src_n1], %[src_s1], %[src_n2], %[src_s2], " \
"%[src_n3], %[src_s3], %[src_n4], %[src_s4], %[src_n5], %[src_s5], " \
"%[dst_n0], %[dst_n1], %[dst_s1], %[dst_n2], %[dst_s2], " \
"%[dst_n3], %[dst_s3], %[dst_n4], %[dst_s4]," \
"%[dst_n5], %[dst_s5]," convert_type ";\n\t" ::[dst_addr] "r"(dst + n), \
[src_addr] "r"(src + n), [src_n0] "r"(rem), [src_n1] "i"(1), [src_s1] "i"(0), \
[src_n2] "i"(1), [src_s2] "i"(0), [src_n3] "i"(1), [src_s3] "i"(0), [src_n4] "i"(1), \
[src_s4] "i"(0), [src_n5] "i"(1), [src_s5] "i"(0), [dst_n0] "r"(rem), [dst_n1] "i"(1), \
[dst_s1] "i"(0), [dst_n2] "i"(1), [dst_s2] "i"(0), [dst_n3] "i"(1), [dst_s3] "i"(0), \
[dst_n4] "i"(1), [dst_s4] "i"(0), [dst_n5] "i"(1), [dst_s5] "i"(0)); \
} \
} while (false)
__mlu_func__ void warp_prompt_input(float *dst, half *src, int32_t size) {
#if __BANG_ARCH__ >= 500
SRAM2NRAM_CONVERT_IMPL(dst, src, size, sizeof(float), sizeof(half), ".cvt.rn.f32.f16()");
#endif
}
__mlu_func__ void warp_prompt_input(float *dst, bfloat16_t *src, int32_t size) {
#if __BANG_ARCH__ >= 500
SRAM2NRAM_CONVERT_IMPL(dst, src, size, sizeof(float), sizeof(bfloat16_t), ".cvt.rn.f32.bf16()");
#endif
}
__mlu_func__ void warp_prompt_input(float *dst, float *src, int32_t size) {
__memcpy_async((float *)dst, (float *)src, size, SRAM2NRAM);
}
#define SRAM2WRAM_CONVERT_IMPL(wram_dst, sram_src, n, k, total_k, dst_dsize, src_dsize, \
convert_type) \
int align_n = PAD_DOWN(n, LT_NUM); \
int sn0 = ONE_LINE; \
int size_sn0 = sn0 / src_dsize; \
int sn1 = ONE_LINE / src_dsize; \
int ss1 = total_k * src_dsize; \
int sn3 = k / size_sn0; \
int sn4 = align_n / sn1; \
int ss4 = sn1 * ss1; \
int dn0 = sn0; \
int dn1 = ROW_PER_LT; \
int dst_k = PAD_UP(k, ONE_LINE / dst_dsize); \
int ds1 = dst_k * dst_dsize; \
int dn2 = sn1 / ROW_PER_LT; \
int ds2 = WRAM_LT_MAP16_STRIDE; \
int ds3 = sn0 * dst_dsize / src_dsize; \
int dn4 = LT_SIZE / dn2; \
int ds4 = dn2 * WRAM_LT_MAP16_STRIDE; \
int dn5 = align_n / LT_NUM; \
int ds5 = ROW_PER_LT * dst_k * dst_dsize; \
int rem_k = k % size_sn0; \
int8_t *sram_src2 = (int8_t *)sram_src + sn3 * size_sn0 * src_dsize; \
int8_t *wram_dst2 = (int8_t *)wram_dst + sn3 * size_sn0 * dst_dsize; \
if (align_n > 0 && sn3 > 0) { \
__asm__ __volatile__( \
"move.tiling.async.wram.sram.b16 [%[dst_addr]], [%[src_addr]], %[src_n0], " \
"%[src_n1], %[src_s1], 1, 0, %[src_n3], %[src_s3], " \
"%[src_n4], %[src_s4], 1, 0, %[dst_n0], " \
"%[dst_n1], %[dst_s1], %[dst_n2], %[dst_s2], %[dst_n3], %[dst_s3], " \
"%[dst_n4], %[dst_s4], %[dst_n5], %[dst_s5], " convert_type \
";\n\t" ::[dst_addr] "r"(wram_dst), \
[src_addr] "r"(sram_src), [src_n0] "r"(sn0), [src_n1] "r"(sn1), [src_s1] "r"(ss1), \
[src_n3] "r"(sn3), [src_s3] "r"(sn0), [src_n4] "r"(sn4), [src_s4] "r"(ss4), \
[dst_n0] "r"(dn0), [dst_n1] "r"(dn1), [dst_s1] "r"(ds1), [dst_n2] "r"(dn2), \
[dst_s2] "r"(ds2), [dst_n3] "r"(sn3), [dst_s3] "r"(ds3), [dst_n4] "r"(dn4), \
[dst_s4] "r"(ds4), [dst_n5] "r"(dn5), [dst_s5] "r"(ds5)); \
sram_src += align_n * total_k; \
wram_dst += align_n / LT_SIZE * dst_k; \
} \
align_n = PAD_UP(n % LT_NUM, ROW_PER_LT); \
if (align_n > 0 && sn3 > 0) { \
sn1 = align_n; \
dn2 = (sn1 + ROW_PER_LT - 1) / ROW_PER_LT; \
__asm__ __volatile__( \
"move.tiling.async.wram.sram.b16 [%[dst_addr]], [%[src_addr]], %[src_n0], " \
"%[src_n1], %[src_s1], 1, 0, %[src_n3], %[src_s3], 1, 0, 1, 0, " \
"%[dst_n0], %[dst_n1], %[dst_s1], %[dst_n2], %[dst_s2], %[dst_n3], %[dst_s3], " \
"1, 0, 1, 0, " convert_type ";\n\t" ::[dst_addr] "r"(wram_dst), \
[src_addr] "r"(sram_src), [src_n0] "r"(sn0), [src_n1] "r"(sn1), [src_s1] "r"(ss1), \
[src_n3] "r"(sn3), [src_s3] "r"(sn0), [dst_n0] "r"(dn0), [dst_n1] "r"(dn1), \
[dst_s1] "r"(ds1), [dst_n2] "r"(dn2), [dst_s2] "r"(ds2), [dst_n3] "r"(sn3), \
[dst_s3] "r"(ds3)); \
sram_src += align_n * total_k; \
wram_dst += align_n / ROW_PER_LT * WRAM_LT_MAP16_STRIDE / dst_dsize; \
} \
if (rem_k > 0) { \
align_n = PAD_UP(n, LT_NUM); \
sn0 = rem_k * src_dsize; \
dn0 = sn0; \
__asm__ __volatile__( \
"move.tiling.async.wram.sram.b16 [%[dst_addr]], [%[src_addr]], %[src_n0], " \
"%[src_n1], %[src_s1], 1, 0, %[src_n3], %[src_s3], " \
"%[src_n4], %[src_s4], 1, 0, %[dst_n0], " \
"%[dst_n1], %[dst_s1], %[dst_n2], %[dst_s2], %[dst_n3], %[dst_s3], " \
"%[dst_n4], %[dst_s4], %[dst_n5], %[dst_s5], " convert_type \
";\n\t" ::[dst_addr] "r"(wram_dst2), \
[src_addr] "r"(sram_src2), [src_n0] "r"(sn0), [src_n1] "r"(ROW_PER_LT), [src_s1] "r"(ss1), \
[src_n3] "r"(LT_NUM / ROW_PER_LT), [src_s3] "r"(ROW_PER_LT * ss1), \
[src_n4] "r"(align_n / LT_NUM), [src_s4] "r"(LT_NUM * ss1), [dst_n0] "r"(dn0), \
[dst_n1] "r"(ROW_PER_LT), [dst_s1] "r"(ds1), [dst_n2] "r"(1), [dst_s2] "r"(0), \
[dst_n3] "r"(LT_NUM / ROW_PER_LT), [dst_s3] "r"(WRAM_LT_MAP16_STRIDE), \
[dst_n4] "r"(align_n / LT_NUM), [dst_s4] "r"(ROW_PER_LT * ds1), [dst_n5] "r"(1), \
[dst_s5] "r"(0)); \
}
__mlu_func__ void warp_prompt_weight(float *wram_dst,
half *sram_src,
int32_t warp_n,
int32_t len_k,
int32_t total_k) {
#if __BANG_ARCH__ >= 500
SRAM2WRAM_CONVERT_IMPL(wram_dst, sram_src, warp_n, len_k, total_k, sizeof(float), sizeof(half),
".cvt.rn.f32.f16()");
#endif
}
__mlu_func__ void warp_prompt_weight(float *wram_dst,
bfloat16_t *sram_src,
int32_t warp_n,
int32_t len_k,
int32_t total_k) {
#if __BANG_ARCH__ >= 500
SRAM2WRAM_CONVERT_IMPL(wram_dst, sram_src, warp_n, len_k, total_k, sizeof(float),
sizeof(bfloat16_t), ".cvt.rn.f32.bf16()");
#endif
}
template <typename T>
__mlu_func__ void warp_prompt_weight(T *wram_dst,
T *sram_src,
int32_t n,
int32_t len_k,
int32_t total_k) {
int32_t type_size = sizeof(T);
int32_t data_size = len_k * type_size;
int32_t ds0 = PAD_UP(data_size, ONE_LINE);
int32_t ss0 = total_k * type_size;
int32_t count = n / LT_NUM;
for (int32_t i = 0; i < count; ++i) {
__memcpy_async(wram_dst, sram_src, data_size, SRAM2WRAM, ds0, ROW_PER_LT - 1,
WRAM_LT_MAP16_STRIDE, LT_SIZE - 1, ss0, LT_NUM - 1, 0, 0);
wram_dst = (T *)((int8_t *)wram_dst + ROW_PER_LT * ds0);
sram_src = (T *)((int8_t *)sram_src + LT_NUM * ss0);
}
count = n % LT_NUM / ROW_PER_LT;
if (count > 0) {
__memcpy_async(wram_dst, sram_src, data_size, SRAM2WRAM, ds0, ROW_PER_LT - 1,
WRAM_LT_MAP16_STRIDE, count - 1, ss0, count * ROW_PER_LT - 1, 0, 0);
wram_dst = (T *)((int8_t *)wram_dst + count * WRAM_LT_MAP16_STRIDE);
sram_src = (T *)((int8_t *)sram_src + count * ROW_PER_LT * ss0);
}
count = n % ROW_PER_LT;
if (count > 0) {
__memcpy_async(wram_dst, sram_src, data_size, SRAM2WRAM, ds0, ss0, count - 1);
}
}
__mlu_func__ void assignTaskEvenly(const int32_t num_total_task,
const int32_t &taskid,
const int32_t &taskdim,
int32_t &task_offset,
int32_t &num_cur_task) {
int32_t num_per_task = num_total_task / taskdim;
int32_t rem_idx = num_total_task % taskdim;
if (taskid < rem_idx) {
task_offset = taskid * (num_per_task + 1);
num_cur_task = num_per_task + 1;
} else {
task_offset = taskid * num_per_task + rem_idx;
num_cur_task = num_per_task;
}
}
__mlu_func__ void bidirectionBarrierOp() {
int32_t bcnt = coreDim + 1;
if (__is_ipu()) {
__asm__ __volatile__("barrier.arrive.local.pmv.cio 5, %[cnt];\n\t" ::[cnt] "r"(bcnt));
__asm__ __volatile__("barrier.sync.local.pio.cmv 3, %[cnt];\n\t" ::[cnt] "r"(bcnt));
} else {
__asm__ __volatile__("barrier.sync.local.pmv.cio 5, %[cnt];\n\t" ::[cnt] "r"(bcnt));
__asm__ __volatile__("barrier.arrive.local.pio.cmv 3, %[cnt];\n\t" ::[cnt] "r"(bcnt));
}
}
__mlu_func__ void __wmma(float *c, float *a, float *b, int32_t m, int32_t n, int32_t k) {
__bang_conv_partial((float *)c, (float *)a, (float *)b, (float *)c, k, m, 1, 1, 1, 1, 1, n);
}
__mlu_func__ void warp_store(void *ddr_dst,
void *nram_src,
const int32_t data_num,
const int32_t dst_stride,
const int32_t src_stride,
const int32_t count,
const int32_t dt_size) {
if (src_stride == data_num && dst_stride == data_num) {
__memcpy_async(ddr_dst, nram_src, count * data_num * dt_size, NRAM2GDRAM);
} else {
__memcpy_async(ddr_dst, nram_src, data_num * dt_size, NRAM2GDRAM, (size_t)dst_stride * dt_size,
src_stride * dt_size, count - 1);
}
}
template <typename Tc, typename Tcc>
__mlu_func__ void splitKReduce(Tcc *workspace,
Tc *output,
int32_t M,
int32_t N,
int32_t split_k_num,
int32_t ldc) {
int32_t offset_m, cta_m;
assignTaskEvenly(M, taskId, taskDim, offset_m, cta_m);
if (cta_m <= 0) return;
int32_t block_m = NRAM_BUFFER_SIZE / split_k_num / N / sizeof(Tcc);
int32_t repeat = cta_m / block_m + int32_t(cta_m % block_m != 0);
int32_t rem_m = cta_m % block_m != 0 ? cta_m % block_m : block_m;
Tcc *workspace_ddr = (Tcc *)workspace + offset_m * N;
Tc *output_ddr = (Tc *)output + offset_m * ldc;
for (int32_t i = 0; i < repeat; i++) {
int32_t current_m = i == repeat - 1 ? rem_m : block_m;
int32_t data_size = N * sizeof(Tc);
int32_t data_num = current_m - 1;
if (ldc == N) {
data_size = current_m * N * sizeof(Tc);
data_num = 0;
}
__memcpy((Tcc *)nram_buffer, (Tcc *)workspace_ddr, current_m * N * sizeof(Tcc), GDRAM2NRAM,
current_m * N * sizeof(Tcc), M * N * sizeof(Tcc), split_k_num - 1);
__bang_sumpool((Tcc *)nram_buffer, (Tcc *)nram_buffer, current_m * N, split_k_num, 1,
split_k_num, 1, 1, 1);
__memcpy((Tc *)output_ddr, (Tc *)nram_buffer, data_size, NRAM2GDRAM, ldc * sizeof(Tc),
N * sizeof(Tc), data_num);
workspace_ddr = workspace_ddr + block_m * N;
output_ddr = output_ddr + block_m * ldc;
}
}
template <typename Ta,
typename Tac,
typename Tb,
typename Tbc,
typename Tc,
typename Tcc,
bool EXCHANGE_AB>
__mlu_global__ void MLUCastGating(Ta *A,
Tb *B,
Tc *C,
Tcc *workspace,
int32_t M,
int32_t N,
int32_t K,
int32_t lda,
int32_t ldb,
int32_t ldc,
castGatingTileInfo split_info) {
#if __BANG_ARCH__ >= 500
int32_t block_k = split_info.block_k;
int32_t grid_dimx = split_info.split_k_num;
int32_t block = split_info.block;
int32_t grid_idx = clusterId % grid_dimx;
int32_t grid_idy = clusterId / grid_dimx;
int32_t offset_k = 0, problem_k = 0;
assignTaskEvenly(K, grid_idx, grid_dimx, offset_k, problem_k);
int32_t rem_k = problem_k % block_k > 0 ? problem_k % block_k : block_k;
int32_t k_loop = problem_k / block_k + (int32_t)(problem_k % block_k > 0);
int32_t cta_k = k_loop == 1 ? rem_k : block_k;
int32_t cta_m = M, offset_m = 0, cta_n = N, offset_n = 0;
int32_t warp_m = cta_m, warp_offset_m = 0;
int32_t warp_n = cta_n, warp_offset_n = 0;
int32_t outer_loop = 0, outer_rem = 0;
if (EXCHANGE_AB) {
assignTaskEvenly(N, grid_idy, clusterDim / grid_dimx, offset_n, cta_n);
assignTaskEvenly(block, coreId, coreDim, warp_offset_n, warp_n);
if (cta_n > block && cta_n % block != 0) {
int32_t block_tmp = PAD_UP((cta_n + cta_n / block) / (cta_n / block + 1), coreDim * LT_NUM);
if (block_tmp < block) block = block_tmp;
}
outer_loop = (cta_n + block - 1) / block;
outer_rem = cta_n % block == 0 ? block : cta_n % block;
} else {
assignTaskEvenly(M, grid_idy, clusterDim / grid_dimx, offset_m, cta_m);
assignTaskEvenly(block, coreId, coreDim, warp_offset_m, warp_m);
if (cta_m > block && cta_m % block != 0) {
int32_t block_tmp = PAD_UP((cta_m + cta_m / block) / (cta_m / block + 1), coreDim);
if (block_tmp < block) block = block_tmp;
}
outer_loop = (cta_m + block - 1) / block;
outer_rem = cta_m % block == 0 ? block : cta_m % block;
}
int32_t size_nram_buf =
NRAM_BUFFER_SIZE - warp_m * warp_n * sizeof(Tcc) * (1 + int32_t(EXCHANGE_AB));
int32_t pong_a_nram = size_nram_buf / 2 / sizeof(Tac);
Tac *nbuf_a = (Tac *)nram_buffer;
Tcc *nbuf_c = (Tcc *)(nram_buffer + size_nram_buf);
Tcc *nbuf_out = EXCHANGE_AB ? (Tcc *)nbuf_c + warp_m * warp_n : nbuf_c;
int32_t size_sram_buf = SRAM_BUFFER_SIZE;
int32_t pong_sram_a = size_sram_buf / 2 / sizeof(Ta);
int32_t pong_sram_b = size_sram_buf / 2 / sizeof(Tb);
Ta *sbuf_a = (Ta *)sram_buffer;
Tb *sbuf_b = (Tb *)((Ta *)sram_buffer + (EXCHANGE_AB ? M * block_k : block * block_k));
int32_t pong_b_wram = WRAM_LT_MAP16_STRIDE / 2 / sizeof(Tbc);
Tbc *wbuf_b = (Tbc *)wram_buffer;
int32_t a_dsize = sizeof(Ta);
int32_t b_dsize = sizeof(Tb);
int32_t k_loop_count = 0;
for (int32_t j = 0; j < outer_loop; j++) {
Ta *a_ddr = (Ta *)A + offset_k + ((size_t)offset_m + j * block) * lda * int(!EXCHANGE_AB);
Tb *b_ddr = (Tb *)B + offset_k + ((size_t)offset_n + j * block) * ldb * int(EXCHANGE_AB);
int32_t current_block = j == outer_loop - 1 ? outer_rem : block;
if (EXCHANGE_AB) {
assignTaskEvenly(current_block, coreId, coreDim, warp_offset_n, warp_n);
} else {
assignTaskEvenly(current_block, coreId, coreDim, warp_offset_m, warp_m);
}
int32_t compute_total = warp_m * warp_n;
if (compute_total > 0 && __is_ipu()) {
if (!EXCHANGE_AB) {
__sync_io_move_compute(true, false, false, false, false, true);
}
__bang_write_zero((Tcc *)nbuf_c, compute_total);
}
int32_t i = 0;
for (; i < k_loop; i++) {
Ta *sram_a = (Ta *)sbuf_a + k_loop_count % 2 * pong_sram_a;
Tb *sram_b = (Tb *)sbuf_b + k_loop_count % 2 * pong_sram_b;
cta_k = i == k_loop - 1 ? rem_k : block_k;
if (__is_mpu()) {
if (EXCHANGE_AB) {
__memcpy_async(sram_b, b_ddr, cta_k * b_dsize, GDRAM2SRAM, cta_k * b_dsize, ldb * b_dsize,
current_block - 1);
__asm__ volatile(
"ld.async.stride.sram.gdram.scmnormal [%[dst]], [%[src]], %[size], %[dst_strd], "
"%[src_strd], %[segnum];\n\t" ::[dst] "r"(sram_a),
[src] "r"(a_ddr), [size] "r"(cta_k * a_dsize), [dst_strd] "r"(cta_k * a_dsize),
[src_strd] "r"(lda * a_dsize), [segnum] "r"(M - 1));
} else {
__memcpy_async(sram_a, a_ddr, cta_k * a_dsize, GDRAM2SRAM, cta_k * a_dsize, lda * a_dsize,
current_block - 1);
__asm__ volatile(
"ld.async.stride.sram.gdram.scmnormal [%[dst]], [%[src]], %[size], %[dst_strd], "
"%[src_strd], %[segnum];\n\t" ::[dst] "r"(sram_b),
[src] "r"(b_ddr), [size] "r"(cta_k * b_dsize), [dst_strd] "r"(cta_k * b_dsize),
[src_strd] "r"(ldb * b_dsize), [segnum] "r"(N - 1));
}
a_ddr = (Ta *)a_ddr + block_k;
b_ddr = (Tb *)b_ddr + block_k;
}
bidirectionBarrierOp();
if (__is_ipu() && compute_total > 0) {
__sync_io_move_compute(false, true, false, false, false, true);
__sync_io_move_compute(false, false, true, false, true, false);
if (i >= 1) {
__wmma(nbuf_c, (Tac *)nbuf_a + (k_loop_count - 1) % 2 * pong_a_nram,
(Tbc *)wbuf_b + (k_loop_count - 1) % 2 * pong_b_wram, warp_m, warp_n, block_k);
}
warp_prompt_input((Tac *)nbuf_a + k_loop_count % 2 * pong_a_nram,
sram_a + cta_k * warp_offset_m, warp_m * cta_k * sizeof(Ta));
// mvdma bound for EXCHANGE_AB when n==32
warp_prompt_weight((Tbc *)wbuf_b + k_loop_count % 2 * pong_b_wram,
(Tb *)sram_b + cta_k * warp_offset_n, warp_n, cta_k, cta_k);
}
k_loop_count += 1;
}
if (compute_total > 0) {
__sync_io_move_compute(false, true, false, false, false, true);
__wmma(nbuf_c, (Tac *)nbuf_a + (k_loop_count - 1) % 2 * pong_a_nram,
(Tbc *)wbuf_b + (k_loop_count - 1) % 2 * pong_b_wram, warp_m, warp_n, rem_k);
if (EXCHANGE_AB) {
__sync_io_move_compute(true, false, false, false, false, true);
__bang_transpose((Tcc *)nbuf_out, (Tcc *)nbuf_c, warp_m, warp_n);
}
int32_t total_offset =
grid_idx * M * N + (EXCHANGE_AB ? (offset_n + warp_offset_n + block * j) * M
: (offset_m + warp_offset_m + block * j) * N);
Tcc *wks = (Tcc *)workspace + total_offset;
int32_t store_c_size = sizeof(Tcc);
int8_t *store_ddr = (int8_t *)wks;
int32_t dst_str = EXCHANGE_AB ? M : N;
if (grid_dimx == 1) {
// convert Tcc to Tc
dst_str = ldc;
store_ddr =
(int8_t *)((Tc *)C + (EXCHANGE_AB ? (offset_n + warp_offset_n + block * j) * ldc
: (offset_m + warp_offset_m + block * j) * ldc));
}
__asm__ volatile("sync.psimd.cio;\n\t");
if (EXCHANGE_AB) {
warp_store(store_ddr, (Tcc *)nbuf_out, warp_m, dst_str, warp_m, warp_n, store_c_size);
} else {
warp_store(store_ddr, (Tcc *)nbuf_out, warp_n, dst_str, warp_n, warp_m, store_c_size);
}
}
}
if (grid_dimx != 1) {
__sync_all();
splitKReduce((Tcc *)workspace, (Tc *)C, EXCHANGE_AB ? N : M, EXCHANGE_AB ? M : N,
split_info.split_k_num, ldc);
}
#endif // __BANG_ARCH__ >= 500
}
} // namespace kernels
int32_t getBlock(int32_t m,
int32_t n,
int32_t core_num,
int32_t block_k,
int32_t a_dtype_size,
int32_t b_dtype_size,
int32_t compute_dtype_size,
bool EXCHANGE_AB) {
int32_t block = 0;
if (EXCHANGE_AB) {
int32_t block_m = n;
int32_t nram_block_n = (NRAM_BUFFER_SIZE - block_m * block_k * compute_dtype_size * 2) /
(2 * block_m * compute_dtype_size) * core_num;
int32_t wram_block_n =
WRAM_BUFFER_SIZE / 2 / PAD_UP(block_k * compute_dtype_size, 64) * core_num;
int32_t sram_block_n =
(SRAM_BUFFER_SIZE - block_m * block_k * a_dtype_size * 2) / (block_k * b_dtype_size * 2);
int32_t block_n_tmp = std::min(std::min(nram_block_n, wram_block_n), sram_block_n);
int32_t block_n = PAD_DOWN(block_n_tmp, core_num * LT_NUM);
return block_n > 0 ? block_n : block_n_tmp;
} else {
int32_t block_n = n;
int32_t nram_block_m =
NRAM_BUFFER_SIZE / (block_n * compute_dtype_size + block_k * compute_dtype_size * 2);
int32_t sram_block_m =
(SRAM_BUFFER_SIZE - block_n * block_k * b_dtype_size * 2) / (block_k * a_dtype_size * 2);
block = std::min(nram_block_m * core_num, PAD_DOWN(sram_block_m, core_num));
return block;
}
}
void gatingTiling(int32_t m,
int32_t n,
int32_t k,
size_t a_dtype_size,
size_t b_dtype_size,
size_t compute_dtype_size,
size_t workspace_size,
int32_t union_number,
int32_t core_num,
int32_t &block,
int32_t &split_k_num,
int32_t &block_k,
bool &EXCHANGE_AB) {
block_k = std::min(k, int32_t(512 / a_dtype_size));
split_k_num = 1;
// swap A and B to reduce computing waste caused by LT_NUM-align of co dimensian
if (m >= core_num * LT_NUM &&
float(m) / float(PAD_UP((size_t)m, LT_NUM)) > float(n) / float(PAD_UP(n, LT_NUM))) {
EXCHANGE_AB = true;
}
int32_t tmp_block = getBlock(m, n, core_num, block_k, a_dtype_size, b_dtype_size,
compute_dtype_size, EXCHANGE_AB);
int32_t total_blocks = DIV_UP((size_t)m, tmp_block);
block = tmp_block;
if (total_blocks < union_number && (size_t)k * a_dtype_size > 512 * union_number) {
for (int32_t i = total_blocks; i <= union_number; i++) {
if (union_number % i == 0) {
int32_t tmp_split_k = union_number / i;
size_t workspace_size_need = (size_t)tmp_split_k * m * n * compute_dtype_size;
if (workspace_size >= workspace_size_need) {
split_k_num = tmp_split_k;
block = std::min(((size_t)m + total_blocks - 1) / total_blocks, (size_t)tmp_block);
if (EXCHANGE_AB && block > LT_NUM * core_num) {
block = PAD_DOWN(block, LT_NUM * core_num);
}
break;
}
}
}
}
}
void getContxtInfo(int32_t *union_number, int32_t *core_num) {
CNdev dev;
cnCtxGetDevice(&dev);
CNRT_CHECK(cnrtDeviceGetAttribute(union_number, cnrtAttrMaxClusterPerUnionLimitTask, dev));
CNRT_CHECK(cnrtDeviceGetAttribute(core_num, cnrtAttrMcorePerCluster, dev));
}
KernelStatus invokeCastGating(cnrtQueue_t queue,
void *input,
void *filter,
void *output,
int input_row,
int expert_num,
int hidden_size,
cnnlDataType_t a_dtype,
void *workspace,
size_t workspace_size_bytes) {
if (is_arch300()) {
std::cerr << "[invokeCastGating]: kernel does not support MLU300 devices." << std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
if (expert_num > 128) {
std::cerr << "[invokeCastGating]: expert_num should NOT be greater than 128." << std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
if (workspace != NULL && workspace_size_bytes < 16 * 1024 * 1024) {
std::cerr
<< "[invokeCastGating]: workspace_size_bytes should NOT be smaller than 16 * 1024 * 1024."
<< std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
if (workspace_size_bytes > 0 && workspace == NULL) {
std::cerr << "[invokeCastGating]: workspace should NOT be NULL when workspace_size_bytes is "
"greater than 0."
<< std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
int32_t union_number, core_num;
getContxtInfo(&union_number, &core_num);
cnrtFunctionType_t func_type = cnrtFunctionType_t(union_number * core_num);
cnrtDim3_t dim;
dim.x = (int32_t)func_type;
dim.y = 1;
dim.z = 1;
cnnlDataType_t b_dtype = CNNL_DTYPE_FLOAT;
cnnlDataType_t compute_dtype = CNNL_DTYPE_FLOAT;
size_t a_dtype_size = 0, b_dtype_size = 0, compute_dtype_size = 0;
cnnlGetSizeOfDataType(a_dtype, &a_dtype_size);
cnnlGetSizeOfDataType(b_dtype, &b_dtype_size);
cnnlGetSizeOfDataType(compute_dtype, &compute_dtype_size);
castGatingTileInfo split_info;
bool EXCHANGE_AB = false;
gatingTiling(input_row, expert_num, hidden_size, a_dtype_size, b_dtype_size, compute_dtype_size,
workspace_size_bytes, union_number, core_num, split_info.block,
split_info.split_k_num, split_info.block_k, EXCHANGE_AB);
if (a_dtype == CNNL_DTYPE_BFLOAT16) {
if (EXCHANGE_AB) {
kernels::MLUCastGating<float, float, bfloat16_t, float, float, float, true>
<<<dim, func_type, queue>>>((float *)filter, (bfloat16_t *)input, (float *)output,
(float *)workspace, expert_num, input_row, hidden_size,
hidden_size, hidden_size, expert_num, split_info);
} else {
kernels::MLUCastGating<bfloat16_t, float, float, float, float, float, false>
<<<dim, func_type, queue>>>((bfloat16_t *)input, (float *)filter, (float *)output,
(float *)workspace, input_row, expert_num, hidden_size,
hidden_size, hidden_size, expert_num, split_info);
}
} else if (a_dtype == CNNL_DTYPE_HALF) {
if (EXCHANGE_AB) {
kernels::MLUCastGating<float, float, half, float, float, float, true>
<<<dim, func_type, queue>>>((float *)filter, (half *)input, (float *)output,
(float *)workspace, expert_num, input_row, hidden_size,
hidden_size, hidden_size, expert_num, split_info);
} else {
kernels::MLUCastGating<half, float, float, float, float, float, false>
<<<dim, func_type, queue>>>((half *)input, (float *)filter, (float *)output,
(float *)workspace, input_row, expert_num, hidden_size,
hidden_size, hidden_size, expert_num, split_info);
}
} else {
std::cerr << "[invokeCastGating]: kernel does not support this data-type." << std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
} // namespace tmo

View File

@@ -0,0 +1,50 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_CAST_GATING_MLUH_
#define CSRC_KERNELS_CAST_GATING_MLUH_
#include "../kernel_utils.h"
#include "cnnl.h"
namespace tmo {
/**
* @brief Convert input to float32 and do gating operation.
* @param queue: The queue for mlu.
* @param input: Input. Pointer to the MLU memory that stores the input,
* the shape must be [input_row, hidden_size].
* @param filter: Input. Pointer to the MLU memory that stores the weight,
* the shape must be [expert_num, hidden_size].
* @param output: Output. Pointer to the MLU memory that stores the output,
* the shape must be [input_row, expert_num].
* @param input_row: Input.
* @param expert_num: Input.
* @param hidden_size: Input.
* @param a_dtype: Input. The data-type of input.
* @param workspace: Input. Pointer to the MLU workspace.
* @param workspace_size_bytes: Input. The size of workspace in bytes.
* @note: a_dtype must be CNNL_DTYPE_BFLOAT16 or CNNL_DTYPE_HALF.
* expert_num must be in range [1, 128].
* If workspace is NOT NULL, workspace_size_bytes must NOT be smaller than 16 * 1024 * 1024.
* The data-type of filter and output must be float.
* cast_gating only supports MLU500 device or higher.
*/
KernelStatus invokeCastGating(cnrtQueue_t queue,
void *input,
void *filter,
void *output,
int input_row,
int expert_num,
int hidden_size,
cnnlDataType_t a_dtype,
void *workspace,
size_t workspace_size_bytes);
} // namespace tmo
#endif // CSRC_KERNELS_CAST_GATING_MLUH_

View File

@@ -0,0 +1,760 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include <algorithm>
#include <iostream>
#include "cnrt.h"
#include "combine_result.mluh"
// clang-format off
#include <bang_device_functions_extra.h>
#include <mlu.h>
// clang-format on
#if __BANG_ARCH__ >= 592
#include <bang_fusor.h>
template <typename SrcT>
using bang_fusor = bang::experimental::fusor<SrcT>;
#endif
namespace tmo {
namespace kernels {
#define NRAM_REMAIN_SIZE (32 * 1024)
#define NRAM_BUFFER_SIZE (__MLU_NRAM_SIZE__ * 1024 - NRAM_REMAIN_SIZE)
__nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE];
template <typename T>
__mlu_func__ void swap(T *&ping, T *&pong) {
T *temp = ping;
ping = pong;
pong = temp;
}
#define GATHER_ASYNC_IO0(offset_type) \
__asm__ __volatile__( \
"gather.vector.async.nram.gdram.nram." #offset_type \
".io0 [%[dst]], [%[src]], [%[offset]], " \
"%[transfer_size], %[transfer_num], %[stride];\n\t" ::[dst] "r"(dst), \
[src] "r"(src_gdram), [offset] "r"(nram_offset), [transfer_size] "r"(transfer_size), \
[transfer_num] "r"(token_count), [stride] "r"(transfer_size))
#define FUSE_MUL_CVT(dst_dtype) \
__asm__ __volatile__("mult.scalar.nram.crn." #dst_dtype \
".f32 [%[dst]], [%[src0]], %[src1]," \
" %[size];\n\t" ::[dst] "r"(dst), \
[src0] "r"(nram_input_buffer), [src1] "r"(expert_coeff), [size] "r"(size));
#define FUSE_MULADD_CVT(dst_dtype) \
__asm__ __volatile__("muladd.nram.crn." #dst_dtype \
".f32 [%[dst]], [%[src0]], %[src1], [%[dst]]," \
" %[size], %[size];\n\t" ::[dst] "r"(dst), \
[src0] "r"(nram_input_buffer), [src1] "r"(expert_coeff), [size] "r"(size));
template <typename T>
__mlu_func__ void toFloat(float *dst, T *src, int count) {
if (std::is_same<T, half>::value) {
__bang_half2float(dst, (half *)src, count);
} else if (std::is_same<T, bfloat16_t>::value) {
__bang_bfloat162float(dst, (bfloat16_t *)src, count);
} else if (std::is_same<T, float>::value) {
__bang_add_scalar((float *)dst, (float *)src, (float)0, count);
}
}
template <typename T>
__mlu_func__ void floatTo(T *dst, float *src, int count) {
if (std::is_same<T, half>::value) {
__bang_float2half_rn((half *)dst, src, count);
} else if (std::is_same<T, bfloat16_t>::value) {
__bang_float2bfloat16_rn((bfloat16_t *)dst, src, count);
} else if (std::is_same<T, float>::value) {
__bang_add_scalar((float *)dst, (float *)src, (float)0, count);
}
}
__mlu_func__ void loadAsync2d(void *dst,
void *src,
int size,
int dststride,
int srcstride,
int seg_num) {
#if __BANG_ARCH__ > 500
__asm__ __volatile__(
"ld.async.stride.nram.gdram.io0 [%[dst]], [%[src]],"
" %[size], %[dststride], %[srcstride], %[segnum];\n\t" ::[dst] "r"(dst),
[src] "r"(src), [size] "r"(size), [dststride] "r"(dststride), [srcstride] "r"(srcstride),
[segnum] "r"(seg_num));
#else
__memcpy_async(dst, src, size, GDRAM2NRAM, dststride, srcstride, seg_num);
#endif
}
__mlu_func__ void storeAsync2d(void *dst,
void *src,
int size,
int dststride,
int srcstride,
int seg_num) {
#if __BANG_ARCH__ > 500
__asm__ __volatile__(
"st.async.stride.gdram.nram.io1 [%[dst]], [%[src]],"
" %[size], %[dststride], %[srcstride], %[segnum];\n\t" ::[dst] "r"(dst),
[src] "r"(src), [size] "r"(size), [dststride] "r"(dststride), [srcstride] "r"(srcstride),
[segnum] "r"(seg_num));
#else
__memcpy_async(dst, src, size, NRAM2GDRAM, dststride, srcstride, seg_num);
#endif
}
template <typename T_IDX>
__mlu_func__ void gatherTokensAsync(void *dst,
void *src_gdram,
T_IDX *nram_offset,
int transfer_size,
int token_count) {
if (token_count <= 0 || src_gdram == nullptr) return;
#if __BANG_ARCH__ > 500
if (std::is_same<T_IDX, uint32_t>::value) {
GATHER_ASYNC_IO0(u32);
} else {
GATHER_ASYNC_IO0(u64);
}
#else
for (int k = 0; k < token_count; k++) {
__memcpy_async((int8_t *)dst + k * transfer_size,
(int8_t *)src_gdram + __load_nram(nram_offset + k), transfer_size, GDRAM2NRAM);
}
#endif
}
__mlu_func__ int getMaskAndActiveTokenCount(int *nram_token_idx,
int *nram_mask,
uint8_t *nram_mask_char,
int *nram_mask_buffer,
int begin_expert_acc_tokens,
int end_expert_acc_tokens,
int token_count,
bool expert_parallelism) {
if (!expert_parallelism) {
return token_count;
}
__bang_lt_scalar(nram_mask_buffer, nram_token_idx, end_expert_acc_tokens, token_count);
#if __BANG_ARCH__ >= 592
bang_fusor<int32_t>(nram_mask, nram_token_idx, token_count)
.ge(begin_expert_acc_tokens)
.land(nram_mask_buffer)
.cvt<float>(0);
#else
__bang_ge_scalar(nram_mask, nram_token_idx, begin_expert_acc_tokens, token_count);
__bang_and(nram_mask, nram_mask, nram_mask_buffer, token_count);
__bang_int322float((float *)nram_mask, (int *)nram_mask, token_count, 0);
#endif
__bang_filter((float *)nram_token_idx, (float *)nram_token_idx, (float *)nram_mask, token_count);
int active_token_count = __bang_count((float *)nram_mask, token_count);
return active_token_count;
}
__mlu_func__ void computeOffset0(uint64_t *nram_offset,
int *nram_idx,
uint64_t mul_scalar,
int64_t add_scalar,
uint32_t token_count) {
#if __BANG_ARCH__ > 592
__bang_int322int64((int64_t *)nram_offset, nram_idx, token_count, 0, 0);
#else
__bang_int322int64((int64_t *)nram_offset, nram_idx, token_count);
#endif
__bang_mul_scalar(nram_offset, nram_offset, mul_scalar, token_count);
__bang_add_scalar((int64_t *)nram_offset, (int64_t *)nram_offset, add_scalar, token_count);
}
__mlu_func__ void computeOffset0(uint32_t *nram_offset,
int *nram_idx,
uint32_t mul_scalar,
int64_t add_scalar,
uint32_t token_count) {
__bang_fusion(FUSION_FMA, nram_offset, (uint32_t *)nram_idx, mul_scalar, (int32_t)add_scalar,
token_count);
}
template <typename T_IDX>
__mlu_func__ void computeOffset(T_IDX *nram_token_offset,
T_IDX *nram_bias_offset,
int *nram_token_idx,
int *nram_expert_tables,
int expert_num,
int token_count,
int active_token_count,
int hidden_size,
int local_hidden_begin,
int dtype_size,
int start_expert_id,
int expert_size,
int begin_expert_acc_tokens,
bool has_bias) {
// for large tensor, convert int322int64 then do multiply and add seperately.
if (active_token_count <= 0) return;
if (has_bias) {
int *nram_bias_offset_temp = (int *)nram_token_offset;
__bang_write_zero(nram_bias_offset, active_token_count);
for (int i = start_expert_id + 1; i < start_expert_id + expert_size; i++) {
__bang_ge_scalar(nram_bias_offset_temp, nram_token_idx, nram_expert_tables[i],
active_token_count);
__bang_add((int *)nram_bias_offset, (int *)nram_bias_offset, nram_bias_offset_temp,
active_token_count);
}
__bang_add_scalar(nram_bias_offset_temp, (int *)nram_bias_offset, 0, active_token_count);
computeOffset0(nram_bias_offset, nram_bias_offset_temp, (T_IDX)hidden_size * dtype_size,
(T_IDX)local_hidden_begin * dtype_size, active_token_count);
}
int64_t offset =
((int64_t)local_hidden_begin - (int64_t)begin_expert_acc_tokens * hidden_size) * dtype_size;
computeOffset0(nram_token_offset, nram_token_idx, (T_IDX)(hidden_size * dtype_size), offset,
active_token_count);
}
template <typename T>
__mlu_func__ void mulScalarCvt(T *dst, float *nram_input_buffer, float expert_coeff, int size) {
#if __BANG_ARCH__ > 500
if (std::is_same<T, bfloat16_t>::value) {
FUSE_MUL_CVT(bf16);
} else if (std::is_same<T, half>::value) {
FUSE_MUL_CVT(f16);
} else if (std::is_same<T, float>::value) {
__bang_mul_scalar((float *)dst, nram_input_buffer, expert_coeff, size);
}
#else
__bang_mul_scalar((float *)dst, nram_input_buffer, expert_coeff, size);
floatTo((T *)dst, (float *)dst, size);
#endif
}
template <typename T>
__mlu_func__ void mulAddCvt(T *dst, float *nram_input_buffer, float expert_coeff, int size) {
#if __BANG_ARCH__ > 500
if (std::is_same<T, bfloat16_t>::value) {
FUSE_MULADD_CVT(bf16);
} else if (std::is_same<T, half>::value) {
FUSE_MULADD_CVT(f16);
} else if (std::is_same<T, float>::value) {
__bang_fusion(FUSION_FMA, (float *)dst, nram_input_buffer, expert_coeff, (float *)dst, size,
size);
}
#else
__bang_fusion(FUSION_FMA, (float *)dst, nram_input_buffer, expert_coeff, (float *)dst, size,
size);
floatTo((T *)dst, (float *)dst, size);
#endif
}
// weightedReduceSum with EP split
// input [token_count, k, hidden_size], weight [token_count, k]
// 1. input * weight
// 2. reduce: [token_count, k, hidden_size] --> [token_count, hidden_size], reduce mode add
template <typename T,
bool expert_parallelism,
typename std::enable_if<expert_parallelism == true, void *>::type = nullptr>
__mlu_func__ void weightedReduceSum(T *output,
T *input,
float *weight,
T *input_buffer,
int8_t *og_mask,
int topk,
int hidden_size,
int token_count,
bool &is_ping) {
float *nram_input_buffer =
(float *)((half *)input_buffer +
((std::is_same<T, float>::value || !is_ping) ? 0 : hidden_size));
T *output_base = output - ((std::is_same<T, float>::value || is_ping) ? 0 : hidden_size);
int32_t index[32];
float reg_weight[128];
int8_t *index_ = (int8_t *)index;
int topk_divide_4 = PAD_UP(topk, 4) / 4;
int token_use_count = 0;
for (int t_i = 0; t_i < token_count; t_i++) {
float *output_begin = (float *)(output_base + t_i * hidden_size);
for (int i = 0; i < topk_divide_4; i++) {
index[i] = __load_nram((int32_t *)(og_mask + t_i * topk) + i);
float *weight_begin = weight + t_i * topk + i * 4;
reg_weight[i * 4] = __load_nram(weight_begin);
if (i * 4 + 1 < topk) {
reg_weight[i * 4 + 1] = __load_nram(weight_begin + 1);
}
if (i * 4 + 2 < topk) {
reg_weight[i * 4 + 2] = __load_nram(weight_begin + 2);
}
if (i * 4 + 3 < topk) {
reg_weight[i * 4 + 3] = __load_nram(weight_begin + 3);
}
}
int first_in_expert = 0;
float expert_coeff;
for (; first_in_expert < topk - 1; first_in_expert++) {
bool in_expert_range = index_[first_in_expert];
if (!in_expert_range) continue;
expert_coeff = reg_weight[first_in_expert];
toFloat<T>(output_begin, input + token_use_count * hidden_size, hidden_size);
__bang_mul_scalar(output_begin, output_begin, expert_coeff, hidden_size);
token_use_count++;
break;
}
if (first_in_expert == topk - 1) {
if (index_[topk - 1]) {
expert_coeff = reg_weight[topk - 1];
toFloat<T>(nram_input_buffer, input + token_use_count * hidden_size, hidden_size);
token_use_count++;
mulScalarCvt((T *)output_begin, nram_input_buffer, expert_coeff, hidden_size);
} else {
__bang_write_zero((T *)output_begin, hidden_size);
}
} else {
for (int j = first_in_expert + 1; j < topk - 1; j++) {
bool in_expert_range = index_[j];
if (!in_expert_range) continue;
expert_coeff = reg_weight[j];
toFloat<T>(nram_input_buffer, input + token_use_count * hidden_size, hidden_size);
token_use_count++;
__bang_fusion(FUSION_FMA, output_begin, nram_input_buffer, expert_coeff, output_begin,
hidden_size, hidden_size);
}
if (index_[topk - 1]) {
expert_coeff = reg_weight[topk - 1];
toFloat<T>(nram_input_buffer, input + token_use_count * hidden_size, hidden_size);
token_use_count++;
mulAddCvt((T *)output_begin, nram_input_buffer, expert_coeff, hidden_size);
} else {
floatTo((T *)output_begin, (float *)output_begin, hidden_size);
}
}
}
if (!is_ping && sizeof(T) < sizeof(float)) {
__memcpy(output + (token_count - 1) * hidden_size, output + (token_count - 2) * hidden_size,
hidden_size * sizeof(T), NRAM2NRAM, -hidden_size * sizeof(T), token_count - 1,
token_count * hidden_size * sizeof(T), 0, -hidden_size * sizeof(T), token_count - 1,
token_count * hidden_size * sizeof(T), 0);
}
is_ping = !is_ping;
}
// weightedReduceSum without EP split
// input [token_count, k, hidden_size], weight [token_count, k]
// 1. input * weight
// 2. reduce: [token_count, k, hidden_size] --> [token_count, hidden_size], reduce mode add
template <typename T,
bool expert_parallelism,
typename std::enable_if<expert_parallelism == false, void *>::type = nullptr>
__mlu_func__ void weightedReduceSum(T *output,
T *input,
float *weight,
T *input_buffer,
int8_t *og_mask,
int topk,
int hidden_size,
int token_count,
bool &is_ping) {
float *nram_input_buffer =
(float *)((half *)input_buffer +
((std::is_same<T, float>::value || !is_ping) ? 0 : hidden_size));
T *output_base = output - ((std::is_same<T, float>::value || is_ping) ? 0 : hidden_size);
if (topk == 1) {
for (int i = 0; i < token_count; i++) {
float expert_coeff = __load_nram(weight + i);
toFloat<T>(nram_input_buffer, input + i * hidden_size, hidden_size);
mulScalarCvt(output + i * hidden_size, nram_input_buffer, expert_coeff, hidden_size);
}
return;
}
for (int t_i = 0; t_i < token_count; t_i++) {
float *output_begin = (float *)(output_base + t_i * hidden_size);
float expert_coeff = __load_nram(weight + t_i * topk);
toFloat<T>(output_begin, input + t_i * topk * hidden_size, hidden_size);
toFloat<T>(nram_input_buffer, input + (t_i * topk + 1) * hidden_size, hidden_size);
__bang_mul_scalar(output_begin, output_begin, expert_coeff, hidden_size);
expert_coeff = __load_nram(weight + t_i * topk + 1);
for (int k_i = 2; k_i < topk; k_i++) {
__bang_fusion(FUSION_FMA, output_begin, nram_input_buffer, expert_coeff, output_begin,
hidden_size, hidden_size);
expert_coeff = __load_nram(weight + t_i * topk + k_i);
toFloat<T>(nram_input_buffer, input + (t_i * topk + k_i) * hidden_size, hidden_size);
}
mulAddCvt((T *)output_begin, nram_input_buffer, expert_coeff, hidden_size);
}
if (!is_ping && sizeof(T) < sizeof(float)) {
__memcpy(output + (token_count - 1) * hidden_size, output + (token_count - 2) * hidden_size,
hidden_size * sizeof(T), NRAM2NRAM, -hidden_size * sizeof(T), token_count - 1,
token_count * hidden_size * sizeof(T), 0, -hidden_size * sizeof(T), token_count - 1,
token_count * hidden_size * sizeof(T), 0);
}
is_ping = !is_ping;
}
template <typename T, typename T_IDX>
__mlu_global__ void MLUCombineMoeResultKernel(T *output,
T *input,
T *bias,
T *residual,
float *reduce_weight,
int *cusum_token_count,
int *gather_idx,
int num_token,
int topk,
int num_expert,
int hidden_size,
int start_expert_id,
int expert_size,
int HIDDEN_BLOCK,
int TOKEN_BLOCK) {
if (__is_mpu()) {
return;
}
int local_hidden_begin = taskIdX * HIDDEN_BLOCK;
int local_hidden_size = std::min(HIDDEN_BLOCK, hidden_size - local_hidden_begin);
int task_avg_tokens = num_token / taskDimY;
int task_remain_tokens = num_token % taskDimY;
int task_tokens = task_avg_tokens + (int)(taskIdY < task_remain_tokens);
int task_token_begin = taskIdY * task_avg_tokens + std::min(taskIdY, task_remain_tokens);
if (local_hidden_size <= 0) return;
if (task_tokens <= 0) return;
constexpr int int32_dtype_size = (int)sizeof(int);
constexpr int fp32_dtype_size = (int)sizeof(float);
int pad_num_expert = PAD_UP(num_expert + 1, 32);
bool has_bias = bias != nullptr;
bool has_residual = residual != nullptr;
bool using_acc_sum = cusum_token_count != nullptr;
bool expert_parallelism = expert_size < num_expert;
int block_size = TOKEN_BLOCK * topk;
int pad_block_size = PAD_UP(block_size, 64);
int *nram_expert_tables = (int *)nram_buffer;
int *nram_token_idx = nram_expert_tables + pad_num_expert;
T_IDX *nram_token_offset = (T_IDX *)(nram_token_idx + pad_block_size);
T_IDX *nram_bias_offset = (T_IDX *)(nram_token_offset + pad_block_size);
int *nram_mask = (int *)(nram_bias_offset + (int)has_bias * pad_block_size);
T *nram_input_ping = (T *)(nram_mask + pad_block_size);
T *nram_input_pong = nram_input_ping + block_size * HIDDEN_BLOCK;
T *nram_bias_ping = nram_input_pong + block_size * HIDDEN_BLOCK;
T *nram_bias_pong = nram_bias_ping + (int)has_bias * block_size * HIDDEN_BLOCK;
T *nram_residual_ping = nram_bias_pong + (int)has_bias * block_size * HIDDEN_BLOCK;
T *nram_residual_pong = nram_residual_ping + (int)has_residual * TOKEN_BLOCK * HIDDEN_BLOCK;
float *nram_weight_ping =
(float *)(nram_residual_pong + (int)has_residual * TOKEN_BLOCK * HIDDEN_BLOCK);
float *nram_weight_pong = nram_weight_ping + pad_block_size;
int buffer_block_num = sizeof(T) > 2 ? 2 : 3;
T *nram_output_ping = (T *)(nram_weight_pong + pad_block_size);
T *nram_input_buffer = nram_output_ping + TOKEN_BLOCK * HIDDEN_BLOCK;
T *nram_output_pong = (T *)((char *)nram_output_ping + TOKEN_BLOCK * HIDDEN_BLOCK * sizeof(T) +
buffer_block_num * HIDDEN_BLOCK * sizeof(half));
int *nram_mask_buffer = (int *)nram_token_offset;
uint8_t *nram_mask_char = (uint8_t *)(nram_output_pong + TOKEN_BLOCK * HIDDEN_BLOCK);
int init_token_count = std::min(TOKEN_BLOCK, task_tokens) * topk;
int begin_expert_acc_tokens = 0;
int end_expert_acc_tokens = num_token * topk;
if (using_acc_sum) {
__memcpy_async(nram_expert_tables, cusum_token_count, (num_expert + 1) * int32_dtype_size,
GDRAM2NRAM);
}
__memcpy_async(nram_token_idx, gather_idx + task_token_begin * topk,
init_token_count * sizeof(int), GDRAM2NRAM);
__sync_io();
if (expert_parallelism) {
begin_expert_acc_tokens = __load_nram(nram_expert_tables + start_expert_id);
end_expert_acc_tokens = __load_nram(nram_expert_tables + start_expert_id + expert_size);
}
int active_token_count = getMaskAndActiveTokenCount(
nram_token_idx, nram_mask, nram_mask_char, nram_mask_buffer, begin_expert_acc_tokens,
end_expert_acc_tokens, init_token_count, expert_parallelism);
computeOffset(nram_token_offset, nram_bias_offset, nram_token_idx, nram_expert_tables, num_expert,
init_token_count, active_token_count, hidden_size, local_hidden_begin,
(int)sizeof(T), start_expert_id, expert_size, begin_expert_acc_tokens, has_bias);
__sync_io_move_compute(true, false, false, false, false, true);
__sync_io_move_compute(false, false, true, true, false, false);
int next_active_token_count = active_token_count;
int previous_global_token_begin = 0;
int previous_token_count = 0;
bool is_ping = false;
for (int task_begin = -1; task_begin * TOKEN_BLOCK < task_tokens; task_begin++) {
int next_token_begin = (task_begin + 1) * TOKEN_BLOCK;
int next_next_token_begin = (task_begin + 2) * TOKEN_BLOCK;
bool is_last_loop = next_token_begin >= task_tokens;
bool is_last_2_loop = next_next_token_begin >= task_tokens;
int current_token_begin = task_begin * TOKEN_BLOCK;
int current_token_count = std::min(TOKEN_BLOCK, task_tokens - current_token_begin);
int next_token_count = std::min(TOKEN_BLOCK, task_tokens - next_token_begin);
int next_next_token_count = std::min(TOKEN_BLOCK, task_tokens - next_next_token_begin);
int current_global_token_begin = task_token_begin + current_token_begin;
int next_global_token_begin = task_token_begin + next_token_begin;
int next_next_global_token_begin = task_token_begin + next_next_token_begin;
if (!is_last_loop) {
if (!is_last_2_loop) {
loadAsync2d(nram_token_idx, gather_idx + next_next_global_token_begin * topk,
next_next_token_count * topk * sizeof(int), 0, 0, 0);
}
loadAsync2d(nram_weight_ping, reduce_weight + next_global_token_begin * topk,
next_token_count * topk * fp32_dtype_size, 0, 0, 0);
if (has_residual) {
loadAsync2d(nram_residual_ping,
residual + next_global_token_begin * (uint64_t)hidden_size + local_hidden_begin,
local_hidden_size * sizeof(T), local_hidden_size * sizeof(T),
hidden_size * sizeof(T), next_token_count - 1);
}
gatherTokensAsync<T_IDX>(nram_input_ping, input, nram_token_offset,
local_hidden_size * sizeof(T), next_active_token_count);
gatherTokensAsync<T_IDX>(nram_bias_ping, bias, nram_bias_offset,
local_hidden_size * sizeof(T), next_active_token_count);
}
if (task_begin >= 1) {
storeAsync2d(
output + previous_global_token_begin * (uint64_t)hidden_size + local_hidden_begin,
nram_output_pong, local_hidden_size * sizeof(T), hidden_size * sizeof(T),
local_hidden_size * sizeof(T), previous_token_count - 1);
}
if (task_begin >= 0) {
if (has_bias && active_token_count) {
__bang_add(nram_input_pong, nram_input_pong, nram_bias_pong,
active_token_count * local_hidden_size);
}
if (expert_parallelism) {
weightedReduceSum<T, true>(nram_output_ping, nram_input_pong, nram_weight_pong,
nram_input_buffer, (int8_t *)nram_mask_char, topk,
local_hidden_size, current_token_count, is_ping);
} else {
weightedReduceSum<T, false>(nram_output_ping, nram_input_pong, nram_weight_pong,
nram_input_buffer, (int8_t *)nram_mask_char, topk,
local_hidden_size, current_token_count, is_ping);
}
if (has_residual) {
__bang_add((T *)nram_output_ping, (T *)nram_output_ping, nram_residual_pong,
current_token_count * local_hidden_size);
}
}
__sync_io_move_compute();
active_token_count = next_active_token_count;
if (expert_parallelism && !is_last_loop) {
__bang_float2uchar_tz((uint8_t *)nram_mask_char, (float *)nram_mask, next_token_count * topk);
}
if (!is_last_2_loop) {
next_active_token_count = getMaskAndActiveTokenCount(
nram_token_idx, nram_mask, nram_mask_char, nram_mask_buffer, begin_expert_acc_tokens,
end_expert_acc_tokens, next_next_token_count * topk, expert_parallelism);
computeOffset(nram_token_offset, nram_bias_offset, nram_token_idx, nram_expert_tables,
num_expert, next_next_token_count * topk, next_active_token_count, hidden_size,
local_hidden_begin, (int)sizeof(T), start_expert_id, expert_size,
begin_expert_acc_tokens, has_bias);
}
swap(nram_input_ping, nram_input_pong);
swap(nram_bias_ping, nram_bias_pong);
swap(nram_residual_ping, nram_residual_pong);
swap(nram_weight_ping, nram_weight_pong);
swap(nram_output_ping, nram_output_pong);
previous_global_token_begin = current_global_token_begin;
previous_token_count = current_token_count;
}
storeAsync2d(output + previous_global_token_begin * (uint64_t)hidden_size + local_hidden_begin,
nram_output_pong, local_hidden_size * sizeof(T), hidden_size * sizeof(T),
local_hidden_size * sizeof(T), previous_token_count - 1);
}
#if __BANG_ARCH__ < 500
template <>
__mlu_global__ void MLUCombineMoeResultKernel<bfloat16_t, uint32_t>(bfloat16_t *output,
bfloat16_t *input,
bfloat16_t *bias,
bfloat16_t *residual,
float *reduce_weight,
int *cusum_token_count,
int *gather_ids,
int num_token,
int topk,
int num_expert,
int hidden_size,
int start_expert_id,
int expert_size,
int HIDDEN_BLOCK,
int TOKEN_BLOCK) {}
template <>
__mlu_global__ void MLUCombineMoeResultKernel<bfloat16_t, uint64_t>(bfloat16_t *output,
bfloat16_t *input,
bfloat16_t *bias,
bfloat16_t *residual,
float *reduce_weight,
int *cusum_token_count,
int *gather_ids,
int num_token,
int topk,
int num_expert,
int hidden_size,
int start_expert_id,
int expert_size,
int HIDDEN_BLOCK,
int TOKEN_BLOCK) {}
#endif
} // namespace kernels
KernelStatus invokeMoeCombineResultKernel(cnrtQueue_t queue,
void *output,
const void *input,
const void *bias,
const void *residual,
const float *reduce_weight,
const int *cusum_token_count,
const int *gather_idx,
int num_token,
int topk,
int num_expert,
int hidden_size,
int start_expert_id,
int expert_size,
cnnlDataType_t dtype) {
if (topk > 128 || num_expert > 1024 || hidden_size < 256) {
std::cerr << "[invokeMoeCombineResultKernel]: "
<< "currently only support topk <= 128, num_expert <= 1024 and hidden_size >= 256.";
return KernelStatus::KERNEL_STATUS_FAILED;
}
if (bias != nullptr) {
std::cerr << "[invokeMoeCombineResultKernel]: currently does not support bias.";
return KernelStatus::KERNEL_STATUS_FAILED;
}
if ((bias != nullptr || num_expert > expert_size) && cusum_token_count == nullptr) {
std::cerr << "[invokeMoeCombineResultKernel]: if has bias or expert parallelism, "
<< "cusum_token_count can not be nullptr.";
return KernelStatus::KERNEL_STATUS_FAILED;
}
size_t data_bytes = 0;
cnnlGetSizeOfDataType(dtype, &data_bytes);
CNdev dev;
cnCtxGetDevice(&dev);
int cluster_num;
int core_num;
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev));
// 480KB nram size, 48KB for token idx, token/bias offset and weight. 432KB for buffer.
// TOKEN_BLOCK * topk <= 1024 in case 32KB is enough for idx and offset.
int convert_buffer = data_bytes == 2
? 3 * hidden_size * data_bytes
: 2 * hidden_size * data_bytes; // buffer for convert bf16/fp16->fp32
int max_input_size = (432 * 1024 - convert_buffer) /
(2 * topk * data_bytes + /*input size, double buffer*/
(bias != nullptr) * 2 * topk * data_bytes + /*bias size, double buffer*/
(residual != nullptr) * 2 * data_bytes + /*residual size, double buffer*/
2 * data_bytes); /*output size, one buffer*/
int TOKEN_BLOCK = 1;
int HIDDEN_BLOCK = 1;
int HIDDEN_BLOCK_X_TOKEN_BLOCK = (max_input_size / 64) * 64;
if (HIDDEN_BLOCK_X_TOKEN_BLOCK < hidden_size) {
HIDDEN_BLOCK = HIDDEN_BLOCK_X_TOKEN_BLOCK;
TOKEN_BLOCK = 1;
} else {
HIDDEN_BLOCK = hidden_size;
}
// for latency case, hidden_size is large but token is small.
if (HIDDEN_BLOCK == hidden_size && hidden_size >= 4096 && num_token <= core_num * cluster_num) {
HIDDEN_BLOCK = (hidden_size + core_num - 1) / core_num;
}
HIDDEN_BLOCK = std::min(HIDDEN_BLOCK, 8 * 1024);
uint32_t task_dim_x = (hidden_size + HIDDEN_BLOCK - 1) / HIDDEN_BLOCK;
task_dim_x =
(task_dim_x < core_num) ? task_dim_x : ((task_dim_x + core_num - 1) / core_num * core_num);
uint32_t pad_dim_x = task_dim_x;
while (pad_dim_x <= cluster_num * core_num) {
if ((cluster_num * core_num % pad_dim_x == 0)) {
task_dim_x = pad_dim_x;
break;
}
pad_dim_x += core_num;
}
HIDDEN_BLOCK = (hidden_size + task_dim_x - 1) / task_dim_x;
HIDDEN_BLOCK = (HIDDEN_BLOCK + 63) / 64 * 64;
if (HIDDEN_BLOCK_X_TOKEN_BLOCK >= hidden_size) {
TOKEN_BLOCK = HIDDEN_BLOCK_X_TOKEN_BLOCK / HIDDEN_BLOCK;
}
TOKEN_BLOCK = std::min(TOKEN_BLOCK, 1024 / topk);
float max_cluster_num = core_num * cluster_num / task_dim_x;
uint32_t task_dim_y = std::min(max_cluster_num, num_token);
task_dim_y = task_dim_y < 1 ? 1 : task_dim_y;
cnrtDim3_t dim{.x = task_dim_x, .y = task_dim_y, .z = 1};
bool is_large_tensor = data_bytes * num_token * topk * hidden_size > UINT32_MAX;
if (dtype == CNNL_DTYPE_FLOAT) {
if (!is_large_tensor) {
kernels::MLUCombineMoeResultKernel<float, uint32_t><<<dim, cnrtFuncTypeBlock, queue>>>(
(float *)output, (float *)input, (float *)bias, (float *)residual, (float *)reduce_weight,
(int *)cusum_token_count, (int *)gather_idx, num_token, topk, num_expert, hidden_size,
start_expert_id, expert_size, HIDDEN_BLOCK, TOKEN_BLOCK);
} else {
kernels::MLUCombineMoeResultKernel<float, uint64_t><<<dim, cnrtFuncTypeBlock, queue>>>(
(float *)output, (float *)input, (float *)bias, (float *)residual, (float *)reduce_weight,
(int *)cusum_token_count, (int *)gather_idx, num_token, topk, num_expert, hidden_size,
start_expert_id, expert_size, HIDDEN_BLOCK, TOKEN_BLOCK);
}
} else if (dtype == CNNL_DTYPE_HALF) {
if (!is_large_tensor) {
kernels::MLUCombineMoeResultKernel<half, uint32_t><<<dim, cnrtFuncTypeBlock, queue>>>(
(half *)output, (half *)input, (half *)bias, (half *)residual, (float *)reduce_weight,
(int *)cusum_token_count, (int *)gather_idx, num_token, topk, num_expert, hidden_size,
start_expert_id, expert_size, HIDDEN_BLOCK, TOKEN_BLOCK);
} else {
kernels::MLUCombineMoeResultKernel<half, uint64_t><<<dim, cnrtFuncTypeBlock, queue>>>(
(half *)output, (half *)input, (half *)bias, (half *)residual, (float *)reduce_weight,
(int *)cusum_token_count, (int *)gather_idx, num_token, topk, num_expert, hidden_size,
start_expert_id, expert_size, HIDDEN_BLOCK, TOKEN_BLOCK);
}
} else if (dtype == CNNL_DTYPE_BFLOAT16) {
if (!isBf16Supported()) {
std::cerr << "[invokeMoeCombineResultKernel]: MLU300 devices do not support bfloat16."
<< std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
if (!is_large_tensor) {
kernels::MLUCombineMoeResultKernel<bfloat16_t, uint32_t><<<dim, cnrtFuncTypeBlock, queue>>>(
(bfloat16_t *)output, (bfloat16_t *)input, (bfloat16_t *)bias, (bfloat16_t *)residual,
(float *)reduce_weight, (int *)cusum_token_count, (int *)gather_idx, num_token, topk,
num_expert, hidden_size, start_expert_id, expert_size, HIDDEN_BLOCK, TOKEN_BLOCK);
} else {
kernels::MLUCombineMoeResultKernel<bfloat16_t, uint64_t><<<dim, cnrtFuncTypeBlock, queue>>>(
(bfloat16_t *)output, (bfloat16_t *)input, (bfloat16_t *)bias, (bfloat16_t *)residual,
(float *)reduce_weight, (int *)cusum_token_count, (int *)gather_idx, num_token, topk,
num_expert, hidden_size, start_expert_id, expert_size, HIDDEN_BLOCK, TOKEN_BLOCK);
}
} else {
std::cerr << "[invokeMoeCombineResultKernel]: the current supported dtype is "
<< "among float/half/bfloat16." << std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
} // namespace tmo

View File

@@ -0,0 +1,85 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_MOE_COMBINE_RESULT_MLUH_
#define CSRC_KERNELS_MOE_COMBINE_RESULT_MLUH_
#include "../kernel_utils.h"
#include "cnnl.h"
namespace tmo {
/**
* @brief Sort tokens grouped by different experts based on index. Each token
* selects the topk hidden vectors, multiplies them by corresponding weights,
* and finally reduces the topk vectors for each token. This process involves
* bias and residual, calculated as (x + bias) * weight + residual.
* @example
* input:
* [[[1, 2, 1, 1],
* [1, 1, 1, 2]],
* [[2, 1, 1, 1],
* [1, 1, 1, 1]]]
* num_token = 2, topk = 2
* cusum_token_count = [0, 2, 4]
* index:
* [0, 1, 2, 3]
* weight:
* [0, 0, 1, 1]
* bias:
* [[0, 0, 0, 0],
* [1, 1, 1, 1]]
* residual:
* [[1, 1, 1, 1],
* [0, 0, 0, 0]]
* output:
* [[1, 1, 1, 1],
* [5, 4, 4, 4]]
* @param queue: The queue for mlu.
* @param output: Output. Pointer to the MLU memory that stores the result.
* The shape is [num_token, hidden_size].
* @param input: Input. Pointer to the MLU memory that stores input tokens.
* The shape is [num_token * topk, hidden_size].
* @param bias: Input. Pointer to the MLU memory that stores bias.
* The shape is [num_expert, hidden_size].
* @param residual: Input. Pointer to the MLU memory that stores residual.
* The shape is [num_token, hidden_size].
* @param reduce_weight: Input. Pointer to the MLU memory that stores reduce_weight.
* The shape is [num_token * topk].
* @param cusum_token_count: Input. Pointer to the MLU memory that stores the cumulative sum of the
* token number of each expert. The shape is [num_expert + 1].
* @param gather_idx: Input. Pointer to the MLU memory that stores gather_idx.
* The shape is [num_token * topk].
* @param num_token: The total number of tokens.
* @param topk: The number of expert.
* @param num_expert: The number of expert.
* @param hidden_size: The size of lowest dimension.
* @param start_expert_id: The id of the first processed expert.
* @param expert_size: The number of processed experts.
* @param dtype: Data type.
* @note Currently does not support bias.
*/
KernelStatus invokeMoeCombineResultKernel(cnrtQueue_t queue,
void *output,
const void *input,
const void *bias,
const void *residual,
const float *reduce_weight,
const int *cusum_token_count,
const int *gather_idx,
int num_token,
int topk,
int num_expert,
int hidden_size,
int start_expert_id,
int expert_size,
cnnlDataType_t dtype);
} // namespace tmo
#endif // CSRC_KERNELS_MOE_COMBINE_RESULT_MLUH_

View File

@@ -0,0 +1,219 @@
#include <bang_device_functions_extra.h>
#include <mlu.h>
#include "cnnl.h"
#include "cnrt.h"
#include "expand_input.mluh"
namespace tmo {
namespace kernels {
#define RESERVED_SIZE (32 * 1024)
#define NRAM_BUFFER_SIZE (__MLU_NRAM_SIZE__ * 1024 - RESERVED_SIZE)
#define SRAM_BUFFER_SIZE (__MLU_SRAM_SIZE__ * 1024 - RESERVED_SIZE)
#define MEMCPY_BURST_SIZE 128
__mlu_shared__ int8_t sram_buffer[SRAM_BUFFER_SIZE];
__nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE];
// T_offset: uint32_t or uint64_t
template <size_t data_size, typename T_offset>
__mlu_func__ void ExpandInputKernel(void *output,
void *input,
int *index,
int num_token,
int hidden_size,
int num_index) {
void *input_ptr = input;
uint64_t input_size = data_size * num_token * hidden_size;
// whether SRAM_BUFFER_SIZE can hold the input data
// sram_enable is true ==> is_nram_output is true
bool sram_enable = (hidden_size == 1) && (input_size < SRAM_BUFFER_SIZE);
if (sram_enable) {
input_ptr = (void *)sram_buffer;
__memcpy(input_ptr, input, input_size, GDRAM2SRAM);
__sync_cluster();
}
if (__is_mpu()) {
return;
}
// Each ipu core processes no less than 128B of data, and the remaining cores can idle
int32_t max_task_num = data_size * hidden_size * num_index / MEMCPY_BURST_SIZE;
uint32_t maxTaskDim = std::min(taskDim, std::max(max_task_num, 1));
uint32_t total_num = num_index;
uint32_t base = total_num / maxTaskDim;
uint32_t tail = total_num - base * maxTaskDim;
if (taskId >= maxTaskDim) {
return;
}
uint32_t batch_per_core = base + (taskId < tail ? 1 : 0);
uint32_t batch_step = base * taskId + (taskId < tail ? taskId : tail);
// nram
/*
* first: compute offset: index[i] * data_size
* second: gather data
* -------------------------------------------------
* addr || index/offset | output |
* type || int32_t/T_offset | T |
* num || n | n * hidden_size |
* -------------------------------------------------
*/
uint32_t nram_size_per_pixel = sizeof(T_offset) + hidden_size * data_size;
// whether nram can hold two pixel: if so, then GDRAM->NRAM->GDRAM, otherwise GDRAM->GDRAM
bool is_nram_output = nram_size_per_pixel * 2 <= NRAM_BUFFER_SIZE;
uint32_t per_num =
is_nram_output ? NRAM_BUFFER_SIZE / nram_size_per_pixel : NRAM_BUFFER_SIZE / sizeof(T_offset);
int8_t *output_base = (int8_t *)output + (uint64_t)batch_step * hidden_size * data_size;
int *index_base = index + batch_step;
T_offset *nram_offset = (T_offset *)nram_buffer;
int32_t *nram_index;
if (std::is_same<T_offset, int64_t>::value) {
nram_index = (int32_t *)nram_offset + per_num;
} else {
nram_index = (int32_t *)nram_offset;
}
int8_t *nram_output = (int8_t *)(nram_offset + per_num);
uint32_t repeat = batch_per_core / per_num;
uint32_t remain = batch_per_core - repeat * per_num;
uint32_t deal_num = per_num;
uint32_t is_remain = remain != 0 ? 1 : 0;
for (int32_t i = 0; i < repeat + is_remain; i++) {
if (i == repeat) {
deal_num = remain;
}
int8_t *output_ptr = output_base + (uint64_t)i * per_num * hidden_size * data_size;
int32_t *index_ptr = index_base + i * per_num;
// index -> offset
__memcpy((void *)nram_index, (void *)index_ptr, deal_num * sizeof(int32_t), GDRAM2NRAM);
if (std::is_same<T_offset, uint64_t>::value) {
#if __BANG_ARCH__ > 592
__bang_int322int64((int64_t *)nram_offset, (int32_t *)nram_index, deal_num, 0, 0);
#else
__bang_int322int64((int64_t *)nram_offset, (int32_t *)nram_index, deal_num);
#endif
}
__bang_mul_scalar(nram_offset, nram_offset, (int64_t)data_size * hidden_size, deal_num);
// copy
if (is_nram_output) {
__bang_write_zero((int8_t *)nram_output, deal_num * hidden_size);
mluMemcpyDirection_t dir = sram_enable ? SRAM2NRAM : GDRAM2NRAM;
// GDRAM or SRAM -> NRAM -> GDRAM
#if __BANG_ARCH__ >= 592 // gather requires
__gather(nram_output, input_ptr, nram_offset, hidden_size * data_size, dir,
hidden_size * data_size, deal_num);
#else
for (int32_t j = 0; j < deal_num; j++) {
T_offset offset_value = *(nram_offset + j);
int8_t *input_offset = (int8_t *)input_ptr + offset_value;
__memcpy(nram_output + j * hidden_size * data_size, input_offset, hidden_size * data_size,
dir);
}
#endif // __BANG_ARCH__
__memcpy(output_ptr, nram_output, deal_num * hidden_size * data_size, NRAM2GDRAM);
} else {
// GDRAM -> GDRAM
#if __BANG_ARCH__ >= 592 // gather requires
__gather(output_ptr, input, (uint64_t *)nram_offset, hidden_size * data_size, GDRAM2GDRAM,
hidden_size * data_size, deal_num);
#else
for (int32_t j = 0; j < deal_num; j++) {
T_offset offset_value = *(nram_offset + j);
int8_t *input_offset = (int8_t *)input + offset_value;
__memcpy(output_ptr + (T_offset)j * hidden_size * data_size, input_offset,
hidden_size * data_size, GDRAM2GDRAM);
}
#endif // __BANG_ARCH__
}
}
}
// T_offset: uint32_t or uint64_t
template <size_t data_size, typename T_offset>
__mlu_global__ void MLUExpandInputKernel(void *expand_hidden_state,
void *hidden_state,
int *gather_idx,
int *cusum_token_count,
int num_token,
int hidden_size,
int topk,
int total_expert_num,
int start_expert_id,
int expert_count) {
int32_t num_index = num_token * topk;
int *gather_start_idx = (int *)gather_idx;
if (cusum_token_count != nullptr) {
num_index = *((int *)cusum_token_count + start_expert_id + expert_count) -
*((int *)cusum_token_count + start_expert_id);
gather_start_idx = (int *)gather_idx + *(cusum_token_count + start_expert_id);
}
ExpandInputKernel<data_size, T_offset>(expand_hidden_state, hidden_state, gather_start_idx,
num_token, hidden_size, num_index);
}
// instantiate kernels
#define INSTANTIATE_ONE(data_size, T_offset) \
template __mlu_global__ void MLUExpandInputKernel<data_size, T_offset>( \
void *, void *, int *, int *, int, int, int, int, int, int);
INSTANTIATE_ONE(1, uint32_t)
INSTANTIATE_ONE(2, uint32_t)
INSTANTIATE_ONE(4, uint32_t)
INSTANTIATE_ONE(8, uint32_t)
// large tensor
INSTANTIATE_ONE(1, uint64_t)
INSTANTIATE_ONE(2, uint64_t)
INSTANTIATE_ONE(4, uint64_t)
INSTANTIATE_ONE(8, uint64_t)
} // namespace kernels
KernelStatus invokeMoeExpandInputKernel(cnrtQueue_t queue,
void *expand_hidden_state,
const void *hidden_state,
const int *gather_idx,
const int *cusum_token_count,
int num_token,
int hidden_size,
int topk,
cnnlDataType_t data_type,
int total_expert_num,
int start_expert_id,
int expert_count) {
CNdev dev;
cnCtxGetDevice(&dev);
int cluster_num;
int core_num;
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev));
size_t type_size = 0;
cnnlGetSizeOfDataType(data_type, &type_size);
int max_cluster_num =
(uint64_t)hidden_size * num_token * topk * type_size / (core_num * MEMCPY_BURST_SIZE);
cluster_num = std::min(std::max(max_cluster_num, 1), cluster_num);
cnrtDim3_t dim{.x = (uint32_t)core_num, .y = (uint32_t)cluster_num, .z = 1};
void (*expand_input_kernels[])(void *, void *, int *, int *, int, int, int, int, int, int) = {
kernels::MLUExpandInputKernel<1, uint32_t>, kernels::MLUExpandInputKernel<2, uint32_t>,
kernels::MLUExpandInputKernel<4, uint32_t>, kernels::MLUExpandInputKernel<8, uint32_t>,
kernels::MLUExpandInputKernel<1, uint64_t>, kernels::MLUExpandInputKernel<2, uint64_t>,
kernels::MLUExpandInputKernel<4, uint64_t>, kernels::MLUExpandInputKernel<8, uint64_t>};
bool is_large_tensor = type_size * hidden_size * num_token * topk > INT32_MAX;
int kernel_index = (type_size == 8 ? 3 : type_size >> 1) + (is_large_tensor ? 4 : 0);
expand_input_kernels[kernel_index]<<<dim, cnrtFuncTypeUnion1, queue>>>(
(void *)expand_hidden_state, (void *)hidden_state, (int *)gather_idx,
(int *)cusum_token_count, num_token, hidden_size, topk, total_expert_num, start_expert_id,
expert_count);
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
} // namespace tmo

View File

@@ -0,0 +1,81 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_MOE_EXPAND_INPUT_MLUH_
#define CSRC_KERNELS_MOE_EXPAND_INPUT_MLUH_
#include "../kernel_utils.h"
#include "cnnl.h"
namespace tmo {
/**
* @brief Gathers slices from hidden_state at axis 1 according to gather_idx and cusum_token_count.
* @example
* hidden_state:
* [[1, 2, 3, 4],
* [5, 6, 7, 8],
* [9, 10, 11, 12]]
* gather_idx:
* [[1, 0, 2, 2, 1, 0]]
* cusum_token_count: NULL
* num_token = 3
* hidden_size = 4
* topk = 2
* expand_hidden_state:
* [[5, 6, 7, 8],
* [1, 2, 3, 4],
* [9, 10, 11, 12],
* [9, 10, 11, 12],
* [5, 6, 7, 8],
* [1, 2, 3, 4]]
* @param queue: The queue for mlu.
* @param hidden_state: Input. Pointer to the MLU memory that store the input,
* the shape must be [num_token, hidden_size].
* @param gather_idx: Input. Pointer to the MLU memory that stores the index,
* the shape must be [num_token * topk].
* @param cusum_token_count: Input. Pointer to the MLU memory that stores the prefix sum of
* token_count. If cusum_token_count is not NULL, the shape must be [total_expert_num + 1]. The
* gather operation will be performed as follows: if cusum_token_count is not NULL: index =
* gather_idx[cusum_token_count[start_expert_id]:cusum_token_count[start_expert_id+expert_count]]
* expand_hidden_state = hidden_state[index]
* else:
* index = gather_idx[:]
* expand_hidden_state = hidden_state[index]
* @param expand_hidden_state: Output. Pointer to the MLU memory that stores the output,
* if cusum_token_count is not NULL, the shape shoule be [num_index * topk ,hidden_size] in
* which num_index =
* cusum_token_count[start_expert_id+expert_count]-cusum_token_count[start_expert_id]. Otherwise,
* the shape should be [num_token * topk, hidden_size].
* @param num_token: the number of token.
* @param hidden_size: the slice size.
* @param topk: the number of topk.
* @param data_type: Data type of hidden_state.
* @param total_expert_num: the total number of expert.
* @param start_expert_id: the first expert id.
* @param expert_count: the number of experts currently being processed.
*/
KernelStatus invokeMoeExpandInputKernel(cnrtQueue_t queue,
void *expand_hidden_state,
const void *hidden_state,
const int *gather_idx,
const int *cusum_token_count,
int num_token,
int hidden_size,
int topk,
cnnlDataType_t data_type,
int total_expert_num,
int start_expert_id,
int expert_count);
} // namespace tmo
#endif // CSRC_KERNELS_MOE_EXPAND_INPUT_MLUH_

View File

@@ -0,0 +1,935 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include <algorithm>
#include <cmath>
#include <cstddef>
#include <cstdint>
#include <type_traits>
#include <vector>
#include "gen_idx.mluh"
// clang-format off
#include <mlu.h>
// clang-format on
namespace tmo {
namespace kernels {
#define NRAM_BUFFER_SIZE ((__MLU_NRAM_SIZE__ - 16) * 1024)
#define SRAM_BUFFER_SIZE ((__MLU_SRAM_SIZE__ - 8) * 1024)
#define ALIGN_16 (16)
#define EXPERT_AVG_COUNT_TEST (0)
__mlu_shared__ int8_t sram_buffer[SRAM_BUFFER_SIZE];
__nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE];
__nram__ const int range[64] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63};
// Generate integer sequence data from 0 to length-1
__mlu_func__ void generateIntSeq(int *dst, int length) {
int count = 64;
__bang_move(dst, range, std::min(count, length) * sizeof(int));
while (count < length) {
__bang_add_scalar(dst + count, dst, (int)count, std::min(count, length - count));
count *= 2;
}
}
// genIdx Block kernel, use only 1 core to process
__mlu_global__ void launchMoeGenIdxBlockKernel(int *gather_expand_idx,
int *gather_combine_idx,
int *token_count,
int *cusum_token_count,
const void *expert_id,
const int num_token,
const int num_expert,
const int topk) {
/* NRAM space */
// Total occupy: (4 * token_total_num + 2 * num_expert) * sizeof(int)
// --------------------------------------------------------------
// | expert_id | sorted_idx |gen_idx_onchip|cur_expert_result|
// | combine_idx | expand_idx | | scatter_offset |
// |num_token*topk|num_token*topk|num_token*topk| num_token*topk |
// --------------------------------------------------------------
// ------------------------------
// |token_count|token_count_presum|
// | | |
// | num_expert| num_expert |
// ------------------------------
uint32_t token_total_num = num_token * topk;
// num align to 16, size align to 64B
uint32_t align_total_num = (token_total_num + ALIGN_16 - 1) >> 4 << 4;
int8_t *expert_id_onchip = (int8_t *)nram_buffer;
int8_t *sorted_idx_onchip = (int8_t *)expert_id_onchip + align_total_num * sizeof(int);
int8_t *gen_idx_onchip = (int8_t *)sorted_idx_onchip + align_total_num * sizeof(int);
int8_t *cur_expert_result = (int8_t *)gen_idx_onchip + align_total_num * sizeof(int);
int8_t *token_count_onchip = (int8_t *)cur_expert_result + align_total_num * sizeof(int);
int8_t *token_count_presum_onchip = (int8_t *)token_count_onchip + num_expert * sizeof(int);
int8_t *scatter_offset = cur_expert_result; // reuse cur_expert space
#if __BANG_ARCH__ >= 592
int8_t *combine_idx_onchip = expert_id_onchip; // reuse expert_it space
#endif
int8_t *expand_idx_onchip = sorted_idx_onchip; // reuse sorted_idx space
// Load current core input expert_id and generate int sequence
__memcpy_async((int *)expert_id_onchip, (int *)expert_id, token_total_num * sizeof(int),
GDRAM2NRAM);
generateIntSeq((int *)gen_idx_onchip, token_total_num);
__sync();
// Initialize sort idx offset
uint32_t sorted_idx_offset = 0;
// Initialize token count first presum with 0
((int *)token_count_presum_onchip)[0] = 0;
bool need_cusum_token_count = bool(cusum_token_count != nullptr);
// Loop on each expert, eq, count, filter index
for (int cur_expert = 0; cur_expert < num_expert; cur_expert++) {
__bang_eq_scalar((int *)cur_expert_result, (int *)expert_id_onchip, cur_expert,
token_total_num);
// Use filter to sort gen_idx, output with sorted_idx_offset
uint32_t cur_expert_count =
__bang_filter(((float *)sorted_idx_onchip) + sorted_idx_offset, (float *)gen_idx_onchip,
(float *)cur_expert_result, token_total_num);
sorted_idx_offset += cur_expert_count;
((int *)token_count_onchip)[cur_expert] = cur_expert_count;
// Compute cusum token count and store
if (need_cusum_token_count) {
((int *)token_count_presum_onchip)[cur_expert + 1] = sorted_idx_offset;
}
}
#if EXPERT_AVG_COUNT_TEST
// NOTE: test avg expert code here:
uint32_t token_count_avg = token_total_num / num_expert;
uint32_t expert_remain_num = token_total_num % num_expert;
for (int i = 0; i < num_expert; i++) {
((int *)token_count_onchip)[i] =
(i < expert_remain_num) ? token_count_avg + 1 : token_count_avg;
((int *)token_count_presum_onchip)[i + 1] =
((int *)token_count_presum_onchip)[i] + ((int *)token_count_onchip)[i];
}
#endif
__sync_compute();
// Store token_count and cusum token count
__memcpy_async((int *)token_count, (int *)token_count_onchip, num_expert * sizeof(int),
NRAM2GDRAM);
if (need_cusum_token_count) {
__memcpy_async((int *)cusum_token_count, (int *)token_count_presum_onchip,
(num_expert + 1) * sizeof(int), NRAM2GDRAM);
}
// Use sorted idx to generate gather idx for expand and combine
#if __BANG_ARCH__ >= 592
// scatter_offset = sorted_idx mul_scalar sizeof(int);
__bang_mul_scalar((int *)scatter_offset, (int *)sorted_idx_onchip, (int)(sizeof(int)),
token_total_num);
#else
// scatter dst GDRAM addr should align to 64B
int *combine_idx_align_addr = (int *)((uint64_t)(gather_combine_idx) >> 6 << 6);
int combine_idx_align_offset = (int)(gather_combine_idx - combine_idx_align_addr);
__bang_fusion(FUSION_FAM, (int *)scatter_offset, (int *)sorted_idx_onchip,
combine_idx_align_offset, (int)(sizeof(int)), token_total_num);
#endif
__sync_compute();
#if __BANG_ARCH__ >= 592
// scatter_async to NRAM
__scatter_async((int *)combine_idx_onchip, (int *)gen_idx_onchip, (uint32_t *)scatter_offset,
sizeof(int), NRAM2NRAM, sizeof(int), (unsigned short)token_total_num);
#endif
// expand_idx = sorted_idx div(topk)
__bang_div((int *)expand_idx_onchip, (int *)sorted_idx_onchip, topk, token_total_num);
// Store expand_idx and combine_idx
__sync_compute();
__memcpy_async((int *)gather_expand_idx, (int *)expand_idx_onchip, token_total_num * sizeof(int),
NRAM2GDRAM);
#if __BANG_ARCH__ >= 592
__sync_move();
__memcpy_async((int *)gather_combine_idx, (int *)combine_idx_onchip,
token_total_num * sizeof(int), NRAM2GDRAM);
#else
// 370 directly scatter to GDRAM
__scatter((int *)combine_idx_align_addr, (int *)gen_idx_onchip, (uint32_t *)scatter_offset,
sizeof(int), NRAM2GDRAM, sizeof(int), (unsigned short)token_total_num);
#endif
}
// Only MLU500 series support NRAM2SRAM scatter direction
__mlu_func__ void scatterSeqSram(int *dst, int *src, uint32_t *offset, int length) {
#if __BANG_ARCH__ >= 592
// When length larger than 65535(maximum segnum in bang_scatter),
// and src/offset address should align to 64B
int seg_repeat = length / 32768;
int seg_remain = length % 32768;
int seg_offset = 0;
for (int seg = 0; seg < seg_repeat; seg++) {
__scatter_async((int *)dst, ((int *)src) + seg_offset, ((uint32_t *)offset) + seg_offset,
sizeof(int), NRAM2SRAM, sizeof(int), (unsigned short)32768);
seg_offset += 32768;
}
if (seg_remain > 0) {
__scatter_async((int *)dst, ((int *)src) + seg_offset, ((uint32_t *)offset) + seg_offset,
sizeof(int), NRAM2SRAM, sizeof(int), (unsigned short)seg_remain);
}
#endif
}
// Scatter sequence, transfer size is sizeof(int)
__mlu_func__ void scatterSeqDram(int *dst, int *src, uint32_t *offset, int length) {
// When length larger than 65535(maximum segnum in bang_scatter),
// and src/offset address should align to 64B
int seg_repeat = length / 32768;
int seg_remain = length % 32768;
int seg_offset = 0;
for (int seg = 0; seg < seg_repeat; seg++) {
#if __BANG_ARCH__ >= 592
__scatter_async((int *)dst, ((int *)src) + seg_offset, ((uint32_t *)offset) + seg_offset,
sizeof(int), NRAM2GDRAM, sizeof(int), (unsigned short)32768);
#else
__scatter((int *)dst, ((int *)src) + seg_offset, ((uint32_t *)offset) + seg_offset, sizeof(int),
NRAM2GDRAM, sizeof(int), (unsigned short)32768);
#endif
seg_offset += 32768;
}
if (seg_remain > 0) {
#if __BANG_ARCH__ >= 592
__scatter_async((int *)dst, ((int *)src) + seg_offset, ((uint32_t *)offset) + seg_offset,
sizeof(int), NRAM2GDRAM, sizeof(int), (unsigned short)seg_remain);
#else
__scatter((int *)dst, ((int *)src) + seg_offset, ((uint32_t *)offset) + seg_offset, sizeof(int),
NRAM2GDRAM, sizeof(int), (unsigned short)seg_remain);
#endif
}
}
// 1. Get token count
__mlu_func__ void getTokenCount(int *token_count,
int *expert_id,
int token_cur_core,
int cur_token_start,
int num_expert) {
// 1. Partition on [num_token*topk],
// each core for-loop on all expert_id, use eq and count instructions,
// use AtomicAdd to accumulate all expert_id token counts, on GDRAM.
// And sync for all cores.
// NRAM:
// ------------------------------------------------------
// |expert_id_onchip|cur_expert_result|expert_count_onchip|
// | deal_num | deal_num | num_expert |
// ------------------------------------------------------
uint32_t deal_num = (NRAM_BUFFER_SIZE / sizeof(int) - num_expert) / 2;
int8_t *expert_id_onchip = (int8_t *)nram_buffer;
int8_t *cur_expert_result = (int8_t *)expert_id_onchip + deal_num * sizeof(int);
int8_t *expert_count_onchip = cur_expert_result + deal_num * sizeof(int);
// Current core data loop
uint32_t repeat = token_cur_core / deal_num;
uint32_t remain = token_cur_core % deal_num;
uint32_t total_repeat = repeat + (int)(remain > 0);
uint32_t token_addr_offset = cur_token_start;
// Initialize token_count with 0
if (taskId == 0) {
__gdramset((int *)token_count, num_expert, 0);
}
// Sync for initialize token_count
__sync_all_ipu();
// Initialize expert count onchip with 0
if (token_cur_core > 0) {
__bang_write_zero((int *)expert_count_onchip, num_expert);
}
// actual num in loop
int cur_deal_num = deal_num;
for (int i = 0; i < total_repeat; i++) {
if (i == total_repeat - 1 && remain > 0) {
cur_deal_num = remain;
}
// Load current core input expert_id
__memcpy((int *)expert_id_onchip, ((int *)expert_id) + token_addr_offset,
cur_deal_num * sizeof(int), GDRAM2NRAM);
token_addr_offset += cur_deal_num;
// Loop on each expert, eq, count
for (int cur_expert = 0; cur_expert < num_expert; cur_expert++) {
__bang_eq_scalar((int *)cur_expert_result, (int *)expert_id_onchip, cur_expert, cur_deal_num);
// NOTE: __bang_count() only support floating data type
uint32_t cur_expert_count = __bang_count((float *)cur_expert_result, cur_deal_num);
((int *)expert_count_onchip)[cur_expert] += cur_expert_count;
}
}
// AtomicAdd(reduce) all cores token count results
if (token_cur_core > 0) {
__bang_atomic_reduce_add((int *)token_count, (int *)expert_count_onchip, num_expert);
}
// Sync for all cores, get accumulate of token_count
__sync_all_ipu();
}
// 2. Get token count presum, for each expert index start address after sorting
__mlu_func__ void getTokenCountPresum(int *token_count_presum,
int *token_count,
const int num_expert) {
// 2. After first process, already get token_count.
// Then use one core to pre-sum on token_count, consider size of int32,
// first expert id start address should be zero.
// to get each expert id start address after sorting, store to workspace,
// token_count_presum.
// And sync for all cores.
// NRAM:
// load token_count to token_count_presum[1~num_expert+1],
// for i = 0 to num_expert:
// token_count_presum[i+1] += token_count_presum[i]
// store token_count_presum[0~num_expert]
// -------------------------
// |token_count_presum_onchip|
// | {0}, num_expert |
// -------------------------
if (taskId == 0) {
// Initialize count presum onchip with a first 0
int8_t *token_count_presum_onchip = nram_buffer;
((int *)token_count_presum_onchip)[0] = 0;
// Load token_count with an offset of 1
__memcpy(((int *)token_count_presum_onchip) + 1, (int *)token_count, num_expert * sizeof(int),
GDRAM2NRAM);
// Calculate presum of token count by each expert
for (int cur_expert = 0; cur_expert < num_expert; cur_expert++) {
((int *)token_count_presum_onchip)[cur_expert + 1] +=
((int *)token_count_presum_onchip)[cur_expert];
}
// Store token count presum to workspace
__memcpy((int *)token_count_presum, (int *)token_count_presum_onchip,
(num_expert + 1) * sizeof(int), NRAM2GDRAM);
}
// Sync for all cores, get presum of token count
__sync_all_ipu();
}
__mlu_func__ void modifyTokenCountAndPresum(int *token_count_presum,
int *token_count,
const uint32_t token_total_num,
const int num_expert) {
uint32_t token_count_avg = token_total_num / num_expert;
uint32_t expert_remain_num = token_total_num % num_expert;
int8_t *token_count_onchip = nram_buffer;
int8_t *token_count_presum_onchip = token_count_onchip + num_expert * sizeof(int);
((int *)token_count_presum_onchip)[0] = 0;
for (int i = 0; i < num_expert; i++) {
((int *)token_count_onchip)[i] =
(i < expert_remain_num) ? token_count_avg + 1 : token_count_avg;
((int *)token_count_presum_onchip)[i + 1] =
((int *)token_count_presum_onchip)[i] + ((int *)token_count_onchip)[i];
}
__memcpy((int *)token_count, (int *)token_count_onchip, num_expert * sizeof(int), NRAM2GDRAM);
__memcpy((int *)token_count_presum, (int *)token_count_presum_onchip,
(num_expert + 1) * sizeof(int), NRAM2GDRAM);
}
// 3. Get expert position index after sorting
__mlu_func__ void getSortedIdx(int *sorted_idx,
int *expert_id,
int *token_count_presum,
const int token_total_num,
const int num_expert,
const int expert_cur_core,
const int cur_expert_start,
const int cur_expert_end) {
// 3. Partition on num_expert, each core generate position index from 0,
// and for-loop on all expert_id data, use eq with own each expert_id,
// and filter on index, stores to each expert_id start address of
// sorted_idx on workspace.
// And sync for all cores.
// NRAM:
// -------------------------------------------------------------------
// |expert_id_onchip|cur_expert_result|gen_idx_onchip|filter_idx_onchip|
// | deal_num | deal_num | deal_num | deal_num |
// -------------------------------------------------------------------
// |expert_start_addr|
// | num_expert |
// -----------------
// Calculate new deal_num of sorting process
int deal_num = (NRAM_BUFFER_SIZE / sizeof(int) - num_expert) / 4;
// Each core deal with whole token expert_id data
int repeat = token_total_num / deal_num;
int remain = token_total_num % deal_num;
int token_addr_offset = 0;
int8_t *expert_id_onchip = nram_buffer;
int8_t *cur_expert_result = expert_id_onchip + deal_num * sizeof(int);
int8_t *gen_idx_onchip = cur_expert_result + deal_num * sizeof(int);
int8_t *filter_idx_onchip = gen_idx_onchip + deal_num * sizeof(int);
int8_t *expert_start_addr = filter_idx_onchip + deal_num * sizeof(int);
// When num_expert < taskDim, not all cores need to sort
if (expert_cur_core > 0) {
// Generate position index from 0
if (deal_num <= token_total_num) {
generateIntSeq((int *)gen_idx_onchip, deal_num);
} else { // only remainder part
generateIntSeq((int *)gen_idx_onchip, token_total_num);
}
// Initialize expert start address with presum of token count
__memcpy((int *)expert_start_addr, (int *)token_count_presum, num_expert * sizeof(int),
GDRAM2NRAM);
// repeat part
for (int i = 0; i < repeat; i++) {
// Load current core expert_id
__memcpy((int *)expert_id_onchip, ((int *)expert_id) + token_addr_offset,
deal_num * sizeof(int), GDRAM2NRAM);
token_addr_offset += deal_num;
// Loop for current core expert, eq, filter position index
// use filter, store to sorted_idx[expert_start_addr]
for (int cur_expert = cur_expert_start; cur_expert <= cur_expert_end; cur_expert++) {
__bang_eq_scalar((int *)cur_expert_result, (int *)expert_id_onchip, cur_expert, deal_num);
int cur_expert_offset = ((int *)expert_start_addr)[cur_expert];
// NOTE: __bang_filter() only support floating data type
uint32_t cur_expert_count =
__bang_filter((float *)filter_idx_onchip, (float *)gen_idx_onchip,
(float *)cur_expert_result, deal_num);
// Store to the corresponding address of sorted_idx
if (cur_expert_count > 0) {
__memcpy(((int *)sorted_idx) + cur_expert_offset, (int *)filter_idx_onchip,
cur_expert_count * sizeof(int), NRAM2GDRAM);
// Update address offset of current expert
((int *)expert_start_addr)[cur_expert] = cur_expert_offset + cur_expert_count;
}
}
// Update position index for each data loop
__bang_add_scalar((int *)gen_idx_onchip, (int *)gen_idx_onchip, (int)(deal_num), deal_num);
}
// remainder part
if (remain > 0) {
__memcpy((int *)expert_id_onchip, ((int *)expert_id) + token_addr_offset,
remain * sizeof(int), GDRAM2NRAM);
for (int cur_expert = cur_expert_start; cur_expert <= cur_expert_end; cur_expert++) {
__bang_eq_scalar((int *)cur_expert_result, (int *)expert_id_onchip, cur_expert, remain);
int cur_expert_offset = ((int *)expert_start_addr)[cur_expert];
// NOTE: __bang_filter() only support floating data type
uint32_t cur_expert_count =
__bang_filter((float *)filter_idx_onchip, (float *)gen_idx_onchip,
(float *)cur_expert_result, remain);
// Store to the corresponding address of sorted_idx
if (cur_expert_count > 0) {
__memcpy(((int *)sorted_idx) + cur_expert_offset, (int *)filter_idx_onchip,
cur_expert_count * sizeof(int), NRAM2GDRAM);
}
}
}
}
// Sync for all cores, get position index after sorting
__sync_all_ipu();
}
// 4. Get gather index for expand and combine
template <bool is_sram_scatter>
__mlu_func__ void getGatherIdx(int *gather_expand_idx,
int *gather_combine_idx,
int *sorted_idx,
const int token_cur_core,
const int cur_token_start,
const int topk) {
// 4. Partition on [num_token*topk],
// load sorted_idx onchip,
// generate sequence according to position index from 0, add token offset
// gather_combine_idx = scatter(seq, sorted_idx)
// gather_expand_idx = sorted_idx / topk
// update sequence
// NRAM:
// -------------------------------------------------------------------
// |sorted_idx_onchip|expand_idx_onchip|scatter_offset|scatter_sequence|
// | deal_num | deal_num | deal_num | deal_num |
// -------------------------------------------------------------------
// Calculate new deal_num of generate gather index
// NOTE: deal_num should align to 64 Bytes, because bang_scatter() constraints
int deal_num = (NRAM_BUFFER_SIZE / sizeof(int)) / 4;
int repeat = token_cur_core / deal_num;
int remain = token_cur_core % deal_num;
int token_addr_offset = cur_token_start;
// scatter dst GDRAM addr should align to 64B
int *combine_idx_align_addr = (int *)((uint64_t)(gather_combine_idx) >> 6 << 6);
int combine_idx_align_offset = (int)(gather_combine_idx - combine_idx_align_addr);
int8_t *sorted_idx_onchip = nram_buffer;
int8_t *expand_idx_onchip = sorted_idx_onchip + deal_num * sizeof(int);
int8_t *scatter_offset = expand_idx_onchip + deal_num * sizeof(int);
int8_t *scatter_sequence = scatter_offset + deal_num * sizeof(int);
// Generate position index from 0
// Add base offset to sequence according to current core token start address
if (token_cur_core > 0) {
if (deal_num <= token_cur_core) {
generateIntSeq((int *)scatter_sequence, deal_num);
__bang_add_scalar((int *)scatter_sequence, (int *)scatter_sequence, (int)token_addr_offset,
deal_num);
} else { // only remainder part
generateIntSeq((int *)scatter_sequence, token_cur_core);
__bang_add_scalar((int *)scatter_sequence, (int *)scatter_sequence, (int)token_addr_offset,
token_cur_core);
}
}
// repeat part
for (int i = 0; i < repeat; i++) {
// Load current core sorted_idx
__memcpy((int *)sorted_idx_onchip, ((int *)sorted_idx) + token_addr_offset,
deal_num * sizeof(int), GDRAM2NRAM);
// offset = sorted_idx * sizeof(int), counted in bytes
if (is_sram_scatter) {
__bang_mul_scalar((int *)scatter_offset, (int *)sorted_idx_onchip, (int)(sizeof(int)),
deal_num);
} else {
// GDRAM addr should align to 64B
__bang_fusion(FUSION_FAM, (int *)scatter_offset, (int *)sorted_idx_onchip,
combine_idx_align_offset, (int)(sizeof(int)), deal_num);
}
// Sync for scatter
__sync_compute();
if (is_sram_scatter) {
scatterSeqSram((int *)sram_buffer, (int *)scatter_sequence, (uint32_t *)scatter_offset,
deal_num);
} else {
// Scatter to output gather_combine_idx
scatterSeqDram((int *)combine_idx_align_addr, (int *)scatter_sequence,
(uint32_t *)scatter_offset, deal_num);
}
// expand_idx_onchip = sorted_idx / topk
__bang_div((int *)expand_idx_onchip, (int *)sorted_idx_onchip, topk, deal_num);
// Store expand idx
__memcpy(((int *)gather_expand_idx) + token_addr_offset, (int *)expand_idx_onchip,
deal_num * sizeof(int), NRAM2GDRAM);
if (is_sram_scatter) {
// if scatter to SRAM, need to sync compute with mv
__sync_move();
}
// Add offset to sequence and token_address
__bang_add_scalar((int *)scatter_sequence, (int *)scatter_sequence, (int)deal_num, deal_num);
token_addr_offset += deal_num;
}
// remainder part
if (remain > 0) {
// Load current core sorted_idx
__memcpy((int *)sorted_idx_onchip, ((int *)sorted_idx) + token_addr_offset,
remain * sizeof(int), GDRAM2NRAM);
// offset = sorted_idx * sizeof(int), counted in bytes
if (is_sram_scatter) {
__bang_mul_scalar((int *)scatter_offset, (int *)sorted_idx_onchip, (int)(sizeof(int)),
remain);
} else {
// GDRAM addr should align to 64B
__bang_fusion(FUSION_FAM, (int *)scatter_offset, (int *)sorted_idx_onchip,
combine_idx_align_offset, (int)(sizeof(int)), remain);
}
// Sync for scatter
__sync_compute();
if (is_sram_scatter) {
scatterSeqSram((int *)sram_buffer, (int *)scatter_sequence, (uint32_t *)scatter_offset,
remain);
} else {
// Scatter to output gather_combine_idx
scatterSeqDram((int *)combine_idx_align_addr, (int *)scatter_sequence,
(uint32_t *)scatter_offset, remain);
}
// expand_idx_onchip = sorted_idx / topk
__bang_div((int *)expand_idx_onchip, (int *)sorted_idx_onchip, topk, remain);
// Store expand idx
__memcpy(((int *)gather_expand_idx) + token_addr_offset, (int *)expand_idx_onchip,
remain * sizeof(int), NRAM2GDRAM);
}
}
// 4.1 Get gather combine index on SRAM
__mlu_func__ void getCombineIdxSram(int *sorted_idx,
const int token_cur_core,
const int cur_token_start) {
// 4.1 Partition on [num_token*topk], with only 1 union
// load sorted_idx onchip,
// generate sequence according to position index from 0, add token offset
// gather_combine_idx = scatter(seq, sorted_idx)
// update sequence
// NRAM:
// -------------------------------
// |scatter_offset|scatter_sequence|
// | deal_num | deal_num |
// -------------------------------
// Calculate new deal_num of generate gather index
// NOTE: deal_num should align to 64 Bytes, because bang_scatter() constraints
int deal_num = (NRAM_BUFFER_SIZE / sizeof(int)) / 2;
int repeat = token_cur_core / deal_num;
int remain = token_cur_core % deal_num;
int token_addr_offset = cur_token_start;
int8_t *scatter_offset = nram_buffer;
int8_t *scatter_sequence = scatter_offset + deal_num * sizeof(int);
// Generate position index from 0
// Add base offset to sequence according to current core token start address
if (token_cur_core > 0) {
if (deal_num <= token_cur_core) {
generateIntSeq((int *)scatter_sequence, deal_num);
__bang_add_scalar((int *)scatter_sequence, (int *)scatter_sequence, (int)token_addr_offset,
deal_num);
} else { // only remainder part
generateIntSeq((int *)scatter_sequence, token_cur_core);
__bang_add_scalar((int *)scatter_sequence, (int *)scatter_sequence, (int)token_addr_offset,
token_cur_core);
}
}
// repeat part
for (int i = 0; i < repeat; i++) {
// Load current core sorted_idx
__memcpy((int *)scatter_offset, ((int *)sorted_idx) + token_addr_offset, deal_num * sizeof(int),
GDRAM2NRAM);
// offset = sorted_idx * sizeof(int), counted in bytes
__bang_mul_scalar((int *)scatter_offset, (int *)scatter_offset, (int)(sizeof(int)), deal_num);
// Sync for scatter
__sync_compute();
// Scatter to SRAM
scatterSeqSram((int *)sram_buffer, (int *)scatter_sequence, (uint32_t *)scatter_offset,
deal_num);
__sync_move();
// Add offset to sequence and token_address
__bang_add_scalar((int *)scatter_sequence, (int *)scatter_sequence, (int)deal_num, deal_num);
token_addr_offset += deal_num;
}
// remainder part
if (remain > 0) {
// Load current core sorted_idx
__memcpy((int *)scatter_offset, ((int *)sorted_idx) + token_addr_offset, remain * sizeof(int),
GDRAM2NRAM);
// offset = sorted_idx * sizeof(int), counted in bytes
__bang_mul_scalar((int *)scatter_offset, (int *)scatter_offset, (int)(sizeof(int)), remain);
// Sync for scatter
__sync_compute();
scatterSeqSram((int *)sram_buffer, (int *)scatter_sequence, (uint32_t *)scatter_offset, remain);
}
}
// 4.2 Get gather expand index
__mlu_func__ void getExpandIdx(int *gather_expand_idx,
int *sorted_idx,
const int token_cur_core,
const int cur_token_start,
const int topk) {
// 4.2 Partition on [num_token*topk],
// load sorted_idx onchip,
// gather_expand_idx = sorted_idx / topk
// NRAM:
// -----------------------------------
// |sorted_idx_onchip|expand_idx_onchip|
// | deal_num | deal_num |
// -----------------------------------
// Calculate new deal_num of generate gather index
int deal_num = (NRAM_BUFFER_SIZE / sizeof(int)) / 2;
int repeat = token_cur_core / deal_num;
int remain = token_cur_core % deal_num;
int token_addr_offset = cur_token_start;
int8_t *sorted_idx_onchip = nram_buffer;
int8_t *expand_idx_onchip = sorted_idx_onchip + deal_num * sizeof(int);
// repeat part
for (int i = 0; i < repeat; i++) {
// Load current core sorted_idx
__memcpy((int *)sorted_idx_onchip, ((int *)sorted_idx) + token_addr_offset,
deal_num * sizeof(int), GDRAM2NRAM);
// expand_idx_onchip = sorted_idx / topk
__bang_div((int *)expand_idx_onchip, (int *)sorted_idx_onchip, topk, deal_num);
// Store expand idx
__memcpy(((int *)gather_expand_idx) + token_addr_offset, (int *)expand_idx_onchip,
deal_num * sizeof(int), NRAM2GDRAM);
token_addr_offset += deal_num;
}
// remainder part
if (remain > 0) {
// Load current core sorted_idx
__memcpy((int *)sorted_idx_onchip, ((int *)sorted_idx) + token_addr_offset,
remain * sizeof(int), GDRAM2NRAM);
// expand_idx_onchip = sorted_idx / topk
__bang_div((int *)expand_idx_onchip, (int *)sorted_idx_onchip, topk, remain);
// Store expand idx
__memcpy(((int *)gather_expand_idx) + token_addr_offset, (int *)expand_idx_onchip,
remain * sizeof(int), NRAM2GDRAM);
}
}
__mlu_global__ void launchMoeGenIdxKernel(int *gather_expand_idx,
int *gather_combine_idx,
int *token_count,
int *cusum_token_count,
void *workspace,
const void *expert_id,
const int num_token,
const int num_expert,
const int topk) {
// Store token count presum result, shape [num_expert + 1]
int *token_count_presum = (cusum_token_count != nullptr) ? cusum_token_count : (int *)workspace;
// Store position index after sorting, shape [num_token*topk]
int *sorted_idx = ((int *)workspace) + num_expert + 1;
// Calculate partition information for different processes
// Partition on [num_token*topk]
uint32_t token_total_num = num_token * topk;
uint32_t token_cur_core = token_total_num / taskDim;
uint32_t token_remain_num = token_total_num % taskDim;
token_cur_core += (uint32_t)(taskId < token_remain_num);
// Current core range according to partition on [num_token*topk]
uint32_t cur_token_start = (taskId < token_remain_num)
? token_cur_core * taskId
: token_cur_core * taskId + token_remain_num;
// Partition on [num_expert]
uint32_t expert_cur_core = num_expert / taskDim;
uint32_t expert_remain_num = num_expert % taskDim;
expert_cur_core += (uint32_t)(taskId < expert_remain_num);
// Current core range according to partition on [num_expert]
uint32_t cur_expert_start = (taskId < expert_remain_num)
? expert_cur_core * taskId
: expert_cur_core * taskId + expert_remain_num;
uint32_t cur_expert_end = cur_expert_start + expert_cur_core - 1;
// Use Union1 SRAM to scatter, only MLU500 series support now
#if __BANG_ARCH__ >= 592
bool is_sram_scatter = token_total_num * sizeof(int) < SRAM_BUFFER_SIZE;
#else
bool is_sram_scatter = false;
#endif
if (__is_ipu()) {
// 1. Get token count
getTokenCount((int *)token_count, (int *)expert_id, token_cur_core, cur_token_start,
num_expert);
// 2. Get presum of token count
getTokenCountPresum((int *)token_count_presum, (int *)token_count, num_expert);
// 3. Get expert position index after sorting
getSortedIdx((int *)sorted_idx, (int *)expert_id, (int *)token_count_presum, token_total_num,
num_expert, expert_cur_core, cur_expert_start, cur_expert_end);
}
#if EXPERT_AVG_COUNT_TEST
// NOTE: test avg expert code here:
if (__is_ipu() && taskId == 0) {
modifyTokenCountAndPresum((int *)token_count_presum, (int *)token_count, token_total_num,
num_expert);
}
__sync_cluster();
#endif
// 4. Get gather index for expand and combine
if (is_sram_scatter) {
// Only use Union1 SRAM
uint32_t scatter_idx_cur_core = token_total_num / 4;
uint32_t scatter_idx_remain_num = token_total_num % 4;
scatter_idx_cur_core += (uint32_t)(taskId < scatter_idx_remain_num);
uint32_t cur_idx_start = (taskId < scatter_idx_remain_num)
? scatter_idx_cur_core * taskId
: scatter_idx_cur_core * taskId + scatter_idx_remain_num;
// Only Union1 task type,
// deal once num is same with deal_num in getGatherIdx,
// which means only 1 repeat to generate both expand and combine idx on NRAM
const int deal_once_num = (NRAM_BUFFER_SIZE / sizeof(int)) / 4;
if (taskDim <= 4 || token_total_num < deal_once_num) {
if (taskId < 4) {
if (__is_ipu()) {
getGatherIdx<true>((int *)gather_expand_idx, (int *)gather_combine_idx, (int *)sorted_idx,
scatter_idx_cur_core, cur_idx_start, topk);
// sync for ipu and mpu
__sync_cluster();
} else {
// sync for ipu and mpu
__sync_cluster();
__memcpy_async((int *)gather_combine_idx, (int *)sram_buffer,
token_total_num * sizeof(int), SRAM2GDRAM);
}
}
} else {
// If taskDim > 4, use first union to generate combine idx,
// use other union to generate expand idx
if (taskId < 4) {
if (__is_ipu()) {
// Scatter combine idx to SRAM
getCombineIdxSram((int *)sorted_idx, scatter_idx_cur_core, cur_idx_start);
__sync_cluster();
} else {
__sync_cluster();
__memcpy_async((int *)gather_combine_idx, (int *)sram_buffer,
token_total_num * sizeof(int), SRAM2GDRAM);
}
} else {
// Other union generate expand idx
if (__is_ipu()) {
uint32_t expand_dim = taskDim - 4;
uint32_t expand_id = taskId - 4;
uint32_t expand_token_cur_core = token_total_num / expand_dim;
uint32_t expand_token_remain_num = token_total_num % expand_dim;
expand_token_cur_core += (uint32_t)(expand_id < expand_token_remain_num);
uint32_t expand_cur_token_start =
(expand_id < expand_token_remain_num)
? expand_token_cur_core * expand_id
: expand_token_cur_core * expand_id + expand_token_remain_num;
getExpandIdx((int *)gather_expand_idx, (int *)sorted_idx, expand_token_cur_core,
expand_cur_token_start, topk);
}
}
}
} else {
// not use SRAM to generate both expand and combine idx
if (__is_ipu()) {
getGatherIdx<false>((int *)gather_expand_idx, (int *)gather_combine_idx, (int *)sorted_idx,
token_cur_core, cur_token_start, topk);
}
}
// step 5 does not need MPU
if (__is_mpu()) {
return;
}
} // end of kernel
} // namespace kernels
KernelStatus invokeMoeGenIdxKernel(cnrtQueue_t queue,
int *gather_expand_idx,
int *gather_combine_idx,
int *token_count,
int *cusum_token_count,
void *workspace,
const void *expert_id,
const int num_token,
const int num_expert,
const int topk) {
CNdev dev;
cnCtxGetDevice(&dev);
int cluster_num;
int core_num;
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev));
const int token_total_num = num_token * topk;
// For partition on num_token*topk, single core processes at least 128 num
const int single_core_num_limit = 1024;
int need_core_num = std::ceil(float(token_total_num) / single_core_num_limit);
// When partition on num_expert, each core at least processes one expert
need_core_num = std::max(num_expert, need_core_num);
// When consider UnionX cnrt func type, reset cluster_num
if (token_total_num <= 4096) { // Block
cnrtFunctionType_t k_type = cnrtFuncTypeBlock;
cnrtDim3_t k_dim{1, 1, 1};
// Block kernel does not need workspace
kernels::launchMoeGenIdxBlockKernel<<<k_dim, k_type, queue>>>(
gather_expand_idx, gather_combine_idx, token_count, cusum_token_count, expert_id, num_token,
num_expert, topk);
return KernelStatus::KERNEL_STATUS_SUCCESS;
} else if (need_core_num <= 4) { // Union1
cluster_num = 1;
} else if (need_core_num <= 8) { // Union2
cluster_num = std::min(cluster_num, 2);
} else if (need_core_num <= 16) { // Union4
cluster_num = std::min(cluster_num, 4);
} else if (need_core_num <= 32) { // Union8
cluster_num = std::min(cluster_num, 8);
}
cnrtFunctionType_t k_type;
cnrtDim3_t k_dim{1, 1, 1};
// Find max UnionX cnrt func type
if (cluster_num == 1) {
k_type = cnrtFuncTypeUnion1;
k_dim.x = 4;
} else if (cluster_num < 4) { // cluster num is 2 or 3
k_type = cnrtFuncTypeUnion2;
k_dim.x = 8;
} else if (cluster_num < 8) { // cluster num is 4,5,6,7
k_type = cnrtFuncTypeUnion4;
k_dim.x = 16;
} else { // cluster num larger than 8
k_type = cnrtFuncTypeUnion8;
k_dim.x = 32;
}
// The expert_id is int data type
kernels::launchMoeGenIdxKernel<<<k_dim, k_type, queue>>>(
gather_expand_idx, gather_combine_idx, token_count, cusum_token_count, workspace, expert_id,
num_token, num_expert, topk);
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
#undef EXPERT_AVG_COUNT_TEST // undef test macro
} // namespace tmo

View File

@@ -0,0 +1,58 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_MOE_GEN_IDX_MLUH_
#define CSRC_KERNELS_MOE_GEN_IDX_MLUH_
#include <vector>
#include "../kernel_utils.h"
#include "cnnl.h"
namespace tmo {
/**
* @brief Apply generate MOE index operation, which performs the following
* tasks:
* - 1. Generate gather_expand_idx and gather_combine_idx.
* - 2. Output token_count, the token number of each expert.
* - 3. Prepare inputs and outputs address for group_gemm.
* @param queue: The queue of mlu.
* @param gather_expand_idx: Output. Pointer to the MLU memory that stores the
* gather index for expand hidden state operation, the shape must be
* [num_token * topk].
* @param gather_combine_idx: Output. Pointer to the MLU memory that stores the
* gather index for combine MOE operation, the shape must be
* [num_token * topk].
* @param token_count: Output. Pointer to the MLU memory that stores the token
* number of each expert, the shape must be [num_expert].
* @param cusum_token_count: Output. Pointer to the MLU memory that stores the
* cumulative sum of the token number of each expert, the shape must be
* [num_expert + 1]. It can be set to nullptr if don't need cusum output.
* @param workspace: Input. A pointer to the extra workspace required in the
* operation, the size must be larger than
* (num_expert + 1 + num_token * topk) multiplied by the size of uint32.
* @param expert_id: Input. Pointer to the MLU memory that stores the expert id
* of each token, the shape must be [num_token, topk].
* @param num_token: The number of tokens.
* @param num_expert: The number of experts.
* @param topk: The number of expert selected by each token.
*/
KernelStatus invokeMoeGenIdxKernel(cnrtQueue_t queue,
int *gather_expand_idx,
int *gather_combine_idx,
int *token_count,
int *cusum_token_count,
void *workspace,
const void *expert_id,
const int num_token,
const int num_expert,
const int topk);
} // namespace tmo
#endif // CSRC_KERNELS_MOE_GEN_IDX_MLUH_

View File

@@ -0,0 +1,21 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_MOE_MOE_MLUH_
#define CSRC_KERNELS_MOE_MOE_MLUH_
#include "add_bias_activation.mluh"
#include "combine_result.mluh"
#include "expand_input.mluh"
#include "gen_idx.mluh"
#include "softmax_topk.mluh"
#endif // CSRC_KERNELS_MOE_MOE_MLUH_

View File

@@ -0,0 +1,602 @@
#include <mlu.h>
#include <cassert>
#include <iostream>
#include <limits>
#include <map>
#include <ostream>
#include "cnnl.h"
#include "cnrt.h"
#include "softmax_topk.mluh"
namespace tmo {
namespace kernels {
#define SCATTER_ALIGN (64) // align for __scatter()
#define NRAM_SIZE (__MLU_NRAM_SIZE__ * 1024 - 32 * 1024)
#define SRAM_SIZE (__MLU_SRAM_SIZE__ * 1024 - 32 * 1024)
#define TILING_ALIGN (64)
#define DIV_UP(x, y) ((x) / (y) + (int)((x) % (y) > 0))
__nram__ int8_t nram_buffer[NRAM_SIZE];
__mlu_shared__ int8_t sram_buffer[SRAM_SIZE];
#define __TRANS_TILING(TYPE, CONVERT) \
__asm__ volatile("trans.tiling." TYPE \
" [%[dst]], [%[src]]," \
"%[in0], %[in1], %[is1], %[in2], %[is2], %[in3], %[is3], %[in4]," \
"%[is4], %[in5], %[is5]," \
"%[dn0], %[dn1], %[ds1], %[dn2], %[ds2], %[dn3], %[ds3], %[dn4]," \
"%[ds4], %[dn5], %[ds5]" CONVERT ::[dst] "r"(dst), \
[src] "r"(src), [in0] "r"(in0), [in1] "r"(in1), [is1] "r"(is1), [in2] "r"(in2), \
[is2] "r"(is2), [in3] "r"(in3), [is3] "r"(is3), [in4] "r"(in4), [is4] "r"(is4), \
[in5] "r"(in5), [is5] "r"(is5), [dn0] "r"(dn0), [dn1] "r"(dn1), [ds1] "r"(ds1), \
[dn2] "r"(dn2), [ds2] "r"(ds2), [dn3] "r"(dn3), [ds3] "r"(ds3), [dn4] "r"(dn4), \
[ds4] "r"(ds4), [dn5] "r"(dn5), [ds5] "r"(ds5));
template <typename SRC_DTYPE, typename DST_DTYPE, mluMemcpyDirection_t dir>
__mlu_func__ void __mlvm_trans(DST_DTYPE *dst,
const SRC_DTYPE *src,
const uint32_t in0,
const uint32_t in1,
const uint32_t is1,
const uint32_t in2,
const uint32_t is2,
const uint32_t in3,
const uint32_t is3,
const uint32_t in4,
const uint32_t is4,
const uint32_t in5,
const uint32_t is5,
const uint32_t dn0,
const uint32_t dn1,
const uint32_t ds1,
const uint32_t dn2,
const uint32_t ds2,
const uint32_t dn3,
const uint32_t ds3,
const uint32_t dn4,
const uint32_t ds4,
const uint32_t dn5,
const uint32_t ds5) {
if (SRAM2NRAM == dir && std::is_same<DST_DTYPE, float>::value) {
if (std::is_same<SRC_DTYPE, float>::value) {
__TRANS_TILING("nram.sram.b32", ";")
} else if (std::is_same<SRC_DTYPE, half>::value) {
__TRANS_TILING("nram.sram.b16", ", .cvt.f32.f16();")
#if __BANG_ARCH__ >= 500
} else if (std::is_same<SRC_DTYPE, bfloat16_t>::value) {
__TRANS_TILING("nram.sram.b16", ", .cvt.f32.bf16();")
#endif
}
}
}
/* 将shape为[h,w]的数据转置为[w,h](带转数)分4块分别进行处理。
* dst: dst地址
* src: src地址
* h: h方向大小
* w: w方向大小
*/
template <typename SRC_DTYPE, typename DST_DTYPE, mluMemcpyDirection_t dir>
__mlu_func__ void transhw2wh(DST_DTYPE *dst, SRC_DTYPE *src, uint32_t h, uint32_t w) {
uint32_t align_num = TILING_ALIGN / sizeof(SRC_DTYPE);
uint32_t w_align = w / align_num;
uint32_t w_rem = w % align_num;
uint32_t h_align = h / align_num;
uint32_t h_rem = h % align_num;
uint32_t in0 = TILING_ALIGN, dn0 = TILING_ALIGN;
uint32_t in1 = align_num, is1 = w * sizeof(SRC_DTYPE);
uint32_t in3 = w_align, is3 = TILING_ALIGN;
uint32_t in4 = h_align, is4 = w * TILING_ALIGN;
uint32_t dn1 = align_num, ds1 = h * sizeof(DST_DTYPE);
uint32_t dn3 = in3, ds3 = h * align_num * sizeof(DST_DTYPE);
uint32_t dn4 = in4, ds4 = align_num * sizeof(DST_DTYPE);
/* 1. h_align * w_align */
if (w_align > 0 && h_align > 0) {
__mlvm_trans<SRC_DTYPE, DST_DTYPE, dir>(dst, src, in0, in1, is1, 1, 0, in3, is3, in4, is4, 1, 0,
dn0, dn1, ds1, 1, 0, dn3, ds3, dn4, ds4, 1, 0);
}
/* 2. h_align * w_rem */
if (w_rem > 0 && h_align > 0) {
SRC_DTYPE *src_temp = src + w_align * align_num;
DST_DTYPE *dst_temp = dst + w_align * align_num * h;
in0 = w_rem * sizeof(SRC_DTYPE);
dn0 = TILING_ALIGN;
in1 = align_num;
is1 = w * sizeof(SRC_DTYPE);
in4 = h_align;
is4 = w * TILING_ALIGN;
dn1 = w_rem;
ds1 = h * sizeof(DST_DTYPE);
dn4 = in4;
ds4 = align_num * sizeof(DST_DTYPE);
__mlvm_trans<SRC_DTYPE, DST_DTYPE, dir>(dst_temp, src_temp, in0, in1, is1, 1, 0, 1, 0, in4, is4,
1, 0, dn0, dn1, ds1, 1, 0, 1, 0, dn4, ds4, 1, 0);
}
/* 3. h_rem * w_align */
if (w_align > 0 && h_rem > 0) {
SRC_DTYPE *src_temp = src + h_align * align_num * w;
DST_DTYPE *dst_temp = dst + h_align * align_num;
in0 = TILING_ALIGN;
dn0 = h_rem * sizeof(SRC_DTYPE);
in1 = h_rem;
is1 = w * sizeof(SRC_DTYPE);
in4 = w_align;
is4 = TILING_ALIGN;
dn1 = align_num;
ds1 = h * sizeof(DST_DTYPE);
dn4 = in4;
ds4 = h * align_num * sizeof(DST_DTYPE);
__mlvm_trans<SRC_DTYPE, DST_DTYPE, dir>(dst_temp, src_temp, in0, in1, is1, 1, 0, 1, 0, in4, is4,
1, 0, dn0, dn1, ds1, 1, 0, 1, 0, dn4, ds4, 1, 0);
}
/* 4. h_rem * w_rem */
if (w_rem > 0 && h_rem > 0) {
SRC_DTYPE *src_temp = src + h_align * align_num * w + w_align * align_num;
DST_DTYPE *dst_temp = dst + w_align * align_num * h + h_align * align_num;
in0 = w_rem * sizeof(SRC_DTYPE);
dn0 = h_rem * sizeof(SRC_DTYPE);
in1 = h_rem;
is1 = w * sizeof(SRC_DTYPE);
dn1 = w_rem;
ds1 = h * sizeof(DST_DTYPE);
__mlvm_trans<SRC_DTYPE, DST_DTYPE, dir>(dst_temp, src_temp, in0, in1, is1, 1, 0, 1, 0, 1, 0, 1,
0, dn0, dn1, ds1, 1, 0, 1, 0, 1, 0, 1, 0);
}
}
__mlu_func__ void getTopk(float *value_buffer,
uint32_t *index_buffer,
float *src_buffer,
float *compute_buffer,
float *max_buffer,
float *temp_buffer,
uint32_t *i_buffer,
uint32_t *col_buffer,
uint32_t topk,
uint32_t num_expert_group,
uint32_t col,
uint32_t row,
uint32_t value_index_stride,
uint32_t group_size,
bool is_deal_group) {
__bang_write_value((float *)temp_buffer, col, -INFINITY); // set -inf vector
for (int k = 0; k < topk; k++) {
if (is_deal_group) {
__bang_maxpool_index((uint32_t *)value_buffer + k * col, max_buffer, col, 1, num_expert_group,
1, num_expert_group, 1, 1);
__bang_fusion(FUSION_FMA, col_buffer, (uint32_t *)value_buffer + k * col, col, i_buffer, col,
col);
} else {
__bang_maxpool_value_index(value_buffer + k * col, max_buffer, col, 1, row, 1, row, 1, 1,
value_index_stride);
__bang_fusion(FUSION_FMA, col_buffer, index_buffer + k * col, col, i_buffer, col, col);
}
#if __BANG_ARCH__ >= 592
__bang_mul_scalar(col_buffer, col_buffer, sizeof(float), col); // index in byte
__scatter(max_buffer, temp_buffer, col_buffer, sizeof(uint32_t), NRAM2NRAM, sizeof(uint32_t),
col); // replace max value with -inf
#else
for (int i = 0; i < col; i++) {
uint32_t index = __load_nram(col_buffer + i);
max_buffer[index] = -INFINITY;
}
#endif
#if __BANG_ARCH__ < 500
if (is_deal_group) {
for (int i = 0; i < col; i++) {
uint32_t index = __load_nram((uint32_t *)value_buffer + k * col + i);
__memcpy(compute_buffer + i * row + index * group_size,
src_buffer + i * row + index * group_size, group_size * sizeof(float), NRAM2NRAM);
}
}
#endif
}
#if __BANG_ARCH__ >= 592
if (is_deal_group) {
__bang_transpose(index_buffer, (uint32_t *)value_buffer, topk, col);
__bang_mul_scalar((uint32_t *)value_buffer, i_buffer, row * sizeof(float), col);
__bang_move(value_buffer, value_buffer, col * sizeof(uint32_t), col * sizeof(uint32_t), 0,
topk - 1);
__bang_transpose((uint32_t *)compute_buffer, (uint32_t *)value_buffer, topk, col);
__bang_fusion(FUSION_FMA, index_buffer, index_buffer, group_size * sizeof(float),
(uint32_t *)compute_buffer, col * topk, col * topk);
__gather(compute_buffer, src_buffer, (uint32_t *)index_buffer, group_size * sizeof(float),
NRAM2NRAM, group_size * sizeof(float), col * topk);
__bang_write_value(src_buffer, row * col, -INFINITY);
__scatter(src_buffer, compute_buffer, index_buffer, group_size * sizeof(float), NRAM2NRAM,
group_size * sizeof(float), col * topk);
}
#endif
}
template <typename T>
__mlu_func__ void computeSoftmaxTopk(T *sram_buffer,
T *load_buffer,
float *src_buffer,
float *compute_buffer,
float *group_max_buffer,
float *nramout_value,
uint32_t *nramout_index,
uint32_t *i_buffer,
uint32_t *col_buffer,
float *softmax_buffer,
uint32_t row,
uint32_t nram_compute_col_num,
uint32_t mask_num,
uint32_t nram_max_col_num,
uint32_t topk,
int num_expert_group,
uint32_t topk_group,
uint32_t top_num,
uint32_t nram_col_offset,
int normalize_mode,
bool valid_mask,
bool split_mask) {
uint32_t nram_compute_num = nram_compute_col_num * row;
// convert to float for half/bf16 datatype
if (std::is_same<T, half>::value) {
__bang_half2float(src_buffer, (half *)load_buffer, nram_compute_num);
} else if (std::is_same<T, bfloat16_t>::value) {
__bang_bfloat162float(src_buffer, (bfloat16_t *)load_buffer, nram_compute_num);
}
// transpose [col, row] to [row, col]. To accelerate max/sum compute with maxpool/sumpool.
__bang_transpose(compute_buffer, src_buffer, nram_compute_col_num, row);
// compute softmax
int tmp = 0x3fb8aa3b;
float log2e = *(float *)&tmp; // for exp
// src_buffer reuse as buffer for max/sum.
__bang_maxpool(src_buffer, compute_buffer, nram_compute_col_num, row, 1, row, 1, 1, 1); // max
__bang_fusion(FUSION_FSM, compute_buffer, compute_buffer, src_buffer, log2e, nram_compute_num,
nram_compute_col_num);
__bang_pow2(compute_buffer, compute_buffer, nram_compute_num); // exp(input - max)
__bang_sumpool(src_buffer, compute_buffer, nram_compute_col_num, row, 1, row, 1, 1, 1); // sum
__bang_recip(src_buffer, src_buffer, nram_compute_col_num); // 1/sum
__bang_cycle_mul(compute_buffer, compute_buffer, src_buffer, nram_compute_num,
nram_compute_col_num);
__sync_cluster();
// move mask and compute
if (valid_mask) {
if (!split_mask) {
__bang_transpose(src_buffer, compute_buffer, row, nram_compute_col_num);
if (std::is_same<T, half>::value) {
__memcpy((half *)compute_buffer + mask_num * row, sram_buffer, mask_num * row * sizeof(T),
SRAM2NRAM);
__bang_half2float((float *)compute_buffer, (half *)compute_buffer + mask_num * row,
mask_num * row);
} else if (std::is_same<T, bfloat16_t>::value) {
__memcpy((bfloat16_t *)compute_buffer + mask_num * row, sram_buffer,
mask_num * row * sizeof(T), SRAM2NRAM);
__bang_bfloat162float((float *)compute_buffer,
(bfloat16_t *)compute_buffer + mask_num * row, mask_num * row);
} else {
__memcpy(compute_buffer, sram_buffer, mask_num * row * sizeof(T), SRAM2NRAM);
}
__bang_cycle_mul(src_buffer, src_buffer, compute_buffer, nram_compute_col_num * row,
mask_num * row);
__bang_transpose(compute_buffer, src_buffer, nram_compute_col_num, row);
} else {
transhw2wh<T, float, SRAM2NRAM>(src_buffer, sram_buffer + nram_col_offset * row,
nram_compute_col_num, row);
__sync();
__bang_mul(compute_buffer, compute_buffer, src_buffer, nram_compute_col_num * row);
}
}
if (normalize_mode == 2) {
__bang_sumpool(softmax_buffer, compute_buffer, nram_compute_col_num, row, 1, row, 1, 1, 1);
}
if (num_expert_group <= 1) {
// num_expert_group <= 1, maintain original topk calculation logic
getTopk(nramout_value, nramout_index, src_buffer, compute_buffer, compute_buffer, src_buffer,
i_buffer, col_buffer, topk, num_expert_group, nram_compute_col_num, row,
nram_max_col_num * topk * sizeof(float), 0, false);
} else {
// num_expert_group > 1, use grouped_topk calculation logic
uint32_t group_size = row / num_expert_group;
__bang_transpose(src_buffer, compute_buffer, row, nram_compute_col_num);
__bang_maxpool(group_max_buffer, compute_buffer, nram_compute_col_num, num_expert_group,
group_size, 1, group_size, 1, 1);
__bang_write_value(compute_buffer, row * nram_compute_col_num, -INFINITY);
// get topk_group
getTopk(nramout_value, nramout_index, src_buffer, compute_buffer, group_max_buffer,
(float *)nramout_index, i_buffer, col_buffer, topk_group, num_expert_group,
nram_compute_col_num, row, nram_max_col_num * topk * sizeof(float), group_size, true);
// get topk
#if __BANG_ARCH__ < 500
__bang_transpose(src_buffer, compute_buffer, nram_compute_col_num, row);
getTopk(nramout_value, nramout_index, src_buffer, compute_buffer, src_buffer, compute_buffer,
i_buffer, col_buffer, topk, num_expert_group, nram_compute_col_num, row,
nram_max_col_num * top_num * sizeof(float), 0, false);
#else
__bang_transpose(compute_buffer, src_buffer, nram_compute_col_num, row);
getTopk(nramout_value, nramout_index, src_buffer, compute_buffer, compute_buffer, src_buffer,
i_buffer, col_buffer, topk, num_expert_group, nram_compute_col_num, row,
nram_max_col_num * top_num * sizeof(float), 0, false);
#endif
} // end else
// normalize result
if (normalize_mode == 1) {
// compute_buffer reuse as buffer for sum.
__bang_sumpool(compute_buffer, nramout_value, nram_compute_col_num, topk, 1, topk, 1, 1, 1);
__bang_recip(compute_buffer, compute_buffer, nram_compute_col_num);
__bang_cycle_mul(nramout_value, nramout_value, compute_buffer, topk * nram_compute_col_num,
nram_compute_col_num);
} else if (normalize_mode == 2) {
__bang_recip(compute_buffer, softmax_buffer, nram_compute_col_num);
__bang_cycle_mul(nramout_value, nramout_value, compute_buffer, topk * nram_compute_col_num,
nram_compute_col_num);
}
// transpose back. src and dst of transpose can not be the same address.
__bang_transpose(compute_buffer, nramout_value, topk, nram_compute_col_num);
__bang_transpose((uint32_t *)nramout_value, nramout_index, topk, nram_compute_col_num);
}
template <typename T>
__mlu_global__ void MLUSoftmaxTopkKernel(T *input,
T *mask,
int *index_out,
float *value_out,
int col,
int row,
int mask_num,
int topk,
int num_expert_group,
int topk_group,
int normalize_mode) {
bool valid_mask = (mask != nullptr);
int top_num = topk >= topk_group ? topk : topk_group;
uint32_t nram_low_space =
PAD_UP((row * 2 + top_num * 2 + 2 + (normalize_mode == 2) + num_expert_group) * sizeof(float),
SCATTER_ALIGN);
if (num_expert_group <= 1) {
nram_low_space =
PAD_UP((row * 2 + topk * 2 + 2 + (normalize_mode == 2)) * sizeof(float), SCATTER_ALIGN);
}
uint32_t nram_max_col_num = (NRAM_SIZE) / nram_low_space;
if (nram_max_col_num > col / taskDim + (col % taskDim > 0)) {
nram_max_col_num = col / taskDim + (col % taskDim > 0);
}
nram_max_col_num = PAD_DOWN(nram_max_col_num, SCATTER_ALIGN / sizeof(float));
if (nram_max_col_num <= 0) {
nram_max_col_num = SCATTER_ALIGN / sizeof(float);
}
uint32_t nram_deal_num = nram_max_col_num * row;
uint32_t batch = col / mask_num;
// nram split:
// |--------------------------|--------------------------|--------------------|...
// | size: nram/2 -col*topk*2 | size: nram/2 -col*topk*2 |col*num_expert_group|...
// | src_buffer | compute_buffer | group_max_buffer |...
// |--------------------------|--------------------------|--------------------|...
// |----------------------------------------|---------------|--------------|
// | nram_col_num*3 | col*topk | col*topk |
// | i_buffer | col_buffer | softmax_buffer | nramout_value | nramout_index|
// |----------------------------------------|---------------|--------------|
float *src_buffer = (float *)nram_buffer;
float *compute_buffer = src_buffer + PAD_UP(nram_deal_num, SCATTER_ALIGN / sizeof(float));
float *group_max_buffer = compute_buffer + nram_deal_num;
uint32_t *i_buffer = (uint32_t *)group_max_buffer + num_expert_group * nram_max_col_num;
if (num_expert_group <= 1) {
i_buffer = (uint32_t *)group_max_buffer;
}
uint32_t *col_buffer = i_buffer + nram_max_col_num;
float *softmax_buffer = (float *)col_buffer + nram_max_col_num;
if (normalize_mode != 2) {
softmax_buffer = (float *)col_buffer;
}
float *nramout_value = softmax_buffer + nram_max_col_num;
uint32_t *nramout_index = (uint32_t *)nramout_value + top_num * nram_max_col_num;
if (num_expert_group <= 1) {
nramout_index = (uint32_t *)nramout_value + topk * nram_max_col_num;
}
T *load_buffer = (T *)src_buffer;
if (std::is_same<T, half>::value || std::is_same<T, bfloat16_t>::value) {
load_buffer = load_buffer + nram_deal_num;
}
// set i_buffer
for (uint32_t i = 0; i < nram_max_col_num; i++) {
i_buffer[i] = i;
}
// input[batch, mask, low], mask[mask, low]
if (nram_max_col_num >= mask_num) { // nram can deal complete mask
bool split_mask = false;
uint32_t batch_seg = nram_max_col_num / mask_num;
uint32_t batch_rem = batch % batch_seg;
uint32_t batch_seg_num = batch / batch_seg + (batch_rem > 0);
int repeat = DIV_UP(batch_seg_num, taskDim);
for (int i = 0; i < repeat; i++) {
uint32_t seg_id = i * taskDim + taskId;
uint32_t sram_load_num = mask_num * row;
uint32_t sram_load_offset = 0;
uint32_t nram_compute_col_num = (seg_id == batch_seg_num - 1 && batch_rem > 0)
? batch_rem * mask_num
: batch_seg * mask_num;
uint32_t nram_load_num = seg_id < batch_seg_num ? nram_compute_col_num * row : 0;
uint32_t nram_store_num = seg_id < batch_seg_num ? nram_compute_col_num * topk : 0;
uint32_t nram_load_offset = seg_id * batch_seg * mask_num * row;
uint32_t nram_store_offset = seg_id * batch_seg * mask_num * topk;
// Load
if (valid_mask) {
__memcpy_async(sram_buffer, mask + sram_load_offset, sram_load_num * sizeof(T), GDRAM2SRAM);
}
if (nram_load_num > 0) {
__memcpy(load_buffer, input + nram_load_offset, nram_load_num * sizeof(T), GDRAM2NRAM);
}
// Compute
computeSoftmaxTopk<T>((T *)sram_buffer, load_buffer, src_buffer, compute_buffer,
group_max_buffer, nramout_value, nramout_index, i_buffer, col_buffer,
softmax_buffer, row, nram_compute_col_num, mask_num, nram_max_col_num,
topk, num_expert_group, topk_group, top_num, 0, normalize_mode,
valid_mask, split_mask);
// Store
if (nram_store_num > 0) {
__memcpy(value_out + nram_store_offset, compute_buffer, nram_store_num * sizeof(float),
NRAM2GDRAM);
__memcpy(index_out + nram_store_offset, nramout_value, nram_store_num * sizeof(int),
NRAM2GDRAM);
}
__sync_cluster();
}
} else {
bool split_mask = true;
uint32_t mask_seg = nram_max_col_num;
uint32_t mask_rem = mask_num % mask_seg;
uint32_t mask_seg_num = mask_num / mask_seg + (mask_rem > 0);
uint32_t sram_mask_seg_num = DIV_UP(mask_seg_num, coreDim);
uint32_t sram_mask_rem = mask_num % sram_mask_seg_num;
uint32_t sram_average_mask_num = mask_num / sram_mask_seg_num;
for (int i = taskIdY; i < sram_mask_seg_num * batch; i += taskDimY) {
uint32_t batch_idx = i / sram_mask_seg_num;
uint32_t mask_idx = i % sram_mask_seg_num;
uint32_t sram_deal_mask_num = sram_average_mask_num + (mask_idx < sram_mask_rem);
uint32_t sram_load_num = sram_deal_mask_num * row;
uint32_t sram_mask_offset = mask_idx < sram_mask_rem
? mask_idx * (sram_average_mask_num + 1)
: mask_idx * sram_average_mask_num + sram_mask_rem;
uint32_t sram_load_offset = sram_mask_offset * row;
uint32_t nram_average_mask_num = sram_deal_mask_num / taskDimX;
uint32_t nram_mask_rem = sram_deal_mask_num % taskDimX;
uint32_t nram_deal_mask_num = nram_average_mask_num + (taskIdX < nram_mask_rem);
uint32_t nram_load_num = nram_deal_mask_num * row;
uint32_t nram_col_offset = taskIdX < nram_mask_rem
? taskIdX * (nram_average_mask_num + 1)
: taskIdX * nram_average_mask_num + nram_mask_rem;
uint32_t nram_load_offset = (batch_idx * mask_num + sram_mask_offset + nram_col_offset) * row;
uint32_t nram_store_num = nram_deal_mask_num * topk;
uint32_t nram_store_offset =
(batch_idx * mask_num + sram_mask_offset + nram_col_offset) * topk;
// Load
if (valid_mask) {
__memcpy_async(sram_buffer, mask + sram_load_offset, sram_load_num * sizeof(T), GDRAM2SRAM);
}
if (nram_load_num > 0) {
__memcpy(load_buffer, input + nram_load_offset, nram_load_num * sizeof(T), GDRAM2NRAM);
}
// Compute
computeSoftmaxTopk<T>((T *)sram_buffer, load_buffer, src_buffer, compute_buffer,
group_max_buffer, nramout_value, nramout_index, i_buffer, col_buffer,
softmax_buffer, row, nram_deal_mask_num, mask_num, nram_max_col_num,
topk, num_expert_group, topk_group, top_num, nram_col_offset,
normalize_mode, valid_mask, split_mask);
// Store
if (nram_store_num > 0) {
__memcpy(value_out + nram_store_offset, compute_buffer, nram_store_num * sizeof(float),
NRAM2GDRAM);
__memcpy(index_out + nram_store_offset, nramout_value, nram_store_num * sizeof(int),
NRAM2GDRAM);
}
__sync_cluster();
}
}
}
} // namespace kernels
KernelStatus invokeMoeSoftmaxTopkKernel(cnrtQueue_t queue,
float *reduce_weight,
int *expert_id,
const void *input,
const void *mask,
const int num_token,
const int num_expert,
const int num_mask,
const int topk,
const int num_expert_group,
const int topk_group,
const cnnlDataType_t dtype,
const int normalize_mode) {
CNdev dev;
cnCtxGetDevice(&dev);
int cluster_num;
int core_num;
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev));
cnrtDim3_t dim{.x = (uint32_t)core_num, .y = (uint32_t)cluster_num, .z = 1};
int top_num = topk >= topk_group ? topk : topk_group;
if (num_expert_group <= 1) {
if (num_expert > (NRAM_SIZE - (topk * 2 + 3) * sizeof(float)) / 2 / sizeof(float)) {
std::cerr << "[invokeMoeSoftmaxTopkKernel]: num_expert is too large, currently not supported."
<< "Supported max num_expert:"
<< (NRAM_SIZE - (topk * 2 + 3) * sizeof(float)) / 2 / sizeof(float)
<< ". Current num_expert:" << num_expert;
return KernelStatus::KERNEL_STATUS_FAILED;
}
} else {
if (num_expert >
(NRAM_SIZE - (top_num * 2 + 2 + num_expert_group) * sizeof(float)) / 2 / sizeof(float)) {
std::cerr << "[invokeMoeSoftmaxTopkKernel]: num_expert is too large, currently not supported."
<< "Supported max num_expert:"
<< (NRAM_SIZE - (topk * 2 + 3) * sizeof(float)) / 2 / sizeof(float)
<< ". Current num_expert:" << num_expert;
return KernelStatus::KERNEL_STATUS_FAILED;
}
}
if (topk > num_expert) {
std::cerr << "[invokeMoeSoftmaxTopkKernel]: topk is larger than num_expert."
<< "topk:" << topk << ". num_expert:" << num_expert;
return KernelStatus::KERNEL_STATUS_FAILED;
}
if (num_expert_group > 1) {
if (mask != nullptr) {
std::cerr << "[invokeMoeSoftmaxTopkKernel]: if num_expert_group > 1, mask should be nullptr";
}
if (num_expert % num_expert_group != 0) {
std::cerr << "[invokeMoeSoftmaxTopkKernel]: if num_expert_group > 1, num_expert should be"
<< "divisible by num_expert_group, but now num_expert:" << num_expert
<< ", num_expert_group:" << num_expert_group;
return KernelStatus::KERNEL_STATUS_FAILED;
}
if (topk_group <= 0 || topk_group > num_expert_group) {
std::cerr << "[invokeMoeSoftmaxTopkKernel]: if num_expert_group > 1, topk_group should be"
<< "larger than 0 and less than or equal to num_expert_group, but now topk_group"
<< topk_group << ", num_expert group:" << num_expert_group;
return KernelStatus::KERNEL_STATUS_FAILED;
}
if (topk > (num_expert / num_expert_group) * topk_group) {
std::cerr << "[invokeMoeSoftmaxTopkKernel]: if num_expert_group > 1, topk should be less"
<< "than or equal to (num_expert / num_expert_group) * topk_group, but now"
<< "topk :" << topk << ", num_expert:" << num_expert
<< ", num_expert_group:" << num_expert_group << ", topk_group:" << topk_group;
return KernelStatus::KERNEL_STATUS_FAILED;
}
}
if (dtype == CNNL_DTYPE_FLOAT) {
kernels::MLUSoftmaxTopkKernel<<<dim, cnrtFuncTypeUnion1, queue>>>(
(float *)input, (float *)mask, expert_id, reduce_weight, num_token, num_expert, num_mask,
topk, num_expert_group, topk_group, normalize_mode);
} else if (dtype == CNNL_DTYPE_HALF) {
kernels::MLUSoftmaxTopkKernel<<<dim, cnrtFuncTypeUnion1, queue>>>(
(half *)input, (half *)mask, expert_id, reduce_weight, num_token, num_expert, num_mask,
topk, num_expert_group, topk_group, normalize_mode);
} else if (dtype == CNNL_DTYPE_BFLOAT16) {
if (!isBf16Supported()) {
std::cerr << "[invokeMoeSoftmaxTopkKernel]: MLU300 devices do not support bfloat16."
<< std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
kernels::MLUSoftmaxTopkKernel<<<dim, cnrtFuncTypeUnion1, queue>>>(
(bfloat16_t *)input, (bfloat16_t *)mask, expert_id, reduce_weight, num_token, num_expert,
num_mask, topk, num_expert_group, topk_group, normalize_mode);
} else {
std::cerr << "[invokeMoeSoftmaxTopkKernel]: source type not supported ";
return KernelStatus::KERNEL_STATUS_FAILED;
}
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
} // namespace tmo

View File

@@ -0,0 +1,66 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_MOE_SOFTMAX_TOPK_MLUH_
#define CSRC_KERNELS_MOE_SOFTMAX_TOPK_MLUH_
#include "../kernel_utils.h"
#include "cnnl.h"
namespace tmo {
/**
* @brief Execute MOE Softmax Top-K Kernel.
*
* This function executes the MOE Softmax Top-K Kernel, which computes
* the Top-K values along a specified dimension after applying softmax to the input data.
* It is specifically designed for reduction along the lowest dimension.
*
* @param queue CNRT queue used to specify the queue for execution.
* @param reduce_weight Pointer to store the Top-K values.
* The shape must be [num_token, topk].
* @param expert_id Pointer to store the indices of the Top-K values.
* The shape must be [num_token, topk].
* @param input Pointer to the input data containing the values to be computed.
* The shape must be [num_token, num_expert].
* @param mask Pointer to the input data containing the mask value to be computed after
* computing softmax, Mask can be nullptr, which means no need to compute,
* otherwise the shape and datatype of mask should be the same as input.
* @param num_token Number of channels in the input data.
* @param num_expert Specified dimension. Note that num_expert should not exceed 32768.
* @param num_mask Number of channels in the mask data.
* @param topk Number of Top-K values to compute. topk should not be larger than num_expert.
* @param num_expert_group Group numbers of num_expert. If num_expert_group > 0, num_expert
* should be divisible by num_expert_group. Otherwise, num_expert_group and topk_group
* is not valid.
* @param topk_group Number of Top-K group values to compute. Topk_group should not be larger
* than num_expert_group.
* @param dtype Data type of the input data, should match the actual data type.
* float, half, bfloat16 is supported.
* @param normalize_mode Whether and how to normalize the output, if normalize_mode == 0, no
normalization is performed; if normalize_mode == 1, the normalized denominator is
the sum of topk; if normalize_mode == 2, the normalized denominator is the sum of
* the products of softmax_result mask.
*/
KernelStatus invokeMoeSoftmaxTopkKernel(cnrtQueue_t queue,
float *reduce_weight,
int *expert_id,
const void *input,
const void *mask,
const int num_token,
const int num_expert,
const int num_mask,
const int topk,
const int num_expert_group,
const int topk_group,
const cnnlDataType_t dtype,
const int normalize_mode);
} // namespace tmo
#endif // CSRC_KERNELS_MOE_SOFTMAX_TOPK_MLUH_

View File

@@ -0,0 +1,425 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include <cassert>
#include <iostream>
#include "offline_quant_to_linear_cache.mluh"
// clang-format off
#include <mlu.h>
// clang-format on
namespace tmo {
namespace kernels {
#define NRAM_BUFFER_SIZE (480 * 1024)
__nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE];
#define sizeof_(T) (uint32_t)sizeof(T)
template <typename T>
__mlu_func__ void quantify(int8_t *nram_output,
float *nram_input_float,
T *nram_input,
float *nram_scale,
int input_num,
int scale_num) {
if (std::is_same<half, T>::value) {
__bang_half2float(nram_input_float, (half *)nram_input, input_num);
} else if (std::is_same<bfloat16_t, T>::value) {
#if __BANG_ARCH__ > 500
__bang_bfloat162float(nram_input_float, (bfloat16_t *)nram_input, input_num);
#endif
}
__bang_cycle_mul(nram_input_float, nram_input_float, nram_scale, input_num, scale_num);
__bang_float2int8_rn(nram_output, (float *)nram_input_float, input_num, 0);
}
template <typename T>
__mlu_func__ void quantPerHead(int8_t *output_gdram,
int8_t *output_nram,
const T *input_gdram,
T *input_nram,
const float *scale_gdram,
float *scale_nram,
float *input_nram_float,
T *trans_nram,
int seq,
int head_num,
int head_size,
size_t in_hstr_bytes, // context head_num stide bytes
size_t in_sstr_bytes, // context seq stide bytes
size_t scale_hstr_bytes, // scale head_num stide bytes
size_t out_hstr_bytes, // cache head_num stride bytes
size_t out_sstr_bytes // cache seq stride bytes
) {
constexpr int dtype_size = sizeof_(T);
// nram_input: (head_num, seq, head_size)
int io1_size = head_size * dtype_size;
__memcpy(trans_nram, input_gdram, io1_size, GDRAM2NRAM, seq * io1_size, head_num - 1, io1_size,
seq - 1, in_hstr_bytes, head_num - 1, in_sstr_bytes, seq - 1);
// nram_scale:(head_num, seq);
int io2_size = seq * sizeof_(float);
__memcpy(scale_nram, scale_gdram, io2_size, GDRAM2NRAM, io2_size, scale_hstr_bytes, head_num - 1);
__bang_recip(scale_nram, scale_nram, head_num * seq);
if (std::is_same<T, half>::value || std::is_same<T, bfloat16_t>::value) {
__bang_transpose((half *)input_nram, (half *)trans_nram, head_num * seq, head_size);
} else {
__bang_transpose(input_nram_float, (float *)trans_nram, head_num * seq, head_size);
}
quantify<T>(output_nram, input_nram_float, input_nram, scale_nram, head_size * head_num * seq,
head_num * seq);
__bang_transpose((int8_t *)trans_nram, output_nram, head_size, head_num * seq);
__memcpy(output_gdram, trans_nram, head_size, NRAM2GDRAM, out_hstr_bytes, head_num - 1,
out_sstr_bytes, seq - 1, seq * head_size, head_num - 1, head_size, seq - 1);
}
template <typename T>
__mlu_global__ void MLUOfflineQuantToLinearCacheKernelPerHead(int8_t *key_cache,
int8_t *value_cache,
const float *key_cache_scale,
const float *value_cache_scale,
const int *cache_bs_offsets,
const int *cache_seq_offsets,
const T *key,
const T *value,
const int *context_seq_offsets,
const int *context_lens,
const int batch,
const int head_num,
const int head_size,
const int max_context_len,
const int cache_mem_len,
const size_t context_bs_stride,
const size_t context_head_stride,
const size_t context_seq_stride,
const size_t cache_bs_stride,
const size_t cache_head_stride,
const size_t cache_seq_stride,
const size_t cache_scale_head_stride,
const bool packed,
const int seq_block) {
bool handle_key = (key != nullptr && key_cache != nullptr && key_cache_scale != nullptr);
bool handle_value = (value != nullptr && value_cache != nullptr && value_cache_scale != nullptr);
if ((!handle_key) && (!handle_value)) {
return;
}
constexpr int dtype_size = sizeof_(T);
size_t in_hstr_bytes = context_head_stride * dtype_size;
size_t in_sstr_bytes = context_seq_stride * dtype_size;
size_t out_hstr_bytes = cache_head_stride * sizeof_(int8_t);
size_t out_sstr_bytes = cache_seq_stride * sizeof_(int8_t);
size_t scale_hstr_bytes = cache_scale_head_stride * sizeof_(float);
/* ***************************nram space ****************************
* | scale | input/output | trans |
* scale size:[head_num, seq_block], float
* input size:[head_num, seq_block, head_size], float
* trans size:, [head_size, head_num, seq_block], T
*/
float *scale_nram = (float *)nram_buffer;
float *input_nram_float = nullptr;
T *trans_nram = nullptr, *input_nram = nullptr;
input_nram_float = scale_nram + head_num * seq_block;
trans_nram = (T *)(input_nram_float + head_num * seq_block * head_size);
if (std::is_same<T, half>::value || std::is_same<T, bfloat16_t>::value) {
// need cast from input_nram to input_nram_float
input_nram = (T *)input_nram_float + seq_block * head_num * head_size;
} else {
input_nram = (T *)input_nram_float;
}
int8_t *output_nram = (int8_t *)input_nram_float; // output and input share nram space
for (int bs_idx = taskIdY; bs_idx < batch; bs_idx += taskDimY) {
int context_len = __load_gdram(context_lens + bs_idx);
int seq_len = packed ? (__load_gdram(context_lens + bs_idx + 1) - context_len) : context_len;
int task_seq_begin = taskIdZ * seq_block;
if (task_seq_begin >= seq_len) continue;
int seq = std::min(seq_len - task_seq_begin, seq_block);
// context offset
size_t context_offset = 0;
if (packed) {
context_offset = (context_len + task_seq_begin) * context_seq_stride;
} else {
int seq_offset = (context_seq_offsets == nullptr) ? 0 : context_seq_offsets[bs_idx];
context_offset =
(bs_idx * context_bs_stride + (seq_offset + task_seq_begin) * context_seq_stride);
}
// cache offset
int cache_seq_offset = (cache_seq_offsets == nullptr ? 0 : cache_seq_offsets[bs_idx]);
int cache_bs_offset = cache_bs_offsets == nullptr ? bs_idx : cache_bs_offsets[bs_idx];
if (cache_seq_offset < 0 || cache_bs_offset < 0) {
continue;
}
cache_seq_offset += task_seq_begin;
size_t cache_offset = (cache_bs_offset * cache_bs_stride + cache_seq_offset * cache_seq_stride);
// per_head, nram input[head_num, seq, head_size], nram scale[head_num, seq]
if (handle_key) {
quantPerHead(key_cache + cache_offset, output_nram, key + context_offset, input_nram,
key_cache_scale + cache_seq_offset, scale_nram, input_nram_float, trans_nram,
seq, head_num, head_size, in_hstr_bytes, in_sstr_bytes, scale_hstr_bytes,
out_hstr_bytes, out_sstr_bytes);
}
if (handle_value) {
quantPerHead(value_cache + cache_offset, output_nram, value + context_offset, input_nram,
value_cache_scale + cache_seq_offset, scale_nram, input_nram_float, trans_nram,
seq, head_num, head_size, in_hstr_bytes, in_sstr_bytes, scale_hstr_bytes,
out_hstr_bytes, out_sstr_bytes);
}
}
}
template <typename T>
__mlu_global__ void MLUOfflineQuantToLinearCacheKernelPerChannel(
int8_t *key_cache,
int8_t *value_cache,
const float *key_cache_scale,
const float *value_cache_scale,
const int *cache_bs_offsets,
const int *cache_seq_offsets,
const T *key,
const T *value,
const int *context_seq_offsets,
const int *context_lens,
const int batch,
const int head_num,
const int head_size,
const int max_context_len,
const int cache_mem_len,
const size_t context_bs_stride,
const size_t context_head_stride,
const size_t context_seq_stride,
const size_t cache_bs_stride,
const size_t cache_head_stride,
const size_t cache_seq_stride,
const size_t cache_scale_head_stride,
const bool packed,
const int seq_block) {
bool handle_key = (key != nullptr && key_cache != nullptr && key_cache_scale != nullptr);
bool handle_value = (value != nullptr && value_cache != nullptr && value_cache_scale != nullptr);
if ((!handle_key) && (!handle_value)) {
return;
}
constexpr int dtype_size = sizeof_(T);
size_t in_hstr_bytes = context_head_stride * dtype_size;
size_t in_sstr_bytes = context_seq_stride * dtype_size;
size_t out_hstr_bytes = cache_head_stride * sizeof_(int8_t);
size_t out_sstr_bytes = cache_seq_stride * sizeof_(int8_t);
size_t scale_hstr_bytes = cache_scale_head_stride * sizeof_(float);
/* *********************************nram space **************************************
* per_chennel: |scale[head_num, head_size]| input[seq_block, head_num, head_size]|
*/
float *scale_nram = (float *)nram_buffer;
float *input_nram_float = scale_nram + head_num * head_size;
T *input_nram = (T *)input_nram_float;
if (std::is_same<T, half>::value || std::is_same<T, bfloat16_t>::value) {
// need cast from input_nram to input_nram_float
input_nram = (T *)input_nram_float + seq_block * head_num * head_size;
}
int8_t *output_nram = (int8_t *)input_nram_float; // output and input share nram space
int size1 = head_size * sizeof_(float);
int size2 = head_size * dtype_size;
int scale_num = head_num * head_size;
if (handle_key) {
// load offline scale nram_scale:(head_num, head_size);
__memcpy(scale_nram, key_cache_scale, size1, GDRAM2NRAM, size1, scale_hstr_bytes, head_num - 1);
__bang_recip(scale_nram, scale_nram, scale_num);
for (int bs_idx = taskIdY; bs_idx < batch; bs_idx += taskDimY) {
int context_len = __load_gdram(context_lens + bs_idx);
int seq_len = packed ? (__load_gdram(context_lens + bs_idx + 1) - context_len) : context_len;
int task_seq_begin = taskIdZ * seq_block;
if (task_seq_begin >= seq_len) continue;
int seq = std::min(seq_len - task_seq_begin, seq_block);
// context offset
size_t context_offset = 0;
if (packed) {
context_offset = (context_len + task_seq_begin) * context_seq_stride;
} else {
int seq_offset = (context_seq_offsets == nullptr) ? 0 : context_seq_offsets[bs_idx];
context_offset =
(bs_idx * context_bs_stride + (seq_offset + task_seq_begin) * context_seq_stride);
}
// cache offset
int cache_seq_offset = (cache_seq_offsets == nullptr ? 0 : cache_seq_offsets[bs_idx]);
int cache_bs_offset = cache_bs_offsets == nullptr ? bs_idx : cache_bs_offsets[bs_idx];
if (cache_seq_offset < 0 || cache_bs_offset < 0) {
continue;
}
cache_seq_offset += task_seq_begin;
size_t cache_offset =
(cache_bs_offset * cache_bs_stride + cache_seq_offset * cache_seq_stride);
__memcpy(input_nram, key + context_offset, size2, GDRAM2NRAM, size2, head_num - 1,
head_num * size2, seq - 1, in_hstr_bytes, head_num - 1, in_sstr_bytes, seq - 1);
quantify<T>((int8_t *)output_nram, input_nram_float, input_nram, scale_nram, seq * scale_num,
scale_num);
__memcpy(key_cache + cache_offset, output_nram, head_size, NRAM2GDRAM, out_hstr_bytes,
head_num - 1, out_sstr_bytes, seq - 1, head_size, head_num - 1, scale_num, seq - 1);
}
}
if (handle_value) {
// load offline scale nram_scale:(head_num, head_size);
__memcpy(scale_nram, value_cache_scale, size1, GDRAM2NRAM, size1, scale_hstr_bytes,
head_num - 1);
__bang_recip(scale_nram, scale_nram, scale_num);
for (int bs_idx = taskIdY; bs_idx < batch; bs_idx += taskDimY) {
int context_len = __load_gdram(context_lens + bs_idx);
int seq_len = packed ? (__load_gdram(context_lens + bs_idx + 1) - context_len) : context_len;
int task_seq_begin = taskIdZ * seq_block;
if (task_seq_begin >= seq_len) continue;
int seq = std::min(seq_len - task_seq_begin, seq_block);
// context offset
size_t context_offset = 0;
if (packed) {
context_offset = (context_len + task_seq_begin) * context_seq_stride;
} else {
int seq_offset = (context_seq_offsets == nullptr) ? 0 : context_seq_offsets[bs_idx];
context_offset =
(bs_idx * context_bs_stride + (seq_offset + task_seq_begin) * context_seq_stride);
}
// cache offset
int cache_seq_offset = (cache_seq_offsets == nullptr ? 0 : cache_seq_offsets[bs_idx]);
int cache_bs_offset = cache_bs_offsets == nullptr ? bs_idx : cache_bs_offsets[bs_idx];
if (cache_seq_offset < 0 || cache_bs_offset < 0) {
continue;
}
cache_seq_offset += task_seq_begin;
size_t cache_offset =
(cache_bs_offset * cache_bs_stride + cache_seq_offset * cache_seq_stride);
__memcpy(input_nram, value + context_offset, size2, GDRAM2NRAM, size2, head_num - 1,
head_num * size2, seq - 1, in_hstr_bytes, head_num - 1, in_sstr_bytes, seq - 1);
quantify<T>((int8_t *)output_nram, input_nram_float, input_nram, scale_nram, seq * scale_num,
scale_num);
__memcpy(value_cache + cache_offset, output_nram, head_size, NRAM2GDRAM, out_hstr_bytes,
head_num - 1, out_sstr_bytes, seq - 1, head_size, head_num - 1, scale_num, seq - 1);
}
}
}
} // namespace kernels
#define LAUNCH_OFFLINE_QUANT_KERNEL(Dtype, Name) \
kernels::MLUOfflineQuantToLinearCacheKernel##Name<Dtype><<<dim, cnrtFuncTypeBlock, queue>>>( \
(int8_t *)key_cache, (int8_t *)value_cache, (float *)key_cache_scale, \
(float *)value_cache_scale, (int *)cache_bs_offsets, (int *)cache_seq_offsets, (Dtype *)key, \
(Dtype *)value, (int *)context_seq_offsets, (int *)context_lens, batch, head_num, head_size, \
max_context_len, cache_mem_len, context_bs_stride, context_head_stride, context_seq_stride, \
cache_bs_stride, cache_head_stride, cache_seq_stride, cache_scale_head_stride, packed, \
seq_block);
KernelStatus invokeOfflineQuantToLinearCache(cnrtQueue_t queue,
void *key_cache,
void *value_cache,
const void *key_cache_scale,
const void *value_cache_scale,
const void *cache_bs_offsets,
const void *cache_seq_offsets,
const void *key,
const void *value,
const void *context_seq_offsets,
const void *context_lens,
const cnnlDataType_t dtype,
const int batch,
const int head_num,
const int head_size,
const int max_context_len,
const int cache_mem_len,
const size_t context_bs_stride,
const size_t context_head_stride,
const size_t context_seq_stride,
const size_t cache_bs_stride,
const size_t cache_head_stride,
const size_t cache_seq_stride,
const size_t cache_scale_head_stride,
const bool packed,
const int quant_mode) {
constexpr int nram_size = 480 * 1024;
int dtype_size = dtype == CNNL_DTYPE_FLOAT ? sizeof(float) : sizeof(half);
int seq_block = 0;
if (quant_mode == 0) {
seq_block = nram_size / (head_num * head_size * sizeof(float)) - 1;
if (seq_block <= 0) {
std::cerr << __func__ << "," << __LINE__
<< " :head_num * head_size * sizeof(float) should be less than 240KB when "
"quant_mode is 0."
<< std::endl;
}
} else {
seq_block = nram_size /
(head_num * sizeof(float) + head_num * head_size * (sizeof(float) + dtype_size));
if (seq_block <= 0) {
std::cerr << __func__ << "," << __LINE__
<< " :head_num * sizeof(float) + head_num * head_size * (sizeof(float) + "
"context_dtype_size)) "
<< " should be less than 480KB when quant_mode is not 0." << std::endl;
}
}
seq_block = std::min(seq_block, max_context_len);
if (seq_block > 16 && seq_block < max_context_len) {
seq_block = seq_block / 16 * 16;
}
int seq_seg = (max_context_len + seq_block - 1) / seq_block;
CNdev dev;
int cluster_dim, core_dim;
cnCtxGetDevice(&dev);
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_dim, cnrtAttrClusterCount, dev));
CNRT_CHECK(cnrtDeviceGetAttribute(&core_dim, cnrtAttrMcorePerCluster, dev));
uint32_t core_num = cluster_dim * core_dim;
uint32_t task_y_dim = std::min((uint32_t)batch, core_num);
cnrtDim3_t dim{1, task_y_dim, (uint32_t)seq_seg};
if (dtype == CNNL_DTYPE_HALF) {
if (quant_mode == 0) {
LAUNCH_OFFLINE_QUANT_KERNEL(half, PerChannel);
} else {
LAUNCH_OFFLINE_QUANT_KERNEL(half, PerHead);
}
} else if (dtype == CNNL_DTYPE_BFLOAT16) {
if (quant_mode == 0) {
LAUNCH_OFFLINE_QUANT_KERNEL(bfloat16_t, PerChannel);
} else {
LAUNCH_OFFLINE_QUANT_KERNEL(bfloat16_t, PerHead);
}
} else {
if (quant_mode == 0) {
LAUNCH_OFFLINE_QUANT_KERNEL(float, PerChannel);
} else {
LAUNCH_OFFLINE_QUANT_KERNEL(float, PerHead);
}
}
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
} // namespace tmo

View File

@@ -0,0 +1,103 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_OFFLINE_QUANT_TO_LINEAR_CACHE_MLUH_
#define CSRC_KERNELS_OFFLINE_QUANT_TO_LINEAR_CACHE_MLUH_
#include "kernel_utils.h"
namespace tmo {
/**
* @brief Quantize current key and value, Then store key and value to key_cache and value_cache.
* @param queue: The queue for mlu.
* @param key_cache: Pointer to the MLU memory that stores the key cache,
* the shape must be [max_batch, head_num, cache_mem_len, head_size].
* Data type of key_cache must be int8. key_cache could be nullptr.
* @param value_cache: Pointer to the MLU memory that stores the value cache,
* the shape must be [max_batch, head_num, cache_mem_len, head_size].
* Data type of value_cache must be int8. value_cache could be nullptr.
* @param key_cache_scale: Pointer to the MLU memory that stores the key cache scale,
* the shape must be [head_num, cache_mem_len] when quant_mode is not zero,
* and [head_num, head_size] when quant_mode is zero. Data type of key_cache_scale
* must be float. value_cache could be nullptr.
* @param value_cache_scale: Pointer to the MLU memory that stores the value cache scale,
* the shape must be [head_num, cache_mem_len] when quant_mode is not zero,
* and [head_num, head_size] when quant_mode is zero. Data type of value_cache_scale
* must be float. value_cache_scale could be nullptr.
* @param cache_bs_offsets: Pointer to the MLU memory that stores the batch
* offset of cache, the shape must be [batch], if it's nullptr, the
* default value is {0, 1, 2 ... batch - 1}.
* @param cache_seq_offsets: Pointer to the MLU memory that stores the sequence
* offset of cache, the shape must be [batch], if it's nullptr, the
* default value is 0 for every batch.
* @param key: Pointer to the MLU memory that stores the key,
* the shape must be [batch, max_contxt_len, head_num, head_size].
* Data type of key couble be float/half/bfloat16. key could be nullptr.
* @param value: Pointer to the MLU memory that stores the value,
* the shape must be [batch, max_contxt_len, head_num, head_size].
* Data type of value couble be float/half/bfloat16, value could be nullptr.
* @param context_seq_offsets: Pointer to the MLU memory that stores the
* sequence offset of context, the shape must be [batch]. if it's nullptr,
* the default value is 0 for every batch. It must be nullptr when packed is true.
* @param context_lens: Pointer to the MLU memory that stores the sequence length or cumulative
* sequence length of context. when packed is false, the shape must be [batch], which
* indicates sequence length of context. when packed is true, the shape must be [batch + 1],
which
* indicates cumulative sequence length of context.
* @param dtype: Data type.
* @param batch: Batch size.
* @param head_num: Head number.
* @param head_size: Head size.
* @param max_contxt_len: The maximum sequence length of context.
* @param cache_mem_len: The maximum sequence length of cache.
* @param contxt_bs_stride: The stride of batch in context, does not work when packed is true.
* @param contxt_head_stride: The stride of head_num in context.
* @param contxt_seq_stride: The stride of max_contxt_len in context.
* @param cache_bs_stride: The stride of batch in cache.
* @param cache_head_stride: The stride of head_num in cache.
* @param cache_seq_stride: The stride of cache_mem_len in cache.
* @param cache_scale_bs_stride: The stride of batch in cache scale.
* @param cache_scale_head_stride: The stride of head in cache scale.
* @param packed: A boolean value indicates whether to use pack mode.
* @param quant_mode: A int value indicates the quantify mode, 0 means quantify by per_channel, and
others value means quantify by per_head.
* @note If one of key/key_cache/key_cache_scale is nullptr, nothing todo for key.
If one of value/value_cache/value_cache_scale is nullptr, nothing todo for value.
*/
KernelStatus invokeOfflineQuantToLinearCache(cnrtQueue_t queue,
void *key_cache,
void *value_cache,
const void *key_cache_scale,
const void *value_cache_scale,
const void *cache_bs_offsets,
const void *cache_seq_offsets,
const void *key,
const void *value,
const void *context_seq_offsets,
const void *context_lens,
const cnnlDataType_t dtype,
const int batch,
const int head_num,
const int head_size,
const int max_context_len,
const int cache_mem_len,
const size_t context_bs_stride,
const size_t context_head_stride,
const size_t context_seq_stride,
const size_t cache_bs_stride,
const size_t cache_head_stride,
const size_t cache_seq_stride,
const size_t cache_scale_head_stride,
const bool packed,
const int quant_mode);
} // namespace tmo
#endif // CSRC_KERNELS_OFFLINE_QUANT_TO_LINEAR_CACHE_MLUH_

View File

@@ -0,0 +1,232 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include <climits>
#include "offline_quant_to_paged_cache.mluh"
namespace tmo {
namespace kernels {
#define sizeof_(T) (uint32_t)sizeof(T)
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
#define REM_FOR_STACK (32 * 1024)
__nram__ int8_t nram_buffer[__MLU_NRAM_SIZE__ * 1024 - REM_FOR_STACK];
__nram__ int nram_range_32[32] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31};
template <typename T>
__mlu_func__ void quantifyToInt8(T *nram_input, float *nram_scale, int token_handle, int head_len) {
// quantify
if (std::is_same<half, T>::value) {
__bang_half2float((float *)nram_input,
(half *)((int8_t *)nram_input + token_handle * head_len * sizeof_(half)),
token_handle * head_len);
}
if (std::is_same<bfloat16_t, T>::value) {
#if __BANG_ARCH__ > 500
__bang_bfloat162float(
(float *)nram_input,
(bfloat16_t *)((int8_t *)nram_input + token_handle * head_len * sizeof_(bfloat16_t)),
token_handle * head_len);
#endif
}
__bang_cycle_mul((float *)nram_input, (float *)nram_input, nram_scale, token_handle * head_len,
head_len);
__bang_float2int8_rn((int8_t *)nram_input, (float *)nram_input, token_handle * head_len, 0);
}
template <typename T>
__mlu_global__ void MLUOfflineQuantToPagedCacheKernel(T *key,
T *value,
int8_t *key_cache,
int8_t *value_cache,
float *key_cache_scale,
float *value_cache_scale,
int *slot_mapping,
size_t key_stride0,
size_t value_stride0,
int tokens_num,
int head_num,
int block_size,
int head_size,
int tokens_block) {
/*******************************************************nram space***********************
* nram:| input | scale | cache_offset | scale_offset | mask | temp | index |
* input size: tokens_block * head_num * head_size * sizeof(float)
* scale size: head_num * head_size * sizeof(float)
* cache_offset size: tokens_block * head_num * sizeof(float)
* scale_offset size: equal to cache_offset size
* mask size: CEIL_DIV(tokens_size * head_num, 8) * sizeof(int8_t)
* temp size: CEIL_ALIGN(token_size * head_num, 8) * sizeof(int)
* index size: head_num * sizeof(int)
****************************************************************************************/
#if __BANG_ARCH__ > 500
int token_begin = taskId * tokens_block;
if (token_begin >= tokens_num) return;
int token_handle = std::min(tokens_block, tokens_num - token_begin);
int seq_len = token_handle * head_num;
int head_len = head_num * head_size;
int pad8_num = CEIL_DIV(seq_len, CHAR_BIT) * CHAR_BIT;
int input_size = seq_len * head_size * sizeof_(float);
int8_t *nram_input = nram_buffer;
float *nram_scale = (float *)(nram_buffer + input_size);
int *cache_offset = (int *)(nram_scale + head_len);
int *scale_offset = cache_offset + pad8_num;
int *nram_mask = scale_offset + pad8_num;
int *nram_temp = nram_mask + pad8_num;
int *head_index = nram_temp + pad8_num;
// generate range: (0, 1, 2, ..., (head_num - 1))
__memcpy(head_index, nram_range_32, std::min(head_num, 32) * sizeof_(int), NRAM2NRAM);
int begin = 32;
while (begin < head_num) {
int count = std::min(begin, head_num - begin);
__bang_add_scalar(head_index + begin, head_index, begin, count);
begin += count;
}
// load slot(token_handle) -> expand(head_num, token_handle) ->transpose(token_handle, head_num)
int token_size = token_handle * sizeof_(int);
__memcpy(scale_offset, slot_mapping + token_begin, token_size, GDRAM2NRAM);
__memcpy(nram_temp, scale_offset, token_size, NRAM2NRAM, token_size, 0, head_num - 1);
__bang_transpose(scale_offset, nram_temp, head_num, token_handle);
__bang_write_zero((float *)nram_temp, pad8_num);
__bang_ge_bitindex((float *)nram_mask, (float *)scale_offset, (float *)nram_temp, pad8_num);
// calculate cache/scale scatter offset
__bang_div(cache_offset, scale_offset, (int)block_size, seq_len);
__bang_rem(scale_offset, scale_offset, (int)block_size, seq_len);
__bang_mul_scalar(cache_offset, cache_offset, head_num * block_size, seq_len);
__bang_mul_scalar(head_index, head_index, block_size, head_num);
__bang_cycle_add(cache_offset, cache_offset, head_index, seq_len, head_num);
__bang_add(scale_offset, cache_offset, scale_offset, seq_len);
__bang_mul_scalar(cache_offset, scale_offset, head_size, seq_len);
__bang_mul_scalar(scale_offset, scale_offset, sizeof_(float), seq_len);
int hidden_bytes = head_num * head_size * sizeof_(T);
bool half_size = (sizeof(T) == sizeof(half));
if (key != nullptr && key_cache != nullptr && key_cache_scale != nullptr) {
// load key_cache_scale
__memcpy(nram_scale, key_cache_scale, head_len * sizeof_(float), GDRAM2NRAM);
__bang_recip(nram_scale, nram_scale, head_len);
// (token_handle, head_num, head_size)
__memcpy(nram_input + half_size * token_handle * hidden_bytes, key + token_begin * key_stride0,
hidden_bytes, GDRAM2NRAM, hidden_bytes, key_stride0 * sizeof_(T), token_handle - 1);
// quantify
quantifyToInt8((T *)nram_input, nram_scale, token_handle, head_len);
// scatter to gdram
__scatter(key_cache, (int8_t *)nram_input, (uint32_t *)cache_offset, nram_mask, head_size,
NRAM2GDRAM, head_size, seq_len);
}
if (value != nullptr && value_cache != nullptr && value_cache_scale != nullptr) {
// load key_cache_scale
__memcpy(nram_scale, value_cache_scale, head_len * sizeof_(float), GDRAM2NRAM);
__bang_recip(nram_scale, nram_scale, head_len);
// (token_handle, head_num, head_size)
__memcpy(nram_input + half_size * token_handle * hidden_bytes,
value + token_begin * value_stride0, hidden_bytes, GDRAM2NRAM, hidden_bytes,
value_stride0 * sizeof_(T), token_handle - 1);
// quantify
quantifyToInt8((T *)nram_input, nram_scale, token_handle, head_len);
// scatter to gdram
__scatter(value_cache, (int8_t *)nram_input, (uint32_t *)cache_offset, nram_mask, head_size,
NRAM2GDRAM, head_size, seq_len);
}
#endif
}
} // namespace kernels
KernelStatus invokeOfflineQuantToPagedCache(cnrtQueue_t queue,
cnnlDataType_t data_type,
void *key,
void *value,
void *key_cache,
void *value_cache,
void *key_cache_scale,
void *value_cache_scale,
void *slot_mapping,
size_t key_stride0,
size_t value_stride0,
int num_tokens,
int num_heads,
int block_num,
int block_size,
int head_size) {
if (is_arch300()) {
std::cerr << "[invokeOfflineQuantToPagedCache]: kernel does not support MLU300 devices."
<< std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
int dtype_size = 1;
if (data_type == CNNL_DTYPE_HALF || data_type == CNNL_DTYPE_BFLOAT16) {
dtype_size = 2;
} else if (data_type == CNNL_DTYPE_FLOAT) {
dtype_size = 4;
} else {
std::cerr << "invokeOfflineQuantToPagedCache: unsupport data type\n";
return KernelStatus::KERNEL_STATUS_FAILED;
}
int64_t kv_cache_range = block_num * block_size * num_heads * head_size * dtype_size;
if (kv_cache_range > UINT32_MAX) {
std::cerr
<< "invokeOfflineQuantToPagedCache: The addressing range of kv_cache cannot exceed 4G."
<< std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
// nram_size_need: token_block * head_num * head_size + head_num * head_size * sizeof(float)
// token_block * head_num * 4 * sizeof(int) + head_num * sizeof(int)
// nram uesd: 480KB
int nram_size = 480 * 1024 - num_heads * sizeof(int) - num_heads * head_size * sizeof(float);
int hidden_bytes = num_heads * head_size * sizeof(float) +
4 * CEIL_DIV(num_heads, CHAR_BIT) * CHAR_BIT * sizeof(int);
int seq_block = nram_size / hidden_bytes;
if (seq_block <= 0) {
std::cerr << "invokeOfflineQuantToPagedCache: "
<< "num_heads * head_size * dtype_size + 4 * num_heads * sizeof(int) "
<< "should be less than 480KB.\n";
return KernelStatus::KERNEL_STATUS_FAILED;
}
if (seq_block > 16) {
seq_block = seq_block / 16 * 16;
}
int cluster_num, core_dim;
CNdev dev;
cnCtxGetDevice(&dev);
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
CNRT_CHECK(cnrtDeviceGetAttribute(&core_dim, cnrtAttrMcorePerCluster, dev));
int core_num = core_dim * cluster_num;
seq_block = std::min(seq_block, CEIL_DIV(num_tokens, core_num));
uint32_t task_dim = CEIL_DIV(num_tokens, seq_block);
cnrtDim3_t dim{1, task_dim, 1};
if (data_type == CNNL_DTYPE_FLOAT) {
kernels::MLUOfflineQuantToPagedCacheKernel<<<dim, cnrtFuncTypeBlock, queue>>>(
(float *)key, (float *)value, (int8_t *)key_cache, (int8_t *)value_cache,
(float *)key_cache_scale, (float *)value_cache_scale, (int *)slot_mapping, key_stride0,
value_stride0, num_tokens, num_heads, block_size, head_size, seq_block);
} else if (data_type == CNNL_DTYPE_HALF) {
kernels::MLUOfflineQuantToPagedCacheKernel<<<dim, cnrtFuncTypeBlock, queue>>>(
(half *)key, (half *)value, (int8_t *)key_cache, (int8_t *)value_cache,
(float *)key_cache_scale, (float *)value_cache_scale, (int *)slot_mapping, key_stride0,
value_stride0, num_tokens, num_heads, block_size, head_size, seq_block);
} else {
kernels::MLUOfflineQuantToPagedCacheKernel<<<dim, cnrtFuncTypeBlock, queue>>>(
(bfloat16_t *)key, (bfloat16_t *)value, (int8_t *)key_cache, (int8_t *)value_cache,
(float *)key_cache_scale, (float *)value_cache_scale, (int *)slot_mapping, key_stride0,
value_stride0, num_tokens, num_heads, block_size, head_size, seq_block);
}
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
} // namespace tmo

View File

@@ -0,0 +1,62 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_OFFLINE_QUANT_TO_PAGED_CACHE_MLUH_
#define CSRC_KERNELS_OFFLINE_QUANT_TO_PAGED_CACHE_MLUH_
#include "kernel_utils.h"
namespace tmo {
/**
* @brief Perform offline_quant_to_paged_cache operation.
* @param queue[in]: The queue for mlu.
* @param data_type[in]: The cnnl data type of key.
* @param key[in]: Pointer to the MLU memory that stores the key tensor which has shape [num_tokens,
* num_heads, head_size]. Data type of key must be half/bfloat16_t/float.
* @param value[in]: Pointer to the MLU memory that stores the value tensor which has shape
* [num_tokens, num_heads, head_size]. Data type of value must be half/bfloat16_t/float.
* @param key_cache[out]: Pointer to the MLU memory that stores the key_cache tensor which has
* shape [num_blocks, num_heads, block_size, head_size]. Data type of key cache must be int8_t.
* @param value_cache[out]: Pointer to the MLU memory that stores the value_cache tensor which has
* shape [num_blocks, num_heads, block_size, head_size]. Data type of value cache must be int8_t.
* @param key_cache_scale[in]: Pointer to the MLU memory that stores the key_cache_scale tensor
* which has shape [num_heads, head_size]. Data type of key cache scale must be float.
* @param value_cache_scale[in]: Pointer to the MLU memory that stores the value_cache_scale tensor
* which has shape [num_heads, head_size]. Data type of value cache scale must be float.
* @param slot_mapping[in]: Pointer to the MLU memory that stores the slot_mapping tensor which has
* shape [num_tokens]. Data type of slot mapping must be int32_t.
* @param key_stride0[in]: The first dimension stride length of key_cache tensor.
* @param value_stride0[in]: The first dimension stride length of value_cache tensor.
* @param num_tokens[in]: Total number of tokens.
* @param num_heads[in]: Head number.
* @param block_num[in]: Total number of blocks.
* @param block_size[in]: Number of tokens per block.
* @param head_size[in]: Head size.
* @note: offline_quant_to_paged_cache does not support MLU300 device.
*/
KernelStatus invokeOfflineQuantToPagedCache(cnrtQueue_t queue,
cnnlDataType_t data_type,
void *key,
void *value,
void *key_cache,
void *value_cache,
void *key_cache_scale,
void *value_cache_scale,
void *slot_mapping,
size_t key_stride0,
size_t value_stride0,
int num_tokens,
int num_heads,
int block_num,
int block_size,
int head_size);
} // namespace tmo
#endif // CSRC_KERNELS_OFFLINE_QUANT_TO_PAGED_CACHE_MLUH_

View File

@@ -0,0 +1,156 @@
#include <algorithm>
#include "cnrt.h"
#include "operate_cu_seq_lens.mluh"
namespace {
constexpr int pair_elem_num = 2;
}
namespace tmo {
namespace kernels {
#define ONCHIP_DATA_NUM ((int)((__MLU_NRAM_SIZE__ * 1024 - 32 * 1024) / sizeof(int)))
__nram__ int nram_buffer[ONCHIP_DATA_NUM];
__nram__ const int acc_seq_lens[16] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
__mlu_func__ void genSeqLens(int *seq_len_nram, int start, int multi, int elem_count) {
constexpr int acc_seq_lens_size = 16;
int count = std::min(acc_seq_lens_size, elem_count);
int add_on = multi * acc_seq_lens_size;
__bang_mul_scalar(seq_len_nram, acc_seq_lens, multi, count);
__bang_add_scalar(seq_len_nram, seq_len_nram, start, count);
while (count < elem_count) {
__bang_add_scalar(seq_len_nram + count, seq_len_nram, add_on,
std::min(count, elem_count - count));
count *= 2;
add_on *= 2;
}
}
__mlu_global__ void MLUSliceCuSeqlens(int *cu_seq_lens,
int *sliced_cu_seq_lens,
int batch,
int every,
int remain,
int loop) {
int cu_seq_lens_elem_count = batch + 1;
int sliced_cu_seq_lens_elem_count = batch + loop;
int *cu_seq_lens_narm = nram_buffer;
int *sliced_cu_seq_lens_narm = cu_seq_lens_narm + cu_seq_lens_elem_count;
int *sliced_cu_seq_lens_narm_start = sliced_cu_seq_lens_narm;
__memcpy(cu_seq_lens_narm, cu_seq_lens, cu_seq_lens_elem_count * sizeof(int), GDRAM2NRAM);
__bang_write_zero(sliced_cu_seq_lens_narm, sliced_cu_seq_lens_elem_count);
for (int i = 0; i < loop; ++i) {
int elem_num = 1 + (i == loop - 1 && remain != 0 ? remain : every);
__bang_sub_scalar(sliced_cu_seq_lens_narm, cu_seq_lens_narm, cu_seq_lens_narm[0], elem_num);
cu_seq_lens_narm += elem_num - 1;
sliced_cu_seq_lens_narm += elem_num;
}
__memcpy(sliced_cu_seq_lens, sliced_cu_seq_lens_narm_start,
sliced_cu_seq_lens_elem_count * sizeof(int), NRAM2GDRAM);
}
__mlu_global__ void MLUGenerateKVCuSeqlens(int *gen_cu_seq_lens,
int every,
int remain,
int loop,
int seq_len,
bool is_causal_mask,
int seg_data_num,
int task_num) {
int offset = seg_data_num * taskIdX;
int total_elem_num = std::min(seg_data_num, loop * pair_elem_num - offset);
int seq_len_elem_num = total_elem_num / pair_elem_num;
int *gen_cu_seq_lens_narm = nram_buffer;
__bang_write_zero(gen_cu_seq_lens_narm, total_elem_num);
if (is_causal_mask) {
int *seq_lens_narm = gen_cu_seq_lens_narm + total_elem_num;
genSeqLens(seq_lens_narm, every * offset / pair_elem_num, every, seq_len_elem_num);
__memcpy(gen_cu_seq_lens_narm + 1, seq_lens_narm, sizeof(int), NRAM2NRAM,
pair_elem_num * sizeof(int), sizeof(int), seq_len_elem_num - 1);
if (remain != 0 && taskIdX == task_num - 1) {
gen_cu_seq_lens_narm[total_elem_num - 1] -= (every - remain);
}
} else {
__bang_write_value(gen_cu_seq_lens_narm + 1, 1, seq_len, pair_elem_num * sizeof(int),
seq_len_elem_num - 1, seq_len_elem_num * pair_elem_num * sizeof(int), 0);
}
__memcpy(gen_cu_seq_lens + offset, gen_cu_seq_lens_narm, total_elem_num * sizeof(int),
NRAM2GDRAM);
}
__mlu_global__ void MLUGenerateQCuSeqlens(int *gen_cu_seq_lens,
int every,
int remain,
int loop,
int seg_data_num,
int task_num) {
int offset = seg_data_num * taskIdX;
int total_elem_num = std::min(seg_data_num, loop * pair_elem_num - offset);
int seq_len_elem_num = total_elem_num / pair_elem_num;
int *gen_cu_seq_lens_narm = nram_buffer;
__bang_write_zero(gen_cu_seq_lens_narm, total_elem_num);
__bang_write_value(gen_cu_seq_lens_narm + 1, 1, every, pair_elem_num * sizeof(int),
seq_len_elem_num - 1, seq_len_elem_num * pair_elem_num * sizeof(int), 0);
if (remain != 0 && taskIdX == task_num - 1) {
gen_cu_seq_lens_narm[total_elem_num - 1] = remain;
}
__memcpy(gen_cu_seq_lens + offset, gen_cu_seq_lens_narm, total_elem_num * sizeof(int),
NRAM2GDRAM);
}
} // namespace kernels
KernelStatus invokeSliceCuSeqlens(cnrtQueue_t queue,
int *cu_seq_lens,
int *sliced_cu_seq_lens,
int batch,
int parallel_num) {
int every = (batch + parallel_num - 1) / parallel_num;
int repeat = batch / every;
int remain = batch % every;
int loop = repeat + (remain != 0);
cnrtDim3_t dim{1, 1, 1};
kernels::MLUSliceCuSeqlens<<<dim, cnrtFuncTypeBlock, queue>>>(cu_seq_lens, sliced_cu_seq_lens,
batch, every, remain, loop);
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
KernelStatus invokeGenerateCuSeqlens(cnrtQueue_t queue,
int *gen_cu_seq_lens,
int seq_len,
int parallel_num,
bool is_causal_mask,
bool is_kv_seq_len) {
int every = (seq_len + parallel_num - 1) / parallel_num;
int repeat = seq_len / every;
int remain = seq_len % every;
int loop = repeat + (remain != 0);
int seg_data_num = ONCHIP_DATA_NUM / 2;
if (is_kv_seq_len && is_causal_mask) {
// max segnum for 2d memcpy is 64k
seg_data_num = std::min(pair_elem_num * 64 * 1024, seg_data_num);
}
int total_elem_num = loop * pair_elem_num;
int task_num = (total_elem_num + seg_data_num - 1) / seg_data_num;
cnrtDim3_t dim{(unsigned int)task_num, 1, 1};
if (is_kv_seq_len) {
kernels::MLUGenerateKVCuSeqlens<<<dim, cnrtFuncTypeBlock, queue>>>(
gen_cu_seq_lens, every, remain, loop, seq_len, is_causal_mask, seg_data_num, task_num);
} else {
kernels::MLUGenerateQCuSeqlens<<<dim, cnrtFuncTypeBlock, queue>>>(
gen_cu_seq_lens, every, remain, loop, seg_data_num, task_num);
}
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
} // namespace tmo

View File

@@ -0,0 +1,66 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_OPERATE_CU_SEQ_LENS_MLUH_
#define CSRC_KERNELS_OPERATE_CU_SEQ_LENS_MLUH_
#include "cnnl.h"
#include "kernel_utils.h"
namespace tmo {
/**
* @brief slice cu_seq_lens and cu_k_seq_lens for parallel context when attention split batch;
* @example
* cu_seq_lens: [0, 2, 5, 10, 20, 33, 46, 51, 77]
* batch: 8, parallel_num: 3
* sliced_cu_seq_lens: [0, 2, 5, 10, 0, 10, 23, 36, 0, 5, 31]
* @param queue: The queue for mlu.
* @param cu_seq_lens: Input. Pointer to the MLU memory that stores the current seq lens, the shape
* is [batch + 1].
* @param sliced_cu_seq_lens: Output. Pointer to the MLU memory that stores the sliced current seq
* lens, the shape is [batch + loop_time].
* @param batch: Batch size.
* @param parallel_num: Parallel num of batch.
*/
KernelStatus invokeSliceCuSeqlens(cnrtQueue_t queue,
int *cu_seq_lens,
int *sliced_cu_seq_lens,
int batch,
int parallel_num);
/**
* @brief generate cu_seq_lens and cu_k_seq_lens for parallel context when attention split seq;
* @example
* seq_len: 11, parallel_num: 3
* gen_cu_seq_lens for q: [0, 4, 0, 4, 0, 3]
* @example
* seq_len: 11, parallel_num: 3
* is_causal_mask false, gen_cu_seq_lens for kv: [0, 11, 0, 11, 0, 11]
* is_causal_mask true , gen_cu_seq_lens for kv: [0, 4, 0, 8, 0, 11]
* @param queue: The queue for mlu.
* @param gen_cu_seq_lens: Output. Pointer to the MLU memory that stores the generated current seq
* lens, the shape is [2 * loop_time].
* @param seq_len: Sequence length.
* @param parallel_num: Parallel num of sequence length.
* @param is_causal_mask: Whether self attention use causal mask.
* @param is_kv_seq_len: The gen_cu_seq_lens is for q or kv.
*/
KernelStatus invokeGenerateCuSeqlens(cnrtQueue_t queue,
int *gen_cu_seq_lens,
int seq_len,
int parallel_num,
bool is_causal_mask,
bool is_kv_seq_len);
} // namespace tmo
#endif // CSRC_KERNELS_OPERATE_CU_SEQ_LENS_MLUH_

View File

@@ -0,0 +1,81 @@
#include <cassert>
#include <iostream>
#include <map>
#include <ostream>
#include "cnnl.h"
#include "cnrt.h"
#include "preload.mluh"
// clang-format off
#include <mlu.h>
// clang-format on
namespace tmo {
namespace kernels {
#define SRAM_SIZE ((__MLU_SRAM_SIZE__ - 32) * 1024)
__mlu_shared__ int8_t sram_buffer[SRAM_SIZE];
__mlu_func__ void split(const int64_t total,
const int64_t num,
const int64_t id,
size_t &every,
size_t &offset) {
int64_t base = total / num;
int64_t tail = total - base * num;
every = base + (id < tail ? 1 : 0);
offset = base * id + (id < tail ? id : tail);
}
__mlu_global__ void MLUUnion1Preload(void *filter_ptr, size_t preload_size) {
#if __BANG_ARCH__ > 372
size_t cluster_preload_size = 0;
size_t cluster_preload_offset = 0;
split(preload_size, taskDimY, taskIdY, cluster_preload_size, cluster_preload_offset);
size_t load_repeat = cluster_preload_size / SRAM_SIZE;
size_t load_remain = cluster_preload_size % SRAM_SIZE;
for (size_t i = 0; i < load_repeat + 1; i++) {
if (i == load_repeat && load_remain == 0) {
break;
}
size_t loop_load_size = (i < load_repeat ? SRAM_SIZE : load_remain);
int8_t *gdram_ptr = (int8_t *)filter_ptr + cluster_preload_offset + i * SRAM_SIZE;
if (loop_load_size > 0) {
__memcpy(sram_buffer, gdram_ptr, loop_load_size, GDRAM2SRAM);
}
}
#endif
}
} // namespace kernels
KernelStatus invokePreload(cnrtQueue_t queue,
void *filter_ptr,
size_t filter_size,
size_t preload_size) {
if (preload_size == 0) {
std::cerr << "[invokePreload]: preload_size must be greater than 0." << std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
if (preload_size > filter_size) {
preload_size = filter_size;
}
CNdev dev;
cnCtxGetDevice(&dev);
int cluster_num;
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
cnrtDim3_t dim{.x = 4, .y = (uint32_t)cluster_num, .z = 1};
if (cluster_num == 1) {
dim.y = 1;
} else if (cluster_num >= 2) {
dim.y = 2;
}
kernels::MLUUnion1Preload<<<dim, cnrtFuncTypeUnion1, queue>>>(filter_ptr, preload_size);
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
} // namespace tmo

View File

@@ -0,0 +1,34 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_PRELOAD_MLUH_
#define CSRC_KERNELS_PRELOAD_MLUH_
#include "cnnl.h"
#include "kernel_utils.h"
namespace tmo {
/**
* @brief When tp is greater than 1, while executing reducesum, the weight of ffn
* or selfattention to be calculated is loaded into LLC in advance.
* @param queue: The queue for mlu.
* @param filter_ptr: Input. Pointer to the MLU memory that stores the weight of ffn or
* selfattention.
* @param filter_size: The weight size of ffn or selfattention.
* @param preload_size: The size of the preload weight.
* @note The weights of ffn or selfattention must be continuous in filter_ptr.
*/
KernelStatus invokePreload(cnrtQueue_t queue,
void *filter_ptr,
size_t filter_size,
size_t preload_size);
} // namespace tmo
#endif // CSRC_KERNELS_PRELOAD_MLUH_

View File

@@ -0,0 +1,476 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include <mlu.h>
#include <cassert>
#include <iostream>
#include <map>
#include "quant_to_linear_cache.mluh"
namespace tmo {
namespace kernels {
#define NRAM_BUFFER_SIZE (480 * 1024)
__nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE];
__nram__ uint8_t post_table_nram[64];
#define sizeof_(T) (uint32_t)sizeof(T)
#define DIV_UP(x, y) ((x) / (y) + (int)((x) % (y) > 0))
#define __reshape_nhwc2nchw_smallhw(TYPE) \
asm volatile( \
"trans.tiling.nram.nram." TYPE \
"[%[dst]], [%[src]], " \
"%[in0], %[in1], %[is1], %[in2], %[is2], %[in3], %[is3], %[in4], %[is4], %[in5], %[is5]," \
"%[dn0], %[dn1], %[ds1], %[dn2], %[ds2], %[dn3], %[ds3], %[dn4], %[ds4], %[dn5], %[ds5]," \
".posttable.nram([%[post]]); \n\t" ::[dst] "r"(dst), \
[src] "r"(src), [post] "r"(post_table), [in0] "r"(in0), [in1] "r"(in1), [is1] "r"(is1), \
[in2] "i"(1), [is2] "i"(0), [in3] "r"(in3), [is3] "r"(is3), [in4] "r"(n), \
[is4] "r"(batch_offset), [in5] "i"(1), [is5] "i"(0), [dn0] "r"(dn0), [dn1] "r"(dn1), \
[ds1] "r"(ds1), [dn2] "i"(1), [ds2] "i"(0), [dn3] "r"(dn3), [ds3] "r"(ds3), [dn4] "r"(n), \
[ds4] "r"(batch_offset), [dn5] "i"(1), [ds5] "i"(0));
template <typename T>
__mlu_func__ void __reshape_nhwc2nchw_smallhw_init(uint8_t *post_table_nram, uint32_t hw) {
uint32_t align_num = 64 / sizeof(T);
int tmp = hw - 1;
asm volatile("findlast1.gpr.b32 %[out], %[in];" : [out] "=r"(tmp) : [in] "r"(tmp));
int align_hw = 1 << (tmp + 1);
int repeat = align_num / align_hw;
for (int i = 0; i < 64; i++) {
int idx = i / sizeof_(T);
int tmp_idx = (idx % hw) * repeat + idx / hw;
int real_idx = tmp_idx * sizeof_(T) + i % sizeof_(T);
int mask = idx < repeat * hw ? 0x80 : 0x0;
post_table_nram[i] = (uint8_t)real_idx + mask;
}
}
template <typename T>
__mlu_func__ void trans_nhwc2nchw(T *dst,
const T *src,
uint8_t *post_table,
const uint32_t n,
const uint32_t hw,
const uint32_t c) {
uint32_t align_num = 64 / sizeof_(T);
int tmp = hw - 1;
asm volatile("findlast1.gpr.b32 %[out], %[in];" : [out] "=r"(tmp) : [in] "r"(tmp));
int align_hw = 1 << (tmp + 1);
int repeat = align_num / align_hw;
int in0 = 64;
int in1 = hw;
int is1 = c * sizeof(T);
int in3 = c / align_num;
int is3 = in0;
int batch_offset = hw * c * sizeof(T);
int dn0 = hw * repeat * sizeof_(T);
int dn1 = align_hw;
int ds1 = dn0;
int dn3 = in3;
int ds3 = dn0 * dn1;
align_hw = in3 > 0 ? align_hw : 0;
if (align_hw == 2) {
__reshape_nhwc2nchw_smallhw("b256");
} else if (align_hw == 4) {
__reshape_nhwc2nchw_smallhw("b128");
} else if (align_hw == 8) {
__reshape_nhwc2nchw_smallhw("b64");
} else if (align_hw == 16) {
__reshape_nhwc2nchw_smallhw("b32");
} else if (align_hw == 32) {
__reshape_nhwc2nchw_smallhw("b16");
}
constexpr uint32_t bw = 8 * sizeof_(T);
int in3_rem = c % align_num;
int tail_in0 = in3_rem * sizeof_(T);
int tail_dn0 = hw * sizeof_(T);
if (in3_rem > 0) {
asm volatile(
"trans.tiling.nram.nram.b%[bw] [%[dst]], [%[src]], \
%[in0], %[in1], %[is1], %[in2], %[is2], %[in3], %[is3], %[in4], %[is4], %[in5], %[is5], \
%[dn0], %[dn1], %[ds1], %[dn2], %[ds2], %[dn3], %[ds3], %[dn4], %[ds4], %[dn5], %[ds5]; \n\t" ::
[bw] "i"(bw),
[dst] "r"(dst + dn3 * ds3 / sizeof_(T)), [src] "r"(src + is3 * in3 / sizeof_(T)),
[in0] "r"(tail_in0), [in1] "r"(hw), [is1] "r"(c * sizeof_(T)), [in2] "i"(1), [is2] "i"(0),
[in3] "i"(1), [is3] "i"(0), [in4] "r"(n), [is4] "r"(batch_offset), [in5] "i"(1),
[is5] "i"(0), [dn0] "r"(tail_dn0), [dn1] "r"(in3_rem), [ds1] "r"(tail_dn0), [dn2] "i"(1),
[ds2] "i"(0), [dn3] "i"(1), [ds3] "i"(0), [dn4] "r"(n), [ds4] "r"(batch_offset),
[dn5] "i"(1), [ds5] "i"(0));
}
}
template <typename T>
__mlu_func__ void quantify(T *src_input,
float *nram_input,
float *transpose,
float *scale,
float *scale_recip,
int high_dim,
int low_dim,
int quant_bit) {
float recip_data = 1.0f / ((1 << (quant_bit - 1)) - 1);
if (std::is_same<half, T>::value) {
__bang_transpose((half *)transpose, (half *)src_input, high_dim, low_dim);
__bang_abs((half *)src_input, (half *)transpose, high_dim * low_dim);
__bang_maxpool((half *)scale, (half *)src_input, high_dim, low_dim, 1, low_dim, 1, 1, 1);
__bang_half2float(scale_recip, (half *)scale, high_dim);
} else if (std::is_same<bfloat16_t, T>::value) {
#if __BANG_ARCH__ > 500
__bang_transpose((bfloat16_t *)transpose, (bfloat16_t *)src_input, high_dim, low_dim);
__bang_abs((bfloat16_t *)src_input, (bfloat16_t *)transpose, high_dim * low_dim);
__bang_maxpool((bfloat16_t *)scale, (bfloat16_t *)src_input, high_dim, low_dim, 1, low_dim, 1,
1, 1);
__bang_bfloat162float(scale_recip, (bfloat16_t *)scale, high_dim);
#endif
} else {
__bang_transpose(transpose, (float *)src_input, high_dim, low_dim);
__bang_abs((float *)src_input, transpose, high_dim * low_dim);
__bang_maxpool(scale_recip, (float *)src_input, high_dim, low_dim, 1, low_dim, 1, 1, 1);
}
__bang_mul_scalar(scale, scale_recip, recip_data, high_dim);
__bang_recip(scale_recip, scale, high_dim);
if (std::is_same<half, T>::value) {
__bang_half2float(nram_input, (half *)transpose, high_dim * low_dim);
__bang_cycle_mul(transpose, nram_input, scale_recip, high_dim * low_dim, high_dim);
} else if (std::is_same<bfloat16_t, T>::value) {
#if __BANG_ARCH__ > 500
__bang_bfloat162float(nram_input, (bfloat16_t *)transpose, high_dim * low_dim);
__bang_cycle_mul(transpose, nram_input, scale_recip, high_dim * low_dim, high_dim);
#endif
} else {
__bang_cycle_mul(transpose, transpose, scale_recip, high_dim * low_dim, high_dim);
}
}
__mlu_func__ void castKeyToIntx(int8_t *dst, float *src, int high_dim, int low_dim, int quant_bit) {
if (quant_bit == 8) {
__bang_float2int8_rn((int8_t *)src, (float *)dst, high_dim * low_dim, 0);
__bang_transpose((int8_t *)dst, (int8_t *)src, low_dim, high_dim);
} else if (quant_bit == 4) {
__bang_transpose(src, (float *)dst, low_dim, high_dim);
__bang_float2int4_rn((int4x2_t *)dst, src, high_dim * low_dim, 0);
}
}
__mlu_func__ void castValueToIntx(int8_t *value_cache_begin,
int8_t *value_cache_end,
uint8_t *post_table_nram,
float *dst,
float *src,
size_t cache_head_stride,
int seq,
int head_num,
int head_size,
int group_num,
int quant_bit,
bool need_pad_front,
bool need_pad_back) {
if (quant_bit == 8) {
__bang_float2int8_rn((int8_t *)src, dst, seq * head_num * head_size, 0);
__bang_transpose((int8_t *)dst, (int8_t *)src, head_size / group_num,
seq * head_num * group_num);
} else {
__bang_transpose(src, dst, head_size / group_num, seq * head_num * group_num);
__reshape_nhwc2nchw_smallhw_init<int8_t>(post_table_nram, 2);
if (!(need_pad_front || need_pad_back)) {
__bang_float2int8_rn((int8_t *)src, src, seq * head_num * head_size,
0); // [head_num, seq, head_size]
trans_nhwc2nchw((int8_t *)dst, (int8_t *)src, post_table_nram, seq / 2, 2,
head_num * head_size);
__bang_int82int4_rn((int4x2_t *)dst, (int8_t *)dst, seq * head_num * head_size, 0, 0);
} else {
int origin_seq = seq;
if (need_pad_front) {
__memcpy((int8_t *)dst, value_cache_begin, head_size, GDRAM2NRAM, head_size,
cache_head_stride, head_num - 1);
__bang_band_scalar((int8_t *)dst, (int8_t *)dst, 0x0F, head_num * head_size);
seq += 1;
}
if (need_pad_back) {
__memcpy((int8_t *)dst + seq * head_num * head_size, value_cache_end, head_size, GDRAM2NRAM,
head_size, cache_head_stride, head_num - 1);
__bang_srl((int8_t *)dst + seq * head_num * head_size,
(int8_t *)dst + seq * head_num * head_size, 4, head_num * head_size);
seq += 1;
}
__bang_float2int8_rn((int8_t *)dst + need_pad_front * head_num * head_size, src,
origin_seq * head_num * head_size, 0); // [new_seq, head_num, head_size]
trans_nhwc2nchw((int8_t *)src, (int8_t *)dst, post_table_nram, seq / 2, 2,
head_num * head_size); // [seq / 2, 2, head_num, head_size]
__bang_int82int4_rn((int4x2_t *)dst, (int8_t *)src, seq * head_num * head_size, 0, 0);
}
}
}
// [head_num, batch, seq_seg]
template <typename T>
__mlu_global__ void MLUQuantToLinearCacheKernel(int8_t *key_cache,
int8_t *value_cache,
float *key_cache_scale,
float *value_cache_scale,
int *cache_bs_offsets,
int *cache_seq_offsets,
T *key,
T *value,
int *context_seq_offsets,
int *context_lens,
int batch,
int head_num,
int head_size,
int max_context_len,
int cache_mem_len,
size_t context_bs_stride,
size_t context_head_stride,
size_t context_seq_stride,
size_t cache_bs_stride,
size_t cache_head_stride,
size_t key_cache_seq_stride,
size_t value_cache_seq_stride,
size_t cache_scale_bs_stride,
size_t cache_scale_head_stride,
bool packed,
int seq_block,
int quant_bit,
int group_num) {
float *nram_input = (float *)nram_buffer;
T *src_input = (T *)nram_input;
if (sizeof_(T) == sizeof_(half)) {
src_input = (T *)(nram_buffer + seq_block * head_num * head_size * sizeof_(T));
}
float *nram_trans = nram_input + seq_block * head_num * head_size;
float *nram_scale = nram_trans + seq_block * head_num * head_size;
float *nram_scale_recip = nram_scale + seq_block * head_num * group_num;
constexpr int dtype_size = sizeof_(T);
int head_size_store = head_size * quant_bit / 8;
for (int bs_idx = taskIdY; bs_idx < batch; bs_idx += taskDimY) {
int seq_offset =
(packed || context_seq_offsets == nullptr) ? 0 : __load_gdram(context_seq_offsets + bs_idx);
int cache_seq_offset =
cache_seq_offsets == nullptr ? 0 : __load_gdram(cache_seq_offsets + bs_idx);
int context_len = __load_gdram(context_lens + bs_idx);
int seq_len = packed ? (__load_gdram(context_lens + bs_idx + 1) - context_len) : context_len;
int key_seq_begin = taskIdZ * seq_block;
int first_value_seq = (quant_bit == 4 && cache_seq_offset % 2 == 1)
? std::min(seq_len, seq_block - 1)
: std::min(seq_len, seq_block);
int value_seq_begin = taskIdZ == 0 ? 0 : first_value_seq + (taskIdZ - 1) * seq_block;
int key_seq = std::min(seq_len - key_seq_begin, seq_block);
int value_seq = taskIdZ == 0 ? first_value_seq : std::min(seq_len - value_seq_begin, seq_block);
size_t key_context_offset = 0;
size_t value_context_offset = 0;
if (packed) {
key_context_offset += (context_len + key_seq_begin) * context_seq_stride * dtype_size;
value_context_offset += (context_len + value_seq_begin) * context_seq_stride * dtype_size;
} else {
key_context_offset +=
(bs_idx * context_bs_stride + (key_seq_begin + seq_offset) * context_seq_stride) *
dtype_size;
value_context_offset +=
(bs_idx * context_bs_stride + (value_seq_begin + seq_offset) * context_seq_stride) *
dtype_size;
}
int key_cache_seq_offset = cache_seq_offset;
int value_cache_seq_offset = cache_seq_offset;
int cache_bs_offset = cache_bs_offsets == nullptr ? bs_idx : cache_bs_offsets[bs_idx];
if (key_cache_seq_offset < 0 || cache_bs_offset < 0) {
continue;
}
key_cache_seq_offset += key_seq_begin;
value_cache_seq_offset += value_seq_begin;
size_t key_cache_offset =
(cache_bs_offset * cache_bs_stride + key_cache_seq_offset * key_cache_seq_stride);
size_t key_cache_scale_offset =
cache_bs_offset * cache_scale_bs_stride + key_cache_seq_offset * group_num;
if (key != nullptr && key_cache != nullptr && key_cache_scale != nullptr &&
key_seq_begin < seq_len) {
int8_t *key_cache_begin = key_cache + key_cache_offset;
float *key_cache_scale_begin = key_cache_scale + key_cache_scale_offset;
char *key_begin = (char *)key + key_context_offset;
// nram_input: (head_num, seq, head_size)
__memcpy(src_input, key_begin, head_size * dtype_size, GDRAM2NRAM,
key_seq * head_size * dtype_size, head_num - 1, head_size * dtype_size, key_seq - 1,
context_head_stride * dtype_size, head_num - 1, context_seq_stride * dtype_size,
key_seq - 1);
// [head_num, seq, head_size]
quantify(src_input, nram_input, nram_trans, nram_scale, nram_scale_recip,
key_seq * head_num * group_num, head_size / group_num, quant_bit);
castKeyToIntx((int8_t *)nram_trans, nram_input, key_seq * head_num * group_num,
head_size / group_num, quant_bit);
// after quantify: (head_num, seq, head_size * quant_bit / 8)
__memcpy(key_cache_begin, nram_trans, head_size_store, NRAM2GDRAM, key_cache_seq_stride,
key_seq - 1, cache_head_stride, head_num - 1, head_size_store, key_seq - 1,
key_seq * head_size_store, head_num - 1);
// nram_scale: (head_num, seq, group_num)
__memcpy(key_cache_scale_begin, nram_scale, key_seq * group_num * sizeof_(float), NRAM2GDRAM,
cache_scale_head_stride * sizeof_(float), key_seq * group_num * sizeof_(float),
head_num - 1);
}
if (value != nullptr && value_cache != nullptr && value_cache_scale != nullptr &&
value_seq_begin < seq_len) {
size_t value_cache_offset =
cache_bs_offset * cache_bs_stride + value_cache_seq_offset * value_cache_seq_stride;
if (quant_bit == 4) {
value_cache_offset = cache_bs_offset * cache_bs_stride +
(value_cache_seq_offset / 2) * value_cache_seq_stride;
}
size_t value_cache_scale_offset =
cache_bs_offset * cache_scale_bs_stride + value_cache_seq_offset * group_num;
int8_t *value_cache_begin = value_cache + value_cache_offset;
int8_t *value_cache_end =
value_cache + cache_bs_offset * cache_bs_stride +
(DIV_UP(cache_seq_offset + seq_len, 2) - 1) * value_cache_seq_stride;
float *value_cache_scale_begin = value_cache_scale + value_cache_scale_offset;
char *value_begin = (char *)value + value_context_offset;
bool need_pad_front = (quant_bit == 4) && (taskIdZ == 0) && (cache_seq_offset % 2 == 1);
bool need_pad_back = (quant_bit == 4) && (value_seq_begin + value_seq >= seq_len) &&
((cache_seq_offset + seq_len) % 2 == 1);
// quant_bit == 8 : nram_input: (head_num, seq, head_size)
if (quant_bit == 8) {
__memcpy(src_input, value_begin, head_size * dtype_size, GDRAM2NRAM,
value_seq * head_size * dtype_size, head_num - 1, head_size * dtype_size,
value_seq - 1, context_head_stride * dtype_size, head_num - 1,
context_seq_stride * dtype_size, value_seq - 1);
} else if (quant_bit == 4) {
// quant_bit == 4 : nram_input: (seq, head_num. head_size)
__memcpy(src_input, value_begin, head_size * dtype_size, GDRAM2NRAM, head_size * dtype_size,
head_num - 1, head_num * head_size * dtype_size, value_seq - 1,
context_head_stride * dtype_size, head_num - 1, context_seq_stride * dtype_size,
value_seq - 1);
}
quantify(src_input, nram_input, nram_trans, nram_scale, nram_scale_recip,
value_seq * head_num * group_num, head_size / group_num, quant_bit);
castValueToIntx(value_cache_begin, value_cache_end, post_table_nram, nram_trans, nram_input,
cache_head_stride, value_seq, head_num, head_size, group_num, quant_bit,
need_pad_front, need_pad_back);
// [head_num, seq, head_size]
if (quant_bit == 8) {
__memcpy(value_cache_scale_begin, nram_scale, value_seq * group_num * sizeof_(float),
NRAM2GDRAM, cache_scale_head_stride * sizeof_(float),
value_seq * group_num * sizeof_(float), head_num - 1);
__memcpy(value_cache_begin, nram_trans, head_size, NRAM2GDRAM, value_cache_seq_stride,
value_seq - 1, cache_head_stride, head_num - 1, head_size, value_seq - 1,
value_seq * head_size, head_num - 1);
} else if (quant_bit == 4) {
int new_seq = value_seq + ((int)need_pad_front + (int)need_pad_back);
__sync();
__memcpy_async(nram_scale_recip, nram_scale, group_num * sizeof_(float), NRAM2NRAM,
value_seq * group_num * sizeof_(float), head_num - 1,
group_num * sizeof_(float), value_seq - 1, group_num * sizeof_(float),
head_num - 1, head_num * group_num * sizeof_(float), value_seq - 1);
__memcpy_async(value_cache_begin, nram_trans, head_size, NRAM2GDRAM, cache_head_stride,
head_num - 1, value_cache_seq_stride, new_seq / 2 - 1, head_size,
head_num - 1, head_num * head_size, new_seq / 2 - 1);
__sync();
__memcpy(value_cache_scale_begin, nram_scale_recip, value_seq * group_num * sizeof_(float),
NRAM2GDRAM, cache_scale_head_stride * sizeof_(float),
value_seq * group_num * sizeof_(float), head_num - 1);
}
} // end if
} // end for
}
} // namespace kernels
void invokeQuantToLinearCache(cnrtQueue_t queue,
void *key_cache,
void *value_cache,
void *key_cache_scale,
void *value_cache_scale,
const void *cache_bs_offsets,
const void *cache_seq_offsets,
void *key,
void *value,
const void *context_seq_offsets,
const void *context_lens,
const cnnlDataType_t dtype,
const int batch,
const int head_num,
const int head_size,
const int max_context_len,
const int cache_mem_len,
const size_t context_bs_stride,
const size_t context_head_stride,
const size_t context_seq_stride,
const size_t cache_bs_stride,
const size_t cache_head_stride,
const size_t key_cache_seq_stride,
const size_t value_cache_seq_stride,
const size_t cache_scale_bs_stride,
const size_t cache_scale_head_stride,
const bool packed,
const int quant_bit,
const int group_size) {
constexpr int nram_size = 480 * 1024;
int group_num = head_size / group_size;
int hidden_bytes = head_num * (head_size + group_num) * sizeof(float) * 2;
int seq_block = nram_size / hidden_bytes;
if (seq_block <= 1) {
std::cerr << __func__ << "," << __LINE__
<< " :head_num * (head_size + group_num) * sizeof(float) should be less than 120KB."
<< std::endl;
}
if (seq_block > 16) {
seq_block = seq_block / 16 * 16;
} else {
seq_block = seq_block / 2 * 2;
}
int seq_seg = max_context_len / seq_block + 1;
int cluster_num, core_dim;
CNdev dev;
cnCtxGetDevice(&dev);
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
CNRT_CHECK(cnrtDeviceGetAttribute(&core_dim, cnrtAttrMcorePerCluster, dev));
int core_num = core_dim * cluster_num;
uint32_t task_y_dim = std::min(batch, core_num);
cnrtDim3_t dim{1, task_y_dim, (uint32_t)seq_seg};
if (dtype == CNNL_DTYPE_HALF) {
kernels::MLUQuantToLinearCacheKernel<<<dim, cnrtFuncTypeBlock, queue>>>(
(int8_t *)key_cache, (int8_t *)value_cache, (float *)key_cache_scale,
(float *)value_cache_scale, (int *)cache_bs_offsets, (int *)cache_seq_offsets, (half *)key,
(half *)value, (int *)context_seq_offsets, (int *)context_lens, batch, head_num, head_size,
max_context_len, cache_mem_len, context_bs_stride, context_head_stride, context_seq_stride,
cache_bs_stride, cache_head_stride, key_cache_seq_stride, value_cache_seq_stride,
cache_scale_bs_stride, cache_scale_head_stride, packed, seq_block, quant_bit, group_num);
} else if (dtype == CNNL_DTYPE_BFLOAT16) {
kernels::MLUQuantToLinearCacheKernel<<<dim, cnrtFuncTypeBlock, queue>>>(
(int8_t *)key_cache, (int8_t *)value_cache, (float *)key_cache_scale,
(float *)value_cache_scale, (int *)cache_bs_offsets, (int *)cache_seq_offsets,
(bfloat16_t *)key, (bfloat16_t *)value, (int *)context_seq_offsets, (int *)context_lens,
batch, head_num, head_size, max_context_len, cache_mem_len, context_bs_stride,
context_head_stride, context_seq_stride, cache_bs_stride, cache_head_stride,
key_cache_seq_stride, value_cache_seq_stride, cache_scale_bs_stride,
cache_scale_head_stride, packed, seq_block, quant_bit, group_num);
} else {
kernels::MLUQuantToLinearCacheKernel<<<dim, cnrtFuncTypeBlock, queue>>>(
(int8_t *)key_cache, (int8_t *)value_cache, (float *)key_cache_scale,
(float *)value_cache_scale, (int *)cache_bs_offsets, (int *)cache_seq_offsets, (float *)key,
(float *)value, (int *)context_seq_offsets, (int *)context_lens, batch, head_num, head_size,
max_context_len, cache_mem_len, context_bs_stride, context_head_stride, context_seq_stride,
cache_bs_stride, cache_head_stride, key_cache_seq_stride, value_cache_seq_stride,
cache_scale_bs_stride, cache_scale_head_stride, packed, seq_block, quant_bit, group_num);
}
}
} // namespace tmo

View File

@@ -0,0 +1,105 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_QUANT_TO_LINEAR_CACHE_MLUH_
#define CSRC_KERNELS_QUANT_TO_LINEAR_CACHE_MLUH_
#include "cnnl.h"
namespace tmo {
/**
* @brief Quantize current key and value, Then store key and value to key_cache and value_cache.
* @param queue: The queue for mlu.
* @param key_cache: Pointer to the MLU memory that stores the key cache,
* the shape must be [max_batch, head_num, cache_mem_len, head_size].
* Data type of key_cache must be int8. key_cache could be nullptr.
* @param value_cache: Pointer to the MLU memory that stores the value cache,
* the shape must be [max_batch, head_num, cache_mem_len, head_size].
* Data type of value_cache must be int8. value_cache could be nullptr.
* @param key_cache_scale: Pointer to the MLU memory that stores the key cache scale,
* the shape must be [max_batch, head_num, cache_mem_len].
* Data type of key_cache_scale must be float. key_cache_scale could be nullptr.
* @param value_cache_scale: Pointer to the MLU memory that stores the value cache scale,
* the shape must be [max_batch, head_num, cache_mem_len].
* Data type of value_cache_scale must be float. value_cache_scale could be nullptr.
* @param cache_bs_offsets: Pointer to the MLU memory that stores the batch
* offset of cache, the shape must be [batch], if it's nullptr, the
* default value is {0, 1, 2 ... batch - 1}.
* @param cache_seq_offsets: Pointer to the MLU memory that stores the sequence
* offset of cache, the shape must be [batch], if it's nullptr, the
* default value is 0 for every batch.
* @param key: Pointer to the MLU memory that stores the key,
* the shape must be [batch, max_contxt_len, head_num, head_size].
* Data type of key couble be float/half/bfloat16. key could be nullptr.
* @param value: Pointer to the MLU memory that stores the value,
* the shape must be [batch, max_contxt_len, head_num, head_size].
* Data type of value couble be float/half/bfloat16, value could be nullptr.
* @param context_seq_offsets: Pointer to the MLU memory that stores the
* sequence offset of context, the shape must be [batch]. if it's nullptr,
* the default value is 0 for every batch. It must be nullptr when packed is true.
* @param context_lens: Pointer to the MLU memory that stores the sequence length or cumulative
* sequence length of context. when packed is false, the shape must be [batch], which
* indicates sequence length of context. when packed is true, the shape must be [batch + 1], which
* indicates cumulative sequence length of context.
* @param dtype: Data type.
* @param batch: Batch size.
* @param head_num: Head number.
* @param head_size: Head size.
* @param max_contxt_len: The maximum sequence length of context.
* @param cache_mem_len: The maximum sequence length of cache.
* @param contxt_bs_stride: The stride of batch in context, does not work when packed is true.
* @param contxt_head_stride: The stride of head_num in context.
* @param contxt_seq_stride: The stride of max_contxt_len in context.
* @param cache_bs_stride: The stride of batch in cache.
* @param cache_head_stride: The stride of head_num in cache.
* @param key_cache_seq_stride: The stride of cache_mem_len in key_cache.
* @param value_cache_seq_stride: The stride of cache_mem_len in value_cache.
* @param cache_scale_bs_stride: The stride of batch in cache scale.
* @param cache_scale_head_stride: The stride of head in cache scale.
* @param packed: A boolean value indicates whether to use pack mode.
* @param quant_bit: Bit width of quantified results.
* @param group_size: Size of a group during group quantization,
* @note If one of key/key_cache/key_cache_scale is nullptr, nothing todo for key.
If one of value/value_cache/value_cache_scale is nullptr, nothing todo for value.
A negative value in cache_bs_offsets or cache_seq_offsets means nothing to do for
the corresponding batch.
*/
void invokeQuantToLinearCache(cnrtQueue_t queue,
void *key_cache,
void *value_cache,
void *key_cache_scale,
void *value_cache_scale,
const void *cache_bs_offsets,
const void *cache_seq_offsets,
void *key,
void *value,
const void *context_seq_offsets,
const void *context_lens,
const cnnlDataType_t dtype,
const int batch,
const int head_num,
const int head_size,
const int max_context_len,
const int cache_mem_len,
const size_t context_bs_stride,
const size_t context_head_stride,
const size_t context_seq_stride,
const size_t cache_bs_stride,
const size_t cache_head_stride,
const size_t key_cache_seq_stride,
const size_t value_cache_seq_stride,
const size_t cache_scale_bs_stride,
const size_t cache_scale_head_stride,
const bool packed,
const int quant_bit,
const int group_size);
} // namespace tmo
#endif // CSRC_KERNELS_QUANT_TO_LINEAR_CACHE_MLUH_

View File

@@ -0,0 +1,241 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include <climits>
#include "quant_to_paged_cache.mluh"
namespace tmo {
namespace kernels {
#define sizeof_(T) (uint32_t)sizeof(T)
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
#define REM_FOR_STACK (32 * 1024)
__nram__ int8_t nram_buffer[__MLU_NRAM_SIZE__ * 1024 - REM_FOR_STACK];
#if __BANG_ARCH__ > 500
__nram__ const int nram_range_32[32] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21,
22, 23, 24, 25, 26, 27, 28, 29, 30, 31};
#endif
template <typename T>
__mlu_func__ void quantifyToInt8(int8_t *nram_output,
float *nram_float,
T *nram_input,
float *nram_scale,
float *nram_temp,
int seq_len,
int head_size) {
// quantify
__bang_transpose((T *)nram_input, (T *)nram_scale, seq_len, head_size);
if (std::is_same<half, T>::value) {
__bang_half2float(nram_float, (half *)nram_input, head_size * seq_len);
}
if (std::is_same<bfloat16_t, T>::value) {
#if __BANG_ARCH__ > 500
__bang_bfloat162float(nram_float, (bfloat16_t *)nram_input, head_size * seq_len);
#endif
}
__bang_abs(nram_scale, nram_float, head_size * seq_len);
__bang_maxpool(nram_scale, nram_scale, seq_len, head_size, 1, head_size, 1, 1, 1);
__bang_mul_scalar(nram_scale, nram_scale, 1 / 127.f, seq_len);
__bang_recip(nram_temp, nram_scale, seq_len);
__bang_cycle_mul(nram_float, nram_float, nram_temp, head_size * seq_len, seq_len);
__bang_float2int8_rn((int8_t *)nram_float, nram_float, head_size * seq_len, 0);
__bang_transpose((int8_t *)nram_output, (int8_t *)nram_float, head_size, seq_len);
}
template <typename T>
__mlu_global__ void MLUQuantToPagedCacheKernel(T *key,
T *value,
int8_t *key_cache,
int8_t *value_cache,
float *key_cache_scale,
float *value_cache_scale,
int *slot_mapping,
size_t key_stride0,
size_t value_stride0,
int tokens_num,
int head_num,
int block_size,
int head_size,
int tokens_block) {
/*******************************************************nram space***********************
* nram:| input | scale | cache_offset | scale_offset | mask | temp | index |
* input size: tokens_block * head_num * head_size * sizeof(float)
* scale size: equal to input size
* cache_offset size: tokens_block * head_num * sizeof(float)
* scale_offset size: equal to cache_offset size
* mask size: CEIL_DIV(tokens_size * head_num, 8) * sizeof(int8_t)
* temp size: CEIL_ALIGN(token_size * head_num, 8) * sizeof(int)
* index size: head_num * sizeof(int)
****************************************************************************************/
#if __BANG_ARCH__ > 500
int token_begin = taskId * tokens_block;
if (token_begin >= tokens_num) return;
int token_handle = std::min(tokens_block, tokens_num - token_begin);
int seq_len = token_handle * head_num;
int pad8_num = CEIL_DIV(seq_len, CHAR_BIT) * CHAR_BIT;
int input_size = seq_len * head_size * sizeof_(float);
int8_t *nram_input = nram_buffer;
float *nram_scale = (float *)(nram_buffer + input_size);
int *cache_offset = (int *)((int8_t *)nram_scale + input_size);
int *scale_offset = cache_offset + pad8_num;
int *nram_mask = scale_offset + pad8_num;
int *nram_temp = nram_mask + pad8_num;
int *head_index = nram_temp + pad8_num;
// generate range: (0, 1, 2, ..., (head_num - 1))
__memcpy(head_index, nram_range_32, std::min(head_num, 32) * sizeof_(int), NRAM2NRAM);
int begin = 32;
while (begin < head_num) {
int count = std::min(begin, head_num - begin);
__bang_add_scalar(head_index + begin, head_index, begin, count);
begin += count;
}
// load slot(token_handle) -> expand(head_num, token_handle) ->transpose(token_handle, head_num)
int token_size = token_handle * sizeof_(int);
__memcpy(scale_offset, slot_mapping + token_begin, token_size, GDRAM2NRAM);
__memcpy(nram_temp, scale_offset, token_size, NRAM2NRAM, token_size, 0, head_num - 1);
__bang_transpose(scale_offset, nram_temp, head_num, token_handle);
__bang_write_zero((float *)nram_temp, pad8_num);
__bang_ge_bitindex((float *)nram_mask, (float *)scale_offset, (float *)nram_temp, pad8_num);
// calculate cache/scale scatter offset
__bang_div(cache_offset, scale_offset, (int)block_size, seq_len);
__bang_rem(scale_offset, scale_offset, (int)block_size, seq_len);
__bang_mul_scalar(cache_offset, cache_offset, head_num * block_size, seq_len);
__bang_mul_scalar(head_index, head_index, block_size, head_num);
__bang_cycle_add(cache_offset, cache_offset, head_index, seq_len, head_num);
__bang_add(scale_offset, cache_offset, scale_offset, seq_len);
__bang_mul_scalar(cache_offset, scale_offset, head_size, seq_len);
__bang_mul_scalar(scale_offset, scale_offset, sizeof_(float), seq_len);
int hidden_bytes = head_num * head_size * sizeof_(T);
int8_t *nram_output = (int8_t *)(nram_input + seq_len * head_size * sizeof_(half));
T *nram_input_origin = (T *)nram_input;
if (!std::is_same<float, T>::value) {
nram_input_origin = (T *)(nram_input + seq_len * head_size * (sizeof_(float) - sizeof_(T)));
}
if (key != nullptr && key_cache != nullptr && key_cache_scale != nullptr) {
// (token_handle, head_num, head_size)
__memcpy(nram_scale, key + token_begin * key_stride0, hidden_bytes, GDRAM2NRAM, hidden_bytes,
key_stride0 * sizeof_(T), token_handle - 1);
// quantify
quantifyToInt8(nram_output, (float *)nram_input, nram_input_origin, nram_scale,
(float *)nram_temp, seq_len, head_size);
// scatter to gdram
__scatter(key_cache, (int8_t *)nram_output, (uint32_t *)cache_offset, nram_mask, head_size,
NRAM2GDRAM, head_size, seq_len);
__scatter(key_cache_scale, nram_scale, (uint32_t *)scale_offset, nram_mask, sizeof_(float),
NRAM2GDRAM, sizeof_(float), seq_len);
}
if (value != nullptr && value_cache != nullptr && value_cache_scale != nullptr) {
// (token_handle, head_num, head_size)
__memcpy(nram_scale, value + token_begin * value_stride0, hidden_bytes, GDRAM2NRAM,
hidden_bytes, value_stride0 * sizeof_(T), token_handle - 1);
// quantify
quantifyToInt8(nram_output, (float *)nram_input, nram_input_origin, nram_scale,
(float *)nram_temp, seq_len, head_size);
// scatter to gdram
__scatter(value_cache, (int8_t *)nram_output, (uint32_t *)cache_offset, nram_mask, head_size,
NRAM2GDRAM, head_size, seq_len);
__scatter(value_cache_scale, nram_scale, (uint32_t *)scale_offset, nram_mask, sizeof_(float),
NRAM2GDRAM, sizeof_(float), seq_len);
}
#endif
}
} // namespace kernels
KernelStatus invokeQuantToPagedCache(cnrtQueue_t queue,
cnnlDataType_t data_type,
void *key,
void *value,
void *key_cache,
void *value_cache,
void *key_cache_scale,
void *value_cache_scale,
void *slot_mapping,
size_t key_stride0,
size_t value_stride0,
int num_tokens,
int num_heads,
int block_num,
int block_size,
int head_size) {
int dtype_size = 1;
if (data_type == CNNL_DTYPE_HALF || data_type == CNNL_DTYPE_BFLOAT16) {
dtype_size = 2;
} else if (data_type == CNNL_DTYPE_FLOAT) {
dtype_size = 4;
} else {
std::cerr << "invokeQuantToPagedCache: unsupport data type\n";
return KernelStatus::KERNEL_STATUS_FAILED;
}
int64_t kv_range = block_num * block_size * num_heads * head_size * dtype_size;
if (kv_range > UINT32_MAX) {
std::cerr << "invokeQuantToPagedCache: The addressing range of kv_cache cannot exceed 4G."
<< std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
if (is_arch300()) {
std::cerr << "[invokeQuantToPagedCache]: kernel does not support MLU300 devices." << std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
// nram_size_need: token_block * head_num * head_size * 2 +
// token_block * head_num * 4 * sizeof(int) + head_num * sizeof(int)
// nram uesd: 480KB
int nram_size = 480 * 1024 - num_heads * sizeof(int);
int hidden_bytes = num_heads * head_size * 2 * sizeof(float) + 4 * num_heads * sizeof(int);
int seq_block = nram_size / hidden_bytes;
if (seq_block <= 0) {
std::cerr << "invokeQuantToPagedCache: "
<< "num_heads * head_size * dtype_size + 4 * num_heads * sizeof(int) "
<< "should be less than 480KB.\n";
return KernelStatus::KERNEL_STATUS_FAILED;
}
if (seq_block > 16) {
seq_block = seq_block / 16 * 16;
}
int cluster_num, core_dim;
CNdev dev;
cnCtxGetDevice(&dev);
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
CNRT_CHECK(cnrtDeviceGetAttribute(&core_dim, cnrtAttrMcorePerCluster, dev));
int core_num = core_dim * cluster_num;
seq_block = std::min(seq_block, CEIL_DIV(num_tokens, core_num));
uint32_t task_dim = CEIL_DIV(num_tokens, seq_block);
cnrtDim3_t dim{1, task_dim, 1};
if (data_type == CNNL_DTYPE_FLOAT) {
kernels::MLUQuantToPagedCacheKernel<<<dim, cnrtFuncTypeBlock, queue>>>(
(float *)key, (float *)value, (int8_t *)key_cache, (int8_t *)value_cache,
(float *)key_cache_scale, (float *)value_cache_scale, (int *)slot_mapping, key_stride0,
value_stride0, num_tokens, num_heads, block_size, head_size, seq_block);
} else if (data_type == CNNL_DTYPE_HALF) {
kernels::MLUQuantToPagedCacheKernel<<<dim, cnrtFuncTypeBlock, queue>>>(
(half *)key, (half *)value, (int8_t *)key_cache, (int8_t *)value_cache,
(float *)key_cache_scale, (float *)value_cache_scale, (int *)slot_mapping, key_stride0,
value_stride0, num_tokens, num_heads, block_size, head_size, seq_block);
} else {
kernels::MLUQuantToPagedCacheKernel<<<dim, cnrtFuncTypeBlock, queue>>>(
(bfloat16_t *)key, (bfloat16_t *)value, (int8_t *)key_cache, (int8_t *)value_cache,
(float *)key_cache_scale, (float *)value_cache_scale, (int *)slot_mapping, key_stride0,
value_stride0, num_tokens, num_heads, block_size, head_size, seq_block);
}
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
} // namespace tmo

View File

@@ -0,0 +1,63 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_QUANT_TO_PAGED_CACHE_MLUH_
#define CSRC_KERNELS_QUANT_TO_PAGED_CACHE_MLUH_
#include "kernel_utils.h"
namespace tmo {
/**
* @brief Perform quant_to_paged_cache operation.
* @param handle: The handle of cnnl.
* @param data_type: The cnnl data type of key.
* @param key: Pointer to the MLU memory that stores the key tensor which has shape [num_tokens,
* num_heads, head_size]. Data type of key must be half/bfloat16_t/float.
* @param value: Pointer to the MLU memory that stores the value tensor which has shape [num_tokens,
* num_heads, head_size]. Data type of key must be half/bfloat16_t/float.
* @param key_cache: Pointer to the MLU memory that stores the key_cache tensor which has
* shape [num_blocks, num_heads, block_size, head_size]. Data type of key cache must be int8_t.
* @param value_cache: Pointer to the MLU memory that stores the value_cache tensor which has
* shape [num_blocks, num_heads, block_size, head_size]. Data type of value cache must be int8_t.
* @param key_cache_scale: Pointer to the MLU memory that stores the key_cache_scale tensor which
* has shape [num_blocks, num_heads, block_size]. Data type of key cache scale must be float.
* @param value_cache_scale: Pointer to the MLU memory that stores the value_cache_scale tensor
* which has shape [num_blocks, num_heads, block_size]. Data type of value cache scale must be
* float.
* @param slot_mapping: Pointer to the MLU memory that stores the slot_mapping tensor which has
* shape [num_tokens]. Data type of slot mapping must be int32_t.
* @param key_stride0: The first dimension stride length of key_cache tensor.
* @param value_stride0: The first dimension stride length of value_cache tensor.
* @param num_tokens: Total number of tokens.
* @param num_heads: Head number.
* @param block_num: Total number of blocks.
* @param block_size: Number of tokens per block.
* @param head_size: Head size.
* @note: quant_to_paged_cache does not support MLU300 device.
*/
KernelStatus invokeQuantToPagedCache(cnrtQueue_t queue,
cnnlDataType_t data_type,
void *key,
void *value,
void *key_cache,
void *value_cache,
void *key_cache_scale,
void *value_cache_scale,
void *slot_mapping,
size_t key_stride0,
size_t value_stride0,
int num_tokens,
int num_heads,
int block_num,
int block_size,
int head_size);
} // namespace tmo
#endif // CSRC_KERNELS_QUANT_TO_PAGED_CACHE_MLUH_

View File

@@ -0,0 +1,313 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_QUANT_UTILS_H_
#define CSRC_KERNELS_QUANT_UTILS_H_
#include <cassert>
#include <iostream>
#include <string>
#include "cnnl.h"
#include "cnrt.h"
#include "kernel_utils.h"
namespace tmo {
#ifndef LT_NUM
#define LT_NUM (64)
#endif
#ifndef ANT_LT_ROW
#define ANT_LT_ROW (4)
#endif
#ifndef LT_NUM_ANT
#define LT_NUM_ANT (16)
#endif
#ifndef ONE_LINE
#define ONE_LINE (64)
#endif
#ifndef sizeof_
#define sizeof_(T) (uint32_t)sizeof(T)
#endif
#ifndef WRAM_LT_MAP16_STRIDE
#define WRAM_LT_MAP16_STRIDE (__MLU_WRAM_SIZE__ * 1024 / 16)
#endif
#ifndef TRANS_TABLE_SIZE
#define TRANS_TABLE_SIZE (64)
#endif
#ifndef DIV_UP
#define DIV_UP(x, y) ((x) / (y) + (int)((x) % (y) > 0))
#endif
#ifndef CONV_FUSE_OP_CVT
#define CONV_FUSE_OP_CVT(dtype, op, cvt, op_data) \
asm volatile("conv.nram.rn.f32" dtype dtype \
"[%[dst]], [%[src]], [%[kernel]], %[src_channel], " \
"%[src_height], 1, 1, 1, 1, 1, %[dst_channel]" op cvt \
";\n\t" ::[dst] "r"((Td *)output), \
[src] "r"((Ts *)input), [kernel] "r"((Ts *)filter), [src_channel] "r"(k), \
[src_height] "r"(m), [dst_channel] "r"(n), [operand0] "r"(op_data));
#endif
#define __reshape_nhwc2nchw_smallc(TYPE) \
asm volatile( \
"trans.tiling.nram.nram." TYPE \
"[%[dst]], [%[src]], " \
"%[in0], %[in1], %[is1], %[in2], %[is2], %[in3], %[is3], %[in4], %[is4], %[in5], %[is5]," \
"%[dn0], %[dn1], %[ds1], %[dn2], %[ds2], %[dn3], %[ds3], %[dn4], %[ds4], %[dn5], %[ds5]," \
".pretable.nram([%[pre]]); \n\t" ::[dst] "r"((T *)dst), \
[src] "r"((T *)src), [pre] "r"((uint8_t *)pre_table), [in0] "r"(in0), [in1] "r"(in1), \
[is1] "r"(in0), [in2] "i"(1), [is2] "i"(0), [in3] "r"(in3), [is3] "r"(is3), [in4] "r"(n), \
[is4] "r"(n_stride), [in5] "i"(1), [is5] "i"(0), [dn0] "r"(dn0), [dn1] "r"(dn1), \
[ds1] "r"(ds1), [dn2] "i"(1), [ds2] "i"(0), [dn3] "r"(dn3), [ds3] "r"(ds3), [dn4] "r"(n), \
[ds4] "r"(n_stride), [dn5] "i"(1), [ds5] "i"(0));
__mlu_func__ void next_power_of_two(int32_t &align_num, const int32_t num) {
int32_t tmp = num - 1;
asm volatile("findlast1.gpr.b32 %[out], %[in];" : [out] "=r"(tmp) : [in] "r"(tmp));
align_num = 1 << (tmp + 1);
}
/* copy from cnnl utils/trans_small.py by xwm. */
template <typename T>
__mlu_func__ void __reshape_nhwc2nchw_smallc_init(uint8_t *pre_table_nram, uint32_t channel) {
int32_t align_c;
next_power_of_two(align_c, channel);
int32_t align_num = ONE_LINE / sizeof_(T);
int32_t repeat = align_num / align_c;
for (int i = 0; i < 64; ++i) {
int32_t idx = i / sizeof_(T);
int32_t tmp_idx = (idx % repeat) * channel + idx / repeat;
int32_t real_idx = tmp_idx * sizeof_(T) + i % sizeof_(T);
__store_nram((uint8_t *)pre_table_nram + i, (uint8_t)real_idx + 0x80);
}
}
template <typename T>
__mlu_func__ void trans_nhwc2nchw_smallc(T *dst,
T *src,
uint8_t *pre_table,
uint32_t n,
uint32_t h,
uint32_t w,
uint32_t c) {
int32_t align_c;
next_power_of_two(align_c, c);
int32_t align_num = 64 / sizeof_(T);
int32_t hw = h * w;
int32_t repeat = align_num / align_c;
int32_t in0 = c * repeat * sizeof_(T);
int32_t in1 = align_c;
int32_t in3 = hw / align_num;
int32_t is3 = in0 * in1;
int32_t n_stride = hw * c * sizeof_(T);
int32_t dn0 = 64;
int32_t dn1 = c;
int32_t ds1 = hw * sizeof_(T);
int32_t dn3 = in3;
int32_t ds3 = dn0;
align_c = in3 > 0 ? align_c : 0;
if (align_c == 2) {
__reshape_nhwc2nchw_smallc("b256");
} else if (align_c == 4) {
__reshape_nhwc2nchw_smallc("b128");
} else if (align_c == 8) {
__reshape_nhwc2nchw_smallc("b64");
} else if (align_c == 16) {
__reshape_nhwc2nchw_smallc("b32");
} else if (align_c == 32) {
__reshape_nhwc2nchw_smallc("b16");
}
constexpr int32_t bw = 8 * sizeof_(T);
int32_t in3_rem = hw % align_num;
int32_t tail_in0 = c * sizeof_(T);
int32_t tail_dn0 = in3_rem * sizeof_(T);
if (in3_rem) {
asm volatile(
"trans.tiling.nram.nram.b%[bw] [%[dst]], [%[src]], \
%[in0], %[in1], %[is1], %[in2], %[is2], %[in3], \
%[is3], %[in4], %[is4], %[in5], %[is5], \
%[dn0], %[dn1], %[ds1], %[dn2], %[ds2], %[dn3], \
%[ds3], %[dn4], %[ds4], %[dn5], %[ds5]; \n\t" ::[bw] "i"(bw),
[dst] "r"((T *)dst + dn3 * ds3 / sizeof_(T)), [src] "r"((T *)src + is3 * in3 / sizeof_(T)),
[in0] "r"(tail_in0), [in1] "r"(in3_rem), [is1] "r"(tail_in0), [in2] "i"(1), [is2] "i"(0),
[in3] "i"(1), [is3] "i"(0), [in4] "r"(n), [is4] "r"(n_stride), [in5] "i"(1), [is5] "i"(0),
[dn0] "r"(tail_dn0), [dn1] "r"(dn1), [ds1] "r"(ds1), [dn2] "i"(1), [ds2] "i"(0),
[dn3] "i"(1), [ds3] "i"(0), [dn4] "r"(n), [ds4] "r"(n_stride), [dn5] "i"(1), [ds5] "i"(0));
}
}
__mlu_func__ void convert(float *dst, int8_t *src, int32_t num) {
__bang_int82float((float *)dst, (int8_t *)src, num, 0);
}
__mlu_func__ void convert(float *dst, int4x2_t *src, int32_t num) {
__bang_int42float((float *)dst, (int4x2_t *)src, num, 0);
}
__mlu_func__ void convert(half *dst, float *src, int32_t num) {
__bang_float2half((half *)dst, (float *)src, num);
}
__mlu_func__ void convert(bfloat16_t *dst, float *src, int32_t num) {
#if __BANG_ARCH__ >= 500
__bang_float2bfloat16((bfloat16_t *)dst, (float *)src, num);
#endif
}
__mlu_func__ void convert(int8_t *dst, int4x2_t *src, int32_t num) {
__bang_int42int8((int8_t *)dst, (int4x2_t *)src, num, 0, 0);
}
// if the dst dtype == src dtype, do nothing. if you want to mv, use mv directly
__mlu_func__ void convert(float *dst, float *src, int32_t num) {}
__mlu_func__ void convert(int8_t *dst, int8_t *src, int32_t num) {}
template <typename T>
__mlu_func__ void transpose(T *dst, T *src, int32_t dim1, int32_t dim2) {
__bang_transpose((T *)dst, (T *)src, dim1, dim2);
}
// if data type is int4x2_t, transpose is not supported directly
__mlu_func__ void transpose(int4x2_t *dst, int4x2_t *src, int32_t dim1, int32_t dim2) {}
template <typename T>
__mlu_func__ void mvNram2WramLT16(int8_t *wram_dst,
int8_t *nram_src,
int32_t n,
int32_t k,
int32_t total_k) {
int32_t data_size = k * sizeof_(T);
int32_t ds0 = PAD_UP(data_size, ONE_LINE);
int32_t ss0 = total_k * sizeof_(T);
int32_t count = DIV_UP(n, LT_NUM);
if (count > 0) {
for (int i = 0; i < count; ++i) {
__memcpy((int8_t *)wram_dst, (int8_t *)nram_src, data_size, NRAM2WRAM, ds0, ANT_LT_ROW - 1,
WRAM_LT_MAP16_STRIDE, LT_NUM_ANT - 1, ss0, LT_NUM - 1, 0, 0);
wram_dst += ANT_LT_ROW * ds0;
nram_src += LT_NUM * ss0;
}
}
count = n % LT_NUM / ANT_LT_ROW;
if (count > 0) {
__memcpy((int8_t *)wram_dst, (int8_t *)nram_src, data_size, NRAM2WRAM, ds0, ANT_LT_ROW - 1,
WRAM_LT_MAP16_STRIDE, count - 1, ss0, count * ANT_LT_ROW - 1, 0, 0);
wram_dst += count * WRAM_LT_MAP16_STRIDE;
nram_src += count * ANT_LT_ROW * ss0;
}
count = n % ANT_LT_ROW;
if (count) {
__memcpy((int8_t *)wram_dst, (int8_t *)nram_src, data_size, NRAM2WRAM, ds0, ss0, count - 1);
}
}
template <typename Td, typename Ts>
__mlu_func__ void
conv_fuse_mul_cvt(Td *output, Ts *input, Ts *filter, float *partial, int m, int n, int k) {
if (std::is_same<Td, half>::value && std::is_same<Ts, float>::value) {
CONV_FUSE_OP_CVT(".f32", ", .mul.partial.rn([%[operand0]])", ", .cvt.dst.rn.f16()", partial)
} else if (std::is_same<Td, bfloat16_t>::value && std::is_same<Ts, float>::value) {
#if __BANG_ARCH__ > 500
CONV_FUSE_OP_CVT(".f32", ", .mul.partial.rn([%[operand0]])", ", .cvt.dst.rn.bf16()", partial)
#endif
} else if (std::is_same<Td, float>::value && std::is_same<Ts, float>::value) {
CONV_FUSE_OP_CVT(".f32", ", .mul.partial.rn([%[operand0]])", "", partial)
}
}
template <bool ProcessOffsets>
__mlu_func__ void process_offsets(int32_t *lens_nram,
int32_t *offsets_nram,
const int32_t *context_lens,
const int32_t *context_seq_offsets,
const int32_t batch_size) {
if constexpr (ProcessOffsets) {
__memcpy((int32_t *)lens_nram, (int32_t *)context_lens, sizeof_(int32_t) * batch_size,
GDRAM2NRAM);
int total_lens = 0;
for (int batch_idx = 0; batch_idx < batch_size; ++batch_idx) {
__store_nram((int32_t *)offsets_nram + batch_idx, total_lens);
total_lens += __load_nram((int32_t *)lens_nram + batch_idx);
}
}
}
template <bool ProcessOffsets>
__mlu_func__ void load_len_offset(int32_t &seq_len,
int32_t &seq_offset,
const int32_t *lens_nram,
const int32_t *offsets_nram,
const int32_t *context_lens,
const int32_t *context_seq_offsets,
const int32_t batch_idx) {
if (ProcessOffsets) {
seq_len = __load_nram((int32_t *)lens_nram + batch_idx);
seq_offset = __load_nram((int32_t *)offsets_nram + batch_idx);
} else {
seq_len = __load_gdram((int32_t *)context_lens + batch_idx);
seq_offset = __load_gdram((int32_t *)context_seq_offsets + batch_idx);
}
}
template <typename T>
__mlu_func__ void load_scale_once(T *scale_nram,
const T *scale,
const int32_t head_num,
const int32_t head_size,
const size_t scale_bs_stride,
const size_t scale_head_stride) {
__memcpy((T *)scale_nram, (T *)scale, head_size * sizeof_(T), GDRAM2NRAM, head_size * sizeof_(T),
scale_head_stride * sizeof_(T), head_num - 1);
}
template <typename T, typename Tc, typename Ts>
__mlu_func__ void dequantize(T *output_nram,
Tc *input_nram,
Ts *scale_nram,
Ts *start_nram,
const int32_t input_num,
const int32_t scale_num) {
convert((float *)output_nram, (Tc *)input_nram, input_num);
convert((float *)start_nram, (Ts *)scale_nram, input_num);
__bang_cycle_mul((float *)output_nram, (float *)output_nram, (float *)start_nram, input_num,
scale_num);
convert((T *)output_nram, (float *)output_nram, input_num);
}
inline void getDeviceCoreAndRam(int32_t &cluster_dim,
int32_t &core_dim,
int32_t &nram_size,
int32_t &wram_size,
int32_t &sram_size,
const int32_t rem_for_stack) {
CNdev dev;
cnCtxGetDevice(&dev);
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_dim, cnrtAttrClusterCount, dev));
CNRT_CHECK(cnrtDeviceGetAttribute(&core_dim, cnrtAttrMcorePerCluster, dev));
CNRT_CHECK(cnrtDeviceGetAttribute(&nram_size, cnrtAttrNramSizePerMcore, dev));
CNRT_CHECK(cnrtDeviceGetAttribute(&wram_size, cnrtAttrWramSizePerMcore, dev));
CNRT_CHECK(cnrtDeviceGetAttribute(&sram_size, cnrtAttrSramSizePerMcore, dev));
nram_size -= rem_for_stack;
sram_size -= rem_for_stack;
}
} // namespace tmo
#endif // CSRC_KERNELS_QUANT_UTILS_H_

View File

@@ -0,0 +1,186 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include <algorithm>
#include <iostream>
#include <type_traits>
#include "cnnl.h"
#include "cnrt.h"
#include "quantize.mluh"
// clang-format off
#include <mlu.h>
// clang-format on
namespace tmo {
#define NRAM_BUFFER_SIZE (480 * 1024)
namespace kernels {
__nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE];
template <typename TSrc>
__mlu_func__ void quantify(int8_t *nram_dst,
TSrc *nram_src,
TSrc *nram_scale_temp,
float *nram_scale,
float *scale_origin,
int core_deal_tokens,
int hidden) {
__bang_abs((TSrc *)nram_dst, nram_src, core_deal_tokens * hidden);
__bang_maxpool(nram_scale_temp, (TSrc *)nram_dst, core_deal_tokens, hidden, 1, hidden, 1, 1, 1);
if (std::is_same<half, TSrc>::value) {
__bang_half2float(nram_scale, (half *)nram_scale_temp, core_deal_tokens);
}
__bang_mul_scalar(nram_scale, nram_scale, 1 / 127.f, core_deal_tokens);
__bang_recip(scale_origin, nram_scale, core_deal_tokens);
if (std::is_same<half, TSrc>::value) {
__bang_float2half_rn((half *)scale_origin, scale_origin, core_deal_tokens);
}
__bang_cycle_mul(nram_src, nram_src, (TSrc *)scale_origin, core_deal_tokens * hidden,
core_deal_tokens);
if (std::is_same<half, TSrc>::value) {
__bang_half2int8_rn((int8_t *)nram_src, (half *)nram_src, core_deal_tokens * hidden, 0);
} else if (std::is_same<float, TSrc>::value) {
__bang_float2int8_rn((int8_t *)nram_src, (float *)nram_src, core_deal_tokens * hidden, 0);
}
__bang_transpose(nram_dst, (int8_t *)nram_src, hidden, core_deal_tokens);
}
template <typename TDst, typename TSrc, typename TScale>
__mlu_global__ void MLUQuantizePerHead(
TDst *dst, // [bs, seq, head_num, head_size], may not be continuous
TScale *scale, // [bs, seq], must becontinuous
const TSrc *src, // [bs, seq, head_num, head_size], may not be continuous
int bs,
int seq_len,
int head_num,
int head_size,
int src_bs_stride,
int src_seq_stride,
int src_head_stride,
int dst_bs_stride,
int dst_seq_stride,
int dst_head_stride) {
int total_bs = bs * seq_len;
int hidden = head_num * head_size;
int core_average_tokens = (total_bs + taskDim - 1) / taskDim;
int core_begin_tokens = core_average_tokens * taskId;
int core_deal_tokens = std::min(total_bs - core_begin_tokens, core_average_tokens);
if (__is_mpu()) {
return;
}
if (core_deal_tokens <= 0) {
return;
}
TScale *nram_scale = (TScale *)nram_buffer;
TScale *scale_origin = nram_scale + core_deal_tokens * head_num;
TSrc *nram_scale_temp =
(TSrc *)(nram_buffer + core_deal_tokens * head_num * (sizeof(TScale) - sizeof(TSrc)));
TSrc *nram_ping = (TSrc *)(scale_origin + core_deal_tokens * head_num);
TSrc *nram_temp = nram_ping + core_deal_tokens * hidden;
const TSrc *src_begin = src + core_begin_tokens * src_seq_stride;
TDst *dst_begin = dst + core_begin_tokens * dst_seq_stride;
TScale *scale_begin = scale + core_begin_tokens * head_num;
// load
__memcpy(nram_ping, src_begin, head_size * sizeof(TSrc), GDRAM2NRAM, head_size * sizeof(TSrc),
head_num - 1, hidden * sizeof(TSrc), core_deal_tokens - 1,
src_head_stride * sizeof(TSrc), head_num - 1, src_seq_stride * sizeof(TSrc),
core_deal_tokens - 1);
__bang_transpose(nram_temp, nram_ping, core_deal_tokens * head_num, head_size);
quantify((TDst *)nram_ping, nram_temp, nram_scale_temp, nram_scale, scale_origin,
core_deal_tokens * head_num, head_size);
// store scale
__memcpy(scale_begin, nram_scale, core_deal_tokens * head_num * sizeof(TScale), NRAM2GDRAM);
// store
__memcpy(dst_begin, nram_ping, head_size * sizeof(TDst), NRAM2GDRAM,
dst_head_stride * sizeof(TDst), head_num - 1, dst_seq_stride * sizeof(TDst),
core_deal_tokens - 1, head_size * sizeof(TDst), head_num - 1, hidden * sizeof(TDst),
core_deal_tokens - 1);
}
} // namespace kernels
KernelStatus invokeMluQuantizePerHead(cnrtQueue_t queue,
void *dst,
void *scale,
const void *src,
cnnlDataType_t dst_dtype,
cnnlDataType_t scale_dtype,
cnnlDataType_t src_dtype,
int bs,
int seq_len,
int head_num,
int head_size,
int dst_bs_stride,
int dst_seq_stride,
int dst_head_stride,
int src_bs_stride,
int src_seq_stride,
int src_head_stride) {
// bs must be continuous, for pack mode, bs = 1, seq_len equals to sum of all bs seq_len.
if (dst_bs_stride != seq_len * dst_seq_stride) {
std::cerr
<< "[invokeMluQuantizePerToken]: dst_bs_stride must equal to seq_len * dst_seq_stride."
<< std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
if (dst_head_stride != head_size) {
std::cerr << "[invokeMluQuantizePerToken]: dst_head_stride must equal to head_size."
<< std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
if (src_bs_stride != seq_len * src_seq_stride) {
std::cerr
<< "[invokeMluQuantizePerToken]: src_bs_stride must equal to seq_len * src_seq_stride."
<< std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
if (src_head_stride != head_size) {
std::cerr << "[invokeMluQuantizePerToken]: src_head_stride must equal to head_size."
<< std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
CNdev dev;
cnCtxGetDevice(&dev);
int cluster_num;
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
int core_num;
CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev));
int scale_buffer_size = 64 * 1024; // scale on nram
int dtype_size = (src_dtype == CNNL_DTYPE_HALF || src_dtype == CNNL_DTYPE_BFLOAT16) ? 2 : 4;
int bs_once = (NRAM_BUFFER_SIZE - scale_buffer_size) / (2 * head_num * head_size * dtype_size);
int bs_once_ = scale_buffer_size / 2 / sizeof(float);
bs_once = std::min(bs_once, bs_once_);
uint32_t task_dim = std::min(bs * seq_len, cluster_num * core_num);
task_dim = std::max((uint32_t)(bs * seq_len + bs_once - 1) / bs_once, task_dim);
cnrtDim3_t dim{task_dim, 1, 1};
if (src_dtype == CNNL_DTYPE_FLOAT) {
kernels::MLUQuantizePerHead<<<dim, cnrtFuncTypeBlock, queue>>>(
(int8_t *)dst, (float *)scale, (const float *)src, bs, seq_len, head_num, head_size,
src_bs_stride, src_seq_stride, src_head_stride, dst_bs_stride, dst_seq_stride,
dst_head_stride);
} else if (src_dtype == CNNL_DTYPE_HALF) {
kernels::MLUQuantizePerHead<<<dim, cnrtFuncTypeBlock, queue>>>(
(int8_t *)dst, (float *)scale, (const half *)src, bs, seq_len, head_num, head_size,
src_bs_stride, src_seq_stride, src_head_stride, dst_bs_stride, dst_seq_stride,
dst_head_stride);
} else if (src_dtype == CNNL_DTYPE_BFLOAT16) {
std::cerr << __func__ << "," << __LINE__ << " :currently does not support bfloat16."
<< std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
} // namespace tmo

View File

@@ -0,0 +1,60 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_QUANTIZE_MLUH_
#define CSRC_KERNELS_QUANTIZE_MLUH_
#include "cnnl.h"
#include "kernel_utils.h"
namespace tmo {
/**
* @brief quantize tensor by per head.
* @param queue: The queue for mlu.
* @param dst: Output. Pointer to the destination MLU memory. Shape is [bs, seq, head_num,
* head_size], may not be continuous.
* @param scale: Input. Pointer to the destination scale MLU memory. Shape is [bs, seq, head_num],
* must be continuous.
* @param src: Input. Pointer to the source MLU memory. Shape is [bs, seq, head_num, head_size], may
* not be continuous.
* @param dst_dtype: Data type of destination tensor. Must be int8.
* @param scale_dtype: Data type of destination scale tensor. Must be float32.
* @param src_dtype: Data type of src tensor. Must be float or half.
* @param bs: batch_size of dst or src tensor.
* @param seq_len: seq_len of dst or src tensor.
* @param head_num: head_num of dst or src tensor.
* @param head_size: head_size of dst or src tensor.
* @param dst_bs_stride: stride of batch_size dim of dst tensor.
* @param dst_seq_stride: stride of seq_len dim of dst tensor.
* @param dst_head_stride: stride of head_num dim of dst tensor.
* @param src_bs_stride: stride of batch_size dim of src tensor.
* @param src_seq_stride: stride of seq_len dim of src tensor.
* @param src_head_stride: stride of head_num dim of src tensor.
*/
KernelStatus invokeMluQuantizePerHead(cnrtQueue_t queue,
void *dst,
void *scale,
const void *src,
cnnlDataType_t dst_dtype,
cnnlDataType_t scale_dtype,
cnnlDataType_t src_dtype,
int bs,
int seq_len,
int head_num,
int head_size,
int dst_bs_stride,
int dst_seq_stride,
int dst_head_stride,
int src_bs_stride,
int src_seq_stride,
int src_head_stride);
} // namespace tmo
#endif // CSRC_KERNELS_QUANTIZE_MLUH_

View File

@@ -0,0 +1,134 @@
#include "reshape_linear_cache.mluh"
// clang-format off
#include <mlu.h>
// clang-format on
namespace tmo {
namespace kernels {
// [head_num, batch, seq_seg]
__mlu_global__ void MLUReshapeLinearCacheKernel(int8_t *key_cache,
int8_t *value_cache,
int *cache_bs_offsets,
int *cache_seq_offsets,
int8_t *key,
int8_t *value,
int *context_seq_offsets,
int *context_lens,
int batch,
int head_num,
int head_size,
int max_context_len,
int cache_mem_len,
size_t context_bs_stride,
size_t context_head_stride,
size_t context_seq_stride,
size_t cache_bs_stride,
size_t cache_head_stride,
size_t cache_seq_stride,
bool packed,
int dtype_size,
int SEQ_BLOCK) {
int head_repeat = taskDimX > 1 ? 1 : head_num;
for (int bs_idx = taskIdY; bs_idx < batch; bs_idx += taskDimY) {
int seq_offset = (packed || context_seq_offsets == nullptr) ? 0 : context_seq_offsets[bs_idx];
int task_seq_begin = taskIdZ * SEQ_BLOCK;
int seq_len = packed ? (context_lens[bs_idx + 1] - context_lens[bs_idx]) : context_lens[bs_idx];
if (task_seq_begin >= seq_len) continue;
int seq = std::min(seq_len - task_seq_begin, SEQ_BLOCK);
size_t context_offset = taskIdX * context_head_stride * dtype_size;
if (packed) {
context_offset += (context_lens[bs_idx] + task_seq_begin) * context_seq_stride * dtype_size;
} else {
context_offset +=
(bs_idx * context_bs_stride + (task_seq_begin + seq_offset) * context_seq_stride) *
dtype_size;
}
int cache_seq_offset = cache_seq_offsets == nullptr ? 0 : cache_seq_offsets[bs_idx];
int cache_bs_offset = cache_bs_offsets == nullptr ? bs_idx : cache_bs_offsets[bs_idx];
if (cache_seq_offset < 0 || cache_bs_offset < 0) {
continue;
}
cache_seq_offset += task_seq_begin;
if (key != nullptr && key_cache != nullptr) {
int8_t *key_cache_begin =
key_cache + (cache_bs_offset * cache_bs_stride + taskIdX * cache_head_stride +
cache_seq_offset * cache_seq_stride) *
dtype_size;
int8_t *key_begin = key + context_offset;
__memcpy(key_cache_begin, key_begin, head_size * dtype_size, GDRAM2GDRAM,
cache_seq_stride * dtype_size, seq - 1, cache_head_stride * dtype_size,
head_repeat - 1, context_seq_stride * dtype_size, seq - 1,
context_head_stride * dtype_size, head_repeat - 1);
}
if (value != nullptr && value_cache != nullptr) {
int8_t *value_cache_begin =
value_cache + (cache_bs_offset * cache_bs_stride + taskIdX * cache_head_stride +
cache_seq_offset * cache_seq_stride) *
dtype_size;
int8_t *value_begin = value + context_offset;
__memcpy(value_cache_begin, value_begin, head_size * dtype_size, GDRAM2GDRAM,
cache_seq_stride * dtype_size, seq - 1, cache_head_stride * dtype_size,
head_repeat - 1, context_seq_stride * dtype_size, seq - 1,
context_head_stride * dtype_size, head_repeat - 1);
}
}
}
} // namespace kernels
KernelStatus invokeReshapeLinearCache(cnrtQueue_t queue,
void *key_cache,
void *value_cache,
const void *cache_bs_offsets,
const void *cache_seq_offsets,
void *key,
void *value,
const void *context_seq_offsets,
const void *context_lens,
const cnnlDataType_t dtype,
const int batch,
const int head_num,
const int head_size,
const int max_context_len,
const int cache_mem_len,
const int context_bs_stride,
const int context_head_stride,
const int context_seq_stride,
const int cache_bs_stride,
const int cache_head_stride,
const int cache_seq_stride,
const bool packed) {
constexpr int SEQ_BLOCK = 512;
int seq_seg = (max_context_len + SEQ_BLOCK - 1) / SEQ_BLOCK;
bool is_decoder_case = head_num * max_context_len < SEQ_BLOCK;
uint32_t task_x_dim = is_decoder_case ? 1 : head_num;
uint32_t task_y_dim = is_decoder_case ? std::min(batch, 48) : batch;
cnrtDim3_t dim{task_x_dim, task_y_dim, (uint32_t)seq_seg};
int dtype_size = 1;
if (dtype == CNNL_DTYPE_HALF || dtype == CNNL_DTYPE_BFLOAT16) {
dtype_size = 2;
} else if (dtype == CNNL_DTYPE_INT8) {
dtype_size = 1;
} else if (dtype == CNNL_DTYPE_FLOAT) {
dtype_size = 4;
} else {
std::cerr << "invokeReshapeLinearCache: unsupport dtype" << std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
kernels::MLUReshapeLinearCacheKernel<<<dim, cnrtFuncTypeBlock, queue>>>(
(int8_t *)key_cache, (int8_t *)value_cache, (int *)cache_bs_offsets, (int *)cache_seq_offsets,
(int8_t *)key, (int8_t *)value, (int *)context_seq_offsets, (int *)context_lens, batch,
head_num, head_size, max_context_len, cache_mem_len, context_bs_stride, context_head_stride,
context_seq_stride, cache_bs_stride, cache_head_stride, cache_seq_stride, packed, dtype_size,
SEQ_BLOCK);
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
} // namespace tmo

View File

@@ -0,0 +1,106 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_RESHAPE_LINEAR_CACHE_MLUH_
#define CSRC_KERNELS_RESHAPE_LINEAR_CACHE_MLUH_
#include "cnnl.h"
#include "kernel_utils.h"
namespace tmo {
/**
* @brief In the context stage, concate the result of multi head attention
* key and value to key_cache and value_cache.
* @example
* input:
* cache:
* [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
* [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
* [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
* context:
* [[1, 2, 3, 4, 5],
* [6, 7, 8, 9, 10]]
* cache_bs_offsets: [1, 2]
* cache_seq_offsets: [3, 4]
* context_seq_offsets: [0, 1]
* context_lens: [4, 3]
* output:
* [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
* [0, 0, 0, 1, 2, 3, 4, 0, 0, 0],
* [0, 0, 0, 0, 7, 8, 9, 0, 0, 0]]
* @param queue: The queue for mlu.
* @param key_cache: Pointer to the MLU memory that stores the key cache,
* the shape must be [max_batch, head_num, cache_mem_len, head_size].
* key_cache could be nullptr.
* @param value_cache: Pointer to the MLU memory that stores the value cache,
* the shape must be [max_batch, head_num, cache_mem_len, head_size].
* value_cache could be nullptr.
* @param cache_bs_offsets: Pointer to the MLU memory that stores the batch
* offset of cache, the shape must be [batch], if it's nullptr, the
* default value is {0, 1, 2 ... batch - 1}.
* @param cache_seq_offsets: Input. Pointer to the MLU memory that stores the sequence
* offset of cache, the shape must be [batch], if it's nullptr, the
* default value is 0 for every batch.
* @param key: Pointer to the MLU memory that stores the key,
* the shape must be [batch, max_contxt_len, head_num, head_size].
* key could be nullptr.
* @param value: Pointer to the MLU memory that stores the value,
* the shape must be [batch, max_contxt_len, head_num, head_size].
* value could be nullptr.
* @param context_seq_offsets: Pointer to the MLU memory that stores the
* sequence offset of context, the shape must be [batch]. if it's nullptr,
* the default value is 0 for every batch. It must be nullptr when packed is true.
* @param context_lens: Input. Pointer to the MLU memory that stores the sequence length or
* cumulative sequence length of context. when packed is false, the shape must be [batch], which
* indicates sequence length of context. when packed is true, the shape must be [batch + 1], which
* indicates cumulative sequence length of context.
* @param dtype: Data type.
* @param batch: Batch size.
* @param head_num: Head number.
* @param head_size: Head size.
* @param max_contxt_len: The maximum sequence length of context.
* @param cache_mem_len: The maximum sequence length of cache.
* @param contxt_bs_stride: The stride of batch in context, does not work when packed is true.
* @param contxt_head_stride: The stride of head_num in context.
* @param contxt_seq_stride: The stride of max_contxt_len in context.
* @param cache_bs_stride: The stride of batch in cache.
* @param cache_head_stride: The stride of head_num in cache.
* @param cache_seq_stride: The stride of cache_mem_len in cache.
* @param packed: A boolean value indicates whether to use pack mode.
* @note If key and key_cache are nullptr, nothing todo for key.
If value and value_cache are nullptr, nothing todo for value.
A negative value in cache_bs_offsets or cache_seq_offsets means nothing to do for
the corresponding batch.
*/
KernelStatus invokeReshapeLinearCache(cnrtQueue_t queue,
void *key_cache,
void *value_cache,
const void *cache_bs_offsets,
const void *cache_seq_offsets,
void *key,
void *value,
const void *context_seq_offsets,
const void *context_lens,
const cnnlDataType_t dtype,
const int batch,
const int head_num,
const int head_size,
const int max_context_len,
const int cache_mem_len,
const int context_bs_stride,
const int context_head_stride,
const int context_seq_stride,
const int cache_bs_stride,
const int cache_head_stride,
const int cache_seq_stride,
const bool packed);
} // namespace tmo
#endif // CSRC_KERNELS_RESHAPE_LINEAR_CACHE_MLUH_

View File

@@ -0,0 +1,166 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "reshape_paged_cache.mluh"
namespace tmo {
namespace kernels {
#define NRAM_BUFFER_SIZE (480 * 1024)
__nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE];
__nram__ int nram_range_32[32] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31};
#define sizeof_(T) (uint32_t)sizeof(T)
__mlu_global__ void MLUReshapePagedCacheKernel(int8_t *key,
int8_t *value,
int8_t *key_cache,
int8_t *value_cache,
int *slot_mapping,
size_t key_stride0,
size_t value_stride0,
int num_tokens,
int num_heads,
int block_size,
int head_size,
int dtype_size,
int seq_block) {
#if __BANG_ARCH__ > 500
int seq_begin = taskId * seq_block;
if (seq_begin >= num_tokens) return;
int seq = std::min(seq_block, num_tokens - seq_begin);
int head_bytes = head_size * dtype_size;
int head_stride = block_size * head_bytes;
int block_stride = num_heads * head_stride;
int hidden_bytes = num_heads * head_bytes;
int8_t *nram_input = nram_buffer;
int *nram_token_offset = (int *)(nram_input + seq * hidden_bytes);
int pad_8_size = (num_heads * seq + 7) / 8 * 8;
int *nram_block_offset = nram_token_offset + pad_8_size;
int *nram_offset = nram_block_offset + pad_8_size;
int *nram_mask = nram_offset + pad_8_size;
__memcpy(nram_offset, slot_mapping + seq_begin, seq * sizeof_(int), GDRAM2NRAM);
__bang_rem(nram_token_offset, nram_offset, (int)block_size, seq);
__bang_mul_scalar(nram_token_offset, nram_token_offset, head_bytes, seq);
__bang_div(nram_block_offset, nram_offset, (int)block_size, seq);
__bang_mul_scalar(nram_block_offset, nram_block_offset, block_stride, seq);
// (num_heads, seq)
__memcpy(nram_offset, nram_token_offset, seq * sizeof_(int), NRAM2NRAM, seq * sizeof_(int), 0,
num_heads - 1);
// (num_heads, seq) -> (seq, num_heads)
__bang_transpose(nram_token_offset, nram_offset, num_heads, seq);
// (num_heads, seq)
__memcpy(nram_offset, nram_block_offset, seq * sizeof_(int), NRAM2NRAM, seq * sizeof_(int), 0,
num_heads - 1);
// (num_heads, seq) -> (seq, num_heads)
__bang_transpose(nram_block_offset, nram_offset, num_heads, seq);
__bang_write_zero(nram_offset, pad_8_size);
__bang_ge_bitindex((float *)nram_mask, (float *)nram_token_offset, (float *)nram_offset,
pad_8_size);
// generate range: (0, head_stride, 2 * head_stride, ..., (num_heads - 1) * head_stride)
__memcpy(nram_offset, nram_range_32, std::min(num_heads, 32) * sizeof_(int), NRAM2NRAM);
int begin = 32;
while (begin < num_heads) {
int count = std::min(begin, num_heads - begin);
__bang_add_scalar(nram_offset + begin, nram_offset, begin, count);
begin += count;
}
__bang_mul_scalar(nram_offset, nram_offset, head_stride, num_heads);
__bang_cycle_add(nram_token_offset, nram_token_offset, nram_offset, seq * num_heads, num_heads);
__bang_add(nram_offset, nram_token_offset, nram_block_offset, seq * num_heads);
if (key != nullptr && key_cache != nullptr) {
// (seq, num_heads, head_size)
__memcpy(nram_input, key + seq_begin * key_stride0 * dtype_size, hidden_bytes, GDRAM2NRAM,
hidden_bytes, key_stride0 * dtype_size, seq - 1);
__scatter(key_cache, nram_input, (uint32_t *)nram_offset, nram_mask, head_bytes, NRAM2GDRAM,
head_bytes, seq * num_heads);
}
if (value != nullptr && value_cache != nullptr) {
__memcpy(nram_input, value + seq_begin * value_stride0 * dtype_size, hidden_bytes, GDRAM2NRAM,
hidden_bytes, value_stride0 * dtype_size, seq - 1);
__scatter(value_cache, nram_input, (uint32_t *)nram_offset, nram_mask, head_bytes, NRAM2GDRAM,
head_bytes, seq * num_heads);
}
#endif
}
} // namespace kernels
KernelStatus invokeReshapePagedCache(cnrtQueue_t queue,
cnnlDataType_t data_type,
void *key,
void *value,
void *key_cache,
void *value_cache,
void *slot_mapping,
size_t key_stride0,
size_t value_stride0,
int num_tokens,
int num_heads,
int block_num,
int block_size,
int head_size) {
if (is_arch300()) {
std::cerr << "[invokeReshapePagedCache]: kernel does not support MLU300 devices." << std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
int dtype_size = 1;
if (data_type == CNNL_DTYPE_HALF || data_type == CNNL_DTYPE_BFLOAT16) {
dtype_size = 2;
} else if (data_type == CNNL_DTYPE_INT8) {
dtype_size = 1;
} else if (data_type == CNNL_DTYPE_FLOAT) {
dtype_size = 4;
} else {
std::cerr << "invokeReshapePagedCache: unsupport data type\n";
return KernelStatus::KERNEL_STATUS_FAILED;
}
int64_t kv_cache_range = block_num * block_size * num_heads * head_size * dtype_size;
if (kv_cache_range > UINT32_MAX) {
std::cerr << "[invokeReshapePagedCache]: The addressing range of kv_cache cannot exceed 4G."
<< std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
constexpr int nram_size = 224 * 1024;
int hidden_bytes = num_heads * head_size * dtype_size + 4 * num_heads * sizeof(int);
int seq_block = nram_size / hidden_bytes;
if (seq_block <= 0) {
std::cerr << "invokeReshapePagedCache: "
<< "num_heads * head_size * dtype_size + 4 * num_heads * sizeof(int) "
<< "should be less than 224KB.\n";
return KernelStatus::KERNEL_STATUS_FAILED;
}
if (seq_block > 16) {
seq_block = seq_block / 16 * 16;
}
uint32_t task_dim = (num_tokens + seq_block - 1) / seq_block;
task_dim = std::max(task_dim, (uint32_t)8);
task_dim = std::min(task_dim, (uint32_t)num_tokens);
cnrtDim3_t dim{task_dim, 1, 1};
kernels::MLUReshapePagedCacheKernel<<<dim, cnrtFuncTypeBlock, queue>>>(
(int8_t *)key, (int8_t *)value, (int8_t *)key_cache, (int8_t *)value_cache,
(int *)slot_mapping, key_stride0, value_stride0, num_tokens, num_heads, block_size, head_size,
dtype_size, seq_block);
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
} // namespace tmo

View File

@@ -0,0 +1,54 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_RESHAPE_PAGED_CACHE_MLUH_
#define CSRC_KERNELS_RESHAPE_PAGED_CACHE_MLUH_
#include "kernel_utils.h"
namespace tmo {
/**
* @brief Perform reshape_paged_cache operation.
* @param handle: The handle of cnnl.
* @param key: Pointer to the MLU memory that stores the key tensor which has shape [num_tokens,
* num_heads, head_size].
* @param value: Pointer to the MLU memory that stores the value tensor which has shape [num_tokens,
* num_heads, head_size].
* @param key_cache: Pointer to the MLU memory that stores the key_cache tensor which has shape
* [num_blocks, num_heads, block_size, head_size].
* @param value_cache: Pointer to the MLU memory that stores the value_cache tensor which has shape
* [num_blocks, num_heads, block_size, head_size].
* @param slot_mapping: Pointer to the MLU memory that stores the slot_mapping tensor which has
* shape [num_tokens]. Data type of slot mapping must be int32_t.
* @param key_stride0: The first dimension stride length of key_cache tensor.
* @param value_stride0: The first dimension stride length of value_cache tensor.
* @param num_tokens: Total number of tokens.
* @param num_heads: Head number.
* @param block_num: Total number of blocks.
* @param block_size: Number of tokens per block.
* @note: reshape_paged_cache does not support MLU300 device.
*/
KernelStatus invokeReshapePagedCache(cnrtQueue_t queue,
cnnlDataType_t data_type,
void *key,
void *value,
void *key_cache,
void *value_cache,
void *slot_mapping,
size_t key_stride0,
size_t value_stride0,
int num_tokens,
int num_heads,
int block_num,
int block_size,
int head_size);
} // namespace tmo
#endif // CSRC_KERNELS_RESHAPE_PAGED_CACHE_MLUH_

View File

@@ -0,0 +1,628 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include <algorithm>
#include <cstddef>
#include <type_traits>
#include "rotary_embedding.mluh"
// clang-format off
#include <mlu.h>
// clang-format on
namespace tmo {
namespace kernels {
#define NRAM_REMAIN_SIZE (32 * 1024)
#define NRAM_BUFFER_SIZE (__MLU_NRAM_SIZE__ * 1024 - NRAM_REMAIN_SIZE)
__nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE];
__nram__ float nram_meta_mask[32] = {1.f, 0.f, 1.f, 0.f, 1.f, 0.f, 1.f, 0.f, 1.f, 0.f, 1.f,
0.f, 1.f, 0.f, 1.f, 0.f, 1.f, 0.f, 1.f, 0.f, 1.f, 0.f,
1.f, 0.f, 1.f, 0.f, 1.f, 0.f, 1.f, 0.f, 1.f, 0.f};
__nram__ float nram_mask[1024];
__nram__ int nram_offsets[1024];
__mlu_func__ void loadTableAsync(void *nram_table,
void *gdram_table,
int *nram_offset,
int rotary_dim,
int rotary_stride,
int seq_block,
int seq_begin,
int dtype_size,
bool discrete,
bool decoder_mode) {
if (!discrete) {
int src_stride = decoder_mode ? 0 : rotary_stride * dtype_size;
__memcpy_async(nram_table, gdram_table, rotary_dim * dtype_size, GDRAM2NRAM,
rotary_dim * dtype_size, src_stride, seq_block - 1);
} else {
#if __BANG_ARCH__ >= 592
__gather_async(nram_table, gdram_table, (uint32_t *)nram_offset, rotary_dim * dtype_size,
GDRAM2NRAM, rotary_dim * dtype_size, seq_block);
#else
for (int i = 0; i < seq_block; i++) {
__memcpy_async((int8_t *)nram_table + i * rotary_dim * dtype_size,
(int8_t *)gdram_table + nram_offset[i], rotary_dim * dtype_size, GDRAM2NRAM);
}
#endif
}
}
template <typename T>
__mlu_func__ void toFloat(float *dst, T *src, int count) {
if (std::is_same<T, half>::value) {
__bang_half2float(dst, (half *)src, count);
} else if (std::is_same<T, bfloat16_t>::value) {
__bang_bfloat162float(dst, (bfloat16_t *)src, count);
}
}
template <typename T>
__mlu_func__ void floatTo(T *dst, float *src, int count) {
if (std::is_same<T, half>::value) {
__bang_float2half_rn((half *)dst, src, count);
} else if (std::is_same<T, bfloat16_t>::value) {
__bang_float2bfloat16_rn((bfloat16_t *)dst, src, count);
}
}
template <typename T>
__mlu_func__ void initMask(float *mask, int rotary_dim, bool interleaved) {
if (interleaved) {
T *mask0 = (T *)mask;
T *mask1 = (T *)(mask + 512);
int seg = (rotary_dim + 31) / 32;
__memcpy(mask0, nram_meta_mask, 32 * sizeof(float), NRAM2NRAM, 32 * sizeof(float), 0, seg - 1);
floatTo((T *)mask0, (float *)mask0, rotary_dim);
__bang_add_scalar(mask1, mask0, (T)-1, rotary_dim);
} else {
__bang_write_value((T *)mask, rotary_dim / 2, (T)-1);
__bang_write_value((T *)mask + rotary_dim / 2, rotary_dim / 2, (T)1);
}
}
/*
* half: mask, in, sl, sr
* float: sl, , sr, , sin, cos
*/
template <typename T>
__mlu_func__ void crossRotaryEmbedding(T *output,
T *input,
T *sin_table,
T *cos_table,
int *seq_offsets,
int head_num,
int seq_block,
int head_size,
int rotary_dim,
int rotary_stride,
size_t input_head_stride,
size_t input_seq_stride,
size_t output_head_stride,
size_t output_seq_stride,
int seq_begin,
bool discrete,
bool decoder_mode = false) {
int float_size = sizeof(float);
int dtype_size = sizeof(T);
int seq_rotary = seq_block * rotary_dim;
int block_head = head_num * seq_rotary;
float *q_1 = (float *)nram_buffer;
float *sincos = q_1 + block_head + 2;
float *q_2 = sincos + block_head + 2;
T *temp = (T *)q_2 + block_head + 2;
if (seq_offsets != nullptr && (discrete || decoder_mode)) {
__memcpy(nram_offsets, seq_offsets + seq_begin, seq_block * sizeof(int), GDRAM2NRAM);
__bang_mul_scalar(nram_offsets, nram_offsets, rotary_stride * dtype_size, seq_block);
}
bool gather_table = (seq_offsets != nullptr && decoder_mode) || discrete;
T *mask0 = (T *)nram_mask;
T *mask1 = (T *)(nram_mask + 512);
T *q_1_ = (T *)((int8_t *)q_1 + (float_size - dtype_size) * (block_head + 2));
T *sincos_ = (T *)((int8_t *)sincos + (float_size - dtype_size) * (block_head + 2));
T *q_2_ = (T *)((int8_t *)q_2 + (float_size - dtype_size) * (block_head + 2));
// if dtype is float, temp point to a new buffer, and temp_ is temp;
// if dtype is half/bfloat16, temp is q_2_, and temp_ is (T*)q_2;
T *temp_ = dtype_size == 4 ? temp : (T *)q_2;
// load input
__memcpy_async(q_1_, input, rotary_dim * dtype_size, GDRAM2NRAM, rotary_dim * dtype_size,
seq_block - 1, seq_rotary * dtype_size, head_num - 1,
input_seq_stride * dtype_size, seq_block - 1, input_head_stride * dtype_size,
head_num - 1);
__bang_write_zero(q_2_ + block_head, 2);
__sync();
// copy input
__memcpy_async(q_2_, q_1_, block_head * dtype_size, NRAM2NRAM);
__bang_cycle_mul(temp_, q_1_, mask0, block_head, rotary_dim);
__sync();
// load cos
loadTableAsync(sincos_, cos_table, nram_offsets, rotary_dim, rotary_stride, seq_block, seq_begin,
dtype_size, gather_table, decoder_mode);
__bang_cycle_mul(q_2_, q_2_, mask1, block_head, rotary_dim);
// rotary_input
__bang_add(q_2_ + 2, temp_, q_2_ + 2, block_head);
toFloat(q_1, q_1_, block_head);
__sync();
toFloat(sincos, sincos_, block_head);
// input * cos
__bang_cycle_mul(q_1, q_1, sincos, block_head, seq_rotary);
__sync();
toFloat(q_2, q_2_, block_head + 2);
// load sin
loadTableAsync(sincos_, sin_table, nram_offsets, rotary_dim, rotary_stride, seq_block, seq_begin,
dtype_size, gather_table, decoder_mode);
__sync();
toFloat(sincos, sincos_, block_head);
// rotary_input * sin
__bang_cycle_mul(q_2, q_2 + 1, sincos, block_head, seq_rotary);
// input_cos + rotary_input_sin
__bang_add(q_1, q_1, q_2, block_head);
floatTo((T *)q_1, q_1, block_head);
if ((head_size - rotary_dim) > 0) {
__memcpy_async(output + rotary_dim, input + rotary_dim, (head_size - rotary_dim) * dtype_size,
GDRAM2GDRAM, output_seq_stride * dtype_size, seq_block - 1,
output_head_stride * dtype_size, head_num - 1, input_seq_stride * dtype_size,
seq_block - 1, input_head_stride * dtype_size, head_num - 1);
}
// copy out
__memcpy(output, q_1, rotary_dim * dtype_size, NRAM2GDRAM, output_seq_stride * dtype_size,
seq_block - 1, output_head_stride * dtype_size, head_num - 1, rotary_dim * dtype_size,
seq_block - 1, seq_rotary * dtype_size, head_num - 1);
}
template <typename T>
__mlu_func__ void foldRotaryEmbedding(T *output,
T *input,
T *sin_table,
T *cos_table,
int *seq_offsets,
int head_num,
int seq_block,
int head_size,
int rotary_dim,
int rotary_stride,
size_t input_head_stride,
size_t input_seq_stride,
size_t output_head_stride,
size_t output_seq_stride,
int seq_begin,
bool discrete,
bool decoder_mode,
bool loop_head,
int once_head_num) {
once_head_num = loop_head ? once_head_num : head_num;
int loop_num = (head_num + once_head_num - 1) / once_head_num;
// int head_per_loop = loop_head ? 1 : head_num;
int seq_rotary = seq_block * rotary_dim;
int block_head = once_head_num * seq_rotary;
int buffer_blocks = loop_head ? 2 : 1;
float *buffer = (float *)nram_buffer;
float *q_2 = buffer + block_head * buffer_blocks;
float *sin = q_2 + block_head;
float *cos = sin + seq_rotary;
int float_size = sizeof(float);
int dtype_size = sizeof(T);
T *sincos_ = (T *)((int8_t *)sin + (float_size - dtype_size) * seq_rotary * 2);
T *q_2_ = (T *)((int8_t *)q_2 + (float_size - dtype_size) * block_head);
if (seq_offsets != nullptr && (discrete || decoder_mode)) {
__memcpy(nram_offsets, seq_offsets + seq_begin, seq_block * sizeof(int), GDRAM2NRAM);
__bang_mul_scalar(nram_offsets, nram_offsets, rotary_stride * dtype_size, seq_block);
__sync_io_move_compute();
}
bool gather_table = (seq_offsets != nullptr && decoder_mode) || discrete;
int load_head_num = 0;
int calc_head_num = 0;
int store_head_num = 0;
for (int i = 0; i < loop_num + 2; i++) {
// store
if (i > 1) {
store_head_num = std::min(once_head_num, head_num - (i - 2) * once_head_num);
if ((head_size - rotary_dim) > 0) {
__memcpy_async(output + (i - 2) * once_head_num * output_head_stride + rotary_dim,
input + (i - 2) * once_head_num * input_head_stride + rotary_dim,
(head_size - rotary_dim) * dtype_size, GDRAM2GDRAM,
output_seq_stride * dtype_size, seq_block - 1,
output_head_stride * dtype_size, store_head_num - 1,
input_seq_stride * dtype_size, seq_block - 1, input_head_stride * dtype_size,
store_head_num - 1);
}
float *nram_store = buffer + (i % 2) * block_head;
__memcpy_async(output + (i - 2) * once_head_num * output_head_stride, nram_store,
rotary_dim * dtype_size, NRAM2GDRAM, output_seq_stride * dtype_size,
seq_block - 1, output_head_stride * dtype_size, store_head_num - 1,
rotary_dim * dtype_size, seq_block - 1, seq_block * rotary_dim * dtype_size,
store_head_num - 1);
}
// load
float *temp_load = buffer + (i % 2) * block_head;
T *nram_load = (T *)((int8_t *)temp_load + (float_size - dtype_size) * block_head);
if (i < loop_num) {
load_head_num = std::min(once_head_num, head_num - i * once_head_num);
__memcpy_async(nram_load, input + i * once_head_num * input_head_stride,
rotary_dim * dtype_size, GDRAM2NRAM, rotary_dim * dtype_size, seq_block - 1,
seq_block * rotary_dim * dtype_size, load_head_num - 1,
input_seq_stride * dtype_size, seq_block - 1, input_head_stride * dtype_size,
load_head_num - 1);
}
if (i == 1) {
loadTableAsync(sincos_, sin_table, nram_offsets, rotary_dim, rotary_stride, seq_block,
seq_begin, dtype_size, gather_table, decoder_mode);
loadTableAsync(sincos_ + seq_rotary, cos_table, nram_offsets, rotary_dim, rotary_stride,
seq_block, seq_begin, dtype_size, gather_table, decoder_mode);
}
// compute
if (i > 0 && i < loop_num + 1) {
float *q_1 = buffer + ((i + 1) % 2) * block_head;
T *q_1_ = (T *)((int8_t *)q_1 + (float_size - dtype_size) * block_head);
calc_head_num = std::min(once_head_num, head_num - (i - 1) * once_head_num);
__memcpy_async(q_2_, q_1_ + rotary_dim / 2, rotary_dim / 2 * dtype_size, NRAM2NRAM,
rotary_dim * dtype_size, rotary_dim * dtype_size,
calc_head_num * seq_block - 1);
__memcpy_async(q_2_ + rotary_dim / 2, q_1_, rotary_dim / 2 * dtype_size, NRAM2NRAM,
rotary_dim * dtype_size, rotary_dim * dtype_size,
calc_head_num * seq_block - 1);
__sync_move();
toFloat(q_1, q_1_, block_head);
__bang_cycle_mul(q_2_, q_2_, (T *)nram_mask, block_head, rotary_dim);
toFloat(q_2, q_2_, block_head);
if (i == 1) {
__sync_io();
toFloat(sin, sincos_, seq_rotary * 2);
}
__bang_cycle_mul(q_1, q_1, cos, block_head, seq_rotary);
__bang_cycle_mul(q_2, q_2, sin, block_head, seq_rotary);
__bang_add(q_1, q_1, q_2, block_head);
floatTo((T *)q_1, q_1, block_head);
}
__sync_io_move_compute();
}
}
// [bs, seq_block]
template <typename T, bool interleaved>
__mlu_global__ void MluRotaryEmebdding(void *output,
const void *input,
const void *sin_table,
const void *cos_table,
const int *seq_offsets,
const int *cu_seq_lens,
int batch,
int max_seq_len,
int head_num,
int head_size,
int rotary_seq_len,
int rotary_dim,
int seq_once,
int rotary_stride,
size_t input_seq_stride,
size_t input_head_stride,
size_t output_seq_stride,
size_t output_head_stride,
bool discrete,
bool dynamic_ntk,
bool decoder_mode,
bool loop_head,
int once_head_num) {
initMask<T>(nram_mask, rotary_dim, interleaved);
int head_begin = taskIdX;
int head_per_task = taskDimX == 1 ? head_num : 1;
// decode mode little diff: no loop
if (decoder_mode) {
int task_begin_seq = taskIdY * seq_once;
int seq_block = std::min(batch - task_begin_seq, seq_once);
if (seq_block <= 0 || __is_mpu()) {
return;
}
size_t input_offset = task_begin_seq * input_seq_stride + head_begin * input_head_stride;
size_t output_offset = task_begin_seq * output_seq_stride + head_begin * output_head_stride;
T *input_begin = (T *)input + input_offset;
T *output_begin = (T *)output + output_offset;
if (interleaved) {
crossRotaryEmbedding((T *)output_begin, (T *)input_begin, (T *)sin_table, (T *)cos_table,
(int *)seq_offsets, head_per_task, seq_block, head_size, rotary_dim,
rotary_stride, input_head_stride, input_seq_stride, output_head_stride,
output_seq_stride, task_begin_seq, discrete, decoder_mode);
} else {
foldRotaryEmbedding((T *)output_begin, (T *)input_begin, (T *)sin_table, (T *)cos_table,
(int *)seq_offsets, head_per_task, seq_block, head_size, rotary_dim,
rotary_stride, input_head_stride, input_seq_stride, output_head_stride,
output_seq_stride, task_begin_seq, discrete, decoder_mode, loop_head,
once_head_num);
}
return;
}
int seq_begin = cu_seq_lens == nullptr ? taskIdY * max_seq_len : cu_seq_lens[taskIdY];
int seq_len = cu_seq_lens == nullptr ? max_seq_len : cu_seq_lens[taskIdY + 1] - seq_begin;
for (int i = taskIdZ * seq_once; i < seq_len; i += taskDimZ * seq_once) {
int seq_block = std::min(seq_once, seq_len - i);
int global_seq_begin = seq_begin + i;
int seq_block_begin = i;
size_t input_offset = global_seq_begin * input_seq_stride + head_begin * input_head_stride;
size_t output_offset = global_seq_begin * output_seq_stride + head_begin * output_head_stride;
size_t bs_table_offset = dynamic_ntk ? (size_t)taskIdY * rotary_seq_len * rotary_stride : 0;
T *input_begin = (T *)input + input_offset;
T *output_begin = (T *)output + output_offset;
T *sin_table_begin = (T *)sin_table + bs_table_offset + (size_t)seq_block_begin * rotary_stride;
T *cos_table_begin = (T *)cos_table + bs_table_offset + (size_t)seq_block_begin * rotary_stride;
if (seq_offsets != nullptr && !discrete) {
sin_table_begin += seq_offsets[taskIdY] * (size_t)rotary_stride;
cos_table_begin += seq_offsets[taskIdY] * (size_t)rotary_stride;
} else if (seq_offsets != nullptr && discrete) {
sin_table_begin = (T *)sin_table + bs_table_offset;
cos_table_begin = (T *)cos_table + bs_table_offset;
}
if (interleaved) {
crossRotaryEmbedding((T *)output_begin, (T *)input_begin, (T *)sin_table_begin,
(T *)cos_table_begin, (int *)seq_offsets, head_per_task, seq_block,
head_size, rotary_dim, rotary_stride, input_head_stride,
input_seq_stride, output_head_stride, output_seq_stride,
global_seq_begin, discrete, decoder_mode);
__sync_io_move_compute();
} else {
foldRotaryEmbedding((T *)output_begin, (T *)input_begin, (T *)sin_table_begin,
(T *)cos_table_begin, (int *)seq_offsets, head_per_task, seq_block,
head_size, rotary_dim, rotary_stride, input_head_stride, input_seq_stride,
output_head_stride, output_seq_stride, global_seq_begin, discrete,
decoder_mode, loop_head, once_head_num);
}
}
}
#if __BANG_ARCH__ < 592
template <>
__mlu_global__ void MluRotaryEmebdding<bfloat16_t, true>(void *output,
const void *input,
const void *sin_table,
const void *cos_table,
const int *seq_offsets,
const int *cu_seq_lens,
int batch,
int max_seq_len,
int head_num,
int head_size,
int rotary_seq_len,
int rotary_dim,
int seq_once,
int rotary_stride,
size_t input_seq_stride,
size_t input_head_stride,
size_t output_seq_stride,
size_t output_head_stride,
bool discrete,
bool dynamic_ntk,
bool decoder_mode,
bool loop_head,
int once_head_num) {}
template <>
__mlu_global__ void MluRotaryEmebdding<bfloat16_t, false>(void *output,
const void *input,
const void *sin_table,
const void *cos_table,
const int *seq_offsets,
const int *cu_seq_lens,
int batch,
int max_seq_len,
int head_num,
int head_size,
int rotary_seq_len,
int rotary_dim,
int seq_once,
int rotary_stride,
size_t input_seq_stride,
size_t input_head_stride,
size_t output_seq_stride,
size_t output_head_stride,
bool discrete,
bool dynamic_ntk,
bool decoder_mode,
bool loop_head,
int once_head_num) {}
#endif
} // namespace kernels
KernelStatus invokeRotaryEmbedding(cnrtQueue_t queue,
void *output,
const void *input,
const void *sin_table,
const void *cos_table,
const int *seq_offsets,
const int *cu_seq_lens,
int batch,
int max_seq_len,
int head_num,
int head_size,
int rotary_seq_len,
int rotary_dim,
int rotary_stride,
size_t input_seq_stride,
size_t input_head_stride,
size_t output_seq_stride,
size_t output_head_stride,
bool interleaved,
bool discrete,
bool dynamic_ntk,
cnnlDataType_t data_type) {
void (*rotary_embedding_kernels[])(void *, /* output */
const void *, /* input */
const void *, /* sin_table */
const void *, /* cos_table */
const int *, /* seq_offsets */
const int *, /* cu_seq_lens */
int, /* batch */
int, /* max_seq_len */
int, /* head_num */
int, /* head_size */
int, /* rotary_seq_len */
int, /* rotary_dim */
int, /* seq_once */
int, /* rotary_stride */
size_t, /* input_seq_stride */
size_t, /* input_head_stride */
size_t, /* output_seq_stride */
size_t, /* output_head_stride */
bool, /* discrete, */
bool, /* dynamic_ntk */
bool, /* decoder_mode */
bool, /* loop_head */
int) /* once_head_num */
= {kernels::MluRotaryEmebdding<half, true>,
kernels::MluRotaryEmebdding<half, false>,
kernels::MluRotaryEmebdding<bfloat16_t, true>,
kernels::MluRotaryEmebdding<bfloat16_t, false>,
kernels::MluRotaryEmebdding<float, true>,
kernels::MluRotaryEmebdding<float, false>};
int kernel_index = 0;
if (data_type == CNNL_DTYPE_HALF) {
kernel_index = interleaved ? 0 : 1;
} else if (data_type == CNNL_DTYPE_BFLOAT16) {
kernel_index = interleaved ? 2 : 3;
} else if (data_type == CNNL_DTYPE_FLOAT) {
kernel_index = interleaved ? 4 : 5;
}
if (head_size > 256) {
std::cerr << "[invokeRotaryEmbedding]: only supported head_size <= 256, currently head_size = "
<< head_size << std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
CNdev dev;
cnCtxGetDevice(&dev);
int cluster_num;
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
int core_num;
CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev));
int total_core_num = cluster_num * core_num;
uint32_t seq_once = data_type == CNNL_DTYPE_FLOAT ? (rotary_dim > 128 ? 64 : 128)
: (rotary_dim > 128 ? 128 : 256);
// decode场景需要判断空间是否够fold场景下最大限制为每个ipu处理64cross限制为batch*head小于等于sq_once
int batch_per_core = (batch + total_core_num - 1) / total_core_num;
int batch_per_core_cap = 64;
bool batch_limit = interleaved ? (batch_per_core * head_num <= seq_once)
: (batch_per_core <= batch_per_core_cap);
bool decoder_mode = batch_limit && max_seq_len == 1 && dynamic_ntk == false;
bool do_one_head_per_task = (head_num > 32 && max_seq_len > 2048) || head_num > seq_once;
seq_once = do_one_head_per_task ? seq_once : seq_once / head_num;
// fold rotary做了流水拆分有所不同。
bool loop_head = true;
int once_head_num = 1;
if (!interleaved) {
seq_once = rotary_dim > 128 ? 64 : 128;
// 小seq情况下不够拆需要减小seq_once
if (batch * (max_seq_len + seq_once - 1) / seq_once < total_core_num) {
seq_once = std::max(1, max_seq_len / (total_core_num / batch));
}
do_one_head_per_task = false;
// 判断decode场景能否一次性处理完所有head
if (decoder_mode) {
loop_head = false;
int nram_buffer_size = 480 * 1024;
int nram_input_size = batch_per_core * head_num * rotary_dim * sizeof(float);
int nram_q2_size = batch_per_core * head_num * rotary_dim * sizeof(float);
int nram_table_size = batch_per_core * rotary_dim * sizeof(float) * 2;
int total_nram_size = nram_input_size + nram_q2_size + nram_table_size;
loop_head = total_nram_size > nram_buffer_size;
if (loop_head) {
// 如果需要循环,则重新计算每次处理多少头
once_head_num = (nram_buffer_size - nram_table_size) /
(batch_per_core * rotary_dim * sizeof(float) * 3);
}
// rebalance
int loop_num = (head_num + once_head_num - 1) / once_head_num;
once_head_num = (head_num + loop_num - 1) / loop_num;
}
}
uint32_t seq_segments = ((uint32_t)max_seq_len + seq_once - 1) / seq_once;
uint32_t task_dimx = do_one_head_per_task ? head_num : 1;
uint32_t task_dimz = total_core_num > seq_segments ? seq_segments : total_core_num;
uint32_t task_dimy =
decoder_mode && !do_one_head_per_task ? (uint32_t)total_core_num : (uint32_t)batch;
seq_once = decoder_mode ? (batch + task_dimy - 1) / task_dimy : seq_once;
cnrtDim3_t dim = {task_dimx, task_dimy, task_dimz};
if (data_type == CNNL_DTYPE_BFLOAT16 && !isBf16Supported()) {
std::cerr << "[invokeRotaryEmbedding]: MLU300 devices do not support bfloat16." << std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
rotary_embedding_kernels[kernel_index]<<<dim, cnrtFuncTypeBlock, queue>>>(
output, input, sin_table, cos_table, seq_offsets, cu_seq_lens, batch, max_seq_len, head_num,
head_size, rotary_seq_len, rotary_dim, seq_once, rotary_stride, input_seq_stride,
input_head_stride, output_seq_stride, output_head_stride, discrete, dynamic_ntk, decoder_mode,
loop_head, once_head_num);
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
KernelStatus invokeGlm6BRotaryEmbedding(cnrtQueue_t queue,
void *output,
const void *input,
const void *sin_table,
const void *cos_table,
const int *seq_offsets,
const int *cu_seq_lens,
int batch,
int max_seq_len,
int total_seq_len,
int head_num,
int head_size,
int rotary_seq_len,
int rotary_stride,
size_t input_seq_stride,
size_t input_head_stride,
size_t output_seq_stride,
size_t output_head_stride,
bool interleaved,
cnnlDataType_t data_type) {
size_t type_size = 0;
cnnlGetSizeOfDataType(data_type, &type_size);
invokeRotaryEmbedding(queue, output, input, sin_table, cos_table, seq_offsets, cu_seq_lens, batch,
max_seq_len, head_num, head_size / 2, rotary_seq_len, head_size / 2,
rotary_stride, input_seq_stride, input_head_stride, output_seq_stride,
output_head_stride, interleaved, true, false, data_type);
invokeRotaryEmbedding(queue, (int8_t *)output + head_size / 2 * type_size,
(int8_t *)input + head_size / 2 * type_size, sin_table, cos_table,
seq_offsets + total_seq_len, cu_seq_lens, batch, max_seq_len, head_num,
head_size / 2, rotary_seq_len, head_size / 2, rotary_stride,
input_seq_stride, input_head_stride, output_seq_stride, output_head_stride,
interleaved, true, false, data_type);
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
} // namespace tmo

View File

@@ -0,0 +1,129 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_ROTARY_EMBEDDING_MLUH_
#define CSRC_KERNELS_ROTARY_EMBEDDING_MLUH_
#include "cnnl.h"
#include "kernel_utils.h"
namespace tmo {
/**
* @brief Apply rotary embedding.
* @param queue: The queue for mlu.
* @param output: Output. Pointer to the MLU memory that stores the output,
* the shape must be [total_seq_len, head_num, head_size]
* @param input: Input. Pointer to the MLU memory that stores the input
* the shape must be [total_seq_len, head_num, head_size].
* @param sin_table: Input. Pointer to the MLU memory that stores the sin value, may not be
* continous. If dynamic_ntk is true, the shape must be [batch, rotary_seq_len, rotary_dim]. If
* dynamic_ntk is false, the shape must be [rotary_seq_len, rotary_dim].
* @param cos_table: Input. Pointer to the MLU memory that stores the cos value, may not be
* continous. If dynamic_ntk is true, the shape must be [batch, rotary_seq_len, rotary_dim]. If
* dynamic_ntk is false, the shape must be [rotary_seq_len, rotary_dim].
* @param seq_offsets: Input. Pointer to the MLU memory that stores the sequene offsets of each
* batch. If discrete is true, the shape must be [total_seq_len]. If discrete is false, the shape
* must be [batch]. Seq_offsets could be nullptr if discrete is false, which means no offset for
* each batch.
* @param cu_seq_lens: Input. Pointer to the MLU memory that stores the cumulative sequence length
* of each batch. The shape must be [batch + 1]. If cu_seq_lens is nullptr, Sequence length of all
* batches is max_seq_Len.
* @param batch: Batch size.
* @param max_seq_len: The maximum sequence length of input.
* @param head_num: Head number.
* @param head_size: Head size.
* @param rotary_seq_len: The rotary seq_len of sin_table and cos_table.
* @param rotary_dim: The rotary dimension of sin_table and cos_table.
* @param rotary_stride: The stride of rotary_seq_len in sin_table and cos_table.
* @param input_seq_stride: The stride of total_seq_len in input.
* @param input_head_stride: The stride of head_num in input.
* @param output_seq_stride: The stride of total_seq_len in output.
* @param output_head_stride: The stride of head_num in output.
* @param interleaved: A boolean value indicates compute mode of rotary embedding.
* @param discrete: A boolean value indicates whether all input tokens have offsets.
* @param dynamic_ntk: A boolean value indicates whether all batches have different sin_table and
* cos_table.
* @param data_type: Data type of all inputs and outputs.
*/
KernelStatus invokeRotaryEmbedding(cnrtQueue_t queue,
void *output,
const void *input,
const void *sin_table,
const void *cos_table,
const int *seq_offsets,
const int *cu_seq_lens,
int batch,
int max_seq_len,
int head_num,
int head_size,
int rotary_seq_len,
int rotary_dim,
int rotary_stride,
size_t input_seq_stride,
size_t input_head_stride,
size_t output_seq_stride,
size_t output_head_stride,
bool interleaved,
bool discrete,
bool dynamic_ntk,
cnnlDataType_t data_type);
/**
* @brief Apply rotary embedding.
* @param queue: The queue for mlu.
* @param output: Output. Pointer to the MLU memory that stores the output,
* the shape must be [total_seq_len, head_num, head_size]
* @param input: Input. Pointer to the MLU memory that stores the input
* the shape must be [total_seq_len, head_num, head_size].
* @param sin_table: Input. Pointer to the MLU memory that stores the sin value, may not be
* continous. The shape must be [rotary_seq_len, head_size / 2].
* @param cos_table: Input. Pointer to the MLU memory that stores the cos value, may not be
* continous. The shape must be [rotary_seq_len, head_size / 2].
* @param seq_offsets: Input. Pointer to the MLU memory that stores the sequene offsets of each
* batch. The Shape must be [2, total_seq_len].
* @param cu_seq_lens: Input. Pointer to the MLU memory that stores the cumulative sequence length
* of each batch. The shape must be [batch + 1]. If cu_seq_lens is nullptr, Sequence length of all
* batches is max_seq_Len.
* @param batch: Batch size.
* @param max_seq_len: The maximum sequence length of input.
* @param head_num: Head number.
* @param head_size: Head size.
* @param rotary_seq_len: The rotary seq_len of sin_table and cos_table.
* @param rotary_stride: The stride of rotary_seq_len stride in sin_table and cos_table.
* @param input_seq_stride: The stride of total_seq_len in input.
* @param input_head_stride: The stride of head_num in input.
* @param output_seq_stride: The stride of total_seq_len in output.
* @param output_head_stride: The stride of head_num in output.
* @param interleaved: A boolean value indicates compute mode of rotary embedding.
* @param data_type: Data type of all inputs and outputs.
*/
KernelStatus invokeGlm6BRotaryEmbedding(cnrtQueue_t queue,
void *output,
const void *input,
const void *sin_table,
const void *cos_table,
const int *seq_offsets,
const int *cu_seq_lens,
int batch,
int max_seq_len,
int total_seq_len,
int head_num,
int head_size,
int rotary_seq_len,
int rotary_stride,
size_t input_seq_stride,
size_t input_head_stride,
size_t output_seq_stride,
size_t output_head_stride,
bool interleaved,
cnnlDataType_t data_type);
} // namespace tmo
#endif // CSRC_KERNELS_ROTARY_EMBEDDING_MLUH_

View File

@@ -0,0 +1,22 @@
#include "swap_blocks.mluh"
namespace tmo {
KernelStatus invokeSwapBlocksKernel(const cnnlHandle_t handle,
void *dst,
const void *src,
const int64_t &block_size_in_bytes,
const cnrtMemTransDir_t &memcpy_type,
const std::map<int64_t, int64_t> &block_mapping) {
cnrtQueue_t queue;
cnnlGetQueue(handle, &queue);
for (const auto &pair : block_mapping) {
int64_t src_block_number = pair.first;
int64_t dst_block_number = pair.second;
int64_t src_offset = src_block_number * block_size_in_bytes;
int64_t dst_offset = dst_block_number * block_size_in_bytes;
cnrtMemcpyAsync((int8_t *)dst + dst_offset, (int8_t *)src + src_offset, block_size_in_bytes,
queue, memcpy_type);
}
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
} // namespace tmo

View File

@@ -0,0 +1,39 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_SWAP_BLOCKS_MLUH_
#define CSRC_KERNELS_SWAP_BLOCKS_MLUH_
#include <map>
#include "cnnl.h"
#include "kernel_utils.h"
namespace tmo {
/**
* @brief Perform swap_blocks operation.
* @param handle: The handle of cnnl.
* @param dst: Output. Pointer to the MLU memory that stores the dst tensor which has shape
* [num_blocks, num_heads, block_size, head_size].
* @param src: Input. Pointer to the MLU memory that stores the src tensor which has shape
* [num_blocks, num_heads, block_size, head_size].
* @param block_size_in_bytes: Data block size for each copy.
* @param memcpy_type: Copy direction, including h2d, d2h and d2d.
* @param block_mapping: Mapping table of src and dst.
*/
KernelStatus invokeSwapBlocksKernel(const cnnlHandle_t handle,
void *dst,
const void *src,
const int64_t &block_size_in_bytes,
const cnrtMemTransDir_t &memcpy_type,
const std::map<int64_t, int64_t> &block_mapping);
} // namespace tmo
#endif // CSRC_KERNELS_SWAP_BLOCKS_MLUH_

View File

@@ -0,0 +1,447 @@
// clang-format off
#include <mlu.h>
// clang-format on
#include "kernel_utils.h"
#include "update_out_and_lse.mluh"
namespace tmo {
namespace kernels {
#define NRAM_REMAIN_SIZE (32 * 1024)
#define NRAM_BUFFER_SIZE (__MLU_NRAM_SIZE__ * 1024 - NRAM_REMAIN_SIZE)
#define INF (2139095040)
__nram__ char nram_buffer[NRAM_BUFFER_SIZE];
__mlu_func__ void splitTask(int32_t total_task, int32_t &task_length, int32_t &task_offset) {
task_length = (total_task + taskDimX - 1) / taskDimX;
task_offset = taskIdX * task_length;
task_length = std::min(total_task - task_offset, task_length);
}
template <typename T>
__mlu_func__ void toFloat(float *dst, T *src, int32_t num) {
if (std::is_same<T, half>::value) {
__bang_half2float(dst, (half *)src, num);
} else if (std::is_same<T, bfloat16_t>::value) {
__bang_bfloat162float(dst, (bfloat16_t *)src, num);
} else if (std::is_same<T, float>::value) {
__bang_move(dst, src, num * sizeof(float));
}
}
template <typename T>
__mlu_func__ void floatTo(T *dst, float *src, int32_t num) {
if (std::is_same<T, half>::value) {
__bang_float2half_rn((half *)dst, src, num);
} else if (std::is_same<T, bfloat16_t>::value) {
__bang_float2bfloat16_rn((bfloat16_t *)dst, src, num);
} else if (std::is_same<T, float>::value) {
__bang_move(dst, src, num * sizeof(float));
}
}
template <typename T>
__mlu_func__ void subCvt(T *dst, float *src0, float *src1, int32_t num) {
#if __BANG_ARCH__ >= 500
if (std::is_same<T, half>::value) {
__asm__("sub.nram.crn.f16.f32 [%[dst]], [%[src0]], [%[src1]], %[num];" ::[dst] "r"(dst),
[src0] "r"(src0), [src1] "r"(src1), [num] "r"(num));
} else if (std::is_same<T, bfloat16_t>::value) {
__asm__("sub.nram.crn.bf16.f32 [%[dst]], [%[src0]], [%[src1]], %[num];" ::[dst] "r"(dst),
[src0] "r"(src0), [src1] "r"(src1), [num] "r"(num));
} else {
__asm__("sub.nram.crn.f32 [%[dst]], [%[src0]], [%[src1]], %[num];" ::[dst] "r"(dst),
[src0] "r"(src0), [src1] "r"(src1), [num] "r"(num));
}
#else
__bang_sub((float *)dst, src0, src1, num);
floatTo(dst, (float *)dst, num);
#endif
}
// task_length在非decoder模式下为每个实际的block_seq_len的长度在decoder模式下是batch*head_num的长度。
// out_task_stride/block_task_stride在非deocder模式下为seq_stride在decoder模式下是head_stride。
// out = out - sigmoid(block_lse - lse) * (out - block_out)
// lse = lse - logsigmoid(lse - block_lse)
// sigmoid: y = 1 / (1 + e ^ (-x))
// logsigmoid: y = log(1 / (1 + e ^ (-x))) = -log(1 + e ^ (-x))
template <typename T>
__mlu_func__ void updateOutAndLse(T *out,
float *lse,
const T *block_out,
const float *block_lse,
int32_t head_size,
int32_t task_length,
int64_t out_task_stride,
int64_t block_task_stride) {
constexpr int32_t tmp = 0x3fb8aa3b;
const float log2e = *(float *)&tmp;
const float neg_log2e = (-1) * log2e;
const float recip_log2e = 1 / log2e;
constexpr bool is_fp32 = std::is_same<T, float>::value;
constexpr int32_t buffer_num = is_fp32 ? 5 : 3;
constexpr int32_t pad_length = is_fp32 ? 16 : 32;
int32_t task_once =
PAD_DOWN(NRAM_BUFFER_SIZE / ((head_size * buffer_num + 3) * sizeof(float)), pad_length);
int32_t task_loop = (task_length + task_once - 1) / task_once;
if (task_loop > 1) {
task_once = PAD_UP((task_length + task_loop - 1) / task_loop, pad_length);
}
float *nram_out = (float *)nram_buffer;
float *nram_block_out = nram_out + head_size * task_once * (is_fp32 + 1);
float *nram_lse = nram_block_out + head_size * task_once * (is_fp32 + 1);
float *nram_block_lse = nram_lse + task_once;
float *nram_sigmoid_result = nram_block_lse + task_once;
__attribute__((unused)) float *nram_end = nram_sigmoid_result + task_once;
int32_t task_deal{task_once};
for (int32_t task_i = 0; task_i < task_loop; ++task_i) {
task_deal = std::min(task_once, task_length - task_i * task_once);
/*
读取 lse block_lse
*/
__memcpy_async(nram_lse, lse, task_deal * sizeof(float), GDRAM2NRAM);
__memcpy_async(nram_block_lse, block_lse, task_deal * sizeof(float), GDRAM2NRAM);
__sync_io();
/*
计算 lse + 读取 out
block_lse = (block_lse - lse) * -log2e
sigmoid_result = 1 / (pow2(block_lse) + 1)
*/
__bang_fusion(FUSION_FSM, nram_block_lse, nram_block_lse, nram_lse, neg_log2e, task_deal,
task_deal);
__bang_pow2(nram_sigmoid_result, nram_block_lse, task_deal);
__bang_add_scalar(nram_sigmoid_result, nram_sigmoid_result, (float)1, task_deal);
__bang_recip(nram_sigmoid_result, nram_sigmoid_result, task_deal);
__memcpy_async(nram_out, out, head_size * sizeof(T), GDRAM2NRAM, head_size * sizeof(T),
out_task_stride * sizeof(T), task_deal - 1);
__sync_io();
/*
转置和升位宽 out + 读取 block_out
*/
__bang_transpose((T *)nram_out + task_deal * head_size, (T *)nram_out, task_deal, head_size);
toFloat(nram_out, (T *)nram_out + task_deal * head_size, task_deal * head_size);
__memcpy_async(nram_block_out, block_out, head_size * sizeof(T), GDRAM2NRAM,
head_size * sizeof(T), block_task_stride * sizeof(T), task_deal - 1);
__sync_io();
/*
转置和升位宽 block_out
*/
__bang_transpose((T *)nram_block_out + task_deal * head_size, (T *)nram_block_out, task_deal,
head_size);
toFloat(nram_block_out, (T *)nram_block_out + task_deal * head_size, task_deal * head_size);
/*
bang_fusor的计算流
((out - block_out) * sigmoid_result * -1 + out).tofp16()
*/
__bang_sub(nram_block_out, nram_out, nram_block_out, task_deal * head_size);
__bang_cycle_mul(nram_block_out, nram_block_out, nram_sigmoid_result, task_deal * head_size,
task_deal);
subCvt((T *)nram_out, nram_out, nram_block_out, task_deal * head_size);
__bang_transpose((T *)nram_out + task_deal * head_size, (T *)nram_out, head_size, task_deal);
__sync_compute();
__memcpy_async(out, (T *)nram_out + task_deal * head_size, head_size * sizeof(T), NRAM2GDRAM,
out_task_stride * sizeof(T), head_size * sizeof(T), task_deal - 1);
/*
算法上: lse = lse - logsigmoid(lse - block_lse)
= lse - (-log(1 + e ^ (-(lse - block_lse))))
= lse + log(1 + e ^ (block_lse - lse))
之前block_lse = (block_lse - lse) * -log2e
实际逻辑如下:
block_lse = block_lse * -1
= (block_lse - lse) * log2e
block_lse = log2(pow2(block_lse) + 1) / log2e + lse
= lse + log(1 + e ^ (block_lse - lse))
*/
__bang_mul_scalar(nram_block_lse, nram_block_lse, -1, task_deal);
__bang_pow2(nram_sigmoid_result, nram_block_lse, task_deal);
__bang_add_scalar(nram_sigmoid_result, nram_sigmoid_result, (float)1.0f, task_deal);
__bang_log2(nram_sigmoid_result, nram_sigmoid_result, task_deal);
__bang_mul_scalar(nram_sigmoid_result, nram_sigmoid_result, recip_log2e, task_deal);
/*
nram_sigmoid_result 中的值为log(1 + e ^ (block_lse - lse))
在一些数值分布场景例如block_lse - lse大于85左右时这个值会出现inf。
gpu采用的是std::log1p,
在原始公式中log(1 / (1 + e ^ (block_lse - lse)))中的log里的数值会极限接近0
采用log1p会比普通log在靠近0时有更高精度。
mlu这里由于做了对数倒数外提*-1所以log里的数值会变为inf(如果不外提log2(0)同样会出现inf)。
logsigmoid在大数值场景下是等于原值的例如logsigmoid(-100) = -100
所以以下逻辑用于 识别inf的值对inf的位置进行写入原值。
*/
__bang_ne_scalar((uint32_t *)nram_block_out, (uint32_t *)nram_sigmoid_result, (uint32_t)INF,
task_deal);
__bang_mul((uint32_t *)nram_sigmoid_result, (uint32_t *)nram_sigmoid_result,
(uint32_t *)nram_block_out, task_deal);
__bang_not((uint32_t *)nram_block_out, (uint32_t *)nram_block_out, task_deal);
__bang_mul((uint32_t *)nram_block_lse, (uint32_t *)nram_block_lse, (uint32_t *)nram_block_out,
task_deal);
__bang_mul_scalar(nram_block_lse, nram_block_lse, recip_log2e,
task_deal); // block_lse里的原值是*了log2e的这里需要除回去。
__bang_fusion(FUSION_FAA, nram_lse, nram_lse, nram_sigmoid_result, nram_block_lse, task_deal,
task_deal);
__sync_compute();
__memcpy_async(lse, nram_lse, task_deal * sizeof(float), NRAM2GDRAM);
lse += task_deal;
block_lse += task_deal;
out += task_deal * out_task_stride;
block_out += task_deal * out_task_stride;
}
}
// 非decoder模式采用taskDimZ拆分batchtaskDimY拆分head每个task内部进行block_seq_len的循环
template <typename T>
__mlu_global__ void MluUpdateOutAndLse(void *out,
float *lse,
const void *block_out,
const float *block_lse,
const int32_t *seq_offsets,
const int32_t *cu_seqs,
const int32_t *block_cu_seqs,
const int32_t batch,
const int32_t head_num,
const int32_t head_size,
const int32_t max_seq_len,
const int32_t block_seq_len,
const bool packed,
const int64_t bs_stride,
const int64_t seq_stride,
const int64_t head_stride,
const int64_t block_bs_stride,
const int64_t block_seq_stride,
const int64_t block_head_stride) {
#if __BANG_ARCH__ > 300
if (!(std::is_same<T, bfloat16_t>::value && __BANG_ARCH__ < 500)) {
int64_t kernel_out_offset = 0;
int64_t kernel_block_out_offset = 0;
int64_t kernel_lse_offset = 0;
int64_t kernel_block_lse_offset = 0;
int64_t kernel_seq_offset = 0;
int32_t block_seq_len_real{block_seq_len};
if (seq_offsets != nullptr) {
kernel_seq_offset = __load_gdram(seq_offsets + taskIdZ);
}
if (!packed) {
kernel_out_offset =
taskIdZ * bs_stride + taskIdY * head_stride + kernel_seq_offset * seq_stride;
kernel_block_out_offset = taskIdZ * block_bs_stride + taskIdY * block_head_stride;
kernel_lse_offset =
taskIdZ * max_seq_len * head_num + taskIdY * max_seq_len + kernel_seq_offset;
kernel_block_lse_offset = taskIdZ * block_seq_len * head_num + taskIdY * block_seq_len;
} else {
int32_t block_seq_begin = __load_gdram(block_cu_seqs + taskIdZ);
int32_t block_seq_end = __load_gdram(block_cu_seqs + taskIdZ + 1);
int32_t out_seq_begin = __load_gdram(cu_seqs + taskIdZ);
block_seq_len_real = block_seq_end - block_seq_begin;
kernel_out_offset = (out_seq_begin + kernel_seq_offset) * seq_stride + taskIdY * head_stride;
kernel_block_out_offset = block_seq_begin * block_seq_stride + taskIdY * head_stride;
kernel_lse_offset =
taskIdZ * max_seq_len * head_num + taskIdY * max_seq_len + kernel_seq_offset;
kernel_block_lse_offset = taskIdZ * block_seq_len * head_num + taskIdY * block_seq_len;
}
auto kernel_out = (T *)out + kernel_out_offset;
auto kernel_lse = lse + kernel_lse_offset;
auto kernel_block_out = (T *)block_out + kernel_block_out_offset;
auto kernel_block_lse = block_lse + kernel_block_lse_offset;
updateOutAndLse<T>(kernel_out, kernel_lse, kernel_block_out, kernel_block_lse, head_size,
block_seq_len_real, seq_stride, block_seq_stride);
}
#endif
}
// decoder模式下采用launch 所有core内部拆分batch*head_num维度
// 每个core处理 batch*head_num/taskDimX
template <typename T>
__mlu_global__ void MluUpdateOutAndLseDecoder(void *out,
float *lse,
const void *block_out,
const float *block_lse,
const int32_t *seq_offsets,
const int32_t *cu_seqs,
const int32_t *block_cu_seqs,
const int32_t batch,
const int32_t head_num,
const int32_t head_size,
const int32_t max_seq_len,
const int32_t block_seq_len,
const bool packed,
const int64_t bs_stride,
const int64_t seq_stride,
const int64_t head_stride,
const int64_t block_bs_stride,
const int64_t block_seq_stride,
const int64_t block_head_stride) {
#if __BANG_ARCH__ > 300
if (!(std::is_same<T, bfloat16_t>::value && __BANG_ARCH__ < 500)) {
int32_t task_length{0}, task_begin{0};
splitTask(batch * head_num, task_length, task_begin);
if (task_length <= 0) {
return;
}
int32_t batch_idx = task_begin / head_num;
int32_t head_idx = task_begin % head_num;
auto kernel_out_offset = batch_idx * bs_stride + head_idx * head_stride;
auto kernel_block_out_offset = batch_idx * block_bs_stride + head_idx * block_head_stride;
auto kernel_lse_offset = batch_idx * max_seq_len * head_num + head_idx * max_seq_len;
auto kernel_block_lse_offset = batch_idx * block_seq_len * head_num + head_idx * block_seq_len;
auto kernel_out = (T *)out + kernel_out_offset;
auto kernel_lse = lse + kernel_lse_offset;
auto kernel_block_out = (T *)block_out + kernel_block_out_offset;
auto kernel_block_lse = block_lse + kernel_block_lse_offset;
updateOutAndLse<T>(kernel_out, kernel_lse, kernel_block_out, kernel_block_lse, head_size,
task_length, head_stride, block_head_stride);
}
#endif
}
#if __BANG_ARCH__ < 500
template <>
__mlu_global__ void MluUpdateOutAndLseDecoder<bfloat16_t>(void *out,
float *lse,
const void *block_out,
const float *block_lse,
const int32_t *seq_offsets,
const int32_t *cu_seqs,
const int32_t *block_cu_seqs,
const int32_t batch,
const int32_t head_num,
const int32_t head_size,
const int32_t max_seq_len,
const int32_t block_seq_len,
const bool packed,
const int64_t bs_stride,
const int64_t seq_stride,
const int64_t head_stride,
const int64_t block_bs_stride,
const int64_t block_seq_stride,
const int64_t block_head_stride) {}
template <>
__mlu_global__ void MluUpdateOutAndLse<bfloat16_t>(void *out,
float *lse,
const void *block_out,
const float *block_lse,
const int32_t *seq_offsets,
const int32_t *cu_seqs,
const int32_t *block_cu_seqs,
const int32_t batch,
const int32_t head_num,
const int32_t head_size,
const int32_t max_seq_len,
const int32_t block_seq_len,
const bool packed,
const int64_t bs_stride,
const int64_t seq_stride,
const int64_t head_stride,
const int64_t block_bs_stride,
const int64_t block_seq_stride,
const int64_t block_head_stride) {}
#endif
} // namespace kernels
inline int32_t dtype_index(cnnlDataType_t dtype) {
switch (dtype) {
case CNNL_DTYPE_HALF:
return 0;
break;
case CNNL_DTYPE_BFLOAT16:
return 1;
break;
case CNNL_DTYPE_FLOAT:
return 2;
break;
default:
return 0;
break;
}
}
KernelStatus invokeUpdateOutAndLse(cnrtQueue_t queue,
void *out,
float *lse,
const void *block_out,
const float *block_lse,
const int32_t *seq_offsets,
const int32_t *cu_seqs,
const int32_t *block_cu_seqs,
const int32_t batch,
const int32_t head_num,
const int32_t head_size,
const int32_t max_seq_len,
const int32_t block_seq_len,
const int64_t bs_stride,
const int64_t seq_stride,
const int64_t head_stride,
const int64_t block_bs_stride,
const int64_t block_seq_stride,
const int64_t block_head_stride,
const bool packed,
const cnnlDataType_t dtype) {
void (*update_out_and_lse_kernels[])(
void *, float *, const void *, const float *, const int32_t *, const int32_t *,
const int32_t *, const int32_t, const int32_t, const int32_t, const int32_t, const int32_t,
const bool, const int64_t, const int64_t, const int64_t, const int64_t, const int64_t,
const int64_t) = {kernels::MluUpdateOutAndLse<half>,
kernels::MluUpdateOutAndLse<bfloat16_t>,
kernels::MluUpdateOutAndLse<float>,
kernels::MluUpdateOutAndLseDecoder<half>,
kernels::MluUpdateOutAndLseDecoder<bfloat16_t>,
kernels::MluUpdateOutAndLseDecoder<float>};
// 非decoder模式采用taskDimZ拆分batchtaskDimY拆分head每个task内部进行block_seq_len的循环
uint32_t task_dimx = 1;
uint32_t task_dimy = head_num;
uint32_t task_dimz = batch;
bool decoder_mode = (block_seq_len == 1 && max_seq_len == 1);
if (decoder_mode) {
// decoder模式下采用launch 所有core内部拆分batch*head_num维度
// 每个core处理 batch*head_num/taskDimX
CNdev dev;
cnCtxGetDevice(&dev);
int32_t cluster_num;
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
int32_t core_num;
CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev));
task_dimx = std::min(PAD_UP(batch * head_num, 16) / 16, cluster_num * core_num);
task_dimy = 1;
task_dimz = 1;
}
cnrtDim3_t task_dim = {task_dimx, task_dimy, task_dimz};
cnrtFunctionType_t func_type = cnrtFuncTypeBlock;
int32_t kernel_index = dtype_index(dtype) + decoder_mode * 3;
if (dtype == CNNL_DTYPE_BFLOAT16 && !isBf16Supported()) {
std::cerr << "[invokeUpdateOutAndLse]: MLU300 devices do not support bfloat16." << std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
update_out_and_lse_kernels[kernel_index]<<<task_dim, func_type, queue>>>(
out, lse, block_out, block_lse, seq_offsets, cu_seqs, block_cu_seqs, batch, head_num,
head_size, max_seq_len, block_seq_len, packed, bs_stride, seq_stride, head_stride,
block_bs_stride, block_seq_stride, block_head_stride);
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
} // namespace tmo

View File

@@ -0,0 +1,80 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_UPDATE_OUT_AND_LSE_MLUH_
#define CSRC_KERNELS_UPDATE_OUT_AND_LSE_MLUH_
#include "cnnl.h"
#include "kernel_utils.h"
namespace tmo {
/**
* @brief Update out and log-sum-exp(lse) according to block out and block lse.
* @param queue: The queue for mlu.
* @param out: Input/Output. Pointer to the MLU memory that stores the origin out.
* In pad mode, the shape must be [batch, max_seq_len, head_num, head_size].
* In pack mode, the shape must be [total_seq_len, head_num, head_size].
* Dim of seq_len and head_num may have stride.
* @param lse: Input/Output. Pointer to the MLU memory that stores the origin lse.
* The shape must be [batch, head_num, max_seq_len]. Lse must be continuous.
* @param block_out: Input. Pointer to the MLU memory that stores the block out.
* In pad mode, the shape must be [batch, block_seq_len, head_num, head_size].
* In pack mode, the shape must be [total_block_seq_len, head_num, head_size].
* Dim of seq_len and head_num may have stride.
* @param block_lse: Input. Pointer to the MLU memory that stores the origin lse.
* The shape must be [batch, head_num, block_seq_len]. Block_lse must be continuous.
* @param seq_offsets: Input. Pointer to the MLU memory that stores the origin out
* and lse sequence offset. The shape must be [batch].
* Seq_offsets must be continuous, and could be nullptr.
* @param cu_seqs: Input. Pointer to the MLU memory that stores the cumulative sum of out seq_lens,
* In pack mode, the shape must be [batch + 1].
* In pad mode. cu_seqs does not work, could be nullptr.
* @param block_cu_seqs: Input. Pointer to the MLU memory that stores the cumulative sum of block
* out seq_lens, In pack mode, the shape must be [batch + 1]. In pad mode. block_cu_seqs does not
* work, could be nullptr.
* @param dtype: Data type.
* @param batch: Batch size.
* @param head_num: Head number.
* @param head_size: Head size.
* @param max_seq_len: The sequence length of origin out.
* @param block_seq_len: The sequence length of block out.
* @param bs_stride: The stride of batch in origin out, does not work when packed is true.
* @param seq_stride: The stride of seq_len in origin out.
* @param head_stride: The stride of head_num in origin out.
* @param block_bs_stride: The stride of batch in block out, does not work when packed is true.
* @param block_seq_stride: The stride of seq_len in block out.
* @param block_head_stride: The stride of head_num in block out.
* @param packed: A boolean value indicates whether to use pack mode.
* @note All seq_lens in block out should be less than or equal to origin out.
*/
KernelStatus invokeUpdateOutAndLse(cnrtQueue_t queue,
void *out,
float *lse,
const void *block_out,
const float *block_lse,
const int *seq_offsets,
const int *cu_seqs,
const int *block_cu_seqs,
const int batch,
const int head_num,
const int head_size,
const int max_seq_len,
const int block_seq_len,
const int64_t bs_stride,
const int64_t seq_stride,
const int64_t head_stride,
const int64_t block_bs_stride,
const int64_t block_seq_stride,
const int64_t block_head_stride,
const bool packed,
const cnnlDataType_t dtype);
} // namespace tmo
#endif // CSRC_KERNELS_RESHAPE_LINEAR_CACHE_MLUH_