From b992b115454ab1bdbdf57306ea7c661c190c6ce8 Mon Sep 17 00:00:00 2001 From: LQLlulu <39671654+LQLlulu@users.noreply.github.com> Date: Sat, 18 Apr 2026 22:45:08 +0800 Subject: [PATCH] [BugFix] dispatch_ffn_combine kernal rollback combinev2 part (#8405) ### What this PR does / why we need it? ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? Signed-off-by: l00893928 Co-authored-by: l00893928 --- .../op_kernel/dispatch_ffn_combine.h | 5 +- .../op_kernel/dispatch_ffn_combine_kernel.hpp | 63 +++++++++++++++---- 2 files changed, 53 insertions(+), 15 deletions(-) diff --git a/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine.h b/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine.h index 4e73b832..faca1356 100644 --- a/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine.h +++ b/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine.h @@ -224,7 +224,7 @@ __aicore__ inline void DispatchFFNCombine::Process() constexpr uint32_t ubStages = 2; using EpilogueDispatchPolicy1 = Epilogue::EpilogueAtlasA2PerTokenDequantSwigluQuant; - + using ScaleType = Gemm::GemmType; using PerTokenScaleType = Gemm::GemmType; using ElementMulType = Gemm::GemmType; @@ -234,7 +234,8 @@ __aicore__ inline void DispatchFFNCombine::Process() using BlockEpilogue1 = Epilogue::Block::BlockEpilogue; - using EpilogueDispatchPolicy2 = Epilogue::EpilogueAtlasA2PerTokenDequantV2; + using EpilogueDispatchPolicy2 = Epilogue::EpilogueAtlasA2PerTokenDequant; + using TileCopy2 = Epilogue::Tile::TileCopy; using BlockEpilogue2 = Epilogue::Block::BlockEpilogue; diff --git a/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine_kernel.hpp b/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine_kernel.hpp index df7d88f5..bdec03d6 100644 --- a/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine_kernel.hpp +++ b/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine_kernel.hpp @@ -571,6 +571,7 @@ private: if constexpr (BlockMmad::DispatchPolicy::ASYNC) { blockMmad.SynchronizeBlock(); } + blockMmad.Finalize(params.expertPerRank - 1, 0); } @@ -837,20 +838,16 @@ private: uint32_t n2 = params.problemShape.k(); + typename BlockEpilogue2::Params epilogueParams{ static_cast(params.EP), static_cast(params.expertPerRank), - static_cast(params.rank), reinterpret_cast<__gm__ int32_t *>(shmem() + peermemInfo.offsetPeerTokenPerExpert), - params.layoutD2, - static_cast(n2), - static_cast(L1TileShape::N), - shmem, - static_cast(peermemInfo.offsetD) + static_cast(n2) }; uint32_t n = params.problemShape.n(); - BlockEpilogue2 blockEpilogue2(resource, epilogueParams); + BlockEpilogue1 blockEpilogue1(resource, n); // Synchronous wait: SwiGLU waits for GMM1 [1] @@ -889,16 +886,13 @@ private: } blockEpilogue1.Finalize(); - - - CombineSetFlag(); - - CombineV2(params, blockEpilogue2); - + BlockEpilogue2 blockEpilogue2(resource, epilogueParams); + CombineV1(params, blockEpilogue2); AscendC::SyncAll(); #ifndef __CROSSRANKSYNCANDALLGATHERV1__ ResetTokenPerExpert(params.EP * AlignUp(params.EP * params.expertPerRank, 128)); #endif + shmem.InitStatusTargetSum(); if (get_subblockid() == 0) { AscendC::LocalTensor ctrBuffer = resource.ubBuf.template GetBufferByByte(0); @@ -929,6 +923,49 @@ private: } } + CATLASS_DEVICE + void CombineV1(Params const ¶ms, BlockEpilogue2 & blockEpilogue) { + uint32_t n2 = params.problemShape.k(); + int32_t prevGroupSum2 = 0; + + icache_preload(8); + for (uint32_t t_groupIdx = 0; t_groupIdx < params.expertPerRank; ++t_groupIdx) { + int32_t flagId = t_groupIdx / CROSS_CORE_FLAG_MAX_SET_COUNT; + AscendC::CrossCoreWaitFlag<0x2>(flagId); + AscendC::SyncAll(); + + uint32_t groupIdx = t_groupIdx; + + for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) { + __gm__ void* dstPeermemPtr = shmem(peermemInfo.offsetD, dstEpIdx); + AscendC::GlobalTensor gmRemotePeer; + gmRemotePeer.SetGlobalBuffer(reinterpret_cast<__gm__ ElementD2*>(dstPeermemPtr)); + uint32_t srcRowOffset = (dstEpIdx == 0 ? 0 : cumsumMM((dstEpIdx - 1) * params.expertPerRank + groupIdx)) + prevGroupSum2; + if (srcRowOffset < params.maxOutputSize) { + uint32_t dataRows = tokenPerExpert(tokenPerExpertLayout(dstEpIdx, params.rank, groupIdx)); + if (srcRowOffset + dataRows > params.maxOutputSize) { + dataRows = params.maxOutputSize - srcRowOffset; + } + //uint32_t dstRowOffset = preSumBeforeRank(2 * dstEpIdx * FLAGSTRIDE + groupIdx); + int32_t tmpBlock = AlignUp(params.expertPerRank, FLAGSTRIDE); + //uint32_t dstRowOffset = preSumBeforeRank(dstEpIdx * tmpBlock + groupIdx); + uint32_t dstRowOffset = preSumBeforeRank(dstEpIdx * params.expertPerRank + groupIdx); + MatrixCoord offsetC{srcRowOffset, 0}; + MatrixCoord offsetPeer{dstRowOffset, 0}; + MatrixCoord shapeC{dataRows, n2}; + int64_t gmOffsetC = params.layoutD2.GetOffset(offsetC); + int64_t gmOffsetPeer = params.layoutD2.GetOffset(offsetPeer); + if constexpr (std::is_same_v) { + blockEpilogue(gmC2[gmOffsetC], shapeC, gmPerTokenScale2[srcRowOffset], gmRemotePeer[gmOffsetPeer]); + } else { + blockEpilogue(gmC2[gmOffsetC], shapeC, gmRemotePeer[gmOffsetPeer]); + } + } + } + prevGroupSum2 += cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx); + } + blockEpilogue.Finalize(); + } CATLASS_DEVICE void CombineV2(Params const ¶ms, BlockEpilogue2 & blockEpilogue) {