Files
enginex-mlu370-vllm/torch_mlu_ops-v1.3.2/csrc/kernels/moe/expand_input.mlu
2026-02-04 17:39:32 +08:00

220 lines
9.1 KiB
Plaintext

#include <bang_device_functions_extra.h>
#include <mlu.h>
#include "cnnl.h"
#include "cnrt.h"
#include "expand_input.mluh"
namespace tmo {
namespace kernels {
#define RESERVED_SIZE (32 * 1024)
#define NRAM_BUFFER_SIZE (__MLU_NRAM_SIZE__ * 1024 - RESERVED_SIZE)
#define SRAM_BUFFER_SIZE (__MLU_SRAM_SIZE__ * 1024 - RESERVED_SIZE)
#define MEMCPY_BURST_SIZE 128
__mlu_shared__ int8_t sram_buffer[SRAM_BUFFER_SIZE];
__nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE];
// T_offset: uint32_t or uint64_t
template <size_t data_size, typename T_offset>
__mlu_func__ void ExpandInputKernel(void *output,
void *input,
int *index,
int num_token,
int hidden_size,
int num_index) {
void *input_ptr = input;
uint64_t input_size = data_size * num_token * hidden_size;
// whether SRAM_BUFFER_SIZE can hold the input data
// sram_enable is true ==> is_nram_output is true
bool sram_enable = (hidden_size == 1) && (input_size < SRAM_BUFFER_SIZE);
if (sram_enable) {
input_ptr = (void *)sram_buffer;
__memcpy(input_ptr, input, input_size, GDRAM2SRAM);
__sync_cluster();
}
if (__is_mpu()) {
return;
}
// Each ipu core processes no less than 128B of data, and the remaining cores can idle
int32_t max_task_num = data_size * hidden_size * num_index / MEMCPY_BURST_SIZE;
uint32_t maxTaskDim = std::min(taskDim, std::max(max_task_num, 1));
uint32_t total_num = num_index;
uint32_t base = total_num / maxTaskDim;
uint32_t tail = total_num - base * maxTaskDim;
if (taskId >= maxTaskDim) {
return;
}
uint32_t batch_per_core = base + (taskId < tail ? 1 : 0);
uint32_t batch_step = base * taskId + (taskId < tail ? taskId : tail);
// nram
/*
* first: compute offset: index[i] * data_size
* second: gather data
* -------------------------------------------------
* addr || index/offset | output |
* type || int32_t/T_offset | T |
* num || n | n * hidden_size |
* -------------------------------------------------
*/
uint32_t nram_size_per_pixel = sizeof(T_offset) + hidden_size * data_size;
// whether nram can hold two pixel: if so, then GDRAM->NRAM->GDRAM, otherwise GDRAM->GDRAM
bool is_nram_output = nram_size_per_pixel * 2 <= NRAM_BUFFER_SIZE;
uint32_t per_num =
is_nram_output ? NRAM_BUFFER_SIZE / nram_size_per_pixel : NRAM_BUFFER_SIZE / sizeof(T_offset);
int8_t *output_base = (int8_t *)output + (uint64_t)batch_step * hidden_size * data_size;
int *index_base = index + batch_step;
T_offset *nram_offset = (T_offset *)nram_buffer;
int32_t *nram_index;
if (std::is_same<T_offset, int64_t>::value) {
nram_index = (int32_t *)nram_offset + per_num;
} else {
nram_index = (int32_t *)nram_offset;
}
int8_t *nram_output = (int8_t *)(nram_offset + per_num);
uint32_t repeat = batch_per_core / per_num;
uint32_t remain = batch_per_core - repeat * per_num;
uint32_t deal_num = per_num;
uint32_t is_remain = remain != 0 ? 1 : 0;
for (int32_t i = 0; i < repeat + is_remain; i++) {
if (i == repeat) {
deal_num = remain;
}
int8_t *output_ptr = output_base + (uint64_t)i * per_num * hidden_size * data_size;
int32_t *index_ptr = index_base + i * per_num;
// index -> offset
__memcpy((void *)nram_index, (void *)index_ptr, deal_num * sizeof(int32_t), GDRAM2NRAM);
if (std::is_same<T_offset, uint64_t>::value) {
#if __BANG_ARCH__ > 592
__bang_int322int64((int64_t *)nram_offset, (int32_t *)nram_index, deal_num, 0, 0);
#else
__bang_int322int64((int64_t *)nram_offset, (int32_t *)nram_index, deal_num);
#endif
}
__bang_mul_scalar(nram_offset, nram_offset, (int64_t)data_size * hidden_size, deal_num);
// copy
if (is_nram_output) {
__bang_write_zero((int8_t *)nram_output, deal_num * hidden_size);
mluMemcpyDirection_t dir = sram_enable ? SRAM2NRAM : GDRAM2NRAM;
// GDRAM or SRAM -> NRAM -> GDRAM
#if __BANG_ARCH__ >= 592 // gather requires
__gather(nram_output, input_ptr, nram_offset, hidden_size * data_size, dir,
hidden_size * data_size, deal_num);
#else
for (int32_t j = 0; j < deal_num; j++) {
T_offset offset_value = *(nram_offset + j);
int8_t *input_offset = (int8_t *)input_ptr + offset_value;
__memcpy(nram_output + j * hidden_size * data_size, input_offset, hidden_size * data_size,
dir);
}
#endif // __BANG_ARCH__
__memcpy(output_ptr, nram_output, deal_num * hidden_size * data_size, NRAM2GDRAM);
} else {
// GDRAM -> GDRAM
#if __BANG_ARCH__ >= 592 // gather requires
__gather(output_ptr, input, (uint64_t *)nram_offset, hidden_size * data_size, GDRAM2GDRAM,
hidden_size * data_size, deal_num);
#else
for (int32_t j = 0; j < deal_num; j++) {
T_offset offset_value = *(nram_offset + j);
int8_t *input_offset = (int8_t *)input + offset_value;
__memcpy(output_ptr + (T_offset)j * hidden_size * data_size, input_offset,
hidden_size * data_size, GDRAM2GDRAM);
}
#endif // __BANG_ARCH__
}
}
}
// T_offset: uint32_t or uint64_t
template <size_t data_size, typename T_offset>
__mlu_global__ void MLUExpandInputKernel(void *expand_hidden_state,
void *hidden_state,
int *gather_idx,
int *cusum_token_count,
int num_token,
int hidden_size,
int topk,
int total_expert_num,
int start_expert_id,
int expert_count) {
int32_t num_index = num_token * topk;
int *gather_start_idx = (int *)gather_idx;
if (cusum_token_count != nullptr) {
num_index = *((int *)cusum_token_count + start_expert_id + expert_count) -
*((int *)cusum_token_count + start_expert_id);
gather_start_idx = (int *)gather_idx + *(cusum_token_count + start_expert_id);
}
ExpandInputKernel<data_size, T_offset>(expand_hidden_state, hidden_state, gather_start_idx,
num_token, hidden_size, num_index);
}
// instantiate kernels
#define INSTANTIATE_ONE(data_size, T_offset) \
template __mlu_global__ void MLUExpandInputKernel<data_size, T_offset>( \
void *, void *, int *, int *, int, int, int, int, int, int);
INSTANTIATE_ONE(1, uint32_t)
INSTANTIATE_ONE(2, uint32_t)
INSTANTIATE_ONE(4, uint32_t)
INSTANTIATE_ONE(8, uint32_t)
// large tensor
INSTANTIATE_ONE(1, uint64_t)
INSTANTIATE_ONE(2, uint64_t)
INSTANTIATE_ONE(4, uint64_t)
INSTANTIATE_ONE(8, uint64_t)
} // namespace kernels
KernelStatus invokeMoeExpandInputKernel(cnrtQueue_t queue,
void *expand_hidden_state,
const void *hidden_state,
const int *gather_idx,
const int *cusum_token_count,
int num_token,
int hidden_size,
int topk,
cnnlDataType_t data_type,
int total_expert_num,
int start_expert_id,
int expert_count) {
CNdev dev;
cnCtxGetDevice(&dev);
int cluster_num;
int core_num;
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev));
size_t type_size = 0;
cnnlGetSizeOfDataType(data_type, &type_size);
int max_cluster_num =
(uint64_t)hidden_size * num_token * topk * type_size / (core_num * MEMCPY_BURST_SIZE);
cluster_num = std::min(std::max(max_cluster_num, 1), cluster_num);
cnrtDim3_t dim{.x = (uint32_t)core_num, .y = (uint32_t)cluster_num, .z = 1};
void (*expand_input_kernels[])(void *, void *, int *, int *, int, int, int, int, int, int) = {
kernels::MLUExpandInputKernel<1, uint32_t>, kernels::MLUExpandInputKernel<2, uint32_t>,
kernels::MLUExpandInputKernel<4, uint32_t>, kernels::MLUExpandInputKernel<8, uint32_t>,
kernels::MLUExpandInputKernel<1, uint64_t>, kernels::MLUExpandInputKernel<2, uint64_t>,
kernels::MLUExpandInputKernel<4, uint64_t>, kernels::MLUExpandInputKernel<8, uint64_t>};
bool is_large_tensor = type_size * hidden_size * num_token * topk > INT32_MAX;
int kernel_index = (type_size == 8 ? 3 : type_size >> 1) + (is_large_tensor ? 4 : 0);
expand_input_kernels[kernel_index]<<<dim, cnrtFuncTypeUnion1, queue>>>(
(void *)expand_hidden_state, (void *)hidden_state, (int *)gather_idx,
(int *)cusum_token_count, num_token, hidden_size, topk, total_expert_num, start_expert_id,
expert_count);
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
} // namespace tmo