Files
enginex-mlu370-vllm/torch_mlu_ops-v1.3.2/csrc/kernels/reshape_paged_cache.mlu
2026-02-04 17:39:32 +08:00

167 lines
7.5 KiB
Plaintext

/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "reshape_paged_cache.mluh"
namespace tmo {
namespace kernels {
#define NRAM_BUFFER_SIZE (480 * 1024)
__nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE];
__nram__ int nram_range_32[32] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31};
#define sizeof_(T) (uint32_t)sizeof(T)
__mlu_global__ void MLUReshapePagedCacheKernel(int8_t *key,
int8_t *value,
int8_t *key_cache,
int8_t *value_cache,
int *slot_mapping,
size_t key_stride0,
size_t value_stride0,
int num_tokens,
int num_heads,
int block_size,
int head_size,
int dtype_size,
int seq_block) {
#if __BANG_ARCH__ > 500
int seq_begin = taskId * seq_block;
if (seq_begin >= num_tokens) return;
int seq = std::min(seq_block, num_tokens - seq_begin);
int head_bytes = head_size * dtype_size;
int head_stride = block_size * head_bytes;
int block_stride = num_heads * head_stride;
int hidden_bytes = num_heads * head_bytes;
int8_t *nram_input = nram_buffer;
int *nram_token_offset = (int *)(nram_input + seq * hidden_bytes);
int pad_8_size = (num_heads * seq + 7) / 8 * 8;
int *nram_block_offset = nram_token_offset + pad_8_size;
int *nram_offset = nram_block_offset + pad_8_size;
int *nram_mask = nram_offset + pad_8_size;
__memcpy(nram_offset, slot_mapping + seq_begin, seq * sizeof_(int), GDRAM2NRAM);
__bang_rem(nram_token_offset, nram_offset, (int)block_size, seq);
__bang_mul_scalar(nram_token_offset, nram_token_offset, head_bytes, seq);
__bang_div(nram_block_offset, nram_offset, (int)block_size, seq);
__bang_mul_scalar(nram_block_offset, nram_block_offset, block_stride, seq);
// (num_heads, seq)
__memcpy(nram_offset, nram_token_offset, seq * sizeof_(int), NRAM2NRAM, seq * sizeof_(int), 0,
num_heads - 1);
// (num_heads, seq) -> (seq, num_heads)
__bang_transpose(nram_token_offset, nram_offset, num_heads, seq);
// (num_heads, seq)
__memcpy(nram_offset, nram_block_offset, seq * sizeof_(int), NRAM2NRAM, seq * sizeof_(int), 0,
num_heads - 1);
// (num_heads, seq) -> (seq, num_heads)
__bang_transpose(nram_block_offset, nram_offset, num_heads, seq);
__bang_write_zero(nram_offset, pad_8_size);
__bang_ge_bitindex((float *)nram_mask, (float *)nram_token_offset, (float *)nram_offset,
pad_8_size);
// generate range: (0, head_stride, 2 * head_stride, ..., (num_heads - 1) * head_stride)
__memcpy(nram_offset, nram_range_32, std::min(num_heads, 32) * sizeof_(int), NRAM2NRAM);
int begin = 32;
while (begin < num_heads) {
int count = std::min(begin, num_heads - begin);
__bang_add_scalar(nram_offset + begin, nram_offset, begin, count);
begin += count;
}
__bang_mul_scalar(nram_offset, nram_offset, head_stride, num_heads);
__bang_cycle_add(nram_token_offset, nram_token_offset, nram_offset, seq * num_heads, num_heads);
__bang_add(nram_offset, nram_token_offset, nram_block_offset, seq * num_heads);
if (key != nullptr && key_cache != nullptr) {
// (seq, num_heads, head_size)
__memcpy(nram_input, key + seq_begin * key_stride0 * dtype_size, hidden_bytes, GDRAM2NRAM,
hidden_bytes, key_stride0 * dtype_size, seq - 1);
__scatter(key_cache, nram_input, (uint32_t *)nram_offset, nram_mask, head_bytes, NRAM2GDRAM,
head_bytes, seq * num_heads);
}
if (value != nullptr && value_cache != nullptr) {
__memcpy(nram_input, value + seq_begin * value_stride0 * dtype_size, hidden_bytes, GDRAM2NRAM,
hidden_bytes, value_stride0 * dtype_size, seq - 1);
__scatter(value_cache, nram_input, (uint32_t *)nram_offset, nram_mask, head_bytes, NRAM2GDRAM,
head_bytes, seq * num_heads);
}
#endif
}
} // namespace kernels
KernelStatus invokeReshapePagedCache(cnrtQueue_t queue,
cnnlDataType_t data_type,
void *key,
void *value,
void *key_cache,
void *value_cache,
void *slot_mapping,
size_t key_stride0,
size_t value_stride0,
int num_tokens,
int num_heads,
int block_num,
int block_size,
int head_size) {
if (is_arch300()) {
std::cerr << "[invokeReshapePagedCache]: kernel does not support MLU300 devices." << std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
int dtype_size = 1;
if (data_type == CNNL_DTYPE_HALF || data_type == CNNL_DTYPE_BFLOAT16) {
dtype_size = 2;
} else if (data_type == CNNL_DTYPE_INT8) {
dtype_size = 1;
} else if (data_type == CNNL_DTYPE_FLOAT) {
dtype_size = 4;
} else {
std::cerr << "invokeReshapePagedCache: unsupport data type\n";
return KernelStatus::KERNEL_STATUS_FAILED;
}
int64_t kv_cache_range = block_num * block_size * num_heads * head_size * dtype_size;
if (kv_cache_range > UINT32_MAX) {
std::cerr << "[invokeReshapePagedCache]: The addressing range of kv_cache cannot exceed 4G."
<< std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
constexpr int nram_size = 224 * 1024;
int hidden_bytes = num_heads * head_size * dtype_size + 4 * num_heads * sizeof(int);
int seq_block = nram_size / hidden_bytes;
if (seq_block <= 0) {
std::cerr << "invokeReshapePagedCache: "
<< "num_heads * head_size * dtype_size + 4 * num_heads * sizeof(int) "
<< "should be less than 224KB.\n";
return KernelStatus::KERNEL_STATUS_FAILED;
}
if (seq_block > 16) {
seq_block = seq_block / 16 * 16;
}
uint32_t task_dim = (num_tokens + seq_block - 1) / seq_block;
task_dim = std::max(task_dim, (uint32_t)8);
task_dim = std::min(task_dim, (uint32_t)num_tokens);
cnrtDim3_t dim{task_dim, 1, 1};
kernels::MLUReshapePagedCacheKernel<<<dim, cnrtFuncTypeBlock, queue>>>(
(int8_t *)key, (int8_t *)value, (int8_t *)key_cache, (int8_t *)value_cache,
(int *)slot_mapping, key_stride0, value_stride0, num_tokens, num_heads, block_size, head_size,
dtype_size, seq_block);
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
} // namespace tmo