forked from EngineX-Cambricon/enginex-mlu370-vllm
add ops
This commit is contained in:
254
torch_mlu_ops-v1.3.2/csrc/kernels/dequantify.mlu
Normal file
254
torch_mlu_ops-v1.3.2/csrc/kernels/dequantify.mlu
Normal file
@@ -0,0 +1,254 @@
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <ostream>
|
||||
#include "cnnl.h"
|
||||
#include "cnrt.h"
|
||||
#include "dequantify.mluh"
|
||||
// clang-format off
|
||||
#include <mlu.h>
|
||||
// clang-format on
|
||||
|
||||
namespace tmo {
|
||||
namespace kernels {
|
||||
template <typename T>
|
||||
struct PackValueNum {
|
||||
const static int value = 1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PackValueNum<int4x2_t> {
|
||||
const static int value = 2;
|
||||
};
|
||||
|
||||
__nram__ int8_t nram_buf[(__MLU_NRAM_SIZE__ * 3 / 8 * 1024)];
|
||||
|
||||
__nram__ int8_t nram_buf_scale[8192];
|
||||
|
||||
__mlu_func__ void convert(float *dst, const int8_t *src, int count, float scale) {
|
||||
__bang_int82float(dst, src, count, 0);
|
||||
__bang_mul_scalar(dst, dst, scale, count);
|
||||
}
|
||||
|
||||
__mlu_func__ void convert(float *dst, const int4x2_t *src, int count, float scale) {
|
||||
__bang_int42float_rn(dst, src, count, 0);
|
||||
__bang_mul_scalar(dst, dst, scale, count);
|
||||
}
|
||||
|
||||
__mlu_func__ void convert(half *dst, const int8_t *src, int count, float scale) {
|
||||
__bang_int82half(dst, src, count, 0);
|
||||
__bang_mul_scalar(dst, dst, (half)scale, count);
|
||||
}
|
||||
|
||||
__mlu_func__ void convert(half *dst, const int4x2_t *src, int count, float scale) {
|
||||
__bang_int42half_rn(dst, src, count, 0);
|
||||
__bang_mul_scalar(dst, dst, (half)scale, count);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__mlu_func__ void swap(T *&ping, T *&pong) {
|
||||
T *tmp = ping;
|
||||
ping = pong;
|
||||
pong = tmp;
|
||||
}
|
||||
|
||||
template <typename TDst, typename TSrc>
|
||||
__mlu_global__ void dequantifyPerTensor(void *all_dst,
|
||||
const void *all_src,
|
||||
size_t all_src_count,
|
||||
float scale) {
|
||||
scale = 1.0f / scale;
|
||||
size_t src_per_core = all_src_count / taskDim;
|
||||
size_t src_remain = all_src_count % taskDim;
|
||||
size_t start = taskId * src_per_core + (taskId < src_remain ? taskId : src_remain);
|
||||
const size_t src_count = src_per_core + (taskId < src_remain ? 1 : 0);
|
||||
TDst *dst = reinterpret_cast<TDst *>(all_dst) + start * PackValueNum<TSrc>::value;
|
||||
const TSrc *src = reinterpret_cast<const TSrc *>(all_src) + start;
|
||||
|
||||
constexpr int size_unit = sizeof(nram_buf) / 2 / // divide by 2 for ping pong
|
||||
(sizeof(TSrc) + sizeof(TDst) * PackValueNum<TSrc>::value) / 128 *
|
||||
128; // align to 128
|
||||
constexpr int src_num_unit = size_unit / sizeof(TSrc);
|
||||
constexpr int dst_num_unit = src_num_unit * PackValueNum<TSrc>::value;
|
||||
int8_t *nram_buf_ping = nram_buf;
|
||||
int8_t *nram_buf_pong = nram_buf + sizeof(nram_buf) / 2;
|
||||
|
||||
TSrc *nram_src_ping = reinterpret_cast<TSrc *>(nram_buf_ping);
|
||||
TDst *nram_dst_ping =
|
||||
reinterpret_cast<TDst *>(nram_buf_ping + static_cast<int>(sizeof(TSrc)) * size_unit);
|
||||
TSrc *nram_src_pong = reinterpret_cast<TSrc *>(nram_buf_pong);
|
||||
TDst *nram_dst_pong =
|
||||
reinterpret_cast<TDst *>(nram_buf_pong + static_cast<int>(sizeof(TSrc)) * size_unit);
|
||||
|
||||
int loop_count = src_count / src_num_unit;
|
||||
int remain_count = src_count % src_num_unit;
|
||||
|
||||
// L
|
||||
__memcpy_async(nram_src_ping, src, sizeof(TSrc) * src_num_unit, GDRAM2NRAM);
|
||||
swap(nram_src_ping, nram_src_pong);
|
||||
swap(nram_dst_ping, nram_dst_pong);
|
||||
__sync_io_move_compute();
|
||||
|
||||
// L C
|
||||
__memcpy_async(nram_src_ping, src + 1 * src_num_unit, sizeof(TSrc) * src_num_unit, GDRAM2NRAM);
|
||||
convert(nram_dst_pong, nram_src_pong, dst_num_unit, scale);
|
||||
swap(nram_src_ping, nram_src_pong);
|
||||
swap(nram_dst_ping, nram_dst_pong);
|
||||
__sync_io_move_compute();
|
||||
|
||||
// L C S
|
||||
for (int i = 0; i < loop_count - 2; ++i) {
|
||||
__memcpy_async(nram_src_ping, src + (i + 2) * src_num_unit, sizeof(TSrc) * src_num_unit,
|
||||
GDRAM2NRAM);
|
||||
__memcpy_async(dst + i * dst_num_unit, nram_dst_ping, sizeof(TDst) * dst_num_unit, NRAM2GDRAM);
|
||||
convert(nram_dst_pong, nram_src_pong, dst_num_unit, scale);
|
||||
swap(nram_src_ping, nram_src_pong);
|
||||
swap(nram_dst_ping, nram_dst_pong);
|
||||
__sync_io_move_compute();
|
||||
}
|
||||
|
||||
// C S
|
||||
__memcpy_async(dst + (loop_count - 2) * dst_num_unit, nram_dst_ping, sizeof(TDst) * dst_num_unit,
|
||||
NRAM2GDRAM);
|
||||
convert(nram_dst_pong, nram_src_pong, dst_num_unit, scale);
|
||||
swap(nram_src_ping, nram_src_pong);
|
||||
swap(nram_dst_ping, nram_dst_pong);
|
||||
__sync_io_move_compute();
|
||||
|
||||
// S
|
||||
__memcpy_async(dst + (loop_count - 1) * dst_num_unit, nram_dst_ping, sizeof(TDst) * dst_num_unit,
|
||||
NRAM2GDRAM);
|
||||
|
||||
__sync_io_move_compute();
|
||||
|
||||
if (remain_count > 0) {
|
||||
__memcpy(nram_src_ping, src + loop_count * src_num_unit, sizeof(TSrc) * remain_count,
|
||||
GDRAM2NRAM);
|
||||
convert(nram_dst_ping, nram_src_ping, remain_count * PackValueNum<TSrc>::value, scale);
|
||||
__memcpy(dst + loop_count * dst_num_unit, nram_dst_ping,
|
||||
sizeof(TDst) * remain_count * PackValueNum<TSrc>::value, NRAM2GDRAM);
|
||||
}
|
||||
}
|
||||
|
||||
// does not use a pipeline because per channel is more complicated but it's a one-time operation, so
|
||||
// performance doesn't matter.
|
||||
template <typename TDst, typename TSrc>
|
||||
__mlu_global__ void dequantifyPerChannel(void *all_dst,
|
||||
const void *all_src,
|
||||
int src_ci,
|
||||
int all_co,
|
||||
const void *scale) {
|
||||
const int co_per_core = all_co / taskDim;
|
||||
const int co_remain = all_co % taskDim;
|
||||
const int start_co = taskId * co_per_core + (taskId < co_remain ? taskId : co_remain);
|
||||
const int co_count = co_per_core + (taskId < co_remain ? 1 : 0);
|
||||
assert(co_count <= sizeof(nram_buf_scale) / sizeof(TDst));
|
||||
|
||||
constexpr int size_unit = sizeof(nram_buf) /
|
||||
(sizeof(TSrc) + sizeof(TDst) * PackValueNum<TSrc>::value) / 128 *
|
||||
128; // align to 128
|
||||
// yes, we only deal with 1 channel at a time
|
||||
// no, there's no need to optimize a one-time operation
|
||||
const int src_num_unit = std::min((int)(size_unit / sizeof(TSrc)), src_ci);
|
||||
const int dst_num_unit = src_num_unit * PackValueNum<TSrc>::value;
|
||||
TSrc *const nram_src = reinterpret_cast<TSrc *>(nram_buf);
|
||||
TDst *const nram_dst =
|
||||
reinterpret_cast<TDst *>(nram_buf + static_cast<int>(sizeof(TSrc)) * size_unit);
|
||||
|
||||
const TDst *nram_scale = reinterpret_cast<const TDst *>(nram_buf_scale);
|
||||
|
||||
const int loop_one_channel = src_ci / src_num_unit;
|
||||
const int remain_one_channel = src_ci % src_num_unit;
|
||||
|
||||
for (int o = start_co; o < start_co + co_count; ++o) {
|
||||
const TSrc *src = reinterpret_cast<const TSrc *>(all_src) + o * src_ci;
|
||||
TDst *dst = reinterpret_cast<TDst *>(all_dst) + o * src_ci;
|
||||
const TDst scale_value = 1. / nram_scale[o];
|
||||
for (int i = 0; i < loop_one_channel; ++i) {
|
||||
__memcpy(nram_src, src + i * src_num_unit, sizeof(TSrc) * src_num_unit, GDRAM2NRAM);
|
||||
convert(nram_dst, nram_src, dst_num_unit, scale_value);
|
||||
__memcpy(dst + i * dst_num_unit, nram_dst, sizeof(TDst) * dst_num_unit, NRAM2GDRAM);
|
||||
}
|
||||
if (remain_one_channel > 0) {
|
||||
__memcpy(nram_src, src + loop_one_channel * src_num_unit, sizeof(TSrc) * remain_one_channel,
|
||||
GDRAM2NRAM);
|
||||
convert(nram_dst, nram_src, remain_one_channel * PackValueNum<TSrc>::value, scale_value);
|
||||
__memcpy(dst + loop_one_channel * dst_num_unit, nram_dst,
|
||||
sizeof(TDst) * remain_one_channel * PackValueNum<TSrc>::value, NRAM2GDRAM);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace kernels
|
||||
static const std::map<std::pair<int, cnnlDataType_t>,
|
||||
decltype(&kernels::dequantifyPerTensor<half, int4x2_t>)>
|
||||
per_tensor_func_map = {
|
||||
{{4, CNNL_DTYPE_HALF}, &kernels::dequantifyPerTensor<half, int4x2_t>},
|
||||
{{4, CNNL_DTYPE_FLOAT}, &kernels::dequantifyPerTensor<float, int4x2_t>},
|
||||
{{8, CNNL_DTYPE_HALF}, &kernels::dequantifyPerTensor<half, int8_t>},
|
||||
{{8, CNNL_DTYPE_FLOAT}, &kernels::dequantifyPerTensor<float, int8_t>},
|
||||
};
|
||||
|
||||
KernelStatus invokeDequantifyPerTensor(cnnlHandle_t handle,
|
||||
const void *src,
|
||||
int src_bitwidth,
|
||||
void *dst,
|
||||
cnnlDataType_t dst_dtype,
|
||||
size_t src_count,
|
||||
float scale) {
|
||||
cnrtQueue_t queue;
|
||||
cnnlGetQueue(handle, &queue);
|
||||
CNdev dev;
|
||||
cnnlGetDevice(handle, &dev);
|
||||
int cluster_num;
|
||||
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
|
||||
const cnrtDim3_t dim = {.x = 4, .y = (uint32_t)cluster_num, .z = 1};
|
||||
auto iter = per_tensor_func_map.find(std::make_pair(src_bitwidth, dst_dtype));
|
||||
if (iter == per_tensor_func_map.end()) {
|
||||
std::cerr << "[invokeDequantifyPerTensor]: unsupported src_bitwidth: " << src_bitwidth
|
||||
<< " dst_dtype: " << dst_dtype;
|
||||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||||
}
|
||||
|
||||
iter->second<<<dim, cnrtFuncTypeUnion1, queue>>>(dst, src, src_count, scale);
|
||||
|
||||
return KernelStatus::KERNEL_STATUS_SUCCESS;
|
||||
}
|
||||
|
||||
static const std::map<std::pair<int, cnnlDataType_t>,
|
||||
decltype(&kernels::dequantifyPerChannel<half, int4x2_t>)>
|
||||
per_channel_func_map = {
|
||||
{{4, CNNL_DTYPE_HALF}, &kernels::dequantifyPerChannel<half, int4x2_t>},
|
||||
{{4, CNNL_DTYPE_FLOAT}, &kernels::dequantifyPerChannel<float, int4x2_t>},
|
||||
{{8, CNNL_DTYPE_HALF}, &kernels::dequantifyPerChannel<half, int8_t>},
|
||||
{{8, CNNL_DTYPE_FLOAT}, &kernels::dequantifyPerChannel<float, int8_t>},
|
||||
};
|
||||
|
||||
KernelStatus invokeDequantifyPerChannel(cnnlHandle_t handle,
|
||||
const void *src,
|
||||
int src_bitwidth,
|
||||
void *dst,
|
||||
cnnlDataType_t dst_dtype,
|
||||
int src_ci,
|
||||
int co,
|
||||
const void *scale) {
|
||||
cnrtQueue_t queue;
|
||||
cnnlGetQueue(handle, &queue);
|
||||
CNdev dev;
|
||||
cnnlGetDevice(handle, &dev);
|
||||
int cluster_num;
|
||||
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
|
||||
const cnrtDim3_t dim = {.x = 4, .y = (uint32_t)cluster_num, .z = 1};
|
||||
auto iter = per_channel_func_map.find(std::make_pair(src_bitwidth, dst_dtype));
|
||||
if (iter == per_channel_func_map.end()) {
|
||||
std::cerr << "[invokeDequantifyPerChannel]: unsupported src_bitwidth: " << src_bitwidth
|
||||
<< " dst_dtype: " << dst_dtype;
|
||||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||||
}
|
||||
|
||||
iter->second<<<dim, cnrtFuncTypeUnion1, queue>>>(dst, src, src_ci, co, scale);
|
||||
|
||||
return KernelStatus::KERNEL_STATUS_SUCCESS;
|
||||
}
|
||||
|
||||
} // namespace tmo
|
||||
Reference in New Issue
Block a user