Files
enginex-mlu370-vllm/torch_mlu_ops-v1.3.2/csrc/kernels/moe/expand_input.mluh
2026-02-04 17:39:32 +08:00

82 lines
3.5 KiB
Plaintext

/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_MOE_EXPAND_INPUT_MLUH_
#define CSRC_KERNELS_MOE_EXPAND_INPUT_MLUH_
#include "../kernel_utils.h"
#include "cnnl.h"
namespace tmo {
/**
* @brief Gathers slices from hidden_state at axis 1 according to gather_idx and cusum_token_count.
* @example
* hidden_state:
* [[1, 2, 3, 4],
* [5, 6, 7, 8],
* [9, 10, 11, 12]]
* gather_idx:
* [[1, 0, 2, 2, 1, 0]]
* cusum_token_count: NULL
* num_token = 3
* hidden_size = 4
* topk = 2
* expand_hidden_state:
* [[5, 6, 7, 8],
* [1, 2, 3, 4],
* [9, 10, 11, 12],
* [9, 10, 11, 12],
* [5, 6, 7, 8],
* [1, 2, 3, 4]]
* @param queue: The queue for mlu.
* @param hidden_state: Input. Pointer to the MLU memory that store the input,
* the shape must be [num_token, hidden_size].
* @param gather_idx: Input. Pointer to the MLU memory that stores the index,
* the shape must be [num_token * topk].
* @param cusum_token_count: Input. Pointer to the MLU memory that stores the prefix sum of
* token_count. If cusum_token_count is not NULL, the shape must be [total_expert_num + 1]. The
* gather operation will be performed as follows: if cusum_token_count is not NULL: index =
* gather_idx[cusum_token_count[start_expert_id]:cusum_token_count[start_expert_id+expert_count]]
* expand_hidden_state = hidden_state[index]
* else:
* index = gather_idx[:]
* expand_hidden_state = hidden_state[index]
* @param expand_hidden_state: Output. Pointer to the MLU memory that stores the output,
* if cusum_token_count is not NULL, the shape shoule be [num_index * topk ,hidden_size] in
* which num_index =
* cusum_token_count[start_expert_id+expert_count]-cusum_token_count[start_expert_id]. Otherwise,
* the shape should be [num_token * topk, hidden_size].
* @param num_token: the number of token.
* @param hidden_size: the slice size.
* @param topk: the number of topk.
* @param data_type: Data type of hidden_state.
* @param total_expert_num: the total number of expert.
* @param start_expert_id: the first expert id.
* @param expert_count: the number of experts currently being processed.
*/
KernelStatus invokeMoeExpandInputKernel(cnrtQueue_t queue,
void *expand_hidden_state,
const void *hidden_state,
const int *gather_idx,
const int *cusum_token_count,
int num_token,
int hidden_size,
int topk,
cnnlDataType_t data_type,
int total_expert_num,
int start_expert_id,
int expert_count);
} // namespace tmo
#endif // CSRC_KERNELS_MOE_EXPAND_INPUT_MLUH_