#include #include #include "cn_api.h" #include "cnnl.h" #include "cnrt.h" #include "generate_mask.mluh" // clang-format off #include // clang-format on namespace tmo { namespace kernels { template __mlu_func__ void write_value(void *dst, unsigned int elem_count, T value) { __bang_write_value(dst, elem_count, value); } template <> __mlu_func__ void write_value(void *dst, unsigned int elem_count, bfloat16_t value) { #if __BANG_ARCH__ >= 500 __bang_write_value(dst, elem_count, value); #endif } // [once_len, once_len] __nram__ int8_t nram_small[(__MLU_NRAM_SIZE__ * 1 / 4 * 1024)]; // [1 + once_len, 2 * once_len] __nram__ int8_t nram_large[(__MLU_NRAM_SIZE__ * 2 / 4 * 1024 + 1024)]; // [once_len * 2 + 1] __nram__ int8_t nram_tiny[2048]; template class GenerateMask { constexpr static int once_len = sizeof(T) == 4 ? 160 : 256; // [once_len, once_len] T *nram_upper = (T *)(nram_small); // [1 + once_len, 2 * once_len] T *nram_buf = (T *)(nram_large); // [once_len, once_len], reuse upper part of nram_buf T *nram_filled = nram_buf; // [once_len, once_len], reuse lower part of nram_buf T *nram_zeros = nram_buf + once_len * once_len; // [once_len] T *nram_ones_zeros = (T *)nram_tiny; __mlu_func__ void initBuffers(T fill_value = -10000) { /* nram_buf: |---once_len---||---once_len---| 0, 1, 1, 1, ..., 1, 0, 0, 0, ... 0, 0, 1, 1, ..., 1, 1, 0, 0, ... 0, 0, 0, 1, ..., 1, 1, 1, 0, ... ... */ nram_buf[0] = 0; constexpr static int copy_size = (once_len * 2 + 1) * sizeof(T); __memcpy(nram_buf + 1, nram_ones_zeros, copy_size, NRAM2NRAM, copy_size, 0, once_len - 1); __memcpy(nram_upper, nram_buf, once_len * sizeof(T), NRAM2NRAM, once_len * sizeof(T), once_len * 2 * sizeof(T), once_len - 1); // nram_buf is nolonger needed write_value(nram_filled, once_len * once_len, (T)fill_value); write_value(nram_zeros, once_len * once_len, (T)0); } __mlu_func__ void dealOneBatch(T *output, // [max_seq_len, max_seq_len] int max_seq_len, int seq_len) { /* | once_len | +----------+-----------------------------------+ | | | | | upper | fill_value | | | | | | +----------+----------+ | | | | | | | | | upper | | fill | | | | | value | | +----------+----------+ | | | | | | | | | upper | | | | 0 | | | | | +----------+---+ | | | u | | |--------------------------------+---+ | | | | fill_value | | | +----------------------------------------------+ */ int tile_count = seq_len / once_len; int tile_remain = seq_len % once_len; int boarder_len = max_seq_len - seq_len; int row = 0; for (; row < tile_count * once_len; row += once_len) { // fill left with zeros // assume that max_seq_len <= once_len^2 if (row > 0) { __memcpy_async(output + (size_t)row * max_seq_len, nram_zeros, row * sizeof(T), NRAM2GDRAM, max_seq_len * sizeof(T), 0, once_len - 1); } // fill middle with upper __memcpy_async(output + (size_t)row * max_seq_len + row, nram_upper, once_len * sizeof(T), NRAM2GDRAM, max_seq_len * sizeof(T), once_len * sizeof(T), once_len - 1); // fill right with fill_value if (row + once_len < max_seq_len) { __memcpy_async(output + (size_t)row * max_seq_len + row + once_len, nram_filled, (max_seq_len - row - once_len) * sizeof(T), NRAM2GDRAM, max_seq_len * sizeof(T), 0, once_len - 1); } } if (tile_remain) { // fill left with zeros if (row > 0) { __memcpy_async(output + (size_t)row * max_seq_len, nram_zeros, row * sizeof(T), NRAM2GDRAM, max_seq_len * sizeof(T), 0, tile_remain - 1); } // fill middle with upper __memcpy_async(output + (size_t)row * max_seq_len + row, nram_upper, tile_remain * sizeof(T), NRAM2GDRAM, max_seq_len * sizeof(T), once_len * sizeof(T), tile_remain - 1); // fill right with fill_value if (row + tile_remain < max_seq_len) { __memcpy_async(output + (size_t)row * max_seq_len + row + tile_remain, nram_filled, (max_seq_len - row - tile_remain) * sizeof(T), NRAM2GDRAM, max_seq_len * sizeof(T), 0, tile_remain - 1); } } if (boarder_len) { // fill right boarder with fill_value __memcpy_async(output + seq_len, nram_filled, boarder_len * sizeof(T), NRAM2GDRAM, max_seq_len * sizeof(T), 0, (max_seq_len - boarder_len) - 1); // fill bottom boarder with fill_value __memcpy_async(output + (size_t)seq_len * max_seq_len, nram_filled, max_seq_len * sizeof(T), NRAM2GDRAM, max_seq_len * sizeof(T), 0, boarder_len - 1); } __sync_io(); } public: __mlu_func__ void execute(T *output_ddr, // [total_batch, max_seq_len, max_seq_len] int *batch_seq_len, int total_batch, int max_seq_len, T fill_value = -10000) { int batch_each = total_batch / taskDimY; int batch_remain = total_batch % taskDimY; int batch_start = taskIdY * batch_each + (taskIdY < batch_remain ? taskIdY : batch_remain); int batch_count = batch_each + (taskIdY < batch_remain ? 1 : 0); write_value(nram_ones_zeros, once_len, (T)fill_value); write_value(nram_ones_zeros + once_len, once_len + 1, (T)0); initBuffers(); for (int n = batch_start; n < batch_start + batch_count; n++) { T *output = output_ddr + (size_t)n * max_seq_len * max_seq_len; int seq_len = batch_seq_len[n]; dealOneBatch(output, max_seq_len, seq_len); } } }; template __mlu_global__ void MLUUnion1GenerateMask(T *output_ddr, // [total_batch, max_seq_len, max_seq_len] int *batch_seq_len, int total_batch, int max_seq_len, T fill_value = -10000) { if (coreId != 0) { return; // we only use 1 core in a cluster } GenerateMask().execute(output_ddr, batch_seq_len, total_batch, max_seq_len, fill_value); } } // namespace kernels KernelStatus invokeGenerateMask(cnnlHandle_t handle, void *output_ddr, int *batch_seq_len, int total_batch, int max_seq_len, cnnlDataType_t data_type, float fill_value) { cnrtQueue_t queue; cnnlGetQueue(handle, &queue); CNdev dev; cnnlGetDevice(handle, &dev); int cluster_num; CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev)); cnrtDim3_t dim; dim.x = 4; dim.y = cluster_num; dim.z = 1; if (data_type == CNNL_DTYPE_FLOAT) { kernels::MLUUnion1GenerateMask<<>>( static_cast(output_ddr), batch_seq_len, total_batch, max_seq_len, static_cast(fill_value)); } else if (data_type == CNNL_DTYPE_HALF) { kernels::MLUUnion1GenerateMask<<>>( static_cast(output_ddr), batch_seq_len, total_batch, max_seq_len, static_cast(fill_value)); } else if (data_type == CNNL_DTYPE_BFLOAT16) { if (!isBf16Supported()) { std::cerr << "[invokeGenerateMask]: MLU300 devices do not support bfloat16." << std::endl; return KernelStatus::KERNEL_STATUS_FAILED; } kernels::MLUUnion1GenerateMask<<>>( static_cast(output_ddr), batch_seq_len, total_batch, max_seq_len, static_cast(fill_value)); } else { std::cerr << "[invokeGenerateMask]: invokeGenerateMask: data_type is not supported" << std::endl; return KernelStatus::KERNEL_STATUS_FAILED; } return KernelStatus::KERNEL_STATUS_SUCCESS; } } // namespace tmo