[bugfix]fix some bug in dispatch_ffn_combine kernel (#6465)

### What this PR does / why we need it?
The kernel internals had an issue with maxoutputsize overflow in the
swiglu section, which has been fixed.
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.14.1
- vLLM main:
dc917cceb8

---------

Signed-off-by: LQLlulu <39671654+LQLlulu@users.noreply.github.com>
This commit is contained in:
LQLlulu
2026-02-02 08:32:42 +08:00
committed by GitHub
parent 347eb36a59
commit d1dcdfc408

View File

@@ -653,7 +653,7 @@ private:
m_prevSumBeforeRank = prevSumBeforeRank;
}
int prevSum = prevSumBeforeRank;
uint32_t prevGroupSum1 = 0;
uint32_t prevGroupSum1 = 0, dequantSum1 = 0, dequantSum2 = 0;
uint32_t dequantSum = 0;
int32_t syncLoopIdx = -1;
uint32_t n = params.problemShape.n();
@@ -661,6 +661,7 @@ private:
for (int32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) {
// The ith core reads data from the ith rank's peermem
groupIdxDeq = groupIdx - 2;
uint32_t currentM = cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx);
for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) {
uint32_t rowStart = (dstEpIdx == 0 ? 0 : cumsumMM((dstEpIdx - 1) * params.expertPerRank + groupIdx)) + prevGroupSum1;
if (rowStart < params.maxOutputSize) {
@@ -687,57 +688,55 @@ private:
}
}
if ((params.epilogueGranularity < params.expertPerRank && params.epilogueGranularity > 0) && groupIdx == params.expertPerRank - 1) {
syncLoopIdx++;
AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGC2V);
}
AscendC::SyncAll<true>();
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(syncgmm1Idx / CROSS_CORE_FLAG_MAX_SET_COUNT); // V notifies C that the current communication round is complete
syncgmm1Idx++;
if ((params.epilogueGranularity < params.expertPerRank && params.epilogueGranularity > 0) && groupIdx == params.expertPerRank - 1 && prevGroupSum1 > 0) {
uint32_t rowStartThisCore = 0;
MatrixCoord offsetC{0U, 0};
uint32_t dequantLen = prevGroupSum1 - dequantSum;
if (dequantLen >= params.maxOutputSize) {
dequantLen = dequantLen - params.maxOutputSize;
prevGroupSum1 += currentM;
syncgmm1Idx ++;
if (groupIdx + 1 <= params.epilogueGranularity) {
if (dequantSum1 + currentM <= params.maxOutputSize) {
dequantSum1 += currentM;
} else if (dequantSum1 < params.maxOutputSize) {
dequantSum1 = params.maxOutputSize;
}
}
if (groupIdx + 1 > params.epilogueGranularity && dequantSum1 < params.maxOutputSize) {
if (dequantSum1 + dequantSum2 + currentM <= params.maxOutputSize) {
dequantSum2 += currentM;
} else if (dequantSum1 + dequantSum2 < params.maxOutputSize) {
dequantSum2 += params.maxOutputSize - dequantSum1 - dequantSum2;
}
}
}
AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGC2V);
AscendC::SyncAll<true>();
if (dequantSum1 > 0) {
uint32_t rowStartThisCore = 0;
MatrixCoord offsetC{0U, 0};
MatrixCoord shapeC{dequantSum1, params.problemShape.n()};
LayoutC layoutC{dequantSum1, params.problemShape.n()};
int64_t gmOffsetC = layoutC.GetOffset(offsetC);
int64_t gmOffsetD = params.layoutD1.GetOffset(offsetC);
blockEpilogue(gmC[gmOffsetC], shapeC, gmPerTokenScale1[rowStartThisCore], gmPermutedToken[gmOffsetD], gmPerTokenScale2[rowStartThisCore], params.epilogueCoreNum);
}
AscendC::SyncAll<true>();
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(SYNCFLAGV2C);
if ((params.epilogueGranularity < params.expertPerRank && params.epilogueGranularity > 0)) {
AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGC2V);
AscendC::SyncAll<true>();
if (dequantSum2 > 0) {
uint32_t rowStartThisCore = dequantSum1;
MatrixCoord offsetC{rowStartThisCore, 0};
uint32_t dequantLen = dequantSum2;
MatrixCoord shapeC{dequantLen, params.problemShape.n()};
LayoutC layoutC{dequantLen, params.problemShape.n()};
int64_t gmOffsetC = layoutC.GetOffset(offsetC);
int64_t gmOffsetD = params.layoutD1.GetOffset(offsetC);
blockEpilogue(gmC[gmOffsetC], shapeC, gmPerTokenScale1[rowStartThisCore], gmPermutedToken[gmOffsetD], gmPerTokenScale2[rowStartThisCore], params.epilogueCoreNum);
blockEpilogue(gmC[gmOffsetC], shapeC, gmPerTokenScale1[rowStartThisCore], gmPermutedToken[gmOffsetD], gmPerTokenScale2[rowStartThisCore], coreNum);
}
prevGroupSum1 += cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx);
dequantSum += cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx);
if (groupIdx + 1 == params.epilogueGranularity && groupIdx < params.expertPerRank - 1) {
dequantSum = 0;
}
}
syncLoopIdx ++;
AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGC2V);
AscendC::SyncAll<true>();
uint32_t lastDequantExpertNum = params.expertPerRank;
if (params.epilogueGranularity < params.expertPerRank) {
lastDequantExpertNum = params.expertPerRank - params.epilogueGranularity;
}
if (lastDequantExpertNum < params.expertPerRank) {
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(SYNCFLAGV2C);
}
if (prevGroupSum1 - dequantSum < params.maxOutputSize) {
uint32_t rowStartThisCore = prevGroupSum1 - dequantSum;;
MatrixCoord offsetC{rowStartThisCore, 0};
uint32_t dequantLen = dequantSum;
if (prevGroupSum1 >= params.maxOutputSize) {
dequantLen = dequantSum - (prevGroupSum1 - params.maxOutputSize);
}
MatrixCoord shapeC{dequantLen, params.problemShape.n()};
LayoutC layoutC{dequantLen, params.problemShape.n()};
int64_t gmOffsetC = layoutC.GetOffset(offsetC);
int64_t gmOffsetD = params.layoutD1.GetOffset(offsetC);
blockEpilogue(gmC[gmOffsetC], shapeC, gmPerTokenScale1[rowStartThisCore], gmPermutedToken[gmOffsetD], gmPerTokenScale2[rowStartThisCore], coreNum);
AscendC::SyncAll<true>();
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(SYNCFLAGV2C);
}
blockEpilogue.Finalize();
}