Files
2026-02-04 17:39:32 +08:00

314 lines
13 KiB
C++

/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_QUANT_UTILS_H_
#define CSRC_KERNELS_QUANT_UTILS_H_
#include <cassert>
#include <iostream>
#include <string>
#include "cnnl.h"
#include "cnrt.h"
#include "kernel_utils.h"
namespace tmo {
#ifndef LT_NUM
#define LT_NUM (64)
#endif
#ifndef ANT_LT_ROW
#define ANT_LT_ROW (4)
#endif
#ifndef LT_NUM_ANT
#define LT_NUM_ANT (16)
#endif
#ifndef ONE_LINE
#define ONE_LINE (64)
#endif
#ifndef sizeof_
#define sizeof_(T) (uint32_t)sizeof(T)
#endif
#ifndef WRAM_LT_MAP16_STRIDE
#define WRAM_LT_MAP16_STRIDE (__MLU_WRAM_SIZE__ * 1024 / 16)
#endif
#ifndef TRANS_TABLE_SIZE
#define TRANS_TABLE_SIZE (64)
#endif
#ifndef DIV_UP
#define DIV_UP(x, y) ((x) / (y) + (int)((x) % (y) > 0))
#endif
#ifndef CONV_FUSE_OP_CVT
#define CONV_FUSE_OP_CVT(dtype, op, cvt, op_data) \
asm volatile("conv.nram.rn.f32" dtype dtype \
"[%[dst]], [%[src]], [%[kernel]], %[src_channel], " \
"%[src_height], 1, 1, 1, 1, 1, %[dst_channel]" op cvt \
";\n\t" ::[dst] "r"((Td *)output), \
[src] "r"((Ts *)input), [kernel] "r"((Ts *)filter), [src_channel] "r"(k), \
[src_height] "r"(m), [dst_channel] "r"(n), [operand0] "r"(op_data));
#endif
#define __reshape_nhwc2nchw_smallc(TYPE) \
asm volatile( \
"trans.tiling.nram.nram." TYPE \
"[%[dst]], [%[src]], " \
"%[in0], %[in1], %[is1], %[in2], %[is2], %[in3], %[is3], %[in4], %[is4], %[in5], %[is5]," \
"%[dn0], %[dn1], %[ds1], %[dn2], %[ds2], %[dn3], %[ds3], %[dn4], %[ds4], %[dn5], %[ds5]," \
".pretable.nram([%[pre]]); \n\t" ::[dst] "r"((T *)dst), \
[src] "r"((T *)src), [pre] "r"((uint8_t *)pre_table), [in0] "r"(in0), [in1] "r"(in1), \
[is1] "r"(in0), [in2] "i"(1), [is2] "i"(0), [in3] "r"(in3), [is3] "r"(is3), [in4] "r"(n), \
[is4] "r"(n_stride), [in5] "i"(1), [is5] "i"(0), [dn0] "r"(dn0), [dn1] "r"(dn1), \
[ds1] "r"(ds1), [dn2] "i"(1), [ds2] "i"(0), [dn3] "r"(dn3), [ds3] "r"(ds3), [dn4] "r"(n), \
[ds4] "r"(n_stride), [dn5] "i"(1), [ds5] "i"(0));
__mlu_func__ void next_power_of_two(int32_t &align_num, const int32_t num) {
int32_t tmp = num - 1;
asm volatile("findlast1.gpr.b32 %[out], %[in];" : [out] "=r"(tmp) : [in] "r"(tmp));
align_num = 1 << (tmp + 1);
}
/* copy from cnnl utils/trans_small.py by xwm. */
template <typename T>
__mlu_func__ void __reshape_nhwc2nchw_smallc_init(uint8_t *pre_table_nram, uint32_t channel) {
int32_t align_c;
next_power_of_two(align_c, channel);
int32_t align_num = ONE_LINE / sizeof_(T);
int32_t repeat = align_num / align_c;
for (int i = 0; i < 64; ++i) {
int32_t idx = i / sizeof_(T);
int32_t tmp_idx = (idx % repeat) * channel + idx / repeat;
int32_t real_idx = tmp_idx * sizeof_(T) + i % sizeof_(T);
__store_nram((uint8_t *)pre_table_nram + i, (uint8_t)real_idx + 0x80);
}
}
template <typename T>
__mlu_func__ void trans_nhwc2nchw_smallc(T *dst,
T *src,
uint8_t *pre_table,
uint32_t n,
uint32_t h,
uint32_t w,
uint32_t c) {
int32_t align_c;
next_power_of_two(align_c, c);
int32_t align_num = 64 / sizeof_(T);
int32_t hw = h * w;
int32_t repeat = align_num / align_c;
int32_t in0 = c * repeat * sizeof_(T);
int32_t in1 = align_c;
int32_t in3 = hw / align_num;
int32_t is3 = in0 * in1;
int32_t n_stride = hw * c * sizeof_(T);
int32_t dn0 = 64;
int32_t dn1 = c;
int32_t ds1 = hw * sizeof_(T);
int32_t dn3 = in3;
int32_t ds3 = dn0;
align_c = in3 > 0 ? align_c : 0;
if (align_c == 2) {
__reshape_nhwc2nchw_smallc("b256");
} else if (align_c == 4) {
__reshape_nhwc2nchw_smallc("b128");
} else if (align_c == 8) {
__reshape_nhwc2nchw_smallc("b64");
} else if (align_c == 16) {
__reshape_nhwc2nchw_smallc("b32");
} else if (align_c == 32) {
__reshape_nhwc2nchw_smallc("b16");
}
constexpr int32_t bw = 8 * sizeof_(T);
int32_t in3_rem = hw % align_num;
int32_t tail_in0 = c * sizeof_(T);
int32_t tail_dn0 = in3_rem * sizeof_(T);
if (in3_rem) {
asm volatile(
"trans.tiling.nram.nram.b%[bw] [%[dst]], [%[src]], \
%[in0], %[in1], %[is1], %[in2], %[is2], %[in3], \
%[is3], %[in4], %[is4], %[in5], %[is5], \
%[dn0], %[dn1], %[ds1], %[dn2], %[ds2], %[dn3], \
%[ds3], %[dn4], %[ds4], %[dn5], %[ds5]; \n\t" ::[bw] "i"(bw),
[dst] "r"((T *)dst + dn3 * ds3 / sizeof_(T)), [src] "r"((T *)src + is3 * in3 / sizeof_(T)),
[in0] "r"(tail_in0), [in1] "r"(in3_rem), [is1] "r"(tail_in0), [in2] "i"(1), [is2] "i"(0),
[in3] "i"(1), [is3] "i"(0), [in4] "r"(n), [is4] "r"(n_stride), [in5] "i"(1), [is5] "i"(0),
[dn0] "r"(tail_dn0), [dn1] "r"(dn1), [ds1] "r"(ds1), [dn2] "i"(1), [ds2] "i"(0),
[dn3] "i"(1), [ds3] "i"(0), [dn4] "r"(n), [ds4] "r"(n_stride), [dn5] "i"(1), [ds5] "i"(0));
}
}
__mlu_func__ void convert(float *dst, int8_t *src, int32_t num) {
__bang_int82float((float *)dst, (int8_t *)src, num, 0);
}
__mlu_func__ void convert(float *dst, int4x2_t *src, int32_t num) {
__bang_int42float((float *)dst, (int4x2_t *)src, num, 0);
}
__mlu_func__ void convert(half *dst, float *src, int32_t num) {
__bang_float2half((half *)dst, (float *)src, num);
}
__mlu_func__ void convert(bfloat16_t *dst, float *src, int32_t num) {
#if __BANG_ARCH__ >= 500
__bang_float2bfloat16((bfloat16_t *)dst, (float *)src, num);
#endif
}
__mlu_func__ void convert(int8_t *dst, int4x2_t *src, int32_t num) {
__bang_int42int8((int8_t *)dst, (int4x2_t *)src, num, 0, 0);
}
// if the dst dtype == src dtype, do nothing. if you want to mv, use mv directly
__mlu_func__ void convert(float *dst, float *src, int32_t num) {}
__mlu_func__ void convert(int8_t *dst, int8_t *src, int32_t num) {}
template <typename T>
__mlu_func__ void transpose(T *dst, T *src, int32_t dim1, int32_t dim2) {
__bang_transpose((T *)dst, (T *)src, dim1, dim2);
}
// if data type is int4x2_t, transpose is not supported directly
__mlu_func__ void transpose(int4x2_t *dst, int4x2_t *src, int32_t dim1, int32_t dim2) {}
template <typename T>
__mlu_func__ void mvNram2WramLT16(int8_t *wram_dst,
int8_t *nram_src,
int32_t n,
int32_t k,
int32_t total_k) {
int32_t data_size = k * sizeof_(T);
int32_t ds0 = PAD_UP(data_size, ONE_LINE);
int32_t ss0 = total_k * sizeof_(T);
int32_t count = DIV_UP(n, LT_NUM);
if (count > 0) {
for (int i = 0; i < count; ++i) {
__memcpy((int8_t *)wram_dst, (int8_t *)nram_src, data_size, NRAM2WRAM, ds0, ANT_LT_ROW - 1,
WRAM_LT_MAP16_STRIDE, LT_NUM_ANT - 1, ss0, LT_NUM - 1, 0, 0);
wram_dst += ANT_LT_ROW * ds0;
nram_src += LT_NUM * ss0;
}
}
count = n % LT_NUM / ANT_LT_ROW;
if (count > 0) {
__memcpy((int8_t *)wram_dst, (int8_t *)nram_src, data_size, NRAM2WRAM, ds0, ANT_LT_ROW - 1,
WRAM_LT_MAP16_STRIDE, count - 1, ss0, count * ANT_LT_ROW - 1, 0, 0);
wram_dst += count * WRAM_LT_MAP16_STRIDE;
nram_src += count * ANT_LT_ROW * ss0;
}
count = n % ANT_LT_ROW;
if (count) {
__memcpy((int8_t *)wram_dst, (int8_t *)nram_src, data_size, NRAM2WRAM, ds0, ss0, count - 1);
}
}
template <typename Td, typename Ts>
__mlu_func__ void
conv_fuse_mul_cvt(Td *output, Ts *input, Ts *filter, float *partial, int m, int n, int k) {
if (std::is_same<Td, half>::value && std::is_same<Ts, float>::value) {
CONV_FUSE_OP_CVT(".f32", ", .mul.partial.rn([%[operand0]])", ", .cvt.dst.rn.f16()", partial)
} else if (std::is_same<Td, bfloat16_t>::value && std::is_same<Ts, float>::value) {
#if __BANG_ARCH__ > 500
CONV_FUSE_OP_CVT(".f32", ", .mul.partial.rn([%[operand0]])", ", .cvt.dst.rn.bf16()", partial)
#endif
} else if (std::is_same<Td, float>::value && std::is_same<Ts, float>::value) {
CONV_FUSE_OP_CVT(".f32", ", .mul.partial.rn([%[operand0]])", "", partial)
}
}
template <bool ProcessOffsets>
__mlu_func__ void process_offsets(int32_t *lens_nram,
int32_t *offsets_nram,
const int32_t *context_lens,
const int32_t *context_seq_offsets,
const int32_t batch_size) {
if constexpr (ProcessOffsets) {
__memcpy((int32_t *)lens_nram, (int32_t *)context_lens, sizeof_(int32_t) * batch_size,
GDRAM2NRAM);
int total_lens = 0;
for (int batch_idx = 0; batch_idx < batch_size; ++batch_idx) {
__store_nram((int32_t *)offsets_nram + batch_idx, total_lens);
total_lens += __load_nram((int32_t *)lens_nram + batch_idx);
}
}
}
template <bool ProcessOffsets>
__mlu_func__ void load_len_offset(int32_t &seq_len,
int32_t &seq_offset,
const int32_t *lens_nram,
const int32_t *offsets_nram,
const int32_t *context_lens,
const int32_t *context_seq_offsets,
const int32_t batch_idx) {
if (ProcessOffsets) {
seq_len = __load_nram((int32_t *)lens_nram + batch_idx);
seq_offset = __load_nram((int32_t *)offsets_nram + batch_idx);
} else {
seq_len = __load_gdram((int32_t *)context_lens + batch_idx);
seq_offset = __load_gdram((int32_t *)context_seq_offsets + batch_idx);
}
}
template <typename T>
__mlu_func__ void load_scale_once(T *scale_nram,
const T *scale,
const int32_t head_num,
const int32_t head_size,
const size_t scale_bs_stride,
const size_t scale_head_stride) {
__memcpy((T *)scale_nram, (T *)scale, head_size * sizeof_(T), GDRAM2NRAM, head_size * sizeof_(T),
scale_head_stride * sizeof_(T), head_num - 1);
}
template <typename T, typename Tc, typename Ts>
__mlu_func__ void dequantize(T *output_nram,
Tc *input_nram,
Ts *scale_nram,
Ts *start_nram,
const int32_t input_num,
const int32_t scale_num) {
convert((float *)output_nram, (Tc *)input_nram, input_num);
convert((float *)start_nram, (Ts *)scale_nram, input_num);
__bang_cycle_mul((float *)output_nram, (float *)output_nram, (float *)start_nram, input_num,
scale_num);
convert((T *)output_nram, (float *)output_nram, input_num);
}
inline void getDeviceCoreAndRam(int32_t &cluster_dim,
int32_t &core_dim,
int32_t &nram_size,
int32_t &wram_size,
int32_t &sram_size,
const int32_t rem_for_stack) {
CNdev dev;
cnCtxGetDevice(&dev);
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_dim, cnrtAttrClusterCount, dev));
CNRT_CHECK(cnrtDeviceGetAttribute(&core_dim, cnrtAttrMcorePerCluster, dev));
CNRT_CHECK(cnrtDeviceGetAttribute(&nram_size, cnrtAttrNramSizePerMcore, dev));
CNRT_CHECK(cnrtDeviceGetAttribute(&wram_size, cnrtAttrWramSizePerMcore, dev));
CNRT_CHECK(cnrtDeviceGetAttribute(&sram_size, cnrtAttrSramSizePerMcore, dev));
nram_size -= rem_for_stack;
sram_size -= rem_for_stack;
}
} // namespace tmo
#endif // CSRC_KERNELS_QUANT_UTILS_H_