fix: resolve sync bug in DispathFFNCombine when expert num per card is 32 (#6416)

### What this PR does / why we need it?

Fix the synchronization deadlock issue in DispathFFNCombine module that
occurs on NPU cards when the number of experts per card exceeds 16 (the
bug manifests prominently when set to 32/128).

### Does this PR introduce _any_ user-facing change?
No, this is a bug fix for internal synchronization logic specific to NPU
expert dispatch, with no impact on external APIs, interfaces, or
end-user behaviors.


- vLLM version: v0.14.1
- vLLM main:
dc917cceb8

Signed-off-by: xulei_ict <xulei292@huawei.com>
Co-authored-by: xulei_ict <xulei292@huawei.com>
This commit is contained in:
xulei
2026-01-30 21:21:20 +08:00
committed by GitHub
parent 56f5d3bd49
commit 77ea873224
4 changed files with 69 additions and 32 deletions

View File

@@ -42,6 +42,10 @@ using namespace AscendC;
namespace Catlass::Gemm::Kernel {
constexpr uint16_t SYNCFLAGC2V = 9;
constexpr uint16_t SYNCFLAGV2C = 10;
constexpr uint16_t CROSS_CORE_FLAG_MAX_SET_COUNT = 15;
template <
class BlockMmad_,
class BlockScheduler_,
@@ -198,7 +202,7 @@ public:
{
GMM1(params);
AscendC::CrossCoreWaitFlag<0x2>(2);
AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGV2C);
GMM2(params);
}
@@ -210,7 +214,7 @@ public:
{
Dispatch(params);
AscendC::SyncAll<true>();
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(2);
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(SYNCFLAGV2C);
CombineV2(params);
}
@@ -507,7 +511,9 @@ CATLASS_DEVICE
int64_t gmGroupOffsetC = 0;
uint32_t startCoreIdx = 0;
uint32_t syncGroupIdx = 0;
AscendC::CrossCoreWaitFlag<0x2>(0);
uint16_t syncgmmIdx = 0;
AscendC::CrossCoreWaitFlag<0x2>(syncgmmIdx / CROSS_CORE_FLAG_MAX_SET_COUNT); // Wait for AIV to finish cumsum for matmul
syncgmmIdx++;
AicSyncAll();
int64_t preCurrentmSum = 0;
int32_t syncLoopIdx = -1;
@@ -553,7 +559,8 @@ CATLASS_DEVICE
for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += coreNum) {
for(;syncGroupIdx <= groupIdx; syncGroupIdx++) {
AscendC::CrossCoreWaitFlag<0x2>(0);
AscendC::CrossCoreWaitFlag<0x2>(syncgmmIdx / CROSS_CORE_FLAG_MAX_SET_COUNT);
syncgmmIdx ++;
}
// Compute block location
GemmCoord blockCoord = blockScheduler.GetBlockCoord(loopIdx);
@@ -592,7 +599,7 @@ CATLASS_DEVICE
if constexpr (BlockMmad::DispatchPolicy::ASYNC) {
blockMmad.SynchronizeBlock();
}
blockMmad.Finalize(syncLoopIdx, 1);
blockMmad.Finalize(syncLoopIdx, SYNCFLAGC2V);
}
preCurrentmSum += currentM;
@@ -603,10 +610,16 @@ CATLASS_DEVICE
gmGroupOffsetC += inGroupProblemShape.m() * inGroupProblemShape.n();
startCoreIdx = (startCoreIdx + coreLoops) % coreNum;
}
for(;syncGroupIdx < params.expertPerRank; syncGroupIdx++) {
AscendC::CrossCoreWaitFlag<0x2>(syncgmmIdx / CROSS_CORE_FLAG_MAX_SET_COUNT);
syncgmmIdx ++;
}
if constexpr (BlockMmad::DispatchPolicy::ASYNC) {
blockMmad.SynchronizeBlock();
}
blockMmad.Finalize(syncLoopIdx + 1, 1);
blockMmad.Finalize(syncLoopIdx + 1, SYNCFLAGC2V);
}
CATLASS_DEVICE
@@ -674,7 +687,7 @@ CATLASS_DEVICE
uint32_t startLoopIdx = ((coreIdx < startCoreIdx) ? (coreIdx + coreNum) : coreIdx) - startCoreIdx;
// Loop through the matmul of each groupIdx
if (params.expertPerRank > lastDequantExpertNum && groupIdx + 1 == params.expertPerRank - lastDequantExpertNum) {
AscendC::CrossCoreWaitFlag<0x2>(2);
AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGV2C);
}
for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += coreNum) {
if (loopIdx + coreNum >= coreLoops) {
@@ -701,7 +714,7 @@ CATLASS_DEVICE
gmB2[gmGroupOffsetB + gmOffsetB], layoutB2,
gmC2[gmGroupOffsetC + gmOffsetC], layoutC,
gmS2[gmOffsetS], layoutScale,
actualBlockShape, syncLoopIdx, 3
actualBlockShape, syncLoopIdx, 0
);
} else {
blockMmad(
@@ -709,7 +722,7 @@ CATLASS_DEVICE
gmB2[gmGroupOffsetB + gmOffsetB], layoutB2,
gmC2[gmGroupOffsetC + gmOffsetC], layoutC,
gmS2, layoutScale,
actualBlockShape, syncLoopIdx, 3
actualBlockShape, syncLoopIdx, 0
);
}
}
@@ -749,7 +762,9 @@ CATLASS_DEVICE
GetCumsumForMMAIV(tokenPerExpert, cumsumMM, params.expertPerRank, params.rank, params.EP);
}
AscendC::SyncAll<true>();
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(0);
uint16_t syncgmm1Idx = 0;
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(syncgmm1Idx / CROSS_CORE_FLAG_MAX_SET_COUNT);
syncgmm1Idx++;
uint32_t curGroupOffset = 0;
int32_t prevSumBeforeRank = 0;
@@ -791,10 +806,11 @@ CATLASS_DEVICE
if ((params.epilogueGranularity < params.expertPerRank && params.epilogueGranularity > 0) && groupIdx == params.expertPerRank - 1) {
syncLoopIdx++;
AscendC::CrossCoreWaitFlag<0x2>(syncLoopIdx / 8 + 1);
AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGC2V);
}
AscendC::SyncAll<true>();
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(0);
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(syncgmm1Idx / CROSS_CORE_FLAG_MAX_SET_COUNT); // V notifies C that the current communication round is complete
syncgmm1Idx++;
if ((params.epilogueGranularity < params.expertPerRank && params.epilogueGranularity > 0) && groupIdx == params.expertPerRank - 1 && prevGroupSum1 > 0) {
uint32_t rowStartThisCore = 0;
@@ -829,9 +845,9 @@ CATLASS_DEVICE
lastDequantExpertNum = params.expertPerRank - params.epilogueGranularity;
}
if (lastDequantExpertNum < params.expertPerRank) {
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(2);
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(SYNCFLAGV2C);
}
AscendC::CrossCoreWaitFlag<0x2>(syncLoopIdx /8 + 1);
AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGC2V);
AscendC::SyncAll<true>();
if (prevGroupSum1 - dequantSum < params.maxOutputSize) {
uint32_t rowStartThisCore = prevGroupSum1 - dequantSum;;
@@ -907,7 +923,7 @@ CATLASS_DEVICE
}
if (loopIdx == startLoopIdx) {
for (;syncLoopIdx <= groupIdx; syncLoopIdx++) {
int32_t flag_id = 3 + syncLoopIdx / 8;
int32_t flag_id = syncLoopIdx / CROSS_CORE_FLAG_MAX_SET_COUNT;
AscendC::CrossCoreWaitFlag<0x2>(flag_id);
}
}

View File

@@ -22,6 +22,8 @@
namespace Catlass::Gemm::Block {
constexpr uint16_t CROSS_CORE_FLAG_MAX_SET_COUNT = 15;
template<AscendC::HardEvent event>
__aicore__ inline void SyncFlagFunc(int32_t eventID)
{
@@ -271,7 +273,7 @@ public:
void Finalize(int32_t target, int32_t flag = 0)
{
for(;syncGroupIdx <= target; syncGroupIdx++) {
int32_t flagId = syncGroupIdx / 8 + flag;
int32_t flagId = syncGroupIdx / CROSS_CORE_FLAG_MAX_SET_COUNT + flag;
AscendC::CrossCoreSetFlag<0x2, PIPE_FIX>(flagId);
}
}