86 lines
3.8 KiB
Plaintext
86 lines
3.8 KiB
Plaintext
/*************************************************************************
|
|
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
|
*
|
|
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
|
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
|
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
|
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
|
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
|
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
|
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
*************************************************************************/
|
|
#ifndef CSRC_KERNELS_MOE_COMBINE_RESULT_MLUH_
|
|
#define CSRC_KERNELS_MOE_COMBINE_RESULT_MLUH_
|
|
|
|
#include "../kernel_utils.h"
|
|
#include "cnnl.h"
|
|
namespace tmo {
|
|
/**
|
|
* @brief Sort tokens grouped by different experts based on index. Each token
|
|
* selects the topk hidden vectors, multiplies them by corresponding weights,
|
|
* and finally reduces the topk vectors for each token. This process involves
|
|
* bias and residual, calculated as (x + bias) * weight + residual.
|
|
* @example
|
|
* input:
|
|
* [[[1, 2, 1, 1],
|
|
* [1, 1, 1, 2]],
|
|
* [[2, 1, 1, 1],
|
|
* [1, 1, 1, 1]]]
|
|
* num_token = 2, topk = 2
|
|
* cusum_token_count = [0, 2, 4]
|
|
* index:
|
|
* [0, 1, 2, 3]
|
|
* weight:
|
|
* [0, 0, 1, 1]
|
|
* bias:
|
|
* [[0, 0, 0, 0],
|
|
* [1, 1, 1, 1]]
|
|
* residual:
|
|
* [[1, 1, 1, 1],
|
|
* [0, 0, 0, 0]]
|
|
* output:
|
|
* [[1, 1, 1, 1],
|
|
* [5, 4, 4, 4]]
|
|
* @param queue: The queue for mlu.
|
|
* @param output: Output. Pointer to the MLU memory that stores the result.
|
|
* The shape is [num_token, hidden_size].
|
|
* @param input: Input. Pointer to the MLU memory that stores input tokens.
|
|
* The shape is [num_token * topk, hidden_size].
|
|
* @param bias: Input. Pointer to the MLU memory that stores bias.
|
|
* The shape is [num_expert, hidden_size].
|
|
* @param residual: Input. Pointer to the MLU memory that stores residual.
|
|
* The shape is [num_token, hidden_size].
|
|
* @param reduce_weight: Input. Pointer to the MLU memory that stores reduce_weight.
|
|
* The shape is [num_token * topk].
|
|
* @param cusum_token_count: Input. Pointer to the MLU memory that stores the cumulative sum of the
|
|
* token number of each expert. The shape is [num_expert + 1].
|
|
* @param gather_idx: Input. Pointer to the MLU memory that stores gather_idx.
|
|
* The shape is [num_token * topk].
|
|
* @param num_token: The total number of tokens.
|
|
* @param topk: The number of expert.
|
|
* @param num_expert: The number of expert.
|
|
* @param hidden_size: The size of lowest dimension.
|
|
* @param start_expert_id: The id of the first processed expert.
|
|
* @param expert_size: The number of processed experts.
|
|
* @param dtype: Data type.
|
|
* @note Currently does not support bias.
|
|
*/
|
|
KernelStatus invokeMoeCombineResultKernel(cnrtQueue_t queue,
|
|
void *output,
|
|
const void *input,
|
|
const void *bias,
|
|
const void *residual,
|
|
const float *reduce_weight,
|
|
const int *cusum_token_count,
|
|
const int *gather_idx,
|
|
int num_token,
|
|
int topk,
|
|
int num_expert,
|
|
int hidden_size,
|
|
int start_expert_id,
|
|
int expert_size,
|
|
cnnlDataType_t dtype);
|
|
} // namespace tmo
|
|
|
|
#endif // CSRC_KERNELS_MOE_COMBINE_RESULT_MLUH_
|