forked from EngineX-Cambricon/enginex-mlu370-vllm
add ops
This commit is contained in:
81
torch_mlu_ops-v1.3.2/csrc/kernels/preload.mlu
Normal file
81
torch_mlu_ops-v1.3.2/csrc/kernels/preload.mlu
Normal file
@@ -0,0 +1,81 @@
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <ostream>
|
||||
#include "cnnl.h"
|
||||
#include "cnrt.h"
|
||||
#include "preload.mluh"
|
||||
// clang-format off
|
||||
#include <mlu.h>
|
||||
// clang-format on
|
||||
|
||||
namespace tmo {
|
||||
namespace kernels {
|
||||
|
||||
#define SRAM_SIZE ((__MLU_SRAM_SIZE__ - 32) * 1024)
|
||||
__mlu_shared__ int8_t sram_buffer[SRAM_SIZE];
|
||||
|
||||
__mlu_func__ void split(const int64_t total,
|
||||
const int64_t num,
|
||||
const int64_t id,
|
||||
size_t &every,
|
||||
size_t &offset) {
|
||||
int64_t base = total / num;
|
||||
int64_t tail = total - base * num;
|
||||
every = base + (id < tail ? 1 : 0);
|
||||
offset = base * id + (id < tail ? id : tail);
|
||||
}
|
||||
|
||||
__mlu_global__ void MLUUnion1Preload(void *filter_ptr, size_t preload_size) {
|
||||
#if __BANG_ARCH__ > 372
|
||||
size_t cluster_preload_size = 0;
|
||||
size_t cluster_preload_offset = 0;
|
||||
split(preload_size, taskDimY, taskIdY, cluster_preload_size, cluster_preload_offset);
|
||||
|
||||
size_t load_repeat = cluster_preload_size / SRAM_SIZE;
|
||||
size_t load_remain = cluster_preload_size % SRAM_SIZE;
|
||||
|
||||
for (size_t i = 0; i < load_repeat + 1; i++) {
|
||||
if (i == load_repeat && load_remain == 0) {
|
||||
break;
|
||||
}
|
||||
size_t loop_load_size = (i < load_repeat ? SRAM_SIZE : load_remain);
|
||||
int8_t *gdram_ptr = (int8_t *)filter_ptr + cluster_preload_offset + i * SRAM_SIZE;
|
||||
if (loop_load_size > 0) {
|
||||
__memcpy(sram_buffer, gdram_ptr, loop_load_size, GDRAM2SRAM);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace kernels
|
||||
|
||||
KernelStatus invokePreload(cnrtQueue_t queue,
|
||||
void *filter_ptr,
|
||||
size_t filter_size,
|
||||
size_t preload_size) {
|
||||
if (preload_size == 0) {
|
||||
std::cerr << "[invokePreload]: preload_size must be greater than 0." << std::endl;
|
||||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||||
}
|
||||
|
||||
if (preload_size > filter_size) {
|
||||
preload_size = filter_size;
|
||||
}
|
||||
|
||||
CNdev dev;
|
||||
cnCtxGetDevice(&dev);
|
||||
int cluster_num;
|
||||
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
|
||||
cnrtDim3_t dim{.x = 4, .y = (uint32_t)cluster_num, .z = 1};
|
||||
if (cluster_num == 1) {
|
||||
dim.y = 1;
|
||||
} else if (cluster_num >= 2) {
|
||||
dim.y = 2;
|
||||
}
|
||||
|
||||
kernels::MLUUnion1Preload<<<dim, cnrtFuncTypeUnion1, queue>>>(filter_ptr, preload_size);
|
||||
return KernelStatus::KERNEL_STATUS_SUCCESS;
|
||||
}
|
||||
|
||||
} // namespace tmo
|
||||
Reference in New Issue
Block a user