[kernel] Adapt DispatchGmmCombineDecode operator to parameters of small operators (#4790)
### What this PR does / why we need it?
This PR adapt DispatchGmmCombineDecode operator to parameters of small
operators.
1. This operator no longer requires permuting the weights and scales of
GMM1.
2. This operator no longer requires transposing the weights of GMM2.
Therefore, this operator and the small operator can use the same
parameters (weights and scales), which is beneficial for model
adaptation.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
Signed-off-by: wangqiankun <wangqiankun13@huawei.com>
This commit is contained in:
@@ -238,6 +238,7 @@ static ge::graphStatus SetWorkSpace(gert::TilingContext *context, const char *no
|
||||
uint32_t moeExpertNumPerRank = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank;
|
||||
uint32_t h = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.h;
|
||||
uint32_t aicNum = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.aicNum;
|
||||
uint64_t gmm1HLen = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen;
|
||||
uint64_t gmm2HLen = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen / 2;
|
||||
if (epRankId < sharedExpertRankNum) {
|
||||
maxTokenNum = maxBatchSize * epRankSize / sharedExpertRankNum;
|
||||
@@ -245,20 +246,23 @@ static ge::graphStatus SetWorkSpace(gert::TilingContext *context, const char *no
|
||||
maxTokenNum = maxBatchSize * epRankSize * std::min(topK, moeExpertNumPerRank);
|
||||
}
|
||||
|
||||
size_t x2TokenSize = CeilUp(maxTokenNum * gmm2HLen * sizeof(int8_t), GM_ALIGN_SIZE);
|
||||
size_t x2ScaleSize = CeilUp(maxTokenNum * sizeof(float), GM_ALIGN_SIZE);
|
||||
size_t x1TokenSize = maxTokenNum * h * sizeof(int8_t);
|
||||
size_t x2TokenSize = maxTokenNum * gmm2HLen * sizeof(int8_t);
|
||||
size_t maxTokenSize = x1TokenSize < x2TokenSize ? x2TokenSize : x1TokenSize;
|
||||
maxTokenSize = CeilUp(maxTokenSize, GM_ALIGN_SIZE);
|
||||
size_t tokenScaleSize = CeilUp(maxTokenNum * sizeof(float), GM_ALIGN_SIZE);
|
||||
size_t CVSwapBufferSize =
|
||||
CeilUp(aicNum * L1_TILE_BYTE_SIZE * CUBE_WORKSPACE_STAGE * sizeof(int32_t), GM_ALIGN_SIZE);
|
||||
size_t swigluOutSize = CeilUp(maxTokenNum * gmm2HLen * sizeof(float), GM_ALIGN_SIZE);
|
||||
size_t swigluOutSize = maxTokenNum * gmm1HLen * sizeof(float);
|
||||
size_t gmm2DepOutSize = maxTokenNum * h * TOKEN_DTYPE_BYTE_SIZE;
|
||||
size_t maxSwigluGmm2Size = swigluOutSize < gmm2DepOutSize ? gmm2DepOutSize : swigluOutSize;
|
||||
maxSwigluGmm2Size = CeilUp(maxSwigluGmm2Size, GM_ALIGN_SIZE);
|
||||
size_t groupListSize = CeilUp(moeExpertNumPerRank * sizeof(int64_t), GM_ALIGN_SIZE);
|
||||
size_t expandIdxSize = CeilUp(batchSize * topK * sizeof(int32_t), GM_ALIGN_SIZE);
|
||||
size_t epSendCountSize = CeilUp(epRankSize * moeExpertNumPerRank * sizeof(int32_t), GM_ALIGN_SIZE);
|
||||
size_t x1TokenSize = CeilUp(maxTokenNum * h * sizeof(int8_t), GM_ALIGN_SIZE);
|
||||
size_t x1ScaleSize = CeilUp(maxTokenNum * sizeof(float), GM_ALIGN_SIZE);
|
||||
size_t gmm2DepOutSize = CeilUp(maxTokenNum * h * TOKEN_DTYPE_BYTE_SIZE, GM_ALIGN_SIZE);
|
||||
size_t resveredSize = CeilUp(RESERVED_WORKSPACE_SIZE, GM_ALIGN_SIZE);
|
||||
size_t usrSize = x2TokenSize + x2ScaleSize + CVSwapBufferSize + swigluOutSize + groupListSize + expandIdxSize +
|
||||
epSendCountSize + x1TokenSize + x1ScaleSize + gmm2DepOutSize + resveredSize;
|
||||
size_t usrSize = maxTokenSize + tokenScaleSize + CVSwapBufferSize + maxSwigluGmm2Size + groupListSize + expandIdxSize +
|
||||
epSendCountSize + resveredSize;
|
||||
|
||||
workSpaces[0] = SYSTEM_NEED_WORKSPACE + usrSize;
|
||||
return ge::GRAPH_SUCCESS;
|
||||
|
||||
@@ -178,7 +178,7 @@ CATLASS_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCoun
|
||||
template <TemplateMC2TypeClass, class L1TileShape_, class L0TileShape_, class EpilogueTileShape_, class BlockScheduler_,
|
||||
class DispatchPolicy_ = MmadAtlasA2Custom>
|
||||
CATLASS_DEVICE void GmmDeq(GemmCoord problemShape, uint32_t groupCount, GM_ADDR gmGroupList, GM_ADDR gmA,
|
||||
layout::RowMajor layoutA, GM_ADDR gmB, layout::nZ layoutB, GM_ADDR gmScale,
|
||||
layout::RowMajor layoutA, GM_ADDR gmB, layout::zN layoutB, GM_ADDR gmScale,
|
||||
layout::VectorLayout layoutScale, GM_ADDR gmPerTokenScale,
|
||||
layout::VectorLayout layoutPerTokenScale, GM_ADDR gmD, layout::RowMajor layoutD,
|
||||
GM_ADDR gmWorkspace, void *combiner)
|
||||
@@ -189,7 +189,7 @@ CATLASS_DEVICE void GmmDeq(GemmCoord problemShape, uint32_t groupCount, GM_ADDR
|
||||
using L0TileShape = L0TileShape_;
|
||||
|
||||
using AType = Gemm::GemmType<int8_t, layout::RowMajor>;
|
||||
using BType = Gemm::GemmType<int8_t, layout::nZ>;
|
||||
using BType = Gemm::GemmType<int8_t, layout::zN>;
|
||||
using CType = Gemm::GemmType<int32_t, layout::RowMajor>;
|
||||
|
||||
using BlockMmad = Gemm::Block::BlockMmad<DispatchPolicy, L1TileShape, L0TileShape, AType, BType, CType>;
|
||||
@@ -261,12 +261,12 @@ private:
|
||||
GM_ADDR gmSmoothScales_;
|
||||
GM_ADDR gmexpertScales_;
|
||||
|
||||
uint32_t m_{0};
|
||||
uint32_t n_{0};
|
||||
uint32_t k_{0};
|
||||
uint32_t maxTokenNum_{0};
|
||||
uint32_t gmm1OutputDim_{0};
|
||||
uint32_t tokenHiddenSize_{0};
|
||||
uint32_t groupCount_{0};
|
||||
uint32_t n2_{0};
|
||||
uint32_t k2_{0};
|
||||
uint32_t gmm2OutputDim_{0};
|
||||
uint32_t gmm2InputDim_{0};
|
||||
uint32_t globalRankId_{0};
|
||||
uint32_t winSizePerRank_{0};
|
||||
uint32_t blockDim_{0};
|
||||
@@ -327,59 +327,62 @@ __aicore__ inline void DispatchGmmCombineDecode<TemplateMC2TypeFunc>::Init(
|
||||
|
||||
bool isShareExpert = (epRankId_ < sharedExpertRankNum_);
|
||||
if (isShareExpert) {
|
||||
m_ = maxBs_ * epRankSize_ / sharedExpertRankNum_;
|
||||
maxTokenNum_ = maxBs_ * epRankSize_ / sharedExpertRankNum_;
|
||||
} else {
|
||||
m_ = maxBs_ * epRankSize_ * (topK_ < moeExpertNumPerRank_ ? topK_ : moeExpertNumPerRank_);
|
||||
maxTokenNum_ = maxBs_ * epRankSize_ * (topK_ < moeExpertNumPerRank_ ? topK_ : moeExpertNumPerRank_);
|
||||
}
|
||||
|
||||
n_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen;
|
||||
k_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.h;
|
||||
gmm1OutputDim_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen;
|
||||
tokenHiddenSize_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.h;
|
||||
groupCount_ = isShareExpert ? 1 : tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank;
|
||||
n2_ = k_;
|
||||
k2_ = n_ / 2;
|
||||
gmm2OutputDim_ = tokenHiddenSize_;
|
||||
gmm2InputDim_ = gmm1OutputDim_ / 2;
|
||||
}
|
||||
|
||||
template <TemplateMC2TypeClass>
|
||||
__aicore__ inline void DispatchGmmCombineDecode<TemplateMC2TypeFunc>::Process()
|
||||
{
|
||||
GemmCoord gmm1ProblemShape{m_, n_, k_};
|
||||
GemmCoord gmm2ProblemShape{m_, n2_, k2_};
|
||||
GemmCoord gmm1ProblemShape{maxTokenNum_, gmm1OutputDim_, tokenHiddenSize_};
|
||||
GemmCoord gmm2ProblemShape{maxTokenNum_, gmm2OutputDim_, gmm2InputDim_};
|
||||
|
||||
layout::RowMajor layoutX1{m_, k_};
|
||||
layout::zN layoutWeight1 = layout::zN::template MakeLayout<int8_t>(k_, n_);
|
||||
layout::VectorLayout layoutScale1{n_};
|
||||
layout::VectorLayout layoutPerTokenScale1{m_};
|
||||
layout::RowMajor layoutX2{m_, k2_};
|
||||
layout::nZ layoutWeight2 = layout::nZ::template MakeLayout<int8_t>(k2_, n2_);
|
||||
layout::VectorLayout layoutScale2{n2_};
|
||||
layout::VectorLayout layoutPerTokenScale2{m_};
|
||||
layout::RowMajor layoutOutput{m_, n2_};
|
||||
layout::RowMajor layoutX1{maxTokenNum_, tokenHiddenSize_};
|
||||
layout::zN layoutWeight1 = layout::zN::template MakeLayout<int8_t>(tokenHiddenSize_, gmm1OutputDim_);
|
||||
layout::VectorLayout layoutW1Scale{gmm1OutputDim_};
|
||||
layout::VectorLayout layoutX1Scale{maxTokenNum_};
|
||||
layout::RowMajor layoutX2{maxTokenNum_, gmm2InputDim_};
|
||||
layout::zN layoutWeight2 = layout::zN::template MakeLayout<int8_t>(gmm2InputDim_, gmm2OutputDim_);
|
||||
layout::VectorLayout layoutW2Scale{gmm2OutputDim_};
|
||||
layout::VectorLayout layoutX2Scale{maxTokenNum_};
|
||||
layout::RowMajor layoutOutput{maxTokenNum_, gmm2OutputDim_};
|
||||
|
||||
size_t workspaceOffset = 0;
|
||||
constexpr int32_t resveredWorkSpaceSize = 256 * 1024;
|
||||
GM_ADDR gmX2 = workspaceGM_;
|
||||
workspaceOffset += RoundUp<GM_ALIGN_BYTE>(static_cast<size_t>(m_) * k2_ * sizeof(int8_t));
|
||||
GM_ADDR gmPerTokenScale2 = workspaceGM_ + workspaceOffset;
|
||||
workspaceOffset += RoundUp<GM_ALIGN_BYTE>(static_cast<size_t>(m_) * sizeof(float));
|
||||
int64_t x1TokenSize = maxTokenNum_ * tokenHiddenSize_ * sizeof(int8_t);
|
||||
int64_t x2TokenSize = maxTokenNum_ * gmm2InputDim_ * sizeof(int8_t);
|
||||
int64_t maxTokenSize = x1TokenSize < x2TokenSize ? x2TokenSize : x1TokenSize;
|
||||
int64_t tokenScaleSize = maxTokenNum_ * sizeof(float);
|
||||
GM_ADDR gmX1 = workspaceGM_ + workspaceOffset;
|
||||
GM_ADDR gmX2 = workspaceGM_ + workspaceOffset;
|
||||
workspaceOffset += RoundUp<GM_ALIGN_BYTE>(maxTokenSize);
|
||||
GM_ADDR gmX1Scale = workspaceGM_ + workspaceOffset;
|
||||
GM_ADDR gmX2Scale = workspaceGM_ + workspaceOffset;
|
||||
workspaceOffset += RoundUp<GM_ALIGN_BYTE>(tokenScaleSize);
|
||||
GM_ADDR gmWorkspace = workspaceGM_ + workspaceOffset;
|
||||
|
||||
GM_ADDR gmCVSwap = workspaceGM_ + workspaceOffset;
|
||||
workspaceOffset += RoundUp<GM_ALIGN_BYTE>(static_cast<size_t>(blockDim_) * (GMM1_L1M * GMM1_L1N) *
|
||||
WORKSPACE_STAGES * sizeof(int32_t));
|
||||
int64_t swigluOutSize = maxTokenNum_ * gmm1OutputDim_ * sizeof(float);
|
||||
int64_t gmm2OutSize = maxTokenNum_ * tokenHiddenSize_ * sizeof(ExpandXType);
|
||||
int64_t maxSwigluGmm2Size = swigluOutSize < gmm2OutSize ? gmm2OutSize : swigluOutSize;
|
||||
GM_ADDR gmSwigluOut = workspaceGM_ + workspaceOffset;
|
||||
workspaceOffset += RoundUp<GM_ALIGN_BYTE>(static_cast<size_t>(m_) * k2_ * sizeof(float));
|
||||
GM_ADDR gmGmm2DepOut = workspaceGM_ + workspaceOffset;
|
||||
workspaceOffset += RoundUp<GM_ALIGN_BYTE>(maxSwigluGmm2Size);
|
||||
GM_ADDR gmGroupList = workspaceGM_ + workspaceOffset;
|
||||
workspaceOffset += RoundUp<GM_ALIGN_BYTE>(static_cast<size_t>(groupCount_) * sizeof(int64_t));
|
||||
GM_ADDR gmExpandIdx = workspaceGM_ + workspaceOffset;
|
||||
workspaceOffset += RoundUp<GM_ALIGN_BYTE>(static_cast<size_t>(bs_) * topK_ * sizeof(int32_t));
|
||||
GM_ADDR gmEpSendCount = workspaceGM_ + workspaceOffset;
|
||||
workspaceOffset += RoundUp<GM_ALIGN_BYTE>(static_cast<size_t>(epRankSize_) * groupCount_ * sizeof(int32_t));
|
||||
GM_ADDR gmX1Token = workspaceGM_ + workspaceOffset;
|
||||
workspaceOffset += RoundUp<GM_ALIGN_BYTE>(static_cast<size_t>(m_) * k_ * sizeof(int8_t));
|
||||
GM_ADDR gmX1Scale = workspaceGM_ + workspaceOffset;
|
||||
workspaceOffset += RoundUp<GM_ALIGN_BYTE>(static_cast<size_t>(m_) * sizeof(float));
|
||||
GM_ADDR gmGmm2DepOut = workspaceGM_ + workspaceOffset;
|
||||
workspaceOffset += RoundUp<GM_ALIGN_BYTE>(static_cast<size_t>(m_) * k_ * sizeof(ExpandXType));
|
||||
GM_ADDR gmResvered = workspaceGM_ + workspaceOffset;
|
||||
workspaceOffset += RoundUp<GM_ALIGN_BYTE>(resveredWorkSpaceSize);
|
||||
|
||||
@@ -388,7 +391,7 @@ __aicore__ inline void DispatchGmmCombineDecode<TemplateMC2TypeFunc>::Process()
|
||||
AscendC::TPipe tpipe;
|
||||
MoeDistributeDispatchImpl::CamMoeDistributeDispatch<ExpandXType, int8_t, false, true, false, false>
|
||||
dispatcher;
|
||||
dispatcher.Init(gmX_, gmexpertIds_, gmSmoothScales_, gmX1Token, gmX1Scale, gmExpandIdx, gmGroupList,
|
||||
dispatcher.Init(gmX_, gmexpertIds_, gmSmoothScales_, gmX1, gmX1Scale, gmExpandIdx, gmGroupList,
|
||||
gmEpSendCount, gmOutputRecvCount_, nullptr, gmWorkspace, &tpipe, tilingData_);
|
||||
dispatcher.Process();
|
||||
tpipe.Destroy();
|
||||
@@ -406,11 +409,11 @@ __aicore__ inline void DispatchGmmCombineDecode<TemplateMC2TypeFunc>::Process()
|
||||
}
|
||||
GmmDeqSwigluQuant<EXEC_FLAG, ExpandXType, Gmm1L1TileShape, Gmm1L0TileShape, Gmm1EpilogueTileShape,
|
||||
Gmm1BlockScheduler>(
|
||||
gmm1ProblemShape, groupCount_, gmGroupList, gmX1Token, layoutX1, gmPermuteWeight1_, layoutWeight1,
|
||||
gmPermuteScale1_, layoutScale1, gmX1Scale, layoutPerTokenScale1, gmX2, layoutX2, gmPerTokenScale2,
|
||||
layoutPerTokenScale2, gmWorkspace, gmX_, gmSmoothScales_, gmexpertIds_, gmExpandIdx, gmEpSendCount, gmResvered,
|
||||
gmm1ProblemShape, groupCount_, gmGroupList, gmX1, layoutX1, gmPermuteWeight1_, layoutWeight1,
|
||||
gmPermuteScale1_, layoutW1Scale, gmX1Scale, layoutX1Scale, gmX2, layoutX2, gmX2Scale,
|
||||
layoutX2Scale, gmWorkspace, gmX_, gmSmoothScales_, gmexpertIds_, gmExpandIdx, gmEpSendCount, gmResvered,
|
||||
gmOutputRecvCount_, epRankSize_, epRankId_, moeExpertNum_, moeExpertNumPerRank_, sharedExpertNum_,
|
||||
sharedExpertRankNum_, quantMode_, globalBs_, bs_, topK_, k_);
|
||||
sharedExpertRankNum_, quantMode_, globalBs_, bs_, topK_, tokenHiddenSize_);
|
||||
AscendC::PipeBarrier<PIPE_ALL>();
|
||||
Arch::CrossCoreFlag gmm1AivFinished{0};
|
||||
if constexpr (g_coreType == AscendC::AIV) {
|
||||
@@ -427,7 +430,7 @@ __aicore__ inline void DispatchGmmCombineDecode<TemplateMC2TypeFunc>::Process()
|
||||
}
|
||||
GmmDeq<TemplateMC2TypeFunc, Gmm2L1TileShape, Gmm2L0TileShape, Gmm2EpilogueTileShape, Gmm2BlockScheduler,
|
||||
Gmm2DispatchPolicy>(gmm2ProblemShape, groupCount_, gmGroupList, gmX2, layoutX2, gmWeight2_, layoutWeight2,
|
||||
gmScale2_, layoutScale2, gmPerTokenScale2, layoutPerTokenScale2, gmGmm2DepOut,
|
||||
gmScale2_, layoutW2Scale, gmX2Scale, layoutX2Scale, gmGmm2DepOut,
|
||||
layoutOutput, gmWorkspace, &combiner);
|
||||
}
|
||||
#endif // DISPATCH_GMM_COMBINE_DECODE_H
|
||||
|
||||
@@ -74,20 +74,11 @@ public:
|
||||
std::is_same_v<TileShape, typename TileOneBlkColumnBroadcastMul::TileShape>,
|
||||
"TileShape must be consistent for all tile compute ops");
|
||||
|
||||
static constexpr uint32_t CHUNK_TILE_COLUMN = TileShape::COLUMN / 2;
|
||||
using ChunkTileShape = MatrixShape<TileShape::ROW, CHUNK_TILE_COLUMN>;
|
||||
|
||||
using TileStrideMuls = Tile::TileStrideMuls<ArchTag, float, ChunkTileShape, ChunkTileShape, TileShape>;
|
||||
using TileStrideDiv = Tile::TileStrideDiv<ArchTag, float, ChunkTileShape, ChunkTileShape::COLUMN, TileShape::COLUMN,
|
||||
ChunkTileShape::COLUMN>;
|
||||
using TileStrideMul = Tile::TileStrideMul<ArchTag, float, ChunkTileShape, ChunkTileShape::COLUMN, TileShape::COLUMN,
|
||||
ChunkTileShape::COLUMN>;
|
||||
|
||||
static_assert(UB_STAGES <= 2, "UB stages too large, event id is not enough.");
|
||||
|
||||
static_assert((UB_STAGES * (TileShape::COUNT * sizeof(ElementC) + TileShape::COLUMN * sizeof(ElementScale) +
|
||||
TileShape::ROW * sizeof(ElementPerTokenScale) + TileShape::COUNT * sizeof(ElementD)) +
|
||||
(TileShape::COUNT + ChunkTileShape::COUNT) * sizeof(float) + TileShape::ROW * BYTE_PER_BLK) <=
|
||||
(TileShape::COUNT + TileShape::COUNT) * sizeof(float) + TileShape::ROW * BYTE_PER_BLK) <=
|
||||
ArchTag::UB_SIZE,
|
||||
"TileShape is too large to fit in UB");
|
||||
|
||||
@@ -151,7 +142,7 @@ public:
|
||||
ubOffset += TileShape::COUNT * sizeof(float);
|
||||
ubTmpMx32B = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
|
||||
ubOffset += TileShape::ROW * BYTE_PER_BLK;
|
||||
ubTmpMxChunkN = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
|
||||
ubDenominatorMxN = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
|
||||
}
|
||||
|
||||
CATLASS_DEVICE
|
||||
@@ -185,7 +176,7 @@ public:
|
||||
MatrixCoord blockCoord = blockCoordMNK.GetCoordMN();
|
||||
MatrixCoord actualBlockShape = actualBlockShapeMNK.GetCoordMN();
|
||||
MatrixCoord blockOffset = blockCoord * blockShape;
|
||||
|
||||
bool isLeft = blockOffset.column() < (params.layoutD.shape(1) >> 1);
|
||||
AscendC::GlobalTensor<ElementScale> gmScale;
|
||||
gmScale.SetGlobalBuffer(params.ptrScale);
|
||||
AscendC::GlobalTensor<ElementPerTokenScale> gmPerTokenScale;
|
||||
@@ -194,7 +185,6 @@ public:
|
||||
gmD.SetGlobalBuffer(params.ptrD);
|
||||
|
||||
auto ubTileStride = MakeCoord(static_cast<int64_t>(TileShape::COLUMN), 1L);
|
||||
auto ubChunkTileStride = MakeCoord(static_cast<int64_t>(ChunkTileShape::COLUMN), 1L);
|
||||
auto tileShape = TileShape::ToCoord();
|
||||
EpilogueTileSwizzle epilogueTileSwizzle(actualBlockShape, tileShape);
|
||||
uint32_t tileLoops = epilogueTileSwizzle.GetLoops();
|
||||
@@ -206,9 +196,6 @@ public:
|
||||
auto tileOffsetInBlock = tileCoord * tileShape;
|
||||
auto tileOffset = blockOffset + tileOffsetInBlock;
|
||||
|
||||
auto actualChunkTileShape = MakeCoord(actualTileShape.row(), actualTileShape.column() >> 1);
|
||||
auto chunkTileOffset = MakeCoord(tileOffset.row(), tileOffset.column() >> 1);
|
||||
|
||||
auto gmTileC = gmBlockC[layoutBlockC.GetOffset(tileOffsetInBlock)];
|
||||
auto layoutGmTileC = layoutBlockC.GetTileLayout(actualTileShape);
|
||||
|
||||
@@ -257,27 +244,29 @@ public:
|
||||
tileBroadcastOneBlk(ubTmpMx32B, ubPerTokenScale);
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbPerTokenScaleVMTE2List[ubListId]);
|
||||
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
tileOneBlkColumnBroadcastMul(ubTmpMxN, ubTmpMxN, ubTmpMx32B);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
tileStrideMuls(ubTmpMxChunkN, ubTmpMxN, -1.0f);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
AscendC::Exp(ubTmpMxChunkN, ubTmpMxChunkN, ChunkTileShape::COUNT);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
AscendC::Adds(ubTmpMxChunkN, ubTmpMxChunkN, 1.0f, ChunkTileShape::COUNT);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
tileStrideDiv(ubTmpMxChunkN, ubTmpMxN, ubTmpMxChunkN);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
auto &ubD = ubDList[ubListId];
|
||||
LayoutD layoutUbD{actualChunkTileShape, ubChunkTileStride};
|
||||
|
||||
auto ubTmpMxNR = ubTmpMxN[ChunkTileShape::COLUMN];
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[ubListId]);
|
||||
tileStrideMul(ubD, ubTmpMxNR, ubTmpMxChunkN);
|
||||
LayoutD layoutUbD{actualTileShape, ubTileStride};
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
// after dequant, the left half does x / (x + exp(-Dequant(x))), the right dose nothing
|
||||
if (isLeft) {
|
||||
tileOneBlkColumnBroadcastMul(ubTmpMxN, ubTmpMxN, ubTmpMx32B);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
AscendC::Muls(ubDenominatorMxN, ubTmpMxN, -1.0f, TileShape::COUNT);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
AscendC::Exp(ubDenominatorMxN, ubDenominatorMxN, TileShape::COUNT);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
AscendC::Adds(ubDenominatorMxN, ubDenominatorMxN, 1.0f, TileShape::COUNT);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[ubListId]);
|
||||
AscendC::Div(ubD, ubTmpMxN, ubDenominatorMxN, TileShape::COUNT);
|
||||
} else {
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[ubListId]);
|
||||
tileOneBlkColumnBroadcastMul(ubD, ubTmpMxN, ubTmpMx32B);
|
||||
}
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(eventUbDVMTE3List[ubListId]);
|
||||
|
||||
auto gmTileD = gmD[params.layoutD.GetOffset(chunkTileOffset)];
|
||||
auto layoutGmTileD = params.layoutD.GetTileLayout(actualChunkTileShape);
|
||||
auto gmTileD = gmD[params.layoutD.GetOffset(tileOffset)];
|
||||
auto layoutGmTileD = params.layoutD.GetTileLayout(actualTileShape);
|
||||
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(eventUbDVMTE3List[ubListId]);
|
||||
copyUbToGmD(gmTileD, ubD, layoutGmTileD, layoutUbD);
|
||||
@@ -307,16 +296,12 @@ private:
|
||||
|
||||
AscendC::LocalTensor<float> ubTmpMxN;
|
||||
AscendC::LocalTensor<float> ubTmpMx32B;
|
||||
AscendC::LocalTensor<float> ubTmpMxChunkN;
|
||||
AscendC::LocalTensor<float> ubDenominatorMxN;
|
||||
|
||||
TileRowBroadcastMul tileRowBroadcastMul;
|
||||
TileBroadcastOneBlk tileBroadcastOneBlk;
|
||||
TileOneBlkColumnBroadcastMul tileOneBlkColumnBroadcastMul;
|
||||
|
||||
TileStrideMuls tileStrideMuls;
|
||||
TileStrideDiv tileStrideDiv;
|
||||
TileStrideMul tileStrideMul;
|
||||
|
||||
CopyGmToUbC copyGmToUbC;
|
||||
CopyGmToUbScale copyGmToUbScale;
|
||||
CopyGmToUbPerTokenScale copyGmToUbPerTokenScale;
|
||||
|
||||
@@ -140,6 +140,7 @@ public:
|
||||
ubQuantScale = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
|
||||
ubOffset += CEIL_UP(tileRow * sizeof(float));
|
||||
ubInputTmp = ubAbs;
|
||||
ubInputRightHalf = ubAbs;
|
||||
ubQuantF32 = ubAbs;
|
||||
ubQuantS32 = ubAbs.ReinterpretCast<int32_t>();
|
||||
ubQuantF16 = ubAbs.ReinterpretCast<half>();
|
||||
@@ -188,10 +189,14 @@ public:
|
||||
layout::RowMajor layoutUbInput{actualTileShape, ubTileStride};
|
||||
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(0);
|
||||
// continue swiglu computing here and then quant
|
||||
copyGmToUbInput(ubInput, gmTileInput, layoutUbInput, layoutGmTileInput);
|
||||
copyGmToUbInput(ubInputRightHalf, gmTileInput[params.layoutInput.shape(1) >> 1], layoutUbInput, layoutGmTileInput);
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(0);
|
||||
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(0);
|
||||
AscendC::Mul(ubInput, ubInput, ubInputRightHalf, tileCount);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
AscendC::Abs(ubAbs, ubInput, tileCount);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
|
||||
@@ -290,6 +295,7 @@ private:
|
||||
AscendC::LocalTensor<float> ubQuantScale;
|
||||
AscendC::LocalTensor<float> ubQuantScaleBrcb;
|
||||
AscendC::LocalTensor<float> ubInputTmp;
|
||||
AscendC::LocalTensor<float> ubInputRightHalf;
|
||||
AscendC::LocalTensor<float> ubQuantF32;
|
||||
AscendC::LocalTensor<int32_t> ubQuantS32;
|
||||
AscendC::LocalTensor<half> ubQuantF16;
|
||||
@@ -1232,7 +1238,6 @@ public:
|
||||
__gm__ float *gmSwigluOutput, uint32_t n, uint32_t k, LayoutScale layoutScale,
|
||||
LayoutPerTokenScale wholeLayoutPerTokenScale, LayoutOutput layoutOutput)
|
||||
{
|
||||
uint32_t nOut = n / 2;
|
||||
uint32_t coreNumPerGroup = recvCoreNum / localExpertNum;
|
||||
int64_t gmGroupOffsetScale = 0;
|
||||
int64_t gmGroupOffsetPerTokenScale = 0;
|
||||
@@ -1267,7 +1272,7 @@ public:
|
||||
GemmCoord inGroupProblemShape{currentM, n, k};
|
||||
LayoutPerTokenScale layoutPerTokenScale =
|
||||
wholeLayoutPerTokenScale.GetTileLayout(inGroupProblemShape.template GetCoordByAxis<0>());
|
||||
LayoutD layoutD = layoutOutput.GetTileLayout(MakeCoord(currentM, nOut));
|
||||
LayoutD layoutD = layout::RowMajor{currentM, n};
|
||||
EpilogueParams epilogueParams{gmScale + gmGroupOffsetScale,
|
||||
layoutScale,
|
||||
gmTokenScale + gmGroupOffsetPerTokenScale,
|
||||
@@ -1299,7 +1304,7 @@ public:
|
||||
|
||||
gmGroupOffsetScale += inGroupProblemShape.n();
|
||||
gmGroupOffsetPerTokenScale += inGroupProblemShape.m();
|
||||
gmGroupOffsetD += currentM * nOut;
|
||||
gmGroupOffsetD += currentM * n;
|
||||
|
||||
startCoreIdx = (startCoreIdx + coreLoops) % aiCoreGroupNum;
|
||||
}
|
||||
@@ -1514,11 +1519,13 @@ public:
|
||||
__asm__ __volatile__("");
|
||||
totalTokenCount = sendCountsGlobal.GetValue(localExpertNum * epRankSize - 1);
|
||||
AscendC::PipeBarrier<PIPE_ALL>();
|
||||
uint32_t n = params.problemShape.n();
|
||||
uint32_t nOut = params.problemShape.n() / 2;
|
||||
uint32_t quantRowOnce = 0;
|
||||
CalQuantRow(nOut, quantRowOnce);
|
||||
auto swigluLayout = layout::RowMajor{totalTokenCount, n};
|
||||
typename BlockQuant<ArchTag>::Params quantParams{
|
||||
gmSwigluOutput, params.layoutOutput, params.ptrDequantScale, params.layoutDequantScale,
|
||||
gmSwigluOutput, swigluLayout, params.ptrDequantScale, params.layoutDequantScale,
|
||||
params.ptrOutput, params.layoutOutput, quantRowOnce, nOut};
|
||||
|
||||
BlockQuant<ArchTag> blockQuant(resource, quantParams);
|
||||
@@ -1850,6 +1857,7 @@ public:
|
||||
params.ptrWorkspace + sizeof(int32_t) * (L1TileShape::M * coreNum * WORKSPACE_STAGES * L1TileShape::N));
|
||||
|
||||
uint32_t mActual = groupList.GetValue(params.problemCount - 1);
|
||||
uint32_t n = params.problemShape.n();
|
||||
uint32_t nOut = params.problemShape.n() / 2;
|
||||
|
||||
{
|
||||
@@ -1866,7 +1874,7 @@ public:
|
||||
LayoutScale layoutScale = params.layoutScale;
|
||||
LayoutPerTokenScale layoutPerTokenScale =
|
||||
params.layoutPerTokenScale.GetTileLayout(inGroupProblemShape.template GetCoordByAxis<0>());
|
||||
LayoutD layoutD = params.layoutOutput.GetTileLayout(MakeCoord(currentM, nOut));
|
||||
LayoutD layoutD = layout::RowMajor{currentM, n};
|
||||
|
||||
EpilogueParams epilogueParams{params.ptrScale + gmGroupOffsetScale,
|
||||
layoutScale,
|
||||
@@ -1899,7 +1907,7 @@ public:
|
||||
|
||||
gmGroupOffsetScale += inGroupProblemShape.n();
|
||||
gmGroupOffsetPerTokenScale += inGroupProblemShape.m();
|
||||
gmGroupOffsetD += currentM * nOut;
|
||||
gmGroupOffsetD += currentM * n;
|
||||
|
||||
startCoreIdx = (startCoreIdx + coreLoops) % coreNum;
|
||||
}
|
||||
@@ -1910,8 +1918,9 @@ public:
|
||||
{
|
||||
uint32_t quantRowOnce = 0;
|
||||
CalQuantRow(nOut, quantRowOnce);
|
||||
auto swigluLayout = layout::RowMajor{mActual, n};
|
||||
typename BlockQuant<ArchTag>::Params quantParams{ptrD,
|
||||
params.layoutOutput,
|
||||
swigluLayout,
|
||||
params.ptrDequantScale,
|
||||
params.layoutDequantScale,
|
||||
params.ptrOutput,
|
||||
|
||||
@@ -12,9 +12,7 @@ import torchair
|
||||
|
||||
from vllm_ascend.utils import enable_custom_op
|
||||
|
||||
config = torchair.CompilerConfig()
|
||||
config.mode = "reduce-overhead"
|
||||
npu_backend = torchair.get_npu_backend(compiler_config=config)
|
||||
torch.manual_seed(42)
|
||||
torch_npu.npu.config.allow_internal_format = True
|
||||
enable_custom_op()
|
||||
LOG_NAME = "dispatch_gmm_combine_decode_test_logs"
|
||||
@@ -101,7 +99,21 @@ class DecodeMoeOps(torch.nn.Module):
|
||||
|
||||
def _process_weights_after_loading(self, gmm1_weight, gmm1_weight_scale,
|
||||
gmm2_weight, gmm2_weight_scale):
|
||||
raise NotImplementedError("To be implemented in subclass")
|
||||
gmm1_weight = torch_npu.npu_format_cast(gmm1_weight,
|
||||
torch_npu.Format.FRACTAL_NZ)
|
||||
gmm2_weight = torch_npu.npu_format_cast(gmm2_weight,
|
||||
torch_npu.Format.FRACTAL_NZ)
|
||||
self.gmm1_weight = torch.nn.Parameter(gmm1_weight, requires_grad=False)
|
||||
self.gmm1_weight_scale = torch.nn.Parameter(gmm1_weight_scale,
|
||||
requires_grad=False)
|
||||
self.gmm2_weight = torch.nn.Parameter(gmm2_weight, requires_grad=False)
|
||||
self.gmm2_weight_scale = torch.nn.Parameter(gmm2_weight_scale,
|
||||
requires_grad=False)
|
||||
|
||||
self.gmm1_weight_scale_fp32 = torch.nn.Parameter(
|
||||
gmm1_weight_scale.float(), requires_grad=False)
|
||||
self.gmm2_weight_scale_fp32 = torch.nn.Parameter(
|
||||
gmm2_weight_scale.float(), requires_grad=False)
|
||||
|
||||
def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales):
|
||||
raise NotImplementedError("To be implemented in subclass")
|
||||
@@ -132,19 +144,6 @@ class SmallOps(DecodeMoeOps):
|
||||
shared_expert_rank_num)
|
||||
self.tp_hcomm_info = ""
|
||||
|
||||
def _process_weights_after_loading(self, gmm1_weight, gmm1_weight_scale,
|
||||
gmm2_weight, gmm2_weight_scale):
|
||||
gmm1_weight = torch_npu.npu_format_cast(gmm1_weight,
|
||||
torch_npu.Format.FRACTAL_NZ)
|
||||
gmm2_weight = torch_npu.npu_format_cast(gmm2_weight,
|
||||
torch_npu.Format.FRACTAL_NZ)
|
||||
self.gmm1_weight = torch.nn.Parameter(gmm1_weight, requires_grad=False)
|
||||
self.gmm1_weight_scale = torch.nn.Parameter(gmm1_weight_scale,
|
||||
requires_grad=False)
|
||||
self.gmm2_weight = torch.nn.Parameter(gmm2_weight, requires_grad=False)
|
||||
self.gmm2_weight_scale = torch.nn.Parameter(gmm2_weight_scale,
|
||||
requires_grad=False)
|
||||
|
||||
def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales):
|
||||
outputs = torch_npu.npu_moe_distribute_dispatch_v2(
|
||||
x=x,
|
||||
@@ -238,41 +237,14 @@ class FusionOp(DecodeMoeOps):
|
||||
ep_world_size, moe_expert_num, global_rank_id,
|
||||
shared_expert_rank_num)
|
||||
|
||||
def _process_weights_after_loading(self, gmm1_weight, gmm1_weight_scale,
|
||||
gmm2_weight, gmm2_weight_scale):
|
||||
gmm1_weight = gmm1_weight.transpose(1,2).contiguous()\
|
||||
.view(self.local_expert_num, 2, self.moe_intermediate_size // 64, 64, self.token_hidden_size)\
|
||||
.transpose(1,2).contiguous()\
|
||||
.view(self.local_expert_num, self.moe_intermediate_size * 2, self.token_hidden_size)\
|
||||
.transpose(1,2).contiguous()
|
||||
gmm1_weight = torch_npu.npu_format_cast(gmm1_weight,
|
||||
torch_npu.Format.ND)
|
||||
gmm1_weight.add_(0)
|
||||
gmm1_weight = torch_npu.npu_format_cast(gmm1_weight,
|
||||
torch_npu.Format.FRACTAL_NZ)
|
||||
gmm1_weight_scale = permute_weight(gmm1_weight_scale, 128)
|
||||
gmm2_weight = torch_npu.npu_format_cast(
|
||||
gmm2_weight.transpose(1, 2).contiguous(),
|
||||
torch_npu.Format.FRACTAL_NZ)
|
||||
|
||||
gmm1_weight_scale = gmm1_weight_scale.float()
|
||||
gmm2_weight_scale = gmm2_weight_scale.float()
|
||||
|
||||
self.gmm1_weight = torch.nn.Parameter(gmm1_weight, requires_grad=False)
|
||||
self.gmm1_weight_scale = torch.nn.Parameter(gmm1_weight_scale,
|
||||
requires_grad=False)
|
||||
self.gmm2_weight = torch.nn.Parameter(gmm2_weight, requires_grad=False)
|
||||
self.gmm2_weight_scale = torch.nn.Parameter(gmm2_weight_scale,
|
||||
requires_grad=False)
|
||||
|
||||
def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales):
|
||||
output = torch.ops._C_ascend.dispatch_gmm_combine_decode(
|
||||
x=x,
|
||||
expert_ids=expert_ids,
|
||||
gmm1_permuted_weight=self.gmm1_weight,
|
||||
gmm1_permuted_weight_scale=self.gmm1_weight_scale,
|
||||
gmm1_permuted_weight_scale=self.gmm1_weight_scale_fp32,
|
||||
gmm2_weight=self.gmm2_weight,
|
||||
gmm2_weight_scale=self.gmm2_weight_scale,
|
||||
gmm2_weight_scale=self.gmm2_weight_scale_fp32,
|
||||
expert_smooth_scales=smooth_scales,
|
||||
expert_scales=expert_scales,
|
||||
group_ep=self.ep_hcomm_info,
|
||||
@@ -399,6 +371,9 @@ def run_once(local_rank_id,
|
||||
fused_ops = FusionOp(*weight_datas, ep_hcomm_info_fused,
|
||||
*parameter).npu() # type: ignore
|
||||
if test_graph:
|
||||
config = torchair.CompilerConfig()
|
||||
config.mode = "reduce-overhead"
|
||||
npu_backend = torchair.get_npu_backend(compiler_config=config)
|
||||
fused_ops = torch.compile(fused_ops, backend=npu_backend)
|
||||
small_op_token_output, small_op_count_output = small_ops(*input_datas)
|
||||
fused_op_token_output, fused_op_count_output = fused_ops(*input_datas)
|
||||
|
||||
Reference in New Issue
Block a user