Files
enginex-mlu370-vllm/torch_mlu_ops-v1.3.2/csrc/kernels/moe/combine_result.mluh
2026-02-04 17:39:32 +08:00

86 lines
3.8 KiB
Plaintext

/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_MOE_COMBINE_RESULT_MLUH_
#define CSRC_KERNELS_MOE_COMBINE_RESULT_MLUH_
#include "../kernel_utils.h"
#include "cnnl.h"
namespace tmo {
/**
* @brief Sort tokens grouped by different experts based on index. Each token
* selects the topk hidden vectors, multiplies them by corresponding weights,
* and finally reduces the topk vectors for each token. This process involves
* bias and residual, calculated as (x + bias) * weight + residual.
* @example
* input:
* [[[1, 2, 1, 1],
* [1, 1, 1, 2]],
* [[2, 1, 1, 1],
* [1, 1, 1, 1]]]
* num_token = 2, topk = 2
* cusum_token_count = [0, 2, 4]
* index:
* [0, 1, 2, 3]
* weight:
* [0, 0, 1, 1]
* bias:
* [[0, 0, 0, 0],
* [1, 1, 1, 1]]
* residual:
* [[1, 1, 1, 1],
* [0, 0, 0, 0]]
* output:
* [[1, 1, 1, 1],
* [5, 4, 4, 4]]
* @param queue: The queue for mlu.
* @param output: Output. Pointer to the MLU memory that stores the result.
* The shape is [num_token, hidden_size].
* @param input: Input. Pointer to the MLU memory that stores input tokens.
* The shape is [num_token * topk, hidden_size].
* @param bias: Input. Pointer to the MLU memory that stores bias.
* The shape is [num_expert, hidden_size].
* @param residual: Input. Pointer to the MLU memory that stores residual.
* The shape is [num_token, hidden_size].
* @param reduce_weight: Input. Pointer to the MLU memory that stores reduce_weight.
* The shape is [num_token * topk].
* @param cusum_token_count: Input. Pointer to the MLU memory that stores the cumulative sum of the
* token number of each expert. The shape is [num_expert + 1].
* @param gather_idx: Input. Pointer to the MLU memory that stores gather_idx.
* The shape is [num_token * topk].
* @param num_token: The total number of tokens.
* @param topk: The number of expert.
* @param num_expert: The number of expert.
* @param hidden_size: The size of lowest dimension.
* @param start_expert_id: The id of the first processed expert.
* @param expert_size: The number of processed experts.
* @param dtype: Data type.
* @note Currently does not support bias.
*/
KernelStatus invokeMoeCombineResultKernel(cnrtQueue_t queue,
void *output,
const void *input,
const void *bias,
const void *residual,
const float *reduce_weight,
const int *cusum_token_count,
const int *gather_idx,
int num_token,
int topk,
int num_expert,
int hidden_size,
int start_expert_id,
int expert_size,
cnnlDataType_t dtype);
} // namespace tmo
#endif // CSRC_KERNELS_MOE_COMBINE_RESULT_MLUH_