40 lines
1.8 KiB
Plaintext
40 lines
1.8 KiB
Plaintext
/*************************************************************************
|
|
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
|
*
|
|
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
|
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
|
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
|
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
|
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
|
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
|
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
*************************************************************************/
|
|
#ifndef CSRC_KERNELS_SWAP_BLOCKS_MLUH_
|
|
#define CSRC_KERNELS_SWAP_BLOCKS_MLUH_
|
|
|
|
#include <map>
|
|
#include "cnnl.h"
|
|
#include "kernel_utils.h"
|
|
|
|
namespace tmo {
|
|
/**
|
|
* @brief Perform swap_blocks operation.
|
|
* @param handle: The handle of cnnl.
|
|
* @param dst: Output. Pointer to the MLU memory that stores the dst tensor which has shape
|
|
* [num_blocks, num_heads, block_size, head_size].
|
|
* @param src: Input. Pointer to the MLU memory that stores the src tensor which has shape
|
|
* [num_blocks, num_heads, block_size, head_size].
|
|
* @param block_size_in_bytes: Data block size for each copy.
|
|
* @param memcpy_type: Copy direction, including h2d, d2h and d2d.
|
|
* @param block_mapping: Mapping table of src and dst.
|
|
*/
|
|
KernelStatus invokeSwapBlocksKernel(const cnnlHandle_t handle,
|
|
void *dst,
|
|
const void *src,
|
|
const int64_t &block_size_in_bytes,
|
|
const cnrtMemTransDir_t &memcpy_type,
|
|
const std::map<int64_t, int64_t> &block_mapping);
|
|
} // namespace tmo
|
|
|
|
#endif // CSRC_KERNELS_SWAP_BLOCKS_MLUH_
|