38 lines
1.8 KiB
Plaintext
38 lines
1.8 KiB
Plaintext
/*************************************************************************
|
|
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
|
*
|
|
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
|
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
|
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
|
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
|
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
|
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
|
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
*************************************************************************/
|
|
#ifndef CSRC_KERNELS_COPY_BLOCKS_MLUH_
|
|
#define CSRC_KERNELS_COPY_BLOCKS_MLUH_
|
|
|
|
#include <vector>
|
|
#include "cnnl.h"
|
|
#include "kernel_utils.h"
|
|
|
|
namespace tmo {
|
|
/**
|
|
* @brief Perform copy_blocks operation.
|
|
* @param queue: The queue for mlu.
|
|
* @param key_caches: Output/Input. Pointer to the MLU memory that stores the key_caches
|
|
* vector<tensor> which has shape [num_layers<num_blocks, num_heads, block_size, head_size>].
|
|
* @param value_caches: Output/Input. Pointer to the MLU memory that stores the value_caches
|
|
* vector<tensor> which has shape [num_layers<num_blocks, num_heads, block_size, head_size>].
|
|
* @param block_mapping_vec: block_mapping vector.
|
|
* @param block_size_in_bytes: one block data size.
|
|
*/
|
|
KernelStatus invokeCopyBlocksKernel(const cnrtQueue_t queue,
|
|
const std::vector<void *> &key_caches,
|
|
const std::vector<void *> &value_caches,
|
|
const std::vector<int32_t> &block_mapping_vec,
|
|
const size_t block_size_in_bytes);
|
|
} // namespace tmo
|
|
|
|
#endif // CSRC_KERNELS_COPY_BLOCKS_MLUH_
|