Files
enginex-mlu370-vllm/torch_mlu_ops-v1.3.2/csrc/kernels/moe/softmax_topk.mlu
2026-02-04 17:39:32 +08:00

603 lines
29 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#include <mlu.h>
#include <cassert>
#include <iostream>
#include <limits>
#include <map>
#include <ostream>
#include "cnnl.h"
#include "cnrt.h"
#include "softmax_topk.mluh"
namespace tmo {
namespace kernels {
#define SCATTER_ALIGN (64) // align for __scatter()
#define NRAM_SIZE (__MLU_NRAM_SIZE__ * 1024 - 32 * 1024)
#define SRAM_SIZE (__MLU_SRAM_SIZE__ * 1024 - 32 * 1024)
#define TILING_ALIGN (64)
#define DIV_UP(x, y) ((x) / (y) + (int)((x) % (y) > 0))
__nram__ int8_t nram_buffer[NRAM_SIZE];
__mlu_shared__ int8_t sram_buffer[SRAM_SIZE];
#define __TRANS_TILING(TYPE, CONVERT) \
__asm__ volatile("trans.tiling." TYPE \
" [%[dst]], [%[src]]," \
"%[in0], %[in1], %[is1], %[in2], %[is2], %[in3], %[is3], %[in4]," \
"%[is4], %[in5], %[is5]," \
"%[dn0], %[dn1], %[ds1], %[dn2], %[ds2], %[dn3], %[ds3], %[dn4]," \
"%[ds4], %[dn5], %[ds5]" CONVERT ::[dst] "r"(dst), \
[src] "r"(src), [in0] "r"(in0), [in1] "r"(in1), [is1] "r"(is1), [in2] "r"(in2), \
[is2] "r"(is2), [in3] "r"(in3), [is3] "r"(is3), [in4] "r"(in4), [is4] "r"(is4), \
[in5] "r"(in5), [is5] "r"(is5), [dn0] "r"(dn0), [dn1] "r"(dn1), [ds1] "r"(ds1), \
[dn2] "r"(dn2), [ds2] "r"(ds2), [dn3] "r"(dn3), [ds3] "r"(ds3), [dn4] "r"(dn4), \
[ds4] "r"(ds4), [dn5] "r"(dn5), [ds5] "r"(ds5));
template <typename SRC_DTYPE, typename DST_DTYPE, mluMemcpyDirection_t dir>
__mlu_func__ void __mlvm_trans(DST_DTYPE *dst,
const SRC_DTYPE *src,
const uint32_t in0,
const uint32_t in1,
const uint32_t is1,
const uint32_t in2,
const uint32_t is2,
const uint32_t in3,
const uint32_t is3,
const uint32_t in4,
const uint32_t is4,
const uint32_t in5,
const uint32_t is5,
const uint32_t dn0,
const uint32_t dn1,
const uint32_t ds1,
const uint32_t dn2,
const uint32_t ds2,
const uint32_t dn3,
const uint32_t ds3,
const uint32_t dn4,
const uint32_t ds4,
const uint32_t dn5,
const uint32_t ds5) {
if (SRAM2NRAM == dir && std::is_same<DST_DTYPE, float>::value) {
if (std::is_same<SRC_DTYPE, float>::value) {
__TRANS_TILING("nram.sram.b32", ";")
} else if (std::is_same<SRC_DTYPE, half>::value) {
__TRANS_TILING("nram.sram.b16", ", .cvt.f32.f16();")
#if __BANG_ARCH__ >= 500
} else if (std::is_same<SRC_DTYPE, bfloat16_t>::value) {
__TRANS_TILING("nram.sram.b16", ", .cvt.f32.bf16();")
#endif
}
}
}
/* 将shape为[h,w]的数据转置为[w,h](带转数)分4块分别进行处理。
* dst: dst地址
* src: src地址
* h: h方向大小
* w: w方向大小
*/
template <typename SRC_DTYPE, typename DST_DTYPE, mluMemcpyDirection_t dir>
__mlu_func__ void transhw2wh(DST_DTYPE *dst, SRC_DTYPE *src, uint32_t h, uint32_t w) {
uint32_t align_num = TILING_ALIGN / sizeof(SRC_DTYPE);
uint32_t w_align = w / align_num;
uint32_t w_rem = w % align_num;
uint32_t h_align = h / align_num;
uint32_t h_rem = h % align_num;
uint32_t in0 = TILING_ALIGN, dn0 = TILING_ALIGN;
uint32_t in1 = align_num, is1 = w * sizeof(SRC_DTYPE);
uint32_t in3 = w_align, is3 = TILING_ALIGN;
uint32_t in4 = h_align, is4 = w * TILING_ALIGN;
uint32_t dn1 = align_num, ds1 = h * sizeof(DST_DTYPE);
uint32_t dn3 = in3, ds3 = h * align_num * sizeof(DST_DTYPE);
uint32_t dn4 = in4, ds4 = align_num * sizeof(DST_DTYPE);
/* 1. h_align * w_align */
if (w_align > 0 && h_align > 0) {
__mlvm_trans<SRC_DTYPE, DST_DTYPE, dir>(dst, src, in0, in1, is1, 1, 0, in3, is3, in4, is4, 1, 0,
dn0, dn1, ds1, 1, 0, dn3, ds3, dn4, ds4, 1, 0);
}
/* 2. h_align * w_rem */
if (w_rem > 0 && h_align > 0) {
SRC_DTYPE *src_temp = src + w_align * align_num;
DST_DTYPE *dst_temp = dst + w_align * align_num * h;
in0 = w_rem * sizeof(SRC_DTYPE);
dn0 = TILING_ALIGN;
in1 = align_num;
is1 = w * sizeof(SRC_DTYPE);
in4 = h_align;
is4 = w * TILING_ALIGN;
dn1 = w_rem;
ds1 = h * sizeof(DST_DTYPE);
dn4 = in4;
ds4 = align_num * sizeof(DST_DTYPE);
__mlvm_trans<SRC_DTYPE, DST_DTYPE, dir>(dst_temp, src_temp, in0, in1, is1, 1, 0, 1, 0, in4, is4,
1, 0, dn0, dn1, ds1, 1, 0, 1, 0, dn4, ds4, 1, 0);
}
/* 3. h_rem * w_align */
if (w_align > 0 && h_rem > 0) {
SRC_DTYPE *src_temp = src + h_align * align_num * w;
DST_DTYPE *dst_temp = dst + h_align * align_num;
in0 = TILING_ALIGN;
dn0 = h_rem * sizeof(SRC_DTYPE);
in1 = h_rem;
is1 = w * sizeof(SRC_DTYPE);
in4 = w_align;
is4 = TILING_ALIGN;
dn1 = align_num;
ds1 = h * sizeof(DST_DTYPE);
dn4 = in4;
ds4 = h * align_num * sizeof(DST_DTYPE);
__mlvm_trans<SRC_DTYPE, DST_DTYPE, dir>(dst_temp, src_temp, in0, in1, is1, 1, 0, 1, 0, in4, is4,
1, 0, dn0, dn1, ds1, 1, 0, 1, 0, dn4, ds4, 1, 0);
}
/* 4. h_rem * w_rem */
if (w_rem > 0 && h_rem > 0) {
SRC_DTYPE *src_temp = src + h_align * align_num * w + w_align * align_num;
DST_DTYPE *dst_temp = dst + w_align * align_num * h + h_align * align_num;
in0 = w_rem * sizeof(SRC_DTYPE);
dn0 = h_rem * sizeof(SRC_DTYPE);
in1 = h_rem;
is1 = w * sizeof(SRC_DTYPE);
dn1 = w_rem;
ds1 = h * sizeof(DST_DTYPE);
__mlvm_trans<SRC_DTYPE, DST_DTYPE, dir>(dst_temp, src_temp, in0, in1, is1, 1, 0, 1, 0, 1, 0, 1,
0, dn0, dn1, ds1, 1, 0, 1, 0, 1, 0, 1, 0);
}
}
__mlu_func__ void getTopk(float *value_buffer,
uint32_t *index_buffer,
float *src_buffer,
float *compute_buffer,
float *max_buffer,
float *temp_buffer,
uint32_t *i_buffer,
uint32_t *col_buffer,
uint32_t topk,
uint32_t num_expert_group,
uint32_t col,
uint32_t row,
uint32_t value_index_stride,
uint32_t group_size,
bool is_deal_group) {
__bang_write_value((float *)temp_buffer, col, -INFINITY); // set -inf vector
for (int k = 0; k < topk; k++) {
if (is_deal_group) {
__bang_maxpool_index((uint32_t *)value_buffer + k * col, max_buffer, col, 1, num_expert_group,
1, num_expert_group, 1, 1);
__bang_fusion(FUSION_FMA, col_buffer, (uint32_t *)value_buffer + k * col, col, i_buffer, col,
col);
} else {
__bang_maxpool_value_index(value_buffer + k * col, max_buffer, col, 1, row, 1, row, 1, 1,
value_index_stride);
__bang_fusion(FUSION_FMA, col_buffer, index_buffer + k * col, col, i_buffer, col, col);
}
#if __BANG_ARCH__ >= 592
__bang_mul_scalar(col_buffer, col_buffer, sizeof(float), col); // index in byte
__scatter(max_buffer, temp_buffer, col_buffer, sizeof(uint32_t), NRAM2NRAM, sizeof(uint32_t),
col); // replace max value with -inf
#else
for (int i = 0; i < col; i++) {
uint32_t index = __load_nram(col_buffer + i);
max_buffer[index] = -INFINITY;
}
#endif
#if __BANG_ARCH__ < 500
if (is_deal_group) {
for (int i = 0; i < col; i++) {
uint32_t index = __load_nram((uint32_t *)value_buffer + k * col + i);
__memcpy(compute_buffer + i * row + index * group_size,
src_buffer + i * row + index * group_size, group_size * sizeof(float), NRAM2NRAM);
}
}
#endif
}
#if __BANG_ARCH__ >= 592
if (is_deal_group) {
__bang_transpose(index_buffer, (uint32_t *)value_buffer, topk, col);
__bang_mul_scalar((uint32_t *)value_buffer, i_buffer, row * sizeof(float), col);
__bang_move(value_buffer, value_buffer, col * sizeof(uint32_t), col * sizeof(uint32_t), 0,
topk - 1);
__bang_transpose((uint32_t *)compute_buffer, (uint32_t *)value_buffer, topk, col);
__bang_fusion(FUSION_FMA, index_buffer, index_buffer, group_size * sizeof(float),
(uint32_t *)compute_buffer, col * topk, col * topk);
__gather(compute_buffer, src_buffer, (uint32_t *)index_buffer, group_size * sizeof(float),
NRAM2NRAM, group_size * sizeof(float), col * topk);
__bang_write_value(src_buffer, row * col, -INFINITY);
__scatter(src_buffer, compute_buffer, index_buffer, group_size * sizeof(float), NRAM2NRAM,
group_size * sizeof(float), col * topk);
}
#endif
}
template <typename T>
__mlu_func__ void computeSoftmaxTopk(T *sram_buffer,
T *load_buffer,
float *src_buffer,
float *compute_buffer,
float *group_max_buffer,
float *nramout_value,
uint32_t *nramout_index,
uint32_t *i_buffer,
uint32_t *col_buffer,
float *softmax_buffer,
uint32_t row,
uint32_t nram_compute_col_num,
uint32_t mask_num,
uint32_t nram_max_col_num,
uint32_t topk,
int num_expert_group,
uint32_t topk_group,
uint32_t top_num,
uint32_t nram_col_offset,
int normalize_mode,
bool valid_mask,
bool split_mask) {
uint32_t nram_compute_num = nram_compute_col_num * row;
// convert to float for half/bf16 datatype
if (std::is_same<T, half>::value) {
__bang_half2float(src_buffer, (half *)load_buffer, nram_compute_num);
} else if (std::is_same<T, bfloat16_t>::value) {
__bang_bfloat162float(src_buffer, (bfloat16_t *)load_buffer, nram_compute_num);
}
// transpose [col, row] to [row, col]. To accelerate max/sum compute with maxpool/sumpool.
__bang_transpose(compute_buffer, src_buffer, nram_compute_col_num, row);
// compute softmax
int tmp = 0x3fb8aa3b;
float log2e = *(float *)&tmp; // for exp
// src_buffer reuse as buffer for max/sum.
__bang_maxpool(src_buffer, compute_buffer, nram_compute_col_num, row, 1, row, 1, 1, 1); // max
__bang_fusion(FUSION_FSM, compute_buffer, compute_buffer, src_buffer, log2e, nram_compute_num,
nram_compute_col_num);
__bang_pow2(compute_buffer, compute_buffer, nram_compute_num); // exp(input - max)
__bang_sumpool(src_buffer, compute_buffer, nram_compute_col_num, row, 1, row, 1, 1, 1); // sum
__bang_recip(src_buffer, src_buffer, nram_compute_col_num); // 1/sum
__bang_cycle_mul(compute_buffer, compute_buffer, src_buffer, nram_compute_num,
nram_compute_col_num);
__sync_cluster();
// move mask and compute
if (valid_mask) {
if (!split_mask) {
__bang_transpose(src_buffer, compute_buffer, row, nram_compute_col_num);
if (std::is_same<T, half>::value) {
__memcpy((half *)compute_buffer + mask_num * row, sram_buffer, mask_num * row * sizeof(T),
SRAM2NRAM);
__bang_half2float((float *)compute_buffer, (half *)compute_buffer + mask_num * row,
mask_num * row);
} else if (std::is_same<T, bfloat16_t>::value) {
__memcpy((bfloat16_t *)compute_buffer + mask_num * row, sram_buffer,
mask_num * row * sizeof(T), SRAM2NRAM);
__bang_bfloat162float((float *)compute_buffer,
(bfloat16_t *)compute_buffer + mask_num * row, mask_num * row);
} else {
__memcpy(compute_buffer, sram_buffer, mask_num * row * sizeof(T), SRAM2NRAM);
}
__bang_cycle_mul(src_buffer, src_buffer, compute_buffer, nram_compute_col_num * row,
mask_num * row);
__bang_transpose(compute_buffer, src_buffer, nram_compute_col_num, row);
} else {
transhw2wh<T, float, SRAM2NRAM>(src_buffer, sram_buffer + nram_col_offset * row,
nram_compute_col_num, row);
__sync();
__bang_mul(compute_buffer, compute_buffer, src_buffer, nram_compute_col_num * row);
}
}
if (normalize_mode == 2) {
__bang_sumpool(softmax_buffer, compute_buffer, nram_compute_col_num, row, 1, row, 1, 1, 1);
}
if (num_expert_group <= 1) {
// num_expert_group <= 1, maintain original topk calculation logic
getTopk(nramout_value, nramout_index, src_buffer, compute_buffer, compute_buffer, src_buffer,
i_buffer, col_buffer, topk, num_expert_group, nram_compute_col_num, row,
nram_max_col_num * topk * sizeof(float), 0, false);
} else {
// num_expert_group > 1, use grouped_topk calculation logic
uint32_t group_size = row / num_expert_group;
__bang_transpose(src_buffer, compute_buffer, row, nram_compute_col_num);
__bang_maxpool(group_max_buffer, compute_buffer, nram_compute_col_num, num_expert_group,
group_size, 1, group_size, 1, 1);
__bang_write_value(compute_buffer, row * nram_compute_col_num, -INFINITY);
// get topk_group
getTopk(nramout_value, nramout_index, src_buffer, compute_buffer, group_max_buffer,
(float *)nramout_index, i_buffer, col_buffer, topk_group, num_expert_group,
nram_compute_col_num, row, nram_max_col_num * topk * sizeof(float), group_size, true);
// get topk
#if __BANG_ARCH__ < 500
__bang_transpose(src_buffer, compute_buffer, nram_compute_col_num, row);
getTopk(nramout_value, nramout_index, src_buffer, compute_buffer, src_buffer, compute_buffer,
i_buffer, col_buffer, topk, num_expert_group, nram_compute_col_num, row,
nram_max_col_num * top_num * sizeof(float), 0, false);
#else
__bang_transpose(compute_buffer, src_buffer, nram_compute_col_num, row);
getTopk(nramout_value, nramout_index, src_buffer, compute_buffer, compute_buffer, src_buffer,
i_buffer, col_buffer, topk, num_expert_group, nram_compute_col_num, row,
nram_max_col_num * top_num * sizeof(float), 0, false);
#endif
} // end else
// normalize result
if (normalize_mode == 1) {
// compute_buffer reuse as buffer for sum.
__bang_sumpool(compute_buffer, nramout_value, nram_compute_col_num, topk, 1, topk, 1, 1, 1);
__bang_recip(compute_buffer, compute_buffer, nram_compute_col_num);
__bang_cycle_mul(nramout_value, nramout_value, compute_buffer, topk * nram_compute_col_num,
nram_compute_col_num);
} else if (normalize_mode == 2) {
__bang_recip(compute_buffer, softmax_buffer, nram_compute_col_num);
__bang_cycle_mul(nramout_value, nramout_value, compute_buffer, topk * nram_compute_col_num,
nram_compute_col_num);
}
// transpose back. src and dst of transpose can not be the same address.
__bang_transpose(compute_buffer, nramout_value, topk, nram_compute_col_num);
__bang_transpose((uint32_t *)nramout_value, nramout_index, topk, nram_compute_col_num);
}
template <typename T>
__mlu_global__ void MLUSoftmaxTopkKernel(T *input,
T *mask,
int *index_out,
float *value_out,
int col,
int row,
int mask_num,
int topk,
int num_expert_group,
int topk_group,
int normalize_mode) {
bool valid_mask = (mask != nullptr);
int top_num = topk >= topk_group ? topk : topk_group;
uint32_t nram_low_space =
PAD_UP((row * 2 + top_num * 2 + 2 + (normalize_mode == 2) + num_expert_group) * sizeof(float),
SCATTER_ALIGN);
if (num_expert_group <= 1) {
nram_low_space =
PAD_UP((row * 2 + topk * 2 + 2 + (normalize_mode == 2)) * sizeof(float), SCATTER_ALIGN);
}
uint32_t nram_max_col_num = (NRAM_SIZE) / nram_low_space;
if (nram_max_col_num > col / taskDim + (col % taskDim > 0)) {
nram_max_col_num = col / taskDim + (col % taskDim > 0);
}
nram_max_col_num = PAD_DOWN(nram_max_col_num, SCATTER_ALIGN / sizeof(float));
if (nram_max_col_num <= 0) {
nram_max_col_num = SCATTER_ALIGN / sizeof(float);
}
uint32_t nram_deal_num = nram_max_col_num * row;
uint32_t batch = col / mask_num;
// nram split:
// |--------------------------|--------------------------|--------------------|...
// | size: nram/2 -col*topk*2 | size: nram/2 -col*topk*2 |col*num_expert_group|...
// | src_buffer | compute_buffer | group_max_buffer |...
// |--------------------------|--------------------------|--------------------|...
// |----------------------------------------|---------------|--------------|
// | nram_col_num*3 | col*topk | col*topk |
// | i_buffer | col_buffer | softmax_buffer | nramout_value | nramout_index|
// |----------------------------------------|---------------|--------------|
float *src_buffer = (float *)nram_buffer;
float *compute_buffer = src_buffer + PAD_UP(nram_deal_num, SCATTER_ALIGN / sizeof(float));
float *group_max_buffer = compute_buffer + nram_deal_num;
uint32_t *i_buffer = (uint32_t *)group_max_buffer + num_expert_group * nram_max_col_num;
if (num_expert_group <= 1) {
i_buffer = (uint32_t *)group_max_buffer;
}
uint32_t *col_buffer = i_buffer + nram_max_col_num;
float *softmax_buffer = (float *)col_buffer + nram_max_col_num;
if (normalize_mode != 2) {
softmax_buffer = (float *)col_buffer;
}
float *nramout_value = softmax_buffer + nram_max_col_num;
uint32_t *nramout_index = (uint32_t *)nramout_value + top_num * nram_max_col_num;
if (num_expert_group <= 1) {
nramout_index = (uint32_t *)nramout_value + topk * nram_max_col_num;
}
T *load_buffer = (T *)src_buffer;
if (std::is_same<T, half>::value || std::is_same<T, bfloat16_t>::value) {
load_buffer = load_buffer + nram_deal_num;
}
// set i_buffer
for (uint32_t i = 0; i < nram_max_col_num; i++) {
i_buffer[i] = i;
}
// input[batch, mask, low], mask[mask, low]
if (nram_max_col_num >= mask_num) { // nram can deal complete mask
bool split_mask = false;
uint32_t batch_seg = nram_max_col_num / mask_num;
uint32_t batch_rem = batch % batch_seg;
uint32_t batch_seg_num = batch / batch_seg + (batch_rem > 0);
int repeat = DIV_UP(batch_seg_num, taskDim);
for (int i = 0; i < repeat; i++) {
uint32_t seg_id = i * taskDim + taskId;
uint32_t sram_load_num = mask_num * row;
uint32_t sram_load_offset = 0;
uint32_t nram_compute_col_num = (seg_id == batch_seg_num - 1 && batch_rem > 0)
? batch_rem * mask_num
: batch_seg * mask_num;
uint32_t nram_load_num = seg_id < batch_seg_num ? nram_compute_col_num * row : 0;
uint32_t nram_store_num = seg_id < batch_seg_num ? nram_compute_col_num * topk : 0;
uint32_t nram_load_offset = seg_id * batch_seg * mask_num * row;
uint32_t nram_store_offset = seg_id * batch_seg * mask_num * topk;
// Load
if (valid_mask) {
__memcpy_async(sram_buffer, mask + sram_load_offset, sram_load_num * sizeof(T), GDRAM2SRAM);
}
if (nram_load_num > 0) {
__memcpy(load_buffer, input + nram_load_offset, nram_load_num * sizeof(T), GDRAM2NRAM);
}
// Compute
computeSoftmaxTopk<T>((T *)sram_buffer, load_buffer, src_buffer, compute_buffer,
group_max_buffer, nramout_value, nramout_index, i_buffer, col_buffer,
softmax_buffer, row, nram_compute_col_num, mask_num, nram_max_col_num,
topk, num_expert_group, topk_group, top_num, 0, normalize_mode,
valid_mask, split_mask);
// Store
if (nram_store_num > 0) {
__memcpy(value_out + nram_store_offset, compute_buffer, nram_store_num * sizeof(float),
NRAM2GDRAM);
__memcpy(index_out + nram_store_offset, nramout_value, nram_store_num * sizeof(int),
NRAM2GDRAM);
}
__sync_cluster();
}
} else {
bool split_mask = true;
uint32_t mask_seg = nram_max_col_num;
uint32_t mask_rem = mask_num % mask_seg;
uint32_t mask_seg_num = mask_num / mask_seg + (mask_rem > 0);
uint32_t sram_mask_seg_num = DIV_UP(mask_seg_num, coreDim);
uint32_t sram_mask_rem = mask_num % sram_mask_seg_num;
uint32_t sram_average_mask_num = mask_num / sram_mask_seg_num;
for (int i = taskIdY; i < sram_mask_seg_num * batch; i += taskDimY) {
uint32_t batch_idx = i / sram_mask_seg_num;
uint32_t mask_idx = i % sram_mask_seg_num;
uint32_t sram_deal_mask_num = sram_average_mask_num + (mask_idx < sram_mask_rem);
uint32_t sram_load_num = sram_deal_mask_num * row;
uint32_t sram_mask_offset = mask_idx < sram_mask_rem
? mask_idx * (sram_average_mask_num + 1)
: mask_idx * sram_average_mask_num + sram_mask_rem;
uint32_t sram_load_offset = sram_mask_offset * row;
uint32_t nram_average_mask_num = sram_deal_mask_num / taskDimX;
uint32_t nram_mask_rem = sram_deal_mask_num % taskDimX;
uint32_t nram_deal_mask_num = nram_average_mask_num + (taskIdX < nram_mask_rem);
uint32_t nram_load_num = nram_deal_mask_num * row;
uint32_t nram_col_offset = taskIdX < nram_mask_rem
? taskIdX * (nram_average_mask_num + 1)
: taskIdX * nram_average_mask_num + nram_mask_rem;
uint32_t nram_load_offset = (batch_idx * mask_num + sram_mask_offset + nram_col_offset) * row;
uint32_t nram_store_num = nram_deal_mask_num * topk;
uint32_t nram_store_offset =
(batch_idx * mask_num + sram_mask_offset + nram_col_offset) * topk;
// Load
if (valid_mask) {
__memcpy_async(sram_buffer, mask + sram_load_offset, sram_load_num * sizeof(T), GDRAM2SRAM);
}
if (nram_load_num > 0) {
__memcpy(load_buffer, input + nram_load_offset, nram_load_num * sizeof(T), GDRAM2NRAM);
}
// Compute
computeSoftmaxTopk<T>((T *)sram_buffer, load_buffer, src_buffer, compute_buffer,
group_max_buffer, nramout_value, nramout_index, i_buffer, col_buffer,
softmax_buffer, row, nram_deal_mask_num, mask_num, nram_max_col_num,
topk, num_expert_group, topk_group, top_num, nram_col_offset,
normalize_mode, valid_mask, split_mask);
// Store
if (nram_store_num > 0) {
__memcpy(value_out + nram_store_offset, compute_buffer, nram_store_num * sizeof(float),
NRAM2GDRAM);
__memcpy(index_out + nram_store_offset, nramout_value, nram_store_num * sizeof(int),
NRAM2GDRAM);
}
__sync_cluster();
}
}
}
} // namespace kernels
KernelStatus invokeMoeSoftmaxTopkKernel(cnrtQueue_t queue,
float *reduce_weight,
int *expert_id,
const void *input,
const void *mask,
const int num_token,
const int num_expert,
const int num_mask,
const int topk,
const int num_expert_group,
const int topk_group,
const cnnlDataType_t dtype,
const int normalize_mode) {
CNdev dev;
cnCtxGetDevice(&dev);
int cluster_num;
int core_num;
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev));
cnrtDim3_t dim{.x = (uint32_t)core_num, .y = (uint32_t)cluster_num, .z = 1};
int top_num = topk >= topk_group ? topk : topk_group;
if (num_expert_group <= 1) {
if (num_expert > (NRAM_SIZE - (topk * 2 + 3) * sizeof(float)) / 2 / sizeof(float)) {
std::cerr << "[invokeMoeSoftmaxTopkKernel]: num_expert is too large, currently not supported."
<< "Supported max num_expert:"
<< (NRAM_SIZE - (topk * 2 + 3) * sizeof(float)) / 2 / sizeof(float)
<< ". Current num_expert:" << num_expert;
return KernelStatus::KERNEL_STATUS_FAILED;
}
} else {
if (num_expert >
(NRAM_SIZE - (top_num * 2 + 2 + num_expert_group) * sizeof(float)) / 2 / sizeof(float)) {
std::cerr << "[invokeMoeSoftmaxTopkKernel]: num_expert is too large, currently not supported."
<< "Supported max num_expert:"
<< (NRAM_SIZE - (topk * 2 + 3) * sizeof(float)) / 2 / sizeof(float)
<< ". Current num_expert:" << num_expert;
return KernelStatus::KERNEL_STATUS_FAILED;
}
}
if (topk > num_expert) {
std::cerr << "[invokeMoeSoftmaxTopkKernel]: topk is larger than num_expert."
<< "topk:" << topk << ". num_expert:" << num_expert;
return KernelStatus::KERNEL_STATUS_FAILED;
}
if (num_expert_group > 1) {
if (mask != nullptr) {
std::cerr << "[invokeMoeSoftmaxTopkKernel]: if num_expert_group > 1, mask should be nullptr";
}
if (num_expert % num_expert_group != 0) {
std::cerr << "[invokeMoeSoftmaxTopkKernel]: if num_expert_group > 1, num_expert should be"
<< "divisible by num_expert_group, but now num_expert:" << num_expert
<< ", num_expert_group:" << num_expert_group;
return KernelStatus::KERNEL_STATUS_FAILED;
}
if (topk_group <= 0 || topk_group > num_expert_group) {
std::cerr << "[invokeMoeSoftmaxTopkKernel]: if num_expert_group > 1, topk_group should be"
<< "larger than 0 and less than or equal to num_expert_group, but now topk_group"
<< topk_group << ", num_expert group:" << num_expert_group;
return KernelStatus::KERNEL_STATUS_FAILED;
}
if (topk > (num_expert / num_expert_group) * topk_group) {
std::cerr << "[invokeMoeSoftmaxTopkKernel]: if num_expert_group > 1, topk should be less"
<< "than or equal to (num_expert / num_expert_group) * topk_group, but now"
<< "topk :" << topk << ", num_expert:" << num_expert
<< ", num_expert_group:" << num_expert_group << ", topk_group:" << topk_group;
return KernelStatus::KERNEL_STATUS_FAILED;
}
}
if (dtype == CNNL_DTYPE_FLOAT) {
kernels::MLUSoftmaxTopkKernel<<<dim, cnrtFuncTypeUnion1, queue>>>(
(float *)input, (float *)mask, expert_id, reduce_weight, num_token, num_expert, num_mask,
topk, num_expert_group, topk_group, normalize_mode);
} else if (dtype == CNNL_DTYPE_HALF) {
kernels::MLUSoftmaxTopkKernel<<<dim, cnrtFuncTypeUnion1, queue>>>(
(half *)input, (half *)mask, expert_id, reduce_weight, num_token, num_expert, num_mask,
topk, num_expert_group, topk_group, normalize_mode);
} else if (dtype == CNNL_DTYPE_BFLOAT16) {
if (!isBf16Supported()) {
std::cerr << "[invokeMoeSoftmaxTopkKernel]: MLU300 devices do not support bfloat16."
<< std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
kernels::MLUSoftmaxTopkKernel<<<dim, cnrtFuncTypeUnion1, queue>>>(
(bfloat16_t *)input, (bfloat16_t *)mask, expert_id, reduce_weight, num_token, num_expert,
num_mask, topk, num_expert_group, topk_group, normalize_mode);
} else {
std::cerr << "[invokeMoeSoftmaxTopkKernel]: source type not supported ";
return KernelStatus::KERNEL_STATUS_FAILED;
}
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
} // namespace tmo