moe_gating_top_k (#5271)

1. What this PR does / why we need it?
This PR supports the moe_gating_top_k operator, which enables
post-positioned renormalization (renorm) on the basis of softmax.
2. Does this PR introduce any user-facing change?
No user-facing changes are required.
3. How was this patch tested?
This patch was tested with the test_npu_moe_gating_top_k test case.
vLLM version: release/v0.13.0
vLLM main:
ad32e3e19c

---------

Signed-off-by: ZCG12345 <2097562023@qq.com>
Signed-off-by: zzzzwwjj <34335947+zzzzwwjj@users.noreply.github.com>
Co-authored-by: zzzzwwjj <34335947+zzzzwwjj@users.noreply.github.com>
This commit is contained in:
ZCG12345
2025-12-30 09:28:01 +08:00
committed by GitHub
parent 15d73f248e
commit 45c3c279e2
34 changed files with 4791 additions and 22 deletions

View File

@@ -0,0 +1,89 @@
/**
 * This program is free software, you can redistribute it and/or modify.
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file common.h
* \brief
*/
#ifndef MOE_GATING_TOP_K_COMMON_H
#define MOE_GATING_TOP_K_COMMON_H
#include "kernel_operator.h"
namespace MoeGatingTopK {
using namespace AscendC;
const float MIN_FP32 = *(float *)(&F32_NEG_INF);
constexpr int32_t FLOAT32_NEG_INF = 0xFF800000; // -inf -2139095040
constexpr int64_t ONE_REPEAT_SORT_NUM = 32;
constexpr int64_t BLOCK_BYTES = 32;
constexpr int64_t REPEAT_BYTES = 256;
constexpr int64_t REPEAT_BLOCKS = 8;
constexpr int32_t CONSTANT_TWO = 2;
constexpr int32_t CONSTANT_THREE = 3;
constexpr int32_t CONSTANT_FOUR = 4;
constexpr int32_t CONSTANT_EIGHT = 8;
constexpr int64_t MERGE_LIST_TWO = 2;
constexpr int64_t MERGE_LIST_THREE = 3;
constexpr int64_t MERGE_LIST_FOUR = 4;
constexpr int64_t MERGE_LIST_IDX_TWO = 2;
constexpr int64_t MERGE_LIST_IDX_THREE = 3;
constexpr int64_t NORM_TYPE_SOFTMAX = 0;
constexpr int64_t NORM_TYPE_SIGMOID = 1;
__aicore__ inline int64_t Ceil(int64_t a, int64_t b)
{
if (b == 0) {
return 0;
}
return (a + b - 1) / b;
}
__aicore__ inline int64_t Align(int64_t elementNum, int64_t bytes)
{
if (bytes == 0) {
return 0;
}
return (elementNum * bytes + BLOCK_BYTES - 1) / BLOCK_BYTES * BLOCK_BYTES / bytes;
}
__aicore__ inline int64_t AlignBytes(int64_t elementNum, int64_t bytes)
{
return (elementNum * bytes + BLOCK_BYTES - 1) / BLOCK_BYTES * BLOCK_BYTES;
}
template <typename T>
__aicore__ inline T Min(T a, T b)
{
return a > b ? b : a;
}
template <typename T>
__aicore__ inline T Max(T a, T b)
{
return a < b ? b : a;
}
template <typename T1, typename T2>
__aicore__ inline T1 CeilDiv(T1 x, T2 y)
{
if (y != 0 && x != 0) {
const T1 quotient = x / y;
return (x % y != 0 && ((x ^ y) >= 0)) ? (quotient + 1) : quotient;
}
return x;
}
} // namespace MoeGatingTopK
#endif // MOE_GATING_TOP_K_COMMON_H

View File

@@ -0,0 +1,56 @@
#ifndef OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_
#define OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_
#include <string>
#include "toolchain/slog.h"
#define OP_LOGI(opname, ...)
#define OP_LOGW(opname, ...) \
do { \
printf("[WARN][%s] ", (opname), ##__VA_ARGS__); \
printf("\n"); \
} while (0)
#define OP_LOGE_WITHOUT_REPORT(opname, ...) \
do { \
printf("[ERRORx][%s] ", (opname), ##__VA_ARGS__); \
printf("\n"); \
} while (0)
#define OP_LOGE(opname, ...) \
do { \
printf("[ERROR][%s] ", (opname), ##__VA_ARGS__); \
printf("\n"); \
} while (0)
#define OP_LOGD(opname, ...)
namespace optiling {
#define VECTOR_INNER_ERR_REPORT_TILIING(op_name, err_msg, ...) \
do { \
OP_LOGE_WITHOUT_REPORT(op_name, err_msg, ##__VA_ARGS__); \
} while (0)
// 修改 OP_TILING_CHECK 宏,确保正确处理表达式
#define OP_CHECK_IF(cond, log_func, expr) \
do { \
if (cond) { \
log_func; \
expr; \
} \
} while (0)
#define OP_CHECK_NULL_WITH_CONTEXT(context, ptr) \
do { \
if ((ptr) == nullptr) { \
OP_LOGE(context->GetNodeType(), "%s is null", #ptr); \
return ge::GRAPH_FAILED; \
} \
} while (0)
} // namespace optiling
#endif // OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_

View File

@@ -0,0 +1,63 @@
/**
 * This program is free software, you can redistribute it and/or modify.
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file moe_gating_top_k.cpp
* \brief
*/
#include "moe_gating_top_k_e_k_fullload.h"
#include "moe_gating_top_k_without_group.h"
#include "moe_gating_top_k_generalized.h"
#include "error_log.h"
#define TILING_KEY_PER_GROUP_COUNT_32 0
#define TILING_KEY_WITHOUT_GROUP 1
#define TILING_KEY_GENERALIZED 2
using namespace AscendC;
using namespace MoeGatingTopK;
extern "C" __global__ __aicore__ void moe_gating_top_k(GM_ADDR x, GM_ADDR bias, GM_ADDR y, GM_ADDR expertIdx,
GM_ADDR out, GM_ADDR workspace, GM_ADDR tiling)
{
KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_AIV_ONLY);
if (g_coreType == AIC) {
return;
}
GET_TILING_DATA_WITH_STRUCT(MoeGatingTopKTilingData, tilingData, tiling);
if (workspace == nullptr) {
return;
}
GM_ADDR userWS = GetUserWorkspace(workspace);
if (userWS == nullptr) {
return;
}
const MoeGatingTopKTilingData *__restrict t = &tilingData;
TPipe tPipe;
if (TILING_KEY_IS(TILING_KEY_PER_GROUP_COUNT_32)) {
MoeGatingTopKEKFullload<DTYPE_X> op;
op.Init(x, bias, y, expertIdx, out, userWS, t, &tPipe);
op.Process();
} else if (TILING_KEY_IS(TILING_KEY_WITHOUT_GROUP)) {
MoeGatingTopKWithoutGroup<DTYPE_X> op;
op.Init(x, bias, y, expertIdx, out, userWS, t, &tPipe);
op.Process();
} else if (TILING_KEY_IS(TILING_KEY_GENERALIZED)) {
MoeGatingTopKGenerlized<DTYPE_X> op;
op.Init(x, bias, y, expertIdx, out, userWS, t, &tPipe);
op.Process();
}
}

View File

@@ -0,0 +1,46 @@
/**
 * This program is free software, you can redistribute it and/or modify.
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file moe_gating_top_k_apt.cpp
* \brief
*/
#include "arch35/moe_gating_top_k_regbase.h"
using namespace AscendC;
using namespace MoeGatingTopK;
#define TILING_KEY_REGBASE 10000
extern "C" __global__ __aicore__ void moe_gating_top_k(GM_ADDR x, GM_ADDR bias, GM_ADDR y, GM_ADDR expertIdx,
GM_ADDR out, GM_ADDR workspace, GM_ADDR tiling)
{
if (g_coreType == AIC) {
return;
}
if (workspace == nullptr) {
return;
}
GM_ADDR userWS = GetUserWorkspace(workspace);
if (userWS == nullptr) {
return;
}
GET_TILING_DATA_WITH_STRUCT(MoeGatingTopKRegbaseTilingData, tiling_data_in, tiling);
const MoeGatingTopKRegbaseTilingData *__restrict tilingData = &tiling_data_in;
TPipe tPipe;
if (TILING_KEY_IS(TILING_KEY_REGBASE)) {
MoeGatingTopKRegbase<DTYPE_X> op;
op.Init(x, bias, y, expertIdx, out, userWS, tilingData, &tPipe);
op.Process();
}
}

View File

@@ -0,0 +1,404 @@
/**
 * This program is free software, you can redistribute it and/or modify.
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file moe_gating_top_k_e_k_fullload.h
* \brief
*/
#ifndef MOE_GATING_TOP_K_E_K_FULLLOAD_H
#define MOE_GATING_TOP_K_E_K_FULLLOAD_H
#include "kernel_operator.h"
#include "common.h"
namespace MoeGatingTopK {
using namespace AscendC;
template <typename T>
class MoeGatingTopKEKFullload {
public:
__aicore__ inline MoeGatingTopKEKFullload(){};
__aicore__ inline void Init(GM_ADDR x, GM_ADDR bias, GM_ADDR y, GM_ADDR expertIdx, GM_ADDR out, GM_ADDR workspace,
const MoeGatingTopKTilingData *tilingData, TPipe *tPipe);
__aicore__ inline void Process();
private:
__aicore__ inline void CopyInBias();
__aicore__ inline void CopyInX(int64_t progress);
__aicore__ inline void ComputeX();
__aicore__ inline void SortInGroup();
__aicore__ inline void SelectTopKGroupIndex();
__aicore__ inline void SelectTopKExpertIdx();
__aicore__ inline void SelectTopKExpertScore();
__aicore__ inline void CopyOut(int64_t progress);
private:
TPipe *pipe_;
TQue<QuePosition::VECIN, 1> xInQueue_;
TBuf<TPosition::VECCALC> biasInQueue_;
TQue<QuePosition::VECOUT, 1> yOutQueue_;
TQue<QuePosition::VECOUT, 1> expertIdxOutQueue_;
TQue<QuePosition::VECOUT, 1> outOutQueue_;
TQue<QuePosition::VECOUT, 1> xBiasQueue_;
TQue<QuePosition::VECOUT, 1> xSigmoidQueue_;
TQue<QuePosition::VECIN, 1> sigmoidTmpQueue_;
TQue<QuePosition::VECIN, 1> sortedInGroupQueue_;
TQue<QuePosition::VECIN, 1> sortedGroupQueue_;
TBuf<TPosition::VECCALC> calcTmpBuffer_;
GlobalTensor<T> xGm_;
GlobalTensor<T> biasGm_;
GlobalTensor<T> yGm_;
GlobalTensor<int32_t> expertIdxGm_;
GlobalTensor<T> outGm_;
int64_t blockIdx_;
int64_t perCoreRowCount_;
int64_t curCoreRowCount_;
int64_t expertCount_;
bool addBias_;
int64_t k_;
int64_t kGroup_;
int64_t groupCount_;
int64_t groupSelectMode_;
int64_t renorm_;
int64_t normType_;
int64_t outFlag_;
float routedScalingFactor_;
float eps_;
int64_t expertCountAlign_;
int64_t kAlign_;
int64_t perGroupExpertCount_;
const MoeGatingTopKTilingData *tilingData_;
};
template <typename T>
__aicore__ inline void MoeGatingTopKEKFullload<T>::CopyInBias()
{
LocalTensor<float> biasTensor = biasInQueue_.Get<float>();
DataCopyExtParams dataCopyParams{1, static_cast<uint32_t>(expertCount_ * sizeof(T)), 0, 0, 0};
DataCopyPadExtParams dataCopyPadParams{false, 0, 0, static_cast<T>(0)};
if constexpr (IsSameType<T, float>::value) {
DataCopyPad(biasTensor, biasGm_, dataCopyParams, dataCopyPadParams);
event_t eventIdMte2ToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
SetFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
WaitFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
} else {
DataCopyPad(biasTensor[expertCountAlign_].ReinterpretCast<T>(), biasGm_, dataCopyParams, dataCopyPadParams);
event_t eventIdMte2ToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
SetFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
WaitFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
Cast(biasTensor, biasTensor[expertCountAlign_].ReinterpretCast<T>(), RoundMode::CAST_NONE, expertCount_);
}
}
template <typename T>
__aicore__ inline void MoeGatingTopKEKFullload<T>::CopyInX(int64_t row)
{
LocalTensor<float> xInLocalTensor = xInQueue_.AllocTensor<float>();
DataCopyExtParams dataCopyParams{1, static_cast<uint32_t>(expertCount_ * sizeof(T)), 0, 0, 0};
DataCopyPadExtParams dataCopyPadParams{false, 0, 0, static_cast<T>(0)};
if constexpr (IsSameType<T, float>::value) {
DataCopyPad(xInLocalTensor, xGm_[row * expertCount_], dataCopyParams, dataCopyPadParams);
} else {
DataCopyPad(xInLocalTensor[expertCountAlign_].ReinterpretCast<T>(), xGm_[row * expertCount_], dataCopyParams,
dataCopyPadParams);
event_t eventIdMte2ToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
SetFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
WaitFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
Cast(xInLocalTensor, xInLocalTensor[expertCountAlign_].ReinterpretCast<T>(), RoundMode::CAST_NONE,
expertCount_);
}
xInQueue_.EnQue(xInLocalTensor);
}
template <typename T>
__aicore__ inline void MoeGatingTopKEKFullload<T>::ComputeX()
{
LocalTensor<float> xSigmoidTensor = xSigmoidQueue_.AllocTensor<float>();
LocalTensor<float> xInLocalTensor = xInQueue_.DeQue<float>();
LocalTensor<float> xBiasTensor = xBiasQueue_.AllocTensor<float>();
LocalTensor<float> biasTensor = biasInQueue_.Get<float>();
LocalTensor<uint8_t> sharedTmpBuffer = sigmoidTmpQueue_.AllocTensor<uint8_t>(); // 临时空间可以复用
Sigmoid(xSigmoidTensor, xInLocalTensor, sharedTmpBuffer, expertCount_);
PipeBarrier<PIPE_V>();
if (addBias_) {
Add(xBiasTensor, xSigmoidTensor, biasTensor, expertCount_);
} else {
Adds(xBiasTensor, xSigmoidTensor, static_cast<float>(0), expertCount_);
}
xSigmoidQueue_.EnQue<float>(xSigmoidTensor);
xBiasQueue_.EnQue<float>(xBiasTensor);
xInQueue_.FreeTensor(xInLocalTensor);
sigmoidTmpQueue_.FreeTensor(sharedTmpBuffer);
}
template <typename T>
__aicore__ inline void MoeGatingTopKEKFullload<T>::SortInGroup()
{
LocalTensor<float> xBiasTensor = xBiasQueue_.DeQue<float>();
LocalTensor<float> sortedInGroupTensor = sortedInGroupQueue_.AllocTensor<float>(); // 组内排序的结果, 后续归并需要
LocalTensor<uint32_t> indexTensor = calcTmpBuffer_.Get<uint32_t>(); // 用于存储排序时的索引
ArithProgression(indexTensor.ReinterpretCast<int32_t>(), 0, 1, expertCount_); // 生成组索引0 1 2 ......
PipeBarrier<PIPE_V>();
Sort32(sortedInGroupTensor, xBiasTensor, indexTensor, expertCount_ / ONE_REPEAT_SORT_NUM); // 组内排序
sortedInGroupQueue_.EnQue<float>(sortedInGroupTensor);
xBiasQueue_.FreeTensor(xBiasTensor);
}
template <typename T>
__aicore__ inline void MoeGatingTopKEKFullload<T>::SelectTopKGroupIndex()
{
LocalTensor<float> sortedInGroupTensor = sortedInGroupQueue_.DeQue<float>();
LocalTensor<uint32_t> indexTensor = calcTmpBuffer_.Get<uint32_t>();
LocalTensor<float> top2ValueInGroupTensor = sigmoidTmpQueue_.AllocTensor<float>(); // 这个临时空间可以复用
event_t eventIdVToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
SetFlag<HardEvent::V_S>(eventIdVToS);
WaitFlag<HardEvent::V_S>(eventIdVToS);
indexTensor.SetValue(0, static_cast<uint32_t>(5)); // b0101
indexTensor.SetValue(1, static_cast<uint32_t>(0));
event_t eventIdSToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
SetFlag<HardEvent::S_V>(eventIdSToV);
WaitFlag<HardEvent::S_V>(eventIdSToV);
uint64_t rsvdCnt = 0; // 用于保存筛选后保留下来的元素个数
GatherMaskParams gatherMaskParams;
gatherMaskParams.repeatTimes = 8;
gatherMaskParams.src0BlockStride = 1;
gatherMaskParams.src0RepeatStride = 8;
gatherMaskParams.src1RepeatStride = 0;
GatherMask(top2ValueInGroupTensor, sortedInGroupTensor, indexTensor, true, static_cast<uint32_t>(64),
gatherMaskParams, rsvdCnt);
PipeBarrier<PIPE_V>();
LocalTensor<float> groupTop2SumTensor = top2ValueInGroupTensor;
PairReduceSum(groupTop2SumTensor, top2ValueInGroupTensor, 1, groupCount_ * 2, 1, 1,
1); // 计算每个组内最大的两个数之和
PipeBarrier<PIPE_V>();
LocalTensor<uint32_t> groupIndexTensor = indexTensor;
ArithProgression(groupIndexTensor.ReinterpretCast<int32_t>(), 0, 1, groupCount_); // 生成组索引
PipeBarrier<PIPE_V>();
// 用最小值补到32个数
int64_t duplicateNum = ONE_REPEAT_SORT_NUM - groupCount_;
if (duplicateNum > 0) {
uint64_t mask0 = UINT64_MAX << groupCount_;
uint64_t mask[2] = {mask0, 0};
Duplicate(groupTop2SumTensor, MIN_FP32, mask, 1, 1, 8);
PipeBarrier<PIPE_V>();
}
// 排序将kgroup选出来
LocalTensor<float> sortedGroupTensor = sortedGroupQueue_.AllocTensor<float>();
Sort32(sortedGroupTensor, groupTop2SumTensor, groupIndexTensor, 1);
PipeBarrier<PIPE_V>();
LocalTensor<int32_t> sortedGroupIndexTensor = indexTensor.ReinterpretCast<int32_t>();
// 提取组序号
uint8_t src1Pattern = 2; // 内置固定模式
GatherMask(sortedGroupIndexTensor, sortedGroupTensor.template ReinterpretCast<int32_t>(), src1Pattern, false,
static_cast<uint32_t>(0), {1, 1, 0, 0}, rsvdCnt);
// 需要将组排序(这里是降序所以下mrgsor的时候反着取3、2、1、0)
Cast(sortedGroupTensor, sortedGroupIndexTensor, RoundMode::CAST_ROUND, kGroup_);
PipeBarrier<PIPE_V>();
duplicateNum = ONE_REPEAT_SORT_NUM - kGroup_;
if (duplicateNum > 0) {
uint64_t mask0 = UINT64_MAX << kGroup_;
uint64_t mask[2] = {mask0, 0};
Duplicate(sortedGroupTensor, MIN_FP32, mask, 1, 1, 8);
PipeBarrier<PIPE_V>();
}
Sort32(top2ValueInGroupTensor, sortedGroupTensor, sortedGroupIndexTensor.template ReinterpretCast<uint32_t>(), 1);
PipeBarrier<PIPE_V>();
src1Pattern = 1;
GatherMask(sortedGroupTensor, top2ValueInGroupTensor, src1Pattern, false, static_cast<uint32_t>(0), {1, 1, 0, 0},
rsvdCnt);
PipeBarrier<PIPE_V>();
Cast(sortedGroupIndexTensor, sortedGroupTensor, RoundMode::CAST_ROUND, kGroup_);
sortedGroupQueue_.FreeTensor(sortedGroupTensor);
sortedInGroupQueue_.EnQue<float>(sortedInGroupTensor);
sigmoidTmpQueue_.FreeTensor(top2ValueInGroupTensor);
}
template <typename T>
__aicore__ inline void MoeGatingTopKEKFullload<T>::SelectTopKExpertIdx()
{
LocalTensor<int32_t> expertIdxTensor = expertIdxOutQueue_.AllocTensor<int32_t>();
LocalTensor<int32_t> topKGroupIndexTensor = calcTmpBuffer_.Get<int32_t>();
LocalTensor<float> sortedInGroupTensor = sortedInGroupQueue_.DeQue<float>();
LocalTensor<float> sortedExpertTensor = xInQueue_.AllocTensor<float>();
AscendC::MrgSort4Info params;
params.elementLengths[0] = k_;
params.elementLengths[1] = k_;
params.elementLengths[2] = k_;
params.elementLengths[3] = k_;
params.ifExhaustedSuspension = true;
params.validBit = 0b1111;
params.repeatTimes = 1;
event_t eventIdVToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
SetFlag<HardEvent::V_S>(eventIdVToS);
WaitFlag<HardEvent::V_S>(eventIdVToS);
int64_t listOffset1 = topKGroupIndexTensor.GetValue(3) * perGroupExpertCount_ * 2;
int64_t listOffset2 = topKGroupIndexTensor.GetValue(2) * perGroupExpertCount_ * 2;
int64_t listOffset3 = topKGroupIndexTensor.GetValue(1) * perGroupExpertCount_ * 2;
int64_t listOffset4 = topKGroupIndexTensor.GetValue(0) * perGroupExpertCount_ * 2;
event_t eventIdSToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
SetFlag<HardEvent::S_V>(eventIdSToV);
WaitFlag<HardEvent::S_V>(eventIdSToV);
AscendC::MrgSortSrcList<float> srcList;
srcList.src1 = sortedInGroupTensor[listOffset1];
srcList.src2 = sortedInGroupTensor[listOffset2];
srcList.src3 = sortedInGroupTensor[listOffset3];
srcList.src4 = sortedInGroupTensor[listOffset4];
MrgSort<float>(sortedExpertTensor, srcList, params);
PipeBarrier<PIPE_V>();
uint64_t rsvdCnt = 0; // 用于保存筛选后保留下来的元素个数
uint8_t src1Pattern = 2; // 内置固定模式
GatherMask(expertIdxTensor, sortedExpertTensor.template ReinterpretCast<int32_t>(), src1Pattern, false,
static_cast<uint32_t>(0), {1, 1, 0, 0}, rsvdCnt);
xInQueue_.FreeTensor(sortedExpertTensor);
expertIdxOutQueue_.EnQue(expertIdxTensor);
sortedInGroupQueue_.FreeTensor(sortedInGroupTensor);
}
template <typename T>
__aicore__ inline void MoeGatingTopKEKFullload<T>::SelectTopKExpertScore()
{
LocalTensor<int32_t> expertIdxTensor = expertIdxOutQueue_.DeQue<int32_t>();
LocalTensor<int32_t> expertByteIdxTensor = calcTmpBuffer_.Get<int32_t>();
LocalTensor<float> xSigmoidTensor = xSigmoidQueue_.DeQue<float>();
LocalTensor<T> yTensor = yOutQueue_.AllocTensor<T>();
LocalTensor<float> yOutTensor;
if constexpr (!IsSameType<T, float>::value) {
yOutTensor = yTensor.template ReinterpretCast<float>()[kAlign_];
} else {
yOutTensor = yTensor;
}
Muls(expertByteIdxTensor, expertIdxTensor, static_cast<int32_t>(sizeof(float)), k_);
PipeBarrier<PIPE_V>();
Gather(yOutTensor, xSigmoidTensor, expertByteIdxTensor.template ReinterpretCast<uint32_t>(),
static_cast<uint32_t>(0), k_);
LocalTensor<float> calTensor = calcTmpBuffer_.Get<float>();
PipeBarrier<PIPE_V>();
ReduceSum(calTensor, yOutTensor, xSigmoidTensor, k_);
event_t eventIdVToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
SetFlag<HardEvent::V_S>(eventIdVToS);
WaitFlag<HardEvent::V_S>(eventIdVToS);
float sumValue = calTensor.GetValue(0) + eps_;
event_t eventIdSToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
SetFlag<HardEvent::S_V>(eventIdSToV);
WaitFlag<HardEvent::S_V>(eventIdSToV);
Duplicate(calTensor, sumValue, k_);
PipeBarrier<PIPE_V>();
Div(yOutTensor, yOutTensor, calTensor, k_);
PipeBarrier<PIPE_V>();
Muls(yOutTensor, yOutTensor, routedScalingFactor_, k_);
if constexpr (!IsSameType<T, float>::value) {
PipeBarrier<PIPE_V>();
Cast(yTensor, yOutTensor, RoundMode::CAST_RINT, k_);
}
xSigmoidQueue_.EnQue<float>(xSigmoidTensor);
expertIdxOutQueue_.EnQue<int32_t>(expertIdxTensor);
yOutQueue_.EnQue(yTensor);
}
template <typename T>
__aicore__ inline void MoeGatingTopKEKFullload<T>::CopyOut(int64_t row)
{
LocalTensor<T> yOutTensor = yOutQueue_.DeQue<T>();
LocalTensor<int32_t> expertIdxTensor = expertIdxOutQueue_.DeQue<int32_t>();
LocalTensor<float> xSigmoidTensor = xSigmoidQueue_.DeQue<float>();
DataCopyExtParams dataCopyParams{1, static_cast<uint32_t>(k_ * sizeof(T)), 0, 0, 0};
DataCopyPad(yGm_[row * k_], yOutTensor, dataCopyParams);
dataCopyParams.blockLen = k_ * sizeof(int32_t);
DataCopyPad(expertIdxGm_[row * k_], expertIdxTensor, dataCopyParams);
xSigmoidQueue_.FreeTensor(xSigmoidTensor);
expertIdxOutQueue_.FreeTensor(expertIdxTensor);
yOutQueue_.FreeTensor(yOutTensor);
}
template <typename T>
__aicore__ inline void MoeGatingTopKEKFullload<T>::Init(GM_ADDR x, GM_ADDR bias, GM_ADDR y, GM_ADDR expertIdx,
GM_ADDR out, GM_ADDR workspace,
const MoeGatingTopKTilingData *tilingData, TPipe *tPipe)
{
tilingData_ = tilingData;
pipe_ = tPipe;
blockIdx_ = GetBlockIdx();
perCoreRowCount_ = tilingData_->perCoreRowCount;
if (blockIdx_ == GetBlockNum() - 1) {
curCoreRowCount_ = tilingData_->lastCoreRowCount;
} else {
curCoreRowCount_ = tilingData_->perCoreRowCount;
}
expertCount_ = tilingData_->expertCount;
addBias_ = tilingData_->addBias == 1;
k_ = tilingData_->k;
kGroup_ = tilingData_->kGroup;
groupCount_ = tilingData_->groupCount;
perGroupExpertCount_ = tilingData_->perGroupExpertCount;
routedScalingFactor_ = tilingData_->routedScalingFactor;
eps_ = tilingData_->eps;
expertCountAlign_ = Align(expertCount_, sizeof(float));
kAlign_ = Align(expertCount_, sizeof(float));
// init input gm buf
xGm_.SetGlobalBuffer((__gm__ T *)x + perCoreRowCount_ * expertCount_ * blockIdx_, expertCount_);
biasGm_.SetGlobalBuffer((__gm__ T *)bias, expertCount_);
// init output gm buf
yGm_.SetGlobalBuffer((__gm__ T *)y + perCoreRowCount_ * k_ * blockIdx_, k_);
expertIdxGm_.SetGlobalBuffer((__gm__ int32_t *)expertIdx + perCoreRowCount_ * k_ * blockIdx_, k_);
outGm_.SetGlobalBuffer((__gm__ T *)out + perCoreRowCount_ * expertCount_ * blockIdx_, expertCount_);
// init que
pipe_->InitBuffer(xInQueue_, 2, expertCountAlign_ * sizeof(float) * (sizeof(float) / sizeof(T)));
pipe_->InitBuffer(biasInQueue_, expertCountAlign_ * sizeof(float) * (sizeof(float) / sizeof(T)));
pipe_->InitBuffer(xSigmoidQueue_, 1, AlignBytes(expertCount_, sizeof(float)));
pipe_->InitBuffer(xBiasQueue_, 2, AlignBytes(expertCount_, sizeof(float)));
pipe_->InitBuffer(yOutQueue_, 2, kAlign_ * sizeof(float) * (sizeof(float) / sizeof(T)));
pipe_->InitBuffer(expertIdxOutQueue_, 2, AlignBytes(k_, sizeof(int32_t)));
pipe_->InitBuffer(outOutQueue_, 2, AlignBytes(expertCount_, sizeof(float)));
pipe_->InitBuffer(sigmoidTmpQueue_, 2, AlignBytes(expertCount_, sizeof(float)));
pipe_->InitBuffer(sortedInGroupQueue_, 2, AlignBytes(expertCount_, sizeof(float)) * 2);
pipe_->InitBuffer(sortedGroupQueue_, 2,
(groupCount_ + ONE_REPEAT_SORT_NUM - 1) / ONE_REPEAT_SORT_NUM * ONE_REPEAT_SORT_NUM *
sizeof(float) * 2);
pipe_->InitBuffer(calcTmpBuffer_, tilingData_->calTmpBufUbSize);
}
template <typename T>
__aicore__ inline void MoeGatingTopKEKFullload<T>::Process()
{
CopyInBias();
for (int64_t row = 0; row < curCoreRowCount_; row++) {
CopyInX(row);
ComputeX();
SortInGroup();
SelectTopKGroupIndex();
SelectTopKExpertIdx();
SelectTopKExpertScore();
CopyOut(row);
}
}
} // namespace MoeGatingTopK
#endif // MOE_GATING_TOP_K_E_K_FULLLOAD_H

View File

@@ -0,0 +1,669 @@
/**
 * This program is free software, you can redistribute it and/or modify.
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file moe_gating_top_k_generalized.h
* \brief
*/
#ifndef MOE_GATING_TOP_K_E_K_GENERALIZED_H
#define MOE_GATING_TOP_K_E_K_GENERALIZED_H
#include "kernel_operator.h"
#include "common.h"
#include "kernel_utils.h"
namespace MoeGatingTopK {
using namespace AscendC;
template <typename T>
class MoeGatingTopKGenerlized {
public:
__aicore__ inline MoeGatingTopKGenerlized(){};
__aicore__ inline void Init(GM_ADDR x, GM_ADDR bias, GM_ADDR y, GM_ADDR expertIdx, GM_ADDR out, GM_ADDR workspace,
const MoeGatingTopKTilingData *tilingData, TPipe *tPipe);
__aicore__ inline void Process();
private:
__aicore__ inline void CopyInBiasAndInitExpertId();
__aicore__ inline void CopyInX(int64_t progress);
__aicore__ inline void ComputeX();
__aicore__ inline void CopuOutXNorm(int64_t row);
__aicore__ inline void SortInGroup();
__aicore__ inline void SelectTopKGroupIndex();
__aicore__ inline void SelectTopKExpertIdx();
__aicore__ inline void SelectTopKExpertScore();
__aicore__ inline void CumputeActualTopKExpertId();
__aicore__ inline void CopyOut(int64_t row);
private:
TPipe *pipe_;
TQue<QuePosition::VECIN, 1> xInQueue_;
TQue<QuePosition::VECOUT, 1> yOutQueue_;
TQue<QuePosition::VECOUT, 1> expertIdxOutQueue_;
TQue<QuePosition::VECOUT, 1> outOutQueue_;
TBuf<TPosition::VECCALC> biasBuf_; // 存放输入bias
TBuf<TPosition::VECCALC> expertIdBuf_; // 专家编号
TBuf<TPosition::VECCALC> xNormWithBiasBuf_; // 存放加了bias之后的值
TBuf<TPosition::VECCALC> xNormBuf_; // 存放计算sigmoid或softmax的值
TBuf<TPosition::VECCALC> sortedInGroupBuf_; // 存放组内排序后的结果
TBuf<TPosition::VECCALC> topKExpertIdBuf_;
TBuf<TPosition::VECCALC> sortedGroupIndexBuf_;
TBuf<TPosition::VECCALC> calcTmpBuf_;
GlobalTensor<T> xGm_;
GlobalTensor<T> biasGm_;
GlobalTensor<T> yGm_;
GlobalTensor<int32_t> expertIdxGm_;
GlobalTensor<float> outGm_;
int64_t blockIdx_ = 0;
int64_t perCoreRowCount_ = 0;
int64_t curCoreRowCount_ = 0;
int64_t expertCount_ = 0;
bool addBias_ = false;
int64_t k_ = 0;
int64_t kGroup_ = 0;
int64_t groupCount_ = 0;
int64_t groupCountAlign_ = 0;
int64_t perGroupExpertCount_ = 0;
int64_t perGroupExpertCountAlign_ = 0;
int64_t groupSelectMode_ = 0;
int64_t renorm_ = 0;
int64_t normType_ = 0;
int64_t outFlag_ = 0;
int64_t expertCountAlign_ = 0;
int64_t kAlign_ = 0;
bool isAlign_ = false;
const MoeGatingTopKTilingData *tilingData_;
};
template <typename T>
__aicore__ inline void MoeGatingTopKGenerlized<T>::CopyInBiasAndInitExpertId()
{
LocalTensor<float> biasTensor = biasBuf_.Get<float>();
LocalTensor<int32_t> expertIdTensor = expertIdBuf_.Get<int32_t>();
DataCopyExtParams dataCopyParams;
dataCopyParams.blockCount = groupCount_;
dataCopyParams.blockLen = perGroupExpertCount_ * sizeof(T);
dataCopyParams.srcStride = 0;
dataCopyParams.dstStride = (perGroupExpertCountAlign_ - perGroupExpertCount_) * sizeof(T) / BLOCK_BYTES;
DataCopyPadExtParams dataCopyPadParams{false, 0, 0, static_cast<T>(0)};
if (addBias_) {
if constexpr (IsSameType<T, float>::value) {
DataCopyPad(biasTensor, biasGm_, dataCopyParams, dataCopyPadParams);
event_t eventIdMte2ToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
SetFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
WaitFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
} else {
DataCopyPad(biasTensor[expertCountAlign_].ReinterpretCast<T>(), biasGm_, dataCopyParams, dataCopyPadParams);
event_t eventIdMte2ToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
SetFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
WaitFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
Cast(biasTensor, biasTensor[expertCountAlign_].ReinterpretCast<T>(), RoundMode::CAST_NONE,
expertCountAlign_);
PipeBarrier<PIPE_V>();
}
if (!isAlign_) {
int64_t duplicateNum = perGroupExpertCount_ % ONE_REPEAT_SORT_NUM;
int duplicateIndex = perGroupExpertCount_ - duplicateNum;
if (duplicateNum > 0) {
uint64_t mask0 = UINT64_MAX;
mask0 = mask0 << duplicateNum;
mask0 = mask0 & (UINT64_MAX >> ONE_REPEAT_SORT_NUM);
uint64_t mask[2] = {mask0, 0};
Duplicate(biasTensor.ReinterpretCast<int32_t>()[duplicateIndex], FLOAT32_NEG_INF, mask, groupCount_, 1,
perGroupExpertCountAlign_ * sizeof(float) / BLOCK_BYTES);
}
}
}
ArithProgression(expertIdTensor, static_cast<int32_t>(0), static_cast<int32_t>(1), expertCountAlign_);
}
template <typename T>
__aicore__ inline void MoeGatingTopKGenerlized<T>::CopyInX(int64_t row)
{
LocalTensor<float> xInLocalTensor = xInQueue_.AllocTensor<float>();
DataCopyExtParams dataCopyParams;
dataCopyParams.blockCount = groupCount_;
dataCopyParams.blockLen = perGroupExpertCount_ * sizeof(T);
dataCopyParams.srcStride = 0;
dataCopyParams.dstStride = (perGroupExpertCountAlign_ - perGroupExpertCount_) * sizeof(T) / BLOCK_BYTES;
DataCopyPadExtParams dataCopyPadParams{false, 0, 0, static_cast<T>(0)};
if constexpr (IsSameType<T, float>::value) {
DataCopyPad(xInLocalTensor, xGm_[row * expertCount_], dataCopyParams, dataCopyPadParams);
} else {
DataCopyPad(xInLocalTensor[expertCountAlign_].ReinterpretCast<T>(), xGm_[row * expertCount_], dataCopyParams,
dataCopyPadParams);
}
xInQueue_.EnQue(xInLocalTensor);
}
template <typename T>
__aicore__ inline void MoeGatingTopKGenerlized<T>::ComputeX()
{
LocalTensor<float> xNormTensor = xNormBuf_.Get<float>();
LocalTensor<float> xInLocalTensor = xInQueue_.DeQue<float>();
LocalTensor<float> xNormWithBiasTensor = xNormWithBiasBuf_.Get<float>();
LocalTensor<float> biasTensor = biasBuf_.Get<float>();
if constexpr (!IsSameType<T, float>::value) {
Cast(xInLocalTensor, xInLocalTensor[expertCountAlign_].ReinterpretCast<T>(), RoundMode::CAST_NONE,
expertCountAlign_);
PipeBarrier<PIPE_V>();
}
int64_t duplicateNum = perGroupExpertCount_ % ONE_REPEAT_SORT_NUM;
int duplicateIndex = perGroupExpertCount_ - duplicateNum;
if (!isAlign_ && duplicateNum > 0) {
uint64_t mask0 = UINT64_MAX;
mask0 = mask0 << duplicateNum;
mask0 = mask0 & (UINT64_MAX >> ONE_REPEAT_SORT_NUM);
uint64_t mask[2] = {mask0, 0};
Duplicate(xInLocalTensor.ReinterpretCast<int32_t>()[duplicateIndex], FLOAT32_NEG_INF, mask, groupCount_, 1,
(perGroupExpertCountAlign_ * sizeof(float)) / BLOCK_BYTES);
PipeBarrier<PIPE_V>();
}
if (normType_ == 1) { // sigmoid
LocalTensor<uint8_t> calcNormTmpTensor = calcTmpBuf_.Get<uint8_t>();
Sigmoid(xNormTensor, xInLocalTensor, calcNormTmpTensor, expertCountAlign_);
PipeBarrier<PIPE_V>();
}
else if (normType_ == 0) { // softmax
LocalTensor<float> reduceValueTensor = calcTmpBuf_.Get<float>();
LocalTensor<float> calcTmp = calcTmpBuf_.Get<float>()[BLOCK_BYTES];
ReduceMax(reduceValueTensor, xInLocalTensor, calcTmp, expertCountAlign_);
event_t eventIdVToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
SetFlag<HardEvent::V_S>(eventIdVToS);
WaitFlag<HardEvent::V_S>(eventIdVToS);
float maxValue = reduceValueTensor.GetValue(0);
event_t eventIdSToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
SetFlag<HardEvent::S_V>(eventIdSToV);
WaitFlag<HardEvent::S_V>(eventIdSToV);
Adds(xNormTensor, xInLocalTensor, -maxValue, expertCountAlign_);
PipeBarrier<PIPE_V>();
Exp(xNormTensor, xNormTensor, expertCountAlign_);
PipeBarrier<PIPE_V>();
ReduceSum(reduceValueTensor, xNormTensor, calcTmp, expertCountAlign_);
eventIdVToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
SetFlag<HardEvent::V_S>(eventIdVToS);
WaitFlag<HardEvent::V_S>(eventIdVToS);
float sumValue = reduceValueTensor.GetValue(0);
eventIdSToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
SetFlag<HardEvent::S_V>(eventIdSToV);
WaitFlag<HardEvent::S_V>(eventIdSToV);
Muls(xNormTensor, xNormTensor, 1.0f / sumValue, expertCountAlign_);
PipeBarrier<PIPE_V>();
}
if (addBias_) {
Add(xNormWithBiasTensor, xNormTensor, biasTensor, expertCountAlign_);
} else {
DataCopy(xNormWithBiasTensor, xNormTensor, expertCountAlign_);
}
if (!isAlign_ && duplicateNum > 0) {
uint64_t mask0 = UINT64_MAX;
mask0 = mask0 << duplicateNum;
mask0 = mask0 & (UINT64_MAX >> ONE_REPEAT_SORT_NUM);
uint64_t mask[2] = {mask0, 0};
PipeBarrier<PIPE_V>();
Duplicate(xNormWithBiasTensor.ReinterpretCast<int32_t>()[duplicateIndex],
FLOAT32_NEG_INF, // MIN_FP32,
mask, groupCount_, 1, perGroupExpertCountAlign_ * sizeof(float) / BLOCK_BYTES);
}
xInQueue_.FreeTensor(xInLocalTensor);
}
template <typename T>
__aicore__ inline void MoeGatingTopKGenerlized<T>::CopuOutXNorm(int64_t row)
{
LocalTensor<float> outOutTensor = outOutQueue_.AllocTensor<float>();
LocalTensor<float> xNormTensor = xNormBuf_.Get<float>();
DataCopy(outOutTensor, xNormTensor, expertCountAlign_);
outOutQueue_.EnQue<float>(outOutTensor);
outOutTensor = outOutQueue_.DeQue<float>();
DataCopyExtParams dataCopyParams{
static_cast<uint16_t>(groupCount_), static_cast<uint32_t>(perGroupExpertCount_ * sizeof(float)),
static_cast<uint32_t>((perGroupExpertCountAlign_ - perGroupExpertCount_) * sizeof(float) / BLOCK_BYTES), 0, 0};
DataCopyPad(outGm_[row * expertCount_], outOutTensor, dataCopyParams);
outOutQueue_.FreeTensor(outOutTensor);
}
template <typename T>
__aicore__ inline void MoeGatingTopKGenerlized<T>::SortInGroup()
{
LocalTensor<float> xNormWithBiasTensor = xNormWithBiasBuf_.Get<float>();
LocalTensor<uint32_t> expertIdTensor = expertIdBuf_.Get<uint32_t>();
LocalTensor<float> sortedInGroupTensor = sortedInGroupBuf_.Get<float>();
LocalTensor<float> tmpLocal = calcTmpBuf_.Get<float>();
if (perGroupExpertCountAlign_ == ONE_REPEAT_SORT_NUM) {
PipeBarrier<PIPE_V>();
Sort32(sortedInGroupTensor, xNormWithBiasTensor, expertIdTensor, groupCount_);
} else {
for (int64_t group = 0; group < groupCount_; group++) {
PipeBarrier<PIPE_V>();
Sort<float, true>(sortedInGroupTensor[group * perGroupExpertCountAlign_ * CONSTANT_TWO],
xNormWithBiasTensor[group * perGroupExpertCountAlign_],
expertIdTensor[group * perGroupExpertCountAlign_], tmpLocal,
perGroupExpertCountAlign_ / ONE_REPEAT_SORT_NUM);
}
}
}
template <typename T>
__aicore__ inline void MoeGatingTopKGenerlized<T>::SelectTopKGroupIndex()
{
LocalTensor<float> sortedInGroupTensor = sortedInGroupBuf_.Get<float>();
LocalTensor<float> valueSelectedFromGroupTensor = calcTmpBuf_.GetWithOffset<float>(groupCountAlign_ * 2, 0);
LocalTensor<uint32_t> maskTensor =
calcTmpBuf_.GetWithOffset<uint32_t>(groupCountAlign_, groupCountAlign_ * 2 * sizeof(float));
LocalTensor<float> topValueInGroupTensor =
calcTmpBuf_.GetWithOffset<float>(groupCountAlign_, groupCountAlign_ * 3 * sizeof(float));
LocalTensor<uint32_t> groupIndex =
calcTmpBuf_.GetWithOffset<uint32_t>(groupCountAlign_, groupCountAlign_ * 4 * sizeof(float));
LocalTensor<float> sortedTopValue =
calcTmpBuf_.GetWithOffset<float>(groupCountAlign_ * 2, groupCountAlign_ * 5 * sizeof(float));
LocalTensor<float> sortTmp =
calcTmpBuf_.GetWithOffset<float>(groupCountAlign_ * 2, groupCountAlign_ * 7 * sizeof(float));
event_t eventIdVToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
SetFlag<HardEvent::V_S>(eventIdVToS);
WaitFlag<HardEvent::V_S>(eventIdVToS);
uint64_t rsvdCnt = 0; // 用于保存筛选后保留下来的元素个数
PipeBarrier<PIPE_V>();
if (groupSelectMode_ == 1) { // top2 sum
// 提取每组组前两个元素
maskTensor.SetValue(0, static_cast<uint32_t>(5)); // b0101
maskTensor.SetValue(1, static_cast<uint32_t>(0));
event_t eventIdSToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
SetFlag<HardEvent::S_V>(eventIdSToV);
WaitFlag<HardEvent::S_V>(eventIdSToV);
GatherMaskParams gatherMaskParams;
gatherMaskParams.repeatTimes = groupCount_;
gatherMaskParams.src0BlockStride = 1;
gatherMaskParams.src0RepeatStride =
Ceil(perGroupExpertCountAlign_ * (sizeof(float) + sizeof(uint32_t)), BLOCK_BYTES);
gatherMaskParams.src1RepeatStride = 0;
GatherMask(valueSelectedFromGroupTensor, sortedInGroupTensor, maskTensor, true,
static_cast<uint32_t>(ONE_REPEAT_SORT_NUM * CONSTANT_TWO), gatherMaskParams, rsvdCnt);
PipeBarrier<PIPE_V>();
// 计算每个组前两个数的和
PairReduceSum(topValueInGroupTensor, valueSelectedFromGroupTensor,
Ceil(groupCount_ * sizeof(float) * 2, REPEAT_BYTES), REPEAT_BYTES / sizeof(float), 1, 1,
CONSTANT_EIGHT); // 计算每个组内最大的两个数之和
} else {
maskTensor.SetValue(0, static_cast<uint32_t>(1)); // b0101
maskTensor.SetValue(1, static_cast<uint32_t>(0));
event_t eventIdSToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
SetFlag<HardEvent::S_V>(eventIdSToV);
WaitFlag<HardEvent::S_V>(eventIdSToV);
uint64_t rsvdCnt = 0; // 用于保存筛选后保留下来的元素个数
GatherMaskParams gatherMaskParams;
gatherMaskParams.repeatTimes = groupCount_;
gatherMaskParams.src0BlockStride = 1;
gatherMaskParams.src0RepeatStride = Ceil(perGroupExpertCountAlign_ * (sizeof(float) + sizeof(uint32_t)), 32);
gatherMaskParams.src1RepeatStride = 0;
GatherMask(topValueInGroupTensor, sortedInGroupTensor, maskTensor, true,
static_cast<uint32_t>(ONE_REPEAT_SORT_NUM * CONSTANT_TWO), gatherMaskParams, rsvdCnt);
}
PipeBarrier<PIPE_V>();
// 生成组索引
ArithProgression(groupIndex.ReinterpretCast<int32_t>(), static_cast<int32_t>(0), static_cast<int32_t>(1),
groupCount_); // 生成组索引
PipeBarrier<PIPE_V>();
int64_t duplicateNum = groupCount_ % ONE_REPEAT_SORT_NUM;
int duplicateIndex = groupCount_ - duplicateNum;
if (duplicateNum > 0) {
uint64_t mask0 = UINT64_MAX;
mask0 = mask0 << duplicateNum;
mask0 = mask0 & (UINT64_MAX >> ONE_REPEAT_SORT_NUM);
uint64_t mask[2] = {mask0, 0};
Duplicate(topValueInGroupTensor.ReinterpretCast<int32_t>()[duplicateIndex], FLOAT32_NEG_INF, mask, 1, 1,
REPEAT_BLOCKS);
PipeBarrier<PIPE_V>();
}
PipeBarrier<PIPE_V>();
// 排序
Sort<float, true>(sortedTopValue, topValueInGroupTensor, groupIndex, sortTmp, Ceil(groupCount_, 32));
PipeBarrier<PIPE_V>();
// 提取组序号
uint8_t src1Pattern = 2; // 内置固定模式
GatherMask(groupIndex, sortedTopValue.template ReinterpretCast<uint32_t>(), src1Pattern, false,
static_cast<uint32_t>(0),
{1, static_cast<uint8_t>(Ceil(kGroup_ * sizeof(float) * CONSTANT_TWO, 256)), REPEAT_BLOCKS, 0}, rsvdCnt);
PipeBarrier<PIPE_V>();
duplicateNum = kGroup_ % ONE_REPEAT_SORT_NUM;
if (duplicateNum > 0) {
duplicateIndex = kGroup_ - duplicateNum;
uint64_t mask0 = UINT64_MAX;
mask0 = mask0 << duplicateNum;
mask0 = mask0 & (UINT64_MAX >> ONE_REPEAT_SORT_NUM);
uint64_t mask[2] = {mask0, 0};
PipeBarrier<PIPE_V>();
Duplicate(groupIndex.ReinterpretCast<int32_t>()[duplicateIndex], FLOAT32_NEG_INF, mask, 1, 1, REPEAT_BLOCKS);
}
// 将筛选出来的组序号降序排列
LocalTensor<float> sortedGroupIndex = sortedGroupIndexBuf_.Get<float>();
PipeBarrier<PIPE_V>();
Sort<float, true>(sortedGroupIndex, groupIndex.ReinterpretCast<float>(), groupIndex, sortTmp, Ceil(kGroup_, 32));
}
template <typename T>
__aicore__ inline void MoeGatingTopKGenerlized<T>::SelectTopKExpertIdx()
{
LocalTensor<float> sortedInGroupTensor = sortedInGroupBuf_.Get<float>();
LocalTensor<int32_t> sortedGroupIndex = sortedGroupIndexBuf_.Get<int32_t>();
LocalTensor<int32_t> topKExpertId = topKExpertIdBuf_.Get<int32_t>();
LocalTensor<float> mrgSort0Tensor = calcTmpBuf_.Get<float>();
uint32_t offset[CONSTANT_FOUR] = {0, 0, 0, 0};
uint16_t lenArr[CONSTANT_FOUR] = {
static_cast<uint16_t>(perGroupExpertCount_), static_cast<uint16_t>(perGroupExpertCount_),
static_cast<uint16_t>(perGroupExpertCount_), static_cast<uint16_t>(perGroupExpertCount_)};
MrgSort4Info params{lenArr, false, 0b1111, 1};
MrgSortSrcList<float> srcList;
event_t eventIdVToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
SetFlag<HardEvent::V_S>(eventIdVToS);
WaitFlag<HardEvent::V_S>(eventIdVToS);
for (int32_t i = kGroup_ - 1; i >= 0; i -= CONSTANT_FOUR) {
int64_t mrgLen = Min(i + 1, CONSTANT_FOUR);
if (mrgLen > 1) {
if (mrgLen == MERGE_LIST_FOUR) {
offset[0] = sortedGroupIndex.GetValue(i * 2) * perGroupExpertCountAlign_ * 2;
offset[1] = sortedGroupIndex.GetValue((i - 1) * 2) * perGroupExpertCountAlign_ * 2;
offset[2] = sortedGroupIndex.GetValue((i - 2) * 2) * perGroupExpertCountAlign_ * 2;
offset[3] = sortedGroupIndex.GetValue((i - 3) * 2) * perGroupExpertCountAlign_ * 2;
} else if (mrgLen == MERGE_LIST_THREE) {
offset[0] = sortedGroupIndex.GetValue(i * 2) * perGroupExpertCountAlign_ * 2;
offset[1] = sortedGroupIndex.GetValue((i - 1) * 2) * perGroupExpertCountAlign_ * 2;
offset[2] = sortedGroupIndex.GetValue((i - 2) * 2) * perGroupExpertCountAlign_ * 2;
offset[3] = 0;
params.elementLengths[3] = 0;
params.validBit = 0b111;
} else {
offset[0] = sortedGroupIndex.GetValue(i * 2) * perGroupExpertCountAlign_ * 2;
offset[1] = sortedGroupIndex.GetValue((i - 1) * 2) * perGroupExpertCountAlign_ * 2;
offset[2] = 0;
offset[3] = 0;
params.elementLengths[2] = 0;
params.elementLengths[3] = 0;
params.validBit = 0b11;
}
srcList.src1 = sortedInGroupTensor[offset[0]];
srcList.src2 = sortedInGroupTensor[offset[1]];
srcList.src3 = sortedInGroupTensor[offset[2]];
srcList.src4 = sortedInGroupTensor[offset[3]];
PipeBarrier<PIPE_V>();
MrgSort(mrgSort0Tensor[(kGroup_ - 1 - i) * perGroupExpertCountAlign_ * 2], srcList, params);
} else {
offset[0] = sortedGroupIndex.GetValue(i * 2) * perGroupExpertCountAlign_ * 2;
PipeBarrier<PIPE_V>();
DataCopy(mrgSort0Tensor[(kGroup_ - 1 - i) * perGroupExpertCountAlign_ * 2], sortedInGroupTensor[offset[0]],
perGroupExpertCountAlign_ * 2);
}
}
int32_t baseLoop = 4;
LocalTensor<float> srcTensor = mrgSort0Tensor;
LocalTensor<float> dstTensor = mrgSort0Tensor;
for (int i = 0; i < tilingData_->vmsCount; i++) {
if (i % 2 == 0) {
srcTensor = mrgSort0Tensor;
dstTensor = sortedInGroupTensor;
} else {
srcTensor = sortedInGroupTensor;
dstTensor = mrgSort0Tensor;
}
int32_t nextBaseRow = baseLoop * MERGE_LIST_FOUR;
int32_t quotient = kGroup_ / nextBaseRow;
int32_t remainder = kGroup_ - quotient * nextBaseRow;
if (quotient > 0) {
MrgSort4Info params;
MrgSortSrcList<float> srcList;
params.ifExhaustedSuspension = false;
params.elementLengths[0] = perGroupExpertCount_ * baseLoop;
params.elementLengths[1] = perGroupExpertCount_ * baseLoop;
params.elementLengths[2] = perGroupExpertCount_ * baseLoop;
params.elementLengths[3] = perGroupExpertCount_ * baseLoop;
params.validBit = 0b1111;
params.repeatTimes = 1;
for (int j = 0; j < quotient; j++) {
srcList.src1 = srcTensor[perGroupExpertCountAlign_ * baseLoop * 8 * j];
srcList.src2 = srcTensor[perGroupExpertCountAlign_ * baseLoop * (8 * j + 2)];
srcList.src3 = srcTensor[perGroupExpertCountAlign_ * baseLoop * (8 * j + 4)];
srcList.src4 = srcTensor[perGroupExpertCountAlign_ * baseLoop * (8 * j + 6)];
PipeBarrier<PIPE_V>();
MrgSort(dstTensor[perGroupExpertCountAlign_ * baseLoop * 8 * j], srcList, params);
}
}
if (remainder > 0) {
int32_t baseOffset = quotient * nextBaseRow * perGroupExpertCountAlign_ * 2;
int32_t mrgLen = CeilDiv(remainder, baseLoop);
int32_t tailRow = remainder - (mrgLen - 1) * baseLoop;
if (mrgLen > 1) {
MrgSort4Info params;
MrgSortSrcList<float> srcList;
params.repeatTimes = 1;
params.ifExhaustedSuspension = false;
params.elementLengths[0] = perGroupExpertCount_ * baseLoop;
params.elementLengths[1] = perGroupExpertCount_ * baseLoop;
params.elementLengths[2] = perGroupExpertCount_ * baseLoop;
params.elementLengths[3] = perGroupExpertCount_ * baseLoop;
srcList.src1 = srcTensor[baseOffset];
srcList.src2 = srcTensor[baseOffset + perGroupExpertCountAlign_ * baseLoop * 2];
if (mrgLen == MERGE_LIST_FOUR) {
srcList.src3 = srcTensor[baseOffset + perGroupExpertCountAlign_ * baseLoop * 2 * 2];
srcList.src4 = srcTensor[baseOffset + perGroupExpertCountAlign_ * baseLoop * 2 * 3];
params.elementLengths[3] = perGroupExpertCount_ * tailRow;
params.validBit = 0b1111;
} else if (mrgLen == MERGE_LIST_THREE) {
srcList.src3 = srcTensor[baseOffset + perGroupExpertCountAlign_ * baseLoop * 2 * 2];
params.elementLengths[2] = perGroupExpertCount_ * tailRow;
params.elementLengths[3] = 0;
params.validBit = 0b111;
} else {
params.elementLengths[1] = perGroupExpertCount_ * tailRow;
params.elementLengths[2] = 0;
params.elementLengths[3] = 0;
params.validBit = 0b11;
}
PipeBarrier<PIPE_V>();
MrgSort(dstTensor[baseOffset], srcList, params);
} else {
PipeBarrier<PIPE_V>();
DataCopy(dstTensor[baseOffset], srcTensor[baseOffset], tailRow * perGroupExpertCountAlign_ * 2);
}
}
baseLoop = nextBaseRow;
}
GatherMaskParams gatherMaskParams;
gatherMaskParams.repeatTimes = Ceil(k_ * sizeof(float) * 2, REPEAT_BYTES);
gatherMaskParams.src0BlockStride = 1;
gatherMaskParams.src0RepeatStride = REPEAT_BLOCKS;
gatherMaskParams.src1RepeatStride = 0;
uint64_t rsvdCnt = 0; // 用于保存筛选后保留下来的元素个数
uint8_t src1Pattern = 2; // 内置固定模式
PipeBarrier<PIPE_V>();
GatherMask(topKExpertId, dstTensor.template ReinterpretCast<int32_t>(), src1Pattern, false,
static_cast<uint32_t>(0), gatherMaskParams, rsvdCnt);
}
template <typename T>
__aicore__ inline void MoeGatingTopKGenerlized<T>::SelectTopKExpertScore()
{
LocalTensor<float> xNormTensor = xNormBuf_.Get<float>();
LocalTensor<float> yOutTensor = yOutQueue_.AllocTensor<float>();
LocalTensor<int32_t> topKExpertId = topKExpertIdBuf_.Get<int32_t>();
LocalTensor<int32_t> topKExpertIdWithByte = calcTmpBuf_.Get<int32_t>();
PipeBarrier<PIPE_V>();
Muls(topKExpertIdWithByte, topKExpertId, static_cast<int32_t>(sizeof(float)), k_);
PipeBarrier<PIPE_V>();
Gather(yOutTensor, xNormTensor, topKExpertIdWithByte.template ReinterpretCast<uint32_t>(), static_cast<uint32_t>(0),
k_);
bool needRenorm = (normType_ == 1 ) || // 情况1sigmoid + renorm
(normType_ == 0 && renorm_ == 1); // 情况3softmax + renorm
if (needRenorm) {
LocalTensor<float> maxValueTensor = calcTmpBuf_.Get<float>();
LocalTensor<float> tmpTensor = calcTmpBuf_.Get<float>()[32];
PipeBarrier<PIPE_V>();
ReduceSum(maxValueTensor, yOutTensor, tmpTensor, k_);
event_t eventIdVToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
SetFlag<HardEvent::V_S>(eventIdVToS);
WaitFlag<HardEvent::V_S>(eventIdVToS);
float sumValue = maxValueTensor.GetValue(0) + tilingData_->eps;
event_t eventIdSToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
SetFlag<HardEvent::S_V>(eventIdSToV);
WaitFlag<HardEvent::S_V>(eventIdSToV);
Duplicate(tmpTensor, sumValue, k_);
PipeBarrier<PIPE_V>();
Div(yOutTensor, yOutTensor, tmpTensor, k_);
}
PipeBarrier<PIPE_V>();
Muls(yOutTensor, yOutTensor, tilingData_->routedScalingFactor, k_);
if constexpr (!IsSameType<T, float>::value) {
PipeBarrier<PIPE_V>();
Cast(yOutTensor.ReinterpretCast<T>(), yOutTensor, RoundMode::CAST_RINT, k_);
}
yOutQueue_.EnQue<float>(yOutTensor);
}
template <typename T>
__aicore__ inline void MoeGatingTopKGenerlized<T>::CumputeActualTopKExpertId()
{
LocalTensor<int32_t> expertIdxOut = expertIdxOutQueue_.AllocTensor<int32_t>();
LocalTensor<int32_t> topKExpertId = topKExpertIdBuf_.Get<int32_t>();
LocalTensor<float> topKExpertIdFp32 = calcTmpBuf_.Get<float>();
PipeBarrier<PIPE_V>();
Cast(topKExpertIdFp32, topKExpertId, RoundMode::CAST_ROUND, k_);
PipeBarrier<PIPE_V>();
Muls(topKExpertIdFp32, topKExpertIdFp32, 1.0f / (float)perGroupExpertCountAlign_, k_);
PipeBarrier<PIPE_V>();
Cast(expertIdxOut, topKExpertIdFp32, RoundMode::CAST_TRUNC, k_);
PipeBarrier<PIPE_V>();
Muls(expertIdxOut, expertIdxOut, static_cast<int32_t>(perGroupExpertCountAlign_ - perGroupExpertCount_), k_);
PipeBarrier<PIPE_V>();
Sub(expertIdxOut, topKExpertId, expertIdxOut, k_);
expertIdxOutQueue_.EnQue<int32_t>(expertIdxOut);
}
template <typename T>
__aicore__ inline void MoeGatingTopKGenerlized<T>::CopyOut(int64_t row)
{
LocalTensor<T> yOutTensor = yOutQueue_.DeQue<T>();
LocalTensor<int32_t> expertIdxOut = expertIdxOutQueue_.DeQue<int32_t>();
DataCopyExtParams dataCopyParams{1, static_cast<uint32_t>(k_ * sizeof(T)), 0, 0, 0};
DataCopyPad(yGm_[row * k_], yOutTensor, dataCopyParams);
dataCopyParams.blockLen = k_ * sizeof(int32_t);
DataCopyPad(expertIdxGm_[row * k_], expertIdxOut, dataCopyParams);
yOutQueue_.FreeTensor(yOutTensor);
expertIdxOutQueue_.FreeTensor(expertIdxOut);
}
template <typename T>
__aicore__ inline void MoeGatingTopKGenerlized<T>::Init(GM_ADDR x, GM_ADDR bias, GM_ADDR y, GM_ADDR expertIdx,
GM_ADDR out, GM_ADDR workspace,
const MoeGatingTopKTilingData *tilingData, TPipe *tPipe)
{
tilingData_ = tilingData;
pipe_ = tPipe;
blockIdx_ = GetBlockIdx();
perCoreRowCount_ = tilingData_->perCoreRowCount;
if (blockIdx_ == GetBlockNum() - 1) {
curCoreRowCount_ = tilingData_->lastCoreRowCount;
} else {
curCoreRowCount_ = tilingData_->perCoreRowCount;
}
expertCount_ = tilingData_->expertCount;
addBias_ = tilingData_->addBias == 1;
k_ = tilingData_->k;
kGroup_ = tilingData_->kGroup;
groupCount_ = tilingData_->groupCount;
groupCountAlign_ = Ceil(groupCount_, ONE_REPEAT_SORT_NUM) * ONE_REPEAT_SORT_NUM;
perGroupExpertCount_ = tilingData_->perGroupExpertCount;
perGroupExpertCountAlign_ = tilingData_->perGroupExpertCountAlign;
renorm_ = tilingData_->renorm;
normType_ = tilingData_->normType;
groupSelectMode_ = tilingData_->groupSelectMode;
expertCountAlign_ = Align(perGroupExpertCountAlign_ * groupCount_, sizeof(float));
kAlign_ = Align(k_, sizeof(float));
isAlign_ = perGroupExpertCount_ == perGroupExpertCountAlign_;
// init input gm buf
xGm_.SetGlobalBuffer((__gm__ T *)x + perCoreRowCount_ * expertCount_ * blockIdx_, expertCount_);
biasGm_.SetGlobalBuffer((__gm__ T *)bias, expertCount_);
// init output gm buf
yGm_.SetGlobalBuffer((__gm__ T *)y + perCoreRowCount_ * k_ * blockIdx_, k_);
expertIdxGm_.SetGlobalBuffer((__gm__ int32_t *)expertIdx + perCoreRowCount_ * k_ * blockIdx_, k_);
outGm_.SetGlobalBuffer((__gm__ float *)out + perCoreRowCount_ * expertCount_ * blockIdx_, expertCount_);
// init que
pipe_->InitBuffer(xInQueue_, 1, expertCountAlign_ * sizeof(float) * (sizeof(float) / sizeof(T)));
pipe_->InitBuffer(yOutQueue_, 1, kAlign_ * sizeof(float));
pipe_->InitBuffer(expertIdxOutQueue_, 1, kAlign_ * sizeof(int32_t));
pipe_->InitBuffer(outOutQueue_, 1, expertCountAlign_ * sizeof(float));
pipe_->InitBuffer(biasBuf_, expertCountAlign_ * sizeof(float) * (sizeof(float) / sizeof(T)));
pipe_->InitBuffer(expertIdBuf_, expertCountAlign_ * sizeof(int32_t));
pipe_->InitBuffer(xNormBuf_, expertCountAlign_ * sizeof(float));
pipe_->InitBuffer(xNormWithBiasBuf_, expertCountAlign_ * sizeof(float));
pipe_->InitBuffer(sortedInGroupBuf_, expertCountAlign_ * (sizeof(float) + sizeof(uint32_t)));
pipe_->InitBuffer(sortedGroupIndexBuf_, groupCountAlign_ * sizeof(float) * CONSTANT_TWO);
pipe_->InitBuffer(topKExpertIdBuf_, kAlign_ * sizeof(int32_t));
pipe_->InitBuffer(calcTmpBuf_, expertCountAlign_ * sizeof(float) * 10);
}
template <typename T>
__aicore__ inline void MoeGatingTopKGenerlized<T>::Process()
{
CopyInBiasAndInitExpertId();
for (int64_t row = 0; row < curCoreRowCount_; row++) {
CopyInX(row);
ComputeX();
if (tilingData_->outFlag) {
CopuOutXNorm(row);
}
SortInGroup();
SelectTopKGroupIndex();
SelectTopKExpertIdx();
SelectTopKExpertScore();
CumputeActualTopKExpertId();
CopyOut(row);
}
}
} // namespace MoeGatingTopK
#endif // MOE_GATING_TOP_K_E_K_GENERALIZED_H

View File

@@ -0,0 +1,338 @@
/**
 * This program is free software, you can redistribute it and/or modify.
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file moe_gating_top_k_without_group.h
* \brief
*/
#ifndef MOE_GATING_TOP_K_E_K_WITHOUT_GROUP_H
#define MOE_GATING_TOP_K_E_K_WITHOUT_GROUP_H
#include "kernel_operator.h"
#include "common.h"
#include "kernel_utils.h"
namespace MoeGatingTopK {
using namespace AscendC;
template <typename T>
class MoeGatingTopKWithoutGroup {
public:
__aicore__ inline MoeGatingTopKWithoutGroup(){};
__aicore__ inline void Init(GM_ADDR x, GM_ADDR bias, GM_ADDR y, GM_ADDR expertIdx, GM_ADDR out, GM_ADDR workspace,
const MoeGatingTopKTilingData *tilingData, TPipe *tPipe);
__aicore__ inline void Process();
private:
__aicore__ inline void CopyInBiasAndInitExpertId();
__aicore__ inline void CopyInX(int64_t progress);
__aicore__ inline void ComputeX();
__aicore__ inline void CopuOutXNorm(int64_t row);
__aicore__ inline void SelectTopKExpertIdx();
__aicore__ inline void SelectTopKExpertScore();
__aicore__ inline void CopyOut(int64_t row);
private:
TPipe *pipe_;
TQue<QuePosition::VECIN, 1> xInQueue_;
TQue<QuePosition::VECOUT, 1> yOutQueue_;
TQue<QuePosition::VECOUT, 1> expertIdxOutQueue_;
TQue<QuePosition::VECOUT, 1> outOutQueue_;
TBuf<TPosition::VECCALC> biasBuf_; // 存放输入bias
TBuf<TPosition::VECCALC> expertIdBuf_; // 专家编号
TBuf<TPosition::VECCALC> xNormWithBiasBuf_; // 存放加了bias之后的值
TBuf<TPosition::VECCALC> xNormBuf_; // 存放计算sigmoid或softmax的值
TBuf<TPosition::VECCALC> topKExpertIdBuf_;
TBuf<TPosition::VECCALC> calcTmpBuf_;
GlobalTensor<T> xGm_;
GlobalTensor<T> biasGm_;
GlobalTensor<T> yGm_;
GlobalTensor<int32_t> expertIdxGm_;
GlobalTensor<float> outGm_;
int64_t blockIdx_ = 0;
int64_t perCoreRowCount_ = 0;
int64_t curCoreRowCount_ = 0;
int64_t expertCount_ = 0;
bool addBias_ = false;
bool outFlag_ = false;
int64_t k_ = 0;
int64_t renorm_ = 0;
int64_t normType_ = 0;
int64_t expertCountAlign_ = 0;
const MoeGatingTopKTilingData *tilingData_;
};
template <typename T>
__aicore__ inline void MoeGatingTopKWithoutGroup<T>::CopyInBiasAndInitExpertId()
{
LocalTensor<float> biasTensor = biasBuf_.Get<float>();
LocalTensor<int32_t> expertIdTensor = expertIdBuf_.Get<int32_t>();
DataCopyExtParams dataCopyParams{1, static_cast<uint32_t>(expertCount_ * sizeof(T)), 0, 0, 0};
DataCopyPadExtParams dataCopyPadParams{false, 0, 0, static_cast<T>(0)};
if (addBias_) {
if constexpr (IsSameType<T, float>::value) {
DataCopyPad(biasTensor, biasGm_, dataCopyParams, dataCopyPadParams);
event_t eventIdMte2ToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
SetFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
WaitFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
} else {
DataCopyPad(biasTensor[expertCountAlign_].ReinterpretCast<T>(), biasGm_, dataCopyParams, dataCopyPadParams);
event_t eventIdMte2ToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
SetFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
WaitFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
Cast(biasTensor, biasTensor[expertCountAlign_].ReinterpretCast<T>(), RoundMode::CAST_NONE,
expertCountAlign_);
PipeBarrier<PIPE_V>();
}
}
ArithProgression(expertIdTensor, static_cast<int32_t>(0), static_cast<int32_t>(1), expertCount_);
}
template <typename T>
__aicore__ inline void MoeGatingTopKWithoutGroup<T>::CopyInX(int64_t row)
{
LocalTensor<float> xInLocalTensor = xInQueue_.AllocTensor<float>();
DataCopyExtParams dataCopyParams{1, static_cast<uint32_t>(expertCount_ * sizeof(T)), 0, 0, 0};
DataCopyPadExtParams dataCopyPadParams{false, 0, 0, static_cast<T>(0)};
if constexpr (IsSameType<T, float>::value) {
DataCopyPad(xInLocalTensor, xGm_[row * expertCount_], dataCopyParams, dataCopyPadParams);
} else {
DataCopyPad(xInLocalTensor[expertCountAlign_].ReinterpretCast<T>(), xGm_[row * expertCount_], dataCopyParams,
dataCopyPadParams);
}
xInQueue_.EnQue(xInLocalTensor);
}
template <typename T>
__aicore__ inline void MoeGatingTopKWithoutGroup<T>::ComputeX()
{
LocalTensor<float> xNormTensor = xNormBuf_.Get<float>();
LocalTensor<float> xInLocalTensor = xInQueue_.DeQue<float>();
LocalTensor<float> xNormWithBiasTensor = xNormWithBiasBuf_.Get<float>();
LocalTensor<float> biasTensor = biasBuf_.Get<float>();
if constexpr (!IsSameType<T, float>::value) {
Cast(xInLocalTensor, xInLocalTensor[expertCountAlign_].ReinterpretCast<T>(), RoundMode::CAST_NONE,
expertCount_);
PipeBarrier<PIPE_V>();
}
if (normType_ == 1) { // sigmoid
LocalTensor<uint8_t> calcNormTmpTensor = calcTmpBuf_.Get<uint8_t>();
Sigmoid(xNormTensor, xInLocalTensor, calcNormTmpTensor, expertCount_);
PipeBarrier<PIPE_V>();
} else if (normType_ == 0) { // sigmoid
LocalTensor<float> reduceValueTensor = calcTmpBuf_.Get<float>();
LocalTensor<float> calcTmp = calcTmpBuf_.Get<float>()[8];
ReduceMax(reduceValueTensor, xInLocalTensor, calcTmp, expertCount_);
event_t eventIdVToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
SetFlag<HardEvent::V_S>(eventIdVToS);
WaitFlag<HardEvent::V_S>(eventIdVToS);
float maxValue = reduceValueTensor.GetValue(0);
event_t eventIdSToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
SetFlag<HardEvent::S_V>(eventIdSToV);
WaitFlag<HardEvent::S_V>(eventIdSToV);
Adds(xNormTensor, xInLocalTensor, -maxValue, expertCount_);
PipeBarrier<PIPE_V>();
Exp(xNormTensor, xNormTensor, expertCount_);
PipeBarrier<PIPE_V>();
ReduceSum(reduceValueTensor, xNormTensor, calcTmp, expertCount_);
eventIdVToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
SetFlag<HardEvent::V_S>(eventIdVToS);
WaitFlag<HardEvent::V_S>(eventIdVToS);
float sumValue = reduceValueTensor.GetValue(0);
eventIdSToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
SetFlag<HardEvent::S_V>(eventIdSToV);
WaitFlag<HardEvent::S_V>(eventIdSToV);
Muls(xNormTensor, xNormTensor, 1.0f / sumValue, expertCount_);
PipeBarrier<PIPE_V>();
}
if (addBias_) {
Add(xNormWithBiasTensor, xNormTensor, biasTensor, expertCount_);
} else {
DataCopy(xNormWithBiasTensor, xNormTensor, expertCountAlign_);
}
int64_t duplicateNum = expertCount_ % ONE_REPEAT_SORT_NUM;
int duplicateIndex = expertCount_ - duplicateNum;
if (duplicateNum > 0) {
uint64_t mask0 = UINT64_MAX;
mask0 = mask0 << duplicateNum;
mask0 = mask0 & (UINT64_MAX >> ONE_REPEAT_SORT_NUM);
uint64_t mask[2] = {mask0, 0};
Duplicate(xNormWithBiasTensor.ReinterpretCast<int32_t>()[duplicateIndex], FLOAT32_NEG_INF, mask, 1, 1, 1);
PipeBarrier<PIPE_V>();
}
xInQueue_.FreeTensor(xInLocalTensor);
}
template <typename T>
__aicore__ inline void MoeGatingTopKWithoutGroup<T>::CopuOutXNorm(int64_t row)
{
LocalTensor<float> outOutTensor = outOutQueue_.AllocTensor<float>();
LocalTensor<float> xNormTensor = xNormBuf_.Get<float>();
DataCopy(outOutTensor, xNormTensor, expertCountAlign_);
outOutQueue_.EnQue<float>(outOutTensor);
outOutTensor = outOutQueue_.DeQue<float>();
DataCopyExtParams dataCopyParams{1, static_cast<uint32_t>(expertCount_ * sizeof(float)), 0, 0, 0};
DataCopyPad(outGm_[row * expertCount_], outOutTensor, dataCopyParams);
outOutQueue_.FreeTensor(outOutTensor);
}
template <typename T>
__aicore__ inline void MoeGatingTopKWithoutGroup<T>::SelectTopKExpertIdx()
{
LocalTensor<int32_t> expertIdxOut = expertIdxOutQueue_.AllocTensor<int32_t>();
LocalTensor<float> xNormWithBiasTensor = xNormWithBiasBuf_.Get<float>();
LocalTensor<uint32_t> expertIdTensor = expertIdBuf_.Get<uint32_t>();
LocalTensor<int32_t> topKExpertId = topKExpertIdBuf_.Get<int32_t>();
LocalTensor<float> sortedScore = calcTmpBuf_.Get<float>();
LocalTensor<float> sortTmp = calcTmpBuf_.Get<float>()[expertCountAlign_ * CONSTANT_TWO];
PipeBarrier<PIPE_ALL>();
Sort<float, true>(sortedScore, xNormWithBiasTensor, expertIdTensor, sortTmp,
expertCountAlign_ / ONE_REPEAT_SORT_NUM);
GatherMaskParams gatherMaskParams;
gatherMaskParams.repeatTimes = Ceil(k_ * sizeof(float) * CONSTANT_TWO, REPEAT_BYTES);
gatherMaskParams.src0BlockStride = 1;
gatherMaskParams.src0RepeatStride = REPEAT_BLOCKS;
gatherMaskParams.src1RepeatStride = 0;
uint64_t rsvdCnt = 0; // 用于保存筛选后保留下来的元素个数
uint8_t src1Pattern = 2; // 内置固定模式
PipeBarrier<PIPE_V>();
GatherMask(topKExpertId, sortedScore.template ReinterpretCast<int32_t>(), src1Pattern, false,
static_cast<uint32_t>(0), gatherMaskParams, rsvdCnt);
DataCopy(expertIdxOut, topKExpertId, expertCountAlign_);
expertIdxOutQueue_.EnQue<int32_t>(expertIdxOut);
}
template <typename T>
__aicore__ inline void MoeGatingTopKWithoutGroup<T>::SelectTopKExpertScore()
{
LocalTensor<float> xNormTensor = xNormBuf_.Get<float>();
LocalTensor<float> yOutTensor = yOutQueue_.AllocTensor<float>();
LocalTensor<int32_t> topKExpertId = topKExpertIdBuf_.Get<int32_t>();
LocalTensor<int32_t> topKExpertIdWithByte = calcTmpBuf_.Get<int32_t>();
PipeBarrier<PIPE_V>();
Muls(topKExpertIdWithByte, topKExpertId, static_cast<int32_t>(sizeof(float)), k_);
PipeBarrier<PIPE_V>();
Gather(yOutTensor, xNormTensor, topKExpertIdWithByte.template ReinterpretCast<uint32_t>(), static_cast<uint32_t>(0),
k_);
bool needRenorm = (normType_ == 1 ) || // 情况1sigmoid + renorm
(normType_ == 0 && renorm_ == 1); // 情况3softmax + renorm
if (needRenorm == 1) {
LocalTensor<float> maxValueTensor = calcTmpBuf_.Get<float>();
LocalTensor<float> tmpTensor = calcTmpBuf_.Get<float>()[BLOCK_BYTES];
PipeBarrier<PIPE_V>();
ReduceSum(maxValueTensor, yOutTensor, tmpTensor, k_);
event_t eventIdVToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
SetFlag<HardEvent::V_S>(eventIdVToS);
WaitFlag<HardEvent::V_S>(eventIdVToS);
float sumValue = maxValueTensor.GetValue(0) + tilingData_->eps;
event_t eventIdSToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
SetFlag<HardEvent::S_V>(eventIdSToV);
WaitFlag<HardEvent::S_V>(eventIdSToV);
Duplicate(tmpTensor, sumValue, k_);
PipeBarrier<PIPE_V>();
Div(yOutTensor, yOutTensor, tmpTensor, k_);
}
PipeBarrier<PIPE_V>();
Muls(yOutTensor, yOutTensor, tilingData_->routedScalingFactor, k_);
if constexpr (!IsSameType<T, float>::value) {
PipeBarrier<PIPE_V>();
Cast(yOutTensor.ReinterpretCast<T>(), yOutTensor, RoundMode::CAST_RINT, k_);
}
yOutQueue_.EnQue<float>(yOutTensor);
}
template <typename T>
__aicore__ inline void MoeGatingTopKWithoutGroup<T>::CopyOut(int64_t row)
{
LocalTensor<T> yOutTensor = yOutQueue_.DeQue<T>();
LocalTensor<int32_t> expertIdxOut = expertIdxOutQueue_.DeQue<int32_t>();
DataCopyExtParams dataCopyParams{1, static_cast<uint32_t>(k_ * sizeof(T)), 0, 0, 0};
DataCopyPad(yGm_[row * k_], yOutTensor, dataCopyParams);
dataCopyParams.blockLen = k_ * sizeof(int32_t);
DataCopyPad(expertIdxGm_[row * k_], expertIdxOut, dataCopyParams);
yOutQueue_.FreeTensor(yOutTensor);
expertIdxOutQueue_.FreeTensor(expertIdxOut);
}
template <typename T>
__aicore__ inline void MoeGatingTopKWithoutGroup<T>::Init(GM_ADDR x, GM_ADDR bias, GM_ADDR y, GM_ADDR expertIdx,
GM_ADDR out, GM_ADDR workspace,
const MoeGatingTopKTilingData *tilingData, TPipe *tPipe)
{
tilingData_ = tilingData;
pipe_ = tPipe;
blockIdx_ = GetBlockIdx();
perCoreRowCount_ = tilingData_->perCoreRowCount;
if (blockIdx_ == GetBlockNum() - 1) {
curCoreRowCount_ = tilingData_->lastCoreRowCount;
} else {
curCoreRowCount_ = tilingData_->perCoreRowCount;
}
expertCount_ = tilingData_->expertCount;
addBias_ = tilingData_->addBias == 1;
outFlag_ = tilingData_->outFlag == 1;
k_ = tilingData_->k;
renorm_ = tilingData_->renorm;
normType_ = tilingData_->normType;
expertCountAlign_ = Ceil(expertCount_, ONE_REPEAT_SORT_NUM) * ONE_REPEAT_SORT_NUM;
// init input gm buf
xGm_.SetGlobalBuffer((__gm__ T *)x + perCoreRowCount_ * expertCount_ * blockIdx_, expertCount_);
biasGm_.SetGlobalBuffer((__gm__ T *)bias, expertCount_);
// init output gm buf
yGm_.SetGlobalBuffer((__gm__ T *)y + perCoreRowCount_ * k_ * blockIdx_, k_);
expertIdxGm_.SetGlobalBuffer((__gm__ int32_t *)expertIdx + perCoreRowCount_ * k_ * blockIdx_, k_);
outGm_.SetGlobalBuffer((__gm__ float *)out + perCoreRowCount_ * expertCount_ * blockIdx_, expertCount_);
// init que
pipe_->InitBuffer(xInQueue_, 1, expertCountAlign_ * sizeof(float) * (sizeof(float) / sizeof(T)));
pipe_->InitBuffer(yOutQueue_, 1, Align(k_, sizeof(float)) * sizeof(float));
pipe_->InitBuffer(expertIdxOutQueue_, 1, Align(k_, sizeof(float)) * sizeof(int32_t));
pipe_->InitBuffer(outOutQueue_, 1, expertCountAlign_ * sizeof(float));
// init calc buf
pipe_->InitBuffer(biasBuf_, expertCountAlign_ * sizeof(float) * (sizeof(float) / sizeof(T)));
pipe_->InitBuffer(expertIdBuf_, expertCountAlign_ * sizeof(int32_t));
pipe_->InitBuffer(xNormBuf_, expertCountAlign_ * sizeof(float));
pipe_->InitBuffer(xNormWithBiasBuf_, expertCountAlign_ * sizeof(float));
pipe_->InitBuffer(topKExpertIdBuf_, Align(k_, sizeof(float)) * sizeof(int32_t));
// init tmp buf
pipe_->InitBuffer(calcTmpBuf_, expertCountAlign_ * sizeof(float) * CONSTANT_EIGHT);
}
template <typename T>
__aicore__ inline void MoeGatingTopKWithoutGroup<T>::Process()
{
CopyInBiasAndInitExpertId();
for (int64_t row = 0; row < curCoreRowCount_; row++) {
CopyInX(row);
ComputeX();
if (outFlag_) {
CopuOutXNorm(row);
}
SelectTopKExpertIdx();
SelectTopKExpertScore();
CopyOut(row);
}
}
} // namespace MoeGatingTopK
#endif // MOE_GATING_TOP_K_E_K_WITHOUT_GROUP_H