From 832552836842e00cc9de75bc853a90338694ccc4 Mon Sep 17 00:00:00 2001 From: xulei <33539210+serlar@users.noreply.github.com> Date: Mon, 9 Feb 2026 16:30:34 +0800 Subject: [PATCH] [Kernel]: Optimize DispatchFFNCombine performance (#6468) ### What this PR does / why we need it? This PR focuses on performance optimization for the DispatchFFNCombine operator. The key optimizations include: 1. Improving communication efficiency by merging the transmission of tokens and scales; 2. Decoupling multi-core dependencies and reducing waiting bubbles in the combine process through tile-granularity communication; 3. Optimizing the full-card synchronization overhead before the umpermute operation. These optimizations aim to reduce the overall execution latency of the DispatchFFNCombine operator and enhance the runtime performance of the model inference process on Ascend devices. ### Does this PR introduce _any_ user-facing change? No. This PR only involves internal performance optimization of the DispatchFFNCombine operator and does not introduce any changes to user-facing APIs, interfaces, or behaviors. ### How was this patch tested? 1. Enable the DispatchFFNCombine operator by setting the environment variable: ``` export VLLM_ASCEND_ENABLE_FUSED_MC2=1 ``` 2. Run the standard model inference test suite with the above environment variable enabled; 4. Verify the correctness of model outputs (ensuring no functional regression) and measure the performance improvement of the DispatchFFNCombine operator (reduced latency and improved throughput). - vLLM version: v0.14.1 - vLLM main: https://github.com/vllm-project/vllm/commit/dc917cceb877dfd13f98c538c4c96158047d98bd Signed-off-by: xulei_ict Co-authored-by: xulei_ict --- .../op_host/dispatch_ffn_combine_tiling.cpp | 12 +- .../op_kernel/dispatch_ffn_combine.cpp | 20 +- .../op_kernel/dispatch_ffn_combine.h | 2 +- .../op_kernel/dispatch_ffn_combine_kernel.hpp | 549 ++++++++++++------ .../moe_init_routing_quant_v2/moe_v2_common.h | 2 + .../moe_v2_fullload_dynamic_quant.h | 84 +-- .../moe_v2_gather_dynamic_quant.h | 39 +- .../op_kernel/unpermute/moe_token_unpermute.h | 4 +- .../utils/block_epilogue_pertoken_v2.hpp | 243 ++++++++ ...block_mmad_preload_async_fixpipe_quant.hpp | 53 +- .../op_kernel/utils/const_args.hpp | 4 +- .../utils/dispatch_policy_custom.hpp | 4 +- .../op_kernel/utils/hccl_shmem.hpp | 237 ++++++-- 13 files changed, 897 insertions(+), 356 deletions(-) create mode 100644 csrc/dispatch_ffn_combine/op_kernel/utils/block_epilogue_pertoken_v2.hpp diff --git a/csrc/dispatch_ffn_combine/op_host/dispatch_ffn_combine_tiling.cpp b/csrc/dispatch_ffn_combine/op_host/dispatch_ffn_combine_tiling.cpp index 8b16f0b9..d41d4d93 100644 --- a/csrc/dispatch_ffn_combine/op_host/dispatch_ffn_combine_tiling.cpp +++ b/csrc/dispatch_ffn_combine/op_host/dispatch_ffn_combine_tiling.cpp @@ -17,11 +17,11 @@ #include "error_log.h" #include "hcom_topo_info.h" #include "register/op_def_registry.h" -#include "dispatch_ffn_combine_tiling.h" +#include "../op_kernel/dispatch_ffn_combine_tiling.h" #include #include #include -#include "moe_init_routing_quant_v2/moe_init_routing_quant_v2_tiling.h" +#include "../op_kernel/moe_init_routing_quant_v2/moe_init_routing_quant_v2_tiling.h" using namespace AscendC; using namespace ge; @@ -278,8 +278,12 @@ static ge::graphStatus DispatchFFNCombineTilingFuncImpl(gert::TilingContext *con uint64_t cocWorkspace = (info.M + 256 - 1) / 256 * 256 * info.topK *sizeof(int32_t) + info.worldSize * info.worldSize * info.expertPerRank * sizeof(int32_t) * 3 + info.maxOutputSize * sizeof(float) * 2 + - std::max(info.maxOutputSize * info.N * sizeof(int16_t), info.maxOutputSize * n2 * sizeof(int16_t)) + - std::max(info.maxOutputSize * info.K * sizeof(int8_t), info.maxOutputSize * k2 * sizeof(int8_t)); + info.maxOutputSize * info.N * sizeof(int16_t) + + info.maxOutputSize * n2 * sizeof(int16_t) + + info.maxOutputSize * info.K * sizeof(int8_t) + + info.maxOutputSize * k2 * sizeof(int8_t) + + info.worldSize * sizeof(int32_t) * 16 + + (info.expertPerRank + info.worldSize) * sizeof(int32_t) * 16; workSpaces[0] = SYSTEM_NEED_WORKSPACE + std::max(cocWorkspace, initRoutingWorkspace); diff --git a/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine.cpp b/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine.cpp index 476f43e5..86680f30 100644 --- a/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine.cpp +++ b/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine.cpp @@ -23,29 +23,11 @@ extern "C" __global__ __aicore__ void dispatch_ffn_combine(GM_ADDR x, GM_ADDR w1 GM_ADDR c, GM_ADDR expertTokenNums, GM_ADDR workspaceGM, GM_ADDR tilingGM) { REGISTER_TILING_DEFAULT(DispatchFFNCombineTilingData); - if (TILING_KEY_IS(1000000)) { - KERNEL_TASK_TYPE(1000000, KERNEL_TYPE_MIX_AIC_1_2); - GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineTilingData, tilingData, tilingGM); - DispatchFFNCombine op; - op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, expertTokenNums, workspaceGM, tilingGM); - op.Process(); - } else if (TILING_KEY_IS(1000001)) { - KERNEL_TASK_TYPE(1000001, KERNEL_TYPE_MIX_AIC_1_2); - GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineTilingData, tilingData, tilingGM); - DispatchFFNCombine op; - op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, expertTokenNums, workspaceGM, tilingGM); - op.Process(); - } else if (TILING_KEY_IS(1000010)) { + if (TILING_KEY_IS(1000010)) { KERNEL_TASK_TYPE(1000010, KERNEL_TYPE_MIX_AIC_1_2); GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineTilingData, tilingData, tilingGM); DispatchFFNCombine op; op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, expertTokenNums, workspaceGM, tilingGM); op.Process(); - } else if (TILING_KEY_IS(1000011)) { - KERNEL_TASK_TYPE(1000011, KERNEL_TYPE_MIX_AIC_1_2); - GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineTilingData, tilingData, tilingGM); - DispatchFFNCombine op; - op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, expertTokenNums, workspaceGM, tilingGM); - op.Process(); } } \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine.h b/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine.h index 31c0471f..4e73b832 100644 --- a/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine.h +++ b/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine.h @@ -234,7 +234,7 @@ __aicore__ inline void DispatchFFNCombine::Process() using BlockEpilogue1 = Epilogue::Block::BlockEpilogue; - using EpilogueDispatchPolicy2 = Epilogue::EpilogueAtlasA2PerTokenDequant; + using EpilogueDispatchPolicy2 = Epilogue::EpilogueAtlasA2PerTokenDequantV2; using TileCopy2 = Epilogue::Tile::TileCopy; using BlockEpilogue2 = Epilogue::Block::BlockEpilogue; diff --git a/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine_kernel.hpp b/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine_kernel.hpp index e0956c92..fed0232c 100644 --- a/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine_kernel.hpp +++ b/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine_kernel.hpp @@ -22,21 +22,38 @@ #include "catlass/matrix_coord.hpp" #include "catlass/epilogue/tile/tile_copy.hpp" -#include "utils/block_mmad_preload_async_fixpipe_quant.hpp" -#include "utils/copy_gm_to_l1_custom.hpp" -#include "utils/copy_l0c_to_gm_custom.hpp" -#include "utils/block_epilogue_pertoken_row.hpp" -#include "utils/block_epilogue_pertoken_swiglu.hpp" -#include "utils/hccl_shmem.hpp" -#include "utils/const_args.hpp" -#include "utils/layout3d.hpp" -#include "utils/get_tensor_addr.hpp" - -#include "moe_init_routing_quant_v2/moe_init_routing_quant_v2_tiling.h" -#include "moe_init_routing_quant_v2/moe_init_routing_quant_v2.cpp" -#include "moe_init_routing_quant_v2/moe_v2_fullload_dynamic_quant.h" -#include "unpermute/moe_token_unpermute.h" - +#ifndef HCCL_COMM + #include "block_mmad_preload_async_fixpipe_quant.hpp" + #include "copy_gm_to_l1_custom.hpp" + #include "copy_l0c_to_gm_custom.hpp" + #include "block_epilogue_pertoken_row.hpp" + #include "block_epilogue_pertoken_v2.hpp" + #include "block_epilogue_pertoken_swiglu.hpp" + #include "hccl_shmem.hpp" + #include "const_args.hpp" + #include "layout3d.hpp" + #include "tiling/moe_init_routing_quant_v2_tiling.h" + #include "moe_init_routing_quant_v2/moe_init_routing_quant_v2.cpp" + #include "moe_init_routing_quant_v2/moe_v2_fullload_dynamic_quant.h" + #include "moe_token_unpermute.h" + #include "get_tensor_addr.hpp" + inline __gm__ struct OpSystemRunCfg g_opSystemRunCfg{Catlass::L2_OFFSET}; +#else + #include "utils/block_mmad_preload_async_fixpipe_quant.hpp" + #include "utils/copy_gm_to_l1_custom.hpp" + #include "utils/copy_l0c_to_gm_custom.hpp" + #include "utils/block_epilogue_pertoken_row.hpp" + #include "utils/block_epilogue_pertoken_v2.hpp" + #include "utils/block_epilogue_pertoken_swiglu.hpp" + #include "utils/hccl_shmem.hpp" + #include "utils/const_args.hpp" + #include "utils/layout3d.hpp" + #include "moe_init_routing_quant_v2/moe_init_routing_quant_v2_tiling.h" + #include "moe_init_routing_quant_v2/moe_init_routing_quant_v2.cpp" + #include "moe_init_routing_quant_v2/moe_v2_fullload_dynamic_quant.h" + #include "unpermute/moe_token_unpermute.h" + #include "utils/get_tensor_addr.hpp" +#endif using namespace AscendC; @@ -44,7 +61,6 @@ namespace Catlass::Gemm::Kernel { constexpr uint16_t SYNCFLAGC2V = 9; constexpr uint16_t SYNCFLAGV2C = 10; -constexpr uint16_t CROSS_CORE_FLAG_MAX_SET_COUNT = 15; template < class BlockMmad_, @@ -104,6 +120,7 @@ public: uint32_t rank; uint32_t rankSize; int32_t ubMoveNum; + GM_ADDR symmetricPtr; //-------------- GM_ADDR expertIdx; GM_ADDR moeInitRoutingQuantV2Scale; @@ -193,9 +210,7 @@ public: void operator()(Params const ¶ms) { GMM1(params); - AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGV2C); - GMM2(params); } @@ -204,32 +219,26 @@ public: CATLASS_DEVICE void operator()(Params const ¶ms) { - Dispatch(params); - AscendC::SyncAll(); - AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(SYNCFLAGV2C); - - Combine(params); + DispatchAndCombine(params); } private: CATLASS_DEVICE void initBuffer(Params const ¶ms) { + #ifndef HCCL_COMM + shmem.initShmem(params.symmetricPtr, params.rank, params.rankSize); + #endif workspaceInfo = WorkspaceInfo(params); peermemInfo = PeermemInfo(params, shmem); - cumsumMM.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(workspaceInfo.ptrcumsumMM)); - gmA.SetGlobalBuffer(reinterpret_cast<__gm__ ElementA *>(workspaceInfo.ptrA)); gmC.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(workspaceInfo.ptrC)); - gmPermutedToken.SetGlobalBuffer(reinterpret_cast<__gm__ ElementD1 *>(workspaceInfo.ptrPermutedToken)); gmC2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(workspaceInfo.ptrC2)); - gmPerTokenScale1.SetGlobalBuffer(reinterpret_cast<__gm__ ElementPerTokenScale *>(workspaceInfo.ptrPerTokenScale)); gmPerTokenScale2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementPerTokenScale *>(workspaceInfo.ptrPerTokenScale2)); - tokenPerExpert.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(shmem() + peermemInfo.offsetPeerTokenPerExpert)); - tokenPerExpertLayout = Layout3D(AlignUp(params.EP * params.expertPerRank, ALIGN_128), params.expertPerRank); + preSumBeforeRank.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(workspaceInfo.ptrSumBeforeRank)); } template @@ -285,6 +294,51 @@ private: AscendC::WaitFlag(EVENT_ID1); } + // Move tokens and scales together, then write them to different positions respectively + template + CATLASS_DEVICE void CopyGMToGMPerToken( + AscendC::GlobalTensor dst, + AscendC::GlobalTensor dstScale, + AscendC::GlobalTensor src, + int32_t rows, + int32_t hiddenSize + ) + { + AscendC::SetFlag(EVENT_ID0); + AscendC::SetFlag(EVENT_ID1); + + constexpr int32_t BufferNum = 2; + AscendC::LocalTensor tmpBuffer1 = resource.ubBuf.template GetBufferByByte(0); + constexpr int tmpBufferOffset = 96 * 1024; // half of UB + AscendC::LocalTensor tmpBuffer2 = resource.ubBuf.template GetBufferByByte(tmpBufferOffset); + uint32_t copyInNum = hiddenSize + ALIGN_512; + // [ReduceScatter] 2. Pre Interface Sync + int pingpongId = 0; + for (uint32_t processIndex = 0; processIndex < rows; ++processIndex) { + AscendC::TEventID EVENT_ID = pingpongId == 0 ? EVENT_ID0 : EVENT_ID1; + AscendC::LocalTensor buf = pingpongId == 0 ? tmpBuffer1 : tmpBuffer2; + AscendC::LocalTensor bufScale = buf[hiddenSize].template ReinterpretCast(); + auto inputOffset = processIndex * copyInNum; + auto outputOffset = processIndex * hiddenSize; + // [ReduceScatter] 2. Pre Interface Sync + AscendC::WaitFlag(EVENT_ID); + // [ReduceScatter] 3. Start shmem_mte_get_mem_nbi + AscendC::DataCopy(buf, src[inputOffset], copyInNum); + AscendC::SetFlag(EVENT_ID); + AscendC::WaitFlag(EVENT_ID); + AscendC::DataCopy(dst[outputOffset], buf, hiddenSize); + AscendC::DataCopyPad(dstScale[processIndex], bufScale, {1, 4, 0, 0, 0}); + + // [ReduceScatter] 4. Post Interface Sync + AscendC::SetFlag(EVENT_ID); + pingpongId = (pingpongId + 1) % BufferNum; + } + // [ReduceScatter] 4. Post Interface Sync + + AscendC::WaitFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID1); + } + CATLASS_DEVICE void GetCumsumForMMAIV(AscendC::GlobalTensor & tokenPerExpert, AscendC::GlobalTensor & result, uint32_t expertPerRank, uint32_t rankId, uint32_t EP) { @@ -296,7 +350,7 @@ private: AscendC::DataCopyPad( tmpBuffer1, tokenPerExpert[rankId * expertPerRank], - {U16(EP), U16(expertPerRank * sizeof(int32_t)), U16((AlignUp(EP * expertPerRank, ALIGN_128) - expertPerRank) * sizeof(int32_t)), 0}, + {U16(EP), U16(expertPerRank * sizeof(int32_t)), U16((AlignUp(EP * expertPerRank, 128) - expertPerRank) * sizeof(int32_t)), 0}, {} ); @@ -323,6 +377,8 @@ private: icache_preload(8); BlockScheduler blockScheduler; BlockMmad blockMmad(resource); + float aivFinishGroups = 0.0f; + __gm__ float* aivFinishPtr = workspaceInfo.ptrSoftFlagBase + params.EP * FLAGSTRIDE; int64_t gmGroupOffsetA = 0; int64_t gmGroupOffsetB = 0; @@ -331,11 +387,10 @@ private: uint32_t syncGroupIdx = 0; int64_t preCurrentmSum = 0; int32_t syncLoopIdx = -1; - + uint16_t syncgmmIdx = 0; AscendC::CrossCoreWaitFlag<0x2>(syncgmmIdx / CROSS_CORE_FLAG_MAX_SET_COUNT); // Wait for AIV to finish cumsum for matmul syncgmmIdx++; - AscendC::PipeBarrier(); for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) { @@ -350,9 +405,7 @@ private: int32_t arrayGroupIdx = params.listLen == 1 ? 0 : groupIdx; gmB1.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(GetTensorAddr(arrayGroupIdx, params.ptrB1))); gmS.SetGlobalBuffer(reinterpret_cast<__gm__ ElementScale *>(GetTensorAddr(arrayGroupIdx, params.ptrScale1))); - AscendC::PipeBarrier(); - if (currentM <= L1TileShape::M) { gmB1.SetL2CacheHint(AscendC::CacheMode::CACHE_MODE_DISABLE); } @@ -372,6 +425,7 @@ private: AscendC::CrossCoreWaitFlag<0x2>(syncgmmIdx / CROSS_CORE_FLAG_MAX_SET_COUNT); syncgmmIdx ++; } + // Compute block location GemmCoord blockCoord = blockScheduler.GetBlockCoord(loopIdx); GemmCoord actualBlockShape = blockScheduler.GetActualBlockShape(blockCoord); @@ -399,6 +453,7 @@ private: if constexpr (BlockMmad::DispatchPolicy::ASYNC) { blockMmad.SynchronizeBlock(); } + // Synchronization signal: GMM1 notifies SwiGLU [1] blockMmad.Finalize(syncLoopIdx, SYNCFLAGC2V); } @@ -419,6 +474,7 @@ private: if constexpr (BlockMmad::DispatchPolicy::ASYNC) { blockMmad.SynchronizeBlock(); } + // Synchronization signal: GMM1 notifies SwiGLU [2] blockMmad.Finalize(syncLoopIdx + 1, SYNCFLAGC2V); } @@ -427,7 +483,7 @@ private: icache_preload(8); BlockScheduler blockScheduler; BlockMmad blockMmad(resource); - + uint32_t n2 = params.problemShape.k(); uint32_t k2 = params.problemShape.n() / 2; @@ -458,11 +514,9 @@ private: } AscendC::GlobalTensor gmB2; AscendC::GlobalTensor gmS2; - AscendC::PipeBarrier(); int32_t arrayGroupIdx = params.listLen == 1 ? 0 : groupIdx; gmB2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(GetTensorAddr(arrayGroupIdx, params.ptrB2))); gmS2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementScale *>(GetTensorAddr(arrayGroupIdx, params.ptrScale2))); - if (currentM <= L1TileShape::M) { gmB2.SetL2CacheHint(AscendC::CacheMode::CACHE_MODE_DISABLE); } @@ -482,6 +536,7 @@ private: if (params.expertPerRank > lastDequantExpertNum && groupIdx + 1 == params.expertPerRank - lastDequantExpertNum) { AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGV2C); } + for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += coreNum) { if (loopIdx + coreNum >= coreLoops) { syncLoopIdx = groupIdx; @@ -502,12 +557,12 @@ private: int64_t gmOffsetS = blockCoord.n() * L1TileShape::N + (params.listLen == 1 ? groupIdx * n2 : 0); // One scale group per expert if (currentM > 0) { blockMmad( - gmPermutedToken[gmGroupOffsetA + gmOffsetA], layoutA, - gmB2[gmGroupOffsetB + gmOffsetB], layoutB2, - gmC2[gmGroupOffsetC + gmOffsetC], layoutC, - gmS2[gmOffsetS], layoutScale, - actualBlockShape, syncLoopIdx, 0 - ); + gmPermutedToken[gmGroupOffsetA + gmOffsetA], layoutA, + gmB2[gmGroupOffsetB + gmOffsetB], layoutB2, + gmC2[gmGroupOffsetC + gmOffsetC], layoutC, + gmS2[gmOffsetS], layoutScale, + actualBlockShape, syncLoopIdx, 0 + ); } } preCurrentmSum += currentM; @@ -518,34 +573,34 @@ private: gmGroupOffsetC += inGroupProblemShape.m() * inGroupProblemShape.n(); startCoreIdx = (startCoreIdx + coreLoops) % coreNum; - } - if constexpr (BlockMmad::DispatchPolicy::ASYNC) { blockMmad.SynchronizeBlock(); } - blockMmad.Finalize(params.expertPerRank - 1, 0); } - CATLASS_DEVICE - void ResetTokenPerExpert(AscendC::GlobalTensor & tokenPerExpert, int32_t num) - { - if (coreIdx != coreNum - 1) { - return; - } + + CATLASS_DEVICE + void InitArithProgress(Params const ¶ms) { + AscendC::LocalTensor tmpBuffer1 = resource.ubBuf.template GetBufferByByte(0); AscendC::SetFlag(EVENT_ID0); AscendC::WaitFlag(EVENT_ID0); - AscendC::LocalTensor tmp = resource.ubBuf.template GetBufferByByte(0); - AscendC::Duplicate(tmp, 0, num); + AscendC::Duplicate(tmpBuffer1, 0.0f, (params.EP + 1) * FLAGSTRIDE); AscendC::SetFlag(EVENT_ID0); AscendC::WaitFlag(EVENT_ID0); - AscendC::DataCopy(tokenPerExpert, tmp, num); + + AscendC::GlobalTensor flagGlobalBase; + flagGlobalBase.SetGlobalBuffer(workspaceInfo.ptrSoftFlagBase); + AscendC::DataCopy(flagGlobalBase, tmpBuffer1, (params.EP + 1) * FLAGSTRIDE); } + CATLASS_DEVICE - void CrossRankSyncAndlocalTokenPerExpertAllGather(Params const ¶ms, int64_t localTokenPerExpertOffset){ + void CrossRankSyncAndlocalTokenPerExpertAllGatherAndGetSumPreRankV2(Params const ¶ms, int64_t localTokenPerExpertOffset){ + uint32_t numPerCore = AlignUp(params.EP * params.expertPerRank, 128); AscendC::LocalTensor tmpBuffer = resource.ubBuf.template GetBufferByByte(0); - uint32_t numPerCore = AlignUp(params.EP * params.expertPerRank, ALIGN_128); + AscendC::LocalTensor prevSumBuf = tmpBuffer[numPerCore]; + for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) { if (dstEpIdx == params.rank) { continue; @@ -562,11 +617,11 @@ private: using CopyUbToGm = Epilogue::Tile::CopyUb2Gm; CopyGmToUb copyGmToUb; CopyUbToGm copyUbToGm; - + AscendC::WaitFlag(EVENT_ID0); - - copyGmToUb(tmpBuffer, srcAddress[0], - layout::RowMajor{ 1, numPerCore}, + + copyGmToUb(tmpBuffer, srcAddress[0], + layout::RowMajor{ 1, numPerCore}, layout::RowMajor{1, numPerCore}); AscendC::SetFlag(EVENT_ID0); @@ -574,35 +629,125 @@ private: AscendC::Adds(tmpBuffer, tmpBuffer, 0x800000, numPerCore); AscendC::SetFlag(EVENT_ID0); AscendC::WaitFlag(EVENT_ID0); - copyUbToGm(dstAddress[0], tmpBuffer, - layout::RowMajor{ 1, numPerCore}, + copyUbToGm(dstAddress[0], tmpBuffer, + layout::RowMajor{ 1, numPerCore}, layout::RowMajor{1, numPerCore}); AscendC::SetFlag(EVENT_ID0); AscendC::WaitFlag(EVENT_ID0); } for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) { - if (dstEpIdx == params.rank) { - continue; + if (dstEpIdx != params.rank) { + int32_t intPer512 = CACHE_LINE / sizeof(int); + for(int32_t checkIdx = 0; checkIdx < AlignUp(params.EP * params.expertPerRank, 128); checkIdx += intPer512) { + __gm__ int32_t* sync_check = reinterpret_cast<__gm__ int32_t*>(shmem() + peermemInfo.offsetPeerTokenPerExpert) + tokenPerExpertLayout(dstEpIdx, 0, checkIdx); + gm_signal_wait_until_ne(sync_check, 0); + } + AscendC::DataCopy(tmpBuffer, tokenPerExpert[tokenPerExpertLayout(dstEpIdx, 0, 0)], numPerCore); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + AscendC::Adds(tmpBuffer, tmpBuffer, -0x800000, numPerCore); + AscendC::PipeBarrier(); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + AscendC::DataCopy(tokenPerExpert[tokenPerExpertLayout(dstEpIdx, 0, 0)], tmpBuffer, numPerCore); + } else { + AscendC::DataCopy(tmpBuffer, tokenPerExpert[tokenPerExpertLayout(dstEpIdx, 0, 0)], numPerCore); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); } - int32_t intPer512 = CACHE_LINE / sizeof(int); - for(int32_t checkIdx = 0; checkIdx < AlignUp(params.EP * params.expertPerRank, ALIGN_128); checkIdx += intPer512) { - __gm__ int32_t* sync_check = reinterpret_cast<__gm__ int32_t*>(shmem() + peermemInfo.offsetPeerTokenPerExpert) + tokenPerExpertLayout(dstEpIdx, 0, checkIdx); - gm_signal_wait_until_ne(sync_check, 0); + AscendC::PipeBarrier(); + int32_t prevSum = 0; + int32_t j = 0; + for (int32_t i = 0; i < (params.rank + 1) * params.expertPerRank; i++) { + if (i >= params.rank * params.expertPerRank) { + prevSumBuf(j) = prevSum; + j++; + } + prevSum += tmpBuffer(i); } - AscendC::DataCopy(tmpBuffer, tokenPerExpert[tokenPerExpertLayout(dstEpIdx, 0, 0)], numPerCore); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + AscendC::DataCopyPad(preSumBeforeRank[dstEpIdx * params.expertPerRank], prevSumBuf, + AscendC::DataCopyParams{1, static_cast(params.expertPerRank * sizeof(int32_t)), 0, 0}); + } + + AscendC::SyncAll(); + } + + CATLASS_DEVICE + void ResetTokenPerExpert(int32_t num) + { + if (coreIdx != coreNum - 1) { + return; + } + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + AscendC::LocalTensor tmp = resource.ubBuf.template GetBufferByByte(0); + AscendC::Duplicate(tmp, 0, num); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + AscendC::DataCopy(tokenPerExpert, tmp, num); + } + + CATLASS_DEVICE + void UpdateAicFlags(const Params ¶ms) + { + float flagBase = 1.0f * params.expertPerRank; + __gm__ float* aicFinishPtr = workspaceInfo.ptrSoftFlagBase + params.EP * FLAGSTRIDE; + float flag = 0.0f; + float lastflag = -1.0f; + AscendC::LocalTensor tmpBuffer1 = resource.ubBuf.template GetBufferByByte(0); + __gm__ float* flagPtr = workspaceInfo.ptrSoftFlagBase; + AscendC::GlobalTensor flagGM; + flagGM.SetGlobalBuffer(flagPtr); + int32_t flagBufferSize = max(4, params.EP) * FLAGSTRIDE; + AscendC::LocalTensor dstValueBuffer = resource.ubBuf.template GetBufferByByte(flagBufferSize); + AscendC::LocalTensor sharedTmpBuffer = resource.ubBuf.template GetBufferByByte((flagBufferSize + 64)); + uint64_t mask[1] = {0}; + uint32_t repeatNum = (flagBufferSize / (4 * FLAGSTRIDE)); + for (int32_t i = 0; i < 4; i ++) { + if (i < params.EP) { + mask[0] |= 1ull * (1ull << (i * 16)); + } + } + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + while (flag < flagBase) { + flag = flagBase; + AscendC::DataCopy(tmpBuffer1, flagGM, params.EP * FLAGSTRIDE); AscendC::SetFlag(EVENT_ID0); AscendC::WaitFlag(EVENT_ID0); - AscendC::Adds(tmpBuffer, tmpBuffer, -0x800000, numPerCore); - AscendC::SetFlag(EVENT_ID0); - AscendC::WaitFlag(EVENT_ID0); - AscendC::DataCopy(tokenPerExpert[tokenPerExpertLayout(dstEpIdx, 0, 0)], tmpBuffer, numPerCore); + + AscendC::ReduceMin(dstValueBuffer, tmpBuffer1, sharedTmpBuffer, mask, repeatNum, 8, false); + + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + flag = min(flag, dstValueBuffer.GetValue(0)); + + if (flag > lastflag) { + *aicFinishPtr = flag; + gm_dcci(aicFinishPtr); + lastflag = flag; + } } - AscendC::SyncAll(); } CATLASS_DEVICE - void Dispatch(Params const ¶ms) { + void CombineSetFlag() { + AscendC::SetFlag(EVENT_ID0); + AscendC::SetFlag(EVENT_ID1); + AscendC::SetFlag(EVENT_ID2); + AscendC::SetFlag(EVENT_ID3); + AscendC::SetFlag(EVENT_ID2); + AscendC::SetFlag(EVENT_ID3); + AscendC::SetFlag(EVENT_ID0); + AscendC::SetFlag(EVENT_ID1); + } + + + CATLASS_DEVICE + void DispatchAndCombine(Params const ¶ms) { icache_preload(8); int64_t localTokenPerExpertOffset = peermemInfo.offsetPeerTokenPerExpert + tokenPerExpertLayout(params.rank, 0, 0) * sizeof(int32_t); GM_ADDR localTokenPerExpert = shmem() + localTokenPerExpertOffset; // Place the entire communication matrix in peermem @@ -617,10 +762,19 @@ private: ¶ms.moeInitRoutingQuantV2TilingData, params.initRoutingQuantTilingKey); AscendC::SyncAll(); - CrossRankSyncAndlocalTokenPerExpertAllGather(params, localTokenPerExpertOffset); + + CrossRankSyncAndlocalTokenPerExpertAllGatherAndGetSumPreRankV2(params, localTokenPerExpertOffset); + if (coreIdx == 0) { GetCumsumForMMAIV(tokenPerExpert, cumsumMM, params.expertPerRank, params.rank, params.EP); } + + uint32_t curGroupOffset = 0; + int32_t prevSumBeforeRank = 0; + int32_t prevSum = 0; + if (coreIdx < params.EP) { + prevSum = preSumBeforeRank(coreIdx * params.expertPerRank); + } AscendC::SyncAll(); AscendC::GlobalTensor ExpertTokenNums; @@ -633,24 +787,12 @@ private: AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(syncgmm1Idx / CROSS_CORE_FLAG_MAX_SET_COUNT); syncgmm1Idx++; - uint32_t curGroupOffset = 0; - int32_t prevSumBeforeRank = 0; - int32_t groupIdxDeq = 0; - if (coreIdx < params.EP) { - for (int32_t i = 0; i < params.rank * params.expertPerRank; i++) { - prevSumBeforeRank += tokenPerExpert(tokenPerExpertLayout(coreIdx, 0, i)); - } - m_prevSumBeforeRank = prevSumBeforeRank; - } - int prevSum = prevSumBeforeRank; uint32_t prevGroupSum1 = 0, dequantSum1 = 0, dequantSum2 = 0; uint32_t dequantSum = 0; - int32_t syncLoopIdx = -1; - uint32_t n = params.problemShape.n(); - BlockEpilogue1 blockEpilogue(resource, n); + + icache_preload(8); for (int32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) { // The ith core reads data from the ith rank's peermem - groupIdxDeq = groupIdx - 2; uint32_t currentM = cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx); for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) { uint32_t rowStart = (dstEpIdx == 0 ? 0 : cumsumMM((dstEpIdx - 1) * params.expertPerRank + groupIdx)) + prevGroupSum1; @@ -664,24 +806,23 @@ private: GM_ADDR otherRankPtr = shmem(0, dstEpIdx); AscendC::GlobalTensor gmRemoteA; gmRemoteA.SetGlobalBuffer(reinterpret_cast<__gm__ ElementA*>(otherRankPtr + peermemInfo.offsetA)); - AscendC::GlobalTensor gmRemotePerTokenScale; - gmRemotePerTokenScale.SetGlobalBuffer(reinterpret_cast<__gm__ ElementPerTokenScale*>(otherRankPtr + peermemInfo.offsetPeerPerTokenScale)); + MatrixCoord offsetA{rowStart, 0}; - MatrixCoord shapeA{rows, params.problemShape.k()}; MatrixCoord offsetPeer{rowSrc, 0}; int64_t gmOffsetA = params.layoutA.GetOffset(offsetA); - int64_t gmOffsetPeer = params.layoutA.GetOffset(offsetPeer); + int64_t gmOffsetPeer = rowSrc * (params.problemShape.k() + ALIGN_512); // Communication data - CopyGMToGM(gmA[gmOffsetA], gmRemoteA[gmOffsetPeer], rows * params.problemShape.k(), params.ubMoveNum); - // Communication scale - CopyGMToGM(gmPerTokenScale1[rowStart], gmRemotePerTokenScale[rowSrc], rows, rows); + CopyGMToGMPerToken(gmA[gmOffsetA], gmPerTokenScale1[rowStart], gmRemoteA[gmOffsetPeer], rows, params.problemShape.k()); } - } + } AscendC::SyncAll(); - AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(syncgmm1Idx / CROSS_CORE_FLAG_MAX_SET_COUNT); // V notifies C that the current communication round is complete - prevGroupSum1 += currentM; + AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(syncgmm1Idx / CROSS_CORE_FLAG_MAX_SET_COUNT); syncgmm1Idx ++; + + prevGroupSum1 += currentM; + + // Token count and truncation logic for the first SwiGLU operation if (groupIdx + 1 <= params.epilogueGranularity) { if (dequantSum1 + currentM <= params.maxOutputSize) { dequantSum1 += currentM; @@ -689,6 +830,8 @@ private: dequantSum1 = params.maxOutputSize; } } + + // Token count and truncation logic for the second SwiGLU operation if (groupIdx + 1 > params.epilogueGranularity && dequantSum1 < params.maxOutputSize) { if (dequantSum1 + dequantSum2 + currentM <= params.maxOutputSize) { dequantSum2 += currentM; @@ -698,21 +841,42 @@ private: } } + uint32_t n2 = params.problemShape.k(); + + typename BlockEpilogue2::Params epilogueParams{ + static_cast(params.EP), + static_cast(params.expertPerRank), + static_cast(params.rank), + reinterpret_cast<__gm__ int32_t *>(shmem() + peermemInfo.offsetPeerTokenPerExpert), + params.layoutD2, + static_cast(n2), + static_cast(L1TileShape::N), + shmem, + static_cast(peermemInfo.offsetD) + }; + + uint32_t n = params.problemShape.n(); + BlockEpilogue2 blockEpilogue2(resource, epilogueParams); + BlockEpilogue1 blockEpilogue1(resource, n); + + // Synchronous wait: SwiGLU waits for GMM1 [1] AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGC2V); AscendC::SyncAll(); - - if (dequantSum1 > 0) { + if (dequantSum1 > 0) { uint32_t rowStartThisCore = 0; MatrixCoord offsetC{0U, 0}; MatrixCoord shapeC{dequantSum1, params.problemShape.n()}; LayoutC layoutC{dequantSum1, params.problemShape.n()}; int64_t gmOffsetC = layoutC.GetOffset(offsetC); int64_t gmOffsetD = params.layoutD1.GetOffset(offsetC); - blockEpilogue(gmC[gmOffsetC], shapeC, gmPerTokenScale1[rowStartThisCore], gmPermutedToken[gmOffsetD], gmPerTokenScale2[rowStartThisCore], params.epilogueCoreNum); + blockEpilogue1(gmC[gmOffsetC], shapeC, gmPerTokenScale1[rowStartThisCore], gmPermutedToken[gmOffsetD], gmPerTokenScale2[rowStartThisCore], params.epilogueCoreNum); } AscendC::SyncAll(); - AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(SYNCFLAGV2C); + // Synchronization signal: SwiGLU notifies GMM2 [1] + AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(SYNCFLAGV2C); + if ((params.epilogueGranularity < params.expertPerRank && params.epilogueGranularity > 0)) { + // Synchronous wait: SwiGLU waits for GMM1 [2] AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGC2V); AscendC::SyncAll(); if (dequantSum2 > 0) { @@ -723,76 +887,118 @@ private: LayoutC layoutC{dequantLen, params.problemShape.n()}; int64_t gmOffsetC = layoutC.GetOffset(offsetC); int64_t gmOffsetD = params.layoutD1.GetOffset(offsetC); - blockEpilogue(gmC[gmOffsetC], shapeC, gmPerTokenScale1[rowStartThisCore], gmPermutedToken[gmOffsetD], gmPerTokenScale2[rowStartThisCore], coreNum); + blockEpilogue1(gmC[gmOffsetC], shapeC, gmPerTokenScale1[rowStartThisCore], gmPermutedToken[gmOffsetD], gmPerTokenScale2[rowStartThisCore], coreNum); } AscendC::SyncAll(); - AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(SYNCFLAGV2C); + // Synchronization signal: SwiGLU notifies GMM2 [2] + AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(SYNCFLAGV2C); } - blockEpilogue.Finalize(); + + blockEpilogue1.Finalize(); + + + CombineSetFlag(); + + CombineV2(params, blockEpilogue2); + + AscendC::SyncAll(); + #ifndef __CROSSRANKSYNCANDALLGATHERV1__ + ResetTokenPerExpert(params.EP * AlignUp(params.EP * params.expertPerRank, 128)); + #endif + shmem.InitStatusTargetSum(); + if (get_subblockid() == 0) { + AscendC::LocalTensor ctrBuffer = resource.ubBuf.template GetBufferByByte(0); + shmem.CrossRankSyncV2Set(ctrBuffer); + } else { + uint32_t uboffset = 0; + uint32_t aicCoreNum = coreNum / 2; + uint32_t aicCoreIdx = get_block_idx(); + uint32_t sendRankNum_ = params.EP / aicCoreNum; + uint32_t remainderRankNum = params.EP % aicCoreNum; + if (aicCoreIdx < remainderRankNum) { + sendRankNum_++; + } + AscendC::LocalTensor statusTensor = resource.ubBuf.template GetBufferByByte(uboffset); + uboffset += sendRankNum_ * UB_ALIGN; + AscendC::LocalTensor gatherMaskOutTensor = resource.ubBuf.template GetBufferByByte(uboffset); + uboffset += params.EP * sizeof(float); + AscendC::LocalTensor gatherTmpTensor = resource.ubBuf.template GetBufferByByte(uboffset); + uboffset += sizeof(uint32_t); + AscendC::LocalTensor statusSumOutTensor = resource.ubBuf.template GetBufferByByte(uboffset); + uboffset += sizeof(float); + shmem.CrossRankSyncV2Wait(statusTensor, gatherMaskOutTensor, gatherTmpTensor, statusSumOutTensor); + MoeTokenUnpermuteTilingData tilingData; + MoeTokenUnpermuteTiling(params.problemShape.m() * params.topK, n2, params.topK, tilingData, coreNum / 2); + KernelMoeTokenUnpermute kernelMoeTokenUnpermuteOp; + kernelMoeTokenUnpermuteOp.Init(shmem() + peermemInfo.offsetD, workspaceInfo.expandedRowIdx, params.probs, reinterpret_cast(params.ptrOutput), &tilingData); + kernelMoeTokenUnpermuteOp.Process(); + } + } CATLASS_DEVICE - void Combine(Params const ¶ms) { - int32_t prevSumBeforeRank = 0; - if (coreIdx < params.EP) { - prevSumBeforeRank = m_prevSumBeforeRank; - } - - int prevSum = prevSumBeforeRank; + void CombineV2(Params const ¶ms, BlockEpilogue2 & blockEpilogue) { + BlockScheduler blockScheduler; + int32_t syncLoopIdx = 0; + uint32_t startCoreIdx = 0; + uint32_t aicCoreNum = coreNum / 2; + uint32_t aicCoreIdx = get_block_idx(); + uint32_t aivSubCoreIdx = get_subblockid(); + uint32_t preSrcExpertSum = 0; uint32_t n2 = params.problemShape.k(); uint32_t k2 = params.problemShape.n() / 2; - - // TODO compute the cumsum of tokenPerExpert - typename BlockEpilogue2::Params epilogueParams{ - static_cast(params.EP), - static_cast(params.expertPerRank), - reinterpret_cast<__gm__ int32_t *>(params.ptrWorkspace), - static_cast(n2) - }; - BlockEpilogue2 blockEpilogue(resource, epilogueParams); - int32_t prevGroupSum2 = 0; + icache_preload(8); for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) { - AscendC::CrossCoreWaitFlag<0x2>(groupIdx / CROSS_CORE_FLAG_MAX_SET_COUNT); - AscendC::SyncAll(); + uint32_t currentExpertM = cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx); + if (preSrcExpertSum >= params.maxOutputSize) { + currentExpertM = 0; + } else if (preSrcExpertSum + currentExpertM > params.maxOutputSize) { + currentExpertM = params.maxOutputSize - preSrcExpertSum; + } + GemmCoord inGroupProblemShape{currentExpertM, n2, k2}; // M N K + blockScheduler.Update(inGroupProblemShape, MakeCoord(L1TileShape::M, L1TileShape::N)); + uint32_t coreLoops = blockScheduler.GetCoreLoops(); + uint32_t startLoopIdx = ((aicCoreIdx < startCoreIdx) ? (aicCoreIdx + aicCoreNum) : aicCoreIdx) - startCoreIdx; - for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) { - __gm__ void* dstPeermemPtr = shmem(peermemInfo.offsetD, dstEpIdx); - AscendC::GlobalTensor gmRemotePeer; - gmRemotePeer.SetGlobalBuffer(reinterpret_cast<__gm__ ElementD2*>(dstPeermemPtr)); - uint32_t srcRowOffset = (dstEpIdx == 0 ? 0 : cumsumMM((dstEpIdx - 1) * params.expertPerRank + groupIdx)) + prevGroupSum2; - if (srcRowOffset < params.maxOutputSize) { - uint32_t dataRows = tokenPerExpert(tokenPerExpertLayout(dstEpIdx, params.rank, groupIdx)); - if (srcRowOffset + dataRows > params.maxOutputSize) { - dataRows = params.maxOutputSize - srcRowOffset; - } - uint32_t dstRowOffset = prevSum; - prevSum += dataRows; - MatrixCoord offsetC{srcRowOffset, 0}; - MatrixCoord offsetPeer{dstRowOffset, 0}; - MatrixCoord shapeC{dataRows, n2}; - int64_t gmOffsetC = params.layoutD2.GetOffset(offsetC); - int64_t gmOffsetPeer = params.layoutD2.GetOffset(offsetPeer); - if constexpr (std::is_same_v) { - blockEpilogue(gmC2[gmOffsetC], shapeC, gmPerTokenScale2[srcRowOffset], gmRemotePeer[gmOffsetPeer]); - } else { - blockEpilogue(gmC2[gmOffsetC], shapeC, gmRemotePeer[gmOffsetPeer]); + for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += aicCoreNum) { + GemmCoord blockCoord = blockScheduler.GetBlockCoord(loopIdx); + GemmCoord actualBlockShape = blockScheduler.GetActualBlockShape(blockCoord); + int32_t m0 = 16; + // Block count, the shape of each block is (m0, actualBlockShape.n()) + int32_t m_rows = (actualBlockShape.m() + m0 - 1) / m0; + int32_t aiv_m_rows = m_rows / 2; + if (aivSubCoreIdx == 1 && aiv_m_rows * 2 < m_rows) { + aiv_m_rows += 1; + } + uint32_t m_offset = blockCoord.m() * L1TileShape::M;//blockOffset + if(aivSubCoreIdx == 1) { + m_offset += (m_rows / 2) * m0; + } + + + for (;syncLoopIdx <= groupIdx; syncLoopIdx ++) { + int32_t flag_id = syncLoopIdx / CROSS_CORE_FLAG_MAX_SET_COUNT; + AscendC::CrossCoreWaitFlag<0x2>(flag_id); + } + + for (int32_t cur_row = 0; cur_row < aiv_m_rows; cur_row ++) { + GemmCoord realTileCoord{m_offset, blockCoord.n() * L1TileShape::N, 1}; + uint32_t actualm = m0; + if(aivSubCoreIdx == 1 && cur_row == aiv_m_rows - 1){ + actualm = actualBlockShape.m() - (m_rows / 2) * m0 - cur_row * m0; } + GemmCoord realTileShape{actualm, actualBlockShape.n(), 1}; + blockEpilogue(gmC2, gmPerTokenScale2, realTileCoord, realTileShape, groupIdx, preSrcExpertSum, preSumBeforeRank); + m_offset += m0; } } - prevGroupSum2 += cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx); + preSrcExpertSum += currentExpertM; + startCoreIdx = (startCoreIdx + coreLoops) % aicCoreNum; } blockEpilogue.Finalize(); - AscendC::SyncAll(); - ResetTokenPerExpert(tokenPerExpert, params.EP * AlignUp(params.EP * params.expertPerRank, ALIGN_128)); - shmem.CrossRankSync(); - MoeTokenUnpermuteTilingData tilingData; - MoeTokenUnpermuteTiling(params.problemShape.m() * params.topK, n2, params.topK, tilingData, coreNum); - KernelMoeTokenUnpermute kernelMoeTokenUnpermuteOp; - - kernelMoeTokenUnpermuteOp.Init(shmem() + peermemInfo.offsetD, workspaceInfo.expandedRowIdx, params.probs, reinterpret_cast(params.ptrOutput), &tilingData); - kernelMoeTokenUnpermuteOp.Process(); } + private: struct WorkspaceInfo { GM_ADDR ptrA; @@ -804,6 +1010,9 @@ private: GM_ADDR ptrPerTokenScale2; GM_ADDR expandedRowIdx; GM_ADDR ptrTokenPerExpert; + GM_ADDR ptrSumBeforeRank; + __gm__ float* ptrSoftFlagBase; + CATLASS_DEVICE WorkspaceInfo(){} @@ -831,15 +1040,21 @@ private: workspaceOffset += (params.EP * params.EP * params.expertPerRank) * sizeof(int32_t); ptrC = params.ptrWorkspace + workspaceOffset; - ptrC2 = ptrC; - workspaceOffset += max(params.maxOutputSize * params.problemShape.n() * sizeof(ElementC), - params.maxOutputSize * n2 * sizeof(ElementC)); + workspaceOffset += params.maxOutputSize * params.problemShape.n() * sizeof(ElementC); + ptrC2 = params.ptrWorkspace + workspaceOffset; + workspaceOffset += params.maxOutputSize * n2 * sizeof(ElementC); ptrA = params.ptrWorkspace + workspaceOffset; - ptrPermutedToken = ptrA; - workspaceOffset += max(params.maxOutputSize * params.problemShape.k() * sizeof(ElementA), - params.maxOutputSize * k2 * sizeof(ElementA)); + + workspaceOffset += params.maxOutputSize * params.problemShape.k() * sizeof(ElementA); + ptrPermutedToken = params.ptrWorkspace + workspaceOffset; + + workspaceOffset += params.maxOutputSize * k2 * sizeof(ElementA); + ptrSumBeforeRank = params.ptrWorkspace + workspaceOffset; + + workspaceOffset += params.EP * sizeof(int32_t) * FLAGSTRIDE; + ptrSoftFlagBase = reinterpret_cast<__gm__ float*>(params.ptrWorkspace + workspaceOffset); } }; @@ -866,12 +1081,9 @@ private: uint32_t coreIdx; uint32_t coreNum; - Params params; WorkspaceInfo workspaceInfo; PeermemInfo peermemInfo; - int64_t m_prevSumBeforeRank; - AscendC::GlobalTensor gmA; AscendC::GlobalTensor gmC; @@ -883,10 +1095,11 @@ private: AscendC::GlobalTensor tokenPerExpert; AscendC::GlobalTensor cumsumMM; + AscendC::GlobalTensor preSumBeforeRank; Layout3D tokenPerExpertLayout; HcclShmem shmem; }; } // namespace Catlass::Gemm::Kernel -#endif // DISPATH_FFN_COMBINE_KERNEL_HPP +#endif // DISPATCH_FFN_COMBINE_KERNEL_HPP diff --git a/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_common.h b/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_common.h index c190033a..e362f50a 100644 --- a/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_common.h +++ b/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_common.h @@ -44,6 +44,8 @@ constexpr int64_t EXERPT_TOKENS_COUNT = 2; constexpr int64_t EXERPT_TOKENS_CUMSUM = 1; constexpr int64_t EXERPT_TOKENS_NONE = 0; constexpr int64_t EXERPT_TOKENS_BEFORE_CAPACITY = 1; +constexpr int64_t ALIGN_512 = 512; +constexpr int64_t ALIGN_128 = 128; const __gm__ int32_t assist[256] = { 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, diff --git a/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_fullload_dynamic_quant.h b/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_fullload_dynamic_quant.h index 824e9af3..9d77c5e2 100644 --- a/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_fullload_dynamic_quant.h +++ b/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_fullload_dynamic_quant.h @@ -35,7 +35,6 @@ class MoeV2FullLoadDynamicQuant : public MoeV2SortBase { __aicore__ inline void CopyOutIdx(); __aicore__ inline void CopyOutEmpty(); __aicore__ inline void CopyOutXQuant1H(); - __aicore__ inline void CopyOutXQuantEH(); __aicore__ inline void ComputeExpertTokenCountOrCumsum(); __aicore__ inline void Compute(LocalTensor& smoothLocal); @@ -49,6 +48,7 @@ class MoeV2FullLoadDynamicQuant : public MoeV2SortBase { int64_t k_; int64_t n_; int64_t cols_; + int64_t cols_scale_; int64_t activateRows_; int64_t expertNum; int64_t expertCapacity; @@ -63,12 +63,10 @@ class MoeV2FullLoadDynamicQuant : public MoeV2SortBase { TQue smoothInQueue; TQue calcQueue; TQue inputXOutQueue; - TQue scaleOutQueue; GlobalTensor xGm_; GlobalTensor expertIdxGm_; GlobalTensor quantSmoothGm; - GlobalTensor dynamicQuantScaleGm; GlobalTensor expandedXGm_; GlobalTensor expandedRowIdxGm_; @@ -225,7 +223,7 @@ __aicore__ inline void MoeV2FullLoadDynamicQuant::Compute(LocalTensor& LocalTensor tempLocal = calcQueue.AllocTensor(); LocalTensor outLocal = inputXOutQueue.AllocTensor(); - LocalTensor dynamicQuantLocal = scaleOutQueue.AllocTensor(); + LocalTensor dynamicQuantLocal = outLocal[this->cols_].template ReinterpretCast(); if constexpr (!IsSameType::value) { Cast(inLocal, inLocal.ReinterpretCast()[colsAlign], RoundMode::CAST_NONE, this->cols_); @@ -259,7 +257,6 @@ __aicore__ inline void MoeV2FullLoadDynamicQuant::Compute(LocalTensor& calcQueue.FreeTensor(tempLocal); inputXOutQueue.EnQue(outLocal); - scaleOutQueue.EnQue(dynamicQuantLocal); } template @@ -275,7 +272,7 @@ __aicore__ inline void MoeV2FullLoadDynamicQuant::CopyOutXQuant1H() { DataCopyExtParams dataXCopyParams{1, static_cast(this->cols_ * sizeof(T)), 0, 0, 0}; DataCopyExtParams smoothCopyParams{1, static_cast(this->cols_ * sizeof(float)), 0, 0, 0}; - DataCopyExtParams intriParams{1, static_cast(this->cols_ * sizeof(int8_t)), 0, 0, 0}; + DataCopyExtParams intriParams{1, static_cast((this->cols_ + BLOCK_BYTES) * sizeof(int8_t)), 0, 0, 0}; LocalTensor smoothLocal; if (smoothType == 1) { @@ -295,7 +292,6 @@ __aicore__ inline void MoeV2FullLoadDynamicQuant::CopyOutXQuant1H() { xCopyInQueue_.EnQue(xLocal); Compute(smoothLocal); - LocalTensor quantScaleLocal = scaleOutQueue.DeQue(); LocalTensor outLocal = inputXOutQueue.DeQue(); while (curRowsStart <= curRowsEnd && curRowsStart / this->k_ == row) { int32_t outIndex = expandedRowIdx.GetValue(curRowsStart); @@ -303,76 +299,15 @@ __aicore__ inline void MoeV2FullLoadDynamicQuant::CopyOutXQuant1H() { if (outIndex == -1 || (this->dropPadMode == DROPLESS_MODE && outIndex >= this->activateRows_)) { continue; } - DataCopyPad(expandedXGm_[outIndex * cols_], outLocal, intriParams); - DataCopyPad(dynamicQuantScaleGm[outIndex], quantScaleLocal, {1, 4, 0, 0, 0}); + DataCopyPad(expandedXGm_[outIndex * this->cols_scale_], outLocal, intriParams); } xCopyInQueue_.FreeTensor(xLocal); inputXOutQueue.FreeTensor(outLocal); - scaleOutQueue.FreeTensor(quantScaleLocal); - } - - if (smoothType == 1) { - smoothInQueue.FreeTensor(smoothLocal); } expandedRowIdxCopyOutQueue_.FreeTensor(expandedRowIdx); } -template -__aicore__ inline void MoeV2FullLoadDynamicQuant::CopyOutXQuantEH() { - LocalTensor expandedRowIdx = expandedRowIdxCopyOutQueue_.DeQue(); - expandedRowIdxCopyOutQueue_.FreeTensor(expandedRowIdx); - - Muls(expandDstToSrcRowLocal.ReinterpretCast(), expandDstToSrcRowLocal.ReinterpretCast(), (float)-1, - this->totalLength); - pipe_barrier(PIPE_V); - LocalTensor sortedRowIdx = expandDstToSrcRowLocal.ReinterpretCast(); - Cast(sortedRowIdx, expandDstToSrcRowLocal.ReinterpretCast(), RoundMode::CAST_ROUND, this->totalLength); - - int64_t curRowsStart = this->blockIdx_ * this->perCoreRows_; - int64_t curRowsEnd = curRowsStart + this->coreRows_ - 1; - - DataCopyExtParams dataXCopyParams{1, static_cast(this->cols_ * sizeof(T)), 0, 0, 0}; - DataCopyExtParams smoothCopyParams{1, static_cast(this->cols_ * sizeof(float)), 0, 0, 0}; - DataCopyExtParams intriParams{1, static_cast(this->cols_ * sizeof(int8_t)), 0, 0, 0}; - - for (int64_t row = curRowsStart; row <= curRowsEnd; row++) { - if (this->dropPadMode == DROPLESS_MODE && row >= this->activateRows_) { - break; - } - int32_t srcIdx = sortedRowIdx.GetValue(row); - int32_t expertIdx = expandedExpertIdxLocal.GetValue(row); - - LocalTensor inLocal = xCopyInQueue_.AllocTensor(); - LocalTensor smoothLocal = smoothInQueue.AllocTensor(); - if constexpr (IsSameType::value) { - DataCopyPad(inLocal, xGm_[srcIdx / this->k_ * this->cols_], dataXCopyParams, {false, 0, 0, 0}); - } else { - DataCopyPad(inLocal[colsAlign], xGm_[srcIdx / this->k_ * this->cols_], dataXCopyParams, {false, 0, 0, 0}); - } - DataCopyPad(smoothLocal, quantSmoothGm[expertIdx * this->cols_], smoothCopyParams, {false, 0, 0, 0}); - xCopyInQueue_.EnQue(inLocal); - smoothInQueue.EnQue(smoothLocal); - smoothLocal = smoothInQueue.DeQue(); - - Compute(smoothLocal); - - LocalTensor quantScaleLocal = scaleOutQueue.DeQue(); - DataCopyPad(dynamicQuantScaleGm[row], quantScaleLocal, {1, 4, 0, 0, 0}); - - LocalTensor outLocal = inputXOutQueue.DeQue(); - DataCopyPad(expandedXGm_[row * this->cols_], outLocal, intriParams); - - xCopyInQueue_.FreeTensor(inLocal); - smoothInQueue.FreeTensor(smoothLocal); - inputXOutQueue.FreeTensor(outLocal); - scaleOutQueue.FreeTensor(quantScaleLocal); - } - - expandDstToSrcRowQueue_.FreeTensor(expandDstToSrcRowLocal); - expandedExpertIdxCopyOutQueue_.FreeTensor(expandedExpertIdxLocal); -} - template __aicore__ inline void MoeV2FullLoadDynamicQuant::Init(GM_ADDR x, GM_ADDR expertIdx, GM_ADDR expandedX, GM_ADDR expandedRowIdx, GM_ADDR expertTokensCountOrCumsum, @@ -386,6 +321,7 @@ __aicore__ inline void MoeV2FullLoadDynamicQuant::Init(GM_ADDR x, GM_ADDR exp this->k_ = tilingData->k; this->n_ = tilingData->n; this->cols_ = tilingData->cols; + this->cols_scale_ = this->cols_ + ALIGN_512; this->needCoreNum_ = this->gatherOutTilingData_->needCoreNum; this->perCoreRows_ = this->gatherOutTilingData_->perCoreRows; this->activateRows_ = this->gatherOutTilingData_->activateRows; @@ -416,7 +352,6 @@ __aicore__ inline void MoeV2FullLoadDynamicQuant::Init(GM_ADDR x, GM_ADDR exp Align(this->expertNum, sizeof(int32_t))); } quantSmoothGm.SetGlobalBuffer((__gm__ float*)quantSmooth); - dynamicQuantScaleGm.SetGlobalBuffer((__gm__ float*)dynamicQuantScale); int64_t kvFactor = 2; int64_t buffSize = this->sortNum_ * sizeof(int32_t); @@ -440,8 +375,7 @@ __aicore__ inline void MoeV2FullLoadDynamicQuant::Init(GM_ADDR x, GM_ADDR exp } pipe->InitBuffer(smoothInQueue, 1, AlignBytes(this->cols_, sizeof(float))); pipe->InitBuffer(calcQueue, 1, AlignBytes(this->cols_, sizeof(float))); - pipe->InitBuffer(inputXOutQueue, 1, AlignBytes(this->cols_, sizeof(int8_t))); - pipe->InitBuffer(scaleOutQueue, 1, BLOCK_BYTES + BLOCK_BYTES); + pipe->InitBuffer(inputXOutQueue, 1, AlignBytes(this->cols_scale_, sizeof(int8_t))); } template @@ -457,11 +391,7 @@ __aicore__ inline void MoeV2FullLoadDynamicQuant::Process() { } else { CopyOutEmpty(); } - if (smoothType == 2) { - CopyOutXQuantEH(); - } else { - CopyOutXQuant1H(); - } + CopyOutXQuant1H(); } } } // namespace MoeInitRoutingQuantV2 diff --git a/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_gather_dynamic_quant.h b/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_gather_dynamic_quant.h index 924e8548..64852f31 100644 --- a/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_gather_dynamic_quant.h +++ b/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_gather_dynamic_quant.h @@ -66,6 +66,7 @@ class MoeV2GatherDynamicQuant { int64_t needCoreNum; int64_t blockIdx; int64_t cols; + int64_t cols_scale_; int64_t n; int64_t k; int64_t totalLength; @@ -117,7 +118,7 @@ __aicore__ inline void MoeV2GatherDynamicQuant::Compute(LocalTensor& s LocalTensor tempLocal = calcQueue.AllocTensor(); LocalTensor outLocal = inputXOutQueue.AllocTensor(); - LocalTensor dynamicQuantLocal = scaleOutQueue.AllocTensor(); + LocalTensor dynamicQuantLocal = outLocal[this->cols].template ReinterpretCast(); if constexpr (!IsSameType::value) { Cast(inLocal, inLocal.ReinterpretCast()[perLoopColsAlign], RoundMode::CAST_NONE, this->cols); @@ -151,7 +152,6 @@ __aicore__ inline void MoeV2GatherDynamicQuant::Compute(LocalTensor& s calcQueue.FreeTensor(tempLocal); inputXOutQueue.EnQue(outLocal); - scaleOutQueue.EnQue(dynamicQuantLocal); } template @@ -163,7 +163,7 @@ __aicore__ inline void MoeV2GatherDynamicQuant::CopyOutXQuant1H(int64_t progr int64_t currentLoopStartRow = initialRow / this->k; int64_t currentLoopLastRow = (initialRow + this->currentLoopRows - 1) / this->k; DataCopyExtParams copyInParams{1, static_cast(this->cols * sizeof(T)), 0, 0, 0}; - DataCopyExtParams copyOutParams{1, static_cast(this->cols * sizeof(int8_t)), 0, 0, 0}; + DataCopyExtParams copyOutParams{1, static_cast((this->cols + BLOCK_BYTES) * sizeof(int8_t)), 0, 0, 0}; DataCopyExtParams smoothParams{1, static_cast(this->cols * sizeof(float)), 0, 0, 0}; LocalTensor smoothLocal; @@ -187,7 +187,6 @@ __aicore__ inline void MoeV2GatherDynamicQuant::CopyOutXQuant1H(int64_t progr // Compute quantization Compute(smoothLocal); - LocalTensor quantScaleLocal = scaleOutQueue.DeQue(); LocalTensor outLocal = inputXOutQueue.DeQue(); while (curLoopRow < this->currentLoopRows && initialRow / this->k == row) { @@ -197,15 +196,11 @@ __aicore__ inline void MoeV2GatherDynamicQuant::CopyOutXQuant1H(int64_t progr if (outIndex == -1 || (this->dropPadMode == DROPLESS_MODE && outIndex >= this->activateRows)) { continue; } - DataCopyPad(expandedXGm[outIndex * cols], outLocal, copyOutParams); - DataCopyPad(dynamicQuantScaleGm[outIndex], quantScaleLocal, {1, 4, 0, 0, 0}); + // Scale is placed after the data position + DataCopyPad(expandedXGm[outIndex * cols_scale_], outLocal, copyOutParams); } inputXInQueue.FreeTensor(inLocal); inputXOutQueue.FreeTensor(outLocal); - scaleOutQueue.FreeTensor(quantScaleLocal); - } - if (smoothType == 1) { - smoothInQueue.FreeTensor(smoothLocal); } expandRowIdxInQueue.FreeTensor(indicesLocal); } @@ -463,6 +458,7 @@ __aicore__ inline void MoeV2GatherDynamicQuant::Init(GM_ADDR inputX, GM_ADDR this->needCoreNum = this->gatherOutTilingData->needCoreNum; this->activateRows = this->gatherOutTilingData->activateRows; this->cols = tilingData->cols; + this->cols_scale_ = this->cols + ALIGN_512; this->n = tilingData->n; this->k = tilingData->k; this->totalLength = tilingData->n * tilingData->k; @@ -518,32 +514,15 @@ __aicore__ inline void MoeV2GatherDynamicQuant::Init(GM_ADDR inputX, GM_ADDR pipe->InitBuffer(smoothInQueue, BUFFER_NUM, AlignBytes(this->perLoopCols, sizeof(float))); pipe->InitBuffer(calcQueue, 1, AlignBytes(this->perLoopCols, sizeof(float))); pipe->InitBuffer(inputXOutQueue, 1, AlignBytes(this->perLoopCols, sizeof(int8_t))); - pipe->InitBuffer(scaleOutQueue, 1, BLOCK_BYTES + BLOCK_BYTES); } template __aicore__ inline void MoeV2GatherDynamicQuant::Process() { if (this->blockIdx < this->needCoreNum) { currentLoopRows = perLoopRows; - if (colLoops > 1) { // A single row cannot be fully loaded; workspace is required - if (smoothType == 2) { - for (int64_t loop = 0; loop < this->rowLoops - 1; loop++) { - CopyInExpandedExpertIdx(loop); - CopyOutPartialXQuantEH(loop); - } - currentLoopRows = lastLoopRows; - CopyInExpandedExpertIdx(this->rowLoops - 1); - CopyOutPartialXQuantEH(this->rowLoops - 1); - } else { - for (int64_t loop = 0; loop < this->rowLoops - 1; loop++) { - CopyInExpandedRowIdx(loop); - CopyOutPartialXQuant1H(loop); - } - currentLoopRows = lastLoopRows; - CopyInExpandedRowIdx(this->rowLoops - 1); - CopyOutPartialXQuant1H(this->rowLoops - 1); - } - } else { // A single row can be fully loaded + if (colLoops > 1) { // Cannot fit all data in one row, workspace is required + trap(); // Not supported + } else { // All data can fit in one row if (smoothType == 2) { for (int64_t loop = 0; loop < this->rowLoops - 1; loop++) { CopyInExpandedExpertIdx(loop); diff --git a/csrc/dispatch_ffn_combine/op_kernel/unpermute/moe_token_unpermute.h b/csrc/dispatch_ffn_combine/op_kernel/unpermute/moe_token_unpermute.h index 1255b5cf..adb805b8 100644 --- a/csrc/dispatch_ffn_combine/op_kernel/unpermute/moe_token_unpermute.h +++ b/csrc/dispatch_ffn_combine/op_kernel/unpermute/moe_token_unpermute.h @@ -85,8 +85,8 @@ KernelMoeTokenUnpermute::Init(GM_ADDR permuted_tokens, GM_ADD GM_ADDR unpermuted_tokens, const MoeTokenUnpermuteTilingData *__restrict tiling_data) { - this->blockIdx = get_block_idx() + get_subblockid() * get_block_num(); - this->blockNum = get_block_num() * get_subblockdim(); + this->blockIdx = get_block_idx(); + this->blockNum = get_block_num(); if (blockIdx >= blockNum) { return; diff --git a/csrc/dispatch_ffn_combine/op_kernel/utils/block_epilogue_pertoken_v2.hpp b/csrc/dispatch_ffn_combine/op_kernel/utils/block_epilogue_pertoken_v2.hpp new file mode 100644 index 00000000..926622dc --- /dev/null +++ b/csrc/dispatch_ffn_combine/op_kernel/utils/block_epilogue_pertoken_v2.hpp @@ -0,0 +1,243 @@ +#ifndef CATLASS_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_V2_ONLY_HPP +#define CATLASS_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_V2_ONLY_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/epilogue/dispatch_policy.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/matrix_coord.hpp" +#include "catlass/layout/layout.hpp" +#include "catlass/detail/callback.hpp" + +#include "hccl_shmem.hpp" +#include "layout3d.hpp" + +namespace Catlass::Epilogue::Block { +template < + uint32_t UB_STAGES_, + class CType_, + class LayoutPerTokenScale_, + class DType_, + class TileCopy_ +> +class BlockEpilogue < + EpilogueAtlasA2PerTokenDequantV2, + CType_, + Gemm::GemmType, + DType_, + TileCopy_ +> { +public: + using DispatchPolicy = EpilogueAtlasA2PerTokenDequantV2; + using ArchTag = typename DispatchPolicy::ArchTag; + static constexpr uint32_t UB_STAGES = UB_STAGES_; + + // Data infos + using ElementC = typename CType_::Element; + using LayoutC = typename CType_::Layout; + using ElementPerTokenScale = float; + using LayoutPerTokenScale = LayoutPerTokenScale_; + using ElementD = typename DType_::Element; + using LayoutD = typename DType_::Layout; + + using CopyScaleGmToUb = Epilogue::Tile::CopyGm2Ub>; + using CopyGmToUbC = typename TileCopy_::CopyGmToUbC; + using CopyUbToGmD = typename TileCopy_::CopyUbToGmD; + + struct Params { + __gm__ int32_t *ptrTokenPerExpert{nullptr}; + int32_t EP; + int32_t expertPerRank; + int32_t n2; + LayoutC layoutC; + int32_t n0; + int32_t rank; + HcclShmem shmem; + int32_t offsetD; + + CATLASS_DEVICE + Params() {}; + CATLASS_DEVICE + Params(int32_t EP_, int32_t expertPerRank_, int32_t rank_, __gm__ int32_t *ptrTokenPerExpert_, + LayoutC layoutC_, int32_t n2_, int32_t n0_, HcclShmem& shmem_, int32_t offsetD_) : + ptrTokenPerExpert(ptrTokenPerExpert_), EP(EP_), + expertPerRank(expertPerRank_),rank(rank_), layoutC(layoutC_), n2(n2_), n0(n0_), + shmem(shmem_), offsetD(offsetD_) + {} + }; + + + CATLASS_DEVICE + BlockEpilogue(Arch::Resource const &resource, Params const ¶ms = Params{}) : params(params) + { + //ub:192KB + n0 = params.n0; + size_t ubOffset = 0; + for(int32_t i = 0; i < 2; i++) { + ubCList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += max_len * sizeof(ElementC); + ubDList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += max_len * sizeof(ElementD); + ubFp32List[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += max_len * sizeof(float); + scaleUbList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += (max_len / n0) * sizeof(float); + source_scale_offset[i] = -1; + } + tokenPerExpert.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(params.ptrTokenPerExpert)); + tokenPerExpertLayout = Layout3D(AlignUp(params.EP * params.expertPerRank, 128), params.expertPerRank); + is_ping = true; + } + + CATLASS_DEVICE + void Finalize() + { + AscendC::WaitFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID1); + AscendC::WaitFlag(EVENT_ID2); + AscendC::WaitFlag(EVENT_ID3); + AscendC::WaitFlag(EVENT_ID2); + AscendC::WaitFlag(EVENT_ID3); + AscendC::WaitFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID1); + + } + CATLASS_DEVICE + ~BlockEpilogue() + { + + } + CATLASS_DEVICE + void operator() ( + AscendC::GlobalTensor const &gmC, + AscendC::GlobalTensor const &gmPerTokenScale, + GemmCoord& blockCoord, + GemmCoord& actualBlockShape, + int32_t groupIdx, + int32_t preSrcExpertSum, + AscendC::GlobalTensor preSumBeforeRank + ){ + is_ping = !is_ping; + auto event_id = is_ping ? EVENT_ID0 : EVENT_ID1; + auto event_id_2 = is_ping ? EVENT_ID2 : EVENT_ID3; + + auto &ubC = ubCList[is_ping]; + auto &ubD = ubDList[is_ping]; + int32_t gmCOffset = preSrcExpertSum * params.n2 + blockCoord.m() * params.n2 + blockCoord.n(); + auto gmTileC = gmC[gmCOffset]; + auto &ubCFp32 = ubFp32List[is_ping]; + auto &scaleUb = scaleUbList[is_ping]; + + LayoutC layoutGM{actualBlockShape.m(), actualBlockShape.n(), params.n2}; + LayoutC layoutUB{actualBlockShape.m(), actualBlockShape.n(), n0}; + + + AscendC::WaitFlag(event_id); + copyGmToUbC(ubC, gmTileC, layoutUB, layoutGM); + AscendC::SetFlag(event_id); + + AscendC::WaitFlag(event_id); + AscendC::Cast(ubCFp32, ubC, AscendC::RoundMode::CAST_NONE, -1, repeat, {1, 1, 8, 4}); + AscendC::SetFlag(event_id); + + + AscendC::WaitFlag(event_id_2); + AscendC::WaitFlag(event_id_2); + + int32_t gmScaleOffset = preSrcExpertSum + blockCoord.m(); + layout::VectorLayout scaleLauout{actualBlockShape.m()}; + if (source_scale_offset[event_id] != gmScaleOffset) { + source_scale_offset[event_id] = gmScaleOffset; + copyScaleGmToUb(scaleUb, gmPerTokenScale[gmScaleOffset], scaleLauout, scaleLauout); + } + + AscendC::SetFlag(event_id_2); + AscendC::SetFlag(event_id_2); + + + + + AscendC::WaitFlag(event_id_2); + AscendC::WaitFlag(event_id_2); // Note that the value must be MTE2_S instead of MTE2_V. + // Otherwise, 0 will be read, causing garbled characters. + AscendC::PipeBarrier(); + for (int32_t row = 0; row < actualBlockShape.m(); ++row) { + float scale = scaleUb(row); + Muls(ubCFp32[n0* row], ubCFp32[n0 * row] , scale, -1, (actualBlockShape.n() + 127) / 128 * 2, {1, 1, 8, 8}); + } + AscendC::PipeBarrier(); + AscendC::WaitFlag(event_id); + AscendC::Cast(ubD, ubCFp32, AscendC::RoundMode::CAST_RINT, -1, repeat, {1, 1, 4, 8}); + AscendC::SetFlag(event_id_2); + AscendC::SetFlag(event_id_2); + AscendC::SetFlag(event_id); + + int32_t lenTile = actualBlockShape.m(); + int32_t stTile = blockCoord.m(); + int32_t edTile = stTile + lenTile; + int32_t preSumRankInExpert = 0; + int32_t tileOffset = 0; + + AscendC::WaitFlag(event_id); + for (int32_t dstEpIdx = 0; dstEpIdx < params.EP; dstEpIdx ++) { + int32_t lenRankInExpert = tokenPerExpert(tokenPerExpertLayout(dstEpIdx, params.rank, groupIdx)); + int32_t dstExpertOffset = preSumBeforeRank(dstEpIdx * params.expertPerRank + groupIdx); + int32_t stRankInExpert = preSumRankInExpert; + int32_t edRankInExpert = stRankInExpert + lenRankInExpert; + preSumRankInExpert += lenRankInExpert; + if (stRankInExpert >= edTile) { + break; + } + else if (edRankInExpert <= stTile) { + continue; + } + int32_t stData = max(stRankInExpert, stTile); + int32_t edData = min(edRankInExpert, edTile); + uint32_t lenData = edData - stData; + if (lenData <= 0){ + continue; + } + + uint32_t dstOffsetInExpert = 0; + if (stTile > stRankInExpert) { + dstOffsetInExpert = stTile - stRankInExpert; + } + AscendC::GlobalTensor gmRemotePeer; + __gm__ void* dstPeermemPtr = params.shmem(params.offsetD, dstEpIdx); + gmRemotePeer.SetGlobalBuffer(reinterpret_cast<__gm__ ElementD*>(dstPeermemPtr)); + MatrixCoord dstOffset{dstOffsetInExpert + dstExpertOffset, blockCoord.n()}; + int64_t gmDstOffset = params.layoutC.GetOffset(dstOffset); + auto gmTileD = gmRemotePeer[gmDstOffset]; + LayoutC layoutGM2{lenData, actualBlockShape.n(), params.n2}; + LayoutC layoutUB2{lenData, actualBlockShape.n(), n0}; + copyUbToGmD(gmTileD, ubD[tileOffset * n0], layoutGM2, layoutUB2); + tileOffset += lenData; + } + AscendC::SetFlag(event_id); + + } +private: + + Params params; + AscendC::LocalTensor ubCList[UB_STAGES]; + AscendC::LocalTensor ubDList[UB_STAGES]; + AscendC::LocalTensor ubFp32List[UB_STAGES]; + AscendC::LocalTensor scaleUbList[UB_STAGES]; + int32_t source_scale_offset[UB_STAGES]; + + int32_t max_len = 8 * 32 / 4 * 128; + int32_t n0; + bool is_ping = false; + + + int32_t repeat = 128; + + CopyGmToUbC copyGmToUbC; + CopyUbToGmD copyUbToGmD; + + CopyScaleGmToUb copyScaleGmToUb; + AscendC::GlobalTensor tokenPerExpert; + Layout3D tokenPerExpertLayout; +}; +} +#endif \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine/op_kernel/utils/block_mmad_preload_async_fixpipe_quant.hpp b/csrc/dispatch_ffn_combine/op_kernel/utils/block_mmad_preload_async_fixpipe_quant.hpp index 3b435f26..4f935180 100644 --- a/csrc/dispatch_ffn_combine/op_kernel/utils/block_mmad_preload_async_fixpipe_quant.hpp +++ b/csrc/dispatch_ffn_combine/op_kernel/utils/block_mmad_preload_async_fixpipe_quant.hpp @@ -22,8 +22,6 @@ namespace Catlass::Gemm::Block { -constexpr uint16_t CROSS_CORE_FLAG_MAX_SET_COUNT = 15; - template __aicore__ inline void SyncFlagFunc(int32_t eventID) { @@ -153,9 +151,11 @@ public: L1TileShape::K, L1TileShape::N); CATLASS_DEVICE - BlockMmad(Arch::Resource &resource, uint32_t l1BufAddrStart = 0) + BlockMmad(Arch::Resource &resource, __gm__ int32_t* flagPtr = nullptr, int32_t expertPerRank = 0, uint32_t l1BufAddrStart = 0) { syncGroupIdx = 0; + ptrSoftFlagBase_ = flagPtr; + expertPerRank_ = expertPerRank; InitL1(resource, l1BufAddrStart); InitL0A(resource); InitL0B(resource); @@ -272,9 +272,21 @@ public: CATLASS_DEVICE void Finalize(int32_t target, int32_t flag = 0) { - for(;syncGroupIdx <= target; syncGroupIdx++) { - int32_t flagId = syncGroupIdx / CROSS_CORE_FLAG_MAX_SET_COUNT + flag; - AscendC::CrossCoreSetFlag<0x2, PIPE_FIX>(flagId); + if (ptrSoftFlagBase_ != nullptr) { + if (target < 0) { + return; + } + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + AscendC::GlobalTensor flagGlobal; + flagGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(ptrSoftFlagBase_) + (expertPerRank_ + AscendC::GetBlockIdx()) * FLAGSTRIDE); + AscendC::DataCopy(flagGlobal, l1FTensor[target * 16], FLAGSTRIDE); + } + else { + for(;syncGroupIdx <= target; syncGroupIdx++) { + int32_t flagId = syncGroupIdx / 15 + flag; + AscendC::CrossCoreSetFlag<0x2, PIPE_FIX>(flagId); + } } } private: @@ -291,7 +303,6 @@ private: layout::VectorLayout layoutScale; int32_t syncLoopIdx; int32_t flag; - CATLASS_DEVICE L1TileMmadParams() = default; }; @@ -310,11 +321,24 @@ private: AscendC::SetFlag(l1AEventList[i]); AscendC::SetFlag(l1BEventList[i]); } + uint32_t l1SOffset = l1BOffset + L1B_TILE_SIZE * L1_STAGES; if constexpr (std::is_same_v) { - uint32_t l1SOffset = l1BOffset + L1B_TILE_SIZE * L1_STAGES; l1STensor = resource.l1Buf.template GetBufferByByte(l1SOffset); AscendC::SetFlag(0); } + if (ptrSoftFlagBase_ != nullptr) { + // Initialize the flag matrix (structure as below): + // 1 0 0 0 0 0 0 0 + // 2 0 0 0 0 0 0 0 + // ... + // 16 0 0 0 0 0 0 0 + // Then move it to L1 + uint32_t l1FOffset = l1SOffset + L1S_TILE_SIZE; + l1FTensor = resource.l1Buf.template GetBufferByByte(l1FOffset); + AscendC::GlobalTensor flagBase; + flagBase.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(ptrSoftFlagBase_)); + AscendC::DataCopy(l1FTensor, flagBase, expertPerRank_ * FLAGSTRIDE); + } } CATLASS_DEVICE @@ -463,12 +487,20 @@ private: if constexpr (std::is_same_v) { AscendC::SetFlag(0); } + #ifdef __TILE_SYNC__ + if (params.flag > 0) { + int32_t flagId = params.flag + params.syncLoopIdx / 8; + AscendC::CrossCoreSetFlag<0x2, PIPE_FIX>(flagId); + } + #else Finalize(params.syncLoopIdx, params.flag); + #endif } } AscendC::LocalTensor l1ATensorList[L1_STAGES]; AscendC::LocalTensor l1BTensorList[L1_STAGES]; AscendC::LocalTensor l1STensor; + AscendC::LocalTensor l1FTensor; int32_t syncGroupIdx; int32_t l1AEventList[L1_STAGES]; int32_t l1BEventList[L1_STAGES]; @@ -497,8 +529,11 @@ private: CopyL1ToL0A copyL1ToL0A; CopyL1ToL0B copyL1ToL0B; CopyL0CToGm copyL0CToGm; + + __gm__ int32_t* ptrSoftFlagBase_ = nullptr; + int32_t expertPerRank_; }; } // namespace Catlass::Gemm::Block -#endif // CATLASS_GEMM_BLOCK_BLOCK_MMAD_PRELOAD_FIXPIPE_QUANT_HPP +#endif // CATLASS_GEMM_BLOCK_BLOCK_MMAD_PRELOAD_FIXPIPE_QUANT_HPP \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine/op_kernel/utils/const_args.hpp b/csrc/dispatch_ffn_combine/op_kernel/utils/const_args.hpp index 84cb6c4e..3249138e 100644 --- a/csrc/dispatch_ffn_combine/op_kernel/utils/const_args.hpp +++ b/csrc/dispatch_ffn_combine/op_kernel/utils/const_args.hpp @@ -5,5 +5,7 @@ constexpr static uint64_t MB_SIZE = 1024 * 1024UL; constexpr static int32_t NUMS_PER_FLAG = 16; constexpr static int32_t CACHE_LINE = 512; constexpr static int32_t RESET_VAL = 0xffff; -constexpr static int32_t ALIGN_128 = 128; +constexpr static int32_t FLAGSTRIDE = 16; +constexpr static int32_t UB_ALIGN = 32; +constexpr uint16_t CROSS_CORE_FLAG_MAX_SET_COUNT = 15; #endif \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine/op_kernel/utils/dispatch_policy_custom.hpp b/csrc/dispatch_ffn_combine/op_kernel/utils/dispatch_policy_custom.hpp index 31fdbad1..7e30114e 100644 --- a/csrc/dispatch_ffn_combine/op_kernel/utils/dispatch_policy_custom.hpp +++ b/csrc/dispatch_ffn_combine/op_kernel/utils/dispatch_policy_custom.hpp @@ -33,13 +33,13 @@ namespace Catlass::Epilogue { }; template - struct EpilogueAtlasA2PerTokenDequantQuant { + struct EpilogueAtlasA2PerTokenDequantSwigluQuant { using ArchTag = Arch::AtlasA2; static constexpr uint32_t UB_STAGES = UB_STAGES_; }; template - struct EpilogueAtlasA2PerTokenDequantSwigluQuant { + struct EpilogueAtlasA2PerTokenDequantV2 { using ArchTag = Arch::AtlasA2; static constexpr uint32_t UB_STAGES = UB_STAGES_; }; diff --git a/csrc/dispatch_ffn_combine/op_kernel/utils/hccl_shmem.hpp b/csrc/dispatch_ffn_combine/op_kernel/utils/hccl_shmem.hpp index cfbb4daf..93d4c9e7 100644 --- a/csrc/dispatch_ffn_combine/op_kernel/utils/hccl_shmem.hpp +++ b/csrc/dispatch_ffn_combine/op_kernel/utils/hccl_shmem.hpp @@ -5,13 +5,28 @@ #include "kernel_operator.h" #include "const_args.hpp" +#ifdef HCCL_COMM #include "moe_distribute_base.h" +using namespace AscendC::HcclContextDef; -#ifndef HCCL_COMM +#else #include "shmem_api.h" #endif #define FORCE_INLINE_AICORE inline __attribute__((always_inline)) __aicore__ +constexpr int32_t MAX_RANK_SIZE = 32; +constexpr int32_t SHMEM_MEM = 700 * MB_SIZE; + +constexpr uint16_t SEND_SYNC_EVENT_ID = 9; +constexpr uint16_t RECV_SYNC_EVENT_ID = 10; + +constexpr uint32_t SELF_STATE_OFFSET = 256 * 1024; +constexpr uint32_t STATE_OFFSET = 512; + +FORCE_INLINE_AICORE void AicSyncAll() { + AscendC::CrossCoreSetFlag<0x0, PIPE_FIX>(8); + AscendC::CrossCoreWaitFlag<0x0>(8); +} template FORCE_INLINE_AICORE void gm_store(__gm__ T *addr, T val) { @@ -23,10 +38,11 @@ FORCE_INLINE_AICORE T gm_load(__gm__ T *cache) { return *((__gm__ T *)cache); } -FORCE_INLINE_AICORE void gm_dcci(__gm__ uint8_t * addr) { +template +FORCE_INLINE_AICORE void gm_dcci(__gm__ T * addr) { using namespace AscendC; GlobalTensor global; - global.SetGlobalBuffer(addr); + global.SetGlobalBuffer(reinterpret_cast(addr)); // Important: add hint to avoid dcci being optimized by compiler __asm__ __volatile__(""); @@ -37,26 +53,20 @@ FORCE_INLINE_AICORE void gm_dcci(__gm__ uint8_t * addr) { FORCE_INLINE_AICORE int32_t gm_signal_wait_until_eq_for_barrier(__gm__ int32_t *sig_addr, int32_t cmp_val) { do { gm_dcci((__gm__ uint8_t *)sig_addr); - if (*sig_addr == cmp_val) { return *sig_addr; } - - // in case when peer pe enters next barrier if (*sig_addr == cmp_val + 1) { return *sig_addr; } } while (true); - - // never reach return -1; } - FORCE_INLINE_AICORE void gm_signal_wait_until_ne(__gm__ int32_t *sig_addr, int32_t cmp_val) { do { AscendC::LocalTensor ub; - ub.address_.logicPos = static_cast(TPosition::VECIN); + ub.address_.logicPos = static_cast(AscendC::TPosition::VECIN); ub.address_.bufferAddr = 0; AscendC::GlobalTensor sig; sig.SetGlobalBuffer(sig_addr); @@ -71,59 +81,53 @@ FORCE_INLINE_AICORE void gm_signal_wait_until_ne(__gm__ int32_t *sig_addr, int32 } -constexpr int32_t MAX_RANK_SIZE = 32; class HcclShmem { public: #ifdef HCCL_COMM // HCCL needs to initialize the HCCL context - __gm__ HcclOpResParamCustom *WinContext_{nullptr}; - Hccl hccl_; - size_t m_segmentSize; - int32_t m_rank; - int32_t m_rankSize; + __gm__ HcclOpResParamCustom *WinContext_{nullptr}; + Hccl hccl_; + AscendC::LocalTensor ub; + FORCE_INLINE_AICORE + HcclShmem(){ + auto contextGM0 = AscendC::GetHcclContext(); + WinContext_ = (__gm__ HcclOpResParamCustom *)contextGM0; - FORCE_INLINE_AICORE - HcclShmem(){ - auto contextGM0 = AscendC::GetHcclContext(); - WinContext_ = (__gm__ HcclOpResParamCustom *)contextGM0; - - m_rank = WinContext_->localUsrRankId; - m_rankSize = WinContext_->rankSize; - m_segmentSize = WinContext_->winSize; - - } - - FORCE_INLINE_AICORE - size_t SegmentSize() const { - return m_segmentSize; - } - - FORCE_INLINE_AICORE - int32_t RankSize() const { - return m_rankSize; - } + m_rank = WinContext_->localUsrRankId; + m_rankSize = WinContext_->rankSize; + m_segmentSize = WinContext_->winSize; + } + #else + FORCE_INLINE_AICORE + HcclShmem(){ + m_segmentSize = SHMEM_MEM; + } + FORCE_INLINE_AICORE + void initShmem(GM_ADDR symmetricPtr_, size_t rank, size_t rankSize) { + symmetricPtr = symmetricPtr_; + m_rank = rank; + m_rankSize = rankSize; + } #endif FORCE_INLINE_AICORE - GM_ADDR operator() () const { // No argument: return local peermem + GM_ADDR operator() () const { // No parameters: return pointer to local peermem #ifdef HCCL_COMM return (GM_ADDR)(WinContext_->localWindowsIn); #else - return reinterpret_cast(shmemi_get_state()->heap_base); + return reinterpret_cast(shmem_ptr(symmetricPtr, m_rank)); #endif } FORCE_INLINE_AICORE - GM_ADDR operator() (int32_t index) const { // With index: return remote peermem base address + GM_ADDR operator() (int32_t index) const { // With index parameter: return pointer to the base address of remote peermem #ifdef HCCL_COMM return (GM_ADDR)((index == m_rank) ? WinContext_->localWindowsIn : ((HcclRankRelationResV2Custom *)(WinContext_->remoteRes[index].nextDevicePtr))->windowsIn); #else - return reinterpret_cast(shmem_ptr(shmemi_get_state()->heap_base, index)); + return reinterpret_cast(shmem_ptr(symmetricPtr, index)); #endif } - - FORCE_INLINE_AICORE GM_ADDR operator () (int64_t offset, int32_t rankId) const { #ifdef HCCL_COMM @@ -136,15 +140,28 @@ public: return (GM_ADDR)((rankId == m_rank) ? WinContext_->localWindowsIn : ((HcclRankRelationResV2Custom *)(WinContext_->remoteRes[rankId].nextDevicePtr))->windowsIn) + offset; #else - return shmem_ptr(shmemi_get_state()->heap_base + offset, rankId); + return reinterpret_cast(shmem_ptr((symmetricPtr + offset), rankId)); #endif } + + FORCE_INLINE_AICORE + size_t SegmentSize() const { + return m_segmentSize; + } + + FORCE_INLINE_AICORE + int32_t RankSize() const { + return m_rankSize; + } + + FORCE_INLINE_AICORE ~HcclShmem() { } + FORCE_INLINE_AICORE void CrossRankSync() { uint64_t flag_offset = (m_segmentSize - MB_SIZE) / sizeof(int32_t); @@ -165,12 +182,146 @@ public: gm_store(sync_base, count); } + + FORCE_INLINE_AICORE + void InitStatusTargetSum() + { + using namespace AscendC; + uint64_t flag_offset = (m_segmentSize - MB_SIZE) + SELF_STATE_OFFSET; + //uint64_t self_state_offset = (m_segmentSize - 2 * MB_SIZE); + // ep state + //uint32_t coreIdx = get_block_idx();; + uint32_t coreIdx = GetBlockIdx(); + GlobalTensor selfStatusTensor; + selfStatusTensor.SetGlobalBuffer((__gm__ int32_t *)((*this)() + flag_offset)); + __asm__ __volatile__(""); + DataCacheCleanAndInvalid(selfStatusTensor[coreIdx * UB_ALIGN]); + __asm__ __volatile__(""); + int32_t state = selfStatusTensor(coreIdx * UB_ALIGN); + if (state == 0) { + sumTarget_ = static_cast(1.0); + selfStatusTensor(coreIdx * UB_ALIGN) = 0x3F800000; // 1.0f + epStateValue_ = 0x3F800000; // 1.0f + } else { + sumTarget_ = static_cast(0.0); + selfStatusTensor(coreIdx * UB_ALIGN) = 0; + epStateValue_ = 0; + } + __asm__ __volatile__(""); + DataCacheCleanAndInvalid(selfStatusTensor[coreIdx * UB_ALIGN]); + __asm__ __volatile__(""); + } + + FORCE_INLINE_AICORE + void CrossRankSyncV2Set(AscendC::LocalTensor ctrBuffer) { + //subblockid = 0 + uint32_t stateOffset_ = STATE_OFFSET; + // uint32_t epStateOffsetOnWin_ = m_rank * stateOffset_; + + uint64_t flag_offset = (m_segmentSize - MB_SIZE) + m_rank * stateOffset_; + //uint64_t flag_offset = (m_segmentSize - MB_SIZE); + int vec_size = get_block_num(); + int vec_id = get_block_idx(); + + AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>(RECV_SYNC_EVENT_ID); + AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>(SEND_SYNC_EVENT_ID); + AscendC::CrossCoreWaitFlag(SEND_SYNC_EVENT_ID); + pipe_barrier(PIPE_ALL); + + ctrBuffer.SetValue(0, epStateValue_); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + for (uint32_t dstEpIdx = vec_id; dstEpIdx < m_rankSize; dstEpIdx += vec_size) { + AscendC::GlobalTensor gmDstStates; + gmDstStates.SetGlobalBuffer((__gm__ int32_t*)((*this)(flag_offset, dstEpIdx))); + DataCopy(gmDstStates, ctrBuffer, 8); + } + AscendC::CrossCoreWaitFlag(RECV_SYNC_EVENT_ID); + } + + FORCE_INLINE_AICORE + void CrossRankSyncV2Wait(AscendC::LocalTensor statusTensor, AscendC::LocalTensor gatherMaskOutTensor, + AscendC::LocalTensor gatherTmpTensor, AscendC::LocalTensor statusSumOutTensor) { + + uint64_t flag_offset = (m_segmentSize - MB_SIZE); + int vec_size = get_block_num(); + int vec_id = get_block_idx(); + uint32_t stateOffset_ = STATE_OFFSET; + + uint32_t sendRankNum_ = m_rankSize / vec_size; + uint32_t remainderRankNum = m_rankSize % vec_size; + uint32_t startRankId_ = sendRankNum_ * vec_id; + if (vec_id < remainderRankNum) { + sendRankNum_++; + startRankId_ += vec_id; + } else { + startRankId_ += remainderRankNum; + } + uint32_t endRankId_ = startRankId_ + sendRankNum_; + AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>(SEND_SYNC_EVENT_ID); + + AscendC::GlobalTensor epStatusSpaceGlobalTensor_; + epStatusSpaceGlobalTensor_.SetGlobalBuffer((__gm__ float *)((*this)() + flag_offset)); + + if (startRankId_ < m_rankSize) { + AscendC::PipeBarrier(); + gatherTmpTensor.SetValue(0, 1); + uint32_t mask = 1; // gatherMask + sum + uint64_t rsvdCnt = 0; + // DataCopyParams intriParams{static_cast(sendRankNum_), 1, + // static_cast((moeSendNum_ > 512) ? 7 : 15), 0}; + AscendC::DataCopyParams intriParams{static_cast(sendRankNum_), 1, + static_cast(15), 0}; + + float sumOfFlag = static_cast(-1.0); + float minTarget = (sumTarget_ * sendRankNum_) - (float)0.5; + float maxTarget = (sumTarget_ * sendRankNum_) + (float)0.5; + AscendC::SumParams sumParams{1, sendRankNum_, sendRankNum_}; + + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + + while ((sumOfFlag < minTarget) || (sumOfFlag > maxTarget)) { + AscendC::DataCopy(statusTensor, epStatusSpaceGlobalTensor_[startRankId_ * stateOffset_ / sizeof(float)], + intriParams); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + + GatherMask(gatherMaskOutTensor, statusTensor, gatherTmpTensor, true, mask, + {1, (uint16_t)sendRankNum_, 1, 0}, rsvdCnt); + + AscendC::PipeBarrier(); + AscendC::Sum(statusSumOutTensor, gatherMaskOutTensor, sumParams); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + sumOfFlag = statusSumOutTensor.GetValue(0); + } + } + + AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>(RECV_SYNC_EVENT_ID); + AscendC::CrossCoreWaitFlag(RECV_SYNC_EVENT_ID); + + //unpermute + AscendC::CrossCoreWaitFlag(SEND_SYNC_EVENT_ID); + } + + FORCE_INLINE_AICORE __gm__ int32_t* SyncBaseAddr() { uint64_t flag_offset = (m_segmentSize - MB_SIZE) / sizeof(int32_t); return (__gm__ int32_t*)(*this)() + flag_offset + 2048; } + +private: + GM_ADDR symmetricPtr; + int32_t m_rank; + int32_t m_rankSize; + size_t m_segmentSize; + float sumTarget_{0.0}; + int32_t epStateValue_; }; + + #endif