Files
enginex-mlu370-vllm/torch_mlu_ops-v1.3.2/csrc/kernels/reshape_linear_cache.mlu
2026-02-04 17:39:32 +08:00

135 lines
6.6 KiB
Plaintext

#include "reshape_linear_cache.mluh"
// clang-format off
#include <mlu.h>
// clang-format on
namespace tmo {
namespace kernels {
// [head_num, batch, seq_seg]
__mlu_global__ void MLUReshapeLinearCacheKernel(int8_t *key_cache,
int8_t *value_cache,
int *cache_bs_offsets,
int *cache_seq_offsets,
int8_t *key,
int8_t *value,
int *context_seq_offsets,
int *context_lens,
int batch,
int head_num,
int head_size,
int max_context_len,
int cache_mem_len,
size_t context_bs_stride,
size_t context_head_stride,
size_t context_seq_stride,
size_t cache_bs_stride,
size_t cache_head_stride,
size_t cache_seq_stride,
bool packed,
int dtype_size,
int SEQ_BLOCK) {
int head_repeat = taskDimX > 1 ? 1 : head_num;
for (int bs_idx = taskIdY; bs_idx < batch; bs_idx += taskDimY) {
int seq_offset = (packed || context_seq_offsets == nullptr) ? 0 : context_seq_offsets[bs_idx];
int task_seq_begin = taskIdZ * SEQ_BLOCK;
int seq_len = packed ? (context_lens[bs_idx + 1] - context_lens[bs_idx]) : context_lens[bs_idx];
if (task_seq_begin >= seq_len) continue;
int seq = std::min(seq_len - task_seq_begin, SEQ_BLOCK);
size_t context_offset = taskIdX * context_head_stride * dtype_size;
if (packed) {
context_offset += (context_lens[bs_idx] + task_seq_begin) * context_seq_stride * dtype_size;
} else {
context_offset +=
(bs_idx * context_bs_stride + (task_seq_begin + seq_offset) * context_seq_stride) *
dtype_size;
}
int cache_seq_offset = cache_seq_offsets == nullptr ? 0 : cache_seq_offsets[bs_idx];
int cache_bs_offset = cache_bs_offsets == nullptr ? bs_idx : cache_bs_offsets[bs_idx];
if (cache_seq_offset < 0 || cache_bs_offset < 0) {
continue;
}
cache_seq_offset += task_seq_begin;
if (key != nullptr && key_cache != nullptr) {
int8_t *key_cache_begin =
key_cache + (cache_bs_offset * cache_bs_stride + taskIdX * cache_head_stride +
cache_seq_offset * cache_seq_stride) *
dtype_size;
int8_t *key_begin = key + context_offset;
__memcpy(key_cache_begin, key_begin, head_size * dtype_size, GDRAM2GDRAM,
cache_seq_stride * dtype_size, seq - 1, cache_head_stride * dtype_size,
head_repeat - 1, context_seq_stride * dtype_size, seq - 1,
context_head_stride * dtype_size, head_repeat - 1);
}
if (value != nullptr && value_cache != nullptr) {
int8_t *value_cache_begin =
value_cache + (cache_bs_offset * cache_bs_stride + taskIdX * cache_head_stride +
cache_seq_offset * cache_seq_stride) *
dtype_size;
int8_t *value_begin = value + context_offset;
__memcpy(value_cache_begin, value_begin, head_size * dtype_size, GDRAM2GDRAM,
cache_seq_stride * dtype_size, seq - 1, cache_head_stride * dtype_size,
head_repeat - 1, context_seq_stride * dtype_size, seq - 1,
context_head_stride * dtype_size, head_repeat - 1);
}
}
}
} // namespace kernels
KernelStatus invokeReshapeLinearCache(cnrtQueue_t queue,
void *key_cache,
void *value_cache,
const void *cache_bs_offsets,
const void *cache_seq_offsets,
void *key,
void *value,
const void *context_seq_offsets,
const void *context_lens,
const cnnlDataType_t dtype,
const int batch,
const int head_num,
const int head_size,
const int max_context_len,
const int cache_mem_len,
const int context_bs_stride,
const int context_head_stride,
const int context_seq_stride,
const int cache_bs_stride,
const int cache_head_stride,
const int cache_seq_stride,
const bool packed) {
constexpr int SEQ_BLOCK = 512;
int seq_seg = (max_context_len + SEQ_BLOCK - 1) / SEQ_BLOCK;
bool is_decoder_case = head_num * max_context_len < SEQ_BLOCK;
uint32_t task_x_dim = is_decoder_case ? 1 : head_num;
uint32_t task_y_dim = is_decoder_case ? std::min(batch, 48) : batch;
cnrtDim3_t dim{task_x_dim, task_y_dim, (uint32_t)seq_seg};
int dtype_size = 1;
if (dtype == CNNL_DTYPE_HALF || dtype == CNNL_DTYPE_BFLOAT16) {
dtype_size = 2;
} else if (dtype == CNNL_DTYPE_INT8) {
dtype_size = 1;
} else if (dtype == CNNL_DTYPE_FLOAT) {
dtype_size = 4;
} else {
std::cerr << "invokeReshapeLinearCache: unsupport dtype" << std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
kernels::MLUReshapeLinearCacheKernel<<<dim, cnrtFuncTypeBlock, queue>>>(
(int8_t *)key_cache, (int8_t *)value_cache, (int *)cache_bs_offsets, (int *)cache_seq_offsets,
(int8_t *)key, (int8_t *)value, (int *)context_seq_offsets, (int *)context_lens, batch,
head_num, head_size, max_context_len, cache_mem_len, context_bs_stride, context_head_stride,
context_seq_stride, cache_bs_stride, cache_head_stride, cache_seq_stride, packed, dtype_size,
SEQ_BLOCK);
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
} // namespace tmo