#include #include #include #include #include "cnnl.h" #include "cnrt.h" #include "dequantify.mluh" // clang-format off #include // clang-format on namespace tmo { namespace kernels { template struct PackValueNum { const static int value = 1; }; template <> struct PackValueNum { const static int value = 2; }; __nram__ int8_t nram_buf[(__MLU_NRAM_SIZE__ * 3 / 8 * 1024)]; __nram__ int8_t nram_buf_scale[8192]; __mlu_func__ void convert(float *dst, const int8_t *src, int count, float scale) { __bang_int82float(dst, src, count, 0); __bang_mul_scalar(dst, dst, scale, count); } __mlu_func__ void convert(float *dst, const int4x2_t *src, int count, float scale) { __bang_int42float_rn(dst, src, count, 0); __bang_mul_scalar(dst, dst, scale, count); } __mlu_func__ void convert(half *dst, const int8_t *src, int count, float scale) { __bang_int82half(dst, src, count, 0); __bang_mul_scalar(dst, dst, (half)scale, count); } __mlu_func__ void convert(half *dst, const int4x2_t *src, int count, float scale) { __bang_int42half_rn(dst, src, count, 0); __bang_mul_scalar(dst, dst, (half)scale, count); } template __mlu_func__ void swap(T *&ping, T *&pong) { T *tmp = ping; ping = pong; pong = tmp; } template __mlu_global__ void dequantifyPerTensor(void *all_dst, const void *all_src, size_t all_src_count, float scale) { scale = 1.0f / scale; size_t src_per_core = all_src_count / taskDim; size_t src_remain = all_src_count % taskDim; size_t start = taskId * src_per_core + (taskId < src_remain ? taskId : src_remain); const size_t src_count = src_per_core + (taskId < src_remain ? 1 : 0); TDst *dst = reinterpret_cast(all_dst) + start * PackValueNum::value; const TSrc *src = reinterpret_cast(all_src) + start; constexpr int size_unit = sizeof(nram_buf) / 2 / // divide by 2 for ping pong (sizeof(TSrc) + sizeof(TDst) * PackValueNum::value) / 128 * 128; // align to 128 constexpr int src_num_unit = size_unit / sizeof(TSrc); constexpr int dst_num_unit = src_num_unit * PackValueNum::value; int8_t *nram_buf_ping = nram_buf; int8_t *nram_buf_pong = nram_buf + sizeof(nram_buf) / 2; TSrc *nram_src_ping = reinterpret_cast(nram_buf_ping); TDst *nram_dst_ping = reinterpret_cast(nram_buf_ping + static_cast(sizeof(TSrc)) * size_unit); TSrc *nram_src_pong = reinterpret_cast(nram_buf_pong); TDst *nram_dst_pong = reinterpret_cast(nram_buf_pong + static_cast(sizeof(TSrc)) * size_unit); int loop_count = src_count / src_num_unit; int remain_count = src_count % src_num_unit; // L __memcpy_async(nram_src_ping, src, sizeof(TSrc) * src_num_unit, GDRAM2NRAM); swap(nram_src_ping, nram_src_pong); swap(nram_dst_ping, nram_dst_pong); __sync_io_move_compute(); // L C __memcpy_async(nram_src_ping, src + 1 * src_num_unit, sizeof(TSrc) * src_num_unit, GDRAM2NRAM); convert(nram_dst_pong, nram_src_pong, dst_num_unit, scale); swap(nram_src_ping, nram_src_pong); swap(nram_dst_ping, nram_dst_pong); __sync_io_move_compute(); // L C S for (int i = 0; i < loop_count - 2; ++i) { __memcpy_async(nram_src_ping, src + (i + 2) * src_num_unit, sizeof(TSrc) * src_num_unit, GDRAM2NRAM); __memcpy_async(dst + i * dst_num_unit, nram_dst_ping, sizeof(TDst) * dst_num_unit, NRAM2GDRAM); convert(nram_dst_pong, nram_src_pong, dst_num_unit, scale); swap(nram_src_ping, nram_src_pong); swap(nram_dst_ping, nram_dst_pong); __sync_io_move_compute(); } // C S __memcpy_async(dst + (loop_count - 2) * dst_num_unit, nram_dst_ping, sizeof(TDst) * dst_num_unit, NRAM2GDRAM); convert(nram_dst_pong, nram_src_pong, dst_num_unit, scale); swap(nram_src_ping, nram_src_pong); swap(nram_dst_ping, nram_dst_pong); __sync_io_move_compute(); // S __memcpy_async(dst + (loop_count - 1) * dst_num_unit, nram_dst_ping, sizeof(TDst) * dst_num_unit, NRAM2GDRAM); __sync_io_move_compute(); if (remain_count > 0) { __memcpy(nram_src_ping, src + loop_count * src_num_unit, sizeof(TSrc) * remain_count, GDRAM2NRAM); convert(nram_dst_ping, nram_src_ping, remain_count * PackValueNum::value, scale); __memcpy(dst + loop_count * dst_num_unit, nram_dst_ping, sizeof(TDst) * remain_count * PackValueNum::value, NRAM2GDRAM); } } // does not use a pipeline because per channel is more complicated but it's a one-time operation, so // performance doesn't matter. template __mlu_global__ void dequantifyPerChannel(void *all_dst, const void *all_src, int src_ci, int all_co, const void *scale) { const int co_per_core = all_co / taskDim; const int co_remain = all_co % taskDim; const int start_co = taskId * co_per_core + (taskId < co_remain ? taskId : co_remain); const int co_count = co_per_core + (taskId < co_remain ? 1 : 0); assert(co_count <= sizeof(nram_buf_scale) / sizeof(TDst)); constexpr int size_unit = sizeof(nram_buf) / (sizeof(TSrc) + sizeof(TDst) * PackValueNum::value) / 128 * 128; // align to 128 // yes, we only deal with 1 channel at a time // no, there's no need to optimize a one-time operation const int src_num_unit = std::min((int)(size_unit / sizeof(TSrc)), src_ci); const int dst_num_unit = src_num_unit * PackValueNum::value; TSrc *const nram_src = reinterpret_cast(nram_buf); TDst *const nram_dst = reinterpret_cast(nram_buf + static_cast(sizeof(TSrc)) * size_unit); const TDst *nram_scale = reinterpret_cast(nram_buf_scale); const int loop_one_channel = src_ci / src_num_unit; const int remain_one_channel = src_ci % src_num_unit; for (int o = start_co; o < start_co + co_count; ++o) { const TSrc *src = reinterpret_cast(all_src) + o * src_ci; TDst *dst = reinterpret_cast(all_dst) + o * src_ci; const TDst scale_value = 1. / nram_scale[o]; for (int i = 0; i < loop_one_channel; ++i) { __memcpy(nram_src, src + i * src_num_unit, sizeof(TSrc) * src_num_unit, GDRAM2NRAM); convert(nram_dst, nram_src, dst_num_unit, scale_value); __memcpy(dst + i * dst_num_unit, nram_dst, sizeof(TDst) * dst_num_unit, NRAM2GDRAM); } if (remain_one_channel > 0) { __memcpy(nram_src, src + loop_one_channel * src_num_unit, sizeof(TSrc) * remain_one_channel, GDRAM2NRAM); convert(nram_dst, nram_src, remain_one_channel * PackValueNum::value, scale_value); __memcpy(dst + loop_one_channel * dst_num_unit, nram_dst, sizeof(TDst) * remain_one_channel * PackValueNum::value, NRAM2GDRAM); } } } } // namespace kernels static const std::map, decltype(&kernels::dequantifyPerTensor)> per_tensor_func_map = { {{4, CNNL_DTYPE_HALF}, &kernels::dequantifyPerTensor}, {{4, CNNL_DTYPE_FLOAT}, &kernels::dequantifyPerTensor}, {{8, CNNL_DTYPE_HALF}, &kernels::dequantifyPerTensor}, {{8, CNNL_DTYPE_FLOAT}, &kernels::dequantifyPerTensor}, }; KernelStatus invokeDequantifyPerTensor(cnnlHandle_t handle, const void *src, int src_bitwidth, void *dst, cnnlDataType_t dst_dtype, size_t src_count, float scale) { cnrtQueue_t queue; cnnlGetQueue(handle, &queue); CNdev dev; cnnlGetDevice(handle, &dev); int cluster_num; CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev)); const cnrtDim3_t dim = {.x = 4, .y = (uint32_t)cluster_num, .z = 1}; auto iter = per_tensor_func_map.find(std::make_pair(src_bitwidth, dst_dtype)); if (iter == per_tensor_func_map.end()) { std::cerr << "[invokeDequantifyPerTensor]: unsupported src_bitwidth: " << src_bitwidth << " dst_dtype: " << dst_dtype; return KernelStatus::KERNEL_STATUS_FAILED; } iter->second<<>>(dst, src, src_count, scale); return KernelStatus::KERNEL_STATUS_SUCCESS; } static const std::map, decltype(&kernels::dequantifyPerChannel)> per_channel_func_map = { {{4, CNNL_DTYPE_HALF}, &kernels::dequantifyPerChannel}, {{4, CNNL_DTYPE_FLOAT}, &kernels::dequantifyPerChannel}, {{8, CNNL_DTYPE_HALF}, &kernels::dequantifyPerChannel}, {{8, CNNL_DTYPE_FLOAT}, &kernels::dequantifyPerChannel}, }; KernelStatus invokeDequantifyPerChannel(cnnlHandle_t handle, const void *src, int src_bitwidth, void *dst, cnnlDataType_t dst_dtype, int src_ci, int co, const void *scale) { cnrtQueue_t queue; cnnlGetQueue(handle, &queue); CNdev dev; cnnlGetDevice(handle, &dev); int cluster_num; CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev)); const cnrtDim3_t dim = {.x = 4, .y = (uint32_t)cluster_num, .z = 1}; auto iter = per_channel_func_map.find(std::make_pair(src_bitwidth, dst_dtype)); if (iter == per_channel_func_map.end()) { std::cerr << "[invokeDequantifyPerChannel]: unsupported src_bitwidth: " << src_bitwidth << " dst_dtype: " << dst_dtype; return KernelStatus::KERNEL_STATUS_FAILED; } iter->second<<>>(dst, src, src_ci, co, scale); return KernelStatus::KERNEL_STATUS_SUCCESS; } } // namespace tmo