220 lines
9.1 KiB
Plaintext
220 lines
9.1 KiB
Plaintext
#include <bang_device_functions_extra.h>
|
|
#include <mlu.h>
|
|
#include "cnnl.h"
|
|
#include "cnrt.h"
|
|
#include "expand_input.mluh"
|
|
|
|
namespace tmo {
|
|
|
|
namespace kernels {
|
|
|
|
#define RESERVED_SIZE (32 * 1024)
|
|
#define NRAM_BUFFER_SIZE (__MLU_NRAM_SIZE__ * 1024 - RESERVED_SIZE)
|
|
#define SRAM_BUFFER_SIZE (__MLU_SRAM_SIZE__ * 1024 - RESERVED_SIZE)
|
|
#define MEMCPY_BURST_SIZE 128
|
|
|
|
__mlu_shared__ int8_t sram_buffer[SRAM_BUFFER_SIZE];
|
|
__nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE];
|
|
|
|
// T_offset: uint32_t or uint64_t
|
|
template <size_t data_size, typename T_offset>
|
|
__mlu_func__ void ExpandInputKernel(void *output,
|
|
void *input,
|
|
int *index,
|
|
int num_token,
|
|
int hidden_size,
|
|
int num_index) {
|
|
void *input_ptr = input;
|
|
uint64_t input_size = data_size * num_token * hidden_size;
|
|
// whether SRAM_BUFFER_SIZE can hold the input data
|
|
// sram_enable is true ==> is_nram_output is true
|
|
bool sram_enable = (hidden_size == 1) && (input_size < SRAM_BUFFER_SIZE);
|
|
if (sram_enable) {
|
|
input_ptr = (void *)sram_buffer;
|
|
__memcpy(input_ptr, input, input_size, GDRAM2SRAM);
|
|
__sync_cluster();
|
|
}
|
|
if (__is_mpu()) {
|
|
return;
|
|
}
|
|
|
|
// Each ipu core processes no less than 128B of data, and the remaining cores can idle
|
|
int32_t max_task_num = data_size * hidden_size * num_index / MEMCPY_BURST_SIZE;
|
|
uint32_t maxTaskDim = std::min(taskDim, std::max(max_task_num, 1));
|
|
uint32_t total_num = num_index;
|
|
uint32_t base = total_num / maxTaskDim;
|
|
uint32_t tail = total_num - base * maxTaskDim;
|
|
if (taskId >= maxTaskDim) {
|
|
return;
|
|
}
|
|
uint32_t batch_per_core = base + (taskId < tail ? 1 : 0);
|
|
uint32_t batch_step = base * taskId + (taskId < tail ? taskId : tail);
|
|
|
|
// nram
|
|
/*
|
|
* first: compute offset: index[i] * data_size
|
|
* second: gather data
|
|
* -------------------------------------------------
|
|
* addr || index/offset | output |
|
|
* type || int32_t/T_offset | T |
|
|
* num || n | n * hidden_size |
|
|
* -------------------------------------------------
|
|
*/
|
|
|
|
uint32_t nram_size_per_pixel = sizeof(T_offset) + hidden_size * data_size;
|
|
// whether nram can hold two pixel: if so, then GDRAM->NRAM->GDRAM, otherwise GDRAM->GDRAM
|
|
bool is_nram_output = nram_size_per_pixel * 2 <= NRAM_BUFFER_SIZE;
|
|
uint32_t per_num =
|
|
is_nram_output ? NRAM_BUFFER_SIZE / nram_size_per_pixel : NRAM_BUFFER_SIZE / sizeof(T_offset);
|
|
|
|
int8_t *output_base = (int8_t *)output + (uint64_t)batch_step * hidden_size * data_size;
|
|
int *index_base = index + batch_step;
|
|
T_offset *nram_offset = (T_offset *)nram_buffer;
|
|
int32_t *nram_index;
|
|
if (std::is_same<T_offset, int64_t>::value) {
|
|
nram_index = (int32_t *)nram_offset + per_num;
|
|
} else {
|
|
nram_index = (int32_t *)nram_offset;
|
|
}
|
|
int8_t *nram_output = (int8_t *)(nram_offset + per_num);
|
|
|
|
uint32_t repeat = batch_per_core / per_num;
|
|
uint32_t remain = batch_per_core - repeat * per_num;
|
|
uint32_t deal_num = per_num;
|
|
uint32_t is_remain = remain != 0 ? 1 : 0;
|
|
|
|
for (int32_t i = 0; i < repeat + is_remain; i++) {
|
|
if (i == repeat) {
|
|
deal_num = remain;
|
|
}
|
|
int8_t *output_ptr = output_base + (uint64_t)i * per_num * hidden_size * data_size;
|
|
int32_t *index_ptr = index_base + i * per_num;
|
|
|
|
// index -> offset
|
|
__memcpy((void *)nram_index, (void *)index_ptr, deal_num * sizeof(int32_t), GDRAM2NRAM);
|
|
if (std::is_same<T_offset, uint64_t>::value) {
|
|
#if __BANG_ARCH__ > 592
|
|
__bang_int322int64((int64_t *)nram_offset, (int32_t *)nram_index, deal_num, 0, 0);
|
|
#else
|
|
__bang_int322int64((int64_t *)nram_offset, (int32_t *)nram_index, deal_num);
|
|
#endif
|
|
}
|
|
__bang_mul_scalar(nram_offset, nram_offset, (int64_t)data_size * hidden_size, deal_num);
|
|
|
|
// copy
|
|
if (is_nram_output) {
|
|
__bang_write_zero((int8_t *)nram_output, deal_num * hidden_size);
|
|
mluMemcpyDirection_t dir = sram_enable ? SRAM2NRAM : GDRAM2NRAM;
|
|
// GDRAM or SRAM -> NRAM -> GDRAM
|
|
#if __BANG_ARCH__ >= 592 // gather requires
|
|
__gather(nram_output, input_ptr, nram_offset, hidden_size * data_size, dir,
|
|
hidden_size * data_size, deal_num);
|
|
#else
|
|
for (int32_t j = 0; j < deal_num; j++) {
|
|
T_offset offset_value = *(nram_offset + j);
|
|
int8_t *input_offset = (int8_t *)input_ptr + offset_value;
|
|
__memcpy(nram_output + j * hidden_size * data_size, input_offset, hidden_size * data_size,
|
|
dir);
|
|
}
|
|
#endif // __BANG_ARCH__
|
|
__memcpy(output_ptr, nram_output, deal_num * hidden_size * data_size, NRAM2GDRAM);
|
|
} else {
|
|
// GDRAM -> GDRAM
|
|
#if __BANG_ARCH__ >= 592 // gather requires
|
|
__gather(output_ptr, input, (uint64_t *)nram_offset, hidden_size * data_size, GDRAM2GDRAM,
|
|
hidden_size * data_size, deal_num);
|
|
#else
|
|
for (int32_t j = 0; j < deal_num; j++) {
|
|
T_offset offset_value = *(nram_offset + j);
|
|
int8_t *input_offset = (int8_t *)input + offset_value;
|
|
__memcpy(output_ptr + (T_offset)j * hidden_size * data_size, input_offset,
|
|
hidden_size * data_size, GDRAM2GDRAM);
|
|
}
|
|
#endif // __BANG_ARCH__
|
|
}
|
|
}
|
|
}
|
|
|
|
// T_offset: uint32_t or uint64_t
|
|
template <size_t data_size, typename T_offset>
|
|
__mlu_global__ void MLUExpandInputKernel(void *expand_hidden_state,
|
|
void *hidden_state,
|
|
int *gather_idx,
|
|
int *cusum_token_count,
|
|
int num_token,
|
|
int hidden_size,
|
|
int topk,
|
|
int total_expert_num,
|
|
int start_expert_id,
|
|
int expert_count) {
|
|
int32_t num_index = num_token * topk;
|
|
int *gather_start_idx = (int *)gather_idx;
|
|
if (cusum_token_count != nullptr) {
|
|
num_index = *((int *)cusum_token_count + start_expert_id + expert_count) -
|
|
*((int *)cusum_token_count + start_expert_id);
|
|
gather_start_idx = (int *)gather_idx + *(cusum_token_count + start_expert_id);
|
|
}
|
|
ExpandInputKernel<data_size, T_offset>(expand_hidden_state, hidden_state, gather_start_idx,
|
|
num_token, hidden_size, num_index);
|
|
}
|
|
// instantiate kernels
|
|
#define INSTANTIATE_ONE(data_size, T_offset) \
|
|
template __mlu_global__ void MLUExpandInputKernel<data_size, T_offset>( \
|
|
void *, void *, int *, int *, int, int, int, int, int, int);
|
|
|
|
INSTANTIATE_ONE(1, uint32_t)
|
|
INSTANTIATE_ONE(2, uint32_t)
|
|
INSTANTIATE_ONE(4, uint32_t)
|
|
INSTANTIATE_ONE(8, uint32_t)
|
|
// large tensor
|
|
INSTANTIATE_ONE(1, uint64_t)
|
|
INSTANTIATE_ONE(2, uint64_t)
|
|
INSTANTIATE_ONE(4, uint64_t)
|
|
INSTANTIATE_ONE(8, uint64_t)
|
|
} // namespace kernels
|
|
|
|
KernelStatus invokeMoeExpandInputKernel(cnrtQueue_t queue,
|
|
void *expand_hidden_state,
|
|
const void *hidden_state,
|
|
const int *gather_idx,
|
|
const int *cusum_token_count,
|
|
int num_token,
|
|
int hidden_size,
|
|
int topk,
|
|
cnnlDataType_t data_type,
|
|
int total_expert_num,
|
|
int start_expert_id,
|
|
int expert_count) {
|
|
CNdev dev;
|
|
cnCtxGetDevice(&dev);
|
|
int cluster_num;
|
|
int core_num;
|
|
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
|
|
CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev));
|
|
|
|
size_t type_size = 0;
|
|
cnnlGetSizeOfDataType(data_type, &type_size);
|
|
|
|
int max_cluster_num =
|
|
(uint64_t)hidden_size * num_token * topk * type_size / (core_num * MEMCPY_BURST_SIZE);
|
|
cluster_num = std::min(std::max(max_cluster_num, 1), cluster_num);
|
|
|
|
cnrtDim3_t dim{.x = (uint32_t)core_num, .y = (uint32_t)cluster_num, .z = 1};
|
|
|
|
void (*expand_input_kernels[])(void *, void *, int *, int *, int, int, int, int, int, int) = {
|
|
kernels::MLUExpandInputKernel<1, uint32_t>, kernels::MLUExpandInputKernel<2, uint32_t>,
|
|
kernels::MLUExpandInputKernel<4, uint32_t>, kernels::MLUExpandInputKernel<8, uint32_t>,
|
|
kernels::MLUExpandInputKernel<1, uint64_t>, kernels::MLUExpandInputKernel<2, uint64_t>,
|
|
kernels::MLUExpandInputKernel<4, uint64_t>, kernels::MLUExpandInputKernel<8, uint64_t>};
|
|
|
|
bool is_large_tensor = type_size * hidden_size * num_token * topk > INT32_MAX;
|
|
int kernel_index = (type_size == 8 ? 3 : type_size >> 1) + (is_large_tensor ? 4 : 0);
|
|
expand_input_kernels[kernel_index]<<<dim, cnrtFuncTypeUnion1, queue>>>(
|
|
(void *)expand_hidden_state, (void *)hidden_state, (int *)gather_idx,
|
|
(int *)cusum_token_count, num_token, hidden_size, topk, total_expert_num, start_expert_id,
|
|
expert_count);
|
|
return KernelStatus::KERNEL_STATUS_SUCCESS;
|
|
}
|
|
|
|
} // namespace tmo
|