/************************************************************************* * Copyright (C) [2023-2024] by Cambricon, Inc. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. *************************************************************************/ #ifndef CSRC_KERNELS_CAST_GATING_MLUH_ #define CSRC_KERNELS_CAST_GATING_MLUH_ #include "../kernel_utils.h" #include "cnnl.h" namespace tmo { /** * @brief Convert input to float32 and do gating operation. * @param queue: The queue for mlu. * @param input: Input. Pointer to the MLU memory that stores the input, * the shape must be [input_row, hidden_size]. * @param filter: Input. Pointer to the MLU memory that stores the weight, * the shape must be [expert_num, hidden_size]. * @param output: Output. Pointer to the MLU memory that stores the output, * the shape must be [input_row, expert_num]. * @param input_row: Input. * @param expert_num: Input. * @param hidden_size: Input. * @param a_dtype: Input. The data-type of input. * @param workspace: Input. Pointer to the MLU workspace. * @param workspace_size_bytes: Input. The size of workspace in bytes. * @note: a_dtype must be CNNL_DTYPE_BFLOAT16 or CNNL_DTYPE_HALF. * expert_num must be in range [1, 128]. * If workspace is NOT NULL, workspace_size_bytes must NOT be smaller than 16 * 1024 * 1024. * The data-type of filter and output must be float. * cast_gating only supports MLU500 device or higher. */ KernelStatus invokeCastGating(cnrtQueue_t queue, void *input, void *filter, void *output, int input_row, int expert_num, int hidden_size, cnnlDataType_t a_dtype, void *workspace, size_t workspace_size_bytes); } // namespace tmo #endif // CSRC_KERNELS_CAST_GATING_MLUH_