diff --git a/csrc/build_aclnn.sh b/csrc/build_aclnn.sh index 631ea4e3..1d0a3742 100644 --- a/csrc/build_aclnn.sh +++ b/csrc/build_aclnn.sh @@ -39,12 +39,6 @@ elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then exit 1 fi fi - # dependency: cann-toolkit file moe_distribute_base.h - HCCL_STRUCT_FILE_PATH=$(find -L "${ASCEND_TOOLKIT_HOME}" -name "moe_distribute_base.h" 2>/dev/null | head -n1) - if [ -z "$HCCL_STRUCT_FILE_PATH" ]; then - echo "cannot find moe_distribute_base.h file in CANN env" - exit 1 - fi # for dispatch_gmm_combine_decode yes | cp "${HCCL_STRUCT_FILE_PATH}" "${ROOT_DIR}/csrc/utils/inc/kernel" diff --git a/csrc/dispatch_gmm_combine_decode/op_host/CMakeLists.txt b/csrc/dispatch_gmm_combine_decode/op_host/CMakeLists.txt index 7039b61f..7a936dcf 100644 --- a/csrc/dispatch_gmm_combine_decode/op_host/CMakeLists.txt +++ b/csrc/dispatch_gmm_combine_decode/op_host/CMakeLists.txt @@ -47,6 +47,7 @@ target_sources(optiling PRIVATE target_include_directories(optiling PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} + ${CMAKE_SOURCE_DIR}/utils/inc/kernel ) target_sources(opsproto PRIVATE diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode_base.h b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode_base.h index cd4dd6b9..06f7e62f 100644 --- a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode_base.h +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode_base.h @@ -10,7 +10,7 @@ #ifndef DISPATCH_GMM_COMBINE_DECODE_BASE_H #define DISPATCH_GMM_COMBINE_DECODE_BASE_H -#include "../common/moe_distribute_base.h" +#include "moe_distribute_base.h" #define TemplateMC2TypeClass typename ExpandXType, typename W1ScaleType, typename W2ScaleType, typename WType, typename ExpandIdxType, bool IsNeedReduceScatter, uint32_t EXEC_FLAG #define TemplateMC2TypeFunc ExpandXType, W1ScaleType, W2ScaleType, WType, ExpandIdxType, IsNeedReduceScatter, EXEC_FLAG diff --git a/csrc/moe_combine_normal/op_host/CMakeLists.txt b/csrc/moe_combine_normal/op_host/CMakeLists.txt index 190adfe1..09720665 100644 --- a/csrc/moe_combine_normal/op_host/CMakeLists.txt +++ b/csrc/moe_combine_normal/op_host/CMakeLists.txt @@ -38,6 +38,7 @@ target_sources(optiling PRIVATE target_include_directories(optiling PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} + ${CMAKE_SOURCE_DIR}/utils/inc/kernel ) target_sources(opsproto PRIVATE) diff --git a/csrc/moe_combine_normal/op_kernel/moe_combine_normal.h b/csrc/moe_combine_normal/op_kernel/moe_combine_normal.h index 156e7248..5235cc6e 100644 --- a/csrc/moe_combine_normal/op_kernel/moe_combine_normal.h +++ b/csrc/moe_combine_normal/op_kernel/moe_combine_normal.h @@ -3,7 +3,7 @@ #include "kernel_operator.h" #include "kernel_tiling/kernel_tiling.h" -#include "../common/moe_distribute_base.h" +#include "moe_distribute_base.h" #include "moe_combine_normal_tiling.h" namespace MoeCombineNormalImpl { diff --git a/csrc/moe_dispatch_normal/op_host/CMakeLists.txt b/csrc/moe_dispatch_normal/op_host/CMakeLists.txt index c6afc9f5..3d8ae1f3 100644 --- a/csrc/moe_dispatch_normal/op_host/CMakeLists.txt +++ b/csrc/moe_dispatch_normal/op_host/CMakeLists.txt @@ -38,6 +38,7 @@ target_sources(optiling PRIVATE target_include_directories(optiling PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} + ${CMAKE_SOURCE_DIR}/utils/inc/kernel ) target_sources(opsproto PRIVATE) diff --git a/csrc/moe_dispatch_normal/op_kernel/moe_dispatch_normal.h b/csrc/moe_dispatch_normal/op_kernel/moe_dispatch_normal.h index 2af4e580..79b73474 100644 --- a/csrc/moe_dispatch_normal/op_kernel/moe_dispatch_normal.h +++ b/csrc/moe_dispatch_normal/op_kernel/moe_dispatch_normal.h @@ -3,7 +3,7 @@ #include "kernel_operator.h" #include "kernel_tiling/kernel_tiling.h" -#include "../common/moe_distribute_base.h" +#include "moe_distribute_base.h" #include "moe_dispatch_normal_tiling.h" namespace MoeDispatchNormalImpl { diff --git a/csrc/utils/inc/kernel/moe_distribute_base.h b/csrc/utils/inc/kernel/moe_distribute_base.h new file mode 100644 index 00000000..b74df9a5 --- /dev/null +++ b/csrc/utils/inc/kernel/moe_distribute_base.h @@ -0,0 +1,288 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file moe_distribute_base.h + * \brief + */ + +#ifndef MOE_DISTRIBUTE_BASE_H +#define MOE_DISTRIBUTE_BASE_H + +constexpr uint32_t LOCAL_NOTIFY_MAX_NUM = 64; +constexpr uint32_t LOCAL_STREAM_MAX_NUM = 19; +constexpr uint32_t AICPU_OP_NOTIFY_MAX_NUM = 2; +constexpr uint32_t AICPU_MAX_RANK_NUM = 128 * 1024; + +struct HcclSignalInfo { + uint64_t resId; + uint64_t addr; + uint32_t devId; + uint32_t tsId; + uint32_t rankId; + uint32_t flag; +}; + +struct ListCommon { + uint64_t nextHost; + uint64_t preHost; + uint64_t nextDevice; + uint64_t preDevice; +}; + +struct HcclStreamInfo { + int32_t streamIds; + uint32_t sqIds; + uint32_t cqIds; + uint32_t logicCqids; +}; + +struct LocalResInfoV2 { + uint32_t streamNum; + uint32_t signalNum; + HcclSignalInfo localSignals[LOCAL_NOTIFY_MAX_NUM]; + HcclStreamInfo streamInfo[LOCAL_STREAM_MAX_NUM]; + HcclStreamInfo mainStreamInfo; + HcclSignalInfo aicpuOpNotify[AICPU_OP_NOTIFY_MAX_NUM]; + ListCommon nextTagRes; // HccltagLocalResV2 +}; + +enum class rtFloatOverflowMode_t { + RT_OVERFLOW_MODE_SATURATION = 0, + RT_OVERFLOW_MODE_INFNAN, + RT_OVERFLOW_MODE_UNDEF, +}; + +struct AlgoTopoInfo { + uint32_t userRank; + uint32_t userRankSize; + int32_t deviceLogicId; + bool isSingleMeshAggregation; + uint32_t deviceNumPerAggregation; + uint32_t superPodNum; + uint32_t devicePhyId; + uint32_t topoType; // TopoType + uint32_t deviceType; + uint32_t serverNum; + uint32_t meshAggregationRankSize; + uint32_t multiModuleDiffDeviceNumMode; + uint32_t multiSuperPodDiffServerNumMode; + uint32_t realUserRank; + bool isDiffDeviceModule; + bool isDiffDeviceType; + uint32_t gcdDeviceNumPerAggregation; + uint32_t moduleNum; + uint32_t isUsedRdmaRankPairNum; + uint64_t isUsedRdmaRankPair; + uint32_t pairLinkCounterNum; + uint64_t pairLinkCounter; + uint32_t nicNum; + uint64_t nicList; + uint64_t complanRankLength; + uint64_t complanRank; + uint64_t bridgeRankNum; + uint64_t bridgeRank; + uint64_t serverAndsuperPodRankLength; + uint64_t serverAndsuperPodRank; +}; + +struct HcclOpConfig { + uint8_t deterministic; + uint8_t retryEnable; + uint8_t highPerfEnable; + uint8_t padding[5]; + uint8_t linkTimeOut[8]; + uint64_t notifyWaitTime; + uint32_t retryHoldTime; + uint32_t retryIntervalTime; + bool interHccsDisable = false; + rtFloatOverflowMode_t floatOverflowMode = rtFloatOverflowMode_t::RT_OVERFLOW_MODE_UNDEF; + uint32_t multiQpThreshold = 512; +}; + +struct HcclMC2WorkSpace { + uint64_t workSpace; + uint64_t workSpaceSize; +}; + +struct RemoteResPtr { + uint64_t nextHostPtr; + uint64_t nextDevicePtr; +}; + +struct HDCommunicateParams { + uint64_t hostAddr { 0 }; + uint64_t deviceAddr { 0 }; + uint64_t readCacheAddr { 0 }; + uint32_t devMemSize{ 0 }; + uint32_t buffLen{ 0 }; + uint32_t flag{ 0 }; +}; + +struct HcclRankRelationResV2 { + uint32_t remoteUsrRankId; + uint32_t remoteWorldRank; + uint64_t windowsIn; + uint64_t windowsOut; + uint64_t windowsExp; + ListCommon nextTagRes; +}; + +struct HcclOpResParam { + HcclMC2WorkSpace mc2WorkSpace; + uint32_t localUsrRankId; // usrrankid + uint32_t rankSize; + uint64_t winSize; + uint64_t localWindowsIn; + uint64_t localWindowsOut; + char hcomId[128]; + uint64_t winExpSize; + uint64_t localWindowsExp; + uint32_t rWinStart; + uint32_t rWinOffset; + uint64_t version; + LocalResInfoV2 localRes; + AlgoTopoInfo topoInfo; + + HcclOpConfig config; + uint64_t hostStateInfo; + uint64_t aicpuStateInfo; + uint64_t lockAddr; + uint32_t rsv[16]; + uint32_t notifysize; + uint32_t remoteResNum; + RemoteResPtr remoteRes[AICPU_MAX_RANK_NUM]; + + // communicate retry + HDCommunicateParams kfcControlTransferH2DParams; + HDCommunicateParams kfcStatusTransferD2HParams; + uint64_t tinyMem; // for all2all + uint64_t tinyMemSize; + uint64_t zeroCopyHeadPtr; + uint64_t zeroCopyTailPtr; + uint64_t zeroCopyRingBuffer; + uint64_t zeroCopyIpcPtrs[16]; + uint32_t zeroCopyDevicePhyId[16]; + + bool utraceStatusFlag; +}; + +// Transport +enum class HcclAiRMAMemType : uint32_t { + LOCAL_INPUT = 0, + REMOTE_INPUT, + LOCAL_OUTPUT, + REMOTE_OUTPUT, + MAX_NUM +}; + +struct HcclAiRMAMemInfo { + uint32_t memMaxNum{0}; + uint32_t sizeOfMemDetails{0}; + uint64_t memDetailPtr{0}; +}; + +// Transport QP/Mem +struct HcclAiRMAInfo { + uint32_t curRankId{0}; + uint32_t rankNum{0}; + uint32_t qpNum{0}; + uint32_t sizeOfAiRMAWQ{0}; // sizeof(HcclAiRMAWQ) + uint32_t sizeOfAiRMACQ{0}; // sizeof(HcclAiRMACQ) + uint32_t sizeOfAiRMAMem{0}; // sizeof(HcclAiRMAMemInfo) + uint64_t sqPtr{0}; + uint64_t scqPtr{0}; + uint64_t rqPtr{0}; + uint64_t rcqPtr{0}; + uint64_t memPtr{0}; +}; + +struct HcclA2CombineOpParam { + uint64_t workSpace; // Address for communication between client and server, + // hccl requests and clears + uint64_t workSpaceSize; // Space for communication between client and server + uint32_t rankId; // id of this rank + uint32_t rankNum; // num of ranks in this comm group + uint64_t winSize; // size of each windows memory + uint64_t windowsIn[AscendC::HCCL_MAX_RANK_NUM]; // windows address for input, windowsIn[rankId] corresponds + // to the local card address, + // and others are cross-card mapping addresses. + uint64_t windowsOut[AscendC::HCCL_MAX_RANK_NUM]; // windows address for output, windowsOut[rankId] corresponds + // to the local card address, + // and others are cross-card mapping addresses. + uint8_t res[8328]; + uint8_t multiFlag; + __gm__ AscendC::IbVerbsData *data; + uint64_t dataSize; + + uint64_t sizeOfAiRMAInfo; + uint64_t aiRMAInfo; +}; + +enum class DBMode : int32_t { + INVALID_DB = -1, + HW_DB = 0, + SW_DB +}; + +struct HcclAiRMAWQ { + uint32_t wqn{0}; + uint64_t bufAddr{0}; + uint32_t wqeSize{0}; + uint32_t depth{0}; + uint64_t headAddr{0}; + uint64_t tailAddr{0}; + DBMode dbMode{DBMode::INVALID_DB}; // 0-hw/1-sw + uint64_t dbAddr{0}; + uint32_t sl{0}; +}; + +struct HcclAiRMACQ { + uint32_t cqn{0}; + uint64_t bufAddr{0}; + uint32_t cqeSize{0}; + uint32_t depth{0}; + uint64_t headAddr{0}; + uint64_t tailAddr{0}; + DBMode dbMode{DBMode::INVALID_DB}; // 0-hw/1-sw + uint64_t dbAddr{0}; +}; + +struct hns_roce_rc_sq_wqe { + uint32_t byte_4; + uint32_t msg_len; + uint32_t immtdata; + uint32_t byte_16; + uint32_t byte_20; + uint32_t rkey; + uint64_t remoteVA; +}; + + +struct hns_roce_lite_wqe_data_seg { + uint32_t len; + uint32_t lkey; + uint64_t localVA; +}; + +__aicore__ inline void cacheWriteThrough(__gm__ uint8_t* sourceAddr, uint64_t length) { + __gm__ uint8_t* start = + (__gm__ uint8_t*)((uint64_t)sourceAddr / AscendC::CACHE_LINE_SIZE * AscendC::CACHE_LINE_SIZE); + __gm__ uint8_t* end = + (__gm__ uint8_t*)(((uint64_t)sourceAddr + length) / AscendC::CACHE_LINE_SIZE * AscendC::CACHE_LINE_SIZE); + AscendC::GlobalTensor global; + global.SetGlobalBuffer(start); + for (uint32_t i = 0; i <= end - start; i += AscendC::CACHE_LINE_SIZE) { + AscendC::DataCacheCleanAndInvalid(global[i]); + } +} + +#endif // MOE_DISTRIBUTE_BASE_H \ No newline at end of file