diff --git a/csrc/dispatch_gmm_combine_decode/op_host/aclnn_dispatch_gmm_combine_decode.cpp b/csrc/dispatch_gmm_combine_decode/op_host/aclnn_dispatch_gmm_combine_decode.cpp index 30c890c0..309e179a 100644 --- a/csrc/dispatch_gmm_combine_decode/op_host/aclnn_dispatch_gmm_combine_decode.cpp +++ b/csrc/dispatch_gmm_combine_decode/op_host/aclnn_dispatch_gmm_combine_decode.cpp @@ -30,8 +30,9 @@ extern aclnnStatus aclnnInnerDispatchGmmCombineDecodeGetWorkspaceSize( const aclTensor *gmm1PermutedWeightScale, const aclTensor *gmm2Weight, const aclTensor *gmm2WeightScale, + const aclTensor *expertScales, const aclTensor *expertSmoothScalesOptional, - const aclTensor *expertScalesOptional, + const aclTensor *xActiveMaskOptional, char *groupEp, int64_t epRankSize, int64_t epRankId, @@ -57,8 +58,9 @@ aclnnStatus aclnnDispatchGmmCombineDecodeGetWorkspaceSize( const aclTensor *gmm1PermutedWeightScale, const aclTensor *gmm2Weight, const aclTensor *gmm2WeightScale, + const aclTensor *expertScales, const aclTensor *expertSmoothScalesOptional, - const aclTensor *expertScalesOptional, + const aclTensor *xActiveMaskOptional, char *groupEp, int64_t epRankSize, int64_t epRankId, @@ -73,7 +75,7 @@ aclnnStatus aclnnDispatchGmmCombineDecodeGetWorkspaceSize( aclOpExecutor **executor) { return aclnnInnerDispatchGmmCombineDecodeGetWorkspaceSize(x, expertIds, gmm1PermutedWeight, gmm1PermutedWeightScale, - gmm2Weight, gmm2WeightScale, expertSmoothScalesOptional, expertScalesOptional, groupEp, epRankSize, + gmm2Weight, gmm2WeightScale, expertScales, expertSmoothScalesOptional, xActiveMaskOptional, groupEp, epRankSize, epRankId, moeExpertNum, shareExpertNum, shareExpertRankNum, quantMode, globalBs, output, epRecvCount, workspaceSize, executor); } diff --git a/csrc/dispatch_gmm_combine_decode/op_host/aclnn_dispatch_gmm_combine_decode.h b/csrc/dispatch_gmm_combine_decode/op_host/aclnn_dispatch_gmm_combine_decode.h index bf7fc18b..6601e2a1 100644 --- a/csrc/dispatch_gmm_combine_decode/op_host/aclnn_dispatch_gmm_combine_decode.h +++ b/csrc/dispatch_gmm_combine_decode/op_host/aclnn_dispatch_gmm_combine_decode.h @@ -23,8 +23,9 @@ __attribute__((visibility("default"))) aclnnStatus aclnnDispatchGmmCombineDecode const aclTensor *gmm1PermutedWeightScale, const aclTensor *gmm2Weight, const aclTensor *gmm2WeightScale, + const aclTensor *expertScales, const aclTensor *expertSmoothScalesOptional, - const aclTensor *expertScalesOptional, + const aclTensor *xActiveMaskOptional, char *groupEp, int64_t epRankSize, int64_t epRankId, diff --git a/csrc/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_def.cpp b/csrc/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_def.cpp index 0a0737b2..520c64c4 100644 --- a/csrc/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_def.cpp +++ b/csrc/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_def.cpp @@ -45,14 +45,19 @@ public: .DataType({ge::DT_FLOAT, ge::DT_FLOAT}) .Format({ge::FORMAT_ND, ge::FORMAT_ND}) .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("expert_scales") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); this->Input("expert_smooth_scales") .ParamType(OPTIONAL) .DataType({ge::DT_FLOAT, ge::DT_FLOAT}) .Format({ge::FORMAT_ND, ge::FORMAT_ND}) .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); - this->Input("expert_scales") + this->Input("x_active_mask") .ParamType(OPTIONAL) - .DataType({ge::DT_FLOAT, ge::DT_FLOAT}) + .DataType({ge::DT_BOOL, ge::DT_BOOL}) .Format({ge::FORMAT_ND, ge::FORMAT_ND}) .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); this->Output("output") diff --git a/csrc/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_tiling.cpp b/csrc/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_tiling.cpp index 977b3c6e..699b9e87 100644 --- a/csrc/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_tiling.cpp +++ b/csrc/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_tiling.cpp @@ -35,8 +35,9 @@ constexpr uint32_t INPUT_GMM1_WEIGHT_INDEX = 2; constexpr uint32_t INPUT_GMM1_WEIGHT_SCALE_INDEX = 3; constexpr uint32_t INPUT_GMM2_WEIGHT_INDEX = 4; constexpr uint32_t INPUT_GMM2_WEIGHT_SCALE_INDEX = 5; -constexpr uint32_t INPUT_SMOOTH_SCALE_INDEX = 6; -constexpr uint32_t INPUT_EXPERT_SCALE_INDEX = 7; +constexpr uint32_t INPUT_EXPERT_SCALE_INDEX = 6; +constexpr uint32_t INPUT_SMOOTH_SCALE_INDEX = 7; +constexpr uint32_t INPUT_SHARE_X_ACTIVE_MASK_INDEX = 8; constexpr uint32_t ATTR_GROUP_EP_INDEX = 0; constexpr uint32_t ATTR_EP_RANK_SIZE_INDEX = 1; @@ -51,6 +52,7 @@ constexpr uint32_t MIN_BATCH_SIZE = 1; constexpr uint32_t MAX_BATCH_SIZE = 256; constexpr uint32_t MAX_MOE_EXERT_NUM = 512; constexpr uint32_t SUPPORT_TOP_K = 12; +constexpr uint32_t ONE_DIMS = 1; constexpr uint32_t TWO_DIMS = 2; constexpr uint32_t MIN_TOKEN_LENGTH = 512; constexpr uint32_t MAX_TOKEN_LENGTH = 7168; @@ -71,6 +73,7 @@ static ge::graphStatus CheckTensorShape(gert::TilingContext *context, const char uint32_t moeExpertNum = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNum; uint32_t sharedExpertRankNum = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.sharedExpertRankNum; uint32_t moeExpertNumPerRank = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank; + uint32_t batchSize = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.bs; uint32_t h = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.h; uint64_t gmm1WeightDim2 = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen; @@ -122,6 +125,18 @@ static ge::graphStatus CheckTensorShape(gert::TilingContext *context, const char OPS_ERR_IF(gmm2WeightScaleDim1 != h, OPS_LOG_E(nodeName, "gmm2WeightScale Dim1 must be %u.", h), return ge::GRAPH_FAILED); + const gert::StorageShape* xActiveMaskStorageShape = context->GetOptionalInputShape( + INPUT_SHARE_X_ACTIVE_MASK_INDEX); + if (xActiveMaskStorageShape != nullptr) { + OPS_ERR_IF(xActiveMaskStorageShape->GetStorageShape().GetDimNum() != ONE_DIMS, + OPS_LOG_E(nodeName, " xActiveMask scale shape dims must be 1, but current dim num is %lu.", + xActiveMaskStorageShape->GetStorageShape().GetDimNum()), + return ge::GRAPH_FAILED); + const int64_t xActiveMaskDim0 = xActiveMaskStorageShape->GetStorageShape().GetDim(0); + OPS_ERR_IF(xActiveMaskDim0 != batchSize, OPS_LOG_E(nodeName, + "xActiveMask Dim0 must be batchSize(%u), but current dim is %lu.", batchSize, xActiveMaskDim0), + return ge::GRAPH_FAILED); + } return ge::GRAPH_SUCCESS; } @@ -308,14 +323,22 @@ static ge::graphStatus DispatchGmmCombineDecodeTilingFuncImpl(gert::TilingContex tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.aivNum = aivNum; OPS_ERR_IF(CheckData(nodeName, *tilingData) != ge::GRAPH_SUCCESS, OPS_LOG_E(nodeName, "CheckData failed."), return ge::GRAPH_FAILED); + OPS_ERR_IF(CheckTensorShape(context, nodeName, *tilingData) != ge::GRAPH_SUCCESS, OPS_LOG_E( + nodeName, "CheckTensorShape failed."), return ge::GRAPH_FAILED); OPS_ERR_IF(SetWorkSpace(context, nodeName, *tilingData) != ge::GRAPH_SUCCESS, OPS_LOG_E(nodeName, "Tiling set workspace failed."), return ge::GRAPH_FAILED); SetHcommCfg(context, tilingData, groupEp); - if (tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank == 1) { - context->SetTilingKey(0); - } else { - context->SetTilingKey(EXEC_FLAG_DEEP_FUSE); + const gert::StorageShape* xActiveMaskStorageShape = context->GetOptionalInputShape( + INPUT_SHARE_X_ACTIVE_MASK_INDEX); + bool xActiveMaskEnable = (xActiveMaskStorageShape != nullptr); + uint64_t tilingKey = 0; + if (xActiveMaskEnable) { + tilingKey |= EXEC_FLAG_X_ACTIVE_MASK; } + if (tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank != 1) { + tilingKey |= EXEC_FLAG_DEEP_FUSE; + } + context->SetTilingKey(tilingKey); context->SetBlockDim(aicNum); return ge::GRAPH_SUCCESS; } diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.cpp b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.cpp index f5a9ba67..40d8e82e 100644 --- a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.cpp +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.cpp @@ -14,7 +14,8 @@ extern "C" __global__ __aicore__ void dispatch_gmm_combine_decode( // input GM_ADDR x, GM_ADDR expert_ids, GM_ADDR gmm1_permuted_weight, GM_ADDR gmm1_permuted_weight_scale, - GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_smooth_scales, GM_ADDR expert_scales, + GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_scales, GM_ADDR expert_smooth_scales, + GM_ADDR x_active_mask, // output GM_ADDR output, GM_ADDR outputRecvCount, // system @@ -24,10 +25,10 @@ extern "C" __global__ __aicore__ void dispatch_gmm_combine_decode( REGISTER_TILING_DEFAULT(DispatchGmmCombineDecodeTilingData); KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_MIX_AIC_1_2); // 1C2V GET_TILING_DATA(tiling_data, tiling); - if constexpr (TILING_KEY_IS(0) || TILING_KEY_IS(1)) { + if constexpr (TILING_KEY_IS(0) || TILING_KEY_IS(1) || TILING_KEY_IS(4) || TILING_KEY_IS(5)) { DispatchGmmCombineDecode op; op.Init(x, expert_ids, gmm1_permuted_weight, gmm1_permuted_weight_scale, gmm2_weight, gmm2_weight_scale, - expert_smooth_scales, expert_scales, output, outputRecvCount, workspace, nullptr, &tiling_data); + expert_scales, expert_smooth_scales, x_active_mask, output, outputRecvCount, workspace, nullptr, &tiling_data); op.Process(); } } diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.h b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.h index cb7dabb7..74691f94 100644 --- a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.h +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.h @@ -62,7 +62,7 @@ CATLASS_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCoun layout::VectorLayout layoutPerTokenScale, GM_ADDR gmD, layout::RowMajor layoutD, GM_ADDR gmDequantScale, layout::VectorLayout layoutDequantScale, GM_ADDR gmWorkspace, GM_ADDR gmX, GM_ADDR debugGm, GM_ADDR gmexpertIds, GM_ADDR gmExpandIdx, - GM_ADDR gmEpSendCount, GM_ADDR gmResvered, GM_ADDR gmOutputRecvCount, + GM_ADDR gmEpSendCount, GM_ADDR xActiveMask, GM_ADDR gmResvered, GM_ADDR gmOutputRecvCount, uint32_t epRankSize, uint32_t epRankId, uint32_t moeExpertNum, uint32_t moeExpertNumPerRank, uint32_t sharedExpertNum, uint32_t sharedExpertRankNum, uint32_t quantMode, uint32_t globalBs, uint32_t bs, uint32_t topK, uint32_t tokenLen) @@ -110,7 +110,7 @@ CATLASS_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCoun using GemmKernel = typename std::conditional< (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE), Gemm::Kernel::GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspace< - XType, BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>, + EXEC_FLAG, XType, BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>, Gemm::Kernel::GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspaceWithShallowDispatch< BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>>::type; @@ -136,6 +136,7 @@ CATLASS_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCoun gmexpertIds, gmExpandIdx, gmEpSendCount, + xActiveMask, gmResvered, gmOutputRecvCount, epRankSize, @@ -241,7 +242,7 @@ public: __aicore__ inline void Init( // input GM_ADDR x, GM_ADDR expert_ids, GM_ADDR gmm1_permuted_weight, GM_ADDR gmm1_permuted_weight_scale, - GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_smooth_scales, GM_ADDR expert_scales, + GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_scales, GM_ADDR expert_smooth_scales, GM_ADDR x_active_mask, // output GM_ADDR output, GM_ADDR outputRecvCount, // system @@ -260,6 +261,7 @@ private: GM_ADDR workspaceGM_; GM_ADDR gmSmoothScales_; GM_ADDR gmexpertScales_; + GM_ADDR xActiveMask_; uint32_t maxTokenNum_{0}; uint32_t gmm1OutputDim_{0}; @@ -291,7 +293,8 @@ template __aicore__ inline void DispatchGmmCombineDecode::Init( // input GM_ADDR x, GM_ADDR expert_ids, GM_ADDR gmm1_permuted_weight, GM_ADDR gmm1_permuted_weight_scale, - GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_smooth_scales, GM_ADDR expert_scales, + GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_scales, GM_ADDR expert_smooth_scales, + GM_ADDR x_active_mask, // output GM_ADDR output, GM_ADDR outputRecvCount, // system @@ -312,6 +315,7 @@ __aicore__ inline void DispatchGmmCombineDecode::Init( gmOutputRecvCount_ = outputRecvCount; workspaceGM_ = workspaceGM; gmexpertScales_ = expert_scales; + xActiveMask_ = x_active_mask; tilingData_ = tilingData; epRankSize_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.epRankSize; epRankId_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.epRankId; @@ -386,12 +390,12 @@ __aicore__ inline void DispatchGmmCombineDecode::Process() GM_ADDR gmResvered = workspaceGM_ + workspaceOffset; workspaceOffset += RoundUp(resveredWorkSpaceSize); - if constexpr (EXEC_FLAG == 0) { + if constexpr ((EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) == 0) { if constexpr (g_coreType == AscendC::AIV) { AscendC::TPipe tpipe; - MoeDistributeDispatchImpl::CamMoeDistributeDispatch + MoeDistributeDispatchImpl::CamMoeDistributeDispatch dispatcher; - dispatcher.Init(gmX_, gmexpertIds_, gmSmoothScales_, gmX1, gmX1Scale, gmExpandIdx, gmGroupList, + dispatcher.Init(gmX_, gmexpertIds_, gmSmoothScales_, xActiveMask_, gmX1, gmX1Scale, gmExpandIdx, gmGroupList, gmEpSendCount, gmOutputRecvCount_, nullptr, gmWorkspace, &tpipe, tilingData_); dispatcher.Process(); tpipe.Destroy(); @@ -411,7 +415,7 @@ __aicore__ inline void DispatchGmmCombineDecode::Process() Gmm1BlockScheduler>( gmm1ProblemShape, groupCount_, gmGroupList, gmX1, layoutX1, gmPermuteWeight1_, layoutWeight1, gmPermuteScale1_, layoutW1Scale, gmX1Scale, layoutX1Scale, gmX2, layoutX2, gmX2Scale, - layoutX2Scale, gmWorkspace, gmX_, gmSmoothScales_, gmexpertIds_, gmExpandIdx, gmEpSendCount, gmResvered, + layoutX2Scale, gmWorkspace, gmX_, gmSmoothScales_, gmexpertIds_, gmExpandIdx, gmEpSendCount, xActiveMask_, gmResvered, gmOutputRecvCount_, epRankSize_, epRankId_, moeExpertNum_, moeExpertNumPerRank_, sharedExpertNum_, sharedExpertRankNum_, quantMode_, globalBs_, bs_, topK_, tokenHiddenSize_); AscendC::PipeBarrier(); @@ -425,7 +429,7 @@ __aicore__ inline void DispatchGmmCombineDecode::Process() MoeDistributeCombineImpl::CamMoeDistributeCombine combiner; if (g_coreType == AscendC::AIV) { - combiner.Init(gmGmm2DepOut, gmexpertIds_, gmExpandIdx, gmEpSendCount, nullptr, gmexpertScales_, gmOutput_, + combiner.Init(gmGmm2DepOut, gmexpertIds_, gmExpandIdx, gmEpSendCount, nullptr, gmexpertScales_, xActiveMask_, gmOutput_, workspaceGM_, nullptr, tilingData_); } GmmDeq class GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspace { @@ -411,6 +411,7 @@ public: GM_ADDR gmX; GM_ADDR debugGm; GM_ADDR gmexpertIds; + GM_ADDR gmXActiveMask; GM_ADDR gmExpandIdx; GM_ADDR gmEpSendCount; @@ -438,7 +439,7 @@ public: LayoutScale const &layoutScale_, GM_ADDR ptrPerTokenScale_, LayoutPerTokenScale const &layoutPerTokenScale_, GM_ADDR ptrOutput_, LayoutOutput const &layoutOutput_, GM_ADDR ptrDequantScale_, LayoutDequantScale const &layoutDequantScale_, GM_ADDR ptrWorkspace_, - GM_ADDR gmX_, GM_ADDR debugGm_, GM_ADDR gmexpertIds_, GM_ADDR gmExpandIdx_, GM_ADDR gmEpSendCount_, + GM_ADDR gmX_, GM_ADDR debugGm_, GM_ADDR gmexpertIds_, GM_ADDR gmExpandIdx_, GM_ADDR gmEpSendCount_, GM_ADDR gmXActiveMask_, GM_ADDR gmResvered_, GM_ADDR gmOutputRecvCount_, uint32_t epRankSize_, uint32_t epRankId_, uint32_t moeExpertNum_, uint32_t moeExpertNumPerRank_, uint32_t sharedExpertNum_, uint32_t sharedExpertRankNum_, uint32_t quantMode_, uint32_t globalBs_, uint32_t bs_, uint32_t topK_, @@ -465,6 +466,7 @@ public: gmExpandIdx(gmExpandIdx_), gmEpSendCount(gmEpSendCount_), gmOutputRecvCount(gmOutputRecvCount_), + gmXActiveMask(gmXActiveMask_), gmResvered(gmResvered_), epRankSize(epRankSize_), epRankId(epRankId_), @@ -635,6 +637,39 @@ public: AscendC::SyncAll(); } + CATLASS_DEVICE + void TokenActiveMaskCal(GM_ADDR gmXActiveMask, int64_t ubOffset) + { + int64_t subUbOffset = ubOffset; + AscendC::LocalTensor maskInputTensor = (resource.ubBuf.template + GetBufferByByte(subUbOffset)); + AscendC::LocalTensor maskInputInt8Tensor = maskInputTensor.template ReinterpretCast(); + subUbOffset += CEIL_UP(axisBS * sizeof(bool)); + AscendC::LocalTensor maskTmpTensor = (resource.ubBuf.template + GetBufferByByte(subUbOffset)); + subUbOffset += CEIL_UP(axisBS * sizeof(half)); + AscendC::LocalTensor sumOutTensor = (resource.ubBuf.template + GetBufferByByte(subUbOffset)); + subUbOffset += CEIL_UP(SUM_TMP_TENSOR_SIZE); + + AscendC::GlobalTensor xActiveMaskGMTensor; + xActiveMaskGMTensor.SetGlobalBuffer((__gm__ bool *)gmXActiveMask); + uint32_t axisBsAlignSize = CEIL_UP(axisBS * sizeof(bool)); + + AscendC::DataCopyExtParams maskParams = {1U, static_cast(axisBS * sizeof(bool)), 0U, 0U, 0U}; + AscendC::DataCopyPadExtParams maskCopyPadParams{false, 0U, 0U, 0U}; + AscendC::DataCopyPad(maskInputTensor, xActiveMaskGMTensor, maskParams, maskCopyPadParams); + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + AscendC::Cast(maskTmpTensor, maskInputInt8Tensor, AscendC::RoundMode::CAST_NONE, axisBS); + AscendC::PipeBarrier(); + AscendC::SumParams params{1, axisBsAlignSize, axisBS}; + AscendC::Sum(sumOutTensor, maskTmpTensor, params); + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + activeMaskBsCnt = static_cast(sumOutTensor.GetValue(0)); + } + CATLASS_DEVICE void CalExpandxIdx(int32_t dstExpertId, uint32_t tokenIndex, int32_t &curExpertCnt, int64_t ubOffset) { @@ -778,8 +813,8 @@ public: void SendToShareExprt(GM_ADDR gmX, GM_ADDR gmX1, GM_ADDR gmX1Scale) { uint32_t newAivId = sendCoreIdx - sendToMoeAivNum; - uint32_t sendTokenNum = axisBS / sendToShareAivNum; - uint32_t remainderTokenNum = axisBS % sendToShareAivNum; + uint32_t sendTokenNum = activeMaskBsCnt / sendToShareAivNum; + uint32_t remainderTokenNum = activeMaskBsCnt % sendToShareAivNum; uint32_t startTokenId = sendTokenNum * newAivId; if (newAivId < remainderTokenNum) { sendTokenNum += 1; @@ -788,7 +823,7 @@ public: startTokenId += remainderTokenNum; } uint32_t endTokenId = startTokenId + sendTokenNum; - if (startTokenId >= axisBS) { + if (startTokenId >= activeMaskBsCnt) { return; } @@ -962,10 +997,13 @@ public: } CATLASS_DEVICE void - SendCoreFunc(GM_ADDR gmX, GM_ADDR gmExpertIds, GM_ADDR gmX1, GM_ADDR gmX1Scale, GM_ADDR gmExpandIdx) + SendCoreFunc(GM_ADDR gmX, GM_ADDR gmExpertIds, GM_ADDR gmX1, GM_ADDR gmX1Scale, GM_ADDR gmExpandIdx, GM_ADDR gmXActiveMask) { ubOffset = 0; - expertIdsCnt = axisBS * axisK; + if constexpr (EXEC_FLAG & EXEC_FLAG_X_ACTIVE_MASK) { + TokenActiveMaskCal(gmXActiveMask, ubOffset); + } + expertIdsCnt = activeMaskBsCnt * axisK; AscendC::GlobalTensor expertIdsGMTensor_; expertIdsGMTensor_.SetGlobalBuffer((__gm__ int32_t *)gmExpertIds); @@ -1372,6 +1410,7 @@ public: hCommuSize = hOutSize + scaleParamPad; axisHCommu = hCommuSize / sizeof(int8_t); axisBS = params.bs; + activeMaskBsCnt = axisBS; axisK = params.topK; uint32_t maxAxisBs = params.globalBs / epRankSize; @@ -1489,7 +1528,7 @@ public: AivInitState(); if (isSendCore) { SendCoreFunc((GM_ADDR)params.gmX, (GM_ADDR)params.gmexpertIds, (GM_ADDR)params.ptrA, - (GM_ADDR)params.ptrPerTokenScale, (GM_ADDR)params.gmExpandIdx); + (GM_ADDR)params.ptrPerTokenScale, (GM_ADDR)params.gmExpandIdx, (GM_ADDR)params.gmXActiveMask); } if (isRecvCore) { RecvCoreFunc((GM_ADDR)params.ptrA, (GM_ADDR)params.ptrPerTokenScale, (GM_ADDR)params.gmEpSendCount, @@ -1596,6 +1635,7 @@ private: uint32_t hCommuSize{0}; uint32_t axisHCommu{0}; uint32_t axisBS{0}; + uint32_t activeMaskBsCnt{0}; uint32_t axisK{0}; uint32_t totalTokenCount{0}; uint32_t expertIdsCnt{0}; diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/raw_distributed/cam_moe_distribute_combine.h b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/raw_distributed/cam_moe_distribute_combine.h index bad72869..c8fa2e1f 100644 --- a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/raw_distributed/cam_moe_distribute_combine.h +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/raw_distributed/cam_moe_distribute_combine.h @@ -62,7 +62,7 @@ class CamMoeDistributeCombine public: __aicore__ inline CamMoeDistributeCombine(){}; __aicore__ inline void Init(GM_ADDR expandX, GM_ADDR expertIds, GM_ADDR expandIdx, GM_ADDR epSendCount, - GM_ADDR tpSendCount, GM_ADDR scales, GM_ADDR XOut, GM_ADDR workspaceGM, TPipe *pipe, + GM_ADDR tpSendCount, GM_ADDR scales, GM_ADDR xActiveMask, GM_ADDR XOut, GM_ADDR workspaceGM, TPipe *pipe, const DispatchGmmCombineDecodeTilingData *tilingData); __aicore__ inline void Process(); __aicore__ inline void AllToAllSend(); @@ -145,6 +145,7 @@ private: GlobalTensor epSendCountGM_; GlobalTensor tpSendCountGM_; GlobalTensor expandScalesGM_; + GlobalTensor xActiveMaskGM_; GlobalTensor expandOutGlobal_; GlobalTensor rankWindow_; GlobalTensor rankStates_; @@ -169,6 +170,8 @@ private: CombineCalcInfo calcInfo_; uint32_t axisBS_{0}; uint32_t axisMaxBs_{0}; + uint32_t axisBsAlignSize_{0}; + uint64_t activeMaskBsCnt_{0}; uint32_t axisH_{0}; uint32_t axisK_{0}; uint32_t aivNum_{0}; @@ -224,6 +227,9 @@ private: TBuf<> gatherMaskOutBuf_; // gather mask output buf TBuf<> gatherTmpBuf_; TBuf<> statusSumOutBuf_; + TBuf<> xActMaskTBuf_; + TBuf<> xActMaskCastTBuf_; + TBuf<> xActMaskSumTBuf_; float sumTarget_{0.0}; int32_t epStateValue_; bool isShardExpert_{false}; @@ -231,7 +237,7 @@ private: template __aicore__ inline void CamMoeDistributeCombine::Init( - GM_ADDR expandX, GM_ADDR expertIds, GM_ADDR expandIdx, GM_ADDR epSendCount, GM_ADDR tpSendCount, GM_ADDR scales, + GM_ADDR expandX, GM_ADDR expertIds, GM_ADDR expandIdx, GM_ADDR epSendCount, GM_ADDR tpSendCount, GM_ADDR scales, GM_ADDR xActiveMask, GM_ADDR XOut, GM_ADDR workspaceGM, TPipe *pipe, const DispatchGmmCombineDecodeTilingData *tilingData) { tpipe_ = pipe; @@ -264,8 +270,10 @@ __aicore__ inline void CamMoeDistributeCombine::Init( expandIdxGM_.SetGlobalBuffer((__gm__ ExpandIdxType *)expandIdx); epSendCountGM_.SetGlobalBuffer((__gm__ int32_t *)epSendCount); expandScalesGM_.SetGlobalBuffer((__gm__ float *)scales); + xActiveMaskGM_.SetGlobalBuffer((__gm__ bool*)xActiveMask); expandOutGlobal_.SetGlobalBuffer((__gm__ ExpandXType *)XOut); axisBS_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.bs; + activeMaskBsCnt_ = axisBS_; axisH_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.h; axisK_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.k; if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) { @@ -395,6 +403,27 @@ __aicore__ inline void CamMoeDistributeCombine::AlltoAllBuf tpipe_->InitBuffer(gatherMaskOutBuf_, epWorldSize_ * sizeof(float)); tpipe_->InitBuffer(gatherTmpBuf_, sizeof(uint32_t)); tpipe_->InitBuffer(statusSumOutBuf_, sizeof(float)); + + if constexpr (EXEC_FLAG & EXEC_FLAG_X_ACTIVE_MASK) { + axisBsAlignSize_ = Ceil(axisBS_ * sizeof(bool), UB_ALIGN) * UB_ALIGN; + tpipe_->InitBuffer(xActMaskTBuf_, axisBsAlignSize_); + tpipe_->InitBuffer(xActMaskCastTBuf_, axisBsAlignSize_ * sizeof(half)); + tpipe_->InitBuffer(xActMaskSumTBuf_, axisBsAlignSize_ * sizeof(half)); + LocalTensor xActiveMaskTensor = xActMaskTBuf_.Get(); + LocalTensor tempTensor = xActMaskCastTBuf_.Get(); + LocalTensor sumOutTensor = xActMaskSumTBuf_.Get(); + DataCopyExtParams xActiveMaskParams{1U, static_cast(axisBS_ * sizeof(bool)), 0U, 0U, 0U}; + DataCopyPadExtParams xActiveMaskCopyPadParams{false, 0U, 0U, 0U}; + DataCopyPad(xActiveMaskTensor, xActiveMaskGM_, xActiveMaskParams, xActiveMaskCopyPadParams); + SyncFunc(); + LocalTensor xActiveMaskInt8Tensor = xActiveMaskTensor.ReinterpretCast(); + Cast(tempTensor, xActiveMaskInt8Tensor, RoundMode::CAST_NONE, axisBS_); + PipeBarrier(); + SumParams params{1, axisBsAlignSize_, axisBS_}; + Sum(sumOutTensor, tempTensor, params); + SyncFunc(); + activeMaskBsCnt_ = static_cast(sumOutTensor.GetValue(0)); + } } template @@ -645,13 +674,16 @@ __aicore__ inline void CamMoeDistributeCombine::WaitDispatc template __aicore__ inline void CamMoeDistributeCombine::LocalWindowCopy() { + if (activeMaskBsCnt_ == 0U) { + return; + } uint32_t beginIndex = 0; uint32_t endIndex = 0; uint32_t processLen = 0; uint32_t tokenOffset = 0; - if (axisBS_ < aivNum_) { - uint32_t aivNumPerToken = aivNum_ / axisBS_; // axisBS_ < aivNum_ - if (coreIdx_ >= (axisBS_ * aivNumPerToken)) { + if (activeMaskBsCnt_ < aivNum_) { + uint32_t aivNumPerToken = aivNum_ / activeMaskBsCnt_; // activeMaskBsCnt_ < aivNum_ + if (coreIdx_ >= (activeMaskBsCnt_ * aivNumPerToken)) { return; } uint32_t tokenIndex = coreIdx_ / aivNumPerToken; @@ -663,8 +695,8 @@ __aicore__ inline void CamMoeDistributeCombine::LocalWindow beginIndex = tokenIndex; endIndex = beginIndex + 1U; } else { - uint32_t tokenPerAivNum = axisBS_ / aivNum_; - uint32_t remainderToken = axisBS_ % aivNum_; + uint32_t tokenPerAivNum = activeMaskBsCnt_ / aivNum_; + uint32_t remainderToken = activeMaskBsCnt_ % aivNum_; beginIndex = tokenPerAivNum * coreIdx_; if (coreIdx_ < remainderToken) { tokenPerAivNum++; @@ -723,10 +755,10 @@ __aicore__ inline void CamMoeDistributeCombine::LocalWindow } LocalTensor rowTmpLocal = tokenBuf_.Get(); if (sharedExpertRankNum_ > 0U) { - uint32_t temp = (epRankId_ * axisBS_) / sharedExpertRankNum_; - uint32_t moeOnShareRank = Ceil((tokenIndex + 1 + temp) * sharedExpertRankNum_, axisBS_) - 1 - epRankId_; - uint32_t preCnt = (moeOnShareRank + epRankId_) * axisBS_ / sharedExpertRankNum_ - - epRankId_ * axisBS_ / sharedExpertRankNum_; + uint32_t temp = (epRankId_ * activeMaskBsCnt_) / sharedExpertRankNum_; + uint32_t moeOnShareRank = Ceil((tokenIndex + 1 + temp) * sharedExpertRankNum_, activeMaskBsCnt_) - 1 - epRankId_; + uint32_t preCnt = (moeOnShareRank + epRankId_) * activeMaskBsCnt_ / sharedExpertRankNum_ - + epRankId_ * activeMaskBsCnt_ / sharedExpertRankNum_; __gm__ ExpandXType *shareAddr = (__gm__ ExpandXType *)(epWindowGM_ + moeOnShareRank * expertPerSizeOnWin_ * moeExpertPerRankNum_) + (tokenIndex - preCnt) * axisH_ + tokenOffset; diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/raw_distributed/cam_moe_distribute_dispatch.h b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/raw_distributed/cam_moe_distribute_dispatch.h index dc01adda..f73e2c60 100644 --- a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/raw_distributed/cam_moe_distribute_dispatch.h +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/raw_distributed/cam_moe_distribute_dispatch.h @@ -22,6 +22,7 @@ constexpr uint8_t BUFFER_NUM = 2; constexpr uint32_t STATE_OFFSET = 512; // state space offset constexpr uint32_t STATE_SIZE = 1024 * 1024; // 1M constexpr uint32_t UB_ALIGN = 32; +constexpr uint64_t ALIGNED_LEN_256 = 256UL; constexpr uint32_t SELF_STATE_OFFSET = 256 * 1024; constexpr uint8_t COMM_NUM = 2; constexpr uint8_t COMM_EP_IDX = 0; @@ -47,18 +48,13 @@ __aicore__ inline void SyncFunc() AscendC::WaitFlag(eventID); } -#define TemplateDispatchTypeClass \ - typename XType, typename ExpandXOutType, bool StaticQuant, bool DynamicQuant, bool IsSmoothScaleExist, \ - bool IsNeedAllgater -#define TemplateDispatchTypeFunc XType, ExpandXOutType, StaticQuant, DynamicQuant, IsSmoothScaleExist, IsNeedAllgater - using namespace AscendC; template class CamMoeDistributeDispatch { public: __aicore__ inline CamMoeDistributeDispatch(){}; - __aicore__ inline void Init(GM_ADDR x, GM_ADDR expertIds, GM_ADDR scales, GM_ADDR expandXOut, + __aicore__ inline void Init(GM_ADDR x, GM_ADDR expertIds, GM_ADDR scales, GM_ADDR xActiveMask, GM_ADDR expandXOut, GM_ADDR dynamicScalesOut, GM_ADDR expandIdxOut, GM_ADDR expertTokenNumsOut, GM_ADDR sendCountsOut, GM_ADDR outputRecvCount, GM_ADDR tpSendCountsOut, GM_ADDR workspaceGM, TPipe *pipe, const DispatchGmmCombineDecodeTilingData *tilingData); @@ -80,6 +76,7 @@ private: __aicore__ inline void AllGatherSetStatusAndWait(); __aicore__ inline void ResetStatus(); __aicore__ inline void QuantInit(GM_ADDR scales); + __aicore__ inline void TokenActiveMaskCal(); __aicore__ inline void AllgatherProcessOut(); __aicore__ inline void UpdataMultiMoeTokenNumsOut(); __aicore__ inline void UpdataTokenNumsOut(); @@ -112,6 +109,7 @@ private: GlobalTensor xGMTensor_; GlobalTensor expertIdsGMTensor_; GlobalTensor scalesGMTensor_; + GlobalTensor xActiveMaskGMTensor_; GlobalTensor expandXOutGMTensor_; GlobalTensor dynamicScalesOutGMTensor_; GlobalTensor expertTokenNumsOutGMTensor_; @@ -144,6 +142,9 @@ private: TBuf<> rowMaxBuf_; TBuf<> receiveDataCastFloatBuf_; TBuf<> smoothScalesBuf_; + TBuf<> dstExpBuf_; + TBuf<> subExpBuf_; + TBuf<> gatherMaskTBuf_; TQueBind xQueue_; TQue xInQueue_; TQue xOutQueue_; @@ -161,7 +162,10 @@ private: GM_ADDR tpLocalStatusWindowGM_; GlobalTensor peerMemsAddrGm_; uint32_t axisBS_{0}; + uint32_t axisBsAlignSize_{0}; + uint64_t activeMaskBsCnt_{0}; uint32_t axisMaxBS_{0}; + uint64_t sendToMoeExpTokenCnt_{0}; uint32_t axisH_{0}; uint32_t axisK_{0}; uint32_t aivNum_{0}; @@ -183,6 +187,7 @@ private: uint32_t hSize_{0}; uint32_t hOutSize_{0}; uint32_t hCommuSize_{0}; + uint32_t maxSize_{0}; uint32_t scaleParamPad_{0}; uint32_t axisHCommu_{0}; uint32_t startExpertId_; @@ -216,7 +221,7 @@ private: template __aicore__ inline void CamMoeDistributeDispatch::Init( - GM_ADDR x, GM_ADDR expertIds, GM_ADDR scales, GM_ADDR expandXOut, GM_ADDR dynamicScalesOut, GM_ADDR expandIdxOut, + GM_ADDR x, GM_ADDR expertIds, GM_ADDR scales, GM_ADDR xActiveMask, GM_ADDR expandXOut, GM_ADDR dynamicScalesOut, GM_ADDR expandIdxOut, GM_ADDR expertTokenNumsOut, GM_ADDR sendCountsOut, GM_ADDR outputRecvCount, GM_ADDR tpSendCountsOut, GM_ADDR workspaceGM, TPipe *pipe, const DispatchGmmCombineDecodeTilingData *tilingData) { @@ -268,6 +273,7 @@ __aicore__ inline void CamMoeDistributeDispatch::Init( tpWorldSize_ = 1; xGMTensor_.SetGlobalBuffer((__gm__ XType *)x); expertIdsGMTensor_.SetGlobalBuffer((__gm__ int32_t *)expertIds); + xActiveMaskGMTensor_.SetGlobalBuffer((__gm__ bool*)xActiveMask); expandXOutGMTensor_.SetGlobalBuffer((__gm__ ExpandXOutType *)expandXOut); dynamicScalesOutGMTensor_.SetGlobalBuffer((__gm__ float *)dynamicScalesOut); expertTokenNumsOutGMTensor_.SetGlobalBuffer((__gm__ int64_t *)expertTokenNumsOut); @@ -329,7 +335,8 @@ __aicore__ inline void CamMoeDistributeDispatch::Init( if (isQuant_) { QuantInit(scales); } - uint32_t expertIdsSize = Ceil(axisBS_ * axisK_ * sizeof(int32_t), UB_ALIGN) * UB_ALIGN; + uint32_t expertIdsCnt = axisBS_ * axisK_; + uint32_t expertIdsSize = Ceil(expertIdsCnt * sizeof(int32_t), UB_ALIGN) * UB_ALIGN; tpipe_->InitBuffer(expertIdsBuf_, expertIdsSize); expertIdsTensor_ = expertIdsBuf_.Get(); tpipe_->InitBuffer(expertCountBuf_, expertIdsSize); @@ -339,7 +346,22 @@ __aicore__ inline void CamMoeDistributeDispatch::Init( tpipe_->InitBuffer(getTotalBuf_, epWorldSize_ * moeExpertNumPerRank_ * sizeof(int32_t)); tpipe_->InitBuffer(scalarBuf_, UB_ALIGN * 2); - + uint32_t hFp32Size = axisH_ * sizeof(float); + uint32_t xActivateMaskSize = axisBS_ * (Ceil(axisK_ * sizeof(bool), UB_ALIGN) * UB_ALIGN) * sizeof(half); + uint32_t bsAlign256 = Ceil(axisBS_ * sizeof(half), ALIGNED_LEN_256) * ALIGNED_LEN_256 / sizeof(half); + uint32_t bsKAlign256 = Ceil(expertIdsCnt * sizeof(half), ALIGNED_LEN_256) * ALIGNED_LEN_256 / sizeof(half); + expertIdsSize = Ceil(expertIdsSize, UB_ALIGN) * UB_ALIGN; + maxSize_ = hFp32Size > expertIdsSize ? hFp32Size : expertIdsSize; + maxSize_ = maxSize_ > xActivateMaskSize ? maxSize_ : xActivateMaskSize; + maxSize_ = maxSize_ > bsKAlign256 ? maxSize_ : bsKAlign256; + tpipe_->InitBuffer(gatherMaskTBuf_, maxSize_); + if constexpr (DynamicQuant || StaticQuant) { + dstExpBuf_ = receiveDataCastFloatBuf_; // 内存复用 + subExpBuf_ = smoothScalesBuf_; // 内存复用 + } else { + tpipe_->InitBuffer(dstExpBuf_, maxSize_); // BS * K * 4 = 32K + tpipe_->InitBuffer(subExpBuf_, maxSize_); // BS * K * 4 = 32K + } moeExpertRankNumAligned_ = Ceil(moeExpertNum_, TABLE_ELEM_COUNT_PER_BLOCK) * TABLE_ELEM_COUNT_PER_BLOCK; if (axisBS_ <= LOOP_OPT_MAX_BS && moeExpertRankNumAligned_ <= LOOP_OPT_MAX_MOE_RANK && axisK_ % TOPK_ELEM_COUNT_PER_BLOCK == 0) { @@ -370,11 +392,36 @@ __aicore__ inline void CamMoeDistributeDispatch::Quant dynamicScalesTensor_ = dynamicScalesBuf_.Get(); } +template +__aicore__ inline void CamMoeDistributeDispatch::TokenActiveMaskCal() +{ + // 搬运x_active_mask, 当前仅用于计算有效token总数 + LocalTensor maskTmpTensor; + LocalTensor sumOutTensor; + LocalTensor maskInputTensor; + axisBsAlignSize_ = Ceil(axisBS_ * sizeof(bool), UB_ALIGN) * UB_ALIGN; + maskInputTensor = dstExpBuf_.Get(); + maskTmpTensor = subExpBuf_.Get(); + sumOutTensor = gatherMaskTBuf_.Get(); + DataCopyExtParams maskParams = {1U, static_cast(axisBS_ * sizeof(bool)), 0U, 0U, 0U}; + DataCopyPadExtParams maskCopyPadParams{false, 0U, 0U, 0U}; + DataCopyPad(maskInputTensor, xActiveMaskGMTensor_, maskParams, maskCopyPadParams); + SyncFunc(); + LocalTensor maskInputInt8Tensor = maskInputTensor.ReinterpretCast(); + Cast(maskTmpTensor, maskInputInt8Tensor, RoundMode::CAST_NONE, axisBS_); + PipeBarrier(); + SumParams params{1, axisBsAlignSize_, axisBS_}; + Sum(sumOutTensor, maskTmpTensor, params); + SyncFunc(); + activeMaskBsCnt_ = static_cast(sumOutTensor.GetValue(0)); + sendToMoeExpTokenCnt_ = activeMaskBsCnt_ * axisK_; +} + template __aicore__ inline void CamMoeDistributeDispatch::SendToSharedExpert() { - uint32_t sendTokenNum = axisBS_ / sharedUsedAivNum_; - uint32_t remainderTokenNum = axisBS_ % sharedUsedAivNum_; + uint32_t sendTokenNum = activeMaskBsCnt_ / sharedUsedAivNum_; + uint32_t remainderTokenNum = activeMaskBsCnt_ % sharedUsedAivNum_; uint32_t newAivId = aivId_ - moeUsedAivNum_; uint32_t startTokenId = sendTokenNum * newAivId; if (newAivId < remainderTokenNum) { @@ -383,16 +430,16 @@ __aicore__ inline void CamMoeDistributeDispatch::SendT } else { startTokenId += remainderTokenNum; } - if (startTokenId >= axisBS_) { + if (startTokenId >= activeMaskBsCnt_) { return; } uint32_t endTokenId = startTokenId + sendTokenNum; for (uint32_t tokenShuffleIndex = 0; tokenShuffleIndex < sendTokenNum; ++tokenShuffleIndex) { uint32_t tokenIndex = startTokenId + ((tokenShuffleIndex + epRankId_) % sendTokenNum); - uint32_t temp = (epRankId_ * axisBS_) / sharedExpertRankNum_; - uint32_t moeOnShareRank = Ceil((tokenIndex + 1 + temp) * sharedExpertRankNum_, axisBS_) - 1 - epRankId_; - uint32_t preCnt = (moeOnShareRank + epRankId_) * axisBS_ / sharedExpertRankNum_ - - epRankId_ * axisBS_ / sharedExpertRankNum_; + uint32_t temp = (epRankId_ * activeMaskBsCnt_) / sharedExpertRankNum_; + uint32_t moeOnShareRank = Ceil((tokenIndex + 1 + temp) * sharedExpertRankNum_, activeMaskBsCnt_) - 1 - epRankId_; + uint32_t preCnt = (moeOnShareRank + epRankId_) * activeMaskBsCnt_ / sharedExpertRankNum_ - + epRankId_ * activeMaskBsCnt_ / sharedExpertRankNum_; GlobalTensor dstWinGMTensor; dstWinGMTensor.SetGlobalBuffer((__gm__ ExpandXOutType *)(GetWindAddrByRankId(COMM_EP_IDX, moeOnShareRank) + expertPerSizeOnWin_ * epRankId_)); @@ -440,7 +487,7 @@ __aicore__ inline void CamMoeDistributeDispatch::SendT template __aicore__ inline void CamMoeDistributeDispatch::SendToMoeExpert() { - uint32_t expertIdsCnt = axisBS_ * axisK_; + uint32_t expertIdsCnt = activeMaskBsCnt_ * axisK_; uint32_t sendTokenNum = expertIdsCnt / moeUsedAivNum_; uint32_t remainderTokenNum = expertIdsCnt % moeUsedAivNum_; uint32_t startTokenId = sendTokenNum * aivId_; @@ -496,7 +543,12 @@ __aicore__ inline void CamMoeDistributeDispatch::SendT template __aicore__ inline void CamMoeDistributeDispatch::AlltoAllDispatch() { - uint32_t expertIdsCnt = axisBS_ * axisK_; + activeMaskBsCnt_ = axisBS_; + sendToMoeExpTokenCnt_ = activeMaskBsCnt_ * axisK_; + if constexpr (EXEC_FLAG & EXEC_FLAG_X_ACTIVE_MASK) { + TokenActiveMaskCal(); + } + uint32_t expertIdsCnt = activeMaskBsCnt_ * axisK_; DataCopyExtParams expertIdsCntParams = {1U, static_cast(expertIdsCnt * sizeof(uint32_t)), 0U, 0U, 0U}; DataCopyPadExtParams copyPadParams{false, 0U, 0U, 0U}; DataCopyPad(expertIdsTensor_, expertIdsGMTensor_, expertIdsCntParams, copyPadParams); diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode_base.h b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode_base.h index f93c15a6..0c57896c 100644 --- a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode_base.h +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode_base.h @@ -14,5 +14,8 @@ #define TemplateMC2TypeClass typename ExpandXType, typename ExpandIdxType, bool IsNeedReduceScatter, uint32_t EXEC_FLAG #define TemplateMC2TypeFunc ExpandXType, ExpandIdxType, IsNeedReduceScatter, EXEC_FLAG - +#define TemplateDispatchTypeClass \ + typename XType, typename ExpandXOutType, bool StaticQuant, bool DynamicQuant, bool IsSmoothScaleExist, \ + bool IsNeedAllgater, uint32_t EXEC_FLAG +#define TemplateDispatchTypeFunc XType, ExpandXOutType, StaticQuant, DynamicQuant, IsSmoothScaleExist, IsNeedAllgater, EXEC_FLAG #endif // DISPATCH_GMM_COMBINE_DECODE_BASE_H diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode_tiling.h b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode_tiling.h index b8a831ac..328538f2 100644 --- a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode_tiling.h +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode_tiling.h @@ -70,5 +70,6 @@ constexpr uint32_t GMM2_SWIZZLE_DIRECTION = 0; constexpr uint32_t WORKSPACE_STAGES = 4; constexpr uint32_t EXEC_FLAG_DEEP_FUSE = (1U << 0); +constexpr uint32_t EXEC_FLAG_X_ACTIVE_MASK = (1U << 2); #endif // DISPATCH_GMM_COMBINE_DECODE_TILING_H diff --git a/csrc/torch_binding.cpp b/csrc/torch_binding.cpp index e0c976c4..96b6205e 100644 --- a/csrc/torch_binding.cpp +++ b/csrc/torch_binding.cpp @@ -640,8 +640,9 @@ std::tuple dispatch_gmm_combine_decode( const at::Tensor &gmm1_permuted_weight_scale, const at::Tensor &gmm2_weight, const at::Tensor &gmm2_weight_scale, + const at::Tensor &expert_scales, const c10::optional &expert_smooth_scales, - const c10::optional &expert_scales, + const c10::optional &x_active_mask, c10::string_view group_ep, int64_t ep_rank_size, int64_t ep_rank_id, @@ -674,8 +675,9 @@ std::tuple dispatch_gmm_combine_decode( gmm1_permuted_weight_scale, gmm2_weight, gmm2_weight_scale, - expert_smooth_scales, expert_scales, + expert_smooth_scales, + x_active_mask, //input attrs group_ep_ptr, ep_rank_size, @@ -1188,7 +1190,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops) "dispatch_gmm_combine_decode(Tensor x, Tensor expert_ids, Tensor gmm1_permuted_weight," " Tensor gmm1_permuted_weight_scale," " Tensor gmm2_weight, Tensor gmm2_weight_scale," - " Tensor? expert_smooth_scales=None, Tensor? expert_scales=None," + " Tensor expert_scales, Tensor? expert_smooth_scales=None," + " Tensor? x_active_mask=None," " str group_ep=''," " int ep_rank_size=0, int ep_rank_id=0, int moe_expert_num=0," " int shared_expert_num=1, int shared_expert_rank_num=0," diff --git a/csrc/torch_binding_meta.cpp b/csrc/torch_binding_meta.cpp index 3223801b..e114a981 100644 --- a/csrc/torch_binding_meta.cpp +++ b/csrc/torch_binding_meta.cpp @@ -161,8 +161,9 @@ std::tuple dispatch_gmm_combine_decode_meta( const at::Tensor &gmm1_permuted_weight_scale, const at::Tensor &gmm2_weight, const at::Tensor &gmm2_weight_scale, + const at::Tensor &expert_scales, const c10::optional &expert_smooth_scales, - const c10::optional &expert_scales, + const c10::optional &x_active_mask, c10::string_view group_ep, int64_t ep_rank_size, int64_t ep_rank_id, diff --git a/tests/e2e/nightly/multicard_ops/test_dispatch_gmm_combine_decode.py b/tests/e2e/nightly/multicard_ops/test_dispatch_gmm_combine_decode.py index d11254cc..03b4b964 100644 --- a/tests/e2e/nightly/multicard_ops/test_dispatch_gmm_combine_decode.py +++ b/tests/e2e/nightly/multicard_ops/test_dispatch_gmm_combine_decode.py @@ -16,6 +16,19 @@ torch.manual_seed(42) torch_npu.npu.config.allow_internal_format = True enable_custom_op() LOG_NAME = "dispatch_gmm_combine_decode_test_logs" +BASE_KWARGS = { + "batch_size": 64, + "token_hidden_size": 7168, + "moe_intermediate_size": 2048, + "ep_world_size": 16, + "moe_expert_num": 64, + "shared_expert_rank_num": 0, + "top_k": 8, + "test_bfloat16": True, + "enable_dynamic_bs": False, + "test_graph": False, + "with_mc2_mask": False +} def redirect_output(log_file_path): @@ -115,11 +128,14 @@ class DecodeMoeOps(torch.nn.Module): self.gmm2_weight_scale_fp32 = torch.nn.Parameter( gmm2_weight_scale.float(), requires_grad=False) - def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales): + def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales, + x_active_mask): raise NotImplementedError("To be implemented in subclass") - def forward(self, x, expert_ids, smooth_scales, expert_scales): - return self._apply_ops(x, expert_ids, smooth_scales, expert_scales) + def forward(self, x, expert_ids, smooth_scales, expert_scales, + x_active_mask): + return self._apply_ops(x, expert_ids, smooth_scales, expert_scales, + x_active_mask) class SmallOps(DecodeMoeOps): @@ -144,11 +160,13 @@ class SmallOps(DecodeMoeOps): shared_expert_rank_num) self.tp_hcomm_info = "" - def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales): + def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales, + x_active_mask): outputs = torch_npu.npu_moe_distribute_dispatch_v2( x=x, expert_ids=expert_ids, expert_scales=expert_scales, + x_active_mask=x_active_mask, group_ep=self.ep_hcomm_info, ep_world_size=self.ep_world_size, ep_rank_id=self.global_rank_id, @@ -200,6 +218,7 @@ class SmallOps(DecodeMoeOps): assist_info_for_combine=assist_info_for_combine, ep_send_counts=ep_send_counts, expert_scales=expert_scales, + x_active_mask=x_active_mask, group_ep=self.ep_hcomm_info, ep_world_size=self.ep_world_size, ep_rank_id=self.global_rank_id, @@ -237,7 +256,8 @@ class FusionOp(DecodeMoeOps): ep_world_size, moe_expert_num, global_rank_id, shared_expert_rank_num) - def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales): + def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales, + x_active_mask): output = torch.ops._C_ascend.dispatch_gmm_combine_decode( x=x, expert_ids=expert_ids, @@ -245,8 +265,9 @@ class FusionOp(DecodeMoeOps): gmm1_permuted_weight_scale=self.gmm1_weight_scale_fp32, gmm2_weight=self.gmm2_weight, gmm2_weight_scale=self.gmm2_weight_scale_fp32, - expert_smooth_scales=smooth_scales, expert_scales=expert_scales, + expert_smooth_scales=smooth_scales, + x_active_mask=x_active_mask, group_ep=self.ep_hcomm_info, ep_rank_size=self.ep_world_size, ep_rank_id=self.global_rank_id, @@ -267,12 +288,13 @@ def generate_datas(batch_size, shared_expert_rank_num=0, top_k=8, test_bfloat16=True, - enable_dynamic_bs=False): + enable_dynamic_bs=False, + with_mc2_mask=False): is_shared_expert = global_rank_id < shared_expert_rank_num moe_expert_num_per_rank = moe_expert_num // (ep_world_size - shared_expert_rank_num) actual_bs = int( - torch.randint(1, batch_size, [1]).item( + torch.randint(2 if with_mc2_mask else 1, batch_size, [1]).item( ) if enable_dynamic_bs else batch_size) local_expert_num = 1 if is_shared_expert else moe_expert_num_per_rank gmm1_input_dim = token_hidden_size @@ -317,9 +339,16 @@ def generate_datas(batch_size, else: x = x.half() smooth_sales = None - return (x, expert_ids, smooth_sales, expert_scales), \ + x_active_mask = None + valid_token_num = actual_bs + if with_mc2_mask: + valid_token_num = int(torch.randint(1, actual_bs, [1]).item()) + x_active_mask = torch.cat( + (torch.ones(valid_token_num), + torch.zeros(actual_bs - valid_token_num))).bool() + return (x, expert_ids, smooth_sales, expert_scales, x_active_mask), \ (gmm1_weight, gmm1_weight_scale, gmm2_weight, gmm2_weight_scale), \ - actual_bs + actual_bs, valid_token_num def run_once(local_rank_id, @@ -332,7 +361,8 @@ def run_once(local_rank_id, top_k=8, test_bfloat16=True, enable_dynamic_bs=False, - test_graph=False): + test_graph=False, + with_mc2_mask=False): log_file = redirect_output(f"local_rank_{local_rank_id}.log" ) if output_to_file(local_rank_id) else None global_rank_id = local_rank_id # 单机 @@ -358,8 +388,8 @@ def run_once(local_rank_id, parameter = (batch_size, token_hidden_size, moe_intermediate_size, ep_world_size, moe_expert_num, global_rank_id, shared_expert_rank_num) - input_datas, weight_datas, actual_bs = generate_datas( - *parameter, top_k, test_bfloat16, enable_dynamic_bs) + input_datas, weight_datas, actual_bs, valid_token_num = generate_datas( + *parameter, top_k, test_bfloat16, enable_dynamic_bs, with_mc2_mask) input_datas = [ data.npu() if data is not None else None for data in input_datas ] @@ -382,8 +412,8 @@ def run_once(local_rank_id, if log_file is not None: log_file.close() small_op_count_output = from_inclusive_prefix_sum(small_op_count_output) - torch.testing.assert_close(small_op_token_output.cpu(), - fused_op_token_output.cpu(), + torch.testing.assert_close(small_op_token_output[0:valid_token_num].cpu(), + fused_op_token_output[0:valid_token_num].cpu(), atol=2.0, rtol=0.02) torch.testing.assert_close(small_op_count_output.cpu(), @@ -394,18 +424,16 @@ def run_once(local_rank_id, @torch.inference_mode() -def test(): - batch_size = 64 - token_hidden_size = 7168 - moe_intermediate_size = 2048 - ep_world_size = 16 - moe_expert_num = 64 - shared_expert_rank_num = 0 - top_k = 8 - test_bfloat16 = True - enable_dynamic_bs = False - test_graph = False - args = (batch_size, token_hidden_size, moe_intermediate_size, - ep_world_size, moe_expert_num, shared_expert_rank_num, top_k, - test_bfloat16, enable_dynamic_bs, test_graph) - mp.spawn(run_once, args=args, nprocs=ep_world_size, join=True) +def test_dispatch_gmm_combine_decode_base(): + custom_kwargs = BASE_KWARGS + ep_world_size = custom_kwargs["ep_world_size"] + custom_args = tuple(custom_kwargs.values()) + mp.spawn(run_once, args=custom_args, nprocs=ep_world_size, join=True) + + +def test_dispatch_gmm_combine_decode_with_mc2_mask(): + custom_kwargs = BASE_KWARGS + custom_kwargs["with_mc2_mask"] = True + ep_world_size = custom_kwargs["ep_world_size"] + custom_args = tuple(custom_kwargs.values()) + mp.spawn(run_once, args=custom_args, nprocs=ep_world_size, join=True)