135 lines
6.6 KiB
Plaintext
135 lines
6.6 KiB
Plaintext
#include "reshape_linear_cache.mluh"
|
|
// clang-format off
|
|
#include <mlu.h>
|
|
// clang-format on
|
|
|
|
namespace tmo {
|
|
|
|
namespace kernels {
|
|
|
|
// [head_num, batch, seq_seg]
|
|
__mlu_global__ void MLUReshapeLinearCacheKernel(int8_t *key_cache,
|
|
int8_t *value_cache,
|
|
int *cache_bs_offsets,
|
|
int *cache_seq_offsets,
|
|
int8_t *key,
|
|
int8_t *value,
|
|
int *context_seq_offsets,
|
|
int *context_lens,
|
|
int batch,
|
|
int head_num,
|
|
int head_size,
|
|
int max_context_len,
|
|
int cache_mem_len,
|
|
size_t context_bs_stride,
|
|
size_t context_head_stride,
|
|
size_t context_seq_stride,
|
|
size_t cache_bs_stride,
|
|
size_t cache_head_stride,
|
|
size_t cache_seq_stride,
|
|
bool packed,
|
|
int dtype_size,
|
|
int SEQ_BLOCK) {
|
|
int head_repeat = taskDimX > 1 ? 1 : head_num;
|
|
for (int bs_idx = taskIdY; bs_idx < batch; bs_idx += taskDimY) {
|
|
int seq_offset = (packed || context_seq_offsets == nullptr) ? 0 : context_seq_offsets[bs_idx];
|
|
int task_seq_begin = taskIdZ * SEQ_BLOCK;
|
|
int seq_len = packed ? (context_lens[bs_idx + 1] - context_lens[bs_idx]) : context_lens[bs_idx];
|
|
if (task_seq_begin >= seq_len) continue;
|
|
|
|
int seq = std::min(seq_len - task_seq_begin, SEQ_BLOCK);
|
|
size_t context_offset = taskIdX * context_head_stride * dtype_size;
|
|
if (packed) {
|
|
context_offset += (context_lens[bs_idx] + task_seq_begin) * context_seq_stride * dtype_size;
|
|
} else {
|
|
context_offset +=
|
|
(bs_idx * context_bs_stride + (task_seq_begin + seq_offset) * context_seq_stride) *
|
|
dtype_size;
|
|
}
|
|
|
|
int cache_seq_offset = cache_seq_offsets == nullptr ? 0 : cache_seq_offsets[bs_idx];
|
|
int cache_bs_offset = cache_bs_offsets == nullptr ? bs_idx : cache_bs_offsets[bs_idx];
|
|
if (cache_seq_offset < 0 || cache_bs_offset < 0) {
|
|
continue;
|
|
}
|
|
|
|
cache_seq_offset += task_seq_begin;
|
|
if (key != nullptr && key_cache != nullptr) {
|
|
int8_t *key_cache_begin =
|
|
key_cache + (cache_bs_offset * cache_bs_stride + taskIdX * cache_head_stride +
|
|
cache_seq_offset * cache_seq_stride) *
|
|
dtype_size;
|
|
int8_t *key_begin = key + context_offset;
|
|
__memcpy(key_cache_begin, key_begin, head_size * dtype_size, GDRAM2GDRAM,
|
|
cache_seq_stride * dtype_size, seq - 1, cache_head_stride * dtype_size,
|
|
head_repeat - 1, context_seq_stride * dtype_size, seq - 1,
|
|
context_head_stride * dtype_size, head_repeat - 1);
|
|
}
|
|
|
|
if (value != nullptr && value_cache != nullptr) {
|
|
int8_t *value_cache_begin =
|
|
value_cache + (cache_bs_offset * cache_bs_stride + taskIdX * cache_head_stride +
|
|
cache_seq_offset * cache_seq_stride) *
|
|
dtype_size;
|
|
int8_t *value_begin = value + context_offset;
|
|
__memcpy(value_cache_begin, value_begin, head_size * dtype_size, GDRAM2GDRAM,
|
|
cache_seq_stride * dtype_size, seq - 1, cache_head_stride * dtype_size,
|
|
head_repeat - 1, context_seq_stride * dtype_size, seq - 1,
|
|
context_head_stride * dtype_size, head_repeat - 1);
|
|
}
|
|
}
|
|
}
|
|
} // namespace kernels
|
|
|
|
KernelStatus invokeReshapeLinearCache(cnrtQueue_t queue,
|
|
void *key_cache,
|
|
void *value_cache,
|
|
const void *cache_bs_offsets,
|
|
const void *cache_seq_offsets,
|
|
void *key,
|
|
void *value,
|
|
const void *context_seq_offsets,
|
|
const void *context_lens,
|
|
const cnnlDataType_t dtype,
|
|
const int batch,
|
|
const int head_num,
|
|
const int head_size,
|
|
const int max_context_len,
|
|
const int cache_mem_len,
|
|
const int context_bs_stride,
|
|
const int context_head_stride,
|
|
const int context_seq_stride,
|
|
const int cache_bs_stride,
|
|
const int cache_head_stride,
|
|
const int cache_seq_stride,
|
|
const bool packed) {
|
|
constexpr int SEQ_BLOCK = 512;
|
|
int seq_seg = (max_context_len + SEQ_BLOCK - 1) / SEQ_BLOCK;
|
|
bool is_decoder_case = head_num * max_context_len < SEQ_BLOCK;
|
|
uint32_t task_x_dim = is_decoder_case ? 1 : head_num;
|
|
uint32_t task_y_dim = is_decoder_case ? std::min(batch, 48) : batch;
|
|
cnrtDim3_t dim{task_x_dim, task_y_dim, (uint32_t)seq_seg};
|
|
|
|
int dtype_size = 1;
|
|
if (dtype == CNNL_DTYPE_HALF || dtype == CNNL_DTYPE_BFLOAT16) {
|
|
dtype_size = 2;
|
|
} else if (dtype == CNNL_DTYPE_INT8) {
|
|
dtype_size = 1;
|
|
} else if (dtype == CNNL_DTYPE_FLOAT) {
|
|
dtype_size = 4;
|
|
} else {
|
|
std::cerr << "invokeReshapeLinearCache: unsupport dtype" << std::endl;
|
|
return KernelStatus::KERNEL_STATUS_FAILED;
|
|
}
|
|
|
|
kernels::MLUReshapeLinearCacheKernel<<<dim, cnrtFuncTypeBlock, queue>>>(
|
|
(int8_t *)key_cache, (int8_t *)value_cache, (int *)cache_bs_offsets, (int *)cache_seq_offsets,
|
|
(int8_t *)key, (int8_t *)value, (int *)context_seq_offsets, (int *)context_lens, batch,
|
|
head_num, head_size, max_context_len, cache_mem_len, context_bs_stride, context_head_stride,
|
|
context_seq_stride, cache_bs_stride, cache_head_stride, cache_seq_stride, packed, dtype_size,
|
|
SEQ_BLOCK);
|
|
|
|
return KernelStatus::KERNEL_STATUS_SUCCESS;
|
|
}
|
|
} // namespace tmo
|