[fix]: fix precision issue in dispatch_ffn_combine_bf16 and remove redundant sync (#7198)

### What this PR does / why we need it?
Fix the precision issue in dispatch_ffn_combine_bf16 operator.
Remove redundant synchronization operations in dispatch_ffn_combine
operator.

- vLLM version: v0.16.0
- vLLM main:
4034c3d32e

---------

Signed-off-by: guanguan0308 <1546542263@qq.com>
This commit is contained in:
guanguan0308
2026-03-23 10:14:03 +08:00
committed by GitHub
parent e68464a1d6
commit 44ef9a36ac
8 changed files with 531 additions and 462 deletions

View File

@@ -72,17 +72,6 @@ public:
CATLASS_DEVICE
BlockEpilogue(Arch::Resource<ArchTag> const &resource, Params const &params = Params{}) : params(params)
{
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID0);
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID1);
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID2);
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID3);
AscendC::SetFlag<AscendC::HardEvent::S_MTE2>(EVENT_ID2);
AscendC::SetFlag<AscendC::HardEvent::S_MTE2>(EVENT_ID3);
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID0);
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID1);
//ub:192KB
n0 = params.n0;
size_t ubOffset = 0;
@@ -98,137 +87,20 @@ public:
source_scale_offset[i] = -1;
}
tokenPerExpert.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(params.ptrTokenPerExpert));
tokenPerExpertLayout = Layout3D(params.EP * params.expertPerRank, params.expertPerRank);
tokenPerExpertLayout = Layout3D(AlignUp(params.EP * params.expertPerRank, ALIGN_128), params.expertPerRank);
is_ping = true;
}
CATLASS_DEVICE
void Finalize()
{
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID1);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID2);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID3);
AscendC::WaitFlag<AscendC::HardEvent::S_MTE2>(EVENT_ID2);
AscendC::WaitFlag<AscendC::HardEvent::S_MTE2>(EVENT_ID3);
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID1);
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID1);
}
CATLASS_DEVICE
~BlockEpilogue()
{
}
CATLASS_DEVICE
void operator() (
AscendC::GlobalTensor<ElementC> const &gmC,
AscendC::GlobalTensor<ElementPerTokenScale> const &gmPerTokenScale,
GemmCoord& blockCoord,
GemmCoord& actualBlockShape,
int32_t groupIdx,
int32_t preSrcExpertSum,
AscendC::GlobalTensor<int32_t> preSumBeforeRank,
uint32_t *mPreSumBeforeRank
){
is_ping = !is_ping;
auto event_id = is_ping ? EVENT_ID0 : EVENT_ID1;
auto event_id_2 = is_ping ? EVENT_ID2 : EVENT_ID3;
auto &ubC = ubCList[is_ping];
auto &ubD = ubDList[is_ping];
int32_t gmCOffset = preSrcExpertSum * params.n2 + blockCoord.m() * params.n2 + blockCoord.n();
auto gmTileC = gmC[gmCOffset];
auto &ubCFp32 = ubFp32List[is_ping];
auto &scaleUb = scaleUbList[is_ping];
// auto &ubOutFp32 = ubOutFp32List[is_ping];
LayoutC layoutGM{actualBlockShape.m(), actualBlockShape.n(), params.n2};
LayoutC layoutUB{actualBlockShape.m(), actualBlockShape.n(), n0};
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(event_id); //for debug
copyGmToUbC(ubC, gmTileC, layoutUB, layoutGM);
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(event_id); //for debug
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(event_id);
AscendC::Cast<float, ElementC, false>(ubCFp32, ubC, AscendC::RoundMode::CAST_NONE, -1, repeat, {1, 1, 8, 4});
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(event_id);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(event_id_2);
AscendC::WaitFlag<AscendC::HardEvent::S_MTE2>(event_id_2);
int32_t gmScaleOffset = preSrcExpertSum + blockCoord.m();
layout::VectorLayout scaleLauout{actualBlockShape.m()};
if (source_scale_offset[event_id] != gmScaleOffset) {
source_scale_offset[event_id] = gmScaleOffset;
copyScaleGmToUb(scaleUb, gmPerTokenScale[gmScaleOffset], scaleLauout, scaleLauout);
}
AscendC::SetFlag<AscendC::HardEvent::MTE2_S>(event_id_2);
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(event_id_2);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(event_id_2);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_S>(event_id_2); // 注意必须是MTE2_S不能是MTE2_V否则会读到0造成乱码
AscendC::PipeBarrier<PIPE_V>();
for (int32_t row = 0; row < actualBlockShape.m(); ++row) {
float scale = scaleUb(row);
Muls<float, false>(ubCFp32[n0* row], ubCFp32[n0 * row] , scale, -1, (actualBlockShape.n() + 127) / 128 * 2, {1, 1, 8, 8});
}
AscendC::PipeBarrier<PIPE_V>();
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(event_id);
AscendC::Cast<ElementD, float, false>(ubD, ubCFp32, AscendC::RoundMode::CAST_RINT, -1, repeat, {1, 1, 4, 8});
AscendC::SetFlag<AscendC::HardEvent::S_MTE2>(event_id_2);
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(event_id_2);
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(event_id);
int32_t lenTile = actualBlockShape.m();
int32_t stTile = blockCoord.m();
int32_t edTile = stTile + lenTile;
int32_t preSumRankInExpert = 0;
int32_t tileOffset = 0;
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(event_id); //for debug
for (int32_t dstEpIdx = 0; dstEpIdx < params.EP; dstEpIdx ++) {
int32_t lenRankInExpert = tokenPerExpert(tokenPerExpertLayout(dstEpIdx, params.rank, groupIdx));
int32_t dstExpertOffset = preSumBeforeRank(dstEpIdx * 16);
int32_t stRankInExpert = preSumRankInExpert;
int32_t edRankInExpert = stRankInExpert + lenRankInExpert;
preSumRankInExpert += lenRankInExpert;
if (stRankInExpert >= edTile) {
break;
}
else if (edRankInExpert <= stTile) {
continue;
}
int32_t stData = max(stRankInExpert, stTile);
int32_t edData = min(edRankInExpert, edTile);
uint32_t lenData = edData - stData;
if (lenData <= 0){
continue;
}
uint32_t dstOffsetInExpert = 0;
if (stTile > stRankInExpert) {
dstOffsetInExpert = stTile - stRankInExpert;
}
AscendC::GlobalTensor<ElementD> gmRemotePeer;
__gm__ void* dstPeermemPtr = params.shmem(params.offsetD, dstEpIdx);
gmRemotePeer.SetGlobalBuffer(reinterpret_cast<__gm__ ElementD*>(dstPeermemPtr));
MatrixCoord dstOffset{dstOffsetInExpert + dstExpertOffset + mPreSumBeforeRank[dstEpIdx], blockCoord.n()};
int64_t gmDstOffset = params.layoutC.GetOffset(dstOffset);
auto gmTileD = gmRemotePeer[gmDstOffset];
LayoutC layoutGM2{lenData, actualBlockShape.n(), params.n2};
LayoutC layoutUB2{lenData, actualBlockShape.n(), n0};
copyUbToGmD(gmTileD, ubD[tileOffset * n0], layoutGM2, layoutUB2);
tileOffset += lenData;
}
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(event_id);
}
CATLASS_DEVICE
@@ -238,14 +110,12 @@ public:
GemmCoord& actualBlockShape,
int32_t groupIdx,
int32_t preSrcExpertSum,
AscendC::GlobalTensor<int32_t> preSumBeforeRank,
uint32_t *mPreSumBeforeRank
AscendC::GlobalTensor<int32_t> preSumBeforeRank
){
is_ping = !is_ping;
auto event_id = is_ping ? EVENT_ID0 : EVENT_ID1;
auto &ubC = ubCList[is_ping];
auto &ubD = ubDList[is_ping];
int32_t gmCOffset = preSrcExpertSum * params.n2 + blockCoord.m() * params.n2 + blockCoord.n();
auto gmTileC = gmC[gmCOffset];
@@ -253,7 +123,7 @@ public:
LayoutC layoutUB{actualBlockShape.m(), actualBlockShape.n(), n0};
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(event_id); //for debug
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(event_id); //for debug
copyGmToUbC(ubC, gmTileC, layoutUB, layoutGM);
AscendC::SetFlag<AscendC::HardEvent::MTE2_MTE3>(event_id); //for debug
@@ -263,10 +133,10 @@ public:
int32_t preSumRankInExpert = 0;
int32_t tileOffset = 0;
AscendC::WaitFlag<AscendC::HardEvent::MTE2_MTE3>(event_id); //for debug
AscendC::WaitFlag<AscendC::HardEvent::MTE2_MTE3>(event_id);
for (int32_t dstEpIdx = 0; dstEpIdx < params.EP; dstEpIdx ++) {
int32_t lenRankInExpert = tokenPerExpert(tokenPerExpertLayout(dstEpIdx, params.rank, groupIdx));
int32_t dstExpertOffset = preSumBeforeRank(dstEpIdx * 16);
int32_t dstExpertOffset = preSumBeforeRank(dstEpIdx * params.expertPerRank + groupIdx);
int32_t stRankInExpert = preSumRankInExpert;
int32_t edRankInExpert = stRankInExpert + lenRankInExpert;
preSumRankInExpert += lenRankInExpert;
@@ -282,7 +152,7 @@ public:
if (lenData <= 0){
continue;
}
uint32_t dstOffsetInExpert = 0;
if (stTile > stRankInExpert) {
dstOffsetInExpert = stTile - stRankInExpert;
@@ -290,7 +160,7 @@ public:
AscendC::GlobalTensor<ElementD> gmRemotePeer;
__gm__ void* dstPeermemPtr = params.shmem(params.offsetD, dstEpIdx);
gmRemotePeer.SetGlobalBuffer(reinterpret_cast<__gm__ ElementD*>(dstPeermemPtr));
MatrixCoord dstOffset{dstOffsetInExpert + dstExpertOffset + mPreSumBeforeRank[dstEpIdx], blockCoord.n()};
MatrixCoord dstOffset{dstOffsetInExpert + dstExpertOffset, blockCoord.n()};
int64_t gmDstOffset = params.layoutC.GetOffset(dstOffset);
auto gmTileD = gmRemotePeer[gmDstOffset];
LayoutC layoutGM2{lenData, actualBlockShape.n(), params.n2};
@@ -298,7 +168,8 @@ public:
copyUbToGmD(gmTileD, ubC[tileOffset * n0], layoutGM2, layoutUB2);
tileOffset += lenData;
}
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(event_id);
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(event_id);
}

View File

@@ -22,8 +22,6 @@
namespace Catlass::Gemm::Block {
constexpr uint16_t CROSS_CORE_FLAG_MAX_SET_COUNT = 15;
template<AscendC::HardEvent event>
__aicore__ inline void SyncFlagFunc(int32_t eventID)
{
@@ -153,9 +151,11 @@ public:
L1TileShape::K, L1TileShape::N);
CATLASS_DEVICE
BlockMmad(Arch::Resource<ArchTag> &resource, uint32_t l1BufAddrStart = 0)
BlockMmad(Arch::Resource<ArchTag> &resource, __gm__ int32_t* flagPtr = nullptr, int32_t expertPerRank = 0, uint32_t l1BufAddrStart = 0)
{
syncGroupIdx = 0;
ptrSoftFlagBase_ = flagPtr;
expertPerRank_ = expertPerRank;
InitL1(resource, l1BufAddrStart);
InitL0A(resource);
InitL0B(resource);
@@ -272,9 +272,21 @@ public:
CATLASS_DEVICE
void Finalize(int32_t target, int32_t flag = 0)
{
for(;syncGroupIdx <= target; syncGroupIdx++) {
int32_t flagId = syncGroupIdx / CROSS_CORE_FLAG_MAX_SET_COUNT + flag;
AscendC::CrossCoreSetFlag<0x2, PIPE_FIX>(flagId);
if (ptrSoftFlagBase_ != nullptr) {
if (target < 0) {
return;
}
AscendC::SetFlag<AscendC::HardEvent::FIX_MTE3>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::FIX_MTE3>(EVENT_ID0);
AscendC::GlobalTensor<int32_t> flagGlobal;
flagGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(ptrSoftFlagBase_) + (expertPerRank_ + AscendC::GetBlockIdx()) * FLAGSTRIDE);
AscendC::DataCopy(flagGlobal, l1FTensor[target * 16], FLAGSTRIDE);
}
else {
for(;syncGroupIdx <= target; syncGroupIdx++) {
int32_t flagId = syncGroupIdx / 15 + flag;
AscendC::CrossCoreSetFlag<0x2, PIPE_FIX>(flagId);
}
}
}
private:
@@ -291,7 +303,6 @@ private:
layout::VectorLayout layoutScale;
int32_t syncLoopIdx;
int32_t flag;
CATLASS_DEVICE
L1TileMmadParams() = default;
};
@@ -310,11 +321,24 @@ private:
AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(l1AEventList[i]);
AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(l1BEventList[i]);
}
uint32_t l1SOffset = l1BOffset + L1B_TILE_SIZE * L1_STAGES;
if constexpr (std::is_same_v<ElementA, int8_t>) {
uint32_t l1SOffset = l1BOffset + L1B_TILE_SIZE * L1_STAGES;
l1STensor = resource.l1Buf.template GetBufferByByte<uint64_t>(l1SOffset);
AscendC::SetFlag<AscendC::HardEvent::FIX_MTE2>(0);
}
if (ptrSoftFlagBase_ != nullptr) {
// Initialize the flag matrix (structure as below):
// 1 0 0 0 0 0 0 0
// 2 0 0 0 0 0 0 0
// ...
// 16 0 0 0 0 0 0 0
// Then move it to L1
uint32_t l1FOffset = l1SOffset + L1S_TILE_SIZE;
l1FTensor = resource.l1Buf.template GetBufferByByte<int32_t>(l1FOffset);
AscendC::GlobalTensor<int32_t> flagBase;
flagBase.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(ptrSoftFlagBase_));
AscendC::DataCopy(l1FTensor, flagBase, expertPerRank_ * FLAGSTRIDE);
}
}
CATLASS_DEVICE
@@ -463,12 +487,20 @@ private:
if constexpr (std::is_same_v<ElementA, int8_t>) {
AscendC::SetFlag<AscendC::HardEvent::FIX_MTE2>(0);
}
#ifdef __TILE_SYNC__
if (params.flag > 0) {
int32_t flagId = params.flag + params.syncLoopIdx / 8;
AscendC::CrossCoreSetFlag<0x2, PIPE_FIX>(flagId);
}
#else
Finalize(params.syncLoopIdx, params.flag);
#endif
}
}
AscendC::LocalTensor<ElementA> l1ATensorList[L1_STAGES];
AscendC::LocalTensor<ElementB> l1BTensorList[L1_STAGES];
AscendC::LocalTensor<uint64_t> l1STensor;
AscendC::LocalTensor<int32_t> l1FTensor;
int32_t syncGroupIdx;
int32_t l1AEventList[L1_STAGES];
int32_t l1BEventList[L1_STAGES];
@@ -497,8 +529,11 @@ private:
CopyL1ToL0A copyL1ToL0A;
CopyL1ToL0B copyL1ToL0B;
CopyL0CToGm copyL0CToGm;
__gm__ int32_t* ptrSoftFlagBase_ = nullptr;
int32_t expertPerRank_;
};
} // namespace Catlass::Gemm::Block
#endif // CATLASS_GEMM_BLOCK_BLOCK_MMAD_PRELOAD_FIXPIPE_QUANT_HPP
#endif // CATLASS_GEMM_BLOCK_BLOCK_MMAD_PRELOAD_FIXPIPE_QUANT_HPP

View File

@@ -4,6 +4,10 @@
constexpr static uint64_t MB_SIZE = 1024 * 1024UL;
constexpr static int32_t NUMS_PER_FLAG = 16;
constexpr static int32_t CACHE_LINE = 512;
constexpr static int32_t FLAGSTRIDE = 16;
constexpr static int32_t RESET_VAL = 0xffff;
constexpr static int32_t ALIGN_128 = 128;
constexpr uint32_t MAX_EXPERTS_PER_RANK = 32;
constexpr static int32_t UB_ALIGN = 32;
constexpr uint16_t CROSS_CORE_FLAG_MAX_SET_COUNT = 15;
#endif

View File

@@ -5,16 +5,23 @@
#include "kernel_operator.h"
#include "const_args.hpp"
#ifdef HCCL_COMM
#include "moe_distribute_base.h"
#ifndef HCCL_COMM
#include "shmem_api.h"
using namespace AscendC::HcclContextDef;
#else
#include "shmem_api.h"
#endif
#define FORCE_INLINE_AICORE inline __attribute__((always_inline)) __aicore__
constexpr int32_t MAX_RANK_SIZE = 32;
constexpr int32_t SHMEM_MEM = 1024 * MB_SIZE;
constexpr int32_t SHMEM_MEM = 700 * MB_SIZE;
constexpr uint16_t SEND_SYNC_EVENT_ID = 9;
constexpr uint16_t RECV_SYNC_EVENT_ID = 10;
constexpr uint32_t SELF_STATE_OFFSET = 256 * 1024;
constexpr uint32_t STATE_OFFSET = 512;
FORCE_INLINE_AICORE void AicSyncAll() {
AscendC::CrossCoreSetFlag<0x0, PIPE_FIX>(8);
@@ -31,10 +38,11 @@ FORCE_INLINE_AICORE T gm_load(__gm__ T *cache) {
return *((__gm__ T *)cache);
}
FORCE_INLINE_AICORE void gm_dcci(__gm__ uint8_t * addr) {
template<typename T>
FORCE_INLINE_AICORE void gm_dcci(__gm__ T * addr) {
using namespace AscendC;
GlobalTensor<uint8_t> global;
global.SetGlobalBuffer(addr);
global.SetGlobalBuffer(reinterpret_cast<GM_ADDR>(addr));
// Important: add hint to avoid dcci being optimized by compiler
__asm__ __volatile__("");
@@ -58,7 +66,7 @@ FORCE_INLINE_AICORE int32_t gm_signal_wait_until_eq_for_barrier(__gm__ int32_t *
FORCE_INLINE_AICORE void gm_signal_wait_until_ne(__gm__ int32_t *sig_addr, int32_t cmp_val) {
do {
AscendC::LocalTensor<int32_t> ub;
ub.address_.logicPos = static_cast<uint8_t>(TPosition::VECIN);
ub.address_.logicPos = static_cast<uint8_t>(AscendC::TPosition::VECIN);
ub.address_.bufferAddr = 0;
AscendC::GlobalTensor<int32_t> sig;
sig.SetGlobalBuffer(sig_addr);
@@ -75,10 +83,10 @@ FORCE_INLINE_AICORE void gm_signal_wait_until_ne(__gm__ int32_t *sig_addr, int32
class HcclShmem {
public:
#ifdef HCCL_COMM
#ifdef HCCL_COMM // HCCL needs to initialize the HCCL context
__gm__ HcclOpResParamCustom *WinContext_{nullptr};
Hccl<HCCL_SERVER_TYPE_AICPU> hccl_;
GM_ADDR m_ptrArray[MAX_RANK_SIZE];
AscendC::LocalTensor<int32_t> ub;
FORCE_INLINE_AICORE
HcclShmem(){
auto contextGM0 = AscendC::GetHcclContext<HCCL_GROUP_ID_0>();
@@ -87,17 +95,13 @@ public:
m_rank = WinContext_->localUsrRankId;
m_rankSize = WinContext_->rankSize;
m_segmentSize = WinContext_->winSize;
for (int i = 0; i < m_rankSize; i++) {
m_ptrArray[i] = (GM_ADDR)((i == m_rank) ? WinContext_->localWindowsIn :
((HcclRankRelationResV2Custom *)(WinContext_->remoteRes[i].nextDevicePtr))->windowsIn);
}
}
#else
FORCE_INLINE_AICORE
HcclShmem(){
m_segmentSize = SHMEM_MEM;
}
FORCE_INLINE_AICORE
FORCE_INLINE_AICORE
void initShmem(GM_ADDR symmetricPtr_, size_t rank, size_t rankSize) {
symmetricPtr = symmetricPtr_;
m_rank = rank;
@@ -106,25 +110,26 @@ public:
#endif
FORCE_INLINE_AICORE
GM_ADDR operator() () const {
GM_ADDR operator() () const { // No parameters: return pointer to local peermem
#ifdef HCCL_COMM
return m_ptrArray[m_rank];
return (GM_ADDR)(WinContext_->localWindowsIn);
#else
return reinterpret_cast<GM_ADDR>(shmem_ptr(symmetricPtr, m_rank));
#endif
}
FORCE_INLINE_AICORE
GM_ADDR operator() (int32_t index) const {
GM_ADDR operator() (int32_t index) const { // With index parameter: return pointer to the base address of remote peermem
#ifdef HCCL_COMM
return m_ptrArray[index];
return (GM_ADDR)((index == m_rank) ? WinContext_->localWindowsIn :
((HcclRankRelationResV2Custom *)(WinContext_->remoteRes[index].nextDevicePtr))->windowsIn);
#else
return reinterpret_cast<GM_ADDR>(shmem_ptr(symmetricPtr, index));
#endif
}
FORCE_INLINE_AICORE
GM_ADDR operator () (int64_t offset, int32_t rankId) const {
GM_ADDR operator () (int64_t offset, int32_t rankId) const {
#ifdef HCCL_COMM
if (offset < 0 || offset >= m_segmentSize) {
return nullptr;
@@ -132,7 +137,8 @@ public:
if (rankId < 0 || rankId >= m_rankSize) {
return nullptr;
}
return m_ptrArray[rankId] + offset;
return (GM_ADDR)((rankId == m_rank) ? WinContext_->localWindowsIn :
((HcclRankRelationResV2Custom *)(WinContext_->remoteRes[rankId].nextDevicePtr))->windowsIn) + offset;
#else
return reinterpret_cast<GM_ADDR>(shmem_ptr((symmetricPtr + offset), rankId));
#endif
@@ -176,6 +182,130 @@ public:
gm_store(sync_base, count);
}
FORCE_INLINE_AICORE
void InitStatusTargetSum()
{
using namespace AscendC;
uint64_t flag_offset = (m_segmentSize - MB_SIZE) + SELF_STATE_OFFSET;
//uint64_t self_state_offset = (m_segmentSize - 2 * MB_SIZE);
// ep state
//uint32_t coreIdx = get_block_idx();;
uint32_t coreIdx = GetBlockIdx();
GlobalTensor<int32_t> selfStatusTensor;
selfStatusTensor.SetGlobalBuffer((__gm__ int32_t *)((*this)() + flag_offset));
__asm__ __volatile__("");
DataCacheCleanAndInvalid<int32_t, CacheLine::SINGLE_CACHE_LINE, DcciDst::CACHELINE_OUT>(selfStatusTensor[coreIdx * UB_ALIGN]);
__asm__ __volatile__("");
int32_t state = selfStatusTensor(coreIdx * UB_ALIGN);
if (state == 0) {
sumTarget_ = static_cast<float>(1.0);
selfStatusTensor(coreIdx * UB_ALIGN) = 0x3F800000; // 1.0f
epStateValue_ = 0x3F800000; // 1.0f
} else {
sumTarget_ = static_cast<float>(0.0);
selfStatusTensor(coreIdx * UB_ALIGN) = 0;
epStateValue_ = 0;
}
__asm__ __volatile__("");
DataCacheCleanAndInvalid<int32_t, CacheLine::SINGLE_CACHE_LINE, DcciDst::CACHELINE_OUT>(selfStatusTensor[coreIdx * UB_ALIGN]);
__asm__ __volatile__("");
}
FORCE_INLINE_AICORE
void CrossRankSyncV2Set(AscendC::LocalTensor<int32_t> ctrBuffer) {
//subblockid = 0
uint32_t stateOffset_ = STATE_OFFSET;
// uint32_t epStateOffsetOnWin_ = m_rank * stateOffset_;
uint64_t flag_offset = (m_segmentSize - MB_SIZE) + m_rank * stateOffset_;
//uint64_t flag_offset = (m_segmentSize - MB_SIZE);
int vec_size = get_block_num();
int vec_id = get_block_idx();
AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>(RECV_SYNC_EVENT_ID);
AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>(SEND_SYNC_EVENT_ID);
AscendC::CrossCoreWaitFlag(SEND_SYNC_EVENT_ID);
pipe_barrier(PIPE_ALL);
ctrBuffer.SetValue(0, epStateValue_);
AscendC::SetFlag<AscendC::HardEvent::S_MTE3>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::S_MTE3>(EVENT_ID0);
for (uint32_t dstEpIdx = vec_id; dstEpIdx < m_rankSize; dstEpIdx += vec_size) {
AscendC::GlobalTensor<int32_t> gmDstStates;
gmDstStates.SetGlobalBuffer((__gm__ int32_t*)((*this)(flag_offset, dstEpIdx)));
DataCopy(gmDstStates, ctrBuffer, 8);
}
AscendC::CrossCoreWaitFlag(RECV_SYNC_EVENT_ID);
}
FORCE_INLINE_AICORE
void CrossRankSyncV2Wait(AscendC::LocalTensor<float> statusTensor, AscendC::LocalTensor<float> gatherMaskOutTensor,
AscendC::LocalTensor<uint32_t> gatherTmpTensor, AscendC::LocalTensor<float> statusSumOutTensor) {
uint64_t flag_offset = (m_segmentSize - MB_SIZE);
int vec_size = get_block_num();
int vec_id = get_block_idx();
uint32_t stateOffset_ = STATE_OFFSET;
uint32_t sendRankNum_ = m_rankSize / vec_size;
uint32_t remainderRankNum = m_rankSize % vec_size;
uint32_t startRankId_ = sendRankNum_ * vec_id;
if (vec_id < remainderRankNum) {
sendRankNum_++;
startRankId_ += vec_id;
} else {
startRankId_ += remainderRankNum;
}
uint32_t endRankId_ = startRankId_ + sendRankNum_;
AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>(SEND_SYNC_EVENT_ID);
AscendC::GlobalTensor<float> epStatusSpaceGlobalTensor_;
epStatusSpaceGlobalTensor_.SetGlobalBuffer((__gm__ float *)((*this)() + flag_offset));
if (startRankId_ < m_rankSize) {
AscendC::PipeBarrier<PIPE_ALL>();
gatherTmpTensor.SetValue(0, 1);
uint32_t mask = 1; // gatherMask + sum
uint64_t rsvdCnt = 0;
// DataCopyParams intriParams{static_cast<uint16_t>(sendRankNum_), 1,
// static_cast<uint16_t>((moeSendNum_ > 512) ? 7 : 15), 0};
AscendC::DataCopyParams intriParams{static_cast<uint16_t>(sendRankNum_), 1,
static_cast<uint16_t>(15), 0};
float sumOfFlag = static_cast<float>(-1.0);
float minTarget = (sumTarget_ * sendRankNum_) - (float)0.5;
float maxTarget = (sumTarget_ * sendRankNum_) + (float)0.5;
AscendC::SumParams sumParams{1, sendRankNum_, sendRankNum_};
AscendC::SetFlag<AscendC::HardEvent::S_V>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::S_V>(EVENT_ID0);
while ((sumOfFlag < minTarget) || (sumOfFlag > maxTarget)) {
AscendC::DataCopy<float>(statusTensor, epStatusSpaceGlobalTensor_[startRankId_ * stateOffset_ / sizeof(float)],
intriParams);
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
GatherMask(gatherMaskOutTensor, statusTensor, gatherTmpTensor, true, mask,
{1, (uint16_t)sendRankNum_, 1, 0}, rsvdCnt);
AscendC::PipeBarrier<PIPE_V>();
AscendC::Sum(statusSumOutTensor, gatherMaskOutTensor, sumParams);
AscendC::SetFlag<AscendC::HardEvent::V_S>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::V_S>(EVENT_ID0);
sumOfFlag = statusSumOutTensor.GetValue(0);
}
}
AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>(RECV_SYNC_EVENT_ID);
AscendC::CrossCoreWaitFlag(RECV_SYNC_EVENT_ID);
//unpermute
AscendC::CrossCoreWaitFlag(SEND_SYNC_EVENT_ID);
}
FORCE_INLINE_AICORE
__gm__ int32_t* SyncBaseAddr() {
uint64_t flag_offset = (m_segmentSize - MB_SIZE) / sizeof(int32_t);
@@ -187,9 +317,11 @@ private:
int32_t m_rank;
int32_t m_rankSize;
size_t m_segmentSize;
float sumTarget_{0.0};
int32_t epStateValue_;
};
#endif
#endif