#include "reshape_linear_cache.mluh" // clang-format off #include // clang-format on namespace tmo { namespace kernels { // [head_num, batch, seq_seg] __mlu_global__ void MLUReshapeLinearCacheKernel(int8_t *key_cache, int8_t *value_cache, int *cache_bs_offsets, int *cache_seq_offsets, int8_t *key, int8_t *value, int *context_seq_offsets, int *context_lens, int batch, int head_num, int head_size, int max_context_len, int cache_mem_len, size_t context_bs_stride, size_t context_head_stride, size_t context_seq_stride, size_t cache_bs_stride, size_t cache_head_stride, size_t cache_seq_stride, bool packed, int dtype_size, int SEQ_BLOCK) { int head_repeat = taskDimX > 1 ? 1 : head_num; for (int bs_idx = taskIdY; bs_idx < batch; bs_idx += taskDimY) { int seq_offset = (packed || context_seq_offsets == nullptr) ? 0 : context_seq_offsets[bs_idx]; int task_seq_begin = taskIdZ * SEQ_BLOCK; int seq_len = packed ? (context_lens[bs_idx + 1] - context_lens[bs_idx]) : context_lens[bs_idx]; if (task_seq_begin >= seq_len) continue; int seq = std::min(seq_len - task_seq_begin, SEQ_BLOCK); size_t context_offset = taskIdX * context_head_stride * dtype_size; if (packed) { context_offset += (context_lens[bs_idx] + task_seq_begin) * context_seq_stride * dtype_size; } else { context_offset += (bs_idx * context_bs_stride + (task_seq_begin + seq_offset) * context_seq_stride) * dtype_size; } int cache_seq_offset = cache_seq_offsets == nullptr ? 0 : cache_seq_offsets[bs_idx]; int cache_bs_offset = cache_bs_offsets == nullptr ? bs_idx : cache_bs_offsets[bs_idx]; if (cache_seq_offset < 0 || cache_bs_offset < 0) { continue; } cache_seq_offset += task_seq_begin; if (key != nullptr && key_cache != nullptr) { int8_t *key_cache_begin = key_cache + (cache_bs_offset * cache_bs_stride + taskIdX * cache_head_stride + cache_seq_offset * cache_seq_stride) * dtype_size; int8_t *key_begin = key + context_offset; __memcpy(key_cache_begin, key_begin, head_size * dtype_size, GDRAM2GDRAM, cache_seq_stride * dtype_size, seq - 1, cache_head_stride * dtype_size, head_repeat - 1, context_seq_stride * dtype_size, seq - 1, context_head_stride * dtype_size, head_repeat - 1); } if (value != nullptr && value_cache != nullptr) { int8_t *value_cache_begin = value_cache + (cache_bs_offset * cache_bs_stride + taskIdX * cache_head_stride + cache_seq_offset * cache_seq_stride) * dtype_size; int8_t *value_begin = value + context_offset; __memcpy(value_cache_begin, value_begin, head_size * dtype_size, GDRAM2GDRAM, cache_seq_stride * dtype_size, seq - 1, cache_head_stride * dtype_size, head_repeat - 1, context_seq_stride * dtype_size, seq - 1, context_head_stride * dtype_size, head_repeat - 1); } } } } // namespace kernels KernelStatus invokeReshapeLinearCache(cnrtQueue_t queue, void *key_cache, void *value_cache, const void *cache_bs_offsets, const void *cache_seq_offsets, void *key, void *value, const void *context_seq_offsets, const void *context_lens, const cnnlDataType_t dtype, const int batch, const int head_num, const int head_size, const int max_context_len, const int cache_mem_len, const int context_bs_stride, const int context_head_stride, const int context_seq_stride, const int cache_bs_stride, const int cache_head_stride, const int cache_seq_stride, const bool packed) { constexpr int SEQ_BLOCK = 512; int seq_seg = (max_context_len + SEQ_BLOCK - 1) / SEQ_BLOCK; bool is_decoder_case = head_num * max_context_len < SEQ_BLOCK; uint32_t task_x_dim = is_decoder_case ? 1 : head_num; uint32_t task_y_dim = is_decoder_case ? std::min(batch, 48) : batch; cnrtDim3_t dim{task_x_dim, task_y_dim, (uint32_t)seq_seg}; int dtype_size = 1; if (dtype == CNNL_DTYPE_HALF || dtype == CNNL_DTYPE_BFLOAT16) { dtype_size = 2; } else if (dtype == CNNL_DTYPE_INT8) { dtype_size = 1; } else if (dtype == CNNL_DTYPE_FLOAT) { dtype_size = 4; } else { std::cerr << "invokeReshapeLinearCache: unsupport dtype" << std::endl; return KernelStatus::KERNEL_STATUS_FAILED; } kernels::MLUReshapeLinearCacheKernel<<>>( (int8_t *)key_cache, (int8_t *)value_cache, (int *)cache_bs_offsets, (int *)cache_seq_offsets, (int8_t *)key, (int8_t *)value, (int *)context_seq_offsets, (int *)context_lens, batch, head_num, head_size, max_context_len, cache_mem_len, context_bs_stride, context_head_stride, context_seq_stride, cache_bs_stride, cache_head_stride, cache_seq_stride, packed, dtype_size, SEQ_BLOCK); return KernelStatus::KERNEL_STATUS_SUCCESS; } } // namespace tmo