fix: resolve sync bug in DispathFFNCombine when expert num per card is 32 (#6416)
### What this PR does / why we need it?
Fix the synchronization deadlock issue in DispathFFNCombine module that
occurs on NPU cards when the number of experts per card exceeds 16 (the
bug manifests prominently when set to 32/128).
### Does this PR introduce _any_ user-facing change?
No, this is a bug fix for internal synchronization logic specific to NPU
expert dispatch, with no impact on external APIs, interfaces, or
end-user behaviors.
- vLLM version: v0.14.1
- vLLM main:
dc917cceb8
Signed-off-by: xulei_ict <xulei292@huawei.com>
Co-authored-by: xulei_ict <xulei292@huawei.com>
This commit is contained in:
@@ -42,6 +42,10 @@ using namespace AscendC;
|
|||||||
|
|
||||||
namespace Catlass::Gemm::Kernel {
|
namespace Catlass::Gemm::Kernel {
|
||||||
|
|
||||||
|
constexpr uint16_t SYNCFLAGC2V = 9;
|
||||||
|
constexpr uint16_t SYNCFLAGV2C = 10;
|
||||||
|
constexpr uint16_t CROSS_CORE_FLAG_MAX_SET_COUNT = 15;
|
||||||
|
|
||||||
template <
|
template <
|
||||||
class BlockMmad_,
|
class BlockMmad_,
|
||||||
class BlockScheduler_,
|
class BlockScheduler_,
|
||||||
@@ -189,7 +193,7 @@ public:
|
|||||||
{
|
{
|
||||||
GMM1(params);
|
GMM1(params);
|
||||||
|
|
||||||
AscendC::CrossCoreWaitFlag<0x2>(2);
|
AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGV2C);
|
||||||
|
|
||||||
GMM2(params);
|
GMM2(params);
|
||||||
}
|
}
|
||||||
@@ -201,7 +205,7 @@ public:
|
|||||||
{
|
{
|
||||||
Dispatch(params);
|
Dispatch(params);
|
||||||
AscendC::SyncAll<true>();
|
AscendC::SyncAll<true>();
|
||||||
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(2);
|
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(SYNCFLAGV2C);
|
||||||
|
|
||||||
Combine(params);
|
Combine(params);
|
||||||
}
|
}
|
||||||
@@ -324,10 +328,13 @@ private:
|
|||||||
int64_t gmGroupOffsetC = 0;
|
int64_t gmGroupOffsetC = 0;
|
||||||
uint32_t startCoreIdx = 0;
|
uint32_t startCoreIdx = 0;
|
||||||
uint32_t syncGroupIdx = 0;
|
uint32_t syncGroupIdx = 0;
|
||||||
AscendC::CrossCoreWaitFlag<0x2>(0); // Wait for AIV to finish cumsum for matmul
|
|
||||||
int64_t preCurrentmSum = 0;
|
int64_t preCurrentmSum = 0;
|
||||||
int32_t syncLoopIdx = -1;
|
int32_t syncLoopIdx = -1;
|
||||||
|
|
||||||
|
uint16_t syncgmmIdx = 0;
|
||||||
|
AscendC::CrossCoreWaitFlag<0x2>(syncgmmIdx / CROSS_CORE_FLAG_MAX_SET_COUNT); // Wait for AIV to finish cumsum for matmul
|
||||||
|
syncgmmIdx++;
|
||||||
|
|
||||||
constexpr uint32_t MAX_EXPERTS_PER_RANK = 32;
|
constexpr uint32_t MAX_EXPERTS_PER_RANK = 32;
|
||||||
__gm__ ElementB* weight1Array[MAX_EXPERTS_PER_RANK];
|
__gm__ ElementB* weight1Array[MAX_EXPERTS_PER_RANK];
|
||||||
__gm__ ElementScale * scale1Array[MAX_EXPERTS_PER_RANK];
|
__gm__ ElementScale * scale1Array[MAX_EXPERTS_PER_RANK];
|
||||||
@@ -370,7 +377,8 @@ private:
|
|||||||
|
|
||||||
for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += coreNum) {
|
for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += coreNum) {
|
||||||
for(;syncGroupIdx <= groupIdx; syncGroupIdx++) {
|
for(;syncGroupIdx <= groupIdx; syncGroupIdx++) {
|
||||||
AscendC::CrossCoreWaitFlag<0x2>(0);
|
AscendC::CrossCoreWaitFlag<0x2>(syncgmmIdx / CROSS_CORE_FLAG_MAX_SET_COUNT);
|
||||||
|
syncgmmIdx ++;
|
||||||
}
|
}
|
||||||
// Compute block location
|
// Compute block location
|
||||||
GemmCoord blockCoord = blockScheduler.GetBlockCoord(loopIdx);
|
GemmCoord blockCoord = blockScheduler.GetBlockCoord(loopIdx);
|
||||||
@@ -399,7 +407,7 @@ private:
|
|||||||
if constexpr (BlockMmad::DispatchPolicy::ASYNC) {
|
if constexpr (BlockMmad::DispatchPolicy::ASYNC) {
|
||||||
blockMmad.SynchronizeBlock();
|
blockMmad.SynchronizeBlock();
|
||||||
}
|
}
|
||||||
blockMmad.Finalize(syncLoopIdx, 1);
|
blockMmad.Finalize(syncLoopIdx, SYNCFLAGC2V);
|
||||||
}
|
}
|
||||||
|
|
||||||
preCurrentmSum += currentM;
|
preCurrentmSum += currentM;
|
||||||
@@ -410,10 +418,16 @@ private:
|
|||||||
gmGroupOffsetC += inGroupProblemShape.m() * inGroupProblemShape.n();
|
gmGroupOffsetC += inGroupProblemShape.m() * inGroupProblemShape.n();
|
||||||
startCoreIdx = (startCoreIdx + coreLoops) % coreNum;
|
startCoreIdx = (startCoreIdx + coreLoops) % coreNum;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for(;syncGroupIdx < params.expertPerRank; syncGroupIdx++) {
|
||||||
|
AscendC::CrossCoreWaitFlag<0x2>(syncgmmIdx / CROSS_CORE_FLAG_MAX_SET_COUNT);
|
||||||
|
syncgmmIdx ++;
|
||||||
|
}
|
||||||
|
|
||||||
if constexpr (BlockMmad::DispatchPolicy::ASYNC) {
|
if constexpr (BlockMmad::DispatchPolicy::ASYNC) {
|
||||||
blockMmad.SynchronizeBlock();
|
blockMmad.SynchronizeBlock();
|
||||||
}
|
}
|
||||||
blockMmad.Finalize(syncLoopIdx + 1, 1);
|
blockMmad.Finalize(syncLoopIdx + 1, SYNCFLAGC2V);
|
||||||
}
|
}
|
||||||
|
|
||||||
CATLASS_DEVICE
|
CATLASS_DEVICE
|
||||||
@@ -482,7 +496,7 @@ private:
|
|||||||
uint32_t startLoopIdx = ((coreIdx < startCoreIdx) ? (coreIdx + coreNum) : coreIdx) - startCoreIdx;
|
uint32_t startLoopIdx = ((coreIdx < startCoreIdx) ? (coreIdx + coreNum) : coreIdx) - startCoreIdx;
|
||||||
// Loop through the matmul of each groupIdx
|
// Loop through the matmul of each groupIdx
|
||||||
if (params.expertPerRank > lastDequantExpertNum && groupIdx + 1 == params.expertPerRank - lastDequantExpertNum) {
|
if (params.expertPerRank > lastDequantExpertNum && groupIdx + 1 == params.expertPerRank - lastDequantExpertNum) {
|
||||||
AscendC::CrossCoreWaitFlag<0x2>(2);
|
AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGV2C);
|
||||||
}
|
}
|
||||||
for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += coreNum) {
|
for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += coreNum) {
|
||||||
if (loopIdx + coreNum >= coreLoops) {
|
if (loopIdx + coreNum >= coreLoops) {
|
||||||
@@ -508,7 +522,7 @@ private:
|
|||||||
gmB2[gmGroupOffsetB + gmOffsetB], layoutB2,
|
gmB2[gmGroupOffsetB + gmOffsetB], layoutB2,
|
||||||
gmC2[gmGroupOffsetC + gmOffsetC], layoutC,
|
gmC2[gmGroupOffsetC + gmOffsetC], layoutC,
|
||||||
gmS2[gmOffsetS], layoutScale,
|
gmS2[gmOffsetS], layoutScale,
|
||||||
actualBlockShape, syncLoopIdx, 3
|
actualBlockShape, syncLoopIdx, 0
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -526,7 +540,7 @@ private:
|
|||||||
if constexpr (BlockMmad::DispatchPolicy::ASYNC) {
|
if constexpr (BlockMmad::DispatchPolicy::ASYNC) {
|
||||||
blockMmad.SynchronizeBlock();
|
blockMmad.SynchronizeBlock();
|
||||||
}
|
}
|
||||||
blockMmad.Finalize(params.expertPerRank - 1, 3);
|
blockMmad.Finalize(params.expertPerRank - 1, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
CATLASS_DEVICE
|
CATLASS_DEVICE
|
||||||
@@ -625,7 +639,9 @@ private:
|
|||||||
GetCumsumForMMAIV(tokenPerExpert, cumsumMM, params.expertPerRank, params.rank, params.EP);
|
GetCumsumForMMAIV(tokenPerExpert, cumsumMM, params.expertPerRank, params.rank, params.EP);
|
||||||
}
|
}
|
||||||
AscendC::SyncAll<true>();
|
AscendC::SyncAll<true>();
|
||||||
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(0);
|
uint16_t syncgmm1Idx = 0;
|
||||||
|
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(syncgmm1Idx / CROSS_CORE_FLAG_MAX_SET_COUNT);
|
||||||
|
syncgmm1Idx++;
|
||||||
|
|
||||||
uint32_t curGroupOffset = 0;
|
uint32_t curGroupOffset = 0;
|
||||||
int32_t prevSumBeforeRank = 0;
|
int32_t prevSumBeforeRank = 0;
|
||||||
@@ -673,10 +689,11 @@ private:
|
|||||||
|
|
||||||
if ((params.epilogueGranularity < params.expertPerRank && params.epilogueGranularity > 0) && groupIdx == params.expertPerRank - 1) {
|
if ((params.epilogueGranularity < params.expertPerRank && params.epilogueGranularity > 0) && groupIdx == params.expertPerRank - 1) {
|
||||||
syncLoopIdx++;
|
syncLoopIdx++;
|
||||||
AscendC::CrossCoreWaitFlag<0x2>(syncLoopIdx / 8 + 1);
|
AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGC2V);
|
||||||
}
|
}
|
||||||
AscendC::SyncAll<true>();
|
AscendC::SyncAll<true>();
|
||||||
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(0); // V notifies C that the current communication round is complete
|
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(syncgmm1Idx / CROSS_CORE_FLAG_MAX_SET_COUNT); // V notifies C that the current communication round is complete
|
||||||
|
syncgmm1Idx++;
|
||||||
|
|
||||||
if ((params.epilogueGranularity < params.expertPerRank && params.epilogueGranularity > 0) && groupIdx == params.expertPerRank - 1 && prevGroupSum1 > 0) {
|
if ((params.epilogueGranularity < params.expertPerRank && params.epilogueGranularity > 0) && groupIdx == params.expertPerRank - 1 && prevGroupSum1 > 0) {
|
||||||
uint32_t rowStartThisCore = 0;
|
uint32_t rowStartThisCore = 0;
|
||||||
@@ -699,7 +716,7 @@ private:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
syncLoopIdx ++;
|
syncLoopIdx ++;
|
||||||
AscendC::CrossCoreWaitFlag<0x2>(syncLoopIdx /8 + 1);
|
AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGC2V);
|
||||||
AscendC::SyncAll<true>();
|
AscendC::SyncAll<true>();
|
||||||
|
|
||||||
uint32_t lastDequantExpertNum = params.expertPerRank;
|
uint32_t lastDequantExpertNum = params.expertPerRank;
|
||||||
@@ -707,7 +724,7 @@ private:
|
|||||||
lastDequantExpertNum = params.expertPerRank - params.epilogueGranularity;
|
lastDequantExpertNum = params.expertPerRank - params.epilogueGranularity;
|
||||||
}
|
}
|
||||||
if (lastDequantExpertNum < params.expertPerRank) {
|
if (lastDequantExpertNum < params.expertPerRank) {
|
||||||
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(2);
|
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(SYNCFLAGV2C);
|
||||||
}
|
}
|
||||||
if (prevGroupSum1 - dequantSum < params.maxOutputSize) {
|
if (prevGroupSum1 - dequantSum < params.maxOutputSize) {
|
||||||
uint32_t rowStartThisCore = prevGroupSum1 - dequantSum;;
|
uint32_t rowStartThisCore = prevGroupSum1 - dequantSum;;
|
||||||
@@ -746,7 +763,7 @@ private:
|
|||||||
BlockEpilogue2 blockEpilogue(resource, epilogueParams);
|
BlockEpilogue2 blockEpilogue(resource, epilogueParams);
|
||||||
int32_t prevGroupSum2 = 0;
|
int32_t prevGroupSum2 = 0;
|
||||||
for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) {
|
for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) {
|
||||||
AscendC::CrossCoreWaitFlag<0x2>(groupIdx / 8 + 3);
|
AscendC::CrossCoreWaitFlag<0x2>(groupIdx / CROSS_CORE_FLAG_MAX_SET_COUNT);
|
||||||
AscendC::SyncAll<true>();
|
AscendC::SyncAll<true>();
|
||||||
|
|
||||||
for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) {
|
for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) {
|
||||||
|
|||||||
@@ -22,6 +22,8 @@
|
|||||||
|
|
||||||
namespace Catlass::Gemm::Block {
|
namespace Catlass::Gemm::Block {
|
||||||
|
|
||||||
|
constexpr uint16_t CROSS_CORE_FLAG_MAX_SET_COUNT = 15;
|
||||||
|
|
||||||
template<AscendC::HardEvent event>
|
template<AscendC::HardEvent event>
|
||||||
__aicore__ inline void SyncFlagFunc(int32_t eventID)
|
__aicore__ inline void SyncFlagFunc(int32_t eventID)
|
||||||
{
|
{
|
||||||
@@ -271,7 +273,7 @@ public:
|
|||||||
void Finalize(int32_t target, int32_t flag = 0)
|
void Finalize(int32_t target, int32_t flag = 0)
|
||||||
{
|
{
|
||||||
for(;syncGroupIdx <= target; syncGroupIdx++) {
|
for(;syncGroupIdx <= target; syncGroupIdx++) {
|
||||||
int32_t flagId = syncGroupIdx / 8 + flag;
|
int32_t flagId = syncGroupIdx / CROSS_CORE_FLAG_MAX_SET_COUNT + flag;
|
||||||
AscendC::CrossCoreSetFlag<0x2, PIPE_FIX>(flagId);
|
AscendC::CrossCoreSetFlag<0x2, PIPE_FIX>(flagId);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -42,6 +42,10 @@ using namespace AscendC;
|
|||||||
|
|
||||||
namespace Catlass::Gemm::Kernel {
|
namespace Catlass::Gemm::Kernel {
|
||||||
|
|
||||||
|
constexpr uint16_t SYNCFLAGC2V = 9;
|
||||||
|
constexpr uint16_t SYNCFLAGV2C = 10;
|
||||||
|
constexpr uint16_t CROSS_CORE_FLAG_MAX_SET_COUNT = 15;
|
||||||
|
|
||||||
template <
|
template <
|
||||||
class BlockMmad_,
|
class BlockMmad_,
|
||||||
class BlockScheduler_,
|
class BlockScheduler_,
|
||||||
@@ -198,7 +202,7 @@ public:
|
|||||||
{
|
{
|
||||||
GMM1(params);
|
GMM1(params);
|
||||||
|
|
||||||
AscendC::CrossCoreWaitFlag<0x2>(2);
|
AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGV2C);
|
||||||
|
|
||||||
GMM2(params);
|
GMM2(params);
|
||||||
}
|
}
|
||||||
@@ -210,7 +214,7 @@ public:
|
|||||||
{
|
{
|
||||||
Dispatch(params);
|
Dispatch(params);
|
||||||
AscendC::SyncAll<true>();
|
AscendC::SyncAll<true>();
|
||||||
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(2);
|
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(SYNCFLAGV2C);
|
||||||
|
|
||||||
CombineV2(params);
|
CombineV2(params);
|
||||||
}
|
}
|
||||||
@@ -507,7 +511,9 @@ CATLASS_DEVICE
|
|||||||
int64_t gmGroupOffsetC = 0;
|
int64_t gmGroupOffsetC = 0;
|
||||||
uint32_t startCoreIdx = 0;
|
uint32_t startCoreIdx = 0;
|
||||||
uint32_t syncGroupIdx = 0;
|
uint32_t syncGroupIdx = 0;
|
||||||
AscendC::CrossCoreWaitFlag<0x2>(0);
|
uint16_t syncgmmIdx = 0;
|
||||||
|
AscendC::CrossCoreWaitFlag<0x2>(syncgmmIdx / CROSS_CORE_FLAG_MAX_SET_COUNT); // Wait for AIV to finish cumsum for matmul
|
||||||
|
syncgmmIdx++;
|
||||||
AicSyncAll();
|
AicSyncAll();
|
||||||
int64_t preCurrentmSum = 0;
|
int64_t preCurrentmSum = 0;
|
||||||
int32_t syncLoopIdx = -1;
|
int32_t syncLoopIdx = -1;
|
||||||
@@ -553,7 +559,8 @@ CATLASS_DEVICE
|
|||||||
|
|
||||||
for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += coreNum) {
|
for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += coreNum) {
|
||||||
for(;syncGroupIdx <= groupIdx; syncGroupIdx++) {
|
for(;syncGroupIdx <= groupIdx; syncGroupIdx++) {
|
||||||
AscendC::CrossCoreWaitFlag<0x2>(0);
|
AscendC::CrossCoreWaitFlag<0x2>(syncgmmIdx / CROSS_CORE_FLAG_MAX_SET_COUNT);
|
||||||
|
syncgmmIdx ++;
|
||||||
}
|
}
|
||||||
// Compute block location
|
// Compute block location
|
||||||
GemmCoord blockCoord = blockScheduler.GetBlockCoord(loopIdx);
|
GemmCoord blockCoord = blockScheduler.GetBlockCoord(loopIdx);
|
||||||
@@ -592,7 +599,7 @@ CATLASS_DEVICE
|
|||||||
if constexpr (BlockMmad::DispatchPolicy::ASYNC) {
|
if constexpr (BlockMmad::DispatchPolicy::ASYNC) {
|
||||||
blockMmad.SynchronizeBlock();
|
blockMmad.SynchronizeBlock();
|
||||||
}
|
}
|
||||||
blockMmad.Finalize(syncLoopIdx, 1);
|
blockMmad.Finalize(syncLoopIdx, SYNCFLAGC2V);
|
||||||
}
|
}
|
||||||
|
|
||||||
preCurrentmSum += currentM;
|
preCurrentmSum += currentM;
|
||||||
@@ -603,10 +610,16 @@ CATLASS_DEVICE
|
|||||||
gmGroupOffsetC += inGroupProblemShape.m() * inGroupProblemShape.n();
|
gmGroupOffsetC += inGroupProblemShape.m() * inGroupProblemShape.n();
|
||||||
startCoreIdx = (startCoreIdx + coreLoops) % coreNum;
|
startCoreIdx = (startCoreIdx + coreLoops) % coreNum;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for(;syncGroupIdx < params.expertPerRank; syncGroupIdx++) {
|
||||||
|
AscendC::CrossCoreWaitFlag<0x2>(syncgmmIdx / CROSS_CORE_FLAG_MAX_SET_COUNT);
|
||||||
|
syncgmmIdx ++;
|
||||||
|
}
|
||||||
|
|
||||||
if constexpr (BlockMmad::DispatchPolicy::ASYNC) {
|
if constexpr (BlockMmad::DispatchPolicy::ASYNC) {
|
||||||
blockMmad.SynchronizeBlock();
|
blockMmad.SynchronizeBlock();
|
||||||
}
|
}
|
||||||
blockMmad.Finalize(syncLoopIdx + 1, 1);
|
blockMmad.Finalize(syncLoopIdx + 1, SYNCFLAGC2V);
|
||||||
}
|
}
|
||||||
|
|
||||||
CATLASS_DEVICE
|
CATLASS_DEVICE
|
||||||
@@ -674,7 +687,7 @@ CATLASS_DEVICE
|
|||||||
uint32_t startLoopIdx = ((coreIdx < startCoreIdx) ? (coreIdx + coreNum) : coreIdx) - startCoreIdx;
|
uint32_t startLoopIdx = ((coreIdx < startCoreIdx) ? (coreIdx + coreNum) : coreIdx) - startCoreIdx;
|
||||||
// Loop through the matmul of each groupIdx
|
// Loop through the matmul of each groupIdx
|
||||||
if (params.expertPerRank > lastDequantExpertNum && groupIdx + 1 == params.expertPerRank - lastDequantExpertNum) {
|
if (params.expertPerRank > lastDequantExpertNum && groupIdx + 1 == params.expertPerRank - lastDequantExpertNum) {
|
||||||
AscendC::CrossCoreWaitFlag<0x2>(2);
|
AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGV2C);
|
||||||
}
|
}
|
||||||
for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += coreNum) {
|
for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += coreNum) {
|
||||||
if (loopIdx + coreNum >= coreLoops) {
|
if (loopIdx + coreNum >= coreLoops) {
|
||||||
@@ -701,7 +714,7 @@ CATLASS_DEVICE
|
|||||||
gmB2[gmGroupOffsetB + gmOffsetB], layoutB2,
|
gmB2[gmGroupOffsetB + gmOffsetB], layoutB2,
|
||||||
gmC2[gmGroupOffsetC + gmOffsetC], layoutC,
|
gmC2[gmGroupOffsetC + gmOffsetC], layoutC,
|
||||||
gmS2[gmOffsetS], layoutScale,
|
gmS2[gmOffsetS], layoutScale,
|
||||||
actualBlockShape, syncLoopIdx, 3
|
actualBlockShape, syncLoopIdx, 0
|
||||||
);
|
);
|
||||||
} else {
|
} else {
|
||||||
blockMmad(
|
blockMmad(
|
||||||
@@ -709,7 +722,7 @@ CATLASS_DEVICE
|
|||||||
gmB2[gmGroupOffsetB + gmOffsetB], layoutB2,
|
gmB2[gmGroupOffsetB + gmOffsetB], layoutB2,
|
||||||
gmC2[gmGroupOffsetC + gmOffsetC], layoutC,
|
gmC2[gmGroupOffsetC + gmOffsetC], layoutC,
|
||||||
gmS2, layoutScale,
|
gmS2, layoutScale,
|
||||||
actualBlockShape, syncLoopIdx, 3
|
actualBlockShape, syncLoopIdx, 0
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -749,7 +762,9 @@ CATLASS_DEVICE
|
|||||||
GetCumsumForMMAIV(tokenPerExpert, cumsumMM, params.expertPerRank, params.rank, params.EP);
|
GetCumsumForMMAIV(tokenPerExpert, cumsumMM, params.expertPerRank, params.rank, params.EP);
|
||||||
}
|
}
|
||||||
AscendC::SyncAll<true>();
|
AscendC::SyncAll<true>();
|
||||||
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(0);
|
uint16_t syncgmm1Idx = 0;
|
||||||
|
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(syncgmm1Idx / CROSS_CORE_FLAG_MAX_SET_COUNT);
|
||||||
|
syncgmm1Idx++;
|
||||||
|
|
||||||
uint32_t curGroupOffset = 0;
|
uint32_t curGroupOffset = 0;
|
||||||
int32_t prevSumBeforeRank = 0;
|
int32_t prevSumBeforeRank = 0;
|
||||||
@@ -791,10 +806,11 @@ CATLASS_DEVICE
|
|||||||
|
|
||||||
if ((params.epilogueGranularity < params.expertPerRank && params.epilogueGranularity > 0) && groupIdx == params.expertPerRank - 1) {
|
if ((params.epilogueGranularity < params.expertPerRank && params.epilogueGranularity > 0) && groupIdx == params.expertPerRank - 1) {
|
||||||
syncLoopIdx++;
|
syncLoopIdx++;
|
||||||
AscendC::CrossCoreWaitFlag<0x2>(syncLoopIdx / 8 + 1);
|
AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGC2V);
|
||||||
}
|
}
|
||||||
AscendC::SyncAll<true>();
|
AscendC::SyncAll<true>();
|
||||||
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(0);
|
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(syncgmm1Idx / CROSS_CORE_FLAG_MAX_SET_COUNT); // V notifies C that the current communication round is complete
|
||||||
|
syncgmm1Idx++;
|
||||||
|
|
||||||
if ((params.epilogueGranularity < params.expertPerRank && params.epilogueGranularity > 0) && groupIdx == params.expertPerRank - 1 && prevGroupSum1 > 0) {
|
if ((params.epilogueGranularity < params.expertPerRank && params.epilogueGranularity > 0) && groupIdx == params.expertPerRank - 1 && prevGroupSum1 > 0) {
|
||||||
uint32_t rowStartThisCore = 0;
|
uint32_t rowStartThisCore = 0;
|
||||||
@@ -829,9 +845,9 @@ CATLASS_DEVICE
|
|||||||
lastDequantExpertNum = params.expertPerRank - params.epilogueGranularity;
|
lastDequantExpertNum = params.expertPerRank - params.epilogueGranularity;
|
||||||
}
|
}
|
||||||
if (lastDequantExpertNum < params.expertPerRank) {
|
if (lastDequantExpertNum < params.expertPerRank) {
|
||||||
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(2);
|
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(SYNCFLAGV2C);
|
||||||
}
|
}
|
||||||
AscendC::CrossCoreWaitFlag<0x2>(syncLoopIdx /8 + 1);
|
AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGC2V);
|
||||||
AscendC::SyncAll<true>();
|
AscendC::SyncAll<true>();
|
||||||
if (prevGroupSum1 - dequantSum < params.maxOutputSize) {
|
if (prevGroupSum1 - dequantSum < params.maxOutputSize) {
|
||||||
uint32_t rowStartThisCore = prevGroupSum1 - dequantSum;;
|
uint32_t rowStartThisCore = prevGroupSum1 - dequantSum;;
|
||||||
@@ -907,7 +923,7 @@ CATLASS_DEVICE
|
|||||||
}
|
}
|
||||||
if (loopIdx == startLoopIdx) {
|
if (loopIdx == startLoopIdx) {
|
||||||
for (;syncLoopIdx <= groupIdx; syncLoopIdx++) {
|
for (;syncLoopIdx <= groupIdx; syncLoopIdx++) {
|
||||||
int32_t flag_id = 3 + syncLoopIdx / 8;
|
int32_t flag_id = syncLoopIdx / CROSS_CORE_FLAG_MAX_SET_COUNT;
|
||||||
AscendC::CrossCoreWaitFlag<0x2>(flag_id);
|
AscendC::CrossCoreWaitFlag<0x2>(flag_id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -22,6 +22,8 @@
|
|||||||
|
|
||||||
namespace Catlass::Gemm::Block {
|
namespace Catlass::Gemm::Block {
|
||||||
|
|
||||||
|
constexpr uint16_t CROSS_CORE_FLAG_MAX_SET_COUNT = 15;
|
||||||
|
|
||||||
template<AscendC::HardEvent event>
|
template<AscendC::HardEvent event>
|
||||||
__aicore__ inline void SyncFlagFunc(int32_t eventID)
|
__aicore__ inline void SyncFlagFunc(int32_t eventID)
|
||||||
{
|
{
|
||||||
@@ -271,7 +273,7 @@ public:
|
|||||||
void Finalize(int32_t target, int32_t flag = 0)
|
void Finalize(int32_t target, int32_t flag = 0)
|
||||||
{
|
{
|
||||||
for(;syncGroupIdx <= target; syncGroupIdx++) {
|
for(;syncGroupIdx <= target; syncGroupIdx++) {
|
||||||
int32_t flagId = syncGroupIdx / 8 + flag;
|
int32_t flagId = syncGroupIdx / CROSS_CORE_FLAG_MAX_SET_COUNT + flag;
|
||||||
AscendC::CrossCoreSetFlag<0x2, PIPE_FIX>(flagId);
|
AscendC::CrossCoreSetFlag<0x2, PIPE_FIX>(flagId);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user