[fix]: fix precision issue in dispatch_ffn_combine_bf16 and remove redundant sync (#7198)
### What this PR does / why we need it?
Fix the precision issue in dispatch_ffn_combine_bf16 operator.
Remove redundant synchronization operations in dispatch_ffn_combine
operator.
- vLLM version: v0.16.0
- vLLM main:
4034c3d32e
---------
Signed-off-by: guanguan0308 <1546542263@qq.com>
This commit is contained in:
@@ -72,17 +72,6 @@ public:
|
||||
CATLASS_DEVICE
|
||||
BlockEpilogue(Arch::Resource<ArchTag> const &resource, Params const ¶ms = Params{}) : params(params)
|
||||
{
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID0);
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID1);
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID2);
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID3);
|
||||
AscendC::SetFlag<AscendC::HardEvent::S_MTE2>(EVENT_ID2);
|
||||
AscendC::SetFlag<AscendC::HardEvent::S_MTE2>(EVENT_ID3);
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID0);
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID1);
|
||||
|
||||
|
||||
|
||||
//ub:192KB
|
||||
n0 = params.n0;
|
||||
size_t ubOffset = 0;
|
||||
@@ -98,137 +87,20 @@ public:
|
||||
source_scale_offset[i] = -1;
|
||||
}
|
||||
tokenPerExpert.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(params.ptrTokenPerExpert));
|
||||
tokenPerExpertLayout = Layout3D(params.EP * params.expertPerRank, params.expertPerRank);
|
||||
tokenPerExpertLayout = Layout3D(AlignUp(params.EP * params.expertPerRank, ALIGN_128), params.expertPerRank);
|
||||
is_ping = true;
|
||||
}
|
||||
|
||||
CATLASS_DEVICE
|
||||
void Finalize()
|
||||
{
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID0);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID1);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID2);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID3);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::S_MTE2>(EVENT_ID2);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::S_MTE2>(EVENT_ID3);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID0);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID1);
|
||||
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID1);
|
||||
}
|
||||
CATLASS_DEVICE
|
||||
~BlockEpilogue()
|
||||
{
|
||||
|
||||
}
|
||||
CATLASS_DEVICE
|
||||
void operator() (
|
||||
AscendC::GlobalTensor<ElementC> const &gmC,
|
||||
AscendC::GlobalTensor<ElementPerTokenScale> const &gmPerTokenScale,
|
||||
GemmCoord& blockCoord,
|
||||
GemmCoord& actualBlockShape,
|
||||
int32_t groupIdx,
|
||||
int32_t preSrcExpertSum,
|
||||
AscendC::GlobalTensor<int32_t> preSumBeforeRank,
|
||||
uint32_t *mPreSumBeforeRank
|
||||
){
|
||||
is_ping = !is_ping;
|
||||
auto event_id = is_ping ? EVENT_ID0 : EVENT_ID1;
|
||||
auto event_id_2 = is_ping ? EVENT_ID2 : EVENT_ID3;
|
||||
|
||||
auto &ubC = ubCList[is_ping];
|
||||
auto &ubD = ubDList[is_ping];
|
||||
int32_t gmCOffset = preSrcExpertSum * params.n2 + blockCoord.m() * params.n2 + blockCoord.n();
|
||||
auto gmTileC = gmC[gmCOffset];
|
||||
auto &ubCFp32 = ubFp32List[is_ping];
|
||||
auto &scaleUb = scaleUbList[is_ping];
|
||||
// auto &ubOutFp32 = ubOutFp32List[is_ping];
|
||||
|
||||
LayoutC layoutGM{actualBlockShape.m(), actualBlockShape.n(), params.n2};
|
||||
LayoutC layoutUB{actualBlockShape.m(), actualBlockShape.n(), n0};
|
||||
|
||||
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(event_id); //for debug
|
||||
copyGmToUbC(ubC, gmTileC, layoutUB, layoutGM);
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(event_id); //for debug
|
||||
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(event_id);
|
||||
AscendC::Cast<float, ElementC, false>(ubCFp32, ubC, AscendC::RoundMode::CAST_NONE, -1, repeat, {1, 1, 8, 4});
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(event_id);
|
||||
|
||||
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(event_id_2);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::S_MTE2>(event_id_2);
|
||||
|
||||
int32_t gmScaleOffset = preSrcExpertSum + blockCoord.m();
|
||||
layout::VectorLayout scaleLauout{actualBlockShape.m()};
|
||||
if (source_scale_offset[event_id] != gmScaleOffset) {
|
||||
source_scale_offset[event_id] = gmScaleOffset;
|
||||
copyScaleGmToUb(scaleUb, gmPerTokenScale[gmScaleOffset], scaleLauout, scaleLauout);
|
||||
}
|
||||
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_S>(event_id_2);
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(event_id_2);
|
||||
|
||||
|
||||
|
||||
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(event_id_2);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_S>(event_id_2); // 注意必须是MTE2_S,不能是MTE2_V,否则会读到0,造成乱码
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
for (int32_t row = 0; row < actualBlockShape.m(); ++row) {
|
||||
float scale = scaleUb(row);
|
||||
Muls<float, false>(ubCFp32[n0* row], ubCFp32[n0 * row] , scale, -1, (actualBlockShape.n() + 127) / 128 * 2, {1, 1, 8, 8});
|
||||
}
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(event_id);
|
||||
AscendC::Cast<ElementD, float, false>(ubD, ubCFp32, AscendC::RoundMode::CAST_RINT, -1, repeat, {1, 1, 4, 8});
|
||||
AscendC::SetFlag<AscendC::HardEvent::S_MTE2>(event_id_2);
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(event_id_2);
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(event_id);
|
||||
|
||||
int32_t lenTile = actualBlockShape.m();
|
||||
int32_t stTile = blockCoord.m();
|
||||
int32_t edTile = stTile + lenTile;
|
||||
int32_t preSumRankInExpert = 0;
|
||||
int32_t tileOffset = 0;
|
||||
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(event_id); //for debug
|
||||
for (int32_t dstEpIdx = 0; dstEpIdx < params.EP; dstEpIdx ++) {
|
||||
int32_t lenRankInExpert = tokenPerExpert(tokenPerExpertLayout(dstEpIdx, params.rank, groupIdx));
|
||||
int32_t dstExpertOffset = preSumBeforeRank(dstEpIdx * 16);
|
||||
int32_t stRankInExpert = preSumRankInExpert;
|
||||
int32_t edRankInExpert = stRankInExpert + lenRankInExpert;
|
||||
preSumRankInExpert += lenRankInExpert;
|
||||
if (stRankInExpert >= edTile) {
|
||||
break;
|
||||
}
|
||||
else if (edRankInExpert <= stTile) {
|
||||
continue;
|
||||
}
|
||||
int32_t stData = max(stRankInExpert, stTile);
|
||||
int32_t edData = min(edRankInExpert, edTile);
|
||||
uint32_t lenData = edData - stData;
|
||||
if (lenData <= 0){
|
||||
continue;
|
||||
}
|
||||
|
||||
uint32_t dstOffsetInExpert = 0;
|
||||
if (stTile > stRankInExpert) {
|
||||
dstOffsetInExpert = stTile - stRankInExpert;
|
||||
}
|
||||
AscendC::GlobalTensor<ElementD> gmRemotePeer;
|
||||
__gm__ void* dstPeermemPtr = params.shmem(params.offsetD, dstEpIdx);
|
||||
gmRemotePeer.SetGlobalBuffer(reinterpret_cast<__gm__ ElementD*>(dstPeermemPtr));
|
||||
MatrixCoord dstOffset{dstOffsetInExpert + dstExpertOffset + mPreSumBeforeRank[dstEpIdx], blockCoord.n()};
|
||||
int64_t gmDstOffset = params.layoutC.GetOffset(dstOffset);
|
||||
auto gmTileD = gmRemotePeer[gmDstOffset];
|
||||
LayoutC layoutGM2{lenData, actualBlockShape.n(), params.n2};
|
||||
LayoutC layoutUB2{lenData, actualBlockShape.n(), n0};
|
||||
copyUbToGmD(gmTileD, ubD[tileOffset * n0], layoutGM2, layoutUB2);
|
||||
tileOffset += lenData;
|
||||
}
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(event_id);
|
||||
|
||||
}
|
||||
|
||||
CATLASS_DEVICE
|
||||
@@ -238,14 +110,12 @@ public:
|
||||
GemmCoord& actualBlockShape,
|
||||
int32_t groupIdx,
|
||||
int32_t preSrcExpertSum,
|
||||
AscendC::GlobalTensor<int32_t> preSumBeforeRank,
|
||||
uint32_t *mPreSumBeforeRank
|
||||
AscendC::GlobalTensor<int32_t> preSumBeforeRank
|
||||
){
|
||||
is_ping = !is_ping;
|
||||
auto event_id = is_ping ? EVENT_ID0 : EVENT_ID1;
|
||||
|
||||
auto &ubC = ubCList[is_ping];
|
||||
auto &ubD = ubDList[is_ping];
|
||||
int32_t gmCOffset = preSrcExpertSum * params.n2 + blockCoord.m() * params.n2 + blockCoord.n();
|
||||
auto gmTileC = gmC[gmCOffset];
|
||||
|
||||
@@ -253,7 +123,7 @@ public:
|
||||
LayoutC layoutUB{actualBlockShape.m(), actualBlockShape.n(), n0};
|
||||
|
||||
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(event_id); //for debug
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(event_id); //for debug
|
||||
copyGmToUbC(ubC, gmTileC, layoutUB, layoutGM);
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_MTE3>(event_id); //for debug
|
||||
|
||||
@@ -263,10 +133,10 @@ public:
|
||||
int32_t preSumRankInExpert = 0;
|
||||
int32_t tileOffset = 0;
|
||||
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_MTE3>(event_id); //for debug
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_MTE3>(event_id);
|
||||
for (int32_t dstEpIdx = 0; dstEpIdx < params.EP; dstEpIdx ++) {
|
||||
int32_t lenRankInExpert = tokenPerExpert(tokenPerExpertLayout(dstEpIdx, params.rank, groupIdx));
|
||||
int32_t dstExpertOffset = preSumBeforeRank(dstEpIdx * 16);
|
||||
int32_t dstExpertOffset = preSumBeforeRank(dstEpIdx * params.expertPerRank + groupIdx);
|
||||
int32_t stRankInExpert = preSumRankInExpert;
|
||||
int32_t edRankInExpert = stRankInExpert + lenRankInExpert;
|
||||
preSumRankInExpert += lenRankInExpert;
|
||||
@@ -282,7 +152,7 @@ public:
|
||||
if (lenData <= 0){
|
||||
continue;
|
||||
}
|
||||
|
||||
|
||||
uint32_t dstOffsetInExpert = 0;
|
||||
if (stTile > stRankInExpert) {
|
||||
dstOffsetInExpert = stTile - stRankInExpert;
|
||||
@@ -290,7 +160,7 @@ public:
|
||||
AscendC::GlobalTensor<ElementD> gmRemotePeer;
|
||||
__gm__ void* dstPeermemPtr = params.shmem(params.offsetD, dstEpIdx);
|
||||
gmRemotePeer.SetGlobalBuffer(reinterpret_cast<__gm__ ElementD*>(dstPeermemPtr));
|
||||
MatrixCoord dstOffset{dstOffsetInExpert + dstExpertOffset + mPreSumBeforeRank[dstEpIdx], blockCoord.n()};
|
||||
MatrixCoord dstOffset{dstOffsetInExpert + dstExpertOffset, blockCoord.n()};
|
||||
int64_t gmDstOffset = params.layoutC.GetOffset(dstOffset);
|
||||
auto gmTileD = gmRemotePeer[gmDstOffset];
|
||||
LayoutC layoutGM2{lenData, actualBlockShape.n(), params.n2};
|
||||
@@ -298,7 +168,8 @@ public:
|
||||
copyUbToGmD(gmTileD, ubC[tileOffset * n0], layoutGM2, layoutUB2);
|
||||
tileOffset += lenData;
|
||||
}
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(event_id);
|
||||
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(event_id);
|
||||
|
||||
}
|
||||
|
||||
|
||||
@@ -22,8 +22,6 @@
|
||||
|
||||
namespace Catlass::Gemm::Block {
|
||||
|
||||
constexpr uint16_t CROSS_CORE_FLAG_MAX_SET_COUNT = 15;
|
||||
|
||||
template<AscendC::HardEvent event>
|
||||
__aicore__ inline void SyncFlagFunc(int32_t eventID)
|
||||
{
|
||||
@@ -153,9 +151,11 @@ public:
|
||||
L1TileShape::K, L1TileShape::N);
|
||||
|
||||
CATLASS_DEVICE
|
||||
BlockMmad(Arch::Resource<ArchTag> &resource, uint32_t l1BufAddrStart = 0)
|
||||
BlockMmad(Arch::Resource<ArchTag> &resource, __gm__ int32_t* flagPtr = nullptr, int32_t expertPerRank = 0, uint32_t l1BufAddrStart = 0)
|
||||
{
|
||||
syncGroupIdx = 0;
|
||||
ptrSoftFlagBase_ = flagPtr;
|
||||
expertPerRank_ = expertPerRank;
|
||||
InitL1(resource, l1BufAddrStart);
|
||||
InitL0A(resource);
|
||||
InitL0B(resource);
|
||||
@@ -272,9 +272,21 @@ public:
|
||||
CATLASS_DEVICE
|
||||
void Finalize(int32_t target, int32_t flag = 0)
|
||||
{
|
||||
for(;syncGroupIdx <= target; syncGroupIdx++) {
|
||||
int32_t flagId = syncGroupIdx / CROSS_CORE_FLAG_MAX_SET_COUNT + flag;
|
||||
AscendC::CrossCoreSetFlag<0x2, PIPE_FIX>(flagId);
|
||||
if (ptrSoftFlagBase_ != nullptr) {
|
||||
if (target < 0) {
|
||||
return;
|
||||
}
|
||||
AscendC::SetFlag<AscendC::HardEvent::FIX_MTE3>(EVENT_ID0);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::FIX_MTE3>(EVENT_ID0);
|
||||
AscendC::GlobalTensor<int32_t> flagGlobal;
|
||||
flagGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(ptrSoftFlagBase_) + (expertPerRank_ + AscendC::GetBlockIdx()) * FLAGSTRIDE);
|
||||
AscendC::DataCopy(flagGlobal, l1FTensor[target * 16], FLAGSTRIDE);
|
||||
}
|
||||
else {
|
||||
for(;syncGroupIdx <= target; syncGroupIdx++) {
|
||||
int32_t flagId = syncGroupIdx / 15 + flag;
|
||||
AscendC::CrossCoreSetFlag<0x2, PIPE_FIX>(flagId);
|
||||
}
|
||||
}
|
||||
}
|
||||
private:
|
||||
@@ -291,7 +303,6 @@ private:
|
||||
layout::VectorLayout layoutScale;
|
||||
int32_t syncLoopIdx;
|
||||
int32_t flag;
|
||||
|
||||
CATLASS_DEVICE
|
||||
L1TileMmadParams() = default;
|
||||
};
|
||||
@@ -310,11 +321,24 @@ private:
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(l1AEventList[i]);
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(l1BEventList[i]);
|
||||
}
|
||||
uint32_t l1SOffset = l1BOffset + L1B_TILE_SIZE * L1_STAGES;
|
||||
if constexpr (std::is_same_v<ElementA, int8_t>) {
|
||||
uint32_t l1SOffset = l1BOffset + L1B_TILE_SIZE * L1_STAGES;
|
||||
l1STensor = resource.l1Buf.template GetBufferByByte<uint64_t>(l1SOffset);
|
||||
AscendC::SetFlag<AscendC::HardEvent::FIX_MTE2>(0);
|
||||
}
|
||||
if (ptrSoftFlagBase_ != nullptr) {
|
||||
// Initialize the flag matrix (structure as below):
|
||||
// 1 0 0 0 0 0 0 0
|
||||
// 2 0 0 0 0 0 0 0
|
||||
// ...
|
||||
// 16 0 0 0 0 0 0 0
|
||||
// Then move it to L1
|
||||
uint32_t l1FOffset = l1SOffset + L1S_TILE_SIZE;
|
||||
l1FTensor = resource.l1Buf.template GetBufferByByte<int32_t>(l1FOffset);
|
||||
AscendC::GlobalTensor<int32_t> flagBase;
|
||||
flagBase.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(ptrSoftFlagBase_));
|
||||
AscendC::DataCopy(l1FTensor, flagBase, expertPerRank_ * FLAGSTRIDE);
|
||||
}
|
||||
}
|
||||
|
||||
CATLASS_DEVICE
|
||||
@@ -463,12 +487,20 @@ private:
|
||||
if constexpr (std::is_same_v<ElementA, int8_t>) {
|
||||
AscendC::SetFlag<AscendC::HardEvent::FIX_MTE2>(0);
|
||||
}
|
||||
#ifdef __TILE_SYNC__
|
||||
if (params.flag > 0) {
|
||||
int32_t flagId = params.flag + params.syncLoopIdx / 8;
|
||||
AscendC::CrossCoreSetFlag<0x2, PIPE_FIX>(flagId);
|
||||
}
|
||||
#else
|
||||
Finalize(params.syncLoopIdx, params.flag);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
AscendC::LocalTensor<ElementA> l1ATensorList[L1_STAGES];
|
||||
AscendC::LocalTensor<ElementB> l1BTensorList[L1_STAGES];
|
||||
AscendC::LocalTensor<uint64_t> l1STensor;
|
||||
AscendC::LocalTensor<int32_t> l1FTensor;
|
||||
int32_t syncGroupIdx;
|
||||
int32_t l1AEventList[L1_STAGES];
|
||||
int32_t l1BEventList[L1_STAGES];
|
||||
@@ -497,8 +529,11 @@ private:
|
||||
CopyL1ToL0A copyL1ToL0A;
|
||||
CopyL1ToL0B copyL1ToL0B;
|
||||
CopyL0CToGm copyL0CToGm;
|
||||
|
||||
__gm__ int32_t* ptrSoftFlagBase_ = nullptr;
|
||||
int32_t expertPerRank_;
|
||||
};
|
||||
|
||||
} // namespace Catlass::Gemm::Block
|
||||
|
||||
#endif // CATLASS_GEMM_BLOCK_BLOCK_MMAD_PRELOAD_FIXPIPE_QUANT_HPP
|
||||
#endif // CATLASS_GEMM_BLOCK_BLOCK_MMAD_PRELOAD_FIXPIPE_QUANT_HPP
|
||||
@@ -4,6 +4,10 @@
|
||||
constexpr static uint64_t MB_SIZE = 1024 * 1024UL;
|
||||
constexpr static int32_t NUMS_PER_FLAG = 16;
|
||||
constexpr static int32_t CACHE_LINE = 512;
|
||||
constexpr static int32_t FLAGSTRIDE = 16;
|
||||
constexpr static int32_t RESET_VAL = 0xffff;
|
||||
constexpr static int32_t ALIGN_128 = 128;
|
||||
constexpr uint32_t MAX_EXPERTS_PER_RANK = 32;
|
||||
constexpr static int32_t UB_ALIGN = 32;
|
||||
constexpr uint16_t CROSS_CORE_FLAG_MAX_SET_COUNT = 15;
|
||||
#endif
|
||||
@@ -5,16 +5,23 @@
|
||||
#include "kernel_operator.h"
|
||||
#include "const_args.hpp"
|
||||
|
||||
#ifdef HCCL_COMM
|
||||
#include "moe_distribute_base.h"
|
||||
|
||||
#ifndef HCCL_COMM
|
||||
#include "shmem_api.h"
|
||||
using namespace AscendC::HcclContextDef;
|
||||
|
||||
#else
|
||||
#include "shmem_api.h"
|
||||
#endif
|
||||
|
||||
#define FORCE_INLINE_AICORE inline __attribute__((always_inline)) __aicore__
|
||||
constexpr int32_t MAX_RANK_SIZE = 32;
|
||||
constexpr int32_t SHMEM_MEM = 1024 * MB_SIZE;
|
||||
constexpr int32_t SHMEM_MEM = 700 * MB_SIZE;
|
||||
|
||||
constexpr uint16_t SEND_SYNC_EVENT_ID = 9;
|
||||
constexpr uint16_t RECV_SYNC_EVENT_ID = 10;
|
||||
|
||||
constexpr uint32_t SELF_STATE_OFFSET = 256 * 1024;
|
||||
constexpr uint32_t STATE_OFFSET = 512;
|
||||
|
||||
FORCE_INLINE_AICORE void AicSyncAll() {
|
||||
AscendC::CrossCoreSetFlag<0x0, PIPE_FIX>(8);
|
||||
@@ -31,10 +38,11 @@ FORCE_INLINE_AICORE T gm_load(__gm__ T *cache) {
|
||||
return *((__gm__ T *)cache);
|
||||
}
|
||||
|
||||
FORCE_INLINE_AICORE void gm_dcci(__gm__ uint8_t * addr) {
|
||||
template<typename T>
|
||||
FORCE_INLINE_AICORE void gm_dcci(__gm__ T * addr) {
|
||||
using namespace AscendC;
|
||||
GlobalTensor<uint8_t> global;
|
||||
global.SetGlobalBuffer(addr);
|
||||
global.SetGlobalBuffer(reinterpret_cast<GM_ADDR>(addr));
|
||||
|
||||
// Important: add hint to avoid dcci being optimized by compiler
|
||||
__asm__ __volatile__("");
|
||||
@@ -58,7 +66,7 @@ FORCE_INLINE_AICORE int32_t gm_signal_wait_until_eq_for_barrier(__gm__ int32_t *
|
||||
FORCE_INLINE_AICORE void gm_signal_wait_until_ne(__gm__ int32_t *sig_addr, int32_t cmp_val) {
|
||||
do {
|
||||
AscendC::LocalTensor<int32_t> ub;
|
||||
ub.address_.logicPos = static_cast<uint8_t>(TPosition::VECIN);
|
||||
ub.address_.logicPos = static_cast<uint8_t>(AscendC::TPosition::VECIN);
|
||||
ub.address_.bufferAddr = 0;
|
||||
AscendC::GlobalTensor<int32_t> sig;
|
||||
sig.SetGlobalBuffer(sig_addr);
|
||||
@@ -75,10 +83,10 @@ FORCE_INLINE_AICORE void gm_signal_wait_until_ne(__gm__ int32_t *sig_addr, int32
|
||||
|
||||
class HcclShmem {
|
||||
public:
|
||||
#ifdef HCCL_COMM
|
||||
#ifdef HCCL_COMM // HCCL needs to initialize the HCCL context
|
||||
__gm__ HcclOpResParamCustom *WinContext_{nullptr};
|
||||
Hccl<HCCL_SERVER_TYPE_AICPU> hccl_;
|
||||
GM_ADDR m_ptrArray[MAX_RANK_SIZE];
|
||||
AscendC::LocalTensor<int32_t> ub;
|
||||
FORCE_INLINE_AICORE
|
||||
HcclShmem(){
|
||||
auto contextGM0 = AscendC::GetHcclContext<HCCL_GROUP_ID_0>();
|
||||
@@ -87,17 +95,13 @@ public:
|
||||
m_rank = WinContext_->localUsrRankId;
|
||||
m_rankSize = WinContext_->rankSize;
|
||||
m_segmentSize = WinContext_->winSize;
|
||||
for (int i = 0; i < m_rankSize; i++) {
|
||||
m_ptrArray[i] = (GM_ADDR)((i == m_rank) ? WinContext_->localWindowsIn :
|
||||
((HcclRankRelationResV2Custom *)(WinContext_->remoteRes[i].nextDevicePtr))->windowsIn);
|
||||
}
|
||||
}
|
||||
#else
|
||||
FORCE_INLINE_AICORE
|
||||
HcclShmem(){
|
||||
m_segmentSize = SHMEM_MEM;
|
||||
}
|
||||
FORCE_INLINE_AICORE
|
||||
FORCE_INLINE_AICORE
|
||||
void initShmem(GM_ADDR symmetricPtr_, size_t rank, size_t rankSize) {
|
||||
symmetricPtr = symmetricPtr_;
|
||||
m_rank = rank;
|
||||
@@ -106,25 +110,26 @@ public:
|
||||
#endif
|
||||
|
||||
FORCE_INLINE_AICORE
|
||||
GM_ADDR operator() () const {
|
||||
GM_ADDR operator() () const { // No parameters: return pointer to local peermem
|
||||
#ifdef HCCL_COMM
|
||||
return m_ptrArray[m_rank];
|
||||
return (GM_ADDR)(WinContext_->localWindowsIn);
|
||||
#else
|
||||
return reinterpret_cast<GM_ADDR>(shmem_ptr(symmetricPtr, m_rank));
|
||||
#endif
|
||||
}
|
||||
|
||||
FORCE_INLINE_AICORE
|
||||
GM_ADDR operator() (int32_t index) const {
|
||||
GM_ADDR operator() (int32_t index) const { // With index parameter: return pointer to the base address of remote peermem
|
||||
#ifdef HCCL_COMM
|
||||
return m_ptrArray[index];
|
||||
return (GM_ADDR)((index == m_rank) ? WinContext_->localWindowsIn :
|
||||
((HcclRankRelationResV2Custom *)(WinContext_->remoteRes[index].nextDevicePtr))->windowsIn);
|
||||
#else
|
||||
return reinterpret_cast<GM_ADDR>(shmem_ptr(symmetricPtr, index));
|
||||
#endif
|
||||
}
|
||||
|
||||
FORCE_INLINE_AICORE
|
||||
GM_ADDR operator () (int64_t offset, int32_t rankId) const {
|
||||
GM_ADDR operator () (int64_t offset, int32_t rankId) const {
|
||||
#ifdef HCCL_COMM
|
||||
if (offset < 0 || offset >= m_segmentSize) {
|
||||
return nullptr;
|
||||
@@ -132,7 +137,8 @@ public:
|
||||
if (rankId < 0 || rankId >= m_rankSize) {
|
||||
return nullptr;
|
||||
}
|
||||
return m_ptrArray[rankId] + offset;
|
||||
return (GM_ADDR)((rankId == m_rank) ? WinContext_->localWindowsIn :
|
||||
((HcclRankRelationResV2Custom *)(WinContext_->remoteRes[rankId].nextDevicePtr))->windowsIn) + offset;
|
||||
#else
|
||||
return reinterpret_cast<GM_ADDR>(shmem_ptr((symmetricPtr + offset), rankId));
|
||||
#endif
|
||||
@@ -176,6 +182,130 @@ public:
|
||||
gm_store(sync_base, count);
|
||||
}
|
||||
|
||||
|
||||
FORCE_INLINE_AICORE
|
||||
void InitStatusTargetSum()
|
||||
{
|
||||
using namespace AscendC;
|
||||
uint64_t flag_offset = (m_segmentSize - MB_SIZE) + SELF_STATE_OFFSET;
|
||||
//uint64_t self_state_offset = (m_segmentSize - 2 * MB_SIZE);
|
||||
// ep state
|
||||
//uint32_t coreIdx = get_block_idx();;
|
||||
uint32_t coreIdx = GetBlockIdx();
|
||||
GlobalTensor<int32_t> selfStatusTensor;
|
||||
selfStatusTensor.SetGlobalBuffer((__gm__ int32_t *)((*this)() + flag_offset));
|
||||
__asm__ __volatile__("");
|
||||
DataCacheCleanAndInvalid<int32_t, CacheLine::SINGLE_CACHE_LINE, DcciDst::CACHELINE_OUT>(selfStatusTensor[coreIdx * UB_ALIGN]);
|
||||
__asm__ __volatile__("");
|
||||
int32_t state = selfStatusTensor(coreIdx * UB_ALIGN);
|
||||
if (state == 0) {
|
||||
sumTarget_ = static_cast<float>(1.0);
|
||||
selfStatusTensor(coreIdx * UB_ALIGN) = 0x3F800000; // 1.0f
|
||||
epStateValue_ = 0x3F800000; // 1.0f
|
||||
} else {
|
||||
sumTarget_ = static_cast<float>(0.0);
|
||||
selfStatusTensor(coreIdx * UB_ALIGN) = 0;
|
||||
epStateValue_ = 0;
|
||||
}
|
||||
__asm__ __volatile__("");
|
||||
DataCacheCleanAndInvalid<int32_t, CacheLine::SINGLE_CACHE_LINE, DcciDst::CACHELINE_OUT>(selfStatusTensor[coreIdx * UB_ALIGN]);
|
||||
__asm__ __volatile__("");
|
||||
}
|
||||
|
||||
FORCE_INLINE_AICORE
|
||||
void CrossRankSyncV2Set(AscendC::LocalTensor<int32_t> ctrBuffer) {
|
||||
//subblockid = 0
|
||||
uint32_t stateOffset_ = STATE_OFFSET;
|
||||
// uint32_t epStateOffsetOnWin_ = m_rank * stateOffset_;
|
||||
|
||||
uint64_t flag_offset = (m_segmentSize - MB_SIZE) + m_rank * stateOffset_;
|
||||
//uint64_t flag_offset = (m_segmentSize - MB_SIZE);
|
||||
int vec_size = get_block_num();
|
||||
int vec_id = get_block_idx();
|
||||
|
||||
AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>(RECV_SYNC_EVENT_ID);
|
||||
AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>(SEND_SYNC_EVENT_ID);
|
||||
AscendC::CrossCoreWaitFlag(SEND_SYNC_EVENT_ID);
|
||||
pipe_barrier(PIPE_ALL);
|
||||
|
||||
ctrBuffer.SetValue(0, epStateValue_);
|
||||
AscendC::SetFlag<AscendC::HardEvent::S_MTE3>(EVENT_ID0);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::S_MTE3>(EVENT_ID0);
|
||||
for (uint32_t dstEpIdx = vec_id; dstEpIdx < m_rankSize; dstEpIdx += vec_size) {
|
||||
AscendC::GlobalTensor<int32_t> gmDstStates;
|
||||
gmDstStates.SetGlobalBuffer((__gm__ int32_t*)((*this)(flag_offset, dstEpIdx)));
|
||||
DataCopy(gmDstStates, ctrBuffer, 8);
|
||||
}
|
||||
AscendC::CrossCoreWaitFlag(RECV_SYNC_EVENT_ID);
|
||||
}
|
||||
|
||||
FORCE_INLINE_AICORE
|
||||
void CrossRankSyncV2Wait(AscendC::LocalTensor<float> statusTensor, AscendC::LocalTensor<float> gatherMaskOutTensor,
|
||||
AscendC::LocalTensor<uint32_t> gatherTmpTensor, AscendC::LocalTensor<float> statusSumOutTensor) {
|
||||
|
||||
uint64_t flag_offset = (m_segmentSize - MB_SIZE);
|
||||
int vec_size = get_block_num();
|
||||
int vec_id = get_block_idx();
|
||||
uint32_t stateOffset_ = STATE_OFFSET;
|
||||
|
||||
uint32_t sendRankNum_ = m_rankSize / vec_size;
|
||||
uint32_t remainderRankNum = m_rankSize % vec_size;
|
||||
uint32_t startRankId_ = sendRankNum_ * vec_id;
|
||||
if (vec_id < remainderRankNum) {
|
||||
sendRankNum_++;
|
||||
startRankId_ += vec_id;
|
||||
} else {
|
||||
startRankId_ += remainderRankNum;
|
||||
}
|
||||
uint32_t endRankId_ = startRankId_ + sendRankNum_;
|
||||
AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>(SEND_SYNC_EVENT_ID);
|
||||
|
||||
AscendC::GlobalTensor<float> epStatusSpaceGlobalTensor_;
|
||||
epStatusSpaceGlobalTensor_.SetGlobalBuffer((__gm__ float *)((*this)() + flag_offset));
|
||||
|
||||
if (startRankId_ < m_rankSize) {
|
||||
AscendC::PipeBarrier<PIPE_ALL>();
|
||||
gatherTmpTensor.SetValue(0, 1);
|
||||
uint32_t mask = 1; // gatherMask + sum
|
||||
uint64_t rsvdCnt = 0;
|
||||
// DataCopyParams intriParams{static_cast<uint16_t>(sendRankNum_), 1,
|
||||
// static_cast<uint16_t>((moeSendNum_ > 512) ? 7 : 15), 0};
|
||||
AscendC::DataCopyParams intriParams{static_cast<uint16_t>(sendRankNum_), 1,
|
||||
static_cast<uint16_t>(15), 0};
|
||||
|
||||
float sumOfFlag = static_cast<float>(-1.0);
|
||||
float minTarget = (sumTarget_ * sendRankNum_) - (float)0.5;
|
||||
float maxTarget = (sumTarget_ * sendRankNum_) + (float)0.5;
|
||||
AscendC::SumParams sumParams{1, sendRankNum_, sendRankNum_};
|
||||
|
||||
AscendC::SetFlag<AscendC::HardEvent::S_V>(EVENT_ID0);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::S_V>(EVENT_ID0);
|
||||
|
||||
while ((sumOfFlag < minTarget) || (sumOfFlag > maxTarget)) {
|
||||
AscendC::DataCopy<float>(statusTensor, epStatusSpaceGlobalTensor_[startRankId_ * stateOffset_ / sizeof(float)],
|
||||
intriParams);
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
|
||||
|
||||
GatherMask(gatherMaskOutTensor, statusTensor, gatherTmpTensor, true, mask,
|
||||
{1, (uint16_t)sendRankNum_, 1, 0}, rsvdCnt);
|
||||
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
AscendC::Sum(statusSumOutTensor, gatherMaskOutTensor, sumParams);
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_S>(EVENT_ID0);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_S>(EVENT_ID0);
|
||||
sumOfFlag = statusSumOutTensor.GetValue(0);
|
||||
}
|
||||
}
|
||||
|
||||
AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>(RECV_SYNC_EVENT_ID);
|
||||
AscendC::CrossCoreWaitFlag(RECV_SYNC_EVENT_ID);
|
||||
|
||||
//unpermute
|
||||
AscendC::CrossCoreWaitFlag(SEND_SYNC_EVENT_ID);
|
||||
}
|
||||
|
||||
|
||||
FORCE_INLINE_AICORE
|
||||
__gm__ int32_t* SyncBaseAddr() {
|
||||
uint64_t flag_offset = (m_segmentSize - MB_SIZE) / sizeof(int32_t);
|
||||
@@ -187,9 +317,11 @@ private:
|
||||
int32_t m_rank;
|
||||
int32_t m_rankSize;
|
||||
size_t m_segmentSize;
|
||||
float sumTarget_{0.0};
|
||||
int32_t epStateValue_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
|
||||
#endif
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user