67 lines
3.7 KiB
Plaintext
67 lines
3.7 KiB
Plaintext
/*************************************************************************
|
|
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
|
*
|
|
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
|
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
|
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
|
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
|
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
|
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
|
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
*************************************************************************/
|
|
#ifndef CSRC_KERNELS_MOE_SOFTMAX_TOPK_MLUH_
|
|
#define CSRC_KERNELS_MOE_SOFTMAX_TOPK_MLUH_
|
|
|
|
#include "../kernel_utils.h"
|
|
#include "cnnl.h"
|
|
namespace tmo {
|
|
/**
|
|
* @brief Execute MOE Softmax Top-K Kernel.
|
|
*
|
|
* This function executes the MOE Softmax Top-K Kernel, which computes
|
|
* the Top-K values along a specified dimension after applying softmax to the input data.
|
|
* It is specifically designed for reduction along the lowest dimension.
|
|
*
|
|
* @param queue CNRT queue used to specify the queue for execution.
|
|
* @param reduce_weight Pointer to store the Top-K values.
|
|
* The shape must be [num_token, topk].
|
|
* @param expert_id Pointer to store the indices of the Top-K values.
|
|
* The shape must be [num_token, topk].
|
|
* @param input Pointer to the input data containing the values to be computed.
|
|
* The shape must be [num_token, num_expert].
|
|
* @param mask Pointer to the input data containing the mask value to be computed after
|
|
* computing softmax, Mask can be nullptr, which means no need to compute,
|
|
* otherwise the shape and datatype of mask should be the same as input.
|
|
* @param num_token Number of channels in the input data.
|
|
* @param num_expert Specified dimension. Note that num_expert should not exceed 32768.
|
|
* @param num_mask Number of channels in the mask data.
|
|
* @param topk Number of Top-K values to compute. topk should not be larger than num_expert.
|
|
* @param num_expert_group Group numbers of num_expert. If num_expert_group > 0, num_expert
|
|
* should be divisible by num_expert_group. Otherwise, num_expert_group and topk_group
|
|
* is not valid.
|
|
* @param topk_group Number of Top-K group values to compute. Topk_group should not be larger
|
|
* than num_expert_group.
|
|
* @param dtype Data type of the input data, should match the actual data type.
|
|
* float, half, bfloat16 is supported.
|
|
* @param normalize_mode Whether and how to normalize the output, if normalize_mode == 0, no
|
|
normalization is performed; if normalize_mode == 1, the normalized denominator is
|
|
the sum of topk; if normalize_mode == 2, the normalized denominator is the sum of
|
|
* the products of softmax_result mask.
|
|
*/
|
|
KernelStatus invokeMoeSoftmaxTopkKernel(cnrtQueue_t queue,
|
|
float *reduce_weight,
|
|
int *expert_id,
|
|
const void *input,
|
|
const void *mask,
|
|
const int num_token,
|
|
const int num_expert,
|
|
const int num_mask,
|
|
const int topk,
|
|
const int num_expert_group,
|
|
const int topk_group,
|
|
const cnnlDataType_t dtype,
|
|
const int normalize_mode);
|
|
} // namespace tmo
|
|
|
|
#endif // CSRC_KERNELS_MOE_SOFTMAX_TOPK_MLUH_
|