255 lines
10 KiB
Plaintext
255 lines
10 KiB
Plaintext
#include <cassert>
|
|
#include <iostream>
|
|
#include <map>
|
|
#include <ostream>
|
|
#include "cnnl.h"
|
|
#include "cnrt.h"
|
|
#include "dequantify.mluh"
|
|
// clang-format off
|
|
#include <mlu.h>
|
|
// clang-format on
|
|
|
|
namespace tmo {
|
|
namespace kernels {
|
|
template <typename T>
|
|
struct PackValueNum {
|
|
const static int value = 1;
|
|
};
|
|
|
|
template <>
|
|
struct PackValueNum<int4x2_t> {
|
|
const static int value = 2;
|
|
};
|
|
|
|
__nram__ int8_t nram_buf[(__MLU_NRAM_SIZE__ * 3 / 8 * 1024)];
|
|
|
|
__nram__ int8_t nram_buf_scale[8192];
|
|
|
|
__mlu_func__ void convert(float *dst, const int8_t *src, int count, float scale) {
|
|
__bang_int82float(dst, src, count, 0);
|
|
__bang_mul_scalar(dst, dst, scale, count);
|
|
}
|
|
|
|
__mlu_func__ void convert(float *dst, const int4x2_t *src, int count, float scale) {
|
|
__bang_int42float_rn(dst, src, count, 0);
|
|
__bang_mul_scalar(dst, dst, scale, count);
|
|
}
|
|
|
|
__mlu_func__ void convert(half *dst, const int8_t *src, int count, float scale) {
|
|
__bang_int82half(dst, src, count, 0);
|
|
__bang_mul_scalar(dst, dst, (half)scale, count);
|
|
}
|
|
|
|
__mlu_func__ void convert(half *dst, const int4x2_t *src, int count, float scale) {
|
|
__bang_int42half_rn(dst, src, count, 0);
|
|
__bang_mul_scalar(dst, dst, (half)scale, count);
|
|
}
|
|
|
|
template <typename T>
|
|
__mlu_func__ void swap(T *&ping, T *&pong) {
|
|
T *tmp = ping;
|
|
ping = pong;
|
|
pong = tmp;
|
|
}
|
|
|
|
template <typename TDst, typename TSrc>
|
|
__mlu_global__ void dequantifyPerTensor(void *all_dst,
|
|
const void *all_src,
|
|
size_t all_src_count,
|
|
float scale) {
|
|
scale = 1.0f / scale;
|
|
size_t src_per_core = all_src_count / taskDim;
|
|
size_t src_remain = all_src_count % taskDim;
|
|
size_t start = taskId * src_per_core + (taskId < src_remain ? taskId : src_remain);
|
|
const size_t src_count = src_per_core + (taskId < src_remain ? 1 : 0);
|
|
TDst *dst = reinterpret_cast<TDst *>(all_dst) + start * PackValueNum<TSrc>::value;
|
|
const TSrc *src = reinterpret_cast<const TSrc *>(all_src) + start;
|
|
|
|
constexpr int size_unit = sizeof(nram_buf) / 2 / // divide by 2 for ping pong
|
|
(sizeof(TSrc) + sizeof(TDst) * PackValueNum<TSrc>::value) / 128 *
|
|
128; // align to 128
|
|
constexpr int src_num_unit = size_unit / sizeof(TSrc);
|
|
constexpr int dst_num_unit = src_num_unit * PackValueNum<TSrc>::value;
|
|
int8_t *nram_buf_ping = nram_buf;
|
|
int8_t *nram_buf_pong = nram_buf + sizeof(nram_buf) / 2;
|
|
|
|
TSrc *nram_src_ping = reinterpret_cast<TSrc *>(nram_buf_ping);
|
|
TDst *nram_dst_ping =
|
|
reinterpret_cast<TDst *>(nram_buf_ping + static_cast<int>(sizeof(TSrc)) * size_unit);
|
|
TSrc *nram_src_pong = reinterpret_cast<TSrc *>(nram_buf_pong);
|
|
TDst *nram_dst_pong =
|
|
reinterpret_cast<TDst *>(nram_buf_pong + static_cast<int>(sizeof(TSrc)) * size_unit);
|
|
|
|
int loop_count = src_count / src_num_unit;
|
|
int remain_count = src_count % src_num_unit;
|
|
|
|
// L
|
|
__memcpy_async(nram_src_ping, src, sizeof(TSrc) * src_num_unit, GDRAM2NRAM);
|
|
swap(nram_src_ping, nram_src_pong);
|
|
swap(nram_dst_ping, nram_dst_pong);
|
|
__sync_io_move_compute();
|
|
|
|
// L C
|
|
__memcpy_async(nram_src_ping, src + 1 * src_num_unit, sizeof(TSrc) * src_num_unit, GDRAM2NRAM);
|
|
convert(nram_dst_pong, nram_src_pong, dst_num_unit, scale);
|
|
swap(nram_src_ping, nram_src_pong);
|
|
swap(nram_dst_ping, nram_dst_pong);
|
|
__sync_io_move_compute();
|
|
|
|
// L C S
|
|
for (int i = 0; i < loop_count - 2; ++i) {
|
|
__memcpy_async(nram_src_ping, src + (i + 2) * src_num_unit, sizeof(TSrc) * src_num_unit,
|
|
GDRAM2NRAM);
|
|
__memcpy_async(dst + i * dst_num_unit, nram_dst_ping, sizeof(TDst) * dst_num_unit, NRAM2GDRAM);
|
|
convert(nram_dst_pong, nram_src_pong, dst_num_unit, scale);
|
|
swap(nram_src_ping, nram_src_pong);
|
|
swap(nram_dst_ping, nram_dst_pong);
|
|
__sync_io_move_compute();
|
|
}
|
|
|
|
// C S
|
|
__memcpy_async(dst + (loop_count - 2) * dst_num_unit, nram_dst_ping, sizeof(TDst) * dst_num_unit,
|
|
NRAM2GDRAM);
|
|
convert(nram_dst_pong, nram_src_pong, dst_num_unit, scale);
|
|
swap(nram_src_ping, nram_src_pong);
|
|
swap(nram_dst_ping, nram_dst_pong);
|
|
__sync_io_move_compute();
|
|
|
|
// S
|
|
__memcpy_async(dst + (loop_count - 1) * dst_num_unit, nram_dst_ping, sizeof(TDst) * dst_num_unit,
|
|
NRAM2GDRAM);
|
|
|
|
__sync_io_move_compute();
|
|
|
|
if (remain_count > 0) {
|
|
__memcpy(nram_src_ping, src + loop_count * src_num_unit, sizeof(TSrc) * remain_count,
|
|
GDRAM2NRAM);
|
|
convert(nram_dst_ping, nram_src_ping, remain_count * PackValueNum<TSrc>::value, scale);
|
|
__memcpy(dst + loop_count * dst_num_unit, nram_dst_ping,
|
|
sizeof(TDst) * remain_count * PackValueNum<TSrc>::value, NRAM2GDRAM);
|
|
}
|
|
}
|
|
|
|
// does not use a pipeline because per channel is more complicated but it's a one-time operation, so
|
|
// performance doesn't matter.
|
|
template <typename TDst, typename TSrc>
|
|
__mlu_global__ void dequantifyPerChannel(void *all_dst,
|
|
const void *all_src,
|
|
int src_ci,
|
|
int all_co,
|
|
const void *scale) {
|
|
const int co_per_core = all_co / taskDim;
|
|
const int co_remain = all_co % taskDim;
|
|
const int start_co = taskId * co_per_core + (taskId < co_remain ? taskId : co_remain);
|
|
const int co_count = co_per_core + (taskId < co_remain ? 1 : 0);
|
|
assert(co_count <= sizeof(nram_buf_scale) / sizeof(TDst));
|
|
|
|
constexpr int size_unit = sizeof(nram_buf) /
|
|
(sizeof(TSrc) + sizeof(TDst) * PackValueNum<TSrc>::value) / 128 *
|
|
128; // align to 128
|
|
// yes, we only deal with 1 channel at a time
|
|
// no, there's no need to optimize a one-time operation
|
|
const int src_num_unit = std::min((int)(size_unit / sizeof(TSrc)), src_ci);
|
|
const int dst_num_unit = src_num_unit * PackValueNum<TSrc>::value;
|
|
TSrc *const nram_src = reinterpret_cast<TSrc *>(nram_buf);
|
|
TDst *const nram_dst =
|
|
reinterpret_cast<TDst *>(nram_buf + static_cast<int>(sizeof(TSrc)) * size_unit);
|
|
|
|
const TDst *nram_scale = reinterpret_cast<const TDst *>(nram_buf_scale);
|
|
|
|
const int loop_one_channel = src_ci / src_num_unit;
|
|
const int remain_one_channel = src_ci % src_num_unit;
|
|
|
|
for (int o = start_co; o < start_co + co_count; ++o) {
|
|
const TSrc *src = reinterpret_cast<const TSrc *>(all_src) + o * src_ci;
|
|
TDst *dst = reinterpret_cast<TDst *>(all_dst) + o * src_ci;
|
|
const TDst scale_value = 1. / nram_scale[o];
|
|
for (int i = 0; i < loop_one_channel; ++i) {
|
|
__memcpy(nram_src, src + i * src_num_unit, sizeof(TSrc) * src_num_unit, GDRAM2NRAM);
|
|
convert(nram_dst, nram_src, dst_num_unit, scale_value);
|
|
__memcpy(dst + i * dst_num_unit, nram_dst, sizeof(TDst) * dst_num_unit, NRAM2GDRAM);
|
|
}
|
|
if (remain_one_channel > 0) {
|
|
__memcpy(nram_src, src + loop_one_channel * src_num_unit, sizeof(TSrc) * remain_one_channel,
|
|
GDRAM2NRAM);
|
|
convert(nram_dst, nram_src, remain_one_channel * PackValueNum<TSrc>::value, scale_value);
|
|
__memcpy(dst + loop_one_channel * dst_num_unit, nram_dst,
|
|
sizeof(TDst) * remain_one_channel * PackValueNum<TSrc>::value, NRAM2GDRAM);
|
|
}
|
|
}
|
|
}
|
|
|
|
} // namespace kernels
|
|
static const std::map<std::pair<int, cnnlDataType_t>,
|
|
decltype(&kernels::dequantifyPerTensor<half, int4x2_t>)>
|
|
per_tensor_func_map = {
|
|
{{4, CNNL_DTYPE_HALF}, &kernels::dequantifyPerTensor<half, int4x2_t>},
|
|
{{4, CNNL_DTYPE_FLOAT}, &kernels::dequantifyPerTensor<float, int4x2_t>},
|
|
{{8, CNNL_DTYPE_HALF}, &kernels::dequantifyPerTensor<half, int8_t>},
|
|
{{8, CNNL_DTYPE_FLOAT}, &kernels::dequantifyPerTensor<float, int8_t>},
|
|
};
|
|
|
|
KernelStatus invokeDequantifyPerTensor(cnnlHandle_t handle,
|
|
const void *src,
|
|
int src_bitwidth,
|
|
void *dst,
|
|
cnnlDataType_t dst_dtype,
|
|
size_t src_count,
|
|
float scale) {
|
|
cnrtQueue_t queue;
|
|
cnnlGetQueue(handle, &queue);
|
|
CNdev dev;
|
|
cnnlGetDevice(handle, &dev);
|
|
int cluster_num;
|
|
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
|
|
const cnrtDim3_t dim = {.x = 4, .y = (uint32_t)cluster_num, .z = 1};
|
|
auto iter = per_tensor_func_map.find(std::make_pair(src_bitwidth, dst_dtype));
|
|
if (iter == per_tensor_func_map.end()) {
|
|
std::cerr << "[invokeDequantifyPerTensor]: unsupported src_bitwidth: " << src_bitwidth
|
|
<< " dst_dtype: " << dst_dtype;
|
|
return KernelStatus::KERNEL_STATUS_FAILED;
|
|
}
|
|
|
|
iter->second<<<dim, cnrtFuncTypeUnion1, queue>>>(dst, src, src_count, scale);
|
|
|
|
return KernelStatus::KERNEL_STATUS_SUCCESS;
|
|
}
|
|
|
|
static const std::map<std::pair<int, cnnlDataType_t>,
|
|
decltype(&kernels::dequantifyPerChannel<half, int4x2_t>)>
|
|
per_channel_func_map = {
|
|
{{4, CNNL_DTYPE_HALF}, &kernels::dequantifyPerChannel<half, int4x2_t>},
|
|
{{4, CNNL_DTYPE_FLOAT}, &kernels::dequantifyPerChannel<float, int4x2_t>},
|
|
{{8, CNNL_DTYPE_HALF}, &kernels::dequantifyPerChannel<half, int8_t>},
|
|
{{8, CNNL_DTYPE_FLOAT}, &kernels::dequantifyPerChannel<float, int8_t>},
|
|
};
|
|
|
|
KernelStatus invokeDequantifyPerChannel(cnnlHandle_t handle,
|
|
const void *src,
|
|
int src_bitwidth,
|
|
void *dst,
|
|
cnnlDataType_t dst_dtype,
|
|
int src_ci,
|
|
int co,
|
|
const void *scale) {
|
|
cnrtQueue_t queue;
|
|
cnnlGetQueue(handle, &queue);
|
|
CNdev dev;
|
|
cnnlGetDevice(handle, &dev);
|
|
int cluster_num;
|
|
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
|
|
const cnrtDim3_t dim = {.x = 4, .y = (uint32_t)cluster_num, .z = 1};
|
|
auto iter = per_channel_func_map.find(std::make_pair(src_bitwidth, dst_dtype));
|
|
if (iter == per_channel_func_map.end()) {
|
|
std::cerr << "[invokeDequantifyPerChannel]: unsupported src_bitwidth: " << src_bitwidth
|
|
<< " dst_dtype: " << dst_dtype;
|
|
return KernelStatus::KERNEL_STATUS_FAILED;
|
|
}
|
|
|
|
iter->second<<<dim, cnrtFuncTypeUnion1, queue>>>(dst, src, src_ci, co, scale);
|
|
|
|
return KernelStatus::KERNEL_STATUS_SUCCESS;
|
|
}
|
|
|
|
} // namespace tmo
|