diff --git a/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine_kernel.hpp b/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine_kernel.hpp index 469a89e2..a0fe0ad8 100644 --- a/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine_kernel.hpp +++ b/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine_kernel.hpp @@ -224,7 +224,7 @@ private: tokenPerExpert.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(shmem() + peermemInfo.offsetPeerTokenPerExpert)); - tokenPerExpertLayout = Layout3D(params.EP * params.expertPerRank, params.expertPerRank); + tokenPerExpertLayout = Layout3D( AlignUp(params.EP * params.expertPerRank, ALIGN_128), params.expertPerRank); } template @@ -291,7 +291,7 @@ private: AscendC::DataCopyPad( tmpBuffer1, tokenPerExpert[rankId * expertPerRank], - {U16(EP), U16(expertPerRank * sizeof(int32_t)), U16(((EP - 1) * expertPerRank) * sizeof(int32_t)), 0}, + {U16(EP), U16(expertPerRank * sizeof(int32_t)), U16((AlignUp(EP * expertPerRank, ALIGN_128) - expertPerRank) * sizeof(int32_t)), 0}, {} ); @@ -547,7 +547,7 @@ private: CATLASS_DEVICE void CrossRankSyncAndlocalTokenPerExpertAllGather(Params const ¶ms, int64_t localTokenPerExpertOffset){ AscendC::LocalTensor tmpBuffer = resource.ubBuf.template GetBufferByByte(0); - uint32_t numPerCore = params.EP * params.expertPerRank; + uint32_t numPerCore = AlignUp(params.EP * params.expertPerRank, ALIGN_128); for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) { if (dstEpIdx == params.rank) { continue; @@ -582,12 +582,13 @@ private: AscendC::SetFlag(EVENT_ID0); AscendC::WaitFlag(EVENT_ID0); } + for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) { if (dstEpIdx == params.rank) { continue; } int32_t intPer512 = CACHE_LINE / sizeof(int); - for(int32_t checkIdx = 0; checkIdx < params.EP * params.expertPerRank; checkIdx += intPer512) { + for(int32_t checkIdx = 0; checkIdx < AlignUp(params.EP * params.expertPerRank, ALIGN_128); checkIdx += intPer512) { __gm__ int32_t* sync_check = reinterpret_cast<__gm__ int32_t*>(shmem() + peermemInfo.offsetPeerTokenPerExpert) + tokenPerExpertLayout(dstEpIdx, 0, checkIdx); gm_signal_wait_until_ne(sync_check, 0); } @@ -776,7 +777,7 @@ private: } blockEpilogue.Finalize(); AscendC::SyncAll(); - ResetTokenPerExpert(tokenPerExpert, params.EP * params.EP * params.expertPerRank); + ResetTokenPerExpert(tokenPerExpert, params.EP * AlignUp(params.EP * params.expertPerRank, ALIGN_128)); shmem.CrossRankSync(); MoeTokenUnpermuteTilingData tilingData; MoeTokenUnpermuteTiling(params.problemShape.m() * params.topK, n2, params.topK, tilingData, coreNum); diff --git a/csrc/dispatch_ffn_combine/op_kernel/utils/const_args.hpp b/csrc/dispatch_ffn_combine/op_kernel/utils/const_args.hpp index 61a3d866..84cb6c4e 100644 --- a/csrc/dispatch_ffn_combine/op_kernel/utils/const_args.hpp +++ b/csrc/dispatch_ffn_combine/op_kernel/utils/const_args.hpp @@ -5,4 +5,5 @@ constexpr static uint64_t MB_SIZE = 1024 * 1024UL; constexpr static int32_t NUMS_PER_FLAG = 16; constexpr static int32_t CACHE_LINE = 512; constexpr static int32_t RESET_VAL = 0xffff; +constexpr static int32_t ALIGN_128 = 128; #endif \ No newline at end of file