[Kernel] Add moe normal ops (#4810)

### What this PR does / why we need it?
1.Add the implementation of normal Aclnn operators: MoeCombineNormal,
MoeDispatchNormal, NotifyDispatch,and DispatchLayout.

- MoeCombineNormal: Implements the combine logic within MoE operations.
- MoeDispatchNormal: Implements the dispatch logic within MoE
operations.
- NotifyDispatch: Exchanges topk_idx information among different ranks
to calculate the device memory required for the dispatch stage.
- DispatchLayout: Used to calculate information related to the device
memory layout for the dispatch stage.

2.Provide PyTorch interfaces for normal operators—get_dispatch_layout,
dispatch_prefill, and combine_prefill—to be used for MoE communication
during the prefill stage in vLLM.

- get_dispatch_layout: Calculates information related to the device
memory layout for the dispatch operator, and is called before
dispatch_prefill.
- dispatch_prefill: Initiates the dispatch operation.
- combine_prefill: Initiates the combine operation.

### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
The functionality has already been validated using the local Qwen model.
Test cases will be added after support for multi-NPU use cases in the CI
pipeline is finalized.

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

Signed-off-by: shiro-zzzz <zhangdianhao@huawei.com>
This commit is contained in:
shiro-zzzz
2025-12-10 17:15:28 +08:00
committed by GitHub
parent c77dca54b2
commit bd8be2e759
39 changed files with 5365 additions and 4 deletions

View File

@@ -0,0 +1,72 @@
#ifndef COMM_ARGS_H
#define COMM_ARGS_H
#include <cstdint>
#define FORCE_INLINE_AICORE __attribute__((always_inline)) inline __aicore__
#include "kernel_operator.h"
namespace Moe {
constexpr int CAM_MAX_RANK_SIZE = 384; // Maximum number of NPU cards supported by the communication library
constexpr int64_t IPC_BUFF_MAX_SIZE = 100 * 1024 * 1024;
constexpr int64_t IPC_DATA_OFFSET = 2 * 1024 * 1024; // First 2MB as flag, then 100MB as data storage
constexpr int64_t PING_PONG_SIZE = 2;
constexpr int64_t UB_SINGLE_DMA_SIZE_MAX = 190 * 1024;
constexpr int64_t SMALL_DATA_SIZE = 1 * 1024 * 1024;
constexpr int64_t UB_SINGLE_PING_PONG_ADD_SIZE_MAX = UB_SINGLE_DMA_SIZE_MAX / 2;
constexpr int UB_ALIGN_SIZE = 32;
constexpr int64_t MAGIC_ALIGN_COUNT = UB_ALIGN_SIZE / sizeof(int32_t);
constexpr uint8_t COMM_NUM = 2; // Size of communication domain
constexpr uint8_t COMM_EP_IDX = 0;
constexpr uint8_t COMM_TP_IDX = 1;
constexpr int DFX_COUNT = 50;
constexpr int64_t WAIT_SUCCESS = 112233445566;
constexpr int64_t IPC_CHUNK_FLAG = 0; // Start offset for send recv, chunk flag region
constexpr int64_t MAX_WAIT_ROUND_UNIT = 10 * 1000 * 1000; // Threshold for waiting to get Flag under normal conditions within the same SIO
constexpr static int32_t UB_HEAD_OFFSET = 96;
constexpr static int32_t UB_MID_OFFSET = UB_HEAD_OFFSET + UB_SINGLE_PING_PONG_ADD_SIZE_MAX + UB_ALIGN_SIZE;
constexpr static int64_t UB_FLAG_SIZE = 2 * 1024;
constexpr static int64_t MAX_CORE_NUM = 48;
constexpr static uint64_t STATE_WIN_OFFSET = 900 * 1024;
constexpr static int64_t COMPARE_ALIGN_SIZE = 256;
constexpr static int64_t UB_SINGLE_TOTAL_SIZE_MAX = 192 * 1024;
constexpr static int64_t START_OFFSET_FOR_SHARE = 512;
enum Op : int {
COPYONLY = -1,
ADD = 0,
MUL = 1,
MAX = 2,
MIN = 3
};
struct CommArgs {
int rank = 0; // attr rank_id, global rank
int localRank = -1;
int rankSize = 0; // global rank size
int localRankSize = -1; // This parameter refers to the number of cards interconnected in fullmesh
uint32_t extraFlag = 0; // 32 bit map, the specific meaning of each bit is above in this file
int testFlag = 0;
GM_ADDR peerMems[CAM_MAX_RANK_SIZE] = {}; // Buffer obtained from initialization, all allreduce is the same parameter
/**
* @param sendCountMatrix One-dimensional array with a size of rankSize*rankSize
* eg: The value of sendCountMatrix[1] corresponds to the [0][1] of the two-dimensional array, indicating the number of data that card 0 needs to send to card 1
*/
int64_t sendCountMatrix[CAM_MAX_RANK_SIZE * CAM_MAX_RANK_SIZE] = {}; // for all2allvc
int64_t sendCounts[CAM_MAX_RANK_SIZE] = {}; // for all2allv
int64_t sdispls[CAM_MAX_RANK_SIZE] = {}; // for all2allv
int64_t recvCounts[CAM_MAX_RANK_SIZE] = {}; // for all2allv
int64_t rdispls[CAM_MAX_RANK_SIZE] = {}; // for all2allv
int64_t batchSize;
int64_t hiddenSize;
int64_t topk;
int64_t sharedExpertRankNum;
int64_t expertNumPerRank;
int64_t dfx[DFX_COUNT] = {};
};
}
#endif // COMM_ARGS_H

View File

@@ -0,0 +1,68 @@
#ifndef CAM_DATACOPY_GM2GM_H
#define CAM_DATACOPY_GM2GM_H
#include <type_traits>
#include "comm_args.h"
using namespace AscendC;
using namespace Moe;
template <typename T>
FORCE_INLINE_AICORE void SetAtomicOpType(int op)
{
switch (op) {
case ADD:
AscendC::SetAtomicAdd<T>();
break;
case MUL:
// Ignore setting the atomic register when performing mul
break;
case MAX:
AscendC::SetAtomicMax<T>();
break;
case MIN:
AscendC::SetAtomicMin<T>();
break;
default:
AscendC::SetAtomicNone();
}
}
template <typename T>
FORCE_INLINE_AICORE void CpUB2GM(__gm__ T *gmAddr, __ubuf__ T *ubAddr, uint32_t size)
{
LocalTensor<uint8_t> ubTensor;
GlobalTensor<uint8_t> gmTensor;
DataCopyExtParams dataCopyParams(1, size, 0, 0, 0);
ubTensor.address_.logicPos = static_cast<uint8_t>(TPosition::VECIN);
ubTensor.address_.bufferAddr = reinterpret_cast<uint64_t>(ubAddr);
gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ uint8_t *>(gmAddr));
DataCopyPad(gmTensor, ubTensor, dataCopyParams);
}
template <typename T>
FORCE_INLINE_AICORE void CpGM2UB(__ubuf__ T *ubAddr, __gm__ T *gmAddr, uint32_t size)
{
LocalTensor<uint8_t> ubTensor;
GlobalTensor<uint8_t> gmTensor;
DataCopyExtParams dataCopyParams(1, size, 0, 0, 0);
ubTensor.address_.logicPos = static_cast<uint8_t>(TPosition::VECIN);
ubTensor.address_.bufferAddr = reinterpret_cast<uint64_t>(ubAddr);
gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ uint8_t *>(gmAddr));
DataCopyPadExtParams<uint8_t> padParams;
DataCopyPad(ubTensor, gmTensor, dataCopyParams, padParams);
}
template<typename T>
FORCE_INLINE_AICORE void CopyUB2UB(__ubuf__ T *dst, __ubuf__ T *src, const uint32_t calCount)
{
LocalTensor<T> srcTensor;
LocalTensor<T> dstTensor;
TBuffAddr srcAddr, dstAddr;
srcAddr.bufferAddr = reinterpret_cast<uint64_t>(src);
dstAddr.bufferAddr = reinterpret_cast<uint64_t>(dst);
srcTensor.SetAddr(srcAddr);
dstTensor.SetAddr(dstAddr);
DataCopy(dstTensor, srcTensor, calCount);
}
#endif // CAM_DATACOPY_GM2GM_H

View File

@@ -0,0 +1,426 @@
#ifndef SYNC_COLLECTIVES_H
#define SYNC_COLLECTIVES_H
#include "comm_args.h"
using namespace AscendC;
using namespace Moe;
// Synchronization flag occupies length
constexpr int64_t FLAG_UNIT_INT_NUM = 4;
// Memory size occupied by each synchronization unit (Bytes)
constexpr int64_t SYNC_UNIT_SIZE = FLAG_UNIT_INT_NUM * sizeof(int64_t);
// High-order offset when using magic as a comparison value
constexpr int64_t MAGIC_OFFSET = 32;
constexpr int64_t MAGIC_MASK = ~((1LL << MAGIC_OFFSET) - 1);
class SyncCollectives {
public:
__aicore__ inline SyncCollectives() {}
__aicore__ inline void Init(int rank, int rankSize, GM_ADDR *shareAddrs, TBuf<QuePosition::VECCALC> &tBuf)
{
this->rank = rank;
this->rankSize = rankSize;
this->shareAddrs = shareAddrs;
this->blockIdx = GetBlockIdx();
this->blockNum = GetBlockNum();
// Length of a single indicator segment
segmentCount = GetBlockNum() * FLAG_UNIT_INT_NUM;
// Initialize the intra-card/inter-card synchronization address corresponding to the current core.
localSyncAddr = (__gm__ int64_t*)(shareAddrs[rank]);
basicSyncAddr = (__gm__ int64_t*)(shareAddrs[rank]) + GetBlockIdx() * FLAG_UNIT_INT_NUM;
blockOuterSyncAddr = (__gm__ int64_t*)(shareAddrs[rank]) + segmentCount + GetBlockIdx() * FLAG_UNIT_INT_NUM;
this->tBuf = tBuf;
}
__aicore__ inline void SetSyncFlag(int32_t magic, int32_t value, int32_t eventID)
{
int64_t v = MergeMagicWithValue(magic, value);
SetFlag(localSyncAddr + eventID * FLAG_UNIT_INT_NUM, v);
}
/**
* @brief Set the flag for the specified eventID of the designated card, with the value being a combination of magic and value.
* @param magic The operator batch, which will be combined into the high 32 bits of the flag value to be set.
* @param value The specific value to be set, which will be the low 32 bits of the flag value to be set.
* @param eventID Physically, it is an offset from the shared memory base address (requires scaling, not an absolute value).
* @param rank This rank is the rankId corresponding to the peerMems array in the CommArgs structure, not a global or local id.
* (Local is not applicable in the 91093 scenario, and global is not applicable in the 910B multi-machine scenario.)
*/
__aicore__ inline void SetSyncFlag(int32_t magic, int32_t value, int32_t eventID, int32_t rank)
{
int64_t v = MergeMagicWithValue(magic, value);
SetFlag((__gm__ int64_t*)(shareAddrs[rank]) + eventID * FLAG_UNIT_INT_NUM, v);
}
__aicore__ inline int32_t CalEventIdByMulBlockNum(int32_t blockMultiplier, int32_t targetCoreId)
{
return (blockMultiplier * blockNum) + targetCoreId;
}
/**
* @brief Wait for the flag of the specified eventID on the specified card to become a value
* composed of the combination of magic and value.
* @param magic The operator batch, which will be combined into the high 32 bits of the flag
* value to be wait.
* @param value The specific value to be wait, which will be the low 32 bits of the flag
* value to be wait.
* @param eventID Physically, it is an offset from the shared memory base address (requires
* scaling, not an absolute value).
* @param rank This rank is the rankId corresponding to the peerMems array in the CommArgs
* structure, not a global or local id. (Local is not applicable in the 91093
* scenario, and global is not applicable in the 910B multi-machine scenario.)
*/
__aicore__ inline void WaitSyncFlag(int32_t magic, int32_t value, int32_t eventID, int32_t rank)
{
int64_t v = MergeMagicWithValue(magic, value);
WaitOneRankPartFlag((__gm__ int64_t*)(shareAddrs[rank]) + eventID * FLAG_UNIT_INT_NUM, 1, v);
}
__aicore__ inline void WaitSyncFlag(int32_t magic, int32_t value, int32_t eventID)
{
int64_t v = MergeMagicWithValue(magic, value);
WaitOneRankPartFlag((__gm__ int64_t*)(shareAddrs[this->rank]) + eventID * FLAG_UNIT_INT_NUM, 1, v);
}
/**
* @brief Wait for the flags starting from the specified eventID on the specified card to become
* a value composed of the combination of magic and value.<br>
* Note: [eventID, eventID + flagNum)
*/
__aicore__ inline void WaitSyncFlag(int32_t magic, int32_t value, int32_t eventID, int32_t rank, int64_t flagNum)
{
int64_t v = MergeMagicWithValue(magic, value);
WaitOneRankPartFlag((__gm__ int64_t*)(shareAddrs[rank]) + eventID * FLAG_UNIT_INT_NUM, flagNum, v);
}
// Set inner-card synchronization flag (memory A)
__aicore__ inline void SetInnerFlag(int32_t magic, int32_t eventID)
{
int64_t value = MergeMagicWithValue(magic, eventID);
SetFlag(basicSyncAddr, value);
}
__aicore__ inline void SetInnerFlag(int32_t magic, int32_t eventID, int64_t setRank, int64_t setBlock)
{
int64_t value = MergeMagicWithValue(magic, eventID);
SetFlag((__gm__ int64_t*)(shareAddrs[setRank]) + setBlock * FLAG_UNIT_INT_NUM, value);
}
// Wait for a single inner-card synchronization flag (memory A)
__aicore__ inline void WaitInnerFlag(int32_t magic, int32_t eventID, int64_t waitRank, int64_t waitBlock)
{
int64_t value = MergeMagicWithValue(magic, eventID);
WaitOneRankPartFlag((__gm__ int64_t*)(shareAddrs[waitRank]) + waitBlock * FLAG_UNIT_INT_NUM, 1, value);
}
// Wait for all inner-card synchronization flags within the entire rank (memory A)
__aicore__ inline void WaitRankInnerFlag(int32_t magic, int32_t eventID, int64_t waitRank)
{
int64_t value = MergeMagicWithValue(magic, eventID);
WaitOneRankAllFlag((__gm__ int64_t*)(shareAddrs[waitRank]), value);
}
// Check all inner-card synchronization flags within the entire rank (memory A)
__aicore__ inline bool CheckRankInnerFlag(int32_t magic, int32_t eventID, int64_t waitRank)
{
int64_t value = MergeMagicWithValue(magic, eventID);
return CheckOneRankAllFlag((__gm__ int64_t*)(shareAddrs[waitRank]), value);
}
// Set inter-card synchronization flag (memory B)
__aicore__ inline void SetOuterFlag(int32_t magic, int32_t eventID)
{
int64_t value = MergeMagicWithValue(magic, eventID);
SetFlag(blockOuterSyncAddr, value);
}
__aicore__ inline void SetOuterFlag(int32_t magic, int32_t eventID, int64_t setRank, int64_t setBlock)
{
__gm__ int64_t* flagAddr = GetOuterFlagAddr(setRank, setBlock);
int64_t value = MergeMagicWithValue(magic, eventID);
SetFlag(flagAddr, value);
}
// Wait for a single inter-card synchronization flag (memory B)
__aicore__ inline void WaitOuterFlag(int32_t magic, int32_t eventID, int64_t waitRank, int64_t waitBlock)
{
int64_t value = MergeMagicWithValue(magic, eventID);
__gm__ int64_t* flagAddr = GetOuterFlagAddr(waitRank, waitBlock);
WaitOneRankPartFlag(flagAddr, 1, value);
}
// Wait for all inter-card synchronization flags within the entire rank (memory B)
__aicore__ inline void WaitOneRankOuterFlag(int32_t magic, int32_t eventID, int64_t rank)
{
int64_t value = MergeMagicWithValue(magic, eventID);
__gm__ int64_t* flagAddr;
flagAddr = GetOuterFlagAddr(rank, 0);
WaitOneRankPartFlag(flagAddr, blockNum, value);
}
// Wait for flagNum inter-card synchronization flags starting from startBlock for all ranks (memory B)
__aicore__ inline void WaitAllRankPartOuterFlag(int32_t magic, int32_t eventID, int64_t startBlock, int64_t flagNum)
{
int64_t value = MergeMagicWithValue(magic, eventID);
__gm__ int64_t* flagAddr;
int waitRank;
for (auto r = 0; r < rankSize; ++r) {
waitRank = (rank + r) % rankSize; // Offset reading of rank flags to prevent performance impact from concurrent copying by multiple cores
flagAddr = GetOuterFlagAddr(waitRank, startBlock);
WaitOneRankPartFlag(flagAddr, flagNum, value);
}
}
// Check flagNum inter-card synchronization flags starting from startBlock for all ranks (memory B)
__aicore__ inline bool CheckAllRankPartOuterFlag(int32_t magic, int32_t eventID, int64_t startBlock,
int64_t flagNum)
{
int64_t value = MergeMagicWithValue(magic, eventID);
__gm__ int64_t* flagAddr;
int waitRank;
for (auto r = 0; r < rankSize; ++r) {
waitRank = (rank + r) % rankSize; // Offset reading of rank flags to prevent performance impact from concurrent copying by multiple cores
flagAddr = GetOuterFlagAddr(waitRank, startBlock);
if (!CheckOneRankPartFlag(flagAddr, flagNum, value)) {
return false;
}
}
return true;
}
// Wait for all inter-card synchronization flags for all ranks, full rank synchronization (memory B)
__aicore__ inline void WaitAllRankOuterFlag(int32_t magic, int32_t eventID)
{
WaitAllRankPartOuterFlag(magic, eventID, 0, blockNum);
}
// Check all inter-card synchronization flags for all ranks, full rank synchronization (memory B)
__aicore__ inline bool CheckAllRankOuterFlag(int32_t magic, int32_t eventID)
{
return CheckAllRankPartOuterFlag(magic, eventID, 0, blockNum);
}
// Low-level interface, set synchronization flag
__aicore__ inline void SetFlag(__gm__ int64_t* setAddr, int64_t setValue)
{
AscendC::SetFlag<HardEvent::MTE3_S>(EVENT_ID0);
AscendC::WaitFlag<HardEvent::MTE3_S>(EVENT_ID0);
AscendC::SetFlag<HardEvent::MTE2_S>(EVENT_ID0);
AscendC::WaitFlag<HardEvent::MTE2_S>(EVENT_ID0);
GlobalTensor<int64_t> globalSet;
globalSet.SetGlobalBuffer(setAddr, FLAG_UNIT_INT_NUM);
LocalTensor<int64_t> localSet = tBuf.GetWithOffset<int64_t>(1, 0);
localSet.SetValue(0, setValue);
// Copy global synchronization flag to local
AscendC::SetFlag<HardEvent::S_MTE3>(EVENT_ID0);
AscendC::WaitFlag<HardEvent::S_MTE3>(EVENT_ID0); // Wait for SetValue to complete
DataCopy(globalSet, localSet, FLAG_UNIT_INT_NUM);
AscendC::SetFlag<HardEvent::MTE3_S>(EVENT_ID0);
AscendC::WaitFlag<HardEvent::MTE3_S>(EVENT_ID0); // Wait for UB->GM to complete
}
// Low-level interface, wait for synchronization flag
__aicore__ inline void WaitFlag(__gm__ int64_t* waitAddr, int64_t waitValue)
{
WaitOneRankPartFlag(waitAddr, 1, waitValue);
}
// Read a flag, return an immediate number
__aicore__ inline int64_t GetFlag(__gm__ int64_t* waitAddr)
{
GlobalTensor<int64_t> globalWait;
globalWait.SetGlobalBuffer(waitAddr, FLAG_UNIT_INT_NUM);
LocalTensor<int64_t> localWait = tBuf.GetWithOffset<int64_t>(1, 0);
// Copy global to local
DataCopy(localWait, globalWait, FLAG_UNIT_INT_NUM);
AscendC::SetFlag<HardEvent::MTE2_S>(EVENT_ID0);
AscendC::WaitFlag<HardEvent::MTE2_S>(EVENT_ID0); // Wait for GM->UB
int64_t res = localWait.GetValue(0);
return res;
}
// Get multiple consecutive synchronization flags within a single card
__aicore__ inline void WaitOneRankPartOuterFlag(int32_t magic, int32_t eventID, int64_t waitRank,
int64_t startBlock, int64_t flagNum)
{
int64_t value = MergeMagicWithValue(magic, eventID);
__gm__ int64_t* flagAddr;
flagAddr = GetOuterFlagAddr(waitRank, startBlock);
WaitOneRankPartFlag(flagAddr, flagNum, value);
}
// Get synchronization flag within a single card (memory A)
__aicore__ inline int64_t GetInnerFlag(int64_t waitRank, int64_t waitBlock)
{
return GetFlag((__gm__ int64_t*)(shareAddrs[waitRank]) + waitBlock * FLAG_UNIT_INT_NUM);
}
__aicore__ inline int64_t GetOuterFlag(int64_t waitRank, int64_t waitBlock)
{
return GetFlag((__gm__ int64_t*)(shareAddrs[waitRank]) + segmentCount + waitBlock * FLAG_UNIT_INT_NUM);
}
// In the rank Chunk Flag area, return success if the destRank chunk Flag value is 0, otherwise fail
__aicore__ inline int64_t GetChunkFlag(int64_t rank, int64_t destRank, int64_t magic, int64_t timeout)
{
int64_t value = MergeMagicWithValue(magic, 0);
int64_t status = GetChunkFlagValue((__gm__ int64_t*)(shareAddrs[rank]) +
IPC_CHUNK_FLAG + destRank * FLAG_UNIT_INT_NUM, value, timeout);
return status;
}
// Set the destRank chunk Flag value in the rank Chunk Flag area to value
__aicore__ inline void SetChunkFlag(int64_t rank, int64_t destRank, int64_t magic, int64_t eventId)
{
int64_t value = MergeMagicWithValue(magic, eventId);
SetFlag((__gm__ int64_t*)(shareAddrs[rank]) + IPC_CHUNK_FLAG + destRank * FLAG_UNIT_INT_NUM, value);
}
__aicore__ inline int64_t GetChunkRecvLen(int64_t rank, int64_t destRank, int64_t magic, int64_t timeout)
{
int64_t len = GetChunkFlagValue((__gm__ int64_t*)(shareAddrs[rank]) + IPC_CHUNK_FLAG +
destRank * FLAG_UNIT_INT_NUM, 0, timeout, true, magic);
return len;
}
private:
__aicore__ inline int64_t MergeMagicWithValue(int32_t magic, int32_t value)
{
// Merge magic as the high bits and eventID as the low bits into a value for comparison
return (static_cast<int64_t>(static_cast<uint32_t>(magic)) << MAGIC_OFFSET) | static_cast<int64_t>(value);
}
__aicore__ inline __gm__ int64_t* GetInnerFlagAddr(int64_t flagRank, int64_t flagBlock)
{
return (__gm__ int64_t*)(shareAddrs[flagRank]) + flagBlock * FLAG_UNIT_INT_NUM;
}
__aicore__ inline __gm__ int64_t* GetOuterFlagAddr(int64_t flagRank, int64_t flagBlock)
{
return (__gm__ int64_t*)(shareAddrs[flagRank]) + segmentCount + flagBlock * FLAG_UNIT_INT_NUM;
}
// Wait for a part of synchronization flags within a rank
__aicore__ inline void WaitOneRankPartFlag(__gm__ int64_t* waitAddr, int64_t flagNum, int64_t checkValue)
{
GlobalTensor<int64_t> globalWait;
globalWait.SetGlobalBuffer(waitAddr, flagNum * FLAG_UNIT_INT_NUM);
LocalTensor<int64_t> localWait = tBuf.GetWithOffset<int64_t>(flagNum * FLAG_UNIT_INT_NUM, 0);
bool isSync = true;
int64_t checkedFlagNum = 0;
do {
// Copy global synchronization flags to local
DataCopy(localWait, globalWait[checkedFlagNum * FLAG_UNIT_INT_NUM],
(flagNum - checkedFlagNum) * FLAG_UNIT_INT_NUM);
AscendC::SetFlag<HardEvent::MTE2_S>(EVENT_ID0);
AscendC::WaitFlag<HardEvent::MTE2_S>(EVENT_ID0); // Wait for GM->UB
// Check if the synchronization flags are equal to checkValue
isSync = true;
int64_t remainToCheck = flagNum - checkedFlagNum;
for (auto i = 0; i < remainToCheck; ++i) {
// Continue waiting if any core has not reached the checkValue phase
int64_t v = localWait.GetValue(i * FLAG_UNIT_INT_NUM);
if ((v & MAGIC_MASK) != (checkValue & MAGIC_MASK) || v < checkValue) {
isSync = false;
checkedFlagNum += i;
break;
}
}
} while (!isSync);
}
// Wait for all synchronization flags within a rank
__aicore__ inline void WaitOneRankAllFlag(__gm__ int64_t* waitAddr, int64_t checkValue)
{
WaitOneRankPartFlag(waitAddr, blockNum, checkValue);
}
// Check partial synchronization flags within a rank, copy only once
__aicore__ inline bool CheckOneRankPartFlag(__gm__ int64_t* waitAddr, int64_t flagNum, int64_t checkValue)
{
GlobalTensor<int64_t> globalWait;
globalWait.SetGlobalBuffer(waitAddr, flagNum * FLAG_UNIT_INT_NUM);
LocalTensor<int64_t> localWait = tBuf.GetWithOffset<int64_t>(flagNum * FLAG_UNIT_INT_NUM, 0);
// Copy global synchronization flags to local
DataCopy(localWait, globalWait, flagNum * FLAG_UNIT_INT_NUM);
AscendC::SetFlag<HardEvent::MTE2_S>(EVENT_ID0);
AscendC::WaitFlag<HardEvent::MTE2_S>(EVENT_ID0); // Wait for GM->UB
// Check if the synchronization flags are equal to checkValue
bool isSync = true;
for (auto i = 0; i < flagNum; ++i) {
// Continue waiting if any core has not reached the checkValue phase
int64_t v = localWait.GetValue(i * FLAG_UNIT_INT_NUM);
if ((v & MAGIC_MASK) != (checkValue & MAGIC_MASK) || v < checkValue) {
isSync = false;
break;
}
}
return isSync;
}
__aicore__ inline int64_t GetChunkFlagValue(__gm__ int64_t* waitAddr, int64_t checkValue, int64_t timeout,
bool checkNonZero = false, int64_t magic = 0)
{
GlobalTensor<int64_t> globalWait;
globalWait.SetGlobalBuffer(waitAddr, FLAG_UNIT_INT_NUM);
LocalTensor<int64_t> localWait = tBuf.GetWithOffset<int64_t>(FLAG_UNIT_INT_NUM, 0);
bool isSync = true;
int64_t waitTimes = 0;
int64_t v = 0;
do {
// Copy global sync flag to local
DataCopy(localWait, globalWait[0], FLAG_UNIT_INT_NUM);
AscendC::SetFlag<HardEvent::MTE2_S>(EVENT_ID0);
AscendC::WaitFlag<HardEvent::MTE2_S>(EVENT_ID0); // Wait for GM->UB
isSync = true;
v = localWait.GetValue(0);
if (checkNonZero) {
// Non-zero check mode
if (((v & MAGIC_MASK) == (static_cast<int64_t>(magic) << MAGIC_OFFSET)) && (v & 0xFFFFFFFF)) {
return v & 0xFFFFFFFF; // Return lower 32 bits when non-zero
}
} else {
// Exact value check mode
if (v == checkValue) {
return WAIT_SUCCESS;
}
}
isSync = false;
waitTimes++;
if (timeout > INT64_MAX / MAX_WAIT_ROUND_UNIT || waitTimes >= (timeout * MAX_WAIT_ROUND_UNIT)) {
isSync = true;
return v; // Return the read flag value
}
} while (!isSync);
return checkNonZero ? 0 : v;
}
// Check all sync flags within a rank, copy only once
__aicore__ inline bool CheckOneRankAllFlag(__gm__ int64_t* waitAddr, int64_t checkValue)
{
return CheckOneRankPartFlag(waitAddr, blockNum, checkValue);
}
int rank;
int rankSize;
int blockIdx;
int blockNum;
GM_ADDR *shareAddrs;
int64_t segmentCount; // Length of a single sync flag segment (count in int64_t)
__gm__ int64_t* localSyncAddr;
__gm__ int64_t* basicSyncAddr; // Intra-card sync flag address for the current block
__gm__ int64_t* blockOuterSyncAddr; // Inter-card sync flag address for the current block
TBuf<QuePosition::VECCALC> tBuf;
};
#endif // SYNC_COLLECTIVES_H