59 lines
3.0 KiB
Plaintext
59 lines
3.0 KiB
Plaintext
/*************************************************************************
|
|
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
|
*
|
|
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
|
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
|
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
|
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
|
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
|
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
|
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
*************************************************************************/
|
|
#ifndef CSRC_KERNELS_MOE_GEN_IDX_MLUH_
|
|
#define CSRC_KERNELS_MOE_GEN_IDX_MLUH_
|
|
|
|
#include <vector>
|
|
#include "../kernel_utils.h"
|
|
#include "cnnl.h"
|
|
namespace tmo {
|
|
/**
|
|
* @brief Apply generate MOE index operation, which performs the following
|
|
* tasks:
|
|
* - 1. Generate gather_expand_idx and gather_combine_idx.
|
|
* - 2. Output token_count, the token number of each expert.
|
|
* - 3. Prepare inputs and outputs address for group_gemm.
|
|
* @param queue: The queue of mlu.
|
|
* @param gather_expand_idx: Output. Pointer to the MLU memory that stores the
|
|
* gather index for expand hidden state operation, the shape must be
|
|
* [num_token * topk].
|
|
* @param gather_combine_idx: Output. Pointer to the MLU memory that stores the
|
|
* gather index for combine MOE operation, the shape must be
|
|
* [num_token * topk].
|
|
* @param token_count: Output. Pointer to the MLU memory that stores the token
|
|
* number of each expert, the shape must be [num_expert].
|
|
* @param cusum_token_count: Output. Pointer to the MLU memory that stores the
|
|
* cumulative sum of the token number of each expert, the shape must be
|
|
* [num_expert + 1]. It can be set to nullptr if don't need cusum output.
|
|
* @param workspace: Input. A pointer to the extra workspace required in the
|
|
* operation, the size must be larger than
|
|
* (num_expert + 1 + num_token * topk) multiplied by the size of uint32.
|
|
* @param expert_id: Input. Pointer to the MLU memory that stores the expert id
|
|
* of each token, the shape must be [num_token, topk].
|
|
* @param num_token: The number of tokens.
|
|
* @param num_expert: The number of experts.
|
|
* @param topk: The number of expert selected by each token.
|
|
*/
|
|
KernelStatus invokeMoeGenIdxKernel(cnrtQueue_t queue,
|
|
int *gather_expand_idx,
|
|
int *gather_combine_idx,
|
|
int *token_count,
|
|
int *cusum_token_count,
|
|
void *workspace,
|
|
const void *expert_id,
|
|
const int num_token,
|
|
const int num_expert,
|
|
const int topk);
|
|
} // namespace tmo
|
|
|
|
#endif // CSRC_KERNELS_MOE_GEN_IDX_MLUH_
|