/************************************************************************* * Copyright (C) [2023-2024] by Cambricon, Inc. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. *************************************************************************/ #ifndef CSRC_KERNELS_GENERATE_MASK_MLUH_ #define CSRC_KERNELS_GENERATE_MASK_MLUH_ #include "cnnl.h" #include "kernel_utils.h" namespace tmo { /** * @brief Generate causal mask for context stage. * @param handle: The handle of cnnl. * @param output_ddr: Output. Pointer to the MLU memory that stores the output. * @param batch_seq_len: Input. Pointer to the MLU memory that stores the sequence length. * @param total_batch: Batch size. * @param max_seq_len: The maximum sequence length of context. * @param data_type: Data type. * @param fill_value: The fill value of the pad part. */ KernelStatus invokeGenerateMask(cnnlHandle_t handle, void *output_ddr, int *batch_seq_len, int total_batch, int max_seq_len, cnnlDataType_t data_type, float fill_value); } // namespace tmo #endif // CSRC_KERNELS_GENERATE_MASK_MLUH_