[fix]: fix precision issue in dispatch_ffn_combine_bf16 and remove redundant sync (#7198)
### What this PR does / why we need it?
Fix the precision issue in dispatch_ffn_combine_bf16 operator.
Remove redundant synchronization operations in dispatch_ffn_combine
operator.
- vLLM version: v0.16.0
- vLLM main:
4034c3d32e
---------
Signed-off-by: guanguan0308 <1546542263@qq.com>
This commit is contained in:
@@ -391,7 +391,6 @@ private:
|
||||
uint16_t syncgmmIdx = 0;
|
||||
AscendC::CrossCoreWaitFlag<0x2>(syncgmmIdx / CROSS_CORE_FLAG_MAX_SET_COUNT); // Wait for AIV to finish cumsum for matmul
|
||||
syncgmmIdx++;
|
||||
AscendC::PipeBarrier<PIPE_ALL>();
|
||||
|
||||
for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) {
|
||||
uint32_t currentM = cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx);
|
||||
@@ -405,7 +404,6 @@ private:
|
||||
int32_t arrayGroupIdx = params.listLen == 1 ? 0 : groupIdx;
|
||||
gmB1.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(GetTensorAddr<int8_t>(arrayGroupIdx, params.ptrB1)));
|
||||
gmS.SetGlobalBuffer(reinterpret_cast<__gm__ ElementScale *>(GetTensorAddr<int64_t>(arrayGroupIdx, params.ptrScale1)));
|
||||
AscendC::PipeBarrier<PIPE_ALL>();
|
||||
if (currentM <= L1TileShape::M) {
|
||||
gmB1.SetL2CacheHint(AscendC::CacheMode::CACHE_MODE_DISABLE);
|
||||
}
|
||||
@@ -493,8 +491,6 @@ private:
|
||||
|
||||
uint32_t startCoreIdx = 0;
|
||||
|
||||
AscendC::PipeBarrier<PIPE_ALL>();
|
||||
|
||||
int64_t preCurrentmSum = 0;
|
||||
int32_t syncLoopIdx = -1;
|
||||
uint32_t lastDequantExpertNum = params.expertPerRank;
|
||||
@@ -503,8 +499,6 @@ private:
|
||||
lastDequantExpertNum = params.expertPerRank - params.epilogueGranularity;
|
||||
}
|
||||
|
||||
AscendC::PipeBarrier<PIPE_ALL>();
|
||||
|
||||
for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) {
|
||||
uint32_t currentM = cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx);
|
||||
if (preCurrentmSum >= params.maxOutputSize) {
|
||||
|
||||
@@ -41,6 +41,7 @@ namespace {
|
||||
constexpr uint32_t EXPERTID_INDEX = 3;
|
||||
constexpr uint32_t BLOCK_NUM = 20;
|
||||
constexpr uint32_t SYSTEM_NEED_WORKSPACE = 16 * 1024 * 1024;
|
||||
constexpr uint64_t MB_SIZE = 1024 * 1024UL;
|
||||
}
|
||||
|
||||
namespace optiling {
|
||||
@@ -240,7 +241,8 @@ static ge::graphStatus DispatchFFNCombineBF16TilingFuncImpl(gert::TilingContext
|
||||
info.maxOutputSize * n2 * sizeof(int16_t) +
|
||||
info.maxOutputSize * info.K * sizeof(int16_t) +
|
||||
info.maxOutputSize * k2 * sizeof(int16_t) +
|
||||
info.worldSize * sizeof(int32_t) * 16;
|
||||
info.worldSize * sizeof(int32_t) * 16 +
|
||||
(info.expertPerRank + info.worldSize) * sizeof(int32_t) * 16;
|
||||
// std::max(info.maxOutputSize * info.N * sizeof(int16_t), info.maxOutputSize * n2 * sizeof(int16_t)) +
|
||||
// std::max(info.maxOutputSize * info.K * sizeof(int8_t), info.maxOutputSize * k2 * sizeof(int8_t));
|
||||
|
||||
|
||||
@@ -213,11 +213,7 @@ public:
|
||||
CATLASS_DEVICE
|
||||
void operator()<AscendC::AIV>(Params const ¶ms)
|
||||
{
|
||||
Dispatch(params);
|
||||
AscendC::SyncAll<true>();
|
||||
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(SYNCFLAGV2C);
|
||||
|
||||
CombineV2(params);
|
||||
DispatchAndCombine(params);
|
||||
}
|
||||
|
||||
private:
|
||||
@@ -241,7 +237,7 @@ private:
|
||||
|
||||
tokenPerExpert.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(shmem() + peermemInfo.offsetPeerTokenPerExpert));
|
||||
|
||||
tokenPerExpertLayout = Layout3D(params.EP * params.expertPerRank, params.expertPerRank);
|
||||
tokenPerExpertLayout = Layout3D(AlignUp(params.EP * params.expertPerRank, ALIGN_128), params.expertPerRank);
|
||||
preSumBeforeRank.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(workspaceInfo.ptrSumBeforeRank));
|
||||
}
|
||||
|
||||
@@ -309,7 +305,7 @@ private:
|
||||
AscendC::DataCopyPad(
|
||||
tmpBuffer1,
|
||||
tokenPerExpert[rankId * expertPerRank],
|
||||
{U16(EP), U16(expertPerRank * sizeof(int32_t)), U16(((EP - 1) * expertPerRank) * sizeof(int32_t)), 0},
|
||||
{U16(EP), U16(expertPerRank * sizeof(int32_t)), U16((AlignUp(EP * expertPerRank, 128) - expertPerRank) * sizeof(int32_t)), 0},
|
||||
{}
|
||||
);
|
||||
|
||||
@@ -331,145 +327,6 @@ private:
|
||||
);
|
||||
}
|
||||
|
||||
CATLASS_DEVICE
|
||||
void CrossRankSyncAndlocalTokenPerExpertAllGatherAndGetSumPreRank(Params const ¶ms, int64_t localTokenPerExpertOffset){
|
||||
AscendC::LocalTensor<int32_t> tmpBuffer = resource.ubBuf.template GetBufferByByte<int32_t>(0);
|
||||
AscendC::LocalTensor<float> ubFloat = resource.ubBuf.template GetBufferByByte<float>(0);
|
||||
uint32_t numPerCore = params.EP * params.expertPerRank;
|
||||
for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) {
|
||||
if (dstEpIdx == params.rank) {
|
||||
continue;
|
||||
}
|
||||
AscendC::GlobalTensor<int32_t> srcAddress;
|
||||
srcAddress.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(shmem() + localTokenPerExpertOffset));
|
||||
AscendC::GlobalTensor<int32_t> dstAddress;
|
||||
__gm__ void* dstPeermemPtr = shmem(localTokenPerExpertOffset, coreIdx);
|
||||
dstAddress.SetGlobalBuffer((__gm__ int32_t * )dstPeermemPtr);
|
||||
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
|
||||
using TType = Gemm::GemmType<int32_t, layout::RowMajor>;
|
||||
using CopyGmToUb = Epilogue::Tile::CopyGm2Ub<ArchTag, TType>;
|
||||
using CopyUbToGm = Epilogue::Tile::CopyUb2Gm<ArchTag, TType>;
|
||||
CopyGmToUb copyGmToUb;
|
||||
CopyUbToGm copyUbToGm;
|
||||
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
|
||||
|
||||
copyGmToUb(tmpBuffer, srcAddress[0],
|
||||
layout::RowMajor{ 1, numPerCore},
|
||||
layout::RowMajor{1, numPerCore});
|
||||
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
|
||||
AscendC::Adds(tmpBuffer, tmpBuffer, 0x800000, numPerCore);
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
|
||||
copyUbToGm(dstAddress[0], tmpBuffer,
|
||||
layout::RowMajor{ 1, numPerCore},
|
||||
layout::RowMajor{1, numPerCore});
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
|
||||
}
|
||||
|
||||
for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) {
|
||||
if (dstEpIdx != params.rank) {
|
||||
int32_t intPer512 = CACHE_LINE / sizeof(int);
|
||||
for(int32_t checkIdx = 0; checkIdx < params.EP * params.expertPerRank; checkIdx += intPer512) {
|
||||
__gm__ int32_t* sync_check = reinterpret_cast<__gm__ int32_t*>(shmem() + peermemInfo.offsetPeerTokenPerExpert) + tokenPerExpertLayout(dstEpIdx, 0, checkIdx);
|
||||
gm_signal_wait_until_ne(sync_check, 0);
|
||||
}
|
||||
AscendC::DataCopy(tmpBuffer, tokenPerExpert[tokenPerExpertLayout(dstEpIdx, 0, 0)], numPerCore);
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
|
||||
AscendC::Adds(tmpBuffer, tmpBuffer, -0x800000, numPerCore);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
|
||||
AscendC::DataCopy(tokenPerExpert[tokenPerExpertLayout(dstEpIdx, 0, 0)], tmpBuffer, numPerCore);
|
||||
} else {
|
||||
AscendC::DataCopy(tmpBuffer, tokenPerExpert[tokenPerExpertLayout(dstEpIdx, 0, 0)], numPerCore);
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
|
||||
}
|
||||
}
|
||||
|
||||
AscendC::SyncAll<true>();
|
||||
}
|
||||
|
||||
CATLASS_DEVICE
|
||||
void CrossRankSyncAndlocalTokenPerExpertAllGatherAndGetSumPreRankV2(Params const ¶ms, int64_t localTokenPerExpertOffset){
|
||||
AscendC::LocalTensor<int32_t> tmpBuffer = resource.ubBuf.template GetBufferByByte<int32_t>(0);
|
||||
AscendC::LocalTensor<float> ubFloat = resource.ubBuf.template GetBufferByByte<float>(0);
|
||||
uint32_t numPerCore = params.EP * params.expertPerRank;
|
||||
for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) {
|
||||
if (dstEpIdx == params.rank) {
|
||||
continue;
|
||||
}
|
||||
AscendC::GlobalTensor<int32_t> srcAddress;
|
||||
srcAddress.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(shmem() + localTokenPerExpertOffset));
|
||||
AscendC::GlobalTensor<int32_t> dstAddress;
|
||||
__gm__ void* dstPeermemPtr = shmem(localTokenPerExpertOffset, coreIdx);
|
||||
dstAddress.SetGlobalBuffer((__gm__ int32_t * )dstPeermemPtr);
|
||||
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
|
||||
using TType = Gemm::GemmType<int32_t, layout::RowMajor>;
|
||||
using CopyGmToUb = Epilogue::Tile::CopyGm2Ub<ArchTag, TType>;
|
||||
using CopyUbToGm = Epilogue::Tile::CopyUb2Gm<ArchTag, TType>;
|
||||
CopyGmToUb copyGmToUb;
|
||||
CopyUbToGm copyUbToGm;
|
||||
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
|
||||
|
||||
copyGmToUb(tmpBuffer, srcAddress[0],
|
||||
layout::RowMajor{ 1, numPerCore},
|
||||
layout::RowMajor{1, numPerCore});
|
||||
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
|
||||
AscendC::Adds(tmpBuffer, tmpBuffer, 0x800000, numPerCore);
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
|
||||
copyUbToGm(dstAddress[0], tmpBuffer,
|
||||
layout::RowMajor{ 1, numPerCore},
|
||||
layout::RowMajor{1, numPerCore});
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
|
||||
}
|
||||
|
||||
for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) {
|
||||
if (dstEpIdx != params.rank) {
|
||||
int32_t intPer512 = CACHE_LINE / sizeof(int);
|
||||
for(int32_t checkIdx = 0; checkIdx < params.EP * params.expertPerRank; checkIdx += intPer512) {
|
||||
__gm__ int32_t* sync_check = reinterpret_cast<__gm__ int32_t*>(shmem() + peermemInfo.offsetPeerTokenPerExpert) + tokenPerExpertLayout(dstEpIdx, 0, checkIdx);
|
||||
gm_signal_wait_until_ne(sync_check, 0);
|
||||
}
|
||||
AscendC::DataCopy(tmpBuffer, tokenPerExpert[tokenPerExpertLayout(dstEpIdx, 0, 0)], numPerCore);
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
|
||||
AscendC::Adds(tmpBuffer, tmpBuffer, -0x800000, numPerCore);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
|
||||
AscendC::DataCopy(tokenPerExpert[tokenPerExpertLayout(dstEpIdx, 0, 0)], tmpBuffer, numPerCore);
|
||||
} else {
|
||||
AscendC::DataCopy(tmpBuffer, tokenPerExpert[tokenPerExpertLayout(dstEpIdx, 0, 0)], numPerCore);
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
|
||||
}
|
||||
|
||||
int32_t prevSum = 0;
|
||||
for (int32_t i = 0; i < params.rank * params.expertPerRank; i++) {
|
||||
prevSum += tmpBuffer(i);
|
||||
}
|
||||
preSumBeforeRank(dstEpIdx * 16) = prevSum;
|
||||
__asm__ __volatile__("");
|
||||
DataCacheCleanAndInvalid<int32_t, CacheLine::SINGLE_CACHE_LINE, DcciDst::CACHELINE_OUT>(preSumBeforeRank[dstEpIdx * 16]);
|
||||
__asm__ __volatile__("");
|
||||
|
||||
}
|
||||
AscendC::SyncAll<true>();
|
||||
}
|
||||
|
||||
CATLASS_DEVICE
|
||||
void GetSumPreRank(AscendC::GlobalTensor<int32_t> & tokenPerExpert, AscendC::GlobalTensor<int32_t> & result,
|
||||
uint32_t expertPerRank, uint32_t rankId, uint32_t EP) {
|
||||
@@ -506,20 +363,20 @@ CATLASS_DEVICE
|
||||
icache_preload(8);
|
||||
BlockScheduler blockScheduler;
|
||||
BlockMmad blockMmad(resource);
|
||||
float aivFinishGroups = 0.0f;
|
||||
__gm__ float* aivFinishPtr = workspaceInfo.ptrSoftFlagBase + params.EP * FLAGSTRIDE;
|
||||
|
||||
int64_t gmGroupOffsetA = 0;
|
||||
int64_t gmGroupOffsetB = 0;
|
||||
int64_t gmGroupOffsetC = 0;
|
||||
uint32_t startCoreIdx = 0;
|
||||
uint32_t syncGroupIdx = 0;
|
||||
uint16_t syncgmmIdx = 0;
|
||||
AscendC::CrossCoreWaitFlag<0x2>(syncgmmIdx / CROSS_CORE_FLAG_MAX_SET_COUNT); // Wait for AIV to finish cumsum for matmul
|
||||
syncgmmIdx++;
|
||||
AicSyncAll();
|
||||
int64_t preCurrentmSum = 0;
|
||||
int32_t syncLoopIdx = -1;
|
||||
|
||||
AscendC::PipeBarrier<PIPE_ALL>();
|
||||
uint16_t syncgmmIdx = 0;
|
||||
AscendC::CrossCoreWaitFlag<0x2>(syncgmmIdx / CROSS_CORE_FLAG_MAX_SET_COUNT); // Wait for AIV to finish cumsum for matmul
|
||||
syncgmmIdx++;
|
||||
|
||||
for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) {
|
||||
uint32_t currentM = cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx);
|
||||
@@ -533,9 +390,6 @@ CATLASS_DEVICE
|
||||
int32_t arrayGroupIdx = params.listLen == 1 ? 0 : groupIdx;
|
||||
gmB1.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(GetTensorAddr<int8_t>(arrayGroupIdx, params.ptrB1)));
|
||||
gmS.SetGlobalBuffer(reinterpret_cast<__gm__ ElementScale *>(GetTensorAddr<int64_t>(arrayGroupIdx, params.ptrScale1)));
|
||||
|
||||
AscendC::PipeBarrier<PIPE_ALL>();
|
||||
|
||||
if (currentM <= L1TileShape::M) {
|
||||
gmB1.SetL2CacheHint(AscendC::CacheMode::CACHE_MODE_DISABLE);
|
||||
}
|
||||
@@ -630,7 +484,6 @@ CATLASS_DEVICE
|
||||
|
||||
uint32_t startCoreIdx = 0;
|
||||
|
||||
AscendC::PipeBarrier<PIPE_ALL>();
|
||||
|
||||
int64_t preCurrentmSum = 0;
|
||||
int32_t syncLoopIdx = -1;
|
||||
@@ -640,7 +493,6 @@ CATLASS_DEVICE
|
||||
lastDequantExpertNum = params.expertPerRank - params.epilogueGranularity;
|
||||
}
|
||||
|
||||
AscendC::PipeBarrier<PIPE_ALL>();
|
||||
|
||||
for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) {
|
||||
uint32_t currentM = cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx);
|
||||
@@ -651,7 +503,6 @@ CATLASS_DEVICE
|
||||
}
|
||||
AscendC::GlobalTensor<ElementB> gmB2;
|
||||
AscendC::GlobalTensor<ElementScale> gmS2;
|
||||
AscendC::PipeBarrier<PIPE_ALL>();
|
||||
int32_t arrayGroupIdx = params.listLen == 1 ? 0 : groupIdx;
|
||||
gmB2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(GetTensorAddr<int8_t>(arrayGroupIdx, params.ptrB2)));
|
||||
gmS2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementScale *>(GetTensorAddr<int64_t>(arrayGroupIdx, params.ptrScale2)));
|
||||
@@ -721,7 +572,6 @@ CATLASS_DEVICE
|
||||
gmGroupOffsetC += inGroupProblemShape.m() * inGroupProblemShape.n();
|
||||
|
||||
startCoreIdx = (startCoreIdx + coreLoops) % coreNum;
|
||||
|
||||
}
|
||||
|
||||
if constexpr (BlockMmad::DispatchPolicy::ASYNC) {
|
||||
@@ -729,8 +579,168 @@ CATLASS_DEVICE
|
||||
}
|
||||
}
|
||||
|
||||
CATLASS_DEVICE
|
||||
void InitArithProgress(Params const ¶ms) {
|
||||
AscendC::LocalTensor<float> tmpBuffer1 = resource.ubBuf.template GetBufferByByte<float>(0);
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID0);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID0);
|
||||
AscendC::Duplicate(tmpBuffer1, 0.0f, (params.EP + 1) * FLAGSTRIDE);
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
|
||||
|
||||
AscendC::GlobalTensor<float> flagGlobalBase;
|
||||
flagGlobalBase.SetGlobalBuffer(workspaceInfo.ptrSoftFlagBase);
|
||||
AscendC::DataCopy(flagGlobalBase, tmpBuffer1, (params.EP + 1) * FLAGSTRIDE);
|
||||
}
|
||||
|
||||
|
||||
CATLASS_DEVICE
|
||||
void Dispatch(Params const ¶ms) {
|
||||
void CrossRankSyncAndlocalTokenPerExpertAllGatherAndGetSumPreRankV2(Params const ¶ms, int64_t localTokenPerExpertOffset){
|
||||
uint32_t numPerCore = AlignUp(params.EP * params.expertPerRank, 128);
|
||||
AscendC::LocalTensor<int32_t> tmpBuffer = resource.ubBuf.template GetBufferByByte<int32_t>(0);
|
||||
AscendC::LocalTensor<int32_t> prevSumBuf = tmpBuffer[numPerCore];
|
||||
|
||||
for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) {
|
||||
if (dstEpIdx == params.rank) {
|
||||
continue;
|
||||
}
|
||||
AscendC::GlobalTensor<int32_t> srcAddress;
|
||||
srcAddress.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(shmem() + localTokenPerExpertOffset));
|
||||
AscendC::GlobalTensor<int32_t> dstAddress;
|
||||
__gm__ void* dstPeermemPtr = shmem(localTokenPerExpertOffset, coreIdx);
|
||||
dstAddress.SetGlobalBuffer((__gm__ int32_t * )dstPeermemPtr);
|
||||
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
|
||||
using TType = Gemm::GemmType<int32_t, layout::RowMajor>;
|
||||
using CopyGmToUb = Epilogue::Tile::CopyGm2Ub<ArchTag, TType>;
|
||||
using CopyUbToGm = Epilogue::Tile::CopyUb2Gm<ArchTag, TType>;
|
||||
CopyGmToUb copyGmToUb;
|
||||
CopyUbToGm copyUbToGm;
|
||||
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
|
||||
|
||||
copyGmToUb(tmpBuffer, srcAddress[0],
|
||||
layout::RowMajor{ 1, numPerCore},
|
||||
layout::RowMajor{1, numPerCore});
|
||||
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
|
||||
AscendC::Adds(tmpBuffer, tmpBuffer, 0x800000, numPerCore);
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
|
||||
copyUbToGm(dstAddress[0], tmpBuffer,
|
||||
layout::RowMajor{ 1, numPerCore},
|
||||
layout::RowMajor{1, numPerCore});
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
|
||||
}
|
||||
for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) {
|
||||
if (dstEpIdx != params.rank) {
|
||||
int32_t intPer512 = CACHE_LINE / sizeof(int);
|
||||
for(int32_t checkIdx = 0; checkIdx < AlignUp(params.EP * params.expertPerRank, 128); checkIdx += intPer512) {
|
||||
__gm__ int32_t* sync_check = reinterpret_cast<__gm__ int32_t*>(shmem() + peermemInfo.offsetPeerTokenPerExpert) + tokenPerExpertLayout(dstEpIdx, 0, checkIdx);
|
||||
gm_signal_wait_until_ne(sync_check, 0);
|
||||
}
|
||||
AscendC::DataCopy(tmpBuffer, tokenPerExpert[tokenPerExpertLayout(dstEpIdx, 0, 0)], numPerCore);
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
|
||||
AscendC::Adds(tmpBuffer, tmpBuffer, -0x800000, numPerCore);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
|
||||
AscendC::DataCopy(tokenPerExpert[tokenPerExpertLayout(dstEpIdx, 0, 0)], tmpBuffer, numPerCore);
|
||||
} else {
|
||||
AscendC::DataCopy(tmpBuffer, tokenPerExpert[tokenPerExpertLayout(dstEpIdx, 0, 0)], numPerCore);
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
|
||||
}
|
||||
AscendC::PipeBarrier<PIPE_ALL>();
|
||||
int32_t prevSum = 0;
|
||||
int32_t j = 0;
|
||||
for (int32_t i = 0; i < (params.rank + 1) * params.expertPerRank; i++) {
|
||||
if (i >= params.rank * params.expertPerRank) {
|
||||
prevSumBuf(j) = prevSum;
|
||||
j++;
|
||||
}
|
||||
prevSum += tmpBuffer(i);
|
||||
}
|
||||
AscendC::SetFlag<AscendC::HardEvent::S_MTE3>(EVENT_ID0);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::S_MTE3>(EVENT_ID0);
|
||||
AscendC::DataCopyPad(preSumBeforeRank[dstEpIdx * params.expertPerRank], prevSumBuf,
|
||||
AscendC::DataCopyParams{1, static_cast<uint16_t>(params.expertPerRank * sizeof(int32_t)), 0, 0});
|
||||
}
|
||||
|
||||
AscendC::SyncAll<true>();
|
||||
}
|
||||
|
||||
CATLASS_DEVICE
|
||||
void ResetTokenPerExpert(int32_t num)
|
||||
{
|
||||
if (coreIdx != coreNum - 1) {
|
||||
return;
|
||||
}
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID0);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID0);
|
||||
AscendC::LocalTensor<int32_t> tmp = resource.ubBuf.template GetBufferByByte<int32_t>(0);
|
||||
AscendC::Duplicate(tmp, 0, num);
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
|
||||
AscendC::DataCopy(tokenPerExpert, tmp, num);
|
||||
}
|
||||
|
||||
CATLASS_DEVICE
|
||||
void UpdateAicFlags(const Params ¶ms)
|
||||
{
|
||||
float flagBase = 1.0f * params.expertPerRank;
|
||||
__gm__ float* aicFinishPtr = workspaceInfo.ptrSoftFlagBase + params.EP * FLAGSTRIDE;
|
||||
float flag = 0.0f;
|
||||
float lastflag = -1.0f;
|
||||
AscendC::LocalTensor<float> tmpBuffer1 = resource.ubBuf.template GetBufferByByte<float>(0);
|
||||
__gm__ float* flagPtr = workspaceInfo.ptrSoftFlagBase;
|
||||
AscendC::GlobalTensor<float> flagGM;
|
||||
flagGM.SetGlobalBuffer(flagPtr);
|
||||
int32_t flagBufferSize = max(4, params.EP) * FLAGSTRIDE;
|
||||
AscendC::LocalTensor<float> dstValueBuffer = resource.ubBuf.template GetBufferByByte<float>(flagBufferSize);
|
||||
AscendC::LocalTensor<float> sharedTmpBuffer = resource.ubBuf.template GetBufferByByte<float>((flagBufferSize + 64));
|
||||
uint64_t mask[1] = {0};
|
||||
uint32_t repeatNum = (flagBufferSize / (4 * FLAGSTRIDE));
|
||||
for (int32_t i = 0; i < 4; i ++) {
|
||||
if (i < params.EP) {
|
||||
mask[0] |= 1ull * (1ull << (i * 16));
|
||||
}
|
||||
}
|
||||
AscendC::SetFlag<AscendC::HardEvent::S_V>(EVENT_ID0);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::S_V>(EVENT_ID0);
|
||||
while (flag < flagBase) {
|
||||
flag = flagBase;
|
||||
AscendC::DataCopy(tmpBuffer1, flagGM, params.EP * FLAGSTRIDE);
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
|
||||
|
||||
AscendC::ReduceMin<float>(dstValueBuffer, tmpBuffer1, sharedTmpBuffer, mask, repeatNum, 8, false);
|
||||
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_S>(EVENT_ID0);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_S>(EVENT_ID0);
|
||||
flag = min(flag, dstValueBuffer.GetValue(0));
|
||||
|
||||
if (flag > lastflag) {
|
||||
*aicFinishPtr = flag;
|
||||
gm_dcci(aicFinishPtr);
|
||||
lastflag = flag;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
CATLASS_DEVICE
|
||||
void CombineSetFlag() {
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID1);
|
||||
}
|
||||
|
||||
|
||||
CATLASS_DEVICE
|
||||
void DispatchAndCombine(Params const ¶ms) {
|
||||
icache_preload(8);
|
||||
int64_t localTokenPerExpertOffset = peermemInfo.offsetPeerTokenPerExpert + tokenPerExpertLayout(params.rank, 0, 0) * sizeof(int32_t);
|
||||
GM_ADDR localTokenPerExpert = shmem() + localTokenPerExpertOffset;
|
||||
@@ -744,35 +754,35 @@ CATLASS_DEVICE
|
||||
AscendC::SyncAll<true>();
|
||||
CrossRankSyncAndlocalTokenPerExpertAllGatherAndGetSumPreRankV2(params, localTokenPerExpertOffset);
|
||||
|
||||
if (coreIdx == coreNum - 1) {
|
||||
if (coreIdx == 0) {
|
||||
GetCumsumForMMAIV(tokenPerExpert, cumsumMM, params.expertPerRank, params.rank, params.EP);
|
||||
}
|
||||
|
||||
uint32_t curGroupOffset = 0;
|
||||
int32_t prevSumBeforeRank = 0;
|
||||
int32_t prevSum = 0;
|
||||
if (coreIdx < params.EP) {
|
||||
prevSum = preSumBeforeRank(coreIdx * params.expertPerRank);
|
||||
}
|
||||
AscendC::SyncAll<true>();
|
||||
|
||||
AscendC::GlobalTensor<int32_t> ExpertTokenNums;
|
||||
ExpertTokenNums.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(params.ptrExpertTokenNums));
|
||||
if(coreIdx == 0)
|
||||
{
|
||||
CopyGMToGM(ExpertTokenNums, cumsumMM[(params.EP - 1) * params.expertPerRank], params.expertPerRank, params.ubMoveNum);
|
||||
}
|
||||
uint16_t syncgmm1Idx = 0;
|
||||
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(syncgmm1Idx / CROSS_CORE_FLAG_MAX_SET_COUNT);
|
||||
syncgmm1Idx++;
|
||||
|
||||
AscendC::GlobalTensor<int32_t> ExpertTokenNums;
|
||||
ExpertTokenNums.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(params.ptrExpertTokenNums));
|
||||
AscendC::GlobalTensor<int32_t> LcalCumsumMM;
|
||||
LcalCumsumMM.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(workspaceInfo.ptrcumsumMM + (params.EP - 1) * params.expertPerRank * sizeof(int32_t)));
|
||||
if (coreIdx == 0) {
|
||||
CopyGMToGM(ExpertTokenNums, LcalCumsumMM, params.expertPerRank, params.ubMoveNum);
|
||||
}
|
||||
|
||||
uint32_t curGroupOffset = 0;
|
||||
int32_t prevSumBeforeRank = 0;
|
||||
int32_t groupIdxDeq = 0;
|
||||
int prevSum = 0;
|
||||
if (coreIdx < params.EP) {
|
||||
prevSum = preSumBeforeRank(coreIdx * 16);
|
||||
}
|
||||
uint32_t prevGroupSum1 = 0;
|
||||
uint32_t prevGroupSum1 = 0, dequantSum1 = 0, dequantSum2 = 0;
|
||||
uint32_t dequantSum = 0;
|
||||
int32_t syncLoopIdx = -1;
|
||||
uint32_t n = params.problemShape.n();
|
||||
BlockEpilogue1 blockEpilogue(resource, n);
|
||||
|
||||
icache_preload(8);
|
||||
for (int32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) {
|
||||
// The ith core reads data from the ith rank's peermem
|
||||
uint32_t currentM = cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx);
|
||||
for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) {
|
||||
uint32_t rowStart = (dstEpIdx == 0 ? 0 : cumsumMM((dstEpIdx - 1) * params.expertPerRank + groupIdx)) + prevGroupSum1;
|
||||
if (rowStart < params.maxOutputSize) {
|
||||
@@ -785,99 +795,40 @@ CATLASS_DEVICE
|
||||
GM_ADDR otherRankPtr = shmem(0, dstEpIdx);
|
||||
AscendC::GlobalTensor<ElementA> gmRemoteA;
|
||||
gmRemoteA.SetGlobalBuffer(reinterpret_cast<__gm__ ElementA*>(otherRankPtr + peermemInfo.offsetA));
|
||||
AscendC::GlobalTensor<ElementPerTokenScale> gmRemotePerTokenScale;
|
||||
gmRemotePerTokenScale.SetGlobalBuffer(reinterpret_cast<__gm__ ElementPerTokenScale*>(otherRankPtr + peermemInfo.offsetPeerPerTokenScale));
|
||||
|
||||
MatrixCoord offsetA{rowStart, 0};
|
||||
MatrixCoord shapeA{rows, params.problemShape.k()};
|
||||
MatrixCoord offsetPeer{rowSrc, 0};
|
||||
int64_t gmOffsetA = params.layoutA.GetOffset(offsetA);
|
||||
int64_t gmOffsetPeer = params.layoutA.GetOffset(offsetPeer);
|
||||
CopyGMToGM(gmA[gmOffsetA], gmRemoteA[gmOffsetPeer], rows * params.problemShape.k(), params.ubMoveNum);
|
||||
if constexpr (std::is_same_v<ElementA, int8_t>) {
|
||||
CopyGMToGM(gmPerTokenScale1[rowStart], gmRemotePerTokenScale[rowSrc], rows, rows);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if ((params.epilogueGranularity < params.expertPerRank && params.epilogueGranularity > 0) && groupIdx == params.expertPerRank - 1) {
|
||||
syncLoopIdx++;
|
||||
AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGC2V);
|
||||
}
|
||||
AscendC::SyncAll<true>();
|
||||
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(syncgmm1Idx / CROSS_CORE_FLAG_MAX_SET_COUNT); // V notifies C that the current communication round is complete
|
||||
syncgmm1Idx++;
|
||||
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(syncgmm1Idx / CROSS_CORE_FLAG_MAX_SET_COUNT);
|
||||
syncgmm1Idx ++;
|
||||
|
||||
if ((params.epilogueGranularity < params.expertPerRank && params.epilogueGranularity > 0) && groupIdx == params.expertPerRank - 1 && prevGroupSum1 > 0) {
|
||||
uint32_t rowStartThisCore = 0;
|
||||
MatrixCoord offsetC{0U, 0};
|
||||
uint32_t dequantLen = prevGroupSum1 - dequantSum;
|
||||
if (dequantLen >= params.maxOutputSize) {
|
||||
dequantLen = dequantLen - params.maxOutputSize;
|
||||
}
|
||||
prevGroupSum1 += currentM;
|
||||
|
||||
MatrixCoord shapeC{dequantLen, params.problemShape.n()};
|
||||
LayoutC layoutC{dequantLen, params.problemShape.n()};
|
||||
int64_t gmOffsetC = layoutC.GetOffset(offsetC);
|
||||
int64_t gmOffsetD = params.layoutD1.GetOffset(offsetC);
|
||||
if constexpr (std::is_same_v<ElementA, int8_t>) {
|
||||
blockEpilogue(gmC[gmOffsetC], shapeC, gmPerTokenScale1[rowStartThisCore], gmPermutedToken[gmOffsetD], gmPerTokenScale2[rowStartThisCore], params.epilogueCoreNum);
|
||||
} else {
|
||||
blockEpilogue(gmC[gmOffsetC], shapeC, gmPermutedToken[gmOffsetD], params.epilogueCoreNum);
|
||||
// Token count and truncation logic for the first SwiGLU operation
|
||||
if (groupIdx + 1 <= params.epilogueGranularity) {
|
||||
if (dequantSum1 + currentM <= params.maxOutputSize) {
|
||||
dequantSum1 += currentM;
|
||||
} else if (dequantSum1 < params.maxOutputSize) {
|
||||
dequantSum1 = params.maxOutputSize;
|
||||
}
|
||||
}
|
||||
prevGroupSum1 += cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx);
|
||||
dequantSum += cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx);
|
||||
if (groupIdx + 1 == params.epilogueGranularity && groupIdx < params.expertPerRank - 1) {
|
||||
dequantSum = 0;
|
||||
}
|
||||
}
|
||||
syncLoopIdx ++;
|
||||
|
||||
AscendC::SyncAll<true>();
|
||||
|
||||
uint32_t lastDequantExpertNum = params.expertPerRank;
|
||||
if (params.epilogueGranularity < params.expertPerRank) {
|
||||
lastDequantExpertNum = params.expertPerRank - params.epilogueGranularity;
|
||||
}
|
||||
if (lastDequantExpertNum < params.expertPerRank) {
|
||||
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(SYNCFLAGV2C);
|
||||
}
|
||||
AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGC2V);
|
||||
AscendC::SyncAll<true>();
|
||||
if (prevGroupSum1 - dequantSum < params.maxOutputSize) {
|
||||
uint32_t rowStartThisCore = prevGroupSum1 - dequantSum;;
|
||||
MatrixCoord offsetC{rowStartThisCore, 0};
|
||||
uint32_t dequantLen = dequantSum;
|
||||
if (prevGroupSum1 >= params.maxOutputSize) {
|
||||
dequantLen = dequantSum - (prevGroupSum1 - params.maxOutputSize);
|
||||
}
|
||||
MatrixCoord shapeC{dequantLen, params.problemShape.n()};
|
||||
LayoutC layoutC{dequantLen, params.problemShape.n()};
|
||||
int64_t gmOffsetC = layoutC.GetOffset(offsetC);
|
||||
int64_t gmOffsetD = params.layoutD1.GetOffset(offsetC);
|
||||
if constexpr (std::is_same_v<ElementA, int8_t>) {
|
||||
blockEpilogue(gmC[gmOffsetC], shapeC, gmPerTokenScale1[rowStartThisCore], gmPermutedToken[gmOffsetD], gmPerTokenScale2[rowStartThisCore], coreNum);
|
||||
} else {
|
||||
blockEpilogue(gmC[gmOffsetC], shapeC, gmPermutedToken[gmOffsetD], coreNum);
|
||||
// Token count and truncation logic for the second SwiGLU operation
|
||||
if (groupIdx + 1 > params.epilogueGranularity && dequantSum1 < params.maxOutputSize) {
|
||||
if (dequantSum1 + dequantSum2 + currentM <= params.maxOutputSize) {
|
||||
dequantSum2 += currentM;
|
||||
} else if (dequantSum1 + dequantSum2 < params.maxOutputSize) {
|
||||
dequantSum2 += params.maxOutputSize - dequantSum1 - dequantSum2;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
blockEpilogue.Finalize();
|
||||
}
|
||||
|
||||
CATLASS_DEVICE
|
||||
void CombineV2(Params const ¶ms) {
|
||||
BlockScheduler blockScheduler;
|
||||
uint32_t n2 = params.problemShape.k();
|
||||
uint32_t k2 = params.problemShape.n() / 2;
|
||||
uint32_t startCoreIdx = 0;
|
||||
int64_t gmGroupOffsetC = 0;
|
||||
uint32_t aivCoreNum = coreNum;
|
||||
uint32_t aicCoreNum = coreNum / 2;
|
||||
uint32_t aivCoreIdx = coreIdx;
|
||||
uint32_t aicCoreIdx = get_block_idx();
|
||||
uint32_t aivSubCoreIdx = get_subblockid();
|
||||
uint32_t preSrcExpertSum = 0;
|
||||
|
||||
typename BlockEpilogue2::Params epilogueParams{
|
||||
static_cast<int32_t>(params.EP),
|
||||
@@ -891,11 +842,108 @@ CATLASS_DEVICE
|
||||
static_cast<int32_t>(peermemInfo.offsetD)
|
||||
};
|
||||
|
||||
BlockEpilogue2 blockEpilogue(resource, epilogueParams);
|
||||
BlockEpilogue2 blockEpilogue2(resource, epilogueParams);
|
||||
|
||||
uint32_t n = params.problemShape.n();
|
||||
BlockEpilogue1 blockEpilogue1(resource, n);
|
||||
|
||||
// Synchronous wait: SwiGLU waits for GMM1 [1]
|
||||
AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGC2V);
|
||||
AscendC::SyncAll<true>();
|
||||
if (dequantSum1 > 0) {
|
||||
uint32_t rowStartThisCore = 0;
|
||||
MatrixCoord offsetC{0U, 0};
|
||||
MatrixCoord shapeC{dequantSum1, params.problemShape.n()};
|
||||
LayoutC layoutC{dequantSum1, params.problemShape.n()};
|
||||
int64_t gmOffsetC = layoutC.GetOffset(offsetC);
|
||||
int64_t gmOffsetD = params.layoutD1.GetOffset(offsetC);
|
||||
// blockEpilogue1(gmC[gmOffsetC], shapeC, gmPerTokenScale1[rowStartThisCore], gmPermutedToken[gmOffsetD], gmPerTokenScale2[rowStartThisCore], params.epilogueCoreNum);
|
||||
blockEpilogue1(gmC[gmOffsetC], shapeC, gmPermutedToken[gmOffsetD], params.epilogueCoreNum);
|
||||
}
|
||||
AscendC::SyncAll<true>();
|
||||
// Synchronization signal: SwiGLU notifies GMM2 [1]
|
||||
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(SYNCFLAGV2C);
|
||||
|
||||
if ((params.epilogueGranularity < params.expertPerRank && params.epilogueGranularity > 0)) {
|
||||
// Synchronous wait: SwiGLU waits for GMM1 [2]
|
||||
AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGC2V);
|
||||
AscendC::SyncAll<true>();
|
||||
if (dequantSum2 > 0) {
|
||||
uint32_t rowStartThisCore = dequantSum1;
|
||||
MatrixCoord offsetC{rowStartThisCore, 0};
|
||||
uint32_t dequantLen = dequantSum2;
|
||||
MatrixCoord shapeC{dequantLen, params.problemShape.n()};
|
||||
LayoutC layoutC{dequantLen, params.problemShape.n()};
|
||||
int64_t gmOffsetC = layoutC.GetOffset(offsetC);
|
||||
int64_t gmOffsetD = params.layoutD1.GetOffset(offsetC);
|
||||
// blockEpilogue1(gmC[gmOffsetC], shapeC, gmPerTokenScale1[rowStartThisCore], gmPermutedToken[gmOffsetD], gmPerTokenScale2[rowStartThisCore], coreNum);
|
||||
blockEpilogue1(gmC[gmOffsetC], shapeC, gmPermutedToken[gmOffsetD], coreNum);
|
||||
}
|
||||
AscendC::SyncAll<true>();
|
||||
// Synchronization signal: SwiGLU notifies GMM2 [2]
|
||||
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(SYNCFLAGV2C);
|
||||
}
|
||||
|
||||
blockEpilogue1.Finalize();
|
||||
|
||||
CombineSetFlag();
|
||||
|
||||
CombineV2(params, blockEpilogue2);
|
||||
|
||||
AscendC::SyncAll<true>();
|
||||
#ifndef __CROSSRANKSYNCANDALLGATHERV1__
|
||||
ResetTokenPerExpert(params.EP * AlignUp(params.EP * params.expertPerRank, 128));
|
||||
#endif
|
||||
shmem.InitStatusTargetSum();
|
||||
if (get_subblockid() == 0) {
|
||||
AscendC::LocalTensor<int32_t> ctrBuffer = resource.ubBuf.template GetBufferByByte<int32_t>(0);
|
||||
shmem.CrossRankSyncV2Set(ctrBuffer);
|
||||
} else {
|
||||
uint32_t uboffset = 0;
|
||||
uint32_t aicCoreNum = coreNum / 2;
|
||||
uint32_t aicCoreIdx = get_block_idx();
|
||||
uint32_t sendRankNum_ = params.EP / aicCoreNum;
|
||||
uint32_t remainderRankNum = params.EP % aicCoreNum;
|
||||
if (aicCoreIdx < remainderRankNum) {
|
||||
sendRankNum_++;
|
||||
}
|
||||
AscendC::LocalTensor<float> statusTensor = resource.ubBuf.template GetBufferByByte<float>(uboffset);
|
||||
uboffset += sendRankNum_ * UB_ALIGN;
|
||||
AscendC::LocalTensor<float> gatherMaskOutTensor = resource.ubBuf.template GetBufferByByte<float>(uboffset);
|
||||
uboffset += AlignUp(params.EP * sizeof(float), 32);
|
||||
AscendC::LocalTensor<uint32_t> gatherTmpTensor = resource.ubBuf.template GetBufferByByte<uint32_t>(uboffset);
|
||||
uboffset += AlignUp(sizeof(uint32_t), 32);
|
||||
AscendC::LocalTensor<float> statusSumOutTensor = resource.ubBuf.template GetBufferByByte<float>(uboffset);
|
||||
uboffset += AlignUp(sizeof(float), 32);
|
||||
shmem.CrossRankSyncV2Wait(statusTensor, gatherMaskOutTensor, gatherTmpTensor, statusSumOutTensor);
|
||||
MoeTokenUnpermuteTilingData tilingData;
|
||||
MoeTokenUnpermuteTiling(params.problemShape.m() * params.topK, n2, params.topK, tilingData, coreNum / 2);
|
||||
KernelMoeTokenUnpermute<ElementD2, int32_t, float, true> kernelMoeTokenUnpermuteOp;
|
||||
kernelMoeTokenUnpermuteOp.Init(shmem() + peermemInfo.offsetD, workspaceInfo.expandedRowIdx, params.probs, reinterpret_cast<GM_ADDR>(params.ptrOutput), &tilingData);
|
||||
kernelMoeTokenUnpermuteOp.Process();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
CATLASS_DEVICE
|
||||
void CombineV2(Params const ¶ms, BlockEpilogue2 & blockEpilogue) {
|
||||
BlockScheduler blockScheduler;
|
||||
int32_t syncLoopIdx = 0;
|
||||
uint32_t startCoreIdx = 0;
|
||||
uint32_t aicCoreNum = coreNum / 2;
|
||||
uint32_t aicCoreIdx = get_block_idx();
|
||||
uint32_t aivSubCoreIdx = get_subblockid();
|
||||
uint32_t preSrcExpertSum = 0;
|
||||
uint32_t n2 = params.problemShape.k();
|
||||
uint32_t k2 = params.problemShape.n() / 2;
|
||||
icache_preload(8);
|
||||
for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) {
|
||||
uint32_t currentExpertM = cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx);
|
||||
if (preSrcExpertSum >= params.maxOutputSize) {
|
||||
currentExpertM = 0;
|
||||
} else if (preSrcExpertSum + currentExpertM > params.maxOutputSize) {
|
||||
currentExpertM = params.maxOutputSize - preSrcExpertSum;
|
||||
}
|
||||
GemmCoord inGroupProblemShape{currentExpertM, n2, k2}; // M N K
|
||||
blockScheduler.Update(inGroupProblemShape, MakeCoord(L1TileShape::M, L1TileShape::N));
|
||||
uint32_t coreLoops = blockScheduler.GetCoreLoops();
|
||||
@@ -916,11 +964,10 @@ CATLASS_DEVICE
|
||||
if(aivSubCoreIdx == 1) {
|
||||
m_offset += (m_rows / 2) * m0;
|
||||
}
|
||||
if (loopIdx == startLoopIdx) {
|
||||
for (;syncLoopIdx <= groupIdx; syncLoopIdx++) {
|
||||
int32_t flag_id = syncLoopIdx / CROSS_CORE_FLAG_MAX_SET_COUNT;
|
||||
AscendC::CrossCoreWaitFlag<0x2>(flag_id);
|
||||
}
|
||||
|
||||
for (;syncLoopIdx <= groupIdx; syncLoopIdx ++) {
|
||||
int32_t flag_id = syncLoopIdx / CROSS_CORE_FLAG_MAX_SET_COUNT;
|
||||
AscendC::CrossCoreWaitFlag<0x2>(flag_id);
|
||||
}
|
||||
|
||||
for (int32_t cur_row = 0; cur_row < aiv_m_rows; cur_row ++) {
|
||||
@@ -930,34 +977,15 @@ CATLASS_DEVICE
|
||||
actualm = actualBlockShape.m() - (m_rows / 2) * m0 - cur_row * m0;
|
||||
}
|
||||
GemmCoord realTileShape{actualm, actualBlockShape.n(), 1};
|
||||
if constexpr (std::is_same_v<ElementA, int8_t>) {
|
||||
blockEpilogue(gmC2, gmPerTokenScale2, realTileCoord, realTileShape, groupIdx, preSrcExpertSum, preSumBeforeRank, mPreSumBeforeRank);
|
||||
} else {
|
||||
blockEpilogue(gmC2, realTileCoord, realTileShape, groupIdx, preSrcExpertSum, preSumBeforeRank, mPreSumBeforeRank);
|
||||
}
|
||||
blockEpilogue(gmC2, realTileCoord, realTileShape, groupIdx, preSrcExpertSum, preSumBeforeRank);
|
||||
m_offset += m0;
|
||||
}
|
||||
}
|
||||
|
||||
for (int32_t dstEpIdx = 0; dstEpIdx < params.EP; dstEpIdx ++) {
|
||||
int32_t expertRankM = tokenPerExpert(tokenPerExpertLayout(dstEpIdx, params.rank, groupIdx));
|
||||
mPreSumBeforeRank[dstEpIdx] += expertRankM;
|
||||
}
|
||||
preSrcExpertSum += currentExpertM;
|
||||
startCoreIdx = (startCoreIdx + coreLoops) % aicCoreNum;
|
||||
}
|
||||
|
||||
blockEpilogue.Finalize();
|
||||
AscendC::SyncAll<true>();
|
||||
ResetTokenPerExpert(tokenPerExpert, params.EP * params.EP * params.expertPerRank);
|
||||
shmem.CrossRankSync();
|
||||
|
||||
MoeTokenUnpermuteTilingData tilingData;
|
||||
MoeTokenUnpermuteTiling(params.problemShape.m() * params.topK, n2, params.topK, tilingData, aivCoreNum);
|
||||
KernelMoeTokenUnpermute<ElementD2, int32_t, float, true> kernelMoeTokenUnpermuteOp;
|
||||
|
||||
kernelMoeTokenUnpermuteOp.Init(shmem() + peermemInfo.offsetD, workspaceInfo.expandedRowIdx, params.probs, reinterpret_cast<GM_ADDR>(params.ptrOutput), &tilingData);
|
||||
kernelMoeTokenUnpermuteOp.Process();
|
||||
}
|
||||
|
||||
|
||||
@@ -973,6 +1001,7 @@ private:
|
||||
GM_ADDR expandedRowIdx;
|
||||
GM_ADDR ptrTokenPerExpert;
|
||||
GM_ADDR ptrSumBeforeRank;
|
||||
__gm__ float* ptrSoftFlagBase;
|
||||
|
||||
CATLASS_DEVICE
|
||||
WorkspaceInfo(){}
|
||||
@@ -1012,6 +1041,9 @@ private:
|
||||
|
||||
workspaceOffset += params.maxOutputSize * k2 * sizeof(ElementA);
|
||||
ptrSumBeforeRank = params.ptrWorkspace + workspaceOffset;
|
||||
workspaceOffset += params.EP * sizeof(int32_t) * FLAGSTRIDE;
|
||||
ptrSoftFlagBase = reinterpret_cast<__gm__ float*>(params.ptrWorkspace + workspaceOffset);
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1041,7 +1073,6 @@ private:
|
||||
WorkspaceInfo workspaceInfo;
|
||||
PeermemInfo peermemInfo;
|
||||
|
||||
int64_t m_prevSumBeforeRank;
|
||||
|
||||
AscendC::GlobalTensor<ElementA> gmA;
|
||||
AscendC::GlobalTensor<ElementC> gmC;
|
||||
@@ -1057,7 +1088,7 @@ private:
|
||||
AscendC::GlobalTensor<int32_t> tokenPerExpert;
|
||||
AscendC::GlobalTensor<int32_t> cumsumMM;
|
||||
AscendC::GlobalTensor<int32_t> preSumBeforeRank;
|
||||
uint32_t mPreSumBeforeRank[32] = {0};
|
||||
|
||||
Layout3D tokenPerExpertLayout;
|
||||
HcclShmem shmem;
|
||||
};
|
||||
|
||||
@@ -85,8 +85,8 @@ KernelMoeTokenUnpermute<T1, T2, T3, PROBS>::Init(GM_ADDR permuted_tokens, GM_ADD
|
||||
GM_ADDR unpermuted_tokens,
|
||||
const MoeTokenUnpermuteTilingData *__restrict tiling_data)
|
||||
{
|
||||
this->blockIdx = get_block_idx() + get_subblockid() * get_block_num();
|
||||
this->blockNum = get_block_num() * get_subblockdim();
|
||||
this->blockIdx = get_block_idx();
|
||||
this->blockNum = get_block_num();
|
||||
|
||||
if (blockIdx >= blockNum) {
|
||||
return;
|
||||
|
||||
@@ -72,17 +72,6 @@ public:
|
||||
CATLASS_DEVICE
|
||||
BlockEpilogue(Arch::Resource<ArchTag> const &resource, Params const ¶ms = Params{}) : params(params)
|
||||
{
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID0);
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID1);
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID2);
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID3);
|
||||
AscendC::SetFlag<AscendC::HardEvent::S_MTE2>(EVENT_ID2);
|
||||
AscendC::SetFlag<AscendC::HardEvent::S_MTE2>(EVENT_ID3);
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID0);
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID1);
|
||||
|
||||
|
||||
|
||||
//ub:192KB
|
||||
n0 = params.n0;
|
||||
size_t ubOffset = 0;
|
||||
@@ -98,137 +87,20 @@ public:
|
||||
source_scale_offset[i] = -1;
|
||||
}
|
||||
tokenPerExpert.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(params.ptrTokenPerExpert));
|
||||
tokenPerExpertLayout = Layout3D(params.EP * params.expertPerRank, params.expertPerRank);
|
||||
tokenPerExpertLayout = Layout3D(AlignUp(params.EP * params.expertPerRank, ALIGN_128), params.expertPerRank);
|
||||
is_ping = true;
|
||||
}
|
||||
|
||||
CATLASS_DEVICE
|
||||
void Finalize()
|
||||
{
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID0);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID1);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID2);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID3);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::S_MTE2>(EVENT_ID2);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::S_MTE2>(EVENT_ID3);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID0);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID1);
|
||||
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID1);
|
||||
}
|
||||
CATLASS_DEVICE
|
||||
~BlockEpilogue()
|
||||
{
|
||||
|
||||
}
|
||||
CATLASS_DEVICE
|
||||
void operator() (
|
||||
AscendC::GlobalTensor<ElementC> const &gmC,
|
||||
AscendC::GlobalTensor<ElementPerTokenScale> const &gmPerTokenScale,
|
||||
GemmCoord& blockCoord,
|
||||
GemmCoord& actualBlockShape,
|
||||
int32_t groupIdx,
|
||||
int32_t preSrcExpertSum,
|
||||
AscendC::GlobalTensor<int32_t> preSumBeforeRank,
|
||||
uint32_t *mPreSumBeforeRank
|
||||
){
|
||||
is_ping = !is_ping;
|
||||
auto event_id = is_ping ? EVENT_ID0 : EVENT_ID1;
|
||||
auto event_id_2 = is_ping ? EVENT_ID2 : EVENT_ID3;
|
||||
|
||||
auto &ubC = ubCList[is_ping];
|
||||
auto &ubD = ubDList[is_ping];
|
||||
int32_t gmCOffset = preSrcExpertSum * params.n2 + blockCoord.m() * params.n2 + blockCoord.n();
|
||||
auto gmTileC = gmC[gmCOffset];
|
||||
auto &ubCFp32 = ubFp32List[is_ping];
|
||||
auto &scaleUb = scaleUbList[is_ping];
|
||||
// auto &ubOutFp32 = ubOutFp32List[is_ping];
|
||||
|
||||
LayoutC layoutGM{actualBlockShape.m(), actualBlockShape.n(), params.n2};
|
||||
LayoutC layoutUB{actualBlockShape.m(), actualBlockShape.n(), n0};
|
||||
|
||||
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(event_id); //for debug
|
||||
copyGmToUbC(ubC, gmTileC, layoutUB, layoutGM);
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(event_id); //for debug
|
||||
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(event_id);
|
||||
AscendC::Cast<float, ElementC, false>(ubCFp32, ubC, AscendC::RoundMode::CAST_NONE, -1, repeat, {1, 1, 8, 4});
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(event_id);
|
||||
|
||||
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(event_id_2);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::S_MTE2>(event_id_2);
|
||||
|
||||
int32_t gmScaleOffset = preSrcExpertSum + blockCoord.m();
|
||||
layout::VectorLayout scaleLauout{actualBlockShape.m()};
|
||||
if (source_scale_offset[event_id] != gmScaleOffset) {
|
||||
source_scale_offset[event_id] = gmScaleOffset;
|
||||
copyScaleGmToUb(scaleUb, gmPerTokenScale[gmScaleOffset], scaleLauout, scaleLauout);
|
||||
}
|
||||
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_S>(event_id_2);
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(event_id_2);
|
||||
|
||||
|
||||
|
||||
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(event_id_2);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_S>(event_id_2); // 注意必须是MTE2_S,不能是MTE2_V,否则会读到0,造成乱码
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
for (int32_t row = 0; row < actualBlockShape.m(); ++row) {
|
||||
float scale = scaleUb(row);
|
||||
Muls<float, false>(ubCFp32[n0* row], ubCFp32[n0 * row] , scale, -1, (actualBlockShape.n() + 127) / 128 * 2, {1, 1, 8, 8});
|
||||
}
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(event_id);
|
||||
AscendC::Cast<ElementD, float, false>(ubD, ubCFp32, AscendC::RoundMode::CAST_RINT, -1, repeat, {1, 1, 4, 8});
|
||||
AscendC::SetFlag<AscendC::HardEvent::S_MTE2>(event_id_2);
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(event_id_2);
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(event_id);
|
||||
|
||||
int32_t lenTile = actualBlockShape.m();
|
||||
int32_t stTile = blockCoord.m();
|
||||
int32_t edTile = stTile + lenTile;
|
||||
int32_t preSumRankInExpert = 0;
|
||||
int32_t tileOffset = 0;
|
||||
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(event_id); //for debug
|
||||
for (int32_t dstEpIdx = 0; dstEpIdx < params.EP; dstEpIdx ++) {
|
||||
int32_t lenRankInExpert = tokenPerExpert(tokenPerExpertLayout(dstEpIdx, params.rank, groupIdx));
|
||||
int32_t dstExpertOffset = preSumBeforeRank(dstEpIdx * 16);
|
||||
int32_t stRankInExpert = preSumRankInExpert;
|
||||
int32_t edRankInExpert = stRankInExpert + lenRankInExpert;
|
||||
preSumRankInExpert += lenRankInExpert;
|
||||
if (stRankInExpert >= edTile) {
|
||||
break;
|
||||
}
|
||||
else if (edRankInExpert <= stTile) {
|
||||
continue;
|
||||
}
|
||||
int32_t stData = max(stRankInExpert, stTile);
|
||||
int32_t edData = min(edRankInExpert, edTile);
|
||||
uint32_t lenData = edData - stData;
|
||||
if (lenData <= 0){
|
||||
continue;
|
||||
}
|
||||
|
||||
uint32_t dstOffsetInExpert = 0;
|
||||
if (stTile > stRankInExpert) {
|
||||
dstOffsetInExpert = stTile - stRankInExpert;
|
||||
}
|
||||
AscendC::GlobalTensor<ElementD> gmRemotePeer;
|
||||
__gm__ void* dstPeermemPtr = params.shmem(params.offsetD, dstEpIdx);
|
||||
gmRemotePeer.SetGlobalBuffer(reinterpret_cast<__gm__ ElementD*>(dstPeermemPtr));
|
||||
MatrixCoord dstOffset{dstOffsetInExpert + dstExpertOffset + mPreSumBeforeRank[dstEpIdx], blockCoord.n()};
|
||||
int64_t gmDstOffset = params.layoutC.GetOffset(dstOffset);
|
||||
auto gmTileD = gmRemotePeer[gmDstOffset];
|
||||
LayoutC layoutGM2{lenData, actualBlockShape.n(), params.n2};
|
||||
LayoutC layoutUB2{lenData, actualBlockShape.n(), n0};
|
||||
copyUbToGmD(gmTileD, ubD[tileOffset * n0], layoutGM2, layoutUB2);
|
||||
tileOffset += lenData;
|
||||
}
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(event_id);
|
||||
|
||||
}
|
||||
|
||||
CATLASS_DEVICE
|
||||
@@ -238,14 +110,12 @@ public:
|
||||
GemmCoord& actualBlockShape,
|
||||
int32_t groupIdx,
|
||||
int32_t preSrcExpertSum,
|
||||
AscendC::GlobalTensor<int32_t> preSumBeforeRank,
|
||||
uint32_t *mPreSumBeforeRank
|
||||
AscendC::GlobalTensor<int32_t> preSumBeforeRank
|
||||
){
|
||||
is_ping = !is_ping;
|
||||
auto event_id = is_ping ? EVENT_ID0 : EVENT_ID1;
|
||||
|
||||
auto &ubC = ubCList[is_ping];
|
||||
auto &ubD = ubDList[is_ping];
|
||||
int32_t gmCOffset = preSrcExpertSum * params.n2 + blockCoord.m() * params.n2 + blockCoord.n();
|
||||
auto gmTileC = gmC[gmCOffset];
|
||||
|
||||
@@ -253,7 +123,7 @@ public:
|
||||
LayoutC layoutUB{actualBlockShape.m(), actualBlockShape.n(), n0};
|
||||
|
||||
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(event_id); //for debug
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(event_id); //for debug
|
||||
copyGmToUbC(ubC, gmTileC, layoutUB, layoutGM);
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_MTE3>(event_id); //for debug
|
||||
|
||||
@@ -263,10 +133,10 @@ public:
|
||||
int32_t preSumRankInExpert = 0;
|
||||
int32_t tileOffset = 0;
|
||||
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_MTE3>(event_id); //for debug
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_MTE3>(event_id);
|
||||
for (int32_t dstEpIdx = 0; dstEpIdx < params.EP; dstEpIdx ++) {
|
||||
int32_t lenRankInExpert = tokenPerExpert(tokenPerExpertLayout(dstEpIdx, params.rank, groupIdx));
|
||||
int32_t dstExpertOffset = preSumBeforeRank(dstEpIdx * 16);
|
||||
int32_t dstExpertOffset = preSumBeforeRank(dstEpIdx * params.expertPerRank + groupIdx);
|
||||
int32_t stRankInExpert = preSumRankInExpert;
|
||||
int32_t edRankInExpert = stRankInExpert + lenRankInExpert;
|
||||
preSumRankInExpert += lenRankInExpert;
|
||||
@@ -282,7 +152,7 @@ public:
|
||||
if (lenData <= 0){
|
||||
continue;
|
||||
}
|
||||
|
||||
|
||||
uint32_t dstOffsetInExpert = 0;
|
||||
if (stTile > stRankInExpert) {
|
||||
dstOffsetInExpert = stTile - stRankInExpert;
|
||||
@@ -290,7 +160,7 @@ public:
|
||||
AscendC::GlobalTensor<ElementD> gmRemotePeer;
|
||||
__gm__ void* dstPeermemPtr = params.shmem(params.offsetD, dstEpIdx);
|
||||
gmRemotePeer.SetGlobalBuffer(reinterpret_cast<__gm__ ElementD*>(dstPeermemPtr));
|
||||
MatrixCoord dstOffset{dstOffsetInExpert + dstExpertOffset + mPreSumBeforeRank[dstEpIdx], blockCoord.n()};
|
||||
MatrixCoord dstOffset{dstOffsetInExpert + dstExpertOffset, blockCoord.n()};
|
||||
int64_t gmDstOffset = params.layoutC.GetOffset(dstOffset);
|
||||
auto gmTileD = gmRemotePeer[gmDstOffset];
|
||||
LayoutC layoutGM2{lenData, actualBlockShape.n(), params.n2};
|
||||
@@ -298,7 +168,8 @@ public:
|
||||
copyUbToGmD(gmTileD, ubC[tileOffset * n0], layoutGM2, layoutUB2);
|
||||
tileOffset += lenData;
|
||||
}
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(event_id);
|
||||
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(event_id);
|
||||
|
||||
}
|
||||
|
||||
|
||||
@@ -22,8 +22,6 @@
|
||||
|
||||
namespace Catlass::Gemm::Block {
|
||||
|
||||
constexpr uint16_t CROSS_CORE_FLAG_MAX_SET_COUNT = 15;
|
||||
|
||||
template<AscendC::HardEvent event>
|
||||
__aicore__ inline void SyncFlagFunc(int32_t eventID)
|
||||
{
|
||||
@@ -153,9 +151,11 @@ public:
|
||||
L1TileShape::K, L1TileShape::N);
|
||||
|
||||
CATLASS_DEVICE
|
||||
BlockMmad(Arch::Resource<ArchTag> &resource, uint32_t l1BufAddrStart = 0)
|
||||
BlockMmad(Arch::Resource<ArchTag> &resource, __gm__ int32_t* flagPtr = nullptr, int32_t expertPerRank = 0, uint32_t l1BufAddrStart = 0)
|
||||
{
|
||||
syncGroupIdx = 0;
|
||||
ptrSoftFlagBase_ = flagPtr;
|
||||
expertPerRank_ = expertPerRank;
|
||||
InitL1(resource, l1BufAddrStart);
|
||||
InitL0A(resource);
|
||||
InitL0B(resource);
|
||||
@@ -272,9 +272,21 @@ public:
|
||||
CATLASS_DEVICE
|
||||
void Finalize(int32_t target, int32_t flag = 0)
|
||||
{
|
||||
for(;syncGroupIdx <= target; syncGroupIdx++) {
|
||||
int32_t flagId = syncGroupIdx / CROSS_CORE_FLAG_MAX_SET_COUNT + flag;
|
||||
AscendC::CrossCoreSetFlag<0x2, PIPE_FIX>(flagId);
|
||||
if (ptrSoftFlagBase_ != nullptr) {
|
||||
if (target < 0) {
|
||||
return;
|
||||
}
|
||||
AscendC::SetFlag<AscendC::HardEvent::FIX_MTE3>(EVENT_ID0);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::FIX_MTE3>(EVENT_ID0);
|
||||
AscendC::GlobalTensor<int32_t> flagGlobal;
|
||||
flagGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(ptrSoftFlagBase_) + (expertPerRank_ + AscendC::GetBlockIdx()) * FLAGSTRIDE);
|
||||
AscendC::DataCopy(flagGlobal, l1FTensor[target * 16], FLAGSTRIDE);
|
||||
}
|
||||
else {
|
||||
for(;syncGroupIdx <= target; syncGroupIdx++) {
|
||||
int32_t flagId = syncGroupIdx / 15 + flag;
|
||||
AscendC::CrossCoreSetFlag<0x2, PIPE_FIX>(flagId);
|
||||
}
|
||||
}
|
||||
}
|
||||
private:
|
||||
@@ -291,7 +303,6 @@ private:
|
||||
layout::VectorLayout layoutScale;
|
||||
int32_t syncLoopIdx;
|
||||
int32_t flag;
|
||||
|
||||
CATLASS_DEVICE
|
||||
L1TileMmadParams() = default;
|
||||
};
|
||||
@@ -310,11 +321,24 @@ private:
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(l1AEventList[i]);
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(l1BEventList[i]);
|
||||
}
|
||||
uint32_t l1SOffset = l1BOffset + L1B_TILE_SIZE * L1_STAGES;
|
||||
if constexpr (std::is_same_v<ElementA, int8_t>) {
|
||||
uint32_t l1SOffset = l1BOffset + L1B_TILE_SIZE * L1_STAGES;
|
||||
l1STensor = resource.l1Buf.template GetBufferByByte<uint64_t>(l1SOffset);
|
||||
AscendC::SetFlag<AscendC::HardEvent::FIX_MTE2>(0);
|
||||
}
|
||||
if (ptrSoftFlagBase_ != nullptr) {
|
||||
// Initialize the flag matrix (structure as below):
|
||||
// 1 0 0 0 0 0 0 0
|
||||
// 2 0 0 0 0 0 0 0
|
||||
// ...
|
||||
// 16 0 0 0 0 0 0 0
|
||||
// Then move it to L1
|
||||
uint32_t l1FOffset = l1SOffset + L1S_TILE_SIZE;
|
||||
l1FTensor = resource.l1Buf.template GetBufferByByte<int32_t>(l1FOffset);
|
||||
AscendC::GlobalTensor<int32_t> flagBase;
|
||||
flagBase.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(ptrSoftFlagBase_));
|
||||
AscendC::DataCopy(l1FTensor, flagBase, expertPerRank_ * FLAGSTRIDE);
|
||||
}
|
||||
}
|
||||
|
||||
CATLASS_DEVICE
|
||||
@@ -463,12 +487,20 @@ private:
|
||||
if constexpr (std::is_same_v<ElementA, int8_t>) {
|
||||
AscendC::SetFlag<AscendC::HardEvent::FIX_MTE2>(0);
|
||||
}
|
||||
#ifdef __TILE_SYNC__
|
||||
if (params.flag > 0) {
|
||||
int32_t flagId = params.flag + params.syncLoopIdx / 8;
|
||||
AscendC::CrossCoreSetFlag<0x2, PIPE_FIX>(flagId);
|
||||
}
|
||||
#else
|
||||
Finalize(params.syncLoopIdx, params.flag);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
AscendC::LocalTensor<ElementA> l1ATensorList[L1_STAGES];
|
||||
AscendC::LocalTensor<ElementB> l1BTensorList[L1_STAGES];
|
||||
AscendC::LocalTensor<uint64_t> l1STensor;
|
||||
AscendC::LocalTensor<int32_t> l1FTensor;
|
||||
int32_t syncGroupIdx;
|
||||
int32_t l1AEventList[L1_STAGES];
|
||||
int32_t l1BEventList[L1_STAGES];
|
||||
@@ -497,8 +529,11 @@ private:
|
||||
CopyL1ToL0A copyL1ToL0A;
|
||||
CopyL1ToL0B copyL1ToL0B;
|
||||
CopyL0CToGm copyL0CToGm;
|
||||
|
||||
__gm__ int32_t* ptrSoftFlagBase_ = nullptr;
|
||||
int32_t expertPerRank_;
|
||||
};
|
||||
|
||||
} // namespace Catlass::Gemm::Block
|
||||
|
||||
#endif // CATLASS_GEMM_BLOCK_BLOCK_MMAD_PRELOAD_FIXPIPE_QUANT_HPP
|
||||
#endif // CATLASS_GEMM_BLOCK_BLOCK_MMAD_PRELOAD_FIXPIPE_QUANT_HPP
|
||||
@@ -4,6 +4,10 @@
|
||||
constexpr static uint64_t MB_SIZE = 1024 * 1024UL;
|
||||
constexpr static int32_t NUMS_PER_FLAG = 16;
|
||||
constexpr static int32_t CACHE_LINE = 512;
|
||||
constexpr static int32_t FLAGSTRIDE = 16;
|
||||
constexpr static int32_t RESET_VAL = 0xffff;
|
||||
constexpr static int32_t ALIGN_128 = 128;
|
||||
constexpr uint32_t MAX_EXPERTS_PER_RANK = 32;
|
||||
constexpr static int32_t UB_ALIGN = 32;
|
||||
constexpr uint16_t CROSS_CORE_FLAG_MAX_SET_COUNT = 15;
|
||||
#endif
|
||||
@@ -5,16 +5,23 @@
|
||||
#include "kernel_operator.h"
|
||||
#include "const_args.hpp"
|
||||
|
||||
#ifdef HCCL_COMM
|
||||
#include "moe_distribute_base.h"
|
||||
|
||||
#ifndef HCCL_COMM
|
||||
#include "shmem_api.h"
|
||||
using namespace AscendC::HcclContextDef;
|
||||
|
||||
#else
|
||||
#include "shmem_api.h"
|
||||
#endif
|
||||
|
||||
#define FORCE_INLINE_AICORE inline __attribute__((always_inline)) __aicore__
|
||||
constexpr int32_t MAX_RANK_SIZE = 32;
|
||||
constexpr int32_t SHMEM_MEM = 1024 * MB_SIZE;
|
||||
constexpr int32_t SHMEM_MEM = 700 * MB_SIZE;
|
||||
|
||||
constexpr uint16_t SEND_SYNC_EVENT_ID = 9;
|
||||
constexpr uint16_t RECV_SYNC_EVENT_ID = 10;
|
||||
|
||||
constexpr uint32_t SELF_STATE_OFFSET = 256 * 1024;
|
||||
constexpr uint32_t STATE_OFFSET = 512;
|
||||
|
||||
FORCE_INLINE_AICORE void AicSyncAll() {
|
||||
AscendC::CrossCoreSetFlag<0x0, PIPE_FIX>(8);
|
||||
@@ -31,10 +38,11 @@ FORCE_INLINE_AICORE T gm_load(__gm__ T *cache) {
|
||||
return *((__gm__ T *)cache);
|
||||
}
|
||||
|
||||
FORCE_INLINE_AICORE void gm_dcci(__gm__ uint8_t * addr) {
|
||||
template<typename T>
|
||||
FORCE_INLINE_AICORE void gm_dcci(__gm__ T * addr) {
|
||||
using namespace AscendC;
|
||||
GlobalTensor<uint8_t> global;
|
||||
global.SetGlobalBuffer(addr);
|
||||
global.SetGlobalBuffer(reinterpret_cast<GM_ADDR>(addr));
|
||||
|
||||
// Important: add hint to avoid dcci being optimized by compiler
|
||||
__asm__ __volatile__("");
|
||||
@@ -58,7 +66,7 @@ FORCE_INLINE_AICORE int32_t gm_signal_wait_until_eq_for_barrier(__gm__ int32_t *
|
||||
FORCE_INLINE_AICORE void gm_signal_wait_until_ne(__gm__ int32_t *sig_addr, int32_t cmp_val) {
|
||||
do {
|
||||
AscendC::LocalTensor<int32_t> ub;
|
||||
ub.address_.logicPos = static_cast<uint8_t>(TPosition::VECIN);
|
||||
ub.address_.logicPos = static_cast<uint8_t>(AscendC::TPosition::VECIN);
|
||||
ub.address_.bufferAddr = 0;
|
||||
AscendC::GlobalTensor<int32_t> sig;
|
||||
sig.SetGlobalBuffer(sig_addr);
|
||||
@@ -75,10 +83,10 @@ FORCE_INLINE_AICORE void gm_signal_wait_until_ne(__gm__ int32_t *sig_addr, int32
|
||||
|
||||
class HcclShmem {
|
||||
public:
|
||||
#ifdef HCCL_COMM
|
||||
#ifdef HCCL_COMM // HCCL needs to initialize the HCCL context
|
||||
__gm__ HcclOpResParamCustom *WinContext_{nullptr};
|
||||
Hccl<HCCL_SERVER_TYPE_AICPU> hccl_;
|
||||
GM_ADDR m_ptrArray[MAX_RANK_SIZE];
|
||||
AscendC::LocalTensor<int32_t> ub;
|
||||
FORCE_INLINE_AICORE
|
||||
HcclShmem(){
|
||||
auto contextGM0 = AscendC::GetHcclContext<HCCL_GROUP_ID_0>();
|
||||
@@ -87,17 +95,13 @@ public:
|
||||
m_rank = WinContext_->localUsrRankId;
|
||||
m_rankSize = WinContext_->rankSize;
|
||||
m_segmentSize = WinContext_->winSize;
|
||||
for (int i = 0; i < m_rankSize; i++) {
|
||||
m_ptrArray[i] = (GM_ADDR)((i == m_rank) ? WinContext_->localWindowsIn :
|
||||
((HcclRankRelationResV2Custom *)(WinContext_->remoteRes[i].nextDevicePtr))->windowsIn);
|
||||
}
|
||||
}
|
||||
#else
|
||||
FORCE_INLINE_AICORE
|
||||
HcclShmem(){
|
||||
m_segmentSize = SHMEM_MEM;
|
||||
}
|
||||
FORCE_INLINE_AICORE
|
||||
FORCE_INLINE_AICORE
|
||||
void initShmem(GM_ADDR symmetricPtr_, size_t rank, size_t rankSize) {
|
||||
symmetricPtr = symmetricPtr_;
|
||||
m_rank = rank;
|
||||
@@ -106,25 +110,26 @@ public:
|
||||
#endif
|
||||
|
||||
FORCE_INLINE_AICORE
|
||||
GM_ADDR operator() () const {
|
||||
GM_ADDR operator() () const { // No parameters: return pointer to local peermem
|
||||
#ifdef HCCL_COMM
|
||||
return m_ptrArray[m_rank];
|
||||
return (GM_ADDR)(WinContext_->localWindowsIn);
|
||||
#else
|
||||
return reinterpret_cast<GM_ADDR>(shmem_ptr(symmetricPtr, m_rank));
|
||||
#endif
|
||||
}
|
||||
|
||||
FORCE_INLINE_AICORE
|
||||
GM_ADDR operator() (int32_t index) const {
|
||||
GM_ADDR operator() (int32_t index) const { // With index parameter: return pointer to the base address of remote peermem
|
||||
#ifdef HCCL_COMM
|
||||
return m_ptrArray[index];
|
||||
return (GM_ADDR)((index == m_rank) ? WinContext_->localWindowsIn :
|
||||
((HcclRankRelationResV2Custom *)(WinContext_->remoteRes[index].nextDevicePtr))->windowsIn);
|
||||
#else
|
||||
return reinterpret_cast<GM_ADDR>(shmem_ptr(symmetricPtr, index));
|
||||
#endif
|
||||
}
|
||||
|
||||
FORCE_INLINE_AICORE
|
||||
GM_ADDR operator () (int64_t offset, int32_t rankId) const {
|
||||
GM_ADDR operator () (int64_t offset, int32_t rankId) const {
|
||||
#ifdef HCCL_COMM
|
||||
if (offset < 0 || offset >= m_segmentSize) {
|
||||
return nullptr;
|
||||
@@ -132,7 +137,8 @@ public:
|
||||
if (rankId < 0 || rankId >= m_rankSize) {
|
||||
return nullptr;
|
||||
}
|
||||
return m_ptrArray[rankId] + offset;
|
||||
return (GM_ADDR)((rankId == m_rank) ? WinContext_->localWindowsIn :
|
||||
((HcclRankRelationResV2Custom *)(WinContext_->remoteRes[rankId].nextDevicePtr))->windowsIn) + offset;
|
||||
#else
|
||||
return reinterpret_cast<GM_ADDR>(shmem_ptr((symmetricPtr + offset), rankId));
|
||||
#endif
|
||||
@@ -176,6 +182,130 @@ public:
|
||||
gm_store(sync_base, count);
|
||||
}
|
||||
|
||||
|
||||
FORCE_INLINE_AICORE
|
||||
void InitStatusTargetSum()
|
||||
{
|
||||
using namespace AscendC;
|
||||
uint64_t flag_offset = (m_segmentSize - MB_SIZE) + SELF_STATE_OFFSET;
|
||||
//uint64_t self_state_offset = (m_segmentSize - 2 * MB_SIZE);
|
||||
// ep state
|
||||
//uint32_t coreIdx = get_block_idx();;
|
||||
uint32_t coreIdx = GetBlockIdx();
|
||||
GlobalTensor<int32_t> selfStatusTensor;
|
||||
selfStatusTensor.SetGlobalBuffer((__gm__ int32_t *)((*this)() + flag_offset));
|
||||
__asm__ __volatile__("");
|
||||
DataCacheCleanAndInvalid<int32_t, CacheLine::SINGLE_CACHE_LINE, DcciDst::CACHELINE_OUT>(selfStatusTensor[coreIdx * UB_ALIGN]);
|
||||
__asm__ __volatile__("");
|
||||
int32_t state = selfStatusTensor(coreIdx * UB_ALIGN);
|
||||
if (state == 0) {
|
||||
sumTarget_ = static_cast<float>(1.0);
|
||||
selfStatusTensor(coreIdx * UB_ALIGN) = 0x3F800000; // 1.0f
|
||||
epStateValue_ = 0x3F800000; // 1.0f
|
||||
} else {
|
||||
sumTarget_ = static_cast<float>(0.0);
|
||||
selfStatusTensor(coreIdx * UB_ALIGN) = 0;
|
||||
epStateValue_ = 0;
|
||||
}
|
||||
__asm__ __volatile__("");
|
||||
DataCacheCleanAndInvalid<int32_t, CacheLine::SINGLE_CACHE_LINE, DcciDst::CACHELINE_OUT>(selfStatusTensor[coreIdx * UB_ALIGN]);
|
||||
__asm__ __volatile__("");
|
||||
}
|
||||
|
||||
FORCE_INLINE_AICORE
|
||||
void CrossRankSyncV2Set(AscendC::LocalTensor<int32_t> ctrBuffer) {
|
||||
//subblockid = 0
|
||||
uint32_t stateOffset_ = STATE_OFFSET;
|
||||
// uint32_t epStateOffsetOnWin_ = m_rank * stateOffset_;
|
||||
|
||||
uint64_t flag_offset = (m_segmentSize - MB_SIZE) + m_rank * stateOffset_;
|
||||
//uint64_t flag_offset = (m_segmentSize - MB_SIZE);
|
||||
int vec_size = get_block_num();
|
||||
int vec_id = get_block_idx();
|
||||
|
||||
AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>(RECV_SYNC_EVENT_ID);
|
||||
AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>(SEND_SYNC_EVENT_ID);
|
||||
AscendC::CrossCoreWaitFlag(SEND_SYNC_EVENT_ID);
|
||||
pipe_barrier(PIPE_ALL);
|
||||
|
||||
ctrBuffer.SetValue(0, epStateValue_);
|
||||
AscendC::SetFlag<AscendC::HardEvent::S_MTE3>(EVENT_ID0);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::S_MTE3>(EVENT_ID0);
|
||||
for (uint32_t dstEpIdx = vec_id; dstEpIdx < m_rankSize; dstEpIdx += vec_size) {
|
||||
AscendC::GlobalTensor<int32_t> gmDstStates;
|
||||
gmDstStates.SetGlobalBuffer((__gm__ int32_t*)((*this)(flag_offset, dstEpIdx)));
|
||||
DataCopy(gmDstStates, ctrBuffer, 8);
|
||||
}
|
||||
AscendC::CrossCoreWaitFlag(RECV_SYNC_EVENT_ID);
|
||||
}
|
||||
|
||||
FORCE_INLINE_AICORE
|
||||
void CrossRankSyncV2Wait(AscendC::LocalTensor<float> statusTensor, AscendC::LocalTensor<float> gatherMaskOutTensor,
|
||||
AscendC::LocalTensor<uint32_t> gatherTmpTensor, AscendC::LocalTensor<float> statusSumOutTensor) {
|
||||
|
||||
uint64_t flag_offset = (m_segmentSize - MB_SIZE);
|
||||
int vec_size = get_block_num();
|
||||
int vec_id = get_block_idx();
|
||||
uint32_t stateOffset_ = STATE_OFFSET;
|
||||
|
||||
uint32_t sendRankNum_ = m_rankSize / vec_size;
|
||||
uint32_t remainderRankNum = m_rankSize % vec_size;
|
||||
uint32_t startRankId_ = sendRankNum_ * vec_id;
|
||||
if (vec_id < remainderRankNum) {
|
||||
sendRankNum_++;
|
||||
startRankId_ += vec_id;
|
||||
} else {
|
||||
startRankId_ += remainderRankNum;
|
||||
}
|
||||
uint32_t endRankId_ = startRankId_ + sendRankNum_;
|
||||
AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>(SEND_SYNC_EVENT_ID);
|
||||
|
||||
AscendC::GlobalTensor<float> epStatusSpaceGlobalTensor_;
|
||||
epStatusSpaceGlobalTensor_.SetGlobalBuffer((__gm__ float *)((*this)() + flag_offset));
|
||||
|
||||
if (startRankId_ < m_rankSize) {
|
||||
AscendC::PipeBarrier<PIPE_ALL>();
|
||||
gatherTmpTensor.SetValue(0, 1);
|
||||
uint32_t mask = 1; // gatherMask + sum
|
||||
uint64_t rsvdCnt = 0;
|
||||
// DataCopyParams intriParams{static_cast<uint16_t>(sendRankNum_), 1,
|
||||
// static_cast<uint16_t>((moeSendNum_ > 512) ? 7 : 15), 0};
|
||||
AscendC::DataCopyParams intriParams{static_cast<uint16_t>(sendRankNum_), 1,
|
||||
static_cast<uint16_t>(15), 0};
|
||||
|
||||
float sumOfFlag = static_cast<float>(-1.0);
|
||||
float minTarget = (sumTarget_ * sendRankNum_) - (float)0.5;
|
||||
float maxTarget = (sumTarget_ * sendRankNum_) + (float)0.5;
|
||||
AscendC::SumParams sumParams{1, sendRankNum_, sendRankNum_};
|
||||
|
||||
AscendC::SetFlag<AscendC::HardEvent::S_V>(EVENT_ID0);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::S_V>(EVENT_ID0);
|
||||
|
||||
while ((sumOfFlag < minTarget) || (sumOfFlag > maxTarget)) {
|
||||
AscendC::DataCopy<float>(statusTensor, epStatusSpaceGlobalTensor_[startRankId_ * stateOffset_ / sizeof(float)],
|
||||
intriParams);
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
|
||||
|
||||
GatherMask(gatherMaskOutTensor, statusTensor, gatherTmpTensor, true, mask,
|
||||
{1, (uint16_t)sendRankNum_, 1, 0}, rsvdCnt);
|
||||
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
AscendC::Sum(statusSumOutTensor, gatherMaskOutTensor, sumParams);
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_S>(EVENT_ID0);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_S>(EVENT_ID0);
|
||||
sumOfFlag = statusSumOutTensor.GetValue(0);
|
||||
}
|
||||
}
|
||||
|
||||
AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>(RECV_SYNC_EVENT_ID);
|
||||
AscendC::CrossCoreWaitFlag(RECV_SYNC_EVENT_ID);
|
||||
|
||||
//unpermute
|
||||
AscendC::CrossCoreWaitFlag(SEND_SYNC_EVENT_ID);
|
||||
}
|
||||
|
||||
|
||||
FORCE_INLINE_AICORE
|
||||
__gm__ int32_t* SyncBaseAddr() {
|
||||
uint64_t flag_offset = (m_segmentSize - MB_SIZE) / sizeof(int32_t);
|
||||
@@ -187,9 +317,11 @@ private:
|
||||
int32_t m_rank;
|
||||
int32_t m_rankSize;
|
||||
size_t m_segmentSize;
|
||||
float sumTarget_{0.0};
|
||||
int32_t epStateValue_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
|
||||
#endif
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user