603 lines
29 KiB
Plaintext
603 lines
29 KiB
Plaintext
#include <mlu.h>
|
||
#include <cassert>
|
||
#include <iostream>
|
||
#include <limits>
|
||
#include <map>
|
||
#include <ostream>
|
||
#include "cnnl.h"
|
||
#include "cnrt.h"
|
||
#include "softmax_topk.mluh"
|
||
|
||
namespace tmo {
|
||
|
||
namespace kernels {
|
||
#define SCATTER_ALIGN (64) // align for __scatter()
|
||
|
||
#define NRAM_SIZE (__MLU_NRAM_SIZE__ * 1024 - 32 * 1024)
|
||
#define SRAM_SIZE (__MLU_SRAM_SIZE__ * 1024 - 32 * 1024)
|
||
#define TILING_ALIGN (64)
|
||
#define DIV_UP(x, y) ((x) / (y) + (int)((x) % (y) > 0))
|
||
__nram__ int8_t nram_buffer[NRAM_SIZE];
|
||
__mlu_shared__ int8_t sram_buffer[SRAM_SIZE];
|
||
|
||
#define __TRANS_TILING(TYPE, CONVERT) \
|
||
__asm__ volatile("trans.tiling." TYPE \
|
||
" [%[dst]], [%[src]]," \
|
||
"%[in0], %[in1], %[is1], %[in2], %[is2], %[in3], %[is3], %[in4]," \
|
||
"%[is4], %[in5], %[is5]," \
|
||
"%[dn0], %[dn1], %[ds1], %[dn2], %[ds2], %[dn3], %[ds3], %[dn4]," \
|
||
"%[ds4], %[dn5], %[ds5]" CONVERT ::[dst] "r"(dst), \
|
||
[src] "r"(src), [in0] "r"(in0), [in1] "r"(in1), [is1] "r"(is1), [in2] "r"(in2), \
|
||
[is2] "r"(is2), [in3] "r"(in3), [is3] "r"(is3), [in4] "r"(in4), [is4] "r"(is4), \
|
||
[in5] "r"(in5), [is5] "r"(is5), [dn0] "r"(dn0), [dn1] "r"(dn1), [ds1] "r"(ds1), \
|
||
[dn2] "r"(dn2), [ds2] "r"(ds2), [dn3] "r"(dn3), [ds3] "r"(ds3), [dn4] "r"(dn4), \
|
||
[ds4] "r"(ds4), [dn5] "r"(dn5), [ds5] "r"(ds5));
|
||
|
||
template <typename SRC_DTYPE, typename DST_DTYPE, mluMemcpyDirection_t dir>
|
||
__mlu_func__ void __mlvm_trans(DST_DTYPE *dst,
|
||
const SRC_DTYPE *src,
|
||
const uint32_t in0,
|
||
const uint32_t in1,
|
||
const uint32_t is1,
|
||
const uint32_t in2,
|
||
const uint32_t is2,
|
||
const uint32_t in3,
|
||
const uint32_t is3,
|
||
const uint32_t in4,
|
||
const uint32_t is4,
|
||
const uint32_t in5,
|
||
const uint32_t is5,
|
||
const uint32_t dn0,
|
||
const uint32_t dn1,
|
||
const uint32_t ds1,
|
||
const uint32_t dn2,
|
||
const uint32_t ds2,
|
||
const uint32_t dn3,
|
||
const uint32_t ds3,
|
||
const uint32_t dn4,
|
||
const uint32_t ds4,
|
||
const uint32_t dn5,
|
||
const uint32_t ds5) {
|
||
if (SRAM2NRAM == dir && std::is_same<DST_DTYPE, float>::value) {
|
||
if (std::is_same<SRC_DTYPE, float>::value) {
|
||
__TRANS_TILING("nram.sram.b32", ";")
|
||
} else if (std::is_same<SRC_DTYPE, half>::value) {
|
||
__TRANS_TILING("nram.sram.b16", ", .cvt.f32.f16();")
|
||
#if __BANG_ARCH__ >= 500
|
||
} else if (std::is_same<SRC_DTYPE, bfloat16_t>::value) {
|
||
__TRANS_TILING("nram.sram.b16", ", .cvt.f32.bf16();")
|
||
#endif
|
||
}
|
||
}
|
||
}
|
||
|
||
/* 将shape为[h,w]的数据转置为[w,h](带转数),分4块分别进行处理。
|
||
* dst: dst地址
|
||
* src: src地址
|
||
* h: h方向大小
|
||
* w: w方向大小
|
||
*/
|
||
template <typename SRC_DTYPE, typename DST_DTYPE, mluMemcpyDirection_t dir>
|
||
__mlu_func__ void transhw2wh(DST_DTYPE *dst, SRC_DTYPE *src, uint32_t h, uint32_t w) {
|
||
uint32_t align_num = TILING_ALIGN / sizeof(SRC_DTYPE);
|
||
uint32_t w_align = w / align_num;
|
||
uint32_t w_rem = w % align_num;
|
||
uint32_t h_align = h / align_num;
|
||
uint32_t h_rem = h % align_num;
|
||
uint32_t in0 = TILING_ALIGN, dn0 = TILING_ALIGN;
|
||
uint32_t in1 = align_num, is1 = w * sizeof(SRC_DTYPE);
|
||
uint32_t in3 = w_align, is3 = TILING_ALIGN;
|
||
uint32_t in4 = h_align, is4 = w * TILING_ALIGN;
|
||
uint32_t dn1 = align_num, ds1 = h * sizeof(DST_DTYPE);
|
||
uint32_t dn3 = in3, ds3 = h * align_num * sizeof(DST_DTYPE);
|
||
uint32_t dn4 = in4, ds4 = align_num * sizeof(DST_DTYPE);
|
||
/* 1. h_align * w_align */
|
||
if (w_align > 0 && h_align > 0) {
|
||
__mlvm_trans<SRC_DTYPE, DST_DTYPE, dir>(dst, src, in0, in1, is1, 1, 0, in3, is3, in4, is4, 1, 0,
|
||
dn0, dn1, ds1, 1, 0, dn3, ds3, dn4, ds4, 1, 0);
|
||
}
|
||
/* 2. h_align * w_rem */
|
||
if (w_rem > 0 && h_align > 0) {
|
||
SRC_DTYPE *src_temp = src + w_align * align_num;
|
||
DST_DTYPE *dst_temp = dst + w_align * align_num * h;
|
||
in0 = w_rem * sizeof(SRC_DTYPE);
|
||
dn0 = TILING_ALIGN;
|
||
in1 = align_num;
|
||
is1 = w * sizeof(SRC_DTYPE);
|
||
in4 = h_align;
|
||
is4 = w * TILING_ALIGN;
|
||
dn1 = w_rem;
|
||
ds1 = h * sizeof(DST_DTYPE);
|
||
dn4 = in4;
|
||
ds4 = align_num * sizeof(DST_DTYPE);
|
||
__mlvm_trans<SRC_DTYPE, DST_DTYPE, dir>(dst_temp, src_temp, in0, in1, is1, 1, 0, 1, 0, in4, is4,
|
||
1, 0, dn0, dn1, ds1, 1, 0, 1, 0, dn4, ds4, 1, 0);
|
||
}
|
||
/* 3. h_rem * w_align */
|
||
if (w_align > 0 && h_rem > 0) {
|
||
SRC_DTYPE *src_temp = src + h_align * align_num * w;
|
||
DST_DTYPE *dst_temp = dst + h_align * align_num;
|
||
in0 = TILING_ALIGN;
|
||
dn0 = h_rem * sizeof(SRC_DTYPE);
|
||
in1 = h_rem;
|
||
is1 = w * sizeof(SRC_DTYPE);
|
||
in4 = w_align;
|
||
is4 = TILING_ALIGN;
|
||
dn1 = align_num;
|
||
ds1 = h * sizeof(DST_DTYPE);
|
||
dn4 = in4;
|
||
ds4 = h * align_num * sizeof(DST_DTYPE);
|
||
__mlvm_trans<SRC_DTYPE, DST_DTYPE, dir>(dst_temp, src_temp, in0, in1, is1, 1, 0, 1, 0, in4, is4,
|
||
1, 0, dn0, dn1, ds1, 1, 0, 1, 0, dn4, ds4, 1, 0);
|
||
}
|
||
/* 4. h_rem * w_rem */
|
||
if (w_rem > 0 && h_rem > 0) {
|
||
SRC_DTYPE *src_temp = src + h_align * align_num * w + w_align * align_num;
|
||
DST_DTYPE *dst_temp = dst + w_align * align_num * h + h_align * align_num;
|
||
in0 = w_rem * sizeof(SRC_DTYPE);
|
||
dn0 = h_rem * sizeof(SRC_DTYPE);
|
||
in1 = h_rem;
|
||
is1 = w * sizeof(SRC_DTYPE);
|
||
dn1 = w_rem;
|
||
ds1 = h * sizeof(DST_DTYPE);
|
||
__mlvm_trans<SRC_DTYPE, DST_DTYPE, dir>(dst_temp, src_temp, in0, in1, is1, 1, 0, 1, 0, 1, 0, 1,
|
||
0, dn0, dn1, ds1, 1, 0, 1, 0, 1, 0, 1, 0);
|
||
}
|
||
}
|
||
|
||
__mlu_func__ void getTopk(float *value_buffer,
|
||
uint32_t *index_buffer,
|
||
float *src_buffer,
|
||
float *compute_buffer,
|
||
float *max_buffer,
|
||
float *temp_buffer,
|
||
uint32_t *i_buffer,
|
||
uint32_t *col_buffer,
|
||
uint32_t topk,
|
||
uint32_t num_expert_group,
|
||
uint32_t col,
|
||
uint32_t row,
|
||
uint32_t value_index_stride,
|
||
uint32_t group_size,
|
||
bool is_deal_group) {
|
||
__bang_write_value((float *)temp_buffer, col, -INFINITY); // set -inf vector
|
||
for (int k = 0; k < topk; k++) {
|
||
if (is_deal_group) {
|
||
__bang_maxpool_index((uint32_t *)value_buffer + k * col, max_buffer, col, 1, num_expert_group,
|
||
1, num_expert_group, 1, 1);
|
||
__bang_fusion(FUSION_FMA, col_buffer, (uint32_t *)value_buffer + k * col, col, i_buffer, col,
|
||
col);
|
||
} else {
|
||
__bang_maxpool_value_index(value_buffer + k * col, max_buffer, col, 1, row, 1, row, 1, 1,
|
||
value_index_stride);
|
||
__bang_fusion(FUSION_FMA, col_buffer, index_buffer + k * col, col, i_buffer, col, col);
|
||
}
|
||
#if __BANG_ARCH__ >= 592
|
||
__bang_mul_scalar(col_buffer, col_buffer, sizeof(float), col); // index in byte
|
||
__scatter(max_buffer, temp_buffer, col_buffer, sizeof(uint32_t), NRAM2NRAM, sizeof(uint32_t),
|
||
col); // replace max value with -inf
|
||
#else
|
||
for (int i = 0; i < col; i++) {
|
||
uint32_t index = __load_nram(col_buffer + i);
|
||
max_buffer[index] = -INFINITY;
|
||
}
|
||
#endif
|
||
#if __BANG_ARCH__ < 500
|
||
if (is_deal_group) {
|
||
for (int i = 0; i < col; i++) {
|
||
uint32_t index = __load_nram((uint32_t *)value_buffer + k * col + i);
|
||
__memcpy(compute_buffer + i * row + index * group_size,
|
||
src_buffer + i * row + index * group_size, group_size * sizeof(float), NRAM2NRAM);
|
||
}
|
||
}
|
||
#endif
|
||
}
|
||
#if __BANG_ARCH__ >= 592
|
||
if (is_deal_group) {
|
||
__bang_transpose(index_buffer, (uint32_t *)value_buffer, topk, col);
|
||
__bang_mul_scalar((uint32_t *)value_buffer, i_buffer, row * sizeof(float), col);
|
||
__bang_move(value_buffer, value_buffer, col * sizeof(uint32_t), col * sizeof(uint32_t), 0,
|
||
topk - 1);
|
||
__bang_transpose((uint32_t *)compute_buffer, (uint32_t *)value_buffer, topk, col);
|
||
__bang_fusion(FUSION_FMA, index_buffer, index_buffer, group_size * sizeof(float),
|
||
(uint32_t *)compute_buffer, col * topk, col * topk);
|
||
__gather(compute_buffer, src_buffer, (uint32_t *)index_buffer, group_size * sizeof(float),
|
||
NRAM2NRAM, group_size * sizeof(float), col * topk);
|
||
__bang_write_value(src_buffer, row * col, -INFINITY);
|
||
__scatter(src_buffer, compute_buffer, index_buffer, group_size * sizeof(float), NRAM2NRAM,
|
||
group_size * sizeof(float), col * topk);
|
||
}
|
||
#endif
|
||
}
|
||
|
||
template <typename T>
|
||
__mlu_func__ void computeSoftmaxTopk(T *sram_buffer,
|
||
T *load_buffer,
|
||
float *src_buffer,
|
||
float *compute_buffer,
|
||
float *group_max_buffer,
|
||
float *nramout_value,
|
||
uint32_t *nramout_index,
|
||
uint32_t *i_buffer,
|
||
uint32_t *col_buffer,
|
||
float *softmax_buffer,
|
||
uint32_t row,
|
||
uint32_t nram_compute_col_num,
|
||
uint32_t mask_num,
|
||
uint32_t nram_max_col_num,
|
||
uint32_t topk,
|
||
int num_expert_group,
|
||
uint32_t topk_group,
|
||
uint32_t top_num,
|
||
uint32_t nram_col_offset,
|
||
int normalize_mode,
|
||
bool valid_mask,
|
||
bool split_mask) {
|
||
uint32_t nram_compute_num = nram_compute_col_num * row;
|
||
// convert to float for half/bf16 datatype
|
||
if (std::is_same<T, half>::value) {
|
||
__bang_half2float(src_buffer, (half *)load_buffer, nram_compute_num);
|
||
} else if (std::is_same<T, bfloat16_t>::value) {
|
||
__bang_bfloat162float(src_buffer, (bfloat16_t *)load_buffer, nram_compute_num);
|
||
}
|
||
// transpose [col, row] to [row, col]. To accelerate max/sum compute with maxpool/sumpool.
|
||
__bang_transpose(compute_buffer, src_buffer, nram_compute_col_num, row);
|
||
|
||
// compute softmax
|
||
int tmp = 0x3fb8aa3b;
|
||
float log2e = *(float *)&tmp; // for exp
|
||
// src_buffer reuse as buffer for max/sum.
|
||
__bang_maxpool(src_buffer, compute_buffer, nram_compute_col_num, row, 1, row, 1, 1, 1); // max
|
||
__bang_fusion(FUSION_FSM, compute_buffer, compute_buffer, src_buffer, log2e, nram_compute_num,
|
||
nram_compute_col_num);
|
||
__bang_pow2(compute_buffer, compute_buffer, nram_compute_num); // exp(input - max)
|
||
__bang_sumpool(src_buffer, compute_buffer, nram_compute_col_num, row, 1, row, 1, 1, 1); // sum
|
||
__bang_recip(src_buffer, src_buffer, nram_compute_col_num); // 1/sum
|
||
__bang_cycle_mul(compute_buffer, compute_buffer, src_buffer, nram_compute_num,
|
||
nram_compute_col_num);
|
||
__sync_cluster();
|
||
// move mask and compute
|
||
if (valid_mask) {
|
||
if (!split_mask) {
|
||
__bang_transpose(src_buffer, compute_buffer, row, nram_compute_col_num);
|
||
if (std::is_same<T, half>::value) {
|
||
__memcpy((half *)compute_buffer + mask_num * row, sram_buffer, mask_num * row * sizeof(T),
|
||
SRAM2NRAM);
|
||
__bang_half2float((float *)compute_buffer, (half *)compute_buffer + mask_num * row,
|
||
mask_num * row);
|
||
} else if (std::is_same<T, bfloat16_t>::value) {
|
||
__memcpy((bfloat16_t *)compute_buffer + mask_num * row, sram_buffer,
|
||
mask_num * row * sizeof(T), SRAM2NRAM);
|
||
__bang_bfloat162float((float *)compute_buffer,
|
||
(bfloat16_t *)compute_buffer + mask_num * row, mask_num * row);
|
||
} else {
|
||
__memcpy(compute_buffer, sram_buffer, mask_num * row * sizeof(T), SRAM2NRAM);
|
||
}
|
||
__bang_cycle_mul(src_buffer, src_buffer, compute_buffer, nram_compute_col_num * row,
|
||
mask_num * row);
|
||
__bang_transpose(compute_buffer, src_buffer, nram_compute_col_num, row);
|
||
} else {
|
||
transhw2wh<T, float, SRAM2NRAM>(src_buffer, sram_buffer + nram_col_offset * row,
|
||
nram_compute_col_num, row);
|
||
__sync();
|
||
__bang_mul(compute_buffer, compute_buffer, src_buffer, nram_compute_col_num * row);
|
||
}
|
||
}
|
||
if (normalize_mode == 2) {
|
||
__bang_sumpool(softmax_buffer, compute_buffer, nram_compute_col_num, row, 1, row, 1, 1, 1);
|
||
}
|
||
|
||
if (num_expert_group <= 1) {
|
||
// num_expert_group <= 1, maintain original topk calculation logic
|
||
getTopk(nramout_value, nramout_index, src_buffer, compute_buffer, compute_buffer, src_buffer,
|
||
i_buffer, col_buffer, topk, num_expert_group, nram_compute_col_num, row,
|
||
nram_max_col_num * topk * sizeof(float), 0, false);
|
||
} else {
|
||
// num_expert_group > 1, use grouped_topk calculation logic
|
||
uint32_t group_size = row / num_expert_group;
|
||
__bang_transpose(src_buffer, compute_buffer, row, nram_compute_col_num);
|
||
__bang_maxpool(group_max_buffer, compute_buffer, nram_compute_col_num, num_expert_group,
|
||
group_size, 1, group_size, 1, 1);
|
||
__bang_write_value(compute_buffer, row * nram_compute_col_num, -INFINITY);
|
||
// get topk_group
|
||
getTopk(nramout_value, nramout_index, src_buffer, compute_buffer, group_max_buffer,
|
||
(float *)nramout_index, i_buffer, col_buffer, topk_group, num_expert_group,
|
||
nram_compute_col_num, row, nram_max_col_num * topk * sizeof(float), group_size, true);
|
||
// get topk
|
||
#if __BANG_ARCH__ < 500
|
||
__bang_transpose(src_buffer, compute_buffer, nram_compute_col_num, row);
|
||
getTopk(nramout_value, nramout_index, src_buffer, compute_buffer, src_buffer, compute_buffer,
|
||
i_buffer, col_buffer, topk, num_expert_group, nram_compute_col_num, row,
|
||
nram_max_col_num * top_num * sizeof(float), 0, false);
|
||
#else
|
||
__bang_transpose(compute_buffer, src_buffer, nram_compute_col_num, row);
|
||
getTopk(nramout_value, nramout_index, src_buffer, compute_buffer, compute_buffer, src_buffer,
|
||
i_buffer, col_buffer, topk, num_expert_group, nram_compute_col_num, row,
|
||
nram_max_col_num * top_num * sizeof(float), 0, false);
|
||
#endif
|
||
} // end else
|
||
|
||
// normalize result
|
||
if (normalize_mode == 1) {
|
||
// compute_buffer reuse as buffer for sum.
|
||
__bang_sumpool(compute_buffer, nramout_value, nram_compute_col_num, topk, 1, topk, 1, 1, 1);
|
||
__bang_recip(compute_buffer, compute_buffer, nram_compute_col_num);
|
||
__bang_cycle_mul(nramout_value, nramout_value, compute_buffer, topk * nram_compute_col_num,
|
||
nram_compute_col_num);
|
||
} else if (normalize_mode == 2) {
|
||
__bang_recip(compute_buffer, softmax_buffer, nram_compute_col_num);
|
||
__bang_cycle_mul(nramout_value, nramout_value, compute_buffer, topk * nram_compute_col_num,
|
||
nram_compute_col_num);
|
||
}
|
||
|
||
// transpose back. src and dst of transpose can not be the same address.
|
||
__bang_transpose(compute_buffer, nramout_value, topk, nram_compute_col_num);
|
||
__bang_transpose((uint32_t *)nramout_value, nramout_index, topk, nram_compute_col_num);
|
||
}
|
||
|
||
template <typename T>
|
||
__mlu_global__ void MLUSoftmaxTopkKernel(T *input,
|
||
T *mask,
|
||
int *index_out,
|
||
float *value_out,
|
||
int col,
|
||
int row,
|
||
int mask_num,
|
||
int topk,
|
||
int num_expert_group,
|
||
int topk_group,
|
||
int normalize_mode) {
|
||
bool valid_mask = (mask != nullptr);
|
||
int top_num = topk >= topk_group ? topk : topk_group;
|
||
uint32_t nram_low_space =
|
||
PAD_UP((row * 2 + top_num * 2 + 2 + (normalize_mode == 2) + num_expert_group) * sizeof(float),
|
||
SCATTER_ALIGN);
|
||
if (num_expert_group <= 1) {
|
||
nram_low_space =
|
||
PAD_UP((row * 2 + topk * 2 + 2 + (normalize_mode == 2)) * sizeof(float), SCATTER_ALIGN);
|
||
}
|
||
uint32_t nram_max_col_num = (NRAM_SIZE) / nram_low_space;
|
||
if (nram_max_col_num > col / taskDim + (col % taskDim > 0)) {
|
||
nram_max_col_num = col / taskDim + (col % taskDim > 0);
|
||
}
|
||
nram_max_col_num = PAD_DOWN(nram_max_col_num, SCATTER_ALIGN / sizeof(float));
|
||
if (nram_max_col_num <= 0) {
|
||
nram_max_col_num = SCATTER_ALIGN / sizeof(float);
|
||
}
|
||
uint32_t nram_deal_num = nram_max_col_num * row;
|
||
uint32_t batch = col / mask_num;
|
||
|
||
// nram split:
|
||
// |--------------------------|--------------------------|--------------------|...
|
||
// | size: nram/2 -col*topk*2 | size: nram/2 -col*topk*2 |col*num_expert_group|...
|
||
// | src_buffer | compute_buffer | group_max_buffer |...
|
||
// |--------------------------|--------------------------|--------------------|...
|
||
|
||
// |----------------------------------------|---------------|--------------|
|
||
// | nram_col_num*3 | col*topk | col*topk |
|
||
// | i_buffer | col_buffer | softmax_buffer | nramout_value | nramout_index|
|
||
// |----------------------------------------|---------------|--------------|
|
||
float *src_buffer = (float *)nram_buffer;
|
||
float *compute_buffer = src_buffer + PAD_UP(nram_deal_num, SCATTER_ALIGN / sizeof(float));
|
||
float *group_max_buffer = compute_buffer + nram_deal_num;
|
||
uint32_t *i_buffer = (uint32_t *)group_max_buffer + num_expert_group * nram_max_col_num;
|
||
if (num_expert_group <= 1) {
|
||
i_buffer = (uint32_t *)group_max_buffer;
|
||
}
|
||
uint32_t *col_buffer = i_buffer + nram_max_col_num;
|
||
float *softmax_buffer = (float *)col_buffer + nram_max_col_num;
|
||
if (normalize_mode != 2) {
|
||
softmax_buffer = (float *)col_buffer;
|
||
}
|
||
|
||
float *nramout_value = softmax_buffer + nram_max_col_num;
|
||
uint32_t *nramout_index = (uint32_t *)nramout_value + top_num * nram_max_col_num;
|
||
if (num_expert_group <= 1) {
|
||
nramout_index = (uint32_t *)nramout_value + topk * nram_max_col_num;
|
||
}
|
||
T *load_buffer = (T *)src_buffer;
|
||
if (std::is_same<T, half>::value || std::is_same<T, bfloat16_t>::value) {
|
||
load_buffer = load_buffer + nram_deal_num;
|
||
}
|
||
|
||
// set i_buffer
|
||
for (uint32_t i = 0; i < nram_max_col_num; i++) {
|
||
i_buffer[i] = i;
|
||
}
|
||
|
||
// input[batch, mask, low], mask[mask, low]
|
||
if (nram_max_col_num >= mask_num) { // nram can deal complete mask
|
||
bool split_mask = false;
|
||
uint32_t batch_seg = nram_max_col_num / mask_num;
|
||
uint32_t batch_rem = batch % batch_seg;
|
||
uint32_t batch_seg_num = batch / batch_seg + (batch_rem > 0);
|
||
int repeat = DIV_UP(batch_seg_num, taskDim);
|
||
for (int i = 0; i < repeat; i++) {
|
||
uint32_t seg_id = i * taskDim + taskId;
|
||
uint32_t sram_load_num = mask_num * row;
|
||
uint32_t sram_load_offset = 0;
|
||
uint32_t nram_compute_col_num = (seg_id == batch_seg_num - 1 && batch_rem > 0)
|
||
? batch_rem * mask_num
|
||
: batch_seg * mask_num;
|
||
uint32_t nram_load_num = seg_id < batch_seg_num ? nram_compute_col_num * row : 0;
|
||
uint32_t nram_store_num = seg_id < batch_seg_num ? nram_compute_col_num * topk : 0;
|
||
uint32_t nram_load_offset = seg_id * batch_seg * mask_num * row;
|
||
uint32_t nram_store_offset = seg_id * batch_seg * mask_num * topk;
|
||
|
||
// Load
|
||
if (valid_mask) {
|
||
__memcpy_async(sram_buffer, mask + sram_load_offset, sram_load_num * sizeof(T), GDRAM2SRAM);
|
||
}
|
||
if (nram_load_num > 0) {
|
||
__memcpy(load_buffer, input + nram_load_offset, nram_load_num * sizeof(T), GDRAM2NRAM);
|
||
}
|
||
|
||
// Compute
|
||
computeSoftmaxTopk<T>((T *)sram_buffer, load_buffer, src_buffer, compute_buffer,
|
||
group_max_buffer, nramout_value, nramout_index, i_buffer, col_buffer,
|
||
softmax_buffer, row, nram_compute_col_num, mask_num, nram_max_col_num,
|
||
topk, num_expert_group, topk_group, top_num, 0, normalize_mode,
|
||
valid_mask, split_mask);
|
||
|
||
// Store
|
||
if (nram_store_num > 0) {
|
||
__memcpy(value_out + nram_store_offset, compute_buffer, nram_store_num * sizeof(float),
|
||
NRAM2GDRAM);
|
||
__memcpy(index_out + nram_store_offset, nramout_value, nram_store_num * sizeof(int),
|
||
NRAM2GDRAM);
|
||
}
|
||
__sync_cluster();
|
||
}
|
||
} else {
|
||
bool split_mask = true;
|
||
uint32_t mask_seg = nram_max_col_num;
|
||
uint32_t mask_rem = mask_num % mask_seg;
|
||
uint32_t mask_seg_num = mask_num / mask_seg + (mask_rem > 0);
|
||
uint32_t sram_mask_seg_num = DIV_UP(mask_seg_num, coreDim);
|
||
uint32_t sram_mask_rem = mask_num % sram_mask_seg_num;
|
||
uint32_t sram_average_mask_num = mask_num / sram_mask_seg_num;
|
||
for (int i = taskIdY; i < sram_mask_seg_num * batch; i += taskDimY) {
|
||
uint32_t batch_idx = i / sram_mask_seg_num;
|
||
uint32_t mask_idx = i % sram_mask_seg_num;
|
||
uint32_t sram_deal_mask_num = sram_average_mask_num + (mask_idx < sram_mask_rem);
|
||
uint32_t sram_load_num = sram_deal_mask_num * row;
|
||
uint32_t sram_mask_offset = mask_idx < sram_mask_rem
|
||
? mask_idx * (sram_average_mask_num + 1)
|
||
: mask_idx * sram_average_mask_num + sram_mask_rem;
|
||
uint32_t sram_load_offset = sram_mask_offset * row;
|
||
uint32_t nram_average_mask_num = sram_deal_mask_num / taskDimX;
|
||
uint32_t nram_mask_rem = sram_deal_mask_num % taskDimX;
|
||
uint32_t nram_deal_mask_num = nram_average_mask_num + (taskIdX < nram_mask_rem);
|
||
uint32_t nram_load_num = nram_deal_mask_num * row;
|
||
uint32_t nram_col_offset = taskIdX < nram_mask_rem
|
||
? taskIdX * (nram_average_mask_num + 1)
|
||
: taskIdX * nram_average_mask_num + nram_mask_rem;
|
||
uint32_t nram_load_offset = (batch_idx * mask_num + sram_mask_offset + nram_col_offset) * row;
|
||
uint32_t nram_store_num = nram_deal_mask_num * topk;
|
||
uint32_t nram_store_offset =
|
||
(batch_idx * mask_num + sram_mask_offset + nram_col_offset) * topk;
|
||
// Load
|
||
if (valid_mask) {
|
||
__memcpy_async(sram_buffer, mask + sram_load_offset, sram_load_num * sizeof(T), GDRAM2SRAM);
|
||
}
|
||
if (nram_load_num > 0) {
|
||
__memcpy(load_buffer, input + nram_load_offset, nram_load_num * sizeof(T), GDRAM2NRAM);
|
||
}
|
||
|
||
// Compute
|
||
computeSoftmaxTopk<T>((T *)sram_buffer, load_buffer, src_buffer, compute_buffer,
|
||
group_max_buffer, nramout_value, nramout_index, i_buffer, col_buffer,
|
||
softmax_buffer, row, nram_deal_mask_num, mask_num, nram_max_col_num,
|
||
topk, num_expert_group, topk_group, top_num, nram_col_offset,
|
||
normalize_mode, valid_mask, split_mask);
|
||
|
||
// Store
|
||
if (nram_store_num > 0) {
|
||
__memcpy(value_out + nram_store_offset, compute_buffer, nram_store_num * sizeof(float),
|
||
NRAM2GDRAM);
|
||
__memcpy(index_out + nram_store_offset, nramout_value, nram_store_num * sizeof(int),
|
||
NRAM2GDRAM);
|
||
}
|
||
__sync_cluster();
|
||
}
|
||
}
|
||
}
|
||
|
||
} // namespace kernels
|
||
|
||
KernelStatus invokeMoeSoftmaxTopkKernel(cnrtQueue_t queue,
|
||
float *reduce_weight,
|
||
int *expert_id,
|
||
const void *input,
|
||
const void *mask,
|
||
const int num_token,
|
||
const int num_expert,
|
||
const int num_mask,
|
||
const int topk,
|
||
const int num_expert_group,
|
||
const int topk_group,
|
||
const cnnlDataType_t dtype,
|
||
const int normalize_mode) {
|
||
CNdev dev;
|
||
cnCtxGetDevice(&dev);
|
||
int cluster_num;
|
||
int core_num;
|
||
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
|
||
CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev));
|
||
cnrtDim3_t dim{.x = (uint32_t)core_num, .y = (uint32_t)cluster_num, .z = 1};
|
||
int top_num = topk >= topk_group ? topk : topk_group;
|
||
if (num_expert_group <= 1) {
|
||
if (num_expert > (NRAM_SIZE - (topk * 2 + 3) * sizeof(float)) / 2 / sizeof(float)) {
|
||
std::cerr << "[invokeMoeSoftmaxTopkKernel]: num_expert is too large, currently not supported."
|
||
<< "Supported max num_expert:"
|
||
<< (NRAM_SIZE - (topk * 2 + 3) * sizeof(float)) / 2 / sizeof(float)
|
||
<< ". Current num_expert:" << num_expert;
|
||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||
}
|
||
} else {
|
||
if (num_expert >
|
||
(NRAM_SIZE - (top_num * 2 + 2 + num_expert_group) * sizeof(float)) / 2 / sizeof(float)) {
|
||
std::cerr << "[invokeMoeSoftmaxTopkKernel]: num_expert is too large, currently not supported."
|
||
<< "Supported max num_expert:"
|
||
<< (NRAM_SIZE - (topk * 2 + 3) * sizeof(float)) / 2 / sizeof(float)
|
||
<< ". Current num_expert:" << num_expert;
|
||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||
}
|
||
}
|
||
if (topk > num_expert) {
|
||
std::cerr << "[invokeMoeSoftmaxTopkKernel]: topk is larger than num_expert."
|
||
<< "topk:" << topk << ". num_expert:" << num_expert;
|
||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||
}
|
||
if (num_expert_group > 1) {
|
||
if (mask != nullptr) {
|
||
std::cerr << "[invokeMoeSoftmaxTopkKernel]: if num_expert_group > 1, mask should be nullptr";
|
||
}
|
||
if (num_expert % num_expert_group != 0) {
|
||
std::cerr << "[invokeMoeSoftmaxTopkKernel]: if num_expert_group > 1, num_expert should be"
|
||
<< "divisible by num_expert_group, but now num_expert:" << num_expert
|
||
<< ", num_expert_group:" << num_expert_group;
|
||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||
}
|
||
if (topk_group <= 0 || topk_group > num_expert_group) {
|
||
std::cerr << "[invokeMoeSoftmaxTopkKernel]: if num_expert_group > 1, topk_group should be"
|
||
<< "larger than 0 and less than or equal to num_expert_group, but now topk_group"
|
||
<< topk_group << ", num_expert group:" << num_expert_group;
|
||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||
}
|
||
if (topk > (num_expert / num_expert_group) * topk_group) {
|
||
std::cerr << "[invokeMoeSoftmaxTopkKernel]: if num_expert_group > 1, topk should be less"
|
||
<< "than or equal to (num_expert / num_expert_group) * topk_group, but now"
|
||
<< "topk :" << topk << ", num_expert:" << num_expert
|
||
<< ", num_expert_group:" << num_expert_group << ", topk_group:" << topk_group;
|
||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||
}
|
||
}
|
||
|
||
if (dtype == CNNL_DTYPE_FLOAT) {
|
||
kernels::MLUSoftmaxTopkKernel<<<dim, cnrtFuncTypeUnion1, queue>>>(
|
||
(float *)input, (float *)mask, expert_id, reduce_weight, num_token, num_expert, num_mask,
|
||
topk, num_expert_group, topk_group, normalize_mode);
|
||
} else if (dtype == CNNL_DTYPE_HALF) {
|
||
kernels::MLUSoftmaxTopkKernel<<<dim, cnrtFuncTypeUnion1, queue>>>(
|
||
(half *)input, (half *)mask, expert_id, reduce_weight, num_token, num_expert, num_mask,
|
||
topk, num_expert_group, topk_group, normalize_mode);
|
||
} else if (dtype == CNNL_DTYPE_BFLOAT16) {
|
||
if (!isBf16Supported()) {
|
||
std::cerr << "[invokeMoeSoftmaxTopkKernel]: MLU300 devices do not support bfloat16."
|
||
<< std::endl;
|
||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||
}
|
||
kernels::MLUSoftmaxTopkKernel<<<dim, cnrtFuncTypeUnion1, queue>>>(
|
||
(bfloat16_t *)input, (bfloat16_t *)mask, expert_id, reduce_weight, num_token, num_expert,
|
||
num_mask, topk, num_expert_group, topk_group, normalize_mode);
|
||
} else {
|
||
std::cerr << "[invokeMoeSoftmaxTopkKernel]: source type not supported ";
|
||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||
}
|
||
|
||
return KernelStatus::KERNEL_STATUS_SUCCESS;
|
||
}
|
||
|
||
} // namespace tmo
|