/************************************************************************* * Copyright (C) [2023-2024] by Cambricon, Inc. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. *************************************************************************/ #ifndef CSRC_KERNELS_SWAP_BLOCKS_MLUH_ #define CSRC_KERNELS_SWAP_BLOCKS_MLUH_ #include #include "cnnl.h" #include "kernel_utils.h" namespace tmo { /** * @brief Perform swap_blocks operation. * @param handle: The handle of cnnl. * @param dst: Output. Pointer to the MLU memory that stores the dst tensor which has shape * [num_blocks, num_heads, block_size, head_size]. * @param src: Input. Pointer to the MLU memory that stores the src tensor which has shape * [num_blocks, num_heads, block_size, head_size]. * @param block_size_in_bytes: Data block size for each copy. * @param memcpy_type: Copy direction, including h2d, d2h and d2d. * @param block_mapping: Mapping table of src and dst. */ KernelStatus invokeSwapBlocksKernel(const cnnlHandle_t handle, void *dst, const void *src, const int64_t &block_size_in_bytes, const cnrtMemTransDir_t &memcpy_type, const std::map &block_mapping); } // namespace tmo #endif // CSRC_KERNELS_SWAP_BLOCKS_MLUH_