diff --git a/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine_kernel.hpp b/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine_kernel.hpp index a0fe0ad8..422595aa 100644 --- a/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine_kernel.hpp +++ b/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine_kernel.hpp @@ -42,6 +42,10 @@ using namespace AscendC; namespace Catlass::Gemm::Kernel { +constexpr uint16_t SYNCFLAGC2V = 9; +constexpr uint16_t SYNCFLAGV2C = 10; +constexpr uint16_t CROSS_CORE_FLAG_MAX_SET_COUNT = 15; + template < class BlockMmad_, class BlockScheduler_, @@ -189,7 +193,7 @@ public: { GMM1(params); - AscendC::CrossCoreWaitFlag<0x2>(2); + AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGV2C); GMM2(params); } @@ -201,7 +205,7 @@ public: { Dispatch(params); AscendC::SyncAll(); - AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(2); + AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(SYNCFLAGV2C); Combine(params); } @@ -324,9 +328,12 @@ private: int64_t gmGroupOffsetC = 0; uint32_t startCoreIdx = 0; uint32_t syncGroupIdx = 0; - AscendC::CrossCoreWaitFlag<0x2>(0); // Wait for AIV to finish cumsum for matmul int64_t preCurrentmSum = 0; int32_t syncLoopIdx = -1; + + uint16_t syncgmmIdx = 0; + AscendC::CrossCoreWaitFlag<0x2>(syncgmmIdx / CROSS_CORE_FLAG_MAX_SET_COUNT); // Wait for AIV to finish cumsum for matmul + syncgmmIdx++; constexpr uint32_t MAX_EXPERTS_PER_RANK = 32; __gm__ ElementB* weight1Array[MAX_EXPERTS_PER_RANK]; @@ -370,7 +377,8 @@ private: for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += coreNum) { for(;syncGroupIdx <= groupIdx; syncGroupIdx++) { - AscendC::CrossCoreWaitFlag<0x2>(0); + AscendC::CrossCoreWaitFlag<0x2>(syncgmmIdx / CROSS_CORE_FLAG_MAX_SET_COUNT); + syncgmmIdx ++; } // Compute block location GemmCoord blockCoord = blockScheduler.GetBlockCoord(loopIdx); @@ -399,7 +407,7 @@ private: if constexpr (BlockMmad::DispatchPolicy::ASYNC) { blockMmad.SynchronizeBlock(); } - blockMmad.Finalize(syncLoopIdx, 1); + blockMmad.Finalize(syncLoopIdx, SYNCFLAGC2V); } preCurrentmSum += currentM; @@ -410,10 +418,16 @@ private: gmGroupOffsetC += inGroupProblemShape.m() * inGroupProblemShape.n(); startCoreIdx = (startCoreIdx + coreLoops) % coreNum; } + + for(;syncGroupIdx < params.expertPerRank; syncGroupIdx++) { + AscendC::CrossCoreWaitFlag<0x2>(syncgmmIdx / CROSS_CORE_FLAG_MAX_SET_COUNT); + syncgmmIdx ++; + } + if constexpr (BlockMmad::DispatchPolicy::ASYNC) { blockMmad.SynchronizeBlock(); } - blockMmad.Finalize(syncLoopIdx + 1, 1); + blockMmad.Finalize(syncLoopIdx + 1, SYNCFLAGC2V); } CATLASS_DEVICE @@ -482,7 +496,7 @@ private: uint32_t startLoopIdx = ((coreIdx < startCoreIdx) ? (coreIdx + coreNum) : coreIdx) - startCoreIdx; // Loop through the matmul of each groupIdx if (params.expertPerRank > lastDequantExpertNum && groupIdx + 1 == params.expertPerRank - lastDequantExpertNum) { - AscendC::CrossCoreWaitFlag<0x2>(2); + AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGV2C); } for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += coreNum) { if (loopIdx + coreNum >= coreLoops) { @@ -508,7 +522,7 @@ private: gmB2[gmGroupOffsetB + gmOffsetB], layoutB2, gmC2[gmGroupOffsetC + gmOffsetC], layoutC, gmS2[gmOffsetS], layoutScale, - actualBlockShape, syncLoopIdx, 3 + actualBlockShape, syncLoopIdx, 0 ); } } @@ -526,7 +540,7 @@ private: if constexpr (BlockMmad::DispatchPolicy::ASYNC) { blockMmad.SynchronizeBlock(); } - blockMmad.Finalize(params.expertPerRank - 1, 3); + blockMmad.Finalize(params.expertPerRank - 1, 0); } CATLASS_DEVICE @@ -625,7 +639,9 @@ private: GetCumsumForMMAIV(tokenPerExpert, cumsumMM, params.expertPerRank, params.rank, params.EP); } AscendC::SyncAll(); - AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(0); + uint16_t syncgmm1Idx = 0; + AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(syncgmm1Idx / CROSS_CORE_FLAG_MAX_SET_COUNT); + syncgmm1Idx++; uint32_t curGroupOffset = 0; int32_t prevSumBeforeRank = 0; @@ -673,10 +689,11 @@ private: if ((params.epilogueGranularity < params.expertPerRank && params.epilogueGranularity > 0) && groupIdx == params.expertPerRank - 1) { syncLoopIdx++; - AscendC::CrossCoreWaitFlag<0x2>(syncLoopIdx / 8 + 1); + AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGC2V); } AscendC::SyncAll(); - AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(0); // V notifies C that the current communication round is complete + AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(syncgmm1Idx / CROSS_CORE_FLAG_MAX_SET_COUNT); // V notifies C that the current communication round is complete + syncgmm1Idx++; if ((params.epilogueGranularity < params.expertPerRank && params.epilogueGranularity > 0) && groupIdx == params.expertPerRank - 1 && prevGroupSum1 > 0) { uint32_t rowStartThisCore = 0; @@ -699,7 +716,7 @@ private: } } syncLoopIdx ++; - AscendC::CrossCoreWaitFlag<0x2>(syncLoopIdx /8 + 1); + AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGC2V); AscendC::SyncAll(); uint32_t lastDequantExpertNum = params.expertPerRank; @@ -707,7 +724,7 @@ private: lastDequantExpertNum = params.expertPerRank - params.epilogueGranularity; } if (lastDequantExpertNum < params.expertPerRank) { - AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(2); + AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(SYNCFLAGV2C); } if (prevGroupSum1 - dequantSum < params.maxOutputSize) { uint32_t rowStartThisCore = prevGroupSum1 - dequantSum;; @@ -746,7 +763,7 @@ private: BlockEpilogue2 blockEpilogue(resource, epilogueParams); int32_t prevGroupSum2 = 0; for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) { - AscendC::CrossCoreWaitFlag<0x2>(groupIdx / 8 + 3); + AscendC::CrossCoreWaitFlag<0x2>(groupIdx / CROSS_CORE_FLAG_MAX_SET_COUNT); AscendC::SyncAll(); for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) { diff --git a/csrc/dispatch_ffn_combine/op_kernel/utils/block_mmad_preload_async_fixpipe_quant.hpp b/csrc/dispatch_ffn_combine/op_kernel/utils/block_mmad_preload_async_fixpipe_quant.hpp index c15e11b2..3b435f26 100644 --- a/csrc/dispatch_ffn_combine/op_kernel/utils/block_mmad_preload_async_fixpipe_quant.hpp +++ b/csrc/dispatch_ffn_combine/op_kernel/utils/block_mmad_preload_async_fixpipe_quant.hpp @@ -22,6 +22,8 @@ namespace Catlass::Gemm::Block { +constexpr uint16_t CROSS_CORE_FLAG_MAX_SET_COUNT = 15; + template __aicore__ inline void SyncFlagFunc(int32_t eventID) { @@ -271,7 +273,7 @@ public: void Finalize(int32_t target, int32_t flag = 0) { for(;syncGroupIdx <= target; syncGroupIdx++) { - int32_t flagId = syncGroupIdx / 8 + flag; + int32_t flagId = syncGroupIdx / CROSS_CORE_FLAG_MAX_SET_COUNT + flag; AscendC::CrossCoreSetFlag<0x2, PIPE_FIX>(flagId); } } diff --git a/csrc/dispatch_ffn_combine_bf16/op_kernel/dispatch_ffn_combine_bf16_kernel.hpp b/csrc/dispatch_ffn_combine_bf16/op_kernel/dispatch_ffn_combine_bf16_kernel.hpp index b1c74aa8..5877a370 100644 --- a/csrc/dispatch_ffn_combine_bf16/op_kernel/dispatch_ffn_combine_bf16_kernel.hpp +++ b/csrc/dispatch_ffn_combine_bf16/op_kernel/dispatch_ffn_combine_bf16_kernel.hpp @@ -42,6 +42,10 @@ using namespace AscendC; namespace Catlass::Gemm::Kernel { +constexpr uint16_t SYNCFLAGC2V = 9; +constexpr uint16_t SYNCFLAGV2C = 10; +constexpr uint16_t CROSS_CORE_FLAG_MAX_SET_COUNT = 15; + template < class BlockMmad_, class BlockScheduler_, @@ -198,7 +202,7 @@ public: { GMM1(params); - AscendC::CrossCoreWaitFlag<0x2>(2); + AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGV2C); GMM2(params); } @@ -210,7 +214,7 @@ public: { Dispatch(params); AscendC::SyncAll(); - AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(2); + AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(SYNCFLAGV2C); CombineV2(params); } @@ -507,7 +511,9 @@ CATLASS_DEVICE int64_t gmGroupOffsetC = 0; uint32_t startCoreIdx = 0; uint32_t syncGroupIdx = 0; - AscendC::CrossCoreWaitFlag<0x2>(0); + uint16_t syncgmmIdx = 0; + AscendC::CrossCoreWaitFlag<0x2>(syncgmmIdx / CROSS_CORE_FLAG_MAX_SET_COUNT); // Wait for AIV to finish cumsum for matmul + syncgmmIdx++; AicSyncAll(); int64_t preCurrentmSum = 0; int32_t syncLoopIdx = -1; @@ -553,7 +559,8 @@ CATLASS_DEVICE for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += coreNum) { for(;syncGroupIdx <= groupIdx; syncGroupIdx++) { - AscendC::CrossCoreWaitFlag<0x2>(0); + AscendC::CrossCoreWaitFlag<0x2>(syncgmmIdx / CROSS_CORE_FLAG_MAX_SET_COUNT); + syncgmmIdx ++; } // Compute block location GemmCoord blockCoord = blockScheduler.GetBlockCoord(loopIdx); @@ -592,7 +599,7 @@ CATLASS_DEVICE if constexpr (BlockMmad::DispatchPolicy::ASYNC) { blockMmad.SynchronizeBlock(); } - blockMmad.Finalize(syncLoopIdx, 1); + blockMmad.Finalize(syncLoopIdx, SYNCFLAGC2V); } preCurrentmSum += currentM; @@ -603,10 +610,16 @@ CATLASS_DEVICE gmGroupOffsetC += inGroupProblemShape.m() * inGroupProblemShape.n(); startCoreIdx = (startCoreIdx + coreLoops) % coreNum; } + + for(;syncGroupIdx < params.expertPerRank; syncGroupIdx++) { + AscendC::CrossCoreWaitFlag<0x2>(syncgmmIdx / CROSS_CORE_FLAG_MAX_SET_COUNT); + syncgmmIdx ++; + } + if constexpr (BlockMmad::DispatchPolicy::ASYNC) { blockMmad.SynchronizeBlock(); } - blockMmad.Finalize(syncLoopIdx + 1, 1); + blockMmad.Finalize(syncLoopIdx + 1, SYNCFLAGC2V); } CATLASS_DEVICE @@ -674,7 +687,7 @@ CATLASS_DEVICE uint32_t startLoopIdx = ((coreIdx < startCoreIdx) ? (coreIdx + coreNum) : coreIdx) - startCoreIdx; // Loop through the matmul of each groupIdx if (params.expertPerRank > lastDequantExpertNum && groupIdx + 1 == params.expertPerRank - lastDequantExpertNum) { - AscendC::CrossCoreWaitFlag<0x2>(2); + AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGV2C); } for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += coreNum) { if (loopIdx + coreNum >= coreLoops) { @@ -701,7 +714,7 @@ CATLASS_DEVICE gmB2[gmGroupOffsetB + gmOffsetB], layoutB2, gmC2[gmGroupOffsetC + gmOffsetC], layoutC, gmS2[gmOffsetS], layoutScale, - actualBlockShape, syncLoopIdx, 3 + actualBlockShape, syncLoopIdx, 0 ); } else { blockMmad( @@ -709,7 +722,7 @@ CATLASS_DEVICE gmB2[gmGroupOffsetB + gmOffsetB], layoutB2, gmC2[gmGroupOffsetC + gmOffsetC], layoutC, gmS2, layoutScale, - actualBlockShape, syncLoopIdx, 3 + actualBlockShape, syncLoopIdx, 0 ); } } @@ -749,7 +762,9 @@ CATLASS_DEVICE GetCumsumForMMAIV(tokenPerExpert, cumsumMM, params.expertPerRank, params.rank, params.EP); } AscendC::SyncAll(); - AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(0); + uint16_t syncgmm1Idx = 0; + AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(syncgmm1Idx / CROSS_CORE_FLAG_MAX_SET_COUNT); + syncgmm1Idx++; uint32_t curGroupOffset = 0; int32_t prevSumBeforeRank = 0; @@ -791,10 +806,11 @@ CATLASS_DEVICE if ((params.epilogueGranularity < params.expertPerRank && params.epilogueGranularity > 0) && groupIdx == params.expertPerRank - 1) { syncLoopIdx++; - AscendC::CrossCoreWaitFlag<0x2>(syncLoopIdx / 8 + 1); + AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGC2V); } AscendC::SyncAll(); - AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(0); + AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(syncgmm1Idx / CROSS_CORE_FLAG_MAX_SET_COUNT); // V notifies C that the current communication round is complete + syncgmm1Idx++; if ((params.epilogueGranularity < params.expertPerRank && params.epilogueGranularity > 0) && groupIdx == params.expertPerRank - 1 && prevGroupSum1 > 0) { uint32_t rowStartThisCore = 0; @@ -829,9 +845,9 @@ CATLASS_DEVICE lastDequantExpertNum = params.expertPerRank - params.epilogueGranularity; } if (lastDequantExpertNum < params.expertPerRank) { - AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(2); + AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(SYNCFLAGV2C); } - AscendC::CrossCoreWaitFlag<0x2>(syncLoopIdx /8 + 1); + AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGC2V); AscendC::SyncAll(); if (prevGroupSum1 - dequantSum < params.maxOutputSize) { uint32_t rowStartThisCore = prevGroupSum1 - dequantSum;; @@ -907,7 +923,7 @@ CATLASS_DEVICE } if (loopIdx == startLoopIdx) { for (;syncLoopIdx <= groupIdx; syncLoopIdx++) { - int32_t flag_id = 3 + syncLoopIdx / 8; + int32_t flag_id = syncLoopIdx / CROSS_CORE_FLAG_MAX_SET_COUNT; AscendC::CrossCoreWaitFlag<0x2>(flag_id); } } diff --git a/csrc/dispatch_ffn_combine_bf16/op_kernel/utils/block_mmad_preload_async_fixpipe_quant.hpp b/csrc/dispatch_ffn_combine_bf16/op_kernel/utils/block_mmad_preload_async_fixpipe_quant.hpp index c15e11b2..3b435f26 100644 --- a/csrc/dispatch_ffn_combine_bf16/op_kernel/utils/block_mmad_preload_async_fixpipe_quant.hpp +++ b/csrc/dispatch_ffn_combine_bf16/op_kernel/utils/block_mmad_preload_async_fixpipe_quant.hpp @@ -22,6 +22,8 @@ namespace Catlass::Gemm::Block { +constexpr uint16_t CROSS_CORE_FLAG_MAX_SET_COUNT = 15; + template __aicore__ inline void SyncFlagFunc(int32_t eventID) { @@ -271,7 +273,7 @@ public: void Finalize(int32_t target, int32_t flag = 0) { for(;syncGroupIdx <= target; syncGroupIdx++) { - int32_t flagId = syncGroupIdx / 8 + flag; + int32_t flagId = syncGroupIdx / CROSS_CORE_FLAG_MAX_SET_COUNT + flag; AscendC::CrossCoreSetFlag<0x2, PIPE_FIX>(flagId); } }