[bugfix]fix some bug in dispatch_ffn_combine kernel (#6465)
### What this PR does / why we need it?
The kernel internals had an issue with maxoutputsize overflow in the
swiglu section, which has been fixed.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.14.1
- vLLM main:
dc917cceb8
---------
Signed-off-by: LQLlulu <39671654+LQLlulu@users.noreply.github.com>
This commit is contained in:
@@ -653,7 +653,7 @@ private:
|
||||
m_prevSumBeforeRank = prevSumBeforeRank;
|
||||
}
|
||||
int prevSum = prevSumBeforeRank;
|
||||
uint32_t prevGroupSum1 = 0;
|
||||
uint32_t prevGroupSum1 = 0, dequantSum1 = 0, dequantSum2 = 0;
|
||||
uint32_t dequantSum = 0;
|
||||
int32_t syncLoopIdx = -1;
|
||||
uint32_t n = params.problemShape.n();
|
||||
@@ -661,6 +661,7 @@ private:
|
||||
for (int32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) {
|
||||
// The ith core reads data from the ith rank's peermem
|
||||
groupIdxDeq = groupIdx - 2;
|
||||
uint32_t currentM = cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx);
|
||||
for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) {
|
||||
uint32_t rowStart = (dstEpIdx == 0 ? 0 : cumsumMM((dstEpIdx - 1) * params.expertPerRank + groupIdx)) + prevGroupSum1;
|
||||
if (rowStart < params.maxOutputSize) {
|
||||
@@ -687,57 +688,55 @@ private:
|
||||
}
|
||||
}
|
||||
|
||||
if ((params.epilogueGranularity < params.expertPerRank && params.epilogueGranularity > 0) && groupIdx == params.expertPerRank - 1) {
|
||||
syncLoopIdx++;
|
||||
AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGC2V);
|
||||
}
|
||||
AscendC::SyncAll<true>();
|
||||
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(syncgmm1Idx / CROSS_CORE_FLAG_MAX_SET_COUNT); // V notifies C that the current communication round is complete
|
||||
syncgmm1Idx++;
|
||||
|
||||
if ((params.epilogueGranularity < params.expertPerRank && params.epilogueGranularity > 0) && groupIdx == params.expertPerRank - 1 && prevGroupSum1 > 0) {
|
||||
uint32_t rowStartThisCore = 0;
|
||||
MatrixCoord offsetC{0U, 0};
|
||||
uint32_t dequantLen = prevGroupSum1 - dequantSum;
|
||||
if (dequantLen >= params.maxOutputSize) {
|
||||
dequantLen = dequantLen - params.maxOutputSize;
|
||||
prevGroupSum1 += currentM;
|
||||
syncgmm1Idx ++;
|
||||
if (groupIdx + 1 <= params.epilogueGranularity) {
|
||||
if (dequantSum1 + currentM <= params.maxOutputSize) {
|
||||
dequantSum1 += currentM;
|
||||
} else if (dequantSum1 < params.maxOutputSize) {
|
||||
dequantSum1 = params.maxOutputSize;
|
||||
}
|
||||
|
||||
}
|
||||
if (groupIdx + 1 > params.epilogueGranularity && dequantSum1 < params.maxOutputSize) {
|
||||
if (dequantSum1 + dequantSum2 + currentM <= params.maxOutputSize) {
|
||||
dequantSum2 += currentM;
|
||||
} else if (dequantSum1 + dequantSum2 < params.maxOutputSize) {
|
||||
dequantSum2 += params.maxOutputSize - dequantSum1 - dequantSum2;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGC2V);
|
||||
AscendC::SyncAll<true>();
|
||||
|
||||
if (dequantSum1 > 0) {
|
||||
uint32_t rowStartThisCore = 0;
|
||||
MatrixCoord offsetC{0U, 0};
|
||||
MatrixCoord shapeC{dequantSum1, params.problemShape.n()};
|
||||
LayoutC layoutC{dequantSum1, params.problemShape.n()};
|
||||
int64_t gmOffsetC = layoutC.GetOffset(offsetC);
|
||||
int64_t gmOffsetD = params.layoutD1.GetOffset(offsetC);
|
||||
blockEpilogue(gmC[gmOffsetC], shapeC, gmPerTokenScale1[rowStartThisCore], gmPermutedToken[gmOffsetD], gmPerTokenScale2[rowStartThisCore], params.epilogueCoreNum);
|
||||
}
|
||||
AscendC::SyncAll<true>();
|
||||
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(SYNCFLAGV2C);
|
||||
if ((params.epilogueGranularity < params.expertPerRank && params.epilogueGranularity > 0)) {
|
||||
AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGC2V);
|
||||
AscendC::SyncAll<true>();
|
||||
if (dequantSum2 > 0) {
|
||||
uint32_t rowStartThisCore = dequantSum1;
|
||||
MatrixCoord offsetC{rowStartThisCore, 0};
|
||||
uint32_t dequantLen = dequantSum2;
|
||||
MatrixCoord shapeC{dequantLen, params.problemShape.n()};
|
||||
LayoutC layoutC{dequantLen, params.problemShape.n()};
|
||||
int64_t gmOffsetC = layoutC.GetOffset(offsetC);
|
||||
int64_t gmOffsetD = params.layoutD1.GetOffset(offsetC);
|
||||
blockEpilogue(gmC[gmOffsetC], shapeC, gmPerTokenScale1[rowStartThisCore], gmPermutedToken[gmOffsetD], gmPerTokenScale2[rowStartThisCore], params.epilogueCoreNum);
|
||||
blockEpilogue(gmC[gmOffsetC], shapeC, gmPerTokenScale1[rowStartThisCore], gmPermutedToken[gmOffsetD], gmPerTokenScale2[rowStartThisCore], coreNum);
|
||||
}
|
||||
prevGroupSum1 += cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx);
|
||||
dequantSum += cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx);
|
||||
if (groupIdx + 1 == params.epilogueGranularity && groupIdx < params.expertPerRank - 1) {
|
||||
dequantSum = 0;
|
||||
}
|
||||
}
|
||||
syncLoopIdx ++;
|
||||
AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGC2V);
|
||||
AscendC::SyncAll<true>();
|
||||
|
||||
uint32_t lastDequantExpertNum = params.expertPerRank;
|
||||
if (params.epilogueGranularity < params.expertPerRank) {
|
||||
lastDequantExpertNum = params.expertPerRank - params.epilogueGranularity;
|
||||
}
|
||||
if (lastDequantExpertNum < params.expertPerRank) {
|
||||
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(SYNCFLAGV2C);
|
||||
}
|
||||
if (prevGroupSum1 - dequantSum < params.maxOutputSize) {
|
||||
uint32_t rowStartThisCore = prevGroupSum1 - dequantSum;;
|
||||
MatrixCoord offsetC{rowStartThisCore, 0};
|
||||
uint32_t dequantLen = dequantSum;
|
||||
if (prevGroupSum1 >= params.maxOutputSize) {
|
||||
dequantLen = dequantSum - (prevGroupSum1 - params.maxOutputSize);
|
||||
}
|
||||
MatrixCoord shapeC{dequantLen, params.problemShape.n()};
|
||||
LayoutC layoutC{dequantLen, params.problemShape.n()};
|
||||
int64_t gmOffsetC = layoutC.GetOffset(offsetC);
|
||||
int64_t gmOffsetD = params.layoutD1.GetOffset(offsetC);
|
||||
blockEpilogue(gmC[gmOffsetC], shapeC, gmPerTokenScale1[rowStartThisCore], gmPermutedToken[gmOffsetD], gmPerTokenScale2[rowStartThisCore], coreNum);
|
||||
AscendC::SyncAll<true>();
|
||||
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(SYNCFLAGV2C);
|
||||
}
|
||||
blockEpilogue.Finalize();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user