From 77ea8732241bfd80160824b17d1aee75909e1c24 Mon Sep 17 00:00:00 2001 From: xulei <33539210+serlar@users.noreply.github.com> Date: Fri, 30 Jan 2026 21:21:20 +0800 Subject: [PATCH] fix: resolve sync bug in DispathFFNCombine when expert num per card is 32 (#6416) ### What this PR does / why we need it? Fix the synchronization deadlock issue in DispathFFNCombine module that occurs on NPU cards when the number of experts per card exceeds 16 (the bug manifests prominently when set to 32/128). ### Does this PR introduce _any_ user-facing change? No, this is a bug fix for internal synchronization logic specific to NPU expert dispatch, with no impact on external APIs, interfaces, or end-user behaviors. - vLLM version: v0.14.1 - vLLM main: https://github.com/vllm-project/vllm/commit/dc917cceb877dfd13f98c538c4c96158047d98bd Signed-off-by: xulei_ict Co-authored-by: xulei_ict --- .../op_kernel/dispatch_ffn_combine_kernel.hpp | 47 +++++++++++++------ ...block_mmad_preload_async_fixpipe_quant.hpp | 4 +- .../dispatch_ffn_combine_bf16_kernel.hpp | 46 ++++++++++++------ ...block_mmad_preload_async_fixpipe_quant.hpp | 4 +- 4 files changed, 69 insertions(+), 32 deletions(-) diff --git a/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine_kernel.hpp b/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine_kernel.hpp index a0fe0ad8..422595aa 100644 --- a/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine_kernel.hpp +++ b/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine_kernel.hpp @@ -42,6 +42,10 @@ using namespace AscendC; namespace Catlass::Gemm::Kernel { +constexpr uint16_t SYNCFLAGC2V = 9; +constexpr uint16_t SYNCFLAGV2C = 10; +constexpr uint16_t CROSS_CORE_FLAG_MAX_SET_COUNT = 15; + template < class BlockMmad_, class BlockScheduler_, @@ -189,7 +193,7 @@ public: { GMM1(params); - AscendC::CrossCoreWaitFlag<0x2>(2); + AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGV2C); GMM2(params); } @@ -201,7 +205,7 @@ public: { Dispatch(params); AscendC::SyncAll(); - AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(2); + AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(SYNCFLAGV2C); Combine(params); } @@ -324,9 +328,12 @@ private: int64_t gmGroupOffsetC = 0; uint32_t startCoreIdx = 0; uint32_t syncGroupIdx = 0; - AscendC::CrossCoreWaitFlag<0x2>(0); // Wait for AIV to finish cumsum for matmul int64_t preCurrentmSum = 0; int32_t syncLoopIdx = -1; + + uint16_t syncgmmIdx = 0; + AscendC::CrossCoreWaitFlag<0x2>(syncgmmIdx / CROSS_CORE_FLAG_MAX_SET_COUNT); // Wait for AIV to finish cumsum for matmul + syncgmmIdx++; constexpr uint32_t MAX_EXPERTS_PER_RANK = 32; __gm__ ElementB* weight1Array[MAX_EXPERTS_PER_RANK]; @@ -370,7 +377,8 @@ private: for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += coreNum) { for(;syncGroupIdx <= groupIdx; syncGroupIdx++) { - AscendC::CrossCoreWaitFlag<0x2>(0); + AscendC::CrossCoreWaitFlag<0x2>(syncgmmIdx / CROSS_CORE_FLAG_MAX_SET_COUNT); + syncgmmIdx ++; } // Compute block location GemmCoord blockCoord = blockScheduler.GetBlockCoord(loopIdx); @@ -399,7 +407,7 @@ private: if constexpr (BlockMmad::DispatchPolicy::ASYNC) { blockMmad.SynchronizeBlock(); } - blockMmad.Finalize(syncLoopIdx, 1); + blockMmad.Finalize(syncLoopIdx, SYNCFLAGC2V); } preCurrentmSum += currentM; @@ -410,10 +418,16 @@ private: gmGroupOffsetC += inGroupProblemShape.m() * inGroupProblemShape.n(); startCoreIdx = (startCoreIdx + coreLoops) % coreNum; } + + for(;syncGroupIdx < params.expertPerRank; syncGroupIdx++) { + AscendC::CrossCoreWaitFlag<0x2>(syncgmmIdx / CROSS_CORE_FLAG_MAX_SET_COUNT); + syncgmmIdx ++; + } + if constexpr (BlockMmad::DispatchPolicy::ASYNC) { blockMmad.SynchronizeBlock(); } - blockMmad.Finalize(syncLoopIdx + 1, 1); + blockMmad.Finalize(syncLoopIdx + 1, SYNCFLAGC2V); } CATLASS_DEVICE @@ -482,7 +496,7 @@ private: uint32_t startLoopIdx = ((coreIdx < startCoreIdx) ? (coreIdx + coreNum) : coreIdx) - startCoreIdx; // Loop through the matmul of each groupIdx if (params.expertPerRank > lastDequantExpertNum && groupIdx + 1 == params.expertPerRank - lastDequantExpertNum) { - AscendC::CrossCoreWaitFlag<0x2>(2); + AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGV2C); } for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += coreNum) { if (loopIdx + coreNum >= coreLoops) { @@ -508,7 +522,7 @@ private: gmB2[gmGroupOffsetB + gmOffsetB], layoutB2, gmC2[gmGroupOffsetC + gmOffsetC], layoutC, gmS2[gmOffsetS], layoutScale, - actualBlockShape, syncLoopIdx, 3 + actualBlockShape, syncLoopIdx, 0 ); } } @@ -526,7 +540,7 @@ private: if constexpr (BlockMmad::DispatchPolicy::ASYNC) { blockMmad.SynchronizeBlock(); } - blockMmad.Finalize(params.expertPerRank - 1, 3); + blockMmad.Finalize(params.expertPerRank - 1, 0); } CATLASS_DEVICE @@ -625,7 +639,9 @@ private: GetCumsumForMMAIV(tokenPerExpert, cumsumMM, params.expertPerRank, params.rank, params.EP); } AscendC::SyncAll(); - AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(0); + uint16_t syncgmm1Idx = 0; + AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(syncgmm1Idx / CROSS_CORE_FLAG_MAX_SET_COUNT); + syncgmm1Idx++; uint32_t curGroupOffset = 0; int32_t prevSumBeforeRank = 0; @@ -673,10 +689,11 @@ private: if ((params.epilogueGranularity < params.expertPerRank && params.epilogueGranularity > 0) && groupIdx == params.expertPerRank - 1) { syncLoopIdx++; - AscendC::CrossCoreWaitFlag<0x2>(syncLoopIdx / 8 + 1); + AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGC2V); } AscendC::SyncAll(); - AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(0); // V notifies C that the current communication round is complete + AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(syncgmm1Idx / CROSS_CORE_FLAG_MAX_SET_COUNT); // V notifies C that the current communication round is complete + syncgmm1Idx++; if ((params.epilogueGranularity < params.expertPerRank && params.epilogueGranularity > 0) && groupIdx == params.expertPerRank - 1 && prevGroupSum1 > 0) { uint32_t rowStartThisCore = 0; @@ -699,7 +716,7 @@ private: } } syncLoopIdx ++; - AscendC::CrossCoreWaitFlag<0x2>(syncLoopIdx /8 + 1); + AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGC2V); AscendC::SyncAll(); uint32_t lastDequantExpertNum = params.expertPerRank; @@ -707,7 +724,7 @@ private: lastDequantExpertNum = params.expertPerRank - params.epilogueGranularity; } if (lastDequantExpertNum < params.expertPerRank) { - AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(2); + AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(SYNCFLAGV2C); } if (prevGroupSum1 - dequantSum < params.maxOutputSize) { uint32_t rowStartThisCore = prevGroupSum1 - dequantSum;; @@ -746,7 +763,7 @@ private: BlockEpilogue2 blockEpilogue(resource, epilogueParams); int32_t prevGroupSum2 = 0; for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) { - AscendC::CrossCoreWaitFlag<0x2>(groupIdx / 8 + 3); + AscendC::CrossCoreWaitFlag<0x2>(groupIdx / CROSS_CORE_FLAG_MAX_SET_COUNT); AscendC::SyncAll(); for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) { diff --git a/csrc/dispatch_ffn_combine/op_kernel/utils/block_mmad_preload_async_fixpipe_quant.hpp b/csrc/dispatch_ffn_combine/op_kernel/utils/block_mmad_preload_async_fixpipe_quant.hpp index c15e11b2..3b435f26 100644 --- a/csrc/dispatch_ffn_combine/op_kernel/utils/block_mmad_preload_async_fixpipe_quant.hpp +++ b/csrc/dispatch_ffn_combine/op_kernel/utils/block_mmad_preload_async_fixpipe_quant.hpp @@ -22,6 +22,8 @@ namespace Catlass::Gemm::Block { +constexpr uint16_t CROSS_CORE_FLAG_MAX_SET_COUNT = 15; + template __aicore__ inline void SyncFlagFunc(int32_t eventID) { @@ -271,7 +273,7 @@ public: void Finalize(int32_t target, int32_t flag = 0) { for(;syncGroupIdx <= target; syncGroupIdx++) { - int32_t flagId = syncGroupIdx / 8 + flag; + int32_t flagId = syncGroupIdx / CROSS_CORE_FLAG_MAX_SET_COUNT + flag; AscendC::CrossCoreSetFlag<0x2, PIPE_FIX>(flagId); } } diff --git a/csrc/dispatch_ffn_combine_bf16/op_kernel/dispatch_ffn_combine_bf16_kernel.hpp b/csrc/dispatch_ffn_combine_bf16/op_kernel/dispatch_ffn_combine_bf16_kernel.hpp index b1c74aa8..5877a370 100644 --- a/csrc/dispatch_ffn_combine_bf16/op_kernel/dispatch_ffn_combine_bf16_kernel.hpp +++ b/csrc/dispatch_ffn_combine_bf16/op_kernel/dispatch_ffn_combine_bf16_kernel.hpp @@ -42,6 +42,10 @@ using namespace AscendC; namespace Catlass::Gemm::Kernel { +constexpr uint16_t SYNCFLAGC2V = 9; +constexpr uint16_t SYNCFLAGV2C = 10; +constexpr uint16_t CROSS_CORE_FLAG_MAX_SET_COUNT = 15; + template < class BlockMmad_, class BlockScheduler_, @@ -198,7 +202,7 @@ public: { GMM1(params); - AscendC::CrossCoreWaitFlag<0x2>(2); + AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGV2C); GMM2(params); } @@ -210,7 +214,7 @@ public: { Dispatch(params); AscendC::SyncAll(); - AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(2); + AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(SYNCFLAGV2C); CombineV2(params); } @@ -507,7 +511,9 @@ CATLASS_DEVICE int64_t gmGroupOffsetC = 0; uint32_t startCoreIdx = 0; uint32_t syncGroupIdx = 0; - AscendC::CrossCoreWaitFlag<0x2>(0); + uint16_t syncgmmIdx = 0; + AscendC::CrossCoreWaitFlag<0x2>(syncgmmIdx / CROSS_CORE_FLAG_MAX_SET_COUNT); // Wait for AIV to finish cumsum for matmul + syncgmmIdx++; AicSyncAll(); int64_t preCurrentmSum = 0; int32_t syncLoopIdx = -1; @@ -553,7 +559,8 @@ CATLASS_DEVICE for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += coreNum) { for(;syncGroupIdx <= groupIdx; syncGroupIdx++) { - AscendC::CrossCoreWaitFlag<0x2>(0); + AscendC::CrossCoreWaitFlag<0x2>(syncgmmIdx / CROSS_CORE_FLAG_MAX_SET_COUNT); + syncgmmIdx ++; } // Compute block location GemmCoord blockCoord = blockScheduler.GetBlockCoord(loopIdx); @@ -592,7 +599,7 @@ CATLASS_DEVICE if constexpr (BlockMmad::DispatchPolicy::ASYNC) { blockMmad.SynchronizeBlock(); } - blockMmad.Finalize(syncLoopIdx, 1); + blockMmad.Finalize(syncLoopIdx, SYNCFLAGC2V); } preCurrentmSum += currentM; @@ -603,10 +610,16 @@ CATLASS_DEVICE gmGroupOffsetC += inGroupProblemShape.m() * inGroupProblemShape.n(); startCoreIdx = (startCoreIdx + coreLoops) % coreNum; } + + for(;syncGroupIdx < params.expertPerRank; syncGroupIdx++) { + AscendC::CrossCoreWaitFlag<0x2>(syncgmmIdx / CROSS_CORE_FLAG_MAX_SET_COUNT); + syncgmmIdx ++; + } + if constexpr (BlockMmad::DispatchPolicy::ASYNC) { blockMmad.SynchronizeBlock(); } - blockMmad.Finalize(syncLoopIdx + 1, 1); + blockMmad.Finalize(syncLoopIdx + 1, SYNCFLAGC2V); } CATLASS_DEVICE @@ -674,7 +687,7 @@ CATLASS_DEVICE uint32_t startLoopIdx = ((coreIdx < startCoreIdx) ? (coreIdx + coreNum) : coreIdx) - startCoreIdx; // Loop through the matmul of each groupIdx if (params.expertPerRank > lastDequantExpertNum && groupIdx + 1 == params.expertPerRank - lastDequantExpertNum) { - AscendC::CrossCoreWaitFlag<0x2>(2); + AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGV2C); } for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += coreNum) { if (loopIdx + coreNum >= coreLoops) { @@ -701,7 +714,7 @@ CATLASS_DEVICE gmB2[gmGroupOffsetB + gmOffsetB], layoutB2, gmC2[gmGroupOffsetC + gmOffsetC], layoutC, gmS2[gmOffsetS], layoutScale, - actualBlockShape, syncLoopIdx, 3 + actualBlockShape, syncLoopIdx, 0 ); } else { blockMmad( @@ -709,7 +722,7 @@ CATLASS_DEVICE gmB2[gmGroupOffsetB + gmOffsetB], layoutB2, gmC2[gmGroupOffsetC + gmOffsetC], layoutC, gmS2, layoutScale, - actualBlockShape, syncLoopIdx, 3 + actualBlockShape, syncLoopIdx, 0 ); } } @@ -749,7 +762,9 @@ CATLASS_DEVICE GetCumsumForMMAIV(tokenPerExpert, cumsumMM, params.expertPerRank, params.rank, params.EP); } AscendC::SyncAll(); - AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(0); + uint16_t syncgmm1Idx = 0; + AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(syncgmm1Idx / CROSS_CORE_FLAG_MAX_SET_COUNT); + syncgmm1Idx++; uint32_t curGroupOffset = 0; int32_t prevSumBeforeRank = 0; @@ -791,10 +806,11 @@ CATLASS_DEVICE if ((params.epilogueGranularity < params.expertPerRank && params.epilogueGranularity > 0) && groupIdx == params.expertPerRank - 1) { syncLoopIdx++; - AscendC::CrossCoreWaitFlag<0x2>(syncLoopIdx / 8 + 1); + AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGC2V); } AscendC::SyncAll(); - AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(0); + AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(syncgmm1Idx / CROSS_CORE_FLAG_MAX_SET_COUNT); // V notifies C that the current communication round is complete + syncgmm1Idx++; if ((params.epilogueGranularity < params.expertPerRank && params.epilogueGranularity > 0) && groupIdx == params.expertPerRank - 1 && prevGroupSum1 > 0) { uint32_t rowStartThisCore = 0; @@ -829,9 +845,9 @@ CATLASS_DEVICE lastDequantExpertNum = params.expertPerRank - params.epilogueGranularity; } if (lastDequantExpertNum < params.expertPerRank) { - AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(2); + AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(SYNCFLAGV2C); } - AscendC::CrossCoreWaitFlag<0x2>(syncLoopIdx /8 + 1); + AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGC2V); AscendC::SyncAll(); if (prevGroupSum1 - dequantSum < params.maxOutputSize) { uint32_t rowStartThisCore = prevGroupSum1 - dequantSum;; @@ -907,7 +923,7 @@ CATLASS_DEVICE } if (loopIdx == startLoopIdx) { for (;syncLoopIdx <= groupIdx; syncLoopIdx++) { - int32_t flag_id = 3 + syncLoopIdx / 8; + int32_t flag_id = syncLoopIdx / CROSS_CORE_FLAG_MAX_SET_COUNT; AscendC::CrossCoreWaitFlag<0x2>(flag_id); } } diff --git a/csrc/dispatch_ffn_combine_bf16/op_kernel/utils/block_mmad_preload_async_fixpipe_quant.hpp b/csrc/dispatch_ffn_combine_bf16/op_kernel/utils/block_mmad_preload_async_fixpipe_quant.hpp index c15e11b2..3b435f26 100644 --- a/csrc/dispatch_ffn_combine_bf16/op_kernel/utils/block_mmad_preload_async_fixpipe_quant.hpp +++ b/csrc/dispatch_ffn_combine_bf16/op_kernel/utils/block_mmad_preload_async_fixpipe_quant.hpp @@ -22,6 +22,8 @@ namespace Catlass::Gemm::Block { +constexpr uint16_t CROSS_CORE_FLAG_MAX_SET_COUNT = 15; + template __aicore__ inline void SyncFlagFunc(int32_t eventID) { @@ -271,7 +273,7 @@ public: void Finalize(int32_t target, int32_t flag = 0) { for(;syncGroupIdx <= target; syncGroupIdx++) { - int32_t flagId = syncGroupIdx / 8 + flag; + int32_t flagId = syncGroupIdx / CROSS_CORE_FLAG_MAX_SET_COUNT + flag; AscendC::CrossCoreSetFlag<0x2, PIPE_FIX>(flagId); } }