From 44ef9a36ac275aa63ee1b31a86e742efa4f4f3fa Mon Sep 17 00:00:00 2001 From: guanguan0308 <162653673+guanguan0308@users.noreply.github.com> Date: Mon, 23 Mar 2026 10:14:03 +0800 Subject: [PATCH] [fix]: fix precision issue in dispatch_ffn_combine_bf16 and remove redundant sync (#7198) ### What this PR does / why we need it? Fix the precision issue in dispatch_ffn_combine_bf16 operator. Remove redundant synchronization operations in dispatch_ffn_combine operator. - vLLM version: v0.16.0 - vLLM main: https://github.com/vllm-project/vllm/commit/4034c3d32e30d01639459edd3ab486f56993876d --------- Signed-off-by: guanguan0308 <1546542263@qq.com> --- .../op_kernel/dispatch_ffn_combine_kernel.hpp | 6 - .../dispatch_ffn_combine_bf16_tiling.cpp | 4 +- .../dispatch_ffn_combine_bf16_kernel.hpp | 597 +++++++++--------- .../op_kernel/unpermute/moe_token_unpermute.h | 4 +- .../utils/block_epilogue_pertoken_v2.hpp | 151 +---- ...block_mmad_preload_async_fixpipe_quant.hpp | 53 +- .../op_kernel/utils/const_args.hpp | 4 + .../op_kernel/utils/hccl_shmem.hpp | 174 ++++- 8 files changed, 531 insertions(+), 462 deletions(-) diff --git a/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine_kernel.hpp b/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine_kernel.hpp index 7051349b..df7d88f5 100644 --- a/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine_kernel.hpp +++ b/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine_kernel.hpp @@ -391,7 +391,6 @@ private: uint16_t syncgmmIdx = 0; AscendC::CrossCoreWaitFlag<0x2>(syncgmmIdx / CROSS_CORE_FLAG_MAX_SET_COUNT); // Wait for AIV to finish cumsum for matmul syncgmmIdx++; - AscendC::PipeBarrier(); for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) { uint32_t currentM = cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx); @@ -405,7 +404,6 @@ private: int32_t arrayGroupIdx = params.listLen == 1 ? 0 : groupIdx; gmB1.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(GetTensorAddr(arrayGroupIdx, params.ptrB1))); gmS.SetGlobalBuffer(reinterpret_cast<__gm__ ElementScale *>(GetTensorAddr(arrayGroupIdx, params.ptrScale1))); - AscendC::PipeBarrier(); if (currentM <= L1TileShape::M) { gmB1.SetL2CacheHint(AscendC::CacheMode::CACHE_MODE_DISABLE); } @@ -493,8 +491,6 @@ private: uint32_t startCoreIdx = 0; - AscendC::PipeBarrier(); - int64_t preCurrentmSum = 0; int32_t syncLoopIdx = -1; uint32_t lastDequantExpertNum = params.expertPerRank; @@ -503,8 +499,6 @@ private: lastDequantExpertNum = params.expertPerRank - params.epilogueGranularity; } - AscendC::PipeBarrier(); - for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) { uint32_t currentM = cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx); if (preCurrentmSum >= params.maxOutputSize) { diff --git a/csrc/dispatch_ffn_combine_bf16/op_host/dispatch_ffn_combine_bf16_tiling.cpp b/csrc/dispatch_ffn_combine_bf16/op_host/dispatch_ffn_combine_bf16_tiling.cpp index 482470a2..899c3f9a 100644 --- a/csrc/dispatch_ffn_combine_bf16/op_host/dispatch_ffn_combine_bf16_tiling.cpp +++ b/csrc/dispatch_ffn_combine_bf16/op_host/dispatch_ffn_combine_bf16_tiling.cpp @@ -41,6 +41,7 @@ namespace { constexpr uint32_t EXPERTID_INDEX = 3; constexpr uint32_t BLOCK_NUM = 20; constexpr uint32_t SYSTEM_NEED_WORKSPACE = 16 * 1024 * 1024; + constexpr uint64_t MB_SIZE = 1024 * 1024UL; } namespace optiling { @@ -240,7 +241,8 @@ static ge::graphStatus DispatchFFNCombineBF16TilingFuncImpl(gert::TilingContext info.maxOutputSize * n2 * sizeof(int16_t) + info.maxOutputSize * info.K * sizeof(int16_t) + info.maxOutputSize * k2 * sizeof(int16_t) + - info.worldSize * sizeof(int32_t) * 16; + info.worldSize * sizeof(int32_t) * 16 + + (info.expertPerRank + info.worldSize) * sizeof(int32_t) * 16; // std::max(info.maxOutputSize * info.N * sizeof(int16_t), info.maxOutputSize * n2 * sizeof(int16_t)) + // std::max(info.maxOutputSize * info.K * sizeof(int8_t), info.maxOutputSize * k2 * sizeof(int8_t)); diff --git a/csrc/dispatch_ffn_combine_bf16/op_kernel/dispatch_ffn_combine_bf16_kernel.hpp b/csrc/dispatch_ffn_combine_bf16/op_kernel/dispatch_ffn_combine_bf16_kernel.hpp index 51e939be..a2e6ba35 100644 --- a/csrc/dispatch_ffn_combine_bf16/op_kernel/dispatch_ffn_combine_bf16_kernel.hpp +++ b/csrc/dispatch_ffn_combine_bf16/op_kernel/dispatch_ffn_combine_bf16_kernel.hpp @@ -213,11 +213,7 @@ public: CATLASS_DEVICE void operator()(Params const ¶ms) { - Dispatch(params); - AscendC::SyncAll(); - AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(SYNCFLAGV2C); - - CombineV2(params); + DispatchAndCombine(params); } private: @@ -241,7 +237,7 @@ private: tokenPerExpert.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(shmem() + peermemInfo.offsetPeerTokenPerExpert)); - tokenPerExpertLayout = Layout3D(params.EP * params.expertPerRank, params.expertPerRank); + tokenPerExpertLayout = Layout3D(AlignUp(params.EP * params.expertPerRank, ALIGN_128), params.expertPerRank); preSumBeforeRank.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(workspaceInfo.ptrSumBeforeRank)); } @@ -309,7 +305,7 @@ private: AscendC::DataCopyPad( tmpBuffer1, tokenPerExpert[rankId * expertPerRank], - {U16(EP), U16(expertPerRank * sizeof(int32_t)), U16(((EP - 1) * expertPerRank) * sizeof(int32_t)), 0}, + {U16(EP), U16(expertPerRank * sizeof(int32_t)), U16((AlignUp(EP * expertPerRank, 128) - expertPerRank) * sizeof(int32_t)), 0}, {} ); @@ -331,145 +327,6 @@ private: ); } -CATLASS_DEVICE - void CrossRankSyncAndlocalTokenPerExpertAllGatherAndGetSumPreRank(Params const ¶ms, int64_t localTokenPerExpertOffset){ - AscendC::LocalTensor tmpBuffer = resource.ubBuf.template GetBufferByByte(0); - AscendC::LocalTensor ubFloat = resource.ubBuf.template GetBufferByByte(0); - uint32_t numPerCore = params.EP * params.expertPerRank; - for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) { - if (dstEpIdx == params.rank) { - continue; - } - AscendC::GlobalTensor srcAddress; - srcAddress.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(shmem() + localTokenPerExpertOffset)); - AscendC::GlobalTensor dstAddress; - __gm__ void* dstPeermemPtr = shmem(localTokenPerExpertOffset, coreIdx); - dstAddress.SetGlobalBuffer((__gm__ int32_t * )dstPeermemPtr); - - AscendC::SetFlag(EVENT_ID0); - using TType = Gemm::GemmType; - using CopyGmToUb = Epilogue::Tile::CopyGm2Ub; - using CopyUbToGm = Epilogue::Tile::CopyUb2Gm; - CopyGmToUb copyGmToUb; - CopyUbToGm copyUbToGm; - - AscendC::WaitFlag(EVENT_ID0); - - copyGmToUb(tmpBuffer, srcAddress[0], - layout::RowMajor{ 1, numPerCore}, - layout::RowMajor{1, numPerCore}); - - AscendC::SetFlag(EVENT_ID0); - AscendC::WaitFlag(EVENT_ID0); - AscendC::Adds(tmpBuffer, tmpBuffer, 0x800000, numPerCore); - AscendC::SetFlag(EVENT_ID0); - AscendC::WaitFlag(EVENT_ID0); - copyUbToGm(dstAddress[0], tmpBuffer, - layout::RowMajor{ 1, numPerCore}, - layout::RowMajor{1, numPerCore}); - AscendC::SetFlag(EVENT_ID0); - AscendC::WaitFlag(EVENT_ID0); - } - - for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) { - if (dstEpIdx != params.rank) { - int32_t intPer512 = CACHE_LINE / sizeof(int); - for(int32_t checkIdx = 0; checkIdx < params.EP * params.expertPerRank; checkIdx += intPer512) { - __gm__ int32_t* sync_check = reinterpret_cast<__gm__ int32_t*>(shmem() + peermemInfo.offsetPeerTokenPerExpert) + tokenPerExpertLayout(dstEpIdx, 0, checkIdx); - gm_signal_wait_until_ne(sync_check, 0); - } - AscendC::DataCopy(tmpBuffer, tokenPerExpert[tokenPerExpertLayout(dstEpIdx, 0, 0)], numPerCore); - AscendC::SetFlag(EVENT_ID0); - AscendC::WaitFlag(EVENT_ID0); - AscendC::Adds(tmpBuffer, tmpBuffer, -0x800000, numPerCore); - AscendC::PipeBarrier(); - AscendC::SetFlag(EVENT_ID0); - AscendC::WaitFlag(EVENT_ID0); - AscendC::DataCopy(tokenPerExpert[tokenPerExpertLayout(dstEpIdx, 0, 0)], tmpBuffer, numPerCore); - } else { - AscendC::DataCopy(tmpBuffer, tokenPerExpert[tokenPerExpertLayout(dstEpIdx, 0, 0)], numPerCore); - AscendC::SetFlag(EVENT_ID0); - AscendC::WaitFlag(EVENT_ID0); - } - } - - AscendC::SyncAll(); - } - - CATLASS_DEVICE - void CrossRankSyncAndlocalTokenPerExpertAllGatherAndGetSumPreRankV2(Params const ¶ms, int64_t localTokenPerExpertOffset){ - AscendC::LocalTensor tmpBuffer = resource.ubBuf.template GetBufferByByte(0); - AscendC::LocalTensor ubFloat = resource.ubBuf.template GetBufferByByte(0); - uint32_t numPerCore = params.EP * params.expertPerRank; - for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) { - if (dstEpIdx == params.rank) { - continue; - } - AscendC::GlobalTensor srcAddress; - srcAddress.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(shmem() + localTokenPerExpertOffset)); - AscendC::GlobalTensor dstAddress; - __gm__ void* dstPeermemPtr = shmem(localTokenPerExpertOffset, coreIdx); - dstAddress.SetGlobalBuffer((__gm__ int32_t * )dstPeermemPtr); - - AscendC::SetFlag(EVENT_ID0); - using TType = Gemm::GemmType; - using CopyGmToUb = Epilogue::Tile::CopyGm2Ub; - using CopyUbToGm = Epilogue::Tile::CopyUb2Gm; - CopyGmToUb copyGmToUb; - CopyUbToGm copyUbToGm; - - AscendC::WaitFlag(EVENT_ID0); - - copyGmToUb(tmpBuffer, srcAddress[0], - layout::RowMajor{ 1, numPerCore}, - layout::RowMajor{1, numPerCore}); - - AscendC::SetFlag(EVENT_ID0); - AscendC::WaitFlag(EVENT_ID0); - AscendC::Adds(tmpBuffer, tmpBuffer, 0x800000, numPerCore); - AscendC::SetFlag(EVENT_ID0); - AscendC::WaitFlag(EVENT_ID0); - copyUbToGm(dstAddress[0], tmpBuffer, - layout::RowMajor{ 1, numPerCore}, - layout::RowMajor{1, numPerCore}); - AscendC::SetFlag(EVENT_ID0); - AscendC::WaitFlag(EVENT_ID0); - } - - for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) { - if (dstEpIdx != params.rank) { - int32_t intPer512 = CACHE_LINE / sizeof(int); - for(int32_t checkIdx = 0; checkIdx < params.EP * params.expertPerRank; checkIdx += intPer512) { - __gm__ int32_t* sync_check = reinterpret_cast<__gm__ int32_t*>(shmem() + peermemInfo.offsetPeerTokenPerExpert) + tokenPerExpertLayout(dstEpIdx, 0, checkIdx); - gm_signal_wait_until_ne(sync_check, 0); - } - AscendC::DataCopy(tmpBuffer, tokenPerExpert[tokenPerExpertLayout(dstEpIdx, 0, 0)], numPerCore); - AscendC::SetFlag(EVENT_ID0); - AscendC::WaitFlag(EVENT_ID0); - AscendC::Adds(tmpBuffer, tmpBuffer, -0x800000, numPerCore); - AscendC::PipeBarrier(); - AscendC::SetFlag(EVENT_ID0); - AscendC::WaitFlag(EVENT_ID0); - AscendC::DataCopy(tokenPerExpert[tokenPerExpertLayout(dstEpIdx, 0, 0)], tmpBuffer, numPerCore); - } else { - AscendC::DataCopy(tmpBuffer, tokenPerExpert[tokenPerExpertLayout(dstEpIdx, 0, 0)], numPerCore); - AscendC::SetFlag(EVENT_ID0); - AscendC::WaitFlag(EVENT_ID0); - } - - int32_t prevSum = 0; - for (int32_t i = 0; i < params.rank * params.expertPerRank; i++) { - prevSum += tmpBuffer(i); - } - preSumBeforeRank(dstEpIdx * 16) = prevSum; - __asm__ __volatile__(""); - DataCacheCleanAndInvalid(preSumBeforeRank[dstEpIdx * 16]); - __asm__ __volatile__(""); - - } - AscendC::SyncAll(); - } - CATLASS_DEVICE void GetSumPreRank(AscendC::GlobalTensor & tokenPerExpert, AscendC::GlobalTensor & result, uint32_t expertPerRank, uint32_t rankId, uint32_t EP) { @@ -506,20 +363,20 @@ CATLASS_DEVICE icache_preload(8); BlockScheduler blockScheduler; BlockMmad blockMmad(resource); + float aivFinishGroups = 0.0f; + __gm__ float* aivFinishPtr = workspaceInfo.ptrSoftFlagBase + params.EP * FLAGSTRIDE; int64_t gmGroupOffsetA = 0; int64_t gmGroupOffsetB = 0; int64_t gmGroupOffsetC = 0; uint32_t startCoreIdx = 0; uint32_t syncGroupIdx = 0; - uint16_t syncgmmIdx = 0; - AscendC::CrossCoreWaitFlag<0x2>(syncgmmIdx / CROSS_CORE_FLAG_MAX_SET_COUNT); // Wait for AIV to finish cumsum for matmul - syncgmmIdx++; - AicSyncAll(); int64_t preCurrentmSum = 0; int32_t syncLoopIdx = -1; - AscendC::PipeBarrier(); + uint16_t syncgmmIdx = 0; + AscendC::CrossCoreWaitFlag<0x2>(syncgmmIdx / CROSS_CORE_FLAG_MAX_SET_COUNT); // Wait for AIV to finish cumsum for matmul + syncgmmIdx++; for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) { uint32_t currentM = cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx); @@ -533,9 +390,6 @@ CATLASS_DEVICE int32_t arrayGroupIdx = params.listLen == 1 ? 0 : groupIdx; gmB1.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(GetTensorAddr(arrayGroupIdx, params.ptrB1))); gmS.SetGlobalBuffer(reinterpret_cast<__gm__ ElementScale *>(GetTensorAddr(arrayGroupIdx, params.ptrScale1))); - - AscendC::PipeBarrier(); - if (currentM <= L1TileShape::M) { gmB1.SetL2CacheHint(AscendC::CacheMode::CACHE_MODE_DISABLE); } @@ -630,7 +484,6 @@ CATLASS_DEVICE uint32_t startCoreIdx = 0; - AscendC::PipeBarrier(); int64_t preCurrentmSum = 0; int32_t syncLoopIdx = -1; @@ -640,7 +493,6 @@ CATLASS_DEVICE lastDequantExpertNum = params.expertPerRank - params.epilogueGranularity; } - AscendC::PipeBarrier(); for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) { uint32_t currentM = cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx); @@ -651,7 +503,6 @@ CATLASS_DEVICE } AscendC::GlobalTensor gmB2; AscendC::GlobalTensor gmS2; - AscendC::PipeBarrier(); int32_t arrayGroupIdx = params.listLen == 1 ? 0 : groupIdx; gmB2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(GetTensorAddr(arrayGroupIdx, params.ptrB2))); gmS2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementScale *>(GetTensorAddr(arrayGroupIdx, params.ptrScale2))); @@ -721,7 +572,6 @@ CATLASS_DEVICE gmGroupOffsetC += inGroupProblemShape.m() * inGroupProblemShape.n(); startCoreIdx = (startCoreIdx + coreLoops) % coreNum; - } if constexpr (BlockMmad::DispatchPolicy::ASYNC) { @@ -729,8 +579,168 @@ CATLASS_DEVICE } } + CATLASS_DEVICE + void InitArithProgress(Params const ¶ms) { + AscendC::LocalTensor tmpBuffer1 = resource.ubBuf.template GetBufferByByte(0); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + AscendC::Duplicate(tmpBuffer1, 0.0f, (params.EP + 1) * FLAGSTRIDE); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + + AscendC::GlobalTensor flagGlobalBase; + flagGlobalBase.SetGlobalBuffer(workspaceInfo.ptrSoftFlagBase); + AscendC::DataCopy(flagGlobalBase, tmpBuffer1, (params.EP + 1) * FLAGSTRIDE); + } + + CATLASS_DEVICE - void Dispatch(Params const ¶ms) { + void CrossRankSyncAndlocalTokenPerExpertAllGatherAndGetSumPreRankV2(Params const ¶ms, int64_t localTokenPerExpertOffset){ + uint32_t numPerCore = AlignUp(params.EP * params.expertPerRank, 128); + AscendC::LocalTensor tmpBuffer = resource.ubBuf.template GetBufferByByte(0); + AscendC::LocalTensor prevSumBuf = tmpBuffer[numPerCore]; + + for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) { + if (dstEpIdx == params.rank) { + continue; + } + AscendC::GlobalTensor srcAddress; + srcAddress.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(shmem() + localTokenPerExpertOffset)); + AscendC::GlobalTensor dstAddress; + __gm__ void* dstPeermemPtr = shmem(localTokenPerExpertOffset, coreIdx); + dstAddress.SetGlobalBuffer((__gm__ int32_t * )dstPeermemPtr); + + AscendC::SetFlag(EVENT_ID0); + using TType = Gemm::GemmType; + using CopyGmToUb = Epilogue::Tile::CopyGm2Ub; + using CopyUbToGm = Epilogue::Tile::CopyUb2Gm; + CopyGmToUb copyGmToUb; + CopyUbToGm copyUbToGm; + + AscendC::WaitFlag(EVENT_ID0); + + copyGmToUb(tmpBuffer, srcAddress[0], + layout::RowMajor{ 1, numPerCore}, + layout::RowMajor{1, numPerCore}); + + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + AscendC::Adds(tmpBuffer, tmpBuffer, 0x800000, numPerCore); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + copyUbToGm(dstAddress[0], tmpBuffer, + layout::RowMajor{ 1, numPerCore}, + layout::RowMajor{1, numPerCore}); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + } + for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) { + if (dstEpIdx != params.rank) { + int32_t intPer512 = CACHE_LINE / sizeof(int); + for(int32_t checkIdx = 0; checkIdx < AlignUp(params.EP * params.expertPerRank, 128); checkIdx += intPer512) { + __gm__ int32_t* sync_check = reinterpret_cast<__gm__ int32_t*>(shmem() + peermemInfo.offsetPeerTokenPerExpert) + tokenPerExpertLayout(dstEpIdx, 0, checkIdx); + gm_signal_wait_until_ne(sync_check, 0); + } + AscendC::DataCopy(tmpBuffer, tokenPerExpert[tokenPerExpertLayout(dstEpIdx, 0, 0)], numPerCore); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + AscendC::Adds(tmpBuffer, tmpBuffer, -0x800000, numPerCore); + AscendC::PipeBarrier(); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + AscendC::DataCopy(tokenPerExpert[tokenPerExpertLayout(dstEpIdx, 0, 0)], tmpBuffer, numPerCore); + } else { + AscendC::DataCopy(tmpBuffer, tokenPerExpert[tokenPerExpertLayout(dstEpIdx, 0, 0)], numPerCore); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + } + AscendC::PipeBarrier(); + int32_t prevSum = 0; + int32_t j = 0; + for (int32_t i = 0; i < (params.rank + 1) * params.expertPerRank; i++) { + if (i >= params.rank * params.expertPerRank) { + prevSumBuf(j) = prevSum; + j++; + } + prevSum += tmpBuffer(i); + } + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + AscendC::DataCopyPad(preSumBeforeRank[dstEpIdx * params.expertPerRank], prevSumBuf, + AscendC::DataCopyParams{1, static_cast(params.expertPerRank * sizeof(int32_t)), 0, 0}); + } + + AscendC::SyncAll(); + } + + CATLASS_DEVICE + void ResetTokenPerExpert(int32_t num) + { + if (coreIdx != coreNum - 1) { + return; + } + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + AscendC::LocalTensor tmp = resource.ubBuf.template GetBufferByByte(0); + AscendC::Duplicate(tmp, 0, num); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + AscendC::DataCopy(tokenPerExpert, tmp, num); + } + + CATLASS_DEVICE + void UpdateAicFlags(const Params ¶ms) + { + float flagBase = 1.0f * params.expertPerRank; + __gm__ float* aicFinishPtr = workspaceInfo.ptrSoftFlagBase + params.EP * FLAGSTRIDE; + float flag = 0.0f; + float lastflag = -1.0f; + AscendC::LocalTensor tmpBuffer1 = resource.ubBuf.template GetBufferByByte(0); + __gm__ float* flagPtr = workspaceInfo.ptrSoftFlagBase; + AscendC::GlobalTensor flagGM; + flagGM.SetGlobalBuffer(flagPtr); + int32_t flagBufferSize = max(4, params.EP) * FLAGSTRIDE; + AscendC::LocalTensor dstValueBuffer = resource.ubBuf.template GetBufferByByte(flagBufferSize); + AscendC::LocalTensor sharedTmpBuffer = resource.ubBuf.template GetBufferByByte((flagBufferSize + 64)); + uint64_t mask[1] = {0}; + uint32_t repeatNum = (flagBufferSize / (4 * FLAGSTRIDE)); + for (int32_t i = 0; i < 4; i ++) { + if (i < params.EP) { + mask[0] |= 1ull * (1ull << (i * 16)); + } + } + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + while (flag < flagBase) { + flag = flagBase; + AscendC::DataCopy(tmpBuffer1, flagGM, params.EP * FLAGSTRIDE); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + + AscendC::ReduceMin(dstValueBuffer, tmpBuffer1, sharedTmpBuffer, mask, repeatNum, 8, false); + + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + flag = min(flag, dstValueBuffer.GetValue(0)); + + if (flag > lastflag) { + *aicFinishPtr = flag; + gm_dcci(aicFinishPtr); + lastflag = flag; + } + } + } + + + CATLASS_DEVICE + void CombineSetFlag() { + AscendC::SetFlag(EVENT_ID0); + AscendC::SetFlag(EVENT_ID1); + } + + + CATLASS_DEVICE + void DispatchAndCombine(Params const ¶ms) { icache_preload(8); int64_t localTokenPerExpertOffset = peermemInfo.offsetPeerTokenPerExpert + tokenPerExpertLayout(params.rank, 0, 0) * sizeof(int32_t); GM_ADDR localTokenPerExpert = shmem() + localTokenPerExpertOffset; @@ -744,35 +754,35 @@ CATLASS_DEVICE AscendC::SyncAll(); CrossRankSyncAndlocalTokenPerExpertAllGatherAndGetSumPreRankV2(params, localTokenPerExpertOffset); - if (coreIdx == coreNum - 1) { + if (coreIdx == 0) { GetCumsumForMMAIV(tokenPerExpert, cumsumMM, params.expertPerRank, params.rank, params.EP); } + + uint32_t curGroupOffset = 0; + int32_t prevSumBeforeRank = 0; + int32_t prevSum = 0; + if (coreIdx < params.EP) { + prevSum = preSumBeforeRank(coreIdx * params.expertPerRank); + } AscendC::SyncAll(); + + AscendC::GlobalTensor ExpertTokenNums; + ExpertTokenNums.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(params.ptrExpertTokenNums)); + if(coreIdx == 0) + { + CopyGMToGM(ExpertTokenNums, cumsumMM[(params.EP - 1) * params.expertPerRank], params.expertPerRank, params.ubMoveNum); + } uint16_t syncgmm1Idx = 0; AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(syncgmm1Idx / CROSS_CORE_FLAG_MAX_SET_COUNT); syncgmm1Idx++; - AscendC::GlobalTensor ExpertTokenNums; - ExpertTokenNums.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(params.ptrExpertTokenNums)); - AscendC::GlobalTensor LcalCumsumMM; - LcalCumsumMM.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(workspaceInfo.ptrcumsumMM + (params.EP - 1) * params.expertPerRank * sizeof(int32_t))); - if (coreIdx == 0) { - CopyGMToGM(ExpertTokenNums, LcalCumsumMM, params.expertPerRank, params.ubMoveNum); - } - - uint32_t curGroupOffset = 0; - int32_t prevSumBeforeRank = 0; - int32_t groupIdxDeq = 0; - int prevSum = 0; - if (coreIdx < params.EP) { - prevSum = preSumBeforeRank(coreIdx * 16); - } - uint32_t prevGroupSum1 = 0; + uint32_t prevGroupSum1 = 0, dequantSum1 = 0, dequantSum2 = 0; uint32_t dequantSum = 0; - int32_t syncLoopIdx = -1; - uint32_t n = params.problemShape.n(); - BlockEpilogue1 blockEpilogue(resource, n); + + icache_preload(8); for (int32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) { + // The ith core reads data from the ith rank's peermem + uint32_t currentM = cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx); for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) { uint32_t rowStart = (dstEpIdx == 0 ? 0 : cumsumMM((dstEpIdx - 1) * params.expertPerRank + groupIdx)) + prevGroupSum1; if (rowStart < params.maxOutputSize) { @@ -785,99 +795,40 @@ CATLASS_DEVICE GM_ADDR otherRankPtr = shmem(0, dstEpIdx); AscendC::GlobalTensor gmRemoteA; gmRemoteA.SetGlobalBuffer(reinterpret_cast<__gm__ ElementA*>(otherRankPtr + peermemInfo.offsetA)); - AscendC::GlobalTensor gmRemotePerTokenScale; - gmRemotePerTokenScale.SetGlobalBuffer(reinterpret_cast<__gm__ ElementPerTokenScale*>(otherRankPtr + peermemInfo.offsetPeerPerTokenScale)); + MatrixCoord offsetA{rowStart, 0}; - MatrixCoord shapeA{rows, params.problemShape.k()}; MatrixCoord offsetPeer{rowSrc, 0}; int64_t gmOffsetA = params.layoutA.GetOffset(offsetA); int64_t gmOffsetPeer = params.layoutA.GetOffset(offsetPeer); CopyGMToGM(gmA[gmOffsetA], gmRemoteA[gmOffsetPeer], rows * params.problemShape.k(), params.ubMoveNum); - if constexpr (std::is_same_v) { - CopyGMToGM(gmPerTokenScale1[rowStart], gmRemotePerTokenScale[rowSrc], rows, rows); - } } } - - if ((params.epilogueGranularity < params.expertPerRank && params.epilogueGranularity > 0) && groupIdx == params.expertPerRank - 1) { - syncLoopIdx++; - AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGC2V); - } AscendC::SyncAll(); - AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(syncgmm1Idx / CROSS_CORE_FLAG_MAX_SET_COUNT); // V notifies C that the current communication round is complete - syncgmm1Idx++; + AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(syncgmm1Idx / CROSS_CORE_FLAG_MAX_SET_COUNT); + syncgmm1Idx ++; - if ((params.epilogueGranularity < params.expertPerRank && params.epilogueGranularity > 0) && groupIdx == params.expertPerRank - 1 && prevGroupSum1 > 0) { - uint32_t rowStartThisCore = 0; - MatrixCoord offsetC{0U, 0}; - uint32_t dequantLen = prevGroupSum1 - dequantSum; - if (dequantLen >= params.maxOutputSize) { - dequantLen = dequantLen - params.maxOutputSize; - } + prevGroupSum1 += currentM; - MatrixCoord shapeC{dequantLen, params.problemShape.n()}; - LayoutC layoutC{dequantLen, params.problemShape.n()}; - int64_t gmOffsetC = layoutC.GetOffset(offsetC); - int64_t gmOffsetD = params.layoutD1.GetOffset(offsetC); - if constexpr (std::is_same_v) { - blockEpilogue(gmC[gmOffsetC], shapeC, gmPerTokenScale1[rowStartThisCore], gmPermutedToken[gmOffsetD], gmPerTokenScale2[rowStartThisCore], params.epilogueCoreNum); - } else { - blockEpilogue(gmC[gmOffsetC], shapeC, gmPermutedToken[gmOffsetD], params.epilogueCoreNum); + // Token count and truncation logic for the first SwiGLU operation + if (groupIdx + 1 <= params.epilogueGranularity) { + if (dequantSum1 + currentM <= params.maxOutputSize) { + dequantSum1 += currentM; + } else if (dequantSum1 < params.maxOutputSize) { + dequantSum1 = params.maxOutputSize; } } - prevGroupSum1 += cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx); - dequantSum += cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx); - if (groupIdx + 1 == params.epilogueGranularity && groupIdx < params.expertPerRank - 1) { - dequantSum = 0; - } - } - syncLoopIdx ++; - AscendC::SyncAll(); - - uint32_t lastDequantExpertNum = params.expertPerRank; - if (params.epilogueGranularity < params.expertPerRank) { - lastDequantExpertNum = params.expertPerRank - params.epilogueGranularity; - } - if (lastDequantExpertNum < params.expertPerRank) { - AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(SYNCFLAGV2C); - } - AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGC2V); - AscendC::SyncAll(); - if (prevGroupSum1 - dequantSum < params.maxOutputSize) { - uint32_t rowStartThisCore = prevGroupSum1 - dequantSum;; - MatrixCoord offsetC{rowStartThisCore, 0}; - uint32_t dequantLen = dequantSum; - if (prevGroupSum1 >= params.maxOutputSize) { - dequantLen = dequantSum - (prevGroupSum1 - params.maxOutputSize); - } - MatrixCoord shapeC{dequantLen, params.problemShape.n()}; - LayoutC layoutC{dequantLen, params.problemShape.n()}; - int64_t gmOffsetC = layoutC.GetOffset(offsetC); - int64_t gmOffsetD = params.layoutD1.GetOffset(offsetC); - if constexpr (std::is_same_v) { - blockEpilogue(gmC[gmOffsetC], shapeC, gmPerTokenScale1[rowStartThisCore], gmPermutedToken[gmOffsetD], gmPerTokenScale2[rowStartThisCore], coreNum); - } else { - blockEpilogue(gmC[gmOffsetC], shapeC, gmPermutedToken[gmOffsetD], coreNum); + // Token count and truncation logic for the second SwiGLU operation + if (groupIdx + 1 > params.epilogueGranularity && dequantSum1 < params.maxOutputSize) { + if (dequantSum1 + dequantSum2 + currentM <= params.maxOutputSize) { + dequantSum2 += currentM; + } else if (dequantSum1 + dequantSum2 < params.maxOutputSize) { + dequantSum2 += params.maxOutputSize - dequantSum1 - dequantSum2; + } } } - blockEpilogue.Finalize(); - } - - CATLASS_DEVICE - void CombineV2(Params const ¶ms) { - BlockScheduler blockScheduler; uint32_t n2 = params.problemShape.k(); - uint32_t k2 = params.problemShape.n() / 2; - uint32_t startCoreIdx = 0; - int64_t gmGroupOffsetC = 0; - uint32_t aivCoreNum = coreNum; - uint32_t aicCoreNum = coreNum / 2; - uint32_t aivCoreIdx = coreIdx; - uint32_t aicCoreIdx = get_block_idx(); - uint32_t aivSubCoreIdx = get_subblockid(); - uint32_t preSrcExpertSum = 0; typename BlockEpilogue2::Params epilogueParams{ static_cast(params.EP), @@ -891,11 +842,108 @@ CATLASS_DEVICE static_cast(peermemInfo.offsetD) }; - BlockEpilogue2 blockEpilogue(resource, epilogueParams); + BlockEpilogue2 blockEpilogue2(resource, epilogueParams); + uint32_t n = params.problemShape.n(); + BlockEpilogue1 blockEpilogue1(resource, n); + + // Synchronous wait: SwiGLU waits for GMM1 [1] + AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGC2V); + AscendC::SyncAll(); + if (dequantSum1 > 0) { + uint32_t rowStartThisCore = 0; + MatrixCoord offsetC{0U, 0}; + MatrixCoord shapeC{dequantSum1, params.problemShape.n()}; + LayoutC layoutC{dequantSum1, params.problemShape.n()}; + int64_t gmOffsetC = layoutC.GetOffset(offsetC); + int64_t gmOffsetD = params.layoutD1.GetOffset(offsetC); + // blockEpilogue1(gmC[gmOffsetC], shapeC, gmPerTokenScale1[rowStartThisCore], gmPermutedToken[gmOffsetD], gmPerTokenScale2[rowStartThisCore], params.epilogueCoreNum); + blockEpilogue1(gmC[gmOffsetC], shapeC, gmPermutedToken[gmOffsetD], params.epilogueCoreNum); + } + AscendC::SyncAll(); + // Synchronization signal: SwiGLU notifies GMM2 [1] + AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(SYNCFLAGV2C); + + if ((params.epilogueGranularity < params.expertPerRank && params.epilogueGranularity > 0)) { + // Synchronous wait: SwiGLU waits for GMM1 [2] + AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGC2V); + AscendC::SyncAll(); + if (dequantSum2 > 0) { + uint32_t rowStartThisCore = dequantSum1; + MatrixCoord offsetC{rowStartThisCore, 0}; + uint32_t dequantLen = dequantSum2; + MatrixCoord shapeC{dequantLen, params.problemShape.n()}; + LayoutC layoutC{dequantLen, params.problemShape.n()}; + int64_t gmOffsetC = layoutC.GetOffset(offsetC); + int64_t gmOffsetD = params.layoutD1.GetOffset(offsetC); + // blockEpilogue1(gmC[gmOffsetC], shapeC, gmPerTokenScale1[rowStartThisCore], gmPermutedToken[gmOffsetD], gmPerTokenScale2[rowStartThisCore], coreNum); + blockEpilogue1(gmC[gmOffsetC], shapeC, gmPermutedToken[gmOffsetD], coreNum); + } + AscendC::SyncAll(); + // Synchronization signal: SwiGLU notifies GMM2 [2] + AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(SYNCFLAGV2C); + } + + blockEpilogue1.Finalize(); + + CombineSetFlag(); + + CombineV2(params, blockEpilogue2); + + AscendC::SyncAll(); + #ifndef __CROSSRANKSYNCANDALLGATHERV1__ + ResetTokenPerExpert(params.EP * AlignUp(params.EP * params.expertPerRank, 128)); + #endif + shmem.InitStatusTargetSum(); + if (get_subblockid() == 0) { + AscendC::LocalTensor ctrBuffer = resource.ubBuf.template GetBufferByByte(0); + shmem.CrossRankSyncV2Set(ctrBuffer); + } else { + uint32_t uboffset = 0; + uint32_t aicCoreNum = coreNum / 2; + uint32_t aicCoreIdx = get_block_idx(); + uint32_t sendRankNum_ = params.EP / aicCoreNum; + uint32_t remainderRankNum = params.EP % aicCoreNum; + if (aicCoreIdx < remainderRankNum) { + sendRankNum_++; + } + AscendC::LocalTensor statusTensor = resource.ubBuf.template GetBufferByByte(uboffset); + uboffset += sendRankNum_ * UB_ALIGN; + AscendC::LocalTensor gatherMaskOutTensor = resource.ubBuf.template GetBufferByByte(uboffset); + uboffset += AlignUp(params.EP * sizeof(float), 32); + AscendC::LocalTensor gatherTmpTensor = resource.ubBuf.template GetBufferByByte(uboffset); + uboffset += AlignUp(sizeof(uint32_t), 32); + AscendC::LocalTensor statusSumOutTensor = resource.ubBuf.template GetBufferByByte(uboffset); + uboffset += AlignUp(sizeof(float), 32); + shmem.CrossRankSyncV2Wait(statusTensor, gatherMaskOutTensor, gatherTmpTensor, statusSumOutTensor); + MoeTokenUnpermuteTilingData tilingData; + MoeTokenUnpermuteTiling(params.problemShape.m() * params.topK, n2, params.topK, tilingData, coreNum / 2); + KernelMoeTokenUnpermute kernelMoeTokenUnpermuteOp; + kernelMoeTokenUnpermuteOp.Init(shmem() + peermemInfo.offsetD, workspaceInfo.expandedRowIdx, params.probs, reinterpret_cast(params.ptrOutput), &tilingData); + kernelMoeTokenUnpermuteOp.Process(); + } + + } + + CATLASS_DEVICE + void CombineV2(Params const ¶ms, BlockEpilogue2 & blockEpilogue) { + BlockScheduler blockScheduler; int32_t syncLoopIdx = 0; + uint32_t startCoreIdx = 0; + uint32_t aicCoreNum = coreNum / 2; + uint32_t aicCoreIdx = get_block_idx(); + uint32_t aivSubCoreIdx = get_subblockid(); + uint32_t preSrcExpertSum = 0; + uint32_t n2 = params.problemShape.k(); + uint32_t k2 = params.problemShape.n() / 2; + icache_preload(8); for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) { uint32_t currentExpertM = cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx); + if (preSrcExpertSum >= params.maxOutputSize) { + currentExpertM = 0; + } else if (preSrcExpertSum + currentExpertM > params.maxOutputSize) { + currentExpertM = params.maxOutputSize - preSrcExpertSum; + } GemmCoord inGroupProblemShape{currentExpertM, n2, k2}; // M N K blockScheduler.Update(inGroupProblemShape, MakeCoord(L1TileShape::M, L1TileShape::N)); uint32_t coreLoops = blockScheduler.GetCoreLoops(); @@ -916,11 +964,10 @@ CATLASS_DEVICE if(aivSubCoreIdx == 1) { m_offset += (m_rows / 2) * m0; } - if (loopIdx == startLoopIdx) { - for (;syncLoopIdx <= groupIdx; syncLoopIdx++) { - int32_t flag_id = syncLoopIdx / CROSS_CORE_FLAG_MAX_SET_COUNT; - AscendC::CrossCoreWaitFlag<0x2>(flag_id); - } + + for (;syncLoopIdx <= groupIdx; syncLoopIdx ++) { + int32_t flag_id = syncLoopIdx / CROSS_CORE_FLAG_MAX_SET_COUNT; + AscendC::CrossCoreWaitFlag<0x2>(flag_id); } for (int32_t cur_row = 0; cur_row < aiv_m_rows; cur_row ++) { @@ -930,34 +977,15 @@ CATLASS_DEVICE actualm = actualBlockShape.m() - (m_rows / 2) * m0 - cur_row * m0; } GemmCoord realTileShape{actualm, actualBlockShape.n(), 1}; - if constexpr (std::is_same_v) { - blockEpilogue(gmC2, gmPerTokenScale2, realTileCoord, realTileShape, groupIdx, preSrcExpertSum, preSumBeforeRank, mPreSumBeforeRank); - } else { - blockEpilogue(gmC2, realTileCoord, realTileShape, groupIdx, preSrcExpertSum, preSumBeforeRank, mPreSumBeforeRank); - } + blockEpilogue(gmC2, realTileCoord, realTileShape, groupIdx, preSrcExpertSum, preSumBeforeRank); m_offset += m0; } } - - for (int32_t dstEpIdx = 0; dstEpIdx < params.EP; dstEpIdx ++) { - int32_t expertRankM = tokenPerExpert(tokenPerExpertLayout(dstEpIdx, params.rank, groupIdx)); - mPreSumBeforeRank[dstEpIdx] += expertRankM; - } preSrcExpertSum += currentExpertM; startCoreIdx = (startCoreIdx + coreLoops) % aicCoreNum; } blockEpilogue.Finalize(); - AscendC::SyncAll(); - ResetTokenPerExpert(tokenPerExpert, params.EP * params.EP * params.expertPerRank); - shmem.CrossRankSync(); - - MoeTokenUnpermuteTilingData tilingData; - MoeTokenUnpermuteTiling(params.problemShape.m() * params.topK, n2, params.topK, tilingData, aivCoreNum); - KernelMoeTokenUnpermute kernelMoeTokenUnpermuteOp; - - kernelMoeTokenUnpermuteOp.Init(shmem() + peermemInfo.offsetD, workspaceInfo.expandedRowIdx, params.probs, reinterpret_cast(params.ptrOutput), &tilingData); - kernelMoeTokenUnpermuteOp.Process(); } @@ -973,6 +1001,7 @@ private: GM_ADDR expandedRowIdx; GM_ADDR ptrTokenPerExpert; GM_ADDR ptrSumBeforeRank; + __gm__ float* ptrSoftFlagBase; CATLASS_DEVICE WorkspaceInfo(){} @@ -1012,6 +1041,9 @@ private: workspaceOffset += params.maxOutputSize * k2 * sizeof(ElementA); ptrSumBeforeRank = params.ptrWorkspace + workspaceOffset; + workspaceOffset += params.EP * sizeof(int32_t) * FLAGSTRIDE; + ptrSoftFlagBase = reinterpret_cast<__gm__ float*>(params.ptrWorkspace + workspaceOffset); + } }; @@ -1041,7 +1073,6 @@ private: WorkspaceInfo workspaceInfo; PeermemInfo peermemInfo; - int64_t m_prevSumBeforeRank; AscendC::GlobalTensor gmA; AscendC::GlobalTensor gmC; @@ -1057,7 +1088,7 @@ private: AscendC::GlobalTensor tokenPerExpert; AscendC::GlobalTensor cumsumMM; AscendC::GlobalTensor preSumBeforeRank; - uint32_t mPreSumBeforeRank[32] = {0}; + Layout3D tokenPerExpertLayout; HcclShmem shmem; }; diff --git a/csrc/dispatch_ffn_combine_bf16/op_kernel/unpermute/moe_token_unpermute.h b/csrc/dispatch_ffn_combine_bf16/op_kernel/unpermute/moe_token_unpermute.h index 1255b5cf..adb805b8 100644 --- a/csrc/dispatch_ffn_combine_bf16/op_kernel/unpermute/moe_token_unpermute.h +++ b/csrc/dispatch_ffn_combine_bf16/op_kernel/unpermute/moe_token_unpermute.h @@ -85,8 +85,8 @@ KernelMoeTokenUnpermute::Init(GM_ADDR permuted_tokens, GM_ADD GM_ADDR unpermuted_tokens, const MoeTokenUnpermuteTilingData *__restrict tiling_data) { - this->blockIdx = get_block_idx() + get_subblockid() * get_block_num(); - this->blockNum = get_block_num() * get_subblockdim(); + this->blockIdx = get_block_idx(); + this->blockNum = get_block_num(); if (blockIdx >= blockNum) { return; diff --git a/csrc/dispatch_ffn_combine_bf16/op_kernel/utils/block_epilogue_pertoken_v2.hpp b/csrc/dispatch_ffn_combine_bf16/op_kernel/utils/block_epilogue_pertoken_v2.hpp index eaab8104..251408ef 100644 --- a/csrc/dispatch_ffn_combine_bf16/op_kernel/utils/block_epilogue_pertoken_v2.hpp +++ b/csrc/dispatch_ffn_combine_bf16/op_kernel/utils/block_epilogue_pertoken_v2.hpp @@ -72,17 +72,6 @@ public: CATLASS_DEVICE BlockEpilogue(Arch::Resource const &resource, Params const ¶ms = Params{}) : params(params) { - AscendC::SetFlag(EVENT_ID0); - AscendC::SetFlag(EVENT_ID1); - AscendC::SetFlag(EVENT_ID2); - AscendC::SetFlag(EVENT_ID3); - AscendC::SetFlag(EVENT_ID2); - AscendC::SetFlag(EVENT_ID3); - AscendC::SetFlag(EVENT_ID0); - AscendC::SetFlag(EVENT_ID1); - - - //ub:192KB n0 = params.n0; size_t ubOffset = 0; @@ -98,137 +87,20 @@ public: source_scale_offset[i] = -1; } tokenPerExpert.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(params.ptrTokenPerExpert)); - tokenPerExpertLayout = Layout3D(params.EP * params.expertPerRank, params.expertPerRank); + tokenPerExpertLayout = Layout3D(AlignUp(params.EP * params.expertPerRank, ALIGN_128), params.expertPerRank); is_ping = true; } CATLASS_DEVICE void Finalize() { - AscendC::WaitFlag(EVENT_ID0); - AscendC::WaitFlag(EVENT_ID1); - AscendC::WaitFlag(EVENT_ID2); - AscendC::WaitFlag(EVENT_ID3); - AscendC::WaitFlag(EVENT_ID2); - AscendC::WaitFlag(EVENT_ID3); - AscendC::WaitFlag(EVENT_ID0); - AscendC::WaitFlag(EVENT_ID1); - + AscendC::WaitFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID1); } CATLASS_DEVICE ~BlockEpilogue() { - } - CATLASS_DEVICE - void operator() ( - AscendC::GlobalTensor const &gmC, - AscendC::GlobalTensor const &gmPerTokenScale, - GemmCoord& blockCoord, - GemmCoord& actualBlockShape, - int32_t groupIdx, - int32_t preSrcExpertSum, - AscendC::GlobalTensor preSumBeforeRank, - uint32_t *mPreSumBeforeRank - ){ - is_ping = !is_ping; - auto event_id = is_ping ? EVENT_ID0 : EVENT_ID1; - auto event_id_2 = is_ping ? EVENT_ID2 : EVENT_ID3; - - auto &ubC = ubCList[is_ping]; - auto &ubD = ubDList[is_ping]; - int32_t gmCOffset = preSrcExpertSum * params.n2 + blockCoord.m() * params.n2 + blockCoord.n(); - auto gmTileC = gmC[gmCOffset]; - auto &ubCFp32 = ubFp32List[is_ping]; - auto &scaleUb = scaleUbList[is_ping]; - // auto &ubOutFp32 = ubOutFp32List[is_ping]; - - LayoutC layoutGM{actualBlockShape.m(), actualBlockShape.n(), params.n2}; - LayoutC layoutUB{actualBlockShape.m(), actualBlockShape.n(), n0}; - - - AscendC::WaitFlag(event_id); //for debug - copyGmToUbC(ubC, gmTileC, layoutUB, layoutGM); - AscendC::SetFlag(event_id); //for debug - - AscendC::WaitFlag(event_id); - AscendC::Cast(ubCFp32, ubC, AscendC::RoundMode::CAST_NONE, -1, repeat, {1, 1, 8, 4}); - AscendC::SetFlag(event_id); - - - AscendC::WaitFlag(event_id_2); - AscendC::WaitFlag(event_id_2); - - int32_t gmScaleOffset = preSrcExpertSum + blockCoord.m(); - layout::VectorLayout scaleLauout{actualBlockShape.m()}; - if (source_scale_offset[event_id] != gmScaleOffset) { - source_scale_offset[event_id] = gmScaleOffset; - copyScaleGmToUb(scaleUb, gmPerTokenScale[gmScaleOffset], scaleLauout, scaleLauout); - } - - AscendC::SetFlag(event_id_2); - AscendC::SetFlag(event_id_2); - - - - - AscendC::WaitFlag(event_id_2); - AscendC::WaitFlag(event_id_2); // 注意必须是MTE2_S,不能是MTE2_V,否则会读到0,造成乱码 - AscendC::PipeBarrier(); - for (int32_t row = 0; row < actualBlockShape.m(); ++row) { - float scale = scaleUb(row); - Muls(ubCFp32[n0* row], ubCFp32[n0 * row] , scale, -1, (actualBlockShape.n() + 127) / 128 * 2, {1, 1, 8, 8}); - } - AscendC::PipeBarrier(); - AscendC::WaitFlag(event_id); - AscendC::Cast(ubD, ubCFp32, AscendC::RoundMode::CAST_RINT, -1, repeat, {1, 1, 4, 8}); - AscendC::SetFlag(event_id_2); - AscendC::SetFlag(event_id_2); - AscendC::SetFlag(event_id); - - int32_t lenTile = actualBlockShape.m(); - int32_t stTile = blockCoord.m(); - int32_t edTile = stTile + lenTile; - int32_t preSumRankInExpert = 0; - int32_t tileOffset = 0; - - AscendC::WaitFlag(event_id); //for debug - for (int32_t dstEpIdx = 0; dstEpIdx < params.EP; dstEpIdx ++) { - int32_t lenRankInExpert = tokenPerExpert(tokenPerExpertLayout(dstEpIdx, params.rank, groupIdx)); - int32_t dstExpertOffset = preSumBeforeRank(dstEpIdx * 16); - int32_t stRankInExpert = preSumRankInExpert; - int32_t edRankInExpert = stRankInExpert + lenRankInExpert; - preSumRankInExpert += lenRankInExpert; - if (stRankInExpert >= edTile) { - break; - } - else if (edRankInExpert <= stTile) { - continue; - } - int32_t stData = max(stRankInExpert, stTile); - int32_t edData = min(edRankInExpert, edTile); - uint32_t lenData = edData - stData; - if (lenData <= 0){ - continue; - } - - uint32_t dstOffsetInExpert = 0; - if (stTile > stRankInExpert) { - dstOffsetInExpert = stTile - stRankInExpert; - } - AscendC::GlobalTensor gmRemotePeer; - __gm__ void* dstPeermemPtr = params.shmem(params.offsetD, dstEpIdx); - gmRemotePeer.SetGlobalBuffer(reinterpret_cast<__gm__ ElementD*>(dstPeermemPtr)); - MatrixCoord dstOffset{dstOffsetInExpert + dstExpertOffset + mPreSumBeforeRank[dstEpIdx], blockCoord.n()}; - int64_t gmDstOffset = params.layoutC.GetOffset(dstOffset); - auto gmTileD = gmRemotePeer[gmDstOffset]; - LayoutC layoutGM2{lenData, actualBlockShape.n(), params.n2}; - LayoutC layoutUB2{lenData, actualBlockShape.n(), n0}; - copyUbToGmD(gmTileD, ubD[tileOffset * n0], layoutGM2, layoutUB2); - tileOffset += lenData; - } - AscendC::SetFlag(event_id); - } CATLASS_DEVICE @@ -238,14 +110,12 @@ public: GemmCoord& actualBlockShape, int32_t groupIdx, int32_t preSrcExpertSum, - AscendC::GlobalTensor preSumBeforeRank, - uint32_t *mPreSumBeforeRank + AscendC::GlobalTensor preSumBeforeRank ){ is_ping = !is_ping; auto event_id = is_ping ? EVENT_ID0 : EVENT_ID1; auto &ubC = ubCList[is_ping]; - auto &ubD = ubDList[is_ping]; int32_t gmCOffset = preSrcExpertSum * params.n2 + blockCoord.m() * params.n2 + blockCoord.n(); auto gmTileC = gmC[gmCOffset]; @@ -253,7 +123,7 @@ public: LayoutC layoutUB{actualBlockShape.m(), actualBlockShape.n(), n0}; - AscendC::SetFlag(event_id); //for debug + AscendC::WaitFlag(event_id); //for debug copyGmToUbC(ubC, gmTileC, layoutUB, layoutGM); AscendC::SetFlag(event_id); //for debug @@ -263,10 +133,10 @@ public: int32_t preSumRankInExpert = 0; int32_t tileOffset = 0; - AscendC::WaitFlag(event_id); //for debug + AscendC::WaitFlag(event_id); for (int32_t dstEpIdx = 0; dstEpIdx < params.EP; dstEpIdx ++) { int32_t lenRankInExpert = tokenPerExpert(tokenPerExpertLayout(dstEpIdx, params.rank, groupIdx)); - int32_t dstExpertOffset = preSumBeforeRank(dstEpIdx * 16); + int32_t dstExpertOffset = preSumBeforeRank(dstEpIdx * params.expertPerRank + groupIdx); int32_t stRankInExpert = preSumRankInExpert; int32_t edRankInExpert = stRankInExpert + lenRankInExpert; preSumRankInExpert += lenRankInExpert; @@ -282,7 +152,7 @@ public: if (lenData <= 0){ continue; } - + uint32_t dstOffsetInExpert = 0; if (stTile > stRankInExpert) { dstOffsetInExpert = stTile - stRankInExpert; @@ -290,7 +160,7 @@ public: AscendC::GlobalTensor gmRemotePeer; __gm__ void* dstPeermemPtr = params.shmem(params.offsetD, dstEpIdx); gmRemotePeer.SetGlobalBuffer(reinterpret_cast<__gm__ ElementD*>(dstPeermemPtr)); - MatrixCoord dstOffset{dstOffsetInExpert + dstExpertOffset + mPreSumBeforeRank[dstEpIdx], blockCoord.n()}; + MatrixCoord dstOffset{dstOffsetInExpert + dstExpertOffset, blockCoord.n()}; int64_t gmDstOffset = params.layoutC.GetOffset(dstOffset); auto gmTileD = gmRemotePeer[gmDstOffset]; LayoutC layoutGM2{lenData, actualBlockShape.n(), params.n2}; @@ -298,7 +168,8 @@ public: copyUbToGmD(gmTileD, ubC[tileOffset * n0], layoutGM2, layoutUB2); tileOffset += lenData; } - AscendC::WaitFlag(event_id); + + AscendC::SetFlag(event_id); } diff --git a/csrc/dispatch_ffn_combine_bf16/op_kernel/utils/block_mmad_preload_async_fixpipe_quant.hpp b/csrc/dispatch_ffn_combine_bf16/op_kernel/utils/block_mmad_preload_async_fixpipe_quant.hpp index 3b435f26..4f935180 100644 --- a/csrc/dispatch_ffn_combine_bf16/op_kernel/utils/block_mmad_preload_async_fixpipe_quant.hpp +++ b/csrc/dispatch_ffn_combine_bf16/op_kernel/utils/block_mmad_preload_async_fixpipe_quant.hpp @@ -22,8 +22,6 @@ namespace Catlass::Gemm::Block { -constexpr uint16_t CROSS_CORE_FLAG_MAX_SET_COUNT = 15; - template __aicore__ inline void SyncFlagFunc(int32_t eventID) { @@ -153,9 +151,11 @@ public: L1TileShape::K, L1TileShape::N); CATLASS_DEVICE - BlockMmad(Arch::Resource &resource, uint32_t l1BufAddrStart = 0) + BlockMmad(Arch::Resource &resource, __gm__ int32_t* flagPtr = nullptr, int32_t expertPerRank = 0, uint32_t l1BufAddrStart = 0) { syncGroupIdx = 0; + ptrSoftFlagBase_ = flagPtr; + expertPerRank_ = expertPerRank; InitL1(resource, l1BufAddrStart); InitL0A(resource); InitL0B(resource); @@ -272,9 +272,21 @@ public: CATLASS_DEVICE void Finalize(int32_t target, int32_t flag = 0) { - for(;syncGroupIdx <= target; syncGroupIdx++) { - int32_t flagId = syncGroupIdx / CROSS_CORE_FLAG_MAX_SET_COUNT + flag; - AscendC::CrossCoreSetFlag<0x2, PIPE_FIX>(flagId); + if (ptrSoftFlagBase_ != nullptr) { + if (target < 0) { + return; + } + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + AscendC::GlobalTensor flagGlobal; + flagGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(ptrSoftFlagBase_) + (expertPerRank_ + AscendC::GetBlockIdx()) * FLAGSTRIDE); + AscendC::DataCopy(flagGlobal, l1FTensor[target * 16], FLAGSTRIDE); + } + else { + for(;syncGroupIdx <= target; syncGroupIdx++) { + int32_t flagId = syncGroupIdx / 15 + flag; + AscendC::CrossCoreSetFlag<0x2, PIPE_FIX>(flagId); + } } } private: @@ -291,7 +303,6 @@ private: layout::VectorLayout layoutScale; int32_t syncLoopIdx; int32_t flag; - CATLASS_DEVICE L1TileMmadParams() = default; }; @@ -310,11 +321,24 @@ private: AscendC::SetFlag(l1AEventList[i]); AscendC::SetFlag(l1BEventList[i]); } + uint32_t l1SOffset = l1BOffset + L1B_TILE_SIZE * L1_STAGES; if constexpr (std::is_same_v) { - uint32_t l1SOffset = l1BOffset + L1B_TILE_SIZE * L1_STAGES; l1STensor = resource.l1Buf.template GetBufferByByte(l1SOffset); AscendC::SetFlag(0); } + if (ptrSoftFlagBase_ != nullptr) { + // Initialize the flag matrix (structure as below): + // 1 0 0 0 0 0 0 0 + // 2 0 0 0 0 0 0 0 + // ... + // 16 0 0 0 0 0 0 0 + // Then move it to L1 + uint32_t l1FOffset = l1SOffset + L1S_TILE_SIZE; + l1FTensor = resource.l1Buf.template GetBufferByByte(l1FOffset); + AscendC::GlobalTensor flagBase; + flagBase.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(ptrSoftFlagBase_)); + AscendC::DataCopy(l1FTensor, flagBase, expertPerRank_ * FLAGSTRIDE); + } } CATLASS_DEVICE @@ -463,12 +487,20 @@ private: if constexpr (std::is_same_v) { AscendC::SetFlag(0); } + #ifdef __TILE_SYNC__ + if (params.flag > 0) { + int32_t flagId = params.flag + params.syncLoopIdx / 8; + AscendC::CrossCoreSetFlag<0x2, PIPE_FIX>(flagId); + } + #else Finalize(params.syncLoopIdx, params.flag); + #endif } } AscendC::LocalTensor l1ATensorList[L1_STAGES]; AscendC::LocalTensor l1BTensorList[L1_STAGES]; AscendC::LocalTensor l1STensor; + AscendC::LocalTensor l1FTensor; int32_t syncGroupIdx; int32_t l1AEventList[L1_STAGES]; int32_t l1BEventList[L1_STAGES]; @@ -497,8 +529,11 @@ private: CopyL1ToL0A copyL1ToL0A; CopyL1ToL0B copyL1ToL0B; CopyL0CToGm copyL0CToGm; + + __gm__ int32_t* ptrSoftFlagBase_ = nullptr; + int32_t expertPerRank_; }; } // namespace Catlass::Gemm::Block -#endif // CATLASS_GEMM_BLOCK_BLOCK_MMAD_PRELOAD_FIXPIPE_QUANT_HPP +#endif // CATLASS_GEMM_BLOCK_BLOCK_MMAD_PRELOAD_FIXPIPE_QUANT_HPP \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine_bf16/op_kernel/utils/const_args.hpp b/csrc/dispatch_ffn_combine_bf16/op_kernel/utils/const_args.hpp index 12262c68..f315b217 100644 --- a/csrc/dispatch_ffn_combine_bf16/op_kernel/utils/const_args.hpp +++ b/csrc/dispatch_ffn_combine_bf16/op_kernel/utils/const_args.hpp @@ -4,6 +4,10 @@ constexpr static uint64_t MB_SIZE = 1024 * 1024UL; constexpr static int32_t NUMS_PER_FLAG = 16; constexpr static int32_t CACHE_LINE = 512; +constexpr static int32_t FLAGSTRIDE = 16; constexpr static int32_t RESET_VAL = 0xffff; +constexpr static int32_t ALIGN_128 = 128; constexpr uint32_t MAX_EXPERTS_PER_RANK = 32; +constexpr static int32_t UB_ALIGN = 32; +constexpr uint16_t CROSS_CORE_FLAG_MAX_SET_COUNT = 15; #endif \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine_bf16/op_kernel/utils/hccl_shmem.hpp b/csrc/dispatch_ffn_combine_bf16/op_kernel/utils/hccl_shmem.hpp index ec88e8fd..93d4c9e7 100644 --- a/csrc/dispatch_ffn_combine_bf16/op_kernel/utils/hccl_shmem.hpp +++ b/csrc/dispatch_ffn_combine_bf16/op_kernel/utils/hccl_shmem.hpp @@ -5,16 +5,23 @@ #include "kernel_operator.h" #include "const_args.hpp" +#ifdef HCCL_COMM #include "moe_distribute_base.h" - -#ifndef HCCL_COMM -#include "shmem_api.h" using namespace AscendC::HcclContextDef; + +#else +#include "shmem_api.h" #endif #define FORCE_INLINE_AICORE inline __attribute__((always_inline)) __aicore__ constexpr int32_t MAX_RANK_SIZE = 32; -constexpr int32_t SHMEM_MEM = 1024 * MB_SIZE; +constexpr int32_t SHMEM_MEM = 700 * MB_SIZE; + +constexpr uint16_t SEND_SYNC_EVENT_ID = 9; +constexpr uint16_t RECV_SYNC_EVENT_ID = 10; + +constexpr uint32_t SELF_STATE_OFFSET = 256 * 1024; +constexpr uint32_t STATE_OFFSET = 512; FORCE_INLINE_AICORE void AicSyncAll() { AscendC::CrossCoreSetFlag<0x0, PIPE_FIX>(8); @@ -31,10 +38,11 @@ FORCE_INLINE_AICORE T gm_load(__gm__ T *cache) { return *((__gm__ T *)cache); } -FORCE_INLINE_AICORE void gm_dcci(__gm__ uint8_t * addr) { +template +FORCE_INLINE_AICORE void gm_dcci(__gm__ T * addr) { using namespace AscendC; GlobalTensor global; - global.SetGlobalBuffer(addr); + global.SetGlobalBuffer(reinterpret_cast(addr)); // Important: add hint to avoid dcci being optimized by compiler __asm__ __volatile__(""); @@ -58,7 +66,7 @@ FORCE_INLINE_AICORE int32_t gm_signal_wait_until_eq_for_barrier(__gm__ int32_t * FORCE_INLINE_AICORE void gm_signal_wait_until_ne(__gm__ int32_t *sig_addr, int32_t cmp_val) { do { AscendC::LocalTensor ub; - ub.address_.logicPos = static_cast(TPosition::VECIN); + ub.address_.logicPos = static_cast(AscendC::TPosition::VECIN); ub.address_.bufferAddr = 0; AscendC::GlobalTensor sig; sig.SetGlobalBuffer(sig_addr); @@ -75,10 +83,10 @@ FORCE_INLINE_AICORE void gm_signal_wait_until_ne(__gm__ int32_t *sig_addr, int32 class HcclShmem { public: - #ifdef HCCL_COMM + #ifdef HCCL_COMM // HCCL needs to initialize the HCCL context __gm__ HcclOpResParamCustom *WinContext_{nullptr}; Hccl hccl_; - GM_ADDR m_ptrArray[MAX_RANK_SIZE]; + AscendC::LocalTensor ub; FORCE_INLINE_AICORE HcclShmem(){ auto contextGM0 = AscendC::GetHcclContext(); @@ -87,17 +95,13 @@ public: m_rank = WinContext_->localUsrRankId; m_rankSize = WinContext_->rankSize; m_segmentSize = WinContext_->winSize; - for (int i = 0; i < m_rankSize; i++) { - m_ptrArray[i] = (GM_ADDR)((i == m_rank) ? WinContext_->localWindowsIn : - ((HcclRankRelationResV2Custom *)(WinContext_->remoteRes[i].nextDevicePtr))->windowsIn); - } } #else FORCE_INLINE_AICORE HcclShmem(){ m_segmentSize = SHMEM_MEM; } - FORCE_INLINE_AICORE + FORCE_INLINE_AICORE void initShmem(GM_ADDR symmetricPtr_, size_t rank, size_t rankSize) { symmetricPtr = symmetricPtr_; m_rank = rank; @@ -106,25 +110,26 @@ public: #endif FORCE_INLINE_AICORE - GM_ADDR operator() () const { + GM_ADDR operator() () const { // No parameters: return pointer to local peermem #ifdef HCCL_COMM - return m_ptrArray[m_rank]; + return (GM_ADDR)(WinContext_->localWindowsIn); #else return reinterpret_cast(shmem_ptr(symmetricPtr, m_rank)); #endif } FORCE_INLINE_AICORE - GM_ADDR operator() (int32_t index) const { + GM_ADDR operator() (int32_t index) const { // With index parameter: return pointer to the base address of remote peermem #ifdef HCCL_COMM - return m_ptrArray[index]; + return (GM_ADDR)((index == m_rank) ? WinContext_->localWindowsIn : + ((HcclRankRelationResV2Custom *)(WinContext_->remoteRes[index].nextDevicePtr))->windowsIn); #else return reinterpret_cast(shmem_ptr(symmetricPtr, index)); #endif } FORCE_INLINE_AICORE - GM_ADDR operator () (int64_t offset, int32_t rankId) const { + GM_ADDR operator () (int64_t offset, int32_t rankId) const { #ifdef HCCL_COMM if (offset < 0 || offset >= m_segmentSize) { return nullptr; @@ -132,7 +137,8 @@ public: if (rankId < 0 || rankId >= m_rankSize) { return nullptr; } - return m_ptrArray[rankId] + offset; + return (GM_ADDR)((rankId == m_rank) ? WinContext_->localWindowsIn : + ((HcclRankRelationResV2Custom *)(WinContext_->remoteRes[rankId].nextDevicePtr))->windowsIn) + offset; #else return reinterpret_cast(shmem_ptr((symmetricPtr + offset), rankId)); #endif @@ -176,6 +182,130 @@ public: gm_store(sync_base, count); } + + FORCE_INLINE_AICORE + void InitStatusTargetSum() + { + using namespace AscendC; + uint64_t flag_offset = (m_segmentSize - MB_SIZE) + SELF_STATE_OFFSET; + //uint64_t self_state_offset = (m_segmentSize - 2 * MB_SIZE); + // ep state + //uint32_t coreIdx = get_block_idx();; + uint32_t coreIdx = GetBlockIdx(); + GlobalTensor selfStatusTensor; + selfStatusTensor.SetGlobalBuffer((__gm__ int32_t *)((*this)() + flag_offset)); + __asm__ __volatile__(""); + DataCacheCleanAndInvalid(selfStatusTensor[coreIdx * UB_ALIGN]); + __asm__ __volatile__(""); + int32_t state = selfStatusTensor(coreIdx * UB_ALIGN); + if (state == 0) { + sumTarget_ = static_cast(1.0); + selfStatusTensor(coreIdx * UB_ALIGN) = 0x3F800000; // 1.0f + epStateValue_ = 0x3F800000; // 1.0f + } else { + sumTarget_ = static_cast(0.0); + selfStatusTensor(coreIdx * UB_ALIGN) = 0; + epStateValue_ = 0; + } + __asm__ __volatile__(""); + DataCacheCleanAndInvalid(selfStatusTensor[coreIdx * UB_ALIGN]); + __asm__ __volatile__(""); + } + + FORCE_INLINE_AICORE + void CrossRankSyncV2Set(AscendC::LocalTensor ctrBuffer) { + //subblockid = 0 + uint32_t stateOffset_ = STATE_OFFSET; + // uint32_t epStateOffsetOnWin_ = m_rank * stateOffset_; + + uint64_t flag_offset = (m_segmentSize - MB_SIZE) + m_rank * stateOffset_; + //uint64_t flag_offset = (m_segmentSize - MB_SIZE); + int vec_size = get_block_num(); + int vec_id = get_block_idx(); + + AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>(RECV_SYNC_EVENT_ID); + AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>(SEND_SYNC_EVENT_ID); + AscendC::CrossCoreWaitFlag(SEND_SYNC_EVENT_ID); + pipe_barrier(PIPE_ALL); + + ctrBuffer.SetValue(0, epStateValue_); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + for (uint32_t dstEpIdx = vec_id; dstEpIdx < m_rankSize; dstEpIdx += vec_size) { + AscendC::GlobalTensor gmDstStates; + gmDstStates.SetGlobalBuffer((__gm__ int32_t*)((*this)(flag_offset, dstEpIdx))); + DataCopy(gmDstStates, ctrBuffer, 8); + } + AscendC::CrossCoreWaitFlag(RECV_SYNC_EVENT_ID); + } + + FORCE_INLINE_AICORE + void CrossRankSyncV2Wait(AscendC::LocalTensor statusTensor, AscendC::LocalTensor gatherMaskOutTensor, + AscendC::LocalTensor gatherTmpTensor, AscendC::LocalTensor statusSumOutTensor) { + + uint64_t flag_offset = (m_segmentSize - MB_SIZE); + int vec_size = get_block_num(); + int vec_id = get_block_idx(); + uint32_t stateOffset_ = STATE_OFFSET; + + uint32_t sendRankNum_ = m_rankSize / vec_size; + uint32_t remainderRankNum = m_rankSize % vec_size; + uint32_t startRankId_ = sendRankNum_ * vec_id; + if (vec_id < remainderRankNum) { + sendRankNum_++; + startRankId_ += vec_id; + } else { + startRankId_ += remainderRankNum; + } + uint32_t endRankId_ = startRankId_ + sendRankNum_; + AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>(SEND_SYNC_EVENT_ID); + + AscendC::GlobalTensor epStatusSpaceGlobalTensor_; + epStatusSpaceGlobalTensor_.SetGlobalBuffer((__gm__ float *)((*this)() + flag_offset)); + + if (startRankId_ < m_rankSize) { + AscendC::PipeBarrier(); + gatherTmpTensor.SetValue(0, 1); + uint32_t mask = 1; // gatherMask + sum + uint64_t rsvdCnt = 0; + // DataCopyParams intriParams{static_cast(sendRankNum_), 1, + // static_cast((moeSendNum_ > 512) ? 7 : 15), 0}; + AscendC::DataCopyParams intriParams{static_cast(sendRankNum_), 1, + static_cast(15), 0}; + + float sumOfFlag = static_cast(-1.0); + float minTarget = (sumTarget_ * sendRankNum_) - (float)0.5; + float maxTarget = (sumTarget_ * sendRankNum_) + (float)0.5; + AscendC::SumParams sumParams{1, sendRankNum_, sendRankNum_}; + + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + + while ((sumOfFlag < minTarget) || (sumOfFlag > maxTarget)) { + AscendC::DataCopy(statusTensor, epStatusSpaceGlobalTensor_[startRankId_ * stateOffset_ / sizeof(float)], + intriParams); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + + GatherMask(gatherMaskOutTensor, statusTensor, gatherTmpTensor, true, mask, + {1, (uint16_t)sendRankNum_, 1, 0}, rsvdCnt); + + AscendC::PipeBarrier(); + AscendC::Sum(statusSumOutTensor, gatherMaskOutTensor, sumParams); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + sumOfFlag = statusSumOutTensor.GetValue(0); + } + } + + AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>(RECV_SYNC_EVENT_ID); + AscendC::CrossCoreWaitFlag(RECV_SYNC_EVENT_ID); + + //unpermute + AscendC::CrossCoreWaitFlag(SEND_SYNC_EVENT_ID); + } + + FORCE_INLINE_AICORE __gm__ int32_t* SyncBaseAddr() { uint64_t flag_offset = (m_segmentSize - MB_SIZE) / sizeof(int32_t); @@ -187,9 +317,11 @@ private: int32_t m_rank; int32_t m_rankSize; size_t m_segmentSize; + float sumTarget_{0.0}; + int32_t epStateValue_; }; -#endif \ No newline at end of file +#endif