diff --git a/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine.h b/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine.h index 4e73b832..d96e5931 100644 --- a/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine.h +++ b/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine.h @@ -224,7 +224,7 @@ __aicore__ inline void DispatchFFNCombine::Process() constexpr uint32_t ubStages = 2; using EpilogueDispatchPolicy1 = Epilogue::EpilogueAtlasA2PerTokenDequantSwigluQuant; - + using ScaleType = Gemm::GemmType; using PerTokenScaleType = Gemm::GemmType; using ElementMulType = Gemm::GemmType; @@ -234,7 +234,8 @@ __aicore__ inline void DispatchFFNCombine::Process() using BlockEpilogue1 = Epilogue::Block::BlockEpilogue; - using EpilogueDispatchPolicy2 = Epilogue::EpilogueAtlasA2PerTokenDequantV2; + using EpilogueDispatchPolicy2 = Epilogue::EpilogueAtlasA2PerTokenDequant; + using TileCopy2 = Epilogue::Tile::TileCopy; using BlockEpilogue2 = Epilogue::Block::BlockEpilogue; @@ -254,9 +255,11 @@ __aicore__ inline void DispatchFFNCombine::Process() GemmCoord problemShape{static_cast(m), static_cast(n), static_cast(k)}; - uint32_t epilogueCoreNum = aivNum / 2; - uint32_t epilogueGranularity = expertPerRank - 1; - + uint32_t epilogueCoreNum = aivNum; + uint32_t epilogueGranularity = expertPerRank - 3; + if (expertPerRank <= 4) { + epilogueGranularity = expertPerRank - 1; + } typename MatmulKernel::Params params{ problemShape, static_cast(EP), static_cast(listLen), static_cast(expertPerRank), static_cast(maxOutputSize), static_cast(rank), static_cast(rankSize), @@ -277,4 +280,4 @@ __aicore__ inline void DispatchFFNCombine::Process() } } // DispatchFFNCombineImpl -#endif // DISPATCH_FFN_COMBINE_H +#endif // DISPATCH_FFN_COMBINE_H \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine_kernel.hpp b/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine_kernel.hpp index df7d88f5..20e91bc7 100644 --- a/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine_kernel.hpp +++ b/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine_kernel.hpp @@ -571,6 +571,7 @@ private: if constexpr (BlockMmad::DispatchPolicy::ASYNC) { blockMmad.SynchronizeBlock(); } + blockMmad.Finalize(params.expertPerRank - 1, 0); } @@ -727,19 +728,6 @@ private: } - CATLASS_DEVICE - void CombineSetFlag() { - AscendC::SetFlag(EVENT_ID0); - AscendC::SetFlag(EVENT_ID1); - AscendC::SetFlag(EVENT_ID2); - AscendC::SetFlag(EVENT_ID3); - AscendC::SetFlag(EVENT_ID2); - AscendC::SetFlag(EVENT_ID3); - AscendC::SetFlag(EVENT_ID0); - AscendC::SetFlag(EVENT_ID1); - } - - CATLASS_DEVICE void DispatchAndCombine(Params const ¶ms) { icache_preload(8); @@ -800,13 +788,17 @@ private: GM_ADDR otherRankPtr = shmem(0, dstEpIdx); AscendC::GlobalTensor gmRemoteA; gmRemoteA.SetGlobalBuffer(reinterpret_cast<__gm__ ElementA*>(otherRankPtr + peermemInfo.offsetA)); - + AscendC::GlobalTensor gmRemotePerTokenScale; + gmRemotePerTokenScale.SetGlobalBuffer(reinterpret_cast<__gm__ ElementPerTokenScale*>(otherRankPtr + peermemInfo.offsetPeerPerTokenScale)); MatrixCoord offsetA{rowStart, 0}; MatrixCoord offsetPeer{rowSrc, 0}; int64_t gmOffsetA = params.layoutA.GetOffset(offsetA); - int64_t gmOffsetPeer = rowSrc * (params.problemShape.k() + ALIGN_512); + int64_t gmOffsetPeer = params.layoutA.GetOffset(offsetPeer); + // Communication data - CopyGMToGMPerToken(gmA[gmOffsetA], gmPerTokenScale1[rowStart], gmRemoteA[gmOffsetPeer], rows, params.problemShape.k()); + CopyGMToGM(gmA[gmOffsetA], gmRemoteA[gmOffsetPeer], rows * params.problemShape.k(), params.ubMoveNum); + // Communication scale + CopyGMToGM(gmPerTokenScale1[rowStart], gmRemotePerTokenScale[rowSrc], rows, rows); } } @@ -837,16 +829,12 @@ private: uint32_t n2 = params.problemShape.k(); + typename BlockEpilogue2::Params epilogueParams{ static_cast(params.EP), static_cast(params.expertPerRank), - static_cast(params.rank), reinterpret_cast<__gm__ int32_t *>(shmem() + peermemInfo.offsetPeerTokenPerExpert), - params.layoutD2, - static_cast(n2), - static_cast(L1TileShape::N), - shmem, - static_cast(peermemInfo.offsetD) + static_cast(n2) }; uint32_t n = params.problemShape.n(); @@ -890,109 +878,65 @@ private: blockEpilogue1.Finalize(); - - CombineSetFlag(); - - CombineV2(params, blockEpilogue2); - + blockEpilogue2.SetFlag(); + CombineV1(params, blockEpilogue2); AscendC::SyncAll(); #ifndef __CROSSRANKSYNCANDALLGATHERV1__ ResetTokenPerExpert(params.EP * AlignUp(params.EP * params.expertPerRank, 128)); #endif - shmem.InitStatusTargetSum(); - if (get_subblockid() == 0) { - AscendC::LocalTensor ctrBuffer = resource.ubBuf.template GetBufferByByte(0); - shmem.CrossRankSyncV2Set(ctrBuffer); - } else { - uint32_t uboffset = 0; - uint32_t aicCoreNum = coreNum / 2; - uint32_t aicCoreIdx = get_block_idx(); - uint32_t sendRankNum_ = params.EP / aicCoreNum; - uint32_t remainderRankNum = params.EP % aicCoreNum; - if (aicCoreIdx < remainderRankNum) { - sendRankNum_++; - } - AscendC::LocalTensor statusTensor = resource.ubBuf.template GetBufferByByte(uboffset); - uboffset += sendRankNum_ * UB_ALIGN; - AscendC::LocalTensor gatherMaskOutTensor = resource.ubBuf.template GetBufferByByte(uboffset); - uboffset += AlignUp(params.EP * sizeof(float), 32); - AscendC::LocalTensor gatherTmpTensor = resource.ubBuf.template GetBufferByByte(uboffset); - uboffset += AlignUp(sizeof(uint32_t), 32); - AscendC::LocalTensor statusSumOutTensor = resource.ubBuf.template GetBufferByByte(uboffset); - uboffset += AlignUp(sizeof(float), 32); - shmem.CrossRankSyncV2Wait(statusTensor, gatherMaskOutTensor, gatherTmpTensor, statusSumOutTensor); - MoeTokenUnpermuteTilingData tilingData; - MoeTokenUnpermuteTiling(params.problemShape.m() * params.topK, n2, params.topK, tilingData, coreNum / 2); - KernelMoeTokenUnpermute kernelMoeTokenUnpermuteOp; - kernelMoeTokenUnpermuteOp.Init(shmem() + peermemInfo.offsetD, workspaceInfo.expandedRowIdx, params.probs, reinterpret_cast(params.ptrOutput), &tilingData); - kernelMoeTokenUnpermuteOp.Process(); - } - + + shmem.CrossRankSync(); + + MoeTokenUnpermuteTilingData tilingData; + MoeTokenUnpermuteTiling(params.problemShape.m() * params.topK, n2, params.topK, tilingData, coreNum); + KernelMoeTokenUnpermute kernelMoeTokenUnpermuteOp; + kernelMoeTokenUnpermuteOp.Init(shmem() + peermemInfo.offsetD, workspaceInfo.expandedRowIdx, params.probs, reinterpret_cast(params.ptrOutput), &tilingData); + kernelMoeTokenUnpermuteOp.Process(); } - CATLASS_DEVICE - void CombineV2(Params const ¶ms, BlockEpilogue2 & blockEpilogue) { - BlockScheduler blockScheduler; - int32_t syncLoopIdx = 0; - uint32_t startCoreIdx = 0; - uint32_t aicCoreNum = coreNum / 2; - uint32_t aicCoreIdx = get_block_idx(); - uint32_t aivSubCoreIdx = get_subblockid(); - uint32_t preSrcExpertSum = 0; + void CombineV1(Params const ¶ms, BlockEpilogue2 & blockEpilogue) { uint32_t n2 = params.problemShape.k(); - uint32_t k2 = params.problemShape.n() / 2; + int32_t prevGroupSum2 = 0; + icache_preload(8); - for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) { - uint32_t currentExpertM = cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx); - if (preSrcExpertSum >= params.maxOutputSize) { - currentExpertM = 0; - } else if (preSrcExpertSum + currentExpertM > params.maxOutputSize) { - currentExpertM = params.maxOutputSize - preSrcExpertSum; - } - GemmCoord inGroupProblemShape{currentExpertM, n2, k2}; // M N K - blockScheduler.Update(inGroupProblemShape, MakeCoord(L1TileShape::M, L1TileShape::N)); - uint32_t coreLoops = blockScheduler.GetCoreLoops(); - uint32_t startLoopIdx = ((aicCoreIdx < startCoreIdx) ? (aicCoreIdx + aicCoreNum) : aicCoreIdx) - startCoreIdx; + for (uint32_t t_groupIdx = 0; t_groupIdx < params.expertPerRank; ++t_groupIdx) { + int32_t flagId = t_groupIdx / CROSS_CORE_FLAG_MAX_SET_COUNT; + AscendC::CrossCoreWaitFlag<0x2>(flagId); + AscendC::SyncAll(); - for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += aicCoreNum) { - GemmCoord blockCoord = blockScheduler.GetBlockCoord(loopIdx); - GemmCoord actualBlockShape = blockScheduler.GetActualBlockShape(blockCoord); - int32_t m0 = 16; - // Block count, the shape of each block is (m0, actualBlockShape.n()) - int32_t m_rows = (actualBlockShape.m() + m0 - 1) / m0; - int32_t aiv_m_rows = m_rows / 2; - if (aivSubCoreIdx == 1 && aiv_m_rows * 2 < m_rows) { - aiv_m_rows += 1; - } - uint32_t m_offset = blockCoord.m() * L1TileShape::M;//blockOffset - if(aivSubCoreIdx == 1) { - m_offset += (m_rows / 2) * m0; - } + uint32_t groupIdx = t_groupIdx; - - for (;syncLoopIdx <= groupIdx; syncLoopIdx ++) { - int32_t flag_id = syncLoopIdx / CROSS_CORE_FLAG_MAX_SET_COUNT; - AscendC::CrossCoreWaitFlag<0x2>(flag_id); - } - - for (int32_t cur_row = 0; cur_row < aiv_m_rows; cur_row ++) { - GemmCoord realTileCoord{m_offset, blockCoord.n() * L1TileShape::N, 1}; - uint32_t actualm = m0; - if(aivSubCoreIdx == 1 && cur_row == aiv_m_rows - 1){ - actualm = actualBlockShape.m() - (m_rows / 2) * m0 - cur_row * m0; + for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) { + __gm__ void* dstPeermemPtr = shmem(peermemInfo.offsetD, dstEpIdx); + AscendC::GlobalTensor gmRemotePeer; + gmRemotePeer.SetGlobalBuffer(reinterpret_cast<__gm__ ElementD2*>(dstPeermemPtr)); + uint32_t srcRowOffset = (dstEpIdx == 0 ? 0 : cumsumMM((dstEpIdx - 1) * params.expertPerRank + groupIdx)) + prevGroupSum2; + if (srcRowOffset < params.maxOutputSize) { + uint32_t dataRows = tokenPerExpert(tokenPerExpertLayout(dstEpIdx, params.rank, groupIdx)); + if (srcRowOffset + dataRows > params.maxOutputSize) { + dataRows = params.maxOutputSize - srcRowOffset; + } + //uint32_t dstRowOffset = preSumBeforeRank(2 * dstEpIdx * FLAGSTRIDE + groupIdx); + int32_t tmpBlock = AlignUp(params.expertPerRank, FLAGSTRIDE); + //uint32_t dstRowOffset = preSumBeforeRank(dstEpIdx * tmpBlock + groupIdx); + uint32_t dstRowOffset = preSumBeforeRank(dstEpIdx * params.expertPerRank + groupIdx); + MatrixCoord offsetC{srcRowOffset, 0}; + MatrixCoord offsetPeer{dstRowOffset, 0}; + MatrixCoord shapeC{dataRows, n2}; + int64_t gmOffsetC = params.layoutD2.GetOffset(offsetC); + int64_t gmOffsetPeer = params.layoutD2.GetOffset(offsetPeer); + if constexpr (std::is_same_v) { + blockEpilogue(gmC2[gmOffsetC], shapeC, gmPerTokenScale2[srcRowOffset], gmRemotePeer[gmOffsetPeer]); + } else { + blockEpilogue(gmC2[gmOffsetC], shapeC, gmRemotePeer[gmOffsetPeer]); } - GemmCoord realTileShape{actualm, actualBlockShape.n(), 1}; - blockEpilogue(gmC2, gmPerTokenScale2, realTileCoord, realTileShape, groupIdx, preSrcExpertSum, preSumBeforeRank); - m_offset += m0; } } - preSrcExpertSum += currentExpertM; - startCoreIdx = (startCoreIdx + coreLoops) % aicCoreNum; + prevGroupSum2 += cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx); } blockEpilogue.Finalize(); } - private: struct WorkspaceInfo { GM_ADDR ptrA; @@ -1096,4 +1040,4 @@ private: } // namespace Catlass::Gemm::Kernel -#endif // DISPATCH_FFN_COMBINE_KERNEL_HPP +#endif // DISPATCH_FFN_COMBINE_KERNEL_HPP \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_fullload_dynamic_quant.h b/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_fullload_dynamic_quant.h index dfff54cc..187027c9 100644 --- a/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_fullload_dynamic_quant.h +++ b/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_fullload_dynamic_quant.h @@ -35,6 +35,7 @@ class MoeV2FullLoadDynamicQuant : public MoeV2SortBase { __aicore__ inline void CopyOutIdx(); __aicore__ inline void CopyOutEmpty(); __aicore__ inline void CopyOutXQuant1H(); + __aicore__ inline void CopyOutXQuantEH(); __aicore__ inline void ComputeExpertTokenCountOrCumsum(); __aicore__ inline void Compute(LocalTensor& smoothLocal); @@ -48,7 +49,6 @@ class MoeV2FullLoadDynamicQuant : public MoeV2SortBase { int64_t k_; int64_t n_; int64_t cols_; - int64_t cols_scale_; int64_t activateRows_; int64_t expertNum; int64_t expertCapacity; @@ -63,10 +63,12 @@ class MoeV2FullLoadDynamicQuant : public MoeV2SortBase { TQue smoothInQueue; TQue calcQueue; TQue inputXOutQueue; + TQue scaleOutQueue; GlobalTensor xGm_; GlobalTensor expertIdxGm_; GlobalTensor quantSmoothGm; + GlobalTensor dynamicQuantScaleGm; GlobalTensor expandedXGm_; GlobalTensor expandedRowIdxGm_; @@ -223,7 +225,7 @@ __aicore__ inline void MoeV2FullLoadDynamicQuant::Compute(LocalTensor& LocalTensor tempLocal = calcQueue.AllocTensor(); LocalTensor outLocal = inputXOutQueue.AllocTensor(); - LocalTensor dynamicQuantLocal = outLocal[this->cols_].template ReinterpretCast(); + LocalTensor dynamicQuantLocal = scaleOutQueue.AllocTensor(); if constexpr (!IsSameType::value) { Cast(inLocal, inLocal.ReinterpretCast()[colsAlign], RoundMode::CAST_NONE, this->cols_); @@ -257,6 +259,7 @@ __aicore__ inline void MoeV2FullLoadDynamicQuant::Compute(LocalTensor& calcQueue.FreeTensor(tempLocal); inputXOutQueue.EnQue(outLocal); + scaleOutQueue.EnQue(dynamicQuantLocal); } template @@ -272,7 +275,7 @@ __aicore__ inline void MoeV2FullLoadDynamicQuant::CopyOutXQuant1H() { DataCopyExtParams dataXCopyParams{1, static_cast(this->cols_ * sizeof(T)), 0, 0, 0}; DataCopyExtParams smoothCopyParams{1, static_cast(this->cols_ * sizeof(float)), 0, 0, 0}; - DataCopyExtParams intriParams{1, static_cast((this->cols_ + BLOCK_BYTES) * sizeof(int8_t)), 0, 0, 0}; + DataCopyExtParams intriParams{1, static_cast(this->cols_ * sizeof(int8_t)), 0, 0, 0}; LocalTensor smoothLocal; if (smoothType == 1) { @@ -292,6 +295,7 @@ __aicore__ inline void MoeV2FullLoadDynamicQuant::CopyOutXQuant1H() { xCopyInQueue_.EnQue(xLocal); Compute(smoothLocal); + LocalTensor quantScaleLocal = scaleOutQueue.DeQue(); LocalTensor outLocal = inputXOutQueue.DeQue(); while (curRowsStart <= curRowsEnd && curRowsStart / this->k_ == row) { int32_t outIndex = expandedRowIdx.GetValue(curRowsStart); @@ -299,15 +303,74 @@ __aicore__ inline void MoeV2FullLoadDynamicQuant::CopyOutXQuant1H() { if (outIndex == -1 || (this->dropPadMode == DROPLESS_MODE && outIndex >= this->activateRows_)) { continue; } - DataCopyPad(expandedXGm_[outIndex * this->cols_scale_], outLocal, intriParams); + DataCopyPad(expandedXGm_[outIndex * cols_], outLocal, intriParams); + DataCopyPad(dynamicQuantScaleGm[outIndex], quantScaleLocal, {1, 4, 0, 0, 0}); } xCopyInQueue_.FreeTensor(xLocal); inputXOutQueue.FreeTensor(outLocal); + scaleOutQueue.FreeTensor(quantScaleLocal); + } + if (smoothType == 1) { + smoothInQueue.FreeTensor(smoothLocal); } expandedRowIdxCopyOutQueue_.FreeTensor(expandedRowIdx); } +template +__aicore__ inline void MoeV2FullLoadDynamicQuant::CopyOutXQuantEH() { + LocalTensor expandedRowIdx = expandedRowIdxCopyOutQueue_.DeQue(); + expandedRowIdxCopyOutQueue_.FreeTensor(expandedRowIdx); + Muls(expandDstToSrcRowLocal.ReinterpretCast(), expandDstToSrcRowLocal.ReinterpretCast(), (float)-1, + this->totalLength); + pipe_barrier(PIPE_V); + LocalTensor sortedRowIdx = expandDstToSrcRowLocal.ReinterpretCast(); + Cast(sortedRowIdx, expandDstToSrcRowLocal.ReinterpretCast(), RoundMode::CAST_ROUND, this->totalLength); + + int64_t curRowsStart = this->blockIdx_ * this->perCoreRows_; + int64_t curRowsEnd = curRowsStart + this->coreRows_ - 1; + + DataCopyExtParams dataXCopyParams{1, static_cast(this->cols_ * sizeof(T)), 0, 0, 0}; + DataCopyExtParams smoothCopyParams{1, static_cast(this->cols_ * sizeof(float)), 0, 0, 0}; + DataCopyExtParams intriParams{1, static_cast(this->cols_ * sizeof(int8_t)), 0, 0, 0}; + + for (int64_t row = curRowsStart; row <= curRowsEnd; row++) { + if (this->dropPadMode == DROPLESS_MODE && row >= this->activateRows_) { + break; + } + int32_t srcIdx = sortedRowIdx.GetValue(row); + int32_t expertIdx = expandedExpertIdxLocal.GetValue(row); + + LocalTensor inLocal = xCopyInQueue_.AllocTensor(); + LocalTensor smoothLocal = smoothInQueue.AllocTensor(); + if constexpr (IsSameType::value) { + DataCopyPad(inLocal, xGm_[srcIdx / this->k_ * this->cols_], dataXCopyParams, {false, 0, 0, 0}); + } else { + DataCopyPad(inLocal[colsAlign], xGm_[srcIdx / this->k_ * this->cols_], dataXCopyParams, {false, 0, 0, 0}); + } + DataCopyPad(smoothLocal, quantSmoothGm[expertIdx * this->cols_], smoothCopyParams, {false, 0, 0, 0}); + xCopyInQueue_.EnQue(inLocal); + smoothInQueue.EnQue(smoothLocal); + smoothLocal = smoothInQueue.DeQue(); + + Compute(smoothLocal); + + LocalTensor quantScaleLocal = scaleOutQueue.DeQue(); + DataCopyPad(dynamicQuantScaleGm[row], quantScaleLocal, {1, 4, 0, 0, 0}); + + LocalTensor outLocal = inputXOutQueue.DeQue(); + DataCopyPad(expandedXGm_[row * this->cols_], outLocal, intriParams); + + xCopyInQueue_.FreeTensor(inLocal); + smoothInQueue.FreeTensor(smoothLocal); + inputXOutQueue.FreeTensor(outLocal); + scaleOutQueue.FreeTensor(quantScaleLocal); + } + + expandDstToSrcRowQueue_.FreeTensor(expandDstToSrcRowLocal); + expandedExpertIdxCopyOutQueue_.FreeTensor(expandedExpertIdxLocal); +} + template __aicore__ inline void MoeV2FullLoadDynamicQuant::Init(GM_ADDR x, GM_ADDR expertIdx, GM_ADDR expandedX, GM_ADDR expandedRowIdx, GM_ADDR expertTokensCountOrCumsum, @@ -321,7 +384,6 @@ __aicore__ inline void MoeV2FullLoadDynamicQuant::Init(GM_ADDR x, GM_ADDR exp this->k_ = tilingData->k; this->n_ = tilingData->n; this->cols_ = tilingData->cols; - this->cols_scale_ = this->cols_ + ALIGN_512; this->needCoreNum_ = this->gatherOutTilingData_->needCoreNum; this->perCoreRows_ = this->gatherOutTilingData_->perCoreRows; this->activateRows_ = this->gatherOutTilingData_->activateRows; @@ -352,6 +414,7 @@ __aicore__ inline void MoeV2FullLoadDynamicQuant::Init(GM_ADDR x, GM_ADDR exp Align(this->expertNum, sizeof(int32_t))); } quantSmoothGm.SetGlobalBuffer((__gm__ float*)quantSmooth); + dynamicQuantScaleGm.SetGlobalBuffer((__gm__ float*)dynamicQuantScale); int64_t kvFactor = 2; int64_t buffSize = this->sortNum_ * sizeof(int32_t); @@ -375,7 +438,8 @@ __aicore__ inline void MoeV2FullLoadDynamicQuant::Init(GM_ADDR x, GM_ADDR exp } pipe->InitBuffer(smoothInQueue, 1, AlignBytes(this->cols_, sizeof(float))); pipe->InitBuffer(calcQueue, 1, AlignBytes(this->cols_, sizeof(float))); - pipe->InitBuffer(inputXOutQueue, 1, AlignBytes(this->cols_scale_, sizeof(int8_t))); + pipe->InitBuffer(inputXOutQueue, 1, AlignBytes(this->cols_, sizeof(int8_t))); + pipe->InitBuffer(scaleOutQueue, 1, BLOCK_BYTES + BLOCK_BYTES); } template @@ -391,7 +455,11 @@ __aicore__ inline void MoeV2FullLoadDynamicQuant::Process() { } else { CopyOutEmpty(); } - CopyOutXQuant1H(); + if (smoothType == 2) { + CopyOutXQuantEH(); + } else { + CopyOutXQuant1H(); + } } } } // namespace MoeInitRoutingQuantV2 diff --git a/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_gather_dynamic_quant.h b/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_gather_dynamic_quant.h index 6eced5e9..90a76942 100644 --- a/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_gather_dynamic_quant.h +++ b/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_gather_dynamic_quant.h @@ -66,7 +66,6 @@ class MoeV2GatherDynamicQuant { int64_t needCoreNum; int64_t blockIdx; int64_t cols; - int64_t cols_scale_; int64_t n; int64_t k; int64_t totalLength; @@ -118,7 +117,7 @@ __aicore__ inline void MoeV2GatherDynamicQuant::Compute(LocalTensor& s LocalTensor tempLocal = calcQueue.AllocTensor(); LocalTensor outLocal = inputXOutQueue.AllocTensor(); - LocalTensor dynamicQuantLocal = outLocal[this->cols].template ReinterpretCast(); + LocalTensor dynamicQuantLocal = scaleOutQueue.AllocTensor(); if constexpr (!IsSameType::value) { Cast(inLocal, inLocal.ReinterpretCast()[perLoopColsAlign], RoundMode::CAST_NONE, this->cols); @@ -152,6 +151,7 @@ __aicore__ inline void MoeV2GatherDynamicQuant::Compute(LocalTensor& s calcQueue.FreeTensor(tempLocal); inputXOutQueue.EnQue(outLocal); + scaleOutQueue.EnQue(dynamicQuantLocal); } template @@ -163,7 +163,7 @@ __aicore__ inline void MoeV2GatherDynamicQuant::CopyOutXQuant1H(int64_t progr int64_t currentLoopStartRow = initialRow / this->k; int64_t currentLoopLastRow = (initialRow + this->currentLoopRows - 1) / this->k; DataCopyExtParams copyInParams{1, static_cast(this->cols * sizeof(T)), 0, 0, 0}; - DataCopyExtParams copyOutParams{1, static_cast((this->cols + BLOCK_BYTES) * sizeof(int8_t)), 0, 0, 0}; + DataCopyExtParams copyOutParams{1, static_cast(this->cols * sizeof(int8_t)), 0, 0, 0}; DataCopyExtParams smoothParams{1, static_cast(this->cols * sizeof(float)), 0, 0, 0}; LocalTensor smoothLocal; @@ -187,6 +187,7 @@ __aicore__ inline void MoeV2GatherDynamicQuant::CopyOutXQuant1H(int64_t progr // Compute quantization Compute(smoothLocal); + LocalTensor quantScaleLocal = scaleOutQueue.DeQue(); LocalTensor outLocal = inputXOutQueue.DeQue(); while (curLoopRow < this->currentLoopRows && initialRow / this->k == row) { @@ -196,11 +197,15 @@ __aicore__ inline void MoeV2GatherDynamicQuant::CopyOutXQuant1H(int64_t progr if (outIndex == -1 || (this->dropPadMode == DROPLESS_MODE && outIndex >= this->activateRows)) { continue; } - // Scale is placed after the data position - DataCopyPad(expandedXGm[outIndex * cols_scale_], outLocal, copyOutParams); + DataCopyPad(expandedXGm[outIndex * cols], outLocal, copyOutParams); + DataCopyPad(dynamicQuantScaleGm[outIndex], quantScaleLocal, {1, 4, 0, 0, 0}); } inputXInQueue.FreeTensor(inLocal); inputXOutQueue.FreeTensor(outLocal); + scaleOutQueue.FreeTensor(quantScaleLocal); + } + if (smoothType == 1) { + smoothInQueue.FreeTensor(smoothLocal); } expandRowIdxInQueue.FreeTensor(indicesLocal); } @@ -458,7 +463,6 @@ __aicore__ inline void MoeV2GatherDynamicQuant::Init(GM_ADDR inputX, GM_ADDR this->needCoreNum = this->gatherOutTilingData->needCoreNum; this->activateRows = this->gatherOutTilingData->activateRows; this->cols = tilingData->cols; - this->cols_scale_ = this->cols + ALIGN_512; this->n = tilingData->n; this->k = tilingData->k; this->totalLength = tilingData->n * tilingData->k; @@ -514,15 +518,33 @@ __aicore__ inline void MoeV2GatherDynamicQuant::Init(GM_ADDR inputX, GM_ADDR pipe->InitBuffer(smoothInQueue, BUFFER_NUM, AlignBytes(this->perLoopCols, sizeof(float))); pipe->InitBuffer(calcQueue, 1, AlignBytes(this->perLoopCols, sizeof(float))); pipe->InitBuffer(inputXOutQueue, 1, AlignBytes(this->perLoopCols, sizeof(int8_t))); + pipe->InitBuffer(scaleOutQueue, 1, BLOCK_BYTES + BLOCK_BYTES); } template __aicore__ inline void MoeV2GatherDynamicQuant::Process() { if (this->blockIdx < this->needCoreNum) { currentLoopRows = perLoopRows; - if (colLoops > 1) { // Cannot fit all data in one row, workspace is required - trap(); // Not supported - } else { // All data can fit in one row + + if (colLoops > 1) { // A single row cannot be fully loaded; workspace is required + if (smoothType == 2) { + for (int64_t loop = 0; loop < this->rowLoops - 1; loop++) { + CopyInExpandedExpertIdx(loop); + CopyOutPartialXQuantEH(loop); + } + currentLoopRows = lastLoopRows; + CopyInExpandedExpertIdx(this->rowLoops - 1); + CopyOutPartialXQuantEH(this->rowLoops - 1); + } else { + for (int64_t loop = 0; loop < this->rowLoops - 1; loop++) { + CopyInExpandedRowIdx(loop); + CopyOutPartialXQuant1H(loop); + } + currentLoopRows = lastLoopRows; + CopyInExpandedRowIdx(this->rowLoops - 1); + CopyOutPartialXQuant1H(this->rowLoops - 1); + } + } else { // A single row can be fully loaded if (smoothType == 2) { for (int64_t loop = 0; loop < this->rowLoops - 1; loop++) { CopyInExpandedExpertIdx(loop); diff --git a/csrc/dispatch_ffn_combine/op_kernel/unpermute/moe_token_unpermute.h b/csrc/dispatch_ffn_combine/op_kernel/unpermute/moe_token_unpermute.h index adb805b8..40cab34e 100644 --- a/csrc/dispatch_ffn_combine/op_kernel/unpermute/moe_token_unpermute.h +++ b/csrc/dispatch_ffn_combine/op_kernel/unpermute/moe_token_unpermute.h @@ -85,9 +85,8 @@ KernelMoeTokenUnpermute::Init(GM_ADDR permuted_tokens, GM_ADD GM_ADDR unpermuted_tokens, const MoeTokenUnpermuteTilingData *__restrict tiling_data) { - this->blockIdx = get_block_idx(); - this->blockNum = get_block_num(); - + this->blockIdx = get_block_idx() + get_subblockid() * get_block_num(); + this->blockNum = get_block_num() * get_subblockdim(); if (blockIdx >= blockNum) { return; } diff --git a/csrc/dispatch_ffn_combine/op_kernel/utils/block_epilogue_pertoken_row.hpp b/csrc/dispatch_ffn_combine/op_kernel/utils/block_epilogue_pertoken_row.hpp index 5616b1f8..8ee8ed44 100644 --- a/csrc/dispatch_ffn_combine/op_kernel/utils/block_epilogue_pertoken_row.hpp +++ b/csrc/dispatch_ffn_combine/op_kernel/utils/block_epilogue_pertoken_row.hpp @@ -99,12 +99,20 @@ public: eventUbDMTE3VList[i] = eventMTE3V++; eventUbDVMTE3List[i] = eventVMTE3++; - AscendC::SetFlag(eventUbCVMTE2List[i]); - AscendC::SetFlag(eventUbDMTE3VList[i]); + ubCFp32List[i] = resource.ubBuf.template GetBufferByByte(ubOffset); ubOffset += blockN * sizeof(float); } } + CATLASS_DEVICE + void SetFlag() + { + for (uint32_t i = 0; i < UB_STAGES; ++i) { + AscendC::SetFlag(eventUbCVMTE2List[i]); + AscendC::SetFlag(eventUbDMTE3VList[i]); + } + } + CATLASS_DEVICE void Finalize() {