[Kernel]: Optimize DispatchFFNCombine performance (#6468)

### What this PR does / why we need it?

This PR focuses on performance optimization for the DispatchFFNCombine
operator. The key optimizations include:

1. Improving communication efficiency by merging the transmission of
tokens and scales;
2. Decoupling multi-core dependencies and reducing waiting bubbles in
the combine process through tile-granularity communication;
3. Optimizing the full-card synchronization overhead before the
umpermute operation.

These optimizations aim to reduce the overall execution latency of the
DispatchFFNCombine operator and enhance the runtime performance of the
model inference process on Ascend devices.

### Does this PR introduce _any_ user-facing change?

No. This PR only involves internal performance optimization of the
DispatchFFNCombine operator and does not introduce any changes to
user-facing APIs, interfaces, or behaviors.

### How was this patch tested?

1. Enable the DispatchFFNCombine operator by setting the environment
variable:
```
export VLLM_ASCEND_ENABLE_FUSED_MC2=1
```
2. Run the standard model inference test suite with the above
environment variable enabled;
4. Verify the correctness of model outputs (ensuring no functional
regression) and measure the performance improvement of the
DispatchFFNCombine operator (reduced latency and improved throughput).

- vLLM version: v0.14.1
- vLLM main:
dc917cceb8

Signed-off-by: xulei_ict <xulei292@huawei.com>
Co-authored-by: xulei_ict <xulei292@huawei.com>
This commit is contained in:
xulei
2026-02-09 16:30:34 +08:00
committed by GitHub
parent 9c6d031797
commit 8325528368
13 changed files with 897 additions and 356 deletions

View File

@@ -17,11 +17,11 @@
#include "error_log.h" #include "error_log.h"
#include "hcom_topo_info.h" #include "hcom_topo_info.h"
#include "register/op_def_registry.h" #include "register/op_def_registry.h"
#include "dispatch_ffn_combine_tiling.h" #include "../op_kernel/dispatch_ffn_combine_tiling.h"
#include <vector> #include <vector>
#include <map> #include <map>
#include <algorithm> #include <algorithm>
#include "moe_init_routing_quant_v2/moe_init_routing_quant_v2_tiling.h" #include "../op_kernel/moe_init_routing_quant_v2/moe_init_routing_quant_v2_tiling.h"
using namespace AscendC; using namespace AscendC;
using namespace ge; using namespace ge;
@@ -278,8 +278,12 @@ static ge::graphStatus DispatchFFNCombineTilingFuncImpl(gert::TilingContext *con
uint64_t cocWorkspace = (info.M + 256 - 1) / 256 * 256 * info.topK *sizeof(int32_t) + uint64_t cocWorkspace = (info.M + 256 - 1) / 256 * 256 * info.topK *sizeof(int32_t) +
info.worldSize * info.worldSize * info.expertPerRank * sizeof(int32_t) * 3 + info.worldSize * info.worldSize * info.expertPerRank * sizeof(int32_t) * 3 +
info.maxOutputSize * sizeof(float) * 2 + info.maxOutputSize * sizeof(float) * 2 +
std::max(info.maxOutputSize * info.N * sizeof(int16_t), info.maxOutputSize * n2 * sizeof(int16_t)) + info.maxOutputSize * info.N * sizeof(int16_t) +
std::max(info.maxOutputSize * info.K * sizeof(int8_t), info.maxOutputSize * k2 * sizeof(int8_t)); info.maxOutputSize * n2 * sizeof(int16_t) +
info.maxOutputSize * info.K * sizeof(int8_t) +
info.maxOutputSize * k2 * sizeof(int8_t) +
info.worldSize * sizeof(int32_t) * 16 +
(info.expertPerRank + info.worldSize) * sizeof(int32_t) * 16;
workSpaces[0] = SYSTEM_NEED_WORKSPACE + std::max(cocWorkspace, initRoutingWorkspace); workSpaces[0] = SYSTEM_NEED_WORKSPACE + std::max(cocWorkspace, initRoutingWorkspace);

View File

@@ -23,29 +23,11 @@ extern "C" __global__ __aicore__ void dispatch_ffn_combine(GM_ADDR x, GM_ADDR w1
GM_ADDR c, GM_ADDR expertTokenNums, GM_ADDR workspaceGM, GM_ADDR tilingGM) GM_ADDR c, GM_ADDR expertTokenNums, GM_ADDR workspaceGM, GM_ADDR tilingGM)
{ {
REGISTER_TILING_DEFAULT(DispatchFFNCombineTilingData); REGISTER_TILING_DEFAULT(DispatchFFNCombineTilingData);
if (TILING_KEY_IS(1000000)) { if (TILING_KEY_IS(1000010)) {
KERNEL_TASK_TYPE(1000000, KERNEL_TYPE_MIX_AIC_1_2);
GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineTilingData, tilingData, tilingGM);
DispatchFFNCombine<int8_t, DTYPE_W1, DTYPE_OUT, false, true> op;
op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, expertTokenNums, workspaceGM, tilingGM);
op.Process();
} else if (TILING_KEY_IS(1000001)) {
KERNEL_TASK_TYPE(1000001, KERNEL_TYPE_MIX_AIC_1_2);
GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineTilingData, tilingData, tilingGM);
DispatchFFNCombine<int8_t, DTYPE_W1, DTYPE_OUT, true, false> op;
op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, expertTokenNums, workspaceGM, tilingGM);
op.Process();
} else if (TILING_KEY_IS(1000010)) {
KERNEL_TASK_TYPE(1000010, KERNEL_TYPE_MIX_AIC_1_2); KERNEL_TASK_TYPE(1000010, KERNEL_TYPE_MIX_AIC_1_2);
GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineTilingData, tilingData, tilingGM); GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineTilingData, tilingData, tilingGM);
DispatchFFNCombine<int8_t, DTYPE_W1, DTYPE_OUT, false, true> op; DispatchFFNCombine<int8_t, DTYPE_W1, DTYPE_OUT, false, true> op;
op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, expertTokenNums, workspaceGM, tilingGM); op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, expertTokenNums, workspaceGM, tilingGM);
op.Process(); op.Process();
} else if (TILING_KEY_IS(1000011)) {
KERNEL_TASK_TYPE(1000011, KERNEL_TYPE_MIX_AIC_1_2);
GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineTilingData, tilingData, tilingGM);
DispatchFFNCombine<int8_t, DTYPE_W1, DTYPE_OUT, true, true> op;
op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, expertTokenNums, workspaceGM, tilingGM);
op.Process();
} }
} }

View File

@@ -234,7 +234,7 @@ __aicore__ inline void DispatchFFNCombine<TemplateMMA2ACFunc>::Process()
using BlockEpilogue1 = Epilogue::Block::BlockEpilogue<EpilogueDispatchPolicy1, CType, PerTokenScaleType, using BlockEpilogue1 = Epilogue::Block::BlockEpilogue<EpilogueDispatchPolicy1, CType, PerTokenScaleType,
D1Type, TileElemWiseMuls, TileCopy1>; D1Type, TileElemWiseMuls, TileCopy1>;
using EpilogueDispatchPolicy2 = Epilogue::EpilogueAtlasA2PerTokenDequant<ubStages>; using EpilogueDispatchPolicy2 = Epilogue::EpilogueAtlasA2PerTokenDequantV2<ubStages>;
using TileCopy2 = Epilogue::Tile::TileCopy<ArchTag, CType, ScaleType, PerTokenScaleType, D2Type>; using TileCopy2 = Epilogue::Tile::TileCopy<ArchTag, CType, ScaleType, PerTokenScaleType, D2Type>;
using BlockEpilogue2 = Epilogue::Block::BlockEpilogue<EpilogueDispatchPolicy2, CType,PerTokenScaleType, using BlockEpilogue2 = Epilogue::Block::BlockEpilogue<EpilogueDispatchPolicy2, CType,PerTokenScaleType,
D2Type, TileCopy2>; D2Type, TileCopy2>;

View File

@@ -22,21 +22,38 @@
#include "catlass/matrix_coord.hpp" #include "catlass/matrix_coord.hpp"
#include "catlass/epilogue/tile/tile_copy.hpp" #include "catlass/epilogue/tile/tile_copy.hpp"
#ifndef HCCL_COMM
#include "block_mmad_preload_async_fixpipe_quant.hpp"
#include "copy_gm_to_l1_custom.hpp"
#include "copy_l0c_to_gm_custom.hpp"
#include "block_epilogue_pertoken_row.hpp"
#include "block_epilogue_pertoken_v2.hpp"
#include "block_epilogue_pertoken_swiglu.hpp"
#include "hccl_shmem.hpp"
#include "const_args.hpp"
#include "layout3d.hpp"
#include "tiling/moe_init_routing_quant_v2_tiling.h"
#include "moe_init_routing_quant_v2/moe_init_routing_quant_v2.cpp"
#include "moe_init_routing_quant_v2/moe_v2_fullload_dynamic_quant.h"
#include "moe_token_unpermute.h"
#include "get_tensor_addr.hpp"
inline __gm__ struct OpSystemRunCfg g_opSystemRunCfg{Catlass::L2_OFFSET};
#else
#include "utils/block_mmad_preload_async_fixpipe_quant.hpp" #include "utils/block_mmad_preload_async_fixpipe_quant.hpp"
#include "utils/copy_gm_to_l1_custom.hpp" #include "utils/copy_gm_to_l1_custom.hpp"
#include "utils/copy_l0c_to_gm_custom.hpp" #include "utils/copy_l0c_to_gm_custom.hpp"
#include "utils/block_epilogue_pertoken_row.hpp" #include "utils/block_epilogue_pertoken_row.hpp"
#include "utils/block_epilogue_pertoken_v2.hpp"
#include "utils/block_epilogue_pertoken_swiglu.hpp" #include "utils/block_epilogue_pertoken_swiglu.hpp"
#include "utils/hccl_shmem.hpp" #include "utils/hccl_shmem.hpp"
#include "utils/const_args.hpp" #include "utils/const_args.hpp"
#include "utils/layout3d.hpp" #include "utils/layout3d.hpp"
#include "utils/get_tensor_addr.hpp"
#include "moe_init_routing_quant_v2/moe_init_routing_quant_v2_tiling.h" #include "moe_init_routing_quant_v2/moe_init_routing_quant_v2_tiling.h"
#include "moe_init_routing_quant_v2/moe_init_routing_quant_v2.cpp" #include "moe_init_routing_quant_v2/moe_init_routing_quant_v2.cpp"
#include "moe_init_routing_quant_v2/moe_v2_fullload_dynamic_quant.h" #include "moe_init_routing_quant_v2/moe_v2_fullload_dynamic_quant.h"
#include "unpermute/moe_token_unpermute.h" #include "unpermute/moe_token_unpermute.h"
#include "utils/get_tensor_addr.hpp"
#endif
using namespace AscendC; using namespace AscendC;
@@ -44,7 +61,6 @@ namespace Catlass::Gemm::Kernel {
constexpr uint16_t SYNCFLAGC2V = 9; constexpr uint16_t SYNCFLAGC2V = 9;
constexpr uint16_t SYNCFLAGV2C = 10; constexpr uint16_t SYNCFLAGV2C = 10;
constexpr uint16_t CROSS_CORE_FLAG_MAX_SET_COUNT = 15;
template < template <
class BlockMmad_, class BlockMmad_,
@@ -104,6 +120,7 @@ public:
uint32_t rank; uint32_t rank;
uint32_t rankSize; uint32_t rankSize;
int32_t ubMoveNum; int32_t ubMoveNum;
GM_ADDR symmetricPtr;
//-------------- //--------------
GM_ADDR expertIdx; GM_ADDR expertIdx;
GM_ADDR moeInitRoutingQuantV2Scale; GM_ADDR moeInitRoutingQuantV2Scale;
@@ -193,9 +210,7 @@ public:
void operator()<AscendC::AIC>(Params const &params) void operator()<AscendC::AIC>(Params const &params)
{ {
GMM1(params); GMM1(params);
AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGV2C); AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGV2C);
GMM2(params); GMM2(params);
} }
@@ -204,32 +219,26 @@ public:
CATLASS_DEVICE CATLASS_DEVICE
void operator()<AscendC::AIV>(Params const &params) void operator()<AscendC::AIV>(Params const &params)
{ {
Dispatch(params); DispatchAndCombine(params);
AscendC::SyncAll<true>();
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(SYNCFLAGV2C);
Combine(params);
} }
private: private:
CATLASS_DEVICE void initBuffer(Params const &params) { CATLASS_DEVICE void initBuffer(Params const &params) {
#ifndef HCCL_COMM
shmem.initShmem(params.symmetricPtr, params.rank, params.rankSize);
#endif
workspaceInfo = WorkspaceInfo(params); workspaceInfo = WorkspaceInfo(params);
peermemInfo = PeermemInfo(params, shmem); peermemInfo = PeermemInfo(params, shmem);
cumsumMM.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(workspaceInfo.ptrcumsumMM)); cumsumMM.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(workspaceInfo.ptrcumsumMM));
gmA.SetGlobalBuffer(reinterpret_cast<__gm__ ElementA *>(workspaceInfo.ptrA)); gmA.SetGlobalBuffer(reinterpret_cast<__gm__ ElementA *>(workspaceInfo.ptrA));
gmC.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(workspaceInfo.ptrC)); gmC.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(workspaceInfo.ptrC));
gmPermutedToken.SetGlobalBuffer(reinterpret_cast<__gm__ ElementD1 *>(workspaceInfo.ptrPermutedToken)); gmPermutedToken.SetGlobalBuffer(reinterpret_cast<__gm__ ElementD1 *>(workspaceInfo.ptrPermutedToken));
gmC2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(workspaceInfo.ptrC2)); gmC2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(workspaceInfo.ptrC2));
gmPerTokenScale1.SetGlobalBuffer(reinterpret_cast<__gm__ ElementPerTokenScale *>(workspaceInfo.ptrPerTokenScale)); gmPerTokenScale1.SetGlobalBuffer(reinterpret_cast<__gm__ ElementPerTokenScale *>(workspaceInfo.ptrPerTokenScale));
gmPerTokenScale2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementPerTokenScale *>(workspaceInfo.ptrPerTokenScale2)); gmPerTokenScale2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementPerTokenScale *>(workspaceInfo.ptrPerTokenScale2));
tokenPerExpert.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(shmem() + peermemInfo.offsetPeerTokenPerExpert)); tokenPerExpert.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(shmem() + peermemInfo.offsetPeerTokenPerExpert));
tokenPerExpertLayout = Layout3D(AlignUp(params.EP * params.expertPerRank, ALIGN_128), params.expertPerRank); tokenPerExpertLayout = Layout3D(AlignUp(params.EP * params.expertPerRank, ALIGN_128), params.expertPerRank);
preSumBeforeRank.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(workspaceInfo.ptrSumBeforeRank));
} }
template<typename T> template<typename T>
@@ -285,6 +294,51 @@ private:
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID1); AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID1);
} }
// Move tokens and scales together, then write them to different positions respectively
template<typename T>
CATLASS_DEVICE void CopyGMToGMPerToken(
AscendC::GlobalTensor<T> dst,
AscendC::GlobalTensor<float> dstScale,
AscendC::GlobalTensor<T> src,
int32_t rows,
int32_t hiddenSize
)
{
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID1);
constexpr int32_t BufferNum = 2;
AscendC::LocalTensor<T> tmpBuffer1 = resource.ubBuf.template GetBufferByByte<T>(0);
constexpr int tmpBufferOffset = 96 * 1024; // half of UB
AscendC::LocalTensor<T> tmpBuffer2 = resource.ubBuf.template GetBufferByByte<T>(tmpBufferOffset);
uint32_t copyInNum = hiddenSize + ALIGN_512;
// [ReduceScatter] 2. Pre Interface Sync
int pingpongId = 0;
for (uint32_t processIndex = 0; processIndex < rows; ++processIndex) {
AscendC::TEventID EVENT_ID = pingpongId == 0 ? EVENT_ID0 : EVENT_ID1;
AscendC::LocalTensor<T> buf = pingpongId == 0 ? tmpBuffer1 : tmpBuffer2;
AscendC::LocalTensor<float> bufScale = buf[hiddenSize].template ReinterpretCast<float>();
auto inputOffset = processIndex * copyInNum;
auto outputOffset = processIndex * hiddenSize;
// [ReduceScatter] 2. Pre Interface Sync
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID);
// [ReduceScatter] 3. Start shmem_mte_get_mem_nbi
AscendC::DataCopy(buf, src[inputOffset], copyInNum);
AscendC::SetFlag<AscendC::HardEvent::MTE2_MTE3>(EVENT_ID);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_MTE3>(EVENT_ID);
AscendC::DataCopy(dst[outputOffset], buf, hiddenSize);
AscendC::DataCopyPad(dstScale[processIndex], bufScale, {1, 4, 0, 0, 0});
// [ReduceScatter] 4. Post Interface Sync
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID);
pingpongId = (pingpongId + 1) % BufferNum;
}
// [ReduceScatter] 4. Post Interface Sync
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID1);
}
CATLASS_DEVICE CATLASS_DEVICE
void GetCumsumForMMAIV(AscendC::GlobalTensor<int32_t> & tokenPerExpert, AscendC::GlobalTensor<int32_t> & result, uint32_t expertPerRank, uint32_t rankId, uint32_t EP) void GetCumsumForMMAIV(AscendC::GlobalTensor<int32_t> & tokenPerExpert, AscendC::GlobalTensor<int32_t> & result, uint32_t expertPerRank, uint32_t rankId, uint32_t EP)
{ {
@@ -296,7 +350,7 @@ private:
AscendC::DataCopyPad( AscendC::DataCopyPad(
tmpBuffer1, tmpBuffer1,
tokenPerExpert[rankId * expertPerRank], tokenPerExpert[rankId * expertPerRank],
{U16(EP), U16(expertPerRank * sizeof(int32_t)), U16((AlignUp(EP * expertPerRank, ALIGN_128) - expertPerRank) * sizeof(int32_t)), 0}, {U16(EP), U16(expertPerRank * sizeof(int32_t)), U16((AlignUp(EP * expertPerRank, 128) - expertPerRank) * sizeof(int32_t)), 0},
{} {}
); );
@@ -323,6 +377,8 @@ private:
icache_preload(8); icache_preload(8);
BlockScheduler blockScheduler; BlockScheduler blockScheduler;
BlockMmad blockMmad(resource); BlockMmad blockMmad(resource);
float aivFinishGroups = 0.0f;
__gm__ float* aivFinishPtr = workspaceInfo.ptrSoftFlagBase + params.EP * FLAGSTRIDE;
int64_t gmGroupOffsetA = 0; int64_t gmGroupOffsetA = 0;
int64_t gmGroupOffsetB = 0; int64_t gmGroupOffsetB = 0;
@@ -335,7 +391,6 @@ private:
uint16_t syncgmmIdx = 0; uint16_t syncgmmIdx = 0;
AscendC::CrossCoreWaitFlag<0x2>(syncgmmIdx / CROSS_CORE_FLAG_MAX_SET_COUNT); // Wait for AIV to finish cumsum for matmul AscendC::CrossCoreWaitFlag<0x2>(syncgmmIdx / CROSS_CORE_FLAG_MAX_SET_COUNT); // Wait for AIV to finish cumsum for matmul
syncgmmIdx++; syncgmmIdx++;
AscendC::PipeBarrier<PIPE_ALL>(); AscendC::PipeBarrier<PIPE_ALL>();
for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) { for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) {
@@ -350,9 +405,7 @@ private:
int32_t arrayGroupIdx = params.listLen == 1 ? 0 : groupIdx; int32_t arrayGroupIdx = params.listLen == 1 ? 0 : groupIdx;
gmB1.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(GetTensorAddr<int8_t>(arrayGroupIdx, params.ptrB1))); gmB1.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(GetTensorAddr<int8_t>(arrayGroupIdx, params.ptrB1)));
gmS.SetGlobalBuffer(reinterpret_cast<__gm__ ElementScale *>(GetTensorAddr<int64_t>(arrayGroupIdx, params.ptrScale1))); gmS.SetGlobalBuffer(reinterpret_cast<__gm__ ElementScale *>(GetTensorAddr<int64_t>(arrayGroupIdx, params.ptrScale1)));
AscendC::PipeBarrier<PIPE_ALL>(); AscendC::PipeBarrier<PIPE_ALL>();
if (currentM <= L1TileShape::M) { if (currentM <= L1TileShape::M) {
gmB1.SetL2CacheHint(AscendC::CacheMode::CACHE_MODE_DISABLE); gmB1.SetL2CacheHint(AscendC::CacheMode::CACHE_MODE_DISABLE);
} }
@@ -372,6 +425,7 @@ private:
AscendC::CrossCoreWaitFlag<0x2>(syncgmmIdx / CROSS_CORE_FLAG_MAX_SET_COUNT); AscendC::CrossCoreWaitFlag<0x2>(syncgmmIdx / CROSS_CORE_FLAG_MAX_SET_COUNT);
syncgmmIdx ++; syncgmmIdx ++;
} }
// Compute block location // Compute block location
GemmCoord blockCoord = blockScheduler.GetBlockCoord(loopIdx); GemmCoord blockCoord = blockScheduler.GetBlockCoord(loopIdx);
GemmCoord actualBlockShape = blockScheduler.GetActualBlockShape(blockCoord); GemmCoord actualBlockShape = blockScheduler.GetActualBlockShape(blockCoord);
@@ -399,6 +453,7 @@ private:
if constexpr (BlockMmad::DispatchPolicy::ASYNC) { if constexpr (BlockMmad::DispatchPolicy::ASYNC) {
blockMmad.SynchronizeBlock(); blockMmad.SynchronizeBlock();
} }
// Synchronization signal: GMM1 notifies SwiGLU [1]
blockMmad.Finalize(syncLoopIdx, SYNCFLAGC2V); blockMmad.Finalize(syncLoopIdx, SYNCFLAGC2V);
} }
@@ -419,6 +474,7 @@ private:
if constexpr (BlockMmad::DispatchPolicy::ASYNC) { if constexpr (BlockMmad::DispatchPolicy::ASYNC) {
blockMmad.SynchronizeBlock(); blockMmad.SynchronizeBlock();
} }
// Synchronization signal: GMM1 notifies SwiGLU [2]
blockMmad.Finalize(syncLoopIdx + 1, SYNCFLAGC2V); blockMmad.Finalize(syncLoopIdx + 1, SYNCFLAGC2V);
} }
@@ -458,11 +514,9 @@ private:
} }
AscendC::GlobalTensor<ElementB> gmB2; AscendC::GlobalTensor<ElementB> gmB2;
AscendC::GlobalTensor<ElementScale> gmS2; AscendC::GlobalTensor<ElementScale> gmS2;
AscendC::PipeBarrier<PIPE_ALL>();
int32_t arrayGroupIdx = params.listLen == 1 ? 0 : groupIdx; int32_t arrayGroupIdx = params.listLen == 1 ? 0 : groupIdx;
gmB2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(GetTensorAddr<int8_t>(arrayGroupIdx, params.ptrB2))); gmB2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(GetTensorAddr<int8_t>(arrayGroupIdx, params.ptrB2)));
gmS2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementScale *>(GetTensorAddr<int64_t>(arrayGroupIdx, params.ptrScale2))); gmS2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementScale *>(GetTensorAddr<int64_t>(arrayGroupIdx, params.ptrScale2)));
if (currentM <= L1TileShape::M) { if (currentM <= L1TileShape::M) {
gmB2.SetL2CacheHint(AscendC::CacheMode::CACHE_MODE_DISABLE); gmB2.SetL2CacheHint(AscendC::CacheMode::CACHE_MODE_DISABLE);
} }
@@ -482,6 +536,7 @@ private:
if (params.expertPerRank > lastDequantExpertNum && groupIdx + 1 == params.expertPerRank - lastDequantExpertNum) { if (params.expertPerRank > lastDequantExpertNum && groupIdx + 1 == params.expertPerRank - lastDequantExpertNum) {
AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGV2C); AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGV2C);
} }
for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += coreNum) { for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += coreNum) {
if (loopIdx + coreNum >= coreLoops) { if (loopIdx + coreNum >= coreLoops) {
syncLoopIdx = groupIdx; syncLoopIdx = groupIdx;
@@ -518,34 +573,34 @@ private:
gmGroupOffsetC += inGroupProblemShape.m() * inGroupProblemShape.n(); gmGroupOffsetC += inGroupProblemShape.m() * inGroupProblemShape.n();
startCoreIdx = (startCoreIdx + coreLoops) % coreNum; startCoreIdx = (startCoreIdx + coreLoops) % coreNum;
} }
if constexpr (BlockMmad::DispatchPolicy::ASYNC) { if constexpr (BlockMmad::DispatchPolicy::ASYNC) {
blockMmad.SynchronizeBlock(); blockMmad.SynchronizeBlock();
} }
blockMmad.Finalize(params.expertPerRank - 1, 0);
} }
CATLASS_DEVICE CATLASS_DEVICE
void ResetTokenPerExpert(AscendC::GlobalTensor<int32_t> & tokenPerExpert, int32_t num) void InitArithProgress(Params const &params) {
{ AscendC::LocalTensor<float> tmpBuffer1 = resource.ubBuf.template GetBufferByByte<float>(0);
if (coreIdx != coreNum - 1) {
return;
}
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID0); AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID0); AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID0);
AscendC::LocalTensor<int32_t> tmp = resource.ubBuf.template GetBufferByByte<int32_t>(0); AscendC::Duplicate(tmpBuffer1, 0.0f, (params.EP + 1) * FLAGSTRIDE);
AscendC::Duplicate(tmp, 0, num);
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0); AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0); AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
AscendC::DataCopy(tokenPerExpert, tmp, num);
AscendC::GlobalTensor<float> flagGlobalBase;
flagGlobalBase.SetGlobalBuffer(workspaceInfo.ptrSoftFlagBase);
AscendC::DataCopy(flagGlobalBase, tmpBuffer1, (params.EP + 1) * FLAGSTRIDE);
} }
CATLASS_DEVICE CATLASS_DEVICE
void CrossRankSyncAndlocalTokenPerExpertAllGather(Params const &params, int64_t localTokenPerExpertOffset){ void CrossRankSyncAndlocalTokenPerExpertAllGatherAndGetSumPreRankV2(Params const &params, int64_t localTokenPerExpertOffset){
uint32_t numPerCore = AlignUp(params.EP * params.expertPerRank, 128);
AscendC::LocalTensor<int32_t> tmpBuffer = resource.ubBuf.template GetBufferByByte<int32_t>(0); AscendC::LocalTensor<int32_t> tmpBuffer = resource.ubBuf.template GetBufferByByte<int32_t>(0);
uint32_t numPerCore = AlignUp(params.EP * params.expertPerRank, ALIGN_128); AscendC::LocalTensor<int32_t> prevSumBuf = tmpBuffer[numPerCore];
for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) { for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) {
if (dstEpIdx == params.rank) { if (dstEpIdx == params.rank) {
continue; continue;
@@ -581,11 +636,9 @@ private:
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0); AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
} }
for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) { for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) {
if (dstEpIdx == params.rank) { if (dstEpIdx != params.rank) {
continue;
}
int32_t intPer512 = CACHE_LINE / sizeof(int); int32_t intPer512 = CACHE_LINE / sizeof(int);
for(int32_t checkIdx = 0; checkIdx < AlignUp(params.EP * params.expertPerRank, ALIGN_128); checkIdx += intPer512) { for(int32_t checkIdx = 0; checkIdx < AlignUp(params.EP * params.expertPerRank, 128); checkIdx += intPer512) {
__gm__ int32_t* sync_check = reinterpret_cast<__gm__ int32_t*>(shmem() + peermemInfo.offsetPeerTokenPerExpert) + tokenPerExpertLayout(dstEpIdx, 0, checkIdx); __gm__ int32_t* sync_check = reinterpret_cast<__gm__ int32_t*>(shmem() + peermemInfo.offsetPeerTokenPerExpert) + tokenPerExpertLayout(dstEpIdx, 0, checkIdx);
gm_signal_wait_until_ne(sync_check, 0); gm_signal_wait_until_ne(sync_check, 0);
} }
@@ -593,16 +646,108 @@ private:
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0); AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0); AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
AscendC::Adds(tmpBuffer, tmpBuffer, -0x800000, numPerCore); AscendC::Adds(tmpBuffer, tmpBuffer, -0x800000, numPerCore);
AscendC::PipeBarrier<PIPE_V>();
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0); AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0); AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
AscendC::DataCopy(tokenPerExpert[tokenPerExpertLayout(dstEpIdx, 0, 0)], tmpBuffer, numPerCore); AscendC::DataCopy(tokenPerExpert[tokenPerExpertLayout(dstEpIdx, 0, 0)], tmpBuffer, numPerCore);
} else {
AscendC::DataCopy(tmpBuffer, tokenPerExpert[tokenPerExpertLayout(dstEpIdx, 0, 0)], numPerCore);
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
} }
AscendC::PipeBarrier<PIPE_ALL>();
int32_t prevSum = 0;
int32_t j = 0;
for (int32_t i = 0; i < (params.rank + 1) * params.expertPerRank; i++) {
if (i >= params.rank * params.expertPerRank) {
prevSumBuf(j) = prevSum;
j++;
}
prevSum += tmpBuffer(i);
}
AscendC::SetFlag<AscendC::HardEvent::S_MTE3>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::S_MTE3>(EVENT_ID0);
AscendC::DataCopyPad(preSumBeforeRank[dstEpIdx * params.expertPerRank], prevSumBuf,
AscendC::DataCopyParams{1, static_cast<uint16_t>(params.expertPerRank * sizeof(int32_t)), 0, 0});
}
AscendC::SyncAll<true>(); AscendC::SyncAll<true>();
} }
CATLASS_DEVICE
void ResetTokenPerExpert(int32_t num)
{
if (coreIdx != coreNum - 1) {
return;
}
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID0);
AscendC::LocalTensor<int32_t> tmp = resource.ubBuf.template GetBufferByByte<int32_t>(0);
AscendC::Duplicate(tmp, 0, num);
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
AscendC::DataCopy(tokenPerExpert, tmp, num);
}
CATLASS_DEVICE
void UpdateAicFlags(const Params &params)
{
float flagBase = 1.0f * params.expertPerRank;
__gm__ float* aicFinishPtr = workspaceInfo.ptrSoftFlagBase + params.EP * FLAGSTRIDE;
float flag = 0.0f;
float lastflag = -1.0f;
AscendC::LocalTensor<float> tmpBuffer1 = resource.ubBuf.template GetBufferByByte<float>(0);
__gm__ float* flagPtr = workspaceInfo.ptrSoftFlagBase;
AscendC::GlobalTensor<float> flagGM;
flagGM.SetGlobalBuffer(flagPtr);
int32_t flagBufferSize = max(4, params.EP) * FLAGSTRIDE;
AscendC::LocalTensor<float> dstValueBuffer = resource.ubBuf.template GetBufferByByte<float>(flagBufferSize);
AscendC::LocalTensor<float> sharedTmpBuffer = resource.ubBuf.template GetBufferByByte<float>((flagBufferSize + 64));
uint64_t mask[1] = {0};
uint32_t repeatNum = (flagBufferSize / (4 * FLAGSTRIDE));
for (int32_t i = 0; i < 4; i ++) {
if (i < params.EP) {
mask[0] |= 1ull * (1ull << (i * 16));
}
}
AscendC::SetFlag<AscendC::HardEvent::S_V>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::S_V>(EVENT_ID0);
while (flag < flagBase) {
flag = flagBase;
AscendC::DataCopy(tmpBuffer1, flagGM, params.EP * FLAGSTRIDE);
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
AscendC::ReduceMin<float>(dstValueBuffer, tmpBuffer1, sharedTmpBuffer, mask, repeatNum, 8, false);
AscendC::SetFlag<AscendC::HardEvent::V_S>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::V_S>(EVENT_ID0);
flag = min(flag, dstValueBuffer.GetValue(0));
if (flag > lastflag) {
*aicFinishPtr = flag;
gm_dcci(aicFinishPtr);
lastflag = flag;
}
}
}
CATLASS_DEVICE CATLASS_DEVICE
void Dispatch(Params const &params) { void CombineSetFlag() {
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID0);
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID1);
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID2);
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID3);
AscendC::SetFlag<AscendC::HardEvent::S_MTE2>(EVENT_ID2);
AscendC::SetFlag<AscendC::HardEvent::S_MTE2>(EVENT_ID3);
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID0);
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID1);
}
CATLASS_DEVICE
void DispatchAndCombine(Params const &params) {
icache_preload(8); icache_preload(8);
int64_t localTokenPerExpertOffset = peermemInfo.offsetPeerTokenPerExpert + tokenPerExpertLayout(params.rank, 0, 0) * sizeof(int32_t); int64_t localTokenPerExpertOffset = peermemInfo.offsetPeerTokenPerExpert + tokenPerExpertLayout(params.rank, 0, 0) * sizeof(int32_t);
GM_ADDR localTokenPerExpert = shmem() + localTokenPerExpertOffset; // Place the entire communication matrix in peermem GM_ADDR localTokenPerExpert = shmem() + localTokenPerExpertOffset; // Place the entire communication matrix in peermem
@@ -617,10 +762,19 @@ private:
&params.moeInitRoutingQuantV2TilingData, params.initRoutingQuantTilingKey); &params.moeInitRoutingQuantV2TilingData, params.initRoutingQuantTilingKey);
AscendC::SyncAll<true>(); AscendC::SyncAll<true>();
CrossRankSyncAndlocalTokenPerExpertAllGather(params, localTokenPerExpertOffset);
CrossRankSyncAndlocalTokenPerExpertAllGatherAndGetSumPreRankV2(params, localTokenPerExpertOffset);
if (coreIdx == 0) { if (coreIdx == 0) {
GetCumsumForMMAIV(tokenPerExpert, cumsumMM, params.expertPerRank, params.rank, params.EP); GetCumsumForMMAIV(tokenPerExpert, cumsumMM, params.expertPerRank, params.rank, params.EP);
} }
uint32_t curGroupOffset = 0;
int32_t prevSumBeforeRank = 0;
int32_t prevSum = 0;
if (coreIdx < params.EP) {
prevSum = preSumBeforeRank(coreIdx * params.expertPerRank);
}
AscendC::SyncAll<true>(); AscendC::SyncAll<true>();
AscendC::GlobalTensor<int32_t> ExpertTokenNums; AscendC::GlobalTensor<int32_t> ExpertTokenNums;
@@ -633,24 +787,12 @@ private:
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(syncgmm1Idx / CROSS_CORE_FLAG_MAX_SET_COUNT); AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(syncgmm1Idx / CROSS_CORE_FLAG_MAX_SET_COUNT);
syncgmm1Idx++; syncgmm1Idx++;
uint32_t curGroupOffset = 0;
int32_t prevSumBeforeRank = 0;
int32_t groupIdxDeq = 0;
if (coreIdx < params.EP) {
for (int32_t i = 0; i < params.rank * params.expertPerRank; i++) {
prevSumBeforeRank += tokenPerExpert(tokenPerExpertLayout(coreIdx, 0, i));
}
m_prevSumBeforeRank = prevSumBeforeRank;
}
int prevSum = prevSumBeforeRank;
uint32_t prevGroupSum1 = 0, dequantSum1 = 0, dequantSum2 = 0; uint32_t prevGroupSum1 = 0, dequantSum1 = 0, dequantSum2 = 0;
uint32_t dequantSum = 0; uint32_t dequantSum = 0;
int32_t syncLoopIdx = -1;
uint32_t n = params.problemShape.n(); icache_preload(8);
BlockEpilogue1 blockEpilogue(resource, n);
for (int32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) { for (int32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) {
// The ith core reads data from the ith rank's peermem // The ith core reads data from the ith rank's peermem
groupIdxDeq = groupIdx - 2;
uint32_t currentM = cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx); uint32_t currentM = cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx);
for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) { for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) {
uint32_t rowStart = (dstEpIdx == 0 ? 0 : cumsumMM((dstEpIdx - 1) * params.expertPerRank + groupIdx)) + prevGroupSum1; uint32_t rowStart = (dstEpIdx == 0 ? 0 : cumsumMM((dstEpIdx - 1) * params.expertPerRank + groupIdx)) + prevGroupSum1;
@@ -664,24 +806,23 @@ private:
GM_ADDR otherRankPtr = shmem(0, dstEpIdx); GM_ADDR otherRankPtr = shmem(0, dstEpIdx);
AscendC::GlobalTensor<ElementA> gmRemoteA; AscendC::GlobalTensor<ElementA> gmRemoteA;
gmRemoteA.SetGlobalBuffer(reinterpret_cast<__gm__ ElementA*>(otherRankPtr + peermemInfo.offsetA)); gmRemoteA.SetGlobalBuffer(reinterpret_cast<__gm__ ElementA*>(otherRankPtr + peermemInfo.offsetA));
AscendC::GlobalTensor<ElementPerTokenScale> gmRemotePerTokenScale;
gmRemotePerTokenScale.SetGlobalBuffer(reinterpret_cast<__gm__ ElementPerTokenScale*>(otherRankPtr + peermemInfo.offsetPeerPerTokenScale));
MatrixCoord offsetA{rowStart, 0}; MatrixCoord offsetA{rowStart, 0};
MatrixCoord shapeA{rows, params.problemShape.k()};
MatrixCoord offsetPeer{rowSrc, 0}; MatrixCoord offsetPeer{rowSrc, 0};
int64_t gmOffsetA = params.layoutA.GetOffset(offsetA); int64_t gmOffsetA = params.layoutA.GetOffset(offsetA);
int64_t gmOffsetPeer = params.layoutA.GetOffset(offsetPeer); int64_t gmOffsetPeer = rowSrc * (params.problemShape.k() + ALIGN_512);
// Communication data // Communication data
CopyGMToGM(gmA[gmOffsetA], gmRemoteA[gmOffsetPeer], rows * params.problemShape.k(), params.ubMoveNum); CopyGMToGMPerToken(gmA[gmOffsetA], gmPerTokenScale1[rowStart], gmRemoteA[gmOffsetPeer], rows, params.problemShape.k());
// Communication scale
CopyGMToGM(gmPerTokenScale1[rowStart], gmRemotePerTokenScale[rowSrc], rows, rows);
}
} }
}
AscendC::SyncAll<true>(); AscendC::SyncAll<true>();
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(syncgmm1Idx / CROSS_CORE_FLAG_MAX_SET_COUNT); // V notifies C that the current communication round is complete AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(syncgmm1Idx / CROSS_CORE_FLAG_MAX_SET_COUNT);
prevGroupSum1 += currentM;
syncgmm1Idx ++; syncgmm1Idx ++;
prevGroupSum1 += currentM;
// Token count and truncation logic for the first SwiGLU operation
if (groupIdx + 1 <= params.epilogueGranularity) { if (groupIdx + 1 <= params.epilogueGranularity) {
if (dequantSum1 + currentM <= params.maxOutputSize) { if (dequantSum1 + currentM <= params.maxOutputSize) {
dequantSum1 += currentM; dequantSum1 += currentM;
@@ -689,6 +830,8 @@ private:
dequantSum1 = params.maxOutputSize; dequantSum1 = params.maxOutputSize;
} }
} }
// Token count and truncation logic for the second SwiGLU operation
if (groupIdx + 1 > params.epilogueGranularity && dequantSum1 < params.maxOutputSize) { if (groupIdx + 1 > params.epilogueGranularity && dequantSum1 < params.maxOutputSize) {
if (dequantSum1 + dequantSum2 + currentM <= params.maxOutputSize) { if (dequantSum1 + dequantSum2 + currentM <= params.maxOutputSize) {
dequantSum2 += currentM; dequantSum2 += currentM;
@@ -698,9 +841,27 @@ private:
} }
} }
uint32_t n2 = params.problemShape.k();
typename BlockEpilogue2::Params epilogueParams{
static_cast<int32_t>(params.EP),
static_cast<int32_t>(params.expertPerRank),
static_cast<int32_t>(params.rank),
reinterpret_cast<__gm__ int32_t *>(shmem() + peermemInfo.offsetPeerTokenPerExpert),
params.layoutD2,
static_cast<int32_t>(n2),
static_cast<int32_t>(L1TileShape::N),
shmem,
static_cast<int32_t>(peermemInfo.offsetD)
};
uint32_t n = params.problemShape.n();
BlockEpilogue2 blockEpilogue2(resource, epilogueParams);
BlockEpilogue1 blockEpilogue1(resource, n);
// Synchronous wait: SwiGLU waits for GMM1 [1]
AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGC2V); AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGC2V);
AscendC::SyncAll<true>(); AscendC::SyncAll<true>();
if (dequantSum1 > 0) { if (dequantSum1 > 0) {
uint32_t rowStartThisCore = 0; uint32_t rowStartThisCore = 0;
MatrixCoord offsetC{0U, 0}; MatrixCoord offsetC{0U, 0};
@@ -708,11 +869,14 @@ private:
LayoutC layoutC{dequantSum1, params.problemShape.n()}; LayoutC layoutC{dequantSum1, params.problemShape.n()};
int64_t gmOffsetC = layoutC.GetOffset(offsetC); int64_t gmOffsetC = layoutC.GetOffset(offsetC);
int64_t gmOffsetD = params.layoutD1.GetOffset(offsetC); int64_t gmOffsetD = params.layoutD1.GetOffset(offsetC);
blockEpilogue(gmC[gmOffsetC], shapeC, gmPerTokenScale1[rowStartThisCore], gmPermutedToken[gmOffsetD], gmPerTokenScale2[rowStartThisCore], params.epilogueCoreNum); blockEpilogue1(gmC[gmOffsetC], shapeC, gmPerTokenScale1[rowStartThisCore], gmPermutedToken[gmOffsetD], gmPerTokenScale2[rowStartThisCore], params.epilogueCoreNum);
} }
AscendC::SyncAll<true>(); AscendC::SyncAll<true>();
// Synchronization signal: SwiGLU notifies GMM2 [1]
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(SYNCFLAGV2C); AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(SYNCFLAGV2C);
if ((params.epilogueGranularity < params.expertPerRank && params.epilogueGranularity > 0)) { if ((params.epilogueGranularity < params.expertPerRank && params.epilogueGranularity > 0)) {
// Synchronous wait: SwiGLU waits for GMM1 [2]
AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGC2V); AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGC2V);
AscendC::SyncAll<true>(); AscendC::SyncAll<true>();
if (dequantSum2 > 0) { if (dequantSum2 > 0) {
@@ -723,76 +887,118 @@ private:
LayoutC layoutC{dequantLen, params.problemShape.n()}; LayoutC layoutC{dequantLen, params.problemShape.n()};
int64_t gmOffsetC = layoutC.GetOffset(offsetC); int64_t gmOffsetC = layoutC.GetOffset(offsetC);
int64_t gmOffsetD = params.layoutD1.GetOffset(offsetC); int64_t gmOffsetD = params.layoutD1.GetOffset(offsetC);
blockEpilogue(gmC[gmOffsetC], shapeC, gmPerTokenScale1[rowStartThisCore], gmPermutedToken[gmOffsetD], gmPerTokenScale2[rowStartThisCore], coreNum); blockEpilogue1(gmC[gmOffsetC], shapeC, gmPerTokenScale1[rowStartThisCore], gmPermutedToken[gmOffsetD], gmPerTokenScale2[rowStartThisCore], coreNum);
} }
AscendC::SyncAll<true>(); AscendC::SyncAll<true>();
// Synchronization signal: SwiGLU notifies GMM2 [2]
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(SYNCFLAGV2C); AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(SYNCFLAGV2C);
} }
blockEpilogue.Finalize();
}
CATLASS_DEVICE blockEpilogue1.Finalize();
void Combine(Params const &params) {
int32_t prevSumBeforeRank = 0;
if (coreIdx < params.EP) {
prevSumBeforeRank = m_prevSumBeforeRank;
}
int prevSum = prevSumBeforeRank;
uint32_t n2 = params.problemShape.k();
uint32_t k2 = params.problemShape.n() / 2;
// TODO compute the cumsum of tokenPerExpert CombineSetFlag();
typename BlockEpilogue2::Params epilogueParams{
static_cast<int32_t>(params.EP), CombineV2(params, blockEpilogue2);
static_cast<int32_t>(params.expertPerRank),
reinterpret_cast<__gm__ int32_t *>(params.ptrWorkspace),
static_cast<int32_t>(n2)
};
BlockEpilogue2 blockEpilogue(resource, epilogueParams);
int32_t prevGroupSum2 = 0;
for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) {
AscendC::CrossCoreWaitFlag<0x2>(groupIdx / CROSS_CORE_FLAG_MAX_SET_COUNT);
AscendC::SyncAll<true>(); AscendC::SyncAll<true>();
#ifndef __CROSSRANKSYNCANDALLGATHERV1__
for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) { ResetTokenPerExpert(params.EP * AlignUp(params.EP * params.expertPerRank, 128));
__gm__ void* dstPeermemPtr = shmem(peermemInfo.offsetD, dstEpIdx); #endif
AscendC::GlobalTensor<ElementD2> gmRemotePeer; shmem.InitStatusTargetSum();
gmRemotePeer.SetGlobalBuffer(reinterpret_cast<__gm__ ElementD2*>(dstPeermemPtr)); if (get_subblockid() == 0) {
uint32_t srcRowOffset = (dstEpIdx == 0 ? 0 : cumsumMM((dstEpIdx - 1) * params.expertPerRank + groupIdx)) + prevGroupSum2; AscendC::LocalTensor<int32_t> ctrBuffer = resource.ubBuf.template GetBufferByByte<int32_t>(0);
if (srcRowOffset < params.maxOutputSize) { shmem.CrossRankSyncV2Set(ctrBuffer);
uint32_t dataRows = tokenPerExpert(tokenPerExpertLayout(dstEpIdx, params.rank, groupIdx));
if (srcRowOffset + dataRows > params.maxOutputSize) {
dataRows = params.maxOutputSize - srcRowOffset;
}
uint32_t dstRowOffset = prevSum;
prevSum += dataRows;
MatrixCoord offsetC{srcRowOffset, 0};
MatrixCoord offsetPeer{dstRowOffset, 0};
MatrixCoord shapeC{dataRows, n2};
int64_t gmOffsetC = params.layoutD2.GetOffset(offsetC);
int64_t gmOffsetPeer = params.layoutD2.GetOffset(offsetPeer);
if constexpr (std::is_same_v<ElementA, int8_t>) {
blockEpilogue(gmC2[gmOffsetC], shapeC, gmPerTokenScale2[srcRowOffset], gmRemotePeer[gmOffsetPeer]);
} else { } else {
blockEpilogue(gmC2[gmOffsetC], shapeC, gmRemotePeer[gmOffsetPeer]); uint32_t uboffset = 0;
uint32_t aicCoreNum = coreNum / 2;
uint32_t aicCoreIdx = get_block_idx();
uint32_t sendRankNum_ = params.EP / aicCoreNum;
uint32_t remainderRankNum = params.EP % aicCoreNum;
if (aicCoreIdx < remainderRankNum) {
sendRankNum_++;
} }
} AscendC::LocalTensor<float> statusTensor = resource.ubBuf.template GetBufferByByte<float>(uboffset);
} uboffset += sendRankNum_ * UB_ALIGN;
prevGroupSum2 += cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx); AscendC::LocalTensor<float> gatherMaskOutTensor = resource.ubBuf.template GetBufferByByte<float>(uboffset);
} uboffset += params.EP * sizeof(float);
blockEpilogue.Finalize(); AscendC::LocalTensor<uint32_t> gatherTmpTensor = resource.ubBuf.template GetBufferByByte<uint32_t>(uboffset);
AscendC::SyncAll<true>(); uboffset += sizeof(uint32_t);
ResetTokenPerExpert(tokenPerExpert, params.EP * AlignUp(params.EP * params.expertPerRank, ALIGN_128)); AscendC::LocalTensor<float> statusSumOutTensor = resource.ubBuf.template GetBufferByByte<float>(uboffset);
shmem.CrossRankSync(); uboffset += sizeof(float);
shmem.CrossRankSyncV2Wait(statusTensor, gatherMaskOutTensor, gatherTmpTensor, statusSumOutTensor);
MoeTokenUnpermuteTilingData tilingData; MoeTokenUnpermuteTilingData tilingData;
MoeTokenUnpermuteTiling(params.problemShape.m() * params.topK, n2, params.topK, tilingData, coreNum); MoeTokenUnpermuteTiling(params.problemShape.m() * params.topK, n2, params.topK, tilingData, coreNum / 2);
KernelMoeTokenUnpermute<ElementD2, int32_t, float, true> kernelMoeTokenUnpermuteOp; KernelMoeTokenUnpermute<ElementD2, int32_t, float, true> kernelMoeTokenUnpermuteOp;
kernelMoeTokenUnpermuteOp.Init(shmem() + peermemInfo.offsetD, workspaceInfo.expandedRowIdx, params.probs, reinterpret_cast<GM_ADDR>(params.ptrOutput), &tilingData); kernelMoeTokenUnpermuteOp.Init(shmem() + peermemInfo.offsetD, workspaceInfo.expandedRowIdx, params.probs, reinterpret_cast<GM_ADDR>(params.ptrOutput), &tilingData);
kernelMoeTokenUnpermuteOp.Process(); kernelMoeTokenUnpermuteOp.Process();
} }
}
CATLASS_DEVICE
void CombineV2(Params const &params, BlockEpilogue2 & blockEpilogue) {
BlockScheduler blockScheduler;
int32_t syncLoopIdx = 0;
uint32_t startCoreIdx = 0;
uint32_t aicCoreNum = coreNum / 2;
uint32_t aicCoreIdx = get_block_idx();
uint32_t aivSubCoreIdx = get_subblockid();
uint32_t preSrcExpertSum = 0;
uint32_t n2 = params.problemShape.k();
uint32_t k2 = params.problemShape.n() / 2;
icache_preload(8);
for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) {
uint32_t currentExpertM = cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx);
if (preSrcExpertSum >= params.maxOutputSize) {
currentExpertM = 0;
} else if (preSrcExpertSum + currentExpertM > params.maxOutputSize) {
currentExpertM = params.maxOutputSize - preSrcExpertSum;
}
GemmCoord inGroupProblemShape{currentExpertM, n2, k2}; // M N K
blockScheduler.Update(inGroupProblemShape, MakeCoord(L1TileShape::M, L1TileShape::N));
uint32_t coreLoops = blockScheduler.GetCoreLoops();
uint32_t startLoopIdx = ((aicCoreIdx < startCoreIdx) ? (aicCoreIdx + aicCoreNum) : aicCoreIdx) - startCoreIdx;
for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += aicCoreNum) {
GemmCoord blockCoord = blockScheduler.GetBlockCoord(loopIdx);
GemmCoord actualBlockShape = blockScheduler.GetActualBlockShape(blockCoord);
int32_t m0 = 16;
// Block count, the shape of each block is (m0, actualBlockShape.n())
int32_t m_rows = (actualBlockShape.m() + m0 - 1) / m0;
int32_t aiv_m_rows = m_rows / 2;
if (aivSubCoreIdx == 1 && aiv_m_rows * 2 < m_rows) {
aiv_m_rows += 1;
}
uint32_t m_offset = blockCoord.m() * L1TileShape::M;//blockOffset
if(aivSubCoreIdx == 1) {
m_offset += (m_rows / 2) * m0;
}
for (;syncLoopIdx <= groupIdx; syncLoopIdx ++) {
int32_t flag_id = syncLoopIdx / CROSS_CORE_FLAG_MAX_SET_COUNT;
AscendC::CrossCoreWaitFlag<0x2>(flag_id);
}
for (int32_t cur_row = 0; cur_row < aiv_m_rows; cur_row ++) {
GemmCoord realTileCoord{m_offset, blockCoord.n() * L1TileShape::N, 1};
uint32_t actualm = m0;
if(aivSubCoreIdx == 1 && cur_row == aiv_m_rows - 1){
actualm = actualBlockShape.m() - (m_rows / 2) * m0 - cur_row * m0;
}
GemmCoord realTileShape{actualm, actualBlockShape.n(), 1};
blockEpilogue(gmC2, gmPerTokenScale2, realTileCoord, realTileShape, groupIdx, preSrcExpertSum, preSumBeforeRank);
m_offset += m0;
}
}
preSrcExpertSum += currentExpertM;
startCoreIdx = (startCoreIdx + coreLoops) % aicCoreNum;
}
blockEpilogue.Finalize();
}
private: private:
struct WorkspaceInfo { struct WorkspaceInfo {
GM_ADDR ptrA; GM_ADDR ptrA;
@@ -804,6 +1010,9 @@ private:
GM_ADDR ptrPerTokenScale2; GM_ADDR ptrPerTokenScale2;
GM_ADDR expandedRowIdx; GM_ADDR expandedRowIdx;
GM_ADDR ptrTokenPerExpert; GM_ADDR ptrTokenPerExpert;
GM_ADDR ptrSumBeforeRank;
__gm__ float* ptrSoftFlagBase;
CATLASS_DEVICE CATLASS_DEVICE
WorkspaceInfo(){} WorkspaceInfo(){}
@@ -831,15 +1040,21 @@ private:
workspaceOffset += (params.EP * params.EP * params.expertPerRank) * sizeof(int32_t); workspaceOffset += (params.EP * params.EP * params.expertPerRank) * sizeof(int32_t);
ptrC = params.ptrWorkspace + workspaceOffset; ptrC = params.ptrWorkspace + workspaceOffset;
ptrC2 = ptrC;
workspaceOffset += max(params.maxOutputSize * params.problemShape.n() * sizeof(ElementC), workspaceOffset += params.maxOutputSize * params.problemShape.n() * sizeof(ElementC);
params.maxOutputSize * n2 * sizeof(ElementC)); ptrC2 = params.ptrWorkspace + workspaceOffset;
workspaceOffset += params.maxOutputSize * n2 * sizeof(ElementC);
ptrA = params.ptrWorkspace + workspaceOffset; ptrA = params.ptrWorkspace + workspaceOffset;
ptrPermutedToken = ptrA;
workspaceOffset += max(params.maxOutputSize * params.problemShape.k() * sizeof(ElementA), workspaceOffset += params.maxOutputSize * params.problemShape.k() * sizeof(ElementA);
params.maxOutputSize * k2 * sizeof(ElementA)); ptrPermutedToken = params.ptrWorkspace + workspaceOffset;
workspaceOffset += params.maxOutputSize * k2 * sizeof(ElementA);
ptrSumBeforeRank = params.ptrWorkspace + workspaceOffset;
workspaceOffset += params.EP * sizeof(int32_t) * FLAGSTRIDE;
ptrSoftFlagBase = reinterpret_cast<__gm__ float*>(params.ptrWorkspace + workspaceOffset);
} }
}; };
@@ -866,12 +1081,9 @@ private:
uint32_t coreIdx; uint32_t coreIdx;
uint32_t coreNum; uint32_t coreNum;
Params params;
WorkspaceInfo workspaceInfo; WorkspaceInfo workspaceInfo;
PeermemInfo peermemInfo; PeermemInfo peermemInfo;
int64_t m_prevSumBeforeRank;
AscendC::GlobalTensor<ElementA> gmA; AscendC::GlobalTensor<ElementA> gmA;
AscendC::GlobalTensor<ElementC> gmC; AscendC::GlobalTensor<ElementC> gmC;
@@ -883,10 +1095,11 @@ private:
AscendC::GlobalTensor<int32_t> tokenPerExpert; AscendC::GlobalTensor<int32_t> tokenPerExpert;
AscendC::GlobalTensor<int32_t> cumsumMM; AscendC::GlobalTensor<int32_t> cumsumMM;
AscendC::GlobalTensor<int32_t> preSumBeforeRank;
Layout3D tokenPerExpertLayout; Layout3D tokenPerExpertLayout;
HcclShmem shmem; HcclShmem shmem;
}; };
} // namespace Catlass::Gemm::Kernel } // namespace Catlass::Gemm::Kernel
#endif // DISPATH_FFN_COMBINE_KERNEL_HPP #endif // DISPATCH_FFN_COMBINE_KERNEL_HPP

View File

@@ -44,6 +44,8 @@ constexpr int64_t EXERPT_TOKENS_COUNT = 2;
constexpr int64_t EXERPT_TOKENS_CUMSUM = 1; constexpr int64_t EXERPT_TOKENS_CUMSUM = 1;
constexpr int64_t EXERPT_TOKENS_NONE = 0; constexpr int64_t EXERPT_TOKENS_NONE = 0;
constexpr int64_t EXERPT_TOKENS_BEFORE_CAPACITY = 1; constexpr int64_t EXERPT_TOKENS_BEFORE_CAPACITY = 1;
constexpr int64_t ALIGN_512 = 512;
constexpr int64_t ALIGN_128 = 128;
const __gm__ int32_t assist[256] = { const __gm__ int32_t assist[256] = {
0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0,

View File

@@ -35,7 +35,6 @@ class MoeV2FullLoadDynamicQuant : public MoeV2SortBase {
__aicore__ inline void CopyOutIdx(); __aicore__ inline void CopyOutIdx();
__aicore__ inline void CopyOutEmpty(); __aicore__ inline void CopyOutEmpty();
__aicore__ inline void CopyOutXQuant1H(); __aicore__ inline void CopyOutXQuant1H();
__aicore__ inline void CopyOutXQuantEH();
__aicore__ inline void ComputeExpertTokenCountOrCumsum(); __aicore__ inline void ComputeExpertTokenCountOrCumsum();
__aicore__ inline void Compute(LocalTensor<float>& smoothLocal); __aicore__ inline void Compute(LocalTensor<float>& smoothLocal);
@@ -49,6 +48,7 @@ class MoeV2FullLoadDynamicQuant : public MoeV2SortBase {
int64_t k_; int64_t k_;
int64_t n_; int64_t n_;
int64_t cols_; int64_t cols_;
int64_t cols_scale_;
int64_t activateRows_; int64_t activateRows_;
int64_t expertNum; int64_t expertNum;
int64_t expertCapacity; int64_t expertCapacity;
@@ -63,12 +63,10 @@ class MoeV2FullLoadDynamicQuant : public MoeV2SortBase {
TQue<QuePosition::VECIN, 1> smoothInQueue; TQue<QuePosition::VECIN, 1> smoothInQueue;
TQue<QuePosition::VECOUT, 1> calcQueue; TQue<QuePosition::VECOUT, 1> calcQueue;
TQue<QuePosition::VECOUT, 1> inputXOutQueue; TQue<QuePosition::VECOUT, 1> inputXOutQueue;
TQue<QuePosition::VECOUT, 1> scaleOutQueue;
GlobalTensor<T> xGm_; GlobalTensor<T> xGm_;
GlobalTensor<int32_t> expertIdxGm_; GlobalTensor<int32_t> expertIdxGm_;
GlobalTensor<float> quantSmoothGm; GlobalTensor<float> quantSmoothGm;
GlobalTensor<float> dynamicQuantScaleGm;
GlobalTensor<int8_t> expandedXGm_; GlobalTensor<int8_t> expandedXGm_;
GlobalTensor<int32_t> expandedRowIdxGm_; GlobalTensor<int32_t> expandedRowIdxGm_;
@@ -225,7 +223,7 @@ __aicore__ inline void MoeV2FullLoadDynamicQuant<T>::Compute(LocalTensor<float>&
LocalTensor<float> tempLocal = calcQueue.AllocTensor<float>(); LocalTensor<float> tempLocal = calcQueue.AllocTensor<float>();
LocalTensor<int8_t> outLocal = inputXOutQueue.AllocTensor<int8_t>(); LocalTensor<int8_t> outLocal = inputXOutQueue.AllocTensor<int8_t>();
LocalTensor<float> dynamicQuantLocal = scaleOutQueue.AllocTensor<float>(); LocalTensor<float> dynamicQuantLocal = outLocal[this->cols_].template ReinterpretCast<float>();
if constexpr (!IsSameType<T, float>::value) { if constexpr (!IsSameType<T, float>::value) {
Cast(inLocal, inLocal.ReinterpretCast<T>()[colsAlign], RoundMode::CAST_NONE, this->cols_); Cast(inLocal, inLocal.ReinterpretCast<T>()[colsAlign], RoundMode::CAST_NONE, this->cols_);
@@ -259,7 +257,6 @@ __aicore__ inline void MoeV2FullLoadDynamicQuant<T>::Compute(LocalTensor<float>&
calcQueue.FreeTensor(tempLocal); calcQueue.FreeTensor(tempLocal);
inputXOutQueue.EnQue(outLocal); inputXOutQueue.EnQue(outLocal);
scaleOutQueue.EnQue(dynamicQuantLocal);
} }
template <typename T> template <typename T>
@@ -275,7 +272,7 @@ __aicore__ inline void MoeV2FullLoadDynamicQuant<T>::CopyOutXQuant1H() {
DataCopyExtParams dataXCopyParams{1, static_cast<uint32_t>(this->cols_ * sizeof(T)), 0, 0, 0}; DataCopyExtParams dataXCopyParams{1, static_cast<uint32_t>(this->cols_ * sizeof(T)), 0, 0, 0};
DataCopyExtParams smoothCopyParams{1, static_cast<uint32_t>(this->cols_ * sizeof(float)), 0, 0, 0}; DataCopyExtParams smoothCopyParams{1, static_cast<uint32_t>(this->cols_ * sizeof(float)), 0, 0, 0};
DataCopyExtParams intriParams{1, static_cast<uint32_t>(this->cols_ * sizeof(int8_t)), 0, 0, 0}; DataCopyExtParams intriParams{1, static_cast<uint32_t>((this->cols_ + BLOCK_BYTES) * sizeof(int8_t)), 0, 0, 0};
LocalTensor<float> smoothLocal; LocalTensor<float> smoothLocal;
if (smoothType == 1) { if (smoothType == 1) {
@@ -295,7 +292,6 @@ __aicore__ inline void MoeV2FullLoadDynamicQuant<T>::CopyOutXQuant1H() {
xCopyInQueue_.EnQue<T>(xLocal); xCopyInQueue_.EnQue<T>(xLocal);
Compute(smoothLocal); Compute(smoothLocal);
LocalTensor<float> quantScaleLocal = scaleOutQueue.DeQue<float>();
LocalTensor<int8_t> outLocal = inputXOutQueue.DeQue<int8_t>(); LocalTensor<int8_t> outLocal = inputXOutQueue.DeQue<int8_t>();
while (curRowsStart <= curRowsEnd && curRowsStart / this->k_ == row) { while (curRowsStart <= curRowsEnd && curRowsStart / this->k_ == row) {
int32_t outIndex = expandedRowIdx.GetValue(curRowsStart); int32_t outIndex = expandedRowIdx.GetValue(curRowsStart);
@@ -303,76 +299,15 @@ __aicore__ inline void MoeV2FullLoadDynamicQuant<T>::CopyOutXQuant1H() {
if (outIndex == -1 || (this->dropPadMode == DROPLESS_MODE && outIndex >= this->activateRows_)) { if (outIndex == -1 || (this->dropPadMode == DROPLESS_MODE && outIndex >= this->activateRows_)) {
continue; continue;
} }
DataCopyPad(expandedXGm_[outIndex * cols_], outLocal, intriParams); DataCopyPad(expandedXGm_[outIndex * this->cols_scale_], outLocal, intriParams);
DataCopyPad(dynamicQuantScaleGm[outIndex], quantScaleLocal, {1, 4, 0, 0, 0});
} }
xCopyInQueue_.FreeTensor(xLocal); xCopyInQueue_.FreeTensor(xLocal);
inputXOutQueue.FreeTensor(outLocal); inputXOutQueue.FreeTensor(outLocal);
scaleOutQueue.FreeTensor(quantScaleLocal);
}
if (smoothType == 1) {
smoothInQueue.FreeTensor(smoothLocal);
} }
expandedRowIdxCopyOutQueue_.FreeTensor(expandedRowIdx); expandedRowIdxCopyOutQueue_.FreeTensor(expandedRowIdx);
} }
template <typename T>
__aicore__ inline void MoeV2FullLoadDynamicQuant<T>::CopyOutXQuantEH() {
LocalTensor<int32_t> expandedRowIdx = expandedRowIdxCopyOutQueue_.DeQue<int32_t>();
expandedRowIdxCopyOutQueue_.FreeTensor(expandedRowIdx);
Muls(expandDstToSrcRowLocal.ReinterpretCast<float>(), expandDstToSrcRowLocal.ReinterpretCast<float>(), (float)-1,
this->totalLength);
pipe_barrier(PIPE_V);
LocalTensor<int32_t> sortedRowIdx = expandDstToSrcRowLocal.ReinterpretCast<int32_t>();
Cast(sortedRowIdx, expandDstToSrcRowLocal.ReinterpretCast<float>(), RoundMode::CAST_ROUND, this->totalLength);
int64_t curRowsStart = this->blockIdx_ * this->perCoreRows_;
int64_t curRowsEnd = curRowsStart + this->coreRows_ - 1;
DataCopyExtParams dataXCopyParams{1, static_cast<uint32_t>(this->cols_ * sizeof(T)), 0, 0, 0};
DataCopyExtParams smoothCopyParams{1, static_cast<uint32_t>(this->cols_ * sizeof(float)), 0, 0, 0};
DataCopyExtParams intriParams{1, static_cast<uint32_t>(this->cols_ * sizeof(int8_t)), 0, 0, 0};
for (int64_t row = curRowsStart; row <= curRowsEnd; row++) {
if (this->dropPadMode == DROPLESS_MODE && row >= this->activateRows_) {
break;
}
int32_t srcIdx = sortedRowIdx.GetValue(row);
int32_t expertIdx = expandedExpertIdxLocal.GetValue(row);
LocalTensor<T> inLocal = xCopyInQueue_.AllocTensor<T>();
LocalTensor<float> smoothLocal = smoothInQueue.AllocTensor<float>();
if constexpr (IsSameType<T, float>::value) {
DataCopyPad(inLocal, xGm_[srcIdx / this->k_ * this->cols_], dataXCopyParams, {false, 0, 0, 0});
} else {
DataCopyPad(inLocal[colsAlign], xGm_[srcIdx / this->k_ * this->cols_], dataXCopyParams, {false, 0, 0, 0});
}
DataCopyPad(smoothLocal, quantSmoothGm[expertIdx * this->cols_], smoothCopyParams, {false, 0, 0, 0});
xCopyInQueue_.EnQue<T>(inLocal);
smoothInQueue.EnQue(smoothLocal);
smoothLocal = smoothInQueue.DeQue<float>();
Compute(smoothLocal);
LocalTensor<float> quantScaleLocal = scaleOutQueue.DeQue<float>();
DataCopyPad(dynamicQuantScaleGm[row], quantScaleLocal, {1, 4, 0, 0, 0});
LocalTensor<int8_t> outLocal = inputXOutQueue.DeQue<int8_t>();
DataCopyPad(expandedXGm_[row * this->cols_], outLocal, intriParams);
xCopyInQueue_.FreeTensor(inLocal);
smoothInQueue.FreeTensor(smoothLocal);
inputXOutQueue.FreeTensor(outLocal);
scaleOutQueue.FreeTensor(quantScaleLocal);
}
expandDstToSrcRowQueue_.FreeTensor(expandDstToSrcRowLocal);
expandedExpertIdxCopyOutQueue_.FreeTensor(expandedExpertIdxLocal);
}
template <typename T> template <typename T>
__aicore__ inline void MoeV2FullLoadDynamicQuant<T>::Init(GM_ADDR x, GM_ADDR expertIdx, GM_ADDR expandedX, __aicore__ inline void MoeV2FullLoadDynamicQuant<T>::Init(GM_ADDR x, GM_ADDR expertIdx, GM_ADDR expandedX,
GM_ADDR expandedRowIdx, GM_ADDR expertTokensCountOrCumsum, GM_ADDR expandedRowIdx, GM_ADDR expertTokensCountOrCumsum,
@@ -386,6 +321,7 @@ __aicore__ inline void MoeV2FullLoadDynamicQuant<T>::Init(GM_ADDR x, GM_ADDR exp
this->k_ = tilingData->k; this->k_ = tilingData->k;
this->n_ = tilingData->n; this->n_ = tilingData->n;
this->cols_ = tilingData->cols; this->cols_ = tilingData->cols;
this->cols_scale_ = this->cols_ + ALIGN_512;
this->needCoreNum_ = this->gatherOutTilingData_->needCoreNum; this->needCoreNum_ = this->gatherOutTilingData_->needCoreNum;
this->perCoreRows_ = this->gatherOutTilingData_->perCoreRows; this->perCoreRows_ = this->gatherOutTilingData_->perCoreRows;
this->activateRows_ = this->gatherOutTilingData_->activateRows; this->activateRows_ = this->gatherOutTilingData_->activateRows;
@@ -416,7 +352,6 @@ __aicore__ inline void MoeV2FullLoadDynamicQuant<T>::Init(GM_ADDR x, GM_ADDR exp
Align(this->expertNum, sizeof(int32_t))); Align(this->expertNum, sizeof(int32_t)));
} }
quantSmoothGm.SetGlobalBuffer((__gm__ float*)quantSmooth); quantSmoothGm.SetGlobalBuffer((__gm__ float*)quantSmooth);
dynamicQuantScaleGm.SetGlobalBuffer((__gm__ float*)dynamicQuantScale);
int64_t kvFactor = 2; int64_t kvFactor = 2;
int64_t buffSize = this->sortNum_ * sizeof(int32_t); int64_t buffSize = this->sortNum_ * sizeof(int32_t);
@@ -440,8 +375,7 @@ __aicore__ inline void MoeV2FullLoadDynamicQuant<T>::Init(GM_ADDR x, GM_ADDR exp
} }
pipe->InitBuffer(smoothInQueue, 1, AlignBytes(this->cols_, sizeof(float))); pipe->InitBuffer(smoothInQueue, 1, AlignBytes(this->cols_, sizeof(float)));
pipe->InitBuffer(calcQueue, 1, AlignBytes(this->cols_, sizeof(float))); pipe->InitBuffer(calcQueue, 1, AlignBytes(this->cols_, sizeof(float)));
pipe->InitBuffer(inputXOutQueue, 1, AlignBytes(this->cols_, sizeof(int8_t))); pipe->InitBuffer(inputXOutQueue, 1, AlignBytes(this->cols_scale_, sizeof(int8_t)));
pipe->InitBuffer(scaleOutQueue, 1, BLOCK_BYTES + BLOCK_BYTES);
} }
template <typename T> template <typename T>
@@ -457,12 +391,8 @@ __aicore__ inline void MoeV2FullLoadDynamicQuant<T>::Process() {
} else { } else {
CopyOutEmpty(); CopyOutEmpty();
} }
if (smoothType == 2) {
CopyOutXQuantEH();
} else {
CopyOutXQuant1H(); CopyOutXQuant1H();
} }
} }
}
} // namespace MoeInitRoutingQuantV2 } // namespace MoeInitRoutingQuantV2
#endif // MOE_V2_DYNAMIC_QUANT_FULL_LOAD_H #endif // MOE_V2_DYNAMIC_QUANT_FULL_LOAD_H

View File

@@ -66,6 +66,7 @@ class MoeV2GatherDynamicQuant {
int64_t needCoreNum; int64_t needCoreNum;
int64_t blockIdx; int64_t blockIdx;
int64_t cols; int64_t cols;
int64_t cols_scale_;
int64_t n; int64_t n;
int64_t k; int64_t k;
int64_t totalLength; int64_t totalLength;
@@ -117,7 +118,7 @@ __aicore__ inline void MoeV2GatherDynamicQuant<T>::Compute(LocalTensor<float>& s
LocalTensor<float> tempLocal = calcQueue.AllocTensor<float>(); LocalTensor<float> tempLocal = calcQueue.AllocTensor<float>();
LocalTensor<int8_t> outLocal = inputXOutQueue.AllocTensor<int8_t>(); LocalTensor<int8_t> outLocal = inputXOutQueue.AllocTensor<int8_t>();
LocalTensor<float> dynamicQuantLocal = scaleOutQueue.AllocTensor<float>(); LocalTensor<float> dynamicQuantLocal = outLocal[this->cols].template ReinterpretCast<float>();
if constexpr (!IsSameType<T, float>::value) { if constexpr (!IsSameType<T, float>::value) {
Cast(inLocal, inLocal.ReinterpretCast<T>()[perLoopColsAlign], RoundMode::CAST_NONE, this->cols); Cast(inLocal, inLocal.ReinterpretCast<T>()[perLoopColsAlign], RoundMode::CAST_NONE, this->cols);
@@ -151,7 +152,6 @@ __aicore__ inline void MoeV2GatherDynamicQuant<T>::Compute(LocalTensor<float>& s
calcQueue.FreeTensor(tempLocal); calcQueue.FreeTensor(tempLocal);
inputXOutQueue.EnQue(outLocal); inputXOutQueue.EnQue(outLocal);
scaleOutQueue.EnQue(dynamicQuantLocal);
} }
template <typename T> template <typename T>
@@ -163,7 +163,7 @@ __aicore__ inline void MoeV2GatherDynamicQuant<T>::CopyOutXQuant1H(int64_t progr
int64_t currentLoopStartRow = initialRow / this->k; int64_t currentLoopStartRow = initialRow / this->k;
int64_t currentLoopLastRow = (initialRow + this->currentLoopRows - 1) / this->k; int64_t currentLoopLastRow = (initialRow + this->currentLoopRows - 1) / this->k;
DataCopyExtParams copyInParams{1, static_cast<uint32_t>(this->cols * sizeof(T)), 0, 0, 0}; DataCopyExtParams copyInParams{1, static_cast<uint32_t>(this->cols * sizeof(T)), 0, 0, 0};
DataCopyExtParams copyOutParams{1, static_cast<uint32_t>(this->cols * sizeof(int8_t)), 0, 0, 0}; DataCopyExtParams copyOutParams{1, static_cast<uint32_t>((this->cols + BLOCK_BYTES) * sizeof(int8_t)), 0, 0, 0};
DataCopyExtParams smoothParams{1, static_cast<uint32_t>(this->cols * sizeof(float)), 0, 0, 0}; DataCopyExtParams smoothParams{1, static_cast<uint32_t>(this->cols * sizeof(float)), 0, 0, 0};
LocalTensor<float> smoothLocal; LocalTensor<float> smoothLocal;
@@ -187,7 +187,6 @@ __aicore__ inline void MoeV2GatherDynamicQuant<T>::CopyOutXQuant1H(int64_t progr
// Compute quantization // Compute quantization
Compute(smoothLocal); Compute(smoothLocal);
LocalTensor<float> quantScaleLocal = scaleOutQueue.DeQue<float>();
LocalTensor<int8_t> outLocal = inputXOutQueue.DeQue<int8_t>(); LocalTensor<int8_t> outLocal = inputXOutQueue.DeQue<int8_t>();
while (curLoopRow < this->currentLoopRows && initialRow / this->k == row) { while (curLoopRow < this->currentLoopRows && initialRow / this->k == row) {
@@ -197,15 +196,11 @@ __aicore__ inline void MoeV2GatherDynamicQuant<T>::CopyOutXQuant1H(int64_t progr
if (outIndex == -1 || (this->dropPadMode == DROPLESS_MODE && outIndex >= this->activateRows)) { if (outIndex == -1 || (this->dropPadMode == DROPLESS_MODE && outIndex >= this->activateRows)) {
continue; continue;
} }
DataCopyPad(expandedXGm[outIndex * cols], outLocal, copyOutParams); // Scale is placed after the data position
DataCopyPad(dynamicQuantScaleGm[outIndex], quantScaleLocal, {1, 4, 0, 0, 0}); DataCopyPad(expandedXGm[outIndex * cols_scale_], outLocal, copyOutParams);
} }
inputXInQueue.FreeTensor(inLocal); inputXInQueue.FreeTensor(inLocal);
inputXOutQueue.FreeTensor(outLocal); inputXOutQueue.FreeTensor(outLocal);
scaleOutQueue.FreeTensor(quantScaleLocal);
}
if (smoothType == 1) {
smoothInQueue.FreeTensor(smoothLocal);
} }
expandRowIdxInQueue.FreeTensor(indicesLocal); expandRowIdxInQueue.FreeTensor(indicesLocal);
} }
@@ -463,6 +458,7 @@ __aicore__ inline void MoeV2GatherDynamicQuant<T>::Init(GM_ADDR inputX, GM_ADDR
this->needCoreNum = this->gatherOutTilingData->needCoreNum; this->needCoreNum = this->gatherOutTilingData->needCoreNum;
this->activateRows = this->gatherOutTilingData->activateRows; this->activateRows = this->gatherOutTilingData->activateRows;
this->cols = tilingData->cols; this->cols = tilingData->cols;
this->cols_scale_ = this->cols + ALIGN_512;
this->n = tilingData->n; this->n = tilingData->n;
this->k = tilingData->k; this->k = tilingData->k;
this->totalLength = tilingData->n * tilingData->k; this->totalLength = tilingData->n * tilingData->k;
@@ -518,32 +514,15 @@ __aicore__ inline void MoeV2GatherDynamicQuant<T>::Init(GM_ADDR inputX, GM_ADDR
pipe->InitBuffer(smoothInQueue, BUFFER_NUM, AlignBytes(this->perLoopCols, sizeof(float))); pipe->InitBuffer(smoothInQueue, BUFFER_NUM, AlignBytes(this->perLoopCols, sizeof(float)));
pipe->InitBuffer(calcQueue, 1, AlignBytes(this->perLoopCols, sizeof(float))); pipe->InitBuffer(calcQueue, 1, AlignBytes(this->perLoopCols, sizeof(float)));
pipe->InitBuffer(inputXOutQueue, 1, AlignBytes(this->perLoopCols, sizeof(int8_t))); pipe->InitBuffer(inputXOutQueue, 1, AlignBytes(this->perLoopCols, sizeof(int8_t)));
pipe->InitBuffer(scaleOutQueue, 1, BLOCK_BYTES + BLOCK_BYTES);
} }
template <typename T> template <typename T>
__aicore__ inline void MoeV2GatherDynamicQuant<T>::Process() { __aicore__ inline void MoeV2GatherDynamicQuant<T>::Process() {
if (this->blockIdx < this->needCoreNum) { if (this->blockIdx < this->needCoreNum) {
currentLoopRows = perLoopRows; currentLoopRows = perLoopRows;
if (colLoops > 1) { // A single row cannot be fully loaded; workspace is required if (colLoops > 1) { // Cannot fit all data in one row, workspace is required
if (smoothType == 2) { trap(); // Not supported
for (int64_t loop = 0; loop < this->rowLoops - 1; loop++) { } else { // All data can fit in one row
CopyInExpandedExpertIdx(loop);
CopyOutPartialXQuantEH(loop);
}
currentLoopRows = lastLoopRows;
CopyInExpandedExpertIdx(this->rowLoops - 1);
CopyOutPartialXQuantEH(this->rowLoops - 1);
} else {
for (int64_t loop = 0; loop < this->rowLoops - 1; loop++) {
CopyInExpandedRowIdx(loop);
CopyOutPartialXQuant1H(loop);
}
currentLoopRows = lastLoopRows;
CopyInExpandedRowIdx(this->rowLoops - 1);
CopyOutPartialXQuant1H(this->rowLoops - 1);
}
} else { // A single row can be fully loaded
if (smoothType == 2) { if (smoothType == 2) {
for (int64_t loop = 0; loop < this->rowLoops - 1; loop++) { for (int64_t loop = 0; loop < this->rowLoops - 1; loop++) {
CopyInExpandedExpertIdx(loop); CopyInExpandedExpertIdx(loop);

View File

@@ -85,8 +85,8 @@ KernelMoeTokenUnpermute<T1, T2, T3, PROBS>::Init(GM_ADDR permuted_tokens, GM_ADD
GM_ADDR unpermuted_tokens, GM_ADDR unpermuted_tokens,
const MoeTokenUnpermuteTilingData *__restrict tiling_data) const MoeTokenUnpermuteTilingData *__restrict tiling_data)
{ {
this->blockIdx = get_block_idx() + get_subblockid() * get_block_num(); this->blockIdx = get_block_idx();
this->blockNum = get_block_num() * get_subblockdim(); this->blockNum = get_block_num();
if (blockIdx >= blockNum) { if (blockIdx >= blockNum) {
return; return;

View File

@@ -0,0 +1,243 @@
#ifndef CATLASS_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_V2_ONLY_HPP
#define CATLASS_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_V2_ONLY_HPP
#include "catlass/catlass.hpp"
#include "catlass/arch/resource.hpp"
#include "catlass/epilogue/dispatch_policy.hpp"
#include "catlass/gemm_coord.hpp"
#include "catlass/matrix_coord.hpp"
#include "catlass/layout/layout.hpp"
#include "catlass/detail/callback.hpp"
#include "hccl_shmem.hpp"
#include "layout3d.hpp"
namespace Catlass::Epilogue::Block {
template <
uint32_t UB_STAGES_,
class CType_,
class LayoutPerTokenScale_,
class DType_,
class TileCopy_
>
class BlockEpilogue <
EpilogueAtlasA2PerTokenDequantV2<UB_STAGES_>,
CType_,
Gemm::GemmType<float, LayoutPerTokenScale_>,
DType_,
TileCopy_
> {
public:
using DispatchPolicy = EpilogueAtlasA2PerTokenDequantV2<UB_STAGES_>;
using ArchTag = typename DispatchPolicy::ArchTag;
static constexpr uint32_t UB_STAGES = UB_STAGES_;
// Data infos
using ElementC = typename CType_::Element;
using LayoutC = typename CType_::Layout;
using ElementPerTokenScale = float;
using LayoutPerTokenScale = LayoutPerTokenScale_;
using ElementD = typename DType_::Element;
using LayoutD = typename DType_::Layout;
using CopyScaleGmToUb = Epilogue::Tile::CopyGm2Ub<ArchTag, Gemm::GemmType<float, layout::VectorLayout>>;
using CopyGmToUbC = typename TileCopy_::CopyGmToUbC;
using CopyUbToGmD = typename TileCopy_::CopyUbToGmD;
struct Params {
__gm__ int32_t *ptrTokenPerExpert{nullptr};
int32_t EP;
int32_t expertPerRank;
int32_t n2;
LayoutC layoutC;
int32_t n0;
int32_t rank;
HcclShmem shmem;
int32_t offsetD;
CATLASS_DEVICE
Params() {};
CATLASS_DEVICE
Params(int32_t EP_, int32_t expertPerRank_, int32_t rank_, __gm__ int32_t *ptrTokenPerExpert_,
LayoutC layoutC_, int32_t n2_, int32_t n0_, HcclShmem& shmem_, int32_t offsetD_) :
ptrTokenPerExpert(ptrTokenPerExpert_), EP(EP_),
expertPerRank(expertPerRank_),rank(rank_), layoutC(layoutC_), n2(n2_), n0(n0_),
shmem(shmem_), offsetD(offsetD_)
{}
};
CATLASS_DEVICE
BlockEpilogue(Arch::Resource<ArchTag> const &resource, Params const &params = Params{}) : params(params)
{
//ub:192KB
n0 = params.n0;
size_t ubOffset = 0;
for(int32_t i = 0; i < 2; i++) {
ubCList[i] = resource.ubBuf.template GetBufferByByte<ElementC>(ubOffset);
ubOffset += max_len * sizeof(ElementC);
ubDList[i] = resource.ubBuf.template GetBufferByByte<ElementD>(ubOffset);
ubOffset += max_len * sizeof(ElementD);
ubFp32List[i] = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
ubOffset += max_len * sizeof(float);
scaleUbList[i] = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
ubOffset += (max_len / n0) * sizeof(float);
source_scale_offset[i] = -1;
}
tokenPerExpert.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(params.ptrTokenPerExpert));
tokenPerExpertLayout = Layout3D(AlignUp(params.EP * params.expertPerRank, 128), params.expertPerRank);
is_ping = true;
}
CATLASS_DEVICE
void Finalize()
{
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID1);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID2);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID3);
AscendC::WaitFlag<AscendC::HardEvent::S_MTE2>(EVENT_ID2);
AscendC::WaitFlag<AscendC::HardEvent::S_MTE2>(EVENT_ID3);
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID1);
}
CATLASS_DEVICE
~BlockEpilogue()
{
}
CATLASS_DEVICE
void operator() (
AscendC::GlobalTensor<ElementC> const &gmC,
AscendC::GlobalTensor<ElementPerTokenScale> const &gmPerTokenScale,
GemmCoord& blockCoord,
GemmCoord& actualBlockShape,
int32_t groupIdx,
int32_t preSrcExpertSum,
AscendC::GlobalTensor<int32_t> preSumBeforeRank
){
is_ping = !is_ping;
auto event_id = is_ping ? EVENT_ID0 : EVENT_ID1;
auto event_id_2 = is_ping ? EVENT_ID2 : EVENT_ID3;
auto &ubC = ubCList[is_ping];
auto &ubD = ubDList[is_ping];
int32_t gmCOffset = preSrcExpertSum * params.n2 + blockCoord.m() * params.n2 + blockCoord.n();
auto gmTileC = gmC[gmCOffset];
auto &ubCFp32 = ubFp32List[is_ping];
auto &scaleUb = scaleUbList[is_ping];
LayoutC layoutGM{actualBlockShape.m(), actualBlockShape.n(), params.n2};
LayoutC layoutUB{actualBlockShape.m(), actualBlockShape.n(), n0};
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(event_id);
copyGmToUbC(ubC, gmTileC, layoutUB, layoutGM);
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(event_id);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(event_id);
AscendC::Cast<float, ElementC, false>(ubCFp32, ubC, AscendC::RoundMode::CAST_NONE, -1, repeat, {1, 1, 8, 4});
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(event_id);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(event_id_2);
AscendC::WaitFlag<AscendC::HardEvent::S_MTE2>(event_id_2);
int32_t gmScaleOffset = preSrcExpertSum + blockCoord.m();
layout::VectorLayout scaleLauout{actualBlockShape.m()};
if (source_scale_offset[event_id] != gmScaleOffset) {
source_scale_offset[event_id] = gmScaleOffset;
copyScaleGmToUb(scaleUb, gmPerTokenScale[gmScaleOffset], scaleLauout, scaleLauout);
}
AscendC::SetFlag<AscendC::HardEvent::MTE2_S>(event_id_2);
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(event_id_2);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(event_id_2);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_S>(event_id_2); // Note that the value must be MTE2_S instead of MTE2_V.
// Otherwise, 0 will be read, causing garbled characters.
AscendC::PipeBarrier<PIPE_V>();
for (int32_t row = 0; row < actualBlockShape.m(); ++row) {
float scale = scaleUb(row);
Muls<float, false>(ubCFp32[n0* row], ubCFp32[n0 * row] , scale, -1, (actualBlockShape.n() + 127) / 128 * 2, {1, 1, 8, 8});
}
AscendC::PipeBarrier<PIPE_V>();
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(event_id);
AscendC::Cast<ElementD, float, false>(ubD, ubCFp32, AscendC::RoundMode::CAST_RINT, -1, repeat, {1, 1, 4, 8});
AscendC::SetFlag<AscendC::HardEvent::S_MTE2>(event_id_2);
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(event_id_2);
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(event_id);
int32_t lenTile = actualBlockShape.m();
int32_t stTile = blockCoord.m();
int32_t edTile = stTile + lenTile;
int32_t preSumRankInExpert = 0;
int32_t tileOffset = 0;
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(event_id);
for (int32_t dstEpIdx = 0; dstEpIdx < params.EP; dstEpIdx ++) {
int32_t lenRankInExpert = tokenPerExpert(tokenPerExpertLayout(dstEpIdx, params.rank, groupIdx));
int32_t dstExpertOffset = preSumBeforeRank(dstEpIdx * params.expertPerRank + groupIdx);
int32_t stRankInExpert = preSumRankInExpert;
int32_t edRankInExpert = stRankInExpert + lenRankInExpert;
preSumRankInExpert += lenRankInExpert;
if (stRankInExpert >= edTile) {
break;
}
else if (edRankInExpert <= stTile) {
continue;
}
int32_t stData = max(stRankInExpert, stTile);
int32_t edData = min(edRankInExpert, edTile);
uint32_t lenData = edData - stData;
if (lenData <= 0){
continue;
}
uint32_t dstOffsetInExpert = 0;
if (stTile > stRankInExpert) {
dstOffsetInExpert = stTile - stRankInExpert;
}
AscendC::GlobalTensor<ElementD> gmRemotePeer;
__gm__ void* dstPeermemPtr = params.shmem(params.offsetD, dstEpIdx);
gmRemotePeer.SetGlobalBuffer(reinterpret_cast<__gm__ ElementD*>(dstPeermemPtr));
MatrixCoord dstOffset{dstOffsetInExpert + dstExpertOffset, blockCoord.n()};
int64_t gmDstOffset = params.layoutC.GetOffset(dstOffset);
auto gmTileD = gmRemotePeer[gmDstOffset];
LayoutC layoutGM2{lenData, actualBlockShape.n(), params.n2};
LayoutC layoutUB2{lenData, actualBlockShape.n(), n0};
copyUbToGmD(gmTileD, ubD[tileOffset * n0], layoutGM2, layoutUB2);
tileOffset += lenData;
}
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(event_id);
}
private:
Params params;
AscendC::LocalTensor<ElementC> ubCList[UB_STAGES];
AscendC::LocalTensor<ElementD> ubDList[UB_STAGES];
AscendC::LocalTensor<float> ubFp32List[UB_STAGES];
AscendC::LocalTensor<float> scaleUbList[UB_STAGES];
int32_t source_scale_offset[UB_STAGES];
int32_t max_len = 8 * 32 / 4 * 128;
int32_t n0;
bool is_ping = false;
int32_t repeat = 128;
CopyGmToUbC copyGmToUbC;
CopyUbToGmD copyUbToGmD;
CopyScaleGmToUb copyScaleGmToUb;
AscendC::GlobalTensor<int32_t> tokenPerExpert;
Layout3D tokenPerExpertLayout;
};
}
#endif

View File

@@ -22,8 +22,6 @@
namespace Catlass::Gemm::Block { namespace Catlass::Gemm::Block {
constexpr uint16_t CROSS_CORE_FLAG_MAX_SET_COUNT = 15;
template<AscendC::HardEvent event> template<AscendC::HardEvent event>
__aicore__ inline void SyncFlagFunc(int32_t eventID) __aicore__ inline void SyncFlagFunc(int32_t eventID)
{ {
@@ -153,9 +151,11 @@ public:
L1TileShape::K, L1TileShape::N); L1TileShape::K, L1TileShape::N);
CATLASS_DEVICE CATLASS_DEVICE
BlockMmad(Arch::Resource<ArchTag> &resource, uint32_t l1BufAddrStart = 0) BlockMmad(Arch::Resource<ArchTag> &resource, __gm__ int32_t* flagPtr = nullptr, int32_t expertPerRank = 0, uint32_t l1BufAddrStart = 0)
{ {
syncGroupIdx = 0; syncGroupIdx = 0;
ptrSoftFlagBase_ = flagPtr;
expertPerRank_ = expertPerRank;
InitL1(resource, l1BufAddrStart); InitL1(resource, l1BufAddrStart);
InitL0A(resource); InitL0A(resource);
InitL0B(resource); InitL0B(resource);
@@ -272,11 +272,23 @@ public:
CATLASS_DEVICE CATLASS_DEVICE
void Finalize(int32_t target, int32_t flag = 0) void Finalize(int32_t target, int32_t flag = 0)
{ {
if (ptrSoftFlagBase_ != nullptr) {
if (target < 0) {
return;
}
AscendC::SetFlag<AscendC::HardEvent::FIX_MTE3>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::FIX_MTE3>(EVENT_ID0);
AscendC::GlobalTensor<int32_t> flagGlobal;
flagGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(ptrSoftFlagBase_) + (expertPerRank_ + AscendC::GetBlockIdx()) * FLAGSTRIDE);
AscendC::DataCopy(flagGlobal, l1FTensor[target * 16], FLAGSTRIDE);
}
else {
for(;syncGroupIdx <= target; syncGroupIdx++) { for(;syncGroupIdx <= target; syncGroupIdx++) {
int32_t flagId = syncGroupIdx / CROSS_CORE_FLAG_MAX_SET_COUNT + flag; int32_t flagId = syncGroupIdx / 15 + flag;
AscendC::CrossCoreSetFlag<0x2, PIPE_FIX>(flagId); AscendC::CrossCoreSetFlag<0x2, PIPE_FIX>(flagId);
} }
} }
}
private: private:
struct L1TileMmadParams { struct L1TileMmadParams {
uint32_t l1ListId; uint32_t l1ListId;
@@ -291,7 +303,6 @@ private:
layout::VectorLayout layoutScale; layout::VectorLayout layoutScale;
int32_t syncLoopIdx; int32_t syncLoopIdx;
int32_t flag; int32_t flag;
CATLASS_DEVICE CATLASS_DEVICE
L1TileMmadParams() = default; L1TileMmadParams() = default;
}; };
@@ -310,11 +321,24 @@ private:
AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(l1AEventList[i]); AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(l1AEventList[i]);
AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(l1BEventList[i]); AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(l1BEventList[i]);
} }
if constexpr (std::is_same_v<ElementA, int8_t>) {
uint32_t l1SOffset = l1BOffset + L1B_TILE_SIZE * L1_STAGES; uint32_t l1SOffset = l1BOffset + L1B_TILE_SIZE * L1_STAGES;
if constexpr (std::is_same_v<ElementA, int8_t>) {
l1STensor = resource.l1Buf.template GetBufferByByte<uint64_t>(l1SOffset); l1STensor = resource.l1Buf.template GetBufferByByte<uint64_t>(l1SOffset);
AscendC::SetFlag<AscendC::HardEvent::FIX_MTE2>(0); AscendC::SetFlag<AscendC::HardEvent::FIX_MTE2>(0);
} }
if (ptrSoftFlagBase_ != nullptr) {
// Initialize the flag matrix (structure as below):
// 1 0 0 0 0 0 0 0
// 2 0 0 0 0 0 0 0
// ...
// 16 0 0 0 0 0 0 0
// Then move it to L1
uint32_t l1FOffset = l1SOffset + L1S_TILE_SIZE;
l1FTensor = resource.l1Buf.template GetBufferByByte<int32_t>(l1FOffset);
AscendC::GlobalTensor<int32_t> flagBase;
flagBase.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(ptrSoftFlagBase_));
AscendC::DataCopy(l1FTensor, flagBase, expertPerRank_ * FLAGSTRIDE);
}
} }
CATLASS_DEVICE CATLASS_DEVICE
@@ -463,12 +487,20 @@ private:
if constexpr (std::is_same_v<ElementA, int8_t>) { if constexpr (std::is_same_v<ElementA, int8_t>) {
AscendC::SetFlag<AscendC::HardEvent::FIX_MTE2>(0); AscendC::SetFlag<AscendC::HardEvent::FIX_MTE2>(0);
} }
#ifdef __TILE_SYNC__
if (params.flag > 0) {
int32_t flagId = params.flag + params.syncLoopIdx / 8;
AscendC::CrossCoreSetFlag<0x2, PIPE_FIX>(flagId);
}
#else
Finalize(params.syncLoopIdx, params.flag); Finalize(params.syncLoopIdx, params.flag);
#endif
} }
} }
AscendC::LocalTensor<ElementA> l1ATensorList[L1_STAGES]; AscendC::LocalTensor<ElementA> l1ATensorList[L1_STAGES];
AscendC::LocalTensor<ElementB> l1BTensorList[L1_STAGES]; AscendC::LocalTensor<ElementB> l1BTensorList[L1_STAGES];
AscendC::LocalTensor<uint64_t> l1STensor; AscendC::LocalTensor<uint64_t> l1STensor;
AscendC::LocalTensor<int32_t> l1FTensor;
int32_t syncGroupIdx; int32_t syncGroupIdx;
int32_t l1AEventList[L1_STAGES]; int32_t l1AEventList[L1_STAGES];
int32_t l1BEventList[L1_STAGES]; int32_t l1BEventList[L1_STAGES];
@@ -497,6 +529,9 @@ private:
CopyL1ToL0A copyL1ToL0A; CopyL1ToL0A copyL1ToL0A;
CopyL1ToL0B copyL1ToL0B; CopyL1ToL0B copyL1ToL0B;
CopyL0CToGm copyL0CToGm; CopyL0CToGm copyL0CToGm;
__gm__ int32_t* ptrSoftFlagBase_ = nullptr;
int32_t expertPerRank_;
}; };
} // namespace Catlass::Gemm::Block } // namespace Catlass::Gemm::Block

View File

@@ -5,5 +5,7 @@ constexpr static uint64_t MB_SIZE = 1024 * 1024UL;
constexpr static int32_t NUMS_PER_FLAG = 16; constexpr static int32_t NUMS_PER_FLAG = 16;
constexpr static int32_t CACHE_LINE = 512; constexpr static int32_t CACHE_LINE = 512;
constexpr static int32_t RESET_VAL = 0xffff; constexpr static int32_t RESET_VAL = 0xffff;
constexpr static int32_t ALIGN_128 = 128; constexpr static int32_t FLAGSTRIDE = 16;
constexpr static int32_t UB_ALIGN = 32;
constexpr uint16_t CROSS_CORE_FLAG_MAX_SET_COUNT = 15;
#endif #endif

View File

@@ -33,13 +33,13 @@ namespace Catlass::Epilogue {
}; };
template <uint32_t UB_STAGES_> template <uint32_t UB_STAGES_>
struct EpilogueAtlasA2PerTokenDequantQuant { struct EpilogueAtlasA2PerTokenDequantSwigluQuant {
using ArchTag = Arch::AtlasA2; using ArchTag = Arch::AtlasA2;
static constexpr uint32_t UB_STAGES = UB_STAGES_; static constexpr uint32_t UB_STAGES = UB_STAGES_;
}; };
template <uint32_t UB_STAGES_> template <uint32_t UB_STAGES_>
struct EpilogueAtlasA2PerTokenDequantSwigluQuant { struct EpilogueAtlasA2PerTokenDequantV2 {
using ArchTag = Arch::AtlasA2; using ArchTag = Arch::AtlasA2;
static constexpr uint32_t UB_STAGES = UB_STAGES_; static constexpr uint32_t UB_STAGES = UB_STAGES_;
}; };

View File

@@ -5,13 +5,28 @@
#include "kernel_operator.h" #include "kernel_operator.h"
#include "const_args.hpp" #include "const_args.hpp"
#ifdef HCCL_COMM
#include "moe_distribute_base.h" #include "moe_distribute_base.h"
using namespace AscendC::HcclContextDef;
#ifndef HCCL_COMM #else
#include "shmem_api.h" #include "shmem_api.h"
#endif #endif
#define FORCE_INLINE_AICORE inline __attribute__((always_inline)) __aicore__ #define FORCE_INLINE_AICORE inline __attribute__((always_inline)) __aicore__
constexpr int32_t MAX_RANK_SIZE = 32;
constexpr int32_t SHMEM_MEM = 700 * MB_SIZE;
constexpr uint16_t SEND_SYNC_EVENT_ID = 9;
constexpr uint16_t RECV_SYNC_EVENT_ID = 10;
constexpr uint32_t SELF_STATE_OFFSET = 256 * 1024;
constexpr uint32_t STATE_OFFSET = 512;
FORCE_INLINE_AICORE void AicSyncAll() {
AscendC::CrossCoreSetFlag<0x0, PIPE_FIX>(8);
AscendC::CrossCoreWaitFlag<0x0>(8);
}
template<typename T> template<typename T>
FORCE_INLINE_AICORE void gm_store(__gm__ T *addr, T val) { FORCE_INLINE_AICORE void gm_store(__gm__ T *addr, T val) {
@@ -23,10 +38,11 @@ FORCE_INLINE_AICORE T gm_load(__gm__ T *cache) {
return *((__gm__ T *)cache); return *((__gm__ T *)cache);
} }
FORCE_INLINE_AICORE void gm_dcci(__gm__ uint8_t * addr) { template<typename T>
FORCE_INLINE_AICORE void gm_dcci(__gm__ T * addr) {
using namespace AscendC; using namespace AscendC;
GlobalTensor<uint8_t> global; GlobalTensor<uint8_t> global;
global.SetGlobalBuffer(addr); global.SetGlobalBuffer(reinterpret_cast<GM_ADDR>(addr));
// Important: add hint to avoid dcci being optimized by compiler // Important: add hint to avoid dcci being optimized by compiler
__asm__ __volatile__(""); __asm__ __volatile__("");
@@ -37,26 +53,20 @@ FORCE_INLINE_AICORE void gm_dcci(__gm__ uint8_t * addr) {
FORCE_INLINE_AICORE int32_t gm_signal_wait_until_eq_for_barrier(__gm__ int32_t *sig_addr, int32_t cmp_val) { FORCE_INLINE_AICORE int32_t gm_signal_wait_until_eq_for_barrier(__gm__ int32_t *sig_addr, int32_t cmp_val) {
do { do {
gm_dcci((__gm__ uint8_t *)sig_addr); gm_dcci((__gm__ uint8_t *)sig_addr);
if (*sig_addr == cmp_val) { if (*sig_addr == cmp_val) {
return *sig_addr; return *sig_addr;
} }
// in case when peer pe enters next barrier
if (*sig_addr == cmp_val + 1) { if (*sig_addr == cmp_val + 1) {
return *sig_addr; return *sig_addr;
} }
} while (true); } while (true);
// never reach
return -1; return -1;
} }
FORCE_INLINE_AICORE void gm_signal_wait_until_ne(__gm__ int32_t *sig_addr, int32_t cmp_val) { FORCE_INLINE_AICORE void gm_signal_wait_until_ne(__gm__ int32_t *sig_addr, int32_t cmp_val) {
do { do {
AscendC::LocalTensor<int32_t> ub; AscendC::LocalTensor<int32_t> ub;
ub.address_.logicPos = static_cast<uint8_t>(TPosition::VECIN); ub.address_.logicPos = static_cast<uint8_t>(AscendC::TPosition::VECIN);
ub.address_.bufferAddr = 0; ub.address_.bufferAddr = 0;
AscendC::GlobalTensor<int32_t> sig; AscendC::GlobalTensor<int32_t> sig;
sig.SetGlobalBuffer(sig_addr); sig.SetGlobalBuffer(sig_addr);
@@ -71,16 +81,12 @@ FORCE_INLINE_AICORE void gm_signal_wait_until_ne(__gm__ int32_t *sig_addr, int32
} }
constexpr int32_t MAX_RANK_SIZE = 32;
class HcclShmem { class HcclShmem {
public: public:
#ifdef HCCL_COMM // HCCL needs to initialize the HCCL context #ifdef HCCL_COMM // HCCL needs to initialize the HCCL context
__gm__ HcclOpResParamCustom *WinContext_{nullptr}; __gm__ HcclOpResParamCustom *WinContext_{nullptr};
Hccl<HCCL_SERVER_TYPE_AICPU> hccl_; Hccl<HCCL_SERVER_TYPE_AICPU> hccl_;
size_t m_segmentSize; AscendC::LocalTensor<int32_t> ub;
int32_t m_rank;
int32_t m_rankSize;
FORCE_INLINE_AICORE FORCE_INLINE_AICORE
HcclShmem(){ HcclShmem(){
auto contextGM0 = AscendC::GetHcclContext<HCCL_GROUP_ID_0>(); auto contextGM0 = AscendC::GetHcclContext<HCCL_GROUP_ID_0>();
@@ -89,41 +95,39 @@ public:
m_rank = WinContext_->localUsrRankId; m_rank = WinContext_->localUsrRankId;
m_rankSize = WinContext_->rankSize; m_rankSize = WinContext_->rankSize;
m_segmentSize = WinContext_->winSize; m_segmentSize = WinContext_->winSize;
} }
#else
FORCE_INLINE_AICORE FORCE_INLINE_AICORE
size_t SegmentSize() const { HcclShmem(){
return m_segmentSize; m_segmentSize = SHMEM_MEM;
} }
FORCE_INLINE_AICORE FORCE_INLINE_AICORE
int32_t RankSize() const { void initShmem(GM_ADDR symmetricPtr_, size_t rank, size_t rankSize) {
return m_rankSize; symmetricPtr = symmetricPtr_;
m_rank = rank;
m_rankSize = rankSize;
} }
#endif #endif
FORCE_INLINE_AICORE FORCE_INLINE_AICORE
GM_ADDR operator() () const { // No argument: return local peermem GM_ADDR operator() () const { // No parameters: return pointer to local peermem
#ifdef HCCL_COMM #ifdef HCCL_COMM
return (GM_ADDR)(WinContext_->localWindowsIn); return (GM_ADDR)(WinContext_->localWindowsIn);
#else #else
return reinterpret_cast<GM_ADDR>(shmemi_get_state()->heap_base); return reinterpret_cast<GM_ADDR>(shmem_ptr(symmetricPtr, m_rank));
#endif #endif
} }
FORCE_INLINE_AICORE FORCE_INLINE_AICORE
GM_ADDR operator() (int32_t index) const { // With index: return remote peermem base address GM_ADDR operator() (int32_t index) const { // With index parameter: return pointer to the base address of remote peermem
#ifdef HCCL_COMM #ifdef HCCL_COMM
return (GM_ADDR)((index == m_rank) ? WinContext_->localWindowsIn : return (GM_ADDR)((index == m_rank) ? WinContext_->localWindowsIn :
((HcclRankRelationResV2Custom *)(WinContext_->remoteRes[index].nextDevicePtr))->windowsIn); ((HcclRankRelationResV2Custom *)(WinContext_->remoteRes[index].nextDevicePtr))->windowsIn);
#else #else
return reinterpret_cast<GM_ADDR>(shmem_ptr(shmemi_get_state()->heap_base, index)); return reinterpret_cast<GM_ADDR>(shmem_ptr(symmetricPtr, index));
#endif #endif
} }
FORCE_INLINE_AICORE FORCE_INLINE_AICORE
GM_ADDR operator () (int64_t offset, int32_t rankId) const { GM_ADDR operator () (int64_t offset, int32_t rankId) const {
#ifdef HCCL_COMM #ifdef HCCL_COMM
@@ -136,15 +140,28 @@ public:
return (GM_ADDR)((rankId == m_rank) ? WinContext_->localWindowsIn : return (GM_ADDR)((rankId == m_rank) ? WinContext_->localWindowsIn :
((HcclRankRelationResV2Custom *)(WinContext_->remoteRes[rankId].nextDevicePtr))->windowsIn) + offset; ((HcclRankRelationResV2Custom *)(WinContext_->remoteRes[rankId].nextDevicePtr))->windowsIn) + offset;
#else #else
return shmem_ptr(shmemi_get_state()->heap_base + offset, rankId); return reinterpret_cast<GM_ADDR>(shmem_ptr((symmetricPtr + offset), rankId));
#endif #endif
} }
FORCE_INLINE_AICORE
size_t SegmentSize() const {
return m_segmentSize;
}
FORCE_INLINE_AICORE
int32_t RankSize() const {
return m_rankSize;
}
FORCE_INLINE_AICORE FORCE_INLINE_AICORE
~HcclShmem() { ~HcclShmem() {
} }
FORCE_INLINE_AICORE FORCE_INLINE_AICORE
void CrossRankSync() { void CrossRankSync() {
uint64_t flag_offset = (m_segmentSize - MB_SIZE) / sizeof(int32_t); uint64_t flag_offset = (m_segmentSize - MB_SIZE) / sizeof(int32_t);
@@ -165,12 +182,146 @@ public:
gm_store(sync_base, count); gm_store(sync_base, count);
} }
FORCE_INLINE_AICORE
void InitStatusTargetSum()
{
using namespace AscendC;
uint64_t flag_offset = (m_segmentSize - MB_SIZE) + SELF_STATE_OFFSET;
//uint64_t self_state_offset = (m_segmentSize - 2 * MB_SIZE);
// ep state
//uint32_t coreIdx = get_block_idx();;
uint32_t coreIdx = GetBlockIdx();
GlobalTensor<int32_t> selfStatusTensor;
selfStatusTensor.SetGlobalBuffer((__gm__ int32_t *)((*this)() + flag_offset));
__asm__ __volatile__("");
DataCacheCleanAndInvalid<int32_t, CacheLine::SINGLE_CACHE_LINE, DcciDst::CACHELINE_OUT>(selfStatusTensor[coreIdx * UB_ALIGN]);
__asm__ __volatile__("");
int32_t state = selfStatusTensor(coreIdx * UB_ALIGN);
if (state == 0) {
sumTarget_ = static_cast<float>(1.0);
selfStatusTensor(coreIdx * UB_ALIGN) = 0x3F800000; // 1.0f
epStateValue_ = 0x3F800000; // 1.0f
} else {
sumTarget_ = static_cast<float>(0.0);
selfStatusTensor(coreIdx * UB_ALIGN) = 0;
epStateValue_ = 0;
}
__asm__ __volatile__("");
DataCacheCleanAndInvalid<int32_t, CacheLine::SINGLE_CACHE_LINE, DcciDst::CACHELINE_OUT>(selfStatusTensor[coreIdx * UB_ALIGN]);
__asm__ __volatile__("");
}
FORCE_INLINE_AICORE
void CrossRankSyncV2Set(AscendC::LocalTensor<int32_t> ctrBuffer) {
//subblockid = 0
uint32_t stateOffset_ = STATE_OFFSET;
// uint32_t epStateOffsetOnWin_ = m_rank * stateOffset_;
uint64_t flag_offset = (m_segmentSize - MB_SIZE) + m_rank * stateOffset_;
//uint64_t flag_offset = (m_segmentSize - MB_SIZE);
int vec_size = get_block_num();
int vec_id = get_block_idx();
AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>(RECV_SYNC_EVENT_ID);
AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>(SEND_SYNC_EVENT_ID);
AscendC::CrossCoreWaitFlag(SEND_SYNC_EVENT_ID);
pipe_barrier(PIPE_ALL);
ctrBuffer.SetValue(0, epStateValue_);
AscendC::SetFlag<AscendC::HardEvent::S_MTE3>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::S_MTE3>(EVENT_ID0);
for (uint32_t dstEpIdx = vec_id; dstEpIdx < m_rankSize; dstEpIdx += vec_size) {
AscendC::GlobalTensor<int32_t> gmDstStates;
gmDstStates.SetGlobalBuffer((__gm__ int32_t*)((*this)(flag_offset, dstEpIdx)));
DataCopy(gmDstStates, ctrBuffer, 8);
}
AscendC::CrossCoreWaitFlag(RECV_SYNC_EVENT_ID);
}
FORCE_INLINE_AICORE
void CrossRankSyncV2Wait(AscendC::LocalTensor<float> statusTensor, AscendC::LocalTensor<float> gatherMaskOutTensor,
AscendC::LocalTensor<uint32_t> gatherTmpTensor, AscendC::LocalTensor<float> statusSumOutTensor) {
uint64_t flag_offset = (m_segmentSize - MB_SIZE);
int vec_size = get_block_num();
int vec_id = get_block_idx();
uint32_t stateOffset_ = STATE_OFFSET;
uint32_t sendRankNum_ = m_rankSize / vec_size;
uint32_t remainderRankNum = m_rankSize % vec_size;
uint32_t startRankId_ = sendRankNum_ * vec_id;
if (vec_id < remainderRankNum) {
sendRankNum_++;
startRankId_ += vec_id;
} else {
startRankId_ += remainderRankNum;
}
uint32_t endRankId_ = startRankId_ + sendRankNum_;
AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>(SEND_SYNC_EVENT_ID);
AscendC::GlobalTensor<float> epStatusSpaceGlobalTensor_;
epStatusSpaceGlobalTensor_.SetGlobalBuffer((__gm__ float *)((*this)() + flag_offset));
if (startRankId_ < m_rankSize) {
AscendC::PipeBarrier<PIPE_ALL>();
gatherTmpTensor.SetValue(0, 1);
uint32_t mask = 1; // gatherMask + sum
uint64_t rsvdCnt = 0;
// DataCopyParams intriParams{static_cast<uint16_t>(sendRankNum_), 1,
// static_cast<uint16_t>((moeSendNum_ > 512) ? 7 : 15), 0};
AscendC::DataCopyParams intriParams{static_cast<uint16_t>(sendRankNum_), 1,
static_cast<uint16_t>(15), 0};
float sumOfFlag = static_cast<float>(-1.0);
float minTarget = (sumTarget_ * sendRankNum_) - (float)0.5;
float maxTarget = (sumTarget_ * sendRankNum_) + (float)0.5;
AscendC::SumParams sumParams{1, sendRankNum_, sendRankNum_};
AscendC::SetFlag<AscendC::HardEvent::S_V>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::S_V>(EVENT_ID0);
while ((sumOfFlag < minTarget) || (sumOfFlag > maxTarget)) {
AscendC::DataCopy<float>(statusTensor, epStatusSpaceGlobalTensor_[startRankId_ * stateOffset_ / sizeof(float)],
intriParams);
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
GatherMask(gatherMaskOutTensor, statusTensor, gatherTmpTensor, true, mask,
{1, (uint16_t)sendRankNum_, 1, 0}, rsvdCnt);
AscendC::PipeBarrier<PIPE_V>();
AscendC::Sum(statusSumOutTensor, gatherMaskOutTensor, sumParams);
AscendC::SetFlag<AscendC::HardEvent::V_S>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::V_S>(EVENT_ID0);
sumOfFlag = statusSumOutTensor.GetValue(0);
}
}
AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>(RECV_SYNC_EVENT_ID);
AscendC::CrossCoreWaitFlag(RECV_SYNC_EVENT_ID);
//unpermute
AscendC::CrossCoreWaitFlag(SEND_SYNC_EVENT_ID);
}
FORCE_INLINE_AICORE FORCE_INLINE_AICORE
__gm__ int32_t* SyncBaseAddr() { __gm__ int32_t* SyncBaseAddr() {
uint64_t flag_offset = (m_segmentSize - MB_SIZE) / sizeof(int32_t); uint64_t flag_offset = (m_segmentSize - MB_SIZE) / sizeof(int32_t);
return (__gm__ int32_t*)(*this)() + flag_offset + 2048; return (__gm__ int32_t*)(*this)() + flag_offset + 2048;
} }
private:
GM_ADDR symmetricPtr;
int32_t m_rank;
int32_t m_rankSize;
size_t m_segmentSize;
float sumTarget_{0.0};
int32_t epStateValue_;
}; };
#endif #endif