[Kernel]: Optimize DispatchFFNCombine performance (#6468)

### What this PR does / why we need it?

This PR focuses on performance optimization for the DispatchFFNCombine
operator. The key optimizations include:

1. Improving communication efficiency by merging the transmission of
tokens and scales;
2. Decoupling multi-core dependencies and reducing waiting bubbles in
the combine process through tile-granularity communication;
3. Optimizing the full-card synchronization overhead before the
umpermute operation.

These optimizations aim to reduce the overall execution latency of the
DispatchFFNCombine operator and enhance the runtime performance of the
model inference process on Ascend devices.

### Does this PR introduce _any_ user-facing change?

No. This PR only involves internal performance optimization of the
DispatchFFNCombine operator and does not introduce any changes to
user-facing APIs, interfaces, or behaviors.

### How was this patch tested?

1. Enable the DispatchFFNCombine operator by setting the environment
variable:
```
export VLLM_ASCEND_ENABLE_FUSED_MC2=1
```
2. Run the standard model inference test suite with the above
environment variable enabled;
4. Verify the correctness of model outputs (ensuring no functional
regression) and measure the performance improvement of the
DispatchFFNCombine operator (reduced latency and improved throughput).

- vLLM version: v0.14.1
- vLLM main:
dc917cceb8

Signed-off-by: xulei_ict <xulei292@huawei.com>
Co-authored-by: xulei_ict <xulei292@huawei.com>
This commit is contained in:
xulei
2026-02-09 16:30:34 +08:00
committed by GitHub
parent 9c6d031797
commit 8325528368
13 changed files with 897 additions and 356 deletions

View File

@@ -17,11 +17,11 @@
#include "error_log.h"
#include "hcom_topo_info.h"
#include "register/op_def_registry.h"
#include "dispatch_ffn_combine_tiling.h"
#include "../op_kernel/dispatch_ffn_combine_tiling.h"
#include <vector>
#include <map>
#include <algorithm>
#include "moe_init_routing_quant_v2/moe_init_routing_quant_v2_tiling.h"
#include "../op_kernel/moe_init_routing_quant_v2/moe_init_routing_quant_v2_tiling.h"
using namespace AscendC;
using namespace ge;
@@ -278,8 +278,12 @@ static ge::graphStatus DispatchFFNCombineTilingFuncImpl(gert::TilingContext *con
uint64_t cocWorkspace = (info.M + 256 - 1) / 256 * 256 * info.topK *sizeof(int32_t) +
info.worldSize * info.worldSize * info.expertPerRank * sizeof(int32_t) * 3 +
info.maxOutputSize * sizeof(float) * 2 +
std::max(info.maxOutputSize * info.N * sizeof(int16_t), info.maxOutputSize * n2 * sizeof(int16_t)) +
std::max(info.maxOutputSize * info.K * sizeof(int8_t), info.maxOutputSize * k2 * sizeof(int8_t));
info.maxOutputSize * info.N * sizeof(int16_t) +
info.maxOutputSize * n2 * sizeof(int16_t) +
info.maxOutputSize * info.K * sizeof(int8_t) +
info.maxOutputSize * k2 * sizeof(int8_t) +
info.worldSize * sizeof(int32_t) * 16 +
(info.expertPerRank + info.worldSize) * sizeof(int32_t) * 16;
workSpaces[0] = SYSTEM_NEED_WORKSPACE + std::max(cocWorkspace, initRoutingWorkspace);

View File

@@ -23,29 +23,11 @@ extern "C" __global__ __aicore__ void dispatch_ffn_combine(GM_ADDR x, GM_ADDR w1
GM_ADDR c, GM_ADDR expertTokenNums, GM_ADDR workspaceGM, GM_ADDR tilingGM)
{
REGISTER_TILING_DEFAULT(DispatchFFNCombineTilingData);
if (TILING_KEY_IS(1000000)) {
KERNEL_TASK_TYPE(1000000, KERNEL_TYPE_MIX_AIC_1_2);
GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineTilingData, tilingData, tilingGM);
DispatchFFNCombine<int8_t, DTYPE_W1, DTYPE_OUT, false, true> op;
op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, expertTokenNums, workspaceGM, tilingGM);
op.Process();
} else if (TILING_KEY_IS(1000001)) {
KERNEL_TASK_TYPE(1000001, KERNEL_TYPE_MIX_AIC_1_2);
GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineTilingData, tilingData, tilingGM);
DispatchFFNCombine<int8_t, DTYPE_W1, DTYPE_OUT, true, false> op;
op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, expertTokenNums, workspaceGM, tilingGM);
op.Process();
} else if (TILING_KEY_IS(1000010)) {
if (TILING_KEY_IS(1000010)) {
KERNEL_TASK_TYPE(1000010, KERNEL_TYPE_MIX_AIC_1_2);
GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineTilingData, tilingData, tilingGM);
DispatchFFNCombine<int8_t, DTYPE_W1, DTYPE_OUT, false, true> op;
op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, expertTokenNums, workspaceGM, tilingGM);
op.Process();
} else if (TILING_KEY_IS(1000011)) {
KERNEL_TASK_TYPE(1000011, KERNEL_TYPE_MIX_AIC_1_2);
GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineTilingData, tilingData, tilingGM);
DispatchFFNCombine<int8_t, DTYPE_W1, DTYPE_OUT, true, true> op;
op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, expertTokenNums, workspaceGM, tilingGM);
op.Process();
}
}

View File

@@ -234,7 +234,7 @@ __aicore__ inline void DispatchFFNCombine<TemplateMMA2ACFunc>::Process()
using BlockEpilogue1 = Epilogue::Block::BlockEpilogue<EpilogueDispatchPolicy1, CType, PerTokenScaleType,
D1Type, TileElemWiseMuls, TileCopy1>;
using EpilogueDispatchPolicy2 = Epilogue::EpilogueAtlasA2PerTokenDequant<ubStages>;
using EpilogueDispatchPolicy2 = Epilogue::EpilogueAtlasA2PerTokenDequantV2<ubStages>;
using TileCopy2 = Epilogue::Tile::TileCopy<ArchTag, CType, ScaleType, PerTokenScaleType, D2Type>;
using BlockEpilogue2 = Epilogue::Block::BlockEpilogue<EpilogueDispatchPolicy2, CType,PerTokenScaleType,
D2Type, TileCopy2>;

View File

@@ -22,21 +22,38 @@
#include "catlass/matrix_coord.hpp"
#include "catlass/epilogue/tile/tile_copy.hpp"
#include "utils/block_mmad_preload_async_fixpipe_quant.hpp"
#include "utils/copy_gm_to_l1_custom.hpp"
#include "utils/copy_l0c_to_gm_custom.hpp"
#include "utils/block_epilogue_pertoken_row.hpp"
#include "utils/block_epilogue_pertoken_swiglu.hpp"
#include "utils/hccl_shmem.hpp"
#include "utils/const_args.hpp"
#include "utils/layout3d.hpp"
#include "utils/get_tensor_addr.hpp"
#include "moe_init_routing_quant_v2/moe_init_routing_quant_v2_tiling.h"
#include "moe_init_routing_quant_v2/moe_init_routing_quant_v2.cpp"
#include "moe_init_routing_quant_v2/moe_v2_fullload_dynamic_quant.h"
#include "unpermute/moe_token_unpermute.h"
#ifndef HCCL_COMM
#include "block_mmad_preload_async_fixpipe_quant.hpp"
#include "copy_gm_to_l1_custom.hpp"
#include "copy_l0c_to_gm_custom.hpp"
#include "block_epilogue_pertoken_row.hpp"
#include "block_epilogue_pertoken_v2.hpp"
#include "block_epilogue_pertoken_swiglu.hpp"
#include "hccl_shmem.hpp"
#include "const_args.hpp"
#include "layout3d.hpp"
#include "tiling/moe_init_routing_quant_v2_tiling.h"
#include "moe_init_routing_quant_v2/moe_init_routing_quant_v2.cpp"
#include "moe_init_routing_quant_v2/moe_v2_fullload_dynamic_quant.h"
#include "moe_token_unpermute.h"
#include "get_tensor_addr.hpp"
inline __gm__ struct OpSystemRunCfg g_opSystemRunCfg{Catlass::L2_OFFSET};
#else
#include "utils/block_mmad_preload_async_fixpipe_quant.hpp"
#include "utils/copy_gm_to_l1_custom.hpp"
#include "utils/copy_l0c_to_gm_custom.hpp"
#include "utils/block_epilogue_pertoken_row.hpp"
#include "utils/block_epilogue_pertoken_v2.hpp"
#include "utils/block_epilogue_pertoken_swiglu.hpp"
#include "utils/hccl_shmem.hpp"
#include "utils/const_args.hpp"
#include "utils/layout3d.hpp"
#include "moe_init_routing_quant_v2/moe_init_routing_quant_v2_tiling.h"
#include "moe_init_routing_quant_v2/moe_init_routing_quant_v2.cpp"
#include "moe_init_routing_quant_v2/moe_v2_fullload_dynamic_quant.h"
#include "unpermute/moe_token_unpermute.h"
#include "utils/get_tensor_addr.hpp"
#endif
using namespace AscendC;
@@ -44,7 +61,6 @@ namespace Catlass::Gemm::Kernel {
constexpr uint16_t SYNCFLAGC2V = 9;
constexpr uint16_t SYNCFLAGV2C = 10;
constexpr uint16_t CROSS_CORE_FLAG_MAX_SET_COUNT = 15;
template <
class BlockMmad_,
@@ -104,6 +120,7 @@ public:
uint32_t rank;
uint32_t rankSize;
int32_t ubMoveNum;
GM_ADDR symmetricPtr;
//--------------
GM_ADDR expertIdx;
GM_ADDR moeInitRoutingQuantV2Scale;
@@ -193,9 +210,7 @@ public:
void operator()<AscendC::AIC>(Params const &params)
{
GMM1(params);
AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGV2C);
GMM2(params);
}
@@ -204,32 +219,26 @@ public:
CATLASS_DEVICE
void operator()<AscendC::AIV>(Params const &params)
{
Dispatch(params);
AscendC::SyncAll<true>();
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(SYNCFLAGV2C);
Combine(params);
DispatchAndCombine(params);
}
private:
CATLASS_DEVICE void initBuffer(Params const &params) {
#ifndef HCCL_COMM
shmem.initShmem(params.symmetricPtr, params.rank, params.rankSize);
#endif
workspaceInfo = WorkspaceInfo(params);
peermemInfo = PeermemInfo(params, shmem);
cumsumMM.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(workspaceInfo.ptrcumsumMM));
gmA.SetGlobalBuffer(reinterpret_cast<__gm__ ElementA *>(workspaceInfo.ptrA));
gmC.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(workspaceInfo.ptrC));
gmPermutedToken.SetGlobalBuffer(reinterpret_cast<__gm__ ElementD1 *>(workspaceInfo.ptrPermutedToken));
gmC2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(workspaceInfo.ptrC2));
gmPerTokenScale1.SetGlobalBuffer(reinterpret_cast<__gm__ ElementPerTokenScale *>(workspaceInfo.ptrPerTokenScale));
gmPerTokenScale2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementPerTokenScale *>(workspaceInfo.ptrPerTokenScale2));
tokenPerExpert.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(shmem() + peermemInfo.offsetPeerTokenPerExpert));
tokenPerExpertLayout = Layout3D(AlignUp(params.EP * params.expertPerRank, ALIGN_128), params.expertPerRank);
preSumBeforeRank.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(workspaceInfo.ptrSumBeforeRank));
}
template<typename T>
@@ -285,6 +294,51 @@ private:
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID1);
}
// Move tokens and scales together, then write them to different positions respectively
template<typename T>
CATLASS_DEVICE void CopyGMToGMPerToken(
AscendC::GlobalTensor<T> dst,
AscendC::GlobalTensor<float> dstScale,
AscendC::GlobalTensor<T> src,
int32_t rows,
int32_t hiddenSize
)
{
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID1);
constexpr int32_t BufferNum = 2;
AscendC::LocalTensor<T> tmpBuffer1 = resource.ubBuf.template GetBufferByByte<T>(0);
constexpr int tmpBufferOffset = 96 * 1024; // half of UB
AscendC::LocalTensor<T> tmpBuffer2 = resource.ubBuf.template GetBufferByByte<T>(tmpBufferOffset);
uint32_t copyInNum = hiddenSize + ALIGN_512;
// [ReduceScatter] 2. Pre Interface Sync
int pingpongId = 0;
for (uint32_t processIndex = 0; processIndex < rows; ++processIndex) {
AscendC::TEventID EVENT_ID = pingpongId == 0 ? EVENT_ID0 : EVENT_ID1;
AscendC::LocalTensor<T> buf = pingpongId == 0 ? tmpBuffer1 : tmpBuffer2;
AscendC::LocalTensor<float> bufScale = buf[hiddenSize].template ReinterpretCast<float>();
auto inputOffset = processIndex * copyInNum;
auto outputOffset = processIndex * hiddenSize;
// [ReduceScatter] 2. Pre Interface Sync
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID);
// [ReduceScatter] 3. Start shmem_mte_get_mem_nbi
AscendC::DataCopy(buf, src[inputOffset], copyInNum);
AscendC::SetFlag<AscendC::HardEvent::MTE2_MTE3>(EVENT_ID);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_MTE3>(EVENT_ID);
AscendC::DataCopy(dst[outputOffset], buf, hiddenSize);
AscendC::DataCopyPad(dstScale[processIndex], bufScale, {1, 4, 0, 0, 0});
// [ReduceScatter] 4. Post Interface Sync
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID);
pingpongId = (pingpongId + 1) % BufferNum;
}
// [ReduceScatter] 4. Post Interface Sync
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID1);
}
CATLASS_DEVICE
void GetCumsumForMMAIV(AscendC::GlobalTensor<int32_t> & tokenPerExpert, AscendC::GlobalTensor<int32_t> & result, uint32_t expertPerRank, uint32_t rankId, uint32_t EP)
{
@@ -296,7 +350,7 @@ private:
AscendC::DataCopyPad(
tmpBuffer1,
tokenPerExpert[rankId * expertPerRank],
{U16(EP), U16(expertPerRank * sizeof(int32_t)), U16((AlignUp(EP * expertPerRank, ALIGN_128) - expertPerRank) * sizeof(int32_t)), 0},
{U16(EP), U16(expertPerRank * sizeof(int32_t)), U16((AlignUp(EP * expertPerRank, 128) - expertPerRank) * sizeof(int32_t)), 0},
{}
);
@@ -323,6 +377,8 @@ private:
icache_preload(8);
BlockScheduler blockScheduler;
BlockMmad blockMmad(resource);
float aivFinishGroups = 0.0f;
__gm__ float* aivFinishPtr = workspaceInfo.ptrSoftFlagBase + params.EP * FLAGSTRIDE;
int64_t gmGroupOffsetA = 0;
int64_t gmGroupOffsetB = 0;
@@ -331,11 +387,10 @@ private:
uint32_t syncGroupIdx = 0;
int64_t preCurrentmSum = 0;
int32_t syncLoopIdx = -1;
uint16_t syncgmmIdx = 0;
AscendC::CrossCoreWaitFlag<0x2>(syncgmmIdx / CROSS_CORE_FLAG_MAX_SET_COUNT); // Wait for AIV to finish cumsum for matmul
syncgmmIdx++;
AscendC::PipeBarrier<PIPE_ALL>();
for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) {
@@ -350,9 +405,7 @@ private:
int32_t arrayGroupIdx = params.listLen == 1 ? 0 : groupIdx;
gmB1.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(GetTensorAddr<int8_t>(arrayGroupIdx, params.ptrB1)));
gmS.SetGlobalBuffer(reinterpret_cast<__gm__ ElementScale *>(GetTensorAddr<int64_t>(arrayGroupIdx, params.ptrScale1)));
AscendC::PipeBarrier<PIPE_ALL>();
if (currentM <= L1TileShape::M) {
gmB1.SetL2CacheHint(AscendC::CacheMode::CACHE_MODE_DISABLE);
}
@@ -372,6 +425,7 @@ private:
AscendC::CrossCoreWaitFlag<0x2>(syncgmmIdx / CROSS_CORE_FLAG_MAX_SET_COUNT);
syncgmmIdx ++;
}
// Compute block location
GemmCoord blockCoord = blockScheduler.GetBlockCoord(loopIdx);
GemmCoord actualBlockShape = blockScheduler.GetActualBlockShape(blockCoord);
@@ -399,6 +453,7 @@ private:
if constexpr (BlockMmad::DispatchPolicy::ASYNC) {
blockMmad.SynchronizeBlock();
}
// Synchronization signal: GMM1 notifies SwiGLU [1]
blockMmad.Finalize(syncLoopIdx, SYNCFLAGC2V);
}
@@ -419,6 +474,7 @@ private:
if constexpr (BlockMmad::DispatchPolicy::ASYNC) {
blockMmad.SynchronizeBlock();
}
// Synchronization signal: GMM1 notifies SwiGLU [2]
blockMmad.Finalize(syncLoopIdx + 1, SYNCFLAGC2V);
}
@@ -427,7 +483,7 @@ private:
icache_preload(8);
BlockScheduler blockScheduler;
BlockMmad blockMmad(resource);
uint32_t n2 = params.problemShape.k();
uint32_t k2 = params.problemShape.n() / 2;
@@ -458,11 +514,9 @@ private:
}
AscendC::GlobalTensor<ElementB> gmB2;
AscendC::GlobalTensor<ElementScale> gmS2;
AscendC::PipeBarrier<PIPE_ALL>();
int32_t arrayGroupIdx = params.listLen == 1 ? 0 : groupIdx;
gmB2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(GetTensorAddr<int8_t>(arrayGroupIdx, params.ptrB2)));
gmS2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementScale *>(GetTensorAddr<int64_t>(arrayGroupIdx, params.ptrScale2)));
if (currentM <= L1TileShape::M) {
gmB2.SetL2CacheHint(AscendC::CacheMode::CACHE_MODE_DISABLE);
}
@@ -482,6 +536,7 @@ private:
if (params.expertPerRank > lastDequantExpertNum && groupIdx + 1 == params.expertPerRank - lastDequantExpertNum) {
AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGV2C);
}
for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += coreNum) {
if (loopIdx + coreNum >= coreLoops) {
syncLoopIdx = groupIdx;
@@ -502,12 +557,12 @@ private:
int64_t gmOffsetS = blockCoord.n() * L1TileShape::N + (params.listLen == 1 ? groupIdx * n2 : 0); // One scale group per expert
if (currentM > 0) {
blockMmad(
gmPermutedToken[gmGroupOffsetA + gmOffsetA], layoutA,
gmB2[gmGroupOffsetB + gmOffsetB], layoutB2,
gmC2[gmGroupOffsetC + gmOffsetC], layoutC,
gmS2[gmOffsetS], layoutScale,
actualBlockShape, syncLoopIdx, 0
);
gmPermutedToken[gmGroupOffsetA + gmOffsetA], layoutA,
gmB2[gmGroupOffsetB + gmOffsetB], layoutB2,
gmC2[gmGroupOffsetC + gmOffsetC], layoutC,
gmS2[gmOffsetS], layoutScale,
actualBlockShape, syncLoopIdx, 0
);
}
}
preCurrentmSum += currentM;
@@ -518,34 +573,34 @@ private:
gmGroupOffsetC += inGroupProblemShape.m() * inGroupProblemShape.n();
startCoreIdx = (startCoreIdx + coreLoops) % coreNum;
}
if constexpr (BlockMmad::DispatchPolicy::ASYNC) {
blockMmad.SynchronizeBlock();
}
blockMmad.Finalize(params.expertPerRank - 1, 0);
}
CATLASS_DEVICE
void ResetTokenPerExpert(AscendC::GlobalTensor<int32_t> & tokenPerExpert, int32_t num)
{
if (coreIdx != coreNum - 1) {
return;
}
CATLASS_DEVICE
void InitArithProgress(Params const &params) {
AscendC::LocalTensor<float> tmpBuffer1 = resource.ubBuf.template GetBufferByByte<float>(0);
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID0);
AscendC::LocalTensor<int32_t> tmp = resource.ubBuf.template GetBufferByByte<int32_t>(0);
AscendC::Duplicate(tmp, 0, num);
AscendC::Duplicate(tmpBuffer1, 0.0f, (params.EP + 1) * FLAGSTRIDE);
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
AscendC::DataCopy(tokenPerExpert, tmp, num);
AscendC::GlobalTensor<float> flagGlobalBase;
flagGlobalBase.SetGlobalBuffer(workspaceInfo.ptrSoftFlagBase);
AscendC::DataCopy(flagGlobalBase, tmpBuffer1, (params.EP + 1) * FLAGSTRIDE);
}
CATLASS_DEVICE
void CrossRankSyncAndlocalTokenPerExpertAllGather(Params const &params, int64_t localTokenPerExpertOffset){
void CrossRankSyncAndlocalTokenPerExpertAllGatherAndGetSumPreRankV2(Params const &params, int64_t localTokenPerExpertOffset){
uint32_t numPerCore = AlignUp(params.EP * params.expertPerRank, 128);
AscendC::LocalTensor<int32_t> tmpBuffer = resource.ubBuf.template GetBufferByByte<int32_t>(0);
uint32_t numPerCore = AlignUp(params.EP * params.expertPerRank, ALIGN_128);
AscendC::LocalTensor<int32_t> prevSumBuf = tmpBuffer[numPerCore];
for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) {
if (dstEpIdx == params.rank) {
continue;
@@ -562,11 +617,11 @@ private:
using CopyUbToGm = Epilogue::Tile::CopyUb2Gm<ArchTag, TType>;
CopyGmToUb copyGmToUb;
CopyUbToGm copyUbToGm;
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
copyGmToUb(tmpBuffer, srcAddress[0],
layout::RowMajor{ 1, numPerCore},
copyGmToUb(tmpBuffer, srcAddress[0],
layout::RowMajor{ 1, numPerCore},
layout::RowMajor{1, numPerCore});
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
@@ -574,35 +629,125 @@ private:
AscendC::Adds(tmpBuffer, tmpBuffer, 0x800000, numPerCore);
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
copyUbToGm(dstAddress[0], tmpBuffer,
layout::RowMajor{ 1, numPerCore},
copyUbToGm(dstAddress[0], tmpBuffer,
layout::RowMajor{ 1, numPerCore},
layout::RowMajor{1, numPerCore});
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
}
for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) {
if (dstEpIdx == params.rank) {
continue;
if (dstEpIdx != params.rank) {
int32_t intPer512 = CACHE_LINE / sizeof(int);
for(int32_t checkIdx = 0; checkIdx < AlignUp(params.EP * params.expertPerRank, 128); checkIdx += intPer512) {
__gm__ int32_t* sync_check = reinterpret_cast<__gm__ int32_t*>(shmem() + peermemInfo.offsetPeerTokenPerExpert) + tokenPerExpertLayout(dstEpIdx, 0, checkIdx);
gm_signal_wait_until_ne(sync_check, 0);
}
AscendC::DataCopy(tmpBuffer, tokenPerExpert[tokenPerExpertLayout(dstEpIdx, 0, 0)], numPerCore);
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
AscendC::Adds(tmpBuffer, tmpBuffer, -0x800000, numPerCore);
AscendC::PipeBarrier<PIPE_V>();
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
AscendC::DataCopy(tokenPerExpert[tokenPerExpertLayout(dstEpIdx, 0, 0)], tmpBuffer, numPerCore);
} else {
AscendC::DataCopy(tmpBuffer, tokenPerExpert[tokenPerExpertLayout(dstEpIdx, 0, 0)], numPerCore);
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
}
int32_t intPer512 = CACHE_LINE / sizeof(int);
for(int32_t checkIdx = 0; checkIdx < AlignUp(params.EP * params.expertPerRank, ALIGN_128); checkIdx += intPer512) {
__gm__ int32_t* sync_check = reinterpret_cast<__gm__ int32_t*>(shmem() + peermemInfo.offsetPeerTokenPerExpert) + tokenPerExpertLayout(dstEpIdx, 0, checkIdx);
gm_signal_wait_until_ne(sync_check, 0);
AscendC::PipeBarrier<PIPE_ALL>();
int32_t prevSum = 0;
int32_t j = 0;
for (int32_t i = 0; i < (params.rank + 1) * params.expertPerRank; i++) {
if (i >= params.rank * params.expertPerRank) {
prevSumBuf(j) = prevSum;
j++;
}
prevSum += tmpBuffer(i);
}
AscendC::DataCopy(tmpBuffer, tokenPerExpert[tokenPerExpertLayout(dstEpIdx, 0, 0)], numPerCore);
AscendC::SetFlag<AscendC::HardEvent::S_MTE3>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::S_MTE3>(EVENT_ID0);
AscendC::DataCopyPad(preSumBeforeRank[dstEpIdx * params.expertPerRank], prevSumBuf,
AscendC::DataCopyParams{1, static_cast<uint16_t>(params.expertPerRank * sizeof(int32_t)), 0, 0});
}
AscendC::SyncAll<true>();
}
CATLASS_DEVICE
void ResetTokenPerExpert(int32_t num)
{
if (coreIdx != coreNum - 1) {
return;
}
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID0);
AscendC::LocalTensor<int32_t> tmp = resource.ubBuf.template GetBufferByByte<int32_t>(0);
AscendC::Duplicate(tmp, 0, num);
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
AscendC::DataCopy(tokenPerExpert, tmp, num);
}
CATLASS_DEVICE
void UpdateAicFlags(const Params &params)
{
float flagBase = 1.0f * params.expertPerRank;
__gm__ float* aicFinishPtr = workspaceInfo.ptrSoftFlagBase + params.EP * FLAGSTRIDE;
float flag = 0.0f;
float lastflag = -1.0f;
AscendC::LocalTensor<float> tmpBuffer1 = resource.ubBuf.template GetBufferByByte<float>(0);
__gm__ float* flagPtr = workspaceInfo.ptrSoftFlagBase;
AscendC::GlobalTensor<float> flagGM;
flagGM.SetGlobalBuffer(flagPtr);
int32_t flagBufferSize = max(4, params.EP) * FLAGSTRIDE;
AscendC::LocalTensor<float> dstValueBuffer = resource.ubBuf.template GetBufferByByte<float>(flagBufferSize);
AscendC::LocalTensor<float> sharedTmpBuffer = resource.ubBuf.template GetBufferByByte<float>((flagBufferSize + 64));
uint64_t mask[1] = {0};
uint32_t repeatNum = (flagBufferSize / (4 * FLAGSTRIDE));
for (int32_t i = 0; i < 4; i ++) {
if (i < params.EP) {
mask[0] |= 1ull * (1ull << (i * 16));
}
}
AscendC::SetFlag<AscendC::HardEvent::S_V>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::S_V>(EVENT_ID0);
while (flag < flagBase) {
flag = flagBase;
AscendC::DataCopy(tmpBuffer1, flagGM, params.EP * FLAGSTRIDE);
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
AscendC::Adds(tmpBuffer, tmpBuffer, -0x800000, numPerCore);
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
AscendC::DataCopy(tokenPerExpert[tokenPerExpertLayout(dstEpIdx, 0, 0)], tmpBuffer, numPerCore);
AscendC::ReduceMin<float>(dstValueBuffer, tmpBuffer1, sharedTmpBuffer, mask, repeatNum, 8, false);
AscendC::SetFlag<AscendC::HardEvent::V_S>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::V_S>(EVENT_ID0);
flag = min(flag, dstValueBuffer.GetValue(0));
if (flag > lastflag) {
*aicFinishPtr = flag;
gm_dcci(aicFinishPtr);
lastflag = flag;
}
}
AscendC::SyncAll<true>();
}
CATLASS_DEVICE
void Dispatch(Params const &params) {
void CombineSetFlag() {
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID0);
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID1);
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID2);
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID3);
AscendC::SetFlag<AscendC::HardEvent::S_MTE2>(EVENT_ID2);
AscendC::SetFlag<AscendC::HardEvent::S_MTE2>(EVENT_ID3);
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID0);
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID1);
}
CATLASS_DEVICE
void DispatchAndCombine(Params const &params) {
icache_preload(8);
int64_t localTokenPerExpertOffset = peermemInfo.offsetPeerTokenPerExpert + tokenPerExpertLayout(params.rank, 0, 0) * sizeof(int32_t);
GM_ADDR localTokenPerExpert = shmem() + localTokenPerExpertOffset; // Place the entire communication matrix in peermem
@@ -617,10 +762,19 @@ private:
&params.moeInitRoutingQuantV2TilingData, params.initRoutingQuantTilingKey);
AscendC::SyncAll<true>();
CrossRankSyncAndlocalTokenPerExpertAllGather(params, localTokenPerExpertOffset);
CrossRankSyncAndlocalTokenPerExpertAllGatherAndGetSumPreRankV2(params, localTokenPerExpertOffset);
if (coreIdx == 0) {
GetCumsumForMMAIV(tokenPerExpert, cumsumMM, params.expertPerRank, params.rank, params.EP);
}
uint32_t curGroupOffset = 0;
int32_t prevSumBeforeRank = 0;
int32_t prevSum = 0;
if (coreIdx < params.EP) {
prevSum = preSumBeforeRank(coreIdx * params.expertPerRank);
}
AscendC::SyncAll<true>();
AscendC::GlobalTensor<int32_t> ExpertTokenNums;
@@ -633,24 +787,12 @@ private:
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(syncgmm1Idx / CROSS_CORE_FLAG_MAX_SET_COUNT);
syncgmm1Idx++;
uint32_t curGroupOffset = 0;
int32_t prevSumBeforeRank = 0;
int32_t groupIdxDeq = 0;
if (coreIdx < params.EP) {
for (int32_t i = 0; i < params.rank * params.expertPerRank; i++) {
prevSumBeforeRank += tokenPerExpert(tokenPerExpertLayout(coreIdx, 0, i));
}
m_prevSumBeforeRank = prevSumBeforeRank;
}
int prevSum = prevSumBeforeRank;
uint32_t prevGroupSum1 = 0, dequantSum1 = 0, dequantSum2 = 0;
uint32_t dequantSum = 0;
int32_t syncLoopIdx = -1;
uint32_t n = params.problemShape.n();
BlockEpilogue1 blockEpilogue(resource, n);
icache_preload(8);
for (int32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) {
// The ith core reads data from the ith rank's peermem
groupIdxDeq = groupIdx - 2;
uint32_t currentM = cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx);
for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) {
uint32_t rowStart = (dstEpIdx == 0 ? 0 : cumsumMM((dstEpIdx - 1) * params.expertPerRank + groupIdx)) + prevGroupSum1;
@@ -664,24 +806,23 @@ private:
GM_ADDR otherRankPtr = shmem(0, dstEpIdx);
AscendC::GlobalTensor<ElementA> gmRemoteA;
gmRemoteA.SetGlobalBuffer(reinterpret_cast<__gm__ ElementA*>(otherRankPtr + peermemInfo.offsetA));
AscendC::GlobalTensor<ElementPerTokenScale> gmRemotePerTokenScale;
gmRemotePerTokenScale.SetGlobalBuffer(reinterpret_cast<__gm__ ElementPerTokenScale*>(otherRankPtr + peermemInfo.offsetPeerPerTokenScale));
MatrixCoord offsetA{rowStart, 0};
MatrixCoord shapeA{rows, params.problemShape.k()};
MatrixCoord offsetPeer{rowSrc, 0};
int64_t gmOffsetA = params.layoutA.GetOffset(offsetA);
int64_t gmOffsetPeer = params.layoutA.GetOffset(offsetPeer);
int64_t gmOffsetPeer = rowSrc * (params.problemShape.k() + ALIGN_512);
// Communication data
CopyGMToGM(gmA[gmOffsetA], gmRemoteA[gmOffsetPeer], rows * params.problemShape.k(), params.ubMoveNum);
// Communication scale
CopyGMToGM(gmPerTokenScale1[rowStart], gmRemotePerTokenScale[rowSrc], rows, rows);
CopyGMToGMPerToken(gmA[gmOffsetA], gmPerTokenScale1[rowStart], gmRemoteA[gmOffsetPeer], rows, params.problemShape.k());
}
}
}
AscendC::SyncAll<true>();
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(syncgmm1Idx / CROSS_CORE_FLAG_MAX_SET_COUNT); // V notifies C that the current communication round is complete
prevGroupSum1 += currentM;
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(syncgmm1Idx / CROSS_CORE_FLAG_MAX_SET_COUNT);
syncgmm1Idx ++;
prevGroupSum1 += currentM;
// Token count and truncation logic for the first SwiGLU operation
if (groupIdx + 1 <= params.epilogueGranularity) {
if (dequantSum1 + currentM <= params.maxOutputSize) {
dequantSum1 += currentM;
@@ -689,6 +830,8 @@ private:
dequantSum1 = params.maxOutputSize;
}
}
// Token count and truncation logic for the second SwiGLU operation
if (groupIdx + 1 > params.epilogueGranularity && dequantSum1 < params.maxOutputSize) {
if (dequantSum1 + dequantSum2 + currentM <= params.maxOutputSize) {
dequantSum2 += currentM;
@@ -698,21 +841,42 @@ private:
}
}
uint32_t n2 = params.problemShape.k();
typename BlockEpilogue2::Params epilogueParams{
static_cast<int32_t>(params.EP),
static_cast<int32_t>(params.expertPerRank),
static_cast<int32_t>(params.rank),
reinterpret_cast<__gm__ int32_t *>(shmem() + peermemInfo.offsetPeerTokenPerExpert),
params.layoutD2,
static_cast<int32_t>(n2),
static_cast<int32_t>(L1TileShape::N),
shmem,
static_cast<int32_t>(peermemInfo.offsetD)
};
uint32_t n = params.problemShape.n();
BlockEpilogue2 blockEpilogue2(resource, epilogueParams);
BlockEpilogue1 blockEpilogue1(resource, n);
// Synchronous wait: SwiGLU waits for GMM1 [1]
AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGC2V);
AscendC::SyncAll<true>();
if (dequantSum1 > 0) {
if (dequantSum1 > 0) {
uint32_t rowStartThisCore = 0;
MatrixCoord offsetC{0U, 0};
MatrixCoord shapeC{dequantSum1, params.problemShape.n()};
LayoutC layoutC{dequantSum1, params.problemShape.n()};
int64_t gmOffsetC = layoutC.GetOffset(offsetC);
int64_t gmOffsetD = params.layoutD1.GetOffset(offsetC);
blockEpilogue(gmC[gmOffsetC], shapeC, gmPerTokenScale1[rowStartThisCore], gmPermutedToken[gmOffsetD], gmPerTokenScale2[rowStartThisCore], params.epilogueCoreNum);
blockEpilogue1(gmC[gmOffsetC], shapeC, gmPerTokenScale1[rowStartThisCore], gmPermutedToken[gmOffsetD], gmPerTokenScale2[rowStartThisCore], params.epilogueCoreNum);
}
AscendC::SyncAll<true>();
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(SYNCFLAGV2C);
// Synchronization signal: SwiGLU notifies GMM2 [1]
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(SYNCFLAGV2C);
if ((params.epilogueGranularity < params.expertPerRank && params.epilogueGranularity > 0)) {
// Synchronous wait: SwiGLU waits for GMM1 [2]
AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGC2V);
AscendC::SyncAll<true>();
if (dequantSum2 > 0) {
@@ -723,76 +887,118 @@ private:
LayoutC layoutC{dequantLen, params.problemShape.n()};
int64_t gmOffsetC = layoutC.GetOffset(offsetC);
int64_t gmOffsetD = params.layoutD1.GetOffset(offsetC);
blockEpilogue(gmC[gmOffsetC], shapeC, gmPerTokenScale1[rowStartThisCore], gmPermutedToken[gmOffsetD], gmPerTokenScale2[rowStartThisCore], coreNum);
blockEpilogue1(gmC[gmOffsetC], shapeC, gmPerTokenScale1[rowStartThisCore], gmPermutedToken[gmOffsetD], gmPerTokenScale2[rowStartThisCore], coreNum);
}
AscendC::SyncAll<true>();
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(SYNCFLAGV2C);
// Synchronization signal: SwiGLU notifies GMM2 [2]
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(SYNCFLAGV2C);
}
blockEpilogue.Finalize();
blockEpilogue1.Finalize();
CombineSetFlag();
CombineV2(params, blockEpilogue2);
AscendC::SyncAll<true>();
#ifndef __CROSSRANKSYNCANDALLGATHERV1__
ResetTokenPerExpert(params.EP * AlignUp(params.EP * params.expertPerRank, 128));
#endif
shmem.InitStatusTargetSum();
if (get_subblockid() == 0) {
AscendC::LocalTensor<int32_t> ctrBuffer = resource.ubBuf.template GetBufferByByte<int32_t>(0);
shmem.CrossRankSyncV2Set(ctrBuffer);
} else {
uint32_t uboffset = 0;
uint32_t aicCoreNum = coreNum / 2;
uint32_t aicCoreIdx = get_block_idx();
uint32_t sendRankNum_ = params.EP / aicCoreNum;
uint32_t remainderRankNum = params.EP % aicCoreNum;
if (aicCoreIdx < remainderRankNum) {
sendRankNum_++;
}
AscendC::LocalTensor<float> statusTensor = resource.ubBuf.template GetBufferByByte<float>(uboffset);
uboffset += sendRankNum_ * UB_ALIGN;
AscendC::LocalTensor<float> gatherMaskOutTensor = resource.ubBuf.template GetBufferByByte<float>(uboffset);
uboffset += params.EP * sizeof(float);
AscendC::LocalTensor<uint32_t> gatherTmpTensor = resource.ubBuf.template GetBufferByByte<uint32_t>(uboffset);
uboffset += sizeof(uint32_t);
AscendC::LocalTensor<float> statusSumOutTensor = resource.ubBuf.template GetBufferByByte<float>(uboffset);
uboffset += sizeof(float);
shmem.CrossRankSyncV2Wait(statusTensor, gatherMaskOutTensor, gatherTmpTensor, statusSumOutTensor);
MoeTokenUnpermuteTilingData tilingData;
MoeTokenUnpermuteTiling(params.problemShape.m() * params.topK, n2, params.topK, tilingData, coreNum / 2);
KernelMoeTokenUnpermute<ElementD2, int32_t, float, true> kernelMoeTokenUnpermuteOp;
kernelMoeTokenUnpermuteOp.Init(shmem() + peermemInfo.offsetD, workspaceInfo.expandedRowIdx, params.probs, reinterpret_cast<GM_ADDR>(params.ptrOutput), &tilingData);
kernelMoeTokenUnpermuteOp.Process();
}
}
CATLASS_DEVICE
void Combine(Params const &params) {
int32_t prevSumBeforeRank = 0;
if (coreIdx < params.EP) {
prevSumBeforeRank = m_prevSumBeforeRank;
}
int prevSum = prevSumBeforeRank;
void CombineV2(Params const &params, BlockEpilogue2 & blockEpilogue) {
BlockScheduler blockScheduler;
int32_t syncLoopIdx = 0;
uint32_t startCoreIdx = 0;
uint32_t aicCoreNum = coreNum / 2;
uint32_t aicCoreIdx = get_block_idx();
uint32_t aivSubCoreIdx = get_subblockid();
uint32_t preSrcExpertSum = 0;
uint32_t n2 = params.problemShape.k();
uint32_t k2 = params.problemShape.n() / 2;
// TODO compute the cumsum of tokenPerExpert
typename BlockEpilogue2::Params epilogueParams{
static_cast<int32_t>(params.EP),
static_cast<int32_t>(params.expertPerRank),
reinterpret_cast<__gm__ int32_t *>(params.ptrWorkspace),
static_cast<int32_t>(n2)
};
BlockEpilogue2 blockEpilogue(resource, epilogueParams);
int32_t prevGroupSum2 = 0;
icache_preload(8);
for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) {
AscendC::CrossCoreWaitFlag<0x2>(groupIdx / CROSS_CORE_FLAG_MAX_SET_COUNT);
AscendC::SyncAll<true>();
uint32_t currentExpertM = cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx);
if (preSrcExpertSum >= params.maxOutputSize) {
currentExpertM = 0;
} else if (preSrcExpertSum + currentExpertM > params.maxOutputSize) {
currentExpertM = params.maxOutputSize - preSrcExpertSum;
}
GemmCoord inGroupProblemShape{currentExpertM, n2, k2}; // M N K
blockScheduler.Update(inGroupProblemShape, MakeCoord(L1TileShape::M, L1TileShape::N));
uint32_t coreLoops = blockScheduler.GetCoreLoops();
uint32_t startLoopIdx = ((aicCoreIdx < startCoreIdx) ? (aicCoreIdx + aicCoreNum) : aicCoreIdx) - startCoreIdx;
for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) {
__gm__ void* dstPeermemPtr = shmem(peermemInfo.offsetD, dstEpIdx);
AscendC::GlobalTensor<ElementD2> gmRemotePeer;
gmRemotePeer.SetGlobalBuffer(reinterpret_cast<__gm__ ElementD2*>(dstPeermemPtr));
uint32_t srcRowOffset = (dstEpIdx == 0 ? 0 : cumsumMM((dstEpIdx - 1) * params.expertPerRank + groupIdx)) + prevGroupSum2;
if (srcRowOffset < params.maxOutputSize) {
uint32_t dataRows = tokenPerExpert(tokenPerExpertLayout(dstEpIdx, params.rank, groupIdx));
if (srcRowOffset + dataRows > params.maxOutputSize) {
dataRows = params.maxOutputSize - srcRowOffset;
}
uint32_t dstRowOffset = prevSum;
prevSum += dataRows;
MatrixCoord offsetC{srcRowOffset, 0};
MatrixCoord offsetPeer{dstRowOffset, 0};
MatrixCoord shapeC{dataRows, n2};
int64_t gmOffsetC = params.layoutD2.GetOffset(offsetC);
int64_t gmOffsetPeer = params.layoutD2.GetOffset(offsetPeer);
if constexpr (std::is_same_v<ElementA, int8_t>) {
blockEpilogue(gmC2[gmOffsetC], shapeC, gmPerTokenScale2[srcRowOffset], gmRemotePeer[gmOffsetPeer]);
} else {
blockEpilogue(gmC2[gmOffsetC], shapeC, gmRemotePeer[gmOffsetPeer]);
for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += aicCoreNum) {
GemmCoord blockCoord = blockScheduler.GetBlockCoord(loopIdx);
GemmCoord actualBlockShape = blockScheduler.GetActualBlockShape(blockCoord);
int32_t m0 = 16;
// Block count, the shape of each block is (m0, actualBlockShape.n())
int32_t m_rows = (actualBlockShape.m() + m0 - 1) / m0;
int32_t aiv_m_rows = m_rows / 2;
if (aivSubCoreIdx == 1 && aiv_m_rows * 2 < m_rows) {
aiv_m_rows += 1;
}
uint32_t m_offset = blockCoord.m() * L1TileShape::M;//blockOffset
if(aivSubCoreIdx == 1) {
m_offset += (m_rows / 2) * m0;
}
for (;syncLoopIdx <= groupIdx; syncLoopIdx ++) {
int32_t flag_id = syncLoopIdx / CROSS_CORE_FLAG_MAX_SET_COUNT;
AscendC::CrossCoreWaitFlag<0x2>(flag_id);
}
for (int32_t cur_row = 0; cur_row < aiv_m_rows; cur_row ++) {
GemmCoord realTileCoord{m_offset, blockCoord.n() * L1TileShape::N, 1};
uint32_t actualm = m0;
if(aivSubCoreIdx == 1 && cur_row == aiv_m_rows - 1){
actualm = actualBlockShape.m() - (m_rows / 2) * m0 - cur_row * m0;
}
GemmCoord realTileShape{actualm, actualBlockShape.n(), 1};
blockEpilogue(gmC2, gmPerTokenScale2, realTileCoord, realTileShape, groupIdx, preSrcExpertSum, preSumBeforeRank);
m_offset += m0;
}
}
prevGroupSum2 += cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx);
preSrcExpertSum += currentExpertM;
startCoreIdx = (startCoreIdx + coreLoops) % aicCoreNum;
}
blockEpilogue.Finalize();
AscendC::SyncAll<true>();
ResetTokenPerExpert(tokenPerExpert, params.EP * AlignUp(params.EP * params.expertPerRank, ALIGN_128));
shmem.CrossRankSync();
MoeTokenUnpermuteTilingData tilingData;
MoeTokenUnpermuteTiling(params.problemShape.m() * params.topK, n2, params.topK, tilingData, coreNum);
KernelMoeTokenUnpermute<ElementD2, int32_t, float, true> kernelMoeTokenUnpermuteOp;
kernelMoeTokenUnpermuteOp.Init(shmem() + peermemInfo.offsetD, workspaceInfo.expandedRowIdx, params.probs, reinterpret_cast<GM_ADDR>(params.ptrOutput), &tilingData);
kernelMoeTokenUnpermuteOp.Process();
}
private:
struct WorkspaceInfo {
GM_ADDR ptrA;
@@ -804,6 +1010,9 @@ private:
GM_ADDR ptrPerTokenScale2;
GM_ADDR expandedRowIdx;
GM_ADDR ptrTokenPerExpert;
GM_ADDR ptrSumBeforeRank;
__gm__ float* ptrSoftFlagBase;
CATLASS_DEVICE
WorkspaceInfo(){}
@@ -831,15 +1040,21 @@ private:
workspaceOffset += (params.EP * params.EP * params.expertPerRank) * sizeof(int32_t);
ptrC = params.ptrWorkspace + workspaceOffset;
ptrC2 = ptrC;
workspaceOffset += max(params.maxOutputSize * params.problemShape.n() * sizeof(ElementC),
params.maxOutputSize * n2 * sizeof(ElementC));
workspaceOffset += params.maxOutputSize * params.problemShape.n() * sizeof(ElementC);
ptrC2 = params.ptrWorkspace + workspaceOffset;
workspaceOffset += params.maxOutputSize * n2 * sizeof(ElementC);
ptrA = params.ptrWorkspace + workspaceOffset;
ptrPermutedToken = ptrA;
workspaceOffset += max(params.maxOutputSize * params.problemShape.k() * sizeof(ElementA),
params.maxOutputSize * k2 * sizeof(ElementA));
workspaceOffset += params.maxOutputSize * params.problemShape.k() * sizeof(ElementA);
ptrPermutedToken = params.ptrWorkspace + workspaceOffset;
workspaceOffset += params.maxOutputSize * k2 * sizeof(ElementA);
ptrSumBeforeRank = params.ptrWorkspace + workspaceOffset;
workspaceOffset += params.EP * sizeof(int32_t) * FLAGSTRIDE;
ptrSoftFlagBase = reinterpret_cast<__gm__ float*>(params.ptrWorkspace + workspaceOffset);
}
};
@@ -866,12 +1081,9 @@ private:
uint32_t coreIdx;
uint32_t coreNum;
Params params;
WorkspaceInfo workspaceInfo;
PeermemInfo peermemInfo;
int64_t m_prevSumBeforeRank;
AscendC::GlobalTensor<ElementA> gmA;
AscendC::GlobalTensor<ElementC> gmC;
@@ -883,10 +1095,11 @@ private:
AscendC::GlobalTensor<int32_t> tokenPerExpert;
AscendC::GlobalTensor<int32_t> cumsumMM;
AscendC::GlobalTensor<int32_t> preSumBeforeRank;
Layout3D tokenPerExpertLayout;
HcclShmem shmem;
};
} // namespace Catlass::Gemm::Kernel
#endif // DISPATH_FFN_COMBINE_KERNEL_HPP
#endif // DISPATCH_FFN_COMBINE_KERNEL_HPP

View File

@@ -44,6 +44,8 @@ constexpr int64_t EXERPT_TOKENS_COUNT = 2;
constexpr int64_t EXERPT_TOKENS_CUMSUM = 1;
constexpr int64_t EXERPT_TOKENS_NONE = 0;
constexpr int64_t EXERPT_TOKENS_BEFORE_CAPACITY = 1;
constexpr int64_t ALIGN_512 = 512;
constexpr int64_t ALIGN_128 = 128;
const __gm__ int32_t assist[256] = {
0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0,

View File

@@ -35,7 +35,6 @@ class MoeV2FullLoadDynamicQuant : public MoeV2SortBase {
__aicore__ inline void CopyOutIdx();
__aicore__ inline void CopyOutEmpty();
__aicore__ inline void CopyOutXQuant1H();
__aicore__ inline void CopyOutXQuantEH();
__aicore__ inline void ComputeExpertTokenCountOrCumsum();
__aicore__ inline void Compute(LocalTensor<float>& smoothLocal);
@@ -49,6 +48,7 @@ class MoeV2FullLoadDynamicQuant : public MoeV2SortBase {
int64_t k_;
int64_t n_;
int64_t cols_;
int64_t cols_scale_;
int64_t activateRows_;
int64_t expertNum;
int64_t expertCapacity;
@@ -63,12 +63,10 @@ class MoeV2FullLoadDynamicQuant : public MoeV2SortBase {
TQue<QuePosition::VECIN, 1> smoothInQueue;
TQue<QuePosition::VECOUT, 1> calcQueue;
TQue<QuePosition::VECOUT, 1> inputXOutQueue;
TQue<QuePosition::VECOUT, 1> scaleOutQueue;
GlobalTensor<T> xGm_;
GlobalTensor<int32_t> expertIdxGm_;
GlobalTensor<float> quantSmoothGm;
GlobalTensor<float> dynamicQuantScaleGm;
GlobalTensor<int8_t> expandedXGm_;
GlobalTensor<int32_t> expandedRowIdxGm_;
@@ -225,7 +223,7 @@ __aicore__ inline void MoeV2FullLoadDynamicQuant<T>::Compute(LocalTensor<float>&
LocalTensor<float> tempLocal = calcQueue.AllocTensor<float>();
LocalTensor<int8_t> outLocal = inputXOutQueue.AllocTensor<int8_t>();
LocalTensor<float> dynamicQuantLocal = scaleOutQueue.AllocTensor<float>();
LocalTensor<float> dynamicQuantLocal = outLocal[this->cols_].template ReinterpretCast<float>();
if constexpr (!IsSameType<T, float>::value) {
Cast(inLocal, inLocal.ReinterpretCast<T>()[colsAlign], RoundMode::CAST_NONE, this->cols_);
@@ -259,7 +257,6 @@ __aicore__ inline void MoeV2FullLoadDynamicQuant<T>::Compute(LocalTensor<float>&
calcQueue.FreeTensor(tempLocal);
inputXOutQueue.EnQue(outLocal);
scaleOutQueue.EnQue(dynamicQuantLocal);
}
template <typename T>
@@ -275,7 +272,7 @@ __aicore__ inline void MoeV2FullLoadDynamicQuant<T>::CopyOutXQuant1H() {
DataCopyExtParams dataXCopyParams{1, static_cast<uint32_t>(this->cols_ * sizeof(T)), 0, 0, 0};
DataCopyExtParams smoothCopyParams{1, static_cast<uint32_t>(this->cols_ * sizeof(float)), 0, 0, 0};
DataCopyExtParams intriParams{1, static_cast<uint32_t>(this->cols_ * sizeof(int8_t)), 0, 0, 0};
DataCopyExtParams intriParams{1, static_cast<uint32_t>((this->cols_ + BLOCK_BYTES) * sizeof(int8_t)), 0, 0, 0};
LocalTensor<float> smoothLocal;
if (smoothType == 1) {
@@ -295,7 +292,6 @@ __aicore__ inline void MoeV2FullLoadDynamicQuant<T>::CopyOutXQuant1H() {
xCopyInQueue_.EnQue<T>(xLocal);
Compute(smoothLocal);
LocalTensor<float> quantScaleLocal = scaleOutQueue.DeQue<float>();
LocalTensor<int8_t> outLocal = inputXOutQueue.DeQue<int8_t>();
while (curRowsStart <= curRowsEnd && curRowsStart / this->k_ == row) {
int32_t outIndex = expandedRowIdx.GetValue(curRowsStart);
@@ -303,76 +299,15 @@ __aicore__ inline void MoeV2FullLoadDynamicQuant<T>::CopyOutXQuant1H() {
if (outIndex == -1 || (this->dropPadMode == DROPLESS_MODE && outIndex >= this->activateRows_)) {
continue;
}
DataCopyPad(expandedXGm_[outIndex * cols_], outLocal, intriParams);
DataCopyPad(dynamicQuantScaleGm[outIndex], quantScaleLocal, {1, 4, 0, 0, 0});
DataCopyPad(expandedXGm_[outIndex * this->cols_scale_], outLocal, intriParams);
}
xCopyInQueue_.FreeTensor(xLocal);
inputXOutQueue.FreeTensor(outLocal);
scaleOutQueue.FreeTensor(quantScaleLocal);
}
if (smoothType == 1) {
smoothInQueue.FreeTensor(smoothLocal);
}
expandedRowIdxCopyOutQueue_.FreeTensor(expandedRowIdx);
}
template <typename T>
__aicore__ inline void MoeV2FullLoadDynamicQuant<T>::CopyOutXQuantEH() {
LocalTensor<int32_t> expandedRowIdx = expandedRowIdxCopyOutQueue_.DeQue<int32_t>();
expandedRowIdxCopyOutQueue_.FreeTensor(expandedRowIdx);
Muls(expandDstToSrcRowLocal.ReinterpretCast<float>(), expandDstToSrcRowLocal.ReinterpretCast<float>(), (float)-1,
this->totalLength);
pipe_barrier(PIPE_V);
LocalTensor<int32_t> sortedRowIdx = expandDstToSrcRowLocal.ReinterpretCast<int32_t>();
Cast(sortedRowIdx, expandDstToSrcRowLocal.ReinterpretCast<float>(), RoundMode::CAST_ROUND, this->totalLength);
int64_t curRowsStart = this->blockIdx_ * this->perCoreRows_;
int64_t curRowsEnd = curRowsStart + this->coreRows_ - 1;
DataCopyExtParams dataXCopyParams{1, static_cast<uint32_t>(this->cols_ * sizeof(T)), 0, 0, 0};
DataCopyExtParams smoothCopyParams{1, static_cast<uint32_t>(this->cols_ * sizeof(float)), 0, 0, 0};
DataCopyExtParams intriParams{1, static_cast<uint32_t>(this->cols_ * sizeof(int8_t)), 0, 0, 0};
for (int64_t row = curRowsStart; row <= curRowsEnd; row++) {
if (this->dropPadMode == DROPLESS_MODE && row >= this->activateRows_) {
break;
}
int32_t srcIdx = sortedRowIdx.GetValue(row);
int32_t expertIdx = expandedExpertIdxLocal.GetValue(row);
LocalTensor<T> inLocal = xCopyInQueue_.AllocTensor<T>();
LocalTensor<float> smoothLocal = smoothInQueue.AllocTensor<float>();
if constexpr (IsSameType<T, float>::value) {
DataCopyPad(inLocal, xGm_[srcIdx / this->k_ * this->cols_], dataXCopyParams, {false, 0, 0, 0});
} else {
DataCopyPad(inLocal[colsAlign], xGm_[srcIdx / this->k_ * this->cols_], dataXCopyParams, {false, 0, 0, 0});
}
DataCopyPad(smoothLocal, quantSmoothGm[expertIdx * this->cols_], smoothCopyParams, {false, 0, 0, 0});
xCopyInQueue_.EnQue<T>(inLocal);
smoothInQueue.EnQue(smoothLocal);
smoothLocal = smoothInQueue.DeQue<float>();
Compute(smoothLocal);
LocalTensor<float> quantScaleLocal = scaleOutQueue.DeQue<float>();
DataCopyPad(dynamicQuantScaleGm[row], quantScaleLocal, {1, 4, 0, 0, 0});
LocalTensor<int8_t> outLocal = inputXOutQueue.DeQue<int8_t>();
DataCopyPad(expandedXGm_[row * this->cols_], outLocal, intriParams);
xCopyInQueue_.FreeTensor(inLocal);
smoothInQueue.FreeTensor(smoothLocal);
inputXOutQueue.FreeTensor(outLocal);
scaleOutQueue.FreeTensor(quantScaleLocal);
}
expandDstToSrcRowQueue_.FreeTensor(expandDstToSrcRowLocal);
expandedExpertIdxCopyOutQueue_.FreeTensor(expandedExpertIdxLocal);
}
template <typename T>
__aicore__ inline void MoeV2FullLoadDynamicQuant<T>::Init(GM_ADDR x, GM_ADDR expertIdx, GM_ADDR expandedX,
GM_ADDR expandedRowIdx, GM_ADDR expertTokensCountOrCumsum,
@@ -386,6 +321,7 @@ __aicore__ inline void MoeV2FullLoadDynamicQuant<T>::Init(GM_ADDR x, GM_ADDR exp
this->k_ = tilingData->k;
this->n_ = tilingData->n;
this->cols_ = tilingData->cols;
this->cols_scale_ = this->cols_ + ALIGN_512;
this->needCoreNum_ = this->gatherOutTilingData_->needCoreNum;
this->perCoreRows_ = this->gatherOutTilingData_->perCoreRows;
this->activateRows_ = this->gatherOutTilingData_->activateRows;
@@ -416,7 +352,6 @@ __aicore__ inline void MoeV2FullLoadDynamicQuant<T>::Init(GM_ADDR x, GM_ADDR exp
Align(this->expertNum, sizeof(int32_t)));
}
quantSmoothGm.SetGlobalBuffer((__gm__ float*)quantSmooth);
dynamicQuantScaleGm.SetGlobalBuffer((__gm__ float*)dynamicQuantScale);
int64_t kvFactor = 2;
int64_t buffSize = this->sortNum_ * sizeof(int32_t);
@@ -440,8 +375,7 @@ __aicore__ inline void MoeV2FullLoadDynamicQuant<T>::Init(GM_ADDR x, GM_ADDR exp
}
pipe->InitBuffer(smoothInQueue, 1, AlignBytes(this->cols_, sizeof(float)));
pipe->InitBuffer(calcQueue, 1, AlignBytes(this->cols_, sizeof(float)));
pipe->InitBuffer(inputXOutQueue, 1, AlignBytes(this->cols_, sizeof(int8_t)));
pipe->InitBuffer(scaleOutQueue, 1, BLOCK_BYTES + BLOCK_BYTES);
pipe->InitBuffer(inputXOutQueue, 1, AlignBytes(this->cols_scale_, sizeof(int8_t)));
}
template <typename T>
@@ -457,11 +391,7 @@ __aicore__ inline void MoeV2FullLoadDynamicQuant<T>::Process() {
} else {
CopyOutEmpty();
}
if (smoothType == 2) {
CopyOutXQuantEH();
} else {
CopyOutXQuant1H();
}
CopyOutXQuant1H();
}
}
} // namespace MoeInitRoutingQuantV2

View File

@@ -66,6 +66,7 @@ class MoeV2GatherDynamicQuant {
int64_t needCoreNum;
int64_t blockIdx;
int64_t cols;
int64_t cols_scale_;
int64_t n;
int64_t k;
int64_t totalLength;
@@ -117,7 +118,7 @@ __aicore__ inline void MoeV2GatherDynamicQuant<T>::Compute(LocalTensor<float>& s
LocalTensor<float> tempLocal = calcQueue.AllocTensor<float>();
LocalTensor<int8_t> outLocal = inputXOutQueue.AllocTensor<int8_t>();
LocalTensor<float> dynamicQuantLocal = scaleOutQueue.AllocTensor<float>();
LocalTensor<float> dynamicQuantLocal = outLocal[this->cols].template ReinterpretCast<float>();
if constexpr (!IsSameType<T, float>::value) {
Cast(inLocal, inLocal.ReinterpretCast<T>()[perLoopColsAlign], RoundMode::CAST_NONE, this->cols);
@@ -151,7 +152,6 @@ __aicore__ inline void MoeV2GatherDynamicQuant<T>::Compute(LocalTensor<float>& s
calcQueue.FreeTensor(tempLocal);
inputXOutQueue.EnQue(outLocal);
scaleOutQueue.EnQue(dynamicQuantLocal);
}
template <typename T>
@@ -163,7 +163,7 @@ __aicore__ inline void MoeV2GatherDynamicQuant<T>::CopyOutXQuant1H(int64_t progr
int64_t currentLoopStartRow = initialRow / this->k;
int64_t currentLoopLastRow = (initialRow + this->currentLoopRows - 1) / this->k;
DataCopyExtParams copyInParams{1, static_cast<uint32_t>(this->cols * sizeof(T)), 0, 0, 0};
DataCopyExtParams copyOutParams{1, static_cast<uint32_t>(this->cols * sizeof(int8_t)), 0, 0, 0};
DataCopyExtParams copyOutParams{1, static_cast<uint32_t>((this->cols + BLOCK_BYTES) * sizeof(int8_t)), 0, 0, 0};
DataCopyExtParams smoothParams{1, static_cast<uint32_t>(this->cols * sizeof(float)), 0, 0, 0};
LocalTensor<float> smoothLocal;
@@ -187,7 +187,6 @@ __aicore__ inline void MoeV2GatherDynamicQuant<T>::CopyOutXQuant1H(int64_t progr
// Compute quantization
Compute(smoothLocal);
LocalTensor<float> quantScaleLocal = scaleOutQueue.DeQue<float>();
LocalTensor<int8_t> outLocal = inputXOutQueue.DeQue<int8_t>();
while (curLoopRow < this->currentLoopRows && initialRow / this->k == row) {
@@ -197,15 +196,11 @@ __aicore__ inline void MoeV2GatherDynamicQuant<T>::CopyOutXQuant1H(int64_t progr
if (outIndex == -1 || (this->dropPadMode == DROPLESS_MODE && outIndex >= this->activateRows)) {
continue;
}
DataCopyPad(expandedXGm[outIndex * cols], outLocal, copyOutParams);
DataCopyPad(dynamicQuantScaleGm[outIndex], quantScaleLocal, {1, 4, 0, 0, 0});
// Scale is placed after the data position
DataCopyPad(expandedXGm[outIndex * cols_scale_], outLocal, copyOutParams);
}
inputXInQueue.FreeTensor(inLocal);
inputXOutQueue.FreeTensor(outLocal);
scaleOutQueue.FreeTensor(quantScaleLocal);
}
if (smoothType == 1) {
smoothInQueue.FreeTensor(smoothLocal);
}
expandRowIdxInQueue.FreeTensor(indicesLocal);
}
@@ -463,6 +458,7 @@ __aicore__ inline void MoeV2GatherDynamicQuant<T>::Init(GM_ADDR inputX, GM_ADDR
this->needCoreNum = this->gatherOutTilingData->needCoreNum;
this->activateRows = this->gatherOutTilingData->activateRows;
this->cols = tilingData->cols;
this->cols_scale_ = this->cols + ALIGN_512;
this->n = tilingData->n;
this->k = tilingData->k;
this->totalLength = tilingData->n * tilingData->k;
@@ -518,32 +514,15 @@ __aicore__ inline void MoeV2GatherDynamicQuant<T>::Init(GM_ADDR inputX, GM_ADDR
pipe->InitBuffer(smoothInQueue, BUFFER_NUM, AlignBytes(this->perLoopCols, sizeof(float)));
pipe->InitBuffer(calcQueue, 1, AlignBytes(this->perLoopCols, sizeof(float)));
pipe->InitBuffer(inputXOutQueue, 1, AlignBytes(this->perLoopCols, sizeof(int8_t)));
pipe->InitBuffer(scaleOutQueue, 1, BLOCK_BYTES + BLOCK_BYTES);
}
template <typename T>
__aicore__ inline void MoeV2GatherDynamicQuant<T>::Process() {
if (this->blockIdx < this->needCoreNum) {
currentLoopRows = perLoopRows;
if (colLoops > 1) { // A single row cannot be fully loaded; workspace is required
if (smoothType == 2) {
for (int64_t loop = 0; loop < this->rowLoops - 1; loop++) {
CopyInExpandedExpertIdx(loop);
CopyOutPartialXQuantEH(loop);
}
currentLoopRows = lastLoopRows;
CopyInExpandedExpertIdx(this->rowLoops - 1);
CopyOutPartialXQuantEH(this->rowLoops - 1);
} else {
for (int64_t loop = 0; loop < this->rowLoops - 1; loop++) {
CopyInExpandedRowIdx(loop);
CopyOutPartialXQuant1H(loop);
}
currentLoopRows = lastLoopRows;
CopyInExpandedRowIdx(this->rowLoops - 1);
CopyOutPartialXQuant1H(this->rowLoops - 1);
}
} else { // A single row can be fully loaded
if (colLoops > 1) { // Cannot fit all data in one row, workspace is required
trap(); // Not supported
} else { // All data can fit in one row
if (smoothType == 2) {
for (int64_t loop = 0; loop < this->rowLoops - 1; loop++) {
CopyInExpandedExpertIdx(loop);

View File

@@ -85,8 +85,8 @@ KernelMoeTokenUnpermute<T1, T2, T3, PROBS>::Init(GM_ADDR permuted_tokens, GM_ADD
GM_ADDR unpermuted_tokens,
const MoeTokenUnpermuteTilingData *__restrict tiling_data)
{
this->blockIdx = get_block_idx() + get_subblockid() * get_block_num();
this->blockNum = get_block_num() * get_subblockdim();
this->blockIdx = get_block_idx();
this->blockNum = get_block_num();
if (blockIdx >= blockNum) {
return;

View File

@@ -0,0 +1,243 @@
#ifndef CATLASS_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_V2_ONLY_HPP
#define CATLASS_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_V2_ONLY_HPP
#include "catlass/catlass.hpp"
#include "catlass/arch/resource.hpp"
#include "catlass/epilogue/dispatch_policy.hpp"
#include "catlass/gemm_coord.hpp"
#include "catlass/matrix_coord.hpp"
#include "catlass/layout/layout.hpp"
#include "catlass/detail/callback.hpp"
#include "hccl_shmem.hpp"
#include "layout3d.hpp"
namespace Catlass::Epilogue::Block {
template <
uint32_t UB_STAGES_,
class CType_,
class LayoutPerTokenScale_,
class DType_,
class TileCopy_
>
class BlockEpilogue <
EpilogueAtlasA2PerTokenDequantV2<UB_STAGES_>,
CType_,
Gemm::GemmType<float, LayoutPerTokenScale_>,
DType_,
TileCopy_
> {
public:
using DispatchPolicy = EpilogueAtlasA2PerTokenDequantV2<UB_STAGES_>;
using ArchTag = typename DispatchPolicy::ArchTag;
static constexpr uint32_t UB_STAGES = UB_STAGES_;
// Data infos
using ElementC = typename CType_::Element;
using LayoutC = typename CType_::Layout;
using ElementPerTokenScale = float;
using LayoutPerTokenScale = LayoutPerTokenScale_;
using ElementD = typename DType_::Element;
using LayoutD = typename DType_::Layout;
using CopyScaleGmToUb = Epilogue::Tile::CopyGm2Ub<ArchTag, Gemm::GemmType<float, layout::VectorLayout>>;
using CopyGmToUbC = typename TileCopy_::CopyGmToUbC;
using CopyUbToGmD = typename TileCopy_::CopyUbToGmD;
struct Params {
__gm__ int32_t *ptrTokenPerExpert{nullptr};
int32_t EP;
int32_t expertPerRank;
int32_t n2;
LayoutC layoutC;
int32_t n0;
int32_t rank;
HcclShmem shmem;
int32_t offsetD;
CATLASS_DEVICE
Params() {};
CATLASS_DEVICE
Params(int32_t EP_, int32_t expertPerRank_, int32_t rank_, __gm__ int32_t *ptrTokenPerExpert_,
LayoutC layoutC_, int32_t n2_, int32_t n0_, HcclShmem& shmem_, int32_t offsetD_) :
ptrTokenPerExpert(ptrTokenPerExpert_), EP(EP_),
expertPerRank(expertPerRank_),rank(rank_), layoutC(layoutC_), n2(n2_), n0(n0_),
shmem(shmem_), offsetD(offsetD_)
{}
};
CATLASS_DEVICE
BlockEpilogue(Arch::Resource<ArchTag> const &resource, Params const &params = Params{}) : params(params)
{
//ub:192KB
n0 = params.n0;
size_t ubOffset = 0;
for(int32_t i = 0; i < 2; i++) {
ubCList[i] = resource.ubBuf.template GetBufferByByte<ElementC>(ubOffset);
ubOffset += max_len * sizeof(ElementC);
ubDList[i] = resource.ubBuf.template GetBufferByByte<ElementD>(ubOffset);
ubOffset += max_len * sizeof(ElementD);
ubFp32List[i] = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
ubOffset += max_len * sizeof(float);
scaleUbList[i] = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
ubOffset += (max_len / n0) * sizeof(float);
source_scale_offset[i] = -1;
}
tokenPerExpert.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(params.ptrTokenPerExpert));
tokenPerExpertLayout = Layout3D(AlignUp(params.EP * params.expertPerRank, 128), params.expertPerRank);
is_ping = true;
}
CATLASS_DEVICE
void Finalize()
{
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID1);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID2);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID3);
AscendC::WaitFlag<AscendC::HardEvent::S_MTE2>(EVENT_ID2);
AscendC::WaitFlag<AscendC::HardEvent::S_MTE2>(EVENT_ID3);
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID1);
}
CATLASS_DEVICE
~BlockEpilogue()
{
}
CATLASS_DEVICE
void operator() (
AscendC::GlobalTensor<ElementC> const &gmC,
AscendC::GlobalTensor<ElementPerTokenScale> const &gmPerTokenScale,
GemmCoord& blockCoord,
GemmCoord& actualBlockShape,
int32_t groupIdx,
int32_t preSrcExpertSum,
AscendC::GlobalTensor<int32_t> preSumBeforeRank
){
is_ping = !is_ping;
auto event_id = is_ping ? EVENT_ID0 : EVENT_ID1;
auto event_id_2 = is_ping ? EVENT_ID2 : EVENT_ID3;
auto &ubC = ubCList[is_ping];
auto &ubD = ubDList[is_ping];
int32_t gmCOffset = preSrcExpertSum * params.n2 + blockCoord.m() * params.n2 + blockCoord.n();
auto gmTileC = gmC[gmCOffset];
auto &ubCFp32 = ubFp32List[is_ping];
auto &scaleUb = scaleUbList[is_ping];
LayoutC layoutGM{actualBlockShape.m(), actualBlockShape.n(), params.n2};
LayoutC layoutUB{actualBlockShape.m(), actualBlockShape.n(), n0};
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(event_id);
copyGmToUbC(ubC, gmTileC, layoutUB, layoutGM);
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(event_id);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(event_id);
AscendC::Cast<float, ElementC, false>(ubCFp32, ubC, AscendC::RoundMode::CAST_NONE, -1, repeat, {1, 1, 8, 4});
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(event_id);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(event_id_2);
AscendC::WaitFlag<AscendC::HardEvent::S_MTE2>(event_id_2);
int32_t gmScaleOffset = preSrcExpertSum + blockCoord.m();
layout::VectorLayout scaleLauout{actualBlockShape.m()};
if (source_scale_offset[event_id] != gmScaleOffset) {
source_scale_offset[event_id] = gmScaleOffset;
copyScaleGmToUb(scaleUb, gmPerTokenScale[gmScaleOffset], scaleLauout, scaleLauout);
}
AscendC::SetFlag<AscendC::HardEvent::MTE2_S>(event_id_2);
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(event_id_2);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(event_id_2);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_S>(event_id_2); // Note that the value must be MTE2_S instead of MTE2_V.
// Otherwise, 0 will be read, causing garbled characters.
AscendC::PipeBarrier<PIPE_V>();
for (int32_t row = 0; row < actualBlockShape.m(); ++row) {
float scale = scaleUb(row);
Muls<float, false>(ubCFp32[n0* row], ubCFp32[n0 * row] , scale, -1, (actualBlockShape.n() + 127) / 128 * 2, {1, 1, 8, 8});
}
AscendC::PipeBarrier<PIPE_V>();
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(event_id);
AscendC::Cast<ElementD, float, false>(ubD, ubCFp32, AscendC::RoundMode::CAST_RINT, -1, repeat, {1, 1, 4, 8});
AscendC::SetFlag<AscendC::HardEvent::S_MTE2>(event_id_2);
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(event_id_2);
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(event_id);
int32_t lenTile = actualBlockShape.m();
int32_t stTile = blockCoord.m();
int32_t edTile = stTile + lenTile;
int32_t preSumRankInExpert = 0;
int32_t tileOffset = 0;
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(event_id);
for (int32_t dstEpIdx = 0; dstEpIdx < params.EP; dstEpIdx ++) {
int32_t lenRankInExpert = tokenPerExpert(tokenPerExpertLayout(dstEpIdx, params.rank, groupIdx));
int32_t dstExpertOffset = preSumBeforeRank(dstEpIdx * params.expertPerRank + groupIdx);
int32_t stRankInExpert = preSumRankInExpert;
int32_t edRankInExpert = stRankInExpert + lenRankInExpert;
preSumRankInExpert += lenRankInExpert;
if (stRankInExpert >= edTile) {
break;
}
else if (edRankInExpert <= stTile) {
continue;
}
int32_t stData = max(stRankInExpert, stTile);
int32_t edData = min(edRankInExpert, edTile);
uint32_t lenData = edData - stData;
if (lenData <= 0){
continue;
}
uint32_t dstOffsetInExpert = 0;
if (stTile > stRankInExpert) {
dstOffsetInExpert = stTile - stRankInExpert;
}
AscendC::GlobalTensor<ElementD> gmRemotePeer;
__gm__ void* dstPeermemPtr = params.shmem(params.offsetD, dstEpIdx);
gmRemotePeer.SetGlobalBuffer(reinterpret_cast<__gm__ ElementD*>(dstPeermemPtr));
MatrixCoord dstOffset{dstOffsetInExpert + dstExpertOffset, blockCoord.n()};
int64_t gmDstOffset = params.layoutC.GetOffset(dstOffset);
auto gmTileD = gmRemotePeer[gmDstOffset];
LayoutC layoutGM2{lenData, actualBlockShape.n(), params.n2};
LayoutC layoutUB2{lenData, actualBlockShape.n(), n0};
copyUbToGmD(gmTileD, ubD[tileOffset * n0], layoutGM2, layoutUB2);
tileOffset += lenData;
}
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(event_id);
}
private:
Params params;
AscendC::LocalTensor<ElementC> ubCList[UB_STAGES];
AscendC::LocalTensor<ElementD> ubDList[UB_STAGES];
AscendC::LocalTensor<float> ubFp32List[UB_STAGES];
AscendC::LocalTensor<float> scaleUbList[UB_STAGES];
int32_t source_scale_offset[UB_STAGES];
int32_t max_len = 8 * 32 / 4 * 128;
int32_t n0;
bool is_ping = false;
int32_t repeat = 128;
CopyGmToUbC copyGmToUbC;
CopyUbToGmD copyUbToGmD;
CopyScaleGmToUb copyScaleGmToUb;
AscendC::GlobalTensor<int32_t> tokenPerExpert;
Layout3D tokenPerExpertLayout;
};
}
#endif

View File

@@ -22,8 +22,6 @@
namespace Catlass::Gemm::Block {
constexpr uint16_t CROSS_CORE_FLAG_MAX_SET_COUNT = 15;
template<AscendC::HardEvent event>
__aicore__ inline void SyncFlagFunc(int32_t eventID)
{
@@ -153,9 +151,11 @@ public:
L1TileShape::K, L1TileShape::N);
CATLASS_DEVICE
BlockMmad(Arch::Resource<ArchTag> &resource, uint32_t l1BufAddrStart = 0)
BlockMmad(Arch::Resource<ArchTag> &resource, __gm__ int32_t* flagPtr = nullptr, int32_t expertPerRank = 0, uint32_t l1BufAddrStart = 0)
{
syncGroupIdx = 0;
ptrSoftFlagBase_ = flagPtr;
expertPerRank_ = expertPerRank;
InitL1(resource, l1BufAddrStart);
InitL0A(resource);
InitL0B(resource);
@@ -272,9 +272,21 @@ public:
CATLASS_DEVICE
void Finalize(int32_t target, int32_t flag = 0)
{
for(;syncGroupIdx <= target; syncGroupIdx++) {
int32_t flagId = syncGroupIdx / CROSS_CORE_FLAG_MAX_SET_COUNT + flag;
AscendC::CrossCoreSetFlag<0x2, PIPE_FIX>(flagId);
if (ptrSoftFlagBase_ != nullptr) {
if (target < 0) {
return;
}
AscendC::SetFlag<AscendC::HardEvent::FIX_MTE3>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::FIX_MTE3>(EVENT_ID0);
AscendC::GlobalTensor<int32_t> flagGlobal;
flagGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(ptrSoftFlagBase_) + (expertPerRank_ + AscendC::GetBlockIdx()) * FLAGSTRIDE);
AscendC::DataCopy(flagGlobal, l1FTensor[target * 16], FLAGSTRIDE);
}
else {
for(;syncGroupIdx <= target; syncGroupIdx++) {
int32_t flagId = syncGroupIdx / 15 + flag;
AscendC::CrossCoreSetFlag<0x2, PIPE_FIX>(flagId);
}
}
}
private:
@@ -291,7 +303,6 @@ private:
layout::VectorLayout layoutScale;
int32_t syncLoopIdx;
int32_t flag;
CATLASS_DEVICE
L1TileMmadParams() = default;
};
@@ -310,11 +321,24 @@ private:
AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(l1AEventList[i]);
AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(l1BEventList[i]);
}
uint32_t l1SOffset = l1BOffset + L1B_TILE_SIZE * L1_STAGES;
if constexpr (std::is_same_v<ElementA, int8_t>) {
uint32_t l1SOffset = l1BOffset + L1B_TILE_SIZE * L1_STAGES;
l1STensor = resource.l1Buf.template GetBufferByByte<uint64_t>(l1SOffset);
AscendC::SetFlag<AscendC::HardEvent::FIX_MTE2>(0);
}
if (ptrSoftFlagBase_ != nullptr) {
// Initialize the flag matrix (structure as below):
// 1 0 0 0 0 0 0 0
// 2 0 0 0 0 0 0 0
// ...
// 16 0 0 0 0 0 0 0
// Then move it to L1
uint32_t l1FOffset = l1SOffset + L1S_TILE_SIZE;
l1FTensor = resource.l1Buf.template GetBufferByByte<int32_t>(l1FOffset);
AscendC::GlobalTensor<int32_t> flagBase;
flagBase.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(ptrSoftFlagBase_));
AscendC::DataCopy(l1FTensor, flagBase, expertPerRank_ * FLAGSTRIDE);
}
}
CATLASS_DEVICE
@@ -463,12 +487,20 @@ private:
if constexpr (std::is_same_v<ElementA, int8_t>) {
AscendC::SetFlag<AscendC::HardEvent::FIX_MTE2>(0);
}
#ifdef __TILE_SYNC__
if (params.flag > 0) {
int32_t flagId = params.flag + params.syncLoopIdx / 8;
AscendC::CrossCoreSetFlag<0x2, PIPE_FIX>(flagId);
}
#else
Finalize(params.syncLoopIdx, params.flag);
#endif
}
}
AscendC::LocalTensor<ElementA> l1ATensorList[L1_STAGES];
AscendC::LocalTensor<ElementB> l1BTensorList[L1_STAGES];
AscendC::LocalTensor<uint64_t> l1STensor;
AscendC::LocalTensor<int32_t> l1FTensor;
int32_t syncGroupIdx;
int32_t l1AEventList[L1_STAGES];
int32_t l1BEventList[L1_STAGES];
@@ -497,8 +529,11 @@ private:
CopyL1ToL0A copyL1ToL0A;
CopyL1ToL0B copyL1ToL0B;
CopyL0CToGm copyL0CToGm;
__gm__ int32_t* ptrSoftFlagBase_ = nullptr;
int32_t expertPerRank_;
};
} // namespace Catlass::Gemm::Block
#endif // CATLASS_GEMM_BLOCK_BLOCK_MMAD_PRELOAD_FIXPIPE_QUANT_HPP
#endif // CATLASS_GEMM_BLOCK_BLOCK_MMAD_PRELOAD_FIXPIPE_QUANT_HPP

View File

@@ -5,5 +5,7 @@ constexpr static uint64_t MB_SIZE = 1024 * 1024UL;
constexpr static int32_t NUMS_PER_FLAG = 16;
constexpr static int32_t CACHE_LINE = 512;
constexpr static int32_t RESET_VAL = 0xffff;
constexpr static int32_t ALIGN_128 = 128;
constexpr static int32_t FLAGSTRIDE = 16;
constexpr static int32_t UB_ALIGN = 32;
constexpr uint16_t CROSS_CORE_FLAG_MAX_SET_COUNT = 15;
#endif

View File

@@ -33,13 +33,13 @@ namespace Catlass::Epilogue {
};
template <uint32_t UB_STAGES_>
struct EpilogueAtlasA2PerTokenDequantQuant {
struct EpilogueAtlasA2PerTokenDequantSwigluQuant {
using ArchTag = Arch::AtlasA2;
static constexpr uint32_t UB_STAGES = UB_STAGES_;
};
template <uint32_t UB_STAGES_>
struct EpilogueAtlasA2PerTokenDequantSwigluQuant {
struct EpilogueAtlasA2PerTokenDequantV2 {
using ArchTag = Arch::AtlasA2;
static constexpr uint32_t UB_STAGES = UB_STAGES_;
};

View File

@@ -5,13 +5,28 @@
#include "kernel_operator.h"
#include "const_args.hpp"
#ifdef HCCL_COMM
#include "moe_distribute_base.h"
using namespace AscendC::HcclContextDef;
#ifndef HCCL_COMM
#else
#include "shmem_api.h"
#endif
#define FORCE_INLINE_AICORE inline __attribute__((always_inline)) __aicore__
constexpr int32_t MAX_RANK_SIZE = 32;
constexpr int32_t SHMEM_MEM = 700 * MB_SIZE;
constexpr uint16_t SEND_SYNC_EVENT_ID = 9;
constexpr uint16_t RECV_SYNC_EVENT_ID = 10;
constexpr uint32_t SELF_STATE_OFFSET = 256 * 1024;
constexpr uint32_t STATE_OFFSET = 512;
FORCE_INLINE_AICORE void AicSyncAll() {
AscendC::CrossCoreSetFlag<0x0, PIPE_FIX>(8);
AscendC::CrossCoreWaitFlag<0x0>(8);
}
template<typename T>
FORCE_INLINE_AICORE void gm_store(__gm__ T *addr, T val) {
@@ -23,10 +38,11 @@ FORCE_INLINE_AICORE T gm_load(__gm__ T *cache) {
return *((__gm__ T *)cache);
}
FORCE_INLINE_AICORE void gm_dcci(__gm__ uint8_t * addr) {
template<typename T>
FORCE_INLINE_AICORE void gm_dcci(__gm__ T * addr) {
using namespace AscendC;
GlobalTensor<uint8_t> global;
global.SetGlobalBuffer(addr);
global.SetGlobalBuffer(reinterpret_cast<GM_ADDR>(addr));
// Important: add hint to avoid dcci being optimized by compiler
__asm__ __volatile__("");
@@ -37,26 +53,20 @@ FORCE_INLINE_AICORE void gm_dcci(__gm__ uint8_t * addr) {
FORCE_INLINE_AICORE int32_t gm_signal_wait_until_eq_for_barrier(__gm__ int32_t *sig_addr, int32_t cmp_val) {
do {
gm_dcci((__gm__ uint8_t *)sig_addr);
if (*sig_addr == cmp_val) {
return *sig_addr;
}
// in case when peer pe enters next barrier
if (*sig_addr == cmp_val + 1) {
return *sig_addr;
}
} while (true);
// never reach
return -1;
}
FORCE_INLINE_AICORE void gm_signal_wait_until_ne(__gm__ int32_t *sig_addr, int32_t cmp_val) {
do {
AscendC::LocalTensor<int32_t> ub;
ub.address_.logicPos = static_cast<uint8_t>(TPosition::VECIN);
ub.address_.logicPos = static_cast<uint8_t>(AscendC::TPosition::VECIN);
ub.address_.bufferAddr = 0;
AscendC::GlobalTensor<int32_t> sig;
sig.SetGlobalBuffer(sig_addr);
@@ -71,59 +81,53 @@ FORCE_INLINE_AICORE void gm_signal_wait_until_ne(__gm__ int32_t *sig_addr, int32
}
constexpr int32_t MAX_RANK_SIZE = 32;
class HcclShmem {
public:
#ifdef HCCL_COMM // HCCL needs to initialize the HCCL context
__gm__ HcclOpResParamCustom *WinContext_{nullptr};
Hccl<HCCL_SERVER_TYPE_AICPU> hccl_;
size_t m_segmentSize;
int32_t m_rank;
int32_t m_rankSize;
__gm__ HcclOpResParamCustom *WinContext_{nullptr};
Hccl<HCCL_SERVER_TYPE_AICPU> hccl_;
AscendC::LocalTensor<int32_t> ub;
FORCE_INLINE_AICORE
HcclShmem(){
auto contextGM0 = AscendC::GetHcclContext<HCCL_GROUP_ID_0>();
WinContext_ = (__gm__ HcclOpResParamCustom *)contextGM0;
FORCE_INLINE_AICORE
HcclShmem(){
auto contextGM0 = AscendC::GetHcclContext<HCCL_GROUP_ID_0>();
WinContext_ = (__gm__ HcclOpResParamCustom *)contextGM0;
m_rank = WinContext_->localUsrRankId;
m_rankSize = WinContext_->rankSize;
m_segmentSize = WinContext_->winSize;
}
FORCE_INLINE_AICORE
size_t SegmentSize() const {
return m_segmentSize;
}
FORCE_INLINE_AICORE
int32_t RankSize() const {
return m_rankSize;
}
m_rank = WinContext_->localUsrRankId;
m_rankSize = WinContext_->rankSize;
m_segmentSize = WinContext_->winSize;
}
#else
FORCE_INLINE_AICORE
HcclShmem(){
m_segmentSize = SHMEM_MEM;
}
FORCE_INLINE_AICORE
void initShmem(GM_ADDR symmetricPtr_, size_t rank, size_t rankSize) {
symmetricPtr = symmetricPtr_;
m_rank = rank;
m_rankSize = rankSize;
}
#endif
FORCE_INLINE_AICORE
GM_ADDR operator() () const { // No argument: return local peermem
GM_ADDR operator() () const { // No parameters: return pointer to local peermem
#ifdef HCCL_COMM
return (GM_ADDR)(WinContext_->localWindowsIn);
#else
return reinterpret_cast<GM_ADDR>(shmemi_get_state()->heap_base);
return reinterpret_cast<GM_ADDR>(shmem_ptr(symmetricPtr, m_rank));
#endif
}
FORCE_INLINE_AICORE
GM_ADDR operator() (int32_t index) const { // With index: return remote peermem base address
GM_ADDR operator() (int32_t index) const { // With index parameter: return pointer to the base address of remote peermem
#ifdef HCCL_COMM
return (GM_ADDR)((index == m_rank) ? WinContext_->localWindowsIn :
((HcclRankRelationResV2Custom *)(WinContext_->remoteRes[index].nextDevicePtr))->windowsIn);
#else
return reinterpret_cast<GM_ADDR>(shmem_ptr(shmemi_get_state()->heap_base, index));
return reinterpret_cast<GM_ADDR>(shmem_ptr(symmetricPtr, index));
#endif
}
FORCE_INLINE_AICORE
GM_ADDR operator () (int64_t offset, int32_t rankId) const {
#ifdef HCCL_COMM
@@ -136,15 +140,28 @@ public:
return (GM_ADDR)((rankId == m_rank) ? WinContext_->localWindowsIn :
((HcclRankRelationResV2Custom *)(WinContext_->remoteRes[rankId].nextDevicePtr))->windowsIn) + offset;
#else
return shmem_ptr(shmemi_get_state()->heap_base + offset, rankId);
return reinterpret_cast<GM_ADDR>(shmem_ptr((symmetricPtr + offset), rankId));
#endif
}
FORCE_INLINE_AICORE
size_t SegmentSize() const {
return m_segmentSize;
}
FORCE_INLINE_AICORE
int32_t RankSize() const {
return m_rankSize;
}
FORCE_INLINE_AICORE
~HcclShmem() {
}
FORCE_INLINE_AICORE
void CrossRankSync() {
uint64_t flag_offset = (m_segmentSize - MB_SIZE) / sizeof(int32_t);
@@ -165,12 +182,146 @@ public:
gm_store(sync_base, count);
}
FORCE_INLINE_AICORE
void InitStatusTargetSum()
{
using namespace AscendC;
uint64_t flag_offset = (m_segmentSize - MB_SIZE) + SELF_STATE_OFFSET;
//uint64_t self_state_offset = (m_segmentSize - 2 * MB_SIZE);
// ep state
//uint32_t coreIdx = get_block_idx();;
uint32_t coreIdx = GetBlockIdx();
GlobalTensor<int32_t> selfStatusTensor;
selfStatusTensor.SetGlobalBuffer((__gm__ int32_t *)((*this)() + flag_offset));
__asm__ __volatile__("");
DataCacheCleanAndInvalid<int32_t, CacheLine::SINGLE_CACHE_LINE, DcciDst::CACHELINE_OUT>(selfStatusTensor[coreIdx * UB_ALIGN]);
__asm__ __volatile__("");
int32_t state = selfStatusTensor(coreIdx * UB_ALIGN);
if (state == 0) {
sumTarget_ = static_cast<float>(1.0);
selfStatusTensor(coreIdx * UB_ALIGN) = 0x3F800000; // 1.0f
epStateValue_ = 0x3F800000; // 1.0f
} else {
sumTarget_ = static_cast<float>(0.0);
selfStatusTensor(coreIdx * UB_ALIGN) = 0;
epStateValue_ = 0;
}
__asm__ __volatile__("");
DataCacheCleanAndInvalid<int32_t, CacheLine::SINGLE_CACHE_LINE, DcciDst::CACHELINE_OUT>(selfStatusTensor[coreIdx * UB_ALIGN]);
__asm__ __volatile__("");
}
FORCE_INLINE_AICORE
void CrossRankSyncV2Set(AscendC::LocalTensor<int32_t> ctrBuffer) {
//subblockid = 0
uint32_t stateOffset_ = STATE_OFFSET;
// uint32_t epStateOffsetOnWin_ = m_rank * stateOffset_;
uint64_t flag_offset = (m_segmentSize - MB_SIZE) + m_rank * stateOffset_;
//uint64_t flag_offset = (m_segmentSize - MB_SIZE);
int vec_size = get_block_num();
int vec_id = get_block_idx();
AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>(RECV_SYNC_EVENT_ID);
AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>(SEND_SYNC_EVENT_ID);
AscendC::CrossCoreWaitFlag(SEND_SYNC_EVENT_ID);
pipe_barrier(PIPE_ALL);
ctrBuffer.SetValue(0, epStateValue_);
AscendC::SetFlag<AscendC::HardEvent::S_MTE3>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::S_MTE3>(EVENT_ID0);
for (uint32_t dstEpIdx = vec_id; dstEpIdx < m_rankSize; dstEpIdx += vec_size) {
AscendC::GlobalTensor<int32_t> gmDstStates;
gmDstStates.SetGlobalBuffer((__gm__ int32_t*)((*this)(flag_offset, dstEpIdx)));
DataCopy(gmDstStates, ctrBuffer, 8);
}
AscendC::CrossCoreWaitFlag(RECV_SYNC_EVENT_ID);
}
FORCE_INLINE_AICORE
void CrossRankSyncV2Wait(AscendC::LocalTensor<float> statusTensor, AscendC::LocalTensor<float> gatherMaskOutTensor,
AscendC::LocalTensor<uint32_t> gatherTmpTensor, AscendC::LocalTensor<float> statusSumOutTensor) {
uint64_t flag_offset = (m_segmentSize - MB_SIZE);
int vec_size = get_block_num();
int vec_id = get_block_idx();
uint32_t stateOffset_ = STATE_OFFSET;
uint32_t sendRankNum_ = m_rankSize / vec_size;
uint32_t remainderRankNum = m_rankSize % vec_size;
uint32_t startRankId_ = sendRankNum_ * vec_id;
if (vec_id < remainderRankNum) {
sendRankNum_++;
startRankId_ += vec_id;
} else {
startRankId_ += remainderRankNum;
}
uint32_t endRankId_ = startRankId_ + sendRankNum_;
AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>(SEND_SYNC_EVENT_ID);
AscendC::GlobalTensor<float> epStatusSpaceGlobalTensor_;
epStatusSpaceGlobalTensor_.SetGlobalBuffer((__gm__ float *)((*this)() + flag_offset));
if (startRankId_ < m_rankSize) {
AscendC::PipeBarrier<PIPE_ALL>();
gatherTmpTensor.SetValue(0, 1);
uint32_t mask = 1; // gatherMask + sum
uint64_t rsvdCnt = 0;
// DataCopyParams intriParams{static_cast<uint16_t>(sendRankNum_), 1,
// static_cast<uint16_t>((moeSendNum_ > 512) ? 7 : 15), 0};
AscendC::DataCopyParams intriParams{static_cast<uint16_t>(sendRankNum_), 1,
static_cast<uint16_t>(15), 0};
float sumOfFlag = static_cast<float>(-1.0);
float minTarget = (sumTarget_ * sendRankNum_) - (float)0.5;
float maxTarget = (sumTarget_ * sendRankNum_) + (float)0.5;
AscendC::SumParams sumParams{1, sendRankNum_, sendRankNum_};
AscendC::SetFlag<AscendC::HardEvent::S_V>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::S_V>(EVENT_ID0);
while ((sumOfFlag < minTarget) || (sumOfFlag > maxTarget)) {
AscendC::DataCopy<float>(statusTensor, epStatusSpaceGlobalTensor_[startRankId_ * stateOffset_ / sizeof(float)],
intriParams);
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
GatherMask(gatherMaskOutTensor, statusTensor, gatherTmpTensor, true, mask,
{1, (uint16_t)sendRankNum_, 1, 0}, rsvdCnt);
AscendC::PipeBarrier<PIPE_V>();
AscendC::Sum(statusSumOutTensor, gatherMaskOutTensor, sumParams);
AscendC::SetFlag<AscendC::HardEvent::V_S>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::V_S>(EVENT_ID0);
sumOfFlag = statusSumOutTensor.GetValue(0);
}
}
AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>(RECV_SYNC_EVENT_ID);
AscendC::CrossCoreWaitFlag(RECV_SYNC_EVENT_ID);
//unpermute
AscendC::CrossCoreWaitFlag(SEND_SYNC_EVENT_ID);
}
FORCE_INLINE_AICORE
__gm__ int32_t* SyncBaseAddr() {
uint64_t flag_offset = (m_segmentSize - MB_SIZE) / sizeof(int32_t);
return (__gm__ int32_t*)(*this)() + flag_offset + 2048;
}
private:
GM_ADDR symmetricPtr;
int32_t m_rank;
int32_t m_rankSize;
size_t m_segmentSize;
float sumTarget_{0.0};
int32_t epStateValue_;
};
#endif