Files
enginex-mlu370-vllm/torch_mlu_ops-v1.3.2/csrc/kernels/copy_blocks.mlu
2026-02-04 17:39:32 +08:00

193 lines
8.7 KiB
Plaintext

#include <stdint.h>
#include <cmath>
#include <iostream>
#include <vector>
#include "cnnl.h"
#include "cnrt.h"
#include "copy_blocks.mluh"
#include "kernel_utils.h"
// clang-format off
#include <mlu.h>
// clang-format on
namespace tmo {
#define NRAM_REMAIN_SIZE (32 * 1024)
#define NRAM_BUFFER_SIZE (__MLU_NRAM_SIZE__ * 1024 - NRAM_REMAIN_SIZE)
#define USE_GATHER_THRESHHOLD_BLOCKSIZE 458753
#define LAYER_SIZE 128
#define BLOCK_PAIR_SIZE 512
#define ALIGN_BYTES 64
struct CopyBlocksInfo {
void *key_addrs[LAYER_SIZE];
void *value_addrs[LAYER_SIZE];
unsigned int mapping_addrs[BLOCK_PAIR_SIZE * 2];
bool has_value_cache = true;
};
namespace kernels {
__nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE];
__mlu_func__ void copyBlocksNodld(CopyBlocksInfo info,
uint32_t num_per_core,
uint32_t block_mapping_offset,
int32_t num_layers,
uint32_t block_size_in_bytes) {
for (uint32_t i = 0; i < num_per_core; i++) {
uint32_t map_offset = block_mapping_offset + i * 2;
uint32_t src_idx = info.mapping_addrs[map_offset];
uint32_t dst_idx = info.mapping_addrs[map_offset + 1];
int64_t src_offset = block_size_in_bytes * src_idx;
int64_t dst_offset = block_size_in_bytes * dst_idx;
for (uint32_t j = 0; j < num_layers; j++) {
__memcpy((int8_t *)info.key_addrs[j] + dst_offset, (int8_t *)info.key_addrs[j] + src_offset,
block_size_in_bytes, GDRAM2GDRAM);
if (info.has_value_cache) {
__memcpy((int8_t *)info.value_addrs[j] + dst_offset,
(int8_t *)info.value_addrs[j] + src_offset, block_size_in_bytes, GDRAM2GDRAM);
}
}
}
}
__mlu_global__ void launchCopyBlocksKernel(CopyBlocksInfo info,
int32_t num_pairs,
int32_t num_layers,
uint32_t block_size_in_bytes) {
uint32_t num_per_core = num_pairs / taskDim;
uint32_t remain_for_core = num_pairs % taskDim;
num_per_core += ((taskId < remain_for_core) ? 1 : 0);
uint32_t block_mapping_offset =
num_per_core * taskId + ((taskId < remain_for_core) ? 0 : remain_for_core);
block_mapping_offset *= 2;
#if (__BANG_ARCH__ >= 592)
if (block_size_in_bytes < USE_GATHER_THRESHHOLD_BLOCKSIZE) {
auto num_pair_data_width = sizeof(int32_t);
uint32_t align_num = ALIGN_BYTES / num_pair_data_width;
unsigned int num_per_core_2 = num_per_core * 2;
unsigned int num_per_core_2_align = (num_per_core_2 + align_num - 1) / align_num * align_num;
unsigned int *gather_src_offset = (unsigned int *)nram_buffer;
unsigned int *block_mapping_src_dst = gather_src_offset + num_per_core_2_align;
int8_t *n_buffer = (int8_t *)(block_mapping_src_dst + num_per_core_2_align);
uint32_t nram_remain = NRAM_BUFFER_SIZE - sizeof(unsigned int *) * num_per_core_2_align * 2;
unsigned int *scatter_dst_offset = gather_src_offset + num_per_core;
uint32_t num_per_loop = nram_remain / block_size_in_bytes;
uint32_t repeat = num_per_core / num_per_loop;
uint32_t remain = num_per_core % num_per_loop;
for (int i = 0; i < num_per_core; i++) {
unsigned int mapping_addrs_idx = block_mapping_offset + i * 2;
block_mapping_src_dst[i] = info.mapping_addrs[mapping_addrs_idx];
block_mapping_src_dst[num_per_core + i] = info.mapping_addrs[mapping_addrs_idx + 1];
}
__bang_mul_scalar(gather_src_offset, block_mapping_src_dst, (unsigned int)block_size_in_bytes,
num_per_core_2);
__sync();
for (uint32_t k = 0; k < num_layers; k++) {
for (uint32_t i = 0; i < repeat; i++) {
__gather_async(n_buffer, info.key_addrs[k], gather_src_offset + i * num_per_loop,
(unsigned int)block_size_in_bytes, GDRAM2NRAM,
(unsigned int)block_size_in_bytes, num_per_loop);
__scatter_async(info.key_addrs[k], n_buffer, scatter_dst_offset + i * num_per_loop,
(unsigned int)block_size_in_bytes, NRAM2GDRAM,
(unsigned int)block_size_in_bytes, num_per_loop);
if (info.has_value_cache) {
__gather_async(n_buffer, info.value_addrs[k], gather_src_offset + i * num_per_loop,
(unsigned int)block_size_in_bytes, GDRAM2NRAM,
(unsigned int)block_size_in_bytes, num_per_loop);
__scatter_async(info.value_addrs[k], n_buffer, scatter_dst_offset + i * num_per_loop,
(unsigned int)block_size_in_bytes, NRAM2GDRAM,
(unsigned int)block_size_in_bytes, num_per_loop);
}
}
if (remain != 0) {
uint32_t repeat_nums = repeat * num_per_loop;
__gather_async(n_buffer, info.key_addrs[k], gather_src_offset + repeat_nums,
(unsigned int)block_size_in_bytes, GDRAM2NRAM,
(unsigned int)block_size_in_bytes, remain);
__scatter_async(info.key_addrs[k], n_buffer, scatter_dst_offset + repeat_nums,
(unsigned int)block_size_in_bytes, NRAM2GDRAM,
(unsigned int)block_size_in_bytes, remain);
if (info.has_value_cache) {
__gather_async(n_buffer, info.value_addrs[k], gather_src_offset + repeat_nums,
(unsigned int)block_size_in_bytes, GDRAM2NRAM,
(unsigned int)block_size_in_bytes, remain);
__scatter_async(info.value_addrs[k], n_buffer, scatter_dst_offset + repeat_nums,
(unsigned int)block_size_in_bytes, NRAM2GDRAM,
(unsigned int)block_size_in_bytes, remain);
}
}
}
} else {
copyBlocksNodld(info, num_per_core, block_mapping_offset, num_layers, block_size_in_bytes);
}
#else
copyBlocksNodld(info, num_per_core, block_mapping_offset, num_layers, block_size_in_bytes);
#endif
}
} // namespace kernels
KernelStatus invokeCopyBlocksKernel(const cnrtQueue_t queue,
const std::vector<void *> &key_caches,
const std::vector<void *> &value_caches,
const std::vector<int32_t> &block_mapping_vec,
const size_t block_size_in_bytes) {
CNdev dev;
cnCtxGetDevice(&dev);
int cluster_num;
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
int core_num;
CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev));
cnrtFunctionType_t k_type = cnrtFuncTypeBlock;
if (key_caches.empty()) {
std::cerr << "[invokeCopyBlocksKernel]: key_caches can not be empty." << std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
if (!value_caches.empty() && key_caches.size() != value_caches.size()) {
std::cerr << "[invokeCopyBlocksKernel]: key_caches size must equal to value_caches "
<< "size if value_caches is not empty." << std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
int32_t mapping_size = block_mapping_vec.size();
int32_t num_pairs = mapping_size / 2;
uint32_t task_dim = std::min(num_pairs, cluster_num * core_num);
cnrtDim3_t k_dim{task_dim, 1, 1};
int32_t num_layers = key_caches.size();
int32_t layer_loop_num = std::ceil(float(num_layers) / LAYER_SIZE);
int32_t layer_num_per_loop = std::ceil(float(num_layers) / layer_loop_num);
int32_t pair_loop_num = std::ceil(float(num_pairs) / BLOCK_PAIR_SIZE);
int32_t pair_num_per_loop = std::ceil(float(num_pairs) / pair_loop_num);
CopyBlocksInfo info;
if (value_caches.empty()) {
info.has_value_cache = false;
}
for (int32_t i = 0; i < layer_loop_num; i++) {
int32_t sub_num_layers =
std::min(int32_t(layer_num_per_loop), num_layers - i * layer_num_per_loop);
for (int32_t l = 0; l < sub_num_layers; l++) {
info.key_addrs[l] = key_caches[l + i * layer_num_per_loop];
if (info.has_value_cache) {
info.value_addrs[l] = value_caches[l + i * layer_num_per_loop];
}
}
for (int32_t j = 0; j < pair_loop_num; j++) {
int32_t sub_num_pairs =
std::min(int32_t(pair_num_per_loop), num_pairs - j * pair_num_per_loop);
int32_t lens_block_mapping = sub_num_pairs * 2;
int32_t block_vec_offset = j * pair_num_per_loop * 2;
for (int32_t m = 0; m < lens_block_mapping; m++) {
info.mapping_addrs[m] = block_mapping_vec[m + block_vec_offset];
}
kernels::launchCopyBlocksKernel<<<k_dim, k_type, queue>>>(info, sub_num_pairs, sub_num_layers,
block_size_in_bytes);
}
}
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
} // namespace tmo