forked from EngineX-Cambricon/enginex-mlu370-vllm
314 lines
13 KiB
C++
314 lines
13 KiB
C++
/*************************************************************************
|
|
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
|
*
|
|
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
|
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
|
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
|
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
|
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
|
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
|
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
*************************************************************************/
|
|
#ifndef CSRC_KERNELS_QUANT_UTILS_H_
|
|
#define CSRC_KERNELS_QUANT_UTILS_H_
|
|
|
|
#include <cassert>
|
|
#include <iostream>
|
|
#include <string>
|
|
#include "cnnl.h"
|
|
#include "cnrt.h"
|
|
#include "kernel_utils.h"
|
|
|
|
namespace tmo {
|
|
|
|
#ifndef LT_NUM
|
|
#define LT_NUM (64)
|
|
#endif
|
|
|
|
#ifndef ANT_LT_ROW
|
|
#define ANT_LT_ROW (4)
|
|
#endif
|
|
|
|
#ifndef LT_NUM_ANT
|
|
#define LT_NUM_ANT (16)
|
|
#endif
|
|
|
|
#ifndef ONE_LINE
|
|
#define ONE_LINE (64)
|
|
#endif
|
|
|
|
#ifndef sizeof_
|
|
#define sizeof_(T) (uint32_t)sizeof(T)
|
|
#endif
|
|
|
|
#ifndef WRAM_LT_MAP16_STRIDE
|
|
#define WRAM_LT_MAP16_STRIDE (__MLU_WRAM_SIZE__ * 1024 / 16)
|
|
#endif
|
|
|
|
#ifndef TRANS_TABLE_SIZE
|
|
#define TRANS_TABLE_SIZE (64)
|
|
#endif
|
|
|
|
#ifndef DIV_UP
|
|
#define DIV_UP(x, y) ((x) / (y) + (int)((x) % (y) > 0))
|
|
#endif
|
|
|
|
#ifndef CONV_FUSE_OP_CVT
|
|
#define CONV_FUSE_OP_CVT(dtype, op, cvt, op_data) \
|
|
asm volatile("conv.nram.rn.f32" dtype dtype \
|
|
"[%[dst]], [%[src]], [%[kernel]], %[src_channel], " \
|
|
"%[src_height], 1, 1, 1, 1, 1, %[dst_channel]" op cvt \
|
|
";\n\t" ::[dst] "r"((Td *)output), \
|
|
[src] "r"((Ts *)input), [kernel] "r"((Ts *)filter), [src_channel] "r"(k), \
|
|
[src_height] "r"(m), [dst_channel] "r"(n), [operand0] "r"(op_data));
|
|
#endif
|
|
|
|
#define __reshape_nhwc2nchw_smallc(TYPE) \
|
|
asm volatile( \
|
|
"trans.tiling.nram.nram." TYPE \
|
|
"[%[dst]], [%[src]], " \
|
|
"%[in0], %[in1], %[is1], %[in2], %[is2], %[in3], %[is3], %[in4], %[is4], %[in5], %[is5]," \
|
|
"%[dn0], %[dn1], %[ds1], %[dn2], %[ds2], %[dn3], %[ds3], %[dn4], %[ds4], %[dn5], %[ds5]," \
|
|
".pretable.nram([%[pre]]); \n\t" ::[dst] "r"((T *)dst), \
|
|
[src] "r"((T *)src), [pre] "r"((uint8_t *)pre_table), [in0] "r"(in0), [in1] "r"(in1), \
|
|
[is1] "r"(in0), [in2] "i"(1), [is2] "i"(0), [in3] "r"(in3), [is3] "r"(is3), [in4] "r"(n), \
|
|
[is4] "r"(n_stride), [in5] "i"(1), [is5] "i"(0), [dn0] "r"(dn0), [dn1] "r"(dn1), \
|
|
[ds1] "r"(ds1), [dn2] "i"(1), [ds2] "i"(0), [dn3] "r"(dn3), [ds3] "r"(ds3), [dn4] "r"(n), \
|
|
[ds4] "r"(n_stride), [dn5] "i"(1), [ds5] "i"(0));
|
|
|
|
__mlu_func__ void next_power_of_two(int32_t &align_num, const int32_t num) {
|
|
int32_t tmp = num - 1;
|
|
asm volatile("findlast1.gpr.b32 %[out], %[in];" : [out] "=r"(tmp) : [in] "r"(tmp));
|
|
align_num = 1 << (tmp + 1);
|
|
}
|
|
|
|
/* copy from cnnl utils/trans_small.py by xwm. */
|
|
template <typename T>
|
|
__mlu_func__ void __reshape_nhwc2nchw_smallc_init(uint8_t *pre_table_nram, uint32_t channel) {
|
|
int32_t align_c;
|
|
next_power_of_two(align_c, channel);
|
|
int32_t align_num = ONE_LINE / sizeof_(T);
|
|
int32_t repeat = align_num / align_c;
|
|
for (int i = 0; i < 64; ++i) {
|
|
int32_t idx = i / sizeof_(T);
|
|
int32_t tmp_idx = (idx % repeat) * channel + idx / repeat;
|
|
int32_t real_idx = tmp_idx * sizeof_(T) + i % sizeof_(T);
|
|
__store_nram((uint8_t *)pre_table_nram + i, (uint8_t)real_idx + 0x80);
|
|
}
|
|
}
|
|
|
|
template <typename T>
|
|
__mlu_func__ void trans_nhwc2nchw_smallc(T *dst,
|
|
T *src,
|
|
uint8_t *pre_table,
|
|
uint32_t n,
|
|
uint32_t h,
|
|
uint32_t w,
|
|
uint32_t c) {
|
|
int32_t align_c;
|
|
next_power_of_two(align_c, c);
|
|
int32_t align_num = 64 / sizeof_(T);
|
|
int32_t hw = h * w;
|
|
int32_t repeat = align_num / align_c;
|
|
int32_t in0 = c * repeat * sizeof_(T);
|
|
int32_t in1 = align_c;
|
|
int32_t in3 = hw / align_num;
|
|
int32_t is3 = in0 * in1;
|
|
int32_t n_stride = hw * c * sizeof_(T);
|
|
int32_t dn0 = 64;
|
|
int32_t dn1 = c;
|
|
int32_t ds1 = hw * sizeof_(T);
|
|
int32_t dn3 = in3;
|
|
int32_t ds3 = dn0;
|
|
align_c = in3 > 0 ? align_c : 0;
|
|
if (align_c == 2) {
|
|
__reshape_nhwc2nchw_smallc("b256");
|
|
} else if (align_c == 4) {
|
|
__reshape_nhwc2nchw_smallc("b128");
|
|
} else if (align_c == 8) {
|
|
__reshape_nhwc2nchw_smallc("b64");
|
|
} else if (align_c == 16) {
|
|
__reshape_nhwc2nchw_smallc("b32");
|
|
} else if (align_c == 32) {
|
|
__reshape_nhwc2nchw_smallc("b16");
|
|
}
|
|
|
|
constexpr int32_t bw = 8 * sizeof_(T);
|
|
int32_t in3_rem = hw % align_num;
|
|
int32_t tail_in0 = c * sizeof_(T);
|
|
int32_t tail_dn0 = in3_rem * sizeof_(T);
|
|
if (in3_rem) {
|
|
asm volatile(
|
|
"trans.tiling.nram.nram.b%[bw] [%[dst]], [%[src]], \
|
|
%[in0], %[in1], %[is1], %[in2], %[is2], %[in3], \
|
|
%[is3], %[in4], %[is4], %[in5], %[is5], \
|
|
%[dn0], %[dn1], %[ds1], %[dn2], %[ds2], %[dn3], \
|
|
%[ds3], %[dn4], %[ds4], %[dn5], %[ds5]; \n\t" ::[bw] "i"(bw),
|
|
[dst] "r"((T *)dst + dn3 * ds3 / sizeof_(T)), [src] "r"((T *)src + is3 * in3 / sizeof_(T)),
|
|
[in0] "r"(tail_in0), [in1] "r"(in3_rem), [is1] "r"(tail_in0), [in2] "i"(1), [is2] "i"(0),
|
|
[in3] "i"(1), [is3] "i"(0), [in4] "r"(n), [is4] "r"(n_stride), [in5] "i"(1), [is5] "i"(0),
|
|
[dn0] "r"(tail_dn0), [dn1] "r"(dn1), [ds1] "r"(ds1), [dn2] "i"(1), [ds2] "i"(0),
|
|
[dn3] "i"(1), [ds3] "i"(0), [dn4] "r"(n), [ds4] "r"(n_stride), [dn5] "i"(1), [ds5] "i"(0));
|
|
}
|
|
}
|
|
|
|
__mlu_func__ void convert(float *dst, int8_t *src, int32_t num) {
|
|
__bang_int82float((float *)dst, (int8_t *)src, num, 0);
|
|
}
|
|
|
|
__mlu_func__ void convert(float *dst, int4x2_t *src, int32_t num) {
|
|
__bang_int42float((float *)dst, (int4x2_t *)src, num, 0);
|
|
}
|
|
|
|
__mlu_func__ void convert(half *dst, float *src, int32_t num) {
|
|
__bang_float2half((half *)dst, (float *)src, num);
|
|
}
|
|
|
|
__mlu_func__ void convert(bfloat16_t *dst, float *src, int32_t num) {
|
|
#if __BANG_ARCH__ >= 500
|
|
__bang_float2bfloat16((bfloat16_t *)dst, (float *)src, num);
|
|
#endif
|
|
}
|
|
|
|
__mlu_func__ void convert(int8_t *dst, int4x2_t *src, int32_t num) {
|
|
__bang_int42int8((int8_t *)dst, (int4x2_t *)src, num, 0, 0);
|
|
}
|
|
|
|
// if the dst dtype == src dtype, do nothing. if you want to mv, use mv directly
|
|
__mlu_func__ void convert(float *dst, float *src, int32_t num) {}
|
|
|
|
__mlu_func__ void convert(int8_t *dst, int8_t *src, int32_t num) {}
|
|
|
|
template <typename T>
|
|
__mlu_func__ void transpose(T *dst, T *src, int32_t dim1, int32_t dim2) {
|
|
__bang_transpose((T *)dst, (T *)src, dim1, dim2);
|
|
}
|
|
|
|
// if data type is int4x2_t, transpose is not supported directly
|
|
__mlu_func__ void transpose(int4x2_t *dst, int4x2_t *src, int32_t dim1, int32_t dim2) {}
|
|
|
|
template <typename T>
|
|
__mlu_func__ void mvNram2WramLT16(int8_t *wram_dst,
|
|
int8_t *nram_src,
|
|
int32_t n,
|
|
int32_t k,
|
|
int32_t total_k) {
|
|
int32_t data_size = k * sizeof_(T);
|
|
int32_t ds0 = PAD_UP(data_size, ONE_LINE);
|
|
int32_t ss0 = total_k * sizeof_(T);
|
|
int32_t count = DIV_UP(n, LT_NUM);
|
|
if (count > 0) {
|
|
for (int i = 0; i < count; ++i) {
|
|
__memcpy((int8_t *)wram_dst, (int8_t *)nram_src, data_size, NRAM2WRAM, ds0, ANT_LT_ROW - 1,
|
|
WRAM_LT_MAP16_STRIDE, LT_NUM_ANT - 1, ss0, LT_NUM - 1, 0, 0);
|
|
wram_dst += ANT_LT_ROW * ds0;
|
|
nram_src += LT_NUM * ss0;
|
|
}
|
|
}
|
|
|
|
count = n % LT_NUM / ANT_LT_ROW;
|
|
if (count > 0) {
|
|
__memcpy((int8_t *)wram_dst, (int8_t *)nram_src, data_size, NRAM2WRAM, ds0, ANT_LT_ROW - 1,
|
|
WRAM_LT_MAP16_STRIDE, count - 1, ss0, count * ANT_LT_ROW - 1, 0, 0);
|
|
wram_dst += count * WRAM_LT_MAP16_STRIDE;
|
|
nram_src += count * ANT_LT_ROW * ss0;
|
|
}
|
|
|
|
count = n % ANT_LT_ROW;
|
|
if (count) {
|
|
__memcpy((int8_t *)wram_dst, (int8_t *)nram_src, data_size, NRAM2WRAM, ds0, ss0, count - 1);
|
|
}
|
|
}
|
|
|
|
template <typename Td, typename Ts>
|
|
__mlu_func__ void
|
|
conv_fuse_mul_cvt(Td *output, Ts *input, Ts *filter, float *partial, int m, int n, int k) {
|
|
if (std::is_same<Td, half>::value && std::is_same<Ts, float>::value) {
|
|
CONV_FUSE_OP_CVT(".f32", ", .mul.partial.rn([%[operand0]])", ", .cvt.dst.rn.f16()", partial)
|
|
} else if (std::is_same<Td, bfloat16_t>::value && std::is_same<Ts, float>::value) {
|
|
#if __BANG_ARCH__ > 500
|
|
CONV_FUSE_OP_CVT(".f32", ", .mul.partial.rn([%[operand0]])", ", .cvt.dst.rn.bf16()", partial)
|
|
#endif
|
|
} else if (std::is_same<Td, float>::value && std::is_same<Ts, float>::value) {
|
|
CONV_FUSE_OP_CVT(".f32", ", .mul.partial.rn([%[operand0]])", "", partial)
|
|
}
|
|
}
|
|
|
|
template <bool ProcessOffsets>
|
|
__mlu_func__ void process_offsets(int32_t *lens_nram,
|
|
int32_t *offsets_nram,
|
|
const int32_t *context_lens,
|
|
const int32_t *context_seq_offsets,
|
|
const int32_t batch_size) {
|
|
if constexpr (ProcessOffsets) {
|
|
__memcpy((int32_t *)lens_nram, (int32_t *)context_lens, sizeof_(int32_t) * batch_size,
|
|
GDRAM2NRAM);
|
|
int total_lens = 0;
|
|
for (int batch_idx = 0; batch_idx < batch_size; ++batch_idx) {
|
|
__store_nram((int32_t *)offsets_nram + batch_idx, total_lens);
|
|
total_lens += __load_nram((int32_t *)lens_nram + batch_idx);
|
|
}
|
|
}
|
|
}
|
|
|
|
template <bool ProcessOffsets>
|
|
__mlu_func__ void load_len_offset(int32_t &seq_len,
|
|
int32_t &seq_offset,
|
|
const int32_t *lens_nram,
|
|
const int32_t *offsets_nram,
|
|
const int32_t *context_lens,
|
|
const int32_t *context_seq_offsets,
|
|
const int32_t batch_idx) {
|
|
if (ProcessOffsets) {
|
|
seq_len = __load_nram((int32_t *)lens_nram + batch_idx);
|
|
seq_offset = __load_nram((int32_t *)offsets_nram + batch_idx);
|
|
} else {
|
|
seq_len = __load_gdram((int32_t *)context_lens + batch_idx);
|
|
seq_offset = __load_gdram((int32_t *)context_seq_offsets + batch_idx);
|
|
}
|
|
}
|
|
|
|
template <typename T>
|
|
__mlu_func__ void load_scale_once(T *scale_nram,
|
|
const T *scale,
|
|
const int32_t head_num,
|
|
const int32_t head_size,
|
|
const size_t scale_bs_stride,
|
|
const size_t scale_head_stride) {
|
|
__memcpy((T *)scale_nram, (T *)scale, head_size * sizeof_(T), GDRAM2NRAM, head_size * sizeof_(T),
|
|
scale_head_stride * sizeof_(T), head_num - 1);
|
|
}
|
|
|
|
template <typename T, typename Tc, typename Ts>
|
|
__mlu_func__ void dequantize(T *output_nram,
|
|
Tc *input_nram,
|
|
Ts *scale_nram,
|
|
Ts *start_nram,
|
|
const int32_t input_num,
|
|
const int32_t scale_num) {
|
|
convert((float *)output_nram, (Tc *)input_nram, input_num);
|
|
convert((float *)start_nram, (Ts *)scale_nram, input_num);
|
|
__bang_cycle_mul((float *)output_nram, (float *)output_nram, (float *)start_nram, input_num,
|
|
scale_num);
|
|
convert((T *)output_nram, (float *)output_nram, input_num);
|
|
}
|
|
|
|
inline void getDeviceCoreAndRam(int32_t &cluster_dim,
|
|
int32_t &core_dim,
|
|
int32_t &nram_size,
|
|
int32_t &wram_size,
|
|
int32_t &sram_size,
|
|
const int32_t rem_for_stack) {
|
|
CNdev dev;
|
|
cnCtxGetDevice(&dev);
|
|
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_dim, cnrtAttrClusterCount, dev));
|
|
CNRT_CHECK(cnrtDeviceGetAttribute(&core_dim, cnrtAttrMcorePerCluster, dev));
|
|
CNRT_CHECK(cnrtDeviceGetAttribute(&nram_size, cnrtAttrNramSizePerMcore, dev));
|
|
CNRT_CHECK(cnrtDeviceGetAttribute(&wram_size, cnrtAttrWramSizePerMcore, dev));
|
|
CNRT_CHECK(cnrtDeviceGetAttribute(&sram_size, cnrtAttrSramSizePerMcore, dev));
|
|
nram_size -= rem_for_stack;
|
|
sram_size -= rem_for_stack;
|
|
}
|
|
} // namespace tmo
|
|
#endif // CSRC_KERNELS_QUANT_UTILS_H_
|