[kernel] Adapt DispatchGmmCombineDecode operator to parameters of small operators (#4790)

### What this PR does / why we need it?

This PR adapt DispatchGmmCombineDecode operator to parameters of small
operators.
1. This operator no longer requires permuting the weights and scales of
GMM1.
2. This operator no longer requires transposing the weights of GMM2.

Therefore, this operator and the small operator can use the same
parameters (weights and scales), which is beneficial for model
adaptation.

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

Signed-off-by: wangqiankun <wangqiankun13@huawei.com>
This commit is contained in:
wangqiankun13
2025-12-09 16:17:06 +08:00
committed by GitHub
parent 9a885d08d0
commit 9567e5dd8c
5 changed files with 118 additions and 142 deletions

View File

@@ -238,6 +238,7 @@ static ge::graphStatus SetWorkSpace(gert::TilingContext *context, const char *no
uint32_t moeExpertNumPerRank = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank; uint32_t moeExpertNumPerRank = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank;
uint32_t h = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.h; uint32_t h = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.h;
uint32_t aicNum = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.aicNum; uint32_t aicNum = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.aicNum;
uint64_t gmm1HLen = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen;
uint64_t gmm2HLen = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen / 2; uint64_t gmm2HLen = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen / 2;
if (epRankId < sharedExpertRankNum) { if (epRankId < sharedExpertRankNum) {
maxTokenNum = maxBatchSize * epRankSize / sharedExpertRankNum; maxTokenNum = maxBatchSize * epRankSize / sharedExpertRankNum;
@@ -245,20 +246,23 @@ static ge::graphStatus SetWorkSpace(gert::TilingContext *context, const char *no
maxTokenNum = maxBatchSize * epRankSize * std::min(topK, moeExpertNumPerRank); maxTokenNum = maxBatchSize * epRankSize * std::min(topK, moeExpertNumPerRank);
} }
size_t x2TokenSize = CeilUp(maxTokenNum * gmm2HLen * sizeof(int8_t), GM_ALIGN_SIZE); size_t x1TokenSize = maxTokenNum * h * sizeof(int8_t);
size_t x2ScaleSize = CeilUp(maxTokenNum * sizeof(float), GM_ALIGN_SIZE); size_t x2TokenSize = maxTokenNum * gmm2HLen * sizeof(int8_t);
size_t maxTokenSize = x1TokenSize < x2TokenSize ? x2TokenSize : x1TokenSize;
maxTokenSize = CeilUp(maxTokenSize, GM_ALIGN_SIZE);
size_t tokenScaleSize = CeilUp(maxTokenNum * sizeof(float), GM_ALIGN_SIZE);
size_t CVSwapBufferSize = size_t CVSwapBufferSize =
CeilUp(aicNum * L1_TILE_BYTE_SIZE * CUBE_WORKSPACE_STAGE * sizeof(int32_t), GM_ALIGN_SIZE); CeilUp(aicNum * L1_TILE_BYTE_SIZE * CUBE_WORKSPACE_STAGE * sizeof(int32_t), GM_ALIGN_SIZE);
size_t swigluOutSize = CeilUp(maxTokenNum * gmm2HLen * sizeof(float), GM_ALIGN_SIZE); size_t swigluOutSize = maxTokenNum * gmm1HLen * sizeof(float);
size_t gmm2DepOutSize = maxTokenNum * h * TOKEN_DTYPE_BYTE_SIZE;
size_t maxSwigluGmm2Size = swigluOutSize < gmm2DepOutSize ? gmm2DepOutSize : swigluOutSize;
maxSwigluGmm2Size = CeilUp(maxSwigluGmm2Size, GM_ALIGN_SIZE);
size_t groupListSize = CeilUp(moeExpertNumPerRank * sizeof(int64_t), GM_ALIGN_SIZE); size_t groupListSize = CeilUp(moeExpertNumPerRank * sizeof(int64_t), GM_ALIGN_SIZE);
size_t expandIdxSize = CeilUp(batchSize * topK * sizeof(int32_t), GM_ALIGN_SIZE); size_t expandIdxSize = CeilUp(batchSize * topK * sizeof(int32_t), GM_ALIGN_SIZE);
size_t epSendCountSize = CeilUp(epRankSize * moeExpertNumPerRank * sizeof(int32_t), GM_ALIGN_SIZE); size_t epSendCountSize = CeilUp(epRankSize * moeExpertNumPerRank * sizeof(int32_t), GM_ALIGN_SIZE);
size_t x1TokenSize = CeilUp(maxTokenNum * h * sizeof(int8_t), GM_ALIGN_SIZE);
size_t x1ScaleSize = CeilUp(maxTokenNum * sizeof(float), GM_ALIGN_SIZE);
size_t gmm2DepOutSize = CeilUp(maxTokenNum * h * TOKEN_DTYPE_BYTE_SIZE, GM_ALIGN_SIZE);
size_t resveredSize = CeilUp(RESERVED_WORKSPACE_SIZE, GM_ALIGN_SIZE); size_t resveredSize = CeilUp(RESERVED_WORKSPACE_SIZE, GM_ALIGN_SIZE);
size_t usrSize = x2TokenSize + x2ScaleSize + CVSwapBufferSize + swigluOutSize + groupListSize + expandIdxSize + size_t usrSize = maxTokenSize + tokenScaleSize + CVSwapBufferSize + maxSwigluGmm2Size + groupListSize + expandIdxSize +
epSendCountSize + x1TokenSize + x1ScaleSize + gmm2DepOutSize + resveredSize; epSendCountSize + resveredSize;
workSpaces[0] = SYSTEM_NEED_WORKSPACE + usrSize; workSpaces[0] = SYSTEM_NEED_WORKSPACE + usrSize;
return ge::GRAPH_SUCCESS; return ge::GRAPH_SUCCESS;

View File

@@ -178,7 +178,7 @@ CATLASS_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCoun
template <TemplateMC2TypeClass, class L1TileShape_, class L0TileShape_, class EpilogueTileShape_, class BlockScheduler_, template <TemplateMC2TypeClass, class L1TileShape_, class L0TileShape_, class EpilogueTileShape_, class BlockScheduler_,
class DispatchPolicy_ = MmadAtlasA2Custom> class DispatchPolicy_ = MmadAtlasA2Custom>
CATLASS_DEVICE void GmmDeq(GemmCoord problemShape, uint32_t groupCount, GM_ADDR gmGroupList, GM_ADDR gmA, CATLASS_DEVICE void GmmDeq(GemmCoord problemShape, uint32_t groupCount, GM_ADDR gmGroupList, GM_ADDR gmA,
layout::RowMajor layoutA, GM_ADDR gmB, layout::nZ layoutB, GM_ADDR gmScale, layout::RowMajor layoutA, GM_ADDR gmB, layout::zN layoutB, GM_ADDR gmScale,
layout::VectorLayout layoutScale, GM_ADDR gmPerTokenScale, layout::VectorLayout layoutScale, GM_ADDR gmPerTokenScale,
layout::VectorLayout layoutPerTokenScale, GM_ADDR gmD, layout::RowMajor layoutD, layout::VectorLayout layoutPerTokenScale, GM_ADDR gmD, layout::RowMajor layoutD,
GM_ADDR gmWorkspace, void *combiner) GM_ADDR gmWorkspace, void *combiner)
@@ -189,7 +189,7 @@ CATLASS_DEVICE void GmmDeq(GemmCoord problemShape, uint32_t groupCount, GM_ADDR
using L0TileShape = L0TileShape_; using L0TileShape = L0TileShape_;
using AType = Gemm::GemmType<int8_t, layout::RowMajor>; using AType = Gemm::GemmType<int8_t, layout::RowMajor>;
using BType = Gemm::GemmType<int8_t, layout::nZ>; using BType = Gemm::GemmType<int8_t, layout::zN>;
using CType = Gemm::GemmType<int32_t, layout::RowMajor>; using CType = Gemm::GemmType<int32_t, layout::RowMajor>;
using BlockMmad = Gemm::Block::BlockMmad<DispatchPolicy, L1TileShape, L0TileShape, AType, BType, CType>; using BlockMmad = Gemm::Block::BlockMmad<DispatchPolicy, L1TileShape, L0TileShape, AType, BType, CType>;
@@ -261,12 +261,12 @@ private:
GM_ADDR gmSmoothScales_; GM_ADDR gmSmoothScales_;
GM_ADDR gmexpertScales_; GM_ADDR gmexpertScales_;
uint32_t m_{0}; uint32_t maxTokenNum_{0};
uint32_t n_{0}; uint32_t gmm1OutputDim_{0};
uint32_t k_{0}; uint32_t tokenHiddenSize_{0};
uint32_t groupCount_{0}; uint32_t groupCount_{0};
uint32_t n2_{0}; uint32_t gmm2OutputDim_{0};
uint32_t k2_{0}; uint32_t gmm2InputDim_{0};
uint32_t globalRankId_{0}; uint32_t globalRankId_{0};
uint32_t winSizePerRank_{0}; uint32_t winSizePerRank_{0};
uint32_t blockDim_{0}; uint32_t blockDim_{0};
@@ -327,59 +327,62 @@ __aicore__ inline void DispatchGmmCombineDecode<TemplateMC2TypeFunc>::Init(
bool isShareExpert = (epRankId_ < sharedExpertRankNum_); bool isShareExpert = (epRankId_ < sharedExpertRankNum_);
if (isShareExpert) { if (isShareExpert) {
m_ = maxBs_ * epRankSize_ / sharedExpertRankNum_; maxTokenNum_ = maxBs_ * epRankSize_ / sharedExpertRankNum_;
} else { } else {
m_ = maxBs_ * epRankSize_ * (topK_ < moeExpertNumPerRank_ ? topK_ : moeExpertNumPerRank_); maxTokenNum_ = maxBs_ * epRankSize_ * (topK_ < moeExpertNumPerRank_ ? topK_ : moeExpertNumPerRank_);
} }
n_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen; gmm1OutputDim_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen;
k_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.h; tokenHiddenSize_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.h;
groupCount_ = isShareExpert ? 1 : tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank; groupCount_ = isShareExpert ? 1 : tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank;
n2_ = k_; gmm2OutputDim_ = tokenHiddenSize_;
k2_ = n_ / 2; gmm2InputDim_ = gmm1OutputDim_ / 2;
} }
template <TemplateMC2TypeClass> template <TemplateMC2TypeClass>
__aicore__ inline void DispatchGmmCombineDecode<TemplateMC2TypeFunc>::Process() __aicore__ inline void DispatchGmmCombineDecode<TemplateMC2TypeFunc>::Process()
{ {
GemmCoord gmm1ProblemShape{m_, n_, k_}; GemmCoord gmm1ProblemShape{maxTokenNum_, gmm1OutputDim_, tokenHiddenSize_};
GemmCoord gmm2ProblemShape{m_, n2_, k2_}; GemmCoord gmm2ProblemShape{maxTokenNum_, gmm2OutputDim_, gmm2InputDim_};
layout::RowMajor layoutX1{m_, k_}; layout::RowMajor layoutX1{maxTokenNum_, tokenHiddenSize_};
layout::zN layoutWeight1 = layout::zN::template MakeLayout<int8_t>(k_, n_); layout::zN layoutWeight1 = layout::zN::template MakeLayout<int8_t>(tokenHiddenSize_, gmm1OutputDim_);
layout::VectorLayout layoutScale1{n_}; layout::VectorLayout layoutW1Scale{gmm1OutputDim_};
layout::VectorLayout layoutPerTokenScale1{m_}; layout::VectorLayout layoutX1Scale{maxTokenNum_};
layout::RowMajor layoutX2{m_, k2_}; layout::RowMajor layoutX2{maxTokenNum_, gmm2InputDim_};
layout::nZ layoutWeight2 = layout::nZ::template MakeLayout<int8_t>(k2_, n2_); layout::zN layoutWeight2 = layout::zN::template MakeLayout<int8_t>(gmm2InputDim_, gmm2OutputDim_);
layout::VectorLayout layoutScale2{n2_}; layout::VectorLayout layoutW2Scale{gmm2OutputDim_};
layout::VectorLayout layoutPerTokenScale2{m_}; layout::VectorLayout layoutX2Scale{maxTokenNum_};
layout::RowMajor layoutOutput{m_, n2_}; layout::RowMajor layoutOutput{maxTokenNum_, gmm2OutputDim_};
size_t workspaceOffset = 0; size_t workspaceOffset = 0;
constexpr int32_t resveredWorkSpaceSize = 256 * 1024; constexpr int32_t resveredWorkSpaceSize = 256 * 1024;
GM_ADDR gmX2 = workspaceGM_; int64_t x1TokenSize = maxTokenNum_ * tokenHiddenSize_ * sizeof(int8_t);
workspaceOffset += RoundUp<GM_ALIGN_BYTE>(static_cast<size_t>(m_) * k2_ * sizeof(int8_t)); int64_t x2TokenSize = maxTokenNum_ * gmm2InputDim_ * sizeof(int8_t);
GM_ADDR gmPerTokenScale2 = workspaceGM_ + workspaceOffset; int64_t maxTokenSize = x1TokenSize < x2TokenSize ? x2TokenSize : x1TokenSize;
workspaceOffset += RoundUp<GM_ALIGN_BYTE>(static_cast<size_t>(m_) * sizeof(float)); int64_t tokenScaleSize = maxTokenNum_ * sizeof(float);
GM_ADDR gmX1 = workspaceGM_ + workspaceOffset;
GM_ADDR gmX2 = workspaceGM_ + workspaceOffset;
workspaceOffset += RoundUp<GM_ALIGN_BYTE>(maxTokenSize);
GM_ADDR gmX1Scale = workspaceGM_ + workspaceOffset;
GM_ADDR gmX2Scale = workspaceGM_ + workspaceOffset;
workspaceOffset += RoundUp<GM_ALIGN_BYTE>(tokenScaleSize);
GM_ADDR gmWorkspace = workspaceGM_ + workspaceOffset; GM_ADDR gmWorkspace = workspaceGM_ + workspaceOffset;
GM_ADDR gmCVSwap = workspaceGM_ + workspaceOffset; GM_ADDR gmCVSwap = workspaceGM_ + workspaceOffset;
workspaceOffset += RoundUp<GM_ALIGN_BYTE>(static_cast<size_t>(blockDim_) * (GMM1_L1M * GMM1_L1N) * workspaceOffset += RoundUp<GM_ALIGN_BYTE>(static_cast<size_t>(blockDim_) * (GMM1_L1M * GMM1_L1N) *
WORKSPACE_STAGES * sizeof(int32_t)); WORKSPACE_STAGES * sizeof(int32_t));
int64_t swigluOutSize = maxTokenNum_ * gmm1OutputDim_ * sizeof(float);
int64_t gmm2OutSize = maxTokenNum_ * tokenHiddenSize_ * sizeof(ExpandXType);
int64_t maxSwigluGmm2Size = swigluOutSize < gmm2OutSize ? gmm2OutSize : swigluOutSize;
GM_ADDR gmSwigluOut = workspaceGM_ + workspaceOffset; GM_ADDR gmSwigluOut = workspaceGM_ + workspaceOffset;
workspaceOffset += RoundUp<GM_ALIGN_BYTE>(static_cast<size_t>(m_) * k2_ * sizeof(float)); GM_ADDR gmGmm2DepOut = workspaceGM_ + workspaceOffset;
workspaceOffset += RoundUp<GM_ALIGN_BYTE>(maxSwigluGmm2Size);
GM_ADDR gmGroupList = workspaceGM_ + workspaceOffset; GM_ADDR gmGroupList = workspaceGM_ + workspaceOffset;
workspaceOffset += RoundUp<GM_ALIGN_BYTE>(static_cast<size_t>(groupCount_) * sizeof(int64_t)); workspaceOffset += RoundUp<GM_ALIGN_BYTE>(static_cast<size_t>(groupCount_) * sizeof(int64_t));
GM_ADDR gmExpandIdx = workspaceGM_ + workspaceOffset; GM_ADDR gmExpandIdx = workspaceGM_ + workspaceOffset;
workspaceOffset += RoundUp<GM_ALIGN_BYTE>(static_cast<size_t>(bs_) * topK_ * sizeof(int32_t)); workspaceOffset += RoundUp<GM_ALIGN_BYTE>(static_cast<size_t>(bs_) * topK_ * sizeof(int32_t));
GM_ADDR gmEpSendCount = workspaceGM_ + workspaceOffset; GM_ADDR gmEpSendCount = workspaceGM_ + workspaceOffset;
workspaceOffset += RoundUp<GM_ALIGN_BYTE>(static_cast<size_t>(epRankSize_) * groupCount_ * sizeof(int32_t)); workspaceOffset += RoundUp<GM_ALIGN_BYTE>(static_cast<size_t>(epRankSize_) * groupCount_ * sizeof(int32_t));
GM_ADDR gmX1Token = workspaceGM_ + workspaceOffset;
workspaceOffset += RoundUp<GM_ALIGN_BYTE>(static_cast<size_t>(m_) * k_ * sizeof(int8_t));
GM_ADDR gmX1Scale = workspaceGM_ + workspaceOffset;
workspaceOffset += RoundUp<GM_ALIGN_BYTE>(static_cast<size_t>(m_) * sizeof(float));
GM_ADDR gmGmm2DepOut = workspaceGM_ + workspaceOffset;
workspaceOffset += RoundUp<GM_ALIGN_BYTE>(static_cast<size_t>(m_) * k_ * sizeof(ExpandXType));
GM_ADDR gmResvered = workspaceGM_ + workspaceOffset; GM_ADDR gmResvered = workspaceGM_ + workspaceOffset;
workspaceOffset += RoundUp<GM_ALIGN_BYTE>(resveredWorkSpaceSize); workspaceOffset += RoundUp<GM_ALIGN_BYTE>(resveredWorkSpaceSize);
@@ -388,7 +391,7 @@ __aicore__ inline void DispatchGmmCombineDecode<TemplateMC2TypeFunc>::Process()
AscendC::TPipe tpipe; AscendC::TPipe tpipe;
MoeDistributeDispatchImpl::CamMoeDistributeDispatch<ExpandXType, int8_t, false, true, false, false> MoeDistributeDispatchImpl::CamMoeDistributeDispatch<ExpandXType, int8_t, false, true, false, false>
dispatcher; dispatcher;
dispatcher.Init(gmX_, gmexpertIds_, gmSmoothScales_, gmX1Token, gmX1Scale, gmExpandIdx, gmGroupList, dispatcher.Init(gmX_, gmexpertIds_, gmSmoothScales_, gmX1, gmX1Scale, gmExpandIdx, gmGroupList,
gmEpSendCount, gmOutputRecvCount_, nullptr, gmWorkspace, &tpipe, tilingData_); gmEpSendCount, gmOutputRecvCount_, nullptr, gmWorkspace, &tpipe, tilingData_);
dispatcher.Process(); dispatcher.Process();
tpipe.Destroy(); tpipe.Destroy();
@@ -406,11 +409,11 @@ __aicore__ inline void DispatchGmmCombineDecode<TemplateMC2TypeFunc>::Process()
} }
GmmDeqSwigluQuant<EXEC_FLAG, ExpandXType, Gmm1L1TileShape, Gmm1L0TileShape, Gmm1EpilogueTileShape, GmmDeqSwigluQuant<EXEC_FLAG, ExpandXType, Gmm1L1TileShape, Gmm1L0TileShape, Gmm1EpilogueTileShape,
Gmm1BlockScheduler>( Gmm1BlockScheduler>(
gmm1ProblemShape, groupCount_, gmGroupList, gmX1Token, layoutX1, gmPermuteWeight1_, layoutWeight1, gmm1ProblemShape, groupCount_, gmGroupList, gmX1, layoutX1, gmPermuteWeight1_, layoutWeight1,
gmPermuteScale1_, layoutScale1, gmX1Scale, layoutPerTokenScale1, gmX2, layoutX2, gmPerTokenScale2, gmPermuteScale1_, layoutW1Scale, gmX1Scale, layoutX1Scale, gmX2, layoutX2, gmX2Scale,
layoutPerTokenScale2, gmWorkspace, gmX_, gmSmoothScales_, gmexpertIds_, gmExpandIdx, gmEpSendCount, gmResvered, layoutX2Scale, gmWorkspace, gmX_, gmSmoothScales_, gmexpertIds_, gmExpandIdx, gmEpSendCount, gmResvered,
gmOutputRecvCount_, epRankSize_, epRankId_, moeExpertNum_, moeExpertNumPerRank_, sharedExpertNum_, gmOutputRecvCount_, epRankSize_, epRankId_, moeExpertNum_, moeExpertNumPerRank_, sharedExpertNum_,
sharedExpertRankNum_, quantMode_, globalBs_, bs_, topK_, k_); sharedExpertRankNum_, quantMode_, globalBs_, bs_, topK_, tokenHiddenSize_);
AscendC::PipeBarrier<PIPE_ALL>(); AscendC::PipeBarrier<PIPE_ALL>();
Arch::CrossCoreFlag gmm1AivFinished{0}; Arch::CrossCoreFlag gmm1AivFinished{0};
if constexpr (g_coreType == AscendC::AIV) { if constexpr (g_coreType == AscendC::AIV) {
@@ -427,7 +430,7 @@ __aicore__ inline void DispatchGmmCombineDecode<TemplateMC2TypeFunc>::Process()
} }
GmmDeq<TemplateMC2TypeFunc, Gmm2L1TileShape, Gmm2L0TileShape, Gmm2EpilogueTileShape, Gmm2BlockScheduler, GmmDeq<TemplateMC2TypeFunc, Gmm2L1TileShape, Gmm2L0TileShape, Gmm2EpilogueTileShape, Gmm2BlockScheduler,
Gmm2DispatchPolicy>(gmm2ProblemShape, groupCount_, gmGroupList, gmX2, layoutX2, gmWeight2_, layoutWeight2, Gmm2DispatchPolicy>(gmm2ProblemShape, groupCount_, gmGroupList, gmX2, layoutX2, gmWeight2_, layoutWeight2,
gmScale2_, layoutScale2, gmPerTokenScale2, layoutPerTokenScale2, gmGmm2DepOut, gmScale2_, layoutW2Scale, gmX2Scale, layoutX2Scale, gmGmm2DepOut,
layoutOutput, gmWorkspace, &combiner); layoutOutput, gmWorkspace, &combiner);
} }
#endif // DISPATCH_GMM_COMBINE_DECODE_H #endif // DISPATCH_GMM_COMBINE_DECODE_H

View File

@@ -74,20 +74,11 @@ public:
std::is_same_v<TileShape, typename TileOneBlkColumnBroadcastMul::TileShape>, std::is_same_v<TileShape, typename TileOneBlkColumnBroadcastMul::TileShape>,
"TileShape must be consistent for all tile compute ops"); "TileShape must be consistent for all tile compute ops");
static constexpr uint32_t CHUNK_TILE_COLUMN = TileShape::COLUMN / 2;
using ChunkTileShape = MatrixShape<TileShape::ROW, CHUNK_TILE_COLUMN>;
using TileStrideMuls = Tile::TileStrideMuls<ArchTag, float, ChunkTileShape, ChunkTileShape, TileShape>;
using TileStrideDiv = Tile::TileStrideDiv<ArchTag, float, ChunkTileShape, ChunkTileShape::COLUMN, TileShape::COLUMN,
ChunkTileShape::COLUMN>;
using TileStrideMul = Tile::TileStrideMul<ArchTag, float, ChunkTileShape, ChunkTileShape::COLUMN, TileShape::COLUMN,
ChunkTileShape::COLUMN>;
static_assert(UB_STAGES <= 2, "UB stages too large, event id is not enough."); static_assert(UB_STAGES <= 2, "UB stages too large, event id is not enough.");
static_assert((UB_STAGES * (TileShape::COUNT * sizeof(ElementC) + TileShape::COLUMN * sizeof(ElementScale) + static_assert((UB_STAGES * (TileShape::COUNT * sizeof(ElementC) + TileShape::COLUMN * sizeof(ElementScale) +
TileShape::ROW * sizeof(ElementPerTokenScale) + TileShape::COUNT * sizeof(ElementD)) + TileShape::ROW * sizeof(ElementPerTokenScale) + TileShape::COUNT * sizeof(ElementD)) +
(TileShape::COUNT + ChunkTileShape::COUNT) * sizeof(float) + TileShape::ROW * BYTE_PER_BLK) <= (TileShape::COUNT + TileShape::COUNT) * sizeof(float) + TileShape::ROW * BYTE_PER_BLK) <=
ArchTag::UB_SIZE, ArchTag::UB_SIZE,
"TileShape is too large to fit in UB"); "TileShape is too large to fit in UB");
@@ -151,7 +142,7 @@ public:
ubOffset += TileShape::COUNT * sizeof(float); ubOffset += TileShape::COUNT * sizeof(float);
ubTmpMx32B = resource.ubBuf.template GetBufferByByte<float>(ubOffset); ubTmpMx32B = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
ubOffset += TileShape::ROW * BYTE_PER_BLK; ubOffset += TileShape::ROW * BYTE_PER_BLK;
ubTmpMxChunkN = resource.ubBuf.template GetBufferByByte<float>(ubOffset); ubDenominatorMxN = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
} }
CATLASS_DEVICE CATLASS_DEVICE
@@ -185,7 +176,7 @@ public:
MatrixCoord blockCoord = blockCoordMNK.GetCoordMN(); MatrixCoord blockCoord = blockCoordMNK.GetCoordMN();
MatrixCoord actualBlockShape = actualBlockShapeMNK.GetCoordMN(); MatrixCoord actualBlockShape = actualBlockShapeMNK.GetCoordMN();
MatrixCoord blockOffset = blockCoord * blockShape; MatrixCoord blockOffset = blockCoord * blockShape;
bool isLeft = blockOffset.column() < (params.layoutD.shape(1) >> 1);
AscendC::GlobalTensor<ElementScale> gmScale; AscendC::GlobalTensor<ElementScale> gmScale;
gmScale.SetGlobalBuffer(params.ptrScale); gmScale.SetGlobalBuffer(params.ptrScale);
AscendC::GlobalTensor<ElementPerTokenScale> gmPerTokenScale; AscendC::GlobalTensor<ElementPerTokenScale> gmPerTokenScale;
@@ -194,7 +185,6 @@ public:
gmD.SetGlobalBuffer(params.ptrD); gmD.SetGlobalBuffer(params.ptrD);
auto ubTileStride = MakeCoord(static_cast<int64_t>(TileShape::COLUMN), 1L); auto ubTileStride = MakeCoord(static_cast<int64_t>(TileShape::COLUMN), 1L);
auto ubChunkTileStride = MakeCoord(static_cast<int64_t>(ChunkTileShape::COLUMN), 1L);
auto tileShape = TileShape::ToCoord(); auto tileShape = TileShape::ToCoord();
EpilogueTileSwizzle epilogueTileSwizzle(actualBlockShape, tileShape); EpilogueTileSwizzle epilogueTileSwizzle(actualBlockShape, tileShape);
uint32_t tileLoops = epilogueTileSwizzle.GetLoops(); uint32_t tileLoops = epilogueTileSwizzle.GetLoops();
@@ -206,9 +196,6 @@ public:
auto tileOffsetInBlock = tileCoord * tileShape; auto tileOffsetInBlock = tileCoord * tileShape;
auto tileOffset = blockOffset + tileOffsetInBlock; auto tileOffset = blockOffset + tileOffsetInBlock;
auto actualChunkTileShape = MakeCoord(actualTileShape.row(), actualTileShape.column() >> 1);
auto chunkTileOffset = MakeCoord(tileOffset.row(), tileOffset.column() >> 1);
auto gmTileC = gmBlockC[layoutBlockC.GetOffset(tileOffsetInBlock)]; auto gmTileC = gmBlockC[layoutBlockC.GetOffset(tileOffsetInBlock)];
auto layoutGmTileC = layoutBlockC.GetTileLayout(actualTileShape); auto layoutGmTileC = layoutBlockC.GetTileLayout(actualTileShape);
@@ -257,27 +244,29 @@ public:
tileBroadcastOneBlk(ubTmpMx32B, ubPerTokenScale); tileBroadcastOneBlk(ubTmpMx32B, ubPerTokenScale);
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbPerTokenScaleVMTE2List[ubListId]); AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbPerTokenScaleVMTE2List[ubListId]);
auto &ubD = ubDList[ubListId];
LayoutD layoutUbD{actualTileShape, ubTileStride};
AscendC::PipeBarrier<PIPE_V>(); AscendC::PipeBarrier<PIPE_V>();
// after dequant, the left half does x / (x + exp(-Dequant(x))), the right dose nothing
if (isLeft) {
tileOneBlkColumnBroadcastMul(ubTmpMxN, ubTmpMxN, ubTmpMx32B); tileOneBlkColumnBroadcastMul(ubTmpMxN, ubTmpMxN, ubTmpMx32B);
AscendC::PipeBarrier<PIPE_V>(); AscendC::PipeBarrier<PIPE_V>();
tileStrideMuls(ubTmpMxChunkN, ubTmpMxN, -1.0f); AscendC::Muls(ubDenominatorMxN, ubTmpMxN, -1.0f, TileShape::COUNT);
AscendC::PipeBarrier<PIPE_V>(); AscendC::PipeBarrier<PIPE_V>();
AscendC::Exp(ubTmpMxChunkN, ubTmpMxChunkN, ChunkTileShape::COUNT); AscendC::Exp(ubDenominatorMxN, ubDenominatorMxN, TileShape::COUNT);
AscendC::PipeBarrier<PIPE_V>(); AscendC::PipeBarrier<PIPE_V>();
AscendC::Adds(ubTmpMxChunkN, ubTmpMxChunkN, 1.0f, ChunkTileShape::COUNT); AscendC::Adds(ubDenominatorMxN, ubDenominatorMxN, 1.0f, TileShape::COUNT);
AscendC::PipeBarrier<PIPE_V>(); AscendC::PipeBarrier<PIPE_V>();
tileStrideDiv(ubTmpMxChunkN, ubTmpMxN, ubTmpMxChunkN);
AscendC::PipeBarrier<PIPE_V>();
auto &ubD = ubDList[ubListId];
LayoutD layoutUbD{actualChunkTileShape, ubChunkTileStride};
auto ubTmpMxNR = ubTmpMxN[ChunkTileShape::COLUMN];
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[ubListId]); AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[ubListId]);
tileStrideMul(ubD, ubTmpMxNR, ubTmpMxChunkN); AscendC::Div(ubD, ubTmpMxN, ubDenominatorMxN, TileShape::COUNT);
} else {
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[ubListId]);
tileOneBlkColumnBroadcastMul(ubD, ubTmpMxN, ubTmpMx32B);
}
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(eventUbDVMTE3List[ubListId]); AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(eventUbDVMTE3List[ubListId]);
auto gmTileD = gmD[params.layoutD.GetOffset(chunkTileOffset)]; auto gmTileD = gmD[params.layoutD.GetOffset(tileOffset)];
auto layoutGmTileD = params.layoutD.GetTileLayout(actualChunkTileShape); auto layoutGmTileD = params.layoutD.GetTileLayout(actualTileShape);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(eventUbDVMTE3List[ubListId]); AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(eventUbDVMTE3List[ubListId]);
copyUbToGmD(gmTileD, ubD, layoutGmTileD, layoutUbD); copyUbToGmD(gmTileD, ubD, layoutGmTileD, layoutUbD);
@@ -307,16 +296,12 @@ private:
AscendC::LocalTensor<float> ubTmpMxN; AscendC::LocalTensor<float> ubTmpMxN;
AscendC::LocalTensor<float> ubTmpMx32B; AscendC::LocalTensor<float> ubTmpMx32B;
AscendC::LocalTensor<float> ubTmpMxChunkN; AscendC::LocalTensor<float> ubDenominatorMxN;
TileRowBroadcastMul tileRowBroadcastMul; TileRowBroadcastMul tileRowBroadcastMul;
TileBroadcastOneBlk tileBroadcastOneBlk; TileBroadcastOneBlk tileBroadcastOneBlk;
TileOneBlkColumnBroadcastMul tileOneBlkColumnBroadcastMul; TileOneBlkColumnBroadcastMul tileOneBlkColumnBroadcastMul;
TileStrideMuls tileStrideMuls;
TileStrideDiv tileStrideDiv;
TileStrideMul tileStrideMul;
CopyGmToUbC copyGmToUbC; CopyGmToUbC copyGmToUbC;
CopyGmToUbScale copyGmToUbScale; CopyGmToUbScale copyGmToUbScale;
CopyGmToUbPerTokenScale copyGmToUbPerTokenScale; CopyGmToUbPerTokenScale copyGmToUbPerTokenScale;

View File

@@ -140,6 +140,7 @@ public:
ubQuantScale = resource.ubBuf.template GetBufferByByte<float>(ubOffset); ubQuantScale = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
ubOffset += CEIL_UP(tileRow * sizeof(float)); ubOffset += CEIL_UP(tileRow * sizeof(float));
ubInputTmp = ubAbs; ubInputTmp = ubAbs;
ubInputRightHalf = ubAbs;
ubQuantF32 = ubAbs; ubQuantF32 = ubAbs;
ubQuantS32 = ubAbs.ReinterpretCast<int32_t>(); ubQuantS32 = ubAbs.ReinterpretCast<int32_t>();
ubQuantF16 = ubAbs.ReinterpretCast<half>(); ubQuantF16 = ubAbs.ReinterpretCast<half>();
@@ -188,10 +189,14 @@ public:
layout::RowMajor layoutUbInput{actualTileShape, ubTileStride}; layout::RowMajor layoutUbInput{actualTileShape, ubTileStride};
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(0); AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(0);
// continue swiglu computing here and then quant
copyGmToUbInput(ubInput, gmTileInput, layoutUbInput, layoutGmTileInput); copyGmToUbInput(ubInput, gmTileInput, layoutUbInput, layoutGmTileInput);
copyGmToUbInput(ubInputRightHalf, gmTileInput[params.layoutInput.shape(1) >> 1], layoutUbInput, layoutGmTileInput);
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(0); AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(0);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(0); AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(0);
AscendC::Mul(ubInput, ubInput, ubInputRightHalf, tileCount);
AscendC::PipeBarrier<PIPE_V>();
AscendC::Abs(ubAbs, ubInput, tileCount); AscendC::Abs(ubAbs, ubInput, tileCount);
AscendC::PipeBarrier<PIPE_V>(); AscendC::PipeBarrier<PIPE_V>();
@@ -290,6 +295,7 @@ private:
AscendC::LocalTensor<float> ubQuantScale; AscendC::LocalTensor<float> ubQuantScale;
AscendC::LocalTensor<float> ubQuantScaleBrcb; AscendC::LocalTensor<float> ubQuantScaleBrcb;
AscendC::LocalTensor<float> ubInputTmp; AscendC::LocalTensor<float> ubInputTmp;
AscendC::LocalTensor<float> ubInputRightHalf;
AscendC::LocalTensor<float> ubQuantF32; AscendC::LocalTensor<float> ubQuantF32;
AscendC::LocalTensor<int32_t> ubQuantS32; AscendC::LocalTensor<int32_t> ubQuantS32;
AscendC::LocalTensor<half> ubQuantF16; AscendC::LocalTensor<half> ubQuantF16;
@@ -1232,7 +1238,6 @@ public:
__gm__ float *gmSwigluOutput, uint32_t n, uint32_t k, LayoutScale layoutScale, __gm__ float *gmSwigluOutput, uint32_t n, uint32_t k, LayoutScale layoutScale,
LayoutPerTokenScale wholeLayoutPerTokenScale, LayoutOutput layoutOutput) LayoutPerTokenScale wholeLayoutPerTokenScale, LayoutOutput layoutOutput)
{ {
uint32_t nOut = n / 2;
uint32_t coreNumPerGroup = recvCoreNum / localExpertNum; uint32_t coreNumPerGroup = recvCoreNum / localExpertNum;
int64_t gmGroupOffsetScale = 0; int64_t gmGroupOffsetScale = 0;
int64_t gmGroupOffsetPerTokenScale = 0; int64_t gmGroupOffsetPerTokenScale = 0;
@@ -1267,7 +1272,7 @@ public:
GemmCoord inGroupProblemShape{currentM, n, k}; GemmCoord inGroupProblemShape{currentM, n, k};
LayoutPerTokenScale layoutPerTokenScale = LayoutPerTokenScale layoutPerTokenScale =
wholeLayoutPerTokenScale.GetTileLayout(inGroupProblemShape.template GetCoordByAxis<0>()); wholeLayoutPerTokenScale.GetTileLayout(inGroupProblemShape.template GetCoordByAxis<0>());
LayoutD layoutD = layoutOutput.GetTileLayout(MakeCoord(currentM, nOut)); LayoutD layoutD = layout::RowMajor{currentM, n};
EpilogueParams epilogueParams{gmScale + gmGroupOffsetScale, EpilogueParams epilogueParams{gmScale + gmGroupOffsetScale,
layoutScale, layoutScale,
gmTokenScale + gmGroupOffsetPerTokenScale, gmTokenScale + gmGroupOffsetPerTokenScale,
@@ -1299,7 +1304,7 @@ public:
gmGroupOffsetScale += inGroupProblemShape.n(); gmGroupOffsetScale += inGroupProblemShape.n();
gmGroupOffsetPerTokenScale += inGroupProblemShape.m(); gmGroupOffsetPerTokenScale += inGroupProblemShape.m();
gmGroupOffsetD += currentM * nOut; gmGroupOffsetD += currentM * n;
startCoreIdx = (startCoreIdx + coreLoops) % aiCoreGroupNum; startCoreIdx = (startCoreIdx + coreLoops) % aiCoreGroupNum;
} }
@@ -1514,11 +1519,13 @@ public:
__asm__ __volatile__(""); __asm__ __volatile__("");
totalTokenCount = sendCountsGlobal.GetValue(localExpertNum * epRankSize - 1); totalTokenCount = sendCountsGlobal.GetValue(localExpertNum * epRankSize - 1);
AscendC::PipeBarrier<PIPE_ALL>(); AscendC::PipeBarrier<PIPE_ALL>();
uint32_t n = params.problemShape.n();
uint32_t nOut = params.problemShape.n() / 2; uint32_t nOut = params.problemShape.n() / 2;
uint32_t quantRowOnce = 0; uint32_t quantRowOnce = 0;
CalQuantRow(nOut, quantRowOnce); CalQuantRow(nOut, quantRowOnce);
auto swigluLayout = layout::RowMajor{totalTokenCount, n};
typename BlockQuant<ArchTag>::Params quantParams{ typename BlockQuant<ArchTag>::Params quantParams{
gmSwigluOutput, params.layoutOutput, params.ptrDequantScale, params.layoutDequantScale, gmSwigluOutput, swigluLayout, params.ptrDequantScale, params.layoutDequantScale,
params.ptrOutput, params.layoutOutput, quantRowOnce, nOut}; params.ptrOutput, params.layoutOutput, quantRowOnce, nOut};
BlockQuant<ArchTag> blockQuant(resource, quantParams); BlockQuant<ArchTag> blockQuant(resource, quantParams);
@@ -1850,6 +1857,7 @@ public:
params.ptrWorkspace + sizeof(int32_t) * (L1TileShape::M * coreNum * WORKSPACE_STAGES * L1TileShape::N)); params.ptrWorkspace + sizeof(int32_t) * (L1TileShape::M * coreNum * WORKSPACE_STAGES * L1TileShape::N));
uint32_t mActual = groupList.GetValue(params.problemCount - 1); uint32_t mActual = groupList.GetValue(params.problemCount - 1);
uint32_t n = params.problemShape.n();
uint32_t nOut = params.problemShape.n() / 2; uint32_t nOut = params.problemShape.n() / 2;
{ {
@@ -1866,7 +1874,7 @@ public:
LayoutScale layoutScale = params.layoutScale; LayoutScale layoutScale = params.layoutScale;
LayoutPerTokenScale layoutPerTokenScale = LayoutPerTokenScale layoutPerTokenScale =
params.layoutPerTokenScale.GetTileLayout(inGroupProblemShape.template GetCoordByAxis<0>()); params.layoutPerTokenScale.GetTileLayout(inGroupProblemShape.template GetCoordByAxis<0>());
LayoutD layoutD = params.layoutOutput.GetTileLayout(MakeCoord(currentM, nOut)); LayoutD layoutD = layout::RowMajor{currentM, n};
EpilogueParams epilogueParams{params.ptrScale + gmGroupOffsetScale, EpilogueParams epilogueParams{params.ptrScale + gmGroupOffsetScale,
layoutScale, layoutScale,
@@ -1899,7 +1907,7 @@ public:
gmGroupOffsetScale += inGroupProblemShape.n(); gmGroupOffsetScale += inGroupProblemShape.n();
gmGroupOffsetPerTokenScale += inGroupProblemShape.m(); gmGroupOffsetPerTokenScale += inGroupProblemShape.m();
gmGroupOffsetD += currentM * nOut; gmGroupOffsetD += currentM * n;
startCoreIdx = (startCoreIdx + coreLoops) % coreNum; startCoreIdx = (startCoreIdx + coreLoops) % coreNum;
} }
@@ -1910,8 +1918,9 @@ public:
{ {
uint32_t quantRowOnce = 0; uint32_t quantRowOnce = 0;
CalQuantRow(nOut, quantRowOnce); CalQuantRow(nOut, quantRowOnce);
auto swigluLayout = layout::RowMajor{mActual, n};
typename BlockQuant<ArchTag>::Params quantParams{ptrD, typename BlockQuant<ArchTag>::Params quantParams{ptrD,
params.layoutOutput, swigluLayout,
params.ptrDequantScale, params.ptrDequantScale,
params.layoutDequantScale, params.layoutDequantScale,
params.ptrOutput, params.ptrOutput,

View File

@@ -12,9 +12,7 @@ import torchair
from vllm_ascend.utils import enable_custom_op from vllm_ascend.utils import enable_custom_op
config = torchair.CompilerConfig() torch.manual_seed(42)
config.mode = "reduce-overhead"
npu_backend = torchair.get_npu_backend(compiler_config=config)
torch_npu.npu.config.allow_internal_format = True torch_npu.npu.config.allow_internal_format = True
enable_custom_op() enable_custom_op()
LOG_NAME = "dispatch_gmm_combine_decode_test_logs" LOG_NAME = "dispatch_gmm_combine_decode_test_logs"
@@ -101,7 +99,21 @@ class DecodeMoeOps(torch.nn.Module):
def _process_weights_after_loading(self, gmm1_weight, gmm1_weight_scale, def _process_weights_after_loading(self, gmm1_weight, gmm1_weight_scale,
gmm2_weight, gmm2_weight_scale): gmm2_weight, gmm2_weight_scale):
raise NotImplementedError("To be implemented in subclass") gmm1_weight = torch_npu.npu_format_cast(gmm1_weight,
torch_npu.Format.FRACTAL_NZ)
gmm2_weight = torch_npu.npu_format_cast(gmm2_weight,
torch_npu.Format.FRACTAL_NZ)
self.gmm1_weight = torch.nn.Parameter(gmm1_weight, requires_grad=False)
self.gmm1_weight_scale = torch.nn.Parameter(gmm1_weight_scale,
requires_grad=False)
self.gmm2_weight = torch.nn.Parameter(gmm2_weight, requires_grad=False)
self.gmm2_weight_scale = torch.nn.Parameter(gmm2_weight_scale,
requires_grad=False)
self.gmm1_weight_scale_fp32 = torch.nn.Parameter(
gmm1_weight_scale.float(), requires_grad=False)
self.gmm2_weight_scale_fp32 = torch.nn.Parameter(
gmm2_weight_scale.float(), requires_grad=False)
def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales): def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales):
raise NotImplementedError("To be implemented in subclass") raise NotImplementedError("To be implemented in subclass")
@@ -132,19 +144,6 @@ class SmallOps(DecodeMoeOps):
shared_expert_rank_num) shared_expert_rank_num)
self.tp_hcomm_info = "" self.tp_hcomm_info = ""
def _process_weights_after_loading(self, gmm1_weight, gmm1_weight_scale,
gmm2_weight, gmm2_weight_scale):
gmm1_weight = torch_npu.npu_format_cast(gmm1_weight,
torch_npu.Format.FRACTAL_NZ)
gmm2_weight = torch_npu.npu_format_cast(gmm2_weight,
torch_npu.Format.FRACTAL_NZ)
self.gmm1_weight = torch.nn.Parameter(gmm1_weight, requires_grad=False)
self.gmm1_weight_scale = torch.nn.Parameter(gmm1_weight_scale,
requires_grad=False)
self.gmm2_weight = torch.nn.Parameter(gmm2_weight, requires_grad=False)
self.gmm2_weight_scale = torch.nn.Parameter(gmm2_weight_scale,
requires_grad=False)
def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales): def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales):
outputs = torch_npu.npu_moe_distribute_dispatch_v2( outputs = torch_npu.npu_moe_distribute_dispatch_v2(
x=x, x=x,
@@ -238,41 +237,14 @@ class FusionOp(DecodeMoeOps):
ep_world_size, moe_expert_num, global_rank_id, ep_world_size, moe_expert_num, global_rank_id,
shared_expert_rank_num) shared_expert_rank_num)
def _process_weights_after_loading(self, gmm1_weight, gmm1_weight_scale,
gmm2_weight, gmm2_weight_scale):
gmm1_weight = gmm1_weight.transpose(1,2).contiguous()\
.view(self.local_expert_num, 2, self.moe_intermediate_size // 64, 64, self.token_hidden_size)\
.transpose(1,2).contiguous()\
.view(self.local_expert_num, self.moe_intermediate_size * 2, self.token_hidden_size)\
.transpose(1,2).contiguous()
gmm1_weight = torch_npu.npu_format_cast(gmm1_weight,
torch_npu.Format.ND)
gmm1_weight.add_(0)
gmm1_weight = torch_npu.npu_format_cast(gmm1_weight,
torch_npu.Format.FRACTAL_NZ)
gmm1_weight_scale = permute_weight(gmm1_weight_scale, 128)
gmm2_weight = torch_npu.npu_format_cast(
gmm2_weight.transpose(1, 2).contiguous(),
torch_npu.Format.FRACTAL_NZ)
gmm1_weight_scale = gmm1_weight_scale.float()
gmm2_weight_scale = gmm2_weight_scale.float()
self.gmm1_weight = torch.nn.Parameter(gmm1_weight, requires_grad=False)
self.gmm1_weight_scale = torch.nn.Parameter(gmm1_weight_scale,
requires_grad=False)
self.gmm2_weight = torch.nn.Parameter(gmm2_weight, requires_grad=False)
self.gmm2_weight_scale = torch.nn.Parameter(gmm2_weight_scale,
requires_grad=False)
def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales): def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales):
output = torch.ops._C_ascend.dispatch_gmm_combine_decode( output = torch.ops._C_ascend.dispatch_gmm_combine_decode(
x=x, x=x,
expert_ids=expert_ids, expert_ids=expert_ids,
gmm1_permuted_weight=self.gmm1_weight, gmm1_permuted_weight=self.gmm1_weight,
gmm1_permuted_weight_scale=self.gmm1_weight_scale, gmm1_permuted_weight_scale=self.gmm1_weight_scale_fp32,
gmm2_weight=self.gmm2_weight, gmm2_weight=self.gmm2_weight,
gmm2_weight_scale=self.gmm2_weight_scale, gmm2_weight_scale=self.gmm2_weight_scale_fp32,
expert_smooth_scales=smooth_scales, expert_smooth_scales=smooth_scales,
expert_scales=expert_scales, expert_scales=expert_scales,
group_ep=self.ep_hcomm_info, group_ep=self.ep_hcomm_info,
@@ -399,6 +371,9 @@ def run_once(local_rank_id,
fused_ops = FusionOp(*weight_datas, ep_hcomm_info_fused, fused_ops = FusionOp(*weight_datas, ep_hcomm_info_fused,
*parameter).npu() # type: ignore *parameter).npu() # type: ignore
if test_graph: if test_graph:
config = torchair.CompilerConfig()
config.mode = "reduce-overhead"
npu_backend = torchair.get_npu_backend(compiler_config=config)
fused_ops = torch.compile(fused_ops, backend=npu_backend) fused_ops = torch.compile(fused_ops, backend=npu_backend)
small_op_token_output, small_op_count_output = small_ops(*input_datas) small_op_token_output, small_op_count_output = small_ops(*input_datas)
fused_op_token_output, fused_op_count_output = fused_ops(*input_datas) fused_op_token_output, fused_op_count_output = fused_ops(*input_datas)