/************************************************************************* * Copyright (C) [2023-2024] by Cambricon, Inc. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. *************************************************************************/ #ifndef CSRC_KERNELS_RESHAPE_PAGED_CACHE_MLUH_ #define CSRC_KERNELS_RESHAPE_PAGED_CACHE_MLUH_ #include "kernel_utils.h" namespace tmo { /** * @brief Perform reshape_paged_cache operation. * @param handle: The handle of cnnl. * @param key: Pointer to the MLU memory that stores the key tensor which has shape [num_tokens, * num_heads, head_size]. * @param value: Pointer to the MLU memory that stores the value tensor which has shape [num_tokens, * num_heads, head_size]. * @param key_cache: Pointer to the MLU memory that stores the key_cache tensor which has shape * [num_blocks, num_heads, block_size, head_size]. * @param value_cache: Pointer to the MLU memory that stores the value_cache tensor which has shape * [num_blocks, num_heads, block_size, head_size]. * @param slot_mapping: Pointer to the MLU memory that stores the slot_mapping tensor which has * shape [num_tokens]. Data type of slot mapping must be int32_t. * @param key_stride0: The first dimension stride length of key_cache tensor. * @param value_stride0: The first dimension stride length of value_cache tensor. * @param num_tokens: Total number of tokens. * @param num_heads: Head number. * @param block_num: Total number of blocks. * @param block_size: Number of tokens per block. * @note: reshape_paged_cache does not support MLU300 device. */ KernelStatus invokeReshapePagedCache(cnrtQueue_t queue, cnnlDataType_t data_type, void *key, void *value, void *key_cache, void *value_cache, void *slot_mapping, size_t key_stride0, size_t value_stride0, int num_tokens, int num_heads, int block_num, int block_size, int head_size); } // namespace tmo #endif // CSRC_KERNELS_RESHAPE_PAGED_CACHE_MLUH_