[fix]: fix precision issue in dispatch_ffn_combine_bf16 and remove redundant sync (#7198)

### What this PR does / why we need it?
Fix the precision issue in dispatch_ffn_combine_bf16 operator.
Remove redundant synchronization operations in dispatch_ffn_combine
operator.

- vLLM version: v0.16.0
- vLLM main:
4034c3d32e

---------

Signed-off-by: guanguan0308 <1546542263@qq.com>
This commit is contained in:
guanguan0308
2026-03-23 10:14:03 +08:00
committed by GitHub
parent e68464a1d6
commit 44ef9a36ac
8 changed files with 531 additions and 462 deletions

View File

@@ -391,7 +391,6 @@ private:
uint16_t syncgmmIdx = 0;
AscendC::CrossCoreWaitFlag<0x2>(syncgmmIdx / CROSS_CORE_FLAG_MAX_SET_COUNT); // Wait for AIV to finish cumsum for matmul
syncgmmIdx++;
AscendC::PipeBarrier<PIPE_ALL>();
for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) {
uint32_t currentM = cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx);
@@ -405,7 +404,6 @@ private:
int32_t arrayGroupIdx = params.listLen == 1 ? 0 : groupIdx;
gmB1.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(GetTensorAddr<int8_t>(arrayGroupIdx, params.ptrB1)));
gmS.SetGlobalBuffer(reinterpret_cast<__gm__ ElementScale *>(GetTensorAddr<int64_t>(arrayGroupIdx, params.ptrScale1)));
AscendC::PipeBarrier<PIPE_ALL>();
if (currentM <= L1TileShape::M) {
gmB1.SetL2CacheHint(AscendC::CacheMode::CACHE_MODE_DISABLE);
}
@@ -493,8 +491,6 @@ private:
uint32_t startCoreIdx = 0;
AscendC::PipeBarrier<PIPE_ALL>();
int64_t preCurrentmSum = 0;
int32_t syncLoopIdx = -1;
uint32_t lastDequantExpertNum = params.expertPerRank;
@@ -503,8 +499,6 @@ private:
lastDequantExpertNum = params.expertPerRank - params.epilogueGranularity;
}
AscendC::PipeBarrier<PIPE_ALL>();
for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) {
uint32_t currentM = cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx);
if (preCurrentmSum >= params.maxOutputSize) {

View File

@@ -41,6 +41,7 @@ namespace {
constexpr uint32_t EXPERTID_INDEX = 3;
constexpr uint32_t BLOCK_NUM = 20;
constexpr uint32_t SYSTEM_NEED_WORKSPACE = 16 * 1024 * 1024;
constexpr uint64_t MB_SIZE = 1024 * 1024UL;
}
namespace optiling {
@@ -240,7 +241,8 @@ static ge::graphStatus DispatchFFNCombineBF16TilingFuncImpl(gert::TilingContext
info.maxOutputSize * n2 * sizeof(int16_t) +
info.maxOutputSize * info.K * sizeof(int16_t) +
info.maxOutputSize * k2 * sizeof(int16_t) +
info.worldSize * sizeof(int32_t) * 16;
info.worldSize * sizeof(int32_t) * 16 +
(info.expertPerRank + info.worldSize) * sizeof(int32_t) * 16;
// std::max(info.maxOutputSize * info.N * sizeof(int16_t), info.maxOutputSize * n2 * sizeof(int16_t)) +
// std::max(info.maxOutputSize * info.K * sizeof(int8_t), info.maxOutputSize * k2 * sizeof(int8_t));

View File

@@ -213,11 +213,7 @@ public:
CATLASS_DEVICE
void operator()<AscendC::AIV>(Params const &params)
{
Dispatch(params);
AscendC::SyncAll<true>();
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(SYNCFLAGV2C);
CombineV2(params);
DispatchAndCombine(params);
}
private:
@@ -241,7 +237,7 @@ private:
tokenPerExpert.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(shmem() + peermemInfo.offsetPeerTokenPerExpert));
tokenPerExpertLayout = Layout3D(params.EP * params.expertPerRank, params.expertPerRank);
tokenPerExpertLayout = Layout3D(AlignUp(params.EP * params.expertPerRank, ALIGN_128), params.expertPerRank);
preSumBeforeRank.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(workspaceInfo.ptrSumBeforeRank));
}
@@ -309,7 +305,7 @@ private:
AscendC::DataCopyPad(
tmpBuffer1,
tokenPerExpert[rankId * expertPerRank],
{U16(EP), U16(expertPerRank * sizeof(int32_t)), U16(((EP - 1) * expertPerRank) * sizeof(int32_t)), 0},
{U16(EP), U16(expertPerRank * sizeof(int32_t)), U16((AlignUp(EP * expertPerRank, 128) - expertPerRank) * sizeof(int32_t)), 0},
{}
);
@@ -331,145 +327,6 @@ private:
);
}
CATLASS_DEVICE
void CrossRankSyncAndlocalTokenPerExpertAllGatherAndGetSumPreRank(Params const &params, int64_t localTokenPerExpertOffset){
AscendC::LocalTensor<int32_t> tmpBuffer = resource.ubBuf.template GetBufferByByte<int32_t>(0);
AscendC::LocalTensor<float> ubFloat = resource.ubBuf.template GetBufferByByte<float>(0);
uint32_t numPerCore = params.EP * params.expertPerRank;
for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) {
if (dstEpIdx == params.rank) {
continue;
}
AscendC::GlobalTensor<int32_t> srcAddress;
srcAddress.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(shmem() + localTokenPerExpertOffset));
AscendC::GlobalTensor<int32_t> dstAddress;
__gm__ void* dstPeermemPtr = shmem(localTokenPerExpertOffset, coreIdx);
dstAddress.SetGlobalBuffer((__gm__ int32_t * )dstPeermemPtr);
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
using TType = Gemm::GemmType<int32_t, layout::RowMajor>;
using CopyGmToUb = Epilogue::Tile::CopyGm2Ub<ArchTag, TType>;
using CopyUbToGm = Epilogue::Tile::CopyUb2Gm<ArchTag, TType>;
CopyGmToUb copyGmToUb;
CopyUbToGm copyUbToGm;
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
copyGmToUb(tmpBuffer, srcAddress[0],
layout::RowMajor{ 1, numPerCore},
layout::RowMajor{1, numPerCore});
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
AscendC::Adds(tmpBuffer, tmpBuffer, 0x800000, numPerCore);
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
copyUbToGm(dstAddress[0], tmpBuffer,
layout::RowMajor{ 1, numPerCore},
layout::RowMajor{1, numPerCore});
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
}
for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) {
if (dstEpIdx != params.rank) {
int32_t intPer512 = CACHE_LINE / sizeof(int);
for(int32_t checkIdx = 0; checkIdx < params.EP * params.expertPerRank; checkIdx += intPer512) {
__gm__ int32_t* sync_check = reinterpret_cast<__gm__ int32_t*>(shmem() + peermemInfo.offsetPeerTokenPerExpert) + tokenPerExpertLayout(dstEpIdx, 0, checkIdx);
gm_signal_wait_until_ne(sync_check, 0);
}
AscendC::DataCopy(tmpBuffer, tokenPerExpert[tokenPerExpertLayout(dstEpIdx, 0, 0)], numPerCore);
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
AscendC::Adds(tmpBuffer, tmpBuffer, -0x800000, numPerCore);
AscendC::PipeBarrier<PIPE_V>();
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
AscendC::DataCopy(tokenPerExpert[tokenPerExpertLayout(dstEpIdx, 0, 0)], tmpBuffer, numPerCore);
} else {
AscendC::DataCopy(tmpBuffer, tokenPerExpert[tokenPerExpertLayout(dstEpIdx, 0, 0)], numPerCore);
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
}
}
AscendC::SyncAll<true>();
}
CATLASS_DEVICE
void CrossRankSyncAndlocalTokenPerExpertAllGatherAndGetSumPreRankV2(Params const &params, int64_t localTokenPerExpertOffset){
AscendC::LocalTensor<int32_t> tmpBuffer = resource.ubBuf.template GetBufferByByte<int32_t>(0);
AscendC::LocalTensor<float> ubFloat = resource.ubBuf.template GetBufferByByte<float>(0);
uint32_t numPerCore = params.EP * params.expertPerRank;
for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) {
if (dstEpIdx == params.rank) {
continue;
}
AscendC::GlobalTensor<int32_t> srcAddress;
srcAddress.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(shmem() + localTokenPerExpertOffset));
AscendC::GlobalTensor<int32_t> dstAddress;
__gm__ void* dstPeermemPtr = shmem(localTokenPerExpertOffset, coreIdx);
dstAddress.SetGlobalBuffer((__gm__ int32_t * )dstPeermemPtr);
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
using TType = Gemm::GemmType<int32_t, layout::RowMajor>;
using CopyGmToUb = Epilogue::Tile::CopyGm2Ub<ArchTag, TType>;
using CopyUbToGm = Epilogue::Tile::CopyUb2Gm<ArchTag, TType>;
CopyGmToUb copyGmToUb;
CopyUbToGm copyUbToGm;
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
copyGmToUb(tmpBuffer, srcAddress[0],
layout::RowMajor{ 1, numPerCore},
layout::RowMajor{1, numPerCore});
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
AscendC::Adds(tmpBuffer, tmpBuffer, 0x800000, numPerCore);
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
copyUbToGm(dstAddress[0], tmpBuffer,
layout::RowMajor{ 1, numPerCore},
layout::RowMajor{1, numPerCore});
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
}
for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) {
if (dstEpIdx != params.rank) {
int32_t intPer512 = CACHE_LINE / sizeof(int);
for(int32_t checkIdx = 0; checkIdx < params.EP * params.expertPerRank; checkIdx += intPer512) {
__gm__ int32_t* sync_check = reinterpret_cast<__gm__ int32_t*>(shmem() + peermemInfo.offsetPeerTokenPerExpert) + tokenPerExpertLayout(dstEpIdx, 0, checkIdx);
gm_signal_wait_until_ne(sync_check, 0);
}
AscendC::DataCopy(tmpBuffer, tokenPerExpert[tokenPerExpertLayout(dstEpIdx, 0, 0)], numPerCore);
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
AscendC::Adds(tmpBuffer, tmpBuffer, -0x800000, numPerCore);
AscendC::PipeBarrier<PIPE_V>();
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
AscendC::DataCopy(tokenPerExpert[tokenPerExpertLayout(dstEpIdx, 0, 0)], tmpBuffer, numPerCore);
} else {
AscendC::DataCopy(tmpBuffer, tokenPerExpert[tokenPerExpertLayout(dstEpIdx, 0, 0)], numPerCore);
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
}
int32_t prevSum = 0;
for (int32_t i = 0; i < params.rank * params.expertPerRank; i++) {
prevSum += tmpBuffer(i);
}
preSumBeforeRank(dstEpIdx * 16) = prevSum;
__asm__ __volatile__("");
DataCacheCleanAndInvalid<int32_t, CacheLine::SINGLE_CACHE_LINE, DcciDst::CACHELINE_OUT>(preSumBeforeRank[dstEpIdx * 16]);
__asm__ __volatile__("");
}
AscendC::SyncAll<true>();
}
CATLASS_DEVICE
void GetSumPreRank(AscendC::GlobalTensor<int32_t> & tokenPerExpert, AscendC::GlobalTensor<int32_t> & result,
uint32_t expertPerRank, uint32_t rankId, uint32_t EP) {
@@ -506,20 +363,20 @@ CATLASS_DEVICE
icache_preload(8);
BlockScheduler blockScheduler;
BlockMmad blockMmad(resource);
float aivFinishGroups = 0.0f;
__gm__ float* aivFinishPtr = workspaceInfo.ptrSoftFlagBase + params.EP * FLAGSTRIDE;
int64_t gmGroupOffsetA = 0;
int64_t gmGroupOffsetB = 0;
int64_t gmGroupOffsetC = 0;
uint32_t startCoreIdx = 0;
uint32_t syncGroupIdx = 0;
uint16_t syncgmmIdx = 0;
AscendC::CrossCoreWaitFlag<0x2>(syncgmmIdx / CROSS_CORE_FLAG_MAX_SET_COUNT); // Wait for AIV to finish cumsum for matmul
syncgmmIdx++;
AicSyncAll();
int64_t preCurrentmSum = 0;
int32_t syncLoopIdx = -1;
AscendC::PipeBarrier<PIPE_ALL>();
uint16_t syncgmmIdx = 0;
AscendC::CrossCoreWaitFlag<0x2>(syncgmmIdx / CROSS_CORE_FLAG_MAX_SET_COUNT); // Wait for AIV to finish cumsum for matmul
syncgmmIdx++;
for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) {
uint32_t currentM = cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx);
@@ -533,9 +390,6 @@ CATLASS_DEVICE
int32_t arrayGroupIdx = params.listLen == 1 ? 0 : groupIdx;
gmB1.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(GetTensorAddr<int8_t>(arrayGroupIdx, params.ptrB1)));
gmS.SetGlobalBuffer(reinterpret_cast<__gm__ ElementScale *>(GetTensorAddr<int64_t>(arrayGroupIdx, params.ptrScale1)));
AscendC::PipeBarrier<PIPE_ALL>();
if (currentM <= L1TileShape::M) {
gmB1.SetL2CacheHint(AscendC::CacheMode::CACHE_MODE_DISABLE);
}
@@ -630,7 +484,6 @@ CATLASS_DEVICE
uint32_t startCoreIdx = 0;
AscendC::PipeBarrier<PIPE_ALL>();
int64_t preCurrentmSum = 0;
int32_t syncLoopIdx = -1;
@@ -640,7 +493,6 @@ CATLASS_DEVICE
lastDequantExpertNum = params.expertPerRank - params.epilogueGranularity;
}
AscendC::PipeBarrier<PIPE_ALL>();
for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) {
uint32_t currentM = cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx);
@@ -651,7 +503,6 @@ CATLASS_DEVICE
}
AscendC::GlobalTensor<ElementB> gmB2;
AscendC::GlobalTensor<ElementScale> gmS2;
AscendC::PipeBarrier<PIPE_ALL>();
int32_t arrayGroupIdx = params.listLen == 1 ? 0 : groupIdx;
gmB2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(GetTensorAddr<int8_t>(arrayGroupIdx, params.ptrB2)));
gmS2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementScale *>(GetTensorAddr<int64_t>(arrayGroupIdx, params.ptrScale2)));
@@ -721,7 +572,6 @@ CATLASS_DEVICE
gmGroupOffsetC += inGroupProblemShape.m() * inGroupProblemShape.n();
startCoreIdx = (startCoreIdx + coreLoops) % coreNum;
}
if constexpr (BlockMmad::DispatchPolicy::ASYNC) {
@@ -729,8 +579,168 @@ CATLASS_DEVICE
}
}
CATLASS_DEVICE
void InitArithProgress(Params const &params) {
AscendC::LocalTensor<float> tmpBuffer1 = resource.ubBuf.template GetBufferByByte<float>(0);
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID0);
AscendC::Duplicate(tmpBuffer1, 0.0f, (params.EP + 1) * FLAGSTRIDE);
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
AscendC::GlobalTensor<float> flagGlobalBase;
flagGlobalBase.SetGlobalBuffer(workspaceInfo.ptrSoftFlagBase);
AscendC::DataCopy(flagGlobalBase, tmpBuffer1, (params.EP + 1) * FLAGSTRIDE);
}
CATLASS_DEVICE
void Dispatch(Params const &params) {
void CrossRankSyncAndlocalTokenPerExpertAllGatherAndGetSumPreRankV2(Params const &params, int64_t localTokenPerExpertOffset){
uint32_t numPerCore = AlignUp(params.EP * params.expertPerRank, 128);
AscendC::LocalTensor<int32_t> tmpBuffer = resource.ubBuf.template GetBufferByByte<int32_t>(0);
AscendC::LocalTensor<int32_t> prevSumBuf = tmpBuffer[numPerCore];
for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) {
if (dstEpIdx == params.rank) {
continue;
}
AscendC::GlobalTensor<int32_t> srcAddress;
srcAddress.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(shmem() + localTokenPerExpertOffset));
AscendC::GlobalTensor<int32_t> dstAddress;
__gm__ void* dstPeermemPtr = shmem(localTokenPerExpertOffset, coreIdx);
dstAddress.SetGlobalBuffer((__gm__ int32_t * )dstPeermemPtr);
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
using TType = Gemm::GemmType<int32_t, layout::RowMajor>;
using CopyGmToUb = Epilogue::Tile::CopyGm2Ub<ArchTag, TType>;
using CopyUbToGm = Epilogue::Tile::CopyUb2Gm<ArchTag, TType>;
CopyGmToUb copyGmToUb;
CopyUbToGm copyUbToGm;
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
copyGmToUb(tmpBuffer, srcAddress[0],
layout::RowMajor{ 1, numPerCore},
layout::RowMajor{1, numPerCore});
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
AscendC::Adds(tmpBuffer, tmpBuffer, 0x800000, numPerCore);
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
copyUbToGm(dstAddress[0], tmpBuffer,
layout::RowMajor{ 1, numPerCore},
layout::RowMajor{1, numPerCore});
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
}
for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) {
if (dstEpIdx != params.rank) {
int32_t intPer512 = CACHE_LINE / sizeof(int);
for(int32_t checkIdx = 0; checkIdx < AlignUp(params.EP * params.expertPerRank, 128); checkIdx += intPer512) {
__gm__ int32_t* sync_check = reinterpret_cast<__gm__ int32_t*>(shmem() + peermemInfo.offsetPeerTokenPerExpert) + tokenPerExpertLayout(dstEpIdx, 0, checkIdx);
gm_signal_wait_until_ne(sync_check, 0);
}
AscendC::DataCopy(tmpBuffer, tokenPerExpert[tokenPerExpertLayout(dstEpIdx, 0, 0)], numPerCore);
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
AscendC::Adds(tmpBuffer, tmpBuffer, -0x800000, numPerCore);
AscendC::PipeBarrier<PIPE_V>();
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
AscendC::DataCopy(tokenPerExpert[tokenPerExpertLayout(dstEpIdx, 0, 0)], tmpBuffer, numPerCore);
} else {
AscendC::DataCopy(tmpBuffer, tokenPerExpert[tokenPerExpertLayout(dstEpIdx, 0, 0)], numPerCore);
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
}
AscendC::PipeBarrier<PIPE_ALL>();
int32_t prevSum = 0;
int32_t j = 0;
for (int32_t i = 0; i < (params.rank + 1) * params.expertPerRank; i++) {
if (i >= params.rank * params.expertPerRank) {
prevSumBuf(j) = prevSum;
j++;
}
prevSum += tmpBuffer(i);
}
AscendC::SetFlag<AscendC::HardEvent::S_MTE3>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::S_MTE3>(EVENT_ID0);
AscendC::DataCopyPad(preSumBeforeRank[dstEpIdx * params.expertPerRank], prevSumBuf,
AscendC::DataCopyParams{1, static_cast<uint16_t>(params.expertPerRank * sizeof(int32_t)), 0, 0});
}
AscendC::SyncAll<true>();
}
CATLASS_DEVICE
void ResetTokenPerExpert(int32_t num)
{
if (coreIdx != coreNum - 1) {
return;
}
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID0);
AscendC::LocalTensor<int32_t> tmp = resource.ubBuf.template GetBufferByByte<int32_t>(0);
AscendC::Duplicate(tmp, 0, num);
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
AscendC::DataCopy(tokenPerExpert, tmp, num);
}
CATLASS_DEVICE
void UpdateAicFlags(const Params &params)
{
float flagBase = 1.0f * params.expertPerRank;
__gm__ float* aicFinishPtr = workspaceInfo.ptrSoftFlagBase + params.EP * FLAGSTRIDE;
float flag = 0.0f;
float lastflag = -1.0f;
AscendC::LocalTensor<float> tmpBuffer1 = resource.ubBuf.template GetBufferByByte<float>(0);
__gm__ float* flagPtr = workspaceInfo.ptrSoftFlagBase;
AscendC::GlobalTensor<float> flagGM;
flagGM.SetGlobalBuffer(flagPtr);
int32_t flagBufferSize = max(4, params.EP) * FLAGSTRIDE;
AscendC::LocalTensor<float> dstValueBuffer = resource.ubBuf.template GetBufferByByte<float>(flagBufferSize);
AscendC::LocalTensor<float> sharedTmpBuffer = resource.ubBuf.template GetBufferByByte<float>((flagBufferSize + 64));
uint64_t mask[1] = {0};
uint32_t repeatNum = (flagBufferSize / (4 * FLAGSTRIDE));
for (int32_t i = 0; i < 4; i ++) {
if (i < params.EP) {
mask[0] |= 1ull * (1ull << (i * 16));
}
}
AscendC::SetFlag<AscendC::HardEvent::S_V>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::S_V>(EVENT_ID0);
while (flag < flagBase) {
flag = flagBase;
AscendC::DataCopy(tmpBuffer1, flagGM, params.EP * FLAGSTRIDE);
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
AscendC::ReduceMin<float>(dstValueBuffer, tmpBuffer1, sharedTmpBuffer, mask, repeatNum, 8, false);
AscendC::SetFlag<AscendC::HardEvent::V_S>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::V_S>(EVENT_ID0);
flag = min(flag, dstValueBuffer.GetValue(0));
if (flag > lastflag) {
*aicFinishPtr = flag;
gm_dcci(aicFinishPtr);
lastflag = flag;
}
}
}
CATLASS_DEVICE
void CombineSetFlag() {
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID1);
}
CATLASS_DEVICE
void DispatchAndCombine(Params const &params) {
icache_preload(8);
int64_t localTokenPerExpertOffset = peermemInfo.offsetPeerTokenPerExpert + tokenPerExpertLayout(params.rank, 0, 0) * sizeof(int32_t);
GM_ADDR localTokenPerExpert = shmem() + localTokenPerExpertOffset;
@@ -744,35 +754,35 @@ CATLASS_DEVICE
AscendC::SyncAll<true>();
CrossRankSyncAndlocalTokenPerExpertAllGatherAndGetSumPreRankV2(params, localTokenPerExpertOffset);
if (coreIdx == coreNum - 1) {
if (coreIdx == 0) {
GetCumsumForMMAIV(tokenPerExpert, cumsumMM, params.expertPerRank, params.rank, params.EP);
}
uint32_t curGroupOffset = 0;
int32_t prevSumBeforeRank = 0;
int32_t prevSum = 0;
if (coreIdx < params.EP) {
prevSum = preSumBeforeRank(coreIdx * params.expertPerRank);
}
AscendC::SyncAll<true>();
AscendC::GlobalTensor<int32_t> ExpertTokenNums;
ExpertTokenNums.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(params.ptrExpertTokenNums));
if(coreIdx == 0)
{
CopyGMToGM(ExpertTokenNums, cumsumMM[(params.EP - 1) * params.expertPerRank], params.expertPerRank, params.ubMoveNum);
}
uint16_t syncgmm1Idx = 0;
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(syncgmm1Idx / CROSS_CORE_FLAG_MAX_SET_COUNT);
syncgmm1Idx++;
AscendC::GlobalTensor<int32_t> ExpertTokenNums;
ExpertTokenNums.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(params.ptrExpertTokenNums));
AscendC::GlobalTensor<int32_t> LcalCumsumMM;
LcalCumsumMM.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(workspaceInfo.ptrcumsumMM + (params.EP - 1) * params.expertPerRank * sizeof(int32_t)));
if (coreIdx == 0) {
CopyGMToGM(ExpertTokenNums, LcalCumsumMM, params.expertPerRank, params.ubMoveNum);
}
uint32_t curGroupOffset = 0;
int32_t prevSumBeforeRank = 0;
int32_t groupIdxDeq = 0;
int prevSum = 0;
if (coreIdx < params.EP) {
prevSum = preSumBeforeRank(coreIdx * 16);
}
uint32_t prevGroupSum1 = 0;
uint32_t prevGroupSum1 = 0, dequantSum1 = 0, dequantSum2 = 0;
uint32_t dequantSum = 0;
int32_t syncLoopIdx = -1;
uint32_t n = params.problemShape.n();
BlockEpilogue1 blockEpilogue(resource, n);
icache_preload(8);
for (int32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) {
// The ith core reads data from the ith rank's peermem
uint32_t currentM = cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx);
for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) {
uint32_t rowStart = (dstEpIdx == 0 ? 0 : cumsumMM((dstEpIdx - 1) * params.expertPerRank + groupIdx)) + prevGroupSum1;
if (rowStart < params.maxOutputSize) {
@@ -785,99 +795,40 @@ CATLASS_DEVICE
GM_ADDR otherRankPtr = shmem(0, dstEpIdx);
AscendC::GlobalTensor<ElementA> gmRemoteA;
gmRemoteA.SetGlobalBuffer(reinterpret_cast<__gm__ ElementA*>(otherRankPtr + peermemInfo.offsetA));
AscendC::GlobalTensor<ElementPerTokenScale> gmRemotePerTokenScale;
gmRemotePerTokenScale.SetGlobalBuffer(reinterpret_cast<__gm__ ElementPerTokenScale*>(otherRankPtr + peermemInfo.offsetPeerPerTokenScale));
MatrixCoord offsetA{rowStart, 0};
MatrixCoord shapeA{rows, params.problemShape.k()};
MatrixCoord offsetPeer{rowSrc, 0};
int64_t gmOffsetA = params.layoutA.GetOffset(offsetA);
int64_t gmOffsetPeer = params.layoutA.GetOffset(offsetPeer);
CopyGMToGM(gmA[gmOffsetA], gmRemoteA[gmOffsetPeer], rows * params.problemShape.k(), params.ubMoveNum);
if constexpr (std::is_same_v<ElementA, int8_t>) {
CopyGMToGM(gmPerTokenScale1[rowStart], gmRemotePerTokenScale[rowSrc], rows, rows);
}
}
}
if ((params.epilogueGranularity < params.expertPerRank && params.epilogueGranularity > 0) && groupIdx == params.expertPerRank - 1) {
syncLoopIdx++;
AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGC2V);
}
AscendC::SyncAll<true>();
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(syncgmm1Idx / CROSS_CORE_FLAG_MAX_SET_COUNT); // V notifies C that the current communication round is complete
syncgmm1Idx++;
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(syncgmm1Idx / CROSS_CORE_FLAG_MAX_SET_COUNT);
syncgmm1Idx ++;
if ((params.epilogueGranularity < params.expertPerRank && params.epilogueGranularity > 0) && groupIdx == params.expertPerRank - 1 && prevGroupSum1 > 0) {
uint32_t rowStartThisCore = 0;
MatrixCoord offsetC{0U, 0};
uint32_t dequantLen = prevGroupSum1 - dequantSum;
if (dequantLen >= params.maxOutputSize) {
dequantLen = dequantLen - params.maxOutputSize;
}
prevGroupSum1 += currentM;
MatrixCoord shapeC{dequantLen, params.problemShape.n()};
LayoutC layoutC{dequantLen, params.problemShape.n()};
int64_t gmOffsetC = layoutC.GetOffset(offsetC);
int64_t gmOffsetD = params.layoutD1.GetOffset(offsetC);
if constexpr (std::is_same_v<ElementA, int8_t>) {
blockEpilogue(gmC[gmOffsetC], shapeC, gmPerTokenScale1[rowStartThisCore], gmPermutedToken[gmOffsetD], gmPerTokenScale2[rowStartThisCore], params.epilogueCoreNum);
} else {
blockEpilogue(gmC[gmOffsetC], shapeC, gmPermutedToken[gmOffsetD], params.epilogueCoreNum);
// Token count and truncation logic for the first SwiGLU operation
if (groupIdx + 1 <= params.epilogueGranularity) {
if (dequantSum1 + currentM <= params.maxOutputSize) {
dequantSum1 += currentM;
} else if (dequantSum1 < params.maxOutputSize) {
dequantSum1 = params.maxOutputSize;
}
}
prevGroupSum1 += cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx);
dequantSum += cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx);
if (groupIdx + 1 == params.epilogueGranularity && groupIdx < params.expertPerRank - 1) {
dequantSum = 0;
}
}
syncLoopIdx ++;
AscendC::SyncAll<true>();
uint32_t lastDequantExpertNum = params.expertPerRank;
if (params.epilogueGranularity < params.expertPerRank) {
lastDequantExpertNum = params.expertPerRank - params.epilogueGranularity;
}
if (lastDequantExpertNum < params.expertPerRank) {
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(SYNCFLAGV2C);
}
AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGC2V);
AscendC::SyncAll<true>();
if (prevGroupSum1 - dequantSum < params.maxOutputSize) {
uint32_t rowStartThisCore = prevGroupSum1 - dequantSum;;
MatrixCoord offsetC{rowStartThisCore, 0};
uint32_t dequantLen = dequantSum;
if (prevGroupSum1 >= params.maxOutputSize) {
dequantLen = dequantSum - (prevGroupSum1 - params.maxOutputSize);
}
MatrixCoord shapeC{dequantLen, params.problemShape.n()};
LayoutC layoutC{dequantLen, params.problemShape.n()};
int64_t gmOffsetC = layoutC.GetOffset(offsetC);
int64_t gmOffsetD = params.layoutD1.GetOffset(offsetC);
if constexpr (std::is_same_v<ElementA, int8_t>) {
blockEpilogue(gmC[gmOffsetC], shapeC, gmPerTokenScale1[rowStartThisCore], gmPermutedToken[gmOffsetD], gmPerTokenScale2[rowStartThisCore], coreNum);
} else {
blockEpilogue(gmC[gmOffsetC], shapeC, gmPermutedToken[gmOffsetD], coreNum);
// Token count and truncation logic for the second SwiGLU operation
if (groupIdx + 1 > params.epilogueGranularity && dequantSum1 < params.maxOutputSize) {
if (dequantSum1 + dequantSum2 + currentM <= params.maxOutputSize) {
dequantSum2 += currentM;
} else if (dequantSum1 + dequantSum2 < params.maxOutputSize) {
dequantSum2 += params.maxOutputSize - dequantSum1 - dequantSum2;
}
}
}
blockEpilogue.Finalize();
}
CATLASS_DEVICE
void CombineV2(Params const &params) {
BlockScheduler blockScheduler;
uint32_t n2 = params.problemShape.k();
uint32_t k2 = params.problemShape.n() / 2;
uint32_t startCoreIdx = 0;
int64_t gmGroupOffsetC = 0;
uint32_t aivCoreNum = coreNum;
uint32_t aicCoreNum = coreNum / 2;
uint32_t aivCoreIdx = coreIdx;
uint32_t aicCoreIdx = get_block_idx();
uint32_t aivSubCoreIdx = get_subblockid();
uint32_t preSrcExpertSum = 0;
typename BlockEpilogue2::Params epilogueParams{
static_cast<int32_t>(params.EP),
@@ -891,11 +842,108 @@ CATLASS_DEVICE
static_cast<int32_t>(peermemInfo.offsetD)
};
BlockEpilogue2 blockEpilogue(resource, epilogueParams);
BlockEpilogue2 blockEpilogue2(resource, epilogueParams);
uint32_t n = params.problemShape.n();
BlockEpilogue1 blockEpilogue1(resource, n);
// Synchronous wait: SwiGLU waits for GMM1 [1]
AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGC2V);
AscendC::SyncAll<true>();
if (dequantSum1 > 0) {
uint32_t rowStartThisCore = 0;
MatrixCoord offsetC{0U, 0};
MatrixCoord shapeC{dequantSum1, params.problemShape.n()};
LayoutC layoutC{dequantSum1, params.problemShape.n()};
int64_t gmOffsetC = layoutC.GetOffset(offsetC);
int64_t gmOffsetD = params.layoutD1.GetOffset(offsetC);
// blockEpilogue1(gmC[gmOffsetC], shapeC, gmPerTokenScale1[rowStartThisCore], gmPermutedToken[gmOffsetD], gmPerTokenScale2[rowStartThisCore], params.epilogueCoreNum);
blockEpilogue1(gmC[gmOffsetC], shapeC, gmPermutedToken[gmOffsetD], params.epilogueCoreNum);
}
AscendC::SyncAll<true>();
// Synchronization signal: SwiGLU notifies GMM2 [1]
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(SYNCFLAGV2C);
if ((params.epilogueGranularity < params.expertPerRank && params.epilogueGranularity > 0)) {
// Synchronous wait: SwiGLU waits for GMM1 [2]
AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGC2V);
AscendC::SyncAll<true>();
if (dequantSum2 > 0) {
uint32_t rowStartThisCore = dequantSum1;
MatrixCoord offsetC{rowStartThisCore, 0};
uint32_t dequantLen = dequantSum2;
MatrixCoord shapeC{dequantLen, params.problemShape.n()};
LayoutC layoutC{dequantLen, params.problemShape.n()};
int64_t gmOffsetC = layoutC.GetOffset(offsetC);
int64_t gmOffsetD = params.layoutD1.GetOffset(offsetC);
// blockEpilogue1(gmC[gmOffsetC], shapeC, gmPerTokenScale1[rowStartThisCore], gmPermutedToken[gmOffsetD], gmPerTokenScale2[rowStartThisCore], coreNum);
blockEpilogue1(gmC[gmOffsetC], shapeC, gmPermutedToken[gmOffsetD], coreNum);
}
AscendC::SyncAll<true>();
// Synchronization signal: SwiGLU notifies GMM2 [2]
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(SYNCFLAGV2C);
}
blockEpilogue1.Finalize();
CombineSetFlag();
CombineV2(params, blockEpilogue2);
AscendC::SyncAll<true>();
#ifndef __CROSSRANKSYNCANDALLGATHERV1__
ResetTokenPerExpert(params.EP * AlignUp(params.EP * params.expertPerRank, 128));
#endif
shmem.InitStatusTargetSum();
if (get_subblockid() == 0) {
AscendC::LocalTensor<int32_t> ctrBuffer = resource.ubBuf.template GetBufferByByte<int32_t>(0);
shmem.CrossRankSyncV2Set(ctrBuffer);
} else {
uint32_t uboffset = 0;
uint32_t aicCoreNum = coreNum / 2;
uint32_t aicCoreIdx = get_block_idx();
uint32_t sendRankNum_ = params.EP / aicCoreNum;
uint32_t remainderRankNum = params.EP % aicCoreNum;
if (aicCoreIdx < remainderRankNum) {
sendRankNum_++;
}
AscendC::LocalTensor<float> statusTensor = resource.ubBuf.template GetBufferByByte<float>(uboffset);
uboffset += sendRankNum_ * UB_ALIGN;
AscendC::LocalTensor<float> gatherMaskOutTensor = resource.ubBuf.template GetBufferByByte<float>(uboffset);
uboffset += AlignUp(params.EP * sizeof(float), 32);
AscendC::LocalTensor<uint32_t> gatherTmpTensor = resource.ubBuf.template GetBufferByByte<uint32_t>(uboffset);
uboffset += AlignUp(sizeof(uint32_t), 32);
AscendC::LocalTensor<float> statusSumOutTensor = resource.ubBuf.template GetBufferByByte<float>(uboffset);
uboffset += AlignUp(sizeof(float), 32);
shmem.CrossRankSyncV2Wait(statusTensor, gatherMaskOutTensor, gatherTmpTensor, statusSumOutTensor);
MoeTokenUnpermuteTilingData tilingData;
MoeTokenUnpermuteTiling(params.problemShape.m() * params.topK, n2, params.topK, tilingData, coreNum / 2);
KernelMoeTokenUnpermute<ElementD2, int32_t, float, true> kernelMoeTokenUnpermuteOp;
kernelMoeTokenUnpermuteOp.Init(shmem() + peermemInfo.offsetD, workspaceInfo.expandedRowIdx, params.probs, reinterpret_cast<GM_ADDR>(params.ptrOutput), &tilingData);
kernelMoeTokenUnpermuteOp.Process();
}
}
CATLASS_DEVICE
void CombineV2(Params const &params, BlockEpilogue2 & blockEpilogue) {
BlockScheduler blockScheduler;
int32_t syncLoopIdx = 0;
uint32_t startCoreIdx = 0;
uint32_t aicCoreNum = coreNum / 2;
uint32_t aicCoreIdx = get_block_idx();
uint32_t aivSubCoreIdx = get_subblockid();
uint32_t preSrcExpertSum = 0;
uint32_t n2 = params.problemShape.k();
uint32_t k2 = params.problemShape.n() / 2;
icache_preload(8);
for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) {
uint32_t currentExpertM = cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx);
if (preSrcExpertSum >= params.maxOutputSize) {
currentExpertM = 0;
} else if (preSrcExpertSum + currentExpertM > params.maxOutputSize) {
currentExpertM = params.maxOutputSize - preSrcExpertSum;
}
GemmCoord inGroupProblemShape{currentExpertM, n2, k2}; // M N K
blockScheduler.Update(inGroupProblemShape, MakeCoord(L1TileShape::M, L1TileShape::N));
uint32_t coreLoops = blockScheduler.GetCoreLoops();
@@ -916,11 +964,10 @@ CATLASS_DEVICE
if(aivSubCoreIdx == 1) {
m_offset += (m_rows / 2) * m0;
}
if (loopIdx == startLoopIdx) {
for (;syncLoopIdx <= groupIdx; syncLoopIdx++) {
int32_t flag_id = syncLoopIdx / CROSS_CORE_FLAG_MAX_SET_COUNT;
AscendC::CrossCoreWaitFlag<0x2>(flag_id);
}
for (;syncLoopIdx <= groupIdx; syncLoopIdx ++) {
int32_t flag_id = syncLoopIdx / CROSS_CORE_FLAG_MAX_SET_COUNT;
AscendC::CrossCoreWaitFlag<0x2>(flag_id);
}
for (int32_t cur_row = 0; cur_row < aiv_m_rows; cur_row ++) {
@@ -930,34 +977,15 @@ CATLASS_DEVICE
actualm = actualBlockShape.m() - (m_rows / 2) * m0 - cur_row * m0;
}
GemmCoord realTileShape{actualm, actualBlockShape.n(), 1};
if constexpr (std::is_same_v<ElementA, int8_t>) {
blockEpilogue(gmC2, gmPerTokenScale2, realTileCoord, realTileShape, groupIdx, preSrcExpertSum, preSumBeforeRank, mPreSumBeforeRank);
} else {
blockEpilogue(gmC2, realTileCoord, realTileShape, groupIdx, preSrcExpertSum, preSumBeforeRank, mPreSumBeforeRank);
}
blockEpilogue(gmC2, realTileCoord, realTileShape, groupIdx, preSrcExpertSum, preSumBeforeRank);
m_offset += m0;
}
}
for (int32_t dstEpIdx = 0; dstEpIdx < params.EP; dstEpIdx ++) {
int32_t expertRankM = tokenPerExpert(tokenPerExpertLayout(dstEpIdx, params.rank, groupIdx));
mPreSumBeforeRank[dstEpIdx] += expertRankM;
}
preSrcExpertSum += currentExpertM;
startCoreIdx = (startCoreIdx + coreLoops) % aicCoreNum;
}
blockEpilogue.Finalize();
AscendC::SyncAll<true>();
ResetTokenPerExpert(tokenPerExpert, params.EP * params.EP * params.expertPerRank);
shmem.CrossRankSync();
MoeTokenUnpermuteTilingData tilingData;
MoeTokenUnpermuteTiling(params.problemShape.m() * params.topK, n2, params.topK, tilingData, aivCoreNum);
KernelMoeTokenUnpermute<ElementD2, int32_t, float, true> kernelMoeTokenUnpermuteOp;
kernelMoeTokenUnpermuteOp.Init(shmem() + peermemInfo.offsetD, workspaceInfo.expandedRowIdx, params.probs, reinterpret_cast<GM_ADDR>(params.ptrOutput), &tilingData);
kernelMoeTokenUnpermuteOp.Process();
}
@@ -973,6 +1001,7 @@ private:
GM_ADDR expandedRowIdx;
GM_ADDR ptrTokenPerExpert;
GM_ADDR ptrSumBeforeRank;
__gm__ float* ptrSoftFlagBase;
CATLASS_DEVICE
WorkspaceInfo(){}
@@ -1012,6 +1041,9 @@ private:
workspaceOffset += params.maxOutputSize * k2 * sizeof(ElementA);
ptrSumBeforeRank = params.ptrWorkspace + workspaceOffset;
workspaceOffset += params.EP * sizeof(int32_t) * FLAGSTRIDE;
ptrSoftFlagBase = reinterpret_cast<__gm__ float*>(params.ptrWorkspace + workspaceOffset);
}
};
@@ -1041,7 +1073,6 @@ private:
WorkspaceInfo workspaceInfo;
PeermemInfo peermemInfo;
int64_t m_prevSumBeforeRank;
AscendC::GlobalTensor<ElementA> gmA;
AscendC::GlobalTensor<ElementC> gmC;
@@ -1057,7 +1088,7 @@ private:
AscendC::GlobalTensor<int32_t> tokenPerExpert;
AscendC::GlobalTensor<int32_t> cumsumMM;
AscendC::GlobalTensor<int32_t> preSumBeforeRank;
uint32_t mPreSumBeforeRank[32] = {0};
Layout3D tokenPerExpertLayout;
HcclShmem shmem;
};

View File

@@ -85,8 +85,8 @@ KernelMoeTokenUnpermute<T1, T2, T3, PROBS>::Init(GM_ADDR permuted_tokens, GM_ADD
GM_ADDR unpermuted_tokens,
const MoeTokenUnpermuteTilingData *__restrict tiling_data)
{
this->blockIdx = get_block_idx() + get_subblockid() * get_block_num();
this->blockNum = get_block_num() * get_subblockdim();
this->blockIdx = get_block_idx();
this->blockNum = get_block_num();
if (blockIdx >= blockNum) {
return;

View File

@@ -72,17 +72,6 @@ public:
CATLASS_DEVICE
BlockEpilogue(Arch::Resource<ArchTag> const &resource, Params const &params = Params{}) : params(params)
{
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID0);
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID1);
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID2);
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID3);
AscendC::SetFlag<AscendC::HardEvent::S_MTE2>(EVENT_ID2);
AscendC::SetFlag<AscendC::HardEvent::S_MTE2>(EVENT_ID3);
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID0);
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID1);
//ub:192KB
n0 = params.n0;
size_t ubOffset = 0;
@@ -98,137 +87,20 @@ public:
source_scale_offset[i] = -1;
}
tokenPerExpert.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(params.ptrTokenPerExpert));
tokenPerExpertLayout = Layout3D(params.EP * params.expertPerRank, params.expertPerRank);
tokenPerExpertLayout = Layout3D(AlignUp(params.EP * params.expertPerRank, ALIGN_128), params.expertPerRank);
is_ping = true;
}
CATLASS_DEVICE
void Finalize()
{
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID1);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID2);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID3);
AscendC::WaitFlag<AscendC::HardEvent::S_MTE2>(EVENT_ID2);
AscendC::WaitFlag<AscendC::HardEvent::S_MTE2>(EVENT_ID3);
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID1);
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID1);
}
CATLASS_DEVICE
~BlockEpilogue()
{
}
CATLASS_DEVICE
void operator() (
AscendC::GlobalTensor<ElementC> const &gmC,
AscendC::GlobalTensor<ElementPerTokenScale> const &gmPerTokenScale,
GemmCoord& blockCoord,
GemmCoord& actualBlockShape,
int32_t groupIdx,
int32_t preSrcExpertSum,
AscendC::GlobalTensor<int32_t> preSumBeforeRank,
uint32_t *mPreSumBeforeRank
){
is_ping = !is_ping;
auto event_id = is_ping ? EVENT_ID0 : EVENT_ID1;
auto event_id_2 = is_ping ? EVENT_ID2 : EVENT_ID3;
auto &ubC = ubCList[is_ping];
auto &ubD = ubDList[is_ping];
int32_t gmCOffset = preSrcExpertSum * params.n2 + blockCoord.m() * params.n2 + blockCoord.n();
auto gmTileC = gmC[gmCOffset];
auto &ubCFp32 = ubFp32List[is_ping];
auto &scaleUb = scaleUbList[is_ping];
// auto &ubOutFp32 = ubOutFp32List[is_ping];
LayoutC layoutGM{actualBlockShape.m(), actualBlockShape.n(), params.n2};
LayoutC layoutUB{actualBlockShape.m(), actualBlockShape.n(), n0};
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(event_id); //for debug
copyGmToUbC(ubC, gmTileC, layoutUB, layoutGM);
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(event_id); //for debug
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(event_id);
AscendC::Cast<float, ElementC, false>(ubCFp32, ubC, AscendC::RoundMode::CAST_NONE, -1, repeat, {1, 1, 8, 4});
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(event_id);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(event_id_2);
AscendC::WaitFlag<AscendC::HardEvent::S_MTE2>(event_id_2);
int32_t gmScaleOffset = preSrcExpertSum + blockCoord.m();
layout::VectorLayout scaleLauout{actualBlockShape.m()};
if (source_scale_offset[event_id] != gmScaleOffset) {
source_scale_offset[event_id] = gmScaleOffset;
copyScaleGmToUb(scaleUb, gmPerTokenScale[gmScaleOffset], scaleLauout, scaleLauout);
}
AscendC::SetFlag<AscendC::HardEvent::MTE2_S>(event_id_2);
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(event_id_2);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(event_id_2);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_S>(event_id_2); // 注意必须是MTE2_S不能是MTE2_V否则会读到0造成乱码
AscendC::PipeBarrier<PIPE_V>();
for (int32_t row = 0; row < actualBlockShape.m(); ++row) {
float scale = scaleUb(row);
Muls<float, false>(ubCFp32[n0* row], ubCFp32[n0 * row] , scale, -1, (actualBlockShape.n() + 127) / 128 * 2, {1, 1, 8, 8});
}
AscendC::PipeBarrier<PIPE_V>();
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(event_id);
AscendC::Cast<ElementD, float, false>(ubD, ubCFp32, AscendC::RoundMode::CAST_RINT, -1, repeat, {1, 1, 4, 8});
AscendC::SetFlag<AscendC::HardEvent::S_MTE2>(event_id_2);
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(event_id_2);
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(event_id);
int32_t lenTile = actualBlockShape.m();
int32_t stTile = blockCoord.m();
int32_t edTile = stTile + lenTile;
int32_t preSumRankInExpert = 0;
int32_t tileOffset = 0;
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(event_id); //for debug
for (int32_t dstEpIdx = 0; dstEpIdx < params.EP; dstEpIdx ++) {
int32_t lenRankInExpert = tokenPerExpert(tokenPerExpertLayout(dstEpIdx, params.rank, groupIdx));
int32_t dstExpertOffset = preSumBeforeRank(dstEpIdx * 16);
int32_t stRankInExpert = preSumRankInExpert;
int32_t edRankInExpert = stRankInExpert + lenRankInExpert;
preSumRankInExpert += lenRankInExpert;
if (stRankInExpert >= edTile) {
break;
}
else if (edRankInExpert <= stTile) {
continue;
}
int32_t stData = max(stRankInExpert, stTile);
int32_t edData = min(edRankInExpert, edTile);
uint32_t lenData = edData - stData;
if (lenData <= 0){
continue;
}
uint32_t dstOffsetInExpert = 0;
if (stTile > stRankInExpert) {
dstOffsetInExpert = stTile - stRankInExpert;
}
AscendC::GlobalTensor<ElementD> gmRemotePeer;
__gm__ void* dstPeermemPtr = params.shmem(params.offsetD, dstEpIdx);
gmRemotePeer.SetGlobalBuffer(reinterpret_cast<__gm__ ElementD*>(dstPeermemPtr));
MatrixCoord dstOffset{dstOffsetInExpert + dstExpertOffset + mPreSumBeforeRank[dstEpIdx], blockCoord.n()};
int64_t gmDstOffset = params.layoutC.GetOffset(dstOffset);
auto gmTileD = gmRemotePeer[gmDstOffset];
LayoutC layoutGM2{lenData, actualBlockShape.n(), params.n2};
LayoutC layoutUB2{lenData, actualBlockShape.n(), n0};
copyUbToGmD(gmTileD, ubD[tileOffset * n0], layoutGM2, layoutUB2);
tileOffset += lenData;
}
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(event_id);
}
CATLASS_DEVICE
@@ -238,14 +110,12 @@ public:
GemmCoord& actualBlockShape,
int32_t groupIdx,
int32_t preSrcExpertSum,
AscendC::GlobalTensor<int32_t> preSumBeforeRank,
uint32_t *mPreSumBeforeRank
AscendC::GlobalTensor<int32_t> preSumBeforeRank
){
is_ping = !is_ping;
auto event_id = is_ping ? EVENT_ID0 : EVENT_ID1;
auto &ubC = ubCList[is_ping];
auto &ubD = ubDList[is_ping];
int32_t gmCOffset = preSrcExpertSum * params.n2 + blockCoord.m() * params.n2 + blockCoord.n();
auto gmTileC = gmC[gmCOffset];
@@ -253,7 +123,7 @@ public:
LayoutC layoutUB{actualBlockShape.m(), actualBlockShape.n(), n0};
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(event_id); //for debug
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(event_id); //for debug
copyGmToUbC(ubC, gmTileC, layoutUB, layoutGM);
AscendC::SetFlag<AscendC::HardEvent::MTE2_MTE3>(event_id); //for debug
@@ -263,10 +133,10 @@ public:
int32_t preSumRankInExpert = 0;
int32_t tileOffset = 0;
AscendC::WaitFlag<AscendC::HardEvent::MTE2_MTE3>(event_id); //for debug
AscendC::WaitFlag<AscendC::HardEvent::MTE2_MTE3>(event_id);
for (int32_t dstEpIdx = 0; dstEpIdx < params.EP; dstEpIdx ++) {
int32_t lenRankInExpert = tokenPerExpert(tokenPerExpertLayout(dstEpIdx, params.rank, groupIdx));
int32_t dstExpertOffset = preSumBeforeRank(dstEpIdx * 16);
int32_t dstExpertOffset = preSumBeforeRank(dstEpIdx * params.expertPerRank + groupIdx);
int32_t stRankInExpert = preSumRankInExpert;
int32_t edRankInExpert = stRankInExpert + lenRankInExpert;
preSumRankInExpert += lenRankInExpert;
@@ -282,7 +152,7 @@ public:
if (lenData <= 0){
continue;
}
uint32_t dstOffsetInExpert = 0;
if (stTile > stRankInExpert) {
dstOffsetInExpert = stTile - stRankInExpert;
@@ -290,7 +160,7 @@ public:
AscendC::GlobalTensor<ElementD> gmRemotePeer;
__gm__ void* dstPeermemPtr = params.shmem(params.offsetD, dstEpIdx);
gmRemotePeer.SetGlobalBuffer(reinterpret_cast<__gm__ ElementD*>(dstPeermemPtr));
MatrixCoord dstOffset{dstOffsetInExpert + dstExpertOffset + mPreSumBeforeRank[dstEpIdx], blockCoord.n()};
MatrixCoord dstOffset{dstOffsetInExpert + dstExpertOffset, blockCoord.n()};
int64_t gmDstOffset = params.layoutC.GetOffset(dstOffset);
auto gmTileD = gmRemotePeer[gmDstOffset];
LayoutC layoutGM2{lenData, actualBlockShape.n(), params.n2};
@@ -298,7 +168,8 @@ public:
copyUbToGmD(gmTileD, ubC[tileOffset * n0], layoutGM2, layoutUB2);
tileOffset += lenData;
}
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(event_id);
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(event_id);
}

View File

@@ -22,8 +22,6 @@
namespace Catlass::Gemm::Block {
constexpr uint16_t CROSS_CORE_FLAG_MAX_SET_COUNT = 15;
template<AscendC::HardEvent event>
__aicore__ inline void SyncFlagFunc(int32_t eventID)
{
@@ -153,9 +151,11 @@ public:
L1TileShape::K, L1TileShape::N);
CATLASS_DEVICE
BlockMmad(Arch::Resource<ArchTag> &resource, uint32_t l1BufAddrStart = 0)
BlockMmad(Arch::Resource<ArchTag> &resource, __gm__ int32_t* flagPtr = nullptr, int32_t expertPerRank = 0, uint32_t l1BufAddrStart = 0)
{
syncGroupIdx = 0;
ptrSoftFlagBase_ = flagPtr;
expertPerRank_ = expertPerRank;
InitL1(resource, l1BufAddrStart);
InitL0A(resource);
InitL0B(resource);
@@ -272,9 +272,21 @@ public:
CATLASS_DEVICE
void Finalize(int32_t target, int32_t flag = 0)
{
for(;syncGroupIdx <= target; syncGroupIdx++) {
int32_t flagId = syncGroupIdx / CROSS_CORE_FLAG_MAX_SET_COUNT + flag;
AscendC::CrossCoreSetFlag<0x2, PIPE_FIX>(flagId);
if (ptrSoftFlagBase_ != nullptr) {
if (target < 0) {
return;
}
AscendC::SetFlag<AscendC::HardEvent::FIX_MTE3>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::FIX_MTE3>(EVENT_ID0);
AscendC::GlobalTensor<int32_t> flagGlobal;
flagGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(ptrSoftFlagBase_) + (expertPerRank_ + AscendC::GetBlockIdx()) * FLAGSTRIDE);
AscendC::DataCopy(flagGlobal, l1FTensor[target * 16], FLAGSTRIDE);
}
else {
for(;syncGroupIdx <= target; syncGroupIdx++) {
int32_t flagId = syncGroupIdx / 15 + flag;
AscendC::CrossCoreSetFlag<0x2, PIPE_FIX>(flagId);
}
}
}
private:
@@ -291,7 +303,6 @@ private:
layout::VectorLayout layoutScale;
int32_t syncLoopIdx;
int32_t flag;
CATLASS_DEVICE
L1TileMmadParams() = default;
};
@@ -310,11 +321,24 @@ private:
AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(l1AEventList[i]);
AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(l1BEventList[i]);
}
uint32_t l1SOffset = l1BOffset + L1B_TILE_SIZE * L1_STAGES;
if constexpr (std::is_same_v<ElementA, int8_t>) {
uint32_t l1SOffset = l1BOffset + L1B_TILE_SIZE * L1_STAGES;
l1STensor = resource.l1Buf.template GetBufferByByte<uint64_t>(l1SOffset);
AscendC::SetFlag<AscendC::HardEvent::FIX_MTE2>(0);
}
if (ptrSoftFlagBase_ != nullptr) {
// Initialize the flag matrix (structure as below):
// 1 0 0 0 0 0 0 0
// 2 0 0 0 0 0 0 0
// ...
// 16 0 0 0 0 0 0 0
// Then move it to L1
uint32_t l1FOffset = l1SOffset + L1S_TILE_SIZE;
l1FTensor = resource.l1Buf.template GetBufferByByte<int32_t>(l1FOffset);
AscendC::GlobalTensor<int32_t> flagBase;
flagBase.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(ptrSoftFlagBase_));
AscendC::DataCopy(l1FTensor, flagBase, expertPerRank_ * FLAGSTRIDE);
}
}
CATLASS_DEVICE
@@ -463,12 +487,20 @@ private:
if constexpr (std::is_same_v<ElementA, int8_t>) {
AscendC::SetFlag<AscendC::HardEvent::FIX_MTE2>(0);
}
#ifdef __TILE_SYNC__
if (params.flag > 0) {
int32_t flagId = params.flag + params.syncLoopIdx / 8;
AscendC::CrossCoreSetFlag<0x2, PIPE_FIX>(flagId);
}
#else
Finalize(params.syncLoopIdx, params.flag);
#endif
}
}
AscendC::LocalTensor<ElementA> l1ATensorList[L1_STAGES];
AscendC::LocalTensor<ElementB> l1BTensorList[L1_STAGES];
AscendC::LocalTensor<uint64_t> l1STensor;
AscendC::LocalTensor<int32_t> l1FTensor;
int32_t syncGroupIdx;
int32_t l1AEventList[L1_STAGES];
int32_t l1BEventList[L1_STAGES];
@@ -497,8 +529,11 @@ private:
CopyL1ToL0A copyL1ToL0A;
CopyL1ToL0B copyL1ToL0B;
CopyL0CToGm copyL0CToGm;
__gm__ int32_t* ptrSoftFlagBase_ = nullptr;
int32_t expertPerRank_;
};
} // namespace Catlass::Gemm::Block
#endif // CATLASS_GEMM_BLOCK_BLOCK_MMAD_PRELOAD_FIXPIPE_QUANT_HPP
#endif // CATLASS_GEMM_BLOCK_BLOCK_MMAD_PRELOAD_FIXPIPE_QUANT_HPP

View File

@@ -4,6 +4,10 @@
constexpr static uint64_t MB_SIZE = 1024 * 1024UL;
constexpr static int32_t NUMS_PER_FLAG = 16;
constexpr static int32_t CACHE_LINE = 512;
constexpr static int32_t FLAGSTRIDE = 16;
constexpr static int32_t RESET_VAL = 0xffff;
constexpr static int32_t ALIGN_128 = 128;
constexpr uint32_t MAX_EXPERTS_PER_RANK = 32;
constexpr static int32_t UB_ALIGN = 32;
constexpr uint16_t CROSS_CORE_FLAG_MAX_SET_COUNT = 15;
#endif

View File

@@ -5,16 +5,23 @@
#include "kernel_operator.h"
#include "const_args.hpp"
#ifdef HCCL_COMM
#include "moe_distribute_base.h"
#ifndef HCCL_COMM
#include "shmem_api.h"
using namespace AscendC::HcclContextDef;
#else
#include "shmem_api.h"
#endif
#define FORCE_INLINE_AICORE inline __attribute__((always_inline)) __aicore__
constexpr int32_t MAX_RANK_SIZE = 32;
constexpr int32_t SHMEM_MEM = 1024 * MB_SIZE;
constexpr int32_t SHMEM_MEM = 700 * MB_SIZE;
constexpr uint16_t SEND_SYNC_EVENT_ID = 9;
constexpr uint16_t RECV_SYNC_EVENT_ID = 10;
constexpr uint32_t SELF_STATE_OFFSET = 256 * 1024;
constexpr uint32_t STATE_OFFSET = 512;
FORCE_INLINE_AICORE void AicSyncAll() {
AscendC::CrossCoreSetFlag<0x0, PIPE_FIX>(8);
@@ -31,10 +38,11 @@ FORCE_INLINE_AICORE T gm_load(__gm__ T *cache) {
return *((__gm__ T *)cache);
}
FORCE_INLINE_AICORE void gm_dcci(__gm__ uint8_t * addr) {
template<typename T>
FORCE_INLINE_AICORE void gm_dcci(__gm__ T * addr) {
using namespace AscendC;
GlobalTensor<uint8_t> global;
global.SetGlobalBuffer(addr);
global.SetGlobalBuffer(reinterpret_cast<GM_ADDR>(addr));
// Important: add hint to avoid dcci being optimized by compiler
__asm__ __volatile__("");
@@ -58,7 +66,7 @@ FORCE_INLINE_AICORE int32_t gm_signal_wait_until_eq_for_barrier(__gm__ int32_t *
FORCE_INLINE_AICORE void gm_signal_wait_until_ne(__gm__ int32_t *sig_addr, int32_t cmp_val) {
do {
AscendC::LocalTensor<int32_t> ub;
ub.address_.logicPos = static_cast<uint8_t>(TPosition::VECIN);
ub.address_.logicPos = static_cast<uint8_t>(AscendC::TPosition::VECIN);
ub.address_.bufferAddr = 0;
AscendC::GlobalTensor<int32_t> sig;
sig.SetGlobalBuffer(sig_addr);
@@ -75,10 +83,10 @@ FORCE_INLINE_AICORE void gm_signal_wait_until_ne(__gm__ int32_t *sig_addr, int32
class HcclShmem {
public:
#ifdef HCCL_COMM
#ifdef HCCL_COMM // HCCL needs to initialize the HCCL context
__gm__ HcclOpResParamCustom *WinContext_{nullptr};
Hccl<HCCL_SERVER_TYPE_AICPU> hccl_;
GM_ADDR m_ptrArray[MAX_RANK_SIZE];
AscendC::LocalTensor<int32_t> ub;
FORCE_INLINE_AICORE
HcclShmem(){
auto contextGM0 = AscendC::GetHcclContext<HCCL_GROUP_ID_0>();
@@ -87,17 +95,13 @@ public:
m_rank = WinContext_->localUsrRankId;
m_rankSize = WinContext_->rankSize;
m_segmentSize = WinContext_->winSize;
for (int i = 0; i < m_rankSize; i++) {
m_ptrArray[i] = (GM_ADDR)((i == m_rank) ? WinContext_->localWindowsIn :
((HcclRankRelationResV2Custom *)(WinContext_->remoteRes[i].nextDevicePtr))->windowsIn);
}
}
#else
FORCE_INLINE_AICORE
HcclShmem(){
m_segmentSize = SHMEM_MEM;
}
FORCE_INLINE_AICORE
FORCE_INLINE_AICORE
void initShmem(GM_ADDR symmetricPtr_, size_t rank, size_t rankSize) {
symmetricPtr = symmetricPtr_;
m_rank = rank;
@@ -106,25 +110,26 @@ public:
#endif
FORCE_INLINE_AICORE
GM_ADDR operator() () const {
GM_ADDR operator() () const { // No parameters: return pointer to local peermem
#ifdef HCCL_COMM
return m_ptrArray[m_rank];
return (GM_ADDR)(WinContext_->localWindowsIn);
#else
return reinterpret_cast<GM_ADDR>(shmem_ptr(symmetricPtr, m_rank));
#endif
}
FORCE_INLINE_AICORE
GM_ADDR operator() (int32_t index) const {
GM_ADDR operator() (int32_t index) const { // With index parameter: return pointer to the base address of remote peermem
#ifdef HCCL_COMM
return m_ptrArray[index];
return (GM_ADDR)((index == m_rank) ? WinContext_->localWindowsIn :
((HcclRankRelationResV2Custom *)(WinContext_->remoteRes[index].nextDevicePtr))->windowsIn);
#else
return reinterpret_cast<GM_ADDR>(shmem_ptr(symmetricPtr, index));
#endif
}
FORCE_INLINE_AICORE
GM_ADDR operator () (int64_t offset, int32_t rankId) const {
GM_ADDR operator () (int64_t offset, int32_t rankId) const {
#ifdef HCCL_COMM
if (offset < 0 || offset >= m_segmentSize) {
return nullptr;
@@ -132,7 +137,8 @@ public:
if (rankId < 0 || rankId >= m_rankSize) {
return nullptr;
}
return m_ptrArray[rankId] + offset;
return (GM_ADDR)((rankId == m_rank) ? WinContext_->localWindowsIn :
((HcclRankRelationResV2Custom *)(WinContext_->remoteRes[rankId].nextDevicePtr))->windowsIn) + offset;
#else
return reinterpret_cast<GM_ADDR>(shmem_ptr((symmetricPtr + offset), rankId));
#endif
@@ -176,6 +182,130 @@ public:
gm_store(sync_base, count);
}
FORCE_INLINE_AICORE
void InitStatusTargetSum()
{
using namespace AscendC;
uint64_t flag_offset = (m_segmentSize - MB_SIZE) + SELF_STATE_OFFSET;
//uint64_t self_state_offset = (m_segmentSize - 2 * MB_SIZE);
// ep state
//uint32_t coreIdx = get_block_idx();;
uint32_t coreIdx = GetBlockIdx();
GlobalTensor<int32_t> selfStatusTensor;
selfStatusTensor.SetGlobalBuffer((__gm__ int32_t *)((*this)() + flag_offset));
__asm__ __volatile__("");
DataCacheCleanAndInvalid<int32_t, CacheLine::SINGLE_CACHE_LINE, DcciDst::CACHELINE_OUT>(selfStatusTensor[coreIdx * UB_ALIGN]);
__asm__ __volatile__("");
int32_t state = selfStatusTensor(coreIdx * UB_ALIGN);
if (state == 0) {
sumTarget_ = static_cast<float>(1.0);
selfStatusTensor(coreIdx * UB_ALIGN) = 0x3F800000; // 1.0f
epStateValue_ = 0x3F800000; // 1.0f
} else {
sumTarget_ = static_cast<float>(0.0);
selfStatusTensor(coreIdx * UB_ALIGN) = 0;
epStateValue_ = 0;
}
__asm__ __volatile__("");
DataCacheCleanAndInvalid<int32_t, CacheLine::SINGLE_CACHE_LINE, DcciDst::CACHELINE_OUT>(selfStatusTensor[coreIdx * UB_ALIGN]);
__asm__ __volatile__("");
}
FORCE_INLINE_AICORE
void CrossRankSyncV2Set(AscendC::LocalTensor<int32_t> ctrBuffer) {
//subblockid = 0
uint32_t stateOffset_ = STATE_OFFSET;
// uint32_t epStateOffsetOnWin_ = m_rank * stateOffset_;
uint64_t flag_offset = (m_segmentSize - MB_SIZE) + m_rank * stateOffset_;
//uint64_t flag_offset = (m_segmentSize - MB_SIZE);
int vec_size = get_block_num();
int vec_id = get_block_idx();
AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>(RECV_SYNC_EVENT_ID);
AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>(SEND_SYNC_EVENT_ID);
AscendC::CrossCoreWaitFlag(SEND_SYNC_EVENT_ID);
pipe_barrier(PIPE_ALL);
ctrBuffer.SetValue(0, epStateValue_);
AscendC::SetFlag<AscendC::HardEvent::S_MTE3>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::S_MTE3>(EVENT_ID0);
for (uint32_t dstEpIdx = vec_id; dstEpIdx < m_rankSize; dstEpIdx += vec_size) {
AscendC::GlobalTensor<int32_t> gmDstStates;
gmDstStates.SetGlobalBuffer((__gm__ int32_t*)((*this)(flag_offset, dstEpIdx)));
DataCopy(gmDstStates, ctrBuffer, 8);
}
AscendC::CrossCoreWaitFlag(RECV_SYNC_EVENT_ID);
}
FORCE_INLINE_AICORE
void CrossRankSyncV2Wait(AscendC::LocalTensor<float> statusTensor, AscendC::LocalTensor<float> gatherMaskOutTensor,
AscendC::LocalTensor<uint32_t> gatherTmpTensor, AscendC::LocalTensor<float> statusSumOutTensor) {
uint64_t flag_offset = (m_segmentSize - MB_SIZE);
int vec_size = get_block_num();
int vec_id = get_block_idx();
uint32_t stateOffset_ = STATE_OFFSET;
uint32_t sendRankNum_ = m_rankSize / vec_size;
uint32_t remainderRankNum = m_rankSize % vec_size;
uint32_t startRankId_ = sendRankNum_ * vec_id;
if (vec_id < remainderRankNum) {
sendRankNum_++;
startRankId_ += vec_id;
} else {
startRankId_ += remainderRankNum;
}
uint32_t endRankId_ = startRankId_ + sendRankNum_;
AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>(SEND_SYNC_EVENT_ID);
AscendC::GlobalTensor<float> epStatusSpaceGlobalTensor_;
epStatusSpaceGlobalTensor_.SetGlobalBuffer((__gm__ float *)((*this)() + flag_offset));
if (startRankId_ < m_rankSize) {
AscendC::PipeBarrier<PIPE_ALL>();
gatherTmpTensor.SetValue(0, 1);
uint32_t mask = 1; // gatherMask + sum
uint64_t rsvdCnt = 0;
// DataCopyParams intriParams{static_cast<uint16_t>(sendRankNum_), 1,
// static_cast<uint16_t>((moeSendNum_ > 512) ? 7 : 15), 0};
AscendC::DataCopyParams intriParams{static_cast<uint16_t>(sendRankNum_), 1,
static_cast<uint16_t>(15), 0};
float sumOfFlag = static_cast<float>(-1.0);
float minTarget = (sumTarget_ * sendRankNum_) - (float)0.5;
float maxTarget = (sumTarget_ * sendRankNum_) + (float)0.5;
AscendC::SumParams sumParams{1, sendRankNum_, sendRankNum_};
AscendC::SetFlag<AscendC::HardEvent::S_V>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::S_V>(EVENT_ID0);
while ((sumOfFlag < minTarget) || (sumOfFlag > maxTarget)) {
AscendC::DataCopy<float>(statusTensor, epStatusSpaceGlobalTensor_[startRankId_ * stateOffset_ / sizeof(float)],
intriParams);
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
GatherMask(gatherMaskOutTensor, statusTensor, gatherTmpTensor, true, mask,
{1, (uint16_t)sendRankNum_, 1, 0}, rsvdCnt);
AscendC::PipeBarrier<PIPE_V>();
AscendC::Sum(statusSumOutTensor, gatherMaskOutTensor, sumParams);
AscendC::SetFlag<AscendC::HardEvent::V_S>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::V_S>(EVENT_ID0);
sumOfFlag = statusSumOutTensor.GetValue(0);
}
}
AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>(RECV_SYNC_EVENT_ID);
AscendC::CrossCoreWaitFlag(RECV_SYNC_EVENT_ID);
//unpermute
AscendC::CrossCoreWaitFlag(SEND_SYNC_EVENT_ID);
}
FORCE_INLINE_AICORE
__gm__ int32_t* SyncBaseAddr() {
uint64_t flag_offset = (m_segmentSize - MB_SIZE) / sizeof(int32_t);
@@ -187,9 +317,11 @@ private:
int32_t m_rank;
int32_t m_rankSize;
size_t m_segmentSize;
float sumTarget_{0.0};
int32_t epStateValue_;
};
#endif
#endif