/************************************************************************* * Copyright (C) [2023-2024] by Cambricon, Inc. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. *************************************************************************/ #ifndef CSRC_KERNELS_QUANT_UTILS_H_ #define CSRC_KERNELS_QUANT_UTILS_H_ #include #include #include #include "cnnl.h" #include "cnrt.h" #include "kernel_utils.h" namespace tmo { #ifndef LT_NUM #define LT_NUM (64) #endif #ifndef ANT_LT_ROW #define ANT_LT_ROW (4) #endif #ifndef LT_NUM_ANT #define LT_NUM_ANT (16) #endif #ifndef ONE_LINE #define ONE_LINE (64) #endif #ifndef sizeof_ #define sizeof_(T) (uint32_t)sizeof(T) #endif #ifndef WRAM_LT_MAP16_STRIDE #define WRAM_LT_MAP16_STRIDE (__MLU_WRAM_SIZE__ * 1024 / 16) #endif #ifndef TRANS_TABLE_SIZE #define TRANS_TABLE_SIZE (64) #endif #ifndef DIV_UP #define DIV_UP(x, y) ((x) / (y) + (int)((x) % (y) > 0)) #endif #ifndef CONV_FUSE_OP_CVT #define CONV_FUSE_OP_CVT(dtype, op, cvt, op_data) \ asm volatile("conv.nram.rn.f32" dtype dtype \ "[%[dst]], [%[src]], [%[kernel]], %[src_channel], " \ "%[src_height], 1, 1, 1, 1, 1, %[dst_channel]" op cvt \ ";\n\t" ::[dst] "r"((Td *)output), \ [src] "r"((Ts *)input), [kernel] "r"((Ts *)filter), [src_channel] "r"(k), \ [src_height] "r"(m), [dst_channel] "r"(n), [operand0] "r"(op_data)); #endif #define __reshape_nhwc2nchw_smallc(TYPE) \ asm volatile( \ "trans.tiling.nram.nram." TYPE \ "[%[dst]], [%[src]], " \ "%[in0], %[in1], %[is1], %[in2], %[is2], %[in3], %[is3], %[in4], %[is4], %[in5], %[is5]," \ "%[dn0], %[dn1], %[ds1], %[dn2], %[ds2], %[dn3], %[ds3], %[dn4], %[ds4], %[dn5], %[ds5]," \ ".pretable.nram([%[pre]]); \n\t" ::[dst] "r"((T *)dst), \ [src] "r"((T *)src), [pre] "r"((uint8_t *)pre_table), [in0] "r"(in0), [in1] "r"(in1), \ [is1] "r"(in0), [in2] "i"(1), [is2] "i"(0), [in3] "r"(in3), [is3] "r"(is3), [in4] "r"(n), \ [is4] "r"(n_stride), [in5] "i"(1), [is5] "i"(0), [dn0] "r"(dn0), [dn1] "r"(dn1), \ [ds1] "r"(ds1), [dn2] "i"(1), [ds2] "i"(0), [dn3] "r"(dn3), [ds3] "r"(ds3), [dn4] "r"(n), \ [ds4] "r"(n_stride), [dn5] "i"(1), [ds5] "i"(0)); __mlu_func__ void next_power_of_two(int32_t &align_num, const int32_t num) { int32_t tmp = num - 1; asm volatile("findlast1.gpr.b32 %[out], %[in];" : [out] "=r"(tmp) : [in] "r"(tmp)); align_num = 1 << (tmp + 1); } /* copy from cnnl utils/trans_small.py by xwm. */ template __mlu_func__ void __reshape_nhwc2nchw_smallc_init(uint8_t *pre_table_nram, uint32_t channel) { int32_t align_c; next_power_of_two(align_c, channel); int32_t align_num = ONE_LINE / sizeof_(T); int32_t repeat = align_num / align_c; for (int i = 0; i < 64; ++i) { int32_t idx = i / sizeof_(T); int32_t tmp_idx = (idx % repeat) * channel + idx / repeat; int32_t real_idx = tmp_idx * sizeof_(T) + i % sizeof_(T); __store_nram((uint8_t *)pre_table_nram + i, (uint8_t)real_idx + 0x80); } } template __mlu_func__ void trans_nhwc2nchw_smallc(T *dst, T *src, uint8_t *pre_table, uint32_t n, uint32_t h, uint32_t w, uint32_t c) { int32_t align_c; next_power_of_two(align_c, c); int32_t align_num = 64 / sizeof_(T); int32_t hw = h * w; int32_t repeat = align_num / align_c; int32_t in0 = c * repeat * sizeof_(T); int32_t in1 = align_c; int32_t in3 = hw / align_num; int32_t is3 = in0 * in1; int32_t n_stride = hw * c * sizeof_(T); int32_t dn0 = 64; int32_t dn1 = c; int32_t ds1 = hw * sizeof_(T); int32_t dn3 = in3; int32_t ds3 = dn0; align_c = in3 > 0 ? align_c : 0; if (align_c == 2) { __reshape_nhwc2nchw_smallc("b256"); } else if (align_c == 4) { __reshape_nhwc2nchw_smallc("b128"); } else if (align_c == 8) { __reshape_nhwc2nchw_smallc("b64"); } else if (align_c == 16) { __reshape_nhwc2nchw_smallc("b32"); } else if (align_c == 32) { __reshape_nhwc2nchw_smallc("b16"); } constexpr int32_t bw = 8 * sizeof_(T); int32_t in3_rem = hw % align_num; int32_t tail_in0 = c * sizeof_(T); int32_t tail_dn0 = in3_rem * sizeof_(T); if (in3_rem) { asm volatile( "trans.tiling.nram.nram.b%[bw] [%[dst]], [%[src]], \ %[in0], %[in1], %[is1], %[in2], %[is2], %[in3], \ %[is3], %[in4], %[is4], %[in5], %[is5], \ %[dn0], %[dn1], %[ds1], %[dn2], %[ds2], %[dn3], \ %[ds3], %[dn4], %[ds4], %[dn5], %[ds5]; \n\t" ::[bw] "i"(bw), [dst] "r"((T *)dst + dn3 * ds3 / sizeof_(T)), [src] "r"((T *)src + is3 * in3 / sizeof_(T)), [in0] "r"(tail_in0), [in1] "r"(in3_rem), [is1] "r"(tail_in0), [in2] "i"(1), [is2] "i"(0), [in3] "i"(1), [is3] "i"(0), [in4] "r"(n), [is4] "r"(n_stride), [in5] "i"(1), [is5] "i"(0), [dn0] "r"(tail_dn0), [dn1] "r"(dn1), [ds1] "r"(ds1), [dn2] "i"(1), [ds2] "i"(0), [dn3] "i"(1), [ds3] "i"(0), [dn4] "r"(n), [ds4] "r"(n_stride), [dn5] "i"(1), [ds5] "i"(0)); } } __mlu_func__ void convert(float *dst, int8_t *src, int32_t num) { __bang_int82float((float *)dst, (int8_t *)src, num, 0); } __mlu_func__ void convert(float *dst, int4x2_t *src, int32_t num) { __bang_int42float((float *)dst, (int4x2_t *)src, num, 0); } __mlu_func__ void convert(half *dst, float *src, int32_t num) { __bang_float2half((half *)dst, (float *)src, num); } __mlu_func__ void convert(bfloat16_t *dst, float *src, int32_t num) { #if __BANG_ARCH__ >= 500 __bang_float2bfloat16((bfloat16_t *)dst, (float *)src, num); #endif } __mlu_func__ void convert(int8_t *dst, int4x2_t *src, int32_t num) { __bang_int42int8((int8_t *)dst, (int4x2_t *)src, num, 0, 0); } // if the dst dtype == src dtype, do nothing. if you want to mv, use mv directly __mlu_func__ void convert(float *dst, float *src, int32_t num) {} __mlu_func__ void convert(int8_t *dst, int8_t *src, int32_t num) {} template __mlu_func__ void transpose(T *dst, T *src, int32_t dim1, int32_t dim2) { __bang_transpose((T *)dst, (T *)src, dim1, dim2); } // if data type is int4x2_t, transpose is not supported directly __mlu_func__ void transpose(int4x2_t *dst, int4x2_t *src, int32_t dim1, int32_t dim2) {} template __mlu_func__ void mvNram2WramLT16(int8_t *wram_dst, int8_t *nram_src, int32_t n, int32_t k, int32_t total_k) { int32_t data_size = k * sizeof_(T); int32_t ds0 = PAD_UP(data_size, ONE_LINE); int32_t ss0 = total_k * sizeof_(T); int32_t count = DIV_UP(n, LT_NUM); if (count > 0) { for (int i = 0; i < count; ++i) { __memcpy((int8_t *)wram_dst, (int8_t *)nram_src, data_size, NRAM2WRAM, ds0, ANT_LT_ROW - 1, WRAM_LT_MAP16_STRIDE, LT_NUM_ANT - 1, ss0, LT_NUM - 1, 0, 0); wram_dst += ANT_LT_ROW * ds0; nram_src += LT_NUM * ss0; } } count = n % LT_NUM / ANT_LT_ROW; if (count > 0) { __memcpy((int8_t *)wram_dst, (int8_t *)nram_src, data_size, NRAM2WRAM, ds0, ANT_LT_ROW - 1, WRAM_LT_MAP16_STRIDE, count - 1, ss0, count * ANT_LT_ROW - 1, 0, 0); wram_dst += count * WRAM_LT_MAP16_STRIDE; nram_src += count * ANT_LT_ROW * ss0; } count = n % ANT_LT_ROW; if (count) { __memcpy((int8_t *)wram_dst, (int8_t *)nram_src, data_size, NRAM2WRAM, ds0, ss0, count - 1); } } template __mlu_func__ void conv_fuse_mul_cvt(Td *output, Ts *input, Ts *filter, float *partial, int m, int n, int k) { if (std::is_same::value && std::is_same::value) { CONV_FUSE_OP_CVT(".f32", ", .mul.partial.rn([%[operand0]])", ", .cvt.dst.rn.f16()", partial) } else if (std::is_same::value && std::is_same::value) { #if __BANG_ARCH__ > 500 CONV_FUSE_OP_CVT(".f32", ", .mul.partial.rn([%[operand0]])", ", .cvt.dst.rn.bf16()", partial) #endif } else if (std::is_same::value && std::is_same::value) { CONV_FUSE_OP_CVT(".f32", ", .mul.partial.rn([%[operand0]])", "", partial) } } template __mlu_func__ void process_offsets(int32_t *lens_nram, int32_t *offsets_nram, const int32_t *context_lens, const int32_t *context_seq_offsets, const int32_t batch_size) { if constexpr (ProcessOffsets) { __memcpy((int32_t *)lens_nram, (int32_t *)context_lens, sizeof_(int32_t) * batch_size, GDRAM2NRAM); int total_lens = 0; for (int batch_idx = 0; batch_idx < batch_size; ++batch_idx) { __store_nram((int32_t *)offsets_nram + batch_idx, total_lens); total_lens += __load_nram((int32_t *)lens_nram + batch_idx); } } } template __mlu_func__ void load_len_offset(int32_t &seq_len, int32_t &seq_offset, const int32_t *lens_nram, const int32_t *offsets_nram, const int32_t *context_lens, const int32_t *context_seq_offsets, const int32_t batch_idx) { if (ProcessOffsets) { seq_len = __load_nram((int32_t *)lens_nram + batch_idx); seq_offset = __load_nram((int32_t *)offsets_nram + batch_idx); } else { seq_len = __load_gdram((int32_t *)context_lens + batch_idx); seq_offset = __load_gdram((int32_t *)context_seq_offsets + batch_idx); } } template __mlu_func__ void load_scale_once(T *scale_nram, const T *scale, const int32_t head_num, const int32_t head_size, const size_t scale_bs_stride, const size_t scale_head_stride) { __memcpy((T *)scale_nram, (T *)scale, head_size * sizeof_(T), GDRAM2NRAM, head_size * sizeof_(T), scale_head_stride * sizeof_(T), head_num - 1); } template __mlu_func__ void dequantize(T *output_nram, Tc *input_nram, Ts *scale_nram, Ts *start_nram, const int32_t input_num, const int32_t scale_num) { convert((float *)output_nram, (Tc *)input_nram, input_num); convert((float *)start_nram, (Ts *)scale_nram, input_num); __bang_cycle_mul((float *)output_nram, (float *)output_nram, (float *)start_nram, input_num, scale_num); convert((T *)output_nram, (float *)output_nram, input_num); } inline void getDeviceCoreAndRam(int32_t &cluster_dim, int32_t &core_dim, int32_t &nram_size, int32_t &wram_size, int32_t &sram_size, const int32_t rem_for_stack) { CNdev dev; cnCtxGetDevice(&dev); CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_dim, cnrtAttrClusterCount, dev)); CNRT_CHECK(cnrtDeviceGetAttribute(&core_dim, cnrtAttrMcorePerCluster, dev)); CNRT_CHECK(cnrtDeviceGetAttribute(&nram_size, cnrtAttrNramSizePerMcore, dev)); CNRT_CHECK(cnrtDeviceGetAttribute(&wram_size, cnrtAttrWramSizePerMcore, dev)); CNRT_CHECK(cnrtDeviceGetAttribute(&sram_size, cnrtAttrSramSizePerMcore, dev)); nram_size -= rem_for_stack; sram_size -= rem_for_stack; } } // namespace tmo #endif // CSRC_KERNELS_QUANT_UTILS_H_