Files
enginex-mlu370-vllm/torch_mlu_ops-v1.3.2/csrc/kernels/moe/gen_idx.mlu
2026-02-04 17:39:32 +08:00

936 lines
38 KiB
Plaintext

/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include <algorithm>
#include <cmath>
#include <cstddef>
#include <cstdint>
#include <type_traits>
#include <vector>
#include "gen_idx.mluh"
// clang-format off
#include <mlu.h>
// clang-format on
namespace tmo {
namespace kernels {
#define NRAM_BUFFER_SIZE ((__MLU_NRAM_SIZE__ - 16) * 1024)
#define SRAM_BUFFER_SIZE ((__MLU_SRAM_SIZE__ - 8) * 1024)
#define ALIGN_16 (16)
#define EXPERT_AVG_COUNT_TEST (0)
__mlu_shared__ int8_t sram_buffer[SRAM_BUFFER_SIZE];
__nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE];
__nram__ const int range[64] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63};
// Generate integer sequence data from 0 to length-1
__mlu_func__ void generateIntSeq(int *dst, int length) {
int count = 64;
__bang_move(dst, range, std::min(count, length) * sizeof(int));
while (count < length) {
__bang_add_scalar(dst + count, dst, (int)count, std::min(count, length - count));
count *= 2;
}
}
// genIdx Block kernel, use only 1 core to process
__mlu_global__ void launchMoeGenIdxBlockKernel(int *gather_expand_idx,
int *gather_combine_idx,
int *token_count,
int *cusum_token_count,
const void *expert_id,
const int num_token,
const int num_expert,
const int topk) {
/* NRAM space */
// Total occupy: (4 * token_total_num + 2 * num_expert) * sizeof(int)
// --------------------------------------------------------------
// | expert_id | sorted_idx |gen_idx_onchip|cur_expert_result|
// | combine_idx | expand_idx | | scatter_offset |
// |num_token*topk|num_token*topk|num_token*topk| num_token*topk |
// --------------------------------------------------------------
// ------------------------------
// |token_count|token_count_presum|
// | | |
// | num_expert| num_expert |
// ------------------------------
uint32_t token_total_num = num_token * topk;
// num align to 16, size align to 64B
uint32_t align_total_num = (token_total_num + ALIGN_16 - 1) >> 4 << 4;
int8_t *expert_id_onchip = (int8_t *)nram_buffer;
int8_t *sorted_idx_onchip = (int8_t *)expert_id_onchip + align_total_num * sizeof(int);
int8_t *gen_idx_onchip = (int8_t *)sorted_idx_onchip + align_total_num * sizeof(int);
int8_t *cur_expert_result = (int8_t *)gen_idx_onchip + align_total_num * sizeof(int);
int8_t *token_count_onchip = (int8_t *)cur_expert_result + align_total_num * sizeof(int);
int8_t *token_count_presum_onchip = (int8_t *)token_count_onchip + num_expert * sizeof(int);
int8_t *scatter_offset = cur_expert_result; // reuse cur_expert space
#if __BANG_ARCH__ >= 592
int8_t *combine_idx_onchip = expert_id_onchip; // reuse expert_it space
#endif
int8_t *expand_idx_onchip = sorted_idx_onchip; // reuse sorted_idx space
// Load current core input expert_id and generate int sequence
__memcpy_async((int *)expert_id_onchip, (int *)expert_id, token_total_num * sizeof(int),
GDRAM2NRAM);
generateIntSeq((int *)gen_idx_onchip, token_total_num);
__sync();
// Initialize sort idx offset
uint32_t sorted_idx_offset = 0;
// Initialize token count first presum with 0
((int *)token_count_presum_onchip)[0] = 0;
bool need_cusum_token_count = bool(cusum_token_count != nullptr);
// Loop on each expert, eq, count, filter index
for (int cur_expert = 0; cur_expert < num_expert; cur_expert++) {
__bang_eq_scalar((int *)cur_expert_result, (int *)expert_id_onchip, cur_expert,
token_total_num);
// Use filter to sort gen_idx, output with sorted_idx_offset
uint32_t cur_expert_count =
__bang_filter(((float *)sorted_idx_onchip) + sorted_idx_offset, (float *)gen_idx_onchip,
(float *)cur_expert_result, token_total_num);
sorted_idx_offset += cur_expert_count;
((int *)token_count_onchip)[cur_expert] = cur_expert_count;
// Compute cusum token count and store
if (need_cusum_token_count) {
((int *)token_count_presum_onchip)[cur_expert + 1] = sorted_idx_offset;
}
}
#if EXPERT_AVG_COUNT_TEST
// NOTE: test avg expert code here:
uint32_t token_count_avg = token_total_num / num_expert;
uint32_t expert_remain_num = token_total_num % num_expert;
for (int i = 0; i < num_expert; i++) {
((int *)token_count_onchip)[i] =
(i < expert_remain_num) ? token_count_avg + 1 : token_count_avg;
((int *)token_count_presum_onchip)[i + 1] =
((int *)token_count_presum_onchip)[i] + ((int *)token_count_onchip)[i];
}
#endif
__sync_compute();
// Store token_count and cusum token count
__memcpy_async((int *)token_count, (int *)token_count_onchip, num_expert * sizeof(int),
NRAM2GDRAM);
if (need_cusum_token_count) {
__memcpy_async((int *)cusum_token_count, (int *)token_count_presum_onchip,
(num_expert + 1) * sizeof(int), NRAM2GDRAM);
}
// Use sorted idx to generate gather idx for expand and combine
#if __BANG_ARCH__ >= 592
// scatter_offset = sorted_idx mul_scalar sizeof(int);
__bang_mul_scalar((int *)scatter_offset, (int *)sorted_idx_onchip, (int)(sizeof(int)),
token_total_num);
#else
// scatter dst GDRAM addr should align to 64B
int *combine_idx_align_addr = (int *)((uint64_t)(gather_combine_idx) >> 6 << 6);
int combine_idx_align_offset = (int)(gather_combine_idx - combine_idx_align_addr);
__bang_fusion(FUSION_FAM, (int *)scatter_offset, (int *)sorted_idx_onchip,
combine_idx_align_offset, (int)(sizeof(int)), token_total_num);
#endif
__sync_compute();
#if __BANG_ARCH__ >= 592
// scatter_async to NRAM
__scatter_async((int *)combine_idx_onchip, (int *)gen_idx_onchip, (uint32_t *)scatter_offset,
sizeof(int), NRAM2NRAM, sizeof(int), (unsigned short)token_total_num);
#endif
// expand_idx = sorted_idx div(topk)
__bang_div((int *)expand_idx_onchip, (int *)sorted_idx_onchip, topk, token_total_num);
// Store expand_idx and combine_idx
__sync_compute();
__memcpy_async((int *)gather_expand_idx, (int *)expand_idx_onchip, token_total_num * sizeof(int),
NRAM2GDRAM);
#if __BANG_ARCH__ >= 592
__sync_move();
__memcpy_async((int *)gather_combine_idx, (int *)combine_idx_onchip,
token_total_num * sizeof(int), NRAM2GDRAM);
#else
// 370 directly scatter to GDRAM
__scatter((int *)combine_idx_align_addr, (int *)gen_idx_onchip, (uint32_t *)scatter_offset,
sizeof(int), NRAM2GDRAM, sizeof(int), (unsigned short)token_total_num);
#endif
}
// Only MLU500 series support NRAM2SRAM scatter direction
__mlu_func__ void scatterSeqSram(int *dst, int *src, uint32_t *offset, int length) {
#if __BANG_ARCH__ >= 592
// When length larger than 65535(maximum segnum in bang_scatter),
// and src/offset address should align to 64B
int seg_repeat = length / 32768;
int seg_remain = length % 32768;
int seg_offset = 0;
for (int seg = 0; seg < seg_repeat; seg++) {
__scatter_async((int *)dst, ((int *)src) + seg_offset, ((uint32_t *)offset) + seg_offset,
sizeof(int), NRAM2SRAM, sizeof(int), (unsigned short)32768);
seg_offset += 32768;
}
if (seg_remain > 0) {
__scatter_async((int *)dst, ((int *)src) + seg_offset, ((uint32_t *)offset) + seg_offset,
sizeof(int), NRAM2SRAM, sizeof(int), (unsigned short)seg_remain);
}
#endif
}
// Scatter sequence, transfer size is sizeof(int)
__mlu_func__ void scatterSeqDram(int *dst, int *src, uint32_t *offset, int length) {
// When length larger than 65535(maximum segnum in bang_scatter),
// and src/offset address should align to 64B
int seg_repeat = length / 32768;
int seg_remain = length % 32768;
int seg_offset = 0;
for (int seg = 0; seg < seg_repeat; seg++) {
#if __BANG_ARCH__ >= 592
__scatter_async((int *)dst, ((int *)src) + seg_offset, ((uint32_t *)offset) + seg_offset,
sizeof(int), NRAM2GDRAM, sizeof(int), (unsigned short)32768);
#else
__scatter((int *)dst, ((int *)src) + seg_offset, ((uint32_t *)offset) + seg_offset, sizeof(int),
NRAM2GDRAM, sizeof(int), (unsigned short)32768);
#endif
seg_offset += 32768;
}
if (seg_remain > 0) {
#if __BANG_ARCH__ >= 592
__scatter_async((int *)dst, ((int *)src) + seg_offset, ((uint32_t *)offset) + seg_offset,
sizeof(int), NRAM2GDRAM, sizeof(int), (unsigned short)seg_remain);
#else
__scatter((int *)dst, ((int *)src) + seg_offset, ((uint32_t *)offset) + seg_offset, sizeof(int),
NRAM2GDRAM, sizeof(int), (unsigned short)seg_remain);
#endif
}
}
// 1. Get token count
__mlu_func__ void getTokenCount(int *token_count,
int *expert_id,
int token_cur_core,
int cur_token_start,
int num_expert) {
// 1. Partition on [num_token*topk],
// each core for-loop on all expert_id, use eq and count instructions,
// use AtomicAdd to accumulate all expert_id token counts, on GDRAM.
// And sync for all cores.
// NRAM:
// ------------------------------------------------------
// |expert_id_onchip|cur_expert_result|expert_count_onchip|
// | deal_num | deal_num | num_expert |
// ------------------------------------------------------
uint32_t deal_num = (NRAM_BUFFER_SIZE / sizeof(int) - num_expert) / 2;
int8_t *expert_id_onchip = (int8_t *)nram_buffer;
int8_t *cur_expert_result = (int8_t *)expert_id_onchip + deal_num * sizeof(int);
int8_t *expert_count_onchip = cur_expert_result + deal_num * sizeof(int);
// Current core data loop
uint32_t repeat = token_cur_core / deal_num;
uint32_t remain = token_cur_core % deal_num;
uint32_t total_repeat = repeat + (int)(remain > 0);
uint32_t token_addr_offset = cur_token_start;
// Initialize token_count with 0
if (taskId == 0) {
__gdramset((int *)token_count, num_expert, 0);
}
// Sync for initialize token_count
__sync_all_ipu();
// Initialize expert count onchip with 0
if (token_cur_core > 0) {
__bang_write_zero((int *)expert_count_onchip, num_expert);
}
// actual num in loop
int cur_deal_num = deal_num;
for (int i = 0; i < total_repeat; i++) {
if (i == total_repeat - 1 && remain > 0) {
cur_deal_num = remain;
}
// Load current core input expert_id
__memcpy((int *)expert_id_onchip, ((int *)expert_id) + token_addr_offset,
cur_deal_num * sizeof(int), GDRAM2NRAM);
token_addr_offset += cur_deal_num;
// Loop on each expert, eq, count
for (int cur_expert = 0; cur_expert < num_expert; cur_expert++) {
__bang_eq_scalar((int *)cur_expert_result, (int *)expert_id_onchip, cur_expert, cur_deal_num);
// NOTE: __bang_count() only support floating data type
uint32_t cur_expert_count = __bang_count((float *)cur_expert_result, cur_deal_num);
((int *)expert_count_onchip)[cur_expert] += cur_expert_count;
}
}
// AtomicAdd(reduce) all cores token count results
if (token_cur_core > 0) {
__bang_atomic_reduce_add((int *)token_count, (int *)expert_count_onchip, num_expert);
}
// Sync for all cores, get accumulate of token_count
__sync_all_ipu();
}
// 2. Get token count presum, for each expert index start address after sorting
__mlu_func__ void getTokenCountPresum(int *token_count_presum,
int *token_count,
const int num_expert) {
// 2. After first process, already get token_count.
// Then use one core to pre-sum on token_count, consider size of int32,
// first expert id start address should be zero.
// to get each expert id start address after sorting, store to workspace,
// token_count_presum.
// And sync for all cores.
// NRAM:
// load token_count to token_count_presum[1~num_expert+1],
// for i = 0 to num_expert:
// token_count_presum[i+1] += token_count_presum[i]
// store token_count_presum[0~num_expert]
// -------------------------
// |token_count_presum_onchip|
// | {0}, num_expert |
// -------------------------
if (taskId == 0) {
// Initialize count presum onchip with a first 0
int8_t *token_count_presum_onchip = nram_buffer;
((int *)token_count_presum_onchip)[0] = 0;
// Load token_count with an offset of 1
__memcpy(((int *)token_count_presum_onchip) + 1, (int *)token_count, num_expert * sizeof(int),
GDRAM2NRAM);
// Calculate presum of token count by each expert
for (int cur_expert = 0; cur_expert < num_expert; cur_expert++) {
((int *)token_count_presum_onchip)[cur_expert + 1] +=
((int *)token_count_presum_onchip)[cur_expert];
}
// Store token count presum to workspace
__memcpy((int *)token_count_presum, (int *)token_count_presum_onchip,
(num_expert + 1) * sizeof(int), NRAM2GDRAM);
}
// Sync for all cores, get presum of token count
__sync_all_ipu();
}
__mlu_func__ void modifyTokenCountAndPresum(int *token_count_presum,
int *token_count,
const uint32_t token_total_num,
const int num_expert) {
uint32_t token_count_avg = token_total_num / num_expert;
uint32_t expert_remain_num = token_total_num % num_expert;
int8_t *token_count_onchip = nram_buffer;
int8_t *token_count_presum_onchip = token_count_onchip + num_expert * sizeof(int);
((int *)token_count_presum_onchip)[0] = 0;
for (int i = 0; i < num_expert; i++) {
((int *)token_count_onchip)[i] =
(i < expert_remain_num) ? token_count_avg + 1 : token_count_avg;
((int *)token_count_presum_onchip)[i + 1] =
((int *)token_count_presum_onchip)[i] + ((int *)token_count_onchip)[i];
}
__memcpy((int *)token_count, (int *)token_count_onchip, num_expert * sizeof(int), NRAM2GDRAM);
__memcpy((int *)token_count_presum, (int *)token_count_presum_onchip,
(num_expert + 1) * sizeof(int), NRAM2GDRAM);
}
// 3. Get expert position index after sorting
__mlu_func__ void getSortedIdx(int *sorted_idx,
int *expert_id,
int *token_count_presum,
const int token_total_num,
const int num_expert,
const int expert_cur_core,
const int cur_expert_start,
const int cur_expert_end) {
// 3. Partition on num_expert, each core generate position index from 0,
// and for-loop on all expert_id data, use eq with own each expert_id,
// and filter on index, stores to each expert_id start address of
// sorted_idx on workspace.
// And sync for all cores.
// NRAM:
// -------------------------------------------------------------------
// |expert_id_onchip|cur_expert_result|gen_idx_onchip|filter_idx_onchip|
// | deal_num | deal_num | deal_num | deal_num |
// -------------------------------------------------------------------
// |expert_start_addr|
// | num_expert |
// -----------------
// Calculate new deal_num of sorting process
int deal_num = (NRAM_BUFFER_SIZE / sizeof(int) - num_expert) / 4;
// Each core deal with whole token expert_id data
int repeat = token_total_num / deal_num;
int remain = token_total_num % deal_num;
int token_addr_offset = 0;
int8_t *expert_id_onchip = nram_buffer;
int8_t *cur_expert_result = expert_id_onchip + deal_num * sizeof(int);
int8_t *gen_idx_onchip = cur_expert_result + deal_num * sizeof(int);
int8_t *filter_idx_onchip = gen_idx_onchip + deal_num * sizeof(int);
int8_t *expert_start_addr = filter_idx_onchip + deal_num * sizeof(int);
// When num_expert < taskDim, not all cores need to sort
if (expert_cur_core > 0) {
// Generate position index from 0
if (deal_num <= token_total_num) {
generateIntSeq((int *)gen_idx_onchip, deal_num);
} else { // only remainder part
generateIntSeq((int *)gen_idx_onchip, token_total_num);
}
// Initialize expert start address with presum of token count
__memcpy((int *)expert_start_addr, (int *)token_count_presum, num_expert * sizeof(int),
GDRAM2NRAM);
// repeat part
for (int i = 0; i < repeat; i++) {
// Load current core expert_id
__memcpy((int *)expert_id_onchip, ((int *)expert_id) + token_addr_offset,
deal_num * sizeof(int), GDRAM2NRAM);
token_addr_offset += deal_num;
// Loop for current core expert, eq, filter position index
// use filter, store to sorted_idx[expert_start_addr]
for (int cur_expert = cur_expert_start; cur_expert <= cur_expert_end; cur_expert++) {
__bang_eq_scalar((int *)cur_expert_result, (int *)expert_id_onchip, cur_expert, deal_num);
int cur_expert_offset = ((int *)expert_start_addr)[cur_expert];
// NOTE: __bang_filter() only support floating data type
uint32_t cur_expert_count =
__bang_filter((float *)filter_idx_onchip, (float *)gen_idx_onchip,
(float *)cur_expert_result, deal_num);
// Store to the corresponding address of sorted_idx
if (cur_expert_count > 0) {
__memcpy(((int *)sorted_idx) + cur_expert_offset, (int *)filter_idx_onchip,
cur_expert_count * sizeof(int), NRAM2GDRAM);
// Update address offset of current expert
((int *)expert_start_addr)[cur_expert] = cur_expert_offset + cur_expert_count;
}
}
// Update position index for each data loop
__bang_add_scalar((int *)gen_idx_onchip, (int *)gen_idx_onchip, (int)(deal_num), deal_num);
}
// remainder part
if (remain > 0) {
__memcpy((int *)expert_id_onchip, ((int *)expert_id) + token_addr_offset,
remain * sizeof(int), GDRAM2NRAM);
for (int cur_expert = cur_expert_start; cur_expert <= cur_expert_end; cur_expert++) {
__bang_eq_scalar((int *)cur_expert_result, (int *)expert_id_onchip, cur_expert, remain);
int cur_expert_offset = ((int *)expert_start_addr)[cur_expert];
// NOTE: __bang_filter() only support floating data type
uint32_t cur_expert_count =
__bang_filter((float *)filter_idx_onchip, (float *)gen_idx_onchip,
(float *)cur_expert_result, remain);
// Store to the corresponding address of sorted_idx
if (cur_expert_count > 0) {
__memcpy(((int *)sorted_idx) + cur_expert_offset, (int *)filter_idx_onchip,
cur_expert_count * sizeof(int), NRAM2GDRAM);
}
}
}
}
// Sync for all cores, get position index after sorting
__sync_all_ipu();
}
// 4. Get gather index for expand and combine
template <bool is_sram_scatter>
__mlu_func__ void getGatherIdx(int *gather_expand_idx,
int *gather_combine_idx,
int *sorted_idx,
const int token_cur_core,
const int cur_token_start,
const int topk) {
// 4. Partition on [num_token*topk],
// load sorted_idx onchip,
// generate sequence according to position index from 0, add token offset
// gather_combine_idx = scatter(seq, sorted_idx)
// gather_expand_idx = sorted_idx / topk
// update sequence
// NRAM:
// -------------------------------------------------------------------
// |sorted_idx_onchip|expand_idx_onchip|scatter_offset|scatter_sequence|
// | deal_num | deal_num | deal_num | deal_num |
// -------------------------------------------------------------------
// Calculate new deal_num of generate gather index
// NOTE: deal_num should align to 64 Bytes, because bang_scatter() constraints
int deal_num = (NRAM_BUFFER_SIZE / sizeof(int)) / 4;
int repeat = token_cur_core / deal_num;
int remain = token_cur_core % deal_num;
int token_addr_offset = cur_token_start;
// scatter dst GDRAM addr should align to 64B
int *combine_idx_align_addr = (int *)((uint64_t)(gather_combine_idx) >> 6 << 6);
int combine_idx_align_offset = (int)(gather_combine_idx - combine_idx_align_addr);
int8_t *sorted_idx_onchip = nram_buffer;
int8_t *expand_idx_onchip = sorted_idx_onchip + deal_num * sizeof(int);
int8_t *scatter_offset = expand_idx_onchip + deal_num * sizeof(int);
int8_t *scatter_sequence = scatter_offset + deal_num * sizeof(int);
// Generate position index from 0
// Add base offset to sequence according to current core token start address
if (token_cur_core > 0) {
if (deal_num <= token_cur_core) {
generateIntSeq((int *)scatter_sequence, deal_num);
__bang_add_scalar((int *)scatter_sequence, (int *)scatter_sequence, (int)token_addr_offset,
deal_num);
} else { // only remainder part
generateIntSeq((int *)scatter_sequence, token_cur_core);
__bang_add_scalar((int *)scatter_sequence, (int *)scatter_sequence, (int)token_addr_offset,
token_cur_core);
}
}
// repeat part
for (int i = 0; i < repeat; i++) {
// Load current core sorted_idx
__memcpy((int *)sorted_idx_onchip, ((int *)sorted_idx) + token_addr_offset,
deal_num * sizeof(int), GDRAM2NRAM);
// offset = sorted_idx * sizeof(int), counted in bytes
if (is_sram_scatter) {
__bang_mul_scalar((int *)scatter_offset, (int *)sorted_idx_onchip, (int)(sizeof(int)),
deal_num);
} else {
// GDRAM addr should align to 64B
__bang_fusion(FUSION_FAM, (int *)scatter_offset, (int *)sorted_idx_onchip,
combine_idx_align_offset, (int)(sizeof(int)), deal_num);
}
// Sync for scatter
__sync_compute();
if (is_sram_scatter) {
scatterSeqSram((int *)sram_buffer, (int *)scatter_sequence, (uint32_t *)scatter_offset,
deal_num);
} else {
// Scatter to output gather_combine_idx
scatterSeqDram((int *)combine_idx_align_addr, (int *)scatter_sequence,
(uint32_t *)scatter_offset, deal_num);
}
// expand_idx_onchip = sorted_idx / topk
__bang_div((int *)expand_idx_onchip, (int *)sorted_idx_onchip, topk, deal_num);
// Store expand idx
__memcpy(((int *)gather_expand_idx) + token_addr_offset, (int *)expand_idx_onchip,
deal_num * sizeof(int), NRAM2GDRAM);
if (is_sram_scatter) {
// if scatter to SRAM, need to sync compute with mv
__sync_move();
}
// Add offset to sequence and token_address
__bang_add_scalar((int *)scatter_sequence, (int *)scatter_sequence, (int)deal_num, deal_num);
token_addr_offset += deal_num;
}
// remainder part
if (remain > 0) {
// Load current core sorted_idx
__memcpy((int *)sorted_idx_onchip, ((int *)sorted_idx) + token_addr_offset,
remain * sizeof(int), GDRAM2NRAM);
// offset = sorted_idx * sizeof(int), counted in bytes
if (is_sram_scatter) {
__bang_mul_scalar((int *)scatter_offset, (int *)sorted_idx_onchip, (int)(sizeof(int)),
remain);
} else {
// GDRAM addr should align to 64B
__bang_fusion(FUSION_FAM, (int *)scatter_offset, (int *)sorted_idx_onchip,
combine_idx_align_offset, (int)(sizeof(int)), remain);
}
// Sync for scatter
__sync_compute();
if (is_sram_scatter) {
scatterSeqSram((int *)sram_buffer, (int *)scatter_sequence, (uint32_t *)scatter_offset,
remain);
} else {
// Scatter to output gather_combine_idx
scatterSeqDram((int *)combine_idx_align_addr, (int *)scatter_sequence,
(uint32_t *)scatter_offset, remain);
}
// expand_idx_onchip = sorted_idx / topk
__bang_div((int *)expand_idx_onchip, (int *)sorted_idx_onchip, topk, remain);
// Store expand idx
__memcpy(((int *)gather_expand_idx) + token_addr_offset, (int *)expand_idx_onchip,
remain * sizeof(int), NRAM2GDRAM);
}
}
// 4.1 Get gather combine index on SRAM
__mlu_func__ void getCombineIdxSram(int *sorted_idx,
const int token_cur_core,
const int cur_token_start) {
// 4.1 Partition on [num_token*topk], with only 1 union
// load sorted_idx onchip,
// generate sequence according to position index from 0, add token offset
// gather_combine_idx = scatter(seq, sorted_idx)
// update sequence
// NRAM:
// -------------------------------
// |scatter_offset|scatter_sequence|
// | deal_num | deal_num |
// -------------------------------
// Calculate new deal_num of generate gather index
// NOTE: deal_num should align to 64 Bytes, because bang_scatter() constraints
int deal_num = (NRAM_BUFFER_SIZE / sizeof(int)) / 2;
int repeat = token_cur_core / deal_num;
int remain = token_cur_core % deal_num;
int token_addr_offset = cur_token_start;
int8_t *scatter_offset = nram_buffer;
int8_t *scatter_sequence = scatter_offset + deal_num * sizeof(int);
// Generate position index from 0
// Add base offset to sequence according to current core token start address
if (token_cur_core > 0) {
if (deal_num <= token_cur_core) {
generateIntSeq((int *)scatter_sequence, deal_num);
__bang_add_scalar((int *)scatter_sequence, (int *)scatter_sequence, (int)token_addr_offset,
deal_num);
} else { // only remainder part
generateIntSeq((int *)scatter_sequence, token_cur_core);
__bang_add_scalar((int *)scatter_sequence, (int *)scatter_sequence, (int)token_addr_offset,
token_cur_core);
}
}
// repeat part
for (int i = 0; i < repeat; i++) {
// Load current core sorted_idx
__memcpy((int *)scatter_offset, ((int *)sorted_idx) + token_addr_offset, deal_num * sizeof(int),
GDRAM2NRAM);
// offset = sorted_idx * sizeof(int), counted in bytes
__bang_mul_scalar((int *)scatter_offset, (int *)scatter_offset, (int)(sizeof(int)), deal_num);
// Sync for scatter
__sync_compute();
// Scatter to SRAM
scatterSeqSram((int *)sram_buffer, (int *)scatter_sequence, (uint32_t *)scatter_offset,
deal_num);
__sync_move();
// Add offset to sequence and token_address
__bang_add_scalar((int *)scatter_sequence, (int *)scatter_sequence, (int)deal_num, deal_num);
token_addr_offset += deal_num;
}
// remainder part
if (remain > 0) {
// Load current core sorted_idx
__memcpy((int *)scatter_offset, ((int *)sorted_idx) + token_addr_offset, remain * sizeof(int),
GDRAM2NRAM);
// offset = sorted_idx * sizeof(int), counted in bytes
__bang_mul_scalar((int *)scatter_offset, (int *)scatter_offset, (int)(sizeof(int)), remain);
// Sync for scatter
__sync_compute();
scatterSeqSram((int *)sram_buffer, (int *)scatter_sequence, (uint32_t *)scatter_offset, remain);
}
}
// 4.2 Get gather expand index
__mlu_func__ void getExpandIdx(int *gather_expand_idx,
int *sorted_idx,
const int token_cur_core,
const int cur_token_start,
const int topk) {
// 4.2 Partition on [num_token*topk],
// load sorted_idx onchip,
// gather_expand_idx = sorted_idx / topk
// NRAM:
// -----------------------------------
// |sorted_idx_onchip|expand_idx_onchip|
// | deal_num | deal_num |
// -----------------------------------
// Calculate new deal_num of generate gather index
int deal_num = (NRAM_BUFFER_SIZE / sizeof(int)) / 2;
int repeat = token_cur_core / deal_num;
int remain = token_cur_core % deal_num;
int token_addr_offset = cur_token_start;
int8_t *sorted_idx_onchip = nram_buffer;
int8_t *expand_idx_onchip = sorted_idx_onchip + deal_num * sizeof(int);
// repeat part
for (int i = 0; i < repeat; i++) {
// Load current core sorted_idx
__memcpy((int *)sorted_idx_onchip, ((int *)sorted_idx) + token_addr_offset,
deal_num * sizeof(int), GDRAM2NRAM);
// expand_idx_onchip = sorted_idx / topk
__bang_div((int *)expand_idx_onchip, (int *)sorted_idx_onchip, topk, deal_num);
// Store expand idx
__memcpy(((int *)gather_expand_idx) + token_addr_offset, (int *)expand_idx_onchip,
deal_num * sizeof(int), NRAM2GDRAM);
token_addr_offset += deal_num;
}
// remainder part
if (remain > 0) {
// Load current core sorted_idx
__memcpy((int *)sorted_idx_onchip, ((int *)sorted_idx) + token_addr_offset,
remain * sizeof(int), GDRAM2NRAM);
// expand_idx_onchip = sorted_idx / topk
__bang_div((int *)expand_idx_onchip, (int *)sorted_idx_onchip, topk, remain);
// Store expand idx
__memcpy(((int *)gather_expand_idx) + token_addr_offset, (int *)expand_idx_onchip,
remain * sizeof(int), NRAM2GDRAM);
}
}
__mlu_global__ void launchMoeGenIdxKernel(int *gather_expand_idx,
int *gather_combine_idx,
int *token_count,
int *cusum_token_count,
void *workspace,
const void *expert_id,
const int num_token,
const int num_expert,
const int topk) {
// Store token count presum result, shape [num_expert + 1]
int *token_count_presum = (cusum_token_count != nullptr) ? cusum_token_count : (int *)workspace;
// Store position index after sorting, shape [num_token*topk]
int *sorted_idx = ((int *)workspace) + num_expert + 1;
// Calculate partition information for different processes
// Partition on [num_token*topk]
uint32_t token_total_num = num_token * topk;
uint32_t token_cur_core = token_total_num / taskDim;
uint32_t token_remain_num = token_total_num % taskDim;
token_cur_core += (uint32_t)(taskId < token_remain_num);
// Current core range according to partition on [num_token*topk]
uint32_t cur_token_start = (taskId < token_remain_num)
? token_cur_core * taskId
: token_cur_core * taskId + token_remain_num;
// Partition on [num_expert]
uint32_t expert_cur_core = num_expert / taskDim;
uint32_t expert_remain_num = num_expert % taskDim;
expert_cur_core += (uint32_t)(taskId < expert_remain_num);
// Current core range according to partition on [num_expert]
uint32_t cur_expert_start = (taskId < expert_remain_num)
? expert_cur_core * taskId
: expert_cur_core * taskId + expert_remain_num;
uint32_t cur_expert_end = cur_expert_start + expert_cur_core - 1;
// Use Union1 SRAM to scatter, only MLU500 series support now
#if __BANG_ARCH__ >= 592
bool is_sram_scatter = token_total_num * sizeof(int) < SRAM_BUFFER_SIZE;
#else
bool is_sram_scatter = false;
#endif
if (__is_ipu()) {
// 1. Get token count
getTokenCount((int *)token_count, (int *)expert_id, token_cur_core, cur_token_start,
num_expert);
// 2. Get presum of token count
getTokenCountPresum((int *)token_count_presum, (int *)token_count, num_expert);
// 3. Get expert position index after sorting
getSortedIdx((int *)sorted_idx, (int *)expert_id, (int *)token_count_presum, token_total_num,
num_expert, expert_cur_core, cur_expert_start, cur_expert_end);
}
#if EXPERT_AVG_COUNT_TEST
// NOTE: test avg expert code here:
if (__is_ipu() && taskId == 0) {
modifyTokenCountAndPresum((int *)token_count_presum, (int *)token_count, token_total_num,
num_expert);
}
__sync_cluster();
#endif
// 4. Get gather index for expand and combine
if (is_sram_scatter) {
// Only use Union1 SRAM
uint32_t scatter_idx_cur_core = token_total_num / 4;
uint32_t scatter_idx_remain_num = token_total_num % 4;
scatter_idx_cur_core += (uint32_t)(taskId < scatter_idx_remain_num);
uint32_t cur_idx_start = (taskId < scatter_idx_remain_num)
? scatter_idx_cur_core * taskId
: scatter_idx_cur_core * taskId + scatter_idx_remain_num;
// Only Union1 task type,
// deal once num is same with deal_num in getGatherIdx,
// which means only 1 repeat to generate both expand and combine idx on NRAM
const int deal_once_num = (NRAM_BUFFER_SIZE / sizeof(int)) / 4;
if (taskDim <= 4 || token_total_num < deal_once_num) {
if (taskId < 4) {
if (__is_ipu()) {
getGatherIdx<true>((int *)gather_expand_idx, (int *)gather_combine_idx, (int *)sorted_idx,
scatter_idx_cur_core, cur_idx_start, topk);
// sync for ipu and mpu
__sync_cluster();
} else {
// sync for ipu and mpu
__sync_cluster();
__memcpy_async((int *)gather_combine_idx, (int *)sram_buffer,
token_total_num * sizeof(int), SRAM2GDRAM);
}
}
} else {
// If taskDim > 4, use first union to generate combine idx,
// use other union to generate expand idx
if (taskId < 4) {
if (__is_ipu()) {
// Scatter combine idx to SRAM
getCombineIdxSram((int *)sorted_idx, scatter_idx_cur_core, cur_idx_start);
__sync_cluster();
} else {
__sync_cluster();
__memcpy_async((int *)gather_combine_idx, (int *)sram_buffer,
token_total_num * sizeof(int), SRAM2GDRAM);
}
} else {
// Other union generate expand idx
if (__is_ipu()) {
uint32_t expand_dim = taskDim - 4;
uint32_t expand_id = taskId - 4;
uint32_t expand_token_cur_core = token_total_num / expand_dim;
uint32_t expand_token_remain_num = token_total_num % expand_dim;
expand_token_cur_core += (uint32_t)(expand_id < expand_token_remain_num);
uint32_t expand_cur_token_start =
(expand_id < expand_token_remain_num)
? expand_token_cur_core * expand_id
: expand_token_cur_core * expand_id + expand_token_remain_num;
getExpandIdx((int *)gather_expand_idx, (int *)sorted_idx, expand_token_cur_core,
expand_cur_token_start, topk);
}
}
}
} else {
// not use SRAM to generate both expand and combine idx
if (__is_ipu()) {
getGatherIdx<false>((int *)gather_expand_idx, (int *)gather_combine_idx, (int *)sorted_idx,
token_cur_core, cur_token_start, topk);
}
}
// step 5 does not need MPU
if (__is_mpu()) {
return;
}
} // end of kernel
} // namespace kernels
KernelStatus invokeMoeGenIdxKernel(cnrtQueue_t queue,
int *gather_expand_idx,
int *gather_combine_idx,
int *token_count,
int *cusum_token_count,
void *workspace,
const void *expert_id,
const int num_token,
const int num_expert,
const int topk) {
CNdev dev;
cnCtxGetDevice(&dev);
int cluster_num;
int core_num;
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev));
const int token_total_num = num_token * topk;
// For partition on num_token*topk, single core processes at least 128 num
const int single_core_num_limit = 1024;
int need_core_num = std::ceil(float(token_total_num) / single_core_num_limit);
// When partition on num_expert, each core at least processes one expert
need_core_num = std::max(num_expert, need_core_num);
// When consider UnionX cnrt func type, reset cluster_num
if (token_total_num <= 4096) { // Block
cnrtFunctionType_t k_type = cnrtFuncTypeBlock;
cnrtDim3_t k_dim{1, 1, 1};
// Block kernel does not need workspace
kernels::launchMoeGenIdxBlockKernel<<<k_dim, k_type, queue>>>(
gather_expand_idx, gather_combine_idx, token_count, cusum_token_count, expert_id, num_token,
num_expert, topk);
return KernelStatus::KERNEL_STATUS_SUCCESS;
} else if (need_core_num <= 4) { // Union1
cluster_num = 1;
} else if (need_core_num <= 8) { // Union2
cluster_num = std::min(cluster_num, 2);
} else if (need_core_num <= 16) { // Union4
cluster_num = std::min(cluster_num, 4);
} else if (need_core_num <= 32) { // Union8
cluster_num = std::min(cluster_num, 8);
}
cnrtFunctionType_t k_type;
cnrtDim3_t k_dim{1, 1, 1};
// Find max UnionX cnrt func type
if (cluster_num == 1) {
k_type = cnrtFuncTypeUnion1;
k_dim.x = 4;
} else if (cluster_num < 4) { // cluster num is 2 or 3
k_type = cnrtFuncTypeUnion2;
k_dim.x = 8;
} else if (cluster_num < 8) { // cluster num is 4,5,6,7
k_type = cnrtFuncTypeUnion4;
k_dim.x = 16;
} else { // cluster num larger than 8
k_type = cnrtFuncTypeUnion8;
k_dim.x = 32;
}
// The expert_id is int data type
kernels::launchMoeGenIdxKernel<<<k_dim, k_type, queue>>>(
gather_expand_idx, gather_combine_idx, token_count, cusum_token_count, workspace, expert_id,
num_token, num_expert, topk);
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
#undef EXPERT_AVG_COUNT_TEST // undef test macro
} // namespace tmo