193 lines
8.7 KiB
Plaintext
193 lines
8.7 KiB
Plaintext
#include <stdint.h>
|
|
#include <cmath>
|
|
#include <iostream>
|
|
#include <vector>
|
|
#include "cnnl.h"
|
|
#include "cnrt.h"
|
|
#include "copy_blocks.mluh"
|
|
#include "kernel_utils.h"
|
|
// clang-format off
|
|
#include <mlu.h>
|
|
// clang-format on
|
|
|
|
namespace tmo {
|
|
#define NRAM_REMAIN_SIZE (32 * 1024)
|
|
#define NRAM_BUFFER_SIZE (__MLU_NRAM_SIZE__ * 1024 - NRAM_REMAIN_SIZE)
|
|
#define USE_GATHER_THRESHHOLD_BLOCKSIZE 458753
|
|
#define LAYER_SIZE 128
|
|
#define BLOCK_PAIR_SIZE 512
|
|
#define ALIGN_BYTES 64
|
|
|
|
struct CopyBlocksInfo {
|
|
void *key_addrs[LAYER_SIZE];
|
|
void *value_addrs[LAYER_SIZE];
|
|
unsigned int mapping_addrs[BLOCK_PAIR_SIZE * 2];
|
|
bool has_value_cache = true;
|
|
};
|
|
|
|
namespace kernels {
|
|
__nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE];
|
|
|
|
__mlu_func__ void copyBlocksNodld(CopyBlocksInfo info,
|
|
uint32_t num_per_core,
|
|
uint32_t block_mapping_offset,
|
|
int32_t num_layers,
|
|
uint32_t block_size_in_bytes) {
|
|
for (uint32_t i = 0; i < num_per_core; i++) {
|
|
uint32_t map_offset = block_mapping_offset + i * 2;
|
|
uint32_t src_idx = info.mapping_addrs[map_offset];
|
|
uint32_t dst_idx = info.mapping_addrs[map_offset + 1];
|
|
int64_t src_offset = block_size_in_bytes * src_idx;
|
|
int64_t dst_offset = block_size_in_bytes * dst_idx;
|
|
for (uint32_t j = 0; j < num_layers; j++) {
|
|
__memcpy((int8_t *)info.key_addrs[j] + dst_offset, (int8_t *)info.key_addrs[j] + src_offset,
|
|
block_size_in_bytes, GDRAM2GDRAM);
|
|
if (info.has_value_cache) {
|
|
__memcpy((int8_t *)info.value_addrs[j] + dst_offset,
|
|
(int8_t *)info.value_addrs[j] + src_offset, block_size_in_bytes, GDRAM2GDRAM);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
__mlu_global__ void launchCopyBlocksKernel(CopyBlocksInfo info,
|
|
int32_t num_pairs,
|
|
int32_t num_layers,
|
|
uint32_t block_size_in_bytes) {
|
|
uint32_t num_per_core = num_pairs / taskDim;
|
|
uint32_t remain_for_core = num_pairs % taskDim;
|
|
num_per_core += ((taskId < remain_for_core) ? 1 : 0);
|
|
uint32_t block_mapping_offset =
|
|
num_per_core * taskId + ((taskId < remain_for_core) ? 0 : remain_for_core);
|
|
block_mapping_offset *= 2;
|
|
#if (__BANG_ARCH__ >= 592)
|
|
if (block_size_in_bytes < USE_GATHER_THRESHHOLD_BLOCKSIZE) {
|
|
auto num_pair_data_width = sizeof(int32_t);
|
|
uint32_t align_num = ALIGN_BYTES / num_pair_data_width;
|
|
unsigned int num_per_core_2 = num_per_core * 2;
|
|
unsigned int num_per_core_2_align = (num_per_core_2 + align_num - 1) / align_num * align_num;
|
|
unsigned int *gather_src_offset = (unsigned int *)nram_buffer;
|
|
unsigned int *block_mapping_src_dst = gather_src_offset + num_per_core_2_align;
|
|
int8_t *n_buffer = (int8_t *)(block_mapping_src_dst + num_per_core_2_align);
|
|
uint32_t nram_remain = NRAM_BUFFER_SIZE - sizeof(unsigned int *) * num_per_core_2_align * 2;
|
|
unsigned int *scatter_dst_offset = gather_src_offset + num_per_core;
|
|
|
|
uint32_t num_per_loop = nram_remain / block_size_in_bytes;
|
|
uint32_t repeat = num_per_core / num_per_loop;
|
|
uint32_t remain = num_per_core % num_per_loop;
|
|
|
|
for (int i = 0; i < num_per_core; i++) {
|
|
unsigned int mapping_addrs_idx = block_mapping_offset + i * 2;
|
|
block_mapping_src_dst[i] = info.mapping_addrs[mapping_addrs_idx];
|
|
block_mapping_src_dst[num_per_core + i] = info.mapping_addrs[mapping_addrs_idx + 1];
|
|
}
|
|
__bang_mul_scalar(gather_src_offset, block_mapping_src_dst, (unsigned int)block_size_in_bytes,
|
|
num_per_core_2);
|
|
__sync();
|
|
for (uint32_t k = 0; k < num_layers; k++) {
|
|
for (uint32_t i = 0; i < repeat; i++) {
|
|
__gather_async(n_buffer, info.key_addrs[k], gather_src_offset + i * num_per_loop,
|
|
(unsigned int)block_size_in_bytes, GDRAM2NRAM,
|
|
(unsigned int)block_size_in_bytes, num_per_loop);
|
|
__scatter_async(info.key_addrs[k], n_buffer, scatter_dst_offset + i * num_per_loop,
|
|
(unsigned int)block_size_in_bytes, NRAM2GDRAM,
|
|
(unsigned int)block_size_in_bytes, num_per_loop);
|
|
if (info.has_value_cache) {
|
|
__gather_async(n_buffer, info.value_addrs[k], gather_src_offset + i * num_per_loop,
|
|
(unsigned int)block_size_in_bytes, GDRAM2NRAM,
|
|
(unsigned int)block_size_in_bytes, num_per_loop);
|
|
__scatter_async(info.value_addrs[k], n_buffer, scatter_dst_offset + i * num_per_loop,
|
|
(unsigned int)block_size_in_bytes, NRAM2GDRAM,
|
|
(unsigned int)block_size_in_bytes, num_per_loop);
|
|
}
|
|
}
|
|
if (remain != 0) {
|
|
uint32_t repeat_nums = repeat * num_per_loop;
|
|
__gather_async(n_buffer, info.key_addrs[k], gather_src_offset + repeat_nums,
|
|
(unsigned int)block_size_in_bytes, GDRAM2NRAM,
|
|
(unsigned int)block_size_in_bytes, remain);
|
|
__scatter_async(info.key_addrs[k], n_buffer, scatter_dst_offset + repeat_nums,
|
|
(unsigned int)block_size_in_bytes, NRAM2GDRAM,
|
|
(unsigned int)block_size_in_bytes, remain);
|
|
if (info.has_value_cache) {
|
|
__gather_async(n_buffer, info.value_addrs[k], gather_src_offset + repeat_nums,
|
|
(unsigned int)block_size_in_bytes, GDRAM2NRAM,
|
|
(unsigned int)block_size_in_bytes, remain);
|
|
__scatter_async(info.value_addrs[k], n_buffer, scatter_dst_offset + repeat_nums,
|
|
(unsigned int)block_size_in_bytes, NRAM2GDRAM,
|
|
(unsigned int)block_size_in_bytes, remain);
|
|
}
|
|
}
|
|
}
|
|
} else {
|
|
copyBlocksNodld(info, num_per_core, block_mapping_offset, num_layers, block_size_in_bytes);
|
|
}
|
|
#else
|
|
copyBlocksNodld(info, num_per_core, block_mapping_offset, num_layers, block_size_in_bytes);
|
|
#endif
|
|
}
|
|
} // namespace kernels
|
|
|
|
KernelStatus invokeCopyBlocksKernel(const cnrtQueue_t queue,
|
|
const std::vector<void *> &key_caches,
|
|
const std::vector<void *> &value_caches,
|
|
const std::vector<int32_t> &block_mapping_vec,
|
|
const size_t block_size_in_bytes) {
|
|
CNdev dev;
|
|
cnCtxGetDevice(&dev);
|
|
|
|
int cluster_num;
|
|
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
|
|
int core_num;
|
|
CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev));
|
|
cnrtFunctionType_t k_type = cnrtFuncTypeBlock;
|
|
if (key_caches.empty()) {
|
|
std::cerr << "[invokeCopyBlocksKernel]: key_caches can not be empty." << std::endl;
|
|
return KernelStatus::KERNEL_STATUS_FAILED;
|
|
}
|
|
if (!value_caches.empty() && key_caches.size() != value_caches.size()) {
|
|
std::cerr << "[invokeCopyBlocksKernel]: key_caches size must equal to value_caches "
|
|
<< "size if value_caches is not empty." << std::endl;
|
|
return KernelStatus::KERNEL_STATUS_FAILED;
|
|
}
|
|
int32_t mapping_size = block_mapping_vec.size();
|
|
int32_t num_pairs = mapping_size / 2;
|
|
uint32_t task_dim = std::min(num_pairs, cluster_num * core_num);
|
|
cnrtDim3_t k_dim{task_dim, 1, 1};
|
|
|
|
int32_t num_layers = key_caches.size();
|
|
int32_t layer_loop_num = std::ceil(float(num_layers) / LAYER_SIZE);
|
|
int32_t layer_num_per_loop = std::ceil(float(num_layers) / layer_loop_num);
|
|
int32_t pair_loop_num = std::ceil(float(num_pairs) / BLOCK_PAIR_SIZE);
|
|
int32_t pair_num_per_loop = std::ceil(float(num_pairs) / pair_loop_num);
|
|
|
|
CopyBlocksInfo info;
|
|
if (value_caches.empty()) {
|
|
info.has_value_cache = false;
|
|
}
|
|
for (int32_t i = 0; i < layer_loop_num; i++) {
|
|
int32_t sub_num_layers =
|
|
std::min(int32_t(layer_num_per_loop), num_layers - i * layer_num_per_loop);
|
|
for (int32_t l = 0; l < sub_num_layers; l++) {
|
|
info.key_addrs[l] = key_caches[l + i * layer_num_per_loop];
|
|
if (info.has_value_cache) {
|
|
info.value_addrs[l] = value_caches[l + i * layer_num_per_loop];
|
|
}
|
|
}
|
|
for (int32_t j = 0; j < pair_loop_num; j++) {
|
|
int32_t sub_num_pairs =
|
|
std::min(int32_t(pair_num_per_loop), num_pairs - j * pair_num_per_loop);
|
|
int32_t lens_block_mapping = sub_num_pairs * 2;
|
|
int32_t block_vec_offset = j * pair_num_per_loop * 2;
|
|
for (int32_t m = 0; m < lens_block_mapping; m++) {
|
|
info.mapping_addrs[m] = block_mapping_vec[m + block_vec_offset];
|
|
}
|
|
kernels::launchCopyBlocksKernel<<<k_dim, k_type, queue>>>(info, sub_num_pairs, sub_num_layers,
|
|
block_size_in_bytes);
|
|
}
|
|
}
|
|
|
|
return KernelStatus::KERNEL_STATUS_SUCCESS;
|
|
}
|
|
} // namespace tmo
|