#include #include #include #include #include "cnnl.h" #include "cnrt.h" #include "copy_blocks.mluh" #include "kernel_utils.h" // clang-format off #include // clang-format on namespace tmo { #define NRAM_REMAIN_SIZE (32 * 1024) #define NRAM_BUFFER_SIZE (__MLU_NRAM_SIZE__ * 1024 - NRAM_REMAIN_SIZE) #define USE_GATHER_THRESHHOLD_BLOCKSIZE 458753 #define LAYER_SIZE 128 #define BLOCK_PAIR_SIZE 512 #define ALIGN_BYTES 64 struct CopyBlocksInfo { void *key_addrs[LAYER_SIZE]; void *value_addrs[LAYER_SIZE]; unsigned int mapping_addrs[BLOCK_PAIR_SIZE * 2]; bool has_value_cache = true; }; namespace kernels { __nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE]; __mlu_func__ void copyBlocksNodld(CopyBlocksInfo info, uint32_t num_per_core, uint32_t block_mapping_offset, int32_t num_layers, uint32_t block_size_in_bytes) { for (uint32_t i = 0; i < num_per_core; i++) { uint32_t map_offset = block_mapping_offset + i * 2; uint32_t src_idx = info.mapping_addrs[map_offset]; uint32_t dst_idx = info.mapping_addrs[map_offset + 1]; int64_t src_offset = block_size_in_bytes * src_idx; int64_t dst_offset = block_size_in_bytes * dst_idx; for (uint32_t j = 0; j < num_layers; j++) { __memcpy((int8_t *)info.key_addrs[j] + dst_offset, (int8_t *)info.key_addrs[j] + src_offset, block_size_in_bytes, GDRAM2GDRAM); if (info.has_value_cache) { __memcpy((int8_t *)info.value_addrs[j] + dst_offset, (int8_t *)info.value_addrs[j] + src_offset, block_size_in_bytes, GDRAM2GDRAM); } } } } __mlu_global__ void launchCopyBlocksKernel(CopyBlocksInfo info, int32_t num_pairs, int32_t num_layers, uint32_t block_size_in_bytes) { uint32_t num_per_core = num_pairs / taskDim; uint32_t remain_for_core = num_pairs % taskDim; num_per_core += ((taskId < remain_for_core) ? 1 : 0); uint32_t block_mapping_offset = num_per_core * taskId + ((taskId < remain_for_core) ? 0 : remain_for_core); block_mapping_offset *= 2; #if (__BANG_ARCH__ >= 592) if (block_size_in_bytes < USE_GATHER_THRESHHOLD_BLOCKSIZE) { auto num_pair_data_width = sizeof(int32_t); uint32_t align_num = ALIGN_BYTES / num_pair_data_width; unsigned int num_per_core_2 = num_per_core * 2; unsigned int num_per_core_2_align = (num_per_core_2 + align_num - 1) / align_num * align_num; unsigned int *gather_src_offset = (unsigned int *)nram_buffer; unsigned int *block_mapping_src_dst = gather_src_offset + num_per_core_2_align; int8_t *n_buffer = (int8_t *)(block_mapping_src_dst + num_per_core_2_align); uint32_t nram_remain = NRAM_BUFFER_SIZE - sizeof(unsigned int *) * num_per_core_2_align * 2; unsigned int *scatter_dst_offset = gather_src_offset + num_per_core; uint32_t num_per_loop = nram_remain / block_size_in_bytes; uint32_t repeat = num_per_core / num_per_loop; uint32_t remain = num_per_core % num_per_loop; for (int i = 0; i < num_per_core; i++) { unsigned int mapping_addrs_idx = block_mapping_offset + i * 2; block_mapping_src_dst[i] = info.mapping_addrs[mapping_addrs_idx]; block_mapping_src_dst[num_per_core + i] = info.mapping_addrs[mapping_addrs_idx + 1]; } __bang_mul_scalar(gather_src_offset, block_mapping_src_dst, (unsigned int)block_size_in_bytes, num_per_core_2); __sync(); for (uint32_t k = 0; k < num_layers; k++) { for (uint32_t i = 0; i < repeat; i++) { __gather_async(n_buffer, info.key_addrs[k], gather_src_offset + i * num_per_loop, (unsigned int)block_size_in_bytes, GDRAM2NRAM, (unsigned int)block_size_in_bytes, num_per_loop); __scatter_async(info.key_addrs[k], n_buffer, scatter_dst_offset + i * num_per_loop, (unsigned int)block_size_in_bytes, NRAM2GDRAM, (unsigned int)block_size_in_bytes, num_per_loop); if (info.has_value_cache) { __gather_async(n_buffer, info.value_addrs[k], gather_src_offset + i * num_per_loop, (unsigned int)block_size_in_bytes, GDRAM2NRAM, (unsigned int)block_size_in_bytes, num_per_loop); __scatter_async(info.value_addrs[k], n_buffer, scatter_dst_offset + i * num_per_loop, (unsigned int)block_size_in_bytes, NRAM2GDRAM, (unsigned int)block_size_in_bytes, num_per_loop); } } if (remain != 0) { uint32_t repeat_nums = repeat * num_per_loop; __gather_async(n_buffer, info.key_addrs[k], gather_src_offset + repeat_nums, (unsigned int)block_size_in_bytes, GDRAM2NRAM, (unsigned int)block_size_in_bytes, remain); __scatter_async(info.key_addrs[k], n_buffer, scatter_dst_offset + repeat_nums, (unsigned int)block_size_in_bytes, NRAM2GDRAM, (unsigned int)block_size_in_bytes, remain); if (info.has_value_cache) { __gather_async(n_buffer, info.value_addrs[k], gather_src_offset + repeat_nums, (unsigned int)block_size_in_bytes, GDRAM2NRAM, (unsigned int)block_size_in_bytes, remain); __scatter_async(info.value_addrs[k], n_buffer, scatter_dst_offset + repeat_nums, (unsigned int)block_size_in_bytes, NRAM2GDRAM, (unsigned int)block_size_in_bytes, remain); } } } } else { copyBlocksNodld(info, num_per_core, block_mapping_offset, num_layers, block_size_in_bytes); } #else copyBlocksNodld(info, num_per_core, block_mapping_offset, num_layers, block_size_in_bytes); #endif } } // namespace kernels KernelStatus invokeCopyBlocksKernel(const cnrtQueue_t queue, const std::vector &key_caches, const std::vector &value_caches, const std::vector &block_mapping_vec, const size_t block_size_in_bytes) { CNdev dev; cnCtxGetDevice(&dev); int cluster_num; CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev)); int core_num; CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev)); cnrtFunctionType_t k_type = cnrtFuncTypeBlock; if (key_caches.empty()) { std::cerr << "[invokeCopyBlocksKernel]: key_caches can not be empty." << std::endl; return KernelStatus::KERNEL_STATUS_FAILED; } if (!value_caches.empty() && key_caches.size() != value_caches.size()) { std::cerr << "[invokeCopyBlocksKernel]: key_caches size must equal to value_caches " << "size if value_caches is not empty." << std::endl; return KernelStatus::KERNEL_STATUS_FAILED; } int32_t mapping_size = block_mapping_vec.size(); int32_t num_pairs = mapping_size / 2; uint32_t task_dim = std::min(num_pairs, cluster_num * core_num); cnrtDim3_t k_dim{task_dim, 1, 1}; int32_t num_layers = key_caches.size(); int32_t layer_loop_num = std::ceil(float(num_layers) / LAYER_SIZE); int32_t layer_num_per_loop = std::ceil(float(num_layers) / layer_loop_num); int32_t pair_loop_num = std::ceil(float(num_pairs) / BLOCK_PAIR_SIZE); int32_t pair_num_per_loop = std::ceil(float(num_pairs) / pair_loop_num); CopyBlocksInfo info; if (value_caches.empty()) { info.has_value_cache = false; } for (int32_t i = 0; i < layer_loop_num; i++) { int32_t sub_num_layers = std::min(int32_t(layer_num_per_loop), num_layers - i * layer_num_per_loop); for (int32_t l = 0; l < sub_num_layers; l++) { info.key_addrs[l] = key_caches[l + i * layer_num_per_loop]; if (info.has_value_cache) { info.value_addrs[l] = value_caches[l + i * layer_num_per_loop]; } } for (int32_t j = 0; j < pair_loop_num; j++) { int32_t sub_num_pairs = std::min(int32_t(pair_num_per_loop), num_pairs - j * pair_num_per_loop); int32_t lens_block_mapping = sub_num_pairs * 2; int32_t block_vec_offset = j * pair_num_per_loop * 2; for (int32_t m = 0; m < lens_block_mapping; m++) { info.mapping_addrs[m] = block_mapping_vec[m + block_vec_offset]; } kernels::launchCopyBlocksKernel<<>>(info, sub_num_pairs, sub_num_layers, block_size_in_bytes); } } return KernelStatus::KERNEL_STATUS_SUCCESS; } } // namespace tmo