[Kernel]: Optimize DispatchFFNCombine performance (#6468)

### What this PR does / why we need it?

This PR focuses on performance optimization for the DispatchFFNCombine
operator. The key optimizations include:

1. Improving communication efficiency by merging the transmission of
tokens and scales;
2. Decoupling multi-core dependencies and reducing waiting bubbles in
the combine process through tile-granularity communication;
3. Optimizing the full-card synchronization overhead before the
umpermute operation.

These optimizations aim to reduce the overall execution latency of the
DispatchFFNCombine operator and enhance the runtime performance of the
model inference process on Ascend devices.

### Does this PR introduce _any_ user-facing change?

No. This PR only involves internal performance optimization of the
DispatchFFNCombine operator and does not introduce any changes to
user-facing APIs, interfaces, or behaviors.

### How was this patch tested?

1. Enable the DispatchFFNCombine operator by setting the environment
variable:
```
export VLLM_ASCEND_ENABLE_FUSED_MC2=1
```
2. Run the standard model inference test suite with the above
environment variable enabled;
4. Verify the correctness of model outputs (ensuring no functional
regression) and measure the performance improvement of the
DispatchFFNCombine operator (reduced latency and improved throughput).

- vLLM version: v0.14.1
- vLLM main:
dc917cceb8

Signed-off-by: xulei_ict <xulei292@huawei.com>
Co-authored-by: xulei_ict <xulei292@huawei.com>
This commit is contained in:
xulei
2026-02-09 16:30:34 +08:00
committed by GitHub
parent 9c6d031797
commit 8325528368
13 changed files with 897 additions and 356 deletions

View File

@@ -0,0 +1,243 @@
#ifndef CATLASS_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_V2_ONLY_HPP
#define CATLASS_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_V2_ONLY_HPP
#include "catlass/catlass.hpp"
#include "catlass/arch/resource.hpp"
#include "catlass/epilogue/dispatch_policy.hpp"
#include "catlass/gemm_coord.hpp"
#include "catlass/matrix_coord.hpp"
#include "catlass/layout/layout.hpp"
#include "catlass/detail/callback.hpp"
#include "hccl_shmem.hpp"
#include "layout3d.hpp"
namespace Catlass::Epilogue::Block {
template <
uint32_t UB_STAGES_,
class CType_,
class LayoutPerTokenScale_,
class DType_,
class TileCopy_
>
class BlockEpilogue <
EpilogueAtlasA2PerTokenDequantV2<UB_STAGES_>,
CType_,
Gemm::GemmType<float, LayoutPerTokenScale_>,
DType_,
TileCopy_
> {
public:
using DispatchPolicy = EpilogueAtlasA2PerTokenDequantV2<UB_STAGES_>;
using ArchTag = typename DispatchPolicy::ArchTag;
static constexpr uint32_t UB_STAGES = UB_STAGES_;
// Data infos
using ElementC = typename CType_::Element;
using LayoutC = typename CType_::Layout;
using ElementPerTokenScale = float;
using LayoutPerTokenScale = LayoutPerTokenScale_;
using ElementD = typename DType_::Element;
using LayoutD = typename DType_::Layout;
using CopyScaleGmToUb = Epilogue::Tile::CopyGm2Ub<ArchTag, Gemm::GemmType<float, layout::VectorLayout>>;
using CopyGmToUbC = typename TileCopy_::CopyGmToUbC;
using CopyUbToGmD = typename TileCopy_::CopyUbToGmD;
struct Params {
__gm__ int32_t *ptrTokenPerExpert{nullptr};
int32_t EP;
int32_t expertPerRank;
int32_t n2;
LayoutC layoutC;
int32_t n0;
int32_t rank;
HcclShmem shmem;
int32_t offsetD;
CATLASS_DEVICE
Params() {};
CATLASS_DEVICE
Params(int32_t EP_, int32_t expertPerRank_, int32_t rank_, __gm__ int32_t *ptrTokenPerExpert_,
LayoutC layoutC_, int32_t n2_, int32_t n0_, HcclShmem& shmem_, int32_t offsetD_) :
ptrTokenPerExpert(ptrTokenPerExpert_), EP(EP_),
expertPerRank(expertPerRank_),rank(rank_), layoutC(layoutC_), n2(n2_), n0(n0_),
shmem(shmem_), offsetD(offsetD_)
{}
};
CATLASS_DEVICE
BlockEpilogue(Arch::Resource<ArchTag> const &resource, Params const &params = Params{}) : params(params)
{
//ub:192KB
n0 = params.n0;
size_t ubOffset = 0;
for(int32_t i = 0; i < 2; i++) {
ubCList[i] = resource.ubBuf.template GetBufferByByte<ElementC>(ubOffset);
ubOffset += max_len * sizeof(ElementC);
ubDList[i] = resource.ubBuf.template GetBufferByByte<ElementD>(ubOffset);
ubOffset += max_len * sizeof(ElementD);
ubFp32List[i] = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
ubOffset += max_len * sizeof(float);
scaleUbList[i] = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
ubOffset += (max_len / n0) * sizeof(float);
source_scale_offset[i] = -1;
}
tokenPerExpert.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(params.ptrTokenPerExpert));
tokenPerExpertLayout = Layout3D(AlignUp(params.EP * params.expertPerRank, 128), params.expertPerRank);
is_ping = true;
}
CATLASS_DEVICE
void Finalize()
{
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID1);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID2);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID3);
AscendC::WaitFlag<AscendC::HardEvent::S_MTE2>(EVENT_ID2);
AscendC::WaitFlag<AscendC::HardEvent::S_MTE2>(EVENT_ID3);
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID1);
}
CATLASS_DEVICE
~BlockEpilogue()
{
}
CATLASS_DEVICE
void operator() (
AscendC::GlobalTensor<ElementC> const &gmC,
AscendC::GlobalTensor<ElementPerTokenScale> const &gmPerTokenScale,
GemmCoord& blockCoord,
GemmCoord& actualBlockShape,
int32_t groupIdx,
int32_t preSrcExpertSum,
AscendC::GlobalTensor<int32_t> preSumBeforeRank
){
is_ping = !is_ping;
auto event_id = is_ping ? EVENT_ID0 : EVENT_ID1;
auto event_id_2 = is_ping ? EVENT_ID2 : EVENT_ID3;
auto &ubC = ubCList[is_ping];
auto &ubD = ubDList[is_ping];
int32_t gmCOffset = preSrcExpertSum * params.n2 + blockCoord.m() * params.n2 + blockCoord.n();
auto gmTileC = gmC[gmCOffset];
auto &ubCFp32 = ubFp32List[is_ping];
auto &scaleUb = scaleUbList[is_ping];
LayoutC layoutGM{actualBlockShape.m(), actualBlockShape.n(), params.n2};
LayoutC layoutUB{actualBlockShape.m(), actualBlockShape.n(), n0};
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(event_id);
copyGmToUbC(ubC, gmTileC, layoutUB, layoutGM);
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(event_id);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(event_id);
AscendC::Cast<float, ElementC, false>(ubCFp32, ubC, AscendC::RoundMode::CAST_NONE, -1, repeat, {1, 1, 8, 4});
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(event_id);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(event_id_2);
AscendC::WaitFlag<AscendC::HardEvent::S_MTE2>(event_id_2);
int32_t gmScaleOffset = preSrcExpertSum + blockCoord.m();
layout::VectorLayout scaleLauout{actualBlockShape.m()};
if (source_scale_offset[event_id] != gmScaleOffset) {
source_scale_offset[event_id] = gmScaleOffset;
copyScaleGmToUb(scaleUb, gmPerTokenScale[gmScaleOffset], scaleLauout, scaleLauout);
}
AscendC::SetFlag<AscendC::HardEvent::MTE2_S>(event_id_2);
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(event_id_2);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(event_id_2);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_S>(event_id_2); // Note that the value must be MTE2_S instead of MTE2_V.
// Otherwise, 0 will be read, causing garbled characters.
AscendC::PipeBarrier<PIPE_V>();
for (int32_t row = 0; row < actualBlockShape.m(); ++row) {
float scale = scaleUb(row);
Muls<float, false>(ubCFp32[n0* row], ubCFp32[n0 * row] , scale, -1, (actualBlockShape.n() + 127) / 128 * 2, {1, 1, 8, 8});
}
AscendC::PipeBarrier<PIPE_V>();
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(event_id);
AscendC::Cast<ElementD, float, false>(ubD, ubCFp32, AscendC::RoundMode::CAST_RINT, -1, repeat, {1, 1, 4, 8});
AscendC::SetFlag<AscendC::HardEvent::S_MTE2>(event_id_2);
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(event_id_2);
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(event_id);
int32_t lenTile = actualBlockShape.m();
int32_t stTile = blockCoord.m();
int32_t edTile = stTile + lenTile;
int32_t preSumRankInExpert = 0;
int32_t tileOffset = 0;
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(event_id);
for (int32_t dstEpIdx = 0; dstEpIdx < params.EP; dstEpIdx ++) {
int32_t lenRankInExpert = tokenPerExpert(tokenPerExpertLayout(dstEpIdx, params.rank, groupIdx));
int32_t dstExpertOffset = preSumBeforeRank(dstEpIdx * params.expertPerRank + groupIdx);
int32_t stRankInExpert = preSumRankInExpert;
int32_t edRankInExpert = stRankInExpert + lenRankInExpert;
preSumRankInExpert += lenRankInExpert;
if (stRankInExpert >= edTile) {
break;
}
else if (edRankInExpert <= stTile) {
continue;
}
int32_t stData = max(stRankInExpert, stTile);
int32_t edData = min(edRankInExpert, edTile);
uint32_t lenData = edData - stData;
if (lenData <= 0){
continue;
}
uint32_t dstOffsetInExpert = 0;
if (stTile > stRankInExpert) {
dstOffsetInExpert = stTile - stRankInExpert;
}
AscendC::GlobalTensor<ElementD> gmRemotePeer;
__gm__ void* dstPeermemPtr = params.shmem(params.offsetD, dstEpIdx);
gmRemotePeer.SetGlobalBuffer(reinterpret_cast<__gm__ ElementD*>(dstPeermemPtr));
MatrixCoord dstOffset{dstOffsetInExpert + dstExpertOffset, blockCoord.n()};
int64_t gmDstOffset = params.layoutC.GetOffset(dstOffset);
auto gmTileD = gmRemotePeer[gmDstOffset];
LayoutC layoutGM2{lenData, actualBlockShape.n(), params.n2};
LayoutC layoutUB2{lenData, actualBlockShape.n(), n0};
copyUbToGmD(gmTileD, ubD[tileOffset * n0], layoutGM2, layoutUB2);
tileOffset += lenData;
}
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(event_id);
}
private:
Params params;
AscendC::LocalTensor<ElementC> ubCList[UB_STAGES];
AscendC::LocalTensor<ElementD> ubDList[UB_STAGES];
AscendC::LocalTensor<float> ubFp32List[UB_STAGES];
AscendC::LocalTensor<float> scaleUbList[UB_STAGES];
int32_t source_scale_offset[UB_STAGES];
int32_t max_len = 8 * 32 / 4 * 128;
int32_t n0;
bool is_ping = false;
int32_t repeat = 128;
CopyGmToUbC copyGmToUbC;
CopyUbToGmD copyUbToGmD;
CopyScaleGmToUb copyScaleGmToUb;
AscendC::GlobalTensor<int32_t> tokenPerExpert;
Layout3D tokenPerExpertLayout;
};
}
#endif

View File

@@ -22,8 +22,6 @@
namespace Catlass::Gemm::Block {
constexpr uint16_t CROSS_CORE_FLAG_MAX_SET_COUNT = 15;
template<AscendC::HardEvent event>
__aicore__ inline void SyncFlagFunc(int32_t eventID)
{
@@ -153,9 +151,11 @@ public:
L1TileShape::K, L1TileShape::N);
CATLASS_DEVICE
BlockMmad(Arch::Resource<ArchTag> &resource, uint32_t l1BufAddrStart = 0)
BlockMmad(Arch::Resource<ArchTag> &resource, __gm__ int32_t* flagPtr = nullptr, int32_t expertPerRank = 0, uint32_t l1BufAddrStart = 0)
{
syncGroupIdx = 0;
ptrSoftFlagBase_ = flagPtr;
expertPerRank_ = expertPerRank;
InitL1(resource, l1BufAddrStart);
InitL0A(resource);
InitL0B(resource);
@@ -272,9 +272,21 @@ public:
CATLASS_DEVICE
void Finalize(int32_t target, int32_t flag = 0)
{
for(;syncGroupIdx <= target; syncGroupIdx++) {
int32_t flagId = syncGroupIdx / CROSS_CORE_FLAG_MAX_SET_COUNT + flag;
AscendC::CrossCoreSetFlag<0x2, PIPE_FIX>(flagId);
if (ptrSoftFlagBase_ != nullptr) {
if (target < 0) {
return;
}
AscendC::SetFlag<AscendC::HardEvent::FIX_MTE3>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::FIX_MTE3>(EVENT_ID0);
AscendC::GlobalTensor<int32_t> flagGlobal;
flagGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(ptrSoftFlagBase_) + (expertPerRank_ + AscendC::GetBlockIdx()) * FLAGSTRIDE);
AscendC::DataCopy(flagGlobal, l1FTensor[target * 16], FLAGSTRIDE);
}
else {
for(;syncGroupIdx <= target; syncGroupIdx++) {
int32_t flagId = syncGroupIdx / 15 + flag;
AscendC::CrossCoreSetFlag<0x2, PIPE_FIX>(flagId);
}
}
}
private:
@@ -291,7 +303,6 @@ private:
layout::VectorLayout layoutScale;
int32_t syncLoopIdx;
int32_t flag;
CATLASS_DEVICE
L1TileMmadParams() = default;
};
@@ -310,11 +321,24 @@ private:
AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(l1AEventList[i]);
AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(l1BEventList[i]);
}
uint32_t l1SOffset = l1BOffset + L1B_TILE_SIZE * L1_STAGES;
if constexpr (std::is_same_v<ElementA, int8_t>) {
uint32_t l1SOffset = l1BOffset + L1B_TILE_SIZE * L1_STAGES;
l1STensor = resource.l1Buf.template GetBufferByByte<uint64_t>(l1SOffset);
AscendC::SetFlag<AscendC::HardEvent::FIX_MTE2>(0);
}
if (ptrSoftFlagBase_ != nullptr) {
// Initialize the flag matrix (structure as below):
// 1 0 0 0 0 0 0 0
// 2 0 0 0 0 0 0 0
// ...
// 16 0 0 0 0 0 0 0
// Then move it to L1
uint32_t l1FOffset = l1SOffset + L1S_TILE_SIZE;
l1FTensor = resource.l1Buf.template GetBufferByByte<int32_t>(l1FOffset);
AscendC::GlobalTensor<int32_t> flagBase;
flagBase.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(ptrSoftFlagBase_));
AscendC::DataCopy(l1FTensor, flagBase, expertPerRank_ * FLAGSTRIDE);
}
}
CATLASS_DEVICE
@@ -463,12 +487,20 @@ private:
if constexpr (std::is_same_v<ElementA, int8_t>) {
AscendC::SetFlag<AscendC::HardEvent::FIX_MTE2>(0);
}
#ifdef __TILE_SYNC__
if (params.flag > 0) {
int32_t flagId = params.flag + params.syncLoopIdx / 8;
AscendC::CrossCoreSetFlag<0x2, PIPE_FIX>(flagId);
}
#else
Finalize(params.syncLoopIdx, params.flag);
#endif
}
}
AscendC::LocalTensor<ElementA> l1ATensorList[L1_STAGES];
AscendC::LocalTensor<ElementB> l1BTensorList[L1_STAGES];
AscendC::LocalTensor<uint64_t> l1STensor;
AscendC::LocalTensor<int32_t> l1FTensor;
int32_t syncGroupIdx;
int32_t l1AEventList[L1_STAGES];
int32_t l1BEventList[L1_STAGES];
@@ -497,8 +529,11 @@ private:
CopyL1ToL0A copyL1ToL0A;
CopyL1ToL0B copyL1ToL0B;
CopyL0CToGm copyL0CToGm;
__gm__ int32_t* ptrSoftFlagBase_ = nullptr;
int32_t expertPerRank_;
};
} // namespace Catlass::Gemm::Block
#endif // CATLASS_GEMM_BLOCK_BLOCK_MMAD_PRELOAD_FIXPIPE_QUANT_HPP
#endif // CATLASS_GEMM_BLOCK_BLOCK_MMAD_PRELOAD_FIXPIPE_QUANT_HPP

View File

@@ -5,5 +5,7 @@ constexpr static uint64_t MB_SIZE = 1024 * 1024UL;
constexpr static int32_t NUMS_PER_FLAG = 16;
constexpr static int32_t CACHE_LINE = 512;
constexpr static int32_t RESET_VAL = 0xffff;
constexpr static int32_t ALIGN_128 = 128;
constexpr static int32_t FLAGSTRIDE = 16;
constexpr static int32_t UB_ALIGN = 32;
constexpr uint16_t CROSS_CORE_FLAG_MAX_SET_COUNT = 15;
#endif

View File

@@ -33,13 +33,13 @@ namespace Catlass::Epilogue {
};
template <uint32_t UB_STAGES_>
struct EpilogueAtlasA2PerTokenDequantQuant {
struct EpilogueAtlasA2PerTokenDequantSwigluQuant {
using ArchTag = Arch::AtlasA2;
static constexpr uint32_t UB_STAGES = UB_STAGES_;
};
template <uint32_t UB_STAGES_>
struct EpilogueAtlasA2PerTokenDequantSwigluQuant {
struct EpilogueAtlasA2PerTokenDequantV2 {
using ArchTag = Arch::AtlasA2;
static constexpr uint32_t UB_STAGES = UB_STAGES_;
};

View File

@@ -5,13 +5,28 @@
#include "kernel_operator.h"
#include "const_args.hpp"
#ifdef HCCL_COMM
#include "moe_distribute_base.h"
using namespace AscendC::HcclContextDef;
#ifndef HCCL_COMM
#else
#include "shmem_api.h"
#endif
#define FORCE_INLINE_AICORE inline __attribute__((always_inline)) __aicore__
constexpr int32_t MAX_RANK_SIZE = 32;
constexpr int32_t SHMEM_MEM = 700 * MB_SIZE;
constexpr uint16_t SEND_SYNC_EVENT_ID = 9;
constexpr uint16_t RECV_SYNC_EVENT_ID = 10;
constexpr uint32_t SELF_STATE_OFFSET = 256 * 1024;
constexpr uint32_t STATE_OFFSET = 512;
FORCE_INLINE_AICORE void AicSyncAll() {
AscendC::CrossCoreSetFlag<0x0, PIPE_FIX>(8);
AscendC::CrossCoreWaitFlag<0x0>(8);
}
template<typename T>
FORCE_INLINE_AICORE void gm_store(__gm__ T *addr, T val) {
@@ -23,10 +38,11 @@ FORCE_INLINE_AICORE T gm_load(__gm__ T *cache) {
return *((__gm__ T *)cache);
}
FORCE_INLINE_AICORE void gm_dcci(__gm__ uint8_t * addr) {
template<typename T>
FORCE_INLINE_AICORE void gm_dcci(__gm__ T * addr) {
using namespace AscendC;
GlobalTensor<uint8_t> global;
global.SetGlobalBuffer(addr);
global.SetGlobalBuffer(reinterpret_cast<GM_ADDR>(addr));
// Important: add hint to avoid dcci being optimized by compiler
__asm__ __volatile__("");
@@ -37,26 +53,20 @@ FORCE_INLINE_AICORE void gm_dcci(__gm__ uint8_t * addr) {
FORCE_INLINE_AICORE int32_t gm_signal_wait_until_eq_for_barrier(__gm__ int32_t *sig_addr, int32_t cmp_val) {
do {
gm_dcci((__gm__ uint8_t *)sig_addr);
if (*sig_addr == cmp_val) {
return *sig_addr;
}
// in case when peer pe enters next barrier
if (*sig_addr == cmp_val + 1) {
return *sig_addr;
}
} while (true);
// never reach
return -1;
}
FORCE_INLINE_AICORE void gm_signal_wait_until_ne(__gm__ int32_t *sig_addr, int32_t cmp_val) {
do {
AscendC::LocalTensor<int32_t> ub;
ub.address_.logicPos = static_cast<uint8_t>(TPosition::VECIN);
ub.address_.logicPos = static_cast<uint8_t>(AscendC::TPosition::VECIN);
ub.address_.bufferAddr = 0;
AscendC::GlobalTensor<int32_t> sig;
sig.SetGlobalBuffer(sig_addr);
@@ -71,59 +81,53 @@ FORCE_INLINE_AICORE void gm_signal_wait_until_ne(__gm__ int32_t *sig_addr, int32
}
constexpr int32_t MAX_RANK_SIZE = 32;
class HcclShmem {
public:
#ifdef HCCL_COMM // HCCL needs to initialize the HCCL context
__gm__ HcclOpResParamCustom *WinContext_{nullptr};
Hccl<HCCL_SERVER_TYPE_AICPU> hccl_;
size_t m_segmentSize;
int32_t m_rank;
int32_t m_rankSize;
__gm__ HcclOpResParamCustom *WinContext_{nullptr};
Hccl<HCCL_SERVER_TYPE_AICPU> hccl_;
AscendC::LocalTensor<int32_t> ub;
FORCE_INLINE_AICORE
HcclShmem(){
auto contextGM0 = AscendC::GetHcclContext<HCCL_GROUP_ID_0>();
WinContext_ = (__gm__ HcclOpResParamCustom *)contextGM0;
FORCE_INLINE_AICORE
HcclShmem(){
auto contextGM0 = AscendC::GetHcclContext<HCCL_GROUP_ID_0>();
WinContext_ = (__gm__ HcclOpResParamCustom *)contextGM0;
m_rank = WinContext_->localUsrRankId;
m_rankSize = WinContext_->rankSize;
m_segmentSize = WinContext_->winSize;
}
FORCE_INLINE_AICORE
size_t SegmentSize() const {
return m_segmentSize;
}
FORCE_INLINE_AICORE
int32_t RankSize() const {
return m_rankSize;
}
m_rank = WinContext_->localUsrRankId;
m_rankSize = WinContext_->rankSize;
m_segmentSize = WinContext_->winSize;
}
#else
FORCE_INLINE_AICORE
HcclShmem(){
m_segmentSize = SHMEM_MEM;
}
FORCE_INLINE_AICORE
void initShmem(GM_ADDR symmetricPtr_, size_t rank, size_t rankSize) {
symmetricPtr = symmetricPtr_;
m_rank = rank;
m_rankSize = rankSize;
}
#endif
FORCE_INLINE_AICORE
GM_ADDR operator() () const { // No argument: return local peermem
GM_ADDR operator() () const { // No parameters: return pointer to local peermem
#ifdef HCCL_COMM
return (GM_ADDR)(WinContext_->localWindowsIn);
#else
return reinterpret_cast<GM_ADDR>(shmemi_get_state()->heap_base);
return reinterpret_cast<GM_ADDR>(shmem_ptr(symmetricPtr, m_rank));
#endif
}
FORCE_INLINE_AICORE
GM_ADDR operator() (int32_t index) const { // With index: return remote peermem base address
GM_ADDR operator() (int32_t index) const { // With index parameter: return pointer to the base address of remote peermem
#ifdef HCCL_COMM
return (GM_ADDR)((index == m_rank) ? WinContext_->localWindowsIn :
((HcclRankRelationResV2Custom *)(WinContext_->remoteRes[index].nextDevicePtr))->windowsIn);
#else
return reinterpret_cast<GM_ADDR>(shmem_ptr(shmemi_get_state()->heap_base, index));
return reinterpret_cast<GM_ADDR>(shmem_ptr(symmetricPtr, index));
#endif
}
FORCE_INLINE_AICORE
GM_ADDR operator () (int64_t offset, int32_t rankId) const {
#ifdef HCCL_COMM
@@ -136,15 +140,28 @@ public:
return (GM_ADDR)((rankId == m_rank) ? WinContext_->localWindowsIn :
((HcclRankRelationResV2Custom *)(WinContext_->remoteRes[rankId].nextDevicePtr))->windowsIn) + offset;
#else
return shmem_ptr(shmemi_get_state()->heap_base + offset, rankId);
return reinterpret_cast<GM_ADDR>(shmem_ptr((symmetricPtr + offset), rankId));
#endif
}
FORCE_INLINE_AICORE
size_t SegmentSize() const {
return m_segmentSize;
}
FORCE_INLINE_AICORE
int32_t RankSize() const {
return m_rankSize;
}
FORCE_INLINE_AICORE
~HcclShmem() {
}
FORCE_INLINE_AICORE
void CrossRankSync() {
uint64_t flag_offset = (m_segmentSize - MB_SIZE) / sizeof(int32_t);
@@ -165,12 +182,146 @@ public:
gm_store(sync_base, count);
}
FORCE_INLINE_AICORE
void InitStatusTargetSum()
{
using namespace AscendC;
uint64_t flag_offset = (m_segmentSize - MB_SIZE) + SELF_STATE_OFFSET;
//uint64_t self_state_offset = (m_segmentSize - 2 * MB_SIZE);
// ep state
//uint32_t coreIdx = get_block_idx();;
uint32_t coreIdx = GetBlockIdx();
GlobalTensor<int32_t> selfStatusTensor;
selfStatusTensor.SetGlobalBuffer((__gm__ int32_t *)((*this)() + flag_offset));
__asm__ __volatile__("");
DataCacheCleanAndInvalid<int32_t, CacheLine::SINGLE_CACHE_LINE, DcciDst::CACHELINE_OUT>(selfStatusTensor[coreIdx * UB_ALIGN]);
__asm__ __volatile__("");
int32_t state = selfStatusTensor(coreIdx * UB_ALIGN);
if (state == 0) {
sumTarget_ = static_cast<float>(1.0);
selfStatusTensor(coreIdx * UB_ALIGN) = 0x3F800000; // 1.0f
epStateValue_ = 0x3F800000; // 1.0f
} else {
sumTarget_ = static_cast<float>(0.0);
selfStatusTensor(coreIdx * UB_ALIGN) = 0;
epStateValue_ = 0;
}
__asm__ __volatile__("");
DataCacheCleanAndInvalid<int32_t, CacheLine::SINGLE_CACHE_LINE, DcciDst::CACHELINE_OUT>(selfStatusTensor[coreIdx * UB_ALIGN]);
__asm__ __volatile__("");
}
FORCE_INLINE_AICORE
void CrossRankSyncV2Set(AscendC::LocalTensor<int32_t> ctrBuffer) {
//subblockid = 0
uint32_t stateOffset_ = STATE_OFFSET;
// uint32_t epStateOffsetOnWin_ = m_rank * stateOffset_;
uint64_t flag_offset = (m_segmentSize - MB_SIZE) + m_rank * stateOffset_;
//uint64_t flag_offset = (m_segmentSize - MB_SIZE);
int vec_size = get_block_num();
int vec_id = get_block_idx();
AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>(RECV_SYNC_EVENT_ID);
AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>(SEND_SYNC_EVENT_ID);
AscendC::CrossCoreWaitFlag(SEND_SYNC_EVENT_ID);
pipe_barrier(PIPE_ALL);
ctrBuffer.SetValue(0, epStateValue_);
AscendC::SetFlag<AscendC::HardEvent::S_MTE3>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::S_MTE3>(EVENT_ID0);
for (uint32_t dstEpIdx = vec_id; dstEpIdx < m_rankSize; dstEpIdx += vec_size) {
AscendC::GlobalTensor<int32_t> gmDstStates;
gmDstStates.SetGlobalBuffer((__gm__ int32_t*)((*this)(flag_offset, dstEpIdx)));
DataCopy(gmDstStates, ctrBuffer, 8);
}
AscendC::CrossCoreWaitFlag(RECV_SYNC_EVENT_ID);
}
FORCE_INLINE_AICORE
void CrossRankSyncV2Wait(AscendC::LocalTensor<float> statusTensor, AscendC::LocalTensor<float> gatherMaskOutTensor,
AscendC::LocalTensor<uint32_t> gatherTmpTensor, AscendC::LocalTensor<float> statusSumOutTensor) {
uint64_t flag_offset = (m_segmentSize - MB_SIZE);
int vec_size = get_block_num();
int vec_id = get_block_idx();
uint32_t stateOffset_ = STATE_OFFSET;
uint32_t sendRankNum_ = m_rankSize / vec_size;
uint32_t remainderRankNum = m_rankSize % vec_size;
uint32_t startRankId_ = sendRankNum_ * vec_id;
if (vec_id < remainderRankNum) {
sendRankNum_++;
startRankId_ += vec_id;
} else {
startRankId_ += remainderRankNum;
}
uint32_t endRankId_ = startRankId_ + sendRankNum_;
AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>(SEND_SYNC_EVENT_ID);
AscendC::GlobalTensor<float> epStatusSpaceGlobalTensor_;
epStatusSpaceGlobalTensor_.SetGlobalBuffer((__gm__ float *)((*this)() + flag_offset));
if (startRankId_ < m_rankSize) {
AscendC::PipeBarrier<PIPE_ALL>();
gatherTmpTensor.SetValue(0, 1);
uint32_t mask = 1; // gatherMask + sum
uint64_t rsvdCnt = 0;
// DataCopyParams intriParams{static_cast<uint16_t>(sendRankNum_), 1,
// static_cast<uint16_t>((moeSendNum_ > 512) ? 7 : 15), 0};
AscendC::DataCopyParams intriParams{static_cast<uint16_t>(sendRankNum_), 1,
static_cast<uint16_t>(15), 0};
float sumOfFlag = static_cast<float>(-1.0);
float minTarget = (sumTarget_ * sendRankNum_) - (float)0.5;
float maxTarget = (sumTarget_ * sendRankNum_) + (float)0.5;
AscendC::SumParams sumParams{1, sendRankNum_, sendRankNum_};
AscendC::SetFlag<AscendC::HardEvent::S_V>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::S_V>(EVENT_ID0);
while ((sumOfFlag < minTarget) || (sumOfFlag > maxTarget)) {
AscendC::DataCopy<float>(statusTensor, epStatusSpaceGlobalTensor_[startRankId_ * stateOffset_ / sizeof(float)],
intriParams);
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
GatherMask(gatherMaskOutTensor, statusTensor, gatherTmpTensor, true, mask,
{1, (uint16_t)sendRankNum_, 1, 0}, rsvdCnt);
AscendC::PipeBarrier<PIPE_V>();
AscendC::Sum(statusSumOutTensor, gatherMaskOutTensor, sumParams);
AscendC::SetFlag<AscendC::HardEvent::V_S>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::V_S>(EVENT_ID0);
sumOfFlag = statusSumOutTensor.GetValue(0);
}
}
AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>(RECV_SYNC_EVENT_ID);
AscendC::CrossCoreWaitFlag(RECV_SYNC_EVENT_ID);
//unpermute
AscendC::CrossCoreWaitFlag(SEND_SYNC_EVENT_ID);
}
FORCE_INLINE_AICORE
__gm__ int32_t* SyncBaseAddr() {
uint64_t flag_offset = (m_segmentSize - MB_SIZE) / sizeof(int32_t);
return (__gm__ int32_t*)(*this)() + flag_offset + 2048;
}
private:
GM_ADDR symmetricPtr;
int32_t m_rank;
int32_t m_rankSize;
size_t m_segmentSize;
float sumTarget_{0.0};
int32_t epStateValue_;
};
#endif