[BugFix] dispatch_ffn_combine kernal rollback combinev2 part (#8405)
<!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> ### What this PR does / why we need it? <!-- - Please clarify what changes you are proposing. The purpose of this section is to outline the changes and how this PR fixes the issue. If possible, please consider writing useful notes for better and faster reviews in your PR. - Please clarify why the changes are needed. For instance, the use case and bug description. - Fixes # --> ### Does this PR introduce _any_ user-facing change? <!-- Note that it means *any* user-facing change including all aspects such as API, interface or other behavior changes. Documentation-only updates are not considered user-facing changes. --> ### How was this patch tested? <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> Signed-off-by: l00893928 <liuquanlu@huawei.com> Co-authored-by: l00893928 <liuquanlu@huawei.com>
This commit is contained in:
@@ -224,7 +224,7 @@ __aicore__ inline void DispatchFFNCombine<TemplateMMA2ACFunc>::Process()
|
||||
constexpr uint32_t ubStages = 2;
|
||||
|
||||
using EpilogueDispatchPolicy1 = Epilogue::EpilogueAtlasA2PerTokenDequantSwigluQuant<ubStages>;
|
||||
|
||||
|
||||
using ScaleType = Gemm::GemmType<uint64_t, layout::VectorLayout>;
|
||||
using PerTokenScaleType = Gemm::GemmType<float, layout::VectorLayout>;
|
||||
using ElementMulType = Gemm::GemmType<float, layout::RowMajor>;
|
||||
@@ -234,7 +234,8 @@ __aicore__ inline void DispatchFFNCombine<TemplateMMA2ACFunc>::Process()
|
||||
using BlockEpilogue1 = Epilogue::Block::BlockEpilogue<EpilogueDispatchPolicy1, CType, PerTokenScaleType,
|
||||
D1Type, TileElemWiseMuls, TileCopy1>;
|
||||
|
||||
using EpilogueDispatchPolicy2 = Epilogue::EpilogueAtlasA2PerTokenDequantV2<ubStages>;
|
||||
using EpilogueDispatchPolicy2 = Epilogue::EpilogueAtlasA2PerTokenDequant<ubStages>;
|
||||
|
||||
using TileCopy2 = Epilogue::Tile::TileCopy<ArchTag, CType, ScaleType, PerTokenScaleType, D2Type>;
|
||||
using BlockEpilogue2 = Epilogue::Block::BlockEpilogue<EpilogueDispatchPolicy2, CType,PerTokenScaleType,
|
||||
D2Type, TileCopy2>;
|
||||
|
||||
@@ -571,6 +571,7 @@ private:
|
||||
if constexpr (BlockMmad::DispatchPolicy::ASYNC) {
|
||||
blockMmad.SynchronizeBlock();
|
||||
}
|
||||
blockMmad.Finalize(params.expertPerRank - 1, 0);
|
||||
}
|
||||
|
||||
|
||||
@@ -837,20 +838,16 @@ private:
|
||||
|
||||
uint32_t n2 = params.problemShape.k();
|
||||
|
||||
|
||||
typename BlockEpilogue2::Params epilogueParams{
|
||||
static_cast<int32_t>(params.EP),
|
||||
static_cast<int32_t>(params.expertPerRank),
|
||||
static_cast<int32_t>(params.rank),
|
||||
reinterpret_cast<__gm__ int32_t *>(shmem() + peermemInfo.offsetPeerTokenPerExpert),
|
||||
params.layoutD2,
|
||||
static_cast<int32_t>(n2),
|
||||
static_cast<int32_t>(L1TileShape::N),
|
||||
shmem,
|
||||
static_cast<int32_t>(peermemInfo.offsetD)
|
||||
static_cast<int32_t>(n2)
|
||||
};
|
||||
|
||||
uint32_t n = params.problemShape.n();
|
||||
BlockEpilogue2 blockEpilogue2(resource, epilogueParams);
|
||||
|
||||
BlockEpilogue1 blockEpilogue1(resource, n);
|
||||
|
||||
// Synchronous wait: SwiGLU waits for GMM1 [1]
|
||||
@@ -889,16 +886,13 @@ private:
|
||||
}
|
||||
|
||||
blockEpilogue1.Finalize();
|
||||
|
||||
|
||||
CombineSetFlag();
|
||||
|
||||
CombineV2(params, blockEpilogue2);
|
||||
|
||||
BlockEpilogue2 blockEpilogue2(resource, epilogueParams);
|
||||
CombineV1(params, blockEpilogue2);
|
||||
AscendC::SyncAll<true>();
|
||||
#ifndef __CROSSRANKSYNCANDALLGATHERV1__
|
||||
ResetTokenPerExpert(params.EP * AlignUp(params.EP * params.expertPerRank, 128));
|
||||
#endif
|
||||
|
||||
shmem.InitStatusTargetSum();
|
||||
if (get_subblockid() == 0) {
|
||||
AscendC::LocalTensor<int32_t> ctrBuffer = resource.ubBuf.template GetBufferByByte<int32_t>(0);
|
||||
@@ -929,6 +923,49 @@ private:
|
||||
}
|
||||
|
||||
}
|
||||
CATLASS_DEVICE
|
||||
void CombineV1(Params const ¶ms, BlockEpilogue2 & blockEpilogue) {
|
||||
uint32_t n2 = params.problemShape.k();
|
||||
int32_t prevGroupSum2 = 0;
|
||||
|
||||
icache_preload(8);
|
||||
for (uint32_t t_groupIdx = 0; t_groupIdx < params.expertPerRank; ++t_groupIdx) {
|
||||
int32_t flagId = t_groupIdx / CROSS_CORE_FLAG_MAX_SET_COUNT;
|
||||
AscendC::CrossCoreWaitFlag<0x2>(flagId);
|
||||
AscendC::SyncAll<true>();
|
||||
|
||||
uint32_t groupIdx = t_groupIdx;
|
||||
|
||||
for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) {
|
||||
__gm__ void* dstPeermemPtr = shmem(peermemInfo.offsetD, dstEpIdx);
|
||||
AscendC::GlobalTensor<ElementD2> gmRemotePeer;
|
||||
gmRemotePeer.SetGlobalBuffer(reinterpret_cast<__gm__ ElementD2*>(dstPeermemPtr));
|
||||
uint32_t srcRowOffset = (dstEpIdx == 0 ? 0 : cumsumMM((dstEpIdx - 1) * params.expertPerRank + groupIdx)) + prevGroupSum2;
|
||||
if (srcRowOffset < params.maxOutputSize) {
|
||||
uint32_t dataRows = tokenPerExpert(tokenPerExpertLayout(dstEpIdx, params.rank, groupIdx));
|
||||
if (srcRowOffset + dataRows > params.maxOutputSize) {
|
||||
dataRows = params.maxOutputSize - srcRowOffset;
|
||||
}
|
||||
//uint32_t dstRowOffset = preSumBeforeRank(2 * dstEpIdx * FLAGSTRIDE + groupIdx);
|
||||
int32_t tmpBlock = AlignUp(params.expertPerRank, FLAGSTRIDE);
|
||||
//uint32_t dstRowOffset = preSumBeforeRank(dstEpIdx * tmpBlock + groupIdx);
|
||||
uint32_t dstRowOffset = preSumBeforeRank(dstEpIdx * params.expertPerRank + groupIdx);
|
||||
MatrixCoord offsetC{srcRowOffset, 0};
|
||||
MatrixCoord offsetPeer{dstRowOffset, 0};
|
||||
MatrixCoord shapeC{dataRows, n2};
|
||||
int64_t gmOffsetC = params.layoutD2.GetOffset(offsetC);
|
||||
int64_t gmOffsetPeer = params.layoutD2.GetOffset(offsetPeer);
|
||||
if constexpr (std::is_same_v<ElementA, int8_t>) {
|
||||
blockEpilogue(gmC2[gmOffsetC], shapeC, gmPerTokenScale2[srcRowOffset], gmRemotePeer[gmOffsetPeer]);
|
||||
} else {
|
||||
blockEpilogue(gmC2[gmOffsetC], shapeC, gmRemotePeer[gmOffsetPeer]);
|
||||
}
|
||||
}
|
||||
}
|
||||
prevGroupSum2 += cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx);
|
||||
}
|
||||
blockEpilogue.Finalize();
|
||||
}
|
||||
|
||||
CATLASS_DEVICE
|
||||
void CombineV2(Params const ¶ms, BlockEpilogue2 & blockEpilogue) {
|
||||
|
||||
Reference in New Issue
Block a user