From d1dcdfc4084825d2d8f6ff39f1e69767e5f88c40 Mon Sep 17 00:00:00 2001 From: LQLlulu <39671654+LQLlulu@users.noreply.github.com> Date: Mon, 2 Feb 2026 08:32:42 +0800 Subject: [PATCH] [bugfix]fix some bug in dispatch_ffn_combine kernel (#6465) ### What this PR does / why we need it? The kernel internals had an issue with maxoutputsize overflow in the swiglu section, which has been fixed. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.14.1 - vLLM main: https://github.com/vllm-project/vllm/commit/dc917cceb877dfd13f98c538c4c96158047d98bd --------- Signed-off-by: LQLlulu <39671654+LQLlulu@users.noreply.github.com> --- .../op_kernel/dispatch_ffn_combine_kernel.hpp | 87 +++++++++---------- 1 file changed, 43 insertions(+), 44 deletions(-) diff --git a/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine_kernel.hpp b/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine_kernel.hpp index 422595aa..462004d2 100644 --- a/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine_kernel.hpp +++ b/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine_kernel.hpp @@ -653,7 +653,7 @@ private: m_prevSumBeforeRank = prevSumBeforeRank; } int prevSum = prevSumBeforeRank; - uint32_t prevGroupSum1 = 0; + uint32_t prevGroupSum1 = 0, dequantSum1 = 0, dequantSum2 = 0; uint32_t dequantSum = 0; int32_t syncLoopIdx = -1; uint32_t n = params.problemShape.n(); @@ -661,6 +661,7 @@ private: for (int32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) { // The ith core reads data from the ith rank's peermem groupIdxDeq = groupIdx - 2; + uint32_t currentM = cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx); for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) { uint32_t rowStart = (dstEpIdx == 0 ? 0 : cumsumMM((dstEpIdx - 1) * params.expertPerRank + groupIdx)) + prevGroupSum1; if (rowStart < params.maxOutputSize) { @@ -687,57 +688,55 @@ private: } } - if ((params.epilogueGranularity < params.expertPerRank && params.epilogueGranularity > 0) && groupIdx == params.expertPerRank - 1) { - syncLoopIdx++; - AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGC2V); - } AscendC::SyncAll(); AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(syncgmm1Idx / CROSS_CORE_FLAG_MAX_SET_COUNT); // V notifies C that the current communication round is complete - syncgmm1Idx++; - - if ((params.epilogueGranularity < params.expertPerRank && params.epilogueGranularity > 0) && groupIdx == params.expertPerRank - 1 && prevGroupSum1 > 0) { - uint32_t rowStartThisCore = 0; - MatrixCoord offsetC{0U, 0}; - uint32_t dequantLen = prevGroupSum1 - dequantSum; - if (dequantLen >= params.maxOutputSize) { - dequantLen = dequantLen - params.maxOutputSize; + prevGroupSum1 += currentM; + syncgmm1Idx ++; + if (groupIdx + 1 <= params.epilogueGranularity) { + if (dequantSum1 + currentM <= params.maxOutputSize) { + dequantSum1 += currentM; + } else if (dequantSum1 < params.maxOutputSize) { + dequantSum1 = params.maxOutputSize; } - + } + if (groupIdx + 1 > params.epilogueGranularity && dequantSum1 < params.maxOutputSize) { + if (dequantSum1 + dequantSum2 + currentM <= params.maxOutputSize) { + dequantSum2 += currentM; + } else if (dequantSum1 + dequantSum2 < params.maxOutputSize) { + dequantSum2 += params.maxOutputSize - dequantSum1 - dequantSum2; + } + } + } + + AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGC2V); + AscendC::SyncAll(); + + if (dequantSum1 > 0) { + uint32_t rowStartThisCore = 0; + MatrixCoord offsetC{0U, 0}; + MatrixCoord shapeC{dequantSum1, params.problemShape.n()}; + LayoutC layoutC{dequantSum1, params.problemShape.n()}; + int64_t gmOffsetC = layoutC.GetOffset(offsetC); + int64_t gmOffsetD = params.layoutD1.GetOffset(offsetC); + blockEpilogue(gmC[gmOffsetC], shapeC, gmPerTokenScale1[rowStartThisCore], gmPermutedToken[gmOffsetD], gmPerTokenScale2[rowStartThisCore], params.epilogueCoreNum); + } + AscendC::SyncAll(); + AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(SYNCFLAGV2C); + if ((params.epilogueGranularity < params.expertPerRank && params.epilogueGranularity > 0)) { + AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGC2V); + AscendC::SyncAll(); + if (dequantSum2 > 0) { + uint32_t rowStartThisCore = dequantSum1; + MatrixCoord offsetC{rowStartThisCore, 0}; + uint32_t dequantLen = dequantSum2; MatrixCoord shapeC{dequantLen, params.problemShape.n()}; LayoutC layoutC{dequantLen, params.problemShape.n()}; int64_t gmOffsetC = layoutC.GetOffset(offsetC); int64_t gmOffsetD = params.layoutD1.GetOffset(offsetC); - blockEpilogue(gmC[gmOffsetC], shapeC, gmPerTokenScale1[rowStartThisCore], gmPermutedToken[gmOffsetD], gmPerTokenScale2[rowStartThisCore], params.epilogueCoreNum); + blockEpilogue(gmC[gmOffsetC], shapeC, gmPerTokenScale1[rowStartThisCore], gmPermutedToken[gmOffsetD], gmPerTokenScale2[rowStartThisCore], coreNum); } - prevGroupSum1 += cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx); - dequantSum += cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx); - if (groupIdx + 1 == params.epilogueGranularity && groupIdx < params.expertPerRank - 1) { - dequantSum = 0; - } - } - syncLoopIdx ++; - AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGC2V); - AscendC::SyncAll(); - - uint32_t lastDequantExpertNum = params.expertPerRank; - if (params.epilogueGranularity < params.expertPerRank) { - lastDequantExpertNum = params.expertPerRank - params.epilogueGranularity; - } - if (lastDequantExpertNum < params.expertPerRank) { - AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(SYNCFLAGV2C); - } - if (prevGroupSum1 - dequantSum < params.maxOutputSize) { - uint32_t rowStartThisCore = prevGroupSum1 - dequantSum;; - MatrixCoord offsetC{rowStartThisCore, 0}; - uint32_t dequantLen = dequantSum; - if (prevGroupSum1 >= params.maxOutputSize) { - dequantLen = dequantSum - (prevGroupSum1 - params.maxOutputSize); - } - MatrixCoord shapeC{dequantLen, params.problemShape.n()}; - LayoutC layoutC{dequantLen, params.problemShape.n()}; - int64_t gmOffsetC = layoutC.GetOffset(offsetC); - int64_t gmOffsetD = params.layoutD1.GetOffset(offsetC); - blockEpilogue(gmC[gmOffsetC], shapeC, gmPerTokenScale1[rowStartThisCore], gmPermutedToken[gmOffsetD], gmPerTokenScale2[rowStartThisCore], coreNum); + AscendC::SyncAll(); + AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(SYNCFLAGV2C); } blockEpilogue.Finalize(); }