diff --git a/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine_kernel.hpp b/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine_kernel.hpp index 422595aa..462004d2 100644 --- a/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine_kernel.hpp +++ b/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine_kernel.hpp @@ -653,7 +653,7 @@ private: m_prevSumBeforeRank = prevSumBeforeRank; } int prevSum = prevSumBeforeRank; - uint32_t prevGroupSum1 = 0; + uint32_t prevGroupSum1 = 0, dequantSum1 = 0, dequantSum2 = 0; uint32_t dequantSum = 0; int32_t syncLoopIdx = -1; uint32_t n = params.problemShape.n(); @@ -661,6 +661,7 @@ private: for (int32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) { // The ith core reads data from the ith rank's peermem groupIdxDeq = groupIdx - 2; + uint32_t currentM = cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx); for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) { uint32_t rowStart = (dstEpIdx == 0 ? 0 : cumsumMM((dstEpIdx - 1) * params.expertPerRank + groupIdx)) + prevGroupSum1; if (rowStart < params.maxOutputSize) { @@ -687,57 +688,55 @@ private: } } - if ((params.epilogueGranularity < params.expertPerRank && params.epilogueGranularity > 0) && groupIdx == params.expertPerRank - 1) { - syncLoopIdx++; - AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGC2V); - } AscendC::SyncAll(); AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(syncgmm1Idx / CROSS_CORE_FLAG_MAX_SET_COUNT); // V notifies C that the current communication round is complete - syncgmm1Idx++; - - if ((params.epilogueGranularity < params.expertPerRank && params.epilogueGranularity > 0) && groupIdx == params.expertPerRank - 1 && prevGroupSum1 > 0) { - uint32_t rowStartThisCore = 0; - MatrixCoord offsetC{0U, 0}; - uint32_t dequantLen = prevGroupSum1 - dequantSum; - if (dequantLen >= params.maxOutputSize) { - dequantLen = dequantLen - params.maxOutputSize; + prevGroupSum1 += currentM; + syncgmm1Idx ++; + if (groupIdx + 1 <= params.epilogueGranularity) { + if (dequantSum1 + currentM <= params.maxOutputSize) { + dequantSum1 += currentM; + } else if (dequantSum1 < params.maxOutputSize) { + dequantSum1 = params.maxOutputSize; } - + } + if (groupIdx + 1 > params.epilogueGranularity && dequantSum1 < params.maxOutputSize) { + if (dequantSum1 + dequantSum2 + currentM <= params.maxOutputSize) { + dequantSum2 += currentM; + } else if (dequantSum1 + dequantSum2 < params.maxOutputSize) { + dequantSum2 += params.maxOutputSize - dequantSum1 - dequantSum2; + } + } + } + + AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGC2V); + AscendC::SyncAll(); + + if (dequantSum1 > 0) { + uint32_t rowStartThisCore = 0; + MatrixCoord offsetC{0U, 0}; + MatrixCoord shapeC{dequantSum1, params.problemShape.n()}; + LayoutC layoutC{dequantSum1, params.problemShape.n()}; + int64_t gmOffsetC = layoutC.GetOffset(offsetC); + int64_t gmOffsetD = params.layoutD1.GetOffset(offsetC); + blockEpilogue(gmC[gmOffsetC], shapeC, gmPerTokenScale1[rowStartThisCore], gmPermutedToken[gmOffsetD], gmPerTokenScale2[rowStartThisCore], params.epilogueCoreNum); + } + AscendC::SyncAll(); + AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(SYNCFLAGV2C); + if ((params.epilogueGranularity < params.expertPerRank && params.epilogueGranularity > 0)) { + AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGC2V); + AscendC::SyncAll(); + if (dequantSum2 > 0) { + uint32_t rowStartThisCore = dequantSum1; + MatrixCoord offsetC{rowStartThisCore, 0}; + uint32_t dequantLen = dequantSum2; MatrixCoord shapeC{dequantLen, params.problemShape.n()}; LayoutC layoutC{dequantLen, params.problemShape.n()}; int64_t gmOffsetC = layoutC.GetOffset(offsetC); int64_t gmOffsetD = params.layoutD1.GetOffset(offsetC); - blockEpilogue(gmC[gmOffsetC], shapeC, gmPerTokenScale1[rowStartThisCore], gmPermutedToken[gmOffsetD], gmPerTokenScale2[rowStartThisCore], params.epilogueCoreNum); + blockEpilogue(gmC[gmOffsetC], shapeC, gmPerTokenScale1[rowStartThisCore], gmPermutedToken[gmOffsetD], gmPerTokenScale2[rowStartThisCore], coreNum); } - prevGroupSum1 += cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx); - dequantSum += cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx); - if (groupIdx + 1 == params.epilogueGranularity && groupIdx < params.expertPerRank - 1) { - dequantSum = 0; - } - } - syncLoopIdx ++; - AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGC2V); - AscendC::SyncAll(); - - uint32_t lastDequantExpertNum = params.expertPerRank; - if (params.epilogueGranularity < params.expertPerRank) { - lastDequantExpertNum = params.expertPerRank - params.epilogueGranularity; - } - if (lastDequantExpertNum < params.expertPerRank) { - AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(SYNCFLAGV2C); - } - if (prevGroupSum1 - dequantSum < params.maxOutputSize) { - uint32_t rowStartThisCore = prevGroupSum1 - dequantSum;; - MatrixCoord offsetC{rowStartThisCore, 0}; - uint32_t dequantLen = dequantSum; - if (prevGroupSum1 >= params.maxOutputSize) { - dequantLen = dequantSum - (prevGroupSum1 - params.maxOutputSize); - } - MatrixCoord shapeC{dequantLen, params.problemShape.n()}; - LayoutC layoutC{dequantLen, params.problemShape.n()}; - int64_t gmOffsetC = layoutC.GetOffset(offsetC); - int64_t gmOffsetD = params.layoutD1.GetOffset(offsetC); - blockEpilogue(gmC[gmOffsetC], shapeC, gmPerTokenScale1[rowStartThisCore], gmPermutedToken[gmOffsetD], gmPerTokenScale2[rowStartThisCore], coreNum); + AscendC::SyncAll(); + AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(SYNCFLAGV2C); } blockEpilogue.Finalize(); }