82 lines
2.3 KiB
Plaintext
82 lines
2.3 KiB
Plaintext
#include <cassert>
|
|
#include <iostream>
|
|
#include <map>
|
|
#include <ostream>
|
|
#include "cnnl.h"
|
|
#include "cnrt.h"
|
|
#include "preload.mluh"
|
|
// clang-format off
|
|
#include <mlu.h>
|
|
// clang-format on
|
|
|
|
namespace tmo {
|
|
namespace kernels {
|
|
|
|
#define SRAM_SIZE ((__MLU_SRAM_SIZE__ - 32) * 1024)
|
|
__mlu_shared__ int8_t sram_buffer[SRAM_SIZE];
|
|
|
|
__mlu_func__ void split(const int64_t total,
|
|
const int64_t num,
|
|
const int64_t id,
|
|
size_t &every,
|
|
size_t &offset) {
|
|
int64_t base = total / num;
|
|
int64_t tail = total - base * num;
|
|
every = base + (id < tail ? 1 : 0);
|
|
offset = base * id + (id < tail ? id : tail);
|
|
}
|
|
|
|
__mlu_global__ void MLUUnion1Preload(void *filter_ptr, size_t preload_size) {
|
|
#if __BANG_ARCH__ > 372
|
|
size_t cluster_preload_size = 0;
|
|
size_t cluster_preload_offset = 0;
|
|
split(preload_size, taskDimY, taskIdY, cluster_preload_size, cluster_preload_offset);
|
|
|
|
size_t load_repeat = cluster_preload_size / SRAM_SIZE;
|
|
size_t load_remain = cluster_preload_size % SRAM_SIZE;
|
|
|
|
for (size_t i = 0; i < load_repeat + 1; i++) {
|
|
if (i == load_repeat && load_remain == 0) {
|
|
break;
|
|
}
|
|
size_t loop_load_size = (i < load_repeat ? SRAM_SIZE : load_remain);
|
|
int8_t *gdram_ptr = (int8_t *)filter_ptr + cluster_preload_offset + i * SRAM_SIZE;
|
|
if (loop_load_size > 0) {
|
|
__memcpy(sram_buffer, gdram_ptr, loop_load_size, GDRAM2SRAM);
|
|
}
|
|
}
|
|
#endif
|
|
}
|
|
|
|
} // namespace kernels
|
|
|
|
KernelStatus invokePreload(cnrtQueue_t queue,
|
|
void *filter_ptr,
|
|
size_t filter_size,
|
|
size_t preload_size) {
|
|
if (preload_size == 0) {
|
|
std::cerr << "[invokePreload]: preload_size must be greater than 0." << std::endl;
|
|
return KernelStatus::KERNEL_STATUS_FAILED;
|
|
}
|
|
|
|
if (preload_size > filter_size) {
|
|
preload_size = filter_size;
|
|
}
|
|
|
|
CNdev dev;
|
|
cnCtxGetDevice(&dev);
|
|
int cluster_num;
|
|
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
|
|
cnrtDim3_t dim{.x = 4, .y = (uint32_t)cluster_num, .z = 1};
|
|
if (cluster_num == 1) {
|
|
dim.y = 1;
|
|
} else if (cluster_num >= 2) {
|
|
dim.y = 2;
|
|
}
|
|
|
|
kernels::MLUUnion1Preload<<<dim, cnrtFuncTypeUnion1, queue>>>(filter_ptr, preload_size);
|
|
return KernelStatus::KERNEL_STATUS_SUCCESS;
|
|
}
|
|
|
|
} // namespace tmo
|