#include #include #include #include #include #include #include "cnnl.h" #include "cnrt.h" #include "softmax_topk.mluh" namespace tmo { namespace kernels { #define SCATTER_ALIGN (64) // align for __scatter() #define NRAM_SIZE (__MLU_NRAM_SIZE__ * 1024 - 32 * 1024) #define SRAM_SIZE (__MLU_SRAM_SIZE__ * 1024 - 32 * 1024) #define TILING_ALIGN (64) #define DIV_UP(x, y) ((x) / (y) + (int)((x) % (y) > 0)) __nram__ int8_t nram_buffer[NRAM_SIZE]; __mlu_shared__ int8_t sram_buffer[SRAM_SIZE]; #define __TRANS_TILING(TYPE, CONVERT) \ __asm__ volatile("trans.tiling." TYPE \ " [%[dst]], [%[src]]," \ "%[in0], %[in1], %[is1], %[in2], %[is2], %[in3], %[is3], %[in4]," \ "%[is4], %[in5], %[is5]," \ "%[dn0], %[dn1], %[ds1], %[dn2], %[ds2], %[dn3], %[ds3], %[dn4]," \ "%[ds4], %[dn5], %[ds5]" CONVERT ::[dst] "r"(dst), \ [src] "r"(src), [in0] "r"(in0), [in1] "r"(in1), [is1] "r"(is1), [in2] "r"(in2), \ [is2] "r"(is2), [in3] "r"(in3), [is3] "r"(is3), [in4] "r"(in4), [is4] "r"(is4), \ [in5] "r"(in5), [is5] "r"(is5), [dn0] "r"(dn0), [dn1] "r"(dn1), [ds1] "r"(ds1), \ [dn2] "r"(dn2), [ds2] "r"(ds2), [dn3] "r"(dn3), [ds3] "r"(ds3), [dn4] "r"(dn4), \ [ds4] "r"(ds4), [dn5] "r"(dn5), [ds5] "r"(ds5)); template __mlu_func__ void __mlvm_trans(DST_DTYPE *dst, const SRC_DTYPE *src, const uint32_t in0, const uint32_t in1, const uint32_t is1, const uint32_t in2, const uint32_t is2, const uint32_t in3, const uint32_t is3, const uint32_t in4, const uint32_t is4, const uint32_t in5, const uint32_t is5, const uint32_t dn0, const uint32_t dn1, const uint32_t ds1, const uint32_t dn2, const uint32_t ds2, const uint32_t dn3, const uint32_t ds3, const uint32_t dn4, const uint32_t ds4, const uint32_t dn5, const uint32_t ds5) { if (SRAM2NRAM == dir && std::is_same::value) { if (std::is_same::value) { __TRANS_TILING("nram.sram.b32", ";") } else if (std::is_same::value) { __TRANS_TILING("nram.sram.b16", ", .cvt.f32.f16();") #if __BANG_ARCH__ >= 500 } else if (std::is_same::value) { __TRANS_TILING("nram.sram.b16", ", .cvt.f32.bf16();") #endif } } } /* 将shape为[h,w]的数据转置为[w,h](带转数),分4块分别进行处理。 * dst: dst地址 * src: src地址 * h: h方向大小 * w: w方向大小 */ template __mlu_func__ void transhw2wh(DST_DTYPE *dst, SRC_DTYPE *src, uint32_t h, uint32_t w) { uint32_t align_num = TILING_ALIGN / sizeof(SRC_DTYPE); uint32_t w_align = w / align_num; uint32_t w_rem = w % align_num; uint32_t h_align = h / align_num; uint32_t h_rem = h % align_num; uint32_t in0 = TILING_ALIGN, dn0 = TILING_ALIGN; uint32_t in1 = align_num, is1 = w * sizeof(SRC_DTYPE); uint32_t in3 = w_align, is3 = TILING_ALIGN; uint32_t in4 = h_align, is4 = w * TILING_ALIGN; uint32_t dn1 = align_num, ds1 = h * sizeof(DST_DTYPE); uint32_t dn3 = in3, ds3 = h * align_num * sizeof(DST_DTYPE); uint32_t dn4 = in4, ds4 = align_num * sizeof(DST_DTYPE); /* 1. h_align * w_align */ if (w_align > 0 && h_align > 0) { __mlvm_trans(dst, src, in0, in1, is1, 1, 0, in3, is3, in4, is4, 1, 0, dn0, dn1, ds1, 1, 0, dn3, ds3, dn4, ds4, 1, 0); } /* 2. h_align * w_rem */ if (w_rem > 0 && h_align > 0) { SRC_DTYPE *src_temp = src + w_align * align_num; DST_DTYPE *dst_temp = dst + w_align * align_num * h; in0 = w_rem * sizeof(SRC_DTYPE); dn0 = TILING_ALIGN; in1 = align_num; is1 = w * sizeof(SRC_DTYPE); in4 = h_align; is4 = w * TILING_ALIGN; dn1 = w_rem; ds1 = h * sizeof(DST_DTYPE); dn4 = in4; ds4 = align_num * sizeof(DST_DTYPE); __mlvm_trans(dst_temp, src_temp, in0, in1, is1, 1, 0, 1, 0, in4, is4, 1, 0, dn0, dn1, ds1, 1, 0, 1, 0, dn4, ds4, 1, 0); } /* 3. h_rem * w_align */ if (w_align > 0 && h_rem > 0) { SRC_DTYPE *src_temp = src + h_align * align_num * w; DST_DTYPE *dst_temp = dst + h_align * align_num; in0 = TILING_ALIGN; dn0 = h_rem * sizeof(SRC_DTYPE); in1 = h_rem; is1 = w * sizeof(SRC_DTYPE); in4 = w_align; is4 = TILING_ALIGN; dn1 = align_num; ds1 = h * sizeof(DST_DTYPE); dn4 = in4; ds4 = h * align_num * sizeof(DST_DTYPE); __mlvm_trans(dst_temp, src_temp, in0, in1, is1, 1, 0, 1, 0, in4, is4, 1, 0, dn0, dn1, ds1, 1, 0, 1, 0, dn4, ds4, 1, 0); } /* 4. h_rem * w_rem */ if (w_rem > 0 && h_rem > 0) { SRC_DTYPE *src_temp = src + h_align * align_num * w + w_align * align_num; DST_DTYPE *dst_temp = dst + w_align * align_num * h + h_align * align_num; in0 = w_rem * sizeof(SRC_DTYPE); dn0 = h_rem * sizeof(SRC_DTYPE); in1 = h_rem; is1 = w * sizeof(SRC_DTYPE); dn1 = w_rem; ds1 = h * sizeof(DST_DTYPE); __mlvm_trans(dst_temp, src_temp, in0, in1, is1, 1, 0, 1, 0, 1, 0, 1, 0, dn0, dn1, ds1, 1, 0, 1, 0, 1, 0, 1, 0); } } __mlu_func__ void getTopk(float *value_buffer, uint32_t *index_buffer, float *src_buffer, float *compute_buffer, float *max_buffer, float *temp_buffer, uint32_t *i_buffer, uint32_t *col_buffer, uint32_t topk, uint32_t num_expert_group, uint32_t col, uint32_t row, uint32_t value_index_stride, uint32_t group_size, bool is_deal_group) { __bang_write_value((float *)temp_buffer, col, -INFINITY); // set -inf vector for (int k = 0; k < topk; k++) { if (is_deal_group) { __bang_maxpool_index((uint32_t *)value_buffer + k * col, max_buffer, col, 1, num_expert_group, 1, num_expert_group, 1, 1); __bang_fusion(FUSION_FMA, col_buffer, (uint32_t *)value_buffer + k * col, col, i_buffer, col, col); } else { __bang_maxpool_value_index(value_buffer + k * col, max_buffer, col, 1, row, 1, row, 1, 1, value_index_stride); __bang_fusion(FUSION_FMA, col_buffer, index_buffer + k * col, col, i_buffer, col, col); } #if __BANG_ARCH__ >= 592 __bang_mul_scalar(col_buffer, col_buffer, sizeof(float), col); // index in byte __scatter(max_buffer, temp_buffer, col_buffer, sizeof(uint32_t), NRAM2NRAM, sizeof(uint32_t), col); // replace max value with -inf #else for (int i = 0; i < col; i++) { uint32_t index = __load_nram(col_buffer + i); max_buffer[index] = -INFINITY; } #endif #if __BANG_ARCH__ < 500 if (is_deal_group) { for (int i = 0; i < col; i++) { uint32_t index = __load_nram((uint32_t *)value_buffer + k * col + i); __memcpy(compute_buffer + i * row + index * group_size, src_buffer + i * row + index * group_size, group_size * sizeof(float), NRAM2NRAM); } } #endif } #if __BANG_ARCH__ >= 592 if (is_deal_group) { __bang_transpose(index_buffer, (uint32_t *)value_buffer, topk, col); __bang_mul_scalar((uint32_t *)value_buffer, i_buffer, row * sizeof(float), col); __bang_move(value_buffer, value_buffer, col * sizeof(uint32_t), col * sizeof(uint32_t), 0, topk - 1); __bang_transpose((uint32_t *)compute_buffer, (uint32_t *)value_buffer, topk, col); __bang_fusion(FUSION_FMA, index_buffer, index_buffer, group_size * sizeof(float), (uint32_t *)compute_buffer, col * topk, col * topk); __gather(compute_buffer, src_buffer, (uint32_t *)index_buffer, group_size * sizeof(float), NRAM2NRAM, group_size * sizeof(float), col * topk); __bang_write_value(src_buffer, row * col, -INFINITY); __scatter(src_buffer, compute_buffer, index_buffer, group_size * sizeof(float), NRAM2NRAM, group_size * sizeof(float), col * topk); } #endif } template __mlu_func__ void computeSoftmaxTopk(T *sram_buffer, T *load_buffer, float *src_buffer, float *compute_buffer, float *group_max_buffer, float *nramout_value, uint32_t *nramout_index, uint32_t *i_buffer, uint32_t *col_buffer, float *softmax_buffer, uint32_t row, uint32_t nram_compute_col_num, uint32_t mask_num, uint32_t nram_max_col_num, uint32_t topk, int num_expert_group, uint32_t topk_group, uint32_t top_num, uint32_t nram_col_offset, int normalize_mode, bool valid_mask, bool split_mask) { uint32_t nram_compute_num = nram_compute_col_num * row; // convert to float for half/bf16 datatype if (std::is_same::value) { __bang_half2float(src_buffer, (half *)load_buffer, nram_compute_num); } else if (std::is_same::value) { __bang_bfloat162float(src_buffer, (bfloat16_t *)load_buffer, nram_compute_num); } // transpose [col, row] to [row, col]. To accelerate max/sum compute with maxpool/sumpool. __bang_transpose(compute_buffer, src_buffer, nram_compute_col_num, row); // compute softmax int tmp = 0x3fb8aa3b; float log2e = *(float *)&tmp; // for exp // src_buffer reuse as buffer for max/sum. __bang_maxpool(src_buffer, compute_buffer, nram_compute_col_num, row, 1, row, 1, 1, 1); // max __bang_fusion(FUSION_FSM, compute_buffer, compute_buffer, src_buffer, log2e, nram_compute_num, nram_compute_col_num); __bang_pow2(compute_buffer, compute_buffer, nram_compute_num); // exp(input - max) __bang_sumpool(src_buffer, compute_buffer, nram_compute_col_num, row, 1, row, 1, 1, 1); // sum __bang_recip(src_buffer, src_buffer, nram_compute_col_num); // 1/sum __bang_cycle_mul(compute_buffer, compute_buffer, src_buffer, nram_compute_num, nram_compute_col_num); __sync_cluster(); // move mask and compute if (valid_mask) { if (!split_mask) { __bang_transpose(src_buffer, compute_buffer, row, nram_compute_col_num); if (std::is_same::value) { __memcpy((half *)compute_buffer + mask_num * row, sram_buffer, mask_num * row * sizeof(T), SRAM2NRAM); __bang_half2float((float *)compute_buffer, (half *)compute_buffer + mask_num * row, mask_num * row); } else if (std::is_same::value) { __memcpy((bfloat16_t *)compute_buffer + mask_num * row, sram_buffer, mask_num * row * sizeof(T), SRAM2NRAM); __bang_bfloat162float((float *)compute_buffer, (bfloat16_t *)compute_buffer + mask_num * row, mask_num * row); } else { __memcpy(compute_buffer, sram_buffer, mask_num * row * sizeof(T), SRAM2NRAM); } __bang_cycle_mul(src_buffer, src_buffer, compute_buffer, nram_compute_col_num * row, mask_num * row); __bang_transpose(compute_buffer, src_buffer, nram_compute_col_num, row); } else { transhw2wh(src_buffer, sram_buffer + nram_col_offset * row, nram_compute_col_num, row); __sync(); __bang_mul(compute_buffer, compute_buffer, src_buffer, nram_compute_col_num * row); } } if (normalize_mode == 2) { __bang_sumpool(softmax_buffer, compute_buffer, nram_compute_col_num, row, 1, row, 1, 1, 1); } if (num_expert_group <= 1) { // num_expert_group <= 1, maintain original topk calculation logic getTopk(nramout_value, nramout_index, src_buffer, compute_buffer, compute_buffer, src_buffer, i_buffer, col_buffer, topk, num_expert_group, nram_compute_col_num, row, nram_max_col_num * topk * sizeof(float), 0, false); } else { // num_expert_group > 1, use grouped_topk calculation logic uint32_t group_size = row / num_expert_group; __bang_transpose(src_buffer, compute_buffer, row, nram_compute_col_num); __bang_maxpool(group_max_buffer, compute_buffer, nram_compute_col_num, num_expert_group, group_size, 1, group_size, 1, 1); __bang_write_value(compute_buffer, row * nram_compute_col_num, -INFINITY); // get topk_group getTopk(nramout_value, nramout_index, src_buffer, compute_buffer, group_max_buffer, (float *)nramout_index, i_buffer, col_buffer, topk_group, num_expert_group, nram_compute_col_num, row, nram_max_col_num * topk * sizeof(float), group_size, true); // get topk #if __BANG_ARCH__ < 500 __bang_transpose(src_buffer, compute_buffer, nram_compute_col_num, row); getTopk(nramout_value, nramout_index, src_buffer, compute_buffer, src_buffer, compute_buffer, i_buffer, col_buffer, topk, num_expert_group, nram_compute_col_num, row, nram_max_col_num * top_num * sizeof(float), 0, false); #else __bang_transpose(compute_buffer, src_buffer, nram_compute_col_num, row); getTopk(nramout_value, nramout_index, src_buffer, compute_buffer, compute_buffer, src_buffer, i_buffer, col_buffer, topk, num_expert_group, nram_compute_col_num, row, nram_max_col_num * top_num * sizeof(float), 0, false); #endif } // end else // normalize result if (normalize_mode == 1) { // compute_buffer reuse as buffer for sum. __bang_sumpool(compute_buffer, nramout_value, nram_compute_col_num, topk, 1, topk, 1, 1, 1); __bang_recip(compute_buffer, compute_buffer, nram_compute_col_num); __bang_cycle_mul(nramout_value, nramout_value, compute_buffer, topk * nram_compute_col_num, nram_compute_col_num); } else if (normalize_mode == 2) { __bang_recip(compute_buffer, softmax_buffer, nram_compute_col_num); __bang_cycle_mul(nramout_value, nramout_value, compute_buffer, topk * nram_compute_col_num, nram_compute_col_num); } // transpose back. src and dst of transpose can not be the same address. __bang_transpose(compute_buffer, nramout_value, topk, nram_compute_col_num); __bang_transpose((uint32_t *)nramout_value, nramout_index, topk, nram_compute_col_num); } template __mlu_global__ void MLUSoftmaxTopkKernel(T *input, T *mask, int *index_out, float *value_out, int col, int row, int mask_num, int topk, int num_expert_group, int topk_group, int normalize_mode) { bool valid_mask = (mask != nullptr); int top_num = topk >= topk_group ? topk : topk_group; uint32_t nram_low_space = PAD_UP((row * 2 + top_num * 2 + 2 + (normalize_mode == 2) + num_expert_group) * sizeof(float), SCATTER_ALIGN); if (num_expert_group <= 1) { nram_low_space = PAD_UP((row * 2 + topk * 2 + 2 + (normalize_mode == 2)) * sizeof(float), SCATTER_ALIGN); } uint32_t nram_max_col_num = (NRAM_SIZE) / nram_low_space; if (nram_max_col_num > col / taskDim + (col % taskDim > 0)) { nram_max_col_num = col / taskDim + (col % taskDim > 0); } nram_max_col_num = PAD_DOWN(nram_max_col_num, SCATTER_ALIGN / sizeof(float)); if (nram_max_col_num <= 0) { nram_max_col_num = SCATTER_ALIGN / sizeof(float); } uint32_t nram_deal_num = nram_max_col_num * row; uint32_t batch = col / mask_num; // nram split: // |--------------------------|--------------------------|--------------------|... // | size: nram/2 -col*topk*2 | size: nram/2 -col*topk*2 |col*num_expert_group|... // | src_buffer | compute_buffer | group_max_buffer |... // |--------------------------|--------------------------|--------------------|... // |----------------------------------------|---------------|--------------| // | nram_col_num*3 | col*topk | col*topk | // | i_buffer | col_buffer | softmax_buffer | nramout_value | nramout_index| // |----------------------------------------|---------------|--------------| float *src_buffer = (float *)nram_buffer; float *compute_buffer = src_buffer + PAD_UP(nram_deal_num, SCATTER_ALIGN / sizeof(float)); float *group_max_buffer = compute_buffer + nram_deal_num; uint32_t *i_buffer = (uint32_t *)group_max_buffer + num_expert_group * nram_max_col_num; if (num_expert_group <= 1) { i_buffer = (uint32_t *)group_max_buffer; } uint32_t *col_buffer = i_buffer + nram_max_col_num; float *softmax_buffer = (float *)col_buffer + nram_max_col_num; if (normalize_mode != 2) { softmax_buffer = (float *)col_buffer; } float *nramout_value = softmax_buffer + nram_max_col_num; uint32_t *nramout_index = (uint32_t *)nramout_value + top_num * nram_max_col_num; if (num_expert_group <= 1) { nramout_index = (uint32_t *)nramout_value + topk * nram_max_col_num; } T *load_buffer = (T *)src_buffer; if (std::is_same::value || std::is_same::value) { load_buffer = load_buffer + nram_deal_num; } // set i_buffer for (uint32_t i = 0; i < nram_max_col_num; i++) { i_buffer[i] = i; } // input[batch, mask, low], mask[mask, low] if (nram_max_col_num >= mask_num) { // nram can deal complete mask bool split_mask = false; uint32_t batch_seg = nram_max_col_num / mask_num; uint32_t batch_rem = batch % batch_seg; uint32_t batch_seg_num = batch / batch_seg + (batch_rem > 0); int repeat = DIV_UP(batch_seg_num, taskDim); for (int i = 0; i < repeat; i++) { uint32_t seg_id = i * taskDim + taskId; uint32_t sram_load_num = mask_num * row; uint32_t sram_load_offset = 0; uint32_t nram_compute_col_num = (seg_id == batch_seg_num - 1 && batch_rem > 0) ? batch_rem * mask_num : batch_seg * mask_num; uint32_t nram_load_num = seg_id < batch_seg_num ? nram_compute_col_num * row : 0; uint32_t nram_store_num = seg_id < batch_seg_num ? nram_compute_col_num * topk : 0; uint32_t nram_load_offset = seg_id * batch_seg * mask_num * row; uint32_t nram_store_offset = seg_id * batch_seg * mask_num * topk; // Load if (valid_mask) { __memcpy_async(sram_buffer, mask + sram_load_offset, sram_load_num * sizeof(T), GDRAM2SRAM); } if (nram_load_num > 0) { __memcpy(load_buffer, input + nram_load_offset, nram_load_num * sizeof(T), GDRAM2NRAM); } // Compute computeSoftmaxTopk((T *)sram_buffer, load_buffer, src_buffer, compute_buffer, group_max_buffer, nramout_value, nramout_index, i_buffer, col_buffer, softmax_buffer, row, nram_compute_col_num, mask_num, nram_max_col_num, topk, num_expert_group, topk_group, top_num, 0, normalize_mode, valid_mask, split_mask); // Store if (nram_store_num > 0) { __memcpy(value_out + nram_store_offset, compute_buffer, nram_store_num * sizeof(float), NRAM2GDRAM); __memcpy(index_out + nram_store_offset, nramout_value, nram_store_num * sizeof(int), NRAM2GDRAM); } __sync_cluster(); } } else { bool split_mask = true; uint32_t mask_seg = nram_max_col_num; uint32_t mask_rem = mask_num % mask_seg; uint32_t mask_seg_num = mask_num / mask_seg + (mask_rem > 0); uint32_t sram_mask_seg_num = DIV_UP(mask_seg_num, coreDim); uint32_t sram_mask_rem = mask_num % sram_mask_seg_num; uint32_t sram_average_mask_num = mask_num / sram_mask_seg_num; for (int i = taskIdY; i < sram_mask_seg_num * batch; i += taskDimY) { uint32_t batch_idx = i / sram_mask_seg_num; uint32_t mask_idx = i % sram_mask_seg_num; uint32_t sram_deal_mask_num = sram_average_mask_num + (mask_idx < sram_mask_rem); uint32_t sram_load_num = sram_deal_mask_num * row; uint32_t sram_mask_offset = mask_idx < sram_mask_rem ? mask_idx * (sram_average_mask_num + 1) : mask_idx * sram_average_mask_num + sram_mask_rem; uint32_t sram_load_offset = sram_mask_offset * row; uint32_t nram_average_mask_num = sram_deal_mask_num / taskDimX; uint32_t nram_mask_rem = sram_deal_mask_num % taskDimX; uint32_t nram_deal_mask_num = nram_average_mask_num + (taskIdX < nram_mask_rem); uint32_t nram_load_num = nram_deal_mask_num * row; uint32_t nram_col_offset = taskIdX < nram_mask_rem ? taskIdX * (nram_average_mask_num + 1) : taskIdX * nram_average_mask_num + nram_mask_rem; uint32_t nram_load_offset = (batch_idx * mask_num + sram_mask_offset + nram_col_offset) * row; uint32_t nram_store_num = nram_deal_mask_num * topk; uint32_t nram_store_offset = (batch_idx * mask_num + sram_mask_offset + nram_col_offset) * topk; // Load if (valid_mask) { __memcpy_async(sram_buffer, mask + sram_load_offset, sram_load_num * sizeof(T), GDRAM2SRAM); } if (nram_load_num > 0) { __memcpy(load_buffer, input + nram_load_offset, nram_load_num * sizeof(T), GDRAM2NRAM); } // Compute computeSoftmaxTopk((T *)sram_buffer, load_buffer, src_buffer, compute_buffer, group_max_buffer, nramout_value, nramout_index, i_buffer, col_buffer, softmax_buffer, row, nram_deal_mask_num, mask_num, nram_max_col_num, topk, num_expert_group, topk_group, top_num, nram_col_offset, normalize_mode, valid_mask, split_mask); // Store if (nram_store_num > 0) { __memcpy(value_out + nram_store_offset, compute_buffer, nram_store_num * sizeof(float), NRAM2GDRAM); __memcpy(index_out + nram_store_offset, nramout_value, nram_store_num * sizeof(int), NRAM2GDRAM); } __sync_cluster(); } } } } // namespace kernels KernelStatus invokeMoeSoftmaxTopkKernel(cnrtQueue_t queue, float *reduce_weight, int *expert_id, const void *input, const void *mask, const int num_token, const int num_expert, const int num_mask, const int topk, const int num_expert_group, const int topk_group, const cnnlDataType_t dtype, const int normalize_mode) { CNdev dev; cnCtxGetDevice(&dev); int cluster_num; int core_num; CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev)); CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev)); cnrtDim3_t dim{.x = (uint32_t)core_num, .y = (uint32_t)cluster_num, .z = 1}; int top_num = topk >= topk_group ? topk : topk_group; if (num_expert_group <= 1) { if (num_expert > (NRAM_SIZE - (topk * 2 + 3) * sizeof(float)) / 2 / sizeof(float)) { std::cerr << "[invokeMoeSoftmaxTopkKernel]: num_expert is too large, currently not supported." << "Supported max num_expert:" << (NRAM_SIZE - (topk * 2 + 3) * sizeof(float)) / 2 / sizeof(float) << ". Current num_expert:" << num_expert; return KernelStatus::KERNEL_STATUS_FAILED; } } else { if (num_expert > (NRAM_SIZE - (top_num * 2 + 2 + num_expert_group) * sizeof(float)) / 2 / sizeof(float)) { std::cerr << "[invokeMoeSoftmaxTopkKernel]: num_expert is too large, currently not supported." << "Supported max num_expert:" << (NRAM_SIZE - (topk * 2 + 3) * sizeof(float)) / 2 / sizeof(float) << ". Current num_expert:" << num_expert; return KernelStatus::KERNEL_STATUS_FAILED; } } if (topk > num_expert) { std::cerr << "[invokeMoeSoftmaxTopkKernel]: topk is larger than num_expert." << "topk:" << topk << ". num_expert:" << num_expert; return KernelStatus::KERNEL_STATUS_FAILED; } if (num_expert_group > 1) { if (mask != nullptr) { std::cerr << "[invokeMoeSoftmaxTopkKernel]: if num_expert_group > 1, mask should be nullptr"; } if (num_expert % num_expert_group != 0) { std::cerr << "[invokeMoeSoftmaxTopkKernel]: if num_expert_group > 1, num_expert should be" << "divisible by num_expert_group, but now num_expert:" << num_expert << ", num_expert_group:" << num_expert_group; return KernelStatus::KERNEL_STATUS_FAILED; } if (topk_group <= 0 || topk_group > num_expert_group) { std::cerr << "[invokeMoeSoftmaxTopkKernel]: if num_expert_group > 1, topk_group should be" << "larger than 0 and less than or equal to num_expert_group, but now topk_group" << topk_group << ", num_expert group:" << num_expert_group; return KernelStatus::KERNEL_STATUS_FAILED; } if (topk > (num_expert / num_expert_group) * topk_group) { std::cerr << "[invokeMoeSoftmaxTopkKernel]: if num_expert_group > 1, topk should be less" << "than or equal to (num_expert / num_expert_group) * topk_group, but now" << "topk :" << topk << ", num_expert:" << num_expert << ", num_expert_group:" << num_expert_group << ", topk_group:" << topk_group; return KernelStatus::KERNEL_STATUS_FAILED; } } if (dtype == CNNL_DTYPE_FLOAT) { kernels::MLUSoftmaxTopkKernel<<>>( (float *)input, (float *)mask, expert_id, reduce_weight, num_token, num_expert, num_mask, topk, num_expert_group, topk_group, normalize_mode); } else if (dtype == CNNL_DTYPE_HALF) { kernels::MLUSoftmaxTopkKernel<<>>( (half *)input, (half *)mask, expert_id, reduce_weight, num_token, num_expert, num_mask, topk, num_expert_group, topk_group, normalize_mode); } else if (dtype == CNNL_DTYPE_BFLOAT16) { if (!isBf16Supported()) { std::cerr << "[invokeMoeSoftmaxTopkKernel]: MLU300 devices do not support bfloat16." << std::endl; return KernelStatus::KERNEL_STATUS_FAILED; } kernels::MLUSoftmaxTopkKernel<<>>( (bfloat16_t *)input, (bfloat16_t *)mask, expert_id, reduce_weight, num_token, num_expert, num_mask, topk, num_expert_group, topk_group, normalize_mode); } else { std::cerr << "[invokeMoeSoftmaxTopkKernel]: source type not supported "; return KernelStatus::KERNEL_STATUS_FAILED; } return KernelStatus::KERNEL_STATUS_SUCCESS; } } // namespace tmo