[CustomOp] support TensorList for dispatchFFNCombine (#5665)
### What this PR does / why we need it?
To support tensorList for dispatch_ffn_combine, to adjust eplb
### Does this PR introduce _any_ user-facing change?
N/A
### How was this patch tested?
Single Operator Testing
- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef
---------
Signed-off-by: lhchg <lhao_cheng@163.com>
Co-authored-by: lihaocheng <lihaosheng1@h-partners.com>
This commit is contained in:
@@ -70,23 +70,24 @@ public:
|
||||
__gm__ int32_t *ptrTokenPerExpert{nullptr};
|
||||
int32_t EP;
|
||||
int32_t expertPerRank;
|
||||
int32_t n2;
|
||||
|
||||
CATLASS_DEVICE
|
||||
Params() {};
|
||||
|
||||
CATLASS_DEVICE
|
||||
Params(int32_t EP_, int32_t expertPerRank_, __gm__ int32_t *ptrTokenPerExpert_) : ptrTokenPerExpert(ptrTokenPerExpert_), EP(EP_), expertPerRank(expertPerRank_) {}
|
||||
Params(int32_t EP_, int32_t expertPerRank_, __gm__ int32_t *ptrTokenPerExpert_, int32_t n2_) : ptrTokenPerExpert(ptrTokenPerExpert_), EP(EP_), expertPerRank(expertPerRank_), n2(n2_) {}
|
||||
};
|
||||
|
||||
CATLASS_DEVICE
|
||||
BlockEpilogue(Arch::Resource<ArchTag> const &resource, Params const ¶ms = Params{}) : params(params)
|
||||
{
|
||||
size_t ubOffset = 4096;
|
||||
size_t ubOffset = 0;
|
||||
int32_t eventVMTE2 = 0;
|
||||
int32_t eventMTE2V = 0;
|
||||
int32_t eventMTE3V = 0;
|
||||
int32_t eventVMTE3 = 0;
|
||||
constexpr int32_t blockN = 12000;
|
||||
int32_t blockN = params.n2;
|
||||
for (uint32_t i = 0; i < UB_STAGES; ++i) {
|
||||
ubCList[i] = resource.ubBuf.template GetBufferByByte<ElementC>(ubOffset);
|
||||
ubOffset += blockN * sizeof(ElementC);
|
||||
|
||||
@@ -84,16 +84,16 @@ public:
|
||||
};
|
||||
|
||||
CATLASS_DEVICE
|
||||
BlockEpilogue(Arch::Resource<ArchTag> const &resource, Params const ¶ms = Params{}) : params(params)
|
||||
BlockEpilogue(Arch::Resource<ArchTag> const &resource, int32_t n, Params const ¶ms = Params{}) : params(params)
|
||||
{
|
||||
size_t ubOffset = 0;
|
||||
int32_t eventVMTE2 = 0;
|
||||
int32_t eventMTE2V = 0;
|
||||
int32_t eventMTE3V = 0;
|
||||
int32_t eventVMTE3 = 0;
|
||||
constexpr uint32_t blockN = 4096;
|
||||
constexpr uint32_t ChunkTileLen = blockN / 2;
|
||||
constexpr uint32_t HalfChunkTileLen = ChunkTileLen / 2;
|
||||
uint32_t blockN = n;
|
||||
uint32_t ChunkTileLen = blockN / 2;
|
||||
uint32_t HalfChunkTileLen = ChunkTileLen / 2;
|
||||
|
||||
for (uint32_t i = 0; i < UB_STAGES; ++i) {
|
||||
ubCList[i] = resource.ubBuf.template GetBufferByByte<ElementC>(ubOffset);
|
||||
|
||||
@@ -3,4 +3,6 @@
|
||||
#define CONST_ARGS_HPP
|
||||
constexpr static uint64_t MB_SIZE = 1024 * 1024UL;
|
||||
constexpr static int32_t NUMS_PER_FLAG = 16;
|
||||
constexpr static int32_t CACHE_LINE = 512;
|
||||
constexpr static int32_t RESET_VAL = 0xffff;
|
||||
#endif
|
||||
@@ -0,0 +1,16 @@
|
||||
#ifndef GET_TENSOR_ADDR_HPP
|
||||
#define GET_TENSOR_ADDR_HPP
|
||||
#include "kernel_operator.h"
|
||||
|
||||
#define FORCE_INLINE_AICORE inline __attribute__((always_inline)) __aicore__
|
||||
|
||||
template <typename T>
|
||||
FORCE_INLINE_AICORE __gm__ T* GetTensorAddr(uint32_t index, GM_ADDR tensorPtr) {
|
||||
__gm__ uint64_t* dataAddr = reinterpret_cast<__gm__ uint64_t*>(tensorPtr);
|
||||
uint64_t tensorPtrOffset = *dataAddr; // The offset of the data address from the first address.
|
||||
// Moving 3 bits to the right means dividing by sizeof(uint64 t).
|
||||
__gm__ uint64_t* retPtr = dataAddr + (tensorPtrOffset >> 3);
|
||||
return reinterpret_cast<__gm__ T*>(*(retPtr + index));
|
||||
}
|
||||
|
||||
#endif // GET_TENSOR_ADDR_HPP
|
||||
@@ -53,17 +53,34 @@ FORCE_INLINE_AICORE int32_t gm_signal_wait_until_eq_for_barrier(__gm__ int32_t *
|
||||
}
|
||||
|
||||
|
||||
FORCE_INLINE_AICORE void gm_signal_wait_until_ne(__gm__ int32_t *sig_addr, int32_t cmp_val) {
|
||||
do {
|
||||
AscendC::LocalTensor<int32_t> ub;
|
||||
ub.address_.logicPos = static_cast<uint8_t>(TPosition::VECIN);
|
||||
ub.address_.bufferAddr = 0;
|
||||
AscendC::GlobalTensor<int32_t> sig;
|
||||
sig.SetGlobalBuffer(sig_addr);
|
||||
AscendC::DataCopy(ub, sig, 8);
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_S>(EVENT_ID0);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_S>(EVENT_ID0);
|
||||
if (ub(0) != cmp_val) {
|
||||
return;
|
||||
}
|
||||
} while (true);
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
constexpr int32_t MAX_RANK_SIZE = 32;
|
||||
class HcclShmem {
|
||||
public:
|
||||
#ifdef HCCL_COMM // HCCL needs to initialize the HCCL context
|
||||
__gm__ HcclOpResParamCustom *WinContext_{nullptr};
|
||||
Hccl<HCCL_SERVER_TYPE_AICPU> hccl_;
|
||||
GM_ADDR m_ptrArray[MAX_RANK_SIZE];
|
||||
size_t m_segmentSize;
|
||||
int32_t m_rank;
|
||||
int32_t m_rankSize;
|
||||
|
||||
|
||||
FORCE_INLINE_AICORE
|
||||
HcclShmem(){
|
||||
auto contextGM0 = AscendC::GetHcclContext<HCCL_GROUP_ID_0>();
|
||||
@@ -73,18 +90,13 @@ public:
|
||||
m_rankSize = WinContext_->rankSize;
|
||||
m_segmentSize = WinContext_->winSize;
|
||||
|
||||
for (int i = 0; i < m_rankSize; i++) {
|
||||
m_ptrArray[i] = (GM_ADDR)((i == m_rank) ? WinContext_->localWindowsIn :
|
||||
((HcclRankRelationResV2Custom *)(WinContext_->remoteRes[i].nextDevicePtr))->windowsIn);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
FORCE_INLINE_AICORE
|
||||
size_t SegmentSize() const {
|
||||
return m_segmentSize;
|
||||
}
|
||||
|
||||
|
||||
FORCE_INLINE_AICORE
|
||||
int32_t RankSize() const {
|
||||
return m_rankSize;
|
||||
@@ -94,7 +106,7 @@ public:
|
||||
FORCE_INLINE_AICORE
|
||||
GM_ADDR operator() () const { // No argument: return local peermem
|
||||
#ifdef HCCL_COMM
|
||||
return m_ptrArray[m_rank];
|
||||
return (GM_ADDR)(WinContext_->localWindowsIn);
|
||||
#else
|
||||
return reinterpret_cast<GM_ADDR>(shmemi_get_state()->heap_base);
|
||||
#endif
|
||||
@@ -103,7 +115,8 @@ public:
|
||||
FORCE_INLINE_AICORE
|
||||
GM_ADDR operator() (int32_t index) const { // With index: return remote peermem base address
|
||||
#ifdef HCCL_COMM
|
||||
return m_ptrArray[index];
|
||||
return (GM_ADDR)((index == m_rank) ? WinContext_->localWindowsIn :
|
||||
((HcclRankRelationResV2Custom *)(WinContext_->remoteRes[index].nextDevicePtr))->windowsIn);
|
||||
#else
|
||||
return reinterpret_cast<GM_ADDR>(shmem_ptr(shmemi_get_state()->heap_base, index));
|
||||
#endif
|
||||
@@ -120,7 +133,8 @@ public:
|
||||
if (rankId < 0 || rankId >= m_rankSize) {
|
||||
return nullptr;
|
||||
}
|
||||
return m_ptrArray[rankId] + offset;
|
||||
return (GM_ADDR)((rankId == m_rank) ? WinContext_->localWindowsIn :
|
||||
((HcclRankRelationResV2Custom *)(WinContext_->remoteRes[rankId].nextDevicePtr))->windowsIn) + offset;
|
||||
#else
|
||||
return shmem_ptr(shmemi_get_state()->heap_base + offset, rankId);
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user