add ops
This commit is contained in:
58
torch_mlu_ops-v1.3.2/csrc/kernels/moe/gen_idx.mluh
Normal file
58
torch_mlu_ops-v1.3.2/csrc/kernels/moe/gen_idx.mluh
Normal file
@@ -0,0 +1,58 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#ifndef CSRC_KERNELS_MOE_GEN_IDX_MLUH_
|
||||
#define CSRC_KERNELS_MOE_GEN_IDX_MLUH_
|
||||
|
||||
#include <vector>
|
||||
#include "../kernel_utils.h"
|
||||
#include "cnnl.h"
|
||||
namespace tmo {
|
||||
/**
|
||||
* @brief Apply generate MOE index operation, which performs the following
|
||||
* tasks:
|
||||
* - 1. Generate gather_expand_idx and gather_combine_idx.
|
||||
* - 2. Output token_count, the token number of each expert.
|
||||
* - 3. Prepare inputs and outputs address for group_gemm.
|
||||
* @param queue: The queue of mlu.
|
||||
* @param gather_expand_idx: Output. Pointer to the MLU memory that stores the
|
||||
* gather index for expand hidden state operation, the shape must be
|
||||
* [num_token * topk].
|
||||
* @param gather_combine_idx: Output. Pointer to the MLU memory that stores the
|
||||
* gather index for combine MOE operation, the shape must be
|
||||
* [num_token * topk].
|
||||
* @param token_count: Output. Pointer to the MLU memory that stores the token
|
||||
* number of each expert, the shape must be [num_expert].
|
||||
* @param cusum_token_count: Output. Pointer to the MLU memory that stores the
|
||||
* cumulative sum of the token number of each expert, the shape must be
|
||||
* [num_expert + 1]. It can be set to nullptr if don't need cusum output.
|
||||
* @param workspace: Input. A pointer to the extra workspace required in the
|
||||
* operation, the size must be larger than
|
||||
* (num_expert + 1 + num_token * topk) multiplied by the size of uint32.
|
||||
* @param expert_id: Input. Pointer to the MLU memory that stores the expert id
|
||||
* of each token, the shape must be [num_token, topk].
|
||||
* @param num_token: The number of tokens.
|
||||
* @param num_expert: The number of experts.
|
||||
* @param topk: The number of expert selected by each token.
|
||||
*/
|
||||
KernelStatus invokeMoeGenIdxKernel(cnrtQueue_t queue,
|
||||
int *gather_expand_idx,
|
||||
int *gather_combine_idx,
|
||||
int *token_count,
|
||||
int *cusum_token_count,
|
||||
void *workspace,
|
||||
const void *expert_id,
|
||||
const int num_token,
|
||||
const int num_expert,
|
||||
const int topk);
|
||||
} // namespace tmo
|
||||
|
||||
#endif // CSRC_KERNELS_MOE_GEN_IDX_MLUH_
|
||||
Reference in New Issue
Block a user