167 lines
7.5 KiB
Plaintext
167 lines
7.5 KiB
Plaintext
/*************************************************************************
|
|
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
|
*
|
|
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
|
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
|
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
|
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
|
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
|
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
|
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
*************************************************************************/
|
|
#include "reshape_paged_cache.mluh"
|
|
|
|
namespace tmo {
|
|
namespace kernels {
|
|
|
|
#define NRAM_BUFFER_SIZE (480 * 1024)
|
|
__nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE];
|
|
__nram__ int nram_range_32[32] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
|
|
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31};
|
|
|
|
#define sizeof_(T) (uint32_t)sizeof(T)
|
|
|
|
__mlu_global__ void MLUReshapePagedCacheKernel(int8_t *key,
|
|
int8_t *value,
|
|
int8_t *key_cache,
|
|
int8_t *value_cache,
|
|
int *slot_mapping,
|
|
size_t key_stride0,
|
|
size_t value_stride0,
|
|
int num_tokens,
|
|
int num_heads,
|
|
int block_size,
|
|
int head_size,
|
|
int dtype_size,
|
|
int seq_block) {
|
|
#if __BANG_ARCH__ > 500
|
|
int seq_begin = taskId * seq_block;
|
|
if (seq_begin >= num_tokens) return;
|
|
int seq = std::min(seq_block, num_tokens - seq_begin);
|
|
|
|
int head_bytes = head_size * dtype_size;
|
|
int head_stride = block_size * head_bytes;
|
|
int block_stride = num_heads * head_stride;
|
|
int hidden_bytes = num_heads * head_bytes;
|
|
|
|
int8_t *nram_input = nram_buffer;
|
|
int *nram_token_offset = (int *)(nram_input + seq * hidden_bytes);
|
|
int pad_8_size = (num_heads * seq + 7) / 8 * 8;
|
|
int *nram_block_offset = nram_token_offset + pad_8_size;
|
|
int *nram_offset = nram_block_offset + pad_8_size;
|
|
int *nram_mask = nram_offset + pad_8_size;
|
|
|
|
__memcpy(nram_offset, slot_mapping + seq_begin, seq * sizeof_(int), GDRAM2NRAM);
|
|
__bang_rem(nram_token_offset, nram_offset, (int)block_size, seq);
|
|
__bang_mul_scalar(nram_token_offset, nram_token_offset, head_bytes, seq);
|
|
__bang_div(nram_block_offset, nram_offset, (int)block_size, seq);
|
|
__bang_mul_scalar(nram_block_offset, nram_block_offset, block_stride, seq);
|
|
// (num_heads, seq)
|
|
__memcpy(nram_offset, nram_token_offset, seq * sizeof_(int), NRAM2NRAM, seq * sizeof_(int), 0,
|
|
num_heads - 1);
|
|
// (num_heads, seq) -> (seq, num_heads)
|
|
__bang_transpose(nram_token_offset, nram_offset, num_heads, seq);
|
|
|
|
// (num_heads, seq)
|
|
__memcpy(nram_offset, nram_block_offset, seq * sizeof_(int), NRAM2NRAM, seq * sizeof_(int), 0,
|
|
num_heads - 1);
|
|
// (num_heads, seq) -> (seq, num_heads)
|
|
__bang_transpose(nram_block_offset, nram_offset, num_heads, seq);
|
|
|
|
__bang_write_zero(nram_offset, pad_8_size);
|
|
__bang_ge_bitindex((float *)nram_mask, (float *)nram_token_offset, (float *)nram_offset,
|
|
pad_8_size);
|
|
|
|
// generate range: (0, head_stride, 2 * head_stride, ..., (num_heads - 1) * head_stride)
|
|
__memcpy(nram_offset, nram_range_32, std::min(num_heads, 32) * sizeof_(int), NRAM2NRAM);
|
|
int begin = 32;
|
|
while (begin < num_heads) {
|
|
int count = std::min(begin, num_heads - begin);
|
|
__bang_add_scalar(nram_offset + begin, nram_offset, begin, count);
|
|
begin += count;
|
|
}
|
|
__bang_mul_scalar(nram_offset, nram_offset, head_stride, num_heads);
|
|
|
|
__bang_cycle_add(nram_token_offset, nram_token_offset, nram_offset, seq * num_heads, num_heads);
|
|
__bang_add(nram_offset, nram_token_offset, nram_block_offset, seq * num_heads);
|
|
|
|
if (key != nullptr && key_cache != nullptr) {
|
|
// (seq, num_heads, head_size)
|
|
__memcpy(nram_input, key + seq_begin * key_stride0 * dtype_size, hidden_bytes, GDRAM2NRAM,
|
|
hidden_bytes, key_stride0 * dtype_size, seq - 1);
|
|
__scatter(key_cache, nram_input, (uint32_t *)nram_offset, nram_mask, head_bytes, NRAM2GDRAM,
|
|
head_bytes, seq * num_heads);
|
|
}
|
|
|
|
if (value != nullptr && value_cache != nullptr) {
|
|
__memcpy(nram_input, value + seq_begin * value_stride0 * dtype_size, hidden_bytes, GDRAM2NRAM,
|
|
hidden_bytes, value_stride0 * dtype_size, seq - 1);
|
|
__scatter(value_cache, nram_input, (uint32_t *)nram_offset, nram_mask, head_bytes, NRAM2GDRAM,
|
|
head_bytes, seq * num_heads);
|
|
}
|
|
#endif
|
|
}
|
|
} // namespace kernels
|
|
|
|
KernelStatus invokeReshapePagedCache(cnrtQueue_t queue,
|
|
cnnlDataType_t data_type,
|
|
void *key,
|
|
void *value,
|
|
void *key_cache,
|
|
void *value_cache,
|
|
void *slot_mapping,
|
|
size_t key_stride0,
|
|
size_t value_stride0,
|
|
int num_tokens,
|
|
int num_heads,
|
|
int block_num,
|
|
int block_size,
|
|
int head_size) {
|
|
if (is_arch300()) {
|
|
std::cerr << "[invokeReshapePagedCache]: kernel does not support MLU300 devices." << std::endl;
|
|
return KernelStatus::KERNEL_STATUS_FAILED;
|
|
}
|
|
int dtype_size = 1;
|
|
if (data_type == CNNL_DTYPE_HALF || data_type == CNNL_DTYPE_BFLOAT16) {
|
|
dtype_size = 2;
|
|
} else if (data_type == CNNL_DTYPE_INT8) {
|
|
dtype_size = 1;
|
|
} else if (data_type == CNNL_DTYPE_FLOAT) {
|
|
dtype_size = 4;
|
|
} else {
|
|
std::cerr << "invokeReshapePagedCache: unsupport data type\n";
|
|
return KernelStatus::KERNEL_STATUS_FAILED;
|
|
}
|
|
int64_t kv_cache_range = block_num * block_size * num_heads * head_size * dtype_size;
|
|
if (kv_cache_range > UINT32_MAX) {
|
|
std::cerr << "[invokeReshapePagedCache]: The addressing range of kv_cache cannot exceed 4G."
|
|
<< std::endl;
|
|
return KernelStatus::KERNEL_STATUS_FAILED;
|
|
}
|
|
|
|
constexpr int nram_size = 224 * 1024;
|
|
int hidden_bytes = num_heads * head_size * dtype_size + 4 * num_heads * sizeof(int);
|
|
int seq_block = nram_size / hidden_bytes;
|
|
if (seq_block <= 0) {
|
|
std::cerr << "invokeReshapePagedCache: "
|
|
<< "num_heads * head_size * dtype_size + 4 * num_heads * sizeof(int) "
|
|
<< "should be less than 224KB.\n";
|
|
return KernelStatus::KERNEL_STATUS_FAILED;
|
|
}
|
|
if (seq_block > 16) {
|
|
seq_block = seq_block / 16 * 16;
|
|
}
|
|
uint32_t task_dim = (num_tokens + seq_block - 1) / seq_block;
|
|
task_dim = std::max(task_dim, (uint32_t)8);
|
|
task_dim = std::min(task_dim, (uint32_t)num_tokens);
|
|
cnrtDim3_t dim{task_dim, 1, 1};
|
|
|
|
kernels::MLUReshapePagedCacheKernel<<<dim, cnrtFuncTypeBlock, queue>>>(
|
|
(int8_t *)key, (int8_t *)value, (int8_t *)key_cache, (int8_t *)value_cache,
|
|
(int *)slot_mapping, key_stride0, value_stride0, num_tokens, num_heads, block_size, head_size,
|
|
dtype_size, seq_block);
|
|
|
|
return KernelStatus::KERNEL_STATUS_SUCCESS;
|
|
}
|
|
} // namespace tmo
|