forked from EngineX-Cambricon/enginex-mlu370-vllm
add ops
This commit is contained in:
166
torch_mlu_ops-v1.3.2/csrc/kernels/reshape_paged_cache.mlu
Normal file
166
torch_mlu_ops-v1.3.2/csrc/kernels/reshape_paged_cache.mlu
Normal file
@@ -0,0 +1,166 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#include "reshape_paged_cache.mluh"
|
||||
|
||||
namespace tmo {
|
||||
namespace kernels {
|
||||
|
||||
#define NRAM_BUFFER_SIZE (480 * 1024)
|
||||
__nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE];
|
||||
__nram__ int nram_range_32[32] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
|
||||
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31};
|
||||
|
||||
#define sizeof_(T) (uint32_t)sizeof(T)
|
||||
|
||||
__mlu_global__ void MLUReshapePagedCacheKernel(int8_t *key,
|
||||
int8_t *value,
|
||||
int8_t *key_cache,
|
||||
int8_t *value_cache,
|
||||
int *slot_mapping,
|
||||
size_t key_stride0,
|
||||
size_t value_stride0,
|
||||
int num_tokens,
|
||||
int num_heads,
|
||||
int block_size,
|
||||
int head_size,
|
||||
int dtype_size,
|
||||
int seq_block) {
|
||||
#if __BANG_ARCH__ > 500
|
||||
int seq_begin = taskId * seq_block;
|
||||
if (seq_begin >= num_tokens) return;
|
||||
int seq = std::min(seq_block, num_tokens - seq_begin);
|
||||
|
||||
int head_bytes = head_size * dtype_size;
|
||||
int head_stride = block_size * head_bytes;
|
||||
int block_stride = num_heads * head_stride;
|
||||
int hidden_bytes = num_heads * head_bytes;
|
||||
|
||||
int8_t *nram_input = nram_buffer;
|
||||
int *nram_token_offset = (int *)(nram_input + seq * hidden_bytes);
|
||||
int pad_8_size = (num_heads * seq + 7) / 8 * 8;
|
||||
int *nram_block_offset = nram_token_offset + pad_8_size;
|
||||
int *nram_offset = nram_block_offset + pad_8_size;
|
||||
int *nram_mask = nram_offset + pad_8_size;
|
||||
|
||||
__memcpy(nram_offset, slot_mapping + seq_begin, seq * sizeof_(int), GDRAM2NRAM);
|
||||
__bang_rem(nram_token_offset, nram_offset, (int)block_size, seq);
|
||||
__bang_mul_scalar(nram_token_offset, nram_token_offset, head_bytes, seq);
|
||||
__bang_div(nram_block_offset, nram_offset, (int)block_size, seq);
|
||||
__bang_mul_scalar(nram_block_offset, nram_block_offset, block_stride, seq);
|
||||
// (num_heads, seq)
|
||||
__memcpy(nram_offset, nram_token_offset, seq * sizeof_(int), NRAM2NRAM, seq * sizeof_(int), 0,
|
||||
num_heads - 1);
|
||||
// (num_heads, seq) -> (seq, num_heads)
|
||||
__bang_transpose(nram_token_offset, nram_offset, num_heads, seq);
|
||||
|
||||
// (num_heads, seq)
|
||||
__memcpy(nram_offset, nram_block_offset, seq * sizeof_(int), NRAM2NRAM, seq * sizeof_(int), 0,
|
||||
num_heads - 1);
|
||||
// (num_heads, seq) -> (seq, num_heads)
|
||||
__bang_transpose(nram_block_offset, nram_offset, num_heads, seq);
|
||||
|
||||
__bang_write_zero(nram_offset, pad_8_size);
|
||||
__bang_ge_bitindex((float *)nram_mask, (float *)nram_token_offset, (float *)nram_offset,
|
||||
pad_8_size);
|
||||
|
||||
// generate range: (0, head_stride, 2 * head_stride, ..., (num_heads - 1) * head_stride)
|
||||
__memcpy(nram_offset, nram_range_32, std::min(num_heads, 32) * sizeof_(int), NRAM2NRAM);
|
||||
int begin = 32;
|
||||
while (begin < num_heads) {
|
||||
int count = std::min(begin, num_heads - begin);
|
||||
__bang_add_scalar(nram_offset + begin, nram_offset, begin, count);
|
||||
begin += count;
|
||||
}
|
||||
__bang_mul_scalar(nram_offset, nram_offset, head_stride, num_heads);
|
||||
|
||||
__bang_cycle_add(nram_token_offset, nram_token_offset, nram_offset, seq * num_heads, num_heads);
|
||||
__bang_add(nram_offset, nram_token_offset, nram_block_offset, seq * num_heads);
|
||||
|
||||
if (key != nullptr && key_cache != nullptr) {
|
||||
// (seq, num_heads, head_size)
|
||||
__memcpy(nram_input, key + seq_begin * key_stride0 * dtype_size, hidden_bytes, GDRAM2NRAM,
|
||||
hidden_bytes, key_stride0 * dtype_size, seq - 1);
|
||||
__scatter(key_cache, nram_input, (uint32_t *)nram_offset, nram_mask, head_bytes, NRAM2GDRAM,
|
||||
head_bytes, seq * num_heads);
|
||||
}
|
||||
|
||||
if (value != nullptr && value_cache != nullptr) {
|
||||
__memcpy(nram_input, value + seq_begin * value_stride0 * dtype_size, hidden_bytes, GDRAM2NRAM,
|
||||
hidden_bytes, value_stride0 * dtype_size, seq - 1);
|
||||
__scatter(value_cache, nram_input, (uint32_t *)nram_offset, nram_mask, head_bytes, NRAM2GDRAM,
|
||||
head_bytes, seq * num_heads);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
} // namespace kernels
|
||||
|
||||
KernelStatus invokeReshapePagedCache(cnrtQueue_t queue,
|
||||
cnnlDataType_t data_type,
|
||||
void *key,
|
||||
void *value,
|
||||
void *key_cache,
|
||||
void *value_cache,
|
||||
void *slot_mapping,
|
||||
size_t key_stride0,
|
||||
size_t value_stride0,
|
||||
int num_tokens,
|
||||
int num_heads,
|
||||
int block_num,
|
||||
int block_size,
|
||||
int head_size) {
|
||||
if (is_arch300()) {
|
||||
std::cerr << "[invokeReshapePagedCache]: kernel does not support MLU300 devices." << std::endl;
|
||||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||||
}
|
||||
int dtype_size = 1;
|
||||
if (data_type == CNNL_DTYPE_HALF || data_type == CNNL_DTYPE_BFLOAT16) {
|
||||
dtype_size = 2;
|
||||
} else if (data_type == CNNL_DTYPE_INT8) {
|
||||
dtype_size = 1;
|
||||
} else if (data_type == CNNL_DTYPE_FLOAT) {
|
||||
dtype_size = 4;
|
||||
} else {
|
||||
std::cerr << "invokeReshapePagedCache: unsupport data type\n";
|
||||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||||
}
|
||||
int64_t kv_cache_range = block_num * block_size * num_heads * head_size * dtype_size;
|
||||
if (kv_cache_range > UINT32_MAX) {
|
||||
std::cerr << "[invokeReshapePagedCache]: The addressing range of kv_cache cannot exceed 4G."
|
||||
<< std::endl;
|
||||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||||
}
|
||||
|
||||
constexpr int nram_size = 224 * 1024;
|
||||
int hidden_bytes = num_heads * head_size * dtype_size + 4 * num_heads * sizeof(int);
|
||||
int seq_block = nram_size / hidden_bytes;
|
||||
if (seq_block <= 0) {
|
||||
std::cerr << "invokeReshapePagedCache: "
|
||||
<< "num_heads * head_size * dtype_size + 4 * num_heads * sizeof(int) "
|
||||
<< "should be less than 224KB.\n";
|
||||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||||
}
|
||||
if (seq_block > 16) {
|
||||
seq_block = seq_block / 16 * 16;
|
||||
}
|
||||
uint32_t task_dim = (num_tokens + seq_block - 1) / seq_block;
|
||||
task_dim = std::max(task_dim, (uint32_t)8);
|
||||
task_dim = std::min(task_dim, (uint32_t)num_tokens);
|
||||
cnrtDim3_t dim{task_dim, 1, 1};
|
||||
|
||||
kernels::MLUReshapePagedCacheKernel<<<dim, cnrtFuncTypeBlock, queue>>>(
|
||||
(int8_t *)key, (int8_t *)value, (int8_t *)key_cache, (int8_t *)value_cache,
|
||||
(int *)slot_mapping, key_stride0, value_stride0, num_tokens, num_heads, block_size, head_size,
|
||||
dtype_size, seq_block);
|
||||
|
||||
return KernelStatus::KERNEL_STATUS_SUCCESS;
|
||||
}
|
||||
} // namespace tmo
|
||||
Reference in New Issue
Block a user