#include #include #include "cnnl.h" #include "cnrt.h" #include "expand_input.mluh" namespace tmo { namespace kernels { #define RESERVED_SIZE (32 * 1024) #define NRAM_BUFFER_SIZE (__MLU_NRAM_SIZE__ * 1024 - RESERVED_SIZE) #define SRAM_BUFFER_SIZE (__MLU_SRAM_SIZE__ * 1024 - RESERVED_SIZE) #define MEMCPY_BURST_SIZE 128 __mlu_shared__ int8_t sram_buffer[SRAM_BUFFER_SIZE]; __nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE]; // T_offset: uint32_t or uint64_t template __mlu_func__ void ExpandInputKernel(void *output, void *input, int *index, int num_token, int hidden_size, int num_index) { void *input_ptr = input; uint64_t input_size = data_size * num_token * hidden_size; // whether SRAM_BUFFER_SIZE can hold the input data // sram_enable is true ==> is_nram_output is true bool sram_enable = (hidden_size == 1) && (input_size < SRAM_BUFFER_SIZE); if (sram_enable) { input_ptr = (void *)sram_buffer; __memcpy(input_ptr, input, input_size, GDRAM2SRAM); __sync_cluster(); } if (__is_mpu()) { return; } // Each ipu core processes no less than 128B of data, and the remaining cores can idle int32_t max_task_num = data_size * hidden_size * num_index / MEMCPY_BURST_SIZE; uint32_t maxTaskDim = std::min(taskDim, std::max(max_task_num, 1)); uint32_t total_num = num_index; uint32_t base = total_num / maxTaskDim; uint32_t tail = total_num - base * maxTaskDim; if (taskId >= maxTaskDim) { return; } uint32_t batch_per_core = base + (taskId < tail ? 1 : 0); uint32_t batch_step = base * taskId + (taskId < tail ? taskId : tail); // nram /* * first: compute offset: index[i] * data_size * second: gather data * ------------------------------------------------- * addr || index/offset | output | * type || int32_t/T_offset | T | * num || n | n * hidden_size | * ------------------------------------------------- */ uint32_t nram_size_per_pixel = sizeof(T_offset) + hidden_size * data_size; // whether nram can hold two pixel: if so, then GDRAM->NRAM->GDRAM, otherwise GDRAM->GDRAM bool is_nram_output = nram_size_per_pixel * 2 <= NRAM_BUFFER_SIZE; uint32_t per_num = is_nram_output ? NRAM_BUFFER_SIZE / nram_size_per_pixel : NRAM_BUFFER_SIZE / sizeof(T_offset); int8_t *output_base = (int8_t *)output + (uint64_t)batch_step * hidden_size * data_size; int *index_base = index + batch_step; T_offset *nram_offset = (T_offset *)nram_buffer; int32_t *nram_index; if (std::is_same::value) { nram_index = (int32_t *)nram_offset + per_num; } else { nram_index = (int32_t *)nram_offset; } int8_t *nram_output = (int8_t *)(nram_offset + per_num); uint32_t repeat = batch_per_core / per_num; uint32_t remain = batch_per_core - repeat * per_num; uint32_t deal_num = per_num; uint32_t is_remain = remain != 0 ? 1 : 0; for (int32_t i = 0; i < repeat + is_remain; i++) { if (i == repeat) { deal_num = remain; } int8_t *output_ptr = output_base + (uint64_t)i * per_num * hidden_size * data_size; int32_t *index_ptr = index_base + i * per_num; // index -> offset __memcpy((void *)nram_index, (void *)index_ptr, deal_num * sizeof(int32_t), GDRAM2NRAM); if (std::is_same::value) { #if __BANG_ARCH__ > 592 __bang_int322int64((int64_t *)nram_offset, (int32_t *)nram_index, deal_num, 0, 0); #else __bang_int322int64((int64_t *)nram_offset, (int32_t *)nram_index, deal_num); #endif } __bang_mul_scalar(nram_offset, nram_offset, (int64_t)data_size * hidden_size, deal_num); // copy if (is_nram_output) { __bang_write_zero((int8_t *)nram_output, deal_num * hidden_size); mluMemcpyDirection_t dir = sram_enable ? SRAM2NRAM : GDRAM2NRAM; // GDRAM or SRAM -> NRAM -> GDRAM #if __BANG_ARCH__ >= 592 // gather requires __gather(nram_output, input_ptr, nram_offset, hidden_size * data_size, dir, hidden_size * data_size, deal_num); #else for (int32_t j = 0; j < deal_num; j++) { T_offset offset_value = *(nram_offset + j); int8_t *input_offset = (int8_t *)input_ptr + offset_value; __memcpy(nram_output + j * hidden_size * data_size, input_offset, hidden_size * data_size, dir); } #endif // __BANG_ARCH__ __memcpy(output_ptr, nram_output, deal_num * hidden_size * data_size, NRAM2GDRAM); } else { // GDRAM -> GDRAM #if __BANG_ARCH__ >= 592 // gather requires __gather(output_ptr, input, (uint64_t *)nram_offset, hidden_size * data_size, GDRAM2GDRAM, hidden_size * data_size, deal_num); #else for (int32_t j = 0; j < deal_num; j++) { T_offset offset_value = *(nram_offset + j); int8_t *input_offset = (int8_t *)input + offset_value; __memcpy(output_ptr + (T_offset)j * hidden_size * data_size, input_offset, hidden_size * data_size, GDRAM2GDRAM); } #endif // __BANG_ARCH__ } } } // T_offset: uint32_t or uint64_t template __mlu_global__ void MLUExpandInputKernel(void *expand_hidden_state, void *hidden_state, int *gather_idx, int *cusum_token_count, int num_token, int hidden_size, int topk, int total_expert_num, int start_expert_id, int expert_count) { int32_t num_index = num_token * topk; int *gather_start_idx = (int *)gather_idx; if (cusum_token_count != nullptr) { num_index = *((int *)cusum_token_count + start_expert_id + expert_count) - *((int *)cusum_token_count + start_expert_id); gather_start_idx = (int *)gather_idx + *(cusum_token_count + start_expert_id); } ExpandInputKernel(expand_hidden_state, hidden_state, gather_start_idx, num_token, hidden_size, num_index); } // instantiate kernels #define INSTANTIATE_ONE(data_size, T_offset) \ template __mlu_global__ void MLUExpandInputKernel( \ void *, void *, int *, int *, int, int, int, int, int, int); INSTANTIATE_ONE(1, uint32_t) INSTANTIATE_ONE(2, uint32_t) INSTANTIATE_ONE(4, uint32_t) INSTANTIATE_ONE(8, uint32_t) // large tensor INSTANTIATE_ONE(1, uint64_t) INSTANTIATE_ONE(2, uint64_t) INSTANTIATE_ONE(4, uint64_t) INSTANTIATE_ONE(8, uint64_t) } // namespace kernels KernelStatus invokeMoeExpandInputKernel(cnrtQueue_t queue, void *expand_hidden_state, const void *hidden_state, const int *gather_idx, const int *cusum_token_count, int num_token, int hidden_size, int topk, cnnlDataType_t data_type, int total_expert_num, int start_expert_id, int expert_count) { CNdev dev; cnCtxGetDevice(&dev); int cluster_num; int core_num; CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev)); CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev)); size_t type_size = 0; cnnlGetSizeOfDataType(data_type, &type_size); int max_cluster_num = (uint64_t)hidden_size * num_token * topk * type_size / (core_num * MEMCPY_BURST_SIZE); cluster_num = std::min(std::max(max_cluster_num, 1), cluster_num); cnrtDim3_t dim{.x = (uint32_t)core_num, .y = (uint32_t)cluster_num, .z = 1}; void (*expand_input_kernels[])(void *, void *, int *, int *, int, int, int, int, int, int) = { kernels::MLUExpandInputKernel<1, uint32_t>, kernels::MLUExpandInputKernel<2, uint32_t>, kernels::MLUExpandInputKernel<4, uint32_t>, kernels::MLUExpandInputKernel<8, uint32_t>, kernels::MLUExpandInputKernel<1, uint64_t>, kernels::MLUExpandInputKernel<2, uint64_t>, kernels::MLUExpandInputKernel<4, uint64_t>, kernels::MLUExpandInputKernel<8, uint64_t>}; bool is_large_tensor = type_size * hidden_size * num_token * topk > INT32_MAX; int kernel_index = (type_size == 8 ? 3 : type_size >> 1) + (is_large_tensor ? 4 : 0); expand_input_kernels[kernel_index]<<>>( (void *)expand_hidden_state, (void *)hidden_state, (int *)gather_idx, (int *)cusum_token_count, num_token, hidden_size, topk, total_expert_num, start_expert_id, expert_count); return KernelStatus::KERNEL_STATUS_SUCCESS; } } // namespace tmo