[Feature] Add token mask for DispatchGmmCombineDecode operator (#5171)

### What this PR does / why we need it?
In this PR, DispatchGmmCombineDecode add an optional input
x_active_mask, with which
only token masked True will be dispatched and handle.


- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

Signed-off-by: wangqiankun <wangqiankun13@huawei.com>
This commit is contained in:
wangqiankun13
2025-12-19 16:31:48 +08:00
committed by GitHub
parent 636265be6d
commit 118b0ed346
14 changed files with 292 additions and 96 deletions

View File

@@ -30,8 +30,9 @@ extern aclnnStatus aclnnInnerDispatchGmmCombineDecodeGetWorkspaceSize(
const aclTensor *gmm1PermutedWeightScale,
const aclTensor *gmm2Weight,
const aclTensor *gmm2WeightScale,
const aclTensor *expertScales,
const aclTensor *expertSmoothScalesOptional,
const aclTensor *expertScalesOptional,
const aclTensor *xActiveMaskOptional,
char *groupEp,
int64_t epRankSize,
int64_t epRankId,
@@ -57,8 +58,9 @@ aclnnStatus aclnnDispatchGmmCombineDecodeGetWorkspaceSize(
const aclTensor *gmm1PermutedWeightScale,
const aclTensor *gmm2Weight,
const aclTensor *gmm2WeightScale,
const aclTensor *expertScales,
const aclTensor *expertSmoothScalesOptional,
const aclTensor *expertScalesOptional,
const aclTensor *xActiveMaskOptional,
char *groupEp,
int64_t epRankSize,
int64_t epRankId,
@@ -73,7 +75,7 @@ aclnnStatus aclnnDispatchGmmCombineDecodeGetWorkspaceSize(
aclOpExecutor **executor)
{
return aclnnInnerDispatchGmmCombineDecodeGetWorkspaceSize(x, expertIds, gmm1PermutedWeight, gmm1PermutedWeightScale,
gmm2Weight, gmm2WeightScale, expertSmoothScalesOptional, expertScalesOptional, groupEp, epRankSize,
gmm2Weight, gmm2WeightScale, expertScales, expertSmoothScalesOptional, xActiveMaskOptional, groupEp, epRankSize,
epRankId, moeExpertNum, shareExpertNum, shareExpertRankNum, quantMode, globalBs,
output, epRecvCount, workspaceSize, executor);
}

View File

@@ -23,8 +23,9 @@ __attribute__((visibility("default"))) aclnnStatus aclnnDispatchGmmCombineDecode
const aclTensor *gmm1PermutedWeightScale,
const aclTensor *gmm2Weight,
const aclTensor *gmm2WeightScale,
const aclTensor *expertScales,
const aclTensor *expertSmoothScalesOptional,
const aclTensor *expertScalesOptional,
const aclTensor *xActiveMaskOptional,
char *groupEp,
int64_t epRankSize,
int64_t epRankId,

View File

@@ -45,14 +45,19 @@ public:
.DataType({ge::DT_FLOAT, ge::DT_FLOAT})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("expert_scales")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("expert_smooth_scales")
.ParamType(OPTIONAL)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("expert_scales")
this->Input("x_active_mask")
.ParamType(OPTIONAL)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT})
.DataType({ge::DT_BOOL, ge::DT_BOOL})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->Output("output")

View File

@@ -35,8 +35,9 @@ constexpr uint32_t INPUT_GMM1_WEIGHT_INDEX = 2;
constexpr uint32_t INPUT_GMM1_WEIGHT_SCALE_INDEX = 3;
constexpr uint32_t INPUT_GMM2_WEIGHT_INDEX = 4;
constexpr uint32_t INPUT_GMM2_WEIGHT_SCALE_INDEX = 5;
constexpr uint32_t INPUT_SMOOTH_SCALE_INDEX = 6;
constexpr uint32_t INPUT_EXPERT_SCALE_INDEX = 7;
constexpr uint32_t INPUT_EXPERT_SCALE_INDEX = 6;
constexpr uint32_t INPUT_SMOOTH_SCALE_INDEX = 7;
constexpr uint32_t INPUT_SHARE_X_ACTIVE_MASK_INDEX = 8;
constexpr uint32_t ATTR_GROUP_EP_INDEX = 0;
constexpr uint32_t ATTR_EP_RANK_SIZE_INDEX = 1;
@@ -51,6 +52,7 @@ constexpr uint32_t MIN_BATCH_SIZE = 1;
constexpr uint32_t MAX_BATCH_SIZE = 256;
constexpr uint32_t MAX_MOE_EXERT_NUM = 512;
constexpr uint32_t SUPPORT_TOP_K = 12;
constexpr uint32_t ONE_DIMS = 1;
constexpr uint32_t TWO_DIMS = 2;
constexpr uint32_t MIN_TOKEN_LENGTH = 512;
constexpr uint32_t MAX_TOKEN_LENGTH = 7168;
@@ -71,6 +73,7 @@ static ge::graphStatus CheckTensorShape(gert::TilingContext *context, const char
uint32_t moeExpertNum = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNum;
uint32_t sharedExpertRankNum = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.sharedExpertRankNum;
uint32_t moeExpertNumPerRank = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank;
uint32_t batchSize = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.bs;
uint32_t h = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.h;
uint64_t gmm1WeightDim2 = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen;
@@ -122,6 +125,18 @@ static ge::graphStatus CheckTensorShape(gert::TilingContext *context, const char
OPS_ERR_IF(gmm2WeightScaleDim1 != h, OPS_LOG_E(nodeName, "gmm2WeightScale Dim1 must be %u.", h),
return ge::GRAPH_FAILED);
const gert::StorageShape* xActiveMaskStorageShape = context->GetOptionalInputShape(
INPUT_SHARE_X_ACTIVE_MASK_INDEX);
if (xActiveMaskStorageShape != nullptr) {
OPS_ERR_IF(xActiveMaskStorageShape->GetStorageShape().GetDimNum() != ONE_DIMS,
OPS_LOG_E(nodeName, " xActiveMask scale shape dims must be 1, but current dim num is %lu.",
xActiveMaskStorageShape->GetStorageShape().GetDimNum()),
return ge::GRAPH_FAILED);
const int64_t xActiveMaskDim0 = xActiveMaskStorageShape->GetStorageShape().GetDim(0);
OPS_ERR_IF(xActiveMaskDim0 != batchSize, OPS_LOG_E(nodeName,
"xActiveMask Dim0 must be batchSize(%u), but current dim is %lu.", batchSize, xActiveMaskDim0),
return ge::GRAPH_FAILED);
}
return ge::GRAPH_SUCCESS;
}
@@ -308,14 +323,22 @@ static ge::graphStatus DispatchGmmCombineDecodeTilingFuncImpl(gert::TilingContex
tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.aivNum = aivNum;
OPS_ERR_IF(CheckData(nodeName, *tilingData) != ge::GRAPH_SUCCESS, OPS_LOG_E(nodeName, "CheckData failed."),
return ge::GRAPH_FAILED);
OPS_ERR_IF(CheckTensorShape(context, nodeName, *tilingData) != ge::GRAPH_SUCCESS, OPS_LOG_E(
nodeName, "CheckTensorShape failed."), return ge::GRAPH_FAILED);
OPS_ERR_IF(SetWorkSpace(context, nodeName, *tilingData) != ge::GRAPH_SUCCESS,
OPS_LOG_E(nodeName, "Tiling set workspace failed."), return ge::GRAPH_FAILED);
SetHcommCfg(context, tilingData, groupEp);
if (tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank == 1) {
context->SetTilingKey(0);
} else {
context->SetTilingKey(EXEC_FLAG_DEEP_FUSE);
const gert::StorageShape* xActiveMaskStorageShape = context->GetOptionalInputShape(
INPUT_SHARE_X_ACTIVE_MASK_INDEX);
bool xActiveMaskEnable = (xActiveMaskStorageShape != nullptr);
uint64_t tilingKey = 0;
if (xActiveMaskEnable) {
tilingKey |= EXEC_FLAG_X_ACTIVE_MASK;
}
if (tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank != 1) {
tilingKey |= EXEC_FLAG_DEEP_FUSE;
}
context->SetTilingKey(tilingKey);
context->SetBlockDim(aicNum);
return ge::GRAPH_SUCCESS;
}

View File

@@ -14,7 +14,8 @@
extern "C" __global__ __aicore__ void dispatch_gmm_combine_decode(
// input
GM_ADDR x, GM_ADDR expert_ids, GM_ADDR gmm1_permuted_weight, GM_ADDR gmm1_permuted_weight_scale,
GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_smooth_scales, GM_ADDR expert_scales,
GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_scales, GM_ADDR expert_smooth_scales,
GM_ADDR x_active_mask,
// output
GM_ADDR output, GM_ADDR outputRecvCount,
// system
@@ -24,10 +25,10 @@ extern "C" __global__ __aicore__ void dispatch_gmm_combine_decode(
REGISTER_TILING_DEFAULT(DispatchGmmCombineDecodeTilingData);
KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_MIX_AIC_1_2); // 1C2V
GET_TILING_DATA(tiling_data, tiling);
if constexpr (TILING_KEY_IS(0) || TILING_KEY_IS(1)) {
if constexpr (TILING_KEY_IS(0) || TILING_KEY_IS(1) || TILING_KEY_IS(4) || TILING_KEY_IS(5)) {
DispatchGmmCombineDecode<DTYPE_X, int32_t, false, TILING_KEY_VAR> op;
op.Init(x, expert_ids, gmm1_permuted_weight, gmm1_permuted_weight_scale, gmm2_weight, gmm2_weight_scale,
expert_smooth_scales, expert_scales, output, outputRecvCount, workspace, nullptr, &tiling_data);
expert_scales, expert_smooth_scales, x_active_mask, output, outputRecvCount, workspace, nullptr, &tiling_data);
op.Process();
}
}

View File

@@ -62,7 +62,7 @@ CATLASS_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCoun
layout::VectorLayout layoutPerTokenScale, GM_ADDR gmD, layout::RowMajor layoutD,
GM_ADDR gmDequantScale, layout::VectorLayout layoutDequantScale, GM_ADDR gmWorkspace,
GM_ADDR gmX, GM_ADDR debugGm, GM_ADDR gmexpertIds, GM_ADDR gmExpandIdx,
GM_ADDR gmEpSendCount, GM_ADDR gmResvered, GM_ADDR gmOutputRecvCount,
GM_ADDR gmEpSendCount, GM_ADDR xActiveMask, GM_ADDR gmResvered, GM_ADDR gmOutputRecvCount,
uint32_t epRankSize, uint32_t epRankId, uint32_t moeExpertNum,
uint32_t moeExpertNumPerRank, uint32_t sharedExpertNum, uint32_t sharedExpertRankNum,
uint32_t quantMode, uint32_t globalBs, uint32_t bs, uint32_t topK, uint32_t tokenLen)
@@ -110,7 +110,7 @@ CATLASS_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCoun
using GemmKernel = typename std::conditional<
(EXEC_FLAG & EXEC_FLAG_DEEP_FUSE),
Gemm::Kernel::GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspace<
XType, BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>,
EXEC_FLAG, XType, BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>,
Gemm::Kernel::GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspaceWithShallowDispatch<
BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>>::type;
@@ -136,6 +136,7 @@ CATLASS_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCoun
gmexpertIds,
gmExpandIdx,
gmEpSendCount,
xActiveMask,
gmResvered,
gmOutputRecvCount,
epRankSize,
@@ -241,7 +242,7 @@ public:
__aicore__ inline void Init(
// input
GM_ADDR x, GM_ADDR expert_ids, GM_ADDR gmm1_permuted_weight, GM_ADDR gmm1_permuted_weight_scale,
GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_smooth_scales, GM_ADDR expert_scales,
GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_scales, GM_ADDR expert_smooth_scales, GM_ADDR x_active_mask,
// output
GM_ADDR output, GM_ADDR outputRecvCount,
// system
@@ -260,6 +261,7 @@ private:
GM_ADDR workspaceGM_;
GM_ADDR gmSmoothScales_;
GM_ADDR gmexpertScales_;
GM_ADDR xActiveMask_;
uint32_t maxTokenNum_{0};
uint32_t gmm1OutputDim_{0};
@@ -291,7 +293,8 @@ template <TemplateMC2TypeClass>
__aicore__ inline void DispatchGmmCombineDecode<TemplateMC2TypeFunc>::Init(
// input
GM_ADDR x, GM_ADDR expert_ids, GM_ADDR gmm1_permuted_weight, GM_ADDR gmm1_permuted_weight_scale,
GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_smooth_scales, GM_ADDR expert_scales,
GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_scales, GM_ADDR expert_smooth_scales,
GM_ADDR x_active_mask,
// output
GM_ADDR output, GM_ADDR outputRecvCount,
// system
@@ -312,6 +315,7 @@ __aicore__ inline void DispatchGmmCombineDecode<TemplateMC2TypeFunc>::Init(
gmOutputRecvCount_ = outputRecvCount;
workspaceGM_ = workspaceGM;
gmexpertScales_ = expert_scales;
xActiveMask_ = x_active_mask;
tilingData_ = tilingData;
epRankSize_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.epRankSize;
epRankId_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.epRankId;
@@ -386,12 +390,12 @@ __aicore__ inline void DispatchGmmCombineDecode<TemplateMC2TypeFunc>::Process()
GM_ADDR gmResvered = workspaceGM_ + workspaceOffset;
workspaceOffset += RoundUp<GM_ALIGN_BYTE>(resveredWorkSpaceSize);
if constexpr (EXEC_FLAG == 0) {
if constexpr ((EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) == 0) {
if constexpr (g_coreType == AscendC::AIV) {
AscendC::TPipe tpipe;
MoeDistributeDispatchImpl::CamMoeDistributeDispatch<ExpandXType, int8_t, false, true, false, false>
MoeDistributeDispatchImpl::CamMoeDistributeDispatch<ExpandXType, int8_t, false, true, false, false, EXEC_FLAG>
dispatcher;
dispatcher.Init(gmX_, gmexpertIds_, gmSmoothScales_, gmX1, gmX1Scale, gmExpandIdx, gmGroupList,
dispatcher.Init(gmX_, gmexpertIds_, gmSmoothScales_, xActiveMask_, gmX1, gmX1Scale, gmExpandIdx, gmGroupList,
gmEpSendCount, gmOutputRecvCount_, nullptr, gmWorkspace, &tpipe, tilingData_);
dispatcher.Process();
tpipe.Destroy();
@@ -411,7 +415,7 @@ __aicore__ inline void DispatchGmmCombineDecode<TemplateMC2TypeFunc>::Process()
Gmm1BlockScheduler>(
gmm1ProblemShape, groupCount_, gmGroupList, gmX1, layoutX1, gmPermuteWeight1_, layoutWeight1,
gmPermuteScale1_, layoutW1Scale, gmX1Scale, layoutX1Scale, gmX2, layoutX2, gmX2Scale,
layoutX2Scale, gmWorkspace, gmX_, gmSmoothScales_, gmexpertIds_, gmExpandIdx, gmEpSendCount, gmResvered,
layoutX2Scale, gmWorkspace, gmX_, gmSmoothScales_, gmexpertIds_, gmExpandIdx, gmEpSendCount, xActiveMask_, gmResvered,
gmOutputRecvCount_, epRankSize_, epRankId_, moeExpertNum_, moeExpertNumPerRank_, sharedExpertNum_,
sharedExpertRankNum_, quantMode_, globalBs_, bs_, topK_, tokenHiddenSize_);
AscendC::PipeBarrier<PIPE_ALL>();
@@ -425,7 +429,7 @@ __aicore__ inline void DispatchGmmCombineDecode<TemplateMC2TypeFunc>::Process()
MoeDistributeCombineImpl::CamMoeDistributeCombine<TemplateMC2TypeFunc> combiner;
if (g_coreType == AscendC::AIV) {
combiner.Init(gmGmm2DepOut, gmexpertIds_, gmExpandIdx, gmEpSendCount, nullptr, gmexpertScales_, gmOutput_,
combiner.Init(gmGmm2DepOut, gmexpertIds_, gmExpandIdx, gmEpSendCount, nullptr, gmexpertScales_, xActiveMask_, gmOutput_,
workspaceGM_, nullptr, tilingData_);
}
GmmDeq<TemplateMC2TypeFunc, Gmm2L1TileShape, Gmm2L0TileShape, Gmm2EpilogueTileShape, Gmm2BlockScheduler,

View File

@@ -353,7 +353,7 @@ __aicore__ inline static void CalQuantRow(const uint32_t column, uint32_t &row)
row = row < MAX_QUANT_ROW_ONCE ? row : MAX_QUANT_ROW_ONCE;
}
template <typename XType_, class BlockMmad_, class BlockEpilogue_, class BlockScheduler_, uint32_t WORKSPACE_STAGES_,
template <uint32_t EXEC_FLAG, typename XType_, class BlockMmad_, class BlockEpilogue_, class BlockScheduler_, uint32_t WORKSPACE_STAGES_,
class ElementGroupList_>
class GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspace
{
@@ -411,6 +411,7 @@ public:
GM_ADDR gmX;
GM_ADDR debugGm;
GM_ADDR gmexpertIds;
GM_ADDR gmXActiveMask;
GM_ADDR gmExpandIdx;
GM_ADDR gmEpSendCount;
@@ -438,7 +439,7 @@ public:
LayoutScale const &layoutScale_, GM_ADDR ptrPerTokenScale_,
LayoutPerTokenScale const &layoutPerTokenScale_, GM_ADDR ptrOutput_, LayoutOutput const &layoutOutput_,
GM_ADDR ptrDequantScale_, LayoutDequantScale const &layoutDequantScale_, GM_ADDR ptrWorkspace_,
GM_ADDR gmX_, GM_ADDR debugGm_, GM_ADDR gmexpertIds_, GM_ADDR gmExpandIdx_, GM_ADDR gmEpSendCount_,
GM_ADDR gmX_, GM_ADDR debugGm_, GM_ADDR gmexpertIds_, GM_ADDR gmExpandIdx_, GM_ADDR gmEpSendCount_, GM_ADDR gmXActiveMask_,
GM_ADDR gmResvered_, GM_ADDR gmOutputRecvCount_, uint32_t epRankSize_, uint32_t epRankId_,
uint32_t moeExpertNum_, uint32_t moeExpertNumPerRank_, uint32_t sharedExpertNum_,
uint32_t sharedExpertRankNum_, uint32_t quantMode_, uint32_t globalBs_, uint32_t bs_, uint32_t topK_,
@@ -465,6 +466,7 @@ public:
gmExpandIdx(gmExpandIdx_),
gmEpSendCount(gmEpSendCount_),
gmOutputRecvCount(gmOutputRecvCount_),
gmXActiveMask(gmXActiveMask_),
gmResvered(gmResvered_),
epRankSize(epRankSize_),
epRankId(epRankId_),
@@ -635,6 +637,39 @@ public:
AscendC::SyncAll<false>();
}
CATLASS_DEVICE
void TokenActiveMaskCal(GM_ADDR gmXActiveMask, int64_t ubOffset)
{
int64_t subUbOffset = ubOffset;
AscendC::LocalTensor<bool> maskInputTensor = (resource.ubBuf.template
GetBufferByByte<bool>(subUbOffset));
AscendC::LocalTensor<int8_t> maskInputInt8Tensor = maskInputTensor.template ReinterpretCast<int8_t>();
subUbOffset += CEIL_UP(axisBS * sizeof(bool));
AscendC::LocalTensor<half> maskTmpTensor = (resource.ubBuf.template
GetBufferByByte<half>(subUbOffset));
subUbOffset += CEIL_UP(axisBS * sizeof(half));
AscendC::LocalTensor<half> sumOutTensor = (resource.ubBuf.template
GetBufferByByte<half>(subUbOffset));
subUbOffset += CEIL_UP(SUM_TMP_TENSOR_SIZE);
AscendC::GlobalTensor<bool> xActiveMaskGMTensor;
xActiveMaskGMTensor.SetGlobalBuffer((__gm__ bool *)gmXActiveMask);
uint32_t axisBsAlignSize = CEIL_UP(axisBS * sizeof(bool));
AscendC::DataCopyExtParams maskParams = {1U, static_cast<uint32_t>(axisBS * sizeof(bool)), 0U, 0U, 0U};
AscendC::DataCopyPadExtParams<bool> maskCopyPadParams{false, 0U, 0U, 0U};
AscendC::DataCopyPad(maskInputTensor, xActiveMaskGMTensor, maskParams, maskCopyPadParams);
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(0);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(0);
AscendC::Cast(maskTmpTensor, maskInputInt8Tensor, AscendC::RoundMode::CAST_NONE, axisBS);
AscendC::PipeBarrier<PIPE_V>();
AscendC::SumParams params{1, axisBsAlignSize, axisBS};
AscendC::Sum(sumOutTensor, maskTmpTensor, params);
AscendC::SetFlag<AscendC::HardEvent::V_S>(0);
AscendC::WaitFlag<AscendC::HardEvent::V_S>(0);
activeMaskBsCnt = static_cast<int32_t>(sumOutTensor.GetValue(0));
}
CATLASS_DEVICE
void CalExpandxIdx(int32_t dstExpertId, uint32_t tokenIndex, int32_t &curExpertCnt, int64_t ubOffset)
{
@@ -778,8 +813,8 @@ public:
void SendToShareExprt(GM_ADDR gmX, GM_ADDR gmX1, GM_ADDR gmX1Scale)
{
uint32_t newAivId = sendCoreIdx - sendToMoeAivNum;
uint32_t sendTokenNum = axisBS / sendToShareAivNum;
uint32_t remainderTokenNum = axisBS % sendToShareAivNum;
uint32_t sendTokenNum = activeMaskBsCnt / sendToShareAivNum;
uint32_t remainderTokenNum = activeMaskBsCnt % sendToShareAivNum;
uint32_t startTokenId = sendTokenNum * newAivId;
if (newAivId < remainderTokenNum) {
sendTokenNum += 1;
@@ -788,7 +823,7 @@ public:
startTokenId += remainderTokenNum;
}
uint32_t endTokenId = startTokenId + sendTokenNum;
if (startTokenId >= axisBS) {
if (startTokenId >= activeMaskBsCnt) {
return;
}
@@ -962,10 +997,13 @@ public:
}
CATLASS_DEVICE void
SendCoreFunc(GM_ADDR gmX, GM_ADDR gmExpertIds, GM_ADDR gmX1, GM_ADDR gmX1Scale, GM_ADDR gmExpandIdx)
SendCoreFunc(GM_ADDR gmX, GM_ADDR gmExpertIds, GM_ADDR gmX1, GM_ADDR gmX1Scale, GM_ADDR gmExpandIdx, GM_ADDR gmXActiveMask)
{
ubOffset = 0;
expertIdsCnt = axisBS * axisK;
if constexpr (EXEC_FLAG & EXEC_FLAG_X_ACTIVE_MASK) {
TokenActiveMaskCal(gmXActiveMask, ubOffset);
}
expertIdsCnt = activeMaskBsCnt * axisK;
AscendC::GlobalTensor<int32_t> expertIdsGMTensor_;
expertIdsGMTensor_.SetGlobalBuffer((__gm__ int32_t *)gmExpertIds);
@@ -1372,6 +1410,7 @@ public:
hCommuSize = hOutSize + scaleParamPad;
axisHCommu = hCommuSize / sizeof(int8_t);
axisBS = params.bs;
activeMaskBsCnt = axisBS;
axisK = params.topK;
uint32_t maxAxisBs = params.globalBs / epRankSize;
@@ -1489,7 +1528,7 @@ public:
AivInitState();
if (isSendCore) {
SendCoreFunc((GM_ADDR)params.gmX, (GM_ADDR)params.gmexpertIds, (GM_ADDR)params.ptrA,
(GM_ADDR)params.ptrPerTokenScale, (GM_ADDR)params.gmExpandIdx);
(GM_ADDR)params.ptrPerTokenScale, (GM_ADDR)params.gmExpandIdx, (GM_ADDR)params.gmXActiveMask);
}
if (isRecvCore) {
RecvCoreFunc((GM_ADDR)params.ptrA, (GM_ADDR)params.ptrPerTokenScale, (GM_ADDR)params.gmEpSendCount,
@@ -1596,6 +1635,7 @@ private:
uint32_t hCommuSize{0};
uint32_t axisHCommu{0};
uint32_t axisBS{0};
uint32_t activeMaskBsCnt{0};
uint32_t axisK{0};
uint32_t totalTokenCount{0};
uint32_t expertIdsCnt{0};

View File

@@ -62,7 +62,7 @@ class CamMoeDistributeCombine
public:
__aicore__ inline CamMoeDistributeCombine(){};
__aicore__ inline void Init(GM_ADDR expandX, GM_ADDR expertIds, GM_ADDR expandIdx, GM_ADDR epSendCount,
GM_ADDR tpSendCount, GM_ADDR scales, GM_ADDR XOut, GM_ADDR workspaceGM, TPipe *pipe,
GM_ADDR tpSendCount, GM_ADDR scales, GM_ADDR xActiveMask, GM_ADDR XOut, GM_ADDR workspaceGM, TPipe *pipe,
const DispatchGmmCombineDecodeTilingData *tilingData);
__aicore__ inline void Process();
__aicore__ inline void AllToAllSend();
@@ -145,6 +145,7 @@ private:
GlobalTensor<ExpandIdxType> epSendCountGM_;
GlobalTensor<ExpandIdxType> tpSendCountGM_;
GlobalTensor<float> expandScalesGM_;
GlobalTensor<bool> xActiveMaskGM_;
GlobalTensor<ExpandXType> expandOutGlobal_;
GlobalTensor<ExpandXType> rankWindow_;
GlobalTensor<int32_t> rankStates_;
@@ -169,6 +170,8 @@ private:
CombineCalcInfo calcInfo_;
uint32_t axisBS_{0};
uint32_t axisMaxBs_{0};
uint32_t axisBsAlignSize_{0};
uint64_t activeMaskBsCnt_{0};
uint32_t axisH_{0};
uint32_t axisK_{0};
uint32_t aivNum_{0};
@@ -224,6 +227,9 @@ private:
TBuf<> gatherMaskOutBuf_; // gather mask output buf
TBuf<> gatherTmpBuf_;
TBuf<> statusSumOutBuf_;
TBuf<> xActMaskTBuf_;
TBuf<> xActMaskCastTBuf_;
TBuf<> xActMaskSumTBuf_;
float sumTarget_{0.0};
int32_t epStateValue_;
bool isShardExpert_{false};
@@ -231,7 +237,7 @@ private:
template <TemplateMC2TypeClass>
__aicore__ inline void CamMoeDistributeCombine<TemplateMC2TypeFunc>::Init(
GM_ADDR expandX, GM_ADDR expertIds, GM_ADDR expandIdx, GM_ADDR epSendCount, GM_ADDR tpSendCount, GM_ADDR scales,
GM_ADDR expandX, GM_ADDR expertIds, GM_ADDR expandIdx, GM_ADDR epSendCount, GM_ADDR tpSendCount, GM_ADDR scales, GM_ADDR xActiveMask,
GM_ADDR XOut, GM_ADDR workspaceGM, TPipe *pipe, const DispatchGmmCombineDecodeTilingData *tilingData)
{
tpipe_ = pipe;
@@ -264,8 +270,10 @@ __aicore__ inline void CamMoeDistributeCombine<TemplateMC2TypeFunc>::Init(
expandIdxGM_.SetGlobalBuffer((__gm__ ExpandIdxType *)expandIdx);
epSendCountGM_.SetGlobalBuffer((__gm__ int32_t *)epSendCount);
expandScalesGM_.SetGlobalBuffer((__gm__ float *)scales);
xActiveMaskGM_.SetGlobalBuffer((__gm__ bool*)xActiveMask);
expandOutGlobal_.SetGlobalBuffer((__gm__ ExpandXType *)XOut);
axisBS_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.bs;
activeMaskBsCnt_ = axisBS_;
axisH_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.h;
axisK_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.k;
if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) {
@@ -395,6 +403,27 @@ __aicore__ inline void CamMoeDistributeCombine<TemplateMC2TypeFunc>::AlltoAllBuf
tpipe_->InitBuffer(gatherMaskOutBuf_, epWorldSize_ * sizeof(float));
tpipe_->InitBuffer(gatherTmpBuf_, sizeof(uint32_t));
tpipe_->InitBuffer(statusSumOutBuf_, sizeof(float));
if constexpr (EXEC_FLAG & EXEC_FLAG_X_ACTIVE_MASK) {
axisBsAlignSize_ = Ceil(axisBS_ * sizeof(bool), UB_ALIGN) * UB_ALIGN;
tpipe_->InitBuffer(xActMaskTBuf_, axisBsAlignSize_);
tpipe_->InitBuffer(xActMaskCastTBuf_, axisBsAlignSize_ * sizeof(half));
tpipe_->InitBuffer(xActMaskSumTBuf_, axisBsAlignSize_ * sizeof(half));
LocalTensor<bool> xActiveMaskTensor = xActMaskTBuf_.Get<bool>();
LocalTensor<half> tempTensor = xActMaskCastTBuf_.Get<half>();
LocalTensor<half> sumOutTensor = xActMaskSumTBuf_.Get<half>();
DataCopyExtParams xActiveMaskParams{1U, static_cast<uint32_t>(axisBS_ * sizeof(bool)), 0U, 0U, 0U};
DataCopyPadExtParams<bool> xActiveMaskCopyPadParams{false, 0U, 0U, 0U};
DataCopyPad(xActiveMaskTensor, xActiveMaskGM_, xActiveMaskParams, xActiveMaskCopyPadParams);
SyncFunc<AscendC::HardEvent::MTE2_V>();
LocalTensor<int8_t> xActiveMaskInt8Tensor = xActiveMaskTensor.ReinterpretCast<int8_t>();
Cast(tempTensor, xActiveMaskInt8Tensor, RoundMode::CAST_NONE, axisBS_);
PipeBarrier<PIPE_V>();
SumParams params{1, axisBsAlignSize_, axisBS_};
Sum(sumOutTensor, tempTensor, params);
SyncFunc<AscendC::HardEvent::V_S>();
activeMaskBsCnt_ = static_cast<int32_t>(sumOutTensor.GetValue(0));
}
}
template <TemplateMC2TypeClass>
@@ -645,13 +674,16 @@ __aicore__ inline void CamMoeDistributeCombine<TemplateMC2TypeFunc>::WaitDispatc
template <TemplateMC2TypeClass>
__aicore__ inline void CamMoeDistributeCombine<TemplateMC2TypeFunc>::LocalWindowCopy()
{
if (activeMaskBsCnt_ == 0U) {
return;
}
uint32_t beginIndex = 0;
uint32_t endIndex = 0;
uint32_t processLen = 0;
uint32_t tokenOffset = 0;
if (axisBS_ < aivNum_) {
uint32_t aivNumPerToken = aivNum_ / axisBS_; // axisBS_ < aivNum_
if (coreIdx_ >= (axisBS_ * aivNumPerToken)) {
if (activeMaskBsCnt_ < aivNum_) {
uint32_t aivNumPerToken = aivNum_ / activeMaskBsCnt_; // activeMaskBsCnt_ < aivNum_
if (coreIdx_ >= (activeMaskBsCnt_ * aivNumPerToken)) {
return;
}
uint32_t tokenIndex = coreIdx_ / aivNumPerToken;
@@ -663,8 +695,8 @@ __aicore__ inline void CamMoeDistributeCombine<TemplateMC2TypeFunc>::LocalWindow
beginIndex = tokenIndex;
endIndex = beginIndex + 1U;
} else {
uint32_t tokenPerAivNum = axisBS_ / aivNum_;
uint32_t remainderToken = axisBS_ % aivNum_;
uint32_t tokenPerAivNum = activeMaskBsCnt_ / aivNum_;
uint32_t remainderToken = activeMaskBsCnt_ % aivNum_;
beginIndex = tokenPerAivNum * coreIdx_;
if (coreIdx_ < remainderToken) {
tokenPerAivNum++;
@@ -723,10 +755,10 @@ __aicore__ inline void CamMoeDistributeCombine<TemplateMC2TypeFunc>::LocalWindow
}
LocalTensor<ExpandXType> rowTmpLocal = tokenBuf_.Get<ExpandXType>();
if (sharedExpertRankNum_ > 0U) {
uint32_t temp = (epRankId_ * axisBS_) / sharedExpertRankNum_;
uint32_t moeOnShareRank = Ceil((tokenIndex + 1 + temp) * sharedExpertRankNum_, axisBS_) - 1 - epRankId_;
uint32_t preCnt = (moeOnShareRank + epRankId_) * axisBS_ / sharedExpertRankNum_ -
epRankId_ * axisBS_ / sharedExpertRankNum_;
uint32_t temp = (epRankId_ * activeMaskBsCnt_) / sharedExpertRankNum_;
uint32_t moeOnShareRank = Ceil((tokenIndex + 1 + temp) * sharedExpertRankNum_, activeMaskBsCnt_) - 1 - epRankId_;
uint32_t preCnt = (moeOnShareRank + epRankId_) * activeMaskBsCnt_ / sharedExpertRankNum_ -
epRankId_ * activeMaskBsCnt_ / sharedExpertRankNum_;
__gm__ ExpandXType *shareAddr =
(__gm__ ExpandXType *)(epWindowGM_ + moeOnShareRank * expertPerSizeOnWin_ * moeExpertPerRankNum_) +
(tokenIndex - preCnt) * axisH_ + tokenOffset;

View File

@@ -22,6 +22,7 @@ constexpr uint8_t BUFFER_NUM = 2;
constexpr uint32_t STATE_OFFSET = 512; // state space offset
constexpr uint32_t STATE_SIZE = 1024 * 1024; // 1M
constexpr uint32_t UB_ALIGN = 32;
constexpr uint64_t ALIGNED_LEN_256 = 256UL;
constexpr uint32_t SELF_STATE_OFFSET = 256 * 1024;
constexpr uint8_t COMM_NUM = 2;
constexpr uint8_t COMM_EP_IDX = 0;
@@ -47,18 +48,13 @@ __aicore__ inline void SyncFunc()
AscendC::WaitFlag<event>(eventID);
}
#define TemplateDispatchTypeClass \
typename XType, typename ExpandXOutType, bool StaticQuant, bool DynamicQuant, bool IsSmoothScaleExist, \
bool IsNeedAllgater
#define TemplateDispatchTypeFunc XType, ExpandXOutType, StaticQuant, DynamicQuant, IsSmoothScaleExist, IsNeedAllgater
using namespace AscendC;
template <TemplateDispatchTypeClass>
class CamMoeDistributeDispatch
{
public:
__aicore__ inline CamMoeDistributeDispatch(){};
__aicore__ inline void Init(GM_ADDR x, GM_ADDR expertIds, GM_ADDR scales, GM_ADDR expandXOut,
__aicore__ inline void Init(GM_ADDR x, GM_ADDR expertIds, GM_ADDR scales, GM_ADDR xActiveMask, GM_ADDR expandXOut,
GM_ADDR dynamicScalesOut, GM_ADDR expandIdxOut, GM_ADDR expertTokenNumsOut,
GM_ADDR sendCountsOut, GM_ADDR outputRecvCount, GM_ADDR tpSendCountsOut,
GM_ADDR workspaceGM, TPipe *pipe, const DispatchGmmCombineDecodeTilingData *tilingData);
@@ -80,6 +76,7 @@ private:
__aicore__ inline void AllGatherSetStatusAndWait();
__aicore__ inline void ResetStatus();
__aicore__ inline void QuantInit(GM_ADDR scales);
__aicore__ inline void TokenActiveMaskCal();
__aicore__ inline void AllgatherProcessOut();
__aicore__ inline void UpdataMultiMoeTokenNumsOut();
__aicore__ inline void UpdataTokenNumsOut();
@@ -112,6 +109,7 @@ private:
GlobalTensor<XType> xGMTensor_;
GlobalTensor<int32_t> expertIdsGMTensor_;
GlobalTensor<float> scalesGMTensor_;
GlobalTensor<bool> xActiveMaskGMTensor_;
GlobalTensor<ExpandXOutType> expandXOutGMTensor_;
GlobalTensor<float> dynamicScalesOutGMTensor_;
GlobalTensor<int64_t> expertTokenNumsOutGMTensor_;
@@ -144,6 +142,9 @@ private:
TBuf<> rowMaxBuf_;
TBuf<> receiveDataCastFloatBuf_;
TBuf<> smoothScalesBuf_;
TBuf<> dstExpBuf_;
TBuf<> subExpBuf_;
TBuf<> gatherMaskTBuf_;
TQueBind<QuePosition::VECIN, QuePosition::VECOUT, 1> xQueue_;
TQue<QuePosition::VECIN, 1> xInQueue_;
TQue<QuePosition::VECOUT, 1> xOutQueue_;
@@ -161,7 +162,10 @@ private:
GM_ADDR tpLocalStatusWindowGM_;
GlobalTensor<GM_ADDR> peerMemsAddrGm_;
uint32_t axisBS_{0};
uint32_t axisBsAlignSize_{0};
uint64_t activeMaskBsCnt_{0};
uint32_t axisMaxBS_{0};
uint64_t sendToMoeExpTokenCnt_{0};
uint32_t axisH_{0};
uint32_t axisK_{0};
uint32_t aivNum_{0};
@@ -183,6 +187,7 @@ private:
uint32_t hSize_{0};
uint32_t hOutSize_{0};
uint32_t hCommuSize_{0};
uint32_t maxSize_{0};
uint32_t scaleParamPad_{0};
uint32_t axisHCommu_{0};
uint32_t startExpertId_;
@@ -216,7 +221,7 @@ private:
template <TemplateDispatchTypeClass>
__aicore__ inline void CamMoeDistributeDispatch<TemplateDispatchTypeFunc>::Init(
GM_ADDR x, GM_ADDR expertIds, GM_ADDR scales, GM_ADDR expandXOut, GM_ADDR dynamicScalesOut, GM_ADDR expandIdxOut,
GM_ADDR x, GM_ADDR expertIds, GM_ADDR scales, GM_ADDR xActiveMask, GM_ADDR expandXOut, GM_ADDR dynamicScalesOut, GM_ADDR expandIdxOut,
GM_ADDR expertTokenNumsOut, GM_ADDR sendCountsOut, GM_ADDR outputRecvCount, GM_ADDR tpSendCountsOut,
GM_ADDR workspaceGM, TPipe *pipe, const DispatchGmmCombineDecodeTilingData *tilingData)
{
@@ -268,6 +273,7 @@ __aicore__ inline void CamMoeDistributeDispatch<TemplateDispatchTypeFunc>::Init(
tpWorldSize_ = 1;
xGMTensor_.SetGlobalBuffer((__gm__ XType *)x);
expertIdsGMTensor_.SetGlobalBuffer((__gm__ int32_t *)expertIds);
xActiveMaskGMTensor_.SetGlobalBuffer((__gm__ bool*)xActiveMask);
expandXOutGMTensor_.SetGlobalBuffer((__gm__ ExpandXOutType *)expandXOut);
dynamicScalesOutGMTensor_.SetGlobalBuffer((__gm__ float *)dynamicScalesOut);
expertTokenNumsOutGMTensor_.SetGlobalBuffer((__gm__ int64_t *)expertTokenNumsOut);
@@ -329,7 +335,8 @@ __aicore__ inline void CamMoeDistributeDispatch<TemplateDispatchTypeFunc>::Init(
if (isQuant_) {
QuantInit(scales);
}
uint32_t expertIdsSize = Ceil(axisBS_ * axisK_ * sizeof(int32_t), UB_ALIGN) * UB_ALIGN;
uint32_t expertIdsCnt = axisBS_ * axisK_;
uint32_t expertIdsSize = Ceil(expertIdsCnt * sizeof(int32_t), UB_ALIGN) * UB_ALIGN;
tpipe_->InitBuffer(expertIdsBuf_, expertIdsSize);
expertIdsTensor_ = expertIdsBuf_.Get<int32_t>();
tpipe_->InitBuffer(expertCountBuf_, expertIdsSize);
@@ -339,7 +346,22 @@ __aicore__ inline void CamMoeDistributeDispatch<TemplateDispatchTypeFunc>::Init(
tpipe_->InitBuffer(getTotalBuf_,
epWorldSize_ * moeExpertNumPerRank_ * sizeof(int32_t));
tpipe_->InitBuffer(scalarBuf_, UB_ALIGN * 2);
uint32_t hFp32Size = axisH_ * sizeof(float);
uint32_t xActivateMaskSize = axisBS_ * (Ceil(axisK_ * sizeof(bool), UB_ALIGN) * UB_ALIGN) * sizeof(half);
uint32_t bsAlign256 = Ceil(axisBS_ * sizeof(half), ALIGNED_LEN_256) * ALIGNED_LEN_256 / sizeof(half);
uint32_t bsKAlign256 = Ceil(expertIdsCnt * sizeof(half), ALIGNED_LEN_256) * ALIGNED_LEN_256 / sizeof(half);
expertIdsSize = Ceil(expertIdsSize, UB_ALIGN) * UB_ALIGN;
maxSize_ = hFp32Size > expertIdsSize ? hFp32Size : expertIdsSize;
maxSize_ = maxSize_ > xActivateMaskSize ? maxSize_ : xActivateMaskSize;
maxSize_ = maxSize_ > bsKAlign256 ? maxSize_ : bsKAlign256;
tpipe_->InitBuffer(gatherMaskTBuf_, maxSize_);
if constexpr (DynamicQuant || StaticQuant) {
dstExpBuf_ = receiveDataCastFloatBuf_; // 内存复用
subExpBuf_ = smoothScalesBuf_; // 内存复用
} else {
tpipe_->InitBuffer(dstExpBuf_, maxSize_); // BS * K * 4 = 32K
tpipe_->InitBuffer(subExpBuf_, maxSize_); // BS * K * 4 = 32K
}
moeExpertRankNumAligned_ = Ceil(moeExpertNum_, TABLE_ELEM_COUNT_PER_BLOCK) * TABLE_ELEM_COUNT_PER_BLOCK;
if (axisBS_ <= LOOP_OPT_MAX_BS && moeExpertRankNumAligned_ <= LOOP_OPT_MAX_MOE_RANK &&
axisK_ % TOPK_ELEM_COUNT_PER_BLOCK == 0) {
@@ -370,11 +392,36 @@ __aicore__ inline void CamMoeDistributeDispatch<TemplateDispatchTypeFunc>::Quant
dynamicScalesTensor_ = dynamicScalesBuf_.Get<float>();
}
template <TemplateDispatchTypeClass>
__aicore__ inline void CamMoeDistributeDispatch<TemplateDispatchTypeFunc>::TokenActiveMaskCal()
{
// 搬运x_active_mask, 当前仅用于计算有效token总数
LocalTensor<half> maskTmpTensor;
LocalTensor<half> sumOutTensor;
LocalTensor<bool> maskInputTensor;
axisBsAlignSize_ = Ceil(axisBS_ * sizeof(bool), UB_ALIGN) * UB_ALIGN;
maskInputTensor = dstExpBuf_.Get<bool>();
maskTmpTensor = subExpBuf_.Get<half>();
sumOutTensor = gatherMaskTBuf_.Get<half>();
DataCopyExtParams maskParams = {1U, static_cast<uint32_t>(axisBS_ * sizeof(bool)), 0U, 0U, 0U};
DataCopyPadExtParams<bool> maskCopyPadParams{false, 0U, 0U, 0U};
DataCopyPad(maskInputTensor, xActiveMaskGMTensor_, maskParams, maskCopyPadParams);
SyncFunc<AscendC::HardEvent::MTE2_V>();
LocalTensor<int8_t> maskInputInt8Tensor = maskInputTensor.ReinterpretCast<int8_t>();
Cast(maskTmpTensor, maskInputInt8Tensor, RoundMode::CAST_NONE, axisBS_);
PipeBarrier<PIPE_V>();
SumParams params{1, axisBsAlignSize_, axisBS_};
Sum(sumOutTensor, maskTmpTensor, params);
SyncFunc<AscendC::HardEvent::V_S>();
activeMaskBsCnt_ = static_cast<int32_t>(sumOutTensor.GetValue(0));
sendToMoeExpTokenCnt_ = activeMaskBsCnt_ * axisK_;
}
template <TemplateDispatchTypeClass>
__aicore__ inline void CamMoeDistributeDispatch<TemplateDispatchTypeFunc>::SendToSharedExpert()
{
uint32_t sendTokenNum = axisBS_ / sharedUsedAivNum_;
uint32_t remainderTokenNum = axisBS_ % sharedUsedAivNum_;
uint32_t sendTokenNum = activeMaskBsCnt_ / sharedUsedAivNum_;
uint32_t remainderTokenNum = activeMaskBsCnt_ % sharedUsedAivNum_;
uint32_t newAivId = aivId_ - moeUsedAivNum_;
uint32_t startTokenId = sendTokenNum * newAivId;
if (newAivId < remainderTokenNum) {
@@ -383,16 +430,16 @@ __aicore__ inline void CamMoeDistributeDispatch<TemplateDispatchTypeFunc>::SendT
} else {
startTokenId += remainderTokenNum;
}
if (startTokenId >= axisBS_) {
if (startTokenId >= activeMaskBsCnt_) {
return;
}
uint32_t endTokenId = startTokenId + sendTokenNum;
for (uint32_t tokenShuffleIndex = 0; tokenShuffleIndex < sendTokenNum; ++tokenShuffleIndex) {
uint32_t tokenIndex = startTokenId + ((tokenShuffleIndex + epRankId_) % sendTokenNum);
uint32_t temp = (epRankId_ * axisBS_) / sharedExpertRankNum_;
uint32_t moeOnShareRank = Ceil((tokenIndex + 1 + temp) * sharedExpertRankNum_, axisBS_) - 1 - epRankId_;
uint32_t preCnt = (moeOnShareRank + epRankId_) * axisBS_ / sharedExpertRankNum_ -
epRankId_ * axisBS_ / sharedExpertRankNum_;
uint32_t temp = (epRankId_ * activeMaskBsCnt_) / sharedExpertRankNum_;
uint32_t moeOnShareRank = Ceil((tokenIndex + 1 + temp) * sharedExpertRankNum_, activeMaskBsCnt_) - 1 - epRankId_;
uint32_t preCnt = (moeOnShareRank + epRankId_) * activeMaskBsCnt_ / sharedExpertRankNum_ -
epRankId_ * activeMaskBsCnt_ / sharedExpertRankNum_;
GlobalTensor<ExpandXOutType> dstWinGMTensor;
dstWinGMTensor.SetGlobalBuffer((__gm__ ExpandXOutType *)(GetWindAddrByRankId(COMM_EP_IDX, moeOnShareRank) +
expertPerSizeOnWin_ * epRankId_));
@@ -440,7 +487,7 @@ __aicore__ inline void CamMoeDistributeDispatch<TemplateDispatchTypeFunc>::SendT
template <TemplateDispatchTypeClass>
__aicore__ inline void CamMoeDistributeDispatch<TemplateDispatchTypeFunc>::SendToMoeExpert()
{
uint32_t expertIdsCnt = axisBS_ * axisK_;
uint32_t expertIdsCnt = activeMaskBsCnt_ * axisK_;
uint32_t sendTokenNum = expertIdsCnt / moeUsedAivNum_;
uint32_t remainderTokenNum = expertIdsCnt % moeUsedAivNum_;
uint32_t startTokenId = sendTokenNum * aivId_;
@@ -496,7 +543,12 @@ __aicore__ inline void CamMoeDistributeDispatch<TemplateDispatchTypeFunc>::SendT
template <TemplateDispatchTypeClass>
__aicore__ inline void CamMoeDistributeDispatch<TemplateDispatchTypeFunc>::AlltoAllDispatch()
{
uint32_t expertIdsCnt = axisBS_ * axisK_;
activeMaskBsCnt_ = axisBS_;
sendToMoeExpTokenCnt_ = activeMaskBsCnt_ * axisK_;
if constexpr (EXEC_FLAG & EXEC_FLAG_X_ACTIVE_MASK) {
TokenActiveMaskCal();
}
uint32_t expertIdsCnt = activeMaskBsCnt_ * axisK_;
DataCopyExtParams expertIdsCntParams = {1U, static_cast<uint32_t>(expertIdsCnt * sizeof(uint32_t)), 0U, 0U, 0U};
DataCopyPadExtParams<int32_t> copyPadParams{false, 0U, 0U, 0U};
DataCopyPad(expertIdsTensor_, expertIdsGMTensor_, expertIdsCntParams, copyPadParams);

View File

@@ -14,5 +14,8 @@
#define TemplateMC2TypeClass typename ExpandXType, typename ExpandIdxType, bool IsNeedReduceScatter, uint32_t EXEC_FLAG
#define TemplateMC2TypeFunc ExpandXType, ExpandIdxType, IsNeedReduceScatter, EXEC_FLAG
#define TemplateDispatchTypeClass \
typename XType, typename ExpandXOutType, bool StaticQuant, bool DynamicQuant, bool IsSmoothScaleExist, \
bool IsNeedAllgater, uint32_t EXEC_FLAG
#define TemplateDispatchTypeFunc XType, ExpandXOutType, StaticQuant, DynamicQuant, IsSmoothScaleExist, IsNeedAllgater, EXEC_FLAG
#endif // DISPATCH_GMM_COMBINE_DECODE_BASE_H

View File

@@ -70,5 +70,6 @@ constexpr uint32_t GMM2_SWIZZLE_DIRECTION = 0;
constexpr uint32_t WORKSPACE_STAGES = 4;
constexpr uint32_t EXEC_FLAG_DEEP_FUSE = (1U << 0);
constexpr uint32_t EXEC_FLAG_X_ACTIVE_MASK = (1U << 2);
#endif // DISPATCH_GMM_COMBINE_DECODE_TILING_H

View File

@@ -640,8 +640,9 @@ std::tuple<at::Tensor, at::Tensor> dispatch_gmm_combine_decode(
const at::Tensor &gmm1_permuted_weight_scale,
const at::Tensor &gmm2_weight,
const at::Tensor &gmm2_weight_scale,
const at::Tensor &expert_scales,
const c10::optional<at::Tensor> &expert_smooth_scales,
const c10::optional<at::Tensor> &expert_scales,
const c10::optional<at::Tensor> &x_active_mask,
c10::string_view group_ep,
int64_t ep_rank_size,
int64_t ep_rank_id,
@@ -674,8 +675,9 @@ std::tuple<at::Tensor, at::Tensor> dispatch_gmm_combine_decode(
gmm1_permuted_weight_scale,
gmm2_weight,
gmm2_weight_scale,
expert_smooth_scales,
expert_scales,
expert_smooth_scales,
x_active_mask,
//input attrs
group_ep_ptr,
ep_rank_size,
@@ -1188,7 +1190,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
"dispatch_gmm_combine_decode(Tensor x, Tensor expert_ids, Tensor gmm1_permuted_weight,"
" Tensor gmm1_permuted_weight_scale,"
" Tensor gmm2_weight, Tensor gmm2_weight_scale,"
" Tensor? expert_smooth_scales=None, Tensor? expert_scales=None,"
" Tensor expert_scales, Tensor? expert_smooth_scales=None,"
" Tensor? x_active_mask=None,"
" str group_ep='',"
" int ep_rank_size=0, int ep_rank_id=0, int moe_expert_num=0,"
" int shared_expert_num=1, int shared_expert_rank_num=0,"

View File

@@ -161,8 +161,9 @@ std::tuple<at::Tensor, at::Tensor> dispatch_gmm_combine_decode_meta(
const at::Tensor &gmm1_permuted_weight_scale,
const at::Tensor &gmm2_weight,
const at::Tensor &gmm2_weight_scale,
const at::Tensor &expert_scales,
const c10::optional<at::Tensor> &expert_smooth_scales,
const c10::optional<at::Tensor> &expert_scales,
const c10::optional<at::Tensor> &x_active_mask,
c10::string_view group_ep,
int64_t ep_rank_size,
int64_t ep_rank_id,

View File

@@ -16,6 +16,19 @@ torch.manual_seed(42)
torch_npu.npu.config.allow_internal_format = True
enable_custom_op()
LOG_NAME = "dispatch_gmm_combine_decode_test_logs"
BASE_KWARGS = {
"batch_size": 64,
"token_hidden_size": 7168,
"moe_intermediate_size": 2048,
"ep_world_size": 16,
"moe_expert_num": 64,
"shared_expert_rank_num": 0,
"top_k": 8,
"test_bfloat16": True,
"enable_dynamic_bs": False,
"test_graph": False,
"with_mc2_mask": False
}
def redirect_output(log_file_path):
@@ -115,11 +128,14 @@ class DecodeMoeOps(torch.nn.Module):
self.gmm2_weight_scale_fp32 = torch.nn.Parameter(
gmm2_weight_scale.float(), requires_grad=False)
def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales):
def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales,
x_active_mask):
raise NotImplementedError("To be implemented in subclass")
def forward(self, x, expert_ids, smooth_scales, expert_scales):
return self._apply_ops(x, expert_ids, smooth_scales, expert_scales)
def forward(self, x, expert_ids, smooth_scales, expert_scales,
x_active_mask):
return self._apply_ops(x, expert_ids, smooth_scales, expert_scales,
x_active_mask)
class SmallOps(DecodeMoeOps):
@@ -144,11 +160,13 @@ class SmallOps(DecodeMoeOps):
shared_expert_rank_num)
self.tp_hcomm_info = ""
def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales):
def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales,
x_active_mask):
outputs = torch_npu.npu_moe_distribute_dispatch_v2(
x=x,
expert_ids=expert_ids,
expert_scales=expert_scales,
x_active_mask=x_active_mask,
group_ep=self.ep_hcomm_info,
ep_world_size=self.ep_world_size,
ep_rank_id=self.global_rank_id,
@@ -200,6 +218,7 @@ class SmallOps(DecodeMoeOps):
assist_info_for_combine=assist_info_for_combine,
ep_send_counts=ep_send_counts,
expert_scales=expert_scales,
x_active_mask=x_active_mask,
group_ep=self.ep_hcomm_info,
ep_world_size=self.ep_world_size,
ep_rank_id=self.global_rank_id,
@@ -237,7 +256,8 @@ class FusionOp(DecodeMoeOps):
ep_world_size, moe_expert_num, global_rank_id,
shared_expert_rank_num)
def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales):
def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales,
x_active_mask):
output = torch.ops._C_ascend.dispatch_gmm_combine_decode(
x=x,
expert_ids=expert_ids,
@@ -245,8 +265,9 @@ class FusionOp(DecodeMoeOps):
gmm1_permuted_weight_scale=self.gmm1_weight_scale_fp32,
gmm2_weight=self.gmm2_weight,
gmm2_weight_scale=self.gmm2_weight_scale_fp32,
expert_smooth_scales=smooth_scales,
expert_scales=expert_scales,
expert_smooth_scales=smooth_scales,
x_active_mask=x_active_mask,
group_ep=self.ep_hcomm_info,
ep_rank_size=self.ep_world_size,
ep_rank_id=self.global_rank_id,
@@ -267,12 +288,13 @@ def generate_datas(batch_size,
shared_expert_rank_num=0,
top_k=8,
test_bfloat16=True,
enable_dynamic_bs=False):
enable_dynamic_bs=False,
with_mc2_mask=False):
is_shared_expert = global_rank_id < shared_expert_rank_num
moe_expert_num_per_rank = moe_expert_num // (ep_world_size -
shared_expert_rank_num)
actual_bs = int(
torch.randint(1, batch_size, [1]).item(
torch.randint(2 if with_mc2_mask else 1, batch_size, [1]).item(
) if enable_dynamic_bs else batch_size)
local_expert_num = 1 if is_shared_expert else moe_expert_num_per_rank
gmm1_input_dim = token_hidden_size
@@ -317,9 +339,16 @@ def generate_datas(batch_size,
else:
x = x.half()
smooth_sales = None
return (x, expert_ids, smooth_sales, expert_scales), \
x_active_mask = None
valid_token_num = actual_bs
if with_mc2_mask:
valid_token_num = int(torch.randint(1, actual_bs, [1]).item())
x_active_mask = torch.cat(
(torch.ones(valid_token_num),
torch.zeros(actual_bs - valid_token_num))).bool()
return (x, expert_ids, smooth_sales, expert_scales, x_active_mask), \
(gmm1_weight, gmm1_weight_scale, gmm2_weight, gmm2_weight_scale), \
actual_bs
actual_bs, valid_token_num
def run_once(local_rank_id,
@@ -332,7 +361,8 @@ def run_once(local_rank_id,
top_k=8,
test_bfloat16=True,
enable_dynamic_bs=False,
test_graph=False):
test_graph=False,
with_mc2_mask=False):
log_file = redirect_output(f"local_rank_{local_rank_id}.log"
) if output_to_file(local_rank_id) else None
global_rank_id = local_rank_id # 单机
@@ -358,8 +388,8 @@ def run_once(local_rank_id,
parameter = (batch_size, token_hidden_size, moe_intermediate_size,
ep_world_size, moe_expert_num, global_rank_id,
shared_expert_rank_num)
input_datas, weight_datas, actual_bs = generate_datas(
*parameter, top_k, test_bfloat16, enable_dynamic_bs)
input_datas, weight_datas, actual_bs, valid_token_num = generate_datas(
*parameter, top_k, test_bfloat16, enable_dynamic_bs, with_mc2_mask)
input_datas = [
data.npu() if data is not None else None for data in input_datas
]
@@ -382,8 +412,8 @@ def run_once(local_rank_id,
if log_file is not None:
log_file.close()
small_op_count_output = from_inclusive_prefix_sum(small_op_count_output)
torch.testing.assert_close(small_op_token_output.cpu(),
fused_op_token_output.cpu(),
torch.testing.assert_close(small_op_token_output[0:valid_token_num].cpu(),
fused_op_token_output[0:valid_token_num].cpu(),
atol=2.0,
rtol=0.02)
torch.testing.assert_close(small_op_count_output.cpu(),
@@ -394,18 +424,16 @@ def run_once(local_rank_id,
@torch.inference_mode()
def test():
batch_size = 64
token_hidden_size = 7168
moe_intermediate_size = 2048
ep_world_size = 16
moe_expert_num = 64
shared_expert_rank_num = 0
top_k = 8
test_bfloat16 = True
enable_dynamic_bs = False
test_graph = False
args = (batch_size, token_hidden_size, moe_intermediate_size,
ep_world_size, moe_expert_num, shared_expert_rank_num, top_k,
test_bfloat16, enable_dynamic_bs, test_graph)
mp.spawn(run_once, args=args, nprocs=ep_world_size, join=True)
def test_dispatch_gmm_combine_decode_base():
custom_kwargs = BASE_KWARGS
ep_world_size = custom_kwargs["ep_world_size"]
custom_args = tuple(custom_kwargs.values())
mp.spawn(run_once, args=custom_args, nprocs=ep_world_size, join=True)
def test_dispatch_gmm_combine_decode_with_mc2_mask():
custom_kwargs = BASE_KWARGS
custom_kwargs["with_mc2_mask"] = True
ep_world_size = custom_kwargs["ep_world_size"]
custom_args = tuple(custom_kwargs.values())
mp.spawn(run_once, args=custom_args, nprocs=ep_world_size, join=True)