Files
enginex-mlu370-vllm/torch_mlu_ops-v1.3.2/csrc/kernels/moe/softmax_topk.mluh
2026-02-04 17:39:32 +08:00

67 lines
3.7 KiB
Plaintext

/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_MOE_SOFTMAX_TOPK_MLUH_
#define CSRC_KERNELS_MOE_SOFTMAX_TOPK_MLUH_
#include "../kernel_utils.h"
#include "cnnl.h"
namespace tmo {
/**
* @brief Execute MOE Softmax Top-K Kernel.
*
* This function executes the MOE Softmax Top-K Kernel, which computes
* the Top-K values along a specified dimension after applying softmax to the input data.
* It is specifically designed for reduction along the lowest dimension.
*
* @param queue CNRT queue used to specify the queue for execution.
* @param reduce_weight Pointer to store the Top-K values.
* The shape must be [num_token, topk].
* @param expert_id Pointer to store the indices of the Top-K values.
* The shape must be [num_token, topk].
* @param input Pointer to the input data containing the values to be computed.
* The shape must be [num_token, num_expert].
* @param mask Pointer to the input data containing the mask value to be computed after
* computing softmax, Mask can be nullptr, which means no need to compute,
* otherwise the shape and datatype of mask should be the same as input.
* @param num_token Number of channels in the input data.
* @param num_expert Specified dimension. Note that num_expert should not exceed 32768.
* @param num_mask Number of channels in the mask data.
* @param topk Number of Top-K values to compute. topk should not be larger than num_expert.
* @param num_expert_group Group numbers of num_expert. If num_expert_group > 0, num_expert
* should be divisible by num_expert_group. Otherwise, num_expert_group and topk_group
* is not valid.
* @param topk_group Number of Top-K group values to compute. Topk_group should not be larger
* than num_expert_group.
* @param dtype Data type of the input data, should match the actual data type.
* float, half, bfloat16 is supported.
* @param normalize_mode Whether and how to normalize the output, if normalize_mode == 0, no
normalization is performed; if normalize_mode == 1, the normalized denominator is
the sum of topk; if normalize_mode == 2, the normalized denominator is the sum of
* the products of softmax_result mask.
*/
KernelStatus invokeMoeSoftmaxTopkKernel(cnrtQueue_t queue,
float *reduce_weight,
int *expert_id,
const void *input,
const void *mask,
const int num_token,
const int num_expert,
const int num_mask,
const int topk,
const int num_expert_group,
const int topk_group,
const cnnlDataType_t dtype,
const int normalize_mode);
} // namespace tmo
#endif // CSRC_KERNELS_MOE_SOFTMAX_TOPK_MLUH_