forked from EngineX-Cambricon/enginex-mlu370-vllm
add ops
This commit is contained in:
935
torch_mlu_ops-v1.3.2/csrc/kernels/moe/gen_idx.mlu
Normal file
935
torch_mlu_ops-v1.3.2/csrc/kernels/moe/gen_idx.mlu
Normal file
@@ -0,0 +1,935 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
#include "gen_idx.mluh"
|
||||
// clang-format off
|
||||
#include <mlu.h>
|
||||
// clang-format on
|
||||
|
||||
namespace tmo {
|
||||
namespace kernels {
|
||||
|
||||
#define NRAM_BUFFER_SIZE ((__MLU_NRAM_SIZE__ - 16) * 1024)
|
||||
#define SRAM_BUFFER_SIZE ((__MLU_SRAM_SIZE__ - 8) * 1024)
|
||||
#define ALIGN_16 (16)
|
||||
|
||||
#define EXPERT_AVG_COUNT_TEST (0)
|
||||
|
||||
__mlu_shared__ int8_t sram_buffer[SRAM_BUFFER_SIZE];
|
||||
__nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE];
|
||||
__nram__ const int range[64] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
|
||||
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
|
||||
32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
|
||||
48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63};
|
||||
|
||||
// Generate integer sequence data from 0 to length-1
|
||||
__mlu_func__ void generateIntSeq(int *dst, int length) {
|
||||
int count = 64;
|
||||
__bang_move(dst, range, std::min(count, length) * sizeof(int));
|
||||
while (count < length) {
|
||||
__bang_add_scalar(dst + count, dst, (int)count, std::min(count, length - count));
|
||||
count *= 2;
|
||||
}
|
||||
}
|
||||
|
||||
// genIdx Block kernel, use only 1 core to process
|
||||
__mlu_global__ void launchMoeGenIdxBlockKernel(int *gather_expand_idx,
|
||||
int *gather_combine_idx,
|
||||
int *token_count,
|
||||
int *cusum_token_count,
|
||||
const void *expert_id,
|
||||
const int num_token,
|
||||
const int num_expert,
|
||||
const int topk) {
|
||||
/* NRAM space */
|
||||
// Total occupy: (4 * token_total_num + 2 * num_expert) * sizeof(int)
|
||||
// --------------------------------------------------------------
|
||||
// | expert_id | sorted_idx |gen_idx_onchip|cur_expert_result|
|
||||
// | combine_idx | expand_idx | | scatter_offset |
|
||||
// |num_token*topk|num_token*topk|num_token*topk| num_token*topk |
|
||||
// --------------------------------------------------------------
|
||||
// ------------------------------
|
||||
// |token_count|token_count_presum|
|
||||
// | | |
|
||||
// | num_expert| num_expert |
|
||||
// ------------------------------
|
||||
|
||||
uint32_t token_total_num = num_token * topk;
|
||||
// num align to 16, size align to 64B
|
||||
uint32_t align_total_num = (token_total_num + ALIGN_16 - 1) >> 4 << 4;
|
||||
|
||||
int8_t *expert_id_onchip = (int8_t *)nram_buffer;
|
||||
int8_t *sorted_idx_onchip = (int8_t *)expert_id_onchip + align_total_num * sizeof(int);
|
||||
int8_t *gen_idx_onchip = (int8_t *)sorted_idx_onchip + align_total_num * sizeof(int);
|
||||
int8_t *cur_expert_result = (int8_t *)gen_idx_onchip + align_total_num * sizeof(int);
|
||||
int8_t *token_count_onchip = (int8_t *)cur_expert_result + align_total_num * sizeof(int);
|
||||
int8_t *token_count_presum_onchip = (int8_t *)token_count_onchip + num_expert * sizeof(int);
|
||||
|
||||
int8_t *scatter_offset = cur_expert_result; // reuse cur_expert space
|
||||
#if __BANG_ARCH__ >= 592
|
||||
int8_t *combine_idx_onchip = expert_id_onchip; // reuse expert_it space
|
||||
#endif
|
||||
int8_t *expand_idx_onchip = sorted_idx_onchip; // reuse sorted_idx space
|
||||
|
||||
// Load current core input expert_id and generate int sequence
|
||||
__memcpy_async((int *)expert_id_onchip, (int *)expert_id, token_total_num * sizeof(int),
|
||||
GDRAM2NRAM);
|
||||
generateIntSeq((int *)gen_idx_onchip, token_total_num);
|
||||
__sync();
|
||||
|
||||
// Initialize sort idx offset
|
||||
uint32_t sorted_idx_offset = 0;
|
||||
// Initialize token count first presum with 0
|
||||
((int *)token_count_presum_onchip)[0] = 0;
|
||||
bool need_cusum_token_count = bool(cusum_token_count != nullptr);
|
||||
|
||||
// Loop on each expert, eq, count, filter index
|
||||
for (int cur_expert = 0; cur_expert < num_expert; cur_expert++) {
|
||||
__bang_eq_scalar((int *)cur_expert_result, (int *)expert_id_onchip, cur_expert,
|
||||
token_total_num);
|
||||
// Use filter to sort gen_idx, output with sorted_idx_offset
|
||||
uint32_t cur_expert_count =
|
||||
__bang_filter(((float *)sorted_idx_onchip) + sorted_idx_offset, (float *)gen_idx_onchip,
|
||||
(float *)cur_expert_result, token_total_num);
|
||||
|
||||
sorted_idx_offset += cur_expert_count;
|
||||
((int *)token_count_onchip)[cur_expert] = cur_expert_count;
|
||||
|
||||
// Compute cusum token count and store
|
||||
if (need_cusum_token_count) {
|
||||
((int *)token_count_presum_onchip)[cur_expert + 1] = sorted_idx_offset;
|
||||
}
|
||||
}
|
||||
|
||||
#if EXPERT_AVG_COUNT_TEST
|
||||
// NOTE: test avg expert code here:
|
||||
uint32_t token_count_avg = token_total_num / num_expert;
|
||||
uint32_t expert_remain_num = token_total_num % num_expert;
|
||||
|
||||
for (int i = 0; i < num_expert; i++) {
|
||||
((int *)token_count_onchip)[i] =
|
||||
(i < expert_remain_num) ? token_count_avg + 1 : token_count_avg;
|
||||
((int *)token_count_presum_onchip)[i + 1] =
|
||||
((int *)token_count_presum_onchip)[i] + ((int *)token_count_onchip)[i];
|
||||
}
|
||||
#endif
|
||||
|
||||
__sync_compute();
|
||||
// Store token_count and cusum token count
|
||||
__memcpy_async((int *)token_count, (int *)token_count_onchip, num_expert * sizeof(int),
|
||||
NRAM2GDRAM);
|
||||
if (need_cusum_token_count) {
|
||||
__memcpy_async((int *)cusum_token_count, (int *)token_count_presum_onchip,
|
||||
(num_expert + 1) * sizeof(int), NRAM2GDRAM);
|
||||
}
|
||||
|
||||
// Use sorted idx to generate gather idx for expand and combine
|
||||
#if __BANG_ARCH__ >= 592
|
||||
// scatter_offset = sorted_idx mul_scalar sizeof(int);
|
||||
__bang_mul_scalar((int *)scatter_offset, (int *)sorted_idx_onchip, (int)(sizeof(int)),
|
||||
token_total_num);
|
||||
#else
|
||||
// scatter dst GDRAM addr should align to 64B
|
||||
int *combine_idx_align_addr = (int *)((uint64_t)(gather_combine_idx) >> 6 << 6);
|
||||
int combine_idx_align_offset = (int)(gather_combine_idx - combine_idx_align_addr);
|
||||
|
||||
__bang_fusion(FUSION_FAM, (int *)scatter_offset, (int *)sorted_idx_onchip,
|
||||
combine_idx_align_offset, (int)(sizeof(int)), token_total_num);
|
||||
#endif
|
||||
__sync_compute();
|
||||
|
||||
#if __BANG_ARCH__ >= 592
|
||||
// scatter_async to NRAM
|
||||
__scatter_async((int *)combine_idx_onchip, (int *)gen_idx_onchip, (uint32_t *)scatter_offset,
|
||||
sizeof(int), NRAM2NRAM, sizeof(int), (unsigned short)token_total_num);
|
||||
#endif
|
||||
// expand_idx = sorted_idx div(topk)
|
||||
__bang_div((int *)expand_idx_onchip, (int *)sorted_idx_onchip, topk, token_total_num);
|
||||
|
||||
// Store expand_idx and combine_idx
|
||||
__sync_compute();
|
||||
__memcpy_async((int *)gather_expand_idx, (int *)expand_idx_onchip, token_total_num * sizeof(int),
|
||||
NRAM2GDRAM);
|
||||
#if __BANG_ARCH__ >= 592
|
||||
__sync_move();
|
||||
__memcpy_async((int *)gather_combine_idx, (int *)combine_idx_onchip,
|
||||
token_total_num * sizeof(int), NRAM2GDRAM);
|
||||
#else
|
||||
// 370 directly scatter to GDRAM
|
||||
__scatter((int *)combine_idx_align_addr, (int *)gen_idx_onchip, (uint32_t *)scatter_offset,
|
||||
sizeof(int), NRAM2GDRAM, sizeof(int), (unsigned short)token_total_num);
|
||||
#endif
|
||||
}
|
||||
|
||||
// Only MLU500 series support NRAM2SRAM scatter direction
|
||||
__mlu_func__ void scatterSeqSram(int *dst, int *src, uint32_t *offset, int length) {
|
||||
#if __BANG_ARCH__ >= 592
|
||||
// When length larger than 65535(maximum segnum in bang_scatter),
|
||||
// and src/offset address should align to 64B
|
||||
int seg_repeat = length / 32768;
|
||||
int seg_remain = length % 32768;
|
||||
int seg_offset = 0;
|
||||
|
||||
for (int seg = 0; seg < seg_repeat; seg++) {
|
||||
__scatter_async((int *)dst, ((int *)src) + seg_offset, ((uint32_t *)offset) + seg_offset,
|
||||
sizeof(int), NRAM2SRAM, sizeof(int), (unsigned short)32768);
|
||||
seg_offset += 32768;
|
||||
}
|
||||
|
||||
if (seg_remain > 0) {
|
||||
__scatter_async((int *)dst, ((int *)src) + seg_offset, ((uint32_t *)offset) + seg_offset,
|
||||
sizeof(int), NRAM2SRAM, sizeof(int), (unsigned short)seg_remain);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
// Scatter sequence, transfer size is sizeof(int)
|
||||
__mlu_func__ void scatterSeqDram(int *dst, int *src, uint32_t *offset, int length) {
|
||||
// When length larger than 65535(maximum segnum in bang_scatter),
|
||||
// and src/offset address should align to 64B
|
||||
int seg_repeat = length / 32768;
|
||||
int seg_remain = length % 32768;
|
||||
int seg_offset = 0;
|
||||
|
||||
for (int seg = 0; seg < seg_repeat; seg++) {
|
||||
#if __BANG_ARCH__ >= 592
|
||||
__scatter_async((int *)dst, ((int *)src) + seg_offset, ((uint32_t *)offset) + seg_offset,
|
||||
sizeof(int), NRAM2GDRAM, sizeof(int), (unsigned short)32768);
|
||||
#else
|
||||
__scatter((int *)dst, ((int *)src) + seg_offset, ((uint32_t *)offset) + seg_offset, sizeof(int),
|
||||
NRAM2GDRAM, sizeof(int), (unsigned short)32768);
|
||||
#endif
|
||||
seg_offset += 32768;
|
||||
}
|
||||
if (seg_remain > 0) {
|
||||
#if __BANG_ARCH__ >= 592
|
||||
__scatter_async((int *)dst, ((int *)src) + seg_offset, ((uint32_t *)offset) + seg_offset,
|
||||
sizeof(int), NRAM2GDRAM, sizeof(int), (unsigned short)seg_remain);
|
||||
#else
|
||||
__scatter((int *)dst, ((int *)src) + seg_offset, ((uint32_t *)offset) + seg_offset, sizeof(int),
|
||||
NRAM2GDRAM, sizeof(int), (unsigned short)seg_remain);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
// 1. Get token count
|
||||
__mlu_func__ void getTokenCount(int *token_count,
|
||||
int *expert_id,
|
||||
int token_cur_core,
|
||||
int cur_token_start,
|
||||
int num_expert) {
|
||||
// 1. Partition on [num_token*topk],
|
||||
// each core for-loop on all expert_id, use eq and count instructions,
|
||||
// use AtomicAdd to accumulate all expert_id token counts, on GDRAM.
|
||||
// And sync for all cores.
|
||||
// NRAM:
|
||||
// ------------------------------------------------------
|
||||
// |expert_id_onchip|cur_expert_result|expert_count_onchip|
|
||||
// | deal_num | deal_num | num_expert |
|
||||
// ------------------------------------------------------
|
||||
|
||||
uint32_t deal_num = (NRAM_BUFFER_SIZE / sizeof(int) - num_expert) / 2;
|
||||
int8_t *expert_id_onchip = (int8_t *)nram_buffer;
|
||||
int8_t *cur_expert_result = (int8_t *)expert_id_onchip + deal_num * sizeof(int);
|
||||
int8_t *expert_count_onchip = cur_expert_result + deal_num * sizeof(int);
|
||||
|
||||
// Current core data loop
|
||||
uint32_t repeat = token_cur_core / deal_num;
|
||||
uint32_t remain = token_cur_core % deal_num;
|
||||
uint32_t total_repeat = repeat + (int)(remain > 0);
|
||||
uint32_t token_addr_offset = cur_token_start;
|
||||
|
||||
// Initialize token_count with 0
|
||||
if (taskId == 0) {
|
||||
__gdramset((int *)token_count, num_expert, 0);
|
||||
}
|
||||
// Sync for initialize token_count
|
||||
__sync_all_ipu();
|
||||
|
||||
// Initialize expert count onchip with 0
|
||||
if (token_cur_core > 0) {
|
||||
__bang_write_zero((int *)expert_count_onchip, num_expert);
|
||||
}
|
||||
|
||||
// actual num in loop
|
||||
int cur_deal_num = deal_num;
|
||||
for (int i = 0; i < total_repeat; i++) {
|
||||
if (i == total_repeat - 1 && remain > 0) {
|
||||
cur_deal_num = remain;
|
||||
}
|
||||
// Load current core input expert_id
|
||||
__memcpy((int *)expert_id_onchip, ((int *)expert_id) + token_addr_offset,
|
||||
cur_deal_num * sizeof(int), GDRAM2NRAM);
|
||||
token_addr_offset += cur_deal_num;
|
||||
|
||||
// Loop on each expert, eq, count
|
||||
for (int cur_expert = 0; cur_expert < num_expert; cur_expert++) {
|
||||
__bang_eq_scalar((int *)cur_expert_result, (int *)expert_id_onchip, cur_expert, cur_deal_num);
|
||||
// NOTE: __bang_count() only support floating data type
|
||||
uint32_t cur_expert_count = __bang_count((float *)cur_expert_result, cur_deal_num);
|
||||
((int *)expert_count_onchip)[cur_expert] += cur_expert_count;
|
||||
}
|
||||
}
|
||||
|
||||
// AtomicAdd(reduce) all cores token count results
|
||||
if (token_cur_core > 0) {
|
||||
__bang_atomic_reduce_add((int *)token_count, (int *)expert_count_onchip, num_expert);
|
||||
}
|
||||
// Sync for all cores, get accumulate of token_count
|
||||
__sync_all_ipu();
|
||||
}
|
||||
|
||||
// 2. Get token count presum, for each expert index start address after sorting
|
||||
__mlu_func__ void getTokenCountPresum(int *token_count_presum,
|
||||
int *token_count,
|
||||
const int num_expert) {
|
||||
// 2. After first process, already get token_count.
|
||||
// Then use one core to pre-sum on token_count, consider size of int32,
|
||||
// first expert id start address should be zero.
|
||||
// to get each expert id start address after sorting, store to workspace,
|
||||
// token_count_presum.
|
||||
// And sync for all cores.
|
||||
// NRAM:
|
||||
// load token_count to token_count_presum[1~num_expert+1],
|
||||
// for i = 0 to num_expert:
|
||||
// token_count_presum[i+1] += token_count_presum[i]
|
||||
// store token_count_presum[0~num_expert]
|
||||
// -------------------------
|
||||
// |token_count_presum_onchip|
|
||||
// | {0}, num_expert |
|
||||
// -------------------------
|
||||
|
||||
if (taskId == 0) {
|
||||
// Initialize count presum onchip with a first 0
|
||||
int8_t *token_count_presum_onchip = nram_buffer;
|
||||
((int *)token_count_presum_onchip)[0] = 0;
|
||||
// Load token_count with an offset of 1
|
||||
__memcpy(((int *)token_count_presum_onchip) + 1, (int *)token_count, num_expert * sizeof(int),
|
||||
GDRAM2NRAM);
|
||||
|
||||
// Calculate presum of token count by each expert
|
||||
for (int cur_expert = 0; cur_expert < num_expert; cur_expert++) {
|
||||
((int *)token_count_presum_onchip)[cur_expert + 1] +=
|
||||
((int *)token_count_presum_onchip)[cur_expert];
|
||||
}
|
||||
|
||||
// Store token count presum to workspace
|
||||
__memcpy((int *)token_count_presum, (int *)token_count_presum_onchip,
|
||||
(num_expert + 1) * sizeof(int), NRAM2GDRAM);
|
||||
}
|
||||
// Sync for all cores, get presum of token count
|
||||
__sync_all_ipu();
|
||||
}
|
||||
|
||||
__mlu_func__ void modifyTokenCountAndPresum(int *token_count_presum,
|
||||
int *token_count,
|
||||
const uint32_t token_total_num,
|
||||
const int num_expert) {
|
||||
uint32_t token_count_avg = token_total_num / num_expert;
|
||||
uint32_t expert_remain_num = token_total_num % num_expert;
|
||||
|
||||
int8_t *token_count_onchip = nram_buffer;
|
||||
int8_t *token_count_presum_onchip = token_count_onchip + num_expert * sizeof(int);
|
||||
|
||||
((int *)token_count_presum_onchip)[0] = 0;
|
||||
|
||||
for (int i = 0; i < num_expert; i++) {
|
||||
((int *)token_count_onchip)[i] =
|
||||
(i < expert_remain_num) ? token_count_avg + 1 : token_count_avg;
|
||||
((int *)token_count_presum_onchip)[i + 1] =
|
||||
((int *)token_count_presum_onchip)[i] + ((int *)token_count_onchip)[i];
|
||||
}
|
||||
|
||||
__memcpy((int *)token_count, (int *)token_count_onchip, num_expert * sizeof(int), NRAM2GDRAM);
|
||||
__memcpy((int *)token_count_presum, (int *)token_count_presum_onchip,
|
||||
(num_expert + 1) * sizeof(int), NRAM2GDRAM);
|
||||
}
|
||||
|
||||
// 3. Get expert position index after sorting
|
||||
__mlu_func__ void getSortedIdx(int *sorted_idx,
|
||||
int *expert_id,
|
||||
int *token_count_presum,
|
||||
const int token_total_num,
|
||||
const int num_expert,
|
||||
const int expert_cur_core,
|
||||
const int cur_expert_start,
|
||||
const int cur_expert_end) {
|
||||
// 3. Partition on num_expert, each core generate position index from 0,
|
||||
// and for-loop on all expert_id data, use eq with own each expert_id,
|
||||
// and filter on index, stores to each expert_id start address of
|
||||
// sorted_idx on workspace.
|
||||
// And sync for all cores.
|
||||
// NRAM:
|
||||
// -------------------------------------------------------------------
|
||||
// |expert_id_onchip|cur_expert_result|gen_idx_onchip|filter_idx_onchip|
|
||||
// | deal_num | deal_num | deal_num | deal_num |
|
||||
// -------------------------------------------------------------------
|
||||
// |expert_start_addr|
|
||||
// | num_expert |
|
||||
// -----------------
|
||||
|
||||
// Calculate new deal_num of sorting process
|
||||
int deal_num = (NRAM_BUFFER_SIZE / sizeof(int) - num_expert) / 4;
|
||||
|
||||
// Each core deal with whole token expert_id data
|
||||
int repeat = token_total_num / deal_num;
|
||||
int remain = token_total_num % deal_num;
|
||||
int token_addr_offset = 0;
|
||||
|
||||
int8_t *expert_id_onchip = nram_buffer;
|
||||
int8_t *cur_expert_result = expert_id_onchip + deal_num * sizeof(int);
|
||||
int8_t *gen_idx_onchip = cur_expert_result + deal_num * sizeof(int);
|
||||
int8_t *filter_idx_onchip = gen_idx_onchip + deal_num * sizeof(int);
|
||||
int8_t *expert_start_addr = filter_idx_onchip + deal_num * sizeof(int);
|
||||
|
||||
// When num_expert < taskDim, not all cores need to sort
|
||||
if (expert_cur_core > 0) {
|
||||
// Generate position index from 0
|
||||
if (deal_num <= token_total_num) {
|
||||
generateIntSeq((int *)gen_idx_onchip, deal_num);
|
||||
} else { // only remainder part
|
||||
generateIntSeq((int *)gen_idx_onchip, token_total_num);
|
||||
}
|
||||
|
||||
// Initialize expert start address with presum of token count
|
||||
__memcpy((int *)expert_start_addr, (int *)token_count_presum, num_expert * sizeof(int),
|
||||
GDRAM2NRAM);
|
||||
|
||||
// repeat part
|
||||
for (int i = 0; i < repeat; i++) {
|
||||
// Load current core expert_id
|
||||
__memcpy((int *)expert_id_onchip, ((int *)expert_id) + token_addr_offset,
|
||||
deal_num * sizeof(int), GDRAM2NRAM);
|
||||
token_addr_offset += deal_num;
|
||||
|
||||
// Loop for current core expert, eq, filter position index
|
||||
// use filter, store to sorted_idx[expert_start_addr]
|
||||
for (int cur_expert = cur_expert_start; cur_expert <= cur_expert_end; cur_expert++) {
|
||||
__bang_eq_scalar((int *)cur_expert_result, (int *)expert_id_onchip, cur_expert, deal_num);
|
||||
|
||||
int cur_expert_offset = ((int *)expert_start_addr)[cur_expert];
|
||||
|
||||
// NOTE: __bang_filter() only support floating data type
|
||||
uint32_t cur_expert_count =
|
||||
__bang_filter((float *)filter_idx_onchip, (float *)gen_idx_onchip,
|
||||
(float *)cur_expert_result, deal_num);
|
||||
|
||||
// Store to the corresponding address of sorted_idx
|
||||
if (cur_expert_count > 0) {
|
||||
__memcpy(((int *)sorted_idx) + cur_expert_offset, (int *)filter_idx_onchip,
|
||||
cur_expert_count * sizeof(int), NRAM2GDRAM);
|
||||
|
||||
// Update address offset of current expert
|
||||
((int *)expert_start_addr)[cur_expert] = cur_expert_offset + cur_expert_count;
|
||||
}
|
||||
}
|
||||
|
||||
// Update position index for each data loop
|
||||
__bang_add_scalar((int *)gen_idx_onchip, (int *)gen_idx_onchip, (int)(deal_num), deal_num);
|
||||
}
|
||||
|
||||
// remainder part
|
||||
if (remain > 0) {
|
||||
__memcpy((int *)expert_id_onchip, ((int *)expert_id) + token_addr_offset,
|
||||
remain * sizeof(int), GDRAM2NRAM);
|
||||
|
||||
for (int cur_expert = cur_expert_start; cur_expert <= cur_expert_end; cur_expert++) {
|
||||
__bang_eq_scalar((int *)cur_expert_result, (int *)expert_id_onchip, cur_expert, remain);
|
||||
|
||||
int cur_expert_offset = ((int *)expert_start_addr)[cur_expert];
|
||||
|
||||
// NOTE: __bang_filter() only support floating data type
|
||||
uint32_t cur_expert_count =
|
||||
__bang_filter((float *)filter_idx_onchip, (float *)gen_idx_onchip,
|
||||
(float *)cur_expert_result, remain);
|
||||
// Store to the corresponding address of sorted_idx
|
||||
if (cur_expert_count > 0) {
|
||||
__memcpy(((int *)sorted_idx) + cur_expert_offset, (int *)filter_idx_onchip,
|
||||
cur_expert_count * sizeof(int), NRAM2GDRAM);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Sync for all cores, get position index after sorting
|
||||
__sync_all_ipu();
|
||||
}
|
||||
|
||||
// 4. Get gather index for expand and combine
|
||||
template <bool is_sram_scatter>
|
||||
__mlu_func__ void getGatherIdx(int *gather_expand_idx,
|
||||
int *gather_combine_idx,
|
||||
int *sorted_idx,
|
||||
const int token_cur_core,
|
||||
const int cur_token_start,
|
||||
const int topk) {
|
||||
// 4. Partition on [num_token*topk],
|
||||
// load sorted_idx onchip,
|
||||
// generate sequence according to position index from 0, add token offset
|
||||
// gather_combine_idx = scatter(seq, sorted_idx)
|
||||
// gather_expand_idx = sorted_idx / topk
|
||||
// update sequence
|
||||
// NRAM:
|
||||
// -------------------------------------------------------------------
|
||||
// |sorted_idx_onchip|expand_idx_onchip|scatter_offset|scatter_sequence|
|
||||
// | deal_num | deal_num | deal_num | deal_num |
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
// Calculate new deal_num of generate gather index
|
||||
// NOTE: deal_num should align to 64 Bytes, because bang_scatter() constraints
|
||||
int deal_num = (NRAM_BUFFER_SIZE / sizeof(int)) / 4;
|
||||
int repeat = token_cur_core / deal_num;
|
||||
int remain = token_cur_core % deal_num;
|
||||
int token_addr_offset = cur_token_start;
|
||||
|
||||
// scatter dst GDRAM addr should align to 64B
|
||||
int *combine_idx_align_addr = (int *)((uint64_t)(gather_combine_idx) >> 6 << 6);
|
||||
int combine_idx_align_offset = (int)(gather_combine_idx - combine_idx_align_addr);
|
||||
|
||||
int8_t *sorted_idx_onchip = nram_buffer;
|
||||
int8_t *expand_idx_onchip = sorted_idx_onchip + deal_num * sizeof(int);
|
||||
int8_t *scatter_offset = expand_idx_onchip + deal_num * sizeof(int);
|
||||
int8_t *scatter_sequence = scatter_offset + deal_num * sizeof(int);
|
||||
|
||||
// Generate position index from 0
|
||||
// Add base offset to sequence according to current core token start address
|
||||
if (token_cur_core > 0) {
|
||||
if (deal_num <= token_cur_core) {
|
||||
generateIntSeq((int *)scatter_sequence, deal_num);
|
||||
__bang_add_scalar((int *)scatter_sequence, (int *)scatter_sequence, (int)token_addr_offset,
|
||||
deal_num);
|
||||
} else { // only remainder part
|
||||
generateIntSeq((int *)scatter_sequence, token_cur_core);
|
||||
__bang_add_scalar((int *)scatter_sequence, (int *)scatter_sequence, (int)token_addr_offset,
|
||||
token_cur_core);
|
||||
}
|
||||
}
|
||||
|
||||
// repeat part
|
||||
for (int i = 0; i < repeat; i++) {
|
||||
// Load current core sorted_idx
|
||||
__memcpy((int *)sorted_idx_onchip, ((int *)sorted_idx) + token_addr_offset,
|
||||
deal_num * sizeof(int), GDRAM2NRAM);
|
||||
|
||||
// offset = sorted_idx * sizeof(int), counted in bytes
|
||||
if (is_sram_scatter) {
|
||||
__bang_mul_scalar((int *)scatter_offset, (int *)sorted_idx_onchip, (int)(sizeof(int)),
|
||||
deal_num);
|
||||
} else {
|
||||
// GDRAM addr should align to 64B
|
||||
__bang_fusion(FUSION_FAM, (int *)scatter_offset, (int *)sorted_idx_onchip,
|
||||
combine_idx_align_offset, (int)(sizeof(int)), deal_num);
|
||||
}
|
||||
// Sync for scatter
|
||||
__sync_compute();
|
||||
|
||||
if (is_sram_scatter) {
|
||||
scatterSeqSram((int *)sram_buffer, (int *)scatter_sequence, (uint32_t *)scatter_offset,
|
||||
deal_num);
|
||||
} else {
|
||||
// Scatter to output gather_combine_idx
|
||||
scatterSeqDram((int *)combine_idx_align_addr, (int *)scatter_sequence,
|
||||
(uint32_t *)scatter_offset, deal_num);
|
||||
}
|
||||
|
||||
// expand_idx_onchip = sorted_idx / topk
|
||||
__bang_div((int *)expand_idx_onchip, (int *)sorted_idx_onchip, topk, deal_num);
|
||||
// Store expand idx
|
||||
__memcpy(((int *)gather_expand_idx) + token_addr_offset, (int *)expand_idx_onchip,
|
||||
deal_num * sizeof(int), NRAM2GDRAM);
|
||||
if (is_sram_scatter) {
|
||||
// if scatter to SRAM, need to sync compute with mv
|
||||
__sync_move();
|
||||
}
|
||||
// Add offset to sequence and token_address
|
||||
__bang_add_scalar((int *)scatter_sequence, (int *)scatter_sequence, (int)deal_num, deal_num);
|
||||
token_addr_offset += deal_num;
|
||||
}
|
||||
|
||||
// remainder part
|
||||
if (remain > 0) {
|
||||
// Load current core sorted_idx
|
||||
__memcpy((int *)sorted_idx_onchip, ((int *)sorted_idx) + token_addr_offset,
|
||||
remain * sizeof(int), GDRAM2NRAM);
|
||||
|
||||
// offset = sorted_idx * sizeof(int), counted in bytes
|
||||
if (is_sram_scatter) {
|
||||
__bang_mul_scalar((int *)scatter_offset, (int *)sorted_idx_onchip, (int)(sizeof(int)),
|
||||
remain);
|
||||
} else {
|
||||
// GDRAM addr should align to 64B
|
||||
__bang_fusion(FUSION_FAM, (int *)scatter_offset, (int *)sorted_idx_onchip,
|
||||
combine_idx_align_offset, (int)(sizeof(int)), remain);
|
||||
}
|
||||
|
||||
// Sync for scatter
|
||||
__sync_compute();
|
||||
|
||||
if (is_sram_scatter) {
|
||||
scatterSeqSram((int *)sram_buffer, (int *)scatter_sequence, (uint32_t *)scatter_offset,
|
||||
remain);
|
||||
} else {
|
||||
// Scatter to output gather_combine_idx
|
||||
scatterSeqDram((int *)combine_idx_align_addr, (int *)scatter_sequence,
|
||||
(uint32_t *)scatter_offset, remain);
|
||||
}
|
||||
|
||||
// expand_idx_onchip = sorted_idx / topk
|
||||
__bang_div((int *)expand_idx_onchip, (int *)sorted_idx_onchip, topk, remain);
|
||||
// Store expand idx
|
||||
__memcpy(((int *)gather_expand_idx) + token_addr_offset, (int *)expand_idx_onchip,
|
||||
remain * sizeof(int), NRAM2GDRAM);
|
||||
}
|
||||
}
|
||||
|
||||
// 4.1 Get gather combine index on SRAM
|
||||
__mlu_func__ void getCombineIdxSram(int *sorted_idx,
|
||||
const int token_cur_core,
|
||||
const int cur_token_start) {
|
||||
// 4.1 Partition on [num_token*topk], with only 1 union
|
||||
// load sorted_idx onchip,
|
||||
// generate sequence according to position index from 0, add token offset
|
||||
// gather_combine_idx = scatter(seq, sorted_idx)
|
||||
// update sequence
|
||||
// NRAM:
|
||||
// -------------------------------
|
||||
// |scatter_offset|scatter_sequence|
|
||||
// | deal_num | deal_num |
|
||||
// -------------------------------
|
||||
|
||||
// Calculate new deal_num of generate gather index
|
||||
// NOTE: deal_num should align to 64 Bytes, because bang_scatter() constraints
|
||||
int deal_num = (NRAM_BUFFER_SIZE / sizeof(int)) / 2;
|
||||
int repeat = token_cur_core / deal_num;
|
||||
int remain = token_cur_core % deal_num;
|
||||
int token_addr_offset = cur_token_start;
|
||||
|
||||
int8_t *scatter_offset = nram_buffer;
|
||||
int8_t *scatter_sequence = scatter_offset + deal_num * sizeof(int);
|
||||
|
||||
// Generate position index from 0
|
||||
// Add base offset to sequence according to current core token start address
|
||||
if (token_cur_core > 0) {
|
||||
if (deal_num <= token_cur_core) {
|
||||
generateIntSeq((int *)scatter_sequence, deal_num);
|
||||
__bang_add_scalar((int *)scatter_sequence, (int *)scatter_sequence, (int)token_addr_offset,
|
||||
deal_num);
|
||||
} else { // only remainder part
|
||||
generateIntSeq((int *)scatter_sequence, token_cur_core);
|
||||
__bang_add_scalar((int *)scatter_sequence, (int *)scatter_sequence, (int)token_addr_offset,
|
||||
token_cur_core);
|
||||
}
|
||||
}
|
||||
|
||||
// repeat part
|
||||
for (int i = 0; i < repeat; i++) {
|
||||
// Load current core sorted_idx
|
||||
__memcpy((int *)scatter_offset, ((int *)sorted_idx) + token_addr_offset, deal_num * sizeof(int),
|
||||
GDRAM2NRAM);
|
||||
|
||||
// offset = sorted_idx * sizeof(int), counted in bytes
|
||||
__bang_mul_scalar((int *)scatter_offset, (int *)scatter_offset, (int)(sizeof(int)), deal_num);
|
||||
// Sync for scatter
|
||||
__sync_compute();
|
||||
|
||||
// Scatter to SRAM
|
||||
scatterSeqSram((int *)sram_buffer, (int *)scatter_sequence, (uint32_t *)scatter_offset,
|
||||
deal_num);
|
||||
__sync_move();
|
||||
|
||||
// Add offset to sequence and token_address
|
||||
__bang_add_scalar((int *)scatter_sequence, (int *)scatter_sequence, (int)deal_num, deal_num);
|
||||
token_addr_offset += deal_num;
|
||||
}
|
||||
|
||||
// remainder part
|
||||
if (remain > 0) {
|
||||
// Load current core sorted_idx
|
||||
__memcpy((int *)scatter_offset, ((int *)sorted_idx) + token_addr_offset, remain * sizeof(int),
|
||||
GDRAM2NRAM);
|
||||
|
||||
// offset = sorted_idx * sizeof(int), counted in bytes
|
||||
__bang_mul_scalar((int *)scatter_offset, (int *)scatter_offset, (int)(sizeof(int)), remain);
|
||||
// Sync for scatter
|
||||
__sync_compute();
|
||||
|
||||
scatterSeqSram((int *)sram_buffer, (int *)scatter_sequence, (uint32_t *)scatter_offset, remain);
|
||||
}
|
||||
}
|
||||
|
||||
// 4.2 Get gather expand index
|
||||
__mlu_func__ void getExpandIdx(int *gather_expand_idx,
|
||||
int *sorted_idx,
|
||||
const int token_cur_core,
|
||||
const int cur_token_start,
|
||||
const int topk) {
|
||||
// 4.2 Partition on [num_token*topk],
|
||||
// load sorted_idx onchip,
|
||||
// gather_expand_idx = sorted_idx / topk
|
||||
// NRAM:
|
||||
// -----------------------------------
|
||||
// |sorted_idx_onchip|expand_idx_onchip|
|
||||
// | deal_num | deal_num |
|
||||
// -----------------------------------
|
||||
|
||||
// Calculate new deal_num of generate gather index
|
||||
int deal_num = (NRAM_BUFFER_SIZE / sizeof(int)) / 2;
|
||||
int repeat = token_cur_core / deal_num;
|
||||
int remain = token_cur_core % deal_num;
|
||||
int token_addr_offset = cur_token_start;
|
||||
|
||||
int8_t *sorted_idx_onchip = nram_buffer;
|
||||
int8_t *expand_idx_onchip = sorted_idx_onchip + deal_num * sizeof(int);
|
||||
|
||||
// repeat part
|
||||
for (int i = 0; i < repeat; i++) {
|
||||
// Load current core sorted_idx
|
||||
__memcpy((int *)sorted_idx_onchip, ((int *)sorted_idx) + token_addr_offset,
|
||||
deal_num * sizeof(int), GDRAM2NRAM);
|
||||
|
||||
// expand_idx_onchip = sorted_idx / topk
|
||||
__bang_div((int *)expand_idx_onchip, (int *)sorted_idx_onchip, topk, deal_num);
|
||||
// Store expand idx
|
||||
__memcpy(((int *)gather_expand_idx) + token_addr_offset, (int *)expand_idx_onchip,
|
||||
deal_num * sizeof(int), NRAM2GDRAM);
|
||||
token_addr_offset += deal_num;
|
||||
}
|
||||
|
||||
// remainder part
|
||||
if (remain > 0) {
|
||||
// Load current core sorted_idx
|
||||
__memcpy((int *)sorted_idx_onchip, ((int *)sorted_idx) + token_addr_offset,
|
||||
remain * sizeof(int), GDRAM2NRAM);
|
||||
// expand_idx_onchip = sorted_idx / topk
|
||||
__bang_div((int *)expand_idx_onchip, (int *)sorted_idx_onchip, topk, remain);
|
||||
// Store expand idx
|
||||
__memcpy(((int *)gather_expand_idx) + token_addr_offset, (int *)expand_idx_onchip,
|
||||
remain * sizeof(int), NRAM2GDRAM);
|
||||
}
|
||||
}
|
||||
|
||||
__mlu_global__ void launchMoeGenIdxKernel(int *gather_expand_idx,
|
||||
int *gather_combine_idx,
|
||||
int *token_count,
|
||||
int *cusum_token_count,
|
||||
void *workspace,
|
||||
const void *expert_id,
|
||||
const int num_token,
|
||||
const int num_expert,
|
||||
const int topk) {
|
||||
// Store token count presum result, shape [num_expert + 1]
|
||||
int *token_count_presum = (cusum_token_count != nullptr) ? cusum_token_count : (int *)workspace;
|
||||
// Store position index after sorting, shape [num_token*topk]
|
||||
int *sorted_idx = ((int *)workspace) + num_expert + 1;
|
||||
|
||||
// Calculate partition information for different processes
|
||||
// Partition on [num_token*topk]
|
||||
uint32_t token_total_num = num_token * topk;
|
||||
uint32_t token_cur_core = token_total_num / taskDim;
|
||||
uint32_t token_remain_num = token_total_num % taskDim;
|
||||
token_cur_core += (uint32_t)(taskId < token_remain_num);
|
||||
// Current core range according to partition on [num_token*topk]
|
||||
uint32_t cur_token_start = (taskId < token_remain_num)
|
||||
? token_cur_core * taskId
|
||||
: token_cur_core * taskId + token_remain_num;
|
||||
|
||||
// Partition on [num_expert]
|
||||
uint32_t expert_cur_core = num_expert / taskDim;
|
||||
uint32_t expert_remain_num = num_expert % taskDim;
|
||||
expert_cur_core += (uint32_t)(taskId < expert_remain_num);
|
||||
// Current core range according to partition on [num_expert]
|
||||
uint32_t cur_expert_start = (taskId < expert_remain_num)
|
||||
? expert_cur_core * taskId
|
||||
: expert_cur_core * taskId + expert_remain_num;
|
||||
uint32_t cur_expert_end = cur_expert_start + expert_cur_core - 1;
|
||||
|
||||
// Use Union1 SRAM to scatter, only MLU500 series support now
|
||||
#if __BANG_ARCH__ >= 592
|
||||
bool is_sram_scatter = token_total_num * sizeof(int) < SRAM_BUFFER_SIZE;
|
||||
#else
|
||||
bool is_sram_scatter = false;
|
||||
#endif
|
||||
|
||||
if (__is_ipu()) {
|
||||
// 1. Get token count
|
||||
getTokenCount((int *)token_count, (int *)expert_id, token_cur_core, cur_token_start,
|
||||
num_expert);
|
||||
// 2. Get presum of token count
|
||||
getTokenCountPresum((int *)token_count_presum, (int *)token_count, num_expert);
|
||||
|
||||
// 3. Get expert position index after sorting
|
||||
getSortedIdx((int *)sorted_idx, (int *)expert_id, (int *)token_count_presum, token_total_num,
|
||||
num_expert, expert_cur_core, cur_expert_start, cur_expert_end);
|
||||
}
|
||||
|
||||
#if EXPERT_AVG_COUNT_TEST
|
||||
// NOTE: test avg expert code here:
|
||||
if (__is_ipu() && taskId == 0) {
|
||||
modifyTokenCountAndPresum((int *)token_count_presum, (int *)token_count, token_total_num,
|
||||
num_expert);
|
||||
}
|
||||
__sync_cluster();
|
||||
#endif
|
||||
|
||||
// 4. Get gather index for expand and combine
|
||||
if (is_sram_scatter) {
|
||||
// Only use Union1 SRAM
|
||||
uint32_t scatter_idx_cur_core = token_total_num / 4;
|
||||
uint32_t scatter_idx_remain_num = token_total_num % 4;
|
||||
scatter_idx_cur_core += (uint32_t)(taskId < scatter_idx_remain_num);
|
||||
uint32_t cur_idx_start = (taskId < scatter_idx_remain_num)
|
||||
? scatter_idx_cur_core * taskId
|
||||
: scatter_idx_cur_core * taskId + scatter_idx_remain_num;
|
||||
|
||||
// Only Union1 task type,
|
||||
// deal once num is same with deal_num in getGatherIdx,
|
||||
// which means only 1 repeat to generate both expand and combine idx on NRAM
|
||||
const int deal_once_num = (NRAM_BUFFER_SIZE / sizeof(int)) / 4;
|
||||
if (taskDim <= 4 || token_total_num < deal_once_num) {
|
||||
if (taskId < 4) {
|
||||
if (__is_ipu()) {
|
||||
getGatherIdx<true>((int *)gather_expand_idx, (int *)gather_combine_idx, (int *)sorted_idx,
|
||||
scatter_idx_cur_core, cur_idx_start, topk);
|
||||
// sync for ipu and mpu
|
||||
__sync_cluster();
|
||||
} else {
|
||||
// sync for ipu and mpu
|
||||
__sync_cluster();
|
||||
__memcpy_async((int *)gather_combine_idx, (int *)sram_buffer,
|
||||
token_total_num * sizeof(int), SRAM2GDRAM);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// If taskDim > 4, use first union to generate combine idx,
|
||||
// use other union to generate expand idx
|
||||
if (taskId < 4) {
|
||||
if (__is_ipu()) {
|
||||
// Scatter combine idx to SRAM
|
||||
getCombineIdxSram((int *)sorted_idx, scatter_idx_cur_core, cur_idx_start);
|
||||
__sync_cluster();
|
||||
} else {
|
||||
__sync_cluster();
|
||||
__memcpy_async((int *)gather_combine_idx, (int *)sram_buffer,
|
||||
token_total_num * sizeof(int), SRAM2GDRAM);
|
||||
}
|
||||
} else {
|
||||
// Other union generate expand idx
|
||||
if (__is_ipu()) {
|
||||
uint32_t expand_dim = taskDim - 4;
|
||||
uint32_t expand_id = taskId - 4;
|
||||
uint32_t expand_token_cur_core = token_total_num / expand_dim;
|
||||
uint32_t expand_token_remain_num = token_total_num % expand_dim;
|
||||
expand_token_cur_core += (uint32_t)(expand_id < expand_token_remain_num);
|
||||
|
||||
uint32_t expand_cur_token_start =
|
||||
(expand_id < expand_token_remain_num)
|
||||
? expand_token_cur_core * expand_id
|
||||
: expand_token_cur_core * expand_id + expand_token_remain_num;
|
||||
getExpandIdx((int *)gather_expand_idx, (int *)sorted_idx, expand_token_cur_core,
|
||||
expand_cur_token_start, topk);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// not use SRAM to generate both expand and combine idx
|
||||
if (__is_ipu()) {
|
||||
getGatherIdx<false>((int *)gather_expand_idx, (int *)gather_combine_idx, (int *)sorted_idx,
|
||||
token_cur_core, cur_token_start, topk);
|
||||
}
|
||||
}
|
||||
|
||||
// step 5 does not need MPU
|
||||
if (__is_mpu()) {
|
||||
return;
|
||||
}
|
||||
} // end of kernel
|
||||
|
||||
} // namespace kernels
|
||||
|
||||
KernelStatus invokeMoeGenIdxKernel(cnrtQueue_t queue,
|
||||
int *gather_expand_idx,
|
||||
int *gather_combine_idx,
|
||||
int *token_count,
|
||||
int *cusum_token_count,
|
||||
void *workspace,
|
||||
const void *expert_id,
|
||||
const int num_token,
|
||||
const int num_expert,
|
||||
const int topk) {
|
||||
CNdev dev;
|
||||
cnCtxGetDevice(&dev);
|
||||
int cluster_num;
|
||||
int core_num;
|
||||
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
|
||||
CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev));
|
||||
|
||||
const int token_total_num = num_token * topk;
|
||||
|
||||
// For partition on num_token*topk, single core processes at least 128 num
|
||||
const int single_core_num_limit = 1024;
|
||||
int need_core_num = std::ceil(float(token_total_num) / single_core_num_limit);
|
||||
// When partition on num_expert, each core at least processes one expert
|
||||
need_core_num = std::max(num_expert, need_core_num);
|
||||
|
||||
// When consider UnionX cnrt func type, reset cluster_num
|
||||
if (token_total_num <= 4096) { // Block
|
||||
cnrtFunctionType_t k_type = cnrtFuncTypeBlock;
|
||||
cnrtDim3_t k_dim{1, 1, 1};
|
||||
// Block kernel does not need workspace
|
||||
kernels::launchMoeGenIdxBlockKernel<<<k_dim, k_type, queue>>>(
|
||||
gather_expand_idx, gather_combine_idx, token_count, cusum_token_count, expert_id, num_token,
|
||||
num_expert, topk);
|
||||
|
||||
return KernelStatus::KERNEL_STATUS_SUCCESS;
|
||||
} else if (need_core_num <= 4) { // Union1
|
||||
cluster_num = 1;
|
||||
} else if (need_core_num <= 8) { // Union2
|
||||
cluster_num = std::min(cluster_num, 2);
|
||||
} else if (need_core_num <= 16) { // Union4
|
||||
cluster_num = std::min(cluster_num, 4);
|
||||
} else if (need_core_num <= 32) { // Union8
|
||||
cluster_num = std::min(cluster_num, 8);
|
||||
}
|
||||
|
||||
cnrtFunctionType_t k_type;
|
||||
cnrtDim3_t k_dim{1, 1, 1};
|
||||
|
||||
// Find max UnionX cnrt func type
|
||||
if (cluster_num == 1) {
|
||||
k_type = cnrtFuncTypeUnion1;
|
||||
k_dim.x = 4;
|
||||
} else if (cluster_num < 4) { // cluster num is 2 or 3
|
||||
k_type = cnrtFuncTypeUnion2;
|
||||
k_dim.x = 8;
|
||||
} else if (cluster_num < 8) { // cluster num is 4,5,6,7
|
||||
k_type = cnrtFuncTypeUnion4;
|
||||
k_dim.x = 16;
|
||||
} else { // cluster num larger than 8
|
||||
k_type = cnrtFuncTypeUnion8;
|
||||
k_dim.x = 32;
|
||||
}
|
||||
|
||||
// The expert_id is int data type
|
||||
kernels::launchMoeGenIdxKernel<<<k_dim, k_type, queue>>>(
|
||||
gather_expand_idx, gather_combine_idx, token_count, cusum_token_count, workspace, expert_id,
|
||||
num_token, num_expert, topk);
|
||||
return KernelStatus::KERNEL_STATUS_SUCCESS;
|
||||
}
|
||||
|
||||
#undef EXPERT_AVG_COUNT_TEST // undef test macro
|
||||
|
||||
} // namespace tmo
|
||||
Reference in New Issue
Block a user