55 lines
2.8 KiB
Plaintext
55 lines
2.8 KiB
Plaintext
/*************************************************************************
|
|
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
|
*
|
|
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
|
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
|
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
|
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
|
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
|
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
|
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
*************************************************************************/
|
|
#ifndef CSRC_KERNELS_RESHAPE_PAGED_CACHE_MLUH_
|
|
#define CSRC_KERNELS_RESHAPE_PAGED_CACHE_MLUH_
|
|
|
|
#include "kernel_utils.h"
|
|
namespace tmo {
|
|
/**
|
|
* @brief Perform reshape_paged_cache operation.
|
|
* @param handle: The handle of cnnl.
|
|
* @param key: Pointer to the MLU memory that stores the key tensor which has shape [num_tokens,
|
|
* num_heads, head_size].
|
|
* @param value: Pointer to the MLU memory that stores the value tensor which has shape [num_tokens,
|
|
* num_heads, head_size].
|
|
* @param key_cache: Pointer to the MLU memory that stores the key_cache tensor which has shape
|
|
* [num_blocks, num_heads, block_size, head_size].
|
|
* @param value_cache: Pointer to the MLU memory that stores the value_cache tensor which has shape
|
|
* [num_blocks, num_heads, block_size, head_size].
|
|
* @param slot_mapping: Pointer to the MLU memory that stores the slot_mapping tensor which has
|
|
* shape [num_tokens]. Data type of slot mapping must be int32_t.
|
|
* @param key_stride0: The first dimension stride length of key_cache tensor.
|
|
* @param value_stride0: The first dimension stride length of value_cache tensor.
|
|
* @param num_tokens: Total number of tokens.
|
|
* @param num_heads: Head number.
|
|
* @param block_num: Total number of blocks.
|
|
* @param block_size: Number of tokens per block.
|
|
* @note: reshape_paged_cache does not support MLU300 device.
|
|
*/
|
|
KernelStatus invokeReshapePagedCache(cnrtQueue_t queue,
|
|
cnnlDataType_t data_type,
|
|
void *key,
|
|
void *value,
|
|
void *key_cache,
|
|
void *value_cache,
|
|
void *slot_mapping,
|
|
size_t key_stride0,
|
|
size_t value_stride0,
|
|
int num_tokens,
|
|
int num_heads,
|
|
int block_num,
|
|
int block_size,
|
|
int head_size);
|
|
} // namespace tmo
|
|
|
|
#endif // CSRC_KERNELS_RESHAPE_PAGED_CACHE_MLUH_
|