/************************************************************************* * Copyright (C) [2023-2024] by Cambricon, Inc. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. *************************************************************************/ #ifndef CSRC_KERNELS_MOE_GEN_IDX_MLUH_ #define CSRC_KERNELS_MOE_GEN_IDX_MLUH_ #include #include "../kernel_utils.h" #include "cnnl.h" namespace tmo { /** * @brief Apply generate MOE index operation, which performs the following * tasks: * - 1. Generate gather_expand_idx and gather_combine_idx. * - 2. Output token_count, the token number of each expert. * - 3. Prepare inputs and outputs address for group_gemm. * @param queue: The queue of mlu. * @param gather_expand_idx: Output. Pointer to the MLU memory that stores the * gather index for expand hidden state operation, the shape must be * [num_token * topk]. * @param gather_combine_idx: Output. Pointer to the MLU memory that stores the * gather index for combine MOE operation, the shape must be * [num_token * topk]. * @param token_count: Output. Pointer to the MLU memory that stores the token * number of each expert, the shape must be [num_expert]. * @param cusum_token_count: Output. Pointer to the MLU memory that stores the * cumulative sum of the token number of each expert, the shape must be * [num_expert + 1]. It can be set to nullptr if don't need cusum output. * @param workspace: Input. A pointer to the extra workspace required in the * operation, the size must be larger than * (num_expert + 1 + num_token * topk) multiplied by the size of uint32. * @param expert_id: Input. Pointer to the MLU memory that stores the expert id * of each token, the shape must be [num_token, topk]. * @param num_token: The number of tokens. * @param num_expert: The number of experts. * @param topk: The number of expert selected by each token. */ KernelStatus invokeMoeGenIdxKernel(cnrtQueue_t queue, int *gather_expand_idx, int *gather_combine_idx, int *token_count, int *cusum_token_count, void *workspace, const void *expert_id, const int num_token, const int num_expert, const int topk); } // namespace tmo #endif // CSRC_KERNELS_MOE_GEN_IDX_MLUH_