/************************************************************************* * Copyright (C) [2023-2024] by Cambricon, Inc. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. *************************************************************************/ #include "reshape_paged_cache.mluh" namespace tmo { namespace kernels { #define NRAM_BUFFER_SIZE (480 * 1024) __nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE]; __nram__ int nram_range_32[32] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}; #define sizeof_(T) (uint32_t)sizeof(T) __mlu_global__ void MLUReshapePagedCacheKernel(int8_t *key, int8_t *value, int8_t *key_cache, int8_t *value_cache, int *slot_mapping, size_t key_stride0, size_t value_stride0, int num_tokens, int num_heads, int block_size, int head_size, int dtype_size, int seq_block) { #if __BANG_ARCH__ > 500 int seq_begin = taskId * seq_block; if (seq_begin >= num_tokens) return; int seq = std::min(seq_block, num_tokens - seq_begin); int head_bytes = head_size * dtype_size; int head_stride = block_size * head_bytes; int block_stride = num_heads * head_stride; int hidden_bytes = num_heads * head_bytes; int8_t *nram_input = nram_buffer; int *nram_token_offset = (int *)(nram_input + seq * hidden_bytes); int pad_8_size = (num_heads * seq + 7) / 8 * 8; int *nram_block_offset = nram_token_offset + pad_8_size; int *nram_offset = nram_block_offset + pad_8_size; int *nram_mask = nram_offset + pad_8_size; __memcpy(nram_offset, slot_mapping + seq_begin, seq * sizeof_(int), GDRAM2NRAM); __bang_rem(nram_token_offset, nram_offset, (int)block_size, seq); __bang_mul_scalar(nram_token_offset, nram_token_offset, head_bytes, seq); __bang_div(nram_block_offset, nram_offset, (int)block_size, seq); __bang_mul_scalar(nram_block_offset, nram_block_offset, block_stride, seq); // (num_heads, seq) __memcpy(nram_offset, nram_token_offset, seq * sizeof_(int), NRAM2NRAM, seq * sizeof_(int), 0, num_heads - 1); // (num_heads, seq) -> (seq, num_heads) __bang_transpose(nram_token_offset, nram_offset, num_heads, seq); // (num_heads, seq) __memcpy(nram_offset, nram_block_offset, seq * sizeof_(int), NRAM2NRAM, seq * sizeof_(int), 0, num_heads - 1); // (num_heads, seq) -> (seq, num_heads) __bang_transpose(nram_block_offset, nram_offset, num_heads, seq); __bang_write_zero(nram_offset, pad_8_size); __bang_ge_bitindex((float *)nram_mask, (float *)nram_token_offset, (float *)nram_offset, pad_8_size); // generate range: (0, head_stride, 2 * head_stride, ..., (num_heads - 1) * head_stride) __memcpy(nram_offset, nram_range_32, std::min(num_heads, 32) * sizeof_(int), NRAM2NRAM); int begin = 32; while (begin < num_heads) { int count = std::min(begin, num_heads - begin); __bang_add_scalar(nram_offset + begin, nram_offset, begin, count); begin += count; } __bang_mul_scalar(nram_offset, nram_offset, head_stride, num_heads); __bang_cycle_add(nram_token_offset, nram_token_offset, nram_offset, seq * num_heads, num_heads); __bang_add(nram_offset, nram_token_offset, nram_block_offset, seq * num_heads); if (key != nullptr && key_cache != nullptr) { // (seq, num_heads, head_size) __memcpy(nram_input, key + seq_begin * key_stride0 * dtype_size, hidden_bytes, GDRAM2NRAM, hidden_bytes, key_stride0 * dtype_size, seq - 1); __scatter(key_cache, nram_input, (uint32_t *)nram_offset, nram_mask, head_bytes, NRAM2GDRAM, head_bytes, seq * num_heads); } if (value != nullptr && value_cache != nullptr) { __memcpy(nram_input, value + seq_begin * value_stride0 * dtype_size, hidden_bytes, GDRAM2NRAM, hidden_bytes, value_stride0 * dtype_size, seq - 1); __scatter(value_cache, nram_input, (uint32_t *)nram_offset, nram_mask, head_bytes, NRAM2GDRAM, head_bytes, seq * num_heads); } #endif } } // namespace kernels KernelStatus invokeReshapePagedCache(cnrtQueue_t queue, cnnlDataType_t data_type, void *key, void *value, void *key_cache, void *value_cache, void *slot_mapping, size_t key_stride0, size_t value_stride0, int num_tokens, int num_heads, int block_num, int block_size, int head_size) { if (is_arch300()) { std::cerr << "[invokeReshapePagedCache]: kernel does not support MLU300 devices." << std::endl; return KernelStatus::KERNEL_STATUS_FAILED; } int dtype_size = 1; if (data_type == CNNL_DTYPE_HALF || data_type == CNNL_DTYPE_BFLOAT16) { dtype_size = 2; } else if (data_type == CNNL_DTYPE_INT8) { dtype_size = 1; } else if (data_type == CNNL_DTYPE_FLOAT) { dtype_size = 4; } else { std::cerr << "invokeReshapePagedCache: unsupport data type\n"; return KernelStatus::KERNEL_STATUS_FAILED; } int64_t kv_cache_range = block_num * block_size * num_heads * head_size * dtype_size; if (kv_cache_range > UINT32_MAX) { std::cerr << "[invokeReshapePagedCache]: The addressing range of kv_cache cannot exceed 4G." << std::endl; return KernelStatus::KERNEL_STATUS_FAILED; } constexpr int nram_size = 224 * 1024; int hidden_bytes = num_heads * head_size * dtype_size + 4 * num_heads * sizeof(int); int seq_block = nram_size / hidden_bytes; if (seq_block <= 0) { std::cerr << "invokeReshapePagedCache: " << "num_heads * head_size * dtype_size + 4 * num_heads * sizeof(int) " << "should be less than 224KB.\n"; return KernelStatus::KERNEL_STATUS_FAILED; } if (seq_block > 16) { seq_block = seq_block / 16 * 16; } uint32_t task_dim = (num_tokens + seq_block - 1) / seq_block; task_dim = std::max(task_dim, (uint32_t)8); task_dim = std::min(task_dim, (uint32_t)num_tokens); cnrtDim3_t dim{task_dim, 1, 1}; kernels::MLUReshapePagedCacheKernel<<>>( (int8_t *)key, (int8_t *)value, (int8_t *)key_cache, (int8_t *)value_cache, (int *)slot_mapping, key_stride0, value_stride0, num_tokens, num_heads, block_size, head_size, dtype_size, seq_block); return KernelStatus::KERNEL_STATUS_SUCCESS; } } // namespace tmo