/************************************************************************* * Copyright (C) [2023-2024] by Cambricon, Inc. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. *************************************************************************/ #ifndef CSRC_KERNELS_MOE_EXPAND_INPUT_MLUH_ #define CSRC_KERNELS_MOE_EXPAND_INPUT_MLUH_ #include "../kernel_utils.h" #include "cnnl.h" namespace tmo { /** * @brief Gathers slices from hidden_state at axis 1 according to gather_idx and cusum_token_count. * @example * hidden_state: * [[1, 2, 3, 4], * [5, 6, 7, 8], * [9, 10, 11, 12]] * gather_idx: * [[1, 0, 2, 2, 1, 0]] * cusum_token_count: NULL * num_token = 3 * hidden_size = 4 * topk = 2 * expand_hidden_state: * [[5, 6, 7, 8], * [1, 2, 3, 4], * [9, 10, 11, 12], * [9, 10, 11, 12], * [5, 6, 7, 8], * [1, 2, 3, 4]] * @param queue: The queue for mlu. * @param hidden_state: Input. Pointer to the MLU memory that store the input, * the shape must be [num_token, hidden_size]. * @param gather_idx: Input. Pointer to the MLU memory that stores the index, * the shape must be [num_token * topk]. * @param cusum_token_count: Input. Pointer to the MLU memory that stores the prefix sum of * token_count. If cusum_token_count is not NULL, the shape must be [total_expert_num + 1]. The * gather operation will be performed as follows: if cusum_token_count is not NULL: index = * gather_idx[cusum_token_count[start_expert_id]:cusum_token_count[start_expert_id+expert_count]] * expand_hidden_state = hidden_state[index] * else: * index = gather_idx[:] * expand_hidden_state = hidden_state[index] * @param expand_hidden_state: Output. Pointer to the MLU memory that stores the output, * if cusum_token_count is not NULL, the shape shoule be [num_index * topk ,hidden_size] in * which num_index = * cusum_token_count[start_expert_id+expert_count]-cusum_token_count[start_expert_id]. Otherwise, * the shape should be [num_token * topk, hidden_size]. * @param num_token: the number of token. * @param hidden_size: the slice size. * @param topk: the number of topk. * @param data_type: Data type of hidden_state. * @param total_expert_num: the total number of expert. * @param start_expert_id: the first expert id. * @param expert_count: the number of experts currently being processed. */ KernelStatus invokeMoeExpandInputKernel(cnrtQueue_t queue, void *expand_hidden_state, const void *hidden_state, const int *gather_idx, const int *cusum_token_count, int num_token, int hidden_size, int topk, cnnlDataType_t data_type, int total_expert_num, int start_expert_id, int expert_count); } // namespace tmo #endif // CSRC_KERNELS_MOE_EXPAND_INPUT_MLUH_