diff --git a/csrc/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_tiling.cpp b/csrc/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_tiling.cpp index f75dbf53..977b3c6e 100644 --- a/csrc/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_tiling.cpp +++ b/csrc/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_tiling.cpp @@ -238,6 +238,7 @@ static ge::graphStatus SetWorkSpace(gert::TilingContext *context, const char *no uint32_t moeExpertNumPerRank = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank; uint32_t h = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.h; uint32_t aicNum = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.aicNum; + uint64_t gmm1HLen = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen; uint64_t gmm2HLen = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen / 2; if (epRankId < sharedExpertRankNum) { maxTokenNum = maxBatchSize * epRankSize / sharedExpertRankNum; @@ -245,20 +246,23 @@ static ge::graphStatus SetWorkSpace(gert::TilingContext *context, const char *no maxTokenNum = maxBatchSize * epRankSize * std::min(topK, moeExpertNumPerRank); } - size_t x2TokenSize = CeilUp(maxTokenNum * gmm2HLen * sizeof(int8_t), GM_ALIGN_SIZE); - size_t x2ScaleSize = CeilUp(maxTokenNum * sizeof(float), GM_ALIGN_SIZE); + size_t x1TokenSize = maxTokenNum * h * sizeof(int8_t); + size_t x2TokenSize = maxTokenNum * gmm2HLen * sizeof(int8_t); + size_t maxTokenSize = x1TokenSize < x2TokenSize ? x2TokenSize : x1TokenSize; + maxTokenSize = CeilUp(maxTokenSize, GM_ALIGN_SIZE); + size_t tokenScaleSize = CeilUp(maxTokenNum * sizeof(float), GM_ALIGN_SIZE); size_t CVSwapBufferSize = CeilUp(aicNum * L1_TILE_BYTE_SIZE * CUBE_WORKSPACE_STAGE * sizeof(int32_t), GM_ALIGN_SIZE); - size_t swigluOutSize = CeilUp(maxTokenNum * gmm2HLen * sizeof(float), GM_ALIGN_SIZE); + size_t swigluOutSize = maxTokenNum * gmm1HLen * sizeof(float); + size_t gmm2DepOutSize = maxTokenNum * h * TOKEN_DTYPE_BYTE_SIZE; + size_t maxSwigluGmm2Size = swigluOutSize < gmm2DepOutSize ? gmm2DepOutSize : swigluOutSize; + maxSwigluGmm2Size = CeilUp(maxSwigluGmm2Size, GM_ALIGN_SIZE); size_t groupListSize = CeilUp(moeExpertNumPerRank * sizeof(int64_t), GM_ALIGN_SIZE); size_t expandIdxSize = CeilUp(batchSize * topK * sizeof(int32_t), GM_ALIGN_SIZE); size_t epSendCountSize = CeilUp(epRankSize * moeExpertNumPerRank * sizeof(int32_t), GM_ALIGN_SIZE); - size_t x1TokenSize = CeilUp(maxTokenNum * h * sizeof(int8_t), GM_ALIGN_SIZE); - size_t x1ScaleSize = CeilUp(maxTokenNum * sizeof(float), GM_ALIGN_SIZE); - size_t gmm2DepOutSize = CeilUp(maxTokenNum * h * TOKEN_DTYPE_BYTE_SIZE, GM_ALIGN_SIZE); size_t resveredSize = CeilUp(RESERVED_WORKSPACE_SIZE, GM_ALIGN_SIZE); - size_t usrSize = x2TokenSize + x2ScaleSize + CVSwapBufferSize + swigluOutSize + groupListSize + expandIdxSize + - epSendCountSize + x1TokenSize + x1ScaleSize + gmm2DepOutSize + resveredSize; + size_t usrSize = maxTokenSize + tokenScaleSize + CVSwapBufferSize + maxSwigluGmm2Size + groupListSize + expandIdxSize + + epSendCountSize + resveredSize; workSpaces[0] = SYSTEM_NEED_WORKSPACE + usrSize; return ge::GRAPH_SUCCESS; diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.h b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.h index 230a11a3..cb7dabb7 100644 --- a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.h +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.h @@ -178,7 +178,7 @@ CATLASS_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCoun template CATLASS_DEVICE void GmmDeq(GemmCoord problemShape, uint32_t groupCount, GM_ADDR gmGroupList, GM_ADDR gmA, - layout::RowMajor layoutA, GM_ADDR gmB, layout::nZ layoutB, GM_ADDR gmScale, + layout::RowMajor layoutA, GM_ADDR gmB, layout::zN layoutB, GM_ADDR gmScale, layout::VectorLayout layoutScale, GM_ADDR gmPerTokenScale, layout::VectorLayout layoutPerTokenScale, GM_ADDR gmD, layout::RowMajor layoutD, GM_ADDR gmWorkspace, void *combiner) @@ -189,7 +189,7 @@ CATLASS_DEVICE void GmmDeq(GemmCoord problemShape, uint32_t groupCount, GM_ADDR using L0TileShape = L0TileShape_; using AType = Gemm::GemmType; - using BType = Gemm::GemmType; + using BType = Gemm::GemmType; using CType = Gemm::GemmType; using BlockMmad = Gemm::Block::BlockMmad; @@ -261,12 +261,12 @@ private: GM_ADDR gmSmoothScales_; GM_ADDR gmexpertScales_; - uint32_t m_{0}; - uint32_t n_{0}; - uint32_t k_{0}; + uint32_t maxTokenNum_{0}; + uint32_t gmm1OutputDim_{0}; + uint32_t tokenHiddenSize_{0}; uint32_t groupCount_{0}; - uint32_t n2_{0}; - uint32_t k2_{0}; + uint32_t gmm2OutputDim_{0}; + uint32_t gmm2InputDim_{0}; uint32_t globalRankId_{0}; uint32_t winSizePerRank_{0}; uint32_t blockDim_{0}; @@ -327,59 +327,62 @@ __aicore__ inline void DispatchGmmCombineDecode::Init( bool isShareExpert = (epRankId_ < sharedExpertRankNum_); if (isShareExpert) { - m_ = maxBs_ * epRankSize_ / sharedExpertRankNum_; + maxTokenNum_ = maxBs_ * epRankSize_ / sharedExpertRankNum_; } else { - m_ = maxBs_ * epRankSize_ * (topK_ < moeExpertNumPerRank_ ? topK_ : moeExpertNumPerRank_); + maxTokenNum_ = maxBs_ * epRankSize_ * (topK_ < moeExpertNumPerRank_ ? topK_ : moeExpertNumPerRank_); } - n_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen; - k_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.h; + gmm1OutputDim_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen; + tokenHiddenSize_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.h; groupCount_ = isShareExpert ? 1 : tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank; - n2_ = k_; - k2_ = n_ / 2; + gmm2OutputDim_ = tokenHiddenSize_; + gmm2InputDim_ = gmm1OutputDim_ / 2; } template __aicore__ inline void DispatchGmmCombineDecode::Process() { - GemmCoord gmm1ProblemShape{m_, n_, k_}; - GemmCoord gmm2ProblemShape{m_, n2_, k2_}; + GemmCoord gmm1ProblemShape{maxTokenNum_, gmm1OutputDim_, tokenHiddenSize_}; + GemmCoord gmm2ProblemShape{maxTokenNum_, gmm2OutputDim_, gmm2InputDim_}; - layout::RowMajor layoutX1{m_, k_}; - layout::zN layoutWeight1 = layout::zN::template MakeLayout(k_, n_); - layout::VectorLayout layoutScale1{n_}; - layout::VectorLayout layoutPerTokenScale1{m_}; - layout::RowMajor layoutX2{m_, k2_}; - layout::nZ layoutWeight2 = layout::nZ::template MakeLayout(k2_, n2_); - layout::VectorLayout layoutScale2{n2_}; - layout::VectorLayout layoutPerTokenScale2{m_}; - layout::RowMajor layoutOutput{m_, n2_}; + layout::RowMajor layoutX1{maxTokenNum_, tokenHiddenSize_}; + layout::zN layoutWeight1 = layout::zN::template MakeLayout(tokenHiddenSize_, gmm1OutputDim_); + layout::VectorLayout layoutW1Scale{gmm1OutputDim_}; + layout::VectorLayout layoutX1Scale{maxTokenNum_}; + layout::RowMajor layoutX2{maxTokenNum_, gmm2InputDim_}; + layout::zN layoutWeight2 = layout::zN::template MakeLayout(gmm2InputDim_, gmm2OutputDim_); + layout::VectorLayout layoutW2Scale{gmm2OutputDim_}; + layout::VectorLayout layoutX2Scale{maxTokenNum_}; + layout::RowMajor layoutOutput{maxTokenNum_, gmm2OutputDim_}; size_t workspaceOffset = 0; constexpr int32_t resveredWorkSpaceSize = 256 * 1024; - GM_ADDR gmX2 = workspaceGM_; - workspaceOffset += RoundUp(static_cast(m_) * k2_ * sizeof(int8_t)); - GM_ADDR gmPerTokenScale2 = workspaceGM_ + workspaceOffset; - workspaceOffset += RoundUp(static_cast(m_) * sizeof(float)); + int64_t x1TokenSize = maxTokenNum_ * tokenHiddenSize_ * sizeof(int8_t); + int64_t x2TokenSize = maxTokenNum_ * gmm2InputDim_ * sizeof(int8_t); + int64_t maxTokenSize = x1TokenSize < x2TokenSize ? x2TokenSize : x1TokenSize; + int64_t tokenScaleSize = maxTokenNum_ * sizeof(float); + GM_ADDR gmX1 = workspaceGM_ + workspaceOffset; + GM_ADDR gmX2 = workspaceGM_ + workspaceOffset; + workspaceOffset += RoundUp(maxTokenSize); + GM_ADDR gmX1Scale = workspaceGM_ + workspaceOffset; + GM_ADDR gmX2Scale = workspaceGM_ + workspaceOffset; + workspaceOffset += RoundUp(tokenScaleSize); GM_ADDR gmWorkspace = workspaceGM_ + workspaceOffset; - GM_ADDR gmCVSwap = workspaceGM_ + workspaceOffset; workspaceOffset += RoundUp(static_cast(blockDim_) * (GMM1_L1M * GMM1_L1N) * WORKSPACE_STAGES * sizeof(int32_t)); + int64_t swigluOutSize = maxTokenNum_ * gmm1OutputDim_ * sizeof(float); + int64_t gmm2OutSize = maxTokenNum_ * tokenHiddenSize_ * sizeof(ExpandXType); + int64_t maxSwigluGmm2Size = swigluOutSize < gmm2OutSize ? gmm2OutSize : swigluOutSize; GM_ADDR gmSwigluOut = workspaceGM_ + workspaceOffset; - workspaceOffset += RoundUp(static_cast(m_) * k2_ * sizeof(float)); + GM_ADDR gmGmm2DepOut = workspaceGM_ + workspaceOffset; + workspaceOffset += RoundUp(maxSwigluGmm2Size); GM_ADDR gmGroupList = workspaceGM_ + workspaceOffset; workspaceOffset += RoundUp(static_cast(groupCount_) * sizeof(int64_t)); GM_ADDR gmExpandIdx = workspaceGM_ + workspaceOffset; workspaceOffset += RoundUp(static_cast(bs_) * topK_ * sizeof(int32_t)); GM_ADDR gmEpSendCount = workspaceGM_ + workspaceOffset; workspaceOffset += RoundUp(static_cast(epRankSize_) * groupCount_ * sizeof(int32_t)); - GM_ADDR gmX1Token = workspaceGM_ + workspaceOffset; - workspaceOffset += RoundUp(static_cast(m_) * k_ * sizeof(int8_t)); - GM_ADDR gmX1Scale = workspaceGM_ + workspaceOffset; - workspaceOffset += RoundUp(static_cast(m_) * sizeof(float)); - GM_ADDR gmGmm2DepOut = workspaceGM_ + workspaceOffset; - workspaceOffset += RoundUp(static_cast(m_) * k_ * sizeof(ExpandXType)); GM_ADDR gmResvered = workspaceGM_ + workspaceOffset; workspaceOffset += RoundUp(resveredWorkSpaceSize); @@ -388,7 +391,7 @@ __aicore__ inline void DispatchGmmCombineDecode::Process() AscendC::TPipe tpipe; MoeDistributeDispatchImpl::CamMoeDistributeDispatch dispatcher; - dispatcher.Init(gmX_, gmexpertIds_, gmSmoothScales_, gmX1Token, gmX1Scale, gmExpandIdx, gmGroupList, + dispatcher.Init(gmX_, gmexpertIds_, gmSmoothScales_, gmX1, gmX1Scale, gmExpandIdx, gmGroupList, gmEpSendCount, gmOutputRecvCount_, nullptr, gmWorkspace, &tpipe, tilingData_); dispatcher.Process(); tpipe.Destroy(); @@ -406,11 +409,11 @@ __aicore__ inline void DispatchGmmCombineDecode::Process() } GmmDeqSwigluQuant( - gmm1ProblemShape, groupCount_, gmGroupList, gmX1Token, layoutX1, gmPermuteWeight1_, layoutWeight1, - gmPermuteScale1_, layoutScale1, gmX1Scale, layoutPerTokenScale1, gmX2, layoutX2, gmPerTokenScale2, - layoutPerTokenScale2, gmWorkspace, gmX_, gmSmoothScales_, gmexpertIds_, gmExpandIdx, gmEpSendCount, gmResvered, + gmm1ProblemShape, groupCount_, gmGroupList, gmX1, layoutX1, gmPermuteWeight1_, layoutWeight1, + gmPermuteScale1_, layoutW1Scale, gmX1Scale, layoutX1Scale, gmX2, layoutX2, gmX2Scale, + layoutX2Scale, gmWorkspace, gmX_, gmSmoothScales_, gmexpertIds_, gmExpandIdx, gmEpSendCount, gmResvered, gmOutputRecvCount_, epRankSize_, epRankId_, moeExpertNum_, moeExpertNumPerRank_, sharedExpertNum_, - sharedExpertRankNum_, quantMode_, globalBs_, bs_, topK_, k_); + sharedExpertRankNum_, quantMode_, globalBs_, bs_, topK_, tokenHiddenSize_); AscendC::PipeBarrier(); Arch::CrossCoreFlag gmm1AivFinished{0}; if constexpr (g_coreType == AscendC::AIV) { @@ -427,7 +430,7 @@ __aicore__ inline void DispatchGmmCombineDecode::Process() } GmmDeq(gmm2ProblemShape, groupCount_, gmGroupList, gmX2, layoutX2, gmWeight2_, layoutWeight2, - gmScale2_, layoutScale2, gmPerTokenScale2, layoutPerTokenScale2, gmGmm2DepOut, + gmScale2_, layoutW2Scale, gmX2Scale, layoutX2Scale, gmGmm2DepOut, layoutOutput, gmWorkspace, &combiner); } #endif // DISPATCH_GMM_COMBINE_DECODE_H diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/epilogue/block/block_epilogue_per_token_dequant_swiglu.h b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/epilogue/block/block_epilogue_per_token_dequant_swiglu.h index f1129dd0..f203dbfe 100644 --- a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/epilogue/block/block_epilogue_per_token_dequant_swiglu.h +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/epilogue/block/block_epilogue_per_token_dequant_swiglu.h @@ -74,20 +74,11 @@ public: std::is_same_v, "TileShape must be consistent for all tile compute ops"); - static constexpr uint32_t CHUNK_TILE_COLUMN = TileShape::COLUMN / 2; - using ChunkTileShape = MatrixShape; - - using TileStrideMuls = Tile::TileStrideMuls; - using TileStrideDiv = Tile::TileStrideDiv; - using TileStrideMul = Tile::TileStrideMul; - static_assert(UB_STAGES <= 2, "UB stages too large, event id is not enough."); static_assert((UB_STAGES * (TileShape::COUNT * sizeof(ElementC) + TileShape::COLUMN * sizeof(ElementScale) + TileShape::ROW * sizeof(ElementPerTokenScale) + TileShape::COUNT * sizeof(ElementD)) + - (TileShape::COUNT + ChunkTileShape::COUNT) * sizeof(float) + TileShape::ROW * BYTE_PER_BLK) <= + (TileShape::COUNT + TileShape::COUNT) * sizeof(float) + TileShape::ROW * BYTE_PER_BLK) <= ArchTag::UB_SIZE, "TileShape is too large to fit in UB"); @@ -151,7 +142,7 @@ public: ubOffset += TileShape::COUNT * sizeof(float); ubTmpMx32B = resource.ubBuf.template GetBufferByByte(ubOffset); ubOffset += TileShape::ROW * BYTE_PER_BLK; - ubTmpMxChunkN = resource.ubBuf.template GetBufferByByte(ubOffset); + ubDenominatorMxN = resource.ubBuf.template GetBufferByByte(ubOffset); } CATLASS_DEVICE @@ -185,7 +176,7 @@ public: MatrixCoord blockCoord = blockCoordMNK.GetCoordMN(); MatrixCoord actualBlockShape = actualBlockShapeMNK.GetCoordMN(); MatrixCoord blockOffset = blockCoord * blockShape; - + bool isLeft = blockOffset.column() < (params.layoutD.shape(1) >> 1); AscendC::GlobalTensor gmScale; gmScale.SetGlobalBuffer(params.ptrScale); AscendC::GlobalTensor gmPerTokenScale; @@ -194,7 +185,6 @@ public: gmD.SetGlobalBuffer(params.ptrD); auto ubTileStride = MakeCoord(static_cast(TileShape::COLUMN), 1L); - auto ubChunkTileStride = MakeCoord(static_cast(ChunkTileShape::COLUMN), 1L); auto tileShape = TileShape::ToCoord(); EpilogueTileSwizzle epilogueTileSwizzle(actualBlockShape, tileShape); uint32_t tileLoops = epilogueTileSwizzle.GetLoops(); @@ -206,9 +196,6 @@ public: auto tileOffsetInBlock = tileCoord * tileShape; auto tileOffset = blockOffset + tileOffsetInBlock; - auto actualChunkTileShape = MakeCoord(actualTileShape.row(), actualTileShape.column() >> 1); - auto chunkTileOffset = MakeCoord(tileOffset.row(), tileOffset.column() >> 1); - auto gmTileC = gmBlockC[layoutBlockC.GetOffset(tileOffsetInBlock)]; auto layoutGmTileC = layoutBlockC.GetTileLayout(actualTileShape); @@ -257,27 +244,29 @@ public: tileBroadcastOneBlk(ubTmpMx32B, ubPerTokenScale); AscendC::SetFlag(eventUbPerTokenScaleVMTE2List[ubListId]); - AscendC::PipeBarrier(); - tileOneBlkColumnBroadcastMul(ubTmpMxN, ubTmpMxN, ubTmpMx32B); - AscendC::PipeBarrier(); - tileStrideMuls(ubTmpMxChunkN, ubTmpMxN, -1.0f); - AscendC::PipeBarrier(); - AscendC::Exp(ubTmpMxChunkN, ubTmpMxChunkN, ChunkTileShape::COUNT); - AscendC::PipeBarrier(); - AscendC::Adds(ubTmpMxChunkN, ubTmpMxChunkN, 1.0f, ChunkTileShape::COUNT); - AscendC::PipeBarrier(); - tileStrideDiv(ubTmpMxChunkN, ubTmpMxN, ubTmpMxChunkN); - AscendC::PipeBarrier(); auto &ubD = ubDList[ubListId]; - LayoutD layoutUbD{actualChunkTileShape, ubChunkTileStride}; - - auto ubTmpMxNR = ubTmpMxN[ChunkTileShape::COLUMN]; - AscendC::WaitFlag(eventUbDMTE3VList[ubListId]); - tileStrideMul(ubD, ubTmpMxNR, ubTmpMxChunkN); + LayoutD layoutUbD{actualTileShape, ubTileStride}; + AscendC::PipeBarrier(); + // after dequant, the left half does x / (x + exp(-Dequant(x))), the right dose nothing + if (isLeft) { + tileOneBlkColumnBroadcastMul(ubTmpMxN, ubTmpMxN, ubTmpMx32B); + AscendC::PipeBarrier(); + AscendC::Muls(ubDenominatorMxN, ubTmpMxN, -1.0f, TileShape::COUNT); + AscendC::PipeBarrier(); + AscendC::Exp(ubDenominatorMxN, ubDenominatorMxN, TileShape::COUNT); + AscendC::PipeBarrier(); + AscendC::Adds(ubDenominatorMxN, ubDenominatorMxN, 1.0f, TileShape::COUNT); + AscendC::PipeBarrier(); + AscendC::WaitFlag(eventUbDMTE3VList[ubListId]); + AscendC::Div(ubD, ubTmpMxN, ubDenominatorMxN, TileShape::COUNT); + } else { + AscendC::WaitFlag(eventUbDMTE3VList[ubListId]); + tileOneBlkColumnBroadcastMul(ubD, ubTmpMxN, ubTmpMx32B); + } AscendC::SetFlag(eventUbDVMTE3List[ubListId]); - auto gmTileD = gmD[params.layoutD.GetOffset(chunkTileOffset)]; - auto layoutGmTileD = params.layoutD.GetTileLayout(actualChunkTileShape); + auto gmTileD = gmD[params.layoutD.GetOffset(tileOffset)]; + auto layoutGmTileD = params.layoutD.GetTileLayout(actualTileShape); AscendC::WaitFlag(eventUbDVMTE3List[ubListId]); copyUbToGmD(gmTileD, ubD, layoutGmTileD, layoutUbD); @@ -307,16 +296,12 @@ private: AscendC::LocalTensor ubTmpMxN; AscendC::LocalTensor ubTmpMx32B; - AscendC::LocalTensor ubTmpMxChunkN; + AscendC::LocalTensor ubDenominatorMxN; TileRowBroadcastMul tileRowBroadcastMul; TileBroadcastOneBlk tileBroadcastOneBlk; TileOneBlkColumnBroadcastMul tileOneBlkColumnBroadcastMul; - TileStrideMuls tileStrideMuls; - TileStrideDiv tileStrideDiv; - TileStrideMul tileStrideMul; - CopyGmToUbC copyGmToUbC; CopyGmToUbScale copyGmToUbScale; CopyGmToUbPerTokenScale copyGmToUbPerTokenScale; diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_swiglu_quant_multistage_workspace.h b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_swiglu_quant_multistage_workspace.h index afd651e3..e1562201 100644 --- a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_swiglu_quant_multistage_workspace.h +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_swiglu_quant_multistage_workspace.h @@ -140,6 +140,7 @@ public: ubQuantScale = resource.ubBuf.template GetBufferByByte(ubOffset); ubOffset += CEIL_UP(tileRow * sizeof(float)); ubInputTmp = ubAbs; + ubInputRightHalf = ubAbs; ubQuantF32 = ubAbs; ubQuantS32 = ubAbs.ReinterpretCast(); ubQuantF16 = ubAbs.ReinterpretCast(); @@ -188,10 +189,14 @@ public: layout::RowMajor layoutUbInput{actualTileShape, ubTileStride}; AscendC::WaitFlag(0); + // continue swiglu computing here and then quant copyGmToUbInput(ubInput, gmTileInput, layoutUbInput, layoutGmTileInput); + copyGmToUbInput(ubInputRightHalf, gmTileInput[params.layoutInput.shape(1) >> 1], layoutUbInput, layoutGmTileInput); AscendC::SetFlag(0); AscendC::WaitFlag(0); + AscendC::Mul(ubInput, ubInput, ubInputRightHalf, tileCount); + AscendC::PipeBarrier(); AscendC::Abs(ubAbs, ubInput, tileCount); AscendC::PipeBarrier(); @@ -290,6 +295,7 @@ private: AscendC::LocalTensor ubQuantScale; AscendC::LocalTensor ubQuantScaleBrcb; AscendC::LocalTensor ubInputTmp; + AscendC::LocalTensor ubInputRightHalf; AscendC::LocalTensor ubQuantF32; AscendC::LocalTensor ubQuantS32; AscendC::LocalTensor ubQuantF16; @@ -1232,7 +1238,6 @@ public: __gm__ float *gmSwigluOutput, uint32_t n, uint32_t k, LayoutScale layoutScale, LayoutPerTokenScale wholeLayoutPerTokenScale, LayoutOutput layoutOutput) { - uint32_t nOut = n / 2; uint32_t coreNumPerGroup = recvCoreNum / localExpertNum; int64_t gmGroupOffsetScale = 0; int64_t gmGroupOffsetPerTokenScale = 0; @@ -1267,7 +1272,7 @@ public: GemmCoord inGroupProblemShape{currentM, n, k}; LayoutPerTokenScale layoutPerTokenScale = wholeLayoutPerTokenScale.GetTileLayout(inGroupProblemShape.template GetCoordByAxis<0>()); - LayoutD layoutD = layoutOutput.GetTileLayout(MakeCoord(currentM, nOut)); + LayoutD layoutD = layout::RowMajor{currentM, n}; EpilogueParams epilogueParams{gmScale + gmGroupOffsetScale, layoutScale, gmTokenScale + gmGroupOffsetPerTokenScale, @@ -1299,7 +1304,7 @@ public: gmGroupOffsetScale += inGroupProblemShape.n(); gmGroupOffsetPerTokenScale += inGroupProblemShape.m(); - gmGroupOffsetD += currentM * nOut; + gmGroupOffsetD += currentM * n; startCoreIdx = (startCoreIdx + coreLoops) % aiCoreGroupNum; } @@ -1514,11 +1519,13 @@ public: __asm__ __volatile__(""); totalTokenCount = sendCountsGlobal.GetValue(localExpertNum * epRankSize - 1); AscendC::PipeBarrier(); + uint32_t n = params.problemShape.n(); uint32_t nOut = params.problemShape.n() / 2; uint32_t quantRowOnce = 0; CalQuantRow(nOut, quantRowOnce); + auto swigluLayout = layout::RowMajor{totalTokenCount, n}; typename BlockQuant::Params quantParams{ - gmSwigluOutput, params.layoutOutput, params.ptrDequantScale, params.layoutDequantScale, + gmSwigluOutput, swigluLayout, params.ptrDequantScale, params.layoutDequantScale, params.ptrOutput, params.layoutOutput, quantRowOnce, nOut}; BlockQuant blockQuant(resource, quantParams); @@ -1850,6 +1857,7 @@ public: params.ptrWorkspace + sizeof(int32_t) * (L1TileShape::M * coreNum * WORKSPACE_STAGES * L1TileShape::N)); uint32_t mActual = groupList.GetValue(params.problemCount - 1); + uint32_t n = params.problemShape.n(); uint32_t nOut = params.problemShape.n() / 2; { @@ -1866,7 +1874,7 @@ public: LayoutScale layoutScale = params.layoutScale; LayoutPerTokenScale layoutPerTokenScale = params.layoutPerTokenScale.GetTileLayout(inGroupProblemShape.template GetCoordByAxis<0>()); - LayoutD layoutD = params.layoutOutput.GetTileLayout(MakeCoord(currentM, nOut)); + LayoutD layoutD = layout::RowMajor{currentM, n}; EpilogueParams epilogueParams{params.ptrScale + gmGroupOffsetScale, layoutScale, @@ -1899,7 +1907,7 @@ public: gmGroupOffsetScale += inGroupProblemShape.n(); gmGroupOffsetPerTokenScale += inGroupProblemShape.m(); - gmGroupOffsetD += currentM * nOut; + gmGroupOffsetD += currentM * n; startCoreIdx = (startCoreIdx + coreLoops) % coreNum; } @@ -1910,8 +1918,9 @@ public: { uint32_t quantRowOnce = 0; CalQuantRow(nOut, quantRowOnce); + auto swigluLayout = layout::RowMajor{mActual, n}; typename BlockQuant::Params quantParams{ptrD, - params.layoutOutput, + swigluLayout, params.ptrDequantScale, params.layoutDequantScale, params.ptrOutput, diff --git a/tests/e2e/nightly/multicard_ops/test_dispatch_gmm_combine_decode.py b/tests/e2e/nightly/multicard_ops/test_dispatch_gmm_combine_decode.py index 1333a390..d11254cc 100644 --- a/tests/e2e/nightly/multicard_ops/test_dispatch_gmm_combine_decode.py +++ b/tests/e2e/nightly/multicard_ops/test_dispatch_gmm_combine_decode.py @@ -12,9 +12,7 @@ import torchair from vllm_ascend.utils import enable_custom_op -config = torchair.CompilerConfig() -config.mode = "reduce-overhead" -npu_backend = torchair.get_npu_backend(compiler_config=config) +torch.manual_seed(42) torch_npu.npu.config.allow_internal_format = True enable_custom_op() LOG_NAME = "dispatch_gmm_combine_decode_test_logs" @@ -101,7 +99,21 @@ class DecodeMoeOps(torch.nn.Module): def _process_weights_after_loading(self, gmm1_weight, gmm1_weight_scale, gmm2_weight, gmm2_weight_scale): - raise NotImplementedError("To be implemented in subclass") + gmm1_weight = torch_npu.npu_format_cast(gmm1_weight, + torch_npu.Format.FRACTAL_NZ) + gmm2_weight = torch_npu.npu_format_cast(gmm2_weight, + torch_npu.Format.FRACTAL_NZ) + self.gmm1_weight = torch.nn.Parameter(gmm1_weight, requires_grad=False) + self.gmm1_weight_scale = torch.nn.Parameter(gmm1_weight_scale, + requires_grad=False) + self.gmm2_weight = torch.nn.Parameter(gmm2_weight, requires_grad=False) + self.gmm2_weight_scale = torch.nn.Parameter(gmm2_weight_scale, + requires_grad=False) + + self.gmm1_weight_scale_fp32 = torch.nn.Parameter( + gmm1_weight_scale.float(), requires_grad=False) + self.gmm2_weight_scale_fp32 = torch.nn.Parameter( + gmm2_weight_scale.float(), requires_grad=False) def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales): raise NotImplementedError("To be implemented in subclass") @@ -132,19 +144,6 @@ class SmallOps(DecodeMoeOps): shared_expert_rank_num) self.tp_hcomm_info = "" - def _process_weights_after_loading(self, gmm1_weight, gmm1_weight_scale, - gmm2_weight, gmm2_weight_scale): - gmm1_weight = torch_npu.npu_format_cast(gmm1_weight, - torch_npu.Format.FRACTAL_NZ) - gmm2_weight = torch_npu.npu_format_cast(gmm2_weight, - torch_npu.Format.FRACTAL_NZ) - self.gmm1_weight = torch.nn.Parameter(gmm1_weight, requires_grad=False) - self.gmm1_weight_scale = torch.nn.Parameter(gmm1_weight_scale, - requires_grad=False) - self.gmm2_weight = torch.nn.Parameter(gmm2_weight, requires_grad=False) - self.gmm2_weight_scale = torch.nn.Parameter(gmm2_weight_scale, - requires_grad=False) - def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales): outputs = torch_npu.npu_moe_distribute_dispatch_v2( x=x, @@ -238,41 +237,14 @@ class FusionOp(DecodeMoeOps): ep_world_size, moe_expert_num, global_rank_id, shared_expert_rank_num) - def _process_weights_after_loading(self, gmm1_weight, gmm1_weight_scale, - gmm2_weight, gmm2_weight_scale): - gmm1_weight = gmm1_weight.transpose(1,2).contiguous()\ - .view(self.local_expert_num, 2, self.moe_intermediate_size // 64, 64, self.token_hidden_size)\ - .transpose(1,2).contiguous()\ - .view(self.local_expert_num, self.moe_intermediate_size * 2, self.token_hidden_size)\ - .transpose(1,2).contiguous() - gmm1_weight = torch_npu.npu_format_cast(gmm1_weight, - torch_npu.Format.ND) - gmm1_weight.add_(0) - gmm1_weight = torch_npu.npu_format_cast(gmm1_weight, - torch_npu.Format.FRACTAL_NZ) - gmm1_weight_scale = permute_weight(gmm1_weight_scale, 128) - gmm2_weight = torch_npu.npu_format_cast( - gmm2_weight.transpose(1, 2).contiguous(), - torch_npu.Format.FRACTAL_NZ) - - gmm1_weight_scale = gmm1_weight_scale.float() - gmm2_weight_scale = gmm2_weight_scale.float() - - self.gmm1_weight = torch.nn.Parameter(gmm1_weight, requires_grad=False) - self.gmm1_weight_scale = torch.nn.Parameter(gmm1_weight_scale, - requires_grad=False) - self.gmm2_weight = torch.nn.Parameter(gmm2_weight, requires_grad=False) - self.gmm2_weight_scale = torch.nn.Parameter(gmm2_weight_scale, - requires_grad=False) - def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales): output = torch.ops._C_ascend.dispatch_gmm_combine_decode( x=x, expert_ids=expert_ids, gmm1_permuted_weight=self.gmm1_weight, - gmm1_permuted_weight_scale=self.gmm1_weight_scale, + gmm1_permuted_weight_scale=self.gmm1_weight_scale_fp32, gmm2_weight=self.gmm2_weight, - gmm2_weight_scale=self.gmm2_weight_scale, + gmm2_weight_scale=self.gmm2_weight_scale_fp32, expert_smooth_scales=smooth_scales, expert_scales=expert_scales, group_ep=self.ep_hcomm_info, @@ -399,6 +371,9 @@ def run_once(local_rank_id, fused_ops = FusionOp(*weight_datas, ep_hcomm_info_fused, *parameter).npu() # type: ignore if test_graph: + config = torchair.CompilerConfig() + config.mode = "reduce-overhead" + npu_backend = torchair.get_npu_backend(compiler_config=config) fused_ops = torch.compile(fused_ops, backend=npu_backend) small_op_token_output, small_op_count_output = small_ops(*input_datas) fused_op_token_output, fused_op_count_output = fused_ops(*input_datas)