Files
enginex-mlu370-vllm/torch_mlu_ops-v1.3.2/csrc/kernels/preload.mlu
2026-02-04 17:39:32 +08:00

82 lines
2.3 KiB
Plaintext

#include <cassert>
#include <iostream>
#include <map>
#include <ostream>
#include "cnnl.h"
#include "cnrt.h"
#include "preload.mluh"
// clang-format off
#include <mlu.h>
// clang-format on
namespace tmo {
namespace kernels {
#define SRAM_SIZE ((__MLU_SRAM_SIZE__ - 32) * 1024)
__mlu_shared__ int8_t sram_buffer[SRAM_SIZE];
__mlu_func__ void split(const int64_t total,
const int64_t num,
const int64_t id,
size_t &every,
size_t &offset) {
int64_t base = total / num;
int64_t tail = total - base * num;
every = base + (id < tail ? 1 : 0);
offset = base * id + (id < tail ? id : tail);
}
__mlu_global__ void MLUUnion1Preload(void *filter_ptr, size_t preload_size) {
#if __BANG_ARCH__ > 372
size_t cluster_preload_size = 0;
size_t cluster_preload_offset = 0;
split(preload_size, taskDimY, taskIdY, cluster_preload_size, cluster_preload_offset);
size_t load_repeat = cluster_preload_size / SRAM_SIZE;
size_t load_remain = cluster_preload_size % SRAM_SIZE;
for (size_t i = 0; i < load_repeat + 1; i++) {
if (i == load_repeat && load_remain == 0) {
break;
}
size_t loop_load_size = (i < load_repeat ? SRAM_SIZE : load_remain);
int8_t *gdram_ptr = (int8_t *)filter_ptr + cluster_preload_offset + i * SRAM_SIZE;
if (loop_load_size > 0) {
__memcpy(sram_buffer, gdram_ptr, loop_load_size, GDRAM2SRAM);
}
}
#endif
}
} // namespace kernels
KernelStatus invokePreload(cnrtQueue_t queue,
void *filter_ptr,
size_t filter_size,
size_t preload_size) {
if (preload_size == 0) {
std::cerr << "[invokePreload]: preload_size must be greater than 0." << std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
if (preload_size > filter_size) {
preload_size = filter_size;
}
CNdev dev;
cnCtxGetDevice(&dev);
int cluster_num;
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
cnrtDim3_t dim{.x = 4, .y = (uint32_t)cluster_num, .z = 1};
if (cluster_num == 1) {
dim.y = 1;
} else if (cluster_num >= 2) {
dim.y = 2;
}
kernels::MLUUnion1Preload<<<dim, cnrtFuncTypeUnion1, queue>>>(filter_ptr, preload_size);
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
} // namespace tmo