diff --git a/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine.h b/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine.h index faca1356..4e73b832 100644 --- a/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine.h +++ b/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine.h @@ -224,7 +224,7 @@ __aicore__ inline void DispatchFFNCombine::Process() constexpr uint32_t ubStages = 2; using EpilogueDispatchPolicy1 = Epilogue::EpilogueAtlasA2PerTokenDequantSwigluQuant; - + using ScaleType = Gemm::GemmType; using PerTokenScaleType = Gemm::GemmType; using ElementMulType = Gemm::GemmType; @@ -234,8 +234,7 @@ __aicore__ inline void DispatchFFNCombine::Process() using BlockEpilogue1 = Epilogue::Block::BlockEpilogue; - using EpilogueDispatchPolicy2 = Epilogue::EpilogueAtlasA2PerTokenDequant; - + using EpilogueDispatchPolicy2 = Epilogue::EpilogueAtlasA2PerTokenDequantV2; using TileCopy2 = Epilogue::Tile::TileCopy; using BlockEpilogue2 = Epilogue::Block::BlockEpilogue; diff --git a/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine_kernel.hpp b/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine_kernel.hpp index bdec03d6..df7d88f5 100644 --- a/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine_kernel.hpp +++ b/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine_kernel.hpp @@ -571,7 +571,6 @@ private: if constexpr (BlockMmad::DispatchPolicy::ASYNC) { blockMmad.SynchronizeBlock(); } - blockMmad.Finalize(params.expertPerRank - 1, 0); } @@ -838,16 +837,20 @@ private: uint32_t n2 = params.problemShape.k(); - typename BlockEpilogue2::Params epilogueParams{ static_cast(params.EP), static_cast(params.expertPerRank), + static_cast(params.rank), reinterpret_cast<__gm__ int32_t *>(shmem() + peermemInfo.offsetPeerTokenPerExpert), - static_cast(n2) + params.layoutD2, + static_cast(n2), + static_cast(L1TileShape::N), + shmem, + static_cast(peermemInfo.offsetD) }; uint32_t n = params.problemShape.n(); - + BlockEpilogue2 blockEpilogue2(resource, epilogueParams); BlockEpilogue1 blockEpilogue1(resource, n); // Synchronous wait: SwiGLU waits for GMM1 [1] @@ -886,13 +889,16 @@ private: } blockEpilogue1.Finalize(); - BlockEpilogue2 blockEpilogue2(resource, epilogueParams); - CombineV1(params, blockEpilogue2); + + + CombineSetFlag(); + + CombineV2(params, blockEpilogue2); + AscendC::SyncAll(); #ifndef __CROSSRANKSYNCANDALLGATHERV1__ ResetTokenPerExpert(params.EP * AlignUp(params.EP * params.expertPerRank, 128)); #endif - shmem.InitStatusTargetSum(); if (get_subblockid() == 0) { AscendC::LocalTensor ctrBuffer = resource.ubBuf.template GetBufferByByte(0); @@ -923,49 +929,6 @@ private: } } - CATLASS_DEVICE - void CombineV1(Params const ¶ms, BlockEpilogue2 & blockEpilogue) { - uint32_t n2 = params.problemShape.k(); - int32_t prevGroupSum2 = 0; - - icache_preload(8); - for (uint32_t t_groupIdx = 0; t_groupIdx < params.expertPerRank; ++t_groupIdx) { - int32_t flagId = t_groupIdx / CROSS_CORE_FLAG_MAX_SET_COUNT; - AscendC::CrossCoreWaitFlag<0x2>(flagId); - AscendC::SyncAll(); - - uint32_t groupIdx = t_groupIdx; - - for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) { - __gm__ void* dstPeermemPtr = shmem(peermemInfo.offsetD, dstEpIdx); - AscendC::GlobalTensor gmRemotePeer; - gmRemotePeer.SetGlobalBuffer(reinterpret_cast<__gm__ ElementD2*>(dstPeermemPtr)); - uint32_t srcRowOffset = (dstEpIdx == 0 ? 0 : cumsumMM((dstEpIdx - 1) * params.expertPerRank + groupIdx)) + prevGroupSum2; - if (srcRowOffset < params.maxOutputSize) { - uint32_t dataRows = tokenPerExpert(tokenPerExpertLayout(dstEpIdx, params.rank, groupIdx)); - if (srcRowOffset + dataRows > params.maxOutputSize) { - dataRows = params.maxOutputSize - srcRowOffset; - } - //uint32_t dstRowOffset = preSumBeforeRank(2 * dstEpIdx * FLAGSTRIDE + groupIdx); - int32_t tmpBlock = AlignUp(params.expertPerRank, FLAGSTRIDE); - //uint32_t dstRowOffset = preSumBeforeRank(dstEpIdx * tmpBlock + groupIdx); - uint32_t dstRowOffset = preSumBeforeRank(dstEpIdx * params.expertPerRank + groupIdx); - MatrixCoord offsetC{srcRowOffset, 0}; - MatrixCoord offsetPeer{dstRowOffset, 0}; - MatrixCoord shapeC{dataRows, n2}; - int64_t gmOffsetC = params.layoutD2.GetOffset(offsetC); - int64_t gmOffsetPeer = params.layoutD2.GetOffset(offsetPeer); - if constexpr (std::is_same_v) { - blockEpilogue(gmC2[gmOffsetC], shapeC, gmPerTokenScale2[srcRowOffset], gmRemotePeer[gmOffsetPeer]); - } else { - blockEpilogue(gmC2[gmOffsetC], shapeC, gmRemotePeer[gmOffsetPeer]); - } - } - } - prevGroupSum2 += cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx); - } - blockEpilogue.Finalize(); - } CATLASS_DEVICE void CombineV2(Params const ¶ms, BlockEpilogue2 & blockEpilogue) {