Files
2026-02-04 17:39:32 +08:00

38 lines
1.7 KiB
Plaintext

/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_GENERATE_MASK_MLUH_
#define CSRC_KERNELS_GENERATE_MASK_MLUH_
#include "cnnl.h"
#include "kernel_utils.h"
namespace tmo {
/**
* @brief Generate causal mask for context stage.
* @param handle: The handle of cnnl.
* @param output_ddr: Output. Pointer to the MLU memory that stores the output.
* @param batch_seq_len: Input. Pointer to the MLU memory that stores the sequence length.
* @param total_batch: Batch size.
* @param max_seq_len: The maximum sequence length of context.
* @param data_type: Data type.
* @param fill_value: The fill value of the pad part.
*/
KernelStatus invokeGenerateMask(cnnlHandle_t handle,
void *output_ddr,
int *batch_seq_len,
int total_batch,
int max_seq_len,
cnnlDataType_t data_type,
float fill_value);
} // namespace tmo
#endif // CSRC_KERNELS_GENERATE_MASK_MLUH_