/************************************************************************* * Copyright (C) [2023-2024] by Cambricon, Inc. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. *************************************************************************/ #ifndef CSRC_KERNELS_MOE_COMBINE_RESULT_MLUH_ #define CSRC_KERNELS_MOE_COMBINE_RESULT_MLUH_ #include "../kernel_utils.h" #include "cnnl.h" namespace tmo { /** * @brief Sort tokens grouped by different experts based on index. Each token * selects the topk hidden vectors, multiplies them by corresponding weights, * and finally reduces the topk vectors for each token. This process involves * bias and residual, calculated as (x + bias) * weight + residual. * @example * input: * [[[1, 2, 1, 1], * [1, 1, 1, 2]], * [[2, 1, 1, 1], * [1, 1, 1, 1]]] * num_token = 2, topk = 2 * cusum_token_count = [0, 2, 4] * index: * [0, 1, 2, 3] * weight: * [0, 0, 1, 1] * bias: * [[0, 0, 0, 0], * [1, 1, 1, 1]] * residual: * [[1, 1, 1, 1], * [0, 0, 0, 0]] * output: * [[1, 1, 1, 1], * [5, 4, 4, 4]] * @param queue: The queue for mlu. * @param output: Output. Pointer to the MLU memory that stores the result. * The shape is [num_token, hidden_size]. * @param input: Input. Pointer to the MLU memory that stores input tokens. * The shape is [num_token * topk, hidden_size]. * @param bias: Input. Pointer to the MLU memory that stores bias. * The shape is [num_expert, hidden_size]. * @param residual: Input. Pointer to the MLU memory that stores residual. * The shape is [num_token, hidden_size]. * @param reduce_weight: Input. Pointer to the MLU memory that stores reduce_weight. * The shape is [num_token * topk]. * @param cusum_token_count: Input. Pointer to the MLU memory that stores the cumulative sum of the * token number of each expert. The shape is [num_expert + 1]. * @param gather_idx: Input. Pointer to the MLU memory that stores gather_idx. * The shape is [num_token * topk]. * @param num_token: The total number of tokens. * @param topk: The number of expert. * @param num_expert: The number of expert. * @param hidden_size: The size of lowest dimension. * @param start_expert_id: The id of the first processed expert. * @param expert_size: The number of processed experts. * @param dtype: Data type. * @note Currently does not support bias. */ KernelStatus invokeMoeCombineResultKernel(cnrtQueue_t queue, void *output, const void *input, const void *bias, const void *residual, const float *reduce_weight, const int *cusum_token_count, const int *gather_idx, int num_token, int topk, int num_expert, int hidden_size, int start_expert_id, int expert_size, cnnlDataType_t dtype); } // namespace tmo #endif // CSRC_KERNELS_MOE_COMBINE_RESULT_MLUH_