[fix]: fix precision issue in dispatch_ffn_combine_bf16 and remove redundant sync (#7198)
### What this PR does / why we need it?
Fix the precision issue in dispatch_ffn_combine_bf16 operator.
Remove redundant synchronization operations in dispatch_ffn_combine
operator.
- vLLM version: v0.16.0
- vLLM main:
4034c3d32e
---------
Signed-off-by: guanguan0308 <1546542263@qq.com>
This commit is contained in:
@@ -391,7 +391,6 @@ private:
|
|||||||
uint16_t syncgmmIdx = 0;
|
uint16_t syncgmmIdx = 0;
|
||||||
AscendC::CrossCoreWaitFlag<0x2>(syncgmmIdx / CROSS_CORE_FLAG_MAX_SET_COUNT); // Wait for AIV to finish cumsum for matmul
|
AscendC::CrossCoreWaitFlag<0x2>(syncgmmIdx / CROSS_CORE_FLAG_MAX_SET_COUNT); // Wait for AIV to finish cumsum for matmul
|
||||||
syncgmmIdx++;
|
syncgmmIdx++;
|
||||||
AscendC::PipeBarrier<PIPE_ALL>();
|
|
||||||
|
|
||||||
for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) {
|
for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) {
|
||||||
uint32_t currentM = cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx);
|
uint32_t currentM = cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx);
|
||||||
@@ -405,7 +404,6 @@ private:
|
|||||||
int32_t arrayGroupIdx = params.listLen == 1 ? 0 : groupIdx;
|
int32_t arrayGroupIdx = params.listLen == 1 ? 0 : groupIdx;
|
||||||
gmB1.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(GetTensorAddr<int8_t>(arrayGroupIdx, params.ptrB1)));
|
gmB1.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(GetTensorAddr<int8_t>(arrayGroupIdx, params.ptrB1)));
|
||||||
gmS.SetGlobalBuffer(reinterpret_cast<__gm__ ElementScale *>(GetTensorAddr<int64_t>(arrayGroupIdx, params.ptrScale1)));
|
gmS.SetGlobalBuffer(reinterpret_cast<__gm__ ElementScale *>(GetTensorAddr<int64_t>(arrayGroupIdx, params.ptrScale1)));
|
||||||
AscendC::PipeBarrier<PIPE_ALL>();
|
|
||||||
if (currentM <= L1TileShape::M) {
|
if (currentM <= L1TileShape::M) {
|
||||||
gmB1.SetL2CacheHint(AscendC::CacheMode::CACHE_MODE_DISABLE);
|
gmB1.SetL2CacheHint(AscendC::CacheMode::CACHE_MODE_DISABLE);
|
||||||
}
|
}
|
||||||
@@ -493,8 +491,6 @@ private:
|
|||||||
|
|
||||||
uint32_t startCoreIdx = 0;
|
uint32_t startCoreIdx = 0;
|
||||||
|
|
||||||
AscendC::PipeBarrier<PIPE_ALL>();
|
|
||||||
|
|
||||||
int64_t preCurrentmSum = 0;
|
int64_t preCurrentmSum = 0;
|
||||||
int32_t syncLoopIdx = -1;
|
int32_t syncLoopIdx = -1;
|
||||||
uint32_t lastDequantExpertNum = params.expertPerRank;
|
uint32_t lastDequantExpertNum = params.expertPerRank;
|
||||||
@@ -503,8 +499,6 @@ private:
|
|||||||
lastDequantExpertNum = params.expertPerRank - params.epilogueGranularity;
|
lastDequantExpertNum = params.expertPerRank - params.epilogueGranularity;
|
||||||
}
|
}
|
||||||
|
|
||||||
AscendC::PipeBarrier<PIPE_ALL>();
|
|
||||||
|
|
||||||
for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) {
|
for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) {
|
||||||
uint32_t currentM = cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx);
|
uint32_t currentM = cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx);
|
||||||
if (preCurrentmSum >= params.maxOutputSize) {
|
if (preCurrentmSum >= params.maxOutputSize) {
|
||||||
|
|||||||
@@ -41,6 +41,7 @@ namespace {
|
|||||||
constexpr uint32_t EXPERTID_INDEX = 3;
|
constexpr uint32_t EXPERTID_INDEX = 3;
|
||||||
constexpr uint32_t BLOCK_NUM = 20;
|
constexpr uint32_t BLOCK_NUM = 20;
|
||||||
constexpr uint32_t SYSTEM_NEED_WORKSPACE = 16 * 1024 * 1024;
|
constexpr uint32_t SYSTEM_NEED_WORKSPACE = 16 * 1024 * 1024;
|
||||||
|
constexpr uint64_t MB_SIZE = 1024 * 1024UL;
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace optiling {
|
namespace optiling {
|
||||||
@@ -240,7 +241,8 @@ static ge::graphStatus DispatchFFNCombineBF16TilingFuncImpl(gert::TilingContext
|
|||||||
info.maxOutputSize * n2 * sizeof(int16_t) +
|
info.maxOutputSize * n2 * sizeof(int16_t) +
|
||||||
info.maxOutputSize * info.K * sizeof(int16_t) +
|
info.maxOutputSize * info.K * sizeof(int16_t) +
|
||||||
info.maxOutputSize * k2 * sizeof(int16_t) +
|
info.maxOutputSize * k2 * sizeof(int16_t) +
|
||||||
info.worldSize * sizeof(int32_t) * 16;
|
info.worldSize * sizeof(int32_t) * 16 +
|
||||||
|
(info.expertPerRank + info.worldSize) * sizeof(int32_t) * 16;
|
||||||
// std::max(info.maxOutputSize * info.N * sizeof(int16_t), info.maxOutputSize * n2 * sizeof(int16_t)) +
|
// std::max(info.maxOutputSize * info.N * sizeof(int16_t), info.maxOutputSize * n2 * sizeof(int16_t)) +
|
||||||
// std::max(info.maxOutputSize * info.K * sizeof(int8_t), info.maxOutputSize * k2 * sizeof(int8_t));
|
// std::max(info.maxOutputSize * info.K * sizeof(int8_t), info.maxOutputSize * k2 * sizeof(int8_t));
|
||||||
|
|
||||||
|
|||||||
@@ -213,11 +213,7 @@ public:
|
|||||||
CATLASS_DEVICE
|
CATLASS_DEVICE
|
||||||
void operator()<AscendC::AIV>(Params const ¶ms)
|
void operator()<AscendC::AIV>(Params const ¶ms)
|
||||||
{
|
{
|
||||||
Dispatch(params);
|
DispatchAndCombine(params);
|
||||||
AscendC::SyncAll<true>();
|
|
||||||
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(SYNCFLAGV2C);
|
|
||||||
|
|
||||||
CombineV2(params);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@@ -241,7 +237,7 @@ private:
|
|||||||
|
|
||||||
tokenPerExpert.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(shmem() + peermemInfo.offsetPeerTokenPerExpert));
|
tokenPerExpert.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(shmem() + peermemInfo.offsetPeerTokenPerExpert));
|
||||||
|
|
||||||
tokenPerExpertLayout = Layout3D(params.EP * params.expertPerRank, params.expertPerRank);
|
tokenPerExpertLayout = Layout3D(AlignUp(params.EP * params.expertPerRank, ALIGN_128), params.expertPerRank);
|
||||||
preSumBeforeRank.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(workspaceInfo.ptrSumBeforeRank));
|
preSumBeforeRank.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(workspaceInfo.ptrSumBeforeRank));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -309,7 +305,7 @@ private:
|
|||||||
AscendC::DataCopyPad(
|
AscendC::DataCopyPad(
|
||||||
tmpBuffer1,
|
tmpBuffer1,
|
||||||
tokenPerExpert[rankId * expertPerRank],
|
tokenPerExpert[rankId * expertPerRank],
|
||||||
{U16(EP), U16(expertPerRank * sizeof(int32_t)), U16(((EP - 1) * expertPerRank) * sizeof(int32_t)), 0},
|
{U16(EP), U16(expertPerRank * sizeof(int32_t)), U16((AlignUp(EP * expertPerRank, 128) - expertPerRank) * sizeof(int32_t)), 0},
|
||||||
{}
|
{}
|
||||||
);
|
);
|
||||||
|
|
||||||
@@ -331,145 +327,6 @@ private:
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
CATLASS_DEVICE
|
|
||||||
void CrossRankSyncAndlocalTokenPerExpertAllGatherAndGetSumPreRank(Params const ¶ms, int64_t localTokenPerExpertOffset){
|
|
||||||
AscendC::LocalTensor<int32_t> tmpBuffer = resource.ubBuf.template GetBufferByByte<int32_t>(0);
|
|
||||||
AscendC::LocalTensor<float> ubFloat = resource.ubBuf.template GetBufferByByte<float>(0);
|
|
||||||
uint32_t numPerCore = params.EP * params.expertPerRank;
|
|
||||||
for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) {
|
|
||||||
if (dstEpIdx == params.rank) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
AscendC::GlobalTensor<int32_t> srcAddress;
|
|
||||||
srcAddress.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(shmem() + localTokenPerExpertOffset));
|
|
||||||
AscendC::GlobalTensor<int32_t> dstAddress;
|
|
||||||
__gm__ void* dstPeermemPtr = shmem(localTokenPerExpertOffset, coreIdx);
|
|
||||||
dstAddress.SetGlobalBuffer((__gm__ int32_t * )dstPeermemPtr);
|
|
||||||
|
|
||||||
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
|
|
||||||
using TType = Gemm::GemmType<int32_t, layout::RowMajor>;
|
|
||||||
using CopyGmToUb = Epilogue::Tile::CopyGm2Ub<ArchTag, TType>;
|
|
||||||
using CopyUbToGm = Epilogue::Tile::CopyUb2Gm<ArchTag, TType>;
|
|
||||||
CopyGmToUb copyGmToUb;
|
|
||||||
CopyUbToGm copyUbToGm;
|
|
||||||
|
|
||||||
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
|
|
||||||
|
|
||||||
copyGmToUb(tmpBuffer, srcAddress[0],
|
|
||||||
layout::RowMajor{ 1, numPerCore},
|
|
||||||
layout::RowMajor{1, numPerCore});
|
|
||||||
|
|
||||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
|
|
||||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
|
|
||||||
AscendC::Adds(tmpBuffer, tmpBuffer, 0x800000, numPerCore);
|
|
||||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
|
|
||||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
|
|
||||||
copyUbToGm(dstAddress[0], tmpBuffer,
|
|
||||||
layout::RowMajor{ 1, numPerCore},
|
|
||||||
layout::RowMajor{1, numPerCore});
|
|
||||||
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
|
|
||||||
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
|
|
||||||
}
|
|
||||||
|
|
||||||
for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) {
|
|
||||||
if (dstEpIdx != params.rank) {
|
|
||||||
int32_t intPer512 = CACHE_LINE / sizeof(int);
|
|
||||||
for(int32_t checkIdx = 0; checkIdx < params.EP * params.expertPerRank; checkIdx += intPer512) {
|
|
||||||
__gm__ int32_t* sync_check = reinterpret_cast<__gm__ int32_t*>(shmem() + peermemInfo.offsetPeerTokenPerExpert) + tokenPerExpertLayout(dstEpIdx, 0, checkIdx);
|
|
||||||
gm_signal_wait_until_ne(sync_check, 0);
|
|
||||||
}
|
|
||||||
AscendC::DataCopy(tmpBuffer, tokenPerExpert[tokenPerExpertLayout(dstEpIdx, 0, 0)], numPerCore);
|
|
||||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
|
|
||||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
|
|
||||||
AscendC::Adds(tmpBuffer, tmpBuffer, -0x800000, numPerCore);
|
|
||||||
AscendC::PipeBarrier<PIPE_V>();
|
|
||||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
|
|
||||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
|
|
||||||
AscendC::DataCopy(tokenPerExpert[tokenPerExpertLayout(dstEpIdx, 0, 0)], tmpBuffer, numPerCore);
|
|
||||||
} else {
|
|
||||||
AscendC::DataCopy(tmpBuffer, tokenPerExpert[tokenPerExpertLayout(dstEpIdx, 0, 0)], numPerCore);
|
|
||||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
|
|
||||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
AscendC::SyncAll<true>();
|
|
||||||
}
|
|
||||||
|
|
||||||
CATLASS_DEVICE
|
|
||||||
void CrossRankSyncAndlocalTokenPerExpertAllGatherAndGetSumPreRankV2(Params const ¶ms, int64_t localTokenPerExpertOffset){
|
|
||||||
AscendC::LocalTensor<int32_t> tmpBuffer = resource.ubBuf.template GetBufferByByte<int32_t>(0);
|
|
||||||
AscendC::LocalTensor<float> ubFloat = resource.ubBuf.template GetBufferByByte<float>(0);
|
|
||||||
uint32_t numPerCore = params.EP * params.expertPerRank;
|
|
||||||
for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) {
|
|
||||||
if (dstEpIdx == params.rank) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
AscendC::GlobalTensor<int32_t> srcAddress;
|
|
||||||
srcAddress.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(shmem() + localTokenPerExpertOffset));
|
|
||||||
AscendC::GlobalTensor<int32_t> dstAddress;
|
|
||||||
__gm__ void* dstPeermemPtr = shmem(localTokenPerExpertOffset, coreIdx);
|
|
||||||
dstAddress.SetGlobalBuffer((__gm__ int32_t * )dstPeermemPtr);
|
|
||||||
|
|
||||||
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
|
|
||||||
using TType = Gemm::GemmType<int32_t, layout::RowMajor>;
|
|
||||||
using CopyGmToUb = Epilogue::Tile::CopyGm2Ub<ArchTag, TType>;
|
|
||||||
using CopyUbToGm = Epilogue::Tile::CopyUb2Gm<ArchTag, TType>;
|
|
||||||
CopyGmToUb copyGmToUb;
|
|
||||||
CopyUbToGm copyUbToGm;
|
|
||||||
|
|
||||||
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
|
|
||||||
|
|
||||||
copyGmToUb(tmpBuffer, srcAddress[0],
|
|
||||||
layout::RowMajor{ 1, numPerCore},
|
|
||||||
layout::RowMajor{1, numPerCore});
|
|
||||||
|
|
||||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
|
|
||||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
|
|
||||||
AscendC::Adds(tmpBuffer, tmpBuffer, 0x800000, numPerCore);
|
|
||||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
|
|
||||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
|
|
||||||
copyUbToGm(dstAddress[0], tmpBuffer,
|
|
||||||
layout::RowMajor{ 1, numPerCore},
|
|
||||||
layout::RowMajor{1, numPerCore});
|
|
||||||
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
|
|
||||||
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
|
|
||||||
}
|
|
||||||
|
|
||||||
for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) {
|
|
||||||
if (dstEpIdx != params.rank) {
|
|
||||||
int32_t intPer512 = CACHE_LINE / sizeof(int);
|
|
||||||
for(int32_t checkIdx = 0; checkIdx < params.EP * params.expertPerRank; checkIdx += intPer512) {
|
|
||||||
__gm__ int32_t* sync_check = reinterpret_cast<__gm__ int32_t*>(shmem() + peermemInfo.offsetPeerTokenPerExpert) + tokenPerExpertLayout(dstEpIdx, 0, checkIdx);
|
|
||||||
gm_signal_wait_until_ne(sync_check, 0);
|
|
||||||
}
|
|
||||||
AscendC::DataCopy(tmpBuffer, tokenPerExpert[tokenPerExpertLayout(dstEpIdx, 0, 0)], numPerCore);
|
|
||||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
|
|
||||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
|
|
||||||
AscendC::Adds(tmpBuffer, tmpBuffer, -0x800000, numPerCore);
|
|
||||||
AscendC::PipeBarrier<PIPE_V>();
|
|
||||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
|
|
||||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
|
|
||||||
AscendC::DataCopy(tokenPerExpert[tokenPerExpertLayout(dstEpIdx, 0, 0)], tmpBuffer, numPerCore);
|
|
||||||
} else {
|
|
||||||
AscendC::DataCopy(tmpBuffer, tokenPerExpert[tokenPerExpertLayout(dstEpIdx, 0, 0)], numPerCore);
|
|
||||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
|
|
||||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
|
|
||||||
}
|
|
||||||
|
|
||||||
int32_t prevSum = 0;
|
|
||||||
for (int32_t i = 0; i < params.rank * params.expertPerRank; i++) {
|
|
||||||
prevSum += tmpBuffer(i);
|
|
||||||
}
|
|
||||||
preSumBeforeRank(dstEpIdx * 16) = prevSum;
|
|
||||||
__asm__ __volatile__("");
|
|
||||||
DataCacheCleanAndInvalid<int32_t, CacheLine::SINGLE_CACHE_LINE, DcciDst::CACHELINE_OUT>(preSumBeforeRank[dstEpIdx * 16]);
|
|
||||||
__asm__ __volatile__("");
|
|
||||||
|
|
||||||
}
|
|
||||||
AscendC::SyncAll<true>();
|
|
||||||
}
|
|
||||||
|
|
||||||
CATLASS_DEVICE
|
CATLASS_DEVICE
|
||||||
void GetSumPreRank(AscendC::GlobalTensor<int32_t> & tokenPerExpert, AscendC::GlobalTensor<int32_t> & result,
|
void GetSumPreRank(AscendC::GlobalTensor<int32_t> & tokenPerExpert, AscendC::GlobalTensor<int32_t> & result,
|
||||||
uint32_t expertPerRank, uint32_t rankId, uint32_t EP) {
|
uint32_t expertPerRank, uint32_t rankId, uint32_t EP) {
|
||||||
@@ -506,20 +363,20 @@ CATLASS_DEVICE
|
|||||||
icache_preload(8);
|
icache_preload(8);
|
||||||
BlockScheduler blockScheduler;
|
BlockScheduler blockScheduler;
|
||||||
BlockMmad blockMmad(resource);
|
BlockMmad blockMmad(resource);
|
||||||
|
float aivFinishGroups = 0.0f;
|
||||||
|
__gm__ float* aivFinishPtr = workspaceInfo.ptrSoftFlagBase + params.EP * FLAGSTRIDE;
|
||||||
|
|
||||||
int64_t gmGroupOffsetA = 0;
|
int64_t gmGroupOffsetA = 0;
|
||||||
int64_t gmGroupOffsetB = 0;
|
int64_t gmGroupOffsetB = 0;
|
||||||
int64_t gmGroupOffsetC = 0;
|
int64_t gmGroupOffsetC = 0;
|
||||||
uint32_t startCoreIdx = 0;
|
uint32_t startCoreIdx = 0;
|
||||||
uint32_t syncGroupIdx = 0;
|
uint32_t syncGroupIdx = 0;
|
||||||
uint16_t syncgmmIdx = 0;
|
|
||||||
AscendC::CrossCoreWaitFlag<0x2>(syncgmmIdx / CROSS_CORE_FLAG_MAX_SET_COUNT); // Wait for AIV to finish cumsum for matmul
|
|
||||||
syncgmmIdx++;
|
|
||||||
AicSyncAll();
|
|
||||||
int64_t preCurrentmSum = 0;
|
int64_t preCurrentmSum = 0;
|
||||||
int32_t syncLoopIdx = -1;
|
int32_t syncLoopIdx = -1;
|
||||||
|
|
||||||
AscendC::PipeBarrier<PIPE_ALL>();
|
uint16_t syncgmmIdx = 0;
|
||||||
|
AscendC::CrossCoreWaitFlag<0x2>(syncgmmIdx / CROSS_CORE_FLAG_MAX_SET_COUNT); // Wait for AIV to finish cumsum for matmul
|
||||||
|
syncgmmIdx++;
|
||||||
|
|
||||||
for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) {
|
for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) {
|
||||||
uint32_t currentM = cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx);
|
uint32_t currentM = cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx);
|
||||||
@@ -533,9 +390,6 @@ CATLASS_DEVICE
|
|||||||
int32_t arrayGroupIdx = params.listLen == 1 ? 0 : groupIdx;
|
int32_t arrayGroupIdx = params.listLen == 1 ? 0 : groupIdx;
|
||||||
gmB1.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(GetTensorAddr<int8_t>(arrayGroupIdx, params.ptrB1)));
|
gmB1.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(GetTensorAddr<int8_t>(arrayGroupIdx, params.ptrB1)));
|
||||||
gmS.SetGlobalBuffer(reinterpret_cast<__gm__ ElementScale *>(GetTensorAddr<int64_t>(arrayGroupIdx, params.ptrScale1)));
|
gmS.SetGlobalBuffer(reinterpret_cast<__gm__ ElementScale *>(GetTensorAddr<int64_t>(arrayGroupIdx, params.ptrScale1)));
|
||||||
|
|
||||||
AscendC::PipeBarrier<PIPE_ALL>();
|
|
||||||
|
|
||||||
if (currentM <= L1TileShape::M) {
|
if (currentM <= L1TileShape::M) {
|
||||||
gmB1.SetL2CacheHint(AscendC::CacheMode::CACHE_MODE_DISABLE);
|
gmB1.SetL2CacheHint(AscendC::CacheMode::CACHE_MODE_DISABLE);
|
||||||
}
|
}
|
||||||
@@ -630,7 +484,6 @@ CATLASS_DEVICE
|
|||||||
|
|
||||||
uint32_t startCoreIdx = 0;
|
uint32_t startCoreIdx = 0;
|
||||||
|
|
||||||
AscendC::PipeBarrier<PIPE_ALL>();
|
|
||||||
|
|
||||||
int64_t preCurrentmSum = 0;
|
int64_t preCurrentmSum = 0;
|
||||||
int32_t syncLoopIdx = -1;
|
int32_t syncLoopIdx = -1;
|
||||||
@@ -640,7 +493,6 @@ CATLASS_DEVICE
|
|||||||
lastDequantExpertNum = params.expertPerRank - params.epilogueGranularity;
|
lastDequantExpertNum = params.expertPerRank - params.epilogueGranularity;
|
||||||
}
|
}
|
||||||
|
|
||||||
AscendC::PipeBarrier<PIPE_ALL>();
|
|
||||||
|
|
||||||
for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) {
|
for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) {
|
||||||
uint32_t currentM = cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx);
|
uint32_t currentM = cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx);
|
||||||
@@ -651,7 +503,6 @@ CATLASS_DEVICE
|
|||||||
}
|
}
|
||||||
AscendC::GlobalTensor<ElementB> gmB2;
|
AscendC::GlobalTensor<ElementB> gmB2;
|
||||||
AscendC::GlobalTensor<ElementScale> gmS2;
|
AscendC::GlobalTensor<ElementScale> gmS2;
|
||||||
AscendC::PipeBarrier<PIPE_ALL>();
|
|
||||||
int32_t arrayGroupIdx = params.listLen == 1 ? 0 : groupIdx;
|
int32_t arrayGroupIdx = params.listLen == 1 ? 0 : groupIdx;
|
||||||
gmB2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(GetTensorAddr<int8_t>(arrayGroupIdx, params.ptrB2)));
|
gmB2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(GetTensorAddr<int8_t>(arrayGroupIdx, params.ptrB2)));
|
||||||
gmS2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementScale *>(GetTensorAddr<int64_t>(arrayGroupIdx, params.ptrScale2)));
|
gmS2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementScale *>(GetTensorAddr<int64_t>(arrayGroupIdx, params.ptrScale2)));
|
||||||
@@ -721,7 +572,6 @@ CATLASS_DEVICE
|
|||||||
gmGroupOffsetC += inGroupProblemShape.m() * inGroupProblemShape.n();
|
gmGroupOffsetC += inGroupProblemShape.m() * inGroupProblemShape.n();
|
||||||
|
|
||||||
startCoreIdx = (startCoreIdx + coreLoops) % coreNum;
|
startCoreIdx = (startCoreIdx + coreLoops) % coreNum;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if constexpr (BlockMmad::DispatchPolicy::ASYNC) {
|
if constexpr (BlockMmad::DispatchPolicy::ASYNC) {
|
||||||
@@ -729,8 +579,168 @@ CATLASS_DEVICE
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
CATLASS_DEVICE
|
||||||
|
void InitArithProgress(Params const ¶ms) {
|
||||||
|
AscendC::LocalTensor<float> tmpBuffer1 = resource.ubBuf.template GetBufferByByte<float>(0);
|
||||||
|
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID0);
|
||||||
|
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID0);
|
||||||
|
AscendC::Duplicate(tmpBuffer1, 0.0f, (params.EP + 1) * FLAGSTRIDE);
|
||||||
|
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
|
||||||
|
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
|
||||||
|
|
||||||
|
AscendC::GlobalTensor<float> flagGlobalBase;
|
||||||
|
flagGlobalBase.SetGlobalBuffer(workspaceInfo.ptrSoftFlagBase);
|
||||||
|
AscendC::DataCopy(flagGlobalBase, tmpBuffer1, (params.EP + 1) * FLAGSTRIDE);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
CATLASS_DEVICE
|
CATLASS_DEVICE
|
||||||
void Dispatch(Params const ¶ms) {
|
void CrossRankSyncAndlocalTokenPerExpertAllGatherAndGetSumPreRankV2(Params const ¶ms, int64_t localTokenPerExpertOffset){
|
||||||
|
uint32_t numPerCore = AlignUp(params.EP * params.expertPerRank, 128);
|
||||||
|
AscendC::LocalTensor<int32_t> tmpBuffer = resource.ubBuf.template GetBufferByByte<int32_t>(0);
|
||||||
|
AscendC::LocalTensor<int32_t> prevSumBuf = tmpBuffer[numPerCore];
|
||||||
|
|
||||||
|
for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) {
|
||||||
|
if (dstEpIdx == params.rank) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
AscendC::GlobalTensor<int32_t> srcAddress;
|
||||||
|
srcAddress.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(shmem() + localTokenPerExpertOffset));
|
||||||
|
AscendC::GlobalTensor<int32_t> dstAddress;
|
||||||
|
__gm__ void* dstPeermemPtr = shmem(localTokenPerExpertOffset, coreIdx);
|
||||||
|
dstAddress.SetGlobalBuffer((__gm__ int32_t * )dstPeermemPtr);
|
||||||
|
|
||||||
|
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
|
||||||
|
using TType = Gemm::GemmType<int32_t, layout::RowMajor>;
|
||||||
|
using CopyGmToUb = Epilogue::Tile::CopyGm2Ub<ArchTag, TType>;
|
||||||
|
using CopyUbToGm = Epilogue::Tile::CopyUb2Gm<ArchTag, TType>;
|
||||||
|
CopyGmToUb copyGmToUb;
|
||||||
|
CopyUbToGm copyUbToGm;
|
||||||
|
|
||||||
|
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
|
||||||
|
|
||||||
|
copyGmToUb(tmpBuffer, srcAddress[0],
|
||||||
|
layout::RowMajor{ 1, numPerCore},
|
||||||
|
layout::RowMajor{1, numPerCore});
|
||||||
|
|
||||||
|
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
|
||||||
|
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
|
||||||
|
AscendC::Adds(tmpBuffer, tmpBuffer, 0x800000, numPerCore);
|
||||||
|
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
|
||||||
|
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
|
||||||
|
copyUbToGm(dstAddress[0], tmpBuffer,
|
||||||
|
layout::RowMajor{ 1, numPerCore},
|
||||||
|
layout::RowMajor{1, numPerCore});
|
||||||
|
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
|
||||||
|
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
|
||||||
|
}
|
||||||
|
for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) {
|
||||||
|
if (dstEpIdx != params.rank) {
|
||||||
|
int32_t intPer512 = CACHE_LINE / sizeof(int);
|
||||||
|
for(int32_t checkIdx = 0; checkIdx < AlignUp(params.EP * params.expertPerRank, 128); checkIdx += intPer512) {
|
||||||
|
__gm__ int32_t* sync_check = reinterpret_cast<__gm__ int32_t*>(shmem() + peermemInfo.offsetPeerTokenPerExpert) + tokenPerExpertLayout(dstEpIdx, 0, checkIdx);
|
||||||
|
gm_signal_wait_until_ne(sync_check, 0);
|
||||||
|
}
|
||||||
|
AscendC::DataCopy(tmpBuffer, tokenPerExpert[tokenPerExpertLayout(dstEpIdx, 0, 0)], numPerCore);
|
||||||
|
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
|
||||||
|
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
|
||||||
|
AscendC::Adds(tmpBuffer, tmpBuffer, -0x800000, numPerCore);
|
||||||
|
AscendC::PipeBarrier<PIPE_V>();
|
||||||
|
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
|
||||||
|
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
|
||||||
|
AscendC::DataCopy(tokenPerExpert[tokenPerExpertLayout(dstEpIdx, 0, 0)], tmpBuffer, numPerCore);
|
||||||
|
} else {
|
||||||
|
AscendC::DataCopy(tmpBuffer, tokenPerExpert[tokenPerExpertLayout(dstEpIdx, 0, 0)], numPerCore);
|
||||||
|
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
|
||||||
|
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
|
||||||
|
}
|
||||||
|
AscendC::PipeBarrier<PIPE_ALL>();
|
||||||
|
int32_t prevSum = 0;
|
||||||
|
int32_t j = 0;
|
||||||
|
for (int32_t i = 0; i < (params.rank + 1) * params.expertPerRank; i++) {
|
||||||
|
if (i >= params.rank * params.expertPerRank) {
|
||||||
|
prevSumBuf(j) = prevSum;
|
||||||
|
j++;
|
||||||
|
}
|
||||||
|
prevSum += tmpBuffer(i);
|
||||||
|
}
|
||||||
|
AscendC::SetFlag<AscendC::HardEvent::S_MTE3>(EVENT_ID0);
|
||||||
|
AscendC::WaitFlag<AscendC::HardEvent::S_MTE3>(EVENT_ID0);
|
||||||
|
AscendC::DataCopyPad(preSumBeforeRank[dstEpIdx * params.expertPerRank], prevSumBuf,
|
||||||
|
AscendC::DataCopyParams{1, static_cast<uint16_t>(params.expertPerRank * sizeof(int32_t)), 0, 0});
|
||||||
|
}
|
||||||
|
|
||||||
|
AscendC::SyncAll<true>();
|
||||||
|
}
|
||||||
|
|
||||||
|
CATLASS_DEVICE
|
||||||
|
void ResetTokenPerExpert(int32_t num)
|
||||||
|
{
|
||||||
|
if (coreIdx != coreNum - 1) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID0);
|
||||||
|
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID0);
|
||||||
|
AscendC::LocalTensor<int32_t> tmp = resource.ubBuf.template GetBufferByByte<int32_t>(0);
|
||||||
|
AscendC::Duplicate(tmp, 0, num);
|
||||||
|
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
|
||||||
|
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
|
||||||
|
AscendC::DataCopy(tokenPerExpert, tmp, num);
|
||||||
|
}
|
||||||
|
|
||||||
|
CATLASS_DEVICE
|
||||||
|
void UpdateAicFlags(const Params ¶ms)
|
||||||
|
{
|
||||||
|
float flagBase = 1.0f * params.expertPerRank;
|
||||||
|
__gm__ float* aicFinishPtr = workspaceInfo.ptrSoftFlagBase + params.EP * FLAGSTRIDE;
|
||||||
|
float flag = 0.0f;
|
||||||
|
float lastflag = -1.0f;
|
||||||
|
AscendC::LocalTensor<float> tmpBuffer1 = resource.ubBuf.template GetBufferByByte<float>(0);
|
||||||
|
__gm__ float* flagPtr = workspaceInfo.ptrSoftFlagBase;
|
||||||
|
AscendC::GlobalTensor<float> flagGM;
|
||||||
|
flagGM.SetGlobalBuffer(flagPtr);
|
||||||
|
int32_t flagBufferSize = max(4, params.EP) * FLAGSTRIDE;
|
||||||
|
AscendC::LocalTensor<float> dstValueBuffer = resource.ubBuf.template GetBufferByByte<float>(flagBufferSize);
|
||||||
|
AscendC::LocalTensor<float> sharedTmpBuffer = resource.ubBuf.template GetBufferByByte<float>((flagBufferSize + 64));
|
||||||
|
uint64_t mask[1] = {0};
|
||||||
|
uint32_t repeatNum = (flagBufferSize / (4 * FLAGSTRIDE));
|
||||||
|
for (int32_t i = 0; i < 4; i ++) {
|
||||||
|
if (i < params.EP) {
|
||||||
|
mask[0] |= 1ull * (1ull << (i * 16));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
AscendC::SetFlag<AscendC::HardEvent::S_V>(EVENT_ID0);
|
||||||
|
AscendC::WaitFlag<AscendC::HardEvent::S_V>(EVENT_ID0);
|
||||||
|
while (flag < flagBase) {
|
||||||
|
flag = flagBase;
|
||||||
|
AscendC::DataCopy(tmpBuffer1, flagGM, params.EP * FLAGSTRIDE);
|
||||||
|
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
|
||||||
|
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
|
||||||
|
|
||||||
|
AscendC::ReduceMin<float>(dstValueBuffer, tmpBuffer1, sharedTmpBuffer, mask, repeatNum, 8, false);
|
||||||
|
|
||||||
|
AscendC::SetFlag<AscendC::HardEvent::V_S>(EVENT_ID0);
|
||||||
|
AscendC::WaitFlag<AscendC::HardEvent::V_S>(EVENT_ID0);
|
||||||
|
flag = min(flag, dstValueBuffer.GetValue(0));
|
||||||
|
|
||||||
|
if (flag > lastflag) {
|
||||||
|
*aicFinishPtr = flag;
|
||||||
|
gm_dcci(aicFinishPtr);
|
||||||
|
lastflag = flag;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
CATLASS_DEVICE
|
||||||
|
void CombineSetFlag() {
|
||||||
|
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
|
||||||
|
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID1);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
CATLASS_DEVICE
|
||||||
|
void DispatchAndCombine(Params const ¶ms) {
|
||||||
icache_preload(8);
|
icache_preload(8);
|
||||||
int64_t localTokenPerExpertOffset = peermemInfo.offsetPeerTokenPerExpert + tokenPerExpertLayout(params.rank, 0, 0) * sizeof(int32_t);
|
int64_t localTokenPerExpertOffset = peermemInfo.offsetPeerTokenPerExpert + tokenPerExpertLayout(params.rank, 0, 0) * sizeof(int32_t);
|
||||||
GM_ADDR localTokenPerExpert = shmem() + localTokenPerExpertOffset;
|
GM_ADDR localTokenPerExpert = shmem() + localTokenPerExpertOffset;
|
||||||
@@ -744,35 +754,35 @@ CATLASS_DEVICE
|
|||||||
AscendC::SyncAll<true>();
|
AscendC::SyncAll<true>();
|
||||||
CrossRankSyncAndlocalTokenPerExpertAllGatherAndGetSumPreRankV2(params, localTokenPerExpertOffset);
|
CrossRankSyncAndlocalTokenPerExpertAllGatherAndGetSumPreRankV2(params, localTokenPerExpertOffset);
|
||||||
|
|
||||||
if (coreIdx == coreNum - 1) {
|
if (coreIdx == 0) {
|
||||||
GetCumsumForMMAIV(tokenPerExpert, cumsumMM, params.expertPerRank, params.rank, params.EP);
|
GetCumsumForMMAIV(tokenPerExpert, cumsumMM, params.expertPerRank, params.rank, params.EP);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
uint32_t curGroupOffset = 0;
|
||||||
|
int32_t prevSumBeforeRank = 0;
|
||||||
|
int32_t prevSum = 0;
|
||||||
|
if (coreIdx < params.EP) {
|
||||||
|
prevSum = preSumBeforeRank(coreIdx * params.expertPerRank);
|
||||||
|
}
|
||||||
AscendC::SyncAll<true>();
|
AscendC::SyncAll<true>();
|
||||||
|
|
||||||
|
AscendC::GlobalTensor<int32_t> ExpertTokenNums;
|
||||||
|
ExpertTokenNums.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(params.ptrExpertTokenNums));
|
||||||
|
if(coreIdx == 0)
|
||||||
|
{
|
||||||
|
CopyGMToGM(ExpertTokenNums, cumsumMM[(params.EP - 1) * params.expertPerRank], params.expertPerRank, params.ubMoveNum);
|
||||||
|
}
|
||||||
uint16_t syncgmm1Idx = 0;
|
uint16_t syncgmm1Idx = 0;
|
||||||
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(syncgmm1Idx / CROSS_CORE_FLAG_MAX_SET_COUNT);
|
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(syncgmm1Idx / CROSS_CORE_FLAG_MAX_SET_COUNT);
|
||||||
syncgmm1Idx++;
|
syncgmm1Idx++;
|
||||||
|
|
||||||
AscendC::GlobalTensor<int32_t> ExpertTokenNums;
|
uint32_t prevGroupSum1 = 0, dequantSum1 = 0, dequantSum2 = 0;
|
||||||
ExpertTokenNums.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(params.ptrExpertTokenNums));
|
|
||||||
AscendC::GlobalTensor<int32_t> LcalCumsumMM;
|
|
||||||
LcalCumsumMM.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(workspaceInfo.ptrcumsumMM + (params.EP - 1) * params.expertPerRank * sizeof(int32_t)));
|
|
||||||
if (coreIdx == 0) {
|
|
||||||
CopyGMToGM(ExpertTokenNums, LcalCumsumMM, params.expertPerRank, params.ubMoveNum);
|
|
||||||
}
|
|
||||||
|
|
||||||
uint32_t curGroupOffset = 0;
|
|
||||||
int32_t prevSumBeforeRank = 0;
|
|
||||||
int32_t groupIdxDeq = 0;
|
|
||||||
int prevSum = 0;
|
|
||||||
if (coreIdx < params.EP) {
|
|
||||||
prevSum = preSumBeforeRank(coreIdx * 16);
|
|
||||||
}
|
|
||||||
uint32_t prevGroupSum1 = 0;
|
|
||||||
uint32_t dequantSum = 0;
|
uint32_t dequantSum = 0;
|
||||||
int32_t syncLoopIdx = -1;
|
|
||||||
uint32_t n = params.problemShape.n();
|
icache_preload(8);
|
||||||
BlockEpilogue1 blockEpilogue(resource, n);
|
|
||||||
for (int32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) {
|
for (int32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) {
|
||||||
|
// The ith core reads data from the ith rank's peermem
|
||||||
|
uint32_t currentM = cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx);
|
||||||
for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) {
|
for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) {
|
||||||
uint32_t rowStart = (dstEpIdx == 0 ? 0 : cumsumMM((dstEpIdx - 1) * params.expertPerRank + groupIdx)) + prevGroupSum1;
|
uint32_t rowStart = (dstEpIdx == 0 ? 0 : cumsumMM((dstEpIdx - 1) * params.expertPerRank + groupIdx)) + prevGroupSum1;
|
||||||
if (rowStart < params.maxOutputSize) {
|
if (rowStart < params.maxOutputSize) {
|
||||||
@@ -785,99 +795,40 @@ CATLASS_DEVICE
|
|||||||
GM_ADDR otherRankPtr = shmem(0, dstEpIdx);
|
GM_ADDR otherRankPtr = shmem(0, dstEpIdx);
|
||||||
AscendC::GlobalTensor<ElementA> gmRemoteA;
|
AscendC::GlobalTensor<ElementA> gmRemoteA;
|
||||||
gmRemoteA.SetGlobalBuffer(reinterpret_cast<__gm__ ElementA*>(otherRankPtr + peermemInfo.offsetA));
|
gmRemoteA.SetGlobalBuffer(reinterpret_cast<__gm__ ElementA*>(otherRankPtr + peermemInfo.offsetA));
|
||||||
AscendC::GlobalTensor<ElementPerTokenScale> gmRemotePerTokenScale;
|
|
||||||
gmRemotePerTokenScale.SetGlobalBuffer(reinterpret_cast<__gm__ ElementPerTokenScale*>(otherRankPtr + peermemInfo.offsetPeerPerTokenScale));
|
|
||||||
MatrixCoord offsetA{rowStart, 0};
|
MatrixCoord offsetA{rowStart, 0};
|
||||||
MatrixCoord shapeA{rows, params.problemShape.k()};
|
|
||||||
MatrixCoord offsetPeer{rowSrc, 0};
|
MatrixCoord offsetPeer{rowSrc, 0};
|
||||||
int64_t gmOffsetA = params.layoutA.GetOffset(offsetA);
|
int64_t gmOffsetA = params.layoutA.GetOffset(offsetA);
|
||||||
int64_t gmOffsetPeer = params.layoutA.GetOffset(offsetPeer);
|
int64_t gmOffsetPeer = params.layoutA.GetOffset(offsetPeer);
|
||||||
CopyGMToGM(gmA[gmOffsetA], gmRemoteA[gmOffsetPeer], rows * params.problemShape.k(), params.ubMoveNum);
|
CopyGMToGM(gmA[gmOffsetA], gmRemoteA[gmOffsetPeer], rows * params.problemShape.k(), params.ubMoveNum);
|
||||||
if constexpr (std::is_same_v<ElementA, int8_t>) {
|
|
||||||
CopyGMToGM(gmPerTokenScale1[rowStart], gmRemotePerTokenScale[rowSrc], rows, rows);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if ((params.epilogueGranularity < params.expertPerRank && params.epilogueGranularity > 0) && groupIdx == params.expertPerRank - 1) {
|
|
||||||
syncLoopIdx++;
|
|
||||||
AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGC2V);
|
|
||||||
}
|
|
||||||
AscendC::SyncAll<true>();
|
AscendC::SyncAll<true>();
|
||||||
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(syncgmm1Idx / CROSS_CORE_FLAG_MAX_SET_COUNT); // V notifies C that the current communication round is complete
|
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(syncgmm1Idx / CROSS_CORE_FLAG_MAX_SET_COUNT);
|
||||||
syncgmm1Idx++;
|
syncgmm1Idx ++;
|
||||||
|
|
||||||
if ((params.epilogueGranularity < params.expertPerRank && params.epilogueGranularity > 0) && groupIdx == params.expertPerRank - 1 && prevGroupSum1 > 0) {
|
prevGroupSum1 += currentM;
|
||||||
uint32_t rowStartThisCore = 0;
|
|
||||||
MatrixCoord offsetC{0U, 0};
|
|
||||||
uint32_t dequantLen = prevGroupSum1 - dequantSum;
|
|
||||||
if (dequantLen >= params.maxOutputSize) {
|
|
||||||
dequantLen = dequantLen - params.maxOutputSize;
|
|
||||||
}
|
|
||||||
|
|
||||||
MatrixCoord shapeC{dequantLen, params.problemShape.n()};
|
// Token count and truncation logic for the first SwiGLU operation
|
||||||
LayoutC layoutC{dequantLen, params.problemShape.n()};
|
if (groupIdx + 1 <= params.epilogueGranularity) {
|
||||||
int64_t gmOffsetC = layoutC.GetOffset(offsetC);
|
if (dequantSum1 + currentM <= params.maxOutputSize) {
|
||||||
int64_t gmOffsetD = params.layoutD1.GetOffset(offsetC);
|
dequantSum1 += currentM;
|
||||||
if constexpr (std::is_same_v<ElementA, int8_t>) {
|
} else if (dequantSum1 < params.maxOutputSize) {
|
||||||
blockEpilogue(gmC[gmOffsetC], shapeC, gmPerTokenScale1[rowStartThisCore], gmPermutedToken[gmOffsetD], gmPerTokenScale2[rowStartThisCore], params.epilogueCoreNum);
|
dequantSum1 = params.maxOutputSize;
|
||||||
} else {
|
|
||||||
blockEpilogue(gmC[gmOffsetC], shapeC, gmPermutedToken[gmOffsetD], params.epilogueCoreNum);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
prevGroupSum1 += cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx);
|
|
||||||
dequantSum += cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx);
|
|
||||||
if (groupIdx + 1 == params.epilogueGranularity && groupIdx < params.expertPerRank - 1) {
|
|
||||||
dequantSum = 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
syncLoopIdx ++;
|
|
||||||
|
|
||||||
AscendC::SyncAll<true>();
|
// Token count and truncation logic for the second SwiGLU operation
|
||||||
|
if (groupIdx + 1 > params.epilogueGranularity && dequantSum1 < params.maxOutputSize) {
|
||||||
uint32_t lastDequantExpertNum = params.expertPerRank;
|
if (dequantSum1 + dequantSum2 + currentM <= params.maxOutputSize) {
|
||||||
if (params.epilogueGranularity < params.expertPerRank) {
|
dequantSum2 += currentM;
|
||||||
lastDequantExpertNum = params.expertPerRank - params.epilogueGranularity;
|
} else if (dequantSum1 + dequantSum2 < params.maxOutputSize) {
|
||||||
}
|
dequantSum2 += params.maxOutputSize - dequantSum1 - dequantSum2;
|
||||||
if (lastDequantExpertNum < params.expertPerRank) {
|
}
|
||||||
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(SYNCFLAGV2C);
|
|
||||||
}
|
|
||||||
AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGC2V);
|
|
||||||
AscendC::SyncAll<true>();
|
|
||||||
if (prevGroupSum1 - dequantSum < params.maxOutputSize) {
|
|
||||||
uint32_t rowStartThisCore = prevGroupSum1 - dequantSum;;
|
|
||||||
MatrixCoord offsetC{rowStartThisCore, 0};
|
|
||||||
uint32_t dequantLen = dequantSum;
|
|
||||||
if (prevGroupSum1 >= params.maxOutputSize) {
|
|
||||||
dequantLen = dequantSum - (prevGroupSum1 - params.maxOutputSize);
|
|
||||||
}
|
|
||||||
MatrixCoord shapeC{dequantLen, params.problemShape.n()};
|
|
||||||
LayoutC layoutC{dequantLen, params.problemShape.n()};
|
|
||||||
int64_t gmOffsetC = layoutC.GetOffset(offsetC);
|
|
||||||
int64_t gmOffsetD = params.layoutD1.GetOffset(offsetC);
|
|
||||||
if constexpr (std::is_same_v<ElementA, int8_t>) {
|
|
||||||
blockEpilogue(gmC[gmOffsetC], shapeC, gmPerTokenScale1[rowStartThisCore], gmPermutedToken[gmOffsetD], gmPerTokenScale2[rowStartThisCore], coreNum);
|
|
||||||
} else {
|
|
||||||
blockEpilogue(gmC[gmOffsetC], shapeC, gmPermutedToken[gmOffsetD], coreNum);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
blockEpilogue.Finalize();
|
|
||||||
}
|
|
||||||
|
|
||||||
CATLASS_DEVICE
|
|
||||||
void CombineV2(Params const ¶ms) {
|
|
||||||
BlockScheduler blockScheduler;
|
|
||||||
uint32_t n2 = params.problemShape.k();
|
uint32_t n2 = params.problemShape.k();
|
||||||
uint32_t k2 = params.problemShape.n() / 2;
|
|
||||||
uint32_t startCoreIdx = 0;
|
|
||||||
int64_t gmGroupOffsetC = 0;
|
|
||||||
uint32_t aivCoreNum = coreNum;
|
|
||||||
uint32_t aicCoreNum = coreNum / 2;
|
|
||||||
uint32_t aivCoreIdx = coreIdx;
|
|
||||||
uint32_t aicCoreIdx = get_block_idx();
|
|
||||||
uint32_t aivSubCoreIdx = get_subblockid();
|
|
||||||
uint32_t preSrcExpertSum = 0;
|
|
||||||
|
|
||||||
typename BlockEpilogue2::Params epilogueParams{
|
typename BlockEpilogue2::Params epilogueParams{
|
||||||
static_cast<int32_t>(params.EP),
|
static_cast<int32_t>(params.EP),
|
||||||
@@ -891,11 +842,108 @@ CATLASS_DEVICE
|
|||||||
static_cast<int32_t>(peermemInfo.offsetD)
|
static_cast<int32_t>(peermemInfo.offsetD)
|
||||||
};
|
};
|
||||||
|
|
||||||
BlockEpilogue2 blockEpilogue(resource, epilogueParams);
|
BlockEpilogue2 blockEpilogue2(resource, epilogueParams);
|
||||||
|
|
||||||
|
uint32_t n = params.problemShape.n();
|
||||||
|
BlockEpilogue1 blockEpilogue1(resource, n);
|
||||||
|
|
||||||
|
// Synchronous wait: SwiGLU waits for GMM1 [1]
|
||||||
|
AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGC2V);
|
||||||
|
AscendC::SyncAll<true>();
|
||||||
|
if (dequantSum1 > 0) {
|
||||||
|
uint32_t rowStartThisCore = 0;
|
||||||
|
MatrixCoord offsetC{0U, 0};
|
||||||
|
MatrixCoord shapeC{dequantSum1, params.problemShape.n()};
|
||||||
|
LayoutC layoutC{dequantSum1, params.problemShape.n()};
|
||||||
|
int64_t gmOffsetC = layoutC.GetOffset(offsetC);
|
||||||
|
int64_t gmOffsetD = params.layoutD1.GetOffset(offsetC);
|
||||||
|
// blockEpilogue1(gmC[gmOffsetC], shapeC, gmPerTokenScale1[rowStartThisCore], gmPermutedToken[gmOffsetD], gmPerTokenScale2[rowStartThisCore], params.epilogueCoreNum);
|
||||||
|
blockEpilogue1(gmC[gmOffsetC], shapeC, gmPermutedToken[gmOffsetD], params.epilogueCoreNum);
|
||||||
|
}
|
||||||
|
AscendC::SyncAll<true>();
|
||||||
|
// Synchronization signal: SwiGLU notifies GMM2 [1]
|
||||||
|
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(SYNCFLAGV2C);
|
||||||
|
|
||||||
|
if ((params.epilogueGranularity < params.expertPerRank && params.epilogueGranularity > 0)) {
|
||||||
|
// Synchronous wait: SwiGLU waits for GMM1 [2]
|
||||||
|
AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGC2V);
|
||||||
|
AscendC::SyncAll<true>();
|
||||||
|
if (dequantSum2 > 0) {
|
||||||
|
uint32_t rowStartThisCore = dequantSum1;
|
||||||
|
MatrixCoord offsetC{rowStartThisCore, 0};
|
||||||
|
uint32_t dequantLen = dequantSum2;
|
||||||
|
MatrixCoord shapeC{dequantLen, params.problemShape.n()};
|
||||||
|
LayoutC layoutC{dequantLen, params.problemShape.n()};
|
||||||
|
int64_t gmOffsetC = layoutC.GetOffset(offsetC);
|
||||||
|
int64_t gmOffsetD = params.layoutD1.GetOffset(offsetC);
|
||||||
|
// blockEpilogue1(gmC[gmOffsetC], shapeC, gmPerTokenScale1[rowStartThisCore], gmPermutedToken[gmOffsetD], gmPerTokenScale2[rowStartThisCore], coreNum);
|
||||||
|
blockEpilogue1(gmC[gmOffsetC], shapeC, gmPermutedToken[gmOffsetD], coreNum);
|
||||||
|
}
|
||||||
|
AscendC::SyncAll<true>();
|
||||||
|
// Synchronization signal: SwiGLU notifies GMM2 [2]
|
||||||
|
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(SYNCFLAGV2C);
|
||||||
|
}
|
||||||
|
|
||||||
|
blockEpilogue1.Finalize();
|
||||||
|
|
||||||
|
CombineSetFlag();
|
||||||
|
|
||||||
|
CombineV2(params, blockEpilogue2);
|
||||||
|
|
||||||
|
AscendC::SyncAll<true>();
|
||||||
|
#ifndef __CROSSRANKSYNCANDALLGATHERV1__
|
||||||
|
ResetTokenPerExpert(params.EP * AlignUp(params.EP * params.expertPerRank, 128));
|
||||||
|
#endif
|
||||||
|
shmem.InitStatusTargetSum();
|
||||||
|
if (get_subblockid() == 0) {
|
||||||
|
AscendC::LocalTensor<int32_t> ctrBuffer = resource.ubBuf.template GetBufferByByte<int32_t>(0);
|
||||||
|
shmem.CrossRankSyncV2Set(ctrBuffer);
|
||||||
|
} else {
|
||||||
|
uint32_t uboffset = 0;
|
||||||
|
uint32_t aicCoreNum = coreNum / 2;
|
||||||
|
uint32_t aicCoreIdx = get_block_idx();
|
||||||
|
uint32_t sendRankNum_ = params.EP / aicCoreNum;
|
||||||
|
uint32_t remainderRankNum = params.EP % aicCoreNum;
|
||||||
|
if (aicCoreIdx < remainderRankNum) {
|
||||||
|
sendRankNum_++;
|
||||||
|
}
|
||||||
|
AscendC::LocalTensor<float> statusTensor = resource.ubBuf.template GetBufferByByte<float>(uboffset);
|
||||||
|
uboffset += sendRankNum_ * UB_ALIGN;
|
||||||
|
AscendC::LocalTensor<float> gatherMaskOutTensor = resource.ubBuf.template GetBufferByByte<float>(uboffset);
|
||||||
|
uboffset += AlignUp(params.EP * sizeof(float), 32);
|
||||||
|
AscendC::LocalTensor<uint32_t> gatherTmpTensor = resource.ubBuf.template GetBufferByByte<uint32_t>(uboffset);
|
||||||
|
uboffset += AlignUp(sizeof(uint32_t), 32);
|
||||||
|
AscendC::LocalTensor<float> statusSumOutTensor = resource.ubBuf.template GetBufferByByte<float>(uboffset);
|
||||||
|
uboffset += AlignUp(sizeof(float), 32);
|
||||||
|
shmem.CrossRankSyncV2Wait(statusTensor, gatherMaskOutTensor, gatherTmpTensor, statusSumOutTensor);
|
||||||
|
MoeTokenUnpermuteTilingData tilingData;
|
||||||
|
MoeTokenUnpermuteTiling(params.problemShape.m() * params.topK, n2, params.topK, tilingData, coreNum / 2);
|
||||||
|
KernelMoeTokenUnpermute<ElementD2, int32_t, float, true> kernelMoeTokenUnpermuteOp;
|
||||||
|
kernelMoeTokenUnpermuteOp.Init(shmem() + peermemInfo.offsetD, workspaceInfo.expandedRowIdx, params.probs, reinterpret_cast<GM_ADDR>(params.ptrOutput), &tilingData);
|
||||||
|
kernelMoeTokenUnpermuteOp.Process();
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
CATLASS_DEVICE
|
||||||
|
void CombineV2(Params const ¶ms, BlockEpilogue2 & blockEpilogue) {
|
||||||
|
BlockScheduler blockScheduler;
|
||||||
int32_t syncLoopIdx = 0;
|
int32_t syncLoopIdx = 0;
|
||||||
|
uint32_t startCoreIdx = 0;
|
||||||
|
uint32_t aicCoreNum = coreNum / 2;
|
||||||
|
uint32_t aicCoreIdx = get_block_idx();
|
||||||
|
uint32_t aivSubCoreIdx = get_subblockid();
|
||||||
|
uint32_t preSrcExpertSum = 0;
|
||||||
|
uint32_t n2 = params.problemShape.k();
|
||||||
|
uint32_t k2 = params.problemShape.n() / 2;
|
||||||
|
icache_preload(8);
|
||||||
for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) {
|
for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) {
|
||||||
uint32_t currentExpertM = cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx);
|
uint32_t currentExpertM = cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx);
|
||||||
|
if (preSrcExpertSum >= params.maxOutputSize) {
|
||||||
|
currentExpertM = 0;
|
||||||
|
} else if (preSrcExpertSum + currentExpertM > params.maxOutputSize) {
|
||||||
|
currentExpertM = params.maxOutputSize - preSrcExpertSum;
|
||||||
|
}
|
||||||
GemmCoord inGroupProblemShape{currentExpertM, n2, k2}; // M N K
|
GemmCoord inGroupProblemShape{currentExpertM, n2, k2}; // M N K
|
||||||
blockScheduler.Update(inGroupProblemShape, MakeCoord(L1TileShape::M, L1TileShape::N));
|
blockScheduler.Update(inGroupProblemShape, MakeCoord(L1TileShape::M, L1TileShape::N));
|
||||||
uint32_t coreLoops = blockScheduler.GetCoreLoops();
|
uint32_t coreLoops = blockScheduler.GetCoreLoops();
|
||||||
@@ -916,11 +964,10 @@ CATLASS_DEVICE
|
|||||||
if(aivSubCoreIdx == 1) {
|
if(aivSubCoreIdx == 1) {
|
||||||
m_offset += (m_rows / 2) * m0;
|
m_offset += (m_rows / 2) * m0;
|
||||||
}
|
}
|
||||||
if (loopIdx == startLoopIdx) {
|
|
||||||
for (;syncLoopIdx <= groupIdx; syncLoopIdx++) {
|
for (;syncLoopIdx <= groupIdx; syncLoopIdx ++) {
|
||||||
int32_t flag_id = syncLoopIdx / CROSS_CORE_FLAG_MAX_SET_COUNT;
|
int32_t flag_id = syncLoopIdx / CROSS_CORE_FLAG_MAX_SET_COUNT;
|
||||||
AscendC::CrossCoreWaitFlag<0x2>(flag_id);
|
AscendC::CrossCoreWaitFlag<0x2>(flag_id);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int32_t cur_row = 0; cur_row < aiv_m_rows; cur_row ++) {
|
for (int32_t cur_row = 0; cur_row < aiv_m_rows; cur_row ++) {
|
||||||
@@ -930,34 +977,15 @@ CATLASS_DEVICE
|
|||||||
actualm = actualBlockShape.m() - (m_rows / 2) * m0 - cur_row * m0;
|
actualm = actualBlockShape.m() - (m_rows / 2) * m0 - cur_row * m0;
|
||||||
}
|
}
|
||||||
GemmCoord realTileShape{actualm, actualBlockShape.n(), 1};
|
GemmCoord realTileShape{actualm, actualBlockShape.n(), 1};
|
||||||
if constexpr (std::is_same_v<ElementA, int8_t>) {
|
blockEpilogue(gmC2, realTileCoord, realTileShape, groupIdx, preSrcExpertSum, preSumBeforeRank);
|
||||||
blockEpilogue(gmC2, gmPerTokenScale2, realTileCoord, realTileShape, groupIdx, preSrcExpertSum, preSumBeforeRank, mPreSumBeforeRank);
|
|
||||||
} else {
|
|
||||||
blockEpilogue(gmC2, realTileCoord, realTileShape, groupIdx, preSrcExpertSum, preSumBeforeRank, mPreSumBeforeRank);
|
|
||||||
}
|
|
||||||
m_offset += m0;
|
m_offset += m0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int32_t dstEpIdx = 0; dstEpIdx < params.EP; dstEpIdx ++) {
|
|
||||||
int32_t expertRankM = tokenPerExpert(tokenPerExpertLayout(dstEpIdx, params.rank, groupIdx));
|
|
||||||
mPreSumBeforeRank[dstEpIdx] += expertRankM;
|
|
||||||
}
|
|
||||||
preSrcExpertSum += currentExpertM;
|
preSrcExpertSum += currentExpertM;
|
||||||
startCoreIdx = (startCoreIdx + coreLoops) % aicCoreNum;
|
startCoreIdx = (startCoreIdx + coreLoops) % aicCoreNum;
|
||||||
}
|
}
|
||||||
|
|
||||||
blockEpilogue.Finalize();
|
blockEpilogue.Finalize();
|
||||||
AscendC::SyncAll<true>();
|
|
||||||
ResetTokenPerExpert(tokenPerExpert, params.EP * params.EP * params.expertPerRank);
|
|
||||||
shmem.CrossRankSync();
|
|
||||||
|
|
||||||
MoeTokenUnpermuteTilingData tilingData;
|
|
||||||
MoeTokenUnpermuteTiling(params.problemShape.m() * params.topK, n2, params.topK, tilingData, aivCoreNum);
|
|
||||||
KernelMoeTokenUnpermute<ElementD2, int32_t, float, true> kernelMoeTokenUnpermuteOp;
|
|
||||||
|
|
||||||
kernelMoeTokenUnpermuteOp.Init(shmem() + peermemInfo.offsetD, workspaceInfo.expandedRowIdx, params.probs, reinterpret_cast<GM_ADDR>(params.ptrOutput), &tilingData);
|
|
||||||
kernelMoeTokenUnpermuteOp.Process();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -973,6 +1001,7 @@ private:
|
|||||||
GM_ADDR expandedRowIdx;
|
GM_ADDR expandedRowIdx;
|
||||||
GM_ADDR ptrTokenPerExpert;
|
GM_ADDR ptrTokenPerExpert;
|
||||||
GM_ADDR ptrSumBeforeRank;
|
GM_ADDR ptrSumBeforeRank;
|
||||||
|
__gm__ float* ptrSoftFlagBase;
|
||||||
|
|
||||||
CATLASS_DEVICE
|
CATLASS_DEVICE
|
||||||
WorkspaceInfo(){}
|
WorkspaceInfo(){}
|
||||||
@@ -1012,6 +1041,9 @@ private:
|
|||||||
|
|
||||||
workspaceOffset += params.maxOutputSize * k2 * sizeof(ElementA);
|
workspaceOffset += params.maxOutputSize * k2 * sizeof(ElementA);
|
||||||
ptrSumBeforeRank = params.ptrWorkspace + workspaceOffset;
|
ptrSumBeforeRank = params.ptrWorkspace + workspaceOffset;
|
||||||
|
workspaceOffset += params.EP * sizeof(int32_t) * FLAGSTRIDE;
|
||||||
|
ptrSoftFlagBase = reinterpret_cast<__gm__ float*>(params.ptrWorkspace + workspaceOffset);
|
||||||
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -1041,7 +1073,6 @@ private:
|
|||||||
WorkspaceInfo workspaceInfo;
|
WorkspaceInfo workspaceInfo;
|
||||||
PeermemInfo peermemInfo;
|
PeermemInfo peermemInfo;
|
||||||
|
|
||||||
int64_t m_prevSumBeforeRank;
|
|
||||||
|
|
||||||
AscendC::GlobalTensor<ElementA> gmA;
|
AscendC::GlobalTensor<ElementA> gmA;
|
||||||
AscendC::GlobalTensor<ElementC> gmC;
|
AscendC::GlobalTensor<ElementC> gmC;
|
||||||
@@ -1057,7 +1088,7 @@ private:
|
|||||||
AscendC::GlobalTensor<int32_t> tokenPerExpert;
|
AscendC::GlobalTensor<int32_t> tokenPerExpert;
|
||||||
AscendC::GlobalTensor<int32_t> cumsumMM;
|
AscendC::GlobalTensor<int32_t> cumsumMM;
|
||||||
AscendC::GlobalTensor<int32_t> preSumBeforeRank;
|
AscendC::GlobalTensor<int32_t> preSumBeforeRank;
|
||||||
uint32_t mPreSumBeforeRank[32] = {0};
|
|
||||||
Layout3D tokenPerExpertLayout;
|
Layout3D tokenPerExpertLayout;
|
||||||
HcclShmem shmem;
|
HcclShmem shmem;
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -85,8 +85,8 @@ KernelMoeTokenUnpermute<T1, T2, T3, PROBS>::Init(GM_ADDR permuted_tokens, GM_ADD
|
|||||||
GM_ADDR unpermuted_tokens,
|
GM_ADDR unpermuted_tokens,
|
||||||
const MoeTokenUnpermuteTilingData *__restrict tiling_data)
|
const MoeTokenUnpermuteTilingData *__restrict tiling_data)
|
||||||
{
|
{
|
||||||
this->blockIdx = get_block_idx() + get_subblockid() * get_block_num();
|
this->blockIdx = get_block_idx();
|
||||||
this->blockNum = get_block_num() * get_subblockdim();
|
this->blockNum = get_block_num();
|
||||||
|
|
||||||
if (blockIdx >= blockNum) {
|
if (blockIdx >= blockNum) {
|
||||||
return;
|
return;
|
||||||
|
|||||||
@@ -72,17 +72,6 @@ public:
|
|||||||
CATLASS_DEVICE
|
CATLASS_DEVICE
|
||||||
BlockEpilogue(Arch::Resource<ArchTag> const &resource, Params const ¶ms = Params{}) : params(params)
|
BlockEpilogue(Arch::Resource<ArchTag> const &resource, Params const ¶ms = Params{}) : params(params)
|
||||||
{
|
{
|
||||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID0);
|
|
||||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID1);
|
|
||||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID2);
|
|
||||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID3);
|
|
||||||
AscendC::SetFlag<AscendC::HardEvent::S_MTE2>(EVENT_ID2);
|
|
||||||
AscendC::SetFlag<AscendC::HardEvent::S_MTE2>(EVENT_ID3);
|
|
||||||
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID0);
|
|
||||||
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID1);
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
//ub:192KB
|
//ub:192KB
|
||||||
n0 = params.n0;
|
n0 = params.n0;
|
||||||
size_t ubOffset = 0;
|
size_t ubOffset = 0;
|
||||||
@@ -98,137 +87,20 @@ public:
|
|||||||
source_scale_offset[i] = -1;
|
source_scale_offset[i] = -1;
|
||||||
}
|
}
|
||||||
tokenPerExpert.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(params.ptrTokenPerExpert));
|
tokenPerExpert.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(params.ptrTokenPerExpert));
|
||||||
tokenPerExpertLayout = Layout3D(params.EP * params.expertPerRank, params.expertPerRank);
|
tokenPerExpertLayout = Layout3D(AlignUp(params.EP * params.expertPerRank, ALIGN_128), params.expertPerRank);
|
||||||
is_ping = true;
|
is_ping = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
CATLASS_DEVICE
|
CATLASS_DEVICE
|
||||||
void Finalize()
|
void Finalize()
|
||||||
{
|
{
|
||||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID0);
|
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
|
||||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID1);
|
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID1);
|
||||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID2);
|
|
||||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID3);
|
|
||||||
AscendC::WaitFlag<AscendC::HardEvent::S_MTE2>(EVENT_ID2);
|
|
||||||
AscendC::WaitFlag<AscendC::HardEvent::S_MTE2>(EVENT_ID3);
|
|
||||||
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID0);
|
|
||||||
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID1);
|
|
||||||
|
|
||||||
}
|
}
|
||||||
CATLASS_DEVICE
|
CATLASS_DEVICE
|
||||||
~BlockEpilogue()
|
~BlockEpilogue()
|
||||||
{
|
{
|
||||||
|
|
||||||
}
|
|
||||||
CATLASS_DEVICE
|
|
||||||
void operator() (
|
|
||||||
AscendC::GlobalTensor<ElementC> const &gmC,
|
|
||||||
AscendC::GlobalTensor<ElementPerTokenScale> const &gmPerTokenScale,
|
|
||||||
GemmCoord& blockCoord,
|
|
||||||
GemmCoord& actualBlockShape,
|
|
||||||
int32_t groupIdx,
|
|
||||||
int32_t preSrcExpertSum,
|
|
||||||
AscendC::GlobalTensor<int32_t> preSumBeforeRank,
|
|
||||||
uint32_t *mPreSumBeforeRank
|
|
||||||
){
|
|
||||||
is_ping = !is_ping;
|
|
||||||
auto event_id = is_ping ? EVENT_ID0 : EVENT_ID1;
|
|
||||||
auto event_id_2 = is_ping ? EVENT_ID2 : EVENT_ID3;
|
|
||||||
|
|
||||||
auto &ubC = ubCList[is_ping];
|
|
||||||
auto &ubD = ubDList[is_ping];
|
|
||||||
int32_t gmCOffset = preSrcExpertSum * params.n2 + blockCoord.m() * params.n2 + blockCoord.n();
|
|
||||||
auto gmTileC = gmC[gmCOffset];
|
|
||||||
auto &ubCFp32 = ubFp32List[is_ping];
|
|
||||||
auto &scaleUb = scaleUbList[is_ping];
|
|
||||||
// auto &ubOutFp32 = ubOutFp32List[is_ping];
|
|
||||||
|
|
||||||
LayoutC layoutGM{actualBlockShape.m(), actualBlockShape.n(), params.n2};
|
|
||||||
LayoutC layoutUB{actualBlockShape.m(), actualBlockShape.n(), n0};
|
|
||||||
|
|
||||||
|
|
||||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(event_id); //for debug
|
|
||||||
copyGmToUbC(ubC, gmTileC, layoutUB, layoutGM);
|
|
||||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(event_id); //for debug
|
|
||||||
|
|
||||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(event_id);
|
|
||||||
AscendC::Cast<float, ElementC, false>(ubCFp32, ubC, AscendC::RoundMode::CAST_NONE, -1, repeat, {1, 1, 8, 4});
|
|
||||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(event_id);
|
|
||||||
|
|
||||||
|
|
||||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(event_id_2);
|
|
||||||
AscendC::WaitFlag<AscendC::HardEvent::S_MTE2>(event_id_2);
|
|
||||||
|
|
||||||
int32_t gmScaleOffset = preSrcExpertSum + blockCoord.m();
|
|
||||||
layout::VectorLayout scaleLauout{actualBlockShape.m()};
|
|
||||||
if (source_scale_offset[event_id] != gmScaleOffset) {
|
|
||||||
source_scale_offset[event_id] = gmScaleOffset;
|
|
||||||
copyScaleGmToUb(scaleUb, gmPerTokenScale[gmScaleOffset], scaleLauout, scaleLauout);
|
|
||||||
}
|
|
||||||
|
|
||||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_S>(event_id_2);
|
|
||||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(event_id_2);
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(event_id_2);
|
|
||||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_S>(event_id_2); // 注意必须是MTE2_S,不能是MTE2_V,否则会读到0,造成乱码
|
|
||||||
AscendC::PipeBarrier<PIPE_V>();
|
|
||||||
for (int32_t row = 0; row < actualBlockShape.m(); ++row) {
|
|
||||||
float scale = scaleUb(row);
|
|
||||||
Muls<float, false>(ubCFp32[n0* row], ubCFp32[n0 * row] , scale, -1, (actualBlockShape.n() + 127) / 128 * 2, {1, 1, 8, 8});
|
|
||||||
}
|
|
||||||
AscendC::PipeBarrier<PIPE_V>();
|
|
||||||
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(event_id);
|
|
||||||
AscendC::Cast<ElementD, float, false>(ubD, ubCFp32, AscendC::RoundMode::CAST_RINT, -1, repeat, {1, 1, 4, 8});
|
|
||||||
AscendC::SetFlag<AscendC::HardEvent::S_MTE2>(event_id_2);
|
|
||||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(event_id_2);
|
|
||||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(event_id);
|
|
||||||
|
|
||||||
int32_t lenTile = actualBlockShape.m();
|
|
||||||
int32_t stTile = blockCoord.m();
|
|
||||||
int32_t edTile = stTile + lenTile;
|
|
||||||
int32_t preSumRankInExpert = 0;
|
|
||||||
int32_t tileOffset = 0;
|
|
||||||
|
|
||||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(event_id); //for debug
|
|
||||||
for (int32_t dstEpIdx = 0; dstEpIdx < params.EP; dstEpIdx ++) {
|
|
||||||
int32_t lenRankInExpert = tokenPerExpert(tokenPerExpertLayout(dstEpIdx, params.rank, groupIdx));
|
|
||||||
int32_t dstExpertOffset = preSumBeforeRank(dstEpIdx * 16);
|
|
||||||
int32_t stRankInExpert = preSumRankInExpert;
|
|
||||||
int32_t edRankInExpert = stRankInExpert + lenRankInExpert;
|
|
||||||
preSumRankInExpert += lenRankInExpert;
|
|
||||||
if (stRankInExpert >= edTile) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
else if (edRankInExpert <= stTile) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
int32_t stData = max(stRankInExpert, stTile);
|
|
||||||
int32_t edData = min(edRankInExpert, edTile);
|
|
||||||
uint32_t lenData = edData - stData;
|
|
||||||
if (lenData <= 0){
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
uint32_t dstOffsetInExpert = 0;
|
|
||||||
if (stTile > stRankInExpert) {
|
|
||||||
dstOffsetInExpert = stTile - stRankInExpert;
|
|
||||||
}
|
|
||||||
AscendC::GlobalTensor<ElementD> gmRemotePeer;
|
|
||||||
__gm__ void* dstPeermemPtr = params.shmem(params.offsetD, dstEpIdx);
|
|
||||||
gmRemotePeer.SetGlobalBuffer(reinterpret_cast<__gm__ ElementD*>(dstPeermemPtr));
|
|
||||||
MatrixCoord dstOffset{dstOffsetInExpert + dstExpertOffset + mPreSumBeforeRank[dstEpIdx], blockCoord.n()};
|
|
||||||
int64_t gmDstOffset = params.layoutC.GetOffset(dstOffset);
|
|
||||||
auto gmTileD = gmRemotePeer[gmDstOffset];
|
|
||||||
LayoutC layoutGM2{lenData, actualBlockShape.n(), params.n2};
|
|
||||||
LayoutC layoutUB2{lenData, actualBlockShape.n(), n0};
|
|
||||||
copyUbToGmD(gmTileD, ubD[tileOffset * n0], layoutGM2, layoutUB2);
|
|
||||||
tileOffset += lenData;
|
|
||||||
}
|
|
||||||
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(event_id);
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
CATLASS_DEVICE
|
CATLASS_DEVICE
|
||||||
@@ -238,14 +110,12 @@ public:
|
|||||||
GemmCoord& actualBlockShape,
|
GemmCoord& actualBlockShape,
|
||||||
int32_t groupIdx,
|
int32_t groupIdx,
|
||||||
int32_t preSrcExpertSum,
|
int32_t preSrcExpertSum,
|
||||||
AscendC::GlobalTensor<int32_t> preSumBeforeRank,
|
AscendC::GlobalTensor<int32_t> preSumBeforeRank
|
||||||
uint32_t *mPreSumBeforeRank
|
|
||||||
){
|
){
|
||||||
is_ping = !is_ping;
|
is_ping = !is_ping;
|
||||||
auto event_id = is_ping ? EVENT_ID0 : EVENT_ID1;
|
auto event_id = is_ping ? EVENT_ID0 : EVENT_ID1;
|
||||||
|
|
||||||
auto &ubC = ubCList[is_ping];
|
auto &ubC = ubCList[is_ping];
|
||||||
auto &ubD = ubDList[is_ping];
|
|
||||||
int32_t gmCOffset = preSrcExpertSum * params.n2 + blockCoord.m() * params.n2 + blockCoord.n();
|
int32_t gmCOffset = preSrcExpertSum * params.n2 + blockCoord.m() * params.n2 + blockCoord.n();
|
||||||
auto gmTileC = gmC[gmCOffset];
|
auto gmTileC = gmC[gmCOffset];
|
||||||
|
|
||||||
@@ -253,7 +123,7 @@ public:
|
|||||||
LayoutC layoutUB{actualBlockShape.m(), actualBlockShape.n(), n0};
|
LayoutC layoutUB{actualBlockShape.m(), actualBlockShape.n(), n0};
|
||||||
|
|
||||||
|
|
||||||
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(event_id); //for debug
|
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(event_id); //for debug
|
||||||
copyGmToUbC(ubC, gmTileC, layoutUB, layoutGM);
|
copyGmToUbC(ubC, gmTileC, layoutUB, layoutGM);
|
||||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_MTE3>(event_id); //for debug
|
AscendC::SetFlag<AscendC::HardEvent::MTE2_MTE3>(event_id); //for debug
|
||||||
|
|
||||||
@@ -263,10 +133,10 @@ public:
|
|||||||
int32_t preSumRankInExpert = 0;
|
int32_t preSumRankInExpert = 0;
|
||||||
int32_t tileOffset = 0;
|
int32_t tileOffset = 0;
|
||||||
|
|
||||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_MTE3>(event_id); //for debug
|
AscendC::WaitFlag<AscendC::HardEvent::MTE2_MTE3>(event_id);
|
||||||
for (int32_t dstEpIdx = 0; dstEpIdx < params.EP; dstEpIdx ++) {
|
for (int32_t dstEpIdx = 0; dstEpIdx < params.EP; dstEpIdx ++) {
|
||||||
int32_t lenRankInExpert = tokenPerExpert(tokenPerExpertLayout(dstEpIdx, params.rank, groupIdx));
|
int32_t lenRankInExpert = tokenPerExpert(tokenPerExpertLayout(dstEpIdx, params.rank, groupIdx));
|
||||||
int32_t dstExpertOffset = preSumBeforeRank(dstEpIdx * 16);
|
int32_t dstExpertOffset = preSumBeforeRank(dstEpIdx * params.expertPerRank + groupIdx);
|
||||||
int32_t stRankInExpert = preSumRankInExpert;
|
int32_t stRankInExpert = preSumRankInExpert;
|
||||||
int32_t edRankInExpert = stRankInExpert + lenRankInExpert;
|
int32_t edRankInExpert = stRankInExpert + lenRankInExpert;
|
||||||
preSumRankInExpert += lenRankInExpert;
|
preSumRankInExpert += lenRankInExpert;
|
||||||
@@ -282,7 +152,7 @@ public:
|
|||||||
if (lenData <= 0){
|
if (lenData <= 0){
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t dstOffsetInExpert = 0;
|
uint32_t dstOffsetInExpert = 0;
|
||||||
if (stTile > stRankInExpert) {
|
if (stTile > stRankInExpert) {
|
||||||
dstOffsetInExpert = stTile - stRankInExpert;
|
dstOffsetInExpert = stTile - stRankInExpert;
|
||||||
@@ -290,7 +160,7 @@ public:
|
|||||||
AscendC::GlobalTensor<ElementD> gmRemotePeer;
|
AscendC::GlobalTensor<ElementD> gmRemotePeer;
|
||||||
__gm__ void* dstPeermemPtr = params.shmem(params.offsetD, dstEpIdx);
|
__gm__ void* dstPeermemPtr = params.shmem(params.offsetD, dstEpIdx);
|
||||||
gmRemotePeer.SetGlobalBuffer(reinterpret_cast<__gm__ ElementD*>(dstPeermemPtr));
|
gmRemotePeer.SetGlobalBuffer(reinterpret_cast<__gm__ ElementD*>(dstPeermemPtr));
|
||||||
MatrixCoord dstOffset{dstOffsetInExpert + dstExpertOffset + mPreSumBeforeRank[dstEpIdx], blockCoord.n()};
|
MatrixCoord dstOffset{dstOffsetInExpert + dstExpertOffset, blockCoord.n()};
|
||||||
int64_t gmDstOffset = params.layoutC.GetOffset(dstOffset);
|
int64_t gmDstOffset = params.layoutC.GetOffset(dstOffset);
|
||||||
auto gmTileD = gmRemotePeer[gmDstOffset];
|
auto gmTileD = gmRemotePeer[gmDstOffset];
|
||||||
LayoutC layoutGM2{lenData, actualBlockShape.n(), params.n2};
|
LayoutC layoutGM2{lenData, actualBlockShape.n(), params.n2};
|
||||||
@@ -298,7 +168,8 @@ public:
|
|||||||
copyUbToGmD(gmTileD, ubC[tileOffset * n0], layoutGM2, layoutUB2);
|
copyUbToGmD(gmTileD, ubC[tileOffset * n0], layoutGM2, layoutUB2);
|
||||||
tileOffset += lenData;
|
tileOffset += lenData;
|
||||||
}
|
}
|
||||||
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(event_id);
|
|
||||||
|
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(event_id);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -22,8 +22,6 @@
|
|||||||
|
|
||||||
namespace Catlass::Gemm::Block {
|
namespace Catlass::Gemm::Block {
|
||||||
|
|
||||||
constexpr uint16_t CROSS_CORE_FLAG_MAX_SET_COUNT = 15;
|
|
||||||
|
|
||||||
template<AscendC::HardEvent event>
|
template<AscendC::HardEvent event>
|
||||||
__aicore__ inline void SyncFlagFunc(int32_t eventID)
|
__aicore__ inline void SyncFlagFunc(int32_t eventID)
|
||||||
{
|
{
|
||||||
@@ -153,9 +151,11 @@ public:
|
|||||||
L1TileShape::K, L1TileShape::N);
|
L1TileShape::K, L1TileShape::N);
|
||||||
|
|
||||||
CATLASS_DEVICE
|
CATLASS_DEVICE
|
||||||
BlockMmad(Arch::Resource<ArchTag> &resource, uint32_t l1BufAddrStart = 0)
|
BlockMmad(Arch::Resource<ArchTag> &resource, __gm__ int32_t* flagPtr = nullptr, int32_t expertPerRank = 0, uint32_t l1BufAddrStart = 0)
|
||||||
{
|
{
|
||||||
syncGroupIdx = 0;
|
syncGroupIdx = 0;
|
||||||
|
ptrSoftFlagBase_ = flagPtr;
|
||||||
|
expertPerRank_ = expertPerRank;
|
||||||
InitL1(resource, l1BufAddrStart);
|
InitL1(resource, l1BufAddrStart);
|
||||||
InitL0A(resource);
|
InitL0A(resource);
|
||||||
InitL0B(resource);
|
InitL0B(resource);
|
||||||
@@ -272,9 +272,21 @@ public:
|
|||||||
CATLASS_DEVICE
|
CATLASS_DEVICE
|
||||||
void Finalize(int32_t target, int32_t flag = 0)
|
void Finalize(int32_t target, int32_t flag = 0)
|
||||||
{
|
{
|
||||||
for(;syncGroupIdx <= target; syncGroupIdx++) {
|
if (ptrSoftFlagBase_ != nullptr) {
|
||||||
int32_t flagId = syncGroupIdx / CROSS_CORE_FLAG_MAX_SET_COUNT + flag;
|
if (target < 0) {
|
||||||
AscendC::CrossCoreSetFlag<0x2, PIPE_FIX>(flagId);
|
return;
|
||||||
|
}
|
||||||
|
AscendC::SetFlag<AscendC::HardEvent::FIX_MTE3>(EVENT_ID0);
|
||||||
|
AscendC::WaitFlag<AscendC::HardEvent::FIX_MTE3>(EVENT_ID0);
|
||||||
|
AscendC::GlobalTensor<int32_t> flagGlobal;
|
||||||
|
flagGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(ptrSoftFlagBase_) + (expertPerRank_ + AscendC::GetBlockIdx()) * FLAGSTRIDE);
|
||||||
|
AscendC::DataCopy(flagGlobal, l1FTensor[target * 16], FLAGSTRIDE);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
for(;syncGroupIdx <= target; syncGroupIdx++) {
|
||||||
|
int32_t flagId = syncGroupIdx / 15 + flag;
|
||||||
|
AscendC::CrossCoreSetFlag<0x2, PIPE_FIX>(flagId);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
private:
|
private:
|
||||||
@@ -291,7 +303,6 @@ private:
|
|||||||
layout::VectorLayout layoutScale;
|
layout::VectorLayout layoutScale;
|
||||||
int32_t syncLoopIdx;
|
int32_t syncLoopIdx;
|
||||||
int32_t flag;
|
int32_t flag;
|
||||||
|
|
||||||
CATLASS_DEVICE
|
CATLASS_DEVICE
|
||||||
L1TileMmadParams() = default;
|
L1TileMmadParams() = default;
|
||||||
};
|
};
|
||||||
@@ -310,11 +321,24 @@ private:
|
|||||||
AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(l1AEventList[i]);
|
AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(l1AEventList[i]);
|
||||||
AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(l1BEventList[i]);
|
AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(l1BEventList[i]);
|
||||||
}
|
}
|
||||||
|
uint32_t l1SOffset = l1BOffset + L1B_TILE_SIZE * L1_STAGES;
|
||||||
if constexpr (std::is_same_v<ElementA, int8_t>) {
|
if constexpr (std::is_same_v<ElementA, int8_t>) {
|
||||||
uint32_t l1SOffset = l1BOffset + L1B_TILE_SIZE * L1_STAGES;
|
|
||||||
l1STensor = resource.l1Buf.template GetBufferByByte<uint64_t>(l1SOffset);
|
l1STensor = resource.l1Buf.template GetBufferByByte<uint64_t>(l1SOffset);
|
||||||
AscendC::SetFlag<AscendC::HardEvent::FIX_MTE2>(0);
|
AscendC::SetFlag<AscendC::HardEvent::FIX_MTE2>(0);
|
||||||
}
|
}
|
||||||
|
if (ptrSoftFlagBase_ != nullptr) {
|
||||||
|
// Initialize the flag matrix (structure as below):
|
||||||
|
// 1 0 0 0 0 0 0 0
|
||||||
|
// 2 0 0 0 0 0 0 0
|
||||||
|
// ...
|
||||||
|
// 16 0 0 0 0 0 0 0
|
||||||
|
// Then move it to L1
|
||||||
|
uint32_t l1FOffset = l1SOffset + L1S_TILE_SIZE;
|
||||||
|
l1FTensor = resource.l1Buf.template GetBufferByByte<int32_t>(l1FOffset);
|
||||||
|
AscendC::GlobalTensor<int32_t> flagBase;
|
||||||
|
flagBase.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(ptrSoftFlagBase_));
|
||||||
|
AscendC::DataCopy(l1FTensor, flagBase, expertPerRank_ * FLAGSTRIDE);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
CATLASS_DEVICE
|
CATLASS_DEVICE
|
||||||
@@ -463,12 +487,20 @@ private:
|
|||||||
if constexpr (std::is_same_v<ElementA, int8_t>) {
|
if constexpr (std::is_same_v<ElementA, int8_t>) {
|
||||||
AscendC::SetFlag<AscendC::HardEvent::FIX_MTE2>(0);
|
AscendC::SetFlag<AscendC::HardEvent::FIX_MTE2>(0);
|
||||||
}
|
}
|
||||||
|
#ifdef __TILE_SYNC__
|
||||||
|
if (params.flag > 0) {
|
||||||
|
int32_t flagId = params.flag + params.syncLoopIdx / 8;
|
||||||
|
AscendC::CrossCoreSetFlag<0x2, PIPE_FIX>(flagId);
|
||||||
|
}
|
||||||
|
#else
|
||||||
Finalize(params.syncLoopIdx, params.flag);
|
Finalize(params.syncLoopIdx, params.flag);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
AscendC::LocalTensor<ElementA> l1ATensorList[L1_STAGES];
|
AscendC::LocalTensor<ElementA> l1ATensorList[L1_STAGES];
|
||||||
AscendC::LocalTensor<ElementB> l1BTensorList[L1_STAGES];
|
AscendC::LocalTensor<ElementB> l1BTensorList[L1_STAGES];
|
||||||
AscendC::LocalTensor<uint64_t> l1STensor;
|
AscendC::LocalTensor<uint64_t> l1STensor;
|
||||||
|
AscendC::LocalTensor<int32_t> l1FTensor;
|
||||||
int32_t syncGroupIdx;
|
int32_t syncGroupIdx;
|
||||||
int32_t l1AEventList[L1_STAGES];
|
int32_t l1AEventList[L1_STAGES];
|
||||||
int32_t l1BEventList[L1_STAGES];
|
int32_t l1BEventList[L1_STAGES];
|
||||||
@@ -497,8 +529,11 @@ private:
|
|||||||
CopyL1ToL0A copyL1ToL0A;
|
CopyL1ToL0A copyL1ToL0A;
|
||||||
CopyL1ToL0B copyL1ToL0B;
|
CopyL1ToL0B copyL1ToL0B;
|
||||||
CopyL0CToGm copyL0CToGm;
|
CopyL0CToGm copyL0CToGm;
|
||||||
|
|
||||||
|
__gm__ int32_t* ptrSoftFlagBase_ = nullptr;
|
||||||
|
int32_t expertPerRank_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace Catlass::Gemm::Block
|
} // namespace Catlass::Gemm::Block
|
||||||
|
|
||||||
#endif // CATLASS_GEMM_BLOCK_BLOCK_MMAD_PRELOAD_FIXPIPE_QUANT_HPP
|
#endif // CATLASS_GEMM_BLOCK_BLOCK_MMAD_PRELOAD_FIXPIPE_QUANT_HPP
|
||||||
@@ -4,6 +4,10 @@
|
|||||||
constexpr static uint64_t MB_SIZE = 1024 * 1024UL;
|
constexpr static uint64_t MB_SIZE = 1024 * 1024UL;
|
||||||
constexpr static int32_t NUMS_PER_FLAG = 16;
|
constexpr static int32_t NUMS_PER_FLAG = 16;
|
||||||
constexpr static int32_t CACHE_LINE = 512;
|
constexpr static int32_t CACHE_LINE = 512;
|
||||||
|
constexpr static int32_t FLAGSTRIDE = 16;
|
||||||
constexpr static int32_t RESET_VAL = 0xffff;
|
constexpr static int32_t RESET_VAL = 0xffff;
|
||||||
|
constexpr static int32_t ALIGN_128 = 128;
|
||||||
constexpr uint32_t MAX_EXPERTS_PER_RANK = 32;
|
constexpr uint32_t MAX_EXPERTS_PER_RANK = 32;
|
||||||
|
constexpr static int32_t UB_ALIGN = 32;
|
||||||
|
constexpr uint16_t CROSS_CORE_FLAG_MAX_SET_COUNT = 15;
|
||||||
#endif
|
#endif
|
||||||
@@ -5,16 +5,23 @@
|
|||||||
#include "kernel_operator.h"
|
#include "kernel_operator.h"
|
||||||
#include "const_args.hpp"
|
#include "const_args.hpp"
|
||||||
|
|
||||||
|
#ifdef HCCL_COMM
|
||||||
#include "moe_distribute_base.h"
|
#include "moe_distribute_base.h"
|
||||||
|
|
||||||
#ifndef HCCL_COMM
|
|
||||||
#include "shmem_api.h"
|
|
||||||
using namespace AscendC::HcclContextDef;
|
using namespace AscendC::HcclContextDef;
|
||||||
|
|
||||||
|
#else
|
||||||
|
#include "shmem_api.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#define FORCE_INLINE_AICORE inline __attribute__((always_inline)) __aicore__
|
#define FORCE_INLINE_AICORE inline __attribute__((always_inline)) __aicore__
|
||||||
constexpr int32_t MAX_RANK_SIZE = 32;
|
constexpr int32_t MAX_RANK_SIZE = 32;
|
||||||
constexpr int32_t SHMEM_MEM = 1024 * MB_SIZE;
|
constexpr int32_t SHMEM_MEM = 700 * MB_SIZE;
|
||||||
|
|
||||||
|
constexpr uint16_t SEND_SYNC_EVENT_ID = 9;
|
||||||
|
constexpr uint16_t RECV_SYNC_EVENT_ID = 10;
|
||||||
|
|
||||||
|
constexpr uint32_t SELF_STATE_OFFSET = 256 * 1024;
|
||||||
|
constexpr uint32_t STATE_OFFSET = 512;
|
||||||
|
|
||||||
FORCE_INLINE_AICORE void AicSyncAll() {
|
FORCE_INLINE_AICORE void AicSyncAll() {
|
||||||
AscendC::CrossCoreSetFlag<0x0, PIPE_FIX>(8);
|
AscendC::CrossCoreSetFlag<0x0, PIPE_FIX>(8);
|
||||||
@@ -31,10 +38,11 @@ FORCE_INLINE_AICORE T gm_load(__gm__ T *cache) {
|
|||||||
return *((__gm__ T *)cache);
|
return *((__gm__ T *)cache);
|
||||||
}
|
}
|
||||||
|
|
||||||
FORCE_INLINE_AICORE void gm_dcci(__gm__ uint8_t * addr) {
|
template<typename T>
|
||||||
|
FORCE_INLINE_AICORE void gm_dcci(__gm__ T * addr) {
|
||||||
using namespace AscendC;
|
using namespace AscendC;
|
||||||
GlobalTensor<uint8_t> global;
|
GlobalTensor<uint8_t> global;
|
||||||
global.SetGlobalBuffer(addr);
|
global.SetGlobalBuffer(reinterpret_cast<GM_ADDR>(addr));
|
||||||
|
|
||||||
// Important: add hint to avoid dcci being optimized by compiler
|
// Important: add hint to avoid dcci being optimized by compiler
|
||||||
__asm__ __volatile__("");
|
__asm__ __volatile__("");
|
||||||
@@ -58,7 +66,7 @@ FORCE_INLINE_AICORE int32_t gm_signal_wait_until_eq_for_barrier(__gm__ int32_t *
|
|||||||
FORCE_INLINE_AICORE void gm_signal_wait_until_ne(__gm__ int32_t *sig_addr, int32_t cmp_val) {
|
FORCE_INLINE_AICORE void gm_signal_wait_until_ne(__gm__ int32_t *sig_addr, int32_t cmp_val) {
|
||||||
do {
|
do {
|
||||||
AscendC::LocalTensor<int32_t> ub;
|
AscendC::LocalTensor<int32_t> ub;
|
||||||
ub.address_.logicPos = static_cast<uint8_t>(TPosition::VECIN);
|
ub.address_.logicPos = static_cast<uint8_t>(AscendC::TPosition::VECIN);
|
||||||
ub.address_.bufferAddr = 0;
|
ub.address_.bufferAddr = 0;
|
||||||
AscendC::GlobalTensor<int32_t> sig;
|
AscendC::GlobalTensor<int32_t> sig;
|
||||||
sig.SetGlobalBuffer(sig_addr);
|
sig.SetGlobalBuffer(sig_addr);
|
||||||
@@ -75,10 +83,10 @@ FORCE_INLINE_AICORE void gm_signal_wait_until_ne(__gm__ int32_t *sig_addr, int32
|
|||||||
|
|
||||||
class HcclShmem {
|
class HcclShmem {
|
||||||
public:
|
public:
|
||||||
#ifdef HCCL_COMM
|
#ifdef HCCL_COMM // HCCL needs to initialize the HCCL context
|
||||||
__gm__ HcclOpResParamCustom *WinContext_{nullptr};
|
__gm__ HcclOpResParamCustom *WinContext_{nullptr};
|
||||||
Hccl<HCCL_SERVER_TYPE_AICPU> hccl_;
|
Hccl<HCCL_SERVER_TYPE_AICPU> hccl_;
|
||||||
GM_ADDR m_ptrArray[MAX_RANK_SIZE];
|
AscendC::LocalTensor<int32_t> ub;
|
||||||
FORCE_INLINE_AICORE
|
FORCE_INLINE_AICORE
|
||||||
HcclShmem(){
|
HcclShmem(){
|
||||||
auto contextGM0 = AscendC::GetHcclContext<HCCL_GROUP_ID_0>();
|
auto contextGM0 = AscendC::GetHcclContext<HCCL_GROUP_ID_0>();
|
||||||
@@ -87,17 +95,13 @@ public:
|
|||||||
m_rank = WinContext_->localUsrRankId;
|
m_rank = WinContext_->localUsrRankId;
|
||||||
m_rankSize = WinContext_->rankSize;
|
m_rankSize = WinContext_->rankSize;
|
||||||
m_segmentSize = WinContext_->winSize;
|
m_segmentSize = WinContext_->winSize;
|
||||||
for (int i = 0; i < m_rankSize; i++) {
|
|
||||||
m_ptrArray[i] = (GM_ADDR)((i == m_rank) ? WinContext_->localWindowsIn :
|
|
||||||
((HcclRankRelationResV2Custom *)(WinContext_->remoteRes[i].nextDevicePtr))->windowsIn);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
FORCE_INLINE_AICORE
|
FORCE_INLINE_AICORE
|
||||||
HcclShmem(){
|
HcclShmem(){
|
||||||
m_segmentSize = SHMEM_MEM;
|
m_segmentSize = SHMEM_MEM;
|
||||||
}
|
}
|
||||||
FORCE_INLINE_AICORE
|
FORCE_INLINE_AICORE
|
||||||
void initShmem(GM_ADDR symmetricPtr_, size_t rank, size_t rankSize) {
|
void initShmem(GM_ADDR symmetricPtr_, size_t rank, size_t rankSize) {
|
||||||
symmetricPtr = symmetricPtr_;
|
symmetricPtr = symmetricPtr_;
|
||||||
m_rank = rank;
|
m_rank = rank;
|
||||||
@@ -106,25 +110,26 @@ public:
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
FORCE_INLINE_AICORE
|
FORCE_INLINE_AICORE
|
||||||
GM_ADDR operator() () const {
|
GM_ADDR operator() () const { // No parameters: return pointer to local peermem
|
||||||
#ifdef HCCL_COMM
|
#ifdef HCCL_COMM
|
||||||
return m_ptrArray[m_rank];
|
return (GM_ADDR)(WinContext_->localWindowsIn);
|
||||||
#else
|
#else
|
||||||
return reinterpret_cast<GM_ADDR>(shmem_ptr(symmetricPtr, m_rank));
|
return reinterpret_cast<GM_ADDR>(shmem_ptr(symmetricPtr, m_rank));
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
FORCE_INLINE_AICORE
|
FORCE_INLINE_AICORE
|
||||||
GM_ADDR operator() (int32_t index) const {
|
GM_ADDR operator() (int32_t index) const { // With index parameter: return pointer to the base address of remote peermem
|
||||||
#ifdef HCCL_COMM
|
#ifdef HCCL_COMM
|
||||||
return m_ptrArray[index];
|
return (GM_ADDR)((index == m_rank) ? WinContext_->localWindowsIn :
|
||||||
|
((HcclRankRelationResV2Custom *)(WinContext_->remoteRes[index].nextDevicePtr))->windowsIn);
|
||||||
#else
|
#else
|
||||||
return reinterpret_cast<GM_ADDR>(shmem_ptr(symmetricPtr, index));
|
return reinterpret_cast<GM_ADDR>(shmem_ptr(symmetricPtr, index));
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
FORCE_INLINE_AICORE
|
FORCE_INLINE_AICORE
|
||||||
GM_ADDR operator () (int64_t offset, int32_t rankId) const {
|
GM_ADDR operator () (int64_t offset, int32_t rankId) const {
|
||||||
#ifdef HCCL_COMM
|
#ifdef HCCL_COMM
|
||||||
if (offset < 0 || offset >= m_segmentSize) {
|
if (offset < 0 || offset >= m_segmentSize) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
@@ -132,7 +137,8 @@ public:
|
|||||||
if (rankId < 0 || rankId >= m_rankSize) {
|
if (rankId < 0 || rankId >= m_rankSize) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
return m_ptrArray[rankId] + offset;
|
return (GM_ADDR)((rankId == m_rank) ? WinContext_->localWindowsIn :
|
||||||
|
((HcclRankRelationResV2Custom *)(WinContext_->remoteRes[rankId].nextDevicePtr))->windowsIn) + offset;
|
||||||
#else
|
#else
|
||||||
return reinterpret_cast<GM_ADDR>(shmem_ptr((symmetricPtr + offset), rankId));
|
return reinterpret_cast<GM_ADDR>(shmem_ptr((symmetricPtr + offset), rankId));
|
||||||
#endif
|
#endif
|
||||||
@@ -176,6 +182,130 @@ public:
|
|||||||
gm_store(sync_base, count);
|
gm_store(sync_base, count);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
FORCE_INLINE_AICORE
|
||||||
|
void InitStatusTargetSum()
|
||||||
|
{
|
||||||
|
using namespace AscendC;
|
||||||
|
uint64_t flag_offset = (m_segmentSize - MB_SIZE) + SELF_STATE_OFFSET;
|
||||||
|
//uint64_t self_state_offset = (m_segmentSize - 2 * MB_SIZE);
|
||||||
|
// ep state
|
||||||
|
//uint32_t coreIdx = get_block_idx();;
|
||||||
|
uint32_t coreIdx = GetBlockIdx();
|
||||||
|
GlobalTensor<int32_t> selfStatusTensor;
|
||||||
|
selfStatusTensor.SetGlobalBuffer((__gm__ int32_t *)((*this)() + flag_offset));
|
||||||
|
__asm__ __volatile__("");
|
||||||
|
DataCacheCleanAndInvalid<int32_t, CacheLine::SINGLE_CACHE_LINE, DcciDst::CACHELINE_OUT>(selfStatusTensor[coreIdx * UB_ALIGN]);
|
||||||
|
__asm__ __volatile__("");
|
||||||
|
int32_t state = selfStatusTensor(coreIdx * UB_ALIGN);
|
||||||
|
if (state == 0) {
|
||||||
|
sumTarget_ = static_cast<float>(1.0);
|
||||||
|
selfStatusTensor(coreIdx * UB_ALIGN) = 0x3F800000; // 1.0f
|
||||||
|
epStateValue_ = 0x3F800000; // 1.0f
|
||||||
|
} else {
|
||||||
|
sumTarget_ = static_cast<float>(0.0);
|
||||||
|
selfStatusTensor(coreIdx * UB_ALIGN) = 0;
|
||||||
|
epStateValue_ = 0;
|
||||||
|
}
|
||||||
|
__asm__ __volatile__("");
|
||||||
|
DataCacheCleanAndInvalid<int32_t, CacheLine::SINGLE_CACHE_LINE, DcciDst::CACHELINE_OUT>(selfStatusTensor[coreIdx * UB_ALIGN]);
|
||||||
|
__asm__ __volatile__("");
|
||||||
|
}
|
||||||
|
|
||||||
|
FORCE_INLINE_AICORE
|
||||||
|
void CrossRankSyncV2Set(AscendC::LocalTensor<int32_t> ctrBuffer) {
|
||||||
|
//subblockid = 0
|
||||||
|
uint32_t stateOffset_ = STATE_OFFSET;
|
||||||
|
// uint32_t epStateOffsetOnWin_ = m_rank * stateOffset_;
|
||||||
|
|
||||||
|
uint64_t flag_offset = (m_segmentSize - MB_SIZE) + m_rank * stateOffset_;
|
||||||
|
//uint64_t flag_offset = (m_segmentSize - MB_SIZE);
|
||||||
|
int vec_size = get_block_num();
|
||||||
|
int vec_id = get_block_idx();
|
||||||
|
|
||||||
|
AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>(RECV_SYNC_EVENT_ID);
|
||||||
|
AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>(SEND_SYNC_EVENT_ID);
|
||||||
|
AscendC::CrossCoreWaitFlag(SEND_SYNC_EVENT_ID);
|
||||||
|
pipe_barrier(PIPE_ALL);
|
||||||
|
|
||||||
|
ctrBuffer.SetValue(0, epStateValue_);
|
||||||
|
AscendC::SetFlag<AscendC::HardEvent::S_MTE3>(EVENT_ID0);
|
||||||
|
AscendC::WaitFlag<AscendC::HardEvent::S_MTE3>(EVENT_ID0);
|
||||||
|
for (uint32_t dstEpIdx = vec_id; dstEpIdx < m_rankSize; dstEpIdx += vec_size) {
|
||||||
|
AscendC::GlobalTensor<int32_t> gmDstStates;
|
||||||
|
gmDstStates.SetGlobalBuffer((__gm__ int32_t*)((*this)(flag_offset, dstEpIdx)));
|
||||||
|
DataCopy(gmDstStates, ctrBuffer, 8);
|
||||||
|
}
|
||||||
|
AscendC::CrossCoreWaitFlag(RECV_SYNC_EVENT_ID);
|
||||||
|
}
|
||||||
|
|
||||||
|
FORCE_INLINE_AICORE
|
||||||
|
void CrossRankSyncV2Wait(AscendC::LocalTensor<float> statusTensor, AscendC::LocalTensor<float> gatherMaskOutTensor,
|
||||||
|
AscendC::LocalTensor<uint32_t> gatherTmpTensor, AscendC::LocalTensor<float> statusSumOutTensor) {
|
||||||
|
|
||||||
|
uint64_t flag_offset = (m_segmentSize - MB_SIZE);
|
||||||
|
int vec_size = get_block_num();
|
||||||
|
int vec_id = get_block_idx();
|
||||||
|
uint32_t stateOffset_ = STATE_OFFSET;
|
||||||
|
|
||||||
|
uint32_t sendRankNum_ = m_rankSize / vec_size;
|
||||||
|
uint32_t remainderRankNum = m_rankSize % vec_size;
|
||||||
|
uint32_t startRankId_ = sendRankNum_ * vec_id;
|
||||||
|
if (vec_id < remainderRankNum) {
|
||||||
|
sendRankNum_++;
|
||||||
|
startRankId_ += vec_id;
|
||||||
|
} else {
|
||||||
|
startRankId_ += remainderRankNum;
|
||||||
|
}
|
||||||
|
uint32_t endRankId_ = startRankId_ + sendRankNum_;
|
||||||
|
AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>(SEND_SYNC_EVENT_ID);
|
||||||
|
|
||||||
|
AscendC::GlobalTensor<float> epStatusSpaceGlobalTensor_;
|
||||||
|
epStatusSpaceGlobalTensor_.SetGlobalBuffer((__gm__ float *)((*this)() + flag_offset));
|
||||||
|
|
||||||
|
if (startRankId_ < m_rankSize) {
|
||||||
|
AscendC::PipeBarrier<PIPE_ALL>();
|
||||||
|
gatherTmpTensor.SetValue(0, 1);
|
||||||
|
uint32_t mask = 1; // gatherMask + sum
|
||||||
|
uint64_t rsvdCnt = 0;
|
||||||
|
// DataCopyParams intriParams{static_cast<uint16_t>(sendRankNum_), 1,
|
||||||
|
// static_cast<uint16_t>((moeSendNum_ > 512) ? 7 : 15), 0};
|
||||||
|
AscendC::DataCopyParams intriParams{static_cast<uint16_t>(sendRankNum_), 1,
|
||||||
|
static_cast<uint16_t>(15), 0};
|
||||||
|
|
||||||
|
float sumOfFlag = static_cast<float>(-1.0);
|
||||||
|
float minTarget = (sumTarget_ * sendRankNum_) - (float)0.5;
|
||||||
|
float maxTarget = (sumTarget_ * sendRankNum_) + (float)0.5;
|
||||||
|
AscendC::SumParams sumParams{1, sendRankNum_, sendRankNum_};
|
||||||
|
|
||||||
|
AscendC::SetFlag<AscendC::HardEvent::S_V>(EVENT_ID0);
|
||||||
|
AscendC::WaitFlag<AscendC::HardEvent::S_V>(EVENT_ID0);
|
||||||
|
|
||||||
|
while ((sumOfFlag < minTarget) || (sumOfFlag > maxTarget)) {
|
||||||
|
AscendC::DataCopy<float>(statusTensor, epStatusSpaceGlobalTensor_[startRankId_ * stateOffset_ / sizeof(float)],
|
||||||
|
intriParams);
|
||||||
|
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
|
||||||
|
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
|
||||||
|
|
||||||
|
GatherMask(gatherMaskOutTensor, statusTensor, gatherTmpTensor, true, mask,
|
||||||
|
{1, (uint16_t)sendRankNum_, 1, 0}, rsvdCnt);
|
||||||
|
|
||||||
|
AscendC::PipeBarrier<PIPE_V>();
|
||||||
|
AscendC::Sum(statusSumOutTensor, gatherMaskOutTensor, sumParams);
|
||||||
|
AscendC::SetFlag<AscendC::HardEvent::V_S>(EVENT_ID0);
|
||||||
|
AscendC::WaitFlag<AscendC::HardEvent::V_S>(EVENT_ID0);
|
||||||
|
sumOfFlag = statusSumOutTensor.GetValue(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>(RECV_SYNC_EVENT_ID);
|
||||||
|
AscendC::CrossCoreWaitFlag(RECV_SYNC_EVENT_ID);
|
||||||
|
|
||||||
|
//unpermute
|
||||||
|
AscendC::CrossCoreWaitFlag(SEND_SYNC_EVENT_ID);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
FORCE_INLINE_AICORE
|
FORCE_INLINE_AICORE
|
||||||
__gm__ int32_t* SyncBaseAddr() {
|
__gm__ int32_t* SyncBaseAddr() {
|
||||||
uint64_t flag_offset = (m_segmentSize - MB_SIZE) / sizeof(int32_t);
|
uint64_t flag_offset = (m_segmentSize - MB_SIZE) / sizeof(int32_t);
|
||||||
@@ -187,9 +317,11 @@ private:
|
|||||||
int32_t m_rank;
|
int32_t m_rank;
|
||||||
int32_t m_rankSize;
|
int32_t m_rankSize;
|
||||||
size_t m_segmentSize;
|
size_t m_segmentSize;
|
||||||
|
float sumTarget_{0.0};
|
||||||
|
int32_t epStateValue_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
Reference in New Issue
Block a user