Files
enginex-mlu370-vllm/torch_mlu_ops-v1.3.2/csrc/kernels/swap_blocks.mluh
2026-02-04 17:39:32 +08:00

40 lines
1.8 KiB
Plaintext

/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_SWAP_BLOCKS_MLUH_
#define CSRC_KERNELS_SWAP_BLOCKS_MLUH_
#include <map>
#include "cnnl.h"
#include "kernel_utils.h"
namespace tmo {
/**
* @brief Perform swap_blocks operation.
* @param handle: The handle of cnnl.
* @param dst: Output. Pointer to the MLU memory that stores the dst tensor which has shape
* [num_blocks, num_heads, block_size, head_size].
* @param src: Input. Pointer to the MLU memory that stores the src tensor which has shape
* [num_blocks, num_heads, block_size, head_size].
* @param block_size_in_bytes: Data block size for each copy.
* @param memcpy_type: Copy direction, including h2d, d2h and d2d.
* @param block_mapping: Mapping table of src and dst.
*/
KernelStatus invokeSwapBlocksKernel(const cnnlHandle_t handle,
void *dst,
const void *src,
const int64_t &block_size_in_bytes,
const cnrtMemTransDir_t &memcpy_type,
const std::map<int64_t, int64_t> &block_mapping);
} // namespace tmo
#endif // CSRC_KERNELS_SWAP_BLOCKS_MLUH_