215 lines
8.6 KiB
Plaintext
215 lines
8.6 KiB
Plaintext
#include <cstddef>
|
|
#include <iostream>
|
|
#include "cn_api.h"
|
|
#include "cnnl.h"
|
|
#include "cnrt.h"
|
|
#include "generate_mask.mluh"
|
|
// clang-format off
|
|
#include <mlu.h>
|
|
// clang-format on
|
|
|
|
namespace tmo {
|
|
namespace kernels {
|
|
template <typename T>
|
|
__mlu_func__ void write_value(void *dst, unsigned int elem_count, T value) {
|
|
__bang_write_value(dst, elem_count, value);
|
|
}
|
|
|
|
template <>
|
|
__mlu_func__ void write_value(void *dst, unsigned int elem_count, bfloat16_t value) {
|
|
#if __BANG_ARCH__ >= 500
|
|
__bang_write_value(dst, elem_count, value);
|
|
#endif
|
|
}
|
|
|
|
// [once_len, once_len]
|
|
__nram__ int8_t nram_small[(__MLU_NRAM_SIZE__ * 1 / 4 * 1024)];
|
|
// [1 + once_len, 2 * once_len]
|
|
__nram__ int8_t nram_large[(__MLU_NRAM_SIZE__ * 2 / 4 * 1024 + 1024)];
|
|
// [once_len * 2 + 1]
|
|
__nram__ int8_t nram_tiny[2048];
|
|
template <typename T>
|
|
class GenerateMask {
|
|
constexpr static int once_len = sizeof(T) == 4 ? 160 : 256;
|
|
// [once_len, once_len]
|
|
T *nram_upper = (T *)(nram_small);
|
|
// [1 + once_len, 2 * once_len]
|
|
T *nram_buf = (T *)(nram_large);
|
|
// [once_len, once_len], reuse upper part of nram_buf
|
|
T *nram_filled = nram_buf;
|
|
// [once_len, once_len], reuse lower part of nram_buf
|
|
T *nram_zeros = nram_buf + once_len * once_len;
|
|
// [once_len]
|
|
T *nram_ones_zeros = (T *)nram_tiny;
|
|
|
|
__mlu_func__ void initBuffers(T fill_value = -10000) {
|
|
/* nram_buf:
|
|
|---once_len---||---once_len---|
|
|
0, 1, 1, 1, ..., 1, 0, 0, 0, ...
|
|
0, 0, 1, 1, ..., 1, 1, 0, 0, ...
|
|
0, 0, 0, 1, ..., 1, 1, 1, 0, ...
|
|
... */
|
|
nram_buf[0] = 0;
|
|
constexpr static int copy_size = (once_len * 2 + 1) * sizeof(T);
|
|
__memcpy(nram_buf + 1, nram_ones_zeros, copy_size, NRAM2NRAM, copy_size, 0, once_len - 1);
|
|
__memcpy(nram_upper, nram_buf, once_len * sizeof(T), NRAM2NRAM, once_len * sizeof(T),
|
|
once_len * 2 * sizeof(T), once_len - 1);
|
|
// nram_buf is nolonger needed
|
|
write_value(nram_filled, once_len * once_len, (T)fill_value);
|
|
write_value(nram_zeros, once_len * once_len, (T)0);
|
|
}
|
|
|
|
__mlu_func__ void dealOneBatch(T *output, // [max_seq_len, max_seq_len]
|
|
int max_seq_len,
|
|
int seq_len) {
|
|
/*
|
|
| once_len |
|
|
+----------+-----------------------------------+
|
|
| | | |
|
|
| upper | fill_value | |
|
|
| | | |
|
|
+----------+----------+ | |
|
|
| | | | |
|
|
| | upper | | fill |
|
|
| | | | value |
|
|
| +----------+----------+ | |
|
|
| | | | |
|
|
| | upper | | |
|
|
| 0 | | | |
|
|
| +----------+---+ |
|
|
| | u | |
|
|
|--------------------------------+---+ |
|
|
| |
|
|
| fill_value |
|
|
| |
|
|
+----------------------------------------------+
|
|
*/
|
|
int tile_count = seq_len / once_len;
|
|
int tile_remain = seq_len % once_len;
|
|
int boarder_len = max_seq_len - seq_len;
|
|
int row = 0;
|
|
for (; row < tile_count * once_len; row += once_len) {
|
|
// fill left with zeros
|
|
// assume that max_seq_len <= once_len^2
|
|
if (row > 0) {
|
|
__memcpy_async(output + (size_t)row * max_seq_len, nram_zeros, row * sizeof(T), NRAM2GDRAM,
|
|
max_seq_len * sizeof(T), 0, once_len - 1);
|
|
}
|
|
// fill middle with upper
|
|
__memcpy_async(output + (size_t)row * max_seq_len + row, nram_upper, once_len * sizeof(T),
|
|
NRAM2GDRAM, max_seq_len * sizeof(T), once_len * sizeof(T), once_len - 1);
|
|
// fill right with fill_value
|
|
if (row + once_len < max_seq_len) {
|
|
__memcpy_async(output + (size_t)row * max_seq_len + row + once_len, nram_filled,
|
|
(max_seq_len - row - once_len) * sizeof(T), NRAM2GDRAM,
|
|
max_seq_len * sizeof(T), 0, once_len - 1);
|
|
}
|
|
}
|
|
|
|
if (tile_remain) {
|
|
// fill left with zeros
|
|
if (row > 0) {
|
|
__memcpy_async(output + (size_t)row * max_seq_len, nram_zeros, row * sizeof(T), NRAM2GDRAM,
|
|
max_seq_len * sizeof(T), 0, tile_remain - 1);
|
|
}
|
|
// fill middle with upper
|
|
__memcpy_async(output + (size_t)row * max_seq_len + row, nram_upper, tile_remain * sizeof(T),
|
|
NRAM2GDRAM, max_seq_len * sizeof(T), once_len * sizeof(T), tile_remain - 1);
|
|
// fill right with fill_value
|
|
if (row + tile_remain < max_seq_len) {
|
|
__memcpy_async(output + (size_t)row * max_seq_len + row + tile_remain, nram_filled,
|
|
(max_seq_len - row - tile_remain) * sizeof(T), NRAM2GDRAM,
|
|
max_seq_len * sizeof(T), 0, tile_remain - 1);
|
|
}
|
|
}
|
|
|
|
if (boarder_len) {
|
|
// fill right boarder with fill_value
|
|
__memcpy_async(output + seq_len, nram_filled, boarder_len * sizeof(T), NRAM2GDRAM,
|
|
max_seq_len * sizeof(T), 0, (max_seq_len - boarder_len) - 1);
|
|
// fill bottom boarder with fill_value
|
|
__memcpy_async(output + (size_t)seq_len * max_seq_len, nram_filled, max_seq_len * sizeof(T),
|
|
NRAM2GDRAM, max_seq_len * sizeof(T), 0, boarder_len - 1);
|
|
}
|
|
__sync_io();
|
|
}
|
|
|
|
public:
|
|
__mlu_func__ void execute(T *output_ddr, // [total_batch, max_seq_len, max_seq_len]
|
|
int *batch_seq_len,
|
|
int total_batch,
|
|
int max_seq_len,
|
|
T fill_value = -10000) {
|
|
int batch_each = total_batch / taskDimY;
|
|
int batch_remain = total_batch % taskDimY;
|
|
int batch_start = taskIdY * batch_each + (taskIdY < batch_remain ? taskIdY : batch_remain);
|
|
int batch_count = batch_each + (taskIdY < batch_remain ? 1 : 0);
|
|
write_value(nram_ones_zeros, once_len, (T)fill_value);
|
|
write_value(nram_ones_zeros + once_len, once_len + 1, (T)0);
|
|
initBuffers();
|
|
|
|
for (int n = batch_start; n < batch_start + batch_count; n++) {
|
|
T *output = output_ddr + (size_t)n * max_seq_len * max_seq_len;
|
|
int seq_len = batch_seq_len[n];
|
|
dealOneBatch(output, max_seq_len, seq_len);
|
|
}
|
|
}
|
|
};
|
|
|
|
template <typename T>
|
|
__mlu_global__ void MLUUnion1GenerateMask(T *output_ddr, // [total_batch, max_seq_len, max_seq_len]
|
|
int *batch_seq_len,
|
|
int total_batch,
|
|
int max_seq_len,
|
|
T fill_value = -10000) {
|
|
if (coreId != 0) {
|
|
return; // we only use 1 core in a cluster
|
|
}
|
|
GenerateMask<T>().execute(output_ddr, batch_seq_len, total_batch, max_seq_len, fill_value);
|
|
}
|
|
} // namespace kernels
|
|
|
|
KernelStatus invokeGenerateMask(cnnlHandle_t handle,
|
|
void *output_ddr,
|
|
int *batch_seq_len,
|
|
int total_batch,
|
|
int max_seq_len,
|
|
cnnlDataType_t data_type,
|
|
float fill_value) {
|
|
cnrtQueue_t queue;
|
|
cnnlGetQueue(handle, &queue);
|
|
CNdev dev;
|
|
cnnlGetDevice(handle, &dev);
|
|
int cluster_num;
|
|
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
|
|
cnrtDim3_t dim;
|
|
dim.x = 4;
|
|
dim.y = cluster_num;
|
|
dim.z = 1;
|
|
if (data_type == CNNL_DTYPE_FLOAT) {
|
|
kernels::MLUUnion1GenerateMask<float><<<dim, cnrtFuncTypeUnion1, queue>>>(
|
|
static_cast<float *>(output_ddr), batch_seq_len, total_batch, max_seq_len,
|
|
static_cast<float>(fill_value));
|
|
} else if (data_type == CNNL_DTYPE_HALF) {
|
|
kernels::MLUUnion1GenerateMask<half><<<dim, cnrtFuncTypeUnion1, queue>>>(
|
|
static_cast<half *>(output_ddr), batch_seq_len, total_batch, max_seq_len,
|
|
static_cast<half>(fill_value));
|
|
} else if (data_type == CNNL_DTYPE_BFLOAT16) {
|
|
if (!isBf16Supported()) {
|
|
std::cerr << "[invokeGenerateMask]: MLU300 devices do not support bfloat16." << std::endl;
|
|
return KernelStatus::KERNEL_STATUS_FAILED;
|
|
}
|
|
kernels::MLUUnion1GenerateMask<bfloat16_t><<<dim, cnrtFuncTypeUnion1, queue>>>(
|
|
static_cast<bfloat16_t *>(output_ddr), batch_seq_len, total_batch, max_seq_len,
|
|
static_cast<bfloat16_t>(fill_value));
|
|
} else {
|
|
std::cerr << "[invokeGenerateMask]: invokeGenerateMask: data_type is not supported"
|
|
<< std::endl;
|
|
return KernelStatus::KERNEL_STATUS_FAILED;
|
|
}
|
|
|
|
return KernelStatus::KERNEL_STATUS_SUCCESS;
|
|
}
|
|
|
|
} // namespace tmo
|