Files
xc-llm-ascend/csrc/dispatch_ffn_combine/op_kernel/utils/hccl_shmem.hpp
lhchg dc99cfdc15 [CustomOp] support TensorList for dispatchFFNCombine (#5665)
### What this PR does / why we need it?
To support tensorList for dispatch_ffn_combine, to adjust eplb

### Does this PR introduce _any_ user-facing change?
N/A

### How was this patch tested?
Single Operator Testing

- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef

---------

Signed-off-by: lhchg <lhao_cheng@163.com>
Co-authored-by: lihaocheng <lihaosheng1@h-partners.com>
2026-01-09 15:56:29 +08:00

177 lines
5.2 KiB
C++

#ifndef SYNC_UTIL_HPP
#define SYNC_UTIL_HPP
#include "kernel_operator.h"
#include "const_args.hpp"
#include "moe_distribute_base.h"
#ifndef HCCL_COMM
#include "shmem_api.h"
#endif
#define FORCE_INLINE_AICORE inline __attribute__((always_inline)) __aicore__
template<typename T>
FORCE_INLINE_AICORE void gm_store(__gm__ T *addr, T val) {
*((__gm__ T *)addr) = val;
}
template<typename T>
FORCE_INLINE_AICORE T gm_load(__gm__ T *cache) {
return *((__gm__ T *)cache);
}
FORCE_INLINE_AICORE void gm_dcci(__gm__ uint8_t * addr) {
using namespace AscendC;
GlobalTensor<uint8_t> global;
global.SetGlobalBuffer(addr);
// Important: add hint to avoid dcci being optimized by compiler
__asm__ __volatile__("");
DataCacheCleanAndInvalid<uint8_t, CacheLine::SINGLE_CACHE_LINE, DcciDst::CACHELINE_OUT>(global);
__asm__ __volatile__("");
}
FORCE_INLINE_AICORE int32_t gm_signal_wait_until_eq_for_barrier(__gm__ int32_t *sig_addr, int32_t cmp_val) {
do {
gm_dcci((__gm__ uint8_t *)sig_addr);
if (*sig_addr == cmp_val) {
return *sig_addr;
}
// in case when peer pe enters next barrier
if (*sig_addr == cmp_val + 1) {
return *sig_addr;
}
} while (true);
// never reach
return -1;
}
FORCE_INLINE_AICORE void gm_signal_wait_until_ne(__gm__ int32_t *sig_addr, int32_t cmp_val) {
do {
AscendC::LocalTensor<int32_t> ub;
ub.address_.logicPos = static_cast<uint8_t>(TPosition::VECIN);
ub.address_.bufferAddr = 0;
AscendC::GlobalTensor<int32_t> sig;
sig.SetGlobalBuffer(sig_addr);
AscendC::DataCopy(ub, sig, 8);
AscendC::SetFlag<AscendC::HardEvent::MTE2_S>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_S>(EVENT_ID0);
if (ub(0) != cmp_val) {
return;
}
} while (true);
return;
}
constexpr int32_t MAX_RANK_SIZE = 32;
class HcclShmem {
public:
#ifdef HCCL_COMM // HCCL needs to initialize the HCCL context
__gm__ HcclOpResParamCustom *WinContext_{nullptr};
Hccl<HCCL_SERVER_TYPE_AICPU> hccl_;
size_t m_segmentSize;
int32_t m_rank;
int32_t m_rankSize;
FORCE_INLINE_AICORE
HcclShmem(){
auto contextGM0 = AscendC::GetHcclContext<HCCL_GROUP_ID_0>();
WinContext_ = (__gm__ HcclOpResParamCustom *)contextGM0;
m_rank = WinContext_->localUsrRankId;
m_rankSize = WinContext_->rankSize;
m_segmentSize = WinContext_->winSize;
}
FORCE_INLINE_AICORE
size_t SegmentSize() const {
return m_segmentSize;
}
FORCE_INLINE_AICORE
int32_t RankSize() const {
return m_rankSize;
}
#endif
FORCE_INLINE_AICORE
GM_ADDR operator() () const { // No argument: return local peermem
#ifdef HCCL_COMM
return (GM_ADDR)(WinContext_->localWindowsIn);
#else
return reinterpret_cast<GM_ADDR>(shmemi_get_state()->heap_base);
#endif
}
FORCE_INLINE_AICORE
GM_ADDR operator() (int32_t index) const { // With index: return remote peermem base address
#ifdef HCCL_COMM
return (GM_ADDR)((index == m_rank) ? WinContext_->localWindowsIn :
((HcclRankRelationResV2Custom *)(WinContext_->remoteRes[index].nextDevicePtr))->windowsIn);
#else
return reinterpret_cast<GM_ADDR>(shmem_ptr(shmemi_get_state()->heap_base, index));
#endif
}
FORCE_INLINE_AICORE
GM_ADDR operator () (int64_t offset, int32_t rankId) const {
#ifdef HCCL_COMM
if (offset < 0 || offset >= m_segmentSize) {
return nullptr;
}
if (rankId < 0 || rankId >= m_rankSize) {
return nullptr;
}
return (GM_ADDR)((rankId == m_rank) ? WinContext_->localWindowsIn :
((HcclRankRelationResV2Custom *)(WinContext_->remoteRes[rankId].nextDevicePtr))->windowsIn) + offset;
#else
return shmem_ptr(shmemi_get_state()->heap_base + offset, rankId);
#endif
}
FORCE_INLINE_AICORE
~HcclShmem() {
}
FORCE_INLINE_AICORE
void CrossRankSync() {
uint64_t flag_offset = (m_segmentSize - MB_SIZE) / sizeof(int32_t);
__gm__ int32_t* sync_counter = (__gm__ int32_t*)(*this)() + flag_offset;
__gm__ int32_t* sync_base = (__gm__ int32_t*)(*this)() + flag_offset + 2048;
int count = gm_load(sync_base) + 1;
int vec_id = AscendC::GetBlockIdx();
int vec_size = AscendC::GetBlockNum() * AscendC::GetTaskRation();
for(int i = vec_id; i < m_rankSize; i += vec_size) {
__gm__ int32_t* sync_remote = (__gm__ int32_t*)((*this)(i)) + flag_offset + m_rank * 16;
gm_store(sync_remote, count);
gm_dcci((__gm__ uint8_t*)sync_remote);
auto sync_check = sync_counter + i * 16;
gm_signal_wait_until_eq_for_barrier(sync_check, count);
}
AscendC::SyncAll<true>();
gm_store(sync_base, count);
}
FORCE_INLINE_AICORE
__gm__ int32_t* SyncBaseAddr() {
uint64_t flag_offset = (m_segmentSize - MB_SIZE) / sizeof(int32_t);
return (__gm__ int32_t*)(*this)() + flag_offset + 2048;
}
};
#endif