936 lines
38 KiB
Plaintext
936 lines
38 KiB
Plaintext
/*************************************************************************
|
|
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
|
*
|
|
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
|
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
|
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
|
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
|
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
|
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
|
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
*************************************************************************/
|
|
#include <algorithm>
|
|
#include <cmath>
|
|
#include <cstddef>
|
|
#include <cstdint>
|
|
#include <type_traits>
|
|
#include <vector>
|
|
#include "gen_idx.mluh"
|
|
// clang-format off
|
|
#include <mlu.h>
|
|
// clang-format on
|
|
|
|
namespace tmo {
|
|
namespace kernels {
|
|
|
|
#define NRAM_BUFFER_SIZE ((__MLU_NRAM_SIZE__ - 16) * 1024)
|
|
#define SRAM_BUFFER_SIZE ((__MLU_SRAM_SIZE__ - 8) * 1024)
|
|
#define ALIGN_16 (16)
|
|
|
|
#define EXPERT_AVG_COUNT_TEST (0)
|
|
|
|
__mlu_shared__ int8_t sram_buffer[SRAM_BUFFER_SIZE];
|
|
__nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE];
|
|
__nram__ const int range[64] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
|
|
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
|
|
32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
|
|
48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63};
|
|
|
|
// Generate integer sequence data from 0 to length-1
|
|
__mlu_func__ void generateIntSeq(int *dst, int length) {
|
|
int count = 64;
|
|
__bang_move(dst, range, std::min(count, length) * sizeof(int));
|
|
while (count < length) {
|
|
__bang_add_scalar(dst + count, dst, (int)count, std::min(count, length - count));
|
|
count *= 2;
|
|
}
|
|
}
|
|
|
|
// genIdx Block kernel, use only 1 core to process
|
|
__mlu_global__ void launchMoeGenIdxBlockKernel(int *gather_expand_idx,
|
|
int *gather_combine_idx,
|
|
int *token_count,
|
|
int *cusum_token_count,
|
|
const void *expert_id,
|
|
const int num_token,
|
|
const int num_expert,
|
|
const int topk) {
|
|
/* NRAM space */
|
|
// Total occupy: (4 * token_total_num + 2 * num_expert) * sizeof(int)
|
|
// --------------------------------------------------------------
|
|
// | expert_id | sorted_idx |gen_idx_onchip|cur_expert_result|
|
|
// | combine_idx | expand_idx | | scatter_offset |
|
|
// |num_token*topk|num_token*topk|num_token*topk| num_token*topk |
|
|
// --------------------------------------------------------------
|
|
// ------------------------------
|
|
// |token_count|token_count_presum|
|
|
// | | |
|
|
// | num_expert| num_expert |
|
|
// ------------------------------
|
|
|
|
uint32_t token_total_num = num_token * topk;
|
|
// num align to 16, size align to 64B
|
|
uint32_t align_total_num = (token_total_num + ALIGN_16 - 1) >> 4 << 4;
|
|
|
|
int8_t *expert_id_onchip = (int8_t *)nram_buffer;
|
|
int8_t *sorted_idx_onchip = (int8_t *)expert_id_onchip + align_total_num * sizeof(int);
|
|
int8_t *gen_idx_onchip = (int8_t *)sorted_idx_onchip + align_total_num * sizeof(int);
|
|
int8_t *cur_expert_result = (int8_t *)gen_idx_onchip + align_total_num * sizeof(int);
|
|
int8_t *token_count_onchip = (int8_t *)cur_expert_result + align_total_num * sizeof(int);
|
|
int8_t *token_count_presum_onchip = (int8_t *)token_count_onchip + num_expert * sizeof(int);
|
|
|
|
int8_t *scatter_offset = cur_expert_result; // reuse cur_expert space
|
|
#if __BANG_ARCH__ >= 592
|
|
int8_t *combine_idx_onchip = expert_id_onchip; // reuse expert_it space
|
|
#endif
|
|
int8_t *expand_idx_onchip = sorted_idx_onchip; // reuse sorted_idx space
|
|
|
|
// Load current core input expert_id and generate int sequence
|
|
__memcpy_async((int *)expert_id_onchip, (int *)expert_id, token_total_num * sizeof(int),
|
|
GDRAM2NRAM);
|
|
generateIntSeq((int *)gen_idx_onchip, token_total_num);
|
|
__sync();
|
|
|
|
// Initialize sort idx offset
|
|
uint32_t sorted_idx_offset = 0;
|
|
// Initialize token count first presum with 0
|
|
((int *)token_count_presum_onchip)[0] = 0;
|
|
bool need_cusum_token_count = bool(cusum_token_count != nullptr);
|
|
|
|
// Loop on each expert, eq, count, filter index
|
|
for (int cur_expert = 0; cur_expert < num_expert; cur_expert++) {
|
|
__bang_eq_scalar((int *)cur_expert_result, (int *)expert_id_onchip, cur_expert,
|
|
token_total_num);
|
|
// Use filter to sort gen_idx, output with sorted_idx_offset
|
|
uint32_t cur_expert_count =
|
|
__bang_filter(((float *)sorted_idx_onchip) + sorted_idx_offset, (float *)gen_idx_onchip,
|
|
(float *)cur_expert_result, token_total_num);
|
|
|
|
sorted_idx_offset += cur_expert_count;
|
|
((int *)token_count_onchip)[cur_expert] = cur_expert_count;
|
|
|
|
// Compute cusum token count and store
|
|
if (need_cusum_token_count) {
|
|
((int *)token_count_presum_onchip)[cur_expert + 1] = sorted_idx_offset;
|
|
}
|
|
}
|
|
|
|
#if EXPERT_AVG_COUNT_TEST
|
|
// NOTE: test avg expert code here:
|
|
uint32_t token_count_avg = token_total_num / num_expert;
|
|
uint32_t expert_remain_num = token_total_num % num_expert;
|
|
|
|
for (int i = 0; i < num_expert; i++) {
|
|
((int *)token_count_onchip)[i] =
|
|
(i < expert_remain_num) ? token_count_avg + 1 : token_count_avg;
|
|
((int *)token_count_presum_onchip)[i + 1] =
|
|
((int *)token_count_presum_onchip)[i] + ((int *)token_count_onchip)[i];
|
|
}
|
|
#endif
|
|
|
|
__sync_compute();
|
|
// Store token_count and cusum token count
|
|
__memcpy_async((int *)token_count, (int *)token_count_onchip, num_expert * sizeof(int),
|
|
NRAM2GDRAM);
|
|
if (need_cusum_token_count) {
|
|
__memcpy_async((int *)cusum_token_count, (int *)token_count_presum_onchip,
|
|
(num_expert + 1) * sizeof(int), NRAM2GDRAM);
|
|
}
|
|
|
|
// Use sorted idx to generate gather idx for expand and combine
|
|
#if __BANG_ARCH__ >= 592
|
|
// scatter_offset = sorted_idx mul_scalar sizeof(int);
|
|
__bang_mul_scalar((int *)scatter_offset, (int *)sorted_idx_onchip, (int)(sizeof(int)),
|
|
token_total_num);
|
|
#else
|
|
// scatter dst GDRAM addr should align to 64B
|
|
int *combine_idx_align_addr = (int *)((uint64_t)(gather_combine_idx) >> 6 << 6);
|
|
int combine_idx_align_offset = (int)(gather_combine_idx - combine_idx_align_addr);
|
|
|
|
__bang_fusion(FUSION_FAM, (int *)scatter_offset, (int *)sorted_idx_onchip,
|
|
combine_idx_align_offset, (int)(sizeof(int)), token_total_num);
|
|
#endif
|
|
__sync_compute();
|
|
|
|
#if __BANG_ARCH__ >= 592
|
|
// scatter_async to NRAM
|
|
__scatter_async((int *)combine_idx_onchip, (int *)gen_idx_onchip, (uint32_t *)scatter_offset,
|
|
sizeof(int), NRAM2NRAM, sizeof(int), (unsigned short)token_total_num);
|
|
#endif
|
|
// expand_idx = sorted_idx div(topk)
|
|
__bang_div((int *)expand_idx_onchip, (int *)sorted_idx_onchip, topk, token_total_num);
|
|
|
|
// Store expand_idx and combine_idx
|
|
__sync_compute();
|
|
__memcpy_async((int *)gather_expand_idx, (int *)expand_idx_onchip, token_total_num * sizeof(int),
|
|
NRAM2GDRAM);
|
|
#if __BANG_ARCH__ >= 592
|
|
__sync_move();
|
|
__memcpy_async((int *)gather_combine_idx, (int *)combine_idx_onchip,
|
|
token_total_num * sizeof(int), NRAM2GDRAM);
|
|
#else
|
|
// 370 directly scatter to GDRAM
|
|
__scatter((int *)combine_idx_align_addr, (int *)gen_idx_onchip, (uint32_t *)scatter_offset,
|
|
sizeof(int), NRAM2GDRAM, sizeof(int), (unsigned short)token_total_num);
|
|
#endif
|
|
}
|
|
|
|
// Only MLU500 series support NRAM2SRAM scatter direction
|
|
__mlu_func__ void scatterSeqSram(int *dst, int *src, uint32_t *offset, int length) {
|
|
#if __BANG_ARCH__ >= 592
|
|
// When length larger than 65535(maximum segnum in bang_scatter),
|
|
// and src/offset address should align to 64B
|
|
int seg_repeat = length / 32768;
|
|
int seg_remain = length % 32768;
|
|
int seg_offset = 0;
|
|
|
|
for (int seg = 0; seg < seg_repeat; seg++) {
|
|
__scatter_async((int *)dst, ((int *)src) + seg_offset, ((uint32_t *)offset) + seg_offset,
|
|
sizeof(int), NRAM2SRAM, sizeof(int), (unsigned short)32768);
|
|
seg_offset += 32768;
|
|
}
|
|
|
|
if (seg_remain > 0) {
|
|
__scatter_async((int *)dst, ((int *)src) + seg_offset, ((uint32_t *)offset) + seg_offset,
|
|
sizeof(int), NRAM2SRAM, sizeof(int), (unsigned short)seg_remain);
|
|
}
|
|
#endif
|
|
}
|
|
|
|
// Scatter sequence, transfer size is sizeof(int)
|
|
__mlu_func__ void scatterSeqDram(int *dst, int *src, uint32_t *offset, int length) {
|
|
// When length larger than 65535(maximum segnum in bang_scatter),
|
|
// and src/offset address should align to 64B
|
|
int seg_repeat = length / 32768;
|
|
int seg_remain = length % 32768;
|
|
int seg_offset = 0;
|
|
|
|
for (int seg = 0; seg < seg_repeat; seg++) {
|
|
#if __BANG_ARCH__ >= 592
|
|
__scatter_async((int *)dst, ((int *)src) + seg_offset, ((uint32_t *)offset) + seg_offset,
|
|
sizeof(int), NRAM2GDRAM, sizeof(int), (unsigned short)32768);
|
|
#else
|
|
__scatter((int *)dst, ((int *)src) + seg_offset, ((uint32_t *)offset) + seg_offset, sizeof(int),
|
|
NRAM2GDRAM, sizeof(int), (unsigned short)32768);
|
|
#endif
|
|
seg_offset += 32768;
|
|
}
|
|
if (seg_remain > 0) {
|
|
#if __BANG_ARCH__ >= 592
|
|
__scatter_async((int *)dst, ((int *)src) + seg_offset, ((uint32_t *)offset) + seg_offset,
|
|
sizeof(int), NRAM2GDRAM, sizeof(int), (unsigned short)seg_remain);
|
|
#else
|
|
__scatter((int *)dst, ((int *)src) + seg_offset, ((uint32_t *)offset) + seg_offset, sizeof(int),
|
|
NRAM2GDRAM, sizeof(int), (unsigned short)seg_remain);
|
|
#endif
|
|
}
|
|
}
|
|
|
|
// 1. Get token count
|
|
__mlu_func__ void getTokenCount(int *token_count,
|
|
int *expert_id,
|
|
int token_cur_core,
|
|
int cur_token_start,
|
|
int num_expert) {
|
|
// 1. Partition on [num_token*topk],
|
|
// each core for-loop on all expert_id, use eq and count instructions,
|
|
// use AtomicAdd to accumulate all expert_id token counts, on GDRAM.
|
|
// And sync for all cores.
|
|
// NRAM:
|
|
// ------------------------------------------------------
|
|
// |expert_id_onchip|cur_expert_result|expert_count_onchip|
|
|
// | deal_num | deal_num | num_expert |
|
|
// ------------------------------------------------------
|
|
|
|
uint32_t deal_num = (NRAM_BUFFER_SIZE / sizeof(int) - num_expert) / 2;
|
|
int8_t *expert_id_onchip = (int8_t *)nram_buffer;
|
|
int8_t *cur_expert_result = (int8_t *)expert_id_onchip + deal_num * sizeof(int);
|
|
int8_t *expert_count_onchip = cur_expert_result + deal_num * sizeof(int);
|
|
|
|
// Current core data loop
|
|
uint32_t repeat = token_cur_core / deal_num;
|
|
uint32_t remain = token_cur_core % deal_num;
|
|
uint32_t total_repeat = repeat + (int)(remain > 0);
|
|
uint32_t token_addr_offset = cur_token_start;
|
|
|
|
// Initialize token_count with 0
|
|
if (taskId == 0) {
|
|
__gdramset((int *)token_count, num_expert, 0);
|
|
}
|
|
// Sync for initialize token_count
|
|
__sync_all_ipu();
|
|
|
|
// Initialize expert count onchip with 0
|
|
if (token_cur_core > 0) {
|
|
__bang_write_zero((int *)expert_count_onchip, num_expert);
|
|
}
|
|
|
|
// actual num in loop
|
|
int cur_deal_num = deal_num;
|
|
for (int i = 0; i < total_repeat; i++) {
|
|
if (i == total_repeat - 1 && remain > 0) {
|
|
cur_deal_num = remain;
|
|
}
|
|
// Load current core input expert_id
|
|
__memcpy((int *)expert_id_onchip, ((int *)expert_id) + token_addr_offset,
|
|
cur_deal_num * sizeof(int), GDRAM2NRAM);
|
|
token_addr_offset += cur_deal_num;
|
|
|
|
// Loop on each expert, eq, count
|
|
for (int cur_expert = 0; cur_expert < num_expert; cur_expert++) {
|
|
__bang_eq_scalar((int *)cur_expert_result, (int *)expert_id_onchip, cur_expert, cur_deal_num);
|
|
// NOTE: __bang_count() only support floating data type
|
|
uint32_t cur_expert_count = __bang_count((float *)cur_expert_result, cur_deal_num);
|
|
((int *)expert_count_onchip)[cur_expert] += cur_expert_count;
|
|
}
|
|
}
|
|
|
|
// AtomicAdd(reduce) all cores token count results
|
|
if (token_cur_core > 0) {
|
|
__bang_atomic_reduce_add((int *)token_count, (int *)expert_count_onchip, num_expert);
|
|
}
|
|
// Sync for all cores, get accumulate of token_count
|
|
__sync_all_ipu();
|
|
}
|
|
|
|
// 2. Get token count presum, for each expert index start address after sorting
|
|
__mlu_func__ void getTokenCountPresum(int *token_count_presum,
|
|
int *token_count,
|
|
const int num_expert) {
|
|
// 2. After first process, already get token_count.
|
|
// Then use one core to pre-sum on token_count, consider size of int32,
|
|
// first expert id start address should be zero.
|
|
// to get each expert id start address after sorting, store to workspace,
|
|
// token_count_presum.
|
|
// And sync for all cores.
|
|
// NRAM:
|
|
// load token_count to token_count_presum[1~num_expert+1],
|
|
// for i = 0 to num_expert:
|
|
// token_count_presum[i+1] += token_count_presum[i]
|
|
// store token_count_presum[0~num_expert]
|
|
// -------------------------
|
|
// |token_count_presum_onchip|
|
|
// | {0}, num_expert |
|
|
// -------------------------
|
|
|
|
if (taskId == 0) {
|
|
// Initialize count presum onchip with a first 0
|
|
int8_t *token_count_presum_onchip = nram_buffer;
|
|
((int *)token_count_presum_onchip)[0] = 0;
|
|
// Load token_count with an offset of 1
|
|
__memcpy(((int *)token_count_presum_onchip) + 1, (int *)token_count, num_expert * sizeof(int),
|
|
GDRAM2NRAM);
|
|
|
|
// Calculate presum of token count by each expert
|
|
for (int cur_expert = 0; cur_expert < num_expert; cur_expert++) {
|
|
((int *)token_count_presum_onchip)[cur_expert + 1] +=
|
|
((int *)token_count_presum_onchip)[cur_expert];
|
|
}
|
|
|
|
// Store token count presum to workspace
|
|
__memcpy((int *)token_count_presum, (int *)token_count_presum_onchip,
|
|
(num_expert + 1) * sizeof(int), NRAM2GDRAM);
|
|
}
|
|
// Sync for all cores, get presum of token count
|
|
__sync_all_ipu();
|
|
}
|
|
|
|
__mlu_func__ void modifyTokenCountAndPresum(int *token_count_presum,
|
|
int *token_count,
|
|
const uint32_t token_total_num,
|
|
const int num_expert) {
|
|
uint32_t token_count_avg = token_total_num / num_expert;
|
|
uint32_t expert_remain_num = token_total_num % num_expert;
|
|
|
|
int8_t *token_count_onchip = nram_buffer;
|
|
int8_t *token_count_presum_onchip = token_count_onchip + num_expert * sizeof(int);
|
|
|
|
((int *)token_count_presum_onchip)[0] = 0;
|
|
|
|
for (int i = 0; i < num_expert; i++) {
|
|
((int *)token_count_onchip)[i] =
|
|
(i < expert_remain_num) ? token_count_avg + 1 : token_count_avg;
|
|
((int *)token_count_presum_onchip)[i + 1] =
|
|
((int *)token_count_presum_onchip)[i] + ((int *)token_count_onchip)[i];
|
|
}
|
|
|
|
__memcpy((int *)token_count, (int *)token_count_onchip, num_expert * sizeof(int), NRAM2GDRAM);
|
|
__memcpy((int *)token_count_presum, (int *)token_count_presum_onchip,
|
|
(num_expert + 1) * sizeof(int), NRAM2GDRAM);
|
|
}
|
|
|
|
// 3. Get expert position index after sorting
|
|
__mlu_func__ void getSortedIdx(int *sorted_idx,
|
|
int *expert_id,
|
|
int *token_count_presum,
|
|
const int token_total_num,
|
|
const int num_expert,
|
|
const int expert_cur_core,
|
|
const int cur_expert_start,
|
|
const int cur_expert_end) {
|
|
// 3. Partition on num_expert, each core generate position index from 0,
|
|
// and for-loop on all expert_id data, use eq with own each expert_id,
|
|
// and filter on index, stores to each expert_id start address of
|
|
// sorted_idx on workspace.
|
|
// And sync for all cores.
|
|
// NRAM:
|
|
// -------------------------------------------------------------------
|
|
// |expert_id_onchip|cur_expert_result|gen_idx_onchip|filter_idx_onchip|
|
|
// | deal_num | deal_num | deal_num | deal_num |
|
|
// -------------------------------------------------------------------
|
|
// |expert_start_addr|
|
|
// | num_expert |
|
|
// -----------------
|
|
|
|
// Calculate new deal_num of sorting process
|
|
int deal_num = (NRAM_BUFFER_SIZE / sizeof(int) - num_expert) / 4;
|
|
|
|
// Each core deal with whole token expert_id data
|
|
int repeat = token_total_num / deal_num;
|
|
int remain = token_total_num % deal_num;
|
|
int token_addr_offset = 0;
|
|
|
|
int8_t *expert_id_onchip = nram_buffer;
|
|
int8_t *cur_expert_result = expert_id_onchip + deal_num * sizeof(int);
|
|
int8_t *gen_idx_onchip = cur_expert_result + deal_num * sizeof(int);
|
|
int8_t *filter_idx_onchip = gen_idx_onchip + deal_num * sizeof(int);
|
|
int8_t *expert_start_addr = filter_idx_onchip + deal_num * sizeof(int);
|
|
|
|
// When num_expert < taskDim, not all cores need to sort
|
|
if (expert_cur_core > 0) {
|
|
// Generate position index from 0
|
|
if (deal_num <= token_total_num) {
|
|
generateIntSeq((int *)gen_idx_onchip, deal_num);
|
|
} else { // only remainder part
|
|
generateIntSeq((int *)gen_idx_onchip, token_total_num);
|
|
}
|
|
|
|
// Initialize expert start address with presum of token count
|
|
__memcpy((int *)expert_start_addr, (int *)token_count_presum, num_expert * sizeof(int),
|
|
GDRAM2NRAM);
|
|
|
|
// repeat part
|
|
for (int i = 0; i < repeat; i++) {
|
|
// Load current core expert_id
|
|
__memcpy((int *)expert_id_onchip, ((int *)expert_id) + token_addr_offset,
|
|
deal_num * sizeof(int), GDRAM2NRAM);
|
|
token_addr_offset += deal_num;
|
|
|
|
// Loop for current core expert, eq, filter position index
|
|
// use filter, store to sorted_idx[expert_start_addr]
|
|
for (int cur_expert = cur_expert_start; cur_expert <= cur_expert_end; cur_expert++) {
|
|
__bang_eq_scalar((int *)cur_expert_result, (int *)expert_id_onchip, cur_expert, deal_num);
|
|
|
|
int cur_expert_offset = ((int *)expert_start_addr)[cur_expert];
|
|
|
|
// NOTE: __bang_filter() only support floating data type
|
|
uint32_t cur_expert_count =
|
|
__bang_filter((float *)filter_idx_onchip, (float *)gen_idx_onchip,
|
|
(float *)cur_expert_result, deal_num);
|
|
|
|
// Store to the corresponding address of sorted_idx
|
|
if (cur_expert_count > 0) {
|
|
__memcpy(((int *)sorted_idx) + cur_expert_offset, (int *)filter_idx_onchip,
|
|
cur_expert_count * sizeof(int), NRAM2GDRAM);
|
|
|
|
// Update address offset of current expert
|
|
((int *)expert_start_addr)[cur_expert] = cur_expert_offset + cur_expert_count;
|
|
}
|
|
}
|
|
|
|
// Update position index for each data loop
|
|
__bang_add_scalar((int *)gen_idx_onchip, (int *)gen_idx_onchip, (int)(deal_num), deal_num);
|
|
}
|
|
|
|
// remainder part
|
|
if (remain > 0) {
|
|
__memcpy((int *)expert_id_onchip, ((int *)expert_id) + token_addr_offset,
|
|
remain * sizeof(int), GDRAM2NRAM);
|
|
|
|
for (int cur_expert = cur_expert_start; cur_expert <= cur_expert_end; cur_expert++) {
|
|
__bang_eq_scalar((int *)cur_expert_result, (int *)expert_id_onchip, cur_expert, remain);
|
|
|
|
int cur_expert_offset = ((int *)expert_start_addr)[cur_expert];
|
|
|
|
// NOTE: __bang_filter() only support floating data type
|
|
uint32_t cur_expert_count =
|
|
__bang_filter((float *)filter_idx_onchip, (float *)gen_idx_onchip,
|
|
(float *)cur_expert_result, remain);
|
|
// Store to the corresponding address of sorted_idx
|
|
if (cur_expert_count > 0) {
|
|
__memcpy(((int *)sorted_idx) + cur_expert_offset, (int *)filter_idx_onchip,
|
|
cur_expert_count * sizeof(int), NRAM2GDRAM);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Sync for all cores, get position index after sorting
|
|
__sync_all_ipu();
|
|
}
|
|
|
|
// 4. Get gather index for expand and combine
|
|
template <bool is_sram_scatter>
|
|
__mlu_func__ void getGatherIdx(int *gather_expand_idx,
|
|
int *gather_combine_idx,
|
|
int *sorted_idx,
|
|
const int token_cur_core,
|
|
const int cur_token_start,
|
|
const int topk) {
|
|
// 4. Partition on [num_token*topk],
|
|
// load sorted_idx onchip,
|
|
// generate sequence according to position index from 0, add token offset
|
|
// gather_combine_idx = scatter(seq, sorted_idx)
|
|
// gather_expand_idx = sorted_idx / topk
|
|
// update sequence
|
|
// NRAM:
|
|
// -------------------------------------------------------------------
|
|
// |sorted_idx_onchip|expand_idx_onchip|scatter_offset|scatter_sequence|
|
|
// | deal_num | deal_num | deal_num | deal_num |
|
|
// -------------------------------------------------------------------
|
|
|
|
// Calculate new deal_num of generate gather index
|
|
// NOTE: deal_num should align to 64 Bytes, because bang_scatter() constraints
|
|
int deal_num = (NRAM_BUFFER_SIZE / sizeof(int)) / 4;
|
|
int repeat = token_cur_core / deal_num;
|
|
int remain = token_cur_core % deal_num;
|
|
int token_addr_offset = cur_token_start;
|
|
|
|
// scatter dst GDRAM addr should align to 64B
|
|
int *combine_idx_align_addr = (int *)((uint64_t)(gather_combine_idx) >> 6 << 6);
|
|
int combine_idx_align_offset = (int)(gather_combine_idx - combine_idx_align_addr);
|
|
|
|
int8_t *sorted_idx_onchip = nram_buffer;
|
|
int8_t *expand_idx_onchip = sorted_idx_onchip + deal_num * sizeof(int);
|
|
int8_t *scatter_offset = expand_idx_onchip + deal_num * sizeof(int);
|
|
int8_t *scatter_sequence = scatter_offset + deal_num * sizeof(int);
|
|
|
|
// Generate position index from 0
|
|
// Add base offset to sequence according to current core token start address
|
|
if (token_cur_core > 0) {
|
|
if (deal_num <= token_cur_core) {
|
|
generateIntSeq((int *)scatter_sequence, deal_num);
|
|
__bang_add_scalar((int *)scatter_sequence, (int *)scatter_sequence, (int)token_addr_offset,
|
|
deal_num);
|
|
} else { // only remainder part
|
|
generateIntSeq((int *)scatter_sequence, token_cur_core);
|
|
__bang_add_scalar((int *)scatter_sequence, (int *)scatter_sequence, (int)token_addr_offset,
|
|
token_cur_core);
|
|
}
|
|
}
|
|
|
|
// repeat part
|
|
for (int i = 0; i < repeat; i++) {
|
|
// Load current core sorted_idx
|
|
__memcpy((int *)sorted_idx_onchip, ((int *)sorted_idx) + token_addr_offset,
|
|
deal_num * sizeof(int), GDRAM2NRAM);
|
|
|
|
// offset = sorted_idx * sizeof(int), counted in bytes
|
|
if (is_sram_scatter) {
|
|
__bang_mul_scalar((int *)scatter_offset, (int *)sorted_idx_onchip, (int)(sizeof(int)),
|
|
deal_num);
|
|
} else {
|
|
// GDRAM addr should align to 64B
|
|
__bang_fusion(FUSION_FAM, (int *)scatter_offset, (int *)sorted_idx_onchip,
|
|
combine_idx_align_offset, (int)(sizeof(int)), deal_num);
|
|
}
|
|
// Sync for scatter
|
|
__sync_compute();
|
|
|
|
if (is_sram_scatter) {
|
|
scatterSeqSram((int *)sram_buffer, (int *)scatter_sequence, (uint32_t *)scatter_offset,
|
|
deal_num);
|
|
} else {
|
|
// Scatter to output gather_combine_idx
|
|
scatterSeqDram((int *)combine_idx_align_addr, (int *)scatter_sequence,
|
|
(uint32_t *)scatter_offset, deal_num);
|
|
}
|
|
|
|
// expand_idx_onchip = sorted_idx / topk
|
|
__bang_div((int *)expand_idx_onchip, (int *)sorted_idx_onchip, topk, deal_num);
|
|
// Store expand idx
|
|
__memcpy(((int *)gather_expand_idx) + token_addr_offset, (int *)expand_idx_onchip,
|
|
deal_num * sizeof(int), NRAM2GDRAM);
|
|
if (is_sram_scatter) {
|
|
// if scatter to SRAM, need to sync compute with mv
|
|
__sync_move();
|
|
}
|
|
// Add offset to sequence and token_address
|
|
__bang_add_scalar((int *)scatter_sequence, (int *)scatter_sequence, (int)deal_num, deal_num);
|
|
token_addr_offset += deal_num;
|
|
}
|
|
|
|
// remainder part
|
|
if (remain > 0) {
|
|
// Load current core sorted_idx
|
|
__memcpy((int *)sorted_idx_onchip, ((int *)sorted_idx) + token_addr_offset,
|
|
remain * sizeof(int), GDRAM2NRAM);
|
|
|
|
// offset = sorted_idx * sizeof(int), counted in bytes
|
|
if (is_sram_scatter) {
|
|
__bang_mul_scalar((int *)scatter_offset, (int *)sorted_idx_onchip, (int)(sizeof(int)),
|
|
remain);
|
|
} else {
|
|
// GDRAM addr should align to 64B
|
|
__bang_fusion(FUSION_FAM, (int *)scatter_offset, (int *)sorted_idx_onchip,
|
|
combine_idx_align_offset, (int)(sizeof(int)), remain);
|
|
}
|
|
|
|
// Sync for scatter
|
|
__sync_compute();
|
|
|
|
if (is_sram_scatter) {
|
|
scatterSeqSram((int *)sram_buffer, (int *)scatter_sequence, (uint32_t *)scatter_offset,
|
|
remain);
|
|
} else {
|
|
// Scatter to output gather_combine_idx
|
|
scatterSeqDram((int *)combine_idx_align_addr, (int *)scatter_sequence,
|
|
(uint32_t *)scatter_offset, remain);
|
|
}
|
|
|
|
// expand_idx_onchip = sorted_idx / topk
|
|
__bang_div((int *)expand_idx_onchip, (int *)sorted_idx_onchip, topk, remain);
|
|
// Store expand idx
|
|
__memcpy(((int *)gather_expand_idx) + token_addr_offset, (int *)expand_idx_onchip,
|
|
remain * sizeof(int), NRAM2GDRAM);
|
|
}
|
|
}
|
|
|
|
// 4.1 Get gather combine index on SRAM
|
|
__mlu_func__ void getCombineIdxSram(int *sorted_idx,
|
|
const int token_cur_core,
|
|
const int cur_token_start) {
|
|
// 4.1 Partition on [num_token*topk], with only 1 union
|
|
// load sorted_idx onchip,
|
|
// generate sequence according to position index from 0, add token offset
|
|
// gather_combine_idx = scatter(seq, sorted_idx)
|
|
// update sequence
|
|
// NRAM:
|
|
// -------------------------------
|
|
// |scatter_offset|scatter_sequence|
|
|
// | deal_num | deal_num |
|
|
// -------------------------------
|
|
|
|
// Calculate new deal_num of generate gather index
|
|
// NOTE: deal_num should align to 64 Bytes, because bang_scatter() constraints
|
|
int deal_num = (NRAM_BUFFER_SIZE / sizeof(int)) / 2;
|
|
int repeat = token_cur_core / deal_num;
|
|
int remain = token_cur_core % deal_num;
|
|
int token_addr_offset = cur_token_start;
|
|
|
|
int8_t *scatter_offset = nram_buffer;
|
|
int8_t *scatter_sequence = scatter_offset + deal_num * sizeof(int);
|
|
|
|
// Generate position index from 0
|
|
// Add base offset to sequence according to current core token start address
|
|
if (token_cur_core > 0) {
|
|
if (deal_num <= token_cur_core) {
|
|
generateIntSeq((int *)scatter_sequence, deal_num);
|
|
__bang_add_scalar((int *)scatter_sequence, (int *)scatter_sequence, (int)token_addr_offset,
|
|
deal_num);
|
|
} else { // only remainder part
|
|
generateIntSeq((int *)scatter_sequence, token_cur_core);
|
|
__bang_add_scalar((int *)scatter_sequence, (int *)scatter_sequence, (int)token_addr_offset,
|
|
token_cur_core);
|
|
}
|
|
}
|
|
|
|
// repeat part
|
|
for (int i = 0; i < repeat; i++) {
|
|
// Load current core sorted_idx
|
|
__memcpy((int *)scatter_offset, ((int *)sorted_idx) + token_addr_offset, deal_num * sizeof(int),
|
|
GDRAM2NRAM);
|
|
|
|
// offset = sorted_idx * sizeof(int), counted in bytes
|
|
__bang_mul_scalar((int *)scatter_offset, (int *)scatter_offset, (int)(sizeof(int)), deal_num);
|
|
// Sync for scatter
|
|
__sync_compute();
|
|
|
|
// Scatter to SRAM
|
|
scatterSeqSram((int *)sram_buffer, (int *)scatter_sequence, (uint32_t *)scatter_offset,
|
|
deal_num);
|
|
__sync_move();
|
|
|
|
// Add offset to sequence and token_address
|
|
__bang_add_scalar((int *)scatter_sequence, (int *)scatter_sequence, (int)deal_num, deal_num);
|
|
token_addr_offset += deal_num;
|
|
}
|
|
|
|
// remainder part
|
|
if (remain > 0) {
|
|
// Load current core sorted_idx
|
|
__memcpy((int *)scatter_offset, ((int *)sorted_idx) + token_addr_offset, remain * sizeof(int),
|
|
GDRAM2NRAM);
|
|
|
|
// offset = sorted_idx * sizeof(int), counted in bytes
|
|
__bang_mul_scalar((int *)scatter_offset, (int *)scatter_offset, (int)(sizeof(int)), remain);
|
|
// Sync for scatter
|
|
__sync_compute();
|
|
|
|
scatterSeqSram((int *)sram_buffer, (int *)scatter_sequence, (uint32_t *)scatter_offset, remain);
|
|
}
|
|
}
|
|
|
|
// 4.2 Get gather expand index
|
|
__mlu_func__ void getExpandIdx(int *gather_expand_idx,
|
|
int *sorted_idx,
|
|
const int token_cur_core,
|
|
const int cur_token_start,
|
|
const int topk) {
|
|
// 4.2 Partition on [num_token*topk],
|
|
// load sorted_idx onchip,
|
|
// gather_expand_idx = sorted_idx / topk
|
|
// NRAM:
|
|
// -----------------------------------
|
|
// |sorted_idx_onchip|expand_idx_onchip|
|
|
// | deal_num | deal_num |
|
|
// -----------------------------------
|
|
|
|
// Calculate new deal_num of generate gather index
|
|
int deal_num = (NRAM_BUFFER_SIZE / sizeof(int)) / 2;
|
|
int repeat = token_cur_core / deal_num;
|
|
int remain = token_cur_core % deal_num;
|
|
int token_addr_offset = cur_token_start;
|
|
|
|
int8_t *sorted_idx_onchip = nram_buffer;
|
|
int8_t *expand_idx_onchip = sorted_idx_onchip + deal_num * sizeof(int);
|
|
|
|
// repeat part
|
|
for (int i = 0; i < repeat; i++) {
|
|
// Load current core sorted_idx
|
|
__memcpy((int *)sorted_idx_onchip, ((int *)sorted_idx) + token_addr_offset,
|
|
deal_num * sizeof(int), GDRAM2NRAM);
|
|
|
|
// expand_idx_onchip = sorted_idx / topk
|
|
__bang_div((int *)expand_idx_onchip, (int *)sorted_idx_onchip, topk, deal_num);
|
|
// Store expand idx
|
|
__memcpy(((int *)gather_expand_idx) + token_addr_offset, (int *)expand_idx_onchip,
|
|
deal_num * sizeof(int), NRAM2GDRAM);
|
|
token_addr_offset += deal_num;
|
|
}
|
|
|
|
// remainder part
|
|
if (remain > 0) {
|
|
// Load current core sorted_idx
|
|
__memcpy((int *)sorted_idx_onchip, ((int *)sorted_idx) + token_addr_offset,
|
|
remain * sizeof(int), GDRAM2NRAM);
|
|
// expand_idx_onchip = sorted_idx / topk
|
|
__bang_div((int *)expand_idx_onchip, (int *)sorted_idx_onchip, topk, remain);
|
|
// Store expand idx
|
|
__memcpy(((int *)gather_expand_idx) + token_addr_offset, (int *)expand_idx_onchip,
|
|
remain * sizeof(int), NRAM2GDRAM);
|
|
}
|
|
}
|
|
|
|
__mlu_global__ void launchMoeGenIdxKernel(int *gather_expand_idx,
|
|
int *gather_combine_idx,
|
|
int *token_count,
|
|
int *cusum_token_count,
|
|
void *workspace,
|
|
const void *expert_id,
|
|
const int num_token,
|
|
const int num_expert,
|
|
const int topk) {
|
|
// Store token count presum result, shape [num_expert + 1]
|
|
int *token_count_presum = (cusum_token_count != nullptr) ? cusum_token_count : (int *)workspace;
|
|
// Store position index after sorting, shape [num_token*topk]
|
|
int *sorted_idx = ((int *)workspace) + num_expert + 1;
|
|
|
|
// Calculate partition information for different processes
|
|
// Partition on [num_token*topk]
|
|
uint32_t token_total_num = num_token * topk;
|
|
uint32_t token_cur_core = token_total_num / taskDim;
|
|
uint32_t token_remain_num = token_total_num % taskDim;
|
|
token_cur_core += (uint32_t)(taskId < token_remain_num);
|
|
// Current core range according to partition on [num_token*topk]
|
|
uint32_t cur_token_start = (taskId < token_remain_num)
|
|
? token_cur_core * taskId
|
|
: token_cur_core * taskId + token_remain_num;
|
|
|
|
// Partition on [num_expert]
|
|
uint32_t expert_cur_core = num_expert / taskDim;
|
|
uint32_t expert_remain_num = num_expert % taskDim;
|
|
expert_cur_core += (uint32_t)(taskId < expert_remain_num);
|
|
// Current core range according to partition on [num_expert]
|
|
uint32_t cur_expert_start = (taskId < expert_remain_num)
|
|
? expert_cur_core * taskId
|
|
: expert_cur_core * taskId + expert_remain_num;
|
|
uint32_t cur_expert_end = cur_expert_start + expert_cur_core - 1;
|
|
|
|
// Use Union1 SRAM to scatter, only MLU500 series support now
|
|
#if __BANG_ARCH__ >= 592
|
|
bool is_sram_scatter = token_total_num * sizeof(int) < SRAM_BUFFER_SIZE;
|
|
#else
|
|
bool is_sram_scatter = false;
|
|
#endif
|
|
|
|
if (__is_ipu()) {
|
|
// 1. Get token count
|
|
getTokenCount((int *)token_count, (int *)expert_id, token_cur_core, cur_token_start,
|
|
num_expert);
|
|
// 2. Get presum of token count
|
|
getTokenCountPresum((int *)token_count_presum, (int *)token_count, num_expert);
|
|
|
|
// 3. Get expert position index after sorting
|
|
getSortedIdx((int *)sorted_idx, (int *)expert_id, (int *)token_count_presum, token_total_num,
|
|
num_expert, expert_cur_core, cur_expert_start, cur_expert_end);
|
|
}
|
|
|
|
#if EXPERT_AVG_COUNT_TEST
|
|
// NOTE: test avg expert code here:
|
|
if (__is_ipu() && taskId == 0) {
|
|
modifyTokenCountAndPresum((int *)token_count_presum, (int *)token_count, token_total_num,
|
|
num_expert);
|
|
}
|
|
__sync_cluster();
|
|
#endif
|
|
|
|
// 4. Get gather index for expand and combine
|
|
if (is_sram_scatter) {
|
|
// Only use Union1 SRAM
|
|
uint32_t scatter_idx_cur_core = token_total_num / 4;
|
|
uint32_t scatter_idx_remain_num = token_total_num % 4;
|
|
scatter_idx_cur_core += (uint32_t)(taskId < scatter_idx_remain_num);
|
|
uint32_t cur_idx_start = (taskId < scatter_idx_remain_num)
|
|
? scatter_idx_cur_core * taskId
|
|
: scatter_idx_cur_core * taskId + scatter_idx_remain_num;
|
|
|
|
// Only Union1 task type,
|
|
// deal once num is same with deal_num in getGatherIdx,
|
|
// which means only 1 repeat to generate both expand and combine idx on NRAM
|
|
const int deal_once_num = (NRAM_BUFFER_SIZE / sizeof(int)) / 4;
|
|
if (taskDim <= 4 || token_total_num < deal_once_num) {
|
|
if (taskId < 4) {
|
|
if (__is_ipu()) {
|
|
getGatherIdx<true>((int *)gather_expand_idx, (int *)gather_combine_idx, (int *)sorted_idx,
|
|
scatter_idx_cur_core, cur_idx_start, topk);
|
|
// sync for ipu and mpu
|
|
__sync_cluster();
|
|
} else {
|
|
// sync for ipu and mpu
|
|
__sync_cluster();
|
|
__memcpy_async((int *)gather_combine_idx, (int *)sram_buffer,
|
|
token_total_num * sizeof(int), SRAM2GDRAM);
|
|
}
|
|
}
|
|
} else {
|
|
// If taskDim > 4, use first union to generate combine idx,
|
|
// use other union to generate expand idx
|
|
if (taskId < 4) {
|
|
if (__is_ipu()) {
|
|
// Scatter combine idx to SRAM
|
|
getCombineIdxSram((int *)sorted_idx, scatter_idx_cur_core, cur_idx_start);
|
|
__sync_cluster();
|
|
} else {
|
|
__sync_cluster();
|
|
__memcpy_async((int *)gather_combine_idx, (int *)sram_buffer,
|
|
token_total_num * sizeof(int), SRAM2GDRAM);
|
|
}
|
|
} else {
|
|
// Other union generate expand idx
|
|
if (__is_ipu()) {
|
|
uint32_t expand_dim = taskDim - 4;
|
|
uint32_t expand_id = taskId - 4;
|
|
uint32_t expand_token_cur_core = token_total_num / expand_dim;
|
|
uint32_t expand_token_remain_num = token_total_num % expand_dim;
|
|
expand_token_cur_core += (uint32_t)(expand_id < expand_token_remain_num);
|
|
|
|
uint32_t expand_cur_token_start =
|
|
(expand_id < expand_token_remain_num)
|
|
? expand_token_cur_core * expand_id
|
|
: expand_token_cur_core * expand_id + expand_token_remain_num;
|
|
getExpandIdx((int *)gather_expand_idx, (int *)sorted_idx, expand_token_cur_core,
|
|
expand_cur_token_start, topk);
|
|
}
|
|
}
|
|
}
|
|
} else {
|
|
// not use SRAM to generate both expand and combine idx
|
|
if (__is_ipu()) {
|
|
getGatherIdx<false>((int *)gather_expand_idx, (int *)gather_combine_idx, (int *)sorted_idx,
|
|
token_cur_core, cur_token_start, topk);
|
|
}
|
|
}
|
|
|
|
// step 5 does not need MPU
|
|
if (__is_mpu()) {
|
|
return;
|
|
}
|
|
} // end of kernel
|
|
|
|
} // namespace kernels
|
|
|
|
KernelStatus invokeMoeGenIdxKernel(cnrtQueue_t queue,
|
|
int *gather_expand_idx,
|
|
int *gather_combine_idx,
|
|
int *token_count,
|
|
int *cusum_token_count,
|
|
void *workspace,
|
|
const void *expert_id,
|
|
const int num_token,
|
|
const int num_expert,
|
|
const int topk) {
|
|
CNdev dev;
|
|
cnCtxGetDevice(&dev);
|
|
int cluster_num;
|
|
int core_num;
|
|
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
|
|
CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev));
|
|
|
|
const int token_total_num = num_token * topk;
|
|
|
|
// For partition on num_token*topk, single core processes at least 128 num
|
|
const int single_core_num_limit = 1024;
|
|
int need_core_num = std::ceil(float(token_total_num) / single_core_num_limit);
|
|
// When partition on num_expert, each core at least processes one expert
|
|
need_core_num = std::max(num_expert, need_core_num);
|
|
|
|
// When consider UnionX cnrt func type, reset cluster_num
|
|
if (token_total_num <= 4096) { // Block
|
|
cnrtFunctionType_t k_type = cnrtFuncTypeBlock;
|
|
cnrtDim3_t k_dim{1, 1, 1};
|
|
// Block kernel does not need workspace
|
|
kernels::launchMoeGenIdxBlockKernel<<<k_dim, k_type, queue>>>(
|
|
gather_expand_idx, gather_combine_idx, token_count, cusum_token_count, expert_id, num_token,
|
|
num_expert, topk);
|
|
|
|
return KernelStatus::KERNEL_STATUS_SUCCESS;
|
|
} else if (need_core_num <= 4) { // Union1
|
|
cluster_num = 1;
|
|
} else if (need_core_num <= 8) { // Union2
|
|
cluster_num = std::min(cluster_num, 2);
|
|
} else if (need_core_num <= 16) { // Union4
|
|
cluster_num = std::min(cluster_num, 4);
|
|
} else if (need_core_num <= 32) { // Union8
|
|
cluster_num = std::min(cluster_num, 8);
|
|
}
|
|
|
|
cnrtFunctionType_t k_type;
|
|
cnrtDim3_t k_dim{1, 1, 1};
|
|
|
|
// Find max UnionX cnrt func type
|
|
if (cluster_num == 1) {
|
|
k_type = cnrtFuncTypeUnion1;
|
|
k_dim.x = 4;
|
|
} else if (cluster_num < 4) { // cluster num is 2 or 3
|
|
k_type = cnrtFuncTypeUnion2;
|
|
k_dim.x = 8;
|
|
} else if (cluster_num < 8) { // cluster num is 4,5,6,7
|
|
k_type = cnrtFuncTypeUnion4;
|
|
k_dim.x = 16;
|
|
} else { // cluster num larger than 8
|
|
k_type = cnrtFuncTypeUnion8;
|
|
k_dim.x = 32;
|
|
}
|
|
|
|
// The expert_id is int data type
|
|
kernels::launchMoeGenIdxKernel<<<k_dim, k_type, queue>>>(
|
|
gather_expand_idx, gather_combine_idx, token_count, cusum_token_count, workspace, expert_id,
|
|
num_token, num_expert, topk);
|
|
return KernelStatus::KERNEL_STATUS_SUCCESS;
|
|
}
|
|
|
|
#undef EXPERT_AVG_COUNT_TEST // undef test macro
|
|
|
|
} // namespace tmo
|