Files
enginex-mlu370-vllm/torch_mlu_ops-v1.3.2/csrc/kernels/moe/gen_idx.mluh
2026-02-04 17:39:32 +08:00

59 lines
3.0 KiB
Plaintext

/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_MOE_GEN_IDX_MLUH_
#define CSRC_KERNELS_MOE_GEN_IDX_MLUH_
#include <vector>
#include "../kernel_utils.h"
#include "cnnl.h"
namespace tmo {
/**
* @brief Apply generate MOE index operation, which performs the following
* tasks:
* - 1. Generate gather_expand_idx and gather_combine_idx.
* - 2. Output token_count, the token number of each expert.
* - 3. Prepare inputs and outputs address for group_gemm.
* @param queue: The queue of mlu.
* @param gather_expand_idx: Output. Pointer to the MLU memory that stores the
* gather index for expand hidden state operation, the shape must be
* [num_token * topk].
* @param gather_combine_idx: Output. Pointer to the MLU memory that stores the
* gather index for combine MOE operation, the shape must be
* [num_token * topk].
* @param token_count: Output. Pointer to the MLU memory that stores the token
* number of each expert, the shape must be [num_expert].
* @param cusum_token_count: Output. Pointer to the MLU memory that stores the
* cumulative sum of the token number of each expert, the shape must be
* [num_expert + 1]. It can be set to nullptr if don't need cusum output.
* @param workspace: Input. A pointer to the extra workspace required in the
* operation, the size must be larger than
* (num_expert + 1 + num_token * topk) multiplied by the size of uint32.
* @param expert_id: Input. Pointer to the MLU memory that stores the expert id
* of each token, the shape must be [num_token, topk].
* @param num_token: The number of tokens.
* @param num_expert: The number of experts.
* @param topk: The number of expert selected by each token.
*/
KernelStatus invokeMoeGenIdxKernel(cnrtQueue_t queue,
int *gather_expand_idx,
int *gather_combine_idx,
int *token_count,
int *cusum_token_count,
void *workspace,
const void *expert_id,
const int num_token,
const int num_expert,
const int topk);
} // namespace tmo
#endif // CSRC_KERNELS_MOE_GEN_IDX_MLUH_