fix: resolve sync bug in DispathFFNCombine when expert num per card is 32 (#6416)
### What this PR does / why we need it?
Fix the synchronization deadlock issue in DispathFFNCombine module that
occurs on NPU cards when the number of experts per card exceeds 16 (the
bug manifests prominently when set to 32/128).
### Does this PR introduce _any_ user-facing change?
No, this is a bug fix for internal synchronization logic specific to NPU
expert dispatch, with no impact on external APIs, interfaces, or
end-user behaviors.
- vLLM version: v0.14.1
- vLLM main:
dc917cceb8
Signed-off-by: xulei_ict <xulei292@huawei.com>
Co-authored-by: xulei_ict <xulei292@huawei.com>
This commit is contained in:
@@ -42,6 +42,10 @@ using namespace AscendC;
|
||||
|
||||
namespace Catlass::Gemm::Kernel {
|
||||
|
||||
constexpr uint16_t SYNCFLAGC2V = 9;
|
||||
constexpr uint16_t SYNCFLAGV2C = 10;
|
||||
constexpr uint16_t CROSS_CORE_FLAG_MAX_SET_COUNT = 15;
|
||||
|
||||
template <
|
||||
class BlockMmad_,
|
||||
class BlockScheduler_,
|
||||
@@ -189,7 +193,7 @@ public:
|
||||
{
|
||||
GMM1(params);
|
||||
|
||||
AscendC::CrossCoreWaitFlag<0x2>(2);
|
||||
AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGV2C);
|
||||
|
||||
GMM2(params);
|
||||
}
|
||||
@@ -201,7 +205,7 @@ public:
|
||||
{
|
||||
Dispatch(params);
|
||||
AscendC::SyncAll<true>();
|
||||
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(2);
|
||||
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(SYNCFLAGV2C);
|
||||
|
||||
Combine(params);
|
||||
}
|
||||
@@ -324,9 +328,12 @@ private:
|
||||
int64_t gmGroupOffsetC = 0;
|
||||
uint32_t startCoreIdx = 0;
|
||||
uint32_t syncGroupIdx = 0;
|
||||
AscendC::CrossCoreWaitFlag<0x2>(0); // Wait for AIV to finish cumsum for matmul
|
||||
int64_t preCurrentmSum = 0;
|
||||
int32_t syncLoopIdx = -1;
|
||||
|
||||
uint16_t syncgmmIdx = 0;
|
||||
AscendC::CrossCoreWaitFlag<0x2>(syncgmmIdx / CROSS_CORE_FLAG_MAX_SET_COUNT); // Wait for AIV to finish cumsum for matmul
|
||||
syncgmmIdx++;
|
||||
|
||||
constexpr uint32_t MAX_EXPERTS_PER_RANK = 32;
|
||||
__gm__ ElementB* weight1Array[MAX_EXPERTS_PER_RANK];
|
||||
@@ -370,7 +377,8 @@ private:
|
||||
|
||||
for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += coreNum) {
|
||||
for(;syncGroupIdx <= groupIdx; syncGroupIdx++) {
|
||||
AscendC::CrossCoreWaitFlag<0x2>(0);
|
||||
AscendC::CrossCoreWaitFlag<0x2>(syncgmmIdx / CROSS_CORE_FLAG_MAX_SET_COUNT);
|
||||
syncgmmIdx ++;
|
||||
}
|
||||
// Compute block location
|
||||
GemmCoord blockCoord = blockScheduler.GetBlockCoord(loopIdx);
|
||||
@@ -399,7 +407,7 @@ private:
|
||||
if constexpr (BlockMmad::DispatchPolicy::ASYNC) {
|
||||
blockMmad.SynchronizeBlock();
|
||||
}
|
||||
blockMmad.Finalize(syncLoopIdx, 1);
|
||||
blockMmad.Finalize(syncLoopIdx, SYNCFLAGC2V);
|
||||
}
|
||||
|
||||
preCurrentmSum += currentM;
|
||||
@@ -410,10 +418,16 @@ private:
|
||||
gmGroupOffsetC += inGroupProblemShape.m() * inGroupProblemShape.n();
|
||||
startCoreIdx = (startCoreIdx + coreLoops) % coreNum;
|
||||
}
|
||||
|
||||
for(;syncGroupIdx < params.expertPerRank; syncGroupIdx++) {
|
||||
AscendC::CrossCoreWaitFlag<0x2>(syncgmmIdx / CROSS_CORE_FLAG_MAX_SET_COUNT);
|
||||
syncgmmIdx ++;
|
||||
}
|
||||
|
||||
if constexpr (BlockMmad::DispatchPolicy::ASYNC) {
|
||||
blockMmad.SynchronizeBlock();
|
||||
}
|
||||
blockMmad.Finalize(syncLoopIdx + 1, 1);
|
||||
blockMmad.Finalize(syncLoopIdx + 1, SYNCFLAGC2V);
|
||||
}
|
||||
|
||||
CATLASS_DEVICE
|
||||
@@ -482,7 +496,7 @@ private:
|
||||
uint32_t startLoopIdx = ((coreIdx < startCoreIdx) ? (coreIdx + coreNum) : coreIdx) - startCoreIdx;
|
||||
// Loop through the matmul of each groupIdx
|
||||
if (params.expertPerRank > lastDequantExpertNum && groupIdx + 1 == params.expertPerRank - lastDequantExpertNum) {
|
||||
AscendC::CrossCoreWaitFlag<0x2>(2);
|
||||
AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGV2C);
|
||||
}
|
||||
for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += coreNum) {
|
||||
if (loopIdx + coreNum >= coreLoops) {
|
||||
@@ -508,7 +522,7 @@ private:
|
||||
gmB2[gmGroupOffsetB + gmOffsetB], layoutB2,
|
||||
gmC2[gmGroupOffsetC + gmOffsetC], layoutC,
|
||||
gmS2[gmOffsetS], layoutScale,
|
||||
actualBlockShape, syncLoopIdx, 3
|
||||
actualBlockShape, syncLoopIdx, 0
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -526,7 +540,7 @@ private:
|
||||
if constexpr (BlockMmad::DispatchPolicy::ASYNC) {
|
||||
blockMmad.SynchronizeBlock();
|
||||
}
|
||||
blockMmad.Finalize(params.expertPerRank - 1, 3);
|
||||
blockMmad.Finalize(params.expertPerRank - 1, 0);
|
||||
}
|
||||
|
||||
CATLASS_DEVICE
|
||||
@@ -625,7 +639,9 @@ private:
|
||||
GetCumsumForMMAIV(tokenPerExpert, cumsumMM, params.expertPerRank, params.rank, params.EP);
|
||||
}
|
||||
AscendC::SyncAll<true>();
|
||||
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(0);
|
||||
uint16_t syncgmm1Idx = 0;
|
||||
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(syncgmm1Idx / CROSS_CORE_FLAG_MAX_SET_COUNT);
|
||||
syncgmm1Idx++;
|
||||
|
||||
uint32_t curGroupOffset = 0;
|
||||
int32_t prevSumBeforeRank = 0;
|
||||
@@ -673,10 +689,11 @@ private:
|
||||
|
||||
if ((params.epilogueGranularity < params.expertPerRank && params.epilogueGranularity > 0) && groupIdx == params.expertPerRank - 1) {
|
||||
syncLoopIdx++;
|
||||
AscendC::CrossCoreWaitFlag<0x2>(syncLoopIdx / 8 + 1);
|
||||
AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGC2V);
|
||||
}
|
||||
AscendC::SyncAll<true>();
|
||||
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(0); // V notifies C that the current communication round is complete
|
||||
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(syncgmm1Idx / CROSS_CORE_FLAG_MAX_SET_COUNT); // V notifies C that the current communication round is complete
|
||||
syncgmm1Idx++;
|
||||
|
||||
if ((params.epilogueGranularity < params.expertPerRank && params.epilogueGranularity > 0) && groupIdx == params.expertPerRank - 1 && prevGroupSum1 > 0) {
|
||||
uint32_t rowStartThisCore = 0;
|
||||
@@ -699,7 +716,7 @@ private:
|
||||
}
|
||||
}
|
||||
syncLoopIdx ++;
|
||||
AscendC::CrossCoreWaitFlag<0x2>(syncLoopIdx /8 + 1);
|
||||
AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGC2V);
|
||||
AscendC::SyncAll<true>();
|
||||
|
||||
uint32_t lastDequantExpertNum = params.expertPerRank;
|
||||
@@ -707,7 +724,7 @@ private:
|
||||
lastDequantExpertNum = params.expertPerRank - params.epilogueGranularity;
|
||||
}
|
||||
if (lastDequantExpertNum < params.expertPerRank) {
|
||||
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(2);
|
||||
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(SYNCFLAGV2C);
|
||||
}
|
||||
if (prevGroupSum1 - dequantSum < params.maxOutputSize) {
|
||||
uint32_t rowStartThisCore = prevGroupSum1 - dequantSum;;
|
||||
@@ -746,7 +763,7 @@ private:
|
||||
BlockEpilogue2 blockEpilogue(resource, epilogueParams);
|
||||
int32_t prevGroupSum2 = 0;
|
||||
for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) {
|
||||
AscendC::CrossCoreWaitFlag<0x2>(groupIdx / 8 + 3);
|
||||
AscendC::CrossCoreWaitFlag<0x2>(groupIdx / CROSS_CORE_FLAG_MAX_SET_COUNT);
|
||||
AscendC::SyncAll<true>();
|
||||
|
||||
for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) {
|
||||
|
||||
@@ -22,6 +22,8 @@
|
||||
|
||||
namespace Catlass::Gemm::Block {
|
||||
|
||||
constexpr uint16_t CROSS_CORE_FLAG_MAX_SET_COUNT = 15;
|
||||
|
||||
template<AscendC::HardEvent event>
|
||||
__aicore__ inline void SyncFlagFunc(int32_t eventID)
|
||||
{
|
||||
@@ -271,7 +273,7 @@ public:
|
||||
void Finalize(int32_t target, int32_t flag = 0)
|
||||
{
|
||||
for(;syncGroupIdx <= target; syncGroupIdx++) {
|
||||
int32_t flagId = syncGroupIdx / 8 + flag;
|
||||
int32_t flagId = syncGroupIdx / CROSS_CORE_FLAG_MAX_SET_COUNT + flag;
|
||||
AscendC::CrossCoreSetFlag<0x2, PIPE_FIX>(flagId);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user