Files
enginex-mlu370-vllm/torch_mlu_ops-v1.3.2/csrc/kernels/moe/softmax_topk.mlu

603 lines
29 KiB
Plaintext
Raw Normal View History

2026-02-04 17:39:32 +08:00
#include <mlu.h>
#include <cassert>
#include <iostream>
#include <limits>
#include <map>
#include <ostream>
#include "cnnl.h"
#include "cnrt.h"
#include "softmax_topk.mluh"
namespace tmo {
namespace kernels {
#define SCATTER_ALIGN (64) // align for __scatter()
#define NRAM_SIZE (__MLU_NRAM_SIZE__ * 1024 - 32 * 1024)
#define SRAM_SIZE (__MLU_SRAM_SIZE__ * 1024 - 32 * 1024)
#define TILING_ALIGN (64)
#define DIV_UP(x, y) ((x) / (y) + (int)((x) % (y) > 0))
__nram__ int8_t nram_buffer[NRAM_SIZE];
__mlu_shared__ int8_t sram_buffer[SRAM_SIZE];
#define __TRANS_TILING(TYPE, CONVERT) \
__asm__ volatile("trans.tiling." TYPE \
" [%[dst]], [%[src]]," \
"%[in0], %[in1], %[is1], %[in2], %[is2], %[in3], %[is3], %[in4]," \
"%[is4], %[in5], %[is5]," \
"%[dn0], %[dn1], %[ds1], %[dn2], %[ds2], %[dn3], %[ds3], %[dn4]," \
"%[ds4], %[dn5], %[ds5]" CONVERT ::[dst] "r"(dst), \
[src] "r"(src), [in0] "r"(in0), [in1] "r"(in1), [is1] "r"(is1), [in2] "r"(in2), \
[is2] "r"(is2), [in3] "r"(in3), [is3] "r"(is3), [in4] "r"(in4), [is4] "r"(is4), \
[in5] "r"(in5), [is5] "r"(is5), [dn0] "r"(dn0), [dn1] "r"(dn1), [ds1] "r"(ds1), \
[dn2] "r"(dn2), [ds2] "r"(ds2), [dn3] "r"(dn3), [ds3] "r"(ds3), [dn4] "r"(dn4), \
[ds4] "r"(ds4), [dn5] "r"(dn5), [ds5] "r"(ds5));
template <typename SRC_DTYPE, typename DST_DTYPE, mluMemcpyDirection_t dir>
__mlu_func__ void __mlvm_trans(DST_DTYPE *dst,
const SRC_DTYPE *src,
const uint32_t in0,
const uint32_t in1,
const uint32_t is1,
const uint32_t in2,
const uint32_t is2,
const uint32_t in3,
const uint32_t is3,
const uint32_t in4,
const uint32_t is4,
const uint32_t in5,
const uint32_t is5,
const uint32_t dn0,
const uint32_t dn1,
const uint32_t ds1,
const uint32_t dn2,
const uint32_t ds2,
const uint32_t dn3,
const uint32_t ds3,
const uint32_t dn4,
const uint32_t ds4,
const uint32_t dn5,
const uint32_t ds5) {
if (SRAM2NRAM == dir && std::is_same<DST_DTYPE, float>::value) {
if (std::is_same<SRC_DTYPE, float>::value) {
__TRANS_TILING("nram.sram.b32", ";")
} else if (std::is_same<SRC_DTYPE, half>::value) {
__TRANS_TILING("nram.sram.b16", ", .cvt.f32.f16();")
#if __BANG_ARCH__ >= 500
} else if (std::is_same<SRC_DTYPE, bfloat16_t>::value) {
__TRANS_TILING("nram.sram.b16", ", .cvt.f32.bf16();")
#endif
}
}
}
/* 将shape为[h,w]的数据转置为[w,h](带转数)分4块分别进行处理。
* dst: dst地址
* src: src地址
* h: h方向大小
* w: w方向大小
*/
template <typename SRC_DTYPE, typename DST_DTYPE, mluMemcpyDirection_t dir>
__mlu_func__ void transhw2wh(DST_DTYPE *dst, SRC_DTYPE *src, uint32_t h, uint32_t w) {
uint32_t align_num = TILING_ALIGN / sizeof(SRC_DTYPE);
uint32_t w_align = w / align_num;
uint32_t w_rem = w % align_num;
uint32_t h_align = h / align_num;
uint32_t h_rem = h % align_num;
uint32_t in0 = TILING_ALIGN, dn0 = TILING_ALIGN;
uint32_t in1 = align_num, is1 = w * sizeof(SRC_DTYPE);
uint32_t in3 = w_align, is3 = TILING_ALIGN;
uint32_t in4 = h_align, is4 = w * TILING_ALIGN;
uint32_t dn1 = align_num, ds1 = h * sizeof(DST_DTYPE);
uint32_t dn3 = in3, ds3 = h * align_num * sizeof(DST_DTYPE);
uint32_t dn4 = in4, ds4 = align_num * sizeof(DST_DTYPE);
/* 1. h_align * w_align */
if (w_align > 0 && h_align > 0) {
__mlvm_trans<SRC_DTYPE, DST_DTYPE, dir>(dst, src, in0, in1, is1, 1, 0, in3, is3, in4, is4, 1, 0,
dn0, dn1, ds1, 1, 0, dn3, ds3, dn4, ds4, 1, 0);
}
/* 2. h_align * w_rem */
if (w_rem > 0 && h_align > 0) {
SRC_DTYPE *src_temp = src + w_align * align_num;
DST_DTYPE *dst_temp = dst + w_align * align_num * h;
in0 = w_rem * sizeof(SRC_DTYPE);
dn0 = TILING_ALIGN;
in1 = align_num;
is1 = w * sizeof(SRC_DTYPE);
in4 = h_align;
is4 = w * TILING_ALIGN;
dn1 = w_rem;
ds1 = h * sizeof(DST_DTYPE);
dn4 = in4;
ds4 = align_num * sizeof(DST_DTYPE);
__mlvm_trans<SRC_DTYPE, DST_DTYPE, dir>(dst_temp, src_temp, in0, in1, is1, 1, 0, 1, 0, in4, is4,
1, 0, dn0, dn1, ds1, 1, 0, 1, 0, dn4, ds4, 1, 0);
}
/* 3. h_rem * w_align */
if (w_align > 0 && h_rem > 0) {
SRC_DTYPE *src_temp = src + h_align * align_num * w;
DST_DTYPE *dst_temp = dst + h_align * align_num;
in0 = TILING_ALIGN;
dn0 = h_rem * sizeof(SRC_DTYPE);
in1 = h_rem;
is1 = w * sizeof(SRC_DTYPE);
in4 = w_align;
is4 = TILING_ALIGN;
dn1 = align_num;
ds1 = h * sizeof(DST_DTYPE);
dn4 = in4;
ds4 = h * align_num * sizeof(DST_DTYPE);
__mlvm_trans<SRC_DTYPE, DST_DTYPE, dir>(dst_temp, src_temp, in0, in1, is1, 1, 0, 1, 0, in4, is4,
1, 0, dn0, dn1, ds1, 1, 0, 1, 0, dn4, ds4, 1, 0);
}
/* 4. h_rem * w_rem */
if (w_rem > 0 && h_rem > 0) {
SRC_DTYPE *src_temp = src + h_align * align_num * w + w_align * align_num;
DST_DTYPE *dst_temp = dst + w_align * align_num * h + h_align * align_num;
in0 = w_rem * sizeof(SRC_DTYPE);
dn0 = h_rem * sizeof(SRC_DTYPE);
in1 = h_rem;
is1 = w * sizeof(SRC_DTYPE);
dn1 = w_rem;
ds1 = h * sizeof(DST_DTYPE);
__mlvm_trans<SRC_DTYPE, DST_DTYPE, dir>(dst_temp, src_temp, in0, in1, is1, 1, 0, 1, 0, 1, 0, 1,
0, dn0, dn1, ds1, 1, 0, 1, 0, 1, 0, 1, 0);
}
}
__mlu_func__ void getTopk(float *value_buffer,
uint32_t *index_buffer,
float *src_buffer,
float *compute_buffer,
float *max_buffer,
float *temp_buffer,
uint32_t *i_buffer,
uint32_t *col_buffer,
uint32_t topk,
uint32_t num_expert_group,
uint32_t col,
uint32_t row,
uint32_t value_index_stride,
uint32_t group_size,
bool is_deal_group) {
__bang_write_value((float *)temp_buffer, col, -INFINITY); // set -inf vector
for (int k = 0; k < topk; k++) {
if (is_deal_group) {
__bang_maxpool_index((uint32_t *)value_buffer + k * col, max_buffer, col, 1, num_expert_group,
1, num_expert_group, 1, 1);
__bang_fusion(FUSION_FMA, col_buffer, (uint32_t *)value_buffer + k * col, col, i_buffer, col,
col);
} else {
__bang_maxpool_value_index(value_buffer + k * col, max_buffer, col, 1, row, 1, row, 1, 1,
value_index_stride);
__bang_fusion(FUSION_FMA, col_buffer, index_buffer + k * col, col, i_buffer, col, col);
}
#if __BANG_ARCH__ >= 592
__bang_mul_scalar(col_buffer, col_buffer, sizeof(float), col); // index in byte
__scatter(max_buffer, temp_buffer, col_buffer, sizeof(uint32_t), NRAM2NRAM, sizeof(uint32_t),
col); // replace max value with -inf
#else
for (int i = 0; i < col; i++) {
uint32_t index = __load_nram(col_buffer + i);
max_buffer[index] = -INFINITY;
}
#endif
#if __BANG_ARCH__ < 500
if (is_deal_group) {
for (int i = 0; i < col; i++) {
uint32_t index = __load_nram((uint32_t *)value_buffer + k * col + i);
__memcpy(compute_buffer + i * row + index * group_size,
src_buffer + i * row + index * group_size, group_size * sizeof(float), NRAM2NRAM);
}
}
#endif
}
#if __BANG_ARCH__ >= 592
if (is_deal_group) {
__bang_transpose(index_buffer, (uint32_t *)value_buffer, topk, col);
__bang_mul_scalar((uint32_t *)value_buffer, i_buffer, row * sizeof(float), col);
__bang_move(value_buffer, value_buffer, col * sizeof(uint32_t), col * sizeof(uint32_t), 0,
topk - 1);
__bang_transpose((uint32_t *)compute_buffer, (uint32_t *)value_buffer, topk, col);
__bang_fusion(FUSION_FMA, index_buffer, index_buffer, group_size * sizeof(float),
(uint32_t *)compute_buffer, col * topk, col * topk);
__gather(compute_buffer, src_buffer, (uint32_t *)index_buffer, group_size * sizeof(float),
NRAM2NRAM, group_size * sizeof(float), col * topk);
__bang_write_value(src_buffer, row * col, -INFINITY);
__scatter(src_buffer, compute_buffer, index_buffer, group_size * sizeof(float), NRAM2NRAM,
group_size * sizeof(float), col * topk);
}
#endif
}
template <typename T>
__mlu_func__ void computeSoftmaxTopk(T *sram_buffer,
T *load_buffer,
float *src_buffer,
float *compute_buffer,
float *group_max_buffer,
float *nramout_value,
uint32_t *nramout_index,
uint32_t *i_buffer,
uint32_t *col_buffer,
float *softmax_buffer,
uint32_t row,
uint32_t nram_compute_col_num,
uint32_t mask_num,
uint32_t nram_max_col_num,
uint32_t topk,
int num_expert_group,
uint32_t topk_group,
uint32_t top_num,
uint32_t nram_col_offset,
int normalize_mode,
bool valid_mask,
bool split_mask) {
uint32_t nram_compute_num = nram_compute_col_num * row;
// convert to float for half/bf16 datatype
if (std::is_same<T, half>::value) {
__bang_half2float(src_buffer, (half *)load_buffer, nram_compute_num);
} else if (std::is_same<T, bfloat16_t>::value) {
__bang_bfloat162float(src_buffer, (bfloat16_t *)load_buffer, nram_compute_num);
}
// transpose [col, row] to [row, col]. To accelerate max/sum compute with maxpool/sumpool.
__bang_transpose(compute_buffer, src_buffer, nram_compute_col_num, row);
// compute softmax
int tmp = 0x3fb8aa3b;
float log2e = *(float *)&tmp; // for exp
// src_buffer reuse as buffer for max/sum.
__bang_maxpool(src_buffer, compute_buffer, nram_compute_col_num, row, 1, row, 1, 1, 1); // max
__bang_fusion(FUSION_FSM, compute_buffer, compute_buffer, src_buffer, log2e, nram_compute_num,
nram_compute_col_num);
__bang_pow2(compute_buffer, compute_buffer, nram_compute_num); // exp(input - max)
__bang_sumpool(src_buffer, compute_buffer, nram_compute_col_num, row, 1, row, 1, 1, 1); // sum
__bang_recip(src_buffer, src_buffer, nram_compute_col_num); // 1/sum
__bang_cycle_mul(compute_buffer, compute_buffer, src_buffer, nram_compute_num,
nram_compute_col_num);
__sync_cluster();
// move mask and compute
if (valid_mask) {
if (!split_mask) {
__bang_transpose(src_buffer, compute_buffer, row, nram_compute_col_num);
if (std::is_same<T, half>::value) {
__memcpy((half *)compute_buffer + mask_num * row, sram_buffer, mask_num * row * sizeof(T),
SRAM2NRAM);
__bang_half2float((float *)compute_buffer, (half *)compute_buffer + mask_num * row,
mask_num * row);
} else if (std::is_same<T, bfloat16_t>::value) {
__memcpy((bfloat16_t *)compute_buffer + mask_num * row, sram_buffer,
mask_num * row * sizeof(T), SRAM2NRAM);
__bang_bfloat162float((float *)compute_buffer,
(bfloat16_t *)compute_buffer + mask_num * row, mask_num * row);
} else {
__memcpy(compute_buffer, sram_buffer, mask_num * row * sizeof(T), SRAM2NRAM);
}
__bang_cycle_mul(src_buffer, src_buffer, compute_buffer, nram_compute_col_num * row,
mask_num * row);
__bang_transpose(compute_buffer, src_buffer, nram_compute_col_num, row);
} else {
transhw2wh<T, float, SRAM2NRAM>(src_buffer, sram_buffer + nram_col_offset * row,
nram_compute_col_num, row);
__sync();
__bang_mul(compute_buffer, compute_buffer, src_buffer, nram_compute_col_num * row);
}
}
if (normalize_mode == 2) {
__bang_sumpool(softmax_buffer, compute_buffer, nram_compute_col_num, row, 1, row, 1, 1, 1);
}
if (num_expert_group <= 1) {
// num_expert_group <= 1, maintain original topk calculation logic
getTopk(nramout_value, nramout_index, src_buffer, compute_buffer, compute_buffer, src_buffer,
i_buffer, col_buffer, topk, num_expert_group, nram_compute_col_num, row,
nram_max_col_num * topk * sizeof(float), 0, false);
} else {
// num_expert_group > 1, use grouped_topk calculation logic
uint32_t group_size = row / num_expert_group;
__bang_transpose(src_buffer, compute_buffer, row, nram_compute_col_num);
__bang_maxpool(group_max_buffer, compute_buffer, nram_compute_col_num, num_expert_group,
group_size, 1, group_size, 1, 1);
__bang_write_value(compute_buffer, row * nram_compute_col_num, -INFINITY);
// get topk_group
getTopk(nramout_value, nramout_index, src_buffer, compute_buffer, group_max_buffer,
(float *)nramout_index, i_buffer, col_buffer, topk_group, num_expert_group,
nram_compute_col_num, row, nram_max_col_num * topk * sizeof(float), group_size, true);
// get topk
#if __BANG_ARCH__ < 500
__bang_transpose(src_buffer, compute_buffer, nram_compute_col_num, row);
getTopk(nramout_value, nramout_index, src_buffer, compute_buffer, src_buffer, compute_buffer,
i_buffer, col_buffer, topk, num_expert_group, nram_compute_col_num, row,
nram_max_col_num * top_num * sizeof(float), 0, false);
#else
__bang_transpose(compute_buffer, src_buffer, nram_compute_col_num, row);
getTopk(nramout_value, nramout_index, src_buffer, compute_buffer, compute_buffer, src_buffer,
i_buffer, col_buffer, topk, num_expert_group, nram_compute_col_num, row,
nram_max_col_num * top_num * sizeof(float), 0, false);
#endif
} // end else
// normalize result
if (normalize_mode == 1) {
// compute_buffer reuse as buffer for sum.
__bang_sumpool(compute_buffer, nramout_value, nram_compute_col_num, topk, 1, topk, 1, 1, 1);
__bang_recip(compute_buffer, compute_buffer, nram_compute_col_num);
__bang_cycle_mul(nramout_value, nramout_value, compute_buffer, topk * nram_compute_col_num,
nram_compute_col_num);
} else if (normalize_mode == 2) {
__bang_recip(compute_buffer, softmax_buffer, nram_compute_col_num);
__bang_cycle_mul(nramout_value, nramout_value, compute_buffer, topk * nram_compute_col_num,
nram_compute_col_num);
}
// transpose back. src and dst of transpose can not be the same address.
__bang_transpose(compute_buffer, nramout_value, topk, nram_compute_col_num);
__bang_transpose((uint32_t *)nramout_value, nramout_index, topk, nram_compute_col_num);
}
template <typename T>
__mlu_global__ void MLUSoftmaxTopkKernel(T *input,
T *mask,
int *index_out,
float *value_out,
int col,
int row,
int mask_num,
int topk,
int num_expert_group,
int topk_group,
int normalize_mode) {
bool valid_mask = (mask != nullptr);
int top_num = topk >= topk_group ? topk : topk_group;
uint32_t nram_low_space =
PAD_UP((row * 2 + top_num * 2 + 2 + (normalize_mode == 2) + num_expert_group) * sizeof(float),
SCATTER_ALIGN);
if (num_expert_group <= 1) {
nram_low_space =
PAD_UP((row * 2 + topk * 2 + 2 + (normalize_mode == 2)) * sizeof(float), SCATTER_ALIGN);
}
uint32_t nram_max_col_num = (NRAM_SIZE) / nram_low_space;
if (nram_max_col_num > col / taskDim + (col % taskDim > 0)) {
nram_max_col_num = col / taskDim + (col % taskDim > 0);
}
nram_max_col_num = PAD_DOWN(nram_max_col_num, SCATTER_ALIGN / sizeof(float));
if (nram_max_col_num <= 0) {
nram_max_col_num = SCATTER_ALIGN / sizeof(float);
}
uint32_t nram_deal_num = nram_max_col_num * row;
uint32_t batch = col / mask_num;
// nram split:
// |--------------------------|--------------------------|--------------------|...
// | size: nram/2 -col*topk*2 | size: nram/2 -col*topk*2 |col*num_expert_group|...
// | src_buffer | compute_buffer | group_max_buffer |...
// |--------------------------|--------------------------|--------------------|...
// |----------------------------------------|---------------|--------------|
// | nram_col_num*3 | col*topk | col*topk |
// | i_buffer | col_buffer | softmax_buffer | nramout_value | nramout_index|
// |----------------------------------------|---------------|--------------|
float *src_buffer = (float *)nram_buffer;
float *compute_buffer = src_buffer + PAD_UP(nram_deal_num, SCATTER_ALIGN / sizeof(float));
float *group_max_buffer = compute_buffer + nram_deal_num;
uint32_t *i_buffer = (uint32_t *)group_max_buffer + num_expert_group * nram_max_col_num;
if (num_expert_group <= 1) {
i_buffer = (uint32_t *)group_max_buffer;
}
uint32_t *col_buffer = i_buffer + nram_max_col_num;
float *softmax_buffer = (float *)col_buffer + nram_max_col_num;
if (normalize_mode != 2) {
softmax_buffer = (float *)col_buffer;
}
float *nramout_value = softmax_buffer + nram_max_col_num;
uint32_t *nramout_index = (uint32_t *)nramout_value + top_num * nram_max_col_num;
if (num_expert_group <= 1) {
nramout_index = (uint32_t *)nramout_value + topk * nram_max_col_num;
}
T *load_buffer = (T *)src_buffer;
if (std::is_same<T, half>::value || std::is_same<T, bfloat16_t>::value) {
load_buffer = load_buffer + nram_deal_num;
}
// set i_buffer
for (uint32_t i = 0; i < nram_max_col_num; i++) {
i_buffer[i] = i;
}
// input[batch, mask, low], mask[mask, low]
if (nram_max_col_num >= mask_num) { // nram can deal complete mask
bool split_mask = false;
uint32_t batch_seg = nram_max_col_num / mask_num;
uint32_t batch_rem = batch % batch_seg;
uint32_t batch_seg_num = batch / batch_seg + (batch_rem > 0);
int repeat = DIV_UP(batch_seg_num, taskDim);
for (int i = 0; i < repeat; i++) {
uint32_t seg_id = i * taskDim + taskId;
uint32_t sram_load_num = mask_num * row;
uint32_t sram_load_offset = 0;
uint32_t nram_compute_col_num = (seg_id == batch_seg_num - 1 && batch_rem > 0)
? batch_rem * mask_num
: batch_seg * mask_num;
uint32_t nram_load_num = seg_id < batch_seg_num ? nram_compute_col_num * row : 0;
uint32_t nram_store_num = seg_id < batch_seg_num ? nram_compute_col_num * topk : 0;
uint32_t nram_load_offset = seg_id * batch_seg * mask_num * row;
uint32_t nram_store_offset = seg_id * batch_seg * mask_num * topk;
// Load
if (valid_mask) {
__memcpy_async(sram_buffer, mask + sram_load_offset, sram_load_num * sizeof(T), GDRAM2SRAM);
}
if (nram_load_num > 0) {
__memcpy(load_buffer, input + nram_load_offset, nram_load_num * sizeof(T), GDRAM2NRAM);
}
// Compute
computeSoftmaxTopk<T>((T *)sram_buffer, load_buffer, src_buffer, compute_buffer,
group_max_buffer, nramout_value, nramout_index, i_buffer, col_buffer,
softmax_buffer, row, nram_compute_col_num, mask_num, nram_max_col_num,
topk, num_expert_group, topk_group, top_num, 0, normalize_mode,
valid_mask, split_mask);
// Store
if (nram_store_num > 0) {
__memcpy(value_out + nram_store_offset, compute_buffer, nram_store_num * sizeof(float),
NRAM2GDRAM);
__memcpy(index_out + nram_store_offset, nramout_value, nram_store_num * sizeof(int),
NRAM2GDRAM);
}
__sync_cluster();
}
} else {
bool split_mask = true;
uint32_t mask_seg = nram_max_col_num;
uint32_t mask_rem = mask_num % mask_seg;
uint32_t mask_seg_num = mask_num / mask_seg + (mask_rem > 0);
uint32_t sram_mask_seg_num = DIV_UP(mask_seg_num, coreDim);
uint32_t sram_mask_rem = mask_num % sram_mask_seg_num;
uint32_t sram_average_mask_num = mask_num / sram_mask_seg_num;
for (int i = taskIdY; i < sram_mask_seg_num * batch; i += taskDimY) {
uint32_t batch_idx = i / sram_mask_seg_num;
uint32_t mask_idx = i % sram_mask_seg_num;
uint32_t sram_deal_mask_num = sram_average_mask_num + (mask_idx < sram_mask_rem);
uint32_t sram_load_num = sram_deal_mask_num * row;
uint32_t sram_mask_offset = mask_idx < sram_mask_rem
? mask_idx * (sram_average_mask_num + 1)
: mask_idx * sram_average_mask_num + sram_mask_rem;
uint32_t sram_load_offset = sram_mask_offset * row;
uint32_t nram_average_mask_num = sram_deal_mask_num / taskDimX;
uint32_t nram_mask_rem = sram_deal_mask_num % taskDimX;
uint32_t nram_deal_mask_num = nram_average_mask_num + (taskIdX < nram_mask_rem);
uint32_t nram_load_num = nram_deal_mask_num * row;
uint32_t nram_col_offset = taskIdX < nram_mask_rem
? taskIdX * (nram_average_mask_num + 1)
: taskIdX * nram_average_mask_num + nram_mask_rem;
uint32_t nram_load_offset = (batch_idx * mask_num + sram_mask_offset + nram_col_offset) * row;
uint32_t nram_store_num = nram_deal_mask_num * topk;
uint32_t nram_store_offset =
(batch_idx * mask_num + sram_mask_offset + nram_col_offset) * topk;
// Load
if (valid_mask) {
__memcpy_async(sram_buffer, mask + sram_load_offset, sram_load_num * sizeof(T), GDRAM2SRAM);
}
if (nram_load_num > 0) {
__memcpy(load_buffer, input + nram_load_offset, nram_load_num * sizeof(T), GDRAM2NRAM);
}
// Compute
computeSoftmaxTopk<T>((T *)sram_buffer, load_buffer, src_buffer, compute_buffer,
group_max_buffer, nramout_value, nramout_index, i_buffer, col_buffer,
softmax_buffer, row, nram_deal_mask_num, mask_num, nram_max_col_num,
topk, num_expert_group, topk_group, top_num, nram_col_offset,
normalize_mode, valid_mask, split_mask);
// Store
if (nram_store_num > 0) {
__memcpy(value_out + nram_store_offset, compute_buffer, nram_store_num * sizeof(float),
NRAM2GDRAM);
__memcpy(index_out + nram_store_offset, nramout_value, nram_store_num * sizeof(int),
NRAM2GDRAM);
}
__sync_cluster();
}
}
}
} // namespace kernels
KernelStatus invokeMoeSoftmaxTopkKernel(cnrtQueue_t queue,
float *reduce_weight,
int *expert_id,
const void *input,
const void *mask,
const int num_token,
const int num_expert,
const int num_mask,
const int topk,
const int num_expert_group,
const int topk_group,
const cnnlDataType_t dtype,
const int normalize_mode) {
CNdev dev;
cnCtxGetDevice(&dev);
int cluster_num;
int core_num;
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev));
cnrtDim3_t dim{.x = (uint32_t)core_num, .y = (uint32_t)cluster_num, .z = 1};
int top_num = topk >= topk_group ? topk : topk_group;
if (num_expert_group <= 1) {
if (num_expert > (NRAM_SIZE - (topk * 2 + 3) * sizeof(float)) / 2 / sizeof(float)) {
std::cerr << "[invokeMoeSoftmaxTopkKernel]: num_expert is too large, currently not supported."
<< "Supported max num_expert:"
<< (NRAM_SIZE - (topk * 2 + 3) * sizeof(float)) / 2 / sizeof(float)
<< ". Current num_expert:" << num_expert;
return KernelStatus::KERNEL_STATUS_FAILED;
}
} else {
if (num_expert >
(NRAM_SIZE - (top_num * 2 + 2 + num_expert_group) * sizeof(float)) / 2 / sizeof(float)) {
std::cerr << "[invokeMoeSoftmaxTopkKernel]: num_expert is too large, currently not supported."
<< "Supported max num_expert:"
<< (NRAM_SIZE - (topk * 2 + 3) * sizeof(float)) / 2 / sizeof(float)
<< ". Current num_expert:" << num_expert;
return KernelStatus::KERNEL_STATUS_FAILED;
}
}
if (topk > num_expert) {
std::cerr << "[invokeMoeSoftmaxTopkKernel]: topk is larger than num_expert."
<< "topk:" << topk << ". num_expert:" << num_expert;
return KernelStatus::KERNEL_STATUS_FAILED;
}
if (num_expert_group > 1) {
if (mask != nullptr) {
std::cerr << "[invokeMoeSoftmaxTopkKernel]: if num_expert_group > 1, mask should be nullptr";
}
if (num_expert % num_expert_group != 0) {
std::cerr << "[invokeMoeSoftmaxTopkKernel]: if num_expert_group > 1, num_expert should be"
<< "divisible by num_expert_group, but now num_expert:" << num_expert
<< ", num_expert_group:" << num_expert_group;
return KernelStatus::KERNEL_STATUS_FAILED;
}
if (topk_group <= 0 || topk_group > num_expert_group) {
std::cerr << "[invokeMoeSoftmaxTopkKernel]: if num_expert_group > 1, topk_group should be"
<< "larger than 0 and less than or equal to num_expert_group, but now topk_group"
<< topk_group << ", num_expert group:" << num_expert_group;
return KernelStatus::KERNEL_STATUS_FAILED;
}
if (topk > (num_expert / num_expert_group) * topk_group) {
std::cerr << "[invokeMoeSoftmaxTopkKernel]: if num_expert_group > 1, topk should be less"
<< "than or equal to (num_expert / num_expert_group) * topk_group, but now"
<< "topk :" << topk << ", num_expert:" << num_expert
<< ", num_expert_group:" << num_expert_group << ", topk_group:" << topk_group;
return KernelStatus::KERNEL_STATUS_FAILED;
}
}
if (dtype == CNNL_DTYPE_FLOAT) {
kernels::MLUSoftmaxTopkKernel<<<dim, cnrtFuncTypeUnion1, queue>>>(
(float *)input, (float *)mask, expert_id, reduce_weight, num_token, num_expert, num_mask,
topk, num_expert_group, topk_group, normalize_mode);
} else if (dtype == CNNL_DTYPE_HALF) {
kernels::MLUSoftmaxTopkKernel<<<dim, cnrtFuncTypeUnion1, queue>>>(
(half *)input, (half *)mask, expert_id, reduce_weight, num_token, num_expert, num_mask,
topk, num_expert_group, topk_group, normalize_mode);
} else if (dtype == CNNL_DTYPE_BFLOAT16) {
if (!isBf16Supported()) {
std::cerr << "[invokeMoeSoftmaxTopkKernel]: MLU300 devices do not support bfloat16."
<< std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
kernels::MLUSoftmaxTopkKernel<<<dim, cnrtFuncTypeUnion1, queue>>>(
(bfloat16_t *)input, (bfloat16_t *)mask, expert_id, reduce_weight, num_token, num_expert,
num_mask, topk, num_expert_group, topk_group, normalize_mode);
} else {
std::cerr << "[invokeMoeSoftmaxTopkKernel]: source type not supported ";
return KernelStatus::KERNEL_STATUS_FAILED;
}
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
} // namespace tmo