/************************************************************************* * Copyright (C) [2023-2024] by Cambricon, Inc. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. *************************************************************************/ #include #include #include #include #include #include #include "gen_idx.mluh" // clang-format off #include // clang-format on namespace tmo { namespace kernels { #define NRAM_BUFFER_SIZE ((__MLU_NRAM_SIZE__ - 16) * 1024) #define SRAM_BUFFER_SIZE ((__MLU_SRAM_SIZE__ - 8) * 1024) #define ALIGN_16 (16) #define EXPERT_AVG_COUNT_TEST (0) __mlu_shared__ int8_t sram_buffer[SRAM_BUFFER_SIZE]; __nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE]; __nram__ const int range[64] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63}; // Generate integer sequence data from 0 to length-1 __mlu_func__ void generateIntSeq(int *dst, int length) { int count = 64; __bang_move(dst, range, std::min(count, length) * sizeof(int)); while (count < length) { __bang_add_scalar(dst + count, dst, (int)count, std::min(count, length - count)); count *= 2; } } // genIdx Block kernel, use only 1 core to process __mlu_global__ void launchMoeGenIdxBlockKernel(int *gather_expand_idx, int *gather_combine_idx, int *token_count, int *cusum_token_count, const void *expert_id, const int num_token, const int num_expert, const int topk) { /* NRAM space */ // Total occupy: (4 * token_total_num + 2 * num_expert) * sizeof(int) // -------------------------------------------------------------- // | expert_id | sorted_idx |gen_idx_onchip|cur_expert_result| // | combine_idx | expand_idx | | scatter_offset | // |num_token*topk|num_token*topk|num_token*topk| num_token*topk | // -------------------------------------------------------------- // ------------------------------ // |token_count|token_count_presum| // | | | // | num_expert| num_expert | // ------------------------------ uint32_t token_total_num = num_token * topk; // num align to 16, size align to 64B uint32_t align_total_num = (token_total_num + ALIGN_16 - 1) >> 4 << 4; int8_t *expert_id_onchip = (int8_t *)nram_buffer; int8_t *sorted_idx_onchip = (int8_t *)expert_id_onchip + align_total_num * sizeof(int); int8_t *gen_idx_onchip = (int8_t *)sorted_idx_onchip + align_total_num * sizeof(int); int8_t *cur_expert_result = (int8_t *)gen_idx_onchip + align_total_num * sizeof(int); int8_t *token_count_onchip = (int8_t *)cur_expert_result + align_total_num * sizeof(int); int8_t *token_count_presum_onchip = (int8_t *)token_count_onchip + num_expert * sizeof(int); int8_t *scatter_offset = cur_expert_result; // reuse cur_expert space #if __BANG_ARCH__ >= 592 int8_t *combine_idx_onchip = expert_id_onchip; // reuse expert_it space #endif int8_t *expand_idx_onchip = sorted_idx_onchip; // reuse sorted_idx space // Load current core input expert_id and generate int sequence __memcpy_async((int *)expert_id_onchip, (int *)expert_id, token_total_num * sizeof(int), GDRAM2NRAM); generateIntSeq((int *)gen_idx_onchip, token_total_num); __sync(); // Initialize sort idx offset uint32_t sorted_idx_offset = 0; // Initialize token count first presum with 0 ((int *)token_count_presum_onchip)[0] = 0; bool need_cusum_token_count = bool(cusum_token_count != nullptr); // Loop on each expert, eq, count, filter index for (int cur_expert = 0; cur_expert < num_expert; cur_expert++) { __bang_eq_scalar((int *)cur_expert_result, (int *)expert_id_onchip, cur_expert, token_total_num); // Use filter to sort gen_idx, output with sorted_idx_offset uint32_t cur_expert_count = __bang_filter(((float *)sorted_idx_onchip) + sorted_idx_offset, (float *)gen_idx_onchip, (float *)cur_expert_result, token_total_num); sorted_idx_offset += cur_expert_count; ((int *)token_count_onchip)[cur_expert] = cur_expert_count; // Compute cusum token count and store if (need_cusum_token_count) { ((int *)token_count_presum_onchip)[cur_expert + 1] = sorted_idx_offset; } } #if EXPERT_AVG_COUNT_TEST // NOTE: test avg expert code here: uint32_t token_count_avg = token_total_num / num_expert; uint32_t expert_remain_num = token_total_num % num_expert; for (int i = 0; i < num_expert; i++) { ((int *)token_count_onchip)[i] = (i < expert_remain_num) ? token_count_avg + 1 : token_count_avg; ((int *)token_count_presum_onchip)[i + 1] = ((int *)token_count_presum_onchip)[i] + ((int *)token_count_onchip)[i]; } #endif __sync_compute(); // Store token_count and cusum token count __memcpy_async((int *)token_count, (int *)token_count_onchip, num_expert * sizeof(int), NRAM2GDRAM); if (need_cusum_token_count) { __memcpy_async((int *)cusum_token_count, (int *)token_count_presum_onchip, (num_expert + 1) * sizeof(int), NRAM2GDRAM); } // Use sorted idx to generate gather idx for expand and combine #if __BANG_ARCH__ >= 592 // scatter_offset = sorted_idx mul_scalar sizeof(int); __bang_mul_scalar((int *)scatter_offset, (int *)sorted_idx_onchip, (int)(sizeof(int)), token_total_num); #else // scatter dst GDRAM addr should align to 64B int *combine_idx_align_addr = (int *)((uint64_t)(gather_combine_idx) >> 6 << 6); int combine_idx_align_offset = (int)(gather_combine_idx - combine_idx_align_addr); __bang_fusion(FUSION_FAM, (int *)scatter_offset, (int *)sorted_idx_onchip, combine_idx_align_offset, (int)(sizeof(int)), token_total_num); #endif __sync_compute(); #if __BANG_ARCH__ >= 592 // scatter_async to NRAM __scatter_async((int *)combine_idx_onchip, (int *)gen_idx_onchip, (uint32_t *)scatter_offset, sizeof(int), NRAM2NRAM, sizeof(int), (unsigned short)token_total_num); #endif // expand_idx = sorted_idx div(topk) __bang_div((int *)expand_idx_onchip, (int *)sorted_idx_onchip, topk, token_total_num); // Store expand_idx and combine_idx __sync_compute(); __memcpy_async((int *)gather_expand_idx, (int *)expand_idx_onchip, token_total_num * sizeof(int), NRAM2GDRAM); #if __BANG_ARCH__ >= 592 __sync_move(); __memcpy_async((int *)gather_combine_idx, (int *)combine_idx_onchip, token_total_num * sizeof(int), NRAM2GDRAM); #else // 370 directly scatter to GDRAM __scatter((int *)combine_idx_align_addr, (int *)gen_idx_onchip, (uint32_t *)scatter_offset, sizeof(int), NRAM2GDRAM, sizeof(int), (unsigned short)token_total_num); #endif } // Only MLU500 series support NRAM2SRAM scatter direction __mlu_func__ void scatterSeqSram(int *dst, int *src, uint32_t *offset, int length) { #if __BANG_ARCH__ >= 592 // When length larger than 65535(maximum segnum in bang_scatter), // and src/offset address should align to 64B int seg_repeat = length / 32768; int seg_remain = length % 32768; int seg_offset = 0; for (int seg = 0; seg < seg_repeat; seg++) { __scatter_async((int *)dst, ((int *)src) + seg_offset, ((uint32_t *)offset) + seg_offset, sizeof(int), NRAM2SRAM, sizeof(int), (unsigned short)32768); seg_offset += 32768; } if (seg_remain > 0) { __scatter_async((int *)dst, ((int *)src) + seg_offset, ((uint32_t *)offset) + seg_offset, sizeof(int), NRAM2SRAM, sizeof(int), (unsigned short)seg_remain); } #endif } // Scatter sequence, transfer size is sizeof(int) __mlu_func__ void scatterSeqDram(int *dst, int *src, uint32_t *offset, int length) { // When length larger than 65535(maximum segnum in bang_scatter), // and src/offset address should align to 64B int seg_repeat = length / 32768; int seg_remain = length % 32768; int seg_offset = 0; for (int seg = 0; seg < seg_repeat; seg++) { #if __BANG_ARCH__ >= 592 __scatter_async((int *)dst, ((int *)src) + seg_offset, ((uint32_t *)offset) + seg_offset, sizeof(int), NRAM2GDRAM, sizeof(int), (unsigned short)32768); #else __scatter((int *)dst, ((int *)src) + seg_offset, ((uint32_t *)offset) + seg_offset, sizeof(int), NRAM2GDRAM, sizeof(int), (unsigned short)32768); #endif seg_offset += 32768; } if (seg_remain > 0) { #if __BANG_ARCH__ >= 592 __scatter_async((int *)dst, ((int *)src) + seg_offset, ((uint32_t *)offset) + seg_offset, sizeof(int), NRAM2GDRAM, sizeof(int), (unsigned short)seg_remain); #else __scatter((int *)dst, ((int *)src) + seg_offset, ((uint32_t *)offset) + seg_offset, sizeof(int), NRAM2GDRAM, sizeof(int), (unsigned short)seg_remain); #endif } } // 1. Get token count __mlu_func__ void getTokenCount(int *token_count, int *expert_id, int token_cur_core, int cur_token_start, int num_expert) { // 1. Partition on [num_token*topk], // each core for-loop on all expert_id, use eq and count instructions, // use AtomicAdd to accumulate all expert_id token counts, on GDRAM. // And sync for all cores. // NRAM: // ------------------------------------------------------ // |expert_id_onchip|cur_expert_result|expert_count_onchip| // | deal_num | deal_num | num_expert | // ------------------------------------------------------ uint32_t deal_num = (NRAM_BUFFER_SIZE / sizeof(int) - num_expert) / 2; int8_t *expert_id_onchip = (int8_t *)nram_buffer; int8_t *cur_expert_result = (int8_t *)expert_id_onchip + deal_num * sizeof(int); int8_t *expert_count_onchip = cur_expert_result + deal_num * sizeof(int); // Current core data loop uint32_t repeat = token_cur_core / deal_num; uint32_t remain = token_cur_core % deal_num; uint32_t total_repeat = repeat + (int)(remain > 0); uint32_t token_addr_offset = cur_token_start; // Initialize token_count with 0 if (taskId == 0) { __gdramset((int *)token_count, num_expert, 0); } // Sync for initialize token_count __sync_all_ipu(); // Initialize expert count onchip with 0 if (token_cur_core > 0) { __bang_write_zero((int *)expert_count_onchip, num_expert); } // actual num in loop int cur_deal_num = deal_num; for (int i = 0; i < total_repeat; i++) { if (i == total_repeat - 1 && remain > 0) { cur_deal_num = remain; } // Load current core input expert_id __memcpy((int *)expert_id_onchip, ((int *)expert_id) + token_addr_offset, cur_deal_num * sizeof(int), GDRAM2NRAM); token_addr_offset += cur_deal_num; // Loop on each expert, eq, count for (int cur_expert = 0; cur_expert < num_expert; cur_expert++) { __bang_eq_scalar((int *)cur_expert_result, (int *)expert_id_onchip, cur_expert, cur_deal_num); // NOTE: __bang_count() only support floating data type uint32_t cur_expert_count = __bang_count((float *)cur_expert_result, cur_deal_num); ((int *)expert_count_onchip)[cur_expert] += cur_expert_count; } } // AtomicAdd(reduce) all cores token count results if (token_cur_core > 0) { __bang_atomic_reduce_add((int *)token_count, (int *)expert_count_onchip, num_expert); } // Sync for all cores, get accumulate of token_count __sync_all_ipu(); } // 2. Get token count presum, for each expert index start address after sorting __mlu_func__ void getTokenCountPresum(int *token_count_presum, int *token_count, const int num_expert) { // 2. After first process, already get token_count. // Then use one core to pre-sum on token_count, consider size of int32, // first expert id start address should be zero. // to get each expert id start address after sorting, store to workspace, // token_count_presum. // And sync for all cores. // NRAM: // load token_count to token_count_presum[1~num_expert+1], // for i = 0 to num_expert: // token_count_presum[i+1] += token_count_presum[i] // store token_count_presum[0~num_expert] // ------------------------- // |token_count_presum_onchip| // | {0}, num_expert | // ------------------------- if (taskId == 0) { // Initialize count presum onchip with a first 0 int8_t *token_count_presum_onchip = nram_buffer; ((int *)token_count_presum_onchip)[0] = 0; // Load token_count with an offset of 1 __memcpy(((int *)token_count_presum_onchip) + 1, (int *)token_count, num_expert * sizeof(int), GDRAM2NRAM); // Calculate presum of token count by each expert for (int cur_expert = 0; cur_expert < num_expert; cur_expert++) { ((int *)token_count_presum_onchip)[cur_expert + 1] += ((int *)token_count_presum_onchip)[cur_expert]; } // Store token count presum to workspace __memcpy((int *)token_count_presum, (int *)token_count_presum_onchip, (num_expert + 1) * sizeof(int), NRAM2GDRAM); } // Sync for all cores, get presum of token count __sync_all_ipu(); } __mlu_func__ void modifyTokenCountAndPresum(int *token_count_presum, int *token_count, const uint32_t token_total_num, const int num_expert) { uint32_t token_count_avg = token_total_num / num_expert; uint32_t expert_remain_num = token_total_num % num_expert; int8_t *token_count_onchip = nram_buffer; int8_t *token_count_presum_onchip = token_count_onchip + num_expert * sizeof(int); ((int *)token_count_presum_onchip)[0] = 0; for (int i = 0; i < num_expert; i++) { ((int *)token_count_onchip)[i] = (i < expert_remain_num) ? token_count_avg + 1 : token_count_avg; ((int *)token_count_presum_onchip)[i + 1] = ((int *)token_count_presum_onchip)[i] + ((int *)token_count_onchip)[i]; } __memcpy((int *)token_count, (int *)token_count_onchip, num_expert * sizeof(int), NRAM2GDRAM); __memcpy((int *)token_count_presum, (int *)token_count_presum_onchip, (num_expert + 1) * sizeof(int), NRAM2GDRAM); } // 3. Get expert position index after sorting __mlu_func__ void getSortedIdx(int *sorted_idx, int *expert_id, int *token_count_presum, const int token_total_num, const int num_expert, const int expert_cur_core, const int cur_expert_start, const int cur_expert_end) { // 3. Partition on num_expert, each core generate position index from 0, // and for-loop on all expert_id data, use eq with own each expert_id, // and filter on index, stores to each expert_id start address of // sorted_idx on workspace. // And sync for all cores. // NRAM: // ------------------------------------------------------------------- // |expert_id_onchip|cur_expert_result|gen_idx_onchip|filter_idx_onchip| // | deal_num | deal_num | deal_num | deal_num | // ------------------------------------------------------------------- // |expert_start_addr| // | num_expert | // ----------------- // Calculate new deal_num of sorting process int deal_num = (NRAM_BUFFER_SIZE / sizeof(int) - num_expert) / 4; // Each core deal with whole token expert_id data int repeat = token_total_num / deal_num; int remain = token_total_num % deal_num; int token_addr_offset = 0; int8_t *expert_id_onchip = nram_buffer; int8_t *cur_expert_result = expert_id_onchip + deal_num * sizeof(int); int8_t *gen_idx_onchip = cur_expert_result + deal_num * sizeof(int); int8_t *filter_idx_onchip = gen_idx_onchip + deal_num * sizeof(int); int8_t *expert_start_addr = filter_idx_onchip + deal_num * sizeof(int); // When num_expert < taskDim, not all cores need to sort if (expert_cur_core > 0) { // Generate position index from 0 if (deal_num <= token_total_num) { generateIntSeq((int *)gen_idx_onchip, deal_num); } else { // only remainder part generateIntSeq((int *)gen_idx_onchip, token_total_num); } // Initialize expert start address with presum of token count __memcpy((int *)expert_start_addr, (int *)token_count_presum, num_expert * sizeof(int), GDRAM2NRAM); // repeat part for (int i = 0; i < repeat; i++) { // Load current core expert_id __memcpy((int *)expert_id_onchip, ((int *)expert_id) + token_addr_offset, deal_num * sizeof(int), GDRAM2NRAM); token_addr_offset += deal_num; // Loop for current core expert, eq, filter position index // use filter, store to sorted_idx[expert_start_addr] for (int cur_expert = cur_expert_start; cur_expert <= cur_expert_end; cur_expert++) { __bang_eq_scalar((int *)cur_expert_result, (int *)expert_id_onchip, cur_expert, deal_num); int cur_expert_offset = ((int *)expert_start_addr)[cur_expert]; // NOTE: __bang_filter() only support floating data type uint32_t cur_expert_count = __bang_filter((float *)filter_idx_onchip, (float *)gen_idx_onchip, (float *)cur_expert_result, deal_num); // Store to the corresponding address of sorted_idx if (cur_expert_count > 0) { __memcpy(((int *)sorted_idx) + cur_expert_offset, (int *)filter_idx_onchip, cur_expert_count * sizeof(int), NRAM2GDRAM); // Update address offset of current expert ((int *)expert_start_addr)[cur_expert] = cur_expert_offset + cur_expert_count; } } // Update position index for each data loop __bang_add_scalar((int *)gen_idx_onchip, (int *)gen_idx_onchip, (int)(deal_num), deal_num); } // remainder part if (remain > 0) { __memcpy((int *)expert_id_onchip, ((int *)expert_id) + token_addr_offset, remain * sizeof(int), GDRAM2NRAM); for (int cur_expert = cur_expert_start; cur_expert <= cur_expert_end; cur_expert++) { __bang_eq_scalar((int *)cur_expert_result, (int *)expert_id_onchip, cur_expert, remain); int cur_expert_offset = ((int *)expert_start_addr)[cur_expert]; // NOTE: __bang_filter() only support floating data type uint32_t cur_expert_count = __bang_filter((float *)filter_idx_onchip, (float *)gen_idx_onchip, (float *)cur_expert_result, remain); // Store to the corresponding address of sorted_idx if (cur_expert_count > 0) { __memcpy(((int *)sorted_idx) + cur_expert_offset, (int *)filter_idx_onchip, cur_expert_count * sizeof(int), NRAM2GDRAM); } } } } // Sync for all cores, get position index after sorting __sync_all_ipu(); } // 4. Get gather index for expand and combine template __mlu_func__ void getGatherIdx(int *gather_expand_idx, int *gather_combine_idx, int *sorted_idx, const int token_cur_core, const int cur_token_start, const int topk) { // 4. Partition on [num_token*topk], // load sorted_idx onchip, // generate sequence according to position index from 0, add token offset // gather_combine_idx = scatter(seq, sorted_idx) // gather_expand_idx = sorted_idx / topk // update sequence // NRAM: // ------------------------------------------------------------------- // |sorted_idx_onchip|expand_idx_onchip|scatter_offset|scatter_sequence| // | deal_num | deal_num | deal_num | deal_num | // ------------------------------------------------------------------- // Calculate new deal_num of generate gather index // NOTE: deal_num should align to 64 Bytes, because bang_scatter() constraints int deal_num = (NRAM_BUFFER_SIZE / sizeof(int)) / 4; int repeat = token_cur_core / deal_num; int remain = token_cur_core % deal_num; int token_addr_offset = cur_token_start; // scatter dst GDRAM addr should align to 64B int *combine_idx_align_addr = (int *)((uint64_t)(gather_combine_idx) >> 6 << 6); int combine_idx_align_offset = (int)(gather_combine_idx - combine_idx_align_addr); int8_t *sorted_idx_onchip = nram_buffer; int8_t *expand_idx_onchip = sorted_idx_onchip + deal_num * sizeof(int); int8_t *scatter_offset = expand_idx_onchip + deal_num * sizeof(int); int8_t *scatter_sequence = scatter_offset + deal_num * sizeof(int); // Generate position index from 0 // Add base offset to sequence according to current core token start address if (token_cur_core > 0) { if (deal_num <= token_cur_core) { generateIntSeq((int *)scatter_sequence, deal_num); __bang_add_scalar((int *)scatter_sequence, (int *)scatter_sequence, (int)token_addr_offset, deal_num); } else { // only remainder part generateIntSeq((int *)scatter_sequence, token_cur_core); __bang_add_scalar((int *)scatter_sequence, (int *)scatter_sequence, (int)token_addr_offset, token_cur_core); } } // repeat part for (int i = 0; i < repeat; i++) { // Load current core sorted_idx __memcpy((int *)sorted_idx_onchip, ((int *)sorted_idx) + token_addr_offset, deal_num * sizeof(int), GDRAM2NRAM); // offset = sorted_idx * sizeof(int), counted in bytes if (is_sram_scatter) { __bang_mul_scalar((int *)scatter_offset, (int *)sorted_idx_onchip, (int)(sizeof(int)), deal_num); } else { // GDRAM addr should align to 64B __bang_fusion(FUSION_FAM, (int *)scatter_offset, (int *)sorted_idx_onchip, combine_idx_align_offset, (int)(sizeof(int)), deal_num); } // Sync for scatter __sync_compute(); if (is_sram_scatter) { scatterSeqSram((int *)sram_buffer, (int *)scatter_sequence, (uint32_t *)scatter_offset, deal_num); } else { // Scatter to output gather_combine_idx scatterSeqDram((int *)combine_idx_align_addr, (int *)scatter_sequence, (uint32_t *)scatter_offset, deal_num); } // expand_idx_onchip = sorted_idx / topk __bang_div((int *)expand_idx_onchip, (int *)sorted_idx_onchip, topk, deal_num); // Store expand idx __memcpy(((int *)gather_expand_idx) + token_addr_offset, (int *)expand_idx_onchip, deal_num * sizeof(int), NRAM2GDRAM); if (is_sram_scatter) { // if scatter to SRAM, need to sync compute with mv __sync_move(); } // Add offset to sequence and token_address __bang_add_scalar((int *)scatter_sequence, (int *)scatter_sequence, (int)deal_num, deal_num); token_addr_offset += deal_num; } // remainder part if (remain > 0) { // Load current core sorted_idx __memcpy((int *)sorted_idx_onchip, ((int *)sorted_idx) + token_addr_offset, remain * sizeof(int), GDRAM2NRAM); // offset = sorted_idx * sizeof(int), counted in bytes if (is_sram_scatter) { __bang_mul_scalar((int *)scatter_offset, (int *)sorted_idx_onchip, (int)(sizeof(int)), remain); } else { // GDRAM addr should align to 64B __bang_fusion(FUSION_FAM, (int *)scatter_offset, (int *)sorted_idx_onchip, combine_idx_align_offset, (int)(sizeof(int)), remain); } // Sync for scatter __sync_compute(); if (is_sram_scatter) { scatterSeqSram((int *)sram_buffer, (int *)scatter_sequence, (uint32_t *)scatter_offset, remain); } else { // Scatter to output gather_combine_idx scatterSeqDram((int *)combine_idx_align_addr, (int *)scatter_sequence, (uint32_t *)scatter_offset, remain); } // expand_idx_onchip = sorted_idx / topk __bang_div((int *)expand_idx_onchip, (int *)sorted_idx_onchip, topk, remain); // Store expand idx __memcpy(((int *)gather_expand_idx) + token_addr_offset, (int *)expand_idx_onchip, remain * sizeof(int), NRAM2GDRAM); } } // 4.1 Get gather combine index on SRAM __mlu_func__ void getCombineIdxSram(int *sorted_idx, const int token_cur_core, const int cur_token_start) { // 4.1 Partition on [num_token*topk], with only 1 union // load sorted_idx onchip, // generate sequence according to position index from 0, add token offset // gather_combine_idx = scatter(seq, sorted_idx) // update sequence // NRAM: // ------------------------------- // |scatter_offset|scatter_sequence| // | deal_num | deal_num | // ------------------------------- // Calculate new deal_num of generate gather index // NOTE: deal_num should align to 64 Bytes, because bang_scatter() constraints int deal_num = (NRAM_BUFFER_SIZE / sizeof(int)) / 2; int repeat = token_cur_core / deal_num; int remain = token_cur_core % deal_num; int token_addr_offset = cur_token_start; int8_t *scatter_offset = nram_buffer; int8_t *scatter_sequence = scatter_offset + deal_num * sizeof(int); // Generate position index from 0 // Add base offset to sequence according to current core token start address if (token_cur_core > 0) { if (deal_num <= token_cur_core) { generateIntSeq((int *)scatter_sequence, deal_num); __bang_add_scalar((int *)scatter_sequence, (int *)scatter_sequence, (int)token_addr_offset, deal_num); } else { // only remainder part generateIntSeq((int *)scatter_sequence, token_cur_core); __bang_add_scalar((int *)scatter_sequence, (int *)scatter_sequence, (int)token_addr_offset, token_cur_core); } } // repeat part for (int i = 0; i < repeat; i++) { // Load current core sorted_idx __memcpy((int *)scatter_offset, ((int *)sorted_idx) + token_addr_offset, deal_num * sizeof(int), GDRAM2NRAM); // offset = sorted_idx * sizeof(int), counted in bytes __bang_mul_scalar((int *)scatter_offset, (int *)scatter_offset, (int)(sizeof(int)), deal_num); // Sync for scatter __sync_compute(); // Scatter to SRAM scatterSeqSram((int *)sram_buffer, (int *)scatter_sequence, (uint32_t *)scatter_offset, deal_num); __sync_move(); // Add offset to sequence and token_address __bang_add_scalar((int *)scatter_sequence, (int *)scatter_sequence, (int)deal_num, deal_num); token_addr_offset += deal_num; } // remainder part if (remain > 0) { // Load current core sorted_idx __memcpy((int *)scatter_offset, ((int *)sorted_idx) + token_addr_offset, remain * sizeof(int), GDRAM2NRAM); // offset = sorted_idx * sizeof(int), counted in bytes __bang_mul_scalar((int *)scatter_offset, (int *)scatter_offset, (int)(sizeof(int)), remain); // Sync for scatter __sync_compute(); scatterSeqSram((int *)sram_buffer, (int *)scatter_sequence, (uint32_t *)scatter_offset, remain); } } // 4.2 Get gather expand index __mlu_func__ void getExpandIdx(int *gather_expand_idx, int *sorted_idx, const int token_cur_core, const int cur_token_start, const int topk) { // 4.2 Partition on [num_token*topk], // load sorted_idx onchip, // gather_expand_idx = sorted_idx / topk // NRAM: // ----------------------------------- // |sorted_idx_onchip|expand_idx_onchip| // | deal_num | deal_num | // ----------------------------------- // Calculate new deal_num of generate gather index int deal_num = (NRAM_BUFFER_SIZE / sizeof(int)) / 2; int repeat = token_cur_core / deal_num; int remain = token_cur_core % deal_num; int token_addr_offset = cur_token_start; int8_t *sorted_idx_onchip = nram_buffer; int8_t *expand_idx_onchip = sorted_idx_onchip + deal_num * sizeof(int); // repeat part for (int i = 0; i < repeat; i++) { // Load current core sorted_idx __memcpy((int *)sorted_idx_onchip, ((int *)sorted_idx) + token_addr_offset, deal_num * sizeof(int), GDRAM2NRAM); // expand_idx_onchip = sorted_idx / topk __bang_div((int *)expand_idx_onchip, (int *)sorted_idx_onchip, topk, deal_num); // Store expand idx __memcpy(((int *)gather_expand_idx) + token_addr_offset, (int *)expand_idx_onchip, deal_num * sizeof(int), NRAM2GDRAM); token_addr_offset += deal_num; } // remainder part if (remain > 0) { // Load current core sorted_idx __memcpy((int *)sorted_idx_onchip, ((int *)sorted_idx) + token_addr_offset, remain * sizeof(int), GDRAM2NRAM); // expand_idx_onchip = sorted_idx / topk __bang_div((int *)expand_idx_onchip, (int *)sorted_idx_onchip, topk, remain); // Store expand idx __memcpy(((int *)gather_expand_idx) + token_addr_offset, (int *)expand_idx_onchip, remain * sizeof(int), NRAM2GDRAM); } } __mlu_global__ void launchMoeGenIdxKernel(int *gather_expand_idx, int *gather_combine_idx, int *token_count, int *cusum_token_count, void *workspace, const void *expert_id, const int num_token, const int num_expert, const int topk) { // Store token count presum result, shape [num_expert + 1] int *token_count_presum = (cusum_token_count != nullptr) ? cusum_token_count : (int *)workspace; // Store position index after sorting, shape [num_token*topk] int *sorted_idx = ((int *)workspace) + num_expert + 1; // Calculate partition information for different processes // Partition on [num_token*topk] uint32_t token_total_num = num_token * topk; uint32_t token_cur_core = token_total_num / taskDim; uint32_t token_remain_num = token_total_num % taskDim; token_cur_core += (uint32_t)(taskId < token_remain_num); // Current core range according to partition on [num_token*topk] uint32_t cur_token_start = (taskId < token_remain_num) ? token_cur_core * taskId : token_cur_core * taskId + token_remain_num; // Partition on [num_expert] uint32_t expert_cur_core = num_expert / taskDim; uint32_t expert_remain_num = num_expert % taskDim; expert_cur_core += (uint32_t)(taskId < expert_remain_num); // Current core range according to partition on [num_expert] uint32_t cur_expert_start = (taskId < expert_remain_num) ? expert_cur_core * taskId : expert_cur_core * taskId + expert_remain_num; uint32_t cur_expert_end = cur_expert_start + expert_cur_core - 1; // Use Union1 SRAM to scatter, only MLU500 series support now #if __BANG_ARCH__ >= 592 bool is_sram_scatter = token_total_num * sizeof(int) < SRAM_BUFFER_SIZE; #else bool is_sram_scatter = false; #endif if (__is_ipu()) { // 1. Get token count getTokenCount((int *)token_count, (int *)expert_id, token_cur_core, cur_token_start, num_expert); // 2. Get presum of token count getTokenCountPresum((int *)token_count_presum, (int *)token_count, num_expert); // 3. Get expert position index after sorting getSortedIdx((int *)sorted_idx, (int *)expert_id, (int *)token_count_presum, token_total_num, num_expert, expert_cur_core, cur_expert_start, cur_expert_end); } #if EXPERT_AVG_COUNT_TEST // NOTE: test avg expert code here: if (__is_ipu() && taskId == 0) { modifyTokenCountAndPresum((int *)token_count_presum, (int *)token_count, token_total_num, num_expert); } __sync_cluster(); #endif // 4. Get gather index for expand and combine if (is_sram_scatter) { // Only use Union1 SRAM uint32_t scatter_idx_cur_core = token_total_num / 4; uint32_t scatter_idx_remain_num = token_total_num % 4; scatter_idx_cur_core += (uint32_t)(taskId < scatter_idx_remain_num); uint32_t cur_idx_start = (taskId < scatter_idx_remain_num) ? scatter_idx_cur_core * taskId : scatter_idx_cur_core * taskId + scatter_idx_remain_num; // Only Union1 task type, // deal once num is same with deal_num in getGatherIdx, // which means only 1 repeat to generate both expand and combine idx on NRAM const int deal_once_num = (NRAM_BUFFER_SIZE / sizeof(int)) / 4; if (taskDim <= 4 || token_total_num < deal_once_num) { if (taskId < 4) { if (__is_ipu()) { getGatherIdx((int *)gather_expand_idx, (int *)gather_combine_idx, (int *)sorted_idx, scatter_idx_cur_core, cur_idx_start, topk); // sync for ipu and mpu __sync_cluster(); } else { // sync for ipu and mpu __sync_cluster(); __memcpy_async((int *)gather_combine_idx, (int *)sram_buffer, token_total_num * sizeof(int), SRAM2GDRAM); } } } else { // If taskDim > 4, use first union to generate combine idx, // use other union to generate expand idx if (taskId < 4) { if (__is_ipu()) { // Scatter combine idx to SRAM getCombineIdxSram((int *)sorted_idx, scatter_idx_cur_core, cur_idx_start); __sync_cluster(); } else { __sync_cluster(); __memcpy_async((int *)gather_combine_idx, (int *)sram_buffer, token_total_num * sizeof(int), SRAM2GDRAM); } } else { // Other union generate expand idx if (__is_ipu()) { uint32_t expand_dim = taskDim - 4; uint32_t expand_id = taskId - 4; uint32_t expand_token_cur_core = token_total_num / expand_dim; uint32_t expand_token_remain_num = token_total_num % expand_dim; expand_token_cur_core += (uint32_t)(expand_id < expand_token_remain_num); uint32_t expand_cur_token_start = (expand_id < expand_token_remain_num) ? expand_token_cur_core * expand_id : expand_token_cur_core * expand_id + expand_token_remain_num; getExpandIdx((int *)gather_expand_idx, (int *)sorted_idx, expand_token_cur_core, expand_cur_token_start, topk); } } } } else { // not use SRAM to generate both expand and combine idx if (__is_ipu()) { getGatherIdx((int *)gather_expand_idx, (int *)gather_combine_idx, (int *)sorted_idx, token_cur_core, cur_token_start, topk); } } // step 5 does not need MPU if (__is_mpu()) { return; } } // end of kernel } // namespace kernels KernelStatus invokeMoeGenIdxKernel(cnrtQueue_t queue, int *gather_expand_idx, int *gather_combine_idx, int *token_count, int *cusum_token_count, void *workspace, const void *expert_id, const int num_token, const int num_expert, const int topk) { CNdev dev; cnCtxGetDevice(&dev); int cluster_num; int core_num; CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev)); CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev)); const int token_total_num = num_token * topk; // For partition on num_token*topk, single core processes at least 128 num const int single_core_num_limit = 1024; int need_core_num = std::ceil(float(token_total_num) / single_core_num_limit); // When partition on num_expert, each core at least processes one expert need_core_num = std::max(num_expert, need_core_num); // When consider UnionX cnrt func type, reset cluster_num if (token_total_num <= 4096) { // Block cnrtFunctionType_t k_type = cnrtFuncTypeBlock; cnrtDim3_t k_dim{1, 1, 1}; // Block kernel does not need workspace kernels::launchMoeGenIdxBlockKernel<<>>( gather_expand_idx, gather_combine_idx, token_count, cusum_token_count, expert_id, num_token, num_expert, topk); return KernelStatus::KERNEL_STATUS_SUCCESS; } else if (need_core_num <= 4) { // Union1 cluster_num = 1; } else if (need_core_num <= 8) { // Union2 cluster_num = std::min(cluster_num, 2); } else if (need_core_num <= 16) { // Union4 cluster_num = std::min(cluster_num, 4); } else if (need_core_num <= 32) { // Union8 cluster_num = std::min(cluster_num, 8); } cnrtFunctionType_t k_type; cnrtDim3_t k_dim{1, 1, 1}; // Find max UnionX cnrt func type if (cluster_num == 1) { k_type = cnrtFuncTypeUnion1; k_dim.x = 4; } else if (cluster_num < 4) { // cluster num is 2 or 3 k_type = cnrtFuncTypeUnion2; k_dim.x = 8; } else if (cluster_num < 8) { // cluster num is 4,5,6,7 k_type = cnrtFuncTypeUnion4; k_dim.x = 16; } else { // cluster num larger than 8 k_type = cnrtFuncTypeUnion8; k_dim.x = 32; } // The expert_id is int data type kernels::launchMoeGenIdxKernel<<>>( gather_expand_idx, gather_combine_idx, token_count, cusum_token_count, workspace, expert_id, num_token, num_expert, topk); return KernelStatus::KERNEL_STATUS_SUCCESS; } #undef EXPERT_AVG_COUNT_TEST // undef test macro } // namespace tmo